ssh/tailssh: better handling of signals and exits

We were not handling errors occurred while copying data between the subprocess and the connection.
This makes it so that we pass the appropriate signals when to the process and the connection.

This also fixes mosh.

Updates #4919

Co-authored-by: James Tucker <raggi@tailscale.com>
Co-authored-by: Brad Fitzpatrick <bradfitz@tailscale.com>
Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2022-07-15 15:40:20 +00:00 committed by Maisem Ali
parent 004f0ca3e0
commit af412e8874
2 changed files with 29 additions and 12 deletions

View File

@ -310,15 +310,25 @@ func (ss *sshSession) launchProcess() error {
if err != nil {
return err
}
go resizeWindow(pty, winCh)
ss.stdout = pty // no stderr for a pty
// We need to be able to close stdin and stdout separately later so make a
// dup.
ptyDup, err := syscall.Dup(int(pty.Fd()))
if err != nil {
return err
}
go resizeWindow(ptyDup /* arbitrary fd */, winCh)
ss.stdin = pty
ss.stdout = os.NewFile(uintptr(ptyDup), pty.Name())
ss.stderr = nil // not available for pty
return nil
}
func resizeWindow(f *os.File, winCh <-chan ssh.Window) {
func resizeWindow(fd int, winCh <-chan ssh.Window) {
for win := range winCh {
unix.IoctlSetWinsize(int(f.Fd()), syscall.TIOCSWINSZ, &unix.Winsize{
unix.IoctlSetWinsize(fd, syscall.TIOCSWINSZ, &unix.Winsize{
Row: uint16(win.Height),
Col: uint16(win.Width),
})

View File

@ -732,7 +732,7 @@ type sshSession struct {
// initialized by launchProcess:
cmd *exec.Cmd
stdin io.WriteCloser
stdout io.Reader
stdout io.ReadCloser
stderr io.Reader // nil for pty sessions
ptyReq *ssh.Pty // non-nil for pty sessions
@ -843,6 +843,8 @@ func (ss *sshSession) killProcessOnContextDone() {
ss.logf("terminating SSH session from %v: %v", ss.conn.info.src.IP(), err)
// We don't need to Process.Wait here, sshSession.run() does
// the waiting regardless of termination reason.
// TODO(maisem): should this be a SIGTERM followed by a SIGKILL?
ss.cmd.Process.Kill()
})
}
@ -1004,20 +1006,23 @@ func (ss *sshSession) run() {
go ss.killProcessOnContextDone()
go func() {
_, err := io.Copy(rec.writer("i", ss.stdin), ss)
if err != nil {
// TODO: don't log in the success case.
defer ss.stdin.Close()
if _, err := io.Copy(rec.writer("i", ss.stdin), ss); err != nil {
logf("stdin copy: %v", err)
ss.ctx.CloseWithError(err)
} else if ss.ptyReq != nil {
const EOT = 4 // https://en.wikipedia.org/wiki/End-of-Transmission_character
ss.stdin.Write([]byte{EOT})
}
ss.stdin.Close()
}()
go func() {
defer ss.stdout.Close()
_, err := io.Copy(rec.writer("o", ss), ss.stdout)
if err != nil {
if err != nil && !errors.Is(err, io.EOF) {
logf("stdout copy: %v", err)
// If we got an error here, it's probably because the client has
// disconnected.
ss.ctx.CloseWithError(err)
} else {
ss.CloseWrite()
}
}()
// stderr is nil for ptys.
@ -1029,6 +1034,7 @@ func (ss *sshSession) run() {
}
}()
}
err = ss.cmd.Wait()
// This will either make the SSH Termination goroutine be a no-op,
// or itself will be a no-op because the process was killed by the
@ -1036,6 +1042,7 @@ func (ss *sshSession) run() {
ss.exitOnce.Do(func() {})
if err == nil {
ss.logf("Session complete")
ss.Exit(0)
return
}