mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-29 04:55:31 +00:00
sessionrecording: implement v2 recording endpoint support (#14105)
The v2 endpoint supports HTTP/2 bidirectional streaming and acks for received bytes. This is used to detect when a recorder disappears to more quickly terminate the session. Updates https://github.com/tailscale/corp/issues/24023 Signed-off-by: Andrew Lytvynov <awly@tailscale.com>
This commit is contained in:
parent
5cae7c51bf
commit
c2a7f17f2b
@ -102,7 +102,7 @@ type Hijacker struct {
|
|||||||
// connection succeeds. In case of success, returns a list with a single
|
// connection succeeds. In case of success, returns a list with a single
|
||||||
// successful recording attempt and an error channel. If the connection errors
|
// successful recording attempt and an error channel. If the connection errors
|
||||||
// after having been established, an error is sent down the channel.
|
// after having been established, an error is sent down the channel.
|
||||||
type RecorderDialFn func(context.Context, []netip.AddrPort, func(context.Context, string, string) (net.Conn, error)) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error)
|
type RecorderDialFn func(context.Context, []netip.AddrPort, sessionrecording.DialFunc) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error)
|
||||||
|
|
||||||
// Hijack hijacks a 'kubectl exec' session and configures for the session
|
// Hijack hijacks a 'kubectl exec' session and configures for the session
|
||||||
// contents to be sent to a recorder.
|
// contents to be sent to a recorder.
|
||||||
|
@ -10,7 +10,6 @@
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -20,6 +19,7 @@
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"tailscale.com/client/tailscale/apitype"
|
"tailscale.com/client/tailscale/apitype"
|
||||||
"tailscale.com/k8s-operator/sessionrecording/fakes"
|
"tailscale.com/k8s-operator/sessionrecording/fakes"
|
||||||
|
"tailscale.com/sessionrecording"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/tsnet"
|
"tailscale.com/tsnet"
|
||||||
"tailscale.com/tstest"
|
"tailscale.com/tstest"
|
||||||
@ -80,7 +80,7 @@ func Test_Hijacker(t *testing.T) {
|
|||||||
h := &Hijacker{
|
h := &Hijacker{
|
||||||
connectToRecorder: func(context.Context,
|
connectToRecorder: func(context.Context,
|
||||||
[]netip.AddrPort,
|
[]netip.AddrPort,
|
||||||
func(context.Context, string, string) (net.Conn, error),
|
sessionrecording.DialFunc,
|
||||||
) (wc io.WriteCloser, rec []*tailcfg.SSHRecordingAttempt, _ <-chan error, err error) {
|
) (wc io.WriteCloser, rec []*tailcfg.SSHRecordingAttempt, _ <-chan error, err error) {
|
||||||
if tt.failRecorderConnect {
|
if tt.failRecorderConnect {
|
||||||
err = errors.New("test")
|
err = errors.New("test")
|
||||||
|
@ -7,6 +7,8 @@
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -14,12 +16,33 @@
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptrace"
|
"net/http/httptrace"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/http2"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/util/httpm"
|
||||||
"tailscale.com/util/multierr"
|
"tailscale.com/util/multierr"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Timeout for an individual DialFunc call for a single recorder address.
|
||||||
|
perDialAttemptTimeout = 5 * time.Second
|
||||||
|
// Timeout for the V2 API HEAD probe request (supportsV2).
|
||||||
|
http2ProbeTimeout = 10 * time.Second
|
||||||
|
// Maximum timeout for trying all available recorders, including V2 API
|
||||||
|
// probes and dial attempts.
|
||||||
|
allDialAttemptsTimeout = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// uploadAckWindow is the period of time to wait for an ackFrame from recorder
|
||||||
|
// before terminating the connection. This is a variable to allow overriding it
|
||||||
|
// in tests.
|
||||||
|
var uploadAckWindow = 30 * time.Second
|
||||||
|
|
||||||
|
// DialFunc is a function for dialing the recorder.
|
||||||
|
type DialFunc func(ctx context.Context, network, host string) (net.Conn, error)
|
||||||
|
|
||||||
// ConnectToRecorder connects to the recorder at any of the provided addresses.
|
// ConnectToRecorder connects to the recorder at any of the provided addresses.
|
||||||
// It returns the first successful response, or a multierr if all attempts fail.
|
// It returns the first successful response, or a multierr if all attempts fail.
|
||||||
//
|
//
|
||||||
@ -32,19 +55,15 @@
|
|||||||
// attempts are in order the recorder(s) was attempted. If successful a
|
// attempts are in order the recorder(s) was attempted. If successful a
|
||||||
// successful connection is made, the last attempt in the slice is the
|
// successful connection is made, the last attempt in the slice is the
|
||||||
// attempt for connected recorder.
|
// attempt for connected recorder.
|
||||||
func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial func(context.Context, string, string) (net.Conn, error)) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) {
|
func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial DialFunc) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) {
|
||||||
if len(recs) == 0 {
|
if len(recs) == 0 {
|
||||||
return nil, nil, nil, errors.New("no recorders configured")
|
return nil, nil, nil, errors.New("no recorders configured")
|
||||||
}
|
}
|
||||||
// We use a special context for dialing the recorder, so that we can
|
// We use a special context for dialing the recorder, so that we can
|
||||||
// limit the time we spend dialing to 30 seconds and still have an
|
// limit the time we spend dialing to 30 seconds and still have an
|
||||||
// unbounded context for the upload.
|
// unbounded context for the upload.
|
||||||
dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second)
|
dialCtx, dialCancel := context.WithTimeout(ctx, allDialAttemptsTimeout)
|
||||||
defer dialCancel()
|
defer dialCancel()
|
||||||
hc, err := SessionRecordingClientForDialer(dialCtx, dial)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var errs []error
|
var errs []error
|
||||||
var attempts []*tailcfg.SSHRecordingAttempt
|
var attempts []*tailcfg.SSHRecordingAttempt
|
||||||
@ -54,6 +73,55 @@ func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial func(con
|
|||||||
}
|
}
|
||||||
attempts = append(attempts, attempt)
|
attempts = append(attempts, attempt)
|
||||||
|
|
||||||
|
var pw io.WriteCloser
|
||||||
|
var errChan <-chan error
|
||||||
|
var err error
|
||||||
|
hc := clientHTTP2(dialCtx, dial)
|
||||||
|
// We need to probe V2 support using a separate HEAD request. Sending
|
||||||
|
// an HTTP/2 POST request to a HTTP/1 server will just "hang" until the
|
||||||
|
// request body is closed (instead of returning a 404 as one would
|
||||||
|
// expect). Sending a HEAD request without a body does not have that
|
||||||
|
// problem.
|
||||||
|
if supportsV2(ctx, hc, ap) {
|
||||||
|
pw, errChan, err = connectV2(ctx, hc, ap)
|
||||||
|
} else {
|
||||||
|
pw, errChan, err = connectV1(ctx, clientHTTP1(dialCtx, dial), ap)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("recording: error starting recording on %q: %w", ap, err)
|
||||||
|
attempt.FailureMessage = err.Error()
|
||||||
|
errs = append(errs, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return pw, attempts, errChan, nil
|
||||||
|
}
|
||||||
|
return nil, attempts, nil, multierr.New(errs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// supportsV2 checks whether a recorder instance supports the /v2/record
|
||||||
|
// endpoint.
|
||||||
|
func supportsV2(ctx context.Context, hc *http.Client, ap netip.AddrPort) bool {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, http2ProbeTimeout)
|
||||||
|
defer cancel()
|
||||||
|
req, err := http.NewRequestWithContext(ctx, httpm.HEAD, fmt.Sprintf("http://%s/v2/record", ap), nil)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
resp, err := hc.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
return resp.StatusCode == http.StatusOK && resp.ProtoMajor > 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// connectV1 connects to the legacy /record endpoint on the recorder. It is
|
||||||
|
// used for backwards-compatibility with older tsrecorder instances.
|
||||||
|
//
|
||||||
|
// On success, it returns a WriteCloser that can be used to upload the
|
||||||
|
// recording, and a channel that will be sent an error (or nil) when the upload
|
||||||
|
// fails or completes.
|
||||||
|
func connectV1(ctx context.Context, hc *http.Client, ap netip.AddrPort) (io.WriteCloser, <-chan error, error) {
|
||||||
// We dial the recorder and wait for it to send a 100-continue
|
// We dial the recorder and wait for it to send a 100-continue
|
||||||
// response before returning from this function. This ensures that
|
// response before returning from this function. This ensures that
|
||||||
// the recorder is ready to accept the recording.
|
// the recorder is ready to accept the recording.
|
||||||
@ -67,12 +135,9 @@ func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial func(con
|
|||||||
})
|
})
|
||||||
|
|
||||||
pr, pw := io.Pipe()
|
pr, pw := io.Pipe()
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s:%d/record", ap.Addr(), ap.Port()), pr)
|
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s/record", ap), pr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("recording: error starting recording: %w", err)
|
return nil, nil, err
|
||||||
attempt.FailureMessage = err.Error()
|
|
||||||
errs = append(errs, err)
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
// We set the Expect header to 100-continue, so that the recorder
|
// We set the Expect header to 100-continue, so that the recorder
|
||||||
// will send a 100-continue response before it starts reading the
|
// will send a 100-continue response before it starts reading the
|
||||||
@ -82,19 +147,21 @@ func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial func(con
|
|||||||
// errChan is used to indicate the result of the request.
|
// errChan is used to indicate the result of the request.
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer close(errChan)
|
||||||
resp, err := hc.Do(req)
|
resp, err := hc.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- fmt.Errorf("recording: error starting recording: %w", err)
|
errChan <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
errChan <- fmt.Errorf("recording: unexpected status: %v", resp.Status)
|
errChan <- fmt.Errorf("recording: unexpected status: %v", resp.Status)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
errChan <- nil
|
|
||||||
}()
|
}()
|
||||||
select {
|
select {
|
||||||
case <-got100:
|
case <-got100:
|
||||||
|
return pw, errChan, nil
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
// If we get an error before we get the 100-continue response,
|
// If we get an error before we get the 100-continue response,
|
||||||
// we need to try another recorder.
|
// we need to try another recorder.
|
||||||
@ -103,25 +170,133 @@ func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial func(con
|
|||||||
// is unexpected as we haven't sent any data yet.
|
// is unexpected as we haven't sent any data yet.
|
||||||
err = errors.New("recording: unexpected EOF")
|
err = errors.New("recording: unexpected EOF")
|
||||||
}
|
}
|
||||||
attempt.FailureMessage = err.Error()
|
return nil, nil, err
|
||||||
errs = append(errs, err)
|
|
||||||
continue // try the next recorder
|
|
||||||
}
|
}
|
||||||
return pw, attempts, errChan, nil
|
|
||||||
}
|
|
||||||
return nil, attempts, nil, multierr.New(errs...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionRecordingClientForDialer returns an http.Client that uses a clone of
|
// connectV2 connects to the /v2/record endpoint on the recorder over HTTP/2.
|
||||||
// the provided Dialer's PeerTransport to dial connections. This is used to make
|
// It explicitly tracks ack frames sent in the response and terminates the
|
||||||
// requests to the session recording server to upload session recordings. It
|
// connection if sent recording data is un-acked for uploadAckWindow.
|
||||||
// uses the provided dialCtx to dial connections, and limits a single dial to 5
|
//
|
||||||
// seconds.
|
// On success, it returns a WriteCloser that can be used to upload the
|
||||||
func SessionRecordingClientForDialer(dialCtx context.Context, dial func(context.Context, string, string) (net.Conn, error)) (*http.Client, error) {
|
// recording, and a channel that will be sent an error (or nil) when the upload
|
||||||
|
// fails or completes.
|
||||||
|
func connectV2(ctx context.Context, hc *http.Client, ap netip.AddrPort) (io.WriteCloser, <-chan error, error) {
|
||||||
|
pr, pw := io.Pipe()
|
||||||
|
upload := &readCounter{r: pr}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s/v2/record", ap), upload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// With HTTP/2, hc.Do will not block while the request body is being sent.
|
||||||
|
// It will return immediately and allow us to consume the response body at
|
||||||
|
// the same time.
|
||||||
|
resp, err := hc.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
resp.Body.Close()
|
||||||
|
return nil, nil, fmt.Errorf("recording: unexpected status: %v", resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
acks := make(chan int64)
|
||||||
|
// Read acks from the response and send them to the acks channel.
|
||||||
|
go func() {
|
||||||
|
defer close(errChan)
|
||||||
|
defer close(acks)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
defer pw.Close()
|
||||||
|
dec := json.NewDecoder(resp.Body)
|
||||||
|
for {
|
||||||
|
var frame v2ResponseFrame
|
||||||
|
if err := dec.Decode(&frame); err != nil {
|
||||||
|
if !errors.Is(err, io.EOF) {
|
||||||
|
errChan <- fmt.Errorf("recording: unexpected error receiving acks: %w", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if frame.Error != "" {
|
||||||
|
errChan <- fmt.Errorf("recording: received error from the recorder: %q", frame.Error)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case acks <- frame.Ack:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
// Track acks from the acks channel.
|
||||||
|
go func() {
|
||||||
|
// Hack for tests: some tests modify uploadAckWindow and reset it when
|
||||||
|
// the test ends. This can race with t.Reset call below. Making a copy
|
||||||
|
// here is a lazy workaround to not wait for this goroutine to exit in
|
||||||
|
// the test cases.
|
||||||
|
uploadAckWindow := uploadAckWindow
|
||||||
|
// This timer fires if we didn't receive an ack for too long.
|
||||||
|
t := time.NewTimer(uploadAckWindow)
|
||||||
|
defer t.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.C:
|
||||||
|
// Close the pipe which terminates the connection and cleans up
|
||||||
|
// other goroutines. Note that tsrecorder will send us ack
|
||||||
|
// frames even if there is no new data to ack. This helps
|
||||||
|
// detect broken recorder connection if the session is idle.
|
||||||
|
pr.CloseWithError(errNoAcks)
|
||||||
|
resp.Body.Close()
|
||||||
|
return
|
||||||
|
case _, ok := <-acks:
|
||||||
|
if !ok {
|
||||||
|
// acks channel closed means that the goroutine reading them
|
||||||
|
// finished, which means that the request has ended.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// TODO(awly): limit how far behind the received acks can be. This
|
||||||
|
// should handle scenarios where a session suddenly dumps a lot of
|
||||||
|
// output.
|
||||||
|
t.Reset(uploadAckWindow)
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return pw, errChan, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var errNoAcks = errors.New("did not receive ack frames from the recorder in 30s")
|
||||||
|
|
||||||
|
type v2ResponseFrame struct {
|
||||||
|
// Ack is the number of bytes received from the client so far. The bytes
|
||||||
|
// are not guaranteed to be durably stored yet.
|
||||||
|
Ack int64 `json:"ack,omitempty"`
|
||||||
|
// Error is an error encountered while storing the recording. Error is only
|
||||||
|
// ever set as the last frame in the response.
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// readCounter is an io.Reader that counts how many bytes were read.
|
||||||
|
type readCounter struct {
|
||||||
|
r io.Reader
|
||||||
|
sent atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *readCounter) Read(buf []byte) (int, error) {
|
||||||
|
n, err := u.r.Read(buf)
|
||||||
|
u.sent.Add(int64(n))
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// clientHTTP1 returns a claassic http.Client with a per-dial context. It uses
|
||||||
|
// dialCtx and adds a 5s timeout to it.
|
||||||
|
func clientHTTP1(dialCtx context.Context, dial DialFunc) *http.Client {
|
||||||
tr := http.DefaultTransport.(*http.Transport).Clone()
|
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
|
||||||
tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
perAttemptCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
perAttemptCtx, cancel := context.WithTimeout(ctx, perDialAttemptTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
go func() {
|
go func() {
|
||||||
select {
|
select {
|
||||||
@ -132,7 +307,32 @@ func SessionRecordingClientForDialer(dialCtx context.Context, dial func(context.
|
|||||||
}()
|
}()
|
||||||
return dial(perAttemptCtx, network, addr)
|
return dial(perAttemptCtx, network, addr)
|
||||||
}
|
}
|
||||||
return &http.Client{
|
return &http.Client{Transport: tr}
|
||||||
Transport: tr,
|
}
|
||||||
}, nil
|
|
||||||
|
// clientHTTP2 is like clientHTTP1 but returns an http.Client suitable for h2c
|
||||||
|
// requests (HTTP/2 over plaintext). Unfortunately the same client does not
|
||||||
|
// work for HTTP/1 so we need to split these up.
|
||||||
|
func clientHTTP2(dialCtx context.Context, dial DialFunc) *http.Client {
|
||||||
|
return &http.Client{
|
||||||
|
Transport: &http2.Transport{
|
||||||
|
// Allow "http://" scheme in URLs.
|
||||||
|
AllowHTTP: true,
|
||||||
|
// Pretend like we're using TLS, but actually use the provided
|
||||||
|
// DialFunc underneath. This is necessary to convince the transport
|
||||||
|
// to actually dial.
|
||||||
|
DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
|
||||||
|
perAttemptCtx, cancel := context.WithTimeout(ctx, perDialAttemptTimeout)
|
||||||
|
defer cancel()
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-perAttemptCtx.Done():
|
||||||
|
case <-dialCtx.Done():
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return dial(perAttemptCtx, network, addr)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
189
sessionrecording/connect_test.go
Normal file
189
sessionrecording/connect_test.go
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
package sessionrecording
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
"golang.org/x/net/http2/h2c"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConnectToRecorder(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
desc string
|
||||||
|
http2 bool
|
||||||
|
// setup returns a recorder server mux, and a channel which sends the
|
||||||
|
// hash of the recording uploaded to it. The channel is expected to
|
||||||
|
// fire only once.
|
||||||
|
setup func(t *testing.T) (*http.ServeMux, <-chan []byte)
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "v1 recorder",
|
||||||
|
setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) {
|
||||||
|
uploadHash := make(chan []byte, 1)
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
hash := sha256.New()
|
||||||
|
if _, err := io.Copy(hash, r.Body); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
uploadHash <- hash.Sum(nil)
|
||||||
|
})
|
||||||
|
return mux, uploadHash
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "v2 recorder",
|
||||||
|
http2: true,
|
||||||
|
setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) {
|
||||||
|
uploadHash := make(chan []byte, 1)
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
t.Error("received request to v1 endpoint")
|
||||||
|
http.Error(w, "not found", http.StatusNotFound)
|
||||||
|
})
|
||||||
|
mux.HandleFunc("POST /v2/record", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Force the status to send to unblock the client waiting
|
||||||
|
// for it.
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.(http.Flusher).Flush()
|
||||||
|
|
||||||
|
body := &readCounter{r: r.Body}
|
||||||
|
hash := sha256.New()
|
||||||
|
ctx, cancel := context.WithCancel(r.Context())
|
||||||
|
go func() {
|
||||||
|
defer cancel()
|
||||||
|
if _, err := io.Copy(hash, body); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Send acks for received bytes.
|
||||||
|
tick := time.NewTicker(time.Millisecond)
|
||||||
|
defer tick.Stop()
|
||||||
|
enc := json.NewEncoder(w)
|
||||||
|
outer:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
break outer
|
||||||
|
case <-tick.C:
|
||||||
|
if err := enc.Encode(v2ResponseFrame{Ack: body.sent.Load()}); err != nil {
|
||||||
|
t.Errorf("writing ack frame: %v", err)
|
||||||
|
break outer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
uploadHash <- hash.Sum(nil)
|
||||||
|
})
|
||||||
|
// Probing HEAD endpoint which always returns 200 OK.
|
||||||
|
mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {})
|
||||||
|
return mux, uploadHash
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "v2 recorder no acks",
|
||||||
|
http2: true,
|
||||||
|
wantErr: true,
|
||||||
|
setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) {
|
||||||
|
// Make the client no-ack timeout quick for the test.
|
||||||
|
oldAckWindow := uploadAckWindow
|
||||||
|
uploadAckWindow = 100 * time.Millisecond
|
||||||
|
t.Cleanup(func() { uploadAckWindow = oldAckWindow })
|
||||||
|
|
||||||
|
uploadHash := make(chan []byte, 1)
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
t.Error("received request to v1 endpoint")
|
||||||
|
http.Error(w, "not found", http.StatusNotFound)
|
||||||
|
})
|
||||||
|
mux.HandleFunc("POST /v2/record", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Force the status to send to unblock the client waiting
|
||||||
|
// for it.
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.(http.Flusher).Flush()
|
||||||
|
|
||||||
|
// Consume the whole request body but don't send any acks
|
||||||
|
// back.
|
||||||
|
hash := sha256.New()
|
||||||
|
if _, err := io.Copy(hash, r.Body); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
// Goes in the channel buffer, non-blocking.
|
||||||
|
uploadHash <- hash.Sum(nil)
|
||||||
|
|
||||||
|
// Block until the parent test case ends to prevent the
|
||||||
|
// request termination. We want to exercise the ack
|
||||||
|
// tracking logic specifically.
|
||||||
|
ctx, cancel := context.WithCancel(r.Context())
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
<-ctx.Done()
|
||||||
|
})
|
||||||
|
mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {})
|
||||||
|
return mux, uploadHash
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.desc, func(t *testing.T) {
|
||||||
|
mux, uploadHash := tt.setup(t)
|
||||||
|
|
||||||
|
srv := httptest.NewUnstartedServer(mux)
|
||||||
|
if tt.http2 {
|
||||||
|
// Wire up h2c-compatible HTTP/2 server. This is optional
|
||||||
|
// because the v1 recorder didn't support HTTP/2 and we try to
|
||||||
|
// mimic that.
|
||||||
|
h2s := &http2.Server{}
|
||||||
|
srv.Config.Handler = h2c.NewHandler(mux, h2s)
|
||||||
|
if err := http2.ConfigureServer(srv.Config, h2s); err != nil {
|
||||||
|
t.Errorf("configuring HTTP/2 support in server: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
srv.Start()
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
|
||||||
|
d := new(net.Dialer)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
w, _, errc, err := ConnectToRecorder(ctx, []netip.AddrPort{netip.MustParseAddrPort(srv.Listener.Addr().String())}, d.DialContext)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ConnectToRecorder: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send some random data and hash it to compare with the recorded
|
||||||
|
// data hash.
|
||||||
|
hash := sha256.New()
|
||||||
|
const numBytes = 1 << 20 // 1MB
|
||||||
|
if _, err := io.CopyN(io.MultiWriter(w, hash), rand.Reader, numBytes); err != nil {
|
||||||
|
t.Fatalf("writing recording data: %v", err)
|
||||||
|
}
|
||||||
|
if err := w.Close(); err != nil {
|
||||||
|
t.Fatalf("closing recording stream: %v", err)
|
||||||
|
}
|
||||||
|
if err := <-errc; err != nil && !tt.wantErr {
|
||||||
|
t.Fatalf("error from the channel: %v", err)
|
||||||
|
} else if err == nil && tt.wantErr {
|
||||||
|
t.Fatalf("did not receive expected error from the channel")
|
||||||
|
}
|
||||||
|
|
||||||
|
if recv, sent := <-uploadHash, hash.Sum(nil); !bytes.Equal(recv, sent) {
|
||||||
|
t.Errorf("mismatch in recording data hash, sent %x, received %x", sent, recv)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -1170,7 +1170,7 @@ func (ss *sshSession) run() {
|
|||||||
if err != nil && !errors.Is(err, io.EOF) {
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
isErrBecauseProcessExited := processDone.Load() && errors.Is(err, syscall.EIO)
|
isErrBecauseProcessExited := processDone.Load() && errors.Is(err, syscall.EIO)
|
||||||
if !isErrBecauseProcessExited {
|
if !isErrBecauseProcessExited {
|
||||||
logf("stdout copy: %v, %T", err)
|
logf("stdout copy: %v", err)
|
||||||
ss.cancelCtx(err)
|
ss.cancelCtx(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1520,9 +1520,14 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) {
|
|||||||
go func() {
|
go func() {
|
||||||
err := <-errChan
|
err := <-errChan
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
select {
|
||||||
|
case <-ss.ctx.Done():
|
||||||
// Success.
|
// Success.
|
||||||
ss.logf("recording: finished uploading recording")
|
ss.logf("recording: finished uploading recording")
|
||||||
return
|
return
|
||||||
|
default:
|
||||||
|
err = errors.New("recording upload ended before the SSH session")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if onFailure != nil && onFailure.NotifyURL != "" && len(attempts) > 0 {
|
if onFailure != nil && onFailure.NotifyURL != "" && len(attempts) > 0 {
|
||||||
lastAttempt := attempts[len(attempts)-1]
|
lastAttempt := attempts[len(attempts)-1]
|
||||||
|
@ -33,6 +33,8 @@
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
gossh "github.com/tailscale/golang-x-crypto/ssh"
|
gossh "github.com/tailscale/golang-x-crypto/ssh"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
"golang.org/x/net/http2/h2c"
|
||||||
"tailscale.com/ipn/ipnlocal"
|
"tailscale.com/ipn/ipnlocal"
|
||||||
"tailscale.com/ipn/store/mem"
|
"tailscale.com/ipn/store/mem"
|
||||||
"tailscale.com/net/memnet"
|
"tailscale.com/net/memnet"
|
||||||
@ -481,10 +483,9 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var handler http.HandlerFunc
|
var handler http.HandlerFunc
|
||||||
recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
recordingServer := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
handler(w, r)
|
handler(w, r)
|
||||||
}))
|
})
|
||||||
defer recordingServer.Close()
|
|
||||||
|
|
||||||
s := &server{
|
s := &server{
|
||||||
logf: t.Logf,
|
logf: t.Logf,
|
||||||
@ -533,9 +534,10 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "upload-fails-after-starting",
|
name: "upload-fails-after-starting",
|
||||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.(http.Flusher).Flush()
|
||||||
r.Body.Read(make([]byte, 1))
|
r.Body.Read(make([]byte, 1))
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
},
|
},
|
||||||
sshCommand: "echo hello && sleep 1 && echo world",
|
sshCommand: "echo hello && sleep 1 && echo world",
|
||||||
wantClientOutput: "\r\n\r\nsession terminated\r\n\r\n",
|
wantClientOutput: "\r\n\r\nsession terminated\r\n\r\n",
|
||||||
@ -548,6 +550,7 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
s.logf = t.Logf
|
||||||
tstest.Replace(t, &handler, tt.handler)
|
tstest.Replace(t, &handler, tt.handler)
|
||||||
sc, dc := memnet.NewTCPConn(src, dst, 1024)
|
sc, dc := memnet.NewTCPConn(src, dst, 1024)
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
@ -597,12 +600,12 @@ func TestMultipleRecorders(t *testing.T) {
|
|||||||
t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
|
t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
|
||||||
}
|
}
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
recordingServer := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
defer close(done)
|
defer close(done)
|
||||||
io.ReadAll(r.Body)
|
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
}))
|
w.(http.Flusher).Flush()
|
||||||
defer recordingServer.Close()
|
io.ReadAll(r.Body)
|
||||||
|
})
|
||||||
badRecorder, err := net.Listen("tcp", ":0")
|
badRecorder, err := net.Listen("tcp", ":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -610,15 +613,9 @@ func TestMultipleRecorders(t *testing.T) {
|
|||||||
badRecorderAddr := badRecorder.Addr().String()
|
badRecorderAddr := badRecorder.Addr().String()
|
||||||
badRecorder.Close()
|
badRecorder.Close()
|
||||||
|
|
||||||
badRecordingServer500 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
badRecordingServer500 := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(500)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
}))
|
})
|
||||||
defer badRecordingServer500.Close()
|
|
||||||
|
|
||||||
badRecordingServer200 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(200)
|
|
||||||
}))
|
|
||||||
defer badRecordingServer200.Close()
|
|
||||||
|
|
||||||
s := &server{
|
s := &server{
|
||||||
logf: t.Logf,
|
logf: t.Logf,
|
||||||
@ -630,7 +627,6 @@ func TestMultipleRecorders(t *testing.T) {
|
|||||||
Recorders: []netip.AddrPort{
|
Recorders: []netip.AddrPort{
|
||||||
netip.MustParseAddrPort(badRecorderAddr),
|
netip.MustParseAddrPort(badRecorderAddr),
|
||||||
netip.MustParseAddrPort(badRecordingServer500.Listener.Addr().String()),
|
netip.MustParseAddrPort(badRecordingServer500.Listener.Addr().String()),
|
||||||
netip.MustParseAddrPort(badRecordingServer200.Listener.Addr().String()),
|
|
||||||
netip.MustParseAddrPort(recordingServer.Listener.Addr().String()),
|
netip.MustParseAddrPort(recordingServer.Listener.Addr().String()),
|
||||||
},
|
},
|
||||||
OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
|
OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
|
||||||
@ -701,19 +697,21 @@ func TestSSHRecordingNonInteractive(t *testing.T) {
|
|||||||
}
|
}
|
||||||
var recording []byte
|
var recording []byte
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
recordingServer := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.(http.Flusher).Flush()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
recording, err = io.ReadAll(r.Body)
|
recording, err = io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}))
|
})
|
||||||
defer recordingServer.Close()
|
|
||||||
|
|
||||||
s := &server{
|
s := &server{
|
||||||
logf: logger.Discard,
|
logf: t.Logf,
|
||||||
lb: &localState{
|
lb: &localState{
|
||||||
sshEnabled: true,
|
sshEnabled: true,
|
||||||
matchingRule: newSSHRule(
|
matchingRule: newSSHRule(
|
||||||
@ -1299,3 +1297,22 @@ func TestStdOsUserUserAssumptions(t *testing.T) {
|
|||||||
t.Errorf("os/user.User has %v fields; this package assumes %v", got, want)
|
t.Errorf("os/user.User has %v fields; this package assumes %v", got, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mockRecordingServer(t *testing.T, handleRecord http.HandlerFunc) *httptest.Server {
|
||||||
|
t.Helper()
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("POST /record", func(http.ResponseWriter, *http.Request) {
|
||||||
|
t.Errorf("v1 recording endpoint called")
|
||||||
|
})
|
||||||
|
mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {})
|
||||||
|
mux.HandleFunc("POST /v2/record", handleRecord)
|
||||||
|
|
||||||
|
h2s := &http2.Server{}
|
||||||
|
srv := httptest.NewUnstartedServer(h2c.NewHandler(mux, h2s))
|
||||||
|
if err := http2.ConfigureServer(srv.Config, h2s); err != nil {
|
||||||
|
t.Errorf("configuring HTTP/2 support in recording server: %v", err)
|
||||||
|
}
|
||||||
|
srv.Start()
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
return srv
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user