ssh/tailssh: flesh out env, support non-pty commands

Updates #3802

Change-Id: I7022460117542a5424919144828bf571c7c19ec0
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2022-02-19 15:37:13 -08:00 committed by Brad Fitzpatrick
parent 7d897229d9
commit 3c2cd854be

View File

@ -150,7 +150,7 @@ func (srv *server) handleSSH(s ssh.Session) {
}
action, localUser, ok := evalSSHPolicy(pol, sctx)
if ok && action.Message != "" {
io.WriteString(s, action.Message)
io.WriteString(s.Stderr(), strings.Replace(action.Message, "\n", "\r\n", -1))
}
if !ok || action.Reject {
logf("ssh: access denied for %q from %v", uprof.LoginName, srcIP)
@ -160,62 +160,102 @@ func (srv *server) handleSSH(s ssh.Session) {
if !action.Accept || action.HoldAndDelegate != "" {
fmt.Fprintf(s, "TODO: other SSHAction outcomes")
s.Exit(1)
}
if !isPty {
fmt.Fprintf(s, "TODO scp etc\n")
lu, err := user.Lookup(localUser)
if err != nil {
logf("ssh: user Lookup %q: %v", localUser, err)
s.Exit(1)
return
}
logf("ssh: connection from %v %v to %v@ => %q. command = %q, env = %q", srcIP, uprof.LoginName, sshUser, localUser, s.Command(), s.Environ())
var cmd *exec.Cmd
if os.Getuid() != 0 {
u, err := user.Current()
if err != nil {
logf("failed to get current user: %v", err)
s.Exit(1)
return
}
if u.Username != localUser {
if euid := os.Geteuid(); euid != 0 {
if lu.Uid != fmt.Sprint(euid) {
logf("ssh: can't switch to user %q from process euid %v", localUser, euid)
fmt.Fprintf(s, "can't switch user\n")
s.Exit(1)
return
}
cmd = exec.Command(loginShell(u.Uid))
cmd = exec.Command(loginShell(lu.Uid))
} else {
cmd = exec.Command("/usr/bin/env", "su", "-", localUser)
if rawCmd := s.RawCommand(); rawCmd != "" {
cmd = exec.Command("/usr/bin/env", "su", "-c", rawCmd, localUser)
cmd.Dir = lu.HomeDir
cmd.Env = append(cmd.Env, envForUser(lu)...)
// TODO: and Env for PATH, SSH_CONNECTION, SSH_CLIENT, XDG_SESSION_TYPE, XDG_*, etc
} else {
cmd = exec.Command("/usr/bin/env", "su", "-", localUser)
}
}
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
f, err := pty.Start(cmd)
if err != nil {
logf("running shell: %v", err)
s.Exit(1)
return
if ptyReq.Term != "" {
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
}
// TODO(bradfitz,maisem): also blend in user's s.Environ()
logf("Running: %q", cmd.Args)
var toCmd io.WriteCloser
var fromCmd io.ReadCloser
if isPty {
f, err := pty.StartWithSize(cmd, &pty.Winsize{
Rows: uint16(ptyReq.Window.Width),
Cols: uint16(ptyReq.Window.Height),
})
if err != nil {
logf("running shell: %v", err)
s.Exit(1)
return
}
defer f.Close()
toCmd = f
fromCmd = f
go func() {
for win := range winCh {
setWinsize(f, win.Width, win.Height)
}
}()
} else {
stdin, stdout, stderr, err := startWithStdPipes(cmd)
if err != nil {
logf("ssh: start error: %f", err)
s.Exit(1)
return
}
fromCmd, toCmd = stdout, stdin
go func() { io.Copy(s.Stderr(), stderr) }()
}
if action.SesssionDuration != 0 {
t := time.AfterFunc(action.SesssionDuration, func() {
logf("terminating SSH session from %v after max duration", srcIP)
cmd.Process.Kill()
f.Close()
})
defer t.Stop()
}
defer f.Close()
go func() {
for win := range winCh {
setWinsize(f, win.Width, win.Height)
}
_, err := io.Copy(toCmd, s) // stdin
logf("ssh: stdin copy: %v", err)
toCmd.Close()
}()
go func() {
io.Copy(f, s) // stdin
_, err := io.Copy(s, fromCmd) // stdout
logf("ssh: stdout copy: %v", err)
}()
io.Copy(s, f) // stdout
cmd.Process.Kill()
if err := cmd.Wait(); err != nil {
s.Exit(1)
err = cmd.Wait()
if err == nil {
logf("ssh: Wait: ok")
s.Exit(0)
return
}
s.Exit(0)
if ee, ok := err.(*exec.ExitError); ok {
code := ee.ProcessState.ExitCode()
logf("ssh: Wait: code=%v", code)
s.Exit(code)
return
}
logf("ssh: Wait: %v", err)
s.Exit(1)
return
}
@ -327,3 +367,37 @@ func loginShell(uid string) string {
}
return "/bin/bash"
}
func startWithStdPipes(cmd *exec.Cmd) (stdin io.WriteCloser, stdout, stderr io.ReadCloser, err error) {
defer func() {
if err != nil {
for _, c := range []io.Closer{stdin, stdout, stderr} {
if c != nil {
c.Close()
}
}
}
}()
stdin, err = cmd.StdinPipe()
if err != nil {
return
}
stdout, err = cmd.StdoutPipe()
if err != nil {
return
}
stderr, err = cmd.StderrPipe()
if err != nil {
return
}
err = cmd.Start()
return
}
func envForUser(u *user.User) []string {
return []string{
fmt.Sprintf("SHELL=" + loginShell(u.Uid)),
fmt.Sprintf("USER=" + u.Username),
fmt.Sprintf("HOME=" + u.HomeDir),
}
}