portlist: fix data race

Maisem spotted the bug. The initial getList call in NewPoller wasn't
making a clone (only the Run loop's getList calls).

Fixes #6314

Change-Id: I8ab8799fcccea8e799140340d0ff88a825bb6ff0
Co-authored-by: Maisem Ali <maisem@tailscale.com>
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2022-11-14 11:54:52 -08:00 committed by Brad Fitzpatrick
parent 42855d219b
commit f81351fdef
2 changed files with 85 additions and 10 deletions

View File

@ -14,6 +14,7 @@
"sync" "sync"
"time" "time"
"golang.org/x/exp/slices"
"tailscale.com/envknob" "tailscale.com/envknob"
) )
@ -84,14 +85,20 @@ func NewPoller() (*Poller, error) {
// Do one initial poll synchronously so we can return an error // Do one initial poll synchronously so we can return an error
// early. // early.
var err error if pl, err := p.getList(); err != nil {
p.prev, err = p.getList()
if err != nil {
return nil, err return nil, err
} else {
p.setPrev(pl)
} }
return p, nil return p, nil
} }
func (p *Poller) setPrev(pl List) {
// Make a copy, as the pass in pl slice aliases pl.scratch and we don't want
// that to except to the caller.
p.prev = slices.Clone(pl)
}
func (p *Poller) initOSField() { func (p *Poller) initOSField() {
if newOSImpl != nil { if newOSImpl != nil {
p.os = newOSImpl() p.os = newOSImpl()
@ -131,11 +138,14 @@ func (p *Poller) send(ctx context.Context, pl List) (sent bool, err error) {
// //
// Run may only be called once. // Run may only be called once.
func (p *Poller) Run(ctx context.Context) error { func (p *Poller) Run(ctx context.Context) error {
defer close(p.runDone)
defer close(p.c)
tick := time.NewTicker(pollInterval) tick := time.NewTicker(pollInterval)
defer tick.Stop() defer tick.Stop()
return p.runWithTickChan(ctx, tick.C)
}
func (p *Poller) runWithTickChan(ctx context.Context, tickChan <-chan time.Time) error {
defer close(p.runDone)
defer close(p.c)
// Send out the pre-generated initial value. // Send out the pre-generated initial value.
if sent, err := p.send(ctx, p.prev); !sent { if sent, err := p.send(ctx, p.prev); !sent {
@ -144,7 +154,7 @@ func (p *Poller) Run(ctx context.Context) error {
for { for {
select { select {
case <-tick.C: case <-tickChan:
pl, err := p.getList() pl, err := p.getList()
if err != nil { if err != nil {
return err return err
@ -152,9 +162,7 @@ func (p *Poller) Run(ctx context.Context) error {
if pl.equal(p.prev) { if pl.equal(p.prev) {
continue continue
} }
// New value. Make a copy, as pl might alias pl.scratch p.setPrev(pl)
// and prev must not.
p.prev = append([]Port(nil), pl...)
if sent, err := p.send(ctx, p.prev); !sent { if sent, err := p.send(ctx, p.prev); !sent {
return err return err
} }

View File

@ -5,10 +5,13 @@
package portlist package portlist
import ( import (
"context"
"flag" "flag"
"net" "net"
"runtime" "runtime"
"sync"
"testing" "testing"
"time"
"tailscale.com/tstest" "tailscale.com/tstest"
) )
@ -182,6 +185,70 @@ func TestEqualLessThan(t *testing.T) {
} }
} }
func TestPoller(t *testing.T) {
p, err := NewPoller()
if err != nil {
t.Skipf("not running test: %v", err)
}
defer p.Close()
var wg sync.WaitGroup
wg.Add(2)
gotUpdate := make(chan bool, 16)
go func() {
defer wg.Done()
for pl := range p.Updates() {
// Look at all the pl slice memory to maximize
// chance of race detector seeing violations.
for _, v := range pl {
if v == (Port{}) {
// Force use
panic("empty port")
}
}
select {
case gotUpdate <- true:
default:
}
}
}()
tick := make(chan time.Time, 16)
go func() {
defer wg.Done()
if err := p.runWithTickChan(context.Background(), tick); err != nil {
t.Error("runWithTickChan:", err)
}
}()
for i := 0; i < 10; i++ {
ln, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
tick <- time.Time{}
select {
case <-gotUpdate:
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for update")
}
}
// And a bunch of ticks without waiting for updates,
// to make race tests more likely to fail, if any present.
for i := 0; i < 10; i++ {
tick <- time.Time{}
}
if err := p.Close(); err != nil {
t.Fatal(err)
}
wg.Wait()
}
func BenchmarkGetList(b *testing.B) { func BenchmarkGetList(b *testing.B) {
benchmarkGetList(b, false) benchmarkGetList(b, false)
} }