package archiver import ( "context" "fmt" "runtime" "strings" "sync" "sync/atomic" "testing" "github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/restic" rtest "github.com/restic/restic/internal/test" "golang.org/x/sync/errgroup" ) var errTest = errors.New("test error") type saveFail struct { cnt int32 failAt int32 } func (b *saveFail) SaveBlob(_ context.Context, _ restic.BlobType, _ []byte, id restic.ID, _ bool) (restic.ID, bool, int, error) { val := atomic.AddInt32(&b.cnt, 1) if val == b.failAt { return restic.ID{}, false, 0, errTest } return id, false, 0, nil } func TestBlobSaver(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() wg, ctx := errgroup.WithContext(ctx) saver := &saveFail{} b := NewBlobSaver(ctx, wg, saver, uint(runtime.NumCPU())) var wait sync.WaitGroup var results []SaveBlobResponse var lock sync.Mutex wait.Add(20) for i := 0; i < 20; i++ { buf := &Buffer{Data: []byte(fmt.Sprintf("foo%d", i))} idx := i lock.Lock() results = append(results, SaveBlobResponse{}) lock.Unlock() b.Save(ctx, restic.DataBlob, buf, "file", func(res SaveBlobResponse) { lock.Lock() results[idx] = res lock.Unlock() wait.Done() }) } wait.Wait() for i, sbr := range results { if sbr.known { t.Errorf("blob %v is known, that should not be the case", i) } } b.TriggerShutdown() err := wg.Wait() if err != nil { t.Fatal(err) } } func TestBlobSaverError(t *testing.T) { var tests = []struct { blobs int failAt int }{ {20, 2}, {20, 5}, {20, 15}, {200, 150}, } for _, test := range tests { t.Run("", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() wg, ctx := errgroup.WithContext(ctx) saver := &saveFail{ failAt: int32(test.failAt), } b := NewBlobSaver(ctx, wg, saver, uint(runtime.NumCPU())) for i := 0; i < test.blobs; i++ { buf := &Buffer{Data: []byte(fmt.Sprintf("foo%d", i))} b.Save(ctx, restic.DataBlob, buf, "errfile", func(res SaveBlobResponse) {}) } b.TriggerShutdown() err := wg.Wait() if err == nil { t.Errorf("expected error not found") } rtest.Assert(t, errors.Is(err, errTest), "unexpected error %v", err) rtest.Assert(t, strings.Contains(err.Error(), "errfile"), "expected error to contain 'errfile' got: %v", err) }) } }