diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index fdac1a037..c0df501a8 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -468,19 +468,37 @@ func (ns *Impl) injectInbound(p *packet.Parsed, t *tstun.Wrapper) filter.Respons return filter.DropSilently } +func netaddrIPFromNetstackIP(s tcpip.Address) netaddr.IP { + switch len(s) { + case 4: + return netaddr.IPv4(s[0], s[1], s[2], s[3]) + case 16: + var a [16]byte + copy(a[:], s) + return netaddr.IPFrom16(a) + } + return netaddr.IP{} +} + func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { reqDetails := r.ID() if debugNetstack { ns.logf("[v2] TCP ForwarderRequest: %s", stringifyTEI(reqDetails)) } - dialAddr := reqDetails.LocalAddress - dialNetAddr, _ := netaddr.FromStdIP(net.IP(dialAddr)) - isTailscaleIP := tsaddr.IsTailscaleIP(dialNetAddr) + clientRemoteIP := netaddrIPFromNetstackIP(reqDetails.RemoteAddress) + if !clientRemoteIP.IsValid() { + ns.logf("invalid RemoteAddress in TCP ForwarderRequest: %s", stringifyTEI(reqDetails)) + r.Complete(true) + return + } + + dialIP := netaddrIPFromNetstackIP(reqDetails.LocalAddress) + isTailscaleIP := tsaddr.IsTailscaleIP(dialIP) defer func() { if !isTailscaleIP { // if this is a subnet IP, we added this in before the TCP handshake // so netstack is happy TCP-handshaking as a subnet IP - ns.removeSubnetAddress(dialNetAddr) + ns.removeSubnetAddress(dialIP) } }() var wq waiter.Queue @@ -490,21 +508,31 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { return } r.Complete(false) + + // Asynchronously start the TCP handshake. Note that the + // gonet.TCPConn methods c.RemoteAddr() and c.LocalAddr() will + // return nil until the handshake actually completes. But we + // have the remote address in reqDetails instead, so we don't + // use RemoteAddr. The byte copies in both directions in + // forwardTCP will block until the TCP handshake is complete. c := gonet.NewTCPConn(&wq, ep) + if ns.ForwardTCPIn != nil { ns.ForwardTCPIn(c, reqDetails.LocalPort) return } if isTailscaleIP { - dialAddr = tcpip.Address(net.ParseIP("127.0.0.1")).To4() + dialIP = netaddr.IPv4(127, 0, 0, 1) } - ns.forwardTCP(c, &wq, dialAddr, reqDetails.LocalPort) + dialAddr := netaddr.IPPortFrom(dialIP, uint16(reqDetails.LocalPort)) + ns.forwardTCP(c, clientRemoteIP, &wq, dialAddr) } -func (ns *Impl) forwardTCP(client *gonet.TCPConn, wq *waiter.Queue, dialAddr tcpip.Address, dialPort uint16) { +func (ns *Impl) forwardTCP(client *gonet.TCPConn, clientRemoteIP netaddr.IP, wq *waiter.Queue, dialAddr netaddr.IPPort) { defer client.Close() - dialAddrStr := net.JoinHostPort(dialAddr.String(), strconv.Itoa(int(dialPort))) + dialAddrStr := dialAddr.String() ns.logf("[v2] netstack: forwarding incoming connection to %s", dialAddrStr) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() waitEntry, notifyCh := waiter.NewChannelEntry(nil) @@ -530,7 +558,6 @@ func (ns *Impl) forwardTCP(client *gonet.TCPConn, wq *waiter.Queue, dialAddr tcp defer server.Close() backendLocalAddr := server.LocalAddr().(*net.TCPAddr) backendLocalIPPort, _ := netaddr.FromStdAddr(backendLocalAddr.IP, backendLocalAddr.Port, backendLocalAddr.Zone) - clientRemoteIP, _ := netaddr.FromStdIP(client.RemoteAddr().(*net.TCPAddr).IP) ns.e.RegisterIPPortIdentity(backendLocalIPPort, clientRemoteIP) defer ns.e.UnregisterIPPortIdentity(backendLocalIPPort) connClosed := make(chan error, 2)