From fa5e590aabda38bd346f1a41484466aebdd8f903 Mon Sep 17 00:00:00 2001 From: Livio Spring Date: Mon, 6 Jan 2025 10:47:46 +0100 Subject: [PATCH] fix(idp): prevent server errors for idps using form post for callbacks (#9097) # Which Problems Are Solved Some IdP callbacks use HTTP form POST to return their data on callbacks. For handling CSRF in the login after such calls, a 302 Found to the corresponding non form callback (in ZITADEL) is sent. Depending on the size of the initial form body, this could lead to ZITADEL terminating the connection, resulting in the user not getting a response or an intermediate proxy to return them an HTTP 502. # How the Problems Are Solved - the form body is parsed and stored into the ZITADEL cache (using the configured database by default) - the redirect (302 Found) is performed with the request id - the callback retrieves the data from the cache instead of the query parameters (will fallback to latter to handle open uncached requests) # Additional Changes - fixed a typo in the default (cache) configuration: `LastUsage` -> `LastUseAge` # Additional Context - reported by a customer - needs to be backported to current cloud version (2.66.x) --------- Co-authored-by: Silvan <27845747+adlerhurst@users.noreply.github.com> --- cmd/defaults.yaml | 21 ++++++-- cmd/start/start.go | 3 ++ .../api/ui/login/external_provider_handler.go | 30 ++++++++++-- internal/api/ui/login/login.go | 49 +++++++++++++++++++ internal/cache/cache.go | 1 + internal/cache/connector/connector.go | 7 +-- internal/cache/purpose_enumer.go | 12 +++-- 7 files changed, 108 insertions(+), 15 deletions(-) diff --git a/cmd/defaults.yaml b/cmd/defaults.yaml index 1e5de1eea1..e993657123 100644 --- a/cmd/defaults.yaml +++ b/cmd/defaults.yaml @@ -198,8 +198,11 @@ Caches: AutoPrune: Interval: 1m TimeOut: 5s + # Postgres connector uses the configured database (postgres or cockraochdb) as cache. + # It is suitable for deployments with multiple containers. + # The cache is enabled by default because it is the default cache states for IdP form callbacks Postgres: - Enabled: false + Enabled: true AutoPrune: Interval: 15m TimeOut: 30s @@ -311,7 +314,7 @@ Caches: # When connector is empty, this cache will be disabled. Connector: "" MaxAge: 1h - LastUsage: 10m + LastUseAge: 10m # Log enables cache-specific logging. Default to error log to stderr when omitted. Log: Level: error @@ -322,7 +325,7 @@ Caches: Milestones: Connector: "" MaxAge: 1h - LastUsage: 10m + LastUseAge: 10m Log: Level: error AddSource: true @@ -332,7 +335,17 @@ Caches: Organization: Connector: "" MaxAge: 1h - LastUsage: 10m + LastUseAge: 10m + Log: + Level: error + AddSource: true + Formatter: + Format: text + # IdP callbacks using form POST cache, required for handling them securely and without possible too big request urls. + IdPFormCallbacks: + Connector: "postgres" + MaxAge: 1h + LastUseAge: 10m Log: Level: error AddSource: true diff --git a/cmd/start/start.go b/cmd/start/start.go index 72ab9ea862..154c683481 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -317,6 +317,7 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server authZRepo, keys, permissionCheck, + cacheConnectors, ) if err != nil { return err @@ -361,6 +362,7 @@ func startAPIs( authZRepo authz_repo.Repository, keys *encryption.EncryptionKeys, permissionCheck domain.PermissionCheck, + cacheConnectors connector.Connectors, ) (*api.API, error) { repo := struct { authz_repo.Repository @@ -542,6 +544,7 @@ func startAPIs( keys.User, keys.IDPConfig, keys.CSRFCookieKey, + cacheConnectors, ) if err != nil { return nil, fmt.Errorf("unable to start login: %w", err) diff --git a/internal/api/ui/login/external_provider_handler.go b/internal/api/ui/login/external_provider_handler.go index 15046d25e8..6b312317be 100644 --- a/internal/api/ui/login/external_provider_handler.go +++ b/internal/api/ui/login/external_provider_handler.go @@ -214,8 +214,20 @@ func (l *Login) handleExternalLoginCallbackForm(w http.ResponseWriter, r *http.R l.renderLogin(w, r, nil, err) return } - r.Form.Add("Method", http.MethodPost) - http.Redirect(w, r, HandlerPrefix+EndpointExternalLoginCallback+"?"+r.Form.Encode(), 302) + state := r.Form.Get("state") + if state == "" { + state = r.Form.Get("RelayState") + } + if state == "" { + l.renderLogin(w, r, nil, zerrors.ThrowInvalidArgument(nil, "LOGIN-dsg3f", "Errors.AuthRequest.NotFound")) + return + } + l.caches.idpFormCallbacks.Set(r.Context(), &idpFormCallback{ + InstanceID: authz.GetInstance(r.Context()).InstanceID(), + State: state, + Form: r.Form, + }) + http.Redirect(w, r, HandlerPrefix+EndpointExternalLoginCallback+"?method=POST&state="+state, 302) } // handleExternalLoginCallback handles the callback from a IDP @@ -232,8 +244,7 @@ func (l *Login) handleExternalLoginCallback(w http.ResponseWriter, r *http.Reque } // workaround because of CSRF on external identity provider flows if data.Method == http.MethodPost { - r.Method = http.MethodPost - r.PostForm = r.Form + l.setDataFromFormCallback(r, data.State) } userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) @@ -345,6 +356,17 @@ func (l *Login) handleExternalLoginCallback(w http.ResponseWriter, r *http.Reque l.handleExternalUserAuthenticated(w, r, authReq, identityProvider, session, user, l.renderNextStep) } +func (l *Login) setDataFromFormCallback(r *http.Request, state string) { + r.Method = http.MethodPost + // fallback to the form data in case the request was started before the cache was implemented + r.PostForm = r.Form + idpCallback, ok := l.caches.idpFormCallbacks.Get(r.Context(), idpFormCallbackIndexRequestID, + idpFormCallbackKey(authz.GetInstance(r.Context()).InstanceID(), state)) + if ok { + r.PostForm = idpCallback.Form + } +} + func (l *Login) tryMigrateExternalUserID(r *http.Request, session idp.Session, authReq *domain.AuthRequest, externalUser *domain.ExternalUser) (previousIDMatched bool, err error) { migration, ok := session.(idp.SessionSupportsMigration) if !ok { diff --git a/internal/api/ui/login/login.go b/internal/api/ui/login/login.go index 57f6a5f9a3..444c5aaa85 100644 --- a/internal/api/ui/login/login.go +++ b/internal/api/ui/login/login.go @@ -3,6 +3,7 @@ package login import ( "context" "net/http" + "net/url" "strings" "time" @@ -15,6 +16,8 @@ import ( _ "github.com/zitadel/zitadel/internal/api/ui/login/statik" auth_repository "github.com/zitadel/zitadel/internal/auth/repository" "github.com/zitadel/zitadel/internal/auth/repository/eventsourcing" + "github.com/zitadel/zitadel/internal/cache" + "github.com/zitadel/zitadel/internal/cache/connector" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/domain" @@ -38,6 +41,7 @@ type Login struct { samlAuthCallbackURL func(context.Context, string) string idpConfigAlg crypto.EncryptionAlgorithm userCodeAlg crypto.EncryptionAlgorithm + caches *Caches } type Config struct { @@ -74,6 +78,7 @@ func CreateLogin(config Config, userCodeAlg crypto.EncryptionAlgorithm, idpConfigAlg crypto.EncryptionAlgorithm, csrfCookieKey []byte, + cacheConnectors connector.Connectors, ) (*Login, error) { login := &Login{ oidcAuthCallbackURL: oidcAuthCallbackURL, @@ -94,6 +99,12 @@ func CreateLogin(config Config, login.router = CreateRouter(login, middleware.TelemetryHandler(IgnoreInstanceEndpoints...), oidcInstanceHandler, samlInstanceHandler, csrfInterceptor, cacheInterceptor, security, userAgentCookie, issuerInterceptor, accessHandler) login.renderer = CreateRenderer(HandlerPrefix, staticStorage, config.LanguageCookieName) login.parser = form.NewParser() + + var err error + login.caches, err = startCaches(context.Background(), cacheConnectors) + if err != nil { + return nil, err + } return login, nil } @@ -201,3 +212,41 @@ func setUserContext(ctx context.Context, userID, resourceOwner string) context.C func (l *Login) baseURL(ctx context.Context) string { return http_utils.DomainContext(ctx).Origin() + HandlerPrefix } + +type Caches struct { + idpFormCallbacks cache.Cache[idpFormCallbackIndex, string, *idpFormCallback] +} + +func startCaches(background context.Context, connectors connector.Connectors) (_ *Caches, err error) { + caches := new(Caches) + caches.idpFormCallbacks, err = connector.StartCache[idpFormCallbackIndex, string, *idpFormCallback](background, []idpFormCallbackIndex{idpFormCallbackIndexRequestID}, cache.PurposeIdPFormCallback, connectors.Config.IdPFormCallbacks, connectors) + if err != nil { + return nil, err + } + return caches, nil +} + +type idpFormCallbackIndex int + +const ( + idpFormCallbackIndexUnspecified idpFormCallbackIndex = iota + idpFormCallbackIndexRequestID +) + +type idpFormCallback struct { + InstanceID string + State string + Form url.Values +} + +// Keys implements cache.Entry +func (c *idpFormCallback) Keys(i idpFormCallbackIndex) []string { + if i == idpFormCallbackIndexRequestID { + return []string{idpFormCallbackKey(c.InstanceID, c.State)} + } + return nil +} + +func idpFormCallbackKey(instanceID, state string) string { + return instanceID + "-" + state +} diff --git a/internal/cache/cache.go b/internal/cache/cache.go index c7dbad6f2c..dc05208caa 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -17,6 +17,7 @@ const ( PurposeAuthzInstance PurposeMilestones PurposeOrganization + PurposeIdPFormCallback ) // Cache stores objects with a value of type `V`. diff --git a/internal/cache/connector/connector.go b/internal/cache/connector/connector.go index 09298fa688..1a0534759a 100644 --- a/internal/cache/connector/connector.go +++ b/internal/cache/connector/connector.go @@ -19,9 +19,10 @@ type CachesConfig struct { Postgres pg.Config Redis redis.Config } - Instance *cache.Config - Milestones *cache.Config - Organization *cache.Config + Instance *cache.Config + Milestones *cache.Config + Organization *cache.Config + IdPFormCallbacks *cache.Config } type Connectors struct { diff --git a/internal/cache/purpose_enumer.go b/internal/cache/purpose_enumer.go index 47ad167d70..a93a978efb 100644 --- a/internal/cache/purpose_enumer.go +++ b/internal/cache/purpose_enumer.go @@ -7,11 +7,11 @@ import ( "strings" ) -const _PurposeName = "unspecifiedauthz_instancemilestonesorganization" +const _PurposeName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback" -var _PurposeIndex = [...]uint8{0, 11, 25, 35, 47} +var _PurposeIndex = [...]uint8{0, 11, 25, 35, 47, 65} -const _PurposeLowerName = "unspecifiedauthz_instancemilestonesorganization" +const _PurposeLowerName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback" func (i Purpose) String() string { if i < 0 || i >= Purpose(len(_PurposeIndex)-1) { @@ -28,9 +28,10 @@ func _PurposeNoOp() { _ = x[PurposeAuthzInstance-(1)] _ = x[PurposeMilestones-(2)] _ = x[PurposeOrganization-(3)] + _ = x[PurposeIdPFormCallback-(4)] } -var _PurposeValues = []Purpose{PurposeUnspecified, PurposeAuthzInstance, PurposeMilestones, PurposeOrganization} +var _PurposeValues = []Purpose{PurposeUnspecified, PurposeAuthzInstance, PurposeMilestones, PurposeOrganization, PurposeIdPFormCallback} var _PurposeNameToValueMap = map[string]Purpose{ _PurposeName[0:11]: PurposeUnspecified, @@ -41,6 +42,8 @@ var _PurposeNameToValueMap = map[string]Purpose{ _PurposeLowerName[25:35]: PurposeMilestones, _PurposeName[35:47]: PurposeOrganization, _PurposeLowerName[35:47]: PurposeOrganization, + _PurposeName[47:65]: PurposeIdPFormCallback, + _PurposeLowerName[47:65]: PurposeIdPFormCallback, } var _PurposeNames = []string{ @@ -48,6 +51,7 @@ var _PurposeNames = []string{ _PurposeName[11:25], _PurposeName[25:35], _PurposeName[35:47], + _PurposeName[47:65], } // PurposeString retrieves an enum value from the enum constants string name.