diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 8c8f342e1..573d9a910 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -292,10 +292,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo osshare.SetFileSharingEnabled(false, logf) ctx, cancel := context.WithCancel(context.Background()) - portpoll, err := portlist.NewPoller() - if err != nil { - logf("skipping portlist: %s", err) - } + portpoll := new(portlist.Poller) b := &LocalBackend{ ctx: ctx, @@ -1377,7 +1374,6 @@ func (b *LocalBackend) Start(opts ipn.Options) error { if b.portpoll != nil { b.portpollOnce.Do(func() { - go b.portpoll.Run(b.ctx) go b.readPoller() // Give the poller a second to get results to @@ -1812,11 +1808,30 @@ func dnsMapsEqual(new, old *netmap.NetworkMap) bool { // readPoller is a goroutine that receives service lists from // b.portpoll and propagates them into the controlclient's HostInfo. func (b *LocalBackend) readPoller() { - n := 0 + isFirst := true + ticker := time.NewTicker(portlist.PollInterval()) + defer ticker.Stop() + initChan := make(chan struct{}) + close(initChan) for { - ports, ok := <-b.portpoll.Updates() - if !ok { + select { + case <-ticker.C: + case <-b.ctx.Done(): return + case <-initChan: + // Preserving old behavior: readPoller should + // immediately poll the first time, then wait + // for a tick after. + initChan = nil + } + + ports, changed, err := b.portpoll.Poll() + if err != nil { + b.logf("error polling for open ports: %v", err) + return + } + if !changed { + continue } sl := []tailcfg.Service{} for _, p := range ports { @@ -1840,8 +1855,8 @@ func (b *LocalBackend) readPoller() { b.doSetHostinfoFilterServices(hi) - n++ - if n == 1 { + if isFirst { + isFirst = false close(b.gotPortPollRes) } } diff --git a/portlist/poller.go b/portlist/poller.go index 16dd8e74c..90c8e7838 100644 --- a/portlist/poller.go +++ b/portlist/poller.go @@ -9,6 +9,7 @@ import ( "context" "errors" + "fmt" "runtime" "sync" "time" @@ -17,9 +18,16 @@ "tailscale.com/envknob" ) -var pollInterval = 5 * time.Second // default; changed by some OS-specific init funcs +var ( + pollInterval = 5 * time.Second // default; changed by some OS-specific init funcs + debugDisablePortlist = envknob.RegisterBool("TS_DEBUG_DISABLE_PORTLIST") +) -var debugDisablePortlist = envknob.RegisterBool("TS_DEBUG_DISABLE_PORTLIST") +// PollInterval is the recommended OS-specific interval +// to wait between *Poller.Poll method calls. +func PollInterval() time.Duration { + return pollInterval +} // Poller scans the systems for listening ports periodically and sends // the results to C. @@ -37,8 +45,9 @@ type Poller struct { // addProcesses is not used. // A nil values means we don't have code for getting the list on the current // operating system. - os osImpl - osOnce sync.Once // guards init of os + os osImpl + initOnce sync.Once // guards init of os + initErr error // closeCtx is the context that's canceled on Close. closeCtx context.Context @@ -69,24 +78,23 @@ type osImpl interface { // newOSImpl, if non-nil, constructs a new osImpl. var newOSImpl func(includeLocalhost bool) osImpl -var errUnimplemented = errors.New("portlist poller not implemented on " + runtime.GOOS) +var ( + errUnimplemented = errors.New("portlist poller not implemented on " + runtime.GOOS) + errDisabled = errors.New("portlist disabled by envknob") +) // NewPoller returns a new portlist Poller. It returns an error // if the portlist couldn't be obtained. func NewPoller() (*Poller, error) { - if debugDisablePortlist() { - return nil, errors.New("portlist disabled by envknob") - } p := &Poller{ c: make(chan List), runDone: make(chan struct{}), } - p.closeCtx, p.closeCtxCancel = context.WithCancel(context.Background()) - p.osOnce.Do(p.initOSField) - if p.os == nil { - return nil, errUnimplemented + p.initOnce.Do(p.init) + if p.initErr != nil { + return nil, p.initErr } - + p.closeCtx, p.closeCtxCancel = context.WithCancel(context.Background()) // Do one initial poll synchronously so we can return an error // early. if pl, err := p.getList(); err != nil { @@ -103,10 +111,18 @@ func (p *Poller) setPrev(pl List) { p.prev = slices.Clone(pl) } -func (p *Poller) initOSField() { - if newOSImpl != nil { - p.os = newOSImpl(p.IncludeLocalhost) +// init initializes the Poller by ensuring it has an underlying +// OS implementation and is not turned off by envknob. +func (p *Poller) init() { + if debugDisablePortlist() { + p.initErr = errDisabled + return } + if newOSImpl == nil { + p.initErr = errUnimplemented + return + } + p.os = newOSImpl(p.IncludeLocalhost) } // Updates return the channel that receives port list updates. @@ -115,14 +131,18 @@ func (p *Poller) initOSField() { func (p *Poller) Updates() <-chan List { return p.c } // Close closes the Poller. -// Run will return with a nil error. func (p *Poller) Close() error { - p.closeCtxCancel() - <-p.runDone - if p.os != nil { - p.os.Close() + if p.initErr != nil { + return p.initErr } - return nil + if p.os == nil { + return nil + } + if p.closeCtxCancel != nil { + p.closeCtxCancel() + <-p.runDone + } + return p.os.Close() } // send sends pl to p.c and returns whether it was successfully sent. @@ -137,6 +157,24 @@ func (p *Poller) send(ctx context.Context, pl List) (sent bool, err error) { } } +// Poll returns the list of listening ports, if changed from +// a previous call as indicated by the changed result. +func (p *Poller) Poll() (ports []Port, changed bool, err error) { + p.initOnce.Do(p.init) + if p.initErr != nil { + return nil, false, fmt.Errorf("error initializing poller: %w", p.initErr) + } + pl, err := p.getList() + if err != nil { + return nil, false, err + } + if pl.equal(p.prev) { + return nil, false, nil + } + p.setPrev(pl) + return p.prev, true, nil +} + // Run runs the Poller periodically until either the context // is done, or the Close is called. // @@ -179,10 +217,13 @@ func (p *Poller) runWithTickChan(ctx context.Context, tickChan <-chan time.Time) } func (p *Poller) getList() (List, error) { - if debugDisablePortlist() { + // TODO(marwan): this method does not + // need to do any init logic. Update tests + // once async API is removed. + p.initOnce.Do(p.init) + if p.initErr == errDisabled { return nil, nil } - p.osOnce.Do(p.initOSField) var err error p.scratch, err = p.os.AppendListeningPorts(p.scratch[:0]) return p.scratch, err diff --git a/portlist/portlist_test.go b/portlist/portlist_test.go index 6055a8426..86e8bd335 100644 --- a/portlist/portlist_test.go +++ b/portlist/portlist_test.go @@ -175,12 +175,33 @@ func TestEqualLessThan(t *testing.T) { } } +func TestClose(t *testing.T) { + var p Poller + err := p.Close() + if err != nil { + t.Fatal(err) + } + p = Poller{} + _, _, err = p.Poll() + if err != nil { + t.Skipf("skipping due to poll error: %v", err) + } + err = p.Close() + if err != nil { + t.Fatal(err) + } +} + func TestPoller(t *testing.T) { p, err := NewPoller() if err != nil { t.Skipf("not running test: %v", err) } - defer p.Close() + t.Cleanup(func() { + if err := p.Close(); err != nil { + t.Errorf("error closing poller in test: %v", err) + } + }) var wg sync.WaitGroup wg.Add(2)