diff --git a/ssh/tailssh/context.go b/ssh/tailssh/context.go index b1e2bc12e..4effaa3e0 100644 --- a/ssh/tailssh/context.go +++ b/ssh/tailssh/context.go @@ -5,6 +5,7 @@ package tailssh import ( + "context" "sync" "time" ) @@ -13,14 +14,16 @@ import ( // that adds a CloseWithError method. Otherwise it's just a normalish // Context. type sshContext struct { - mu sync.Mutex - closed bool - done chan struct{} - err error + underlying context.Context + cancel context.CancelFunc // cancels underlying + mu sync.Mutex + closed bool + err error } -func newSSHContext() *sshContext { - return &sshContext{done: make(chan struct{})} +func newSSHContext(ctx context.Context) *sshContext { + ctx, cancel := context.WithCancel(ctx) + return &sshContext{underlying: ctx, cancel: cancel} } func (ctx *sshContext) CloseWithError(err error) { @@ -31,7 +34,7 @@ func (ctx *sshContext) CloseWithError(err error) { } ctx.closed = true ctx.err = err - close(ctx.done) + ctx.cancel() } func (ctx *sshContext) Err() error { @@ -40,9 +43,9 @@ func (ctx *sshContext) Err() error { return ctx.err } -func (ctx *sshContext) Done() <-chan struct{} { return ctx.done } +func (ctx *sshContext) Done() <-chan struct{} { return ctx.underlying.Done() } func (ctx *sshContext) Deadline() (deadline time.Time, ok bool) { return } -func (ctx *sshContext) Value(any) any { return nil } +func (ctx *sshContext) Value(k any) any { return ctx.underlying.Value(k) } // userVisibleError is a wrapper around an error that implements // SSHTerminationError, so msg is written to their session. diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 30e0aa14e..870f7f4a9 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -770,7 +770,7 @@ func (c *conn) newSSHSession(s ssh.Session) *sshSession { return &sshSession{ Session: s, sharedID: sharedID, - ctx: newSSHContext(), + ctx: newSSHContext(s.Context()), conn: c, logf: logger.WithPrefix(c.srv.logf, "ssh-session("+sharedID+"): "), }