ssh/tailssh: better handling of signals and exits

We were not handling errors occurred while copying data between the subprocess and the connection.
This makes it so that we pass the appropriate signals when to the process and the connection.

This also fixes mosh.

Updates #4919

Co-authored-by: James Tucker <raggi@tailscale.com>
Co-authored-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2022-07-15 15:40:20 +00:00 committed by Maisem Ali
parent 004f0ca3e0
commit af412e8874
2 changed files with 29 additions and 12 deletions

View File

@ -310,15 +310,25 @@ func (ss *sshSession) launchProcess() error {
if err != nil { if err != nil {
return err return err
} }
go resizeWindow(pty, winCh)
ss.stdout = pty // no stderr for a pty // We need to be able to close stdin and stdout separately later so make a
// dup.
ptyDup, err := syscall.Dup(int(pty.Fd()))
if err != nil {
return err
}
go resizeWindow(ptyDup /* arbitrary fd */, winCh)
ss.stdin = pty ss.stdin = pty
ss.stdout = os.NewFile(uintptr(ptyDup), pty.Name())
ss.stderr = nil // not available for pty
return nil return nil
} }
func resizeWindow(f *os.File, winCh <-chan ssh.Window) { func resizeWindow(fd int, winCh <-chan ssh.Window) {
for win := range winCh { for win := range winCh {
unix.IoctlSetWinsize(int(f.Fd()), syscall.TIOCSWINSZ, &unix.Winsize{ unix.IoctlSetWinsize(fd, syscall.TIOCSWINSZ, &unix.Winsize{
Row: uint16(win.Height), Row: uint16(win.Height),
Col: uint16(win.Width), Col: uint16(win.Width),
}) })

View File

@ -732,7 +732,7 @@ type sshSession struct {
// initialized by launchProcess: // initialized by launchProcess:
cmd *exec.Cmd cmd *exec.Cmd
stdin io.WriteCloser stdin io.WriteCloser
stdout io.Reader stdout io.ReadCloser
stderr io.Reader // nil for pty sessions stderr io.Reader // nil for pty sessions
ptyReq *ssh.Pty // non-nil for pty sessions ptyReq *ssh.Pty // non-nil for pty sessions
@ -843,6 +843,8 @@ func (ss *sshSession) killProcessOnContextDone() {
ss.logf("terminating SSH session from %v: %v", ss.conn.info.src.IP(), err) ss.logf("terminating SSH session from %v: %v", ss.conn.info.src.IP(), err)
// We don't need to Process.Wait here, sshSession.run() does // We don't need to Process.Wait here, sshSession.run() does
// the waiting regardless of termination reason. // the waiting regardless of termination reason.
// TODO(maisem): should this be a SIGTERM followed by a SIGKILL?
ss.cmd.Process.Kill() ss.cmd.Process.Kill()
}) })
} }
@ -1004,20 +1006,23 @@ func (ss *sshSession) run() {
go ss.killProcessOnContextDone() go ss.killProcessOnContextDone()
go func() { go func() {
_, err := io.Copy(rec.writer("i", ss.stdin), ss) defer ss.stdin.Close()
if err != nil { if _, err := io.Copy(rec.writer("i", ss.stdin), ss); err != nil {
// TODO: don't log in the success case.
logf("stdin copy: %v", err) logf("stdin copy: %v", err)
ss.ctx.CloseWithError(err)
} else if ss.ptyReq != nil {
const EOT = 4 // https://en.wikipedia.org/wiki/End-of-Transmission_character
ss.stdin.Write([]byte{EOT})
} }
ss.stdin.Close()
}() }()
go func() { go func() {
defer ss.stdout.Close()
_, err := io.Copy(rec.writer("o", ss), ss.stdout) _, err := io.Copy(rec.writer("o", ss), ss.stdout)
if err != nil { if err != nil && !errors.Is(err, io.EOF) {
logf("stdout copy: %v", err) logf("stdout copy: %v", err)
// If we got an error here, it's probably because the client has
// disconnected.
ss.ctx.CloseWithError(err) ss.ctx.CloseWithError(err)
} else {
ss.CloseWrite()
} }
}() }()
// stderr is nil for ptys. // stderr is nil for ptys.
@ -1029,6 +1034,7 @@ func (ss *sshSession) run() {
} }
}() }()
} }
err = ss.cmd.Wait() err = ss.cmd.Wait()
// This will either make the SSH Termination goroutine be a no-op, // This will either make the SSH Termination goroutine be a no-op,
// or itself will be a no-op because the process was killed by the // or itself will be a no-op because the process was killed by the
@ -1036,6 +1042,7 @@ func (ss *sshSession) run() {
ss.exitOnce.Do(func() {}) ss.exitOnce.Do(func() {})
if err == nil { if err == nil {
ss.logf("Session complete")
ss.Exit(0) ss.Exit(0)
return return
} }