some cleanup

This commit is contained in:
Tim Möhlmann 2023-11-13 19:20:01 +02:00
parent 477d565ffb
commit 3294ba4c4b
10 changed files with 47 additions and 20 deletions

View File

@ -36,7 +36,7 @@ func UserMetadataListFromQuery(c *actions.FieldConfig, metadata *query.UserMetad
func UserMetadataListFromSlice(c *actions.FieldConfig, metadata []query.UserMetadata) goja.Value { func UserMetadataListFromSlice(c *actions.FieldConfig, metadata []query.UserMetadata) goja.Value {
result := &userMetadataList{ result := &userMetadataList{
// Count was the only field ever queries from the DB in the old implementation, // Count was the only field ever queried from the DB in the old implementation,
// so Sequence and LastRun are omitted. // so Sequence and LastRun are omitted.
Count: uint64(len(metadata)), Count: uint64(len(metadata)),
Metadata: make([]*userMetadata, len(metadata)), Metadata: make([]*userMetadata, len(metadata)),

View File

@ -27,7 +27,6 @@ import (
) )
const ( const (
// TODO: remove declarations: (moved to domain package)
ScopeProjectRolePrefix = "urn:zitadel:iam:org:project:role:" ScopeProjectRolePrefix = "urn:zitadel:iam:org:project:role:"
ScopeProjectsRoles = "urn:zitadel:iam:org:projects:roles" ScopeProjectsRoles = "urn:zitadel:iam:org:projects:roles"
ClaimProjectRoles = "urn:zitadel:iam:org:project:roles" ClaimProjectRoles = "urn:zitadel:iam:org:project:roles"
@ -180,6 +179,21 @@ func (o *OPStorage) SetIntrospectionFromToken(ctx context.Context, introspection
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
if strings.HasPrefix(tokenID, command.IDPrefixV2) {
token, err := o.query.ActiveAccessTokenByToken(ctx, tokenID)
if err != nil {
return err
}
projectID, err := o.query.ProjectIDFromClientID(ctx, clientID)
if err != nil {
return errors.ThrowPermissionDenied(nil, "OIDC-Adfg5", "client not found")
}
return o.introspect(ctx, introspection,
tokenID, token.UserID, token.ClientID, clientID, projectID,
token.Audience, token.Scope,
token.AccessTokenCreation, token.AccessTokenExpiration)
}
token, err := o.repo.TokenByIDs(ctx, subject, tokenID) token, err := o.repo.TokenByIDs(ctx, subject, tokenID)
if err != nil { if err != nil {
return errors.ThrowPermissionDenied(nil, "OIDC-Dsfb2", "token is not valid or has expired") return errors.ThrowPermissionDenied(nil, "OIDC-Dsfb2", "token is not valid or has expired")
@ -370,7 +384,7 @@ func (o *OPStorage) setUserinfo(ctx context.Context, userInfo *oidc.UserInfo, us
if err != nil { if err != nil {
return err return err
} }
setUserInfoRoleClaims(userInfo, projectRoles) o.setUserInfoRoleClaims(userInfo, projectRoles)
return o.userinfoFlows(ctx, user, userGrants, userInfo) return o.userinfoFlows(ctx, user, userGrants, userInfo)
} }
@ -432,7 +446,7 @@ func (o *OPStorage) setUserInfoResourceOwner(ctx context.Context, userInfo *oidc
return nil return nil
} }
func setUserInfoRoleClaims(userInfo *oidc.UserInfo, roles *projectsRoles) { func (o *OPStorage) setUserInfoRoleClaims(userInfo *oidc.UserInfo, roles *projectsRoles) {
if roles != nil && len(roles.projects) > 0 { if roles != nil && len(roles.projects) > 0 {
if roles, ok := roles.projects[roles.requestProjectID]; ok { if roles, ok := roles.projects[roles.requestProjectID]; ok {
userInfo.AppendClaims(ClaimProjectRoles, roles) userInfo.AppendClaims(ClaimProjectRoles, roles)

View File

@ -186,7 +186,7 @@ func (s *Server) introspectionToken(ctx context.Context, accessToken string, rc
} }
tokenID, subject = split[0], split[1] tokenID, subject = split[0], split[1]
} else { } else {
verifier := op.NewAccessTokenVerifier(op.IssuerFromContext(ctx), s.storage.keySet) verifier := op.NewAccessTokenVerifier(op.IssuerFromContext(ctx), s.keySet)
claims, err := op.VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, verifier) claims, err := op.VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, verifier)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -68,7 +68,6 @@ type OPStorage struct {
command *command.Commands command *command.Commands
query *query.Queries query *query.Queries
eventstore *eventstore.Eventstore eventstore *eventstore.Eventstore
keySet *keySetCache
defaultLoginURL string defaultLoginURL string
defaultLoginURLV2 string defaultLoginURLV2 string
defaultLogoutURLV2 string defaultLogoutURLV2 string
@ -123,6 +122,8 @@ func NewServer(
storage: storage, storage: storage,
LegacyServer: op.NewLegacyServer(provider, endpoints(config.CustomEndpoints)), LegacyServer: op.NewLegacyServer(provider, endpoints(config.CustomEndpoints)),
query: query, query: query,
command: command,
keySet: newKeySet(context.TODO(), time.Hour, query.GetActivePublicKeyByID),
fallbackLogger: fallbackLogger, fallbackLogger: fallbackLogger,
hashAlg: crypto.NewBCrypt(10), // as we are only verifying in oidc, the cost is already part of the hash string and the config here is irrelevant. hashAlg: crypto.NewBCrypt(10), // as we are only verifying in oidc, the cost is already part of the hash string and the config here is irrelevant.
signingKeyAlgorithm: config.SigningKeyAlgorithm, signingKeyAlgorithm: config.SigningKeyAlgorithm,
@ -179,13 +180,12 @@ func createOPConfig(config Config, defaultLogoutRedirectURI string, cryptoKey []
return opConfig, nil return opConfig, nil
} }
func newStorage(config Config, command *command.Commands, queries *query.Queries, repo repository.Repository, encAlg crypto.EncryptionAlgorithm, es *eventstore.Eventstore, db *database.DB, externalSecure bool) *OPStorage { func newStorage(config Config, command *command.Commands, query *query.Queries, repo repository.Repository, encAlg crypto.EncryptionAlgorithm, es *eventstore.Eventstore, db *database.DB, externalSecure bool) *OPStorage {
return &OPStorage{ return &OPStorage{
repo: repo, repo: repo,
command: command, command: command,
query: queries, query: query,
eventstore: es, eventstore: es,
keySet: newKeySet(context.TODO(), time.Hour, queries.GetActivePublicKeyByID),
defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID), defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID),
defaultLoginURLV2: config.DefaultLoginURLV2, defaultLoginURLV2: config.DefaultLoginURLV2,
defaultLogoutURLV2: config.DefaultLogoutURLV2, defaultLogoutURLV2: config.DefaultLogoutURLV2,

View File

@ -9,6 +9,7 @@ import (
"github.com/zitadel/logging" "github.com/zitadel/logging"
"github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/query"
@ -22,6 +23,7 @@ type Server struct {
query *query.Queries query *query.Queries
command *command.Commands command *command.Commands
keySet *keySetCache
fallbackLogger *slog.Logger fallbackLogger *slog.Logger
hashAlg crypto.HashAlgorithm hashAlg crypto.HashAlgorithm

View File

@ -248,6 +248,17 @@ func setUserInfoOrgClaims(user *query.OIDCUserInfo, out *oidc.UserInfo) {
} }
} }
func setUserInfoRoleClaims(userInfo *oidc.UserInfo, roles *projectsRoles) {
if roles != nil && len(roles.projects) > 0 {
if roles, ok := roles.projects[roles.requestProjectID]; ok {
userInfo.AppendClaims(ClaimProjectRoles, roles)
}
for projectID, roles := range roles.projects {
userInfo.AppendClaims(fmt.Sprintf(ClaimProjectRolesFormat, projectID), roles)
}
}
}
func (s *Server) userinfoFlows(ctx context.Context, user *query.OIDCUserInfo, userGrants *query.UserGrants, userInfo *oidc.UserInfo) error { func (s *Server) userinfoFlows(ctx context.Context, user *query.OIDCUserInfo, userGrants *query.UserGrants, userInfo *oidc.UserInfo) error {
queriedActions, err := s.query.GetActiveActionsByFlowAndTriggerType(ctx, domain.FlowTypeCustomiseToken, domain.TriggerTypePreUserinfoCreation, user.User.ResourceOwner, false) queriedActions, err := s.query.GetActiveActionsByFlowAndTriggerType(ctx, domain.FlowTypeCustomiseToken, domain.TriggerTypePreUserinfoCreation, user.User.ResourceOwner, false)
if err != nil { if err != nil {

View File

@ -52,8 +52,11 @@ func (wm *OIDCSessionAccessTokenReadModel) Reduce() error {
return wm.WriteModel.Reduce() return wm.WriteModel.Reduce()
} }
func (wm *OIDCSessionAccessTokenReadModel) addQuery(b *eventstore.SearchQueryBuilder) *eventstore.SearchQueryBuilder { func (wm *OIDCSessionAccessTokenReadModel) Query() *eventstore.SearchQueryBuilder {
return b.AddQuery(). return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AwaitOpenTransactions().
AllowTimeTravel().
AddQuery().
AggregateTypes(oidcsession.AggregateType). AggregateTypes(oidcsession.AggregateType).
AggregateIDs(wm.AggregateID). AggregateIDs(wm.AggregateID).
EventTypes( EventTypes(
@ -65,14 +68,6 @@ func (wm *OIDCSessionAccessTokenReadModel) addQuery(b *eventstore.SearchQueryBui
Builder() Builder()
} }
func (wm *OIDCSessionAccessTokenReadModel) Query() *eventstore.SearchQueryBuilder {
return wm.addQuery(
eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AwaitOpenTransactions().
AllowTimeTravel(),
)
}
func (wm *OIDCSessionAccessTokenReadModel) reduceAdded(e *oidcsession.AddedEvent) { func (wm *OIDCSessionAccessTokenReadModel) reduceAdded(e *oidcsession.AddedEvent) {
wm.UserID = e.UserID wm.UserID = e.UserID
wm.SessionID = e.SessionID wm.SessionID = e.SessionID

View File

@ -45,6 +45,6 @@ select json_build_object(
left join machine m on u.id = m.user_id left join machine m on u.id = m.user_id
) r ) r
), ),
'organization', (select organization from org), 'org', (select organization from org),
'metadata', (select metadata from metadata) 'metadata', (select metadata from metadata)
); );

View File

@ -387,6 +387,7 @@ func (wm *PublicKeyReadModel) Reduce() error {
func (wm *PublicKeyReadModel) Query() *eventstore.SearchQueryBuilder { func (wm *PublicKeyReadModel) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AwaitOpenTransactions().
ResourceOwner(wm.ResourceOwner). ResourceOwner(wm.ResourceOwner).
AddQuery(). AddQuery().
AggregateTypes(keypair.AggregateType). AggregateTypes(keypair.AggregateType).

View File

@ -9,12 +9,16 @@ import (
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
) )
//go:embed embed/userinfo_by_id.sql //go:embed embed/userinfo_by_id.sql
var oidcUserInfoQuery string var oidcUserInfoQuery string
func (q *Queries) GetOIDCUserInfo(ctx context.Context, userID string) (_ *OIDCUserInfo, err error) { func (q *Queries) GetOIDCUserInfo(ctx context.Context, userID string) (_ *OIDCUserInfo, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
userInfo := new(OIDCUserInfo) userInfo := new(OIDCUserInfo)
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
var data []byte var data []byte