From cc326ea820c15a2d704a79685ad45f762b1d0a76 Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Sat, 2 Nov 2024 17:24:18 -0700 Subject: [PATCH] cmd/lopower: add udp forwarding Signed-off-by: Maisem Ali --- cmd/lopower/lopower.go | 212 ++++++++++++++++++++++++++++------------- net/tstun/wrap.go | 32 +++++-- 2 files changed, 171 insertions(+), 73 deletions(-) diff --git a/cmd/lopower/lopower.go b/cmd/lopower/lopower.go index c8f9e7f41..460d480d5 100644 --- a/cmd/lopower/lopower.go +++ b/cmd/lopower/lopower.go @@ -9,11 +9,11 @@ import ( "encoding/base64" "encoding/hex" "encoding/json" - "errors" "flag" "fmt" "io" "log" + "math/rand/v2" "net" "net/http" "net/netip" @@ -23,6 +23,7 @@ import ( "slices" "strings" "sync" + "sync/atomic" "time" qrcode "github.com/skip2/go-qrcode" @@ -34,15 +35,16 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" "tailscale.com/net/packet" "tailscale.com/net/tsaddr" + "tailscale.com/syncs" "tailscale.com/tsnet" "tailscale.com/types/key" "tailscale.com/types/logger" @@ -160,8 +162,7 @@ func newLP(ctx context.Context) *lpServer { logf := log.Printf deviceLogger := &device.Logger{ Verbosef: logger.Discard, - // Verbosef: logf, - Errorf: logf, + Errorf: logf, } lp := &lpServer{ dir: *confDir, @@ -196,6 +197,10 @@ type lpServer struct { linkEP *channel.Endpoint readCh chan *stack.PacketBuffer + // protocolConns tracks the number of active connections for each connection. + // It is used to add and remove protocol addresses from netstack as needed. + protocolConns syncs.Map[tcpip.ProtocolAddress, *atomic.Int32] + mu sync.Mutex // protects following c *config } @@ -203,83 +208,58 @@ type lpServer struct { // MaxPacketSize is the maximum size (in bytes) // of a packet that can be injected into lpServer. const MaxPacketSize = device.MaxContentSize +const nicID = 1 func (lp *lpServer) initNetstack(ctx context.Context) error { - ns := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ ipv4.NewProtocol, ipv6.NewProtocol, - arp.NewProtocol, }, TransportProtocols: []stack.TransportProtocolFactory{ tcp.NewProtocol, icmp.NewProtocol4, + udp.NewProtocol, }, }) + lp.ns = ns sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default - tcpipErr := ns.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) - if tcpipErr != nil { + if tcpipErr := ns.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt); tcpipErr != nil { return fmt.Errorf("SetTransportProtocolOption SACK: %v", tcpipErr) } lp.linkEP = channel.New(512, 1280, "") - const nicID = 1 if tcpipProblem := ns.CreateNIC(nicID, lp.linkEP); tcpipProblem != nil { return fmt.Errorf("CreateNIC: %v", tcpipProblem) } ns.SetPromiscuousMode(nicID, true) - ns.SetSpoofing(nicID, true) - var routes []tcpip.Route - lp.mu.Lock() - v4, v6 := lp.c.V4, lp.c.V6 - lp.mu.Unlock() - - { - prefix := tcpip.AddrFrom4Slice(v4.AsSlice()).WithPrefix() - if tcpProb := ns.AddProtocolAddress(nicID, tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: prefix, - }, stack.AddressProperties{}); tcpProb != nil { - return errors.New(tcpProb.String()) - } - - ipv4Subnet, err := tcpip.NewSubnet(tcpip.AddrFromSlice(make([]byte, 4)), tcpip.MaskFromBytes(make([]byte, 4))) - if err != nil { - return fmt.Errorf("could not create IPv4 subnet: %v", err) - } - routes = append(routes, tcpip.Route{ - Destination: ipv4Subnet, - NIC: nicID, - }) + ipv4Subnet, err := tcpip.NewSubnet(tcpip.AddrFromSlice(make([]byte, 4)), tcpip.MaskFromBytes(make([]byte, 4))) + if err != nil { + return fmt.Errorf("could not create IPv4 subnet: %v", err) } - { - prefix := tcpip.AddrFrom16(v6.As16()).WithPrefix() - if tcpProb := ns.AddProtocolAddress(nicID, tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: prefix, - }, stack.AddressProperties{}); tcpProb != nil { - return errors.New(tcpProb.String()) - } - - ipv6Subnet, err := tcpip.NewSubnet(tcpip.AddrFromSlice(make([]byte, 16)), tcpip.MaskFromBytes(make([]byte, 16))) - if err != nil { - return fmt.Errorf("could not create IPv6 subnet: %v", err) - } - routes = append(routes, tcpip.Route{ - Destination: ipv6Subnet, - NIC: nicID, - }) + ipv6Subnet, err := tcpip.NewSubnet(tcpip.AddrFromSlice(make([]byte, 16)), tcpip.MaskFromBytes(make([]byte, 16))) + if err != nil { + return fmt.Errorf("could not create IPv6 subnet: %v", err) } - ns.SetRouteTable(routes) + ns.SetRouteTable([]tcpip.Route{{ + Destination: ipv4Subnet, + NIC: nicID, + }, { + Destination: ipv6Subnet, + NIC: nicID, + }}) const tcpReceiveBufferSize = 0 // default const maxInFlightConnectionAttempts = 8192 tcpFwd := tcp.NewForwarder(ns, tcpReceiveBufferSize, maxInFlightConnectionAttempts, lp.acceptTCP) + udpFwd := udp.NewForwarder(ns, lp.acceptUDP) ns.SetTransportProtocolHandler(tcp.ProtocolNumber, func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) (handled bool) { return tcpFwd.HandlePacket(tei, pb) }) + ns.SetTransportProtocolHandler(udp.ProtocolNumber, func(tei stack.TransportEndpointID, pb *stack.PacketBuffer) (handled bool) { + return udpFwd.HandlePacket(tei, pb) + }) go func() { for { @@ -316,8 +296,125 @@ func netaddrIPFromNetstackIP(s tcpip.Address) netip.Addr { return netip.Addr{} } +func (lp *lpServer) trackProtocolAddr(destIP netip.Addr) (untrack func()) { + pa := tcpip.ProtocolAddress{ + AddressWithPrefix: tcpip.AddrFromSlice(destIP.AsSlice()).WithPrefix(), + } + if destIP.Is4() { + pa.Protocol = ipv4.ProtocolNumber + } else if destIP.Is6() { + pa.Protocol = ipv6.ProtocolNumber + } + + addrConns, _ := lp.protocolConns.LoadOrInit(pa, func() *atomic.Int32 { return new(atomic.Int32) }) + if addrConns.Add(1) == 1 { + lp.ns.AddProtocolAddress(nicID, pa, stack.AddressProperties{ + PEB: stack.CanBePrimaryEndpoint, // zero value default + ConfigType: stack.AddressConfigStatic, // zero value default + }) + } + return func() { + if addrConns.Add(-1) == 0 { + lp.ns.RemoveAddress(nicID, pa.AddressWithPrefix.Address) + } + } +} + +func (lp *lpServer) acceptUDP(r *udp.ForwarderRequest) { + log.Printf("acceptUDP: %v", r.ID()) + destIP := netaddrIPFromNetstackIP(r.ID().LocalAddress) + untrack := lp.trackProtocolAddr(destIP) + var wq waiter.Queue + ep, udpErr := r.CreateEndpoint(&wq) + if udpErr != nil { + log.Printf("CreateEndpoint: %v", udpErr) + return + } + go func() { + defer untrack() + defer ep.Close() + reqDetails := r.ID() + + clientRemoteIP := netaddrIPFromNetstackIP(reqDetails.RemoteAddress) + destPort := reqDetails.LocalPort + if !clientRemoteIP.IsValid() { + log.Printf("acceptUDP: invalid remote IP %v", reqDetails.RemoteAddress) + return + } + + randPort := rand.IntN(65536-1024) + 1024 + v4, v6 := lp.tsnet.TailscaleIPs() + var listenAddr netip.Addr + if destIP.Is4() { + listenAddr = v4 + } else { + listenAddr = v6 + } + backendConn, err := lp.tsnet.ListenPacket("udp", fmt.Sprintf("%s:%d", listenAddr, randPort)) + if err != nil { + log.Printf("ListenPacket: %v", err) + return + } + defer backendConn.Close() + clientConn := gonet.NewUDPConn(&wq, ep) + defer clientConn.Close() + errCh := make(chan error, 2) + go func() (err error) { + defer func() { errCh <- err }() + var buf [64]byte + for { + n, _, err := backendConn.ReadFrom(buf[:]) + if err != nil { + log.Printf("UDP read: %v", err) + return err + } + _, err = clientConn.Write(buf[:n]) + if err != nil { + return err + } + } + }() + dstAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", destIP, destPort)) + if err != nil { + log.Printf("ResolveUDPAddr: %v", err) + return + } + go func() (err error) { + defer func() { errCh <- err }() + var buf [2048]byte + for { + n, err := clientConn.Read(buf[:]) + if err != nil { + log.Printf("UDP read: %v", err) + return err + } + _, err = backendConn.WriteTo(buf[:n], dstAddr) + if err != nil { + return err + } + } + }() + err = <-errCh + if err != nil { + log.Printf("io.Copy: %v", err) + } + }() +} + func (lp *lpServer) acceptTCP(r *tcp.ForwarderRequest) { log.Printf("acceptTCP: %v", r.ID()) + reqDetails := r.ID() + destIP := netaddrIPFromNetstackIP(reqDetails.LocalAddress) + clientRemoteIP := netaddrIPFromNetstackIP(reqDetails.RemoteAddress) + destPort := reqDetails.LocalPort + if !clientRemoteIP.IsValid() { + log.Printf("acceptTCP: invalid remote IP %v", reqDetails.RemoteAddress) + r.Complete(true) // sends a RST + return + } + untrack := lp.trackProtocolAddr(destIP) + defer untrack() + var wq waiter.Queue ep, tcpErr := r.CreateEndpoint(&wq) if tcpErr != nil { @@ -325,22 +422,10 @@ func (lp *lpServer) acceptTCP(r *tcp.ForwarderRequest) { r.Complete(true) return } - log.Printf("created endpoint %v", ep) defer ep.Close() ep.SocketOptions().SetKeepAlive(true) - reqDetails := r.ID() - clientRemoteIP := netaddrIPFromNetstackIP(reqDetails.RemoteAddress) - destIP := netaddrIPFromNetstackIP(reqDetails.LocalAddress) - destPort := reqDetails.LocalPort - if !clientRemoteIP.IsValid() { - log.Printf("acceptTCP: invalid remote IP %v", reqDetails.RemoteAddress) - r.Complete(true) // sends a RST - return - } - log.Printf("request from %v to %v:%d", clientRemoteIP, destIP, destPort) - - dialCtx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + dialCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) c, err := lp.tsnet.Dial(dialCtx, "tcp", fmt.Sprintf("%s:%d", destIP, destPort)) cancel() if err != nil { @@ -349,7 +434,6 @@ func (lp *lpServer) acceptTCP(r *tcp.ForwarderRequest) { return } defer c.Close() - log.Printf("Connected to %s:%d", destIP, destPort) tc := gonet.NewTCPConn(&wq, ep) defer tc.Close() diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index c384abf9d..65c9f8932 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -861,7 +861,22 @@ func (t *Wrapper) filterPacketOutboundToWireGuard(p *packet.Parsed, pc *peerConf return res, gro } } + if resp := t.filtRunOut(p, pc); resp != filter.Accept { + return resp, gro + } + if t.PostFilterPacketOutboundToWireGuard != nil { + if res := t.PostFilterPacketOutboundToWireGuard(p, t); res.IsDrop() { + return res, gro + } + } + return filter.Accept, gro +} + +// filtRunOut runs the outbound packet filter on p. +// It uses pc to determine if the packet is to a jailed peer and should be +// filtered with the jailed filter. +func (t *Wrapper) filtRunOut(p *packet.Parsed, pc *peerConfigTable) filter.Response { // If the outbound packet is to a jailed peer, use our jailed peer // packet filter. var filt *filter.Filter @@ -871,7 +886,7 @@ func (t *Wrapper) filterPacketOutboundToWireGuard(p *packet.Parsed, pc *peerConf filt = t.filter.Load() } if filt == nil { - return filter.Drop, gro + return filter.Drop } if filt.RunOut(p, t.filterFlags) != filter.Accept { @@ -879,15 +894,9 @@ func (t *Wrapper) filterPacketOutboundToWireGuard(p *packet.Parsed, pc *peerConf t.metrics.outboundDroppedPacketsTotal.Add(usermetric.DropLabels{ Reason: usermetric.ReasonACL, }, 1) - return filter.Drop, gro + return filter.Drop } - - if t.PostFilterPacketOutboundToWireGuard != nil { - if res := t.PostFilterPacketOutboundToWireGuard(p, t); res.IsDrop() { - return res, gro - } - } - return filter.Accept, gro + return filter.Accept } // noteActivity records that there was a read or write at the current time. @@ -1051,6 +1060,11 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []i p := parsedPacketPool.Get().(*packet.Parsed) defer parsedPacketPool.Put(p) p.Decode(pkt) + response, _ := t.filterPacketOutboundToWireGuard(p, pc, nil) + if response != filter.Accept { + metricPacketOutDrop.Add(1) + return + } invertGSOChecksum(pkt, gso) pc.snat(p)