From 2efa305e10d1e8b2d3a55206df1d29d7ef822d6f Mon Sep 17 00:00:00 2001 From: Livio Spring Date: Fri, 3 Mar 2023 11:38:49 +0100 Subject: [PATCH] fix: use of generic oauth provider (#5345) Adds a id_attribute to the GenericOAuthProvider, which is used to map the external User. Further mapping can be done in actions by using the `rawInfo` of the new `ctx.v1.providerInfo` field. --- .../apis/actions/external-authentication.md | 3 + internal/api/grpc/admin/idp_converter.go | 2 + internal/api/grpc/management/idp_converter.go | 2 + internal/api/ui/login/custom_action.go | 40 +++++++++-- .../api/ui/login/external_provider_handler.go | 58 +++++++++++---- internal/api/ui/login/jwt_handler.go | 8 +-- internal/command/idp.go | 1 + internal/command/idp_model.go | 11 ++- internal/command/instance_idp.go | 8 +++ internal/command/instance_idp_model.go | 4 +- internal/command/instance_idp_test.go | 51 +++++++++++++ internal/command/org_idp.go | 8 +++ internal/command/org_idp_model.go | 4 +- internal/command/org_idp_test.go | 53 ++++++++++++++ internal/idp/providers/oauth/mapper.go | 71 ++++++++++--------- internal/idp/providers/oauth/session_test.go | 18 ++--- internal/query/idp_template.go | 13 ++++ internal/query/idp_template_test.go | 51 ++++++++----- .../query/projection/idp_login_policy_link.go | 29 ++++++++ .../projection/idp_login_policy_link_test.go | 60 ++++++++++++++++ internal/query/projection/idp_template.go | 8 ++- .../query/projection/idp_template_test.go | 14 ++-- internal/repository/idp/oauth.go | 12 +++- internal/repository/instance/idp.go | 4 +- internal/repository/org/idp.go | 4 +- proto/zitadel/admin.proto | 8 ++- proto/zitadel/idp.proto | 1 + proto/zitadel/management.proto | 8 ++- 28 files changed, 456 insertions(+), 98 deletions(-) diff --git a/docs/docs/apis/actions/external-authentication.md b/docs/docs/apis/actions/external-authentication.md index 8ba447333d..cdd06449f7 100644 --- a/docs/docs/apis/actions/external-authentication.md +++ b/docs/docs/apis/actions/external-authentication.md @@ -26,6 +26,9 @@ The first parameter contains the following fields This is a verification errors string representation. If the verification succeeds, this is "none" - `authRequest` [*auth request*](/docs/apis/actions/objects#auth-request) - `httpRequest` [*http request*](/docs/apis/actions/objects#http-request) + - `providerInfo` *Any* + Returns the response of the provider. In case the provider is a Generic OAuth Provider, the information is accessible through: + - `rawInfo` *Any* - `api` The second parameter contains the following fields - `v1` diff --git a/internal/api/grpc/admin/idp_converter.go b/internal/api/grpc/admin/idp_converter.go index c8f2f78e95..fc1802f388 100644 --- a/internal/api/grpc/admin/idp_converter.go +++ b/internal/api/grpc/admin/idp_converter.go @@ -210,6 +210,7 @@ func addGenericOAuthProviderToCommand(req *admin_pb.AddGenericOAuthProviderReque TokenEndpoint: req.TokenEndpoint, UserEndpoint: req.UserEndpoint, Scopes: req.Scopes, + IDAttribute: req.IdAttribute, IDPOptions: idp_grpc.OptionsToCommand(req.ProviderOptions), } } @@ -223,6 +224,7 @@ func updateGenericOAuthProviderToCommand(req *admin_pb.UpdateGenericOAuthProvide TokenEndpoint: req.TokenEndpoint, UserEndpoint: req.UserEndpoint, Scopes: req.Scopes, + IDAttribute: req.IdAttribute, IDPOptions: idp_grpc.OptionsToCommand(req.ProviderOptions), } } diff --git a/internal/api/grpc/management/idp_converter.go b/internal/api/grpc/management/idp_converter.go index 0138392cae..75d8e2bab9 100644 --- a/internal/api/grpc/management/idp_converter.go +++ b/internal/api/grpc/management/idp_converter.go @@ -227,6 +227,7 @@ func addGenericOAuthProviderToCommand(req *mgmt_pb.AddGenericOAuthProviderReques TokenEndpoint: req.TokenEndpoint, UserEndpoint: req.UserEndpoint, Scopes: req.Scopes, + IDAttribute: req.IdAttribute, IDPOptions: idp_grpc.OptionsToCommand(req.ProviderOptions), } } @@ -240,6 +241,7 @@ func updateGenericOAuthProviderToCommand(req *mgmt_pb.UpdateGenericOAuthProvider TokenEndpoint: req.TokenEndpoint, UserEndpoint: req.UserEndpoint, Scopes: req.Scopes, + IDAttribute: req.IdAttribute, IDPOptions: idp_grpc.OptionsToCommand(req.ProviderOptions), } } diff --git a/internal/api/ui/login/custom_action.go b/internal/api/ui/login/custom_action.go index 43ac22c017..3a1dc6d890 100644 --- a/internal/api/ui/login/custom_action.go +++ b/internal/api/ui/login/custom_action.go @@ -13,6 +13,7 @@ import ( "github.com/zitadel/zitadel/internal/actions/object" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/idp" ) func (l *Login) runPostExternalAuthenticationActions( @@ -20,6 +21,7 @@ func (l *Login) runPostExternalAuthenticationActions( tokens *oidc.Tokens, authRequest *domain.AuthRequest, httpRequest *http.Request, + idpUser idp.User, authenticationError error, ) (*domain.ExternalUser, error) { ctx := httpRequest.Context() @@ -86,6 +88,9 @@ func (l *Login) runPostExternalAuthenticationActions( actions.SetFields("externalUser", func(c *actions.FieldConfig) interface{} { return object.UserFromExternalUser(c, user) }), + actions.SetFields("providerInfo", func(c *actions.FieldConfig) interface{} { + return c.Runtime.ToValue(idpUser) + }), actions.SetFields("authRequest", object.AuthRequestField(authRequest)), actions.SetFields("httpRequest", object.HTTPRequestField(httpRequest)), actions.SetFields("authError", authErrStr), @@ -337,18 +342,39 @@ func (l *Login) runPostCreationActions( } func tokenCtxFields(tokens *oidc.Tokens) []actions.FieldOption { - return []actions.FieldOption{ - actions.SetFields("accessToken", tokens.AccessToken), - actions.SetFields("idToken", tokens.IDToken), - actions.SetFields("getClaim", func(claim string) interface{} { + var accessToken, idToken string + getClaim := func(claim string) interface{} { + return nil + } + claimsJSON := func() (string, error) { + return "", nil + } + if tokens == nil { + return []actions.FieldOption{ + actions.SetFields("accessToken", accessToken), + actions.SetFields("idToken", idToken), + actions.SetFields("getClaim", getClaim), + actions.SetFields("claimsJSON", claimsJSON), + } + } + accessToken = tokens.AccessToken + idToken = tokens.IDToken + if tokens.IDTokenClaims != nil { + getClaim = func(claim string) interface{} { return tokens.IDTokenClaims.GetClaim(claim) - }), - actions.SetFields("claimsJSON", func() (string, error) { + } + claimsJSON = func() (string, error) { c, err := json.Marshal(tokens.IDTokenClaims) if err != nil { return "", err } return string(c), nil - }), + } + } + return []actions.FieldOption{ + actions.SetFields("accessToken", accessToken), + actions.SetFields("idToken", idToken), + actions.SetFields("getClaim", getClaim), + actions.SetFields("claimsJSON", claimsJSON), } } diff --git a/internal/api/ui/login/external_provider_handler.go b/internal/api/ui/login/external_provider_handler.go index 19d572ae88..6bd0ccf0bc 100644 --- a/internal/api/ui/login/external_provider_handler.go +++ b/internal/api/ui/login/external_provider_handler.go @@ -20,6 +20,7 @@ import ( "github.com/zitadel/zitadel/internal/idp" "github.com/zitadel/zitadel/internal/idp/providers/google" "github.com/zitadel/zitadel/internal/idp/providers/jwt" + "github.com/zitadel/zitadel/internal/idp/providers/oauth" openid "github.com/zitadel/zitadel/internal/idp/providers/oidc" "github.com/zitadel/zitadel/internal/query" ) @@ -134,14 +135,15 @@ func (l *Login) handleIDP(w http.ResponseWriter, r *http.Request, authReq *domai } var provider idp.Provider switch identityProvider.Type { + case domain.IDPTypeOAuth: + provider, err = l.oauthProvider(r.Context(), identityProvider) case domain.IDPTypeOIDC: provider, err = l.oidcProvider(r.Context(), identityProvider) case domain.IDPTypeJWT: provider, err = l.jwtProvider(identityProvider) case domain.IDPTypeGoogle: provider, err = l.googleProvider(r.Context(), identityProvider) - case domain.IDPTypeOAuth, - domain.IDPTypeLDAP, + case domain.IDPTypeLDAP, domain.IDPTypeAzureAD, domain.IDPTypeGitHub, domain.IDPTypeGitHubEE, @@ -177,33 +179,39 @@ func (l *Login) handleExternalLoginCallback(w http.ResponseWriter, r *http.Reque userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) authReq, err := l.authRepo.AuthRequestByID(r.Context(), data.State, userAgentID) if err != nil { - l.externalAuthFailed(w, r, authReq, nil, err) + l.externalAuthFailed(w, r, authReq, nil, nil, err) return } identityProvider, err := l.getIDPByID(r, authReq.SelectedIDPConfigID) if err != nil { - l.externalAuthFailed(w, r, authReq, nil, err) + l.externalAuthFailed(w, r, authReq, nil, nil, err) return } var provider idp.Provider var session idp.Session switch identityProvider.Type { + case domain.IDPTypeOAuth: + provider, err = l.oauthProvider(r.Context(), identityProvider) + if err != nil { + l.externalAuthFailed(w, r, authReq, nil, nil, err) + return + } + session = &oauth.Session{Provider: provider.(*oauth.Provider), Code: data.Code} case domain.IDPTypeOIDC: provider, err = l.oidcProvider(r.Context(), identityProvider) if err != nil { - l.externalAuthFailed(w, r, authReq, nil, err) + l.externalAuthFailed(w, r, authReq, nil, nil, err) return } session = &openid.Session{Provider: provider.(*openid.Provider), Code: data.Code} case domain.IDPTypeGoogle: provider, err = l.googleProvider(r.Context(), identityProvider) if err != nil { - l.externalAuthFailed(w, r, authReq, nil, err) + l.externalAuthFailed(w, r, authReq, nil, nil, err) return } session = &openid.Session{Provider: provider.(*google.Provider).Provider, Code: data.Code} case domain.IDPTypeJWT, - domain.IDPTypeOAuth, domain.IDPTypeLDAP, domain.IDPTypeAzureAD, domain.IDPTypeGitHub, @@ -219,7 +227,7 @@ func (l *Login) handleExternalLoginCallback(w http.ResponseWriter, r *http.Reque user, err := session.FetchUser(r.Context()) if err != nil { - l.externalAuthFailed(w, r, authReq, tokens(session), err) + l.externalAuthFailed(w, r, authReq, tokens(session), user, err) return } l.handleExternalUserAuthenticated(w, r, authReq, identityProvider, session, user, l.renderNextStep) @@ -236,7 +244,7 @@ func (l *Login) handleExternalUserAuthenticated( callback func(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest), ) { externalUser := mapIDPUserToExternalUser(user, provider.ID) - externalUser, err := l.runPostExternalAuthenticationActions(externalUser, tokens(session), authReq, r, nil) + externalUser, err := l.runPostExternalAuthenticationActions(externalUser, tokens(session), authReq, r, user, nil) if err != nil { l.renderError(w, r, authReq, err) return @@ -600,6 +608,31 @@ func (l *Login) jwtProvider(identityProvider *query.IDPTemplate) (*jwt.Provider, ) } +func (l *Login) oauthProvider(ctx context.Context, identityProvider *query.IDPTemplate) (*oauth.Provider, error) { + secret, err := crypto.DecryptString(identityProvider.OAuthIDPTemplate.ClientSecret, l.idpConfigAlg) + if err != nil { + return nil, err + } + config := &oauth2.Config{ + ClientID: identityProvider.OAuthIDPTemplate.ClientID, + ClientSecret: secret, + Endpoint: oauth2.Endpoint{ + AuthURL: identityProvider.OAuthIDPTemplate.AuthorizationEndpoint, + TokenURL: identityProvider.OAuthIDPTemplate.TokenEndpoint, + }, + RedirectURL: l.baseURL(ctx) + EndpointExternalLoginCallback, + Scopes: identityProvider.OAuthIDPTemplate.Scopes, + } + return oauth.New( + config, + identityProvider.Name, + identityProvider.OAuthIDPTemplate.UserEndpoint, + func() idp.User { + return oauth.NewUserMapper(identityProvider.OAuthIDPTemplate.IDAttribute) + }, + ) +} + func (l *Login) appendUserGrants(ctx context.Context, userGrants []*domain.UserGrant, resourceOwner string) error { if len(userGrants) == 0 { return nil @@ -613,11 +646,8 @@ func (l *Login) appendUserGrants(ctx context.Context, userGrants []*domain.UserG return nil } -func (l *Login) externalAuthFailed(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, tokens *oidc.Tokens, err error) { - if tokens == nil { - tokens = &oidc.Tokens{Token: &oauth2.Token{}} - } - if _, actionErr := l.runPostExternalAuthenticationActions(&domain.ExternalUser{}, tokens, authReq, r, err); actionErr != nil { +func (l *Login) externalAuthFailed(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, tokens *oidc.Tokens, user idp.User, err error) { + if _, actionErr := l.runPostExternalAuthenticationActions(&domain.ExternalUser{}, tokens, authReq, r, user, err); actionErr != nil { logging.WithError(err).Error("both external user authentication and action post authentication failed") } l.renderLogin(w, r, authReq, err) diff --git a/internal/api/ui/login/jwt_handler.go b/internal/api/ui/login/jwt_handler.go index 8cef3ed747..fd5eac88c6 100644 --- a/internal/api/ui/login/jwt_handler.go +++ b/internal/api/ui/login/jwt_handler.go @@ -66,8 +66,7 @@ func (l *Login) handleJWTRequest(w http.ResponseWriter, r *http.Request) { func (l *Login) handleJWTExtraction(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, identityProvider *query.IDPTemplate) { token, err := getToken(r, identityProvider.JWTIDPTemplate.HeaderName) if err != nil { - emptyTokens := &oidc.Tokens{Token: &oauth2.Token{}} - if _, actionErr := l.runPostExternalAuthenticationActions(&domain.ExternalUser{}, emptyTokens, authReq, r, err); actionErr != nil { + if _, actionErr := l.runPostExternalAuthenticationActions(new(domain.ExternalUser), nil, authReq, r, nil, err); actionErr != nil { logging.WithError(err).Error("both external user authentication and action post authentication failed") } @@ -76,8 +75,7 @@ func (l *Login) handleJWTExtraction(w http.ResponseWriter, r *http.Request, auth } provider, err := l.jwtProvider(identityProvider) if err != nil { - emptyTokens := &oidc.Tokens{Token: &oauth2.Token{}} - if _, actionErr := l.runPostExternalAuthenticationActions(&domain.ExternalUser{}, emptyTokens, authReq, r, err); actionErr != nil { + if _, actionErr := l.runPostExternalAuthenticationActions(new(domain.ExternalUser), nil, authReq, r, nil, err); actionErr != nil { logging.WithError(err).Error("both external user authentication and action post authentication failed") } l.renderError(w, r, authReq, err) @@ -86,7 +84,7 @@ func (l *Login) handleJWTExtraction(w http.ResponseWriter, r *http.Request, auth session := &jwt.Session{Provider: provider, Tokens: &oidc.Tokens{IDToken: token, Token: &oauth2.Token{}}} user, err := session.FetchUser(r.Context()) if err != nil { - if _, actionErr := l.runPostExternalAuthenticationActions(&domain.ExternalUser{}, tokens(session), authReq, r, err); actionErr != nil { + if _, actionErr := l.runPostExternalAuthenticationActions(new(domain.ExternalUser), tokens(session), authReq, r, user, err); actionErr != nil { logging.WithError(err).Error("both external user authentication and action post authentication failed") } l.renderError(w, r, authReq, err) diff --git a/internal/command/idp.go b/internal/command/idp.go index 811bc2d1d7..d0add970aa 100644 --- a/internal/command/idp.go +++ b/internal/command/idp.go @@ -16,6 +16,7 @@ type GenericOAuthProvider struct { TokenEndpoint string UserEndpoint string Scopes []string + IDAttribute string IDPOptions idp.Options } diff --git a/internal/command/idp_model.go b/internal/command/idp_model.go index eb2944b611..86a7543241 100644 --- a/internal/command/idp_model.go +++ b/internal/command/idp_model.go @@ -21,6 +21,7 @@ type OAuthIDPWriteModel struct { TokenEndpoint string UserEndpoint string Scopes []string + IDAttribute string idp.Options State domain.IDPState @@ -48,6 +49,7 @@ func (wm *OAuthIDPWriteModel) reduceAddedEvent(e *idp.OAuthIDPAddedEvent) { wm.TokenEndpoint = e.TokenEndpoint wm.UserEndpoint = e.UserEndpoint wm.Scopes = e.Scopes + wm.IDAttribute = e.IDAttribute wm.State = domain.IDPStateActive } @@ -73,6 +75,9 @@ func (wm *OAuthIDPWriteModel) reduceChangedEvent(e *idp.OAuthIDPChangedEvent) { if e.Scopes != nil { wm.Scopes = e.Scopes } + if e.IDAttribute != nil { + wm.IDAttribute = *e.IDAttribute + } wm.Options.ReduceChanges(e.OptionChanges) } @@ -83,7 +88,8 @@ func (wm *OAuthIDPWriteModel) NewChanges( secretCrypto crypto.Crypto, authorizationEndpoint, tokenEndpoint, - userEndpoint string, + userEndpoint, + idAttribute string, scopes []string, options idp.Options, ) ([]idp.OAuthIDPChanges, error) { @@ -115,6 +121,9 @@ func (wm *OAuthIDPWriteModel) NewChanges( if !reflect.DeepEqual(wm.Scopes, scopes) { changes = append(changes, idp.ChangeOAuthScopes(scopes)) } + if wm.IDAttribute != idAttribute { + changes = append(changes, idp.ChangeOAuthIDAttribute(idAttribute)) + } opts := wm.Options.Changes(options) if !opts.IsZero() { changes = append(changes, idp.ChangeOAuthOptions(opts)) diff --git a/internal/command/instance_idp.go b/internal/command/instance_idp.go index 0139b996bc..80f49711e4 100644 --- a/internal/command/instance_idp.go +++ b/internal/command/instance_idp.go @@ -273,6 +273,9 @@ func (c *Commands) prepareAddInstanceOAuthProvider(a *instance.Aggregate, writeM if provider.UserEndpoint = strings.TrimSpace(provider.UserEndpoint); provider.UserEndpoint == "" { return nil, caos_errs.ThrowInvalidArgument(nil, "INST-Fb8jk", "Errors.Invalid.Argument") } + if provider.IDAttribute = strings.TrimSpace(provider.IDAttribute); provider.IDAttribute == "" { + return nil, caos_errs.ThrowInvalidArgument(nil, "INST-sdf3f", "Errors.Invalid.Argument") + } return func(ctx context.Context, filter preparation.FilterToQueryReducer) ([]eventstore.Command, error) { events, err := filter(ctx, writeModel.Query()) if err != nil { @@ -297,6 +300,7 @@ func (c *Commands) prepareAddInstanceOAuthProvider(a *instance.Aggregate, writeM provider.AuthorizationEndpoint, provider.TokenEndpoint, provider.UserEndpoint, + provider.IDAttribute, provider.Scopes, provider.IDPOptions, ), @@ -322,6 +326,9 @@ func (c *Commands) prepareUpdateInstanceOAuthProvider(a *instance.Aggregate, wri if provider.UserEndpoint = strings.TrimSpace(provider.UserEndpoint); provider.UserEndpoint == "" { return nil, caos_errs.ThrowInvalidArgument(nil, "INST-Fb8jk", "Errors.Invalid.Argument") } + if provider.IDAttribute = strings.TrimSpace(provider.IDAttribute); provider.IDAttribute == "" { + return nil, caos_errs.ThrowInvalidArgument(nil, "INST-asf3fs", "Errors.Invalid.Argument") + } return func(ctx context.Context, filter preparation.FilterToQueryReducer) ([]eventstore.Command, error) { events, err := filter(ctx, writeModel.Query()) if err != nil { @@ -345,6 +352,7 @@ func (c *Commands) prepareUpdateInstanceOAuthProvider(a *instance.Aggregate, wri provider.AuthorizationEndpoint, provider.TokenEndpoint, provider.UserEndpoint, + provider.IDAttribute, provider.Scopes, provider.IDPOptions, ) diff --git a/internal/command/instance_idp_model.go b/internal/command/instance_idp_model.go index 4184892c58..113d764fb0 100644 --- a/internal/command/instance_idp_model.go +++ b/internal/command/instance_idp_model.go @@ -67,7 +67,8 @@ func (wm *InstanceOAuthIDPWriteModel) NewChangedEvent( secretCrypto crypto.Crypto, authorizationEndpoint, tokenEndpoint, - userEndpoint string, + userEndpoint, + idAttribute string, scopes []string, options idp.Options, ) (*instance.OAuthIDPChangedEvent, error) { @@ -80,6 +81,7 @@ func (wm *InstanceOAuthIDPWriteModel) NewChangedEvent( authorizationEndpoint, tokenEndpoint, userEndpoint, + idAttribute, scopes, options, ) diff --git a/internal/command/instance_idp_test.go b/internal/command/instance_idp_test.go index ffeddf96b5..6105641a8c 100644 --- a/internal/command/instance_idp_test.go +++ b/internal/command/instance_idp_test.go @@ -145,6 +145,27 @@ func TestCommandSide_AddInstanceGenericOAuthIDP(t *testing.T) { err: caos_errors.IsErrorInvalidArgument, }, }, + { + "invalid id attribute", + fields{ + eventstore: eventstoreExpect(t), + idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "id1"), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instance1"), + provider: GenericOAuthProvider{ + Name: "name", + ClientID: "clientID", + ClientSecret: "clientSecret", + AuthorizationEndpoint: "auth", + TokenEndpoint: "token", + UserEndpoint: "user", + }, + }, + res{ + err: caos_errors.IsErrorInvalidArgument, + }, + }, { name: "ok", fields: fields{ @@ -167,6 +188,7 @@ func TestCommandSide_AddInstanceGenericOAuthIDP(t *testing.T) { "auth", "token", "user", + "idAttribute", nil, idp.Options{}, )), @@ -185,6 +207,7 @@ func TestCommandSide_AddInstanceGenericOAuthIDP(t *testing.T) { AuthorizationEndpoint: "auth", TokenEndpoint: "token", UserEndpoint: "user", + IDAttribute: "idAttribute", }, }, res: res{ @@ -214,6 +237,7 @@ func TestCommandSide_AddInstanceGenericOAuthIDP(t *testing.T) { "auth", "token", "user", + "idAttribute", []string{"user"}, idp.Options{ IsCreationAllowed: true, @@ -238,6 +262,7 @@ func TestCommandSide_AddInstanceGenericOAuthIDP(t *testing.T) { TokenEndpoint: "token", UserEndpoint: "user", Scopes: []string{"user"}, + IDAttribute: "idAttribute", IDPOptions: idp.Options{ IsCreationAllowed: true, IsLinkingAllowed: true, @@ -390,6 +415,26 @@ func TestCommandSide_UpdateInstanceGenericOAuthIDP(t *testing.T) { err: caos_errors.IsErrorInvalidArgument, }, }, + { + "invalid id attribute", + fields{ + eventstore: eventstoreExpect(t), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instance1"), + id: "id1", + provider: GenericOAuthProvider{ + Name: "name", + ClientID: "clientID", + AuthorizationEndpoint: "auth", + TokenEndpoint: "token", + UserEndpoint: "user", + }, + }, + res{ + err: caos_errors.IsErrorInvalidArgument, + }, + }, { name: "not found", fields: fields{ @@ -406,6 +451,7 @@ func TestCommandSide_UpdateInstanceGenericOAuthIDP(t *testing.T) { AuthorizationEndpoint: "auth", TokenEndpoint: "token", UserEndpoint: "user", + IDAttribute: "idAttribute", }, }, res: res{ @@ -431,6 +477,7 @@ func TestCommandSide_UpdateInstanceGenericOAuthIDP(t *testing.T) { "auth", "token", "user", + "idAttribute", nil, idp.Options{}, )), @@ -446,6 +493,7 @@ func TestCommandSide_UpdateInstanceGenericOAuthIDP(t *testing.T) { AuthorizationEndpoint: "auth", TokenEndpoint: "token", UserEndpoint: "user", + IDAttribute: "idAttribute", }, }, res: res{ @@ -471,6 +519,7 @@ func TestCommandSide_UpdateInstanceGenericOAuthIDP(t *testing.T) { "auth", "token", "user", + "idAttribute", nil, idp.Options{}, )), @@ -496,6 +545,7 @@ func TestCommandSide_UpdateInstanceGenericOAuthIDP(t *testing.T) { idp.ChangeOAuthTokenEndpoint("new token"), idp.ChangeOAuthUserEndpoint("new user"), idp.ChangeOAuthScopes([]string{"openid", "profile"}), + idp.ChangeOAuthIDAttribute("newAttribute"), idp.ChangeOAuthOptions(idp.OptionChanges{ IsCreationAllowed: &t, IsLinkingAllowed: &t, @@ -523,6 +573,7 @@ func TestCommandSide_UpdateInstanceGenericOAuthIDP(t *testing.T) { TokenEndpoint: "new token", UserEndpoint: "new user", Scopes: []string{"openid", "profile"}, + IDAttribute: "newAttribute", IDPOptions: idp.Options{ IsCreationAllowed: true, IsLinkingAllowed: true, diff --git a/internal/command/org_idp.go b/internal/command/org_idp.go index 734d0c5b71..7717736763 100644 --- a/internal/command/org_idp.go +++ b/internal/command/org_idp.go @@ -262,6 +262,9 @@ func (c *Commands) prepareAddOrgOAuthProvider(a *org.Aggregate, writeModel *OrgO if provider.UserEndpoint = strings.TrimSpace(provider.UserEndpoint); provider.UserEndpoint == "" { return nil, caos_errs.ThrowInvalidArgument(nil, "ORG-Fb8jk", "Errors.Invalid.Argument") } + if provider.IDAttribute = strings.TrimSpace(provider.IDAttribute); provider.IDAttribute == "" { + return nil, caos_errs.ThrowInvalidArgument(nil, "ORG-sadf3d", "Errors.Invalid.Argument") + } return func(ctx context.Context, filter preparation.FilterToQueryReducer) ([]eventstore.Command, error) { events, err := filter(ctx, writeModel.Query()) if err != nil { @@ -286,6 +289,7 @@ func (c *Commands) prepareAddOrgOAuthProvider(a *org.Aggregate, writeModel *OrgO provider.AuthorizationEndpoint, provider.TokenEndpoint, provider.UserEndpoint, + provider.IDAttribute, provider.Scopes, provider.IDPOptions, ), @@ -314,6 +318,9 @@ func (c *Commands) prepareUpdateOrgOAuthProvider(a *org.Aggregate, writeModel *O if provider.UserEndpoint = strings.TrimSpace(provider.UserEndpoint); provider.UserEndpoint == "" { return nil, caos_errs.ThrowInvalidArgument(nil, "ORG-Fb8jk", "Errors.Invalid.Argument") } + if provider.IDAttribute = strings.TrimSpace(provider.IDAttribute); provider.IDAttribute == "" { + return nil, caos_errs.ThrowInvalidArgument(nil, "ORG-SAe4gh", "Errors.Invalid.Argument") + } return func(ctx context.Context, filter preparation.FilterToQueryReducer) ([]eventstore.Command, error) { events, err := filter(ctx, writeModel.Query()) if err != nil { @@ -337,6 +344,7 @@ func (c *Commands) prepareUpdateOrgOAuthProvider(a *org.Aggregate, writeModel *O provider.AuthorizationEndpoint, provider.TokenEndpoint, provider.UserEndpoint, + provider.IDAttribute, provider.Scopes, provider.IDPOptions, ) diff --git a/internal/command/org_idp_model.go b/internal/command/org_idp_model.go index f6b3c6ec4d..ee3afd16d2 100644 --- a/internal/command/org_idp_model.go +++ b/internal/command/org_idp_model.go @@ -69,7 +69,8 @@ func (wm *OrgOAuthIDPWriteModel) NewChangedEvent( secretCrypto crypto.Crypto, authorizationEndpoint, tokenEndpoint, - userEndpoint string, + userEndpoint, + idAttribute string, scopes []string, options idp.Options, ) (*org.OAuthIDPChangedEvent, error) { @@ -82,6 +83,7 @@ func (wm *OrgOAuthIDPWriteModel) NewChangedEvent( authorizationEndpoint, tokenEndpoint, userEndpoint, + idAttribute, scopes, options, ) diff --git a/internal/command/org_idp_test.go b/internal/command/org_idp_test.go index 0fcb4c6d66..1d35659713 100644 --- a/internal/command/org_idp_test.go +++ b/internal/command/org_idp_test.go @@ -150,6 +150,28 @@ func TestCommandSide_AddOrgGenericOAuthIDP(t *testing.T) { err: caos_errors.IsErrorInvalidArgument, }, }, + { + "invalid id attribute", + fields{ + eventstore: eventstoreExpect(t), + idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "id1"), + }, + args{ + ctx: context.Background(), + resourceOwner: "org1", + provider: GenericOAuthProvider{ + Name: "name", + ClientID: "clientID", + ClientSecret: "clientSecret", + AuthorizationEndpoint: "auth", + TokenEndpoint: "token", + UserEndpoint: "user", + }, + }, + res{ + err: caos_errors.IsErrorInvalidArgument, + }, + }, { name: "ok", fields: fields{ @@ -170,6 +192,7 @@ func TestCommandSide_AddOrgGenericOAuthIDP(t *testing.T) { "auth", "token", "user", + "idAttribute", nil, idp.Options{}, )), @@ -188,6 +211,7 @@ func TestCommandSide_AddOrgGenericOAuthIDP(t *testing.T) { AuthorizationEndpoint: "auth", TokenEndpoint: "token", UserEndpoint: "user", + IDAttribute: "idAttribute", }, }, res: res{ @@ -215,6 +239,7 @@ func TestCommandSide_AddOrgGenericOAuthIDP(t *testing.T) { "auth", "token", "user", + "idAttribute", []string{"user"}, idp.Options{ IsCreationAllowed: true, @@ -239,6 +264,7 @@ func TestCommandSide_AddOrgGenericOAuthIDP(t *testing.T) { TokenEndpoint: "token", UserEndpoint: "user", Scopes: []string{"user"}, + IDAttribute: "idAttribute", IDPOptions: idp.Options{ IsCreationAllowed: true, IsLinkingAllowed: true, @@ -398,6 +424,27 @@ func TestCommandSide_UpdateOrgGenericOAuthIDP(t *testing.T) { err: caos_errors.IsErrorInvalidArgument, }, }, + { + "invalid id attribute", + fields{ + eventstore: eventstoreExpect(t), + }, + args{ + ctx: context.Background(), + resourceOwner: "org1", + id: "id1", + provider: GenericOAuthProvider{ + Name: "name", + ClientID: "clientID", + AuthorizationEndpoint: "auth", + TokenEndpoint: "token", + UserEndpoint: "user", + }, + }, + res{ + err: caos_errors.IsErrorInvalidArgument, + }, + }, { name: "not found", fields: fields{ @@ -415,6 +462,7 @@ func TestCommandSide_UpdateOrgGenericOAuthIDP(t *testing.T) { AuthorizationEndpoint: "auth", TokenEndpoint: "token", UserEndpoint: "user", + IDAttribute: "idAttribute", }, }, res: res{ @@ -440,6 +488,7 @@ func TestCommandSide_UpdateOrgGenericOAuthIDP(t *testing.T) { "auth", "token", "user", + "idAttribute", nil, idp.Options{}, )), @@ -456,6 +505,7 @@ func TestCommandSide_UpdateOrgGenericOAuthIDP(t *testing.T) { AuthorizationEndpoint: "auth", TokenEndpoint: "token", UserEndpoint: "user", + IDAttribute: "idAttribute", }, }, res: res{ @@ -481,6 +531,7 @@ func TestCommandSide_UpdateOrgGenericOAuthIDP(t *testing.T) { "auth", "token", "user", + "idAttribute", nil, idp.Options{}, )), @@ -504,6 +555,7 @@ func TestCommandSide_UpdateOrgGenericOAuthIDP(t *testing.T) { idp.ChangeOAuthTokenEndpoint("new token"), idp.ChangeOAuthUserEndpoint("new user"), idp.ChangeOAuthScopes([]string{"openid", "profile"}), + idp.ChangeOAuthIDAttribute("newAttribute"), idp.ChangeOAuthOptions(idp.OptionChanges{ IsCreationAllowed: &t, IsLinkingAllowed: &t, @@ -531,6 +583,7 @@ func TestCommandSide_UpdateOrgGenericOAuthIDP(t *testing.T) { TokenEndpoint: "new token", UserEndpoint: "new user", Scopes: []string{"openid", "profile"}, + IDAttribute: "newAttribute", IDPOptions: idp.Options{ IsCreationAllowed: true, IsLinkingAllowed: true, diff --git a/internal/idp/providers/oauth/mapper.go b/internal/idp/providers/oauth/mapper.go index fca9d994b6..3493ea4309 100644 --- a/internal/idp/providers/oauth/mapper.go +++ b/internal/idp/providers/oauth/mapper.go @@ -2,6 +2,8 @@ package oauth import ( "encoding/json" + "fmt" + "strconv" "golang.org/x/text/language" @@ -11,92 +13,97 @@ import ( var _ idp.User = (*UserMapper)(nil) // UserMapper is an implementation of [idp.User]. -// It can be used in ZITADEL actions to map the raw `info` +// It can be used in ZITADEL actions to map the `RawInfo` type UserMapper struct { - ID string - FirstName string - LastName string - DisplayName string - NickName string - PreferredUsername string - Email string - EmailVerified bool - Phone string - PhoneVerified bool - PreferredLanguage string - AvatarURL string - Profile string - info map[string]interface{} + idAttribute string + RawInfo map[string]interface{} +} + +func NewUserMapper(idAttribute string) *UserMapper { + return &UserMapper{ + idAttribute: idAttribute, + RawInfo: make(map[string]interface{}), + } } func (u *UserMapper) UnmarshalJSON(data []byte) error { - if u.info == nil { - u.info = make(map[string]interface{}) - } - return json.Unmarshal(data, &u.info) + return json.Unmarshal(data, &u.RawInfo) } // GetID is an implementation of the [idp.User] interface. func (u *UserMapper) GetID() string { - return u.ID + id, ok := u.RawInfo[u.idAttribute] + if !ok { + return "" + } + switch i := id.(type) { + case string: + return i + case int: + return strconv.Itoa(i) + case float64: + return strconv.FormatFloat(i, 'f', -1, 64) + default: + return fmt.Sprint(i) + } } // GetFirstName is an implementation of the [idp.User] interface. func (u *UserMapper) GetFirstName() string { - return u.FirstName + return "" } // GetLastName is an implementation of the [idp.User] interface. func (u *UserMapper) GetLastName() string { - return u.LastName + return "" } // GetDisplayName is an implementation of the [idp.User] interface. func (u *UserMapper) GetDisplayName() string { - return u.DisplayName + return "" } // GetNickname is an implementation of the [idp.User] interface. func (u *UserMapper) GetNickname() string { - return u.NickName + return "" } // GetPreferredUsername is an implementation of the [idp.User] interface. func (u *UserMapper) GetPreferredUsername() string { - return u.PreferredUsername + return "" } // GetEmail is an implementation of the [idp.User] interface. func (u *UserMapper) GetEmail() string { - return u.Email + return "" } // IsEmailVerified is an implementation of the [idp.User] interface. func (u *UserMapper) IsEmailVerified() bool { - return u.EmailVerified + return false } // GetPhone is an implementation of the [idp.User] interface. func (u *UserMapper) GetPhone() string { - return u.Phone + return "" } // IsPhoneVerified is an implementation of the [idp.User] interface. func (u *UserMapper) IsPhoneVerified() bool { - return u.PhoneVerified + return false } // GetPreferredLanguage is an implementation of the [idp.User] interface. func (u *UserMapper) GetPreferredLanguage() language.Tag { - return language.Make(u.PreferredLanguage) + return language.Und } // GetAvatarURL is an implementation of the [idp.User] interface. func (u *UserMapper) GetAvatarURL() string { - return u.AvatarURL + return "" } // GetProfile is an implementation of the [idp.User] interface. func (u *UserMapper) GetProfile() string { - return u.Profile + return "" } diff --git a/internal/idp/providers/oauth/session_test.go b/internal/idp/providers/oauth/session_test.go index f0761b45a7..ed9c10defb 100644 --- a/internal/idp/providers/oauth/session_test.go +++ b/internal/idp/providers/oauth/session_test.go @@ -93,9 +93,7 @@ func TestProvider_FetchUser(t *testing.T) { Reply(http.StatusInternalServerError) }, userMapper: func() idp.User { - return &UserMapper{ - ID: "userID", - } + return NewUserMapper("userID") }, authURL: "https://oauth2.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=user&state=testState", tokens: &oidc.Tokens{ @@ -135,7 +133,7 @@ func TestProvider_FetchUser(t *testing.T) { }) }, userMapper: func() idp.User { - return &UserMapper{} + return NewUserMapper("userID") }, authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=user&state=testState", tokens: &oidc.Tokens{ @@ -147,12 +145,13 @@ func TestProvider_FetchUser(t *testing.T) { }, want: want{ user: &UserMapper{ - info: map[string]interface{}{ + idAttribute: "userID", + RawInfo: map[string]interface{}{ "userID": "id", "custom": "claim", }, }, - id: "", + id: "id", firstName: "", lastName: "", displayName: "", @@ -202,7 +201,7 @@ func TestProvider_FetchUser(t *testing.T) { }) }, userMapper: func() idp.User { - return &UserMapper{} + return NewUserMapper("userID") }, authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=user&state=testState", tokens: nil, @@ -210,12 +209,13 @@ func TestProvider_FetchUser(t *testing.T) { }, want: want{ user: &UserMapper{ - info: map[string]interface{}{ + idAttribute: "userID", + RawInfo: map[string]interface{}{ "userID": "id", "custom": "claim", }, }, - id: "", + id: "id", firstName: "", lastName: "", displayName: "", diff --git a/internal/query/idp_template.go b/internal/query/idp_template.go index 735efaf174..8d4a909f6d 100644 --- a/internal/query/idp_template.go +++ b/internal/query/idp_template.go @@ -54,6 +54,7 @@ type OAuthIDPTemplate struct { TokenEndpoint string UserEndpoint string Scopes database.StringArray + IDAttribute string } type OIDCIDPTemplate struct { @@ -196,6 +197,10 @@ var ( name: projection.OAuthScopesCol, table: oauthIdpTemplateTable, } + OAuthIDAttributeCol = Column{ + name: projection.OAuthIDAttributeCol, + table: oauthIdpTemplateTable, + } ) var ( @@ -505,6 +510,7 @@ func prepareIDPTemplateByIDQuery(ctx context.Context, db prepareDatabase) (sq.Se OAuthTokenEndpointCol.identifier(), OAuthUserEndpointCol.identifier(), OAuthScopesCol.identifier(), + OAuthIDAttributeCol.identifier(), // oidc OIDCIDCol.identifier(), OIDCIssuerCol.identifier(), @@ -564,6 +570,7 @@ func prepareIDPTemplateByIDQuery(ctx context.Context, db prepareDatabase) (sq.Se oauthTokenEndpoint := sql.NullString{} oauthUserEndpoint := sql.NullString{} oauthScopes := database.StringArray{} + oauthIDAttribute := sql.NullString{} oidcID := sql.NullString{} oidcIssuer := sql.NullString{} @@ -627,6 +634,7 @@ func prepareIDPTemplateByIDQuery(ctx context.Context, db prepareDatabase) (sq.Se &oauthTokenEndpoint, &oauthUserEndpoint, &oauthScopes, + &oauthIDAttribute, // oidc &oidcID, &oidcIssuer, @@ -686,6 +694,7 @@ func prepareIDPTemplateByIDQuery(ctx context.Context, db prepareDatabase) (sq.Se TokenEndpoint: oauthTokenEndpoint.String, UserEndpoint: oauthUserEndpoint.String, Scopes: oauthScopes, + IDAttribute: oauthIDAttribute.String, } } if oidcID.Valid { @@ -770,6 +779,7 @@ func prepareIDPTemplatesQuery(ctx context.Context, db prepareDatabase) (sq.Selec OAuthTokenEndpointCol.identifier(), OAuthUserEndpointCol.identifier(), OAuthScopesCol.identifier(), + OAuthIDAttributeCol.identifier(), // oidc OIDCIDCol.identifier(), OIDCIssuerCol.identifier(), @@ -833,6 +843,7 @@ func prepareIDPTemplatesQuery(ctx context.Context, db prepareDatabase) (sq.Selec oauthTokenEndpoint := sql.NullString{} oauthUserEndpoint := sql.NullString{} oauthScopes := database.StringArray{} + oauthIDAttribute := sql.NullString{} oidcID := sql.NullString{} oidcIssuer := sql.NullString{} @@ -896,6 +907,7 @@ func prepareIDPTemplatesQuery(ctx context.Context, db prepareDatabase) (sq.Selec &oauthTokenEndpoint, &oauthUserEndpoint, &oauthScopes, + &oauthIDAttribute, // oidc &oidcID, &oidcIssuer, @@ -954,6 +966,7 @@ func prepareIDPTemplatesQuery(ctx context.Context, db prepareDatabase) (sq.Selec TokenEndpoint: oauthTokenEndpoint.String, UserEndpoint: oauthUserEndpoint.String, Scopes: oauthScopes, + IDAttribute: oauthIDAttribute.String, } } if oidcID.Valid { diff --git a/internal/query/idp_template_test.go b/internal/query/idp_template_test.go index b07c1841c0..3fff2c0088 100644 --- a/internal/query/idp_template_test.go +++ b/internal/query/idp_template_test.go @@ -29,13 +29,14 @@ var ( ` projections.idp_templates2.is_auto_creation,` + ` projections.idp_templates2.is_auto_update,` + // oauth - ` projections.idp_templates2_oauth.idp_id,` + - ` projections.idp_templates2_oauth.client_id,` + - ` projections.idp_templates2_oauth.client_secret,` + - ` projections.idp_templates2_oauth.authorization_endpoint,` + - ` projections.idp_templates2_oauth.token_endpoint,` + - ` projections.idp_templates2_oauth.user_endpoint,` + - ` projections.idp_templates2_oauth.scopes,` + + ` projections.idp_templates2_oauth2.idp_id,` + + ` projections.idp_templates2_oauth2.client_id,` + + ` projections.idp_templates2_oauth2.client_secret,` + + ` projections.idp_templates2_oauth2.authorization_endpoint,` + + ` projections.idp_templates2_oauth2.token_endpoint,` + + ` projections.idp_templates2_oauth2.user_endpoint,` + + ` projections.idp_templates2_oauth2.scopes,` + + ` projections.idp_templates2_oauth2.id_attribute,` + // oidc ` projections.idp_templates2_oidc.idp_id,` + ` projections.idp_templates2_oidc.issuer,` + @@ -77,7 +78,7 @@ var ( ` projections.idp_templates2_ldap.avatar_url_attribute,` + ` projections.idp_templates2_ldap.profile_attribute` + ` FROM projections.idp_templates2` + - ` LEFT JOIN projections.idp_templates2_oauth ON projections.idp_templates2.id = projections.idp_templates2_oauth.idp_id AND projections.idp_templates2.instance_id = projections.idp_templates2_oauth.instance_id` + + ` LEFT JOIN projections.idp_templates2_oauth2 ON projections.idp_templates2.id = projections.idp_templates2_oauth2.idp_id AND projections.idp_templates2.instance_id = projections.idp_templates2_oauth2.instance_id` + ` LEFT JOIN projections.idp_templates2_oidc ON projections.idp_templates2.id = projections.idp_templates2_oidc.idp_id AND projections.idp_templates2.instance_id = projections.idp_templates2_oidc.instance_id` + ` LEFT JOIN projections.idp_templates2_jwt ON projections.idp_templates2.id = projections.idp_templates2_jwt.idp_id AND projections.idp_templates2.instance_id = projections.idp_templates2_jwt.instance_id` + ` LEFT JOIN projections.idp_templates2_google ON projections.idp_templates2.id = projections.idp_templates2_google.idp_id AND projections.idp_templates2.instance_id = projections.idp_templates2_google.instance_id` + @@ -105,6 +106,7 @@ var ( "token_endpoint", "user_endpoint", "scopes", + "id_attribute", // oidc config "id_id", "issuer", @@ -160,13 +162,14 @@ var ( ` projections.idp_templates2.is_auto_creation,` + ` projections.idp_templates2.is_auto_update,` + // oauth - ` projections.idp_templates2_oauth.idp_id,` + - ` projections.idp_templates2_oauth.client_id,` + - ` projections.idp_templates2_oauth.client_secret,` + - ` projections.idp_templates2_oauth.authorization_endpoint,` + - ` projections.idp_templates2_oauth.token_endpoint,` + - ` projections.idp_templates2_oauth.user_endpoint,` + - ` projections.idp_templates2_oauth.scopes,` + + ` projections.idp_templates2_oauth2.idp_id,` + + ` projections.idp_templates2_oauth2.client_id,` + + ` projections.idp_templates2_oauth2.client_secret,` + + ` projections.idp_templates2_oauth2.authorization_endpoint,` + + ` projections.idp_templates2_oauth2.token_endpoint,` + + ` projections.idp_templates2_oauth2.user_endpoint,` + + ` projections.idp_templates2_oauth2.scopes,` + + ` projections.idp_templates2_oauth2.id_attribute,` + // oidc ` projections.idp_templates2_oidc.idp_id,` + ` projections.idp_templates2_oidc.issuer,` + @@ -209,7 +212,7 @@ var ( ` projections.idp_templates2_ldap.profile_attribute,` + ` COUNT(*) OVER ()` + ` FROM projections.idp_templates2` + - ` LEFT JOIN projections.idp_templates2_oauth ON projections.idp_templates2.id = projections.idp_templates2_oauth.idp_id AND projections.idp_templates2.instance_id = projections.idp_templates2_oauth.instance_id` + + ` LEFT JOIN projections.idp_templates2_oauth2 ON projections.idp_templates2.id = projections.idp_templates2_oauth2.idp_id AND projections.idp_templates2.instance_id = projections.idp_templates2_oauth2.instance_id` + ` LEFT JOIN projections.idp_templates2_oidc ON projections.idp_templates2.id = projections.idp_templates2_oidc.idp_id AND projections.idp_templates2.instance_id = projections.idp_templates2_oidc.instance_id` + ` LEFT JOIN projections.idp_templates2_jwt ON projections.idp_templates2.id = projections.idp_templates2_jwt.idp_id AND projections.idp_templates2.instance_id = projections.idp_templates2_jwt.instance_id` + ` LEFT JOIN projections.idp_templates2_google ON projections.idp_templates2.id = projections.idp_templates2_google.idp_id AND projections.idp_templates2.instance_id = projections.idp_templates2_google.instance_id` + @@ -237,6 +240,7 @@ var ( "token_endpoint", "user_endpoint", "scopes", + "id_attribute", // oidc config "id_id", "issuer", @@ -339,6 +343,7 @@ func Test_IDPTemplateTemplatesPrepares(t *testing.T) { "token", "user", database.StringArray{"profile"}, + "id-attribute", // oidc nil, nil, @@ -404,6 +409,7 @@ func Test_IDPTemplateTemplatesPrepares(t *testing.T) { TokenEndpoint: "token", UserEndpoint: "user", Scopes: []string{"profile"}, + IDAttribute: "id-attribute", }, }, }, @@ -436,6 +442,7 @@ func Test_IDPTemplateTemplatesPrepares(t *testing.T) { nil, nil, nil, + nil, // oidc "idp-id", "issuer", @@ -531,6 +538,7 @@ func Test_IDPTemplateTemplatesPrepares(t *testing.T) { nil, nil, nil, + nil, // oidc nil, nil, @@ -626,6 +634,7 @@ func Test_IDPTemplateTemplatesPrepares(t *testing.T) { nil, nil, nil, + nil, // oidc nil, nil, @@ -720,6 +729,7 @@ func Test_IDPTemplateTemplatesPrepares(t *testing.T) { nil, nil, nil, + nil, // oidc nil, nil, @@ -833,6 +843,7 @@ func Test_IDPTemplateTemplatesPrepares(t *testing.T) { nil, nil, nil, + nil, // oidc nil, nil, @@ -957,6 +968,7 @@ func Test_IDPTemplateTemplatesPrepares(t *testing.T) { nil, nil, nil, + nil, // oidc nil, nil, @@ -1079,6 +1091,7 @@ func Test_IDPTemplateTemplatesPrepares(t *testing.T) { nil, nil, nil, + nil, // oidc nil, nil, @@ -1176,6 +1189,7 @@ func Test_IDPTemplateTemplatesPrepares(t *testing.T) { nil, nil, nil, + nil, // oidc nil, nil, @@ -1239,6 +1253,7 @@ func Test_IDPTemplateTemplatesPrepares(t *testing.T) { nil, nil, nil, + nil, // oidc nil, nil, @@ -1302,6 +1317,7 @@ func Test_IDPTemplateTemplatesPrepares(t *testing.T) { "token", "user", database.StringArray{"profile"}, + "id-attribute", // oidc nil, nil, @@ -1365,6 +1381,7 @@ func Test_IDPTemplateTemplatesPrepares(t *testing.T) { nil, nil, nil, + nil, // oidc "idp-id-oidc", "issuer", @@ -1428,6 +1445,7 @@ func Test_IDPTemplateTemplatesPrepares(t *testing.T) { nil, nil, nil, + nil, // oidc nil, nil, @@ -1561,6 +1579,7 @@ func Test_IDPTemplateTemplatesPrepares(t *testing.T) { TokenEndpoint: "token", UserEndpoint: "user", Scopes: []string{"profile"}, + IDAttribute: "id-attribute", }, }, { diff --git a/internal/query/projection/idp_login_policy_link.go b/internal/query/projection/idp_login_policy_link.go index ced104f4bf..4ddf815571 100644 --- a/internal/query/projection/idp_login_policy_link.go +++ b/internal/query/projection/idp_login_policy_link.go @@ -81,6 +81,10 @@ func (p *idpLoginPolicyLinkProjection) reducers() []handler.AggregateReducer { Event: org.IDPConfigRemovedEventType, Reduce: p.reduceIDPConfigRemoved, }, + { + Event: org.IDPRemovedEventType, + Reduce: p.reduceIDPRemoved, + }, { Event: org.OrgRemovedEventType, Reduce: p.reduceOwnerRemoved, @@ -106,6 +110,10 @@ func (p *idpLoginPolicyLinkProjection) reducers() []handler.AggregateReducer { Event: instance.IDPConfigRemovedEventType, Reduce: p.reduceIDPConfigRemoved, }, + { + Event: instance.IDPRemovedEventType, + Reduce: p.reduceIDPRemoved, + }, { Event: instance.InstanceRemovedEventType, Reduce: reduceInstanceRemovedHelper(IDPUserLinkInstanceIDCol), @@ -209,6 +217,27 @@ func (p *idpLoginPolicyLinkProjection) reduceIDPConfigRemoved(event eventstore.E ), nil } +func (p *idpLoginPolicyLinkProjection) reduceIDPRemoved(event eventstore.Event) (*handler.Statement, error) { + var idpID string + + switch e := event.(type) { + case *org.IDPRemovedEvent: + idpID = e.ID + case *instance.IDPRemovedEvent: + idpID = e.ID + default: + return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-SFED3", "reduce.wrong.event.type %v", []eventstore.EventType{org.IDPRemovedEventType, instance.IDPRemovedEventType}) + } + + return crdb.NewDeleteStatement(event, + []handler.Condition{ + handler.NewCond(IDPLoginPolicyLinkIDPIDCol, idpID), + handler.NewCond(IDPLoginPolicyLinkResourceOwnerCol, event.Aggregate().ResourceOwner), + handler.NewCond(IDPLoginPolicyLinkInstanceIDCol, event.Aggregate().InstanceID), + }, + ), nil +} + func (p *idpLoginPolicyLinkProjection) reducePolicyRemoved(event eventstore.Event) (*handler.Statement, error) { e, ok := event.(*org.LoginPolicyRemovedEvent) if !ok { diff --git a/internal/query/projection/idp_login_policy_link_test.go b/internal/query/projection/idp_login_policy_link_test.go index c85fa3d30c..43a6edb175 100644 --- a/internal/query/projection/idp_login_policy_link_test.go +++ b/internal/query/projection/idp_login_policy_link_test.go @@ -331,6 +331,66 @@ func TestIDPLoginPolicyLinkProjection_reduces(t *testing.T) { }, }, }, + { + name: "org IDPRemovedEvent", + args: args{ + event: getEvent(testEvent( + repository.EventType(org.IDPRemovedEventType), + org.AggregateType, + []byte(`{ + "id": "id" + }`), + ), org.IDPRemovedEventMapper), + }, + reduce: (&idpLoginPolicyLinkProjection{}).reduceIDPRemoved, + want: wantReduce{ + aggregateType: org.AggregateType, + sequence: 15, + previousSequence: 10, + executer: &testExecuter{ + executions: []execution{ + { + expectedStmt: "DELETE FROM projections.idp_login_policy_links4 WHERE (idp_id = $1) AND (resource_owner = $2) AND (instance_id = $3)", + expectedArgs: []interface{}{ + "id", + "ro-id", + "instance-id", + }, + }, + }, + }, + }, + }, + { + name: "iam IDPRemovedEvent", + args: args{ + event: getEvent(testEvent( + repository.EventType(instance.IDPRemovedEventType), + instance.AggregateType, + []byte(`{ + "id": "id" + }`), + ), instance.IDPRemovedEventMapper), + }, + reduce: (&idpLoginPolicyLinkProjection{}).reduceIDPRemoved, + want: wantReduce{ + aggregateType: instance.AggregateType, + sequence: 15, + previousSequence: 10, + executer: &testExecuter{ + executions: []execution{ + { + expectedStmt: "DELETE FROM projections.idp_login_policy_links4 WHERE (idp_id = $1) AND (resource_owner = $2) AND (instance_id = $3)", + expectedArgs: []interface{}{ + "id", + "ro-id", + "instance-id", + }, + }, + }, + }, + }, + }, { name: "org.reduceOwnerRemoved", reduce: (&idpLoginPolicyLinkProjection{}).reduceOwnerRemoved, diff --git a/internal/query/projection/idp_template.go b/internal/query/projection/idp_template.go index 09c3c3ad45..6f392bfa15 100644 --- a/internal/query/projection/idp_template.go +++ b/internal/query/projection/idp_template.go @@ -24,7 +24,7 @@ const ( IDPTemplateGoogleTable = IDPTemplateTable + "_" + IDPTemplateGoogleSuffix IDPTemplateLDAPTable = IDPTemplateTable + "_" + IDPTemplateLDAPSuffix - IDPTemplateOAuthSuffix = "oauth" + IDPTemplateOAuthSuffix = "oauth2" IDPTemplateOIDCSuffix = "oidc" IDPTemplateJWTSuffix = "jwt" IDPTemplateGoogleSuffix = "google" @@ -54,6 +54,7 @@ const ( OAuthTokenEndpointCol = "token_endpoint" OAuthUserEndpointCol = "user_endpoint" OAuthScopesCol = "scopes" + OAuthIDAttributeCol = "id_attribute" OIDCIDCol = "idp_id" OIDCInstanceIDCol = "instance_id" @@ -139,6 +140,7 @@ func newIDPTemplateProjection(ctx context.Context, config crdb.StatementHandlerC crdb.NewColumn(OAuthTokenEndpointCol, crdb.ColumnTypeText), crdb.NewColumn(OAuthUserEndpointCol, crdb.ColumnTypeText), crdb.NewColumn(OAuthScopesCol, crdb.ColumnTypeTextArray, crdb.Nullable()), + crdb.NewColumn(OAuthIDAttributeCol, crdb.ColumnTypeText), }, crdb.NewPrimaryKey(OAuthInstanceIDCol, OAuthIDCol), IDPTemplateOAuthSuffix, @@ -417,6 +419,7 @@ func (p *idpTemplateProjection) reduceOAuthIDPAdded(event eventstore.Event) (*ha handler.NewCol(OAuthTokenEndpointCol, idpEvent.TokenEndpoint), handler.NewCol(OAuthUserEndpointCol, idpEvent.UserEndpoint), handler.NewCol(OAuthScopesCol, database.StringArray(idpEvent.Scopes)), + handler.NewCol(OAuthIDAttributeCol, idpEvent.IDAttribute), }, crdb.WithTableSuffix(IDPTemplateOAuthSuffix), ), @@ -1176,6 +1179,9 @@ func reduceOAuthIDPChangedColumns(idpEvent idp.OAuthIDPChangedEvent) []handler.C if idpEvent.Scopes != nil { oauthCols = append(oauthCols, handler.NewCol(OAuthScopesCol, database.StringArray(idpEvent.Scopes))) } + if idpEvent.IDAttribute != nil { + oauthCols = append(oauthCols, handler.NewCol(OAuthIDAttributeCol, *idpEvent.IDAttribute)) + } return oauthCols } diff --git a/internal/query/projection/idp_template_test.go b/internal/query/projection/idp_template_test.go index a44d58b710..3c9ed82f8f 100644 --- a/internal/query/projection/idp_template_test.go +++ b/internal/query/projection/idp_template_test.go @@ -154,6 +154,7 @@ func TestIDPTemplateProjection_reducesOAuth(t *testing.T) { "tokenEndpoint": "token", "userEndpoint": "user", "scopes": ["profile"], + "idAttribute": "id-attribute", "isCreationAllowed": true, "isLinkingAllowed": true, "isAutoCreation": true, @@ -188,7 +189,7 @@ func TestIDPTemplateProjection_reducesOAuth(t *testing.T) { }, }, { - expectedStmt: "INSERT INTO projections.idp_templates2_oauth (idp_id, instance_id, client_id, client_secret, authorization_endpoint, token_endpoint, user_endpoint, scopes) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)", + expectedStmt: "INSERT INTO projections.idp_templates2_oauth2 (idp_id, instance_id, client_id, client_secret, authorization_endpoint, token_endpoint, user_endpoint, scopes, id_attribute) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)", expectedArgs: []interface{}{ "idp-id", "instance-id", @@ -198,6 +199,7 @@ func TestIDPTemplateProjection_reducesOAuth(t *testing.T) { "token", "user", database.StringArray{"profile"}, + "id-attribute", }, }, }, @@ -223,6 +225,7 @@ func TestIDPTemplateProjection_reducesOAuth(t *testing.T) { "tokenEndpoint": "token", "userEndpoint": "user", "scopes": ["profile"], + "idAttribute": "id-attribute", "isCreationAllowed": true, "isLinkingAllowed": true, "isAutoCreation": true, @@ -257,7 +260,7 @@ func TestIDPTemplateProjection_reducesOAuth(t *testing.T) { }, }, { - expectedStmt: "INSERT INTO projections.idp_templates2_oauth (idp_id, instance_id, client_id, client_secret, authorization_endpoint, token_endpoint, user_endpoint, scopes) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)", + expectedStmt: "INSERT INTO projections.idp_templates2_oauth2 (idp_id, instance_id, client_id, client_secret, authorization_endpoint, token_endpoint, user_endpoint, scopes, id_attribute) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)", expectedArgs: []interface{}{ "idp-id", "instance-id", @@ -267,6 +270,7 @@ func TestIDPTemplateProjection_reducesOAuth(t *testing.T) { "token", "user", database.StringArray{"profile"}, + "id-attribute", }, }, }, @@ -304,7 +308,7 @@ func TestIDPTemplateProjection_reducesOAuth(t *testing.T) { }, }, { - expectedStmt: "UPDATE projections.idp_templates2_oauth SET client_id = $1 WHERE (idp_id = $2) AND (instance_id = $3)", + expectedStmt: "UPDATE projections.idp_templates2_oauth2 SET client_id = $1 WHERE (idp_id = $2) AND (instance_id = $3)", expectedArgs: []interface{}{ "id", "idp-id", @@ -334,6 +338,7 @@ func TestIDPTemplateProjection_reducesOAuth(t *testing.T) { "tokenEndpoint": "token", "userEndpoint": "user", "scopes": ["profile"], + "idAttribute": "id-attribute", "isCreationAllowed": true, "isLinkingAllowed": true, "isAutoCreation": true, @@ -363,7 +368,7 @@ func TestIDPTemplateProjection_reducesOAuth(t *testing.T) { }, }, { - expectedStmt: "UPDATE projections.idp_templates2_oauth SET (client_id, client_secret, authorization_endpoint, token_endpoint, user_endpoint, scopes) = ($1, $2, $3, $4, $5, $6) WHERE (idp_id = $7) AND (instance_id = $8)", + expectedStmt: "UPDATE projections.idp_templates2_oauth2 SET (client_id, client_secret, authorization_endpoint, token_endpoint, user_endpoint, scopes, id_attribute) = ($1, $2, $3, $4, $5, $6, $7) WHERE (idp_id = $8) AND (instance_id = $9)", expectedArgs: []interface{}{ "client_id", anyArg{}, @@ -371,6 +376,7 @@ func TestIDPTemplateProjection_reducesOAuth(t *testing.T) { "token", "user", database.StringArray{"profile"}, + "id-attribute", "idp-id", "instance-id", }, diff --git a/internal/repository/idp/oauth.go b/internal/repository/idp/oauth.go index 85271818ed..b693cd25e5 100644 --- a/internal/repository/idp/oauth.go +++ b/internal/repository/idp/oauth.go @@ -20,6 +20,7 @@ type OAuthIDPAddedEvent struct { TokenEndpoint string `json:"tokenEndpoint,omitempty"` UserEndpoint string `json:"userEndpoint,omitempty"` Scopes []string `json:"scopes,omitempty"` + IDAttribute string `json:"idAttribute,omitempty"` Options } @@ -31,7 +32,8 @@ func NewOAuthIDPAddedEvent( clientSecret *crypto.CryptoValue, authorizationEndpoint, tokenEndpoint, - userEndpoint string, + userEndpoint, + idAttribute string, scopes []string, options Options, ) *OAuthIDPAddedEvent { @@ -45,6 +47,7 @@ func NewOAuthIDPAddedEvent( TokenEndpoint: tokenEndpoint, UserEndpoint: userEndpoint, Scopes: scopes, + IDAttribute: idAttribute, Options: options, } } @@ -81,6 +84,7 @@ type OAuthIDPChangedEvent struct { TokenEndpoint *string `json:"tokenEndpoint,omitempty"` UserEndpoint *string `json:"userEndpoint,omitempty"` Scopes []string `json:"scopes,omitempty"` + IDAttribute *string `json:"idAttribute,omitempty"` OptionChanges } @@ -151,6 +155,12 @@ func ChangeOAuthScopes(scopes []string) func(*OAuthIDPChangedEvent) { } } +func ChangeOAuthIDAttribute(idAttribute string) func(*OAuthIDPChangedEvent) { + return func(e *OAuthIDPChangedEvent) { + e.IDAttribute = &idAttribute + } +} + func (e *OAuthIDPChangedEvent) Data() interface{} { return e } diff --git a/internal/repository/instance/idp.go b/internal/repository/instance/idp.go index 9dafb32257..e8250b9fc0 100644 --- a/internal/repository/instance/idp.go +++ b/internal/repository/instance/idp.go @@ -36,7 +36,8 @@ func NewOAuthIDPAddedEvent( clientSecret *crypto.CryptoValue, authorizationEndpoint, tokenEndpoint, - userEndpoint string, + userEndpoint, + idAttribute string, scopes []string, options idp.Options, ) *OAuthIDPAddedEvent { @@ -55,6 +56,7 @@ func NewOAuthIDPAddedEvent( authorizationEndpoint, tokenEndpoint, userEndpoint, + idAttribute, scopes, options, ), diff --git a/internal/repository/org/idp.go b/internal/repository/org/idp.go index f0dda3efe1..6cbb188e91 100644 --- a/internal/repository/org/idp.go +++ b/internal/repository/org/idp.go @@ -36,7 +36,8 @@ func NewOAuthIDPAddedEvent( clientSecret *crypto.CryptoValue, authorizationEndpoint, tokenEndpoint, - userEndpoint string, + userEndpoint, + idAttribute string, scopes []string, options idp.Options, ) *OAuthIDPAddedEvent { @@ -55,6 +56,7 @@ func NewOAuthIDPAddedEvent( authorizationEndpoint, tokenEndpoint, userEndpoint, + idAttribute, scopes, options, ), diff --git a/proto/zitadel/admin.proto b/proto/zitadel/admin.proto index 21bf720c84..4b0ae485ae 100644 --- a/proto/zitadel/admin.proto +++ b/proto/zitadel/admin.proto @@ -4343,7 +4343,9 @@ message AddGenericOAuthProviderRequest { string token_endpoint = 5 [(validate.rules).string = {min_len: 1, max_len: 200}]; string user_endpoint = 6 [(validate.rules).string = {min_len: 1, max_len: 200}]; repeated string scopes = 7 [(validate.rules).repeated = {max_items: 20, items: {string: {min_len: 1, max_len: 100}}}]; - zitadel.idp.v1.Options provider_options = 8; + // identifying attribute of the user in the response of the user_endpoint + string id_attribute = 8 [(validate.rules).string = {min_len: 1, max_len: 200}]; + zitadel.idp.v1.Options provider_options = 9; } message AddGenericOAuthProviderResponse { @@ -4361,7 +4363,9 @@ message UpdateGenericOAuthProviderRequest { string token_endpoint = 6 [(validate.rules).string = {min_len: 1, max_len: 200}]; string user_endpoint = 7 [(validate.rules).string = {min_len: 1, max_len: 200}]; repeated string scopes = 8 [(validate.rules).repeated = {max_items: 20, items: {string: {min_len: 1, max_len: 100}}}]; - zitadel.idp.v1.Options provider_options = 9; + // identifying attribute of the user in the response of the user_endpoint + string id_attribute = 9 [(validate.rules).string = {min_len: 1, max_len: 200}]; + zitadel.idp.v1.Options provider_options = 10; } message UpdateGenericOAuthProviderResponse { diff --git a/proto/zitadel/idp.proto b/proto/zitadel/idp.proto index 1594afa752..71afeec27e 100644 --- a/proto/zitadel/idp.proto +++ b/proto/zitadel/idp.proto @@ -275,6 +275,7 @@ message OAuthConfig { string token_endpoint = 3; string user_endpoint = 4; repeated string scopes = 5; + string id_attribute = 6; } message GenericOIDCConfig { diff --git a/proto/zitadel/management.proto b/proto/zitadel/management.proto index 2e919a7a36..529a58d196 100644 --- a/proto/zitadel/management.proto +++ b/proto/zitadel/management.proto @@ -11017,7 +11017,9 @@ message AddGenericOAuthProviderRequest { string token_endpoint = 5 [(validate.rules).string = {min_len: 1, max_len: 200}]; string user_endpoint = 6 [(validate.rules).string = {min_len: 1, max_len: 200}]; repeated string scopes = 7 [(validate.rules).repeated = {max_items: 20, items: {string: {min_len: 1, max_len: 100}}}]; - zitadel.idp.v1.Options provider_options = 8; + // identifying attribute of the user in the response of the user_endpoint + string id_attribute = 8 [(validate.rules).string = {min_len: 1, max_len: 200}]; + zitadel.idp.v1.Options provider_options = 9; } message AddGenericOAuthProviderResponse { @@ -11035,7 +11037,9 @@ message UpdateGenericOAuthProviderRequest { string token_endpoint = 6 [(validate.rules).string = {min_len: 1, max_len: 200}]; string user_endpoint = 7 [(validate.rules).string = {min_len: 1, max_len: 200}]; repeated string scopes = 8 [(validate.rules).repeated = {max_items: 20, items: {string: {min_len: 1, max_len: 100}}}]; - zitadel.idp.v1.Options provider_options = 9; + // identifying attribute of the user in the response of the user_endpoint + string id_attribute = 9 [(validate.rules).string = {min_len: 1, max_len: 200}]; + zitadel.idp.v1.Options provider_options = 10; } message UpdateGenericOAuthProviderResponse {