mirror of
https://github.com/tailscale/tailscale.git
synced 2025-08-14 15:07:55 +00:00
sessionrecording: implement v2 recording endpoint support (#14105)
The v2 endpoint supports HTTP/2 bidirectional streaming and acks for received bytes. This is used to detect when a recorder disappears to more quickly terminate the session. Updates https://github.com/tailscale/corp/issues/24023 Signed-off-by: Andrew Lytvynov <awly@tailscale.com>
This commit is contained in:
@@ -33,6 +33,8 @@ import (
|
||||
"time"
|
||||
|
||||
gossh "github.com/tailscale/golang-x-crypto/ssh"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
"tailscale.com/ipn/ipnlocal"
|
||||
"tailscale.com/ipn/store/mem"
|
||||
"tailscale.com/net/memnet"
|
||||
@@ -481,10 +483,9 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
|
||||
}
|
||||
|
||||
var handler http.HandlerFunc
|
||||
recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
recordingServer := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
handler(w, r)
|
||||
}))
|
||||
defer recordingServer.Close()
|
||||
})
|
||||
|
||||
s := &server{
|
||||
logf: t.Logf,
|
||||
@@ -533,9 +534,10 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
|
||||
{
|
||||
name: "upload-fails-after-starting",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.(http.Flusher).Flush()
|
||||
r.Body.Read(make([]byte, 1))
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
},
|
||||
sshCommand: "echo hello && sleep 1 && echo world",
|
||||
wantClientOutput: "\r\n\r\nsession terminated\r\n\r\n",
|
||||
@@ -548,6 +550,7 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s.logf = t.Logf
|
||||
tstest.Replace(t, &handler, tt.handler)
|
||||
sc, dc := memnet.NewTCPConn(src, dst, 1024)
|
||||
var wg sync.WaitGroup
|
||||
@@ -597,12 +600,12 @@ func TestMultipleRecorders(t *testing.T) {
|
||||
t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
|
||||
}
|
||||
done := make(chan struct{})
|
||||
recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
recordingServer := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
defer close(done)
|
||||
io.ReadAll(r.Body)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer recordingServer.Close()
|
||||
w.(http.Flusher).Flush()
|
||||
io.ReadAll(r.Body)
|
||||
})
|
||||
badRecorder, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -610,15 +613,9 @@ func TestMultipleRecorders(t *testing.T) {
|
||||
badRecorderAddr := badRecorder.Addr().String()
|
||||
badRecorder.Close()
|
||||
|
||||
badRecordingServer500 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(500)
|
||||
}))
|
||||
defer badRecordingServer500.Close()
|
||||
|
||||
badRecordingServer200 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
}))
|
||||
defer badRecordingServer200.Close()
|
||||
badRecordingServer500 := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
})
|
||||
|
||||
s := &server{
|
||||
logf: t.Logf,
|
||||
@@ -630,7 +627,6 @@ func TestMultipleRecorders(t *testing.T) {
|
||||
Recorders: []netip.AddrPort{
|
||||
netip.MustParseAddrPort(badRecorderAddr),
|
||||
netip.MustParseAddrPort(badRecordingServer500.Listener.Addr().String()),
|
||||
netip.MustParseAddrPort(badRecordingServer200.Listener.Addr().String()),
|
||||
netip.MustParseAddrPort(recordingServer.Listener.Addr().String()),
|
||||
},
|
||||
OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
|
||||
@@ -701,19 +697,21 @@ func TestSSHRecordingNonInteractive(t *testing.T) {
|
||||
}
|
||||
var recording []byte
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
recordingServer := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
defer cancel()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.(http.Flusher).Flush()
|
||||
|
||||
var err error
|
||||
recording, err = io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}))
|
||||
defer recordingServer.Close()
|
||||
})
|
||||
|
||||
s := &server{
|
||||
logf: logger.Discard,
|
||||
logf: t.Logf,
|
||||
lb: &localState{
|
||||
sshEnabled: true,
|
||||
matchingRule: newSSHRule(
|
||||
@@ -1299,3 +1297,22 @@ func TestStdOsUserUserAssumptions(t *testing.T) {
|
||||
t.Errorf("os/user.User has %v fields; this package assumes %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func mockRecordingServer(t *testing.T, handleRecord http.HandlerFunc) *httptest.Server {
|
||||
t.Helper()
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("POST /record", func(http.ResponseWriter, *http.Request) {
|
||||
t.Errorf("v1 recording endpoint called")
|
||||
})
|
||||
mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {})
|
||||
mux.HandleFunc("POST /v2/record", handleRecord)
|
||||
|
||||
h2s := &http2.Server{}
|
||||
srv := httptest.NewUnstartedServer(h2c.NewHandler(mux, h2s))
|
||||
if err := http2.ConfigureServer(srv.Config, h2s); err != nil {
|
||||
t.Errorf("configuring HTTP/2 support in recording server: %v", err)
|
||||
}
|
||||
srv.Start()
|
||||
t.Cleanup(srv.Close)
|
||||
return srv
|
||||
}
|
||||
|
Reference in New Issue
Block a user