diff --git a/cmd/setup/48.go b/cmd/setup/48.go new file mode 100644 index 0000000000..2da0ad51a8 --- /dev/null +++ b/cmd/setup/48.go @@ -0,0 +1,27 @@ +package setup + +import ( + "context" + _ "embed" + + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/eventstore" +) + +var ( + //go:embed 48.sql + addSAMLAppLoginVersion string +) + +type Apps7SAMLConfigsLoginVersion struct { + dbClient *database.DB +} + +func (mig *Apps7SAMLConfigsLoginVersion) Execute(ctx context.Context, _ eventstore.Event) error { + _, err := mig.dbClient.ExecContext(ctx, addSAMLAppLoginVersion) + return err +} + +func (mig *Apps7SAMLConfigsLoginVersion) String() string { + return "48_apps7_saml_configs_login_version" +} diff --git a/cmd/setup/48.sql b/cmd/setup/48.sql new file mode 100644 index 0000000000..018231f59e --- /dev/null +++ b/cmd/setup/48.sql @@ -0,0 +1,2 @@ +ALTER TABLE IF EXISTS projections.apps7_saml_configs ADD COLUMN IF NOT EXISTS login_version SMALLINT; +ALTER TABLE IF EXISTS projections.apps7_saml_configs ADD COLUMN IF NOT EXISTS login_base_uri TEXT; diff --git a/cmd/setup/config.go b/cmd/setup/config.go index f3215fd980..d782a32dd6 100644 --- a/cmd/setup/config.go +++ b/cmd/setup/config.go @@ -136,6 +136,7 @@ type Steps struct { s45CorrectProjectOwners *CorrectProjectOwners s46InitPermissionFunctions *InitPermissionFunctions s47FillMembershipFields *FillMembershipFields + s48Apps7SAMLConfigsLoginVersion *Apps7SAMLConfigsLoginVersion } func MustNewSteps(v *viper.Viper) *Steps { diff --git a/cmd/setup/48_river_queue_repeatable.go b/cmd/setup/river_queue_repeatable.go similarity index 100% rename from cmd/setup/48_river_queue_repeatable.go rename to cmd/setup/river_queue_repeatable.go diff --git a/cmd/setup/setup.go b/cmd/setup/setup.go index b78d1fc9cf..bfa289ab36 100644 --- a/cmd/setup/setup.go +++ b/cmd/setup/setup.go @@ -173,6 +173,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string) steps.s45CorrectProjectOwners = &CorrectProjectOwners{eventstore: eventstoreClient} steps.s46InitPermissionFunctions = &InitPermissionFunctions{eventstoreClient: dbClient} steps.s47FillMembershipFields = &FillMembershipFields{eventstore: eventstoreClient} + steps.s48Apps7SAMLConfigsLoginVersion = &Apps7SAMLConfigsLoginVersion{dbClient: dbClient} err = projection.Create(ctx, dbClient, eventstoreClient, config.Projections, nil, nil, nil) logging.OnError(err).Fatal("unable to start projections") @@ -256,6 +257,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string) steps.s37Apps7OIDConfigsBackChannelLogoutURI, steps.s42Apps7OIDCConfigsLoginVersion, steps.s43CreateFieldsDomainIndex, + steps.s48Apps7SAMLConfigsLoginVersion, } { mustExecuteMigration(ctx, eventstoreClient, step, "migration failed") } diff --git a/docs/docs/guides/integrate/login-ui/oidc-standard.mdx b/docs/docs/guides/integrate/login-ui/oidc-standard.mdx index e9bdfb7cbf..f05d0d99b1 100644 --- a/docs/docs/guides/integrate/login-ui/oidc-standard.mdx +++ b/docs/docs/guides/integrate/login-ui/oidc-standard.mdx @@ -10,9 +10,9 @@ The following flow shows you the different components you need to enable OIDC fo ![OIDC Flow](/img/guides/login-ui/oidc-flow.png) 1. Your application makes an authorization request to your login UI -2. The login UI proxies the request to the ZITADEL API. In the request to the ZITADEL API, a header to identify your client is needed. +2. The login UI proxies the request to the ZITADEL API. 3. The ZITADEL API parses the request and does what it needs to interpret certain parameters (e.g., organization scope, etc.) -4. Redirect to a predefined, relative URL of the login UI that includes the authrequest ID ("/login?authRequest=") +4. Redirect to a predefined, relative URL of the login UI that includes the authrequest ID ("/login?authRequest="), configurable per application. 5. Request to ZITADEL API to get all the information from the auth request. This is optional and only needed if you like to get all the parsed information from the authrequest- 6. Authenticate the user in your login UI by creating and updating a session with all the checks you need. 7. Finalize the auth request by sending the session to the request, you will get the callback URL in the response @@ -37,10 +37,10 @@ https://login.example.com/oauth/v2/authorize?client_id=170086824411201793%40your The auth request includes all the relevant information for the OIDC standard and in this example we also have a login hint for the login name "minnie-mouse". You now have to proxy the auth request from your own UI to the authorize Endpoint of ZITADEL. -Make sure to add the user id of your login UI service/machine user as a header to the request: ```x-zitadel-login-client: ``` +For more information, see [OIDC Proxy](./typescript-repo#oidc-proxy) for the necessary headers. :::note -The user id sent in the 'x-zitadel-login-client' has to match to the PAT you are sending in the request. +The version and the optional custom URI for the available login UI is configurable under the application settings. ::: Read more about the [Authorize Endpoint Documentation](/docs/apis/openidoauth/endpoints#authorization_endpoint) @@ -97,7 +97,7 @@ The latest session token has to be sent to the following request: Read more about the [Finalize Auth Request Documentation](/docs/apis/resources/oidc_service_v2/oidc-service-create-callback) -Make sure that the authorization header is from the same account that you originally sent in the client id header ```x-zitadel-login-client: ``` on the authorize endpoint. +Make sure that the authorization header is from an account which is permitted to finalize the Auth Request through the `IAM_LOGIN_CLIENT` role. ```bash curl --request POST \ --url $ZITADEL_DOMAIN/v2/oidc/auth_requests/V2_224908753244265546 \ diff --git a/docs/docs/guides/integrate/login-ui/saml-standard.mdx b/docs/docs/guides/integrate/login-ui/saml-standard.mdx index c1f282371d..a2cb907874 100644 --- a/docs/docs/guides/integrate/login-ui/saml-standard.mdx +++ b/docs/docs/guides/integrate/login-ui/saml-standard.mdx @@ -10,9 +10,9 @@ The following flow shows you the different components you need to enable SAML fo ![SAML Flow](/img/guides/login-ui/saml-flow.png) 1. Your application makes an SAML request to your login UI -2. The login UI proxies the request to the ZITADEL API. In the request to the ZITADEL API, a header to identify your client is needed. +2. The login UI proxies the request to the ZITADEL API. 3. The ZITADEL API parses the request and does what it needs to interpret certain parameters (e.g., binding, nameID policy, etc.) -4. Redirect to a predefined, relative URL of the login UI that includes the samlrequest ID ("/login?authRequest=") +4. Redirect to a predefined, relative URL of the login UI that includes the samlrequest ID ("/login?authRequest="), configurable per application. 5. Request to ZITADEL API to get all the information from the SAML request. This is optional and only needed if you like to get all the parsed information from the samlrequest- 6. Authenticate the user in your login UI by creating and updating a session with all the checks you need. 7. Finalize the SAML request by sending the session to the request, you will get the URL to redirect to or the body in the response @@ -37,10 +37,10 @@ https://login.example.com/saml/v2/SSO?SAMLRequest=nJLRa9swEMb%2FFXHvjmVTY0fUhqxh The SAML request includes all the relevant information for the SAML standard, which includes the RelayState, the used binding and other information. You now have to proxy the SAML request from your own UI to the SSO Endpoint of ZITADEL. -Make sure to add the user id of your login UI service/machine user as a header to the request: ```x-zitadel-login-client: ``` +For more information, see [OIDC Proxy](./typescript-repo#oidc-proxy) for the necessary headers. :::note -The user id sent in the 'x-zitadel-login-client' has to match to the PAT you are sending in the request. +The version and the optional custom URI for the available login UI is configurable under the application settings. ::: Read more about the [SSO Endpoint Documentation](/docs/apis/saml/endpoints#sso_endpoint) @@ -87,14 +87,14 @@ Read the following resources for more information about the different checks: ### Finalize SAML Request -To finalize the SAML request and connect an existing user session with it you have to update the SAML request with the session token. +To finalize the SAML request and connect an existing user session with it you have to update the SAML Request with the session token. On the create and update user session request you will always get a session token in the response. The latest session token has to be sent to the following request: Read more about the [Finalize SAML Request Documentation](/docs/apis/resources/saml_service_v2/saml-service-create-response) -Make sure that the authorization header is from the same account that you originally sent in the client id header ```x-zitadel-login-client: ``` on the SSO endpoint. +Make sure that the authorization header is from an account which is permitted to finalize the SAML Request through the `IAM_LOGIN_CLIENT` role. ```bash curl --request POST \ --url $ZITADEL_DOMAIN/v2/saml/saml_requests/V2_224908753244265546 \ diff --git a/docs/docs/guides/integrate/login-ui/typescript-repo.mdx b/docs/docs/guides/integrate/login-ui/typescript-repo.mdx index d5fd6d9e4d..d4a0726621 100644 --- a/docs/docs/guides/integrate/login-ui/typescript-repo.mdx +++ b/docs/docs/guides/integrate/login-ui/typescript-repo.mdx @@ -130,7 +130,6 @@ To register your login domain on your instance, [add](/docs/apis/resources/admin When setting up the new login app for OIDC, ensure it meets the following requirements: - The OIDC Proxy is deployed and running on HTTPS -- The OIDC Proxy sets `x-zitadel-login-client` which is the user ID of the service account - The OIDC Proxy sets `x-zitadel-public-host` which is the host, your login is deployed to `ex. login.example.com`. - The OIDC Proxy sets `x-zitadel-instance-host` which is the host of your instance `ex. test-hdujwl.zitadel.cloud`. diff --git a/docs/static/img/guides/login-ui/oidc-flow.png b/docs/static/img/guides/login-ui/oidc-flow.png index b606044770..a427bad4ef 100644 Binary files a/docs/static/img/guides/login-ui/oidc-flow.png and b/docs/static/img/guides/login-ui/oidc-flow.png differ diff --git a/docs/static/img/guides/login-ui/saml-flow.png b/docs/static/img/guides/login-ui/saml-flow.png index 5c91fb4430..2a42642e2d 100644 Binary files a/docs/static/img/guides/login-ui/saml-flow.png and b/docs/static/img/guides/login-ui/saml-flow.png differ diff --git a/internal/api/grpc/management/project_application.go b/internal/api/grpc/management/project_application.go index 15e057c1bd..4b65808776 100644 --- a/internal/api/grpc/management/project_application.go +++ b/internal/api/grpc/management/project_application.go @@ -98,7 +98,11 @@ func (s *Server) AddOIDCApp(ctx context.Context, req *mgmt_pb.AddOIDCAppRequest) }, nil } func (s *Server) AddSAMLApp(ctx context.Context, req *mgmt_pb.AddSAMLAppRequest) (*mgmt_pb.AddSAMLAppResponse, error) { - app, err := s.command.AddSAMLApplication(ctx, AddSAMLAppRequestToDomain(req), authz.GetCtxData(ctx).OrgID) + samlApp, err := AddSAMLAppRequestToDomain(req) + if err != nil { + return nil, err + } + app, err := s.command.AddSAMLApplication(ctx, samlApp, authz.GetCtxData(ctx).OrgID) if err != nil { return nil, err } @@ -150,7 +154,11 @@ func (s *Server) UpdateOIDCAppConfig(ctx context.Context, req *mgmt_pb.UpdateOID } func (s *Server) UpdateSAMLAppConfig(ctx context.Context, req *mgmt_pb.UpdateSAMLAppConfigRequest) (*mgmt_pb.UpdateSAMLAppConfigResponse, error) { - config, err := s.command.ChangeSAMLApplication(ctx, UpdateSAMLAppConfigRequestToDomain(req), authz.GetCtxData(ctx).OrgID) + samlApp, err := UpdateSAMLAppConfigRequestToDomain(req) + if err != nil { + return nil, err + } + config, err := s.command.ChangeSAMLApplication(ctx, samlApp, authz.GetCtxData(ctx).OrgID) if err != nil { return nil, err } diff --git a/internal/api/grpc/management/project_application_converter.go b/internal/api/grpc/management/project_application_converter.go index 787470d9c1..13a0048a5b 100644 --- a/internal/api/grpc/management/project_application_converter.go +++ b/internal/api/grpc/management/project_application_converter.go @@ -67,15 +67,21 @@ func AddOIDCAppRequestToDomain(req *mgmt_pb.AddOIDCAppRequest) (*domain.OIDCApp, }, nil } -func AddSAMLAppRequestToDomain(req *mgmt_pb.AddSAMLAppRequest) *domain.SAMLApp { +func AddSAMLAppRequestToDomain(req *mgmt_pb.AddSAMLAppRequest) (*domain.SAMLApp, error) { + loginVersion, loginBaseURI, err := app_grpc.LoginVersionToDomain(req.GetLoginVersion()) + if err != nil { + return nil, err + } return &domain.SAMLApp{ ObjectRoot: models.ObjectRoot{ AggregateID: req.ProjectId, }, - AppName: req.Name, - Metadata: req.GetMetadataXml(), - MetadataURL: req.GetMetadataUrl(), - } + AppName: req.Name, + Metadata: req.GetMetadataXml(), + MetadataURL: req.GetMetadataUrl(), + LoginVersion: loginVersion, + LoginBaseURI: loginBaseURI, + }, nil } func AddAPIAppRequestToDomain(app *mgmt_pb.AddAPIAppRequest) *domain.APIApp { @@ -125,15 +131,21 @@ func UpdateOIDCAppConfigRequestToDomain(app *mgmt_pb.UpdateOIDCAppConfigRequest) }, nil } -func UpdateSAMLAppConfigRequestToDomain(app *mgmt_pb.UpdateSAMLAppConfigRequest) *domain.SAMLApp { +func UpdateSAMLAppConfigRequestToDomain(app *mgmt_pb.UpdateSAMLAppConfigRequest) (*domain.SAMLApp, error) { + loginVersion, loginBaseURI, err := app_grpc.LoginVersionToDomain(app.GetLoginVersion()) + if err != nil { + return nil, err + } return &domain.SAMLApp{ ObjectRoot: models.ObjectRoot{ AggregateID: app.ProjectId, }, - AppID: app.AppId, - Metadata: app.GetMetadataXml(), - MetadataURL: app.GetMetadataUrl(), - } + AppID: app.AppId, + Metadata: app.GetMetadataXml(), + MetadataURL: app.GetMetadataUrl(), + LoginVersion: loginVersion, + LoginBaseURI: loginBaseURI, + }, nil } func UpdateAPIAppConfigRequestToDomain(app *mgmt_pb.UpdateAPIAppConfigRequest) *domain.APIApp { diff --git a/internal/api/grpc/project/application.go b/internal/api/grpc/project/application.go index 573156e637..fc05013c53 100644 --- a/internal/api/grpc/project/application.go +++ b/internal/api/grpc/project/application.go @@ -85,7 +85,8 @@ func loginVersionToPb(version domain.LoginVersion, baseURI *string) *app_pb.Logi func AppSAMLConfigToPb(app *query.SAMLApp) app_pb.AppConfig { return &app_pb.App_SamlConfig{ SamlConfig: &app_pb.SAMLConfig{ - Metadata: &app_pb.SAMLConfig_MetadataXml{MetadataXml: app.Metadata}, + Metadata: &app_pb.SAMLConfig_MetadataXml{MetadataXml: app.Metadata}, + LoginVersion: loginVersionToPb(app.LoginVersion, app.LoginBaseURI), }, } } diff --git a/internal/api/saml/serviceprovider.go b/internal/api/saml/serviceprovider.go new file mode 100644 index 0000000000..98865e0858 --- /dev/null +++ b/internal/api/saml/serviceprovider.go @@ -0,0 +1,53 @@ +package saml + +import ( + "strings" + + "github.com/zitadel/saml/pkg/provider/serviceprovider" + + "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/query" +) + +const ( + LoginSamlRequestParam = "samlRequest" + LoginPath = "/login" +) + +type ServiceProvider struct { + SP *query.SAMLServiceProvider + defaultLoginURL string + defaultLoginURLV2 string +} + +func ServiceProviderFromBusiness(spQuery *query.SAMLServiceProvider, defaultLoginURL, defaultLoginURLV2 string) (*serviceprovider.ServiceProvider, error) { + sp := &ServiceProvider{ + SP: spQuery, + defaultLoginURL: defaultLoginURL, + defaultLoginURLV2: defaultLoginURLV2, + } + + return serviceprovider.NewServiceProvider( + spQuery.AppID, + &serviceprovider.Config{Metadata: spQuery.Metadata}, + sp.LoginURL, + ) +} + +func (s *ServiceProvider) LoginURL(id string) string { + // if the authRequest does not have the v2 prefix, it was created for login V1 + if !strings.HasPrefix(id, command.IDPrefixV2) { + return s.defaultLoginURL + id + } + // any v2 login without a specific base uri will be sent to the configured login v2 UI + // this way we're also backwards compatible + if s.SP.LoginBaseURI == nil || s.SP.LoginBaseURI.String() == "" { + return s.defaultLoginURLV2 + id + } + // for clients with a specific URI (internal or external) we only need to add the auth request id + uri := s.SP.LoginBaseURI.JoinPath(LoginPath) + q := uri.Query() + q.Set(LoginSamlRequestParam, id) + uri.RawQuery = q.Encode() + return uri.String() +} diff --git a/internal/api/saml/storage.go b/internal/api/saml/storage.go index 76f1bfd903..5a02619d93 100644 --- a/internal/api/saml/storage.go +++ b/internal/api/saml/storage.go @@ -17,6 +17,7 @@ import ( "github.com/zitadel/zitadel/internal/actions" "github.com/zitadel/zitadel/internal/actions/object" "github.com/zitadel/zitadel/internal/activity" + "github.com/zitadel/zitadel/internal/api/authz" http_utils "github.com/zitadel/zitadel/internal/api/http" "github.com/zitadel/zitadel/internal/api/http/middleware" "github.com/zitadel/zitadel/internal/auth/repository" @@ -62,22 +63,12 @@ type Storage struct { } func (p *Storage) GetEntityByID(ctx context.Context, entityID string) (*serviceprovider.ServiceProvider, error) { - app, err := p.query.ActiveAppBySAMLEntityID(ctx, entityID) + sp, err := p.query.ActiveSAMLServiceProviderByID(ctx, entityID) if err != nil { return nil, err } - return serviceprovider.NewServiceProvider( - app.ID, - &serviceprovider.Config{ - Metadata: app.SAMLConfig.Metadata, - }, - func(id string) string { - if strings.HasPrefix(id, command.IDPrefixV2) { - return p.defaultLoginURLv2 + id - } - return p.defaultLoginURL + id - }, - ) + + return ServiceProviderFromBusiness(sp, p.defaultLoginURL, p.defaultLoginURLv2) } func (p *Storage) GetEntityIDByAppID(ctx context.Context, appID string) (string, error) { @@ -108,11 +99,34 @@ func (p *Storage) CreateAuthRequest(ctx context.Context, req *samlp.AuthnRequest ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() + // for backwards compatibility we pass the login client if set headers, _ := http_utils.HeadersFromCtx(ctx) - if loginClient := headers.Get(LoginClientHeader); loginClient != "" { + loginClient := headers.Get(LoginClientHeader) + + // for backwards compatibility we'll use the new login if the header is set (no matter the other configs) + if loginClient != "" { return p.createAuthRequestLoginClient(ctx, req, acsUrl, protocolBinding, relayState, applicationID, loginClient) } - return p.createAuthRequest(ctx, req, acsUrl, protocolBinding, relayState, applicationID) + + // if the instance requires the v2 login, use it no matter what the application configured + if authz.GetFeatures(ctx).LoginV2.Required { + return p.createAuthRequestLoginClient(ctx, req, acsUrl, protocolBinding, relayState, applicationID, loginClient) + } + version, err := p.query.SAMLAppLoginVersion(ctx, applicationID) + if err != nil { + return nil, err + } + switch version { + case domain.LoginVersion1: + return p.createAuthRequest(ctx, req, acsUrl, protocolBinding, relayState, applicationID) + case domain.LoginVersion2: + return p.createAuthRequestLoginClient(ctx, req, acsUrl, protocolBinding, relayState, applicationID, loginClient) + case domain.LoginVersionUnspecified: + fallthrough + default: + // since we already checked for a login header, we can fall back to the v1 login + return p.createAuthRequest(ctx, req, acsUrl, protocolBinding, relayState, applicationID) + } } func (p *Storage) createAuthRequestLoginClient(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID, loginClient string) (_ models.AuthRequestInt, err error) { diff --git a/internal/command/org_test.go b/internal/command/org_test.go index bf88b55a86..4ec85d61e1 100644 --- a/internal/command/org_test.go +++ b/internal/command/org_test.go @@ -1264,10 +1264,10 @@ func TestCommandSide_RemoveOrg(t *testing.T) { ), expectFilter( eventFromEventPusher( - project.NewSAMLConfigAddedEvent(context.Background(), &project.NewAggregate("project1", "org1").Aggregate, "app1", "entity1", []byte{}, ""), + project.NewSAMLConfigAddedEvent(context.Background(), &project.NewAggregate("project1", "org1").Aggregate, "app1", "entity1", []byte{}, "", domain.LoginVersionUnspecified, ""), ), eventFromEventPusher( - project.NewSAMLConfigAddedEvent(context.Background(), &project.NewAggregate("project2", "org1").Aggregate, "app2", "entity2", []byte{}, ""), + project.NewSAMLConfigAddedEvent(context.Background(), &project.NewAggregate("project2", "org1").Aggregate, "app2", "entity2", []byte{}, "", domain.LoginVersionUnspecified, ""), ), ), expectPush( diff --git a/internal/command/project_application_oidc_model.go b/internal/command/project_application_oidc_model.go index 3fc07c79a9..603ebdcda2 100644 --- a/internal/command/project_application_oidc_model.go +++ b/internal/command/project_application_oidc_model.go @@ -325,10 +325,10 @@ func (wm *OIDCApplicationWriteModel) NewChangedEvent( changes = append(changes, project.ChangeBackChannelLogoutURI(backChannelLogoutURI)) } if wm.LoginVersion != loginVersion { - changes = append(changes, project.ChangeLoginVersion(loginVersion)) + changes = append(changes, project.ChangeOIDCLoginVersion(loginVersion)) } if wm.LoginBaseURI != loginBaseURI { - changes = append(changes, project.ChangeLoginBaseURI(loginBaseURI)) + changes = append(changes, project.ChangeOIDCLoginBaseURI(loginBaseURI)) } if len(changes) == 0 { diff --git a/internal/command/project_application_oidc_test.go b/internal/command/project_application_oidc_test.go index 8b663afa57..4b9f5bf94f 100644 --- a/internal/command/project_application_oidc_test.go +++ b/internal/command/project_application_oidc_test.go @@ -1297,8 +1297,8 @@ func newOIDCAppChangedEvent(ctx context.Context, appID, projectID, resourceOwner project.ChangeIDTokenRoleAssertion(false), project.ChangeIDTokenUserinfoAssertion(false), project.ChangeClockSkew(time.Second * 2), - project.ChangeLoginVersion(domain.LoginVersion2), - project.ChangeLoginBaseURI("https://login.test.ch"), + project.ChangeOIDCLoginVersion(domain.LoginVersion2), + project.ChangeOIDCLoginBaseURI("https://login.test.ch"), } event, _ := project.NewOIDCConfigChangedEvent(ctx, &project.NewAggregate(projectID, resourceOwner).Aggregate, diff --git a/internal/command/project_application_saml.go b/internal/command/project_application_saml.go index 76297ad93f..b14bed0758 100644 --- a/internal/command/project_application_saml.go +++ b/internal/command/project_application_saml.go @@ -79,6 +79,8 @@ func (c *Commands) addSAMLApplication(ctx context.Context, projectAgg *eventstor string(entity.EntityID), samlApp.Metadata, samlApp.MetadataURL, + samlApp.LoginVersion, + samlApp.LoginBaseURI, ), }, nil } @@ -119,7 +121,10 @@ func (c *Commands) ChangeSAMLApplication(ctx context.Context, samlApp *domain.SA samlApp.AppID, string(entity.EntityID), samlApp.Metadata, - samlApp.MetadataURL) + samlApp.MetadataURL, + samlApp.LoginVersion, + samlApp.LoginBaseURI, + ) if err != nil { return nil, err } diff --git a/internal/command/project_application_saml_model.go b/internal/command/project_application_saml_model.go index 2652acc617..f219039b58 100644 --- a/internal/command/project_application_saml_model.go +++ b/internal/command/project_application_saml_model.go @@ -12,11 +12,13 @@ import ( type SAMLApplicationWriteModel struct { eventstore.WriteModel - AppID string - AppName string - EntityID string - Metadata []byte - MetadataURL string + AppID string + AppName string + EntityID string + Metadata []byte + MetadataURL string + LoginVersion domain.LoginVersion + LoginBaseURI string State domain.AppState saml bool @@ -121,6 +123,8 @@ func (wm *SAMLApplicationWriteModel) appendAddSAMLEvent(e *project.SAMLConfigAdd wm.Metadata = e.Metadata wm.MetadataURL = e.MetadataURL wm.EntityID = e.EntityID + wm.LoginVersion = e.LoginVersion + wm.LoginBaseURI = e.LoginBaseURI } func (wm *SAMLApplicationWriteModel) appendChangeSAMLEvent(e *project.SAMLConfigChangedEvent) { @@ -134,6 +138,12 @@ func (wm *SAMLApplicationWriteModel) appendChangeSAMLEvent(e *project.SAMLConfig if e.EntityID != "" { wm.EntityID = e.EntityID } + if e.LoginVersion != nil { + wm.LoginVersion = *e.LoginVersion + } + if e.LoginBaseURI != nil { + wm.LoginBaseURI = *e.LoginBaseURI + } } func (wm *SAMLApplicationWriteModel) Query() *eventstore.SearchQueryBuilder { @@ -161,6 +171,8 @@ func (wm *SAMLApplicationWriteModel) NewChangedEvent( entityID string, metadata []byte, metadataURL string, + loginVersion domain.LoginVersion, + loginBaseURI string, ) (*project.SAMLConfigChangedEvent, bool, error) { changes := make([]project.SAMLConfigChanges, 0) var err error @@ -173,6 +185,12 @@ func (wm *SAMLApplicationWriteModel) NewChangedEvent( if wm.EntityID != entityID { changes = append(changes, project.ChangeEntityID(entityID)) } + if wm.LoginVersion != loginVersion { + changes = append(changes, project.ChangeSAMLLoginVersion(loginVersion)) + } + if wm.LoginBaseURI != loginBaseURI { + changes = append(changes, project.ChangeSAMLLoginBaseURI(loginBaseURI)) + } if len(changes) == 0 { return nil, false, nil diff --git a/internal/command/project_application_saml_test.go b/internal/command/project_application_saml_test.go index ff774e9f49..3082e87c46 100644 --- a/internal/command/project_application_saml_test.go +++ b/internal/command/project_application_saml_test.go @@ -50,7 +50,7 @@ var testMetadataChangedEntityID = []byte(` func TestCommandSide_AddSAMLApplication(t *testing.T) { type fields struct { - eventstore *eventstore.Eventstore + eventstore func(t *testing.T) *eventstore.Eventstore idGenerator id.Generator httpClient *http.Client } @@ -72,9 +72,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) { { name: "no aggregate id, invalid argument error", fields: fields{ - eventstore: eventstoreExpect( - t, - ), + eventstore: expectEventstore(), }, args: args{ ctx: authz.WithInstanceID(context.Background(), "instanceID"), @@ -88,8 +86,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) { { name: "project not existing, not found error", fields: fields{ - eventstore: eventstoreExpect( - t, + eventstore: expectEventstore( expectFilter(), ), }, @@ -111,8 +108,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) { { name: "invalid app, invalid argument error", fields: fields{ - eventstore: eventstoreExpect( - t, + eventstore: expectEventstore( expectFilter( eventFromEventPusher( project.NewProjectAddedEvent(context.Background(), @@ -141,8 +137,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) { { name: "create saml app, metadata not parsable", fields: fields{ - eventstore: eventstoreExpect( - t, + eventstore: expectEventstore( expectFilter( eventFromEventPusher( project.NewProjectAddedEvent(context.Background(), @@ -174,8 +169,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) { { name: "create saml app, ok", fields: fields{ - eventstore: eventstoreExpect( - t, + eventstore: expectEventstore( expectFilter( eventFromEventPusher( project.NewProjectAddedEvent(context.Background(), @@ -196,6 +190,8 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) { "https://test.com/saml/metadata", testMetadata, "", + domain.LoginVersionUnspecified, + "", ), ), ), @@ -229,11 +225,73 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) { }, }, }, + { + name: "create saml app, loginversion, ok", + fields: fields{ + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + project.NewProjectAddedEvent(context.Background(), + &project.NewAggregate("project1", "org1").Aggregate, + "project", true, true, true, + domain.PrivateLabelingSettingUnspecified), + ), + ), + expectPush( + project.NewApplicationAddedEvent(context.Background(), + &project.NewAggregate("project1", "org1").Aggregate, + "app1", + "app", + ), + project.NewSAMLConfigAddedEvent(context.Background(), + &project.NewAggregate("project1", "org1").Aggregate, + "app1", + "https://test.com/saml/metadata", + testMetadata, + "", + domain.LoginVersion2, + "https://test.com/login", + ), + ), + ), + idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "app1"), + }, + args: args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + samlApp: &domain.SAMLApp{ + ObjectRoot: models.ObjectRoot{ + AggregateID: "project1", + }, + AppName: "app", + EntityID: "https://test.com/saml/metadata", + Metadata: testMetadata, + MetadataURL: "", + LoginVersion: domain.LoginVersion2, + LoginBaseURI: "https://test.com/login", + }, + resourceOwner: "org1", + }, + res: res{ + want: &domain.SAMLApp{ + ObjectRoot: models.ObjectRoot{ + AggregateID: "project1", + ResourceOwner: "org1", + }, + AppID: "app1", + AppName: "app", + EntityID: "https://test.com/saml/metadata", + Metadata: testMetadata, + MetadataURL: "", + State: domain.AppStateActive, + LoginVersion: domain.LoginVersion2, + LoginBaseURI: "https://test.com/login", + }, + }, + }, { name: "create saml app metadataURL, ok", fields: fields{ - eventstore: eventstoreExpect( - t, + eventstore: expectEventstore( expectFilter( eventFromEventPusher( project.NewProjectAddedEvent(context.Background(), @@ -254,6 +312,8 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) { "https://test.com/saml/metadata", testMetadata, "http://localhost:8080/saml/metadata", + domain.LoginVersionUnspecified, + "", ), ), ), @@ -291,8 +351,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) { { name: "create saml app metadataURL, http error", fields: fields{ - eventstore: eventstoreExpect( - t, + eventstore: expectEventstore( expectFilter( eventFromEventPusher( project.NewProjectAddedEvent(context.Background(), @@ -327,7 +386,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &Commands{ - eventstore: tt.fields.eventstore, + eventstore: tt.fields.eventstore(t), idGenerator: tt.fields.idGenerator, httpClient: tt.fields.httpClient, } @@ -348,7 +407,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) { func TestCommandSide_ChangeSAMLApplication(t *testing.T) { type fields struct { - eventstore *eventstore.Eventstore + eventstore func(t *testing.T) *eventstore.Eventstore httpClient *http.Client } type args struct { @@ -369,9 +428,7 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) { { name: "invalid app, invalid argument error", fields: fields{ - eventstore: eventstoreExpect( - t, - ), + eventstore: expectEventstore(), }, args: args{ ctx: context.Background(), @@ -390,9 +447,7 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) { { name: "missing appid, invalid argument error", fields: fields{ - eventstore: eventstoreExpect( - t, - ), + eventstore: expectEventstore(), }, args: args{ ctx: context.Background(), @@ -412,9 +467,7 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) { { name: "missing aggregateid, invalid argument error", fields: fields{ - eventstore: eventstoreExpect( - t, - ), + eventstore: expectEventstore(), }, args: args{ ctx: context.Background(), @@ -434,8 +487,7 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) { { name: "app not existing, not found error", fields: fields{ - eventstore: eventstoreExpect( - t, + eventstore: expectEventstore( expectFilter(), ), }, @@ -457,8 +509,7 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) { { name: "no changes, precondition error, metadataURL", fields: fields{ - eventstore: eventstoreExpect( - t, + eventstore: expectEventstore( expectFilter( eventFromEventPusher( project.NewApplicationAddedEvent(context.Background(), @@ -474,6 +525,8 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) { "https://test.com/saml/metadata", testMetadata, "http://localhost:8080/saml/metadata", + domain.LoginVersionUnspecified, + "", ), ), ), @@ -502,8 +555,7 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) { { name: "no changes, precondition error, metadata", fields: fields{ - eventstore: eventstoreExpect( - t, + eventstore: expectEventstore( expectFilter( eventFromEventPusher( project.NewApplicationAddedEvent(context.Background(), @@ -519,6 +571,8 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) { "https://test.com/saml/metadata", testMetadata, "", + domain.LoginVersionUnspecified, + "", ), ), ), @@ -547,8 +601,7 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) { { name: "change saml app, ok, metadataURL", fields: fields{ - eventstore: eventstoreExpect( - t, + eventstore: expectEventstore( expectFilter( eventFromEventPusher( project.NewApplicationAddedEvent(context.Background(), @@ -564,6 +617,8 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) { "https://test.com/saml/metadata", testMetadata, "http://localhost:8080/saml/metadata", + domain.LoginVersionUnspecified, + "", ), ), ), @@ -613,8 +668,7 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) { { name: "change saml app, ok, metadata", fields: fields{ - eventstore: eventstoreExpect( - t, + eventstore: expectEventstore( expectFilter( eventFromEventPusher( project.NewApplicationAddedEvent(context.Background(), @@ -630,6 +684,8 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) { "https://test.com/saml/metadata", testMetadata, "", + domain.LoginVersionUnspecified, + "", ), ), ), @@ -675,13 +731,85 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) { State: domain.AppStateActive, }, }, + }, { + name: "change saml app, ok, loginversion", + fields: fields{ + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + project.NewApplicationAddedEvent(context.Background(), + &project.NewAggregate("project1", "org1").Aggregate, + "app1", + "app", + ), + ), + eventFromEventPusher( + project.NewSAMLConfigAddedEvent(context.Background(), + &project.NewAggregate("project1", "org1").Aggregate, + "app1", + "https://test.com/saml/metadata", + testMetadata, + "", + domain.LoginVersionUnspecified, + "", + ), + ), + ), + expectPush( + newSAMLAppChangedEventLoginVersion(context.Background(), + "app1", + "project1", + "org1", + "https://test.com/saml/metadata", + "https://test2.com/saml/metadata", + testMetadataChangedEntityID, + domain.LoginVersion2, + "https://test.com/login", + ), + ), + ), + httpClient: nil, + }, + args: args{ + ctx: context.Background(), + samlApp: &domain.SAMLApp{ + ObjectRoot: models.ObjectRoot{ + AggregateID: "project1", + ResourceOwner: "org1", + }, + AppID: "app1", + AppName: "app", + EntityID: "https://test2.com/saml/metadata", + Metadata: testMetadataChangedEntityID, + MetadataURL: "", + LoginVersion: domain.LoginVersion2, + LoginBaseURI: "https://test.com/login", + }, + resourceOwner: "org1", + }, + res: res{ + want: &domain.SAMLApp{ + ObjectRoot: models.ObjectRoot{ + AggregateID: "project1", + ResourceOwner: "org1", + }, + AppID: "app1", + AppName: "app", + EntityID: "https://test2.com/saml/metadata", + Metadata: testMetadataChangedEntityID, + MetadataURL: "", + State: domain.AppStateActive, + LoginVersion: domain.LoginVersion2, + LoginBaseURI: "https://test.com/login", + }, + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { r := &Commands{ - eventstore: tt.fields.eventstore, + eventstore: tt.fields.eventstore(t), httpClient: tt.fields.httpClient, } got, err := r.ChangeSAMLApplication(tt.args.ctx, tt.args.samlApp, tt.args.resourceOwner) @@ -726,6 +854,22 @@ func newSAMLAppChangedEventMetadataURL(ctx context.Context, appID, projectID, re return event } +func newSAMLAppChangedEventLoginVersion(ctx context.Context, appID, projectID, resourceOwner, oldEntityID, entityID string, metadata []byte, loginVersion domain.LoginVersion, loginURI string) *project.SAMLConfigChangedEvent { + changes := []project.SAMLConfigChanges{ + project.ChangeEntityID(entityID), + project.ChangeMetadata(metadata), + project.ChangeSAMLLoginVersion(loginVersion), + project.ChangeSAMLLoginBaseURI(loginURI), + } + event, _ := project.NewSAMLConfigChangedEvent(ctx, + &project.NewAggregate(projectID, resourceOwner).Aggregate, + appID, + oldEntityID, + changes, + ) + return event +} + type roundTripperFunc func(*http.Request) *http.Response // RoundTrip implements the http.RoundTripper interface. diff --git a/internal/command/project_application_test.go b/internal/command/project_application_test.go index ae2c6c39b0..050a41d29f 100644 --- a/internal/command/project_application_test.go +++ b/internal/command/project_application_test.go @@ -596,6 +596,8 @@ func TestCommandSide_RemoveApplication(t *testing.T) { "https://test.com/saml/metadata", []byte("\n\n \n urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified\n \n \n \n"), "", + domain.LoginVersionUnspecified, + "", )), ), expectPush( diff --git a/internal/command/project_converter.go b/internal/command/project_converter.go index 59343aa762..01b5a4e63d 100644 --- a/internal/command/project_converter.go +++ b/internal/command/project_converter.go @@ -55,13 +55,15 @@ func oidcWriteModelToOIDCConfig(writeModel *OIDCApplicationWriteModel) *domain.O func samlWriteModelToSAMLConfig(writeModel *SAMLApplicationWriteModel) *domain.SAMLApp { return &domain.SAMLApp{ - ObjectRoot: writeModelToObjectRoot(writeModel.WriteModel), - AppID: writeModel.AppID, - AppName: writeModel.AppName, - State: writeModel.State, - Metadata: writeModel.Metadata, - MetadataURL: writeModel.MetadataURL, - EntityID: writeModel.EntityID, + ObjectRoot: writeModelToObjectRoot(writeModel.WriteModel), + AppID: writeModel.AppID, + AppName: writeModel.AppName, + State: writeModel.State, + Metadata: writeModel.Metadata, + MetadataURL: writeModel.MetadataURL, + EntityID: writeModel.EntityID, + LoginVersion: writeModel.LoginVersion, + LoginBaseURI: writeModel.LoginBaseURI, } } diff --git a/internal/command/project_test.go b/internal/command/project_test.go index 645371e2fc..842e1aa640 100644 --- a/internal/command/project_test.go +++ b/internal/command/project_test.go @@ -988,6 +988,8 @@ func TestCommandSide_RemoveProject(t *testing.T) { "https://test.com/saml/metadata", []byte("\n\n \n urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified\n \n \n \n"), "http://localhost:8080/saml/metadata", + domain.LoginVersionUnspecified, + "", ), ), ), @@ -1039,6 +1041,8 @@ func TestCommandSide_RemoveProject(t *testing.T) { "https://test1.com/saml/metadata", []byte("\n\n \n urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified\n \n \n \n"), "", + domain.LoginVersionUnspecified, + "", ), ), eventFromEventPusher(project.NewApplicationAddedEvent(context.Background(), @@ -1053,6 +1057,8 @@ func TestCommandSide_RemoveProject(t *testing.T) { "https://test2.com/saml/metadata", []byte("\n\n \n urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified\n \n \n \n"), "", + domain.LoginVersionUnspecified, + "", ), ), eventFromEventPusher(project.NewApplicationAddedEvent(context.Background(), @@ -1067,6 +1073,8 @@ func TestCommandSide_RemoveProject(t *testing.T) { "https://test3.com/saml/metadata", []byte("\n\n \n urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified\n \n \n \n"), "", + domain.LoginVersionUnspecified, + "", ), ), ), diff --git a/internal/command/saml_request.go b/internal/command/saml_request.go index 2dfa8756c7..17f56101ec 100644 --- a/internal/command/saml_request.go +++ b/internal/command/saml_request.go @@ -75,7 +75,9 @@ func (c *Commands) LinkSessionToSAMLRequest(ctx context.Context, id, sessionID, return nil, nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-ttPKNdAIFT", "Errors.SAMLRequest.AlreadyHandled") } if checkLoginClient && authz.GetCtxData(ctx).UserID != writeModel.LoginClient { - return nil, nil, zerrors.ThrowPermissionDenied(nil, "COMMAND-KCd48Rxt7x", "Errors.SAMLRequest.WrongLoginClient") + if err := c.checkPermission(ctx, domain.PermissionSessionLink, writeModel.ResourceOwner, ""); err != nil { + return nil, nil, err + } } sessionWriteModel := NewSessionWriteModel(sessionID, authz.GetInstance(ctx).InstanceID()) err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel) diff --git a/internal/command/saml_request_test.go b/internal/command/saml_request_test.go index ed7363e151..761edde8fb 100644 --- a/internal/command/saml_request_test.go +++ b/internal/command/saml_request_test.go @@ -132,8 +132,9 @@ func TestCommands_AddSAMLRequest(t *testing.T) { func TestCommands_LinkSessionToSAMLRequest(t *testing.T) { mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient") type fields struct { - eventstore func(t *testing.T) *eventstore.Eventstore - tokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) + eventstore func(t *testing.T) *eventstore.Eventstore + tokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) + checkPermission domain.PermissionCheck } type args struct { ctx context.Context @@ -207,7 +208,7 @@ func TestCommands_LinkSessionToSAMLRequest(t *testing.T) { }, }, { - "wrong login client", + "wrong login client / not permitted", fields{ eventstore: expectEventstore( expectFilter( @@ -225,7 +226,8 @@ func TestCommands_LinkSessionToSAMLRequest(t *testing.T) { ), ), ), - tokenVerifier: newMockTokenVerifierValid(), + tokenVerifier: newMockTokenVerifierValid(), + checkPermission: newMockPermissionCheckNotAllowed(), }, args{ ctx: authz.NewMockContext("instanceID", "orgID", "wrongLoginClient"), @@ -235,7 +237,7 @@ func TestCommands_LinkSessionToSAMLRequest(t *testing.T) { checkLoginClient: true, }, res{ - wantErr: zerrors.ThrowPermissionDenied(nil, "COMMAND-KCd48Rxt7x", "Errors.SAMLRequest.WrongLoginClient"), + wantErr: zerrors.ThrowPermissionDenied(nil, "AUTHZ-HKJD33", "Errors.PermissionDenied"), }, }, { @@ -524,6 +526,86 @@ func TestCommands_LinkSessionToSAMLRequest(t *testing.T) { AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, }, }, + }, { + "linked with permission", + fields{ + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "loginClient", + "application", + "acs", + "relaystate", + "request", + "binding", + "issuer", + "destination", + ), + ), + ), + expectFilter( + eventFromEventPusher( + session.NewAddedEvent(mockCtx, + &session.NewAggregate("sessionID", "instance1").Aggregate, + &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + )), + eventFromEventPusher( + session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate, + "userID", "org1", testNow, &language.Afrikaans), + ), + eventFromEventPusher( + session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate, + testNow), + ), + eventFromEventPusherWithCreationDateNow( + session.NewLifetimeSetEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate, + 2*time.Minute), + ), + ), + expectPush( + samlrequest.NewSessionLinkedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "sessionID", + "userID", + testNow, + []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, + ), + ), + ), + tokenVerifier: newMockTokenVerifierValid(), + checkPermission: newMockPermissionCheckAllowed(), + }, + args{ + ctx: authz.NewMockContext("instanceID", "orgID", "loginClient"), + id: "V2_id", + sessionID: "sessionID", + sessionToken: "token", + checkLoginClient: true, + }, + res{ + details: &domain.ObjectDetails{ResourceOwner: "instanceID"}, + authReq: &CurrentSAMLRequest{ + SAMLRequest: &SAMLRequest{ + ID: "V2_id", + LoginClient: "loginClient", + ApplicationID: "application", + ACSURL: "acs", + RelayState: "relaystate", + RequestID: "request", + Binding: "binding", + Issuer: "issuer", + Destination: "destination", + }, + SessionID: "sessionID", + UserID: "userID", + AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, + }, + }, }, { "linked with login client check, application permission check", @@ -669,6 +751,7 @@ func TestCommands_LinkSessionToSAMLRequest(t *testing.T) { c := &Commands{ eventstore: tt.fields.eventstore(t), sessionTokenVerifier: tt.fields.tokenVerifier, + checkPermission: tt.fields.checkPermission, } details, got, err := c.LinkSessionToSAMLRequest(tt.args.ctx, tt.args.id, tt.args.sessionID, tt.args.sessionToken, tt.args.checkLoginClient, tt.args.checkPermission) require.ErrorIs(t, err, tt.res.wantErr) diff --git a/internal/domain/application_saml.go b/internal/domain/application_saml.go index b00366df7e..de7ef789ee 100644 --- a/internal/domain/application_saml.go +++ b/internal/domain/application_saml.go @@ -7,11 +7,13 @@ import ( type SAMLApp struct { models.ObjectRoot - AppID string - AppName string - EntityID string - Metadata []byte - MetadataURL string + AppID string + AppName string + EntityID string + Metadata []byte + MetadataURL string + LoginVersion LoginVersion + LoginBaseURI string State AppState } diff --git a/internal/integration/saml.go b/internal/integration/saml.go index 483543b322..533b0ee515 100644 --- a/internal/integration/saml.go +++ b/internal/integration/saml.go @@ -20,6 +20,7 @@ import ( http_util "github.com/zitadel/zitadel/internal/api/http" oidc_internal "github.com/zitadel/zitadel/internal/api/oidc" + app_pb "github.com/zitadel/zitadel/pkg/grpc/app" "github.com/zitadel/zitadel/pkg/grpc/management" saml_pb "github.com/zitadel/zitadel/pkg/grpc/saml/v2" session_pb "github.com/zitadel/zitadel/pkg/grpc/session/v2" @@ -102,7 +103,7 @@ func CreateSAMLSP(root string, idpMetadata *saml.EntityDescriptor, binding strin return sp, nil } -func (i *Instance) CreateSAMLClient(ctx context.Context, projectID string, m *samlsp.Middleware) (*management.AddSAMLAppResponse, error) { +func (i *Instance) CreateSAMLClientLoginVersion(ctx context.Context, projectID string, m *samlsp.Middleware, loginVersion *app_pb.LoginVersion) (*management.AddSAMLAppResponse, error) { spMetadata, err := xml.MarshalIndent(m.ServiceProvider.Metadata(), "", " ") if err != nil { return nil, err @@ -114,9 +115,10 @@ func (i *Instance) CreateSAMLClient(ctx context.Context, projectID string, m *sa } resp, err := i.Client.Mgmt.AddSAMLApp(ctx, &management.AddSAMLAppRequest{ - ProjectId: projectID, - Name: fmt.Sprintf("app-%s", gofakeit.AppName()), - Metadata: &management.AddSAMLAppRequest_MetadataXml{MetadataXml: spMetadata}, + ProjectId: projectID, + Name: fmt.Sprintf("app-%s", gofakeit.AppName()), + Metadata: &management.AddSAMLAppRequest_MetadataXml{MetadataXml: spMetadata}, + LoginVersion: loginVersion, }) if err != nil { return nil, err @@ -136,7 +138,19 @@ func (i *Instance) CreateSAMLClient(ctx context.Context, projectID string, m *sa }) } -func (i *Instance) CreateSAMLAuthRequest(m *samlsp.Middleware, loginClient string, acs saml.Endpoint, relayState string, responseBinding string) (now time.Time, authRequestID string, err error) { +func (i *Instance) CreateSAMLClient(ctx context.Context, projectID string, m *samlsp.Middleware) (*management.AddSAMLAppResponse, error) { + return i.CreateSAMLClientLoginVersion(ctx, projectID, m, nil) +} + +func (i *Instance) CreateSAMLAuthRequestWithoutLoginClientHeader(m *samlsp.Middleware, loginBaseURI string, acs saml.Endpoint, relayState, responseBinding string) (now time.Time, authRequestID string, err error) { + return i.createSAMLAuthRequest(m, "", loginBaseURI, acs, relayState, responseBinding) +} + +func (i *Instance) CreateSAMLAuthRequest(m *samlsp.Middleware, loginClient string, acs saml.Endpoint, relayState, responseBinding string) (now time.Time, authRequestID string, err error) { + return i.createSAMLAuthRequest(m, loginClient, "", acs, relayState, responseBinding) +} + +func (i *Instance) createSAMLAuthRequest(m *samlsp.Middleware, loginClient, loginBaseURI string, acs saml.Endpoint, relayState, responseBinding string) (now time.Time, authRequestID string, err error) { authReq, err := m.ServiceProvider.MakeAuthenticationRequest(acs.Location, acs.Binding, responseBinding) if err != nil { return now, "", err @@ -147,7 +161,11 @@ func (i *Instance) CreateSAMLAuthRequest(m *samlsp.Middleware, loginClient strin return now, "", err } - req, err := GetRequest(redirectURL.String(), map[string]string{oidc_internal.LoginClientHeader: loginClient}) + var headers map[string]string + if loginClient != "" { + headers = map[string]string{oidc_internal.LoginClientHeader: loginClient} + } + req, err := GetRequest(redirectURL.String(), headers) if err != nil { return now, "", fmt.Errorf("get request: %w", err) } @@ -158,11 +176,13 @@ func (i *Instance) CreateSAMLAuthRequest(m *samlsp.Middleware, loginClient strin return now, "", fmt.Errorf("check redirect: %w", err) } - prefixWithHost := i.Issuer() + i.Config.LoginURLV2 - if !strings.HasPrefix(loc.String(), prefixWithHost) { - return now, "", fmt.Errorf("login location has not prefix %s, but is %s", prefixWithHost, loc.String()) + if loginBaseURI == "" { + loginBaseURI = i.Issuer() + i.Config.LoginURLV2 } - return now, strings.TrimPrefix(loc.String(), prefixWithHost), nil + if !strings.HasPrefix(loc.String(), loginBaseURI) { + return now, "", fmt.Errorf("login location has not prefix %s, but is %s", loginBaseURI, loc.String()) + } + return now, strings.TrimPrefix(loc.String(), loginBaseURI), nil } func (i *Instance) FailSAMLAuthRequest(ctx context.Context, id string, reason saml_pb.ErrorReason) *saml_pb.CreateResponseResponse { diff --git a/internal/query/app.go b/internal/query/app.go index 1aa0323a5a..fafbbe72d9 100644 --- a/internal/query/app.go +++ b/internal/query/app.go @@ -66,9 +66,11 @@ type OIDCApp struct { } type SAMLApp struct { - Metadata []byte - MetadataURL string - EntityID string + Metadata []byte + MetadataURL string + EntityID string + LoginVersion domain.LoginVersion + LoginBaseURI *string } type APIApp struct { @@ -137,6 +139,10 @@ var ( name: projection.AppSAMLTable, instanceIDCol: projection.AppSAMLConfigColumnInstanceID, } + AppSAMLConfigColumnInstanceID = Column{ + name: projection.AppSAMLConfigColumnInstanceID, + table: appSAMLConfigsTable, + } AppSAMLConfigColumnAppID = Column{ name: projection.AppSAMLConfigColumnAppID, table: appSAMLConfigsTable, @@ -153,6 +159,14 @@ var ( name: projection.AppSAMLConfigColumnMetadataURL, table: appSAMLConfigsTable, } + AppSAMLConfigColumnLoginVersion = Column{ + name: projection.AppSAMLConfigColumnLoginVersion, + table: appSAMLConfigsTable, + } + AppSAMLConfigColumnLoginBaseURI = Column{ + name: projection.AppSAMLConfigColumnLoginBaseURI, + table: appSAMLConfigsTable, + } ) var ( @@ -320,30 +334,6 @@ func (q *Queries) AppByID(ctx context.Context, appID string, activeOnly bool) (a return app, err } -func (q *Queries) ActiveAppBySAMLEntityID(ctx context.Context, entityID string) (app *App, err error) { - ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() - - stmt, scan := prepareSAMLAppQuery(ctx, q.client) - eq := sq.Eq{ - AppSAMLConfigColumnEntityID.identifier(): entityID, - AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), - AppColumnState.identifier(): domain.AppStateActive, - ProjectColumnState.identifier(): domain.ProjectStateActive, - OrgColumnState.identifier(): domain.OrgStateActive, - } - query, args, err := stmt.Where(eq).ToSql() - if err != nil { - return nil, zerrors.ThrowInternal(err, "QUERY-JgUop", "Errors.Query.SQLStatement") - } - - err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { - app, err = scan(row) - return err - }, query, args...) - return app, err -} - func (q *Queries) ProjectByClientID(ctx context.Context, appID string) (project *Project, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -591,7 +581,7 @@ func (q *Queries) OIDCClientLoginVersion(ctx context.Context, clientID string) ( ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareLoginVersionByClientID(ctx, q.client) + query, scan := prepareLoginVersionByOIDCClientID(ctx, q.client) eq := sq.Eq{ AppOIDCConfigColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), AppOIDCConfigColumnClientID.identifier(): clientID, @@ -611,6 +601,30 @@ func (q *Queries) OIDCClientLoginVersion(ctx context.Context, clientID string) ( return loginVersion, nil } +func (q *Queries) SAMLAppLoginVersion(ctx context.Context, appID string) (loginVersion domain.LoginVersion, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + query, scan := prepareLoginVersionBySAMLAppID(ctx, q.client) + eq := sq.Eq{ + AppSAMLConfigColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), + AppSAMLConfigColumnAppID.identifier(): appID, + } + stmt, args, err := query.Where(eq).ToSql() + if err != nil { + return domain.LoginVersionUnspecified, zerrors.ThrowInvalidArgument(err, "QUERY-TnaciwZfp3", "Errors.Query.InvalidRequest") + } + + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + loginVersion, err = scan(row) + return err + }, stmt, args...) + if err != nil { + return domain.LoginVersionUnspecified, zerrors.ThrowInternal(err, "QUERY-lvDDwRzIoP", "Errors.Internal") + } + return loginVersion, nil +} + func NewAppNameSearchQuery(method TextComparison, value string) (SearchQuery, error) { return NewTextQuery(AppColumnName, value, method) } @@ -659,6 +673,8 @@ func prepareAppQuery(ctx context.Context, db prepareDatabase, activeOnly bool) ( AppSAMLConfigColumnEntityID.identifier(), AppSAMLConfigColumnMetadata.identifier(), AppSAMLConfigColumnMetadataURL.identifier(), + AppSAMLConfigColumnLoginVersion.identifier(), + AppSAMLConfigColumnLoginBaseURI.identifier(), ).From(appsTable.identifier()). PlaceholderFormat(sq.Dollar) @@ -726,6 +742,8 @@ func scanApp(row *sql.Row) (*App, error) { &samlConfig.entityID, &samlConfig.metadata, &samlConfig.metadataURL, + &samlConfig.loginVersion, + &samlConfig.loginBaseURI, ) if err != nil { @@ -827,61 +845,6 @@ func prepareOIDCAppQuery() (sq.SelectBuilder, func(*sql.Row) (*App, error)) { } } -func prepareSAMLAppQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { - return sq.Select( - AppColumnID.identifier(), - AppColumnName.identifier(), - AppColumnProjectID.identifier(), - AppColumnCreationDate.identifier(), - AppColumnChangeDate.identifier(), - AppColumnResourceOwner.identifier(), - AppColumnState.identifier(), - AppColumnSequence.identifier(), - - AppSAMLConfigColumnAppID.identifier(), - AppSAMLConfigColumnEntityID.identifier(), - AppSAMLConfigColumnMetadata.identifier(), - AppSAMLConfigColumnMetadataURL.identifier(), - ).From(appsTable.identifier()). - Join(join(AppSAMLConfigColumnAppID, AppColumnID)). - Join(join(ProjectColumnID, AppColumnProjectID)). - Join(join(OrgColumnID, AppColumnResourceOwner)). - PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*App, error) { - - app := new(App) - var ( - samlConfig = sqlSAMLConfig{} - ) - - err := row.Scan( - &app.ID, - &app.Name, - &app.ProjectID, - &app.CreationDate, - &app.ChangeDate, - &app.ResourceOwner, - &app.State, - &app.Sequence, - - &samlConfig.appID, - &samlConfig.entityID, - &samlConfig.metadata, - &samlConfig.metadataURL, - ) - - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, zerrors.ThrowNotFound(err, "QUERY-d6TO1", "Errors.App.NotExisting") - } - return nil, zerrors.ThrowInternal(err, "QUERY-NAtPg", "Errors.Internal") - } - - samlConfig.set(app) - - return app, nil - } -} - func prepareProjectIDByAppQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (projectID string, err error)) { return sq.Select( AppColumnProjectID.identifier(), @@ -1031,6 +994,8 @@ func prepareAppsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder AppSAMLConfigColumnEntityID.identifier(), AppSAMLConfigColumnMetadata.identifier(), AppSAMLConfigColumnMetadataURL.identifier(), + AppSAMLConfigColumnLoginVersion.identifier(), + AppSAMLConfigColumnLoginBaseURI.identifier(), countColumn.identifier(), ).From(appsTable.identifier()). LeftJoin(join(AppAPIConfigColumnAppID, AppColumnID)). @@ -1086,6 +1051,8 @@ func prepareAppsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder &samlConfig.entityID, &samlConfig.metadata, &samlConfig.metadataURL, + &samlConfig.loginVersion, + &samlConfig.loginBaseURI, &apps.Count, ) @@ -1135,7 +1102,7 @@ func prepareClientIDsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBu } } -func prepareLoginVersionByClientID(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (domain.LoginVersion, error)) { +func prepareLoginVersionByOIDCClientID(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (domain.LoginVersion, error)) { return sq.Select( AppOIDCConfigColumnLoginVersion.identifier(), ).From(appOIDCConfigsTable.identifier()). @@ -1150,6 +1117,21 @@ func prepareLoginVersionByClientID(ctx context.Context, db prepareDatabase) (sq. } } +func prepareLoginVersionBySAMLAppID(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (domain.LoginVersion, error)) { + return sq.Select( + AppSAMLConfigColumnLoginVersion.identifier(), + ).From(appSAMLConfigsTable.identifier()). + PlaceholderFormat(sq.Dollar), func(row *sql.Row) (domain.LoginVersion, error) { + var loginVersion sql.NullInt16 + if err := row.Scan( + &loginVersion, + ); err != nil { + return domain.LoginVersionUnspecified, zerrors.ThrowInternal(err, "QUERY-KbzaCnaziI", "Errors.Internal") + } + return domain.LoginVersion(loginVersion.Int16), nil + } +} + type sqlOIDCConfig struct { appID sql.NullString version sql.NullInt32 @@ -1209,10 +1191,12 @@ func (c sqlOIDCConfig) set(app *App) { } type sqlSAMLConfig struct { - appID sql.NullString - entityID sql.NullString - metadataURL sql.NullString - metadata []byte + appID sql.NullString + entityID sql.NullString + metadataURL sql.NullString + metadata []byte + loginVersion sql.NullInt16 + loginBaseURI sql.NullString } func (c sqlSAMLConfig) set(app *App) { @@ -1220,9 +1204,13 @@ func (c sqlSAMLConfig) set(app *App) { return } app.SAMLConfig = &SAMLApp{ - MetadataURL: c.metadataURL.String, - Metadata: c.metadata, - EntityID: c.entityID.String, + EntityID: c.entityID.String, + MetadataURL: c.metadataURL.String, + Metadata: c.metadata, + LoginVersion: domain.LoginVersion(c.loginVersion.Int16), + } + if c.loginBaseURI.Valid { + app.SAMLConfig.LoginBaseURI = &c.loginBaseURI.String } } diff --git a/internal/query/app_test.go b/internal/query/app_test.go index ea9444f665..dbbcaef47c 100644 --- a/internal/query/app_test.go +++ b/internal/query/app_test.go @@ -56,7 +56,9 @@ var ( ` projections.apps7_saml_configs.app_id,` + ` projections.apps7_saml_configs.entity_id,` + ` projections.apps7_saml_configs.metadata,` + - ` projections.apps7_saml_configs.metadata_url` + + ` projections.apps7_saml_configs.metadata_url,` + + ` projections.apps7_saml_configs.login_version,` + + ` projections.apps7_saml_configs.login_base_uri` + ` FROM projections.apps7` + ` LEFT JOIN projections.apps7_api_configs ON projections.apps7.id = projections.apps7_api_configs.app_id AND projections.apps7.instance_id = projections.apps7_api_configs.instance_id` + ` LEFT JOIN projections.apps7_oidc_configs ON projections.apps7.id = projections.apps7_oidc_configs.app_id AND projections.apps7.instance_id = projections.apps7_oidc_configs.instance_id` + @@ -103,6 +105,8 @@ var ( ` projections.apps7_saml_configs.entity_id,` + ` projections.apps7_saml_configs.metadata,` + ` projections.apps7_saml_configs.metadata_url,` + + ` projections.apps7_saml_configs.login_version,` + + ` projections.apps7_saml_configs.login_base_uri,` + ` COUNT(*) OVER ()` + ` FROM projections.apps7` + ` LEFT JOIN projections.apps7_api_configs ON projections.apps7.id = projections.apps7_api_configs.app_id AND projections.apps7.instance_id = projections.apps7_api_configs.instance_id` + @@ -178,6 +182,8 @@ var ( "entity_id", "metadata", "metadata_url", + "login_version", + "login_base_uri", } appsCols = append(appCols, "count") ) @@ -252,6 +258,8 @@ func Test_AppsPrepare(t *testing.T) { nil, nil, nil, + nil, + nil, }, }, ), @@ -321,6 +329,8 @@ func Test_AppsPrepare(t *testing.T) { nil, nil, nil, + nil, + nil, }, }, ), @@ -393,6 +403,8 @@ func Test_AppsPrepare(t *testing.T) { "https://test.com/saml/metadata", []byte("\n\n \n urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified\n \n \n \n"), "https://test.com/saml/metadata", + domain.LoginVersionUnspecified, + nil, }, }, ), @@ -467,6 +479,8 @@ func Test_AppsPrepare(t *testing.T) { nil, nil, nil, + nil, + nil, }, }, ), @@ -559,6 +573,8 @@ func Test_AppsPrepare(t *testing.T) { nil, nil, nil, + nil, + nil, }, }, ), @@ -651,6 +667,8 @@ func Test_AppsPrepare(t *testing.T) { nil, nil, nil, + nil, + nil, }, }, ), @@ -743,6 +761,8 @@ func Test_AppsPrepare(t *testing.T) { nil, nil, nil, + nil, + nil, }, }, ), @@ -835,6 +855,8 @@ func Test_AppsPrepare(t *testing.T) { nil, nil, nil, + nil, + nil, }, }, ), @@ -927,6 +949,8 @@ func Test_AppsPrepare(t *testing.T) { nil, nil, nil, + nil, + nil, }, }, ), @@ -1019,6 +1043,8 @@ func Test_AppsPrepare(t *testing.T) { nil, nil, nil, + nil, + nil, }, { "api-app-id", @@ -1059,6 +1085,8 @@ func Test_AppsPrepare(t *testing.T) { nil, nil, nil, + nil, + nil, }, { "saml-app-id", @@ -1099,6 +1127,8 @@ func Test_AppsPrepare(t *testing.T) { "https://test.com/saml/metadata", []byte("\n\n \n urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified\n \n \n \n"), "https://test.com/saml/metadata", + domain.LoginVersion2, + "https://login.ch/", }, }, ), @@ -1165,9 +1195,11 @@ func Test_AppsPrepare(t *testing.T) { Name: "app-name", ProjectID: "project-id", SAMLConfig: &SAMLApp{ - Metadata: []byte("\n\n \n urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified\n \n \n \n"), - MetadataURL: "https://test.com/saml/metadata", - EntityID: "https://test.com/saml/metadata", + Metadata: []byte("\n\n \n urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified\n \n \n \n"), + MetadataURL: "https://test.com/saml/metadata", + EntityID: "https://test.com/saml/metadata", + LoginVersion: domain.LoginVersion2, + LoginBaseURI: gu.Ptr("https://login.ch/"), }, }, }, @@ -1280,6 +1312,8 @@ func Test_AppPrepare(t *testing.T) { nil, nil, nil, + nil, + nil, }, ), }, @@ -1343,6 +1377,8 @@ func Test_AppPrepare(t *testing.T) { nil, nil, nil, + nil, + nil, }, }, ), @@ -1411,6 +1447,8 @@ func Test_AppPrepare(t *testing.T) { nil, nil, nil, + nil, + nil, }, }, ), @@ -1498,6 +1536,8 @@ func Test_AppPrepare(t *testing.T) { nil, nil, nil, + nil, + nil, }, }, ), @@ -1585,6 +1625,8 @@ func Test_AppPrepare(t *testing.T) { "https://test.com/saml/metadata", []byte("\n\n \n urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified\n \n \n \n"), "https://test.com/saml/metadata", + domain.LoginVersionUnspecified, + nil, }, }, ), @@ -1599,9 +1641,11 @@ func Test_AppPrepare(t *testing.T) { Name: "app-name", ProjectID: "project-id", SAMLConfig: &SAMLApp{ - Metadata: []byte("\n\n \n urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified\n \n \n \n"), - MetadataURL: "https://test.com/saml/metadata", - EntityID: "https://test.com/saml/metadata", + Metadata: []byte("\n\n \n urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified\n \n \n \n"), + MetadataURL: "https://test.com/saml/metadata", + EntityID: "https://test.com/saml/metadata", + LoginVersion: domain.LoginVersionUnspecified, + LoginBaseURI: nil, }, }, }, @@ -1654,6 +1698,8 @@ func Test_AppPrepare(t *testing.T) { nil, nil, nil, + nil, + nil, }, }, ), @@ -1741,6 +1787,8 @@ func Test_AppPrepare(t *testing.T) { nil, nil, nil, + nil, + nil, }, }, ), @@ -1828,6 +1876,8 @@ func Test_AppPrepare(t *testing.T) { nil, nil, nil, + nil, + nil, }, }, ), @@ -1915,6 +1965,8 @@ func Test_AppPrepare(t *testing.T) { nil, nil, nil, + nil, + nil, }, }, ), diff --git a/internal/query/oidc_client_test.go b/internal/query/oidc_client_test.go index bb0890bff3..25e069da85 100644 --- a/internal/query/oidc_client_test.go +++ b/internal/query/oidc_client_test.go @@ -4,6 +4,7 @@ import ( "database/sql" "database/sql/driver" _ "embed" + "net/url" "regexp" "testing" @@ -19,6 +20,8 @@ import ( var ( //go:embed testdata/oidc_client_jwt.json testdataOidcClientJWT string + //go:embed testdata/oidc_client_jwt_loginversion.json + testdataOidcClientJWTLoginVersion string //go:embed testdata/oidc_client_public.json testdataOidcClientPublic string //go:embed testdata/oidc_client_public_old_id.json @@ -91,6 +94,44 @@ low2kyJov38V4Uk2I8kuXpLcnrpw5Tio2ooiUE27b0vHZqBKOei9Uo88qCrn3EKx }, }, }, + { + name: "jwt client, login version", + mock: mockQuery(expQuery, cols, []driver.Value{testdataOidcClientJWTLoginVersion}, "instanceID", "clientID", true), + want: &OIDCClient{ + InstanceID: "230690539048009730", + AppID: "236647088211886082", + State: domain.AppStateActive, + ClientID: "236647088211951618", + HashedSecret: "", + RedirectURIs: []string{"http://localhost:9999/auth/callback"}, + ResponseTypes: []domain.OIDCResponseType{domain.OIDCResponseTypeCode}, + GrantTypes: []domain.OIDCGrantType{domain.OIDCGrantTypeAuthorizationCode, domain.OIDCGrantTypeRefreshToken}, + ApplicationType: domain.OIDCApplicationTypeWeb, + AuthMethodType: domain.OIDCAuthMethodTypePrivateKeyJWT, + PostLogoutRedirectURIs: []string{"https://example.com/logout"}, + IsDevMode: true, + AccessTokenType: domain.OIDCTokenTypeJWT, + AccessTokenRoleAssertion: true, + IDTokenRoleAssertion: true, + IDTokenUserinfoAssertion: true, + ClockSkew: 1000000000, + AdditionalOrigins: []string{"https://example.com"}, + ProjectID: "236645808328409090", + ProjectRoleAssertion: true, + PublicKeys: map[string][]byte{"236647201860747266": []byte(pubkey)}, + ProjectRoleKeys: []string{"role1", "role2"}, + Settings: &OIDCSettings{ + AccessTokenLifetime: 43200000000000, + IdTokenLifetime: 43200000000000, + }, + LoginVersion: domain.LoginVersion1, + LoginBaseURI: func() *URL { + ret, _ := url.Parse("https://test.com/login") + retURL := URL(*ret) + return &retURL + }(), + }, + }, { name: "public client", mock: mockQuery(expQuery, cols, []driver.Value{testdataOidcClientPublic}, "instanceID", "clientID", true), diff --git a/internal/query/projection/app.go b/internal/query/projection/app.go index 14053cc8dc..c50bf03f40 100644 --- a/internal/query/projection/app.go +++ b/internal/query/projection/app.go @@ -62,12 +62,14 @@ const ( AppOIDCConfigColumnLoginVersion = "login_version" AppOIDCConfigColumnLoginBaseURI = "login_base_uri" - appSAMLTableSuffix = "saml_configs" - AppSAMLConfigColumnAppID = "app_id" - AppSAMLConfigColumnInstanceID = "instance_id" - AppSAMLConfigColumnEntityID = "entity_id" - AppSAMLConfigColumnMetadata = "metadata" - AppSAMLConfigColumnMetadataURL = "metadata_url" + appSAMLTableSuffix = "saml_configs" + AppSAMLConfigColumnAppID = "app_id" + AppSAMLConfigColumnInstanceID = "instance_id" + AppSAMLConfigColumnEntityID = "entity_id" + AppSAMLConfigColumnMetadata = "metadata" + AppSAMLConfigColumnMetadataURL = "metadata_url" + AppSAMLConfigColumnLoginVersion = "login_version" + AppSAMLConfigColumnLoginBaseURI = "login_base_uri" ) type appProjection struct{} @@ -143,6 +145,8 @@ func (*appProjection) Init() *old_handler.Check { handler.NewColumn(AppSAMLConfigColumnEntityID, handler.ColumnTypeText), handler.NewColumn(AppSAMLConfigColumnMetadata, handler.ColumnTypeBytes), handler.NewColumn(AppSAMLConfigColumnMetadataURL, handler.ColumnTypeText), + handler.NewColumn(AppSAMLConfigColumnLoginVersion, handler.ColumnTypeEnum, handler.Nullable()), + handler.NewColumn(AppSAMLConfigColumnLoginBaseURI, handler.ColumnTypeText, handler.Nullable()), }, handler.NewPrimaryKey(AppSAMLConfigColumnInstanceID, AppSAMLConfigColumnAppID), appSAMLTableSuffix, @@ -703,6 +707,8 @@ func (p *appProjection) reduceSAMLConfigAdded(event eventstore.Event) (*handler. handler.NewCol(AppSAMLConfigColumnEntityID, e.EntityID), handler.NewCol(AppSAMLConfigColumnMetadata, e.Metadata), handler.NewCol(AppSAMLConfigColumnMetadataURL, e.MetadataURL), + handler.NewCol(AppSAMLConfigColumnLoginVersion, e.LoginVersion), + handler.NewCol(AppSAMLConfigColumnLoginBaseURI, e.LoginBaseURI), }, handler.WithTableSuffix(appSAMLTableSuffix), ), @@ -735,6 +741,12 @@ func (p *appProjection) reduceSAMLConfigChanged(event eventstore.Event) (*handle if e.EntityID != "" { cols = append(cols, handler.NewCol(AppSAMLConfigColumnEntityID, e.EntityID)) } + if e.LoginVersion != nil { + cols = append(cols, handler.NewCol(AppSAMLConfigColumnLoginVersion, *e.LoginVersion)) + } + if e.LoginBaseURI != nil { + cols = append(cols, handler.NewCol(AppSAMLConfigColumnLoginBaseURI, *e.LoginBaseURI)) + } if len(cols) == 0 { return handler.NewNoOpStatement(e), nil diff --git a/internal/query/saml_sp.go b/internal/query/saml_sp.go new file mode 100644 index 0000000000..3682375d0b --- /dev/null +++ b/internal/query/saml_sp.go @@ -0,0 +1,104 @@ +package query + +import ( + "context" + "database/sql" + _ "embed" + "errors" + "net/url" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/telemetry/tracing" + "github.com/zitadel/zitadel/internal/zerrors" +) + +type SAMLServiceProvider struct { + InstanceID string `json:"instance_id,omitempty"` + AppID string `json:"app_id,omitempty"` + State domain.AppState `json:"state,omitempty"` + EntityID string `json:"entity_id,omitempty"` + Metadata []byte `json:"metadata,omitempty"` + MetadataURL string `json:"metadata_url,omitempty"` + ProjectID string `json:"project_id,omitempty"` + ProjectRoleAssertion bool `json:"project_role_assertion,omitempty"` + LoginVersion domain.LoginVersion `json:"login_version,omitempty"` + LoginBaseURI *url.URL `json:"login_base_uri,omitempty"` +} + +//go:embed saml_sp_by_id.sql +var samlSPQuery string + +func (q *Queries) ActiveSAMLServiceProviderByID(ctx context.Context, entityID string) (sp *SAMLServiceProvider, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + sp, err = scanSAMLServiceProviderByID(row) + return err + }, samlSPQuery, + authz.GetInstance(ctx).InstanceID(), + entityID, + ) + if errors.Is(err, sql.ErrNoRows) { + return nil, zerrors.ThrowNotFound(err, "QUERY-HeOcis2511", "Errors.App.NotFound") + } + if err != nil { + return nil, zerrors.ThrowInternal(err, "QUERY-OyJx1Rp30z", "Errors.Internal") + } + instance := authz.GetInstance(ctx) + loginV2 := instance.Features().LoginV2 + if loginV2.Required { + sp.LoginVersion = domain.LoginVersion2 + sp.LoginBaseURI = loginV2.BaseURI + } + return sp, err +} + +func scanSAMLServiceProviderByID(row *sql.Row) (*SAMLServiceProvider, error) { + var instanceID, appID, entityID, metadataURL, projectID sql.NullString + var projectRoleAssertion sql.NullBool + var metadata []byte + var state, loginVersion sql.NullInt16 + var loginBaseURI sql.NullString + + err := row.Scan( + &instanceID, + &appID, + &state, + &entityID, + &metadata, + &metadataURL, + &projectID, + &projectRoleAssertion, + &loginVersion, + &loginBaseURI, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, zerrors.ThrowNotFound(err, "QUERY-8cjj8ao6yY", "Errors.App.NotFound") + } + return nil, zerrors.ThrowInternal(err, "QUERY-1xzFD209Bp", "Errors.Internal") + } + sp := &SAMLServiceProvider{ + InstanceID: instanceID.String, + AppID: appID.String, + State: domain.AppState(state.Int16), + EntityID: entityID.String, + Metadata: metadata, + MetadataURL: metadataURL.String, + ProjectID: projectID.String, + ProjectRoleAssertion: projectRoleAssertion.Bool, + } + if loginVersion.Valid { + sp.LoginVersion = domain.LoginVersion(loginVersion.Int16) + } + if loginBaseURI.Valid && loginBaseURI.String != "" { + url, err := url.Parse(loginBaseURI.String) + if err != nil { + return nil, err + } + sp.LoginBaseURI = url + } + return sp, nil +} diff --git a/internal/query/saml_sp_by_id.sql b/internal/query/saml_sp_by_id.sql new file mode 100644 index 0000000000..ff877c7ab9 --- /dev/null +++ b/internal/query/saml_sp_by_id.sql @@ -0,0 +1,19 @@ +select c.instance_id, + c.app_id, + a.state, + c.entity_id, + c.metadata, + c.metadata_url, + a.project_id, + p.project_role_assertion, + c.login_version, + c.login_base_uri +from projections.apps7_saml_configs c + join projections.apps7 a + on a.id = c.app_id and a.instance_id = c.instance_id and a.state = 1 + join projections.projects4 p + on p.id = a.project_id and p.instance_id = a.instance_id and p.state = 1 + join projections.orgs1 o + on o.id = p.resource_owner and o.instance_id = c.instance_id and o.org_state = 1 +where c.instance_id = $1 + and c.entity_id = $2 diff --git a/internal/query/saml_sp_test.go b/internal/query/saml_sp_test.go new file mode 100644 index 0000000000..4aafd95de1 --- /dev/null +++ b/internal/query/saml_sp_test.go @@ -0,0 +1,123 @@ +package query + +import ( + "database/sql" + "database/sql/driver" + _ "embed" + "net/url" + "regexp" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/zerrors" +) + +func TestQueries_ActiveSAMLServiceProviderByID(t *testing.T) { + expQuery := regexp.QuoteMeta(samlSPQuery) + cols := []string{ + "instance_id", + "app_id", + "state", + "entity_id", + "metadata", + "metadata_url", + "project_id", + "project_role_assertion", + "login_version", + "login_base_uri", + } + + tests := []struct { + name string + mock sqlExpectation + want *SAMLServiceProvider + wantErr error + }{ + { + name: "no rows", + mock: mockQueryErr(expQuery, sql.ErrNoRows, "instanceID", "entityID"), + wantErr: zerrors.ThrowNotFound(sql.ErrNoRows, "QUERY-HeOcis2511", "Errors.App.NotFound"), + }, + { + name: "internal error", + mock: mockQueryErr(expQuery, sql.ErrConnDone, "instanceID", "entityID"), + wantErr: zerrors.ThrowInternal(sql.ErrConnDone, "QUERY-OyJx1Rp30z", "Errors.Internal"), + }, + { + name: "sp", + mock: mockQuery(expQuery, cols, []driver.Value{ + "230690539048009730", + "236647088211886082", + domain.AppStateActive, + "https://test.com/metadata", + "metadata", + "https://test.com/metadata", + "236645808328409090", + true, + domain.LoginVersionUnspecified, + "", + }, "instanceID", "entityID"), + want: &SAMLServiceProvider{ + InstanceID: "230690539048009730", + AppID: "236647088211886082", + State: domain.AppStateActive, + EntityID: "https://test.com/metadata", + Metadata: []byte("metadata"), + MetadataURL: "https://test.com/metadata", + ProjectID: "236645808328409090", + ProjectRoleAssertion: true, + }, + }, + { + name: "sp with loginversion", + mock: mockQuery(expQuery, cols, []driver.Value{ + "230690539048009730", + "236647088211886082", + domain.AppStateActive, + "https://test.com/metadata", + "metadata", + "https://test.com/metadata", + "236645808328409090", + true, + domain.LoginVersion2, + "https://test.com/login", + }, "instanceID", "entityID"), + want: &SAMLServiceProvider{ + InstanceID: "230690539048009730", + AppID: "236647088211886082", + State: domain.AppStateActive, + EntityID: "https://test.com/metadata", + Metadata: []byte("metadata"), + MetadataURL: "https://test.com/metadata", + ProjectID: "236645808328409090", + ProjectRoleAssertion: true, + LoginVersion: domain.LoginVersion2, + LoginBaseURI: func() *url.URL { + ret, _ := url.Parse("https://test.com/login") + return ret + }(), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + execMock(t, tt.mock, func(db *sql.DB) { + q := &Queries{ + client: &database.DB{ + DB: db, + Database: &prepareDB{}, + }, + } + ctx := authz.NewMockContext("instanceID", "orgID", "loginClient") + got, err := q.ActiveSAMLServiceProviderByID(ctx, "entityID") + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + }) + }) + } +} diff --git a/internal/query/testdata/oidc_client_jwt_loginversion.json b/internal/query/testdata/oidc_client_jwt_loginversion.json new file mode 100644 index 0000000000..7664c0abc1 --- /dev/null +++ b/internal/query/testdata/oidc_client_jwt_loginversion.json @@ -0,0 +1,32 @@ +{ + "instance_id": "230690539048009730", + "app_id": "236647088211886082", + "state": 1, + "client_id": "236647088211951618", + "client_secret": null, + "redirect_uris": ["http://localhost:9999/auth/callback"], + "response_types": [0], + "grant_types": [0, 2], + "application_type": 0, + "auth_method_type": 3, + "post_logout_redirect_uris": ["https://example.com/logout"], + "is_dev_mode": true, + "access_token_type": 1, + "access_token_role_assertion": true, + "id_token_role_assertion": true, + "id_token_userinfo_assertion": true, + "clock_skew": 1000000000, + "additional_origins": ["https://example.com"], + "project_id": "236645808328409090", + "project_role_assertion": true, + "project_role_keys": ["role1", "role2"], + "public_keys": { + "236647201860747266": "LS0tLS1CRUdJTiBSU0EgUFVCTElDIEtFWS0tLS0tCk1JSUJJakFOQmdrcWhraUc5dzBCQVFFRkFB\nT0NBUThBTUlJQkNnS0NBUUVBMnVmQUwxYjcyYkl5MWFyK1dzNmIKR29oSkpRRkI3ZGZSYXBEcWVx\nTThVa3A2Q1ZkUHpxL3BPejF2aUFxNTB5eldaSnJ5Risyd3NoRkFLR0Y5QTIvQgoyWWY5YkpYUFov\nS2JrRnJZVDNOVHZZRGt2bGFTVGw5bU1uenJVMjlzNDhGMVBUV0tmQitDM2FNc09FRzFCdWZWCnM2\nM3FGNG5yRVBqU2JobGpJY285RlpxNFhwcEl6aE1RMGZEZEEvK1h5Z0NKcXZ1YUwwTGliTTFLcmxV\nZG51NzEKWWVraFNKakVQbnZPaXNYSWs0SVh5d29HSU93dGp4a0R2Tkl0UXZhTVZsZHI0L2tiNnV2\nYmdkV3dxNUV3QlpYcQpsb3cya3lKb3YzOFY0VWsySThrdVhwTGNucnB3NVRpbzJvb2lVRTI3YjB2\nSFpxQktPZWk5VW84OHFDcm4zRUt4CjZRSURBUUFCCi0tLS0tRU5EIFJTQSBQVUJMSUMgS0VZLS0t\nLS0K" + }, + "settings": { + "access_token_lifetime": 43200000000000, + "id_token_lifetime": 43200000000000 + }, + "login_version": 1, + "login_base_uri": "https://test.com/login" +} diff --git a/internal/repository/project/oidc_config.go b/internal/repository/project/oidc_config.go index 8bc918afbe..dd7d3a85b6 100644 --- a/internal/repository/project/oidc_config.go +++ b/internal/repository/project/oidc_config.go @@ -384,13 +384,13 @@ func ChangeBackChannelLogoutURI(backChannelLogoutURI string) func(event *OIDCCon } } -func ChangeLoginVersion(loginVersion domain.LoginVersion) func(event *OIDCConfigChangedEvent) { +func ChangeOIDCLoginVersion(loginVersion domain.LoginVersion) func(event *OIDCConfigChangedEvent) { return func(e *OIDCConfigChangedEvent) { e.LoginVersion = &loginVersion } } -func ChangeLoginBaseURI(loginBaseURI string) func(event *OIDCConfigChangedEvent) { +func ChangeOIDCLoginBaseURI(loginBaseURI string) func(event *OIDCConfigChangedEvent) { return func(e *OIDCConfigChangedEvent) { e.LoginBaseURI = &loginBaseURI } diff --git a/internal/repository/project/saml_config.go b/internal/repository/project/saml_config.go index 97af24a0d9..ddcb9c0eab 100644 --- a/internal/repository/project/saml_config.go +++ b/internal/repository/project/saml_config.go @@ -3,6 +3,7 @@ package project import ( "context" + "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/zerrors" ) @@ -16,10 +17,12 @@ const ( type SAMLConfigAddedEvent struct { eventstore.BaseEvent `json:"-"` - AppID string `json:"appId"` - EntityID string `json:"entityId"` - Metadata []byte `json:"metadata,omitempty"` - MetadataURL string `json:"metadata_url,omitempty"` + AppID string `json:"appId"` + EntityID string `json:"entityId"` + Metadata []byte `json:"metadata,omitempty"` + MetadataURL string `json:"metadata_url,omitempty"` + LoginVersion domain.LoginVersion `json:"loginVersion,omitempty"` + LoginBaseURI string `json:"loginBaseURI,omitempty"` } func (e *SAMLConfigAddedEvent) Payload() interface{} { @@ -50,6 +53,8 @@ func NewSAMLConfigAddedEvent( entityID string, metadata []byte, metadataURL string, + loginVersion domain.LoginVersion, + loginBaseURI string, ) *SAMLConfigAddedEvent { return &SAMLConfigAddedEvent{ BaseEvent: *eventstore.NewBaseEventForPush( @@ -57,10 +62,12 @@ func NewSAMLConfigAddedEvent( aggregate, SAMLConfigAddedType, ), - AppID: appID, - EntityID: entityID, - Metadata: metadata, - MetadataURL: metadataURL, + AppID: appID, + EntityID: entityID, + Metadata: metadata, + MetadataURL: metadataURL, + LoginVersion: loginVersion, + LoginBaseURI: loginBaseURI, } } @@ -80,11 +87,13 @@ func SAMLConfigAddedEventMapper(event eventstore.Event) (eventstore.Event, error type SAMLConfigChangedEvent struct { eventstore.BaseEvent `json:"-"` - AppID string `json:"appId"` - EntityID string `json:"entityId"` - Metadata []byte `json:"metadata,omitempty"` - MetadataURL *string `json:"metadata_url,omitempty"` - oldEntityID string + AppID string `json:"appId"` + EntityID string `json:"entityId"` + Metadata []byte `json:"metadata,omitempty"` + MetadataURL *string `json:"metadata_url,omitempty"` + LoginVersion *domain.LoginVersion `json:"loginVersion,omitempty"` + LoginBaseURI *string `json:"loginBaseURI,omitempty"` + oldEntityID string } func (e *SAMLConfigChangedEvent) Payload() interface{} { @@ -147,6 +156,17 @@ func ChangeEntityID(entityID string) func(event *SAMLConfigChangedEvent) { } } +func ChangeSAMLLoginVersion(loginVersion domain.LoginVersion) func(event *SAMLConfigChangedEvent) { + return func(e *SAMLConfigChangedEvent) { + e.LoginVersion = &loginVersion + } +} +func ChangeSAMLLoginBaseURI(loginBaseURI string) func(event *SAMLConfigChangedEvent) { + return func(e *SAMLConfigChangedEvent) { + e.LoginBaseURI = &loginBaseURI + } +} + func SAMLConfigChangedEventMapper(event eventstore.Event) (eventstore.Event, error) { e := &SAMLConfigChangedEvent{ BaseEvent: *eventstore.BaseEventFromRepo(event), diff --git a/proto/zitadel/app.proto b/proto/zitadel/app.proto index 999e71cabf..08359e3762 100644 --- a/proto/zitadel/app.proto +++ b/proto/zitadel/app.proto @@ -222,6 +222,11 @@ message SAMLConfig { bytes metadata_xml = 1; string metadata_url = 2; } + LoginVersion login_version = 3 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "Specify the preferred login UI, where the user is redirected to for authentication. If unset, the login UI is chosen by the instance default."; + } + ]; } enum APIAuthMethodType { diff --git a/proto/zitadel/management.proto b/proto/zitadel/management.proto index 94de141a65..5444983396 100644 --- a/proto/zitadel/management.proto +++ b/proto/zitadel/management.proto @@ -9850,6 +9850,11 @@ message AddSAMLAppRequest { bytes metadata_xml = 3 [(validate.rules).bytes.max_len = 500000]; string metadata_url = 4 [(validate.rules).string.max_len = 200]; } + zitadel.app.v1.LoginVersion login_version = 5 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "Specify the preferred login UI, where the user is redirected to for authentication. If unset, the login UI is chosen by the instance default."; + } + ]; } message AddSAMLAppResponse { @@ -10014,6 +10019,11 @@ message UpdateSAMLAppConfigRequest { bytes metadata_xml = 3 [(validate.rules).bytes.max_len = 500000]; string metadata_url = 4 [(validate.rules).string.max_len = 200]; } + zitadel.app.v1.LoginVersion login_version = 5 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "Specify the preferred login UI, where the user is redirected to for authentication. If unset, the login UI is chosen by the instance default."; + } + ]; } message UpdateSAMLAppConfigResponse { @@ -13653,7 +13663,7 @@ message SetTriggerActionsRequest { * - Internal Authentication: 3 * - Complement Token: 2 * - Complement SAML Response: 4 - */ + */ string flow_type = 1 [ (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { example: "\"1\""; @@ -13664,11 +13674,11 @@ message SetTriggerActionsRequest { * - External Authentication: * - Post Authentication: TRIGGER_TYPE_POST_AUTHENTICATION or 1 * - Pre Creation: TRIGGER_TYPE_PRE_CREATION or 2 - * - Post Creation: TRIGGER_TYPE_POST_CREATION or 3 + * - Post Creation: TRIGGER_TYPE_POST_CREATION or 3 * - Internal Authentication: * - Post Authentication: TRIGGER_TYPE_POST_AUTHENTICATION or 1 * - Pre Creation: TRIGGER_TYPE_PRE_CREATION or 2 - * - Post Creation: TRIGGER_TYPE_POST_CREATION or 3 + * - Post Creation: TRIGGER_TYPE_POST_CREATION or 3 * - Complement Token: * - Pre Userinfo Creation: 4 * - Pre Access Token Creation: 5