mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 21:27:42 +00:00
fix: improvements for login and oidc (#227)
* add csrf * caching * caching * caching * caching * security headers * csp and security headers * error handler csp * select user with display name * csp * user selection styling * username to loginname * regenerate grpc * regenerate * change to login name
This commit is contained in:
@@ -76,6 +76,7 @@ func createMux(ctx context.Context, g Gateway) *runtime.ServeMux {
|
||||
|
||||
func addInterceptors(handler http.Handler, g Gateway) http.Handler {
|
||||
handler = http_mw.DefaultTraceHandler(handler)
|
||||
handler = http_mw.NoCacheInterceptor(handler)
|
||||
if interceptor, ok := g.(grpcGatewayCustomInterceptor); ok {
|
||||
handler = interceptor.GatewayHTTPInterceptor(handler)
|
||||
}
|
||||
|
@@ -4,12 +4,23 @@ const (
|
||||
Authorization = "authorization"
|
||||
Accept = "accept"
|
||||
AcceptLanguage = "accept-language"
|
||||
CacheControl = "cache-control"
|
||||
ContentType = "content-type"
|
||||
Expires = "expires"
|
||||
Location = "location"
|
||||
Origin = "origin"
|
||||
Pragma = "pragma"
|
||||
UserAgent = "user-agent"
|
||||
ForwardedFor = "x-forwarded-for"
|
||||
|
||||
ContentSecurityPolicy = "content-security-policy"
|
||||
XXSSProtection = "x-xss-protection"
|
||||
StrictTransportSecurity = "strict-transport-security"
|
||||
XFrameOptions = "x-frame-options"
|
||||
XContentTypeOptions = "x-content-type-options"
|
||||
ReferrerPolicy = "referrer-policy"
|
||||
FeaturePolicy = "feature-policy"
|
||||
|
||||
ZitadelOrgID = "x-zitadel-orgid"
|
||||
//TODO: Remove as soon an authentification is implemented
|
||||
ZitadelUserID = "x-zitadel-userid"
|
||||
|
128
internal/api/http/middleware/cache_interceptor.go
Normal file
128
internal/api/http/middleware/cache_interceptor.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/caos/zitadel/internal/api"
|
||||
"github.com/caos/zitadel/internal/config/types"
|
||||
)
|
||||
|
||||
type Cache struct {
|
||||
Cacheability Cacheability
|
||||
NoCache bool
|
||||
NoStore bool
|
||||
MaxAge time.Duration
|
||||
SharedMaxAge time.Duration
|
||||
NoTransform bool
|
||||
Revalidation Revalidation
|
||||
}
|
||||
|
||||
type Cacheability string
|
||||
|
||||
const (
|
||||
CacheabilityNotSet Cacheability = ""
|
||||
CacheabilityPublic = "public"
|
||||
CacheabilityPrivate = "private"
|
||||
)
|
||||
|
||||
type Revalidation string
|
||||
|
||||
const (
|
||||
RevalidationNotSet Revalidation = ""
|
||||
RevalidationMust = "must-revalidate"
|
||||
RevalidationProxy = "proxy-revalidate"
|
||||
)
|
||||
|
||||
type CacheConfig struct {
|
||||
MaxAge types.Duration
|
||||
SharedMaxAge types.Duration
|
||||
}
|
||||
|
||||
var (
|
||||
NeverCacheOptions = &Cache{
|
||||
NoStore: true,
|
||||
}
|
||||
AssetOptions = func(maxAge, SharedMaxAge time.Duration) *Cache {
|
||||
return &Cache{
|
||||
Cacheability: CacheabilityPublic,
|
||||
MaxAge: maxAge,
|
||||
SharedMaxAge: SharedMaxAge,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
func DefaultCacheInterceptor(pattern string, maxAge, sharedMaxAge time.Duration) (func(http.Handler) http.Handler, error) {
|
||||
regex, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return func(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if regex.MatchString(r.URL.Path) {
|
||||
AssetsCacheInterceptor(maxAge, sharedMaxAge, handler).ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
NoCacheInterceptor(handler).ServeHTTP(w, r)
|
||||
})
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NoCacheInterceptor(h http.Handler) http.Handler {
|
||||
return CacheInterceptorOpts(h, NeverCacheOptions)
|
||||
}
|
||||
|
||||
func AssetsCacheInterceptor(maxAge, sharedMaxAge time.Duration, h http.Handler) http.Handler {
|
||||
return CacheInterceptorOpts(h, AssetOptions(maxAge, sharedMaxAge))
|
||||
}
|
||||
|
||||
func CacheInterceptorOpts(h http.Handler, cache *Cache) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
cache.serializeHeaders(w)
|
||||
h.ServeHTTP(w, req)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Cache) serializeHeaders(w http.ResponseWriter) {
|
||||
control := make([]string, 0, 6)
|
||||
pragma := false
|
||||
|
||||
if c.Cacheability != CacheabilityNotSet {
|
||||
control = append(control, string(c.Cacheability))
|
||||
control = append(control, fmt.Sprintf("max-age=%v", c.MaxAge.Seconds()))
|
||||
if c.SharedMaxAge != c.MaxAge {
|
||||
control = append(control, fmt.Sprintf("s-maxage=%v", c.SharedMaxAge.Seconds()))
|
||||
}
|
||||
}
|
||||
maxAge := c.MaxAge
|
||||
if maxAge == 0 {
|
||||
maxAge = -time.Hour
|
||||
}
|
||||
expires := time.Now().UTC().Add(maxAge).Format(http.TimeFormat)
|
||||
|
||||
if c.NoCache {
|
||||
control = append(control, fmt.Sprintf("no-cache"))
|
||||
pragma = true
|
||||
}
|
||||
|
||||
if c.NoStore {
|
||||
control = append(control, fmt.Sprintf("no-store"))
|
||||
pragma = true
|
||||
}
|
||||
if c.NoTransform {
|
||||
control = append(control, fmt.Sprintf("no-transform"))
|
||||
}
|
||||
|
||||
if c.Revalidation != RevalidationNotSet {
|
||||
control = append(control, string(c.Revalidation))
|
||||
}
|
||||
|
||||
w.Header().Set(api.CacheControl, strings.Join(control, ", "))
|
||||
w.Header().Set(api.Expires, expires)
|
||||
if pragma {
|
||||
w.Header().Set(api.Pragma, "no-cache")
|
||||
}
|
||||
}
|
82
internal/api/http/middleware/cache_interceptor_test.go
Normal file
82
internal/api/http/middleware/cache_interceptor_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCache_serializeHeaders(t *testing.T) {
|
||||
type fields struct {
|
||||
Cacheability Cacheability
|
||||
NoCache bool
|
||||
NoStore bool
|
||||
MaxAge time.Duration
|
||||
SharedMaxAge time.Duration
|
||||
NoTransform bool
|
||||
Revalidation Revalidation
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantControl string
|
||||
wantExpires string
|
||||
wantPragma string
|
||||
}{
|
||||
{
|
||||
"no-store",
|
||||
fields{
|
||||
NoStore: true,
|
||||
},
|
||||
"no-store",
|
||||
time.Now().UTC().Add(-1 * time.Hour).Format(http.TimeFormat),
|
||||
"no-cache",
|
||||
},
|
||||
{
|
||||
"private and max-age",
|
||||
fields{
|
||||
Cacheability: CacheabilityPrivate,
|
||||
MaxAge: 1 * time.Hour,
|
||||
SharedMaxAge: 1 * time.Hour,
|
||||
},
|
||||
"private, max-age=3600",
|
||||
time.Now().UTC().Add(1 * time.Hour).Format(http.TimeFormat),
|
||||
"",
|
||||
},
|
||||
{
|
||||
"public, no-cache, proxy-revalidate",
|
||||
fields{
|
||||
Cacheability: CacheabilityPublic,
|
||||
NoCache: true,
|
||||
Revalidation: RevalidationProxy,
|
||||
},
|
||||
"public, max-age=0, no-cache, proxy-revalidate",
|
||||
time.Now().UTC().Add(-1 * time.Hour).Format(http.TimeFormat),
|
||||
"no-cache",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
c := &Cache{
|
||||
Cacheability: tt.fields.Cacheability,
|
||||
NoCache: tt.fields.NoCache,
|
||||
NoStore: tt.fields.NoStore,
|
||||
MaxAge: tt.fields.MaxAge,
|
||||
SharedMaxAge: tt.fields.SharedMaxAge,
|
||||
NoTransform: tt.fields.NoTransform,
|
||||
Revalidation: tt.fields.Revalidation,
|
||||
}
|
||||
c.serializeHeaders(recorder)
|
||||
cc := recorder.Result().Header.Get("cache-control")
|
||||
assert.Equal(t, tt.wantControl, cc)
|
||||
exp := recorder.Result().Header.Get("expires")
|
||||
assert.Equal(t, tt.wantExpires, exp)
|
||||
pragma := recorder.Result().Header.Get("pragma")
|
||||
assert.Equal(t, tt.wantPragma, pragma)
|
||||
})
|
||||
}
|
||||
}
|
125
internal/api/http/middleware/csp.go
Normal file
125
internal/api/http/middleware/csp.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type CSP struct {
|
||||
DefaultSrc CSPSourceOptions
|
||||
ScriptSrc CSPSourceOptions
|
||||
ObjectSrc CSPSourceOptions
|
||||
StyleSrc CSPSourceOptions
|
||||
ImgSrc CSPSourceOptions
|
||||
MediaSrc CSPSourceOptions
|
||||
FrameSrc CSPSourceOptions
|
||||
FontSrc CSPSourceOptions
|
||||
ConnectSrc CSPSourceOptions
|
||||
FormAction CSPSourceOptions
|
||||
}
|
||||
|
||||
var (
|
||||
DefaultSCP = CSP{
|
||||
DefaultSrc: CSPSourceOptsNone(),
|
||||
ScriptSrc: CSPSourceOptsSelf(),
|
||||
ObjectSrc: CSPSourceOptsNone(),
|
||||
StyleSrc: CSPSourceOptsSelf(),
|
||||
ImgSrc: CSPSourceOptsSelf(),
|
||||
MediaSrc: CSPSourceOptsNone(),
|
||||
FrameSrc: CSPSourceOptsNone(),
|
||||
FontSrc: CSPSourceOptsSelf(),
|
||||
ConnectSrc: CSPSourceOptsSelf(),
|
||||
}
|
||||
)
|
||||
|
||||
func (csp *CSP) Value(nonce string) string {
|
||||
valuesMap := csp.asMap()
|
||||
|
||||
values := make([]string, 0, len(valuesMap))
|
||||
for k, v := range valuesMap {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
values = append(values, fmt.Sprintf("%v %v", k, v.String(nonce)))
|
||||
}
|
||||
|
||||
return strings.Join(values, ";")
|
||||
}
|
||||
|
||||
func (csp *CSP) asMap() map[string]CSPSourceOptions {
|
||||
return map[string]CSPSourceOptions{
|
||||
"default-src": csp.DefaultSrc,
|
||||
"script-src": csp.ScriptSrc,
|
||||
"object-src": csp.ObjectSrc,
|
||||
"style-src": csp.StyleSrc,
|
||||
"img-src": csp.ImgSrc,
|
||||
"media-src": csp.MediaSrc,
|
||||
"frame-src": csp.FrameSrc,
|
||||
"font-src": csp.FontSrc,
|
||||
"connect-src": csp.ConnectSrc,
|
||||
"form-action": csp.FormAction,
|
||||
}
|
||||
}
|
||||
|
||||
type CSPSourceOptions []string
|
||||
|
||||
func CSPSourceOpts() CSPSourceOptions {
|
||||
return CSPSourceOptions{}
|
||||
}
|
||||
|
||||
func CSPSourceOptsNone() CSPSourceOptions {
|
||||
return []string{"'none'"}
|
||||
}
|
||||
|
||||
func CSPSourceOptsSelf() CSPSourceOptions {
|
||||
return []string{"'self'"}
|
||||
}
|
||||
|
||||
func (srcOpts CSPSourceOptions) AddSelf() CSPSourceOptions {
|
||||
return append(srcOpts, "'self'")
|
||||
}
|
||||
|
||||
func (srcOpts CSPSourceOptions) AddInline() CSPSourceOptions {
|
||||
return append(srcOpts, "'unsafe-inline'")
|
||||
}
|
||||
|
||||
func (srcOpts CSPSourceOptions) AddEval() CSPSourceOptions {
|
||||
return append(srcOpts, "'unsafe-eval'")
|
||||
}
|
||||
|
||||
func (srcOpts CSPSourceOptions) AddStrictDynamic() CSPSourceOptions {
|
||||
return append(srcOpts, "'strict-dynamic'")
|
||||
}
|
||||
|
||||
func (srcOpts CSPSourceOptions) AddHost(h ...string) CSPSourceOptions {
|
||||
return append(srcOpts, h...)
|
||||
}
|
||||
|
||||
func (srcOpts CSPSourceOptions) AddScheme(s ...string) CSPSourceOptions {
|
||||
return srcOpts.add(s, "%v:")
|
||||
}
|
||||
|
||||
func (srcOpts CSPSourceOptions) AddNonce() CSPSourceOptions {
|
||||
return append(srcOpts, "'nonce-%v'")
|
||||
}
|
||||
|
||||
func (srcOpts CSPSourceOptions) AddHash(alg, b64v string) CSPSourceOptions {
|
||||
return append(srcOpts, fmt.Sprintf("'%v-%v'", alg, b64v))
|
||||
}
|
||||
|
||||
func (srcOpts CSPSourceOptions) String(nonce string) string {
|
||||
value := strings.Join(srcOpts, " ")
|
||||
if !strings.Contains(value, "%v") {
|
||||
return value
|
||||
}
|
||||
return fmt.Sprintf(value, nonce)
|
||||
}
|
||||
|
||||
func (srcOpts CSPSourceOptions) add(values []string, format string) CSPSourceOptions {
|
||||
for i, v := range values {
|
||||
values[i] = fmt.Sprintf(format, v)
|
||||
}
|
||||
|
||||
return append(srcOpts, values...)
|
||||
}
|
90
internal/api/http/middleware/security_headers.go
Normal file
90
internal/api/http/middleware/security_headers.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
|
||||
"github.com/caos/zitadel/internal/api"
|
||||
)
|
||||
|
||||
type key int
|
||||
|
||||
const (
|
||||
nonceKey key = 0
|
||||
|
||||
DefaultNonceLength = uint(32)
|
||||
)
|
||||
|
||||
func SecurityHeaders(csp *CSP, errorHandler func(error) http.Handler, nonceLength ...uint) func(http.Handler) http.Handler {
|
||||
return func(handler http.Handler) http.Handler {
|
||||
if csp == nil {
|
||||
csp = &DefaultSCP
|
||||
}
|
||||
length := DefaultNonceLength
|
||||
if len(nonceLength) > 0 {
|
||||
length = nonceLength[0]
|
||||
}
|
||||
return &headers{
|
||||
csp: csp,
|
||||
handler: handler,
|
||||
errorHandler: errorHandler,
|
||||
nonceLength: length,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type headers struct {
|
||||
csp *CSP
|
||||
handler http.Handler
|
||||
errorHandler func(err error) http.Handler
|
||||
nonceLength uint
|
||||
}
|
||||
|
||||
func (h *headers) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
nonce := GetNonce(r)
|
||||
if nonce == "" {
|
||||
var err error
|
||||
nonce, err = generateNonce(h.nonceLength)
|
||||
if err != nil {
|
||||
h.errorHandler(err).ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
r = saveContext(r, nonceKey, nonce)
|
||||
}
|
||||
headers := w.Header()
|
||||
headers.Set(api.ContentSecurityPolicy, h.csp.Value(nonce))
|
||||
headers.Set(api.XXSSProtection, "1; mode=block")
|
||||
headers.Set(api.StrictTransportSecurity, "max-age=31536000; includeSubDomains")
|
||||
headers.Set(api.XFrameOptions, "DENY")
|
||||
headers.Set(api.XContentTypeOptions, "nosniff")
|
||||
headers.Set(api.ReferrerPolicy, "same-origin")
|
||||
headers.Set(api.FeaturePolicy, "payment 'none'")
|
||||
//PLANNED: add expect-ct
|
||||
|
||||
h.handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func GetNonce(r *http.Request) string {
|
||||
nonce, _ := getContext(r, nonceKey).(string)
|
||||
return nonce
|
||||
}
|
||||
|
||||
func generateNonce(length uint) (string, error) {
|
||||
b := make([]byte, length)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func saveContext(r *http.Request, key, value interface{}) *http.Request {
|
||||
ctx := context.WithValue(r.Context(), key, value)
|
||||
return r.WithContext(ctx)
|
||||
}
|
||||
|
||||
func getContext(r *http.Request, key interface{}) interface{} {
|
||||
return r.Context().Value(key)
|
||||
}
|
@@ -25,11 +25,11 @@ type UserAgentCookieConfig struct {
|
||||
}
|
||||
|
||||
func NewUserAgentHandler(config *UserAgentCookieConfig, idGenerator id.Generator) (*UserAgentHandler, error) {
|
||||
keys, _, err := crypto.LoadKeys(config.Key)
|
||||
key, err := crypto.LoadKey(config.Key, config.Key.EncryptionKeyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cookieKey := []byte(keys[config.Key.EncryptionKeyID])
|
||||
cookieKey := []byte(key)
|
||||
handler := NewCookieHandler(
|
||||
WithEncryption(cookieKey, cookieKey),
|
||||
WithDomain(config.Domain),
|
||||
|
Reference in New Issue
Block a user