diff --git a/disco/disco.go b/disco/disco.go index 0854eb4c0..d4623c119 100644 --- a/disco/disco.go +++ b/disco/disco.go @@ -321,79 +321,131 @@ const ( BindUDPRelayHandshakeStateAnswerReceived ) -// bindUDPRelayEndpointLen is the length of a marshalled BindUDPRelayEndpoint -// message, without the message header. -const bindUDPRelayEndpointLen = BindUDPRelayEndpointChallengeLen +// bindUDPRelayEndpointCommonLen is the length of a marshalled +// [BindUDPRelayEndpointCommon], without the message header. +const bindUDPRelayEndpointCommonLen = 72 + +// BindUDPRelayChallengeLen is the length of the Challenge field carried in +// [BindUDPRelayEndpointChallenge] & [BindUDPRelayEndpointAnswer] messages. +const BindUDPRelayChallengeLen = 32 + +// BindUDPRelayEndpointCommon contains fields that are common across all 3 +// UDP relay handshake message types. All 4 field values are expected to be +// consistent for the lifetime of a handshake besides Challenge, which is +// irrelevant in a [BindUDPRelayEndpoint] message. +type BindUDPRelayEndpointCommon struct { + // VNI is the Geneve header Virtual Network Identifier field value, which + // must match this disco-sealed value upon reception. If they are + // non-matching it indicates the cleartext Geneve header was tampered with + // and/or mangled. + VNI uint32 + // Generation represents the handshake generation. Clients must set a new, + // nonzero value at the start of every handshake. + Generation uint32 + // RemoteKey is the disco key of the remote peer participating over this + // relay endpoint. + RemoteKey key.DiscoPublic + // Challenge is set by the server in a [BindUDPRelayEndpointChallenge] + // message, and expected to be echoed back by the client in a + // [BindUDPRelayEndpointAnswer] message. Its value is irrelevant in a + // [BindUDPRelayEndpoint] message, where it simply serves a padding purpose + // ensuring all handshake messages are equal in size. + Challenge [BindUDPRelayChallengeLen]byte +} + +// encode encodes m in b. b must be at least bindUDPRelayEndpointCommonLen bytes +// long. +func (m *BindUDPRelayEndpointCommon) encode(b []byte) { + binary.BigEndian.PutUint32(b, m.VNI) + b = b[4:] + binary.BigEndian.PutUint32(b, m.Generation) + b = b[4:] + m.RemoteKey.AppendTo(b[:0]) + b = b[key.DiscoPublicRawLen:] + copy(b, m.Challenge[:]) +} + +// decode decodes m from b. +func (m *BindUDPRelayEndpointCommon) decode(b []byte) error { + if len(b) < bindUDPRelayEndpointCommonLen { + return errShort + } + m.VNI = binary.BigEndian.Uint32(b) + b = b[4:] + m.Generation = binary.BigEndian.Uint32(b) + b = b[4:] + m.RemoteKey = key.DiscoPublicFromRaw32(mem.B(b[:key.DiscoPublicRawLen])) + b = b[key.DiscoPublicRawLen:] + copy(m.Challenge[:], b[:BindUDPRelayChallengeLen]) + return nil +} // BindUDPRelayEndpoint is the first messaged transmitted from UDP relay client -// towards UDP relay server as part of the 3-way bind handshake. It is padded to -// match the length of BindUDPRelayEndpointChallenge. This message type is -// currently considered experimental and is not yet tied to a +// towards UDP relay server as part of the 3-way bind handshake. This message +// type is currently considered experimental and is not yet tied to a // tailcfg.CapabilityVersion. type BindUDPRelayEndpoint struct { + BindUDPRelayEndpointCommon } func (m *BindUDPRelayEndpoint) AppendMarshal(b []byte) []byte { - ret, _ := appendMsgHeader(b, TypeBindUDPRelayEndpoint, v0, bindUDPRelayEndpointLen) + ret, d := appendMsgHeader(b, TypeBindUDPRelayEndpoint, v0, bindUDPRelayEndpointCommonLen) + m.BindUDPRelayEndpointCommon.encode(d) return ret } func parseBindUDPRelayEndpoint(ver uint8, p []byte) (m *BindUDPRelayEndpoint, err error) { m = new(BindUDPRelayEndpoint) + err = m.BindUDPRelayEndpointCommon.decode(p) + if err != nil { + return nil, err + } return m, nil } -// BindUDPRelayEndpointChallengeLen is the length of a marshalled -// BindUDPRelayEndpointChallenge message, without the message header. -const BindUDPRelayEndpointChallengeLen = 32 - // BindUDPRelayEndpointChallenge is transmitted from UDP relay server towards // UDP relay client in response to a BindUDPRelayEndpoint message as part of the // 3-way bind handshake. This message type is currently considered experimental // and is not yet tied to a tailcfg.CapabilityVersion. type BindUDPRelayEndpointChallenge struct { - Challenge [BindUDPRelayEndpointChallengeLen]byte + BindUDPRelayEndpointCommon } func (m *BindUDPRelayEndpointChallenge) AppendMarshal(b []byte) []byte { - ret, d := appendMsgHeader(b, TypeBindUDPRelayEndpointChallenge, v0, BindUDPRelayEndpointChallengeLen) - copy(d, m.Challenge[:]) + ret, d := appendMsgHeader(b, TypeBindUDPRelayEndpointChallenge, v0, bindUDPRelayEndpointCommonLen) + m.BindUDPRelayEndpointCommon.encode(d) return ret } func parseBindUDPRelayEndpointChallenge(ver uint8, p []byte) (m *BindUDPRelayEndpointChallenge, err error) { - if len(p) < BindUDPRelayEndpointChallengeLen { - return nil, errShort - } m = new(BindUDPRelayEndpointChallenge) - copy(m.Challenge[:], p[:]) + err = m.BindUDPRelayEndpointCommon.decode(p) + if err != nil { + return nil, err + } return m, nil } -// bindUDPRelayEndpointAnswerLen is the length of a marshalled -// BindUDPRelayEndpointAnswer message, without the message header. -const bindUDPRelayEndpointAnswerLen = BindUDPRelayEndpointChallengeLen - // BindUDPRelayEndpointAnswer is transmitted from UDP relay client to UDP relay // server in response to a BindUDPRelayEndpointChallenge message. This message // type is currently considered experimental and is not yet tied to a // tailcfg.CapabilityVersion. type BindUDPRelayEndpointAnswer struct { - Answer [bindUDPRelayEndpointAnswerLen]byte + BindUDPRelayEndpointCommon } func (m *BindUDPRelayEndpointAnswer) AppendMarshal(b []byte) []byte { - ret, d := appendMsgHeader(b, TypeBindUDPRelayEndpointAnswer, v0, bindUDPRelayEndpointAnswerLen) - copy(d, m.Answer[:]) + ret, d := appendMsgHeader(b, TypeBindUDPRelayEndpointAnswer, v0, bindUDPRelayEndpointCommonLen) + m.BindUDPRelayEndpointCommon.encode(d) return ret } func parseBindUDPRelayEndpointAnswer(ver uint8, p []byte) (m *BindUDPRelayEndpointAnswer, err error) { - if len(p) < bindUDPRelayEndpointAnswerLen { - return nil, errShort - } m = new(BindUDPRelayEndpointAnswer) - copy(m.Answer[:], p[:]) + err = m.BindUDPRelayEndpointCommon.decode(p) + if err != nil { + return nil, err + } return m, nil } diff --git a/disco/disco_test.go b/disco/disco_test.go index f2a29a744..9fb71ff83 100644 --- a/disco/disco_test.go +++ b/disco/disco_test.go @@ -16,6 +16,15 @@ import ( ) func TestMarshalAndParse(t *testing.T) { + relayHandshakeCommon := BindUDPRelayEndpointCommon{ + VNI: 1, + Generation: 2, + RemoteKey: key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 30: 30, 31: 31})), + Challenge: [BindUDPRelayChallengeLen]byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + }, + } + tests := []struct { name string want string @@ -86,26 +95,24 @@ func TestMarshalAndParse(t *testing.T) { }, { name: "bind_udp_relay_endpoint", - m: &BindUDPRelayEndpoint{}, - want: "04 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00", + m: &BindUDPRelayEndpoint{ + relayHandshakeCommon, + }, + want: "04 00 00 00 00 01 00 00 00 02 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f", }, { name: "bind_udp_relay_endpoint_challenge", m: &BindUDPRelayEndpointChallenge{ - Challenge: [BindUDPRelayEndpointChallengeLen]byte{ - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, - }, + relayHandshakeCommon, }, - want: "05 00 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f", + want: "05 00 00 00 00 01 00 00 00 02 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f", }, { name: "bind_udp_relay_endpoint_answer", m: &BindUDPRelayEndpointAnswer{ - Answer: [bindUDPRelayEndpointAnswerLen]byte{ - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, - }, + relayHandshakeCommon, }, - want: "06 00 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f", + want: "06 00 00 00 00 01 00 00 00 02 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f", }, { name: "call_me_maybe_via", diff --git a/net/udprelay/server.go b/net/udprelay/server.go index 8b9e95fb1..e32f8917c 100644 --- a/net/udprelay/server.go +++ b/net/udprelay/server.go @@ -96,12 +96,13 @@ type serverEndpoint struct { // indexing of this array aligns with the following fields, e.g. // discoSharedSecrets[0] is the shared secret to use when sealing // Disco protocol messages for transmission towards discoPubKeys[0]. - discoPubKeys pairOfDiscoPubKeys - discoSharedSecrets [2]key.DiscoShared - handshakeState [2]disco.BindUDPRelayHandshakeState - addrPorts [2]netip.AddrPort - lastSeen [2]time.Time // TODO(jwhited): consider using mono.Time - challenge [2][disco.BindUDPRelayEndpointChallengeLen]byte + discoPubKeys pairOfDiscoPubKeys + discoSharedSecrets [2]key.DiscoShared + handshakeGeneration [2]uint32 // or zero if a handshake has never started for that relay leg + handshakeAddrPorts [2]netip.AddrPort // or zero value if a handshake has never started for that relay leg + boundAddrPorts [2]netip.AddrPort // or zero value if a handshake has never completed for that relay leg + lastSeen [2]time.Time // TODO(jwhited): consider using mono.Time + challenge [2][disco.BindUDPRelayChallengeLen]byte lamportID uint64 vni uint32 @@ -112,69 +113,77 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex if senderIndex != 0 && senderIndex != 1 { return } - handshakeState := e.handshakeState[senderIndex] - if handshakeState == disco.BindUDPRelayHandshakeStateAnswerReceived { - // this sender is already bound - return + + otherSender := 0 + if senderIndex == 0 { + otherSender = 1 } + + validateVNIAndRemoteKey := func(common disco.BindUDPRelayEndpointCommon) error { + if common.VNI != e.vni { + return errors.New("mismatching VNI") + } + if common.RemoteKey.Compare(e.discoPubKeys[otherSender]) != 0 { + return errors.New("mismatching RemoteKey") + } + return nil + } + switch discoMsg := discoMsg.(type) { case *disco.BindUDPRelayEndpoint: - switch handshakeState { - case disco.BindUDPRelayHandshakeStateInit: - // set sender addr - e.addrPorts[senderIndex] = from - fallthrough - case disco.BindUDPRelayHandshakeStateChallengeSent: - if from != e.addrPorts[senderIndex] { - // this is a later arriving bind from a different source, or - // a retransmit and the sender's source has changed, discard - return - } - m := new(disco.BindUDPRelayEndpointChallenge) - copy(m.Challenge[:], e.challenge[senderIndex][:]) - reply := make([]byte, packet.GeneveFixedHeaderLength, 512) - gh := packet.GeneveHeader{Control: true, VNI: e.vni, Protocol: packet.GeneveProtocolDisco} - err := gh.Encode(reply) - if err != nil { - return - } - reply = append(reply, disco.Magic...) - reply = serverDisco.AppendTo(reply) - box := e.discoSharedSecrets[senderIndex].Seal(m.AppendMarshal(nil)) - reply = append(reply, box...) - uw.WriteMsgUDPAddrPort(reply, nil, from) - // set new state - e.handshakeState[senderIndex] = disco.BindUDPRelayHandshakeStateChallengeSent - return - default: - // disco.BindUDPRelayEndpoint is unexpected in all other handshake states + err := validateVNIAndRemoteKey(discoMsg.BindUDPRelayEndpointCommon) + if err != nil { + // silently drop return } + if discoMsg.Generation == 0 { + // Generation must be nonzero, silently drop + return + } + if e.handshakeGeneration[senderIndex] == discoMsg.Generation { + // we've seen this generation before, silently drop + return + } + e.handshakeGeneration[senderIndex] = discoMsg.Generation + e.handshakeAddrPorts[senderIndex] = from + m := new(disco.BindUDPRelayEndpointChallenge) + m.VNI = e.vni + m.Generation = discoMsg.Generation + m.RemoteKey = e.discoPubKeys[otherSender] + rand.Read(e.challenge[senderIndex][:]) + copy(m.Challenge[:], e.challenge[senderIndex][:]) + reply := make([]byte, packet.GeneveFixedHeaderLength, 512) + gh := packet.GeneveHeader{Control: true, VNI: e.vni, Protocol: packet.GeneveProtocolDisco} + err = gh.Encode(reply) + if err != nil { + return + } + reply = append(reply, disco.Magic...) + reply = serverDisco.AppendTo(reply) + box := e.discoSharedSecrets[senderIndex].Seal(m.AppendMarshal(nil)) + reply = append(reply, box...) + uw.WriteMsgUDPAddrPort(reply, nil, from) + return case *disco.BindUDPRelayEndpointAnswer: - switch handshakeState { - case disco.BindUDPRelayHandshakeStateChallengeSent: - if from != e.addrPorts[senderIndex] { - // sender source has changed - return - } - if !bytes.Equal(discoMsg.Answer[:], e.challenge[senderIndex][:]) { - // bad answer - return - } - // sender is now bound - // TODO: Consider installing a fast path via netfilter or similar to - // relay (NAT) data packets for this serverEndpoint. - e.handshakeState[senderIndex] = disco.BindUDPRelayHandshakeStateAnswerReceived - // record last seen as bound time - e.lastSeen[senderIndex] = time.Now() - return - default: - // disco.BindUDPRelayEndpointAnswer is unexpected in all other handshake - // states, or we've already handled it + err := validateVNIAndRemoteKey(discoMsg.BindUDPRelayEndpointCommon) + if err != nil { + // silently drop return } + generation := e.handshakeGeneration[senderIndex] + if generation == 0 || // we have no active handshake + generation != discoMsg.Generation || // mismatching generation for the active handshake + e.handshakeAddrPorts[senderIndex] != from || // mismatching source for the active handshake + !bytes.Equal(e.challenge[senderIndex][:], discoMsg.Challenge[:]) { // mismatching answer for the active handshake + // silently drop + return + } + // Handshake complete. Update the binding for this sender. + e.boundAddrPorts[senderIndex] = from + e.lastSeen[senderIndex] = time.Now() // record last seen as bound time + return default: - // unexpected Disco message type + // unexpected message types, silently drop return } } @@ -225,12 +234,12 @@ func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeade } var to netip.AddrPort switch { - case from == e.addrPorts[0]: + case from == e.boundAddrPorts[0]: e.lastSeen[0] = time.Now() - to = e.addrPorts[1] - case from == e.addrPorts[1]: + to = e.boundAddrPorts[1] + case from == e.boundAddrPorts[1]: e.lastSeen[1] = time.Now() - to = e.addrPorts[0] + to = e.boundAddrPorts[0] default: // unrecognized source return @@ -240,11 +249,6 @@ func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeade return } - if e.isBound() { - // control packet, but serverEndpoint is already bound - return - } - if gh.Protocol != packet.GeneveProtocolDisco { // control packet, but not Disco return @@ -267,11 +271,11 @@ func (e *serverEndpoint) isExpired(now time.Time, bindLifetime, steadyStateLifet return false } -// isBound returns true if both clients have completed their 3-way handshake, +// isBound returns true if both clients have completed a 3-way handshake, // otherwise false. func (e *serverEndpoint) isBound() bool { - return e.handshakeState[0] == disco.BindUDPRelayHandshakeStateAnswerReceived && - e.handshakeState[1] == disco.BindUDPRelayHandshakeStateAnswerReceived + return e.boundAddrPorts[0].IsValid() && + e.boundAddrPorts[1].IsValid() } // NewServer constructs a [Server] listening on 0.0.0.0:'port'. IPv6 is not yet @@ -591,8 +595,6 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv e.discoSharedSecrets[0] = s.disco.Shared(e.discoPubKeys[0]) e.discoSharedSecrets[1] = s.disco.Shared(e.discoPubKeys[1]) e.vni, s.vniPool = s.vniPool[0], s.vniPool[1:] - rand.Read(e.challenge[0][:]) - rand.Read(e.challenge[1][:]) s.byDisco[pair] = e s.byVNI[e.vni] = e diff --git a/net/udprelay/server_test.go b/net/udprelay/server_test.go index a4e5ca451..3fcb9b8b1 100644 --- a/net/udprelay/server_test.go +++ b/net/udprelay/server_test.go @@ -19,23 +19,27 @@ import ( ) type testClient struct { - vni uint32 - local key.DiscoPrivate - server key.DiscoPublic - uc *net.UDPConn + vni uint32 + handshakeGeneration uint32 + local key.DiscoPrivate + remote key.DiscoPublic + server key.DiscoPublic + uc *net.UDPConn } -func newTestClient(t *testing.T, vni uint32, serverEndpoint netip.AddrPort, local key.DiscoPrivate, server key.DiscoPublic) *testClient { +func newTestClient(t *testing.T, vni uint32, serverEndpoint netip.AddrPort, local key.DiscoPrivate, remote, server key.DiscoPublic) *testClient { rAddr := &net.UDPAddr{IP: serverEndpoint.Addr().AsSlice(), Port: int(serverEndpoint.Port())} uc, err := net.DialUDP("udp4", nil, rAddr) if err != nil { t.Fatal(err) } return &testClient{ - vni: vni, - local: local, - server: server, - uc: uc, + vni: vni, + handshakeGeneration: 1, + local: local, + remote: remote, + server: server, + uc: uc, } } @@ -137,13 +141,35 @@ func (c *testClient) readControlDiscoMsg(t *testing.T) disco.Message { } func (c *testClient) handshake(t *testing.T) { - c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpoint{}) + generation := c.handshakeGeneration + c.handshakeGeneration++ + common := disco.BindUDPRelayEndpointCommon{ + VNI: c.vni, + Generation: generation, + RemoteKey: c.remote, + } + c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpoint{ + BindUDPRelayEndpointCommon: common, + }) msg := c.readControlDiscoMsg(t) challenge, ok := msg.(*disco.BindUDPRelayEndpointChallenge) if !ok { - t.Fatal("unexepcted disco message type") + t.Fatal("unexpected disco message type") } - c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpointAnswer{Answer: challenge.Challenge}) + if challenge.Generation != common.Generation { + t.Fatalf("rx'd challenge.Generation (%d) != %d", challenge.Generation, common.Generation) + } + if challenge.VNI != common.VNI { + t.Fatalf("rx'd challenge.VNI (%d) != %d", challenge.VNI, common.VNI) + } + if challenge.RemoteKey != common.RemoteKey { + t.Fatalf("rx'd challenge.RemoteKey (%v) != %v", challenge.RemoteKey, common.RemoteKey) + } + answer := &disco.BindUDPRelayEndpointAnswer{ + BindUDPRelayEndpointCommon: common, + } + answer.Challenge = challenge.Challenge + c.writeControlDiscoMsg(t, answer) } func (c *testClient) close() { @@ -179,9 +205,9 @@ func TestServer(t *testing.T) { if len(endpoint.AddrPorts) != 1 { t.Fatalf("unexpected endpoint.AddrPorts: %v", endpoint.AddrPorts) } - tcA := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, endpoint.ServerDisco) + tcA := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, discoB.Public(), endpoint.ServerDisco) defer tcA.close() - tcB := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, endpoint.ServerDisco) + tcB := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, discoA.Public(), endpoint.ServerDisco) defer tcB.close() tcA.handshake(t) @@ -209,4 +235,30 @@ func TestServer(t *testing.T) { if !bytes.Equal(txToA, rxFromB) { t.Fatal("unexpected msg B->A") } + + tcAOnNewPort := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, discoB.Public(), endpoint.ServerDisco) + tcAOnNewPort.handshakeGeneration = tcA.handshakeGeneration + 1 + defer tcAOnNewPort.close() + + // Handshake client A on a new source IP:port, verify we receive packets on the new binding + tcAOnNewPort.handshake(t) + txToAOnNewPort := []byte{7, 8, 9} + tcB.writeDataPkt(t, txToAOnNewPort) + rxFromB = tcAOnNewPort.readDataPkt(t) + if !bytes.Equal(txToAOnNewPort, rxFromB) { + t.Fatal("unexpected msg B->A") + } + + tcBOnNewPort := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, discoA.Public(), endpoint.ServerDisco) + tcBOnNewPort.handshakeGeneration = tcB.handshakeGeneration + 1 + defer tcBOnNewPort.close() + + // Handshake client B on a new source IP:port, verify we receive packets on the new binding + tcBOnNewPort.handshake(t) + txToBOnNewPort := []byte{7, 8, 9} + tcAOnNewPort.writeDataPkt(t, txToBOnNewPort) + rxFromA = tcBOnNewPort.readDataPkt(t) + if !bytes.Equal(txToBOnNewPort, rxFromA) { + t.Fatal("unexpected msg A->B") + } } diff --git a/wgengine/magicsock/relaymanager.go b/wgengine/magicsock/relaymanager.go index 7b378838a..6418a4364 100644 --- a/wgengine/magicsock/relaymanager.go +++ b/wgengine/magicsock/relaymanager.go @@ -45,6 +45,7 @@ type relayManager struct { handshakeWorkByServerDiscoVNI map[serverDiscoVNI]*relayHandshakeWork handshakeWorkAwaitingPong map[*relayHandshakeWork]addrPortVNI addrPortVNIToHandshakeWork map[addrPortVNI]*relayHandshakeWork + handshakeGeneration uint32 // =================================================================== // The following chan fields serve event inputs to a single goroutine, @@ -590,7 +591,12 @@ func (r *relayManager) handleNewServerEndpointRunLoop(newServerEndpoint newRelay go r.sendCallMeMaybeVia(work.ep, work.se) } - go r.handshakeServerEndpoint(work) + r.handshakeGeneration++ + if r.handshakeGeneration == 0 { // generation must be nonzero + r.handshakeGeneration++ + } + + go r.handshakeServerEndpoint(work, r.handshakeGeneration) } // sendCallMeMaybeVia sends a [disco.CallMeMaybeVia] to ep over DERP. It must be @@ -616,7 +622,7 @@ func (r *relayManager) sendCallMeMaybeVia(ep *endpoint, se udprelay.ServerEndpoi ep.c.sendDiscoMessage(epAddr{ap: derpAddr}, ep.publicKey, epDisco.key, callMeMaybeVia, discoVerboseLog) } -func (r *relayManager) handshakeServerEndpoint(work *relayHandshakeWork) { +func (r *relayManager) handshakeServerEndpoint(work *relayHandshakeWork, generation uint32) { done := relayEndpointHandshakeWorkDoneEvent{work: work} r.ensureDiscoInfoFor(work) @@ -627,8 +633,21 @@ func (r *relayManager) handshakeServerEndpoint(work *relayHandshakeWork) { work.cancel() }() + epDisco := work.ep.disco.Load() + if epDisco == nil { + return + } + + common := disco.BindUDPRelayEndpointCommon{ + VNI: work.se.VNI, + Generation: generation, + RemoteKey: epDisco.key, + } + sentBindAny := false - bind := &disco.BindUDPRelayEndpoint{} + bind := &disco.BindUDPRelayEndpoint{ + BindUDPRelayEndpointCommon: common, + } vni := virtualNetworkID{} vni.set(work.se.VNI) for _, addrPort := range work.se.AddrPorts { @@ -661,10 +680,6 @@ func (r *relayManager) handshakeServerEndpoint(work *relayHandshakeWork) { if len(sentPingAt) == limitPings { return } - epDisco := work.ep.disco.Load() - if epDisco == nil { - return - } txid := stun.NewTxID() sentPingAt[txid] = time.Now() ping := &disco.Ping{ @@ -673,13 +688,24 @@ func (r *relayManager) handshakeServerEndpoint(work *relayHandshakeWork) { } go func() { if withAnswer != nil { - answer := &disco.BindUDPRelayEndpointAnswer{Answer: *withAnswer} + answer := &disco.BindUDPRelayEndpointAnswer{BindUDPRelayEndpointCommon: common} + answer.Challenge = *withAnswer work.ep.c.sendDiscoMessage(epAddr{ap: to, vni: vni}, key.NodePublic{}, work.se.ServerDisco, answer, discoVerboseLog) } work.ep.c.sendDiscoMessage(epAddr{ap: to, vni: vni}, key.NodePublic{}, epDisco.key, ping, discoVerboseLog) }() } + validateVNIAndRemoteKey := func(common disco.BindUDPRelayEndpointCommon) error { + if common.VNI != work.se.VNI { + return errors.New("mismatching VNI") + } + if common.RemoteKey.Compare(epDisco.key) != 0 { + return errors.New("mismatching RemoteKey") + } + return nil + } + // This for{select{}} is responsible for handshaking and tx'ing ping/pong // when the handshake is complete. for { @@ -689,6 +715,10 @@ func (r *relayManager) handshakeServerEndpoint(work *relayHandshakeWork) { case msgEvent := <-work.rxDiscoMsgCh: switch msg := msgEvent.msg.(type) { case *disco.BindUDPRelayEndpointChallenge: + err := validateVNIAndRemoteKey(msg.BindUDPRelayEndpointCommon) + if err != nil { + continue + } if handshakeState >= disco.BindUDPRelayHandshakeStateAnswerSent { continue }