From 6e86bbcb06c427df71cd52d0ec5eccd24a9605e7 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sun, 13 Mar 2022 13:01:59 -0700 Subject: [PATCH] ssh/tailssh: add a new sshSession type to clean up existing+future code Updates #3802 Change-Id: I7054dca387f5e5aee1185937ecf41b77a5a07f1a Signed-off-by: Brad Fitzpatrick Co-authored-by: Maisem Ali --- ssh/tailssh/incubator.go | 88 +++++++++++------ ssh/tailssh/tailssh.go | 183 +++++++++++++++++++++++++----------- ssh/tailssh/tailssh_test.go | 6 +- 3 files changed, 188 insertions(+), 89 deletions(-) diff --git a/ssh/tailssh/incubator.go b/ssh/tailssh/incubator.go index 93864072b..3c24ba38b 100644 --- a/ssh/tailssh/incubator.go +++ b/ssh/tailssh/incubator.go @@ -14,6 +14,7 @@ import ( "context" + "errors" "flag" "fmt" "io" @@ -54,11 +55,15 @@ func init() { // newIncubatorCommand returns a new exec.Cmd configured with // `tailscaled be-child ssh` as the entrypoint. -// If tailscaled is empty, the desired cmd is executed directly. -func newIncubatorCommand(ctx context.Context, ci *sshConnInfo, lu *user.User, tailscaled, name string, args []string) *exec.Cmd { - if tailscaled == "" { +// +// If ss.srv.tailscaledPath is empty, this method is equivalent to +// exec.CommandContext. +func (ss *sshSession) newIncubatorCommand(ctx context.Context, name string, args []string) *exec.Cmd { + if ss.srv.tailscaledPath == "" { return exec.CommandContext(ctx, name, args...) } + lu := ss.localUser + ci := ss.connInfo remoteUser := ci.uprof.LoginName if len(ci.node.Tags) > 0 { remoteUser = strings.Join(ci.node.Tags, ",") @@ -80,7 +85,7 @@ func newIncubatorCommand(ctx context.Context, ci *sshConnInfo, lu *user.User, ta incubatorArgs = append(incubatorArgs, args...) } - return exec.CommandContext(ctx, tailscaled, incubatorArgs...) + return exec.CommandContext(ctx, ss.srv.tailscaledPath, incubatorArgs...) } const debugIncubator = false @@ -158,37 +163,44 @@ func beIncubator(args []string) error { // launchProcess launches an incubator process for the provided session. // It is responsible for configuring the process execution environment. // The caller can wait for the process to exit by calling cmd.Wait(). -func (srv *server) launchProcess(ctx context.Context, s ssh.Session, ci *sshConnInfo, lu *user.User) (cmd *exec.Cmd, stdin io.WriteCloser, stdout, stderr io.Reader, err error) { - shell := loginShell(lu.Uid) +// +// It sets ss.cmd, stdin, stdout, and stderr. +func (ss *sshSession) launchProcess(ctx context.Context) error { + shell := loginShell(ss.localUser.Uid) var args []string - if rawCmd := s.RawCommand(); rawCmd != "" { + if rawCmd := ss.RawCommand(); rawCmd != "" { args = append(args, "-c", rawCmd) } else { args = append(args, "-l") // login shell } - ptyReq, winCh, isPty := s.Pty() - cmd = newIncubatorCommand(ctx, ci, lu, srv.tailscaledPath, shell, args) - cmd.Dir = lu.HomeDir - cmd.Env = append(cmd.Env, envForUser(lu)...) - cmd.Env = append(cmd.Env, s.Environ()...) + ci := ss.connInfo + cmd := ss.newIncubatorCommand(ctx, shell, args) + cmd.Dir = ss.localUser.HomeDir + cmd.Env = append(cmd.Env, envForUser(ss.localUser)...) + cmd.Env = append(cmd.Env, ss.Environ()...) cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_CLIENT=%s %d %d", ci.src.IP(), ci.src.Port(), ci.dst.Port()), fmt.Sprintf("SSH_CONNECTION=%s %d %s %d", ci.src.IP(), ci.src.Port(), ci.dst.IP(), ci.dst.Port()), ) - srv.logf("ssh: starting: %+v", cmd.Args) + ss.cmd = cmd + + ptyReq, winCh, isPty := ss.Pty() if !isPty { - stdin, stdout, stderr, err = startWithStdPipes(cmd) - return + ss.logf("starting non-pty command: %+v", cmd.Args) + return ss.startWithStdPipes() } - pty, err := srv.startWithPTY(cmd, ptyReq) + ss.ptyReq = &ptyReq + ss.logf("starting pty command: %+v", cmd.Args) + pty, err := ss.startWithPTY() if err != nil { - return nil, nil, nil, nil, err + return err } go resizeWindow(pty, winCh) - // When using a pty we don't get a separate reader for stderr. - return cmd, pty, pty, nil, nil + ss.stdout = pty // no stderr for a pty + ss.stdin = pty + return nil } func resizeWindow(f *os.File, winCh <-chan ssh.Window) { @@ -263,7 +275,16 @@ func resizeWindow(f *os.File, winCh <-chan ssh.Window) { } // startWithPTY starts cmd with a psuedo-terminal attached to Stdin, Stdout and Stderr. -func (srv *server) startWithPTY(cmd *exec.Cmd, ptyReq ssh.Pty) (ptyFile *os.File, err error) { +func (ss *sshSession) startWithPTY() (ptyFile *os.File, err error) { + ptyReq := ss.ptyReq + cmd := ss.cmd + if cmd == nil { + return nil, errors.New("nil ss.cmd") + } + if ptyReq == nil { + return nil, errors.New("nil ss.ptyReq") + } + var tty *os.File ptyFile, tty, err = pty.Open() if err != nil { @@ -305,7 +326,7 @@ func (srv *server) startWithPTY(cmd *exec.Cmd, ptyReq ssh.Pty) (ptyFile *os.File } k, ok := opcodeShortName[c] if !ok { - srv.logf("unknown opcode: %d", c) + ss.logf("unknown opcode: %d", c) continue } if _, ok := tios.CC[k]; ok { @@ -316,7 +337,7 @@ func (srv *server) startWithPTY(cmd *exec.Cmd, ptyReq ssh.Pty) (ptyFile *os.File tios.Opts[k] = v > 0 continue } - srv.logf("unsupported opcode: %v(%d)=%v", k, c, v) + ss.logf("unsupported opcode: %v(%d)=%v", k, c, v) } // Save PTY settings. @@ -355,7 +376,9 @@ func (srv *server) startWithPTY(cmd *exec.Cmd, ptyReq ssh.Pty) (ptyFile *os.File } // startWithStdPipes starts cmd with os.Pipe for Stdin, Stdout and Stderr. -func startWithStdPipes(cmd *exec.Cmd) (stdin io.WriteCloser, stdout, stderr io.ReadCloser, err error) { +func (ss *sshSession) startWithStdPipes() (err error) { + var stdin io.WriteCloser + var stdout, stderr io.ReadCloser defer func() { if err != nil { for _, c := range []io.Closer{stdin, stdout, stderr} { @@ -365,20 +388,29 @@ func startWithStdPipes(cmd *exec.Cmd) (stdin io.WriteCloser, stdout, stderr io.R } } }() + cmd := ss.cmd + if cmd == nil { + return errors.New("nil cmd") + } stdin, err = cmd.StdinPipe() if err != nil { - return + return err } stdout, err = cmd.StdoutPipe() if err != nil { - return + return err } stderr, err = cmd.StderrPipe() if err != nil { - return + return err } - err = cmd.Start() - return + if err := cmd.Start(); err != nil { + return err + } + ss.stdin = stdin + ss.stdout = stdout + ss.stderr = stderr + return nil } func loginShell(uid string) string { diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index fa7b08502..f8b769f12 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -10,6 +10,7 @@ import ( "context" + "crypto/rand" "encoding/json" "errors" "fmt" @@ -95,8 +96,9 @@ type server struct { tailscaledPath string // mu protects activeSessions. - mu sync.Mutex - activeSessions map[string]bool + mu sync.Mutex + activeSessionByH map[string]*sshSession // ssh.SessionID (DH H) => that session + activeSessionBySharedID map[string]*sshSession // yyymmddThhmmss-XXXXX => session } var debugPolicyFile = envknob.String("TS_DEBUG_SSH_POLICY_FILE") @@ -244,19 +246,48 @@ func (srv *server) handleSSH(s ssh.Session) { return } - var ctx context.Context = context.Background() - if action.SesssionDuration != 0 { - sctx := newSSHContext() - ctx = sctx - t := time.AfterFunc(action.SesssionDuration, func() { - sctx.CloseWithError(userVisibleError{ - fmt.Sprintf("Session timeout of %v elapsed.", action.SesssionDuration), - context.DeadlineExceeded, - }) - }) - defer t.Stop() + ss := srv.newSSHSession(s, ci, lu, action) + ss.run() +} + +// sshSession is an accepted Tailscale SSH session. +type sshSession struct { + ssh.Session + idH string // the RFC4253 sec8 hash H; don't share outside process + sharedID string // ID that's shared with control + logf logger.Logf + + ctx *sshContext // implements context.Context + srv *server + connInfo *sshConnInfo + action *tailcfg.SSHAction + localUser *user.User + + // initialized by launchProcess: + cmd *exec.Cmd + stdin io.WriteCloser + stdout io.Reader + stderr io.Reader // nil for pty sessions + ptyReq *ssh.Pty // non-nil for pty sessions + + // We use this sync.Once to ensure that we only terminate the process once, + // either it exits itself or is terminated + exitOnce sync.Once +} + +func (srv *server) newSSHSession(s ssh.Session, ci *sshConnInfo, lu *user.User, action *tailcfg.SSHAction) *sshSession { + sharedID := fmt.Sprintf("%s-%02x", ci.now.UTC().Format("20060102T150405"), randBytes(5)) + return &sshSession{ + Session: s, + idH: s.Context().(ssh.Context).SessionID(), + sharedID: sharedID, + ctx: newSSHContext(), + srv: srv, + action: action, + localUser: lu, + connInfo: ci, + logf: logger.WithPrefix(srv.logf, "ssh-session("+sharedID+"): "), } - srv.handleAcceptedSSH(ctx, s, ci, lu) } func (srv *server) fetchSSHAction(ctx context.Context, url string) (*tailcfg.SSHAction, error) { @@ -290,20 +321,22 @@ func (srv *server) fetchSSHAction(ctx context.Context, url string) (*tailcfg.SSH } } -func (srv *server) handleSessionTermination(ctx context.Context, s ssh.Session, ci *sshConnInfo, cmd *exec.Cmd, exitOnce *sync.Once) { - <-ctx.Done() +// killProcessOnContextDone waits for ss.ctx to be done and kills the process, +// unless the process has already exited. +func (ss *sshSession) killProcessOnContextDone() { + <-ss.ctx.Done() // Either the process has already existed, in which case this does nothing. // Or, the process is still running in which case this will kill it. - exitOnce.Do(func() { - err := ctx.Err() + ss.exitOnce.Do(func() { + err := ss.ctx.Err() if serr, ok := err.(SSHTerminationError); ok { msg := serr.SSHTerminationMessage() if msg != "" { - io.WriteString(s.Stderr(), "\r\n\r\n"+msg+"\r\n\r\n") + io.WriteString(ss.Stderr(), "\r\n\r\n"+msg+"\r\n\r\n") } } - srv.logf("terminating SSH session from %v: %v", ci.src.IP(), err) - cmd.Process.Kill() + ss.logf("terminating SSH session from %v: %v", ss.connInfo.src.IP(), err) + ss.cmd.Process.Kill() }) } @@ -312,110 +345,138 @@ func (srv *server) handleSessionTermination(ctx context.Context, s ssh.Session, func (srv *server) isActiveSession(sctx ssh.Context) bool { srv.mu.Lock() defer srv.mu.Unlock() - return srv.activeSessions[sctx.SessionID()] + _, ok := srv.activeSessionByH[sctx.SessionID()] + return ok } -// startSession registers s as an active session. -func (srv *server) startSession(s ssh.Session) { +// startSession registers ss as an active session. +func (srv *server) startSession(ss *sshSession) { srv.mu.Lock() defer srv.mu.Unlock() - if srv.activeSessions == nil { - srv.activeSessions = make(map[string]bool) + if srv.activeSessionByH == nil { + srv.activeSessionByH = make(map[string]*sshSession) } - srv.activeSessions[s.Context().(ssh.Context).SessionID()] = true + if srv.activeSessionBySharedID == nil { + srv.activeSessionBySharedID = make(map[string]*sshSession) + } + if ss.idH == "" { + panic("empty idH") + } + if _, dup := srv.activeSessionByH[ss.idH]; dup { + panic("dup idH") + } + if ss.sharedID == "" { + panic("empty sharedID") + } + if _, dup := srv.activeSessionBySharedID[ss.sharedID]; dup { + panic("dup sharedID") + } + srv.activeSessionByH[ss.idH] = ss + srv.activeSessionBySharedID[ss.sharedID] = ss } // endSession unregisters s from the list of active sessions. -func (srv *server) endSession(s ssh.Session) { +func (srv *server) endSession(ss *sshSession) { srv.mu.Lock() defer srv.mu.Unlock() - delete(srv.activeSessions, s.Context().(ssh.Context).SessionID()) + delete(srv.activeSessionByH, ss.idH) + delete(srv.activeSessionBySharedID, ss.sharedID) } -// handleAcceptedSSH handles s once it's been accepted and determined -// that it should run as local system user lu. +var errSessionDone = errors.New("session is done") + +// run is the entrypoint for a newly accepted SSH session. // // When ctx is done, the session is forcefully terminated. If its Err // is an SSHTerminationError, its SSHTerminationMessage is sent to the // user. -func (srv *server) handleAcceptedSSH(ctx context.Context, s ssh.Session, ci *sshConnInfo, lu *user.User) { - srv.startSession(s) - defer srv.endSession(s) +func (ss *sshSession) run() { + srv := ss.srv + srv.startSession(ss) + defer srv.endSession(ss) + + defer ss.ctx.CloseWithError(errSessionDone) + + if ss.action.SesssionDuration != 0 { + t := time.AfterFunc(ss.action.SesssionDuration, func() { + ss.ctx.CloseWithError(userVisibleError{ + fmt.Sprintf("Session timeout of %v elapsed.", ss.action.SesssionDuration), + context.DeadlineExceeded, + }) + }) + defer t.Stop() + } + logf := srv.logf + lu := ss.localUser localUser := lu.Username if euid := os.Geteuid(); euid != 0 { if lu.Uid != fmt.Sprint(euid) { logf("ssh: can't switch to user %q from process euid %v", localUser, euid) - fmt.Fprintf(s, "can't switch user\n") - s.Exit(1) + fmt.Fprintf(ss, "can't switch user\n") + ss.Exit(1) return } } // Take control of the PTY so that we can configure it below. // See https://github.com/tailscale/tailscale/issues/4146 - s.DisablePTYEmulation() + ss.DisablePTYEmulation() - cmd, stdin, stdout, stderr, err := srv.launchProcess(ctx, s, ci, lu) + err := ss.launchProcess(ss.ctx) if err != nil { logf("start failed: %v", err.Error()) - s.Exit(1) + ss.Exit(1) return } - // We use this sync.Once to ensure that we only terminate the process once, - // either it exits itself or is terminated - var exitOnce sync.Once - if ctx.Done() != nil { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - go srv.handleSessionTermination(ctx, s, ci, cmd, &exitOnce) - } + go ss.killProcessOnContextDone() + go func() { - _, err := io.Copy(stdin, s) + _, err := io.Copy(ss.stdin, ss) if err != nil { // TODO: don't log in the success case. logf("ssh: stdin copy: %v", err) } - stdin.Close() + ss.stdin.Close() }() go func() { - _, err := io.Copy(s, stdout) + _, err := io.Copy(ss, ss.stdout) if err != nil { // TODO: don't log in the success case. logf("ssh: stdout copy: %v", err) } }() // stderr is nil for ptys. - if stderr != nil { + if ss.stderr != nil { go func() { - _, err := io.Copy(s.Stderr(), stderr) + _, err := io.Copy(ss.Stderr(), ss.stderr) if err != nil { // TODO: don't log in the success case. logf("ssh: stderr copy: %v", err) } }() } - err = cmd.Wait() + err = ss.cmd.Wait() // This will either make the SSH Termination goroutine be a no-op, // or itself will be a no-op because the process was killed by the // aforementioned goroutine. - exitOnce.Do(func() {}) + ss.exitOnce.Do(func() {}) if err == nil { logf("ssh: Wait: ok") - s.Exit(0) + ss.Exit(0) return } if ee, ok := err.(*exec.ExitError); ok { code := ee.ProcessState.ExitCode() logf("ssh: Wait: code=%v", code) - s.Exit(code) + ss.Exit(code) return } logf("ssh: Wait: %v", err) - s.Exit(1) + ss.Exit(1) return } @@ -509,3 +570,11 @@ func matchesPrincipal(ps []*tailcfg.SSHPrincipal, ci *sshConnInfo) bool { } return false } + +func randBytes(n int) []byte { + b := make([]byte, n) + if _, err := rand.Read(b); err != nil { + panic(err) + } + return b +} diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index d3ee8a116..9d580b04a 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -9,7 +9,6 @@ import ( "bytes" - "context" "errors" "fmt" "net" @@ -211,10 +210,9 @@ func TestSSH(t *testing.T) { uprof: &tailcfg.UserProfile{}, } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() ss.Handler = func(s ssh.Session) { - srv.handleAcceptedSSH(ctx, s, ci, u) + ss := srv.newSSHSession(s, ci, u, &tailcfg.SSHAction{Accept: true}) + ss.run() } ln, err := net.Listen("tcp4", "127.0.0.1:0")