diff --git a/internal/api/oidc/client.go b/internal/api/oidc/client.go index 8dc7c58ad5..77fe7451b2 100644 --- a/internal/api/oidc/client.go +++ b/internal/api/oidc/client.go @@ -50,13 +50,10 @@ func (o *OPStorage) GetClientByClientID(ctx context.Context, id string) (_ op.Cl err = oidcError(err) span.EndWithError(err) }() - client, err := o.query.GetOIDCClientByID(ctx, id, false) + client, err := o.query.ActiveOIDCClientByID(ctx, id, false) if err != nil { return nil, err } - if client.State != domain.AppStateActive { - return nil, zerrors.ThrowPreconditionFailed(nil, "OIDC-sdaGg", "client is not active") - } return ClientFromBusiness(client, o.defaultLoginURL, o.defaultLoginURLV2), nil } @@ -979,16 +976,13 @@ func (s *Server) VerifyClient(ctx context.Context, r *op.Request[op.ClientCreden if err != nil { return nil, err } - client, err := s.query.GetOIDCClientByID(ctx, clientID, assertion) + client, err := s.query.ActiveOIDCClientByID(ctx, clientID, assertion) if zerrors.IsNotFound(err) { - return nil, oidc.ErrInvalidClient().WithParent(err).WithDescription("client not found") + return nil, oidc.ErrInvalidClient().WithParent(err).WithDescription("no active client not found") } if err != nil { return nil, err // defaults to server error } - if client.State != domain.AppStateActive { - return nil, oidc.ErrInvalidClient().WithDescription("client is not active") - } if client.Settings == nil { client.Settings = &query.OIDCSettings{ AccessTokenLifetime: s.defaultAccessTokenLifetime, diff --git a/internal/api/oidc/client_integration_test.go b/internal/api/oidc/client_integration_test.go index 65cc9309d5..aff193070e 100644 --- a/internal/api/oidc/client_integration_test.go +++ b/internal/api/oidc/client_integration_test.go @@ -190,9 +190,13 @@ func TestServer_VerifyClient(t *testing.T) { sessionID, sessionToken, startTime, changeTime := Tester.CreateVerifiedWebAuthNSession(t, CTXLOGIN, User.GetUserId()) project, err := Tester.CreateProject(CTX) require.NoError(t, err) + projectInactive, err := Tester.CreateProject(CTX) + require.NoError(t, err) inactiveClient, err := Tester.CreateOIDCInactivateClient(CTX, redirectURI, logoutRedirectURI, project.GetId()) require.NoError(t, err) + inactiveProjectClient, err := Tester.CreateOIDCInactivateProjectClient(CTX, redirectURI, logoutRedirectURI, projectInactive.GetId()) + require.NoError(t, err) nativeClient, err := Tester.CreateOIDCNativeClient(CTX, redirectURI, logoutRedirectURI, project.GetId(), false) require.NoError(t, err) basicWebClient, err := Tester.CreateOIDCWebClientBasic(CTX, redirectURI, logoutRedirectURI, project.GetId()) @@ -234,6 +238,14 @@ func TestServer_VerifyClient(t *testing.T) { }, wantErr: true, }, + { + name: "client inactive (project) error", + client: clientDetails{ + authReqClientID: nativeClient.GetClientId(), + clientID: inactiveProjectClient.GetClientId(), + }, + wantErr: true, + }, { name: "native client success", client: clientDetails{ diff --git a/internal/api/oidc/introspect.go b/internal/api/oidc/introspect.go index 99602393c5..1d53b78110 100644 --- a/internal/api/oidc/introspect.go +++ b/internal/api/oidc/introspect.go @@ -216,7 +216,7 @@ func (s *Server) clientFromCredentials(ctx context.Context, cc *op.ClientCredent if err != nil { return nil, err } - client, err = s.query.GetIntrospectionClientByID(ctx, clientID, assertion) + client, err = s.query.ActiveIntrospectionClientByID(ctx, clientID, assertion) if errors.Is(err, sql.ErrNoRows) { return nil, oidc.ErrUnauthorizedClient().WithParent(err) } diff --git a/internal/api/saml/storage.go b/internal/api/saml/storage.go index bd8afffe54..45b9af1b94 100644 --- a/internal/api/saml/storage.go +++ b/internal/api/saml/storage.go @@ -55,13 +55,10 @@ type Storage struct { } func (p *Storage) GetEntityByID(ctx context.Context, entityID string) (*serviceprovider.ServiceProvider, error) { - app, err := p.query.AppBySAMLEntityID(ctx, entityID) + app, err := p.query.ActiveAppBySAMLEntityID(ctx, entityID) if err != nil { return nil, err } - if app.State != domain.AppStateActive { - return nil, zerrors.ThrowPreconditionFailed(nil, "SAML-sdaGg", "app is not active") - } return serviceprovider.NewServiceProvider( app.ID, &serviceprovider.Config{ @@ -72,13 +69,10 @@ func (p *Storage) GetEntityByID(ctx context.Context, entityID string) (*servicep } func (p *Storage) GetEntityIDByAppID(ctx context.Context, appID string) (string, error) { - app, err := p.query.AppByID(ctx, appID) + app, err := p.query.AppByID(ctx, appID, true) if err != nil { return "", err } - if app.State != domain.AppStateActive { - return "", zerrors.ThrowPreconditionFailed(nil, "SAML-sdaGg", "app is not active") - } return app.SAMLConfig.EntityID, nil } diff --git a/internal/integration/oidc.go b/internal/integration/oidc.go index 3e90cb6856..9c4130600b 100644 --- a/internal/integration/oidc.go +++ b/internal/integration/oidc.go @@ -88,6 +88,20 @@ func (s *Tester) CreateOIDCInactivateClient(ctx context.Context, redirectURI, lo return client, err } +func (s *Tester) CreateOIDCInactivateProjectClient(ctx context.Context, redirectURI, logoutRedirectURI, projectID string) (*management.AddOIDCAppResponse, error) { + client, err := s.CreateOIDCNativeClient(ctx, redirectURI, logoutRedirectURI, projectID, false) + if err != nil { + return nil, err + } + _, err = s.Client.Mgmt.DeactivateProject(ctx, &management.DeactivateProjectRequest{ + Id: projectID, + }) + if err != nil { + return nil, err + } + return client, err +} + func (s *Tester) CreateOIDCImplicitFlowClient(ctx context.Context, redirectURI string) (*management.AddOIDCAppResponse, error) { project, err := s.Client.Mgmt.AddProject(ctx, &management.AddProjectRequest{ Name: fmt.Sprintf("project-%d", time.Now().UnixNano()), diff --git a/internal/query/app.go b/internal/query/app.go index 7d69981dd4..b94cb9cdaf 100644 --- a/internal/query/app.go +++ b/internal/query/app.go @@ -256,7 +256,7 @@ func (q *Queries) AppByProjectAndAppID(ctx context.Context, shouldTriggerBulk bo traceSpan.EndWithError(err) } - stmt, scan := prepareAppQuery(ctx, q.client) + stmt, scan := prepareAppQuery(ctx, q.client, false) eq := sq.Eq{ AppColumnID.identifier(): appID, AppColumnProjectID.identifier(): projectID, @@ -274,15 +274,20 @@ func (q *Queries) AppByProjectAndAppID(ctx context.Context, shouldTriggerBulk bo return app, err } -func (q *Queries) AppByID(ctx context.Context, appID string) (app *App, err error) { +func (q *Queries) AppByID(ctx context.Context, appID string, activeOnly bool) (app *App, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareAppQuery(ctx, q.client) + stmt, scan := prepareAppQuery(ctx, q.client, activeOnly) eq := sq.Eq{ AppColumnID.identifier(): appID, AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), } + if activeOnly { + eq[AppColumnState.identifier()] = domain.AppStateActive + eq[ProjectColumnState.identifier()] = domain.ProjectStateActive + eq[OrgColumnState.identifier()] = domain.OrgStateActive + } query, args, err := stmt.Where(eq).ToSql() if err != nil { return nil, zerrors.ThrowInternal(err, "QUERY-immt9", "Errors.Query.SQLStatement") @@ -295,7 +300,7 @@ func (q *Queries) AppByID(ctx context.Context, appID string) (app *App, err erro return app, err } -func (q *Queries) AppBySAMLEntityID(ctx context.Context, entityID string) (app *App, err error) { +func (q *Queries) ActiveAppBySAMLEntityID(ctx context.Context, entityID string) (app *App, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -303,6 +308,9 @@ func (q *Queries) AppBySAMLEntityID(ctx context.Context, entityID string) (app * eq := sq.Eq{ AppSAMLConfigColumnEntityID.identifier(): entityID, AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), + AppColumnState.identifier(): domain.AppStateActive, + ProjectColumnState.identifier(): domain.ProjectStateActive, + OrgColumnState.identifier(): domain.OrgStateActive, } query, args, err := stmt.Where(eq).ToSql() if err != nil { @@ -413,8 +421,13 @@ func (q *Queries) AppByClientID(ctx context.Context, clientID string) (app *App, ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareAppQuery(ctx, q.client) - eq := sq.Eq{AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()} + stmt, scan := prepareAppQuery(ctx, q.client, true) + eq := sq.Eq{ + AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), + AppColumnState.identifier(): domain.AppStateActive, + ProjectColumnState.identifier(): domain.ProjectStateActive, + OrgColumnState.identifier(): domain.OrgStateActive, + } query, args, err := stmt.Where(sq.And{ eq, sq.Or{ @@ -491,107 +504,121 @@ func NewAppProjectIDSearchQuery(id string) (SearchQuery, error) { return NewTextQuery(AppColumnProjectID, id, TextEquals) } -func prepareAppQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { - return sq.Select( - AppColumnID.identifier(), - AppColumnName.identifier(), - AppColumnProjectID.identifier(), - AppColumnCreationDate.identifier(), - AppColumnChangeDate.identifier(), - AppColumnResourceOwner.identifier(), - AppColumnState.identifier(), - AppColumnSequence.identifier(), +func prepareAppQuery(ctx context.Context, db prepareDatabase, activeOnly bool) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + query := sq.Select( + AppColumnID.identifier(), + AppColumnName.identifier(), + AppColumnProjectID.identifier(), + AppColumnCreationDate.identifier(), + AppColumnChangeDate.identifier(), + AppColumnResourceOwner.identifier(), + AppColumnState.identifier(), + AppColumnSequence.identifier(), - AppAPIConfigColumnAppID.identifier(), - AppAPIConfigColumnClientID.identifier(), - AppAPIConfigColumnAuthMethod.identifier(), + AppAPIConfigColumnAppID.identifier(), + AppAPIConfigColumnClientID.identifier(), + AppAPIConfigColumnAuthMethod.identifier(), - AppOIDCConfigColumnAppID.identifier(), - AppOIDCConfigColumnVersion.identifier(), - AppOIDCConfigColumnClientID.identifier(), - AppOIDCConfigColumnRedirectUris.identifier(), - AppOIDCConfigColumnResponseTypes.identifier(), - AppOIDCConfigColumnGrantTypes.identifier(), - AppOIDCConfigColumnApplicationType.identifier(), - AppOIDCConfigColumnAuthMethodType.identifier(), - AppOIDCConfigColumnPostLogoutRedirectUris.identifier(), - AppOIDCConfigColumnDevMode.identifier(), - AppOIDCConfigColumnAccessTokenType.identifier(), - AppOIDCConfigColumnAccessTokenRoleAssertion.identifier(), - AppOIDCConfigColumnIDTokenRoleAssertion.identifier(), - AppOIDCConfigColumnIDTokenUserinfoAssertion.identifier(), - AppOIDCConfigColumnClockSkew.identifier(), - AppOIDCConfigColumnAdditionalOrigins.identifier(), - AppOIDCConfigColumnSkipNativeAppSuccessPage.identifier(), + AppOIDCConfigColumnAppID.identifier(), + AppOIDCConfigColumnVersion.identifier(), + AppOIDCConfigColumnClientID.identifier(), + AppOIDCConfigColumnRedirectUris.identifier(), + AppOIDCConfigColumnResponseTypes.identifier(), + AppOIDCConfigColumnGrantTypes.identifier(), + AppOIDCConfigColumnApplicationType.identifier(), + AppOIDCConfigColumnAuthMethodType.identifier(), + AppOIDCConfigColumnPostLogoutRedirectUris.identifier(), + AppOIDCConfigColumnDevMode.identifier(), + AppOIDCConfigColumnAccessTokenType.identifier(), + AppOIDCConfigColumnAccessTokenRoleAssertion.identifier(), + AppOIDCConfigColumnIDTokenRoleAssertion.identifier(), + AppOIDCConfigColumnIDTokenUserinfoAssertion.identifier(), + AppOIDCConfigColumnClockSkew.identifier(), + AppOIDCConfigColumnAdditionalOrigins.identifier(), + AppOIDCConfigColumnSkipNativeAppSuccessPage.identifier(), - AppSAMLConfigColumnAppID.identifier(), - AppSAMLConfigColumnEntityID.identifier(), - AppSAMLConfigColumnMetadata.identifier(), - AppSAMLConfigColumnMetadataURL.identifier(), - ).From(appsTable.identifier()). + AppSAMLConfigColumnAppID.identifier(), + AppSAMLConfigColumnEntityID.identifier(), + AppSAMLConfigColumnMetadata.identifier(), + AppSAMLConfigColumnMetadataURL.identifier(), + ).From(appsTable.identifier()). + PlaceholderFormat(sq.Dollar) + + if activeOnly { + return query. + LeftJoin(join(AppAPIConfigColumnAppID, AppColumnID)). + LeftJoin(join(AppOIDCConfigColumnAppID, AppColumnID)). + LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID)). + LeftJoin(join(ProjectColumnID, AppColumnProjectID)). + LeftJoin(join(OrgColumnID, AppColumnResourceOwner) + db.Timetravel(call.Took(ctx))), + scanApp + } + return query. LeftJoin(join(AppAPIConfigColumnAppID, AppColumnID)). LeftJoin(join(AppOIDCConfigColumnAppID, AppColumnID)). - LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID) + db.Timetravel(call.Took(ctx))). - PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*App, error) { - app := new(App) + LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID) + db.Timetravel(call.Took(ctx))), + scanApp +} - var ( - apiConfig = sqlAPIConfig{} - oidcConfig = sqlOIDCConfig{} - samlConfig = sqlSAMLConfig{} - ) +func scanApp(row *sql.Row) (*App, error) { + app := new(App) - err := row.Scan( - &app.ID, - &app.Name, - &app.ProjectID, - &app.CreationDate, - &app.ChangeDate, - &app.ResourceOwner, - &app.State, - &app.Sequence, + var ( + apiConfig = sqlAPIConfig{} + oidcConfig = sqlOIDCConfig{} + samlConfig = sqlSAMLConfig{} + ) - &apiConfig.appID, - &apiConfig.clientID, - &apiConfig.authMethod, + err := row.Scan( + &app.ID, + &app.Name, + &app.ProjectID, + &app.CreationDate, + &app.ChangeDate, + &app.ResourceOwner, + &app.State, + &app.Sequence, - &oidcConfig.appID, - &oidcConfig.version, - &oidcConfig.clientID, - &oidcConfig.redirectUris, - &oidcConfig.responseTypes, - &oidcConfig.grantTypes, - &oidcConfig.applicationType, - &oidcConfig.authMethodType, - &oidcConfig.postLogoutRedirectUris, - &oidcConfig.devMode, - &oidcConfig.accessTokenType, - &oidcConfig.accessTokenRoleAssertion, - &oidcConfig.iDTokenRoleAssertion, - &oidcConfig.iDTokenUserinfoAssertion, - &oidcConfig.clockSkew, - &oidcConfig.additionalOrigins, - &oidcConfig.skipNativeAppSuccessPage, + &apiConfig.appID, + &apiConfig.clientID, + &apiConfig.authMethod, - &samlConfig.appID, - &samlConfig.entityID, - &samlConfig.metadata, - &samlConfig.metadataURL, - ) + &oidcConfig.appID, + &oidcConfig.version, + &oidcConfig.clientID, + &oidcConfig.redirectUris, + &oidcConfig.responseTypes, + &oidcConfig.grantTypes, + &oidcConfig.applicationType, + &oidcConfig.authMethodType, + &oidcConfig.postLogoutRedirectUris, + &oidcConfig.devMode, + &oidcConfig.accessTokenType, + &oidcConfig.accessTokenRoleAssertion, + &oidcConfig.iDTokenRoleAssertion, + &oidcConfig.iDTokenUserinfoAssertion, + &oidcConfig.clockSkew, + &oidcConfig.additionalOrigins, + &oidcConfig.skipNativeAppSuccessPage, - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, zerrors.ThrowNotFound(err, "QUERY-pCP8P", "Errors.App.NotExisting") - } - return nil, zerrors.ThrowInternal(err, "QUERY-4SJlx", "Errors.Internal") - } + &samlConfig.appID, + &samlConfig.entityID, + &samlConfig.metadata, + &samlConfig.metadataURL, + ) - apiConfig.set(app) - oidcConfig.set(app) - samlConfig.set(app) - - return app, nil + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, zerrors.ThrowNotFound(err, "QUERY-pCP8P", "Errors.App.NotExisting") } + return nil, zerrors.ThrowInternal(err, "QUERY-4SJlx", "Errors.Internal") + } + + apiConfig.set(app) + oidcConfig.set(app) + samlConfig.set(app) + + return app, nil } func prepareOIDCAppQuery() (sq.SelectBuilder, func(*sql.Row) (*App, error)) { @@ -690,6 +717,8 @@ func prepareSAMLAppQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil AppSAMLConfigColumnMetadataURL.identifier(), ).From(appsTable.identifier()). Join(join(AppSAMLConfigColumnAppID, AppColumnID)). + Join(join(ProjectColumnID, AppColumnProjectID)). + Join(join(OrgColumnID, AppColumnResourceOwner)). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*App, error) { app := new(App) diff --git a/internal/query/app_test.go b/internal/query/app_test.go index 3ccec2e4c8..9a9c613868 100644 --- a/internal/query/app_test.go +++ b/internal/query/app_test.go @@ -1,6 +1,7 @@ package query import ( + "context" "database/sql" "database/sql/driver" "errors" @@ -9,13 +10,15 @@ import ( "testing" "time" + sq "github.com/Masterminds/squirrel" + "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/zerrors" ) var ( - expectedAppQuery = regexp.QuoteMeta(`SELECT projections.apps7.id,` + + expectedAppQueryBase = `SELECT projections.apps7.id,` + ` projections.apps7.name,` + ` projections.apps7.project_id,` + ` projections.apps7.creation_date,` + @@ -53,8 +56,11 @@ var ( ` FROM projections.apps7` + ` LEFT JOIN projections.apps7_api_configs ON projections.apps7.id = projections.apps7_api_configs.app_id AND projections.apps7.instance_id = projections.apps7_api_configs.instance_id` + ` LEFT JOIN projections.apps7_oidc_configs ON projections.apps7.id = projections.apps7_oidc_configs.app_id AND projections.apps7.instance_id = projections.apps7_oidc_configs.instance_id` + - ` LEFT JOIN projections.apps7_saml_configs ON projections.apps7.id = projections.apps7_saml_configs.app_id AND projections.apps7.instance_id = projections.apps7_saml_configs.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'`) + ` LEFT JOIN projections.apps7_saml_configs ON projections.apps7.id = projections.apps7_saml_configs.app_id AND projections.apps7.instance_id = projections.apps7_saml_configs.instance_id` + expectedAppQuery = regexp.QuoteMeta(expectedAppQueryBase) + expectedActiveAppQuery = regexp.QuoteMeta(expectedAppQueryBase + + ` LEFT JOIN projections.projects4 ON projections.apps7.project_id = projections.projects4.id AND projections.apps7.instance_id = projections.projects4.instance_id` + + ` LEFT JOIN projections.orgs1 ON projections.apps7.resource_owner = projections.orgs1.id AND projections.apps7.instance_id = projections.orgs1.instance_id`) expectedAppsQuery = regexp.QuoteMeta(`SELECT projections.apps7.id,` + ` projections.apps7.name,` + ` projections.apps7.project_id,` + @@ -1140,8 +1146,10 @@ func Test_AppPrepare(t *testing.T) { object interface{} }{ { - name: "prepareAppQuery no result", - prepare: prepareAppQuery, + name: "prepareAppQuery no result", + prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return prepareAppQuery(ctx, db, false) + }, want: want{ sqlExpectations: mockQueriesScanErr( expectedAppQuery, @@ -1158,8 +1166,10 @@ func Test_AppPrepare(t *testing.T) { object: (*App)(nil), }, { - name: "prepareAppQuery found", - prepare: prepareAppQuery, + name: "prepareAppQuery found", + prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return prepareAppQuery(ctx, db, false) + }, want: want{ sqlExpectations: mockQuery( expectedAppQuery, @@ -1215,8 +1225,10 @@ func Test_AppPrepare(t *testing.T) { }, }, { - name: "prepareAppQuery api app", - prepare: prepareAppQuery, + name: "prepareAppQuery api app", + prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return prepareAppQuery(ctx, db, false) + }, want: want{ sqlExpectations: mockQueries( expectedAppQuery, @@ -1278,8 +1290,10 @@ func Test_AppPrepare(t *testing.T) { }, }, { - name: "prepareAppQuery oidc app", - prepare: prepareAppQuery, + name: "prepareAppQuery oidc app", + prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return prepareAppQuery(ctx, db, false) + }, want: want{ sqlExpectations: mockQueries( expectedAppQuery, @@ -1355,9 +1369,93 @@ func Test_AppPrepare(t *testing.T) { SkipNativeAppSuccessPage: false, }, }, - }, { - name: "prepareAppQuery saml app", - prepare: prepareAppQuery, + }, + { + name: "prepareAppQuery oidc app active only", + prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return prepareAppQuery(ctx, db, true) + }, + want: want{ + sqlExpectations: mockQueries( + expectedActiveAppQuery, + appCols, + [][]driver.Value{ + { + "app-id", + "app-name", + "project-id", + testNow, + testNow, + "ro", + domain.AppStateActive, + uint64(20211109), + // api config + nil, + nil, + nil, + // oidc config + "app-id", + domain.OIDCVersionV1, + "oidc-client-id", + database.TextArray[string]{"https://redirect.to/me"}, + database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken}, + database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit}, + domain.OIDCApplicationTypeUserAgent, + domain.OIDCAuthMethodTypeNone, + database.TextArray[string]{"post.logout.ch"}, + true, + domain.OIDCTokenTypeJWT, + true, + true, + true, + 1 * time.Second, + database.TextArray[string]{"additional.origin"}, + false, + // saml config + nil, + nil, + nil, + nil, + }, + }, + ), + }, + object: &App{ + ID: "app-id", + CreationDate: testNow, + ChangeDate: testNow, + ResourceOwner: "ro", + State: domain.AppStateActive, + Sequence: 20211109, + Name: "app-name", + ProjectID: "project-id", + OIDCConfig: &OIDCApp{ + Version: domain.OIDCVersionV1, + ClientID: "oidc-client-id", + RedirectURIs: database.TextArray[string]{"https://redirect.to/me"}, + ResponseTypes: database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken}, + GrantTypes: database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit}, + AppType: domain.OIDCApplicationTypeUserAgent, + AuthMethodType: domain.OIDCAuthMethodTypeNone, + PostLogoutRedirectURIs: database.TextArray[string]{"post.logout.ch"}, + IsDevMode: true, + AccessTokenType: domain.OIDCTokenTypeJWT, + AssertAccessTokenRole: true, + AssertIDTokenRole: true, + AssertIDTokenUserinfo: true, + ClockSkew: 1 * time.Second, + AdditionalOrigins: database.TextArray[string]{"additional.origin"}, + ComplianceProblems: nil, + AllowedOrigins: database.TextArray[string]{"https://redirect.to", "additional.origin"}, + SkipNativeAppSuccessPage: false, + }, + }, + }, + { + name: "prepareAppQuery saml app", + prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return prepareAppQuery(ctx, db, false) + }, want: want{ sqlExpectations: mockQueries( expectedAppQuery, @@ -1420,8 +1518,10 @@ func Test_AppPrepare(t *testing.T) { }, }, { - name: "prepareAppQuery oidc app IsDevMode inactive", - prepare: prepareAppQuery, + name: "prepareAppQuery oidc app IsDevMode inactive", + prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return prepareAppQuery(ctx, db, false) + }, want: want{ sqlExpectations: mockQueries( expectedAppQuery, @@ -1499,8 +1599,10 @@ func Test_AppPrepare(t *testing.T) { }, }, { - name: "prepareAppQuery oidc app AssertAccessTokenRole inactive", - prepare: prepareAppQuery, + name: "prepareAppQuery oidc app AssertAccessTokenRole inactive", + prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return prepareAppQuery(ctx, db, false) + }, want: want{ sqlExpectations: mockQueries( expectedAppQuery, @@ -1578,8 +1680,10 @@ func Test_AppPrepare(t *testing.T) { }, }, { - name: "prepareAppQuery oidc app AssertIDTokenRole inactive", - prepare: prepareAppQuery, + name: "prepareAppQuery oidc app AssertIDTokenRole inactive", + prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return prepareAppQuery(ctx, db, false) + }, want: want{ sqlExpectations: mockQueries( expectedAppQuery, @@ -1657,8 +1761,10 @@ func Test_AppPrepare(t *testing.T) { }, }, { - name: "prepareAppQuery oidc app AssertIDTokenUserinfo inactive", - prepare: prepareAppQuery, + name: "prepareAppQuery oidc app AssertIDTokenUserinfo inactive", + prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return prepareAppQuery(ctx, db, false) + }, want: want{ sqlExpectations: mockQueries( expectedAppQuery, @@ -1736,8 +1842,10 @@ func Test_AppPrepare(t *testing.T) { }, }, { - name: "prepareAppQuery sql err", - prepare: prepareAppQuery, + name: "prepareAppQuery sql err", + prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return prepareAppQuery(ctx, db, false) + }, want: want{ sqlExpectations: mockQueryErr( expectedAppQuery, diff --git a/internal/query/introspection.go b/internal/query/introspection.go index a7fdaab718..ee96bf576b 100644 --- a/internal/query/introspection.go +++ b/internal/query/introspection.go @@ -52,7 +52,7 @@ type IntrospectionClient struct { //go:embed introspection_client_by_id.sql var introspectionClientByIDQuery string -func (q *Queries) GetIntrospectionClientByID(ctx context.Context, clientID string, getKeys bool) (_ *IntrospectionClient, err error) { +func (q *Queries) ActiveIntrospectionClientByID(ctx context.Context, clientID string, getKeys bool) (_ *IntrospectionClient, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() diff --git a/internal/query/introspection_client_by_id.sql b/internal/query/introspection_client_by_id.sql index 5129d99a70..1cc6baf9ad 100644 --- a/internal/query/introspection_client_by_id.sql +++ b/internal/query/introspection_client_by_id.sql @@ -20,6 +20,7 @@ keys as ( ) select config.app_id, config.client_id, config.client_secret, config.app_type, apps.project_id, apps.resource_owner, p.project_role_assertion, keys.public_keys from config -join projections.apps7 apps on apps.id = config.app_id and apps.instance_id = config.instance_id -join projections.projects4 p on p.id = apps.project_id and p.instance_id = $1 +join projections.apps7 apps on apps.id = config.app_id and apps.instance_id = config.instance_id and apps.state = 1 +join projections.projects4 p on p.id = apps.project_id and p.instance_id = $1 and p.state = 1 +join projections.orgs1 o on o.id = p.resource_owner and o.instance_id = config.instance_id and o.org_state = 1 left join keys on keys.client_id = config.client_id; diff --git a/internal/query/introspection_test.go b/internal/query/introspection_test.go index 6535bd1639..4346842bf9 100644 --- a/internal/query/introspection_test.go +++ b/internal/query/introspection_test.go @@ -14,7 +14,7 @@ import ( "github.com/zitadel/zitadel/internal/database" ) -func TestQueries_GetIntrospectionClientByID(t *testing.T) { +func TestQueries_ActiveIntrospectionClientByID(t *testing.T) { pubkeys := database.Map[[]byte]{ "key1": {1, 2, 3}, "key2": {4, 5, 6}, @@ -96,7 +96,7 @@ func TestQueries_GetIntrospectionClientByID(t *testing.T) { }, } ctx := authz.NewMockContext("instanceID", "orgID", "userID") - got, err := q.GetIntrospectionClientByID(ctx, tt.args.clientID, tt.args.getKeys) + got, err := q.ActiveIntrospectionClientByID(ctx, tt.args.clientID, tt.args.getKeys) require.ErrorIs(t, err, tt.wantErr) assert.Equal(t, tt.want, got) }) diff --git a/internal/query/oidc_client.go b/internal/query/oidc_client.go index 6669b398b5..dd815c2e64 100644 --- a/internal/query/oidc_client.go +++ b/internal/query/oidc_client.go @@ -43,7 +43,7 @@ type OIDCClient struct { //go:embed oidc_client_by_id.sql var oidcClientQuery string -func (q *Queries) GetOIDCClientByID(ctx context.Context, clientID string, getKeys bool) (client *OIDCClient, err error) { +func (q *Queries) ActiveOIDCClientByID(ctx context.Context, clientID string, getKeys bool) (client *OIDCClient, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() diff --git a/internal/query/oidc_client_by_id.sql b/internal/query/oidc_client_by_id.sql index 3a0a0a0c95..ef471387b3 100644 --- a/internal/query/oidc_client_by_id.sql +++ b/internal/query/oidc_client_by_id.sql @@ -5,8 +5,9 @@ with client as ( c.access_token_type, c.access_token_role_assertion, c.id_token_role_assertion, c.id_token_userinfo_assertion, c.clock_skew, c.additional_origins, a.project_id, p.project_role_assertion from projections.apps7_oidc_configs c - join projections.apps7 a on a.id = c.app_id and a.instance_id = c.instance_id - join projections.projects4 p on p.id = a.project_id and p.instance_id = a.instance_id + join projections.apps7 a on a.id = c.app_id and a.instance_id = c.instance_id and a.state = 1 + join projections.projects4 p on p.id = a.project_id and p.instance_id = a.instance_id and p.state = 1 + join projections.orgs1 o on o.id = p.resource_owner and o.instance_id = c.instance_id and o.org_state = 1 where c.instance_id = $1 and c.client_id = $2 ), diff --git a/internal/query/oidc_client_test.go b/internal/query/oidc_client_test.go index 93bd428015..73a1aa800a 100644 --- a/internal/query/oidc_client_test.go +++ b/internal/query/oidc_client_test.go @@ -27,7 +27,7 @@ var ( testdataOidcClientNoSettings string ) -func TestQueries_GetOIDCClientByID(t *testing.T) { +func TestQueries_ActiveOIDCClientByID(t *testing.T) { expQuery := regexp.QuoteMeta(oidcClientQuery) cols := []string{"client"} pubkey := `-----BEGIN RSA PUBLIC KEY----- @@ -198,7 +198,7 @@ low2kyJov38V4Uk2I8kuXpLcnrpw5Tio2ooiUE27b0vHZqBKOei9Uo88qCrn3EKx }, } ctx := authz.NewMockContext("instanceID", "orgID", "loginClient") - got, err := q.GetOIDCClientByID(ctx, "clientID", true) + got, err := q.ActiveOIDCClientByID(ctx, "clientID", true) require.ErrorIs(t, err, tt.wantErr) assert.Equal(t, tt.want, got) })