diff --git a/ssh/tailssh/ctxreader.go b/ssh/tailssh/ctxreader.go deleted file mode 100644 index a437eaab7..000000000 --- a/ssh/tailssh/ctxreader.go +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package tailssh - -import ( - "context" - "io" - "sync" - - "tailscale.com/tempfork/gliderlabs/ssh" -) - -// readResult is a result from a io.Reader.Read call, -// as used by contextReader. -type readResult struct { - buf []byte // ownership passed on chan send - err error -} - -// contextReader wraps an io.Reader, providing a ReadContext method -// that can be aborted before yielding bytes. If it's aborted, subsequent -// reads can get those byte(s) later. -type contextReader struct { - r io.Reader - - // buffered is leftover data from a previous read call that wasn't entirely - // consumed. - buffered []byte - // readErr is a previous read error that was seen while filling buffered. It - // should be returned to the caller after buffered is consumed. - readErr error - - mu sync.Mutex // guards ch only - - // ch is non-nil if a goroutine had been started and has a result to be - // read. The goroutine may be either still running or done and has - // send to the channel. - ch chan readResult -} - -// HasOutstandingRead reports whether there's an outstanding Read call that's -// either currently blocked in a Read or whose result hasn't been consumed. -func (w *contextReader) HasOutstandingRead() bool { - w.mu.Lock() - defer w.mu.Unlock() - return w.ch != nil -} - -func (w *contextReader) setChan(c chan readResult) { - w.mu.Lock() - defer w.mu.Unlock() - w.ch = c -} - -// ReadContext is like Read, but takes a context permitting the read to be canceled. -// -// If the context becomes done, the underlying Read call continues and its result -// will be given to the next caller to ReadContext. -func (w *contextReader) ReadContext(ctx context.Context, p []byte) (n int, err error) { - if len(p) == 0 { - return 0, nil - } - - n = copy(p, w.buffered) - if n > 0 { - w.buffered = w.buffered[n:] - if len(w.buffered) == 0 { - err = w.readErr - } - return n, err - } - - if w.ch == nil { - ch := make(chan readResult, 1) - w.setChan(ch) - go func() { - rbuf := make([]byte, len(p)) - n, err := w.r.Read(rbuf) - ch <- readResult{rbuf[:n], err} - }() - } - - select { - case <-ctx.Done(): - return 0, ctx.Err() - case rr := <-w.ch: - w.setChan(nil) - n = copy(p, rr.buf) - w.buffered = rr.buf[n:] - w.readErr = rr.err - if len(w.buffered) == 0 { - err = rr.err - } - return n, err - } -} - -// contextReaderSession implements ssh.Session, wrapping another -// ssh.Session but changing its Read method to use contextReader. -type contextReaderSession struct { - ssh.Session - cr *contextReader -} - -func (a contextReaderSession) Read(p []byte) (n int, err error) { - if a.cr.HasOutstandingRead() { - return a.cr.ReadContext(context.Background(), p) - } - return a.Session.Read(p) -} diff --git a/ssh/tailssh/incubator.go b/ssh/tailssh/incubator.go index ac1257c95..86f7b835d 100644 --- a/ssh/tailssh/incubator.go +++ b/ssh/tailssh/incubator.go @@ -86,11 +86,9 @@ func (ss *sshSession) newIncubatorCommand() *exec.Cmd { // TODO(maisem): this doesn't work with sftp return exec.CommandContext(ss.ctx, name, args...) } - ss.conn.mu.Lock() lu := ss.conn.localUser ci := ss.conn.info gids := strings.Join(ss.conn.userGroupIDs, ",") - ss.conn.mu.Unlock() remoteUser := ci.uprof.LoginName if len(ci.node.Tags) > 0 { remoteUser = strings.Join(ci.node.Tags, ",") diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index a4c76110f..ec3889c20 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -29,7 +29,6 @@ "strconv" "strings" "sync" - "sync/atomic" "time" gossh "github.com/tailscale/golang-x-crypto/ssh" @@ -87,6 +86,21 @@ func init() { }) } +// attachSessionToConnIfNotShutdown ensures that srv is not shutdown before +// attaching the session to the conn. This ensures that once Shutdown is called, +// new sessions are not allowed and existing ones are cleaned up. +// It reports whether ss was attached to the conn. +func (srv *server) attachSessionToConnIfNotShutdown(ss *sshSession) bool { + srv.mu.Lock() + defer srv.mu.Unlock() + if srv.shutdownCalled { + // Do not start any new sessions. + return false + } + ss.conn.attachSession(ss) + return true +} + func (srv *server) trackActiveConn(c *conn, add bool) { srv.mu.Lock() defer srv.mu.Unlock() @@ -121,12 +135,7 @@ func (srv *server) Shutdown() { srv.mu.Lock() srv.shutdownCalled = true for c := range srv.activeConns { - for _, s := range c.sessions { - s.ctx.CloseWithError(userVisibleError{ - fmt.Sprintf("Tailscale SSH is shutting down.\r\n"), - context.Canceled, - }) - } + c.Close() } srv.mu.Unlock() srv.sessionWaitGroup.Wait() @@ -138,10 +147,7 @@ func (srv *server) OnPolicyChange() { srv.mu.Lock() defer srv.mu.Unlock() for c := range srv.activeConns { - c.mu.Lock() - ci := c.info - c.mu.Unlock() - if ci == nil { + if c.info == nil { // c.info is nil when the connection hasn't been authenticated yet. // In that case, the connection will be terminated when it is. continue @@ -152,28 +158,53 @@ func (srv *server) OnPolicyChange() { // conn represents a single SSH connection and its associated // ssh.Server. +// +// During the lifecycle of a connection, the following are called in order: +// Setup and discover server info +// - ServerConfigCallback +// +// Do the user auth +// - BannerHandler +// - NoClientAuthHandler +// - PublicKeyHandler (only if NoClientAuthHandler returns errPubKeyRequired) +// +// Once auth is done, the conn can be multiplexed with multiple sessions and +// channels concurrently. At which point any of the following can be called +// in any order. +// - c.handleSessionPostSSHAuth +// - c.mayForwardLocalPortTo followed by ssh.DirectTCPIPHandler type conn struct { *ssh.Server + srv *server insecureSkipTailscaleAuth bool // used by tests. - connID string // ID that's shared with control - action0 *tailcfg.SSHAction // first matching action - srv *server - - mu sync.Mutex // protects the following - localUser *user.User // set by checkAuth - userGroupIDs []string // set by checkAuth - info *sshConnInfo // set by setInfo // idH is the RFC4253 sec8 hash H. It is used to identify the connection, // and is shared among all sessions. It should not be shared outside // process. It is confusingly referred to as SessionID by the gliderlabs/ssh // library. - idH string - pubKey gossh.PublicKey // set by authorizeSession - finalAction *tailcfg.SSHAction // set by authorizeSession - finalActionErr error // set by authorizeSession - sessions []*sshSession + idH string + connID string // ID that's shared with control + + noPubKeyPolicyAuthError error // set by BannerCallback + + action0 *tailcfg.SSHAction // set by doPolicyAuth; first matching action + currentAction *tailcfg.SSHAction // set by doPolicyAuth, updated by resolveNextAction + finalAction *tailcfg.SSHAction // set by doPolicyAuth or resolveNextAction + finalActionErr error // set by doPolicyAuth or resolveNextAction + + info *sshConnInfo // set by setInfo + localUser *user.User // set by doPolicyAuth + userGroupIDs []string // set by doPolicyAuth + pubKey gossh.PublicKey // set by doPolicyAuth + + // mu protects the following fields. + // + // srv.mu should be acquired prior to mu. + // It is safe to just acquire mu, but unsafe to + // acquire mu and then srv.mu. + mu sync.Mutex // protects the following + sessions []*sshSession } func (c *conn) logf(format string, args ...any) { @@ -181,49 +212,108 @@ func (c *conn) logf(format string, args ...any) { c.srv.logf(format, args...) } -// PublicKeyHandler implements ssh.PublicKeyHandler is called by the -// ssh.Server when the client presents a public key. -func (c *conn) PublicKeyHandler(ctx ssh.Context, pubKey ssh.PublicKey) error { - c.mu.Lock() - ci := c.info - c.mu.Unlock() - if ci == nil { - return gossh.ErrDenied +// isAuthorized returns nil if the connection is authorized to proceed. +func (c *conn) isAuthorized(ctx ssh.Context) error { + action := c.currentAction + for { + if action.Accept { + if c.pubKey != nil { + metricPublicKeyAccepts.Add(1) + } + return nil + } + if action.Reject || action.HoldAndDelegate == "" { + return gossh.ErrDenied + } + var err error + action, err = c.resolveNextAction(ctx) + if err != nil { + return err + } } - - if err := c.checkAuth(pubKey); err != nil { - // TODO(maisem/bradfitz): surface the error here. - c.logf("rejecting SSH public key %s: %v", bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey)), err) - return err - } - c.logf("accepting SSH public key %s", bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey))) - return nil } // errPubKeyRequired is returned by NoClientAuthCallback to make the client // resort to public-key auth; not user visible. var errPubKeyRequired = errors.New("ssh publickey required") +// BannerCallback implements ssh.BannerCallback. +// It is responsible for starting the policy evaluation, and returns +// the first message found in the action chain. It stops the evaluation +// on the first "accept" or "reject" action, and returns the message +// associated with that action (if any). +func (c *conn) BannerCallback(ctx ssh.Context) string { + if err := c.setInfo(ctx); err != nil { + c.logf("failed to get conninfo: %v", err) + return gossh.ErrDenied.Error() + } + if err := c.doPolicyAuth(ctx, nil /* no pub key */); err != nil { + // Stash the error for NoClientAuthCallback to return it. + c.noPubKeyPolicyAuthError = err + return "" + } + action := c.currentAction + for { + if action.Reject || action.Accept || action.Message != "" { + return action.Message + } + if action.HoldAndDelegate == "" { + // Do not send user-visible messages to the user. + // Let the SSH level authentication fail instead. + return "" + } + var err error + action, err = c.resolveNextAction(ctx) + if err != nil { + return "" + } + } +} + // NoClientAuthCallback implements gossh.NoClientAuthCallback and is called by // the ssh.Server when the client first connects with the "none" // authentication method. -func (c *conn) NoClientAuthCallback(cm gossh.ConnMetadata) (*gossh.Permissions, error) { +// +// It is responsible for continuing policy evaluation from BannerCallback (or +// starting it afresh). It returns an error if the policy evaluation fails, or +// if the decision is "reject" +// +// It either returns nil (accept) or errPubKeyRequired or gossh.ErrDenied +// (reject). The errors may be wrapped. +func (c *conn) NoClientAuthCallback(ctx ssh.Context) error { if c.insecureSkipTailscaleAuth { - return nil, nil + return nil } - if err := c.setInfo(cm); err != nil { - c.logf("failed to get conninfo: %v", err) - return nil, gossh.ErrDenied + if c.noPubKeyPolicyAuthError != nil { + return c.noPubKeyPolicyAuthError + } else if c.currentAction == nil { + // This should never happen, but if it does, we want to know. + panic("no current action") } - return nil, c.checkAuth(nil /* no pub key */) + return c.isAuthorized(ctx) } -// checkAuth verifies that conn can proceed with the specified (optional) +// PublicKeyHandler implements ssh.PublicKeyHandler is called by the +// ssh.Server when the client presents a public key. +func (c *conn) PublicKeyHandler(ctx ssh.Context, pubKey ssh.PublicKey) error { + if err := c.doPolicyAuth(ctx, pubKey); err != nil { + // TODO(maisem/bradfitz): surface the error here. + c.logf("rejecting SSH public key %s: %v", bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey)), err) + return err + } + if err := c.isAuthorized(ctx); err != nil { + return err + } + c.logf("accepting SSH public key %s", bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey))) + return nil +} + +// doPolicyAuth verifies that conn can proceed with the specified (optional) // pubKey. It returns nil if the matching policy action is Accept or // HoldAndDelegate. If pubKey is nil, there was no policy match but there is a // policy that might match a public key it returns errPubKeyRequired. Otherwise, // it returns gossh.ErrDenied possibly wrapped in gossh.WithBannerError. -func (c *conn) checkAuth(pubKey ssh.PublicKey) error { +func (c *conn) doPolicyAuth(ctx ssh.Context, pubKey ssh.PublicKey) error { a, localUser, err := c.evaluatePolicy(pubKey) if err != nil { if pubKey == nil && c.havePubKeyPolicy() { @@ -232,7 +322,12 @@ func (c *conn) checkAuth(pubKey ssh.PublicKey) error { return fmt.Errorf("%w: %v", gossh.ErrDenied, err) } c.action0 = a + c.currentAction = a + c.pubKey = pubKey if a.Accept || a.HoldAndDelegate != "" { + if a.Accept { + c.finalAction = a + } lu, err := user.Lookup(localUser) if err != nil { c.logf("failed to lookup %v: %v", localUser, err) @@ -245,13 +340,12 @@ func (c *conn) checkAuth(pubKey ssh.PublicKey) error { if err != nil { return err } - c.mu.Lock() - defer c.mu.Unlock() c.userGroupIDs = gids c.localUser = lu return nil } if a.Reject { + c.finalAction = a err := gossh.ErrDenied if a.Message != "" { err = gossh.WithBannerError{ @@ -269,9 +363,8 @@ func (c *conn) checkAuth(pubKey ssh.PublicKey) error { func (c *conn) ServerConfig(ctx ssh.Context) *gossh.ServerConfig { return &gossh.ServerConfig{ // OpenSSH presents this on failure as `Permission denied (tailscale).` - ImplicitAuthMethod: "tailscale", - NoClientAuth: true, // required for the NoClientAuthCallback to run - NoClientAuthCallback: c.NoClientAuthCallback, + ImplicitAuthMethod: "tailscale", + NoClientAuth: true, // required for the NoClientAuthCallback to run } } @@ -289,23 +382,25 @@ func (srv *server) newConn() (*conn, error) { now := srv.now() c.connID = fmt.Sprintf("ssh-conn-%s-%02x", now.UTC().Format("20060102T150405"), randBytes(5)) c.Server = &ssh.Server{ - Version: "Tailscale", - Handler: c.handleSessionPostSSHAuth, - RequestHandlers: map[string]ssh.RequestHandler{}, + Version: "Tailscale", + ServerConfigCallback: c.ServerConfig, + + BannerHandler: c.BannerCallback, + NoClientAuthHandler: c.NoClientAuthCallback, + PublicKeyHandler: c.PublicKeyHandler, + + Handler: c.handleSessionPostSSHAuth, + LocalPortForwardingCallback: c.mayForwardLocalPortTo, SubsystemHandlers: map[string]ssh.SubsystemHandler{ "sftp": c.handleSessionPostSSHAuth, }, - // Note: the direct-tcpip channel handler and LocalPortForwardingCallback // only adds support for forwarding ports from the local machine. // TODO(maisem/bradfitz): add remote port forwarding support. ChannelHandlers: map[string]ssh.ChannelHandler{ "direct-tcpip": ssh.DirectTCPIPHandler, }, - LocalPortForwardingCallback: c.mayForwardLocalPortTo, - - PublicKeyHandler: c.PublicKeyHandler, - ServerConfigCallback: c.ServerConfig, + RequestHandlers: map[string]ssh.RequestHandler{}, } ss := c.Server for k, v := range ssh.DefaultRequestHandlers { @@ -341,10 +436,7 @@ func (c *conn) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string, de // havePubKeyPolicy reports whether any policy rule may provide access by means // of a ssh.PublicKey. func (c *conn) havePubKeyPolicy() bool { - c.mu.Lock() - ci := c.info - c.mu.Unlock() - if ci == nil { + if c.info == nil { panic("havePubKeyPolicy called before setInfo") } // Is there any rule that looks like it'd require a public key for this @@ -357,7 +449,7 @@ func (c *conn) havePubKeyPolicy() bool { if c.ruleExpired(r) { continue } - if mapLocalUser(r.SSHUsers, ci.sshUser) == "" { + if mapLocalUser(r.SSHUsers, c.info.sshUser) == "" { continue } for _, p := range r.Principals { @@ -416,11 +508,11 @@ func toIPPort(a net.Addr) (ipp netip.AddrPort) { // connInfo returns a populated sshConnInfo from the provided arguments, // validating only that they represent a known Tailscale identity. -func (c *conn) setInfo(cm gossh.ConnMetadata) error { +func (c *conn) setInfo(ctx ssh.Context) error { ci := &sshConnInfo{ - sshUser: cm.User(), - src: toIPPort(cm.RemoteAddr()), - dst: toIPPort(cm.LocalAddr()), + sshUser: ctx.User(), + src: toIPPort(ctx.RemoteAddr()), + dst: toIPPort(ctx.LocalAddr()), } if !tsaddr.IsTailscaleIP(ci.dst.Addr()) { return fmt.Errorf("tailssh: rejecting non-Tailscale local address %v", ci.dst) @@ -432,11 +524,10 @@ func (c *conn) setInfo(cm gossh.ConnMetadata) error { if !ok { return fmt.Errorf("unknown Tailscale identity from src %v", ci.src) } - c.mu.Lock() - defer c.mu.Unlock() ci.node = node ci.uprof = &uprof + c.idH = ctx.SessionID() c.info = ci c.logf("handling conn: %v", ci.String()) return nil @@ -554,50 +645,10 @@ func (srv *server) fetchPublicKeysURL(url string) ([]string, error) { return lines, err } -func (c *conn) authorizeSession(s ssh.Session) (_ *contextReader, ok bool) { - c.mu.Lock() - defer c.mu.Unlock() - idH := s.Context().(ssh.Context).SessionID() - if c.idH == "" { - c.idH = idH - } else if c.idH != idH { - c.logf("ssh: session ID mismatch: %q != %q", c.idH, idH) - s.Exit(1) - return nil, false - } - cr := &contextReader{r: s} - action, err := c.resolveTerminalActionLocked(s, cr) - if err != nil { - c.logf("resolveTerminalAction: %v", err) - io.WriteString(s.Stderr(), "Access Denied: failed during authorization check.\r\n") - s.Exit(1) - return nil, false - } - if action.Reject || !action.Accept { - c.logf("access denied for %v", c.info.uprof.LoginName) - s.Exit(1) - return nil, false - } - return cr, true -} - // handleSessionPostSSHAuth runs an SSH session after the SSH-level authentication, // but not necessarily before all the Tailscale-level extra verification has // completed. It also handles SFTP requests. func (c *conn) handleSessionPostSSHAuth(s ssh.Session) { - // Now that we have passed the SSH-level authentication, we can start the - // Tailscale-level extra verification. This means that we are going to - // evaluate the policy provided by control against the incoming SSH session. - cr, ok := c.authorizeSession(s) - if !ok { - return - } - if cr.HasOutstandingRead() { - // There was some buffered input while we were waiting for the policy - // decision. - s = contextReaderSession{s, cr} - } - // Do this check after auth, but before starting the session. switch s.Subsystem() { case "sftp", "": @@ -609,45 +660,35 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) { } ss := c.newSSHSession(s) - c.mu.Lock() ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.Addr(), c.localUser.Username) ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, c.localUser.Username) - c.mu.Unlock() ss.run() } -// resolveTerminalActionLocked either returns action0 (if it's Accept or Reject) or -// else loops, fetching new SSHActions from the control plane. -// -// Any action with a Message in the chain will be printed to s. -// -// The returned SSHAction will be either Reject or Accept. -// -// c.mu must be held. -func (c *conn) resolveTerminalActionLocked(s ssh.Session, cr *contextReader) (action *tailcfg.SSHAction, err error) { +// resolveNextAction starts at c.currentAction and makes it way through the +// action chain one step at a time. An action without a HoldAndDelegate is +// considered the final action. Once a final action is reached, this function +// will keep returning that action. It updates c.currentAction to the next +// action in the chain. When the final action is reached, it also sets +// c.finalAction to the final action. +func (c *conn) resolveNextAction(sctx ssh.Context) (action *tailcfg.SSHAction, err error) { if c.finalAction != nil || c.finalActionErr != nil { return c.finalAction, c.finalActionErr } - if s.PublicKey() != nil { - metricPublicKeyConnections.Add(1) - } defer func() { - c.finalAction = action - c.finalActionErr = err - c.pubKey = s.PublicKey() - if c.pubKey != nil && action.Accept { - metricPublicKeyAccepts.Add(1) + if action != nil { + c.currentAction = action + if action.Accept || action.Reject { + c.finalAction = action + } + } + if err != nil { + c.finalActionErr = err } }() - action = c.action0 - var awaitReadOnce sync.Once // to start Reads on cr - var sawInterrupt atomic.Bool - var wg sync.WaitGroup - defer wg.Wait() // wait for awaitIntrOnce's goroutine to exit - - ctx, cancel := context.WithCancel(s.Context()) + ctx, cancel := context.WithCancel(sctx) defer cancel() // Loop processing/fetching Actions until one reaches a @@ -656,56 +697,28 @@ func (c *conn) resolveTerminalActionLocked(s ssh.Session, cr *contextReader) (ac // done (client disconnect) or its 30 minute timeout passes. // (Which is a long time for somebody to see login // instructions and go to a URL to do something.) - for { - if action.Message != "" { - io.WriteString(s.Stderr(), strings.Replace(action.Message, "\n", "\r\n", -1)) - } - if action.Accept || action.Reject { - if action.Reject { - metricTerminalReject.Add(1) - } else { - metricTerminalAccept.Add(1) - } - return action, nil - } - url := action.HoldAndDelegate - if url == "" { - metricTerminalMalformed.Add(1) - return nil, errors.New("reached Action that lacked Accept, Reject, and HoldAndDelegate") - } - metricHolds.Add(1) - awaitReadOnce.Do(func() { - wg.Add(1) - go func() { - defer wg.Done() - buf := make([]byte, 1) - for { - n, err := cr.ReadContext(ctx, buf) - if err != nil { - return - } - if n > 0 && buf[0] == 0x03 { // Ctrl-C - sawInterrupt.Store(true) - s.Stderr().Write([]byte("Canceled.\r\n")) - s.Exit(1) - return - } - } - }() - }) - url = c.expandDelegateURLLocked(url) - var err error - action, err = c.fetchSSHAction(ctx, url) - if err != nil { - if sawInterrupt.Load() { - metricTerminalInterrupt.Add(1) - return nil, fmt.Errorf("aborted by user") - } else { - metricTerminalFetchError.Add(1) - } - return nil, fmt.Errorf("fetching SSHAction from %s: %w", url, err) + action = c.currentAction + if action.Accept || action.Reject { + if action.Reject { + metricTerminalReject.Add(1) + } else { + metricTerminalAccept.Add(1) } + return action, nil } + url := action.HoldAndDelegate + if url == "" { + metricTerminalMalformed.Add(1) + return nil, errors.New("reached Action that lacked Accept, Reject, and HoldAndDelegate") + } + metricHolds.Add(1) + url = c.expandDelegateURLLocked(url) + nextAction, err := c.fetchSSHAction(ctx, url) + if err != nil { + metricTerminalFetchError.Add(1) + return nil, fmt.Errorf("fetching SSHAction from %s: %w", url, err) + } + return nextAction, nil } func (c *conn) expandDelegateURLLocked(actionURL string) string { @@ -732,12 +745,10 @@ func (c *conn) expandPublicKeyURL(pubKeyURL string) string { } var localPart string var loginName string - c.mu.Lock() if c.info.uprof != nil { loginName = c.info.uprof.LoginName localPart, _, _ = strings.Cut(loginName, "@") } - c.mu.Unlock() return strings.NewReplacer( "$LOGINNAME_EMAIL", loginName, "$LOGINNAME_LOCALPART", localPart, @@ -793,8 +804,6 @@ func (c *conn) isStillValid() bool { if !a.Accept && a.HoldAndDelegate == "" { return false } - c.mu.Lock() - defer c.mu.Unlock() return c.localUser.Username == localUser } @@ -806,6 +815,8 @@ func (c *conn) checkStillValid() { } metricPolicyChangeKick.Add(1) c.logf("session no longer valid per new SSH policy; closing") + c.mu.Lock() + defer c.mu.Unlock() for _, s := range c.sessions { s.ctx.CloseWithError(userVisibleError{ fmt.Sprintf("Access revoked.\r\n"), @@ -876,21 +887,22 @@ func (ss *sshSession) killProcessOnContextDone() { }) } -// startSessionLocked registers ss as an active session. -// It must be called with srv.mu held. -func (c *conn) startSessionLocked(ss *sshSession) { +// attachSession registers ss as an active session. +func (c *conn) attachSession(ss *sshSession) { c.srv.sessionWaitGroup.Add(1) if ss.sharedID == "" { panic("empty sharedID") } + c.mu.Lock() + defer c.mu.Unlock() c.sessions = append(c.sessions, ss) } -// endSession unregisters s from the list of active sessions. -func (c *conn) endSession(ss *sshSession) { +// detachSession unregisters s from the list of active sessions. +func (c *conn) detachSession(ss *sshSession) { defer c.srv.sessionWaitGroup.Done() - c.srv.mu.Lock() - defer c.srv.mu.Unlock() + c.mu.Lock() + defer c.mu.Unlock() for i, s := range c.sessions { if s == ss { c.sessions = append(c.sessions[:i], c.sessions[i+1:]...) @@ -960,22 +972,16 @@ func (ss *sshSession) run() { metricActiveSessions.Add(1) defer metricActiveSessions.Add(-1) defer ss.ctx.CloseWithError(errSessionDone) - srv := ss.conn.srv - srv.mu.Lock() - if srv.shutdownCalled { - srv.mu.Unlock() - // Do not start any new sessions. + if attached := ss.conn.srv.attachSessionToConnIfNotShutdown(ss); !attached { fmt.Fprintf(ss, "Tailscale SSH is shutting down\r\n") ss.Exit(1) return } - ss.conn.startSessionLocked(ss) - lu := ss.conn.localUser - localUser := lu.Username - srv.mu.Unlock() + defer ss.conn.detachSession(ss) - defer ss.conn.endSession(ss) + lu := ss.conn.localUser + logf := ss.logf if ss.conn.finalAction.SessionDuration != 0 { t := time.AfterFunc(ss.conn.finalAction.SessionDuration, func() { @@ -987,11 +993,9 @@ func (ss *sshSession) run() { defer t.Stop() } - logf := ss.logf - if euid := os.Geteuid(); euid != 0 { if lu.Uid != fmt.Sprint(euid) { - ss.logf("can't switch to user %q from process euid %v", localUser, euid) + ss.logf("can't switch to user %q from process euid %v", lu.Username, euid) fmt.Fprintf(ss, "can't switch user\r\n") ss.Exit(1) return @@ -1141,10 +1145,7 @@ func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg if c == nil { return nil, "", errInvalidConn } - c.mu.Lock() - ci := c.info - c.mu.Unlock() - if ci == nil { + if c.info == nil { c.logf("invalid connection state") return nil, "", errInvalidConn } @@ -1161,7 +1162,7 @@ func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg // For all but Reject rules, SSHUsers is required. // If SSHUsers is nil or empty, mapLocalUser will return an // empty string anyway. - localUser = mapLocalUser(r.SSHUsers, ci.sshUser) + localUser = mapLocalUser(r.SSHUsers, c.info.sshUser) if localUser == "" { return nil, "", errUserMatch } @@ -1210,9 +1211,7 @@ func (c *conn) principalMatches(p *tailcfg.SSHPrincipal, pubKey gossh.PublicKey) // that match the Tailscale identity match (Node, NodeIP, UserLogin, Any). // This function does not consider PubKeys. func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool { - c.mu.Lock() ci := c.info - c.mu.Unlock() if p.Any { return true } diff --git a/tempfork/gliderlabs/ssh/server.go b/tempfork/gliderlabs/ssh/server.go index cf9a7c804..65db39667 100644 --- a/tempfork/gliderlabs/ssh/server.go +++ b/tempfork/gliderlabs/ssh/server.go @@ -38,9 +38,11 @@ type Server struct { HostSigners []Signer // private keys for the host key, must have at least one Version string // server version to be sent before the initial handshake - KeyboardInteractiveHandler KeyboardInteractiveHandler // keyboard-interactive authentication handler + KeyboardInteractiveHandler KeyboardInteractiveHandler // keyboard-interactive authentication handler + BannerHandler BannerHandler PasswordHandler PasswordHandler // password authentication handler PublicKeyHandler PublicKeyHandler // public key authentication handler + NoClientAuthHandler NoClientAuthHandler // no client authentication handler PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil @@ -160,6 +162,21 @@ func (srv *Server) config(ctx Context) *gossh.ServerConfig { return ctx.Permissions().Permissions, nil } } + if srv.NoClientAuthHandler != nil { + config.NoClientAuthCallback = func(conn gossh.ConnMetadata) (*gossh.Permissions, error) { + applyConnMetadata(ctx, conn) + if err := srv.NoClientAuthHandler(ctx); err != nil { + return ctx.Permissions().Permissions, err + } + return ctx.Permissions().Permissions, nil + } + } + if srv.BannerHandler != nil { + config.BannerCallback = func(conn gossh.ConnMetadata) string { + applyConnMetadata(ctx, conn) + return srv.BannerHandler(ctx) + } + } return config } diff --git a/tempfork/gliderlabs/ssh/ssh.go b/tempfork/gliderlabs/ssh/ssh.go index 0c7f45de8..644cb257d 100644 --- a/tempfork/gliderlabs/ssh/ssh.go +++ b/tempfork/gliderlabs/ssh/ssh.go @@ -38,6 +38,10 @@ // PublicKeyHandler is a callback for performing public key authentication. type PublicKeyHandler func(ctx Context, key PublicKey) error +type NoClientAuthHandler func(ctx Context) error + +type BannerHandler func(ctx Context) string + // PasswordHandler is a callback for performing password authentication. type PasswordHandler func(ctx Context, password string) bool