wgengine/netstack: fix crash in userspace netstack TCP forwarding

Fixes #2658

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2021-08-25 14:39:09 -07:00 committed by Brad Fitzpatrick
parent 88bd796622
commit 1925fb584e

View File

@ -468,19 +468,37 @@ func (ns *Impl) injectInbound(p *packet.Parsed, t *tstun.Wrapper) filter.Respons
return filter.DropSilently return filter.DropSilently
} }
func netaddrIPFromNetstackIP(s tcpip.Address) netaddr.IP {
switch len(s) {
case 4:
return netaddr.IPv4(s[0], s[1], s[2], s[3])
case 16:
var a [16]byte
copy(a[:], s)
return netaddr.IPFrom16(a)
}
return netaddr.IP{}
}
func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
reqDetails := r.ID() reqDetails := r.ID()
if debugNetstack { if debugNetstack {
ns.logf("[v2] TCP ForwarderRequest: %s", stringifyTEI(reqDetails)) ns.logf("[v2] TCP ForwarderRequest: %s", stringifyTEI(reqDetails))
} }
dialAddr := reqDetails.LocalAddress clientRemoteIP := netaddrIPFromNetstackIP(reqDetails.RemoteAddress)
dialNetAddr, _ := netaddr.FromStdIP(net.IP(dialAddr)) if !clientRemoteIP.IsValid() {
isTailscaleIP := tsaddr.IsTailscaleIP(dialNetAddr) ns.logf("invalid RemoteAddress in TCP ForwarderRequest: %s", stringifyTEI(reqDetails))
r.Complete(true)
return
}
dialIP := netaddrIPFromNetstackIP(reqDetails.LocalAddress)
isTailscaleIP := tsaddr.IsTailscaleIP(dialIP)
defer func() { defer func() {
if !isTailscaleIP { if !isTailscaleIP {
// if this is a subnet IP, we added this in before the TCP handshake // if this is a subnet IP, we added this in before the TCP handshake
// so netstack is happy TCP-handshaking as a subnet IP // so netstack is happy TCP-handshaking as a subnet IP
ns.removeSubnetAddress(dialNetAddr) ns.removeSubnetAddress(dialIP)
} }
}() }()
var wq waiter.Queue var wq waiter.Queue
@ -490,21 +508,31 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
return return
} }
r.Complete(false) r.Complete(false)
// Asynchronously start the TCP handshake. Note that the
// gonet.TCPConn methods c.RemoteAddr() and c.LocalAddr() will
// return nil until the handshake actually completes. But we
// have the remote address in reqDetails instead, so we don't
// use RemoteAddr. The byte copies in both directions in
// forwardTCP will block until the TCP handshake is complete.
c := gonet.NewTCPConn(&wq, ep) c := gonet.NewTCPConn(&wq, ep)
if ns.ForwardTCPIn != nil { if ns.ForwardTCPIn != nil {
ns.ForwardTCPIn(c, reqDetails.LocalPort) ns.ForwardTCPIn(c, reqDetails.LocalPort)
return return
} }
if isTailscaleIP { if isTailscaleIP {
dialAddr = tcpip.Address(net.ParseIP("127.0.0.1")).To4() dialIP = netaddr.IPv4(127, 0, 0, 1)
} }
ns.forwardTCP(c, &wq, dialAddr, reqDetails.LocalPort) dialAddr := netaddr.IPPortFrom(dialIP, uint16(reqDetails.LocalPort))
ns.forwardTCP(c, clientRemoteIP, &wq, dialAddr)
} }
func (ns *Impl) forwardTCP(client *gonet.TCPConn, wq *waiter.Queue, dialAddr tcpip.Address, dialPort uint16) { func (ns *Impl) forwardTCP(client *gonet.TCPConn, clientRemoteIP netaddr.IP, wq *waiter.Queue, dialAddr netaddr.IPPort) {
defer client.Close() defer client.Close()
dialAddrStr := net.JoinHostPort(dialAddr.String(), strconv.Itoa(int(dialPort))) dialAddrStr := dialAddr.String()
ns.logf("[v2] netstack: forwarding incoming connection to %s", dialAddrStr) ns.logf("[v2] netstack: forwarding incoming connection to %s", dialAddrStr)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
waitEntry, notifyCh := waiter.NewChannelEntry(nil) waitEntry, notifyCh := waiter.NewChannelEntry(nil)
@ -530,7 +558,6 @@ func (ns *Impl) forwardTCP(client *gonet.TCPConn, wq *waiter.Queue, dialAddr tcp
defer server.Close() defer server.Close()
backendLocalAddr := server.LocalAddr().(*net.TCPAddr) backendLocalAddr := server.LocalAddr().(*net.TCPAddr)
backendLocalIPPort, _ := netaddr.FromStdAddr(backendLocalAddr.IP, backendLocalAddr.Port, backendLocalAddr.Zone) backendLocalIPPort, _ := netaddr.FromStdAddr(backendLocalAddr.IP, backendLocalAddr.Port, backendLocalAddr.Zone)
clientRemoteIP, _ := netaddr.FromStdIP(client.RemoteAddr().(*net.TCPAddr).IP)
ns.e.RegisterIPPortIdentity(backendLocalIPPort, clientRemoteIP) ns.e.RegisterIPPortIdentity(backendLocalIPPort, clientRemoteIP)
defer ns.e.UnregisterIPPortIdentity(backendLocalIPPort) defer ns.e.UnregisterIPPortIdentity(backendLocalIPPort)
connClosed := make(chan error, 2) connClosed := make(chan error, 2)