ssh/tailssh: fix double race condition with non-pty command (#8405)

There are two race conditions in output handling.

The first race condition is due to a misuse of exec.Cmd.StdoutPipe.
The documentation explicitly forbids concurrent use of StdoutPipe
with exec.Cmd.Wait (see golang/go#60908) because Wait will
close both sides of the pipe once the process ends without
any guarantees that all data has been read from the pipe.
To fix this, we allocate the os.Pipes ourselves and
manage cleanup ourselves when the process has ended.

The second race condition is because sshSession.run waits
upon exec.Cmd to finish and then immediately proceeds to call ss.Exit,
which will close all output streams going to the SSH client.
This may interrupt any asynchronous io.Copy still copying data.
To fix this, we close the write-side of the os.Pipes after
the process has finished (and before calling ss.Exit) and
synchronously wait for the io.Copy routines to finish.

Fixes #7601

Signed-off-by: Joe Tsai <joetsai@digital-static.net>
Co-authored-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Joe Tsai 2023-06-21 19:57:45 -07:00 committed by GitHub
parent d4de60c3ae
commit 61886e031e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 69 additions and 44 deletions

View File

@ -476,10 +476,10 @@ func (ss *sshSession) launchProcess() error {
}
go resizeWindow(ptyDup /* arbitrary fd */, winCh)
ss.tty = tty
ss.stdin = pty
ss.stdout = os.NewFile(uintptr(ptyDup), pty.Name())
ss.stderr = nil // not available for pty
ss.wrStdin = pty
ss.rdStdout = os.NewFile(uintptr(ptyDup), pty.Name())
ss.rdStderr = nil // not available for pty
ss.childPipes = []io.Closer{tty}
return nil
}
@ -658,40 +658,29 @@ func (ss *sshSession) startWithPTY() (ptyFile, tty *os.File, err error) {
// startWithStdPipes starts cmd with os.Pipe for Stdin, Stdout and Stderr.
func (ss *sshSession) startWithStdPipes() (err error) {
var stdin io.WriteCloser
var stdout, stderr io.ReadCloser
var rdStdin, wrStdout, wrStderr io.ReadWriteCloser
defer func() {
if err != nil {
for _, c := range []io.Closer{stdin, stdout, stderr} {
if c != nil {
c.Close()
}
}
closeAll(rdStdin, ss.wrStdin, ss.rdStdout, wrStdout, ss.rdStderr, wrStderr)
}
}()
cmd := ss.cmd
if cmd == nil {
if ss.cmd == nil {
return errors.New("nil cmd")
}
stdin, err = cmd.StdinPipe()
if err != nil {
if rdStdin, ss.wrStdin, err = os.Pipe(); err != nil {
return err
}
stdout, err = cmd.StdoutPipe()
if err != nil {
if ss.rdStdout, wrStdout, err = os.Pipe(); err != nil {
return err
}
stderr, err = cmd.StderrPipe()
if err != nil {
if ss.rdStderr, wrStderr, err = os.Pipe(); err != nil {
return err
}
if err := cmd.Start(); err != nil {
return err
}
ss.stdin = stdin
ss.stdout = stdout
ss.stderr = stderr
return nil
ss.cmd.Stdin = rdStdin
ss.cmd.Stdout = wrStdout
ss.cmd.Stderr = wrStderr
ss.childPipes = []io.Closer{rdStdin, wrStdout, wrStderr}
return ss.cmd.Start()
}
func envForUser(u *userMeta) []string {

View File

@ -823,12 +823,16 @@ type sshSession struct {
agentListener net.Listener // non-nil if agent-forwarding requested+allowed
// initialized by launchProcess:
cmd *exec.Cmd
stdin io.WriteCloser
stdout io.ReadCloser
stderr io.Reader // nil for pty sessions
ptyReq *ssh.Pty // non-nil for pty sessions
tty *os.File // non-nil for pty sessions, must be closed after process exits
cmd *exec.Cmd
wrStdin io.WriteCloser
rdStdout io.ReadCloser
rdStderr io.ReadCloser // rdStderr is nil for pty sessions
ptyReq *ssh.Pty // non-nil for pty sessions
// childPipes is a list of pipes that need to be closed when the process exits.
// For pty sessions, this is the tty fd.
// For non-pty sessions, this is the stdin, stdout, stderr fds.
childPipes []io.Closer
// We use this sync.Once to ensure that we only terminate the process once,
// either it exits itself or is terminated
@ -1107,21 +1111,22 @@ func (ss *sshSession) run() {
var processDone atomic.Bool
go func() {
defer ss.stdin.Close()
if _, err := io.Copy(rec.writer("i", ss.stdin), ss); err != nil {
defer ss.wrStdin.Close()
if _, err := io.Copy(rec.writer("i", ss.wrStdin), ss); err != nil {
logf("stdin copy: %v", err)
ss.cancelCtx(err)
}
}()
outputDone := make(chan struct{})
var openOutputStreams atomic.Int32
if ss.stderr != nil {
if ss.rdStderr != nil {
openOutputStreams.Store(2)
} else {
openOutputStreams.Store(1)
}
go func() {
defer ss.stdout.Close()
_, err := io.Copy(rec.writer("o", ss), ss.stdout)
defer ss.rdStdout.Close()
_, err := io.Copy(rec.writer("o", ss), ss.rdStdout)
if err != nil && !errors.Is(err, io.EOF) {
isErrBecauseProcessExited := processDone.Load() && errors.Is(err, syscall.EIO)
if !isErrBecauseProcessExited {
@ -1131,32 +1136,41 @@ func (ss *sshSession) run() {
}
if openOutputStreams.Add(-1) == 0 {
ss.CloseWrite()
close(outputDone)
}
}()
// stderr is nil for ptys.
if ss.stderr != nil {
// rdStderr is nil for ptys.
if ss.rdStderr != nil {
go func() {
_, err := io.Copy(ss.Stderr(), ss.stderr)
defer ss.rdStderr.Close()
_, err := io.Copy(ss.Stderr(), ss.rdStderr)
if err != nil {
logf("stderr copy: %v", err)
}
if openOutputStreams.Add(-1) == 0 {
ss.CloseWrite()
close(outputDone)
}
}()
}
if ss.tty != nil {
// If running a tty session, close the tty when the session is done.
defer ss.tty.Close()
}
err = ss.cmd.Wait()
processDone.Store(true)
// This will either make the SSH Termination goroutine be a no-op,
// or itself will be a no-op because the process was killed by the
// aforementioned goroutine.
ss.exitOnce.Do(func() {})
// Close the process-side of all pipes to signal the asynchronous
// io.Copy routines reading/writing from the pipes to terminate.
// Block for the io.Copy to finish before calling ss.Exit below.
closeAll(ss.childPipes...)
select {
case <-outputDone:
case <-ss.ctx.Done():
}
if err == nil {
ss.logf("Session complete")
ss.Exit(0)
@ -1894,3 +1908,11 @@ type SSHTerminationError interface {
error
SSHTerminationMessage() string
}
func closeAll(cs ...io.Closer) {
for _, c := range cs {
if c != nil {
c.Close()
}
}
}

View File

@ -25,6 +25,7 @@
"os/user"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
@ -947,6 +948,19 @@ func TestSSH(t *testing.T) {
// "foo\n" and "bar\n", not "\n" and "bar\n".
})
t.Run("large_file", func(t *testing.T) {
const wantSize = 1e6
var outBuf bytes.Buffer
cmd := execSSH("head", "-c", strconv.Itoa(wantSize), "/dev/zero")
cmd.Stdout = &outBuf
if err := cmd.Run(); err != nil {
t.Fatal(err)
}
if gotSize := outBuf.Len(); gotSize != wantSize {
t.Fatalf("got %d, want %d", gotSize, int(wantSize))
}
})
t.Run("stdin", func(t *testing.T) {
if cibuild.On() {
t.Skip("Skipping for now; see https://github.com/tailscale/tailscale/issues/4051")