ipn/ipnlocal,wgengine/netstack: unify TCP forwarders

This also has the side effect of registering the local ports with the
port mapper so that subsequent calls to WhoIs return the correct information
(resolves a TODO in code).

Updates #13513
This commit is contained in:
Maisem Ali 2024-09-18 17:53:03 -07:00
parent 5f89c93274
commit f51c968b2a
3 changed files with 92 additions and 77 deletions

View File

@ -479,14 +479,6 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort,
if backDst := tcph.TCPForward(); backDst != "" { if backDst := tcph.TCPForward(); backDst != "" {
return func(conn net.Conn) error { return func(conn net.Conn) error {
defer conn.Close() defer conn.Close()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
backConn, err := b.dialer.SystemDial(ctx, "tcp", backDst)
cancel()
if err != nil {
b.logf("localbackend: failed to TCP proxy port %v (from %v) to %s: %v", dport, srcAddr, backDst, err)
return nil
}
defer backConn.Close()
if sni := tcph.TerminateTLS(); sni != "" { if sni := tcph.TerminateTLS(); sni != "" {
conn = tls.Server(conn, &tls.Config{ conn = tls.Server(conn, &tls.Config{
GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
@ -505,18 +497,17 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort,
}) })
} }
// TODO(bradfitz): do the RegisterIPPortIdentity and ns, ok := b.sys.Netstack.GetOK()
// UnregisterIPPortIdentity stuff that netstack does if !ok {
errc := make(chan error, 1) return errors.New("netstack not available")
go func() { }
_, err := io.Copy(backConn, conn) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
errc <- err h, err := ns.ForwardTCPHandler(ctx, backDst)
}() cancel()
go func() { if err != nil {
_, err := io.Copy(conn, backConn) return fmt.Errorf("ForwardTCPHandler(%q): %w", backDst, err)
errc <- err }
}() return h(conn)
return <-errc
} }
} }

View File

@ -18,7 +18,9 @@
package tsd package tsd
import ( import (
"context"
"fmt" "fmt"
"net"
"reflect" "reflect"
"tailscale.com/control/controlknobs" "tailscale.com/control/controlknobs"
@ -73,6 +75,7 @@ type System struct {
// references LocalBackend, and LocalBackend has a tsd.System. // references LocalBackend, and LocalBackend has a tsd.System.
type NetstackImpl interface { type NetstackImpl interface {
UpdateNetstackIPs(*netmap.NetworkMap) UpdateNetstackIPs(*netmap.NetworkMap)
ForwardTCPHandler(dialCtx context.Context, dialAddr string) (func(net.Conn) error, error)
} }
// Set is a convenience method to set a subsystem value. // Set is a convenience method to set a subsystem value.

View File

@ -1217,21 +1217,21 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
clientRemotePort := reqDetails.RemotePort clientRemotePort := reqDetails.RemotePort
clientRemoteAddrPort := netip.AddrPortFrom(clientRemoteIP, clientRemotePort) clientRemoteAddrPort := netip.AddrPortFrom(clientRemoteIP, clientRemotePort)
dialIP := netaddrIPFromNetstackIP(reqDetails.LocalAddress) dstAddr := netaddrIPFromNetstackIP(reqDetails.LocalAddress)
isTailscaleIP := tsaddr.IsTailscaleIP(dialIP) isTailscaleIP := tsaddr.IsTailscaleIP(dstAddr)
dstAddrPort := netip.AddrPortFrom(dialIP, reqDetails.LocalPort) dstAddrPort := netip.AddrPortFrom(dstAddr, reqDetails.LocalPort)
if viaRange.Contains(dialIP) { if viaRange.Contains(dstAddr) {
isTailscaleIP = false isTailscaleIP = false
dialIP = tsaddr.UnmapVia(dialIP) dstAddr = tsaddr.UnmapVia(dstAddr)
} }
defer func() { defer func() {
if !isTailscaleIP { if !isTailscaleIP {
// if this is a subnet IP, we added this in before the TCP handshake // if this is a subnet IP, we added this in before the TCP handshake
// so netstack is happy TCP-handshaking as a subnet IP // so netstack is happy TCP-handshaking as a subnet IP
ns.removeSubnetAddress(dialIP) ns.removeSubnetAddress(dstAddr)
} }
}() }()
@ -1287,7 +1287,7 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
} }
// Local Services (DNS and WebDAV) // Local Services (DNS and WebDAV)
hittingServiceIP := dialIP == serviceIP || dialIP == serviceIPv6 hittingServiceIP := dstAddr == serviceIP || dstAddr == serviceIPv6
hittingDNS := hittingServiceIP && reqDetails.LocalPort == 53 hittingDNS := hittingServiceIP && reqDetails.LocalPort == 53
if hittingDNS { if hittingDNS {
c := getConnOrReset() c := getConnOrReset()
@ -1326,8 +1326,10 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
return return
} }
} }
var dialIP netip.Addr
switch { switch {
case hittingServiceIP && ns.isLoopbackPort(reqDetails.LocalPort): case hittingServiceIP && ns.isLoopbackPort(dstAddrPort.Port()):
if dialIP == serviceIPv6 { if dialIP == serviceIPv6 {
dialIP = ipv6Loopback dialIP = ipv6Loopback
} else { } else {
@ -1336,20 +1338,14 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
case isTailscaleIP: case isTailscaleIP:
dialIP = ipv4Loopback dialIP = ipv4Loopback
} }
dialAddr := netip.AddrPortFrom(dialIP, uint16(reqDetails.LocalPort))
if !ns.forwardTCP(getConnOrReset, clientRemoteIP, &wq, dialAddr) { if !ns.forwardTCP(getConnOrReset, fmt.Sprintf("%v:%d", dialIP, dstAddrPort.Port()), &wq) {
r.Complete(true) // sends a RST r.Complete(true) // sends a RST
} }
} }
func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet.TCPConn, clientRemoteIP netip.Addr, wq *waiter.Queue, dialAddr netip.AddrPort) (handled bool) { func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet.TCPConn, dialAddr string, wq *waiter.Queue) (handled bool) {
dialAddrStr := dialAddr.String() dialCtx, cancel := context.WithCancel(context.Background())
if debugNetstack() {
ns.logf("[v2] netstack: forwarding incoming connection to %s", dialAddrStr)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
waitEntry, notifyCh := waiter.NewChannelEntry(waiter.EventHUp) // TODO(bradfitz): right EventMask? waitEntry, notifyCh := waiter.NewChannelEntry(waiter.EventHUp) // TODO(bradfitz): right EventMask?
@ -1363,13 +1359,47 @@ func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet.
select { select {
case <-notifyCh: case <-notifyCh:
if debugNetstack() { if debugNetstack() {
ns.logf("[v2] netstack: forwardTCP notifyCh fired; canceling context for %s", dialAddrStr) ns.logf("[v2] netstack: forwardTCP notifyCh fired; canceling context for %s", dialAddr)
} }
case <-done: case <-done:
} }
cancel() cancel()
}() }()
h, err := ns.ForwardTCPHandler(dialCtx, dialAddr)
if err != nil {
return false
}
cancel()
// If we get here, either the getClient call below will succeed and
// return something we can Close, or it will fail and will properly
// respond to the client with a RST. Either way, the caller no longer
// needs to clean up the client connection.
handled = true
client := getClient()
if client == nil {
return
}
if err := h(client); err != nil {
ns.logf("forwardTCP: %v", err)
}
return
}
// ForwardTCPHandler returns a function that forwards an incoming TCP connection
// to the given dialAddr. The returned function should be called immediately
// after the client connection is accepted. The returned function will block
// until the connection is closed. It returns an error if the connection could
// not be established.
//
// It also registers the mapping between the local and remote IP:port pairs with
// the portmapper such that the connection can be identified using calls to
// [ipnlocal.LocalBackend.WhoIs].
func (ns *Impl) ForwardTCPHandler(dialCtx context.Context, dialAddrStr string) (func(net.Conn) error, error) {
if debugNetstack() {
ns.logf("[v2] netstack: forwarding incoming connection to %s", dialAddrStr)
}
// Attempt to dial the outbound connection before we accept the inbound one. // Attempt to dial the outbound connection before we accept the inbound one.
var dialFunc func(context.Context, string, string) (net.Conn, error) var dialFunc func(context.Context, string, string) (net.Conn, error)
if ns.forwardDialFunc != nil { if ns.forwardDialFunc != nil {
@ -1381,49 +1411,40 @@ func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet.
// TODO: this is racy, dialing before we register our local address. See // TODO: this is racy, dialing before we register our local address. See
// https://github.com/tailscale/tailscale/issues/1616. // https://github.com/tailscale/tailscale/issues/1616.
backend, err := dialFunc(ctx, "tcp", dialAddrStr) backend, err := dialFunc(dialCtx, "tcp", dialAddrStr)
if err != nil { if err != nil {
ns.logf("netstack: could not connect to local backend server at %s: %v", dialAddr.String(), err) ns.logf("netstack: could not connect to local backend server at %s: %v", dialAddrStr, err)
return return nil, err
} }
defer backend.Close()
backendLocalAddr := backend.LocalAddr().(*net.TCPAddr)
backendLocalIPPort := netaddr.Unmap(backendLocalAddr.AddrPort())
if err := ns.pm.RegisterIPPortIdentity("tcp", backendLocalIPPort, clientRemoteIP); err != nil {
ns.logf("netstack: could not register TCP mapping %s: %v", backendLocalIPPort, err)
return
}
defer ns.pm.UnregisterIPPortIdentity("tcp", backendLocalIPPort)
// If we get here, either the getClient call below will succeed and
// return something we can Close, or it will fail and will properly
// respond to the client with a RST. Either way, the caller no longer
// needs to clean up the client connection.
handled = true
// We dialed the connection; we can complete the client's TCP handshake. // We dialed the connection; we can complete the client's TCP handshake.
client := getClient() backendLocalAddr := backend.LocalAddr().(*net.TCPAddr)
if client == nil { backendLocalIPPort := netaddr.Unmap(backendLocalAddr.AddrPort())
return return func(client net.Conn) error {
} defer backend.Close()
defer client.Close() defer client.Close()
caller := netaddr.Unmap(client.RemoteAddr().(*net.TCPAddr).AddrPort())
connClosed := make(chan error, 2) if err := ns.pm.RegisterIPPortIdentity("tcp", backendLocalIPPort, caller.Addr()); err != nil {
go func() { ns.logf("netstack: could not register TCP mapping %s: %v", backendLocalIPPort, err)
_, err := io.Copy(backend, client) return err
connClosed <- err }
}() defer ns.pm.UnregisterIPPortIdentity("tcp", backendLocalIPPort)
go func() { connClosed := make(chan error, 2)
_, err := io.Copy(client, backend) go func() {
connClosed <- err _, err := io.Copy(backend, client)
}() connClosed <- err
err = <-connClosed }()
if err != nil { go func() {
ns.logf("proxy connection closed with error: %v", err) _, err := io.Copy(client, backend)
} connClosed <- err
ns.logf("[v2] netstack: forwarder connection to %s closed", dialAddrStr) }()
return err = <-connClosed
if err != nil {
ns.logf("proxy connection closed with error: %v", err)
}
ns.logf("[v2] netstack: forwarder connection to %s closed", dialAddrStr)
return err
}, nil
} }
// ListenPacket listens for incoming packets for the given network and address. // ListenPacket listens for incoming packets for the given network and address.