diff --git a/net/udprelay/server.go b/net/udprelay/server.go index 373165777..5580b6e65 100644 --- a/net/udprelay/server.go +++ b/net/udprelay/server.go @@ -454,9 +454,10 @@ func (s *Server) packetReadLoop() { var ErrServerClosed = errors.New("server closed") -// AllocateEndpoint allocates a ServerEndpoint for the provided pair of -// key.DiscoPublic's. It returns an error (ErrServerClosed) if the server has -// been closed. +// AllocateEndpoint allocates a [ServerEndpoint] for the provided pair of +// [key.DiscoPublic]'s. If an allocation already exists for discoA and discoB it +// is returned without modification/reallocation. AllocateEndpoint returns +// [ErrServerClosed] if the server has been closed. func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (ServerEndpoint, error) { s.mu.Lock() defer s.mu.Unlock() @@ -471,36 +472,19 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (ServerEndpoin pair := newPairOfDiscoPubKeys(discoA, discoB) e, ok := s.byDisco[pair] if ok { - if !e.isBound() { - // If the endpoint is not yet bound this is likely an allocation - // race between two clients on the same Server. Instead of - // re-allocating we return the existing allocation. We do not reset - // e.allocatedAt in case a client is "stuck" in an allocation - // loop and will not be able to complete a handshake, for whatever - // reason. Once the endpoint expires a new endpoint will be - // allocated. Clients can resolve duplicate ServerEndpoint details - // via ServerEndpoint.LamportID. - // - // TODO: consider ServerEndpoint.BindLifetime -= time.Now()-e.allocatedAt - // to give the client a more accurate picture of the bind window. - // Or, some threshold to trigger re-allocation if too much time has - // already passed since it was originally allocated. - return ServerEndpoint{ - ServerDisco: s.discoPublic, - AddrPorts: s.addrPorts, - VNI: e.vni, - LamportID: e.lamportID, - BindLifetime: tstime.GoDuration{Duration: s.bindLifetime}, - SteadyStateLifetime: tstime.GoDuration{Duration: s.steadyStateLifetime}, - }, nil - } - // If an endpoint exists for the pair of key.DiscoPublic's, and is - // already bound, delete it. We will re-allocate a new endpoint. Chances - // are clients cannot make use of the existing, bound allocation if - // they are requesting a new one. - delete(s.byDisco, pair) - delete(s.byVNI, e.vni) - s.vniPool = append(s.vniPool, e.vni) + // Return the existing allocation. Clients can resolve duplicate + // [ServerEndpoint]'s via [ServerEndpoint.LamportID]. + // + // TODO: consider ServerEndpoint.BindLifetime -= time.Now()-e.allocatedAt + // to give the client a more accurate picture of the bind window. + return ServerEndpoint{ + ServerDisco: s.discoPublic, + AddrPorts: s.addrPorts, + VNI: e.vni, + LamportID: e.lamportID, + BindLifetime: tstime.GoDuration{Duration: s.bindLifetime}, + SteadyStateLifetime: tstime.GoDuration{Duration: s.steadyStateLifetime}, + }, nil } if len(s.vniPool) == 0 { diff --git a/net/udprelay/server_test.go b/net/udprelay/server_test.go index fad35ec03..c699e5d15 100644 --- a/net/udprelay/server_test.go +++ b/net/udprelay/server_test.go @@ -174,8 +174,7 @@ func TestServer(t *testing.T) { t.Fatal(err) } - // We expect the same endpoint details as the 3-way bind handshake has not - // yet been completed for both relay client parties. + // We expect the same endpoint details pre-handshake. if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" { t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff) } @@ -191,6 +190,15 @@ func TestServer(t *testing.T) { tcA.handshake(t) tcB.handshake(t) + dupEndpoint, err = server.AllocateEndpoint(discoA.Public(), discoB.Public()) + if err != nil { + t.Fatal(err) + } + // We expect the same endpoint details post-handshake. + if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" { + t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff) + } + txToB := []byte{1, 2, 3} tcA.writeDataPkt(t, txToB) rxFromA := tcB.readDataPkt(t)