fix(oidc): IDP and machine user auth methods (#7992)

# Which Problems Are Solved

After https://github.com/zitadel/zitadel/pull/7822 was merged we
discovered that
v2 tokens that where obtained through an IDP using the v1 login, can't
be used for
zitadel API calls.

- Because we used to store the AMR claim on the auth request, but
internally use the domain.UserAuthMethod type. AMR has no notion of an
IDP login, so that "factor" was lost
during conversion. Rendering those v2 tokens invalid on the zitadel API.
- A wrong check on machine user tokens falsly allowed some tokens to be
valid
- The client ID was set to tokens from client credentials and JWT
profile, which made client queries fail in the validation middleware.
The middleware expects client ID unset for machine users.

# How the Problems Are Solved

Store the domain.AuthMethods directly in  the auth requests and session,
instead of using AMR claims with lossy conversion.

- IDPs have seperate auth method, which is not an AMR claim
- Machine users are treated specialy, eg auth methods are not required.
- Do not set the client ID for client credentials and JWT profile

# Additional Changes

Cleaned up mostly unused `oidc.getInfoFromRequest()`.

# Additional Context

- Bugs were introduced in https://github.com/zitadel/zitadel/pull/7822
and not yet part of a release.
- Reported internally.
This commit is contained in:
Tim Möhlmann 2024-05-23 07:35:10 +02:00 committed by GitHub
parent e57a9b57c8
commit f5e9d4f57f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 126 additions and 112 deletions

View File

@ -12,7 +12,6 @@ import (
"github.com/zitadel/logging" "github.com/zitadel/logging"
"github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
http_utils "github.com/zitadel/zitadel/internal/api/http" http_utils "github.com/zitadel/zitadel/internal/api/http"
@ -199,23 +198,6 @@ func (*OPStorage) panicErr(method string) error {
return fmt.Errorf("OPStorage.%s should not be called anymore. This is a bug. Please report https://github.com/zitadel/zitadel/issues", method) return fmt.Errorf("OPStorage.%s should not be called anymore. This is a bug. Please report https://github.com/zitadel/zitadel/issues", method)
} }
func getInfoFromRequest(req op.TokenRequest) (agentID, clientID, userOrgID string, authTime time.Time, amr []string, preferredLanguage *language.Tag, reason domain.TokenReason, actor *domain.TokenActor) {
switch r := req.(type) {
case *AuthRequest:
return r.AgentID, r.ApplicationID, r.UserOrgID, r.AuthTime, r.GetAMR(), r.PreferredLanguage, domain.TokenReasonAuthRequest, nil
case *RefreshTokenRequest:
return r.UserAgentID, r.ClientID, "", r.AuthTime, r.AuthMethodsReferences, nil, domain.TokenReasonRefresh, r.Actor
case op.IDTokenRequest:
return "", r.GetClientID(), "", r.GetAuthTime(), r.GetAMR(), nil, domain.TokenReasonAuthRequest, nil
case *oidc.JWTTokenRequest:
return "", "", "", r.GetAuthTime(), nil, nil, domain.TokenReasonJWTProfile, nil
case *clientCredentialsRequest:
return "", "", "", time.Time{}, nil, nil, domain.TokenReasonClientCredentials, nil
default:
return "", "", "", time.Time{}, nil, nil, domain.TokenReasonAuthRequest, nil
}
}
func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (_ op.RefreshTokenRequest, err error) { func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (_ op.RefreshTokenRequest, err error) {
panic("TokenRequestByRefreshToken should not be called anymore. This is a bug. Please report https://github.com/zitadel/zitadel/issues") panic("TokenRequestByRefreshToken should not be called anymore. This is a bug. Please report https://github.com/zitadel/zitadel/issues")
} }
@ -511,8 +493,8 @@ func implicitFlowComplianceChecker() command.AuthRequestComplianceChecker {
func (s *Server) authorizeCallbackHandler(w http.ResponseWriter, r *http.Request) { func (s *Server) authorizeCallbackHandler(w http.ResponseWriter, r *http.Request) {
authorizer := s.Provider() authorizer := s.Provider()
authReq, err := func() (authReq op.AuthRequest, err error) { authReq, err := func(ctx context.Context) (authReq *AuthRequest, err error) {
ctx, span := tracing.NewSpan(r.Context()) ctx, span := tracing.NewSpan(ctx)
r = r.WithContext(ctx) r = r.WithContext(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -520,7 +502,7 @@ func (s *Server) authorizeCallbackHandler(w http.ResponseWriter, r *http.Request
if err != nil { if err != nil {
return nil, err return nil, err
} }
authReq, err = authorizer.Storage().AuthRequestByID(r.Context(), id) authReq, err = s.getAuthRequestV1ByID(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -528,13 +510,13 @@ func (s *Server) authorizeCallbackHandler(w http.ResponseWriter, r *http.Request
return authReq, oidc.ErrInteractionRequired().WithDescription("Unfortunately, the user may be not logged in and/or additional interaction is required.") return authReq, oidc.ErrInteractionRequired().WithDescription("Unfortunately, the user may be not logged in and/or additional interaction is required.")
} }
return authReq, s.authResponse(authReq, authorizer, w, r) return authReq, s.authResponse(authReq, authorizer, w, r)
}() }(r.Context())
if err != nil { if err != nil {
op.AuthRequestError(w, r, authReq, err, authorizer) op.AuthRequestError(w, r, authReq, err, authorizer)
} }
} }
func (s *Server) authResponse(authReq op.AuthRequest, authorizer op.Authorizer, w http.ResponseWriter, r *http.Request) (err error) { func (s *Server) authResponse(authReq *AuthRequest, authorizer op.Authorizer, w http.ResponseWriter, r *http.Request) (err error) {
ctx, span := tracing.NewSpan(r.Context()) ctx, span := tracing.NewSpan(r.Context())
r = r.WithContext(ctx) r = r.WithContext(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -551,7 +533,7 @@ func (s *Server) authResponse(authReq op.AuthRequest, authorizer op.Authorizer,
return s.authResponseToken(authReq, authorizer, client, w, r) return s.authResponseToken(authReq, authorizer, client, w, r)
} }
func (s *Server) authResponseToken(authReq op.AuthRequest, authorizer op.Authorizer, opClient op.Client, w http.ResponseWriter, r *http.Request) (err error) { func (s *Server) authResponseToken(authReq *AuthRequest, authorizer op.Authorizer, opClient op.Client, w http.ResponseWriter, r *http.Request) (err error) {
ctx, span := tracing.NewSpan(r.Context()) ctx, span := tracing.NewSpan(r.Context())
r = r.WithContext(ctx) r = r.WithContext(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -561,23 +543,20 @@ func (s *Server) authResponseToken(authReq op.AuthRequest, authorizer op.Authori
return zerrors.ThrowInternal(nil, "OIDC-waeN6", "Error.Internal") return zerrors.ThrowInternal(nil, "OIDC-waeN6", "Error.Internal")
} }
userAgentID, _, userOrgID, authTime, authMethodsReferences, preferredLanguage, reason, actor := getInfoFromRequest(authReq)
scope := authReq.GetScopes() scope := authReq.GetScopes()
session, err := s.command.CreateOIDCSession(ctx, session, err := s.command.CreateOIDCSession(ctx,
authReq.GetSubject(), authReq.UserID,
userOrgID, authReq.UserOrgID,
client.client.ClientID, client.client.ClientID,
scope, scope,
authReq.GetAudience(), authReq.Audience,
AMRToAuthMethodTypes(authMethodsReferences), authReq.AuthMethods(),
authTime, authReq.AuthTime,
authReq.GetNonce(), authReq.GetNonce(),
preferredLanguage, authReq.PreferredLanguage,
&domain.UserAgent{ authReq.BrowserInfo.ToUserAgent(),
FingerprintID: &userAgentID, domain.TokenReasonAuthRequest,
}, nil,
reason,
actor,
slices.Contains(scope, oidc.ScopeOfflineAccess), slices.Contains(scope, oidc.ScopeOfflineAccess),
) )
if err != nil { if err != nil {

View File

@ -94,7 +94,7 @@ func (a *AuthRequest) oidc() *domain.AuthRequestOIDC {
return a.Request.(*domain.AuthRequestOIDC) return a.Request.(*domain.AuthRequestOIDC)
} }
func AuthRequestFromBusiness(authReq *domain.AuthRequest) (_ op.AuthRequest, err error) { func AuthRequestFromBusiness(authReq *domain.AuthRequest) (_ *AuthRequest, err error) {
if _, ok := authReq.Request.(*domain.AuthRequestOIDC); !ok { if _, ok := authReq.Request.(*domain.AuthRequestOIDC); !ok {
return nil, zerrors.ThrowInvalidArgument(nil, "OIDC-Haz7A", "auth request is not of type oidc") return nil, zerrors.ThrowInvalidArgument(nil, "OIDC-Haz7A", "auth request is not of type oidc")
} }

View File

@ -34,7 +34,7 @@ func (s *Server) ClientCredentialsExchange(ctx context.Context, r *op.ClientRequ
session, err := s.command.CreateOIDCSession(ctx, session, err := s.command.CreateOIDCSession(ctx,
client.user.ID, client.user.ID,
client.user.ResourceOwner, client.user.ResourceOwner,
r.Data.ClientID, "",
scope, scope,
domain.AddAudScopeToAudience(ctx, nil, r.Data.Scope), domain.AddAudScopeToAudience(ctx, nil, r.Data.Scope),
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},

View File

@ -8,6 +8,7 @@ import (
"github.com/zitadel/oidc/v3/pkg/oidc" "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/oidc/v3/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/telemetry/tracing"
@ -69,36 +70,33 @@ func (s *Server) codeExchangeV1(ctx context.Context, client *Client, req *oidc.A
if req.RedirectURI != authReq.GetRedirectURI() { if req.RedirectURI != authReq.GetRedirectURI() {
return nil, "", oidc.ErrInvalidGrant().WithDescription("redirect_uri does not correspond") return nil, "", oidc.ErrInvalidGrant().WithDescription("redirect_uri does not correspond")
} }
userAgentID, _, userOrgID, authTime, authMethodsReferences, preferredLanguage, reason, actor := getInfoFromRequest(authReq)
scope := authReq.GetScopes() scope := authReq.GetScopes()
session, err = s.command.CreateOIDCSession(ctx, session, err = s.command.CreateOIDCSession(ctx,
authReq.GetSubject(), authReq.UserID,
userOrgID, authReq.UserOrgID,
client.client.ClientID, client.client.ClientID,
scope, scope,
authReq.GetAudience(), authReq.Audience,
AMRToAuthMethodTypes(authMethodsReferences), authReq.AuthMethods(),
authTime, authReq.AuthTime,
authReq.GetNonce(), authReq.GetNonce(),
preferredLanguage, authReq.PreferredLanguage,
&domain.UserAgent{ authReq.BrowserInfo.ToUserAgent(),
FingerprintID: &userAgentID, domain.TokenReasonAuthRequest,
}, nil,
reason,
actor,
slices.Contains(scope, oidc.ScopeOfflineAccess), slices.Contains(scope, oidc.ScopeOfflineAccess),
) )
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
return session, authReq.GetState(), s.repo.DeleteAuthRequest(ctx, authReq.GetID()) return session, authReq.TransferState, s.repo.DeleteAuthRequest(ctx, authReq.ID)
} }
// getAuthRequestV1ByCode finds the v1 auth request by code. // getAuthRequestV1ByCode finds the v1 auth request by code.
// code needs to be the encrypted version of the ID, // code needs to be the encrypted version of the ID,
// this is required by the underlying repo. // this is required by the underlying repo.
func (s *Server) getAuthRequestV1ByCode(ctx context.Context, code string) (op.AuthRequest, error) { func (s *Server) getAuthRequestV1ByCode(ctx context.Context, code string) (*AuthRequest, error) {
authReq, err := s.repo.AuthRequestByCode(ctx, code) authReq, err := s.repo.AuthRequestByCode(ctx, code)
if err != nil { if err != nil {
return nil, err return nil, err
@ -106,6 +104,18 @@ func (s *Server) getAuthRequestV1ByCode(ctx context.Context, code string) (op.Au
return AuthRequestFromBusiness(authReq) return AuthRequestFromBusiness(authReq)
} }
func (s *Server) getAuthRequestV1ByID(ctx context.Context, id string) (*AuthRequest, error) {
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
if !ok {
return nil, zerrors.ThrowPreconditionFailed(nil, "OIDC-TiTu7", "no user agent id")
}
resp, err := s.repo.AuthRequestByIDCheckLoggedIn(ctx, id, userAgentID)
if err != nil {
return nil, err
}
return AuthRequestFromBusiness(resp)
}
func codeExchangeComplianceChecker(client *Client, req *oidc.AccessTokenRequest) command.AuthRequestComplianceChecker { func codeExchangeComplianceChecker(client *Client, req *oidc.AccessTokenRequest) command.AuthRequestComplianceChecker {
return func(ctx context.Context, authReq *command.AuthRequestWriteModel) error { return func(ctx context.Context, authReq *command.AuthRequestWriteModel) error {
if authReq.CodeChallenge != nil || client.AuthMethod() == oidc.AuthMethodNone { if authReq.CodeChallenge != nil || client.AuthMethod() == oidc.AuthMethodNone {

View File

@ -42,15 +42,15 @@ func (s *Server) JWTProfile(ctx context.Context, r *op.Request[oidc.JWTProfileGr
session, err := s.command.CreateOIDCSession(ctx, session, err := s.command.CreateOIDCSession(ctx,
user.ID, user.ID,
user.ResourceOwner, user.ResourceOwner,
jwtReq.Subject, "",
scope, scope,
domain.AddAudScopeToAudience(ctx, nil, r.Data.Scope), domain.AddAudScopeToAudience(ctx, nil, r.Data.Scope),
nil, []domain.UserAuthMethodType{domain.UserAuthMethodTypePrivateKey},
time.Now(), time.Now(),
"", "",
nil, nil,
nil, nil,
domain.TokenReasonClientCredentials, domain.TokenReasonJWTProfile,
nil, nil,
false, false,
) )

View File

@ -269,6 +269,7 @@ func (repo *AuthRequestRepo) CheckExternalUserLogin(ctx context.Context, authReq
return err return err
} }
request.IDPLoginChecked = true
err = repo.Command.UserIDPLoginChecked(ctx, request.UserOrgID, request.UserID, request.WithCurrentInfo(info)) err = repo.Command.UserIDPLoginChecked(ctx, request.UserOrgID, request.UserID, request.WithCurrentInfo(info))
if err != nil { if err != nil {
return err return err
@ -516,6 +517,7 @@ func (repo *AuthRequestRepo) LinkExternalUsers(ctx context.Context, authReqID, u
return err return err
} }
request.LinkingUsers = nil request.LinkingUsers = nil
request.IDPLoginChecked = true
return repo.AuthRequests.UpdateAuthRequest(ctx, request) return repo.AuthRequests.UpdateAuthRequest(ctx, request)
} }
@ -560,6 +562,7 @@ func (repo *AuthRequestRepo) AutoRegisterExternalUser(ctx context.Context, regis
request.SetUserInfo(human.ID, human.Username, human.Username, human.DisplayName, "", resourceOwner) request.SetUserInfo(human.ID, human.Username, human.Username, human.DisplayName, "", resourceOwner)
request.SelectedIDPConfigID = externalIDP.IDPConfigID request.SelectedIDPConfigID = externalIDP.IDPConfigID
request.LinkingUsers = nil request.LinkingUsers = nil
request.IDPLoginChecked = true
err = repo.Command.UserIDPLoginChecked(ctx, request.UserOrgID, request.UserID, request.WithCurrentInfo(info)) err = repo.Command.UserIDPLoginChecked(ctx, request.UserOrgID, request.UserID, request.WithCurrentInfo(info))
if err != nil { if err != nil {
return err return err

View File

@ -180,11 +180,18 @@ func (repo *TokenVerifierRepo) checkAuthentication(ctx context.Context, authMeth
if domain.HasMFA(authMethods) { if domain.HasMFA(authMethods) {
return nil return nil
} }
availableAuthMethods, forceMFA, forceMFALocalOnly, err := repo.Query.ListUserAuthMethodTypesRequired(setCallerCtx(ctx, userID), userID) requirements, err := repo.Query.ListUserAuthMethodTypesRequired(setCallerCtx(ctx, userID), userID)
if err != nil { if err != nil {
return err return err
} }
if domain.RequiresMFA(forceMFA, forceMFALocalOnly, hasIDPAuthentication(authMethods)) || domain.HasMFA(availableAuthMethods) { if requirements.UserType == domain.UserTypeMachine {
return nil
}
if domain.RequiresMFA(
requirements.ForceMFA,
requirements.ForceMFALocalOnly,
!hasIDPAuthentication(authMethods)) ||
domain.HasMFA(requirements.AuthMethods) {
return zerrors.ThrowPermissionDenied(nil, "AUTHZ-Kl3p0", "mfa required") return zerrors.ThrowPermissionDenied(nil, "AUTHZ-Kl3p0", "mfa required")
} }
return nil return nil

View File

@ -45,6 +45,7 @@ type AuthRequest struct {
LinkingUsers []*ExternalUser LinkingUsers []*ExternalUser
PossibleSteps []NextStep `json:"-"` PossibleSteps []NextStep `json:"-"`
PasswordVerified bool PasswordVerified bool
IDPLoginChecked bool
MFAsVerified []MFAType MFAsVerified []MFAType
Audience []string Audience []string
AuthTime time.Time AuthTime time.Time
@ -69,6 +70,20 @@ func (a *AuthRequest) PolicyOrgID() string {
return a.policyOrgID return a.policyOrgID
} }
func (a *AuthRequest) AuthMethods() []UserAuthMethodType {
list := make([]UserAuthMethodType, 0, len(a.MFAsVerified)+2)
if a.PasswordVerified {
list = append(list, UserAuthMethodTypePassword)
}
if a.IDPLoginChecked {
list = append(list, UserAuthMethodTypeIDP)
}
for _, mfa := range a.MFAsVerified {
list = append(list, mfa.UserAuthMethodType())
}
return list
}
type ExternalUser struct { type ExternalUser struct {
IDPConfigID string IDPConfigID string
ExternalUserID string ExternalUserID string

View File

@ -43,6 +43,7 @@ const (
UserAuthMethodTypeOTPSMS UserAuthMethodTypeOTPSMS
UserAuthMethodTypeOTPEmail UserAuthMethodTypeOTPEmail
UserAuthMethodTypeOTP // generic OTP when parsing AMR from OIDC UserAuthMethodTypeOTP // generic OTP when parsing AMR from OIDC
UserAuthMethodTypePrivateKey
userAuthMethodTypeCount userAuthMethodTypeCount
) )

View File

@ -172,11 +172,18 @@ func (q *Queries) ListActiveUserAuthMethodTypes(ctx context.Context, userID stri
return userAuthMethodTypes, err return userAuthMethodTypes, err
} }
func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID string) (userAuthMethodTypes []domain.UserAuthMethodType, forceMFA, forceMFALocalOnly bool, err error) { type UserAuthMethodRequirements struct {
UserType domain.UserType
AuthMethods []domain.UserAuthMethodType
ForceMFA bool
ForceMFALocalOnly bool
}
func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID string) (requirements *UserAuthMethodRequirements, err error) {
ctxData := authz.GetCtxData(ctx) ctxData := authz.GetCtxData(ctx)
if ctxData.UserID != userID { if ctxData.UserID != userID {
if err := q.checkPermission(ctx, domain.PermissionUserRead, ctxData.OrgID, userID); err != nil { if err := q.checkPermission(ctx, domain.PermissionUserRead, ctxData.OrgID, userID); err != nil {
return nil, false, false, err return nil, err
} }
} }
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
@ -189,17 +196,17 @@ func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID st
} }
stmt, args, err := query.Where(eq).ToSql() stmt, args, err := query.Where(eq).ToSql()
if err != nil { if err != nil {
return nil, false, false, zerrors.ThrowInvalidArgument(err, "QUERY-E5ut4", "Errors.Query.InvalidRequest") return nil, zerrors.ThrowInvalidArgument(err, "QUERY-E5ut4", "Errors.Query.InvalidRequest")
} }
err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
userAuthMethodTypes, forceMFA, forceMFALocalOnly, err = scan(rows) requirements, err = scan(rows)
return err return err
}, stmt, args...) }, stmt, args...)
if err != nil { if err != nil {
return nil, false, false, zerrors.ThrowInternal(err, "QUERY-Dun75", "Errors.Internal") return nil, zerrors.ThrowInternal(err, "QUERY-Dun75", "Errors.Internal")
} }
return userAuthMethodTypes, forceMFA, forceMFALocalOnly, nil return requirements, nil
} }
func NewUserAuthMethodUserIDSearchQuery(value string) (SearchQuery, error) { func NewUserAuthMethodUserIDSearchQuery(value string) (SearchQuery, error) {
@ -404,7 +411,7 @@ func prepareActiveUserAuthMethodTypesQuery(ctx context.Context, db prepareDataba
} }
} }
func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (_ []domain.UserAuthMethodType, forceMFA, forceMFALocalOnly bool, err error)) { func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) {
loginPolicyQuery, err := prepareAuthMethodsForceMFAQuery() loginPolicyQuery, err := prepareAuthMethodsForceMFAQuery()
if err != nil { if err != nil {
return sq.SelectBuilder{}, nil return sq.SelectBuilder{}, nil
@ -421,6 +428,7 @@ func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareData
NotifyPasswordSetCol.identifier(), NotifyPasswordSetCol.identifier(),
authMethodTypeTypes.identifier(), authMethodTypeTypes.identifier(),
userIDPsCountCount.identifier(), userIDPsCountCount.identifier(),
UserTypeCol.identifier(),
forceMFAForce.identifier(), forceMFAForce.identifier(),
forceMFAForceLocalOnly.identifier()). forceMFAForceLocalOnly.identifier()).
From(userTable.identifier()). From(userTable.identifier()).
@ -436,10 +444,11 @@ func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareData
"(" + forceMFAOrgID.identifier() + " = " + UserInstanceIDCol.identifier() + " OR " + forceMFAOrgID.identifier() + " = " + UserResourceOwnerCol.identifier() + ") AND " + "(" + forceMFAOrgID.identifier() + " = " + UserInstanceIDCol.identifier() + " OR " + forceMFAOrgID.identifier() + " = " + UserResourceOwnerCol.identifier() + ") AND " +
forceMFAInstanceID.identifier() + " = " + UserInstanceIDCol.identifier() + db.Timetravel(call.Took(ctx))). forceMFAInstanceID.identifier() + " = " + UserInstanceIDCol.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar), PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) ([]domain.UserAuthMethodType, bool, bool, error) { func(rows *sql.Rows) (*UserAuthMethodRequirements, error) {
userAuthMethodTypes := make([]domain.UserAuthMethodType, 0) userAuthMethodTypes := make([]domain.UserAuthMethodType, 0)
var passwordSet sql.NullBool var passwordSet sql.NullBool
var idp sql.NullInt64 var idp sql.NullInt64
var userType sql.NullInt32
var forceMFA sql.NullBool var forceMFA sql.NullBool
var forceMFALocalOnly sql.NullBool var forceMFALocalOnly sql.NullBool
for rows.Next() { for rows.Next() {
@ -448,11 +457,12 @@ func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareData
&passwordSet, &passwordSet,
&authMethodType, &authMethodType,
&idp, &idp,
&userType,
&forceMFA, &forceMFA,
&forceMFALocalOnly, &forceMFALocalOnly,
) )
if err != nil { if err != nil {
return nil, false, false, err return nil, err
} }
if authMethodType.Valid { if authMethodType.Valid {
userAuthMethodTypes = append(userAuthMethodTypes, domain.UserAuthMethodType(authMethodType.Int16)) userAuthMethodTypes = append(userAuthMethodTypes, domain.UserAuthMethodType(authMethodType.Int16))
@ -467,10 +477,15 @@ func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareData
} }
if err := rows.Close(); err != nil { if err := rows.Close(); err != nil {
return nil, false, false, zerrors.ThrowInternal(err, "QUERY-W4zje", "Errors.Query.CloseRows") return nil, zerrors.ThrowInternal(err, "QUERY-W4zje", "Errors.Query.CloseRows")
} }
return userAuthMethodTypes, forceMFA.Bool, forceMFALocalOnly.Bool, nil return &UserAuthMethodRequirements{
UserType: domain.UserType(userType.Int32),
AuthMethods: userAuthMethodTypes,
ForceMFA: forceMFA.Bool,
ForceMFALocalOnly: forceMFALocalOnly.Bool,
}, nil
} }
} }

View File

@ -59,6 +59,7 @@ var (
prepareAuthMethodTypesRequiredStmt = `SELECT projections.users12_notifications.password_set,` + prepareAuthMethodTypesRequiredStmt = `SELECT projections.users12_notifications.password_set,` +
` auth_method_types.method_type,` + ` auth_method_types.method_type,` +
` user_idps_count.count,` + ` user_idps_count.count,` +
` projections.users12.type,` +
` auth_methods_force_mfa.force_mfa,` + ` auth_methods_force_mfa.force_mfa,` +
` auth_methods_force_mfa.force_mfa_local_only` + ` auth_methods_force_mfa.force_mfa_local_only` +
` FROM projections.users12` + ` FROM projections.users12` +
@ -75,6 +76,7 @@ var (
` `
prepareAuthMethodTypesRequiredCols = []string{ prepareAuthMethodTypesRequiredCols = []string{
"password_set", "password_set",
"type",
"method_type", "method_type",
"idps_count", "idps_count",
"force_mfa", "force_mfa",
@ -317,14 +319,10 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
}, },
{ {
name: "prepareUserAuthMethodTypesRequiredQuery no result", name: "prepareUserAuthMethodTypesRequiredQuery no result",
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*testUserAuthMethodTypesRequired, error)) { prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) {
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
return builder, func(rows *sql.Rows) (*testUserAuthMethodTypesRequired, error) { return builder, func(rows *sql.Rows) (*UserAuthMethodRequirements, error) {
authMethods, forceMFA, forceMFALocalOnly, err := scan(rows) return scan(rows)
if err != nil {
return nil, err
}
return &testUserAuthMethodTypesRequired{authMethods: authMethods, forceMFA: forceMFA, forceMFALocalOnly: forceMFALocalOnly}, nil
} }
}, },
want: want{ want: want{
@ -334,18 +332,14 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
nil, nil,
), ),
}, },
object: &testUserAuthMethodTypesRequired{authMethods: []domain.UserAuthMethodType{}, forceMFA: false}, object: &UserAuthMethodRequirements{AuthMethods: []domain.UserAuthMethodType{}, ForceMFA: false},
}, },
{ {
name: "prepareUserAuthMethodTypesRequiredQuery one second factor", name: "prepareUserAuthMethodTypesRequiredQuery one second factor",
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*testUserAuthMethodTypesRequired, error)) { prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) {
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
return builder, func(rows *sql.Rows) (*testUserAuthMethodTypesRequired, error) { return builder, func(rows *sql.Rows) (*UserAuthMethodRequirements, error) {
authMethods, forceMFA, forceMFALocalOnly, err := scan(rows) return scan(rows)
if err != nil {
return nil, err
}
return &testUserAuthMethodTypesRequired{authMethods: authMethods, forceMFA: forceMFA, forceMFALocalOnly: forceMFALocalOnly}, nil
} }
}, },
want: want{ want: want{
@ -357,32 +351,30 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
true, true,
domain.UserAuthMethodTypePasswordless, domain.UserAuthMethodTypePasswordless,
1, 1,
domain.UserTypeHuman,
true, true,
true, true,
}, },
}, },
), ),
}, },
object: &testUserAuthMethodTypesRequired{ object: &UserAuthMethodRequirements{
authMethods: []domain.UserAuthMethodType{ UserType: domain.UserTypeHuman,
AuthMethods: []domain.UserAuthMethodType{
domain.UserAuthMethodTypePasswordless, domain.UserAuthMethodTypePasswordless,
domain.UserAuthMethodTypePassword, domain.UserAuthMethodTypePassword,
domain.UserAuthMethodTypeIDP, domain.UserAuthMethodTypeIDP,
}, },
forceMFA: true, ForceMFA: true,
forceMFALocalOnly: true, ForceMFALocalOnly: true,
}, },
}, },
{ {
name: "prepareUserAuthMethodTypesRequiredQuery multiple second factors", name: "prepareUserAuthMethodTypesRequiredQuery multiple second factors",
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*testUserAuthMethodTypesRequired, error)) { prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) {
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
return builder, func(rows *sql.Rows) (*testUserAuthMethodTypesRequired, error) { return builder, func(rows *sql.Rows) (*UserAuthMethodRequirements, error) {
authMethods, forceMFA, forceMFALocalOnly, err := scan(rows) return scan(rows)
if err != nil {
return nil, err
}
return &testUserAuthMethodTypesRequired{authMethods: authMethods, forceMFA: forceMFA, forceMFALocalOnly: forceMFALocalOnly}, nil
} }
}, },
want: want{ want: want{
@ -394,6 +386,7 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
true, true,
domain.UserAuthMethodTypePasswordless, domain.UserAuthMethodTypePasswordless,
1, 1,
domain.UserTypeHuman,
true, true,
true, true,
}, },
@ -401,6 +394,7 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
true, true,
domain.UserAuthMethodTypeTOTP, domain.UserAuthMethodTypeTOTP,
1, 1,
domain.UserTypeHuman,
true, true,
true, true,
}, },
@ -408,27 +402,24 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
), ),
}, },
object: &testUserAuthMethodTypesRequired{ object: &UserAuthMethodRequirements{
authMethods: []domain.UserAuthMethodType{ UserType: domain.UserTypeHuman,
AuthMethods: []domain.UserAuthMethodType{
domain.UserAuthMethodTypePasswordless, domain.UserAuthMethodTypePasswordless,
domain.UserAuthMethodTypeTOTP, domain.UserAuthMethodTypeTOTP,
domain.UserAuthMethodTypePassword, domain.UserAuthMethodTypePassword,
domain.UserAuthMethodTypeIDP, domain.UserAuthMethodTypeIDP,
}, },
forceMFA: true, ForceMFA: true,
forceMFALocalOnly: true, ForceMFALocalOnly: true,
}, },
}, },
{ {
name: "prepareUserAuthMethodTypesRequiredQuery sql err", name: "prepareUserAuthMethodTypesRequiredQuery sql err",
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*testUserAuthMethodTypesRequired, error)) { prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) {
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
return builder, func(rows *sql.Rows) (*testUserAuthMethodTypesRequired, error) { return builder, func(rows *sql.Rows) (*UserAuthMethodRequirements, error) {
authMethods, forceMFA, forceMFALocalOnly, err := scan(rows) return scan(rows)
if err != nil {
return nil, err
}
return &testUserAuthMethodTypesRequired{authMethods: authMethods, forceMFA: forceMFA, forceMFALocalOnly: forceMFALocalOnly}, nil
} }
}, },
want: want{ want: want{
@ -452,10 +443,3 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
}) })
} }
} }
// testUserAuthMethodTypesRequired is required as assetPrepare is only able to return a single object from scan
type testUserAuthMethodTypesRequired struct {
authMethods []domain.UserAuthMethodType
forceMFA bool
forceMFALocalOnly bool
}