mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-29 04:55:31 +00:00
ssh/tailssh: add tests for recording failure
Updates tailscale/corp#9967 Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
parent
d5abdd915e
commit
5ba57e4661
@ -1077,6 +1077,13 @@ func (ss *sshSession) run() {
|
|||||||
err := ss.launchProcess()
|
err := ss.launchProcess()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logf("start failed: %v", err.Error())
|
logf("start failed: %v", err.Error())
|
||||||
|
if errors.Is(err, context.Canceled) {
|
||||||
|
err := context.Cause(ss.ctx)
|
||||||
|
uve := userVisibleError{}
|
||||||
|
if errors.As(err, &uve) {
|
||||||
|
fmt.Fprintf(ss, "%s\r\n", uve)
|
||||||
|
}
|
||||||
|
}
|
||||||
ss.Exit(1)
|
ss.Exit(1)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -1425,20 +1432,35 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) {
|
|||||||
pw.Close()
|
pw.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
// We want to wait for the server to respond with 100 Continue to notifiy us
|
||||||
|
// that it's ready to receive data. We do this to block the session from
|
||||||
|
// starting until the server is ready to receive data.
|
||||||
|
// It also allows the server to reject the request before we start sending
|
||||||
|
// data.
|
||||||
|
req.Header.Set("Expect", "100-continue")
|
||||||
go func() {
|
go func() {
|
||||||
defer pw.Close()
|
defer pw.Close()
|
||||||
ss.logf("starting asciinema recording to %s", recorder)
|
ss.logf("starting asciinema recording to %s", recorder)
|
||||||
hc := ss.conn.srv.sessionRecordingClient()
|
hc := ss.conn.srv.sessionRecordingClient()
|
||||||
resp, err := hc.Do(req)
|
resp, err := hc.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ss.cancelCtx(err)
|
err := fmt.Errorf("recording: error sending recording: %w", err)
|
||||||
ss.logf("recording: error sending recording to %s: %v", recorder, err)
|
ss.logf("%v", err)
|
||||||
|
ss.cancelCtx(userVisibleError{
|
||||||
|
msg: "recording: error sending recording",
|
||||||
|
error: err,
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
defer ss.cancelCtx(errors.New("recording: done"))
|
defer ss.cancelCtx(errors.New("recording: done"))
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
ss.logf("recording: error sending recording to %s: %v", recorder, resp.Status)
|
err := fmt.Errorf("recording: server responded with %s", resp.Status)
|
||||||
|
ss.logf("%v", err)
|
||||||
|
ss.cancelCtx(userVisibleError{
|
||||||
|
msg: "recording server responded with: " + resp.Status,
|
||||||
|
error: err,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -326,6 +326,108 @@ func newSSHRule(action *tailcfg.SSHAction) *tailcfg.SSHRule {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
|
||||||
|
if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
|
||||||
|
t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
|
||||||
|
}
|
||||||
|
|
||||||
|
var handler http.HandlerFunc
|
||||||
|
recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
handler(w, r)
|
||||||
|
}))
|
||||||
|
defer recordingServer.Close()
|
||||||
|
|
||||||
|
s := &server{
|
||||||
|
logf: t.Logf,
|
||||||
|
httpc: recordingServer.Client(),
|
||||||
|
lb: &localState{
|
||||||
|
sshEnabled: true,
|
||||||
|
matchingRule: newSSHRule(
|
||||||
|
&tailcfg.SSHAction{
|
||||||
|
Accept: true,
|
||||||
|
Recorders: []netip.AddrPort{
|
||||||
|
netip.MustParseAddrPort(recordingServer.Listener.Addr().String()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
defer s.Shutdown()
|
||||||
|
|
||||||
|
const sshUser = "alice"
|
||||||
|
cfg := &gossh.ClientConfig{
|
||||||
|
User: sshUser,
|
||||||
|
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
handler func(w http.ResponseWriter, r *http.Request)
|
||||||
|
sshCommand string
|
||||||
|
wantClientOutput string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "upload-denied",
|
||||||
|
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusForbidden)
|
||||||
|
},
|
||||||
|
sshCommand: "echo hello",
|
||||||
|
wantClientOutput: "recording: server responded with 403 Forbidden\r\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "upload-fails-after-starting",
|
||||||
|
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
r.Body.Read(make([]byte, 1))
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
},
|
||||||
|
sshCommand: "echo hello && sleep 1 && echo world",
|
||||||
|
wantClientOutput: "hello\n\r\n\r\nrecording server responded with: 500 Internal Server Error\r\n\r\n",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tstest.Replace(t, &handler, tt.handler)
|
||||||
|
sc, dc := memnet.NewTCPConn(src, dst, 1024)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("client: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
client := gossh.NewClient(c, chans, reqs)
|
||||||
|
defer client.Close()
|
||||||
|
session, err := client.NewSession()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("client: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer session.Close()
|
||||||
|
t.Logf("client established session")
|
||||||
|
got, err := session.CombinedOutput(tt.sshCommand)
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("client got: %q: %v", got, err)
|
||||||
|
} else {
|
||||||
|
t.Errorf("client did not get kicked out: %q", got)
|
||||||
|
}
|
||||||
|
if string(got) != tt.wantClientOutput {
|
||||||
|
t.Errorf("client got %q, want %q", got, tt.wantClientOutput)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if err := s.HandleSSHConn(dc); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestSSHRecordingNonInteractive tests that the SSH server records the SSH session
|
// TestSSHRecordingNonInteractive tests that the SSH server records the SSH session
|
||||||
// when the client is not interactive (i.e. no PTY).
|
// when the client is not interactive (i.e. no PTY).
|
||||||
// It starts a local SSH server and a recording server. The recording server
|
// It starts a local SSH server and a recording server. The recording server
|
||||||
@ -346,30 +448,28 @@ func TestSSHRecordingNonInteractive(t *testing.T) {
|
|||||||
t.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}))
|
}))
|
||||||
defer recordingServer.Close()
|
defer recordingServer.Close()
|
||||||
|
|
||||||
state := &localState{
|
|
||||||
sshEnabled: true,
|
|
||||||
matchingRule: newSSHRule(
|
|
||||||
&tailcfg.SSHAction{
|
|
||||||
Accept: true,
|
|
||||||
Recorders: []netip.AddrPort{
|
|
||||||
must.Get(netip.ParseAddrPort(recordingServer.Listener.Addr().String())),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
}
|
|
||||||
s := &server{
|
s := &server{
|
||||||
logf: t.Logf,
|
logf: logger.Discard,
|
||||||
httpc: recordingServer.Client(),
|
httpc: recordingServer.Client(),
|
||||||
|
lb: &localState{
|
||||||
|
sshEnabled: true,
|
||||||
|
matchingRule: newSSHRule(
|
||||||
|
&tailcfg.SSHAction{
|
||||||
|
Accept: true,
|
||||||
|
Recorders: []netip.AddrPort{
|
||||||
|
must.Get(netip.ParseAddrPort(recordingServer.Listener.Addr().String())),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
defer s.Shutdown()
|
defer s.Shutdown()
|
||||||
|
|
||||||
src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
|
src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
|
||||||
sc, dc := memnet.NewTCPConn(src, dst, 1024)
|
sc, dc := memnet.NewTCPConn(src, dst, 1024)
|
||||||
s.lb = state
|
|
||||||
|
|
||||||
const sshUser = "alice"
|
const sshUser = "alice"
|
||||||
cfg := &gossh.ClientConfig{
|
cfg := &gossh.ClientConfig{
|
||||||
|
Loading…
Reference in New Issue
Block a user