mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-25 11:05:45 +00:00
cmd/k8s-operator,k8s-operator/sessionrecording: support recording kubectl exec sessions over WebSockets (#12947)
cmd/k8s-operator,k8s-operator/sessionrecording: support recording WebSocket sessions Kubernetes currently supports two streaming protocols, SPDY and WebSockets. WebSockets are replacing SPDY, see https://github.com/kubernetes/enhancements/issues/4006. We were currently only supporting SPDY, erroring out if session was not SPDY and relying on the kube's built-in SPDY fallback. This PR: - adds support for parsing contents of 'kubectl exec' sessions streamed over WebSockets - adds logic to distinguish 'kubectl exec' requests for a SPDY/WebSockets sessions and call the relevant handler Updates tailscale/corp#19821 Signed-off-by: Irbe Krumina <irbe@tailscale.com> Co-authored-by: Tom Proctor <tomhjp@users.noreply.github.com>
This commit is contained in:
parent
4c2e978f1e
commit
a15ff1bade
@ -417,6 +417,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
|
||||
k8s.io/apimachinery/pkg/util/naming from k8s.io/apimachinery/pkg/runtime+
|
||||
k8s.io/apimachinery/pkg/util/net from k8s.io/apimachinery/pkg/watch+
|
||||
k8s.io/apimachinery/pkg/util/rand from k8s.io/apiserver/pkg/storage/names
|
||||
k8s.io/apimachinery/pkg/util/remotecommand from tailscale.com/k8s-operator/sessionrecording/ws
|
||||
k8s.io/apimachinery/pkg/util/runtime from k8s.io/apimachinery/pkg/apis/meta/internalversion/scheme+
|
||||
k8s.io/apimachinery/pkg/util/sets from k8s.io/apimachinery/pkg/api/meta+
|
||||
k8s.io/apimachinery/pkg/util/strategicpatch from k8s.io/client-go/tools/record+
|
||||
@ -686,9 +687,9 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
|
||||
tailscale.com/k8s-operator/apis from tailscale.com/k8s-operator/apis/v1alpha1
|
||||
tailscale.com/k8s-operator/apis/v1alpha1 from tailscale.com/cmd/k8s-operator+
|
||||
tailscale.com/k8s-operator/sessionrecording from tailscale.com/cmd/k8s-operator
|
||||
tailscale.com/k8s-operator/sessionrecording/conn from tailscale.com/k8s-operator/sessionrecording/spdy
|
||||
tailscale.com/k8s-operator/sessionrecording/spdy from tailscale.com/k8s-operator/sessionrecording
|
||||
tailscale.com/k8s-operator/sessionrecording/tsrecorder from tailscale.com/k8s-operator/sessionrecording+
|
||||
tailscale.com/k8s-operator/sessionrecording/ws from tailscale.com/k8s-operator/sessionrecording
|
||||
tailscale.com/kube from tailscale.com/cmd/k8s-operator+
|
||||
tailscale.com/licenses from tailscale.com/client/web
|
||||
tailscale.com/log/filelogger from tailscale.com/logpolicy
|
||||
@ -741,7 +742,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
|
||||
tailscale.com/posture from tailscale.com/ipn/ipnlocal
|
||||
tailscale.com/proxymap from tailscale.com/tsd+
|
||||
💣 tailscale.com/safesocket from tailscale.com/client/tailscale+
|
||||
tailscale.com/sessionrecording from tailscale.com/cmd/k8s-operator+
|
||||
tailscale.com/sessionrecording from tailscale.com/k8s-operator/sessionrecording+
|
||||
tailscale.com/syncs from tailscale.com/control/controlknobs+
|
||||
tailscale.com/tailcfg from tailscale.com/client/tailscale+
|
||||
tailscale.com/taildrop from tailscale.com/ipn/ipnlocal+
|
||||
@ -863,6 +864,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/
|
||||
golang.org/x/net/ipv6 from github.com/miekg/dns+
|
||||
golang.org/x/net/proxy from tailscale.com/net/netns
|
||||
D golang.org/x/net/route from net+
|
||||
golang.org/x/net/websocket from tailscale.com/k8s-operator/sessionrecording/ws
|
||||
golang.org/x/oauth2 from golang.org/x/oauth2/clientcredentials+
|
||||
golang.org/x/oauth2/clientcredentials from tailscale.com/cmd/k8s-operator
|
||||
golang.org/x/oauth2/internal from golang.org/x/oauth2+
|
||||
|
@ -22,9 +22,8 @@
|
||||
"k8s.io/client-go/transport"
|
||||
"tailscale.com/client/tailscale"
|
||||
"tailscale.com/client/tailscale/apitype"
|
||||
kubesessionrecording "tailscale.com/k8s-operator/sessionrecording"
|
||||
ksr "tailscale.com/k8s-operator/sessionrecording"
|
||||
tskube "tailscale.com/kube"
|
||||
"tailscale.com/sessionrecording"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/tsnet"
|
||||
"tailscale.com/util/clientmetric"
|
||||
@ -168,7 +167,8 @@ func runAPIServerProxy(ts *tsnet.Server, rt http.RoundTripper, log *zap.SugaredL
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", ap.serveDefault)
|
||||
mux.HandleFunc("/api/v1/namespaces/{namespace}/pods/{pod}/exec", ap.serveExec)
|
||||
mux.HandleFunc("POST /api/v1/namespaces/{namespace}/pods/{pod}/exec", ap.serveExecSPDY)
|
||||
mux.HandleFunc("GET /api/v1/namespaces/{namespace}/pods/{pod}/exec", ap.serveExecWS)
|
||||
|
||||
hs := &http.Server{
|
||||
// Kubernetes uses SPDY for exec and port-forward, however SPDY is
|
||||
@ -209,9 +209,19 @@ func (ap *apiserverProxy) serveDefault(w http.ResponseWriter, r *http.Request) {
|
||||
ap.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who)))
|
||||
}
|
||||
|
||||
// serveExec serves 'kubectl exec' requests, optionally configuring the kubectl
|
||||
// exec sessions to be recorded.
|
||||
func (ap *apiserverProxy) serveExec(w http.ResponseWriter, r *http.Request) {
|
||||
// serveExecSPDY serves 'kubectl exec' requests for sessions streamed over SPDY,
|
||||
// optionally configuring the kubectl exec sessions to be recorded.
|
||||
func (ap *apiserverProxy) serveExecSPDY(w http.ResponseWriter, r *http.Request) {
|
||||
ap.execForProto(w, r, ksr.SPDYProtocol)
|
||||
}
|
||||
|
||||
// serveExecWS serves 'kubectl exec' requests for sessions streamed over WebSocket,
|
||||
// optionally configuring the kubectl exec sessions to be recorded.
|
||||
func (ap *apiserverProxy) serveExecWS(w http.ResponseWriter, r *http.Request) {
|
||||
ap.execForProto(w, r, ksr.WSProtocol)
|
||||
}
|
||||
|
||||
func (ap *apiserverProxy) execForProto(w http.ResponseWriter, r *http.Request, proto ksr.Protocol) {
|
||||
who, err := ap.whoIs(r)
|
||||
if err != nil {
|
||||
ap.authError(w, err)
|
||||
@ -227,15 +237,17 @@ func (ap *apiserverProxy) serveExec(w http.ResponseWriter, r *http.Request) {
|
||||
ap.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who)))
|
||||
return
|
||||
}
|
||||
kubesessionrecording.CounterSessionRecordingsAttempted.Add(1) // at this point we know that users intended for this session to be recorded
|
||||
ksr.CounterSessionRecordingsAttempted.Add(1) // at this point we know that users intended for this session to be recorded
|
||||
if !failOpen && len(addrs) == 0 {
|
||||
msg := "forbidden: 'kubectl exec' session must be recorded, but no recorders are available."
|
||||
ap.log.Error(msg)
|
||||
http.Error(w, msg, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
if r.Method != "POST" || r.Header.Get("Upgrade") != "SPDY/3.1" {
|
||||
msg := "'kubectl exec' session recording is configured, but the request is not over SPDY. Session recording is currently only supported for SPDY based clients"
|
||||
|
||||
wantsHeader := upgradeHeaderForProto[proto]
|
||||
if h := r.Header.Get("Upgrade"); h != wantsHeader {
|
||||
msg := fmt.Sprintf("[unexpected] unable to verify that streaming protocol is %s, wants Upgrade header %q, got: %q", proto, wantsHeader, h)
|
||||
if failOpen {
|
||||
msg = msg + "; failure mode is 'fail open'; continuing session without recording."
|
||||
ap.log.Warn(msg)
|
||||
@ -247,9 +259,22 @@ func (ap *apiserverProxy) serveExec(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, msg, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
spdyH := kubesessionrecording.New(ap.ts, r, who, w, r.PathValue("pod"), r.PathValue("namespace"), kubesessionrecording.SPDYProtocol, addrs, failOpen, sessionrecording.ConnectToRecorder, ap.log)
|
||||
|
||||
ap.rp.ServeHTTP(spdyH, r.WithContext(whoIsKey.WithValue(r.Context(), who)))
|
||||
opts := ksr.HijackerOpts{
|
||||
Req: r,
|
||||
W: w,
|
||||
Proto: proto,
|
||||
TS: ap.ts,
|
||||
Who: who,
|
||||
Addrs: addrs,
|
||||
FailOpen: failOpen,
|
||||
Pod: r.PathValue("pod"),
|
||||
Namespace: r.PathValue("namespace"),
|
||||
Log: ap.log,
|
||||
}
|
||||
h := ksr.New(opts)
|
||||
|
||||
ap.rp.ServeHTTP(h, r.WithContext(whoIsKey.WithValue(r.Context(), who)))
|
||||
}
|
||||
|
||||
func (h *apiserverProxy) addImpersonationHeadersAsRequired(r *http.Request) {
|
||||
@ -382,3 +407,8 @@ func determineRecorderConfig(who *apitype.WhoIsResponse) (failOpen bool, recorde
|
||||
}
|
||||
return failOpen, recorderAddresses, nil
|
||||
}
|
||||
|
||||
var upgradeHeaderForProto = map[ksr.Protocol]string{
|
||||
ksr.SPDYProtocol: "SPDY/3.1",
|
||||
ksr.WSProtocol: "websocket",
|
||||
}
|
||||
|
@ -1,20 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build !plan9
|
||||
|
||||
// Package conn contains shared interface for the hijacked
|
||||
// connection of a 'kubectl exec' session that is being recorded.
|
||||
package conn
|
||||
|
||||
import "net"
|
||||
|
||||
type Conn interface {
|
||||
net.Conn
|
||||
// Fail can be called to set connection state to failed. By default any
|
||||
// bytes left over in write buffer are forwarded to the intended
|
||||
// destination when the connection is being closed except for when the
|
||||
// connection state is failed- so set the state to failed when erroring
|
||||
// out and failure policy is to fail closed.
|
||||
Fail()
|
||||
}
|
@ -13,6 +13,9 @@
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"math/rand"
|
||||
|
||||
"tailscale.com/sessionrecording"
|
||||
"tailscale.com/tstime"
|
||||
@ -116,3 +119,20 @@ func AsciinemaResizeMsg(t *testing.T, width, height int) []byte {
|
||||
}
|
||||
return append(bs, '\n')
|
||||
}
|
||||
|
||||
func RandomBytes(t *testing.T) [][]byte {
|
||||
t.Helper()
|
||||
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
n := r.Intn(4096)
|
||||
b := make([]byte, n)
|
||||
t.Logf("RandomBytes: generating byte slice of length %d", n)
|
||||
_, err := r.Read(b)
|
||||
if err != nil {
|
||||
t.Fatalf("error generating random byte slice: %v", err)
|
||||
}
|
||||
if len(b) < 2 {
|
||||
return [][]byte{b}
|
||||
}
|
||||
split := r.Intn(len(b) - 1)
|
||||
return [][]byte{b[:split], b[split:]}
|
||||
}
|
||||
|
@ -23,6 +23,7 @@
|
||||
"tailscale.com/client/tailscale/apitype"
|
||||
"tailscale.com/k8s-operator/sessionrecording/spdy"
|
||||
"tailscale.com/k8s-operator/sessionrecording/tsrecorder"
|
||||
"tailscale.com/k8s-operator/sessionrecording/ws"
|
||||
"tailscale.com/sessionrecording"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/tsnet"
|
||||
@ -31,11 +32,14 @@
|
||||
"tailscale.com/util/multierr"
|
||||
)
|
||||
|
||||
const SPDYProtocol protocol = "SPDY"
|
||||
const (
|
||||
SPDYProtocol Protocol = "SPDY"
|
||||
WSProtocol Protocol = "WebSocket"
|
||||
)
|
||||
|
||||
// protocol is the streaming protocol of the hijacked session. Supported
|
||||
// protocols are SPDY.
|
||||
type protocol string
|
||||
// Protocol is the streaming protocol of the hijacked session. Supported
|
||||
// protocols are SPDY and WebSocket.
|
||||
type Protocol string
|
||||
|
||||
var (
|
||||
// CounterSessionRecordingsAttempted counts the number of session recording attempts.
|
||||
@ -45,22 +49,35 @@
|
||||
counterSessionRecordingsUploaded = clientmetric.NewCounter("k8s_auth_proxy_session_recordings_uploaded")
|
||||
)
|
||||
|
||||
func New(ts *tsnet.Server, req *http.Request, who *apitype.WhoIsResponse, w http.ResponseWriter, pod, ns string, proto protocol, addrs []netip.AddrPort, failOpen bool, connFunc RecorderDialFn, log *zap.SugaredLogger) *Hijacker {
|
||||
func New(opts HijackerOpts) *Hijacker {
|
||||
return &Hijacker{
|
||||
ts: ts,
|
||||
req: req,
|
||||
who: who,
|
||||
ResponseWriter: w,
|
||||
pod: pod,
|
||||
ns: ns,
|
||||
addrs: addrs,
|
||||
failOpen: failOpen,
|
||||
connectToRecorder: connFunc,
|
||||
proto: proto,
|
||||
log: log,
|
||||
ts: opts.TS,
|
||||
req: opts.Req,
|
||||
who: opts.Who,
|
||||
ResponseWriter: opts.W,
|
||||
pod: opts.Pod,
|
||||
ns: opts.Namespace,
|
||||
addrs: opts.Addrs,
|
||||
failOpen: opts.FailOpen,
|
||||
proto: opts.Proto,
|
||||
log: opts.Log,
|
||||
connectToRecorder: sessionrecording.ConnectToRecorder,
|
||||
}
|
||||
}
|
||||
|
||||
type HijackerOpts struct {
|
||||
TS *tsnet.Server
|
||||
Req *http.Request
|
||||
W http.ResponseWriter
|
||||
Who *apitype.WhoIsResponse
|
||||
Addrs []netip.AddrPort
|
||||
Log *zap.SugaredLogger
|
||||
Pod string
|
||||
Namespace string
|
||||
FailOpen bool
|
||||
Proto Protocol
|
||||
}
|
||||
|
||||
// Hijacker implements [net/http.Hijacker] interface.
|
||||
// It must be configured with an http request for a 'kubectl exec' session that
|
||||
// needs to be recorded. It knows how to hijack the connection and configure for
|
||||
@ -76,7 +93,7 @@ type Hijacker struct {
|
||||
addrs []netip.AddrPort // tsrecorder addresses
|
||||
failOpen bool // whether to fail open if recording fails
|
||||
connectToRecorder RecorderDialFn
|
||||
proto protocol // streaming protocol
|
||||
proto Protocol // streaming protocol
|
||||
}
|
||||
|
||||
// RecorderDialFn dials the specified netip.AddrPorts that should be tsrecorder
|
||||
@ -111,10 +128,14 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn,
|
||||
// https://docs.asciinema.org/manual/asciicast/v2/
|
||||
asciicastv2 = 2
|
||||
)
|
||||
var wc io.WriteCloser
|
||||
var (
|
||||
wc io.WriteCloser
|
||||
err error
|
||||
errChan <-chan error
|
||||
)
|
||||
h.log.Infof("kubectl exec session will be recorded, recorders: %v, fail open policy: %t", h.addrs, h.failOpen)
|
||||
// TODO (irbekrm): send client a message that session will be recorded.
|
||||
rw, _, errChan, err := h.connectToRecorder(ctx, h.addrs, h.ts.Dial)
|
||||
wc, _, errChan, err = h.connectToRecorder(ctx, h.addrs, h.ts.Dial)
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("error connecting to session recorders: %v", err)
|
||||
if h.failOpen {
|
||||
@ -131,7 +152,6 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn,
|
||||
|
||||
// TODO (irbekrm): log which recorder
|
||||
h.log.Info("successfully connected to a session recorder")
|
||||
wc = rw
|
||||
cl := tstime.DefaultClock{}
|
||||
rec := tsrecorder.New(wc, cl, cl.Now(), h.failOpen)
|
||||
qp := h.req.URL.Query()
|
||||
@ -153,7 +173,17 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn,
|
||||
} else {
|
||||
ch.SrcNodeTags = h.who.Node.Tags
|
||||
}
|
||||
lc := spdy.New(conn, rec, ch, h.log)
|
||||
|
||||
var lc net.Conn
|
||||
switch h.proto {
|
||||
case SPDYProtocol:
|
||||
lc = spdy.New(conn, rec, ch, h.log)
|
||||
case WSProtocol:
|
||||
lc = ws.New(conn, rec, ch, h.log)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown protocol: %s", h.proto)
|
||||
}
|
||||
|
||||
go func() {
|
||||
var err error
|
||||
select {
|
||||
@ -174,7 +204,6 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn,
|
||||
}
|
||||
msg += "; failure mode set to 'fail closed'; closing connection"
|
||||
h.log.Error(msg)
|
||||
lc.Fail()
|
||||
// TODO (irbekrm): write a message to the client
|
||||
if err := lc.Close(); err != nil {
|
||||
h.log.Infof("error closing recorder connections: %v", err)
|
||||
|
@ -37,30 +37,40 @@ func Test_Hijacker(t *testing.T) {
|
||||
failRecorderConnPostConnect bool // send error down the error channel
|
||||
wantsConnClosed bool
|
||||
wantsSetupErr bool
|
||||
proto Protocol
|
||||
}{
|
||||
{
|
||||
name: "setup succeeds, conn stays open",
|
||||
name: "setup_succeeds_conn_stays_open",
|
||||
proto: SPDYProtocol,
|
||||
},
|
||||
{
|
||||
name: "setup fails, policy is to fail open, conn stays open",
|
||||
name: "setup_succeeds_conn_stays_open_ws",
|
||||
proto: WSProtocol,
|
||||
},
|
||||
{
|
||||
name: "setup_fails_policy_is_to_fail_open_conn_stays_open",
|
||||
failOpen: true,
|
||||
failRecorderConnect: true,
|
||||
proto: SPDYProtocol,
|
||||
},
|
||||
{
|
||||
name: "setup fails, policy is to fail closed, conn is closed",
|
||||
name: "setup_fails_policy_is_to_fail_closed_conn_is_closed",
|
||||
failRecorderConnect: true,
|
||||
wantsSetupErr: true,
|
||||
wantsConnClosed: true,
|
||||
proto: SPDYProtocol,
|
||||
},
|
||||
{
|
||||
name: "connection fails post-initial connect, policy is to fail open, conn stays open",
|
||||
name: "connection_fails_post-initial_connect_policy_is_to_fail_open_conn_stays_open",
|
||||
failRecorderConnPostConnect: true,
|
||||
failOpen: true,
|
||||
proto: SPDYProtocol,
|
||||
},
|
||||
{
|
||||
name: "connection fails post-initial connect, policy is to fail closed, conn is closed",
|
||||
name: "connection_fails_post-initial_connect,_policy_is_to_fail_closed_conn_is_closed",
|
||||
failRecorderConnPostConnect: true,
|
||||
wantsConnClosed: true,
|
||||
proto: SPDYProtocol,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
@ -79,6 +89,7 @@ func Test_Hijacker(t *testing.T) {
|
||||
log: zl.Sugar(),
|
||||
ts: &tsnet.Server{},
|
||||
req: &http.Request{URL: &url.URL{}},
|
||||
proto: tt.proto,
|
||||
}
|
||||
ctx := context.Background()
|
||||
_, err := h.setUpRecording(ctx, tc)
|
||||
|
@ -19,12 +19,18 @@
|
||||
|
||||
"go.uber.org/zap"
|
||||
corev1 "k8s.io/api/core/v1"
|
||||
srconn "tailscale.com/k8s-operator/sessionrecording/conn"
|
||||
"tailscale.com/k8s-operator/sessionrecording/tsrecorder"
|
||||
"tailscale.com/sessionrecording"
|
||||
)
|
||||
|
||||
func New(nc net.Conn, rec *tsrecorder.Client, ch sessionrecording.CastHeader, log *zap.SugaredLogger) srconn.Conn {
|
||||
// New wraps the provided network connection and returns a connection whose reads and writes will get triggered as data is received on the hijacked connection.
|
||||
// The connection must be a hijacked connection for a 'kubectl exec' session using SPDY.
|
||||
// The hijacked connection is used to transmit SPDY streams between Kubernetes client ('kubectl') and the destination container.
|
||||
// Data read from the underlying network connection is data sent via one of the SPDY streams from the client to the container.
|
||||
// Data written to the underlying connection is data sent from the container to the client.
|
||||
// We parse the data and send everything for the STDOUT/STDERR streams to the configured tsrecorder as an asciinema recording with the provided header.
|
||||
// https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/4006-transition-spdy-to-websockets#background-remotecommand-subprotocol
|
||||
func New(nc net.Conn, rec *tsrecorder.Client, ch sessionrecording.CastHeader, log *zap.SugaredLogger) net.Conn {
|
||||
return &conn{
|
||||
Conn: nc,
|
||||
rec: rec,
|
||||
@ -49,7 +55,6 @@ type conn struct {
|
||||
|
||||
wmu sync.Mutex // sequences writes
|
||||
closed bool
|
||||
failed bool
|
||||
|
||||
rmu sync.Mutex // sequences reads
|
||||
writeCastHeaderOnce sync.Once
|
||||
@ -172,9 +177,6 @@ func (c *conn) Close() error {
|
||||
if c.closed {
|
||||
return nil
|
||||
}
|
||||
if !c.failed && c.writeBuf.Len() > 0 {
|
||||
c.Conn.Write(c.writeBuf.Bytes())
|
||||
}
|
||||
c.writeBuf.Reset()
|
||||
c.closed = true
|
||||
err := c.Conn.Close()
|
||||
@ -182,14 +184,8 @@ func (c *conn) Close() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *conn) Fail() {
|
||||
s.wmu.Lock()
|
||||
s.failed = true
|
||||
s.wmu.Unlock()
|
||||
}
|
||||
|
||||
// storeStreamID parses SYN_STREAM SPDY control frame and updates
|
||||
// spdyRemoteConnRecorder to store the newly created stream's ID if it is one of
|
||||
// conn to store the newly created stream's ID if it is one of
|
||||
// the stream types we care about. Storing stream_id:stream_type mapping allows
|
||||
// us to parse received data frames (that have stream IDs) differently depening
|
||||
// on which stream they belong to (i.e send data frame payload for stdout stream
|
||||
|
@ -7,6 +7,7 @@
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
@ -234,6 +235,57 @@ func Test_Reads(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Test_conn_ReadRand tests reading arbitrarily generated byte slices from conn to
|
||||
// test that we don't panic when parsing input from a broken or malicious
|
||||
// client.
|
||||
func Test_conn_ReadRand(t *testing.T) {
|
||||
zl, err := zap.NewDevelopment()
|
||||
if err != nil {
|
||||
t.Fatalf("error creating a test logger: %v", err)
|
||||
}
|
||||
for i := range 1000 {
|
||||
tc := &fakes.TestConn{}
|
||||
tc.ResetReadBuf()
|
||||
c := &conn{
|
||||
Conn: tc,
|
||||
log: zl.Sugar(),
|
||||
}
|
||||
bb := fakes.RandomBytes(t)
|
||||
for j, input := range bb {
|
||||
if err := tc.WriteReadBufBytes(input); err != nil {
|
||||
t.Fatalf("[%d] writing bytes to test conn: %v", i, err)
|
||||
}
|
||||
f := func() {
|
||||
c.Read(make([]byte, len(input)))
|
||||
}
|
||||
testPanic(t, f, fmt.Sprintf("[%d %d] Read panic parsing input of length %d", i, j, len(input)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test_conn_WriteRand calls conn.Write with an arbitrary input to validate that
|
||||
// it does not panic.
|
||||
func Test_conn_WriteRand(t *testing.T) {
|
||||
zl, err := zap.NewDevelopment()
|
||||
if err != nil {
|
||||
t.Fatalf("error creating a test logger: %v", err)
|
||||
}
|
||||
for i := range 100 {
|
||||
tc := &fakes.TestConn{}
|
||||
c := &conn{
|
||||
Conn: tc,
|
||||
log: zl.Sugar(),
|
||||
}
|
||||
bb := fakes.RandomBytes(t)
|
||||
for j, input := range bb {
|
||||
f := func() {
|
||||
c.Write(input)
|
||||
}
|
||||
testPanic(t, f, fmt.Sprintf("[%d %d] Write: panic parsing input of length %d", i, j, len(input)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func resizeMsgBytes(t *testing.T, width, height int) []byte {
|
||||
t.Helper()
|
||||
bs, err := json.Marshal(spdyResizeMsg{Width: width, Height: height})
|
||||
|
@ -9,11 +9,15 @@
|
||||
"bytes"
|
||||
"compress/zlib"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"math/rand"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"go.uber.org/zap"
|
||||
@ -200,6 +204,29 @@ func Test_spdyFrame_parseHeaders(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Test_spdyFrame_ParseRand calls spdyFrame.Parse with randomly generated bytes
|
||||
// to test that it doesn't panic.
|
||||
func Test_spdyFrame_ParseRand(t *testing.T) {
|
||||
zl, err := zap.NewDevelopment()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
for i := range 100 {
|
||||
n := r.Intn(4096)
|
||||
b := make([]byte, n)
|
||||
_, err := r.Read(b)
|
||||
if err != nil {
|
||||
t.Fatalf("error generating random byte slice: %v", err)
|
||||
}
|
||||
sf := &spdyFrame{}
|
||||
f := func() {
|
||||
sf.Parse(b, zl.Sugar())
|
||||
}
|
||||
testPanic(t, f, fmt.Sprintf("[%d] Parse panicked running with byte slice of length %d: %v", i, n, r))
|
||||
}
|
||||
}
|
||||
|
||||
// payload takes a control frame type and a map with 0 or more header keys and
|
||||
// values and returns a SPDY control frame payload with the header as SPDY zlib
|
||||
// compressed header name/value block. The payload is padded with arbitrary
|
||||
@ -291,3 +318,13 @@ func header(hs map[string]string) http.Header {
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func testPanic(t *testing.T, f func(), msg string) {
|
||||
t.Helper()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatal(msg, r)
|
||||
}
|
||||
}()
|
||||
f()
|
||||
}
|
||||
|
301
k8s-operator/sessionrecording/ws/conn.go
Normal file
301
k8s-operator/sessionrecording/ws/conn.go
Normal file
@ -0,0 +1,301 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build !plan9
|
||||
|
||||
// package ws has functionality to parse 'kubectl exec' sessions streamed using
|
||||
// WebSocket protocol.
|
||||
package ws
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"k8s.io/apimachinery/pkg/util/remotecommand"
|
||||
"tailscale.com/k8s-operator/sessionrecording/tsrecorder"
|
||||
"tailscale.com/sessionrecording"
|
||||
"tailscale.com/util/multierr"
|
||||
)
|
||||
|
||||
// New wraps the provided network connection and returns a connection whose reads and writes will get triggered as data is received on the hijacked connection.
|
||||
// The connection must be a hijacked connection for a 'kubectl exec' session using WebSocket protocol and a *.channel.k8s.io subprotocol.
|
||||
// The hijacked connection is used to transmit *.channel.k8s.io streams between Kubernetes client ('kubectl') and the destination proxy controlled by Kubernetes.
|
||||
// Data read from the underlying network connection is data sent via one of the streams from the client to the container.
|
||||
// Data written to the underlying connection is data sent from the container to the client.
|
||||
// We parse the data and send everything for the STDOUT/STDERR streams to the configured tsrecorder as an asciinema recording with the provided header.
|
||||
// https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/4006-transition-spdy-to-websockets#proposal-new-remotecommand-sub-protocol-version---v5channelk8sio
|
||||
func New(c net.Conn, rec *tsrecorder.Client, ch sessionrecording.CastHeader, log *zap.SugaredLogger) net.Conn {
|
||||
return &conn{
|
||||
Conn: c,
|
||||
rec: rec,
|
||||
ch: ch,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
// conn is a wrapper around net.Conn. It reads the bytestream
|
||||
// for a 'kubectl exec' session, sends session recording data to the configured
|
||||
// recorder and forwards the raw bytes to the original destination.
|
||||
// A new conn is created per session.
|
||||
// conn only knows to how to read a 'kubectl exec' session that is streamed using WebSocket protocol.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455
|
||||
type conn struct {
|
||||
net.Conn
|
||||
// rec knows how to send data to a tsrecorder instance.
|
||||
rec *tsrecorder.Client
|
||||
// ch is the asiinema CastHeader for a session.
|
||||
ch sessionrecording.CastHeader
|
||||
log *zap.SugaredLogger
|
||||
|
||||
rmu sync.Mutex // sequences reads
|
||||
// currentReadMsg contains parsed contents of a websocket binary data message that
|
||||
// is currently being read from the underlying net.Conn.
|
||||
currentReadMsg *message
|
||||
// readBuf contains bytes for a currently parsed binary data message
|
||||
// read from the underlying conn. If the message is masked, it is
|
||||
// unmasked in place, so having this buffer allows us to avoid modifying
|
||||
// the original byte array.
|
||||
readBuf bytes.Buffer
|
||||
|
||||
wmu sync.Mutex // sequences writes
|
||||
writeCastHeaderOnce sync.Once
|
||||
closed bool // connection is closed
|
||||
// writeBuf contains bytes for a currently parsed binary data message
|
||||
// being written to the underlying conn. If the message is masked, it is
|
||||
// unmasked in place, so having this buffer allows us to avoid modifying
|
||||
// the original byte array.
|
||||
writeBuf bytes.Buffer
|
||||
// currentWriteMsg contains parsed contents of a websocket binary data message that
|
||||
// is currently being written to the underlying net.Conn.
|
||||
currentWriteMsg *message
|
||||
}
|
||||
|
||||
// Read reads bytes from the original connection and parses them as websocket
|
||||
// message fragments.
|
||||
// Bytes read from the original connection are the bytes sent from the Kubernetes client (kubectl) to the destination container via kubelet.
|
||||
|
||||
// If the message is for the resize stream, sets the width
|
||||
// and height of the CastHeader for this connection.
|
||||
// The fragment can be incomplete.
|
||||
func (c *conn) Read(b []byte) (int, error) {
|
||||
c.rmu.Lock()
|
||||
defer c.rmu.Unlock()
|
||||
n, err := c.Conn.Read(b)
|
||||
if err != nil {
|
||||
// It seems that we sometimes get a wrapped io.EOF, but the
|
||||
// caller checks for io.EOF with ==.
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = io.EOF
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
if n == 0 {
|
||||
c.log.Debug("[unexpected] Read called for 0 length bytes")
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
typ := messageType(opcode(b))
|
||||
if (typ == noOpcode && c.readMsgIsIncomplete()) || c.readBufHasIncompleteFragment() { // subsequent fragment
|
||||
if typ, err = c.curReadMsgType(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
// A control message can not be fragmented and we are not interested in
|
||||
// these messages. Just return.
|
||||
if isControlMessage(typ) {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// The only data message type that Kubernetes supports is binary message.
|
||||
// If we received another message type, return and let the API server close the connection.
|
||||
// https://github.com/kubernetes/client-go/blob/release-1.30/tools/remotecommand/websocket.go#L281
|
||||
if typ != binaryMessage {
|
||||
c.log.Infof("[unexpected] received a data message with a type that is not binary message type %v", typ)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
readMsg := &message{typ: typ} // start a new message...
|
||||
// ... or pick up an already started one if the previous fragment was not final.
|
||||
if c.readMsgIsIncomplete() || c.readBufHasIncompleteFragment() {
|
||||
readMsg = c.currentReadMsg
|
||||
}
|
||||
|
||||
if _, err := c.readBuf.Write(b[:n]); err != nil {
|
||||
return 0, fmt.Errorf("[unexpected] error writing message contents to read buffer: %w", err)
|
||||
}
|
||||
|
||||
ok, err := readMsg.Parse(c.readBuf.Bytes(), c.log)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error parsing message: %v", err)
|
||||
}
|
||||
if !ok { // incomplete fragment
|
||||
return n, nil
|
||||
}
|
||||
c.readBuf.Next(len(readMsg.raw))
|
||||
|
||||
if readMsg.isFinalized {
|
||||
// Stream IDs for websocket streams are static.
|
||||
// https://github.com/kubernetes/client-go/blob/v0.30.0-rc.1/tools/remotecommand/websocket.go#L218
|
||||
if readMsg.streamID.Load() == remotecommand.StreamResize {
|
||||
var err error
|
||||
var msg tsrecorder.ResizeMsg
|
||||
if err = json.Unmarshal(readMsg.payload, &msg); err != nil {
|
||||
return 0, fmt.Errorf("error umarshalling resize message: %w", err)
|
||||
}
|
||||
c.ch.Width = msg.Width
|
||||
c.ch.Height = msg.Height
|
||||
}
|
||||
}
|
||||
c.currentReadMsg = readMsg
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Write parses the written bytes as WebSocket message fragment. If the message
|
||||
// is for stdout or stderr streams, it is written to the configured tsrecorder.
|
||||
// A message fragment can be incomplete.
|
||||
func (c *conn) Write(b []byte) (int, error) {
|
||||
c.wmu.Lock()
|
||||
defer c.wmu.Unlock()
|
||||
if len(b) == 0 {
|
||||
c.log.Debug("[unexpected] Write called with 0 bytes")
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
typ := messageType(opcode(b))
|
||||
// If we are in process of parsing a message fragment, the received
|
||||
// bytes are not structured as a message fragment and can not be used to
|
||||
// determine a message fragment.
|
||||
if c.writeBufHasIncompleteFragment() { // buffer contains previous incomplete fragment
|
||||
var err error
|
||||
if typ, err = c.curWriteMsgType(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
if isControlMessage(typ) {
|
||||
return c.Conn.Write(b)
|
||||
}
|
||||
|
||||
writeMsg := &message{typ: typ} // start a new message...
|
||||
// ... or continue the existing one if it has not been finalized.
|
||||
if c.writeMsgIsIncomplete() || c.writeBufHasIncompleteFragment() {
|
||||
writeMsg = c.currentWriteMsg
|
||||
}
|
||||
|
||||
if _, err := c.writeBuf.Write(b); err != nil {
|
||||
c.log.Errorf("write: error writing to write buf: %v", err)
|
||||
return 0, fmt.Errorf("[unexpected] error writing to internal write buffer: %w", err)
|
||||
}
|
||||
|
||||
ok, err := writeMsg.Parse(c.writeBuf.Bytes(), c.log)
|
||||
if err != nil {
|
||||
c.log.Errorf("write: parsing a message errored: %v", err)
|
||||
return 0, fmt.Errorf("write: error parsing message: %v", err)
|
||||
}
|
||||
c.currentWriteMsg = writeMsg
|
||||
if !ok { // incomplete fragment
|
||||
return len(b), nil
|
||||
}
|
||||
c.writeBuf.Next(len(writeMsg.raw)) // advance frame
|
||||
|
||||
if len(writeMsg.payload) != 0 && writeMsg.isFinalized {
|
||||
if writeMsg.streamID.Load() == remotecommand.StreamStdOut || writeMsg.streamID.Load() == remotecommand.StreamStdErr {
|
||||
var err error
|
||||
c.writeCastHeaderOnce.Do(func() {
|
||||
var j []byte
|
||||
j, err = json.Marshal(c.ch)
|
||||
if err != nil {
|
||||
c.log.Errorf("error marhsalling conn: %v", err)
|
||||
return
|
||||
}
|
||||
j = append(j, '\n')
|
||||
err = c.rec.WriteCastLine(j)
|
||||
if err != nil {
|
||||
c.log.Errorf("received error from recorder: %v", err)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error writing CastHeader: %w", err)
|
||||
}
|
||||
if err := c.rec.Write(writeMsg.payload); err != nil {
|
||||
return 0, fmt.Errorf("error writing message to recorder: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
_, err = c.Conn.Write(c.currentWriteMsg.raw)
|
||||
if err != nil {
|
||||
c.log.Errorf("write: error writing to conn: %v", err)
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
c.wmu.Lock()
|
||||
defer c.wmu.Unlock()
|
||||
if c.closed {
|
||||
return nil
|
||||
}
|
||||
c.closed = true
|
||||
connCloseErr := c.Conn.Close()
|
||||
recCloseErr := c.rec.Close()
|
||||
return multierr.New(connCloseErr, recCloseErr)
|
||||
}
|
||||
|
||||
// writeBufHasIncompleteFragment returns true if the latest data message
|
||||
// fragment written to the connection was incomplete and the following write
|
||||
// must be the remaining payload bytes of that fragment.
|
||||
func (c *conn) writeBufHasIncompleteFragment() bool {
|
||||
return c.writeBuf.Len() != 0
|
||||
}
|
||||
|
||||
// readBufHasIncompleteFragment returns true if the latest data message
|
||||
// fragment read from the connection was incomplete and the following read
|
||||
// must be the remaining payload bytes of that fragment.
|
||||
func (c *conn) readBufHasIncompleteFragment() bool {
|
||||
return c.readBuf.Len() != 0
|
||||
}
|
||||
|
||||
// writeMsgIsIncomplete returns true if the latest WebSocket message written to
|
||||
// the connection was fragmented and the next data message fragment written to
|
||||
// the connection must be a fragment of that message.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.4
|
||||
func (c *conn) writeMsgIsIncomplete() bool {
|
||||
return c.currentWriteMsg != nil && !c.currentWriteMsg.isFinalized
|
||||
}
|
||||
|
||||
// readMsgIsIncomplete returns true if the latest WebSocket message written to
|
||||
// the connection was fragmented and the next data message fragment written to
|
||||
// the connection must be a fragment of that message.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.4
|
||||
func (c *conn) readMsgIsIncomplete() bool {
|
||||
return c.currentReadMsg != nil && !c.currentReadMsg.isFinalized
|
||||
}
|
||||
func (c *conn) curReadMsgType() (messageType, error) {
|
||||
if c.currentReadMsg != nil {
|
||||
return c.currentReadMsg.typ, nil
|
||||
}
|
||||
return 0, errors.New("[unexpected] attempted to determine type for nil message")
|
||||
}
|
||||
|
||||
func (c *conn) curWriteMsgType() (messageType, error) {
|
||||
if c.currentWriteMsg != nil {
|
||||
return c.currentWriteMsg.typ, nil
|
||||
}
|
||||
return 0, errors.New("[unexpected] attempted to determine type for nil message")
|
||||
}
|
||||
|
||||
// opcode reads the websocket message opcode that denotes the message type.
|
||||
// opcode is contained in bits [4-8] of the message.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
|
||||
func opcode(b []byte) int {
|
||||
// 0xf = 00001111; b & 00001111 zeroes out bits [0 - 3] of b
|
||||
var mask byte = 0xf
|
||||
return int(b[0] & mask)
|
||||
}
|
257
k8s-operator/sessionrecording/ws/conn_test.go
Normal file
257
k8s-operator/sessionrecording/ws/conn_test.go
Normal file
@ -0,0 +1,257 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build !plan9
|
||||
|
||||
package ws
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"k8s.io/apimachinery/pkg/util/remotecommand"
|
||||
"tailscale.com/k8s-operator/sessionrecording/fakes"
|
||||
"tailscale.com/k8s-operator/sessionrecording/tsrecorder"
|
||||
"tailscale.com/sessionrecording"
|
||||
"tailscale.com/tstest"
|
||||
)
|
||||
|
||||
func Test_conn_Read(t *testing.T) {
|
||||
zl, err := zap.NewDevelopment()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Resize stream ID + {"width": 10, "height": 20}
|
||||
testResizeMsg := []byte{byte(remotecommand.StreamResize), 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x3a, 0x32, 0x30, 0x7d}
|
||||
lenResizeMsgPayload := byte(len(testResizeMsg))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
inputs [][]byte
|
||||
wantWidth int
|
||||
wantHeight int
|
||||
}{
|
||||
{
|
||||
name: "single_read_control_message",
|
||||
inputs: [][]byte{{0x88, 0x0}},
|
||||
},
|
||||
{
|
||||
name: "single_read_resize_message",
|
||||
inputs: [][]byte{append([]byte{0x82, lenResizeMsgPayload}, testResizeMsg...)},
|
||||
wantWidth: 10,
|
||||
wantHeight: 20,
|
||||
},
|
||||
{
|
||||
name: "two_reads_resize_message",
|
||||
inputs: [][]byte{{0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22}, {0x80, 0x11, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x3a, 0x32, 0x30, 0x7d}},
|
||||
wantWidth: 10,
|
||||
wantHeight: 20,
|
||||
},
|
||||
{
|
||||
name: "three_reads_resize_message_with_split_fragment",
|
||||
inputs: [][]byte{{0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22}, {0x80, 0x11, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74}, {0x22, 0x3a, 0x32, 0x30, 0x7d}},
|
||||
wantWidth: 10,
|
||||
wantHeight: 20,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tc := &fakes.TestConn{}
|
||||
tc.ResetReadBuf()
|
||||
c := &conn{
|
||||
Conn: tc,
|
||||
log: zl.Sugar(),
|
||||
}
|
||||
for i, input := range tt.inputs {
|
||||
if err := tc.WriteReadBufBytes(input); err != nil {
|
||||
t.Fatalf("writing bytes to test conn: %v", err)
|
||||
}
|
||||
_, err := c.Read(make([]byte, len(input)))
|
||||
if err != nil {
|
||||
t.Errorf("[%d] conn.Read() errored %v", i, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
if tt.wantHeight != 0 || tt.wantWidth != 0 {
|
||||
if tt.wantWidth != c.ch.Width {
|
||||
t.Errorf("wants width: %v, got %v", tt.wantWidth, c.ch.Width)
|
||||
}
|
||||
if tt.wantHeight != c.ch.Height {
|
||||
t.Errorf("want height: %v, got %v", tt.wantHeight, c.ch.Height)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_conn_Write(t *testing.T) {
|
||||
zl, err := zap.NewDevelopment()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cl := tstest.NewClock(tstest.ClockOpts{})
|
||||
tests := []struct {
|
||||
name string
|
||||
inputs [][]byte
|
||||
wantForwarded []byte
|
||||
wantRecorded []byte
|
||||
firstWrite bool
|
||||
width int
|
||||
height int
|
||||
}{
|
||||
{
|
||||
name: "single_write_control_frame",
|
||||
inputs: [][]byte{{0x88, 0x0}},
|
||||
wantForwarded: []byte{0x88, 0x0},
|
||||
},
|
||||
{
|
||||
name: "single_write_stdout_data_message",
|
||||
inputs: [][]byte{{0x82, 0x3, 0x1, 0x7, 0x8}},
|
||||
wantForwarded: []byte{0x82, 0x3, 0x1, 0x7, 0x8},
|
||||
wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8}, cl),
|
||||
},
|
||||
{
|
||||
name: "single_write_stderr_data_message",
|
||||
inputs: [][]byte{{0x82, 0x3, 0x2, 0x7, 0x8}},
|
||||
wantForwarded: []byte{0x82, 0x3, 0x2, 0x7, 0x8},
|
||||
wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8}, cl),
|
||||
},
|
||||
{
|
||||
name: "single_write_stdin_data_message",
|
||||
inputs: [][]byte{{0x82, 0x3, 0x0, 0x7, 0x8}},
|
||||
wantForwarded: []byte{0x82, 0x3, 0x0, 0x7, 0x8},
|
||||
},
|
||||
{
|
||||
name: "single_write_stdout_data_message_with_cast_header",
|
||||
inputs: [][]byte{{0x82, 0x3, 0x1, 0x7, 0x8}},
|
||||
wantForwarded: []byte{0x82, 0x3, 0x1, 0x7, 0x8},
|
||||
wantRecorded: append(fakes.AsciinemaResizeMsg(t, 10, 20), fakes.CastLine(t, []byte{0x7, 0x8}, cl)...),
|
||||
width: 10,
|
||||
height: 20,
|
||||
firstWrite: true,
|
||||
},
|
||||
{
|
||||
name: "two_writes_stdout_data_message",
|
||||
inputs: [][]byte{{0x2, 0x3, 0x1, 0x7, 0x8}, {0x80, 0x6, 0x1, 0x1, 0x2, 0x3, 0x4, 0x5}},
|
||||
wantForwarded: []byte{0x2, 0x3, 0x1, 0x7, 0x8, 0x80, 0x6, 0x1, 0x1, 0x2, 0x3, 0x4, 0x5},
|
||||
wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8, 0x1, 0x2, 0x3, 0x4, 0x5}, cl),
|
||||
},
|
||||
{
|
||||
name: "three_writes_stdout_data_message_with_split_fragment",
|
||||
inputs: [][]byte{{0x2, 0x3, 0x1, 0x7, 0x8}, {0x80, 0x6, 0x1, 0x1, 0x2, 0x3}, {0x4, 0x5}},
|
||||
wantForwarded: []byte{0x2, 0x3, 0x1, 0x7, 0x8, 0x80, 0x6, 0x1, 0x1, 0x2, 0x3, 0x4, 0x5},
|
||||
wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8, 0x1, 0x2, 0x3, 0x4, 0x5}, cl),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tc := &fakes.TestConn{}
|
||||
sr := &fakes.TestSessionRecorder{}
|
||||
rec := tsrecorder.New(sr, cl, cl.Now(), true)
|
||||
c := &conn{
|
||||
Conn: tc,
|
||||
log: zl.Sugar(),
|
||||
ch: sessionrecording.CastHeader{
|
||||
Width: tt.width,
|
||||
Height: tt.height,
|
||||
},
|
||||
rec: rec,
|
||||
}
|
||||
if !tt.firstWrite {
|
||||
// This test case does not intend to test that cast header gets written once.
|
||||
c.writeCastHeaderOnce.Do(func() {})
|
||||
}
|
||||
for i, input := range tt.inputs {
|
||||
_, err := c.Write(input)
|
||||
if err != nil {
|
||||
t.Fatalf("[%d] conn.Write() errored: %v", i, err)
|
||||
}
|
||||
}
|
||||
// Assert that the expected bytes have been forwarded to the original destination.
|
||||
gotForwarded := tc.WriteBufBytes()
|
||||
if !reflect.DeepEqual(gotForwarded, tt.wantForwarded) {
|
||||
t.Errorf("expected bytes not forwarded, wants\n%x\ngot\n%x", tt.wantForwarded, gotForwarded)
|
||||
}
|
||||
|
||||
// Assert that the expected bytes have been forwarded to the session recorder.
|
||||
gotRecorded := sr.Bytes()
|
||||
if !reflect.DeepEqual(gotRecorded, tt.wantRecorded) {
|
||||
t.Errorf("expected bytes not recorded, wants\n%b\ngot\n%b", tt.wantRecorded, gotRecorded)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test_conn_ReadRand tests reading arbitrarily generated byte slices from conn to
|
||||
// test that we don't panic when parsing input from a broken or malicious
|
||||
// client.
|
||||
func Test_conn_ReadRand(t *testing.T) {
|
||||
zl, err := zap.NewDevelopment()
|
||||
if err != nil {
|
||||
t.Fatalf("error creating a test logger: %v", err)
|
||||
}
|
||||
for i := range 100 {
|
||||
tc := &fakes.TestConn{}
|
||||
tc.ResetReadBuf()
|
||||
c := &conn{
|
||||
Conn: tc,
|
||||
log: zl.Sugar(),
|
||||
}
|
||||
bb := fakes.RandomBytes(t)
|
||||
for j, input := range bb {
|
||||
if err := tc.WriteReadBufBytes(input); err != nil {
|
||||
t.Fatalf("[%d] writing bytes to test conn: %v", i, err)
|
||||
}
|
||||
f := func() {
|
||||
c.Read(make([]byte, len(input)))
|
||||
}
|
||||
testPanic(t, f, fmt.Sprintf("[%d %d] Read panic parsing input of length %d first bytes: %v, current read message: %+#v", i, j, len(input), firstBytes(input), c.currentReadMsg))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test_conn_WriteRand calls conn.Write with an arbitrary input to validate that it does not
|
||||
// panic.
|
||||
func Test_conn_WriteRand(t *testing.T) {
|
||||
zl, err := zap.NewDevelopment()
|
||||
if err != nil {
|
||||
t.Fatalf("error creating a test logger: %v", err)
|
||||
}
|
||||
cl := tstest.NewClock(tstest.ClockOpts{})
|
||||
sr := &fakes.TestSessionRecorder{}
|
||||
rec := tsrecorder.New(sr, cl, cl.Now(), true)
|
||||
for i := range 100 {
|
||||
tc := &fakes.TestConn{}
|
||||
c := &conn{
|
||||
Conn: tc,
|
||||
log: zl.Sugar(),
|
||||
rec: rec,
|
||||
}
|
||||
bb := fakes.RandomBytes(t)
|
||||
for j, input := range bb {
|
||||
f := func() {
|
||||
c.Write(input)
|
||||
}
|
||||
testPanic(t, f, fmt.Sprintf("[%d %d] Write: panic parsing input of length %d first bytes %b current write message %+#v", i, j, len(input), firstBytes(input), c.currentWriteMsg))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testPanic(t *testing.T, f func(), msg string) {
|
||||
t.Helper()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Fatal(msg, r)
|
||||
}
|
||||
}()
|
||||
f()
|
||||
}
|
||||
|
||||
func firstBytes(b []byte) []byte {
|
||||
if len(b) < 10 {
|
||||
return b
|
||||
}
|
||||
return b[:10]
|
||||
}
|
267
k8s-operator/sessionrecording/ws/message.go
Normal file
267
k8s-operator/sessionrecording/ws/message.go
Normal file
@ -0,0 +1,267 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build !plan9
|
||||
|
||||
package ws
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
noOpcode messageType = 0 // continuation frame for fragmented messages
|
||||
binaryMessage messageType = 2
|
||||
)
|
||||
|
||||
// messageType is the type of a websocket data or control message as defined by opcode.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
|
||||
// Known types of control messages are close, ping and pong.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.5
|
||||
// The only data message type supported by Kubernetes is binary message
|
||||
// https://github.com/kubernetes/client-go/blob/v0.30.0-rc.1/tools/remotecommand/websocket.go#L281
|
||||
type messageType int
|
||||
|
||||
// message is a parsed Websocket Message.
|
||||
type message struct {
|
||||
// payload is the contents of the so far parsed Websocket
|
||||
// data Message payload, potentially from multiple fragments written by
|
||||
// multiple invocations of Parse. As per RFC 6455 We can assume that the
|
||||
// fragments will always arrive in order and data messages will not be
|
||||
// interleaved.
|
||||
payload []byte
|
||||
|
||||
// isFinalized is set to true if msgPayload contains full contents of
|
||||
// the message (the final fragment has been received).
|
||||
isFinalized bool
|
||||
|
||||
// streamID is the stream to which the message belongs, i.e stdin, stout
|
||||
// etc. It is one of the stream IDs defined in
|
||||
// https://github.com/kubernetes/apimachinery/blob/73d12d09c5be8703587b5127416eb83dc3b7e182/pkg/util/httpstream/wsstream/doc.go#L23-L36
|
||||
streamID atomic.Uint32
|
||||
|
||||
// typ is the type of a WebsocketMessage as defined by its opcode
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
|
||||
typ messageType
|
||||
raw []byte
|
||||
}
|
||||
|
||||
// Parse accepts a websocket message fragment as a byte slice and parses its contents.
|
||||
// It returns true if the fragment is complete, false if the fragment is incomplete.
|
||||
// If the fragment is incomplete, Parse will be called again with the same fragment + more bytes when those are received.
|
||||
// If the fragment is complete, it will be parsed into msg.
|
||||
// A complete fragment can be:
|
||||
// - a fragment that consists of a whole message
|
||||
// - an initial fragment for a message for which we expect more fragments
|
||||
// - a subsequent fragment for a message that we are currently parsing and whose so-far parsed contents are stored in msg.
|
||||
// Parse must not be called with bytes that don't contain fragment header (so, no less than 2 bytes).
|
||||
// 0 1 2 3
|
||||
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
|
||||
// +-+-+-+-+-------+-+-------------+-------------------------------+
|
||||
// |F|R|R|R| opcode|M| Payload len | Extended payload length |
|
||||
// |I|S|S|S| (4) |A| (7) | (16/64) |
|
||||
// |N|V|V|V| |S| | (if payload len==126/127) |
|
||||
// | |1|2|3| |K| | |
|
||||
// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
|
||||
// | Extended payload length continued, if payload len == 127 |
|
||||
// + - - - - - - - - - - - - - - - +-------------------------------+
|
||||
// | |Masking-key, if MASK set to 1 |
|
||||
// +-------------------------------+-------------------------------+
|
||||
// | Masking-key (continued) | Payload Data |
|
||||
// +-------------------------------- - - - - - - - - - - - - - - - +
|
||||
// : Payload Data continued ... :
|
||||
// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
|
||||
// | Payload Data continued ... |
|
||||
// +---------------------------------------------------------------+
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
|
||||
//
|
||||
// Fragmentation rules:
|
||||
// An unfragmented message consists of a single frame with the FIN
|
||||
// bit set (Section 5.2) and an opcode other than 0.
|
||||
// A fragmented message consists of a single frame with the FIN bit
|
||||
// clear and an opcode other than 0, followed by zero or more frames
|
||||
// with the FIN bit clear and the opcode set to 0, and terminated by
|
||||
// a single frame with the FIN bit set and an opcode of 0.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.4
|
||||
func (msg *message) Parse(b []byte, log *zap.SugaredLogger) (bool, error) {
|
||||
if len(b) < 2 {
|
||||
return false, fmt.Errorf("[unexpected] Parse should not be called with less than 2 bytes, got %d bytes", len(b))
|
||||
}
|
||||
if msg.typ != binaryMessage {
|
||||
return false, fmt.Errorf("[unexpected] internal error: attempted to parse a message with type %d", msg.typ)
|
||||
}
|
||||
isInitialFragment := len(msg.raw) == 0
|
||||
|
||||
msg.isFinalized = isFinalFragment(b)
|
||||
|
||||
maskSet := isMasked(b)
|
||||
|
||||
payloadLength, payloadOffset, maskOffset, err := fragmentDimensions(b, maskSet)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("error determining payload length: %w", err)
|
||||
}
|
||||
log.Debugf("parse: parsing a message fragment with payload length: %d payload offset: %d maskOffset: %d mask set: %t, is finalized: %t, is initial fragment: %t", payloadLength, payloadOffset, maskOffset, maskSet, msg.isFinalized, isInitialFragment)
|
||||
|
||||
if len(b) < int(payloadOffset+payloadLength) { // incomplete fragment
|
||||
return false, nil
|
||||
}
|
||||
// TODO (irbekrm): perhaps only do this extra allocation if we know we
|
||||
// will need to unmask?
|
||||
msg.raw = make([]byte, int(payloadOffset)+int(payloadLength))
|
||||
copy(msg.raw, b[:payloadOffset+payloadLength])
|
||||
|
||||
// Extract the payload.
|
||||
msgPayload := b[payloadOffset : payloadOffset+payloadLength]
|
||||
|
||||
// Unmask the payload if needed.
|
||||
// TODO (irbekrm): instead of unmasking all of the payload each time,
|
||||
// determine if the payload is for a resize message early and skip
|
||||
// unmasking the remaining bytes if not.
|
||||
if maskSet {
|
||||
m := b[maskOffset:payloadOffset]
|
||||
var mask [4]byte
|
||||
copy(mask[:], m)
|
||||
maskBytes(mask, msgPayload)
|
||||
}
|
||||
|
||||
// Determine what stream the message is for. Stream ID of a Kubernetes
|
||||
// streaming session is a 32bit integer, stored in the first byte of the
|
||||
// message payload.
|
||||
// https://github.com/kubernetes/apimachinery/commit/73d12d09c5be8703587b5127416eb83dc3b7e182#diff-291f96e8632d04d2d20f5fb00f6b323492670570d65434e8eac90c7a442d13bdR23-R36
|
||||
if len(msgPayload) == 0 {
|
||||
return false, errors.New("[unexpected] received a message fragment with no stream ID")
|
||||
}
|
||||
|
||||
streamID := uint32(msgPayload[0])
|
||||
if !isInitialFragment && msg.streamID.Load() != streamID {
|
||||
return false, fmt.Errorf("[unexpected] received message fragments with mismatched streamIDs %d and %d", msg.streamID.Load(), streamID)
|
||||
}
|
||||
msg.streamID.Store(streamID)
|
||||
|
||||
// This is normal, Kubernetes seem to send a couple data messages with
|
||||
// no payloads at the start.
|
||||
if len(msgPayload) < 2 {
|
||||
return true, nil
|
||||
}
|
||||
msgPayload = msgPayload[1:] // remove the stream ID byte
|
||||
msg.payload = append(msg.payload, msgPayload...)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// maskBytes applies mask to bytes in place.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.3
|
||||
func maskBytes(key [4]byte, b []byte) {
|
||||
for i := range b {
|
||||
b[i] = b[i] ^ key[i%4]
|
||||
}
|
||||
}
|
||||
|
||||
// isControlMessage returns true if the message type is one of the known control
|
||||
// frame message types.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.5
|
||||
func isControlMessage(t messageType) bool {
|
||||
const (
|
||||
closeMessage messageType = 8
|
||||
pingMessage messageType = 9
|
||||
pongMessage messageType = 10
|
||||
)
|
||||
return t == closeMessage || t == pingMessage || t == pongMessage
|
||||
}
|
||||
|
||||
// isFinalFragment can be called with websocket message fragment and returns true if
|
||||
// the fragment is the final fragment of a websocket message.
|
||||
func isFinalFragment(b []byte) bool {
|
||||
return extractFirstBit(b[0]) != 0
|
||||
}
|
||||
|
||||
// isMasked can be called with a websocket message fragment and returns true if
|
||||
// the payload of the message is masked. It uses the mask bit to determine if
|
||||
// the payload is masked.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.3
|
||||
func isMasked(b []byte) bool {
|
||||
return extractFirstBit(b[1]) != 0
|
||||
}
|
||||
|
||||
// extractFirstBit extracts first bit of a byte by zeroing out all the other
|
||||
// bits.
|
||||
func extractFirstBit(b byte) byte {
|
||||
return b & 0x80
|
||||
}
|
||||
|
||||
// zeroFirstBit returns the provided byte with the first bit set to 0.
|
||||
func zeroFirstBit(b byte) byte {
|
||||
return b & 0x7f
|
||||
}
|
||||
|
||||
// fragmentDimensions returns payload length as well as payload offset and mask offset.
|
||||
func fragmentDimensions(b []byte, maskSet bool) (payloadLength, payloadOffset, maskOffset uint64, _ error) {
|
||||
|
||||
// payload length can be stored either in bits [9-15] or in bytes 2, 3
|
||||
// or in bytes 2, 3, 4, 5, 6, 7.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
|
||||
// 0 1 2 3
|
||||
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
|
||||
// +-+-+-+-+-------+-+-------------+-------------------------------+
|
||||
// |F|R|R|R| opcode|M| Payload len | Extended payload length |
|
||||
// |I|S|S|S| (4) |A| (7) | (16/64) |
|
||||
// |N|V|V|V| |S| | (if payload len==126/127) |
|
||||
// | |1|2|3| |K| | |
|
||||
// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
|
||||
// | Extended payload length continued, if payload len == 127 |
|
||||
// + - - - - - - - - - - - - - - - +-------------------------------+
|
||||
// | |Masking-key, if MASK set to 1 |
|
||||
// +-------------------------------+-------------------------------+
|
||||
payloadLengthIndicator := zeroFirstBit(b[1])
|
||||
switch {
|
||||
case payloadLengthIndicator < 126:
|
||||
maskOffset = 2
|
||||
payloadLength = uint64(payloadLengthIndicator)
|
||||
case payloadLengthIndicator == 126:
|
||||
maskOffset = 4
|
||||
if len(b) < int(maskOffset) {
|
||||
return 0, 0, 0, fmt.Errorf("invalid message fragment- length indicator suggests that length is stored in bytes 2:4, but message length is only %d", len(b))
|
||||
}
|
||||
payloadLength = uint64(binary.BigEndian.Uint16(b[2:4]))
|
||||
case payloadLengthIndicator == 127:
|
||||
maskOffset = 10
|
||||
if len(b) < int(maskOffset) {
|
||||
return 0, 0, 0, fmt.Errorf("invalid message fragment- length indicator suggests that length is stored in bytes 2:10, but message length is only %d", len(b))
|
||||
}
|
||||
payloadLength = binary.BigEndian.Uint64(b[2:10])
|
||||
default:
|
||||
return 0, 0, 0, fmt.Errorf("unexpected payload length indicator value: %v", payloadLengthIndicator)
|
||||
}
|
||||
|
||||
// Ensure that a rogue or broken client doesn't cause us attempt to
|
||||
// allocate a huge array by setting a high payload size.
|
||||
// websocket.DefaultMaxPayloadBytes is the maximum payload size accepted
|
||||
// by server side of this connection, so we can safely reject messages
|
||||
// with larger payload size.
|
||||
if payloadLength > websocket.DefaultMaxPayloadBytes {
|
||||
return 0, 0, 0, fmt.Errorf("[unexpected]: too large payload size: %v", payloadLength)
|
||||
}
|
||||
|
||||
// Masking key can take up 0 or 4 bytes- we need to take that into
|
||||
// account when determining payload offset.
|
||||
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
|
||||
// ....
|
||||
// + - - - - - - - - - - - - - - - +-------------------------------+
|
||||
// | |Masking-key, if MASK set to 1 |
|
||||
// +-------------------------------+-------------------------------+
|
||||
// | Masking-key (continued) | Payload Data |
|
||||
// + - - - - - - - - - - - - - - - +-------------------------------+
|
||||
// ...
|
||||
if maskSet {
|
||||
payloadOffset = maskOffset + 4
|
||||
} else {
|
||||
payloadOffset = maskOffset
|
||||
}
|
||||
return
|
||||
}
|
215
k8s-operator/sessionrecording/ws/message_test.go
Normal file
215
k8s-operator/sessionrecording/ws/message_test.go
Normal file
@ -0,0 +1,215 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build !plan9
|
||||
|
||||
package ws
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"math/rand"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
func Test_msg_Parse(t *testing.T) {
|
||||
zl, err := zap.NewDevelopment()
|
||||
if err != nil {
|
||||
t.Fatalf("error creating a test logger: %v", err)
|
||||
}
|
||||
testMask := [4]byte{1, 2, 3, 4}
|
||||
bs126, bs126Len := bytesSlice2ByteLen(t)
|
||||
bs127, bs127Len := byteSlice8ByteLen(t)
|
||||
tests := []struct {
|
||||
name string
|
||||
b []byte
|
||||
initialPayload []byte
|
||||
wantPayload []byte
|
||||
wantIsFinalized bool
|
||||
wantStreamID uint32
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "single_fragment_stdout_stream_no_payload_no_mask",
|
||||
b: []byte{0x82, 0x1, 0x1},
|
||||
wantPayload: nil,
|
||||
wantIsFinalized: true,
|
||||
wantStreamID: 1,
|
||||
},
|
||||
{
|
||||
name: "single_fragment_stderr_steam_no_payload_has_mask",
|
||||
b: append([]byte{0x82, 0x81, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x2})...),
|
||||
wantPayload: nil,
|
||||
wantIsFinalized: true,
|
||||
wantStreamID: 2,
|
||||
},
|
||||
{
|
||||
name: "single_fragment_stdout_stream_no_mask_has_payload",
|
||||
b: []byte{0x82, 0x3, 0x1, 0x7, 0x8},
|
||||
wantPayload: []byte{0x7, 0x8},
|
||||
wantIsFinalized: true,
|
||||
wantStreamID: 1,
|
||||
},
|
||||
{
|
||||
name: "single_fragment_stdout_stream_has_mask_has_payload",
|
||||
b: append([]byte{0x82, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...),
|
||||
wantPayload: []byte{0x7, 0x8},
|
||||
wantIsFinalized: true,
|
||||
wantStreamID: 1,
|
||||
},
|
||||
{
|
||||
name: "initial_fragment_stdout_stream_no_mask_has_payload",
|
||||
b: []byte{0x2, 0x3, 0x1, 0x7, 0x8},
|
||||
wantPayload: []byte{0x7, 0x8},
|
||||
wantStreamID: 1,
|
||||
},
|
||||
{
|
||||
name: "initial_fragment_stdout_stream_has_mask_has_payload",
|
||||
b: append([]byte{0x2, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...),
|
||||
wantPayload: []byte{0x7, 0x8},
|
||||
wantStreamID: 1,
|
||||
},
|
||||
{
|
||||
name: "subsequent_fragment_stdout_stream_no_mask_has_payload",
|
||||
b: []byte{0x0, 0x3, 0x1, 0x7, 0x8},
|
||||
initialPayload: []byte{0x1, 0x2, 0x3},
|
||||
wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8},
|
||||
wantStreamID: 1,
|
||||
},
|
||||
{
|
||||
name: "subsequent_fragment_stdout_stream_has_mask_has_payload",
|
||||
b: append([]byte{0x0, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...),
|
||||
initialPayload: []byte{0x1, 0x2, 0x3},
|
||||
wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8},
|
||||
wantStreamID: 1,
|
||||
},
|
||||
{
|
||||
name: "final_fragment_stdout_stream_no_mask_has_payload",
|
||||
b: []byte{0x80, 0x3, 0x1, 0x7, 0x8},
|
||||
initialPayload: []byte{0x1, 0x2, 0x3},
|
||||
wantIsFinalized: true,
|
||||
wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8},
|
||||
wantStreamID: 1,
|
||||
},
|
||||
{
|
||||
name: "final_fragment_stdout_stream_has_mask_has_payload",
|
||||
b: append([]byte{0x80, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...),
|
||||
initialPayload: []byte{0x1, 0x2, 0x3},
|
||||
wantIsFinalized: true,
|
||||
wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8},
|
||||
wantStreamID: 1,
|
||||
},
|
||||
{
|
||||
name: "single_large_fragment_no_mask_length_hint_126",
|
||||
b: append(append([]byte{0x80, 0x7e}, bs126Len...), append([]byte{0x1}, bs126...)...),
|
||||
wantIsFinalized: true,
|
||||
wantPayload: bs126,
|
||||
wantStreamID: 1,
|
||||
},
|
||||
{
|
||||
name: "single_large_fragment_no_mask_length_hint_127",
|
||||
b: append(append([]byte{0x80, 0x7f}, bs127Len...), append([]byte{0x1}, bs127...)...),
|
||||
wantIsFinalized: true,
|
||||
wantPayload: bs127,
|
||||
wantStreamID: 1,
|
||||
},
|
||||
{
|
||||
name: "zero_length_bytes",
|
||||
b: []byte{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msg := &message{
|
||||
typ: binaryMessage,
|
||||
payload: tt.initialPayload,
|
||||
}
|
||||
if _, err := msg.Parse(tt.b, zl.Sugar()); (err != nil) != tt.wantErr {
|
||||
t.Errorf("msg.Parse() = %v, wantsErr: %t", err, tt.wantErr)
|
||||
}
|
||||
if msg.isFinalized != tt.wantIsFinalized {
|
||||
t.Errorf("wants message to be finalized: %t, got: %t", tt.wantIsFinalized, msg.isFinalized)
|
||||
}
|
||||
if msg.streamID.Load() != tt.wantStreamID {
|
||||
t.Errorf("wants stream ID: %d, got: %d", tt.wantStreamID, msg.streamID.Load())
|
||||
}
|
||||
if !reflect.DeepEqual(msg.payload, tt.wantPayload) {
|
||||
t.Errorf("unexpected message payload after Parse, wants %b got %b", tt.wantPayload, msg.payload)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test_msg_Parse_Rand calls Parse with a randomly generated input to verify
|
||||
// that it doesn't panic.
|
||||
func Test_msg_Parse_Rand(t *testing.T) {
|
||||
zl, err := zap.NewDevelopment()
|
||||
if err != nil {
|
||||
t.Fatalf("error creating a test logger: %v", err)
|
||||
}
|
||||
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
for i := range 100 {
|
||||
n := r.Intn(4096)
|
||||
b := make([]byte, n)
|
||||
_, err := r.Read(b)
|
||||
if err != nil {
|
||||
t.Fatalf("error generating random byte slice: %v", err)
|
||||
}
|
||||
msg := message{typ: binaryMessage}
|
||||
f := func() {
|
||||
msg.Parse(b, zl.Sugar())
|
||||
}
|
||||
testPanic(t, f, fmt.Sprintf("[%d] Parse panicked running with byte slice of length %d: %v", i, n, r))
|
||||
}
|
||||
}
|
||||
|
||||
// byteSlice2ByteLen generates a number that represents websocket message fragment length and is stored in an 8 byte slice.
|
||||
// Returns the byte slice with the length as well as a slice of arbitrary bytes of the given length.
|
||||
// This is used to generate test input representing websocket message with payload length hint 126.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
|
||||
func bytesSlice2ByteLen(t *testing.T) ([]byte, []byte) {
|
||||
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
var n uint16
|
||||
n = uint16(rand.Intn(65535 - 1)) // space for and additional 1 byte stream ID
|
||||
b := make([]byte, n)
|
||||
_, err := r.Read(b)
|
||||
if err != nil {
|
||||
t.Fatalf("error generating random byte slice: %v ", err)
|
||||
}
|
||||
bb := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(bb, n+1) // + stream ID
|
||||
return b, bb
|
||||
}
|
||||
|
||||
// byteSlice8ByteLen generates a number that represents websocket message fragment length and is stored in an 8 byte slice.
|
||||
// Returns the byte slice with the length as well as a slice of arbitrary bytes of the given length.
|
||||
// This is used to generate test input representing websocket message with payload length hint 127.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
|
||||
func byteSlice8ByteLen(t *testing.T) ([]byte, []byte) {
|
||||
nanos := time.Now().UnixNano()
|
||||
t.Logf("Creating random source with seed %v", nanos)
|
||||
r := rand.New(rand.NewSource(nanos))
|
||||
var n uint64
|
||||
n = uint64(rand.Intn(websocket.DefaultMaxPayloadBytes - 1)) // space for and additional 1 byte stream ID
|
||||
t.Logf("byteSlice8ByteLen: generating message payload of length %d", n)
|
||||
b := make([]byte, n)
|
||||
_, err := r.Read(b)
|
||||
if err != nil {
|
||||
t.Fatalf("error generating random byte slice: %v ", err)
|
||||
}
|
||||
bb := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(bb, n+1) // + stream ID
|
||||
return b, bb
|
||||
}
|
||||
|
||||
func maskedBytes(mask [4]byte, b []byte) []byte {
|
||||
maskBytes(mask, b)
|
||||
return b
|
||||
}
|
Loading…
Reference in New Issue
Block a user