mirror of
https://github.com/zitadel/zitadel.git
synced 2025-01-12 13:03:40 +00:00
fix(oidc): IDP and machine user auth methods (#7992)
# Which Problems Are Solved After https://github.com/zitadel/zitadel/pull/7822 was merged we discovered that v2 tokens that where obtained through an IDP using the v1 login, can't be used for zitadel API calls. - Because we used to store the AMR claim on the auth request, but internally use the domain.UserAuthMethod type. AMR has no notion of an IDP login, so that "factor" was lost during conversion. Rendering those v2 tokens invalid on the zitadel API. - A wrong check on machine user tokens falsly allowed some tokens to be valid - The client ID was set to tokens from client credentials and JWT profile, which made client queries fail in the validation middleware. The middleware expects client ID unset for machine users. # How the Problems Are Solved Store the domain.AuthMethods directly in the auth requests and session, instead of using AMR claims with lossy conversion. - IDPs have seperate auth method, which is not an AMR claim - Machine users are treated specialy, eg auth methods are not required. - Do not set the client ID for client credentials and JWT profile # Additional Changes Cleaned up mostly unused `oidc.getInfoFromRequest()`. # Additional Context - Bugs were introduced in https://github.com/zitadel/zitadel/pull/7822 and not yet part of a release. - Reported internally.
This commit is contained in:
parent
e57a9b57c8
commit
f5e9d4f57f
@ -12,7 +12,6 @@ import (
|
|||||||
"github.com/zitadel/logging"
|
"github.com/zitadel/logging"
|
||||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
"github.com/zitadel/oidc/v3/pkg/op"
|
"github.com/zitadel/oidc/v3/pkg/op"
|
||||||
"golang.org/x/text/language"
|
|
||||||
|
|
||||||
"github.com/zitadel/zitadel/internal/api/authz"
|
"github.com/zitadel/zitadel/internal/api/authz"
|
||||||
http_utils "github.com/zitadel/zitadel/internal/api/http"
|
http_utils "github.com/zitadel/zitadel/internal/api/http"
|
||||||
@ -199,23 +198,6 @@ func (*OPStorage) panicErr(method string) error {
|
|||||||
return fmt.Errorf("OPStorage.%s should not be called anymore. This is a bug. Please report https://github.com/zitadel/zitadel/issues", method)
|
return fmt.Errorf("OPStorage.%s should not be called anymore. This is a bug. Please report https://github.com/zitadel/zitadel/issues", method)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getInfoFromRequest(req op.TokenRequest) (agentID, clientID, userOrgID string, authTime time.Time, amr []string, preferredLanguage *language.Tag, reason domain.TokenReason, actor *domain.TokenActor) {
|
|
||||||
switch r := req.(type) {
|
|
||||||
case *AuthRequest:
|
|
||||||
return r.AgentID, r.ApplicationID, r.UserOrgID, r.AuthTime, r.GetAMR(), r.PreferredLanguage, domain.TokenReasonAuthRequest, nil
|
|
||||||
case *RefreshTokenRequest:
|
|
||||||
return r.UserAgentID, r.ClientID, "", r.AuthTime, r.AuthMethodsReferences, nil, domain.TokenReasonRefresh, r.Actor
|
|
||||||
case op.IDTokenRequest:
|
|
||||||
return "", r.GetClientID(), "", r.GetAuthTime(), r.GetAMR(), nil, domain.TokenReasonAuthRequest, nil
|
|
||||||
case *oidc.JWTTokenRequest:
|
|
||||||
return "", "", "", r.GetAuthTime(), nil, nil, domain.TokenReasonJWTProfile, nil
|
|
||||||
case *clientCredentialsRequest:
|
|
||||||
return "", "", "", time.Time{}, nil, nil, domain.TokenReasonClientCredentials, nil
|
|
||||||
default:
|
|
||||||
return "", "", "", time.Time{}, nil, nil, domain.TokenReasonAuthRequest, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (_ op.RefreshTokenRequest, err error) {
|
func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (_ op.RefreshTokenRequest, err error) {
|
||||||
panic("TokenRequestByRefreshToken should not be called anymore. This is a bug. Please report https://github.com/zitadel/zitadel/issues")
|
panic("TokenRequestByRefreshToken should not be called anymore. This is a bug. Please report https://github.com/zitadel/zitadel/issues")
|
||||||
}
|
}
|
||||||
@ -511,8 +493,8 @@ func implicitFlowComplianceChecker() command.AuthRequestComplianceChecker {
|
|||||||
|
|
||||||
func (s *Server) authorizeCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) authorizeCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
authorizer := s.Provider()
|
authorizer := s.Provider()
|
||||||
authReq, err := func() (authReq op.AuthRequest, err error) {
|
authReq, err := func(ctx context.Context) (authReq *AuthRequest, err error) {
|
||||||
ctx, span := tracing.NewSpan(r.Context())
|
ctx, span := tracing.NewSpan(ctx)
|
||||||
r = r.WithContext(ctx)
|
r = r.WithContext(ctx)
|
||||||
defer func() { span.EndWithError(err) }()
|
defer func() { span.EndWithError(err) }()
|
||||||
|
|
||||||
@ -520,7 +502,7 @@ func (s *Server) authorizeCallbackHandler(w http.ResponseWriter, r *http.Request
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
authReq, err = authorizer.Storage().AuthRequestByID(r.Context(), id)
|
authReq, err = s.getAuthRequestV1ByID(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -528,13 +510,13 @@ func (s *Server) authorizeCallbackHandler(w http.ResponseWriter, r *http.Request
|
|||||||
return authReq, oidc.ErrInteractionRequired().WithDescription("Unfortunately, the user may be not logged in and/or additional interaction is required.")
|
return authReq, oidc.ErrInteractionRequired().WithDescription("Unfortunately, the user may be not logged in and/or additional interaction is required.")
|
||||||
}
|
}
|
||||||
return authReq, s.authResponse(authReq, authorizer, w, r)
|
return authReq, s.authResponse(authReq, authorizer, w, r)
|
||||||
}()
|
}(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
op.AuthRequestError(w, r, authReq, err, authorizer)
|
op.AuthRequestError(w, r, authReq, err, authorizer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) authResponse(authReq op.AuthRequest, authorizer op.Authorizer, w http.ResponseWriter, r *http.Request) (err error) {
|
func (s *Server) authResponse(authReq *AuthRequest, authorizer op.Authorizer, w http.ResponseWriter, r *http.Request) (err error) {
|
||||||
ctx, span := tracing.NewSpan(r.Context())
|
ctx, span := tracing.NewSpan(r.Context())
|
||||||
r = r.WithContext(ctx)
|
r = r.WithContext(ctx)
|
||||||
defer func() { span.EndWithError(err) }()
|
defer func() { span.EndWithError(err) }()
|
||||||
@ -551,7 +533,7 @@ func (s *Server) authResponse(authReq op.AuthRequest, authorizer op.Authorizer,
|
|||||||
return s.authResponseToken(authReq, authorizer, client, w, r)
|
return s.authResponseToken(authReq, authorizer, client, w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) authResponseToken(authReq op.AuthRequest, authorizer op.Authorizer, opClient op.Client, w http.ResponseWriter, r *http.Request) (err error) {
|
func (s *Server) authResponseToken(authReq *AuthRequest, authorizer op.Authorizer, opClient op.Client, w http.ResponseWriter, r *http.Request) (err error) {
|
||||||
ctx, span := tracing.NewSpan(r.Context())
|
ctx, span := tracing.NewSpan(r.Context())
|
||||||
r = r.WithContext(ctx)
|
r = r.WithContext(ctx)
|
||||||
defer func() { span.EndWithError(err) }()
|
defer func() { span.EndWithError(err) }()
|
||||||
@ -561,23 +543,20 @@ func (s *Server) authResponseToken(authReq op.AuthRequest, authorizer op.Authori
|
|||||||
return zerrors.ThrowInternal(nil, "OIDC-waeN6", "Error.Internal")
|
return zerrors.ThrowInternal(nil, "OIDC-waeN6", "Error.Internal")
|
||||||
}
|
}
|
||||||
|
|
||||||
userAgentID, _, userOrgID, authTime, authMethodsReferences, preferredLanguage, reason, actor := getInfoFromRequest(authReq)
|
|
||||||
scope := authReq.GetScopes()
|
scope := authReq.GetScopes()
|
||||||
session, err := s.command.CreateOIDCSession(ctx,
|
session, err := s.command.CreateOIDCSession(ctx,
|
||||||
authReq.GetSubject(),
|
authReq.UserID,
|
||||||
userOrgID,
|
authReq.UserOrgID,
|
||||||
client.client.ClientID,
|
client.client.ClientID,
|
||||||
scope,
|
scope,
|
||||||
authReq.GetAudience(),
|
authReq.Audience,
|
||||||
AMRToAuthMethodTypes(authMethodsReferences),
|
authReq.AuthMethods(),
|
||||||
authTime,
|
authReq.AuthTime,
|
||||||
authReq.GetNonce(),
|
authReq.GetNonce(),
|
||||||
preferredLanguage,
|
authReq.PreferredLanguage,
|
||||||
&domain.UserAgent{
|
authReq.BrowserInfo.ToUserAgent(),
|
||||||
FingerprintID: &userAgentID,
|
domain.TokenReasonAuthRequest,
|
||||||
},
|
nil,
|
||||||
reason,
|
|
||||||
actor,
|
|
||||||
slices.Contains(scope, oidc.ScopeOfflineAccess),
|
slices.Contains(scope, oidc.ScopeOfflineAccess),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -94,7 +94,7 @@ func (a *AuthRequest) oidc() *domain.AuthRequestOIDC {
|
|||||||
return a.Request.(*domain.AuthRequestOIDC)
|
return a.Request.(*domain.AuthRequestOIDC)
|
||||||
}
|
}
|
||||||
|
|
||||||
func AuthRequestFromBusiness(authReq *domain.AuthRequest) (_ op.AuthRequest, err error) {
|
func AuthRequestFromBusiness(authReq *domain.AuthRequest) (_ *AuthRequest, err error) {
|
||||||
if _, ok := authReq.Request.(*domain.AuthRequestOIDC); !ok {
|
if _, ok := authReq.Request.(*domain.AuthRequestOIDC); !ok {
|
||||||
return nil, zerrors.ThrowInvalidArgument(nil, "OIDC-Haz7A", "auth request is not of type oidc")
|
return nil, zerrors.ThrowInvalidArgument(nil, "OIDC-Haz7A", "auth request is not of type oidc")
|
||||||
}
|
}
|
||||||
|
@ -34,7 +34,7 @@ func (s *Server) ClientCredentialsExchange(ctx context.Context, r *op.ClientRequ
|
|||||||
session, err := s.command.CreateOIDCSession(ctx,
|
session, err := s.command.CreateOIDCSession(ctx,
|
||||||
client.user.ID,
|
client.user.ID,
|
||||||
client.user.ResourceOwner,
|
client.user.ResourceOwner,
|
||||||
r.Data.ClientID,
|
"",
|
||||||
scope,
|
scope,
|
||||||
domain.AddAudScopeToAudience(ctx, nil, r.Data.Scope),
|
domain.AddAudScopeToAudience(ctx, nil, r.Data.Scope),
|
||||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
"github.com/zitadel/oidc/v3/pkg/op"
|
"github.com/zitadel/oidc/v3/pkg/op"
|
||||||
|
|
||||||
|
"github.com/zitadel/zitadel/internal/api/http/middleware"
|
||||||
"github.com/zitadel/zitadel/internal/command"
|
"github.com/zitadel/zitadel/internal/command"
|
||||||
"github.com/zitadel/zitadel/internal/domain"
|
"github.com/zitadel/zitadel/internal/domain"
|
||||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||||
@ -69,36 +70,33 @@ func (s *Server) codeExchangeV1(ctx context.Context, client *Client, req *oidc.A
|
|||||||
if req.RedirectURI != authReq.GetRedirectURI() {
|
if req.RedirectURI != authReq.GetRedirectURI() {
|
||||||
return nil, "", oidc.ErrInvalidGrant().WithDescription("redirect_uri does not correspond")
|
return nil, "", oidc.ErrInvalidGrant().WithDescription("redirect_uri does not correspond")
|
||||||
}
|
}
|
||||||
userAgentID, _, userOrgID, authTime, authMethodsReferences, preferredLanguage, reason, actor := getInfoFromRequest(authReq)
|
|
||||||
|
|
||||||
scope := authReq.GetScopes()
|
scope := authReq.GetScopes()
|
||||||
session, err = s.command.CreateOIDCSession(ctx,
|
session, err = s.command.CreateOIDCSession(ctx,
|
||||||
authReq.GetSubject(),
|
authReq.UserID,
|
||||||
userOrgID,
|
authReq.UserOrgID,
|
||||||
client.client.ClientID,
|
client.client.ClientID,
|
||||||
scope,
|
scope,
|
||||||
authReq.GetAudience(),
|
authReq.Audience,
|
||||||
AMRToAuthMethodTypes(authMethodsReferences),
|
authReq.AuthMethods(),
|
||||||
authTime,
|
authReq.AuthTime,
|
||||||
authReq.GetNonce(),
|
authReq.GetNonce(),
|
||||||
preferredLanguage,
|
authReq.PreferredLanguage,
|
||||||
&domain.UserAgent{
|
authReq.BrowserInfo.ToUserAgent(),
|
||||||
FingerprintID: &userAgentID,
|
domain.TokenReasonAuthRequest,
|
||||||
},
|
nil,
|
||||||
reason,
|
|
||||||
actor,
|
|
||||||
slices.Contains(scope, oidc.ScopeOfflineAccess),
|
slices.Contains(scope, oidc.ScopeOfflineAccess),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
return session, authReq.GetState(), s.repo.DeleteAuthRequest(ctx, authReq.GetID())
|
return session, authReq.TransferState, s.repo.DeleteAuthRequest(ctx, authReq.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getAuthRequestV1ByCode finds the v1 auth request by code.
|
// getAuthRequestV1ByCode finds the v1 auth request by code.
|
||||||
// code needs to be the encrypted version of the ID,
|
// code needs to be the encrypted version of the ID,
|
||||||
// this is required by the underlying repo.
|
// this is required by the underlying repo.
|
||||||
func (s *Server) getAuthRequestV1ByCode(ctx context.Context, code string) (op.AuthRequest, error) {
|
func (s *Server) getAuthRequestV1ByCode(ctx context.Context, code string) (*AuthRequest, error) {
|
||||||
authReq, err := s.repo.AuthRequestByCode(ctx, code)
|
authReq, err := s.repo.AuthRequestByCode(ctx, code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -106,6 +104,18 @@ func (s *Server) getAuthRequestV1ByCode(ctx context.Context, code string) (op.Au
|
|||||||
return AuthRequestFromBusiness(authReq)
|
return AuthRequestFromBusiness(authReq)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) getAuthRequestV1ByID(ctx context.Context, id string) (*AuthRequest, error) {
|
||||||
|
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
||||||
|
if !ok {
|
||||||
|
return nil, zerrors.ThrowPreconditionFailed(nil, "OIDC-TiTu7", "no user agent id")
|
||||||
|
}
|
||||||
|
resp, err := s.repo.AuthRequestByIDCheckLoggedIn(ctx, id, userAgentID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return AuthRequestFromBusiness(resp)
|
||||||
|
}
|
||||||
|
|
||||||
func codeExchangeComplianceChecker(client *Client, req *oidc.AccessTokenRequest) command.AuthRequestComplianceChecker {
|
func codeExchangeComplianceChecker(client *Client, req *oidc.AccessTokenRequest) command.AuthRequestComplianceChecker {
|
||||||
return func(ctx context.Context, authReq *command.AuthRequestWriteModel) error {
|
return func(ctx context.Context, authReq *command.AuthRequestWriteModel) error {
|
||||||
if authReq.CodeChallenge != nil || client.AuthMethod() == oidc.AuthMethodNone {
|
if authReq.CodeChallenge != nil || client.AuthMethod() == oidc.AuthMethodNone {
|
||||||
|
@ -42,15 +42,15 @@ func (s *Server) JWTProfile(ctx context.Context, r *op.Request[oidc.JWTProfileGr
|
|||||||
session, err := s.command.CreateOIDCSession(ctx,
|
session, err := s.command.CreateOIDCSession(ctx,
|
||||||
user.ID,
|
user.ID,
|
||||||
user.ResourceOwner,
|
user.ResourceOwner,
|
||||||
jwtReq.Subject,
|
"",
|
||||||
scope,
|
scope,
|
||||||
domain.AddAudScopeToAudience(ctx, nil, r.Data.Scope),
|
domain.AddAudScopeToAudience(ctx, nil, r.Data.Scope),
|
||||||
nil,
|
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePrivateKey},
|
||||||
time.Now(),
|
time.Now(),
|
||||||
"",
|
"",
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
domain.TokenReasonClientCredentials,
|
domain.TokenReasonJWTProfile,
|
||||||
nil,
|
nil,
|
||||||
false,
|
false,
|
||||||
)
|
)
|
||||||
|
@ -269,6 +269,7 @@ func (repo *AuthRequestRepo) CheckExternalUserLogin(ctx context.Context, authReq
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
request.IDPLoginChecked = true
|
||||||
err = repo.Command.UserIDPLoginChecked(ctx, request.UserOrgID, request.UserID, request.WithCurrentInfo(info))
|
err = repo.Command.UserIDPLoginChecked(ctx, request.UserOrgID, request.UserID, request.WithCurrentInfo(info))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -516,6 +517,7 @@ func (repo *AuthRequestRepo) LinkExternalUsers(ctx context.Context, authReqID, u
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
request.LinkingUsers = nil
|
request.LinkingUsers = nil
|
||||||
|
request.IDPLoginChecked = true
|
||||||
return repo.AuthRequests.UpdateAuthRequest(ctx, request)
|
return repo.AuthRequests.UpdateAuthRequest(ctx, request)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -560,6 +562,7 @@ func (repo *AuthRequestRepo) AutoRegisterExternalUser(ctx context.Context, regis
|
|||||||
request.SetUserInfo(human.ID, human.Username, human.Username, human.DisplayName, "", resourceOwner)
|
request.SetUserInfo(human.ID, human.Username, human.Username, human.DisplayName, "", resourceOwner)
|
||||||
request.SelectedIDPConfigID = externalIDP.IDPConfigID
|
request.SelectedIDPConfigID = externalIDP.IDPConfigID
|
||||||
request.LinkingUsers = nil
|
request.LinkingUsers = nil
|
||||||
|
request.IDPLoginChecked = true
|
||||||
err = repo.Command.UserIDPLoginChecked(ctx, request.UserOrgID, request.UserID, request.WithCurrentInfo(info))
|
err = repo.Command.UserIDPLoginChecked(ctx, request.UserOrgID, request.UserID, request.WithCurrentInfo(info))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -180,11 +180,18 @@ func (repo *TokenVerifierRepo) checkAuthentication(ctx context.Context, authMeth
|
|||||||
if domain.HasMFA(authMethods) {
|
if domain.HasMFA(authMethods) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
availableAuthMethods, forceMFA, forceMFALocalOnly, err := repo.Query.ListUserAuthMethodTypesRequired(setCallerCtx(ctx, userID), userID)
|
requirements, err := repo.Query.ListUserAuthMethodTypesRequired(setCallerCtx(ctx, userID), userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if domain.RequiresMFA(forceMFA, forceMFALocalOnly, hasIDPAuthentication(authMethods)) || domain.HasMFA(availableAuthMethods) {
|
if requirements.UserType == domain.UserTypeMachine {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if domain.RequiresMFA(
|
||||||
|
requirements.ForceMFA,
|
||||||
|
requirements.ForceMFALocalOnly,
|
||||||
|
!hasIDPAuthentication(authMethods)) ||
|
||||||
|
domain.HasMFA(requirements.AuthMethods) {
|
||||||
return zerrors.ThrowPermissionDenied(nil, "AUTHZ-Kl3p0", "mfa required")
|
return zerrors.ThrowPermissionDenied(nil, "AUTHZ-Kl3p0", "mfa required")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -45,6 +45,7 @@ type AuthRequest struct {
|
|||||||
LinkingUsers []*ExternalUser
|
LinkingUsers []*ExternalUser
|
||||||
PossibleSteps []NextStep `json:"-"`
|
PossibleSteps []NextStep `json:"-"`
|
||||||
PasswordVerified bool
|
PasswordVerified bool
|
||||||
|
IDPLoginChecked bool
|
||||||
MFAsVerified []MFAType
|
MFAsVerified []MFAType
|
||||||
Audience []string
|
Audience []string
|
||||||
AuthTime time.Time
|
AuthTime time.Time
|
||||||
@ -69,6 +70,20 @@ func (a *AuthRequest) PolicyOrgID() string {
|
|||||||
return a.policyOrgID
|
return a.policyOrgID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *AuthRequest) AuthMethods() []UserAuthMethodType {
|
||||||
|
list := make([]UserAuthMethodType, 0, len(a.MFAsVerified)+2)
|
||||||
|
if a.PasswordVerified {
|
||||||
|
list = append(list, UserAuthMethodTypePassword)
|
||||||
|
}
|
||||||
|
if a.IDPLoginChecked {
|
||||||
|
list = append(list, UserAuthMethodTypeIDP)
|
||||||
|
}
|
||||||
|
for _, mfa := range a.MFAsVerified {
|
||||||
|
list = append(list, mfa.UserAuthMethodType())
|
||||||
|
}
|
||||||
|
return list
|
||||||
|
}
|
||||||
|
|
||||||
type ExternalUser struct {
|
type ExternalUser struct {
|
||||||
IDPConfigID string
|
IDPConfigID string
|
||||||
ExternalUserID string
|
ExternalUserID string
|
||||||
|
@ -43,6 +43,7 @@ const (
|
|||||||
UserAuthMethodTypeOTPSMS
|
UserAuthMethodTypeOTPSMS
|
||||||
UserAuthMethodTypeOTPEmail
|
UserAuthMethodTypeOTPEmail
|
||||||
UserAuthMethodTypeOTP // generic OTP when parsing AMR from OIDC
|
UserAuthMethodTypeOTP // generic OTP when parsing AMR from OIDC
|
||||||
|
UserAuthMethodTypePrivateKey
|
||||||
userAuthMethodTypeCount
|
userAuthMethodTypeCount
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -172,11 +172,18 @@ func (q *Queries) ListActiveUserAuthMethodTypes(ctx context.Context, userID stri
|
|||||||
return userAuthMethodTypes, err
|
return userAuthMethodTypes, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID string) (userAuthMethodTypes []domain.UserAuthMethodType, forceMFA, forceMFALocalOnly bool, err error) {
|
type UserAuthMethodRequirements struct {
|
||||||
|
UserType domain.UserType
|
||||||
|
AuthMethods []domain.UserAuthMethodType
|
||||||
|
ForceMFA bool
|
||||||
|
ForceMFALocalOnly bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID string) (requirements *UserAuthMethodRequirements, err error) {
|
||||||
ctxData := authz.GetCtxData(ctx)
|
ctxData := authz.GetCtxData(ctx)
|
||||||
if ctxData.UserID != userID {
|
if ctxData.UserID != userID {
|
||||||
if err := q.checkPermission(ctx, domain.PermissionUserRead, ctxData.OrgID, userID); err != nil {
|
if err := q.checkPermission(ctx, domain.PermissionUserRead, ctxData.OrgID, userID); err != nil {
|
||||||
return nil, false, false, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ctx, span := tracing.NewSpan(ctx)
|
ctx, span := tracing.NewSpan(ctx)
|
||||||
@ -189,17 +196,17 @@ func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID st
|
|||||||
}
|
}
|
||||||
stmt, args, err := query.Where(eq).ToSql()
|
stmt, args, err := query.Where(eq).ToSql()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, false, zerrors.ThrowInvalidArgument(err, "QUERY-E5ut4", "Errors.Query.InvalidRequest")
|
return nil, zerrors.ThrowInvalidArgument(err, "QUERY-E5ut4", "Errors.Query.InvalidRequest")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
|
err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
|
||||||
userAuthMethodTypes, forceMFA, forceMFALocalOnly, err = scan(rows)
|
requirements, err = scan(rows)
|
||||||
return err
|
return err
|
||||||
}, stmt, args...)
|
}, stmt, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, false, zerrors.ThrowInternal(err, "QUERY-Dun75", "Errors.Internal")
|
return nil, zerrors.ThrowInternal(err, "QUERY-Dun75", "Errors.Internal")
|
||||||
}
|
}
|
||||||
return userAuthMethodTypes, forceMFA, forceMFALocalOnly, nil
|
return requirements, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserAuthMethodUserIDSearchQuery(value string) (SearchQuery, error) {
|
func NewUserAuthMethodUserIDSearchQuery(value string) (SearchQuery, error) {
|
||||||
@ -404,7 +411,7 @@ func prepareActiveUserAuthMethodTypesQuery(ctx context.Context, db prepareDataba
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (_ []domain.UserAuthMethodType, forceMFA, forceMFALocalOnly bool, err error)) {
|
func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) {
|
||||||
loginPolicyQuery, err := prepareAuthMethodsForceMFAQuery()
|
loginPolicyQuery, err := prepareAuthMethodsForceMFAQuery()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return sq.SelectBuilder{}, nil
|
return sq.SelectBuilder{}, nil
|
||||||
@ -421,6 +428,7 @@ func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareData
|
|||||||
NotifyPasswordSetCol.identifier(),
|
NotifyPasswordSetCol.identifier(),
|
||||||
authMethodTypeTypes.identifier(),
|
authMethodTypeTypes.identifier(),
|
||||||
userIDPsCountCount.identifier(),
|
userIDPsCountCount.identifier(),
|
||||||
|
UserTypeCol.identifier(),
|
||||||
forceMFAForce.identifier(),
|
forceMFAForce.identifier(),
|
||||||
forceMFAForceLocalOnly.identifier()).
|
forceMFAForceLocalOnly.identifier()).
|
||||||
From(userTable.identifier()).
|
From(userTable.identifier()).
|
||||||
@ -436,10 +444,11 @@ func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareData
|
|||||||
"(" + forceMFAOrgID.identifier() + " = " + UserInstanceIDCol.identifier() + " OR " + forceMFAOrgID.identifier() + " = " + UserResourceOwnerCol.identifier() + ") AND " +
|
"(" + forceMFAOrgID.identifier() + " = " + UserInstanceIDCol.identifier() + " OR " + forceMFAOrgID.identifier() + " = " + UserResourceOwnerCol.identifier() + ") AND " +
|
||||||
forceMFAInstanceID.identifier() + " = " + UserInstanceIDCol.identifier() + db.Timetravel(call.Took(ctx))).
|
forceMFAInstanceID.identifier() + " = " + UserInstanceIDCol.identifier() + db.Timetravel(call.Took(ctx))).
|
||||||
PlaceholderFormat(sq.Dollar),
|
PlaceholderFormat(sq.Dollar),
|
||||||
func(rows *sql.Rows) ([]domain.UserAuthMethodType, bool, bool, error) {
|
func(rows *sql.Rows) (*UserAuthMethodRequirements, error) {
|
||||||
userAuthMethodTypes := make([]domain.UserAuthMethodType, 0)
|
userAuthMethodTypes := make([]domain.UserAuthMethodType, 0)
|
||||||
var passwordSet sql.NullBool
|
var passwordSet sql.NullBool
|
||||||
var idp sql.NullInt64
|
var idp sql.NullInt64
|
||||||
|
var userType sql.NullInt32
|
||||||
var forceMFA sql.NullBool
|
var forceMFA sql.NullBool
|
||||||
var forceMFALocalOnly sql.NullBool
|
var forceMFALocalOnly sql.NullBool
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
@ -448,11 +457,12 @@ func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareData
|
|||||||
&passwordSet,
|
&passwordSet,
|
||||||
&authMethodType,
|
&authMethodType,
|
||||||
&idp,
|
&idp,
|
||||||
|
&userType,
|
||||||
&forceMFA,
|
&forceMFA,
|
||||||
&forceMFALocalOnly,
|
&forceMFALocalOnly,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, false, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if authMethodType.Valid {
|
if authMethodType.Valid {
|
||||||
userAuthMethodTypes = append(userAuthMethodTypes, domain.UserAuthMethodType(authMethodType.Int16))
|
userAuthMethodTypes = append(userAuthMethodTypes, domain.UserAuthMethodType(authMethodType.Int16))
|
||||||
@ -467,10 +477,15 @@ func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareData
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := rows.Close(); err != nil {
|
if err := rows.Close(); err != nil {
|
||||||
return nil, false, false, zerrors.ThrowInternal(err, "QUERY-W4zje", "Errors.Query.CloseRows")
|
return nil, zerrors.ThrowInternal(err, "QUERY-W4zje", "Errors.Query.CloseRows")
|
||||||
}
|
}
|
||||||
|
|
||||||
return userAuthMethodTypes, forceMFA.Bool, forceMFALocalOnly.Bool, nil
|
return &UserAuthMethodRequirements{
|
||||||
|
UserType: domain.UserType(userType.Int32),
|
||||||
|
AuthMethods: userAuthMethodTypes,
|
||||||
|
ForceMFA: forceMFA.Bool,
|
||||||
|
ForceMFALocalOnly: forceMFALocalOnly.Bool,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,6 +59,7 @@ var (
|
|||||||
prepareAuthMethodTypesRequiredStmt = `SELECT projections.users12_notifications.password_set,` +
|
prepareAuthMethodTypesRequiredStmt = `SELECT projections.users12_notifications.password_set,` +
|
||||||
` auth_method_types.method_type,` +
|
` auth_method_types.method_type,` +
|
||||||
` user_idps_count.count,` +
|
` user_idps_count.count,` +
|
||||||
|
` projections.users12.type,` +
|
||||||
` auth_methods_force_mfa.force_mfa,` +
|
` auth_methods_force_mfa.force_mfa,` +
|
||||||
` auth_methods_force_mfa.force_mfa_local_only` +
|
` auth_methods_force_mfa.force_mfa_local_only` +
|
||||||
` FROM projections.users12` +
|
` FROM projections.users12` +
|
||||||
@ -75,6 +76,7 @@ var (
|
|||||||
`
|
`
|
||||||
prepareAuthMethodTypesRequiredCols = []string{
|
prepareAuthMethodTypesRequiredCols = []string{
|
||||||
"password_set",
|
"password_set",
|
||||||
|
"type",
|
||||||
"method_type",
|
"method_type",
|
||||||
"idps_count",
|
"idps_count",
|
||||||
"force_mfa",
|
"force_mfa",
|
||||||
@ -317,14 +319,10 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "prepareUserAuthMethodTypesRequiredQuery no result",
|
name: "prepareUserAuthMethodTypesRequiredQuery no result",
|
||||||
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*testUserAuthMethodTypesRequired, error)) {
|
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) {
|
||||||
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
|
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
|
||||||
return builder, func(rows *sql.Rows) (*testUserAuthMethodTypesRequired, error) {
|
return builder, func(rows *sql.Rows) (*UserAuthMethodRequirements, error) {
|
||||||
authMethods, forceMFA, forceMFALocalOnly, err := scan(rows)
|
return scan(rows)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &testUserAuthMethodTypesRequired{authMethods: authMethods, forceMFA: forceMFA, forceMFALocalOnly: forceMFALocalOnly}, nil
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
want: want{
|
want: want{
|
||||||
@ -334,18 +332,14 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
|
|||||||
nil,
|
nil,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
object: &testUserAuthMethodTypesRequired{authMethods: []domain.UserAuthMethodType{}, forceMFA: false},
|
object: &UserAuthMethodRequirements{AuthMethods: []domain.UserAuthMethodType{}, ForceMFA: false},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "prepareUserAuthMethodTypesRequiredQuery one second factor",
|
name: "prepareUserAuthMethodTypesRequiredQuery one second factor",
|
||||||
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*testUserAuthMethodTypesRequired, error)) {
|
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) {
|
||||||
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
|
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
|
||||||
return builder, func(rows *sql.Rows) (*testUserAuthMethodTypesRequired, error) {
|
return builder, func(rows *sql.Rows) (*UserAuthMethodRequirements, error) {
|
||||||
authMethods, forceMFA, forceMFALocalOnly, err := scan(rows)
|
return scan(rows)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &testUserAuthMethodTypesRequired{authMethods: authMethods, forceMFA: forceMFA, forceMFALocalOnly: forceMFALocalOnly}, nil
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
want: want{
|
want: want{
|
||||||
@ -357,32 +351,30 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
|
|||||||
true,
|
true,
|
||||||
domain.UserAuthMethodTypePasswordless,
|
domain.UserAuthMethodTypePasswordless,
|
||||||
1,
|
1,
|
||||||
|
domain.UserTypeHuman,
|
||||||
true,
|
true,
|
||||||
true,
|
true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
object: &testUserAuthMethodTypesRequired{
|
object: &UserAuthMethodRequirements{
|
||||||
authMethods: []domain.UserAuthMethodType{
|
UserType: domain.UserTypeHuman,
|
||||||
|
AuthMethods: []domain.UserAuthMethodType{
|
||||||
domain.UserAuthMethodTypePasswordless,
|
domain.UserAuthMethodTypePasswordless,
|
||||||
domain.UserAuthMethodTypePassword,
|
domain.UserAuthMethodTypePassword,
|
||||||
domain.UserAuthMethodTypeIDP,
|
domain.UserAuthMethodTypeIDP,
|
||||||
},
|
},
|
||||||
forceMFA: true,
|
ForceMFA: true,
|
||||||
forceMFALocalOnly: true,
|
ForceMFALocalOnly: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "prepareUserAuthMethodTypesRequiredQuery multiple second factors",
|
name: "prepareUserAuthMethodTypesRequiredQuery multiple second factors",
|
||||||
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*testUserAuthMethodTypesRequired, error)) {
|
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) {
|
||||||
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
|
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
|
||||||
return builder, func(rows *sql.Rows) (*testUserAuthMethodTypesRequired, error) {
|
return builder, func(rows *sql.Rows) (*UserAuthMethodRequirements, error) {
|
||||||
authMethods, forceMFA, forceMFALocalOnly, err := scan(rows)
|
return scan(rows)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &testUserAuthMethodTypesRequired{authMethods: authMethods, forceMFA: forceMFA, forceMFALocalOnly: forceMFALocalOnly}, nil
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
want: want{
|
want: want{
|
||||||
@ -394,6 +386,7 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
|
|||||||
true,
|
true,
|
||||||
domain.UserAuthMethodTypePasswordless,
|
domain.UserAuthMethodTypePasswordless,
|
||||||
1,
|
1,
|
||||||
|
domain.UserTypeHuman,
|
||||||
true,
|
true,
|
||||||
true,
|
true,
|
||||||
},
|
},
|
||||||
@ -401,6 +394,7 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
|
|||||||
true,
|
true,
|
||||||
domain.UserAuthMethodTypeTOTP,
|
domain.UserAuthMethodTypeTOTP,
|
||||||
1,
|
1,
|
||||||
|
domain.UserTypeHuman,
|
||||||
true,
|
true,
|
||||||
true,
|
true,
|
||||||
},
|
},
|
||||||
@ -408,27 +402,24 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
|
|||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
|
||||||
object: &testUserAuthMethodTypesRequired{
|
object: &UserAuthMethodRequirements{
|
||||||
authMethods: []domain.UserAuthMethodType{
|
UserType: domain.UserTypeHuman,
|
||||||
|
AuthMethods: []domain.UserAuthMethodType{
|
||||||
domain.UserAuthMethodTypePasswordless,
|
domain.UserAuthMethodTypePasswordless,
|
||||||
domain.UserAuthMethodTypeTOTP,
|
domain.UserAuthMethodTypeTOTP,
|
||||||
domain.UserAuthMethodTypePassword,
|
domain.UserAuthMethodTypePassword,
|
||||||
domain.UserAuthMethodTypeIDP,
|
domain.UserAuthMethodTypeIDP,
|
||||||
},
|
},
|
||||||
forceMFA: true,
|
ForceMFA: true,
|
||||||
forceMFALocalOnly: true,
|
ForceMFALocalOnly: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "prepareUserAuthMethodTypesRequiredQuery sql err",
|
name: "prepareUserAuthMethodTypesRequiredQuery sql err",
|
||||||
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*testUserAuthMethodTypesRequired, error)) {
|
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) {
|
||||||
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
|
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
|
||||||
return builder, func(rows *sql.Rows) (*testUserAuthMethodTypesRequired, error) {
|
return builder, func(rows *sql.Rows) (*UserAuthMethodRequirements, error) {
|
||||||
authMethods, forceMFA, forceMFALocalOnly, err := scan(rows)
|
return scan(rows)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &testUserAuthMethodTypesRequired{authMethods: authMethods, forceMFA: forceMFA, forceMFALocalOnly: forceMFALocalOnly}, nil
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
want: want{
|
want: want{
|
||||||
@ -452,10 +443,3 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// testUserAuthMethodTypesRequired is required as assetPrepare is only able to return a single object from scan
|
|
||||||
type testUserAuthMethodTypesRequired struct {
|
|
||||||
authMethods []domain.UserAuthMethodType
|
|
||||||
forceMFA bool
|
|
||||||
forceMFALocalOnly bool
|
|
||||||
}
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user