diff --git a/cmd/restic/cleanup.go b/cmd/restic/cleanup.go index 5a6cf79e1..90ea93b92 100644 --- a/cmd/restic/cleanup.go +++ b/cmd/restic/cleanup.go @@ -1,89 +1,41 @@ package main import ( + "context" "os" "os/signal" - "sync" "syscall" "github.com/restic/restic/internal/debug" ) -var cleanupHandlers struct { - sync.Mutex - list []func(code int) (int, error) - done bool - ch chan os.Signal +func createGlobalContext() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + + ch := make(chan os.Signal, 1) + go cleanupHandler(ch, cancel) + signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) + + return ctx } -func init() { - cleanupHandlers.ch = make(chan os.Signal, 1) - go CleanupHandler(cleanupHandlers.ch) - signal.Notify(cleanupHandlers.ch, syscall.SIGINT, syscall.SIGTERM) -} +// cleanupHandler handles the SIGINT and SIGTERM signals. +func cleanupHandler(c <-chan os.Signal, cancel context.CancelFunc) { + s := <-c + debug.Log("signal %v received, cleaning up", s) + Warnf("%ssignal %v received, cleaning up\n", clearLine(0), s) -// AddCleanupHandler adds the function f to the list of cleanup handlers so -// that it is executed when all the cleanup handlers are run, e.g. when SIGINT -// is received. -func AddCleanupHandler(f func(code int) (int, error)) { - cleanupHandlers.Lock() - defer cleanupHandlers.Unlock() - - // reset the done flag for integration tests - cleanupHandlers.done = false - - cleanupHandlers.list = append(cleanupHandlers.list, f) -} - -// RunCleanupHandlers runs all registered cleanup handlers -func RunCleanupHandlers(code int) int { - cleanupHandlers.Lock() - defer cleanupHandlers.Unlock() - - if cleanupHandlers.done { - return code + if val, _ := os.LookupEnv("RESTIC_DEBUG_STACKTRACE_SIGINT"); val != "" { + _, _ = os.Stderr.WriteString("\n--- STACKTRACE START ---\n\n") + _, _ = os.Stderr.WriteString(debug.DumpStacktrace()) + _, _ = os.Stderr.WriteString("\n--- STACKTRACE END ---\n") } - cleanupHandlers.done = true - for _, f := range cleanupHandlers.list { - var err error - code, err = f(code) - if err != nil { - Warnf("error in cleanup handler: %v\n", err) - } - } - cleanupHandlers.list = nil - return code + cancel() } -// CleanupHandler handles the SIGINT and SIGTERM signals. -func CleanupHandler(c <-chan os.Signal) { - for s := range c { - debug.Log("signal %v received, cleaning up", s) - Warnf("%ssignal %v received, cleaning up\n", clearLine(0), s) - - if val, _ := os.LookupEnv("RESTIC_DEBUG_STACKTRACE_SIGINT"); val != "" { - _, _ = os.Stderr.WriteString("\n--- STACKTRACE START ---\n\n") - _, _ = os.Stderr.WriteString(debug.DumpStacktrace()) - _, _ = os.Stderr.WriteString("\n--- STACKTRACE END ---\n") - } - - code := 0 - - if s == syscall.SIGINT || s == syscall.SIGTERM { - code = 130 - } else { - code = 1 - } - - Exit(code) - } -} - -// Exit runs the cleanup handlers and then terminates the process with the -// given exit code. +// Exit terminates the process with the given exit code. func Exit(code int) { - code = RunCleanupHandlers(code) debug.Log("exiting with status code %d", code) os.Exit(code) } diff --git a/cmd/restic/global.go b/cmd/restic/global.go index 9f1ec85a2..5b21871dc 100644 --- a/cmd/restic/global.go +++ b/cmd/restic/global.go @@ -96,8 +96,6 @@ var globalOptions = GlobalOptions{ stderr: os.Stderr, } -var internalGlobalCtx context.Context - func init() { backends := location.NewRegistry() backends.Register(azure.NewFactory()) @@ -111,15 +109,6 @@ func init() { backends.Register(swift.NewFactory()) globalOptions.backends = backends - var cancel context.CancelFunc - internalGlobalCtx, cancel = context.WithCancel(context.Background()) - AddCleanupHandler(func(code int) (int, error) { - // Must be called before the unlock cleanup handler to ensure that the latter is - // not blocked due to limited number of backend connections, see #1434 - cancel() - return code, nil - }) - f := cmdRoot.PersistentFlags() f.StringVarP(&globalOptions.Repo, "repo", "r", "", "`repository` to backup to or restore from (default: $RESTIC_REPOSITORY)") f.StringVarP(&globalOptions.RepositoryFile, "repository-file", "", "", "`file` to read the repository location from (default: $RESTIC_REPOSITORY_FILE)") diff --git a/cmd/restic/main.go b/cmd/restic/main.go index 308a432b5..56ddf74a4 100644 --- a/cmd/restic/main.go +++ b/cmd/restic/main.go @@ -3,6 +3,7 @@ package main import ( "bufio" "bytes" + "context" "fmt" "log" "os" @@ -118,7 +119,13 @@ func main() { debug.Log("main %#v", os.Args) debug.Log("restic %s compiled with %v on %v/%v", version, runtime.Version(), runtime.GOOS, runtime.GOARCH) - err = cmdRoot.ExecuteContext(internalGlobalCtx) + + ctx := createGlobalContext() + err = cmdRoot.ExecuteContext(ctx) + + if err == nil { + err = ctx.Err() + } switch { case restic.IsAlreadyLocked(err): @@ -140,11 +147,13 @@ func main() { } var exitCode int - switch err { - case nil: + switch { + case err == nil: exitCode = 0 - case ErrInvalidSourceData: + case err == ErrInvalidSourceData: exitCode = 3 + case errors.Is(err, context.Canceled): + exitCode = 130 default: exitCode = 1 }