package command import ( "context" "encoding/base64" "encoding/json" "encoding/xml" "net/url" "github.com/crewjam/saml" "github.com/crewjam/saml/samlsp" "github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/zitadel/internal/command/preparation" "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/idp" "github.com/zitadel/zitadel/internal/idp/providers/apple" "github.com/zitadel/zitadel/internal/idp/providers/azuread" "github.com/zitadel/zitadel/internal/idp/providers/jwt" "github.com/zitadel/zitadel/internal/idp/providers/oauth" openid "github.com/zitadel/zitadel/internal/idp/providers/oidc" "github.com/zitadel/zitadel/internal/repository/idpintent" "github.com/zitadel/zitadel/internal/zerrors" ) func (c *Commands) prepareCreateIntent(writeModel *IDPIntentWriteModel, idpID string, successURL, failureURL string) preparation.Validation { return func() (_ preparation.CreateCommands, err error) { if idpID == "" { return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-x8j2bk", "Errors.Intent.IDPMissing") } successURL, err := url.Parse(successURL) if err != nil { return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-x8j3bk", "Errors.Intent.SuccessURLMissing") } failureURL, err := url.Parse(failureURL) if err != nil { return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-x8j4bk", "Errors.Intent.FailureURLMissing") } return func(ctx context.Context, filter preparation.FilterToQueryReducer) ([]eventstore.Command, error) { err = getIDPIntentWriteModel(ctx, writeModel, filter) if err != nil { return nil, err } exists, err := ExistsIDP(ctx, filter, idpID, writeModel.ResourceOwner) if !exists || err != nil { return nil, zerrors.ThrowPreconditionFailed(err, "COMMAND-39n221fs", "Errors.IDPConfig.NotExisting") } return []eventstore.Command{ idpintent.NewStartedEvent(ctx, writeModel.aggregate, successURL, failureURL, idpID), }, nil }, nil } } func (c *Commands) CreateIntent(ctx context.Context, idpID, successURL, failureURL, resourceOwner string) (*IDPIntentWriteModel, *domain.ObjectDetails, error) { id, err := c.idGenerator.Next() if err != nil { return nil, nil, err } writeModel := NewIDPIntentWriteModel(id, resourceOwner) if err != nil { return nil, nil, err } cmds, err := preparation.PrepareCommands(ctx, c.eventstore.Filter, c.prepareCreateIntent(writeModel, idpID, successURL, failureURL)) if err != nil { return nil, nil, err } pushedEvents, err := c.eventstore.Push(ctx, cmds...) if err != nil { return nil, nil, err } err = AppendAndReduce(writeModel, pushedEvents...) if err != nil { return nil, nil, err } return writeModel, writeModelToObjectDetails(&writeModel.WriteModel), nil } func (c *Commands) GetProvider(ctx context.Context, idpID string, idpCallback string, samlRootURL string) (idp.Provider, error) { writeModel, err := IDPProviderWriteModel(ctx, c.eventstore.Filter, idpID) if err != nil { return nil, err } if writeModel.IDPType != domain.IDPTypeSAML { return writeModel.ToProvider(idpCallback, c.idpConfigEncryption) } return writeModel.ToSAMLProvider( samlRootURL, c.idpConfigEncryption, func(ctx context.Context, intentID string) (*samlsp.TrackedRequest, error) { intent, err := c.GetActiveIntent(ctx, intentID) if err != nil { return nil, err } return &samlsp.TrackedRequest{ SAMLRequestID: intent.RequestID, Index: intentID, URI: intent.SuccessURL.String(), }, nil }, func(ctx context.Context, intentID, samlRequestID string) error { intent, err := c.GetActiveIntent(ctx, intentID) if err != nil { return err } return c.RequestSAMLIDPIntent(ctx, intent, samlRequestID) }, ) } func (c *Commands) GetActiveIntent(ctx context.Context, intentID string) (*IDPIntentWriteModel, error) { intent, err := c.GetIntentWriteModel(ctx, intentID, "") if err != nil { return nil, err } if intent.State == domain.IDPIntentStateUnspecified { return nil, zerrors.ThrowNotFound(nil, "IDP-Hk38e", "Errors.Intent.NotStarted") } if intent.State != domain.IDPIntentStateStarted { return nil, zerrors.ThrowInvalidArgument(nil, "IDP-Sfrgs", "Errors.Intent.NotStarted") } return intent, nil } func (c *Commands) AuthFromProvider(ctx context.Context, idpID, state string, idpCallback, samlRootURL string) (string, bool, error) { provider, err := c.GetProvider(ctx, idpID, idpCallback, samlRootURL) if err != nil { return "", false, err } session, err := provider.BeginAuth(ctx, state) if err != nil { return "", false, err } content, redirect := session.GetAuth(ctx) return content, redirect, nil } func getIDPIntentWriteModel(ctx context.Context, writeModel *IDPIntentWriteModel, filter preparation.FilterToQueryReducer) error { events, err := filter(ctx, writeModel.Query()) if err != nil { return err } if len(events) == 0 { return nil } writeModel.AppendEvents(events...) return writeModel.Reduce() } func (c *Commands) SucceedIDPIntent(ctx context.Context, writeModel *IDPIntentWriteModel, idpUser idp.User, idpSession idp.Session, userID string) (string, error) { token, err := c.generateIntentToken(writeModel.AggregateID) if err != nil { return "", err } accessToken, idToken, err := tokensForSucceededIDPIntent(idpSession, c.idpConfigEncryption) if err != nil { return "", err } idpInfo, err := json.Marshal(idpUser) if err != nil { return "", err } cmd := idpintent.NewSucceededEvent( ctx, &idpintent.NewAggregate(writeModel.AggregateID, writeModel.ResourceOwner).Aggregate, idpInfo, idpUser.GetID(), idpUser.GetPreferredUsername(), userID, accessToken, idToken, ) err = c.pushAppendAndReduce(ctx, writeModel, cmd) if err != nil { return "", err } return token, nil } func (c *Commands) SucceedSAMLIDPIntent(ctx context.Context, writeModel *IDPIntentWriteModel, idpUser idp.User, userID string, assertion *saml.Assertion) (string, error) { token, err := c.generateIntentToken(writeModel.AggregateID) if err != nil { return "", err } idpInfo, err := json.Marshal(idpUser) if err != nil { return "", err } assertionData, err := xml.Marshal(assertion) if err != nil { return "", err } assertionEnc, err := crypto.Encrypt(assertionData, c.idpConfigEncryption) if err != nil { return "", err } cmd := idpintent.NewSAMLSucceededEvent( ctx, &idpintent.NewAggregate(writeModel.AggregateID, writeModel.ResourceOwner).Aggregate, idpInfo, idpUser.GetID(), idpUser.GetPreferredUsername(), userID, assertionEnc, ) err = c.pushAppendAndReduce(ctx, writeModel, cmd) if err != nil { return "", err } return token, nil } func (c *Commands) RequestSAMLIDPIntent(ctx context.Context, writeModel *IDPIntentWriteModel, requestID string) error { return c.pushAppendAndReduce(ctx, writeModel, idpintent.NewSAMLRequestEvent( ctx, &idpintent.NewAggregate(writeModel.AggregateID, writeModel.ResourceOwner).Aggregate, requestID, )) } func (c *Commands) generateIntentToken(intentID string) (string, error) { token, err := c.idpConfigEncryption.Encrypt([]byte(intentID)) if err != nil { return "", err } return base64.RawURLEncoding.EncodeToString(token), nil } func (c *Commands) SucceedLDAPIDPIntent(ctx context.Context, writeModel *IDPIntentWriteModel, idpUser idp.User, userID string, attributes map[string][]string) (string, error) { token, err := c.generateIntentToken(writeModel.AggregateID) if err != nil { return "", err } idpInfo, err := json.Marshal(idpUser) if err != nil { return "", err } cmd := idpintent.NewLDAPSucceededEvent( ctx, &idpintent.NewAggregate(writeModel.AggregateID, writeModel.ResourceOwner).Aggregate, idpInfo, idpUser.GetID(), idpUser.GetPreferredUsername(), userID, attributes, ) err = c.pushAppendAndReduce(ctx, writeModel, cmd) if err != nil { return "", err } return token, nil } func (c *Commands) FailIDPIntent(ctx context.Context, writeModel *IDPIntentWriteModel, reason string) error { cmd := idpintent.NewFailedEvent( ctx, &idpintent.NewAggregate(writeModel.AggregateID, writeModel.ResourceOwner).Aggregate, reason, ) _, err := c.eventstore.Push(ctx, cmd) return err } func (c *Commands) GetIntentWriteModel(ctx context.Context, id, resourceOwner string) (*IDPIntentWriteModel, error) { writeModel := NewIDPIntentWriteModel(id, resourceOwner) err := c.eventstore.FilterToQueryReducer(ctx, writeModel) if err != nil { return nil, err } return writeModel, err } // tokensForSucceededIDPIntent extracts the oidc.Tokens if available (and encrypts the access_token) for the succeeded event payload func tokensForSucceededIDPIntent(session idp.Session, encryptionAlg crypto.EncryptionAlgorithm) (*crypto.CryptoValue, string, error) { var tokens *oidc.Tokens[*oidc.IDTokenClaims] switch s := session.(type) { case *oauth.Session: tokens = s.Tokens case *openid.Session: tokens = s.Tokens case *jwt.Session: tokens = s.Tokens case *azuread.Session: tokens = s.Tokens() case *apple.Session: tokens = s.Tokens default: return nil, "", nil } if tokens.Token == nil || tokens.AccessToken == "" { return nil, tokens.IDToken, nil } accessToken, err := crypto.Encrypt([]byte(tokens.AccessToken), encryptionAlg) return accessToken, tokens.IDToken, err }