From fb422497af69c1267f242275e817257424e08be4 Mon Sep 17 00:00:00 2001 From: Michael Eischer Date: Sun, 31 Dec 2023 00:18:41 +0100 Subject: [PATCH 1/4] repository: split StreamPack implementation Move the actual decoding of the pack data into a separate iterator. --- internal/repository/repository.go | 173 +++++++++++++++++++----------- 1 file changed, 111 insertions(+), 62 deletions(-) diff --git a/internal/repository/repository.go b/internal/repository/repository.go index 97dc33fdf..e13220741 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -882,7 +882,7 @@ const maxUnusedRange = 4 * 1024 * 1024 // StreamPack loads the listed blobs from the specified pack file. The plaintext blob is passed to // the handleBlobFn callback or an error if decryption failed or the blob hash does not match. -// handleBlobFn is never called multiple times for the same blob. If the callback returns an error, +// handleBlobFn is called at most once for each blob. If the callback returns an error, // then StreamPack will abort and not retry it. func StreamPack(ctx context.Context, beLoad BackendLoadFn, key *crypto.Key, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error { if len(blobs) == 0 { @@ -940,72 +940,18 @@ func streamPackPart(ctx context.Context, beLoad BackendLoadFn, key *crypto.Key, if bufferSize > MaxStreamBufferSize { bufferSize = MaxStreamBufferSize } - // create reader here to allow reusing the buffered reader from checker.checkData bufRd := bufio.NewReaderSize(rd, bufferSize) - currentBlobEnd := dataStart - var buf []byte - var decode []byte - for len(blobs) > 0 { - entry := blobs[0] + it := NewPackBlobIterator(packID, bufRd, dataStart, blobs, key, dec) - skipBytes := int(entry.Offset - currentBlobEnd) - if skipBytes < 0 { - return errors.Errorf("overlapping blobs in pack %v", packID) - } - - _, err := bufRd.Discard(skipBytes) - if err != nil { + for { + val, err := it.Next() + if err == ErrPackEOF { + break + } else if err != nil { return err } - h := restic.BlobHandle{ID: entry.ID, Type: entry.Type} - debug.Log(" process blob %v, skipped %d, %v", h, skipBytes, entry) - - if uint(cap(buf)) < entry.Length { - buf = make([]byte, entry.Length) - } - buf = buf[:entry.Length] - - n, err := io.ReadFull(bufRd, buf) - if err != nil { - debug.Log(" read error %v", err) - return errors.Wrap(err, "ReadFull") - } - - if n != len(buf) { - return errors.Errorf("read blob %v from %v: not enough bytes read, want %v, got %v", - h, packID.Str(), len(buf), n) - } - currentBlobEnd = entry.Offset + entry.Length - - if int(entry.Length) <= key.NonceSize() { - debug.Log("%v", blobs) - return errors.Errorf("invalid blob length %v", entry) - } - - // decryption errors are likely permanent, give the caller a chance to skip them - nonce, ciphertext := buf[:key.NonceSize()], buf[key.NonceSize():] - plaintext, err := key.Open(ciphertext[:0], nonce, ciphertext, nil) - if err == nil && entry.IsCompressed() { - // DecodeAll will allocate a slice if it is not large enough since it - // knows the decompressed size (because we're using EncodeAll) - decode, err = dec.DecodeAll(plaintext, decode[:0]) - plaintext = decode - if err != nil { - err = errors.Errorf("decompressing blob %v failed: %v", h, err) - } - } - if err == nil { - id := restic.Hash(plaintext) - if !id.Equal(entry.ID) { - debug.Log("read blob %v/%v from %v: wrong data returned, hash is %v", - h.Type, h.ID, packID.Str(), id) - err = errors.Errorf("read blob %v from %v: wrong data returned, hash is %v", - h, packID.Str(), id) - } - } - - err = handleBlobFn(entry.BlobHandle, plaintext, err) + err = handleBlobFn(val.Handle, val.Plaintext, val.Err) if err != nil { cancel() return backoff.Permanent(err) @@ -1018,6 +964,109 @@ func streamPackPart(ctx context.Context, beLoad BackendLoadFn, key *crypto.Key, return errors.Wrap(err, "StreamPack") } +type PackBlobIterator struct { + packID restic.ID + rd *bufio.Reader + currentOffset uint + + blobs []restic.Blob + key *crypto.Key + dec *zstd.Decoder + + buf []byte + decode []byte +} + +type PackBlobValue struct { + Handle restic.BlobHandle + Plaintext []byte + Err error +} + +var ErrPackEOF = errors.New("reached EOF of pack file") + +func NewPackBlobIterator(packID restic.ID, rd *bufio.Reader, currentOffset uint, + blobs []restic.Blob, key *crypto.Key, dec *zstd.Decoder) *PackBlobIterator { + return &PackBlobIterator{ + packID: packID, + rd: rd, + currentOffset: currentOffset, + blobs: blobs, + key: key, + dec: dec, + } +} + +// Next returns the next blob, an error or ErrPackEOF if all blobs were read +func (b *PackBlobIterator) Next() (PackBlobValue, error) { + if len(b.blobs) == 0 { + return PackBlobValue{}, ErrPackEOF + } + + entry := b.blobs[0] + b.blobs = b.blobs[1:] + + skipBytes := int(entry.Offset - b.currentOffset) + if skipBytes < 0 { + return PackBlobValue{}, errors.Errorf("overlapping blobs in pack %v", b.packID) + } + + _, err := b.rd.Discard(skipBytes) + if err != nil { + return PackBlobValue{}, err + } + b.currentOffset = entry.Offset + + h := restic.BlobHandle{ID: entry.ID, Type: entry.Type} + debug.Log(" process blob %v, skipped %d, %v", h, skipBytes, entry) + + if uint(cap(b.buf)) < entry.Length { + b.buf = make([]byte, entry.Length) + } + b.buf = b.buf[:entry.Length] + + n, err := io.ReadFull(b.rd, b.buf) + if err != nil { + debug.Log(" read error %v", err) + return PackBlobValue{}, errors.Wrap(err, "ReadFull") + } + + if n != len(b.buf) { + return PackBlobValue{}, errors.Errorf("read blob %v from %v: not enough bytes read, want %v, got %v", + h, b.packID.Str(), len(b.buf), n) + } + b.currentOffset = entry.Offset + entry.Length + + if int(entry.Length) <= b.key.NonceSize() { + debug.Log("%v", b.blobs) + return PackBlobValue{}, errors.Errorf("invalid blob length %v", entry) + } + + // decryption errors are likely permanent, give the caller a chance to skip them + nonce, ciphertext := b.buf[:b.key.NonceSize()], b.buf[b.key.NonceSize():] + plaintext, err := b.key.Open(ciphertext[:0], nonce, ciphertext, nil) + if err == nil && entry.IsCompressed() { + // DecodeAll will allocate a slice if it is not large enough since it + // knows the decompressed size (because we're using EncodeAll) + b.decode, err = b.dec.DecodeAll(plaintext, b.decode[:0]) + plaintext = b.decode + if err != nil { + err = errors.Errorf("decompressing blob %v failed: %v", h, err) + } + } + if err == nil { + id := restic.Hash(plaintext) + if !id.Equal(entry.ID) { + debug.Log("read blob %v/%v from %v: wrong data returned, hash is %v", + h.Type, h.ID, b.packID.Str(), id) + err = errors.Errorf("read blob %v from %v: wrong data returned, hash is %v", + h, b.packID.Str(), id) + } + } + + return PackBlobValue{entry.BlobHandle, plaintext, err}, nil +} + var zeroChunkOnce sync.Once var zeroChunkID restic.ID From 22d0c3f8dcc3042219a9cad20d32f641d683f058 Mon Sep 17 00:00:00 2001 From: Michael Eischer Date: Sun, 31 Dec 2023 10:58:26 +0100 Subject: [PATCH 2/4] check: Use PackBlobIterator instead of StreamPack To only stream the content of a pack file once, check used StreamPack with a custom pack load function. This combination was always brittle and complicates using StreamPack everywhere else. Now that StreamPack internally uses PackBlobIterator use that primitive instead, which is a much better fit for what the check command requires. --- internal/checker/checker.go | 81 +++++++++++++++++++------------------ 1 file changed, 41 insertions(+), 40 deletions(-) diff --git a/internal/checker/checker.go b/internal/checker/checker.go index 3bc0fac87..e6a7a9035 100644 --- a/internal/checker/checker.go +++ b/internal/checker/checker.go @@ -10,6 +10,7 @@ import ( "sort" "sync" + "github.com/klauspost/compress/zstd" "github.com/minio/sha256-simd" "github.com/restic/restic/internal/backend" "github.com/restic/restic/internal/backend/s3" @@ -526,7 +527,7 @@ func (c *Checker) GetPacks() map[restic.ID]int64 { } // checkPack reads a pack and checks the integrity of all blobs. -func checkPack(ctx context.Context, r restic.Repository, id restic.ID, blobs []restic.Blob, size int64, bufRd *bufio.Reader) error { +func checkPack(ctx context.Context, r restic.Repository, id restic.ID, blobs []restic.Blob, size int64, bufRd *bufio.Reader, dec *zstd.Decoder) error { debug.Log("checking pack %v", id.String()) if len(blobs) == 0 { @@ -557,49 +558,44 @@ func checkPack(ctx context.Context, r restic.Repository, id restic.ID, blobs []r // calculate hash on-the-fly while reading the pack and capture pack header var hash restic.ID var hdrBuf []byte - hashingLoader := func(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error { - return r.Backend().Load(ctx, h, int(size), 0, func(rd io.Reader) error { - hrd := hashing.NewReader(rd, sha256.New()) - bufRd.Reset(hrd) + h := backend.Handle{Type: backend.PackFile, Name: id.String()} + err := r.Backend().Load(ctx, h, int(size), 0, func(rd io.Reader) error { + hrd := hashing.NewReader(rd, sha256.New()) + bufRd.Reset(hrd) - // skip to start of first blob, offset == 0 for correct pack files - _, err := bufRd.Discard(int(offset)) - if err != nil { + it := repository.NewPackBlobIterator(id, bufRd, 0, blobs, r.Key(), dec) + for { + val, err := it.Next() + if err == repository.ErrPackEOF { + break + } else if err != nil { return err } - - err = fn(bufRd) - if err != nil { - return err + debug.Log(" check blob %v: %v", val.Handle.ID, val.Handle) + if val.Err != nil { + debug.Log(" error verifying blob %v: %v", val.Handle.ID, err) + errs = append(errs, errors.Errorf("blob %v: %v", val.Handle.ID, err)) } - - // skip enough bytes until we reach the possible header start - curPos := length + int(offset) - minHdrStart := int(size) - pack.MaxHeaderSize - if minHdrStart > curPos { - _, err := bufRd.Discard(minHdrStart - curPos) - if err != nil { - return err - } - } - - // read remainder, which should be the pack header - hdrBuf, err = io.ReadAll(bufRd) - if err != nil { - return err - } - - hash = restic.IDFromHash(hrd.Sum(nil)) - return nil - }) - } - - err := repository.StreamPack(ctx, hashingLoader, r.Key(), id, blobs, func(blob restic.BlobHandle, buf []byte, err error) error { - debug.Log(" check blob %v: %v", blob.ID, blob) - if err != nil { - debug.Log(" error verifying blob %v: %v", blob.ID, err) - errs = append(errs, errors.Errorf("blob %v: %v", blob.ID, err)) } + + // skip enough bytes until we reach the possible header start + curPos := lastBlobEnd + minHdrStart := int(size) - pack.MaxHeaderSize + if minHdrStart > curPos { + _, err := bufRd.Discard(minHdrStart - curPos) + if err != nil { + return err + } + } + + // read remainder, which should be the pack header + var err error + hdrBuf, err = io.ReadAll(bufRd) + if err != nil { + return err + } + + hash = restic.IDFromHash(hrd.Sum(nil)) return nil }) if err != nil { @@ -670,6 +666,11 @@ func (c *Checker) ReadPacks(ctx context.Context, packs map[restic.ID]int64, p *p // create a buffer that is large enough to be reused by repository.StreamPack // this ensures that we can read the pack header later on bufRd := bufio.NewReaderSize(nil, repository.MaxStreamBufferSize) + dec, err := zstd.NewReader(nil) + if err != nil { + panic(dec) + } + defer dec.Close() for { var ps checkTask var ok bool @@ -683,7 +684,7 @@ func (c *Checker) ReadPacks(ctx context.Context, packs map[restic.ID]int64, p *p } } - err := checkPack(ctx, c.repo, ps.id, ps.blobs, ps.size, bufRd) + err := checkPack(ctx, c.repo, ps.id, ps.blobs, ps.size, bufRd, dec) p.Add(1) if err == nil { continue From 6b7b5c89e9109210fb7bc9933677d8a54d67866b Mon Sep 17 00:00:00 2001 From: Michael Eischer Date: Sun, 31 Dec 2023 15:45:10 +0100 Subject: [PATCH 3/4] repository: prepare StreamPack refactor --- .../repository/repository_internal_test.go | 278 ++++++++++++++++++ internal/repository/repository_test.go | 273 ----------------- 2 files changed, 278 insertions(+), 273 deletions(-) diff --git a/internal/repository/repository_internal_test.go b/internal/repository/repository_internal_test.go index d8e35b993..fc408910c 100644 --- a/internal/repository/repository_internal_test.go +++ b/internal/repository/repository_internal_test.go @@ -1,11 +1,21 @@ package repository import ( + "bytes" + "context" + "encoding/json" + "io" "math/rand" "sort" + "strings" "testing" + "github.com/cenkalti/backoff/v4" + "github.com/google/go-cmp/cmp" + "github.com/klauspost/compress/zstd" "github.com/restic/restic/internal/backend" + "github.com/restic/restic/internal/crypto" + "github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/restic" rtest "github.com/restic/restic/internal/test" ) @@ -73,3 +83,271 @@ func BenchmarkSortCachedPacksFirst(b *testing.B) { sortCachedPacksFirst(cache, cpy[:]) } } + +// buildPackfileWithoutHeader returns a manually built pack file without a header. +func buildPackfileWithoutHeader(blobSizes []int, key *crypto.Key, compress bool) (blobs []restic.Blob, packfile []byte) { + opts := []zstd.EOption{ + // Set the compression level configured. + zstd.WithEncoderLevel(zstd.SpeedDefault), + // Disable CRC, we have enough checks in place, makes the + // compressed data four bytes shorter. + zstd.WithEncoderCRC(false), + // Set a window of 512kbyte, so we have good lookbehind for usual + // blob sizes. + zstd.WithWindowSize(512 * 1024), + } + enc, err := zstd.NewWriter(nil, opts...) + if err != nil { + panic(err) + } + + var offset uint + for i, size := range blobSizes { + plaintext := rtest.Random(800+i, size) + id := restic.Hash(plaintext) + uncompressedLength := uint(0) + if compress { + uncompressedLength = uint(len(plaintext)) + plaintext = enc.EncodeAll(plaintext, nil) + } + + // we use a deterministic nonce here so the whole process is + // deterministic, last byte is the blob index + var nonce = []byte{ + 0x15, 0x98, 0xc0, 0xf7, 0xb9, 0x65, 0x97, 0x74, + 0x12, 0xdc, 0xd3, 0x62, 0xa9, 0x6e, 0x20, byte(i), + } + + before := len(packfile) + packfile = append(packfile, nonce...) + packfile = key.Seal(packfile, nonce, plaintext, nil) + after := len(packfile) + + ciphertextLength := after - before + + blobs = append(blobs, restic.Blob{ + BlobHandle: restic.BlobHandle{ + Type: restic.DataBlob, + ID: id, + }, + Length: uint(ciphertextLength), + UncompressedLength: uncompressedLength, + Offset: offset, + }) + + offset = uint(len(packfile)) + } + + return blobs, packfile +} + +func TestStreamPack(t *testing.T) { + TestAllVersions(t, testStreamPack) +} + +func testStreamPack(t *testing.T, version uint) { + // always use the same key for deterministic output + const jsonKey = `{"mac":{"k":"eQenuI8adktfzZMuC8rwdA==","r":"k8cfAly2qQSky48CQK7SBA=="},"encrypt":"MKO9gZnRiQFl8mDUurSDa9NMjiu9MUifUrODTHS05wo="}` + + var key crypto.Key + err := json.Unmarshal([]byte(jsonKey), &key) + if err != nil { + t.Fatal(err) + } + + blobSizes := []int{ + 5522811, + 10, + 5231, + 18812, + 123123, + 13522811, + 12301, + 892242, + 28616, + 13351, + 252287, + 188883, + 3522811, + 18883, + } + + var compress bool + switch version { + case 1: + compress = false + case 2: + compress = true + default: + t.Fatal("test does not support repository version", version) + } + + packfileBlobs, packfile := buildPackfileWithoutHeader(blobSizes, &key, compress) + + loadCalls := 0 + shortFirstLoad := false + + loadBytes := func(length int, offset int64) []byte { + data := packfile + + if offset > int64(len(data)) { + offset = 0 + length = 0 + } + data = data[offset:] + + if length > len(data) { + length = len(data) + } + if shortFirstLoad { + length /= 2 + shortFirstLoad = false + } + + return data[:length] + } + + load := func(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error { + data := loadBytes(length, offset) + if shortFirstLoad { + data = data[:len(data)/2] + shortFirstLoad = false + } + + loadCalls++ + + err := fn(bytes.NewReader(data)) + if err == nil { + return nil + } + var permanent *backoff.PermanentError + if errors.As(err, &permanent) { + return err + } + + // retry loading once + return fn(bytes.NewReader(loadBytes(length, offset))) + } + + // first, test regular usage + t.Run("regular", func(t *testing.T) { + tests := []struct { + blobs []restic.Blob + calls int + shortFirstLoad bool + }{ + {packfileBlobs[1:2], 1, false}, + {packfileBlobs[2:5], 1, false}, + {packfileBlobs[2:8], 1, false}, + {[]restic.Blob{ + packfileBlobs[0], + packfileBlobs[4], + packfileBlobs[2], + }, 1, false}, + {[]restic.Blob{ + packfileBlobs[0], + packfileBlobs[len(packfileBlobs)-1], + }, 2, false}, + {packfileBlobs[:], 1, true}, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + gotBlobs := make(map[restic.ID]int) + + handleBlob := func(blob restic.BlobHandle, buf []byte, err error) error { + gotBlobs[blob.ID]++ + + id := restic.Hash(buf) + if !id.Equal(blob.ID) { + t.Fatalf("wrong id %v for blob %s returned", id, blob.ID) + } + + return err + } + + wantBlobs := make(map[restic.ID]int) + for _, blob := range test.blobs { + wantBlobs[blob.ID] = 1 + } + + loadCalls = 0 + shortFirstLoad = test.shortFirstLoad + err = StreamPack(ctx, load, &key, restic.ID{}, test.blobs, handleBlob) + if err != nil { + t.Fatal(err) + } + + if !cmp.Equal(wantBlobs, gotBlobs) { + t.Fatal(cmp.Diff(wantBlobs, gotBlobs)) + } + rtest.Equals(t, test.calls, loadCalls) + }) + } + }) + shortFirstLoad = false + + // next, test invalid uses, which should return an error + t.Run("invalid", func(t *testing.T) { + tests := []struct { + blobs []restic.Blob + err string + }{ + { + // pass one blob several times + blobs: []restic.Blob{ + packfileBlobs[3], + packfileBlobs[8], + packfileBlobs[3], + packfileBlobs[4], + }, + err: "overlapping blobs in pack", + }, + + { + // pass something that's not a valid blob in the current pack file + blobs: []restic.Blob{ + { + Offset: 123, + Length: 20000, + }, + }, + err: "ciphertext verification failed", + }, + + { + // pass a blob that's too small + blobs: []restic.Blob{ + { + Offset: 123, + Length: 10, + }, + }, + err: "invalid blob length", + }, + } + + for _, test := range tests { + t.Run("", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + handleBlob := func(blob restic.BlobHandle, buf []byte, err error) error { + return err + } + + err = StreamPack(ctx, load, &key, restic.ID{}, test.blobs, handleBlob) + if err == nil { + t.Fatalf("wanted error %v, got nil", test.err) + } + + if !strings.Contains(err.Error(), test.err) { + t.Fatalf("wrong error returned, it should contain %q but was %q", test.err, err) + } + }) + } + }) +} diff --git a/internal/repository/repository_test.go b/internal/repository/repository_test.go index 1178a7693..272ea94ac 100644 --- a/internal/repository/repository_test.go +++ b/internal/repository/repository_test.go @@ -4,8 +4,6 @@ import ( "bytes" "context" "crypto/sha256" - "encoding/json" - "errors" "fmt" "io" "math/rand" @@ -15,9 +13,6 @@ import ( "testing" "time" - "github.com/cenkalti/backoff/v4" - "github.com/google/go-cmp/cmp" - "github.com/klauspost/compress/zstd" "github.com/restic/restic/internal/backend" "github.com/restic/restic/internal/backend/local" "github.com/restic/restic/internal/crypto" @@ -430,274 +425,6 @@ func testRepositoryIncrementalIndex(t *testing.T, version uint) { } -// buildPackfileWithoutHeader returns a manually built pack file without a header. -func buildPackfileWithoutHeader(blobSizes []int, key *crypto.Key, compress bool) (blobs []restic.Blob, packfile []byte) { - opts := []zstd.EOption{ - // Set the compression level configured. - zstd.WithEncoderLevel(zstd.SpeedDefault), - // Disable CRC, we have enough checks in place, makes the - // compressed data four bytes shorter. - zstd.WithEncoderCRC(false), - // Set a window of 512kbyte, so we have good lookbehind for usual - // blob sizes. - zstd.WithWindowSize(512 * 1024), - } - enc, err := zstd.NewWriter(nil, opts...) - if err != nil { - panic(err) - } - - var offset uint - for i, size := range blobSizes { - plaintext := rtest.Random(800+i, size) - id := restic.Hash(plaintext) - uncompressedLength := uint(0) - if compress { - uncompressedLength = uint(len(plaintext)) - plaintext = enc.EncodeAll(plaintext, nil) - } - - // we use a deterministic nonce here so the whole process is - // deterministic, last byte is the blob index - var nonce = []byte{ - 0x15, 0x98, 0xc0, 0xf7, 0xb9, 0x65, 0x97, 0x74, - 0x12, 0xdc, 0xd3, 0x62, 0xa9, 0x6e, 0x20, byte(i), - } - - before := len(packfile) - packfile = append(packfile, nonce...) - packfile = key.Seal(packfile, nonce, plaintext, nil) - after := len(packfile) - - ciphertextLength := after - before - - blobs = append(blobs, restic.Blob{ - BlobHandle: restic.BlobHandle{ - Type: restic.DataBlob, - ID: id, - }, - Length: uint(ciphertextLength), - UncompressedLength: uncompressedLength, - Offset: offset, - }) - - offset = uint(len(packfile)) - } - - return blobs, packfile -} - -func TestStreamPack(t *testing.T) { - repository.TestAllVersions(t, testStreamPack) -} - -func testStreamPack(t *testing.T, version uint) { - // always use the same key for deterministic output - const jsonKey = `{"mac":{"k":"eQenuI8adktfzZMuC8rwdA==","r":"k8cfAly2qQSky48CQK7SBA=="},"encrypt":"MKO9gZnRiQFl8mDUurSDa9NMjiu9MUifUrODTHS05wo="}` - - var key crypto.Key - err := json.Unmarshal([]byte(jsonKey), &key) - if err != nil { - t.Fatal(err) - } - - blobSizes := []int{ - 5522811, - 10, - 5231, - 18812, - 123123, - 13522811, - 12301, - 892242, - 28616, - 13351, - 252287, - 188883, - 3522811, - 18883, - } - - var compress bool - switch version { - case 1: - compress = false - case 2: - compress = true - default: - t.Fatal("test does not support repository version", version) - } - - packfileBlobs, packfile := buildPackfileWithoutHeader(blobSizes, &key, compress) - - loadCalls := 0 - shortFirstLoad := false - - loadBytes := func(length int, offset int64) []byte { - data := packfile - - if offset > int64(len(data)) { - offset = 0 - length = 0 - } - data = data[offset:] - - if length > len(data) { - length = len(data) - } - if shortFirstLoad { - length /= 2 - shortFirstLoad = false - } - - return data[:length] - } - - load := func(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error { - data := loadBytes(length, offset) - if shortFirstLoad { - data = data[:len(data)/2] - shortFirstLoad = false - } - - loadCalls++ - - err := fn(bytes.NewReader(data)) - if err == nil { - return nil - } - var permanent *backoff.PermanentError - if errors.As(err, &permanent) { - return err - } - - // retry loading once - return fn(bytes.NewReader(loadBytes(length, offset))) - } - - // first, test regular usage - t.Run("regular", func(t *testing.T) { - tests := []struct { - blobs []restic.Blob - calls int - shortFirstLoad bool - }{ - {packfileBlobs[1:2], 1, false}, - {packfileBlobs[2:5], 1, false}, - {packfileBlobs[2:8], 1, false}, - {[]restic.Blob{ - packfileBlobs[0], - packfileBlobs[4], - packfileBlobs[2], - }, 1, false}, - {[]restic.Blob{ - packfileBlobs[0], - packfileBlobs[len(packfileBlobs)-1], - }, 2, false}, - {packfileBlobs[:], 1, true}, - } - - for _, test := range tests { - t.Run("", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - gotBlobs := make(map[restic.ID]int) - - handleBlob := func(blob restic.BlobHandle, buf []byte, err error) error { - gotBlobs[blob.ID]++ - - id := restic.Hash(buf) - if !id.Equal(blob.ID) { - t.Fatalf("wrong id %v for blob %s returned", id, blob.ID) - } - - return err - } - - wantBlobs := make(map[restic.ID]int) - for _, blob := range test.blobs { - wantBlobs[blob.ID] = 1 - } - - loadCalls = 0 - shortFirstLoad = test.shortFirstLoad - err = repository.StreamPack(ctx, load, &key, restic.ID{}, test.blobs, handleBlob) - if err != nil { - t.Fatal(err) - } - - if !cmp.Equal(wantBlobs, gotBlobs) { - t.Fatal(cmp.Diff(wantBlobs, gotBlobs)) - } - rtest.Equals(t, test.calls, loadCalls) - }) - } - }) - shortFirstLoad = false - - // next, test invalid uses, which should return an error - t.Run("invalid", func(t *testing.T) { - tests := []struct { - blobs []restic.Blob - err string - }{ - { - // pass one blob several times - blobs: []restic.Blob{ - packfileBlobs[3], - packfileBlobs[8], - packfileBlobs[3], - packfileBlobs[4], - }, - err: "overlapping blobs in pack", - }, - - { - // pass something that's not a valid blob in the current pack file - blobs: []restic.Blob{ - { - Offset: 123, - Length: 20000, - }, - }, - err: "ciphertext verification failed", - }, - - { - // pass a blob that's too small - blobs: []restic.Blob{ - { - Offset: 123, - Length: 10, - }, - }, - err: "invalid blob length", - }, - } - - for _, test := range tests { - t.Run("", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - handleBlob := func(blob restic.BlobHandle, buf []byte, err error) error { - return err - } - - err = repository.StreamPack(ctx, load, &key, restic.ID{}, test.blobs, handleBlob) - if err == nil { - t.Fatalf("wanted error %v, got nil", test.err) - } - - if !strings.Contains(err.Error(), test.err) { - t.Fatalf("wrong error returned, it should contain %q but was %q", test.err, err) - } - }) - } - }) -} - func TestInvalidCompression(t *testing.T) { var comp repository.CompressionMode err := comp.Set("nope") From 2c310a526e9c0c0c7e26313c4ee06a94328e6395 Mon Sep 17 00:00:00 2001 From: Michael Eischer Date: Sun, 31 Dec 2023 12:07:19 +0100 Subject: [PATCH 4/4] repository: Replace StreamPack function with LoadBlobsFromPack method LoadBlobsFromPack is now part of the repository struct. This ensures that users of that method don't have to deal will internals of the repository implementation. The filerestorer tests now also contain far fewer pack file implementation details. --- cmd/restic/cmd_repair_packs.go | 2 +- internal/repository/repack.go | 2 +- internal/repository/repository.go | 14 +- .../repository/repository_internal_test.go | 4 +- internal/restic/repository.go | 1 + internal/restorer/filerestorer.go | 16 +-- internal/restorer/filerestorer_test.go | 132 ++++++------------ internal/restorer/restorer.go | 2 +- 8 files changed, 66 insertions(+), 107 deletions(-) diff --git a/cmd/restic/cmd_repair_packs.go b/cmd/restic/cmd_repair_packs.go index 7d1a3a392..723bdbccb 100644 --- a/cmd/restic/cmd_repair_packs.go +++ b/cmd/restic/cmd_repair_packs.go @@ -116,7 +116,7 @@ func repairPacks(ctx context.Context, gopts GlobalOptions, repo *repository.Repo continue } - err = repository.StreamPack(wgCtx, repo.Backend().Load, repo.Key(), b.PackID, blobs, func(blob restic.BlobHandle, buf []byte, err error) error { + err = repo.LoadBlobsFromPack(wgCtx, b.PackID, blobs, func(blob restic.BlobHandle, buf []byte, err error) error { if err != nil { // Fallback path buf, err = repo.LoadBlob(wgCtx, blob.Type, blob.ID, nil) diff --git a/internal/repository/repack.go b/internal/repository/repack.go index c82e63f28..5588984f6 100644 --- a/internal/repository/repack.go +++ b/internal/repository/repack.go @@ -77,7 +77,7 @@ func repack(ctx context.Context, repo restic.Repository, dstRepo restic.Reposito worker := func() error { for t := range downloadQueue { - err := StreamPack(wgCtx, repo.Backend().Load, repo.Key(), t.PackID, t.Blobs, func(blob restic.BlobHandle, buf []byte, err error) error { + err := repo.LoadBlobsFromPack(wgCtx, t.PackID, t.Blobs, func(blob restic.BlobHandle, buf []byte, err error) error { if err != nil { var ierr error // check whether we can get a valid copy somewhere else diff --git a/internal/repository/repository.go b/internal/repository/repository.go index e13220741..407b6429c 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -875,16 +875,20 @@ func (r *Repository) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte return newID, known, size, err } -type BackendLoadFn func(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error +type backendLoadFn func(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error // Skip sections with more than 4MB unused blobs const maxUnusedRange = 4 * 1024 * 1024 -// StreamPack loads the listed blobs from the specified pack file. The plaintext blob is passed to +// LoadBlobsFromPack loads the listed blobs from the specified pack file. The plaintext blob is passed to // the handleBlobFn callback or an error if decryption failed or the blob hash does not match. // handleBlobFn is called at most once for each blob. If the callback returns an error, -// then StreamPack will abort and not retry it. -func StreamPack(ctx context.Context, beLoad BackendLoadFn, key *crypto.Key, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error { +// then LoadBlobsFromPack will abort and not retry it. +func (r *Repository) LoadBlobsFromPack(ctx context.Context, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error { + return streamPack(ctx, r.Backend().Load, r.key, packID, blobs, handleBlobFn) +} + +func streamPack(ctx context.Context, beLoad backendLoadFn, key *crypto.Key, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error { if len(blobs) == 0 { // nothing to do return nil @@ -915,7 +919,7 @@ func StreamPack(ctx context.Context, beLoad BackendLoadFn, key *crypto.Key, pack return streamPackPart(ctx, beLoad, key, packID, blobs[lowerIdx:], handleBlobFn) } -func streamPackPart(ctx context.Context, beLoad BackendLoadFn, key *crypto.Key, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error { +func streamPackPart(ctx context.Context, beLoad backendLoadFn, key *crypto.Key, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error { h := backend.Handle{Type: restic.PackFile, Name: packID.String(), IsMetadata: false} dataStart := blobs[0].Offset diff --git a/internal/repository/repository_internal_test.go b/internal/repository/repository_internal_test.go index fc408910c..eed99c7e0 100644 --- a/internal/repository/repository_internal_test.go +++ b/internal/repository/repository_internal_test.go @@ -276,7 +276,7 @@ func testStreamPack(t *testing.T, version uint) { loadCalls = 0 shortFirstLoad = test.shortFirstLoad - err = StreamPack(ctx, load, &key, restic.ID{}, test.blobs, handleBlob) + err = streamPack(ctx, load, &key, restic.ID{}, test.blobs, handleBlob) if err != nil { t.Fatal(err) } @@ -339,7 +339,7 @@ func testStreamPack(t *testing.T, version uint) { return err } - err = StreamPack(ctx, load, &key, restic.ID{}, test.blobs, handleBlob) + err = streamPack(ctx, load, &key, restic.ID{}, test.blobs, handleBlob) if err == nil { t.Fatalf("wanted error %v, got nil", test.err) } diff --git a/internal/restic/repository.go b/internal/restic/repository.go index 895c930dd..6818847c0 100644 --- a/internal/restic/repository.go +++ b/internal/restic/repository.go @@ -44,6 +44,7 @@ type Repository interface { ListPack(context.Context, ID, int64) ([]Blob, uint32, error) LoadBlob(context.Context, BlobType, ID, []byte) ([]byte, error) + LoadBlobsFromPack(ctx context.Context, packID ID, blobs []Blob, handleBlobFn func(blob BlobHandle, buf []byte, err error) error) error SaveBlob(context.Context, BlobType, []byte, ID, bool) (ID, bool, int, error) // StartPackUploader start goroutines to upload new pack files. The errgroup diff --git a/internal/restorer/filerestorer.go b/internal/restorer/filerestorer.go index 99a460321..f2c134ea9 100644 --- a/internal/restorer/filerestorer.go +++ b/internal/restorer/filerestorer.go @@ -7,7 +7,6 @@ import ( "golang.org/x/sync/errgroup" - "github.com/restic/restic/internal/crypto" "github.com/restic/restic/internal/debug" "github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/repository" @@ -45,11 +44,12 @@ type packInfo struct { files map[*fileInfo]struct{} // set of files that use blobs from this pack } +type blobsLoaderFn func(ctx context.Context, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error + // fileRestorer restores set of files type fileRestorer struct { - key *crypto.Key - idx func(restic.BlobHandle) []restic.PackedBlob - packLoader repository.BackendLoadFn + idx func(restic.BlobHandle) []restic.PackedBlob + blobsLoader blobsLoaderFn workerCount int filesWriter *filesWriter @@ -63,8 +63,7 @@ type fileRestorer struct { } func newFileRestorer(dst string, - packLoader repository.BackendLoadFn, - key *crypto.Key, + blobsLoader blobsLoaderFn, idx func(restic.BlobHandle) []restic.PackedBlob, connections uint, sparse bool, @@ -74,9 +73,8 @@ func newFileRestorer(dst string, workerCount := int(connections) return &fileRestorer{ - key: key, idx: idx, - packLoader: packLoader, + blobsLoader: blobsLoader, filesWriter: newFilesWriter(workerCount), zeroChunk: repository.ZeroChunk(), sparse: sparse, @@ -310,7 +308,7 @@ func (r *fileRestorer) downloadBlobs(ctx context.Context, packID restic.ID, for _, entry := range blobs { blobList = append(blobList, entry.blob) } - return repository.StreamPack(ctx, r.packLoader, r.key, packID, blobList, + return r.blobsLoader(ctx, packID, blobList, func(h restic.BlobHandle, blobData []byte, err error) error { processedBlobs.Insert(h) blob := blobs[h.ID] diff --git a/internal/restorer/filerestorer_test.go b/internal/restorer/filerestorer_test.go index c5bc3fe31..befeb5d2c 100644 --- a/internal/restorer/filerestorer_test.go +++ b/internal/restorer/filerestorer_test.go @@ -4,14 +4,11 @@ import ( "bytes" "context" "fmt" - "io" "os" + "sort" "testing" - "github.com/restic/restic/internal/backend" - "github.com/restic/restic/internal/crypto" "github.com/restic/restic/internal/errors" - "github.com/restic/restic/internal/repository" "github.com/restic/restic/internal/restic" rtest "github.com/restic/restic/internal/test" ) @@ -27,11 +24,6 @@ type TestFile struct { } type TestRepo struct { - key *crypto.Key - - // pack names and ids - packsNameToID map[string]restic.ID - packsIDToName map[restic.ID]string packsIDToData map[restic.ID][]byte // blobs and files @@ -40,7 +32,7 @@ type TestRepo struct { filesPathToContent map[string]string // - loader repository.BackendLoadFn + loader blobsLoaderFn } func (i *TestRepo) Lookup(bh restic.BlobHandle) []restic.PackedBlob { @@ -59,16 +51,6 @@ func newTestRepo(content []TestFile) *TestRepo { blobs map[restic.ID]restic.Blob } packs := make(map[string]Pack) - - key := crypto.NewRandomKey() - seal := func(data []byte) []byte { - ciphertext := crypto.NewBlobBuffer(len(data)) - ciphertext = ciphertext[:0] // truncate the slice - nonce := crypto.NewRandomNonce() - ciphertext = append(ciphertext, nonce...) - return key.Seal(ciphertext, nonce, data, nil) - } - filesPathToContent := make(map[string]string) for _, file := range content { @@ -86,14 +68,15 @@ func newTestRepo(content []TestFile) *TestRepo { // calculate blob id and add to the pack as necessary blobID := restic.Hash([]byte(blob.data)) if _, found := pack.blobs[blobID]; !found { - blobData := seal([]byte(blob.data)) + blobData := []byte(blob.data) pack.blobs[blobID] = restic.Blob{ BlobHandle: restic.BlobHandle{ Type: restic.DataBlob, ID: blobID, }, - Length: uint(len(blobData)), - Offset: uint(len(pack.data)), + Length: uint(len(blobData)), + UncompressedLength: uint(len(blobData)), + Offset: uint(len(pack.data)), } pack.data = append(pack.data, blobData...) } @@ -104,15 +87,11 @@ func newTestRepo(content []TestFile) *TestRepo { } blobs := make(map[restic.ID][]restic.PackedBlob) - packsIDToName := make(map[restic.ID]string) packsIDToData := make(map[restic.ID][]byte) - packsNameToID := make(map[string]restic.ID) for _, pack := range packs { packID := restic.Hash(pack.data) - packsIDToName[packID] = pack.name packsIDToData[packID] = pack.data - packsNameToID[pack.name] = packID for blobID, blob := range pack.blobs { blobs[blobID] = append(blobs[blobID], restic.PackedBlob{Blob: blob, PackID: packID}) } @@ -128,30 +107,44 @@ func newTestRepo(content []TestFile) *TestRepo { } repo := &TestRepo{ - key: key, - packsIDToName: packsIDToName, packsIDToData: packsIDToData, - packsNameToID: packsNameToID, blobs: blobs, files: files, filesPathToContent: filesPathToContent, } - repo.loader = func(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error { - packID, err := restic.ParseID(h.Name) - if err != nil { - return err + repo.loader = func(ctx context.Context, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error { + blobs = append([]restic.Blob{}, blobs...) + sort.Slice(blobs, func(i, j int) bool { + return blobs[i].Offset < blobs[j].Offset + }) + + for _, blob := range blobs { + found := false + for _, e := range repo.blobs[blob.ID] { + if packID == e.PackID { + found = true + buf := repo.packsIDToData[packID][e.Offset : e.Offset+e.Length] + err := handleBlobFn(e.BlobHandle, buf, nil) + if err != nil { + return err + } + } + } + if !found { + return fmt.Errorf("missing blob: %v", blob) + } } - rd := bytes.NewReader(repo.packsIDToData[packID][int(offset) : int(offset)+length]) - return fn(rd) + return nil } return repo } func restoreAndVerify(t *testing.T, tempdir string, content []TestFile, files map[string]bool, sparse bool) { + t.Helper() repo := newTestRepo(content) - r := newFileRestorer(tempdir, repo.loader, repo.key, repo.Lookup, 2, sparse, nil) + r := newFileRestorer(tempdir, repo.loader, repo.Lookup, 2, sparse, nil) if files == nil { r.files = repo.files @@ -170,6 +163,7 @@ func restoreAndVerify(t *testing.T, tempdir string, content []TestFile, files ma } func verifyRestore(t *testing.T, r *fileRestorer, repo *TestRepo) { + t.Helper() for _, file := range r.files { target := r.targetPath(file.location) data, err := os.ReadFile(target) @@ -283,62 +277,17 @@ func TestErrorRestoreFiles(t *testing.T) { loadError := errors.New("load error") // loader always returns an error - repo.loader = func(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error { + repo.loader = func(ctx context.Context, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error { return loadError } - r := newFileRestorer(tempdir, repo.loader, repo.key, repo.Lookup, 2, false, nil) + r := newFileRestorer(tempdir, repo.loader, repo.Lookup, 2, false, nil) r.files = repo.files err := r.restoreFiles(context.TODO()) rtest.Assert(t, errors.Is(err, loadError), "got %v, expected contained error %v", err, loadError) } -func TestDownloadError(t *testing.T) { - for i := 0; i < 100; i += 10 { - testPartialDownloadError(t, i) - } -} - -func testPartialDownloadError(t *testing.T, part int) { - tempdir := rtest.TempDir(t) - content := []TestFile{ - { - name: "file1", - blobs: []TestBlob{ - {"data1-1", "pack1"}, - {"data1-2", "pack1"}, - {"data1-3", "pack1"}, - }, - }} - - repo := newTestRepo(content) - - // loader always returns an error - loader := repo.loader - repo.loader = func(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error { - // only load partial data to exercise fault handling in different places - err := loader(ctx, h, length*part/100, offset, fn) - if err == nil { - return nil - } - fmt.Println("Retry after error", err) - return loader(ctx, h, length, offset, fn) - } - - r := newFileRestorer(tempdir, repo.loader, repo.key, repo.Lookup, 2, false, nil) - r.files = repo.files - r.Error = func(s string, e error) error { - // ignore errors as in the `restore` command - fmt.Println("error during restore", s, e) - return nil - } - - err := r.restoreFiles(context.TODO()) - rtest.OK(t, err) - verifyRestore(t, r, repo) -} - func TestFatalDownloadError(t *testing.T) { tempdir := rtest.TempDir(t) content := []TestFile{ @@ -361,12 +310,19 @@ func TestFatalDownloadError(t *testing.T) { repo := newTestRepo(content) loader := repo.loader - repo.loader = func(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error { - // only return half the data to break file2 - return loader(ctx, h, length/2, offset, fn) + repo.loader = func(ctx context.Context, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error { + ctr := 0 + return loader(ctx, packID, blobs, func(blob restic.BlobHandle, buf []byte, err error) error { + if ctr < 2 { + ctr++ + return handleBlobFn(blob, buf, err) + } + // break file2 + return errors.New("failed to load blob") + }) } - r := newFileRestorer(tempdir, repo.loader, repo.key, repo.Lookup, 2, false, nil) + r := newFileRestorer(tempdir, repo.loader, repo.Lookup, 2, false, nil) r.files = repo.files var errors []string diff --git a/internal/restorer/restorer.go b/internal/restorer/restorer.go index e973316c0..2ce1ee98e 100644 --- a/internal/restorer/restorer.go +++ b/internal/restorer/restorer.go @@ -231,7 +231,7 @@ func (res *Restorer) RestoreTo(ctx context.Context, dst string) error { } idx := NewHardlinkIndex[string]() - filerestorer := newFileRestorer(dst, res.repo.Backend().Load, res.repo.Key(), res.repo.Index().Lookup, + filerestorer := newFileRestorer(dst, res.repo.LoadBlobsFromPack, res.repo.Index().Lookup, res.repo.Connections(), res.sparse, res.progress) filerestorer.Error = res.Error