mirror of
https://github.com/restic/restic.git
synced 2025-08-22 07:07:26 +00:00
Replace restic.Progress with new progress.Counter
This fixes two race conditions while cleaning up the code.
This commit is contained in:
99
internal/ui/progress/counter.go
Normal file
99
internal/ui/progress/counter.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package progress
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/restic/restic/internal/debug"
|
||||
)
|
||||
|
||||
// A Func is a callback for a Counter.
|
||||
//
|
||||
// The final argument is true if Counter.Done has been called,
|
||||
// which means that the current call will be the last.
|
||||
type Func func(value uint64, runtime time.Duration, final bool)
|
||||
|
||||
// A Counter tracks a running count and controls a goroutine that passes its
|
||||
// value periodically to a Func.
|
||||
//
|
||||
// The Func is also called when SIGUSR1 (or SIGINFO, on BSD) is received.
|
||||
type Counter struct {
|
||||
report Func
|
||||
start time.Time
|
||||
stopped chan struct{} // Closed by run.
|
||||
stop chan struct{} // Close to stop run.
|
||||
tick *time.Ticker
|
||||
value uint64
|
||||
}
|
||||
|
||||
// New starts a new Counter.
|
||||
func New(interval time.Duration, report Func) *Counter {
|
||||
signals.Once.Do(func() {
|
||||
signals.ch = make(chan os.Signal, 1)
|
||||
setupSignals()
|
||||
})
|
||||
|
||||
c := &Counter{
|
||||
report: report,
|
||||
start: time.Now(),
|
||||
stopped: make(chan struct{}),
|
||||
stop: make(chan struct{}),
|
||||
tick: time.NewTicker(interval),
|
||||
}
|
||||
|
||||
go c.run()
|
||||
return c
|
||||
}
|
||||
|
||||
// Add v to the Counter. This method is concurrency-safe.
|
||||
func (c *Counter) Add(v uint64) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
atomic.AddUint64(&c.value, v)
|
||||
}
|
||||
|
||||
// Done tells a Counter to stop and waits for it to report its final value.
|
||||
func (c *Counter) Done() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.tick.Stop()
|
||||
close(c.stop)
|
||||
<-c.stopped // Wait for last progress report.
|
||||
*c = Counter{} // Prevent reuse.
|
||||
}
|
||||
|
||||
func (c *Counter) get() uint64 { return atomic.LoadUint64(&c.value) }
|
||||
|
||||
func (c *Counter) run() {
|
||||
defer close(c.stopped)
|
||||
defer func() {
|
||||
// Must be a func so that time.Since isn't called at defer time.
|
||||
c.report(c.get(), time.Since(c.start), true)
|
||||
}()
|
||||
|
||||
for {
|
||||
var now time.Time
|
||||
|
||||
select {
|
||||
case now = <-c.tick.C:
|
||||
case sig := <-signals.ch:
|
||||
debug.Log("Signal received: %v\n", sig)
|
||||
now = time.Now()
|
||||
case <-c.stop:
|
||||
return
|
||||
}
|
||||
|
||||
c.report(c.get(), now.Sub(c.start), false)
|
||||
}
|
||||
}
|
||||
|
||||
// XXX The fact that signals is a single global variable means that only one
|
||||
// Counter receives each incoming signal.
|
||||
var signals struct {
|
||||
ch chan os.Signal
|
||||
sync.Once
|
||||
}
|
55
internal/ui/progress/counter_test.go
Normal file
55
internal/ui/progress/counter_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package progress_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/restic/restic/internal/test"
|
||||
"github.com/restic/restic/internal/ui/progress"
|
||||
)
|
||||
|
||||
func TestCounter(t *testing.T) {
|
||||
const N = 100
|
||||
|
||||
var (
|
||||
finalSeen = false
|
||||
increasing = true
|
||||
last uint64
|
||||
ncalls int
|
||||
)
|
||||
|
||||
report := func(value uint64, d time.Duration, final bool) {
|
||||
finalSeen = true
|
||||
if value < last {
|
||||
increasing = false
|
||||
}
|
||||
last = value
|
||||
ncalls++
|
||||
}
|
||||
c := progress.New(10*time.Millisecond, report)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for i := 0; i < N; i++ {
|
||||
time.Sleep(time.Millisecond)
|
||||
c.Add(1)
|
||||
}
|
||||
}()
|
||||
|
||||
<-done
|
||||
c.Done()
|
||||
|
||||
test.Assert(t, finalSeen, "final call did not happen")
|
||||
test.Assert(t, increasing, "values not increasing")
|
||||
test.Equals(t, uint64(N), last)
|
||||
|
||||
t.Log("number of calls:", ncalls)
|
||||
}
|
||||
|
||||
func TestCounterNil(t *testing.T) {
|
||||
// Shouldn't panic.
|
||||
var c *progress.Counter = nil
|
||||
c.Add(1)
|
||||
c.Done()
|
||||
}
|
12
internal/ui/progress/signals_bsd.go
Normal file
12
internal/ui/progress/signals_bsd.go
Normal file
@@ -0,0 +1,12 @@
|
||||
// +build darwin dragonfly freebsd netbsd openbsd
|
||||
|
||||
package progress
|
||||
|
||||
import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func setupSignals() {
|
||||
signal.Notify(signals.ch, syscall.SIGINFO, syscall.SIGUSR1)
|
||||
}
|
12
internal/ui/progress/signals_sysv.go
Normal file
12
internal/ui/progress/signals_sysv.go
Normal file
@@ -0,0 +1,12 @@
|
||||
// +build linux solaris
|
||||
|
||||
package progress
|
||||
|
||||
import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func setupSignals() {
|
||||
signal.Notify(signals.ch, syscall.SIGUSR1)
|
||||
}
|
3
internal/ui/progress/signals_windows.go
Normal file
3
internal/ui/progress/signals_windows.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package progress
|
||||
|
||||
func setupSignals() {}
|
Reference in New Issue
Block a user