mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-29 04:55:31 +00:00
wgengine/netstack: implement UDP relaying to advertised subnets
TCP was done in 662fbd4a09
.
This does the same for UDP.
Tested by hand. Integration tests will have to come later. I'd wanted
to do it in this commit, but the SOCKS5 server needed for interop
testing between two userspace nodes doesn't yet support UDP and I
didn't want to invent some whole new userspace packet injection
interface at this point, as SOCKS seems like a better route, but
that's its own bug.
Fixes #2302
RELNOTE=netstack mode can now UDP relay to subnets
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
3daf27eaad
commit
95a9adbb97
@ -136,6 +136,24 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
// wrapProtoHandler returns protocol handler h wrapped in a version
|
||||
// that dynamically reconfigures ns's subnet addresses as needed for
|
||||
// outbound traffic.
|
||||
func (ns *Impl) wrapProtoHandler(h func(stack.TransportEndpointID, *stack.PacketBuffer) bool) func(stack.TransportEndpointID, *stack.PacketBuffer) bool {
|
||||
return func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) bool {
|
||||
addr := tei.LocalAddress
|
||||
ip, ok := netaddr.FromStdIP(net.IP(addr))
|
||||
if !ok {
|
||||
ns.logf("netstack: could not parse local address for incoming connection")
|
||||
return false
|
||||
}
|
||||
if !ns.isLocalIP(ip) {
|
||||
ns.addSubnetAddress(ip)
|
||||
}
|
||||
return h(tei, pb)
|
||||
}
|
||||
}
|
||||
|
||||
// Start sets up all the handlers so netstack can start working. Implements
|
||||
// wgengine.FakeImpl.
|
||||
func (ns *Impl) Start() error {
|
||||
@ -145,25 +163,8 @@ func (ns *Impl) Start() error {
|
||||
const maxInFlightConnectionAttempts = 16
|
||||
tcpFwd := tcp.NewForwarder(ns.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, ns.acceptTCP)
|
||||
udpFwd := udp.NewForwarder(ns.ipstack, ns.acceptUDP)
|
||||
ns.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) bool {
|
||||
addr := tei.LocalAddress
|
||||
var pn tcpip.NetworkProtocolNumber
|
||||
if addr.To4() != "" {
|
||||
pn = ipv4.ProtocolNumber
|
||||
} else {
|
||||
pn = ipv6.ProtocolNumber
|
||||
}
|
||||
ip, ok := netaddr.FromStdIP(net.IP(addr))
|
||||
if !ok {
|
||||
ns.logf("netstack: could not parse local address %s for incoming TCP connection", ip)
|
||||
return false
|
||||
}
|
||||
if !ns.isLocalIP(ip) {
|
||||
ns.addSubnetAddress(pn, ip)
|
||||
}
|
||||
return tcpFwd.HandlePacket(tei, pb)
|
||||
})
|
||||
ns.ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, udpFwd.HandlePacket)
|
||||
ns.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, ns.wrapProtoHandler(tcpFwd.HandlePacket))
|
||||
ns.ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, ns.wrapProtoHandler(udpFwd.HandlePacket))
|
||||
go ns.injectOutbound()
|
||||
ns.tundev.PostFilterIn = ns.injectInbound
|
||||
return nil
|
||||
@ -214,13 +215,19 @@ func (ns *Impl) updateDNS(nm *netmap.NetworkMap) {
|
||||
ns.dns = DNSMapFromNetworkMap(nm)
|
||||
}
|
||||
|
||||
func (ns *Impl) addSubnetAddress(pn tcpip.NetworkProtocolNumber, ip netaddr.IP) {
|
||||
func (ns *Impl) addSubnetAddress(ip netaddr.IP) {
|
||||
ns.mu.Lock()
|
||||
ns.connsOpenBySubnetIP[ip]++
|
||||
needAdd := ns.connsOpenBySubnetIP[ip] == 1
|
||||
ns.mu.Unlock()
|
||||
// Only register address into netstack for first concurrent connection.
|
||||
if needAdd {
|
||||
var pn tcpip.NetworkProtocolNumber
|
||||
if ip.Is4() {
|
||||
pn = ipv4.ProtocolNumber
|
||||
} else if ip.Is6() {
|
||||
pn = ipv6.ProtocolNumber
|
||||
}
|
||||
ns.ipstack.AddAddress(nicID, pn, tcpip.Address(ip.IPAddr().IP))
|
||||
}
|
||||
}
|
||||
@ -543,9 +550,9 @@ func (ns *Impl) forwardTCP(client *gonet.TCPConn, wq *waiter.Queue, dialAddr tcp
|
||||
}
|
||||
|
||||
func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) {
|
||||
reqDetails := r.ID()
|
||||
sess := r.ID()
|
||||
if debugNetstack {
|
||||
ns.logf("[v2] UDP ForwarderRequest: %v", stringifyTEI(reqDetails))
|
||||
ns.logf("[v2] UDP ForwarderRequest: %v", stringifyTEI(sess))
|
||||
}
|
||||
var wq waiter.Queue
|
||||
ep, err := r.CreateEndpoint(&wq)
|
||||
@ -553,30 +560,50 @@ func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) {
|
||||
ns.logf("acceptUDP: could not create endpoint: %v", err)
|
||||
return
|
||||
}
|
||||
localAddr, err := ep.GetLocalAddress()
|
||||
if err != nil {
|
||||
dstAddr, ok := ipPortOfNetstackAddr(sess.LocalAddress, sess.LocalPort)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
remoteAddr, err := ep.GetRemoteAddress()
|
||||
if err != nil {
|
||||
srcAddr, ok := ipPortOfNetstackAddr(sess.RemoteAddress, sess.RemotePort)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
c := gonet.NewUDPConn(ns.ipstack, &wq, ep)
|
||||
go ns.forwardUDP(c, &wq, localAddr, remoteAddr)
|
||||
go ns.forwardUDP(c, &wq, srcAddr, dstAddr)
|
||||
}
|
||||
|
||||
func (ns *Impl) forwardUDP(client *gonet.UDPConn, wq *waiter.Queue, clientLocalAddr, clientRemoteAddr tcpip.FullAddress) {
|
||||
port := clientLocalAddr.Port
|
||||
// forwardUDP proxies between client (with addr clientAddr) and dstAddr.
|
||||
//
|
||||
// dstAddr may be either a local Tailscale IP, in which we case we proxy to
|
||||
// 127.0.0.1, or any other IP (from an advertised subnet), in which case we
|
||||
// proxy to it directly.
|
||||
func (ns *Impl) forwardUDP(client *gonet.UDPConn, wq *waiter.Queue, clientAddr, dstAddr netaddr.IPPort) {
|
||||
port, srcPort := dstAddr.Port(), clientAddr.Port()
|
||||
ns.logf("[v2] netstack: forwarding incoming UDP connection on port %v", port)
|
||||
backendListenAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(clientRemoteAddr.Port)}
|
||||
backendRemoteAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(port)}
|
||||
backendConn, err := net.ListenUDP("udp4", backendListenAddr)
|
||||
|
||||
var backendListenAddr *net.UDPAddr
|
||||
var backendRemoteAddr *net.UDPAddr
|
||||
isLocal := ns.isLocalIP(dstAddr.IP())
|
||||
if isLocal {
|
||||
backendRemoteAddr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(port)}
|
||||
backendListenAddr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(srcPort)}
|
||||
} else {
|
||||
backendRemoteAddr = dstAddr.UDPAddr()
|
||||
if dstAddr.IP().Is4() {
|
||||
backendListenAddr = &net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: int(srcPort)}
|
||||
} else {
|
||||
backendListenAddr = &net.UDPAddr{IP: net.ParseIP("::"), Port: int(srcPort)}
|
||||
}
|
||||
}
|
||||
|
||||
backendConn, err := net.ListenUDP("udp", backendListenAddr)
|
||||
if err != nil {
|
||||
ns.logf("netstack: could not bind local port %v: %v, trying again with random port", clientRemoteAddr.Port, err)
|
||||
ns.logf("netstack: could not bind local port %v: %v, trying again with random port", backendListenAddr.Port, err)
|
||||
backendListenAddr.Port = 0
|
||||
backendConn, err = net.ListenUDP("udp4", backendListenAddr)
|
||||
backendConn, err = net.ListenUDP("udp", backendListenAddr)
|
||||
if err != nil {
|
||||
ns.logf("netstack: could not connect to local UDP server on port %v: %v", port, err)
|
||||
ns.logf("netstack: could not create UDP socket, preventing forwarding to %v: %v", dstAddr, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -585,28 +612,47 @@ func (ns *Impl) forwardUDP(client *gonet.UDPConn, wq *waiter.Queue, clientLocalA
|
||||
if !ok {
|
||||
ns.logf("could not get backend local IP:port from %v:%v", backendLocalAddr.IP, backendLocalAddr.Port)
|
||||
}
|
||||
clientRemoteIP, _ := netaddr.FromStdIP(net.ParseIP(clientRemoteAddr.Addr.String()))
|
||||
ns.e.RegisterIPPortIdentity(backendLocalIPPort, clientRemoteIP)
|
||||
if isLocal {
|
||||
ns.e.RegisterIPPortIdentity(backendLocalIPPort, dstAddr.IP())
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
timer := time.AfterFunc(2*time.Minute, func() {
|
||||
ns.e.UnregisterIPPortIdentity(backendLocalIPPort)
|
||||
ns.logf("netstack: UDP session between %s and %s timed out", clientRemoteAddr, backendRemoteAddr)
|
||||
|
||||
idleTimeout := 2 * time.Minute
|
||||
if port == 53 {
|
||||
// Make DNS packet copies time out much sooner.
|
||||
//
|
||||
// TODO(bradfitz): make DNS queries over UDP forwarding even
|
||||
// cheaper by adding an additional idleTimeout post-DNS-reply.
|
||||
// For instance, after the DNS response goes back out, then only
|
||||
// wait a few seconds (or zero, really)
|
||||
idleTimeout = 30 * time.Second
|
||||
}
|
||||
timer := time.AfterFunc(idleTimeout, func() {
|
||||
if isLocal {
|
||||
ns.e.UnregisterIPPortIdentity(backendLocalIPPort)
|
||||
}
|
||||
ns.logf("netstack: UDP session between %s and %s timed out", backendListenAddr, backendRemoteAddr)
|
||||
cancel()
|
||||
client.Close()
|
||||
backendConn.Close()
|
||||
})
|
||||
extend := func() {
|
||||
timer.Reset(2 * time.Minute)
|
||||
timer.Reset(idleTimeout)
|
||||
}
|
||||
startPacketCopy(ctx, cancel, client, &net.UDPAddr{
|
||||
IP: net.ParseIP(clientRemoteAddr.Addr.String()),
|
||||
Port: int(clientRemoteAddr.Port),
|
||||
}, backendConn, ns.logf, extend)
|
||||
startPacketCopy(ctx, cancel, client, clientAddr.UDPAddr(), backendConn, ns.logf, extend)
|
||||
startPacketCopy(ctx, cancel, backendConn, backendRemoteAddr, client, ns.logf, extend)
|
||||
|
||||
if isLocal {
|
||||
// Wait for the copies to be done before decrementing the
|
||||
// subnet address count to potentially remove the route.
|
||||
<-ctx.Done()
|
||||
ns.removeSubnetAddress(dstAddr.IP())
|
||||
}
|
||||
}
|
||||
|
||||
func startPacketCopy(ctx context.Context, cancel context.CancelFunc, dst net.PacketConn, dstAddr net.Addr, src net.PacketConn, logf logger.Logf, extend func()) {
|
||||
if debugNetstack {
|
||||
logf("[v2] netstack: startPacketCopy to %v (%T) from %T", dstAddr, dst, src)
|
||||
}
|
||||
go func() {
|
||||
defer cancel() // tear down the other direction's copy
|
||||
pkt := make([]byte, mtu)
|
||||
@ -643,3 +689,7 @@ func stringifyTEI(tei stack.TransportEndpointID) string {
|
||||
remoteHostPort := net.JoinHostPort(tei.RemoteAddress.String(), strconv.Itoa(int(tei.RemotePort)))
|
||||
return fmt.Sprintf("%s -> %s", remoteHostPort, localHostPort)
|
||||
}
|
||||
|
||||
func ipPortOfNetstackAddr(a tcpip.Address, port uint16) (ipp netaddr.IPPort, ok bool) {
|
||||
return netaddr.FromStdAddr(net.IP(a), int(port), "") // TODO(bradfitz): can do without allocs
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user