ssh/tailssh: terminate sessions when tailscaled shutsdown

Ideally we would re-establish these sessions when tailscaled comes back
up, however we do not do that yet so this is better than leaking the
sessions.

Updates #3802

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2022-05-28 04:33:46 -07:00 committed by Maisem Ali
parent 760740905e
commit 7cd8c3e839
2 changed files with 52 additions and 9 deletions

View File

@ -81,6 +81,9 @@ type SSHServer interface {
// so that existing sessions can be re-evaluated for validity
// and closed if they'd no longer be accepted.
OnPolicyChange()
// Shutdown is called when tailscaled is shutting down.
Shutdown()
}
type newSSHServerFunc func(logger.Logf, *LocalBackend) (SSHServer, error)
@ -346,6 +349,9 @@ func (b *LocalBackend) Shutdown() {
b.mu.Lock()
b.shutdownCalled = true
cc := b.cc
if b.sshServer != nil {
b.sshServer.Shutdown()
}
b.closePeerAPIListenersLocked()
b.mu.Unlock()

View File

@ -58,11 +58,14 @@ type server struct {
pubKeyHTTPClient *http.Client // or nil for http.DefaultClient
timeNow func() time.Time // or nil for time.Now
sessionWaitGroup sync.WaitGroup
// mu protects the following
mu sync.Mutex
activeSessionByH map[string]*sshSession // ssh.SessionID (DH H) => session
activeSessionBySharedID map[string]*sshSession // yyymmddThhmmss-XXXXX => session
fetchPublicKeysCache map[string]pubKeyCacheEntry // by https URL
shutdownCalled bool
}
func (srv *server) now() time.Time {
@ -101,6 +104,20 @@ func (srv *server) HandleSSHConn(c net.Conn) error {
return nil
}
// Shutdown terminates all active sessions.
func (srv *server) Shutdown() {
srv.mu.Lock()
srv.shutdownCalled = true
for _, s := range srv.activeSessionByH {
s.ctx.CloseWithError(userVisibleError{
fmt.Sprintf("Tailscale shutting down.\r\n"),
context.Canceled,
})
}
srv.mu.Unlock()
srv.sessionWaitGroup.Wait()
}
// OnPolicyChange terminates any active sessions that no longer match
// the SSH access policy.
func (srv *server) OnPolicyChange() {
@ -227,6 +244,15 @@ func (c *conn) ServerConfig(ctx ssh.Context) *gossh.ServerConfig {
}
func (srv *server) newConn() (*conn, error) {
srv.mu.Lock()
shutdownCalled := srv.shutdownCalled
srv.mu.Unlock()
if shutdownCalled {
// Stop accepting new connections.
// Connections in the auth phase are handled in handleConnPostSSHAuth.
// Existing sessions are terminated by Shutdown.
return nil, gossh.ErrDenied
}
c := &conn{srv: srv, now: srv.now()}
c.Server = &ssh.Server{
Version: "Tailscale",
@ -756,10 +782,10 @@ func (srv *server) getSessionForContext(sctx ssh.Context) (ss *sshSession, ok bo
return
}
// startSession registers ss as an active session.
func (srv *server) startSession(ss *sshSession) {
srv.mu.Lock()
defer srv.mu.Unlock()
// startSessionLocked registers ss as an active session.
// It must be called with srv.mu held.
func (srv *server) startSessionLocked(ss *sshSession) {
srv.sessionWaitGroup.Add(1)
if ss.idH == "" {
panic("empty idH")
}
@ -778,6 +804,7 @@ func (srv *server) startSession(ss *sshSession) {
// endSession unregisters s from the list of active sessions.
func (srv *server) endSession(ss *sshSession) {
defer srv.sessionWaitGroup.Done()
srv.mu.Lock()
defer srv.mu.Unlock()
delete(srv.activeSessionByH, ss.idH)
@ -842,11 +869,21 @@ func (ss *sshSession) handleSSHAgentForwarding(s ssh.Session, lu *user.User) err
// It handles ss once it's been accepted and determined
// that it should run.
func (ss *sshSession) run() {
srv := ss.conn.srv
srv.startSession(ss)
defer srv.endSession(ss)
defer ss.ctx.CloseWithError(errSessionDone)
srv := ss.conn.srv
srv.mu.Lock()
if srv.shutdownCalled {
srv.mu.Unlock()
// Do not start any new sessions.
fmt.Fprintf(ss, "Tailscale is shutting down\r\n")
ss.Exit(1)
return
}
srv.startSessionLocked(ss)
srv.mu.Unlock()
defer srv.endSession(ss)
if ss.action.SessionDuration != 0 {
t := time.AfterFunc(ss.action.SessionDuration, func() {
@ -858,7 +895,7 @@ func (ss *sshSession) run() {
defer t.Stop()
}
logf := srv.logf
logf := ss.logf
lu := ss.conn.localUser
localUser := lu.Username