ssh/tailssh: close sshContext on context cancellation

This was preventing tailscaled from shutting down properly if there were
active sessions in certain states (e.g. waiting in check mode).

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2022-10-09 14:17:38 -07:00 committed by Maisem Ali
parent 8fe04b035c
commit f172fc42f7
2 changed files with 13 additions and 10 deletions

View File

@ -5,6 +5,7 @@
package tailssh
import (
"context"
"sync"
"time"
)
@ -13,14 +14,16 @@
// that adds a CloseWithError method. Otherwise it's just a normalish
// Context.
type sshContext struct {
mu sync.Mutex
closed bool
done chan struct{}
err error
underlying context.Context
cancel context.CancelFunc // cancels underlying
mu sync.Mutex
closed bool
err error
}
func newSSHContext() *sshContext {
return &sshContext{done: make(chan struct{})}
func newSSHContext(ctx context.Context) *sshContext {
ctx, cancel := context.WithCancel(ctx)
return &sshContext{underlying: ctx, cancel: cancel}
}
func (ctx *sshContext) CloseWithError(err error) {
@ -31,7 +34,7 @@ func (ctx *sshContext) CloseWithError(err error) {
}
ctx.closed = true
ctx.err = err
close(ctx.done)
ctx.cancel()
}
func (ctx *sshContext) Err() error {
@ -40,9 +43,9 @@ func (ctx *sshContext) Err() error {
return ctx.err
}
func (ctx *sshContext) Done() <-chan struct{} { return ctx.done }
func (ctx *sshContext) Done() <-chan struct{} { return ctx.underlying.Done() }
func (ctx *sshContext) Deadline() (deadline time.Time, ok bool) { return }
func (ctx *sshContext) Value(any) any { return nil }
func (ctx *sshContext) Value(k any) any { return ctx.underlying.Value(k) }
// userVisibleError is a wrapper around an error that implements
// SSHTerminationError, so msg is written to their session.

View File

@ -770,7 +770,7 @@ func (c *conn) newSSHSession(s ssh.Session) *sshSession {
return &sshSession{
Session: s,
sharedID: sharedID,
ctx: newSSHContext(),
ctx: newSSHContext(s.Context()),
conn: c,
logf: logger.WithPrefix(c.srv.logf, "ssh-session("+sharedID+"): "),
}