diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index 1454fe5f6..d8ccbf1b4 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -417,6 +417,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/apimachinery/pkg/util/naming from k8s.io/apimachinery/pkg/runtime+ k8s.io/apimachinery/pkg/util/net from k8s.io/apimachinery/pkg/watch+ k8s.io/apimachinery/pkg/util/rand from k8s.io/apiserver/pkg/storage/names + k8s.io/apimachinery/pkg/util/remotecommand from tailscale.com/k8s-operator/sessionrecording/ws k8s.io/apimachinery/pkg/util/runtime from k8s.io/apimachinery/pkg/apis/meta/internalversion/scheme+ k8s.io/apimachinery/pkg/util/sets from k8s.io/apimachinery/pkg/api/meta+ k8s.io/apimachinery/pkg/util/strategicpatch from k8s.io/client-go/tools/record+ @@ -686,9 +687,9 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/k8s-operator/apis from tailscale.com/k8s-operator/apis/v1alpha1 tailscale.com/k8s-operator/apis/v1alpha1 from tailscale.com/cmd/k8s-operator+ tailscale.com/k8s-operator/sessionrecording from tailscale.com/cmd/k8s-operator - tailscale.com/k8s-operator/sessionrecording/conn from tailscale.com/k8s-operator/sessionrecording/spdy tailscale.com/k8s-operator/sessionrecording/spdy from tailscale.com/k8s-operator/sessionrecording tailscale.com/k8s-operator/sessionrecording/tsrecorder from tailscale.com/k8s-operator/sessionrecording+ + tailscale.com/k8s-operator/sessionrecording/ws from tailscale.com/k8s-operator/sessionrecording tailscale.com/kube from tailscale.com/cmd/k8s-operator+ tailscale.com/licenses from tailscale.com/client/web tailscale.com/log/filelogger from tailscale.com/logpolicy @@ -741,7 +742,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/posture from tailscale.com/ipn/ipnlocal tailscale.com/proxymap from tailscale.com/tsd+ 💣 tailscale.com/safesocket from tailscale.com/client/tailscale+ - tailscale.com/sessionrecording from tailscale.com/cmd/k8s-operator+ + tailscale.com/sessionrecording from tailscale.com/k8s-operator/sessionrecording+ tailscale.com/syncs from tailscale.com/control/controlknobs+ tailscale.com/tailcfg from tailscale.com/client/tailscale+ tailscale.com/taildrop from tailscale.com/ipn/ipnlocal+ @@ -863,6 +864,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ golang.org/x/net/ipv6 from github.com/miekg/dns+ golang.org/x/net/proxy from tailscale.com/net/netns D golang.org/x/net/route from net+ + golang.org/x/net/websocket from tailscale.com/k8s-operator/sessionrecording/ws golang.org/x/oauth2 from golang.org/x/oauth2/clientcredentials+ golang.org/x/oauth2/clientcredentials from tailscale.com/cmd/k8s-operator golang.org/x/oauth2/internal from golang.org/x/oauth2+ diff --git a/cmd/k8s-operator/proxy.go b/cmd/k8s-operator/proxy.go index f31b881e2..3d092fe34 100644 --- a/cmd/k8s-operator/proxy.go +++ b/cmd/k8s-operator/proxy.go @@ -22,9 +22,8 @@ "k8s.io/client-go/transport" "tailscale.com/client/tailscale" "tailscale.com/client/tailscale/apitype" - kubesessionrecording "tailscale.com/k8s-operator/sessionrecording" + ksr "tailscale.com/k8s-operator/sessionrecording" tskube "tailscale.com/kube" - "tailscale.com/sessionrecording" "tailscale.com/tailcfg" "tailscale.com/tsnet" "tailscale.com/util/clientmetric" @@ -168,7 +167,8 @@ func runAPIServerProxy(ts *tsnet.Server, rt http.RoundTripper, log *zap.SugaredL mux := http.NewServeMux() mux.HandleFunc("/", ap.serveDefault) - mux.HandleFunc("/api/v1/namespaces/{namespace}/pods/{pod}/exec", ap.serveExec) + mux.HandleFunc("POST /api/v1/namespaces/{namespace}/pods/{pod}/exec", ap.serveExecSPDY) + mux.HandleFunc("GET /api/v1/namespaces/{namespace}/pods/{pod}/exec", ap.serveExecWS) hs := &http.Server{ // Kubernetes uses SPDY for exec and port-forward, however SPDY is @@ -209,9 +209,19 @@ func (ap *apiserverProxy) serveDefault(w http.ResponseWriter, r *http.Request) { ap.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who))) } -// serveExec serves 'kubectl exec' requests, optionally configuring the kubectl -// exec sessions to be recorded. -func (ap *apiserverProxy) serveExec(w http.ResponseWriter, r *http.Request) { +// serveExecSPDY serves 'kubectl exec' requests for sessions streamed over SPDY, +// optionally configuring the kubectl exec sessions to be recorded. +func (ap *apiserverProxy) serveExecSPDY(w http.ResponseWriter, r *http.Request) { + ap.execForProto(w, r, ksr.SPDYProtocol) +} + +// serveExecWS serves 'kubectl exec' requests for sessions streamed over WebSocket, +// optionally configuring the kubectl exec sessions to be recorded. +func (ap *apiserverProxy) serveExecWS(w http.ResponseWriter, r *http.Request) { + ap.execForProto(w, r, ksr.WSProtocol) +} + +func (ap *apiserverProxy) execForProto(w http.ResponseWriter, r *http.Request, proto ksr.Protocol) { who, err := ap.whoIs(r) if err != nil { ap.authError(w, err) @@ -227,15 +237,17 @@ func (ap *apiserverProxy) serveExec(w http.ResponseWriter, r *http.Request) { ap.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who))) return } - kubesessionrecording.CounterSessionRecordingsAttempted.Add(1) // at this point we know that users intended for this session to be recorded + ksr.CounterSessionRecordingsAttempted.Add(1) // at this point we know that users intended for this session to be recorded if !failOpen && len(addrs) == 0 { msg := "forbidden: 'kubectl exec' session must be recorded, but no recorders are available." ap.log.Error(msg) http.Error(w, msg, http.StatusForbidden) return } - if r.Method != "POST" || r.Header.Get("Upgrade") != "SPDY/3.1" { - msg := "'kubectl exec' session recording is configured, but the request is not over SPDY. Session recording is currently only supported for SPDY based clients" + + wantsHeader := upgradeHeaderForProto[proto] + if h := r.Header.Get("Upgrade"); h != wantsHeader { + msg := fmt.Sprintf("[unexpected] unable to verify that streaming protocol is %s, wants Upgrade header %q, got: %q", proto, wantsHeader, h) if failOpen { msg = msg + "; failure mode is 'fail open'; continuing session without recording." ap.log.Warn(msg) @@ -247,9 +259,22 @@ func (ap *apiserverProxy) serveExec(w http.ResponseWriter, r *http.Request) { http.Error(w, msg, http.StatusForbidden) return } - spdyH := kubesessionrecording.New(ap.ts, r, who, w, r.PathValue("pod"), r.PathValue("namespace"), kubesessionrecording.SPDYProtocol, addrs, failOpen, sessionrecording.ConnectToRecorder, ap.log) - ap.rp.ServeHTTP(spdyH, r.WithContext(whoIsKey.WithValue(r.Context(), who))) + opts := ksr.HijackerOpts{ + Req: r, + W: w, + Proto: proto, + TS: ap.ts, + Who: who, + Addrs: addrs, + FailOpen: failOpen, + Pod: r.PathValue("pod"), + Namespace: r.PathValue("namespace"), + Log: ap.log, + } + h := ksr.New(opts) + + ap.rp.ServeHTTP(h, r.WithContext(whoIsKey.WithValue(r.Context(), who))) } func (h *apiserverProxy) addImpersonationHeadersAsRequired(r *http.Request) { @@ -382,3 +407,8 @@ func determineRecorderConfig(who *apitype.WhoIsResponse) (failOpen bool, recorde } return failOpen, recorderAddresses, nil } + +var upgradeHeaderForProto = map[ksr.Protocol]string{ + ksr.SPDYProtocol: "SPDY/3.1", + ksr.WSProtocol: "websocket", +} diff --git a/k8s-operator/sessionrecording/conn/conn.go b/k8s-operator/sessionrecording/conn/conn.go deleted file mode 100644 index 4be98338d..000000000 --- a/k8s-operator/sessionrecording/conn/conn.go +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !plan9 - -// Package conn contains shared interface for the hijacked -// connection of a 'kubectl exec' session that is being recorded. -package conn - -import "net" - -type Conn interface { - net.Conn - // Fail can be called to set connection state to failed. By default any - // bytes left over in write buffer are forwarded to the intended - // destination when the connection is being closed except for when the - // connection state is failed- so set the state to failed when erroring - // out and failure policy is to fail closed. - Fail() -} diff --git a/k8s-operator/sessionrecording/fakes/fakes.go b/k8s-operator/sessionrecording/fakes/fakes.go index fb5911cf2..9eb1047e4 100644 --- a/k8s-operator/sessionrecording/fakes/fakes.go +++ b/k8s-operator/sessionrecording/fakes/fakes.go @@ -13,6 +13,9 @@ "net" "sync" "testing" + "time" + + "math/rand" "tailscale.com/sessionrecording" "tailscale.com/tstime" @@ -116,3 +119,20 @@ func AsciinemaResizeMsg(t *testing.T, width, height int) []byte { } return append(bs, '\n') } + +func RandomBytes(t *testing.T) [][]byte { + t.Helper() + r := rand.New(rand.NewSource(time.Now().UnixNano())) + n := r.Intn(4096) + b := make([]byte, n) + t.Logf("RandomBytes: generating byte slice of length %d", n) + _, err := r.Read(b) + if err != nil { + t.Fatalf("error generating random byte slice: %v", err) + } + if len(b) < 2 { + return [][]byte{b} + } + split := r.Intn(len(b) - 1) + return [][]byte{b[:split], b[split:]} +} diff --git a/k8s-operator/sessionrecording/hijacker.go b/k8s-operator/sessionrecording/hijacker.go index 3f3d85cd8..2e7ec7598 100644 --- a/k8s-operator/sessionrecording/hijacker.go +++ b/k8s-operator/sessionrecording/hijacker.go @@ -23,6 +23,7 @@ "tailscale.com/client/tailscale/apitype" "tailscale.com/k8s-operator/sessionrecording/spdy" "tailscale.com/k8s-operator/sessionrecording/tsrecorder" + "tailscale.com/k8s-operator/sessionrecording/ws" "tailscale.com/sessionrecording" "tailscale.com/tailcfg" "tailscale.com/tsnet" @@ -31,11 +32,14 @@ "tailscale.com/util/multierr" ) -const SPDYProtocol protocol = "SPDY" +const ( + SPDYProtocol Protocol = "SPDY" + WSProtocol Protocol = "WebSocket" +) -// protocol is the streaming protocol of the hijacked session. Supported -// protocols are SPDY. -type protocol string +// Protocol is the streaming protocol of the hijacked session. Supported +// protocols are SPDY and WebSocket. +type Protocol string var ( // CounterSessionRecordingsAttempted counts the number of session recording attempts. @@ -45,22 +49,35 @@ counterSessionRecordingsUploaded = clientmetric.NewCounter("k8s_auth_proxy_session_recordings_uploaded") ) -func New(ts *tsnet.Server, req *http.Request, who *apitype.WhoIsResponse, w http.ResponseWriter, pod, ns string, proto protocol, addrs []netip.AddrPort, failOpen bool, connFunc RecorderDialFn, log *zap.SugaredLogger) *Hijacker { +func New(opts HijackerOpts) *Hijacker { return &Hijacker{ - ts: ts, - req: req, - who: who, - ResponseWriter: w, - pod: pod, - ns: ns, - addrs: addrs, - failOpen: failOpen, - connectToRecorder: connFunc, - proto: proto, - log: log, + ts: opts.TS, + req: opts.Req, + who: opts.Who, + ResponseWriter: opts.W, + pod: opts.Pod, + ns: opts.Namespace, + addrs: opts.Addrs, + failOpen: opts.FailOpen, + proto: opts.Proto, + log: opts.Log, + connectToRecorder: sessionrecording.ConnectToRecorder, } } +type HijackerOpts struct { + TS *tsnet.Server + Req *http.Request + W http.ResponseWriter + Who *apitype.WhoIsResponse + Addrs []netip.AddrPort + Log *zap.SugaredLogger + Pod string + Namespace string + FailOpen bool + Proto Protocol +} + // Hijacker implements [net/http.Hijacker] interface. // It must be configured with an http request for a 'kubectl exec' session that // needs to be recorded. It knows how to hijack the connection and configure for @@ -76,7 +93,7 @@ type Hijacker struct { addrs []netip.AddrPort // tsrecorder addresses failOpen bool // whether to fail open if recording fails connectToRecorder RecorderDialFn - proto protocol // streaming protocol + proto Protocol // streaming protocol } // RecorderDialFn dials the specified netip.AddrPorts that should be tsrecorder @@ -111,10 +128,14 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn, // https://docs.asciinema.org/manual/asciicast/v2/ asciicastv2 = 2 ) - var wc io.WriteCloser + var ( + wc io.WriteCloser + err error + errChan <-chan error + ) h.log.Infof("kubectl exec session will be recorded, recorders: %v, fail open policy: %t", h.addrs, h.failOpen) // TODO (irbekrm): send client a message that session will be recorded. - rw, _, errChan, err := h.connectToRecorder(ctx, h.addrs, h.ts.Dial) + wc, _, errChan, err = h.connectToRecorder(ctx, h.addrs, h.ts.Dial) if err != nil { msg := fmt.Sprintf("error connecting to session recorders: %v", err) if h.failOpen { @@ -131,7 +152,6 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn, // TODO (irbekrm): log which recorder h.log.Info("successfully connected to a session recorder") - wc = rw cl := tstime.DefaultClock{} rec := tsrecorder.New(wc, cl, cl.Now(), h.failOpen) qp := h.req.URL.Query() @@ -153,7 +173,17 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn, } else { ch.SrcNodeTags = h.who.Node.Tags } - lc := spdy.New(conn, rec, ch, h.log) + + var lc net.Conn + switch h.proto { + case SPDYProtocol: + lc = spdy.New(conn, rec, ch, h.log) + case WSProtocol: + lc = ws.New(conn, rec, ch, h.log) + default: + return nil, fmt.Errorf("unknown protocol: %s", h.proto) + } + go func() { var err error select { @@ -174,7 +204,6 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn, } msg += "; failure mode set to 'fail closed'; closing connection" h.log.Error(msg) - lc.Fail() // TODO (irbekrm): write a message to the client if err := lc.Close(); err != nil { h.log.Infof("error closing recorder connections: %v", err) diff --git a/k8s-operator/sessionrecording/hijacker_test.go b/k8s-operator/sessionrecording/hijacker_test.go index 9f7fb1930..5c19d3a1d 100644 --- a/k8s-operator/sessionrecording/hijacker_test.go +++ b/k8s-operator/sessionrecording/hijacker_test.go @@ -37,30 +37,40 @@ func Test_Hijacker(t *testing.T) { failRecorderConnPostConnect bool // send error down the error channel wantsConnClosed bool wantsSetupErr bool + proto Protocol }{ { - name: "setup succeeds, conn stays open", + name: "setup_succeeds_conn_stays_open", + proto: SPDYProtocol, }, { - name: "setup fails, policy is to fail open, conn stays open", + name: "setup_succeeds_conn_stays_open_ws", + proto: WSProtocol, + }, + { + name: "setup_fails_policy_is_to_fail_open_conn_stays_open", failOpen: true, failRecorderConnect: true, + proto: SPDYProtocol, }, { - name: "setup fails, policy is to fail closed, conn is closed", + name: "setup_fails_policy_is_to_fail_closed_conn_is_closed", failRecorderConnect: true, wantsSetupErr: true, wantsConnClosed: true, + proto: SPDYProtocol, }, { - name: "connection fails post-initial connect, policy is to fail open, conn stays open", + name: "connection_fails_post-initial_connect_policy_is_to_fail_open_conn_stays_open", failRecorderConnPostConnect: true, failOpen: true, + proto: SPDYProtocol, }, { - name: "connection fails post-initial connect, policy is to fail closed, conn is closed", + name: "connection_fails_post-initial_connect,_policy_is_to_fail_closed_conn_is_closed", failRecorderConnPostConnect: true, wantsConnClosed: true, + proto: SPDYProtocol, }, } for _, tt := range tests { @@ -79,6 +89,7 @@ func Test_Hijacker(t *testing.T) { log: zl.Sugar(), ts: &tsnet.Server{}, req: &http.Request{URL: &url.URL{}}, + proto: tt.proto, } ctx := context.Background() _, err := h.setUpRecording(ctx, tc) diff --git a/k8s-operator/sessionrecording/spdy/conn.go b/k8s-operator/sessionrecording/spdy/conn.go index 08e221feb..19a01641e 100644 --- a/k8s-operator/sessionrecording/spdy/conn.go +++ b/k8s-operator/sessionrecording/spdy/conn.go @@ -19,12 +19,18 @@ "go.uber.org/zap" corev1 "k8s.io/api/core/v1" - srconn "tailscale.com/k8s-operator/sessionrecording/conn" "tailscale.com/k8s-operator/sessionrecording/tsrecorder" "tailscale.com/sessionrecording" ) -func New(nc net.Conn, rec *tsrecorder.Client, ch sessionrecording.CastHeader, log *zap.SugaredLogger) srconn.Conn { +// New wraps the provided network connection and returns a connection whose reads and writes will get triggered as data is received on the hijacked connection. +// The connection must be a hijacked connection for a 'kubectl exec' session using SPDY. +// The hijacked connection is used to transmit SPDY streams between Kubernetes client ('kubectl') and the destination container. +// Data read from the underlying network connection is data sent via one of the SPDY streams from the client to the container. +// Data written to the underlying connection is data sent from the container to the client. +// We parse the data and send everything for the STDOUT/STDERR streams to the configured tsrecorder as an asciinema recording with the provided header. +// https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/4006-transition-spdy-to-websockets#background-remotecommand-subprotocol +func New(nc net.Conn, rec *tsrecorder.Client, ch sessionrecording.CastHeader, log *zap.SugaredLogger) net.Conn { return &conn{ Conn: nc, rec: rec, @@ -49,7 +55,6 @@ type conn struct { wmu sync.Mutex // sequences writes closed bool - failed bool rmu sync.Mutex // sequences reads writeCastHeaderOnce sync.Once @@ -172,9 +177,6 @@ func (c *conn) Close() error { if c.closed { return nil } - if !c.failed && c.writeBuf.Len() > 0 { - c.Conn.Write(c.writeBuf.Bytes()) - } c.writeBuf.Reset() c.closed = true err := c.Conn.Close() @@ -182,14 +184,8 @@ func (c *conn) Close() error { return err } -func (s *conn) Fail() { - s.wmu.Lock() - s.failed = true - s.wmu.Unlock() -} - // storeStreamID parses SYN_STREAM SPDY control frame and updates -// spdyRemoteConnRecorder to store the newly created stream's ID if it is one of +// conn to store the newly created stream's ID if it is one of // the stream types we care about. Storing stream_id:stream_type mapping allows // us to parse received data frames (that have stream IDs) differently depening // on which stream they belong to (i.e send data frame payload for stdout stream diff --git a/k8s-operator/sessionrecording/spdy/conn_test.go b/k8s-operator/sessionrecording/spdy/conn_test.go index 0046ae298..629536b2e 100644 --- a/k8s-operator/sessionrecording/spdy/conn_test.go +++ b/k8s-operator/sessionrecording/spdy/conn_test.go @@ -7,6 +7,7 @@ import ( "encoding/json" + "fmt" "reflect" "testing" @@ -234,6 +235,57 @@ func Test_Reads(t *testing.T) { } } +// Test_conn_ReadRand tests reading arbitrarily generated byte slices from conn to +// test that we don't panic when parsing input from a broken or malicious +// client. +func Test_conn_ReadRand(t *testing.T) { + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatalf("error creating a test logger: %v", err) + } + for i := range 1000 { + tc := &fakes.TestConn{} + tc.ResetReadBuf() + c := &conn{ + Conn: tc, + log: zl.Sugar(), + } + bb := fakes.RandomBytes(t) + for j, input := range bb { + if err := tc.WriteReadBufBytes(input); err != nil { + t.Fatalf("[%d] writing bytes to test conn: %v", i, err) + } + f := func() { + c.Read(make([]byte, len(input))) + } + testPanic(t, f, fmt.Sprintf("[%d %d] Read panic parsing input of length %d", i, j, len(input))) + } + } +} + +// Test_conn_WriteRand calls conn.Write with an arbitrary input to validate that +// it does not panic. +func Test_conn_WriteRand(t *testing.T) { + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatalf("error creating a test logger: %v", err) + } + for i := range 100 { + tc := &fakes.TestConn{} + c := &conn{ + Conn: tc, + log: zl.Sugar(), + } + bb := fakes.RandomBytes(t) + for j, input := range bb { + f := func() { + c.Write(input) + } + testPanic(t, f, fmt.Sprintf("[%d %d] Write: panic parsing input of length %d", i, j, len(input))) + } + } +} + func resizeMsgBytes(t *testing.T, width, height int) []byte { t.Helper() bs, err := json.Marshal(spdyResizeMsg{Width: width, Height: height}) diff --git a/k8s-operator/sessionrecording/spdy/frame_test.go b/k8s-operator/sessionrecording/spdy/frame_test.go index c6aa4cf01..4896cdcbf 100644 --- a/k8s-operator/sessionrecording/spdy/frame_test.go +++ b/k8s-operator/sessionrecording/spdy/frame_test.go @@ -9,11 +9,15 @@ "bytes" "compress/zlib" "encoding/binary" + "fmt" "io" "net/http" "reflect" "strings" "testing" + "time" + + "math/rand" "github.com/google/go-cmp/cmp" "go.uber.org/zap" @@ -200,6 +204,29 @@ func Test_spdyFrame_parseHeaders(t *testing.T) { } } +// Test_spdyFrame_ParseRand calls spdyFrame.Parse with randomly generated bytes +// to test that it doesn't panic. +func Test_spdyFrame_ParseRand(t *testing.T) { + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + r := rand.New(rand.NewSource(time.Now().UnixNano())) + for i := range 100 { + n := r.Intn(4096) + b := make([]byte, n) + _, err := r.Read(b) + if err != nil { + t.Fatalf("error generating random byte slice: %v", err) + } + sf := &spdyFrame{} + f := func() { + sf.Parse(b, zl.Sugar()) + } + testPanic(t, f, fmt.Sprintf("[%d] Parse panicked running with byte slice of length %d: %v", i, n, r)) + } +} + // payload takes a control frame type and a map with 0 or more header keys and // values and returns a SPDY control frame payload with the header as SPDY zlib // compressed header name/value block. The payload is padded with arbitrary @@ -291,3 +318,13 @@ func header(hs map[string]string) http.Header { } return h } + +func testPanic(t *testing.T, f func(), msg string) { + t.Helper() + defer func() { + if r := recover(); r != nil { + t.Fatal(msg, r) + } + }() + f() +} diff --git a/k8s-operator/sessionrecording/ws/conn.go b/k8s-operator/sessionrecording/ws/conn.go new file mode 100644 index 000000000..82fd094d1 --- /dev/null +++ b/k8s-operator/sessionrecording/ws/conn.go @@ -0,0 +1,301 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +// package ws has functionality to parse 'kubectl exec' sessions streamed using +// WebSocket protocol. +package ws + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "sync" + + "go.uber.org/zap" + "k8s.io/apimachinery/pkg/util/remotecommand" + "tailscale.com/k8s-operator/sessionrecording/tsrecorder" + "tailscale.com/sessionrecording" + "tailscale.com/util/multierr" +) + +// New wraps the provided network connection and returns a connection whose reads and writes will get triggered as data is received on the hijacked connection. +// The connection must be a hijacked connection for a 'kubectl exec' session using WebSocket protocol and a *.channel.k8s.io subprotocol. +// The hijacked connection is used to transmit *.channel.k8s.io streams between Kubernetes client ('kubectl') and the destination proxy controlled by Kubernetes. +// Data read from the underlying network connection is data sent via one of the streams from the client to the container. +// Data written to the underlying connection is data sent from the container to the client. +// We parse the data and send everything for the STDOUT/STDERR streams to the configured tsrecorder as an asciinema recording with the provided header. +// https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/4006-transition-spdy-to-websockets#proposal-new-remotecommand-sub-protocol-version---v5channelk8sio +func New(c net.Conn, rec *tsrecorder.Client, ch sessionrecording.CastHeader, log *zap.SugaredLogger) net.Conn { + return &conn{ + Conn: c, + rec: rec, + ch: ch, + log: log, + } +} + +// conn is a wrapper around net.Conn. It reads the bytestream +// for a 'kubectl exec' session, sends session recording data to the configured +// recorder and forwards the raw bytes to the original destination. +// A new conn is created per session. +// conn only knows to how to read a 'kubectl exec' session that is streamed using WebSocket protocol. +// https://www.rfc-editor.org/rfc/rfc6455 +type conn struct { + net.Conn + // rec knows how to send data to a tsrecorder instance. + rec *tsrecorder.Client + // ch is the asiinema CastHeader for a session. + ch sessionrecording.CastHeader + log *zap.SugaredLogger + + rmu sync.Mutex // sequences reads + // currentReadMsg contains parsed contents of a websocket binary data message that + // is currently being read from the underlying net.Conn. + currentReadMsg *message + // readBuf contains bytes for a currently parsed binary data message + // read from the underlying conn. If the message is masked, it is + // unmasked in place, so having this buffer allows us to avoid modifying + // the original byte array. + readBuf bytes.Buffer + + wmu sync.Mutex // sequences writes + writeCastHeaderOnce sync.Once + closed bool // connection is closed + // writeBuf contains bytes for a currently parsed binary data message + // being written to the underlying conn. If the message is masked, it is + // unmasked in place, so having this buffer allows us to avoid modifying + // the original byte array. + writeBuf bytes.Buffer + // currentWriteMsg contains parsed contents of a websocket binary data message that + // is currently being written to the underlying net.Conn. + currentWriteMsg *message +} + +// Read reads bytes from the original connection and parses them as websocket +// message fragments. +// Bytes read from the original connection are the bytes sent from the Kubernetes client (kubectl) to the destination container via kubelet. + +// If the message is for the resize stream, sets the width +// and height of the CastHeader for this connection. +// The fragment can be incomplete. +func (c *conn) Read(b []byte) (int, error) { + c.rmu.Lock() + defer c.rmu.Unlock() + n, err := c.Conn.Read(b) + if err != nil { + // It seems that we sometimes get a wrapped io.EOF, but the + // caller checks for io.EOF with ==. + if errors.Is(err, io.EOF) { + err = io.EOF + } + return 0, err + } + if n == 0 { + c.log.Debug("[unexpected] Read called for 0 length bytes") + return 0, nil + } + + typ := messageType(opcode(b)) + if (typ == noOpcode && c.readMsgIsIncomplete()) || c.readBufHasIncompleteFragment() { // subsequent fragment + if typ, err = c.curReadMsgType(); err != nil { + return 0, err + } + } + + // A control message can not be fragmented and we are not interested in + // these messages. Just return. + if isControlMessage(typ) { + return n, nil + } + + // The only data message type that Kubernetes supports is binary message. + // If we received another message type, return and let the API server close the connection. + // https://github.com/kubernetes/client-go/blob/release-1.30/tools/remotecommand/websocket.go#L281 + if typ != binaryMessage { + c.log.Infof("[unexpected] received a data message with a type that is not binary message type %v", typ) + return n, nil + } + + readMsg := &message{typ: typ} // start a new message... + // ... or pick up an already started one if the previous fragment was not final. + if c.readMsgIsIncomplete() || c.readBufHasIncompleteFragment() { + readMsg = c.currentReadMsg + } + + if _, err := c.readBuf.Write(b[:n]); err != nil { + return 0, fmt.Errorf("[unexpected] error writing message contents to read buffer: %w", err) + } + + ok, err := readMsg.Parse(c.readBuf.Bytes(), c.log) + if err != nil { + return 0, fmt.Errorf("error parsing message: %v", err) + } + if !ok { // incomplete fragment + return n, nil + } + c.readBuf.Next(len(readMsg.raw)) + + if readMsg.isFinalized { + // Stream IDs for websocket streams are static. + // https://github.com/kubernetes/client-go/blob/v0.30.0-rc.1/tools/remotecommand/websocket.go#L218 + if readMsg.streamID.Load() == remotecommand.StreamResize { + var err error + var msg tsrecorder.ResizeMsg + if err = json.Unmarshal(readMsg.payload, &msg); err != nil { + return 0, fmt.Errorf("error umarshalling resize message: %w", err) + } + c.ch.Width = msg.Width + c.ch.Height = msg.Height + } + } + c.currentReadMsg = readMsg + return n, nil +} + +// Write parses the written bytes as WebSocket message fragment. If the message +// is for stdout or stderr streams, it is written to the configured tsrecorder. +// A message fragment can be incomplete. +func (c *conn) Write(b []byte) (int, error) { + c.wmu.Lock() + defer c.wmu.Unlock() + if len(b) == 0 { + c.log.Debug("[unexpected] Write called with 0 bytes") + return 0, nil + } + + typ := messageType(opcode(b)) + // If we are in process of parsing a message fragment, the received + // bytes are not structured as a message fragment and can not be used to + // determine a message fragment. + if c.writeBufHasIncompleteFragment() { // buffer contains previous incomplete fragment + var err error + if typ, err = c.curWriteMsgType(); err != nil { + return 0, err + } + } + + if isControlMessage(typ) { + return c.Conn.Write(b) + } + + writeMsg := &message{typ: typ} // start a new message... + // ... or continue the existing one if it has not been finalized. + if c.writeMsgIsIncomplete() || c.writeBufHasIncompleteFragment() { + writeMsg = c.currentWriteMsg + } + + if _, err := c.writeBuf.Write(b); err != nil { + c.log.Errorf("write: error writing to write buf: %v", err) + return 0, fmt.Errorf("[unexpected] error writing to internal write buffer: %w", err) + } + + ok, err := writeMsg.Parse(c.writeBuf.Bytes(), c.log) + if err != nil { + c.log.Errorf("write: parsing a message errored: %v", err) + return 0, fmt.Errorf("write: error parsing message: %v", err) + } + c.currentWriteMsg = writeMsg + if !ok { // incomplete fragment + return len(b), nil + } + c.writeBuf.Next(len(writeMsg.raw)) // advance frame + + if len(writeMsg.payload) != 0 && writeMsg.isFinalized { + if writeMsg.streamID.Load() == remotecommand.StreamStdOut || writeMsg.streamID.Load() == remotecommand.StreamStdErr { + var err error + c.writeCastHeaderOnce.Do(func() { + var j []byte + j, err = json.Marshal(c.ch) + if err != nil { + c.log.Errorf("error marhsalling conn: %v", err) + return + } + j = append(j, '\n') + err = c.rec.WriteCastLine(j) + if err != nil { + c.log.Errorf("received error from recorder: %v", err) + } + }) + if err != nil { + return 0, fmt.Errorf("error writing CastHeader: %w", err) + } + if err := c.rec.Write(writeMsg.payload); err != nil { + return 0, fmt.Errorf("error writing message to recorder: %v", err) + } + } + } + _, err = c.Conn.Write(c.currentWriteMsg.raw) + if err != nil { + c.log.Errorf("write: error writing to conn: %v", err) + } + return len(b), nil +} + +func (c *conn) Close() error { + c.wmu.Lock() + defer c.wmu.Unlock() + if c.closed { + return nil + } + c.closed = true + connCloseErr := c.Conn.Close() + recCloseErr := c.rec.Close() + return multierr.New(connCloseErr, recCloseErr) +} + +// writeBufHasIncompleteFragment returns true if the latest data message +// fragment written to the connection was incomplete and the following write +// must be the remaining payload bytes of that fragment. +func (c *conn) writeBufHasIncompleteFragment() bool { + return c.writeBuf.Len() != 0 +} + +// readBufHasIncompleteFragment returns true if the latest data message +// fragment read from the connection was incomplete and the following read +// must be the remaining payload bytes of that fragment. +func (c *conn) readBufHasIncompleteFragment() bool { + return c.readBuf.Len() != 0 +} + +// writeMsgIsIncomplete returns true if the latest WebSocket message written to +// the connection was fragmented and the next data message fragment written to +// the connection must be a fragment of that message. +// https://www.rfc-editor.org/rfc/rfc6455#section-5.4 +func (c *conn) writeMsgIsIncomplete() bool { + return c.currentWriteMsg != nil && !c.currentWriteMsg.isFinalized +} + +// readMsgIsIncomplete returns true if the latest WebSocket message written to +// the connection was fragmented and the next data message fragment written to +// the connection must be a fragment of that message. +// https://www.rfc-editor.org/rfc/rfc6455#section-5.4 +func (c *conn) readMsgIsIncomplete() bool { + return c.currentReadMsg != nil && !c.currentReadMsg.isFinalized +} +func (c *conn) curReadMsgType() (messageType, error) { + if c.currentReadMsg != nil { + return c.currentReadMsg.typ, nil + } + return 0, errors.New("[unexpected] attempted to determine type for nil message") +} + +func (c *conn) curWriteMsgType() (messageType, error) { + if c.currentWriteMsg != nil { + return c.currentWriteMsg.typ, nil + } + return 0, errors.New("[unexpected] attempted to determine type for nil message") +} + +// opcode reads the websocket message opcode that denotes the message type. +// opcode is contained in bits [4-8] of the message. +// https://www.rfc-editor.org/rfc/rfc6455#section-5.2 +func opcode(b []byte) int { + // 0xf = 00001111; b & 00001111 zeroes out bits [0 - 3] of b + var mask byte = 0xf + return int(b[0] & mask) +} diff --git a/k8s-operator/sessionrecording/ws/conn_test.go b/k8s-operator/sessionrecording/ws/conn_test.go new file mode 100644 index 000000000..2fcbeb7ca --- /dev/null +++ b/k8s-operator/sessionrecording/ws/conn_test.go @@ -0,0 +1,257 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package ws + +import ( + "fmt" + "reflect" + "testing" + + "go.uber.org/zap" + "k8s.io/apimachinery/pkg/util/remotecommand" + "tailscale.com/k8s-operator/sessionrecording/fakes" + "tailscale.com/k8s-operator/sessionrecording/tsrecorder" + "tailscale.com/sessionrecording" + "tailscale.com/tstest" +) + +func Test_conn_Read(t *testing.T) { + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + // Resize stream ID + {"width": 10, "height": 20} + testResizeMsg := []byte{byte(remotecommand.StreamResize), 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x3a, 0x32, 0x30, 0x7d} + lenResizeMsgPayload := byte(len(testResizeMsg)) + + tests := []struct { + name string + inputs [][]byte + wantWidth int + wantHeight int + }{ + { + name: "single_read_control_message", + inputs: [][]byte{{0x88, 0x0}}, + }, + { + name: "single_read_resize_message", + inputs: [][]byte{append([]byte{0x82, lenResizeMsgPayload}, testResizeMsg...)}, + wantWidth: 10, + wantHeight: 20, + }, + { + name: "two_reads_resize_message", + inputs: [][]byte{{0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22}, {0x80, 0x11, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x3a, 0x32, 0x30, 0x7d}}, + wantWidth: 10, + wantHeight: 20, + }, + { + name: "three_reads_resize_message_with_split_fragment", + inputs: [][]byte{{0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22}, {0x80, 0x11, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74}, {0x22, 0x3a, 0x32, 0x30, 0x7d}}, + wantWidth: 10, + wantHeight: 20, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tc := &fakes.TestConn{} + tc.ResetReadBuf() + c := &conn{ + Conn: tc, + log: zl.Sugar(), + } + for i, input := range tt.inputs { + if err := tc.WriteReadBufBytes(input); err != nil { + t.Fatalf("writing bytes to test conn: %v", err) + } + _, err := c.Read(make([]byte, len(input))) + if err != nil { + t.Errorf("[%d] conn.Read() errored %v", i, err) + return + } + } + if tt.wantHeight != 0 || tt.wantWidth != 0 { + if tt.wantWidth != c.ch.Width { + t.Errorf("wants width: %v, got %v", tt.wantWidth, c.ch.Width) + } + if tt.wantHeight != c.ch.Height { + t.Errorf("want height: %v, got %v", tt.wantHeight, c.ch.Height) + } + } + }) + } +} + +func Test_conn_Write(t *testing.T) { + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + cl := tstest.NewClock(tstest.ClockOpts{}) + tests := []struct { + name string + inputs [][]byte + wantForwarded []byte + wantRecorded []byte + firstWrite bool + width int + height int + }{ + { + name: "single_write_control_frame", + inputs: [][]byte{{0x88, 0x0}}, + wantForwarded: []byte{0x88, 0x0}, + }, + { + name: "single_write_stdout_data_message", + inputs: [][]byte{{0x82, 0x3, 0x1, 0x7, 0x8}}, + wantForwarded: []byte{0x82, 0x3, 0x1, 0x7, 0x8}, + wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8}, cl), + }, + { + name: "single_write_stderr_data_message", + inputs: [][]byte{{0x82, 0x3, 0x2, 0x7, 0x8}}, + wantForwarded: []byte{0x82, 0x3, 0x2, 0x7, 0x8}, + wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8}, cl), + }, + { + name: "single_write_stdin_data_message", + inputs: [][]byte{{0x82, 0x3, 0x0, 0x7, 0x8}}, + wantForwarded: []byte{0x82, 0x3, 0x0, 0x7, 0x8}, + }, + { + name: "single_write_stdout_data_message_with_cast_header", + inputs: [][]byte{{0x82, 0x3, 0x1, 0x7, 0x8}}, + wantForwarded: []byte{0x82, 0x3, 0x1, 0x7, 0x8}, + wantRecorded: append(fakes.AsciinemaResizeMsg(t, 10, 20), fakes.CastLine(t, []byte{0x7, 0x8}, cl)...), + width: 10, + height: 20, + firstWrite: true, + }, + { + name: "two_writes_stdout_data_message", + inputs: [][]byte{{0x2, 0x3, 0x1, 0x7, 0x8}, {0x80, 0x6, 0x1, 0x1, 0x2, 0x3, 0x4, 0x5}}, + wantForwarded: []byte{0x2, 0x3, 0x1, 0x7, 0x8, 0x80, 0x6, 0x1, 0x1, 0x2, 0x3, 0x4, 0x5}, + wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8, 0x1, 0x2, 0x3, 0x4, 0x5}, cl), + }, + { + name: "three_writes_stdout_data_message_with_split_fragment", + inputs: [][]byte{{0x2, 0x3, 0x1, 0x7, 0x8}, {0x80, 0x6, 0x1, 0x1, 0x2, 0x3}, {0x4, 0x5}}, + wantForwarded: []byte{0x2, 0x3, 0x1, 0x7, 0x8, 0x80, 0x6, 0x1, 0x1, 0x2, 0x3, 0x4, 0x5}, + wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8, 0x1, 0x2, 0x3, 0x4, 0x5}, cl), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tc := &fakes.TestConn{} + sr := &fakes.TestSessionRecorder{} + rec := tsrecorder.New(sr, cl, cl.Now(), true) + c := &conn{ + Conn: tc, + log: zl.Sugar(), + ch: sessionrecording.CastHeader{ + Width: tt.width, + Height: tt.height, + }, + rec: rec, + } + if !tt.firstWrite { + // This test case does not intend to test that cast header gets written once. + c.writeCastHeaderOnce.Do(func() {}) + } + for i, input := range tt.inputs { + _, err := c.Write(input) + if err != nil { + t.Fatalf("[%d] conn.Write() errored: %v", i, err) + } + } + // Assert that the expected bytes have been forwarded to the original destination. + gotForwarded := tc.WriteBufBytes() + if !reflect.DeepEqual(gotForwarded, tt.wantForwarded) { + t.Errorf("expected bytes not forwarded, wants\n%x\ngot\n%x", tt.wantForwarded, gotForwarded) + } + + // Assert that the expected bytes have been forwarded to the session recorder. + gotRecorded := sr.Bytes() + if !reflect.DeepEqual(gotRecorded, tt.wantRecorded) { + t.Errorf("expected bytes not recorded, wants\n%b\ngot\n%b", tt.wantRecorded, gotRecorded) + } + }) + } +} + +// Test_conn_ReadRand tests reading arbitrarily generated byte slices from conn to +// test that we don't panic when parsing input from a broken or malicious +// client. +func Test_conn_ReadRand(t *testing.T) { + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatalf("error creating a test logger: %v", err) + } + for i := range 100 { + tc := &fakes.TestConn{} + tc.ResetReadBuf() + c := &conn{ + Conn: tc, + log: zl.Sugar(), + } + bb := fakes.RandomBytes(t) + for j, input := range bb { + if err := tc.WriteReadBufBytes(input); err != nil { + t.Fatalf("[%d] writing bytes to test conn: %v", i, err) + } + f := func() { + c.Read(make([]byte, len(input))) + } + testPanic(t, f, fmt.Sprintf("[%d %d] Read panic parsing input of length %d first bytes: %v, current read message: %+#v", i, j, len(input), firstBytes(input), c.currentReadMsg)) + } + } +} + +// Test_conn_WriteRand calls conn.Write with an arbitrary input to validate that it does not +// panic. +func Test_conn_WriteRand(t *testing.T) { + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatalf("error creating a test logger: %v", err) + } + cl := tstest.NewClock(tstest.ClockOpts{}) + sr := &fakes.TestSessionRecorder{} + rec := tsrecorder.New(sr, cl, cl.Now(), true) + for i := range 100 { + tc := &fakes.TestConn{} + c := &conn{ + Conn: tc, + log: zl.Sugar(), + rec: rec, + } + bb := fakes.RandomBytes(t) + for j, input := range bb { + f := func() { + c.Write(input) + } + testPanic(t, f, fmt.Sprintf("[%d %d] Write: panic parsing input of length %d first bytes %b current write message %+#v", i, j, len(input), firstBytes(input), c.currentWriteMsg)) + } + } +} + +func testPanic(t *testing.T, f func(), msg string) { + t.Helper() + defer func() { + if r := recover(); r != nil { + t.Fatal(msg, r) + } + }() + f() +} + +func firstBytes(b []byte) []byte { + if len(b) < 10 { + return b + } + return b[:10] +} diff --git a/k8s-operator/sessionrecording/ws/message.go b/k8s-operator/sessionrecording/ws/message.go new file mode 100644 index 000000000..713febec7 --- /dev/null +++ b/k8s-operator/sessionrecording/ws/message.go @@ -0,0 +1,267 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package ws + +import ( + "encoding/binary" + "fmt" + "sync/atomic" + + "github.com/pkg/errors" + "go.uber.org/zap" + + "golang.org/x/net/websocket" +) + +const ( + noOpcode messageType = 0 // continuation frame for fragmented messages + binaryMessage messageType = 2 +) + +// messageType is the type of a websocket data or control message as defined by opcode. +// https://www.rfc-editor.org/rfc/rfc6455#section-5.2 +// Known types of control messages are close, ping and pong. +// https://www.rfc-editor.org/rfc/rfc6455#section-5.5 +// The only data message type supported by Kubernetes is binary message +// https://github.com/kubernetes/client-go/blob/v0.30.0-rc.1/tools/remotecommand/websocket.go#L281 +type messageType int + +// message is a parsed Websocket Message. +type message struct { + // payload is the contents of the so far parsed Websocket + // data Message payload, potentially from multiple fragments written by + // multiple invocations of Parse. As per RFC 6455 We can assume that the + // fragments will always arrive in order and data messages will not be + // interleaved. + payload []byte + + // isFinalized is set to true if msgPayload contains full contents of + // the message (the final fragment has been received). + isFinalized bool + + // streamID is the stream to which the message belongs, i.e stdin, stout + // etc. It is one of the stream IDs defined in + // https://github.com/kubernetes/apimachinery/blob/73d12d09c5be8703587b5127416eb83dc3b7e182/pkg/util/httpstream/wsstream/doc.go#L23-L36 + streamID atomic.Uint32 + + // typ is the type of a WebsocketMessage as defined by its opcode + // https://www.rfc-editor.org/rfc/rfc6455#section-5.2 + typ messageType + raw []byte +} + +// Parse accepts a websocket message fragment as a byte slice and parses its contents. +// It returns true if the fragment is complete, false if the fragment is incomplete. +// If the fragment is incomplete, Parse will be called again with the same fragment + more bytes when those are received. +// If the fragment is complete, it will be parsed into msg. +// A complete fragment can be: +// - a fragment that consists of a whole message +// - an initial fragment for a message for which we expect more fragments +// - a subsequent fragment for a message that we are currently parsing and whose so-far parsed contents are stored in msg. +// Parse must not be called with bytes that don't contain fragment header (so, no less than 2 bytes). +// 0 1 2 3 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +// +-+-+-+-+-------+-+-------------+-------------------------------+ +// |F|R|R|R| opcode|M| Payload len | Extended payload length | +// |I|S|S|S| (4) |A| (7) | (16/64) | +// |N|V|V|V| |S| | (if payload len==126/127) | +// | |1|2|3| |K| | | +// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + +// | Extended payload length continued, if payload len == 127 | +// + - - - - - - - - - - - - - - - +-------------------------------+ +// | |Masking-key, if MASK set to 1 | +// +-------------------------------+-------------------------------+ +// | Masking-key (continued) | Payload Data | +// +-------------------------------- - - - - - - - - - - - - - - - + +// : Payload Data continued ... : +// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + +// | Payload Data continued ... | +// +---------------------------------------------------------------+ +// https://www.rfc-editor.org/rfc/rfc6455#section-5.2 +// +// Fragmentation rules: +// An unfragmented message consists of a single frame with the FIN +// bit set (Section 5.2) and an opcode other than 0. +// A fragmented message consists of a single frame with the FIN bit +// clear and an opcode other than 0, followed by zero or more frames +// with the FIN bit clear and the opcode set to 0, and terminated by +// a single frame with the FIN bit set and an opcode of 0. +// https://www.rfc-editor.org/rfc/rfc6455#section-5.4 +func (msg *message) Parse(b []byte, log *zap.SugaredLogger) (bool, error) { + if len(b) < 2 { + return false, fmt.Errorf("[unexpected] Parse should not be called with less than 2 bytes, got %d bytes", len(b)) + } + if msg.typ != binaryMessage { + return false, fmt.Errorf("[unexpected] internal error: attempted to parse a message with type %d", msg.typ) + } + isInitialFragment := len(msg.raw) == 0 + + msg.isFinalized = isFinalFragment(b) + + maskSet := isMasked(b) + + payloadLength, payloadOffset, maskOffset, err := fragmentDimensions(b, maskSet) + if err != nil { + return false, fmt.Errorf("error determining payload length: %w", err) + } + log.Debugf("parse: parsing a message fragment with payload length: %d payload offset: %d maskOffset: %d mask set: %t, is finalized: %t, is initial fragment: %t", payloadLength, payloadOffset, maskOffset, maskSet, msg.isFinalized, isInitialFragment) + + if len(b) < int(payloadOffset+payloadLength) { // incomplete fragment + return false, nil + } + // TODO (irbekrm): perhaps only do this extra allocation if we know we + // will need to unmask? + msg.raw = make([]byte, int(payloadOffset)+int(payloadLength)) + copy(msg.raw, b[:payloadOffset+payloadLength]) + + // Extract the payload. + msgPayload := b[payloadOffset : payloadOffset+payloadLength] + + // Unmask the payload if needed. + // TODO (irbekrm): instead of unmasking all of the payload each time, + // determine if the payload is for a resize message early and skip + // unmasking the remaining bytes if not. + if maskSet { + m := b[maskOffset:payloadOffset] + var mask [4]byte + copy(mask[:], m) + maskBytes(mask, msgPayload) + } + + // Determine what stream the message is for. Stream ID of a Kubernetes + // streaming session is a 32bit integer, stored in the first byte of the + // message payload. + // https://github.com/kubernetes/apimachinery/commit/73d12d09c5be8703587b5127416eb83dc3b7e182#diff-291f96e8632d04d2d20f5fb00f6b323492670570d65434e8eac90c7a442d13bdR23-R36 + if len(msgPayload) == 0 { + return false, errors.New("[unexpected] received a message fragment with no stream ID") + } + + streamID := uint32(msgPayload[0]) + if !isInitialFragment && msg.streamID.Load() != streamID { + return false, fmt.Errorf("[unexpected] received message fragments with mismatched streamIDs %d and %d", msg.streamID.Load(), streamID) + } + msg.streamID.Store(streamID) + + // This is normal, Kubernetes seem to send a couple data messages with + // no payloads at the start. + if len(msgPayload) < 2 { + return true, nil + } + msgPayload = msgPayload[1:] // remove the stream ID byte + msg.payload = append(msg.payload, msgPayload...) + return true, nil +} + +// maskBytes applies mask to bytes in place. +// https://www.rfc-editor.org/rfc/rfc6455#section-5.3 +func maskBytes(key [4]byte, b []byte) { + for i := range b { + b[i] = b[i] ^ key[i%4] + } +} + +// isControlMessage returns true if the message type is one of the known control +// frame message types. +// https://www.rfc-editor.org/rfc/rfc6455#section-5.5 +func isControlMessage(t messageType) bool { + const ( + closeMessage messageType = 8 + pingMessage messageType = 9 + pongMessage messageType = 10 + ) + return t == closeMessage || t == pingMessage || t == pongMessage +} + +// isFinalFragment can be called with websocket message fragment and returns true if +// the fragment is the final fragment of a websocket message. +func isFinalFragment(b []byte) bool { + return extractFirstBit(b[0]) != 0 +} + +// isMasked can be called with a websocket message fragment and returns true if +// the payload of the message is masked. It uses the mask bit to determine if +// the payload is masked. +// https://www.rfc-editor.org/rfc/rfc6455#section-5.3 +func isMasked(b []byte) bool { + return extractFirstBit(b[1]) != 0 +} + +// extractFirstBit extracts first bit of a byte by zeroing out all the other +// bits. +func extractFirstBit(b byte) byte { + return b & 0x80 +} + +// zeroFirstBit returns the provided byte with the first bit set to 0. +func zeroFirstBit(b byte) byte { + return b & 0x7f +} + +// fragmentDimensions returns payload length as well as payload offset and mask offset. +func fragmentDimensions(b []byte, maskSet bool) (payloadLength, payloadOffset, maskOffset uint64, _ error) { + + // payload length can be stored either in bits [9-15] or in bytes 2, 3 + // or in bytes 2, 3, 4, 5, 6, 7. + // https://www.rfc-editor.org/rfc/rfc6455#section-5.2 + // 0 1 2 3 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // +-+-+-+-+-------+-+-------------+-------------------------------+ + // |F|R|R|R| opcode|M| Payload len | Extended payload length | + // |I|S|S|S| (4) |A| (7) | (16/64) | + // |N|V|V|V| |S| | (if payload len==126/127) | + // | |1|2|3| |K| | | + // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + // | Extended payload length continued, if payload len == 127 | + // + - - - - - - - - - - - - - - - +-------------------------------+ + // | |Masking-key, if MASK set to 1 | + // +-------------------------------+-------------------------------+ + payloadLengthIndicator := zeroFirstBit(b[1]) + switch { + case payloadLengthIndicator < 126: + maskOffset = 2 + payloadLength = uint64(payloadLengthIndicator) + case payloadLengthIndicator == 126: + maskOffset = 4 + if len(b) < int(maskOffset) { + return 0, 0, 0, fmt.Errorf("invalid message fragment- length indicator suggests that length is stored in bytes 2:4, but message length is only %d", len(b)) + } + payloadLength = uint64(binary.BigEndian.Uint16(b[2:4])) + case payloadLengthIndicator == 127: + maskOffset = 10 + if len(b) < int(maskOffset) { + return 0, 0, 0, fmt.Errorf("invalid message fragment- length indicator suggests that length is stored in bytes 2:10, but message length is only %d", len(b)) + } + payloadLength = binary.BigEndian.Uint64(b[2:10]) + default: + return 0, 0, 0, fmt.Errorf("unexpected payload length indicator value: %v", payloadLengthIndicator) + } + + // Ensure that a rogue or broken client doesn't cause us attempt to + // allocate a huge array by setting a high payload size. + // websocket.DefaultMaxPayloadBytes is the maximum payload size accepted + // by server side of this connection, so we can safely reject messages + // with larger payload size. + if payloadLength > websocket.DefaultMaxPayloadBytes { + return 0, 0, 0, fmt.Errorf("[unexpected]: too large payload size: %v", payloadLength) + } + + // Masking key can take up 0 or 4 bytes- we need to take that into + // account when determining payload offset. + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // .... + // + - - - - - - - - - - - - - - - +-------------------------------+ + // | |Masking-key, if MASK set to 1 | + // +-------------------------------+-------------------------------+ + // | Masking-key (continued) | Payload Data | + // + - - - - - - - - - - - - - - - +-------------------------------+ + // ... + if maskSet { + payloadOffset = maskOffset + 4 + } else { + payloadOffset = maskOffset + } + return +} diff --git a/k8s-operator/sessionrecording/ws/message_test.go b/k8s-operator/sessionrecording/ws/message_test.go new file mode 100644 index 000000000..f634f86dc --- /dev/null +++ b/k8s-operator/sessionrecording/ws/message_test.go @@ -0,0 +1,215 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package ws + +import ( + "encoding/binary" + "fmt" + "reflect" + "testing" + "time" + + "math/rand" + + "go.uber.org/zap" + "golang.org/x/net/websocket" +) + +func Test_msg_Parse(t *testing.T) { + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatalf("error creating a test logger: %v", err) + } + testMask := [4]byte{1, 2, 3, 4} + bs126, bs126Len := bytesSlice2ByteLen(t) + bs127, bs127Len := byteSlice8ByteLen(t) + tests := []struct { + name string + b []byte + initialPayload []byte + wantPayload []byte + wantIsFinalized bool + wantStreamID uint32 + wantErr bool + }{ + { + name: "single_fragment_stdout_stream_no_payload_no_mask", + b: []byte{0x82, 0x1, 0x1}, + wantPayload: nil, + wantIsFinalized: true, + wantStreamID: 1, + }, + { + name: "single_fragment_stderr_steam_no_payload_has_mask", + b: append([]byte{0x82, 0x81, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x2})...), + wantPayload: nil, + wantIsFinalized: true, + wantStreamID: 2, + }, + { + name: "single_fragment_stdout_stream_no_mask_has_payload", + b: []byte{0x82, 0x3, 0x1, 0x7, 0x8}, + wantPayload: []byte{0x7, 0x8}, + wantIsFinalized: true, + wantStreamID: 1, + }, + { + name: "single_fragment_stdout_stream_has_mask_has_payload", + b: append([]byte{0x82, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...), + wantPayload: []byte{0x7, 0x8}, + wantIsFinalized: true, + wantStreamID: 1, + }, + { + name: "initial_fragment_stdout_stream_no_mask_has_payload", + b: []byte{0x2, 0x3, 0x1, 0x7, 0x8}, + wantPayload: []byte{0x7, 0x8}, + wantStreamID: 1, + }, + { + name: "initial_fragment_stdout_stream_has_mask_has_payload", + b: append([]byte{0x2, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...), + wantPayload: []byte{0x7, 0x8}, + wantStreamID: 1, + }, + { + name: "subsequent_fragment_stdout_stream_no_mask_has_payload", + b: []byte{0x0, 0x3, 0x1, 0x7, 0x8}, + initialPayload: []byte{0x1, 0x2, 0x3}, + wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8}, + wantStreamID: 1, + }, + { + name: "subsequent_fragment_stdout_stream_has_mask_has_payload", + b: append([]byte{0x0, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...), + initialPayload: []byte{0x1, 0x2, 0x3}, + wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8}, + wantStreamID: 1, + }, + { + name: "final_fragment_stdout_stream_no_mask_has_payload", + b: []byte{0x80, 0x3, 0x1, 0x7, 0x8}, + initialPayload: []byte{0x1, 0x2, 0x3}, + wantIsFinalized: true, + wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8}, + wantStreamID: 1, + }, + { + name: "final_fragment_stdout_stream_has_mask_has_payload", + b: append([]byte{0x80, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...), + initialPayload: []byte{0x1, 0x2, 0x3}, + wantIsFinalized: true, + wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8}, + wantStreamID: 1, + }, + { + name: "single_large_fragment_no_mask_length_hint_126", + b: append(append([]byte{0x80, 0x7e}, bs126Len...), append([]byte{0x1}, bs126...)...), + wantIsFinalized: true, + wantPayload: bs126, + wantStreamID: 1, + }, + { + name: "single_large_fragment_no_mask_length_hint_127", + b: append(append([]byte{0x80, 0x7f}, bs127Len...), append([]byte{0x1}, bs127...)...), + wantIsFinalized: true, + wantPayload: bs127, + wantStreamID: 1, + }, + { + name: "zero_length_bytes", + b: []byte{}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + msg := &message{ + typ: binaryMessage, + payload: tt.initialPayload, + } + if _, err := msg.Parse(tt.b, zl.Sugar()); (err != nil) != tt.wantErr { + t.Errorf("msg.Parse() = %v, wantsErr: %t", err, tt.wantErr) + } + if msg.isFinalized != tt.wantIsFinalized { + t.Errorf("wants message to be finalized: %t, got: %t", tt.wantIsFinalized, msg.isFinalized) + } + if msg.streamID.Load() != tt.wantStreamID { + t.Errorf("wants stream ID: %d, got: %d", tt.wantStreamID, msg.streamID.Load()) + } + if !reflect.DeepEqual(msg.payload, tt.wantPayload) { + t.Errorf("unexpected message payload after Parse, wants %b got %b", tt.wantPayload, msg.payload) + } + }) + } +} + +// Test_msg_Parse_Rand calls Parse with a randomly generated input to verify +// that it doesn't panic. +func Test_msg_Parse_Rand(t *testing.T) { + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatalf("error creating a test logger: %v", err) + } + r := rand.New(rand.NewSource(time.Now().UnixNano())) + for i := range 100 { + n := r.Intn(4096) + b := make([]byte, n) + _, err := r.Read(b) + if err != nil { + t.Fatalf("error generating random byte slice: %v", err) + } + msg := message{typ: binaryMessage} + f := func() { + msg.Parse(b, zl.Sugar()) + } + testPanic(t, f, fmt.Sprintf("[%d] Parse panicked running with byte slice of length %d: %v", i, n, r)) + } +} + +// byteSlice2ByteLen generates a number that represents websocket message fragment length and is stored in an 8 byte slice. +// Returns the byte slice with the length as well as a slice of arbitrary bytes of the given length. +// This is used to generate test input representing websocket message with payload length hint 126. +// https://www.rfc-editor.org/rfc/rfc6455#section-5.2 +func bytesSlice2ByteLen(t *testing.T) ([]byte, []byte) { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + var n uint16 + n = uint16(rand.Intn(65535 - 1)) // space for and additional 1 byte stream ID + b := make([]byte, n) + _, err := r.Read(b) + if err != nil { + t.Fatalf("error generating random byte slice: %v ", err) + } + bb := make([]byte, 2) + binary.BigEndian.PutUint16(bb, n+1) // + stream ID + return b, bb +} + +// byteSlice8ByteLen generates a number that represents websocket message fragment length and is stored in an 8 byte slice. +// Returns the byte slice with the length as well as a slice of arbitrary bytes of the given length. +// This is used to generate test input representing websocket message with payload length hint 127. +// https://www.rfc-editor.org/rfc/rfc6455#section-5.2 +func byteSlice8ByteLen(t *testing.T) ([]byte, []byte) { + nanos := time.Now().UnixNano() + t.Logf("Creating random source with seed %v", nanos) + r := rand.New(rand.NewSource(nanos)) + var n uint64 + n = uint64(rand.Intn(websocket.DefaultMaxPayloadBytes - 1)) // space for and additional 1 byte stream ID + t.Logf("byteSlice8ByteLen: generating message payload of length %d", n) + b := make([]byte, n) + _, err := r.Read(b) + if err != nil { + t.Fatalf("error generating random byte slice: %v ", err) + } + bb := make([]byte, 8) + binary.BigEndian.PutUint64(bb, n+1) // + stream ID + return b, bb +} + +func maskedBytes(mask [4]byte, b []byte) []byte { + maskBytes(mask, b) + return b +}