mirror of
https://github.com/tailscale/tailscale.git
synced 2025-08-20 09:57:31 +00:00
ssh/tailssh: break a method into half in prep for testing
And add a private context type in the process. Updates #3802 Change-Id: I257187f4cfb0f2248d95b81c1dfe0911ef203b60 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:

committed by
Brad Fitzpatrick

parent
1b5bb2e81d
commit
e2ed06c53c
@@ -9,6 +9,7 @@
|
||||
package tailssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -102,7 +103,6 @@ func (srv *server) sshPolicy() (_ *tailcfg.SSHPolicy, ok bool) {
|
||||
func (srv *server) handleSSH(s ssh.Session) {
|
||||
lb := srv.lb
|
||||
logf := srv.logf
|
||||
|
||||
sshUser := s.User()
|
||||
addr := s.RemoteAddr()
|
||||
logf("Handling SSH from %v for user %v", addr, sshUser)
|
||||
@@ -131,7 +131,6 @@ func (srv *server) handleSSH(s ssh.Session) {
|
||||
return
|
||||
}
|
||||
|
||||
ptyReq, winCh, isPty := s.Pty()
|
||||
srcIPP := netaddr.IPPortFrom(tanetaddr, uint16(ta.Port))
|
||||
node, uprof, ok := lb.WhoIs(srcIPP)
|
||||
if !ok {
|
||||
@@ -167,7 +166,34 @@ func (srv *server) handleSSH(s ssh.Session) {
|
||||
s.Exit(1)
|
||||
}
|
||||
|
||||
logf("ssh: connection from %v %v to %v@ => %q. command = %q, env = %q", srcIP, uprof.LoginName, sshUser, localUser, s.Command(), s.Environ())
|
||||
var ctx context.Context = context.Background()
|
||||
if action.SesssionDuration != 0 {
|
||||
sctx := newSSHContext()
|
||||
ctx = sctx
|
||||
t := time.AfterFunc(action.SesssionDuration, func() {
|
||||
sctx.CloseWithError(userVisibleError{
|
||||
fmt.Sprintf("Session timeout of %v elapsed.", action.SesssionDuration),
|
||||
context.DeadlineExceeded,
|
||||
})
|
||||
})
|
||||
defer t.Stop()
|
||||
}
|
||||
srv.handleAcceptedSSH(ctx, s, ci, lu)
|
||||
}
|
||||
|
||||
// handleAcceptedSSH handles s once it's been accepted and determined
|
||||
// that it should run as local system user lu.
|
||||
//
|
||||
// When ctx is done, the session is forcefully terminated. If its Err
|
||||
// is an SSHTerminationError, its SSHTerminationMessage is sent to the
|
||||
// user.
|
||||
func (srv *server) handleAcceptedSSH(ctx context.Context, s ssh.Session, ci *sshConnInfo, lu *user.User) {
|
||||
logf := srv.logf
|
||||
localUser := lu.Username
|
||||
|
||||
var err error
|
||||
ptyReq, winCh, isPty := s.Pty()
|
||||
logf("ssh: connection from %v %v to %v@ => %q. command = %q, env = %q", ci.srcIP, ci.uprof.LoginName, ci.sshUser, localUser, s.Command(), s.Environ())
|
||||
var cmd *exec.Cmd
|
||||
if euid := os.Geteuid(); euid != 0 {
|
||||
if lu.Uid != fmt.Sprint(euid) {
|
||||
@@ -223,12 +249,24 @@ func (srv *server) handleSSH(s ssh.Session) {
|
||||
go func() { io.Copy(s.Stderr(), stderr) }()
|
||||
}
|
||||
|
||||
if action.SesssionDuration != 0 {
|
||||
t := time.AfterFunc(action.SesssionDuration, func() {
|
||||
logf("terminating SSH session from %v after max duration", srcIP)
|
||||
cmd.Process.Kill()
|
||||
})
|
||||
defer t.Stop()
|
||||
if ctx.Done() != nil {
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
go func() {
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
err := ctx.Err()
|
||||
if serr, ok := err.(SSHTerminationError); ok {
|
||||
msg := serr.SSHTerminationMessage()
|
||||
if msg != "" {
|
||||
io.WriteString(s.Stderr(), "\r\n\r\n"+msg+"\r\n\r\n")
|
||||
}
|
||||
}
|
||||
logf("terminating SSH session from %v: %v", ci.srcIP, err)
|
||||
cmd.Process.Kill()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
go func() {
|
||||
|
Reference in New Issue
Block a user