diff --git a/backend/hashing_reader_test.go b/backend/hashing_reader_test.go deleted file mode 100644 index a825e61c4..000000000 --- a/backend/hashing_reader_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package backend_test - -import ( - "bytes" - "crypto/rand" - "crypto/sha256" - "io" - "testing" - - "github.com/restic/restic/backend" -) - -func TestHashReader(t *testing.T) { - tests := []int{5, 23, 2<<18 + 23, 1 << 20} - - for _, size := range tests { - data := make([]byte, size) - _, err := io.ReadFull(rand.Reader, data) - if err != nil { - t.Fatalf("ReadFull: %v", err) - } - - expectedHash := sha256.Sum256(data) - - rd := backend.NewHashReader(bytes.NewReader(data), sha256.New()) - - target := bytes.NewBuffer(nil) - n, err := io.Copy(target, rd) - ok(t, err) - - assert(t, n == int64(size)+int64(len(expectedHash)), - "HashReader: invalid number of bytes read: got %d, expected %d", - n, size+len(expectedHash)) - - r := target.Bytes() - resultingHash := r[len(r)-len(expectedHash):] - assert(t, bytes.Equal(expectedHash[:], resultingHash), - "HashReader: hashes do not match: expected %02x, got %02x", - expectedHash, resultingHash) - - // try to read again, must return io.EOF - n2, err := rd.Read(make([]byte, 100)) - assert(t, n2 == 0, "HashReader returned %d additional bytes", n) - assert(t, err == io.EOF, "HashReader returned %v instead of EOF", err) - } -} diff --git a/backend/hashing_reader.go b/backend/reader.go similarity index 50% rename from backend/hashing_reader.go rename to backend/reader.go index 938e1aa99..eabe9527b 100644 --- a/backend/hashing_reader.go +++ b/backend/reader.go @@ -5,22 +5,22 @@ import ( "io" ) -type HashReader struct { +type HashAppendReader struct { r io.Reader h hash.Hash sum []byte closed bool } -func NewHashReader(r io.Reader, h hash.Hash) *HashReader { - return &HashReader{ +func NewHashAppendReader(r io.Reader, h hash.Hash) *HashAppendReader { + return &HashAppendReader{ h: h, r: io.TeeReader(r, h), sum: make([]byte, 0, h.Size()), } } -func (h *HashReader) Read(p []byte) (n int, err error) { +func (h *HashAppendReader) Read(p []byte) (n int, err error) { if !h.closed { n, err = h.r.Read(p) @@ -51,3 +51,23 @@ func (h *HashReader) Read(p []byte) (n int, err error) { return } + +type HashingReader struct { + r io.Reader + h hash.Hash +} + +func NewHashingReader(r io.Reader, h hash.Hash) *HashingReader { + return &HashingReader{ + h: h, + r: io.TeeReader(r, h), + } +} + +func (h *HashingReader) Read(p []byte) (int, error) { + return h.r.Read(p) +} + +func (h *HashingReader) Sum(d []byte) []byte { + return h.h.Sum(d) +} diff --git a/backend/reader_test.go b/backend/reader_test.go new file mode 100644 index 000000000..708fbceb2 --- /dev/null +++ b/backend/reader_test.go @@ -0,0 +1,80 @@ +package backend_test + +import ( + "bytes" + "crypto/rand" + "crypto/sha256" + "io" + "io/ioutil" + "testing" + + "github.com/restic/restic/backend" +) + +func TestHashAppendReader(t *testing.T) { + tests := []int{5, 23, 2<<18 + 23, 1 << 20} + + for _, size := range tests { + data := make([]byte, size) + _, err := io.ReadFull(rand.Reader, data) + if err != nil { + t.Fatalf("ReadFull: %v", err) + } + + expectedHash := sha256.Sum256(data) + + rd := backend.NewHashAppendReader(bytes.NewReader(data), sha256.New()) + + target := bytes.NewBuffer(nil) + n, err := io.Copy(target, rd) + ok(t, err) + + assert(t, n == int64(size)+int64(len(expectedHash)), + "HashAppendReader: invalid number of bytes read: got %d, expected %d", + n, size+len(expectedHash)) + + r := target.Bytes() + resultingHash := r[len(r)-len(expectedHash):] + assert(t, bytes.Equal(expectedHash[:], resultingHash), + "HashAppendReader: hashes do not match: expected %02x, got %02x", + expectedHash, resultingHash) + + // try to read again, must return io.EOF + n2, err := rd.Read(make([]byte, 100)) + assert(t, n2 == 0, "HashAppendReader returned %d additional bytes", n) + assert(t, err == io.EOF, "HashAppendReader returned %v instead of EOF", err) + } +} + +func TestHashingReader(t *testing.T) { + tests := []int{5, 23, 2<<18 + 23, 1 << 20} + + for _, size := range tests { + data := make([]byte, size) + _, err := io.ReadFull(rand.Reader, data) + if err != nil { + t.Fatalf("ReadFull: %v", err) + } + + expectedHash := sha256.Sum256(data) + + rd := backend.NewHashingReader(bytes.NewReader(data), sha256.New()) + + n, err := io.Copy(ioutil.Discard, rd) + ok(t, err) + + assert(t, n == int64(size), + "HashAppendReader: invalid number of bytes read: got %d, expected %d", + n, size+len(expectedHash)) + + resultingHash := rd.Sum(nil) + assert(t, bytes.Equal(expectedHash[:], resultingHash), + "HashAppendReader: hashes do not match: expected %02x, got %02x", + expectedHash, resultingHash) + + // try to read again, must return io.EOF + n2, err := rd.Read(make([]byte, 100)) + assert(t, n2 == 0, "HashAppendReader returned %d additional bytes", n) + assert(t, err == io.EOF, "HashAppendReader returned %v instead of EOF", err) + } +} diff --git a/key.go b/key.go index a7d67dafc..fcc5fba17 100644 --- a/key.go +++ b/key.go @@ -341,7 +341,7 @@ func (k *Key) encryptFrom(ks *keys, rd io.Reader) io.Reader { S: cipher.NewCTR(c, iv), } - return backend.NewHashReader(io.MultiReader(ivReader, encryptReader), + return backend.NewHashAppendReader(io.MultiReader(ivReader, encryptReader), hmac.New(sha256.New, ks.Sign)) }