mirror of
https://github.com/tailscale/tailscale.git
synced 2025-01-08 09:07:44 +00:00
01b4bec33f
It used to make assumptions based on having Anycast IPs that are super near. Now we're intentionally going to a bunch of different distant IPs to measure latency. Also, optimize how the hairpin detection works. No need to STUN on that socket. Just use that separate socket for sending, once we know the other UDP4 socket's endpoint. The trick is: make our test probe also a STUN packet, so it fits through magicsock's existing STUN routing. This drops netcheck from ~5 seconds to ~250-500ms. Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
288 lines
7.0 KiB
Go
288 lines
7.0 KiB
Go
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package stunner
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"math/rand"
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"tailscale.com/net/dnscache"
|
|
"tailscale.com/stun"
|
|
)
|
|
|
|
// Stunner sends a STUN request to several servers and handles a response.
|
|
//
|
|
// It is designed to used on a connection owned by other code and so does
|
|
// not directly reference a net.Conn of any sort. Instead, the user should
|
|
// provide Send function to send packets, and call Receive when a new
|
|
// STUN response is received.
|
|
//
|
|
// In response, a Stunner will call Endpoint with any endpoints determined
|
|
// for the connection. (An endpoint may be reported multiple times if
|
|
// multiple servers are provided.)
|
|
type Stunner struct {
|
|
// Send sends a packet.
|
|
// It will typically be a PacketConn.WriteTo method value.
|
|
Send func([]byte, net.Addr) (int, error) // sends a packet
|
|
|
|
// Endpoint is called whenever a STUN response is received.
|
|
// The server is the STUN server that replied, endpoint is the ip:port
|
|
// from the STUN response, and d is the duration that the STUN request
|
|
// took on the wire (not including DNS lookup time.
|
|
Endpoint func(server, endpoint string, d time.Duration)
|
|
|
|
// onPacket is the internal version of Endpoint that does de-dup.
|
|
// It's set by Run.
|
|
onPacket func(server, endpoint string, d time.Duration)
|
|
|
|
Servers []string // STUN servers to contact
|
|
|
|
// DNSCache optionally specifies a DNSCache to use.
|
|
// If nil, a DNS cache is not used.
|
|
DNSCache *dnscache.Resolver
|
|
|
|
// Logf optionally specifies a log function. If nil, logging is disabled.
|
|
Logf func(format string, args ...interface{})
|
|
|
|
// OnlyIPv6 controls whether IPv6 is exclusively used.
|
|
// If false, only IPv4 is used. There is currently no mixed mode.
|
|
OnlyIPv6 bool
|
|
|
|
mu sync.Mutex
|
|
inFlight map[stun.TxID]request
|
|
}
|
|
|
|
func (s *Stunner) addTX(tx stun.TxID, server string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if _, dup := s.inFlight[tx]; dup {
|
|
panic("unexpected duplicate STUN TransactionID")
|
|
}
|
|
s.inFlight[tx] = request{sent: time.Now(), server: server}
|
|
}
|
|
|
|
func (s *Stunner) removeTX(tx stun.TxID) (request, bool) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.inFlight == nil {
|
|
return request{}, false
|
|
}
|
|
r, ok := s.inFlight[tx]
|
|
if ok {
|
|
delete(s.inFlight, tx)
|
|
} else {
|
|
s.logf("stunner: got STUN packet for unknown TxID %x", tx)
|
|
}
|
|
return r, ok
|
|
}
|
|
|
|
type request struct {
|
|
sent time.Time
|
|
server string
|
|
}
|
|
|
|
func (s *Stunner) logf(format string, args ...interface{}) {
|
|
if s.Logf != nil {
|
|
s.Logf(format, args...)
|
|
}
|
|
}
|
|
|
|
// Receive delivers a STUN packet to the stunner.
|
|
func (s *Stunner) Receive(p []byte, fromAddr *net.UDPAddr) {
|
|
if !stun.Is(p) {
|
|
s.logf("[unexpected] stunner: received non-STUN packet")
|
|
return
|
|
}
|
|
now := time.Now()
|
|
tx, addr, port, err := stun.ParseResponse(p)
|
|
if err != nil {
|
|
s.logf("stunner: received bad STUN response: %v", err)
|
|
return
|
|
}
|
|
r, ok := s.removeTX(tx)
|
|
if !ok {
|
|
return
|
|
}
|
|
d := now.Sub(r.sent)
|
|
|
|
host := net.JoinHostPort(net.IP(addr).String(), fmt.Sprint(port))
|
|
s.onPacket(r.server, host, d)
|
|
}
|
|
|
|
func (s *Stunner) resolver() *net.Resolver {
|
|
return net.DefaultResolver
|
|
}
|
|
|
|
// cleanUpPostRun zeros out some fields, mostly for debugging (so
|
|
// things crash or race+fail if there's a sender still running.)
|
|
func (s *Stunner) cleanUpPostRun() {
|
|
s.mu.Lock()
|
|
s.inFlight = nil
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
// Run starts a Stunner and blocks until all servers either respond
|
|
// or are tried multiple times and timeout.
|
|
// It can not be called concurrently with itself.
|
|
func (s *Stunner) Run(ctx context.Context) error {
|
|
for _, server := range s.Servers {
|
|
if _, _, err := net.SplitHostPort(server); err != nil {
|
|
return fmt.Errorf("Stunner.Run: invalid server %q (in Server list %q)", server, s.Servers)
|
|
}
|
|
}
|
|
if len(s.Servers) == 0 {
|
|
return errors.New("stunner: no Servers")
|
|
}
|
|
|
|
s.inFlight = make(map[stun.TxID]request)
|
|
defer s.cleanUpPostRun()
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
type sender struct {
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
}
|
|
var (
|
|
needMu sync.Mutex
|
|
need = make(map[string]sender) // keyed by server; deleted when done
|
|
allDone = make(chan struct{}) // closed when need is empty
|
|
)
|
|
s.onPacket = func(server, endpoint string, d time.Duration) {
|
|
needMu.Lock()
|
|
defer needMu.Unlock()
|
|
sender, ok := need[server]
|
|
if !ok {
|
|
return
|
|
}
|
|
sender.cancel()
|
|
delete(need, server)
|
|
s.Endpoint(server, endpoint, d)
|
|
if len(need) == 0 {
|
|
close(allDone)
|
|
}
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
for _, server := range s.Servers {
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
need[server] = sender{ctx, cancel}
|
|
}
|
|
for server, sender := range need {
|
|
wg.Add(1)
|
|
server, ctx := server, sender.ctx
|
|
go func() {
|
|
defer wg.Done()
|
|
s.sendPackets(ctx, server)
|
|
}()
|
|
}
|
|
var err error
|
|
select {
|
|
case <-ctx.Done():
|
|
err = ctx.Err()
|
|
case <-allDone:
|
|
cancel()
|
|
}
|
|
wg.Wait()
|
|
|
|
var missing []string
|
|
needMu.Lock()
|
|
for server := range need {
|
|
missing = append(missing, server)
|
|
}
|
|
needMu.Unlock()
|
|
|
|
if len(missing) == 0 || err == nil {
|
|
return nil
|
|
}
|
|
return fmt.Errorf("got STUN error: %v; missing replies from: %v", err, strings.Join(missing, ", "))
|
|
}
|
|
|
|
func (s *Stunner) serverAddr(ctx context.Context, server string) (*net.UDPAddr, error) {
|
|
hostStr, portStr, err := net.SplitHostPort(server)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
addrPort, err := strconv.Atoi(portStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("port: %v", err)
|
|
}
|
|
if addrPort == 0 {
|
|
addrPort = 3478
|
|
}
|
|
addr := &net.UDPAddr{Port: addrPort}
|
|
|
|
var ipAddrs []net.IPAddr
|
|
if s.DNSCache != nil {
|
|
ip, err := s.DNSCache.LookupIP(ctx, hostStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ipAddrs = []net.IPAddr{{IP: ip}}
|
|
} else {
|
|
ipAddrs, err = s.resolver().LookupIPAddr(ctx, hostStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("lookup ip addr (%q): %v", hostStr, err)
|
|
}
|
|
}
|
|
|
|
for _, ipAddr := range ipAddrs {
|
|
ip4 := ipAddr.IP.To4()
|
|
if ip4 != nil {
|
|
if s.OnlyIPv6 {
|
|
continue
|
|
}
|
|
addr.IP = ip4
|
|
break
|
|
} else if s.OnlyIPv6 {
|
|
addr.IP = ipAddr.IP
|
|
addr.Zone = ipAddr.Zone
|
|
}
|
|
}
|
|
if addr.IP == nil {
|
|
if s.OnlyIPv6 {
|
|
return nil, fmt.Errorf("cannot resolve any ipv6 addresses for %s, got: %v", server, ipAddrs)
|
|
}
|
|
return nil, fmt.Errorf("cannot resolve any ipv4 addresses for %s, got: %v", server, ipAddrs)
|
|
}
|
|
return addr, nil
|
|
}
|
|
|
|
func (s *Stunner) sendPackets(ctx context.Context, server string) error {
|
|
addr, err := s.serverAddr(ctx, server)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
const maxSend = 2
|
|
for i := 0; i < maxSend; i++ {
|
|
txID := stun.NewTxID()
|
|
req := stun.Request(txID)
|
|
s.addTX(txID, server)
|
|
_, err = s.Send(req, addr)
|
|
if err != nil {
|
|
return fmt.Errorf("send: %v", err)
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
// Ignore error. The caller deals with handling contexts.
|
|
// We only use it to dermine when to stop spraying STUN packets.
|
|
return nil
|
|
case <-time.After(time.Millisecond * time.Duration(50+rand.Intn(200))):
|
|
}
|
|
}
|
|
return nil
|
|
}
|