diff --git a/cmd/restic/progress.go b/cmd/restic/progress.go index 3caa34a26..4b6025a54 100644 --- a/cmd/restic/progress.go +++ b/cmd/restic/progress.go @@ -37,7 +37,7 @@ func newProgressMax(show bool, max uint64, description string) *progress.Counter interval := calculateProgressInterval(show, false) canUpdateStatus := stdoutCanUpdateStatus() - return progress.New(interval, max, func(v uint64, max uint64, d time.Duration, final bool) { + return progress.NewCounter(interval, max, func(v uint64, max uint64, d time.Duration, final bool) { var status string if max == 0 { status = fmt.Sprintf("[%s] %d %s", diff --git a/internal/restic/find_test.go b/internal/restic/find_test.go index 80f616513..f5e288b9d 100644 --- a/internal/restic/find_test.go +++ b/internal/restic/find_test.go @@ -93,7 +93,7 @@ func TestFindUsedBlobs(t *testing.T) { snapshots = append(snapshots, sn) } - p := progress.New(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {}) + p := progress.NewCounter(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {}) defer p.Done() for i, sn := range snapshots { @@ -142,7 +142,7 @@ func TestMultiFindUsedBlobs(t *testing.T) { want.Merge(loadIDSet(t, goldenFilename)) } - p := progress.New(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {}) + p := progress.NewCounter(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {}) defer p.Done() // run twice to check progress bar handling of duplicate tree roots diff --git a/internal/ui/progress/counter.go b/internal/ui/progress/counter.go index 90a09d0d8..c1275d2f2 100644 --- a/internal/ui/progress/counter.go +++ b/internal/ui/progress/counter.go @@ -3,9 +3,6 @@ package progress import ( "sync" "time" - - "github.com/restic/restic/internal/debug" - "github.com/restic/restic/internal/ui/signals" ) // A Func is a callback for a Counter. @@ -19,32 +16,22 @@ type Func func(value uint64, total uint64, runtime time.Duration, final bool) // // The Func is also called when SIGUSR1 (or SIGINFO, on BSD) is received. type Counter struct { - report Func - start time.Time - stopped chan struct{} // Closed by run. - stop chan struct{} // Close to stop run. - tick *time.Ticker + Updater valueMutex sync.Mutex value uint64 max uint64 } -// New starts a new Counter. -func New(interval time.Duration, total uint64, report Func) *Counter { +// NewCounter starts a new Counter. +func NewCounter(interval time.Duration, total uint64, report Func) *Counter { c := &Counter{ - report: report, - start: time.Now(), - stopped: make(chan struct{}), - stop: make(chan struct{}), - max: total, + max: total, } - - if interval > 0 { - c.tick = time.NewTicker(interval) - } - - go c.run() + c.Updater = *NewUpdater(interval, func(runtime time.Duration, final bool) { + v, max := c.Get() + report(v, max, runtime, final) + }) return c } @@ -69,18 +56,6 @@ func (c *Counter) SetMax(max uint64) { c.valueMutex.Unlock() } -// Done tells a Counter to stop and waits for it to report its final value. -func (c *Counter) Done() { - if c == nil { - return - } - if c.tick != nil { - c.tick.Stop() - } - close(c.stop) - <-c.stopped // Wait for last progress report. -} - // Get returns the current value and the maximum of c. // This method is concurrency-safe. func (c *Counter) Get() (v, max uint64) { @@ -91,32 +66,8 @@ func (c *Counter) Get() (v, max uint64) { return v, max } -func (c *Counter) run() { - defer close(c.stopped) - defer func() { - // Must be a func so that time.Since isn't called at defer time. - v, max := c.Get() - c.report(v, max, time.Since(c.start), true) - }() - - var tick <-chan time.Time - if c.tick != nil { - tick = c.tick.C - } - signalsCh := signals.GetProgressChannel() - for { - var now time.Time - - select { - case now = <-tick: - case sig := <-signalsCh: - debug.Log("Signal received: %v\n", sig) - now = time.Now() - case <-c.stop: - return - } - - v, max := c.Get() - c.report(v, max, now.Sub(c.start), false) +func (c *Counter) Done() { + if c != nil { + c.Updater.Done() } } diff --git a/internal/ui/progress/counter_test.go b/internal/ui/progress/counter_test.go index 85695d209..49c694e06 100644 --- a/internal/ui/progress/counter_test.go +++ b/internal/ui/progress/counter_test.go @@ -35,7 +35,7 @@ func TestCounter(t *testing.T) { lastTotal = total ncalls++ } - c := progress.New(10*time.Millisecond, startTotal, report) + c := progress.NewCounter(10*time.Millisecond, startTotal, report) done := make(chan struct{}) go func() { @@ -63,24 +63,6 @@ func TestCounterNil(t *testing.T) { // Shouldn't panic. var c *progress.Counter c.Add(1) + c.SetMax(42) c.Done() } - -func TestCounterNoTick(t *testing.T) { - finalSeen := false - otherSeen := false - - report := func(value, total uint64, d time.Duration, final bool) { - if final { - finalSeen = true - } else { - otherSeen = true - } - } - c := progress.New(0, 1, report) - time.Sleep(time.Millisecond) - c.Done() - - test.Assert(t, finalSeen, "final call did not happen") - test.Assert(t, !otherSeen, "unexpected status update") -} diff --git a/internal/ui/progress/updater.go b/internal/ui/progress/updater.go new file mode 100644 index 000000000..7fb6c8836 --- /dev/null +++ b/internal/ui/progress/updater.go @@ -0,0 +1,84 @@ +package progress + +import ( + "time" + + "github.com/restic/restic/internal/debug" + "github.com/restic/restic/internal/ui/signals" +) + +// An UpdateFunc is a callback for a (progress) Updater. +// +// The final argument is true if Updater.Done has been called, +// which means that the current call will be the last. +type UpdateFunc func(runtime time.Duration, final bool) + +// An Updater controls a goroutine that periodically calls an UpdateFunc. +// +// The UpdateFunc is also called when SIGUSR1 (or SIGINFO, on BSD) is received. +type Updater struct { + report UpdateFunc + start time.Time + stopped chan struct{} // Closed by run. + stop chan struct{} // Close to stop run. + tick *time.Ticker +} + +// NewUpdater starts a new Updater. +func NewUpdater(interval time.Duration, report UpdateFunc) *Updater { + c := &Updater{ + report: report, + start: time.Now(), + stopped: make(chan struct{}), + stop: make(chan struct{}), + } + + if interval > 0 { + c.tick = time.NewTicker(interval) + } + + go c.run() + return c +} + +// Done tells an Updater to stop and waits for it to report its final value. +// Later calls do nothing. +func (c *Updater) Done() { + if c == nil || c.stop == nil { + return + } + if c.tick != nil { + c.tick.Stop() + } + close(c.stop) + <-c.stopped // Wait for last progress report. + c.stop = nil +} + +func (c *Updater) run() { + defer close(c.stopped) + defer func() { + // Must be a func so that time.Since isn't called at defer time. + c.report(time.Since(c.start), true) + }() + + var tick <-chan time.Time + if c.tick != nil { + tick = c.tick.C + } + signalsCh := signals.GetProgressChannel() + for { + var now time.Time + + select { + case now = <-tick: + case sig := <-signalsCh: + debug.Log("Signal received: %v\n", sig) + now = time.Now() + case <-c.stop: + return + } + + c.report(now.Sub(c.start), false) + } +} diff --git a/internal/ui/progress/updater_test.go b/internal/ui/progress/updater_test.go new file mode 100644 index 000000000..5b5207dd5 --- /dev/null +++ b/internal/ui/progress/updater_test.go @@ -0,0 +1,52 @@ +package progress_test + +import ( + "testing" + "time" + + "github.com/restic/restic/internal/test" + "github.com/restic/restic/internal/ui/progress" +) + +func TestUpdater(t *testing.T) { + finalSeen := false + var ncalls int + + report := func(d time.Duration, final bool) { + if final { + finalSeen = true + } + ncalls++ + } + c := progress.NewUpdater(10*time.Millisecond, report) + time.Sleep(100 * time.Millisecond) + c.Done() + + test.Assert(t, finalSeen, "final call did not happen") + test.Assert(t, ncalls > 0, "no progress was reported") +} + +func TestUpdaterStopTwice(t *testing.T) { + c := progress.NewUpdater(0, func(runtime time.Duration, final bool) {}) + c.Done() + c.Done() +} + +func TestUpdaterNoTick(t *testing.T) { + finalSeen := false + otherSeen := false + + report := func(d time.Duration, final bool) { + if final { + finalSeen = true + } else { + otherSeen = true + } + } + c := progress.NewUpdater(0, report) + time.Sleep(time.Millisecond) + c.Done() + + test.Assert(t, finalSeen, "final call did not happen") + test.Assert(t, !otherSeen, "unexpected status update") +}