ssh/tailssh: handle dialing multiple recorders and failing open

This adds support to try dialing out to multiple recorders each
with a 5s timeout and an overall 30s timeout. It also starts respecting
the actions `OnRecordingFailure` field if set, if it is not set
it fails open.

Updates tailscale/corp#9967

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali
2023-04-19 21:33:33 -07:00
committed by Maisem Ali
parent f66ddb544c
commit 7778d708a6
3 changed files with 259 additions and 76 deletions

View File

@@ -240,7 +240,7 @@ var (
)
func (ts *localState) Dialer() *tsdial.Dialer {
return nil
return &tsdial.Dialer{}
}
func (ts *localState) GetSSH_HostKeys() ([]gossh.Signer, error) {
@@ -338,8 +338,7 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
defer recordingServer.Close()
s := &server{
logf: t.Logf,
httpc: recordingServer.Client(),
logf: t.Logf,
lb: &localState{
sshEnabled: true,
matchingRule: newSSHRule(
@@ -348,6 +347,10 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
Recorders: []netip.AddrPort{
netip.MustParseAddrPort(recordingServer.Listener.Addr().String()),
},
OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
RejectSessionWithMessage: "session rejected",
TerminateSessionWithMessage: "session terminated",
},
},
),
},
@@ -374,7 +377,7 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
w.WriteHeader(http.StatusForbidden)
},
sshCommand: "echo hello",
wantClientOutput: "recording: server responded with 403 Forbidden\r\n",
wantClientOutput: "session rejected\r\n",
clientOutputMustNotContain: []string{"hello"},
},
@@ -386,7 +389,7 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
w.WriteHeader(http.StatusInternalServerError)
},
sshCommand: "echo hello && sleep 1 && echo world",
wantClientOutput: "\r\n\r\nrecording server responded with: 500 Internal Server Error\r\n\r\n",
wantClientOutput: "\r\n\r\nsession terminated\r\n\r\n",
clientOutputMustNotContain: []string{"world"},
},
@@ -440,6 +443,103 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
}
}
func TestMultipleRecorders(t *testing.T) {
if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
}
done := make(chan struct{})
recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer close(done)
io.ReadAll(r.Body)
w.WriteHeader(http.StatusOK)
}))
defer recordingServer.Close()
badRecorder, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
badRecorderAddr := badRecorder.Addr().String()
badRecorder.Close()
badRecordingServer500 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(500)
}))
defer badRecordingServer500.Close()
badRecordingServer200 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
defer badRecordingServer200.Close()
s := &server{
logf: t.Logf,
lb: &localState{
sshEnabled: true,
matchingRule: newSSHRule(
&tailcfg.SSHAction{
Accept: true,
Recorders: []netip.AddrPort{
netip.MustParseAddrPort(badRecorderAddr),
netip.MustParseAddrPort(badRecordingServer500.Listener.Addr().String()),
netip.MustParseAddrPort(badRecordingServer200.Listener.Addr().String()),
netip.MustParseAddrPort(recordingServer.Listener.Addr().String()),
},
OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
RejectSessionWithMessage: "session rejected",
TerminateSessionWithMessage: "session terminated",
},
},
),
},
}
defer s.Shutdown()
src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
sc, dc := memnet.NewTCPConn(src, dst, 1024)
const sshUser = "alice"
cfg := &gossh.ClientConfig{
User: sshUser,
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
if err != nil {
t.Errorf("client: %v", err)
return
}
client := gossh.NewClient(c, chans, reqs)
defer client.Close()
session, err := client.NewSession()
if err != nil {
t.Errorf("client: %v", err)
return
}
defer session.Close()
t.Logf("client established session")
out, err := session.CombinedOutput("echo Ran echo!")
if err != nil {
t.Errorf("client: %v", err)
}
if string(out) != "Ran echo!\n" {
t.Errorf("client: unexpected output: %q", out)
}
}()
if err := s.HandleSSHConn(dc); err != nil {
t.Errorf("unexpected error: %v", err)
}
wg.Wait()
select {
case <-done:
case <-time.After(1 * time.Second):
t.Fatal("timed out waiting for recording")
}
}
// TestSSHRecordingNonInteractive tests that the SSH server records the SSH session
// when the client is not interactive (i.e. no PTY).
// It starts a local SSH server and a recording server. The recording server
@@ -464,8 +564,7 @@ func TestSSHRecordingNonInteractive(t *testing.T) {
defer recordingServer.Close()
s := &server{
logf: logger.Discard,
httpc: recordingServer.Client(),
logf: logger.Discard,
lb: &localState{
sshEnabled: true,
matchingRule: newSSHRule(
@@ -474,6 +573,10 @@ func TestSSHRecordingNonInteractive(t *testing.T) {
Recorders: []netip.AddrPort{
must.Get(netip.ParseAddrPort(recordingServer.Listener.Addr().String())),
},
OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
RejectSessionWithMessage: "session rejected",
TerminateSessionWithMessage: "session terminated",
},
},
),
},