mirror of
https://github.com/tailscale/tailscale.git
synced 2025-08-11 05:07:33 +00:00
ssh/tailssh: use context.WithCancelCause
It was using a custom implmentation of the context.WithCancelCause, replace usage with stdlib. Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
@@ -787,7 +787,8 @@ type sshSession struct {
|
||||
sharedID string // ID that's shared with control
|
||||
logf logger.Logf
|
||||
|
||||
ctx *sshContext // implements context.Context
|
||||
ctx context.Context
|
||||
cancelCtx context.CancelCauseFunc
|
||||
conn *conn
|
||||
agentListener net.Listener // non-nil if agent-forwarding requested+allowed
|
||||
|
||||
@@ -812,12 +813,14 @@ func (ss *sshSession) vlogf(format string, args ...interface{}) {
|
||||
func (c *conn) newSSHSession(s ssh.Session) *sshSession {
|
||||
sharedID := fmt.Sprintf("sess-%s-%02x", c.srv.now().UTC().Format("20060102T150405"), randBytes(5))
|
||||
c.logf("starting session: %v", sharedID)
|
||||
ctx, cancel := context.WithCancelCause(s.Context())
|
||||
return &sshSession{
|
||||
Session: s,
|
||||
sharedID: sharedID,
|
||||
ctx: newSSHContext(s.Context()),
|
||||
conn: c,
|
||||
logf: logger.WithPrefix(c.srv.logf, "ssh-session("+sharedID+"): "),
|
||||
Session: s,
|
||||
sharedID: sharedID,
|
||||
ctx: ctx,
|
||||
cancelCtx: cancel,
|
||||
conn: c,
|
||||
logf: logger.WithPrefix(c.srv.logf, "ssh-session("+sharedID+"): "),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -844,7 +847,7 @@ func (c *conn) checkStillValid() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for _, s := range c.sessions {
|
||||
s.ctx.CloseWithError(userVisibleError{
|
||||
s.cancelCtx(userVisibleError{
|
||||
fmt.Sprintf("Access revoked.\r\n"),
|
||||
context.Canceled,
|
||||
})
|
||||
@@ -897,7 +900,7 @@ func (ss *sshSession) killProcessOnContextDone() {
|
||||
// Either the process has already exited, in which case this does nothing.
|
||||
// Or, the process is still running in which case this will kill it.
|
||||
ss.exitOnce.Do(func() {
|
||||
err := ss.ctx.Err()
|
||||
err := context.Cause(ss.ctx)
|
||||
if serr, ok := err.(SSHTerminationError); ok {
|
||||
msg := serr.SSHTerminationMessage()
|
||||
if msg != "" {
|
||||
@@ -997,7 +1000,7 @@ var recordSSH = envknob.RegisterBool("TS_DEBUG_LOG_SSH")
|
||||
func (ss *sshSession) run() {
|
||||
metricActiveSessions.Add(1)
|
||||
defer metricActiveSessions.Add(-1)
|
||||
defer ss.ctx.CloseWithError(errSessionDone)
|
||||
defer ss.cancelCtx(errSessionDone)
|
||||
|
||||
if attached := ss.conn.srv.attachSessionToConnIfNotShutdown(ss); !attached {
|
||||
fmt.Fprintf(ss, "Tailscale SSH is shutting down\r\n")
|
||||
@@ -1011,7 +1014,7 @@ func (ss *sshSession) run() {
|
||||
|
||||
if ss.conn.finalAction.SessionDuration != 0 {
|
||||
t := time.AfterFunc(ss.conn.finalAction.SessionDuration, func() {
|
||||
ss.ctx.CloseWithError(userVisibleError{
|
||||
ss.cancelCtx(userVisibleError{
|
||||
fmt.Sprintf("Session timeout of %v elapsed.", ss.conn.finalAction.SessionDuration),
|
||||
context.DeadlineExceeded,
|
||||
})
|
||||
@@ -1066,7 +1069,7 @@ func (ss *sshSession) run() {
|
||||
defer ss.stdin.Close()
|
||||
if _, err := io.Copy(rec.writer("i", ss.stdin), ss); err != nil {
|
||||
logf("stdin copy: %v", err)
|
||||
ss.ctx.CloseWithError(err)
|
||||
ss.cancelCtx(err)
|
||||
}
|
||||
}()
|
||||
var openOutputStreams atomic.Int32
|
||||
@@ -1080,7 +1083,7 @@ func (ss *sshSession) run() {
|
||||
_, err := io.Copy(rec.writer("o", ss), ss.stdout)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
logf("stdout copy: %v", err)
|
||||
ss.ctx.CloseWithError(err)
|
||||
ss.cancelCtx(err)
|
||||
}
|
||||
if openOutputStreams.Add(-1) == 0 {
|
||||
ss.CloseWrite()
|
||||
@@ -1489,3 +1492,19 @@ var (
|
||||
metricSFTP = clientmetric.NewCounter("ssh_sftp_requests")
|
||||
metricLocalPortForward = clientmetric.NewCounter("ssh_local_port_forward_requests")
|
||||
)
|
||||
|
||||
// userVisibleError is a wrapper around an error that implements
|
||||
// SSHTerminationError, so msg is written to their session.
|
||||
type userVisibleError struct {
|
||||
msg string
|
||||
error
|
||||
}
|
||||
|
||||
func (ue userVisibleError) SSHTerminationMessage() string { return ue.msg }
|
||||
|
||||
// SSHTerminationError is implemented by errors that terminate an SSH
|
||||
// session and should be written to user's sessions.
|
||||
type SSHTerminationError interface {
|
||||
error
|
||||
SSHTerminationMessage() string
|
||||
}
|
||||
|
Reference in New Issue
Block a user