diff --git a/archive_escape_test.go b/archive_escape_test.go index f005aed..f994c0c 100644 --- a/archive_escape_test.go +++ b/archive_escape_test.go @@ -5,35 +5,30 @@ import ( "os" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSafeComponent(t *testing.T) { valid := []string{"file", "dir1", "a.b", "..foo", "foo..", "name with space"} for _, n := range valid { - if err := safeComponent(n); err != nil { - t.Errorf("safeComponent(%q) = %v, want nil", n, err) - } + assert.NoError(t, safeComponent(n), "safeComponent(%q)", n) } invalid := []string{"", ".", "..", "evil/passwd", "a/b", "/abs", `a\b`, `\abs`} for _, n := range invalid { - if err := safeComponent(n); err == nil { - t.Errorf("safeComponent(%q) = nil, want error", n) - } + assert.Error(t, safeComponent(n), "safeComponent(%q)", n) } } func TestConfined(t *testing.T) { in := []string{".", "a", "a/b", "a/../b", "./a"} for _, p := range in { - if !confined(p) { - t.Errorf("confined(%q) = false, want true", p) - } + assert.True(t, confined(p), "confined(%q)", p) } out := []string{"..", "../x", "a/../..", "/abs", "/"} for _, p := range out { - if confined(p) { - t.Errorf("confined(%q) = true, want false", p) - } + assert.False(t, confined(p), "confined(%q)", p) } } @@ -52,35 +47,25 @@ func TestArchiveDecoderRejectsEmbeddedSlash(t *testing.T) { Mode: os.ModeDir | 0755, MTime: time.Unix(0, 0), } - if _, err := enc.Encode(entry); err != nil { - t.Fatal(err) - } + _, err := enc.Encode(entry) + require.NoError(t, err) name := "evil/passwd" fn := FormatFilename{ FormatHeader: FormatHeader{Size: uint64(16 + len(name) + 1), Type: CaFormatFilename}, Name: name, } - if _, err := enc.Encode(fn); err != nil { - t.Fatal(err) - } + _, err = enc.Encode(fn) + require.NoError(t, err) d := NewArchiveDecoder(&buf) // First node is the (unnamed) root directory. v, err := d.Next() - if err != nil { - t.Fatalf("decoding root: %v", err) - } - if _, ok := v.(NodeDirectory); !ok { - t.Fatalf("expected NodeDirectory, got %T", v) - } + require.NoError(t, err, "decoding root") + require.IsType(t, NodeDirectory{}, v) // The embedded-slash filename must be rejected. _, err = d.Next() - if err == nil { - t.Fatal("expected error for embedded-slash filename, got nil") - } - if _, ok := err.(InvalidFormat); !ok { - t.Fatalf("expected InvalidFormat, got %T: %v", err, err) - } + require.Error(t, err, "expected error for embedded-slash filename") + require.IsType(t, InvalidFormat{}, err) } diff --git a/archive_test.go b/archive_test.go index 93bb725..80a505f 100644 --- a/archive_test.go +++ b/archive_test.go @@ -3,15 +3,14 @@ package desync import ( "os" "path" - "reflect" "testing" + + "github.com/stretchr/testify/require" ) func TestArchiveDecoderTypes(t *testing.T) { f, err := os.Open("testdata/flat.catar") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer f.Close() d := NewArchiveDecoder(f) @@ -28,20 +27,14 @@ func TestArchiveDecoderTypes(t *testing.T) { for _, exp := range expected { v, err := d.Next() - if err != nil { - t.Fatal(err) - } - if reflect.TypeOf(exp) != reflect.TypeOf(v) { - t.Fatalf("expected %s, got %s", reflect.TypeOf(exp), reflect.TypeOf(v)) - } + require.NoError(t, err) + require.IsType(t, exp, v) } } func TestArchiveDecoderNesting(t *testing.T) { f, err := os.Open("testdata/nested.catar") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer f.Close() d := NewArchiveDecoder(f) @@ -67,30 +60,18 @@ func TestArchiveDecoderNesting(t *testing.T) { for _, e := range expected { v, err := d.Next() - if err != nil { - t.Fatal(err) - } - if reflect.TypeOf(e.Type) != reflect.TypeOf(v) { - t.Fatalf("expected %s, got %s", reflect.TypeOf(e.Type), reflect.TypeOf(v)) - } + require.NoError(t, err) + require.IsType(t, e.Type, v) if e.Type == nil { break } switch val := v.(type) { case NodeDirectory: - if val.Name != e.Name { - t.Fatalf("expected name '%s', got '%s'", e.Name, val.Name) - } - if val.UID != e.UID { - t.Fatalf("expected uid '%d', got '%d'", e.UID, val.UID) - } + require.Equal(t, e.Name, val.Name) + require.Equal(t, e.UID, val.UID) case NodeFile: - if val.Name != e.Name { - t.Fatalf("expected name '%s', got '%s'", e.Name, val.Name) - } - if val.UID != e.UID { - t.Fatalf("expected uid '%d', got '%d'", e.UID, val.UID) - } + require.Equal(t, e.Name, val.Name) + require.Equal(t, e.UID, val.UID) } } } diff --git a/chunker_test.go b/chunker_test.go index 2fb4b08..2fdf9bb 100644 --- a/chunker_test.go +++ b/chunker_test.go @@ -5,6 +5,8 @@ import ( "crypto/sha512" "os" "testing" + + "github.com/stretchr/testify/require" ) const ( @@ -15,9 +17,7 @@ const ( func TestChunkerLargeFile(t *testing.T) { f, err := os.Open("testdata/chunker.input") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer f.Close() expected := []struct { @@ -48,73 +48,42 @@ func TestChunkerLargeFile(t *testing.T) { } c, err := NewChunker(f, ChunkSizeMinDefault, ChunkSizeAvgDefault, ChunkSizeMaxDefault) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) for i, e := range expected { start, buf, err := c.Next() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) chunkID := ChunkID(sha512.Sum512_256(buf)) - hash := (&chunkID).String() - if hash != e.ID { - t.Fatalf("chunk #%d, unexpected hash %s, expected %s", i+1, hash, e.ID) - } - if start != e.Start { - t.Fatalf("chunk #%d, unexpected start %d, expected %d", i+1, start, e.Start) - } - if uint64(len(buf)) != e.Size { - t.Fatalf("chunk #%d, unexpected size %d, expected %d", i+1, uint64(len(buf)), e.Size) - } + require.Equal(t, e.ID, chunkID.String(), "chunk #%d hash", i+1) + require.Equal(t, e.Start, start, "chunk #%d start", i+1) + require.Equal(t, e.Size, uint64(len(buf)), "chunk #%d size", i+1) } // Should get a size of 0 at the end _, buf, err := c.Next() - if err != nil { - t.Fatal(err) - } - if len(buf) != 0 { - t.Fatalf("expected size 0 at the end, got %d", len(buf)) - } + require.NoError(t, err) + require.Empty(t, buf, "expected size 0 at the end") } func TestChunkerEmptyFile(t *testing.T) { r := bytes.NewReader([]byte{}) c, err := NewChunker(r, ChunkSizeMinDefault, ChunkSizeAvgDefault, ChunkSizeMaxDefault) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) start, buf, err := c.Next() - if err != nil { - t.Fatal(err) - } - if len(buf) != 0 { - t.Fatalf("unexpected size %d, expected 0", len(buf)) - } - if start != 0 { - t.Fatalf("unexpected start position %d, expected 0", start) - } + require.NoError(t, err) + require.Empty(t, buf) + require.Equal(t, uint64(0), start) } func TestChunkerSmallFile(t *testing.T) { b := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} r := bytes.NewReader(b) c, err := NewChunker(r, ChunkSizeMinDefault, ChunkSizeAvgDefault, ChunkSizeMaxDefault) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) start, buf, err := c.Next() - if err != nil { - t.Fatal(err) - } - if len(buf) != len(b) { - t.Fatalf("unexpected size %d, expected %d", len(buf), len(b)) - } - if start != 0 { - t.Fatalf("unexpected start position %d, expected 0", start) - } + require.NoError(t, err) + require.Len(t, buf, len(b)) + require.Equal(t, uint64(0), start) } // There are no chunk boundaries when all data is nil, make sure we get the @@ -123,23 +92,15 @@ func TestChunkerNoBoundary(t *testing.T) { b := make([]byte, 1024*1024) r := bytes.NewReader(b) c, err := NewChunker(r, ChunkSizeMinDefault, ChunkSizeAvgDefault, ChunkSizeMaxDefault) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) for { start, buf, err := c.Next() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if len(buf) == 0 { break } - if uint64(len(buf)) != ChunkSizeMaxDefault { - t.Fatalf("unexpected size %d, expected %d", len(buf), ChunkSizeMaxDefault) - } - if start%ChunkSizeMaxDefault != 0 { - t.Fatalf("unexpected start position %d, expected 0", start) - } + require.Equal(t, ChunkSizeMaxDefault, uint64(len(buf))) + require.Zero(t, start%ChunkSizeMaxDefault, "unexpected start position %d", start) } } @@ -157,20 +118,12 @@ func TestChunkerBounds(t *testing.T) { b := make([]byte, c.size) r := bytes.NewReader(b) c, err := NewChunker(r, ChunkSizeMinDefault, ChunkSizeAvgDefault, ChunkSizeMaxDefault) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) start, buf, err := c.Next() - if err != nil { - t.Fatal(err) - } - if len(buf) != len(b) { - t.Fatalf("unexpected size %d, expected %d", len(buf), len(b)) - } - if start != 0 { - t.Fatalf("unexpected start position %d, expected 0", start) - } + require.NoError(t, err) + require.Len(t, buf, len(b)) + require.Equal(t, uint64(0), start) }) } } @@ -195,46 +148,28 @@ func TestChunkerAdvance(t *testing.T) { input := join(nullChunk.Data, dataA, nullChunk.Data, dataB) c, err := NewChunker(bytes.NewReader(input), ChunkSizeMinDefault, ChunkSizeAvgDefault, ChunkSizeMaxDefault) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // Chunk the first part, this should be a null chunk _, buf, err := c.Next() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(buf, nullChunk.Data) { - t.Fatal("expected null chunk") - } + require.NoError(t, err) + require.Equal(t, nullChunk.Data, buf, "expected null chunk") // Now skip the dataA slice - if err := c.Advance(len(dataA)); err != nil { - t.Fatal(err) - } + require.NoError(t, c.Advance(len(dataA))) // Read the 2nd null chunk _, buf, err = c.Next() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(buf, nullChunk.Data) { - t.Fatal("expected null chunk") - } + require.NoError(t, err) + require.Equal(t, nullChunk.Data, buf, "expected null chunk") // Skip over dataB - if err := c.Advance(len(dataB)); err != nil { - t.Fatal(err) - } + require.NoError(t, c.Advance(len(dataB))) // Should be at the end, nothing more to chunk _, buf, err = c.Next() - if err != nil { - t.Fatal(err) - } - if len(buf) != 0 { - t.Fatal("expected end of input") - } + require.NoError(t, err) + require.Empty(t, buf, "expected end of input") } // Global vars used for results during the benchmark to prevent optimizer diff --git a/dedupqueue_test.go b/dedupqueue_test.go index 233b65e..5458442 100644 --- a/dedupqueue_test.go +++ b/dedupqueue_test.go @@ -1,21 +1,15 @@ package desync import ( - "reflect" "sync" "sync/atomic" "testing" "time" + + "github.com/stretchr/testify/require" ) func TestDedupQueueSimple(t *testing.T) { - // var requests int64 - // store := &TestStore{ - // GetChunkFunc: func(ChunkID) (*Chunk, error) { - // atomic.AddInt64(&requests, 1) - // return NewChunkFromUncompressed([]byte{0}), nil - // }, - // } exists := ChunkID{0} notExists := ChunkID{1} store := &TestStore{ @@ -27,31 +21,19 @@ func TestDedupQueueSimple(t *testing.T) { // First compare we're getting the expected data in the positive case bExpected, err := store.GetChunk(exists) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) bActual, err := q.GetChunk(exists) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(bActual, bExpected) { - t.Fatalf("got %v; want %v", bExpected, bActual) - } + require.NoError(t, err) + require.Equal(t, bExpected, bActual) // Now make sure errors too are passed correctly _, err = q.GetChunk(notExists) - if _, ok := err.(ChunkMissing); !ok { - t.Fatalf("got '%v'; want chunk missing error", err) - } + require.IsType(t, ChunkMissing{}, err) // Check HasChunk() as well hasChunk, err := q.HasChunk(exists) - if err != nil { - t.Fatal(err) - } - if !hasChunk { - t.Fatalf("HasChunk() = false; want true") - } + require.NoError(t, err) + require.True(t, hasChunk) } func TestDedupQueueParallel(t *testing.T) { @@ -85,7 +67,5 @@ func TestDedupQueueParallel(t *testing.T) { wg.Wait() // There should ideally be just one requests that was done on the upstream store - if requests > 1 { - t.Fatalf("%d requests to the store; want 1", requests) - } + require.LessOrEqual(t, requests, int64(1), "requests to the store") } diff --git a/failover_test.go b/failover_test.go index b313b13..4be6d3d 100644 --- a/failover_test.go +++ b/failover_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" ) @@ -15,9 +16,7 @@ func TestFailoverMissingChunk(t *testing.T) { s := &TestStore{} g := NewFailoverGroup(s) _, err := g.GetChunk(ChunkID{0}) - if _, ok := err.(ChunkMissing); !ok { - t.Fatalf("expected missing chunk error, got %T", err) - } + require.IsType(t, ChunkMissing{}, err) } func TestFailoverAllError(t *testing.T) { @@ -26,9 +25,8 @@ func TestFailoverAllError(t *testing.T) { GetChunkFunc: func(ChunkID) (*Chunk, error) { return nil, failed }, } g := NewFailoverGroup(storeFail, storeFail) - if _, err := g.GetChunk(ChunkID{0}); err != failed { - t.Fatalf("expected error, got %T", err) - } + _, err := g.GetChunk(ChunkID{0}) + require.ErrorIs(t, err, failed) } func TestFailoverSimple(t *testing.T) { @@ -44,14 +42,11 @@ func TestFailoverSimple(t *testing.T) { g := NewFailoverGroup(storeFail, storeFail, storeSucc) // Request a chunk, should succeed - if _, err := g.GetChunk(ChunkID{0}); err != nil { - t.Fatal(err) - } + _, err := g.GetChunk(ChunkID{0}) + require.NoError(t, err) // Look inside the group to confirm we failed over to the last one - if g.active != 2 { - t.Fatalf("expected g.active=1, but got %d", g.active) - } + require.Equal(t, 2, g.active) } func TestFailoverMutliple(t *testing.T) { @@ -116,8 +111,5 @@ func TestFailoverMutliple(t *testing.T) { } }) - err := eg.Wait() - if err != nil { - t.Fatal(err) - } + require.NoError(t, eg.Wait()) } diff --git a/format_test.go b/format_test.go index 5c7c658..b301489 100644 --- a/format_test.go +++ b/format_test.go @@ -3,15 +3,14 @@ package desync import ( "bytes" "os" - "reflect" "testing" + + "github.com/stretchr/testify/require" ) func TestFormatDecoder(t *testing.T) { f, err := os.Open("testdata/flat.catar") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer f.Close() d := NewFormatDecoder(f) @@ -49,46 +48,30 @@ func TestFormatDecoder(t *testing.T) { for _, exp := range expected { v, err := d.Next() - if err != nil { - t.Fatal(err) - } - if reflect.TypeOf(exp) != reflect.TypeOf(v) { - t.Fatalf("expected %s, got %s", reflect.TypeOf(exp), reflect.TypeOf(v)) - } + require.NoError(t, err) + require.IsType(t, exp, v) } } func TestIndexDecoder(t *testing.T) { f, err := os.Open("testdata/index.caibx") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer f.Close() d := NewFormatDecoder(f) // The file should start with the index e, err := d.Next() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) index, ok := e.(FormatIndex) - if !ok { - t.Fatal("file doesn't start with an index") - } - if index.FeatureFlags != CaFormatSHA512256|CaFormatExcludeNoDump { - t.Fatal("index flags don't match expected") - } + require.True(t, ok, "file doesn't start with an index") + require.Equal(t, uint64(CaFormatSHA512256|CaFormatExcludeNoDump), index.FeatureFlags) // Now get the table with the chunks e, err = d.Next() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) table, ok := e.(FormatTable) - if !ok { - t.Fatal("index table not found") - } + require.True(t, ok, "index table not found") // Define the chunk IDs and the order they should be in the file expected := []string{ @@ -97,15 +80,11 @@ func TestIndexDecoder(t *testing.T) { "fadff4b303624f2be3d0e04c2f105306118a9f608ef1e4f83c1babbd23a2315f", } // Check the expected length of the table - if len(table.Items) != len(expected) { - t.Fatalf("expected %d chunks in index table, got %d", len(expected), len(table.Items)) - } + require.Len(t, table.Items, len(expected)) // And then make sure the IDs and order match for i := range expected { id, _ := ChunkIDFromString(expected[i]) - if table.Items[i].Chunk != id { - t.Fatalf("expected chunk %s, got %s", id, table.Items[i].Chunk) - } + require.Equal(t, id, table.Items[i].Chunk) } } @@ -118,9 +97,7 @@ func TestEncoder(t *testing.T) { } for _, name := range files { in, err := os.ReadFile(name) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // Decoder d := NewFormatDecoder(bytes.NewReader(in)) @@ -133,26 +110,18 @@ func TestEncoder(t *testing.T) { var total int64 for { v, err := d.Next() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if v == nil { break } n, err := e.Encode(v) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) total += n } // in/out should match - if !bytes.Equal(in, out.Bytes()) { - t.Fatalf("decoded/encoded don't match for file '%s'", name) - } - if total != int64(out.Len()) { - t.Fatalf("unexpected length for encoding of '%s'", name) - } + require.Equal(t, in, out.Bytes(), "decoded/encoded don't match for file '%s'", name) + require.Equal(t, int64(out.Len()), total, "unexpected length for encoding of '%s'", name) } } @@ -199,7 +168,5 @@ func TestGoodbyeBST(t *testing.T) { out := makeGoodbyeBST(in) - if !reflect.DeepEqual(out, expected) { - t.Fatal("BST doesn't match expected") - } + require.Equal(t, expected, out) } diff --git a/gcs_test.go b/gcs_test.go index 09f01aa..21fa2c9 100644 --- a/gcs_test.go +++ b/gcs_test.go @@ -2,6 +2,8 @@ package desync import ( "testing" + + "github.com/stretchr/testify/require" ) func TestNormalizeGCPrefix(t *testing.T) { @@ -26,9 +28,7 @@ func TestNormalizeGCPrefix(t *testing.T) { prefix := normalizeGCPrefix(test.path) - if prefix != test.expectedPrefix { - t.Fatalf("path '%s' should normalize into '%s' but was normalized into '%s'", test.path, test.expectedPrefix, prefix) - } + require.Equal(t, test.expectedPrefix, prefix) }) } } diff --git a/index_test.go b/index_test.go index eb911fe..8972c61 100644 --- a/index_test.go +++ b/index_test.go @@ -5,21 +5,19 @@ import ( "context" "encoding/binary" "os" - "reflect" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestIndexLoad(t *testing.T) { f, err := os.Open("testdata/index.caibx") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer f.Close() index, err := IndexFromReader(f) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) type chunk struct { chunk string @@ -34,78 +32,52 @@ func TestIndexLoad(t *testing.T) { for i := range expected { id, _ := ChunkIDFromString(expected[i].chunk) exp := IndexChunk{ID: id, Start: expected[i].start, Size: expected[i].size} - got := index.Chunks[i] - if !reflect.DeepEqual(exp, got) { - t.Fatalf("expected %v, got %v", exp, got) - } + require.Equal(t, exp, index.Chunks[i]) } } func TestIndexWrite(t *testing.T) { in, err := os.ReadFile("testdata/index.caibx") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) idx, err := IndexFromReader(bytes.NewReader(in)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) out := new(bytes.Buffer) n, err := idx.WriteTo(out) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // in/out should match - if !bytes.Equal(in, out.Bytes()) { - t.Fatalf("decoded/encoded don't match") - } - if n != int64(out.Len()) { - t.Fatalf("unexpected length") - } + require.Equal(t, in, out.Bytes(), "decoded/encoded don't match") + require.Equal(t, int64(out.Len()), n, "unexpected length") } func TestIndexChunking(t *testing.T) { // Open the blob f, err := os.Open("testdata/chunker.input") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer f.Close() // Create a chunker c, err := NewChunker(f, ChunkSizeMinDefault, ChunkSizeAvgDefault, ChunkSizeMaxDefault) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // Make a temp local store dir := t.TempDir() s, err := NewLocalStore(dir, StoreOptions{}) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // Split up the blob into chunks and return the index idx, err := ChunkStream(context.Background(), c, s, 10) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // Write the index and compare it to the expected one b := new(bytes.Buffer) - if _, err = idx.WriteTo(b); err != nil { - t.Fatal(err) - } + _, err = idx.WriteTo(b) + require.NoError(t, err) i, err := os.ReadFile("testdata/chunker.index") - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(b.Bytes(), i) { - t.Fatal("index doesn't match expected") - } + require.NoError(t, err) + require.Equal(t, i, b.Bytes(), "index doesn't match expected") // Make sure the local store contains all the expected chunks expectedChunks := []string{ @@ -132,16 +104,10 @@ func TestIndexChunking(t *testing.T) { } for _, sid := range expectedChunks { id, err := ChunkIDFromString(sid) - if err != nil { - t.Fatal(id) - } + require.NoError(t, err) hasChunk, err := s.HasChunk(id) - if err != nil { - t.Fatal(err) - } - if !hasChunk { - t.Fatalf("store is missing chunk %s", id) - } + require.NoError(t, err) + require.True(t, hasChunk, "store is missing chunk %s", id) } } @@ -160,34 +126,21 @@ func TestChunkStreamIntegrity(t *testing.T) { } c, err := NewChunker(bytes.NewReader(data), ChunkSizeMinDefault, ChunkSizeAvgDefault, ChunkSizeMaxDefault) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) dir := t.TempDir() s, err := NewLocalStore(dir, StoreOptions{}) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) index, err := ChunkStream(context.Background(), c, s, 10) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // Verify every stored chunk: read it back, decompress, and check // that its content hashes to the expected ChunkID. GetChunk with // SkipVerify=false (the default) returns ChunkInvalid on mismatch. - var corrupted int for i, chunk := range index.Chunks { _, err := s.GetChunk(chunk.ID) - if err != nil { - corrupted++ - t.Errorf("chunk %d (%s): %v", i, chunk.ID, err) - } - } - if corrupted > 0 { - t.Fatalf("%d of %d chunks are corrupted", corrupted, len(index.Chunks)) + assert.NoError(t, err, "chunk %d (%s)", i, chunk.ID) } } diff --git a/localfs_escape_test.go b/localfs_escape_test.go index e1011ec..7ddcb98 100644 --- a/localfs_escape_test.go +++ b/localfs_escape_test.go @@ -7,6 +7,9 @@ import ( "path/filepath" "strings" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func newTestLocalFS(root string) *LocalFS { @@ -25,28 +28,19 @@ func TestLocalFSSymlinkWriteEscape(t *testing.T) { outside := t.TempDir() sentinel := filepath.Join(outside, "passwd") const orig = "original-contents\n" - if err := os.WriteFile(sentinel, []byte(orig), 0644); err != nil { - t.Fatal(err) - } + require.NoError(t, os.WriteFile(sentinel, []byte(orig), 0644)) fs := newTestLocalFS(root) defer fs.Close() - if err := fs.CreateSymlink(NodeSymlink{Name: "evil", Target: outside}); err != nil { - t.Fatalf("CreateSymlink: %v", err) - } + require.NoError(t, fs.CreateSymlink(NodeSymlink{Name: "evil", Target: outside})) - if err := fs.CreateFile(NodeFile{Name: "evil/passwd", Data: strings.NewReader("PWNED")}); err == nil { - t.Fatal("CreateFile through escaping symlink succeeded, want error") - } + require.Error(t, fs.CreateFile(NodeFile{Name: "evil/passwd", Data: strings.NewReader("PWNED")}), + "CreateFile through escaping symlink succeeded, want error") got, err := os.ReadFile(sentinel) - if err != nil { - t.Fatal(err) - } - if string(got) != orig { - t.Fatalf("file outside root was modified: got %q, want %q", got, orig) - } + require.NoError(t, err) + require.Equal(t, orig, string(got), "file outside root was modified") } // TestLocalFSLexicalEscape covers the plain ".." traversal at the LocalFS @@ -56,15 +50,12 @@ func TestLocalFSLexicalEscape(t *testing.T) { fs := newTestLocalFS(root) defer fs.Close() - if err := fs.CreateFile(NodeFile{Name: "../escape", Data: strings.NewReader("x")}); err == nil { - t.Error("CreateFile(../escape) succeeded, want error") - } - if err := fs.CreateDir(NodeDirectory{Name: "../evildir"}); err == nil { - t.Error("CreateDir(../evildir) succeeded, want error") - } - if _, err := os.Lstat(filepath.Join(filepath.Dir(root), "escape")); !os.IsNotExist(err) { - t.Errorf("escape file created outside root: err=%v", err) - } + assert.Error(t, fs.CreateFile(NodeFile{Name: "../escape", Data: strings.NewReader("x")}), + "CreateFile(../escape) succeeded, want error") + assert.Error(t, fs.CreateDir(NodeDirectory{Name: "../evildir"}), + "CreateDir(../evildir) succeeded, want error") + _, err := os.Lstat(filepath.Join(filepath.Dir(root), "escape")) + assert.True(t, os.IsNotExist(err), "escape file created outside root: err=%v", err) } // TestLocalFSAbsoluteSymlinkTargetVerbatim confirms an absolute symlink target @@ -75,23 +66,15 @@ func TestLocalFSAbsoluteSymlinkTargetVerbatim(t *testing.T) { fs := newTestLocalFS(root) defer fs.Close() - if err := fs.CreateSymlink(NodeSymlink{Name: "abs", Target: "/etc"}); err != nil { - t.Fatalf("CreateSymlink: %v", err) - } + require.NoError(t, fs.CreateSymlink(NodeSymlink{Name: "abs", Target: "/etc"})) tgt, err := os.Readlink(filepath.Join(root, "abs")) - if err != nil { - t.Fatal(err) - } - if tgt != "/etc" { - t.Fatalf("symlink target = %q, want /etc (verbatim)", tgt) - } - - if err := fs.CreateFile(NodeFile{Name: "abs/desync-should-not-exist", Data: strings.NewReader("x")}); err == nil { - t.Fatal("write through absolute symlink succeeded, want error") - } - if _, err := os.Lstat("/etc/desync-should-not-exist"); !os.IsNotExist(err) { - t.Fatalf("file created under /etc: err=%v", err) - } + require.NoError(t, err) + require.Equal(t, "/etc", tgt, "symlink target should be created verbatim") + + require.Error(t, fs.CreateFile(NodeFile{Name: "abs/desync-should-not-exist", Data: strings.NewReader("x")}), + "write through absolute symlink succeeded, want error") + _, err = os.Lstat("/etc/desync-should-not-exist") + require.True(t, os.IsNotExist(err), "file created under /etc: err=%v", err) } // TestLocalFSBenignRelativeSymlink is a regression guard: a relative symlink @@ -102,20 +85,11 @@ func TestLocalFSBenignRelativeSymlink(t *testing.T) { fs := newTestLocalFS(root) defer fs.Close() - if err := fs.CreateDir(NodeDirectory{Name: "sub"}); err != nil { - t.Fatalf("CreateDir: %v", err) - } - if err := fs.CreateSymlink(NodeSymlink{Name: "link", Target: "sub"}); err != nil { - t.Fatalf("CreateSymlink: %v", err) - } - if err := fs.CreateFile(NodeFile{Name: "link/g", Data: strings.NewReader("hello")}); err != nil { - t.Fatalf("CreateFile through in-root symlink: %v", err) - } + require.NoError(t, fs.CreateDir(NodeDirectory{Name: "sub"})) + require.NoError(t, fs.CreateSymlink(NodeSymlink{Name: "link", Target: "sub"})) + require.NoError(t, fs.CreateFile(NodeFile{Name: "link/g", Data: strings.NewReader("hello")}), + "CreateFile through in-root symlink") got, err := os.ReadFile(filepath.Join(root, "sub", "g")) - if err != nil { - t.Fatalf("expected file via in-root symlink: %v", err) - } - if string(got) != "hello" { - t.Fatalf("content = %q, want hello", got) - } + require.NoError(t, err, "expected file via in-root symlink") + require.Equal(t, "hello", string(got)) } diff --git a/mount-index_linux_test.go b/mount-index_linux_test.go index 80f2ef9..98de0b5 100644 --- a/mount-index_linux_test.go +++ b/mount-index_linux_test.go @@ -1,7 +1,6 @@ package desync import ( - "bytes" "context" "crypto/sha256" "os" @@ -9,6 +8,8 @@ import ( "sync" "testing" "time" + + "github.com/stretchr/testify/require" ) func TestMountIndex(t *testing.T) { @@ -17,27 +18,19 @@ func TestMountIndex(t *testing.T) { // Define the store s, err := NewLocalStore("testdata/blob1.store", StoreOptions{}) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer s.Close() // Read the index f, err := os.Open("testdata/blob1.caibx") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer f.Close() index, err := IndexFromReader(f) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // Calculate the expected hash b, err := os.ReadFile("testdata/blob1") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) wantHash := sha256.Sum256(b) // Make sure that the unmount happens on exit @@ -59,19 +52,15 @@ func TestMountIndex(t *testing.T) { select { case err = <-c: - t.Fatal(err) + require.FailNow(t, "mount exited early", "%v", err) case <-time.After(time.Second): } // Calculate the hash of the file in the mount point b, err = os.ReadFile(filepath.Join(mnt, "blob1")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) gotHash := sha256.Sum256(b) // Compare the checksums - if !bytes.Equal(gotHash[:], wantHash[:]) { - t.Fatalf("unexpected hash of mounted file. Want %x, got %x", gotHash, wantHash) - } + require.Equal(t, wantHash, gotHash) } diff --git a/protocol_test.go b/protocol_test.go index d7d5c02..bfe2eeb 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -7,6 +7,7 @@ import ( "io" "testing" + "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" ) @@ -89,8 +90,5 @@ func TestProtocol(t *testing.T) { w1.Close() w2.Close() - err := g.Wait() - if err != nil { - t.Fatal(err) - } + require.NoError(t, g.Wait()) } diff --git a/protocolserver_test.go b/protocolserver_test.go index c663db8..be47cc7 100644 --- a/protocolserver_test.go +++ b/protocolserver_test.go @@ -1,10 +1,11 @@ package desync import ( - "bytes" "context" "io" "testing" + + "github.com/stretchr/testify/require" ) func TestProtocolServer(t *testing.T) { @@ -26,26 +27,16 @@ func TestProtocolServer(t *testing.T) { // Client flags, err := server.Initialize(CaProtocolPullChunks) - if err != nil { - t.Fatal(err) - } - if flags&CaProtocolReadableStore == 0 { - t.Fatalf("server not offering chunks") - } + require.NoError(t, err) + require.NotZero(t, flags&CaProtocolReadableStore, "server not offering chunks") // Should find this chunk chunk, err := server.RequestChunk(id) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) b, _ := chunk.Data() - if !bytes.Equal(b, uncompressed) { - t.Fatal("chunk data doesn't match expected") - } + require.Equal(t, uncompressed, b) // This one's missing _, err = server.RequestChunk(ChunkID{0}) - if _, ok := err.(ChunkMissing); !ok { - t.Fatal("expected ChunkMissing error, got:", err) - } + require.IsType(t, ChunkMissing{}, err) } diff --git a/remotehttp_test.go b/remotehttp_test.go index 15a15ac..8d6d9e5 100644 --- a/remotehttp_test.go +++ b/remotehttp_test.go @@ -7,6 +7,9 @@ import ( "net/url" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestHTTPStoreURL(t *testing.T) { @@ -32,13 +35,9 @@ func TestHTTPStoreURL(t *testing.T) { t.Run(name, func(t *testing.T) { u.Path = test.storePath s, err := NewRemoteHTTPStore(u, StoreOptions{}) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) s.GetChunk(chunkID) - if requestURI != test.serverPath { - t.Fatalf("got request uri '%s', want '%s'", requestURI, test.serverPath) - } + require.Equal(t, test.serverPath, requestURI) }) } } @@ -106,15 +105,17 @@ func TestHasChunk(t *testing.T) { t.Run(name, func(t *testing.T) { u.Path = "/" s, err := NewRemoteHTTPStore(u, StoreOptions{ErrorRetry: 5, ErrorRetryBaseInterval: time.Microsecond}) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) attemptCount = 0 hasChunk, err := s.HasChunk(test.chunkId) - if (hasChunk != test.hasChunk) || ((err != nil) != test.hasError) || (attemptCount != test.attemptCount) { - t.Errorf("expected hasChunk = %t / hasError = %t / attemptCount = %d, got %t / %t / %d", test.hasChunk, test.hasError, test.attemptCount, hasChunk, (err != nil), attemptCount) + assert.Equal(t, test.hasChunk, hasChunk) + if test.hasError { + assert.Error(t, err) + } else { + assert.NoError(t, err) } + assert.Equal(t, test.attemptCount, attemptCount) }) } } @@ -192,9 +193,7 @@ func TestGetChunk(t *testing.T) { t.Run(name, func(t *testing.T) { u.Path = "/" s, err := NewRemoteHTTPStore(u, StoreOptions{ErrorRetry: 5, ErrorRetryBaseInterval: time.Microsecond, Uncompressed: true}) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) attemptCount = 0 content, err := s.GetChunk(test.chunkId) @@ -203,9 +202,13 @@ func TestGetChunk(t *testing.T) { uncompressedContent, _ := content.Data() content_string = string(uncompressedContent) } - if (content_string != test.content) || ((err != nil) != test.hasError) || (attemptCount != test.attemptCount) { - t.Errorf("expected content = \"%s\" / hasError = %t / attemptCount = %d, got \"%s\" / %t / %d", test.content, test.hasError, test.attemptCount, content_string, (err != nil), attemptCount) + assert.Equal(t, test.content, content_string) + if test.hasError { + assert.Error(t, err) + } else { + assert.NoError(t, err) } + assert.Equal(t, test.attemptCount, attemptCount) }) } } @@ -291,9 +294,7 @@ func TestPutChunk(t *testing.T) { t.Run(name, func(t *testing.T) { u.Path = "/" s, err := NewRemoteHTTPStore(u, StoreOptions{ErrorRetry: 5, ErrorRetryBaseInterval: time.Microsecond, Uncompressed: true}) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) attemptCount = 0 writtenContent = nil @@ -303,9 +304,13 @@ func TestPutChunk(t *testing.T) { if writtenContent != nil { writtenContentString = string(writtenContent) } - if ((err != nil) != test.hasError) || (attemptCount != test.attemptCount) || (writtenContentString != test.writtenContent) { - t.Errorf("expected writtenContent = \"%s\" / hasError = %t / attemptCount = %d, got \"%s\" / %t / %d", test.writtenContent, test.hasError, test.attemptCount, writtenContentString, (err != nil), attemptCount) + assert.Equal(t, test.writtenContent, writtenContentString) + if test.hasError { + assert.Error(t, err) + } else { + assert.NoError(t, err) } + assert.Equal(t, test.attemptCount, attemptCount) }) } } diff --git a/s3_test.go b/s3_test.go index abee656..943af49 100644 --- a/s3_test.go +++ b/s3_test.go @@ -17,6 +17,7 @@ import ( minio "github.com/minio/minio-go/v6" "github.com/minio/minio-go/v6/pkg/credentials" + "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" ) @@ -134,17 +135,11 @@ func getTcpS3Server(t *testing.T, group *errgroup.Group, ctx context.Context, bu var errorTimes int // using localhost + resolver let us work on hosts that support only ipv6 or only ipv4 ip, err := net.DefaultResolver.LookupIP(ctx, "ip", "localhost") - if err != nil { - t.Fatal(err) - } - if len(ip) < 1 { - t.Fatalf("cannot resolve localhost") - } + require.NoError(t, err) + require.NotEmpty(t, ip, "cannot resolve localhost") listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: ip[0], Port: 0}) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) group.Go(func() error { <-ctx.Done() @@ -172,9 +167,7 @@ func getTcpS3Server(t *testing.T, group *errgroup.Group, ctx context.Context, bu func TestS3StoreGetChunk(t *testing.T) { chunkId, err := ChunkIDFromString("dda036db05bc2b99b6b9303d28496000c34b246457ae4bbf00fe625b5cabd7cd") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) location := "vertucon-central" bucket := "doomsdaydevices" provider := MockCredProvider{} @@ -213,9 +206,7 @@ func TestS3StoreGetChunk(t *testing.T) { } }) - if err := group.Wait(); err != nil { - t.Fatal(err) - } + require.NoError(t, group.Wait()) }) t.Run("fail", func(t *testing.T) { @@ -251,9 +242,7 @@ func TestS3StoreGetChunk(t *testing.T) { } }) - if err := group.Wait(); err != nil { - t.Fatal(err) - } + require.NoError(t, group.Wait()) }) t.Run("recover", func(t *testing.T) { @@ -291,8 +280,6 @@ func TestS3StoreGetChunk(t *testing.T) { } }) - if err := group.Wait(); err != nil { - t.Fatal(err) - } + require.NoError(t, group.Wait()) }) } diff --git a/selfseed_test.go b/selfseed_test.go index 93c38e8..1b44cdc 100644 --- a/selfseed_test.go +++ b/selfseed_test.go @@ -6,6 +6,8 @@ import ( "crypto/rand" "os" "testing" + + "github.com/stretchr/testify/require" ) func TestSelfSeed(t *testing.T) { @@ -13,9 +15,7 @@ func TestSelfSeed(t *testing.T) { store := t.TempDir() s, err := NewLocalStore(store, StoreOptions{}) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // Build a number of fake chunks that can then be used in the test in any order type rawChunk struct { @@ -30,9 +30,7 @@ func TestSelfSeed(t *testing.T) { b := make([]byte, size) rand.Read(b) chunk := NewChunk(b) - if err = s.StoreChunk(chunk); err != nil { - t.Fatal(err) - } + require.NoError(t, s.StoreChunk(chunk)) chunks[i] = rawChunk{chunk.ID(), b} } @@ -98,9 +96,7 @@ func TestSelfSeed(t *testing.T) { // Build a temp target file to extract into dst, err := os.CreateTemp("", "dst") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer os.Remove(dst.Name()) defer dst.Close() @@ -108,28 +104,17 @@ func TestSelfSeed(t *testing.T) { stats, err := AssembleFile(context.Background(), dst.Name(), idx, s, nil, AssembleOptions{1, InvalidSeedActionBailOut}, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // Compare the checksums to that of the input data b, err = os.ReadFile(dst.Name()) - if err != nil { - t.Fatal(err) - } - outSum := md5.Sum(b) - if sum != outSum { - t.Fatal("checksum of extracted file doesn't match expected") - } + require.NoError(t, err) + require.Equal(t, sum, md5.Sum(b), "checksum of extracted file doesn't match expected") // Compare to the expected number of bytes copied or cloned from the seed fromSeed := int(stats.BytesCopied + stats.BytesCloned) - if fromSeed < test.minCloned { - t.Fatalf("expected min %d bytes copied/cloned from self-seed, got %d", test.minCloned, fromSeed) - } - if fromSeed > test.maxCloned { - t.Fatalf("expected max %d bytes copied/cloned from self-seed, got %d", test.maxCloned, fromSeed) - } + require.GreaterOrEqual(t, fromSeed, test.minCloned, "bytes copied/cloned from self-seed") + require.LessOrEqual(t, fromSeed, test.maxCloned, "bytes copied/cloned from self-seed") }) } diff --git a/tar_test.go b/tar_test.go index 5460416..583ff19 100644 --- a/tar_test.go +++ b/tar_test.go @@ -8,8 +8,9 @@ import ( "fmt" "os" "path/filepath" - "reflect" "testing" + + "github.com/stretchr/testify/require" ) func TestTar(t *testing.T) { @@ -23,9 +24,7 @@ func TestTar(t *testing.T) { "dir2/sub22", } for _, d := range dirs { - if err := os.MkdirAll(filepath.Join(base, d), 0755); err != nil { - t.Fatal() - } + require.NoError(t, os.MkdirAll(filepath.Join(base, d), 0755)) } files := []string{ @@ -36,16 +35,12 @@ func TestTar(t *testing.T) { os.WriteFile(filepath.Join(base, name), fmt.Appendf(nil, "filecontent%d", i), 0644) } - if err := os.Symlink("dir1", filepath.Join(base, "symlink")); err != nil { - t.Fatal(err) - } + require.NoError(t, os.Symlink("dir1", filepath.Join(base, "symlink"))) // Encode it all into a buffer fs := NewLocalFS(base, LocalFSOptions{}) b := new(bytes.Buffer) - if err := Tar(context.Background(), b, fs); err != nil { - t.Fatal(err) - } + require.NoError(t, Tar(context.Background(), b, fs)) // Decode it again d := NewFormatDecoder(b) @@ -86,11 +81,7 @@ func TestTar(t *testing.T) { for _, exp := range expected { v, err := d.Next() - if err != nil { - t.Fatal(err) - } - if reflect.TypeOf(exp) != reflect.TypeOf(v) { - t.Fatalf("expected %s, got %s", reflect.TypeOf(exp), reflect.TypeOf(v)) - } + require.NoError(t, err) + require.IsType(t, exp, v) } } diff --git a/writededupqueue_test.go b/writededupqueue_test.go index 3cd5fdb..5cced68 100644 --- a/writededupqueue_test.go +++ b/writededupqueue_test.go @@ -3,6 +3,8 @@ package desync import ( "testing" "time" + + "github.com/stretchr/testify/require" ) // Test read access before write access to ensure a failing read doesn't @@ -25,7 +27,5 @@ func TestWriteDedupQueueParallelReadWrite(t *testing.T) { go q.GetChunk(c.ID()) <-sleeping - if err := q.StoreChunk(c); err != nil { - t.Fatal(err) - } + require.NoError(t, q.StoreChunk(c)) }