diff --git a/stunner/stunner.go b/stunner/stunner.go index 6c23af38e..ab1c6a908 100644 --- a/stunner/stunner.go +++ b/stunner/stunner.go @@ -5,11 +5,8 @@ package stunner import ( - "bytes" "context" - "crypto/rand" "fmt" - "log" "net" "strconv" "sync" @@ -29,79 +26,114 @@ // for the connection. (An endpoint may be reported multiple times if // multiple servers are provided.) type Stunner struct { - Send func([]byte, net.Addr) (int, error) // sends a packet - Endpoint func(endpoint string) // reports an endpoint - Servers []string // STUN servers to contact - Resolver *net.Resolver - Logf func(format string, args ...interface{}) + // Send sends a packet. + // It will typically be a PacketConn.WriteTo method value. + Send func([]byte, net.Addr) (int, error) // sends a packet + // Endpoint is called whenever a STUN response is received. + // The server is the STUN server that replied, endpoint is the ip:port + // from the STUN response, and d is the duration that the STUN request + // took on the wire (not including DNS lookup time. + Endpoint func(server, endpoint string, d time.Duration) + + Servers []string // STUN servers to contact + + // Resolver optionally specifies a resolver to use for DNS lookups. + // If nil, net.DefaultResolver is used. + Resolver *net.Resolver + + // Logf optionally specifies a log function. If nil, logging is disabled. + Logf func(format string, args ...interface{}) + + // OnlyIPv6 controls whether IPv6 is exclusively used. + // If false, only IPv4 is used. There is currently no mixed mode. + OnlyIPv6 bool + + // sessions tracks the state of each server. + // It's keyed by the STUN server (from the Servers field). sessions map[string]*session + + mu sync.Mutex + inFlight map[stun.TxID]request +} + +func (s *Stunner) addTX(tx stun.TxID, server string) { + s.mu.Lock() + defer s.mu.Unlock() + if s.inFlight == nil { + s.inFlight = make(map[stun.TxID]request) + } + s.inFlight[tx] = request{sent: time.Now(), server: server} +} + +func (s *Stunner) removeTX(tx stun.TxID) (request, bool) { + s.mu.Lock() + defer s.mu.Unlock() + r, ok := s.inFlight[tx] + delete(s.inFlight, tx) + return r, ok +} + +type request struct { + sent time.Time + server string } type session struct { - replied chan struct{} // closed when server responds - tIDs []stun.TxID // transaction IDs sent to a server + ctx context.Context // closed via call to done when reply received + cancel context.CancelFunc +} + +func (s *Stunner) logf(format string, args ...interface{}) { + if s.Logf != nil { + s.Logf(format, args...) + } } // Receive delivers a STUN packet to the stunner. func (s *Stunner) Receive(p []byte, fromAddr *net.UDPAddr) { if !stun.Is(p) { - log.Println("stunner: received non-STUN packet") + s.logf("stunner: received non-STUN packet") return } - - responseTID, addr, port, err := stun.ParseResponse(p) + now := time.Now() + tx, addr, port, err := stun.ParseResponse(p) if err != nil { - log.Printf("stunner: received bad STUN response: %v", err) + s.logf("stunner: received bad STUN response: %v", err) return } - - // Accept any of the tIDs from any of the active sessions. - for server, session := range s.sessions { - for _, tID := range session.tIDs { - if bytes.Equal(tID[:], responseTID[:]) { - select { - case <-session.replied: - return // already got a reply from this server - default: - } - close(session.replied) - - // TODO(crawshaw): use different endpoints returned from - // different STUN servers to detect NAT types. - portStr := fmt.Sprintf("%d", port) - host := net.JoinHostPort(net.IP(addr).String(), portStr) - if s.Logf != nil { - s.Logf("STUN server %s reports public endpoint %s", server, host) - } - s.Endpoint(host) - return - } - } + r, ok := s.removeTX(tx) + if !ok { + s.logf("stunner: got STUN packet for unknown TxID %x", tx) + return } - log.Printf("stunner: received STUN packet for unknown transaction: %x", responseTID) + d := now.Sub(r.sent) + + session := s.sessions[r.server] + if session != nil { + host := net.JoinHostPort(net.IP(addr).String(), fmt.Sprint(port)) + s.logf("STUN server %s reports public endpoint %s after %v", r.server, host, d) + s.Endpoint(r.server, host, d) + session.cancel() + } +} + +func (s *Stunner) resolver() *net.Resolver { + if s.Resolver != nil { + return s.Resolver + } + return net.DefaultResolver } // Run starts a Stunner and blocks until all servers either respond // or are tried multiple times and timeout. func (s *Stunner) Run(ctx context.Context) error { - if s.Resolver == nil { - s.Resolver = net.DefaultResolver - } + s.sessions = map[string]*session{} for _, server := range s.Servers { - // Generate the transaction IDs for this session. - tIDs := make([]stun.TxID, len(retryDurations)) - for i := range tIDs { - if _, err := rand.Read(tIDs[i][:]); err != nil { - return fmt.Errorf("stunner: rand failed: %v", err) - } - } - if s.sessions == nil { - s.sessions = make(map[string]*session) - } + sctx, cancel := context.WithCancel(ctx) s.sessions[server] = &session{ - replied: make(chan struct{}), - tIDs: tIDs, + ctx: sctx, + cancel: cancel, } } // after this point, the s.sessions map is read-only @@ -124,30 +156,26 @@ func (s *Stunner) runServer(ctx context.Context, server string) { for i, d := range retryDurations { ctx, cancel := context.WithTimeout(ctx, d) - err := s.sendSTUN(ctx, session.tIDs[i], server) + err := s.sendSTUN(ctx, server) if err != nil { - if s.Logf != nil { - s.Logf("stunner: %s: %v", server, err) - } + s.logf("stunner: %s: %v", server, err) } select { case <-ctx.Done(): cancel() - case <-session.replied: + case <-session.ctx.Done(): cancel() - if i > 0 && s.Logf != nil { - s.Logf("stunner: slow STUN response from %s: %d retries", server, i) + if i > 0 { + s.logf("stunner: slow STUN response from %s: %d retries", server, i) } return } } - if s.Logf != nil { - s.Logf("stunner: no STUN response from %s", server) - } + s.logf("stunner: no STUN response from %s", server) } -func (s *Stunner) sendSTUN(ctx context.Context, tID stun.TxID, server string) error { +func (s *Stunner) sendSTUN(ctx context.Context, server string) error { host, port, err := net.SplitHostPort(server) if err != nil { return err @@ -161,23 +189,35 @@ func (s *Stunner) sendSTUN(ctx context.Context, tID stun.TxID, server string) er } addr := &net.UDPAddr{Port: addrPort} - ipAddrs, err := s.Resolver.LookupIPAddr(ctx, host) + ipAddrs, err := s.resolver().LookupIPAddr(ctx, host) if err != nil { return fmt.Errorf("lookup ip addr: %v", err) } for _, ipAddr := range ipAddrs { - if ip4 := ipAddr.IP.To4(); ip4 != nil { + ip4 := ipAddr.IP.To4() + if ip4 != nil { + if s.OnlyIPv6 { + continue + } addr.IP = ip4 - addr.Zone = ipAddr.Zone break + } else if s.OnlyIPv6 { + addr.IP = ipAddr.IP + addr.Zone = ipAddr.Zone } } if addr.IP == nil { + if s.OnlyIPv6 { + return fmt.Errorf("cannot resolve any ipv6 addresses for %s, got: %v", server, ipAddrs) + } return fmt.Errorf("cannot resolve any ipv4 addresses for %s, got: %v", server, ipAddrs) } - req := stun.Request(tID) - if _, err := s.Send(req, addr); err != nil { + txID := stun.NewTxID() + req := stun.Request(txID) + s.addTX(txID, server) + _, err = s.Send(req, addr) + if err != nil { return fmt.Errorf("send: %v", err) } return nil diff --git a/stunner/stunner_test.go b/stunner/stunner_test.go index 839104d71..9f8e4d5a7 100644 --- a/stunner/stunner_test.go +++ b/stunner/stunner_test.go @@ -40,7 +40,7 @@ func TestStun(t *testing.T) { s := &Stunner{ Send: localConn.WriteTo, - Endpoint: func(ep string) { epCh <- ep }, + Endpoint: func(server, ep string, d time.Duration) { epCh <- ep }, Servers: stunServers, } diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 6258943d5..aa5d938d6 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -232,7 +232,7 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]string, error) { s := &stunner.Stunner{ Send: c.pconn.WriteTo, - Endpoint: func(s string) { addAddr(s, "stun") }, + Endpoint: func(server, endpoint string, d time.Duration) { addAddr(endpoint, "stun") }, Servers: c.stunServers, Logf: c.logf, }