diff --git a/cmd/restic/lock.go b/cmd/restic/lock.go index 20ac4dd34..69d433df1 100644 --- a/cmd/restic/lock.go +++ b/cmd/restic/lock.go @@ -2,16 +2,10 @@ package main import ( "context" - "sync" "github.com/restic/restic/internal/repository" - "github.com/restic/restic/internal/restic" ) -var globalLocks struct { - sync.Once -} - func internalOpenWithLocked(ctx context.Context, gopts GlobalOptions, dryRun bool, exclusive bool) (context.Context, *repository.Repository, func(), error) { repo, err := OpenRepository(ctx, gopts) if err != nil { @@ -20,22 +14,22 @@ func internalOpenWithLocked(ctx context.Context, gopts GlobalOptions, dryRun boo unlock := func() {} if !dryRun { - var lock *restic.Lock - - // make sure that a repository is unlocked properly and after cancel() was - // called by the cleanup handler in global.go - globalLocks.Do(func() { - AddCleanupHandler(repository.UnlockAll) - }) + var lock *repository.Unlocker lock, ctx, err = repository.Lock(ctx, repo, exclusive, gopts.RetryLock, func(msg string) { if !gopts.JSON { Verbosef("%s", msg) } }, Warnf) - unlock = func() { - repository.Unlock(lock) - } + + unlock = lock.Unlock + // make sure that a repository is unlocked properly and after cancel() was + // called by the cleanup handler in global.go + AddCleanupHandler(func(code int) (int, error) { + lock.Unlock() + return code, nil + }) + if err != nil { return nil, nil, nil, err } diff --git a/internal/repository/lock.go b/internal/repository/lock.go index c64cb9222..e3360cac0 100644 --- a/internal/repository/lock.go +++ b/internal/repository/lock.go @@ -18,11 +18,6 @@ type lockContext struct { refreshWG sync.WaitGroup } -var globalLocks struct { - locks map[*restic.Lock]*lockContext - sync.Mutex -} - var ( retrySleepStart = 5 * time.Second retrySleepMax = 60 * time.Second @@ -37,7 +32,7 @@ func minDuration(a, b time.Duration) time.Duration { // Lock wraps the ctx such that it is cancelled when the repository is unlocked // cancelling the original context also stops the lock refresh -func Lock(ctx context.Context, repo restic.Repository, exclusive bool, retryLock time.Duration, printRetry func(msg string), logger func(format string, args ...interface{})) (*restic.Lock, context.Context, error) { +func Lock(ctx context.Context, repo restic.Repository, exclusive bool, retryLock time.Duration, printRetry func(msg string), logger func(format string, args ...interface{})) (*Unlocker, context.Context, error) { lockFn := restic.NewLock if exclusive { @@ -97,13 +92,10 @@ retryLoop: refreshChan := make(chan struct{}) forceRefreshChan := make(chan refreshLockRequest) - globalLocks.Lock() - globalLocks.locks[lock] = lockInfo go refreshLocks(ctx, repo.Backend(), lockInfo, refreshChan, forceRefreshChan, logger) go monitorLockRefresh(ctx, lockInfo, refreshChan, forceRefreshChan, logger) - globalLocks.Unlock() - return lock, ctx, err + return &Unlocker{lockInfo}, ctx, nil } var refreshInterval = 5 * time.Minute @@ -261,41 +253,11 @@ func tryRefreshStaleLock(ctx context.Context, be backend.Backend, lock *restic.L return true } -func Unlock(lock *restic.Lock) { - if lock == nil { - return - } - - globalLocks.Lock() - lockInfo, exists := globalLocks.locks[lock] - delete(globalLocks.locks, lock) - globalLocks.Unlock() - - if !exists { - debug.Log("unable to find lock %v in the global list of locks, ignoring", lock) - return - } - lockInfo.cancel() - lockInfo.refreshWG.Wait() +type Unlocker struct { + info *lockContext } -func UnlockAll(code int) (int, error) { - globalLocks.Lock() - locks := globalLocks.locks - debug.Log("unlocking %d locks", len(globalLocks.locks)) - for _, lockInfo := range globalLocks.locks { - lockInfo.cancel() - } - globalLocks.locks = make(map[*restic.Lock]*lockContext) - globalLocks.Unlock() - - for _, lockInfo := range locks { - lockInfo.refreshWG.Wait() - } - - return code, nil -} - -func init() { - globalLocks.locks = make(map[*restic.Lock]*lockContext) +func (l *Unlocker) Unlock() { + l.info.cancel() + l.info.refreshWG.Wait() } diff --git a/internal/repository/lock_test.go b/internal/repository/lock_test.go index fb48a566f..2975ed7ff 100644 --- a/internal/repository/lock_test.go +++ b/internal/repository/lock_test.go @@ -37,11 +37,11 @@ func openLockTestRepo(t *testing.T, wrapper backendWrapper) restic.Repository { return repo } -func checkedLockRepo(ctx context.Context, t *testing.T, repo restic.Repository, retryLock time.Duration) (*restic.Lock, context.Context) { +func checkedLockRepo(ctx context.Context, t *testing.T, repo restic.Repository, retryLock time.Duration) (*Unlocker, context.Context) { lock, wrappedCtx, err := Lock(ctx, repo, false, retryLock, func(msg string) {}, func(format string, args ...interface{}) {}) test.OK(t, err) test.OK(t, wrappedCtx.Err()) - if lock.Stale() { + if lock.info.lock.Stale() { t.Fatal("lock returned stale lock") } return lock, wrappedCtx @@ -51,7 +51,7 @@ func TestLock(t *testing.T) { repo := openLockTestRepo(t, nil) lock, wrappedCtx := checkedLockRepo(context.Background(), t, repo, 0) - Unlock(lock) + lock.Unlock() if wrappedCtx.Err() == nil { t.Fatal("unlock did not cancel context") } @@ -69,21 +69,7 @@ func TestLockCancel(t *testing.T) { } // Unlock should not crash - Unlock(lock) -} - -func TestLockUnlockAll(t *testing.T) { - repo := openLockTestRepo(t, nil) - - lock, wrappedCtx := checkedLockRepo(context.Background(), t, repo, 0) - _, err := UnlockAll(0) - test.OK(t, err) - if wrappedCtx.Err() == nil { - t.Fatal("canceled parent context did not cancel context") - } - - // Unlock should not crash - Unlock(lock) + lock.Unlock() } func TestLockConflict(t *testing.T) { @@ -94,7 +80,7 @@ func TestLockConflict(t *testing.T) { lock, _, err := Lock(context.Background(), repo, true, 0, func(msg string) {}, func(format string, args ...interface{}) {}) test.OK(t, err) - defer Unlock(lock) + defer lock.Unlock() _, _, err = Lock(context.Background(), repo2, false, 0, func(msg string) {}, func(format string, args ...interface{}) {}) if err == nil { t.Fatal("second lock should have failed") @@ -137,7 +123,7 @@ func TestLockFailedRefresh(t *testing.T) { t.Fatal("failed lock refresh did not cause context cancellation") } // Unlock should not crash - Unlock(lock) + lock.Unlock() } type loggingBackend struct { @@ -186,7 +172,7 @@ func TestLockSuccessfulRefresh(t *testing.T) { // expected lock refresh to work } // Unlock should not crash - Unlock(lock) + lock.Unlock() } type slowBackend struct { @@ -248,19 +234,21 @@ func TestLockSuccessfulStaleRefresh(t *testing.T) { } // Unlock should not crash - Unlock(lock) + lock.Unlock() } func TestLockWaitTimeout(t *testing.T) { + t.Parallel() repo := openLockTestRepo(t, nil) elock, _, err := Lock(context.TODO(), repo, true, 0, func(msg string) {}, func(format string, args ...interface{}) {}) test.OK(t, err) + defer elock.Unlock() retryLock := 200 * time.Millisecond start := time.Now() - lock, _, err := Lock(context.TODO(), repo, false, retryLock, func(msg string) {}, func(format string, args ...interface{}) {}) + _, _, err = Lock(context.TODO(), repo, false, retryLock, func(msg string) {}, func(format string, args ...interface{}) {}) duration := time.Since(start) test.Assert(t, err != nil, @@ -269,16 +257,15 @@ func TestLockWaitTimeout(t *testing.T) { "create normal lock with exclusively locked repo didn't return the correct error") test.Assert(t, retryLock <= duration && duration < retryLock*3/2, "create normal lock with exclusively locked repo didn't wait for the specified timeout") - - test.OK(t, lock.Unlock()) - test.OK(t, elock.Unlock()) } func TestLockWaitCancel(t *testing.T) { + t.Parallel() repo := openLockTestRepo(t, nil) elock, _, err := Lock(context.TODO(), repo, true, 0, func(msg string) {}, func(format string, args ...interface{}) {}) test.OK(t, err) + defer elock.Unlock() retryLock := 200 * time.Millisecond cancelAfter := 40 * time.Millisecond @@ -287,7 +274,7 @@ func TestLockWaitCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.TODO()) time.AfterFunc(cancelAfter, cancel) - lock, _, err := Lock(ctx, repo, false, retryLock, func(msg string) {}, func(format string, args ...interface{}) {}) + _, _, err = Lock(ctx, repo, false, retryLock, func(msg string) {}, func(format string, args ...interface{}) {}) duration := time.Since(start) test.Assert(t, err != nil, @@ -296,12 +283,10 @@ func TestLockWaitCancel(t *testing.T) { "create normal lock with exclusively locked repo didn't return the correct error") test.Assert(t, cancelAfter <= duration && duration < retryLock-10*time.Millisecond, "create normal lock with exclusively locked repo didn't return in time, duration %v", duration) - - test.OK(t, lock.Unlock()) - test.OK(t, elock.Unlock()) } func TestLockWaitSuccess(t *testing.T) { + t.Parallel() repo := openLockTestRepo(t, nil) elock, _, err := Lock(context.TODO(), repo, true, 0, func(msg string) {}, func(format string, args ...interface{}) {}) @@ -311,11 +296,10 @@ func TestLockWaitSuccess(t *testing.T) { unlockAfter := 40 * time.Millisecond time.AfterFunc(unlockAfter, func() { - test.OK(t, elock.Unlock()) + elock.Unlock() }) lock, _, err := Lock(context.TODO(), repo, false, retryLock, func(msg string) {}, func(format string, args ...interface{}) {}) test.OK(t, err) - - test.OK(t, lock.Unlock()) + lock.Unlock() }