feat(API): support V2 token and session token usage (#6180)

This PR adds support for userinfo and introspection of V2 tokens. Further V2 access tokens and session tokens can be used for authentication on the ZITADEL API (like the current access tokens).
This commit is contained in:
Livio Spring
2023-07-14 13:16:16 +02:00
committed by GitHub
parent 4589ddad4a
commit 80961125a7
38 changed files with 1309 additions and 181 deletions

View File

@@ -0,0 +1,113 @@
package query
import (
"context"
"strings"
"time"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/oidcsession"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
type OIDCSessionAccessTokenReadModel struct {
eventstore.WriteModel
UserID string
SessionID string
ClientID string
Audience []string
Scope []string
AuthMethods []domain.UserAuthMethodType
AuthTime time.Time
State domain.OIDCSessionState
AccessTokenID string
AccessTokenCreation time.Time
AccessTokenExpiration time.Time
}
func newOIDCSessionAccessTokenWriteModel(id string) *OIDCSessionAccessTokenReadModel {
return &OIDCSessionAccessTokenReadModel{
WriteModel: eventstore.WriteModel{
AggregateID: id,
},
}
}
func (wm *OIDCSessionAccessTokenReadModel) Reduce() error {
for _, event := range wm.Events {
switch e := event.(type) {
case *oidcsession.AddedEvent:
wm.reduceAdded(e)
case *oidcsession.AccessTokenAddedEvent:
wm.reduceAccessTokenAdded(e)
}
}
return wm.WriteModel.Reduce()
}
func (wm *OIDCSessionAccessTokenReadModel) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AllowTimeTravel().
AddQuery().
AggregateTypes(oidcsession.AggregateType).
AggregateIDs(wm.AggregateID).
EventTypes(
oidcsession.AddedType,
oidcsession.AccessTokenAddedType,
).
Builder()
}
func (wm *OIDCSessionAccessTokenReadModel) reduceAdded(e *oidcsession.AddedEvent) {
wm.UserID = e.UserID
wm.SessionID = e.SessionID
wm.ClientID = e.ClientID
wm.Audience = e.Audience
wm.Scope = e.Scope
wm.AuthMethods = e.AuthMethods
wm.AuthTime = e.AuthTime
wm.State = domain.OIDCSessionStateActive
}
func (wm *OIDCSessionAccessTokenReadModel) reduceAccessTokenAdded(e *oidcsession.AccessTokenAddedEvent) {
wm.AccessTokenID = e.ID
wm.AccessTokenCreation = e.CreationDate()
wm.AccessTokenExpiration = e.CreationDate().Add(e.Lifetime)
}
// ActiveAccessTokenByToken will check if the token is active by retrieving the OIDCSession events from the eventstore.
// refreshed or expired tokens will return an error
func (q *Queries) ActiveAccessTokenByToken(ctx context.Context, token string) (model *OIDCSessionAccessTokenReadModel, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
split := strings.Split(token, "-")
if len(split) != 2 {
return nil, caos_errs.ThrowPermissionDenied(nil, "QUERY-SAhtk", "Errors.OIDCSession.Token.Invalid")
}
model, err = q.accessTokenByOIDCSessionAndTokenID(ctx, split[0], split[1])
if err != nil {
return nil, err
}
if !model.AccessTokenExpiration.After(time.Now()) {
return nil, caos_errs.ThrowPermissionDenied(nil, "QUERY-SAF3rf", "Errors.OIDCSession.Token.Expired")
}
return
}
func (q *Queries) accessTokenByOIDCSessionAndTokenID(ctx context.Context, oidcSessionID, tokenID string) (model *OIDCSessionAccessTokenReadModel, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
model = newOIDCSessionAccessTokenWriteModel(oidcSessionID)
if err = q.eventstore.FilterToQueryReducer(ctx, model); err != nil {
return nil, caos_errs.ThrowPermissionDenied(err, "QUERY-ASfe2", "Errors.OIDCSession.Token.Invalid")
}
if model.AccessTokenID != tokenID {
return nil, caos_errs.ThrowPermissionDenied(nil, "QUERY-M2u9w", "Errors.OIDCSession.Token.Invalid")
}
return model, nil
}

View File

@@ -40,6 +40,7 @@ type Session struct {
type SessionUserFactor struct {
UserID string
ResourceOwner string
UserCheckedAt time.Time
LoginName string
DisplayName string
@@ -225,6 +226,7 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil
SessionColumnUserCheckedAt.identifier(),
LoginNameNameCol.identifier(),
HumanDisplayNameCol.identifier(),
UserResourceOwnerCol.identifier(),
SessionColumnPasswordCheckedAt.identifier(),
SessionColumnIntentCheckedAt.identifier(),
SessionColumnPasskeyCheckedAt.identifier(),
@@ -232,7 +234,8 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil
SessionColumnToken.identifier(),
).From(sessionsTable.identifier()).
LeftJoin(join(LoginNameUserIDCol, SessionColumnUserID)).
LeftJoin(join(HumanUserIDCol, SessionColumnUserID) + db.Timetravel(call.Took(ctx))).
LeftJoin(join(HumanUserIDCol, SessionColumnUserID)).
LeftJoin(join(UserIDCol, SessionColumnUserID) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*Session, string, error) {
session := new(Session)
@@ -241,6 +244,7 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil
userCheckedAt sql.NullTime
loginName sql.NullString
displayName sql.NullString
userResourceOwner sql.NullString
passwordCheckedAt sql.NullTime
intentCheckedAt sql.NullTime
passkeyCheckedAt sql.NullTime
@@ -262,6 +266,7 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil
&userCheckedAt,
&loginName,
&displayName,
&userResourceOwner,
&passwordCheckedAt,
&intentCheckedAt,
&passkeyCheckedAt,
@@ -281,6 +286,7 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil
session.UserFactor.UserCheckedAt = userCheckedAt.Time
session.UserFactor.LoginName = loginName.String
session.UserFactor.DisplayName = displayName.String
session.UserFactor.ResourceOwner = userResourceOwner.String
session.PasswordFactor.PasswordCheckedAt = passwordCheckedAt.Time
session.IntentFactor.IntentCheckedAt = intentCheckedAt.Time
session.PasskeyFactor.PasskeyCheckedAt = passkeyCheckedAt.Time
@@ -304,6 +310,7 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui
SessionColumnUserCheckedAt.identifier(),
LoginNameNameCol.identifier(),
HumanDisplayNameCol.identifier(),
UserResourceOwnerCol.identifier(),
SessionColumnPasswordCheckedAt.identifier(),
SessionColumnIntentCheckedAt.identifier(),
SessionColumnPasskeyCheckedAt.identifier(),
@@ -311,7 +318,8 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui
countColumn.identifier(),
).From(sessionsTable.identifier()).
LeftJoin(join(LoginNameUserIDCol, SessionColumnUserID)).
LeftJoin(join(HumanUserIDCol, SessionColumnUserID) + db.Timetravel(call.Took(ctx))).
LeftJoin(join(HumanUserIDCol, SessionColumnUserID)).
LeftJoin(join(UserIDCol, SessionColumnUserID) + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*Sessions, error) {
sessions := &Sessions{Sessions: []*Session{}}
@@ -323,6 +331,7 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui
userCheckedAt sql.NullTime
loginName sql.NullString
displayName sql.NullString
userResourceOwner sql.NullString
passwordCheckedAt sql.NullTime
intentCheckedAt sql.NullTime
passkeyCheckedAt sql.NullTime
@@ -343,6 +352,7 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui
&userCheckedAt,
&loginName,
&displayName,
&userResourceOwner,
&passwordCheckedAt,
&intentCheckedAt,
&passkeyCheckedAt,
@@ -358,6 +368,7 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui
session.UserFactor.UserCheckedAt = userCheckedAt.Time
session.UserFactor.LoginName = loginName.String
session.UserFactor.DisplayName = displayName.String
session.UserFactor.ResourceOwner = userResourceOwner.String
session.PasswordFactor.PasswordCheckedAt = passwordCheckedAt.Time
session.IntentFactor.IntentCheckedAt = intentCheckedAt.Time
session.PasskeyFactor.PasskeyCheckedAt = passkeyCheckedAt.Time

View File

@@ -29,6 +29,7 @@ var (
` projections.sessions3.user_checked_at,` +
` projections.login_names2.login_name,` +
` projections.users8_humans.display_name,` +
` projections.users8.resource_owner,` +
` projections.sessions3.password_checked_at,` +
` projections.sessions3.intent_checked_at,` +
` projections.sessions3.passkey_checked_at,` +
@@ -37,6 +38,7 @@ var (
` FROM projections.sessions3` +
` LEFT JOIN projections.login_names2 ON projections.sessions3.user_id = projections.login_names2.user_id AND projections.sessions3.instance_id = projections.login_names2.instance_id` +
` LEFT JOIN projections.users8_humans ON projections.sessions3.user_id = projections.users8_humans.user_id AND projections.sessions3.instance_id = projections.users8_humans.instance_id` +
` LEFT JOIN projections.users8 ON projections.sessions3.user_id = projections.users8.id AND projections.sessions3.instance_id = projections.users8.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`)
expectedSessionsQuery = regexp.QuoteMeta(`SELECT projections.sessions3.id,` +
` projections.sessions3.creation_date,` +
@@ -50,6 +52,7 @@ var (
` projections.sessions3.user_checked_at,` +
` projections.login_names2.login_name,` +
` projections.users8_humans.display_name,` +
` projections.users8.resource_owner,` +
` projections.sessions3.password_checked_at,` +
` projections.sessions3.intent_checked_at,` +
` projections.sessions3.passkey_checked_at,` +
@@ -58,6 +61,7 @@ var (
` FROM projections.sessions3` +
` LEFT JOIN projections.login_names2 ON projections.sessions3.user_id = projections.login_names2.user_id AND projections.sessions3.instance_id = projections.login_names2.instance_id` +
` LEFT JOIN projections.users8_humans ON projections.sessions3.user_id = projections.users8_humans.user_id AND projections.sessions3.instance_id = projections.users8_humans.instance_id` +
` LEFT JOIN projections.users8 ON projections.sessions3.user_id = projections.users8.id AND projections.sessions3.instance_id = projections.users8.instance_id` +
` AS OF SYSTEM TIME '-1 ms'`)
sessionCols = []string{
@@ -73,6 +77,7 @@ var (
"user_checked_at",
"login_name",
"display_name",
"user_resource_owner",
"password_checked_at",
"intent_checked_at",
"passkey_checked_at",
@@ -93,6 +98,7 @@ var (
"user_checked_at",
"login_name",
"display_name",
"user_resource_owner",
"password_checked_at",
"intent_checked_at",
"passkey_checked_at",
@@ -145,6 +151,7 @@ func Test_SessionsPrepare(t *testing.T) {
testNow,
"login-name",
"display-name",
"resourceOwner",
testNow,
testNow,
testNow,
@@ -172,6 +179,7 @@ func Test_SessionsPrepare(t *testing.T) {
UserCheckedAt: testNow,
LoginName: "login-name",
DisplayName: "display-name",
ResourceOwner: "resourceOwner",
},
PasswordFactor: SessionPasswordFactor{
PasswordCheckedAt: testNow,
@@ -210,6 +218,7 @@ func Test_SessionsPrepare(t *testing.T) {
testNow,
"login-name",
"display-name",
"resourceOwner",
testNow,
testNow,
testNow,
@@ -228,6 +237,7 @@ func Test_SessionsPrepare(t *testing.T) {
testNow,
"login-name2",
"display-name2",
"resourceOwner",
testNow,
testNow,
testNow,
@@ -255,6 +265,7 @@ func Test_SessionsPrepare(t *testing.T) {
UserCheckedAt: testNow,
LoginName: "login-name",
DisplayName: "display-name",
ResourceOwner: "resourceOwner",
},
PasswordFactor: SessionPasswordFactor{
PasswordCheckedAt: testNow,
@@ -283,6 +294,7 @@ func Test_SessionsPrepare(t *testing.T) {
UserCheckedAt: testNow,
LoginName: "login-name2",
DisplayName: "display-name2",
ResourceOwner: "resourceOwner",
},
PasswordFactor: SessionPasswordFactor{
PasswordCheckedAt: testNow,
@@ -374,6 +386,7 @@ func Test_SessionPrepare(t *testing.T) {
testNow,
"login-name",
"display-name",
"resourceOwner",
testNow,
testNow,
testNow,
@@ -396,6 +409,7 @@ func Test_SessionPrepare(t *testing.T) {
UserCheckedAt: testNow,
LoginName: "login-name",
DisplayName: "display-name",
ResourceOwner: "resourceOwner",
},
PasswordFactor: SessionPasswordFactor{
PasswordCheckedAt: testNow,

View File

@@ -79,6 +79,12 @@ var (
name: "count",
table: userIDPsCountTable,
}
forceMFATable = loginPolicyTable.setAlias("auth_methods_force_mfa")
forceMFAInstanceID = LoginPolicyColumnInstanceID.setTable(forceMFATable)
forceMFAOrgID = LoginPolicyColumnOrgID.setTable(forceMFATable)
forceMFAIsDefault = LoginPolicyColumnIsDefault.setTable(forceMFATable)
forceMFAForce = LoginPolicyColumnForceMFA.setTable(forceMFATable)
)
type AuthMethods struct {
@@ -170,6 +176,36 @@ func (q *Queries) ListActiveUserAuthMethodTypes(ctx context.Context, userID stri
return userAuthMethodTypes, err
}
func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID string, withOwnerRemoved bool) (userAuthMethodTypes []domain.UserAuthMethodType, forceMFA bool, err error) {
ctxData := authz.GetCtxData(ctx)
if ctxData.UserID != userID {
if err := q.checkPermission(ctx, domain.PermissionUserRead, ctxData.OrgID, userID); err != nil {
return nil, false, err
}
}
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, q.client)
eq := sq.Eq{
UserIDCol.identifier(): userID,
UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
}
if !withOwnerRemoved {
eq[UserOwnerRemovedCol.identifier()] = false
}
stmt, args, err := query.Where(eq).ToSql()
if err != nil {
return nil, false, errors.ThrowInvalidArgument(err, "QUERY-E5ut4", "Errors.Query.InvalidRequest")
}
rows, err := q.client.QueryContext(ctx, stmt, args...)
if err != nil || rows.Err() != nil {
return nil, false, errors.ThrowInternal(err, "QUERY-Dun75", "Errors.Internal")
}
return scan(rows)
}
func NewUserAuthMethodUserIDSearchQuery(value string) (SearchQuery, error) {
return NewTextQuery(UserAuthMethodColumnUserID, value, TextEquals)
}
@@ -311,26 +347,11 @@ func prepareUserAuthMethodsQuery(ctx context.Context, db prepareDatabase) (sq.Se
}
func prepareActiveUserAuthMethodTypesQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) {
authMethodsQuery, authMethodsArgs, err := sq.Select(
"DISTINCT("+authMethodTypeTypes.identifier()+")",
authMethodTypeUserID.identifier(),
authMethodTypeInstanceID.identifier()).
From(authMethodTypeTable.identifier()).
Where(sq.Eq{authMethodTypeState.identifier(): domain.MFAStateReady}).
ToSql()
authMethodsQuery, authMethodsArgs, err := prepareAuthMethodQuery()
if err != nil {
return sq.SelectBuilder{}, nil
}
idpsQuery, _, err := sq.Select(
userIDPsCountUserID.identifier(),
userIDPsCountInstanceID.identifier(),
"COUNT("+userIDPsCountUserID.identifier()+") AS "+userIDPsCountCount.name).
From(userIDPsCountTable.identifier()).
GroupBy(
userIDPsCountUserID.identifier(),
userIDPsCountInstanceID.identifier(),
).
ToSql()
idpsQuery, err := prepareAuthMethodsIDPsQuery()
if err != nil {
return sq.SelectBuilder{}, nil
}
@@ -386,3 +407,106 @@ func prepareActiveUserAuthMethodTypesQuery(ctx context.Context, db prepareDataba
}, nil
}
}
func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) ([]domain.UserAuthMethodType, bool, error)) {
loginPolicyQuery, err := prepareAuthMethodsForceMFAQuery()
if err != nil {
return sq.SelectBuilder{}, nil
}
authMethodsQuery, authMethodsArgs, err := prepareAuthMethodQuery()
if err != nil {
return sq.SelectBuilder{}, nil
}
idpsQuery, err := prepareAuthMethodsIDPsQuery()
if err != nil {
return sq.SelectBuilder{}, nil
}
return sq.Select(
NotifyPasswordSetCol.identifier(),
authMethodTypeTypes.identifier(),
userIDPsCountCount.identifier(),
forceMFAForce.identifier()).
From(userTable.identifier()).
LeftJoin(join(NotifyUserIDCol, UserIDCol)).
LeftJoin("("+authMethodsQuery+") AS "+authMethodTypeTable.alias+" ON "+
authMethodTypeUserID.identifier()+" = "+UserIDCol.identifier()+" AND "+
authMethodTypeInstanceID.identifier()+" = "+UserInstanceIDCol.identifier(),
authMethodsArgs...).
LeftJoin("(" + idpsQuery + ") AS " + userIDPsCountTable.alias + " ON " +
userIDPsCountUserID.identifier() + " = " + UserIDCol.identifier() + " AND " +
userIDPsCountInstanceID.identifier() + " = " + UserInstanceIDCol.identifier()).
LeftJoin("(" + loginPolicyQuery + ") AS " + forceMFATable.alias + " ON " +
"(" + forceMFAOrgID.identifier() + " = " + UserInstanceIDCol.identifier() + " OR " + forceMFAOrgID.identifier() + " = " + UserResourceOwnerCol.identifier() + ") AND " +
forceMFAInstanceID.identifier() + " = " + UserInstanceIDCol.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) ([]domain.UserAuthMethodType, bool, error) {
userAuthMethodTypes := make([]domain.UserAuthMethodType, 0)
var passwordSet sql.NullBool
var idp sql.NullInt64
var forceMFA sql.NullBool
for rows.Next() {
var authMethodType sql.NullInt16
err := rows.Scan(
&passwordSet,
&authMethodType,
&idp,
&forceMFA,
)
if err != nil {
return nil, false, err
}
if authMethodType.Valid {
userAuthMethodTypes = append(userAuthMethodTypes, domain.UserAuthMethodType(authMethodType.Int16))
}
}
if passwordSet.Valid && passwordSet.Bool {
userAuthMethodTypes = append(userAuthMethodTypes, domain.UserAuthMethodTypePassword)
}
if idp.Valid && idp.Int64 > 0 {
logging.Error("IDP", idp.Int64)
userAuthMethodTypes = append(userAuthMethodTypes, domain.UserAuthMethodTypeIDP)
}
if err := rows.Close(); err != nil {
return nil, false, errors.ThrowInternal(err, "QUERY-W4zje", "Errors.Query.CloseRows")
}
return userAuthMethodTypes, forceMFA.Bool, nil
}
}
func prepareAuthMethodsIDPsQuery() (string, error) {
idpsQuery, _, err := sq.Select(
userIDPsCountUserID.identifier(),
userIDPsCountInstanceID.identifier(),
"COUNT("+userIDPsCountUserID.identifier()+") AS "+userIDPsCountCount.name).
From(userIDPsCountTable.identifier()).
GroupBy(
userIDPsCountUserID.identifier(),
userIDPsCountInstanceID.identifier(),
).
ToSql()
return idpsQuery, err
}
func prepareAuthMethodQuery() (string, []interface{}, error) {
return sq.Select(
"DISTINCT("+authMethodTypeTypes.identifier()+")",
authMethodTypeUserID.identifier(),
authMethodTypeInstanceID.identifier()).
From(authMethodTypeTable.identifier()).
Where(sq.Eq{authMethodTypeState.identifier(): domain.MFAStateReady}).
ToSql()
}
func prepareAuthMethodsForceMFAQuery() (string, error) {
loginPolicyQuery, _, err := sq.Select(
forceMFAForce.identifier(),
forceMFAInstanceID.identifier(),
forceMFAOrgID.identifier(),
).
From(forceMFATable.identifier()).
OrderBy(forceMFAIsDefault.identifier()).
ToSql()
return loginPolicyQuery, err
}

View File

@@ -1,6 +1,7 @@
package query
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
@@ -8,6 +9,8 @@ import (
"regexp"
"testing"
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/domain"
)
@@ -53,6 +56,28 @@ var (
"method_type",
"idps_count",
}
prepareAuthMethodTypesRequiredStmt = `SELECT projections.users8_notifications.password_set,` +
` auth_method_types.method_type,` +
` user_idps_count.count,` +
` auth_methods_force_mfa.force_mfa` +
` FROM projections.users8` +
` LEFT JOIN projections.users8_notifications ON projections.users8.id = projections.users8_notifications.user_id AND projections.users8.instance_id = projections.users8_notifications.instance_id` +
` LEFT JOIN (SELECT DISTINCT(auth_method_types.method_type), auth_method_types.user_id, auth_method_types.instance_id FROM projections.user_auth_methods4 AS auth_method_types` +
` WHERE auth_method_types.state = $1) AS auth_method_types` +
` ON auth_method_types.user_id = projections.users8.id AND auth_method_types.instance_id = projections.users8.instance_id` +
` LEFT JOIN (SELECT user_idps_count.user_id, user_idps_count.instance_id, COUNT(user_idps_count.user_id) AS count FROM projections.idp_user_links3 AS user_idps_count` +
` GROUP BY user_idps_count.user_id, user_idps_count.instance_id) AS user_idps_count` +
` ON user_idps_count.user_id = projections.users8.id AND user_idps_count.instance_id = projections.users8.instance_id` +
` LEFT JOIN (SELECT auth_methods_force_mfa.force_mfa, auth_methods_force_mfa.instance_id, auth_methods_force_mfa.aggregate_id FROM projections.login_policies4 AS auth_methods_force_mfa ORDER BY auth_methods_force_mfa.is_default) AS auth_methods_force_mfa` +
` ON (auth_methods_force_mfa.aggregate_id = projections.users8.instance_id OR auth_methods_force_mfa.aggregate_id = projections.users8.resource_owner) AND auth_methods_force_mfa.instance_id = projections.users8.instance_id` +
` AS OF SYSTEM TIME '-1 ms
`
prepareAuthMethodTypesRequiredCols = []string{
"password_set",
"method_type",
"idps_count",
"force_mfa",
}
)
func Test_UserAuthMethodPrepares(t *testing.T) {
@@ -288,6 +313,131 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
},
object: nil,
},
{
name: "prepareUserAuthMethodTypesRequiredQuery no result",
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*testUserAuthMethodTypesRequired, error)) {
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
return builder, func(rows *sql.Rows) (*testUserAuthMethodTypesRequired, error) {
authMethods, forceMFA, err := scan(rows)
if err != nil {
return nil, err
}
return &testUserAuthMethodTypesRequired{authMethods: authMethods, forceMFA: forceMFA}, nil
}
},
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt),
nil,
nil,
),
},
object: &testUserAuthMethodTypesRequired{authMethods: []domain.UserAuthMethodType{}, forceMFA: false},
},
{
name: "prepareUserAuthMethodTypesRequiredQuery one second factor",
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*testUserAuthMethodTypesRequired, error)) {
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
return builder, func(rows *sql.Rows) (*testUserAuthMethodTypesRequired, error) {
authMethods, forceMFA, err := scan(rows)
if err != nil {
return nil, err
}
return &testUserAuthMethodTypesRequired{authMethods: authMethods, forceMFA: forceMFA}, nil
}
},
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt),
prepareAuthMethodTypesRequiredCols,
[][]driver.Value{
{
true,
domain.UserAuthMethodTypePasswordless,
1,
true,
},
},
),
},
object: &testUserAuthMethodTypesRequired{
authMethods: []domain.UserAuthMethodType{
domain.UserAuthMethodTypePasswordless,
domain.UserAuthMethodTypePassword,
domain.UserAuthMethodTypeIDP,
},
forceMFA: true,
},
},
{
name: "prepareUserAuthMethodTypesRequiredQuery multiple second factors",
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*testUserAuthMethodTypesRequired, error)) {
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
return builder, func(rows *sql.Rows) (*testUserAuthMethodTypesRequired, error) {
authMethods, forceMFA, err := scan(rows)
if err != nil {
return nil, err
}
return &testUserAuthMethodTypesRequired{authMethods: authMethods, forceMFA: forceMFA}, nil
}
},
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt),
prepareAuthMethodTypesRequiredCols,
[][]driver.Value{
{
true,
domain.UserAuthMethodTypePasswordless,
1,
true,
},
{
true,
domain.UserAuthMethodTypeOTP,
1,
true,
},
},
),
},
object: &testUserAuthMethodTypesRequired{
authMethods: []domain.UserAuthMethodType{
domain.UserAuthMethodTypePasswordless,
domain.UserAuthMethodTypeOTP,
domain.UserAuthMethodTypePassword,
domain.UserAuthMethodTypeIDP,
},
forceMFA: true,
},
},
{
name: "prepareUserAuthMethodTypesRequiredQuery sql err",
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*testUserAuthMethodTypesRequired, error)) {
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
return builder, func(rows *sql.Rows) (*testUserAuthMethodTypesRequired, error) {
authMethods, forceMFA, err := scan(rows)
if err != nil {
return nil, err
}
return &testUserAuthMethodTypesRequired{authMethods: authMethods, forceMFA: forceMFA}, nil
}
},
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
if !errors.Is(err, sql.ErrConnDone) {
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
}
return nil, true
},
},
object: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -295,3 +445,9 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
})
}
}
// testUserAuthMethodTypesRequired is required as assetPrepare is only able to return a single object from scan
type testUserAuthMethodTypesRequired struct {
authMethods []domain.UserAuthMethodType
forceMFA bool
}