ssh/tailssh: allow multiple sessions on the same conn

Fixes #4920
Fixes tailscale/corp#5633
Updates #4479

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2022-06-27 11:50:11 -07:00 committed by Maisem Ali
parent 1d04e01d1e
commit a7d2024e35
2 changed files with 145 additions and 106 deletions

View File

@ -62,11 +62,10 @@ type server struct {
sessionWaitGroup sync.WaitGroup sessionWaitGroup sync.WaitGroup
// mu protects the following // mu protects the following
mu sync.Mutex mu sync.Mutex
activeSessionByH map[string]*sshSession // ssh.SessionID (DH H) => session activeConns map[*conn]bool // set; value is always true
activeSessionBySharedID map[string]*sshSession // yyymmddThhmmss-XXXXX => session fetchPublicKeysCache map[string]pubKeyCacheEntry // by https URL
fetchPublicKeysCache map[string]pubKeyCacheEntry // by https URL shutdownCalled bool
shutdownCalled bool
} }
func (srv *server) now() time.Time { func (srv *server) now() time.Time {
@ -91,14 +90,28 @@ func init() {
}) })
} }
func (srv *server) trackActiveConn(c *conn, add bool) {
srv.mu.Lock()
defer srv.mu.Unlock()
if add {
mak.Set(&srv.activeConns, c, true)
return
}
delete(srv.activeConns, c)
}
// HandleSSHConn handles a Tailscale SSH connection from c. // HandleSSHConn handles a Tailscale SSH connection from c.
func (srv *server) HandleSSHConn(c net.Conn) error { // This is the entry point for all SSH connections.
// When this returns, the connection is closed.
func (srv *server) HandleSSHConn(nc net.Conn) error {
metricIncomingConnections.Add(1) metricIncomingConnections.Add(1)
ss, err := srv.newConn() c, err := srv.newConn()
if err != nil { if err != nil {
return err return err
} }
ss.HandleConn(c) srv.trackActiveConn(c, true) // add
defer srv.trackActiveConn(c, false) // remove
c.HandleConn(nc)
// Return nil to signal to netstack's interception that it doesn't need to // Return nil to signal to netstack's interception that it doesn't need to
// log. If ss.HandleConn had problems, it can log itself (ideally on an // log. If ss.HandleConn had problems, it can log itself (ideally on an
@ -110,11 +123,13 @@ func (srv *server) HandleSSHConn(c net.Conn) error {
func (srv *server) Shutdown() { func (srv *server) Shutdown() {
srv.mu.Lock() srv.mu.Lock()
srv.shutdownCalled = true srv.shutdownCalled = true
for _, s := range srv.activeSessionByH { for c := range srv.activeConns {
s.ctx.CloseWithError(userVisibleError{ for _, s := range c.sessions {
fmt.Sprintf("Tailscale SSH is shutting down.\r\n"), s.ctx.CloseWithError(userVisibleError{
context.Canceled, fmt.Sprintf("Tailscale SSH is shutting down.\r\n"),
}) context.Canceled,
})
}
} }
srv.mu.Unlock() srv.mu.Unlock()
srv.sessionWaitGroup.Wait() srv.sessionWaitGroup.Wait()
@ -125,8 +140,8 @@ func (srv *server) Shutdown() {
func (srv *server) OnPolicyChange() { func (srv *server) OnPolicyChange() {
srv.mu.Lock() srv.mu.Lock()
defer srv.mu.Unlock() defer srv.mu.Unlock()
for _, s := range srv.activeSessionByH { for c := range srv.activeConns {
go s.checkStillValid() go c.checkStillValid()
} }
} }
@ -135,25 +150,33 @@ func (srv *server) OnPolicyChange() {
type conn struct { type conn struct {
*ssh.Server *ssh.Server
insecureSkipTailscaleAuth bool // used by tests.
// now is the time to consider the present moment for the // now is the time to consider the present moment for the
// purposes of rule evaluation. // purposes of rule evaluation.
now time.Time now time.Time
connID string // ID that's shared with control
action0 *tailcfg.SSHAction // first matching action action0 *tailcfg.SSHAction // first matching action
srv *server srv *server
info *sshConnInfo // set by setInfo info *sshConnInfo // set by setInfo
localUser *user.User // set by checkAuth localUser *user.User // set by checkAuth
userGroupIDs []string // set by checkAuth userGroupIDs []string // set by checkAuth
insecureSkipTailscaleAuth bool // used by tests. mu sync.Mutex // protects the following
// idH is the RFC4253 sec8 hash H. It is used to identify the connection,
// and is shared among all sessions. It should not be shared outside
// process. It is confusingly referred to as SessionID by the gliderlabs/ssh
// library.
idH string
pubKey gossh.PublicKey // set by authorizeSession
finalAction *tailcfg.SSHAction // set by authorizeSession
finalActionErr error // set by authorizeSession
sessions []*sshSession
} }
func (c *conn) logf(format string, args ...any) { func (c *conn) logf(format string, args ...any) {
if c.info == nil { format = fmt.Sprintf("%v: %v", c.connID, format)
c.srv.logf(format, args...)
return
}
format = fmt.Sprintf("%v: %v", c.info.String(), format)
c.srv.logf(format, args...) c.srv.logf(format, args...)
} }
@ -247,21 +270,22 @@ func (c *conn) ServerConfig(ctx ssh.Context) *gossh.ServerConfig {
func (srv *server) newConn() (*conn, error) { func (srv *server) newConn() (*conn, error) {
srv.mu.Lock() srv.mu.Lock()
shutdownCalled := srv.shutdownCalled if srv.shutdownCalled {
srv.mu.Unlock() srv.mu.Unlock()
if shutdownCalled {
// Stop accepting new connections. // Stop accepting new connections.
// Connections in the auth phase are handled in handleConnPostSSHAuth. // Connections in the auth phase are handled in handleConnPostSSHAuth.
// Existing sessions are terminated by Shutdown. // Existing sessions are terminated by Shutdown.
return nil, gossh.ErrDenied return nil, gossh.ErrDenied
} }
srv.mu.Unlock()
c := &conn{srv: srv, now: srv.now()} c := &conn{srv: srv, now: srv.now()}
c.connID = fmt.Sprintf("conn-%s-%02x", c.now.UTC().Format("20060102T150405"), randBytes(5))
c.Server = &ssh.Server{ c.Server = &ssh.Server{
Version: "Tailscale", Version: "Tailscale",
Handler: c.handleConnPostSSHAuth, Handler: c.handleSessionPostSSHAuth,
RequestHandlers: map[string]ssh.RequestHandler{}, RequestHandlers: map[string]ssh.RequestHandler{},
SubsystemHandlers: map[string]ssh.SubsystemHandler{ SubsystemHandlers: map[string]ssh.SubsystemHandler{
"sftp": c.handleConnPostSSHAuth, "sftp": c.handleSessionPostSSHAuth,
}, },
// Note: the direct-tcpip channel handler and LocalPortForwardingCallback // Note: the direct-tcpip channel handler and LocalPortForwardingCallback
@ -270,7 +294,7 @@ func (srv *server) newConn() (*conn, error) {
ChannelHandlers: map[string]ssh.ChannelHandler{ ChannelHandlers: map[string]ssh.ChannelHandler{
"direct-tcpip": ssh.DirectTCPIPHandler, "direct-tcpip": ssh.DirectTCPIPHandler,
}, },
LocalPortForwardingCallback: srv.mayForwardLocalPortTo, LocalPortForwardingCallback: c.mayForwardLocalPortTo,
PublicKeyHandler: c.PublicKeyHandler, PublicKeyHandler: c.PublicKeyHandler,
ServerConfigCallback: c.ServerConfig, ServerConfigCallback: c.ServerConfig,
@ -298,16 +322,12 @@ func (srv *server) newConn() (*conn, error) {
// mayForwardLocalPortTo reports whether the ctx should be allowed to port forward // mayForwardLocalPortTo reports whether the ctx should be allowed to port forward
// to the specified host and port. // to the specified host and port.
// TODO(bradfitz/maisem): should we have more checks on host/port? // TODO(bradfitz/maisem): should we have more checks on host/port?
func (srv *server) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { func (c *conn) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
ss, ok := srv.getSessionForContext(ctx) if c.finalAction != nil && c.finalAction.AllowLocalPortForwarding {
if !ok { metricLocalPortForward.Add(1)
return false return true
} }
if !ss.action.AllowLocalPortForwarding { return false
return false
}
metricLocalPortForward.Add(1)
return true
} }
// havePubKeyPolicy reports whether any policy rule may provide access by means // havePubKeyPolicy reports whether any policy rule may provide access by means
@ -401,6 +421,7 @@ func (c *conn) setInfo(cm gossh.ConnMetadata) error {
ci.uprof = &uprof ci.uprof = &uprof
c.info = ci c.info = ci
c.logf("handling conn: %v", ci.String())
return nil return nil
} }
@ -516,32 +537,47 @@ func (srv *server) fetchPublicKeysURL(url string) ([]string, error) {
return lines, err return lines, err
} }
// handleConnPostSSHAuth runs an SSH session after the SSH-level authentication, func (c *conn) authorizeSession(s ssh.Session) (_ *contextReader, ok bool) {
// but not necessarily before all the Tailscale-level extra verification has c.mu.Lock()
// completed. It also handles SFTP requests. defer c.mu.Unlock()
func (c *conn) handleConnPostSSHAuth(s ssh.Session) { idH := s.Context().(ssh.Context).SessionID()
if s.PublicKey() != nil { if c.idH == "" {
metricPublicKeyConnections.Add(1) c.idH = idH
} else if c.idH != idH {
c.logf("ssh: session ID mismatch: %q != %q", c.idH, idH)
s.Exit(1)
return nil, false
} }
sshUser := s.User()
cr := &contextReader{r: s} cr := &contextReader{r: s}
action, err := c.resolveTerminalAction(s, cr) action, err := c.resolveTerminalActionLocked(s, cr)
if err != nil { if err != nil {
c.logf("resolveTerminalAction: %v", err) c.logf("resolveTerminalAction: %v", err)
io.WriteString(s.Stderr(), "Access Denied: failed during authorization check.\r\n") io.WriteString(s.Stderr(), "Access Denied: failed during authorization check.\r\n")
s.Exit(1) s.Exit(1)
return return nil, false
} }
if action.Reject || !action.Accept { if action.Reject || !action.Accept {
c.logf("access denied for %v", c.info.uprof.LoginName) c.logf("access denied for %v", c.info.uprof.LoginName)
s.Exit(1) s.Exit(1)
return nil, false
}
return cr, true
}
// handleSessionPostSSHAuth runs an SSH session after the SSH-level authentication,
// but not necessarily before all the Tailscale-level extra verification has
// completed. It also handles SFTP requests.
func (c *conn) handleSessionPostSSHAuth(s ssh.Session) {
// Now that we have passed the SSH-level authentication, we can start the
// Tailscale-level extra verification. This means that we are going to
// evaluate the policy provided by control against the incoming SSH session.
cr, ok := c.authorizeSession(s)
if !ok {
return return
} }
if s.PublicKey() != nil {
metricPublicKeyAccepts.Add(1)
}
if cr.HasOutstandingRead() { if cr.HasOutstandingRead() {
// There was some buffered input while we were waiting for the policy
// decision.
s = contextReaderSesssion{s, cr} s = contextReaderSesssion{s, cr}
} }
@ -555,20 +591,37 @@ func (c *conn) handleConnPostSSHAuth(s ssh.Session) {
return return
} }
ss := c.newSSHSession(s, action) ss := c.newSSHSession(s)
ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.IP(), sshUser) ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.IP(), c.localUser.Username)
ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, sshUser) ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, c.localUser.Name)
ss.run() ss.run()
} }
// resolveTerminalAction either returns action0 (if it's Accept or Reject) or // resolveTerminalActionLocked either returns action0 (if it's Accept or Reject) or
// else loops, fetching new SSHActions from the control plane. // else loops, fetching new SSHActions from the control plane.
// //
// Any action with a Message in the chain will be printed to s. // Any action with a Message in the chain will be printed to s.
// //
// The returned SSHAction will be either Reject or Accept. // The returned SSHAction will be either Reject or Accept.
func (c *conn) resolveTerminalAction(s ssh.Session, cr *contextReader) (*tailcfg.SSHAction, error) { //
action := c.action0 // c.mu must be held.
func (c *conn) resolveTerminalActionLocked(s ssh.Session, cr *contextReader) (action *tailcfg.SSHAction, err error) {
if c.finalAction != nil || c.finalActionErr != nil {
return c.finalAction, c.finalActionErr
}
if s.PublicKey() != nil {
metricPublicKeyConnections.Add(1)
}
defer func() {
c.finalAction = action
c.finalActionErr = err
c.pubKey = s.PublicKey()
if c.pubKey != nil && action.Accept {
metricPublicKeyAccepts.Add(1)
}
}()
action = c.action0
var awaitReadOnce sync.Once // to start Reads on cr var awaitReadOnce sync.Once // to start Reads on cr
var sawInterrupt syncs.AtomicBool var sawInterrupt syncs.AtomicBool
@ -672,13 +725,11 @@ func (c *conn) expandPublicKeyURL(pubKeyURL string) string {
// sshSession is an accepted Tailscale SSH session. // sshSession is an accepted Tailscale SSH session.
type sshSession struct { type sshSession struct {
ssh.Session ssh.Session
idH string // the RFC4253 sec8 hash H; don't share outside process
sharedID string // ID that's shared with control sharedID string // ID that's shared with control
logf logger.Logf logf logger.Logf
ctx *sshContext // implements context.Context ctx *sshContext // implements context.Context
conn *conn conn *conn
action *tailcfg.SSHAction
agentListener net.Listener // non-nil if agent-forwarding requested+allowed agentListener net.Listener // non-nil if agent-forwarding requested+allowed
// initialized by launchProcess: // initialized by launchProcess:
@ -699,22 +750,21 @@ func (ss *sshSession) vlogf(format string, args ...interface{}) {
} }
} }
func (c *conn) newSSHSession(s ssh.Session, action *tailcfg.SSHAction) *sshSession { func (c *conn) newSSHSession(s ssh.Session) *sshSession {
sharedID := fmt.Sprintf("%s-%02x", c.now.UTC().Format("20060102T150405"), randBytes(5)) sharedID := fmt.Sprintf("sess-%s-%02x", c.now.UTC().Format("20060102T150405"), randBytes(5))
c.logf("starting session: %v", sharedID) c.logf("starting session: %v", sharedID)
return &sshSession{ return &sshSession{
Session: s, Session: s,
idH: s.Context().(ssh.Context).SessionID(),
sharedID: sharedID, sharedID: sharedID,
ctx: newSSHContext(), ctx: newSSHContext(),
conn: c, conn: c,
logf: logger.WithPrefix(c.srv.logf, "ssh-session("+sharedID+"): "), logf: logger.WithPrefix(c.srv.logf, "ssh-session("+sharedID+"): "),
action: action,
} }
} }
func (c *conn) isStillValid(pubKey ssh.PublicKey) bool { // isStillValid reports whether the conn is still valid.
a, localUser, err := c.evaluatePolicy(pubKey) func (c *conn) isStillValid() bool {
a, localUser, err := c.evaluatePolicy(c.pubKey)
if err != nil { if err != nil {
return false return false
} }
@ -724,18 +774,20 @@ func (c *conn) isStillValid(pubKey ssh.PublicKey) bool {
return c.localUser.Username == localUser return c.localUser.Username == localUser
} }
// checkStillValid checks that the session is still valid per the latest SSHPolicy. // checkStillValid checks that the conn is still valid per the latest SSHPolicy.
// If not, it terminates the session. // If not, it terminates all sessions associated with the conn.
func (ss *sshSession) checkStillValid() { func (c *conn) checkStillValid() {
if ss.conn.isStillValid(ss.PublicKey()) { if c.isStillValid() {
return return
} }
metricPolicyChangeKick.Add(1) metricPolicyChangeKick.Add(1)
ss.logf("session no longer valid per new SSH policy; closing") c.logf("session no longer valid per new SSH policy; closing")
ss.ctx.CloseWithError(userVisibleError{ for _, s := range c.sessions {
fmt.Sprintf("Access revoked.\r\n"), s.ctx.CloseWithError(userVisibleError{
context.Canceled, fmt.Sprintf("Access revoked.\r\n"),
}) context.Canceled,
})
}
} }
func (c *conn) fetchSSHAction(ctx context.Context, url string) (*tailcfg.SSHAction, error) { func (c *conn) fetchSSHAction(ctx context.Context, url string) (*tailcfg.SSHAction, error) {
@ -798,41 +850,27 @@ func (ss *sshSession) killProcessOnContextDone() {
}) })
} }
// sessionAction returns the SSHAction associated with the session.
func (srv *server) getSessionForContext(sctx ssh.Context) (ss *sshSession, ok bool) {
srv.mu.Lock()
defer srv.mu.Unlock()
ss, ok = srv.activeSessionByH[sctx.SessionID()]
return
}
// startSessionLocked registers ss as an active session. // startSessionLocked registers ss as an active session.
// It must be called with srv.mu held. // It must be called with srv.mu held.
func (srv *server) startSessionLocked(ss *sshSession) { func (c *conn) startSessionLocked(ss *sshSession) {
srv.sessionWaitGroup.Add(1) c.srv.sessionWaitGroup.Add(1)
if ss.idH == "" {
panic("empty idH")
}
if ss.sharedID == "" { if ss.sharedID == "" {
panic("empty sharedID") panic("empty sharedID")
} }
if _, dup := srv.activeSessionByH[ss.idH]; dup { c.sessions = append(c.sessions, ss)
panic("dup idH")
}
if _, dup := srv.activeSessionBySharedID[ss.sharedID]; dup {
panic("dup sharedID")
}
mak.Set(&srv.activeSessionByH, ss.idH, ss)
mak.Set(&srv.activeSessionBySharedID, ss.sharedID, ss)
} }
// endSession unregisters s from the list of active sessions. // endSession unregisters s from the list of active sessions.
func (srv *server) endSession(ss *sshSession) { func (c *conn) endSession(ss *sshSession) {
defer srv.sessionWaitGroup.Done() defer c.srv.sessionWaitGroup.Done()
srv.mu.Lock() c.srv.mu.Lock()
defer srv.mu.Unlock() defer c.srv.mu.Unlock()
delete(srv.activeSessionByH, ss.idH) for i, s := range c.sessions {
delete(srv.activeSessionBySharedID, ss.sharedID) if s == ss {
c.sessions = append(c.sessions[:i], c.sessions[i+1:]...)
break
}
}
} }
var errSessionDone = errors.New("session is done") var errSessionDone = errors.New("session is done")
@ -841,7 +879,7 @@ func (srv *server) endSession(ss *sshSession) {
// forwards agent connections between the listener and the ssh.Session. // forwards agent connections between the listener and the ssh.Session.
// On success, it assigns ss.agentListener. // On success, it assigns ss.agentListener.
func (ss *sshSession) handleSSHAgentForwarding(s ssh.Session, lu *user.User) error { func (ss *sshSession) handleSSHAgentForwarding(s ssh.Session, lu *user.User) error {
if !ssh.AgentRequested(ss) || !ss.action.AllowAgentForwarding { if !ssh.AgentRequested(ss) || !ss.conn.finalAction.AllowAgentForwarding {
return nil return nil
} }
ss.logf("ssh: agent forwarding requested") ss.logf("ssh: agent forwarding requested")
@ -906,15 +944,15 @@ func (ss *sshSession) run() {
ss.Exit(1) ss.Exit(1)
return return
} }
srv.startSessionLocked(ss) ss.conn.startSessionLocked(ss)
srv.mu.Unlock() srv.mu.Unlock()
defer srv.endSession(ss) defer ss.conn.endSession(ss)
if ss.action.SessionDuration != 0 { if ss.conn.finalAction.SessionDuration != 0 {
t := time.AfterFunc(ss.action.SessionDuration, func() { t := time.AfterFunc(ss.conn.finalAction.SessionDuration, func() {
ss.ctx.CloseWithError(userVisibleError{ ss.ctx.CloseWithError(userVisibleError{
fmt.Sprintf("Session timeout of %v elapsed.", ss.action.SessionDuration), fmt.Sprintf("Session timeout of %v elapsed.", ss.conn.finalAction.SessionDuration),
context.DeadlineExceeded, context.DeadlineExceeded,
}) })
}) })

View File

@ -238,9 +238,10 @@ func TestSSH(t *testing.T) {
node: &tailcfg.Node{}, node: &tailcfg.Node{},
uprof: &tailcfg.UserProfile{}, uprof: &tailcfg.UserProfile{},
} }
sc.finalAction = &tailcfg.SSHAction{Accept: true}
sc.Handler = func(s ssh.Session) { sc.Handler = func(s ssh.Session) {
sc.newSSHSession(s, &tailcfg.SSHAction{Accept: true}).run() sc.newSSHSession(s).run()
} }
ln, err := net.Listen("tcp4", "127.0.0.1:0") ln, err := net.Listen("tcp4", "127.0.0.1:0")