diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index a82a60722..f86b66bd3 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -53,6 +53,7 @@ "tailscale.com/types/nettype" "tailscale.com/util/clientmetric" "tailscale.com/util/mak" + "tailscale.com/util/set" "tailscale.com/util/testenv" "tailscale.com/wgengine" "tailscale.com/wgengine/netstack" @@ -133,12 +134,26 @@ type Server struct { logtail *logtail.Logger logid logid.PublicID - mu sync.Mutex - listeners map[listenKey]*listener - dialer *tsdial.Dialer - closed bool + mu sync.Mutex + listeners map[listenKey]*listener + fallbackTCPHandlers set.HandleSet[FallbackTCPHandler] + dialer *tsdial.Dialer + closed bool } +// FallbackTCPHandler describes the callback which +// conditionally handles an incoming TCP flow for the +// provided (src/port, dst/port) 4-tuple. These are registered +// as handlers of last resort, and are called only if no +// listener could handle the incoming flow. +// +// If the callback returns intercept=false, the flow is rejected. +// +// When intercept=true, the behavior depends on whether the returned handler +// is non-nil: if nil, the connection is rejected. If non-nil, handler takes +// over the TCP conn. +type FallbackTCPHandler func(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) + // Dial connects to the address on the tailnet. // It will start the server if it has not been started yet. func (s *Server) Dial(ctx context.Context, network, address string) (net.Conn, error) { @@ -755,6 +770,14 @@ func (s *Server) getTCPHandlerForFunnelFlow(src netip.AddrPort, dstPort uint16) func (s *Server) getTCPHandlerForFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) { ln, ok := s.listenerForDstAddr("tcp", dst, false) if !ok { + s.mu.Lock() + defer s.mu.Unlock() + for _, handler := range s.fallbackTCPHandlers { + connHandler, intercept := handler(src, dst) + if intercept { + return connHandler, intercept + } + } return nil, true // don't handle, don't forward to localhost } return ln.handle, true @@ -858,6 +881,24 @@ func (s *Server) ListenTLS(network, addr string) (net.Listener, error) { }), nil } +// RegisterFallbackTCPHandler registers a callback which will be called +// to handle a TCP flow to this tsnet node, for which no listeners will handle. +// +// If multiple fallback handlers are registered, they will be called in an +// undefined order. See FallbackTCPHandler for details on handling a flow. +// +// The returned function can be used to deregister this callback. +func (s *Server) RegisterFallbackTCPHandler(cb FallbackTCPHandler) func() { + s.mu.Lock() + defer s.mu.Unlock() + hnd := s.fallbackTCPHandlers.Add(cb) + return func() { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.fallbackTCPHandlers, hnd) + } +} + // getCert is the GetCertificate function used by ListenTLS. // // It calls GetCertificate on the localClient, passing in the ClientHelloInfo. diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index 981e8c9b2..0d1689683 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -630,3 +630,45 @@ type bufferedConn struct { func (c *bufferedConn) Read(b []byte) (int, error) { return c.reader.Read(b) } + +func TestFallbackTCPHandler(t *testing.T) { + tstest.ResourceCheck(t) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + controlURL := startControl(t) + s1, s1ip := startServer(t, ctx, controlURL, "s1") + s2, _ := startServer(t, ctx, controlURL, "s2") + + lc2, err := s2.LocalClient() + if err != nil { + t.Fatal(err) + } + + // ping to make sure the connection is up. + res, err := lc2.Ping(ctx, s1ip, tailcfg.PingICMP) + if err != nil { + t.Fatal(err) + } + t.Logf("ping success: %#+v", res) + + s1TcpConnCount := 0 + deregister := s1.RegisterFallbackTCPHandler(func(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) { + s1TcpConnCount++ + return nil, false + }) + + if _, err = s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip)); err == nil { + t.Fatal("Expected dial error because fallback handler did not intercept") + } + if s1TcpConnCount != 1 { + t.Errorf("s1TcpConnCount = %d, want %d", s1TcpConnCount, 1) + } + deregister() + if _, err = s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip)); err == nil { + t.Fatal("Expected dial error because nothing would intercept") + } + if s1TcpConnCount != 1 { + t.Errorf("s1TcpConnCount = %d, want %d", s1TcpConnCount, 1) + } +}