mirror of
https://github.com/tailscale/tailscale.git
synced 2025-01-07 08:07:42 +00:00
wgengine/netstack: fix crash in userspace netstack TCP forwarding
Fixes #2658 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
88bd796622
commit
1925fb584e
@ -468,19 +468,37 @@ func (ns *Impl) injectInbound(p *packet.Parsed, t *tstun.Wrapper) filter.Respons
|
|||||||
return filter.DropSilently
|
return filter.DropSilently
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func netaddrIPFromNetstackIP(s tcpip.Address) netaddr.IP {
|
||||||
|
switch len(s) {
|
||||||
|
case 4:
|
||||||
|
return netaddr.IPv4(s[0], s[1], s[2], s[3])
|
||||||
|
case 16:
|
||||||
|
var a [16]byte
|
||||||
|
copy(a[:], s)
|
||||||
|
return netaddr.IPFrom16(a)
|
||||||
|
}
|
||||||
|
return netaddr.IP{}
|
||||||
|
}
|
||||||
|
|
||||||
func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
|
func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
|
||||||
reqDetails := r.ID()
|
reqDetails := r.ID()
|
||||||
if debugNetstack {
|
if debugNetstack {
|
||||||
ns.logf("[v2] TCP ForwarderRequest: %s", stringifyTEI(reqDetails))
|
ns.logf("[v2] TCP ForwarderRequest: %s", stringifyTEI(reqDetails))
|
||||||
}
|
}
|
||||||
dialAddr := reqDetails.LocalAddress
|
clientRemoteIP := netaddrIPFromNetstackIP(reqDetails.RemoteAddress)
|
||||||
dialNetAddr, _ := netaddr.FromStdIP(net.IP(dialAddr))
|
if !clientRemoteIP.IsValid() {
|
||||||
isTailscaleIP := tsaddr.IsTailscaleIP(dialNetAddr)
|
ns.logf("invalid RemoteAddress in TCP ForwarderRequest: %s", stringifyTEI(reqDetails))
|
||||||
|
r.Complete(true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
dialIP := netaddrIPFromNetstackIP(reqDetails.LocalAddress)
|
||||||
|
isTailscaleIP := tsaddr.IsTailscaleIP(dialIP)
|
||||||
defer func() {
|
defer func() {
|
||||||
if !isTailscaleIP {
|
if !isTailscaleIP {
|
||||||
// if this is a subnet IP, we added this in before the TCP handshake
|
// if this is a subnet IP, we added this in before the TCP handshake
|
||||||
// so netstack is happy TCP-handshaking as a subnet IP
|
// so netstack is happy TCP-handshaking as a subnet IP
|
||||||
ns.removeSubnetAddress(dialNetAddr)
|
ns.removeSubnetAddress(dialIP)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
var wq waiter.Queue
|
var wq waiter.Queue
|
||||||
@ -490,21 +508,31 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.Complete(false)
|
r.Complete(false)
|
||||||
|
|
||||||
|
// Asynchronously start the TCP handshake. Note that the
|
||||||
|
// gonet.TCPConn methods c.RemoteAddr() and c.LocalAddr() will
|
||||||
|
// return nil until the handshake actually completes. But we
|
||||||
|
// have the remote address in reqDetails instead, so we don't
|
||||||
|
// use RemoteAddr. The byte copies in both directions in
|
||||||
|
// forwardTCP will block until the TCP handshake is complete.
|
||||||
c := gonet.NewTCPConn(&wq, ep)
|
c := gonet.NewTCPConn(&wq, ep)
|
||||||
|
|
||||||
if ns.ForwardTCPIn != nil {
|
if ns.ForwardTCPIn != nil {
|
||||||
ns.ForwardTCPIn(c, reqDetails.LocalPort)
|
ns.ForwardTCPIn(c, reqDetails.LocalPort)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if isTailscaleIP {
|
if isTailscaleIP {
|
||||||
dialAddr = tcpip.Address(net.ParseIP("127.0.0.1")).To4()
|
dialIP = netaddr.IPv4(127, 0, 0, 1)
|
||||||
}
|
}
|
||||||
ns.forwardTCP(c, &wq, dialAddr, reqDetails.LocalPort)
|
dialAddr := netaddr.IPPortFrom(dialIP, uint16(reqDetails.LocalPort))
|
||||||
|
ns.forwardTCP(c, clientRemoteIP, &wq, dialAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ns *Impl) forwardTCP(client *gonet.TCPConn, wq *waiter.Queue, dialAddr tcpip.Address, dialPort uint16) {
|
func (ns *Impl) forwardTCP(client *gonet.TCPConn, clientRemoteIP netaddr.IP, wq *waiter.Queue, dialAddr netaddr.IPPort) {
|
||||||
defer client.Close()
|
defer client.Close()
|
||||||
dialAddrStr := net.JoinHostPort(dialAddr.String(), strconv.Itoa(int(dialPort)))
|
dialAddrStr := dialAddr.String()
|
||||||
ns.logf("[v2] netstack: forwarding incoming connection to %s", dialAddrStr)
|
ns.logf("[v2] netstack: forwarding incoming connection to %s", dialAddrStr)
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
|
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
|
||||||
@ -530,7 +558,6 @@ func (ns *Impl) forwardTCP(client *gonet.TCPConn, wq *waiter.Queue, dialAddr tcp
|
|||||||
defer server.Close()
|
defer server.Close()
|
||||||
backendLocalAddr := server.LocalAddr().(*net.TCPAddr)
|
backendLocalAddr := server.LocalAddr().(*net.TCPAddr)
|
||||||
backendLocalIPPort, _ := netaddr.FromStdAddr(backendLocalAddr.IP, backendLocalAddr.Port, backendLocalAddr.Zone)
|
backendLocalIPPort, _ := netaddr.FromStdAddr(backendLocalAddr.IP, backendLocalAddr.Port, backendLocalAddr.Zone)
|
||||||
clientRemoteIP, _ := netaddr.FromStdIP(client.RemoteAddr().(*net.TCPAddr).IP)
|
|
||||||
ns.e.RegisterIPPortIdentity(backendLocalIPPort, clientRemoteIP)
|
ns.e.RegisterIPPortIdentity(backendLocalIPPort, clientRemoteIP)
|
||||||
defer ns.e.UnregisterIPPortIdentity(backendLocalIPPort)
|
defer ns.e.UnregisterIPPortIdentity(backendLocalIPPort)
|
||||||
connClosed := make(chan error, 2)
|
connClosed := make(chan error, 2)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user