cmd/k8s-operator,k8s-operator/sessionrecording: support recording kubectl exec sessions over WebSockets (#12947)

cmd/k8s-operator,k8s-operator/sessionrecording: support recording WebSocket sessions

Kubernetes currently supports two streaming protocols, SPDY and WebSockets.
WebSockets are replacing SPDY, see
https://github.com/kubernetes/enhancements/issues/4006.
We were currently only supporting SPDY, erroring out if session
was not SPDY and relying on the kube's built-in SPDY fallback.

This PR:

- adds support for parsing contents of 'kubectl exec' sessions streamed
over WebSockets

- adds logic to distinguish 'kubectl exec' requests for a SPDY/WebSockets
sessions and call the relevant handler

Updates tailscale/corp#19821

Signed-off-by: Irbe Krumina <irbe@tailscale.com>
Co-authored-by: Tom Proctor <tomhjp@users.noreply.github.com>
This commit is contained in:
Irbe Krumina 2024-08-14 19:57:50 +03:00 committed by GitHub
parent 4c2e978f1e
commit a15ff1bade
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1270 additions and 73 deletions

View File

@ -417,6 +417,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
k8s.io/apimachinery/pkg/util/naming from k8s.io/apimachinery/pkg/runtime+ k8s.io/apimachinery/pkg/util/naming from k8s.io/apimachinery/pkg/runtime+
k8s.io/apimachinery/pkg/util/net from k8s.io/apimachinery/pkg/watch+ k8s.io/apimachinery/pkg/util/net from k8s.io/apimachinery/pkg/watch+
k8s.io/apimachinery/pkg/util/rand from k8s.io/apiserver/pkg/storage/names k8s.io/apimachinery/pkg/util/rand from k8s.io/apiserver/pkg/storage/names
k8s.io/apimachinery/pkg/util/remotecommand from tailscale.com/k8s-operator/sessionrecording/ws
k8s.io/apimachinery/pkg/util/runtime from k8s.io/apimachinery/pkg/apis/meta/internalversion/scheme+ k8s.io/apimachinery/pkg/util/runtime from k8s.io/apimachinery/pkg/apis/meta/internalversion/scheme+
k8s.io/apimachinery/pkg/util/sets from k8s.io/apimachinery/pkg/api/meta+ k8s.io/apimachinery/pkg/util/sets from k8s.io/apimachinery/pkg/api/meta+
k8s.io/apimachinery/pkg/util/strategicpatch from k8s.io/client-go/tools/record+ k8s.io/apimachinery/pkg/util/strategicpatch from k8s.io/client-go/tools/record+
@ -686,9 +687,9 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
tailscale.com/k8s-operator/apis from tailscale.com/k8s-operator/apis/v1alpha1 tailscale.com/k8s-operator/apis from tailscale.com/k8s-operator/apis/v1alpha1
tailscale.com/k8s-operator/apis/v1alpha1 from tailscale.com/cmd/k8s-operator+ tailscale.com/k8s-operator/apis/v1alpha1 from tailscale.com/cmd/k8s-operator+
tailscale.com/k8s-operator/sessionrecording from tailscale.com/cmd/k8s-operator tailscale.com/k8s-operator/sessionrecording from tailscale.com/cmd/k8s-operator
tailscale.com/k8s-operator/sessionrecording/conn from tailscale.com/k8s-operator/sessionrecording/spdy
tailscale.com/k8s-operator/sessionrecording/spdy from tailscale.com/k8s-operator/sessionrecording tailscale.com/k8s-operator/sessionrecording/spdy from tailscale.com/k8s-operator/sessionrecording
tailscale.com/k8s-operator/sessionrecording/tsrecorder from tailscale.com/k8s-operator/sessionrecording+ tailscale.com/k8s-operator/sessionrecording/tsrecorder from tailscale.com/k8s-operator/sessionrecording+
tailscale.com/k8s-operator/sessionrecording/ws from tailscale.com/k8s-operator/sessionrecording
tailscale.com/kube from tailscale.com/cmd/k8s-operator+ tailscale.com/kube from tailscale.com/cmd/k8s-operator+
tailscale.com/licenses from tailscale.com/client/web tailscale.com/licenses from tailscale.com/client/web
tailscale.com/log/filelogger from tailscale.com/logpolicy tailscale.com/log/filelogger from tailscale.com/logpolicy
@ -741,7 +742,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
tailscale.com/posture from tailscale.com/ipn/ipnlocal tailscale.com/posture from tailscale.com/ipn/ipnlocal
tailscale.com/proxymap from tailscale.com/tsd+ tailscale.com/proxymap from tailscale.com/tsd+
💣 tailscale.com/safesocket from tailscale.com/client/tailscale+ 💣 tailscale.com/safesocket from tailscale.com/client/tailscale+
tailscale.com/sessionrecording from tailscale.com/cmd/k8s-operator+ tailscale.com/sessionrecording from tailscale.com/k8s-operator/sessionrecording+
tailscale.com/syncs from tailscale.com/control/controlknobs+ tailscale.com/syncs from tailscale.com/control/controlknobs+
tailscale.com/tailcfg from tailscale.com/client/tailscale+ tailscale.com/tailcfg from tailscale.com/client/tailscale+
tailscale.com/taildrop from tailscale.com/ipn/ipnlocal+ tailscale.com/taildrop from tailscale.com/ipn/ipnlocal+
@ -863,6 +864,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
golang.org/x/net/ipv6 from github.com/miekg/dns+ golang.org/x/net/ipv6 from github.com/miekg/dns+
golang.org/x/net/proxy from tailscale.com/net/netns golang.org/x/net/proxy from tailscale.com/net/netns
D golang.org/x/net/route from net+ D golang.org/x/net/route from net+
golang.org/x/net/websocket from tailscale.com/k8s-operator/sessionrecording/ws
golang.org/x/oauth2 from golang.org/x/oauth2/clientcredentials+ golang.org/x/oauth2 from golang.org/x/oauth2/clientcredentials+
golang.org/x/oauth2/clientcredentials from tailscale.com/cmd/k8s-operator golang.org/x/oauth2/clientcredentials from tailscale.com/cmd/k8s-operator
golang.org/x/oauth2/internal from golang.org/x/oauth2+ golang.org/x/oauth2/internal from golang.org/x/oauth2+

View File

@ -22,9 +22,8 @@
"k8s.io/client-go/transport" "k8s.io/client-go/transport"
"tailscale.com/client/tailscale" "tailscale.com/client/tailscale"
"tailscale.com/client/tailscale/apitype" "tailscale.com/client/tailscale/apitype"
kubesessionrecording "tailscale.com/k8s-operator/sessionrecording" ksr "tailscale.com/k8s-operator/sessionrecording"
tskube "tailscale.com/kube" tskube "tailscale.com/kube"
"tailscale.com/sessionrecording"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/tsnet" "tailscale.com/tsnet"
"tailscale.com/util/clientmetric" "tailscale.com/util/clientmetric"
@ -168,7 +167,8 @@ func runAPIServerProxy(ts *tsnet.Server, rt http.RoundTripper, log *zap.SugaredL
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/", ap.serveDefault) mux.HandleFunc("/", ap.serveDefault)
mux.HandleFunc("/api/v1/namespaces/{namespace}/pods/{pod}/exec", ap.serveExec) mux.HandleFunc("POST /api/v1/namespaces/{namespace}/pods/{pod}/exec", ap.serveExecSPDY)
mux.HandleFunc("GET /api/v1/namespaces/{namespace}/pods/{pod}/exec", ap.serveExecWS)
hs := &http.Server{ hs := &http.Server{
// Kubernetes uses SPDY for exec and port-forward, however SPDY is // Kubernetes uses SPDY for exec and port-forward, however SPDY is
@ -209,9 +209,19 @@ func (ap *apiserverProxy) serveDefault(w http.ResponseWriter, r *http.Request) {
ap.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who))) ap.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who)))
} }
// serveExec serves 'kubectl exec' requests, optionally configuring the kubectl // serveExecSPDY serves 'kubectl exec' requests for sessions streamed over SPDY,
// exec sessions to be recorded. // optionally configuring the kubectl exec sessions to be recorded.
func (ap *apiserverProxy) serveExec(w http.ResponseWriter, r *http.Request) { func (ap *apiserverProxy) serveExecSPDY(w http.ResponseWriter, r *http.Request) {
ap.execForProto(w, r, ksr.SPDYProtocol)
}
// serveExecWS serves 'kubectl exec' requests for sessions streamed over WebSocket,
// optionally configuring the kubectl exec sessions to be recorded.
func (ap *apiserverProxy) serveExecWS(w http.ResponseWriter, r *http.Request) {
ap.execForProto(w, r, ksr.WSProtocol)
}
func (ap *apiserverProxy) execForProto(w http.ResponseWriter, r *http.Request, proto ksr.Protocol) {
who, err := ap.whoIs(r) who, err := ap.whoIs(r)
if err != nil { if err != nil {
ap.authError(w, err) ap.authError(w, err)
@ -227,15 +237,17 @@ func (ap *apiserverProxy) serveExec(w http.ResponseWriter, r *http.Request) {
ap.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who))) ap.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who)))
return return
} }
kubesessionrecording.CounterSessionRecordingsAttempted.Add(1) // at this point we know that users intended for this session to be recorded ksr.CounterSessionRecordingsAttempted.Add(1) // at this point we know that users intended for this session to be recorded
if !failOpen && len(addrs) == 0 { if !failOpen && len(addrs) == 0 {
msg := "forbidden: 'kubectl exec' session must be recorded, but no recorders are available." msg := "forbidden: 'kubectl exec' session must be recorded, but no recorders are available."
ap.log.Error(msg) ap.log.Error(msg)
http.Error(w, msg, http.StatusForbidden) http.Error(w, msg, http.StatusForbidden)
return return
} }
if r.Method != "POST" || r.Header.Get("Upgrade") != "SPDY/3.1" {
msg := "'kubectl exec' session recording is configured, but the request is not over SPDY. Session recording is currently only supported for SPDY based clients" wantsHeader := upgradeHeaderForProto[proto]
if h := r.Header.Get("Upgrade"); h != wantsHeader {
msg := fmt.Sprintf("[unexpected] unable to verify that streaming protocol is %s, wants Upgrade header %q, got: %q", proto, wantsHeader, h)
if failOpen { if failOpen {
msg = msg + "; failure mode is 'fail open'; continuing session without recording." msg = msg + "; failure mode is 'fail open'; continuing session without recording."
ap.log.Warn(msg) ap.log.Warn(msg)
@ -247,9 +259,22 @@ func (ap *apiserverProxy) serveExec(w http.ResponseWriter, r *http.Request) {
http.Error(w, msg, http.StatusForbidden) http.Error(w, msg, http.StatusForbidden)
return return
} }
spdyH := kubesessionrecording.New(ap.ts, r, who, w, r.PathValue("pod"), r.PathValue("namespace"), kubesessionrecording.SPDYProtocol, addrs, failOpen, sessionrecording.ConnectToRecorder, ap.log)
ap.rp.ServeHTTP(spdyH, r.WithContext(whoIsKey.WithValue(r.Context(), who))) opts := ksr.HijackerOpts{
Req: r,
W: w,
Proto: proto,
TS: ap.ts,
Who: who,
Addrs: addrs,
FailOpen: failOpen,
Pod: r.PathValue("pod"),
Namespace: r.PathValue("namespace"),
Log: ap.log,
}
h := ksr.New(opts)
ap.rp.ServeHTTP(h, r.WithContext(whoIsKey.WithValue(r.Context(), who)))
} }
func (h *apiserverProxy) addImpersonationHeadersAsRequired(r *http.Request) { func (h *apiserverProxy) addImpersonationHeadersAsRequired(r *http.Request) {
@ -382,3 +407,8 @@ func determineRecorderConfig(who *apitype.WhoIsResponse) (failOpen bool, recorde
} }
return failOpen, recorderAddresses, nil return failOpen, recorderAddresses, nil
} }
var upgradeHeaderForProto = map[ksr.Protocol]string{
ksr.SPDYProtocol: "SPDY/3.1",
ksr.WSProtocol: "websocket",
}

View File

@ -1,20 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !plan9
// Package conn contains shared interface for the hijacked
// connection of a 'kubectl exec' session that is being recorded.
package conn
import "net"
type Conn interface {
net.Conn
// Fail can be called to set connection state to failed. By default any
// bytes left over in write buffer are forwarded to the intended
// destination when the connection is being closed except for when the
// connection state is failed- so set the state to failed when erroring
// out and failure policy is to fail closed.
Fail()
}

View File

@ -13,6 +13,9 @@
"net" "net"
"sync" "sync"
"testing" "testing"
"time"
"math/rand"
"tailscale.com/sessionrecording" "tailscale.com/sessionrecording"
"tailscale.com/tstime" "tailscale.com/tstime"
@ -116,3 +119,20 @@ func AsciinemaResizeMsg(t *testing.T, width, height int) []byte {
} }
return append(bs, '\n') return append(bs, '\n')
} }
func RandomBytes(t *testing.T) [][]byte {
t.Helper()
r := rand.New(rand.NewSource(time.Now().UnixNano()))
n := r.Intn(4096)
b := make([]byte, n)
t.Logf("RandomBytes: generating byte slice of length %d", n)
_, err := r.Read(b)
if err != nil {
t.Fatalf("error generating random byte slice: %v", err)
}
if len(b) < 2 {
return [][]byte{b}
}
split := r.Intn(len(b) - 1)
return [][]byte{b[:split], b[split:]}
}

View File

@ -23,6 +23,7 @@
"tailscale.com/client/tailscale/apitype" "tailscale.com/client/tailscale/apitype"
"tailscale.com/k8s-operator/sessionrecording/spdy" "tailscale.com/k8s-operator/sessionrecording/spdy"
"tailscale.com/k8s-operator/sessionrecording/tsrecorder" "tailscale.com/k8s-operator/sessionrecording/tsrecorder"
"tailscale.com/k8s-operator/sessionrecording/ws"
"tailscale.com/sessionrecording" "tailscale.com/sessionrecording"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/tsnet" "tailscale.com/tsnet"
@ -31,11 +32,14 @@
"tailscale.com/util/multierr" "tailscale.com/util/multierr"
) )
const SPDYProtocol protocol = "SPDY" const (
SPDYProtocol Protocol = "SPDY"
WSProtocol Protocol = "WebSocket"
)
// protocol is the streaming protocol of the hijacked session. Supported // Protocol is the streaming protocol of the hijacked session. Supported
// protocols are SPDY. // protocols are SPDY and WebSocket.
type protocol string type Protocol string
var ( var (
// CounterSessionRecordingsAttempted counts the number of session recording attempts. // CounterSessionRecordingsAttempted counts the number of session recording attempts.
@ -45,22 +49,35 @@
counterSessionRecordingsUploaded = clientmetric.NewCounter("k8s_auth_proxy_session_recordings_uploaded") counterSessionRecordingsUploaded = clientmetric.NewCounter("k8s_auth_proxy_session_recordings_uploaded")
) )
func New(ts *tsnet.Server, req *http.Request, who *apitype.WhoIsResponse, w http.ResponseWriter, pod, ns string, proto protocol, addrs []netip.AddrPort, failOpen bool, connFunc RecorderDialFn, log *zap.SugaredLogger) *Hijacker { func New(opts HijackerOpts) *Hijacker {
return &Hijacker{ return &Hijacker{
ts: ts, ts: opts.TS,
req: req, req: opts.Req,
who: who, who: opts.Who,
ResponseWriter: w, ResponseWriter: opts.W,
pod: pod, pod: opts.Pod,
ns: ns, ns: opts.Namespace,
addrs: addrs, addrs: opts.Addrs,
failOpen: failOpen, failOpen: opts.FailOpen,
connectToRecorder: connFunc, proto: opts.Proto,
proto: proto, log: opts.Log,
log: log, connectToRecorder: sessionrecording.ConnectToRecorder,
} }
} }
type HijackerOpts struct {
TS *tsnet.Server
Req *http.Request
W http.ResponseWriter
Who *apitype.WhoIsResponse
Addrs []netip.AddrPort
Log *zap.SugaredLogger
Pod string
Namespace string
FailOpen bool
Proto Protocol
}
// Hijacker implements [net/http.Hijacker] interface. // Hijacker implements [net/http.Hijacker] interface.
// It must be configured with an http request for a 'kubectl exec' session that // It must be configured with an http request for a 'kubectl exec' session that
// needs to be recorded. It knows how to hijack the connection and configure for // needs to be recorded. It knows how to hijack the connection and configure for
@ -76,7 +93,7 @@ type Hijacker struct {
addrs []netip.AddrPort // tsrecorder addresses addrs []netip.AddrPort // tsrecorder addresses
failOpen bool // whether to fail open if recording fails failOpen bool // whether to fail open if recording fails
connectToRecorder RecorderDialFn connectToRecorder RecorderDialFn
proto protocol // streaming protocol proto Protocol // streaming protocol
} }
// RecorderDialFn dials the specified netip.AddrPorts that should be tsrecorder // RecorderDialFn dials the specified netip.AddrPorts that should be tsrecorder
@ -111,10 +128,14 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn,
// https://docs.asciinema.org/manual/asciicast/v2/ // https://docs.asciinema.org/manual/asciicast/v2/
asciicastv2 = 2 asciicastv2 = 2
) )
var wc io.WriteCloser var (
wc io.WriteCloser
err error
errChan <-chan error
)
h.log.Infof("kubectl exec session will be recorded, recorders: %v, fail open policy: %t", h.addrs, h.failOpen) h.log.Infof("kubectl exec session will be recorded, recorders: %v, fail open policy: %t", h.addrs, h.failOpen)
// TODO (irbekrm): send client a message that session will be recorded. // TODO (irbekrm): send client a message that session will be recorded.
rw, _, errChan, err := h.connectToRecorder(ctx, h.addrs, h.ts.Dial) wc, _, errChan, err = h.connectToRecorder(ctx, h.addrs, h.ts.Dial)
if err != nil { if err != nil {
msg := fmt.Sprintf("error connecting to session recorders: %v", err) msg := fmt.Sprintf("error connecting to session recorders: %v", err)
if h.failOpen { if h.failOpen {
@ -131,7 +152,6 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn,
// TODO (irbekrm): log which recorder // TODO (irbekrm): log which recorder
h.log.Info("successfully connected to a session recorder") h.log.Info("successfully connected to a session recorder")
wc = rw
cl := tstime.DefaultClock{} cl := tstime.DefaultClock{}
rec := tsrecorder.New(wc, cl, cl.Now(), h.failOpen) rec := tsrecorder.New(wc, cl, cl.Now(), h.failOpen)
qp := h.req.URL.Query() qp := h.req.URL.Query()
@ -153,7 +173,17 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn,
} else { } else {
ch.SrcNodeTags = h.who.Node.Tags ch.SrcNodeTags = h.who.Node.Tags
} }
lc := spdy.New(conn, rec, ch, h.log)
var lc net.Conn
switch h.proto {
case SPDYProtocol:
lc = spdy.New(conn, rec, ch, h.log)
case WSProtocol:
lc = ws.New(conn, rec, ch, h.log)
default:
return nil, fmt.Errorf("unknown protocol: %s", h.proto)
}
go func() { go func() {
var err error var err error
select { select {
@ -174,7 +204,6 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn,
} }
msg += "; failure mode set to 'fail closed'; closing connection" msg += "; failure mode set to 'fail closed'; closing connection"
h.log.Error(msg) h.log.Error(msg)
lc.Fail()
// TODO (irbekrm): write a message to the client // TODO (irbekrm): write a message to the client
if err := lc.Close(); err != nil { if err := lc.Close(); err != nil {
h.log.Infof("error closing recorder connections: %v", err) h.log.Infof("error closing recorder connections: %v", err)

View File

@ -37,30 +37,40 @@ func Test_Hijacker(t *testing.T) {
failRecorderConnPostConnect bool // send error down the error channel failRecorderConnPostConnect bool // send error down the error channel
wantsConnClosed bool wantsConnClosed bool
wantsSetupErr bool wantsSetupErr bool
proto Protocol
}{ }{
{ {
name: "setup succeeds, conn stays open", name: "setup_succeeds_conn_stays_open",
proto: SPDYProtocol,
}, },
{ {
name: "setup fails, policy is to fail open, conn stays open", name: "setup_succeeds_conn_stays_open_ws",
proto: WSProtocol,
},
{
name: "setup_fails_policy_is_to_fail_open_conn_stays_open",
failOpen: true, failOpen: true,
failRecorderConnect: true, failRecorderConnect: true,
proto: SPDYProtocol,
}, },
{ {
name: "setup fails, policy is to fail closed, conn is closed", name: "setup_fails_policy_is_to_fail_closed_conn_is_closed",
failRecorderConnect: true, failRecorderConnect: true,
wantsSetupErr: true, wantsSetupErr: true,
wantsConnClosed: true, wantsConnClosed: true,
proto: SPDYProtocol,
}, },
{ {
name: "connection fails post-initial connect, policy is to fail open, conn stays open", name: "connection_fails_post-initial_connect_policy_is_to_fail_open_conn_stays_open",
failRecorderConnPostConnect: true, failRecorderConnPostConnect: true,
failOpen: true, failOpen: true,
proto: SPDYProtocol,
}, },
{ {
name: "connection fails post-initial connect, policy is to fail closed, conn is closed", name: "connection_fails_post-initial_connect,_policy_is_to_fail_closed_conn_is_closed",
failRecorderConnPostConnect: true, failRecorderConnPostConnect: true,
wantsConnClosed: true, wantsConnClosed: true,
proto: SPDYProtocol,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
@ -79,6 +89,7 @@ func Test_Hijacker(t *testing.T) {
log: zl.Sugar(), log: zl.Sugar(),
ts: &tsnet.Server{}, ts: &tsnet.Server{},
req: &http.Request{URL: &url.URL{}}, req: &http.Request{URL: &url.URL{}},
proto: tt.proto,
} }
ctx := context.Background() ctx := context.Background()
_, err := h.setUpRecording(ctx, tc) _, err := h.setUpRecording(ctx, tc)

View File

@ -19,12 +19,18 @@
"go.uber.org/zap" "go.uber.org/zap"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
srconn "tailscale.com/k8s-operator/sessionrecording/conn"
"tailscale.com/k8s-operator/sessionrecording/tsrecorder" "tailscale.com/k8s-operator/sessionrecording/tsrecorder"
"tailscale.com/sessionrecording" "tailscale.com/sessionrecording"
) )
func New(nc net.Conn, rec *tsrecorder.Client, ch sessionrecording.CastHeader, log *zap.SugaredLogger) srconn.Conn { // New wraps the provided network connection and returns a connection whose reads and writes will get triggered as data is received on the hijacked connection.
// The connection must be a hijacked connection for a 'kubectl exec' session using SPDY.
// The hijacked connection is used to transmit SPDY streams between Kubernetes client ('kubectl') and the destination container.
// Data read from the underlying network connection is data sent via one of the SPDY streams from the client to the container.
// Data written to the underlying connection is data sent from the container to the client.
// We parse the data and send everything for the STDOUT/STDERR streams to the configured tsrecorder as an asciinema recording with the provided header.
// https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/4006-transition-spdy-to-websockets#background-remotecommand-subprotocol
func New(nc net.Conn, rec *tsrecorder.Client, ch sessionrecording.CastHeader, log *zap.SugaredLogger) net.Conn {
return &conn{ return &conn{
Conn: nc, Conn: nc,
rec: rec, rec: rec,
@ -49,7 +55,6 @@ type conn struct {
wmu sync.Mutex // sequences writes wmu sync.Mutex // sequences writes
closed bool closed bool
failed bool
rmu sync.Mutex // sequences reads rmu sync.Mutex // sequences reads
writeCastHeaderOnce sync.Once writeCastHeaderOnce sync.Once
@ -172,9 +177,6 @@ func (c *conn) Close() error {
if c.closed { if c.closed {
return nil return nil
} }
if !c.failed && c.writeBuf.Len() > 0 {
c.Conn.Write(c.writeBuf.Bytes())
}
c.writeBuf.Reset() c.writeBuf.Reset()
c.closed = true c.closed = true
err := c.Conn.Close() err := c.Conn.Close()
@ -182,14 +184,8 @@ func (c *conn) Close() error {
return err return err
} }
func (s *conn) Fail() {
s.wmu.Lock()
s.failed = true
s.wmu.Unlock()
}
// storeStreamID parses SYN_STREAM SPDY control frame and updates // storeStreamID parses SYN_STREAM SPDY control frame and updates
// spdyRemoteConnRecorder to store the newly created stream's ID if it is one of // conn to store the newly created stream's ID if it is one of
// the stream types we care about. Storing stream_id:stream_type mapping allows // the stream types we care about. Storing stream_id:stream_type mapping allows
// us to parse received data frames (that have stream IDs) differently depening // us to parse received data frames (that have stream IDs) differently depening
// on which stream they belong to (i.e send data frame payload for stdout stream // on which stream they belong to (i.e send data frame payload for stdout stream

View File

@ -7,6 +7,7 @@
import ( import (
"encoding/json" "encoding/json"
"fmt"
"reflect" "reflect"
"testing" "testing"
@ -234,6 +235,57 @@ func Test_Reads(t *testing.T) {
} }
} }
// Test_conn_ReadRand tests reading arbitrarily generated byte slices from conn to
// test that we don't panic when parsing input from a broken or malicious
// client.
func Test_conn_ReadRand(t *testing.T) {
zl, err := zap.NewDevelopment()
if err != nil {
t.Fatalf("error creating a test logger: %v", err)
}
for i := range 1000 {
tc := &fakes.TestConn{}
tc.ResetReadBuf()
c := &conn{
Conn: tc,
log: zl.Sugar(),
}
bb := fakes.RandomBytes(t)
for j, input := range bb {
if err := tc.WriteReadBufBytes(input); err != nil {
t.Fatalf("[%d] writing bytes to test conn: %v", i, err)
}
f := func() {
c.Read(make([]byte, len(input)))
}
testPanic(t, f, fmt.Sprintf("[%d %d] Read panic parsing input of length %d", i, j, len(input)))
}
}
}
// Test_conn_WriteRand calls conn.Write with an arbitrary input to validate that
// it does not panic.
func Test_conn_WriteRand(t *testing.T) {
zl, err := zap.NewDevelopment()
if err != nil {
t.Fatalf("error creating a test logger: %v", err)
}
for i := range 100 {
tc := &fakes.TestConn{}
c := &conn{
Conn: tc,
log: zl.Sugar(),
}
bb := fakes.RandomBytes(t)
for j, input := range bb {
f := func() {
c.Write(input)
}
testPanic(t, f, fmt.Sprintf("[%d %d] Write: panic parsing input of length %d", i, j, len(input)))
}
}
}
func resizeMsgBytes(t *testing.T, width, height int) []byte { func resizeMsgBytes(t *testing.T, width, height int) []byte {
t.Helper() t.Helper()
bs, err := json.Marshal(spdyResizeMsg{Width: width, Height: height}) bs, err := json.Marshal(spdyResizeMsg{Width: width, Height: height})

View File

@ -9,11 +9,15 @@
"bytes" "bytes"
"compress/zlib" "compress/zlib"
"encoding/binary" "encoding/binary"
"fmt"
"io" "io"
"net/http" "net/http"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
"time"
"math/rand"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"go.uber.org/zap" "go.uber.org/zap"
@ -200,6 +204,29 @@ func Test_spdyFrame_parseHeaders(t *testing.T) {
} }
} }
// Test_spdyFrame_ParseRand calls spdyFrame.Parse with randomly generated bytes
// to test that it doesn't panic.
func Test_spdyFrame_ParseRand(t *testing.T) {
zl, err := zap.NewDevelopment()
if err != nil {
t.Fatal(err)
}
r := rand.New(rand.NewSource(time.Now().UnixNano()))
for i := range 100 {
n := r.Intn(4096)
b := make([]byte, n)
_, err := r.Read(b)
if err != nil {
t.Fatalf("error generating random byte slice: %v", err)
}
sf := &spdyFrame{}
f := func() {
sf.Parse(b, zl.Sugar())
}
testPanic(t, f, fmt.Sprintf("[%d] Parse panicked running with byte slice of length %d: %v", i, n, r))
}
}
// payload takes a control frame type and a map with 0 or more header keys and // payload takes a control frame type and a map with 0 or more header keys and
// values and returns a SPDY control frame payload with the header as SPDY zlib // values and returns a SPDY control frame payload with the header as SPDY zlib
// compressed header name/value block. The payload is padded with arbitrary // compressed header name/value block. The payload is padded with arbitrary
@ -291,3 +318,13 @@ func header(hs map[string]string) http.Header {
} }
return h return h
} }
func testPanic(t *testing.T, f func(), msg string) {
t.Helper()
defer func() {
if r := recover(); r != nil {
t.Fatal(msg, r)
}
}()
f()
}

View File

@ -0,0 +1,301 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !plan9
// package ws has functionality to parse 'kubectl exec' sessions streamed using
// WebSocket protocol.
package ws
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"sync"
"go.uber.org/zap"
"k8s.io/apimachinery/pkg/util/remotecommand"
"tailscale.com/k8s-operator/sessionrecording/tsrecorder"
"tailscale.com/sessionrecording"
"tailscale.com/util/multierr"
)
// New wraps the provided network connection and returns a connection whose reads and writes will get triggered as data is received on the hijacked connection.
// The connection must be a hijacked connection for a 'kubectl exec' session using WebSocket protocol and a *.channel.k8s.io subprotocol.
// The hijacked connection is used to transmit *.channel.k8s.io streams between Kubernetes client ('kubectl') and the destination proxy controlled by Kubernetes.
// Data read from the underlying network connection is data sent via one of the streams from the client to the container.
// Data written to the underlying connection is data sent from the container to the client.
// We parse the data and send everything for the STDOUT/STDERR streams to the configured tsrecorder as an asciinema recording with the provided header.
// https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/4006-transition-spdy-to-websockets#proposal-new-remotecommand-sub-protocol-version---v5channelk8sio
func New(c net.Conn, rec *tsrecorder.Client, ch sessionrecording.CastHeader, log *zap.SugaredLogger) net.Conn {
return &conn{
Conn: c,
rec: rec,
ch: ch,
log: log,
}
}
// conn is a wrapper around net.Conn. It reads the bytestream
// for a 'kubectl exec' session, sends session recording data to the configured
// recorder and forwards the raw bytes to the original destination.
// A new conn is created per session.
// conn only knows to how to read a 'kubectl exec' session that is streamed using WebSocket protocol.
// https://www.rfc-editor.org/rfc/rfc6455
type conn struct {
net.Conn
// rec knows how to send data to a tsrecorder instance.
rec *tsrecorder.Client
// ch is the asiinema CastHeader for a session.
ch sessionrecording.CastHeader
log *zap.SugaredLogger
rmu sync.Mutex // sequences reads
// currentReadMsg contains parsed contents of a websocket binary data message that
// is currently being read from the underlying net.Conn.
currentReadMsg *message
// readBuf contains bytes for a currently parsed binary data message
// read from the underlying conn. If the message is masked, it is
// unmasked in place, so having this buffer allows us to avoid modifying
// the original byte array.
readBuf bytes.Buffer
wmu sync.Mutex // sequences writes
writeCastHeaderOnce sync.Once
closed bool // connection is closed
// writeBuf contains bytes for a currently parsed binary data message
// being written to the underlying conn. If the message is masked, it is
// unmasked in place, so having this buffer allows us to avoid modifying
// the original byte array.
writeBuf bytes.Buffer
// currentWriteMsg contains parsed contents of a websocket binary data message that
// is currently being written to the underlying net.Conn.
currentWriteMsg *message
}
// Read reads bytes from the original connection and parses them as websocket
// message fragments.
// Bytes read from the original connection are the bytes sent from the Kubernetes client (kubectl) to the destination container via kubelet.
// If the message is for the resize stream, sets the width
// and height of the CastHeader for this connection.
// The fragment can be incomplete.
func (c *conn) Read(b []byte) (int, error) {
c.rmu.Lock()
defer c.rmu.Unlock()
n, err := c.Conn.Read(b)
if err != nil {
// It seems that we sometimes get a wrapped io.EOF, but the
// caller checks for io.EOF with ==.
if errors.Is(err, io.EOF) {
err = io.EOF
}
return 0, err
}
if n == 0 {
c.log.Debug("[unexpected] Read called for 0 length bytes")
return 0, nil
}
typ := messageType(opcode(b))
if (typ == noOpcode && c.readMsgIsIncomplete()) || c.readBufHasIncompleteFragment() { // subsequent fragment
if typ, err = c.curReadMsgType(); err != nil {
return 0, err
}
}
// A control message can not be fragmented and we are not interested in
// these messages. Just return.
if isControlMessage(typ) {
return n, nil
}
// The only data message type that Kubernetes supports is binary message.
// If we received another message type, return and let the API server close the connection.
// https://github.com/kubernetes/client-go/blob/release-1.30/tools/remotecommand/websocket.go#L281
if typ != binaryMessage {
c.log.Infof("[unexpected] received a data message with a type that is not binary message type %v", typ)
return n, nil
}
readMsg := &message{typ: typ} // start a new message...
// ... or pick up an already started one if the previous fragment was not final.
if c.readMsgIsIncomplete() || c.readBufHasIncompleteFragment() {
readMsg = c.currentReadMsg
}
if _, err := c.readBuf.Write(b[:n]); err != nil {
return 0, fmt.Errorf("[unexpected] error writing message contents to read buffer: %w", err)
}
ok, err := readMsg.Parse(c.readBuf.Bytes(), c.log)
if err != nil {
return 0, fmt.Errorf("error parsing message: %v", err)
}
if !ok { // incomplete fragment
return n, nil
}
c.readBuf.Next(len(readMsg.raw))
if readMsg.isFinalized {
// Stream IDs for websocket streams are static.
// https://github.com/kubernetes/client-go/blob/v0.30.0-rc.1/tools/remotecommand/websocket.go#L218
if readMsg.streamID.Load() == remotecommand.StreamResize {
var err error
var msg tsrecorder.ResizeMsg
if err = json.Unmarshal(readMsg.payload, &msg); err != nil {
return 0, fmt.Errorf("error umarshalling resize message: %w", err)
}
c.ch.Width = msg.Width
c.ch.Height = msg.Height
}
}
c.currentReadMsg = readMsg
return n, nil
}
// Write parses the written bytes as WebSocket message fragment. If the message
// is for stdout or stderr streams, it is written to the configured tsrecorder.
// A message fragment can be incomplete.
func (c *conn) Write(b []byte) (int, error) {
c.wmu.Lock()
defer c.wmu.Unlock()
if len(b) == 0 {
c.log.Debug("[unexpected] Write called with 0 bytes")
return 0, nil
}
typ := messageType(opcode(b))
// If we are in process of parsing a message fragment, the received
// bytes are not structured as a message fragment and can not be used to
// determine a message fragment.
if c.writeBufHasIncompleteFragment() { // buffer contains previous incomplete fragment
var err error
if typ, err = c.curWriteMsgType(); err != nil {
return 0, err
}
}
if isControlMessage(typ) {
return c.Conn.Write(b)
}
writeMsg := &message{typ: typ} // start a new message...
// ... or continue the existing one if it has not been finalized.
if c.writeMsgIsIncomplete() || c.writeBufHasIncompleteFragment() {
writeMsg = c.currentWriteMsg
}
if _, err := c.writeBuf.Write(b); err != nil {
c.log.Errorf("write: error writing to write buf: %v", err)
return 0, fmt.Errorf("[unexpected] error writing to internal write buffer: %w", err)
}
ok, err := writeMsg.Parse(c.writeBuf.Bytes(), c.log)
if err != nil {
c.log.Errorf("write: parsing a message errored: %v", err)
return 0, fmt.Errorf("write: error parsing message: %v", err)
}
c.currentWriteMsg = writeMsg
if !ok { // incomplete fragment
return len(b), nil
}
c.writeBuf.Next(len(writeMsg.raw)) // advance frame
if len(writeMsg.payload) != 0 && writeMsg.isFinalized {
if writeMsg.streamID.Load() == remotecommand.StreamStdOut || writeMsg.streamID.Load() == remotecommand.StreamStdErr {
var err error
c.writeCastHeaderOnce.Do(func() {
var j []byte
j, err = json.Marshal(c.ch)
if err != nil {
c.log.Errorf("error marhsalling conn: %v", err)
return
}
j = append(j, '\n')
err = c.rec.WriteCastLine(j)
if err != nil {
c.log.Errorf("received error from recorder: %v", err)
}
})
if err != nil {
return 0, fmt.Errorf("error writing CastHeader: %w", err)
}
if err := c.rec.Write(writeMsg.payload); err != nil {
return 0, fmt.Errorf("error writing message to recorder: %v", err)
}
}
}
_, err = c.Conn.Write(c.currentWriteMsg.raw)
if err != nil {
c.log.Errorf("write: error writing to conn: %v", err)
}
return len(b), nil
}
func (c *conn) Close() error {
c.wmu.Lock()
defer c.wmu.Unlock()
if c.closed {
return nil
}
c.closed = true
connCloseErr := c.Conn.Close()
recCloseErr := c.rec.Close()
return multierr.New(connCloseErr, recCloseErr)
}
// writeBufHasIncompleteFragment returns true if the latest data message
// fragment written to the connection was incomplete and the following write
// must be the remaining payload bytes of that fragment.
func (c *conn) writeBufHasIncompleteFragment() bool {
return c.writeBuf.Len() != 0
}
// readBufHasIncompleteFragment returns true if the latest data message
// fragment read from the connection was incomplete and the following read
// must be the remaining payload bytes of that fragment.
func (c *conn) readBufHasIncompleteFragment() bool {
return c.readBuf.Len() != 0
}
// writeMsgIsIncomplete returns true if the latest WebSocket message written to
// the connection was fragmented and the next data message fragment written to
// the connection must be a fragment of that message.
// https://www.rfc-editor.org/rfc/rfc6455#section-5.4
func (c *conn) writeMsgIsIncomplete() bool {
return c.currentWriteMsg != nil && !c.currentWriteMsg.isFinalized
}
// readMsgIsIncomplete returns true if the latest WebSocket message written to
// the connection was fragmented and the next data message fragment written to
// the connection must be a fragment of that message.
// https://www.rfc-editor.org/rfc/rfc6455#section-5.4
func (c *conn) readMsgIsIncomplete() bool {
return c.currentReadMsg != nil && !c.currentReadMsg.isFinalized
}
func (c *conn) curReadMsgType() (messageType, error) {
if c.currentReadMsg != nil {
return c.currentReadMsg.typ, nil
}
return 0, errors.New("[unexpected] attempted to determine type for nil message")
}
func (c *conn) curWriteMsgType() (messageType, error) {
if c.currentWriteMsg != nil {
return c.currentWriteMsg.typ, nil
}
return 0, errors.New("[unexpected] attempted to determine type for nil message")
}
// opcode reads the websocket message opcode that denotes the message type.
// opcode is contained in bits [4-8] of the message.
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
func opcode(b []byte) int {
// 0xf = 00001111; b & 00001111 zeroes out bits [0 - 3] of b
var mask byte = 0xf
return int(b[0] & mask)
}

View File

@ -0,0 +1,257 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !plan9
package ws
import (
"fmt"
"reflect"
"testing"
"go.uber.org/zap"
"k8s.io/apimachinery/pkg/util/remotecommand"
"tailscale.com/k8s-operator/sessionrecording/fakes"
"tailscale.com/k8s-operator/sessionrecording/tsrecorder"
"tailscale.com/sessionrecording"
"tailscale.com/tstest"
)
func Test_conn_Read(t *testing.T) {
zl, err := zap.NewDevelopment()
if err != nil {
t.Fatal(err)
}
// Resize stream ID + {"width": 10, "height": 20}
testResizeMsg := []byte{byte(remotecommand.StreamResize), 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x3a, 0x32, 0x30, 0x7d}
lenResizeMsgPayload := byte(len(testResizeMsg))
tests := []struct {
name string
inputs [][]byte
wantWidth int
wantHeight int
}{
{
name: "single_read_control_message",
inputs: [][]byte{{0x88, 0x0}},
},
{
name: "single_read_resize_message",
inputs: [][]byte{append([]byte{0x82, lenResizeMsgPayload}, testResizeMsg...)},
wantWidth: 10,
wantHeight: 20,
},
{
name: "two_reads_resize_message",
inputs: [][]byte{{0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22}, {0x80, 0x11, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x3a, 0x32, 0x30, 0x7d}},
wantWidth: 10,
wantHeight: 20,
},
{
name: "three_reads_resize_message_with_split_fragment",
inputs: [][]byte{{0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22}, {0x80, 0x11, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74}, {0x22, 0x3a, 0x32, 0x30, 0x7d}},
wantWidth: 10,
wantHeight: 20,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tc := &fakes.TestConn{}
tc.ResetReadBuf()
c := &conn{
Conn: tc,
log: zl.Sugar(),
}
for i, input := range tt.inputs {
if err := tc.WriteReadBufBytes(input); err != nil {
t.Fatalf("writing bytes to test conn: %v", err)
}
_, err := c.Read(make([]byte, len(input)))
if err != nil {
t.Errorf("[%d] conn.Read() errored %v", i, err)
return
}
}
if tt.wantHeight != 0 || tt.wantWidth != 0 {
if tt.wantWidth != c.ch.Width {
t.Errorf("wants width: %v, got %v", tt.wantWidth, c.ch.Width)
}
if tt.wantHeight != c.ch.Height {
t.Errorf("want height: %v, got %v", tt.wantHeight, c.ch.Height)
}
}
})
}
}
func Test_conn_Write(t *testing.T) {
zl, err := zap.NewDevelopment()
if err != nil {
t.Fatal(err)
}
cl := tstest.NewClock(tstest.ClockOpts{})
tests := []struct {
name string
inputs [][]byte
wantForwarded []byte
wantRecorded []byte
firstWrite bool
width int
height int
}{
{
name: "single_write_control_frame",
inputs: [][]byte{{0x88, 0x0}},
wantForwarded: []byte{0x88, 0x0},
},
{
name: "single_write_stdout_data_message",
inputs: [][]byte{{0x82, 0x3, 0x1, 0x7, 0x8}},
wantForwarded: []byte{0x82, 0x3, 0x1, 0x7, 0x8},
wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8}, cl),
},
{
name: "single_write_stderr_data_message",
inputs: [][]byte{{0x82, 0x3, 0x2, 0x7, 0x8}},
wantForwarded: []byte{0x82, 0x3, 0x2, 0x7, 0x8},
wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8}, cl),
},
{
name: "single_write_stdin_data_message",
inputs: [][]byte{{0x82, 0x3, 0x0, 0x7, 0x8}},
wantForwarded: []byte{0x82, 0x3, 0x0, 0x7, 0x8},
},
{
name: "single_write_stdout_data_message_with_cast_header",
inputs: [][]byte{{0x82, 0x3, 0x1, 0x7, 0x8}},
wantForwarded: []byte{0x82, 0x3, 0x1, 0x7, 0x8},
wantRecorded: append(fakes.AsciinemaResizeMsg(t, 10, 20), fakes.CastLine(t, []byte{0x7, 0x8}, cl)...),
width: 10,
height: 20,
firstWrite: true,
},
{
name: "two_writes_stdout_data_message",
inputs: [][]byte{{0x2, 0x3, 0x1, 0x7, 0x8}, {0x80, 0x6, 0x1, 0x1, 0x2, 0x3, 0x4, 0x5}},
wantForwarded: []byte{0x2, 0x3, 0x1, 0x7, 0x8, 0x80, 0x6, 0x1, 0x1, 0x2, 0x3, 0x4, 0x5},
wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8, 0x1, 0x2, 0x3, 0x4, 0x5}, cl),
},
{
name: "three_writes_stdout_data_message_with_split_fragment",
inputs: [][]byte{{0x2, 0x3, 0x1, 0x7, 0x8}, {0x80, 0x6, 0x1, 0x1, 0x2, 0x3}, {0x4, 0x5}},
wantForwarded: []byte{0x2, 0x3, 0x1, 0x7, 0x8, 0x80, 0x6, 0x1, 0x1, 0x2, 0x3, 0x4, 0x5},
wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8, 0x1, 0x2, 0x3, 0x4, 0x5}, cl),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tc := &fakes.TestConn{}
sr := &fakes.TestSessionRecorder{}
rec := tsrecorder.New(sr, cl, cl.Now(), true)
c := &conn{
Conn: tc,
log: zl.Sugar(),
ch: sessionrecording.CastHeader{
Width: tt.width,
Height: tt.height,
},
rec: rec,
}
if !tt.firstWrite {
// This test case does not intend to test that cast header gets written once.
c.writeCastHeaderOnce.Do(func() {})
}
for i, input := range tt.inputs {
_, err := c.Write(input)
if err != nil {
t.Fatalf("[%d] conn.Write() errored: %v", i, err)
}
}
// Assert that the expected bytes have been forwarded to the original destination.
gotForwarded := tc.WriteBufBytes()
if !reflect.DeepEqual(gotForwarded, tt.wantForwarded) {
t.Errorf("expected bytes not forwarded, wants\n%x\ngot\n%x", tt.wantForwarded, gotForwarded)
}
// Assert that the expected bytes have been forwarded to the session recorder.
gotRecorded := sr.Bytes()
if !reflect.DeepEqual(gotRecorded, tt.wantRecorded) {
t.Errorf("expected bytes not recorded, wants\n%b\ngot\n%b", tt.wantRecorded, gotRecorded)
}
})
}
}
// Test_conn_ReadRand tests reading arbitrarily generated byte slices from conn to
// test that we don't panic when parsing input from a broken or malicious
// client.
func Test_conn_ReadRand(t *testing.T) {
zl, err := zap.NewDevelopment()
if err != nil {
t.Fatalf("error creating a test logger: %v", err)
}
for i := range 100 {
tc := &fakes.TestConn{}
tc.ResetReadBuf()
c := &conn{
Conn: tc,
log: zl.Sugar(),
}
bb := fakes.RandomBytes(t)
for j, input := range bb {
if err := tc.WriteReadBufBytes(input); err != nil {
t.Fatalf("[%d] writing bytes to test conn: %v", i, err)
}
f := func() {
c.Read(make([]byte, len(input)))
}
testPanic(t, f, fmt.Sprintf("[%d %d] Read panic parsing input of length %d first bytes: %v, current read message: %+#v", i, j, len(input), firstBytes(input), c.currentReadMsg))
}
}
}
// Test_conn_WriteRand calls conn.Write with an arbitrary input to validate that it does not
// panic.
func Test_conn_WriteRand(t *testing.T) {
zl, err := zap.NewDevelopment()
if err != nil {
t.Fatalf("error creating a test logger: %v", err)
}
cl := tstest.NewClock(tstest.ClockOpts{})
sr := &fakes.TestSessionRecorder{}
rec := tsrecorder.New(sr, cl, cl.Now(), true)
for i := range 100 {
tc := &fakes.TestConn{}
c := &conn{
Conn: tc,
log: zl.Sugar(),
rec: rec,
}
bb := fakes.RandomBytes(t)
for j, input := range bb {
f := func() {
c.Write(input)
}
testPanic(t, f, fmt.Sprintf("[%d %d] Write: panic parsing input of length %d first bytes %b current write message %+#v", i, j, len(input), firstBytes(input), c.currentWriteMsg))
}
}
}
func testPanic(t *testing.T, f func(), msg string) {
t.Helper()
defer func() {
if r := recover(); r != nil {
t.Fatal(msg, r)
}
}()
f()
}
func firstBytes(b []byte) []byte {
if len(b) < 10 {
return b
}
return b[:10]
}

View File

@ -0,0 +1,267 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !plan9
package ws
import (
"encoding/binary"
"fmt"
"sync/atomic"
"github.com/pkg/errors"
"go.uber.org/zap"
"golang.org/x/net/websocket"
)
const (
noOpcode messageType = 0 // continuation frame for fragmented messages
binaryMessage messageType = 2
)
// messageType is the type of a websocket data or control message as defined by opcode.
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
// Known types of control messages are close, ping and pong.
// https://www.rfc-editor.org/rfc/rfc6455#section-5.5
// The only data message type supported by Kubernetes is binary message
// https://github.com/kubernetes/client-go/blob/v0.30.0-rc.1/tools/remotecommand/websocket.go#L281
type messageType int
// message is a parsed Websocket Message.
type message struct {
// payload is the contents of the so far parsed Websocket
// data Message payload, potentially from multiple fragments written by
// multiple invocations of Parse. As per RFC 6455 We can assume that the
// fragments will always arrive in order and data messages will not be
// interleaved.
payload []byte
// isFinalized is set to true if msgPayload contains full contents of
// the message (the final fragment has been received).
isFinalized bool
// streamID is the stream to which the message belongs, i.e stdin, stout
// etc. It is one of the stream IDs defined in
// https://github.com/kubernetes/apimachinery/blob/73d12d09c5be8703587b5127416eb83dc3b7e182/pkg/util/httpstream/wsstream/doc.go#L23-L36
streamID atomic.Uint32
// typ is the type of a WebsocketMessage as defined by its opcode
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
typ messageType
raw []byte
}
// Parse accepts a websocket message fragment as a byte slice and parses its contents.
// It returns true if the fragment is complete, false if the fragment is incomplete.
// If the fragment is incomplete, Parse will be called again with the same fragment + more bytes when those are received.
// If the fragment is complete, it will be parsed into msg.
// A complete fragment can be:
// - a fragment that consists of a whole message
// - an initial fragment for a message for which we expect more fragments
// - a subsequent fragment for a message that we are currently parsing and whose so-far parsed contents are stored in msg.
// Parse must not be called with bytes that don't contain fragment header (so, no less than 2 bytes).
// 0 1 2 3
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
// +-+-+-+-+-------+-+-------------+-------------------------------+
// |F|R|R|R| opcode|M| Payload len | Extended payload length |
// |I|S|S|S| (4) |A| (7) | (16/64) |
// |N|V|V|V| |S| | (if payload len==126/127) |
// | |1|2|3| |K| | |
// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
// | Extended payload length continued, if payload len == 127 |
// + - - - - - - - - - - - - - - - +-------------------------------+
// | |Masking-key, if MASK set to 1 |
// +-------------------------------+-------------------------------+
// | Masking-key (continued) | Payload Data |
// +-------------------------------- - - - - - - - - - - - - - - - +
// : Payload Data continued ... :
// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
// | Payload Data continued ... |
// +---------------------------------------------------------------+
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
//
// Fragmentation rules:
// An unfragmented message consists of a single frame with the FIN
// bit set (Section 5.2) and an opcode other than 0.
// A fragmented message consists of a single frame with the FIN bit
// clear and an opcode other than 0, followed by zero or more frames
// with the FIN bit clear and the opcode set to 0, and terminated by
// a single frame with the FIN bit set and an opcode of 0.
// https://www.rfc-editor.org/rfc/rfc6455#section-5.4
func (msg *message) Parse(b []byte, log *zap.SugaredLogger) (bool, error) {
if len(b) < 2 {
return false, fmt.Errorf("[unexpected] Parse should not be called with less than 2 bytes, got %d bytes", len(b))
}
if msg.typ != binaryMessage {
return false, fmt.Errorf("[unexpected] internal error: attempted to parse a message with type %d", msg.typ)
}
isInitialFragment := len(msg.raw) == 0
msg.isFinalized = isFinalFragment(b)
maskSet := isMasked(b)
payloadLength, payloadOffset, maskOffset, err := fragmentDimensions(b, maskSet)
if err != nil {
return false, fmt.Errorf("error determining payload length: %w", err)
}
log.Debugf("parse: parsing a message fragment with payload length: %d payload offset: %d maskOffset: %d mask set: %t, is finalized: %t, is initial fragment: %t", payloadLength, payloadOffset, maskOffset, maskSet, msg.isFinalized, isInitialFragment)
if len(b) < int(payloadOffset+payloadLength) { // incomplete fragment
return false, nil
}
// TODO (irbekrm): perhaps only do this extra allocation if we know we
// will need to unmask?
msg.raw = make([]byte, int(payloadOffset)+int(payloadLength))
copy(msg.raw, b[:payloadOffset+payloadLength])
// Extract the payload.
msgPayload := b[payloadOffset : payloadOffset+payloadLength]
// Unmask the payload if needed.
// TODO (irbekrm): instead of unmasking all of the payload each time,
// determine if the payload is for a resize message early and skip
// unmasking the remaining bytes if not.
if maskSet {
m := b[maskOffset:payloadOffset]
var mask [4]byte
copy(mask[:], m)
maskBytes(mask, msgPayload)
}
// Determine what stream the message is for. Stream ID of a Kubernetes
// streaming session is a 32bit integer, stored in the first byte of the
// message payload.
// https://github.com/kubernetes/apimachinery/commit/73d12d09c5be8703587b5127416eb83dc3b7e182#diff-291f96e8632d04d2d20f5fb00f6b323492670570d65434e8eac90c7a442d13bdR23-R36
if len(msgPayload) == 0 {
return false, errors.New("[unexpected] received a message fragment with no stream ID")
}
streamID := uint32(msgPayload[0])
if !isInitialFragment && msg.streamID.Load() != streamID {
return false, fmt.Errorf("[unexpected] received message fragments with mismatched streamIDs %d and %d", msg.streamID.Load(), streamID)
}
msg.streamID.Store(streamID)
// This is normal, Kubernetes seem to send a couple data messages with
// no payloads at the start.
if len(msgPayload) < 2 {
return true, nil
}
msgPayload = msgPayload[1:] // remove the stream ID byte
msg.payload = append(msg.payload, msgPayload...)
return true, nil
}
// maskBytes applies mask to bytes in place.
// https://www.rfc-editor.org/rfc/rfc6455#section-5.3
func maskBytes(key [4]byte, b []byte) {
for i := range b {
b[i] = b[i] ^ key[i%4]
}
}
// isControlMessage returns true if the message type is one of the known control
// frame message types.
// https://www.rfc-editor.org/rfc/rfc6455#section-5.5
func isControlMessage(t messageType) bool {
const (
closeMessage messageType = 8
pingMessage messageType = 9
pongMessage messageType = 10
)
return t == closeMessage || t == pingMessage || t == pongMessage
}
// isFinalFragment can be called with websocket message fragment and returns true if
// the fragment is the final fragment of a websocket message.
func isFinalFragment(b []byte) bool {
return extractFirstBit(b[0]) != 0
}
// isMasked can be called with a websocket message fragment and returns true if
// the payload of the message is masked. It uses the mask bit to determine if
// the payload is masked.
// https://www.rfc-editor.org/rfc/rfc6455#section-5.3
func isMasked(b []byte) bool {
return extractFirstBit(b[1]) != 0
}
// extractFirstBit extracts first bit of a byte by zeroing out all the other
// bits.
func extractFirstBit(b byte) byte {
return b & 0x80
}
// zeroFirstBit returns the provided byte with the first bit set to 0.
func zeroFirstBit(b byte) byte {
return b & 0x7f
}
// fragmentDimensions returns payload length as well as payload offset and mask offset.
func fragmentDimensions(b []byte, maskSet bool) (payloadLength, payloadOffset, maskOffset uint64, _ error) {
// payload length can be stored either in bits [9-15] or in bytes 2, 3
// or in bytes 2, 3, 4, 5, 6, 7.
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
// 0 1 2 3
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
// +-+-+-+-+-------+-+-------------+-------------------------------+
// |F|R|R|R| opcode|M| Payload len | Extended payload length |
// |I|S|S|S| (4) |A| (7) | (16/64) |
// |N|V|V|V| |S| | (if payload len==126/127) |
// | |1|2|3| |K| | |
// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
// | Extended payload length continued, if payload len == 127 |
// + - - - - - - - - - - - - - - - +-------------------------------+
// | |Masking-key, if MASK set to 1 |
// +-------------------------------+-------------------------------+
payloadLengthIndicator := zeroFirstBit(b[1])
switch {
case payloadLengthIndicator < 126:
maskOffset = 2
payloadLength = uint64(payloadLengthIndicator)
case payloadLengthIndicator == 126:
maskOffset = 4
if len(b) < int(maskOffset) {
return 0, 0, 0, fmt.Errorf("invalid message fragment- length indicator suggests that length is stored in bytes 2:4, but message length is only %d", len(b))
}
payloadLength = uint64(binary.BigEndian.Uint16(b[2:4]))
case payloadLengthIndicator == 127:
maskOffset = 10
if len(b) < int(maskOffset) {
return 0, 0, 0, fmt.Errorf("invalid message fragment- length indicator suggests that length is stored in bytes 2:10, but message length is only %d", len(b))
}
payloadLength = binary.BigEndian.Uint64(b[2:10])
default:
return 0, 0, 0, fmt.Errorf("unexpected payload length indicator value: %v", payloadLengthIndicator)
}
// Ensure that a rogue or broken client doesn't cause us attempt to
// allocate a huge array by setting a high payload size.
// websocket.DefaultMaxPayloadBytes is the maximum payload size accepted
// by server side of this connection, so we can safely reject messages
// with larger payload size.
if payloadLength > websocket.DefaultMaxPayloadBytes {
return 0, 0, 0, fmt.Errorf("[unexpected]: too large payload size: %v", payloadLength)
}
// Masking key can take up 0 or 4 bytes- we need to take that into
// account when determining payload offset.
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
// ....
// + - - - - - - - - - - - - - - - +-------------------------------+
// | |Masking-key, if MASK set to 1 |
// +-------------------------------+-------------------------------+
// | Masking-key (continued) | Payload Data |
// + - - - - - - - - - - - - - - - +-------------------------------+
// ...
if maskSet {
payloadOffset = maskOffset + 4
} else {
payloadOffset = maskOffset
}
return
}

View File

@ -0,0 +1,215 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !plan9
package ws
import (
"encoding/binary"
"fmt"
"reflect"
"testing"
"time"
"math/rand"
"go.uber.org/zap"
"golang.org/x/net/websocket"
)
func Test_msg_Parse(t *testing.T) {
zl, err := zap.NewDevelopment()
if err != nil {
t.Fatalf("error creating a test logger: %v", err)
}
testMask := [4]byte{1, 2, 3, 4}
bs126, bs126Len := bytesSlice2ByteLen(t)
bs127, bs127Len := byteSlice8ByteLen(t)
tests := []struct {
name string
b []byte
initialPayload []byte
wantPayload []byte
wantIsFinalized bool
wantStreamID uint32
wantErr bool
}{
{
name: "single_fragment_stdout_stream_no_payload_no_mask",
b: []byte{0x82, 0x1, 0x1},
wantPayload: nil,
wantIsFinalized: true,
wantStreamID: 1,
},
{
name: "single_fragment_stderr_steam_no_payload_has_mask",
b: append([]byte{0x82, 0x81, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x2})...),
wantPayload: nil,
wantIsFinalized: true,
wantStreamID: 2,
},
{
name: "single_fragment_stdout_stream_no_mask_has_payload",
b: []byte{0x82, 0x3, 0x1, 0x7, 0x8},
wantPayload: []byte{0x7, 0x8},
wantIsFinalized: true,
wantStreamID: 1,
},
{
name: "single_fragment_stdout_stream_has_mask_has_payload",
b: append([]byte{0x82, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...),
wantPayload: []byte{0x7, 0x8},
wantIsFinalized: true,
wantStreamID: 1,
},
{
name: "initial_fragment_stdout_stream_no_mask_has_payload",
b: []byte{0x2, 0x3, 0x1, 0x7, 0x8},
wantPayload: []byte{0x7, 0x8},
wantStreamID: 1,
},
{
name: "initial_fragment_stdout_stream_has_mask_has_payload",
b: append([]byte{0x2, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...),
wantPayload: []byte{0x7, 0x8},
wantStreamID: 1,
},
{
name: "subsequent_fragment_stdout_stream_no_mask_has_payload",
b: []byte{0x0, 0x3, 0x1, 0x7, 0x8},
initialPayload: []byte{0x1, 0x2, 0x3},
wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8},
wantStreamID: 1,
},
{
name: "subsequent_fragment_stdout_stream_has_mask_has_payload",
b: append([]byte{0x0, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...),
initialPayload: []byte{0x1, 0x2, 0x3},
wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8},
wantStreamID: 1,
},
{
name: "final_fragment_stdout_stream_no_mask_has_payload",
b: []byte{0x80, 0x3, 0x1, 0x7, 0x8},
initialPayload: []byte{0x1, 0x2, 0x3},
wantIsFinalized: true,
wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8},
wantStreamID: 1,
},
{
name: "final_fragment_stdout_stream_has_mask_has_payload",
b: append([]byte{0x80, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...),
initialPayload: []byte{0x1, 0x2, 0x3},
wantIsFinalized: true,
wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8},
wantStreamID: 1,
},
{
name: "single_large_fragment_no_mask_length_hint_126",
b: append(append([]byte{0x80, 0x7e}, bs126Len...), append([]byte{0x1}, bs126...)...),
wantIsFinalized: true,
wantPayload: bs126,
wantStreamID: 1,
},
{
name: "single_large_fragment_no_mask_length_hint_127",
b: append(append([]byte{0x80, 0x7f}, bs127Len...), append([]byte{0x1}, bs127...)...),
wantIsFinalized: true,
wantPayload: bs127,
wantStreamID: 1,
},
{
name: "zero_length_bytes",
b: []byte{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msg := &message{
typ: binaryMessage,
payload: tt.initialPayload,
}
if _, err := msg.Parse(tt.b, zl.Sugar()); (err != nil) != tt.wantErr {
t.Errorf("msg.Parse() = %v, wantsErr: %t", err, tt.wantErr)
}
if msg.isFinalized != tt.wantIsFinalized {
t.Errorf("wants message to be finalized: %t, got: %t", tt.wantIsFinalized, msg.isFinalized)
}
if msg.streamID.Load() != tt.wantStreamID {
t.Errorf("wants stream ID: %d, got: %d", tt.wantStreamID, msg.streamID.Load())
}
if !reflect.DeepEqual(msg.payload, tt.wantPayload) {
t.Errorf("unexpected message payload after Parse, wants %b got %b", tt.wantPayload, msg.payload)
}
})
}
}
// Test_msg_Parse_Rand calls Parse with a randomly generated input to verify
// that it doesn't panic.
func Test_msg_Parse_Rand(t *testing.T) {
zl, err := zap.NewDevelopment()
if err != nil {
t.Fatalf("error creating a test logger: %v", err)
}
r := rand.New(rand.NewSource(time.Now().UnixNano()))
for i := range 100 {
n := r.Intn(4096)
b := make([]byte, n)
_, err := r.Read(b)
if err != nil {
t.Fatalf("error generating random byte slice: %v", err)
}
msg := message{typ: binaryMessage}
f := func() {
msg.Parse(b, zl.Sugar())
}
testPanic(t, f, fmt.Sprintf("[%d] Parse panicked running with byte slice of length %d: %v", i, n, r))
}
}
// byteSlice2ByteLen generates a number that represents websocket message fragment length and is stored in an 8 byte slice.
// Returns the byte slice with the length as well as a slice of arbitrary bytes of the given length.
// This is used to generate test input representing websocket message with payload length hint 126.
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
func bytesSlice2ByteLen(t *testing.T) ([]byte, []byte) {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
var n uint16
n = uint16(rand.Intn(65535 - 1)) // space for and additional 1 byte stream ID
b := make([]byte, n)
_, err := r.Read(b)
if err != nil {
t.Fatalf("error generating random byte slice: %v ", err)
}
bb := make([]byte, 2)
binary.BigEndian.PutUint16(bb, n+1) // + stream ID
return b, bb
}
// byteSlice8ByteLen generates a number that represents websocket message fragment length and is stored in an 8 byte slice.
// Returns the byte slice with the length as well as a slice of arbitrary bytes of the given length.
// This is used to generate test input representing websocket message with payload length hint 127.
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
func byteSlice8ByteLen(t *testing.T) ([]byte, []byte) {
nanos := time.Now().UnixNano()
t.Logf("Creating random source with seed %v", nanos)
r := rand.New(rand.NewSource(nanos))
var n uint64
n = uint64(rand.Intn(websocket.DefaultMaxPayloadBytes - 1)) // space for and additional 1 byte stream ID
t.Logf("byteSlice8ByteLen: generating message payload of length %d", n)
b := make([]byte, n)
_, err := r.Read(b)
if err != nil {
t.Fatalf("error generating random byte slice: %v ", err)
}
bb := make([]byte, 8)
binary.BigEndian.PutUint64(bb, n+1) // + stream ID
return b, bb
}
func maskedBytes(mask [4]byte, b []byte) []byte {
maskBytes(mask, b)
return b
}