mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-25 19:15:34 +00:00
ssh/tailssh: handle dialing multiple recorders and failing open
This adds support to try dialing out to multiple recorders each with a 5s timeout and an overall 30s timeout. It also starts respecting the actions `OnRecordingFailure` field if set, if it is not set it fails open. Updates tailscale/corp#9967 Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
parent
f66ddb544c
commit
7778d708a6
@ -17,6 +17,7 @@
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
@ -42,6 +43,7 @@
|
||||
"tailscale.com/types/netmap"
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/util/mak"
|
||||
"tailscale.com/util/multierr"
|
||||
"tailscale.com/version/distro"
|
||||
)
|
||||
|
||||
@ -79,33 +81,11 @@ type server struct {
|
||||
|
||||
// mu protects the following
|
||||
mu sync.Mutex
|
||||
httpc *http.Client // for calling out to peers.
|
||||
activeConns map[*conn]bool // set; value is always true
|
||||
fetchPublicKeysCache map[string]pubKeyCacheEntry // by https URL
|
||||
shutdownCalled bool
|
||||
}
|
||||
|
||||
// sessionRecordingClient returns an http.Client that uses srv.lb.Dialer() to
|
||||
// dial connections. This is used to make requests to the session recording
|
||||
// server to upload session recordings.
|
||||
func (srv *server) sessionRecordingClient() *http.Client {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
if srv.httpc != nil {
|
||||
return srv.httpc
|
||||
}
|
||||
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||
tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
return srv.lb.Dialer().UserDial(ctx, network, addr)
|
||||
}
|
||||
srv.httpc = &http.Client{
|
||||
Transport: tr,
|
||||
}
|
||||
return srv.httpc
|
||||
}
|
||||
|
||||
func (srv *server) now() time.Time {
|
||||
if srv != nil && srv.timeNow != nil {
|
||||
return srv.timeNow()
|
||||
@ -1078,7 +1058,7 @@ func (ss *sshSession) run() {
|
||||
if err != nil {
|
||||
var uve userVisibleError
|
||||
if errors.As(err, &uve) {
|
||||
fmt.Fprintf(ss, "%s\r\n", uve)
|
||||
fmt.Fprintf(ss, "%s\r\n", uve.SSHTerminationMessage())
|
||||
} else {
|
||||
fmt.Fprintf(ss, "can't start new recording\r\n")
|
||||
}
|
||||
@ -1086,7 +1066,9 @@ func (ss *sshSession) run() {
|
||||
ss.Exit(1)
|
||||
return
|
||||
}
|
||||
defer rec.Close()
|
||||
if rec != nil {
|
||||
defer rec.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1169,15 +1151,16 @@ func (ss *sshSession) run() {
|
||||
// If the final action has a non-empty list of recorders, that list is
|
||||
// returned. Otherwise, the list of recorders from the initial action
|
||||
// is returned.
|
||||
func (ss *sshSession) recorders() []netip.AddrPort {
|
||||
func (ss *sshSession) recorders() ([]netip.AddrPort, *tailcfg.SSHRecorderFailureAction) {
|
||||
if len(ss.conn.finalAction.Recorders) > 0 {
|
||||
return ss.conn.finalAction.Recorders
|
||||
return ss.conn.finalAction.Recorders, ss.conn.finalAction.OnRecordingFailure
|
||||
}
|
||||
return ss.conn.action0.Recorders
|
||||
return ss.conn.action0.Recorders, ss.conn.action0.OnRecordingFailure
|
||||
}
|
||||
|
||||
func (ss *sshSession) shouldRecord() bool {
|
||||
return len(ss.recorders()) > 0
|
||||
recs, _ := ss.recorders()
|
||||
return len(recs) > 0
|
||||
}
|
||||
|
||||
type sshConnInfo struct {
|
||||
@ -1409,16 +1392,120 @@ type CastHeader struct {
|
||||
LocalUser string `json:"localUser"`
|
||||
}
|
||||
|
||||
// sessionRecordingClient returns an http.Client that uses srv.lb.Dialer() to
|
||||
// dial connections. This is used to make requests to the session recording
|
||||
// server to upload session recordings.
|
||||
// It uses the provided dialCtx to dial connections, and limits a single dial
|
||||
// to 5 seconds.
|
||||
func (ss *sshSession) sessionRecordingClient(dialCtx context.Context) (*http.Client, error) {
|
||||
dialer := ss.conn.srv.lb.Dialer()
|
||||
if dialer == nil {
|
||||
return nil, errors.New("no peer API transport")
|
||||
}
|
||||
tr := dialer.PeerAPITransport().Clone()
|
||||
dialContextFn := tr.DialContext
|
||||
|
||||
tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
perAttemptCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
go func() {
|
||||
select {
|
||||
case <-perAttemptCtx.Done():
|
||||
case <-dialCtx.Done():
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
return dialContextFn(perAttemptCtx, network, addr)
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: tr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// connectToRecorder connects to the recorder at any of the provided addresses.
|
||||
// It returns the first successful response, or a multierr if all attempts fail.
|
||||
//
|
||||
// On success, it returns a WriteCloser that can be used to upload the
|
||||
// recording, and a channel that will be sent an error (or nil) when the upload
|
||||
// fails or completes.
|
||||
func (ss *sshSession) connectToRecorder(ctx context.Context, recs []netip.AddrPort) (io.WriteCloser, <-chan error, error) {
|
||||
if len(recs) == 0 {
|
||||
return nil, nil, errors.New("no recorders configured")
|
||||
}
|
||||
|
||||
// We use a special context for dialing the recorder, so that we can
|
||||
// limit the time we spend dialing to 30 seconds and still have an
|
||||
// unbounded context for the upload.
|
||||
dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer dialCancel()
|
||||
hc, err := ss.sessionRecordingClient(dialCtx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
var errs []error
|
||||
for _, ap := range recs {
|
||||
// We dial the recorder and wait for it to send a 100-continue
|
||||
// response before returning from this function. This ensures that
|
||||
// the recorder is ready to accept the recording.
|
||||
|
||||
// got100 is closed when we receive the 100-continue response.
|
||||
got100 := make(chan struct{})
|
||||
ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
|
||||
Got100Continue: func() {
|
||||
close(got100)
|
||||
},
|
||||
})
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s:%d/record", ap.Addr(), ap.Port()), pr)
|
||||
if err != nil {
|
||||
errs = append(errs, fmt.Errorf("recording: error starting recording: %w", err))
|
||||
continue
|
||||
}
|
||||
// We set the Expect header to 100-continue, so that the recorder
|
||||
// will send a 100-continue response before it starts reading the
|
||||
// request body.
|
||||
req.Header.Set("Expect", "100-continue")
|
||||
|
||||
// errChan is used to indicate the result of the request.
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
resp, err := hc.Do(req)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("recording: error starting recording: %w", err)
|
||||
return
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
errChan <- fmt.Errorf("recording: unexpected status: %v", resp.Status)
|
||||
return
|
||||
}
|
||||
errChan <- nil
|
||||
}()
|
||||
select {
|
||||
case <-got100:
|
||||
case err := <-errChan:
|
||||
// If we get an error before we get the 100-continue response,
|
||||
// we need to try another recorder.
|
||||
if err == nil {
|
||||
// If the error is nil, we got a 200 response, which
|
||||
// is unexpected as we haven't sent any data yet.
|
||||
err = errors.New("recording: unexpected EOF")
|
||||
}
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
}
|
||||
return pw, errChan, nil
|
||||
}
|
||||
return nil, nil, multierr.New(errs...)
|
||||
}
|
||||
|
||||
// startNewRecording starts a new SSH session recording.
|
||||
// It may return a nil recording if recording is not available.
|
||||
func (ss *sshSession) startNewRecording() (_ *recording, err error) {
|
||||
recorders := ss.recorders()
|
||||
recorders, onFailure := ss.recorders()
|
||||
if len(recorders) == 0 {
|
||||
return nil, errors.New("no recorders configured")
|
||||
}
|
||||
recorder := recorders[0]
|
||||
if len(recorders) > 1 {
|
||||
ss.logf("warning: multiple recorders configured, using first one: %v", recorder)
|
||||
}
|
||||
|
||||
var w ssh.Window
|
||||
if ptyReq, _, isPtyReq := ss.Pty(); isPtyReq {
|
||||
@ -1436,51 +1523,43 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) {
|
||||
start: now,
|
||||
}
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
// We want to use a background context for uploading and not ss.ctx.
|
||||
// ss.ctx is closed when the session closes, but we don't want to break the upload at that time.
|
||||
// Instead we want to wait for the session to close the writer when it finishes.
|
||||
ctx := context.Background()
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s:%d/record", recorder.Addr(), recorder.Port()), pr)
|
||||
wc, errChan, err := ss.connectToRecorder(ctx, recorders)
|
||||
if err != nil {
|
||||
pr.Close()
|
||||
pw.Close()
|
||||
return nil, err
|
||||
}
|
||||
// We want to wait for the server to respond with 100 Continue to notifiy us
|
||||
// that it's ready to receive data. We do this to block the session from
|
||||
// starting until the server is ready to receive data.
|
||||
// It also allows the server to reject the request before we start sending
|
||||
// data.
|
||||
req.Header.Set("Expect", "100-continue")
|
||||
go func() {
|
||||
defer pw.Close()
|
||||
ss.logf("starting asciinema recording to %s", recorder)
|
||||
hc := ss.conn.srv.sessionRecordingClient()
|
||||
resp, err := hc.Do(req)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("recording: error sending recording: %w", err)
|
||||
ss.logf("%v", err)
|
||||
ss.cancelCtx(userVisibleError{
|
||||
msg: "recording: error sending recording",
|
||||
// TODO(catzkorn): notify control here.
|
||||
if onFailure != nil && onFailure.RejectSessionWithMessage != "" {
|
||||
ss.logf("recording: error starting recording (rejecting session): %v", err)
|
||||
return nil, userVisibleError{
|
||||
error: err,
|
||||
msg: onFailure.RejectSessionWithMessage,
|
||||
}
|
||||
}
|
||||
ss.logf("recording: error starting recording (failing open): %v", err)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := <-errChan
|
||||
if err == nil {
|
||||
// Success.
|
||||
return
|
||||
}
|
||||
// TODO(catzkorn): notify control here.
|
||||
if onFailure != nil && onFailure.TerminateSessionWithMessage != "" {
|
||||
ss.logf("recording: error uploading recording (closing session): %v", err)
|
||||
ss.cancelCtx(userVisibleError{
|
||||
error: err,
|
||||
msg: onFailure.TerminateSessionWithMessage,
|
||||
})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer ss.cancelCtx(errors.New("recording: done"))
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
err := fmt.Errorf("recording: server responded with %s", resp.Status)
|
||||
ss.logf("%v", err)
|
||||
ss.cancelCtx(userVisibleError{
|
||||
msg: "recording server responded with: " + resp.Status,
|
||||
error: err,
|
||||
})
|
||||
}
|
||||
ss.logf("recording: error uploading recording (failing open): %v", err)
|
||||
}()
|
||||
|
||||
rec.out = pw
|
||||
rec.out = wc
|
||||
|
||||
ch := CastHeader{
|
||||
Version: 2,
|
||||
@ -1515,7 +1594,7 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) {
|
||||
return nil, err
|
||||
}
|
||||
j = append(j, '\n')
|
||||
if _, err := pw.Write(j); err != nil {
|
||||
if _, err := rec.out.Write(j); err != nil {
|
||||
if errors.Is(err, io.ErrClosedPipe) && ss.ctx.Err() != nil {
|
||||
// If we got an io.ErrClosedPipe, it's likely because
|
||||
// the recording server closed the connection on us. Return
|
||||
|
@ -240,7 +240,7 @@ type localState struct {
|
||||
)
|
||||
|
||||
func (ts *localState) Dialer() *tsdial.Dialer {
|
||||
return nil
|
||||
return &tsdial.Dialer{}
|
||||
}
|
||||
|
||||
func (ts *localState) GetSSH_HostKeys() ([]gossh.Signer, error) {
|
||||
@ -338,8 +338,7 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
|
||||
defer recordingServer.Close()
|
||||
|
||||
s := &server{
|
||||
logf: t.Logf,
|
||||
httpc: recordingServer.Client(),
|
||||
logf: t.Logf,
|
||||
lb: &localState{
|
||||
sshEnabled: true,
|
||||
matchingRule: newSSHRule(
|
||||
@ -348,6 +347,10 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
|
||||
Recorders: []netip.AddrPort{
|
||||
netip.MustParseAddrPort(recordingServer.Listener.Addr().String()),
|
||||
},
|
||||
OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
|
||||
RejectSessionWithMessage: "session rejected",
|
||||
TerminateSessionWithMessage: "session terminated",
|
||||
},
|
||||
},
|
||||
),
|
||||
},
|
||||
@ -374,7 +377,7 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
},
|
||||
sshCommand: "echo hello",
|
||||
wantClientOutput: "recording: server responded with 403 Forbidden\r\n",
|
||||
wantClientOutput: "session rejected\r\n",
|
||||
|
||||
clientOutputMustNotContain: []string{"hello"},
|
||||
},
|
||||
@ -386,7 +389,7 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
},
|
||||
sshCommand: "echo hello && sleep 1 && echo world",
|
||||
wantClientOutput: "\r\n\r\nrecording server responded with: 500 Internal Server Error\r\n\r\n",
|
||||
wantClientOutput: "\r\n\r\nsession terminated\r\n\r\n",
|
||||
|
||||
clientOutputMustNotContain: []string{"world"},
|
||||
},
|
||||
@ -440,6 +443,103 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleRecorders(t *testing.T) {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
|
||||
t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
|
||||
}
|
||||
done := make(chan struct{})
|
||||
recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer close(done)
|
||||
io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer recordingServer.Close()
|
||||
badRecorder, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
badRecorderAddr := badRecorder.Addr().String()
|
||||
badRecorder.Close()
|
||||
|
||||
badRecordingServer500 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(500)
|
||||
}))
|
||||
defer badRecordingServer500.Close()
|
||||
|
||||
badRecordingServer200 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
}))
|
||||
defer badRecordingServer200.Close()
|
||||
|
||||
s := &server{
|
||||
logf: t.Logf,
|
||||
lb: &localState{
|
||||
sshEnabled: true,
|
||||
matchingRule: newSSHRule(
|
||||
&tailcfg.SSHAction{
|
||||
Accept: true,
|
||||
Recorders: []netip.AddrPort{
|
||||
netip.MustParseAddrPort(badRecorderAddr),
|
||||
netip.MustParseAddrPort(badRecordingServer500.Listener.Addr().String()),
|
||||
netip.MustParseAddrPort(badRecordingServer200.Listener.Addr().String()),
|
||||
netip.MustParseAddrPort(recordingServer.Listener.Addr().String()),
|
||||
},
|
||||
OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
|
||||
RejectSessionWithMessage: "session rejected",
|
||||
TerminateSessionWithMessage: "session terminated",
|
||||
},
|
||||
},
|
||||
),
|
||||
},
|
||||
}
|
||||
defer s.Shutdown()
|
||||
|
||||
src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
|
||||
sc, dc := memnet.NewTCPConn(src, dst, 1024)
|
||||
|
||||
const sshUser = "alice"
|
||||
cfg := &gossh.ClientConfig{
|
||||
User: sshUser,
|
||||
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
|
||||
if err != nil {
|
||||
t.Errorf("client: %v", err)
|
||||
return
|
||||
}
|
||||
client := gossh.NewClient(c, chans, reqs)
|
||||
defer client.Close()
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
t.Errorf("client: %v", err)
|
||||
return
|
||||
}
|
||||
defer session.Close()
|
||||
t.Logf("client established session")
|
||||
out, err := session.CombinedOutput("echo Ran echo!")
|
||||
if err != nil {
|
||||
t.Errorf("client: %v", err)
|
||||
}
|
||||
if string(out) != "Ran echo!\n" {
|
||||
t.Errorf("client: unexpected output: %q", out)
|
||||
}
|
||||
}()
|
||||
if err := s.HandleSSHConn(dc); err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
wg.Wait()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("timed out waiting for recording")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSHRecordingNonInteractive tests that the SSH server records the SSH session
|
||||
// when the client is not interactive (i.e. no PTY).
|
||||
// It starts a local SSH server and a recording server. The recording server
|
||||
@ -464,8 +564,7 @@ func TestSSHRecordingNonInteractive(t *testing.T) {
|
||||
defer recordingServer.Close()
|
||||
|
||||
s := &server{
|
||||
logf: logger.Discard,
|
||||
httpc: recordingServer.Client(),
|
||||
logf: logger.Discard,
|
||||
lb: &localState{
|
||||
sshEnabled: true,
|
||||
matchingRule: newSSHRule(
|
||||
@ -474,6 +573,10 @@ func TestSSHRecordingNonInteractive(t *testing.T) {
|
||||
Recorders: []netip.AddrPort{
|
||||
must.Get(netip.ParseAddrPort(recordingServer.Listener.Addr().String())),
|
||||
},
|
||||
OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
|
||||
RejectSessionWithMessage: "session rejected",
|
||||
TerminateSessionWithMessage: "session terminated",
|
||||
},
|
||||
},
|
||||
),
|
||||
},
|
||||
|
@ -97,7 +97,8 @@
|
||||
// - 58: 2023-03-10: Client retries lite map updates before restarting map poll.
|
||||
// - 59: 2023-03-16: Client understands Peers[].SelfNodeV4MasqAddrForThisPeer
|
||||
// - 60: 2023-04-06: Client understands IsWireGuardOnly
|
||||
const CurrentCapabilityVersion CapabilityVersion = 60
|
||||
// - 61: 2023-04-18: Client understand SSHAction.SSHRecorderFailureAction
|
||||
const CurrentCapabilityVersion CapabilityVersion = 61
|
||||
|
||||
type StableID string
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user