diff --git a/internal/repository/repository.go b/internal/repository/repository.go index 4a340222f..f028b57ad 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -42,8 +42,10 @@ type Repository struct { treePM *packerManager dataPM *packerManager - enc *zstd.Encoder - dec *zstd.Decoder + allocEnc sync.Once + allocDec sync.Once + enc *zstd.Encoder + dec *zstd.Decoder } // New returns a new repository with backend be. @@ -55,16 +57,6 @@ func New(be restic.Backend) *Repository { treePM: newPackerManager(be, nil), } - enc, err := zstd.NewWriter(nil) - if err != nil { - panic(err) - } - repo.enc = enc - dec, err := zstd.NewReader(nil) - if err != nil { - panic(err) - } - repo.dec = dec return repo } @@ -236,7 +228,7 @@ func (r *Repository) LoadBlob(ctx context.Context, t restic.BlobType, id restic. } if blob.IsCompressed() { - plaintext, err = r.dec.DecodeAll(plaintext, make([]byte, 0, blob.DataLength())) + plaintext, err = r.getZstdDecoder().DecodeAll(plaintext, make([]byte, 0, blob.DataLength())) if err != nil { lastError = errors.Errorf("decompressing blob %v failed: %v", id, err) continue @@ -280,6 +272,28 @@ func (r *Repository) LookupBlobSize(id restic.ID, tpe restic.BlobType) (uint, bo return r.idx.LookupSize(restic.BlobHandle{ID: id, Type: tpe}) } +func (r *Repository) getZstdEncoder() *zstd.Encoder { + r.allocEnc.Do(func() { + enc, err := zstd.NewWriter(nil) + if err != nil { + panic(err) + } + r.enc = enc + }) + return r.enc +} + +func (r *Repository) getZstdDecoder() *zstd.Decoder { + r.allocDec.Do(func() { + dec, err := zstd.NewReader(nil) + if err != nil { + panic(err) + } + r.dec = dec + }) + return r.dec +} + // saveAndEncrypt encrypts data and stores it to the backend as type t. If data // is small enough, it will be packed together with other small blobs. // The caller must ensure that the id matches the data. @@ -289,7 +303,7 @@ func (r *Repository) saveAndEncrypt(ctx context.Context, t restic.BlobType, data uncompressedLength := 0 if r.cfg.Version > 1 { uncompressedLength = len(data) - data = r.enc.EncodeAll(data, nil) + data = r.getZstdEncoder().EncodeAll(data, nil) } nonce := crypto.NewRandomNonce() @@ -354,7 +368,7 @@ func (r *Repository) compressUnpacked(p []byte) ([]byte, error) { // version byte out := []byte{2} - out = r.enc.EncodeAll(p, out) + out = r.getZstdEncoder().EncodeAll(p, out) return out, nil } @@ -377,7 +391,7 @@ func (r *Repository) decompressUnpacked(p []byte) ([]byte, error) { return nil, errors.New("not supported encoding format") } - return r.dec.DecodeAll(p[1:], nil) + return r.getZstdDecoder().DecodeAll(p[1:], nil) } // SaveUnpacked encrypts data and stores it in the backend. Returned is the