diff --git a/internal/api/oidc/auth_request.go b/internal/api/oidc/auth_request.go index 384dc402a5..7cb7ca7af0 100644 --- a/internal/api/oidc/auth_request.go +++ b/internal/api/oidc/auth_request.go @@ -475,11 +475,8 @@ func (o *OPStorage) assertProjectRoleScopes(ctx context.Context, clientID string return scopes, nil } } - projectID, err := o.query.ProjectIDFromOIDCClientID(ctx, clientID) - if err != nil { - return nil, zerrors.ThrowPreconditionFailed(nil, "OIDC-AEG4d", "Errors.Internal") - } - project, err := o.query.ProjectByID(ctx, false, projectID) + + project, err := o.query.ProjectByOIDCClientID(ctx, clientID) if err != nil { return nil, zerrors.ThrowPreconditionFailed(nil, "OIDC-w4wIn", "Errors.Internal") } diff --git a/internal/query/app.go b/internal/query/app.go index 693bc7f0c3..7d69981dd4 100644 --- a/internal/query/app.go +++ b/internal/query/app.go @@ -299,7 +299,7 @@ func (q *Queries) AppBySAMLEntityID(ctx context.Context, entityID string) (app * ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareAppQuery(ctx, q.client) + stmt, scan := prepareSAMLAppQuery(ctx, q.client) eq := sq.Eq{ AppSAMLConfigColumnEntityID.identifier(): entityID, AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -341,27 +341,6 @@ func (q *Queries) ProjectByClientID(ctx context.Context, appID string) (project return project, err } -func (q *Queries) ProjectIDFromOIDCClientID(ctx context.Context, appID string) (id string, err error) { - ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() - - stmt, scan := prepareProjectIDByAppQuery(ctx, q.client) - eq := sq.Eq{ - AppOIDCConfigColumnClientID.identifier(): appID, - AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), - } - query, args, err := stmt.Where(eq).ToSql() - if err != nil { - return "", zerrors.ThrowInternal(err, "QUERY-7d92U", "Errors.Query.SQLStatement") - } - - err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { - id, err = scan(row) - return err - }, query, args...) - return id, err -} - func (q *Queries) ProjectIDFromClientID(ctx context.Context, appID string) (id string, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -392,7 +371,7 @@ func (q *Queries) ProjectByOIDCClientID(ctx context.Context, id string) (project ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareProjectByAppQuery(ctx, q.client) + stmt, scan := prepareProjectByOIDCAppQuery(ctx, q.client) eq := sq.Eq{ AppOIDCConfigColumnClientID.identifier(): id, AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -413,7 +392,7 @@ func (q *Queries) AppByOIDCClientID(ctx context.Context, clientID string) (app * ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareAppQuery(ctx, q.client) + stmt, scan := prepareOIDCAppQuery() eq := sq.Eq{ AppOIDCConfigColumnClientID.identifier(): clientID, AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -615,6 +594,138 @@ func prepareAppQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, } } +func prepareOIDCAppQuery() (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return sq.Select( + AppColumnID.identifier(), + AppColumnName.identifier(), + AppColumnProjectID.identifier(), + AppColumnCreationDate.identifier(), + AppColumnChangeDate.identifier(), + AppColumnResourceOwner.identifier(), + AppColumnState.identifier(), + AppColumnSequence.identifier(), + + AppOIDCConfigColumnAppID.identifier(), + AppOIDCConfigColumnVersion.identifier(), + AppOIDCConfigColumnClientID.identifier(), + AppOIDCConfigColumnRedirectUris.identifier(), + AppOIDCConfigColumnResponseTypes.identifier(), + AppOIDCConfigColumnGrantTypes.identifier(), + AppOIDCConfigColumnApplicationType.identifier(), + AppOIDCConfigColumnAuthMethodType.identifier(), + AppOIDCConfigColumnPostLogoutRedirectUris.identifier(), + AppOIDCConfigColumnDevMode.identifier(), + AppOIDCConfigColumnAccessTokenType.identifier(), + AppOIDCConfigColumnAccessTokenRoleAssertion.identifier(), + AppOIDCConfigColumnIDTokenRoleAssertion.identifier(), + AppOIDCConfigColumnIDTokenUserinfoAssertion.identifier(), + AppOIDCConfigColumnClockSkew.identifier(), + AppOIDCConfigColumnAdditionalOrigins.identifier(), + AppOIDCConfigColumnSkipNativeAppSuccessPage.identifier(), + ).From(appsTable.identifier()). + Join(join(AppOIDCConfigColumnAppID, AppColumnID)). + PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*App, error) { + app := new(App) + + var ( + oidcConfig = sqlOIDCConfig{} + ) + + err := row.Scan( + &app.ID, + &app.Name, + &app.ProjectID, + &app.CreationDate, + &app.ChangeDate, + &app.ResourceOwner, + &app.State, + &app.Sequence, + + &oidcConfig.appID, + &oidcConfig.version, + &oidcConfig.clientID, + &oidcConfig.redirectUris, + &oidcConfig.responseTypes, + &oidcConfig.grantTypes, + &oidcConfig.applicationType, + &oidcConfig.authMethodType, + &oidcConfig.postLogoutRedirectUris, + &oidcConfig.devMode, + &oidcConfig.accessTokenType, + &oidcConfig.accessTokenRoleAssertion, + &oidcConfig.iDTokenRoleAssertion, + &oidcConfig.iDTokenUserinfoAssertion, + &oidcConfig.clockSkew, + &oidcConfig.additionalOrigins, + &oidcConfig.skipNativeAppSuccessPage, + ) + + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, zerrors.ThrowNotFound(err, "QUERY-Fdfax", "Errors.App.NotExisting") + } + return nil, zerrors.ThrowInternal(err, "QUERY-aE7iE", "Errors.Internal") + } + + oidcConfig.set(app) + + return app, nil + } +} + +func prepareSAMLAppQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return sq.Select( + AppColumnID.identifier(), + AppColumnName.identifier(), + AppColumnProjectID.identifier(), + AppColumnCreationDate.identifier(), + AppColumnChangeDate.identifier(), + AppColumnResourceOwner.identifier(), + AppColumnState.identifier(), + AppColumnSequence.identifier(), + + AppSAMLConfigColumnAppID.identifier(), + AppSAMLConfigColumnEntityID.identifier(), + AppSAMLConfigColumnMetadata.identifier(), + AppSAMLConfigColumnMetadataURL.identifier(), + ).From(appsTable.identifier()). + Join(join(AppSAMLConfigColumnAppID, AppColumnID)). + PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*App, error) { + + app := new(App) + var ( + samlConfig = sqlSAMLConfig{} + ) + + err := row.Scan( + &app.ID, + &app.Name, + &app.ProjectID, + &app.CreationDate, + &app.ChangeDate, + &app.ResourceOwner, + &app.State, + &app.Sequence, + + &samlConfig.appID, + &samlConfig.entityID, + &samlConfig.metadata, + &samlConfig.metadataURL, + ) + + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, zerrors.ThrowNotFound(err, "QUERY-d6TO1", "Errors.App.NotExisting") + } + return nil, zerrors.ThrowInternal(err, "QUERY-NAtPg", "Errors.Internal") + } + + samlConfig.set(app) + + return app, nil + } +} + func prepareProjectIDByAppQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (projectID string, err error)) { return sq.Select( AppColumnProjectID.identifier(), @@ -638,6 +749,48 @@ func prepareProjectIDByAppQuery(ctx context.Context, db prepareDatabase) (sq.Sel } } +func prepareProjectByOIDCAppQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Project, error)) { + return sq.Select( + ProjectColumnID.identifier(), + ProjectColumnCreationDate.identifier(), + ProjectColumnChangeDate.identifier(), + ProjectColumnResourceOwner.identifier(), + ProjectColumnState.identifier(), + ProjectColumnSequence.identifier(), + ProjectColumnName.identifier(), + ProjectColumnProjectRoleAssertion.identifier(), + ProjectColumnProjectRoleCheck.identifier(), + ProjectColumnHasProjectCheck.identifier(), + ProjectColumnPrivateLabelingSetting.identifier(), + ).From(projectsTable.identifier()). + Join(join(AppColumnProjectID, ProjectColumnID)). + Join(join(AppOIDCConfigColumnAppID, AppColumnID)). + PlaceholderFormat(sq.Dollar), + func(row *sql.Row) (*Project, error) { + p := new(Project) + err := row.Scan( + &p.ID, + &p.CreationDate, + &p.ChangeDate, + &p.ResourceOwner, + &p.State, + &p.Sequence, + &p.Name, + &p.ProjectRoleAssertion, + &p.ProjectRoleCheck, + &p.HasProjectCheck, + &p.PrivateLabelingSetting, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, zerrors.ThrowNotFound(err, "QUERY-yxTMh", "Errors.Project.NotFound") + } + return nil, zerrors.ThrowInternal(err, "QUERY-dj2FF", "Errors.Internal") + } + return p, nil + } +} + func prepareProjectByAppQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Project, error)) { return sq.Select( ProjectColumnID.identifier(),