From f5522e62d1dde2ea966f2454df248a8ea2d43676 Mon Sep 17 00:00:00 2001 From: Patrick O'Doherty Date: Thu, 27 Feb 2025 11:58:45 -0800 Subject: [PATCH] client/web: fix CSRF handler order in web UI (#15143) Fix the order of the CSRF handlers (HTTP plaintext context setting, _then_ enforcement) in the construction of the web UI server. This resolves false-positive "invalid Origin" 403 exceptions when attempting to update settings in the web UI. Add unit test to exercise the CSRF protection failure and success cases for our web UI configuration. Updates #14822 Updates #14872 Signed-off-by: Patrick O'Doherty --- client/web/web.go | 65 ++++++++++++++++++--------------- client/web/web_test.go | 82 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+), 29 deletions(-) diff --git a/client/web/web.go b/client/web/web.go index 6203b4c18..e9810ccd0 100644 --- a/client/web/web.go +++ b/client/web/web.go @@ -203,35 +203,9 @@ func NewServer(opts ServerOpts) (s *Server, err error) { } s.assetsHandler, s.assetsCleanup = assetsHandler(s.devMode) - var metric string // clientmetric to report on startup - - // Create handler for "/api" requests with CSRF protection. - // We don't require secure cookies, since the web client is regularly used - // on network appliances that are served on local non-https URLs. - // The client is secured by limiting the interface it listens on, - // or by authenticating requests before they reach the web client. - csrfProtect := csrf.Protect(s.csrfKey(), csrf.Secure(false)) - - // signal to the CSRF middleware that the request is being served over - // plaintext HTTP to skip TLS-only header checks. - withSetPlaintext := func(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r = csrf.PlaintextHTTPRequest(r) - h.ServeHTTP(w, r) - }) - } - - switch s.mode { - case LoginServerMode: - s.apiHandler = csrfProtect(withSetPlaintext(http.HandlerFunc(s.serveLoginAPI))) - metric = "web_login_client_initialization" - case ReadOnlyServerMode: - s.apiHandler = csrfProtect(withSetPlaintext(http.HandlerFunc(s.serveLoginAPI))) - metric = "web_readonly_client_initialization" - case ManageServerMode: - s.apiHandler = csrfProtect(withSetPlaintext(http.HandlerFunc(s.serveAPI))) - metric = "web_client_initialization" - } + var metric string + s.apiHandler, metric = s.modeAPIHandler(s.mode) + s.apiHandler = s.withCSRF(s.apiHandler) // Don't block startup on reporting metric. // Report in separate go routine with 5 second timeout. @@ -244,6 +218,39 @@ func NewServer(opts ServerOpts) (s *Server, err error) { return s, nil } +func (s *Server) withCSRF(h http.Handler) http.Handler { + csrfProtect := csrf.Protect(s.csrfKey(), csrf.Secure(false)) + + // ref https://github.com/tailscale/tailscale/pull/14822 + // signal to the CSRF middleware that the request is being served over + // plaintext HTTP to skip TLS-only header checks. + withSetPlaintext := func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r = csrf.PlaintextHTTPRequest(r) + h.ServeHTTP(w, r) + }) + } + + // NB: the order of the withSetPlaintext and csrfProtect calls is important + // to ensure that we signal to the CSRF middleware that the request is being + // served over plaintext HTTP and not over TLS as it presumes by default. + return withSetPlaintext(csrfProtect(h)) +} + +func (s *Server) modeAPIHandler(mode ServerMode) (http.Handler, string) { + switch mode { + case LoginServerMode: + return http.HandlerFunc(s.serveLoginAPI), "web_login_client_initialization" + case ReadOnlyServerMode: + return http.HandlerFunc(s.serveLoginAPI), "web_readonly_client_initialization" + case ManageServerMode: + return http.HandlerFunc(s.serveAPI), "web_client_initialization" + default: // invalid mode + log.Fatalf("invalid mode: %v", mode) + } + return nil, "" +} + func (s *Server) Shutdown() { s.logf("web.Server: shutting down") if s.assetsCleanup != nil { diff --git a/client/web/web_test.go b/client/web/web_test.go index b9242f6ac..291356260 100644 --- a/client/web/web_test.go +++ b/client/web/web_test.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "net/http" + "net/http/cookiejar" "net/http/httptest" "net/netip" "net/url" @@ -20,6 +21,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/gorilla/csrf" "tailscale.com/client/local" "tailscale.com/client/tailscale/apitype" "tailscale.com/ipn" @@ -1477,3 +1479,83 @@ func mockWaitAuthURL(_ context.Context, id string, src tailcfg.NodeID) (*tailcfg return nil, errors.New("unknown id") } } + +func TestCSRFProtect(t *testing.T) { + s := &Server{} + + mux := http.NewServeMux() + mux.HandleFunc("GET /test/csrf-token", func(w http.ResponseWriter, r *http.Request) { + token := csrf.Token(r) + _, err := io.WriteString(w, token) + if err != nil { + t.Fatal(err) + } + }) + mux.HandleFunc("POST /test/csrf-protected", func(w http.ResponseWriter, r *http.Request) { + _, err := io.WriteString(w, "ok") + if err != nil { + t.Fatal(err) + } + }) + h := s.withCSRF(mux) + ser := httptest.NewServer(h) + defer ser.Close() + + jar, err := cookiejar.New(nil) + if err != nil { + t.Fatalf("unable to construct cookie jar: %v", err) + } + + client := ser.Client() + client.Jar = jar + + // make GET request to populate cookie jar + resp, err := client.Get(ser.URL + "/test/csrf-token") + if err != nil { + t.Fatalf("unable to make request: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %v", resp.Status) + } + tokenBytes, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("unable to read body: %v", err) + } + + csrfToken := strings.TrimSpace(string(tokenBytes)) + if csrfToken == "" { + t.Fatal("empty csrf token") + } + + // make a POST request without the CSRF header; ensure it fails + resp, err = client.Post(ser.URL+"/test/csrf-protected", "text/plain", nil) + if err != nil { + t.Fatalf("unable to make request: %v", err) + } + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("unexpected status: %v", resp.Status) + } + + // make a POST request with the CSRF header; ensure it succeeds + req, err := http.NewRequest("POST", ser.URL+"/test/csrf-protected", nil) + if err != nil { + t.Fatalf("error building request: %v", err) + } + req.Header.Set("X-CSRF-Token", csrfToken) + resp, err = client.Do(req) + if err != nil { + t.Fatalf("unable to make request: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %v", resp.Status) + } + defer resp.Body.Close() + out, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("unable to read body: %v", err) + } + if string(out) != "ok" { + t.Fatalf("unexpected body: %q", out) + } +}