From 3c9be072147dc87de8f6e7a3bbedba7c5e93fe50 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 22 May 2024 10:34:57 -0700 Subject: [PATCH] cmd/derper: support TXT-mediated unpublished bootstrap DNS rollouts Updates tailscale/coral#127 Change-Id: I2712c50630d0d1272c30305fa5a1899a19ffacef Signed-off-by: Brad Fitzpatrick --- cmd/derper/bootstrap_dns.go | 133 ++++++++++++++++++++++++------- cmd/derper/bootstrap_dns_test.go | 87 +++++++++++++++++--- cmd/derper/depaware.txt | 2 +- cmd/derper/derper.go | 2 +- 4 files changed, 178 insertions(+), 46 deletions(-) diff --git a/cmd/derper/bootstrap_dns.go b/cmd/derper/bootstrap_dns.go index ee33899f6..a58f040ba 100644 --- a/cmd/derper/bootstrap_dns.go +++ b/cmd/derper/bootstrap_dns.go @@ -5,35 +5,45 @@ package main import ( "context" + "encoding/binary" "encoding/json" "expvar" "log" + "math/rand/v2" "net" "net/http" + "net/netip" + "strconv" "strings" + "sync/atomic" "time" "tailscale.com/syncs" + "tailscale.com/util/mak" "tailscale.com/util/slicesx" ) const refreshTimeout = time.Minute -type dnsEntryMap map[string][]net.IP +type dnsEntryMap struct { + IPs map[string][]net.IP + Percent map[string]float64 // "foo.com" => 0.5 for 50% +} var ( - dnsCache syncs.AtomicValue[dnsEntryMap] + dnsCache atomic.Pointer[dnsEntryMap] dnsCacheBytes syncs.AtomicValue[[]byte] // of JSON - unpublishedDNSCache syncs.AtomicValue[dnsEntryMap] + unpublishedDNSCache atomic.Pointer[dnsEntryMap] bootstrapLookupMap syncs.Map[string, bool] ) var ( - bootstrapDNSRequests = expvar.NewInt("counter_bootstrap_dns_requests") - publishedDNSHits = expvar.NewInt("counter_bootstrap_dns_published_hits") - publishedDNSMisses = expvar.NewInt("counter_bootstrap_dns_published_misses") - unpublishedDNSHits = expvar.NewInt("counter_bootstrap_dns_unpublished_hits") - unpublishedDNSMisses = expvar.NewInt("counter_bootstrap_dns_unpublished_misses") + bootstrapDNSRequests = expvar.NewInt("counter_bootstrap_dns_requests") + publishedDNSHits = expvar.NewInt("counter_bootstrap_dns_published_hits") + publishedDNSMisses = expvar.NewInt("counter_bootstrap_dns_published_misses") + unpublishedDNSHits = expvar.NewInt("counter_bootstrap_dns_unpublished_hits") + unpublishedDNSMisses = expvar.NewInt("counter_bootstrap_dns_unpublished_misses") + unpublishedDNSPercentMisses = expvar.NewInt("counter_bootstrap_dns_unpublished_percent_misses") ) func init() { @@ -59,15 +69,13 @@ func refreshBootstrapDNS() { } ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout) defer cancel() - dnsEntries := resolveList(ctx, strings.Split(*bootstrapDNS, ",")) + dnsEntries := resolveList(ctx, *bootstrapDNS) // Randomize the order of the IPs for each name to avoid the client biasing // to IPv6 - for k := range dnsEntries { - ips := dnsEntries[k] - slicesx.Shuffle(ips) - dnsEntries[k] = ips + for _, vv := range dnsEntries.IPs { + slicesx.Shuffle(vv) } - j, err := json.MarshalIndent(dnsEntries, "", "\t") + j, err := json.MarshalIndent(dnsEntries.IPs, "", "\t") if err != nil { // leave the old values in place return @@ -81,27 +89,50 @@ func refreshUnpublishedDNS() { if *unpublishedDNS == "" { return } - ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout) defer cancel() - - dnsEntries := resolveList(ctx, strings.Split(*unpublishedDNS, ",")) + dnsEntries := resolveList(ctx, *unpublishedDNS) unpublishedDNSCache.Store(dnsEntries) } -func resolveList(ctx context.Context, names []string) dnsEntryMap { - dnsEntries := make(dnsEntryMap) +// resolveList takes a comma-separated list of DNS names to resolve. +// +// If an entry contains a slash, it's two DNS names: the first is the one to +// resolve and the second is that of a TXT recording containing the rollout +// percentage in range "0".."100". If the TXT record doesn't exist or is +// malformed, the percentage is 0. If the TXT record is not provided (there's no +// slash), then the percentage is 100. +func resolveList(ctx context.Context, list string) *dnsEntryMap { + ents := strings.Split(list, ",") + + ret := &dnsEntryMap{} var r net.Resolver - for _, name := range names { + for _, ent := range ents { + name, txtName, _ := strings.Cut(ent, "/") addrs, err := r.LookupIP(ctx, "ip", name) if err != nil { log.Printf("bootstrap DNS lookup %q: %v", name, err) continue } - dnsEntries[name] = addrs + mak.Set(&ret.IPs, name, addrs) + + if txtName == "" { + mak.Set(&ret.Percent, name, 1.0) + continue + } + vals, err := r.LookupTXT(ctx, txtName) + if err != nil { + log.Printf("bootstrap DNS lookup %q: %v", txtName, err) + continue + } + for _, v := range vals { + if v, err := strconv.Atoi(v); err == nil && v >= 0 && v <= 100 { + mak.Set(&ret.Percent, name, float64(v)/100) + } + } } - return dnsEntries + return ret } func handleBootstrapDNS(w http.ResponseWriter, r *http.Request) { @@ -115,22 +146,36 @@ func handleBootstrapDNS(w http.ResponseWriter, r *http.Request) { // Try answering a query from our hidden map first if q := r.URL.Query().Get("q"); q != "" { bootstrapLookupMap.Store(q, true) - if ips, ok := unpublishedDNSCache.Load()[q]; ok && len(ips) > 0 { + if bootstrapLookupMap.Len() > 500 { // defensive + bootstrapLookupMap.Clear() + } + if m := unpublishedDNSCache.Load(); m != nil && len(m.IPs[q]) > 0 { unpublishedDNSHits.Add(1) - // Only return the specific query, not everything. - m := dnsEntryMap{q: ips} - j, err := json.MarshalIndent(m, "", "\t") - if err == nil { - w.Write(j) - return + percent := m.Percent[q] + if remoteAddrMatchesPercent(r.RemoteAddr, percent) { + // Only return the specific query, not everything. + m := map[string][]net.IP{q: m.IPs[q]} + j, err := json.MarshalIndent(m, "", "\t") + if err == nil { + w.Write(j) + return + } + } else { + unpublishedDNSPercentMisses.Add(1) } } // If we have a "q" query for a name in the published cache // list, then track whether that's a hit/miss. - if m, ok := dnsCache.Load()[q]; ok { - if len(m) > 0 { + m := dnsCache.Load() + var inPub bool + var ips []net.IP + if m != nil { + ips, inPub = m.IPs[q] + } + if inPub { + if len(ips) > 0 { publishedDNSHits.Add(1) } else { publishedDNSMisses.Add(1) @@ -146,3 +191,29 @@ func handleBootstrapDNS(w http.ResponseWriter, r *http.Request) { j := dnsCacheBytes.Load() w.Write(j) } + +// percent is [0.0, 1.0]. +func remoteAddrMatchesPercent(remoteAddr string, percent float64) bool { + if percent == 0 { + return false + } + if percent == 1 { + return true + } + reqIPStr, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + return false + } + reqIP, err := netip.ParseAddr(reqIPStr) + if err != nil { + return false + } + if reqIP.IsLoopback() { + // For local testing. + return rand.Float64() < 0.5 + } + reqIP16 := reqIP.As16() + rndSrc := rand.NewPCG(binary.LittleEndian.Uint64(reqIP16[:8]), binary.LittleEndian.Uint64(reqIP16[8:])) + rnd := rand.New(rndSrc) + return percent > rnd.Float64() +} diff --git a/cmd/derper/bootstrap_dns_test.go b/cmd/derper/bootstrap_dns_test.go index 0a4496f61..d151bc2b0 100644 --- a/cmd/derper/bootstrap_dns_test.go +++ b/cmd/derper/bootstrap_dns_test.go @@ -4,10 +4,13 @@ package main import ( + "bytes" "encoding/json" + "io" "net" "net/http" "net/http/httptest" + "net/netip" "net/url" "reflect" "testing" @@ -38,7 +41,7 @@ func (b *bitbucketResponseWriter) Write(p []byte) (int, error) { return len(p), func (b *bitbucketResponseWriter) WriteHeader(statusCode int) {} -func getBootstrapDNS(t *testing.T, q string) dnsEntryMap { +func getBootstrapDNS(t *testing.T, q string) map[string][]net.IP { t.Helper() req, _ := http.NewRequest("GET", "https://localhost/bootstrap-dns?q="+url.QueryEscape(q), nil) w := httptest.NewRecorder() @@ -48,11 +51,12 @@ func getBootstrapDNS(t *testing.T, q string) dnsEntryMap { if res.StatusCode != 200 { t.Fatalf("got status=%d; want %d", res.StatusCode, 200) } - var ips dnsEntryMap - if err := json.NewDecoder(res.Body).Decode(&ips); err != nil { - t.Fatalf("error decoding response body: %v", err) + var m map[string][]net.IP + var buf bytes.Buffer + if err := json.NewDecoder(io.TeeReader(res.Body, &buf)).Decode(&m); err != nil { + t.Fatalf("error decoding response body %q: %v", buf.Bytes(), err) } - return ips + return m } func TestUnpublishedDNS(t *testing.T) { @@ -107,15 +111,21 @@ func resetMetrics() { // Verify that we don't count an empty list in the unpublishedDNSCache as a // cache hit in our metrics. func TestUnpublishedDNSEmptyList(t *testing.T) { - pub := dnsEntryMap{ - "tailscale.com": {net.IPv4(10, 10, 10, 10)}, + pub := &dnsEntryMap{ + IPs: map[string][]net.IP{"tailscale.com": {net.IPv4(10, 10, 10, 10)}}, } dnsCache.Store(pub) dnsCacheBytes.Store([]byte(`{"tailscale.com":["10.10.10.10"]}`)) - unpublishedDNSCache.Store(dnsEntryMap{ - "log.tailscale.io": {}, - "controlplane.tailscale.com": {net.IPv4(1, 2, 3, 4)}, + unpublishedDNSCache.Store(&dnsEntryMap{ + IPs: map[string][]net.IP{ + "log.tailscale.io": {}, + "controlplane.tailscale.com": {net.IPv4(1, 2, 3, 4)}, + }, + Percent: map[string]float64{ + "log.tailscale.io": 1.0, + "controlplane.tailscale.com": 1.0, + }, }) t.Run("CacheMiss", func(t *testing.T) { @@ -125,8 +135,8 @@ func TestUnpublishedDNSEmptyList(t *testing.T) { ips := getBootstrapDNS(t, q) // Expected our public map to be returned on a cache miss - if !reflect.DeepEqual(ips, pub) { - t.Errorf("got ips=%+v; want %+v", ips, pub) + if !reflect.DeepEqual(ips, pub.IPs) { + t.Errorf("got ips=%+v; want %+v", ips, pub.IPs) } if v := unpublishedDNSHits.Value(); v != 0 { t.Errorf("got hits=%d; want 0", v) @@ -141,7 +151,7 @@ func TestUnpublishedDNSEmptyList(t *testing.T) { t.Run("CacheHit", func(t *testing.T) { resetMetrics() ips := getBootstrapDNS(t, "controlplane.tailscale.com") - want := dnsEntryMap{"controlplane.tailscale.com": {net.IPv4(1, 2, 3, 4)}} + want := map[string][]net.IP{"controlplane.tailscale.com": {net.IPv4(1, 2, 3, 4)}} if !reflect.DeepEqual(ips, want) { t.Errorf("got ips=%+v; want %+v", ips, want) } @@ -166,3 +176,54 @@ func TestLookupMetric(t *testing.T) { t.Errorf("bootstrapLookupMap.Len() want=5, got %v", bootstrapLookupMap.Len()) } } + +func TestRemoteAddrMatchesPercent(t *testing.T) { + tests := []struct { + remoteAddr string + percent float64 + want bool + }{ + // 0% and 100%. + {"10.0.0.1:1234", 0.0, false}, + {"10.0.0.1:1234", 1.0, true}, + + // Invalid IP. + {"", 1.0, true}, + {"", 0.0, false}, + {"", 0.5, false}, + + // Small manual sample at 50%. The func uses a deterministic PRNG seed. + {"1.2.3.4:567", 0.5, true}, + {"1.2.3.5:567", 0.5, true}, + {"1.2.3.6:567", 0.5, false}, + {"1.2.3.7:567", 0.5, true}, + {"1.2.3.8:567", 0.5, false}, + {"1.2.3.9:567", 0.5, true}, + {"1.2.3.10:567", 0.5, true}, + } + for _, tt := range tests { + got := remoteAddrMatchesPercent(tt.remoteAddr, tt.percent) + if got != tt.want { + t.Errorf("remoteAddrMatchesPercent(%q, %v) = %v; want %v", tt.remoteAddr, tt.percent, got, tt.want) + } + } + + var match, all int + const wantPercent = 0.5 + for a := range 256 { + for b := range 256 { + all++ + if remoteAddrMatchesPercent( + netip.AddrPortFrom(netip.AddrFrom4([4]byte{1, 2, byte(a), byte(b)}), 12345).String(), + wantPercent) { + match++ + } + } + } + gotPercent := float64(match) / float64(all) + const tolerance = 0.005 + t.Logf("got percent %v (goal %v)", gotPercent, wantPercent) + if gotPercent < wantPercent-tolerance || gotPercent > wantPercent+tolerance { + t.Errorf("got %v; want %v ± %v", gotPercent, wantPercent, tolerance) + } +} diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index 22e910179..b5839b4e9 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -253,7 +253,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa math/big from crypto/dsa+ math/bits from compress/flate+ math/rand from github.com/mdlayher/netlink+ - math/rand/v2 from tailscale.com/util/fastuuid + math/rand/v2 from tailscale.com/util/fastuuid+ mime from github.com/prometheus/common/expfmt+ mime/multipart from net/http mime/quotedprintable from mime/multipart diff --git a/cmd/derper/derper.go b/cmd/derper/derper.go index abf254215..60a269b7d 100644 --- a/cmd/derper/derper.go +++ b/cmd/derper/derper.go @@ -55,7 +55,7 @@ var ( meshPSKFile = flag.String("mesh-psk-file", defaultMeshPSKFile(), "if non-empty, path to file containing the mesh pre-shared key file. It should contain some hex string; whitespace is trimmed.") meshWith = flag.String("mesh-with", "", "optional comma-separated list of hostnames to mesh with; the server's own hostname can be in the list") bootstrapDNS = flag.String("bootstrap-dns-names", "", "optional comma-separated list of hostnames to make available at /bootstrap-dns") - unpublishedDNS = flag.String("unpublished-bootstrap-dns-names", "", "optional comma-separated list of hostnames to make available at /bootstrap-dns and not publish in the list") + unpublishedDNS = flag.String("unpublished-bootstrap-dns-names", "", "optional comma-separated list of hostnames to make available at /bootstrap-dns and not publish in the list. If an entry contains a slash, the second part names a DNS record to poll for its TXT record with a `0` to `100` value for rollout percentage.") verifyClients = flag.Bool("verify-clients", false, "verify clients to this DERP server through a local tailscaled instance.") verifyClientURL = flag.String("verify-client-url", "", "if non-empty, an admission controller URL for permitting client connections; see tailcfg.DERPAdmitClientRequest") verifyFailOpen = flag.Bool("verify-client-url-fail-open", true, "whether we fail open if --verify-client-url is unreachable")