mirror of
https://github.com/tailscale/tailscale.git
synced 2025-12-03 18:42:00 +00:00
Move Linux client & common packages into a public repo.
This commit is contained in:
197
stunner/stunner.go
Normal file
197
stunner/stunner.go
Normal file
@@ -0,0 +1,197 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package stunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tailscale.com/stun"
|
||||
)
|
||||
|
||||
// Stunner sends a STUN request to several servers and handles a response.
|
||||
//
|
||||
// It is designed to used on a connection owned by other code and so does
|
||||
// not directly reference a net.Conn of any sort. Instead, the user should
|
||||
// provide Send function to send packets, and call Receive when a new
|
||||
// STUN response is received.
|
||||
//
|
||||
// In response, a Stunner will call Endpoint with any endpoints determined
|
||||
// for the connection. (An endpoint may be reported multiple times if
|
||||
// multiple servers are provided.)
|
||||
type Stunner struct {
|
||||
Send func([]byte, net.Addr) (int, error) // sends a packet
|
||||
Endpoint func(endpoint string) // reports an endpoint
|
||||
Servers []string // STUN servers to contact
|
||||
Resolver *net.Resolver
|
||||
Logf func(format string, args ...interface{})
|
||||
|
||||
sessions map[string]*session
|
||||
tIDs map[string][][12]byte
|
||||
}
|
||||
|
||||
type session struct {
|
||||
replied chan struct{} // closed when server responds
|
||||
tIDs [][12]byte // transaction IDs sent to a server
|
||||
}
|
||||
|
||||
// Receive delivers a STUN packet to the stunner.
|
||||
func (s *Stunner) Receive(p []byte, fromAddr *net.UDPAddr) {
|
||||
if !stun.Is(p) {
|
||||
log.Println("stunner: received non-STUN packet")
|
||||
return
|
||||
}
|
||||
|
||||
responseTID, addr, port, err := stun.ParseResponse(p)
|
||||
if err != nil {
|
||||
log.Printf("stunner: received bad STUN response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Accept any of the tIDs from any of the active sessions.
|
||||
for server, session := range s.sessions {
|
||||
for _, tID := range session.tIDs {
|
||||
if bytes.Equal(tID[:], responseTID[:]) {
|
||||
select {
|
||||
case <-session.replied:
|
||||
return // already got a reply from this server
|
||||
default:
|
||||
}
|
||||
close(session.replied)
|
||||
|
||||
// TODO(crawshaw): use different endpoints returned from
|
||||
// different STUN servers to detect NAT types.
|
||||
portStr := fmt.Sprintf("%d", port)
|
||||
host := net.JoinHostPort(net.IP(addr).String(), portStr)
|
||||
if s.Logf != nil {
|
||||
s.Logf("STUN server %s reports public endpoint %s", server, host)
|
||||
}
|
||||
s.Endpoint(host)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Printf("stunner: received STUN packet for unknown transaction: %x", responseTID)
|
||||
}
|
||||
|
||||
// Run starts a Stunner and blocks until all servers either respond
|
||||
// or are tried multiple times and timeout.
|
||||
func (s *Stunner) Run(ctx context.Context) error {
|
||||
if s.Resolver == nil {
|
||||
s.Resolver = net.DefaultResolver
|
||||
}
|
||||
for _, server := range s.Servers {
|
||||
// Generate the transaction IDs for this session.
|
||||
tIDs := make([][12]byte, len(retryDurations))
|
||||
for i := range tIDs {
|
||||
if _, err := rand.Read(tIDs[i][:]); err != nil {
|
||||
return fmt.Errorf("stunner: rand failed: %v", err)
|
||||
}
|
||||
}
|
||||
if s.sessions == nil {
|
||||
s.sessions = make(map[string]*session)
|
||||
}
|
||||
s.sessions[server] = &session{
|
||||
replied: make(chan struct{}),
|
||||
tIDs: tIDs,
|
||||
}
|
||||
}
|
||||
// after this point, the s.sessions map is read-only
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, server := range s.Servers {
|
||||
wg.Add(1)
|
||||
go func(server string) {
|
||||
defer wg.Done()
|
||||
s.runServer(ctx, server)
|
||||
}(server)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Stunner) runServer(ctx context.Context, server string) {
|
||||
session := s.sessions[server]
|
||||
|
||||
for i, d := range retryDurations {
|
||||
ctx, cancel := context.WithTimeout(ctx, d)
|
||||
err := s.sendSTUN(ctx, session.tIDs[i], server)
|
||||
if err != nil {
|
||||
if s.Logf != nil {
|
||||
s.Logf("stunner: %s: %v", server, err)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cancel()
|
||||
case <-session.replied:
|
||||
cancel()
|
||||
if i > 0 && s.Logf != nil {
|
||||
s.Logf("stunner: slow STUN response from %s: %d retries", server, i)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
if s.Logf != nil {
|
||||
s.Logf("stunner: no STUN response from %s", server)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Stunner) sendSTUN(ctx context.Context, tID [12]byte, server string) error {
|
||||
host, port, err := net.SplitHostPort(server)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
addrPort, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("port: %v", err)
|
||||
}
|
||||
if addrPort == 0 {
|
||||
addrPort = 3478
|
||||
}
|
||||
addr := &net.UDPAddr{Port: addrPort}
|
||||
|
||||
ipAddrs, err := s.Resolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("lookup ip addr: %v", err)
|
||||
}
|
||||
for _, ipAddr := range ipAddrs {
|
||||
if ip4 := ipAddr.IP.To4(); ip4 != nil {
|
||||
addr.IP = ip4
|
||||
addr.Zone = ipAddr.Zone
|
||||
break
|
||||
}
|
||||
}
|
||||
if addr.IP == nil {
|
||||
return fmt.Errorf("cannot resolve any ipv4 addresses for %s, got: %v", server, ipAddrs)
|
||||
}
|
||||
|
||||
req := stun.Request(tID)
|
||||
if _, err := s.Send(req, addr); err != nil {
|
||||
return fmt.Errorf("Send: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var retryDurations = []time.Duration{
|
||||
100 * time.Millisecond,
|
||||
100 * time.Millisecond,
|
||||
100 * time.Millisecond,
|
||||
200 * time.Millisecond,
|
||||
200 * time.Millisecond,
|
||||
400 * time.Millisecond,
|
||||
800 * time.Millisecond,
|
||||
1600 * time.Millisecond,
|
||||
3200 * time.Millisecond,
|
||||
}
|
||||
150
stunner/stunner_test.go
Normal file
150
stunner/stunner_test.go
Normal file
@@ -0,0 +1,150 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package stunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gortc.io/stun"
|
||||
)
|
||||
|
||||
func TestStun(t *testing.T) {
|
||||
conn1, err := net.ListenPacket("udp4", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn1.Close()
|
||||
conn2, err := net.ListenPacket("udp4", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn2.Close()
|
||||
stunServers := []string{
|
||||
conn1.LocalAddr().String(), conn2.LocalAddr().String(),
|
||||
}
|
||||
|
||||
epCh := make(chan string, 16)
|
||||
|
||||
localConn, err := net.ListenPacket("udp4", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
s := &Stunner{
|
||||
Send: localConn.WriteTo,
|
||||
Endpoint: func(ep string) { epCh <- ep },
|
||||
Servers: stunServers,
|
||||
}
|
||||
|
||||
stun1Err := make(chan error)
|
||||
go func() {
|
||||
stun1Err <- startSTUN(conn1, s.Receive)
|
||||
}()
|
||||
stun2Err := make(chan error)
|
||||
go func() {
|
||||
stun2Err <- startSTUNDrop1(conn2, s.Receive)
|
||||
}()
|
||||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
errCh <- s.Run(context.Background())
|
||||
}()
|
||||
|
||||
var eps []string
|
||||
select {
|
||||
case ep := <-epCh:
|
||||
eps = append(eps, ep)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("missing first endpoint response")
|
||||
}
|
||||
select {
|
||||
case ep := <-epCh:
|
||||
eps = append(eps, ep)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("missing second endpoint response")
|
||||
}
|
||||
sort.Strings(eps)
|
||||
if want := "1.2.3.4:1234"; eps[0] != want {
|
||||
t.Errorf("eps[0]=%q, want %q", eps[0], want)
|
||||
}
|
||||
if want := "4.5.6.7:4567"; eps[1] != want {
|
||||
t.Errorf("eps[1]=%q, want %q", eps[1], want)
|
||||
}
|
||||
|
||||
if err := <-errCh; err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func startSTUNDrop1(conn net.PacketConn, writeTo func([]byte, *net.UDPAddr)) error {
|
||||
if _, _, err := conn.ReadFrom(make([]byte, 1024)); err != nil {
|
||||
return fmt.Errorf("first stun server read failed: %v", err)
|
||||
}
|
||||
req := new(stun.Message)
|
||||
res := new(stun.Message)
|
||||
|
||||
p := make([]byte, 1024)
|
||||
n, addr, err := conn.ReadFrom(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p = p[:n]
|
||||
if !stun.IsMessage(p) {
|
||||
return errors.New("not a STUN message")
|
||||
}
|
||||
if _, err := req.Write(p); err != nil {
|
||||
return err
|
||||
}
|
||||
mappedAddr := &stun.XORMappedAddress{
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Port: 1234,
|
||||
}
|
||||
software := stun.NewSoftware("endpointer")
|
||||
err = res.Build(req, stun.BindingSuccess, software, mappedAddr, stun.Fingerprint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
writeTo(res.Raw, addr.(*net.UDPAddr))
|
||||
return nil
|
||||
}
|
||||
|
||||
func startSTUN(conn net.PacketConn, writeTo func([]byte, *net.UDPAddr)) error {
|
||||
req := new(stun.Message)
|
||||
res := new(stun.Message)
|
||||
|
||||
p := make([]byte, 1024)
|
||||
n, addr, err := conn.ReadFrom(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p = p[:n]
|
||||
if !stun.IsMessage(p) {
|
||||
return errors.New("not a STUN message")
|
||||
}
|
||||
if _, err := req.Write(p); err != nil {
|
||||
return err
|
||||
}
|
||||
mappedAddr := &stun.XORMappedAddress{
|
||||
IP: net.ParseIP("4.5.6.7"),
|
||||
Port: 4567,
|
||||
}
|
||||
software := stun.NewSoftware("endpointer")
|
||||
err = res.Build(req, stun.BindingSuccess, software, mappedAddr, stun.Fingerprint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
writeTo(res.Raw, addr.(*net.UDPAddr))
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: test retry timeout (overwrite the retryDurations)
|
||||
// TODO: test canceling context passed to Run
|
||||
// TODO: test sending bad packets
|
||||
Reference in New Issue
Block a user