client/web: cache csrf key when running in CGI mode

Indicate to the web client when it is running in CGI mode, and if it is
then cache the csrf key between requests.

Updates tailscale/corp#13775

Signed-off-by: Will Norris <will@tailscale.com>
This commit is contained in:
Will Norris 2023-08-23 15:22:24 -07:00 committed by Will Norris
parent 46b0c9168f
commit 824cd02d6d
3 changed files with 63 additions and 15 deletions

View File

@ -22,6 +22,7 @@
"net/url"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/gorilla/csrf"
@ -57,20 +58,31 @@ type Server struct {
devMode bool
devProxy *httputil.ReverseProxy // only filled when devMode is on
cgiMode bool
apiHandler http.Handler // csrf-protected api handler
}
// ServerOpts contains options for constructing a new Server.
type ServerOpts struct {
DevMode bool
// CGIMode indicates if the server is running as a CGI script.
CGIMode bool
// LocalClient is the tailscale.LocalClient to use for this web server.
// If nil, a new one will be created.
LocalClient *tailscale.LocalClient
}
// NewServer constructs a new Tailscale web client server.
//
// lc is an optional parameter. When not filled, NewServer
// initializes its own tailscale.LocalClient.
func NewServer(devMode bool, lc *tailscale.LocalClient) (s *Server, cleanup func()) {
if lc == nil {
lc = &tailscale.LocalClient{}
func NewServer(opts ServerOpts) (s *Server, cleanup func()) {
if opts.LocalClient == nil {
opts.LocalClient = &tailscale.LocalClient{}
}
s = &Server{
devMode: devMode,
lc: lc,
devMode: opts.DevMode,
lc: opts.LocalClient,
cgiMode: opts.CGIMode,
}
cleanup = func() {}
if s.devMode {
@ -82,7 +94,7 @@ func NewServer(devMode bool, lc *tailscale.LocalClient) (s *Server, cleanup func
// on network appliances that are served on local non-https URLs.
// The client is secured by limiting the interface it listens on,
// or by authenticating requests before they reach the web client.
csrfProtect := csrf.Protect(csrfKey(), csrf.Secure(false))
csrfProtect := csrf.Protect(s.csrfKey(), csrf.Secure(false))
s.apiHandler = csrfProtect(&api{s: s})
}
s.lc.IncrementCounter(context.Background(), "web_client_initialization", 1)
@ -530,13 +542,42 @@ func (s *Server) tailscaleUp(ctx context.Context, st *ipnstate.Status, postData
}
}
// csrfKey creates a new random csrf token.
// If an error surfaces during key creation,
// the error is logged and the active process terminated.
func csrfKey() []byte {
// csrfKey returns a key that can be used for CSRF protection.
// If an error occurs during key creation, the error is logged and the active process terminated.
// If the server is running in CGI mode, the key is cached to disk and reused between requests.
// If an error occurs during key storage, the error is logged and the active process terminated.
func (s *Server) csrfKey() []byte {
var csrfFile string
// if running in CGI mode, try to read from disk, but ignore errors
if s.cgiMode {
confdir, err := os.UserConfigDir()
if err != nil {
confdir = os.TempDir()
}
csrfFile = filepath.Join(confdir, "tailscale", "web-csrf.key")
key, _ := os.ReadFile(csrfFile)
if len(key) == 32 {
return key
}
}
// create a new key
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
log.Fatal("error generating CSRF key: %w", err)
}
// if running in CGI mode, try to write the newly created key to disk, and exit if it fails.
if s.cgiMode {
if err := os.Mkdir(filepath.Dir(csrfFile), 0700); err != nil && !os.IsExist(err) {
log.Fatalf("unable to store CSRF key: %v", err)
}
if err := os.WriteFile(csrfFile, key, 0600); err != nil {
log.Fatalf("unable to store CSRF key: %v", err)
}
}
return key
}

View File

@ -78,7 +78,11 @@ func runWeb(ctx context.Context, args []string) error {
return fmt.Errorf("too many non-flag arguments: %q", args)
}
webServer, cleanup := web.NewServer(webArgs.dev, &localClient)
webServer, cleanup := web.NewServer(web.ServerOpts{
DevMode: webArgs.dev,
CGIMode: webArgs.cgi,
LocalClient: &localClient,
})
defer cleanup()
if webArgs.cgi {

View File

@ -30,7 +30,10 @@ func main() {
}
// Serve the Tailscale web client.
ws, cleanup := web.NewServer(*devMode, lc)
ws, cleanup := web.NewServer(web.ServerOpts{
DevMode: *devMode,
LocalClient: lc,
})
defer cleanup()
log.Printf("Serving Tailscale web client on http://%s", *addr)
if err := http.ListenAndServe(*addr, ws); err != nil {