diff --git a/cmd/tailscale/netcheck.go b/cmd/tailscale/netcheck.go index c58103dce..7de01f72a 100644 --- a/cmd/tailscale/netcheck.go +++ b/cmd/tailscale/netcheck.go @@ -9,6 +9,7 @@ "fmt" "log" "sort" + "time" "github.com/peterbourgon/ff/v2/ffcli" "tailscale.com/derp/derpmap" @@ -26,12 +27,12 @@ func runNetcheck(ctx context.Context, args []string) error { c := &netcheck.Client{ - DERP: derpmap.Prod(), Logf: logger.WithPrefix(log.Printf, "netcheck: "), DNSCache: dnscache.Get(), } - report, err := c.GetReport(ctx) + dm := derpmap.Prod() + report, err := c.GetReport(ctx, dm) if err != nil { log.Fatalf("netcheck: %v", err) } @@ -55,18 +56,23 @@ func runNetcheck(ctx context.Context, args []string) error { // When DERP latency checking failed, // magicsock will try to pick the DERP server that // most of your other nodes are also using - if len(report.DERPLatency) == 0 { + if len(report.RegionLatency) == 0 { fmt.Printf("\t* Nearest DERP: unknown (no response to latency probes)\n") } else { - fmt.Printf("\t* Nearest DERP: %v (%v)\n", report.PreferredDERP, c.DERP.LocationOfID(report.PreferredDERP)) + fmt.Printf("\t* Nearest DERP: %v (%v)\n", report.PreferredDERP, dm.Regions[report.PreferredDERP].RegionCode) fmt.Printf("\t* DERP latency:\n") - var ss []string - for s := range report.DERPLatency { - ss = append(ss, s) + var rids []int + for rid := range dm.Regions { + rids = append(rids, rid) } - sort.Strings(ss) - for _, s := range ss { - fmt.Printf("\t\t- %s = %v\n", s, report.DERPLatency[s]) + sort.Ints(rids) + for _, rid := range rids { + d, ok := report.RegionLatency[rid] + var latency string + if ok { + latency = d.Round(time.Millisecond / 10).String() + } + fmt.Printf("\t\t- %v, %3s = %s\n", rid, dm.Regions[rid].RegionCode, latency) } } return nil diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index 0e7afb6f8..9078e9374 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -541,6 +541,8 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM } }() + var lastDERPMap *tailcfg.DERPMap + // If allowStream, then the server will use an HTTP long poll to // return incremental results. There is always one response right // away, followed by a delay, and eventually others. @@ -582,6 +584,11 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM } vlogf("netmap: got new map") + if resp.DERPMap != nil { + vlogf("netmap: new map contains DERP map") + lastDERPMap = resp.DERPMap + } + nm := &NetworkMap{ NodeKey: tailcfg.NodeKey(persist.PrivateNodeKey.Public()), PrivateKey: persist.PrivateNodeKey, @@ -597,6 +604,7 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM DNSDomains: resp.SearchPaths, Hostinfo: resp.Node.Hostinfo, PacketFilter: c.parsePacketFilter(resp.PacketFilter), + DERPMap: lastDERPMap, } for _, profile := range resp.UserProfiles { nm.UserProfiles[profile.ID] = profile diff --git a/control/controlclient/netmap.go b/control/controlclient/netmap.go index b78dd6afe..764c56b38 100644 --- a/control/controlclient/netmap.go +++ b/control/controlclient/netmap.go @@ -33,6 +33,10 @@ type NetworkMap struct { Hostinfo tailcfg.Hostinfo PacketFilter filter.Matches + // DERPMap is the last DERP server map received. It's reused + // between updates and should not be modified. + DERPMap *tailcfg.DERPMap + // ACLs User tailcfg.UserID diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index bcaa08ea4..c63644bb1 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -24,9 +24,11 @@ "sync" "time" + "inet.af/netaddr" "tailscale.com/derp" "tailscale.com/net/dnscache" "tailscale.com/net/tlsdial" + "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/logger" ) @@ -43,7 +45,10 @@ type Client struct { privateKey key.Private logf logger.Logf - url *url.URL + + // Either url or getRegion is non-nil: + url *url.URL + getRegion func() *tailcfg.DERPRegion ctx context.Context // closed via cancelCtx in Client.Close cancelCtx context.CancelFunc @@ -55,8 +60,22 @@ type Client struct { client *derp.Client } +// NewRegionClient returns a new DERP-over-HTTP client. It connects lazily. +// To trigger a connection, use Connect. +func NewRegionClient(privateKey key.Private, logf logger.Logf, getRegion func() *tailcfg.DERPRegion) *Client { + ctx, cancel := context.WithCancel(context.Background()) + c := &Client{ + privateKey: privateKey, + logf: logf, + getRegion: getRegion, + ctx: ctx, + cancelCtx: cancel, + } + return c +} + // NewClient returns a new DERP-over-HTTP client. It connects lazily. -// To trigger a connection use Connect. +// To trigger a connection, use Connect. func NewClient(privateKey key.Private, serverURL string, logf logger.Logf) (*Client, error) { u, err := url.Parse(serverURL) if err != nil { @@ -65,6 +84,7 @@ func NewClient(privateKey key.Private, serverURL string, logf logger.Logf) (*Cli if urlPort(u) == "" { return nil, fmt.Errorf("derphttp.NewClient: invalid URL scheme %q", u.Scheme) } + ctx, cancel := context.WithCancel(context.Background()) c := &Client{ privateKey: privateKey, @@ -101,6 +121,37 @@ func urlPort(u *url.URL) string { return "" } +func (c *Client) targetString(reg *tailcfg.DERPRegion) string { + if c.url != nil { + return c.url.String() + } + return fmt.Sprintf("region %d (%v)", reg.RegionID, reg.RegionCode) +} + +func (c *Client) useHTTPS() bool { + if c.url != nil && c.url.Scheme == "http" { + return false + } + return true +} + +func (c *Client) tlsServerName(node *tailcfg.DERPNode) string { + if c.url != nil { + return c.url.Host + } + if node.CertName != "" { + return node.CertName + } + return node.HostName +} + +func (c *Client) urlString(node *tailcfg.DERPNode) string { + if c.url != nil { + return c.url.String() + } + return fmt.Sprintf("https://%s/derp", node.HostName) +} + func (c *Client) connect(ctx context.Context, caller string) (client *derp.Client, err error) { c.mu.Lock() defer c.mu.Unlock() @@ -111,8 +162,6 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien return c.client, nil } - c.logf("%s: connecting to %v", caller, c.url) - // timeout is the fallback maximum time (if ctx doesn't limit // it further) to do all of: DNS + TCP + TLS + HTTP Upgrade + // DERP upgrade. @@ -132,46 +181,42 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien }() defer cancel() + var reg *tailcfg.DERPRegion // nil when using c.url to dial + if c.getRegion != nil { + reg = c.getRegion() + if reg == nil { + return nil, errors.New("DERP region not available") + } + } + var tcpConn net.Conn + defer func() { if err != nil { if ctx.Err() != nil { err = fmt.Errorf("%v: %v", ctx.Err(), err) } - err = fmt.Errorf("%s connect to %v: %v", caller, c.url, err) + err = fmt.Errorf("%s connect to %v: %v", caller, c.targetString(reg), err) if tcpConn != nil { go tcpConn.Close() } } }() - host := c.url.Hostname() - hostOrIP := host - - var stdDialer dialer = new(net.Dialer) - var dialer = stdDialer - if wrapDialer != nil { - dialer = wrapDialer(dialer) + var node *tailcfg.DERPNode // nil when using c.url to dial + if c.url != nil { + c.logf("%s: connecting to %v", caller, c.url) + tcpConn, err = c.dialURL(ctx) + } else { + c.logf("%s: connecting to derp-%d (%v)", caller, reg.RegionID, reg.RegionCode) + tcpConn, node, err = c.dialRegion(ctx, reg) } - - if c.DNSCache != nil { - ip, err := c.DNSCache.LookupIP(ctx, host) - if err == nil { - hostOrIP = ip.String() - } - if err != nil && dialer == stdDialer { - // Return an error if we're not using a dial - // proxy that can do DNS lookups for us. - return nil, err - } - } - - tcpConn, err = dialer.DialContext(ctx, "tcp", net.JoinHostPort(hostOrIP, urlPort(c.url))) if err != nil { - return nil, fmt.Errorf("dial of %q: %v", host, err) + return nil, err } - // Now that we have a TCP connection, force close it. + // Now that we have a TCP connection, force close it if the + // TLS handshake + DERP setup takes too long. done := make(chan struct{}) defer close(done) go func() { @@ -195,15 +240,19 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien }() var httpConn net.Conn // a TCP conn or a TLS conn; what we speak HTTP to - if c.url.Scheme == "https" { - httpConn = tls.Client(tcpConn, tlsdial.Config(c.url.Host, c.TLSConfig)) + if c.useHTTPS() { + tlsConf := tlsdial.Config(c.tlsServerName(node), c.TLSConfig) + if node != nil && node.DERPTestPort != 0 { + tlsConf.InsecureSkipVerify = true + } + httpConn = tls.Client(tcpConn, tlsConf) } else { httpConn = tcpConn } brw := bufio.NewReadWriter(bufio.NewReader(httpConn), bufio.NewWriter(httpConn)) - req, err := http.NewRequest("GET", c.url.String(), nil) + req, err := http.NewRequest("GET", c.urlString(node), nil) if err != nil { return nil, err } @@ -243,6 +292,148 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien return c.client, nil } +func (c *Client) dialURL(ctx context.Context) (net.Conn, error) { + host := c.url.Hostname() + hostOrIP := host + + var stdDialer dialer = new(net.Dialer) + var dialer = stdDialer + if wrapDialer != nil { + dialer = wrapDialer(dialer) + } + + if c.DNSCache != nil { + ip, err := c.DNSCache.LookupIP(ctx, host) + if err == nil { + hostOrIP = ip.String() + } + if err != nil && dialer == stdDialer { + // Return an error if we're not using a dial + // proxy that can do DNS lookups for us. + return nil, err + } + } + + tcpConn, err := dialer.DialContext(ctx, "tcp", net.JoinHostPort(hostOrIP, urlPort(c.url))) + if err != nil { + return nil, fmt.Errorf("dial of %v: %v", host, err) + } + return tcpConn, nil +} + +// dialRegion returns a TCP connection to the provided region, trying +// each node in order (with dialNode) until one connects or ctx is +// done. +func (c *Client) dialRegion(ctx context.Context, reg *tailcfg.DERPRegion) (net.Conn, *tailcfg.DERPNode, error) { + if len(reg.Nodes) == 0 { + return nil, nil, fmt.Errorf("no nodes for %s", c.targetString(reg)) + } + var firstErr error + for _, n := range reg.Nodes { + if n.STUNOnly { + continue + } + c, err := c.dialNode(ctx, n) + if err == nil { + return c, n, nil + } + if firstErr == nil { + firstErr = err + } + } + return nil, nil, firstErr +} + +func (c *Client) dialContext(ctx context.Context, proto, addr string) (net.Conn, error) { + var stdDialer dialer = new(net.Dialer) + var dialer = stdDialer + if wrapDialer != nil { + dialer = wrapDialer(dialer) + } + return dialer.DialContext(ctx, proto, addr) +} + +// shouldDialProto reports whether an explicitly provided IPv4 or IPv6 +// address (given in s) is valid. An empty value means to dial, but to +// use DNS. The predicate function reports whether the non-empty +// string s contained a valid IP address of the right family. +func shouldDialProto(s string, pred func(netaddr.IP) bool) bool { + if s == "" { + return true + } + ip, _ := netaddr.ParseIP(s) + return pred(ip) +} + +const dialNodeTimeout = 1500 * time.Millisecond + +// dialNode returns a TCP connection to node n, racing IPv4 and IPv6 +// (both as applicable) against each other. +// A node is only given dialNodeTimeout to connect. +// +// TODO(bradfitz): longer if no options remain perhaps? ... Or longer +// overall but have dialRegion start overlapping races? +func (c *Client) dialNode(ctx context.Context, n *tailcfg.DERPNode) (net.Conn, error) { + type res struct { + c net.Conn + err error + } + resc := make(chan res) // must be unbuffered + ctx, cancel := context.WithTimeout(ctx, dialNodeTimeout) + defer cancel() + + nwait := 0 + startDial := func(dstPrimary, proto string) { + nwait++ + go func() { + dst := dstPrimary + if dst == "" { + dst = n.HostName + } + port := "443" + if n.DERPTestPort != 0 { + port = fmt.Sprint(n.DERPTestPort) + } + c, err := c.dialContext(ctx, proto, net.JoinHostPort(dst, port)) + select { + case resc <- res{c, err}: + case <-ctx.Done(): + if c != nil { + c.Close() + } + } + }() + } + if shouldDialProto(n.IPv4, netaddr.IP.Is4) { + startDial(n.IPv4, "tcp4") + } + if shouldDialProto(n.IPv6, netaddr.IP.Is6) { + startDial(n.IPv6, "tcp6") + } + if nwait == 0 { + return nil, errors.New("both IPv4 and IPv6 are explicitly disabled for node") + } + + var firstErr error + for { + select { + case res := <-resc: + nwait-- + if res.err == nil { + return res.c, nil + } + if firstErr == nil { + firstErr = res.err + } + if nwait == 0 { + return nil, firstErr + } + case <-ctx.Done(): + return nil, ctx.Err() + } + } +} + func (c *Client) Send(dstKey key.Public, b []byte) error { client, err := c.connect(context.TODO(), "derphttp.Client.Send") if err != nil { diff --git a/derp/derpmap/derpmap.go b/derp/derpmap/derpmap.go index e0c55eb4b..587f8b61c 100644 --- a/derp/derpmap/derpmap.go +++ b/derp/derpmap/derpmap.go @@ -7,151 +7,59 @@ import ( "fmt" - "net" + "strings" - "tailscale.com/types/structs" + "tailscale.com/tailcfg" ) -// World is a set of DERP server. -type World struct { - servers []*Server - ids []int - byID map[int]*Server - stun4 []string - stun6 []string -} - -func (w *World) IDs() []int { return w.ids } -func (w *World) STUN4() []string { return w.stun4 } -func (w *World) STUN6() []string { return w.stun6 } -func (w *World) ServerByID(id int) *Server { return w.byID[id] } - -// LocationOfID returns the geographic name of a node, if present. -func (w *World) LocationOfID(id int) string { - if s, ok := w.byID[id]; ok { - return s.Geo - } - return "" -} - -func (w *World) NodeIDOfSTUNServer(server string) int { - // TODO: keep reverse map? Small enough to not matter for now. - for _, s := range w.servers { - if s.STUN4 == server || s.STUN6 == server { - return s.ID - } - } - return 0 -} - -// ForeachServer calls fn for each DERP server, in an unspecified order. -func (w *World) ForeachServer(fn func(*Server)) { - for _, s := range w.byID { - fn(s) +func derpNode(suffix, v4, v6 string) *tailcfg.DERPNode { + return &tailcfg.DERPNode{ + Name: suffix, // updated later + RegionID: 0, // updated later + IPv4: v4, + IPv6: v6, } } -// Prod returns the production DERP nodes. -func Prod() *World { - return prod +func derpRegion(id int, code string, nodes ...*tailcfg.DERPNode) *tailcfg.DERPRegion { + region := &tailcfg.DERPRegion{ + RegionID: id, + RegionCode: code, + Nodes: nodes, + } + for _, n := range nodes { + n.Name = fmt.Sprintf("%d%s", id, n.Name) + n.RegionID = id + n.HostName = fmt.Sprintf("derp%s.tailscale.com", strings.TrimSuffix(n.Name, "a")) + } + return region } -func NewTestWorld(stun ...string) *World { - w := &World{} - for i, s := range stun { - w.add(&Server{ - ID: i + 1, - Geo: fmt.Sprintf("Testopolis-%d", i+1), - STUN4: s, - }) - } - return w -} - -func NewTestWorldWith(servers ...*Server) *World { - w := &World{} - for _, s := range servers { - w.add(s) - } - return w -} - -var prod = new(World) // ... a dazzling place I never knew - -func addProd(id int, geo string) { - prod.add(&Server{ - ID: id, - Geo: geo, - HostHTTPS: fmt.Sprintf("derp%v.tailscale.com", id), - STUN4: fmt.Sprintf("derp%v.tailscale.com:3478", id), - STUN6: fmt.Sprintf("derp%v-v6.tailscale.com:3478", id), - }) -} - -func (w *World) add(s *Server) { - if s.ID == 0 { - panic("ID required") - } - if _, dup := w.byID[s.ID]; dup { - panic("duplicate prod server") - } - if w.byID == nil { - w.byID = make(map[int]*Server) - } - w.byID[s.ID] = s - w.ids = append(w.ids, s.ID) - w.servers = append(w.servers, s) - if s.STUN4 != "" { - w.stun4 = append(w.stun4, s.STUN4) - if _, _, err := net.SplitHostPort(s.STUN4); err != nil { - panic("not a host:port: " + s.STUN4) - } - } - if s.STUN6 != "" { - w.stun6 = append(w.stun6, s.STUN6) - if _, _, err := net.SplitHostPort(s.STUN6); err != nil { - panic("not a host:port: " + s.STUN6) - } +// Prod returns Tailscale's map of relay servers. +// +// This list is only used by cmd/tailscale's netcheck subcommand. In +// normal operation the Tailscale nodes get this sent to them from the +// control server. +// +// This list is subject to change and should not be relied on. +func Prod() *tailcfg.DERPMap { + return &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: derpRegion(1, "nyc", + derpNode("a", "159.89.225.99", "2604:a880:400:d1::828:b001"), + ), + 2: derpRegion(2, "sfo", + derpNode("a", "167.172.206.31", "2604:a880:2:d1::c5:7001"), + ), + 3: derpRegion(3, "sin", + derpNode("a", "68.183.179.66", "2400:6180:0:d1::67d:8001"), + ), + 4: derpRegion(4, "fra", + derpNode("a", "167.172.182.26", "2a03:b0c0:3:e0::36e:9001"), + ), + 5: derpRegion(5, "syd", + derpNode("a", "103.43.75.49", "2001:19f0:5801:10b7:5400:2ff:feaa:284c"), + ), + }, } } - -func init() { - addProd(1, "New York") - addProd(2, "San Francisco") - addProd(3, "Singapore") - addProd(4, "Frankfurt") - addProd(5, "Sydney") -} - -// Server is configuration for a DERP server. -type Server struct { - _ structs.Incomparable - - ID int - - // HostHTTPS is the HTTPS hostname. - HostHTTPS string - - // STUN4 is the host:port of the IPv4 STUN server on this DERP - // node. Required. - STUN4 string - - // STUN6 optionally provides the IPv6 host:port of the STUN - // server on the DERP node. - // It should be an IPv6-only address for now. (We currently make lazy - // assumptions that the server names are unique.) - STUN6 string - - // Geo is a human-readable geographic region name of this server. - Geo string -} - -func (s *Server) String() string { - if s == nil { - return "" - } - if s.Geo != "" { - return fmt.Sprintf("%v (%v)", s.HostHTTPS, s.Geo) - } - return s.HostHTTPS -} diff --git a/go.mod b/go.mod index f3ab8dc99..6d57af576 100644 --- a/go.mod +++ b/go.mod @@ -28,7 +28,6 @@ require ( golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e golang.org/x/sys v0.0.0-20200501052902-10377860bb8e golang.org/x/time v0.0.0-20191024005414-555d28b269f0 - gortc.io/stun v1.22.1 inet.af/netaddr v0.0.0-20200430175045-5aaf2097c7fc rsc.io/goversion v1.2.0 ) diff --git a/go.sum b/go.sum index e091d13bd..1ccc12628 100644 --- a/go.sum +++ b/go.sum @@ -142,8 +142,6 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gortc.io/stun v1.22.1 h1:96mOdDATYRqhYB+TZdenWBg4CzL2Ye5kPyBXQ8KAB+8= -gortc.io/stun v1.22.1/go.mod h1:XD5lpONVyjvV3BgOyJFNo0iv6R2oZB4L+weMqxts+zg= inet.af/netaddr v0.0.0-20200430175045-5aaf2097c7fc h1:We3b/z+7i9LV4Ls0yWve5vYIlnAPSPeqxKVgZseRDBs= inet.af/netaddr v0.0.0-20200430175045-5aaf2097c7fc/go.mod h1:qqYzz/2whtrbWJvt+DNWQyvekNN4ePQZcg2xc2/Yjww= rsc.io/goversion v1.2.0 h1:SPn+NLTiAG7w30IRK/DKp1BjvpWabYgxlLp/+kx5J8w= diff --git a/ipn/local.go b/ipn/local.go index 1a4e091f4..e608fc601 100644 --- a/ipn/local.go +++ b/ipn/local.go @@ -240,10 +240,8 @@ func (b *LocalBackend) Start(opts Options) error { b.notify = opts.Notify b.netMapCache = nil persist := b.prefs.Persist - wantDERP := !b.prefs.DisableDERP b.mu.Unlock() - b.e.SetDERPEnabled(wantDERP) b.updateFilter(nil) var err error @@ -307,11 +305,17 @@ func (b *LocalBackend) Start(opts Options) error { b.logf("netmap diff:\n%v", diff) } } + disableDERP := b.prefs != nil && b.prefs.DisableDERP b.netMapCache = newSt.NetMap b.mu.Unlock() b.send(Notify{NetMap: newSt.NetMap}) b.updateFilter(newSt.NetMap) + if disableDERP { + b.e.SetDERPMap(nil) + } else { + b.e.SetDERPMap(newSt.NetMap.DERPMap) + } } if newSt.URL != "" { b.logf("Received auth URL: %.20v...", newSt.URL) diff --git a/netcheck/netcheck.go b/netcheck/netcheck.go index bbd1f163b..f74bae5ec 100644 --- a/netcheck/netcheck.go +++ b/netcheck/netcheck.go @@ -20,26 +20,28 @@ "time" "github.com/tcnksm/go-httpstat" - "golang.org/x/sync/errgroup" - "tailscale.com/derp/derpmap" + "inet.af/netaddr" "tailscale.com/net/dnscache" "tailscale.com/net/interfaces" "tailscale.com/stun" - "tailscale.com/stunner" + "tailscale.com/syncs" + "tailscale.com/tailcfg" "tailscale.com/types/logger" "tailscale.com/types/opt" ) type Report struct { - UDP bool // UDP works - IPv6 bool // IPv6 works - MappingVariesByDestIP opt.Bool // for IPv4 - HairPinning opt.Bool // for IPv4 - PreferredDERP int // or 0 for unknown - DERPLatency map[string]time.Duration // keyed by STUN host:port + UDP bool // UDP works + IPv6 bool // IPv6 works + MappingVariesByDestIP opt.Bool // for IPv4 + HairPinning opt.Bool // for IPv4 + PreferredDERP int // or 0 for unknown + RegionLatency map[int]time.Duration // keyed by DERP Region ID + RegionV4Latency map[int]time.Duration // keyed by DERP Region ID + RegionV6Latency map[int]time.Duration // keyed by DERP Region ID GlobalV4 string // ip:port of global IPv4 - GlobalV6 string // [ip]:port of global IPv6 // TODO + GlobalV6 string // [ip]:port of global IPv6 // TODO: update Clone when adding new fields } @@ -49,40 +51,50 @@ func (r *Report) Clone() *Report { return nil } r2 := *r - if r2.DERPLatency != nil { - r2.DERPLatency = map[string]time.Duration{} - for k, v := range r.DERPLatency { - r2.DERPLatency[k] = v - } - } + r2.RegionLatency = cloneDurationMap(r2.RegionLatency) + r2.RegionV4Latency = cloneDurationMap(r2.RegionV4Latency) + r2.RegionV6Latency = cloneDurationMap(r2.RegionV6Latency) return &r2 } +func cloneDurationMap(m map[int]time.Duration) map[int]time.Duration { + if m == nil { + return nil + } + m2 := make(map[int]time.Duration, len(m)) + for k, v := range m { + m2[k] = v + } + return m2 +} + // Client generates a netcheck Report. type Client struct { - // DERP is the DERP world to use. - DERP *derpmap.World - // DNSCache optionally specifies a DNSCache to use. // If nil, a DNS cache is not used. DNSCache *dnscache.Resolver // Logf optionally specifies where to log to. + // If nil, log.Printf is used. Logf logger.Logf // TimeNow, if non-nil, is used instead of time.Now. TimeNow func() time.Time + // GetSTUNConn4 optionally provides a func to return the + // connection to use for sending & receiving IPv4 packets. If + // nil, an emphemeral one is created as needed. GetSTUNConn4 func() STUNConn + + // GetSTUNConn6 is like GetSTUNConn4, but for IPv6. GetSTUNConn6 func() STUNConn - mu sync.Mutex // guards following - prev map[time.Time]*Report // some previous reports - last *Report // most recent report - s4 *stunner.Stunner - s6 *stunner.Stunner - hairTX stun.TxID - gotHairSTUN chan *net.UDPAddr // non-nil if we're in GetReport + mu sync.Mutex // guards following + nextFull bool // do a full region scan, even if last != nil + prev map[time.Time]*Report // some previous reports + last *Report // most recent report + lastFull time.Time // time of last full (non-incremental) report + curState *reportState // non-nil if we're in a call to GetReportn } // STUNConn is the interface required by the netcheck Client when @@ -102,16 +114,14 @@ func (c *Client) logf(format string, a ...interface{}) { // handleHairSTUN reports whether pkt (from src) was our magic hairpin // probe packet that we sent to ourselves. -func (c *Client) handleHairSTUN(pkt []byte, src *net.UDPAddr) bool { - c.mu.Lock() - defer c.mu.Unlock() - return c.handleHairSTUNLocked(pkt, src) -} - func (c *Client) handleHairSTUNLocked(pkt []byte, src *net.UDPAddr) bool { - if tx, err := stun.ParseBindingRequest(pkt); err == nil && tx == c.hairTX { + rs := c.curState + if rs == nil { + return false + } + if tx, err := stun.ParseBindingRequest(pkt); err == nil && tx == rs.hairTX { select { - case c.gotHairSTUN <- src: + case rs.gotHairSTUN <- src: default: } return true @@ -119,381 +129,570 @@ func (c *Client) handleHairSTUNLocked(pkt []byte, src *net.UDPAddr) bool { return false } +// MakeNextReportFull forces the next GetReport call to be a full +// (non-incremental) probe of all DERP regions. +func (c *Client) MakeNextReportFull() { + c.mu.Lock() + c.nextFull = true + c.mu.Unlock() +} + func (c *Client) ReceiveSTUNPacket(pkt []byte, src *net.UDPAddr) { if src == nil || src.IP == nil { panic("bogus src") } c.mu.Lock() - if c.handleHairSTUNLocked(pkt, src) { c.mu.Unlock() return } - - var st *stunner.Stunner - if src.IP.To4() != nil { - st = c.s4 - } else { - st = c.s6 - } - + rs := c.curState c.mu.Unlock() - if st != nil { - st.Receive(pkt, src) + if rs == nil { + return + } + + tx, addr, port, err := stun.ParseResponse(pkt) + if err != nil { + c.mu.Unlock() + if _, err := stun.ParseBindingRequest(pkt); err == nil { + // This was probably our own netcheck hairpin + // check probe coming in late. Ignore. + return + } + c.logf("netcheck: received unexpected STUN message response from %v: %v", src, err) + return + } + + rs.mu.Lock() + onDone, ok := rs.inFlight[tx] + if ok { + delete(rs.inFlight, tx) + } + rs.mu.Unlock() + if ok { + if ipp, ok := netaddr.FromStdAddr(addr, int(port), ""); ok { + onDone(ipp) + } } } -// pickSubset selects a subset of IPv4 and IPv6 STUN server addresses -// to hit based on history. -// -// maxTries is the max number of tries per server. -// -// The caller owns the returned values. -func (c *Client) pickSubset() (stuns4, stuns6 []string, maxTries map[string]int, err error) { - c.mu.Lock() - defer c.mu.Unlock() +// probeProto is the protocol used to time a node's latency. +type probeProto uint8 - const defaultMaxTries = 2 - maxTries = map[string]int{} +const ( + probeIPv4 probeProto = iota // STUN IPv4 + probeIPv6 // STUN IPv6 + probeHTTPS // HTTPS +) - var prev4, prev6 []string // sorted fastest to slowest - if c.last != nil { - condAppend := func(dst []string, server string) []string { - if server != "" && c.last.DERPLatency[server] != 0 { - return append(dst, server) - } - return dst - } - c.DERP.ForeachServer(func(s *derpmap.Server) { - prev4 = condAppend(prev4, s.STUN4) - prev6 = condAppend(prev6, s.STUN6) - }) - sort.Slice(prev4, func(i, j int) bool { return c.last.DERPLatency[prev4[i]] < c.last.DERPLatency[prev4[j]] }) - sort.Slice(prev6, func(i, j int) bool { return c.last.DERPLatency[prev6[i]] < c.last.DERPLatency[prev6[j]] }) +type probe struct { + // delay is when the probe is started, relative to the time + // that GetReport is called. One probe in each probePlan + // should have a delay of 0. Non-zero values are for retries + // on UDP loss or timeout. + delay time.Duration + + // node is the name of the node name. DERP node names are globally + // unique so there's no region ID. + node string + + // proto is how the node should be probed. + proto probeProto + + // wait is how long to wait until the probe is considered failed. + // 0 means to use a default value. + wait time.Duration +} + +// probePlan is a set of node probes to run. +// The map key is a descriptive name, only used for tests. +// +// The values are logically an unordered set of tests to run concurrently. +// In practice there's some order to them based on their delay fields, +// but multiple probes can have the same delay time or be running concurrently +// both within and between sets. +// +// A set of probes is done once either one of the probes completes, or +// the next probe to run wouldn't yield any new information not +// already discovered by any previous probe in any set. +type probePlan map[string][]probe + +// sortRegions returns the regions of dm first sorted +// from fastest to slowest (based on the 'last' report), +// end in regions that have no data. +func sortRegions(dm *tailcfg.DERPMap, last *Report) (prev []*tailcfg.DERPRegion) { + prev = make([]*tailcfg.DERPRegion, 0, len(dm.Regions)) + for _, reg := range dm.Regions { + prev = append(prev, reg) } + sort.Slice(prev, func(i, j int) bool { + da, db := last.RegionLatency[prev[i].RegionID], last.RegionLatency[prev[j].RegionID] + if db == 0 && da != 0 { + // Non-zero sorts before zero. + return true + } + if da == 0 { + // Zero can't sort before anything else. + return false + } + return da < db + }) + return prev +} - c.DERP.ForeachServer(func(s *derpmap.Server) { - if s.STUN4 == "" { +// numIncrementalRegions is the number of fastest regions to +// periodically re-query during incremental netcheck reports. (During +// a full report, all regions are scanned.) +const numIncrementalRegions = 3 + +// makeProbePlan generates the probe plan for a DERPMap, given the most +// recent report and whether IPv6 is configured on an interface. +func makeProbePlan(dm *tailcfg.DERPMap, have6if bool, last *Report) (plan probePlan) { + if last == nil || len(last.RegionLatency) == 0 { + return makeProbePlanInitial(dm, have6if) + } + plan = make(probePlan) + had4 := len(last.RegionV4Latency) > 0 + had6 := len(last.RegionV6Latency) > 0 + hadBoth := have6if && had4 && had6 + for ri, reg := range sortRegions(dm, last) { + if ri == numIncrementalRegions { + break + } + var p4, p6 []probe + do4 := true + do6 := have6if + + // By default, each node only gets one STUN packet sent, + // except the fastest two from the previous round. + tries := 1 + isFastestTwo := ri < 2 + + if isFastestTwo { + tries = 2 + } else if hadBoth { + // For dual stack machines, make the 3rd & slower nodes alternate + // breetween + if ri%2 == 0 { + do4, do6 = true, false + } else { + do4, do6 = false, true + } + } + if !isFastestTwo && !had6 { + do6 = false + } + + for try := 0; try < tries; try++ { + if len(reg.Nodes) == 0 { + // Shouldn't be possible. + continue + } + if try != 0 && !had6 { + do6 = false + } + n := reg.Nodes[try%len(reg.Nodes)] + prevLatency := last.RegionLatency[reg.RegionID] * 120 / 100 + if prevLatency == 0 { + prevLatency = 200 * time.Millisecond + } + delay := time.Duration(try) * prevLatency + if do4 { + p4 = append(p4, probe{delay: delay, node: n.Name, proto: probeIPv4}) + } + if do6 { + p6 = append(p6, probe{delay: delay, node: n.Name, proto: probeIPv6}) + } + } + if len(p4) > 0 { + plan[fmt.Sprintf("region-%d-v4", reg.RegionID)] = p4 + } + if len(p6) > 0 { + plan[fmt.Sprintf("region-%d-v6", reg.RegionID)] = p6 + } + } + return plan +} + +func makeProbePlanInitial(dm *tailcfg.DERPMap, have6if bool) (plan probePlan) { + plan = make(probePlan) + + // initialSTUNTimeout is only 100ms because some extra retransmits + // when starting up is tolerable. + const initialSTUNTimeout = 100 * time.Millisecond + + for _, reg := range dm.Regions { + var p4 []probe + var p6 []probe + for try := 0; try < 3; try++ { + n := reg.Nodes[try%len(reg.Nodes)] + delay := time.Duration(try) * initialSTUNTimeout + if nodeMight4(n) { + p4 = append(p4, probe{delay: delay, node: n.Name, proto: probeIPv4}) + } + if have6if && nodeMight6(n) { + p6 = append(p6, probe{delay: delay, node: n.Name, proto: probeIPv6}) + } + } + if len(p4) > 0 { + plan[fmt.Sprintf("region-%d-v4", reg.RegionID)] = p4 + } + if len(p6) > 0 { + plan[fmt.Sprintf("region-%d-v6", reg.RegionID)] = p6 + } + } + return plan +} + +// nodeMight6 reports whether n might reply to STUN over IPv6 based on +// its config alone, without DNS lookups. It only returns false if +// it's not explicitly disabled. +func nodeMight6(n *tailcfg.DERPNode) bool { + if n.IPv6 == "" { + return true + } + ip, _ := netaddr.ParseIP(n.IPv6) + return ip.Is6() + +} + +// nodeMight4 reports whether n might reply to STUN over IPv4 based on +// its config alone, without DNS lookups. It only returns false if +// it's not explicitly disabled. +func nodeMight4(n *tailcfg.DERPNode) bool { + if n.IPv4 == "" { + return true + } + ip, _ := netaddr.ParseIP(n.IPv4) + return ip.Is4() +} + +// readPackets reads STUN packets from pc until there's an error or ctx is done. +// In either case, it closes pc. +func (c *Client) readPackets(ctx context.Context, pc net.PacketConn) { + done := make(chan struct{}) + defer close(done) + + go func() { + select { + case <-ctx.Done(): + case <-done: + } + pc.Close() + }() + + var buf [64 << 10]byte + for { + n, addr, err := pc.ReadFrom(buf[:]) + if err != nil { + if ctx.Err() != nil { + return + } + c.logf("ReadFrom: %v", err) return } - // STUN against all DERP's IPv4 endpoints, but - // if the previous report had results from - // more than 2 servers, only do 1 try against - // all but the first two. - stuns4 = append(stuns4, s.STUN4) - tries := defaultMaxTries - if len(prev4) > 2 && !stringsContains(prev4[:2], s.STUN4) { - tries = 1 + ua, ok := addr.(*net.UDPAddr) + if !ok { + c.logf("ReadFrom: unexpected addr %T", addr) + continue } - maxTries[s.STUN4] = tries - if s.STUN6 != "" && tries == defaultMaxTries { - // For IPv6, we mostly care whether the user has IPv6 at all, - // so we don't need to send to all servers. The IPv4 timing - // information is enough for now. (We don't yet support IPv6-only) - // So only add the two fastest ones, or all if this is a fresh one. - stuns6 = append(stuns6, s.STUN6) - maxTries[s.STUN6] = 1 + pkt := buf[:n] + if !stun.Is(pkt) { + continue + } + c.ReceiveSTUNPacket(pkt, ua) + } +} + +// reportState holds the state for a single invocation of Client.GetReport. +type reportState struct { + c *Client + hairTX stun.TxID + gotHairSTUN chan *net.UDPAddr + hairTimeout chan struct{} // closed on timeout + pc4 STUNConn + pc6 STUNConn + pc4Hair net.PacketConn + + mu sync.Mutex + sentHairCheck bool + report *Report // to be returned by GetReport + inFlight map[stun.TxID]func(netaddr.IPPort) // called without c.mu held + gotEP4 string +} + +func (rs *reportState) anyUDP() bool { + rs.mu.Lock() + defer rs.mu.Unlock() + return rs.report.UDP +} + +func (rs *reportState) haveRegionLatency(regionID int) bool { + rs.mu.Lock() + defer rs.mu.Unlock() + _, ok := rs.report.RegionLatency[regionID] + return ok +} + +// probeWouldHelp reports whether executing the given probe would +// yield any new information. +// The given node is provided just because the sole caller already has it +// and it saves a lookup. +func (rs *reportState) probeWouldHelp(probe probe, node *tailcfg.DERPNode) bool { + rs.mu.Lock() + defer rs.mu.Unlock() + + // If the probe is for a region we don't yet know about, that + // would help. + if _, ok := rs.report.RegionLatency[node.RegionID]; !ok { + return true + } + + // If the probe is for IPv6 and we don't yet have an IPv6 + // report, that would help. + if probe.proto == probeIPv6 && len(rs.report.RegionV6Latency) == 0 { + return true + } + + // For IPv4, we need at least two IPv4 results overall to + // determine whether we're behind a NAT that shows us as + // different source IPs and/or ports depending on who we're + // talking to. If we don't yet have two results yet + // (MappingVariesByDestIP is blank), then another IPv4 probe + // would be good. + if probe.proto == probeIPv4 && rs.report.MappingVariesByDestIP == "" { + return true + } + + // Otherwise not interesting. + return false +} + +func (rs *reportState) startHairCheckLocked(dst netaddr.IPPort) { + if rs.sentHairCheck { + return + } + rs.sentHairCheck = true + rs.pc4Hair.WriteTo(stun.Request(rs.hairTX), dst.UDPAddr()) + time.AfterFunc(500*time.Millisecond, func() { close(rs.hairTimeout) }) +} + +func (rs *reportState) waitHairCheck(ctx context.Context) { + rs.mu.Lock() + defer rs.mu.Unlock() + if !rs.sentHairCheck { + return + } + ret := rs.report + + select { + case <-rs.gotHairSTUN: + ret.HairPinning.Set(true) + case <-rs.hairTimeout: + ret.HairPinning.Set(false) + default: + select { + case <-rs.gotHairSTUN: + ret.HairPinning.Set(true) + case <-rs.hairTimeout: + ret.HairPinning.Set(false) + case <-ctx.Done(): + } + } +} + +// addNodeLatency updates rs to note that node's latency is d. If ipp +// is non-zero (for all but HTTPS replies), it's recorded as our UDP +// IP:port. +func (rs *reportState) addNodeLatency(node *tailcfg.DERPNode, ipp netaddr.IPPort, d time.Duration) { + var ipPortStr string + if ipp != (netaddr.IPPort{}) { + ipPortStr = net.JoinHostPort(ipp.IP.String(), fmt.Sprint(ipp.Port)) + } + + rs.mu.Lock() + defer rs.mu.Unlock() + ret := rs.report + + ret.UDP = true + updateLatency(&ret.RegionLatency, node.RegionID, d) + + switch { + case ipp.IP.Is6(): + updateLatency(&ret.RegionV6Latency, node.RegionID, d) + ret.IPv6 = true + ret.GlobalV6 = ipPortStr + // TODO: track MappingVariesByDestIP for IPv6 + // too? Would be sad if so, but who knows. + case ipp.IP.Is4(): + updateLatency(&ret.RegionV4Latency, node.RegionID, d) + if rs.gotEP4 == "" { + rs.gotEP4 = ipPortStr + ret.GlobalV4 = ipPortStr + rs.startHairCheckLocked(ipp) + } else { + if rs.gotEP4 != ipPortStr { + ret.MappingVariesByDestIP.Set(true) + } else if ret.MappingVariesByDestIP == "" { + ret.MappingVariesByDestIP.Set(false) + } } - }) - - if len(stuns4) == 0 { - // TODO: make this work? if we ever need it - // to. Requirement for self-hosted Tailscale might be - // to run a DERP+STUN server co-resident with the - // Control server. - return nil, nil, nil, errors.New("netcheck: GetReport: no STUN servers, no Report") } - sort.Strings(stuns4) - sort.Strings(stuns6) - return stuns4, stuns6, maxTries, nil } // GetReport gets a report. // // It may not be called concurrently with itself. -func (c *Client) GetReport(ctx context.Context) (*Report, error) { +func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap) (*Report, error) { // Mask user context with ours that we guarantee to cancel so // we can depend on it being closed in goroutines later. // (User ctx might be context.Background, etc) ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() - if c.DERP == nil { - return nil, errors.New("netcheck: GetReport: Client.DERP is nil") + if dm == nil { + return nil, errors.New("netcheck: GetReport: DERP map is nil") } c.mu.Lock() - if c.gotHairSTUN != nil { + if c.curState != nil { c.mu.Unlock() return nil, errors.New("invalid concurrent call to GetReport") } - hairTX := stun.NewTxID() // random payload - c.hairTX = hairTX - gotHairSTUN := make(chan *net.UDPAddr, 1) - c.gotHairSTUN = gotHairSTUN + rs := &reportState{ + c: c, + report: new(Report), + inFlight: map[stun.TxID]func(netaddr.IPPort){}, + hairTX: stun.NewTxID(), // random payload + gotHairSTUN: make(chan *net.UDPAddr, 1), + hairTimeout: make(chan struct{}), + } + c.curState = rs + last := c.last + now := c.timeNow() + if c.nextFull || now.Sub(c.lastFull) > 5*time.Minute { + last = nil // causes makeProbePlan below to do a full (initial) plan + c.nextFull = false + c.lastFull = now + } c.mu.Unlock() defer func() { c.mu.Lock() defer c.mu.Unlock() - c.s4 = nil - c.s6 = nil - c.gotHairSTUN = nil + c.curState = nil }() - stuns4, stuns6, maxTries, err := c.pickSubset() - if err != nil { - return nil, err - } - - closeOnCtx := func(c io.Closer) { - <-ctx.Done() - c.Close() - } - v6iface, err := interfaces.HaveIPv6GlobalAddress() if err != nil { c.logf("interfaces: %v", err) } // Create a UDP4 socket used for sending to our discovered IPv4 address. - pc4Hair, err := net.ListenPacket("udp4", ":0") + rs.pc4Hair, err = net.ListenPacket("udp4", ":0") if err != nil { c.logf("udp4: %v", err) return nil, err } - defer pc4Hair.Close() - hairTimeout := make(chan bool, 1) - startHairCheck := func(dstEP string) { - if dst, err := net.ResolveUDPAddr("udp4", dstEP); err == nil { - pc4Hair.WriteTo(stun.Request(hairTX), dst) - time.AfterFunc(500*time.Millisecond, func() { hairTimeout <- true }) - } - } - - var ( - mu sync.Mutex - ret = &Report{ - DERPLatency: map[string]time.Duration{}, - } - gotEP = map[string]string{} // server -> ipPort - gotEP4 string - ) - anyV6 := func() bool { - mu.Lock() - defer mu.Unlock() - return ret.IPv6 - } - anyV4 := func() bool { - mu.Lock() - defer mu.Unlock() - return gotEP4 != "" - } - add := func(server, ipPort string, d time.Duration) { - ua, err := net.ResolveUDPAddr("udp", ipPort) - if err != nil { - c.logf("[unexpected] STUN addr %q", ipPort) - return - } - isV6 := ua.IP.To4() == nil - - mu.Lock() - defer mu.Unlock() - ret.UDP = true - ret.DERPLatency[server] = d - if isV6 { - ret.IPv6 = true - ret.GlobalV6 = ipPort - // TODO: track MappingVariesByDestIP for IPv6 - // too? Would be sad if so, but who knows. - } else { - // IPv4 - if gotEP4 == "" { - gotEP4 = ipPort - ret.GlobalV4 = ipPort - startHairCheck(ipPort) - } else { - if gotEP4 != ipPort { - ret.MappingVariesByDestIP.Set(true) - } else if ret.MappingVariesByDestIP == "" { - ret.MappingVariesByDestIP.Set(false) - } - } - } - gotEP[server] = ipPort - } - - var pc4, pc6 STUNConn + defer rs.pc4Hair.Close() if f := c.GetSTUNConn4; f != nil { - pc4 = f() + rs.pc4 = f() } else { u4, err := net.ListenPacket("udp4", ":0") if err != nil { c.logf("udp4: %v", err) return nil, err } - pc4 = u4 - go closeOnCtx(u4) + rs.pc4 = u4 + go c.readPackets(ctx, u4) } if v6iface { if f := c.GetSTUNConn6; f != nil { - pc6 = f() + rs.pc6 = f() } else { u6, err := net.ListenPacket("udp6", ":0") if err != nil { c.logf("udp6: %v", err) } else { - pc6 = u6 - go closeOnCtx(u6) + rs.pc6 = u6 + go c.readPackets(ctx, u6) } } } - reader := func(s *stunner.Stunner, pc STUNConn) { - var buf [64 << 10]byte - for { - n, addr, err := pc.ReadFrom(buf[:]) - if err != nil { - if ctx.Err() != nil { - return - } - c.logf("ReadFrom: %v", err) - return + plan := makeProbePlan(dm, v6iface, last) + + wg := syncs.NewWaitGroupChan() + wg.Add(len(plan)) + for _, probeSet := range plan { + setCtx, cancelSet := context.WithCancel(ctx) + go func(probeSet []probe) { + for _, probe := range probeSet { + go rs.runProbe(setCtx, dm, probe, cancelSet) } - ua, ok := addr.(*net.UDPAddr) - if !ok { - c.logf("ReadFrom: unexpected addr %T", addr) - continue - } - if c.handleHairSTUN(buf[:n], ua) { - continue - } - s.Receive(buf[:n], ua) - } - + <-setCtx.Done() + wg.Decr() + }(probeSet) } - var grp errgroup.Group - - s4 := &stunner.Stunner{ - Send: pc4.WriteTo, - Endpoint: add, - Servers: stuns4, - Logf: c.logf, - DNSCache: dnscache.Get(), - MaxTries: maxTries, + select { + case <-ctx.Done(): + case <-wg.DoneChan(): } - c.mu.Lock() - c.s4 = s4 - c.mu.Unlock() + rs.waitHairCheck(ctx) - grp.Go(func() error { - err := s4.Run(ctx) - if errors.Is(err, context.DeadlineExceeded) { - if !anyV4() { - c.logf("netcheck: no IPv4 UDP STUN replies") - } - return nil - } - return err - }) - if c.GetSTUNConn4 == nil { - go reader(s4, pc4) - } - - if pc6 != nil && len(stuns6) > 0 { - s6 := &stunner.Stunner{ - Endpoint: add, - Send: pc6.WriteTo, - Servers: stuns6, - Logf: c.logf, - OnlyIPv6: true, - DNSCache: dnscache.Get(), - MaxTries: maxTries, - } - - c.mu.Lock() - c.s6 = s6 - c.mu.Unlock() - - grp.Go(func() error { - err := s6.Run(ctx) - if errors.Is(err, context.DeadlineExceeded) { - if !anyV6() { - // IPv6 seemed like it was configured, but actually failed. - // Just log and return a nil error. - c.logf("IPv6 seemed configured, but no UDP STUN replies") - } - return nil - } - // Otherwise must be some invalid use of Stunner. - return err // - }) - if c.GetSTUNConn6 == nil { - go reader(s6, pc6) - } - } - - err = grp.Wait() - if err != nil { - return nil, err - } - - mu.Lock() - // Check hairpinning. - if ret.MappingVariesByDestIP == "false" && gotEP4 != "" { - select { - case <-gotHairSTUN: - ret.HairPinning.Set(true) - case <-hairTimeout: - ret.HairPinning.Set(false) - } - } - mu.Unlock() - - // Try HTTPS latency check if UDP is blocked and all checkings failed - if !anyV4() { - c.logf("netcheck: UDP is blocked, try HTTPS") + // Try HTTPS latency check if all STUN probes failed due to UDP presumably being blocked. + if !rs.anyUDP() { var wg sync.WaitGroup - for _, server := range stuns4 { - server := server - if _, ok := ret.DERPLatency[server]; ok { - continue + var need []*tailcfg.DERPRegion + for rid, reg := range dm.Regions { + if !rs.haveRegionLatency(rid) && regionHasDERPNode(reg) { + need = append(need, reg) } - - wg.Add(1) - go func() { + } + if len(need) > 0 { + wg.Add(len(need)) + c.logf("netcheck: UDP is blocked, trying HTTPS") + } + for _, reg := range need { + go func(reg *tailcfg.DERPRegion) { defer wg.Done() - if d, err := c.measureHTTPSLatency(server); err != nil { - c.logf("netcheck: measuring HTTPS latency of %v: %v", server, err) + if d, err := c.measureHTTPSLatency(reg); err != nil { + c.logf("netcheck: measuring HTTPS latency of %v (%d): %v", reg.RegionCode, reg.RegionID, err) } else { - mu.Lock() - ret.DERPLatency[server] = d - mu.Unlock() + rs.mu.Lock() + rs.report.RegionLatency[reg.RegionID] = d + rs.mu.Unlock() } - }() + }(reg) } wg.Wait() } - report := ret.Clone() + rs.mu.Lock() + report := rs.report.Clone() + rs.mu.Unlock() c.addReportHistoryAndSetPreferredDERP(report) - c.logConciseReport(report) + c.logConciseReport(report, dm) return report, nil } -func (c *Client) measureHTTPSLatency(server string) (time.Duration, error) { - host, _, err := net.SplitHostPort(server) - if err != nil { - return 0, err +// TODO: have caller pass in context +func (c *Client) measureHTTPSLatency(reg *tailcfg.DERPRegion) (time.Duration, error) { + if len(reg.Nodes) == 0 { + return 0, errors.New("no nodes") } + node := reg.Nodes[0] // TODO: use all nodes per region + host := node.HostName + // TODO: connect using provided IPv4/IPv6; use a Trasport & set the dialer var result httpstat.Result hctx, cancel := context.WithTimeout(httpstat.WithHTTPStat(context.Background(), &result), 5*time.Second) @@ -522,7 +721,7 @@ func (c *Client) measureHTTPSLatency(server string) (time.Duration, error) { return result.ServerProcessing, nil } -func (c *Client) logConciseReport(r *Report) { +func (c *Client) logConciseReport(r *Report, dm *tailcfg.DERPMap) { buf := bytes.NewBuffer(make([]byte, 0, 256)) // empirically: 5 DERPs + IPv6 == ~233 bytes fmt.Fprintf(buf, "udp=%v", r.UDP) fmt.Fprintf(buf, " v6=%v", r.IPv6) @@ -537,21 +736,20 @@ func (c *Client) logConciseReport(r *Report) { fmt.Fprintf(buf, " derp=%v", r.PreferredDERP) if r.PreferredDERP != 0 { fmt.Fprintf(buf, " derpdist=") - for i, id := range c.DERP.IDs() { + for i, rid := range dm.RegionIDs() { if i != 0 { buf.WriteByte(',') } - s := c.DERP.ServerByID(id) needComma := false - if d := r.DERPLatency[s.STUN4]; d != 0 { - fmt.Fprintf(buf, "%dv4:%v", id, d.Round(time.Millisecond)) + if d := r.RegionV4Latency[rid]; d != 0 { + fmt.Fprintf(buf, "%dv4:%v", rid, d.Round(time.Millisecond)) needComma = true } - if d := r.DERPLatency[s.STUN6]; d != 0 { + if d := r.RegionV6Latency[rid]; d != 0 { if needComma { buf.WriteByte(',') } - fmt.Fprintf(buf, "%dv6:%v", id, d.Round(time.Millisecond)) + fmt.Fprintf(buf, "%dv6:%v", rid, d.Round(time.Millisecond)) } } } @@ -581,15 +779,15 @@ func (c *Client) addReportHistoryAndSetPreferredDERP(r *Report) { const maxAge = 5 * time.Minute - // STUN host:port => its best recent latency in last maxAge - bestRecent := map[string]time.Duration{} + // region ID => its best recent latency in last maxAge + bestRecent := map[int]time.Duration{} for t, pr := range c.prev { if now.Sub(t) > maxAge { delete(c.prev, t) continue } - for hp, d := range pr.DERPLatency { + for hp, d := range pr.RegionLatency { if bd, ok := bestRecent[hp]; !ok || d < bd { bestRecent[hp] = d } @@ -599,18 +797,133 @@ func (c *Client) addReportHistoryAndSetPreferredDERP(r *Report) { // Then, pick which currently-alive DERP server from the // current report has the best latency over the past maxAge. var bestAny time.Duration - for hp := range r.DERPLatency { + for hp := range r.RegionLatency { best := bestRecent[hp] if r.PreferredDERP == 0 || best < bestAny { bestAny = best - r.PreferredDERP = c.DERP.NodeIDOfSTUNServer(hp) + r.PreferredDERP = hp } } } -func stringsContains(ss []string, s string) bool { - for _, v := range ss { - if s == v { +func updateLatency(mp *map[int]time.Duration, regionID int, d time.Duration) { + if *mp == nil { + *mp = make(map[int]time.Duration) + } + m := *mp + if prev, ok := m[regionID]; !ok || d < prev { + m[regionID] = d + } +} + +func namedNode(dm *tailcfg.DERPMap, nodeName string) *tailcfg.DERPNode { + if dm == nil { + return nil + } + for _, r := range dm.Regions { + for _, n := range r.Nodes { + if n.Name == nodeName { + return n + } + } + } + return nil +} + +func (rs *reportState) runProbe(ctx context.Context, dm *tailcfg.DERPMap, probe probe, cancelSet func()) { + c := rs.c + node := namedNode(dm, probe.node) + if node == nil { + c.logf("netcheck.runProbe: named node %q not found", probe.node) + return + } + + if probe.delay > 0 { + delayTimer := time.NewTimer(probe.delay) + select { + case <-delayTimer.C: + case <-ctx.Done(): + delayTimer.Stop() + return + } + } + + if !rs.probeWouldHelp(probe, node) { + cancelSet() + return + } + + addr := c.nodeAddr(ctx, node, probe.proto) + if addr == nil { + return + } + + txID := stun.NewTxID() + req := stun.Request(txID) + + sent := time.Now() // after DNS lookup above + + rs.mu.Lock() + rs.inFlight[txID] = func(ipp netaddr.IPPort) { + rs.addNodeLatency(node, ipp, time.Since(sent)) + cancelSet() // abort other nodes in this set + } + rs.mu.Unlock() + + switch probe.proto { + case probeIPv4: + rs.pc4.WriteTo(req, addr) + case probeIPv6: + rs.pc6.WriteTo(req, addr) + default: + panic("bad probe proto " + fmt.Sprint(probe.proto)) + } +} + +// proto is 4 or 6 +// If it returns nil, the node is skipped. +func (c *Client) nodeAddr(ctx context.Context, n *tailcfg.DERPNode, proto probeProto) *net.UDPAddr { + port := n.STUNPort + if port == 0 { + port = 3478 + } + if port < 0 || port > 1<<16-1 { + return nil + } + switch proto { + case probeIPv4: + if n.IPv4 != "" { + ip, _ := netaddr.ParseIP(n.IPv4) + if !ip.Is4() { + return nil + } + return netaddr.IPPort{ip, uint16(port)}.UDPAddr() + } + case probeIPv6: + if n.IPv6 != "" { + ip, _ := netaddr.ParseIP(n.IPv6) + if !ip.Is6() { + return nil + } + return netaddr.IPPort{ip, uint16(port)}.UDPAddr() + } + default: + return nil + } + + // TODO(bradfitz): add singleflight+dnscache here. + addrs, _ := net.DefaultResolver.LookupIPAddr(ctx, n.HostName) + for _, a := range addrs { + if (a.IP.To4() != nil) == (proto == probeIPv4) { + return &net.UDPAddr{IP: a.IP, Port: port} + } + } + return nil +} + +func regionHasDERPNode(r *tailcfg.DERPRegion) bool { + for _, n := range r.Nodes { + if !n.STUNOnly { return true } } diff --git a/netcheck/netcheck_test.go b/netcheck/netcheck_test.go index 6e9894df9..58a4eef74 100644 --- a/netcheck/netcheck_test.go +++ b/netcheck/netcheck_test.go @@ -9,28 +9,34 @@ "fmt" "net" "reflect" + "sort" + "strconv" + "strings" "testing" "time" - "tailscale.com/derp/derpmap" "tailscale.com/stun" "tailscale.com/stun/stuntest" + "tailscale.com/tailcfg" ) func TestHairpinSTUN(t *testing.T) { + tx := stun.NewTxID() c := &Client{ - hairTX: stun.NewTxID(), - gotHairSTUN: make(chan *net.UDPAddr, 1), + curState: &reportState{ + hairTX: tx, + gotHairSTUN: make(chan *net.UDPAddr, 1), + }, } - req := stun.Request(c.hairTX) + req := stun.Request(tx) if !stun.Is(req) { t.Fatal("expected STUN message") } - if !c.handleHairSTUN(req, nil) { + if !c.handleHairSTUNLocked(req, nil) { t.Fatal("expected true") } select { - case <-c.gotHairSTUN: + case <-c.curState.gotHairSTUN: default: t.Fatal("expected value") } @@ -41,25 +47,24 @@ func TestBasic(t *testing.T) { defer cleanup() c := &Client{ - DERP: derpmap.NewTestWorld(stunAddr), Logf: t.Logf, } ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - r, err := c.GetReport(ctx) + r, err := c.GetReport(ctx, stuntest.DERPMapOf(stunAddr.String())) if err != nil { t.Fatal(err) } if !r.UDP { t.Error("want UDP") } - if len(r.DERPLatency) != 1 { - t.Errorf("expected 1 key in DERPLatency; got %+v", r.DERPLatency) + if len(r.RegionLatency) != 1 { + t.Errorf("expected 1 key in DERPLatency; got %+v", r.RegionLatency) } - if _, ok := r.DERPLatency[stunAddr]; !ok { - t.Errorf("expected key %q in DERPLatency; got %+v", stunAddr, r.DERPLatency) + if _, ok := r.RegionLatency[1]; !ok { + t.Errorf("expected key 1 in DERPLatency; got %+v", r.RegionLatency) } if r.GlobalV4 == "" { t.Error("expected GlobalV4 set") @@ -78,20 +83,20 @@ func TestWorksWhenUDPBlocked(t *testing.T) { stunAddr := blackhole.LocalAddr().String() + dm := stuntest.DERPMapOf(stunAddr) + dm.Regions[1].Nodes[0].STUNOnly = true + c := &Client{ - DERP: derpmap.NewTestWorld(stunAddr), Logf: t.Logf, } ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) defer cancel() - r, err := c.GetReport(ctx) + r, err := c.GetReport(ctx, dm) if err != nil { t.Fatal(err) } - want := &Report{ - DERPLatency: map[string]time.Duration{}, - } + want := new(Report) if !reflect.DeepEqual(r, want) { t.Errorf("mismatch\n got: %+v\nwant: %+v\n", r, want) @@ -99,30 +104,24 @@ func TestWorksWhenUDPBlocked(t *testing.T) { } func TestAddReportHistoryAndSetPreferredDERP(t *testing.T) { - derps := derpmap.NewTestWorldWith( - &derpmap.Server{ - ID: 1, - STUN4: "d1:1", - }, - &derpmap.Server{ - ID: 2, - STUN4: "d2:1", - }, - &derpmap.Server{ - ID: 3, - STUN4: "d3:1", - }, - ) // report returns a *Report from (DERP host, time.Duration)+ pairs. report := func(a ...interface{}) *Report { - r := &Report{DERPLatency: map[string]time.Duration{}} + r := &Report{RegionLatency: map[int]time.Duration{}} for i := 0; i < len(a); i += 2 { - k := a[i].(string) + ":1" + s := a[i].(string) + if !strings.HasPrefix(s, "d") { + t.Fatalf("invalid derp server key %q", s) + } + regionID, err := strconv.Atoi(s[1:]) + if err != nil { + t.Fatalf("invalid derp server key %q", s) + } + switch v := a[i+1].(type) { case time.Duration: - r.DERPLatency[k] = v + r.RegionLatency[regionID] = v case int: - r.DERPLatency[k] = time.Second * time.Duration(v) + r.RegionLatency[regionID] = time.Second * time.Duration(v) default: panic(fmt.Sprintf("unexpected type %T", v)) } @@ -194,7 +193,6 @@ type step struct { t.Run(tt.name, func(t *testing.T) { fakeTime := time.Unix(123, 0) c := &Client{ - DERP: derps, TimeNow: func() time.Time { return fakeTime }, } for _, s := range tt.steps { @@ -212,81 +210,217 @@ type step struct { } } -func TestPickSubset(t *testing.T) { - derps := derpmap.NewTestWorldWith( - &derpmap.Server{ - ID: 1, - STUN4: "d1:4", - STUN6: "d1:6", - }, - &derpmap.Server{ - ID: 2, - STUN4: "d2:4", - STUN6: "d2:6", - }, - &derpmap.Server{ - ID: 3, - STUN4: "d3:4", - STUN6: "d3:6", - }, - ) +func TestMakeProbePlan(t *testing.T) { + // basicMap has 5 regions. each region has a number of nodes + // equal to the region number (1 has 1a, 2 has 2a and 2b, etc.) + basicMap := &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{}, + } + for rid := 1; rid <= 5; rid++ { + var nodes []*tailcfg.DERPNode + for nid := 0; nid < rid; nid++ { + nodes = append(nodes, &tailcfg.DERPNode{ + Name: fmt.Sprintf("%d%c", rid, 'a'+rune(nid)), + RegionID: rid, + HostName: fmt.Sprintf("derp%d-%d", rid, nid), + IPv4: fmt.Sprintf("%d.0.0.%d", rid, nid), + IPv6: fmt.Sprintf("%d::%d", rid, nid), + }) + } + basicMap.Regions[rid] = &tailcfg.DERPRegion{ + RegionID: rid, + Nodes: nodes, + } + } + + const ms = time.Millisecond + p := func(name string, c rune, d ...time.Duration) probe { + var proto probeProto + switch c { + case 4: + proto = probeIPv4 + case 6: + proto = probeIPv6 + case 'h': + proto = probeHTTPS + } + pr := probe{node: name, proto: proto} + if len(d) == 1 { + pr.delay = d[0] + } else if len(d) > 1 { + panic("too many args") + } + return pr + } tests := []struct { - name string - last *Report - want4 []string - want6 []string - wantTries map[string]int + name string + dm *tailcfg.DERPMap + have6if bool + last *Report + want probePlan }{ { - name: "fresh", - last: nil, - want4: []string{"d1:4", "d2:4", "d3:4"}, - want6: []string{"d1:6", "d2:6", "d3:6"}, - wantTries: map[string]int{ - "d1:4": 2, - "d2:4": 2, - "d3:4": 2, - "d1:6": 1, - "d2:6": 1, - "d3:6": 1, + name: "initial_v6", + dm: basicMap, + have6if: true, + last: nil, // initial + want: probePlan{ + "region-1-v4": []probe{p("1a", 4), p("1a", 4, 100*ms), p("1a", 4, 200*ms)}, // all a + "region-1-v6": []probe{p("1a", 6), p("1a", 6, 100*ms), p("1a", 6, 200*ms)}, + "region-2-v4": []probe{p("2a", 4), p("2b", 4, 100*ms), p("2a", 4, 200*ms)}, // a -> b -> a + "region-2-v6": []probe{p("2a", 6), p("2b", 6, 100*ms), p("2a", 6, 200*ms)}, + "region-3-v4": []probe{p("3a", 4), p("3b", 4, 100*ms), p("3c", 4, 200*ms)}, // a -> b -> c + "region-3-v6": []probe{p("3a", 6), p("3b", 6, 100*ms), p("3c", 6, 200*ms)}, + "region-4-v4": []probe{p("4a", 4), p("4b", 4, 100*ms), p("4c", 4, 200*ms)}, + "region-4-v6": []probe{p("4a", 6), p("4b", 6, 100*ms), p("4c", 6, 200*ms)}, + "region-5-v4": []probe{p("5a", 4), p("5b", 4, 100*ms), p("5c", 4, 200*ms)}, + "region-5-v6": []probe{p("5a", 6), p("5b", 6, 100*ms), p("5c", 6, 200*ms)}, }, }, { - name: "1_and_3_closest", + name: "initial_no_v6", + dm: basicMap, + have6if: false, + last: nil, // initial + want: probePlan{ + "region-1-v4": []probe{p("1a", 4), p("1a", 4, 100*ms), p("1a", 4, 200*ms)}, // all a + "region-2-v4": []probe{p("2a", 4), p("2b", 4, 100*ms), p("2a", 4, 200*ms)}, // a -> b -> a + "region-3-v4": []probe{p("3a", 4), p("3b", 4, 100*ms), p("3c", 4, 200*ms)}, // a -> b -> c + "region-4-v4": []probe{p("4a", 4), p("4b", 4, 100*ms), p("4c", 4, 200*ms)}, + "region-5-v4": []probe{p("5a", 4), p("5b", 4, 100*ms), p("5c", 4, 200*ms)}, + }, + }, + { + name: "second_v4_no_6if", + dm: basicMap, + have6if: false, last: &Report{ - DERPLatency: map[string]time.Duration{ - "d1:4": 15 * time.Millisecond, - "d2:4": 300 * time.Millisecond, - "d3:4": 25 * time.Millisecond, + RegionLatency: map[int]time.Duration{ + 1: 10 * time.Millisecond, + 2: 20 * time.Millisecond, + 3: 30 * time.Millisecond, + 4: 40 * time.Millisecond, + // Pretend 5 is missing + }, + RegionV4Latency: map[int]time.Duration{ + 1: 10 * time.Millisecond, + 2: 20 * time.Millisecond, + 3: 30 * time.Millisecond, + 4: 40 * time.Millisecond, }, }, - want4: []string{"d1:4", "d2:4", "d3:4"}, - want6: []string{"d1:6", "d3:6"}, - wantTries: map[string]int{ - "d1:4": 2, - "d3:4": 2, - "d2:4": 1, - "d1:6": 1, - "d3:6": 1, + want: probePlan{ + "region-1-v4": []probe{p("1a", 4), p("1a", 4, 12*ms)}, + "region-2-v4": []probe{p("2a", 4), p("2b", 4, 24*ms)}, + "region-3-v4": []probe{p("3a", 4)}, + }, + }, + { + name: "second_v4_only_with_6if", + dm: basicMap, + have6if: true, + last: &Report{ + RegionLatency: map[int]time.Duration{ + 1: 10 * time.Millisecond, + 2: 20 * time.Millisecond, + 3: 30 * time.Millisecond, + 4: 40 * time.Millisecond, + // Pretend 5 is missing + }, + RegionV4Latency: map[int]time.Duration{ + 1: 10 * time.Millisecond, + 2: 20 * time.Millisecond, + 3: 30 * time.Millisecond, + 4: 40 * time.Millisecond, + }, + }, + want: probePlan{ + "region-1-v4": []probe{p("1a", 4), p("1a", 4, 12*ms)}, + "region-1-v6": []probe{p("1a", 6)}, + "region-2-v4": []probe{p("2a", 4), p("2b", 4, 24*ms)}, + "region-2-v6": []probe{p("2a", 6)}, + "region-3-v4": []probe{p("3a", 4)}, + }, + }, + { + name: "second_mixed", + dm: basicMap, + have6if: true, + last: &Report{ + RegionLatency: map[int]time.Duration{ + 1: 10 * time.Millisecond, + 2: 20 * time.Millisecond, + 3: 30 * time.Millisecond, + 4: 40 * time.Millisecond, + // Pretend 5 is missing + }, + RegionV4Latency: map[int]time.Duration{ + 1: 10 * time.Millisecond, + 2: 20 * time.Millisecond, + }, + RegionV6Latency: map[int]time.Duration{ + 3: 30 * time.Millisecond, + 4: 40 * time.Millisecond, + }, + }, + want: probePlan{ + "region-1-v4": []probe{p("1a", 4), p("1a", 4, 12*ms)}, + "region-1-v6": []probe{p("1a", 6), p("1a", 6, 12*ms)}, + "region-2-v4": []probe{p("2a", 4), p("2b", 4, 24*ms)}, + "region-2-v6": []probe{p("2a", 6), p("2b", 6, 24*ms)}, + "region-3-v4": []probe{p("3a", 4)}, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &Client{DERP: derps, last: tt.last} - got4, got6, gotTries, err := c.pickSubset() - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(got4, tt.want4) { - t.Errorf("stuns4 = %q; want %q", got4, tt.want4) - } - if !reflect.DeepEqual(got6, tt.want6) { - t.Errorf("stuns6 = %q; want %q", got6, tt.want6) - } - if !reflect.DeepEqual(gotTries, tt.wantTries) { - t.Errorf("tries = %v; want %v", gotTries, tt.wantTries) + got := makeProbePlan(tt.dm, tt.have6if, tt.last) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("unexpected plan; got:\n%v\nwant:\n%v\n", got, tt.want) } }) } } + +func (plan probePlan) String() string { + var sb strings.Builder + keys := []string{} + for k := range plan { + keys = append(keys, k) + } + sort.Strings(keys) + + for _, key := range keys { + fmt.Fprintf(&sb, "[%s]", key) + pv := plan[key] + for _, p := range pv { + fmt.Fprintf(&sb, " %v", p) + } + sb.WriteByte('\n') + } + return sb.String() +} + +func (p probe) String() string { + wait := "" + if p.wait > 0 { + wait = "+" + p.wait.String() + } + delay := "" + if p.delay > 0 { + delay = "@" + p.delay.String() + } + return fmt.Sprintf("%s-%s%s%s", p.node, p.proto, delay, wait) +} + +func (p probeProto) String() string { + switch p { + case probeIPv4: + return "v4" + case probeIPv6: + return "v4" + case probeHTTPS: + return "https" + } + return "?" +} diff --git a/stun/stuntest/stuntest.go b/stun/stuntest/stuntest.go index ba244fc37..b53db0e0a 100644 --- a/stun/stuntest/stuntest.go +++ b/stun/stuntest/stuntest.go @@ -6,12 +6,16 @@ package stuntest import ( + "fmt" "net" + "strconv" "strings" "sync" "testing" + "inet.af/netaddr" "tailscale.com/stun" + "tailscale.com/tailcfg" ) type stunStats struct { @@ -20,7 +24,7 @@ type stunStats struct { readIPv6 int } -func Serve(t *testing.T) (addr string, cleanupFn func()) { +func Serve(t *testing.T) (addr *net.UDPAddr, cleanupFn func()) { t.Helper() // TODO(crawshaw): use stats to test re-STUN logic @@ -30,13 +34,13 @@ func Serve(t *testing.T) (addr string, cleanupFn func()) { if err != nil { t.Fatalf("failed to open STUN listener: %v", err) } - - stunAddr := pc.LocalAddr().String() - stunAddr = strings.Replace(stunAddr, "0.0.0.0:", "127.0.0.1:", 1) - + addr = &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: pc.LocalAddr().(*net.UDPAddr).Port, + } doneCh := make(chan struct{}) go runSTUN(t, pc, &stats, doneCh) - return stunAddr, func() { + return addr, func() { pc.Close() <-doneCh } @@ -79,3 +83,47 @@ func runSTUN(t *testing.T, pc net.PacketConn, stats *stunStats, done chan<- stru } } } + +func DERPMapOf(stun ...string) *tailcfg.DERPMap { + m := &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{}, + } + for i, hostPortStr := range stun { + regionID := i + 1 + host, portStr, err := net.SplitHostPort(hostPortStr) + if err != nil { + panic(fmt.Sprintf("bogus STUN hostport: %q", hostPortStr)) + } + port, err := strconv.Atoi(portStr) + if err != nil { + panic(fmt.Sprintf("bogus port %q in %q", portStr, hostPortStr)) + } + var ipv4, ipv6 string + ip, err := netaddr.ParseIP(host) + if err != nil { + panic(fmt.Sprintf("bogus non-IP STUN host %q in %q", host, hostPortStr)) + } + if ip.Is4() { + ipv4 = host + ipv6 = "none" + } + if ip.Is6() { + ipv6 = host + ipv4 = "none" + } + node := &tailcfg.DERPNode{ + Name: fmt.Sprint(regionID) + "a", + RegionID: regionID, + HostName: fmt.Sprintf("d%d.invalid", regionID), + IPv4: ipv4, + IPv6: ipv6, + STUNPort: port, + STUNOnly: true, + } + m.Regions[regionID] = &tailcfg.DERPRegion{ + RegionID: regionID, + Nodes: []*tailcfg.DERPNode{node}, + } + } + return m +} diff --git a/stunner/stunner.go b/stunner/stunner.go deleted file mode 100644 index a99ad3c49..000000000 --- a/stunner/stunner.go +++ /dev/null @@ -1,310 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package stunner - -import ( - "context" - "errors" - "fmt" - "math/rand" - "net" - "strconv" - "strings" - "sync" - "time" - - "tailscale.com/net/dnscache" - "tailscale.com/stun" - "tailscale.com/types/structs" -) - -// Stunner sends a STUN request to several servers and handles a response. -// -// It is designed to used on a connection owned by other code and so does -// not directly reference a net.Conn of any sort. Instead, the user should -// provide Send function to send packets, and call Receive when a new -// STUN response is received. -// -// In response, a Stunner will call Endpoint with any endpoints determined -// for the connection. (An endpoint may be reported multiple times if -// multiple servers are provided.) -type Stunner struct { - // Send sends a packet. - // It will typically be a PacketConn.WriteTo method value. - Send func([]byte, net.Addr) (int, error) // sends a packet - - // Endpoint is called whenever a STUN response is received. - // The server is the STUN server that replied, endpoint is the ip:port - // from the STUN response, and d is the duration that the STUN request - // took on the wire (not including DNS lookup time. - Endpoint func(server, endpoint string, d time.Duration) - - // onPacket is the internal version of Endpoint that does de-dup. - // It's set by Run. - onPacket func(server, endpoint string, d time.Duration) - - Servers []string // STUN servers to contact - - // DNSCache optionally specifies a DNSCache to use. - // If nil, a DNS cache is not used. - DNSCache *dnscache.Resolver - - // Logf optionally specifies a log function. If nil, logging is disabled. - Logf func(format string, args ...interface{}) - - // OnlyIPv6 controls whether IPv6 is exclusively used. - // If false, only IPv4 is used. There is currently no mixed mode. - OnlyIPv6 bool - - // MaxTries optionally provides a mapping from server name to the maximum - // number of tries that should be made for a given server. - // If nil or a server is not present in the map, the default is 1. - // Values less than 1 are ignored. - MaxTries map[string]int - - mu sync.Mutex - inFlight map[stun.TxID]request -} - -func (s *Stunner) addTX(tx stun.TxID, server string) { - s.mu.Lock() - defer s.mu.Unlock() - if _, dup := s.inFlight[tx]; dup { - panic("unexpected duplicate STUN TransactionID") - } - s.inFlight[tx] = request{sent: time.Now(), server: server} -} - -func (s *Stunner) removeTX(tx stun.TxID) (request, bool) { - s.mu.Lock() - defer s.mu.Unlock() - if s.inFlight == nil { - return request{}, false - } - r, ok := s.inFlight[tx] - if ok { - delete(s.inFlight, tx) - } else { - s.logf("stunner: got STUN packet for unknown TxID %x", tx) - } - return r, ok -} - -type request struct { - _ structs.Incomparable - sent time.Time - server string -} - -func (s *Stunner) logf(format string, args ...interface{}) { - if s.Logf != nil { - s.Logf(format, args...) - } -} - -// Receive delivers a STUN packet to the stunner. -func (s *Stunner) Receive(p []byte, fromAddr *net.UDPAddr) { - if !stun.Is(p) { - s.logf("[unexpected] stunner: received non-STUN packet") - return - } - now := time.Now() - tx, addr, port, err := stun.ParseResponse(p) - if err != nil { - if _, err := stun.ParseBindingRequest(p); err == nil { - // This was probably our own netcheck hairpin - // check probe coming in late. Ignore. - return - } - s.logf("stunner: received unexpected STUN message response from %v: %v", fromAddr, err) - return - } - r, ok := s.removeTX(tx) - if !ok { - return - } - d := now.Sub(r.sent) - - host := net.JoinHostPort(net.IP(addr).String(), fmt.Sprint(port)) - s.onPacket(r.server, host, d) -} - -func (s *Stunner) resolver() *net.Resolver { - return net.DefaultResolver -} - -// cleanUpPostRun zeros out some fields, mostly for debugging (so -// things crash or race+fail if there's a sender still running.) -func (s *Stunner) cleanUpPostRun() { - s.mu.Lock() - s.inFlight = nil - s.mu.Unlock() -} - -// Run starts a Stunner and blocks until all servers either respond -// or are tried multiple times and timeout. -// It can not be called concurrently with itself. -func (s *Stunner) Run(ctx context.Context) error { - for _, server := range s.Servers { - if _, _, err := net.SplitHostPort(server); err != nil { - return fmt.Errorf("Stunner.Run: invalid server %q (in Server list %q)", server, s.Servers) - } - } - if len(s.Servers) == 0 { - return errors.New("stunner: no Servers") - } - - s.inFlight = make(map[stun.TxID]request) - defer s.cleanUpPostRun() - - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - type sender struct { - ctx context.Context - cancel context.CancelFunc - } - var ( - needMu sync.Mutex - need = make(map[string]sender) // keyed by server; deleted when done - allDone = make(chan struct{}) // closed when need is empty - ) - s.onPacket = func(server, endpoint string, d time.Duration) { - needMu.Lock() - defer needMu.Unlock() - sender, ok := need[server] - if !ok { - return - } - sender.cancel() - delete(need, server) - s.Endpoint(server, endpoint, d) - if len(need) == 0 { - close(allDone) - } - } - - var wg sync.WaitGroup - for _, server := range s.Servers { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - need[server] = sender{ctx, cancel} - } - needMu.Lock() - for server, sender := range need { - wg.Add(1) - server, ctx := server, sender.ctx - go func() { - defer wg.Done() - s.sendPackets(ctx, server) - }() - } - needMu.Unlock() - var err error - select { - case <-ctx.Done(): - err = ctx.Err() - case <-allDone: - cancel() - } - wg.Wait() - - var missing []string - needMu.Lock() - for server := range need { - missing = append(missing, server) - } - needMu.Unlock() - - if len(missing) == 0 || err == nil { - return nil - } - return fmt.Errorf("got STUN error: %w; missing replies from: %v", err, strings.Join(missing, ", ")) -} - -func (s *Stunner) serverAddr(ctx context.Context, server string) (*net.UDPAddr, error) { - hostStr, portStr, err := net.SplitHostPort(server) - if err != nil { - return nil, err - } - addrPort, err := strconv.Atoi(portStr) - if err != nil { - return nil, fmt.Errorf("port: %v", err) - } - if addrPort == 0 { - addrPort = 3478 - } - addr := &net.UDPAddr{Port: addrPort} - - var ipAddrs []net.IPAddr - if s.DNSCache != nil { - ip, err := s.DNSCache.LookupIP(ctx, hostStr) - if err != nil { - return nil, err - } - ipAddrs = []net.IPAddr{{IP: ip}} - } else { - ipAddrs, err = s.resolver().LookupIPAddr(ctx, hostStr) - if err != nil { - return nil, fmt.Errorf("lookup ip addr (%q): %v", hostStr, err) - } - } - - for _, ipAddr := range ipAddrs { - ip4 := ipAddr.IP.To4() - if ip4 != nil { - if s.OnlyIPv6 { - continue - } - addr.IP = ip4 - break - } else if s.OnlyIPv6 { - addr.IP = ipAddr.IP - addr.Zone = ipAddr.Zone - } - } - if addr.IP == nil { - if s.OnlyIPv6 { - return nil, fmt.Errorf("cannot resolve any ipv6 addresses for %s, got: %v", server, ipAddrs) - } - return nil, fmt.Errorf("cannot resolve any ipv4 addresses for %s, got: %v", server, ipAddrs) - } - return addr, nil -} - -// maxTriesForServer returns the maximum number of STUN queries that -// will be sent to server (for one call to Run). The default is 1. -func (s *Stunner) maxTriesForServer(server string) int { - if v, ok := s.MaxTries[server]; ok && v > 0 { - return v - } - return 1 -} - -func (s *Stunner) sendPackets(ctx context.Context, server string) error { - addr, err := s.serverAddr(ctx, server) - if err != nil { - return err - } - maxTries := s.maxTriesForServer(server) - for i := 0; i < maxTries; i++ { - txID := stun.NewTxID() - req := stun.Request(txID) - s.addTX(txID, server) - _, err = s.Send(req, addr) - if err != nil { - return fmt.Errorf("send: %v", err) - } - - select { - case <-ctx.Done(): - // Ignore error. The caller deals with handling contexts. - // We only use it to dermine when to stop spraying STUN packets. - return nil - case <-time.After(time.Millisecond * time.Duration(50+rand.Intn(200))): - } - } - return nil -} diff --git a/stunner/stunner_test.go b/stunner/stunner_test.go deleted file mode 100644 index a3555f1be..000000000 --- a/stunner/stunner_test.go +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package stunner - -import ( - "context" - "errors" - "fmt" - "net" - "sort" - "testing" - "time" - - "gortc.io/stun" -) - -func TestStun(t *testing.T) { - conn1, err := net.ListenPacket("udp4", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - defer conn1.Close() - conn2, err := net.ListenPacket("udp4", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - defer conn2.Close() - stunServers := []string{ - conn1.LocalAddr().String(), conn2.LocalAddr().String(), - } - - epCh := make(chan string, 16) - - localConn, err := net.ListenPacket("udp4", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - - s := &Stunner{ - Send: localConn.WriteTo, - Endpoint: func(server, ep string, d time.Duration) { epCh <- ep }, - Servers: stunServers, - MaxTries: map[string]int{ - stunServers[0]: 2, - stunServers[1]: 2, - }, - } - - stun1Err := make(chan error) - go func() { - stun1Err <- startSTUN(conn1, s.Receive) - }() - stun2Err := make(chan error) - go func() { - stun2Err <- startSTUNDrop1(conn2, s.Receive) - }() - - errCh := make(chan error) - go func() { - errCh <- s.Run(context.Background()) - }() - - var eps []string - select { - case ep := <-epCh: - eps = append(eps, ep) - case <-time.After(100 * time.Millisecond): - t.Fatal("missing first endpoint response") - } - select { - case ep := <-epCh: - eps = append(eps, ep) - case <-time.After(500 * time.Millisecond): - t.Fatal("missing second endpoint response") - } - sort.Strings(eps) - if want := "1.2.3.4:1234"; eps[0] != want { - t.Errorf("eps[0]=%q, want %q", eps[0], want) - } - if want := "4.5.6.7:4567"; eps[1] != want { - t.Errorf("eps[1]=%q, want %q", eps[1], want) - } - - if err := <-errCh; err != nil { - t.Fatal(err) - } -} - -func startSTUNDrop1(conn net.PacketConn, writeTo func([]byte, *net.UDPAddr)) error { - if _, _, err := conn.ReadFrom(make([]byte, 1024)); err != nil { - return fmt.Errorf("first stun server read failed: %v", err) - } - req := new(stun.Message) - res := new(stun.Message) - - p := make([]byte, 1024) - n, addr, err := conn.ReadFrom(p) - if err != nil { - return err - } - p = p[:n] - if !stun.IsMessage(p) { - return errors.New("not a STUN message") - } - if _, err := req.Write(p); err != nil { - return err - } - mappedAddr := &stun.XORMappedAddress{ - IP: net.ParseIP("1.2.3.4"), - Port: 1234, - } - software := stun.NewSoftware("endpointer") - err = res.Build(req, stun.BindingSuccess, software, mappedAddr, stun.Fingerprint) - if err != nil { - return err - } - writeTo(res.Raw, addr.(*net.UDPAddr)) - return nil -} - -func startSTUN(conn net.PacketConn, writeTo func([]byte, *net.UDPAddr)) error { - req := new(stun.Message) - res := new(stun.Message) - - p := make([]byte, 1024) - n, addr, err := conn.ReadFrom(p) - if err != nil { - return err - } - p = p[:n] - if !stun.IsMessage(p) { - return errors.New("not a STUN message") - } - if _, err := req.Write(p); err != nil { - return err - } - mappedAddr := &stun.XORMappedAddress{ - IP: net.ParseIP("4.5.6.7"), - Port: 4567, - } - software := stun.NewSoftware("endpointer") - err = res.Build(req, stun.BindingSuccess, software, mappedAddr, stun.Fingerprint) - if err != nil { - return err - } - writeTo(res.Raw, addr.(*net.UDPAddr)) - return nil -} - -// TODO: test retry timeout (overwrite the retryDurations) -// TODO: test canceling context passed to Run -// TODO: test sending bad packets diff --git a/tailcfg/derpmap.go b/tailcfg/derpmap.go index c7553545c..69aa92157 100644 --- a/tailcfg/derpmap.go +++ b/tailcfg/derpmap.go @@ -4,6 +4,8 @@ package tailcfg +import "sort" + // DERPMap describes the set of DERP packet relay servers that are available. type DERPMap struct { // Regions is the set of geographic regions running DERP node(s). @@ -14,6 +16,16 @@ type DERPMap struct { Regions map[int]*DERPRegion } +/// RegionIDs returns the sorted region IDs. +func (m *DERPMap) RegionIDs() []int { + ret := make([]int, 0, len(m.Regions)) + for rid := range m.Regions { + ret = append(ret, rid) + } + sort.Ints(ret) + return ret +} + // DERPRegion is a geographic region running DERP relay node(s). // // Client nodes discover which region they're closest to, advertise @@ -85,9 +97,29 @@ type DERPNode struct { // IPv4 optionally forces an IPv4 address to use, instead of using DNS. // If empty, A record(s) from DNS lookups of HostName are used. + // If the string is not an IPv4 address, IPv4 is not used; the + // conventional string to disable IPv4 (and not use DNS) is + // "none". IPv4 string `json:",omitempty"` // IPv6 optionally forces an IPv6 address to use, instead of using DNS. // If empty, AAAA record(s) from DNS lookups of HostName are used. + // If the string is not an IPv6 address, IPv6 is not used; the + // conventional string to disable IPv6 (and not use DNS) is + // "none". IPv6 string `json:",omitempty"` + + // Port optionally specifies a STUN port to use. + // Zero means 3478. + // To disable STUN on this node, use -1. + STUNPort int `json:",omitempty"` + + // STUNOnly marks a node as only a STUN server and not a DERP + // server. + STUNOnly bool `json:",omitempty"` + + // DERPTestPort is used in tests to override the port, instead + // of using the default port of 443. If non-zero, TLS + // verification is skipped. + DERPTestPort int `json:",omitempty"` } diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index f1ddb1ee8..5afe05129 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -315,8 +315,9 @@ type NetInfo struct { LinkType string // "wired", "wifi", "mobile" (LTE, 4G, 3G, etc) // DERPLatency is the fastest recent time to reach various - // DERP STUN servers, in seconds. The map key is the DERP - // server's STUN host:port. + // DERP STUN servers, in seconds. The map key is the + // "regionID-v4" or "-v6"; it was previously the DERP server's + // STUN host:port. // // This should only be updated rarely, or when there's a // material change, as any change here also gets uploaded to @@ -336,7 +337,7 @@ func (ni *NetInfo) String() string { } // BasicallyEqual reports whether ni and ni2 are basically equal, ignoring -// changes in DERPLatency. +// changes in DERP ServerLatency & RegionLatency. func (ni *NetInfo) BasicallyEqual(ni2 *NetInfo) bool { if (ni == nil) != (ni2 == nil) { return false diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index d9e3a4800..1b8cb0cc4 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -9,7 +9,6 @@ import ( "bytes" "context" - "crypto/tls" "encoding/binary" "errors" "fmt" @@ -17,6 +16,7 @@ "math/rand" "net" "os" + "reflect" "sort" "strconv" "strings" @@ -32,7 +32,6 @@ "inet.af/netaddr" "tailscale.com/derp" "tailscale.com/derp/derphttp" - "tailscale.com/derp/derpmap" "tailscale.com/ipn/ipnstate" "tailscale.com/net/dnscache" "tailscale.com/net/interfaces" @@ -55,7 +54,6 @@ type Conn struct { epFunc func(endpoints []string) logf logger.Logf sendLogLimit *rate.Limiter - derps *derpmap.World netChecker *netcheck.Client // bufferedIPv4From and bufferedIPv4Packet are owned by @@ -76,7 +74,8 @@ type Conn struct { mu sync.Mutex // guards all following fields - closed bool + started bool + closed bool endpointsUpdateWaiter *sync.Cond endpointsUpdateActive bool @@ -104,13 +103,12 @@ type Conn struct { netInfoFunc func(*tailcfg.NetInfo) // nil until set netInfoLast *tailcfg.NetInfo - wantDerp bool - privateKey key.Private - myDerp int // nearest DERP server; 0 means none/unknown - derpStarted chan struct{} // closed on first connection to DERP; for tests - activeDerp map[int]activeDerp - prevDerp map[int]*syncs.WaitGroupChan - derpTLSConfig *tls.Config // normally nil; used by tests + derpMap *tailcfg.DERPMap // nil (or zero regions/nodes) means DERP is disabled + privateKey key.Private + myDerp int // nearest DERP region ID; 0 means none/unknown + derpStarted chan struct{} // closed on first connection to DERP; for tests + activeDerp map[int]activeDerp // DERP regionID -> connection to a node in that region + prevDerp map[int]*syncs.WaitGroupChan // derpRoute contains optional alternate routes to use as an // optimization instead of contacting a peer via their home @@ -196,14 +194,9 @@ type Options struct { // Zero means to pick one automatically. Port uint16 - // DERPs, if non-nil, is used instead of derpmap.Prod. - DERPs *derpmap.World - // EndpointsFunc optionally provides a func to be called when // endpoints change. The called func does not own the slice. EndpointsFunc func(endpoint []string) - - derpTLSConfig *tls.Config // normally nil; used by tests } func (o *Options) logf() logger.Logf { @@ -220,37 +213,39 @@ func (o *Options) endpointsFunc() func([]string) { return o.EndpointsFunc } -// Listen creates a magic Conn listening on opts.Port. -// As the set of possible endpoints for a Conn changes, the -// callback opts.EndpointsFunc is called. -func Listen(opts Options) (*Conn, error) { +// newConn is the error-free, network-listening-side-effect-free based +// of NewConn. Mostly for tests. +func newConn() *Conn { c := &Conn{ - pconnPort: opts.Port, - logf: opts.logf(), - epFunc: opts.endpointsFunc(), - sendLogLimit: rate.NewLimiter(rate.Every(1*time.Minute), 1), - addrsByUDP: make(map[netaddr.IPPort]*AddrSet), - addrsByKey: make(map[key.Public]*AddrSet), - wantDerp: true, - derpRecvCh: make(chan derpReadResult), - udpRecvCh: make(chan udpReadResult), - derpTLSConfig: opts.derpTLSConfig, - derpStarted: make(chan struct{}), - derps: opts.DERPs, - peerLastDerp: make(map[key.Public]int), + sendLogLimit: rate.NewLimiter(rate.Every(1*time.Minute), 1), + addrsByUDP: make(map[netaddr.IPPort]*AddrSet), + addrsByKey: make(map[key.Public]*AddrSet), + derpRecvCh: make(chan derpReadResult), + udpRecvCh: make(chan udpReadResult), + derpStarted: make(chan struct{}), + peerLastDerp: make(map[key.Public]int), } c.endpointsUpdateWaiter = sync.NewCond(&c.mu) + return c +} + +// NewConn creates a magic Conn listening on opts.Port. +// As the set of possible endpoints for a Conn changes, the +// callback opts.EndpointsFunc is called. +// +// It doesn't start doing anything until Start is called. +func NewConn(opts Options) (*Conn, error) { + c := newConn() + c.pconnPort = opts.Port + c.logf = opts.logf() + c.epFunc = opts.endpointsFunc() if err := c.initialBind(); err != nil { return nil, err } c.connCtx, c.connCtxCancel = context.WithCancel(context.Background()) - if c.derps == nil { - c.derps = derpmap.Prod() - } c.netChecker = &netcheck.Client{ - DERP: c.derps, Logf: logger.WithPrefix(c.logf, "netcheck: "), GetSTUNConn4: func() netcheck.STUNConn { return c.pconn4 }, } @@ -259,6 +254,18 @@ func Listen(opts Options) (*Conn, error) { } c.ignoreSTUNPackets() + + return c, nil +} + +func (c *Conn) Start() { + c.mu.Lock() + if c.started { + panic("duplicate Start call") + } + c.started = true + c.mu.Unlock() + c.ReSTUN("initial") // We assume that LinkChange notifications are plumbed through well @@ -267,8 +274,6 @@ func Listen(opts Options) (*Conn, error) { go c.periodicReSTUN() } go c.periodicDerpCleanup() - - return c, nil } func (c *Conn) donec() <-chan struct{} { return c.connCtx.Done() } @@ -278,10 +283,6 @@ func (c *Conn) ignoreSTUNPackets() { c.stunReceiveFunc.Store(func([]byte, *net.UDPAddr) {}) } -// runs in its own goroutine until ctx is shut down. -// Whenever c.startEpUpdate receives a value, it starts an -// STUN endpoint lookup. -// // c.mu must NOT be held. func (c *Conn) updateEndpoints(why string) { defer func() { @@ -326,7 +327,11 @@ func (c *Conn) setEndpoints(endpoints []string) (changed bool) { } func (c *Conn) updateNetInfo(ctx context.Context) (*netcheck.Report, error) { - if DisableSTUNForTesting { + c.mu.Lock() + dm := c.derpMap + c.mu.Unlock() + + if DisableSTUNForTesting || dm == nil { return new(netcheck.Report), nil } @@ -336,7 +341,7 @@ func (c *Conn) updateNetInfo(ctx context.Context) (*netcheck.Report, error) { c.stunReceiveFunc.Store(c.netChecker.ReceiveSTUNPacket) defer c.ignoreSTUNPackets() - report, err := c.netChecker.GetReport(ctx) + report, err := c.netChecker.GetReport(ctx, dm) if err != nil { return nil, err } @@ -346,8 +351,11 @@ func (c *Conn) updateNetInfo(ctx context.Context) (*netcheck.Report, error) { MappingVariesByDestIP: report.MappingVariesByDestIP, HairPinning: report.HairPinning, } - for server, d := range report.DERPLatency { - ni.DERPLatency[server] = d.Seconds() + for rid, d := range report.RegionV4Latency { + ni.DERPLatency[fmt.Sprintf("%d-v4", rid)] = d.Seconds() + } + for rid, d := range report.RegionV6Latency { + ni.DERPLatency[fmt.Sprintf("%d-v6", rid)] = d.Seconds() } ni.WorkingIPv6.Set(report.IPv6) ni.WorkingUDP.Set(report.UDP) @@ -380,9 +388,12 @@ func (c *Conn) pickDERPFallback() int { c.mu.Lock() defer c.mu.Unlock() - ids := c.derps.IDs() + if !c.wantDerpLocked() { + return 0 + } + ids := c.derpMap.RegionIDs() if len(ids) == 0 { - // No DERP nodes registered. + // No DERP regions in non-nil map. return 0 } @@ -458,7 +469,7 @@ func (c *Conn) SetNetInfoCallback(fn func(*tailcfg.NetInfo)) { func (c *Conn) setNearestDERP(derpNum int) (wantDERP bool) { c.mu.Lock() defer c.mu.Unlock() - if !c.wantDerp { + if !c.wantDerpLocked() { c.myDerp = 0 return false } @@ -476,7 +487,7 @@ func (c *Conn) setNearestDERP(derpNum int) (wantDERP bool) { // On change, notify all currently connected DERP servers and // start connecting to our home DERP if we are not already. - c.logf("magicsock: home is now derp-%v (%v)", derpNum, c.derps.ServerByID(derpNum).Geo) + c.logf("magicsock: home is now derp-%v (%v)", derpNum, c.derpMap.Regions[derpNum].RegionCode) for i, ad := range c.activeDerp { go ad.c.NotePreferred(i == c.myDerp) } @@ -791,11 +802,11 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr, peer key.Public) chan<- de if !addr.IP.Equal(derpMagicIP) { return nil } - nodeID := addr.Port + regionID := addr.Port c.mu.Lock() defer c.mu.Unlock() - if !c.wantDerp || c.closed { + if !c.wantDerpLocked() || c.closed { return nil } if c.privateKey.IsZero() { @@ -807,10 +818,10 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr, peer key.Public) chan<- de // first. If so, might as well use it. (It's a little // arbitrary whether we use this one vs. the reverse route // below when we have both.) - ad, ok := c.activeDerp[nodeID] + ad, ok := c.activeDerp[regionID] if ok { *ad.lastWrite = time.Now() - c.setPeerLastDerpLocked(peer, nodeID, nodeID) + c.setPeerLastDerpLocked(peer, regionID, regionID) return ad.writeCh } @@ -823,7 +834,7 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr, peer key.Public) chan<- de if !peer.IsZero() && debugUseDerpRoute { if r, ok := c.derpRoute[peer]; ok { if ad, ok := c.activeDerp[r.derpID]; ok && ad.c == r.dc { - c.setPeerLastDerpLocked(peer, r.derpID, nodeID) + c.setPeerLastDerpLocked(peer, r.derpID, regionID) *ad.lastWrite = time.Now() return ad.writeCh } @@ -834,7 +845,7 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr, peer key.Public) chan<- de if !peer.IsZero() { why = peerShort(peer) } - c.logf("magicsock: adding connection to derp-%v for %v", nodeID, why) + c.logf("magicsock: adding connection to derp-%v for %v", regionID, why) firstDerp := false if c.activeDerp == nil { @@ -842,22 +853,23 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr, peer key.Public) chan<- de c.activeDerp = make(map[int]activeDerp) c.prevDerp = make(map[int]*syncs.WaitGroupChan) } - derpSrv := c.derps.ServerByID(nodeID) - if derpSrv == nil || derpSrv.HostHTTPS == "" { + if c.derpMap == nil || c.derpMap.Regions[regionID] == nil { return nil } // Note that derphttp.NewClient does not dial the server // so it is safe to do under the mu lock. - dc, err := derphttp.NewClient(c.privateKey, "https://"+derpSrv.HostHTTPS+"/derp", c.logf) - if err != nil { - c.logf("magicsock: derphttp.NewClient: node %d, host %q invalid? err: %v", nodeID, derpSrv.HostHTTPS, err) - return nil - } + dc := derphttp.NewRegionClient(c.privateKey, c.logf, func() *tailcfg.DERPRegion { + c.mu.Lock() + defer c.mu.Unlock() + if c.derpMap == nil { + return nil + } + return c.derpMap.Regions[regionID] + }) - dc.NotePreferred(c.myDerp == nodeID) + dc.NotePreferred(c.myDerp == regionID) dc.DNSCache = dnscache.Get() - dc.TLSConfig = c.derpTLSConfig ctx, cancel := context.WithCancel(c.connCtx) ch := make(chan derpWriteRequest, bufferedDerpWritesBeforeDrop) @@ -868,21 +880,21 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr, peer key.Public) chan<- de ad.lastWrite = new(time.Time) *ad.lastWrite = time.Now() ad.createTime = time.Now() - c.activeDerp[nodeID] = ad + c.activeDerp[regionID] = ad c.logActiveDerpLocked() - c.setPeerLastDerpLocked(peer, nodeID, nodeID) + c.setPeerLastDerpLocked(peer, regionID, regionID) // Build a startGate for the derp reader+writer // goroutines, so they don't start running until any // previous generation is closed. startGate := syncs.ClosedChan() - if prev := c.prevDerp[nodeID]; prev != nil { + if prev := c.prevDerp[regionID]; prev != nil { startGate = prev.DoneChan() } // And register a WaitGroup(Chan) for this generation. wg := syncs.NewWaitGroupChan() wg.Add(2) - c.prevDerp[nodeID] = wg + c.prevDerp[regionID] = wg if firstDerp { startGate = c.derpStarted @@ -899,37 +911,37 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr, peer key.Public) chan<- de } // setPeerLastDerpLocked notes that peer is now being written to via -// provided DERP node nodeID, and that that advertises a DERP home -// node of homeID. +// the provided DERP regionID, and that the peer advertises a DERP +// home region ID of homeID. // // If there's any change, it logs. // // c.mu must be held. -func (c *Conn) setPeerLastDerpLocked(peer key.Public, nodeID, homeID int) { +func (c *Conn) setPeerLastDerpLocked(peer key.Public, regionID, homeID int) { if peer.IsZero() { return } old := c.peerLastDerp[peer] - if old == nodeID { + if old == regionID { return } - c.peerLastDerp[peer] = nodeID + c.peerLastDerp[peer] = regionID var newDesc string switch { - case nodeID == homeID && nodeID == c.myDerp: + case regionID == homeID && regionID == c.myDerp: newDesc = "shared home" - case nodeID == homeID: + case regionID == homeID: newDesc = "their home" - case nodeID == c.myDerp: + case regionID == c.myDerp: newDesc = "our home" - case nodeID != homeID: + case regionID != homeID: newDesc = "alt" } if old == 0 { - c.logf("magicsock: derp route for %s set to derp-%d (%s)", peerShort(peer), nodeID, newDesc) + c.logf("magicsock: derp route for %s set to derp-%d (%s)", peerShort(peer), regionID, newDesc) } else { - c.logf("magicsock: derp route for %s changed from derp-%d => derp-%d (%s)", peerShort(peer), old, nodeID, newDesc) + c.logf("magicsock: derp route for %s changed from derp-%d => derp-%d (%s)", peerShort(peer), old, regionID, newDesc) } } @@ -1284,18 +1296,27 @@ func (c *Conn) UpdatePeers(newPeers map[key.Public]struct{}) { } } -// SetDERPEnabled controls whether DERP is used. -// New connections have it enabled by default. -func (c *Conn) SetDERPEnabled(wantDerp bool) { +// SetDERPMap controls which (if any) DERP servers are used. +// A nil value means to disable DERP; it's disabled by default. +func (c *Conn) SetDERPMap(dm *tailcfg.DERPMap) { c.mu.Lock() defer c.mu.Unlock() - c.wantDerp = wantDerp - if !wantDerp { - c.closeAllDerpLocked("derp-disabled") + if reflect.DeepEqual(dm, c.derpMap) { + return } + + c.derpMap = dm + if dm == nil { + c.closeAllDerpLocked("derp-disabled") + return + } + + go c.ReSTUN("derp-map-update") } +func (c *Conn) wantDerpLocked() bool { return c.derpMap != nil } + // c.mu must be held. func (c *Conn) closeAllDerpLocked(why string) { if len(c.activeDerp) == 0 { @@ -1352,7 +1373,7 @@ func (c *Conn) logEndpointChange(endpoints []string, reasons map[string]string) } // c.mu must be held. -func (c *Conn) foreachActiveDerpSortedLocked(fn func(nodeID int, ad activeDerp)) { +func (c *Conn) foreachActiveDerpSortedLocked(fn func(regionID int, ad activeDerp)) { if len(c.activeDerp) < 2 { for id, ad := range c.activeDerp { fn(id, ad) @@ -1473,6 +1494,9 @@ func (c *Conn) periodicDerpCleanup() { func (c *Conn) ReSTUN(why string) { c.mu.Lock() defer c.mu.Unlock() + if !c.started { + panic("call to ReSTUN before Start") + } if c.closed { // raced with a shutdown. return diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 9b7b30aac..992768adc 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -27,6 +27,7 @@ "tailscale.com/derp/derphttp" "tailscale.com/derp/derpmap" "tailscale.com/stun/stuntest" + "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/types/key" "tailscale.com/types/logger" @@ -54,7 +55,7 @@ func (c *Conn) WaitReady(t *testing.T) { } } -func TestListen(t *testing.T) { +func TestNewConn(t *testing.T) { tstest.PanicOnLog() rc := tstest.NewResourceCheck() defer rc.Assert(t) @@ -70,9 +71,8 @@ func TestListen(t *testing.T) { defer stunCleanupFn() port := pickPort(t) - conn, err := Listen(Options{ + conn, err := NewConn(Options{ Port: port, - DERPs: derpmap.NewTestWorld(stunAddr), EndpointsFunc: epFunc, Logf: t.Logf, }) @@ -80,6 +80,8 @@ func TestListen(t *testing.T) { t.Fatal(err) } defer conn.Close() + conn.Start() + conn.SetDERPMap(stuntest.DERPMapOf(stunAddr.String())) go func() { var pkt [64 << 10]byte @@ -136,9 +138,8 @@ func TestPickDERPFallback(t *testing.T) { rc := tstest.NewResourceCheck() defer rc.Assert(t) - c := &Conn{ - derps: derpmap.Prod(), - } + c := newConn() + c.derpMap = derpmap.Prod() a := c.pickDERPFallback() if a == 0 { t.Fatalf("pickDERPFallback returned 0") @@ -156,7 +157,8 @@ func TestPickDERPFallback(t *testing.T) { // distribution over nodes works. got := map[int]int{} for i := 0; i < 50; i++ { - c = &Conn{derps: derpmap.Prod()} + c = newConn() + c.derpMap = derpmap.Prod() got[c.pickDERPFallback()]++ } t.Logf("distribution: %v", got) @@ -236,7 +238,7 @@ func parseCIDR(t *testing.T, addr string) wgcfg.CIDR { return cidr } -func runDERP(t *testing.T, logf logger.Logf) (s *derp.Server, addr string, cleanupFn func()) { +func runDERP(t *testing.T, logf logger.Logf) (s *derp.Server, addr *net.TCPAddr, cleanupFn func()) { var serverPrivateKey key.Private if _, err := crand.Read(serverPrivateKey[:]); err != nil { t.Fatal(err) @@ -250,14 +252,13 @@ func runDERP(t *testing.T, logf logger.Logf) (s *derp.Server, addr string, clean httpsrv.StartTLS() logf("DERP server URL: %s", httpsrv.URL) - addr = strings.TrimPrefix(httpsrv.URL, "https://") cleanupFn = func() { httpsrv.CloseClientConnections() httpsrv.Close() s.Close() } - return s, addr, cleanupFn + return s, httpsrv.Listener.Addr().(*net.TCPAddr), cleanupFn } // devLogger returns a wireguard-go device.Logger that writes @@ -286,13 +287,14 @@ func TestDeviceStartStop(t *testing.T) { rc := tstest.NewResourceCheck() defer rc.Assert(t) - conn, err := Listen(Options{ + conn, err := NewConn(Options{ EndpointsFunc: func(eps []string) {}, Logf: t.Logf, }) if err != nil { t.Fatal(err) } + conn.Start() defer conn.Close() tun := tuntest.NewChannelTUN() @@ -337,48 +339,58 @@ func TestTwoDevicePing(t *testing.T) { // all log using the "current" t.Logf function. Sigh. logf, setT := makeNestable(t) - // Wipe default DERP list, add local server. - // (Do it now, or derpHost will try to connect to derp1.tailscale.com.) derpServer, derpAddr, derpCleanupFn := runDERP(t, logf) defer derpCleanupFn() - stunAddr, stunCleanupFn := stuntest.Serve(t) defer stunCleanupFn() - derps := derpmap.NewTestWorldWith(&derpmap.Server{ - ID: 1, - HostHTTPS: derpAddr, - STUN4: stunAddr, - Geo: "Testopolis", - }) + derpMap := &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: &tailcfg.DERPRegion{ + RegionID: 1, + RegionCode: "test", + Nodes: []*tailcfg.DERPNode{ + { + Name: "t1", + RegionID: 1, + HostName: "test-node.unused", + IPv4: "127.0.0.1", + IPv6: "none", + STUNPort: stunAddr.Port, + DERPTestPort: derpAddr.Port, + }, + }, + }, + }, + } epCh1 := make(chan []string, 16) - conn1, err := Listen(Options{ - Logf: logger.WithPrefix(logf, "conn1: "), - DERPs: derps, + conn1, err := NewConn(Options{ + Logf: logger.WithPrefix(logf, "conn1: "), EndpointsFunc: func(eps []string) { epCh1 <- eps }, - derpTLSConfig: &tls.Config{InsecureSkipVerify: true}, }) if err != nil { t.Fatal(err) } defer conn1.Close() + conn1.Start() + conn1.SetDERPMap(derpMap) epCh2 := make(chan []string, 16) - conn2, err := Listen(Options{ - Logf: logger.WithPrefix(logf, "conn2: "), - DERPs: derps, + conn2, err := NewConn(Options{ + Logf: logger.WithPrefix(logf, "conn2: "), EndpointsFunc: func(eps []string) { epCh2 <- eps }, - derpTLSConfig: &tls.Config{InsecureSkipVerify: true}, }) if err != nil { t.Fatal(err) } defer conn2.Close() + conn2.Start() + conn2.SetDERPMap(derpMap) ports := []uint16{conn1.LocalPort(), conn2.LocalPort()} cfgs := makeConfigs(t, ports) diff --git a/wgengine/userspace.go b/wgengine/userspace.go index e0a4fda34..9a960fde6 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -8,6 +8,7 @@ "bufio" "bytes" "context" + "errors" "fmt" "io" "log" @@ -49,7 +50,7 @@ type userspaceEngine struct { logf logger.Logf reqCh chan struct{} - waitCh chan struct{} + waitCh chan struct{} // chan is closed when first Close call completes; contrast with closing bool tundev *tstun.TUN wgdev *device.Device router router.Router @@ -61,6 +62,7 @@ type userspaceEngine struct { lastCfg wgcfg.Config mu sync.Mutex // guards following; see lock order comment below + closing bool // Close was called (even if we're still closing) statusCallback StatusCallback peerSequence []wgcfg.Key endpoints []string @@ -149,7 +151,7 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R Port: listenPort, EndpointsFunc: endpointsFn, } - e.magicConn, err = magicsock.Listen(magicsockOpts) + e.magicConn, err = magicsock.NewConn(magicsockOpts) if err != nil { tundev.Close() return nil, fmt.Errorf("wgengine: %v", err) @@ -210,6 +212,7 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R // routers do not Read or Write, but do access native interfaces. e.router, err = routerGen(logf, e.wgdev, e.tundev.Unwrap()) if err != nil { + e.magicConn.Close() return nil, err } @@ -235,16 +238,19 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R e.wgdev.Up() if err := e.router.Up(); err != nil { + e.magicConn.Close() e.wgdev.Close() return nil, err } // TODO(danderson): we should delete this. It's pointless to apply // a no-op settings here. if err := e.router.Set(nil); err != nil { + e.magicConn.Close() e.wgdev.Close() return nil, err } e.linkMon.Start() + e.magicConn.Start() return e, nil } @@ -407,6 +413,13 @@ func (e *userspaceEngine) getStatus() (*Status, error) { e.wgLock.Lock() defer e.wgLock.Unlock() + e.mu.Lock() + closing := e.closing + e.mu.Unlock() + if closing { + return nil, errors.New("engine closing; no status") + } + if e.wgdev == nil { // RequestStatus was invoked before the wgengine has // finished initializing. This can happen when wgegine @@ -553,6 +566,11 @@ func (e *userspaceEngine) RequestStatus() { func (e *userspaceEngine) Close() { e.mu.Lock() + if e.closing { + e.mu.Unlock() + return + } + e.closing = true for key, cancel := range e.pingers { delete(e.pingers, key) cancel() @@ -614,8 +632,8 @@ func (e *userspaceEngine) SetNetInfoCallback(cb NetInfoCallback) { e.magicConn.SetNetInfoCallback(cb) } -func (e *userspaceEngine) SetDERPEnabled(v bool) { - e.magicConn.SetDERPEnabled(v) +func (e *userspaceEngine) SetDERPMap(dm *tailcfg.DERPMap) { + e.magicConn.SetDERPMap(dm) } func (e *userspaceEngine) UpdateStatus(sb *ipnstate.StatusBuilder) { diff --git a/wgengine/watchdog.go b/wgengine/watchdog.go index 9d409ece7..ef9393a47 100644 --- a/wgengine/watchdog.go +++ b/wgengine/watchdog.go @@ -12,6 +12,7 @@ "github.com/tailscale/wireguard-go/wgcfg" "tailscale.com/ipn/ipnstate" + "tailscale.com/tailcfg" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/router" ) @@ -88,8 +89,8 @@ func (e *watchdogEngine) RequestStatus() { func (e *watchdogEngine) LinkChange(isExpensive bool) { e.watchdog("LinkChange", func() { e.wrap.LinkChange(isExpensive) }) } -func (e *watchdogEngine) SetDERPEnabled(v bool) { - e.watchdog("SetDERPEnabled", func() { e.wrap.SetDERPEnabled(v) }) +func (e *watchdogEngine) SetDERPMap(m *tailcfg.DERPMap) { + e.watchdog("SetDERPMap", func() { e.wrap.SetDERPMap(m) }) } func (e *watchdogEngine) Close() { e.watchdog("Close", e.wrap.Close) diff --git a/wgengine/wgengine.go b/wgengine/wgengine.go index 3a229c713..81dcee80e 100644 --- a/wgengine/wgengine.go +++ b/wgengine/wgengine.go @@ -95,9 +95,10 @@ type Engine interface { // action on. LinkChange(isExpensive bool) - // SetDERPEnabled controls whether DERP is enabled. - // It starts enabled by default. - SetDERPEnabled(bool) + // SetDERPMap controls which (if any) DERP servers are used. + // If nil, DERP is disabled. It starts disabled until a DERP map + // is configured. + SetDERPMap(*tailcfg.DERPMap) // SetNetInfoCallback sets the function to call when a // new NetInfo summary is available.