ssh/tailssh: add tests for recording failure

Updates tailscale/corp#9967

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2023-03-25 10:19:51 -07:00 committed by Maisem Ali
parent d5abdd915e
commit 5ba57e4661
2 changed files with 139 additions and 17 deletions

View File

@ -1077,6 +1077,13 @@ func (ss *sshSession) run() {
err := ss.launchProcess() err := ss.launchProcess()
if err != nil { if err != nil {
logf("start failed: %v", err.Error()) logf("start failed: %v", err.Error())
if errors.Is(err, context.Canceled) {
err := context.Cause(ss.ctx)
uve := userVisibleError{}
if errors.As(err, &uve) {
fmt.Fprintf(ss, "%s\r\n", uve)
}
}
ss.Exit(1) ss.Exit(1)
return return
} }
@ -1425,20 +1432,35 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) {
pw.Close() pw.Close()
return nil, err return nil, err
} }
// We want to wait for the server to respond with 100 Continue to notifiy us
// that it's ready to receive data. We do this to block the session from
// starting until the server is ready to receive data.
// It also allows the server to reject the request before we start sending
// data.
req.Header.Set("Expect", "100-continue")
go func() { go func() {
defer pw.Close() defer pw.Close()
ss.logf("starting asciinema recording to %s", recorder) ss.logf("starting asciinema recording to %s", recorder)
hc := ss.conn.srv.sessionRecordingClient() hc := ss.conn.srv.sessionRecordingClient()
resp, err := hc.Do(req) resp, err := hc.Do(req)
if err != nil { if err != nil {
ss.cancelCtx(err) err := fmt.Errorf("recording: error sending recording: %w", err)
ss.logf("recording: error sending recording to %s: %v", recorder, err) ss.logf("%v", err)
ss.cancelCtx(userVisibleError{
msg: "recording: error sending recording",
error: err,
})
return return
} }
defer resp.Body.Close() defer resp.Body.Close()
defer ss.cancelCtx(errors.New("recording: done")) defer ss.cancelCtx(errors.New("recording: done"))
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
ss.logf("recording: error sending recording to %s: %v", recorder, resp.Status) err := fmt.Errorf("recording: server responded with %s", resp.Status)
ss.logf("%v", err)
ss.cancelCtx(userVisibleError{
msg: "recording server responded with: " + resp.Status,
error: err,
})
} }
}() }()

View File

@ -326,6 +326,108 @@ func newSSHRule(action *tailcfg.SSHAction) *tailcfg.SSHRule {
} }
} }
func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
}
var handler http.HandlerFunc
recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler(w, r)
}))
defer recordingServer.Close()
s := &server{
logf: t.Logf,
httpc: recordingServer.Client(),
lb: &localState{
sshEnabled: true,
matchingRule: newSSHRule(
&tailcfg.SSHAction{
Accept: true,
Recorders: []netip.AddrPort{
netip.MustParseAddrPort(recordingServer.Listener.Addr().String()),
},
},
),
},
}
defer s.Shutdown()
const sshUser = "alice"
cfg := &gossh.ClientConfig{
User: sshUser,
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
}
tests := []struct {
name string
handler func(w http.ResponseWriter, r *http.Request)
sshCommand string
wantClientOutput string
}{
{
name: "upload-denied",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
},
sshCommand: "echo hello",
wantClientOutput: "recording: server responded with 403 Forbidden\r\n",
},
{
name: "upload-fails-after-starting",
handler: func(w http.ResponseWriter, r *http.Request) {
r.Body.Read(make([]byte, 1))
time.Sleep(100 * time.Millisecond)
w.WriteHeader(http.StatusInternalServerError)
},
sshCommand: "echo hello && sleep 1 && echo world",
wantClientOutput: "hello\n\r\n\r\nrecording server responded with: 500 Internal Server Error\r\n\r\n",
},
}
src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tstest.Replace(t, &handler, tt.handler)
sc, dc := memnet.NewTCPConn(src, dst, 1024)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
if err != nil {
t.Errorf("client: %v", err)
return
}
client := gossh.NewClient(c, chans, reqs)
defer client.Close()
session, err := client.NewSession()
if err != nil {
t.Errorf("client: %v", err)
return
}
defer session.Close()
t.Logf("client established session")
got, err := session.CombinedOutput(tt.sshCommand)
if err != nil {
t.Logf("client got: %q: %v", got, err)
} else {
t.Errorf("client did not get kicked out: %q", got)
}
if string(got) != tt.wantClientOutput {
t.Errorf("client got %q, want %q", got, tt.wantClientOutput)
}
}()
if err := s.HandleSSHConn(dc); err != nil {
t.Errorf("unexpected error: %v", err)
}
wg.Wait()
})
}
}
// TestSSHRecordingNonInteractive tests that the SSH server records the SSH session // TestSSHRecordingNonInteractive tests that the SSH server records the SSH session
// when the client is not interactive (i.e. no PTY). // when the client is not interactive (i.e. no PTY).
// It starts a local SSH server and a recording server. The recording server // It starts a local SSH server and a recording server. The recording server
@ -346,30 +448,28 @@ func TestSSHRecordingNonInteractive(t *testing.T) {
t.Error(err) t.Error(err)
return return
} }
w.WriteHeader(http.StatusOK)
})) }))
defer recordingServer.Close() defer recordingServer.Close()
state := &localState{
sshEnabled: true,
matchingRule: newSSHRule(
&tailcfg.SSHAction{
Accept: true,
Recorders: []netip.AddrPort{
must.Get(netip.ParseAddrPort(recordingServer.Listener.Addr().String())),
},
},
),
}
s := &server{ s := &server{
logf: t.Logf, logf: logger.Discard,
httpc: recordingServer.Client(), httpc: recordingServer.Client(),
lb: &localState{
sshEnabled: true,
matchingRule: newSSHRule(
&tailcfg.SSHAction{
Accept: true,
Recorders: []netip.AddrPort{
must.Get(netip.ParseAddrPort(recordingServer.Listener.Addr().String())),
},
},
),
},
} }
defer s.Shutdown() defer s.Shutdown()
src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22")) src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
sc, dc := memnet.NewTCPConn(src, dst, 1024) sc, dc := memnet.NewTCPConn(src, dst, 1024)
s.lb = state
const sshUser = "alice" const sshUser = "alice"
cfg := &gossh.ClientConfig{ cfg := &gossh.ClientConfig{