client/web: cache csrf key when running in CGI mode

Indicate to the web client when it is running in CGI mode, and if it is
then cache the csrf key between requests.

Updates tailscale/corp#13775

Signed-off-by: Will Norris <will@tailscale.com>
This commit is contained in:
Will Norris 2023-08-23 15:22:24 -07:00 committed by Will Norris
parent 46b0c9168f
commit 824cd02d6d
3 changed files with 63 additions and 15 deletions

View File

@ -22,6 +22,7 @@
"net/url" "net/url"
"os" "os"
"os/exec" "os/exec"
"path/filepath"
"strings" "strings"
"github.com/gorilla/csrf" "github.com/gorilla/csrf"
@ -57,20 +58,31 @@ type Server struct {
devMode bool devMode bool
devProxy *httputil.ReverseProxy // only filled when devMode is on devProxy *httputil.ReverseProxy // only filled when devMode is on
cgiMode bool
apiHandler http.Handler // csrf-protected api handler apiHandler http.Handler // csrf-protected api handler
} }
// ServerOpts contains options for constructing a new Server.
type ServerOpts struct {
DevMode bool
// CGIMode indicates if the server is running as a CGI script.
CGIMode bool
// LocalClient is the tailscale.LocalClient to use for this web server.
// If nil, a new one will be created.
LocalClient *tailscale.LocalClient
}
// NewServer constructs a new Tailscale web client server. // NewServer constructs a new Tailscale web client server.
// func NewServer(opts ServerOpts) (s *Server, cleanup func()) {
// lc is an optional parameter. When not filled, NewServer if opts.LocalClient == nil {
// initializes its own tailscale.LocalClient. opts.LocalClient = &tailscale.LocalClient{}
func NewServer(devMode bool, lc *tailscale.LocalClient) (s *Server, cleanup func()) {
if lc == nil {
lc = &tailscale.LocalClient{}
} }
s = &Server{ s = &Server{
devMode: devMode, devMode: opts.DevMode,
lc: lc, lc: opts.LocalClient,
cgiMode: opts.CGIMode,
} }
cleanup = func() {} cleanup = func() {}
if s.devMode { if s.devMode {
@ -82,7 +94,7 @@ func NewServer(devMode bool, lc *tailscale.LocalClient) (s *Server, cleanup func
// on network appliances that are served on local non-https URLs. // on network appliances that are served on local non-https URLs.
// The client is secured by limiting the interface it listens on, // The client is secured by limiting the interface it listens on,
// or by authenticating requests before they reach the web client. // or by authenticating requests before they reach the web client.
csrfProtect := csrf.Protect(csrfKey(), csrf.Secure(false)) csrfProtect := csrf.Protect(s.csrfKey(), csrf.Secure(false))
s.apiHandler = csrfProtect(&api{s: s}) s.apiHandler = csrfProtect(&api{s: s})
} }
s.lc.IncrementCounter(context.Background(), "web_client_initialization", 1) s.lc.IncrementCounter(context.Background(), "web_client_initialization", 1)
@ -530,13 +542,42 @@ func (s *Server) tailscaleUp(ctx context.Context, st *ipnstate.Status, postData
} }
} }
// csrfKey creates a new random csrf token. // csrfKey returns a key that can be used for CSRF protection.
// If an error surfaces during key creation, // If an error occurs during key creation, the error is logged and the active process terminated.
// the error is logged and the active process terminated. // If the server is running in CGI mode, the key is cached to disk and reused between requests.
func csrfKey() []byte { // If an error occurs during key storage, the error is logged and the active process terminated.
func (s *Server) csrfKey() []byte {
var csrfFile string
// if running in CGI mode, try to read from disk, but ignore errors
if s.cgiMode {
confdir, err := os.UserConfigDir()
if err != nil {
confdir = os.TempDir()
}
csrfFile = filepath.Join(confdir, "tailscale", "web-csrf.key")
key, _ := os.ReadFile(csrfFile)
if len(key) == 32 {
return key
}
}
// create a new key
key := make([]byte, 32) key := make([]byte, 32)
if _, err := rand.Read(key); err != nil { if _, err := rand.Read(key); err != nil {
log.Fatal("error generating CSRF key: %w", err) log.Fatal("error generating CSRF key: %w", err)
} }
// if running in CGI mode, try to write the newly created key to disk, and exit if it fails.
if s.cgiMode {
if err := os.Mkdir(filepath.Dir(csrfFile), 0700); err != nil && !os.IsExist(err) {
log.Fatalf("unable to store CSRF key: %v", err)
}
if err := os.WriteFile(csrfFile, key, 0600); err != nil {
log.Fatalf("unable to store CSRF key: %v", err)
}
}
return key return key
} }

View File

@ -78,7 +78,11 @@ func runWeb(ctx context.Context, args []string) error {
return fmt.Errorf("too many non-flag arguments: %q", args) return fmt.Errorf("too many non-flag arguments: %q", args)
} }
webServer, cleanup := web.NewServer(webArgs.dev, &localClient) webServer, cleanup := web.NewServer(web.ServerOpts{
DevMode: webArgs.dev,
CGIMode: webArgs.cgi,
LocalClient: &localClient,
})
defer cleanup() defer cleanup()
if webArgs.cgi { if webArgs.cgi {

View File

@ -30,7 +30,10 @@ func main() {
} }
// Serve the Tailscale web client. // Serve the Tailscale web client.
ws, cleanup := web.NewServer(*devMode, lc) ws, cleanup := web.NewServer(web.ServerOpts{
DevMode: *devMode,
LocalClient: lc,
})
defer cleanup() defer cleanup()
log.Printf("Serving Tailscale web client on http://%s", *addr) log.Printf("Serving Tailscale web client on http://%s", *addr)
if err := http.ListenAndServe(*addr, ws); err != nil { if err := http.ListenAndServe(*addr, ws); err != nil {