k8s-operator: handle multiple WebSocket frames per read (#16678) (#16679)

Cherry picks bug fix #16678 and flake fix #16680 onto the 1.86 release branch.

When kubectl starts an interactive attach session, it sends 2 resize
messages in quick succession. It seems that particularly in HTTP mode,
we often receive both of these WebSocket frames from the underlying
connection in a single read. However, our parser currently assumes 0-1
frames per read, and leaves the second frame in the read buffer until
the next read from the underlying connection. It doesn't take long after
that before we end up failing to skip a control message as we normally
should, and then we parse a control message as though it will have a
stream ID (part of the Kubernetes protocol) and error out.

Instead, we should keep parsing frames from the read buffer for as long
as we're able to parse complete frames, so this commit refactors the
messages parsing logic into a loop based on the contents of the read
buffer being non-empty.

k/k staging/src/k8s.io/kubectl/pkg/cmd/attach/attach.go for full
details of the resize messages.

There are at least a couple more multiple-frame read edge cases we
should handle, but this commit is very conservatively fixing a single
observed issue to make it a low-risk candidate for cherry picking.

Updates #13358

Change-Id: Iafb91ad1cbeed9c5231a1525d4563164fc1f002f

Signed-off-by: Tom Proctor <tomhjp@users.noreply.github.com>
This commit is contained in:
Tom Proctor 2025-07-28 14:11:30 +01:00 committed by GitHub
parent fdcff402fb
commit 91d65e03e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 107 additions and 65 deletions

View File

@ -114,8 +114,9 @@ func (ap *APIServerProxy) Run(ctx context.Context) error {
mux.HandleFunc("GET /api/v1/namespaces/{namespace}/pods/{pod}/attach", ap.serveAttachWS) mux.HandleFunc("GET /api/v1/namespaces/{namespace}/pods/{pod}/attach", ap.serveAttachWS)
ap.hs = &http.Server{ ap.hs = &http.Server{
Handler: mux, Handler: mux,
ErrorLog: zap.NewStdLog(ap.log.Desugar()), ErrorLog: zap.NewStdLog(ap.log.Desugar()),
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)),
} }
mode := "noauth" mode := "noauth"
@ -140,7 +141,6 @@ func (ap *APIServerProxy) Run(ctx context.Context) error {
GetCertificate: ap.lc.GetCertificate, GetCertificate: ap.lc.GetCertificate,
NextProtos: []string{"http/1.1"}, NextProtos: []string{"http/1.1"},
} }
ap.hs.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
} else { } else {
var err error var err error
tsLn, err = ap.ts.Listen("tcp", ":80") tsLn, err = ap.ts.Listen("tcp", ":80")

View File

@ -236,7 +236,6 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn,
if err := lc.Close(); err != nil { if err := lc.Close(); err != nil {
h.log.Infof("error closing recorder connections: %v", err) h.log.Infof("error closing recorder connections: %v", err)
} }
return
}() }()
return lc, nil return lc, nil
} }

View File

@ -148,6 +148,8 @@ func (c *conn) Read(b []byte) (int, error) {
return 0, nil return 0, nil
} }
// TODO(tomhjp): If we get multiple frames in a single Read with different
// types, we may parse the second frame with the wrong type.
typ := messageType(opcode(b)) typ := messageType(opcode(b))
if (typ == noOpcode && c.readMsgIsIncomplete()) || c.readBufHasIncompleteFragment() { // subsequent fragment if (typ == noOpcode && c.readMsgIsIncomplete()) || c.readBufHasIncompleteFragment() { // subsequent fragment
if typ, err = c.curReadMsgType(); err != nil { if typ, err = c.curReadMsgType(); err != nil {
@ -157,6 +159,8 @@ func (c *conn) Read(b []byte) (int, error) {
// A control message can not be fragmented and we are not interested in // A control message can not be fragmented and we are not interested in
// these messages. Just return. // these messages. Just return.
// TODO(tomhjp): If we get multiple frames in a single Read, we may skip
// some non-control messages.
if isControlMessage(typ) { if isControlMessage(typ) {
return n, nil return n, nil
} }
@ -169,62 +173,65 @@ func (c *conn) Read(b []byte) (int, error) {
return n, nil return n, nil
} }
readMsg := &message{typ: typ} // start a new message...
// ... or pick up an already started one if the previous fragment was not final.
if c.readMsgIsIncomplete() || c.readBufHasIncompleteFragment() {
readMsg = c.currentReadMsg
}
if _, err := c.readBuf.Write(b[:n]); err != nil { if _, err := c.readBuf.Write(b[:n]); err != nil {
return 0, fmt.Errorf("[unexpected] error writing message contents to read buffer: %w", err) return 0, fmt.Errorf("[unexpected] error writing message contents to read buffer: %w", err)
} }
ok, err := readMsg.Parse(c.readBuf.Bytes(), c.log) for c.readBuf.Len() != 0 {
if err != nil { readMsg := &message{typ: typ} // start a new message...
return 0, fmt.Errorf("error parsing message: %v", err) // ... or pick up an already started one if the previous fragment was not final.
} if c.readMsgIsIncomplete() {
if !ok { // incomplete fragment readMsg = c.currentReadMsg
return n, nil }
}
c.readBuf.Next(len(readMsg.raw))
if readMsg.isFinalized && !c.readMsgIsIncomplete() { ok, err := readMsg.Parse(c.readBuf.Bytes(), c.log)
// we want to send stream resize messages for terminal sessions if err != nil {
// Stream IDs for websocket streams are static. return 0, fmt.Errorf("error parsing message: %v", err)
// https://github.com/kubernetes/client-go/blob/v0.30.0-rc.1/tools/remotecommand/websocket.go#L218 }
if readMsg.streamID.Load() == remotecommand.StreamResize && c.hasTerm { if !ok { // incomplete fragment
var msg tsrecorder.ResizeMsg return n, nil
if err = json.Unmarshal(readMsg.payload, &msg); err != nil { }
return 0, fmt.Errorf("error umarshalling resize message: %w", err) c.readBuf.Next(len(readMsg.raw))
}
c.ch.Width = msg.Width if readMsg.isFinalized && !c.readMsgIsIncomplete() {
c.ch.Height = msg.Height // we want to send stream resize messages for terminal sessions
// Stream IDs for websocket streams are static.
// https://github.com/kubernetes/client-go/blob/v0.30.0-rc.1/tools/remotecommand/websocket.go#L218
if readMsg.streamID.Load() == remotecommand.StreamResize && c.hasTerm {
var msg tsrecorder.ResizeMsg
if err = json.Unmarshal(readMsg.payload, &msg); err != nil {
return 0, fmt.Errorf("error umarshalling resize message: %w", err)
}
var isInitialResize bool c.ch.Width = msg.Width
c.writeCastHeaderOnce.Do(func() { c.ch.Height = msg.Height
isInitialResize = true
// If this is a session with a terminal attached,
// we must wait for the terminal width and
// height to be parsed from a resize message
// before sending CastHeader, else tsrecorder
// will not be able to play this recording.
err = c.rec.WriteCastHeader(c.ch)
close(c.initialCastHeaderSent)
})
if err != nil {
return 0, fmt.Errorf("error writing CastHeader: %w", err)
}
if !isInitialResize { var isInitialResize bool
if err := c.rec.WriteResize(msg.Height, msg.Width); err != nil { c.writeCastHeaderOnce.Do(func() {
return 0, fmt.Errorf("error writing resize message: %w", err) isInitialResize = true
// If this is a session with a terminal attached,
// we must wait for the terminal width and
// height to be parsed from a resize message
// before sending CastHeader, else tsrecorder
// will not be able to play this recording.
err = c.rec.WriteCastHeader(c.ch)
close(c.initialCastHeaderSent)
})
if err != nil {
return 0, fmt.Errorf("error writing CastHeader: %w", err)
}
if !isInitialResize {
if err := c.rec.WriteResize(msg.Height, msg.Width); err != nil {
return 0, fmt.Errorf("error writing resize message: %w", err)
}
} }
} }
} }
c.currentReadMsg = readMsg
} }
c.currentReadMsg = readMsg
return n, nil return n, nil
} }

View File

@ -9,6 +9,7 @@ import (
"context" "context"
"fmt" "fmt"
"reflect" "reflect"
"runtime/debug"
"testing" "testing"
"time" "time"
@ -58,15 +59,39 @@ func Test_conn_Read(t *testing.T) {
wantCastHeaderHeight: 20, wantCastHeaderHeight: 20,
}, },
{ {
name: "two_reads_resize_message", name: "resize_data_frame_two_in_one_read",
inputs: [][]byte{{0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22}, {0x80, 0x11, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x3a, 0x32, 0x30, 0x7d}}, inputs: [][]byte{
fmt.Appendf(nil, "%s%s",
append([]byte{0x82, lenResizeMsgPayload}, testResizeMsg...),
append([]byte{0x82, lenResizeMsgPayload}, testResizeMsg...),
),
},
wantRecorded: append(fakes.AsciinemaCastHeaderMsg(t, 10, 20), fakes.AsciinemaCastResizeMsg(t, 10, 20)...),
wantCastHeaderWidth: 10,
wantCastHeaderHeight: 20,
},
{
name: "two_reads_resize_message",
inputs: [][]byte{
// op, len, stream ID, `{"width`
{0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22},
// op, len, stream ID, `:10,"height":20}`
{0x80, 0x11, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x3a, 0x32, 0x30, 0x7d},
},
wantCastHeaderWidth: 10, wantCastHeaderWidth: 10,
wantCastHeaderHeight: 20, wantCastHeaderHeight: 20,
wantRecorded: fakes.AsciinemaCastHeaderMsg(t, 10, 20), wantRecorded: fakes.AsciinemaCastHeaderMsg(t, 10, 20),
}, },
{ {
name: "three_reads_resize_message_with_split_fragment", name: "three_reads_resize_message_with_split_fragment",
inputs: [][]byte{{0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22}, {0x80, 0x11, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74}, {0x22, 0x3a, 0x32, 0x30, 0x7d}}, inputs: [][]byte{
// op, len, stream ID, `{"width"`
{0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22},
// op, len, stream ID, `:10,"height`
{0x00, 0x0c, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74},
// op, len, stream ID, `":20}`
{0x80, 0x06, 0x4, 0x22, 0x3a, 0x32, 0x30, 0x7d},
},
wantCastHeaderWidth: 10, wantCastHeaderWidth: 10,
wantCastHeaderHeight: 20, wantCastHeaderHeight: 20,
wantRecorded: fakes.AsciinemaCastHeaderMsg(t, 10, 20), wantRecorded: fakes.AsciinemaCastHeaderMsg(t, 10, 20),
@ -260,19 +285,28 @@ func Test_conn_WriteRand(t *testing.T) {
sr := &fakes.TestSessionRecorder{} sr := &fakes.TestSessionRecorder{}
rec := tsrecorder.New(sr, cl, cl.Now(), true, zl.Sugar()) rec := tsrecorder.New(sr, cl, cl.Now(), true, zl.Sugar())
for i := range 100 { for i := range 100 {
tc := &fakes.TestConn{} t.Run(fmt.Sprintf("test_%d", i), func(t *testing.T) {
c := &conn{ tc := &fakes.TestConn{}
Conn: tc, c := &conn{
log: zl.Sugar(), Conn: tc,
rec: rec, log: zl.Sugar(),
} rec: rec,
bb := fakes.RandomBytes(t)
for j, input := range bb { ctx: context.Background(), // ctx must be non-nil.
f := func() { initialCastHeaderSent: make(chan struct{}),
c.Write(input)
} }
testPanic(t, f, fmt.Sprintf("[%d %d] Write: panic parsing input of length %d first bytes %b current write message %+#v", i, j, len(input), firstBytes(input), c.currentWriteMsg)) // Never block for random data.
} c.writeCastHeaderOnce.Do(func() {
close(c.initialCastHeaderSent)
})
bb := fakes.RandomBytes(t)
for j, input := range bb {
f := func() {
c.Write(input)
}
testPanic(t, f, fmt.Sprintf("[%d %d] Write: panic parsing input of length %d first bytes %b current write message %+#v", i, j, len(input), firstBytes(input), c.currentWriteMsg))
}
})
} }
} }
@ -280,7 +314,7 @@ func testPanic(t *testing.T, f func(), msg string) {
t.Helper() t.Helper()
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
t.Fatal(msg, r) t.Fatal(msg, r, string(debug.Stack()))
} }
}() }()
f() f()

View File

@ -7,10 +7,10 @@ package ws
import ( import (
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"sync/atomic" "sync/atomic"
"github.com/pkg/errors"
"go.uber.org/zap" "go.uber.org/zap"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
@ -139,6 +139,8 @@ func (msg *message) Parse(b []byte, log *zap.SugaredLogger) (bool, error) {
return false, errors.New("[unexpected] received a message fragment with no stream ID") return false, errors.New("[unexpected] received a message fragment with no stream ID")
} }
// Stream ID will be one of the constants from:
// https://github.com/kubernetes/kubernetes/blob/f9ed14bf9b1119a2e091f4b487a3b54930661034/staging/src/k8s.io/apimachinery/pkg/util/remotecommand/constants.go#L57-L64
streamID := uint32(msgPayload[0]) streamID := uint32(msgPayload[0])
if !isInitialFragment && msg.streamID.Load() != streamID { if !isInitialFragment && msg.streamID.Load() != streamID {
return false, fmt.Errorf("[unexpected] received message fragments with mismatched streamIDs %d and %d", msg.streamID.Load(), streamID) return false, fmt.Errorf("[unexpected] received message fragments with mismatched streamIDs %d and %d", msg.streamID.Load(), streamID)