stunner: support IPv6, add latency info to callbacks, use unique TxIDs per retry

And some more docs.

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2020-02-27 09:23:20 -08:00
parent 00ad93ec25
commit c185e6b4b0
3 changed files with 111 additions and 71 deletions

View File

@ -5,11 +5,8 @@
package stunner package stunner
import ( import (
"bytes"
"context" "context"
"crypto/rand"
"fmt" "fmt"
"log"
"net" "net"
"strconv" "strconv"
"sync" "sync"
@ -29,79 +26,114 @@
// for the connection. (An endpoint may be reported multiple times if // for the connection. (An endpoint may be reported multiple times if
// multiple servers are provided.) // multiple servers are provided.)
type Stunner struct { type Stunner struct {
Send func([]byte, net.Addr) (int, error) // sends a packet // Send sends a packet.
Endpoint func(endpoint string) // reports an endpoint // It will typically be a PacketConn.WriteTo method value.
Servers []string // STUN servers to contact Send func([]byte, net.Addr) (int, error) // sends a packet
Resolver *net.Resolver
Logf func(format string, args ...interface{})
// Endpoint is called whenever a STUN response is received.
// The server is the STUN server that replied, endpoint is the ip:port
// from the STUN response, and d is the duration that the STUN request
// took on the wire (not including DNS lookup time.
Endpoint func(server, endpoint string, d time.Duration)
Servers []string // STUN servers to contact
// Resolver optionally specifies a resolver to use for DNS lookups.
// If nil, net.DefaultResolver is used.
Resolver *net.Resolver
// Logf optionally specifies a log function. If nil, logging is disabled.
Logf func(format string, args ...interface{})
// OnlyIPv6 controls whether IPv6 is exclusively used.
// If false, only IPv4 is used. There is currently no mixed mode.
OnlyIPv6 bool
// sessions tracks the state of each server.
// It's keyed by the STUN server (from the Servers field).
sessions map[string]*session sessions map[string]*session
mu sync.Mutex
inFlight map[stun.TxID]request
}
func (s *Stunner) addTX(tx stun.TxID, server string) {
s.mu.Lock()
defer s.mu.Unlock()
if s.inFlight == nil {
s.inFlight = make(map[stun.TxID]request)
}
s.inFlight[tx] = request{sent: time.Now(), server: server}
}
func (s *Stunner) removeTX(tx stun.TxID) (request, bool) {
s.mu.Lock()
defer s.mu.Unlock()
r, ok := s.inFlight[tx]
delete(s.inFlight, tx)
return r, ok
}
type request struct {
sent time.Time
server string
} }
type session struct { type session struct {
replied chan struct{} // closed when server responds ctx context.Context // closed via call to done when reply received
tIDs []stun.TxID // transaction IDs sent to a server cancel context.CancelFunc
}
func (s *Stunner) logf(format string, args ...interface{}) {
if s.Logf != nil {
s.Logf(format, args...)
}
} }
// Receive delivers a STUN packet to the stunner. // Receive delivers a STUN packet to the stunner.
func (s *Stunner) Receive(p []byte, fromAddr *net.UDPAddr) { func (s *Stunner) Receive(p []byte, fromAddr *net.UDPAddr) {
if !stun.Is(p) { if !stun.Is(p) {
log.Println("stunner: received non-STUN packet") s.logf("stunner: received non-STUN packet")
return return
} }
now := time.Now()
responseTID, addr, port, err := stun.ParseResponse(p) tx, addr, port, err := stun.ParseResponse(p)
if err != nil { if err != nil {
log.Printf("stunner: received bad STUN response: %v", err) s.logf("stunner: received bad STUN response: %v", err)
return return
} }
r, ok := s.removeTX(tx)
// Accept any of the tIDs from any of the active sessions. if !ok {
for server, session := range s.sessions { s.logf("stunner: got STUN packet for unknown TxID %x", tx)
for _, tID := range session.tIDs { return
if bytes.Equal(tID[:], responseTID[:]) {
select {
case <-session.replied:
return // already got a reply from this server
default:
}
close(session.replied)
// TODO(crawshaw): use different endpoints returned from
// different STUN servers to detect NAT types.
portStr := fmt.Sprintf("%d", port)
host := net.JoinHostPort(net.IP(addr).String(), portStr)
if s.Logf != nil {
s.Logf("STUN server %s reports public endpoint %s", server, host)
}
s.Endpoint(host)
return
}
}
} }
log.Printf("stunner: received STUN packet for unknown transaction: %x", responseTID) d := now.Sub(r.sent)
session := s.sessions[r.server]
if session != nil {
host := net.JoinHostPort(net.IP(addr).String(), fmt.Sprint(port))
s.logf("STUN server %s reports public endpoint %s after %v", r.server, host, d)
s.Endpoint(r.server, host, d)
session.cancel()
}
}
func (s *Stunner) resolver() *net.Resolver {
if s.Resolver != nil {
return s.Resolver
}
return net.DefaultResolver
} }
// Run starts a Stunner and blocks until all servers either respond // Run starts a Stunner and blocks until all servers either respond
// or are tried multiple times and timeout. // or are tried multiple times and timeout.
func (s *Stunner) Run(ctx context.Context) error { func (s *Stunner) Run(ctx context.Context) error {
if s.Resolver == nil { s.sessions = map[string]*session{}
s.Resolver = net.DefaultResolver
}
for _, server := range s.Servers { for _, server := range s.Servers {
// Generate the transaction IDs for this session. sctx, cancel := context.WithCancel(ctx)
tIDs := make([]stun.TxID, len(retryDurations))
for i := range tIDs {
if _, err := rand.Read(tIDs[i][:]); err != nil {
return fmt.Errorf("stunner: rand failed: %v", err)
}
}
if s.sessions == nil {
s.sessions = make(map[string]*session)
}
s.sessions[server] = &session{ s.sessions[server] = &session{
replied: make(chan struct{}), ctx: sctx,
tIDs: tIDs, cancel: cancel,
} }
} }
// after this point, the s.sessions map is read-only // after this point, the s.sessions map is read-only
@ -124,30 +156,26 @@ func (s *Stunner) runServer(ctx context.Context, server string) {
for i, d := range retryDurations { for i, d := range retryDurations {
ctx, cancel := context.WithTimeout(ctx, d) ctx, cancel := context.WithTimeout(ctx, d)
err := s.sendSTUN(ctx, session.tIDs[i], server) err := s.sendSTUN(ctx, server)
if err != nil { if err != nil {
if s.Logf != nil { s.logf("stunner: %s: %v", server, err)
s.Logf("stunner: %s: %v", server, err)
}
} }
select { select {
case <-ctx.Done(): case <-ctx.Done():
cancel() cancel()
case <-session.replied: case <-session.ctx.Done():
cancel() cancel()
if i > 0 && s.Logf != nil { if i > 0 {
s.Logf("stunner: slow STUN response from %s: %d retries", server, i) s.logf("stunner: slow STUN response from %s: %d retries", server, i)
} }
return return
} }
} }
if s.Logf != nil { s.logf("stunner: no STUN response from %s", server)
s.Logf("stunner: no STUN response from %s", server)
}
} }
func (s *Stunner) sendSTUN(ctx context.Context, tID stun.TxID, server string) error { func (s *Stunner) sendSTUN(ctx context.Context, server string) error {
host, port, err := net.SplitHostPort(server) host, port, err := net.SplitHostPort(server)
if err != nil { if err != nil {
return err return err
@ -161,23 +189,35 @@ func (s *Stunner) sendSTUN(ctx context.Context, tID stun.TxID, server string) er
} }
addr := &net.UDPAddr{Port: addrPort} addr := &net.UDPAddr{Port: addrPort}
ipAddrs, err := s.Resolver.LookupIPAddr(ctx, host) ipAddrs, err := s.resolver().LookupIPAddr(ctx, host)
if err != nil { if err != nil {
return fmt.Errorf("lookup ip addr: %v", err) return fmt.Errorf("lookup ip addr: %v", err)
} }
for _, ipAddr := range ipAddrs { for _, ipAddr := range ipAddrs {
if ip4 := ipAddr.IP.To4(); ip4 != nil { ip4 := ipAddr.IP.To4()
if ip4 != nil {
if s.OnlyIPv6 {
continue
}
addr.IP = ip4 addr.IP = ip4
addr.Zone = ipAddr.Zone
break break
} else if s.OnlyIPv6 {
addr.IP = ipAddr.IP
addr.Zone = ipAddr.Zone
} }
} }
if addr.IP == nil { if addr.IP == nil {
if s.OnlyIPv6 {
return fmt.Errorf("cannot resolve any ipv6 addresses for %s, got: %v", server, ipAddrs)
}
return fmt.Errorf("cannot resolve any ipv4 addresses for %s, got: %v", server, ipAddrs) return fmt.Errorf("cannot resolve any ipv4 addresses for %s, got: %v", server, ipAddrs)
} }
req := stun.Request(tID) txID := stun.NewTxID()
if _, err := s.Send(req, addr); err != nil { req := stun.Request(txID)
s.addTX(txID, server)
_, err = s.Send(req, addr)
if err != nil {
return fmt.Errorf("send: %v", err) return fmt.Errorf("send: %v", err)
} }
return nil return nil

View File

@ -40,7 +40,7 @@ func TestStun(t *testing.T) {
s := &Stunner{ s := &Stunner{
Send: localConn.WriteTo, Send: localConn.WriteTo,
Endpoint: func(ep string) { epCh <- ep }, Endpoint: func(server, ep string, d time.Duration) { epCh <- ep },
Servers: stunServers, Servers: stunServers,
} }

View File

@ -232,7 +232,7 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]string, error) {
s := &stunner.Stunner{ s := &stunner.Stunner{
Send: c.pconn.WriteTo, Send: c.pconn.WriteTo,
Endpoint: func(s string) { addAddr(s, "stun") }, Endpoint: func(server, endpoint string, d time.Duration) { addAddr(endpoint, "stun") },
Servers: c.stunServers, Servers: c.stunServers,
Logf: c.logf, Logf: c.logf,
} }