mirror of
https://github.com/tailscale/tailscale.git
synced 2025-12-23 17:16:29 +00:00
net/udprelay: implement Server.SetStaticAddrPorts (#17909)
Only used in tests for now. Updates tailscale/corp#31489 Signed-off-by: Jordan Whited <jordan@tailscale.com>
This commit is contained in:
@@ -8,14 +8,10 @@ package relayserver
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"tailscale.com/disco"
|
"tailscale.com/disco"
|
||||||
"tailscale.com/envknob"
|
|
||||||
"tailscale.com/feature"
|
"tailscale.com/feature"
|
||||||
"tailscale.com/ipn"
|
"tailscale.com/ipn"
|
||||||
"tailscale.com/ipn/ipnext"
|
"tailscale.com/ipn/ipnext"
|
||||||
@@ -71,8 +67,8 @@ func servePeerRelayDebugSessions(h *localapi.Handler, w http.ResponseWriter, r *
|
|||||||
// imported.
|
// imported.
|
||||||
func newExtension(logf logger.Logf, sb ipnext.SafeBackend) (ipnext.Extension, error) {
|
func newExtension(logf logger.Logf, sb ipnext.SafeBackend) (ipnext.Extension, error) {
|
||||||
e := &extension{
|
e := &extension{
|
||||||
newServerFn: func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) {
|
newServerFn: func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) {
|
||||||
return udprelay.NewServer(logf, port, overrideAddrs)
|
return udprelay.NewServer(logf, port, onlyStaticAddrPorts)
|
||||||
},
|
},
|
||||||
logf: logger.WithPrefix(logf, featureName+": "),
|
logf: logger.WithPrefix(logf, featureName+": "),
|
||||||
}
|
}
|
||||||
@@ -94,7 +90,7 @@ type relayServer interface {
|
|||||||
// extension is an [ipnext.Extension] managing the relay server on platforms
|
// extension is an [ipnext.Extension] managing the relay server on platforms
|
||||||
// that import this package.
|
// that import this package.
|
||||||
type extension struct {
|
type extension struct {
|
||||||
newServerFn func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) // swappable for tests
|
newServerFn func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) // swappable for tests
|
||||||
logf logger.Logf
|
logf logger.Logf
|
||||||
ec *eventbus.Client
|
ec *eventbus.Client
|
||||||
respPub *eventbus.Publisher[magicsock.UDPRelayAllocResp]
|
respPub *eventbus.Publisher[magicsock.UDPRelayAllocResp]
|
||||||
@@ -170,7 +166,7 @@ func (e *extension) onAllocReq(req magicsock.UDPRelayAllocReq) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *extension) tryStartRelayServerLocked() {
|
func (e *extension) tryStartRelayServerLocked() {
|
||||||
rs, err := e.newServerFn(e.logf, *e.port, overrideAddrs())
|
rs, err := e.newServerFn(e.logf, *e.port, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
e.logf("error initializing server: %v", err)
|
e.logf("error initializing server: %v", err)
|
||||||
return
|
return
|
||||||
@@ -217,26 +213,6 @@ func (e *extension) profileStateChanged(_ ipn.LoginProfileView, prefs ipn.PrefsV
|
|||||||
e.handleRelayServerLifetimeLocked()
|
e.handleRelayServerLifetimeLocked()
|
||||||
}
|
}
|
||||||
|
|
||||||
// overrideAddrs returns TS_DEBUG_RELAY_SERVER_ADDRS as []netip.Addr, if set. It
|
|
||||||
// can be between 0 and 3 comma-separated Addrs. TS_DEBUG_RELAY_SERVER_ADDRS is
|
|
||||||
// not a stable interface, and is subject to change.
|
|
||||||
var overrideAddrs = sync.OnceValue(func() (ret []netip.Addr) {
|
|
||||||
all := envknob.String("TS_DEBUG_RELAY_SERVER_ADDRS")
|
|
||||||
const max = 3
|
|
||||||
remain := all
|
|
||||||
for remain != "" && len(ret) < max {
|
|
||||||
var s string
|
|
||||||
s, remain, _ = strings.Cut(remain, ",")
|
|
||||||
addr, err := netip.ParseAddr(s)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("ignoring invalid Addr %q in TS_DEBUG_RELAY_SERVER_ADDRS %q: %v", s, all, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ret = append(ret, addr)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
})
|
|
||||||
|
|
||||||
func (e *extension) stopRelayServerLocked() {
|
func (e *extension) stopRelayServerLocked() {
|
||||||
if e.rs != nil {
|
if e.rs != nil {
|
||||||
e.rs.Close()
|
e.rs.Close()
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ package relayserver
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"net/netip"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -157,7 +156,7 @@ func Test_extension_profileStateChanged(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
e := ipne.(*extension)
|
e := ipne.(*extension)
|
||||||
e.newServerFn = func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) {
|
e.newServerFn = func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) {
|
||||||
return &mockRelayServer{}, nil
|
return &mockRelayServer{}, nil
|
||||||
}
|
}
|
||||||
e.port = tt.fields.port
|
e.port = tt.fields.port
|
||||||
@@ -289,7 +288,7 @@ func Test_extension_handleRelayServerLifetimeLocked(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
e := ipne.(*extension)
|
e := ipne.(*extension)
|
||||||
e.newServerFn = func(logf logger.Logf, port int, overrideAddrs []netip.Addr) (relayServer, error) {
|
e.newServerFn = func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) {
|
||||||
return &mockRelayServer{}, nil
|
return &mockRelayServer{}, nil
|
||||||
}
|
}
|
||||||
e.shutdown = tt.shutdown
|
e.shutdown = tt.shutdown
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ import (
|
|||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
"tailscale.com/types/logger"
|
"tailscale.com/types/logger"
|
||||||
"tailscale.com/types/nettype"
|
"tailscale.com/types/nettype"
|
||||||
|
"tailscale.com/types/views"
|
||||||
"tailscale.com/util/eventbus"
|
"tailscale.com/util/eventbus"
|
||||||
"tailscale.com/util/set"
|
"tailscale.com/util/set"
|
||||||
)
|
)
|
||||||
@@ -72,15 +73,16 @@ type Server struct {
|
|||||||
closeCh chan struct{}
|
closeCh chan struct{}
|
||||||
netChecker *netcheck.Client
|
netChecker *netcheck.Client
|
||||||
|
|
||||||
mu sync.Mutex // guards the following fields
|
mu sync.Mutex // guards the following fields
|
||||||
derpMap *tailcfg.DERPMap
|
derpMap *tailcfg.DERPMap
|
||||||
addrDiscoveryOnce bool // addrDiscovery completed once (successfully or unsuccessfully)
|
onlyStaticAddrPorts bool // no dynamic addr port discovery when set
|
||||||
addrPorts []netip.AddrPort // the ip:port pairs returned as candidate endpoints
|
staticAddrPorts views.Slice[netip.AddrPort] // static ip:port pairs set with [Server.SetStaticAddrPorts]
|
||||||
closed bool
|
dynamicAddrPorts []netip.AddrPort // dynamically discovered ip:port pairs
|
||||||
lamportID uint64
|
closed bool
|
||||||
nextVNI uint32
|
lamportID uint64
|
||||||
byVNI map[uint32]*serverEndpoint
|
nextVNI uint32
|
||||||
byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint
|
byVNI map[uint32]*serverEndpoint
|
||||||
|
byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -278,15 +280,17 @@ func (e *serverEndpoint) isBound() bool {
|
|||||||
|
|
||||||
// NewServer constructs a [Server] listening on port. If port is zero, then
|
// NewServer constructs a [Server] listening on port. If port is zero, then
|
||||||
// port selection is left up to the host networking stack. If
|
// port selection is left up to the host networking stack. If
|
||||||
// len(overrideAddrs) > 0 these will be used in place of dynamic discovery,
|
// onlyStaticAddrPorts is true, then dynamic addr:port discovery will be
|
||||||
// which is useful to override in tests.
|
// disabled, and only addr:port's set via [Server.SetStaticAddrPorts] will be
|
||||||
func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Server, err error) {
|
// used.
|
||||||
|
func NewServer(logf logger.Logf, port int, onlyStaticAddrPorts bool) (s *Server, err error) {
|
||||||
s = &Server{
|
s = &Server{
|
||||||
logf: logf,
|
logf: logf,
|
||||||
disco: key.NewDisco(),
|
disco: key.NewDisco(),
|
||||||
bindLifetime: defaultBindLifetime,
|
bindLifetime: defaultBindLifetime,
|
||||||
steadyStateLifetime: defaultSteadyStateLifetime,
|
steadyStateLifetime: defaultSteadyStateLifetime,
|
||||||
closeCh: make(chan struct{}),
|
closeCh: make(chan struct{}),
|
||||||
|
onlyStaticAddrPorts: onlyStaticAddrPorts,
|
||||||
byDisco: make(map[key.SortedPairOfDiscoPublic]*serverEndpoint),
|
byDisco: make(map[key.SortedPairOfDiscoPublic]*serverEndpoint),
|
||||||
nextVNI: minVNI,
|
nextVNI: minVNI,
|
||||||
byVNI: make(map[uint32]*serverEndpoint),
|
byVNI: make(map[uint32]*serverEndpoint),
|
||||||
@@ -321,19 +325,7 @@ func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Serve
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(overrideAddrs) > 0 {
|
if !s.onlyStaticAddrPorts {
|
||||||
addrPorts := make(set.Set[netip.AddrPort], len(overrideAddrs))
|
|
||||||
for _, addr := range overrideAddrs {
|
|
||||||
if addr.IsValid() {
|
|
||||||
if addr.Is4() {
|
|
||||||
addrPorts.Add(netip.AddrPortFrom(addr, s.uc4Port))
|
|
||||||
} else if s.uc6 != nil {
|
|
||||||
addrPorts.Add(netip.AddrPortFrom(addr, s.uc6Port))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s.addrPorts = addrPorts.Slice()
|
|
||||||
} else {
|
|
||||||
s.wg.Add(1)
|
s.wg.Add(1)
|
||||||
go s.addrDiscoveryLoop()
|
go s.addrDiscoveryLoop()
|
||||||
}
|
}
|
||||||
@@ -429,8 +421,7 @@ func (s *Server) addrDiscoveryLoop() {
|
|||||||
s.logf("error discovering IP:port candidates: %v", err)
|
s.logf("error discovering IP:port candidates: %v", err)
|
||||||
}
|
}
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.addrPorts = addrPorts
|
s.dynamicAddrPorts = addrPorts
|
||||||
s.addrDiscoveryOnce = true
|
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
case <-s.closeCh:
|
case <-s.closeCh:
|
||||||
return
|
return
|
||||||
@@ -747,6 +738,15 @@ func (s *Server) getNextVNILocked() (uint32, error) {
|
|||||||
return 0, errors.New("VNI pool exhausted")
|
return 0, errors.New("VNI pool exhausted")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getAllAddrPortsCopyLocked returns a copy of the combined
|
||||||
|
// [Server.staticAddrPorts] and [Server.dynamicAddrPorts] slices.
|
||||||
|
func (s *Server) getAllAddrPortsCopyLocked() []netip.AddrPort {
|
||||||
|
addrPorts := make([]netip.AddrPort, 0, len(s.dynamicAddrPorts)+s.staticAddrPorts.Len())
|
||||||
|
addrPorts = append(addrPorts, s.staticAddrPorts.AsSlice()...)
|
||||||
|
addrPorts = append(addrPorts, slices.Clone(s.dynamicAddrPorts)...)
|
||||||
|
return addrPorts
|
||||||
|
}
|
||||||
|
|
||||||
// AllocateEndpoint allocates an [endpoint.ServerEndpoint] for the provided pair
|
// AllocateEndpoint allocates an [endpoint.ServerEndpoint] for the provided pair
|
||||||
// of [key.DiscoPublic]'s. If an allocation already exists for discoA and discoB
|
// of [key.DiscoPublic]'s. If an allocation already exists for discoA and discoB
|
||||||
// it is returned without modification/reallocation. AllocateEndpoint returns
|
// it is returned without modification/reallocation. AllocateEndpoint returns
|
||||||
@@ -760,11 +760,8 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
|
|||||||
return endpoint.ServerEndpoint{}, ErrServerClosed
|
return endpoint.ServerEndpoint{}, ErrServerClosed
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(s.addrPorts) == 0 {
|
if s.staticAddrPorts.Len() == 0 && len(s.dynamicAddrPorts) == 0 {
|
||||||
if !s.addrDiscoveryOnce {
|
return endpoint.ServerEndpoint{}, ErrServerNotReady{RetryAfter: endpoint.ServerRetryAfter}
|
||||||
return endpoint.ServerEndpoint{}, ErrServerNotReady{RetryAfter: endpoint.ServerRetryAfter}
|
|
||||||
}
|
|
||||||
return endpoint.ServerEndpoint{}, errors.New("server addrPorts are not yet known")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if discoA.Compare(s.discoPublic) == 0 || discoB.Compare(s.discoPublic) == 0 {
|
if discoA.Compare(s.discoPublic) == 0 || discoB.Compare(s.discoPublic) == 0 {
|
||||||
@@ -787,7 +784,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
|
|||||||
// consider storing them (maybe interning) in the [*serverEndpoint]
|
// consider storing them (maybe interning) in the [*serverEndpoint]
|
||||||
// at allocation time.
|
// at allocation time.
|
||||||
ClientDisco: pair.Get(),
|
ClientDisco: pair.Get(),
|
||||||
AddrPorts: slices.Clone(s.addrPorts),
|
AddrPorts: s.getAllAddrPortsCopyLocked(),
|
||||||
VNI: e.vni,
|
VNI: e.vni,
|
||||||
LamportID: e.lamportID,
|
LamportID: e.lamportID,
|
||||||
BindLifetime: tstime.GoDuration{Duration: s.bindLifetime},
|
BindLifetime: tstime.GoDuration{Duration: s.bindLifetime},
|
||||||
@@ -817,7 +814,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
|
|||||||
return endpoint.ServerEndpoint{
|
return endpoint.ServerEndpoint{
|
||||||
ServerDisco: s.discoPublic,
|
ServerDisco: s.discoPublic,
|
||||||
ClientDisco: pair.Get(),
|
ClientDisco: pair.Get(),
|
||||||
AddrPorts: slices.Clone(s.addrPorts),
|
AddrPorts: s.getAllAddrPortsCopyLocked(),
|
||||||
VNI: e.vni,
|
VNI: e.vni,
|
||||||
LamportID: e.lamportID,
|
LamportID: e.lamportID,
|
||||||
BindLifetime: tstime.GoDuration{Duration: s.bindLifetime},
|
BindLifetime: tstime.GoDuration{Duration: s.bindLifetime},
|
||||||
@@ -880,3 +877,13 @@ func (s *Server) getDERPMap() *tailcfg.DERPMap {
|
|||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
return s.derpMap
|
return s.derpMap
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetStaticAddrPorts sets addr:port pairs the [Server] will advertise
|
||||||
|
// as candidates it is potentially reachable over, in combination with
|
||||||
|
// dynamically discovered pairs. This replaces any previously-provided static
|
||||||
|
// values.
|
||||||
|
func (s *Server) SetStaticAddrPorts(addrPorts views.Slice[netip.AddrPort]) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.staticAddrPorts = addrPorts
|
||||||
|
}
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
"tailscale.com/disco"
|
"tailscale.com/disco"
|
||||||
"tailscale.com/net/packet"
|
"tailscale.com/net/packet"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
|
"tailscale.com/types/views"
|
||||||
)
|
)
|
||||||
|
|
||||||
type testClient struct {
|
type testClient struct {
|
||||||
@@ -185,31 +186,40 @@ func TestServer(t *testing.T) {
|
|||||||
|
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
name string
|
name string
|
||||||
overrideAddrs []netip.Addr
|
staticAddrs []netip.Addr
|
||||||
forceClientsMixedAF bool
|
forceClientsMixedAF bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "over ipv4",
|
name: "over ipv4",
|
||||||
overrideAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
|
staticAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1")},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "over ipv6",
|
name: "over ipv6",
|
||||||
overrideAddrs: []netip.Addr{netip.MustParseAddr("::1")},
|
staticAddrs: []netip.Addr{netip.MustParseAddr("::1")},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "mixed address families",
|
name: "mixed address families",
|
||||||
overrideAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("::1")},
|
staticAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("::1")},
|
||||||
forceClientsMixedAF: true,
|
forceClientsMixedAF: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range cases {
|
for _, tt := range cases {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
server, err := NewServer(t.Logf, 0, tt.overrideAddrs)
|
server, err := NewServer(t.Logf, 0, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
addrPorts := make([]netip.AddrPort, 0, len(tt.staticAddrs))
|
||||||
|
for _, addr := range tt.staticAddrs {
|
||||||
|
if addr.Is4() {
|
||||||
|
addrPorts = append(addrPorts, netip.AddrPortFrom(addr, server.uc4Port))
|
||||||
|
} else if server.uc6Port != 0 {
|
||||||
|
addrPorts = append(addrPorts, netip.AddrPortFrom(addr, server.uc6Port))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
server.SetStaticAddrPorts(views.SliceOf(addrPorts))
|
||||||
|
|
||||||
endpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public())
|
endpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user