mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 01:37:31 +00:00
feat(API): support V2 token and session token usage (#6180)
This PR adds support for userinfo and introspection of V2 tokens. Further V2 access tokens and session tokens can be used for authentication on the ZITADEL API (like the current access tokens).
This commit is contained in:
113
internal/query/access_token.go
Normal file
113
internal/query/access_token.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/oidcsession"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
type OIDCSessionAccessTokenReadModel struct {
|
||||
eventstore.WriteModel
|
||||
|
||||
UserID string
|
||||
SessionID string
|
||||
ClientID string
|
||||
Audience []string
|
||||
Scope []string
|
||||
AuthMethods []domain.UserAuthMethodType
|
||||
AuthTime time.Time
|
||||
State domain.OIDCSessionState
|
||||
AccessTokenID string
|
||||
AccessTokenCreation time.Time
|
||||
AccessTokenExpiration time.Time
|
||||
}
|
||||
|
||||
func newOIDCSessionAccessTokenWriteModel(id string) *OIDCSessionAccessTokenReadModel {
|
||||
return &OIDCSessionAccessTokenReadModel{
|
||||
WriteModel: eventstore.WriteModel{
|
||||
AggregateID: id,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (wm *OIDCSessionAccessTokenReadModel) Reduce() error {
|
||||
for _, event := range wm.Events {
|
||||
switch e := event.(type) {
|
||||
case *oidcsession.AddedEvent:
|
||||
wm.reduceAdded(e)
|
||||
case *oidcsession.AccessTokenAddedEvent:
|
||||
wm.reduceAccessTokenAdded(e)
|
||||
}
|
||||
}
|
||||
return wm.WriteModel.Reduce()
|
||||
}
|
||||
|
||||
func (wm *OIDCSessionAccessTokenReadModel) Query() *eventstore.SearchQueryBuilder {
|
||||
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
AllowTimeTravel().
|
||||
AddQuery().
|
||||
AggregateTypes(oidcsession.AggregateType).
|
||||
AggregateIDs(wm.AggregateID).
|
||||
EventTypes(
|
||||
oidcsession.AddedType,
|
||||
oidcsession.AccessTokenAddedType,
|
||||
).
|
||||
Builder()
|
||||
}
|
||||
|
||||
func (wm *OIDCSessionAccessTokenReadModel) reduceAdded(e *oidcsession.AddedEvent) {
|
||||
wm.UserID = e.UserID
|
||||
wm.SessionID = e.SessionID
|
||||
wm.ClientID = e.ClientID
|
||||
wm.Audience = e.Audience
|
||||
wm.Scope = e.Scope
|
||||
wm.AuthMethods = e.AuthMethods
|
||||
wm.AuthTime = e.AuthTime
|
||||
wm.State = domain.OIDCSessionStateActive
|
||||
}
|
||||
|
||||
func (wm *OIDCSessionAccessTokenReadModel) reduceAccessTokenAdded(e *oidcsession.AccessTokenAddedEvent) {
|
||||
wm.AccessTokenID = e.ID
|
||||
wm.AccessTokenCreation = e.CreationDate()
|
||||
wm.AccessTokenExpiration = e.CreationDate().Add(e.Lifetime)
|
||||
}
|
||||
|
||||
// ActiveAccessTokenByToken will check if the token is active by retrieving the OIDCSession events from the eventstore.
|
||||
// refreshed or expired tokens will return an error
|
||||
func (q *Queries) ActiveAccessTokenByToken(ctx context.Context, token string) (model *OIDCSessionAccessTokenReadModel, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
split := strings.Split(token, "-")
|
||||
if len(split) != 2 {
|
||||
return nil, caos_errs.ThrowPermissionDenied(nil, "QUERY-SAhtk", "Errors.OIDCSession.Token.Invalid")
|
||||
}
|
||||
model, err = q.accessTokenByOIDCSessionAndTokenID(ctx, split[0], split[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !model.AccessTokenExpiration.After(time.Now()) {
|
||||
return nil, caos_errs.ThrowPermissionDenied(nil, "QUERY-SAF3rf", "Errors.OIDCSession.Token.Expired")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (q *Queries) accessTokenByOIDCSessionAndTokenID(ctx context.Context, oidcSessionID, tokenID string) (model *OIDCSessionAccessTokenReadModel, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
model = newOIDCSessionAccessTokenWriteModel(oidcSessionID)
|
||||
if err = q.eventstore.FilterToQueryReducer(ctx, model); err != nil {
|
||||
return nil, caos_errs.ThrowPermissionDenied(err, "QUERY-ASfe2", "Errors.OIDCSession.Token.Invalid")
|
||||
}
|
||||
if model.AccessTokenID != tokenID {
|
||||
return nil, caos_errs.ThrowPermissionDenied(nil, "QUERY-M2u9w", "Errors.OIDCSession.Token.Invalid")
|
||||
}
|
||||
return model, nil
|
||||
}
|
@@ -40,6 +40,7 @@ type Session struct {
|
||||
|
||||
type SessionUserFactor struct {
|
||||
UserID string
|
||||
ResourceOwner string
|
||||
UserCheckedAt time.Time
|
||||
LoginName string
|
||||
DisplayName string
|
||||
@@ -225,6 +226,7 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil
|
||||
SessionColumnUserCheckedAt.identifier(),
|
||||
LoginNameNameCol.identifier(),
|
||||
HumanDisplayNameCol.identifier(),
|
||||
UserResourceOwnerCol.identifier(),
|
||||
SessionColumnPasswordCheckedAt.identifier(),
|
||||
SessionColumnIntentCheckedAt.identifier(),
|
||||
SessionColumnPasskeyCheckedAt.identifier(),
|
||||
@@ -232,7 +234,8 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil
|
||||
SessionColumnToken.identifier(),
|
||||
).From(sessionsTable.identifier()).
|
||||
LeftJoin(join(LoginNameUserIDCol, SessionColumnUserID)).
|
||||
LeftJoin(join(HumanUserIDCol, SessionColumnUserID) + db.Timetravel(call.Took(ctx))).
|
||||
LeftJoin(join(HumanUserIDCol, SessionColumnUserID)).
|
||||
LeftJoin(join(UserIDCol, SessionColumnUserID) + db.Timetravel(call.Took(ctx))).
|
||||
PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*Session, string, error) {
|
||||
session := new(Session)
|
||||
|
||||
@@ -241,6 +244,7 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil
|
||||
userCheckedAt sql.NullTime
|
||||
loginName sql.NullString
|
||||
displayName sql.NullString
|
||||
userResourceOwner sql.NullString
|
||||
passwordCheckedAt sql.NullTime
|
||||
intentCheckedAt sql.NullTime
|
||||
passkeyCheckedAt sql.NullTime
|
||||
@@ -262,6 +266,7 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil
|
||||
&userCheckedAt,
|
||||
&loginName,
|
||||
&displayName,
|
||||
&userResourceOwner,
|
||||
&passwordCheckedAt,
|
||||
&intentCheckedAt,
|
||||
&passkeyCheckedAt,
|
||||
@@ -281,6 +286,7 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil
|
||||
session.UserFactor.UserCheckedAt = userCheckedAt.Time
|
||||
session.UserFactor.LoginName = loginName.String
|
||||
session.UserFactor.DisplayName = displayName.String
|
||||
session.UserFactor.ResourceOwner = userResourceOwner.String
|
||||
session.PasswordFactor.PasswordCheckedAt = passwordCheckedAt.Time
|
||||
session.IntentFactor.IntentCheckedAt = intentCheckedAt.Time
|
||||
session.PasskeyFactor.PasskeyCheckedAt = passkeyCheckedAt.Time
|
||||
@@ -304,6 +310,7 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui
|
||||
SessionColumnUserCheckedAt.identifier(),
|
||||
LoginNameNameCol.identifier(),
|
||||
HumanDisplayNameCol.identifier(),
|
||||
UserResourceOwnerCol.identifier(),
|
||||
SessionColumnPasswordCheckedAt.identifier(),
|
||||
SessionColumnIntentCheckedAt.identifier(),
|
||||
SessionColumnPasskeyCheckedAt.identifier(),
|
||||
@@ -311,7 +318,8 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui
|
||||
countColumn.identifier(),
|
||||
).From(sessionsTable.identifier()).
|
||||
LeftJoin(join(LoginNameUserIDCol, SessionColumnUserID)).
|
||||
LeftJoin(join(HumanUserIDCol, SessionColumnUserID) + db.Timetravel(call.Took(ctx))).
|
||||
LeftJoin(join(HumanUserIDCol, SessionColumnUserID)).
|
||||
LeftJoin(join(UserIDCol, SessionColumnUserID) + db.Timetravel(call.Took(ctx))).
|
||||
PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*Sessions, error) {
|
||||
sessions := &Sessions{Sessions: []*Session{}}
|
||||
|
||||
@@ -323,6 +331,7 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui
|
||||
userCheckedAt sql.NullTime
|
||||
loginName sql.NullString
|
||||
displayName sql.NullString
|
||||
userResourceOwner sql.NullString
|
||||
passwordCheckedAt sql.NullTime
|
||||
intentCheckedAt sql.NullTime
|
||||
passkeyCheckedAt sql.NullTime
|
||||
@@ -343,6 +352,7 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui
|
||||
&userCheckedAt,
|
||||
&loginName,
|
||||
&displayName,
|
||||
&userResourceOwner,
|
||||
&passwordCheckedAt,
|
||||
&intentCheckedAt,
|
||||
&passkeyCheckedAt,
|
||||
@@ -358,6 +368,7 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui
|
||||
session.UserFactor.UserCheckedAt = userCheckedAt.Time
|
||||
session.UserFactor.LoginName = loginName.String
|
||||
session.UserFactor.DisplayName = displayName.String
|
||||
session.UserFactor.ResourceOwner = userResourceOwner.String
|
||||
session.PasswordFactor.PasswordCheckedAt = passwordCheckedAt.Time
|
||||
session.IntentFactor.IntentCheckedAt = intentCheckedAt.Time
|
||||
session.PasskeyFactor.PasskeyCheckedAt = passkeyCheckedAt.Time
|
||||
|
@@ -29,6 +29,7 @@ var (
|
||||
` projections.sessions3.user_checked_at,` +
|
||||
` projections.login_names2.login_name,` +
|
||||
` projections.users8_humans.display_name,` +
|
||||
` projections.users8.resource_owner,` +
|
||||
` projections.sessions3.password_checked_at,` +
|
||||
` projections.sessions3.intent_checked_at,` +
|
||||
` projections.sessions3.passkey_checked_at,` +
|
||||
@@ -37,6 +38,7 @@ var (
|
||||
` FROM projections.sessions3` +
|
||||
` LEFT JOIN projections.login_names2 ON projections.sessions3.user_id = projections.login_names2.user_id AND projections.sessions3.instance_id = projections.login_names2.instance_id` +
|
||||
` LEFT JOIN projections.users8_humans ON projections.sessions3.user_id = projections.users8_humans.user_id AND projections.sessions3.instance_id = projections.users8_humans.instance_id` +
|
||||
` LEFT JOIN projections.users8 ON projections.sessions3.user_id = projections.users8.id AND projections.sessions3.instance_id = projections.users8.instance_id` +
|
||||
` AS OF SYSTEM TIME '-1 ms'`)
|
||||
expectedSessionsQuery = regexp.QuoteMeta(`SELECT projections.sessions3.id,` +
|
||||
` projections.sessions3.creation_date,` +
|
||||
@@ -50,6 +52,7 @@ var (
|
||||
` projections.sessions3.user_checked_at,` +
|
||||
` projections.login_names2.login_name,` +
|
||||
` projections.users8_humans.display_name,` +
|
||||
` projections.users8.resource_owner,` +
|
||||
` projections.sessions3.password_checked_at,` +
|
||||
` projections.sessions3.intent_checked_at,` +
|
||||
` projections.sessions3.passkey_checked_at,` +
|
||||
@@ -58,6 +61,7 @@ var (
|
||||
` FROM projections.sessions3` +
|
||||
` LEFT JOIN projections.login_names2 ON projections.sessions3.user_id = projections.login_names2.user_id AND projections.sessions3.instance_id = projections.login_names2.instance_id` +
|
||||
` LEFT JOIN projections.users8_humans ON projections.sessions3.user_id = projections.users8_humans.user_id AND projections.sessions3.instance_id = projections.users8_humans.instance_id` +
|
||||
` LEFT JOIN projections.users8 ON projections.sessions3.user_id = projections.users8.id AND projections.sessions3.instance_id = projections.users8.instance_id` +
|
||||
` AS OF SYSTEM TIME '-1 ms'`)
|
||||
|
||||
sessionCols = []string{
|
||||
@@ -73,6 +77,7 @@ var (
|
||||
"user_checked_at",
|
||||
"login_name",
|
||||
"display_name",
|
||||
"user_resource_owner",
|
||||
"password_checked_at",
|
||||
"intent_checked_at",
|
||||
"passkey_checked_at",
|
||||
@@ -93,6 +98,7 @@ var (
|
||||
"user_checked_at",
|
||||
"login_name",
|
||||
"display_name",
|
||||
"user_resource_owner",
|
||||
"password_checked_at",
|
||||
"intent_checked_at",
|
||||
"passkey_checked_at",
|
||||
@@ -145,6 +151,7 @@ func Test_SessionsPrepare(t *testing.T) {
|
||||
testNow,
|
||||
"login-name",
|
||||
"display-name",
|
||||
"resourceOwner",
|
||||
testNow,
|
||||
testNow,
|
||||
testNow,
|
||||
@@ -172,6 +179,7 @@ func Test_SessionsPrepare(t *testing.T) {
|
||||
UserCheckedAt: testNow,
|
||||
LoginName: "login-name",
|
||||
DisplayName: "display-name",
|
||||
ResourceOwner: "resourceOwner",
|
||||
},
|
||||
PasswordFactor: SessionPasswordFactor{
|
||||
PasswordCheckedAt: testNow,
|
||||
@@ -210,6 +218,7 @@ func Test_SessionsPrepare(t *testing.T) {
|
||||
testNow,
|
||||
"login-name",
|
||||
"display-name",
|
||||
"resourceOwner",
|
||||
testNow,
|
||||
testNow,
|
||||
testNow,
|
||||
@@ -228,6 +237,7 @@ func Test_SessionsPrepare(t *testing.T) {
|
||||
testNow,
|
||||
"login-name2",
|
||||
"display-name2",
|
||||
"resourceOwner",
|
||||
testNow,
|
||||
testNow,
|
||||
testNow,
|
||||
@@ -255,6 +265,7 @@ func Test_SessionsPrepare(t *testing.T) {
|
||||
UserCheckedAt: testNow,
|
||||
LoginName: "login-name",
|
||||
DisplayName: "display-name",
|
||||
ResourceOwner: "resourceOwner",
|
||||
},
|
||||
PasswordFactor: SessionPasswordFactor{
|
||||
PasswordCheckedAt: testNow,
|
||||
@@ -283,6 +294,7 @@ func Test_SessionsPrepare(t *testing.T) {
|
||||
UserCheckedAt: testNow,
|
||||
LoginName: "login-name2",
|
||||
DisplayName: "display-name2",
|
||||
ResourceOwner: "resourceOwner",
|
||||
},
|
||||
PasswordFactor: SessionPasswordFactor{
|
||||
PasswordCheckedAt: testNow,
|
||||
@@ -374,6 +386,7 @@ func Test_SessionPrepare(t *testing.T) {
|
||||
testNow,
|
||||
"login-name",
|
||||
"display-name",
|
||||
"resourceOwner",
|
||||
testNow,
|
||||
testNow,
|
||||
testNow,
|
||||
@@ -396,6 +409,7 @@ func Test_SessionPrepare(t *testing.T) {
|
||||
UserCheckedAt: testNow,
|
||||
LoginName: "login-name",
|
||||
DisplayName: "display-name",
|
||||
ResourceOwner: "resourceOwner",
|
||||
},
|
||||
PasswordFactor: SessionPasswordFactor{
|
||||
PasswordCheckedAt: testNow,
|
||||
|
@@ -79,6 +79,12 @@ var (
|
||||
name: "count",
|
||||
table: userIDPsCountTable,
|
||||
}
|
||||
|
||||
forceMFATable = loginPolicyTable.setAlias("auth_methods_force_mfa")
|
||||
forceMFAInstanceID = LoginPolicyColumnInstanceID.setTable(forceMFATable)
|
||||
forceMFAOrgID = LoginPolicyColumnOrgID.setTable(forceMFATable)
|
||||
forceMFAIsDefault = LoginPolicyColumnIsDefault.setTable(forceMFATable)
|
||||
forceMFAForce = LoginPolicyColumnForceMFA.setTable(forceMFATable)
|
||||
)
|
||||
|
||||
type AuthMethods struct {
|
||||
@@ -170,6 +176,36 @@ func (q *Queries) ListActiveUserAuthMethodTypes(ctx context.Context, userID stri
|
||||
return userAuthMethodTypes, err
|
||||
}
|
||||
|
||||
func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID string, withOwnerRemoved bool) (userAuthMethodTypes []domain.UserAuthMethodType, forceMFA bool, err error) {
|
||||
ctxData := authz.GetCtxData(ctx)
|
||||
if ctxData.UserID != userID {
|
||||
if err := q.checkPermission(ctx, domain.PermissionUserRead, ctxData.OrgID, userID); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
}
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
query, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, q.client)
|
||||
eq := sq.Eq{
|
||||
UserIDCol.identifier(): userID,
|
||||
UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
|
||||
}
|
||||
if !withOwnerRemoved {
|
||||
eq[UserOwnerRemovedCol.identifier()] = false
|
||||
}
|
||||
stmt, args, err := query.Where(eq).ToSql()
|
||||
if err != nil {
|
||||
return nil, false, errors.ThrowInvalidArgument(err, "QUERY-E5ut4", "Errors.Query.InvalidRequest")
|
||||
}
|
||||
|
||||
rows, err := q.client.QueryContext(ctx, stmt, args...)
|
||||
if err != nil || rows.Err() != nil {
|
||||
return nil, false, errors.ThrowInternal(err, "QUERY-Dun75", "Errors.Internal")
|
||||
}
|
||||
return scan(rows)
|
||||
}
|
||||
|
||||
func NewUserAuthMethodUserIDSearchQuery(value string) (SearchQuery, error) {
|
||||
return NewTextQuery(UserAuthMethodColumnUserID, value, TextEquals)
|
||||
}
|
||||
@@ -311,26 +347,11 @@ func prepareUserAuthMethodsQuery(ctx context.Context, db prepareDatabase) (sq.Se
|
||||
}
|
||||
|
||||
func prepareActiveUserAuthMethodTypesQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) {
|
||||
authMethodsQuery, authMethodsArgs, err := sq.Select(
|
||||
"DISTINCT("+authMethodTypeTypes.identifier()+")",
|
||||
authMethodTypeUserID.identifier(),
|
||||
authMethodTypeInstanceID.identifier()).
|
||||
From(authMethodTypeTable.identifier()).
|
||||
Where(sq.Eq{authMethodTypeState.identifier(): domain.MFAStateReady}).
|
||||
ToSql()
|
||||
authMethodsQuery, authMethodsArgs, err := prepareAuthMethodQuery()
|
||||
if err != nil {
|
||||
return sq.SelectBuilder{}, nil
|
||||
}
|
||||
idpsQuery, _, err := sq.Select(
|
||||
userIDPsCountUserID.identifier(),
|
||||
userIDPsCountInstanceID.identifier(),
|
||||
"COUNT("+userIDPsCountUserID.identifier()+") AS "+userIDPsCountCount.name).
|
||||
From(userIDPsCountTable.identifier()).
|
||||
GroupBy(
|
||||
userIDPsCountUserID.identifier(),
|
||||
userIDPsCountInstanceID.identifier(),
|
||||
).
|
||||
ToSql()
|
||||
idpsQuery, err := prepareAuthMethodsIDPsQuery()
|
||||
if err != nil {
|
||||
return sq.SelectBuilder{}, nil
|
||||
}
|
||||
@@ -386,3 +407,106 @@ func prepareActiveUserAuthMethodTypesQuery(ctx context.Context, db prepareDataba
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) ([]domain.UserAuthMethodType, bool, error)) {
|
||||
loginPolicyQuery, err := prepareAuthMethodsForceMFAQuery()
|
||||
if err != nil {
|
||||
return sq.SelectBuilder{}, nil
|
||||
}
|
||||
authMethodsQuery, authMethodsArgs, err := prepareAuthMethodQuery()
|
||||
if err != nil {
|
||||
return sq.SelectBuilder{}, nil
|
||||
}
|
||||
idpsQuery, err := prepareAuthMethodsIDPsQuery()
|
||||
if err != nil {
|
||||
return sq.SelectBuilder{}, nil
|
||||
}
|
||||
return sq.Select(
|
||||
NotifyPasswordSetCol.identifier(),
|
||||
authMethodTypeTypes.identifier(),
|
||||
userIDPsCountCount.identifier(),
|
||||
forceMFAForce.identifier()).
|
||||
From(userTable.identifier()).
|
||||
LeftJoin(join(NotifyUserIDCol, UserIDCol)).
|
||||
LeftJoin("("+authMethodsQuery+") AS "+authMethodTypeTable.alias+" ON "+
|
||||
authMethodTypeUserID.identifier()+" = "+UserIDCol.identifier()+" AND "+
|
||||
authMethodTypeInstanceID.identifier()+" = "+UserInstanceIDCol.identifier(),
|
||||
authMethodsArgs...).
|
||||
LeftJoin("(" + idpsQuery + ") AS " + userIDPsCountTable.alias + " ON " +
|
||||
userIDPsCountUserID.identifier() + " = " + UserIDCol.identifier() + " AND " +
|
||||
userIDPsCountInstanceID.identifier() + " = " + UserInstanceIDCol.identifier()).
|
||||
LeftJoin("(" + loginPolicyQuery + ") AS " + forceMFATable.alias + " ON " +
|
||||
"(" + forceMFAOrgID.identifier() + " = " + UserInstanceIDCol.identifier() + " OR " + forceMFAOrgID.identifier() + " = " + UserResourceOwnerCol.identifier() + ") AND " +
|
||||
forceMFAInstanceID.identifier() + " = " + UserInstanceIDCol.identifier() + db.Timetravel(call.Took(ctx))).
|
||||
PlaceholderFormat(sq.Dollar),
|
||||
func(rows *sql.Rows) ([]domain.UserAuthMethodType, bool, error) {
|
||||
userAuthMethodTypes := make([]domain.UserAuthMethodType, 0)
|
||||
var passwordSet sql.NullBool
|
||||
var idp sql.NullInt64
|
||||
var forceMFA sql.NullBool
|
||||
for rows.Next() {
|
||||
var authMethodType sql.NullInt16
|
||||
err := rows.Scan(
|
||||
&passwordSet,
|
||||
&authMethodType,
|
||||
&idp,
|
||||
&forceMFA,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if authMethodType.Valid {
|
||||
userAuthMethodTypes = append(userAuthMethodTypes, domain.UserAuthMethodType(authMethodType.Int16))
|
||||
}
|
||||
}
|
||||
if passwordSet.Valid && passwordSet.Bool {
|
||||
userAuthMethodTypes = append(userAuthMethodTypes, domain.UserAuthMethodTypePassword)
|
||||
}
|
||||
if idp.Valid && idp.Int64 > 0 {
|
||||
logging.Error("IDP", idp.Int64)
|
||||
userAuthMethodTypes = append(userAuthMethodTypes, domain.UserAuthMethodTypeIDP)
|
||||
}
|
||||
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, false, errors.ThrowInternal(err, "QUERY-W4zje", "Errors.Query.CloseRows")
|
||||
}
|
||||
|
||||
return userAuthMethodTypes, forceMFA.Bool, nil
|
||||
}
|
||||
}
|
||||
|
||||
func prepareAuthMethodsIDPsQuery() (string, error) {
|
||||
idpsQuery, _, err := sq.Select(
|
||||
userIDPsCountUserID.identifier(),
|
||||
userIDPsCountInstanceID.identifier(),
|
||||
"COUNT("+userIDPsCountUserID.identifier()+") AS "+userIDPsCountCount.name).
|
||||
From(userIDPsCountTable.identifier()).
|
||||
GroupBy(
|
||||
userIDPsCountUserID.identifier(),
|
||||
userIDPsCountInstanceID.identifier(),
|
||||
).
|
||||
ToSql()
|
||||
return idpsQuery, err
|
||||
}
|
||||
|
||||
func prepareAuthMethodQuery() (string, []interface{}, error) {
|
||||
return sq.Select(
|
||||
"DISTINCT("+authMethodTypeTypes.identifier()+")",
|
||||
authMethodTypeUserID.identifier(),
|
||||
authMethodTypeInstanceID.identifier()).
|
||||
From(authMethodTypeTable.identifier()).
|
||||
Where(sq.Eq{authMethodTypeState.identifier(): domain.MFAStateReady}).
|
||||
ToSql()
|
||||
}
|
||||
|
||||
func prepareAuthMethodsForceMFAQuery() (string, error) {
|
||||
loginPolicyQuery, _, err := sq.Select(
|
||||
forceMFAForce.identifier(),
|
||||
forceMFAInstanceID.identifier(),
|
||||
forceMFAOrgID.identifier(),
|
||||
).
|
||||
From(forceMFATable.identifier()).
|
||||
OrderBy(forceMFAIsDefault.identifier()).
|
||||
ToSql()
|
||||
return loginPolicyQuery, err
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
@@ -8,6 +9,8 @@ import (
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
)
|
||||
|
||||
@@ -53,6 +56,28 @@ var (
|
||||
"method_type",
|
||||
"idps_count",
|
||||
}
|
||||
prepareAuthMethodTypesRequiredStmt = `SELECT projections.users8_notifications.password_set,` +
|
||||
` auth_method_types.method_type,` +
|
||||
` user_idps_count.count,` +
|
||||
` auth_methods_force_mfa.force_mfa` +
|
||||
` FROM projections.users8` +
|
||||
` LEFT JOIN projections.users8_notifications ON projections.users8.id = projections.users8_notifications.user_id AND projections.users8.instance_id = projections.users8_notifications.instance_id` +
|
||||
` LEFT JOIN (SELECT DISTINCT(auth_method_types.method_type), auth_method_types.user_id, auth_method_types.instance_id FROM projections.user_auth_methods4 AS auth_method_types` +
|
||||
` WHERE auth_method_types.state = $1) AS auth_method_types` +
|
||||
` ON auth_method_types.user_id = projections.users8.id AND auth_method_types.instance_id = projections.users8.instance_id` +
|
||||
` LEFT JOIN (SELECT user_idps_count.user_id, user_idps_count.instance_id, COUNT(user_idps_count.user_id) AS count FROM projections.idp_user_links3 AS user_idps_count` +
|
||||
` GROUP BY user_idps_count.user_id, user_idps_count.instance_id) AS user_idps_count` +
|
||||
` ON user_idps_count.user_id = projections.users8.id AND user_idps_count.instance_id = projections.users8.instance_id` +
|
||||
` LEFT JOIN (SELECT auth_methods_force_mfa.force_mfa, auth_methods_force_mfa.instance_id, auth_methods_force_mfa.aggregate_id FROM projections.login_policies4 AS auth_methods_force_mfa ORDER BY auth_methods_force_mfa.is_default) AS auth_methods_force_mfa` +
|
||||
` ON (auth_methods_force_mfa.aggregate_id = projections.users8.instance_id OR auth_methods_force_mfa.aggregate_id = projections.users8.resource_owner) AND auth_methods_force_mfa.instance_id = projections.users8.instance_id` +
|
||||
` AS OF SYSTEM TIME '-1 ms
|
||||
`
|
||||
prepareAuthMethodTypesRequiredCols = []string{
|
||||
"password_set",
|
||||
"method_type",
|
||||
"idps_count",
|
||||
"force_mfa",
|
||||
}
|
||||
)
|
||||
|
||||
func Test_UserAuthMethodPrepares(t *testing.T) {
|
||||
@@ -288,6 +313,131 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
|
||||
},
|
||||
object: nil,
|
||||
},
|
||||
{
|
||||
name: "prepareUserAuthMethodTypesRequiredQuery no result",
|
||||
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*testUserAuthMethodTypesRequired, error)) {
|
||||
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
|
||||
return builder, func(rows *sql.Rows) (*testUserAuthMethodTypesRequired, error) {
|
||||
authMethods, forceMFA, err := scan(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &testUserAuthMethodTypesRequired{authMethods: authMethods, forceMFA: forceMFA}, nil
|
||||
}
|
||||
},
|
||||
want: want{
|
||||
sqlExpectations: mockQueries(
|
||||
regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt),
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
},
|
||||
object: &testUserAuthMethodTypesRequired{authMethods: []domain.UserAuthMethodType{}, forceMFA: false},
|
||||
},
|
||||
{
|
||||
name: "prepareUserAuthMethodTypesRequiredQuery one second factor",
|
||||
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*testUserAuthMethodTypesRequired, error)) {
|
||||
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
|
||||
return builder, func(rows *sql.Rows) (*testUserAuthMethodTypesRequired, error) {
|
||||
authMethods, forceMFA, err := scan(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &testUserAuthMethodTypesRequired{authMethods: authMethods, forceMFA: forceMFA}, nil
|
||||
}
|
||||
},
|
||||
want: want{
|
||||
sqlExpectations: mockQueries(
|
||||
regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt),
|
||||
prepareAuthMethodTypesRequiredCols,
|
||||
[][]driver.Value{
|
||||
{
|
||||
true,
|
||||
domain.UserAuthMethodTypePasswordless,
|
||||
1,
|
||||
true,
|
||||
},
|
||||
},
|
||||
),
|
||||
},
|
||||
object: &testUserAuthMethodTypesRequired{
|
||||
authMethods: []domain.UserAuthMethodType{
|
||||
domain.UserAuthMethodTypePasswordless,
|
||||
domain.UserAuthMethodTypePassword,
|
||||
domain.UserAuthMethodTypeIDP,
|
||||
},
|
||||
forceMFA: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "prepareUserAuthMethodTypesRequiredQuery multiple second factors",
|
||||
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*testUserAuthMethodTypesRequired, error)) {
|
||||
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
|
||||
return builder, func(rows *sql.Rows) (*testUserAuthMethodTypesRequired, error) {
|
||||
authMethods, forceMFA, err := scan(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &testUserAuthMethodTypesRequired{authMethods: authMethods, forceMFA: forceMFA}, nil
|
||||
}
|
||||
},
|
||||
want: want{
|
||||
sqlExpectations: mockQueries(
|
||||
regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt),
|
||||
prepareAuthMethodTypesRequiredCols,
|
||||
[][]driver.Value{
|
||||
{
|
||||
true,
|
||||
domain.UserAuthMethodTypePasswordless,
|
||||
1,
|
||||
true,
|
||||
},
|
||||
{
|
||||
true,
|
||||
domain.UserAuthMethodTypeOTP,
|
||||
1,
|
||||
true,
|
||||
},
|
||||
},
|
||||
),
|
||||
},
|
||||
|
||||
object: &testUserAuthMethodTypesRequired{
|
||||
authMethods: []domain.UserAuthMethodType{
|
||||
domain.UserAuthMethodTypePasswordless,
|
||||
domain.UserAuthMethodTypeOTP,
|
||||
domain.UserAuthMethodTypePassword,
|
||||
domain.UserAuthMethodTypeIDP,
|
||||
},
|
||||
forceMFA: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "prepareUserAuthMethodTypesRequiredQuery sql err",
|
||||
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*testUserAuthMethodTypesRequired, error)) {
|
||||
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
|
||||
return builder, func(rows *sql.Rows) (*testUserAuthMethodTypesRequired, error) {
|
||||
authMethods, forceMFA, err := scan(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &testUserAuthMethodTypesRequired{authMethods: authMethods, forceMFA: forceMFA}, nil
|
||||
}
|
||||
},
|
||||
want: want{
|
||||
sqlExpectations: mockQueryErr(
|
||||
regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt),
|
||||
sql.ErrConnDone,
|
||||
),
|
||||
err: func(err error) (error, bool) {
|
||||
if !errors.Is(err, sql.ErrConnDone) {
|
||||
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
|
||||
}
|
||||
return nil, true
|
||||
},
|
||||
},
|
||||
object: nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
@@ -295,3 +445,9 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testUserAuthMethodTypesRequired is required as assetPrepare is only able to return a single object from scan
|
||||
type testUserAuthMethodTypesRequired struct {
|
||||
authMethods []domain.UserAuthMethodType
|
||||
forceMFA bool
|
||||
}
|
||||
|
Reference in New Issue
Block a user