Livio Spring 4d66a786c8
feat: JWT IdP intent (#9966)
# Which Problems Are Solved

The login v1 allowed to use JWTs as IdP using the JWT IDP. The login V2
uses idp intents for such cases, which were not yet able to handle JWT
IdPs.

# How the Problems Are Solved

- Added handling of JWT IdPs in `StartIdPIntent` and `RetrieveIdPIntent`
- The redirect returned by the start, uses the existing `authRequestID`
and `userAgentID` parameter names for compatibility reasons.
- Added `/idps/jwt` endpoint to handle the proxied (callback) endpoint ,
which extracts and validates the JWT against the configured endpoint.

# Additional Changes

None

# Additional Context

- closes #9758
2025-05-27 16:26:46 +02:00

601 lines
20 KiB
Go

package idp
import (
"bytes"
"context"
"encoding/base64"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"github.com/crewjam/saml"
"github.com/gorilla/mux"
"github.com/muhlemmer/gu"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
http_utils "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/ui/login"
"github.com/zitadel/zitadel/internal/cache"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain/federatedlogout"
"github.com/zitadel/zitadel/internal/form"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/apple"
"github.com/zitadel/zitadel/internal/idp/providers/azuread"
"github.com/zitadel/zitadel/internal/idp/providers/github"
"github.com/zitadel/zitadel/internal/idp/providers/gitlab"
"github.com/zitadel/zitadel/internal/idp/providers/google"
"github.com/zitadel/zitadel/internal/idp/providers/jwt"
"github.com/zitadel/zitadel/internal/idp/providers/ldap"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
openid "github.com/zitadel/zitadel/internal/idp/providers/oidc"
saml2 "github.com/zitadel/zitadel/internal/idp/providers/saml"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/zerrors"
)
const (
HandlerPrefix = "/idps"
idpPrefix = "/{" + varIDPID + ":[0-9]+}"
callbackPath = "/callback"
metadataPath = idpPrefix + "/saml/metadata"
acsPath = idpPrefix + "/saml/acs"
certificatePath = idpPrefix + "/saml/certificate"
sloPath = idpPrefix + "/saml/slo"
jwtPath = "/jwt"
paramIntentID = "id"
paramToken = "token"
paramUserID = "user"
paramError = "error"
paramErrorDescription = "error_description"
varIDPID = "idpid"
paramInternalUI = "internalUI"
)
type Handler struct {
commands *command.Commands
queries *query.Queries
parser *form.Parser
encryptionAlgorithm crypto.EncryptionAlgorithm
callbackURL func(ctx context.Context) string
samlRootURL func(ctx context.Context, idpID string) string
loginSAMLRootURL func(ctx context.Context) string
caches *Caches
}
type externalIDPCallbackData struct {
State string `schema:"state"`
Code string `schema:"code"`
Error string `schema:"error"`
ErrorDescription string `schema:"error_description"`
// Apple returns a user on first registration
User string `schema:"user"`
}
type externalSAMLIDPCallbackData struct {
IDPID string
Response string
RelayState string
}
// CallbackURL generates the instance specific URL to the IDP callback handler
func CallbackURL() func(ctx context.Context) string {
return func(ctx context.Context) string {
return http_utils.DomainContext(ctx).Origin() + HandlerPrefix + callbackPath
}
}
func SAMLRootURL() func(ctx context.Context, idpID string) string {
return func(ctx context.Context, idpID string) string {
return http_utils.DomainContext(ctx).Origin() + HandlerPrefix + "/" + idpID + "/"
}
}
func LoginSAMLRootURL() func(ctx context.Context) string {
return func(ctx context.Context) string {
return http_utils.DomainContext(ctx).Origin() + login.HandlerPrefix + login.EndpointSAMLACS
}
}
func NewHandler(
commands *command.Commands,
queries *query.Queries,
encryptionAlgorithm crypto.EncryptionAlgorithm,
instanceInterceptor func(next http.Handler) http.Handler,
federatedLogoutCache cache.Cache[federatedlogout.Index, string, *federatedlogout.FederatedLogout],
) http.Handler {
h := &Handler{
commands: commands,
queries: queries,
parser: form.NewParser(),
encryptionAlgorithm: encryptionAlgorithm,
callbackURL: CallbackURL(),
samlRootURL: SAMLRootURL(),
loginSAMLRootURL: LoginSAMLRootURL(),
caches: &Caches{federatedLogouts: federatedLogoutCache},
}
router := mux.NewRouter()
router.Use(instanceInterceptor)
router.HandleFunc(callbackPath, h.handleCallback)
router.HandleFunc(metadataPath, h.handleMetadata)
router.HandleFunc(certificatePath, h.handleCertificate)
router.HandleFunc(acsPath, h.handleACS)
router.HandleFunc(sloPath, h.handleSLO)
router.HandleFunc(jwtPath, h.handleJWT)
return router
}
type Caches struct {
federatedLogouts cache.Cache[federatedlogout.Index, string, *federatedlogout.FederatedLogout]
}
func parseSAMLRequest(r *http.Request) *externalSAMLIDPCallbackData {
vars := mux.Vars(r)
return &externalSAMLIDPCallbackData{
IDPID: vars[varIDPID],
Response: r.FormValue("SAMLResponse"),
RelayState: r.FormValue("RelayState"),
}
}
func (h *Handler) getProvider(ctx context.Context, idpID string) (idp.Provider, error) {
return h.commands.GetProvider(ctx, idpID, h.callbackURL(ctx), h.samlRootURL(ctx, idpID))
}
func (h *Handler) handleCertificate(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data := parseSAMLRequest(r)
provider, err := h.getProvider(ctx, data.IDPID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
samlProvider, ok := provider.(*saml2.Provider)
if !ok {
http.Error(w, zerrors.ThrowInvalidArgument(nil, "SAML-lrud8s9coi", "Errors.Intent.IDPInvalid").Error(), http.StatusBadRequest)
return
}
certPem := new(bytes.Buffer)
if _, err := certPem.Write(samlProvider.Certificate); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
w.Header().Set("Content-Disposition", "attachment; filename=idp.crt")
w.Header().Set("Content-Type", r.Header.Get("Content-Type"))
_, err = io.Copy(w, certPem)
if err != nil {
http.Error(w, fmt.Errorf("failed to response with certificate: %w", err).Error(), http.StatusInternalServerError)
return
}
}
func (h *Handler) handleMetadata(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data := parseSAMLRequest(r)
provider, err := h.getProvider(ctx, data.IDPID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
samlProvider, ok := provider.(*saml2.Provider)
if !ok {
http.Error(w, zerrors.ThrowInvalidArgument(nil, "SAML-lrud8s9coi", "Errors.Intent.IDPInvalid").Error(), http.StatusBadRequest)
return
}
sp, err := samlProvider.GetSP()
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
metadata := sp.ServiceProvider.Metadata()
internalUI, _ := strconv.ParseBool(r.URL.Query().Get(paramInternalUI))
h.assertionConsumerServices(ctx, metadata, internalUI)
buf, _ := xml.MarshalIndent(metadata, "", " ")
w.Header().Set("Content-Type", "application/samlmetadata+xml")
_, err = w.Write(buf)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
}
func (h *Handler) assertionConsumerServices(ctx context.Context, metadata *saml.EntityDescriptor, internalUI bool) {
if !internalUI {
for i, spDesc := range metadata.SPSSODescriptors {
spDesc.AssertionConsumerServices = append(
spDesc.AssertionConsumerServices,
saml.IndexedEndpoint{
Binding: saml.HTTPPostBinding,
Location: h.loginSAMLRootURL(ctx),
Index: len(spDesc.AssertionConsumerServices) + 1,
}, saml.IndexedEndpoint{
Binding: saml.HTTPArtifactBinding,
Location: h.loginSAMLRootURL(ctx),
Index: len(spDesc.AssertionConsumerServices) + 2,
},
)
metadata.SPSSODescriptors[i] = spDesc
}
return
}
for i, spDesc := range metadata.SPSSODescriptors {
acs := make([]saml.IndexedEndpoint, 0, len(spDesc.AssertionConsumerServices)+2)
acs = append(acs,
saml.IndexedEndpoint{
Binding: saml.HTTPPostBinding,
Location: h.loginSAMLRootURL(ctx),
Index: 0,
IsDefault: gu.Ptr(true),
},
saml.IndexedEndpoint{
Binding: saml.HTTPArtifactBinding,
Location: h.loginSAMLRootURL(ctx),
Index: 1,
})
for i := 0; i < len(spDesc.AssertionConsumerServices); i++ {
spDesc.AssertionConsumerServices[i].Index = 2 + i
acs = append(acs, spDesc.AssertionConsumerServices[i])
}
spDesc.AssertionConsumerServices = acs
metadata.SPSSODescriptors[i] = spDesc
}
}
func (h *Handler) handleACS(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data := parseSAMLRequest(r)
provider, err := h.getProvider(ctx, data.IDPID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
samlProvider, ok := provider.(*saml2.Provider)
if !ok {
err := zerrors.ThrowInvalidArgument(nil, "SAML-ui9wyux0hp", "Errors.Intent.IDPInvalid")
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
intent, err := h.commands.GetActiveIntent(ctx, data.RelayState)
if err != nil {
if zerrors.IsNotFound(err) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
redirectToFailureURLErr(w, r, intent, err)
return
}
session, err := saml2.NewSession(samlProvider, intent.RequestID, r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
idpUser, err := session.FetchUser(r.Context())
if err != nil {
cmdErr := h.commands.FailIDPIntent(ctx, intent, err.Error())
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
redirectToFailureURLErr(w, r, intent, err)
return
}
userID, err := h.checkExternalUser(ctx, intent.IDPID, idpUser.GetID())
logging.WithFields("intent", intent.AggregateID).OnError(err).Error("could not check if idp user already exists")
token, err := h.commands.SucceedSAMLIDPIntent(ctx, intent, idpUser, userID, session)
if err != nil {
redirectToFailureURLErr(w, r, intent, zerrors.ThrowInternal(err, "IDP-JdD3g", "Errors.Intent.TokenCreationFailed"))
return
}
redirectToSuccessURL(w, r, intent, token, userID)
}
func (h *Handler) handleJWT(w http.ResponseWriter, r *http.Request) {
intentID, err := h.intentIDFromJWTRequest(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
intent, err := h.commands.GetActiveIntent(r.Context(), intentID)
if err != nil {
if zerrors.IsNotFound(err) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
redirectToFailureURLErr(w, r, intent, err)
return
}
idpConfig, err := h.getProvider(r.Context(), intent.IDPID)
if err != nil {
cmdErr := h.commands.FailIDPIntent(r.Context(), intent, err.Error())
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
redirectToFailureURLErr(w, r, intent, err)
return
}
jwtIDP, ok := idpConfig.(*jwt.Provider)
if !ok {
err := zerrors.ThrowInvalidArgument(nil, "IDP-JK23ed", "Errors.ExternalIDP.IDPTypeNotImplemented")
cmdErr := h.commands.FailIDPIntent(r.Context(), intent, err.Error())
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
redirectToFailureURLErr(w, r, intent, err)
return
}
h.handleJWTExtraction(w, r, intent, jwtIDP)
}
func (h *Handler) intentIDFromJWTRequest(r *http.Request) (string, error) {
// for compatibility of the old JWT provider we use the auth request id parameter to pass the intent id
intentID := r.FormValue(jwt.QueryAuthRequestID)
// for compatibility of the old JWT provider we use the user agent id parameter to pass the encrypted intent id
encryptedIntentID := r.FormValue(jwt.QueryUserAgentID)
if err := h.checkIntentID(intentID, encryptedIntentID); err != nil {
return "", err
}
return intentID, nil
}
func (h *Handler) checkIntentID(intentID, encryptedIntentID string) error {
if intentID == "" || encryptedIntentID == "" {
return zerrors.ThrowInvalidArgument(nil, "LOGIN-adfzz", "Errors.AuthRequest.MissingParameters")
}
id, err := base64.RawURLEncoding.DecodeString(encryptedIntentID)
if err != nil {
return err
}
decryptedIntentID, err := h.encryptionAlgorithm.DecryptString(id, h.encryptionAlgorithm.EncryptionKeyID())
if err != nil {
return err
}
if intentID != decryptedIntentID {
return zerrors.ThrowInvalidArgument(nil, "LOGIN-adfzz", "Errors.AuthRequest.MissingParameters")
}
return nil
}
func (h *Handler) handleJWTExtraction(w http.ResponseWriter, r *http.Request, intent *command.IDPIntentWriteModel, identityProvider *jwt.Provider) {
session := jwt.NewSessionFromRequest(identityProvider, r)
user, err := session.FetchUser(r.Context())
if err != nil {
cmdErr := h.commands.FailIDPIntent(r.Context(), intent, err.Error())
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
redirectToFailureURLErr(w, r, intent, err)
return
}
userID, err := h.checkExternalUser(r.Context(), intent.IDPID, user.GetID())
logging.WithFields("intent", intent.AggregateID).OnError(err).Error("could not check if idp user already exists")
token, err := h.commands.SucceedIDPIntent(r.Context(), intent, user, session, userID)
if err != nil {
redirectToFailureURLErr(w, r, intent, err)
return
}
redirectToSuccessURL(w, r, intent, token, userID)
}
func (h *Handler) handleCallback(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data, err := h.parseCallbackRequest(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
intent, err := h.commands.GetActiveIntent(ctx, data.State)
if err != nil {
if zerrors.IsNotFound(err) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
redirectToFailureURLErr(w, r, intent, err)
return
}
// the provider might have returned an error
if data.Error != "" {
cmdErr := h.commands.FailIDPIntent(ctx, intent, reason(data.Error, data.ErrorDescription))
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
redirectToFailureURL(w, r, intent, data.Error, data.ErrorDescription)
return
}
provider, err := h.getProvider(ctx, intent.IDPID)
if err != nil {
cmdErr := h.commands.FailIDPIntent(ctx, intent, err.Error())
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
redirectToFailureURLErr(w, r, intent, err)
return
}
idpUser, idpSession, err := h.fetchIDPUserFromCode(ctx, provider, data.Code, data.User, intent.IDPArguments)
if err != nil {
cmdErr := h.commands.FailIDPIntent(ctx, intent, err.Error())
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
redirectToFailureURLErr(w, r, intent, err)
return
}
userID, err := h.checkExternalUser(ctx, intent.IDPID, idpUser.GetID())
logging.WithFields("intent", intent.AggregateID).OnError(err).Error("could not check if idp user already exists")
if userID == "" {
userID, err = h.tryMigrateExternalUser(ctx, intent.IDPID, idpUser, idpSession)
logging.WithFields("intent", intent.AggregateID).OnError(err).Error("migration check failed")
}
token, err := h.commands.SucceedIDPIntent(ctx, intent, idpUser, idpSession, userID)
if err != nil {
redirectToFailureURLErr(w, r, intent, zerrors.ThrowInternal(err, "IDP-JdD3g", "Errors.Intent.TokenCreationFailed"))
return
}
redirectToSuccessURL(w, r, intent, token, userID)
}
func (h *Handler) handleSLO(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data := parseSAMLRequest(r)
logoutState, ok := h.caches.federatedLogouts.Get(ctx, federatedlogout.IndexRequestID, federatedlogout.Key(authz.GetInstance(ctx).InstanceID(), data.RelayState))
if !ok || logoutState.State != federatedlogout.StateRedirected {
err := zerrors.ThrowNotFound(nil, "SAML-3uor2", "Errors.Intent.NotFound")
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// For the moment we just make sure the callback matches the IDP it was started on / intended for.
provider, err := h.getProvider(ctx, data.IDPID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if _, ok = provider.(*saml2.Provider); !ok {
err := zerrors.ThrowInvalidArgument(nil, "SAML-ui9wyux0hp", "Errors.Intent.IDPInvalid")
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// We could also parse and validate the response here, but for example Azure does not sign it and thus would already fail.
// Also we can't really act on it if it fails.
err = h.caches.federatedLogouts.Delete(ctx, federatedlogout.IndexRequestID, federatedlogout.Key(logoutState.InstanceID, logoutState.SessionID))
logging.WithFields("instanceID", logoutState.InstanceID, "sessionID", logoutState.SessionID).OnError(err).Error("could not delete federated logout")
http.Redirect(w, r, logoutState.PostLogoutRedirectURI, http.StatusFound)
}
func (h *Handler) tryMigrateExternalUser(ctx context.Context, idpID string, idpUser idp.User, idpSession idp.Session) (userID string, err error) {
migration, ok := idpSession.(idp.SessionSupportsMigration)
if !ok {
return "", nil
}
previousID, err := migration.RetrievePreviousID()
if err != nil || previousID == "" {
return "", err
}
userID, err = h.checkExternalUser(ctx, idpID, previousID)
if err != nil {
return "", err
}
return userID, h.commands.MigrateUserIDP(ctx, userID, "", idpID, previousID, idpUser.GetID())
}
func (h *Handler) parseCallbackRequest(r *http.Request) (*externalIDPCallbackData, error) {
data := new(externalIDPCallbackData)
err := h.parser.Parse(r, data)
if err != nil {
return nil, err
}
if data.State == "" {
return nil, zerrors.ThrowInvalidArgument(nil, "IDP-Hk38e", "Errors.Intent.StateMissing")
}
return data, nil
}
func redirectToSuccessURL(w http.ResponseWriter, r *http.Request, intent *command.IDPIntentWriteModel, token, userID string) {
queries := intent.SuccessURL.Query()
queries.Set(paramIntentID, intent.AggregateID)
queries.Set(paramToken, token)
if userID != "" {
queries.Set(paramUserID, userID)
}
intent.SuccessURL.RawQuery = queries.Encode()
http.Redirect(w, r, intent.SuccessURL.String(), http.StatusFound)
}
func redirectToFailureURLErr(w http.ResponseWriter, r *http.Request, i *command.IDPIntentWriteModel, err error) {
msg := err.Error()
var description string
zErr := new(zerrors.ZitadelError)
if errors.As(err, &zErr) {
msg = zErr.GetID()
description = zErr.GetMessage() // TODO: i18n?
}
redirectToFailureURL(w, r, i, msg, description)
}
func redirectToFailureURL(w http.ResponseWriter, r *http.Request, i *command.IDPIntentWriteModel, err, description string) {
queries := i.FailureURL.Query()
queries.Set(paramIntentID, i.AggregateID)
queries.Set(paramError, err)
queries.Set(paramErrorDescription, description)
i.FailureURL.RawQuery = queries.Encode()
http.Redirect(w, r, i.FailureURL.String(), http.StatusFound)
}
func (h *Handler) fetchIDPUserFromCode(ctx context.Context, identityProvider idp.Provider, code string, appleUser string, idpArguments map[string]any) (user idp.User, idpTokens idp.Session, err error) {
var session idp.Session
switch provider := identityProvider.(type) {
case *oauth.Provider:
session = oauth.NewSession(provider, code, idpArguments)
case *openid.Provider:
session = openid.NewSession(provider, code, idpArguments)
case *azuread.Provider:
session = azuread.NewSession(provider, code)
case *github.Provider:
session = oauth.NewSession(provider.Provider, code, idpArguments)
case *gitlab.Provider:
session = openid.NewSession(provider.Provider, code, idpArguments)
case *google.Provider:
session = openid.NewSession(provider.Provider, code, idpArguments)
case *apple.Provider:
session = apple.NewSession(provider, code, appleUser)
case *jwt.Provider, *ldap.Provider, *saml2.Provider:
return nil, nil, zerrors.ThrowInvalidArgument(nil, "IDP-52jmn", "Errors.ExternalIDP.IDPTypeNotImplemented")
default:
return nil, nil, zerrors.ThrowUnimplemented(nil, "IDP-SSDg", "Errors.ExternalIDP.IDPTypeNotImplemented")
}
user, err = session.FetchUser(ctx)
if err != nil {
return nil, nil, err
}
return user, session, nil
}
func (h *Handler) checkExternalUser(ctx context.Context, idpID, externalUserID string) (userID string, err error) {
idQuery, err := query.NewIDPUserLinkIDPIDSearchQuery(idpID)
if err != nil {
return "", err
}
externalIDQuery, err := query.NewIDPUserLinksExternalIDSearchQuery(externalUserID)
if err != nil {
return "", err
}
queries := []query.SearchQuery{
idQuery, externalIDQuery,
}
links, err := h.queries.IDPUserLinks(ctx, &query.IDPUserLinksSearchQuery{Queries: queries}, nil)
if err != nil {
return "", err
}
if len(links.Links) != 1 {
return "", nil
}
return links.Links[0].UserID, nil
}
func reason(err, description string) string {
if description == "" {
return err
}
return err + ": " + description
}