diff --git a/control/controlhttp/client.go b/control/controlhttp/client.go index c9db75025..cf97acb23 100644 --- a/control/controlhttp/client.go +++ b/control/controlhttp/client.go @@ -35,6 +35,7 @@ "tailscale.com/net/dnscache" "tailscale.com/net/dnsfallback" "tailscale.com/net/netns" + "tailscale.com/net/netutil" "tailscale.com/net/tlsdial" "tailscale.com/net/tshttpproxy" "tailscale.com/types/key" @@ -232,22 +233,5 @@ func (a *dialParams) tryURL(u *url.URL, init []byte) (net.Conn, error) { return nil, errors.New("http Transport did not provide a writable body") } - return &wrappedConn{switchedConn, rwc}, nil -} - -type wrappedConn struct { - net.Conn - rwc io.ReadWriteCloser -} - -func (w *wrappedConn) Read(bs []byte) (int, error) { - return w.rwc.Read(bs) -} - -func (w *wrappedConn) Write(bs []byte) (int, error) { - return w.rwc.Write(bs) -} - -func (w *wrappedConn) Close() error { - return w.rwc.Close() + return netutil.NewAltReadWriteCloserConn(rwc, switchedConn), nil } diff --git a/control/controlhttp/server.go b/control/controlhttp/server.go index 92bd9ec9b..0e38da860 100644 --- a/control/controlhttp/server.go +++ b/control/controlhttp/server.go @@ -5,15 +5,14 @@ package controlhttp import ( - "bufio" "context" "encoding/base64" "errors" "fmt" - "net" "net/http" "tailscale.com/control/controlbase" + "tailscale.com/net/netutil" "tailscale.com/types/key" ) @@ -62,9 +61,7 @@ func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri conn.Close() return nil, fmt.Errorf("flushing hijacked HTTP buffer: %w", err) } - if brw.Reader.Buffered() > 0 { - conn = &drainBufConn{conn, brw.Reader} - } + conn = netutil.NewDrainBufConn(conn, brw.Reader) nc, err := controlbase.Server(ctx, conn, private, init) if err != nil { @@ -74,22 +71,3 @@ func AcceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri return nc, nil } - -// drainBufConn is a net.Conn with an initial bunch of bytes in a -// bufio.Reader. Read drains the bufio.Reader until empty, then passes -// through subsequent reads to the Conn directly. -type drainBufConn struct { - net.Conn - r *bufio.Reader -} - -func (b *drainBufConn) Read(bs []byte) (int, error) { - if b.r == nil { - return b.Conn.Read(bs) - } - n, err := b.r.Read(bs) - if b.r.Buffered() == 0 { - b.r = nil - } - return n, err -} diff --git a/net/netutil/netutil.go b/net/netutil/netutil.go index 12d61122e..0bca8b0b0 100644 --- a/net/netutil/netutil.go +++ b/net/netutil/netutil.go @@ -6,6 +6,7 @@ package netutil import ( + "bufio" "io" "net" "sync" @@ -63,3 +64,55 @@ func (dummyListener) Accept() (c net.Conn, err error) { return nil, io.EOF } func (a dummyAddr) Network() string { return string(a) } func (a dummyAddr) String() string { return string(a) } + +// NewDrainBufConn returns a net.Conn conditionally wrapping c, +// prefixing any bytes that are in initialReadBuf, which may be nil. +func NewDrainBufConn(c net.Conn, initialReadBuf *bufio.Reader) net.Conn { + r := initialReadBuf + if r != nil && r.Buffered() == 0 { + r = nil + } + return &drainBufConn{c, r} +} + +// drainBufConn is a net.Conn with an initial bunch of bytes in a +// bufio.Reader. Read drains the bufio.Reader until empty, then passes +// through subsequent reads to the Conn directly. +type drainBufConn struct { + net.Conn + r *bufio.Reader +} + +func (b *drainBufConn) Read(bs []byte) (int, error) { + if b.r == nil { + return b.Conn.Read(bs) + } + n, err := b.r.Read(bs) + if b.r.Buffered() == 0 { + b.r = nil + } + return n, err +} + +// NewAltReadWriteCloserConn returns a net.Conn that wraps rwc (for +// Read, Write, and Close) and c (for all other methods). +func NewAltReadWriteCloserConn(rwc io.ReadWriteCloser, c net.Conn) net.Conn { + return wrappedConn{c, rwc} +} + +type wrappedConn struct { + net.Conn + rwc io.ReadWriteCloser +} + +func (w wrappedConn) Read(bs []byte) (int, error) { + return w.rwc.Read(bs) +} + +func (w wrappedConn) Write(bs []byte) (int, error) { + return w.rwc.Write(bs) +} + +func (w wrappedConn) Close() error { + return w.rwc.Close() +}