From 1677fb190519710d66354600f659b50af77d7759 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Mon, 21 Jul 2025 10:02:37 -0700 Subject: [PATCH] wgengine/magicsock,all: allocate peer relay over disco instead of PeerAPI (#16603) Updates tailscale/corp#30583 Updates tailscale/corp#30534 Updates tailscale/corp#30557 Signed-off-by: Dylan Bargatze Signed-off-by: Jordan Whited Co-authored-by: Dylan Bargatze --- disco/disco.go | 261 ++++++++---- disco/disco_test.go | 41 +- feature/relayserver/relayserver.go | 151 +++---- feature/relayserver/relayserver_test.go | 10 + ipn/ipnlocal/local.go | 2 +- ipn/localapi/localapi.go | 2 +- net/udprelay/endpoint/endpoint.go | 9 + net/udprelay/server.go | 45 +- tailcfg/tailcfg.go | 3 +- types/key/disco.go | 38 ++ types/key/disco_test.go | 18 + wgengine/magicsock/endpoint.go | 17 +- wgengine/magicsock/magicsock.go | 367 +++++++++++++---- wgengine/magicsock/magicsock_test.go | 519 ++++++++++++------------ wgengine/magicsock/relaymanager.go | 508 +++++++++++++---------- wgengine/magicsock/relaymanager_test.go | 42 +- 16 files changed, 1290 insertions(+), 743 deletions(-) diff --git a/disco/disco.go b/disco/disco.go index d4623c119..1689d2a93 100644 --- a/disco/disco.go +++ b/disco/disco.go @@ -42,13 +42,15 @@ const NonceLen = 24 type MessageType byte const ( - TypePing = MessageType(0x01) - TypePong = MessageType(0x02) - TypeCallMeMaybe = MessageType(0x03) - TypeBindUDPRelayEndpoint = MessageType(0x04) - TypeBindUDPRelayEndpointChallenge = MessageType(0x05) - TypeBindUDPRelayEndpointAnswer = MessageType(0x06) - TypeCallMeMaybeVia = MessageType(0x07) + TypePing = MessageType(0x01) + TypePong = MessageType(0x02) + TypeCallMeMaybe = MessageType(0x03) + TypeBindUDPRelayEndpoint = MessageType(0x04) + TypeBindUDPRelayEndpointChallenge = MessageType(0x05) + TypeBindUDPRelayEndpointAnswer = MessageType(0x06) + TypeCallMeMaybeVia = MessageType(0x07) + TypeAllocateUDPRelayEndpointRequest = MessageType(0x08) + TypeAllocateUDPRelayEndpointResponse = MessageType(0x09) ) const v0 = byte(0) @@ -97,6 +99,10 @@ func Parse(p []byte) (Message, error) { return parseBindUDPRelayEndpointAnswer(ver, p) case TypeCallMeMaybeVia: return parseCallMeMaybeVia(ver, p) + case TypeAllocateUDPRelayEndpointRequest: + return parseAllocateUDPRelayEndpointRequest(ver, p) + case TypeAllocateUDPRelayEndpointResponse: + return parseAllocateUDPRelayEndpointResponse(ver, p) default: return nil, fmt.Errorf("unknown message type 0x%02x", byte(t)) } @@ -381,9 +387,7 @@ func (m *BindUDPRelayEndpointCommon) decode(b []byte) error { } // BindUDPRelayEndpoint is the first messaged transmitted from UDP relay client -// towards UDP relay server as part of the 3-way bind handshake. This message -// type is currently considered experimental and is not yet tied to a -// tailcfg.CapabilityVersion. +// towards UDP relay server as part of the 3-way bind handshake. type BindUDPRelayEndpoint struct { BindUDPRelayEndpointCommon } @@ -405,8 +409,7 @@ func parseBindUDPRelayEndpoint(ver uint8, p []byte) (m *BindUDPRelayEndpoint, er // BindUDPRelayEndpointChallenge is transmitted from UDP relay server towards // UDP relay client in response to a BindUDPRelayEndpoint message as part of the -// 3-way bind handshake. This message type is currently considered experimental -// and is not yet tied to a tailcfg.CapabilityVersion. +// 3-way bind handshake. type BindUDPRelayEndpointChallenge struct { BindUDPRelayEndpointCommon } @@ -427,9 +430,7 @@ func parseBindUDPRelayEndpointChallenge(ver uint8, p []byte) (m *BindUDPRelayEnd } // BindUDPRelayEndpointAnswer is transmitted from UDP relay client to UDP relay -// server in response to a BindUDPRelayEndpointChallenge message. This message -// type is currently considered experimental and is not yet tied to a -// tailcfg.CapabilityVersion. +// server in response to a BindUDPRelayEndpointChallenge message. type BindUDPRelayEndpointAnswer struct { BindUDPRelayEndpointCommon } @@ -449,6 +450,168 @@ func parseBindUDPRelayEndpointAnswer(ver uint8, p []byte) (m *BindUDPRelayEndpoi return m, nil } +// AllocateUDPRelayEndpointRequest is a message sent only over DERP to request +// allocation of a relay endpoint on a [tailscale.com/net/udprelay.Server] +type AllocateUDPRelayEndpointRequest struct { + // ClientDisco are the Disco public keys of the clients that should be + // permitted to handshake with the endpoint. + ClientDisco [2]key.DiscoPublic + // Generation represents the allocation request generation. The server must + // echo it back in the [AllocateUDPRelayEndpointResponse] to enable request + // and response alignment client-side. + Generation uint32 +} + +// allocateUDPRelayEndpointRequestLen is the length of a marshaled +// [AllocateUDPRelayEndpointRequest] message without the message header. +const allocateUDPRelayEndpointRequestLen = key.DiscoPublicRawLen*2 + // ClientDisco + 4 // Generation + +func (m *AllocateUDPRelayEndpointRequest) AppendMarshal(b []byte) []byte { + ret, p := appendMsgHeader(b, TypeAllocateUDPRelayEndpointRequest, v0, allocateUDPRelayEndpointRequestLen) + for i := 0; i < len(m.ClientDisco); i++ { + disco := m.ClientDisco[i].AppendTo(nil) + copy(p, disco) + p = p[key.DiscoPublicRawLen:] + } + binary.BigEndian.PutUint32(p, m.Generation) + return ret +} + +func parseAllocateUDPRelayEndpointRequest(ver uint8, p []byte) (m *AllocateUDPRelayEndpointRequest, err error) { + m = new(AllocateUDPRelayEndpointRequest) + if ver != 0 { + return + } + if len(p) < allocateUDPRelayEndpointRequestLen { + return m, errShort + } + for i := 0; i < len(m.ClientDisco); i++ { + m.ClientDisco[i] = key.DiscoPublicFromRaw32(mem.B(p[:key.DiscoPublicRawLen])) + p = p[key.DiscoPublicRawLen:] + } + m.Generation = binary.BigEndian.Uint32(p) + return m, nil +} + +// AllocateUDPRelayEndpointResponse is a message sent only over DERP in response +// to a [AllocateUDPRelayEndpointRequest]. +type AllocateUDPRelayEndpointResponse struct { + // Generation represents the allocation request generation. The server must + // echo back the [AllocateUDPRelayEndpointRequest.Generation] here to enable + // request and response alignment client-side. + Generation uint32 + UDPRelayEndpoint +} + +func (m *AllocateUDPRelayEndpointResponse) AppendMarshal(b []byte) []byte { + endpointsLen := epLength * len(m.AddrPorts) + generationLen := 4 + ret, d := appendMsgHeader(b, TypeAllocateUDPRelayEndpointResponse, v0, generationLen+udpRelayEndpointLenMinusAddrPorts+endpointsLen) + binary.BigEndian.PutUint32(d, m.Generation) + m.encode(d[4:]) + return ret +} + +func parseAllocateUDPRelayEndpointResponse(ver uint8, p []byte) (m *AllocateUDPRelayEndpointResponse, err error) { + m = new(AllocateUDPRelayEndpointResponse) + if ver != 0 { + return m, nil + } + if len(p) < 4 { + return m, errShort + } + m.Generation = binary.BigEndian.Uint32(p) + err = m.decode(p[4:]) + return m, err +} + +const udpRelayEndpointLenMinusAddrPorts = key.DiscoPublicRawLen + // ServerDisco + (key.DiscoPublicRawLen * 2) + // ClientDisco + 8 + // LamportID + 4 + // VNI + 8 + // BindLifetime + 8 // SteadyStateLifetime + +// UDPRelayEndpoint is a mirror of [tailscale.com/net/udprelay/endpoint.ServerEndpoint], +// refer to it for field documentation. [UDPRelayEndpoint] is carried in both +// [CallMeMaybeVia] and [AllocateUDPRelayEndpointResponse] messages. +type UDPRelayEndpoint struct { + // ServerDisco is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.ServerDisco] + ServerDisco key.DiscoPublic + // ClientDisco is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.ClientDisco] + ClientDisco [2]key.DiscoPublic + // LamportID is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.LamportID] + LamportID uint64 + // VNI is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.VNI] + VNI uint32 + // BindLifetime is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.BindLifetime] + BindLifetime time.Duration + // SteadyStateLifetime is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.SteadyStateLifetime] + SteadyStateLifetime time.Duration + // AddrPorts is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.AddrPorts] + AddrPorts []netip.AddrPort +} + +// encode encodes m in b. b must be at least [udpRelayEndpointLenMinusAddrPorts] +// + [epLength] * len(m.AddrPorts) bytes long. +func (m *UDPRelayEndpoint) encode(b []byte) { + disco := m.ServerDisco.AppendTo(nil) + copy(b, disco) + b = b[key.DiscoPublicRawLen:] + for i := 0; i < len(m.ClientDisco); i++ { + disco = m.ClientDisco[i].AppendTo(nil) + copy(b, disco) + b = b[key.DiscoPublicRawLen:] + } + binary.BigEndian.PutUint64(b[:8], m.LamportID) + b = b[8:] + binary.BigEndian.PutUint32(b[:4], m.VNI) + b = b[4:] + binary.BigEndian.PutUint64(b[:8], uint64(m.BindLifetime)) + b = b[8:] + binary.BigEndian.PutUint64(b[:8], uint64(m.SteadyStateLifetime)) + b = b[8:] + for _, ipp := range m.AddrPorts { + a := ipp.Addr().As16() + copy(b, a[:]) + binary.BigEndian.PutUint16(b[16:18], ipp.Port()) + b = b[epLength:] + } +} + +// decode decodes m from b. +func (m *UDPRelayEndpoint) decode(b []byte) error { + if len(b) < udpRelayEndpointLenMinusAddrPorts+epLength || + (len(b)-udpRelayEndpointLenMinusAddrPorts)%epLength != 0 { + return errShort + } + m.ServerDisco = key.DiscoPublicFromRaw32(mem.B(b[:key.DiscoPublicRawLen])) + b = b[key.DiscoPublicRawLen:] + for i := 0; i < len(m.ClientDisco); i++ { + m.ClientDisco[i] = key.DiscoPublicFromRaw32(mem.B(b[:key.DiscoPublicRawLen])) + b = b[key.DiscoPublicRawLen:] + } + m.LamportID = binary.BigEndian.Uint64(b[:8]) + b = b[8:] + m.VNI = binary.BigEndian.Uint32(b[:4]) + b = b[4:] + m.BindLifetime = time.Duration(binary.BigEndian.Uint64(b[:8])) + b = b[8:] + m.SteadyStateLifetime = time.Duration(binary.BigEndian.Uint64(b[:8])) + b = b[8:] + m.AddrPorts = make([]netip.AddrPort, 0, len(b)-udpRelayEndpointLenMinusAddrPorts/epLength) + for len(b) > 0 { + var a [16]byte + copy(a[:], b) + m.AddrPorts = append(m.AddrPorts, netip.AddrPortFrom( + netip.AddrFrom16(a).Unmap(), + binary.BigEndian.Uint16(b[16:18]))) + b = b[epLength:] + } + return nil +} + // CallMeMaybeVia is a message sent only over DERP to request that the recipient // try to open up a magicsock path back to the sender. The 'Via' in // CallMeMaybeVia highlights that candidate paths are served through an @@ -464,78 +627,22 @@ func parseBindUDPRelayEndpointAnswer(ver uint8, p []byte) (m *BindUDPRelayEndpoi // The recipient may choose to not open a path back if it's already happy with // its path. Direct connections, e.g. [CallMeMaybe]-signaled, take priority over // CallMeMaybeVia paths. -// -// This message type is currently considered experimental and is not yet tied to -// a [tailscale.com/tailcfg.CapabilityVersion]. type CallMeMaybeVia struct { - // ServerDisco is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.ServerDisco] - ServerDisco key.DiscoPublic - // LamportID is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.LamportID] - LamportID uint64 - // VNI is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.VNI] - VNI uint32 - // BindLifetime is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.BindLifetime] - BindLifetime time.Duration - // SteadyStateLifetime is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.SteadyStateLifetime] - SteadyStateLifetime time.Duration - // AddrPorts is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.AddrPorts] - AddrPorts []netip.AddrPort + UDPRelayEndpoint } -const cmmvDataLenMinusEndpoints = key.DiscoPublicRawLen + // ServerDisco - 8 + // LamportID - 4 + // VNI - 8 + // BindLifetime - 8 // SteadyStateLifetime - func (m *CallMeMaybeVia) AppendMarshal(b []byte) []byte { endpointsLen := epLength * len(m.AddrPorts) - ret, p := appendMsgHeader(b, TypeCallMeMaybeVia, v0, cmmvDataLenMinusEndpoints+endpointsLen) - disco := m.ServerDisco.AppendTo(nil) - copy(p, disco) - p = p[key.DiscoPublicRawLen:] - binary.BigEndian.PutUint64(p[:8], m.LamportID) - p = p[8:] - binary.BigEndian.PutUint32(p[:4], m.VNI) - p = p[4:] - binary.BigEndian.PutUint64(p[:8], uint64(m.BindLifetime)) - p = p[8:] - binary.BigEndian.PutUint64(p[:8], uint64(m.SteadyStateLifetime)) - p = p[8:] - for _, ipp := range m.AddrPorts { - a := ipp.Addr().As16() - copy(p, a[:]) - binary.BigEndian.PutUint16(p[16:18], ipp.Port()) - p = p[epLength:] - } + ret, p := appendMsgHeader(b, TypeCallMeMaybeVia, v0, udpRelayEndpointLenMinusAddrPorts+endpointsLen) + m.encode(p) return ret } func parseCallMeMaybeVia(ver uint8, p []byte) (m *CallMeMaybeVia, err error) { m = new(CallMeMaybeVia) - if len(p) < cmmvDataLenMinusEndpoints+epLength || - (len(p)-cmmvDataLenMinusEndpoints)%epLength != 0 || - ver != 0 { + if ver != 0 { return m, nil } - m.ServerDisco = key.DiscoPublicFromRaw32(mem.B(p[:key.DiscoPublicRawLen])) - p = p[key.DiscoPublicRawLen:] - m.LamportID = binary.BigEndian.Uint64(p[:8]) - p = p[8:] - m.VNI = binary.BigEndian.Uint32(p[:4]) - p = p[4:] - m.BindLifetime = time.Duration(binary.BigEndian.Uint64(p[:8])) - p = p[8:] - m.SteadyStateLifetime = time.Duration(binary.BigEndian.Uint64(p[:8])) - p = p[8:] - m.AddrPorts = make([]netip.AddrPort, 0, len(p)-cmmvDataLenMinusEndpoints/epLength) - for len(p) > 0 { - var a [16]byte - copy(a[:], p) - m.AddrPorts = append(m.AddrPorts, netip.AddrPortFrom( - netip.AddrFrom16(a).Unmap(), - binary.BigEndian.Uint16(p[16:18]))) - p = p[epLength:] - } - return m, nil + err = m.decode(p) + return m, err } diff --git a/disco/disco_test.go b/disco/disco_test.go index 9fb71ff83..71b68338a 100644 --- a/disco/disco_test.go +++ b/disco/disco_test.go @@ -25,6 +25,19 @@ func TestMarshalAndParse(t *testing.T) { }, } + udpRelayEndpoint := UDPRelayEndpoint{ + ServerDisco: key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 30: 30, 31: 31})), + ClientDisco: [2]key.DiscoPublic{key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 3: 3, 30: 30, 31: 31})), key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 4: 4, 30: 30, 31: 31}))}, + LamportID: 123, + VNI: 456, + BindLifetime: time.Second, + SteadyStateLifetime: time.Minute, + AddrPorts: []netip.AddrPort{ + netip.MustParseAddrPort("1.2.3.4:567"), + netip.MustParseAddrPort("[2001::3456]:789"), + }, + } + tests := []struct { name string want string @@ -117,17 +130,25 @@ func TestMarshalAndParse(t *testing.T) { { name: "call_me_maybe_via", m: &CallMeMaybeVia{ - ServerDisco: key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 30: 30, 31: 31})), - LamportID: 123, - VNI: 456, - BindLifetime: time.Second, - SteadyStateLifetime: time.Minute, - AddrPorts: []netip.AddrPort{ - netip.MustParseAddrPort("1.2.3.4:567"), - netip.MustParseAddrPort("[2001::3456]:789"), - }, + UDPRelayEndpoint: udpRelayEndpoint, }, - want: "07 00 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 00 00 00 00 00 00 7b 00 00 01 c8 00 00 00 00 3b 9a ca 00 00 00 00 0d f8 47 58 00 00 00 00 00 00 00 00 00 00 00 ff ff 01 02 03 04 02 37 20 01 00 00 00 00 00 00 00 00 00 00 00 00 34 56 03 15", + want: "07 00 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 01 02 03 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 01 02 00 04 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 00 00 00 00 00 00 7b 00 00 01 c8 00 00 00 00 3b 9a ca 00 00 00 00 0d f8 47 58 00 00 00 00 00 00 00 00 00 00 00 ff ff 01 02 03 04 02 37 20 01 00 00 00 00 00 00 00 00 00 00 00 00 34 56 03 15", + }, + { + name: "allocate_udp_relay_endpoint_request", + m: &AllocateUDPRelayEndpointRequest{ + ClientDisco: [2]key.DiscoPublic{key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 3: 3, 30: 30, 31: 31})), key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 4: 4, 30: 30, 31: 31}))}, + Generation: 1, + }, + want: "08 00 00 01 02 03 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 01 02 00 04 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 00 00 01", + }, + { + name: "allocate_udp_relay_endpoint_response", + m: &AllocateUDPRelayEndpointResponse{ + Generation: 1, + UDPRelayEndpoint: udpRelayEndpoint, + }, + want: "09 00 00 00 00 01 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 01 02 03 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 01 02 00 04 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 00 00 00 00 00 00 7b 00 00 01 c8 00 00 00 00 3b 9a ca 00 00 00 00 0d f8 47 58 00 00 00 00 00 00 00 00 00 00 00 ff ff 01 02 03 04 02 37 20 01 00 00 00 00 00 00 00 00 00 00 00 00 34 56 03 15", }, } for _, tt := range tests { diff --git a/feature/relayserver/relayserver.go b/feature/relayserver/relayserver.go index d0ad27624..f4077b5f9 100644 --- a/feature/relayserver/relayserver.go +++ b/feature/relayserver/relayserver.go @@ -6,25 +6,21 @@ package relayserver import ( - "encoding/json" "errors" - "fmt" - "io" - "net/http" "sync" - "time" + "tailscale.com/disco" "tailscale.com/feature" "tailscale.com/ipn" "tailscale.com/ipn/ipnext" - "tailscale.com/ipn/ipnlocal" "tailscale.com/net/udprelay" "tailscale.com/net/udprelay/endpoint" "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/ptr" - "tailscale.com/util/httpm" + "tailscale.com/util/eventbus" + "tailscale.com/wgengine/magicsock" ) // featureName is the name of the feature implemented by this package. @@ -34,26 +30,34 @@ const featureName = "relayserver" func init() { feature.Register(featureName) ipnext.RegisterExtension(featureName, newExtension) - ipnlocal.RegisterPeerAPIHandler("/v0/relay/endpoint", handlePeerAPIRelayAllocateEndpoint) } // newExtension is an [ipnext.NewExtensionFn] that creates a new relay server // extension. It is registered with [ipnext.RegisterExtension] if the package is // imported. -func newExtension(logf logger.Logf, _ ipnext.SafeBackend) (ipnext.Extension, error) { - return &extension{logf: logger.WithPrefix(logf, featureName+": ")}, nil +func newExtension(logf logger.Logf, sb ipnext.SafeBackend) (ipnext.Extension, error) { + return &extension{ + logf: logger.WithPrefix(logf, featureName+": "), + bus: sb.Sys().Bus.Get(), + }, nil } // extension is an [ipnext.Extension] managing the relay server on platforms // that import this package. type extension struct { logf logger.Logf + bus *eventbus.Bus - mu sync.Mutex // guards the following fields + mu sync.Mutex // guards the following fields + eventClient *eventbus.Client // closed to stop consumeEventbusTopics + reqSub *eventbus.Subscriber[magicsock.UDPRelayAllocReq] // receives endpoint alloc requests from magicsock + respPub *eventbus.Publisher[magicsock.UDPRelayAllocResp] // publishes endpoint alloc responses to magicsock shutdown bool - port *int // ipn.Prefs.RelayServerPort, nil if disabled - hasNodeAttrDisableRelayServer bool // tailcfg.NodeAttrDisableRelayServer - server relayServer // lazily initialized + port *int // ipn.Prefs.RelayServerPort, nil if disabled + busDoneCh chan struct{} // non-nil if port is non-nil, closed when consumeEventbusTopics returns + hasNodeAttrDisableRelayServer bool // tailcfg.NodeAttrDisableRelayServer + server relayServer // lazily initialized + } // relayServer is the interface of [udprelay.Server]. @@ -77,6 +81,18 @@ func (e *extension) Init(host ipnext.Host) error { return nil } +// initBusConnection initializes the [*eventbus.Client], [*eventbus.Subscriber], +// [*eventbus.Publisher], and [chan struct{}] used to publish/receive endpoint +// allocation messages to/from the [*eventbus.Bus]. It also starts +// consumeEventbusTopics in a separate goroutine. +func (e *extension) initBusConnection() { + e.eventClient = e.bus.Client("relayserver.extension") + e.reqSub = eventbus.Subscribe[magicsock.UDPRelayAllocReq](e.eventClient) + e.respPub = eventbus.Publish[magicsock.UDPRelayAllocResp](e.eventClient) + e.busDoneCh = make(chan struct{}) + go e.consumeEventbusTopics() +} + func (e *extension) selfNodeViewChanged(nodeView tailcfg.NodeView) { e.mu.Lock() defer e.mu.Unlock() @@ -98,13 +114,59 @@ func (e *extension) profileStateChanged(_ ipn.LoginProfileView, prefs ipn.PrefsV e.server.Close() e.server = nil } + if e.port != nil { + e.eventClient.Close() + <-e.busDoneCh + } e.port = nil if ok { e.port = ptr.To(newPort) + e.initBusConnection() } } } +func (e *extension) consumeEventbusTopics() { + defer close(e.busDoneCh) + + for { + select { + case <-e.reqSub.Done(): + // If reqSub is done, the eventClient has been closed, which is a + // signal to return. + return + case req := <-e.reqSub.Events(): + rs, err := e.relayServerOrInit() + if err != nil { + e.logf("error initializing server: %v", err) + continue + } + se, err := rs.AllocateEndpoint(req.Message.ClientDisco[0], req.Message.ClientDisco[1]) + if err != nil { + e.logf("error allocating endpoint: %v", err) + continue + } + e.respPub.Publish(magicsock.UDPRelayAllocResp{ + ReqRxFromNodeKey: req.RxFromNodeKey, + ReqRxFromDiscoKey: req.RxFromDiscoKey, + Message: &disco.AllocateUDPRelayEndpointResponse{ + Generation: req.Message.Generation, + UDPRelayEndpoint: disco.UDPRelayEndpoint{ + ServerDisco: se.ServerDisco, + ClientDisco: se.ClientDisco, + LamportID: se.LamportID, + VNI: se.VNI, + BindLifetime: se.BindLifetime.Duration, + SteadyStateLifetime: se.SteadyStateLifetime.Duration, + AddrPorts: se.AddrPorts, + }, + }, + }) + } + } + +} + // Shutdown implements [ipnlocal.Extension]. func (e *extension) Shutdown() error { e.mu.Lock() @@ -114,6 +176,10 @@ func (e *extension) Shutdown() error { e.server.Close() e.server = nil } + if e.port != nil { + e.eventClient.Close() + <-e.busDoneCh + } return nil } @@ -139,60 +205,3 @@ func (e *extension) relayServerOrInit() (relayServer, error) { } return e.server, nil } - -func handlePeerAPIRelayAllocateEndpoint(h ipnlocal.PeerAPIHandler, w http.ResponseWriter, r *http.Request) { - e, ok := ipnlocal.GetExt[*extension](h.LocalBackend()) - if !ok { - http.Error(w, "relay failed to initialize", http.StatusServiceUnavailable) - return - } - - httpErrAndLog := func(message string, code int) { - http.Error(w, message, code) - h.Logf("relayserver: request from %v returned code %d: %s", h.RemoteAddr(), code, message) - } - - if !h.PeerCaps().HasCapability(tailcfg.PeerCapabilityRelay) { - httpErrAndLog("relay not permitted", http.StatusForbidden) - return - } - - if r.Method != httpm.POST { - httpErrAndLog("only POST method is allowed", http.StatusMethodNotAllowed) - return - } - - var allocateEndpointReq struct { - DiscoKeys []key.DiscoPublic - } - err := json.NewDecoder(io.LimitReader(r.Body, 512)).Decode(&allocateEndpointReq) - if err != nil { - httpErrAndLog(err.Error(), http.StatusBadRequest) - return - } - if len(allocateEndpointReq.DiscoKeys) != 2 { - httpErrAndLog("2 disco public keys must be supplied", http.StatusBadRequest) - return - } - - rs, err := e.relayServerOrInit() - if err != nil { - httpErrAndLog(err.Error(), http.StatusServiceUnavailable) - return - } - ep, err := rs.AllocateEndpoint(allocateEndpointReq.DiscoKeys[0], allocateEndpointReq.DiscoKeys[1]) - if err != nil { - var notReady udprelay.ErrServerNotReady - if errors.As(err, ¬Ready) { - w.Header().Set("Retry-After", fmt.Sprintf("%d", notReady.RetryAfter.Round(time.Second)/time.Second)) - httpErrAndLog(err.Error(), http.StatusServiceUnavailable) - return - } - httpErrAndLog(err.Error(), http.StatusInternalServerError) - return - } - err = json.NewEncoder(w).Encode(&ep) - if err != nil { - httpErrAndLog(err.Error(), http.StatusInternalServerError) - } -} diff --git a/feature/relayserver/relayserver_test.go b/feature/relayserver/relayserver_test.go index cc7f05f67..84158188e 100644 --- a/feature/relayserver/relayserver_test.go +++ b/feature/relayserver/relayserver_test.go @@ -9,6 +9,7 @@ import ( "tailscale.com/ipn" "tailscale.com/net/udprelay/endpoint" + "tailscale.com/tsd" "tailscale.com/types/key" "tailscale.com/types/ptr" ) @@ -108,9 +109,18 @@ func Test_extension_profileStateChanged(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + sys := tsd.NewSystem() + bus := sys.Bus.Get() e := &extension{ port: tt.fields.port, server: tt.fields.server, + bus: bus, + } + if e.port != nil { + // Entering profileStateChanged with a non-nil port requires + // bus init, which is called in profileStateChanged when + // transitioning port from nil to non-nil. + e.initBusConnection() } e.profileStateChanged(ipn.LoginProfileView{}, tt.args.prefs, tt.args.sameNode) if tt.wantNilServer != (e.server == nil) { diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index d3754e540..8665a88c4 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -6957,7 +6957,7 @@ func (b *LocalBackend) DebugReSTUN() error { return nil } -func (b *LocalBackend) DebugPeerRelayServers() set.Set[netip.AddrPort] { +func (b *LocalBackend) DebugPeerRelayServers() set.Set[netip.Addr] { return b.MagicConn().PeerRelays() } diff --git a/ipn/localapi/localapi.go b/ipn/localapi/localapi.go index 2409aa1ae..0acc5a65f 100644 --- a/ipn/localapi/localapi.go +++ b/ipn/localapi/localapi.go @@ -699,7 +699,7 @@ func (h *Handler) serveDebug(w http.ResponseWriter, r *http.Request) { h.b.DebugForcePreferDERP(n) case "peer-relay-servers": servers := h.b.DebugPeerRelayServers().Slice() - slices.SortFunc(servers, func(a, b netip.AddrPort) int { + slices.SortFunc(servers, func(a, b netip.Addr) int { return a.Compare(b) }) err = json.NewEncoder(w).Encode(servers) diff --git a/net/udprelay/endpoint/endpoint.go b/net/udprelay/endpoint/endpoint.go index 2672a856b..0d2a14e96 100644 --- a/net/udprelay/endpoint/endpoint.go +++ b/net/udprelay/endpoint/endpoint.go @@ -7,11 +7,16 @@ package endpoint import ( "net/netip" + "time" "tailscale.com/tstime" "tailscale.com/types/key" ) +// ServerRetryAfter is the default +// [tailscale.com/net/udprelay.ErrServerNotReady.RetryAfter] value. +const ServerRetryAfter = time.Second * 3 + // ServerEndpoint contains details for an endpoint served by a // [tailscale.com/net/udprelay.Server]. type ServerEndpoint struct { @@ -21,6 +26,10 @@ type ServerEndpoint struct { // unique ServerEndpoint allocation. ServerDisco key.DiscoPublic + // ClientDisco are the Disco public keys of the relay participants permitted + // to handshake with this endpoint. + ClientDisco [2]key.DiscoPublic + // LamportID is unique and monotonically non-decreasing across // ServerEndpoint allocations for the lifetime of Server. It enables clients // to dedup and resolve allocation event order. Clients may race to allocate diff --git a/net/udprelay/server.go b/net/udprelay/server.go index 7651bf295..c34a4b5f6 100644 --- a/net/udprelay/server.go +++ b/net/udprelay/server.go @@ -73,23 +73,7 @@ type Server struct { lamportID uint64 vniPool []uint32 // the pool of available VNIs byVNI map[uint32]*serverEndpoint - byDisco map[pairOfDiscoPubKeys]*serverEndpoint -} - -// pairOfDiscoPubKeys is a pair of key.DiscoPublic. It must be constructed via -// newPairOfDiscoPubKeys to ensure lexicographical ordering. -type pairOfDiscoPubKeys [2]key.DiscoPublic - -func (p pairOfDiscoPubKeys) String() string { - return fmt.Sprintf("%s <=> %s", p[0].ShortString(), p[1].ShortString()) -} - -func newPairOfDiscoPubKeys(discoA, discoB key.DiscoPublic) pairOfDiscoPubKeys { - pair := pairOfDiscoPubKeys([2]key.DiscoPublic{discoA, discoB}) - slices.SortFunc(pair[:], func(a, b key.DiscoPublic) int { - return a.Compare(b) - }) - return pair + byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint } // serverEndpoint contains Server-internal [endpoint.ServerEndpoint] state. @@ -99,7 +83,7 @@ type serverEndpoint struct { // indexing of this array aligns with the following fields, e.g. // discoSharedSecrets[0] is the shared secret to use when sealing // Disco protocol messages for transmission towards discoPubKeys[0]. - discoPubKeys pairOfDiscoPubKeys + discoPubKeys key.SortedPairOfDiscoPublic discoSharedSecrets [2]key.DiscoShared handshakeGeneration [2]uint32 // or zero if a handshake has never started for that relay leg handshakeAddrPorts [2]netip.AddrPort // or zero value if a handshake has never started for that relay leg @@ -126,7 +110,7 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex if common.VNI != e.vni { return errors.New("mismatching VNI") } - if common.RemoteKey.Compare(e.discoPubKeys[otherSender]) != 0 { + if common.RemoteKey.Compare(e.discoPubKeys.Get()[otherSender]) != 0 { return errors.New("mismatching RemoteKey") } return nil @@ -152,7 +136,7 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex m := new(disco.BindUDPRelayEndpointChallenge) m.VNI = e.vni m.Generation = discoMsg.Generation - m.RemoteKey = e.discoPubKeys[otherSender] + m.RemoteKey = e.discoPubKeys.Get()[otherSender] rand.Read(e.challenge[senderIndex][:]) copy(m.Challenge[:], e.challenge[senderIndex][:]) reply := make([]byte, packet.GeneveFixedHeaderLength, 512) @@ -200,9 +184,9 @@ func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []by sender := key.DiscoPublicFromRaw32(mem.B(senderRaw)) senderIndex := -1 switch { - case sender.Compare(e.discoPubKeys[0]) == 0: + case sender.Compare(e.discoPubKeys.Get()[0]) == 0: senderIndex = 0 - case sender.Compare(e.discoPubKeys[1]) == 0: + case sender.Compare(e.discoPubKeys.Get()[1]) == 0: senderIndex = 1 default: // unknown Disco public key @@ -291,12 +275,12 @@ func (e *serverEndpoint) isBound() bool { // which is useful to override in tests. func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Server, err error) { s = &Server{ - logf: logger.WithPrefix(logf, "relayserver"), + logf: logf, disco: key.NewDisco(), bindLifetime: defaultBindLifetime, steadyStateLifetime: defaultSteadyStateLifetime, closeCh: make(chan struct{}), - byDisco: make(map[pairOfDiscoPubKeys]*serverEndpoint), + byDisco: make(map[key.SortedPairOfDiscoPublic]*serverEndpoint), byVNI: make(map[uint32]*serverEndpoint), } s.discoPublic = s.disco.Public() @@ -315,7 +299,7 @@ func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Serve } s.netChecker = &netcheck.Client{ NetMon: netMon, - Logf: logger.WithPrefix(logf, "relayserver: netcheck:"), + Logf: logger.WithPrefix(logf, "netcheck: "), SendPacket: func(b []byte, addrPort netip.AddrPort) (int, error) { if addrPort.Addr().Is4() { return s.uc4.WriteToUDPAddrPort(b, addrPort) @@ -615,7 +599,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv if len(s.addrPorts) == 0 { if !s.addrDiscoveryOnce { - return endpoint.ServerEndpoint{}, ErrServerNotReady{RetryAfter: 3 * time.Second} + return endpoint.ServerEndpoint{}, ErrServerNotReady{RetryAfter: endpoint.ServerRetryAfter} } return endpoint.ServerEndpoint{}, errors.New("server addrPorts are not yet known") } @@ -624,7 +608,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv return endpoint.ServerEndpoint{}, fmt.Errorf("client disco equals server disco: %s", s.discoPublic.ShortString()) } - pair := newPairOfDiscoPubKeys(discoA, discoB) + pair := key.NewSortedPairOfDiscoPublic(discoA, discoB) e, ok := s.byDisco[pair] if ok { // Return the existing allocation. Clients can resolve duplicate @@ -639,6 +623,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv // behaviors and endpoint state (bound or not). We might want to // consider storing them (maybe interning) in the [*serverEndpoint] // at allocation time. + ClientDisco: pair.Get(), AddrPorts: slices.Clone(s.addrPorts), VNI: e.vni, LamportID: e.lamportID, @@ -657,15 +642,17 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv lamportID: s.lamportID, allocatedAt: time.Now(), } - e.discoSharedSecrets[0] = s.disco.Shared(e.discoPubKeys[0]) - e.discoSharedSecrets[1] = s.disco.Shared(e.discoPubKeys[1]) + e.discoSharedSecrets[0] = s.disco.Shared(e.discoPubKeys.Get()[0]) + e.discoSharedSecrets[1] = s.disco.Shared(e.discoPubKeys.Get()[1]) e.vni, s.vniPool = s.vniPool[0], s.vniPool[1:] s.byDisco[pair] = e s.byVNI[e.vni] = e + s.logf("allocated endpoint vni=%d lamportID=%d disco[0]=%v disco[1]=%v", e.vni, e.lamportID, pair.Get()[0].ShortString(), pair.Get()[1].ShortString()) return endpoint.ServerEndpoint{ ServerDisco: s.discoPublic, + ClientDisco: pair.Get(), AddrPorts: slices.Clone(s.addrPorts), VNI: e.vni, LamportID: e.lamportID, diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index 398a2c8a2..550914b96 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -165,7 +165,8 @@ type CapabilityVersion int // - 118: 2025-07-01: Client sends Hostinfo.StateEncrypted to report whether the state file is encrypted at rest (#15830) // - 119: 2025-07-10: Client uses Hostinfo.Location.Priority to prioritize one route over another. // - 120: 2025-07-15: Client understands peer relay disco messages, and implements peer client and relay server functions -const CurrentCapabilityVersion CapabilityVersion = 120 +// - 121: 2025-07-19: Client understands peer relay endpoint alloc with [disco.AllocateUDPRelayEndpointRequest] & [disco.AllocateUDPRelayEndpointResponse] +const CurrentCapabilityVersion CapabilityVersion = 121 // ID is an integer ID for a user, node, or login allocated by the // control plane. diff --git a/types/key/disco.go b/types/key/disco.go index 1013ce5bf..ce5f9b36f 100644 --- a/types/key/disco.go +++ b/types/key/disco.go @@ -73,6 +73,44 @@ func (k DiscoPrivate) Shared(p DiscoPublic) DiscoShared { return ret } +// SortedPairOfDiscoPublic is a lexicographically sorted container of two +// [DiscoPublic] keys. +type SortedPairOfDiscoPublic struct { + k [2]DiscoPublic +} + +// Get returns the underlying keys. +func (s SortedPairOfDiscoPublic) Get() [2]DiscoPublic { + return s.k +} + +// NewSortedPairOfDiscoPublic returns a SortedPairOfDiscoPublic from a and b. +func NewSortedPairOfDiscoPublic(a, b DiscoPublic) SortedPairOfDiscoPublic { + s := SortedPairOfDiscoPublic{} + if a.Compare(b) < 0 { + s.k[0] = a + s.k[1] = b + } else { + s.k[0] = b + s.k[1] = a + } + return s +} + +func (s SortedPairOfDiscoPublic) String() string { + return fmt.Sprintf("%s <=> %s", s.k[0].ShortString(), s.k[1].ShortString()) +} + +// Equal returns true if s and b are equal, otherwise it returns false. +func (s SortedPairOfDiscoPublic) Equal(b SortedPairOfDiscoPublic) bool { + for i := range s.k { + if s.k[i].Compare(b.k[i]) != 0 { + return false + } + } + return true +} + // DiscoPublic is the public portion of a DiscoPrivate. type DiscoPublic struct { k [32]byte diff --git a/types/key/disco_test.go b/types/key/disco_test.go index c62c13cbf..131fe350f 100644 --- a/types/key/disco_test.go +++ b/types/key/disco_test.go @@ -81,3 +81,21 @@ func TestDiscoShared(t *testing.T) { t.Error("k1.Shared(k2) != k2.Shared(k1)") } } + +func TestSortedPairOfDiscoPublic(t *testing.T) { + pubA := DiscoPublic{} + pubA.k[0] = 0x01 + pubB := DiscoPublic{} + pubB.k[0] = 0x02 + sortedInput := NewSortedPairOfDiscoPublic(pubA, pubB) + unsortedInput := NewSortedPairOfDiscoPublic(pubB, pubA) + if sortedInput.Get() != unsortedInput.Get() { + t.Fatal("sortedInput.Get() != unsortedInput.Get()") + } + if unsortedInput.Get()[0] != pubA { + t.Fatal("unsortedInput.Get()[0] != pubA") + } + if unsortedInput.Get()[1] != pubB { + t.Fatal("unsortedInput.Get()[1] != pubB") + } +} diff --git a/wgengine/magicsock/endpoint.go b/wgengine/magicsock/endpoint.go index 48d5ef5a1..6381b0210 100644 --- a/wgengine/magicsock/endpoint.go +++ b/wgengine/magicsock/endpoint.go @@ -879,8 +879,14 @@ func (de *endpoint) setHeartbeatDisabled(v bool) { // discoverUDPRelayPathsLocked starts UDP relay path discovery. func (de *endpoint) discoverUDPRelayPathsLocked(now mono.Time) { - // TODO(jwhited): return early if there are no relay servers set, otherwise - // we spin up and down relayManager.runLoop unnecessarily. + if !de.c.hasPeerRelayServers.Load() { + // Changes in this value between its access and the logic following + // are fine, we will eventually do the "right" thing during future path + // discovery. The worst case is we suppress path discovery for the + // current cycle, or we unnecessarily call into [relayManager] and do + // some wasted work. + return + } de.lastUDPRelayPathDiscovery = now lastBest := de.bestAddr lastBestIsTrusted := mono.Now().Before(de.trustBestAddrUntil) @@ -2035,8 +2041,15 @@ func (de *endpoint) numStopAndReset() int64 { return atomic.LoadInt64(&de.numStopAndResetAtomic) } +// setDERPHome sets the provided regionID as home for de. Calls to setDERPHome +// must never run concurrent to [Conn.updateRelayServersSet], otherwise +// [candidatePeerRelay] DERP home changes may be missed from the perspective of +// [relayManager]. func (de *endpoint) setDERPHome(regionID uint16) { de.mu.Lock() defer de.mu.Unlock() de.derpAddr = netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, uint16(regionID)) + if de.c.hasPeerRelayServers.Load() { + de.c.relayManager.handleDERPHomeChange(de.publicKey, regionID) + } } diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index ad07003f7..ee0ee40ca 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -175,13 +175,15 @@ type Conn struct { // These [eventbus.Subscriber] fields are solely accessed by // consumeEventbusTopics once initialized. - pmSub *eventbus.Subscriber[portmapper.Mapping] - filterSub *eventbus.Subscriber[FilterUpdate] - nodeViewsSub *eventbus.Subscriber[NodeViewsUpdate] - nodeMutsSub *eventbus.Subscriber[NodeMutationsUpdate] - syncSub *eventbus.Subscriber[syncPoint] - syncPub *eventbus.Publisher[syncPoint] - subsDoneCh chan struct{} // closed when consumeEventbusTopics returns + pmSub *eventbus.Subscriber[portmapper.Mapping] + filterSub *eventbus.Subscriber[FilterUpdate] + nodeViewsSub *eventbus.Subscriber[NodeViewsUpdate] + nodeMutsSub *eventbus.Subscriber[NodeMutationsUpdate] + syncSub *eventbus.Subscriber[syncPoint] + syncPub *eventbus.Publisher[syncPoint] + allocRelayEndpointPub *eventbus.Publisher[UDPRelayAllocReq] + allocRelayEndpointSub *eventbus.Subscriber[UDPRelayAllocResp] + subsDoneCh chan struct{} // closed when consumeEventbusTopics returns // pconn4 and pconn6 are the underlying UDP sockets used to // send/receive packets for wireguard and other magicsock @@ -271,6 +273,14 @@ type Conn struct { // captureHook, if non-nil, is the pcap logging callback when capturing. captureHook syncs.AtomicValue[packet.CaptureCallback] + // hasPeerRelayServers is whether [relayManager] is configured with at least + // one peer relay server via [relayManager.handleRelayServersSet]. It is + // only accessed by [Conn.updateRelayServersSet], [endpoint.setDERPHome], + // and [endpoint.discoverUDPRelayPathsLocked]. It exists to suppress + // calls into [relayManager] leading to wasted work involving channel + // operations and goroutine creation. + hasPeerRelayServers atomic.Bool + // discoPrivate is the private naclbox key used for active // discovery traffic. It is always present, and immutable. discoPrivate key.DiscoPrivate @@ -567,6 +577,36 @@ func (s syncPoint) Signal() { close(s) } +// UDPRelayAllocReq represents a [*disco.AllocateUDPRelayEndpointRequest] +// reception event. This is signaled over an [eventbus.Bus] from +// [magicsock.Conn] towards [relayserver.extension]. +type UDPRelayAllocReq struct { + // RxFromNodeKey is the unauthenticated (DERP server claimed src) node key + // of the transmitting party, noted at disco message reception time over + // DERP. This node key is unambiguously-aligned with RxFromDiscoKey being + // that the disco message is received over DERP. + RxFromNodeKey key.NodePublic + // RxFromDiscoKey is the disco key of the transmitting party, noted and + // authenticated at reception time. + RxFromDiscoKey key.DiscoPublic + // Message is the disco message. + Message *disco.AllocateUDPRelayEndpointRequest +} + +// UDPRelayAllocResp represents a [*disco.AllocateUDPRelayEndpointResponse] +// that is yet to be transmitted over DERP (or delivered locally if +// ReqRxFromNodeKey is self). This is signaled over an [eventbus.Bus] from +// [relayserver.extension] towards [magicsock.Conn]. +type UDPRelayAllocResp struct { + // ReqRxFromNodeKey is copied from [UDPRelayAllocReq.RxFromNodeKey]. It + // enables peer lookup leading up to transmission over DERP. + ReqRxFromNodeKey key.NodePublic + // ReqRxFromDiscoKey is copied from [UDPRelayAllocReq.RxFromDiscoKey]. + ReqRxFromDiscoKey key.DiscoPublic + // Message is the disco message. + Message *disco.AllocateUDPRelayEndpointResponse +} + // newConn is the error-free, network-listening-side-effect-free based // of NewConn. Mostly for tests. func newConn(logf logger.Logf) *Conn { @@ -625,10 +665,40 @@ func (c *Conn) consumeEventbusTopics() { case syncPoint := <-c.syncSub.Events(): c.dlogf("magicsock: received sync point after reconfig") syncPoint.Signal() + case allocResp := <-c.allocRelayEndpointSub.Events(): + c.onUDPRelayAllocResp(allocResp) } } } +func (c *Conn) onUDPRelayAllocResp(allocResp UDPRelayAllocResp) { + c.mu.Lock() + defer c.mu.Unlock() + ep, ok := c.peerMap.endpointForNodeKey(allocResp.ReqRxFromNodeKey) + if !ok { + // If it's not a peer, it might be for self (we can peer relay through + // ourselves), in which case we want to hand it down to [relayManager] + // now versus taking a network round-trip through DERP. + selfNodeKey := c.publicKeyAtomic.Load() + if selfNodeKey.Compare(allocResp.ReqRxFromNodeKey) == 0 && + allocResp.ReqRxFromDiscoKey.Compare(c.discoPublic) == 0 { + c.relayManager.handleRxDiscoMsg(c, allocResp.Message, selfNodeKey, allocResp.ReqRxFromDiscoKey, epAddr{}) + } + return + } + disco := ep.disco.Load() + if disco == nil { + return + } + if disco.key.Compare(allocResp.ReqRxFromDiscoKey) != 0 { + return + } + ep.mu.Lock() + defer ep.mu.Unlock() + derpAddr := ep.derpAddr + go c.sendDiscoMessage(epAddr{ap: derpAddr}, ep.publicKey, disco.key, allocResp.Message, discoVerboseLog) +} + // Synchronize waits for all [eventbus] events published // prior to this call to be processed by the receiver. func (c *Conn) Synchronize() { @@ -670,6 +740,8 @@ func NewConn(opts Options) (*Conn, error) { c.nodeMutsSub = eventbus.Subscribe[NodeMutationsUpdate](c.eventClient) c.syncSub = eventbus.Subscribe[syncPoint](c.eventClient) c.syncPub = eventbus.Publish[syncPoint](c.eventClient) + c.allocRelayEndpointPub = eventbus.Publish[UDPRelayAllocReq](c.eventClient) + c.allocRelayEndpointSub = eventbus.Subscribe[UDPRelayAllocResp](c.eventClient) c.subsDoneCh = make(chan struct{}) go c.consumeEventbusTopics() } @@ -1847,6 +1919,24 @@ func (v *virtualNetworkID) get() uint32 { return v._vni & vniGetMask } +// sendDiscoAllocateUDPRelayEndpointRequest is primarily an alias for +// sendDiscoMessage, but it will alternatively send m over the eventbus if dst +// is a DERP IP:port, and dstKey is self. This saves a round-trip through DERP +// when we are attempting to allocate on a self (in-process) peer relay server. +func (c *Conn) sendDiscoAllocateUDPRelayEndpointRequest(dst epAddr, dstKey key.NodePublic, dstDisco key.DiscoPublic, allocReq *disco.AllocateUDPRelayEndpointRequest, logLevel discoLogLevel) (sent bool, err error) { + isDERP := dst.ap.Addr() == tailcfg.DerpMagicIPAddr + selfNodeKey := c.publicKeyAtomic.Load() + if isDERP && dstKey.Compare(selfNodeKey) == 0 { + c.allocRelayEndpointPub.Publish(UDPRelayAllocReq{ + RxFromNodeKey: selfNodeKey, + RxFromDiscoKey: c.discoPublic, + Message: allocReq, + }) + return true, nil + } + return c.sendDiscoMessage(dst, dstKey, dstDisco, allocReq, logLevel) +} + // sendDiscoMessage sends discovery message m to dstDisco at dst. // // If dst.ap is a DERP IP:port, then dstKey must be non-zero. @@ -2176,7 +2266,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake c.logf("[unexpected] %T packets should not come from a relay server with Geneve control bit set", dm) return } - c.relayManager.handleGeneveEncapDiscoMsg(c, challenge, di, src) + c.relayManager.handleRxDiscoMsg(c, challenge, key.NodePublic{}, di.discoKey, src) return } @@ -2201,7 +2291,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake // If it's an unknown TxID, and it's Geneve-encapsulated, then // make [relayManager] aware. It might be in the middle of probing // src. - c.relayManager.handleGeneveEncapDiscoMsg(c, dm, di, src) + c.relayManager.handleRxDiscoMsg(c, dm, key.NodePublic{}, di.discoKey, src) } case *disco.CallMeMaybe, *disco.CallMeMaybeVia: var via *disco.CallMeMaybeVia @@ -2276,7 +2366,95 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake len(cmm.MyNumber)) go ep.handleCallMeMaybe(cmm) } + case *disco.AllocateUDPRelayEndpointRequest, *disco.AllocateUDPRelayEndpointResponse: + var resp *disco.AllocateUDPRelayEndpointResponse + isResp := false + msgType := "AllocateUDPRelayEndpointRequest" + req, ok := dm.(*disco.AllocateUDPRelayEndpointRequest) + if ok { + metricRecvDiscoAllocUDPRelayEndpointRequest.Add(1) + } else { + metricRecvDiscoAllocUDPRelayEndpointResponse.Add(1) + resp = dm.(*disco.AllocateUDPRelayEndpointResponse) + msgType = "AllocateUDPRelayEndpointResponse" + isResp = true + } + if !isDERP { + // These messages should only come via DERP. + c.logf("[unexpected] %s packets should only come via DERP", msgType) + return + } + nodeKey := derpNodeSrc + ep, ok := c.peerMap.endpointForNodeKey(nodeKey) + if !ok { + c.logf("magicsock: disco: ignoring %s from %v; %v is unknown", msgType, sender.ShortString(), derpNodeSrc.ShortString()) + return + } + epDisco := ep.disco.Load() + if epDisco == nil { + return + } + if epDisco.key != di.discoKey { + if isResp { + metricRecvDiscoAllocUDPRelayEndpointResponseBadDisco.Add(1) + } else { + metricRecvDiscoAllocUDPRelayEndpointRequestBadDisco.Add(1) + } + c.logf("[unexpected] %s from peer via DERP whose netmap discokey != disco source", msgType) + return + } + + if isResp { + c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got %s, %d endpoints", + c.discoShort, epDisco.short, + ep.publicKey.ShortString(), derpStr(src.String()), + msgType, + len(resp.AddrPorts)) + c.relayManager.handleRxDiscoMsg(c, resp, nodeKey, di.discoKey, src) + return + } else if sender.Compare(req.ClientDisco[0]) != 0 && sender.Compare(req.ClientDisco[1]) != 0 { + // An allocation request must contain the sender's disco key in + // ClientDisco. One of the relay participants must be the sender. + c.logf("magicsock: disco: %s from %v; %v does not contain sender's disco key", + msgType, sender.ShortString(), derpNodeSrc.ShortString()) + return + } else { + c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got %s, for %d<->%d", + c.discoShort, epDisco.short, + ep.publicKey.ShortString(), derpStr(src.String()), + msgType, + req.ClientDisco[0], req.ClientDisco[1]) + } + + if c.filt == nil { + return + } + // Binary search of peers is O(log n) while c.mu is held. + // TODO: We might be able to use ep.nodeAddr instead of all addresses, + // or we might be able to release c.mu before doing this work. Keep it + // simple and slow for now. c.peers.AsSlice is a copy. We may need to + // write our own binary search for a [views.Slice]. + peerI, ok := slices.BinarySearchFunc(c.peers.AsSlice(), ep.nodeID, func(peer tailcfg.NodeView, target tailcfg.NodeID) int { + if peer.ID() < target { + return -1 + } else if peer.ID() > target { + return 1 + } + return 0 + }) + if !ok { + // unexpected + return + } + if !nodeHasCap(c.filt, c.peers.At(peerI), c.self, tailcfg.PeerCapabilityRelay) { + return + } + c.allocRelayEndpointPub.Publish(UDPRelayAllocReq{ + RxFromDiscoKey: sender, + RxFromNodeKey: nodeKey, + Message: req, + }) } return } @@ -2337,7 +2515,7 @@ func (c *Conn) handlePingLocked(dm *disco.Ping, src epAddr, di *discoInfo, derpN // Geneve-encapsulated [disco.Ping] messages in the interest of // simplicity. It might be in the middle of probing src, so it must be // made aware. - c.relayManager.handleGeneveEncapDiscoMsg(c, dm, di, src) + c.relayManager.handleRxDiscoMsg(c, dm, key.NodePublic{}, di.discoKey, src) return } @@ -2687,7 +2865,7 @@ func (c *Conn) SetProbeUDPLifetime(v bool) { // capVerIsRelayCapable returns true if version is relay client and server // capable, otherwise it returns false. func capVerIsRelayCapable(version tailcfg.CapabilityVersion) bool { - return version >= 120 + return version >= 121 } // onFilterUpdate is called when a [FilterUpdate] is received over the @@ -2715,6 +2893,11 @@ func (c *Conn) onFilterUpdate(f FilterUpdate) { // peers are passed as args (vs c.mu-guarded fields) to enable callers to // release c.mu before calling as this is O(m * n) (we iterate all cap rules 'm' // in filt for every peer 'n'). +// +// Calls to updateRelayServersSet must never run concurrent to +// [endpoint.setDERPHome], otherwise [candidatePeerRelay] DERP home changes may +// be missed from the perspective of [relayManager]. +// // TODO: Optimize this so that it's not O(m * n). This might involve: // 1. Changes to [filter.Filter], e.g. adding a CapsWithValues() to check for // a given capability instead of building and returning a map of all of @@ -2722,69 +2905,75 @@ func (c *Conn) onFilterUpdate(f FilterUpdate) { // 2. Moving this work upstream into [nodeBackend] or similar, and publishing // the computed result over the eventbus instead. func (c *Conn) updateRelayServersSet(filt *filter.Filter, self tailcfg.NodeView, peers views.Slice[tailcfg.NodeView]) { - relayServers := make(set.Set[netip.AddrPort]) + relayServers := make(set.Set[candidatePeerRelay]) nodes := append(peers.AsSlice(), self) for _, maybeCandidate := range nodes { - peerAPI := peerAPIIfCandidateRelayServer(filt, self, maybeCandidate) - if peerAPI.IsValid() { - relayServers.Add(peerAPI) - } - } - c.relayManager.handleRelayServersSet(relayServers) -} - -// peerAPIIfCandidateRelayServer returns the peer API address of maybeCandidate -// if it is considered to be a candidate relay server upon evaluation against -// filt and self, otherwise it returns a zero value. self and maybeCandidate -// may be equal. -func peerAPIIfCandidateRelayServer(filt *filter.Filter, self, maybeCandidate tailcfg.NodeView) netip.AddrPort { - if filt == nil || - !self.Valid() || - !maybeCandidate.Valid() || - !maybeCandidate.Hostinfo().Valid() { - return netip.AddrPort{} - } - if maybeCandidate.ID() != self.ID() && !capVerIsRelayCapable(maybeCandidate.Cap()) { - // If maybeCandidate's [tailcfg.CapabilityVersion] is not relay-capable, - // we skip it. If maybeCandidate happens to be self, then this check is - // unnecessary as self is always capable from this point (the statically - // compiled [tailcfg.CurrentCapabilityVersion]) forward. - return netip.AddrPort{} - } - for _, maybeCandidatePrefix := range maybeCandidate.Addresses().All() { - if !maybeCandidatePrefix.IsSingleIP() { + if maybeCandidate.ID() != self.ID() && !capVerIsRelayCapable(maybeCandidate.Cap()) { + // If maybeCandidate's [tailcfg.CapabilityVersion] is not relay-capable, + // we skip it. If maybeCandidate happens to be self, then this check is + // unnecessary as self is always capable from this point (the statically + // compiled [tailcfg.CurrentCapabilityVersion]) forward. continue } - maybeCandidateAddr := maybeCandidatePrefix.Addr() - for _, selfPrefix := range self.Addresses().All() { - if !selfPrefix.IsSingleIP() { + if !nodeHasCap(filt, maybeCandidate, self, tailcfg.PeerCapabilityRelayTarget) { + continue + } + relayServers.Add(candidatePeerRelay{ + nodeKey: maybeCandidate.Key(), + discoKey: maybeCandidate.DiscoKey(), + derpHomeRegionID: uint16(maybeCandidate.HomeDERP()), + }) + } + c.relayManager.handleRelayServersSet(relayServers) + if len(relayServers) > 0 { + c.hasPeerRelayServers.Store(true) + } else { + c.hasPeerRelayServers.Store(false) + } +} + +// nodeHasCap returns true if src has cap on dst, otherwise it returns false. +func nodeHasCap(filt *filter.Filter, src, dst tailcfg.NodeView, cap tailcfg.PeerCapability) bool { + if filt == nil || + !src.Valid() || + !dst.Valid() { + return false + } + for _, srcPrefix := range src.Addresses().All() { + if !srcPrefix.IsSingleIP() { + continue + } + srcAddr := srcPrefix.Addr() + for _, dstPrefix := range dst.Addresses().All() { + if !dstPrefix.IsSingleIP() { continue } - selfAddr := selfPrefix.Addr() - if selfAddr.BitLen() == maybeCandidateAddr.BitLen() { // same address family - if filt.CapsWithValues(maybeCandidateAddr, selfAddr).HasCapability(tailcfg.PeerCapabilityRelayTarget) { - for _, s := range maybeCandidate.Hostinfo().Services().All() { - if maybeCandidateAddr.Is4() && s.Proto == tailcfg.PeerAPI4 || - maybeCandidateAddr.Is6() && s.Proto == tailcfg.PeerAPI6 { - return netip.AddrPortFrom(maybeCandidateAddr, s.Port) - } - } - return netip.AddrPort{} // no peerAPI - } else { - // [nodeBackend.peerCapsLocked] only returns/considers the - // [tailcfg.PeerCapMap] between the passed src and the - // _first_ host (/32 or /128) address for self. We are - // consistent with that behavior here. If self and - // maybeCandidate host addresses are of the same address - // family they either have the capability or not. We do not - // check against additional host addresses of the same - // address family. - return netip.AddrPort{} - } + dstAddr := dstPrefix.Addr() + if dstAddr.BitLen() == srcAddr.BitLen() { // same address family + // [nodeBackend.peerCapsLocked] only returns/considers the + // [tailcfg.PeerCapMap] between the passed src and the _first_ + // host (/32 or /128) address for self. We are consistent with + // that behavior here. If src and dst host addresses are of the + // same address family they either have the capability or not. + // We do not check against additional host addresses of the same + // address family. + return filt.CapsWithValues(srcAddr, dstAddr).HasCapability(cap) } } } - return netip.AddrPort{} + return false +} + +// candidatePeerRelay represents the identifiers and DERP home region ID for a +// peer relay server. +type candidatePeerRelay struct { + nodeKey key.NodePublic + discoKey key.DiscoPublic + derpHomeRegionID uint16 +} + +func (c *candidatePeerRelay) isValid() bool { + return !c.nodeKey.IsZero() && !c.discoKey.IsZero() } // onNodeViewsUpdate is called when a [NodeViewsUpdate] is received over the @@ -3792,18 +3981,22 @@ var ( metricRecvDiscoBadKey = clientmetric.NewCounter("magicsock_disco_recv_bad_key") metricRecvDiscoBadParse = clientmetric.NewCounter("magicsock_disco_recv_bad_parse") - metricRecvDiscoUDP = clientmetric.NewCounter("magicsock_disco_recv_udp") - metricRecvDiscoDERP = clientmetric.NewCounter("magicsock_disco_recv_derp") - metricRecvDiscoPing = clientmetric.NewCounter("magicsock_disco_recv_ping") - metricRecvDiscoPong = clientmetric.NewCounter("magicsock_disco_recv_pong") - metricRecvDiscoCallMeMaybe = clientmetric.NewCounter("magicsock_disco_recv_callmemaybe") - metricRecvDiscoCallMeMaybeVia = clientmetric.NewCounter("magicsock_disco_recv_callmemaybevia") - metricRecvDiscoCallMeMaybeBadNode = clientmetric.NewCounter("magicsock_disco_recv_callmemaybe_bad_node") - metricRecvDiscoCallMeMaybeViaBadNode = clientmetric.NewCounter("magicsock_disco_recv_callmemaybevia_bad_node") - metricRecvDiscoCallMeMaybeBadDisco = clientmetric.NewCounter("magicsock_disco_recv_callmemaybe_bad_disco") - metricRecvDiscoCallMeMaybeViaBadDisco = clientmetric.NewCounter("magicsock_disco_recv_callmemaybevia_bad_disco") - metricRecvDiscoDERPPeerNotHere = clientmetric.NewCounter("magicsock_disco_recv_derp_peer_not_here") - metricRecvDiscoDERPPeerGoneUnknown = clientmetric.NewCounter("magicsock_disco_recv_derp_peer_gone_unknown") + metricRecvDiscoUDP = clientmetric.NewCounter("magicsock_disco_recv_udp") + metricRecvDiscoDERP = clientmetric.NewCounter("magicsock_disco_recv_derp") + metricRecvDiscoPing = clientmetric.NewCounter("magicsock_disco_recv_ping") + metricRecvDiscoPong = clientmetric.NewCounter("magicsock_disco_recv_pong") + metricRecvDiscoCallMeMaybe = clientmetric.NewCounter("magicsock_disco_recv_callmemaybe") + metricRecvDiscoCallMeMaybeVia = clientmetric.NewCounter("magicsock_disco_recv_callmemaybevia") + metricRecvDiscoCallMeMaybeBadNode = clientmetric.NewCounter("magicsock_disco_recv_callmemaybe_bad_node") + metricRecvDiscoCallMeMaybeViaBadNode = clientmetric.NewCounter("magicsock_disco_recv_callmemaybevia_bad_node") + metricRecvDiscoCallMeMaybeBadDisco = clientmetric.NewCounter("magicsock_disco_recv_callmemaybe_bad_disco") + metricRecvDiscoCallMeMaybeViaBadDisco = clientmetric.NewCounter("magicsock_disco_recv_callmemaybevia_bad_disco") + metricRecvDiscoAllocUDPRelayEndpointRequest = clientmetric.NewCounter("magicsock_disco_recv_alloc_udp_relay_endpoint_request") + metricRecvDiscoAllocUDPRelayEndpointRequestBadDisco = clientmetric.NewCounter("magicsock_disco_recv_alloc_udp_relay_endpoint_request_bad_disco") + metricRecvDiscoAllocUDPRelayEndpointResponseBadDisco = clientmetric.NewCounter("magicsock_disco_recv_alloc_udp_relay_endpoint_response_bad_disco") + metricRecvDiscoAllocUDPRelayEndpointResponse = clientmetric.NewCounter("magicsock_disco_recv_alloc_udp_relay_endpoint_response") + metricRecvDiscoDERPPeerNotHere = clientmetric.NewCounter("magicsock_disco_recv_derp_peer_not_here") + metricRecvDiscoDERPPeerGoneUnknown = clientmetric.NewCounter("magicsock_disco_recv_derp_peer_gone_unknown") // metricDERPHomeChange is how many times our DERP home region DI has // changed from non-zero to a different non-zero. metricDERPHomeChange = clientmetric.NewCounter("derp_home_change") @@ -3985,6 +4178,22 @@ func (le *lazyEndpoint) FromPeer(peerPublicKey [32]byte) { } // PeerRelays returns the current set of candidate peer relays. -func (c *Conn) PeerRelays() set.Set[netip.AddrPort] { - return c.relayManager.getServers() +func (c *Conn) PeerRelays() set.Set[netip.Addr] { + candidatePeerRelays := c.relayManager.getServers() + servers := make(set.Set[netip.Addr], len(candidatePeerRelays)) + c.mu.Lock() + defer c.mu.Unlock() + for relay := range candidatePeerRelays { + pi, ok := c.peerMap.byNodeKey[relay.nodeKey] + if !ok { + if c.self.Key().Compare(relay.nodeKey) == 0 { + if c.self.Addresses().Len() > 0 { + servers.Add(c.self.Addresses().At(0).Addr()) + } + } + continue + } + servers.Add(pi.ep.nodeAddr) + } + return servers } diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 1d76e6c59..8a09df27d 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -19,7 +19,6 @@ import ( "net/http/httptest" "net/netip" "os" - "reflect" "runtime" "strconv" "strings" @@ -64,6 +63,7 @@ import ( "tailscale.com/types/netmap" "tailscale.com/types/nettype" "tailscale.com/types/ptr" + "tailscale.com/types/views" "tailscale.com/util/cibuild" "tailscale.com/util/eventbus" "tailscale.com/util/must" @@ -3384,258 +3384,6 @@ func Test_virtualNetworkID(t *testing.T) { } } -func Test_peerAPIIfCandidateRelayServer(t *testing.T) { - hostInfo := &tailcfg.Hostinfo{ - Services: []tailcfg.Service{ - { - Proto: tailcfg.PeerAPI4, - Port: 4, - }, - { - Proto: tailcfg.PeerAPI6, - Port: 6, - }, - }, - } - - selfOnlyIPv4 := &tailcfg.Node{ - ID: 1, - // Intentionally set a value < 120 to verify the statically compiled - // [tailcfg.CurrentCapabilityVersion] is used when self is - // maybeCandidate. - Cap: 119, - Addresses: []netip.Prefix{ - netip.MustParsePrefix("1.1.1.1/32"), - }, - Hostinfo: hostInfo.View(), - } - selfOnlyIPv6 := selfOnlyIPv4.Clone() - selfOnlyIPv6.Addresses[0] = netip.MustParsePrefix("::1/128") - - peerOnlyIPv4 := &tailcfg.Node{ - ID: 2, - Cap: 120, - Addresses: []netip.Prefix{ - netip.MustParsePrefix("2.2.2.2/32"), - }, - Hostinfo: hostInfo.View(), - } - - peerOnlyIPv4NotCapable := peerOnlyIPv4.Clone() - peerOnlyIPv4NotCapable.Cap = 119 - - peerOnlyIPv6 := peerOnlyIPv4.Clone() - peerOnlyIPv6.Addresses[0] = netip.MustParsePrefix("::2/128") - - peerOnlyIPv4ZeroCapVer := peerOnlyIPv4.Clone() - peerOnlyIPv4ZeroCapVer.Cap = 0 - - peerOnlyIPv4NilHostinfo := peerOnlyIPv4.Clone() - peerOnlyIPv4NilHostinfo.Hostinfo = tailcfg.HostinfoView{} - - tests := []struct { - name string - filt *filter.Filter - self tailcfg.NodeView - maybeCandidate tailcfg.NodeView - want netip.AddrPort - }{ - { - name: "match v4", - filt: filter.New([]filtertype.Match{ - { - Srcs: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, - Caps: []filtertype.CapMatch{ - { - Dst: netip.MustParsePrefix("1.1.1.1/32"), - Cap: tailcfg.PeerCapabilityRelayTarget, - }, - }, - }, - }, nil, nil, nil, nil, nil), - self: selfOnlyIPv4.View(), - maybeCandidate: peerOnlyIPv4.View(), - want: netip.MustParseAddrPort("2.2.2.2:4"), - }, - { - name: "match v4 self", - filt: filter.New([]filtertype.Match{ - { - Srcs: []netip.Prefix{selfOnlyIPv4.Addresses[0]}, - Caps: []filtertype.CapMatch{ - { - Dst: selfOnlyIPv4.Addresses[0], - Cap: tailcfg.PeerCapabilityRelayTarget, - }, - }, - }, - }, nil, nil, nil, nil, nil), - self: selfOnlyIPv4.View(), - maybeCandidate: selfOnlyIPv4.View(), - want: netip.AddrPortFrom(selfOnlyIPv4.Addresses[0].Addr(), 4), - }, - { - name: "match v6", - filt: filter.New([]filtertype.Match{ - { - Srcs: []netip.Prefix{netip.MustParsePrefix("::2/128")}, - Caps: []filtertype.CapMatch{ - { - Dst: netip.MustParsePrefix("::1/128"), - Cap: tailcfg.PeerCapabilityRelayTarget, - }, - }, - }, - }, nil, nil, nil, nil, nil), - self: selfOnlyIPv6.View(), - maybeCandidate: peerOnlyIPv6.View(), - want: netip.MustParseAddrPort("[::2]:6"), - }, - { - name: "match v6 self", - filt: filter.New([]filtertype.Match{ - { - Srcs: []netip.Prefix{selfOnlyIPv6.Addresses[0]}, - Caps: []filtertype.CapMatch{ - { - Dst: selfOnlyIPv6.Addresses[0], - Cap: tailcfg.PeerCapabilityRelayTarget, - }, - }, - }, - }, nil, nil, nil, nil, nil), - self: selfOnlyIPv6.View(), - maybeCandidate: selfOnlyIPv6.View(), - want: netip.AddrPortFrom(selfOnlyIPv6.Addresses[0].Addr(), 6), - }, - { - name: "peer incapable", - filt: filter.New([]filtertype.Match{ - { - Srcs: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, - Caps: []filtertype.CapMatch{ - { - Dst: netip.MustParsePrefix("1.1.1.1/32"), - Cap: tailcfg.PeerCapabilityRelayTarget, - }, - }, - }, - }, nil, nil, nil, nil, nil), - self: selfOnlyIPv4.View(), - maybeCandidate: peerOnlyIPv4NotCapable.View(), - }, - { - name: "no match dst", - filt: filter.New([]filtertype.Match{ - { - Srcs: []netip.Prefix{netip.MustParsePrefix("::2/128")}, - Caps: []filtertype.CapMatch{ - { - Dst: netip.MustParsePrefix("::3/128"), - Cap: tailcfg.PeerCapabilityRelayTarget, - }, - }, - }, - }, nil, nil, nil, nil, nil), - self: selfOnlyIPv6.View(), - maybeCandidate: peerOnlyIPv6.View(), - }, - { - name: "no match peer cap", - filt: filter.New([]filtertype.Match{ - { - Srcs: []netip.Prefix{netip.MustParsePrefix("::2/128")}, - Caps: []filtertype.CapMatch{ - { - Dst: netip.MustParsePrefix("::1/128"), - Cap: tailcfg.PeerCapabilityIngress, - }, - }, - }, - }, nil, nil, nil, nil, nil), - self: selfOnlyIPv6.View(), - maybeCandidate: peerOnlyIPv6.View(), - }, - { - name: "cap ver not relay capable", - filt: filter.New([]filtertype.Match{ - { - Srcs: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, - Caps: []filtertype.CapMatch{ - { - Dst: netip.MustParsePrefix("1.1.1.1/32"), - Cap: tailcfg.PeerCapabilityRelayTarget, - }, - }, - }, - }, nil, nil, nil, nil, nil), - self: peerOnlyIPv4.View(), - maybeCandidate: peerOnlyIPv4ZeroCapVer.View(), - }, - { - name: "nil filt", - filt: nil, - self: selfOnlyIPv4.View(), - maybeCandidate: peerOnlyIPv4.View(), - }, - { - name: "nil self", - filt: filter.New([]filtertype.Match{ - { - Srcs: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, - Caps: []filtertype.CapMatch{ - { - Dst: netip.MustParsePrefix("1.1.1.1/32"), - Cap: tailcfg.PeerCapabilityRelayTarget, - }, - }, - }, - }, nil, nil, nil, nil, nil), - self: tailcfg.NodeView{}, - maybeCandidate: peerOnlyIPv4.View(), - }, - { - name: "nil peer", - filt: filter.New([]filtertype.Match{ - { - Srcs: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, - Caps: []filtertype.CapMatch{ - { - Dst: netip.MustParsePrefix("1.1.1.1/32"), - Cap: tailcfg.PeerCapabilityRelayTarget, - }, - }, - }, - }, nil, nil, nil, nil, nil), - self: selfOnlyIPv4.View(), - maybeCandidate: tailcfg.NodeView{}, - }, - { - name: "nil peer hostinfo", - filt: filter.New([]filtertype.Match{ - { - Srcs: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, - Caps: []filtertype.CapMatch{ - { - Dst: netip.MustParsePrefix("1.1.1.1/32"), - Cap: tailcfg.PeerCapabilityRelayTarget, - }, - }, - }, - }, nil, nil, nil, nil, nil), - self: selfOnlyIPv4.View(), - maybeCandidate: peerOnlyIPv4NilHostinfo.View(), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := peerAPIIfCandidateRelayServer(tt.filt, tt.self, tt.maybeCandidate); !reflect.DeepEqual(got, tt.want) { - t.Errorf("peerAPIIfCandidateRelayServer() = %v, want %v", got, tt.want) - } - }) - } -} - func Test_looksLikeInitiationMsg(t *testing.T) { initMsg := make([]byte, device.MessageInitiationSize) binary.BigEndian.PutUint32(initMsg, device.MessageInitiationType) @@ -3675,3 +3423,268 @@ func Test_looksLikeInitiationMsg(t *testing.T) { }) } } + +func Test_nodeHasCap(t *testing.T) { + nodeAOnlyIPv4 := &tailcfg.Node{ + ID: 1, + Addresses: []netip.Prefix{ + netip.MustParsePrefix("1.1.1.1/32"), + }, + } + nodeBOnlyIPv6 := nodeAOnlyIPv4.Clone() + nodeBOnlyIPv6.Addresses[0] = netip.MustParsePrefix("::1/128") + + nodeCOnlyIPv4 := &tailcfg.Node{ + ID: 2, + Addresses: []netip.Prefix{ + netip.MustParsePrefix("2.2.2.2/32"), + }, + } + nodeDOnlyIPv6 := nodeCOnlyIPv4.Clone() + nodeDOnlyIPv6.Addresses[0] = netip.MustParsePrefix("::2/128") + + tests := []struct { + name string + filt *filter.Filter + src tailcfg.NodeView + dst tailcfg.NodeView + cap tailcfg.PeerCapability + want bool + }{ + { + name: "match v4", + filt: filter.New([]filtertype.Match{ + { + Srcs: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, + Caps: []filtertype.CapMatch{ + { + Dst: netip.MustParsePrefix("1.1.1.1/32"), + Cap: tailcfg.PeerCapabilityRelayTarget, + }, + }, + }, + }, nil, nil, nil, nil, nil), + src: nodeCOnlyIPv4.View(), + dst: nodeAOnlyIPv4.View(), + cap: tailcfg.PeerCapabilityRelayTarget, + want: true, + }, + { + name: "match v6", + filt: filter.New([]filtertype.Match{ + { + Srcs: []netip.Prefix{netip.MustParsePrefix("::2/128")}, + Caps: []filtertype.CapMatch{ + { + Dst: netip.MustParsePrefix("::1/128"), + Cap: tailcfg.PeerCapabilityRelayTarget, + }, + }, + }, + }, nil, nil, nil, nil, nil), + src: nodeDOnlyIPv6.View(), + dst: nodeBOnlyIPv6.View(), + cap: tailcfg.PeerCapabilityRelayTarget, + want: true, + }, + { + name: "no match CapMatch Dst", + filt: filter.New([]filtertype.Match{ + { + Srcs: []netip.Prefix{netip.MustParsePrefix("::2/128")}, + Caps: []filtertype.CapMatch{ + { + Dst: netip.MustParsePrefix("::3/128"), + Cap: tailcfg.PeerCapabilityRelayTarget, + }, + }, + }, + }, nil, nil, nil, nil, nil), + src: nodeDOnlyIPv6.View(), + dst: nodeBOnlyIPv6.View(), + cap: tailcfg.PeerCapabilityRelayTarget, + want: false, + }, + { + name: "no match peer cap", + filt: filter.New([]filtertype.Match{ + { + Srcs: []netip.Prefix{netip.MustParsePrefix("::2/128")}, + Caps: []filtertype.CapMatch{ + { + Dst: netip.MustParsePrefix("::1/128"), + Cap: tailcfg.PeerCapabilityIngress, + }, + }, + }, + }, nil, nil, nil, nil, nil), + src: nodeDOnlyIPv6.View(), + dst: nodeBOnlyIPv6.View(), + cap: tailcfg.PeerCapabilityRelayTarget, + want: false, + }, + { + name: "nil src", + filt: filter.New([]filtertype.Match{ + { + Srcs: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, + Caps: []filtertype.CapMatch{ + { + Dst: netip.MustParsePrefix("1.1.1.1/32"), + Cap: tailcfg.PeerCapabilityRelayTarget, + }, + }, + }, + }, nil, nil, nil, nil, nil), + src: tailcfg.NodeView{}, + dst: nodeAOnlyIPv4.View(), + cap: tailcfg.PeerCapabilityRelayTarget, + want: false, + }, + { + name: "nil dst", + filt: filter.New([]filtertype.Match{ + { + Srcs: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, + Caps: []filtertype.CapMatch{ + { + Dst: netip.MustParsePrefix("1.1.1.1/32"), + Cap: tailcfg.PeerCapabilityRelayTarget, + }, + }, + }, + }, nil, nil, nil, nil, nil), + src: nodeCOnlyIPv4.View(), + dst: tailcfg.NodeView{}, + cap: tailcfg.PeerCapabilityRelayTarget, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := nodeHasCap(tt.filt, tt.src, tt.dst, tt.cap); got != tt.want { + t.Errorf("nodeHasCap() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestConn_updateRelayServersSet(t *testing.T) { + peerNodeCandidateRelay := &tailcfg.Node{ + Cap: 121, + ID: 1, + Addresses: []netip.Prefix{ + netip.MustParsePrefix("1.1.1.1/32"), + }, + HomeDERP: 1, + Key: key.NewNode().Public(), + DiscoKey: key.NewDisco().Public(), + } + + peerNodeNotCandidateRelayCapVer := &tailcfg.Node{ + Cap: 120, // intentionally lower to fail capVer check + ID: 1, + Addresses: []netip.Prefix{ + netip.MustParsePrefix("1.1.1.1/32"), + }, + HomeDERP: 1, + Key: key.NewNode().Public(), + DiscoKey: key.NewDisco().Public(), + } + + selfNode := &tailcfg.Node{ + Cap: 120, // intentionally lower than capVerIsRelayCapable to verify self check + ID: 2, + Addresses: []netip.Prefix{ + netip.MustParsePrefix("2.2.2.2/32"), + }, + HomeDERP: 2, + Key: key.NewNode().Public(), + DiscoKey: key.NewDisco().Public(), + } + + tests := []struct { + name string + filt *filter.Filter + self tailcfg.NodeView + peers views.Slice[tailcfg.NodeView] + wantRelayServers set.Set[candidatePeerRelay] + }{ + { + name: "candidate relay server", + filt: filter.New([]filtertype.Match{ + { + Srcs: peerNodeCandidateRelay.Addresses, + Caps: []filtertype.CapMatch{ + { + Dst: selfNode.Addresses[0], + Cap: tailcfg.PeerCapabilityRelayTarget, + }, + }, + }, + }, nil, nil, nil, nil, nil), + self: selfNode.View(), + peers: views.SliceOf([]tailcfg.NodeView{peerNodeCandidateRelay.View()}), + wantRelayServers: set.SetOf([]candidatePeerRelay{ + { + nodeKey: peerNodeCandidateRelay.Key, + discoKey: peerNodeCandidateRelay.DiscoKey, + derpHomeRegionID: 1, + }, + }), + }, + { + name: "self candidate relay server", + filt: filter.New([]filtertype.Match{ + { + Srcs: selfNode.Addresses, + Caps: []filtertype.CapMatch{ + { + Dst: selfNode.Addresses[0], + Cap: tailcfg.PeerCapabilityRelayTarget, + }, + }, + }, + }, nil, nil, nil, nil, nil), + self: selfNode.View(), + peers: views.SliceOf([]tailcfg.NodeView{selfNode.View()}), + wantRelayServers: set.SetOf([]candidatePeerRelay{ + { + nodeKey: selfNode.Key, + discoKey: selfNode.DiscoKey, + derpHomeRegionID: 2, + }, + }), + }, + { + name: "no candidate relay server", + filt: filter.New([]filtertype.Match{ + { + Srcs: peerNodeNotCandidateRelayCapVer.Addresses, + Caps: []filtertype.CapMatch{ + { + Dst: selfNode.Addresses[0], + Cap: tailcfg.PeerCapabilityRelayTarget, + }, + }, + }, + }, nil, nil, nil, nil, nil), + self: selfNode.View(), + peers: views.SliceOf([]tailcfg.NodeView{peerNodeNotCandidateRelayCapVer.View()}), + wantRelayServers: make(set.Set[candidatePeerRelay]), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Conn{} + c.updateRelayServersSet(tt.filt, tt.self, tt.peers) + got := c.relayManager.getServers() + if !got.Equal(tt.wantRelayServers) { + t.Fatalf("got: %v != want: %v", got, tt.wantRelayServers) + } + if len(tt.wantRelayServers) > 0 != c.hasPeerRelayServers.Load() { + t.Fatalf("c.hasPeerRelayServers: %v != wantRelayServers: %v", c.hasPeerRelayServers.Load(), tt.wantRelayServers) + } + }) + } +} diff --git a/wgengine/magicsock/relaymanager.go b/wgengine/magicsock/relaymanager.go index d7acf80b5..ad8c5fc76 100644 --- a/wgengine/magicsock/relaymanager.go +++ b/wgengine/magicsock/relaymanager.go @@ -4,23 +4,18 @@ package magicsock import ( - "bytes" "context" - "encoding/json" "errors" - "fmt" - "io" - "net/http" "net/netip" - "strconv" "sync" "time" "tailscale.com/disco" "tailscale.com/net/stun" udprelay "tailscale.com/net/udprelay/endpoint" + "tailscale.com/tailcfg" + "tailscale.com/tstime" "tailscale.com/types/key" - "tailscale.com/util/httpm" "tailscale.com/util/set" ) @@ -38,26 +33,28 @@ type relayManager struct { // =================================================================== // The following fields are owned by a single goroutine, runLoop(). - serversByAddrPort map[netip.AddrPort]key.DiscoPublic - serversByDisco map[key.DiscoPublic]netip.AddrPort - allocWorkByEndpoint map[*endpoint]*relayEndpointAllocWork - handshakeWorkByEndpointByServerDisco map[*endpoint]map[key.DiscoPublic]*relayHandshakeWork - handshakeWorkByServerDiscoVNI map[serverDiscoVNI]*relayHandshakeWork - handshakeWorkAwaitingPong map[*relayHandshakeWork]addrPortVNI - addrPortVNIToHandshakeWork map[addrPortVNI]*relayHandshakeWork - handshakeGeneration uint32 + serversByNodeKey map[key.NodePublic]candidatePeerRelay + allocWorkByCandidatePeerRelayByEndpoint map[*endpoint]map[candidatePeerRelay]*relayEndpointAllocWork + allocWorkByDiscoKeysByServerNodeKey map[key.NodePublic]map[key.SortedPairOfDiscoPublic]*relayEndpointAllocWork + handshakeWorkByServerDiscoByEndpoint map[*endpoint]map[key.DiscoPublic]*relayHandshakeWork + handshakeWorkByServerDiscoVNI map[serverDiscoVNI]*relayHandshakeWork + handshakeWorkAwaitingPong map[*relayHandshakeWork]addrPortVNI + addrPortVNIToHandshakeWork map[addrPortVNI]*relayHandshakeWork + handshakeGeneration uint32 + allocGeneration uint32 // =================================================================== // The following chan fields serve event inputs to a single goroutine, // runLoop(). - startDiscoveryCh chan endpointWithLastBest - allocateWorkDoneCh chan relayEndpointAllocWorkDoneEvent - handshakeWorkDoneCh chan relayEndpointHandshakeWorkDoneEvent - cancelWorkCh chan *endpoint - newServerEndpointCh chan newRelayServerEndpointEvent - rxHandshakeDiscoMsgCh chan relayHandshakeDiscoMsgEvent - serversCh chan set.Set[netip.AddrPort] - getServersCh chan chan set.Set[netip.AddrPort] + startDiscoveryCh chan endpointWithLastBest + allocateWorkDoneCh chan relayEndpointAllocWorkDoneEvent + handshakeWorkDoneCh chan relayEndpointHandshakeWorkDoneEvent + cancelWorkCh chan *endpoint + newServerEndpointCh chan newRelayServerEndpointEvent + rxDiscoMsgCh chan relayDiscoMsgEvent + serversCh chan set.Set[candidatePeerRelay] + getServersCh chan chan set.Set[candidatePeerRelay] + derpHomeChangeCh chan derpHomeChangeEvent discoInfoMu sync.Mutex // guards the following field discoInfoByServerDisco map[key.DiscoPublic]*relayHandshakeDiscoInfo @@ -86,7 +83,7 @@ type relayHandshakeWork struct { // relayManager.handshakeWorkDoneCh if runLoop() can receive it. runLoop() // must select{} read on doneCh to prevent deadlock when attempting to write // to rxDiscoMsgCh. - rxDiscoMsgCh chan relayHandshakeDiscoMsgEvent + rxDiscoMsgCh chan relayDiscoMsgEvent doneCh chan relayEndpointHandshakeWorkDoneEvent ctx context.Context @@ -100,14 +97,15 @@ type relayHandshakeWork struct { type newRelayServerEndpointEvent struct { wlb endpointWithLastBest se udprelay.ServerEndpoint - server netip.AddrPort // zero value if learned via [disco.CallMeMaybeVia] + server candidatePeerRelay // zero value if learned via [disco.CallMeMaybeVia] } // relayEndpointAllocWorkDoneEvent indicates relay server endpoint allocation // work for an [*endpoint] has completed. This structure is immutable once // initialized. type relayEndpointAllocWorkDoneEvent struct { - work *relayEndpointAllocWork + work *relayEndpointAllocWork + allocated udprelay.ServerEndpoint // !allocated.ServerDisco.IsZero() if allocation succeeded } // relayEndpointHandshakeWorkDoneEvent indicates relay server endpoint handshake @@ -122,18 +120,42 @@ type relayEndpointHandshakeWorkDoneEvent struct { // hasActiveWorkRunLoop returns true if there is outstanding allocation or // handshaking work for any endpoint, otherwise it returns false. func (r *relayManager) hasActiveWorkRunLoop() bool { - return len(r.allocWorkByEndpoint) > 0 || len(r.handshakeWorkByEndpointByServerDisco) > 0 + return len(r.allocWorkByCandidatePeerRelayByEndpoint) > 0 || len(r.handshakeWorkByServerDiscoByEndpoint) > 0 } // hasActiveWorkForEndpointRunLoop returns true if there is outstanding // allocation or handshaking work for the provided endpoint, otherwise it // returns false. func (r *relayManager) hasActiveWorkForEndpointRunLoop(ep *endpoint) bool { - _, handshakeWork := r.handshakeWorkByEndpointByServerDisco[ep] - _, allocWork := r.allocWorkByEndpoint[ep] + _, handshakeWork := r.handshakeWorkByServerDiscoByEndpoint[ep] + _, allocWork := r.allocWorkByCandidatePeerRelayByEndpoint[ep] return handshakeWork || allocWork } +// derpHomeChangeEvent represents a change in the DERP home region for the +// node identified by nodeKey. This structure is immutable once initialized. +type derpHomeChangeEvent struct { + nodeKey key.NodePublic + regionID uint16 +} + +// handleDERPHomeChange handles a DERP home change event for nodeKey and +// regionID. +func (r *relayManager) handleDERPHomeChange(nodeKey key.NodePublic, regionID uint16) { + relayManagerInputEvent(r, nil, &r.derpHomeChangeCh, derpHomeChangeEvent{ + nodeKey: nodeKey, + regionID: regionID, + }) +} + +func (r *relayManager) handleDERPHomeChangeRunLoop(event derpHomeChangeEvent) { + c, ok := r.serversByNodeKey[event.nodeKey] + if ok { + c.derpHomeRegionID = event.regionID + r.serversByNodeKey[event.nodeKey] = c + } +} + // runLoop is a form of event loop. It ensures exclusive access to most of // [relayManager] state. func (r *relayManager) runLoop() { @@ -151,13 +173,7 @@ func (r *relayManager) runLoop() { return } case done := <-r.allocateWorkDoneCh: - work, ok := r.allocWorkByEndpoint[done.work.ep] - if ok && work == done.work { - // Verify the work in the map is the same as the one that we're - // cleaning up. New events on r.startDiscoveryCh can - // overwrite pre-existing keys. - delete(r.allocWorkByEndpoint, done.work.ep) - } + r.handleAllocWorkDoneRunLoop(done) if !r.hasActiveWorkRunLoop() { return } @@ -176,8 +192,8 @@ func (r *relayManager) runLoop() { if !r.hasActiveWorkRunLoop() { return } - case discoMsgEvent := <-r.rxHandshakeDiscoMsgCh: - r.handleRxHandshakeDiscoMsgRunLoop(discoMsgEvent) + case discoMsgEvent := <-r.rxDiscoMsgCh: + r.handleRxDiscoMsgRunLoop(discoMsgEvent) if !r.hasActiveWorkRunLoop() { return } @@ -191,69 +207,77 @@ func (r *relayManager) runLoop() { if !r.hasActiveWorkRunLoop() { return } + case derpHomeChange := <-r.derpHomeChangeCh: + r.handleDERPHomeChangeRunLoop(derpHomeChange) + if !r.hasActiveWorkRunLoop() { + return + } } } } -func (r *relayManager) handleGetServersRunLoop(getServersCh chan set.Set[netip.AddrPort]) { - servers := make(set.Set[netip.AddrPort], len(r.serversByAddrPort)) - for server := range r.serversByAddrPort { - servers.Add(server) +func (r *relayManager) handleGetServersRunLoop(getServersCh chan set.Set[candidatePeerRelay]) { + servers := make(set.Set[candidatePeerRelay], len(r.serversByNodeKey)) + for _, v := range r.serversByNodeKey { + servers.Add(v) } getServersCh <- servers } -func (r *relayManager) getServers() set.Set[netip.AddrPort] { - ch := make(chan set.Set[netip.AddrPort]) +func (r *relayManager) getServers() set.Set[candidatePeerRelay] { + ch := make(chan set.Set[candidatePeerRelay]) relayManagerInputEvent(r, nil, &r.getServersCh, ch) return <-ch } -func (r *relayManager) handleServersUpdateRunLoop(update set.Set[netip.AddrPort]) { - for k, v := range r.serversByAddrPort { - if !update.Contains(k) { - delete(r.serversByAddrPort, k) - delete(r.serversByDisco, v) +func (r *relayManager) handleServersUpdateRunLoop(update set.Set[candidatePeerRelay]) { + for _, v := range r.serversByNodeKey { + if !update.Contains(v) { + delete(r.serversByNodeKey, v.nodeKey) } } for _, v := range update.Slice() { - _, ok := r.serversByAddrPort[v] - if ok { - // don't zero known disco keys - continue - } - r.serversByAddrPort[v] = key.DiscoPublic{} + r.serversByNodeKey[v.nodeKey] = v } } -type relayHandshakeDiscoMsgEvent struct { - conn *Conn // for access to [Conn] if there is no associated [relayHandshakeWork] - msg disco.Message - disco key.DiscoPublic - from netip.AddrPort - vni uint32 - at time.Time +type relayDiscoMsgEvent struct { + conn *Conn // for access to [Conn] if there is no associated [relayHandshakeWork] + msg disco.Message + relayServerNodeKey key.NodePublic // nonzero if msg is a [*disco.AllocateUDPRelayEndpointResponse] + disco key.DiscoPublic + from netip.AddrPort + vni uint32 + at time.Time } // relayEndpointAllocWork serves to track in-progress relay endpoint allocation // for an [*endpoint]. This structure is immutable once initialized. type relayEndpointAllocWork struct { - // ep is the [*endpoint] associated with the work - ep *endpoint - // cancel() will signal all associated goroutines to return + wlb endpointWithLastBest + discoKeys key.SortedPairOfDiscoPublic + candidatePeerRelay candidatePeerRelay + + // allocateServerEndpoint() always writes to doneCh (len 1) when it + // returns. It may end up writing the same event afterward to + // [relayManager.allocateWorkDoneCh] if runLoop() can receive it. runLoop() + // must select{} read on doneCh to prevent deadlock when attempting to write + // to rxDiscoMsgCh. + rxDiscoMsgCh chan *disco.AllocateUDPRelayEndpointResponse + doneCh chan relayEndpointAllocWorkDoneEvent + + ctx context.Context cancel context.CancelFunc - // wg.Wait() will return once all associated goroutines have returned - wg *sync.WaitGroup } // init initializes [relayManager] if it is not already initialized. func (r *relayManager) init() { r.initOnce.Do(func() { r.discoInfoByServerDisco = make(map[key.DiscoPublic]*relayHandshakeDiscoInfo) - r.serversByDisco = make(map[key.DiscoPublic]netip.AddrPort) - r.serversByAddrPort = make(map[netip.AddrPort]key.DiscoPublic) - r.allocWorkByEndpoint = make(map[*endpoint]*relayEndpointAllocWork) - r.handshakeWorkByEndpointByServerDisco = make(map[*endpoint]map[key.DiscoPublic]*relayHandshakeWork) + r.serversByNodeKey = make(map[key.NodePublic]candidatePeerRelay) + r.allocWorkByCandidatePeerRelayByEndpoint = make(map[*endpoint]map[candidatePeerRelay]*relayEndpointAllocWork) + r.allocWorkByDiscoKeysByServerNodeKey = make(map[key.NodePublic]map[key.SortedPairOfDiscoPublic]*relayEndpointAllocWork) + r.handshakeWorkByServerDiscoByEndpoint = make(map[*endpoint]map[key.DiscoPublic]*relayHandshakeWork) r.handshakeWorkByServerDiscoVNI = make(map[serverDiscoVNI]*relayHandshakeWork) r.handshakeWorkAwaitingPong = make(map[*relayHandshakeWork]addrPortVNI) r.addrPortVNIToHandshakeWork = make(map[addrPortVNI]*relayHandshakeWork) @@ -262,9 +286,10 @@ func (r *relayManager) init() { r.handshakeWorkDoneCh = make(chan relayEndpointHandshakeWorkDoneEvent) r.cancelWorkCh = make(chan *endpoint) r.newServerEndpointCh = make(chan newRelayServerEndpointEvent) - r.rxHandshakeDiscoMsgCh = make(chan relayHandshakeDiscoMsgEvent) - r.serversCh = make(chan set.Set[netip.AddrPort]) - r.getServersCh = make(chan chan set.Set[netip.AddrPort]) + r.rxDiscoMsgCh = make(chan relayDiscoMsgEvent) + r.serversCh = make(chan set.Set[candidatePeerRelay]) + r.getServersCh = make(chan chan set.Set[candidatePeerRelay]) + r.derpHomeChangeCh = make(chan derpHomeChangeEvent) r.runLoopStoppedCh = make(chan struct{}, 1) r.runLoopStoppedCh <- struct{}{} }) @@ -330,6 +355,7 @@ func (r *relayManager) discoInfo(serverDisco key.DiscoPublic) (_ *discoInfo, ok func (r *relayManager) handleCallMeMaybeVia(ep *endpoint, lastBest addrQuality, lastBestIsTrusted bool, dm *disco.CallMeMaybeVia) { se := udprelay.ServerEndpoint{ ServerDisco: dm.ServerDisco, + ClientDisco: dm.ClientDisco, LamportID: dm.LamportID, AddrPorts: dm.AddrPorts, VNI: dm.VNI, @@ -346,14 +372,25 @@ func (r *relayManager) handleCallMeMaybeVia(ep *endpoint, lastBest addrQuality, }) } -// handleGeneveEncapDiscoMsg handles reception of Geneve-encapsulated disco -// messages. -func (r *relayManager) handleGeneveEncapDiscoMsg(conn *Conn, dm disco.Message, di *discoInfo, src epAddr) { - relayManagerInputEvent(r, nil, &r.rxHandshakeDiscoMsgCh, relayHandshakeDiscoMsgEvent{conn: conn, msg: dm, disco: di.discoKey, from: src.ap, vni: src.vni.get(), at: time.Now()}) +// handleRxDiscoMsg handles reception of disco messages that [relayManager] +// may be interested in. This includes all Geneve-encapsulated disco messages +// and [*disco.AllocateUDPRelayEndpointResponse]. If dm is a +// [*disco.AllocateUDPRelayEndpointResponse] then relayServerNodeKey must be +// nonzero. +func (r *relayManager) handleRxDiscoMsg(conn *Conn, dm disco.Message, relayServerNodeKey key.NodePublic, discoKey key.DiscoPublic, src epAddr) { + relayManagerInputEvent(r, nil, &r.rxDiscoMsgCh, relayDiscoMsgEvent{ + conn: conn, + msg: dm, + relayServerNodeKey: relayServerNodeKey, + disco: discoKey, + from: src.ap, + vni: src.vni.get(), + at: time.Now(), + }) } // handleRelayServersSet handles an update of the complete relay server set. -func (r *relayManager) handleRelayServersSet(servers set.Set[netip.AddrPort]) { +func (r *relayManager) handleRelayServersSet(servers set.Set[candidatePeerRelay]) { relayManagerInputEvent(r, nil, &r.serversCh, servers) } @@ -396,7 +433,11 @@ type endpointWithLastBest struct { // startUDPRelayPathDiscoveryFor starts UDP relay path discovery for ep on all // known relay servers if ep has no in-progress work. func (r *relayManager) startUDPRelayPathDiscoveryFor(ep *endpoint, lastBest addrQuality, lastBestIsTrusted bool) { - relayManagerInputEvent(r, nil, &r.startDiscoveryCh, endpointWithLastBest{ep, lastBest, lastBestIsTrusted}) + relayManagerInputEvent(r, nil, &r.startDiscoveryCh, endpointWithLastBest{ + ep: ep, + lastBest: lastBest, + lastBestIsTrusted: lastBestIsTrusted, + }) } // stopWork stops all outstanding allocation & handshaking work for 'ep'. @@ -407,13 +448,15 @@ func (r *relayManager) stopWork(ep *endpoint) { // stopWorkRunLoop cancels & clears outstanding allocation and handshaking // work for 'ep'. func (r *relayManager) stopWorkRunLoop(ep *endpoint) { - allocWork, ok := r.allocWorkByEndpoint[ep] + byDiscoKeys, ok := r.allocWorkByCandidatePeerRelayByEndpoint[ep] if ok { - allocWork.cancel() - allocWork.wg.Wait() - delete(r.allocWorkByEndpoint, ep) + for _, work := range byDiscoKeys { + work.cancel() + done := <-work.doneCh + r.handleAllocWorkDoneRunLoop(done) + } } - byServerDisco, ok := r.handshakeWorkByEndpointByServerDisco[ep] + byServerDisco, ok := r.handshakeWorkByServerDiscoByEndpoint[ep] if ok { for _, handshakeWork := range byServerDisco { handshakeWork.cancel() @@ -430,13 +473,33 @@ type addrPortVNI struct { vni uint32 } -func (r *relayManager) handleRxHandshakeDiscoMsgRunLoop(event relayHandshakeDiscoMsgEvent) { +func (r *relayManager) handleRxDiscoMsgRunLoop(event relayDiscoMsgEvent) { var ( work *relayHandshakeWork ok bool ) apv := addrPortVNI{event.from, event.vni} switch msg := event.msg.(type) { + case *disco.AllocateUDPRelayEndpointResponse: + sorted := key.NewSortedPairOfDiscoPublic(msg.ClientDisco[0], msg.ClientDisco[1]) + byDiscoKeys, ok := r.allocWorkByDiscoKeysByServerNodeKey[event.relayServerNodeKey] + if !ok { + // No outstanding work tied to this relay sever, discard. + return + } + allocWork, ok := byDiscoKeys[sorted] + if !ok { + // No outstanding work tied to these disco keys, discard. + return + } + select { + case done := <-allocWork.doneCh: + // allocateServerEndpoint returned, clean up its state + r.handleAllocWorkDoneRunLoop(done) + return + case allocWork.rxDiscoMsgCh <- msg: + return + } case *disco.BindUDPRelayEndpointChallenge: work, ok = r.handshakeWorkByServerDiscoVNI[serverDiscoVNI{event.disco, event.vni}] if !ok { @@ -504,8 +567,39 @@ func (r *relayManager) handleRxHandshakeDiscoMsgRunLoop(event relayHandshakeDisc } } +func (r *relayManager) handleAllocWorkDoneRunLoop(done relayEndpointAllocWorkDoneEvent) { + byCandidatePeerRelay, ok := r.allocWorkByCandidatePeerRelayByEndpoint[done.work.wlb.ep] + if !ok { + return + } + work, ok := byCandidatePeerRelay[done.work.candidatePeerRelay] + if !ok || work != done.work { + return + } + delete(byCandidatePeerRelay, done.work.candidatePeerRelay) + if len(byCandidatePeerRelay) == 0 { + delete(r.allocWorkByCandidatePeerRelayByEndpoint, done.work.wlb.ep) + } + byDiscoKeys, ok := r.allocWorkByDiscoKeysByServerNodeKey[done.work.candidatePeerRelay.nodeKey] + if !ok { + // unexpected + return + } + delete(byDiscoKeys, done.work.discoKeys) + if len(byDiscoKeys) == 0 { + delete(r.allocWorkByDiscoKeysByServerNodeKey, done.work.candidatePeerRelay.nodeKey) + } + if !done.allocated.ServerDisco.IsZero() { + r.handleNewServerEndpointRunLoop(newRelayServerEndpointEvent{ + wlb: done.work.wlb, + se: done.allocated, + server: done.work.candidatePeerRelay, + }) + } +} + func (r *relayManager) handleHandshakeWorkDoneRunLoop(done relayEndpointHandshakeWorkDoneEvent) { - byServerDisco, ok := r.handshakeWorkByEndpointByServerDisco[done.work.wlb.ep] + byServerDisco, ok := r.handshakeWorkByServerDiscoByEndpoint[done.work.wlb.ep] if !ok { return } @@ -515,7 +609,7 @@ func (r *relayManager) handleHandshakeWorkDoneRunLoop(done relayEndpointHandshak } delete(byServerDisco, done.work.se.ServerDisco) if len(byServerDisco) == 0 { - delete(r.handshakeWorkByEndpointByServerDisco, done.work.wlb.ep) + delete(r.handshakeWorkByServerDiscoByEndpoint, done.work.wlb.ep) } delete(r.handshakeWorkByServerDiscoVNI, serverDiscoVNI{done.work.se.ServerDisco, done.work.se.VNI}) apv, ok := r.handshakeWorkAwaitingPong[work] @@ -562,7 +656,7 @@ func (r *relayManager) handleNewServerEndpointRunLoop(newServerEndpoint newRelay } // Check for duplicate work by [*endpoint] + server disco. - byServerDisco, ok := r.handshakeWorkByEndpointByServerDisco[newServerEndpoint.wlb.ep] + byServerDisco, ok := r.handshakeWorkByServerDiscoByEndpoint[newServerEndpoint.wlb.ep] if ok { existingWork, ok := byServerDisco[newServerEndpoint.se.ServerDisco] if ok { @@ -580,33 +674,9 @@ func (r *relayManager) handleNewServerEndpointRunLoop(newServerEndpoint newRelay // We're now reasonably sure we're dealing with the latest // [udprelay.ServerEndpoint] from a server event order perspective - // (LamportID). Update server disco key tracking if appropriate. - if newServerEndpoint.server.IsValid() { - serverDisco, ok := r.serversByAddrPort[newServerEndpoint.server] - if !ok { - // Allocation raced with an update to our known servers set. This - // server is no longer known. Return early. - return - } - if serverDisco.Compare(newServerEndpoint.se.ServerDisco) != 0 { - // The server's disco key has either changed, or simply become - // known for the first time. In the former case we end up detaching - // any in-progress handshake work from a "known" relay server. - // Practically speaking we expect the detached work to fail - // if the server key did in fact change (server restart) while we - // were attempting to handshake with it. It is possible, though - // unlikely, for a server addr:port to effectively move between - // nodes. Either way, there is no harm in detaching existing work, - // and we explicitly let that happen for the rare case the detached - // handshake would complete and remain functional. - delete(r.serversByDisco, serverDisco) - delete(r.serversByAddrPort, newServerEndpoint.server) - r.serversByDisco[serverDisco] = newServerEndpoint.server - r.serversByAddrPort[newServerEndpoint.server] = serverDisco - } - } + // (LamportID). - if newServerEndpoint.server.IsValid() { + if newServerEndpoint.server.isValid() { // Send a [disco.CallMeMaybeVia] to the remote peer if we allocated this // endpoint, regardless of if we start a handshake below. go r.sendCallMeMaybeVia(newServerEndpoint.wlb.ep, newServerEndpoint.se) @@ -641,14 +711,14 @@ func (r *relayManager) handleNewServerEndpointRunLoop(newServerEndpoint newRelay work := &relayHandshakeWork{ wlb: newServerEndpoint.wlb, se: newServerEndpoint.se, - rxDiscoMsgCh: make(chan relayHandshakeDiscoMsgEvent), + rxDiscoMsgCh: make(chan relayDiscoMsgEvent), doneCh: make(chan relayEndpointHandshakeWorkDoneEvent, 1), ctx: ctx, cancel: cancel, } if byServerDisco == nil { byServerDisco = make(map[key.DiscoPublic]*relayHandshakeWork) - r.handshakeWorkByEndpointByServerDisco[newServerEndpoint.wlb.ep] = byServerDisco + r.handshakeWorkByServerDiscoByEndpoint[newServerEndpoint.wlb.ep] = byServerDisco } byServerDisco[newServerEndpoint.se.ServerDisco] = work r.handshakeWorkByServerDiscoVNI[sdv] = work @@ -674,12 +744,15 @@ func (r *relayManager) sendCallMeMaybeVia(ep *endpoint, se udprelay.ServerEndpoi return } callMeMaybeVia := &disco.CallMeMaybeVia{ - ServerDisco: se.ServerDisco, - LamportID: se.LamportID, - VNI: se.VNI, - BindLifetime: se.BindLifetime.Duration, - SteadyStateLifetime: se.SteadyStateLifetime.Duration, - AddrPorts: se.AddrPorts, + UDPRelayEndpoint: disco.UDPRelayEndpoint{ + ServerDisco: se.ServerDisco, + ClientDisco: se.ClientDisco, + LamportID: se.LamportID, + VNI: se.VNI, + BindLifetime: se.BindLifetime.Duration, + SteadyStateLifetime: se.SteadyStateLifetime.Duration, + AddrPorts: se.AddrPorts, + }, } ep.c.sendDiscoMessage(epAddr{ap: derpAddr}, ep.publicKey, epDisco.key, callMeMaybeVia, discoVerboseLog) } @@ -800,7 +873,7 @@ func (r *relayManager) handshakeServerEndpoint(work *relayHandshakeWork, generat // one. // // We don't need to TX a pong, that was already handled for us - // in handleRxHandshakeDiscoMsgRunLoop(). + // in handleRxDiscoMsgRunLoop(). txPing(msgEvent.from, nil) case *disco.Pong: at, ok := sentPingAt[msg.TxID] @@ -823,104 +896,113 @@ func (r *relayManager) handshakeServerEndpoint(work *relayHandshakeWork, generat } } +const allocateUDPRelayEndpointRequestTimeout = time.Second * 10 + +func (r *relayManager) allocateServerEndpoint(work *relayEndpointAllocWork, generation uint32) { + done := relayEndpointAllocWorkDoneEvent{work: work} + + defer func() { + work.doneCh <- done + relayManagerInputEvent(r, work.ctx, &r.allocateWorkDoneCh, done) + work.cancel() + }() + + dm := &disco.AllocateUDPRelayEndpointRequest{ + ClientDisco: work.discoKeys.Get(), + Generation: generation, + } + + sendAllocReq := func() { + work.wlb.ep.c.sendDiscoAllocateUDPRelayEndpointRequest( + epAddr{ + ap: netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, work.candidatePeerRelay.derpHomeRegionID), + }, + work.candidatePeerRelay.nodeKey, + work.candidatePeerRelay.discoKey, + dm, + discoVerboseLog, + ) + } + go sendAllocReq() + + returnAfterTimer := time.NewTimer(allocateUDPRelayEndpointRequestTimeout) + defer returnAfterTimer.Stop() + // While connections to DERP are over TCP, they can be lossy on the DERP + // server when data moves between the two independent streams. Also, the + // peer relay server may not be "ready" (see [tailscale.com/net/udprelay.ErrServerNotReady]). + // So, start a timer to retry once if needed. + retryAfterTimer := time.NewTimer(udprelay.ServerRetryAfter) + defer retryAfterTimer.Stop() + + for { + select { + case <-work.ctx.Done(): + return + case <-returnAfterTimer.C: + return + case <-retryAfterTimer.C: + go sendAllocReq() + case resp := <-work.rxDiscoMsgCh: + if resp.Generation != generation || + !work.discoKeys.Equal(key.NewSortedPairOfDiscoPublic(resp.ClientDisco[0], resp.ClientDisco[1])) { + continue + } + done.allocated = udprelay.ServerEndpoint{ + ServerDisco: resp.ServerDisco, + ClientDisco: resp.ClientDisco, + LamportID: resp.LamportID, + AddrPorts: resp.AddrPorts, + VNI: resp.VNI, + BindLifetime: tstime.GoDuration{Duration: resp.BindLifetime}, + SteadyStateLifetime: tstime.GoDuration{Duration: resp.SteadyStateLifetime}, + } + return + } + } +} + func (r *relayManager) allocateAllServersRunLoop(wlb endpointWithLastBest) { - if len(r.serversByAddrPort) == 0 { + if len(r.serversByNodeKey) == 0 { return } - ctx, cancel := context.WithCancel(context.Background()) - started := &relayEndpointAllocWork{ep: wlb.ep, cancel: cancel, wg: &sync.WaitGroup{}} - for k := range r.serversByAddrPort { - started.wg.Add(1) - go r.allocateSingleServer(ctx, started.wg, k, wlb) - } - r.allocWorkByEndpoint[wlb.ep] = started - go func() { - started.wg.Wait() - relayManagerInputEvent(r, ctx, &r.allocateWorkDoneCh, relayEndpointAllocWorkDoneEvent{work: started}) - // cleanup context cancellation must come after the - // relayManagerInputEvent call, otherwise it returns early without - // writing the event to runLoop(). - started.cancel() - }() -} - -type errNotReady struct{ retryAfter time.Duration } - -func (e errNotReady) Error() string { - return fmt.Sprintf("server not ready, retry after %v", e.retryAfter) -} - -const reqTimeout = time.Second * 10 - -func doAllocate(ctx context.Context, server netip.AddrPort, discoKeys [2]key.DiscoPublic) (udprelay.ServerEndpoint, error) { - var reqBody bytes.Buffer - type allocateRelayEndpointReq struct { - DiscoKeys []key.DiscoPublic - } - a := &allocateRelayEndpointReq{ - DiscoKeys: []key.DiscoPublic{discoKeys[0], discoKeys[1]}, - } - err := json.NewEncoder(&reqBody).Encode(a) - if err != nil { - return udprelay.ServerEndpoint{}, err - } - reqCtx, cancel := context.WithTimeout(ctx, reqTimeout) - defer cancel() - req, err := http.NewRequestWithContext(reqCtx, httpm.POST, "http://"+server.String()+"/v0/relay/endpoint", &reqBody) - if err != nil { - return udprelay.ServerEndpoint{}, err - } - resp, err := http.DefaultClient.Do(req) - if err != nil { - return udprelay.ServerEndpoint{}, err - } - defer resp.Body.Close() - switch resp.StatusCode { - case http.StatusOK: - var se udprelay.ServerEndpoint - err = json.NewDecoder(io.LimitReader(resp.Body, 4096)).Decode(&se) - return se, err - case http.StatusServiceUnavailable: - raHeader := resp.Header.Get("Retry-After") - raSeconds, err := strconv.ParseUint(raHeader, 10, 32) - if err == nil { - return udprelay.ServerEndpoint{}, errNotReady{retryAfter: time.Second * time.Duration(raSeconds)} - } - fallthrough - default: - return udprelay.ServerEndpoint{}, fmt.Errorf("non-200 status: %d", resp.StatusCode) - } -} - -func (r *relayManager) allocateSingleServer(ctx context.Context, wg *sync.WaitGroup, server netip.AddrPort, wlb endpointWithLastBest) { - // TODO(jwhited): introduce client metrics counters for notable failures - defer wg.Done() remoteDisco := wlb.ep.disco.Load() if remoteDisco == nil { return } - firstTry := true - for { - se, err := doAllocate(ctx, server, [2]key.DiscoPublic{wlb.ep.c.discoPublic, remoteDisco.key}) - if err == nil { - relayManagerInputEvent(r, ctx, &r.newServerEndpointCh, newRelayServerEndpointEvent{ - wlb: wlb, - se: se, - server: server, // we allocated this endpoint (vs CallMeMaybeVia reception), mark it as such - }) - return - } - wlb.ep.c.logf("[v1] magicsock: relayManager: error allocating endpoint on %v for %v: %v", server, wlb.ep.discoShort(), err) - var notReady errNotReady - if firstTry && errors.As(err, ¬Ready) { - select { - case <-ctx.Done(): - return - case <-time.After(min(notReady.retryAfter, reqTimeout)): - firstTry = false + discoKeys := key.NewSortedPairOfDiscoPublic(wlb.ep.c.discoPublic, remoteDisco.key) + for _, v := range r.serversByNodeKey { + byDiscoKeys, ok := r.allocWorkByDiscoKeysByServerNodeKey[v.nodeKey] + if !ok { + byDiscoKeys = make(map[key.SortedPairOfDiscoPublic]*relayEndpointAllocWork) + r.allocWorkByDiscoKeysByServerNodeKey[v.nodeKey] = byDiscoKeys + } else { + _, ok = byDiscoKeys[discoKeys] + if ok { + // If there is an existing key, a disco key collision may have + // occurred across peers ([*endpoint]). Do not overwrite the + // existing work, let it finish. + wlb.ep.c.logf("[unexpected] magicsock: relayManager: suspected disco key collision on server %v for keys: %v", v.nodeKey.ShortString(), discoKeys) continue } } - return + ctx, cancel := context.WithCancel(context.Background()) + started := &relayEndpointAllocWork{ + wlb: wlb, + discoKeys: discoKeys, + candidatePeerRelay: v, + rxDiscoMsgCh: make(chan *disco.AllocateUDPRelayEndpointResponse), + doneCh: make(chan relayEndpointAllocWorkDoneEvent, 1), + ctx: ctx, + cancel: cancel, + } + byDiscoKeys[discoKeys] = started + byCandidatePeerRelay, ok := r.allocWorkByCandidatePeerRelayByEndpoint[wlb.ep] + if !ok { + byCandidatePeerRelay = make(map[candidatePeerRelay]*relayEndpointAllocWork) + r.allocWorkByCandidatePeerRelayByEndpoint[wlb.ep] = byCandidatePeerRelay + } + byCandidatePeerRelay[v] = started + r.allocGeneration++ + go r.allocateServerEndpoint(started, r.allocGeneration) } } diff --git a/wgengine/magicsock/relaymanager_test.go b/wgengine/magicsock/relaymanager_test.go index 01f9258ad..e4891f567 100644 --- a/wgengine/magicsock/relaymanager_test.go +++ b/wgengine/magicsock/relaymanager_test.go @@ -4,7 +4,6 @@ package magicsock import ( - "net/netip" "testing" "tailscale.com/disco" @@ -22,26 +21,57 @@ func TestRelayManagerInitAndIdle(t *testing.T) { <-rm.runLoopStoppedCh rm = relayManager{} - rm.handleCallMeMaybeVia(&endpoint{c: &Conn{discoPrivate: key.NewDisco()}}, addrQuality{}, false, &disco.CallMeMaybeVia{ServerDisco: key.NewDisco().Public()}) + rm.handleCallMeMaybeVia(&endpoint{c: &Conn{discoPrivate: key.NewDisco()}}, addrQuality{}, false, &disco.CallMeMaybeVia{UDPRelayEndpoint: disco.UDPRelayEndpoint{ServerDisco: key.NewDisco().Public()}}) <-rm.runLoopStoppedCh rm = relayManager{} - rm.handleGeneveEncapDiscoMsg(&Conn{discoPrivate: key.NewDisco()}, &disco.BindUDPRelayEndpointChallenge{}, &discoInfo{}, epAddr{}) + rm.handleRxDiscoMsg(&Conn{discoPrivate: key.NewDisco()}, &disco.BindUDPRelayEndpointChallenge{}, key.NodePublic{}, key.DiscoPublic{}, epAddr{}) <-rm.runLoopStoppedCh rm = relayManager{} - rm.handleRelayServersSet(make(set.Set[netip.AddrPort])) + rm.handleRelayServersSet(make(set.Set[candidatePeerRelay])) <-rm.runLoopStoppedCh rm = relayManager{} rm.getServers() <-rm.runLoopStoppedCh + + rm = relayManager{} + rm.handleDERPHomeChange(key.NodePublic{}, 1) + <-rm.runLoopStoppedCh +} + +func TestRelayManagerHandleDERPHomeChange(t *testing.T) { + rm := relayManager{} + servers := make(set.Set[candidatePeerRelay], 1) + c := candidatePeerRelay{ + nodeKey: key.NewNode().Public(), + discoKey: key.NewDisco().Public(), + derpHomeRegionID: 1, + } + servers.Add(c) + rm.handleRelayServersSet(servers) + want := c + want.derpHomeRegionID = 2 + rm.handleDERPHomeChange(c.nodeKey, 2) + got := rm.getServers() + if len(got) != 1 { + t.Fatalf("got %d servers, want 1", len(got)) + } + _, ok := got[want] + if !ok { + t.Fatal("DERP home change failed to propagate") + } } func TestRelayManagerGetServers(t *testing.T) { rm := relayManager{} - servers := make(set.Set[netip.AddrPort], 1) - servers.Add(netip.MustParseAddrPort("192.0.2.1:7")) + servers := make(set.Set[candidatePeerRelay], 1) + c := candidatePeerRelay{ + nodeKey: key.NewNode().Public(), + discoKey: key.NewDisco().Public(), + } + servers.Add(c) rm.handleRelayServersSet(servers) got := rm.getServers() if !servers.Equal(got) {