From f9fa174e11411d1d9ff8d7725acfefbc26f0688a Mon Sep 17 00:00:00 2001 From: Will Norris Date: Wed, 23 Nov 2022 09:54:48 -0800 Subject: [PATCH] tsweb: add EnforceHostname helper for DNS rebinding mitigation Signed-off-by: Will Norris --- tsweb/tsweb.go | 66 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/tsweb/tsweb.go b/tsweb/tsweb.go index 8b52e656d..7191d7d0a 100644 --- a/tsweb/tsweb.go +++ b/tsweb/tsweb.go @@ -25,9 +25,11 @@ "strconv" "strings" "sync" + "sync/atomic" "time" "go4.org/mem" + "tailscale.com/client/tailscale" "tailscale.com/envknob" "tailscale.com/metrics" "tailscale.com/net/tsaddr" @@ -779,3 +781,67 @@ func (r expVarPromStructRoot) String() string { panic("unused _ PrometheusMetricsReflectRooter = expVarPromStructRoot{} _ expvar.Var = expVarPromStructRoot{} ) + +type hostEnforcer struct { + expiry time.Time + hosts []string + once sync.Once +} + +func (he *hostEnforcer) allowedHosts(ctx context.Context, client *tailscale.LocalClient) ([]string, error) { + var statusErr error + he.once.Do(func() { + status, err := client.StatusWithoutPeers(ctx) + if err != nil { + statusErr = err + return + } + dnsName := strings.TrimSuffix(status.Self.DNSName, ".") // node.tailnet.ts.net + bareHost := strings.TrimSuffix(dnsName, "."+status.MagicDNSSuffix) // node + he.hosts = []string{dnsName, bareHost} + }) + if statusErr != nil { + return nil, statusErr + } + return he.hosts, nil +} + +// EnforceHostname wraps a handler and enforces that all inbound requests have a Host header that matches the local Tailscale node. +// Acceptable Host headers are bare IP addresses, the node's full DNS name (node.tailnet.ts.net) and the node's bare hostname. +// This is used to help protect against DNS rebinding attacks. +// If client is nil, a default client will be used. +func EnforceHostname(h http.Handler, client *tailscale.LocalClient) http.Handler { + if client == nil { + client = &tailscale.LocalClient{} + } + enforcerAtomic := atomic.Pointer[hostEnforcer]{} + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // always allow bare IP addresses + if _, err := netip.ParseAddr(r.Host); err == nil { + h.ServeHTTP(w, r) + return + } + + enforcer := enforcerAtomic.Load() + // refresh the list of allowed hosts every 5 seconds + if now := time.Now(); enforcer == nil || now.After(enforcer.expiry) { + newEnforcer := &hostEnforcer{expiry: now.Add(5 * time.Second)} + if enforcerAtomic.CompareAndSwap(enforcer, newEnforcer) { + enforcer = enforcerAtomic.Load() + } + } + allowedHosts, err := enforcer.allowedHosts(r.Context(), client) + if err != nil { + http.Error(w, "unable to get tailscale status", http.StatusInternalServerError) + return + } + for _, host := range allowedHosts { + if r.Host == host { + h.ServeHTTP(w, r) + return + } + } + http.Error(w, "invalid Host header", http.StatusBadRequest) + }) +}