From 953fa80c6f466b8c3ffe7ba3a97dd1ae8e5f6bfe Mon Sep 17 00:00:00 2001 From: James Tucker Date: Fri, 5 Jan 2024 11:14:42 -0800 Subject: [PATCH] cmd/{derper,stund},net/stunserver: add standalone stun server Add a standalone server for STUN that can be hosted independently of the derper, and factor that back into the derper. Fixes #8434 Closes #8435 Closes #10745 Signed-off-by: James Tucker --- Makefile | 6 +- cmd/derper/depaware.txt | 4 +- cmd/derper/derper.go | 88 +++----------- cmd/derper/derper_test.go | 34 ------ cmd/stund/depaware.txt | 190 ++++++++++++++++++++++++++++++ cmd/stund/stund.go | 48 ++++++++ net/stunserver/stunserver.go | 126 ++++++++++++++++++++ net/stunserver/stunserver_test.go | 88 ++++++++++++++ 8 files changed, 474 insertions(+), 110 deletions(-) create mode 100644 cmd/stund/depaware.txt create mode 100644 cmd/stund/stund.go create mode 100644 net/stunserver/stunserver.go create mode 100644 net/stunserver/stunserver_test.go diff --git a/Makefile b/Makefile index 6a14bbc47..3a5db44f4 100644 --- a/Makefile +++ b/Makefile @@ -18,7 +18,8 @@ updatedeps: ## Update depaware deps PATH="$$(./tool/go env GOROOT)/bin:$$PATH" ./tool/go run github.com/tailscale/depaware --update \ tailscale.com/cmd/tailscaled \ tailscale.com/cmd/tailscale \ - tailscale.com/cmd/derper + tailscale.com/cmd/derper \ + tailscale.com/cmd/stund depaware: ## Run depaware checks # depaware (via x/tools/go/packages) shells back to "go", so make sure the "go" @@ -26,7 +27,8 @@ depaware: ## Run depaware checks PATH="$$(./tool/go env GOROOT)/bin:$$PATH" ./tool/go run github.com/tailscale/depaware --check \ tailscale.com/cmd/tailscaled \ tailscale.com/cmd/tailscale \ - tailscale.com/cmd/derper + tailscale.com/cmd/derper \ + tailscale.com/cmd/stund buildwindows: ## Build tailscale CLI for windows/amd64 GOOS=windows GOARCH=amd64 ./tool/go install tailscale.com/cmd/tailscale tailscale.com/cmd/tailscaled diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index 3caf75f46..a661a6e02 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -105,7 +105,8 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa tailscale.com/net/netutil from tailscale.com/client/tailscale tailscale.com/net/packet from tailscale.com/wgengine/filter tailscale.com/net/sockstats from tailscale.com/derp/derphttp - tailscale.com/net/stun from tailscale.com/cmd/derper + tailscale.com/net/stun from tailscale.com/net/stunserver + tailscale.com/net/stunserver from tailscale.com/cmd/derper L tailscale.com/net/tcpinfo from tailscale.com/derp tailscale.com/net/tlsdial from tailscale.com/derp/derphttp tailscale.com/net/tsaddr from tailscale.com/ipn+ @@ -263,6 +264,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa net/url from crypto/x509+ os from crypto/rand+ os/exec from golang.zx2c4.com/wireguard/windows/tunnel/winipcfg+ + os/signal from tailscale.com/cmd/derper W os/user from tailscale.com/util/winutil path from golang.org/x/crypto/acme/autocert+ path/filepath from crypto/x509+ diff --git a/cmd/derper/derper.go b/cmd/derper/derper.go index a757745ba..a1747753d 100644 --- a/cmd/derper/derper.go +++ b/cmd/derper/derper.go @@ -17,11 +17,12 @@ "math" "net" "net/http" - "net/netip" "os" + "os/signal" "path/filepath" "regexp" "strings" + "syscall" "time" "go4.org/mem" @@ -30,7 +31,7 @@ "tailscale.com/derp" "tailscale.com/derp/derphttp" "tailscale.com/metrics" - "tailscale.com/net/stun" + "tailscale.com/net/stunserver" "tailscale.com/tsweb" "tailscale.com/types/key" "tailscale.com/util/cmpx" @@ -59,25 +60,11 @@ ) var ( - stats = new(metrics.Set) - stunDisposition = &metrics.LabelMap{Label: "disposition"} - stunAddrFamily = &metrics.LabelMap{Label: "family"} tlsRequestVersion = &metrics.LabelMap{Label: "version"} tlsActiveVersion = &metrics.LabelMap{Label: "version"} - - stunReadError = stunDisposition.Get("read_error") - stunNotSTUN = stunDisposition.Get("not_stun") - stunWriteError = stunDisposition.Get("write_error") - stunSuccess = stunDisposition.Get("success") - - stunIPv4 = stunAddrFamily.Get("ipv4") - stunIPv6 = stunAddrFamily.Get("ipv6") ) func init() { - stats.Set("counter_requests", stunDisposition) - stats.Set("counter_addrfamily", stunAddrFamily) - expvar.Publish("stun", stats) expvar.Publish("derper_tls_request_version", tlsRequestVersion) expvar.Publish("gauge_derper_tls_active_version", tlsActiveVersion) } @@ -135,6 +122,9 @@ func writeNewConfig() config { func main() { flag.Parse() + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + if *dev { *addr = ":3340" // above the keys DERP log.Printf("Running in dev mode.") @@ -146,6 +136,11 @@ func main() { log.Fatalf("invalid server address: %v", err) } + if *runSTUN { + ss := stunserver.New(ctx) + go ss.ListenAndServe(net.JoinHostPort(listenHost, fmt.Sprint(*stunPort))) + } + cfg := loadConfig() serveTLS := tsweb.IsProd443(*addr) || *certMode == "manual" @@ -221,10 +216,6 @@ func main() { })) debug.Handle("traffic", "Traffic check", http.HandlerFunc(s.ServeDebugTraffic)) - if *runSTUN { - go serveSTUN(listenHost, *stunPort) - } - quietLogger := log.New(logFilter{}, "", 0) httpsrv := &http.Server{ Addr: *addr, @@ -241,6 +232,10 @@ func main() { ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, } + go func() { + <-ctx.Done() + httpsrv.Shutdown(ctx) + }() if serveTLS { log.Printf("derper: serving on %s with TLS", *addr) @@ -351,59 +346,6 @@ func probeHandler(w http.ResponseWriter, r *http.Request) { } } -func serveSTUN(host string, port int) { - pc, err := net.ListenPacket("udp", net.JoinHostPort(host, fmt.Sprint(port))) - if err != nil { - log.Fatalf("failed to open STUN listener: %v", err) - } - log.Printf("running STUN server on %v", pc.LocalAddr()) - serverSTUNListener(context.Background(), pc.(*net.UDPConn)) -} - -func serverSTUNListener(ctx context.Context, pc *net.UDPConn) { - var buf [64 << 10]byte - var ( - n int - ua *net.UDPAddr - err error - ) - for { - n, ua, err = pc.ReadFromUDP(buf[:]) - if err != nil { - if ctx.Err() != nil { - return - } - log.Printf("STUN ReadFrom: %v", err) - time.Sleep(time.Second) - stunReadError.Add(1) - continue - } - pkt := buf[:n] - if !stun.Is(pkt) { - stunNotSTUN.Add(1) - continue - } - txid, err := stun.ParseBindingRequest(pkt) - if err != nil { - stunNotSTUN.Add(1) - continue - } - if ua.IP.To4() != nil { - stunIPv4.Add(1) - } else { - stunIPv6.Add(1) - } - addr, _ := netip.AddrFromSlice(ua.IP) - res := stun.Response(txid, netip.AddrPortFrom(addr, uint16(ua.Port))) - _, err = pc.WriteTo(res, ua) - if err != nil { - stunWriteError.Add(1) - } else { - stunSuccess.Add(1) - } - } -} - var validProdHostname = regexp.MustCompile(`^derp([^.]*)\.tailscale\.com\.?$`) func prodAutocertHostPolicy(_ context.Context, host string) error { diff --git a/cmd/derper/derper_test.go b/cmd/derper/derper_test.go index b1112adea..551800309 100644 --- a/cmd/derper/derper_test.go +++ b/cmd/derper/derper_test.go @@ -5,13 +5,11 @@ import ( "context" - "net" "net/http" "net/http/httptest" "strings" "testing" - "tailscale.com/net/stun" "tailscale.com/tstest/deptest" ) @@ -39,38 +37,6 @@ func TestProdAutocertHostPolicy(t *testing.T) { } } -func BenchmarkServerSTUN(b *testing.B) { - b.ReportAllocs() - pc, err := net.ListenPacket("udp", "127.0.0.1:0") - if err != nil { - b.Fatal(err) - } - defer pc.Close() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - go serverSTUNListener(ctx, pc.(*net.UDPConn)) - addr := pc.LocalAddr().(*net.UDPAddr) - - var resBuf [1500]byte - cc, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}) - if err != nil { - b.Fatal(err) - } - - tx := stun.NewTxID() - req := stun.Request(tx) - for i := 0; i < b.N; i++ { - if _, err := cc.WriteToUDP(req, addr); err != nil { - b.Fatal(err) - } - _, _, err := cc.ReadFromUDP(resBuf[:]) - if err != nil { - b.Fatal(err) - } - } - -} - func TestNoContent(t *testing.T) { testCases := []struct { name string diff --git a/cmd/stund/depaware.txt b/cmd/stund/depaware.txt new file mode 100644 index 000000000..9b89e92d5 --- /dev/null +++ b/cmd/stund/depaware.txt @@ -0,0 +1,190 @@ +tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depaware) + + github.com/beorn7/perks/quantile from github.com/prometheus/client_golang/prometheus + 💣 github.com/cespare/xxhash/v2 from github.com/prometheus/client_golang/prometheus + github.com/golang/protobuf/proto from github.com/matttproud/golang_protobuf_extensions/pbutil + github.com/google/uuid from tailscale.com/tsweb + github.com/matttproud/golang_protobuf_extensions/pbutil from github.com/prometheus/common/expfmt + 💣 github.com/prometheus/client_golang/prometheus from tailscale.com/tsweb/promvarz + github.com/prometheus/client_golang/prometheus/internal from github.com/prometheus/client_golang/prometheus + github.com/prometheus/client_model/go from github.com/prometheus/client_golang/prometheus+ + github.com/prometheus/common/expfmt from github.com/prometheus/client_golang/prometheus+ + github.com/prometheus/common/internal/bitbucket.org/ww/goautoneg from github.com/prometheus/common/expfmt + github.com/prometheus/common/model from github.com/prometheus/client_golang/prometheus+ + LD github.com/prometheus/procfs from github.com/prometheus/client_golang/prometheus + LD github.com/prometheus/procfs/internal/fs from github.com/prometheus/procfs + LD github.com/prometheus/procfs/internal/util from github.com/prometheus/procfs + 💣 go4.org/mem from tailscale.com/metrics+ + go4.org/netipx from tailscale.com/net/tsaddr + google.golang.org/protobuf/encoding/prototext from github.com/golang/protobuf/proto+ + google.golang.org/protobuf/encoding/protowire from github.com/golang/protobuf/proto+ + google.golang.org/protobuf/internal/descfmt from google.golang.org/protobuf/internal/filedesc + google.golang.org/protobuf/internal/descopts from google.golang.org/protobuf/internal/filedesc+ + google.golang.org/protobuf/internal/detrand from google.golang.org/protobuf/internal/descfmt+ + google.golang.org/protobuf/internal/encoding/defval from google.golang.org/protobuf/internal/encoding/tag+ + google.golang.org/protobuf/internal/encoding/messageset from google.golang.org/protobuf/encoding/prototext+ + google.golang.org/protobuf/internal/encoding/tag from google.golang.org/protobuf/internal/impl + google.golang.org/protobuf/internal/encoding/text from google.golang.org/protobuf/encoding/prototext+ + google.golang.org/protobuf/internal/errors from google.golang.org/protobuf/encoding/prototext+ + google.golang.org/protobuf/internal/filedesc from google.golang.org/protobuf/internal/encoding/tag+ + google.golang.org/protobuf/internal/filetype from google.golang.org/protobuf/runtime/protoimpl + google.golang.org/protobuf/internal/flags from google.golang.org/protobuf/encoding/prototext+ + google.golang.org/protobuf/internal/genid from google.golang.org/protobuf/encoding/prototext+ + 💣 google.golang.org/protobuf/internal/impl from google.golang.org/protobuf/internal/filetype+ + google.golang.org/protobuf/internal/order from google.golang.org/protobuf/encoding/prototext+ + google.golang.org/protobuf/internal/pragma from google.golang.org/protobuf/encoding/prototext+ + google.golang.org/protobuf/internal/set from google.golang.org/protobuf/encoding/prototext + 💣 google.golang.org/protobuf/internal/strs from google.golang.org/protobuf/encoding/prototext+ + google.golang.org/protobuf/internal/version from google.golang.org/protobuf/runtime/protoimpl + google.golang.org/protobuf/proto from github.com/golang/protobuf/proto+ + google.golang.org/protobuf/reflect/protodesc from github.com/golang/protobuf/proto + 💣 google.golang.org/protobuf/reflect/protoreflect from github.com/golang/protobuf/proto+ + google.golang.org/protobuf/reflect/protoregistry from github.com/golang/protobuf/proto+ + google.golang.org/protobuf/runtime/protoiface from github.com/golang/protobuf/proto+ + google.golang.org/protobuf/runtime/protoimpl from github.com/golang/protobuf/proto+ + google.golang.org/protobuf/types/descriptorpb from google.golang.org/protobuf/reflect/protodesc + google.golang.org/protobuf/types/known/timestamppb from github.com/prometheus/client_golang/prometheus+ + tailscale.com from tailscale.com/version + tailscale.com/envknob from tailscale.com/tsweb+ + tailscale.com/metrics from tailscale.com/net/stunserver+ + tailscale.com/net/netaddr from tailscale.com/net/tsaddr + tailscale.com/net/stun from tailscale.com/net/stunserver + tailscale.com/net/stunserver from tailscale.com/cmd/stund + tailscale.com/net/tsaddr from tailscale.com/tsweb + tailscale.com/tailcfg from tailscale.com/version + tailscale.com/tsweb from tailscale.com/cmd/stund + tailscale.com/tsweb/promvarz from tailscale.com/tsweb + tailscale.com/tsweb/varz from tailscale.com/tsweb+ + tailscale.com/types/dnstype from tailscale.com/tailcfg + tailscale.com/types/ipproto from tailscale.com/tailcfg + tailscale.com/types/key from tailscale.com/tailcfg + tailscale.com/types/lazy from tailscale.com/version+ + tailscale.com/types/logger from tailscale.com/tsweb + tailscale.com/types/opt from tailscale.com/envknob+ + tailscale.com/types/ptr from tailscale.com/tailcfg + tailscale.com/types/structs from tailscale.com/tailcfg+ + tailscale.com/types/tkatype from tailscale.com/tailcfg+ + tailscale.com/types/views from tailscale.com/net/tsaddr+ + tailscale.com/util/cmpx from tailscale.com/tailcfg+ + L 💣 tailscale.com/util/dirwalk from tailscale.com/metrics + tailscale.com/util/dnsname from tailscale.com/tailcfg + tailscale.com/util/lineread from tailscale.com/version/distro + tailscale.com/util/nocasemaps from tailscale.com/types/ipproto + tailscale.com/util/slicesx from tailscale.com/tailcfg + tailscale.com/util/vizerror from tailscale.com/tailcfg+ + tailscale.com/version from tailscale.com/envknob+ + tailscale.com/version/distro from tailscale.com/envknob + golang.org/x/crypto/blake2b from golang.org/x/crypto/nacl/box + golang.org/x/crypto/chacha20 from golang.org/x/crypto/chacha20poly1305 + golang.org/x/crypto/chacha20poly1305 from crypto/tls + golang.org/x/crypto/cryptobyte from crypto/ecdsa+ + golang.org/x/crypto/cryptobyte/asn1 from crypto/ecdsa+ + golang.org/x/crypto/curve25519 from golang.org/x/crypto/nacl/box+ + golang.org/x/crypto/hkdf from crypto/tls + golang.org/x/crypto/nacl/box from tailscale.com/types/key + golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box + golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ + golang.org/x/net/dns/dnsmessage from net + golang.org/x/net/http/httpguts from net/http + golang.org/x/net/http/httpproxy from net/http + golang.org/x/net/http2/hpack from net/http + golang.org/x/net/idna from golang.org/x/net/http/httpguts+ + D golang.org/x/net/route from net + golang.org/x/sys/cpu from golang.org/x/crypto/blake2b+ + LD golang.org/x/sys/unix from github.com/prometheus/procfs+ + W golang.org/x/sys/windows from github.com/prometheus/client_golang/prometheus + golang.org/x/text/secure/bidirule from golang.org/x/net/idna + golang.org/x/text/transform from golang.org/x/text/secure/bidirule+ + golang.org/x/text/unicode/bidi from golang.org/x/net/idna+ + golang.org/x/text/unicode/norm from golang.org/x/net/idna + bufio from compress/flate+ + bytes from bufio+ + cmp from slices + compress/flate from compress/gzip + compress/gzip from github.com/golang/protobuf/proto+ + container/list from crypto/tls+ + context from crypto/tls+ + crypto from crypto/ecdh+ + crypto/aes from crypto/ecdsa+ + crypto/cipher from crypto/aes+ + crypto/des from crypto/tls+ + crypto/dsa from crypto/x509 + crypto/ecdh from crypto/ecdsa+ + crypto/ecdsa from crypto/tls+ + crypto/ed25519 from crypto/tls+ + crypto/elliptic from crypto/ecdsa+ + crypto/hmac from crypto/tls+ + crypto/md5 from crypto/tls+ + crypto/rand from crypto/ed25519+ + crypto/rc4 from crypto/tls + crypto/rsa from crypto/tls+ + crypto/sha1 from crypto/tls+ + crypto/sha256 from crypto/tls+ + crypto/sha512 from crypto/ecdsa+ + crypto/subtle from crypto/aes+ + crypto/tls from net/http+ + crypto/x509 from crypto/tls + crypto/x509/pkix from crypto/x509 + database/sql/driver from github.com/google/uuid + embed from crypto/internal/nistec+ + encoding from encoding/json+ + encoding/asn1 from crypto/x509+ + encoding/base64 from encoding/json+ + encoding/binary from compress/gzip+ + encoding/hex from crypto/x509+ + encoding/json from expvar+ + encoding/pem from crypto/tls+ + errors from bufio+ + expvar from github.com/prometheus/client_golang/prometheus+ + flag from tailscale.com/cmd/stund + fmt from compress/flate+ + go/token from google.golang.org/protobuf/internal/strs + hash from crypto+ + hash/crc32 from compress/gzip+ + hash/fnv from google.golang.org/protobuf/internal/detrand + hash/maphash from go4.org/mem + html from net/http/pprof+ + io from bufio+ + io/fs from crypto/x509+ + io/ioutil from github.com/golang/protobuf/proto+ + log from expvar+ + log/internal from log + maps from tailscale.com/tailcfg+ + math from compress/flate+ + math/big from crypto/dsa+ + math/bits from compress/flate+ + math/rand from math/big+ + mime from github.com/prometheus/common/expfmt+ + mime/multipart from net/http + mime/quotedprintable from mime/multipart + net from crypto/tls+ + net/http from expvar+ + net/http/httptrace from net/http + net/http/internal from net/http + net/http/pprof from tailscale.com/tsweb+ + net/netip from go4.org/netipx+ + net/textproto from golang.org/x/net/http/httpguts+ + net/url from crypto/x509+ + os from crypto/rand+ + os/signal from tailscale.com/cmd/stund + path from github.com/prometheus/client_golang/prometheus/internal+ + path/filepath from crypto/x509+ + reflect from crypto/x509+ + regexp from github.com/prometheus/client_golang/prometheus/internal+ + regexp/syntax from regexp + runtime/debug from github.com/prometheus/client_golang/prometheus+ + runtime/metrics from github.com/prometheus/client_golang/prometheus+ + runtime/pprof from net/http/pprof + runtime/trace from net/http/pprof + slices from tailscale.com/metrics+ + sort from compress/flate+ + strconv from compress/flate+ + strings from bufio+ + sync from compress/flate+ + sync/atomic from context+ + syscall from crypto/rand+ + text/tabwriter from runtime/pprof + time from compress/gzip+ + unicode from bytes+ + unicode/utf16 from crypto/x509+ + unicode/utf8 from bufio+ diff --git a/cmd/stund/stund.go b/cmd/stund/stund.go new file mode 100644 index 000000000..c38429169 --- /dev/null +++ b/cmd/stund/stund.go @@ -0,0 +1,48 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The stund binary is a standalone STUN server. +package main + +import ( + "context" + "flag" + "io" + "log" + "net/http" + "os/signal" + "syscall" + + "tailscale.com/net/stunserver" + "tailscale.com/tsweb" +) + +var ( + stunAddr = flag.String("stun", ":3478", "UDP address on which to start the STUN server") + httpAddr = flag.String("http", ":3479", "address on which to start the debug http server") +) + +func main() { + flag.Parse() + + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + log.Printf("HTTP server listening on %s", *httpAddr) + go http.ListenAndServe(*httpAddr, mux()) + + s := stunserver.New(ctx) + if err := s.ListenAndServe(*stunAddr); err != nil { + log.Fatal(err) + } +} + +func mux() *http.ServeMux { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "

stund

/debug") + }) + debug := tsweb.Debugger(mux) + debug.KV("stun_addr", *stunAddr) + return mux +} diff --git a/net/stunserver/stunserver.go b/net/stunserver/stunserver.go new file mode 100644 index 000000000..b45bb6331 --- /dev/null +++ b/net/stunserver/stunserver.go @@ -0,0 +1,126 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package stunserver implements a STUN server. The package publishes a number of stats +// to expvar under the top level label "stun". Logs are sent to the standard log package. +package stunserver + +import ( + "context" + "errors" + "expvar" + "io" + "log" + "net" + "net/netip" + "time" + + "tailscale.com/metrics" + "tailscale.com/net/stun" +) + +var ( + stats = new(metrics.Set) + stunDisposition = &metrics.LabelMap{Label: "disposition"} + stunAddrFamily = &metrics.LabelMap{Label: "family"} + stunReadError = stunDisposition.Get("read_error") + stunNotSTUN = stunDisposition.Get("not_stun") + stunWriteError = stunDisposition.Get("write_error") + stunSuccess = stunDisposition.Get("success") + + stunIPv4 = stunAddrFamily.Get("ipv4") + stunIPv6 = stunAddrFamily.Get("ipv6") +) + +func init() { + stats.Set("counter_requests", stunDisposition) + stats.Set("counter_addrfamily", stunAddrFamily) + expvar.Publish("stun", stats) +} + +type STUNServer struct { + ctx context.Context // ctx signals service shutdown + pc *net.UDPConn // pc is the UDP listener +} + +// New creates a new STUN server. The server is shutdown when ctx is done. +func New(ctx context.Context) *STUNServer { + return &STUNServer{ctx: ctx} +} + +// Listen binds the listen socket for the server at listenAddr. +func (s *STUNServer) Listen(listenAddr string) error { + uaddr, err := net.ResolveUDPAddr("udp", listenAddr) + if err != nil { + return err + } + s.pc, err = net.ListenUDP("udp", uaddr) + if err != nil { + return err + } + log.Printf("STUN server listening on %v", s.LocalAddr()) + // close the listener on shutdown in order to break out of the read loop + go func() { + <-s.ctx.Done() + s.pc.Close() + }() + return nil +} + +// Serve starts serving responses to STUN requests. Listen must be called before Serve. +func (s *STUNServer) Serve() error { + var buf [64 << 10]byte + var ( + n int + ua *net.UDPAddr + err error + ) + for { + n, ua, err = s.pc.ReadFromUDP(buf[:]) + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { + return nil + } + log.Printf("STUN ReadFrom: %v", err) + time.Sleep(time.Second) + stunReadError.Add(1) + continue + } + pkt := buf[:n] + if !stun.Is(pkt) { + stunNotSTUN.Add(1) + continue + } + txid, err := stun.ParseBindingRequest(pkt) + if err != nil { + stunNotSTUN.Add(1) + continue + } + if ua.IP.To4() != nil { + stunIPv4.Add(1) + } else { + stunIPv6.Add(1) + } + addr, _ := netip.AddrFromSlice(ua.IP) + res := stun.Response(txid, netip.AddrPortFrom(addr, uint16(ua.Port))) + _, err = s.pc.WriteTo(res, ua) + if err != nil { + stunWriteError.Add(1) + } else { + stunSuccess.Add(1) + } + } +} + +// ListenAndServe starts the STUN server on listenAddr. +func (s *STUNServer) ListenAndServe(listenAddr string) error { + if err := s.Listen(listenAddr); err != nil { + return err + } + return s.Serve() +} + +// LocalAddr returns the local address of the STUN server. It must not be called before ListenAndServe. +func (s *STUNServer) LocalAddr() net.Addr { + return s.pc.LocalAddr() +} diff --git a/net/stunserver/stunserver_test.go b/net/stunserver/stunserver_test.go new file mode 100644 index 000000000..95b2a8c7b --- /dev/null +++ b/net/stunserver/stunserver_test.go @@ -0,0 +1,88 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package stunserver + +import ( + "context" + "net" + "sync" + "testing" + "time" + + "tailscale.com/net/stun" + "tailscale.com/util/must" +) + +func TestSTUNServer(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + s := New(ctx) + must.Do(s.Listen("localhost:0")) + var w sync.WaitGroup + w.Add(1) + var serveErr error + go func() { + defer w.Done() + serveErr = s.Serve() + }() + + c := must.Get(net.DialUDP("udp", nil, s.LocalAddr().(*net.UDPAddr))) + defer c.Close() + c.SetDeadline(time.Now().Add(5 * time.Second)) + txid := stun.NewTxID() + _, err := c.Write(stun.Request(txid)) + if err != nil { + t.Fatalf("failed to write STUN request: %v", err) + } + var buf [64 << 10]byte + n, err := c.Read(buf[:]) + if err != nil { + t.Fatalf("failed to read STUN response: %v", err) + } + if !stun.Is(buf[:n]) { + t.Fatalf("response is not STUN") + } + tid, _, err := stun.ParseResponse(buf[:n]) + if err != nil { + t.Fatalf("failed to parse STUN response: %v", err) + } + if tid != txid { + t.Fatalf("STUN response has wrong transaction ID; got %d, want %d", tid, txid) + } + + cancel() + w.Wait() + if serveErr != nil { + t.Fatalf("failed to listen and serve: %v", serveErr) + } +} + +func BenchmarkServerSTUN(b *testing.B) { + b.ReportAllocs() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := New(ctx) + s.Listen("localhost:0") + go s.Serve() + addr := s.LocalAddr().(*net.UDPAddr) + + var resBuf [1500]byte + cc, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}) + if err != nil { + b.Fatal(err) + } + + tx := stun.NewTxID() + req := stun.Request(tx) + for i := 0; i < b.N; i++ { + if _, err := cc.WriteToUDP(req, addr); err != nil { + b.Fatal(err) + } + _, _, err := cc.ReadFromUDP(resBuf[:]) + if err != nil { + b.Fatal(err) + } + } +}