all: make client use server-provided DERP map, add DERP region support

Instead of hard-coding the DERP map (except for cmd/tailscale netcheck
for now), get it from the control server at runtime.

And make the DERP map support multiple nodes per region with clients
picking the first one that's available. (The server will balance the
order presented to clients for load balancing)

This deletes the stunner package, merging it into the netcheck package
instead, to minimize all the config hooks that would've been
required.

Also fix some test flakes & races.

Fixes #387 (Don't hard-code the DERP map)
Updates #388 (Add DERP region support)
Fixes #399 (wgengine: flaky tests)

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2020-05-17 09:51:38 -07:00 committed by Brad Fitzpatrick
parent e8b3a5e7a1
commit e6b84f2159
20 changed files with 1439 additions and 1201 deletions

View File

@ -9,6 +9,7 @@
"fmt" "fmt"
"log" "log"
"sort" "sort"
"time"
"github.com/peterbourgon/ff/v2/ffcli" "github.com/peterbourgon/ff/v2/ffcli"
"tailscale.com/derp/derpmap" "tailscale.com/derp/derpmap"
@ -26,12 +27,12 @@
func runNetcheck(ctx context.Context, args []string) error { func runNetcheck(ctx context.Context, args []string) error {
c := &netcheck.Client{ c := &netcheck.Client{
DERP: derpmap.Prod(),
Logf: logger.WithPrefix(log.Printf, "netcheck: "), Logf: logger.WithPrefix(log.Printf, "netcheck: "),
DNSCache: dnscache.Get(), DNSCache: dnscache.Get(),
} }
report, err := c.GetReport(ctx) dm := derpmap.Prod()
report, err := c.GetReport(ctx, dm)
if err != nil { if err != nil {
log.Fatalf("netcheck: %v", err) log.Fatalf("netcheck: %v", err)
} }
@ -55,18 +56,23 @@ func runNetcheck(ctx context.Context, args []string) error {
// When DERP latency checking failed, // When DERP latency checking failed,
// magicsock will try to pick the DERP server that // magicsock will try to pick the DERP server that
// most of your other nodes are also using // most of your other nodes are also using
if len(report.DERPLatency) == 0 { if len(report.RegionLatency) == 0 {
fmt.Printf("\t* Nearest DERP: unknown (no response to latency probes)\n") fmt.Printf("\t* Nearest DERP: unknown (no response to latency probes)\n")
} else { } else {
fmt.Printf("\t* Nearest DERP: %v (%v)\n", report.PreferredDERP, c.DERP.LocationOfID(report.PreferredDERP)) fmt.Printf("\t* Nearest DERP: %v (%v)\n", report.PreferredDERP, dm.Regions[report.PreferredDERP].RegionCode)
fmt.Printf("\t* DERP latency:\n") fmt.Printf("\t* DERP latency:\n")
var ss []string var rids []int
for s := range report.DERPLatency { for rid := range dm.Regions {
ss = append(ss, s) rids = append(rids, rid)
} }
sort.Strings(ss) sort.Ints(rids)
for _, s := range ss { for _, rid := range rids {
fmt.Printf("\t\t- %s = %v\n", s, report.DERPLatency[s]) d, ok := report.RegionLatency[rid]
var latency string
if ok {
latency = d.Round(time.Millisecond / 10).String()
}
fmt.Printf("\t\t- %v, %3s = %s\n", rid, dm.Regions[rid].RegionCode, latency)
} }
} }
return nil return nil

View File

@ -541,6 +541,8 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM
} }
}() }()
var lastDERPMap *tailcfg.DERPMap
// If allowStream, then the server will use an HTTP long poll to // If allowStream, then the server will use an HTTP long poll to
// return incremental results. There is always one response right // return incremental results. There is always one response right
// away, followed by a delay, and eventually others. // away, followed by a delay, and eventually others.
@ -582,6 +584,11 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM
} }
vlogf("netmap: got new map") vlogf("netmap: got new map")
if resp.DERPMap != nil {
vlogf("netmap: new map contains DERP map")
lastDERPMap = resp.DERPMap
}
nm := &NetworkMap{ nm := &NetworkMap{
NodeKey: tailcfg.NodeKey(persist.PrivateNodeKey.Public()), NodeKey: tailcfg.NodeKey(persist.PrivateNodeKey.Public()),
PrivateKey: persist.PrivateNodeKey, PrivateKey: persist.PrivateNodeKey,
@ -597,6 +604,7 @@ func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkM
DNSDomains: resp.SearchPaths, DNSDomains: resp.SearchPaths,
Hostinfo: resp.Node.Hostinfo, Hostinfo: resp.Node.Hostinfo,
PacketFilter: c.parsePacketFilter(resp.PacketFilter), PacketFilter: c.parsePacketFilter(resp.PacketFilter),
DERPMap: lastDERPMap,
} }
for _, profile := range resp.UserProfiles { for _, profile := range resp.UserProfiles {
nm.UserProfiles[profile.ID] = profile nm.UserProfiles[profile.ID] = profile

View File

@ -33,6 +33,10 @@ type NetworkMap struct {
Hostinfo tailcfg.Hostinfo Hostinfo tailcfg.Hostinfo
PacketFilter filter.Matches PacketFilter filter.Matches
// DERPMap is the last DERP server map received. It's reused
// between updates and should not be modified.
DERPMap *tailcfg.DERPMap
// ACLs // ACLs
User tailcfg.UserID User tailcfg.UserID

View File

@ -24,9 +24,11 @@
"sync" "sync"
"time" "time"
"inet.af/netaddr"
"tailscale.com/derp" "tailscale.com/derp"
"tailscale.com/net/dnscache" "tailscale.com/net/dnscache"
"tailscale.com/net/tlsdial" "tailscale.com/net/tlsdial"
"tailscale.com/tailcfg"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
) )
@ -43,7 +45,10 @@ type Client struct {
privateKey key.Private privateKey key.Private
logf logger.Logf logf logger.Logf
url *url.URL
// Either url or getRegion is non-nil:
url *url.URL
getRegion func() *tailcfg.DERPRegion
ctx context.Context // closed via cancelCtx in Client.Close ctx context.Context // closed via cancelCtx in Client.Close
cancelCtx context.CancelFunc cancelCtx context.CancelFunc
@ -55,8 +60,22 @@ type Client struct {
client *derp.Client client *derp.Client
} }
// NewRegionClient returns a new DERP-over-HTTP client. It connects lazily.
// To trigger a connection, use Connect.
func NewRegionClient(privateKey key.Private, logf logger.Logf, getRegion func() *tailcfg.DERPRegion) *Client {
ctx, cancel := context.WithCancel(context.Background())
c := &Client{
privateKey: privateKey,
logf: logf,
getRegion: getRegion,
ctx: ctx,
cancelCtx: cancel,
}
return c
}
// NewClient returns a new DERP-over-HTTP client. It connects lazily. // NewClient returns a new DERP-over-HTTP client. It connects lazily.
// To trigger a connection use Connect. // To trigger a connection, use Connect.
func NewClient(privateKey key.Private, serverURL string, logf logger.Logf) (*Client, error) { func NewClient(privateKey key.Private, serverURL string, logf logger.Logf) (*Client, error) {
u, err := url.Parse(serverURL) u, err := url.Parse(serverURL)
if err != nil { if err != nil {
@ -65,6 +84,7 @@ func NewClient(privateKey key.Private, serverURL string, logf logger.Logf) (*Cli
if urlPort(u) == "" { if urlPort(u) == "" {
return nil, fmt.Errorf("derphttp.NewClient: invalid URL scheme %q", u.Scheme) return nil, fmt.Errorf("derphttp.NewClient: invalid URL scheme %q", u.Scheme)
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
c := &Client{ c := &Client{
privateKey: privateKey, privateKey: privateKey,
@ -101,6 +121,37 @@ func urlPort(u *url.URL) string {
return "" return ""
} }
func (c *Client) targetString(reg *tailcfg.DERPRegion) string {
if c.url != nil {
return c.url.String()
}
return fmt.Sprintf("region %d (%v)", reg.RegionID, reg.RegionCode)
}
func (c *Client) useHTTPS() bool {
if c.url != nil && c.url.Scheme == "http" {
return false
}
return true
}
func (c *Client) tlsServerName(node *tailcfg.DERPNode) string {
if c.url != nil {
return c.url.Host
}
if node.CertName != "" {
return node.CertName
}
return node.HostName
}
func (c *Client) urlString(node *tailcfg.DERPNode) string {
if c.url != nil {
return c.url.String()
}
return fmt.Sprintf("https://%s/derp", node.HostName)
}
func (c *Client) connect(ctx context.Context, caller string) (client *derp.Client, err error) { func (c *Client) connect(ctx context.Context, caller string) (client *derp.Client, err error) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -111,8 +162,6 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
return c.client, nil return c.client, nil
} }
c.logf("%s: connecting to %v", caller, c.url)
// timeout is the fallback maximum time (if ctx doesn't limit // timeout is the fallback maximum time (if ctx doesn't limit
// it further) to do all of: DNS + TCP + TLS + HTTP Upgrade + // it further) to do all of: DNS + TCP + TLS + HTTP Upgrade +
// DERP upgrade. // DERP upgrade.
@ -132,46 +181,42 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
}() }()
defer cancel() defer cancel()
var reg *tailcfg.DERPRegion // nil when using c.url to dial
if c.getRegion != nil {
reg = c.getRegion()
if reg == nil {
return nil, errors.New("DERP region not available")
}
}
var tcpConn net.Conn var tcpConn net.Conn
defer func() { defer func() {
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {
err = fmt.Errorf("%v: %v", ctx.Err(), err) err = fmt.Errorf("%v: %v", ctx.Err(), err)
} }
err = fmt.Errorf("%s connect to %v: %v", caller, c.url, err) err = fmt.Errorf("%s connect to %v: %v", caller, c.targetString(reg), err)
if tcpConn != nil { if tcpConn != nil {
go tcpConn.Close() go tcpConn.Close()
} }
} }
}() }()
host := c.url.Hostname() var node *tailcfg.DERPNode // nil when using c.url to dial
hostOrIP := host if c.url != nil {
c.logf("%s: connecting to %v", caller, c.url)
var stdDialer dialer = new(net.Dialer) tcpConn, err = c.dialURL(ctx)
var dialer = stdDialer } else {
if wrapDialer != nil { c.logf("%s: connecting to derp-%d (%v)", caller, reg.RegionID, reg.RegionCode)
dialer = wrapDialer(dialer) tcpConn, node, err = c.dialRegion(ctx, reg)
} }
if c.DNSCache != nil {
ip, err := c.DNSCache.LookupIP(ctx, host)
if err == nil {
hostOrIP = ip.String()
}
if err != nil && dialer == stdDialer {
// Return an error if we're not using a dial
// proxy that can do DNS lookups for us.
return nil, err
}
}
tcpConn, err = dialer.DialContext(ctx, "tcp", net.JoinHostPort(hostOrIP, urlPort(c.url)))
if err != nil { if err != nil {
return nil, fmt.Errorf("dial of %q: %v", host, err) return nil, err
} }
// Now that we have a TCP connection, force close it. // Now that we have a TCP connection, force close it if the
// TLS handshake + DERP setup takes too long.
done := make(chan struct{}) done := make(chan struct{})
defer close(done) defer close(done)
go func() { go func() {
@ -195,15 +240,19 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
}() }()
var httpConn net.Conn // a TCP conn or a TLS conn; what we speak HTTP to var httpConn net.Conn // a TCP conn or a TLS conn; what we speak HTTP to
if c.url.Scheme == "https" { if c.useHTTPS() {
httpConn = tls.Client(tcpConn, tlsdial.Config(c.url.Host, c.TLSConfig)) tlsConf := tlsdial.Config(c.tlsServerName(node), c.TLSConfig)
if node != nil && node.DERPTestPort != 0 {
tlsConf.InsecureSkipVerify = true
}
httpConn = tls.Client(tcpConn, tlsConf)
} else { } else {
httpConn = tcpConn httpConn = tcpConn
} }
brw := bufio.NewReadWriter(bufio.NewReader(httpConn), bufio.NewWriter(httpConn)) brw := bufio.NewReadWriter(bufio.NewReader(httpConn), bufio.NewWriter(httpConn))
req, err := http.NewRequest("GET", c.url.String(), nil) req, err := http.NewRequest("GET", c.urlString(node), nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -243,6 +292,148 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
return c.client, nil return c.client, nil
} }
func (c *Client) dialURL(ctx context.Context) (net.Conn, error) {
host := c.url.Hostname()
hostOrIP := host
var stdDialer dialer = new(net.Dialer)
var dialer = stdDialer
if wrapDialer != nil {
dialer = wrapDialer(dialer)
}
if c.DNSCache != nil {
ip, err := c.DNSCache.LookupIP(ctx, host)
if err == nil {
hostOrIP = ip.String()
}
if err != nil && dialer == stdDialer {
// Return an error if we're not using a dial
// proxy that can do DNS lookups for us.
return nil, err
}
}
tcpConn, err := dialer.DialContext(ctx, "tcp", net.JoinHostPort(hostOrIP, urlPort(c.url)))
if err != nil {
return nil, fmt.Errorf("dial of %v: %v", host, err)
}
return tcpConn, nil
}
// dialRegion returns a TCP connection to the provided region, trying
// each node in order (with dialNode) until one connects or ctx is
// done.
func (c *Client) dialRegion(ctx context.Context, reg *tailcfg.DERPRegion) (net.Conn, *tailcfg.DERPNode, error) {
if len(reg.Nodes) == 0 {
return nil, nil, fmt.Errorf("no nodes for %s", c.targetString(reg))
}
var firstErr error
for _, n := range reg.Nodes {
if n.STUNOnly {
continue
}
c, err := c.dialNode(ctx, n)
if err == nil {
return c, n, nil
}
if firstErr == nil {
firstErr = err
}
}
return nil, nil, firstErr
}
func (c *Client) dialContext(ctx context.Context, proto, addr string) (net.Conn, error) {
var stdDialer dialer = new(net.Dialer)
var dialer = stdDialer
if wrapDialer != nil {
dialer = wrapDialer(dialer)
}
return dialer.DialContext(ctx, proto, addr)
}
// shouldDialProto reports whether an explicitly provided IPv4 or IPv6
// address (given in s) is valid. An empty value means to dial, but to
// use DNS. The predicate function reports whether the non-empty
// string s contained a valid IP address of the right family.
func shouldDialProto(s string, pred func(netaddr.IP) bool) bool {
if s == "" {
return true
}
ip, _ := netaddr.ParseIP(s)
return pred(ip)
}
const dialNodeTimeout = 1500 * time.Millisecond
// dialNode returns a TCP connection to node n, racing IPv4 and IPv6
// (both as applicable) against each other.
// A node is only given dialNodeTimeout to connect.
//
// TODO(bradfitz): longer if no options remain perhaps? ... Or longer
// overall but have dialRegion start overlapping races?
func (c *Client) dialNode(ctx context.Context, n *tailcfg.DERPNode) (net.Conn, error) {
type res struct {
c net.Conn
err error
}
resc := make(chan res) // must be unbuffered
ctx, cancel := context.WithTimeout(ctx, dialNodeTimeout)
defer cancel()
nwait := 0
startDial := func(dstPrimary, proto string) {
nwait++
go func() {
dst := dstPrimary
if dst == "" {
dst = n.HostName
}
port := "443"
if n.DERPTestPort != 0 {
port = fmt.Sprint(n.DERPTestPort)
}
c, err := c.dialContext(ctx, proto, net.JoinHostPort(dst, port))
select {
case resc <- res{c, err}:
case <-ctx.Done():
if c != nil {
c.Close()
}
}
}()
}
if shouldDialProto(n.IPv4, netaddr.IP.Is4) {
startDial(n.IPv4, "tcp4")
}
if shouldDialProto(n.IPv6, netaddr.IP.Is6) {
startDial(n.IPv6, "tcp6")
}
if nwait == 0 {
return nil, errors.New("both IPv4 and IPv6 are explicitly disabled for node")
}
var firstErr error
for {
select {
case res := <-resc:
nwait--
if res.err == nil {
return res.c, nil
}
if firstErr == nil {
firstErr = res.err
}
if nwait == 0 {
return nil, firstErr
}
case <-ctx.Done():
return nil, ctx.Err()
}
}
}
func (c *Client) Send(dstKey key.Public, b []byte) error { func (c *Client) Send(dstKey key.Public, b []byte) error {
client, err := c.connect(context.TODO(), "derphttp.Client.Send") client, err := c.connect(context.TODO(), "derphttp.Client.Send")
if err != nil { if err != nil {

View File

@ -7,151 +7,59 @@
import ( import (
"fmt" "fmt"
"net" "strings"
"tailscale.com/types/structs" "tailscale.com/tailcfg"
) )
// World is a set of DERP server. func derpNode(suffix, v4, v6 string) *tailcfg.DERPNode {
type World struct { return &tailcfg.DERPNode{
servers []*Server Name: suffix, // updated later
ids []int RegionID: 0, // updated later
byID map[int]*Server IPv4: v4,
stun4 []string IPv6: v6,
stun6 []string
}
func (w *World) IDs() []int { return w.ids }
func (w *World) STUN4() []string { return w.stun4 }
func (w *World) STUN6() []string { return w.stun6 }
func (w *World) ServerByID(id int) *Server { return w.byID[id] }
// LocationOfID returns the geographic name of a node, if present.
func (w *World) LocationOfID(id int) string {
if s, ok := w.byID[id]; ok {
return s.Geo
}
return ""
}
func (w *World) NodeIDOfSTUNServer(server string) int {
// TODO: keep reverse map? Small enough to not matter for now.
for _, s := range w.servers {
if s.STUN4 == server || s.STUN6 == server {
return s.ID
}
}
return 0
}
// ForeachServer calls fn for each DERP server, in an unspecified order.
func (w *World) ForeachServer(fn func(*Server)) {
for _, s := range w.byID {
fn(s)
} }
} }
// Prod returns the production DERP nodes. func derpRegion(id int, code string, nodes ...*tailcfg.DERPNode) *tailcfg.DERPRegion {
func Prod() *World { region := &tailcfg.DERPRegion{
return prod RegionID: id,
RegionCode: code,
Nodes: nodes,
}
for _, n := range nodes {
n.Name = fmt.Sprintf("%d%s", id, n.Name)
n.RegionID = id
n.HostName = fmt.Sprintf("derp%s.tailscale.com", strings.TrimSuffix(n.Name, "a"))
}
return region
} }
func NewTestWorld(stun ...string) *World { // Prod returns Tailscale's map of relay servers.
w := &World{} //
for i, s := range stun { // This list is only used by cmd/tailscale's netcheck subcommand. In
w.add(&Server{ // normal operation the Tailscale nodes get this sent to them from the
ID: i + 1, // control server.
Geo: fmt.Sprintf("Testopolis-%d", i+1), //
STUN4: s, // This list is subject to change and should not be relied on.
}) func Prod() *tailcfg.DERPMap {
} return &tailcfg.DERPMap{
return w Regions: map[int]*tailcfg.DERPRegion{
} 1: derpRegion(1, "nyc",
derpNode("a", "159.89.225.99", "2604:a880:400:d1::828:b001"),
func NewTestWorldWith(servers ...*Server) *World { ),
w := &World{} 2: derpRegion(2, "sfo",
for _, s := range servers { derpNode("a", "167.172.206.31", "2604:a880:2:d1::c5:7001"),
w.add(s) ),
} 3: derpRegion(3, "sin",
return w derpNode("a", "68.183.179.66", "2400:6180:0:d1::67d:8001"),
} ),
4: derpRegion(4, "fra",
var prod = new(World) // ... a dazzling place I never knew derpNode("a", "167.172.182.26", "2a03:b0c0:3:e0::36e:9001"),
),
func addProd(id int, geo string) { 5: derpRegion(5, "syd",
prod.add(&Server{ derpNode("a", "103.43.75.49", "2001:19f0:5801:10b7:5400:2ff:feaa:284c"),
ID: id, ),
Geo: geo, },
HostHTTPS: fmt.Sprintf("derp%v.tailscale.com", id),
STUN4: fmt.Sprintf("derp%v.tailscale.com:3478", id),
STUN6: fmt.Sprintf("derp%v-v6.tailscale.com:3478", id),
})
}
func (w *World) add(s *Server) {
if s.ID == 0 {
panic("ID required")
}
if _, dup := w.byID[s.ID]; dup {
panic("duplicate prod server")
}
if w.byID == nil {
w.byID = make(map[int]*Server)
}
w.byID[s.ID] = s
w.ids = append(w.ids, s.ID)
w.servers = append(w.servers, s)
if s.STUN4 != "" {
w.stun4 = append(w.stun4, s.STUN4)
if _, _, err := net.SplitHostPort(s.STUN4); err != nil {
panic("not a host:port: " + s.STUN4)
}
}
if s.STUN6 != "" {
w.stun6 = append(w.stun6, s.STUN6)
if _, _, err := net.SplitHostPort(s.STUN6); err != nil {
panic("not a host:port: " + s.STUN6)
}
} }
} }
func init() {
addProd(1, "New York")
addProd(2, "San Francisco")
addProd(3, "Singapore")
addProd(4, "Frankfurt")
addProd(5, "Sydney")
}
// Server is configuration for a DERP server.
type Server struct {
_ structs.Incomparable
ID int
// HostHTTPS is the HTTPS hostname.
HostHTTPS string
// STUN4 is the host:port of the IPv4 STUN server on this DERP
// node. Required.
STUN4 string
// STUN6 optionally provides the IPv6 host:port of the STUN
// server on the DERP node.
// It should be an IPv6-only address for now. (We currently make lazy
// assumptions that the server names are unique.)
STUN6 string
// Geo is a human-readable geographic region name of this server.
Geo string
}
func (s *Server) String() string {
if s == nil {
return "<nil *derpmap.Server>"
}
if s.Geo != "" {
return fmt.Sprintf("%v (%v)", s.HostHTTPS, s.Geo)
}
return s.HostHTTPS
}

1
go.mod
View File

@ -28,7 +28,6 @@ require (
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e
golang.org/x/sys v0.0.0-20200501052902-10377860bb8e golang.org/x/sys v0.0.0-20200501052902-10377860bb8e
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 golang.org/x/time v0.0.0-20191024005414-555d28b269f0
gortc.io/stun v1.22.1
inet.af/netaddr v0.0.0-20200430175045-5aaf2097c7fc inet.af/netaddr v0.0.0-20200430175045-5aaf2097c7fc
rsc.io/goversion v1.2.0 rsc.io/goversion v1.2.0
) )

2
go.sum
View File

@ -142,8 +142,6 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo= gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gortc.io/stun v1.22.1 h1:96mOdDATYRqhYB+TZdenWBg4CzL2Ye5kPyBXQ8KAB+8=
gortc.io/stun v1.22.1/go.mod h1:XD5lpONVyjvV3BgOyJFNo0iv6R2oZB4L+weMqxts+zg=
inet.af/netaddr v0.0.0-20200430175045-5aaf2097c7fc h1:We3b/z+7i9LV4Ls0yWve5vYIlnAPSPeqxKVgZseRDBs= inet.af/netaddr v0.0.0-20200430175045-5aaf2097c7fc h1:We3b/z+7i9LV4Ls0yWve5vYIlnAPSPeqxKVgZseRDBs=
inet.af/netaddr v0.0.0-20200430175045-5aaf2097c7fc/go.mod h1:qqYzz/2whtrbWJvt+DNWQyvekNN4ePQZcg2xc2/Yjww= inet.af/netaddr v0.0.0-20200430175045-5aaf2097c7fc/go.mod h1:qqYzz/2whtrbWJvt+DNWQyvekNN4ePQZcg2xc2/Yjww=
rsc.io/goversion v1.2.0 h1:SPn+NLTiAG7w30IRK/DKp1BjvpWabYgxlLp/+kx5J8w= rsc.io/goversion v1.2.0 h1:SPn+NLTiAG7w30IRK/DKp1BjvpWabYgxlLp/+kx5J8w=

View File

@ -240,10 +240,8 @@ func (b *LocalBackend) Start(opts Options) error {
b.notify = opts.Notify b.notify = opts.Notify
b.netMapCache = nil b.netMapCache = nil
persist := b.prefs.Persist persist := b.prefs.Persist
wantDERP := !b.prefs.DisableDERP
b.mu.Unlock() b.mu.Unlock()
b.e.SetDERPEnabled(wantDERP)
b.updateFilter(nil) b.updateFilter(nil)
var err error var err error
@ -307,11 +305,17 @@ func (b *LocalBackend) Start(opts Options) error {
b.logf("netmap diff:\n%v", diff) b.logf("netmap diff:\n%v", diff)
} }
} }
disableDERP := b.prefs != nil && b.prefs.DisableDERP
b.netMapCache = newSt.NetMap b.netMapCache = newSt.NetMap
b.mu.Unlock() b.mu.Unlock()
b.send(Notify{NetMap: newSt.NetMap}) b.send(Notify{NetMap: newSt.NetMap})
b.updateFilter(newSt.NetMap) b.updateFilter(newSt.NetMap)
if disableDERP {
b.e.SetDERPMap(nil)
} else {
b.e.SetDERPMap(newSt.NetMap.DERPMap)
}
} }
if newSt.URL != "" { if newSt.URL != "" {
b.logf("Received auth URL: %.20v...", newSt.URL) b.logf("Received auth URL: %.20v...", newSt.URL)

File diff suppressed because it is too large Load Diff

View File

@ -9,28 +9,34 @@
"fmt" "fmt"
"net" "net"
"reflect" "reflect"
"sort"
"strconv"
"strings"
"testing" "testing"
"time" "time"
"tailscale.com/derp/derpmap"
"tailscale.com/stun" "tailscale.com/stun"
"tailscale.com/stun/stuntest" "tailscale.com/stun/stuntest"
"tailscale.com/tailcfg"
) )
func TestHairpinSTUN(t *testing.T) { func TestHairpinSTUN(t *testing.T) {
tx := stun.NewTxID()
c := &Client{ c := &Client{
hairTX: stun.NewTxID(), curState: &reportState{
gotHairSTUN: make(chan *net.UDPAddr, 1), hairTX: tx,
gotHairSTUN: make(chan *net.UDPAddr, 1),
},
} }
req := stun.Request(c.hairTX) req := stun.Request(tx)
if !stun.Is(req) { if !stun.Is(req) {
t.Fatal("expected STUN message") t.Fatal("expected STUN message")
} }
if !c.handleHairSTUN(req, nil) { if !c.handleHairSTUNLocked(req, nil) {
t.Fatal("expected true") t.Fatal("expected true")
} }
select { select {
case <-c.gotHairSTUN: case <-c.curState.gotHairSTUN:
default: default:
t.Fatal("expected value") t.Fatal("expected value")
} }
@ -41,25 +47,24 @@ func TestBasic(t *testing.T) {
defer cleanup() defer cleanup()
c := &Client{ c := &Client{
DERP: derpmap.NewTestWorld(stunAddr),
Logf: t.Logf, Logf: t.Logf,
} }
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel() defer cancel()
r, err := c.GetReport(ctx) r, err := c.GetReport(ctx, stuntest.DERPMapOf(stunAddr.String()))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !r.UDP { if !r.UDP {
t.Error("want UDP") t.Error("want UDP")
} }
if len(r.DERPLatency) != 1 { if len(r.RegionLatency) != 1 {
t.Errorf("expected 1 key in DERPLatency; got %+v", r.DERPLatency) t.Errorf("expected 1 key in DERPLatency; got %+v", r.RegionLatency)
} }
if _, ok := r.DERPLatency[stunAddr]; !ok { if _, ok := r.RegionLatency[1]; !ok {
t.Errorf("expected key %q in DERPLatency; got %+v", stunAddr, r.DERPLatency) t.Errorf("expected key 1 in DERPLatency; got %+v", r.RegionLatency)
} }
if r.GlobalV4 == "" { if r.GlobalV4 == "" {
t.Error("expected GlobalV4 set") t.Error("expected GlobalV4 set")
@ -78,20 +83,20 @@ func TestWorksWhenUDPBlocked(t *testing.T) {
stunAddr := blackhole.LocalAddr().String() stunAddr := blackhole.LocalAddr().String()
dm := stuntest.DERPMapOf(stunAddr)
dm.Regions[1].Nodes[0].STUNOnly = true
c := &Client{ c := &Client{
DERP: derpmap.NewTestWorld(stunAddr),
Logf: t.Logf, Logf: t.Logf,
} }
ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond)
defer cancel() defer cancel()
r, err := c.GetReport(ctx) r, err := c.GetReport(ctx, dm)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
want := &Report{ want := new(Report)
DERPLatency: map[string]time.Duration{},
}
if !reflect.DeepEqual(r, want) { if !reflect.DeepEqual(r, want) {
t.Errorf("mismatch\n got: %+v\nwant: %+v\n", r, want) t.Errorf("mismatch\n got: %+v\nwant: %+v\n", r, want)
@ -99,30 +104,24 @@ func TestWorksWhenUDPBlocked(t *testing.T) {
} }
func TestAddReportHistoryAndSetPreferredDERP(t *testing.T) { func TestAddReportHistoryAndSetPreferredDERP(t *testing.T) {
derps := derpmap.NewTestWorldWith(
&derpmap.Server{
ID: 1,
STUN4: "d1:1",
},
&derpmap.Server{
ID: 2,
STUN4: "d2:1",
},
&derpmap.Server{
ID: 3,
STUN4: "d3:1",
},
)
// report returns a *Report from (DERP host, time.Duration)+ pairs. // report returns a *Report from (DERP host, time.Duration)+ pairs.
report := func(a ...interface{}) *Report { report := func(a ...interface{}) *Report {
r := &Report{DERPLatency: map[string]time.Duration{}} r := &Report{RegionLatency: map[int]time.Duration{}}
for i := 0; i < len(a); i += 2 { for i := 0; i < len(a); i += 2 {
k := a[i].(string) + ":1" s := a[i].(string)
if !strings.HasPrefix(s, "d") {
t.Fatalf("invalid derp server key %q", s)
}
regionID, err := strconv.Atoi(s[1:])
if err != nil {
t.Fatalf("invalid derp server key %q", s)
}
switch v := a[i+1].(type) { switch v := a[i+1].(type) {
case time.Duration: case time.Duration:
r.DERPLatency[k] = v r.RegionLatency[regionID] = v
case int: case int:
r.DERPLatency[k] = time.Second * time.Duration(v) r.RegionLatency[regionID] = time.Second * time.Duration(v)
default: default:
panic(fmt.Sprintf("unexpected type %T", v)) panic(fmt.Sprintf("unexpected type %T", v))
} }
@ -194,7 +193,6 @@ type step struct {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
fakeTime := time.Unix(123, 0) fakeTime := time.Unix(123, 0)
c := &Client{ c := &Client{
DERP: derps,
TimeNow: func() time.Time { return fakeTime }, TimeNow: func() time.Time { return fakeTime },
} }
for _, s := range tt.steps { for _, s := range tt.steps {
@ -212,81 +210,217 @@ type step struct {
} }
} }
func TestPickSubset(t *testing.T) { func TestMakeProbePlan(t *testing.T) {
derps := derpmap.NewTestWorldWith( // basicMap has 5 regions. each region has a number of nodes
&derpmap.Server{ // equal to the region number (1 has 1a, 2 has 2a and 2b, etc.)
ID: 1, basicMap := &tailcfg.DERPMap{
STUN4: "d1:4", Regions: map[int]*tailcfg.DERPRegion{},
STUN6: "d1:6", }
}, for rid := 1; rid <= 5; rid++ {
&derpmap.Server{ var nodes []*tailcfg.DERPNode
ID: 2, for nid := 0; nid < rid; nid++ {
STUN4: "d2:4", nodes = append(nodes, &tailcfg.DERPNode{
STUN6: "d2:6", Name: fmt.Sprintf("%d%c", rid, 'a'+rune(nid)),
}, RegionID: rid,
&derpmap.Server{ HostName: fmt.Sprintf("derp%d-%d", rid, nid),
ID: 3, IPv4: fmt.Sprintf("%d.0.0.%d", rid, nid),
STUN4: "d3:4", IPv6: fmt.Sprintf("%d::%d", rid, nid),
STUN6: "d3:6", })
}, }
) basicMap.Regions[rid] = &tailcfg.DERPRegion{
RegionID: rid,
Nodes: nodes,
}
}
const ms = time.Millisecond
p := func(name string, c rune, d ...time.Duration) probe {
var proto probeProto
switch c {
case 4:
proto = probeIPv4
case 6:
proto = probeIPv6
case 'h':
proto = probeHTTPS
}
pr := probe{node: name, proto: proto}
if len(d) == 1 {
pr.delay = d[0]
} else if len(d) > 1 {
panic("too many args")
}
return pr
}
tests := []struct { tests := []struct {
name string name string
last *Report dm *tailcfg.DERPMap
want4 []string have6if bool
want6 []string last *Report
wantTries map[string]int want probePlan
}{ }{
{ {
name: "fresh", name: "initial_v6",
last: nil, dm: basicMap,
want4: []string{"d1:4", "d2:4", "d3:4"}, have6if: true,
want6: []string{"d1:6", "d2:6", "d3:6"}, last: nil, // initial
wantTries: map[string]int{ want: probePlan{
"d1:4": 2, "region-1-v4": []probe{p("1a", 4), p("1a", 4, 100*ms), p("1a", 4, 200*ms)}, // all a
"d2:4": 2, "region-1-v6": []probe{p("1a", 6), p("1a", 6, 100*ms), p("1a", 6, 200*ms)},
"d3:4": 2, "region-2-v4": []probe{p("2a", 4), p("2b", 4, 100*ms), p("2a", 4, 200*ms)}, // a -> b -> a
"d1:6": 1, "region-2-v6": []probe{p("2a", 6), p("2b", 6, 100*ms), p("2a", 6, 200*ms)},
"d2:6": 1, "region-3-v4": []probe{p("3a", 4), p("3b", 4, 100*ms), p("3c", 4, 200*ms)}, // a -> b -> c
"d3:6": 1, "region-3-v6": []probe{p("3a", 6), p("3b", 6, 100*ms), p("3c", 6, 200*ms)},
"region-4-v4": []probe{p("4a", 4), p("4b", 4, 100*ms), p("4c", 4, 200*ms)},
"region-4-v6": []probe{p("4a", 6), p("4b", 6, 100*ms), p("4c", 6, 200*ms)},
"region-5-v4": []probe{p("5a", 4), p("5b", 4, 100*ms), p("5c", 4, 200*ms)},
"region-5-v6": []probe{p("5a", 6), p("5b", 6, 100*ms), p("5c", 6, 200*ms)},
}, },
}, },
{ {
name: "1_and_3_closest", name: "initial_no_v6",
dm: basicMap,
have6if: false,
last: nil, // initial
want: probePlan{
"region-1-v4": []probe{p("1a", 4), p("1a", 4, 100*ms), p("1a", 4, 200*ms)}, // all a
"region-2-v4": []probe{p("2a", 4), p("2b", 4, 100*ms), p("2a", 4, 200*ms)}, // a -> b -> a
"region-3-v4": []probe{p("3a", 4), p("3b", 4, 100*ms), p("3c", 4, 200*ms)}, // a -> b -> c
"region-4-v4": []probe{p("4a", 4), p("4b", 4, 100*ms), p("4c", 4, 200*ms)},
"region-5-v4": []probe{p("5a", 4), p("5b", 4, 100*ms), p("5c", 4, 200*ms)},
},
},
{
name: "second_v4_no_6if",
dm: basicMap,
have6if: false,
last: &Report{ last: &Report{
DERPLatency: map[string]time.Duration{ RegionLatency: map[int]time.Duration{
"d1:4": 15 * time.Millisecond, 1: 10 * time.Millisecond,
"d2:4": 300 * time.Millisecond, 2: 20 * time.Millisecond,
"d3:4": 25 * time.Millisecond, 3: 30 * time.Millisecond,
4: 40 * time.Millisecond,
// Pretend 5 is missing
},
RegionV4Latency: map[int]time.Duration{
1: 10 * time.Millisecond,
2: 20 * time.Millisecond,
3: 30 * time.Millisecond,
4: 40 * time.Millisecond,
}, },
}, },
want4: []string{"d1:4", "d2:4", "d3:4"}, want: probePlan{
want6: []string{"d1:6", "d3:6"}, "region-1-v4": []probe{p("1a", 4), p("1a", 4, 12*ms)},
wantTries: map[string]int{ "region-2-v4": []probe{p("2a", 4), p("2b", 4, 24*ms)},
"d1:4": 2, "region-3-v4": []probe{p("3a", 4)},
"d3:4": 2, },
"d2:4": 1, },
"d1:6": 1, {
"d3:6": 1, name: "second_v4_only_with_6if",
dm: basicMap,
have6if: true,
last: &Report{
RegionLatency: map[int]time.Duration{
1: 10 * time.Millisecond,
2: 20 * time.Millisecond,
3: 30 * time.Millisecond,
4: 40 * time.Millisecond,
// Pretend 5 is missing
},
RegionV4Latency: map[int]time.Duration{
1: 10 * time.Millisecond,
2: 20 * time.Millisecond,
3: 30 * time.Millisecond,
4: 40 * time.Millisecond,
},
},
want: probePlan{
"region-1-v4": []probe{p("1a", 4), p("1a", 4, 12*ms)},
"region-1-v6": []probe{p("1a", 6)},
"region-2-v4": []probe{p("2a", 4), p("2b", 4, 24*ms)},
"region-2-v6": []probe{p("2a", 6)},
"region-3-v4": []probe{p("3a", 4)},
},
},
{
name: "second_mixed",
dm: basicMap,
have6if: true,
last: &Report{
RegionLatency: map[int]time.Duration{
1: 10 * time.Millisecond,
2: 20 * time.Millisecond,
3: 30 * time.Millisecond,
4: 40 * time.Millisecond,
// Pretend 5 is missing
},
RegionV4Latency: map[int]time.Duration{
1: 10 * time.Millisecond,
2: 20 * time.Millisecond,
},
RegionV6Latency: map[int]time.Duration{
3: 30 * time.Millisecond,
4: 40 * time.Millisecond,
},
},
want: probePlan{
"region-1-v4": []probe{p("1a", 4), p("1a", 4, 12*ms)},
"region-1-v6": []probe{p("1a", 6), p("1a", 6, 12*ms)},
"region-2-v4": []probe{p("2a", 4), p("2b", 4, 24*ms)},
"region-2-v6": []probe{p("2a", 6), p("2b", 6, 24*ms)},
"region-3-v4": []probe{p("3a", 4)},
}, },
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := &Client{DERP: derps, last: tt.last} got := makeProbePlan(tt.dm, tt.have6if, tt.last)
got4, got6, gotTries, err := c.pickSubset() if !reflect.DeepEqual(got, tt.want) {
if err != nil { t.Errorf("unexpected plan; got:\n%v\nwant:\n%v\n", got, tt.want)
t.Fatal(err)
}
if !reflect.DeepEqual(got4, tt.want4) {
t.Errorf("stuns4 = %q; want %q", got4, tt.want4)
}
if !reflect.DeepEqual(got6, tt.want6) {
t.Errorf("stuns6 = %q; want %q", got6, tt.want6)
}
if !reflect.DeepEqual(gotTries, tt.wantTries) {
t.Errorf("tries = %v; want %v", gotTries, tt.wantTries)
} }
}) })
} }
} }
func (plan probePlan) String() string {
var sb strings.Builder
keys := []string{}
for k := range plan {
keys = append(keys, k)
}
sort.Strings(keys)
for _, key := range keys {
fmt.Fprintf(&sb, "[%s]", key)
pv := plan[key]
for _, p := range pv {
fmt.Fprintf(&sb, " %v", p)
}
sb.WriteByte('\n')
}
return sb.String()
}
func (p probe) String() string {
wait := ""
if p.wait > 0 {
wait = "+" + p.wait.String()
}
delay := ""
if p.delay > 0 {
delay = "@" + p.delay.String()
}
return fmt.Sprintf("%s-%s%s%s", p.node, p.proto, delay, wait)
}
func (p probeProto) String() string {
switch p {
case probeIPv4:
return "v4"
case probeIPv6:
return "v4"
case probeHTTPS:
return "https"
}
return "?"
}

View File

@ -6,12 +6,16 @@
package stuntest package stuntest
import ( import (
"fmt"
"net" "net"
"strconv"
"strings" "strings"
"sync" "sync"
"testing" "testing"
"inet.af/netaddr"
"tailscale.com/stun" "tailscale.com/stun"
"tailscale.com/tailcfg"
) )
type stunStats struct { type stunStats struct {
@ -20,7 +24,7 @@ type stunStats struct {
readIPv6 int readIPv6 int
} }
func Serve(t *testing.T) (addr string, cleanupFn func()) { func Serve(t *testing.T) (addr *net.UDPAddr, cleanupFn func()) {
t.Helper() t.Helper()
// TODO(crawshaw): use stats to test re-STUN logic // TODO(crawshaw): use stats to test re-STUN logic
@ -30,13 +34,13 @@ func Serve(t *testing.T) (addr string, cleanupFn func()) {
if err != nil { if err != nil {
t.Fatalf("failed to open STUN listener: %v", err) t.Fatalf("failed to open STUN listener: %v", err)
} }
addr = &net.UDPAddr{
stunAddr := pc.LocalAddr().String() IP: net.ParseIP("127.0.0.1"),
stunAddr = strings.Replace(stunAddr, "0.0.0.0:", "127.0.0.1:", 1) Port: pc.LocalAddr().(*net.UDPAddr).Port,
}
doneCh := make(chan struct{}) doneCh := make(chan struct{})
go runSTUN(t, pc, &stats, doneCh) go runSTUN(t, pc, &stats, doneCh)
return stunAddr, func() { return addr, func() {
pc.Close() pc.Close()
<-doneCh <-doneCh
} }
@ -79,3 +83,47 @@ func runSTUN(t *testing.T, pc net.PacketConn, stats *stunStats, done chan<- stru
} }
} }
} }
func DERPMapOf(stun ...string) *tailcfg.DERPMap {
m := &tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{},
}
for i, hostPortStr := range stun {
regionID := i + 1
host, portStr, err := net.SplitHostPort(hostPortStr)
if err != nil {
panic(fmt.Sprintf("bogus STUN hostport: %q", hostPortStr))
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("bogus port %q in %q", portStr, hostPortStr))
}
var ipv4, ipv6 string
ip, err := netaddr.ParseIP(host)
if err != nil {
panic(fmt.Sprintf("bogus non-IP STUN host %q in %q", host, hostPortStr))
}
if ip.Is4() {
ipv4 = host
ipv6 = "none"
}
if ip.Is6() {
ipv6 = host
ipv4 = "none"
}
node := &tailcfg.DERPNode{
Name: fmt.Sprint(regionID) + "a",
RegionID: regionID,
HostName: fmt.Sprintf("d%d.invalid", regionID),
IPv4: ipv4,
IPv6: ipv6,
STUNPort: port,
STUNOnly: true,
}
m.Regions[regionID] = &tailcfg.DERPRegion{
RegionID: regionID,
Nodes: []*tailcfg.DERPNode{node},
}
}
return m
}

View File

@ -1,310 +0,0 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stunner
import (
"context"
"errors"
"fmt"
"math/rand"
"net"
"strconv"
"strings"
"sync"
"time"
"tailscale.com/net/dnscache"
"tailscale.com/stun"
"tailscale.com/types/structs"
)
// Stunner sends a STUN request to several servers and handles a response.
//
// It is designed to used on a connection owned by other code and so does
// not directly reference a net.Conn of any sort. Instead, the user should
// provide Send function to send packets, and call Receive when a new
// STUN response is received.
//
// In response, a Stunner will call Endpoint with any endpoints determined
// for the connection. (An endpoint may be reported multiple times if
// multiple servers are provided.)
type Stunner struct {
// Send sends a packet.
// It will typically be a PacketConn.WriteTo method value.
Send func([]byte, net.Addr) (int, error) // sends a packet
// Endpoint is called whenever a STUN response is received.
// The server is the STUN server that replied, endpoint is the ip:port
// from the STUN response, and d is the duration that the STUN request
// took on the wire (not including DNS lookup time.
Endpoint func(server, endpoint string, d time.Duration)
// onPacket is the internal version of Endpoint that does de-dup.
// It's set by Run.
onPacket func(server, endpoint string, d time.Duration)
Servers []string // STUN servers to contact
// DNSCache optionally specifies a DNSCache to use.
// If nil, a DNS cache is not used.
DNSCache *dnscache.Resolver
// Logf optionally specifies a log function. If nil, logging is disabled.
Logf func(format string, args ...interface{})
// OnlyIPv6 controls whether IPv6 is exclusively used.
// If false, only IPv4 is used. There is currently no mixed mode.
OnlyIPv6 bool
// MaxTries optionally provides a mapping from server name to the maximum
// number of tries that should be made for a given server.
// If nil or a server is not present in the map, the default is 1.
// Values less than 1 are ignored.
MaxTries map[string]int
mu sync.Mutex
inFlight map[stun.TxID]request
}
func (s *Stunner) addTX(tx stun.TxID, server string) {
s.mu.Lock()
defer s.mu.Unlock()
if _, dup := s.inFlight[tx]; dup {
panic("unexpected duplicate STUN TransactionID")
}
s.inFlight[tx] = request{sent: time.Now(), server: server}
}
func (s *Stunner) removeTX(tx stun.TxID) (request, bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.inFlight == nil {
return request{}, false
}
r, ok := s.inFlight[tx]
if ok {
delete(s.inFlight, tx)
} else {
s.logf("stunner: got STUN packet for unknown TxID %x", tx)
}
return r, ok
}
type request struct {
_ structs.Incomparable
sent time.Time
server string
}
func (s *Stunner) logf(format string, args ...interface{}) {
if s.Logf != nil {
s.Logf(format, args...)
}
}
// Receive delivers a STUN packet to the stunner.
func (s *Stunner) Receive(p []byte, fromAddr *net.UDPAddr) {
if !stun.Is(p) {
s.logf("[unexpected] stunner: received non-STUN packet")
return
}
now := time.Now()
tx, addr, port, err := stun.ParseResponse(p)
if err != nil {
if _, err := stun.ParseBindingRequest(p); err == nil {
// This was probably our own netcheck hairpin
// check probe coming in late. Ignore.
return
}
s.logf("stunner: received unexpected STUN message response from %v: %v", fromAddr, err)
return
}
r, ok := s.removeTX(tx)
if !ok {
return
}
d := now.Sub(r.sent)
host := net.JoinHostPort(net.IP(addr).String(), fmt.Sprint(port))
s.onPacket(r.server, host, d)
}
func (s *Stunner) resolver() *net.Resolver {
return net.DefaultResolver
}
// cleanUpPostRun zeros out some fields, mostly for debugging (so
// things crash or race+fail if there's a sender still running.)
func (s *Stunner) cleanUpPostRun() {
s.mu.Lock()
s.inFlight = nil
s.mu.Unlock()
}
// Run starts a Stunner and blocks until all servers either respond
// or are tried multiple times and timeout.
// It can not be called concurrently with itself.
func (s *Stunner) Run(ctx context.Context) error {
for _, server := range s.Servers {
if _, _, err := net.SplitHostPort(server); err != nil {
return fmt.Errorf("Stunner.Run: invalid server %q (in Server list %q)", server, s.Servers)
}
}
if len(s.Servers) == 0 {
return errors.New("stunner: no Servers")
}
s.inFlight = make(map[stun.TxID]request)
defer s.cleanUpPostRun()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
type sender struct {
ctx context.Context
cancel context.CancelFunc
}
var (
needMu sync.Mutex
need = make(map[string]sender) // keyed by server; deleted when done
allDone = make(chan struct{}) // closed when need is empty
)
s.onPacket = func(server, endpoint string, d time.Duration) {
needMu.Lock()
defer needMu.Unlock()
sender, ok := need[server]
if !ok {
return
}
sender.cancel()
delete(need, server)
s.Endpoint(server, endpoint, d)
if len(need) == 0 {
close(allDone)
}
}
var wg sync.WaitGroup
for _, server := range s.Servers {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
need[server] = sender{ctx, cancel}
}
needMu.Lock()
for server, sender := range need {
wg.Add(1)
server, ctx := server, sender.ctx
go func() {
defer wg.Done()
s.sendPackets(ctx, server)
}()
}
needMu.Unlock()
var err error
select {
case <-ctx.Done():
err = ctx.Err()
case <-allDone:
cancel()
}
wg.Wait()
var missing []string
needMu.Lock()
for server := range need {
missing = append(missing, server)
}
needMu.Unlock()
if len(missing) == 0 || err == nil {
return nil
}
return fmt.Errorf("got STUN error: %w; missing replies from: %v", err, strings.Join(missing, ", "))
}
func (s *Stunner) serverAddr(ctx context.Context, server string) (*net.UDPAddr, error) {
hostStr, portStr, err := net.SplitHostPort(server)
if err != nil {
return nil, err
}
addrPort, err := strconv.Atoi(portStr)
if err != nil {
return nil, fmt.Errorf("port: %v", err)
}
if addrPort == 0 {
addrPort = 3478
}
addr := &net.UDPAddr{Port: addrPort}
var ipAddrs []net.IPAddr
if s.DNSCache != nil {
ip, err := s.DNSCache.LookupIP(ctx, hostStr)
if err != nil {
return nil, err
}
ipAddrs = []net.IPAddr{{IP: ip}}
} else {
ipAddrs, err = s.resolver().LookupIPAddr(ctx, hostStr)
if err != nil {
return nil, fmt.Errorf("lookup ip addr (%q): %v", hostStr, err)
}
}
for _, ipAddr := range ipAddrs {
ip4 := ipAddr.IP.To4()
if ip4 != nil {
if s.OnlyIPv6 {
continue
}
addr.IP = ip4
break
} else if s.OnlyIPv6 {
addr.IP = ipAddr.IP
addr.Zone = ipAddr.Zone
}
}
if addr.IP == nil {
if s.OnlyIPv6 {
return nil, fmt.Errorf("cannot resolve any ipv6 addresses for %s, got: %v", server, ipAddrs)
}
return nil, fmt.Errorf("cannot resolve any ipv4 addresses for %s, got: %v", server, ipAddrs)
}
return addr, nil
}
// maxTriesForServer returns the maximum number of STUN queries that
// will be sent to server (for one call to Run). The default is 1.
func (s *Stunner) maxTriesForServer(server string) int {
if v, ok := s.MaxTries[server]; ok && v > 0 {
return v
}
return 1
}
func (s *Stunner) sendPackets(ctx context.Context, server string) error {
addr, err := s.serverAddr(ctx, server)
if err != nil {
return err
}
maxTries := s.maxTriesForServer(server)
for i := 0; i < maxTries; i++ {
txID := stun.NewTxID()
req := stun.Request(txID)
s.addTX(txID, server)
_, err = s.Send(req, addr)
if err != nil {
return fmt.Errorf("send: %v", err)
}
select {
case <-ctx.Done():
// Ignore error. The caller deals with handling contexts.
// We only use it to dermine when to stop spraying STUN packets.
return nil
case <-time.After(time.Millisecond * time.Duration(50+rand.Intn(200))):
}
}
return nil
}

View File

@ -1,154 +0,0 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stunner
import (
"context"
"errors"
"fmt"
"net"
"sort"
"testing"
"time"
"gortc.io/stun"
)
func TestStun(t *testing.T) {
conn1, err := net.ListenPacket("udp4", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer conn1.Close()
conn2, err := net.ListenPacket("udp4", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer conn2.Close()
stunServers := []string{
conn1.LocalAddr().String(), conn2.LocalAddr().String(),
}
epCh := make(chan string, 16)
localConn, err := net.ListenPacket("udp4", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
s := &Stunner{
Send: localConn.WriteTo,
Endpoint: func(server, ep string, d time.Duration) { epCh <- ep },
Servers: stunServers,
MaxTries: map[string]int{
stunServers[0]: 2,
stunServers[1]: 2,
},
}
stun1Err := make(chan error)
go func() {
stun1Err <- startSTUN(conn1, s.Receive)
}()
stun2Err := make(chan error)
go func() {
stun2Err <- startSTUNDrop1(conn2, s.Receive)
}()
errCh := make(chan error)
go func() {
errCh <- s.Run(context.Background())
}()
var eps []string
select {
case ep := <-epCh:
eps = append(eps, ep)
case <-time.After(100 * time.Millisecond):
t.Fatal("missing first endpoint response")
}
select {
case ep := <-epCh:
eps = append(eps, ep)
case <-time.After(500 * time.Millisecond):
t.Fatal("missing second endpoint response")
}
sort.Strings(eps)
if want := "1.2.3.4:1234"; eps[0] != want {
t.Errorf("eps[0]=%q, want %q", eps[0], want)
}
if want := "4.5.6.7:4567"; eps[1] != want {
t.Errorf("eps[1]=%q, want %q", eps[1], want)
}
if err := <-errCh; err != nil {
t.Fatal(err)
}
}
func startSTUNDrop1(conn net.PacketConn, writeTo func([]byte, *net.UDPAddr)) error {
if _, _, err := conn.ReadFrom(make([]byte, 1024)); err != nil {
return fmt.Errorf("first stun server read failed: %v", err)
}
req := new(stun.Message)
res := new(stun.Message)
p := make([]byte, 1024)
n, addr, err := conn.ReadFrom(p)
if err != nil {
return err
}
p = p[:n]
if !stun.IsMessage(p) {
return errors.New("not a STUN message")
}
if _, err := req.Write(p); err != nil {
return err
}
mappedAddr := &stun.XORMappedAddress{
IP: net.ParseIP("1.2.3.4"),
Port: 1234,
}
software := stun.NewSoftware("endpointer")
err = res.Build(req, stun.BindingSuccess, software, mappedAddr, stun.Fingerprint)
if err != nil {
return err
}
writeTo(res.Raw, addr.(*net.UDPAddr))
return nil
}
func startSTUN(conn net.PacketConn, writeTo func([]byte, *net.UDPAddr)) error {
req := new(stun.Message)
res := new(stun.Message)
p := make([]byte, 1024)
n, addr, err := conn.ReadFrom(p)
if err != nil {
return err
}
p = p[:n]
if !stun.IsMessage(p) {
return errors.New("not a STUN message")
}
if _, err := req.Write(p); err != nil {
return err
}
mappedAddr := &stun.XORMappedAddress{
IP: net.ParseIP("4.5.6.7"),
Port: 4567,
}
software := stun.NewSoftware("endpointer")
err = res.Build(req, stun.BindingSuccess, software, mappedAddr, stun.Fingerprint)
if err != nil {
return err
}
writeTo(res.Raw, addr.(*net.UDPAddr))
return nil
}
// TODO: test retry timeout (overwrite the retryDurations)
// TODO: test canceling context passed to Run
// TODO: test sending bad packets

View File

@ -4,6 +4,8 @@
package tailcfg package tailcfg
import "sort"
// DERPMap describes the set of DERP packet relay servers that are available. // DERPMap describes the set of DERP packet relay servers that are available.
type DERPMap struct { type DERPMap struct {
// Regions is the set of geographic regions running DERP node(s). // Regions is the set of geographic regions running DERP node(s).
@ -14,6 +16,16 @@ type DERPMap struct {
Regions map[int]*DERPRegion Regions map[int]*DERPRegion
} }
/// RegionIDs returns the sorted region IDs.
func (m *DERPMap) RegionIDs() []int {
ret := make([]int, 0, len(m.Regions))
for rid := range m.Regions {
ret = append(ret, rid)
}
sort.Ints(ret)
return ret
}
// DERPRegion is a geographic region running DERP relay node(s). // DERPRegion is a geographic region running DERP relay node(s).
// //
// Client nodes discover which region they're closest to, advertise // Client nodes discover which region they're closest to, advertise
@ -85,9 +97,29 @@ type DERPNode struct {
// IPv4 optionally forces an IPv4 address to use, instead of using DNS. // IPv4 optionally forces an IPv4 address to use, instead of using DNS.
// If empty, A record(s) from DNS lookups of HostName are used. // If empty, A record(s) from DNS lookups of HostName are used.
// If the string is not an IPv4 address, IPv4 is not used; the
// conventional string to disable IPv4 (and not use DNS) is
// "none".
IPv4 string `json:",omitempty"` IPv4 string `json:",omitempty"`
// IPv6 optionally forces an IPv6 address to use, instead of using DNS. // IPv6 optionally forces an IPv6 address to use, instead of using DNS.
// If empty, AAAA record(s) from DNS lookups of HostName are used. // If empty, AAAA record(s) from DNS lookups of HostName are used.
// If the string is not an IPv6 address, IPv6 is not used; the
// conventional string to disable IPv6 (and not use DNS) is
// "none".
IPv6 string `json:",omitempty"` IPv6 string `json:",omitempty"`
// Port optionally specifies a STUN port to use.
// Zero means 3478.
// To disable STUN on this node, use -1.
STUNPort int `json:",omitempty"`
// STUNOnly marks a node as only a STUN server and not a DERP
// server.
STUNOnly bool `json:",omitempty"`
// DERPTestPort is used in tests to override the port, instead
// of using the default port of 443. If non-zero, TLS
// verification is skipped.
DERPTestPort int `json:",omitempty"`
} }

View File

@ -315,8 +315,9 @@ type NetInfo struct {
LinkType string // "wired", "wifi", "mobile" (LTE, 4G, 3G, etc) LinkType string // "wired", "wifi", "mobile" (LTE, 4G, 3G, etc)
// DERPLatency is the fastest recent time to reach various // DERPLatency is the fastest recent time to reach various
// DERP STUN servers, in seconds. The map key is the DERP // DERP STUN servers, in seconds. The map key is the
// server's STUN host:port. // "regionID-v4" or "-v6"; it was previously the DERP server's
// STUN host:port.
// //
// This should only be updated rarely, or when there's a // This should only be updated rarely, or when there's a
// material change, as any change here also gets uploaded to // material change, as any change here also gets uploaded to
@ -336,7 +337,7 @@ func (ni *NetInfo) String() string {
} }
// BasicallyEqual reports whether ni and ni2 are basically equal, ignoring // BasicallyEqual reports whether ni and ni2 are basically equal, ignoring
// changes in DERPLatency. // changes in DERP ServerLatency & RegionLatency.
func (ni *NetInfo) BasicallyEqual(ni2 *NetInfo) bool { func (ni *NetInfo) BasicallyEqual(ni2 *NetInfo) bool {
if (ni == nil) != (ni2 == nil) { if (ni == nil) != (ni2 == nil) {
return false return false

View File

@ -9,7 +9,6 @@
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/tls"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@ -17,6 +16,7 @@
"math/rand" "math/rand"
"net" "net"
"os" "os"
"reflect"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@ -32,7 +32,6 @@
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/derp" "tailscale.com/derp"
"tailscale.com/derp/derphttp" "tailscale.com/derp/derphttp"
"tailscale.com/derp/derpmap"
"tailscale.com/ipn/ipnstate" "tailscale.com/ipn/ipnstate"
"tailscale.com/net/dnscache" "tailscale.com/net/dnscache"
"tailscale.com/net/interfaces" "tailscale.com/net/interfaces"
@ -55,7 +54,6 @@ type Conn struct {
epFunc func(endpoints []string) epFunc func(endpoints []string)
logf logger.Logf logf logger.Logf
sendLogLimit *rate.Limiter sendLogLimit *rate.Limiter
derps *derpmap.World
netChecker *netcheck.Client netChecker *netcheck.Client
// bufferedIPv4From and bufferedIPv4Packet are owned by // bufferedIPv4From and bufferedIPv4Packet are owned by
@ -76,7 +74,8 @@ type Conn struct {
mu sync.Mutex // guards all following fields mu sync.Mutex // guards all following fields
closed bool started bool
closed bool
endpointsUpdateWaiter *sync.Cond endpointsUpdateWaiter *sync.Cond
endpointsUpdateActive bool endpointsUpdateActive bool
@ -104,13 +103,12 @@ type Conn struct {
netInfoFunc func(*tailcfg.NetInfo) // nil until set netInfoFunc func(*tailcfg.NetInfo) // nil until set
netInfoLast *tailcfg.NetInfo netInfoLast *tailcfg.NetInfo
wantDerp bool derpMap *tailcfg.DERPMap // nil (or zero regions/nodes) means DERP is disabled
privateKey key.Private privateKey key.Private
myDerp int // nearest DERP server; 0 means none/unknown myDerp int // nearest DERP region ID; 0 means none/unknown
derpStarted chan struct{} // closed on first connection to DERP; for tests derpStarted chan struct{} // closed on first connection to DERP; for tests
activeDerp map[int]activeDerp activeDerp map[int]activeDerp // DERP regionID -> connection to a node in that region
prevDerp map[int]*syncs.WaitGroupChan prevDerp map[int]*syncs.WaitGroupChan
derpTLSConfig *tls.Config // normally nil; used by tests
// derpRoute contains optional alternate routes to use as an // derpRoute contains optional alternate routes to use as an
// optimization instead of contacting a peer via their home // optimization instead of contacting a peer via their home
@ -196,14 +194,9 @@ type Options struct {
// Zero means to pick one automatically. // Zero means to pick one automatically.
Port uint16 Port uint16
// DERPs, if non-nil, is used instead of derpmap.Prod.
DERPs *derpmap.World
// EndpointsFunc optionally provides a func to be called when // EndpointsFunc optionally provides a func to be called when
// endpoints change. The called func does not own the slice. // endpoints change. The called func does not own the slice.
EndpointsFunc func(endpoint []string) EndpointsFunc func(endpoint []string)
derpTLSConfig *tls.Config // normally nil; used by tests
} }
func (o *Options) logf() logger.Logf { func (o *Options) logf() logger.Logf {
@ -220,37 +213,39 @@ func (o *Options) endpointsFunc() func([]string) {
return o.EndpointsFunc return o.EndpointsFunc
} }
// Listen creates a magic Conn listening on opts.Port. // newConn is the error-free, network-listening-side-effect-free based
// As the set of possible endpoints for a Conn changes, the // of NewConn. Mostly for tests.
// callback opts.EndpointsFunc is called. func newConn() *Conn {
func Listen(opts Options) (*Conn, error) {
c := &Conn{ c := &Conn{
pconnPort: opts.Port, sendLogLimit: rate.NewLimiter(rate.Every(1*time.Minute), 1),
logf: opts.logf(), addrsByUDP: make(map[netaddr.IPPort]*AddrSet),
epFunc: opts.endpointsFunc(), addrsByKey: make(map[key.Public]*AddrSet),
sendLogLimit: rate.NewLimiter(rate.Every(1*time.Minute), 1), derpRecvCh: make(chan derpReadResult),
addrsByUDP: make(map[netaddr.IPPort]*AddrSet), udpRecvCh: make(chan udpReadResult),
addrsByKey: make(map[key.Public]*AddrSet), derpStarted: make(chan struct{}),
wantDerp: true, peerLastDerp: make(map[key.Public]int),
derpRecvCh: make(chan derpReadResult),
udpRecvCh: make(chan udpReadResult),
derpTLSConfig: opts.derpTLSConfig,
derpStarted: make(chan struct{}),
derps: opts.DERPs,
peerLastDerp: make(map[key.Public]int),
} }
c.endpointsUpdateWaiter = sync.NewCond(&c.mu) c.endpointsUpdateWaiter = sync.NewCond(&c.mu)
return c
}
// NewConn creates a magic Conn listening on opts.Port.
// As the set of possible endpoints for a Conn changes, the
// callback opts.EndpointsFunc is called.
//
// It doesn't start doing anything until Start is called.
func NewConn(opts Options) (*Conn, error) {
c := newConn()
c.pconnPort = opts.Port
c.logf = opts.logf()
c.epFunc = opts.endpointsFunc()
if err := c.initialBind(); err != nil { if err := c.initialBind(); err != nil {
return nil, err return nil, err
} }
c.connCtx, c.connCtxCancel = context.WithCancel(context.Background()) c.connCtx, c.connCtxCancel = context.WithCancel(context.Background())
if c.derps == nil {
c.derps = derpmap.Prod()
}
c.netChecker = &netcheck.Client{ c.netChecker = &netcheck.Client{
DERP: c.derps,
Logf: logger.WithPrefix(c.logf, "netcheck: "), Logf: logger.WithPrefix(c.logf, "netcheck: "),
GetSTUNConn4: func() netcheck.STUNConn { return c.pconn4 }, GetSTUNConn4: func() netcheck.STUNConn { return c.pconn4 },
} }
@ -259,6 +254,18 @@ func Listen(opts Options) (*Conn, error) {
} }
c.ignoreSTUNPackets() c.ignoreSTUNPackets()
return c, nil
}
func (c *Conn) Start() {
c.mu.Lock()
if c.started {
panic("duplicate Start call")
}
c.started = true
c.mu.Unlock()
c.ReSTUN("initial") c.ReSTUN("initial")
// We assume that LinkChange notifications are plumbed through well // We assume that LinkChange notifications are plumbed through well
@ -267,8 +274,6 @@ func Listen(opts Options) (*Conn, error) {
go c.periodicReSTUN() go c.periodicReSTUN()
} }
go c.periodicDerpCleanup() go c.periodicDerpCleanup()
return c, nil
} }
func (c *Conn) donec() <-chan struct{} { return c.connCtx.Done() } func (c *Conn) donec() <-chan struct{} { return c.connCtx.Done() }
@ -278,10 +283,6 @@ func (c *Conn) ignoreSTUNPackets() {
c.stunReceiveFunc.Store(func([]byte, *net.UDPAddr) {}) c.stunReceiveFunc.Store(func([]byte, *net.UDPAddr) {})
} }
// runs in its own goroutine until ctx is shut down.
// Whenever c.startEpUpdate receives a value, it starts an
// STUN endpoint lookup.
//
// c.mu must NOT be held. // c.mu must NOT be held.
func (c *Conn) updateEndpoints(why string) { func (c *Conn) updateEndpoints(why string) {
defer func() { defer func() {
@ -326,7 +327,11 @@ func (c *Conn) setEndpoints(endpoints []string) (changed bool) {
} }
func (c *Conn) updateNetInfo(ctx context.Context) (*netcheck.Report, error) { func (c *Conn) updateNetInfo(ctx context.Context) (*netcheck.Report, error) {
if DisableSTUNForTesting { c.mu.Lock()
dm := c.derpMap
c.mu.Unlock()
if DisableSTUNForTesting || dm == nil {
return new(netcheck.Report), nil return new(netcheck.Report), nil
} }
@ -336,7 +341,7 @@ func (c *Conn) updateNetInfo(ctx context.Context) (*netcheck.Report, error) {
c.stunReceiveFunc.Store(c.netChecker.ReceiveSTUNPacket) c.stunReceiveFunc.Store(c.netChecker.ReceiveSTUNPacket)
defer c.ignoreSTUNPackets() defer c.ignoreSTUNPackets()
report, err := c.netChecker.GetReport(ctx) report, err := c.netChecker.GetReport(ctx, dm)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -346,8 +351,11 @@ func (c *Conn) updateNetInfo(ctx context.Context) (*netcheck.Report, error) {
MappingVariesByDestIP: report.MappingVariesByDestIP, MappingVariesByDestIP: report.MappingVariesByDestIP,
HairPinning: report.HairPinning, HairPinning: report.HairPinning,
} }
for server, d := range report.DERPLatency { for rid, d := range report.RegionV4Latency {
ni.DERPLatency[server] = d.Seconds() ni.DERPLatency[fmt.Sprintf("%d-v4", rid)] = d.Seconds()
}
for rid, d := range report.RegionV6Latency {
ni.DERPLatency[fmt.Sprintf("%d-v6", rid)] = d.Seconds()
} }
ni.WorkingIPv6.Set(report.IPv6) ni.WorkingIPv6.Set(report.IPv6)
ni.WorkingUDP.Set(report.UDP) ni.WorkingUDP.Set(report.UDP)
@ -380,9 +388,12 @@ func (c *Conn) pickDERPFallback() int {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
ids := c.derps.IDs() if !c.wantDerpLocked() {
return 0
}
ids := c.derpMap.RegionIDs()
if len(ids) == 0 { if len(ids) == 0 {
// No DERP nodes registered. // No DERP regions in non-nil map.
return 0 return 0
} }
@ -458,7 +469,7 @@ func (c *Conn) SetNetInfoCallback(fn func(*tailcfg.NetInfo)) {
func (c *Conn) setNearestDERP(derpNum int) (wantDERP bool) { func (c *Conn) setNearestDERP(derpNum int) (wantDERP bool) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if !c.wantDerp { if !c.wantDerpLocked() {
c.myDerp = 0 c.myDerp = 0
return false return false
} }
@ -476,7 +487,7 @@ func (c *Conn) setNearestDERP(derpNum int) (wantDERP bool) {
// On change, notify all currently connected DERP servers and // On change, notify all currently connected DERP servers and
// start connecting to our home DERP if we are not already. // start connecting to our home DERP if we are not already.
c.logf("magicsock: home is now derp-%v (%v)", derpNum, c.derps.ServerByID(derpNum).Geo) c.logf("magicsock: home is now derp-%v (%v)", derpNum, c.derpMap.Regions[derpNum].RegionCode)
for i, ad := range c.activeDerp { for i, ad := range c.activeDerp {
go ad.c.NotePreferred(i == c.myDerp) go ad.c.NotePreferred(i == c.myDerp)
} }
@ -791,11 +802,11 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr, peer key.Public) chan<- de
if !addr.IP.Equal(derpMagicIP) { if !addr.IP.Equal(derpMagicIP) {
return nil return nil
} }
nodeID := addr.Port regionID := addr.Port
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if !c.wantDerp || c.closed { if !c.wantDerpLocked() || c.closed {
return nil return nil
} }
if c.privateKey.IsZero() { if c.privateKey.IsZero() {
@ -807,10 +818,10 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr, peer key.Public) chan<- de
// first. If so, might as well use it. (It's a little // first. If so, might as well use it. (It's a little
// arbitrary whether we use this one vs. the reverse route // arbitrary whether we use this one vs. the reverse route
// below when we have both.) // below when we have both.)
ad, ok := c.activeDerp[nodeID] ad, ok := c.activeDerp[regionID]
if ok { if ok {
*ad.lastWrite = time.Now() *ad.lastWrite = time.Now()
c.setPeerLastDerpLocked(peer, nodeID, nodeID) c.setPeerLastDerpLocked(peer, regionID, regionID)
return ad.writeCh return ad.writeCh
} }
@ -823,7 +834,7 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr, peer key.Public) chan<- de
if !peer.IsZero() && debugUseDerpRoute { if !peer.IsZero() && debugUseDerpRoute {
if r, ok := c.derpRoute[peer]; ok { if r, ok := c.derpRoute[peer]; ok {
if ad, ok := c.activeDerp[r.derpID]; ok && ad.c == r.dc { if ad, ok := c.activeDerp[r.derpID]; ok && ad.c == r.dc {
c.setPeerLastDerpLocked(peer, r.derpID, nodeID) c.setPeerLastDerpLocked(peer, r.derpID, regionID)
*ad.lastWrite = time.Now() *ad.lastWrite = time.Now()
return ad.writeCh return ad.writeCh
} }
@ -834,7 +845,7 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr, peer key.Public) chan<- de
if !peer.IsZero() { if !peer.IsZero() {
why = peerShort(peer) why = peerShort(peer)
} }
c.logf("magicsock: adding connection to derp-%v for %v", nodeID, why) c.logf("magicsock: adding connection to derp-%v for %v", regionID, why)
firstDerp := false firstDerp := false
if c.activeDerp == nil { if c.activeDerp == nil {
@ -842,22 +853,23 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr, peer key.Public) chan<- de
c.activeDerp = make(map[int]activeDerp) c.activeDerp = make(map[int]activeDerp)
c.prevDerp = make(map[int]*syncs.WaitGroupChan) c.prevDerp = make(map[int]*syncs.WaitGroupChan)
} }
derpSrv := c.derps.ServerByID(nodeID) if c.derpMap == nil || c.derpMap.Regions[regionID] == nil {
if derpSrv == nil || derpSrv.HostHTTPS == "" {
return nil return nil
} }
// Note that derphttp.NewClient does not dial the server // Note that derphttp.NewClient does not dial the server
// so it is safe to do under the mu lock. // so it is safe to do under the mu lock.
dc, err := derphttp.NewClient(c.privateKey, "https://"+derpSrv.HostHTTPS+"/derp", c.logf) dc := derphttp.NewRegionClient(c.privateKey, c.logf, func() *tailcfg.DERPRegion {
if err != nil { c.mu.Lock()
c.logf("magicsock: derphttp.NewClient: node %d, host %q invalid? err: %v", nodeID, derpSrv.HostHTTPS, err) defer c.mu.Unlock()
return nil if c.derpMap == nil {
} return nil
}
return c.derpMap.Regions[regionID]
})
dc.NotePreferred(c.myDerp == nodeID) dc.NotePreferred(c.myDerp == regionID)
dc.DNSCache = dnscache.Get() dc.DNSCache = dnscache.Get()
dc.TLSConfig = c.derpTLSConfig
ctx, cancel := context.WithCancel(c.connCtx) ctx, cancel := context.WithCancel(c.connCtx)
ch := make(chan derpWriteRequest, bufferedDerpWritesBeforeDrop) ch := make(chan derpWriteRequest, bufferedDerpWritesBeforeDrop)
@ -868,21 +880,21 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr, peer key.Public) chan<- de
ad.lastWrite = new(time.Time) ad.lastWrite = new(time.Time)
*ad.lastWrite = time.Now() *ad.lastWrite = time.Now()
ad.createTime = time.Now() ad.createTime = time.Now()
c.activeDerp[nodeID] = ad c.activeDerp[regionID] = ad
c.logActiveDerpLocked() c.logActiveDerpLocked()
c.setPeerLastDerpLocked(peer, nodeID, nodeID) c.setPeerLastDerpLocked(peer, regionID, regionID)
// Build a startGate for the derp reader+writer // Build a startGate for the derp reader+writer
// goroutines, so they don't start running until any // goroutines, so they don't start running until any
// previous generation is closed. // previous generation is closed.
startGate := syncs.ClosedChan() startGate := syncs.ClosedChan()
if prev := c.prevDerp[nodeID]; prev != nil { if prev := c.prevDerp[regionID]; prev != nil {
startGate = prev.DoneChan() startGate = prev.DoneChan()
} }
// And register a WaitGroup(Chan) for this generation. // And register a WaitGroup(Chan) for this generation.
wg := syncs.NewWaitGroupChan() wg := syncs.NewWaitGroupChan()
wg.Add(2) wg.Add(2)
c.prevDerp[nodeID] = wg c.prevDerp[regionID] = wg
if firstDerp { if firstDerp {
startGate = c.derpStarted startGate = c.derpStarted
@ -899,37 +911,37 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr, peer key.Public) chan<- de
} }
// setPeerLastDerpLocked notes that peer is now being written to via // setPeerLastDerpLocked notes that peer is now being written to via
// provided DERP node nodeID, and that that advertises a DERP home // the provided DERP regionID, and that the peer advertises a DERP
// node of homeID. // home region ID of homeID.
// //
// If there's any change, it logs. // If there's any change, it logs.
// //
// c.mu must be held. // c.mu must be held.
func (c *Conn) setPeerLastDerpLocked(peer key.Public, nodeID, homeID int) { func (c *Conn) setPeerLastDerpLocked(peer key.Public, regionID, homeID int) {
if peer.IsZero() { if peer.IsZero() {
return return
} }
old := c.peerLastDerp[peer] old := c.peerLastDerp[peer]
if old == nodeID { if old == regionID {
return return
} }
c.peerLastDerp[peer] = nodeID c.peerLastDerp[peer] = regionID
var newDesc string var newDesc string
switch { switch {
case nodeID == homeID && nodeID == c.myDerp: case regionID == homeID && regionID == c.myDerp:
newDesc = "shared home" newDesc = "shared home"
case nodeID == homeID: case regionID == homeID:
newDesc = "their home" newDesc = "their home"
case nodeID == c.myDerp: case regionID == c.myDerp:
newDesc = "our home" newDesc = "our home"
case nodeID != homeID: case regionID != homeID:
newDesc = "alt" newDesc = "alt"
} }
if old == 0 { if old == 0 {
c.logf("magicsock: derp route for %s set to derp-%d (%s)", peerShort(peer), nodeID, newDesc) c.logf("magicsock: derp route for %s set to derp-%d (%s)", peerShort(peer), regionID, newDesc)
} else { } else {
c.logf("magicsock: derp route for %s changed from derp-%d => derp-%d (%s)", peerShort(peer), old, nodeID, newDesc) c.logf("magicsock: derp route for %s changed from derp-%d => derp-%d (%s)", peerShort(peer), old, regionID, newDesc)
} }
} }
@ -1284,18 +1296,27 @@ func (c *Conn) UpdatePeers(newPeers map[key.Public]struct{}) {
} }
} }
// SetDERPEnabled controls whether DERP is used. // SetDERPMap controls which (if any) DERP servers are used.
// New connections have it enabled by default. // A nil value means to disable DERP; it's disabled by default.
func (c *Conn) SetDERPEnabled(wantDerp bool) { func (c *Conn) SetDERPMap(dm *tailcfg.DERPMap) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
c.wantDerp = wantDerp if reflect.DeepEqual(dm, c.derpMap) {
if !wantDerp { return
c.closeAllDerpLocked("derp-disabled")
} }
c.derpMap = dm
if dm == nil {
c.closeAllDerpLocked("derp-disabled")
return
}
go c.ReSTUN("derp-map-update")
} }
func (c *Conn) wantDerpLocked() bool { return c.derpMap != nil }
// c.mu must be held. // c.mu must be held.
func (c *Conn) closeAllDerpLocked(why string) { func (c *Conn) closeAllDerpLocked(why string) {
if len(c.activeDerp) == 0 { if len(c.activeDerp) == 0 {
@ -1352,7 +1373,7 @@ func (c *Conn) logEndpointChange(endpoints []string, reasons map[string]string)
} }
// c.mu must be held. // c.mu must be held.
func (c *Conn) foreachActiveDerpSortedLocked(fn func(nodeID int, ad activeDerp)) { func (c *Conn) foreachActiveDerpSortedLocked(fn func(regionID int, ad activeDerp)) {
if len(c.activeDerp) < 2 { if len(c.activeDerp) < 2 {
for id, ad := range c.activeDerp { for id, ad := range c.activeDerp {
fn(id, ad) fn(id, ad)
@ -1473,6 +1494,9 @@ func (c *Conn) periodicDerpCleanup() {
func (c *Conn) ReSTUN(why string) { func (c *Conn) ReSTUN(why string) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if !c.started {
panic("call to ReSTUN before Start")
}
if c.closed { if c.closed {
// raced with a shutdown. // raced with a shutdown.
return return

View File

@ -27,6 +27,7 @@
"tailscale.com/derp/derphttp" "tailscale.com/derp/derphttp"
"tailscale.com/derp/derpmap" "tailscale.com/derp/derpmap"
"tailscale.com/stun/stuntest" "tailscale.com/stun/stuntest"
"tailscale.com/tailcfg"
"tailscale.com/tstest" "tailscale.com/tstest"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
@ -54,7 +55,7 @@ func (c *Conn) WaitReady(t *testing.T) {
} }
} }
func TestListen(t *testing.T) { func TestNewConn(t *testing.T) {
tstest.PanicOnLog() tstest.PanicOnLog()
rc := tstest.NewResourceCheck() rc := tstest.NewResourceCheck()
defer rc.Assert(t) defer rc.Assert(t)
@ -70,9 +71,8 @@ func TestListen(t *testing.T) {
defer stunCleanupFn() defer stunCleanupFn()
port := pickPort(t) port := pickPort(t)
conn, err := Listen(Options{ conn, err := NewConn(Options{
Port: port, Port: port,
DERPs: derpmap.NewTestWorld(stunAddr),
EndpointsFunc: epFunc, EndpointsFunc: epFunc,
Logf: t.Logf, Logf: t.Logf,
}) })
@ -80,6 +80,8 @@ func TestListen(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
defer conn.Close() defer conn.Close()
conn.Start()
conn.SetDERPMap(stuntest.DERPMapOf(stunAddr.String()))
go func() { go func() {
var pkt [64 << 10]byte var pkt [64 << 10]byte
@ -136,9 +138,8 @@ func TestPickDERPFallback(t *testing.T) {
rc := tstest.NewResourceCheck() rc := tstest.NewResourceCheck()
defer rc.Assert(t) defer rc.Assert(t)
c := &Conn{ c := newConn()
derps: derpmap.Prod(), c.derpMap = derpmap.Prod()
}
a := c.pickDERPFallback() a := c.pickDERPFallback()
if a == 0 { if a == 0 {
t.Fatalf("pickDERPFallback returned 0") t.Fatalf("pickDERPFallback returned 0")
@ -156,7 +157,8 @@ func TestPickDERPFallback(t *testing.T) {
// distribution over nodes works. // distribution over nodes works.
got := map[int]int{} got := map[int]int{}
for i := 0; i < 50; i++ { for i := 0; i < 50; i++ {
c = &Conn{derps: derpmap.Prod()} c = newConn()
c.derpMap = derpmap.Prod()
got[c.pickDERPFallback()]++ got[c.pickDERPFallback()]++
} }
t.Logf("distribution: %v", got) t.Logf("distribution: %v", got)
@ -236,7 +238,7 @@ func parseCIDR(t *testing.T, addr string) wgcfg.CIDR {
return cidr return cidr
} }
func runDERP(t *testing.T, logf logger.Logf) (s *derp.Server, addr string, cleanupFn func()) { func runDERP(t *testing.T, logf logger.Logf) (s *derp.Server, addr *net.TCPAddr, cleanupFn func()) {
var serverPrivateKey key.Private var serverPrivateKey key.Private
if _, err := crand.Read(serverPrivateKey[:]); err != nil { if _, err := crand.Read(serverPrivateKey[:]); err != nil {
t.Fatal(err) t.Fatal(err)
@ -250,14 +252,13 @@ func runDERP(t *testing.T, logf logger.Logf) (s *derp.Server, addr string, clean
httpsrv.StartTLS() httpsrv.StartTLS()
logf("DERP server URL: %s", httpsrv.URL) logf("DERP server URL: %s", httpsrv.URL)
addr = strings.TrimPrefix(httpsrv.URL, "https://")
cleanupFn = func() { cleanupFn = func() {
httpsrv.CloseClientConnections() httpsrv.CloseClientConnections()
httpsrv.Close() httpsrv.Close()
s.Close() s.Close()
} }
return s, addr, cleanupFn return s, httpsrv.Listener.Addr().(*net.TCPAddr), cleanupFn
} }
// devLogger returns a wireguard-go device.Logger that writes // devLogger returns a wireguard-go device.Logger that writes
@ -286,13 +287,14 @@ func TestDeviceStartStop(t *testing.T) {
rc := tstest.NewResourceCheck() rc := tstest.NewResourceCheck()
defer rc.Assert(t) defer rc.Assert(t)
conn, err := Listen(Options{ conn, err := NewConn(Options{
EndpointsFunc: func(eps []string) {}, EndpointsFunc: func(eps []string) {},
Logf: t.Logf, Logf: t.Logf,
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
conn.Start()
defer conn.Close() defer conn.Close()
tun := tuntest.NewChannelTUN() tun := tuntest.NewChannelTUN()
@ -337,48 +339,58 @@ func TestTwoDevicePing(t *testing.T) {
// all log using the "current" t.Logf function. Sigh. // all log using the "current" t.Logf function. Sigh.
logf, setT := makeNestable(t) logf, setT := makeNestable(t)
// Wipe default DERP list, add local server.
// (Do it now, or derpHost will try to connect to derp1.tailscale.com.)
derpServer, derpAddr, derpCleanupFn := runDERP(t, logf) derpServer, derpAddr, derpCleanupFn := runDERP(t, logf)
defer derpCleanupFn() defer derpCleanupFn()
stunAddr, stunCleanupFn := stuntest.Serve(t) stunAddr, stunCleanupFn := stuntest.Serve(t)
defer stunCleanupFn() defer stunCleanupFn()
derps := derpmap.NewTestWorldWith(&derpmap.Server{ derpMap := &tailcfg.DERPMap{
ID: 1, Regions: map[int]*tailcfg.DERPRegion{
HostHTTPS: derpAddr, 1: &tailcfg.DERPRegion{
STUN4: stunAddr, RegionID: 1,
Geo: "Testopolis", RegionCode: "test",
}) Nodes: []*tailcfg.DERPNode{
{
Name: "t1",
RegionID: 1,
HostName: "test-node.unused",
IPv4: "127.0.0.1",
IPv6: "none",
STUNPort: stunAddr.Port,
DERPTestPort: derpAddr.Port,
},
},
},
},
}
epCh1 := make(chan []string, 16) epCh1 := make(chan []string, 16)
conn1, err := Listen(Options{ conn1, err := NewConn(Options{
Logf: logger.WithPrefix(logf, "conn1: "), Logf: logger.WithPrefix(logf, "conn1: "),
DERPs: derps,
EndpointsFunc: func(eps []string) { EndpointsFunc: func(eps []string) {
epCh1 <- eps epCh1 <- eps
}, },
derpTLSConfig: &tls.Config{InsecureSkipVerify: true},
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer conn1.Close() defer conn1.Close()
conn1.Start()
conn1.SetDERPMap(derpMap)
epCh2 := make(chan []string, 16) epCh2 := make(chan []string, 16)
conn2, err := Listen(Options{ conn2, err := NewConn(Options{
Logf: logger.WithPrefix(logf, "conn2: "), Logf: logger.WithPrefix(logf, "conn2: "),
DERPs: derps,
EndpointsFunc: func(eps []string) { EndpointsFunc: func(eps []string) {
epCh2 <- eps epCh2 <- eps
}, },
derpTLSConfig: &tls.Config{InsecureSkipVerify: true},
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer conn2.Close() defer conn2.Close()
conn2.Start()
conn2.SetDERPMap(derpMap)
ports := []uint16{conn1.LocalPort(), conn2.LocalPort()} ports := []uint16{conn1.LocalPort(), conn2.LocalPort()}
cfgs := makeConfigs(t, ports) cfgs := makeConfigs(t, ports)

View File

@ -8,6 +8,7 @@
"bufio" "bufio"
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
@ -49,7 +50,7 @@
type userspaceEngine struct { type userspaceEngine struct {
logf logger.Logf logf logger.Logf
reqCh chan struct{} reqCh chan struct{}
waitCh chan struct{} waitCh chan struct{} // chan is closed when first Close call completes; contrast with closing bool
tundev *tstun.TUN tundev *tstun.TUN
wgdev *device.Device wgdev *device.Device
router router.Router router router.Router
@ -61,6 +62,7 @@ type userspaceEngine struct {
lastCfg wgcfg.Config lastCfg wgcfg.Config
mu sync.Mutex // guards following; see lock order comment below mu sync.Mutex // guards following; see lock order comment below
closing bool // Close was called (even if we're still closing)
statusCallback StatusCallback statusCallback StatusCallback
peerSequence []wgcfg.Key peerSequence []wgcfg.Key
endpoints []string endpoints []string
@ -149,7 +151,7 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R
Port: listenPort, Port: listenPort,
EndpointsFunc: endpointsFn, EndpointsFunc: endpointsFn,
} }
e.magicConn, err = magicsock.Listen(magicsockOpts) e.magicConn, err = magicsock.NewConn(magicsockOpts)
if err != nil { if err != nil {
tundev.Close() tundev.Close()
return nil, fmt.Errorf("wgengine: %v", err) return nil, fmt.Errorf("wgengine: %v", err)
@ -210,6 +212,7 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R
// routers do not Read or Write, but do access native interfaces. // routers do not Read or Write, but do access native interfaces.
e.router, err = routerGen(logf, e.wgdev, e.tundev.Unwrap()) e.router, err = routerGen(logf, e.wgdev, e.tundev.Unwrap())
if err != nil { if err != nil {
e.magicConn.Close()
return nil, err return nil, err
} }
@ -235,16 +238,19 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R
e.wgdev.Up() e.wgdev.Up()
if err := e.router.Up(); err != nil { if err := e.router.Up(); err != nil {
e.magicConn.Close()
e.wgdev.Close() e.wgdev.Close()
return nil, err return nil, err
} }
// TODO(danderson): we should delete this. It's pointless to apply // TODO(danderson): we should delete this. It's pointless to apply
// a no-op settings here. // a no-op settings here.
if err := e.router.Set(nil); err != nil { if err := e.router.Set(nil); err != nil {
e.magicConn.Close()
e.wgdev.Close() e.wgdev.Close()
return nil, err return nil, err
} }
e.linkMon.Start() e.linkMon.Start()
e.magicConn.Start()
return e, nil return e, nil
} }
@ -407,6 +413,13 @@ func (e *userspaceEngine) getStatus() (*Status, error) {
e.wgLock.Lock() e.wgLock.Lock()
defer e.wgLock.Unlock() defer e.wgLock.Unlock()
e.mu.Lock()
closing := e.closing
e.mu.Unlock()
if closing {
return nil, errors.New("engine closing; no status")
}
if e.wgdev == nil { if e.wgdev == nil {
// RequestStatus was invoked before the wgengine has // RequestStatus was invoked before the wgengine has
// finished initializing. This can happen when wgegine // finished initializing. This can happen when wgegine
@ -553,6 +566,11 @@ func (e *userspaceEngine) RequestStatus() {
func (e *userspaceEngine) Close() { func (e *userspaceEngine) Close() {
e.mu.Lock() e.mu.Lock()
if e.closing {
e.mu.Unlock()
return
}
e.closing = true
for key, cancel := range e.pingers { for key, cancel := range e.pingers {
delete(e.pingers, key) delete(e.pingers, key)
cancel() cancel()
@ -614,8 +632,8 @@ func (e *userspaceEngine) SetNetInfoCallback(cb NetInfoCallback) {
e.magicConn.SetNetInfoCallback(cb) e.magicConn.SetNetInfoCallback(cb)
} }
func (e *userspaceEngine) SetDERPEnabled(v bool) { func (e *userspaceEngine) SetDERPMap(dm *tailcfg.DERPMap) {
e.magicConn.SetDERPEnabled(v) e.magicConn.SetDERPMap(dm)
} }
func (e *userspaceEngine) UpdateStatus(sb *ipnstate.StatusBuilder) { func (e *userspaceEngine) UpdateStatus(sb *ipnstate.StatusBuilder) {

View File

@ -12,6 +12,7 @@
"github.com/tailscale/wireguard-go/wgcfg" "github.com/tailscale/wireguard-go/wgcfg"
"tailscale.com/ipn/ipnstate" "tailscale.com/ipn/ipnstate"
"tailscale.com/tailcfg"
"tailscale.com/wgengine/filter" "tailscale.com/wgengine/filter"
"tailscale.com/wgengine/router" "tailscale.com/wgengine/router"
) )
@ -88,8 +89,8 @@ func (e *watchdogEngine) RequestStatus() {
func (e *watchdogEngine) LinkChange(isExpensive bool) { func (e *watchdogEngine) LinkChange(isExpensive bool) {
e.watchdog("LinkChange", func() { e.wrap.LinkChange(isExpensive) }) e.watchdog("LinkChange", func() { e.wrap.LinkChange(isExpensive) })
} }
func (e *watchdogEngine) SetDERPEnabled(v bool) { func (e *watchdogEngine) SetDERPMap(m *tailcfg.DERPMap) {
e.watchdog("SetDERPEnabled", func() { e.wrap.SetDERPEnabled(v) }) e.watchdog("SetDERPMap", func() { e.wrap.SetDERPMap(m) })
} }
func (e *watchdogEngine) Close() { func (e *watchdogEngine) Close() {
e.watchdog("Close", e.wrap.Close) e.watchdog("Close", e.wrap.Close)

View File

@ -95,9 +95,10 @@ type Engine interface {
// action on. // action on.
LinkChange(isExpensive bool) LinkChange(isExpensive bool)
// SetDERPEnabled controls whether DERP is enabled. // SetDERPMap controls which (if any) DERP servers are used.
// It starts enabled by default. // If nil, DERP is disabled. It starts disabled until a DERP map
SetDERPEnabled(bool) // is configured.
SetDERPMap(*tailcfg.DERPMap)
// SetNetInfoCallback sets the function to call when a // SetNetInfoCallback sets the function to call when a
// new NetInfo summary is available. // new NetInfo summary is available.