mirror of
https://github.com/tailscale/tailscale.git
synced 2025-02-18 02:48:40 +00:00
ssh/tailssh: handle Control-C during hold-and-delegate prompt
Fixes #4549 Change-Id: Iafc61af5e08cd03564d39cf667e940b2417714cc Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
f9e86e64b7
commit
c1445155ef
112
ssh/tailssh/ctxreader.go
Normal file
112
ssh/tailssh/ctxreader.go
Normal file
@ -0,0 +1,112 @@
|
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tailssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"tailscale.com/tempfork/gliderlabs/ssh"
|
||||
)
|
||||
|
||||
// readResult is a result from a io.Reader.Read call,
|
||||
// as used by contextReader.
|
||||
type readResult struct {
|
||||
buf []byte // ownership passed on chan send
|
||||
err error
|
||||
}
|
||||
|
||||
// contextReader wraps an io.Reader, providing a ReadContext method
|
||||
// that can be aborted before yielding bytes. If it's aborted, subsequent
|
||||
// reads can get those byte(s) later.
|
||||
type contextReader struct {
|
||||
r io.Reader
|
||||
|
||||
// buffered is leftover data from a previous read call that wasn't entirely
|
||||
// consumed.
|
||||
buffered []byte
|
||||
// readErr is a previous read error that was seen while filling buffered. It
|
||||
// should be returned to the caller after bufffered is consumed.
|
||||
readErr error
|
||||
|
||||
mu sync.Mutex // guards ch only
|
||||
|
||||
// ch is non-nil if a goroutine had been started and has a result to be
|
||||
// read. The goroutine may be either still running or done and has
|
||||
// send to the channel.
|
||||
ch chan readResult
|
||||
}
|
||||
|
||||
// HasOutstandingRead reports whether there's an oustanding Read call that's
|
||||
// either currently blocked in a Read or whose result hasn't been consumed.
|
||||
func (w *contextReader) HasOutstandingRead() bool {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
return w.ch != nil
|
||||
}
|
||||
|
||||
func (w *contextReader) setChan(c chan readResult) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.ch = c
|
||||
}
|
||||
|
||||
// ReadContext is like Read, but takes a context permitting the read to be canceled.
|
||||
//
|
||||
// If the context becomes done, the underlying Read call continues and its result
|
||||
// will be given to the next caller to ReadContext.
|
||||
func (w *contextReader) ReadContext(ctx context.Context, p []byte) (n int, err error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
n = copy(p, w.buffered)
|
||||
if n > 0 {
|
||||
w.buffered = w.buffered[n:]
|
||||
if len(w.buffered) == 0 {
|
||||
err = w.readErr
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
if w.ch == nil {
|
||||
ch := make(chan readResult, 1)
|
||||
w.setChan(ch)
|
||||
go func() {
|
||||
rbuf := make([]byte, len(p))
|
||||
n, err := w.r.Read(rbuf)
|
||||
ch <- readResult{rbuf[:n], err}
|
||||
}()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return 0, ctx.Err()
|
||||
case rr := <-w.ch:
|
||||
w.setChan(nil)
|
||||
n = copy(p, rr.buf)
|
||||
w.buffered = rr.buf[n:]
|
||||
w.readErr = rr.err
|
||||
if len(w.buffered) == 0 {
|
||||
err = rr.err
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
// contextReaderSesssion implements ssh.Session, wrapping another
|
||||
// ssh.Session but changing its Read method to use contextReader.
|
||||
type contextReaderSesssion struct {
|
||||
ssh.Session
|
||||
cr *contextReader
|
||||
}
|
||||
|
||||
func (a contextReaderSesssion) Read(p []byte) (n int, err error) {
|
||||
if a.cr.HasOutstandingRead() {
|
||||
return a.cr.ReadContext(context.Background(), p)
|
||||
}
|
||||
return a.Session.Read(p)
|
||||
}
|
@ -37,6 +37,7 @@ import (
|
||||
"tailscale.com/ipn/ipnlocal"
|
||||
"tailscale.com/logtail/backoff"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/syncs"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/tempfork/gliderlabs/ssh"
|
||||
"tailscale.com/types/logger"
|
||||
@ -488,7 +489,8 @@ func (srv *server) fetchPublicKeysURL(url string) ([]string, error) {
|
||||
// completed. It also handles SFTP requests.
|
||||
func (c *conn) handleConnPostSSHAuth(s ssh.Session) {
|
||||
sshUser := s.User()
|
||||
action, err := c.resolveTerminalAction(s)
|
||||
cr := &contextReader{r: s}
|
||||
action, err := c.resolveTerminalAction(s, cr)
|
||||
if err != nil {
|
||||
c.logf("resolveTerminalAction: %v", err)
|
||||
io.WriteString(s.Stderr(), "Access Denied: failed during authorization check.\r\n")
|
||||
@ -501,6 +503,10 @@ func (c *conn) handleConnPostSSHAuth(s ssh.Session) {
|
||||
return
|
||||
}
|
||||
|
||||
if cr.HasOutstandingRead() {
|
||||
s = contextReaderSesssion{s, cr}
|
||||
}
|
||||
|
||||
// Do this check after auth, but before starting the session.
|
||||
switch s.Subsystem() {
|
||||
case "sftp", "":
|
||||
@ -522,8 +528,17 @@ func (c *conn) handleConnPostSSHAuth(s ssh.Session) {
|
||||
// Any action with a Message in the chain will be printed to s.
|
||||
//
|
||||
// The returned SSHAction will be either Reject or Accept.
|
||||
func (c *conn) resolveTerminalAction(s ssh.Session) (*tailcfg.SSHAction, error) {
|
||||
func (c *conn) resolveTerminalAction(s ssh.Session, cr *contextReader) (*tailcfg.SSHAction, error) {
|
||||
action := c.action0
|
||||
|
||||
var awaitReadOnce sync.Once // to start Reads on cr
|
||||
var sawInterrupt syncs.AtomicBool
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait() // wait for awaitIntrOnce's goroutine to exit
|
||||
|
||||
ctx, cancel := context.WithCancel(s.Context())
|
||||
defer cancel()
|
||||
|
||||
// Loop processing/fetching Actions until one reaches a
|
||||
// terminal state (Accept, Reject, or invalid Action), or
|
||||
// until fetchSSHAction times out due to the context being
|
||||
@ -541,10 +556,32 @@ func (c *conn) resolveTerminalAction(s ssh.Session) (*tailcfg.SSHAction, error)
|
||||
if url == "" {
|
||||
return nil, errors.New("reached Action that lacked Accept, Reject, and HoldAndDelegate")
|
||||
}
|
||||
awaitReadOnce.Do(func() {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buf := make([]byte, 1)
|
||||
for {
|
||||
n, err := cr.ReadContext(ctx, buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if n > 0 && buf[0] == 0x03 { // Ctrl-C
|
||||
sawInterrupt.Set(true)
|
||||
s.Stderr().Write([]byte("Canceled.\r\n"))
|
||||
s.Exit(1)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
url = c.expandDelegateURL(url)
|
||||
var err error
|
||||
action, err = c.fetchSSHAction(s.Context(), url)
|
||||
action, err = c.fetchSSHAction(ctx, url)
|
||||
if err != nil {
|
||||
if sawInterrupt.Get() {
|
||||
return nil, fmt.Errorf("aborted by user")
|
||||
}
|
||||
return nil, fmt.Errorf("fetching SSHAction from %s: %w", url, err)
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user