Merge pull request #5511 from greatroar/atomic

ui/progress: Restore atomics in Counter
This commit is contained in:
Michael Eischer
2025-09-21 22:29:40 +02:00
committed by GitHub
3 changed files with 19 additions and 33 deletions

View File

@@ -1,7 +1,7 @@
package progress package progress
import ( import (
"sync" "sync/atomic"
"time" "time"
) )
@@ -17,17 +17,13 @@ type Func func(value uint64, total uint64, runtime time.Duration, final bool)
// The Func is also called when SIGUSR1 (or SIGINFO, on BSD) is received. // The Func is also called when SIGUSR1 (or SIGINFO, on BSD) is received.
type Counter struct { type Counter struct {
Updater Updater
value, max atomic.Uint64
valueMutex sync.Mutex
value uint64
max uint64
} }
// NewCounter starts a new Counter. // NewCounter starts a new Counter.
func NewCounter(interval time.Duration, total uint64, report Func) *Counter { func NewCounter(interval time.Duration, total uint64, report Func) *Counter {
c := &Counter{ c := new(Counter)
max: total, c.max.Store(total)
}
c.Updater = *NewUpdater(interval, func(runtime time.Duration, final bool) { c.Updater = *NewUpdater(interval, func(runtime time.Duration, final bool) {
v, maxV := c.Get() v, maxV := c.Get()
report(v, maxV, runtime, final) report(v, maxV, runtime, final)
@@ -37,33 +33,22 @@ func NewCounter(interval time.Duration, total uint64, report Func) *Counter {
// Add v to the Counter. This method is concurrency-safe. // Add v to the Counter. This method is concurrency-safe.
func (c *Counter) Add(v uint64) { func (c *Counter) Add(v uint64) {
if c == nil { if c != nil {
return c.value.Add(v)
} }
c.valueMutex.Lock()
c.value += v
c.valueMutex.Unlock()
} }
// SetMax sets the maximum expected counter value. This method is concurrency-safe. // SetMax sets the maximum expected counter value. This method is concurrency-safe.
func (c *Counter) SetMax(max uint64) { func (c *Counter) SetMax(max uint64) {
if c == nil { if c != nil {
return c.max.Store(max)
} }
c.valueMutex.Lock()
c.max = max
c.valueMutex.Unlock()
} }
// Get returns the current value and the maximum of c. // Get returns the current value and the maximum of c.
// This method is concurrency-safe. // This method is concurrency-safe.
func (c *Counter) Get() (v, max uint64) { func (c *Counter) Get() (v, max uint64) {
c.valueMutex.Lock() return c.value.Load(), c.max.Load()
v, max = c.value, c.max
c.valueMutex.Unlock()
return v, max
} }
func (c *Counter) Done() { func (c *Counter) Done() {

View File

@@ -57,17 +57,13 @@ func (c *Updater) Done() {
func (c *Updater) run() { func (c *Updater) run() {
defer close(c.stopped) defer close(c.stopped)
defer func() {
// Must be a func so that time.Since isn't called at defer time.
c.report(time.Since(c.start), true)
}()
var tick <-chan time.Time var tick <-chan time.Time
if c.tick != nil { if c.tick != nil {
tick = c.tick.C tick = c.tick.C
} }
signalsCh := signals.GetProgressChannel() signalsCh := signals.GetProgressChannel()
for { for final := false; !final; {
var now time.Time var now time.Time
select { select {
@@ -76,9 +72,9 @@ func (c *Updater) run() {
debug.Log("Signal received: %v\n", sig) debug.Log("Signal received: %v\n", sig)
now = time.Now() now = time.Now()
case <-c.stop: case <-c.stop:
return final, now = true, time.Now()
} }
c.report(now.Sub(c.start), false) c.report(now.Sub(c.start), final)
} }
} }

View File

@@ -9,13 +9,17 @@ import (
) )
func TestUpdater(t *testing.T) { func TestUpdater(t *testing.T) {
finalSeen := false var (
var ncalls int finalSeen = false
ncalls = 0
dur time.Duration
)
report := func(d time.Duration, final bool) { report := func(d time.Duration, final bool) {
if final { if final {
finalSeen = true finalSeen = true
} }
dur = d
ncalls++ ncalls++
} }
c := progress.NewUpdater(10*time.Millisecond, report) c := progress.NewUpdater(10*time.Millisecond, report)
@@ -24,6 +28,7 @@ func TestUpdater(t *testing.T) {
test.Assert(t, finalSeen, "final call did not happen") test.Assert(t, finalSeen, "final call did not happen")
test.Assert(t, ncalls > 0, "no progress was reported") test.Assert(t, ncalls > 0, "no progress was reported")
test.Assert(t, dur > 0, "duration must be positive")
} }
func TestUpdaterStopTwice(_ *testing.T) { func TestUpdaterStopTwice(_ *testing.T) {