mirror of
https://github.com/tailscale/tailscale.git
synced 2025-07-17 11:08:37 +00:00
feature/relayserver,net/udprelay: add IPv6 support (#16442)
Updates tailscale/corp#27502 Updates tailscale/corp#30043 Signed-off-by: Jordan Whited <jordan@tailscale.com>
This commit is contained in:
parent
77d19604f4
commit
3a4b439c62
@ -137,7 +137,7 @@ func (e *extension) relayServerOrInit() (relayServer, error) {
|
||||
return nil, errors.New("TAILSCALE_USE_WIP_CODE envvar is not set")
|
||||
}
|
||||
var err error
|
||||
e.server, _, err = udprelay.NewServer(e.logf, *e.port, nil)
|
||||
e.server, err = udprelay.NewServer(e.logf, *e.port, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -57,7 +57,10 @@ type Server struct {
|
||||
bindLifetime time.Duration
|
||||
steadyStateLifetime time.Duration
|
||||
bus *eventbus.Bus
|
||||
uc *net.UDPConn
|
||||
uc4 *net.UDPConn // always non-nil
|
||||
uc4Port uint16 // always nonzero
|
||||
uc6 *net.UDPConn // may be nil if IPv6 bind fails during initialization
|
||||
uc6Port uint16 // may be zero if IPv6 bind fails during initialization
|
||||
closeOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
closeCh chan struct{}
|
||||
@ -278,13 +281,11 @@ func (e *serverEndpoint) isBound() bool {
|
||||
e.boundAddrPorts[1].IsValid()
|
||||
}
|
||||
|
||||
// NewServer constructs a [Server] listening on 0.0.0.0:'port'. IPv6 is not yet
|
||||
// supported. Port may be 0, and what ultimately gets bound is returned as
|
||||
// 'boundPort'. If len(overrideAddrs) > 0 these will be used in place of dynamic
|
||||
// discovery, which is useful to override in tests.
|
||||
//
|
||||
// TODO: IPv6 support
|
||||
func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Server, boundPort uint16, err error) {
|
||||
// NewServer constructs a [Server] listening on port. If port is zero, then
|
||||
// port selection is left up to the host networking stack. If
|
||||
// len(overrideAddrs) > 0 these will be used in place of dynamic discovery,
|
||||
// which is useful to override in tests.
|
||||
func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Server, err error) {
|
||||
s = &Server{
|
||||
logf: logger.WithPrefix(logf, "relayserver"),
|
||||
disco: key.NewDisco(),
|
||||
@ -306,30 +307,36 @@ func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Serve
|
||||
s.bus = bus
|
||||
netMon, err := netmon.New(s.bus, logf)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return nil, err
|
||||
}
|
||||
s.netChecker = &netcheck.Client{
|
||||
NetMon: netMon,
|
||||
Logf: logger.WithPrefix(logf, "relayserver: netcheck:"),
|
||||
SendPacket: func(b []byte, addrPort netip.AddrPort) (int, error) {
|
||||
return s.uc.WriteToUDPAddrPort(b, addrPort)
|
||||
if addrPort.Addr().Is4() {
|
||||
return s.uc4.WriteToUDPAddrPort(b, addrPort)
|
||||
} else if s.uc6 != nil {
|
||||
return s.uc6.WriteToUDPAddrPort(b, addrPort)
|
||||
} else {
|
||||
return 0, errors.New("IPv6 socket is not bound")
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
boundPort, err = s.listenOn(port)
|
||||
err = s.listenOn(port)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go s.packetReadLoop()
|
||||
s.wg.Add(1)
|
||||
go s.endpointGCLoop()
|
||||
if len(overrideAddrs) > 0 {
|
||||
addrPorts := make(set.Set[netip.AddrPort], len(overrideAddrs))
|
||||
for _, addr := range overrideAddrs {
|
||||
if addr.IsValid() {
|
||||
addrPorts.Add(netip.AddrPortFrom(addr, boundPort))
|
||||
if addr.Is4() {
|
||||
addrPorts.Add(netip.AddrPortFrom(addr, s.uc4Port))
|
||||
} else if s.uc6 != nil {
|
||||
addrPorts.Add(netip.AddrPortFrom(addr, s.uc6Port))
|
||||
}
|
||||
}
|
||||
}
|
||||
s.addrPorts = addrPorts.Slice()
|
||||
@ -337,7 +344,17 @@ func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Serve
|
||||
s.wg.Add(1)
|
||||
go s.addrDiscoveryLoop()
|
||||
}
|
||||
return s, boundPort, nil
|
||||
|
||||
s.wg.Add(1)
|
||||
go s.packetReadLoop(s.uc4)
|
||||
if s.uc6 != nil {
|
||||
s.wg.Add(1)
|
||||
go s.packetReadLoop(s.uc6)
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go s.endpointGCLoop()
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Server) addrDiscoveryLoop() {
|
||||
@ -351,14 +368,17 @@ func (s *Server) addrDiscoveryLoop() {
|
||||
addrPorts.Make()
|
||||
|
||||
// get local addresses
|
||||
localPort := s.uc.LocalAddr().(*net.UDPAddr).Port
|
||||
ips, _, err := netmon.LocalAddresses()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, ip := range ips {
|
||||
if ip.IsValid() {
|
||||
addrPorts.Add(netip.AddrPortFrom(ip, uint16(localPort)))
|
||||
if ip.Is4() {
|
||||
addrPorts.Add(netip.AddrPortFrom(ip, s.uc4Port))
|
||||
} else {
|
||||
addrPorts.Add(netip.AddrPortFrom(ip, s.uc6Port))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -413,24 +433,52 @@ func (s *Server) addrDiscoveryLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) listenOn(port int) (uint16, error) {
|
||||
uc, err := net.ListenUDP("udp4", &net.UDPAddr{Port: port})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
// listenOn binds an IPv4 and IPv6 socket to port. We consider it successful if
|
||||
// we manage to bind the IPv4 socket.
|
||||
//
|
||||
// The requested port may be zero, in which case port selection is left up to
|
||||
// the host networking stack. We make no attempt to bind a consistent port
|
||||
// across IPv4 and IPv6 if the requested port is zero.
|
||||
//
|
||||
// TODO: make these "re-bindable" in similar fashion to magicsock as a means to
|
||||
// deal with EDR software closing them. http://go/corp/30118
|
||||
func (s *Server) listenOn(port int) error {
|
||||
for _, network := range []string{"udp4", "udp6"} {
|
||||
uc, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
|
||||
if err != nil {
|
||||
if network == "udp4" {
|
||||
return err
|
||||
} else {
|
||||
s.logf("ignoring IPv6 bind failure: %v", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
// TODO: set IP_PKTINFO sockopt
|
||||
_, boundPortStr, err := net.SplitHostPort(uc.LocalAddr().String())
|
||||
if err != nil {
|
||||
uc.Close()
|
||||
if s.uc4 != nil {
|
||||
s.uc4.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
portUint, err := strconv.ParseUint(boundPortStr, 10, 16)
|
||||
if err != nil {
|
||||
uc.Close()
|
||||
if s.uc4 != nil {
|
||||
s.uc4.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
if network == "udp4" {
|
||||
s.uc4 = uc
|
||||
s.uc4Port = uint16(portUint)
|
||||
} else {
|
||||
s.uc6 = uc
|
||||
s.uc6Port = uint16(portUint)
|
||||
}
|
||||
}
|
||||
// TODO: set IP_PKTINFO sockopt
|
||||
_, boundPortStr, err := net.SplitHostPort(uc.LocalAddr().String())
|
||||
if err != nil {
|
||||
s.uc.Close()
|
||||
return 0, err
|
||||
}
|
||||
boundPort, err := strconv.ParseUint(boundPortStr, 10, 16)
|
||||
if err != nil {
|
||||
s.uc.Close()
|
||||
return 0, err
|
||||
}
|
||||
s.uc = uc
|
||||
return uint16(boundPort), nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the server.
|
||||
@ -438,7 +486,10 @@ func (s *Server) Close() error {
|
||||
s.closeOnce.Do(func() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.uc.Close()
|
||||
s.uc4.Close()
|
||||
if s.uc6 != nil {
|
||||
s.uc6.Close()
|
||||
}
|
||||
close(s.closeCh)
|
||||
s.wg.Wait()
|
||||
clear(s.byVNI)
|
||||
@ -507,7 +558,7 @@ func (s *Server) handlePacket(from netip.AddrPort, b []byte, uw udpWriter) {
|
||||
e.handlePacket(from, gh, b, uw, s.discoPublic)
|
||||
}
|
||||
|
||||
func (s *Server) packetReadLoop() {
|
||||
func (s *Server) packetReadLoop(uc *net.UDPConn) {
|
||||
defer func() {
|
||||
s.wg.Done()
|
||||
s.Close()
|
||||
@ -515,11 +566,11 @@ func (s *Server) packetReadLoop() {
|
||||
b := make([]byte, 1<<16-1)
|
||||
for {
|
||||
// TODO: extract laddr from IP_PKTINFO for use in reply
|
||||
n, from, err := s.uc.ReadFromUDPAddrPort(b)
|
||||
n, from, err := uc.ReadFromUDPAddrPort(b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.handlePacket(from, b[:n], s.uc)
|
||||
s.handlePacket(from, b[:n], uc)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -29,7 +29,7 @@ type testClient struct {
|
||||
|
||||
func newTestClient(t *testing.T, vni uint32, serverEndpoint netip.AddrPort, local key.DiscoPrivate, remote, server key.DiscoPublic) *testClient {
|
||||
rAddr := &net.UDPAddr{IP: serverEndpoint.Addr().AsSlice(), Port: int(serverEndpoint.Port())}
|
||||
uc, err := net.DialUDP("udp4", nil, rAddr)
|
||||
uc, err := net.DialUDP("udp", nil, rAddr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -180,85 +180,101 @@ func TestServer(t *testing.T) {
|
||||
discoA := key.NewDisco()
|
||||
discoB := key.NewDisco()
|
||||
|
||||
ipv4LoopbackAddr := netip.MustParseAddr("127.0.0.1")
|
||||
|
||||
server, _, err := NewServer(t.Logf, 0, []netip.Addr{ipv4LoopbackAddr})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer server.Close()
|
||||
|
||||
endpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dupEndpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
cases := []struct {
|
||||
name string
|
||||
overrideAddrs []netip.Addr
|
||||
}{
|
||||
{
|
||||
name: "over ipv4",
|
||||
overrideAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
|
||||
},
|
||||
{
|
||||
name: "over ipv6",
|
||||
overrideAddrs: []netip.Addr{netip.MustParseAddr("::1")},
|
||||
},
|
||||
}
|
||||
|
||||
// We expect the same endpoint details pre-handshake.
|
||||
if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" {
|
||||
t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff)
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server, err := NewServer(t.Logf, 0, tt.overrideAddrs)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer server.Close()
|
||||
|
||||
if len(endpoint.AddrPorts) != 1 {
|
||||
t.Fatalf("unexpected endpoint.AddrPorts: %v", endpoint.AddrPorts)
|
||||
}
|
||||
tcA := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, discoB.Public(), endpoint.ServerDisco)
|
||||
defer tcA.close()
|
||||
tcB := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, discoA.Public(), endpoint.ServerDisco)
|
||||
defer tcB.close()
|
||||
endpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dupEndpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tcA.handshake(t)
|
||||
tcB.handshake(t)
|
||||
// We expect the same endpoint details pre-handshake.
|
||||
if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" {
|
||||
t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff)
|
||||
}
|
||||
|
||||
dupEndpoint, err = server.AllocateEndpoint(discoA.Public(), discoB.Public())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// We expect the same endpoint details post-handshake.
|
||||
if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" {
|
||||
t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff)
|
||||
}
|
||||
if len(endpoint.AddrPorts) != 1 {
|
||||
t.Fatalf("unexpected endpoint.AddrPorts: %v", endpoint.AddrPorts)
|
||||
}
|
||||
tcA := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, discoB.Public(), endpoint.ServerDisco)
|
||||
defer tcA.close()
|
||||
tcB := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, discoA.Public(), endpoint.ServerDisco)
|
||||
defer tcB.close()
|
||||
|
||||
txToB := []byte{1, 2, 3}
|
||||
tcA.writeDataPkt(t, txToB)
|
||||
rxFromA := tcB.readDataPkt(t)
|
||||
if !bytes.Equal(txToB, rxFromA) {
|
||||
t.Fatal("unexpected msg A->B")
|
||||
}
|
||||
tcA.handshake(t)
|
||||
tcB.handshake(t)
|
||||
|
||||
txToA := []byte{4, 5, 6}
|
||||
tcB.writeDataPkt(t, txToA)
|
||||
rxFromB := tcA.readDataPkt(t)
|
||||
if !bytes.Equal(txToA, rxFromB) {
|
||||
t.Fatal("unexpected msg B->A")
|
||||
}
|
||||
dupEndpoint, err = server.AllocateEndpoint(discoA.Public(), discoB.Public())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// We expect the same endpoint details post-handshake.
|
||||
if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" {
|
||||
t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff)
|
||||
}
|
||||
|
||||
tcAOnNewPort := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, discoB.Public(), endpoint.ServerDisco)
|
||||
tcAOnNewPort.handshakeGeneration = tcA.handshakeGeneration + 1
|
||||
defer tcAOnNewPort.close()
|
||||
txToB := []byte{1, 2, 3}
|
||||
tcA.writeDataPkt(t, txToB)
|
||||
rxFromA := tcB.readDataPkt(t)
|
||||
if !bytes.Equal(txToB, rxFromA) {
|
||||
t.Fatal("unexpected msg A->B")
|
||||
}
|
||||
|
||||
// Handshake client A on a new source IP:port, verify we receive packets on the new binding
|
||||
tcAOnNewPort.handshake(t)
|
||||
txToAOnNewPort := []byte{7, 8, 9}
|
||||
tcB.writeDataPkt(t, txToAOnNewPort)
|
||||
rxFromB = tcAOnNewPort.readDataPkt(t)
|
||||
if !bytes.Equal(txToAOnNewPort, rxFromB) {
|
||||
t.Fatal("unexpected msg B->A")
|
||||
}
|
||||
txToA := []byte{4, 5, 6}
|
||||
tcB.writeDataPkt(t, txToA)
|
||||
rxFromB := tcA.readDataPkt(t)
|
||||
if !bytes.Equal(txToA, rxFromB) {
|
||||
t.Fatal("unexpected msg B->A")
|
||||
}
|
||||
|
||||
tcBOnNewPort := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, discoA.Public(), endpoint.ServerDisco)
|
||||
tcBOnNewPort.handshakeGeneration = tcB.handshakeGeneration + 1
|
||||
defer tcBOnNewPort.close()
|
||||
tcAOnNewPort := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, discoB.Public(), endpoint.ServerDisco)
|
||||
tcAOnNewPort.handshakeGeneration = tcA.handshakeGeneration + 1
|
||||
defer tcAOnNewPort.close()
|
||||
|
||||
// Handshake client B on a new source IP:port, verify we receive packets on the new binding
|
||||
tcBOnNewPort.handshake(t)
|
||||
txToBOnNewPort := []byte{7, 8, 9}
|
||||
tcAOnNewPort.writeDataPkt(t, txToBOnNewPort)
|
||||
rxFromA = tcBOnNewPort.readDataPkt(t)
|
||||
if !bytes.Equal(txToBOnNewPort, rxFromA) {
|
||||
t.Fatal("unexpected msg A->B")
|
||||
// Handshake client A on a new source IP:port, verify we receive packets on the new binding
|
||||
tcAOnNewPort.handshake(t)
|
||||
txToAOnNewPort := []byte{7, 8, 9}
|
||||
tcB.writeDataPkt(t, txToAOnNewPort)
|
||||
rxFromB = tcAOnNewPort.readDataPkt(t)
|
||||
if !bytes.Equal(txToAOnNewPort, rxFromB) {
|
||||
t.Fatal("unexpected msg B->A")
|
||||
}
|
||||
|
||||
tcBOnNewPort := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, discoA.Public(), endpoint.ServerDisco)
|
||||
tcBOnNewPort.handshakeGeneration = tcB.handshakeGeneration + 1
|
||||
defer tcBOnNewPort.close()
|
||||
|
||||
// Handshake client B on a new source IP:port, verify we receive packets on the new binding
|
||||
tcBOnNewPort.handshake(t)
|
||||
txToBOnNewPort := []byte{7, 8, 9}
|
||||
tcAOnNewPort.writeDataPkt(t, txToBOnNewPort)
|
||||
rxFromA = tcBOnNewPort.readDataPkt(t)
|
||||
if !bytes.Equal(txToBOnNewPort, rxFromA) {
|
||||
t.Fatal("unexpected msg A->B")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user