diff --git a/feature/relayserver/relayserver.go b/feature/relayserver/relayserver.go index 96d21138e..a38587aa3 100644 --- a/feature/relayserver/relayserver.go +++ b/feature/relayserver/relayserver.go @@ -10,7 +10,6 @@ import ( "errors" "io" "net/http" - "net/netip" "sync" "tailscale.com/envknob" @@ -136,7 +135,7 @@ func (e *extension) relayServerOrInit() (relayServer, error) { return nil, errors.New("TAILSCALE_USE_WIP_CODE envvar is not set") } var err error - e.server, _, err = udprelay.NewServer(*e.port, []netip.Addr{netip.MustParseAddr("127.0.0.1")}) + e.server, _, err = udprelay.NewServer(e.logf, *e.port, nil) if err != nil { return nil, err } diff --git a/net/netcheck/netcheck.go b/net/netcheck/netcheck.go index c9f03966b..54627f713 100644 --- a/net/netcheck/netcheck.go +++ b/net/netcheck/netcheck.go @@ -753,6 +753,7 @@ func newReport() *Report { // GetReportOpts contains options that can be passed to GetReport. Unless // specified, all fields are optional and can be left as their zero value. +// At most one of OnlyTCP443 or OnlySTUN may be set. type GetReportOpts struct { // GetLastDERPActivity is a callback that, if provided, should return // the absolute time that the calling code last communicated with a @@ -765,6 +766,8 @@ type GetReportOpts struct { // OnlyTCP443 constrains netcheck reporting to measurements over TCP port // 443. OnlyTCP443 bool + // OnlySTUN constrains netcheck reporting to STUN measurements over UDP. + OnlySTUN bool } // getLastDERPActivity calls o.GetLastDERPActivity if both o and @@ -790,6 +793,13 @@ func (c *Client) SetForcePreferredDERP(region int) { // // It may not be called concurrently with itself. func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetReportOpts) (_ *Report, reterr error) { + onlySTUN := false + if opts != nil && opts.OnlySTUN { + if opts.OnlyTCP443 { + return nil, errors.New("netcheck: only one of OnlySTUN or OnlyTCP443 may be set in opts") + } + onlySTUN = true + } defer func() { if reterr != nil { metricNumGetReportError.Add(1) @@ -865,6 +875,9 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe }() if runtime.GOOS == "js" || runtime.GOOS == "tamago" || (runtime.GOOS == "plan9" && hostinfo.IsInVM86()) { + if onlySTUN { + return nil, errors.New("platform is restricted to HTTP, but OnlySTUN is set in opts") + } if err := c.runHTTPOnlyChecks(ctx, last, rs, dm); err != nil { return nil, err } @@ -896,7 +909,7 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe // it's unnecessary. captivePortalDone := syncs.ClosedChan() captivePortalStop := func() {} - if !rs.incremental { + if !rs.incremental && !onlySTUN { // NOTE(andrew): we can't simply add this goroutine to the // `NewWaitGroupChan` below, since we don't wait for that // waitgroup to finish when exiting this function and thus get @@ -970,9 +983,9 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe rs.stopTimers() // Try HTTPS and ICMP latency check if all STUN probes failed due to - // UDP presumably being blocked. + // UDP presumably being blocked, and we are not constrained to only STUN. // TODO: this should be moved into the probePlan, using probeProto probeHTTPS. - if !rs.anyUDP() && ctx.Err() == nil { + if !rs.anyUDP() && ctx.Err() == nil && !onlySTUN { var wg sync.WaitGroup var need []*tailcfg.DERPRegion for rid, reg := range dm.Regions { diff --git a/net/udprelay/server.go b/net/udprelay/server.go index 7b63ec95e..f7f5868c0 100644 --- a/net/udprelay/server.go +++ b/net/udprelay/server.go @@ -8,6 +8,7 @@ package udprelay import ( "bytes" + "context" "crypto/rand" "errors" "fmt" @@ -19,11 +20,18 @@ import ( "time" "go4.org/mem" + "tailscale.com/client/local" "tailscale.com/disco" + "tailscale.com/net/netcheck" + "tailscale.com/net/netmon" "tailscale.com/net/packet" + "tailscale.com/net/stun" "tailscale.com/net/udprelay/endpoint" "tailscale.com/tstime" "tailscale.com/types/key" + "tailscale.com/types/logger" + "tailscale.com/util/eventbus" + "tailscale.com/util/set" ) const ( @@ -42,25 +50,22 @@ const ( // Server implements an experimental UDP relay server. type Server struct { - // disco keypair used as part of 3-way bind handshake - disco key.DiscoPrivate - discoPublic key.DiscoPublic - + // The following fields are initialized once and never mutated. + logf logger.Logf + disco key.DiscoPrivate + discoPublic key.DiscoPublic bindLifetime time.Duration steadyStateLifetime time.Duration + bus *eventbus.Bus + uc *net.UDPConn + closeOnce sync.Once + wg sync.WaitGroup + closeCh chan struct{} + netChecker *netcheck.Client - // addrPorts contains the ip:port pairs returned as candidate server - // endpoints in response to an allocation request. - addrPorts []netip.AddrPort - - uc *net.UDPConn - - closeOnce sync.Once - wg sync.WaitGroup - closeCh chan struct{} + mu sync.Mutex // guards the following fields + addrPorts []netip.AddrPort // the ip:port pairs returned as candidate endpoints closed bool - - mu sync.Mutex // guards the following fields lamportID uint64 vniPool []uint32 // the pool of available VNIs byVNI map[uint32]*serverEndpoint @@ -270,14 +275,13 @@ func (e *serverEndpoint) isBound() bool { // NewServer constructs a [Server] listening on 0.0.0.0:'port'. IPv6 is not yet // supported. Port may be 0, and what ultimately gets bound is returned as -// 'boundPort'. Supplied 'addrs' are joined with 'boundPort' and returned as -// [endpoint.ServerEndpoint.AddrPorts] in response to Server.AllocateEndpoint() -// requests. +// 'boundPort'. If len(overrideAddrs) > 0 these will be used in place of dynamic +// discovery, which is useful to override in tests. // // TODO: IPv6 support -// TODO: dynamic addrs:port discovery -func NewServer(port int, addrs []netip.Addr) (s *Server, boundPort int, err error) { +func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Server, boundPort uint16, err error) { s = &Server{ + logf: logger.WithPrefix(logf, "relayserver"), disco: key.NewDisco(), bindLifetime: defaultBindLifetime, steadyStateLifetime: defaultSteadyStateLifetime, @@ -292,26 +296,120 @@ func NewServer(port int, addrs []netip.Addr) (s *Server, boundPort int, err erro for i := 1; i < 1<<24; i++ { s.vniPool = append(s.vniPool, uint32(i)) } + + bus := eventbus.New() + s.bus = bus + netMon, err := netmon.New(s.bus, logf) + if err != nil { + return nil, 0, err + } + s.netChecker = &netcheck.Client{ + NetMon: netMon, + Logf: logger.WithPrefix(logf, "relayserver: netcheck:"), + SendPacket: func(b []byte, addrPort netip.AddrPort) (int, error) { + return s.uc.WriteToUDPAddrPort(b, addrPort) + }, + } + boundPort, err = s.listenOn(port) if err != nil { return nil, 0, err } - addrPorts := make([]netip.AddrPort, 0, len(addrs)) - for _, addr := range addrs { - addrPort, err := netip.ParseAddrPort(net.JoinHostPort(addr.String(), strconv.Itoa(boundPort))) - if err != nil { - return nil, 0, err - } - addrPorts = append(addrPorts, addrPort) - } - s.addrPorts = addrPorts - s.wg.Add(2) + + s.wg.Add(1) go s.packetReadLoop() + s.wg.Add(1) go s.endpointGCLoop() + if len(overrideAddrs) > 0 { + var addrPorts set.Set[netip.AddrPort] + addrPorts.Make() + for _, addr := range overrideAddrs { + if addr.IsValid() { + addrPorts.Add(netip.AddrPortFrom(addr, boundPort)) + } + } + s.addrPorts = addrPorts.Slice() + } else { + s.wg.Add(1) + go s.addrDiscoveryLoop() + } return s, boundPort, nil } -func (s *Server) listenOn(port int) (int, error) { +func (s *Server) addrDiscoveryLoop() { + defer s.wg.Done() + + timer := time.NewTimer(0) // fire immediately + defer timer.Stop() + + getAddrPorts := func() ([]netip.AddrPort, error) { + var addrPorts set.Set[netip.AddrPort] + addrPorts.Make() + + // get local addresses + localPort := s.uc.LocalAddr().(*net.UDPAddr).Port + ips, _, err := netmon.LocalAddresses() + if err != nil { + return nil, err + } + for _, ip := range ips { + if ip.IsValid() { + addrPorts.Add(netip.AddrPortFrom(ip, uint16(localPort))) + } + } + + // fetch DERPMap to feed to netcheck + derpMapCtx, derpMapCancel := context.WithTimeout(context.Background(), time.Second) + defer derpMapCancel() + localClient := &local.Client{} + // TODO(jwhited): We are in-process so use eventbus or similar. + // local.Client gets us going. + dm, err := localClient.CurrentDERPMap(derpMapCtx) + if err != nil { + return nil, err + } + + // get addrPorts as visible from DERP + netCheckerCtx, netCheckerCancel := context.WithTimeout(context.Background(), netcheck.ReportTimeout) + defer netCheckerCancel() + rep, err := s.netChecker.GetReport(netCheckerCtx, dm, &netcheck.GetReportOpts{ + OnlySTUN: true, + }) + if err != nil { + return nil, err + } + if rep.GlobalV4.IsValid() { + addrPorts.Add(rep.GlobalV4) + } + if rep.GlobalV6.IsValid() { + addrPorts.Add(rep.GlobalV6) + } + // TODO(jwhited): consider logging if rep.MappingVariesByDestIP as + // that's a hint we are not well-positioned to operate as a UDP relay. + return addrPorts.Slice(), nil + } + + for { + select { + case <-timer.C: + // Mirror magicsock behavior for duration between STUN. We consider + // 30s a min bound for NAT timeout. + timer.Reset(tstime.RandomDurationBetween(20*time.Second, 26*time.Second)) + addrPorts, err := getAddrPorts() + if err != nil { + s.logf("error discovering IP:port candidates: %v", err) + } + s.mu.Lock() + s.addrPorts = addrPorts + s.mu.Unlock() + case <-s.closeCh: + return + } + } + +} + +func (s *Server) listenOn(port int) (uint16, error) { uc, err := net.ListenUDP("udp4", &net.UDPAddr{Port: port}) if err != nil { return 0, err @@ -322,13 +420,13 @@ func (s *Server) listenOn(port int) (int, error) { s.uc.Close() return 0, err } - boundPort, err := strconv.Atoi(boundPortStr) + boundPort, err := strconv.ParseUint(boundPortStr, 10, 16) if err != nil { s.uc.Close() return 0, err } s.uc = uc - return boundPort, nil + return uint16(boundPort), nil } // Close closes the server. @@ -343,6 +441,7 @@ func (s *Server) Close() error { clear(s.byDisco) s.vniPool = nil s.closed = true + s.bus.Close() }) return nil } @@ -378,6 +477,13 @@ func (s *Server) endpointGCLoop() { } func (s *Server) handlePacket(from netip.AddrPort, b []byte, uw udpWriter) { + if stun.Is(b) && b[1] == 0x01 { + // A b[1] value of 0x01 (STUN method binding) is sufficiently + // non-overlapping with the Geneve header where the LSB is always 0 + // (part of 6 "reserved" bits). + s.netChecker.ReceiveSTUNPacket(b, from) + return + } gh := packet.GeneveHeader{} err := gh.Decode(b) if err != nil { @@ -426,6 +532,10 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv return endpoint.ServerEndpoint{}, ErrServerClosed } + if len(s.addrPorts) == 0 { + return endpoint.ServerEndpoint{}, errors.New("server addrPorts are not yet known") + } + if discoA.Compare(s.discoPublic) == 0 || discoB.Compare(s.discoPublic) == 0 { return endpoint.ServerEndpoint{}, fmt.Errorf("client disco equals server disco: %s", s.discoPublic.ShortString()) } @@ -439,8 +549,13 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv // TODO: consider ServerEndpoint.BindLifetime -= time.Now()-e.allocatedAt // to give the client a more accurate picture of the bind window. return endpoint.ServerEndpoint{ - ServerDisco: s.discoPublic, - AddrPorts: s.addrPorts, + ServerDisco: s.discoPublic, + // Returning the "latest" addrPorts for an existing allocation is + // the simple choice. It may not be the best depending on client + // behaviors and endpoint state (bound or not). We might want to + // consider storing them (maybe interning) in the [*serverEndpoint] + // at allocation time. + AddrPorts: slices.Clone(s.addrPorts), VNI: e.vni, LamportID: e.lamportID, BindLifetime: tstime.GoDuration{Duration: s.bindLifetime}, @@ -469,7 +584,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv return endpoint.ServerEndpoint{ ServerDisco: s.discoPublic, - AddrPorts: s.addrPorts, + AddrPorts: slices.Clone(s.addrPorts), VNI: e.vni, LamportID: e.lamportID, BindLifetime: tstime.GoDuration{Duration: s.bindLifetime}, diff --git a/net/udprelay/server_test.go b/net/udprelay/server_test.go index 38c7ae5d9..a4e5ca451 100644 --- a/net/udprelay/server_test.go +++ b/net/udprelay/server_test.go @@ -156,7 +156,7 @@ func TestServer(t *testing.T) { ipv4LoopbackAddr := netip.MustParseAddr("127.0.0.1") - server, _, err := NewServer(0, []netip.Addr{ipv4LoopbackAddr}) + server, _, err := NewServer(t.Logf, 0, []netip.Addr{ipv4LoopbackAddr}) if err != nil { t.Fatal(err) }