tsweb: add EnforceHostname helper for DNS rebinding mitigation

Signed-off-by: Will Norris <will@tailscale.com>
This commit is contained in:
Will Norris 2022-11-23 09:54:48 -08:00
parent b45b948776
commit f9fa174e11

View File

@ -25,9 +25,11 @@
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"go4.org/mem" "go4.org/mem"
"tailscale.com/client/tailscale"
"tailscale.com/envknob" "tailscale.com/envknob"
"tailscale.com/metrics" "tailscale.com/metrics"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
@ -779,3 +781,67 @@ func (r expVarPromStructRoot) String() string { panic("unused
_ PrometheusMetricsReflectRooter = expVarPromStructRoot{} _ PrometheusMetricsReflectRooter = expVarPromStructRoot{}
_ expvar.Var = expVarPromStructRoot{} _ expvar.Var = expVarPromStructRoot{}
) )
type hostEnforcer struct {
expiry time.Time
hosts []string
once sync.Once
}
func (he *hostEnforcer) allowedHosts(ctx context.Context, client *tailscale.LocalClient) ([]string, error) {
var statusErr error
he.once.Do(func() {
status, err := client.StatusWithoutPeers(ctx)
if err != nil {
statusErr = err
return
}
dnsName := strings.TrimSuffix(status.Self.DNSName, ".") // node.tailnet.ts.net
bareHost := strings.TrimSuffix(dnsName, "."+status.MagicDNSSuffix) // node
he.hosts = []string{dnsName, bareHost}
})
if statusErr != nil {
return nil, statusErr
}
return he.hosts, nil
}
// EnforceHostname wraps a handler and enforces that all inbound requests have a Host header that matches the local Tailscale node.
// Acceptable Host headers are bare IP addresses, the node's full DNS name (node.tailnet.ts.net) and the node's bare hostname.
// This is used to help protect against DNS rebinding attacks.
// If client is nil, a default client will be used.
func EnforceHostname(h http.Handler, client *tailscale.LocalClient) http.Handler {
if client == nil {
client = &tailscale.LocalClient{}
}
enforcerAtomic := atomic.Pointer[hostEnforcer]{}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// always allow bare IP addresses
if _, err := netip.ParseAddr(r.Host); err == nil {
h.ServeHTTP(w, r)
return
}
enforcer := enforcerAtomic.Load()
// refresh the list of allowed hosts every 5 seconds
if now := time.Now(); enforcer == nil || now.After(enforcer.expiry) {
newEnforcer := &hostEnforcer{expiry: now.Add(5 * time.Second)}
if enforcerAtomic.CompareAndSwap(enforcer, newEnforcer) {
enforcer = enforcerAtomic.Load()
}
}
allowedHosts, err := enforcer.allowedHosts(r.Context(), client)
if err != nil {
http.Error(w, "unable to get tailscale status", http.StatusInternalServerError)
return
}
for _, host := range allowedHosts {
if r.Host == host {
h.ServeHTTP(w, r)
return
}
}
http.Error(w, "invalid Host header", http.StatusBadRequest)
})
}