k8s-operator: handle multiple messages per WebSocket frame

Change-Id: Iafb91ad1cbeed9c5231a1525d4563164fc1f002f
Signed-off-by: Tom Proctor <tomhjp@users.noreply.github.com>
This commit is contained in:
Tom Proctor 2025-07-24 20:33:24 +01:00
parent 2a5d9c7269
commit 225aeda80f
4 changed files with 52 additions and 48 deletions

View File

@ -116,6 +116,7 @@ func (ap *APIServerProxy) Run(ctx context.Context) error {
ap.hs = &http.Server{ ap.hs = &http.Server{
Handler: mux, Handler: mux,
ErrorLog: zap.NewStdLog(ap.log.Desugar()), ErrorLog: zap.NewStdLog(ap.log.Desugar()),
TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)),
} }
mode := "noauth" mode := "noauth"
@ -140,7 +141,6 @@ func (ap *APIServerProxy) Run(ctx context.Context) error {
GetCertificate: ap.lc.GetCertificate, GetCertificate: ap.lc.GetCertificate,
NextProtos: []string{"http/1.1"}, NextProtos: []string{"http/1.1"},
} }
ap.hs.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
} else { } else {
var err error var err error
tsLn, err = ap.ts.Listen("tcp", ":80") tsLn, err = ap.ts.Listen("tcp", ":80")

View File

@ -236,7 +236,6 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn,
if err := lc.Close(); err != nil { if err := lc.Close(); err != nil {
h.log.Infof("error closing recorder connections: %v", err) h.log.Infof("error closing recorder connections: %v", err)
} }
return
}() }()
return lc, nil return lc, nil
} }

View File

@ -169,16 +169,17 @@ func (c *conn) Read(b []byte) (int, error) {
return n, nil return n, nil
} }
readMsg := &message{typ: typ} // start a new message...
// ... or pick up an already started one if the previous fragment was not final.
if c.readMsgIsIncomplete() || c.readBufHasIncompleteFragment() {
readMsg = c.currentReadMsg
}
if _, err := c.readBuf.Write(b[:n]); err != nil { if _, err := c.readBuf.Write(b[:n]); err != nil {
return 0, fmt.Errorf("[unexpected] error writing message contents to read buffer: %w", err) return 0, fmt.Errorf("[unexpected] error writing message contents to read buffer: %w", err)
} }
for c.readBuf.Len() != 0 {
readMsg := &message{typ: typ} // start a new message...
// ... or pick up an already started one if the previous fragment was not final.
if c.readMsgIsIncomplete() {
readMsg = c.currentReadMsg
}
ok, err := readMsg.Parse(c.readBuf.Bytes(), c.log) ok, err := readMsg.Parse(c.readBuf.Bytes(), c.log)
if err != nil { if err != nil {
return 0, fmt.Errorf("error parsing message: %v", err) return 0, fmt.Errorf("error parsing message: %v", err)
@ -225,6 +226,8 @@ func (c *conn) Read(b []byte) (int, error) {
} }
c.currentReadMsg = readMsg c.currentReadMsg = readMsg
}
return n, nil return n, nil
} }

View File

@ -7,10 +7,10 @@ package ws
import ( import (
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"sync/atomic" "sync/atomic"
"github.com/pkg/errors"
"go.uber.org/zap" "go.uber.org/zap"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
@ -139,6 +139,8 @@ func (msg *message) Parse(b []byte, log *zap.SugaredLogger) (bool, error) {
return false, errors.New("[unexpected] received a message fragment with no stream ID") return false, errors.New("[unexpected] received a message fragment with no stream ID")
} }
// Stream ID will be one of the constants from:
// https://github.com/kubernetes/kubernetes/blob/f9ed14bf9b1119a2e091f4b487a3b54930661034/staging/src/k8s.io/apimachinery/pkg/util/remotecommand/constants.go#L57-L64
streamID := uint32(msgPayload[0]) streamID := uint32(msgPayload[0])
if !isInitialFragment && msg.streamID.Load() != streamID { if !isInitialFragment && msg.streamID.Load() != streamID {
return false, fmt.Errorf("[unexpected] received message fragments with mismatched streamIDs %d and %d", msg.streamID.Load(), streamID) return false, fmt.Errorf("[unexpected] received message fragments with mismatched streamIDs %d and %d", msg.streamID.Load(), streamID)