From f172fc42f7cba36afdf3a4d89ff0cb2d97e304f6 Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Sun, 9 Oct 2022 14:17:38 -0700 Subject: [PATCH] ssh/tailssh: close sshContext on context cancellation This was preventing tailscaled from shutting down properly if there were active sessions in certain states (e.g. waiting in check mode). Signed-off-by: Maisem Ali --- ssh/tailssh/context.go | 21 ++++++++++++--------- ssh/tailssh/tailssh.go | 2 +- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/ssh/tailssh/context.go b/ssh/tailssh/context.go index b1e2bc12e..4effaa3e0 100644 --- a/ssh/tailssh/context.go +++ b/ssh/tailssh/context.go @@ -5,6 +5,7 @@ package tailssh import ( + "context" "sync" "time" ) @@ -13,14 +14,16 @@ // that adds a CloseWithError method. Otherwise it's just a normalish // Context. type sshContext struct { - mu sync.Mutex - closed bool - done chan struct{} - err error + underlying context.Context + cancel context.CancelFunc // cancels underlying + mu sync.Mutex + closed bool + err error } -func newSSHContext() *sshContext { - return &sshContext{done: make(chan struct{})} +func newSSHContext(ctx context.Context) *sshContext { + ctx, cancel := context.WithCancel(ctx) + return &sshContext{underlying: ctx, cancel: cancel} } func (ctx *sshContext) CloseWithError(err error) { @@ -31,7 +34,7 @@ func (ctx *sshContext) CloseWithError(err error) { } ctx.closed = true ctx.err = err - close(ctx.done) + ctx.cancel() } func (ctx *sshContext) Err() error { @@ -40,9 +43,9 @@ func (ctx *sshContext) Err() error { return ctx.err } -func (ctx *sshContext) Done() <-chan struct{} { return ctx.done } +func (ctx *sshContext) Done() <-chan struct{} { return ctx.underlying.Done() } func (ctx *sshContext) Deadline() (deadline time.Time, ok bool) { return } -func (ctx *sshContext) Value(any) any { return nil } +func (ctx *sshContext) Value(k any) any { return ctx.underlying.Value(k) } // userVisibleError is a wrapper around an error that implements // SSHTerminationError, so msg is written to their session. diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 30e0aa14e..870f7f4a9 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -770,7 +770,7 @@ func (c *conn) newSSHSession(s ssh.Session) *sshSession { return &sshSession{ Session: s, sharedID: sharedID, - ctx: newSSHContext(), + ctx: newSSHContext(s.Context()), conn: c, logf: logger.WithPrefix(c.srv.logf, "ssh-session("+sharedID+"): "), }