ssh/tailssh: break a method into half in prep for testing

And add a private context type in the process.

Updates #3802

Change-Id: I257187f4cfb0f2248d95b81c1dfe0911ef203b60
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick
2022-02-23 15:47:57 -08:00
committed by Brad Fitzpatrick
parent 1b5bb2e81d
commit e2ed06c53c
2 changed files with 108 additions and 9 deletions

View File

@@ -9,6 +9,7 @@
package tailssh
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -102,7 +103,6 @@ func (srv *server) sshPolicy() (_ *tailcfg.SSHPolicy, ok bool) {
func (srv *server) handleSSH(s ssh.Session) {
lb := srv.lb
logf := srv.logf
sshUser := s.User()
addr := s.RemoteAddr()
logf("Handling SSH from %v for user %v", addr, sshUser)
@@ -131,7 +131,6 @@ func (srv *server) handleSSH(s ssh.Session) {
return
}
ptyReq, winCh, isPty := s.Pty()
srcIPP := netaddr.IPPortFrom(tanetaddr, uint16(ta.Port))
node, uprof, ok := lb.WhoIs(srcIPP)
if !ok {
@@ -167,7 +166,34 @@ func (srv *server) handleSSH(s ssh.Session) {
s.Exit(1)
}
logf("ssh: connection from %v %v to %v@ => %q. command = %q, env = %q", srcIP, uprof.LoginName, sshUser, localUser, s.Command(), s.Environ())
var ctx context.Context = context.Background()
if action.SesssionDuration != 0 {
sctx := newSSHContext()
ctx = sctx
t := time.AfterFunc(action.SesssionDuration, func() {
sctx.CloseWithError(userVisibleError{
fmt.Sprintf("Session timeout of %v elapsed.", action.SesssionDuration),
context.DeadlineExceeded,
})
})
defer t.Stop()
}
srv.handleAcceptedSSH(ctx, s, ci, lu)
}
// handleAcceptedSSH handles s once it's been accepted and determined
// that it should run as local system user lu.
//
// When ctx is done, the session is forcefully terminated. If its Err
// is an SSHTerminationError, its SSHTerminationMessage is sent to the
// user.
func (srv *server) handleAcceptedSSH(ctx context.Context, s ssh.Session, ci *sshConnInfo, lu *user.User) {
logf := srv.logf
localUser := lu.Username
var err error
ptyReq, winCh, isPty := s.Pty()
logf("ssh: connection from %v %v to %v@ => %q. command = %q, env = %q", ci.srcIP, ci.uprof.LoginName, ci.sshUser, localUser, s.Command(), s.Environ())
var cmd *exec.Cmd
if euid := os.Geteuid(); euid != 0 {
if lu.Uid != fmt.Sprint(euid) {
@@ -223,12 +249,24 @@ func (srv *server) handleSSH(s ssh.Session) {
go func() { io.Copy(s.Stderr(), stderr) }()
}
if action.SesssionDuration != 0 {
t := time.AfterFunc(action.SesssionDuration, func() {
logf("terminating SSH session from %v after max duration", srcIP)
cmd.Process.Kill()
})
defer t.Stop()
if ctx.Done() != nil {
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-done:
case <-ctx.Done():
err := ctx.Err()
if serr, ok := err.(SSHTerminationError); ok {
msg := serr.SSHTerminationMessage()
if msg != "" {
io.WriteString(s.Stderr(), "\r\n\r\n"+msg+"\r\n\r\n")
}
}
logf("terminating SSH session from %v: %v", ci.srcIP, err)
cmd.Process.Kill()
}
}()
}
go func() {