From e2ed06c53c7ccd789cb0e51575f364589bbdb396 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 23 Feb 2022 15:47:57 -0800 Subject: [PATCH] ssh/tailssh: break a method into half in prep for testing And add a private context type in the process. Updates #3802 Change-Id: I257187f4cfb0f2248d95b81c1dfe0911ef203b60 Signed-off-by: Brad Fitzpatrick --- ssh/tailssh/context.go | 61 ++++++++++++++++++++++++++++++++++++++++++ ssh/tailssh/tailssh.go | 56 +++++++++++++++++++++++++++++++------- 2 files changed, 108 insertions(+), 9 deletions(-) create mode 100644 ssh/tailssh/context.go diff --git a/ssh/tailssh/context.go b/ssh/tailssh/context.go new file mode 100644 index 000000000..066eacf4a --- /dev/null +++ b/ssh/tailssh/context.go @@ -0,0 +1,61 @@ +// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tailssh + +import ( + "sync" + "time" +) + +// sshContext is the context.Context implementation we use for SSH +// that adds a CloseWithError method. Otherwise it's just a normalish +// Context. +type sshContext struct { + mu sync.Mutex + closed bool + done chan struct{} + err error +} + +func newSSHContext() *sshContext { + return &sshContext{done: make(chan struct{})} +} + +func (ctx *sshContext) CloseWithError(err error) { + ctx.mu.Lock() + defer ctx.mu.Unlock() + if ctx.closed { + return + } + ctx.closed = true + ctx.err = err + close(ctx.done) +} + +func (ctx *sshContext) Err() error { + ctx.mu.Lock() + defer ctx.mu.Unlock() + return ctx.err +} + +func (ctx *sshContext) Done() <-chan struct{} { return ctx.done } +func (ctx *sshContext) Deadline() (deadline time.Time, ok bool) { return } +func (ctx *sshContext) Value(interface{}) interface{} { return nil } + +// userVisibleError is a wrapper around an error that implements +// SSHTerminationError, so msg is written to their session. +type userVisibleError struct { + msg string + error +} + +func (ue userVisibleError) SSHTerminationMessage() string { return ue.msg } + +// SSHTerminationError is implemented by errors that terminate an SSH +// session and should be written to user's sessions. +type SSHTerminationError interface { + error + SSHTerminationMessage() string +} diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 943e5297f..6f960cc71 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -9,6 +9,7 @@ package tailssh import ( + "context" "encoding/json" "errors" "fmt" @@ -102,7 +103,6 @@ func (srv *server) sshPolicy() (_ *tailcfg.SSHPolicy, ok bool) { func (srv *server) handleSSH(s ssh.Session) { lb := srv.lb logf := srv.logf - sshUser := s.User() addr := s.RemoteAddr() logf("Handling SSH from %v for user %v", addr, sshUser) @@ -131,7 +131,6 @@ func (srv *server) handleSSH(s ssh.Session) { return } - ptyReq, winCh, isPty := s.Pty() srcIPP := netaddr.IPPortFrom(tanetaddr, uint16(ta.Port)) node, uprof, ok := lb.WhoIs(srcIPP) if !ok { @@ -167,7 +166,34 @@ func (srv *server) handleSSH(s ssh.Session) { s.Exit(1) } - logf("ssh: connection from %v %v to %v@ => %q. command = %q, env = %q", srcIP, uprof.LoginName, sshUser, localUser, s.Command(), s.Environ()) + var ctx context.Context = context.Background() + if action.SesssionDuration != 0 { + sctx := newSSHContext() + ctx = sctx + t := time.AfterFunc(action.SesssionDuration, func() { + sctx.CloseWithError(userVisibleError{ + fmt.Sprintf("Session timeout of %v elapsed.", action.SesssionDuration), + context.DeadlineExceeded, + }) + }) + defer t.Stop() + } + srv.handleAcceptedSSH(ctx, s, ci, lu) +} + +// handleAcceptedSSH handles s once it's been accepted and determined +// that it should run as local system user lu. +// +// When ctx is done, the session is forcefully terminated. If its Err +// is an SSHTerminationError, its SSHTerminationMessage is sent to the +// user. +func (srv *server) handleAcceptedSSH(ctx context.Context, s ssh.Session, ci *sshConnInfo, lu *user.User) { + logf := srv.logf + localUser := lu.Username + + var err error + ptyReq, winCh, isPty := s.Pty() + logf("ssh: connection from %v %v to %v@ => %q. command = %q, env = %q", ci.srcIP, ci.uprof.LoginName, ci.sshUser, localUser, s.Command(), s.Environ()) var cmd *exec.Cmd if euid := os.Geteuid(); euid != 0 { if lu.Uid != fmt.Sprint(euid) { @@ -223,12 +249,24 @@ func (srv *server) handleSSH(s ssh.Session) { go func() { io.Copy(s.Stderr(), stderr) }() } - if action.SesssionDuration != 0 { - t := time.AfterFunc(action.SesssionDuration, func() { - logf("terminating SSH session from %v after max duration", srcIP) - cmd.Process.Kill() - }) - defer t.Stop() + if ctx.Done() != nil { + done := make(chan struct{}) + defer close(done) + go func() { + select { + case <-done: + case <-ctx.Done(): + err := ctx.Err() + if serr, ok := err.(SSHTerminationError); ok { + msg := serr.SSHTerminationMessage() + if msg != "" { + io.WriteString(s.Stderr(), "\r\n\r\n"+msg+"\r\n\r\n") + } + } + logf("terminating SSH session from %v: %v", ci.srcIP, err) + cmd.Process.Kill() + } + }() } go func() {