fix(idp): prevent server errors for idps using form post for callbacks (#9097)

# Which Problems Are Solved

Some IdP callbacks use HTTP form POST to return their data on callbacks.
For handling CSRF in the login after such calls, a 302 Found to the
corresponding non form callback (in ZITADEL) is sent. Depending on the
size of the initial form body, this could lead to ZITADEL terminating
the connection, resulting in the user not getting a response or an
intermediate proxy to return them an HTTP 502.

# How the Problems Are Solved

- the form body is parsed and stored into the ZITADEL cache (using the
configured database by default)
- the redirect (302 Found) is performed with the request id
- the callback retrieves the data from the cache instead of the query
parameters (will fallback to latter to handle open uncached requests)

# Additional Changes

- fixed a typo in the default (cache) configuration: `LastUsage` ->
`LastUseAge`

# Additional Context

- reported by a customer
- needs to be backported to current cloud version (2.66.x)

---------

Co-authored-by: Silvan <27845747+adlerhurst@users.noreply.github.com>
This commit is contained in:
Livio Spring
2025-01-06 10:47:46 +01:00
committed by GitHub
parent 79af682c9b
commit fa5e590aab
7 changed files with 108 additions and 15 deletions

View File

@@ -214,8 +214,20 @@ func (l *Login) handleExternalLoginCallbackForm(w http.ResponseWriter, r *http.R
l.renderLogin(w, r, nil, err)
return
}
r.Form.Add("Method", http.MethodPost)
http.Redirect(w, r, HandlerPrefix+EndpointExternalLoginCallback+"?"+r.Form.Encode(), 302)
state := r.Form.Get("state")
if state == "" {
state = r.Form.Get("RelayState")
}
if state == "" {
l.renderLogin(w, r, nil, zerrors.ThrowInvalidArgument(nil, "LOGIN-dsg3f", "Errors.AuthRequest.NotFound"))
return
}
l.caches.idpFormCallbacks.Set(r.Context(), &idpFormCallback{
InstanceID: authz.GetInstance(r.Context()).InstanceID(),
State: state,
Form: r.Form,
})
http.Redirect(w, r, HandlerPrefix+EndpointExternalLoginCallback+"?method=POST&state="+state, 302)
}
// handleExternalLoginCallback handles the callback from a IDP
@@ -232,8 +244,7 @@ func (l *Login) handleExternalLoginCallback(w http.ResponseWriter, r *http.Reque
}
// workaround because of CSRF on external identity provider flows
if data.Method == http.MethodPost {
r.Method = http.MethodPost
r.PostForm = r.Form
l.setDataFromFormCallback(r, data.State)
}
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
@@ -345,6 +356,17 @@ func (l *Login) handleExternalLoginCallback(w http.ResponseWriter, r *http.Reque
l.handleExternalUserAuthenticated(w, r, authReq, identityProvider, session, user, l.renderNextStep)
}
func (l *Login) setDataFromFormCallback(r *http.Request, state string) {
r.Method = http.MethodPost
// fallback to the form data in case the request was started before the cache was implemented
r.PostForm = r.Form
idpCallback, ok := l.caches.idpFormCallbacks.Get(r.Context(), idpFormCallbackIndexRequestID,
idpFormCallbackKey(authz.GetInstance(r.Context()).InstanceID(), state))
if ok {
r.PostForm = idpCallback.Form
}
}
func (l *Login) tryMigrateExternalUserID(r *http.Request, session idp.Session, authReq *domain.AuthRequest, externalUser *domain.ExternalUser) (previousIDMatched bool, err error) {
migration, ok := session.(idp.SessionSupportsMigration)
if !ok {

View File

@@ -3,6 +3,7 @@ package login
import (
"context"
"net/http"
"net/url"
"strings"
"time"
@@ -15,6 +16,8 @@ import (
_ "github.com/zitadel/zitadel/internal/api/ui/login/statik"
auth_repository "github.com/zitadel/zitadel/internal/auth/repository"
"github.com/zitadel/zitadel/internal/auth/repository/eventsourcing"
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/cache/connector"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
@@ -38,6 +41,7 @@ type Login struct {
samlAuthCallbackURL func(context.Context, string) string
idpConfigAlg crypto.EncryptionAlgorithm
userCodeAlg crypto.EncryptionAlgorithm
caches *Caches
}
type Config struct {
@@ -74,6 +78,7 @@ func CreateLogin(config Config,
userCodeAlg crypto.EncryptionAlgorithm,
idpConfigAlg crypto.EncryptionAlgorithm,
csrfCookieKey []byte,
cacheConnectors connector.Connectors,
) (*Login, error) {
login := &Login{
oidcAuthCallbackURL: oidcAuthCallbackURL,
@@ -94,6 +99,12 @@ func CreateLogin(config Config,
login.router = CreateRouter(login, middleware.TelemetryHandler(IgnoreInstanceEndpoints...), oidcInstanceHandler, samlInstanceHandler, csrfInterceptor, cacheInterceptor, security, userAgentCookie, issuerInterceptor, accessHandler)
login.renderer = CreateRenderer(HandlerPrefix, staticStorage, config.LanguageCookieName)
login.parser = form.NewParser()
var err error
login.caches, err = startCaches(context.Background(), cacheConnectors)
if err != nil {
return nil, err
}
return login, nil
}
@@ -201,3 +212,41 @@ func setUserContext(ctx context.Context, userID, resourceOwner string) context.C
func (l *Login) baseURL(ctx context.Context) string {
return http_utils.DomainContext(ctx).Origin() + HandlerPrefix
}
type Caches struct {
idpFormCallbacks cache.Cache[idpFormCallbackIndex, string, *idpFormCallback]
}
func startCaches(background context.Context, connectors connector.Connectors) (_ *Caches, err error) {
caches := new(Caches)
caches.idpFormCallbacks, err = connector.StartCache[idpFormCallbackIndex, string, *idpFormCallback](background, []idpFormCallbackIndex{idpFormCallbackIndexRequestID}, cache.PurposeIdPFormCallback, connectors.Config.IdPFormCallbacks, connectors)
if err != nil {
return nil, err
}
return caches, nil
}
type idpFormCallbackIndex int
const (
idpFormCallbackIndexUnspecified idpFormCallbackIndex = iota
idpFormCallbackIndexRequestID
)
type idpFormCallback struct {
InstanceID string
State string
Form url.Values
}
// Keys implements cache.Entry
func (c *idpFormCallback) Keys(i idpFormCallbackIndex) []string {
if i == idpFormCallbackIndexRequestID {
return []string{idpFormCallbackKey(c.InstanceID, c.State)}
}
return nil
}
func idpFormCallbackKey(instanceID, state string) string {
return instanceID + "-" + state
}