mirror of
https://github.com/tailscale/tailscale.git
synced 2025-01-08 09:07:44 +00:00
2cff9016e4
I started to write a full DNS caching resolver and I realized it was overkill and wouldn't work on Windows even in Go 1.14 yet, so I'm doing this tiny one instead for now, just for all our netcheck STUN derp lookups, and connections to DERP servers. (This will be caching a exactly 8 DNS entries, all ours.) Fixes #145 (can be better later, of course)
247 lines
6.0 KiB
Go
247 lines
6.0 KiB
Go
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package stunner
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"tailscale.com/net/dnscache"
|
|
"tailscale.com/stun"
|
|
)
|
|
|
|
// Stunner sends a STUN request to several servers and handles a response.
|
|
//
|
|
// It is designed to used on a connection owned by other code and so does
|
|
// not directly reference a net.Conn of any sort. Instead, the user should
|
|
// provide Send function to send packets, and call Receive when a new
|
|
// STUN response is received.
|
|
//
|
|
// In response, a Stunner will call Endpoint with any endpoints determined
|
|
// for the connection. (An endpoint may be reported multiple times if
|
|
// multiple servers are provided.)
|
|
type Stunner struct {
|
|
// Send sends a packet.
|
|
// It will typically be a PacketConn.WriteTo method value.
|
|
Send func([]byte, net.Addr) (int, error) // sends a packet
|
|
|
|
// Endpoint is called whenever a STUN response is received.
|
|
// The server is the STUN server that replied, endpoint is the ip:port
|
|
// from the STUN response, and d is the duration that the STUN request
|
|
// took on the wire (not including DNS lookup time.
|
|
Endpoint func(server, endpoint string, d time.Duration)
|
|
|
|
Servers []string // STUN servers to contact
|
|
|
|
// DNSCache optionally specifies a DNSCache to use.
|
|
// If nil, a DNS cache is not used.
|
|
DNSCache *dnscache.Resolver
|
|
|
|
// Logf optionally specifies a log function. If nil, logging is disabled.
|
|
Logf func(format string, args ...interface{})
|
|
|
|
// OnlyIPv6 controls whether IPv6 is exclusively used.
|
|
// If false, only IPv4 is used. There is currently no mixed mode.
|
|
OnlyIPv6 bool
|
|
|
|
// sessions tracks the state of each server.
|
|
// It's keyed by the STUN server (from the Servers field).
|
|
sessions map[string]*session
|
|
|
|
mu sync.Mutex
|
|
inFlight map[stun.TxID]request
|
|
}
|
|
|
|
func (s *Stunner) addTX(tx stun.TxID, server string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.inFlight == nil {
|
|
s.inFlight = make(map[stun.TxID]request)
|
|
}
|
|
s.inFlight[tx] = request{sent: time.Now(), server: server}
|
|
}
|
|
|
|
func (s *Stunner) removeTX(tx stun.TxID) (request, bool) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
r, ok := s.inFlight[tx]
|
|
delete(s.inFlight, tx)
|
|
return r, ok
|
|
}
|
|
|
|
type request struct {
|
|
sent time.Time
|
|
server string
|
|
}
|
|
|
|
type session struct {
|
|
ctx context.Context // closed via call to done when reply received
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
func (s *Stunner) logf(format string, args ...interface{}) {
|
|
if s.Logf != nil {
|
|
s.Logf(format, args...)
|
|
}
|
|
}
|
|
|
|
// Receive delivers a STUN packet to the stunner.
|
|
func (s *Stunner) Receive(p []byte, fromAddr *net.UDPAddr) {
|
|
if !stun.Is(p) {
|
|
s.logf("stunner: received non-STUN packet")
|
|
return
|
|
}
|
|
now := time.Now()
|
|
tx, addr, port, err := stun.ParseResponse(p)
|
|
if err != nil {
|
|
s.logf("stunner: received bad STUN response: %v", err)
|
|
return
|
|
}
|
|
r, ok := s.removeTX(tx)
|
|
if !ok {
|
|
s.logf("stunner: got STUN packet for unknown TxID %x", tx)
|
|
return
|
|
}
|
|
d := now.Sub(r.sent)
|
|
|
|
session := s.sessions[r.server]
|
|
if session != nil {
|
|
host := net.JoinHostPort(net.IP(addr).String(), fmt.Sprint(port))
|
|
s.Endpoint(r.server, host, d)
|
|
session.cancel()
|
|
}
|
|
}
|
|
|
|
func (s *Stunner) resolver() *net.Resolver {
|
|
return net.DefaultResolver
|
|
}
|
|
|
|
// Run starts a Stunner and blocks until all servers either respond
|
|
// or are tried multiple times and timeout.
|
|
//
|
|
// TODO: this always returns success now. It should return errors
|
|
// if certain servers are unavailable probably. Or if all are.
|
|
// Or some configured threshold are.
|
|
func (s *Stunner) Run(ctx context.Context) error {
|
|
s.sessions = map[string]*session{}
|
|
for _, server := range s.Servers {
|
|
sctx, cancel := context.WithCancel(ctx)
|
|
s.sessions[server] = &session{
|
|
ctx: sctx,
|
|
cancel: cancel,
|
|
}
|
|
}
|
|
// after this point, the s.sessions map is read-only
|
|
|
|
var wg sync.WaitGroup
|
|
for _, server := range s.Servers {
|
|
wg.Add(1)
|
|
go func(server string) {
|
|
defer wg.Done()
|
|
s.runServer(ctx, server)
|
|
}(server)
|
|
}
|
|
wg.Wait()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Stunner) runServer(ctx context.Context, server string) {
|
|
session := s.sessions[server]
|
|
|
|
for i, d := range retryDurations {
|
|
ctx, cancel := context.WithTimeout(ctx, d)
|
|
err := s.sendSTUN(ctx, server)
|
|
if err != nil {
|
|
s.logf("stunner: %s: %v", server, err)
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
cancel()
|
|
case <-session.ctx.Done():
|
|
cancel()
|
|
if i > 0 {
|
|
s.logf("stunner: slow STUN response from %s: %d retries", server, i)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
s.logf("stunner: no STUN response from %s", server)
|
|
}
|
|
|
|
func (s *Stunner) sendSTUN(ctx context.Context, server string) error {
|
|
host, port, err := net.SplitHostPort(server)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
addrPort, err := strconv.Atoi(port)
|
|
if err != nil {
|
|
return fmt.Errorf("port: %v", err)
|
|
}
|
|
if addrPort == 0 {
|
|
addrPort = 3478
|
|
}
|
|
addr := &net.UDPAddr{Port: addrPort}
|
|
|
|
var ipAddrs []net.IPAddr
|
|
if s.DNSCache != nil {
|
|
ip, err := s.DNSCache.LookupIP(ctx, host)
|
|
if err != nil {
|
|
return fmt.Errorf("lookup ip addr: %v", err)
|
|
}
|
|
ipAddrs = []net.IPAddr{{IP: ip}}
|
|
} else {
|
|
ipAddrs, err = s.resolver().LookupIPAddr(ctx, host)
|
|
if err != nil {
|
|
return fmt.Errorf("lookup ip addr: %v", err)
|
|
}
|
|
}
|
|
for _, ipAddr := range ipAddrs {
|
|
ip4 := ipAddr.IP.To4()
|
|
if ip4 != nil {
|
|
if s.OnlyIPv6 {
|
|
continue
|
|
}
|
|
addr.IP = ip4
|
|
break
|
|
} else if s.OnlyIPv6 {
|
|
addr.IP = ipAddr.IP
|
|
addr.Zone = ipAddr.Zone
|
|
}
|
|
}
|
|
if addr.IP == nil {
|
|
if s.OnlyIPv6 {
|
|
return fmt.Errorf("cannot resolve any ipv6 addresses for %s, got: %v", server, ipAddrs)
|
|
}
|
|
return fmt.Errorf("cannot resolve any ipv4 addresses for %s, got: %v", server, ipAddrs)
|
|
}
|
|
|
|
txID := stun.NewTxID()
|
|
req := stun.Request(txID)
|
|
s.addTX(txID, server)
|
|
_, err = s.Send(req, addr)
|
|
if err != nil {
|
|
return fmt.Errorf("send: %v", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
var retryDurations = []time.Duration{
|
|
100 * time.Millisecond,
|
|
100 * time.Millisecond,
|
|
100 * time.Millisecond,
|
|
200 * time.Millisecond,
|
|
200 * time.Millisecond,
|
|
400 * time.Millisecond,
|
|
800 * time.Millisecond,
|
|
1600 * time.Millisecond,
|
|
3200 * time.Millisecond,
|
|
}
|