mirror of
https://github.com/tailscale/tailscale.git
synced 2025-12-23 17:16:29 +00:00
net/udprelay: implement Server.SetStaticAddrPorts (#17909)
Only used in tests for now. Updates tailscale/corp#31489 Signed-off-by: Jordan Whited <jordan@tailscale.com>
This commit is contained in:
@@ -8,14 +8,10 @@ package relayserver
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"tailscale.com/disco"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/feature"
|
||||
"tailscale.com/ipn"
|
||||
"tailscale.com/ipn/ipnext"
|
||||
@@ -71,8 +67,8 @@ func servePeerRelayDebugSessions(h *localapi.Handler, w http.ResponseWriter, r *
|
||||
// imported.
|
||||
func newExtension(logf logger.Logf, sb ipnext.SafeBackend) (ipnext.Extension, error) {
|
||||
e := &extension{
|
||||
newServerFn: func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) {
|
||||
return udprelay.NewServer(logf, port, overrideAddrs)
|
||||
newServerFn: func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) {
|
||||
return udprelay.NewServer(logf, port, onlyStaticAddrPorts)
|
||||
},
|
||||
logf: logger.WithPrefix(logf, featureName+": "),
|
||||
}
|
||||
@@ -94,7 +90,7 @@ type relayServer interface {
|
||||
// extension is an [ipnext.Extension] managing the relay server on platforms
|
||||
// that import this package.
|
||||
type extension struct {
|
||||
newServerFn func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) // swappable for tests
|
||||
newServerFn func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) // swappable for tests
|
||||
logf logger.Logf
|
||||
ec *eventbus.Client
|
||||
respPub *eventbus.Publisher[magicsock.UDPRelayAllocResp]
|
||||
@@ -170,7 +166,7 @@ func (e *extension) onAllocReq(req magicsock.UDPRelayAllocReq) {
|
||||
}
|
||||
|
||||
func (e *extension) tryStartRelayServerLocked() {
|
||||
rs, err := e.newServerFn(e.logf, *e.port, overrideAddrs())
|
||||
rs, err := e.newServerFn(e.logf, *e.port, false)
|
||||
if err != nil {
|
||||
e.logf("error initializing server: %v", err)
|
||||
return
|
||||
@@ -217,26 +213,6 @@ func (e *extension) profileStateChanged(_ ipn.LoginProfileView, prefs ipn.PrefsV
|
||||
e.handleRelayServerLifetimeLocked()
|
||||
}
|
||||
|
||||
// overrideAddrs returns TS_DEBUG_RELAY_SERVER_ADDRS as []netip.Addr, if set. It
|
||||
// can be between 0 and 3 comma-separated Addrs. TS_DEBUG_RELAY_SERVER_ADDRS is
|
||||
// not a stable interface, and is subject to change.
|
||||
var overrideAddrs = sync.OnceValue(func() (ret []netip.Addr) {
|
||||
all := envknob.String("TS_DEBUG_RELAY_SERVER_ADDRS")
|
||||
const max = 3
|
||||
remain := all
|
||||
for remain != "" && len(ret) < max {
|
||||
var s string
|
||||
s, remain, _ = strings.Cut(remain, ",")
|
||||
addr, err := netip.ParseAddr(s)
|
||||
if err != nil {
|
||||
log.Printf("ignoring invalid Addr %q in TS_DEBUG_RELAY_SERVER_ADDRS %q: %v", s, all, err)
|
||||
continue
|
||||
}
|
||||
ret = append(ret, addr)
|
||||
}
|
||||
return
|
||||
})
|
||||
|
||||
func (e *extension) stopRelayServerLocked() {
|
||||
if e.rs != nil {
|
||||
e.rs.Close()
|
||||
|
||||
@@ -5,7 +5,6 @@ package relayserver
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
@@ -157,7 +156,7 @@ func Test_extension_profileStateChanged(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
e := ipne.(*extension)
|
||||
e.newServerFn = func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) {
|
||||
e.newServerFn = func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) {
|
||||
return &mockRelayServer{}, nil
|
||||
}
|
||||
e.port = tt.fields.port
|
||||
@@ -289,7 +288,7 @@ func Test_extension_handleRelayServerLifetimeLocked(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
e := ipne.(*extension)
|
||||
e.newServerFn = func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) {
|
||||
e.newServerFn = func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) {
|
||||
return &mockRelayServer{}, nil
|
||||
}
|
||||
e.shutdown = tt.shutdown
|
||||
|
||||
@@ -36,6 +36,7 @@ import (
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/types/nettype"
|
||||
"tailscale.com/types/views"
|
||||
"tailscale.com/util/eventbus"
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
@@ -72,15 +73,16 @@ type Server struct {
|
||||
closeCh chan struct{}
|
||||
netChecker *netcheck.Client
|
||||
|
||||
mu sync.Mutex // guards the following fields
|
||||
derpMap *tailcfg.DERPMap
|
||||
addrDiscoveryOnce bool // addrDiscovery completed once (successfully or unsuccessfully)
|
||||
addrPorts []netip.AddrPort // the ip:port pairs returned as candidate endpoints
|
||||
closed bool
|
||||
lamportID uint64
|
||||
nextVNI uint32
|
||||
byVNI map[uint32]*serverEndpoint
|
||||
byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint
|
||||
mu sync.Mutex // guards the following fields
|
||||
derpMap *tailcfg.DERPMap
|
||||
onlyStaticAddrPorts bool // no dynamic addr port discovery when set
|
||||
staticAddrPorts views.Slice[netip.AddrPort] // static ip:port pairs set with [Server.SetStaticAddrPorts]
|
||||
dynamicAddrPorts []netip.AddrPort // dynamically discovered ip:port pairs
|
||||
closed bool
|
||||
lamportID uint64
|
||||
nextVNI uint32
|
||||
byVNI map[uint32]*serverEndpoint
|
||||
byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -278,15 +280,17 @@ func (e *serverEndpoint) isBound() bool {
|
||||
|
||||
// NewServer constructs a [Server] listening on port. If port is zero, then
|
||||
// port selection is left up to the host networking stack. If
|
||||
// len(overrideAddrs) > 0 these will be used in place of dynamic discovery,
|
||||
// which is useful to override in tests.
|
||||
func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Server, err error) {
|
||||
// onlyStaticAddrPorts is true, then dynamic addr:port discovery will be
|
||||
// disabled, and only addr:port's set via [Server.SetStaticAddrPorts] will be
|
||||
// used.
|
||||
func NewServer(logf logger.Logf, port int, onlyStaticAddrPorts bool) (s *Server, err error) {
|
||||
s = &Server{
|
||||
logf: logf,
|
||||
disco: key.NewDisco(),
|
||||
bindLifetime: defaultBindLifetime,
|
||||
steadyStateLifetime: defaultSteadyStateLifetime,
|
||||
closeCh: make(chan struct{}),
|
||||
onlyStaticAddrPorts: onlyStaticAddrPorts,
|
||||
byDisco: make(map[key.SortedPairOfDiscoPublic]*serverEndpoint),
|
||||
nextVNI: minVNI,
|
||||
byVNI: make(map[uint32]*serverEndpoint),
|
||||
@@ -321,19 +325,7 @@ func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Serve
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(overrideAddrs) > 0 {
|
||||
addrPorts := make(set.Set[netip.AddrPort], len(overrideAddrs))
|
||||
for _, addr := range overrideAddrs {
|
||||
if addr.IsValid() {
|
||||
if addr.Is4() {
|
||||
addrPorts.Add(netip.AddrPortFrom(addr, s.uc4Port))
|
||||
} else if s.uc6 != nil {
|
||||
addrPorts.Add(netip.AddrPortFrom(addr, s.uc6Port))
|
||||
}
|
||||
}
|
||||
}
|
||||
s.addrPorts = addrPorts.Slice()
|
||||
} else {
|
||||
if !s.onlyStaticAddrPorts {
|
||||
s.wg.Add(1)
|
||||
go s.addrDiscoveryLoop()
|
||||
}
|
||||
@@ -429,8 +421,7 @@ func (s *Server) addrDiscoveryLoop() {
|
||||
s.logf("error discovering IP:port candidates: %v", err)
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.addrPorts = addrPorts
|
||||
s.addrDiscoveryOnce = true
|
||||
s.dynamicAddrPorts = addrPorts
|
||||
s.mu.Unlock()
|
||||
case <-s.closeCh:
|
||||
return
|
||||
@@ -747,6 +738,15 @@ func (s *Server) getNextVNILocked() (uint32, error) {
|
||||
return 0, errors.New("VNI pool exhausted")
|
||||
}
|
||||
|
||||
// getAllAddrPortsCopyLocked returns a copy of the combined
|
||||
// [Server.staticAddrPorts] and [Server.dynamicAddrPorts] slices.
|
||||
func (s *Server) getAllAddrPortsCopyLocked() []netip.AddrPort {
|
||||
addrPorts := make([]netip.AddrPort, 0, len(s.dynamicAddrPorts)+s.staticAddrPorts.Len())
|
||||
addrPorts = append(addrPorts, s.staticAddrPorts.AsSlice()...)
|
||||
addrPorts = append(addrPorts, slices.Clone(s.dynamicAddrPorts)...)
|
||||
return addrPorts
|
||||
}
|
||||
|
||||
// AllocateEndpoint allocates an [endpoint.ServerEndpoint] for the provided pair
|
||||
// of [key.DiscoPublic]'s. If an allocation already exists for discoA and discoB
|
||||
// it is returned without modification/reallocation. AllocateEndpoint returns
|
||||
@@ -760,11 +760,8 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
|
||||
return endpoint.ServerEndpoint{}, ErrServerClosed
|
||||
}
|
||||
|
||||
if len(s.addrPorts) == 0 {
|
||||
if !s.addrDiscoveryOnce {
|
||||
return endpoint.ServerEndpoint{}, ErrServerNotReady{RetryAfter: endpoint.ServerRetryAfter}
|
||||
}
|
||||
return endpoint.ServerEndpoint{}, errors.New("server addrPorts are not yet known")
|
||||
if s.staticAddrPorts.Len() == 0 && len(s.dynamicAddrPorts) == 0 {
|
||||
return endpoint.ServerEndpoint{}, ErrServerNotReady{RetryAfter: endpoint.ServerRetryAfter}
|
||||
}
|
||||
|
||||
if discoA.Compare(s.discoPublic) == 0 || discoB.Compare(s.discoPublic) == 0 {
|
||||
@@ -787,7 +784,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
|
||||
// consider storing them (maybe interning) in the [*serverEndpoint]
|
||||
// at allocation time.
|
||||
ClientDisco: pair.Get(),
|
||||
AddrPorts: slices.Clone(s.addrPorts),
|
||||
AddrPorts: s.getAllAddrPortsCopyLocked(),
|
||||
VNI: e.vni,
|
||||
LamportID: e.lamportID,
|
||||
BindLifetime: tstime.GoDuration{Duration: s.bindLifetime},
|
||||
@@ -817,7 +814,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
|
||||
return endpoint.ServerEndpoint{
|
||||
ServerDisco: s.discoPublic,
|
||||
ClientDisco: pair.Get(),
|
||||
AddrPorts: slices.Clone(s.addrPorts),
|
||||
AddrPorts: s.getAllAddrPortsCopyLocked(),
|
||||
VNI: e.vni,
|
||||
LamportID: e.lamportID,
|
||||
BindLifetime: tstime.GoDuration{Duration: s.bindLifetime},
|
||||
@@ -880,3 +877,13 @@ func (s *Server) getDERPMap() *tailcfg.DERPMap {
|
||||
defer s.mu.Unlock()
|
||||
return s.derpMap
|
||||
}
|
||||
|
||||
// SetStaticAddrPorts sets addr:port pairs the [Server] will advertise
|
||||
// as candidates it is potentially reachable over, in combination with
|
||||
// dynamically discovered pairs. This replaces any previously-provided static
|
||||
// values.
|
||||
func (s *Server) SetStaticAddrPorts(addrPorts views.Slice[netip.AddrPort]) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.staticAddrPorts = addrPorts
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"tailscale.com/disco"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
type testClient struct {
|
||||
@@ -185,31 +186,40 @@ func TestServer(t *testing.T) {
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
overrideAddrs []netip.Addr
|
||||
staticAddrs []netip.Addr
|
||||
forceClientsMixedAF bool
|
||||
}{
|
||||
{
|
||||
name: "over ipv4",
|
||||
overrideAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
|
||||
name: "over ipv4",
|
||||
staticAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
|
||||
},
|
||||
{
|
||||
name: "over ipv6",
|
||||
overrideAddrs: []netip.Addr{netip.MustParseAddr("::1")},
|
||||
name: "over ipv6",
|
||||
staticAddrs: []netip.Addr{netip.MustParseAddr("::1")},
|
||||
},
|
||||
{
|
||||
name: "mixed address families",
|
||||
overrideAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("::1")},
|
||||
staticAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("::1")},
|
||||
forceClientsMixedAF: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server, err := NewServer(t.Logf, 0, tt.overrideAddrs)
|
||||
server, err := NewServer(t.Logf, 0, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer server.Close()
|
||||
addrPorts := make([]netip.AddrPort, 0, len(tt.staticAddrs))
|
||||
for _, addr := range tt.staticAddrs {
|
||||
if addr.Is4() {
|
||||
addrPorts = append(addrPorts, netip.AddrPortFrom(addr, server.uc4Port))
|
||||
} else if server.uc6Port != 0 {
|
||||
addrPorts = append(addrPorts, netip.AddrPortFrom(addr, server.uc6Port))
|
||||
}
|
||||
}
|
||||
server.SetStaticAddrPorts(views.SliceOf(addrPorts))
|
||||
|
||||
endpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public())
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user