tailscale/stunner/stunner.go
Brad Fitzpatrick 01b4bec33f stunner: re-do how Stunner works
It used to make assumptions based on having Anycast IPs that are super
near. Now we're intentionally going to a bunch of different distant
IPs to measure latency.

Also, optimize how the hairpin detection works. No need to STUN on
that socket. Just use that separate socket for sending, once we know
the other UDP4 socket's endpoint. The trick is: make our test probe
also a STUN packet, so it fits through magicsock's existing STUN
routing.

This drops netcheck from ~5 seconds to ~250-500ms.

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
2020-03-11 08:08:48 -07:00

288 lines
7.0 KiB
Go

// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stunner
import (
"context"
"errors"
"fmt"
"math/rand"
"net"
"strconv"
"strings"
"sync"
"time"
"tailscale.com/net/dnscache"
"tailscale.com/stun"
)
// Stunner sends a STUN request to several servers and handles a response.
//
// It is designed to used on a connection owned by other code and so does
// not directly reference a net.Conn of any sort. Instead, the user should
// provide Send function to send packets, and call Receive when a new
// STUN response is received.
//
// In response, a Stunner will call Endpoint with any endpoints determined
// for the connection. (An endpoint may be reported multiple times if
// multiple servers are provided.)
type Stunner struct {
// Send sends a packet.
// It will typically be a PacketConn.WriteTo method value.
Send func([]byte, net.Addr) (int, error) // sends a packet
// Endpoint is called whenever a STUN response is received.
// The server is the STUN server that replied, endpoint is the ip:port
// from the STUN response, and d is the duration that the STUN request
// took on the wire (not including DNS lookup time.
Endpoint func(server, endpoint string, d time.Duration)
// onPacket is the internal version of Endpoint that does de-dup.
// It's set by Run.
onPacket func(server, endpoint string, d time.Duration)
Servers []string // STUN servers to contact
// DNSCache optionally specifies a DNSCache to use.
// If nil, a DNS cache is not used.
DNSCache *dnscache.Resolver
// Logf optionally specifies a log function. If nil, logging is disabled.
Logf func(format string, args ...interface{})
// OnlyIPv6 controls whether IPv6 is exclusively used.
// If false, only IPv4 is used. There is currently no mixed mode.
OnlyIPv6 bool
mu sync.Mutex
inFlight map[stun.TxID]request
}
func (s *Stunner) addTX(tx stun.TxID, server string) {
s.mu.Lock()
defer s.mu.Unlock()
if _, dup := s.inFlight[tx]; dup {
panic("unexpected duplicate STUN TransactionID")
}
s.inFlight[tx] = request{sent: time.Now(), server: server}
}
func (s *Stunner) removeTX(tx stun.TxID) (request, bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.inFlight == nil {
return request{}, false
}
r, ok := s.inFlight[tx]
if ok {
delete(s.inFlight, tx)
} else {
s.logf("stunner: got STUN packet for unknown TxID %x", tx)
}
return r, ok
}
type request struct {
sent time.Time
server string
}
func (s *Stunner) logf(format string, args ...interface{}) {
if s.Logf != nil {
s.Logf(format, args...)
}
}
// Receive delivers a STUN packet to the stunner.
func (s *Stunner) Receive(p []byte, fromAddr *net.UDPAddr) {
if !stun.Is(p) {
s.logf("[unexpected] stunner: received non-STUN packet")
return
}
now := time.Now()
tx, addr, port, err := stun.ParseResponse(p)
if err != nil {
s.logf("stunner: received bad STUN response: %v", err)
return
}
r, ok := s.removeTX(tx)
if !ok {
return
}
d := now.Sub(r.sent)
host := net.JoinHostPort(net.IP(addr).String(), fmt.Sprint(port))
s.onPacket(r.server, host, d)
}
func (s *Stunner) resolver() *net.Resolver {
return net.DefaultResolver
}
// cleanUpPostRun zeros out some fields, mostly for debugging (so
// things crash or race+fail if there's a sender still running.)
func (s *Stunner) cleanUpPostRun() {
s.mu.Lock()
s.inFlight = nil
s.mu.Unlock()
}
// Run starts a Stunner and blocks until all servers either respond
// or are tried multiple times and timeout.
// It can not be called concurrently with itself.
func (s *Stunner) Run(ctx context.Context) error {
for _, server := range s.Servers {
if _, _, err := net.SplitHostPort(server); err != nil {
return fmt.Errorf("Stunner.Run: invalid server %q (in Server list %q)", server, s.Servers)
}
}
if len(s.Servers) == 0 {
return errors.New("stunner: no Servers")
}
s.inFlight = make(map[stun.TxID]request)
defer s.cleanUpPostRun()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
type sender struct {
ctx context.Context
cancel context.CancelFunc
}
var (
needMu sync.Mutex
need = make(map[string]sender) // keyed by server; deleted when done
allDone = make(chan struct{}) // closed when need is empty
)
s.onPacket = func(server, endpoint string, d time.Duration) {
needMu.Lock()
defer needMu.Unlock()
sender, ok := need[server]
if !ok {
return
}
sender.cancel()
delete(need, server)
s.Endpoint(server, endpoint, d)
if len(need) == 0 {
close(allDone)
}
}
var wg sync.WaitGroup
for _, server := range s.Servers {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
need[server] = sender{ctx, cancel}
}
for server, sender := range need {
wg.Add(1)
server, ctx := server, sender.ctx
go func() {
defer wg.Done()
s.sendPackets(ctx, server)
}()
}
var err error
select {
case <-ctx.Done():
err = ctx.Err()
case <-allDone:
cancel()
}
wg.Wait()
var missing []string
needMu.Lock()
for server := range need {
missing = append(missing, server)
}
needMu.Unlock()
if len(missing) == 0 || err == nil {
return nil
}
return fmt.Errorf("got STUN error: %v; missing replies from: %v", err, strings.Join(missing, ", "))
}
func (s *Stunner) serverAddr(ctx context.Context, server string) (*net.UDPAddr, error) {
hostStr, portStr, err := net.SplitHostPort(server)
if err != nil {
return nil, err
}
addrPort, err := strconv.Atoi(portStr)
if err != nil {
return nil, fmt.Errorf("port: %v", err)
}
if addrPort == 0 {
addrPort = 3478
}
addr := &net.UDPAddr{Port: addrPort}
var ipAddrs []net.IPAddr
if s.DNSCache != nil {
ip, err := s.DNSCache.LookupIP(ctx, hostStr)
if err != nil {
return nil, err
}
ipAddrs = []net.IPAddr{{IP: ip}}
} else {
ipAddrs, err = s.resolver().LookupIPAddr(ctx, hostStr)
if err != nil {
return nil, fmt.Errorf("lookup ip addr (%q): %v", hostStr, err)
}
}
for _, ipAddr := range ipAddrs {
ip4 := ipAddr.IP.To4()
if ip4 != nil {
if s.OnlyIPv6 {
continue
}
addr.IP = ip4
break
} else if s.OnlyIPv6 {
addr.IP = ipAddr.IP
addr.Zone = ipAddr.Zone
}
}
if addr.IP == nil {
if s.OnlyIPv6 {
return nil, fmt.Errorf("cannot resolve any ipv6 addresses for %s, got: %v", server, ipAddrs)
}
return nil, fmt.Errorf("cannot resolve any ipv4 addresses for %s, got: %v", server, ipAddrs)
}
return addr, nil
}
func (s *Stunner) sendPackets(ctx context.Context, server string) error {
addr, err := s.serverAddr(ctx, server)
if err != nil {
return err
}
const maxSend = 2
for i := 0; i < maxSend; i++ {
txID := stun.NewTxID()
req := stun.Request(txID)
s.addTX(txID, server)
_, err = s.Send(req, addr)
if err != nil {
return fmt.Errorf("send: %v", err)
}
select {
case <-ctx.Done():
// Ignore error. The caller deals with handling contexts.
// We only use it to dermine when to stop spraying STUN packets.
return nil
case <-time.After(time.Millisecond * time.Duration(50+rand.Intn(200))):
}
}
return nil
}