ipn/ipnlocal: use lockAndGetUnlock in Start

This removes one of the Lock,Unlock,Lock,Unlock at least in
the Start function. Still has 3 more of these.

Updates #11649

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2024-05-07 23:42:45 +00:00 committed by Maisem Ali
parent e1011f1387
commit 9380e2dfc6
2 changed files with 22 additions and 25 deletions

View File

@ -402,7 +402,6 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
osshare.SetFileSharingEnabled(false, logf) osshare.SetFileSharingEnabled(false, logf)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
portpoll := new(portlist.Poller)
clock := tstime.StdClock{} clock := tstime.StdClock{}
b := &LocalBackend{ b := &LocalBackend{
@ -420,7 +419,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo
pm: pm, pm: pm,
backendLogID: logID, backendLogID: logID,
state: ipn.NoState, state: ipn.NoState,
portpoll: portpoll, portpoll: new(portlist.Poller),
em: newExpiryManager(logf), em: newExpiryManager(logf),
gotPortPollRes: make(chan struct{}), gotPortPollRes: make(chan struct{}),
loginFlags: loginFlags, loginFlags: loginFlags,
@ -1593,6 +1592,14 @@ func (b *LocalBackend) NodeViewByIDForTest(id tailcfg.NodeID) (_ tailcfg.NodeVie
return n, ok return n, ok
} }
// DisablePortMapperForTest disables the portmapper for tests.
// It must be called before Start.
func (b *LocalBackend) DisablePortMapperForTest() {
b.mu.Lock()
defer b.mu.Unlock()
b.portpoll = nil
}
// PeersForTest returns all the current peers, sorted by Node.ID, // PeersForTest returns all the current peers, sorted by Node.ID,
// for integration tests in another repo. // for integration tests in another repo.
func (b *LocalBackend) PeersForTest() []tailcfg.NodeView { func (b *LocalBackend) PeersForTest() []tailcfg.NodeView {
@ -1605,9 +1612,7 @@ func (b *LocalBackend) PeersForTest() []tailcfg.NodeView {
return ret return ret
} }
func (b *LocalBackend) getNewControlClientFunc() clientGen { func (b *LocalBackend) getNewControlClientFuncLocked() clientGen {
b.mu.Lock()
defer b.mu.Unlock()
if b.ccGen == nil { if b.ccGen == nil {
// Initialize it rather than just returning the // Initialize it rather than just returning the
// default to make any future call to // default to make any future call to
@ -1632,11 +1637,17 @@ func (b *LocalBackend) getNewControlClientFunc() clientGen {
func (b *LocalBackend) Start(opts ipn.Options) error { func (b *LocalBackend) Start(opts ipn.Options) error {
b.logf("Start") b.logf("Start")
b.mu.Lock() var clientToShutdown controlclient.Client
defer func() {
if clientToShutdown != nil {
clientToShutdown.Shutdown()
}
}()
unlock := b.lockAndGetUnlock()
defer unlock()
if opts.UpdatePrefs != nil { if opts.UpdatePrefs != nil {
if err := b.checkPrefsLocked(opts.UpdatePrefs); err != nil { if err := b.checkPrefsLocked(opts.UpdatePrefs); err != nil {
b.mu.Unlock()
return err return err
} }
} }
@ -1668,7 +1679,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
// into sync with the minimal changes. But that's not how it // into sync with the minimal changes. But that's not how it
// is right now, which is a sign that the code is still too // is right now, which is a sign that the code is still too
// complicated. // complicated.
prevCC := b.resetControlClientLocked() clientToShutdown = b.resetControlClientLocked()
httpTestClient := b.httpTestClient httpTestClient := b.httpTestClient
if b.hostinfo != nil { if b.hostinfo != nil {
@ -1697,7 +1708,6 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
wantRunning := prefs.WantRunning() wantRunning := prefs.WantRunning()
if wantRunning { if wantRunning {
if err := b.initMachineKeyLocked(); err != nil { if err := b.initMachineKeyLocked(); err != nil {
b.mu.Unlock()
return fmt.Errorf("initMachineKeyLocked: %w", err) return fmt.Errorf("initMachineKeyLocked: %w", err)
} }
} }
@ -1716,7 +1726,6 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
persistv = new(persist.Persist) persistv = new(persist.Persist)
} }
b.updateFilterLocked(nil, ipn.PrefsView{}) b.updateFilterLocked(nil, ipn.PrefsView{})
b.mu.Unlock()
if b.portpoll != nil { if b.portpoll != nil {
b.portpollOnce.Do(func() { b.portpollOnce.Do(func() {
@ -1748,15 +1757,11 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
debugFlags = append([]string{"netstack"}, debugFlags...) debugFlags = append([]string{"netstack"}, debugFlags...)
} }
if prevCC != nil {
prevCC.Shutdown()
}
// TODO(apenwarr): The only way to change the ServerURL is to // TODO(apenwarr): The only way to change the ServerURL is to
// re-run b.Start, because this is the only place we create a // re-run b.Start, because this is the only place we create a
// new controlclient. EditPrefs allows you to overwrite ServerURL, // new controlclient. EditPrefs allows you to overwrite ServerURL,
// but it won't take effect until the next Start. // but it won't take effect until the next Start.
cc, err := b.getNewControlClientFunc()(controlclient.Options{ cc, err := b.getNewControlClientFuncLocked()(controlclient.Options{
GetMachinePrivateKey: b.createGetMachinePrivateKeyFunc(), GetMachinePrivateKey: b.createGetMachinePrivateKeyFunc(),
Logf: logger.WithPrefix(b.logf, "control: "), Logf: logger.WithPrefix(b.logf, "control: "),
Persist: *persistv, Persist: *persistv,
@ -1786,14 +1791,6 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
return err return err
} }
b.mu.Lock()
// Even though we reset b.cc above, we might have raced with
// another Start() call. If so, shut down the previous one again
// as we do not know if it was created with the same options.
prevCC = b.resetControlClientLocked()
if prevCC != nil {
defer prevCC.Shutdown() // must be called after b.mu is unlocked
}
b.setControlClientLocked(cc) b.setControlClientLocked(cc)
endpoints := b.endpoints endpoints := b.endpoints
@ -1804,13 +1801,12 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
if b.tka != nil { if b.tka != nil {
head, err := b.tka.authority.Head().MarshalText() head, err := b.tka.authority.Head().MarshalText()
if err != nil { if err != nil {
b.mu.Unlock()
return fmt.Errorf("marshalling tka head: %w", err) return fmt.Errorf("marshalling tka head: %w", err)
} }
tkaHead = string(head) tkaHead = string(head)
} }
confWantRunning := b.conf != nil && wantRunning confWantRunning := b.conf != nil && wantRunning
b.mu.Unlock() unlock.UnlockEarly()
if endpoints != nil { if endpoints != nil {
cc.UpdateEndpoints(endpoints) cc.UpdateEndpoints(endpoints)

View File

@ -316,6 +316,7 @@ func TestStateMachine(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("NewLocalBackend: %v", err) t.Fatalf("NewLocalBackend: %v", err)
} }
b.DisablePortMapperForTest()
var cc, previousCC *mockControl var cc, previousCC *mockControl
b.SetControlClientGetterForTesting(func(opts controlclient.Options) (controlclient.Client, error) { b.SetControlClientGetterForTesting(func(opts controlclient.Options) (controlclient.Client, error) {