ssh/tailssh: handle dialing multiple recorders and failing open

This adds support to try dialing out to multiple recorders each
with a 5s timeout and an overall 30s timeout. It also starts respecting
the actions `OnRecordingFailure` field if set, if it is not set
it fails open.

Updates tailscale/corp#9967

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali
2023-04-19 21:33:33 -07:00
committed by Maisem Ali
parent f66ddb544c
commit 7778d708a6
3 changed files with 259 additions and 76 deletions

View File

@@ -17,6 +17,7 @@ import (
"io"
"net"
"net/http"
"net/http/httptrace"
"net/netip"
"net/url"
"os"
@@ -42,6 +43,7 @@ import (
"tailscale.com/types/netmap"
"tailscale.com/util/clientmetric"
"tailscale.com/util/mak"
"tailscale.com/util/multierr"
"tailscale.com/version/distro"
)
@@ -79,33 +81,11 @@ type server struct {
// mu protects the following
mu sync.Mutex
httpc *http.Client // for calling out to peers.
activeConns map[*conn]bool // set; value is always true
fetchPublicKeysCache map[string]pubKeyCacheEntry // by https URL
shutdownCalled bool
}
// sessionRecordingClient returns an http.Client that uses srv.lb.Dialer() to
// dial connections. This is used to make requests to the session recording
// server to upload session recordings.
func (srv *server) sessionRecordingClient() *http.Client {
srv.mu.Lock()
defer srv.mu.Unlock()
if srv.httpc != nil {
return srv.httpc
}
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
return srv.lb.Dialer().UserDial(ctx, network, addr)
}
srv.httpc = &http.Client{
Transport: tr,
}
return srv.httpc
}
func (srv *server) now() time.Time {
if srv != nil && srv.timeNow != nil {
return srv.timeNow()
@@ -1078,7 +1058,7 @@ func (ss *sshSession) run() {
if err != nil {
var uve userVisibleError
if errors.As(err, &uve) {
fmt.Fprintf(ss, "%s\r\n", uve)
fmt.Fprintf(ss, "%s\r\n", uve.SSHTerminationMessage())
} else {
fmt.Fprintf(ss, "can't start new recording\r\n")
}
@@ -1086,7 +1066,9 @@ func (ss *sshSession) run() {
ss.Exit(1)
return
}
defer rec.Close()
if rec != nil {
defer rec.Close()
}
}
}
@@ -1169,15 +1151,16 @@ func (ss *sshSession) run() {
// If the final action has a non-empty list of recorders, that list is
// returned. Otherwise, the list of recorders from the initial action
// is returned.
func (ss *sshSession) recorders() []netip.AddrPort {
func (ss *sshSession) recorders() ([]netip.AddrPort, *tailcfg.SSHRecorderFailureAction) {
if len(ss.conn.finalAction.Recorders) > 0 {
return ss.conn.finalAction.Recorders
return ss.conn.finalAction.Recorders, ss.conn.finalAction.OnRecordingFailure
}
return ss.conn.action0.Recorders
return ss.conn.action0.Recorders, ss.conn.action0.OnRecordingFailure
}
func (ss *sshSession) shouldRecord() bool {
return len(ss.recorders()) > 0
recs, _ := ss.recorders()
return len(recs) > 0
}
type sshConnInfo struct {
@@ -1409,16 +1392,120 @@ type CastHeader struct {
LocalUser string `json:"localUser"`
}
// sessionRecordingClient returns an http.Client that uses srv.lb.Dialer() to
// dial connections. This is used to make requests to the session recording
// server to upload session recordings.
// It uses the provided dialCtx to dial connections, and limits a single dial
// to 5 seconds.
func (ss *sshSession) sessionRecordingClient(dialCtx context.Context) (*http.Client, error) {
dialer := ss.conn.srv.lb.Dialer()
if dialer == nil {
return nil, errors.New("no peer API transport")
}
tr := dialer.PeerAPITransport().Clone()
dialContextFn := tr.DialContext
tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
perAttemptCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
go func() {
select {
case <-perAttemptCtx.Done():
case <-dialCtx.Done():
cancel()
}
}()
return dialContextFn(perAttemptCtx, network, addr)
}
return &http.Client{
Transport: tr,
}, nil
}
// connectToRecorder connects to the recorder at any of the provided addresses.
// It returns the first successful response, or a multierr if all attempts fail.
//
// On success, it returns a WriteCloser that can be used to upload the
// recording, and a channel that will be sent an error (or nil) when the upload
// fails or completes.
func (ss *sshSession) connectToRecorder(ctx context.Context, recs []netip.AddrPort) (io.WriteCloser, <-chan error, error) {
if len(recs) == 0 {
return nil, nil, errors.New("no recorders configured")
}
// We use a special context for dialing the recorder, so that we can
// limit the time we spend dialing to 30 seconds and still have an
// unbounded context for the upload.
dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second)
defer dialCancel()
hc, err := ss.sessionRecordingClient(dialCtx)
if err != nil {
return nil, nil, err
}
var errs []error
for _, ap := range recs {
// We dial the recorder and wait for it to send a 100-continue
// response before returning from this function. This ensures that
// the recorder is ready to accept the recording.
// got100 is closed when we receive the 100-continue response.
got100 := make(chan struct{})
ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
Got100Continue: func() {
close(got100)
},
})
pr, pw := io.Pipe()
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s:%d/record", ap.Addr(), ap.Port()), pr)
if err != nil {
errs = append(errs, fmt.Errorf("recording: error starting recording: %w", err))
continue
}
// We set the Expect header to 100-continue, so that the recorder
// will send a 100-continue response before it starts reading the
// request body.
req.Header.Set("Expect", "100-continue")
// errChan is used to indicate the result of the request.
errChan := make(chan error, 1)
go func() {
resp, err := hc.Do(req)
if err != nil {
errChan <- fmt.Errorf("recording: error starting recording: %w", err)
return
}
if resp.StatusCode != 200 {
errChan <- fmt.Errorf("recording: unexpected status: %v", resp.Status)
return
}
errChan <- nil
}()
select {
case <-got100:
case err := <-errChan:
// If we get an error before we get the 100-continue response,
// we need to try another recorder.
if err == nil {
// If the error is nil, we got a 200 response, which
// is unexpected as we haven't sent any data yet.
err = errors.New("recording: unexpected EOF")
}
errs = append(errs, err)
continue
}
return pw, errChan, nil
}
return nil, nil, multierr.New(errs...)
}
// startNewRecording starts a new SSH session recording.
// It may return a nil recording if recording is not available.
func (ss *sshSession) startNewRecording() (_ *recording, err error) {
recorders := ss.recorders()
recorders, onFailure := ss.recorders()
if len(recorders) == 0 {
return nil, errors.New("no recorders configured")
}
recorder := recorders[0]
if len(recorders) > 1 {
ss.logf("warning: multiple recorders configured, using first one: %v", recorder)
}
var w ssh.Window
if ptyReq, _, isPtyReq := ss.Pty(); isPtyReq {
@@ -1436,51 +1523,43 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) {
start: now,
}
pr, pw := io.Pipe()
// We want to use a background context for uploading and not ss.ctx.
// ss.ctx is closed when the session closes, but we don't want to break the upload at that time.
// Instead we want to wait for the session to close the writer when it finishes.
ctx := context.Background()
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s:%d/record", recorder.Addr(), recorder.Port()), pr)
wc, errChan, err := ss.connectToRecorder(ctx, recorders)
if err != nil {
pr.Close()
pw.Close()
return nil, err
}
// We want to wait for the server to respond with 100 Continue to notifiy us
// that it's ready to receive data. We do this to block the session from
// starting until the server is ready to receive data.
// It also allows the server to reject the request before we start sending
// data.
req.Header.Set("Expect", "100-continue")
go func() {
defer pw.Close()
ss.logf("starting asciinema recording to %s", recorder)
hc := ss.conn.srv.sessionRecordingClient()
resp, err := hc.Do(req)
if err != nil {
err := fmt.Errorf("recording: error sending recording: %w", err)
ss.logf("%v", err)
ss.cancelCtx(userVisibleError{
msg: "recording: error sending recording",
// TODO(catzkorn): notify control here.
if onFailure != nil && onFailure.RejectSessionWithMessage != "" {
ss.logf("recording: error starting recording (rejecting session): %v", err)
return nil, userVisibleError{
error: err,
msg: onFailure.RejectSessionWithMessage,
}
}
ss.logf("recording: error starting recording (failing open): %v", err)
return nil, nil
}
go func() {
err := <-errChan
if err == nil {
// Success.
return
}
// TODO(catzkorn): notify control here.
if onFailure != nil && onFailure.TerminateSessionWithMessage != "" {
ss.logf("recording: error uploading recording (closing session): %v", err)
ss.cancelCtx(userVisibleError{
error: err,
msg: onFailure.TerminateSessionWithMessage,
})
return
}
defer resp.Body.Close()
defer ss.cancelCtx(errors.New("recording: done"))
if resp.StatusCode != http.StatusOK {
err := fmt.Errorf("recording: server responded with %s", resp.Status)
ss.logf("%v", err)
ss.cancelCtx(userVisibleError{
msg: "recording server responded with: " + resp.Status,
error: err,
})
}
ss.logf("recording: error uploading recording (failing open): %v", err)
}()
rec.out = pw
rec.out = wc
ch := CastHeader{
Version: 2,
@@ -1515,7 +1594,7 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) {
return nil, err
}
j = append(j, '\n')
if _, err := pw.Write(j); err != nil {
if _, err := rec.out.Write(j); err != nil {
if errors.Is(err, io.ErrClosedPipe) && ss.ctx.Err() != nil {
// If we got an io.ErrClosedPipe, it's likely because
// the recording server closed the connection on us. Return