safeweb: allow passing http.Server in safeweb.Config (#13688)

Extend safeweb.Config with the ability to pass a http.Server that
safeweb will use to server traffic.

Updates corp#8207

Signed-off-by: Patrick O'Doherty <patrick@tailscale.com>
This commit is contained in:
Patrick O'Doherty 2024-10-04 19:57:00 +01:00 committed by GitHub
parent 8fdffb8da0
commit 4ad3f01225
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 1 deletions

View File

@ -144,6 +144,12 @@ type Config struct {
// BrowserMux when SecureContext is true.
// If empty, it defaults to max-age of 1 year.
StrictTransportSecurityOptions string
// HTTPServer, if specified, is the underlying http.Server that safeweb will
// use to serve requests. If nil, a new http.Server will be created.
// Do not use the Handler field of http.Server, as it will be ignored.
// Instead, set your handlers using APIMux and BrowserMux.
HTTPServer *http.Server
}
func (c *Config) setDefaults() error {
@ -203,7 +209,11 @@ func NewServer(config Config) (*Server, error) {
if config.CSPAllowInlineStyles {
s.csp = defaultCSP + `; style-src 'self' 'unsafe-inline'`
}
s.h = &http.Server{Handler: s}
s.h = cmp.Or(config.HTTPServer, &http.Server{})
if s.h.Handler != nil {
return nil, fmt.Errorf("use safeweb.Config.APIMux and safeweb.Config.BrowserMux instead of http.Server.Handler")
}
s.h.Handler = s
return s, nil
}

View File

@ -10,6 +10,7 @@
"strconv"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/gorilla/csrf"
@ -609,3 +610,26 @@ func TestStrictTransportSecurityOptions(t *testing.T) {
})
}
}
func TestOverrideHTTPServer(t *testing.T) {
s, err := NewServer(Config{})
if err != nil {
t.Fatalf("NewServer: %v", err)
}
if s.h.IdleTimeout != 0 {
t.Fatalf("got %v; want 0", s.h.IdleTimeout)
}
c := http.Server{
IdleTimeout: 10 * time.Second,
}
s, err = NewServer(Config{HTTPServer: &c})
if err != nil {
t.Fatalf("NewServer: %v", err)
}
if s.h.IdleTimeout != c.IdleTimeout {
t.Fatalf("got %v; want %v", s.h.IdleTimeout, c.IdleTimeout)
}
}