ipn/ipnlocal: acquire b.mu once in Start

We used to Lock, Unlock, Lock, Unlock quite a few
times in Start resulting in all sorts of weird race
conditions. Simplify it all and only Lock/Unlock once.

Updates #11649

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2024-05-08 00:31:10 +00:00 committed by Maisem Ali
parent 9380e2dfc6
commit 32bc596062
3 changed files with 60 additions and 44 deletions

View File

@ -43,7 +43,8 @@ type Client interface {
// Login begins an interactive or non-interactive login process. // Login begins an interactive or non-interactive login process.
// Client will eventually call the Status callback with either a // Client will eventually call the Status callback with either a
// LoginFinished flag (on success) or an auth URL (if further // LoginFinished flag (on success) or an auth URL (if further
// interaction is needed). // interaction is needed). It merely sets the process in motion,
// and doesn't wait for it to complete.
Login(*tailcfg.Oauth2Token, LoginFlags) Login(*tailcfg.Oauth2Token, LoginFlags)
// Logout starts a synchronous logout process. It doesn't return // Logout starts a synchronous logout process. It doesn't return
// until the logout operation has been completed. // until the logout operation has been completed.

View File

@ -230,7 +230,8 @@ type LocalBackend struct {
ccGen clientGen // function for producing controlclient; lazily populated ccGen clientGen // function for producing controlclient; lazily populated
sshServer SSHServer // or nil, initialized lazily. sshServer SSHServer // or nil, initialized lazily.
appConnector *appc.AppConnector // or nil, initialized when configured. appConnector *appc.AppConnector // or nil, initialized when configured.
notify func(ipn.Notify) // notifyCancel cancels notifications to the current SetNotifyCallback.
notifyCancel context.CancelFunc
cc controlclient.Client cc controlclient.Client
ccAuto *controlclient.Auto // if cc is of type *controlclient.Auto ccAuto *controlclient.Auto // if cc is of type *controlclient.Auto
machinePrivKey key.MachinePrivate machinePrivKey key.MachinePrivate
@ -710,6 +711,9 @@ func (b *LocalBackend) Shutdown() {
b.debugSink.Close() b.debugSink.Close()
b.debugSink = nil b.debugSink = nil
} }
if b.notifyCancel != nil {
b.notifyCancel()
}
b.mu.Unlock() b.mu.Unlock()
b.webClientShutdown() b.webClientShutdown()
@ -1557,10 +1561,26 @@ func endpointsEqual(x, y []tailcfg.Endpoint) bool {
return true return true
} }
// SetNotifyCallback sets the function to call when the backend has something to
// notify the frontend about. Only one callback can be set at a time, so calling
// this function will replace the previous callback.
func (b *LocalBackend) SetNotifyCallback(notify func(ipn.Notify)) { func (b *LocalBackend) SetNotifyCallback(notify func(ipn.Notify)) {
ctx, cancel := context.WithCancel(b.ctx)
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock() prevCancel := b.notifyCancel
b.notify = notify b.notifyCancel = cancel
b.mu.Unlock()
if prevCancel != nil {
prevCancel()
}
var wg sync.WaitGroup
wg.Add(1)
go b.WatchNotifications(ctx, 0, wg.Done, func(n *ipn.Notify) bool {
notify(*n)
return true
})
wg.Wait()
} }
// SetHTTPTestClient sets an alternate HTTP client to use with // SetHTTPTestClient sets an alternate HTTP client to use with
@ -1806,7 +1826,6 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
tkaHead = string(head) tkaHead = string(head)
} }
confWantRunning := b.conf != nil && wantRunning confWantRunning := b.conf != nil && wantRunning
unlock.UnlockEarly()
if endpoints != nil { if endpoints != nil {
cc.UpdateEndpoints(endpoints) cc.UpdateEndpoints(endpoints)
@ -1815,16 +1834,23 @@ func (b *LocalBackend) Start(opts ipn.Options) error {
blid := b.backendLogID.String() blid := b.backendLogID.String()
b.logf("Backend: logs: be:%v fe:%v", blid, opts.FrontendLogID) b.logf("Backend: logs: be:%v fe:%v", blid, opts.FrontendLogID)
b.send(ipn.Notify{BackendLogID: &blid}) b.sendLocked(ipn.Notify{
b.send(ipn.Notify{Prefs: &prefs}) BackendLogID: &blid,
Prefs: &prefs,
})
if !loggedOut && (b.hasNodeKey() || confWantRunning) { if !loggedOut && (b.hasNodeKeyLocked() || confWantRunning) {
// Even if !WantRunning, we should verify our key, if there // If we know that we're either logged in or meant to be
// is one. If you want tailscaled to be completely idle, // running, tell the controlclient that it should also assume
// use logout instead. // that we need to be logged in.
//
// Without this, the state machine transitions to "NeedsLogin" implying
// that user interaction is required, which is not the case and can
// regress tsnet.Server restarts.
cc.Login(nil, controlclient.LoginDefault) cc.Login(nil, controlclient.LoginDefault)
} }
b.stateMachine() b.stateMachineLockedOnEntry(unlock)
return nil return nil
} }
@ -2390,6 +2416,13 @@ func (b *LocalBackend) DebugPickNewDERP() error {
// //
// b.mu must not be held. // b.mu must not be held.
func (b *LocalBackend) send(n ipn.Notify) { func (b *LocalBackend) send(n ipn.Notify) {
b.mu.Lock()
defer b.mu.Unlock()
b.sendLocked(n)
}
// sendLocked is like send, but assumes b.mu is already held.
func (b *LocalBackend) sendLocked(n ipn.Notify) {
if n.Prefs != nil { if n.Prefs != nil {
n.Prefs = ptr.To(stripKeysFromPrefs(*n.Prefs)) n.Prefs = ptr.To(stripKeysFromPrefs(*n.Prefs))
} }
@ -2397,8 +2430,6 @@ func (b *LocalBackend) send(n ipn.Notify) {
n.Version = version.Long() n.Version = version.Long()
} }
b.mu.Lock()
notifyFunc := b.notify
apiSrv := b.peerAPIServer apiSrv := b.peerAPIServer
if mayDeref(apiSrv).taildrop.HasFilesWaiting() { if mayDeref(apiSrv).taildrop.HasFilesWaiting() {
n.FilesWaiting = &empty.Message{} n.FilesWaiting = &empty.Message{}
@ -2411,12 +2442,6 @@ func (b *LocalBackend) send(n ipn.Notify) {
// Drop the notification if the channel is full. // Drop the notification if the channel is full.
} }
} }
b.mu.Unlock()
if notifyFunc != nil {
notifyFunc(n)
}
} }
func (b *LocalBackend) sendFileNotify() { func (b *LocalBackend) sendFileNotify() {
@ -2426,9 +2451,8 @@ func (b *LocalBackend) sendFileNotify() {
for _, wakeWaiter := range b.fileWaiters { for _, wakeWaiter := range b.fileWaiters {
wakeWaiter() wakeWaiter()
} }
notifyFunc := b.notify
apiSrv := b.peerAPIServer apiSrv := b.peerAPIServer
if notifyFunc == nil || apiSrv == nil { if apiSrv == nil {
b.mu.Unlock() b.mu.Unlock()
return return
} }
@ -4376,14 +4400,6 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State, unlock unlock
} }
} }
// hasNodeKey reports whether a non-zero node key is present in the current
// prefs.
func (b *LocalBackend) hasNodeKey() bool {
b.mu.Lock()
defer b.mu.Unlock()
return b.hasNodeKeyLocked()
}
func (b *LocalBackend) hasNodeKeyLocked() bool { func (b *LocalBackend) hasNodeKeyLocked() bool {
// we can't use b.Prefs(), because it strips the keys, oops! // we can't use b.Prefs(), because it strips the keys, oops!
p := b.pm.CurrentPrefs() p := b.pm.CurrentPrefs()
@ -4481,6 +4497,12 @@ func (b *LocalBackend) nextStateLocked() ipn.State {
// Or maybe just call the state machine from fewer places. // Or maybe just call the state machine from fewer places.
func (b *LocalBackend) stateMachine() { func (b *LocalBackend) stateMachine() {
unlock := b.lockAndGetUnlock() unlock := b.lockAndGetUnlock()
b.stateMachineLockedOnEntry(unlock)
}
// stateMachineLockedOnEntry is like stateMachine but requires b.mu be held to
// call it, but it unlocks b.mu when done (via unlock, a once func).
func (b *LocalBackend) stateMachineLockedOnEntry(unlock unlockOnce) {
b.enterStateLockedOnEntry(b.nextStateLocked(), unlock) b.enterStateLockedOnEntry(b.nextStateLocked(), unlock)
} }

View File

@ -97,7 +97,6 @@ type mockControl struct {
paused atomic.Bool paused atomic.Bool
mu sync.Mutex mu sync.Mutex
machineKey key.MachinePrivate
persist *persist.Persist persist *persist.Persist
calls []string calls []string
authBlocked bool authBlocked bool
@ -134,12 +133,6 @@ func (cc *mockControl) populateKeys() (newKeys bool) {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
if cc.machineKey.IsZero() {
cc.logf("Copying machineKey.")
cc.machineKey, _ = cc.opts.GetMachinePrivateKey()
newKeys = true
}
if cc.persist == nil { if cc.persist == nil {
cc.persist = &persist.Persist{} cc.persist = &persist.Persist{}
} }
@ -831,7 +824,7 @@ func TestStateMachine(t *testing.T) {
// The last test case is the most common one: restarting when both // The last test case is the most common one: restarting when both
// logged in and WantRunning. // logged in and WantRunning.
t.Logf("\n\nStart5") t.Logf("\n\nStart5")
notifies.expect(2) notifies.expect(1)
c.Assert(b.Start(ipn.Options{}), qt.IsNil) c.Assert(b.Start(ipn.Options{}), qt.IsNil)
{ {
// NOTE: cc.Shutdown() is correct here, since we didn't call // NOTE: cc.Shutdown() is correct here, since we didn't call
@ -839,27 +832,27 @@ func TestStateMachine(t *testing.T) {
previousCC.assertShutdown(false) previousCC.assertShutdown(false)
cc.assertCalls("New", "Login") cc.assertCalls("New", "Login")
nn := notifies.drain(2) nn := notifies.drain(1)
cc.assertCalls() cc.assertCalls()
c.Assert(nn[0].Prefs, qt.IsNotNil) c.Assert(nn[0].Prefs, qt.IsNotNil)
c.Assert(nn[0].Prefs.LoggedOut(), qt.IsFalse) c.Assert(nn[0].Prefs.LoggedOut(), qt.IsFalse)
c.Assert(nn[0].Prefs.WantRunning(), qt.IsTrue) c.Assert(nn[0].Prefs.WantRunning(), qt.IsTrue)
c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) c.Assert(b.State(), qt.Equals, ipn.NoState)
} }
// Control server accepts our valid key from before. // Control server accepts our valid key from before.
t.Logf("\n\nLoginFinished5") t.Logf("\n\nLoginFinished5")
notifies.expect(2) notifies.expect(1)
cc.send(nil, "", true, &netmap.NetworkMap{ cc.send(nil, "", true, &netmap.NetworkMap{
SelfNode: (&tailcfg.Node{MachineAuthorized: true}).View(), SelfNode: (&tailcfg.Node{MachineAuthorized: true}).View(),
}) })
{ {
nn := notifies.drain(2) nn := notifies.drain(1)
cc.assertCalls() cc.assertCalls()
// NOTE: No LoginFinished message since no interactive // NOTE: No LoginFinished message since no interactive
// login was needed. // login was needed.
c.Assert(nn[1].State, qt.IsNotNil) c.Assert(nn[0].State, qt.IsNotNil)
c.Assert(ipn.Starting, qt.Equals, *nn[1].State) c.Assert(ipn.Starting, qt.Equals, *nn[0].State)
// NOTE: No prefs change this time. WantRunning stays true. // NOTE: No prefs change this time. WantRunning stays true.
// We were in Starting in the first place, so that doesn't // We were in Starting in the first place, so that doesn't
// change either. // change either.