disco,net/udprelay,wgengine/magicsock: support relay re-binding (#16388)

Relay handshakes may now occur multiple times over the lifetime of a
relay server endpoint. Handshake messages now include a handshake
generation, which is client specified, as a means to trigger safe
challenge reset server-side.

Relay servers continue to enforce challenge values as single use. They
will only send a given value once, in reply to the first arriving bind
message for a handshake generation.

VNI has been added to the handshake messages, and we expect the outer
Geneve header value to match the sealed value upon reception.

Remote peer disco pub key is now also included in handshake messages,
and it must match the receiver's expectation for the remote,
participating party.

Updates tailscale/corp#27502

Signed-off-by: Jordan Whited <jordan@tailscale.com>
This commit is contained in:
Jordan Whited
2025-06-26 19:30:14 -07:00
committed by GitHub
parent b2bf7e988e
commit b32a01b2dc
5 changed files with 278 additions and 135 deletions

View File

@@ -96,12 +96,13 @@ type serverEndpoint struct {
// indexing of this array aligns with the following fields, e.g.
// discoSharedSecrets[0] is the shared secret to use when sealing
// Disco protocol messages for transmission towards discoPubKeys[0].
discoPubKeys pairOfDiscoPubKeys
discoSharedSecrets [2]key.DiscoShared
handshakeState [2]disco.BindUDPRelayHandshakeState
addrPorts [2]netip.AddrPort
lastSeen [2]time.Time // TODO(jwhited): consider using mono.Time
challenge [2][disco.BindUDPRelayEndpointChallengeLen]byte
discoPubKeys pairOfDiscoPubKeys
discoSharedSecrets [2]key.DiscoShared
handshakeGeneration [2]uint32 // or zero if a handshake has never started for that relay leg
handshakeAddrPorts [2]netip.AddrPort // or zero value if a handshake has never started for that relay leg
boundAddrPorts [2]netip.AddrPort // or zero value if a handshake has never completed for that relay leg
lastSeen [2]time.Time // TODO(jwhited): consider using mono.Time
challenge [2][disco.BindUDPRelayChallengeLen]byte
lamportID uint64
vni uint32
@@ -112,69 +113,77 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex
if senderIndex != 0 && senderIndex != 1 {
return
}
handshakeState := e.handshakeState[senderIndex]
if handshakeState == disco.BindUDPRelayHandshakeStateAnswerReceived {
// this sender is already bound
return
otherSender := 0
if senderIndex == 0 {
otherSender = 1
}
validateVNIAndRemoteKey := func(common disco.BindUDPRelayEndpointCommon) error {
if common.VNI != e.vni {
return errors.New("mismatching VNI")
}
if common.RemoteKey.Compare(e.discoPubKeys[otherSender]) != 0 {
return errors.New("mismatching RemoteKey")
}
return nil
}
switch discoMsg := discoMsg.(type) {
case *disco.BindUDPRelayEndpoint:
switch handshakeState {
case disco.BindUDPRelayHandshakeStateInit:
// set sender addr
e.addrPorts[senderIndex] = from
fallthrough
case disco.BindUDPRelayHandshakeStateChallengeSent:
if from != e.addrPorts[senderIndex] {
// this is a later arriving bind from a different source, or
// a retransmit and the sender's source has changed, discard
return
}
m := new(disco.BindUDPRelayEndpointChallenge)
copy(m.Challenge[:], e.challenge[senderIndex][:])
reply := make([]byte, packet.GeneveFixedHeaderLength, 512)
gh := packet.GeneveHeader{Control: true, VNI: e.vni, Protocol: packet.GeneveProtocolDisco}
err := gh.Encode(reply)
if err != nil {
return
}
reply = append(reply, disco.Magic...)
reply = serverDisco.AppendTo(reply)
box := e.discoSharedSecrets[senderIndex].Seal(m.AppendMarshal(nil))
reply = append(reply, box...)
uw.WriteMsgUDPAddrPort(reply, nil, from)
// set new state
e.handshakeState[senderIndex] = disco.BindUDPRelayHandshakeStateChallengeSent
return
default:
// disco.BindUDPRelayEndpoint is unexpected in all other handshake states
err := validateVNIAndRemoteKey(discoMsg.BindUDPRelayEndpointCommon)
if err != nil {
// silently drop
return
}
if discoMsg.Generation == 0 {
// Generation must be nonzero, silently drop
return
}
if e.handshakeGeneration[senderIndex] == discoMsg.Generation {
// we've seen this generation before, silently drop
return
}
e.handshakeGeneration[senderIndex] = discoMsg.Generation
e.handshakeAddrPorts[senderIndex] = from
m := new(disco.BindUDPRelayEndpointChallenge)
m.VNI = e.vni
m.Generation = discoMsg.Generation
m.RemoteKey = e.discoPubKeys[otherSender]
rand.Read(e.challenge[senderIndex][:])
copy(m.Challenge[:], e.challenge[senderIndex][:])
reply := make([]byte, packet.GeneveFixedHeaderLength, 512)
gh := packet.GeneveHeader{Control: true, VNI: e.vni, Protocol: packet.GeneveProtocolDisco}
err = gh.Encode(reply)
if err != nil {
return
}
reply = append(reply, disco.Magic...)
reply = serverDisco.AppendTo(reply)
box := e.discoSharedSecrets[senderIndex].Seal(m.AppendMarshal(nil))
reply = append(reply, box...)
uw.WriteMsgUDPAddrPort(reply, nil, from)
return
case *disco.BindUDPRelayEndpointAnswer:
switch handshakeState {
case disco.BindUDPRelayHandshakeStateChallengeSent:
if from != e.addrPorts[senderIndex] {
// sender source has changed
return
}
if !bytes.Equal(discoMsg.Answer[:], e.challenge[senderIndex][:]) {
// bad answer
return
}
// sender is now bound
// TODO: Consider installing a fast path via netfilter or similar to
// relay (NAT) data packets for this serverEndpoint.
e.handshakeState[senderIndex] = disco.BindUDPRelayHandshakeStateAnswerReceived
// record last seen as bound time
e.lastSeen[senderIndex] = time.Now()
return
default:
// disco.BindUDPRelayEndpointAnswer is unexpected in all other handshake
// states, or we've already handled it
err := validateVNIAndRemoteKey(discoMsg.BindUDPRelayEndpointCommon)
if err != nil {
// silently drop
return
}
generation := e.handshakeGeneration[senderIndex]
if generation == 0 || // we have no active handshake
generation != discoMsg.Generation || // mismatching generation for the active handshake
e.handshakeAddrPorts[senderIndex] != from || // mismatching source for the active handshake
!bytes.Equal(e.challenge[senderIndex][:], discoMsg.Challenge[:]) { // mismatching answer for the active handshake
// silently drop
return
}
// Handshake complete. Update the binding for this sender.
e.boundAddrPorts[senderIndex] = from
e.lastSeen[senderIndex] = time.Now() // record last seen as bound time
return
default:
// unexpected Disco message type
// unexpected message types, silently drop
return
}
}
@@ -225,12 +234,12 @@ func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeade
}
var to netip.AddrPort
switch {
case from == e.addrPorts[0]:
case from == e.boundAddrPorts[0]:
e.lastSeen[0] = time.Now()
to = e.addrPorts[1]
case from == e.addrPorts[1]:
to = e.boundAddrPorts[1]
case from == e.boundAddrPorts[1]:
e.lastSeen[1] = time.Now()
to = e.addrPorts[0]
to = e.boundAddrPorts[0]
default:
// unrecognized source
return
@@ -240,11 +249,6 @@ func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeade
return
}
if e.isBound() {
// control packet, but serverEndpoint is already bound
return
}
if gh.Protocol != packet.GeneveProtocolDisco {
// control packet, but not Disco
return
@@ -267,11 +271,11 @@ func (e *serverEndpoint) isExpired(now time.Time, bindLifetime, steadyStateLifet
return false
}
// isBound returns true if both clients have completed their 3-way handshake,
// isBound returns true if both clients have completed a 3-way handshake,
// otherwise false.
func (e *serverEndpoint) isBound() bool {
return e.handshakeState[0] == disco.BindUDPRelayHandshakeStateAnswerReceived &&
e.handshakeState[1] == disco.BindUDPRelayHandshakeStateAnswerReceived
return e.boundAddrPorts[0].IsValid() &&
e.boundAddrPorts[1].IsValid()
}
// NewServer constructs a [Server] listening on 0.0.0.0:'port'. IPv6 is not yet
@@ -591,8 +595,6 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
e.discoSharedSecrets[0] = s.disco.Shared(e.discoPubKeys[0])
e.discoSharedSecrets[1] = s.disco.Shared(e.discoPubKeys[1])
e.vni, s.vniPool = s.vniPool[0], s.vniPool[1:]
rand.Read(e.challenge[0][:])
rand.Read(e.challenge[1][:])
s.byDisco[pair] = e
s.byVNI[e.vni] = e

View File

@@ -19,23 +19,27 @@ import (
)
type testClient struct {
vni uint32
local key.DiscoPrivate
server key.DiscoPublic
uc *net.UDPConn
vni uint32
handshakeGeneration uint32
local key.DiscoPrivate
remote key.DiscoPublic
server key.DiscoPublic
uc *net.UDPConn
}
func newTestClient(t *testing.T, vni uint32, serverEndpoint netip.AddrPort, local key.DiscoPrivate, server key.DiscoPublic) *testClient {
func newTestClient(t *testing.T, vni uint32, serverEndpoint netip.AddrPort, local key.DiscoPrivate, remote, server key.DiscoPublic) *testClient {
rAddr := &net.UDPAddr{IP: serverEndpoint.Addr().AsSlice(), Port: int(serverEndpoint.Port())}
uc, err := net.DialUDP("udp4", nil, rAddr)
if err != nil {
t.Fatal(err)
}
return &testClient{
vni: vni,
local: local,
server: server,
uc: uc,
vni: vni,
handshakeGeneration: 1,
local: local,
remote: remote,
server: server,
uc: uc,
}
}
@@ -137,13 +141,35 @@ func (c *testClient) readControlDiscoMsg(t *testing.T) disco.Message {
}
func (c *testClient) handshake(t *testing.T) {
c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpoint{})
generation := c.handshakeGeneration
c.handshakeGeneration++
common := disco.BindUDPRelayEndpointCommon{
VNI: c.vni,
Generation: generation,
RemoteKey: c.remote,
}
c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpoint{
BindUDPRelayEndpointCommon: common,
})
msg := c.readControlDiscoMsg(t)
challenge, ok := msg.(*disco.BindUDPRelayEndpointChallenge)
if !ok {
t.Fatal("unexepcted disco message type")
t.Fatal("unexpected disco message type")
}
c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpointAnswer{Answer: challenge.Challenge})
if challenge.Generation != common.Generation {
t.Fatalf("rx'd challenge.Generation (%d) != %d", challenge.Generation, common.Generation)
}
if challenge.VNI != common.VNI {
t.Fatalf("rx'd challenge.VNI (%d) != %d", challenge.VNI, common.VNI)
}
if challenge.RemoteKey != common.RemoteKey {
t.Fatalf("rx'd challenge.RemoteKey (%v) != %v", challenge.RemoteKey, common.RemoteKey)
}
answer := &disco.BindUDPRelayEndpointAnswer{
BindUDPRelayEndpointCommon: common,
}
answer.Challenge = challenge.Challenge
c.writeControlDiscoMsg(t, answer)
}
func (c *testClient) close() {
@@ -179,9 +205,9 @@ func TestServer(t *testing.T) {
if len(endpoint.AddrPorts) != 1 {
t.Fatalf("unexpected endpoint.AddrPorts: %v", endpoint.AddrPorts)
}
tcA := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, endpoint.ServerDisco)
tcA := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, discoB.Public(), endpoint.ServerDisco)
defer tcA.close()
tcB := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, endpoint.ServerDisco)
tcB := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, discoA.Public(), endpoint.ServerDisco)
defer tcB.close()
tcA.handshake(t)
@@ -209,4 +235,30 @@ func TestServer(t *testing.T) {
if !bytes.Equal(txToA, rxFromB) {
t.Fatal("unexpected msg B->A")
}
tcAOnNewPort := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, discoB.Public(), endpoint.ServerDisco)
tcAOnNewPort.handshakeGeneration = tcA.handshakeGeneration + 1
defer tcAOnNewPort.close()
// Handshake client A on a new source IP:port, verify we receive packets on the new binding
tcAOnNewPort.handshake(t)
txToAOnNewPort := []byte{7, 8, 9}
tcB.writeDataPkt(t, txToAOnNewPort)
rxFromB = tcAOnNewPort.readDataPkt(t)
if !bytes.Equal(txToAOnNewPort, rxFromB) {
t.Fatal("unexpected msg B->A")
}
tcBOnNewPort := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, discoA.Public(), endpoint.ServerDisco)
tcBOnNewPort.handshakeGeneration = tcB.handshakeGeneration + 1
defer tcBOnNewPort.close()
// Handshake client B on a new source IP:port, verify we receive packets on the new binding
tcBOnNewPort.handshake(t)
txToBOnNewPort := []byte{7, 8, 9}
tcAOnNewPort.writeDataPkt(t, txToBOnNewPort)
rxFromA = tcBOnNewPort.readDataPkt(t)
if !bytes.Equal(txToBOnNewPort, rxFromA) {
t.Fatal("unexpected msg A->B")
}
}