From f51c968b2a840804ac66b3a75e4950c14ab3d78c Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Wed, 18 Sep 2024 17:53:03 -0700 Subject: [PATCH] ipn/ipnlocal,wgengine/netstack: unify TCP forwarders This also has the side effect of registering the local ports with the port mapper so that subsequent calls to WhoIs return the correct information (resolves a TODO in code). Updates #13513 --- ipn/ipnlocal/serve.go | 31 +++----- tsd/tsd.go | 3 + wgengine/netstack/netstack.go | 135 ++++++++++++++++++++-------------- 3 files changed, 92 insertions(+), 77 deletions(-) diff --git a/ipn/ipnlocal/serve.go b/ipn/ipnlocal/serve.go index 67d521f09..f8d5d36b6 100644 --- a/ipn/ipnlocal/serve.go +++ b/ipn/ipnlocal/serve.go @@ -479,14 +479,6 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort, if backDst := tcph.TCPForward(); backDst != "" { return func(conn net.Conn) error { defer conn.Close() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - backConn, err := b.dialer.SystemDial(ctx, "tcp", backDst) - cancel() - if err != nil { - b.logf("localbackend: failed to TCP proxy port %v (from %v) to %s: %v", dport, srcAddr, backDst, err) - return nil - } - defer backConn.Close() if sni := tcph.TerminateTLS(); sni != "" { conn = tls.Server(conn, &tls.Config{ GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { @@ -505,18 +497,17 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort, }) } - // TODO(bradfitz): do the RegisterIPPortIdentity and - // UnregisterIPPortIdentity stuff that netstack does - errc := make(chan error, 1) - go func() { - _, err := io.Copy(backConn, conn) - errc <- err - }() - go func() { - _, err := io.Copy(conn, backConn) - errc <- err - }() - return <-errc + ns, ok := b.sys.Netstack.GetOK() + if !ok { + return errors.New("netstack not available") + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + h, err := ns.ForwardTCPHandler(ctx, backDst) + cancel() + if err != nil { + return fmt.Errorf("ForwardTCPHandler(%q): %w", backDst, err) + } + return h(conn) } } diff --git a/tsd/tsd.go b/tsd/tsd.go index 2b5e65626..6d5ef51bc 100644 --- a/tsd/tsd.go +++ b/tsd/tsd.go @@ -18,7 +18,9 @@ package tsd import ( + "context" "fmt" + "net" "reflect" "tailscale.com/control/controlknobs" @@ -73,6 +75,7 @@ type System struct { // references LocalBackend, and LocalBackend has a tsd.System. type NetstackImpl interface { UpdateNetstackIPs(*netmap.NetworkMap) + ForwardTCPHandler(dialCtx context.Context, dialAddr string) (func(net.Conn) error, error) } // Set is a convenience method to set a subsystem value. diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index d029b6c19..f57134ed1 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -1217,21 +1217,21 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { clientRemotePort := reqDetails.RemotePort clientRemoteAddrPort := netip.AddrPortFrom(clientRemoteIP, clientRemotePort) - dialIP := netaddrIPFromNetstackIP(reqDetails.LocalAddress) - isTailscaleIP := tsaddr.IsTailscaleIP(dialIP) + dstAddr := netaddrIPFromNetstackIP(reqDetails.LocalAddress) + isTailscaleIP := tsaddr.IsTailscaleIP(dstAddr) - dstAddrPort := netip.AddrPortFrom(dialIP, reqDetails.LocalPort) + dstAddrPort := netip.AddrPortFrom(dstAddr, reqDetails.LocalPort) - if viaRange.Contains(dialIP) { + if viaRange.Contains(dstAddr) { isTailscaleIP = false - dialIP = tsaddr.UnmapVia(dialIP) + dstAddr = tsaddr.UnmapVia(dstAddr) } defer func() { if !isTailscaleIP { // if this is a subnet IP, we added this in before the TCP handshake // so netstack is happy TCP-handshaking as a subnet IP - ns.removeSubnetAddress(dialIP) + ns.removeSubnetAddress(dstAddr) } }() @@ -1287,7 +1287,7 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { } // Local Services (DNS and WebDAV) - hittingServiceIP := dialIP == serviceIP || dialIP == serviceIPv6 + hittingServiceIP := dstAddr == serviceIP || dstAddr == serviceIPv6 hittingDNS := hittingServiceIP && reqDetails.LocalPort == 53 if hittingDNS { c := getConnOrReset() @@ -1326,8 +1326,10 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { return } } + + var dialIP netip.Addr switch { - case hittingServiceIP && ns.isLoopbackPort(reqDetails.LocalPort): + case hittingServiceIP && ns.isLoopbackPort(dstAddrPort.Port()): if dialIP == serviceIPv6 { dialIP = ipv6Loopback } else { @@ -1336,20 +1338,14 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { case isTailscaleIP: dialIP = ipv4Loopback } - dialAddr := netip.AddrPortFrom(dialIP, uint16(reqDetails.LocalPort)) - if !ns.forwardTCP(getConnOrReset, clientRemoteIP, &wq, dialAddr) { + if !ns.forwardTCP(getConnOrReset, fmt.Sprintf("%v:%d", dialIP, dstAddrPort.Port()), &wq) { r.Complete(true) // sends a RST } } -func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet.TCPConn, clientRemoteIP netip.Addr, wq *waiter.Queue, dialAddr netip.AddrPort) (handled bool) { - dialAddrStr := dialAddr.String() - if debugNetstack() { - ns.logf("[v2] netstack: forwarding incoming connection to %s", dialAddrStr) - } - - ctx, cancel := context.WithCancel(context.Background()) +func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet.TCPConn, dialAddr string, wq *waiter.Queue) (handled bool) { + dialCtx, cancel := context.WithCancel(context.Background()) defer cancel() waitEntry, notifyCh := waiter.NewChannelEntry(waiter.EventHUp) // TODO(bradfitz): right EventMask? @@ -1363,13 +1359,47 @@ func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet. select { case <-notifyCh: if debugNetstack() { - ns.logf("[v2] netstack: forwardTCP notifyCh fired; canceling context for %s", dialAddrStr) + ns.logf("[v2] netstack: forwardTCP notifyCh fired; canceling context for %s", dialAddr) } case <-done: } cancel() }() + h, err := ns.ForwardTCPHandler(dialCtx, dialAddr) + if err != nil { + return false + } + cancel() + // If we get here, either the getClient call below will succeed and + // return something we can Close, or it will fail and will properly + // respond to the client with a RST. Either way, the caller no longer + // needs to clean up the client connection. + handled = true + client := getClient() + if client == nil { + return + } + if err := h(client); err != nil { + ns.logf("forwardTCP: %v", err) + } + return +} + +// ForwardTCPHandler returns a function that forwards an incoming TCP connection +// to the given dialAddr. The returned function should be called immediately +// after the client connection is accepted. The returned function will block +// until the connection is closed. It returns an error if the connection could +// not be established. +// +// It also registers the mapping between the local and remote IP:port pairs with +// the portmapper such that the connection can be identified using calls to +// [ipnlocal.LocalBackend.WhoIs]. +func (ns *Impl) ForwardTCPHandler(dialCtx context.Context, dialAddrStr string) (func(net.Conn) error, error) { + if debugNetstack() { + ns.logf("[v2] netstack: forwarding incoming connection to %s", dialAddrStr) + } + // Attempt to dial the outbound connection before we accept the inbound one. var dialFunc func(context.Context, string, string) (net.Conn, error) if ns.forwardDialFunc != nil { @@ -1381,49 +1411,40 @@ func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet. // TODO: this is racy, dialing before we register our local address. See // https://github.com/tailscale/tailscale/issues/1616. - backend, err := dialFunc(ctx, "tcp", dialAddrStr) + backend, err := dialFunc(dialCtx, "tcp", dialAddrStr) if err != nil { - ns.logf("netstack: could not connect to local backend server at %s: %v", dialAddr.String(), err) - return + ns.logf("netstack: could not connect to local backend server at %s: %v", dialAddrStr, err) + return nil, err } - defer backend.Close() - - backendLocalAddr := backend.LocalAddr().(*net.TCPAddr) - backendLocalIPPort := netaddr.Unmap(backendLocalAddr.AddrPort()) - if err := ns.pm.RegisterIPPortIdentity("tcp", backendLocalIPPort, clientRemoteIP); err != nil { - ns.logf("netstack: could not register TCP mapping %s: %v", backendLocalIPPort, err) - return - } - defer ns.pm.UnregisterIPPortIdentity("tcp", backendLocalIPPort) - - // If we get here, either the getClient call below will succeed and - // return something we can Close, or it will fail and will properly - // respond to the client with a RST. Either way, the caller no longer - // needs to clean up the client connection. - handled = true // We dialed the connection; we can complete the client's TCP handshake. - client := getClient() - if client == nil { - return - } - defer client.Close() - - connClosed := make(chan error, 2) - go func() { - _, err := io.Copy(backend, client) - connClosed <- err - }() - go func() { - _, err := io.Copy(client, backend) - connClosed <- err - }() - err = <-connClosed - if err != nil { - ns.logf("proxy connection closed with error: %v", err) - } - ns.logf("[v2] netstack: forwarder connection to %s closed", dialAddrStr) - return + backendLocalAddr := backend.LocalAddr().(*net.TCPAddr) + backendLocalIPPort := netaddr.Unmap(backendLocalAddr.AddrPort()) + return func(client net.Conn) error { + defer backend.Close() + defer client.Close() + caller := netaddr.Unmap(client.RemoteAddr().(*net.TCPAddr).AddrPort()) + if err := ns.pm.RegisterIPPortIdentity("tcp", backendLocalIPPort, caller.Addr()); err != nil { + ns.logf("netstack: could not register TCP mapping %s: %v", backendLocalIPPort, err) + return err + } + defer ns.pm.UnregisterIPPortIdentity("tcp", backendLocalIPPort) + connClosed := make(chan error, 2) + go func() { + _, err := io.Copy(backend, client) + connClosed <- err + }() + go func() { + _, err := io.Copy(client, backend) + connClosed <- err + }() + err = <-connClosed + if err != nil { + ns.logf("proxy connection closed with error: %v", err) + } + ns.logf("[v2] netstack: forwarder connection to %s closed", dialAddrStr) + return err + }, nil } // ListenPacket listens for incoming packets for the given network and address.