From 69c27b23cb8ae46e6f0845817e879d636f26e70a Mon Sep 17 00:00:00 2001 From: Irbe Krumina Date: Fri, 26 Jul 2024 20:05:49 +0300 Subject: [PATCH] cmd/k8s-operator,k8s-operator/session-recording: implement support for WebSocket protocol Kubernetes currently supports two streaming protocols- SPDY and WebSockets. WebSockets are replacing SPDY, see https://github.com/kubernetes/enhancements/issues/4006 Our 'kubectl exec' session recording was only supporting SPDY. This change: - adds functionality to parse streaming sessions over WebSockets - for sessions that the API server proxy has determined need to be recorded, determines if the session is over SPDY or WebSockets and invoke the relevant parser accordingly - refactors the session recording logic into its own package Updates tailscale/corp#19821 Signed-off-by: Irbe Krumina --- cmd/k8s-operator/depaware.txt | 7 +- cmd/k8s-operator/proxy.go | 81 ++++-- k8s-operator/session-recording/fakes/fakes.go | 117 ++++++++ .../session-recording/hijacker.go | 130 +++++---- .../session-recording/hijacker_test.go | 32 ++- .../session-recording/spdy/conn.go | 37 ++- .../session-recording/spdy/conn_test.go | 123 ++------- .../session-recording/spdy/frame.go | 2 +- .../session-recording/spdy/frame_test.go | 2 +- .../session-recording/spdy}/zlib-reader.go | 2 +- .../session-recording/tsrecorder/header.go | 54 ++++ .../tsrecorder/tsrecorder.go | 56 ++-- k8s-operator/session-recording/ws/conn.go | 244 +++++++++++++++++ .../session-recording/ws/conn_test.go | 171 ++++++++++++ k8s-operator/session-recording/ws/message.go | 253 ++++++++++++++++++ .../session-recording/ws/message_test.go | 125 +++++++++ 16 files changed, 1192 insertions(+), 244 deletions(-) create mode 100644 k8s-operator/session-recording/fakes/fakes.go rename cmd/k8s-operator/spdy-hijacker.go => k8s-operator/session-recording/hijacker.go (70%) rename cmd/k8s-operator/spdy-hijacker_test.go => k8s-operator/session-recording/hijacker_test.go (71%) rename cmd/k8s-operator/spdy-remote-conn-recorder.go => k8s-operator/session-recording/spdy/conn.go (89%) rename cmd/k8s-operator/spdy-remote-conn-recorder_test.go => k8s-operator/session-recording/spdy/conn_test.go (76%) rename cmd/k8s-operator/spdy-frame.go => k8s-operator/session-recording/spdy/frame.go (99%) rename cmd/k8s-operator/spdy-frame_test.go => k8s-operator/session-recording/spdy/frame_test.go (99%) rename {cmd/k8s-operator => k8s-operator/session-recording/spdy}/zlib-reader.go (99%) create mode 100644 k8s-operator/session-recording/tsrecorder/header.go rename cmd/k8s-operator/recorder.go => k8s-operator/session-recording/tsrecorder/tsrecorder.go (57%) create mode 100644 k8s-operator/session-recording/ws/conn.go create mode 100644 k8s-operator/session-recording/ws/conn_test.go create mode 100644 k8s-operator/session-recording/ws/message.go create mode 100644 k8s-operator/session-recording/ws/message_test.go diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index b5c0ed517..c12fd89b7 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -423,6 +423,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/apimachinery/pkg/util/naming from k8s.io/apimachinery/pkg/runtime+ k8s.io/apimachinery/pkg/util/net from k8s.io/apimachinery/pkg/watch+ k8s.io/apimachinery/pkg/util/rand from k8s.io/apiserver/pkg/storage/names + k8s.io/apimachinery/pkg/util/remotecommand from tailscale.com/k8s-operator/session-recording/ws k8s.io/apimachinery/pkg/util/runtime from k8s.io/apimachinery/pkg/apis/meta/internalversion/scheme+ k8s.io/apimachinery/pkg/util/sets from k8s.io/apimachinery/pkg/api/meta+ k8s.io/apimachinery/pkg/util/strategicpatch from k8s.io/client-go/tools/record+ @@ -692,6 +693,10 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/k8s-operator from tailscale.com/cmd/k8s-operator tailscale.com/k8s-operator/apis from tailscale.com/k8s-operator/apis/v1alpha1 tailscale.com/k8s-operator/apis/v1alpha1 from tailscale.com/cmd/k8s-operator+ + tailscale.com/k8s-operator/session-recording from tailscale.com/cmd/k8s-operator + tailscale.com/k8s-operator/session-recording/spdy from tailscale.com/k8s-operator/session-recording + tailscale.com/k8s-operator/session-recording/tsrecorder from tailscale.com/k8s-operator/session-recording+ + tailscale.com/k8s-operator/session-recording/ws from tailscale.com/k8s-operator/session-recording tailscale.com/kube from tailscale.com/cmd/k8s-operator+ tailscale.com/licenses from tailscale.com/client/web tailscale.com/log/filelogger from tailscale.com/logpolicy @@ -752,7 +757,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/tka from tailscale.com/client/tailscale+ W tailscale.com/tsconst from tailscale.com/net/netmon+ tailscale.com/tsd from tailscale.com/ipn/ipnlocal+ - tailscale.com/tsnet from tailscale.com/cmd/k8s-operator + tailscale.com/tsnet from tailscale.com/cmd/k8s-operator+ tailscale.com/tstime from tailscale.com/cmd/k8s-operator+ tailscale.com/tstime/mono from tailscale.com/net/tstun+ tailscale.com/tstime/rate from tailscale.com/derp+ diff --git a/cmd/k8s-operator/proxy.go b/cmd/k8s-operator/proxy.go index 258a958fa..45b048f6f 100644 --- a/cmd/k8s-operator/proxy.go +++ b/cmd/k8s-operator/proxy.go @@ -22,6 +22,7 @@ import ( "k8s.io/client-go/transport" "tailscale.com/client/tailscale" "tailscale.com/client/tailscale/apitype" + sessionrecording "tailscale.com/k8s-operator/session-recording" tskube "tailscale.com/kube" "tailscale.com/ssh/tailssh" "tailscale.com/tailcfg" @@ -36,12 +37,6 @@ var whoIsKey = ctxkey.New("", (*apitype.WhoIsResponse)(nil)) var ( // counterNumRequestsproxies counts the number of API server requests proxied via this proxy. counterNumRequestsProxied = clientmetric.NewCounter("k8s_auth_proxy_requests_proxied") - - // counterSessionRecordingsAttempted counts the number of session recording attempts. - counterSessionRecordingsAttempted = clientmetric.NewCounter("k8s_auth_proxy__session_recordings_attempted") - - // counterSessionRecordingsUploaded counts the number of successfully uploaded session recordings. - counterSessionRecordingsUploaded = clientmetric.NewCounter("k8s_auth_proxy_session_recordings_uploaded") ) type apiServerProxyMode int @@ -173,7 +168,9 @@ func runAPIServerProxy(ts *tsnet.Server, rt http.RoundTripper, log *zap.SugaredL mux := http.NewServeMux() mux.HandleFunc("/", ap.serveDefault) - mux.HandleFunc("/api/v1/namespaces/{namespace}/pods/{pod}/exec", ap.serveExec) + mux.HandleFunc("POST /api/v1/namespaces/{namespace}/pods/{pod}/exec", ap.serveExecSPDY) + + mux.HandleFunc("GET /api/v1/namespaces/{namespace}/pods/{pod}/exec", ap.serveExecWS) hs := &http.Server{ // Kubernetes uses SPDY for exec and port-forward, however SPDY is @@ -214,9 +211,10 @@ func (ap *apiserverProxy) serveDefault(w http.ResponseWriter, r *http.Request) { ap.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who))) } -// serveExec serves 'kubectl exec' requests, optionally configuring the kubectl -// exec sessions to be recorded. -func (ap *apiserverProxy) serveExec(w http.ResponseWriter, r *http.Request) { +// serveExecWS serves 'kubectl exec' requests, optionally configuring the +// kubectl exec sessions to be recorded. It should only be called for requests +// for sessions that use WebSockets protocol for streaming. +func (ap *apiserverProxy) serveExecWS(w http.ResponseWriter, r *http.Request) { who, err := ap.whoIs(r) if err != nil { ap.authError(w, err) @@ -232,14 +230,59 @@ func (ap *apiserverProxy) serveExec(w http.ResponseWriter, r *http.Request) { ap.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who))) return } - counterSessionRecordingsAttempted.Add(1) // at this point we know that users intended for this session to be recorded + sessionrecording.CounterSessionRecordingsAttempted.Add(1) // at this point we know that users intended for this session to be recorded if !failOpen && len(addrs) == 0 { msg := "forbidden: 'kubectl exec' session must be recorded, but no recorders are available." ap.log.Error(msg) http.Error(w, msg, http.StatusForbidden) return } - if r.Method != "POST" || r.Header.Get("Upgrade") != "SPDY/3.1" { + if h := r.Header.Get("Upgrade"); h != "websocket" { + msg := fmt.Sprintf("[unexpected] 'kubectl exec' session was initiated for WebSocket protocol, but the request does not contain expected upgrade header, wants: 'websocket', got: %q", h) + if failOpen { + msg = msg + "; failure mode is 'fail open'; continuing session without recording." + ap.log.Warn(msg) + ap.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who))) + return + } + ap.log.Error(msg) + msg += "; failure mode is 'fail closed'; closing connection." + http.Error(w, msg, 403) + return + } else { + ap.log.Debugf("detected 'kubectl exec' session streaming protocol is WebSockets") + } + wsH := sessionrecording.New(ap.ts, r, who, w, r.PathValue("pod"), r.PathValue("namespace"), sessionrecording.WebSocketsProtocol, addrs, failOpen, tailssh.ConnectToRecorder, ap.log) + + ap.rp.ServeHTTP(wsH, r.WithContext(whoIsKey.WithValue(r.Context(), who))) +} + +// serveExecSPDY serves 'kubectl exec' requests, optionally configuring the +// kubectl exec sessions to be recorded. It should only be called for requests +// that initate 'kubectl exec' sessions using the SPDY protocol for streaming. +func (ap *apiserverProxy) serveExecSPDY(w http.ResponseWriter, r *http.Request) { + who, err := ap.whoIs(r) + if err != nil { + ap.authError(w, err) + return + } + counterNumRequestsProxied.Add(1) + failOpen, addrs, err := determineRecorderConfig(who) + if err != nil { + ap.log.Errorf("error trying to determine whether the 'kubectl exec' session needs to be recorded: %v", err) + return + } + if failOpen && len(addrs) == 0 { // will not record + ap.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who))) + return + } + if !failOpen && len(addrs) == 0 { + msg := "forbidden: 'kubectl exec' session must be recorded, but no recorders are available." + ap.log.Error(msg) + http.Error(w, msg, 403) + return + } + if r.Header.Get("Upgrade") != "SPDY/3.1" { msg := "'kubectl exec' session recording is configured, but the request is not over SPDY. Session recording is currently only supported for SPDY based clients" if failOpen { msg = msg + "; failure mode is 'fail open'; continuing session without recording." @@ -252,19 +295,7 @@ func (ap *apiserverProxy) serveExec(w http.ResponseWriter, r *http.Request) { http.Error(w, msg, http.StatusForbidden) return } - spdyH := &spdyHijacker{ - ts: ap.ts, - req: r, - who: who, - ResponseWriter: w, - log: ap.log, - pod: r.PathValue("pod"), - ns: r.PathValue("namespace"), - addrs: addrs, - failOpen: failOpen, - connectToRecorder: tailssh.ConnectToRecorder, - } - + spdyH := sessionrecording.New(ap.ts, r, who, w, r.PathValue("pod"), r.PathValue("namespace"), sessionrecording.SPDYProtocol, addrs, failOpen, tailssh.ConnectToRecorder, ap.log) ap.rp.ServeHTTP(spdyH, r.WithContext(whoIsKey.WithValue(r.Context(), who))) } diff --git a/k8s-operator/session-recording/fakes/fakes.go b/k8s-operator/session-recording/fakes/fakes.go new file mode 100644 index 000000000..9f5c349d4 --- /dev/null +++ b/k8s-operator/session-recording/fakes/fakes.go @@ -0,0 +1,117 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +// package fakes contains utils for testing session recording behaviour. +package fakes + +import ( + "bytes" + "encoding/json" + "net" + "sync" + "testing" + + "tailscale.com/k8s-operator/session-recording/tsrecorder" + "tailscale.com/tstime" +) + +func New(conn net.Conn, wb bytes.Buffer, rb bytes.Buffer, closed bool) net.Conn { + return &TestConn{ + Conn: conn, + writeBuf: wb, + readBuf: rb, + closed: closed, + } +} + +type TestConn struct { + net.Conn + // writeBuf contains whatever was send to the conn via Write. + writeBuf bytes.Buffer + // readBuf contains whatever was sent to the conn via Read. + readBuf bytes.Buffer + sync.RWMutex // protects the following + closed bool +} + +var _ net.Conn = &TestConn{} + +func (tc *TestConn) Read(b []byte) (int, error) { + return tc.readBuf.Read(b) +} + +func (tc *TestConn) Write(b []byte) (int, error) { + return tc.writeBuf.Write(b) +} + +func (tc *TestConn) Close() error { + tc.Lock() + defer tc.Unlock() + tc.closed = true + return nil +} + +func (tc *TestConn) IsClosed() bool { + tc.Lock() + defer tc.Unlock() + return tc.closed +} + +func (tc *TestConn) WriteBufBytes() []byte { + return tc.writeBuf.Bytes() +} + +func (tc *TestConn) ResetReadBuf() { + tc.readBuf.Reset() +} + +func (tc *TestConn) WriteReadBufBytes(b []byte) error { + _, err := tc.readBuf.Write(b) + return err +} + +type TestSessionRecorder struct { + // buf holds data that was sent to the session recorder. + buf bytes.Buffer +} + +func (t *TestSessionRecorder) Write(b []byte) (int, error) { + return t.buf.Write(b) +} + +func (t *TestSessionRecorder) Close() error { + t.buf.Reset() + return nil +} + +func (t *TestSessionRecorder) Bytes() []byte { + return t.buf.Bytes() +} + +func CastLine(t *testing.T, p []byte, clock tstime.Clock) []byte { + t.Helper() + j, err := json.Marshal([]any{ + clock.Now().Sub(clock.Now()).Seconds(), + "o", + string(p), + }) + if err != nil { + t.Fatalf("error marshalling cast line: %v", err) + } + return append(j, '\n') +} + +func AsciinemaResizeMsg(t *testing.T, width, height int) []byte { + t.Helper() + ch := tsrecorder.CastHeader{ + Width: width, + Height: height, + } + bs, err := json.Marshal(ch) + if err != nil { + t.Fatalf("error marshalling CastHeader: %v", err) + } + return append(bs, '\n') +} diff --git a/cmd/k8s-operator/spdy-hijacker.go b/k8s-operator/session-recording/hijacker.go similarity index 70% rename from cmd/k8s-operator/spdy-hijacker.go rename to k8s-operator/session-recording/hijacker.go index f74771e42..bbaee3ba7 100644 --- a/cmd/k8s-operator/spdy-hijacker.go +++ b/k8s-operator/session-recording/hijacker.go @@ -3,12 +3,15 @@ //go:build !plan9 -package main +// Package sessionrecording has functionality for recording 'kubectl exec' +// sessions and sending to a tsrecorder. +package sessionrecording import ( "bufio" "bytes" "context" + "errors" "fmt" "io" "net" @@ -16,20 +19,52 @@ import ( "net/netip" "strings" - "github.com/pkg/errors" "go.uber.org/zap" "tailscale.com/client/tailscale/apitype" + "tailscale.com/k8s-operator/session-recording/spdy" + "tailscale.com/k8s-operator/session-recording/tsrecorder" + "tailscale.com/k8s-operator/session-recording/ws" "tailscale.com/tailcfg" "tailscale.com/tsnet" "tailscale.com/tstime" + "tailscale.com/util/clientmetric" "tailscale.com/util/multierr" ) +const ( + SPDYProtocol = "SPDY" + WebSocketsProtocol = "WebSockets" +) + +var ( + // counterSessionRecordingsAttempted counts the number of session recording attempts. + CounterSessionRecordingsAttempted = clientmetric.NewCounter("k8s_auth_proxy__session_recordings_attempted") + + // counterSessionRecordingsUploaded counts the number of successfully uploaded session recordings. + CounterSessionRecordingsUploaded = clientmetric.NewCounter("k8s_auth_proxy_session_recordings_uploaded") +) + +func New(ts *tsnet.Server, req *http.Request, who *apitype.WhoIsResponse, w http.ResponseWriter, pod, ns string, proto protocol, addrs []netip.AddrPort, failOpen bool, connFunc RecorderDialFn, log *zap.SugaredLogger) *SpdyHijacker { + return &SpdyHijacker{ + ts: ts, + req: req, + who: who, + ResponseWriter: w, + pod: pod, + ns: ns, + addrs: addrs, + failOpen: failOpen, + connectToRecorder: connFunc, + proto: proto, + log: log, + } +} + // spdyHijacker implements [net/http.Hijacker] interface. // It must be configured with an http request for a 'kubectl exec' session that // needs to be recorded. It knows how to hijack the connection and configure for // the session contents to be sent to a tsrecorder instance. -type spdyHijacker struct { +type SpdyHijacker struct { http.ResponseWriter ts *tsnet.Server req *http.Request @@ -40,8 +75,13 @@ type spdyHijacker struct { addrs []netip.AddrPort // tsrecorder addresses failOpen bool // whether to fail open if recording fails connectToRecorder RecorderDialFn + proto protocol } +// protocol is the streaming protocol of the hijacked session. Supported +// protocols are SPDY and WebSockets. +type protocol string + // RecorderDialFn dials the specified netip.AddrPorts that should be tsrecorder // addresses. It tries to connect to recorder endpoints one by one, till one // connection succeeds. In case of success, returns a list with a single @@ -51,7 +91,7 @@ type RecorderDialFn func(context.Context, []netip.AddrPort, func(context.Context // Hijack hijacks a 'kubectl exec' session and configures for the session // contents to be sent to a recorder. -func (h *spdyHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { +func (h *SpdyHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { h.log.Infof("recorder addrs: %v, failOpen: %v", h.addrs, h.failOpen) reqConn, brw, err := h.ResponseWriter.(http.Hijacker).Hijack() if err != nil { @@ -69,7 +109,7 @@ func (h *spdyHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { // spdyHijacker.addrs. Returns conn from provided opts, wrapped in recording // logic. If connecting to the recorder fails or an error is received during the // session and spdyHijacker.failOpen is false, connection will be closed. -func (h *spdyHijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn, error) { +func (h *SpdyHijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn, error) { const ( // https://docs.asciinema.org/manual/asciicast/v2/ asciicastv2 = 2 @@ -91,30 +131,20 @@ func (h *spdyHijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.C } return nil, errors.New(msg) } - // TODO (irbekrm): log which recorder h.log.Info("successfully connected to a session recorder") wc = rw cl := tstime.DefaultClock{} - lc := &spdyRemoteConnRecorder{ - log: h.log, - Conn: conn, - rec: &recorder{ - start: cl.Now(), - clock: cl, - failOpen: h.failOpen, - conn: wc, - }, - } + rec := tsrecorder.New(wc, cl, cl.Now(), h.failOpen) qp := h.req.URL.Query() - ch := CastHeader{ + ch := tsrecorder.CastHeader{ Version: asciicastv2, - Timestamp: lc.rec.start.Unix(), + Timestamp: cl.Now().Unix(), Command: strings.Join(qp["command"], " "), SrcNode: strings.TrimSuffix(h.who.Node.Name, "."), SrcNodeID: h.who.Node.StableID, - Kubernetes: &Kubernetes{ + Kubernetes: &tsrecorder.Kubernetes{ PodName: h.pod, Namespace: h.ns, Container: strings.Join(qp["container"], " "), @@ -126,7 +156,16 @@ func (h *spdyHijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.C } else { ch.SrcNodeTags = h.who.Node.Tags } - lc.ch = ch + var lc net.Conn + switch h.proto { + case SPDYProtocol: + lc = spdy.New(conn, rec, ch, h.log) + case WebSocketsProtocol: + lc = ws.New(conn, rec, ch, h.log) + default: + return nil, fmt.Errorf("unknown protocol: %s", h.proto) + } + go func() { var err error select { @@ -135,7 +174,7 @@ func (h *spdyHijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.C case err = <-errChan: } if err == nil { - counterSessionRecordingsUploaded.Add(1) + CounterSessionRecordingsUploaded.Add(1) h.log.Info("finished uploading the recording") return } @@ -147,62 +186,15 @@ func (h *spdyHijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.C } msg += "; failure mode set to 'fail closed'; closing connection" h.log.Error(msg) - lc.failed = true - // TODO (irbekrm): write a message to the client if err := lc.Close(); err != nil { h.log.Infof("error closing recorder connections: %v", err) } return }() + return lc, nil } -// CastHeader is the asciicast header to be sent to the recorder at the start of -// the recording of a session. -// https://docs.asciinema.org/manual/asciicast/v2/#header -type CastHeader struct { - // Version is the asciinema file format version. - Version int `json:"version"` - - // Width is the terminal width in characters. - Width int `json:"width"` - - // Height is the terminal height in characters. - Height int `json:"height"` - - // Timestamp is the unix timestamp of when the recording started. - Timestamp int64 `json:"timestamp"` - - // Tailscale-specific fields: SrcNode is the full MagicDNS name of the - // tailnet node originating the connection, without the trailing dot. - SrcNode string `json:"srcNode"` - - // SrcNodeID is the node ID of the tailnet node originating the connection. - SrcNodeID tailcfg.StableNodeID `json:"srcNodeID"` - - // SrcNodeTags is the list of tags on the node originating the connection (if any). - SrcNodeTags []string `json:"srcNodeTags,omitempty"` - - // SrcNodeUserID is the user ID of the node originating the connection (if not tagged). - SrcNodeUserID tailcfg.UserID `json:"srcNodeUserID,omitempty"` // if not tagged - - // SrcNodeUser is the LoginName of the node originating the connection (if not tagged). - SrcNodeUser string `json:"srcNodeUser,omitempty"` - - Command string - - // Kubernetes-specific fields: - Kubernetes *Kubernetes `json:"kubernetes,omitempty"` -} - -// Kubernetes contains 'kubectl exec' session specific information for -// tsrecorder. -type Kubernetes struct { - PodName string - Namespace string - Container string -} - func closeConnWithWarning(conn net.Conn, msg string) error { b := io.NopCloser(bytes.NewBuffer([]byte(msg))) resp := http.Response{Status: http.StatusText(http.StatusForbidden), StatusCode: http.StatusForbidden, Body: b} diff --git a/cmd/k8s-operator/spdy-hijacker_test.go b/k8s-operator/session-recording/hijacker_test.go similarity index 71% rename from cmd/k8s-operator/spdy-hijacker_test.go rename to k8s-operator/session-recording/hijacker_test.go index 7ac79d7f0..cfc694d26 100644 --- a/cmd/k8s-operator/spdy-hijacker_test.go +++ b/k8s-operator/session-recording/hijacker_test.go @@ -3,7 +3,7 @@ //go:build !plan9 -package main +package sessionrecording import ( "context" @@ -19,6 +19,7 @@ import ( "go.uber.org/zap" "tailscale.com/client/tailscale/apitype" + "tailscale.com/k8s-operator/session-recording/fakes" "tailscale.com/tailcfg" "tailscale.com/tsnet" "tailscale.com/tstest" @@ -34,39 +35,49 @@ func Test_SPDYHijacker(t *testing.T) { failOpen bool failRecorderConnect bool // fail initial connect to the recorder failRecorderConnPostConnect bool // send error down the error channel + proto protocol wantsConnClosed bool wantsSetupErr bool }{ { - name: "setup succeeds, conn stays open", + name: "spdy_setup_succeeds_conn_stays_open", + proto: SPDYProtocol, }, { - name: "setup fails, policy is to fail open, conn stays open", + name: "ws_setup_succeeds_conn_stays_open", + proto: WebSocketsProtocol, + }, + { + name: "setup_fails_policy_is_to_fail_open_conn_stays_open", failOpen: true, failRecorderConnect: true, + proto: SPDYProtocol, }, { - name: "setup fails, policy is to fail closed, conn is closed", + name: "setup_fails_policy_is_to_fail_closed_conn_is_closed", failRecorderConnect: true, wantsSetupErr: true, wantsConnClosed: true, + proto: SPDYProtocol, }, { - name: "connection fails post-initial connect, policy is to fail open, conn stays open", + name: "connection_fails_post-initial_connect_policy_is_to_fail_open_conn_stays_open", failRecorderConnPostConnect: true, failOpen: true, + proto: SPDYProtocol, }, { - name: "connection fails post-initial connect, policy is to fail closed, conn is closed", + name: "connection_fails_post-initial_connect_policy_is_to_fail_closed_conn_is_closed", failRecorderConnPostConnect: true, wantsConnClosed: true, + proto: SPDYProtocol, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tc := &testConn{} + tc := &fakes.TestConn{} ch := make(chan error) - h := &spdyHijacker{ + h := &SpdyHijacker{ connectToRecorder: func(context.Context, []netip.AddrPort, func(context.Context, string, string) (net.Conn, error)) (wc io.WriteCloser, rec []*tailcfg.SSHRecordingAttempt, _ <-chan error, err error) { if tt.failRecorderConnect { err = errors.New("test") @@ -78,6 +89,7 @@ func Test_SPDYHijacker(t *testing.T) { log: zl.Sugar(), ts: &tsnet.Server{}, req: &http.Request{URL: &url.URL{}}, + proto: tt.proto, } ctx := context.Background() _, err := h.setUpRecording(ctx, tc) @@ -98,8 +110,8 @@ func Test_SPDYHijacker(t *testing.T) { // (test that connection remains open over some period // of time). if err := tstest.WaitFor(timeout, func() (err error) { - if tt.wantsConnClosed != tc.isClosed() { - return fmt.Errorf("got connection state: %t, wants connection state: %t", tc.isClosed(), tt.wantsConnClosed) + if tt.wantsConnClosed != tc.IsClosed() { + return fmt.Errorf("got conIection state: %t, wants connection state: %t", tc.IsClosed(), tt.wantsConnClosed) } return nil }); err != nil { diff --git a/cmd/k8s-operator/spdy-remote-conn-recorder.go b/k8s-operator/session-recording/spdy/conn.go similarity index 89% rename from cmd/k8s-operator/spdy-remote-conn-recorder.go rename to k8s-operator/session-recording/spdy/conn.go index 563b2a241..af27f27e6 100644 --- a/cmd/k8s-operator/spdy-remote-conn-recorder.go +++ b/k8s-operator/session-recording/spdy/conn.go @@ -3,7 +3,9 @@ //go:build !plan9 -package main +// Package spdy has functionality to parse 'kubectl exec' sessions streamed over +// SPDY. +package spdy import ( "bytes" @@ -15,18 +17,30 @@ import ( "sync" "sync/atomic" + "tailscale.com/k8s-operator/session-recording/tsrecorder" + "go.uber.org/zap" corev1 "k8s.io/api/core/v1" ) +func New(conn net.Conn, rec *tsrecorder.Client, ch tsrecorder.CastHeader, log *zap.SugaredLogger) net.Conn { + return &spdyRemoteConnRecorder{ + Conn: conn, + rec: rec, + ch: ch, + log: log, + } + +} + // spdyRemoteConnRecorder is a wrapper around net.Conn. It reads the bytestream // for a 'kubectl exec' session, sends session recording data to the configured // recorder and forwards the raw bytes to the original destination. type spdyRemoteConnRecorder struct { net.Conn // rec knows how to send data written to it to a tsrecorder instance. - rec *recorder - ch CastHeader + rec *tsrecorder.Client + ch tsrecorder.CastHeader stdoutStreamID atomic.Uint32 stderrStreamID atomic.Uint32 @@ -34,7 +48,6 @@ type spdyRemoteConnRecorder struct { wmu sync.Mutex // sequences writes closed bool - failed bool rmu sync.Mutex // sequences reads writeCastHeaderOnce sync.Once @@ -78,9 +91,9 @@ func (c *spdyRemoteConnRecorder) Read(b []byte) (int, error) { switch sf.StreamID { case c.resizeStreamID.Load(): var err error - var msg spdyResizeMsg + var msg tsrecorder.ResizeMsg if err = json.Unmarshal(sf.Payload, &msg); err != nil { - return 0, fmt.Errorf("error umarshalling resize msg: %w", err) + return 0, err } c.ch.Width = msg.Width c.ch.Height = msg.Height @@ -127,13 +140,14 @@ func (c *spdyRemoteConnRecorder) Write(b []byte) (int, error) { case c.stdoutStreamID.Load(), c.stderrStreamID.Load(): var err error c.writeCastHeaderOnce.Do(func() { + var j []byte j, err = json.Marshal(c.ch) if err != nil { return } j = append(j, '\n') - err = c.rec.writeCastLine(j) + err = c.rec.WriteCastLine(j) if err != nil { c.log.Errorf("received error from recorder: %v", err) } @@ -157,7 +171,9 @@ func (c *spdyRemoteConnRecorder) Close() error { if c.closed { return nil } - if !c.failed && c.writeBuf.Len() > 0 { + // TODO: only do this if this is a normal closure rather than the + // reocrding has failed. + if c.writeBuf.Len() > 0 { c.Conn.Write(c.writeBuf.Bytes()) } c.writeBuf.Reset() @@ -187,8 +203,3 @@ func (c *spdyRemoteConnRecorder) storeStreamID(sf spdyFrame, header http.Header) c.resizeStreamID.Store(id) } } - -type spdyResizeMsg struct { - Width int `json:"width"` - Height int `json:"height"` -} diff --git a/cmd/k8s-operator/spdy-remote-conn-recorder_test.go b/k8s-operator/session-recording/spdy/conn_test.go similarity index 76% rename from cmd/k8s-operator/spdy-remote-conn-recorder_test.go rename to k8s-operator/session-recording/spdy/conn_test.go index 95f5a8bfc..ce8c9ae49 100644 --- a/cmd/k8s-operator/spdy-remote-conn-recorder_test.go +++ b/k8s-operator/session-recording/spdy/conn_test.go @@ -3,19 +3,17 @@ //go:build !plan9 -package main +package spdy import ( - "bytes" "encoding/json" - "net" "reflect" - "sync" "testing" "go.uber.org/zap" + "tailscale.com/k8s-operator/session-recording/fakes" + "tailscale.com/k8s-operator/session-recording/tsrecorder" "tailscale.com/tstest" - "tailscale.com/tstime" ) // Test_Writes tests that 1 or more Write calls to spdyRemoteConnRecorder @@ -56,13 +54,13 @@ func Test_Writes(t *testing.T) { name: "single_write_stdout_data_frame_with_payload", inputs: [][]byte{{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5}}, wantForwarded: []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5}, - wantRecorded: castLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl), + wantRecorded: fakes.CastLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl), }, { name: "single_write_stderr_data_frame_with_payload", inputs: [][]byte{{0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5}}, wantForwarded: []byte{0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5}, - wantRecorded: castLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl), + wantRecorded: fakes.CastLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl), }, { name: "single_data_frame_unknow_stream_with_payload", @@ -73,13 +71,13 @@ func Test_Writes(t *testing.T) { name: "control_frame_and_data_frame_split_across_two_writes", inputs: [][]byte{{0x80, 0x3, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}, {0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5}}, wantForwarded: []byte{0x80, 0x3, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5}, - wantRecorded: castLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl), + wantRecorded: fakes.CastLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl), }, { name: "single_first_write_stdout_data_frame_with_payload", inputs: [][]byte{{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5}}, wantForwarded: []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5}, - wantRecorded: append(asciinemaResizeMsg(t, 10, 20), castLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl)...), + wantRecorded: append(fakes.AsciinemaResizeMsg(t, 10, 20), fakes.CastLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl)...), width: 10, height: 20, firstWrite: true, @@ -87,25 +85,21 @@ func Test_Writes(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tc := &testConn{} - sr := &testSessionRecorder{} - rec := &recorder{ - conn: sr, - clock: cl, - start: cl.Now(), - } + tc := &fakes.TestConn{} + sr := &fakes.TestSessionRecorder{} + rec := tsrecorder.New(sr, cl, cl.Now(), true) c := &spdyRemoteConnRecorder{ Conn: tc, log: zl.Sugar(), rec: rec, - ch: CastHeader{ + ch: tsrecorder.CastHeader{ Width: tt.width, Height: tt.height, }, } if !tt.firstWrite { - // this test case does not intend to test that cast header gets written once + // This test case does not intend to test that cast header gets written once. c.writeCastHeaderOnce.Do(func() {}) } @@ -118,13 +112,13 @@ func Test_Writes(t *testing.T) { } // Assert that the expected bytes have been forwarded to the original destination. - gotForwarded := tc.writeBuf.Bytes() + gotForwarded := tc.WriteBufBytes() if !reflect.DeepEqual(gotForwarded, tt.wantForwarded) { t.Errorf("expected bytes not forwarded, wants\n%v\ngot\n%v", tt.wantForwarded, gotForwarded) } // Assert that the expected bytes have been forwarded to the session recorder. - gotRecorded := sr.buf.Bytes() + gotRecorded := sr.Bytes() if !reflect.DeepEqual(gotRecorded, tt.wantRecorded) { t.Errorf("expected bytes not recorded, wants\n%v\ngot\n%v", tt.wantRecorded, gotRecorded) } @@ -197,13 +191,9 @@ func Test_Reads(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tc := &testConn{} - sr := &testSessionRecorder{} - rec := &recorder{ - conn: sr, - clock: cl, - start: cl.Now(), - } + tc := &fakes.TestConn{} + sr := &fakes.TestSessionRecorder{} + rec := tsrecorder.New(sr, cl, cl.Now(), true) c := &spdyRemoteConnRecorder{ Conn: tc, log: zl.Sugar(), @@ -213,9 +203,8 @@ func Test_Reads(t *testing.T) { for i, input := range tt.inputs { c.zlibReqReader = reader - tc.readBuf.Reset() - _, err := tc.readBuf.Write(input) - if err != nil { + tc.ResetReadBuf() + if err := tc.WriteReadBufBytes(input); err != nil { t.Fatalf("writing bytes to test conn: %v", err) } _, err = c.Read(make([]byte, len(input))) @@ -244,83 +233,11 @@ func Test_Reads(t *testing.T) { } } -func castLine(t *testing.T, p []byte, clock tstime.Clock) []byte { - t.Helper() - j, err := json.Marshal([]any{ - clock.Now().Sub(clock.Now()).Seconds(), - "o", - string(p), - }) - if err != nil { - t.Fatalf("error marshalling cast line: %v", err) - } - return append(j, '\n') -} - func resizeMsgBytes(t *testing.T, width, height int) []byte { t.Helper() - bs, err := json.Marshal(spdyResizeMsg{Width: width, Height: height}) + bs, err := json.Marshal(tsrecorder.ResizeMsg{Width: width, Height: height}) if err != nil { t.Fatalf("error marshalling resizeMsg: %v", err) } return bs } - -func asciinemaResizeMsg(t *testing.T, width, height int) []byte { - t.Helper() - ch := CastHeader{ - Width: width, - Height: height, - } - bs, err := json.Marshal(ch) - if err != nil { - t.Fatalf("error marshalling CastHeader: %v", err) - } - return append(bs, '\n') -} - -type testConn struct { - net.Conn - // writeBuf contains whatever was send to the conn via Write. - writeBuf bytes.Buffer - // readBuf contains whatever was sent to the conn via Read. - readBuf bytes.Buffer - sync.RWMutex // protects the following - closed bool -} - -var _ net.Conn = &testConn{} - -func (tc *testConn) Read(b []byte) (int, error) { - return tc.readBuf.Read(b) -} - -func (tc *testConn) Write(b []byte) (int, error) { - return tc.writeBuf.Write(b) -} - -func (tc *testConn) Close() error { - tc.Lock() - defer tc.Unlock() - tc.closed = true - return nil -} -func (tc *testConn) isClosed() bool { - tc.Lock() - defer tc.Unlock() - return tc.closed -} - -type testSessionRecorder struct { - // buf holds data that was sent to the session recorder. - buf bytes.Buffer -} - -func (t *testSessionRecorder) Write(b []byte) (int, error) { - return t.buf.Write(b) -} - -func (t *testSessionRecorder) Close() error { - t.buf.Reset() - return nil -} diff --git a/cmd/k8s-operator/spdy-frame.go b/k8s-operator/session-recording/spdy/frame.go similarity index 99% rename from cmd/k8s-operator/spdy-frame.go rename to k8s-operator/session-recording/spdy/frame.go index 0ddefdfa1..54b29d33a 100644 --- a/cmd/k8s-operator/spdy-frame.go +++ b/k8s-operator/session-recording/spdy/frame.go @@ -3,7 +3,7 @@ //go:build !plan9 -package main +package spdy import ( "bytes" diff --git a/cmd/k8s-operator/spdy-frame_test.go b/k8s-operator/session-recording/spdy/frame_test.go similarity index 99% rename from cmd/k8s-operator/spdy-frame_test.go rename to k8s-operator/session-recording/spdy/frame_test.go index 416ddfc8b..c6aa4cf01 100644 --- a/cmd/k8s-operator/spdy-frame_test.go +++ b/k8s-operator/session-recording/spdy/frame_test.go @@ -3,7 +3,7 @@ //go:build !plan9 -package main +package spdy import ( "bytes" diff --git a/cmd/k8s-operator/zlib-reader.go b/k8s-operator/session-recording/spdy/zlib-reader.go similarity index 99% rename from cmd/k8s-operator/zlib-reader.go rename to k8s-operator/session-recording/spdy/zlib-reader.go index b29772be3..1eb654be3 100644 --- a/cmd/k8s-operator/zlib-reader.go +++ b/k8s-operator/session-recording/spdy/zlib-reader.go @@ -3,7 +3,7 @@ //go:build !plan9 -package main +package spdy import ( "bytes" diff --git a/k8s-operator/session-recording/tsrecorder/header.go b/k8s-operator/session-recording/tsrecorder/header.go new file mode 100644 index 000000000..45c50ca1e --- /dev/null +++ b/k8s-operator/session-recording/tsrecorder/header.go @@ -0,0 +1,54 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package tsrecorder + +import "tailscale.com/tailcfg" + +// CastHeader is the asciicast header to be sent to the recorder at the start of +// the recording of a session. +// https://docs.asciinema.org/manual/asciicast/v2/#header +type CastHeader struct { + // Version is the asciinema file format version. + Version int `json:"version"` + + // Width is the terminal width in characters. + Width int `json:"width"` + + // Height is the terminal height in characters. + Height int `json:"height"` + + // Timestamp is the unix timestamp of when the recording started. + Timestamp int64 `json:"timestamp"` + + // Tailscale-specific fields: SrcNode is the full MagicDNS name of the + // tailnet node originating the connection, without the trailing dot. + SrcNode string `json:"srcNode"` + + // SrcNodeID is the node ID of the tailnet node originating the connection. + SrcNodeID tailcfg.StableNodeID `json:"srcNodeID"` + + // SrcNodeTags is the list of tags on the node originating the connection (if any). + SrcNodeTags []string `json:"srcNodeTags,omitempty"` + + // SrcNodeUserID is the user ID of the node originating the connection (if not tagged). + SrcNodeUserID tailcfg.UserID `json:"srcNodeUserID,omitempty"` // if not tagged + + // SrcNodeUser is the LoginName of the node originating the connection (if not tagged). + SrcNodeUser string `json:"srcNodeUser,omitempty"` + + Command string + + // Kubernetes-specific fields: + Kubernetes *Kubernetes `json:"kubernetes,omitempty"` +} + +// Kubernetes contains 'kubectl exec' session specific information for +// tsrecorder. +type Kubernetes struct { + PodName string + Namespace string + Container string +} diff --git a/cmd/k8s-operator/recorder.go b/k8s-operator/session-recording/tsrecorder/tsrecorder.go similarity index 57% rename from cmd/k8s-operator/recorder.go rename to k8s-operator/session-recording/tsrecorder/tsrecorder.go index ae17f3820..4ce78a882 100644 --- a/cmd/k8s-operator/recorder.go +++ b/k8s-operator/session-recording/tsrecorder/tsrecorder.go @@ -3,7 +3,9 @@ //go:build !plan9 -package main +// Package tsrecorder contains functionality to send recorded kubectl-exec +// sessions to tsrecorder. +package tsrecorder import ( "encoding/json" @@ -16,9 +18,18 @@ import ( "tailscale.com/tstime" ) -// recorder knows how to send the provided bytes to the configured tsrecorder +func New(conn io.WriteCloser, clock tstime.Clock, start time.Time, failOpen bool) *Client { + return &Client{ + start: start, + clock: clock, + conn: conn, + failOpen: failOpen, + } +} + +// Client knows how to send the provided bytes to the configured tsrecorder // instance in asciinema format. -type recorder struct { +type Client struct { start time.Time clock tstime.Clock @@ -36,15 +47,15 @@ type recorder struct { // Write appends timestamp to the provided bytes and sends them to the // configured tsrecorder. -func (rec *recorder) Write(p []byte) (err error) { +func (c *Client) Write(p []byte) (err error) { if len(p) == 0 { return nil } - if rec.backOff { + if c.backOff { return nil } j, err := json.Marshal([]any{ - rec.clock.Now().Sub(rec.start).Seconds(), + c.clock.Now().Sub(c.start).Seconds(), "o", string(p), }) @@ -52,37 +63,42 @@ func (rec *recorder) Write(p []byte) (err error) { return fmt.Errorf("error marhalling payload: %w", err) } j = append(j, '\n') - if err := rec.writeCastLine(j); err != nil { - if !rec.failOpen { + if err := c.WriteCastLine(j); err != nil { + if !c.failOpen { return fmt.Errorf("error writing payload to recorder: %w", err) } - rec.backOff = true + c.backOff = true } return nil } -func (rec *recorder) Close() error { - rec.mu.Lock() - defer rec.mu.Unlock() - if rec.conn == nil { +func (c *Client) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if c.conn == nil { return nil } - err := rec.conn.Close() - rec.conn = nil + err := c.conn.Close() + c.conn = nil return err } // writeCastLine sends bytes to the tsrecorder. The bytes should be in // asciinema format. -func (rec *recorder) writeCastLine(j []byte) error { - rec.mu.Lock() - defer rec.mu.Unlock() - if rec.conn == nil { +func (c *Client) WriteCastLine(j []byte) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.conn == nil { return errors.New("recorder closed") } - _, err := rec.conn.Write(j) + _, err := c.conn.Write(j) if err != nil { return fmt.Errorf("recorder write error: %w", err) } return nil } + +type ResizeMsg struct { + Width int `json:"width"` + Height int `json:"height"` +} diff --git a/k8s-operator/session-recording/ws/conn.go b/k8s-operator/session-recording/ws/conn.go new file mode 100644 index 000000000..88bbc2a7f --- /dev/null +++ b/k8s-operator/session-recording/ws/conn.go @@ -0,0 +1,244 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +// package ws has functionality to parse 'kubectl exec' sessions streamed using +// WebSockets protocol. +package ws + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "sync" + + "go.uber.org/zap" + "k8s.io/apimachinery/pkg/util/remotecommand" + "tailscale.com/k8s-operator/session-recording/tsrecorder" + "tailscale.com/util/multierr" +) + +// New returns a wrapper around net.Conn that intercepts reads and writes for a +// websocket streaming session over the provided net.Conn, parses the data as +// websocket messages and sends message payloads for STDIN/STDOUT streams to a +// tsrecorder instance using the provided client. Caller must ensure that the +// session is streamed using WebSockets protocol. +func New(c net.Conn, rec *tsrecorder.Client, ch tsrecorder.CastHeader, log *zap.SugaredLogger) net.Conn { + return &conn{ + Conn: c, + rec: rec, + ch: ch, + log: log, + } +} + +// conn is a wrapper around net.Conn. It reads the bytestream +// for a 'kubectl exec' session, sends session recording data to the configured +// recorder and forwards the raw bytes to the original destination. +// A new conn is created per session. +// conn only knows to how to read a 'kubectl exec' session that is streamed using WebSocket protocol. +// https://www.rfc-editor.org/rfc/rfc6455 +type conn struct { + net.Conn + // rec knows how to send data to a tsrecorder instance. + rec *tsrecorder.Client + // ch is the asiinema CastHeader for a session. + ch tsrecorder.CastHeader + log *zap.SugaredLogger + + rmu sync.Mutex // sequences reads + // currentReadMsg contains parsed contents of a websocket binary data message that + // is currently being read from the underlying net.Conn. + currentReadMsg *message + // readBuf contains bytes for a currently parsed binary data message + // read from the underlying conn. If the message is masked, it is + // unmasked in place, so having this buffer allows us to avoid modifying + // the original byte array. + readBuf bytes.Buffer + + wmu sync.Mutex // sequences writes + writeCastHeaderOnce sync.Once + closed bool + // writeBuf contains bytes for a currently parsed binary data message + // being written to the underlying conn. If the message is masked, it is + // unmasked in place, so having this buffer allows us to avoid modifying + // the original byte array. + writeBuf bytes.Buffer + // currentWriteMsg contains parsed contents of a websocket binary data message that + // is currently being written to the underlying net.Conn. + currentWriteMsg *message +} + +// Read reads bytes from the original connection and parses them as websocket +// message fragments. If the message is for the resize stream, sets the width +// and height of the CastHeader for this connection. +// The fragment can be incomplete. +func (c *conn) Read(b []byte) (int, error) { + c.rmu.Lock() + defer c.rmu.Unlock() + n, err := c.Conn.Read(b) + if err != nil { + // It seems that we sometimes get a wrapped io.EOF, but the + // caller checks for io.EOF with ==. + if errors.Is(err, io.EOF) { + err = io.EOF + } + return 0, err + } + + typ := messageType(opcode(b)) + if typ == noOpcode && c.currentReadMsg != nil && !c.currentReadMsg.isFinalized { // subsequent fragment + typ = c.currentReadMsg.typ + } + + // A control message can not be fragmented and we are not interested in + // these messages. Just return. + if isControlMessage(typ) { + return n, nil + } + + // The only data message type that Kubernetes supports is binary message. + // If we received another message type, return and let the API server close the connection. + // https://github.com/kubernetes/client-go/blob/release-1.30/tools/remotecommand/websocket.go#L281 + if typ != binaryMessage { + c.log.Info("[unexpected] received a data message with a type that is not binary message type %d", typ) + return n, nil + } + if _, err := c.readBuf.Write(b[:n]); err != nil { + return 0, fmt.Errorf("[unexpected] error writing message contents to read buffer: %w", err) + } + + readMsg := &message{typ: typ} // start a new message... + // ... or pick up an already started one if the previous fragment was not final. + if c.currentReadMsg != nil && !c.currentReadMsg.isFinalized { + readMsg = c.currentReadMsg + } + + ok, err := readMsg.Parse(c.readBuf.Bytes(), c.log) + if err != nil { + return 0, fmt.Errorf("error parsing message: %v", err) + } + if !ok { // incomplete fragment + return n, nil + } + c.readBuf.Next(len(readMsg.raw)) + + if readMsg.isFinalized { + // Stream IDs for websocket streams are static. + // https://github.com/kubernetes/client-go/blob/v0.30.0-rc.1/tools/remotecommand/websocket.go#L218 + if readMsg.streamID.Load() == remotecommand.StreamResize { + var err error + var msg tsrecorder.ResizeMsg + if err = json.Unmarshal(readMsg.payload, &msg); err != nil { + return 0, fmt.Errorf("error umarshalling resize message: %w", err) + } + c.ch.Width = msg.Width + c.ch.Height = msg.Height + } + } + c.currentReadMsg = readMsg + return n, err +} + +// Write parses the written bytes as WebSocket message fragment. If the message +// is for stdout or stderr streams, it is written to the configured tsrecorder. +// A message fragment can be incomplete. +func (c *conn) Write(b []byte) (int, error) { + c.wmu.Lock() + defer c.wmu.Unlock() + + typ := messageType(opcode(b)) + // If we are in process of parsing a message fragment, the received + // bytes are not structured as a message fragment and can not be used to + // determine a message fragment. + if len(c.writeBuf.Bytes()) > 0 { // buffer contains previous incomplete fragment + typ = c.currentWriteMsg.typ + } + + if isControlMessage(typ) { + n, err := c.Conn.Write(b) + return n, err + } + + if _, err := c.writeBuf.Write(b); err != nil { + c.log.Errorf("write: error writing to write buf: %v", err) + return 0, fmt.Errorf("[unexpected] error writing to internal write buffer: %w", err) + } + + writeMsg := &message{typ: typ} // start a new message... + // ... or continue the existing one if it has not been finalized. + if c.currentWriteMsg != nil && !c.currentWriteMsg.isFinalized { + writeMsg = c.currentWriteMsg + } + + ok, err := writeMsg.Parse(c.writeBuf.Bytes(), c.log) + if err != nil { + c.log.Errorf("write: parsing a message errored: %v", err) + return 0, fmt.Errorf("write: error parsing message: %v", err) + } + c.currentWriteMsg = writeMsg + if !ok { // incomplete fragment + return len(b), nil + } + c.writeBuf.Next(len(writeMsg.raw)) // advance frame + + if len(writeMsg.payload) != 0 && writeMsg.isFinalized { + if writeMsg.streamID.Load() == remotecommand.StreamStdOut || writeMsg.streamID.Load() == remotecommand.StreamStdErr { + var err error + c.writeCastHeaderOnce.Do(func() { + var j []byte + j, err = json.Marshal(c.ch) + if err != nil { + c.log.Infof("error marhsalling conn: %v", err) + return + } + j = append(j, '\n') + err = c.rec.WriteCastLine(j) + if err != nil { + c.log.Errorf("received error from recorder: %v", err) + } + }) + if err != nil { + return 0, fmt.Errorf("error writing CastHeader: %w", err) + } + if err := c.rec.Write(writeMsg.payload); err != nil { + return 0, fmt.Errorf("error writing message to recorder: %v", err) + } + } + } + _, err = c.Conn.Write(c.currentWriteMsg.raw) + if err != nil { + c.log.Errorf("write: error writing to conn: %v", err) + } + return len(b), err +} + +func (c *conn) Close() error { + c.wmu.Lock() + defer c.wmu.Unlock() + if c.closed { + return nil + } + // TODO: only do this if this is a normal closure rather than the + // reocrding has failed. + if c.writeBuf.Len() > 0 { + c.Conn.Write(c.writeBuf.Bytes()) + } + c.closed = true + connCloseErr := c.Conn.Close() + recCloseErr := c.rec.Close() + return multierr.New(connCloseErr, recCloseErr) +} + +// opcode reads the websocket message opcode that denotes the message type. +// opcode is contained in bits [4-8] of the message. +// https://www.rfc-editor.org/rfc/rfc6455#section-5.2 +func opcode(b []byte) int { + // 0xf = 00001111; b & 00001111 zeroes out bits [0 - 3] of b + var mask byte = 0xf + return int(b[0] & mask) +} diff --git a/k8s-operator/session-recording/ws/conn_test.go b/k8s-operator/session-recording/ws/conn_test.go new file mode 100644 index 000000000..a64b89c56 --- /dev/null +++ b/k8s-operator/session-recording/ws/conn_test.go @@ -0,0 +1,171 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package ws + +import ( + "reflect" + "testing" + + "go.uber.org/zap" + "k8s.io/apimachinery/pkg/util/remotecommand" + "tailscale.com/k8s-operator/session-recording/fakes" + "tailscale.com/k8s-operator/session-recording/tsrecorder" + "tailscale.com/tstest" +) + +func Test_conn_Read(t *testing.T) { + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + // Resize stream ID + {"width": 10, "height": 20} + testResizeMsg := []byte{byte(remotecommand.StreamResize), 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x3a, 0x32, 0x30, 0x7d} + lenResizeMsgPayload := byte(len(testResizeMsg)) + + tests := []struct { + name string + inputs [][]byte + wantWidth int + wantHeight int + }{ + { + name: "single_read_control_message", + inputs: [][]byte{{0x88, 0x0}}, + }, + { + name: "single_read_resize_message", + inputs: [][]byte{append([]byte{0x82, lenResizeMsgPayload}, testResizeMsg...)}, + wantWidth: 10, + wantHeight: 20, + }, + { + name: "two_reads_resize_message", + inputs: [][]byte{{0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22}, {0x80, 0x11, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x3a, 0x32, 0x30, 0x7d}}, + wantWidth: 10, + wantHeight: 20, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tc := &fakes.TestConn{} + tc.ResetReadBuf() + c := &conn{ + Conn: tc, + log: zl.Sugar(), + } + for i, input := range tt.inputs { + if err := tc.WriteReadBufBytes(input); err != nil { + t.Fatalf("writing bytes to test conn: %v", err) + } + _, err := c.Read(make([]byte, len(input))) + if err != nil { + t.Errorf("[%d] conn.Read() errored %v", i, err) + return + } + } + if tt.wantHeight != 0 || tt.wantWidth != 0 { + if tt.wantWidth != c.ch.Width { + t.Errorf("wants width: %v, got %v", tt.wantWidth, c.ch.Width) + } + if tt.wantHeight != c.ch.Height { + t.Errorf("want height: %v, got %v", tt.wantHeight, c.ch.Height) + } + } + }) + } +} + +func Test_conn_Write(t *testing.T) { + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + cl := tstest.NewClock(tstest.ClockOpts{}) + tests := []struct { + name string + inputs [][]byte + wantForwarded []byte + wantRecorded []byte + firstWrite bool + width int + height int + }{ + { + name: "single_write_control_frame", + inputs: [][]byte{{0x88, 0x0}}, + wantForwarded: []byte{0x88, 0x0}, + }, + { + name: "single_write_stdout_data_message", + inputs: [][]byte{{0x82, 0x3, 0x1, 0x7, 0x8}}, + wantForwarded: []byte{0x82, 0x3, 0x1, 0x7, 0x8}, + wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8}, cl), + }, + { + name: "single_write_stderr_data_message", + inputs: [][]byte{{0x82, 0x3, 0x2, 0x7, 0x8}}, + wantForwarded: []byte{0x82, 0x3, 0x2, 0x7, 0x8}, + wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8}, cl), + }, + { + name: "single_write_stdin_data_message", + inputs: [][]byte{{0x82, 0x3, 0x0, 0x7, 0x8}}, + wantForwarded: []byte{0x82, 0x3, 0x0, 0x7, 0x8}, + }, + { + name: "single_write_stdout_data_message_with_cast_header", + inputs: [][]byte{{0x82, 0x3, 0x1, 0x7, 0x8}}, + wantForwarded: []byte{0x82, 0x3, 0x1, 0x7, 0x8}, + wantRecorded: append(fakes.AsciinemaResizeMsg(t, 10, 20), fakes.CastLine(t, []byte{0x7, 0x8}, cl)...), + width: 10, + height: 20, + firstWrite: true, + }, + { + name: "two_writes_stdout_data_message", + inputs: [][]byte{{0x2, 0x3, 0x1, 0x7, 0x8}, {0x80, 0x6, 0x1, 0x1, 0x2, 0x3, 0x4, 0x5}}, + wantForwarded: []byte{0x2, 0x3, 0x1, 0x7, 0x8, 0x80, 0x6, 0x1, 0x1, 0x2, 0x3, 0x4, 0x5}, + wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8, 0x1, 0x2, 0x3, 0x4, 0x5}, cl), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tc := &fakes.TestConn{} + sr := &fakes.TestSessionRecorder{} + rec := tsrecorder.New(sr, cl, cl.Now(), true) + c := &conn{ + Conn: tc, + log: zl.Sugar(), + ch: tsrecorder.CastHeader{ + Width: tt.width, + Height: tt.height, + }, + rec: rec, + } + if !tt.firstWrite { + // This test case does not intend to test that cast header gets written once. + c.writeCastHeaderOnce.Do(func() {}) + } + for i, input := range tt.inputs { + _, err := c.Write(input) + if err != nil { + t.Fatalf("[%d] conn.Write() errored: %v", i, err) + } + } + // Assert that the expected bytes have been forwarded to the original destination. + gotForwarded := tc.WriteBufBytes() + if !reflect.DeepEqual(gotForwarded, tt.wantForwarded) { + t.Errorf("expected bytes not forwarded, wants\n%v\ngot\n%v", tt.wantForwarded, gotForwarded) + } + + // Assert that the expected bytes have been forwarded to the session recorder. + gotRecorded := sr.Bytes() + if !reflect.DeepEqual(gotRecorded, tt.wantRecorded) { + t.Errorf("expected bytes not recorded, wants\n%v\ngot\n%v", tt.wantRecorded, gotRecorded) + } + }) + } +} diff --git a/k8s-operator/session-recording/ws/message.go b/k8s-operator/session-recording/ws/message.go new file mode 100644 index 000000000..bf33e6bb2 --- /dev/null +++ b/k8s-operator/session-recording/ws/message.go @@ -0,0 +1,253 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package ws + +import ( + "encoding/binary" + "fmt" + "sync/atomic" + + "github.com/pkg/errors" + "go.uber.org/zap" +) + +const ( + noOpcode messageType = 0 // continuation frame for fragmented messages + binaryMessage messageType = 2 +) + +// messageType is the type of a websocket data or control message as defined by opcode. +// https://www.rfc-editor.org/rfc/rfc6455#section-5.2 +// Known types of control messages are close, ping and pong. +// https://www.rfc-editor.org/rfc/rfc6455#section-5.5 +// The only data message type supported by Kubernetes is binary message +// https://github.com/kubernetes/client-go/blob/v0.30.0-rc.1/tools/remotecommand/websocket.go#L281 +type messageType int + +// message is a parsed Websocket Message. +type message struct { + // payload is the contents of the so far parsed Websocket + // data Message payload, potentially from multiple fragments written by + // multiple invocations of Parse. As per RFC 6455 We can assume that the + // fragments will always arrive in order and data messages will not be + // interleaved. + payload []byte + + // isFinalized is set to true if msgPayload contains full contents of + // the message (the final fragment has been received). + isFinalized bool + + // streamID is the stream to which the message belongs, i.e stdin, stout + // etc. It is one of the stream IDs defined in + // https://github.com/kubernetes/apimachinery/commit/73d12d09c5be8703587b5127416eb83dc3b7e182#diff-291f96e8632d04d2d20f5fb00f6b323492670570d65434e8eac90c7a442d13bdR23-R36 + streamID atomic.Uint32 + + // typ is the type of a WebsocketMessage as defined by its opcode + // https://www.rfc-editor.org/rfc/rfc6455#section-5.2 + typ messageType + raw []byte +} + +// Parse accepts a websocket message fragment as a byte slice and parses its contents. +// The fragment can be: +// - a fragment that consists of a whole message +// - an initial fragment for a message for which we expect more fragments +// - a subsequent fragment for a message that we are currently parsing and whose so-far parsed contents are stored in msg. +// It is not expected that the byte slice would contain an incomplete fragment or fragment for a different message than the one currently being parsed (if any). +// Message fragment structure: +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-------+-+-------------+-------------------------------+ +// |F|R|R|R| opcode|M| Payload len | Extended payload length | +// |I|S|S|S| (4) |A| (7) | (16/64) | +// |N|V|V|V| |S| | (if payload len==126/127) | +// | |1|2|3| |K| | | +// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + +// | Extended payload length continued, if payload len == 127 | +// + - - - - - - - - - - - - - - - +-------------------------------+ +// | |Masking-key, if MASK set to 1 | +// +-------------------------------+-------------------------------+ +// | Masking-key (continued) | Payload Data | +// +-------------------------------- - - - - - - - - - - - - - - - + +// : Payload Data continued ... : +// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + +// | Payload Data continued ... | +// +---------------------------------------------------------------+ +// https://www.rfc-editor.org/rfc/rfc6455#section-5.2 +// +// Fragmentation rules: +// An unfragmented message consists of a single frame with the FIN +// bit set (Section 5.2) and an opcode other than 0. +// A fragmented message consists of a single frame with the FIN bit +// clear and an opcode other than 0, followed by zero or more frames +// with the FIN bit clear and the opcode set to 0, and terminated by +// a single frame with the FIN bit set and an opcode of 0. +// https://www.rfc-editor.org/rfc/rfc6455#section-5.4 +func (msg *message) Parse(b []byte, log *zap.SugaredLogger) (bool, error) { + if msg.typ != binaryMessage { + return false, fmt.Errorf("[unexpected] internal error: attempted to parse a message with type %d", msg.typ) + } + + msg.isFinalized = isFinalFragment(b) + + maskSet := isMasked(b) + + payloadLength, payloadOffset, maskOffset, err := fragmentDimensions(b, maskSet) + if err != nil { + return false, fmt.Errorf("error determining payload length: %w", err) + } + log.Debugf("parse: parsing a message with payload length: %d payload offset: %d maskOffset: %d mask set: %t, is finalized: %t", payloadLength, payloadOffset, maskOffset, maskSet, msg.isFinalized) + + if len(b) < int(payloadOffset)+int(payloadLength) { // incomplete fragment + return false, nil + } + msg.raw = make([]byte, int(payloadOffset)+int(payloadLength)) + copy(msg.raw, b[:payloadOffset+payloadLength]) + + // Extract the payload. + msgPayload := b[payloadOffset : payloadOffset+payloadLength] + + // Unmask the payload if needed. + if maskSet { + m := b[maskOffset:payloadOffset] + var mask [4]byte + copy(mask[:], m) + maskBytes(mask, msgPayload) + } + + // Determine what stream the message is for. Stream ID of a Kubernetes + // streaming session is a 32bit integer, stored in the first byte of the + // message payload. + // https://github.com/kubernetes/apimachinery/commit/73d12d09c5be8703587b5127416eb83dc3b7e182#diff-291f96e8632d04d2d20f5fb00f6b323492670570d65434e8eac90c7a442d13bdR23-R36 + if len(msgPayload) == 0 { + return false, errors.New("[unexpected] received a message fragment with no stream ID") + } + + streamId := uint32(msgPayload[0]) + if msg.streamID.Load() != 0 && msg.streamID.Load() != streamId { + return false, fmt.Errorf("[unexpected] received message fragments with mismatched streamIDs %d and %d", msg.streamID.Load(), streamId) + } + msg.streamID.Store(streamId) + + // This is normal, Kubernetes seem to send a couple data messages with + // no payloads at the start. + if len(msgPayload) < 2 { + return true, nil + } + msgPayload = msgPayload[1:] // remove the stream ID byte + msg.payload = append(msg.payload, msgPayload...) + return true, nil +} + +// maskBytes applies mask to bytes in place. +// https://www.rfc-editor.org/rfc/rfc6455#section-5.3 +func maskBytes(key [4]byte, b []byte) { + for i := range b { + b[i] = b[i] ^ key[i%4] + } +} + +// isControlMessage returns true if the message type is one of the know control +// frame message types. +// https://www.rfc-editor.org/rfc/rfc6455#section-5.5 +func isControlMessage(t messageType) bool { + const ( + closeMessage messageType = 8 + pingMessage messageType = 9 + pongMessage messageType = 10 + ) + return t == closeMessage || t == pingMessage || t == pongMessage +} + +// isFinalFragment can be called with websocket message fragment and returns true if +// the fragment is the final fragment of a websocket message. +func isFinalFragment(b []byte) bool { + // Extract FIN bit. FIN bit is the first bit of a message fragment. + const finBitMask byte = 1 << 7 + finBit := b[0] & finBitMask + return finBit != 0 +} + +// isMasked can be called with a websocket message fragment and returns true if +// the payload of the message is masked. It uses the mask bit to determine if +// the payload is masked. +// https://www.rfc-editor.org/rfc/rfc6455#section-5.3 +func isMasked(b []byte) bool { + return extractFirstBit(b[1]) != 0 +} + +// extractFirstBit extracts first bit of a byte by zeroing out all the other +// bits. +func extractFirstBit(b byte) byte { + const mask byte = 1 << 7 + return b & mask +} + +// zeroFirstBit returns the provided byte with the first bit set to 0. +func zeroFirstBit(b byte) byte { + const revMask byte = 1 << 7 + return b & (^revMask) +} + +// fragmentDimensions returns payload length as well as payload offset and mask offset. +func fragmentDimensions(b []byte, maskSet bool) (payloadLength, payloadOffset, maskOffset int64, _ error) { + + // payload length can be stored either in bits [9-15] or in bytes 2, 3 + // or in bytes 2, 3, 4, 5, 6, 7. + // https://www.rfc-editor.org/rfc/rfc6455#section-5.2 + // 0 1 2 3 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // +-+-+-+-+-------+-+-------------+-------------------------------+ + // |F|R|R|R| opcode|M| Payload len | Extended payload length | + // |I|S|S|S| (4) |A| (7) | (16/64) | + // |N|V|V|V| |S| | (if payload len==126/127) | + // | |1|2|3| |K| | | + // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + // | Extended payload length continued, if payload len == 127 | + // + - - - - - - - - - - - - - - - +-------------------------------+ + payloadLengthIndicator := zeroFirstBit(b[1]) + var lengthOffset int64 + switch { + case payloadLengthIndicator < 126: + lengthOffset = 1 + maskOffset = 2 + payloadLength = int64(payloadLengthIndicator) + case payloadLengthIndicator == 126: + maskOffset = 4 + lengthOffset = 2 + payloadLength = extractInt64(b, lengthOffset, 2) + case payloadLengthIndicator == 127: + maskOffset = 10 + lengthOffset = 2 + payloadLength = extractInt64(b, lengthOffset, 6) + default: + return -1, -1, -1, fmt.Errorf("unexpected payload length indicator value: %v", payloadLengthIndicator) + } + + // Masking key can take up 0 or 4 bytes- we need to take that into + // account when determining payload offset. + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // .... + // + - - - - - - - - - - - - - - - +-------------------------------+ + // | |Masking-key, if MASK set to 1 | + // +-------------------------------+-------------------------------+ + // | Masking-key (continued) | Payload Data | + // + - - - - - - - - - - - - - - - +-------------------------------+ + // ... + if maskSet { + payloadOffset = maskOffset + 4 + } else { + payloadOffset = maskOffset + } + return +} + +func extractInt64(b []byte, offset, length int64) int64 { + payloadLengthBytes := b[offset : offset+length] + payloadLengthBytesPadded := append(make([]byte, 8-len(payloadLengthBytes)), payloadLengthBytes...) + + return int64(binary.BigEndian.Uint64(payloadLengthBytesPadded)) +} diff --git a/k8s-operator/session-recording/ws/message_test.go b/k8s-operator/session-recording/ws/message_test.go new file mode 100644 index 000000000..63a80ade9 --- /dev/null +++ b/k8s-operator/session-recording/ws/message_test.go @@ -0,0 +1,125 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package ws + +import ( + "reflect" + "testing" + + "go.uber.org/zap" +) + +func Test_msg_Parse(t *testing.T) { + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatalf("error creating a test logger: %v", err) + } + testMask := [4]byte{1, 2, 3, 4} + tests := []struct { + name string + b []byte + initialPayload []byte + wantPayload []byte + wantIsFinalized bool + wantStreamID uint32 + }{ + { + name: "single_fragment_stdout_stream_no_payload_no_mask", + b: []byte{0x82, 0x1, 0x1}, + wantPayload: nil, + wantIsFinalized: true, + wantStreamID: 1, + }, + { + name: "single_fragment_stderr_steam_no_payload_has_mask", + b: append([]byte{0x82, 0x81, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x2})...), + wantPayload: nil, + wantIsFinalized: true, + wantStreamID: 2, + }, + { + name: "single_fragment_stdout_stream_no_mask_has_payload", + b: []byte{0x82, 0x3, 0x1, 0x7, 0x8}, + wantPayload: []byte{0x7, 0x8}, + wantIsFinalized: true, + wantStreamID: 1, + }, + { + name: "single_fragment_stdout_stream_has_mask_has_payload", + b: append([]byte{0x82, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...), + wantPayload: []byte{0x7, 0x8}, + wantIsFinalized: true, + wantStreamID: 1, + }, + { + name: "initial_fragment_stdout_stream_no_mask_has_payload", + b: []byte{0x2, 0x3, 0x1, 0x7, 0x8}, + wantPayload: []byte{0x7, 0x8}, + wantStreamID: 1, + }, + { + name: "initial_fragment_stdout_stream_has_mask_has_payload", + b: append([]byte{0x2, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...), + wantPayload: []byte{0x7, 0x8}, + wantStreamID: 1, + }, + { + name: "subsequent_fragment_stdout_stream_no_mask_has_payload", + b: []byte{0x0, 0x3, 0x1, 0x7, 0x8}, + initialPayload: []byte{0x1, 0x2, 0x3}, + wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8}, + wantStreamID: 1, + }, + { + name: "subsequent_fragment_stdout_stream_has_mask_has_payload", + b: append([]byte{0x0, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...), + initialPayload: []byte{0x1, 0x2, 0x3}, + wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8}, + wantStreamID: 1, + }, + { + name: "final_fragment_stdout_stream_no_mask_has_payload", + b: []byte{0x80, 0x3, 0x1, 0x7, 0x8}, + initialPayload: []byte{0x1, 0x2, 0x3}, + wantIsFinalized: true, + wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8}, + wantStreamID: 1, + }, + { + name: "final_fragment_stdout_stream_has_mask_has_payload", + b: append([]byte{0x80, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...), + initialPayload: []byte{0x1, 0x2, 0x3}, + wantIsFinalized: true, + wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8}, + wantStreamID: 1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + msg := &message{ + typ: binaryMessage, + payload: tt.initialPayload, + } + if _, err := msg.Parse(tt.b, zl.Sugar()); err != nil { + t.Errorf("msg.Parse() errored %v", err) + } + if msg.isFinalized != tt.wantIsFinalized { + t.Errorf("wants message to be finalized: %t, got: %t", tt.wantIsFinalized, msg.isFinalized) + } + if msg.streamID.Load() != tt.wantStreamID { + t.Errorf("wants stream ID: %d, got: %d", tt.wantStreamID, msg.streamID.Load()) + } + if !reflect.DeepEqual(msg.payload, tt.wantPayload) { + t.Errorf("unexpected message payload after Parse, wants %b, got %b", tt.wantPayload, msg.payload) + } + }) + } +} + +func maskedBytes(mask [4]byte, b []byte) []byte { + maskBytes(mask, b) + return b +}