mirror of
https://github.com/tailscale/tailscale.git
synced 2025-08-13 22:47:30 +00:00
disco,net/udprelay,wgengine/magicsock: support relay re-binding (#16388)
Relay handshakes may now occur multiple times over the lifetime of a relay server endpoint. Handshake messages now include a handshake generation, which is client specified, as a means to trigger safe challenge reset server-side. Relay servers continue to enforce challenge values as single use. They will only send a given value once, in reply to the first arriving bind message for a handshake generation. VNI has been added to the handshake messages, and we expect the outer Geneve header value to match the sealed value upon reception. Remote peer disco pub key is now also included in handshake messages, and it must match the receiver's expectation for the remote, participating party. Updates tailscale/corp#27502 Signed-off-by: Jordan Whited <jordan@tailscale.com>
This commit is contained in:
@@ -96,12 +96,13 @@ type serverEndpoint struct {
|
||||
// indexing of this array aligns with the following fields, e.g.
|
||||
// discoSharedSecrets[0] is the shared secret to use when sealing
|
||||
// Disco protocol messages for transmission towards discoPubKeys[0].
|
||||
discoPubKeys pairOfDiscoPubKeys
|
||||
discoSharedSecrets [2]key.DiscoShared
|
||||
handshakeState [2]disco.BindUDPRelayHandshakeState
|
||||
addrPorts [2]netip.AddrPort
|
||||
lastSeen [2]time.Time // TODO(jwhited): consider using mono.Time
|
||||
challenge [2][disco.BindUDPRelayEndpointChallengeLen]byte
|
||||
discoPubKeys pairOfDiscoPubKeys
|
||||
discoSharedSecrets [2]key.DiscoShared
|
||||
handshakeGeneration [2]uint32 // or zero if a handshake has never started for that relay leg
|
||||
handshakeAddrPorts [2]netip.AddrPort // or zero value if a handshake has never started for that relay leg
|
||||
boundAddrPorts [2]netip.AddrPort // or zero value if a handshake has never completed for that relay leg
|
||||
lastSeen [2]time.Time // TODO(jwhited): consider using mono.Time
|
||||
challenge [2][disco.BindUDPRelayChallengeLen]byte
|
||||
|
||||
lamportID uint64
|
||||
vni uint32
|
||||
@@ -112,69 +113,77 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex
|
||||
if senderIndex != 0 && senderIndex != 1 {
|
||||
return
|
||||
}
|
||||
handshakeState := e.handshakeState[senderIndex]
|
||||
if handshakeState == disco.BindUDPRelayHandshakeStateAnswerReceived {
|
||||
// this sender is already bound
|
||||
return
|
||||
|
||||
otherSender := 0
|
||||
if senderIndex == 0 {
|
||||
otherSender = 1
|
||||
}
|
||||
|
||||
validateVNIAndRemoteKey := func(common disco.BindUDPRelayEndpointCommon) error {
|
||||
if common.VNI != e.vni {
|
||||
return errors.New("mismatching VNI")
|
||||
}
|
||||
if common.RemoteKey.Compare(e.discoPubKeys[otherSender]) != 0 {
|
||||
return errors.New("mismatching RemoteKey")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
switch discoMsg := discoMsg.(type) {
|
||||
case *disco.BindUDPRelayEndpoint:
|
||||
switch handshakeState {
|
||||
case disco.BindUDPRelayHandshakeStateInit:
|
||||
// set sender addr
|
||||
e.addrPorts[senderIndex] = from
|
||||
fallthrough
|
||||
case disco.BindUDPRelayHandshakeStateChallengeSent:
|
||||
if from != e.addrPorts[senderIndex] {
|
||||
// this is a later arriving bind from a different source, or
|
||||
// a retransmit and the sender's source has changed, discard
|
||||
return
|
||||
}
|
||||
m := new(disco.BindUDPRelayEndpointChallenge)
|
||||
copy(m.Challenge[:], e.challenge[senderIndex][:])
|
||||
reply := make([]byte, packet.GeneveFixedHeaderLength, 512)
|
||||
gh := packet.GeneveHeader{Control: true, VNI: e.vni, Protocol: packet.GeneveProtocolDisco}
|
||||
err := gh.Encode(reply)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
reply = append(reply, disco.Magic...)
|
||||
reply = serverDisco.AppendTo(reply)
|
||||
box := e.discoSharedSecrets[senderIndex].Seal(m.AppendMarshal(nil))
|
||||
reply = append(reply, box...)
|
||||
uw.WriteMsgUDPAddrPort(reply, nil, from)
|
||||
// set new state
|
||||
e.handshakeState[senderIndex] = disco.BindUDPRelayHandshakeStateChallengeSent
|
||||
return
|
||||
default:
|
||||
// disco.BindUDPRelayEndpoint is unexpected in all other handshake states
|
||||
err := validateVNIAndRemoteKey(discoMsg.BindUDPRelayEndpointCommon)
|
||||
if err != nil {
|
||||
// silently drop
|
||||
return
|
||||
}
|
||||
if discoMsg.Generation == 0 {
|
||||
// Generation must be nonzero, silently drop
|
||||
return
|
||||
}
|
||||
if e.handshakeGeneration[senderIndex] == discoMsg.Generation {
|
||||
// we've seen this generation before, silently drop
|
||||
return
|
||||
}
|
||||
e.handshakeGeneration[senderIndex] = discoMsg.Generation
|
||||
e.handshakeAddrPorts[senderIndex] = from
|
||||
m := new(disco.BindUDPRelayEndpointChallenge)
|
||||
m.VNI = e.vni
|
||||
m.Generation = discoMsg.Generation
|
||||
m.RemoteKey = e.discoPubKeys[otherSender]
|
||||
rand.Read(e.challenge[senderIndex][:])
|
||||
copy(m.Challenge[:], e.challenge[senderIndex][:])
|
||||
reply := make([]byte, packet.GeneveFixedHeaderLength, 512)
|
||||
gh := packet.GeneveHeader{Control: true, VNI: e.vni, Protocol: packet.GeneveProtocolDisco}
|
||||
err = gh.Encode(reply)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
reply = append(reply, disco.Magic...)
|
||||
reply = serverDisco.AppendTo(reply)
|
||||
box := e.discoSharedSecrets[senderIndex].Seal(m.AppendMarshal(nil))
|
||||
reply = append(reply, box...)
|
||||
uw.WriteMsgUDPAddrPort(reply, nil, from)
|
||||
return
|
||||
case *disco.BindUDPRelayEndpointAnswer:
|
||||
switch handshakeState {
|
||||
case disco.BindUDPRelayHandshakeStateChallengeSent:
|
||||
if from != e.addrPorts[senderIndex] {
|
||||
// sender source has changed
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(discoMsg.Answer[:], e.challenge[senderIndex][:]) {
|
||||
// bad answer
|
||||
return
|
||||
}
|
||||
// sender is now bound
|
||||
// TODO: Consider installing a fast path via netfilter or similar to
|
||||
// relay (NAT) data packets for this serverEndpoint.
|
||||
e.handshakeState[senderIndex] = disco.BindUDPRelayHandshakeStateAnswerReceived
|
||||
// record last seen as bound time
|
||||
e.lastSeen[senderIndex] = time.Now()
|
||||
return
|
||||
default:
|
||||
// disco.BindUDPRelayEndpointAnswer is unexpected in all other handshake
|
||||
// states, or we've already handled it
|
||||
err := validateVNIAndRemoteKey(discoMsg.BindUDPRelayEndpointCommon)
|
||||
if err != nil {
|
||||
// silently drop
|
||||
return
|
||||
}
|
||||
generation := e.handshakeGeneration[senderIndex]
|
||||
if generation == 0 || // we have no active handshake
|
||||
generation != discoMsg.Generation || // mismatching generation for the active handshake
|
||||
e.handshakeAddrPorts[senderIndex] != from || // mismatching source for the active handshake
|
||||
!bytes.Equal(e.challenge[senderIndex][:], discoMsg.Challenge[:]) { // mismatching answer for the active handshake
|
||||
// silently drop
|
||||
return
|
||||
}
|
||||
// Handshake complete. Update the binding for this sender.
|
||||
e.boundAddrPorts[senderIndex] = from
|
||||
e.lastSeen[senderIndex] = time.Now() // record last seen as bound time
|
||||
return
|
||||
default:
|
||||
// unexpected Disco message type
|
||||
// unexpected message types, silently drop
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -225,12 +234,12 @@ func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeade
|
||||
}
|
||||
var to netip.AddrPort
|
||||
switch {
|
||||
case from == e.addrPorts[0]:
|
||||
case from == e.boundAddrPorts[0]:
|
||||
e.lastSeen[0] = time.Now()
|
||||
to = e.addrPorts[1]
|
||||
case from == e.addrPorts[1]:
|
||||
to = e.boundAddrPorts[1]
|
||||
case from == e.boundAddrPorts[1]:
|
||||
e.lastSeen[1] = time.Now()
|
||||
to = e.addrPorts[0]
|
||||
to = e.boundAddrPorts[0]
|
||||
default:
|
||||
// unrecognized source
|
||||
return
|
||||
@@ -240,11 +249,6 @@ func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeade
|
||||
return
|
||||
}
|
||||
|
||||
if e.isBound() {
|
||||
// control packet, but serverEndpoint is already bound
|
||||
return
|
||||
}
|
||||
|
||||
if gh.Protocol != packet.GeneveProtocolDisco {
|
||||
// control packet, but not Disco
|
||||
return
|
||||
@@ -267,11 +271,11 @@ func (e *serverEndpoint) isExpired(now time.Time, bindLifetime, steadyStateLifet
|
||||
return false
|
||||
}
|
||||
|
||||
// isBound returns true if both clients have completed their 3-way handshake,
|
||||
// isBound returns true if both clients have completed a 3-way handshake,
|
||||
// otherwise false.
|
||||
func (e *serverEndpoint) isBound() bool {
|
||||
return e.handshakeState[0] == disco.BindUDPRelayHandshakeStateAnswerReceived &&
|
||||
e.handshakeState[1] == disco.BindUDPRelayHandshakeStateAnswerReceived
|
||||
return e.boundAddrPorts[0].IsValid() &&
|
||||
e.boundAddrPorts[1].IsValid()
|
||||
}
|
||||
|
||||
// NewServer constructs a [Server] listening on 0.0.0.0:'port'. IPv6 is not yet
|
||||
@@ -591,8 +595,6 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
|
||||
e.discoSharedSecrets[0] = s.disco.Shared(e.discoPubKeys[0])
|
||||
e.discoSharedSecrets[1] = s.disco.Shared(e.discoPubKeys[1])
|
||||
e.vni, s.vniPool = s.vniPool[0], s.vniPool[1:]
|
||||
rand.Read(e.challenge[0][:])
|
||||
rand.Read(e.challenge[1][:])
|
||||
|
||||
s.byDisco[pair] = e
|
||||
s.byVNI[e.vni] = e
|
||||
|
@@ -19,23 +19,27 @@ import (
|
||||
)
|
||||
|
||||
type testClient struct {
|
||||
vni uint32
|
||||
local key.DiscoPrivate
|
||||
server key.DiscoPublic
|
||||
uc *net.UDPConn
|
||||
vni uint32
|
||||
handshakeGeneration uint32
|
||||
local key.DiscoPrivate
|
||||
remote key.DiscoPublic
|
||||
server key.DiscoPublic
|
||||
uc *net.UDPConn
|
||||
}
|
||||
|
||||
func newTestClient(t *testing.T, vni uint32, serverEndpoint netip.AddrPort, local key.DiscoPrivate, server key.DiscoPublic) *testClient {
|
||||
func newTestClient(t *testing.T, vni uint32, serverEndpoint netip.AddrPort, local key.DiscoPrivate, remote, server key.DiscoPublic) *testClient {
|
||||
rAddr := &net.UDPAddr{IP: serverEndpoint.Addr().AsSlice(), Port: int(serverEndpoint.Port())}
|
||||
uc, err := net.DialUDP("udp4", nil, rAddr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &testClient{
|
||||
vni: vni,
|
||||
local: local,
|
||||
server: server,
|
||||
uc: uc,
|
||||
vni: vni,
|
||||
handshakeGeneration: 1,
|
||||
local: local,
|
||||
remote: remote,
|
||||
server: server,
|
||||
uc: uc,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,13 +141,35 @@ func (c *testClient) readControlDiscoMsg(t *testing.T) disco.Message {
|
||||
}
|
||||
|
||||
func (c *testClient) handshake(t *testing.T) {
|
||||
c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpoint{})
|
||||
generation := c.handshakeGeneration
|
||||
c.handshakeGeneration++
|
||||
common := disco.BindUDPRelayEndpointCommon{
|
||||
VNI: c.vni,
|
||||
Generation: generation,
|
||||
RemoteKey: c.remote,
|
||||
}
|
||||
c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpoint{
|
||||
BindUDPRelayEndpointCommon: common,
|
||||
})
|
||||
msg := c.readControlDiscoMsg(t)
|
||||
challenge, ok := msg.(*disco.BindUDPRelayEndpointChallenge)
|
||||
if !ok {
|
||||
t.Fatal("unexepcted disco message type")
|
||||
t.Fatal("unexpected disco message type")
|
||||
}
|
||||
c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpointAnswer{Answer: challenge.Challenge})
|
||||
if challenge.Generation != common.Generation {
|
||||
t.Fatalf("rx'd challenge.Generation (%d) != %d", challenge.Generation, common.Generation)
|
||||
}
|
||||
if challenge.VNI != common.VNI {
|
||||
t.Fatalf("rx'd challenge.VNI (%d) != %d", challenge.VNI, common.VNI)
|
||||
}
|
||||
if challenge.RemoteKey != common.RemoteKey {
|
||||
t.Fatalf("rx'd challenge.RemoteKey (%v) != %v", challenge.RemoteKey, common.RemoteKey)
|
||||
}
|
||||
answer := &disco.BindUDPRelayEndpointAnswer{
|
||||
BindUDPRelayEndpointCommon: common,
|
||||
}
|
||||
answer.Challenge = challenge.Challenge
|
||||
c.writeControlDiscoMsg(t, answer)
|
||||
}
|
||||
|
||||
func (c *testClient) close() {
|
||||
@@ -179,9 +205,9 @@ func TestServer(t *testing.T) {
|
||||
if len(endpoint.AddrPorts) != 1 {
|
||||
t.Fatalf("unexpected endpoint.AddrPorts: %v", endpoint.AddrPorts)
|
||||
}
|
||||
tcA := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, endpoint.ServerDisco)
|
||||
tcA := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, discoB.Public(), endpoint.ServerDisco)
|
||||
defer tcA.close()
|
||||
tcB := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, endpoint.ServerDisco)
|
||||
tcB := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, discoA.Public(), endpoint.ServerDisco)
|
||||
defer tcB.close()
|
||||
|
||||
tcA.handshake(t)
|
||||
@@ -209,4 +235,30 @@ func TestServer(t *testing.T) {
|
||||
if !bytes.Equal(txToA, rxFromB) {
|
||||
t.Fatal("unexpected msg B->A")
|
||||
}
|
||||
|
||||
tcAOnNewPort := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, discoB.Public(), endpoint.ServerDisco)
|
||||
tcAOnNewPort.handshakeGeneration = tcA.handshakeGeneration + 1
|
||||
defer tcAOnNewPort.close()
|
||||
|
||||
// Handshake client A on a new source IP:port, verify we receive packets on the new binding
|
||||
tcAOnNewPort.handshake(t)
|
||||
txToAOnNewPort := []byte{7, 8, 9}
|
||||
tcB.writeDataPkt(t, txToAOnNewPort)
|
||||
rxFromB = tcAOnNewPort.readDataPkt(t)
|
||||
if !bytes.Equal(txToAOnNewPort, rxFromB) {
|
||||
t.Fatal("unexpected msg B->A")
|
||||
}
|
||||
|
||||
tcBOnNewPort := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, discoA.Public(), endpoint.ServerDisco)
|
||||
tcBOnNewPort.handshakeGeneration = tcB.handshakeGeneration + 1
|
||||
defer tcBOnNewPort.close()
|
||||
|
||||
// Handshake client B on a new source IP:port, verify we receive packets on the new binding
|
||||
tcBOnNewPort.handshake(t)
|
||||
txToBOnNewPort := []byte{7, 8, 9}
|
||||
tcAOnNewPort.writeDataPkt(t, txToBOnNewPort)
|
||||
rxFromA = tcBOnNewPort.readDataPkt(t)
|
||||
if !bytes.Equal(txToBOnNewPort, rxFromA) {
|
||||
t.Fatal("unexpected msg A->B")
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user