feature/relayserver,net/udprelay: add IPv6 support (#16442)

Updates tailscale/corp#27502
Updates tailscale/corp#30043

Signed-off-by: Jordan Whited <jordan@tailscale.com>
This commit is contained in:
Jordan Whited 2025-07-02 20:38:39 -07:00 committed by GitHub
parent 77d19604f4
commit 3a4b439c62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 179 additions and 112 deletions

View File

@ -137,7 +137,7 @@ func (e *extension) relayServerOrInit() (relayServer, error) {
return nil, errors.New("TAILSCALE_USE_WIP_CODE envvar is not set")
}
var err error
e.server, _, err = udprelay.NewServer(e.logf, *e.port, nil)
e.server, err = udprelay.NewServer(e.logf, *e.port, nil)
if err != nil {
return nil, err
}

View File

@ -57,7 +57,10 @@ type Server struct {
bindLifetime time.Duration
steadyStateLifetime time.Duration
bus *eventbus.Bus
uc *net.UDPConn
uc4 *net.UDPConn // always non-nil
uc4Port uint16 // always nonzero
uc6 *net.UDPConn // may be nil if IPv6 bind fails during initialization
uc6Port uint16 // may be zero if IPv6 bind fails during initialization
closeOnce sync.Once
wg sync.WaitGroup
closeCh chan struct{}
@ -278,13 +281,11 @@ func (e *serverEndpoint) isBound() bool {
e.boundAddrPorts[1].IsValid()
}
// NewServer constructs a [Server] listening on 0.0.0.0:'port'. IPv6 is not yet
// supported. Port may be 0, and what ultimately gets bound is returned as
// 'boundPort'. If len(overrideAddrs) > 0 these will be used in place of dynamic
// discovery, which is useful to override in tests.
//
// TODO: IPv6 support
func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Server, boundPort uint16, err error) {
// NewServer constructs a [Server] listening on port. If port is zero, then
// port selection is left up to the host networking stack. If
// len(overrideAddrs) > 0 these will be used in place of dynamic discovery,
// which is useful to override in tests.
func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Server, err error) {
s = &Server{
logf: logger.WithPrefix(logf, "relayserver"),
disco: key.NewDisco(),
@ -306,30 +307,36 @@ func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Serve
s.bus = bus
netMon, err := netmon.New(s.bus, logf)
if err != nil {
return nil, 0, err
return nil, err
}
s.netChecker = &netcheck.Client{
NetMon: netMon,
Logf: logger.WithPrefix(logf, "relayserver: netcheck:"),
SendPacket: func(b []byte, addrPort netip.AddrPort) (int, error) {
return s.uc.WriteToUDPAddrPort(b, addrPort)
if addrPort.Addr().Is4() {
return s.uc4.WriteToUDPAddrPort(b, addrPort)
} else if s.uc6 != nil {
return s.uc6.WriteToUDPAddrPort(b, addrPort)
} else {
return 0, errors.New("IPv6 socket is not bound")
}
},
}
boundPort, err = s.listenOn(port)
err = s.listenOn(port)
if err != nil {
return nil, 0, err
return nil, err
}
s.wg.Add(1)
go s.packetReadLoop()
s.wg.Add(1)
go s.endpointGCLoop()
if len(overrideAddrs) > 0 {
addrPorts := make(set.Set[netip.AddrPort], len(overrideAddrs))
for _, addr := range overrideAddrs {
if addr.IsValid() {
addrPorts.Add(netip.AddrPortFrom(addr, boundPort))
if addr.Is4() {
addrPorts.Add(netip.AddrPortFrom(addr, s.uc4Port))
} else if s.uc6 != nil {
addrPorts.Add(netip.AddrPortFrom(addr, s.uc6Port))
}
}
}
s.addrPorts = addrPorts.Slice()
@ -337,7 +344,17 @@ func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Serve
s.wg.Add(1)
go s.addrDiscoveryLoop()
}
return s, boundPort, nil
s.wg.Add(1)
go s.packetReadLoop(s.uc4)
if s.uc6 != nil {
s.wg.Add(1)
go s.packetReadLoop(s.uc6)
}
s.wg.Add(1)
go s.endpointGCLoop()
return s, nil
}
func (s *Server) addrDiscoveryLoop() {
@ -351,14 +368,17 @@ func (s *Server) addrDiscoveryLoop() {
addrPorts.Make()
// get local addresses
localPort := s.uc.LocalAddr().(*net.UDPAddr).Port
ips, _, err := netmon.LocalAddresses()
if err != nil {
return nil, err
}
for _, ip := range ips {
if ip.IsValid() {
addrPorts.Add(netip.AddrPortFrom(ip, uint16(localPort)))
if ip.Is4() {
addrPorts.Add(netip.AddrPortFrom(ip, s.uc4Port))
} else {
addrPorts.Add(netip.AddrPortFrom(ip, s.uc6Port))
}
}
}
@ -413,24 +433,52 @@ func (s *Server) addrDiscoveryLoop() {
}
}
func (s *Server) listenOn(port int) (uint16, error) {
uc, err := net.ListenUDP("udp4", &net.UDPAddr{Port: port})
if err != nil {
return 0, err
// listenOn binds an IPv4 and IPv6 socket to port. We consider it successful if
// we manage to bind the IPv4 socket.
//
// The requested port may be zero, in which case port selection is left up to
// the host networking stack. We make no attempt to bind a consistent port
// across IPv4 and IPv6 if the requested port is zero.
//
// TODO: make these "re-bindable" in similar fashion to magicsock as a means to
// deal with EDR software closing them. http://go/corp/30118
func (s *Server) listenOn(port int) error {
for _, network := range []string{"udp4", "udp6"} {
uc, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
if err != nil {
if network == "udp4" {
return err
} else {
s.logf("ignoring IPv6 bind failure: %v", err)
break
}
}
// TODO: set IP_PKTINFO sockopt
_, boundPortStr, err := net.SplitHostPort(uc.LocalAddr().String())
if err != nil {
uc.Close()
if s.uc4 != nil {
s.uc4.Close()
}
return err
}
portUint, err := strconv.ParseUint(boundPortStr, 10, 16)
if err != nil {
uc.Close()
if s.uc4 != nil {
s.uc4.Close()
}
return err
}
if network == "udp4" {
s.uc4 = uc
s.uc4Port = uint16(portUint)
} else {
s.uc6 = uc
s.uc6Port = uint16(portUint)
}
}
// TODO: set IP_PKTINFO sockopt
_, boundPortStr, err := net.SplitHostPort(uc.LocalAddr().String())
if err != nil {
s.uc.Close()
return 0, err
}
boundPort, err := strconv.ParseUint(boundPortStr, 10, 16)
if err != nil {
s.uc.Close()
return 0, err
}
s.uc = uc
return uint16(boundPort), nil
return nil
}
// Close closes the server.
@ -438,7 +486,10 @@ func (s *Server) Close() error {
s.closeOnce.Do(func() {
s.mu.Lock()
defer s.mu.Unlock()
s.uc.Close()
s.uc4.Close()
if s.uc6 != nil {
s.uc6.Close()
}
close(s.closeCh)
s.wg.Wait()
clear(s.byVNI)
@ -507,7 +558,7 @@ func (s *Server) handlePacket(from netip.AddrPort, b []byte, uw udpWriter) {
e.handlePacket(from, gh, b, uw, s.discoPublic)
}
func (s *Server) packetReadLoop() {
func (s *Server) packetReadLoop(uc *net.UDPConn) {
defer func() {
s.wg.Done()
s.Close()
@ -515,11 +566,11 @@ func (s *Server) packetReadLoop() {
b := make([]byte, 1<<16-1)
for {
// TODO: extract laddr from IP_PKTINFO for use in reply
n, from, err := s.uc.ReadFromUDPAddrPort(b)
n, from, err := uc.ReadFromUDPAddrPort(b)
if err != nil {
return
}
s.handlePacket(from, b[:n], s.uc)
s.handlePacket(from, b[:n], uc)
}
}

View File

@ -29,7 +29,7 @@ type testClient struct {
func newTestClient(t *testing.T, vni uint32, serverEndpoint netip.AddrPort, local key.DiscoPrivate, remote, server key.DiscoPublic) *testClient {
rAddr := &net.UDPAddr{IP: serverEndpoint.Addr().AsSlice(), Port: int(serverEndpoint.Port())}
uc, err := net.DialUDP("udp4", nil, rAddr)
uc, err := net.DialUDP("udp", nil, rAddr)
if err != nil {
t.Fatal(err)
}
@ -180,85 +180,101 @@ func TestServer(t *testing.T) {
discoA := key.NewDisco()
discoB := key.NewDisco()
ipv4LoopbackAddr := netip.MustParseAddr("127.0.0.1")
server, _, err := NewServer(t.Logf, 0, []netip.Addr{ipv4LoopbackAddr})
if err != nil {
t.Fatal(err)
}
defer server.Close()
endpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public())
if err != nil {
t.Fatal(err)
}
dupEndpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public())
if err != nil {
t.Fatal(err)
cases := []struct {
name string
overrideAddrs []netip.Addr
}{
{
name: "over ipv4",
overrideAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
},
{
name: "over ipv6",
overrideAddrs: []netip.Addr{netip.MustParseAddr("::1")},
},
}
// We expect the same endpoint details pre-handshake.
if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" {
t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff)
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
server, err := NewServer(t.Logf, 0, tt.overrideAddrs)
if err != nil {
t.Fatal(err)
}
defer server.Close()
if len(endpoint.AddrPorts) != 1 {
t.Fatalf("unexpected endpoint.AddrPorts: %v", endpoint.AddrPorts)
}
tcA := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, discoB.Public(), endpoint.ServerDisco)
defer tcA.close()
tcB := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, discoA.Public(), endpoint.ServerDisco)
defer tcB.close()
endpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public())
if err != nil {
t.Fatal(err)
}
dupEndpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public())
if err != nil {
t.Fatal(err)
}
tcA.handshake(t)
tcB.handshake(t)
// We expect the same endpoint details pre-handshake.
if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" {
t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff)
}
dupEndpoint, err = server.AllocateEndpoint(discoA.Public(), discoB.Public())
if err != nil {
t.Fatal(err)
}
// We expect the same endpoint details post-handshake.
if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" {
t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff)
}
if len(endpoint.AddrPorts) != 1 {
t.Fatalf("unexpected endpoint.AddrPorts: %v", endpoint.AddrPorts)
}
tcA := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, discoB.Public(), endpoint.ServerDisco)
defer tcA.close()
tcB := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, discoA.Public(), endpoint.ServerDisco)
defer tcB.close()
txToB := []byte{1, 2, 3}
tcA.writeDataPkt(t, txToB)
rxFromA := tcB.readDataPkt(t)
if !bytes.Equal(txToB, rxFromA) {
t.Fatal("unexpected msg A->B")
}
tcA.handshake(t)
tcB.handshake(t)
txToA := []byte{4, 5, 6}
tcB.writeDataPkt(t, txToA)
rxFromB := tcA.readDataPkt(t)
if !bytes.Equal(txToA, rxFromB) {
t.Fatal("unexpected msg B->A")
}
dupEndpoint, err = server.AllocateEndpoint(discoA.Public(), discoB.Public())
if err != nil {
t.Fatal(err)
}
// We expect the same endpoint details post-handshake.
if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" {
t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff)
}
tcAOnNewPort := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, discoB.Public(), endpoint.ServerDisco)
tcAOnNewPort.handshakeGeneration = tcA.handshakeGeneration + 1
defer tcAOnNewPort.close()
txToB := []byte{1, 2, 3}
tcA.writeDataPkt(t, txToB)
rxFromA := tcB.readDataPkt(t)
if !bytes.Equal(txToB, rxFromA) {
t.Fatal("unexpected msg A->B")
}
// Handshake client A on a new source IP:port, verify we receive packets on the new binding
tcAOnNewPort.handshake(t)
txToAOnNewPort := []byte{7, 8, 9}
tcB.writeDataPkt(t, txToAOnNewPort)
rxFromB = tcAOnNewPort.readDataPkt(t)
if !bytes.Equal(txToAOnNewPort, rxFromB) {
t.Fatal("unexpected msg B->A")
}
txToA := []byte{4, 5, 6}
tcB.writeDataPkt(t, txToA)
rxFromB := tcA.readDataPkt(t)
if !bytes.Equal(txToA, rxFromB) {
t.Fatal("unexpected msg B->A")
}
tcBOnNewPort := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, discoA.Public(), endpoint.ServerDisco)
tcBOnNewPort.handshakeGeneration = tcB.handshakeGeneration + 1
defer tcBOnNewPort.close()
tcAOnNewPort := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, discoB.Public(), endpoint.ServerDisco)
tcAOnNewPort.handshakeGeneration = tcA.handshakeGeneration + 1
defer tcAOnNewPort.close()
// Handshake client B on a new source IP:port, verify we receive packets on the new binding
tcBOnNewPort.handshake(t)
txToBOnNewPort := []byte{7, 8, 9}
tcAOnNewPort.writeDataPkt(t, txToBOnNewPort)
rxFromA = tcBOnNewPort.readDataPkt(t)
if !bytes.Equal(txToBOnNewPort, rxFromA) {
t.Fatal("unexpected msg A->B")
// Handshake client A on a new source IP:port, verify we receive packets on the new binding
tcAOnNewPort.handshake(t)
txToAOnNewPort := []byte{7, 8, 9}
tcB.writeDataPkt(t, txToAOnNewPort)
rxFromB = tcAOnNewPort.readDataPkt(t)
if !bytes.Equal(txToAOnNewPort, rxFromB) {
t.Fatal("unexpected msg B->A")
}
tcBOnNewPort := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, discoA.Public(), endpoint.ServerDisco)
tcBOnNewPort.handshakeGeneration = tcB.handshakeGeneration + 1
defer tcBOnNewPort.close()
// Handshake client B on a new source IP:port, verify we receive packets on the new binding
tcBOnNewPort.handshake(t)
txToBOnNewPort := []byte{7, 8, 9}
tcAOnNewPort.writeDataPkt(t, txToBOnNewPort)
rxFromA = tcBOnNewPort.readDataPkt(t)
if !bytes.Equal(txToBOnNewPort, rxFromA) {
t.Fatal("unexpected msg A->B")
}
})
}
}