diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index bae4b3f34..d4f95122a 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -81,6 +81,9 @@ type SSHServer interface { // so that existing sessions can be re-evaluated for validity // and closed if they'd no longer be accepted. OnPolicyChange() + + // Shutdown is called when tailscaled is shutting down. + Shutdown() } type newSSHServerFunc func(logger.Logf, *LocalBackend) (SSHServer, error) @@ -346,6 +349,9 @@ func (b *LocalBackend) Shutdown() { b.mu.Lock() b.shutdownCalled = true cc := b.cc + if b.sshServer != nil { + b.sshServer.Shutdown() + } b.closePeerAPIListenersLocked() b.mu.Unlock() diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index bed2c32fc..43be7498a 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -58,11 +58,14 @@ type server struct { pubKeyHTTPClient *http.Client // or nil for http.DefaultClient timeNow func() time.Time // or nil for time.Now + sessionWaitGroup sync.WaitGroup + // mu protects the following mu sync.Mutex activeSessionByH map[string]*sshSession // ssh.SessionID (DH H) => session activeSessionBySharedID map[string]*sshSession // yyymmddThhmmss-XXXXX => session fetchPublicKeysCache map[string]pubKeyCacheEntry // by https URL + shutdownCalled bool } func (srv *server) now() time.Time { @@ -101,6 +104,20 @@ func (srv *server) HandleSSHConn(c net.Conn) error { return nil } +// Shutdown terminates all active sessions. +func (srv *server) Shutdown() { + srv.mu.Lock() + srv.shutdownCalled = true + for _, s := range srv.activeSessionByH { + s.ctx.CloseWithError(userVisibleError{ + fmt.Sprintf("Tailscale shutting down.\r\n"), + context.Canceled, + }) + } + srv.mu.Unlock() + srv.sessionWaitGroup.Wait() +} + // OnPolicyChange terminates any active sessions that no longer match // the SSH access policy. func (srv *server) OnPolicyChange() { @@ -227,6 +244,15 @@ func (c *conn) ServerConfig(ctx ssh.Context) *gossh.ServerConfig { } func (srv *server) newConn() (*conn, error) { + srv.mu.Lock() + shutdownCalled := srv.shutdownCalled + srv.mu.Unlock() + if shutdownCalled { + // Stop accepting new connections. + // Connections in the auth phase are handled in handleConnPostSSHAuth. + // Existing sessions are terminated by Shutdown. + return nil, gossh.ErrDenied + } c := &conn{srv: srv, now: srv.now()} c.Server = &ssh.Server{ Version: "Tailscale", @@ -756,10 +782,10 @@ func (srv *server) getSessionForContext(sctx ssh.Context) (ss *sshSession, ok bo return } -// startSession registers ss as an active session. -func (srv *server) startSession(ss *sshSession) { - srv.mu.Lock() - defer srv.mu.Unlock() +// startSessionLocked registers ss as an active session. +// It must be called with srv.mu held. +func (srv *server) startSessionLocked(ss *sshSession) { + srv.sessionWaitGroup.Add(1) if ss.idH == "" { panic("empty idH") } @@ -778,6 +804,7 @@ func (srv *server) startSession(ss *sshSession) { // endSession unregisters s from the list of active sessions. func (srv *server) endSession(ss *sshSession) { + defer srv.sessionWaitGroup.Done() srv.mu.Lock() defer srv.mu.Unlock() delete(srv.activeSessionByH, ss.idH) @@ -842,11 +869,21 @@ func (ss *sshSession) handleSSHAgentForwarding(s ssh.Session, lu *user.User) err // It handles ss once it's been accepted and determined // that it should run. func (ss *sshSession) run() { - srv := ss.conn.srv - srv.startSession(ss) - defer srv.endSession(ss) - defer ss.ctx.CloseWithError(errSessionDone) + srv := ss.conn.srv + + srv.mu.Lock() + if srv.shutdownCalled { + srv.mu.Unlock() + // Do not start any new sessions. + fmt.Fprintf(ss, "Tailscale is shutting down\r\n") + ss.Exit(1) + return + } + srv.startSessionLocked(ss) + srv.mu.Unlock() + + defer srv.endSession(ss) if ss.action.SessionDuration != 0 { t := time.AfterFunc(ss.action.SessionDuration, func() { @@ -858,7 +895,7 @@ func (ss *sshSession) run() { defer t.Stop() } - logf := srv.logf + logf := ss.logf lu := ss.conn.localUser localUser := lu.Username