ssh/tailssh: break a method into half in prep for testing

And add a private context type in the process.

Updates #3802

Change-Id: I257187f4cfb0f2248d95b81c1dfe0911ef203b60
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2022-02-23 15:47:57 -08:00 committed by Brad Fitzpatrick
parent 1b5bb2e81d
commit e2ed06c53c
2 changed files with 108 additions and 9 deletions

61
ssh/tailssh/context.go Normal file
View File

@ -0,0 +1,61 @@
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tailssh
import (
"sync"
"time"
)
// sshContext is the context.Context implementation we use for SSH
// that adds a CloseWithError method. Otherwise it's just a normalish
// Context.
type sshContext struct {
mu sync.Mutex
closed bool
done chan struct{}
err error
}
func newSSHContext() *sshContext {
return &sshContext{done: make(chan struct{})}
}
func (ctx *sshContext) CloseWithError(err error) {
ctx.mu.Lock()
defer ctx.mu.Unlock()
if ctx.closed {
return
}
ctx.closed = true
ctx.err = err
close(ctx.done)
}
func (ctx *sshContext) Err() error {
ctx.mu.Lock()
defer ctx.mu.Unlock()
return ctx.err
}
func (ctx *sshContext) Done() <-chan struct{} { return ctx.done }
func (ctx *sshContext) Deadline() (deadline time.Time, ok bool) { return }
func (ctx *sshContext) Value(interface{}) interface{} { return nil }
// userVisibleError is a wrapper around an error that implements
// SSHTerminationError, so msg is written to their session.
type userVisibleError struct {
msg string
error
}
func (ue userVisibleError) SSHTerminationMessage() string { return ue.msg }
// SSHTerminationError is implemented by errors that terminate an SSH
// session and should be written to user's sessions.
type SSHTerminationError interface {
error
SSHTerminationMessage() string
}

View File

@ -9,6 +9,7 @@
package tailssh package tailssh
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -102,7 +103,6 @@ func (srv *server) sshPolicy() (_ *tailcfg.SSHPolicy, ok bool) {
func (srv *server) handleSSH(s ssh.Session) { func (srv *server) handleSSH(s ssh.Session) {
lb := srv.lb lb := srv.lb
logf := srv.logf logf := srv.logf
sshUser := s.User() sshUser := s.User()
addr := s.RemoteAddr() addr := s.RemoteAddr()
logf("Handling SSH from %v for user %v", addr, sshUser) logf("Handling SSH from %v for user %v", addr, sshUser)
@ -131,7 +131,6 @@ func (srv *server) handleSSH(s ssh.Session) {
return return
} }
ptyReq, winCh, isPty := s.Pty()
srcIPP := netaddr.IPPortFrom(tanetaddr, uint16(ta.Port)) srcIPP := netaddr.IPPortFrom(tanetaddr, uint16(ta.Port))
node, uprof, ok := lb.WhoIs(srcIPP) node, uprof, ok := lb.WhoIs(srcIPP)
if !ok { if !ok {
@ -167,7 +166,34 @@ func (srv *server) handleSSH(s ssh.Session) {
s.Exit(1) s.Exit(1)
} }
logf("ssh: connection from %v %v to %v@ => %q. command = %q, env = %q", srcIP, uprof.LoginName, sshUser, localUser, s.Command(), s.Environ()) var ctx context.Context = context.Background()
if action.SesssionDuration != 0 {
sctx := newSSHContext()
ctx = sctx
t := time.AfterFunc(action.SesssionDuration, func() {
sctx.CloseWithError(userVisibleError{
fmt.Sprintf("Session timeout of %v elapsed.", action.SesssionDuration),
context.DeadlineExceeded,
})
})
defer t.Stop()
}
srv.handleAcceptedSSH(ctx, s, ci, lu)
}
// handleAcceptedSSH handles s once it's been accepted and determined
// that it should run as local system user lu.
//
// When ctx is done, the session is forcefully terminated. If its Err
// is an SSHTerminationError, its SSHTerminationMessage is sent to the
// user.
func (srv *server) handleAcceptedSSH(ctx context.Context, s ssh.Session, ci *sshConnInfo, lu *user.User) {
logf := srv.logf
localUser := lu.Username
var err error
ptyReq, winCh, isPty := s.Pty()
logf("ssh: connection from %v %v to %v@ => %q. command = %q, env = %q", ci.srcIP, ci.uprof.LoginName, ci.sshUser, localUser, s.Command(), s.Environ())
var cmd *exec.Cmd var cmd *exec.Cmd
if euid := os.Geteuid(); euid != 0 { if euid := os.Geteuid(); euid != 0 {
if lu.Uid != fmt.Sprint(euid) { if lu.Uid != fmt.Sprint(euid) {
@ -223,12 +249,24 @@ func (srv *server) handleSSH(s ssh.Session) {
go func() { io.Copy(s.Stderr(), stderr) }() go func() { io.Copy(s.Stderr(), stderr) }()
} }
if action.SesssionDuration != 0 { if ctx.Done() != nil {
t := time.AfterFunc(action.SesssionDuration, func() { done := make(chan struct{})
logf("terminating SSH session from %v after max duration", srcIP) defer close(done)
cmd.Process.Kill() go func() {
}) select {
defer t.Stop() case <-done:
case <-ctx.Done():
err := ctx.Err()
if serr, ok := err.(SSHTerminationError); ok {
msg := serr.SSHTerminationMessage()
if msg != "" {
io.WriteString(s.Stderr(), "\r\n\r\n"+msg+"\r\n\r\n")
}
}
logf("terminating SSH session from %v: %v", ci.srcIP, err)
cmd.Process.Kill()
}
}()
} }
go func() { go func() {