tailscale/stunner/stunner.go

198 lines
4.9 KiB
Go

// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package stunner
import (
"bytes"
"context"
"crypto/rand"
"fmt"
"log"
"net"
"strconv"
"sync"
"time"
"tailscale.com/stun"
)
// Stunner sends a STUN request to several servers and handles a response.
//
// It is designed to used on a connection owned by other code and so does
// not directly reference a net.Conn of any sort. Instead, the user should
// provide Send function to send packets, and call Receive when a new
// STUN response is received.
//
// In response, a Stunner will call Endpoint with any endpoints determined
// for the connection. (An endpoint may be reported multiple times if
// multiple servers are provided.)
type Stunner struct {
Send func([]byte, net.Addr) (int, error) // sends a packet
Endpoint func(endpoint string) // reports an endpoint
Servers []string // STUN servers to contact
Resolver *net.Resolver
Logf func(format string, args ...interface{})
sessions map[string]*session
tIDs map[string][][12]byte
}
type session struct {
replied chan struct{} // closed when server responds
tIDs [][12]byte // transaction IDs sent to a server
}
// Receive delivers a STUN packet to the stunner.
func (s *Stunner) Receive(p []byte, fromAddr *net.UDPAddr) {
if !stun.Is(p) {
log.Println("stunner: received non-STUN packet")
return
}
responseTID, addr, port, err := stun.ParseResponse(p)
if err != nil {
log.Printf("stunner: received bad STUN response: %v", err)
return
}
// Accept any of the tIDs from any of the active sessions.
for server, session := range s.sessions {
for _, tID := range session.tIDs {
if bytes.Equal(tID[:], responseTID[:]) {
select {
case <-session.replied:
return // already got a reply from this server
default:
}
close(session.replied)
// TODO(crawshaw): use different endpoints returned from
// different STUN servers to detect NAT types.
portStr := fmt.Sprintf("%d", port)
host := net.JoinHostPort(net.IP(addr).String(), portStr)
if s.Logf != nil {
s.Logf("STUN server %s reports public endpoint %s", server, host)
}
s.Endpoint(host)
return
}
}
}
log.Printf("stunner: received STUN packet for unknown transaction: %x", responseTID)
}
// Run starts a Stunner and blocks until all servers either respond
// or are tried multiple times and timeout.
func (s *Stunner) Run(ctx context.Context) error {
if s.Resolver == nil {
s.Resolver = net.DefaultResolver
}
for _, server := range s.Servers {
// Generate the transaction IDs for this session.
tIDs := make([][12]byte, len(retryDurations))
for i := range tIDs {
if _, err := rand.Read(tIDs[i][:]); err != nil {
return fmt.Errorf("stunner: rand failed: %v", err)
}
}
if s.sessions == nil {
s.sessions = make(map[string]*session)
}
s.sessions[server] = &session{
replied: make(chan struct{}),
tIDs: tIDs,
}
}
// after this point, the s.sessions map is read-only
var wg sync.WaitGroup
for _, server := range s.Servers {
wg.Add(1)
go func(server string) {
defer wg.Done()
s.runServer(ctx, server)
}(server)
}
wg.Wait()
return nil
}
func (s *Stunner) runServer(ctx context.Context, server string) {
session := s.sessions[server]
for i, d := range retryDurations {
ctx, cancel := context.WithTimeout(ctx, d)
err := s.sendSTUN(ctx, session.tIDs[i], server)
if err != nil {
if s.Logf != nil {
s.Logf("stunner: %s: %v", server, err)
}
}
select {
case <-ctx.Done():
cancel()
case <-session.replied:
cancel()
if i > 0 && s.Logf != nil {
s.Logf("stunner: slow STUN response from %s: %d retries", server, i)
}
return
}
}
if s.Logf != nil {
s.Logf("stunner: no STUN response from %s", server)
}
}
func (s *Stunner) sendSTUN(ctx context.Context, tID [12]byte, server string) error {
host, port, err := net.SplitHostPort(server)
if err != nil {
return err
}
addrPort, err := strconv.Atoi(port)
if err != nil {
return fmt.Errorf("port: %v", err)
}
if addrPort == 0 {
addrPort = 3478
}
addr := &net.UDPAddr{Port: addrPort}
ipAddrs, err := s.Resolver.LookupIPAddr(ctx, host)
if err != nil {
return fmt.Errorf("lookup ip addr: %v", err)
}
for _, ipAddr := range ipAddrs {
if ip4 := ipAddr.IP.To4(); ip4 != nil {
addr.IP = ip4
addr.Zone = ipAddr.Zone
break
}
}
if addr.IP == nil {
return fmt.Errorf("cannot resolve any ipv4 addresses for %s, got: %v", server, ipAddrs)
}
req := stun.Request(tID)
if _, err := s.Send(req, addr); err != nil {
return fmt.Errorf("Send: %v", err)
}
return nil
}
var retryDurations = []time.Duration{
100 * time.Millisecond,
100 * time.Millisecond,
100 * time.Millisecond,
200 * time.Millisecond,
200 * time.Millisecond,
400 * time.Millisecond,
800 * time.Millisecond,
1600 * time.Millisecond,
3200 * time.Millisecond,
}