From 31624aeffd4cf9a06afbb16db7ffd7e899187bb6 Mon Sep 17 00:00:00 2001 From: Michael Eischer Date: Sat, 30 Mar 2024 00:19:58 +0100 Subject: [PATCH] Improve command shutdown on context cancellation --- cmd/restic/cmd_check.go | 14 +++++++++++++- cmd/restic/cmd_copy.go | 5 ++++- cmd/restic/cmd_find.go | 3 +++ cmd/restic/cmd_forget.go | 7 +++++++ cmd/restic/cmd_prune.go | 3 +++ cmd/restic/cmd_recover.go | 6 ++++++ cmd/restic/cmd_repair_snapshots.go | 3 +++ cmd/restic/cmd_rewrite.go | 3 +++ cmd/restic/cmd_snapshots.go | 3 +++ cmd/restic/cmd_stats.go | 5 ++--- cmd/restic/cmd_tag.go | 3 +++ internal/archiver/archiver.go | 1 + internal/archiver/tree_saver.go | 4 ++++ internal/index/master_index.go | 3 +++ internal/repository/prune.go | 18 ++++++++++++++++++ internal/repository/repack.go | 2 +- internal/repository/repository.go | 3 +++ 17 files changed, 80 insertions(+), 6 deletions(-) diff --git a/cmd/restic/cmd_check.go b/cmd/restic/cmd_check.go index 83ebf89a6..38623c305 100644 --- a/cmd/restic/cmd_check.go +++ b/cmd/restic/cmd_check.go @@ -219,6 +219,9 @@ func runCheck(ctx context.Context, opts CheckOptions, gopts GlobalOptions, args Verbosef("load indexes\n") bar := newIndexProgress(gopts.Quiet, gopts.JSON) hints, errs := chkr.LoadIndex(ctx, bar) + if ctx.Err() != nil { + return ctx.Err() + } errorsFound := false suggestIndexRebuild := false @@ -280,6 +283,9 @@ func runCheck(ctx context.Context, opts CheckOptions, gopts GlobalOptions, args if orphanedPacks > 0 { Verbosef("%d additional files were found in the repo, which likely contain duplicate data.\nThis is non-critical, you can run `restic prune` to correct this.\n", orphanedPacks) } + if ctx.Err() != nil { + return ctx.Err() + } Verbosef("check snapshots, trees and blobs\n") errChan = make(chan error) @@ -313,6 +319,9 @@ func runCheck(ctx context.Context, opts CheckOptions, gopts GlobalOptions, args // Must happen after `errChan` is read from in the above loop to avoid // deadlocking in the case of errors. wg.Wait() + if ctx.Err() != nil { + return ctx.Err() + } if opts.CheckUnused { for _, id := range chkr.UnusedBlobs(ctx) { @@ -392,10 +401,13 @@ func runCheck(ctx context.Context, opts CheckOptions, gopts GlobalOptions, args doReadData(packs) } + if ctx.Err() != nil { + return ctx.Err() + } + if errorsFound { return errors.Fatal("repository contains errors") } - Verbosef("no errors were found\n") return nil diff --git a/cmd/restic/cmd_copy.go b/cmd/restic/cmd_copy.go index de3958def..ad6c58a25 100644 --- a/cmd/restic/cmd_copy.go +++ b/cmd/restic/cmd_copy.go @@ -103,6 +103,9 @@ func runCopy(ctx context.Context, opts CopyOptions, gopts GlobalOptions, args [] // also consider identical snapshot copies dstSnapshotByOriginal[*sn.ID()] = append(dstSnapshotByOriginal[*sn.ID()], sn) } + if ctx.Err() != nil { + return ctx.Err() + } // remember already processed trees across all snapshots visitedTrees := restic.NewIDSet() @@ -147,7 +150,7 @@ func runCopy(ctx context.Context, opts CopyOptions, gopts GlobalOptions, args [] } Verbosef("snapshot %s saved\n", newID.Str()) } - return nil + return ctx.Err() } func similarSnapshots(sna *restic.Snapshot, snb *restic.Snapshot) bool { diff --git a/cmd/restic/cmd_find.go b/cmd/restic/cmd_find.go index e29fe30dc..77b651c5e 100644 --- a/cmd/restic/cmd_find.go +++ b/cmd/restic/cmd_find.go @@ -608,6 +608,9 @@ func runFind(ctx context.Context, opts FindOptions, gopts GlobalOptions, args [] for sn := range FindFilteredSnapshots(ctx, snapshotLister, repo, &opts.SnapshotFilter, opts.Snapshots) { filteredSnapshots = append(filteredSnapshots, sn) } + if ctx.Err() != nil { + return ctx.Err() + } sort.Slice(filteredSnapshots, func(i, j int) bool { return filteredSnapshots[i].Time.Before(filteredSnapshots[j].Time) diff --git a/cmd/restic/cmd_forget.go b/cmd/restic/cmd_forget.go index 9018da211..92eeed4a1 100644 --- a/cmd/restic/cmd_forget.go +++ b/cmd/restic/cmd_forget.go @@ -188,6 +188,9 @@ func runForget(ctx context.Context, opts ForgetOptions, pruneOptions PruneOption for sn := range FindFilteredSnapshots(ctx, repo, repo, &opts.SnapshotFilter, args) { snapshots = append(snapshots, sn) } + if ctx.Err() != nil { + return ctx.Err() + } var jsonGroups []*ForgetGroup @@ -270,6 +273,10 @@ func runForget(ctx context.Context, opts ForgetOptions, pruneOptions PruneOption } } + if ctx.Err() != nil { + return ctx.Err() + } + if len(removeSnIDs) > 0 { if !opts.DryRun { bar := printer.NewCounter("files deleted") diff --git a/cmd/restic/cmd_prune.go b/cmd/restic/cmd_prune.go index ea5acddf3..cbec100df 100644 --- a/cmd/restic/cmd_prune.go +++ b/cmd/restic/cmd_prune.go @@ -197,6 +197,9 @@ func runPruneWithRepo(ctx context.Context, opts PruneOptions, gopts GlobalOption if err != nil { return err } + if ctx.Err() != nil { + return ctx.Err() + } if popts.DryRun { printer.P("\nWould have made the following changes:") diff --git a/cmd/restic/cmd_recover.go b/cmd/restic/cmd_recover.go index f9a4d419d..cac29a60c 100644 --- a/cmd/restic/cmd_recover.go +++ b/cmd/restic/cmd_recover.go @@ -66,11 +66,17 @@ func runRecover(ctx context.Context, gopts GlobalOptions) error { trees[blob.Blob.ID] = false } }) + if ctx.Err() != nil { + return ctx.Err() + } Verbosef("load %d trees\n", len(trees)) bar = newProgressMax(!gopts.Quiet, uint64(len(trees)), "trees loaded") for id := range trees { tree, err := restic.LoadTree(ctx, repo, id) + if ctx.Err() != nil { + return ctx.Err() + } if err != nil { Warnf("unable to load tree %v: %v\n", id.Str(), err) continue diff --git a/cmd/restic/cmd_repair_snapshots.go b/cmd/restic/cmd_repair_snapshots.go index 4d9745e15..b200d100a 100644 --- a/cmd/restic/cmd_repair_snapshots.go +++ b/cmd/restic/cmd_repair_snapshots.go @@ -145,6 +145,9 @@ func runRepairSnapshots(ctx context.Context, gopts GlobalOptions, opts RepairOpt changedCount++ } } + if ctx.Err() != nil { + return ctx.Err() + } Verbosef("\n") if changedCount == 0 { diff --git a/cmd/restic/cmd_rewrite.go b/cmd/restic/cmd_rewrite.go index 06d4ddbd1..38a868c5c 100644 --- a/cmd/restic/cmd_rewrite.go +++ b/cmd/restic/cmd_rewrite.go @@ -294,6 +294,9 @@ func runRewrite(ctx context.Context, opts RewriteOptions, gopts GlobalOptions, a changedCount++ } } + if ctx.Err() != nil { + return ctx.Err() + } Verbosef("\n") if changedCount == 0 { diff --git a/cmd/restic/cmd_snapshots.go b/cmd/restic/cmd_snapshots.go index 1a9cd2232..faa86d3a6 100644 --- a/cmd/restic/cmd_snapshots.go +++ b/cmd/restic/cmd_snapshots.go @@ -69,6 +69,9 @@ func runSnapshots(ctx context.Context, opts SnapshotOptions, gopts GlobalOptions for sn := range FindFilteredSnapshots(ctx, repo, repo, &opts.SnapshotFilter, args) { snapshots = append(snapshots, sn) } + if ctx.Err() != nil { + return ctx.Err() + } snapshotGroups, grouped, err := restic.GroupSnapshots(snapshots, opts.GroupBy) if err != nil { return err diff --git a/cmd/restic/cmd_stats.go b/cmd/restic/cmd_stats.go index 6bf0dbf19..2647d78e5 100644 --- a/cmd/restic/cmd_stats.go +++ b/cmd/restic/cmd_stats.go @@ -117,9 +117,8 @@ func runStats(ctx context.Context, opts StatsOptions, gopts GlobalOptions, args return fmt.Errorf("error walking snapshot: %v", err) } } - - if err != nil { - return err + if ctx.Err() != nil { + return ctx.Err() } if opts.countMode == countModeRawData { diff --git a/cmd/restic/cmd_tag.go b/cmd/restic/cmd_tag.go index b0d139fa6..3bf386f2c 100644 --- a/cmd/restic/cmd_tag.go +++ b/cmd/restic/cmd_tag.go @@ -122,6 +122,9 @@ func runTag(ctx context.Context, opts TagOptions, gopts GlobalOptions, args []st changeCnt++ } } + if ctx.Err() != nil { + return ctx.Err() + } if changeCnt == 0 { Verbosef("no snapshots were modified\n") } else { diff --git a/internal/archiver/archiver.go b/internal/archiver/archiver.go index 146ff3a7c..c1f73eea6 100644 --- a/internal/archiver/archiver.go +++ b/internal/archiver/archiver.go @@ -380,6 +380,7 @@ func (fn *FutureNode) take(ctx context.Context) futureNodeResult { return res } case <-ctx.Done(): + return futureNodeResult{err: ctx.Err()} } return futureNodeResult{err: errors.Errorf("no result")} } diff --git a/internal/archiver/tree_saver.go b/internal/archiver/tree_saver.go index eae524a78..9c11b48f0 100644 --- a/internal/archiver/tree_saver.go +++ b/internal/archiver/tree_saver.go @@ -90,6 +90,10 @@ func (s *TreeSaver) save(ctx context.Context, job *saveTreeJob) (*restic.Node, I // return the error if it wasn't ignored if fnr.err != nil { debug.Log("err for %v: %v", fnr.snPath, fnr.err) + if fnr.err == context.Canceled { + return nil, stats, fnr.err + } + fnr.err = s.errFn(fnr.target, fnr.err) if fnr.err == nil { // ignore error diff --git a/internal/index/master_index.go b/internal/index/master_index.go index 4c114b955..9833f9a55 100644 --- a/internal/index/master_index.go +++ b/internal/index/master_index.go @@ -320,6 +320,9 @@ func (mi *MasterIndex) Save(ctx context.Context, repo restic.Repository, exclude newIndex = NewIndex() } } + if wgCtx.Err() != nil { + return wgCtx.Err() + } } err := newIndex.AddToSupersedes(extraObsolete...) diff --git a/internal/repository/prune.go b/internal/repository/prune.go index 8900fffaa..39eb30317 100644 --- a/internal/repository/prune.go +++ b/internal/repository/prune.go @@ -130,6 +130,9 @@ func PlanPrune(ctx context.Context, opts PruneOptions, repo restic.Repository, g } keepBlobs.Delete(blob.BlobHandle) }) + if ctx.Err() != nil { + return nil, ctx.Err() + } if keepBlobs.Len() < blobCount/2 { // replace with copy to shrink map to necessary size if there's a chance to benefit @@ -166,6 +169,9 @@ func packInfoFromIndex(ctx context.Context, idx restic.MasterIndex, usedBlobs re usedBlobs[bh] = count } }) + if ctx.Err() != nil { + return nil, nil, ctx.Err() + } // Check if all used blobs have been found in index missingBlobs := restic.NewBlobSet() @@ -240,6 +246,9 @@ func packInfoFromIndex(ctx context.Context, idx restic.MasterIndex, usedBlobs re // update indexPack indexPack[blob.PackID] = ip }) + if ctx.Err() != nil { + return nil, nil, ctx.Err() + } // if duplicate blobs exist, those will be set to either "used" or "unused": // - mark only one occurrence of duplicate blobs as used @@ -286,6 +295,9 @@ func packInfoFromIndex(ctx context.Context, idx restic.MasterIndex, usedBlobs re indexPack[blob.PackID] = ip }) } + if ctx.Err() != nil { + return nil, nil, ctx.Err() + } // Sanity check. If no duplicates exist, all blobs have value 1. After handling // duplicates, this also applies to duplicates. @@ -528,6 +540,9 @@ func (plan *PrunePlan) Execute(ctx context.Context, printer progress.Printer) (e printer.P("deleting unreferenced packs\n") _ = deleteFiles(ctx, true, repo, plan.removePacksFirst, restic.PackFile, printer) } + if ctx.Err() != nil { + return ctx.Err() + } if len(plan.repackPacks) != 0 { printer.P("repacking packs\n") @@ -578,6 +593,9 @@ func (plan *PrunePlan) Execute(ctx context.Context, printer progress.Printer) (e printer.P("removing %d old packs\n", len(plan.removePacks)) _ = deleteFiles(ctx, true, repo, plan.removePacks, restic.PackFile, printer) } + if ctx.Err() != nil { + return ctx.Err() + } if plan.opts.UnsafeRecovery { err = rebuildIndexFiles(ctx, repo, plan.ignorePacks, nil, true, printer) diff --git a/internal/repository/repack.go b/internal/repository/repack.go index 5588984f6..53656252a 100644 --- a/internal/repository/repack.go +++ b/internal/repository/repack.go @@ -72,7 +72,7 @@ func repack(ctx context.Context, repo restic.Repository, dstRepo restic.Reposito return wgCtx.Err() } } - return nil + return wgCtx.Err() }) worker := func() error { diff --git a/internal/repository/repository.go b/internal/repository/repository.go index ae4528d80..a43971266 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -713,6 +713,9 @@ func (r *Repository) LoadIndex(ctx context.Context, p *progress.Counter) error { return errors.New("index uses feature not supported by repository version 1") } } + if ctx.Err() != nil { + return ctx.Err() + } // remove index files from the cache which have been removed in the repo return r.prepareCache()