net/udprelay: implement Server.SetStaticAddrPorts (#17909)

Only used in tests for now.

Updates tailscale/corp#31489

Signed-off-by: Jordan Whited <jordan@tailscale.com>
This commit is contained in:
Jordan Whited
2025-11-14 19:43:44 -08:00
committed by GitHub
parent a96ef432cf
commit e1f0ad7a05
4 changed files with 64 additions and 72 deletions

View File

@@ -36,6 +36,7 @@ import (
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/types/nettype"
"tailscale.com/types/views"
"tailscale.com/util/eventbus"
"tailscale.com/util/set"
)
@@ -72,15 +73,16 @@ type Server struct {
closeCh chan struct{}
netChecker *netcheck.Client
mu sync.Mutex // guards the following fields
derpMap *tailcfg.DERPMap
addrDiscoveryOnce bool // addrDiscovery completed once (successfully or unsuccessfully)
addrPorts []netip.AddrPort // the ip:port pairs returned as candidate endpoints
closed bool
lamportID uint64
nextVNI uint32
byVNI map[uint32]*serverEndpoint
byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint
mu sync.Mutex // guards the following fields
derpMap *tailcfg.DERPMap
onlyStaticAddrPorts bool // no dynamic addr port discovery when set
staticAddrPorts views.Slice[netip.AddrPort] // static ip:port pairs set with [Server.SetStaticAddrPorts]
dynamicAddrPorts []netip.AddrPort // dynamically discovered ip:port pairs
closed bool
lamportID uint64
nextVNI uint32
byVNI map[uint32]*serverEndpoint
byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint
}
const (
@@ -278,15 +280,17 @@ func (e *serverEndpoint) isBound() bool {
// NewServer constructs a [Server] listening on port. If port is zero, then
// port selection is left up to the host networking stack. If
// len(overrideAddrs) > 0 these will be used in place of dynamic discovery,
// which is useful to override in tests.
func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Server, err error) {
// onlyStaticAddrPorts is true, then dynamic addr:port discovery will be
// disabled, and only addr:port's set via [Server.SetStaticAddrPorts] will be
// used.
func NewServer(logf logger.Logf, port int, onlyStaticAddrPorts bool) (s *Server, err error) {
s = &Server{
logf: logf,
disco: key.NewDisco(),
bindLifetime: defaultBindLifetime,
steadyStateLifetime: defaultSteadyStateLifetime,
closeCh: make(chan struct{}),
onlyStaticAddrPorts: onlyStaticAddrPorts,
byDisco: make(map[key.SortedPairOfDiscoPublic]*serverEndpoint),
nextVNI: minVNI,
byVNI: make(map[uint32]*serverEndpoint),
@@ -321,19 +325,7 @@ func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Serve
return nil, err
}
if len(overrideAddrs) > 0 {
addrPorts := make(set.Set[netip.AddrPort], len(overrideAddrs))
for _, addr := range overrideAddrs {
if addr.IsValid() {
if addr.Is4() {
addrPorts.Add(netip.AddrPortFrom(addr, s.uc4Port))
} else if s.uc6 != nil {
addrPorts.Add(netip.AddrPortFrom(addr, s.uc6Port))
}
}
}
s.addrPorts = addrPorts.Slice()
} else {
if !s.onlyStaticAddrPorts {
s.wg.Add(1)
go s.addrDiscoveryLoop()
}
@@ -429,8 +421,7 @@ func (s *Server) addrDiscoveryLoop() {
s.logf("error discovering IP:port candidates: %v", err)
}
s.mu.Lock()
s.addrPorts = addrPorts
s.addrDiscoveryOnce = true
s.dynamicAddrPorts = addrPorts
s.mu.Unlock()
case <-s.closeCh:
return
@@ -747,6 +738,15 @@ func (s *Server) getNextVNILocked() (uint32, error) {
return 0, errors.New("VNI pool exhausted")
}
// getAllAddrPortsCopyLocked returns a copy of the combined
// [Server.staticAddrPorts] and [Server.dynamicAddrPorts] slices.
func (s *Server) getAllAddrPortsCopyLocked() []netip.AddrPort {
addrPorts := make([]netip.AddrPort, 0, len(s.dynamicAddrPorts)+s.staticAddrPorts.Len())
addrPorts = append(addrPorts, s.staticAddrPorts.AsSlice()...)
addrPorts = append(addrPorts, slices.Clone(s.dynamicAddrPorts)...)
return addrPorts
}
// AllocateEndpoint allocates an [endpoint.ServerEndpoint] for the provided pair
// of [key.DiscoPublic]'s. If an allocation already exists for discoA and discoB
// it is returned without modification/reallocation. AllocateEndpoint returns
@@ -760,11 +760,8 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
return endpoint.ServerEndpoint{}, ErrServerClosed
}
if len(s.addrPorts) == 0 {
if !s.addrDiscoveryOnce {
return endpoint.ServerEndpoint{}, ErrServerNotReady{RetryAfter: endpoint.ServerRetryAfter}
}
return endpoint.ServerEndpoint{}, errors.New("server addrPorts are not yet known")
if s.staticAddrPorts.Len() == 0 && len(s.dynamicAddrPorts) == 0 {
return endpoint.ServerEndpoint{}, ErrServerNotReady{RetryAfter: endpoint.ServerRetryAfter}
}
if discoA.Compare(s.discoPublic) == 0 || discoB.Compare(s.discoPublic) == 0 {
@@ -787,7 +784,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
// consider storing them (maybe interning) in the [*serverEndpoint]
// at allocation time.
ClientDisco: pair.Get(),
AddrPorts: slices.Clone(s.addrPorts),
AddrPorts: s.getAllAddrPortsCopyLocked(),
VNI: e.vni,
LamportID: e.lamportID,
BindLifetime: tstime.GoDuration{Duration: s.bindLifetime},
@@ -817,7 +814,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
return endpoint.ServerEndpoint{
ServerDisco: s.discoPublic,
ClientDisco: pair.Get(),
AddrPorts: slices.Clone(s.addrPorts),
AddrPorts: s.getAllAddrPortsCopyLocked(),
VNI: e.vni,
LamportID: e.lamportID,
BindLifetime: tstime.GoDuration{Duration: s.bindLifetime},
@@ -880,3 +877,13 @@ func (s *Server) getDERPMap() *tailcfg.DERPMap {
defer s.mu.Unlock()
return s.derpMap
}
// SetStaticAddrPorts sets addr:port pairs the [Server] will advertise
// as candidates it is potentially reachable over, in combination with
// dynamically discovered pairs. This replaces any previously-provided static
// values.
func (s *Server) SetStaticAddrPorts(addrPorts views.Slice[netip.AddrPort]) {
s.mu.Lock()
defer s.mu.Unlock()
s.staticAddrPorts = addrPorts
}