tailscale/safeweb/http_test.go
Patrick O'Doherty 4a39a47248
safeweb: replace gorilla with Sec-Fetch-Site check
Require that all non-(GET|OPTIONS|HEAD) requests to the browser mux
specify Sec-Fetch-Site=same-origin to prohibit cross-origin requests.

Optionally allow for requests to specify "same-site" indicating a
cross-origin request from an origin that shares a root domain with the
application's own.

Updates tailscale/corp#25340
2025-04-24 17:44:15 -07:00

635 lines
16 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package safeweb
import (
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
)
func TestCompleteCORSConfig(t *testing.T) {
_, err := NewServer(Config{AccessControlAllowOrigin: []string{"https://foobar.com"}})
if err == nil {
t.Fatalf("expected error when AccessControlAllowOrigin is provided without AccessControlAllowMethods")
}
_, err = NewServer(Config{AccessControlAllowMethods: []string{"GET", "POST"}})
if err == nil {
t.Fatalf("expected error when AccessControlAllowMethods is provided without AccessControlAllowOrigin")
}
_, err = NewServer(Config{AccessControlAllowOrigin: []string{"https://foobar.com"}, AccessControlAllowMethods: []string{"GET", "POST"}})
if err != nil {
t.Fatalf("error creating server with complete CORS configuration: %v", err)
}
}
func TestPostRequestContentTypeValidation(t *testing.T) {
tests := []struct {
name string
browserRoute bool
contentType string
wantErr bool
}{
{
name: "API routes should accept `application/json` content-type",
browserRoute: false,
contentType: "application/json",
wantErr: false,
},
{
name: "API routes should reject `application/x-www-form-urlencoded` content-type",
browserRoute: false,
contentType: "application/x-www-form-urlencoded",
wantErr: true,
},
{
name: "Browser routes should accept `application/x-www-form-urlencoded` content-type",
browserRoute: true,
contentType: "application/x-www-form-urlencoded",
wantErr: false,
},
{
name: "non Browser routes should accept `application/json` content-type",
browserRoute: true,
contentType: "application/json",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := &http.ServeMux{}
h.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok"))
}))
var s *Server
var err error
if tt.browserRoute {
s, err = NewServer(Config{BrowserMux: h})
} else {
s, err = NewServer(Config{APIMux: h})
}
if err != nil {
t.Fatal(err)
}
defer s.Close()
req := httptest.NewRequest("POST", "/", nil)
req.Header.Set("Content-Type", tt.contentType)
w := httptest.NewRecorder()
s.h.Handler.ServeHTTP(w, req)
resp := w.Result()
if tt.wantErr && resp.StatusCode != http.StatusBadRequest {
t.Fatalf("content type validation failed: got %v; want %v", resp.StatusCode, http.StatusBadRequest)
}
})
}
}
func TestAPIMuxCrossOriginResourceSharingHeaders(t *testing.T) {
tests := []struct {
name string
httpMethod string
wantCORSHeaders bool
corsOrigins []string
corsMethods []string
}{
{
name: "do not set CORS headers for non-OPTIONS requests",
corsOrigins: []string{"https://foobar.com"},
corsMethods: []string{"GET", "POST", "HEAD"},
httpMethod: "GET",
wantCORSHeaders: false,
},
{
name: "set CORS headers for non-OPTIONS requests",
corsOrigins: []string{"https://foobar.com"},
corsMethods: []string{"GET", "POST", "HEAD"},
httpMethod: "OPTIONS",
wantCORSHeaders: true,
},
{
name: "do not serve CORS headers for OPTIONS requests with no configured origins",
httpMethod: "OPTIONS",
wantCORSHeaders: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := &http.ServeMux{}
h.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok"))
}))
s, err := NewServer(Config{
APIMux: h,
AccessControlAllowOrigin: tt.corsOrigins,
AccessControlAllowMethods: tt.corsMethods,
})
if err != nil {
t.Fatal(err)
}
defer s.Close()
req := httptest.NewRequest(tt.httpMethod, "/", nil)
w := httptest.NewRecorder()
s.h.Handler.ServeHTTP(w, req)
resp := w.Result()
if (resp.Header.Get("Access-Control-Allow-Origin") == "") == tt.wantCORSHeaders {
t.Fatalf("access-control-allow-origin want: %v; got: %v", tt.wantCORSHeaders, resp.Header.Get("Access-Control-Allow-Origin"))
}
})
}
}
func TestCSRFProtection(t *testing.T) {
tests := []struct {
name string
httpMethod string
apiRoute bool
secFetchSiteNone bool
secFetchSiteSameOrigin bool
secFetchSiteSameSite bool
secFetchSiteCrossSite bool
permitSameSite bool
wantStatus int
}{
{
name: "GET requests to browser routes do not require Sec-Fetch-Site header",
httpMethod: http.MethodGet,
apiRoute: false,
wantStatus: http.StatusOK,
},
{
name: "POST requests to browser routes require Sec-Fetch-Site=same-origin and fail if not provided",
httpMethod: http.MethodPost,
apiRoute: false,
wantStatus: http.StatusForbidden,
},
{
name: "POST requests to browser routes require Sec-Fetch-Site=same-origin and pass if provided",
secFetchSiteSameOrigin: true,
httpMethod: http.MethodPost,
apiRoute: false,
wantStatus: http.StatusOK,
},
{
name: "POST requests to browser routes with Sec-Fetch-Site=none fail",
secFetchSiteNone: true,
httpMethod: http.MethodPost,
apiRoute: false,
wantStatus: http.StatusForbidden,
},
{
name: "POST requests to browser routes with Sec-Fetch-Site=same-site fail by default",
secFetchSiteSameSite: true,
httpMethod: http.MethodPost,
apiRoute: false,
wantStatus: http.StatusForbidden,
},
{
name: "POST requests to browser routes with Sec-Fetch-Site=same-site pass if configured",
secFetchSiteSameSite: true,
permitSameSite: true,
httpMethod: http.MethodPost,
apiRoute: false,
wantStatus: http.StatusOK,
},
{
name: "POST requests to API routes do not require Sec-Fetch-Site header",
httpMethod: http.MethodPost,
apiRoute: true,
wantStatus: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := &http.ServeMux{}
h.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok"))
}))
var s *Server
var err error
if tt.apiRoute {
s, err = NewServer(Config{APIMux: h})
} else {
s, err = NewServer(Config{BrowserMux: h})
}
if err != nil {
t.Fatal(err)
}
defer s.Close()
// construct the test request
req := httptest.NewRequest(tt.httpMethod, "/", nil)
// send JSON for API routes, form data for browser routes
if tt.apiRoute {
req.Header.Set("Content-Type", "application/json")
} else {
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
}
if tt.permitSameSite {
s.Config.AllowSecFetchSiteSameSite = true
}
if tt.secFetchSiteNone {
req.Header.Set("Sec-Fetch-Site", "none")
} else if tt.secFetchSiteSameOrigin {
req.Header.Set("Sec-Fetch-Site", "same-origin")
} else if tt.secFetchSiteSameSite {
req.Header.Set("Sec-Fetch-Site", "same-site")
} else if tt.secFetchSiteCrossSite {
req.Header.Set("Sec-Fetch-Site", "cross-site")
}
w := httptest.NewRecorder()
s.h.Handler.ServeHTTP(w, req)
resp := w.Result()
if resp.StatusCode != tt.wantStatus {
t.Fatalf("csrf protection check failed: got %v; want %v", resp.StatusCode, tt.wantStatus)
}
})
}
}
func TestContentSecurityPolicyHeader(t *testing.T) {
tests := []struct {
name string
csp CSP
apiRoute bool
wantCSP string
}{
{
name: "default CSP",
wantCSP: `base-uri 'self'; block-all-mixed-content; default-src 'self'; form-action 'self'; frame-ancestors 'none';`,
},
{
name: "custom CSP",
csp: CSP{
"default-src": {"'self'", "https://tailscale.com"},
"upgrade-insecure-requests": nil,
},
wantCSP: `default-src 'self' https://tailscale.com; upgrade-insecure-requests;`,
},
{
name: "`/api/*` routes do not get CSP headers",
apiRoute: true,
wantCSP: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := &http.ServeMux{}
h.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok"))
}))
var s *Server
var err error
if tt.apiRoute {
s, err = NewServer(Config{APIMux: h, CSP: tt.csp})
} else {
s, err = NewServer(Config{BrowserMux: h, CSP: tt.csp})
}
if err != nil {
t.Fatal(err)
}
defer s.Close()
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
s.h.Handler.ServeHTTP(w, req)
resp := w.Result()
if got := resp.Header.Get("Content-Security-Policy"); got != tt.wantCSP {
t.Fatalf("content security policy want: %q; got: %q", tt.wantCSP, got)
}
})
}
}
func TestRefererPolicy(t *testing.T) {
tests := []struct {
name string
browserRoute bool
wantRefererPolicy bool
}{
{
name: "BrowserMux routes get Referer-Policy headers",
browserRoute: true,
wantRefererPolicy: true,
},
{
name: "APIMux routes do not get Referer-Policy headers",
browserRoute: false,
wantRefererPolicy: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := &http.ServeMux{}
h.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok"))
}))
var s *Server
var err error
if tt.browserRoute {
s, err = NewServer(Config{BrowserMux: h})
} else {
s, err = NewServer(Config{APIMux: h})
}
if err != nil {
t.Fatal(err)
}
defer s.Close()
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
s.h.Handler.ServeHTTP(w, req)
resp := w.Result()
if (resp.Header.Get("Referer-Policy") == "") == tt.wantRefererPolicy {
t.Fatalf("referer policy want: %v; got: %v", tt.wantRefererPolicy, resp.Header.Get("Referer-Policy"))
}
})
}
}
func TestCSPAllowInlineStyles(t *testing.T) {
for _, allow := range []bool{false, true} {
t.Run(strconv.FormatBool(allow), func(t *testing.T) {
h := &http.ServeMux{}
h.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok"))
}))
s, err := NewServer(Config{BrowserMux: h, CSPAllowInlineStyles: allow})
if err != nil {
t.Fatal(err)
}
defer s.Close()
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
s.h.Handler.ServeHTTP(w, req)
resp := w.Result()
csp := resp.Header.Get("Content-Security-Policy")
allowsStyles := strings.Contains(csp, "style-src 'self' 'unsafe-inline'")
if allowsStyles != allow {
t.Fatalf("CSP inline styles want: %v, got: %v in %q", allow, allowsStyles, csp)
}
})
}
}
func TestRouting(t *testing.T) {
for _, tt := range []struct {
desc string
browserPatterns []string
apiPatterns []string
requestPath string
want string
}{
{
desc: "only browser mux",
browserPatterns: []string{"/"},
requestPath: "/index.html",
want: "browser",
},
{
desc: "only API mux",
apiPatterns: []string{"/api/"},
requestPath: "/api/foo",
want: "api",
},
{
desc: "browser mux match",
browserPatterns: []string{"/content/"},
apiPatterns: []string{"/api/"},
requestPath: "/content/index.html",
want: "browser",
},
{
desc: "API mux match",
browserPatterns: []string{"/content/"},
apiPatterns: []string{"/api/"},
requestPath: "/api/foo",
want: "api",
},
{
desc: "browser wildcard match",
browserPatterns: []string{"/"},
apiPatterns: []string{"/api/"},
requestPath: "/index.html",
want: "browser",
},
{
desc: "API wildcard match",
browserPatterns: []string{"/content/"},
apiPatterns: []string{"/"},
requestPath: "/api/foo",
want: "api",
},
{
desc: "path conflict",
browserPatterns: []string{"/foo/"},
apiPatterns: []string{"/foo/bar/"},
requestPath: "/foo/bar/baz",
want: "api",
},
{
desc: "no match",
browserPatterns: []string{"/foo/"},
apiPatterns: []string{"/bar/"},
requestPath: "/baz",
want: "404 page not found",
},
} {
t.Run(tt.desc, func(t *testing.T) {
bm := &http.ServeMux{}
for _, p := range tt.browserPatterns {
bm.HandleFunc(p, func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("browser"))
})
}
am := &http.ServeMux{}
for _, p := range tt.apiPatterns {
am.HandleFunc(p, func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("api"))
})
}
s, err := NewServer(Config{BrowserMux: bm, APIMux: am})
if err != nil {
t.Fatal(err)
}
defer s.Close()
req := httptest.NewRequest("GET", tt.requestPath, nil)
w := httptest.NewRecorder()
s.h.Handler.ServeHTTP(w, req)
resp, err := io.ReadAll(w.Result().Body)
if err != nil {
t.Fatal(err)
}
if got := strings.TrimSpace(string(resp)); got != tt.want {
t.Errorf("got response %q, want %q", got, tt.want)
}
})
}
}
func TestGetMoreSpecificPattern(t *testing.T) {
for _, tt := range []struct {
desc string
a string
b string
want handlerType
}{
{
desc: "identical",
a: "/foo/bar",
b: "/foo/bar",
want: unknownHandler,
},
{
desc: "identical prefix",
a: "/foo/bar/",
b: "/foo/bar/",
want: unknownHandler,
},
{
desc: "trailing slash",
a: "/foo",
b: "/foo/", // path.Clean will strip the trailing slash.
want: unknownHandler,
},
{
desc: "same prefix",
a: "/foo/bar/quux",
b: "/foo/bar/", // path.Clean will strip the trailing slash.
want: apiHandler,
},
{
desc: "almost same prefix, but not a path component",
a: "/goat/sheep/cheese",
b: "/goat/sheepcheese/", // path.Clean will strip the trailing slash.
want: apiHandler,
},
{
desc: "attempt to make less-specific pattern look more specific",
a: "/goat/cat/buddy",
b: "/goat/../../../../../../../cat", // path.Clean catches this foolishness
want: apiHandler,
},
{
desc: "2 names for / (1)",
a: "/",
b: "/../../../../../../",
want: unknownHandler,
},
{
desc: "2 names for / (2)",
a: "/",
b: "///////",
want: unknownHandler,
},
{
desc: "root-level",
a: "/latest",
b: "/", // path.Clean will NOT strip the trailing slash.
want: apiHandler,
},
} {
t.Run(tt.desc, func(t *testing.T) {
got := checkHandlerType(tt.a, tt.b)
if got != tt.want {
t.Errorf("got %q, want %q", got, tt.want)
}
})
}
}
func TestStrictTransportSecurityOptions(t *testing.T) {
tests := []struct {
name string
options string
secureContext bool
expect string
}{
{
name: "off by default",
},
{
name: "default HSTS options in the secure context",
secureContext: true,
expect: DefaultStrictTransportSecurityOptions,
},
{
name: "custom options sent in the secure context",
options: DefaultStrictTransportSecurityOptions + "; includeSubDomains",
secureContext: true,
expect: DefaultStrictTransportSecurityOptions + "; includeSubDomains",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := &http.ServeMux{}
h.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok"))
}))
s, err := NewServer(Config{BrowserMux: h, ServeHSTS: tt.secureContext, StrictTransportSecurityOptions: tt.options})
if err != nil {
t.Fatal(err)
}
defer s.Close()
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
s.h.Handler.ServeHTTP(w, req)
resp := w.Result()
if cmp.Diff(tt.expect, resp.Header.Get("Strict-Transport-Security")) != "" {
t.Fatalf("HSTS want: %q; got: %q", tt.expect, resp.Header.Get("Strict-Transport-Security"))
}
})
}
}
func TestOverrideHTTPServer(t *testing.T) {
s, err := NewServer(Config{})
if err != nil {
t.Fatalf("NewServer: %v", err)
}
if s.h.IdleTimeout != 0 {
t.Fatalf("got %v; want 0", s.h.IdleTimeout)
}
c := http.Server{
IdleTimeout: 10 * time.Second,
}
s, err = NewServer(Config{HTTPServer: &c})
if err != nil {
t.Fatalf("NewServer: %v", err)
}
if s.h.IdleTimeout != c.IdleTimeout {
t.Fatalf("got %v; want %v", s.h.IdleTimeout, c.IdleTimeout)
}
}