mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-25 19:15:34 +00:00
ssh/tailssh: fix double race condition with non-pty command (#8405)
There are two race conditions in output handling. The first race condition is due to a misuse of exec.Cmd.StdoutPipe. The documentation explicitly forbids concurrent use of StdoutPipe with exec.Cmd.Wait (see golang/go#60908) because Wait will close both sides of the pipe once the process ends without any guarantees that all data has been read from the pipe. To fix this, we allocate the os.Pipes ourselves and manage cleanup ourselves when the process has ended. The second race condition is because sshSession.run waits upon exec.Cmd to finish and then immediately proceeds to call ss.Exit, which will close all output streams going to the SSH client. This may interrupt any asynchronous io.Copy still copying data. To fix this, we close the write-side of the os.Pipes after the process has finished (and before calling ss.Exit) and synchronously wait for the io.Copy routines to finish. Fixes #7601 Signed-off-by: Joe Tsai <joetsai@digital-static.net> Co-authored-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
parent
d4de60c3ae
commit
61886e031e
@ -476,10 +476,10 @@ func (ss *sshSession) launchProcess() error {
|
|||||||
}
|
}
|
||||||
go resizeWindow(ptyDup /* arbitrary fd */, winCh)
|
go resizeWindow(ptyDup /* arbitrary fd */, winCh)
|
||||||
|
|
||||||
ss.tty = tty
|
ss.wrStdin = pty
|
||||||
ss.stdin = pty
|
ss.rdStdout = os.NewFile(uintptr(ptyDup), pty.Name())
|
||||||
ss.stdout = os.NewFile(uintptr(ptyDup), pty.Name())
|
ss.rdStderr = nil // not available for pty
|
||||||
ss.stderr = nil // not available for pty
|
ss.childPipes = []io.Closer{tty}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -658,40 +658,29 @@ func (ss *sshSession) startWithPTY() (ptyFile, tty *os.File, err error) {
|
|||||||
|
|
||||||
// startWithStdPipes starts cmd with os.Pipe for Stdin, Stdout and Stderr.
|
// startWithStdPipes starts cmd with os.Pipe for Stdin, Stdout and Stderr.
|
||||||
func (ss *sshSession) startWithStdPipes() (err error) {
|
func (ss *sshSession) startWithStdPipes() (err error) {
|
||||||
var stdin io.WriteCloser
|
var rdStdin, wrStdout, wrStderr io.ReadWriteCloser
|
||||||
var stdout, stderr io.ReadCloser
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
for _, c := range []io.Closer{stdin, stdout, stderr} {
|
closeAll(rdStdin, ss.wrStdin, ss.rdStdout, wrStdout, ss.rdStderr, wrStderr)
|
||||||
if c != nil {
|
|
||||||
c.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
cmd := ss.cmd
|
if ss.cmd == nil {
|
||||||
if cmd == nil {
|
|
||||||
return errors.New("nil cmd")
|
return errors.New("nil cmd")
|
||||||
}
|
}
|
||||||
stdin, err = cmd.StdinPipe()
|
if rdStdin, ss.wrStdin, err = os.Pipe(); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
stdout, err = cmd.StdoutPipe()
|
if ss.rdStdout, wrStdout, err = os.Pipe(); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
stderr, err = cmd.StderrPipe()
|
if ss.rdStderr, wrStderr, err = os.Pipe(); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := cmd.Start(); err != nil {
|
ss.cmd.Stdin = rdStdin
|
||||||
return err
|
ss.cmd.Stdout = wrStdout
|
||||||
}
|
ss.cmd.Stderr = wrStderr
|
||||||
ss.stdin = stdin
|
ss.childPipes = []io.Closer{rdStdin, wrStdout, wrStderr}
|
||||||
ss.stdout = stdout
|
return ss.cmd.Start()
|
||||||
ss.stderr = stderr
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func envForUser(u *userMeta) []string {
|
func envForUser(u *userMeta) []string {
|
||||||
|
@ -823,12 +823,16 @@ type sshSession struct {
|
|||||||
agentListener net.Listener // non-nil if agent-forwarding requested+allowed
|
agentListener net.Listener // non-nil if agent-forwarding requested+allowed
|
||||||
|
|
||||||
// initialized by launchProcess:
|
// initialized by launchProcess:
|
||||||
cmd *exec.Cmd
|
cmd *exec.Cmd
|
||||||
stdin io.WriteCloser
|
wrStdin io.WriteCloser
|
||||||
stdout io.ReadCloser
|
rdStdout io.ReadCloser
|
||||||
stderr io.Reader // nil for pty sessions
|
rdStderr io.ReadCloser // rdStderr is nil for pty sessions
|
||||||
ptyReq *ssh.Pty // non-nil for pty sessions
|
ptyReq *ssh.Pty // non-nil for pty sessions
|
||||||
tty *os.File // non-nil for pty sessions, must be closed after process exits
|
|
||||||
|
// childPipes is a list of pipes that need to be closed when the process exits.
|
||||||
|
// For pty sessions, this is the tty fd.
|
||||||
|
// For non-pty sessions, this is the stdin, stdout, stderr fds.
|
||||||
|
childPipes []io.Closer
|
||||||
|
|
||||||
// We use this sync.Once to ensure that we only terminate the process once,
|
// We use this sync.Once to ensure that we only terminate the process once,
|
||||||
// either it exits itself or is terminated
|
// either it exits itself or is terminated
|
||||||
@ -1107,21 +1111,22 @@ func (ss *sshSession) run() {
|
|||||||
|
|
||||||
var processDone atomic.Bool
|
var processDone atomic.Bool
|
||||||
go func() {
|
go func() {
|
||||||
defer ss.stdin.Close()
|
defer ss.wrStdin.Close()
|
||||||
if _, err := io.Copy(rec.writer("i", ss.stdin), ss); err != nil {
|
if _, err := io.Copy(rec.writer("i", ss.wrStdin), ss); err != nil {
|
||||||
logf("stdin copy: %v", err)
|
logf("stdin copy: %v", err)
|
||||||
ss.cancelCtx(err)
|
ss.cancelCtx(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
outputDone := make(chan struct{})
|
||||||
var openOutputStreams atomic.Int32
|
var openOutputStreams atomic.Int32
|
||||||
if ss.stderr != nil {
|
if ss.rdStderr != nil {
|
||||||
openOutputStreams.Store(2)
|
openOutputStreams.Store(2)
|
||||||
} else {
|
} else {
|
||||||
openOutputStreams.Store(1)
|
openOutputStreams.Store(1)
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
defer ss.stdout.Close()
|
defer ss.rdStdout.Close()
|
||||||
_, err := io.Copy(rec.writer("o", ss), ss.stdout)
|
_, err := io.Copy(rec.writer("o", ss), ss.rdStdout)
|
||||||
if err != nil && !errors.Is(err, io.EOF) {
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
isErrBecauseProcessExited := processDone.Load() && errors.Is(err, syscall.EIO)
|
isErrBecauseProcessExited := processDone.Load() && errors.Is(err, syscall.EIO)
|
||||||
if !isErrBecauseProcessExited {
|
if !isErrBecauseProcessExited {
|
||||||
@ -1131,32 +1136,41 @@ func (ss *sshSession) run() {
|
|||||||
}
|
}
|
||||||
if openOutputStreams.Add(-1) == 0 {
|
if openOutputStreams.Add(-1) == 0 {
|
||||||
ss.CloseWrite()
|
ss.CloseWrite()
|
||||||
|
close(outputDone)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
// stderr is nil for ptys.
|
// rdStderr is nil for ptys.
|
||||||
if ss.stderr != nil {
|
if ss.rdStderr != nil {
|
||||||
go func() {
|
go func() {
|
||||||
_, err := io.Copy(ss.Stderr(), ss.stderr)
|
defer ss.rdStderr.Close()
|
||||||
|
_, err := io.Copy(ss.Stderr(), ss.rdStderr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf("stderr copy: %v", err)
|
logf("stderr copy: %v", err)
|
||||||
}
|
}
|
||||||
if openOutputStreams.Add(-1) == 0 {
|
if openOutputStreams.Add(-1) == 0 {
|
||||||
ss.CloseWrite()
|
ss.CloseWrite()
|
||||||
|
close(outputDone)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
if ss.tty != nil {
|
|
||||||
// If running a tty session, close the tty when the session is done.
|
|
||||||
defer ss.tty.Close()
|
|
||||||
}
|
|
||||||
err = ss.cmd.Wait()
|
err = ss.cmd.Wait()
|
||||||
processDone.Store(true)
|
processDone.Store(true)
|
||||||
|
|
||||||
// This will either make the SSH Termination goroutine be a no-op,
|
// This will either make the SSH Termination goroutine be a no-op,
|
||||||
// or itself will be a no-op because the process was killed by the
|
// or itself will be a no-op because the process was killed by the
|
||||||
// aforementioned goroutine.
|
// aforementioned goroutine.
|
||||||
ss.exitOnce.Do(func() {})
|
ss.exitOnce.Do(func() {})
|
||||||
|
|
||||||
|
// Close the process-side of all pipes to signal the asynchronous
|
||||||
|
// io.Copy routines reading/writing from the pipes to terminate.
|
||||||
|
// Block for the io.Copy to finish before calling ss.Exit below.
|
||||||
|
closeAll(ss.childPipes...)
|
||||||
|
select {
|
||||||
|
case <-outputDone:
|
||||||
|
case <-ss.ctx.Done():
|
||||||
|
}
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
ss.logf("Session complete")
|
ss.logf("Session complete")
|
||||||
ss.Exit(0)
|
ss.Exit(0)
|
||||||
@ -1894,3 +1908,11 @@ type SSHTerminationError interface {
|
|||||||
error
|
error
|
||||||
SSHTerminationMessage() string
|
SSHTerminationMessage() string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func closeAll(cs ...io.Closer) {
|
||||||
|
for _, c := range cs {
|
||||||
|
if c != nil {
|
||||||
|
c.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -25,6 +25,7 @@
|
|||||||
"os/user"
|
"os/user"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@ -947,6 +948,19 @@ func TestSSH(t *testing.T) {
|
|||||||
// "foo\n" and "bar\n", not "\n" and "bar\n".
|
// "foo\n" and "bar\n", not "\n" and "bar\n".
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("large_file", func(t *testing.T) {
|
||||||
|
const wantSize = 1e6
|
||||||
|
var outBuf bytes.Buffer
|
||||||
|
cmd := execSSH("head", "-c", strconv.Itoa(wantSize), "/dev/zero")
|
||||||
|
cmd.Stdout = &outBuf
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if gotSize := outBuf.Len(); gotSize != wantSize {
|
||||||
|
t.Fatalf("got %d, want %d", gotSize, int(wantSize))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("stdin", func(t *testing.T) {
|
t.Run("stdin", func(t *testing.T) {
|
||||||
if cibuild.On() {
|
if cibuild.On() {
|
||||||
t.Skip("Skipping for now; see https://github.com/tailscale/tailscale/issues/4051")
|
t.Skip("Skipping for now; see https://github.com/tailscale/tailscale/issues/4051")
|
||||||
|
Loading…
Reference in New Issue
Block a user