mirror of
https://github.com/tailscale/tailscale.git
synced 2025-01-10 01:53:49 +00:00
b8fb8264a5
The earlier eb06ec172f1d984bb87c589da1dd2d3f15dc6d82 fixed the flaky SSH issue (tailscale/corp#1725) by making sure that packets addressed to Tailscale IPs in hybrid netstack mode weren't delivered to netstack, but another issue remained: All traffic handled by netstack was also potentially being handled by the host networking stack, as the filter hook returned "Accept", which made it keep processing. This could lead to various random racey chaos as a function of OS/firewalls/routes/etc. Instead, once we inject into netstack, stop our caller's packet processing. Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
601 lines
18 KiB
Go
601 lines
18 KiB
Go
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
// Package netstack wires up gVisor's netstack into Tailscale.
|
|
package netstack
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"inet.af/netaddr"
|
|
"inet.af/netstack/tcpip"
|
|
"inet.af/netstack/tcpip/adapters/gonet"
|
|
"inet.af/netstack/tcpip/buffer"
|
|
"inet.af/netstack/tcpip/header"
|
|
"inet.af/netstack/tcpip/link/channel"
|
|
"inet.af/netstack/tcpip/network/ipv4"
|
|
"inet.af/netstack/tcpip/network/ipv6"
|
|
"inet.af/netstack/tcpip/stack"
|
|
"inet.af/netstack/tcpip/transport/icmp"
|
|
"inet.af/netstack/tcpip/transport/tcp"
|
|
"inet.af/netstack/tcpip/transport/udp"
|
|
"inet.af/netstack/waiter"
|
|
"tailscale.com/net/packet"
|
|
"tailscale.com/net/tsaddr"
|
|
"tailscale.com/net/tstun"
|
|
"tailscale.com/types/logger"
|
|
"tailscale.com/types/netmap"
|
|
"tailscale.com/util/dnsname"
|
|
"tailscale.com/wgengine"
|
|
"tailscale.com/wgengine/filter"
|
|
"tailscale.com/wgengine/magicsock"
|
|
)
|
|
|
|
const debugNetstack = false
|
|
|
|
// Impl contains the state for the netstack implementation,
|
|
// and implements wgengine.FakeImpl to act as a userspace network
|
|
// stack when Tailscale is running in fake mode.
|
|
type Impl struct {
|
|
ipstack *stack.Stack
|
|
linkEP *channel.Endpoint
|
|
tundev *tstun.Wrapper
|
|
e wgengine.Engine
|
|
mc *magicsock.Conn
|
|
logf logger.Logf
|
|
onlySubnets bool // whether we only want to handle subnet relaying
|
|
|
|
// atomicIsLocalIPFunc holds a func that reports whether an IP
|
|
// is a local (non-subnet) Tailscale IP address of this
|
|
// machine. It's always a non-nil func. It's changed on netmap
|
|
// updates.
|
|
atomicIsLocalIPFunc atomic.Value // of func(netaddr.IP) bool
|
|
|
|
mu sync.Mutex
|
|
dns DNSMap
|
|
// connsOpenBySubnetIP keeps track of number of connections open
|
|
// for each subnet IP temporarily registered on netstack for active
|
|
// TCP connections, so they can be unregistered when connections are
|
|
// closed.
|
|
connsOpenBySubnetIP map[netaddr.IP]int
|
|
}
|
|
|
|
const nicID = 1
|
|
const mtu = 1500
|
|
|
|
// Create creates and populates a new Impl.
|
|
func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magicsock.Conn, onlySubnets bool) (*Impl, error) {
|
|
if mc == nil {
|
|
return nil, errors.New("nil magicsock.Conn")
|
|
}
|
|
if tundev == nil {
|
|
return nil, errors.New("nil tundev")
|
|
}
|
|
if logf == nil {
|
|
return nil, errors.New("nil logger")
|
|
}
|
|
if e == nil {
|
|
return nil, errors.New("nil Engine")
|
|
}
|
|
ipstack := stack.New(stack.Options{
|
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
|
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
|
|
})
|
|
linkEP := channel.New(512, mtu, "")
|
|
if tcpipProblem := ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil {
|
|
return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem)
|
|
}
|
|
// By default the netstack NIC will only accept packets for the IPs
|
|
// registered to it. Since in some cases we dynamically register IPs
|
|
// based on the packets that arrive, the NIC needs to accept all
|
|
// incoming packets. The NIC won't receive anything it isn't meant to
|
|
// since Wireguard will only send us packets that are meant for us.
|
|
ipstack.SetPromiscuousMode(nicID, true)
|
|
// Add IPv4 and IPv6 default routes, so all incoming packets from the Tailscale side
|
|
// are handled by the one fake NIC we use.
|
|
ipv4Subnet, _ := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 4)), tcpip.AddressMask(strings.Repeat("\x00", 4)))
|
|
ipv6Subnet, _ := tcpip.NewSubnet(tcpip.Address(strings.Repeat("\x00", 16)), tcpip.AddressMask(strings.Repeat("\x00", 16)))
|
|
ipstack.SetRouteTable([]tcpip.Route{
|
|
{
|
|
Destination: ipv4Subnet,
|
|
NIC: nicID,
|
|
},
|
|
{
|
|
Destination: ipv6Subnet,
|
|
NIC: nicID,
|
|
},
|
|
})
|
|
ns := &Impl{
|
|
logf: logf,
|
|
ipstack: ipstack,
|
|
linkEP: linkEP,
|
|
tundev: tundev,
|
|
e: e,
|
|
mc: mc,
|
|
connsOpenBySubnetIP: make(map[netaddr.IP]int),
|
|
onlySubnets: onlySubnets,
|
|
}
|
|
ns.atomicIsLocalIPFunc.Store(tsaddr.NewContainsIPFunc(nil))
|
|
return ns, nil
|
|
}
|
|
|
|
// Start sets up all the handlers so netstack can start working. Implements
|
|
|
|
// wgengine.FakeImpl.
|
|
func (ns *Impl) Start() error {
|
|
ns.e.AddNetworkMapCallback(ns.updateIPs)
|
|
// size = 0 means use default buffer size
|
|
const tcpReceiveBufferSize = 0
|
|
const maxInFlightConnectionAttempts = 16
|
|
tcpFwd := tcp.NewForwarder(ns.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, ns.acceptTCP)
|
|
udpFwd := udp.NewForwarder(ns.ipstack, ns.acceptUDP)
|
|
ns.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) bool {
|
|
addr := tei.LocalAddress
|
|
var pn tcpip.NetworkProtocolNumber
|
|
if addr.To4() != "" {
|
|
pn = ipv4.ProtocolNumber
|
|
} else {
|
|
pn = ipv6.ProtocolNumber
|
|
}
|
|
ip, ok := netaddr.FromStdIP(net.IP(addr))
|
|
if !ok {
|
|
ns.logf("netstack: could not parse local address %s for incoming TCP connection", ip)
|
|
return false
|
|
}
|
|
if !ns.isLocalIP(ip) {
|
|
ns.addSubnetAddress(pn, ip)
|
|
}
|
|
return tcpFwd.HandlePacket(tei, pb)
|
|
})
|
|
ns.ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, udpFwd.HandlePacket)
|
|
go ns.injectOutbound()
|
|
ns.tundev.PostFilterIn = ns.injectInbound
|
|
return nil
|
|
}
|
|
|
|
// DNSMap maps MagicDNS names (both base + FQDN) to their first IP.
|
|
// It should not be mutated once created.
|
|
type DNSMap map[string]netaddr.IP
|
|
|
|
func DNSMapFromNetworkMap(nm *netmap.NetworkMap) DNSMap {
|
|
ret := make(DNSMap)
|
|
suffix := nm.MagicDNSSuffix()
|
|
|
|
if nm.Name != "" && len(nm.Addresses) > 0 {
|
|
ip := nm.Addresses[0].IP
|
|
ret[strings.TrimRight(nm.Name, ".")] = ip
|
|
if dnsname.HasSuffix(nm.Name, suffix) {
|
|
ret[dnsname.TrimSuffix(nm.Name, suffix)] = ip
|
|
}
|
|
}
|
|
for _, p := range nm.Peers {
|
|
if p.Name != "" && len(p.Addresses) > 0 {
|
|
ip := p.Addresses[0].IP
|
|
ret[strings.TrimRight(p.Name, ".")] = ip
|
|
if dnsname.HasSuffix(p.Name, suffix) {
|
|
ret[dnsname.TrimSuffix(p.Name, suffix)] = ip
|
|
}
|
|
}
|
|
}
|
|
return ret
|
|
}
|
|
|
|
func (ns *Impl) updateDNS(nm *netmap.NetworkMap) {
|
|
ns.mu.Lock()
|
|
defer ns.mu.Unlock()
|
|
ns.dns = DNSMapFromNetworkMap(nm)
|
|
}
|
|
|
|
func (ns *Impl) addSubnetAddress(pn tcpip.NetworkProtocolNumber, ip netaddr.IP) {
|
|
ns.mu.Lock()
|
|
ns.connsOpenBySubnetIP[ip]++
|
|
needAdd := ns.connsOpenBySubnetIP[ip] == 1
|
|
ns.mu.Unlock()
|
|
// Only register address into netstack for first concurrent connection.
|
|
if needAdd {
|
|
ns.ipstack.AddAddress(nicID, pn, tcpip.Address(ip.IPAddr().IP))
|
|
}
|
|
}
|
|
|
|
func (ns *Impl) removeSubnetAddress(ip netaddr.IP) {
|
|
ns.mu.Lock()
|
|
defer ns.mu.Unlock()
|
|
ns.connsOpenBySubnetIP[ip]--
|
|
// Only unregister address from netstack after last concurrent connection.
|
|
if ns.connsOpenBySubnetIP[ip] == 0 {
|
|
ns.ipstack.RemoveAddress(nicID, tcpip.Address(ip.IPAddr().IP))
|
|
delete(ns.connsOpenBySubnetIP, ip)
|
|
}
|
|
}
|
|
|
|
func ipPrefixToAddressWithPrefix(ipp netaddr.IPPrefix) tcpip.AddressWithPrefix {
|
|
return tcpip.AddressWithPrefix{
|
|
Address: tcpip.Address(ipp.IP.IPAddr().IP),
|
|
PrefixLen: int(ipp.Bits),
|
|
}
|
|
}
|
|
|
|
func (ns *Impl) updateIPs(nm *netmap.NetworkMap) {
|
|
ns.atomicIsLocalIPFunc.Store(tsaddr.NewContainsIPFunc(nm.Addresses))
|
|
ns.updateDNS(nm)
|
|
|
|
oldIPs := make(map[tcpip.AddressWithPrefix]bool)
|
|
for _, protocolAddr := range ns.ipstack.AllAddresses()[nicID] {
|
|
oldIPs[protocolAddr.AddressWithPrefix] = true
|
|
}
|
|
newIPs := make(map[tcpip.AddressWithPrefix]bool)
|
|
|
|
isAddr := map[netaddr.IPPrefix]bool{}
|
|
for _, ipp := range nm.SelfNode.Addresses {
|
|
isAddr[ipp] = true
|
|
}
|
|
for _, ipp := range nm.SelfNode.AllowedIPs {
|
|
if ns.onlySubnets && isAddr[ipp] {
|
|
continue
|
|
}
|
|
newIPs[ipPrefixToAddressWithPrefix(ipp)] = true
|
|
}
|
|
|
|
ipsToBeAdded := make(map[tcpip.AddressWithPrefix]bool)
|
|
for ipp := range newIPs {
|
|
if !oldIPs[ipp] {
|
|
ipsToBeAdded[ipp] = true
|
|
}
|
|
}
|
|
ipsToBeRemoved := make(map[tcpip.AddressWithPrefix]bool)
|
|
for ip := range oldIPs {
|
|
if !newIPs[ip] {
|
|
ipsToBeRemoved[ip] = true
|
|
}
|
|
}
|
|
ns.mu.Lock()
|
|
for ip := range ns.connsOpenBySubnetIP {
|
|
ipp := tcpip.Address(ip.IPAddr().IP).WithPrefix()
|
|
delete(ipsToBeRemoved, ipp)
|
|
}
|
|
ns.mu.Unlock()
|
|
|
|
for ipp := range ipsToBeRemoved {
|
|
err := ns.ipstack.RemoveAddress(nicID, ipp.Address)
|
|
if err != nil {
|
|
ns.logf("netstack: could not deregister IP %s: %v", ipp, err)
|
|
} else {
|
|
ns.logf("[v2] netstack: deregistered IP %s", ipp)
|
|
}
|
|
}
|
|
for ipp := range ipsToBeAdded {
|
|
var err tcpip.Error
|
|
if ipp.Address.To4() == "" {
|
|
err = ns.ipstack.AddAddressWithPrefix(nicID, ipv6.ProtocolNumber, ipp)
|
|
} else {
|
|
err = ns.ipstack.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, ipp)
|
|
}
|
|
if err != nil {
|
|
ns.logf("netstack: could not register IP %s: %v", ipp, err)
|
|
} else {
|
|
ns.logf("[v2] netstack: registered IP %s", ipp)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Resolve resolves addr into an IP:port using first the MagicDNS contents
|
|
// of m, else using the system resolver.
|
|
func (m DNSMap) Resolve(ctx context.Context, addr string) (netaddr.IPPort, error) {
|
|
ipp, pippErr := netaddr.ParseIPPort(addr)
|
|
if pippErr == nil {
|
|
return ipp, nil
|
|
}
|
|
host, port, err := net.SplitHostPort(addr)
|
|
if err != nil {
|
|
// addr is malformed.
|
|
return netaddr.IPPort{}, err
|
|
}
|
|
if net.ParseIP(host) != nil {
|
|
// The host part of addr was an IP, so the netaddr.ParseIPPort above should've
|
|
// passed. Must've been a bad port number. Return the original error.
|
|
return netaddr.IPPort{}, pippErr
|
|
}
|
|
port16, err := strconv.ParseUint(port, 10, 16)
|
|
if err != nil {
|
|
return netaddr.IPPort{}, fmt.Errorf("invalid port in address %q", addr)
|
|
}
|
|
|
|
// Host is not an IP, so assume it's a DNS name.
|
|
|
|
// Try MagicDNS first, else otherwise a real DNS lookup.
|
|
ip := m[host]
|
|
if !ip.IsZero() {
|
|
return netaddr.IPPort{IP: ip, Port: uint16(port16)}, nil
|
|
}
|
|
|
|
// No MagicDNS name so try real DNS.
|
|
var r net.Resolver
|
|
ips, err := r.LookupIP(ctx, "ip", host)
|
|
if err != nil {
|
|
return netaddr.IPPort{}, err
|
|
}
|
|
if len(ips) == 0 {
|
|
return netaddr.IPPort{}, fmt.Errorf("DNS lookup returned no results for %q", host)
|
|
}
|
|
ip, _ = netaddr.FromStdIP(ips[0])
|
|
return netaddr.IPPort{IP: ip, Port: uint16(port16)}, nil
|
|
}
|
|
|
|
func (ns *Impl) DialContextTCP(ctx context.Context, addr string) (*gonet.TCPConn, error) {
|
|
ns.mu.Lock()
|
|
dnsMap := ns.dns
|
|
ns.mu.Unlock()
|
|
|
|
remoteIPPort, err := dnsMap.Resolve(ctx, addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
remoteAddress := tcpip.FullAddress{
|
|
NIC: nicID,
|
|
Addr: tcpip.Address(remoteIPPort.IP.IPAddr().IP),
|
|
Port: remoteIPPort.Port,
|
|
}
|
|
var ipType tcpip.NetworkProtocolNumber
|
|
if remoteIPPort.IP.Is4() {
|
|
ipType = ipv4.ProtocolNumber
|
|
} else {
|
|
ipType = ipv6.ProtocolNumber
|
|
}
|
|
|
|
return gonet.DialContextTCP(ctx, ns.ipstack, remoteAddress, ipType)
|
|
}
|
|
|
|
func (ns *Impl) injectOutbound() {
|
|
for {
|
|
packetInfo, ok := ns.linkEP.ReadContext(context.Background())
|
|
if !ok {
|
|
ns.logf("[v2] ReadContext-for-write = ok=false")
|
|
continue
|
|
}
|
|
pkt := packetInfo.Pkt
|
|
hdrNetwork := pkt.NetworkHeader()
|
|
hdrTransport := pkt.TransportHeader()
|
|
|
|
full := make([]byte, 0, pkt.Size())
|
|
full = append(full, hdrNetwork.View()...)
|
|
full = append(full, hdrTransport.View()...)
|
|
full = append(full, pkt.Data().AsRange().AsView()...)
|
|
if debugNetstack {
|
|
ns.logf("[v2] packet Write out: % x", full)
|
|
}
|
|
if err := ns.tundev.InjectOutbound(full); err != nil {
|
|
log.Printf("netstack inject outbound: %v", err)
|
|
return
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
// isLocalIP reports whether ip is a Tailscale IP assigned to this
|
|
// node directly (but not a subnet-routed IP).
|
|
func (ns *Impl) isLocalIP(ip netaddr.IP) bool {
|
|
return ns.atomicIsLocalIPFunc.Load().(func(netaddr.IP) bool)(ip)
|
|
}
|
|
|
|
func (ns *Impl) injectInbound(p *packet.Parsed, t *tstun.Wrapper) filter.Response {
|
|
if ns.onlySubnets && ns.isLocalIP(p.Dst.IP) {
|
|
// In hybrid ("only subnets") mode, bail out early if
|
|
// the traffic is destined for an actual Tailscale
|
|
// address. The real host OS interface will handle it.
|
|
return filter.Accept
|
|
}
|
|
var pn tcpip.NetworkProtocolNumber
|
|
switch p.IPVersion {
|
|
case 4:
|
|
pn = header.IPv4ProtocolNumber
|
|
case 6:
|
|
pn = header.IPv6ProtocolNumber
|
|
}
|
|
if debugNetstack {
|
|
ns.logf("[v2] packet in (from %v): % x", p.Src, p.Buffer())
|
|
}
|
|
vv := buffer.View(append([]byte(nil), p.Buffer()...)).ToVectorisedView()
|
|
packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
|
Data: vv,
|
|
})
|
|
ns.linkEP.InjectInbound(pn, packetBuf)
|
|
|
|
// We've now delivered this to netstack, so we're done.
|
|
// Instead of returning a filter.Accept here (which would also
|
|
// potentially deliver it to the host OS), and instead of
|
|
// filter.Drop (which would log about rejected traffic),
|
|
// instead return filter.DropSilently which just quietly stops
|
|
// processing it in the tstun TUN wrapper.
|
|
return filter.DropSilently
|
|
}
|
|
|
|
func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) {
|
|
reqDetails := r.ID()
|
|
if debugNetstack {
|
|
ns.logf("[v2] TCP ForwarderRequest: %s", stringifyTEI(reqDetails))
|
|
}
|
|
dialAddr := reqDetails.LocalAddress
|
|
dialNetAddr, _ := netaddr.FromStdIP(net.IP(dialAddr))
|
|
isTailscaleIP := tsaddr.IsTailscaleIP(dialNetAddr)
|
|
defer func() {
|
|
if !isTailscaleIP {
|
|
// if this is a subnet IP, we added this in before the TCP handshake
|
|
// so netstack is happy TCP-handshaking as a subnet IP
|
|
ns.removeSubnetAddress(dialNetAddr)
|
|
}
|
|
}()
|
|
var wq waiter.Queue
|
|
ep, err := r.CreateEndpoint(&wq)
|
|
if err != nil {
|
|
r.Complete(true)
|
|
return
|
|
}
|
|
if isTailscaleIP {
|
|
dialAddr = tcpip.Address(net.ParseIP("127.0.0.1")).To4()
|
|
}
|
|
r.Complete(false)
|
|
c := gonet.NewTCPConn(&wq, ep)
|
|
ns.forwardTCP(c, &wq, dialAddr, reqDetails.LocalPort)
|
|
}
|
|
|
|
func (ns *Impl) forwardTCP(client *gonet.TCPConn, wq *waiter.Queue, dialAddr tcpip.Address, dialPort uint16) {
|
|
defer client.Close()
|
|
dialAddrStr := net.JoinHostPort(dialAddr.String(), strconv.Itoa(int(dialPort)))
|
|
ns.logf("[v2] netstack: forwarding incoming connection to %s", dialAddrStr)
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
|
|
wq.EventRegister(&waitEntry, waiter.EventHUp)
|
|
defer wq.EventUnregister(&waitEntry)
|
|
done := make(chan bool)
|
|
// netstack doesn't close the notification channel automatically if there was no
|
|
// hup signal, so we close done after we're done to not leak the goroutine below.
|
|
defer close(done)
|
|
go func() {
|
|
select {
|
|
case <-notifyCh:
|
|
case <-done:
|
|
}
|
|
cancel()
|
|
}()
|
|
var stdDialer net.Dialer
|
|
server, err := stdDialer.DialContext(ctx, "tcp", dialAddrStr)
|
|
if err != nil {
|
|
ns.logf("netstack: could not connect to local server at %s: %v", dialAddrStr, err)
|
|
return
|
|
}
|
|
defer server.Close()
|
|
backendLocalAddr := server.LocalAddr().(*net.TCPAddr)
|
|
backendLocalIPPort, _ := netaddr.FromStdAddr(backendLocalAddr.IP, backendLocalAddr.Port, backendLocalAddr.Zone)
|
|
clientRemoteIP, _ := netaddr.FromStdIP(client.RemoteAddr().(*net.TCPAddr).IP)
|
|
ns.e.RegisterIPPortIdentity(backendLocalIPPort, clientRemoteIP)
|
|
defer ns.e.UnregisterIPPortIdentity(backendLocalIPPort)
|
|
connClosed := make(chan error, 2)
|
|
go func() {
|
|
_, err := io.Copy(server, client)
|
|
connClosed <- err
|
|
}()
|
|
go func() {
|
|
_, err := io.Copy(client, server)
|
|
connClosed <- err
|
|
}()
|
|
err = <-connClosed
|
|
if err != nil {
|
|
ns.logf("proxy connection closed with error: %v", err)
|
|
}
|
|
ns.logf("[v2] netstack: forwarder connection to %s closed", dialAddrStr)
|
|
}
|
|
|
|
func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) {
|
|
reqDetails := r.ID()
|
|
if debugNetstack {
|
|
ns.logf("[v2] UDP ForwarderRequest: %v", stringifyTEI(reqDetails))
|
|
}
|
|
var wq waiter.Queue
|
|
ep, err := r.CreateEndpoint(&wq)
|
|
if err != nil {
|
|
ns.logf("acceptUDP: could not create endpoint: %v", err)
|
|
return
|
|
}
|
|
localAddr, err := ep.GetLocalAddress()
|
|
if err != nil {
|
|
return
|
|
}
|
|
remoteAddr, err := ep.GetRemoteAddress()
|
|
if err != nil {
|
|
return
|
|
}
|
|
c := gonet.NewUDPConn(ns.ipstack, &wq, ep)
|
|
go ns.forwardUDP(c, &wq, localAddr, remoteAddr)
|
|
}
|
|
|
|
func (ns *Impl) forwardUDP(client *gonet.UDPConn, wq *waiter.Queue, clientLocalAddr, clientRemoteAddr tcpip.FullAddress) {
|
|
port := clientLocalAddr.Port
|
|
ns.logf("[v2] netstack: forwarding incoming UDP connection on port %v", port)
|
|
backendListenAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(clientRemoteAddr.Port)}
|
|
backendRemoteAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(port)}
|
|
backendConn, err := net.ListenUDP("udp4", backendListenAddr)
|
|
if err != nil {
|
|
ns.logf("netstack: could not bind local port %v: %v, trying again with random port", clientRemoteAddr.Port, err)
|
|
backendListenAddr.Port = 0
|
|
backendConn, err = net.ListenUDP("udp4", backendListenAddr)
|
|
if err != nil {
|
|
ns.logf("netstack: could not connect to local UDP server on port %v: %v", port, err)
|
|
return
|
|
}
|
|
}
|
|
backendLocalAddr := backendConn.LocalAddr().(*net.UDPAddr)
|
|
backendLocalIPPort, ok := netaddr.FromStdAddr(backendListenAddr.IP, backendLocalAddr.Port, backendLocalAddr.Zone)
|
|
if !ok {
|
|
ns.logf("could not get backend local IP:port from %v:%v", backendLocalAddr.IP, backendLocalAddr.Port)
|
|
}
|
|
clientRemoteIP, _ := netaddr.FromStdIP(net.ParseIP(clientRemoteAddr.Addr.String()))
|
|
ns.e.RegisterIPPortIdentity(backendLocalIPPort, clientRemoteIP)
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
timer := time.AfterFunc(2*time.Minute, func() {
|
|
ns.e.UnregisterIPPortIdentity(backendLocalIPPort)
|
|
ns.logf("netstack: UDP session between %s and %s timed out", clientRemoteAddr, backendRemoteAddr)
|
|
cancel()
|
|
client.Close()
|
|
backendConn.Close()
|
|
})
|
|
extend := func() {
|
|
timer.Reset(2 * time.Minute)
|
|
}
|
|
startPacketCopy(ctx, cancel, client, &net.UDPAddr{
|
|
IP: net.ParseIP(clientRemoteAddr.Addr.String()),
|
|
Port: int(clientRemoteAddr.Port),
|
|
}, backendConn, ns.logf, extend)
|
|
startPacketCopy(ctx, cancel, backendConn, backendRemoteAddr, client, ns.logf, extend)
|
|
|
|
}
|
|
|
|
func startPacketCopy(ctx context.Context, cancel context.CancelFunc, dst net.PacketConn, dstAddr net.Addr, src net.PacketConn, logf logger.Logf, extend func()) {
|
|
go func() {
|
|
defer cancel() // tear down the other direction's copy
|
|
pkt := make([]byte, mtu)
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
n, srcAddr, err := src.ReadFrom(pkt)
|
|
if err != nil {
|
|
if ctx.Err() == nil {
|
|
logf("read packet from %s failed: %v", srcAddr, err)
|
|
}
|
|
return
|
|
}
|
|
_, err = dst.WriteTo(pkt[:n], dstAddr)
|
|
if err != nil {
|
|
if ctx.Err() == nil {
|
|
logf("write packet to %s failed: %v", dstAddr, err)
|
|
}
|
|
return
|
|
}
|
|
if debugNetstack {
|
|
logf("[v2] wrote UDP packet %s -> %s", srcAddr, dstAddr)
|
|
}
|
|
extend()
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func stringifyTEI(tei stack.TransportEndpointID) string {
|
|
localHostPort := net.JoinHostPort(tei.LocalAddress.String(), strconv.Itoa(int(tei.LocalPort)))
|
|
remoteHostPort := net.JoinHostPort(tei.RemoteAddress.String(), strconv.Itoa(int(tei.RemotePort)))
|
|
return fmt.Sprintf("%s -> %s", remoteHostPort, localHostPort)
|
|
}
|