mirror of
https://github.com/zitadel/zitadel.git
synced 2025-06-12 01:18:33 +00:00
Merge branch 'main' into next
This commit is contained in:
commit
50e0e7d564
@ -6,9 +6,11 @@
|
||||
<mat-icon class="icon">info_outline</mat-icon>
|
||||
</a>
|
||||
</div>
|
||||
<p class="sub cnsl-secondary-text max-width-description">
|
||||
{{ 'DESCRIPTIONS.PROJECTS.DESCRIPTION' | translate }}
|
||||
</p>
|
||||
<p
|
||||
class="sub cnsl-secondary-text max-width-description"
|
||||
[innerHTML]="'DESCRIPTIONS.PROJECTS.DESCRIPTION' | translate"
|
||||
></p>
|
||||
|
||||
<div class="projects-controls">
|
||||
<div class="project-toggle-group">
|
||||
<cnsl-nav-toggle
|
||||
|
@ -14,12 +14,15 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/logging"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
mgmt "github.com/zitadel/zitadel/pkg/grpc/management"
|
||||
object "github.com/zitadel/zitadel/pkg/grpc/object/v2beta"
|
||||
session "github.com/zitadel/zitadel/pkg/grpc/session/v2beta"
|
||||
user "github.com/zitadel/zitadel/pkg/grpc/user/v2beta"
|
||||
@ -27,6 +30,7 @@ import (
|
||||
|
||||
var (
|
||||
CTX context.Context
|
||||
IAMOwnerCTX context.Context
|
||||
Tester *integration.Tester
|
||||
Client session.SessionServiceClient
|
||||
User *user.AddHumanUserResponse
|
||||
@ -44,6 +48,7 @@ func TestMain(m *testing.M) {
|
||||
Client = Tester.Client.SessionV2
|
||||
|
||||
CTX, _ = Tester.WithAuthorization(ctx, integration.OrgOwner), errCtx
|
||||
IAMOwnerCTX = Tester.WithAuthorization(ctx, integration.IAMOwner)
|
||||
User = createFullUser(CTX)
|
||||
DeactivatedUser = createDeactivatedUser(CTX)
|
||||
LockedUser = createLockedUser(CTX)
|
||||
@ -341,6 +346,48 @@ func TestServer_CreateSession(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_CreateSession_lock_user(t *testing.T) {
|
||||
// create a separate org so we don't interfere with any other test
|
||||
org := Tester.CreateOrganization(IAMOwnerCTX,
|
||||
fmt.Sprintf("TestServer_CreateSession_lock_user_%d", time.Now().UnixNano()),
|
||||
fmt.Sprintf("%d@mouse.com", time.Now().UnixNano()),
|
||||
)
|
||||
userID := org.CreatedAdmins[0].GetUserId()
|
||||
Tester.SetUserPassword(IAMOwnerCTX, userID, integration.UserPassword, false)
|
||||
|
||||
// enable password lockout
|
||||
maxAttempts := 2
|
||||
ctxOrg := metadata.AppendToOutgoingContext(IAMOwnerCTX, "x-zitadel-orgid", org.GetOrganizationId())
|
||||
_, err := Tester.Client.Mgmt.AddCustomLockoutPolicy(ctxOrg, &mgmt.AddCustomLockoutPolicyRequest{
|
||||
MaxPasswordAttempts: uint32(maxAttempts),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := 0; i <= maxAttempts; i++ {
|
||||
_, err := Client.CreateSession(CTX, &session.CreateSessionRequest{
|
||||
Checks: &session.Checks{
|
||||
User: &session.CheckUser{
|
||||
Search: &session.CheckUser_UserId{
|
||||
UserId: userID,
|
||||
},
|
||||
},
|
||||
Password: &session.CheckPassword{
|
||||
Password: "invalid",
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
statusCode := status.Code(err)
|
||||
expectedCode := codes.InvalidArgument
|
||||
// as soon as we hit the limit the user is locked and following request will
|
||||
// already deny any check with a precondition failed since the user is locked
|
||||
if i >= maxAttempts {
|
||||
expectedCode = codes.FailedPrecondition
|
||||
}
|
||||
assert.Equal(t, expectedCode, statusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_CreateSession_webauthn(t *testing.T) {
|
||||
// create new session with user and request the webauthn challenge
|
||||
createResp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{
|
||||
|
@ -471,7 +471,7 @@ func (s *Server) CreateTokenCallbackURL(ctx context.Context, req op.AuthRequest)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion)
|
||||
resp, err := s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@ -563,7 +563,7 @@ func (s *Server) authResponseToken(authReq *AuthRequest, authorizer op.Authorize
|
||||
op.AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return err
|
||||
}
|
||||
resp, err := s.accessTokenResponseFromSession(ctx, client, session, authReq.GetState(), client.client.ProjectID, client.client.ProjectRoleAssertion)
|
||||
resp, err := s.accessTokenResponseFromSession(ctx, client, session, authReq.GetState(), client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion)
|
||||
if err != nil {
|
||||
op.AuthRequestError(w, r, authReq, err, authorizer)
|
||||
return err
|
||||
|
@ -54,7 +54,7 @@ func TestOPStorage_CreateAccessToken_code(t *testing.T) {
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, false)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
|
||||
|
||||
// callback on a succeeded request must fail
|
||||
linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
|
||||
@ -108,7 +108,7 @@ func TestOPStorage_CreateAccessToken_implicit(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
claims, err := rp.VerifyTokens[*oidc.IDTokenClaims](context.Background(), accessToken, idToken, provider.IDTokenVerifier())
|
||||
require.NoError(t, err)
|
||||
assertIDTokenClaims(t, claims, User.GetUserId(), armPasskey, startTime, changeTime)
|
||||
assertIDTokenClaims(t, claims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
|
||||
|
||||
// callback on a succeeded request must fail
|
||||
linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
|
||||
@ -143,7 +143,7 @@ func TestOPStorage_CreateAccessAndRefreshTokens_code(t *testing.T) {
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
|
||||
}
|
||||
|
||||
func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) {
|
||||
@ -168,14 +168,14 @@ func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) {
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
|
||||
|
||||
// test actual refresh grant
|
||||
newTokens, err := refreshTokens(t, clientID, tokens.RefreshToken)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, newTokens, true)
|
||||
// auth time must still be the initial
|
||||
assertIDTokenClaims(t, newTokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
|
||||
assertIDTokenClaims(t, newTokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
|
||||
|
||||
// refresh with an old refresh_token must fail
|
||||
_, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "")
|
||||
@ -204,7 +204,7 @@ func TestOPStorage_RevokeToken_access_token(t *testing.T) {
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
|
||||
|
||||
// revoke access token
|
||||
err = rp.RevokeToken(CTX, provider, tokens.AccessToken, "access_token")
|
||||
@ -247,7 +247,7 @@ func TestOPStorage_RevokeToken_access_token_invalid_token_hint_type(t *testing.T
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
|
||||
|
||||
// revoke access token
|
||||
err = rp.RevokeToken(CTX, provider, tokens.AccessToken, "refresh_token")
|
||||
@ -284,7 +284,7 @@ func TestOPStorage_RevokeToken_refresh_token(t *testing.T) {
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
|
||||
|
||||
// revoke refresh token -> invalidates also access token
|
||||
err = rp.RevokeToken(CTX, provider, tokens.RefreshToken, "refresh_token")
|
||||
@ -327,7 +327,7 @@ func TestOPStorage_RevokeToken_refresh_token_invalid_token_type_hint(t *testing.
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
|
||||
|
||||
// revoke refresh token even with a wrong hint
|
||||
err = rp.RevokeToken(CTX, provider, tokens.RefreshToken, "access_token")
|
||||
@ -362,7 +362,7 @@ func TestOPStorage_RevokeToken_invalid_client(t *testing.T) {
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
|
||||
|
||||
// simulate second client (not part of the audience) trying to revoke the token
|
||||
otherClientID, _ := createClient(t)
|
||||
@ -394,7 +394,7 @@ func TestOPStorage_TerminateSession(t *testing.T) {
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, false)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
|
||||
|
||||
// userinfo must not fail
|
||||
_, err = rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider)
|
||||
@ -431,7 +431,7 @@ func TestOPStorage_TerminateSession_refresh_grant(t *testing.T) {
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
|
||||
|
||||
// userinfo must not fail
|
||||
_, err = rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider)
|
||||
@ -475,7 +475,7 @@ func TestOPStorage_TerminateSession_empty_id_token_hint(t *testing.T) {
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, false)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
|
||||
|
||||
postLogoutRedirect, err := rp.EndSession(CTX, provider, "", logoutRedirectURI, "state")
|
||||
require.NoError(t, err)
|
||||
@ -530,8 +530,13 @@ func assertTokens(t *testing.T, tokens *oidc.Tokens[*oidc.IDTokenClaims], requir
|
||||
}
|
||||
}
|
||||
|
||||
func assertIDTokenClaims(t *testing.T, claims *oidc.IDTokenClaims, userID string, arm []string, sessionStart, sessionChange time.Time) {
|
||||
func assertIDTokenClaims(t *testing.T, claims *oidc.IDTokenClaims, userID string, arm []string, sessionStart, sessionChange time.Time, sessionID string) {
|
||||
assert.Equal(t, userID, claims.Subject)
|
||||
assert.Equal(t, arm, claims.AuthenticationMethodsReferences)
|
||||
assertOIDCTimeRange(t, claims.AuthTime, sessionStart, sessionChange)
|
||||
assert.Equal(t, sessionID, claims.SessionID)
|
||||
assert.Empty(t, claims.Name)
|
||||
assert.Empty(t, claims.GivenName)
|
||||
assert.Empty(t, claims.FamilyName)
|
||||
assert.Empty(t, claims.PreferredUsername)
|
||||
}
|
||||
|
@ -122,7 +122,7 @@ func TestServer_Introspect(t *testing.T) {
|
||||
tokens, err := exchangeTokens(t, app.GetClientId(), code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
|
||||
|
||||
// test actual introspection
|
||||
introspection, err := rs.Introspect[*oidc.IntrospectionResponse](context.Background(), resourceServer, tokens.AccessToken)
|
||||
@ -317,7 +317,7 @@ func TestServer_VerifyClient(t *testing.T) {
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, false)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
@ -80,7 +81,11 @@ func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionR
|
||||
// with active: false
|
||||
defer func() {
|
||||
if err != nil {
|
||||
s.getLogger(ctx).ErrorContext(ctx, "oidc introspection", "err", err)
|
||||
if zerrors.IsInternal(err) {
|
||||
s.getLogger(ctx).ErrorContext(ctx, "oidc introspection", "err", err)
|
||||
} else {
|
||||
s.getLogger(ctx).InfoContext(ctx, "oidc introspection", "err", err)
|
||||
}
|
||||
resp, err = op.NewResponse(new(oidc.IntrospectionResponse)), nil
|
||||
}
|
||||
}()
|
||||
@ -99,7 +104,14 @@ func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionR
|
||||
if err = validateIntrospectionAudience(token.audience, client.clientID, client.projectID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userInfo, err := s.userInfo(ctx, token.userID, token.scope, client.projectID, client.projectRoleAssertion, true)
|
||||
userInfo, err := s.userInfo(
|
||||
token.userID,
|
||||
token.scope,
|
||||
client.projectID,
|
||||
client.projectRoleAssertion,
|
||||
true,
|
||||
true,
|
||||
)(ctx, true, domain.TriggerTypePreUserinfoCreation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -77,7 +77,7 @@ func Test_ZITADEL_API_missing_audience_scope(t *testing.T) {
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, false)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
|
||||
|
||||
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
|
||||
|
||||
@ -142,7 +142,7 @@ func Test_ZITADEL_API_missing_mfa_2fa_setup(t *testing.T) {
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, userResp.GetUserId(), armPassword, startTime, changeTime)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, userResp.GetUserId(), armPassword, startTime, changeTime, sessionID)
|
||||
|
||||
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
|
||||
|
||||
@ -173,7 +173,7 @@ func Test_ZITADEL_API_missing_mfa_policy(t *testing.T) {
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, userID, armPassword, startTime, changeTime)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, userID, armPassword, startTime, changeTime, sessionID)
|
||||
|
||||
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
|
||||
|
||||
@ -227,7 +227,7 @@ func Test_ZITADEL_API_success(t *testing.T) {
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, false)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
|
||||
|
||||
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
|
||||
|
||||
@ -261,7 +261,7 @@ func Test_ZITADEL_API_glob_redirects(t *testing.T) {
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, false)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
|
||||
|
||||
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
|
||||
|
||||
@ -290,7 +290,7 @@ func Test_ZITADEL_API_inactive_access_token(t *testing.T) {
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
|
||||
|
||||
// make sure token works
|
||||
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
|
||||
@ -332,7 +332,7 @@ func Test_ZITADEL_API_terminated_session(t *testing.T) {
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
|
||||
|
||||
// make sure token works
|
||||
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
|
||||
@ -402,7 +402,7 @@ func Test_ZITADEL_API_terminated_session_user_disabled(t *testing.T) {
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, disabledUser.GetUserId(), armPassword, startTime, changeTime)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, disabledUser.GetUserId(), armPassword, startTime, changeTime, sessionID)
|
||||
|
||||
// make sure token works
|
||||
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
|
||||
|
@ -29,8 +29,8 @@ In some cases step 1 till 3 are completely implemented in the command package,
|
||||
for example the v2 code exchange and refresh token.
|
||||
*/
|
||||
|
||||
func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.Client, session *command.OIDCSession, state, projectID string, projectRoleAssertion bool) (_ *oidc.AccessTokenResponse, err error) {
|
||||
getUserInfo := s.getUserInfoOnce(session.UserID, projectID, projectRoleAssertion, session.Scope)
|
||||
func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.Client, session *command.OIDCSession, state, projectID string, projectRoleAssertion, accessTokenRoleAssertion, idTokenRoleAssertion, userInfoAssertion bool) (_ *oidc.AccessTokenResponse, err error) {
|
||||
getUserInfo := s.getUserInfo(session.UserID, projectID, projectRoleAssertion, userInfoAssertion, session.Scope)
|
||||
getSigner := s.getSignerOnce()
|
||||
|
||||
resp := &oidc.AccessTokenResponse{
|
||||
@ -43,7 +43,7 @@ func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.C
|
||||
// If the session does not have a token ID, it is an implicit ID-Token only response.
|
||||
if session.TokenID != "" {
|
||||
if client.AccessTokenType() == op.AccessTokenTypeJWT {
|
||||
resp.AccessToken, err = s.createJWT(ctx, client, session, getUserInfo, getSigner)
|
||||
resp.AccessToken, err = s.createJWT(ctx, client, session, getUserInfo, accessTokenRoleAssertion, getSigner)
|
||||
} else {
|
||||
resp.AccessToken, err = op.CreateBearerToken(session.TokenID, session.UserID, s.opCrypto)
|
||||
}
|
||||
@ -53,7 +53,7 @@ func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.C
|
||||
}
|
||||
|
||||
if slices.Contains(session.Scope, oidc.ScopeOpenID) {
|
||||
resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, getSigner, session.SessionID, resp.AccessToken, session.Audience, session.AuthMethods, session.AuthTime, session.Nonce, session.Actor)
|
||||
resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, idTokenRoleAssertion, getSigner, session.SessionID, resp.AccessToken, session.Audience, session.AuthMethods, session.AuthTime, session.Nonce, session.Actor)
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
@ -92,31 +92,22 @@ func (s *Server) getSignerOnce() signerFunc {
|
||||
}
|
||||
|
||||
// userInfoFunc is a getter function that allows add-hoc retrieval of a user.
|
||||
type userInfoFunc func(ctx context.Context) (*oidc.UserInfo, error)
|
||||
type userInfoFunc func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (*oidc.UserInfo, error)
|
||||
|
||||
// getUserInfoOnce returns a function which retrieves userinfo from the database once.
|
||||
// Repeated calls of the returned function return the same results.
|
||||
func (s *Server) getUserInfoOnce(userID, projectID string, projectRoleAssertion bool, scope []string) userInfoFunc {
|
||||
var (
|
||||
once sync.Once
|
||||
userInfo *oidc.UserInfo
|
||||
err error
|
||||
)
|
||||
return func(ctx context.Context) (*oidc.UserInfo, error) {
|
||||
once.Do(func() {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
userInfo, err = s.userInfo(ctx, userID, scope, projectID, projectRoleAssertion, false)
|
||||
})
|
||||
return userInfo, err
|
||||
// getUserInfo returns a function which retrieves userinfo from the database once.
|
||||
// However, each time, role claims are asserted and also action flows will trigger.
|
||||
func (s *Server) getUserInfo(userID, projectID string, projectRoleAssertion, userInfoAssertion bool, scope []string) userInfoFunc {
|
||||
userInfo := s.userInfo(userID, scope, projectID, projectRoleAssertion, userInfoAssertion, false)
|
||||
return func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (*oidc.UserInfo, error) {
|
||||
return userInfo(ctx, roleAssertion, triggerType)
|
||||
}
|
||||
}
|
||||
|
||||
func (*Server) createIDToken(ctx context.Context, client op.Client, getUserInfo userInfoFunc, getSigningKey signerFunc, sessionID, accessToken string, audience []string, authMethods []domain.UserAuthMethodType, authTime time.Time, nonce string, actor *domain.TokenActor) (idToken string, exp uint64, err error) {
|
||||
func (*Server) createIDToken(ctx context.Context, client op.Client, getUserInfo userInfoFunc, roleAssertion bool, getSigningKey signerFunc, sessionID, accessToken string, audience []string, authMethods []domain.UserAuthMethodType, authTime time.Time, nonce string, actor *domain.TokenActor) (idToken string, exp uint64, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
userInfo, err := getUserInfo(ctx)
|
||||
userInfo, err := getUserInfo(ctx, roleAssertion, domain.TriggerTypePreUserinfoCreation)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
@ -156,11 +147,11 @@ func timeToOIDCExpiresIn(exp time.Time) uint64 {
|
||||
return uint64(time.Until(exp) / time.Second)
|
||||
}
|
||||
|
||||
func (*Server) createJWT(ctx context.Context, client op.Client, session *command.OIDCSession, getUserInfo userInfoFunc, getSigner signerFunc) (_ string, err error) {
|
||||
func (s *Server) createJWT(ctx context.Context, client op.Client, session *command.OIDCSession, getUserInfo userInfoFunc, assertRoles bool, getSigner signerFunc) (_ string, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
userInfo, err := getUserInfo(ctx)
|
||||
userInfo, err := getUserInfo(ctx, assertRoles, domain.TriggerTypePreAccessTokenCreation)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -47,5 +47,5 @@ func (s *Server) ClientCredentialsExchange(ctx context.Context, r *op.ClientRequ
|
||||
false,
|
||||
)
|
||||
|
||||
return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false))
|
||||
return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false, true, false, false))
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ package oidc_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v6"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -18,10 +19,13 @@ import (
|
||||
)
|
||||
|
||||
func TestServer_ClientCredentialsExchange(t *testing.T) {
|
||||
userID, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX)
|
||||
machine, name, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX)
|
||||
require.NoError(t, err)
|
||||
|
||||
type claims struct {
|
||||
name string
|
||||
username string
|
||||
updated time.Time
|
||||
resourceOwnerID any
|
||||
resourceOwnerName any
|
||||
resourceOwnerPrimaryDomain any
|
||||
@ -78,6 +82,17 @@ func TestServer_ClientCredentialsExchange(t *testing.T) {
|
||||
clientSecret: clientSecret,
|
||||
scope: []string{oidc.ScopeOpenID},
|
||||
},
|
||||
{
|
||||
name: "openid, profile, email",
|
||||
clientID: clientID,
|
||||
clientSecret: clientSecret,
|
||||
scope: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail},
|
||||
wantClaims: claims{
|
||||
name: name,
|
||||
username: name,
|
||||
updated: machine.GetDetails().GetChangeDate().AsTime(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "org id and domain scope",
|
||||
clientID: clientID,
|
||||
@ -132,12 +147,20 @@ func TestServer_ClientCredentialsExchange(t *testing.T) {
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, tokens)
|
||||
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, userID, provider)
|
||||
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, machine.GetUserId(), provider)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantClaims.resourceOwnerID, userinfo.Claims[oidc_api.ClaimResourceOwnerID])
|
||||
assert.Equal(t, tt.wantClaims.resourceOwnerName, userinfo.Claims[oidc_api.ClaimResourceOwnerName])
|
||||
assert.Equal(t, tt.wantClaims.resourceOwnerPrimaryDomain, userinfo.Claims[oidc_api.ClaimResourceOwnerPrimaryDomain])
|
||||
assert.Equal(t, tt.wantClaims.orgDomain, userinfo.Claims[domain.OrgDomainPrimaryClaim])
|
||||
assert.Equal(t, tt.wantClaims.name, userinfo.Name)
|
||||
assert.Equal(t, tt.wantClaims.username, userinfo.PreferredUsername)
|
||||
assertOIDCTime(t, userinfo.UpdatedAt, tt.wantClaims.updated)
|
||||
assert.Empty(t, userinfo.UserInfoProfile.FamilyName)
|
||||
assert.Empty(t, userinfo.UserInfoProfile.GivenName)
|
||||
assert.Empty(t, userinfo.UserInfoEmail)
|
||||
assert.Empty(t, userinfo.UserInfoPhone)
|
||||
assert.Empty(t, userinfo.Address)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -49,7 +49,7 @@ func (s *Server) CodeExchange(ctx context.Context, r *op.ClientRequest[oidc.Acce
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return response(s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion))
|
||||
return response(s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion))
|
||||
}
|
||||
|
||||
// codeExchangeV1 creates a v2 token from a v1 auth request.
|
||||
|
@ -26,7 +26,7 @@ func (s *Server) DeviceToken(ctx context.Context, r *op.ClientRequest[oidc.Devic
|
||||
}
|
||||
session, err := s.command.CreateOIDCSessionFromDeviceAuth(ctx, r.Data.DeviceCode)
|
||||
if err == nil {
|
||||
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion))
|
||||
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion))
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, oidc.ErrSlowDown().WithParent(err)
|
||||
|
@ -218,7 +218,7 @@ func validateTokenExchangeAudience(requestedAudience, subjectAudience, actorAudi
|
||||
// Both tokens may point to the same object (subjectToken) in case of a regular Token Exchange.
|
||||
// When the subject and actor Tokens point to different objects, the new tokens will be for impersonation / delegation.
|
||||
func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenType, client *Client, subjectToken, actorToken *exchangeToken, audience, scopes []string) (_ *oidc.TokenExchangeResponse, err error) {
|
||||
getUserInfo := s.getUserInfoOnce(subjectToken.userID, client.client.ProjectID, client.client.ProjectRoleAssertion, scopes)
|
||||
getUserInfo := s.getUserInfo(subjectToken.userID, client.client.ProjectID, client.client.ProjectRoleAssertion, client.IDTokenUserinfoClaimsAssertion(), scopes)
|
||||
getSigner := s.getSignerOnce()
|
||||
|
||||
resp := &oidc.TokenExchangeResponse{
|
||||
@ -240,12 +240,12 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT
|
||||
resp.IssuedTokenType = oidc.AccessTokenType
|
||||
|
||||
case oidc.JWTTokenType:
|
||||
resp.AccessToken, resp.RefreshToken, resp.ExpiresIn, err = s.createExchangeJWT(ctx, client, getUserInfo, getSigner, subjectToken.userID, subjectToken.resourceOwner, audience, scopes, actorToken.authMethods, actorToken.authTime, subjectToken.preferredLanguage, reason, actor)
|
||||
resp.AccessToken, resp.RefreshToken, resp.ExpiresIn, err = s.createExchangeJWT(ctx, client, getUserInfo, client.client.AccessTokenRoleAssertion, getSigner, subjectToken.userID, subjectToken.resourceOwner, audience, scopes, actorToken.authMethods, actorToken.authTime, subjectToken.preferredLanguage, reason, actor)
|
||||
resp.TokenType = oidc.BearerToken
|
||||
resp.IssuedTokenType = oidc.JWTTokenType
|
||||
|
||||
case oidc.IDTokenType:
|
||||
resp.AccessToken, resp.ExpiresIn, err = s.createIDToken(ctx, client, getUserInfo, getSigner, "", resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor)
|
||||
resp.AccessToken, resp.ExpiresIn, err = s.createIDToken(ctx, client, getUserInfo, client.client.IDTokenRoleAssertion, getSigner, "", resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor)
|
||||
resp.TokenType = TokenTypeNA
|
||||
resp.IssuedTokenType = oidc.IDTokenType
|
||||
|
||||
@ -259,7 +259,7 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT
|
||||
}
|
||||
|
||||
if slices.Contains(scopes, oidc.ScopeOpenID) && tokenType != oidc.IDTokenType {
|
||||
resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, getSigner, sessionID, resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor)
|
||||
resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, client.client.IDTokenRoleAssertion, getSigner, sessionID, resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -313,6 +313,7 @@ func (s *Server) createExchangeJWT(
|
||||
ctx context.Context,
|
||||
client *Client,
|
||||
getUserInfo userInfoFunc,
|
||||
roleAssertion bool,
|
||||
getSigner signerFunc,
|
||||
userID,
|
||||
resourceOwner string,
|
||||
@ -342,7 +343,7 @@ func (s *Server) createExchangeJWT(
|
||||
actor,
|
||||
slices.Contains(scope, oidc.ScopeOfflineAccess),
|
||||
)
|
||||
accessToken, err = s.createJWT(ctx, client, session, getUserInfo, getSigner)
|
||||
accessToken, err = s.createJWT(ctx, client, session, getUserInfo, roleAssertion, getSigner)
|
||||
if err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
|
@ -54,7 +54,7 @@ func (s *Server) JWTProfile(ctx context.Context, r *op.Request[oidc.JWTProfileGr
|
||||
nil,
|
||||
false,
|
||||
)
|
||||
return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false))
|
||||
return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false, true, false, false))
|
||||
}
|
||||
|
||||
func (s *Server) verifyJWTProfile(ctx context.Context, req *oidc.JWTProfileGrantRequest) (user *query.User, tokenRequest *oidc.JWTTokenRequest, err error) {
|
||||
|
@ -4,6 +4,7 @@ package oidc_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -16,10 +17,13 @@ import (
|
||||
)
|
||||
|
||||
func TestServer_JWTProfile(t *testing.T) {
|
||||
userID, keyData, err := Tester.CreateOIDCJWTProfileClient(CTX)
|
||||
user, name, keyData, err := Tester.CreateOIDCJWTProfileClient(CTX)
|
||||
require.NoError(t, err)
|
||||
|
||||
type claims struct {
|
||||
name string
|
||||
username string
|
||||
updated time.Time
|
||||
resourceOwnerID any
|
||||
resourceOwnerName any
|
||||
resourceOwnerPrimaryDomain any
|
||||
@ -37,6 +41,16 @@ func TestServer_JWTProfile(t *testing.T) {
|
||||
keyData: keyData,
|
||||
scope: []string{oidc.ScopeOpenID},
|
||||
},
|
||||
{
|
||||
name: "openid, profile, email",
|
||||
keyData: keyData,
|
||||
scope: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail},
|
||||
wantClaims: claims{
|
||||
name: name,
|
||||
username: name,
|
||||
updated: user.GetDetails().GetChangeDate().AsTime(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "org id and domain scope",
|
||||
keyData: keyData,
|
||||
@ -92,12 +106,20 @@ func TestServer_JWTProfile(t *testing.T) {
|
||||
|
||||
provider, err := rp.NewRelyingPartyOIDC(CTX, Tester.OIDCIssuer(), "", "", redirectURI, tt.scope)
|
||||
require.NoError(t, err)
|
||||
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, userID, provider)
|
||||
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, user.GetUserId(), provider)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantClaims.resourceOwnerID, userinfo.Claims[oidc_api.ClaimResourceOwnerID])
|
||||
assert.Equal(t, tt.wantClaims.resourceOwnerName, userinfo.Claims[oidc_api.ClaimResourceOwnerName])
|
||||
assert.Equal(t, tt.wantClaims.resourceOwnerPrimaryDomain, userinfo.Claims[oidc_api.ClaimResourceOwnerPrimaryDomain])
|
||||
assert.Equal(t, tt.wantClaims.orgDomain, userinfo.Claims[domain.OrgDomainPrimaryClaim])
|
||||
assert.Equal(t, tt.wantClaims.name, userinfo.Name)
|
||||
assert.Equal(t, tt.wantClaims.username, userinfo.PreferredUsername)
|
||||
assertOIDCTime(t, userinfo.UpdatedAt, tt.wantClaims.updated)
|
||||
assert.Empty(t, userinfo.UserInfoProfile.FamilyName)
|
||||
assert.Empty(t, userinfo.UserInfoProfile.GivenName)
|
||||
assert.Empty(t, userinfo.UserInfoEmail)
|
||||
assert.Empty(t, userinfo.UserInfoPhone)
|
||||
assert.Empty(t, userinfo.Address)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -28,7 +28,7 @@ func (s *Server) RefreshToken(ctx context.Context, r *op.ClientRequest[oidc.Refr
|
||||
|
||||
session, err := s.command.ExchangeOIDCSessionRefreshAndAccessToken(ctx, r.Data.RefreshToken, r.Data.Scopes, refreshTokenComplianceChecker())
|
||||
if err == nil {
|
||||
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion))
|
||||
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion))
|
||||
} else if errors.Is(err, zerrors.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid")) {
|
||||
// We try again for v1 tokens when we encountered specific parsing error
|
||||
return s.refreshTokenV1(ctx, client, r)
|
||||
@ -78,7 +78,7 @@ func (s *Server) refreshTokenV1(ctx context.Context, client *Client, r *op.Clien
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion))
|
||||
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion))
|
||||
}
|
||||
|
||||
// refreshTokenComplianceChecker validates that the requested scope is a subset of the original auth request scope.
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/dop251/goja"
|
||||
"github.com/zitadel/logging"
|
||||
@ -55,7 +56,14 @@ func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoReques
|
||||
}
|
||||
}
|
||||
|
||||
userInfo, err := s.userInfo(ctx, token.userID, token.scope, projectID, assertion, false)
|
||||
userInfo, err := s.userInfo(
|
||||
token.userID,
|
||||
token.scope,
|
||||
projectID,
|
||||
assertion,
|
||||
true,
|
||||
false,
|
||||
)(ctx, true, domain.TriggerTypePreUserinfoCreation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -66,24 +74,44 @@ func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoReques
|
||||
// The returned UserInfo contains standard and reserved claims, documented
|
||||
// here: https://zitadel.com/docs/apis/openidoauth/claims.
|
||||
//
|
||||
// User information is only retrieved once from the database.
|
||||
// However, each time, role claims are asserted and also action flows will trigger.
|
||||
//
|
||||
// projectID is an optional parameter which defines the default audience when there are any (or all) role claims requested.
|
||||
// projectRoleAssertion sets the default of returning all project roles, only if no specific roles were requested in the scope.
|
||||
// roleAssertion decides whether the roles will be returned (in the token or response)
|
||||
// userInfoAssertion decides whether the user information (profile data like name, email, ...) are returned
|
||||
//
|
||||
// currentProjectOnly can be set to use the current project ID only and ignore the audience from the scope.
|
||||
// It should be set in cases where the client doesn't need to know roles outside its own project,
|
||||
// for example an introspection client.
|
||||
func (s *Server) userInfo(ctx context.Context, userID string, scope []string, projectID string, projectRoleAssertion, currentProjectOnly bool) (_ *oidc.UserInfo, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
func (s *Server) userInfo(
|
||||
userID string,
|
||||
scope []string,
|
||||
projectID string,
|
||||
projectRoleAssertion, userInfoAssertion, currentProjectOnly bool,
|
||||
) func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (_ *oidc.UserInfo, err error) {
|
||||
var (
|
||||
once sync.Once
|
||||
userInfo *oidc.UserInfo
|
||||
qu *query.OIDCUserInfo
|
||||
roleAudience, requestedRoles []string
|
||||
)
|
||||
return func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (_ *oidc.UserInfo, err error) {
|
||||
once.Do(func() {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
roleAudience, requestedRoles := prepareRoles(ctx, scope, projectID, projectRoleAssertion, currentProjectOnly)
|
||||
qu, err := s.query.GetOIDCUserInfo(ctx, userID, roleAudience)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
roleAudience, requestedRoles = prepareRoles(ctx, scope, projectID, projectRoleAssertion, currentProjectOnly)
|
||||
qu, err = s.query.GetOIDCUserInfo(ctx, userID, roleAudience)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
userInfo = userInfoToOIDC(qu, userInfoAssertion, scope, s.assetAPIPrefix(ctx))
|
||||
})
|
||||
userInfoWithRoles := assertRoles(projectID, qu, roleAudience, requestedRoles, roleAssertion, userInfo)
|
||||
return userInfoWithRoles, s.userinfoFlows(ctx, qu, userInfoWithRoles, triggerType)
|
||||
}
|
||||
|
||||
userInfo := userInfoToOIDC(projectID, qu, scope, roleAudience, requestedRoles, s.assetAPIPrefix(ctx))
|
||||
return userInfo, s.userinfoFlows(ctx, qu, userInfo)
|
||||
}
|
||||
|
||||
// prepareRoles scans the requested scopes and builds the requested roles
|
||||
@ -120,20 +148,32 @@ func prepareRoles(ctx context.Context, scope []string, projectID string, project
|
||||
return roleAudience, requestedRoles
|
||||
}
|
||||
|
||||
func userInfoToOIDC(projectID string, user *query.OIDCUserInfo, scope, roleAudience, requestedRoles []string, assetPrefix string) *oidc.UserInfo {
|
||||
func userInfoToOIDC(user *query.OIDCUserInfo, userInfoAssertion bool, scope []string, assetPrefix string) *oidc.UserInfo {
|
||||
out := new(oidc.UserInfo)
|
||||
for _, s := range scope {
|
||||
switch s {
|
||||
case oidc.ScopeOpenID:
|
||||
out.Subject = user.User.ID
|
||||
case oidc.ScopeEmail:
|
||||
if !userInfoAssertion {
|
||||
continue
|
||||
}
|
||||
out.UserInfoEmail = userInfoEmailToOIDC(user.User)
|
||||
case oidc.ScopeProfile:
|
||||
if !userInfoAssertion {
|
||||
continue
|
||||
}
|
||||
out.UserInfoProfile = userInfoProfileToOidc(user.User, assetPrefix)
|
||||
case oidc.ScopePhone:
|
||||
if !userInfoAssertion {
|
||||
continue
|
||||
}
|
||||
out.UserInfoPhone = userInfoPhoneToOIDC(user.User)
|
||||
case oidc.ScopeAddress:
|
||||
//TODO: handle address for human users as soon as implemented
|
||||
if !userInfoAssertion {
|
||||
continue
|
||||
}
|
||||
// TODO: handle address for human users as soon as implemented
|
||||
case ScopeUserMetaData:
|
||||
setUserInfoMetadata(user.Metadata, out)
|
||||
case ScopeResourceOwner:
|
||||
@ -148,12 +188,19 @@ func userInfoToOIDC(projectID string, user *query.OIDCUserInfo, scope, roleAudie
|
||||
}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func assertRoles(projectID string, user *query.OIDCUserInfo, roleAudience, requestedRoles []string, assertion bool, info *oidc.UserInfo) *oidc.UserInfo {
|
||||
if !assertion {
|
||||
return info
|
||||
}
|
||||
userInfo := *info
|
||||
// prevent returning obtained grants if none where requested
|
||||
if (projectID != "" && len(requestedRoles) > 0) || len(roleAudience) > 0 {
|
||||
setUserInfoRoleClaims(out, newProjectRoles(projectID, user.UserGrants, requestedRoles))
|
||||
setUserInfoRoleClaims(&userInfo, newProjectRoles(projectID, user.UserGrants, requestedRoles))
|
||||
}
|
||||
return out
|
||||
return &userInfo
|
||||
}
|
||||
|
||||
func userInfoEmailToOIDC(user *query.User) oidc.UserInfoEmail {
|
||||
@ -230,11 +277,11 @@ func setUserInfoRoleClaims(userInfo *oidc.UserInfo, roles *projectsRoles) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) userinfoFlows(ctx context.Context, qu *query.OIDCUserInfo, userInfo *oidc.UserInfo) (err error) {
|
||||
func (s *Server) userinfoFlows(ctx context.Context, qu *query.OIDCUserInfo, userInfo *oidc.UserInfo, triggerType domain.TriggerType) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
queriedActions, err := s.query.GetActiveActionsByFlowAndTriggerType(ctx, domain.FlowTypeCustomiseToken, domain.TriggerTypePreUserinfoCreation, qu.User.ResourceOwner)
|
||||
queriedActions, err := s.query.GetActiveActionsByFlowAndTriggerType(ctx, domain.FlowTypeCustomiseToken, triggerType, qu.User.ResourceOwner)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -231,9 +231,9 @@ func TestServer_UserInfo_Issue6662(t *testing.T) {
|
||||
project, err := Tester.CreateProject(CTX)
|
||||
projectID := project.GetId()
|
||||
require.NoError(t, err)
|
||||
userID, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX)
|
||||
user, _, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX)
|
||||
require.NoError(t, err)
|
||||
addProjectRolesGrants(t, userID, projectID, roleFoo, roleBar)
|
||||
addProjectRolesGrants(t, user.GetUserId(), projectID, roleFoo, roleBar)
|
||||
|
||||
scope := []string{oidc.ScopeProfile, oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeOfflineAccess,
|
||||
oidc_api.ScopeProjectRolePrefix + roleFoo,
|
||||
@ -245,7 +245,7 @@ func TestServer_UserInfo_Issue6662(t *testing.T) {
|
||||
tokens, err := rp.ClientCredentials(CTX, provider, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, userID, provider)
|
||||
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, user.GetUserId(), provider)
|
||||
require.NoError(t, err)
|
||||
assertProjectRoleClaims(t, projectID, userinfo.Claims, false, roleFoo)
|
||||
}
|
||||
@ -291,7 +291,7 @@ func getTokens(t *testing.T, clientID string, scope []string) *oidc.Tokens[*oidc
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
@ -3,7 +3,6 @@ package oidc
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -267,11 +266,9 @@ func Test_userInfoToOIDC(t *testing.T) {
|
||||
}
|
||||
|
||||
type args struct {
|
||||
projectID string
|
||||
user *query.OIDCUserInfo
|
||||
scope []string
|
||||
roleAudience []string
|
||||
requestedRoles []string
|
||||
user *query.OIDCUserInfo
|
||||
userInfoAssertion bool
|
||||
scope []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@ -281,25 +278,22 @@ func Test_userInfoToOIDC(t *testing.T) {
|
||||
{
|
||||
name: "human, empty",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: humanUserInfo,
|
||||
user: humanUserInfo,
|
||||
},
|
||||
want: &oidc.UserInfo{},
|
||||
},
|
||||
{
|
||||
name: "machine, empty",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: machineUserInfo,
|
||||
user: machineUserInfo,
|
||||
},
|
||||
want: &oidc.UserInfo{},
|
||||
},
|
||||
{
|
||||
name: "human, scope openid",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: humanUserInfo,
|
||||
scope: []string{oidc.ScopeOpenID},
|
||||
user: humanUserInfo,
|
||||
scope: []string{oidc.ScopeOpenID},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
Subject: "human1",
|
||||
@ -308,20 +302,19 @@ func Test_userInfoToOIDC(t *testing.T) {
|
||||
{
|
||||
name: "machine, scope openid",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: machineUserInfo,
|
||||
scope: []string{oidc.ScopeOpenID},
|
||||
user: machineUserInfo,
|
||||
scope: []string{oidc.ScopeOpenID},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
Subject: "machine1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "human, scope email",
|
||||
name: "human, scope email, profileInfoAssertion",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: humanUserInfo,
|
||||
scope: []string{oidc.ScopeEmail},
|
||||
user: humanUserInfo,
|
||||
userInfoAssertion: true,
|
||||
scope: []string{oidc.ScopeEmail},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
UserInfoEmail: oidc.UserInfoEmail{
|
||||
@ -331,22 +324,29 @@ func Test_userInfoToOIDC(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "machine, scope email",
|
||||
name: "human, scope email",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: machineUserInfo,
|
||||
scope: []string{oidc.ScopeEmail},
|
||||
user: humanUserInfo,
|
||||
scope: []string{oidc.ScopeEmail},
|
||||
},
|
||||
want: &oidc.UserInfo{},
|
||||
},
|
||||
{
|
||||
name: "machine, scope email, profileInfoAssertion",
|
||||
args: args{
|
||||
user: machineUserInfo,
|
||||
scope: []string{oidc.ScopeEmail},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
UserInfoEmail: oidc.UserInfoEmail{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "human, scope profile",
|
||||
name: "human, scope profile, profileInfoAssertion",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: humanUserInfo,
|
||||
scope: []string{oidc.ScopeProfile},
|
||||
user: humanUserInfo,
|
||||
userInfoAssertion: true,
|
||||
scope: []string{oidc.ScopeProfile},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
UserInfoProfile: oidc.UserInfoProfile{
|
||||
@ -363,11 +363,11 @@ func Test_userInfoToOIDC(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "machine, scope profile",
|
||||
name: "machine, scope profile, profileInfoAssertion",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: machineUserInfo,
|
||||
scope: []string{oidc.ScopeProfile},
|
||||
user: machineUserInfo,
|
||||
userInfoAssertion: true,
|
||||
scope: []string{oidc.ScopeProfile},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
UserInfoProfile: oidc.UserInfoProfile{
|
||||
@ -378,11 +378,19 @@ func Test_userInfoToOIDC(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "human, scope phone",
|
||||
name: "machine, scope profile",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: humanUserInfo,
|
||||
scope: []string{oidc.ScopePhone},
|
||||
user: machineUserInfo,
|
||||
scope: []string{oidc.ScopeProfile},
|
||||
},
|
||||
want: &oidc.UserInfo{},
|
||||
},
|
||||
{
|
||||
name: "human, scope phone, profileInfoAssertion",
|
||||
args: args{
|
||||
user: humanUserInfo,
|
||||
userInfoAssertion: true,
|
||||
scope: []string{oidc.ScopePhone},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
UserInfoPhone: oidc.UserInfoPhone{
|
||||
@ -391,12 +399,19 @@ func Test_userInfoToOIDC(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "human, scope phone",
|
||||
args: args{
|
||||
user: humanUserInfo,
|
||||
scope: []string{oidc.ScopePhone},
|
||||
},
|
||||
want: &oidc.UserInfo{},
|
||||
},
|
||||
{
|
||||
name: "machine, scope phone",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: machineUserInfo,
|
||||
scope: []string{oidc.ScopePhone},
|
||||
user: machineUserInfo,
|
||||
scope: []string{oidc.ScopePhone},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
UserInfoPhone: oidc.UserInfoPhone{},
|
||||
@ -405,9 +420,8 @@ func Test_userInfoToOIDC(t *testing.T) {
|
||||
{
|
||||
name: "human, scope metadata",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: humanUserInfo,
|
||||
scope: []string{ScopeUserMetaData},
|
||||
user: humanUserInfo,
|
||||
scope: []string{ScopeUserMetaData},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
Claims: map[string]any{
|
||||
@ -421,18 +435,16 @@ func Test_userInfoToOIDC(t *testing.T) {
|
||||
{
|
||||
name: "machine, scope metadata, none found",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: machineUserInfo,
|
||||
scope: []string{ScopeUserMetaData},
|
||||
user: machineUserInfo,
|
||||
scope: []string{ScopeUserMetaData},
|
||||
},
|
||||
want: &oidc.UserInfo{},
|
||||
},
|
||||
{
|
||||
name: "machine, scope resource owner",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: machineUserInfo,
|
||||
scope: []string{ScopeResourceOwner},
|
||||
user: machineUserInfo,
|
||||
scope: []string{ScopeResourceOwner},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
Claims: map[string]any{
|
||||
@ -445,9 +457,8 @@ func Test_userInfoToOIDC(t *testing.T) {
|
||||
{
|
||||
name: "human, scope org primary domain prefix",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: humanUserInfo,
|
||||
scope: []string{domain.OrgDomainPrimaryScope + "foo.com"},
|
||||
user: humanUserInfo,
|
||||
scope: []string{domain.OrgDomainPrimaryScope + "foo.com"},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
Claims: map[string]any{
|
||||
@ -458,9 +469,8 @@ func Test_userInfoToOIDC(t *testing.T) {
|
||||
{
|
||||
name: "machine, scope org id",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: machineUserInfo,
|
||||
scope: []string{domain.OrgIDScope + "orgID"},
|
||||
user: machineUserInfo,
|
||||
scope: []string{domain.OrgIDScope + "orgID"},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
Claims: map[string]any{
|
||||
@ -471,50 +481,11 @@ func Test_userInfoToOIDC(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "human, roleAudience",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: humanUserInfo,
|
||||
roleAudience: []string{"project1"},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
Claims: map[string]any{
|
||||
ClaimProjectRoles: projectRoles{
|
||||
"role1": {"orgID": "orgDomain"},
|
||||
"role2": {"orgID": "orgDomain"},
|
||||
},
|
||||
fmt.Sprintf(ClaimProjectRolesFormat, "project1"): projectRoles{
|
||||
"role1": {"orgID": "orgDomain"},
|
||||
"role2": {"orgID": "orgDomain"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "human, requested roles",
|
||||
args: args{
|
||||
projectID: "project1",
|
||||
user: humanUserInfo,
|
||||
roleAudience: []string{"project1"},
|
||||
requestedRoles: []string{"role2"},
|
||||
},
|
||||
want: &oidc.UserInfo{
|
||||
Claims: map[string]any{
|
||||
ClaimProjectRoles: projectRoles{
|
||||
"role2": {"orgID": "orgDomain"},
|
||||
},
|
||||
fmt.Sprintf(ClaimProjectRolesFormat, "project1"): projectRoles{
|
||||
"role2": {"orgID": "orgDomain"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assetPrefix := "https://foo.com/assets"
|
||||
got := userInfoToOIDC(tt.args.projectID, tt.args.user, tt.args.scope, tt.args.roleAudience, tt.args.requestedRoles, assetPrefix)
|
||||
got := userInfoToOIDC(tt.args.user, tt.args.userInfoAssertion, tt.args.scope, assetPrefix)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
|
@ -342,7 +342,11 @@ func (l *Login) renderInternalError(w http.ResponseWriter, r *http.Request, auth
|
||||
if authReq != nil {
|
||||
log = log.WithField("auth_req_id", authReq.ID)
|
||||
}
|
||||
log.Error()
|
||||
if zerrors.IsInternal(err) {
|
||||
log.Error()
|
||||
} else {
|
||||
log.Info()
|
||||
}
|
||||
|
||||
_, msg = l.getErrorMessage(r, err)
|
||||
}
|
||||
|
@ -340,11 +340,7 @@ func (repo *AuthRequestRepo) VerifyPassword(ctx context.Context, authReqID, user
|
||||
}
|
||||
return err
|
||||
}
|
||||
policy, err := repo.getLockoutPolicy(ctx, resourceOwner)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = repo.Command.HumanCheckPassword(ctx, resourceOwner, userID, password, request.WithCurrentInfo(info), lockoutPolicyToDomain(policy))
|
||||
err = repo.Command.HumanCheckPassword(ctx, resourceOwner, userID, password, request.WithCurrentInfo(info))
|
||||
if isIgnoreUserInvalidPasswordError(err, request) {
|
||||
return zerrors.ThrowInvalidArgument(nil, "EVENT-Jsf32", "Errors.User.UsernameOrPassword.Invalid")
|
||||
}
|
||||
|
@ -32,7 +32,7 @@ func (c *Commands) AddDefaultLockoutPolicy(ctx context.Context, maxPasswordAttem
|
||||
}
|
||||
|
||||
func (c *Commands) ChangeDefaultLockoutPolicy(ctx context.Context, policy *domain.LockoutPolicy) (*domain.LockoutPolicy, error) {
|
||||
existingPolicy, err := c.defaultLockoutPolicyWriteModelByID(ctx)
|
||||
existingPolicy, err := defaultLockoutPolicyWriteModelByID(ctx, c.eventstore.FilterToQueryReducer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -63,12 +63,12 @@ func (c *Commands) ChangeDefaultLockoutPolicy(ctx context.Context, policy *domai
|
||||
return writeModelToLockoutPolicy(&existingPolicy.LockoutPolicyWriteModel), nil
|
||||
}
|
||||
|
||||
func (c *Commands) defaultLockoutPolicyWriteModelByID(ctx context.Context) (policy *InstanceLockoutPolicyWriteModel, err error) {
|
||||
func defaultLockoutPolicyWriteModelByID(ctx context.Context, reducer func(ctx context.Context, r eventstore.QueryReducer) error) (policy *InstanceLockoutPolicyWriteModel, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
writeModel := NewInstanceLockoutPolicyWriteModel(ctx)
|
||||
err = c.eventstore.FilterToQueryReducer(ctx, writeModel)
|
||||
err = reducer(ctx, writeModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/org"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
@ -12,7 +13,7 @@ func (c *Commands) AddLockoutPolicy(ctx context.Context, resourceOwner string, p
|
||||
if resourceOwner == "" {
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "Org-8fJif", "Errors.ResourceOwnerMissing")
|
||||
}
|
||||
addedPolicy, err := c.orgLockoutPolicyWriteModelByID(ctx, resourceOwner)
|
||||
addedPolicy, err := orgLockoutPolicyWriteModelByID(ctx, resourceOwner, c.eventstore.FilterToQueryReducer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -42,7 +43,7 @@ func (c *Commands) ChangeLockoutPolicy(ctx context.Context, resourceOwner string
|
||||
if resourceOwner == "" {
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "Org-3J9fs", "Errors.ResourceOwnerMissing")
|
||||
}
|
||||
existingPolicy, err := c.orgLockoutPolicyWriteModelByID(ctx, resourceOwner)
|
||||
existingPolicy, err := orgLockoutPolicyWriteModelByID(ctx, resourceOwner, c.eventstore.FilterToQueryReducer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -71,7 +72,7 @@ func (c *Commands) RemoveLockoutPolicy(ctx context.Context, orgID string) (*doma
|
||||
if orgID == "" {
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "Org-4J9fs", "Errors.ResourceOwnerMissing")
|
||||
}
|
||||
existingPolicy, err := c.orgLockoutPolicyWriteModelByID(ctx, orgID)
|
||||
existingPolicy, err := orgLockoutPolicyWriteModelByID(ctx, orgID, c.eventstore.FilterToQueryReducer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -93,7 +94,7 @@ func (c *Commands) RemoveLockoutPolicy(ctx context.Context, orgID string) (*doma
|
||||
}
|
||||
|
||||
func (c *Commands) removeLockoutPolicyIfExists(ctx context.Context, orgID string) (*org.LockoutPolicyRemovedEvent, error) {
|
||||
existingPolicy, err := c.orgLockoutPolicyWriteModelByID(ctx, orgID)
|
||||
existingPolicy, err := orgLockoutPolicyWriteModelByID(ctx, orgID, c.eventstore.FilterToQueryReducer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -104,24 +105,24 @@ func (c *Commands) removeLockoutPolicyIfExists(ctx context.Context, orgID string
|
||||
return org.NewLockoutPolicyRemovedEvent(ctx, orgAgg), nil
|
||||
}
|
||||
|
||||
func (c *Commands) orgLockoutPolicyWriteModelByID(ctx context.Context, orgID string) (*OrgLockoutPolicyWriteModel, error) {
|
||||
func orgLockoutPolicyWriteModelByID(ctx context.Context, orgID string, queryReducer func(ctx context.Context, r eventstore.QueryReducer) error) (*OrgLockoutPolicyWriteModel, error) {
|
||||
policy := NewOrgLockoutPolicyWriteModel(orgID)
|
||||
err := c.eventstore.FilterToQueryReducer(ctx, policy)
|
||||
err := queryReducer(ctx, policy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
func (c *Commands) getLockoutPolicy(ctx context.Context, orgID string) (*domain.LockoutPolicy, error) {
|
||||
orgWm, err := c.orgLockoutPolicyWriteModelByID(ctx, orgID)
|
||||
func getLockoutPolicy(ctx context.Context, orgID string, queryReducer func(ctx context.Context, r eventstore.QueryReducer) error) (*domain.LockoutPolicy, error) {
|
||||
orgWm, err := orgLockoutPolicyWriteModelByID(ctx, orgID, queryReducer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if orgWm.State == domain.PolicyStateActive {
|
||||
return writeModelToLockoutPolicy(&orgWm.LockoutPolicyWriteModel), nil
|
||||
}
|
||||
instanceWm, err := c.defaultLockoutPolicyWriteModelByID(ctx)
|
||||
instanceWm, err := defaultLockoutPolicyWriteModelByID(ctx, queryReducer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/activity"
|
||||
@ -17,21 +18,18 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
"github.com/zitadel/zitadel/internal/repository/session"
|
||||
"github.com/zitadel/zitadel/internal/repository/user"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
type SessionCommand func(ctx context.Context, cmd *SessionCommands) error
|
||||
type SessionCommand func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error)
|
||||
|
||||
type SessionCommands struct {
|
||||
sessionCommands []SessionCommand
|
||||
|
||||
sessionWriteModel *SessionWriteModel
|
||||
passwordWriteModel *HumanPasswordWriteModel
|
||||
intentWriteModel *IDPIntentWriteModel
|
||||
totpWriteModel *HumanTOTPWriteModel
|
||||
eventstore *eventstore.Eventstore
|
||||
eventCommands []eventstore.Command
|
||||
sessionWriteModel *SessionWriteModel
|
||||
intentWriteModel *IDPIntentWriteModel
|
||||
eventstore *eventstore.Eventstore
|
||||
eventCommands []eventstore.Command
|
||||
|
||||
hasher *crypto.Hasher
|
||||
intentAlg crypto.EncryptionAlgorithm
|
||||
@ -59,114 +57,92 @@ func (c *Commands) NewSessionCommands(cmds []SessionCommand, session *SessionWri
|
||||
|
||||
// CheckUser defines a user check to be executed for a session update
|
||||
func CheckUser(id string, resourceOwner string, preferredLanguage *language.Tag) SessionCommand {
|
||||
return func(ctx context.Context, cmd *SessionCommands) error {
|
||||
return func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
|
||||
if cmd.sessionWriteModel.UserID != "" && id != "" && cmd.sessionWriteModel.UserID != id {
|
||||
return zerrors.ThrowInvalidArgument(nil, "", "user change not possible")
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "", "user change not possible")
|
||||
}
|
||||
return cmd.UserChecked(ctx, id, resourceOwner, cmd.now(), preferredLanguage)
|
||||
return nil, cmd.UserChecked(ctx, id, resourceOwner, cmd.now(), preferredLanguage)
|
||||
}
|
||||
}
|
||||
|
||||
// CheckPassword defines a password check to be executed for a session update
|
||||
func CheckPassword(password string) SessionCommand {
|
||||
return func(ctx context.Context, cmd *SessionCommands) error {
|
||||
if cmd.sessionWriteModel.UserID == "" {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sfw3f", "Errors.User.UserIDMissing")
|
||||
}
|
||||
cmd.passwordWriteModel = NewHumanPasswordWriteModel(cmd.sessionWriteModel.UserID, "")
|
||||
err := cmd.eventstore.FilterToQueryReducer(ctx, cmd.passwordWriteModel)
|
||||
return func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
|
||||
commands, err := checkPassword(ctx, cmd.sessionWriteModel.UserID, password, cmd.eventstore, cmd.hasher, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
return commands, err
|
||||
}
|
||||
if cmd.passwordWriteModel.UserState == domain.UserStateUnspecified || cmd.passwordWriteModel.UserState == domain.UserStateDeleted {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-Df4b3", "Errors.User.NotFound")
|
||||
}
|
||||
|
||||
if cmd.passwordWriteModel.EncodedHash == "" {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-WEf3t", "Errors.User.Password.NotSet")
|
||||
}
|
||||
ctx, spanPasswordComparison := tracing.NewNamedSpan(ctx, "passwap.Verify")
|
||||
updated, err := cmd.hasher.Verify(cmd.passwordWriteModel.EncodedHash, password)
|
||||
spanPasswordComparison.EndWithError(err)
|
||||
if err != nil {
|
||||
//TODO: maybe we want to reset the session in the future https://github.com/zitadel/zitadel/issues/5807
|
||||
return zerrors.ThrowInvalidArgument(err, "COMMAND-SAF3g", "Errors.User.Password.Invalid")
|
||||
}
|
||||
if updated != "" {
|
||||
cmd.eventCommands = append(cmd.eventCommands, user.NewHumanPasswordHashUpdatedEvent(ctx, UserAggregateFromWriteModel(&cmd.passwordWriteModel.WriteModel), updated))
|
||||
}
|
||||
|
||||
cmd.eventCommands = append(cmd.eventCommands, commands...)
|
||||
cmd.PasswordChecked(ctx, cmd.now())
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// CheckIntent defines a check for a succeeded intent to be executed for a session update
|
||||
func CheckIntent(intentID, token string) SessionCommand {
|
||||
return func(ctx context.Context, cmd *SessionCommands) error {
|
||||
return func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
|
||||
if cmd.sessionWriteModel.UserID == "" {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sfw3r", "Errors.User.UserIDMissing")
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sfw3r", "Errors.User.UserIDMissing")
|
||||
}
|
||||
if err := crypto.CheckToken(cmd.intentAlg, token, intentID); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
cmd.intentWriteModel = NewIDPIntentWriteModel(intentID, "")
|
||||
err := cmd.eventstore.FilterToQueryReducer(ctx, cmd.intentWriteModel)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if cmd.intentWriteModel.State != domain.IDPIntentStateSucceeded {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-Df4bw", "Errors.Intent.NotSucceeded")
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Df4bw", "Errors.Intent.NotSucceeded")
|
||||
}
|
||||
if cmd.intentWriteModel.UserID != "" {
|
||||
if cmd.intentWriteModel.UserID != cmd.sessionWriteModel.UserID {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-O8xk3w", "Errors.Intent.OtherUser")
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-O8xk3w", "Errors.Intent.OtherUser")
|
||||
}
|
||||
} else {
|
||||
linkWriteModel := NewUserIDPLinkWriteModel(cmd.sessionWriteModel.UserID, cmd.intentWriteModel.IDPID, cmd.intentWriteModel.IDPUserID, cmd.sessionWriteModel.UserResourceOwner)
|
||||
err := cmd.eventstore.FilterToQueryReducer(ctx, linkWriteModel)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if linkWriteModel.State != domain.UserIDPLinkStateActive {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-O8xk3w", "Errors.Intent.OtherUser")
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-O8xk3w", "Errors.Intent.OtherUser")
|
||||
}
|
||||
}
|
||||
cmd.IntentChecked(ctx, cmd.now())
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func CheckTOTP(code string) SessionCommand {
|
||||
return func(ctx context.Context, cmd *SessionCommands) (err error) {
|
||||
if cmd.sessionWriteModel.UserID == "" {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-Neil7", "Errors.User.UserIDMissing")
|
||||
}
|
||||
cmd.totpWriteModel = NewHumanTOTPWriteModel(cmd.sessionWriteModel.UserID, "")
|
||||
err = cmd.eventstore.FilterToQueryReducer(ctx, cmd.totpWriteModel)
|
||||
return func(ctx context.Context, cmd *SessionCommands) (_ []eventstore.Command, err error) {
|
||||
commands, err := checkTOTP(
|
||||
ctx,
|
||||
cmd.sessionWriteModel.UserID,
|
||||
"",
|
||||
code,
|
||||
cmd.eventstore.FilterToQueryReducer,
|
||||
cmd.totpAlg,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cmd.totpWriteModel.State != domain.MFAStateReady {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-eej1U", "Errors.User.MFA.OTP.NotReady")
|
||||
}
|
||||
err = domain.VerifyTOTP(code, cmd.totpWriteModel.Secret, cmd.totpAlg)
|
||||
if err != nil {
|
||||
return err
|
||||
return commands, err
|
||||
}
|
||||
cmd.eventCommands = append(cmd.eventCommands, commands...)
|
||||
cmd.TOTPChecked(ctx, cmd.now())
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Exec will execute the commands specified and returns an error on the first occurrence
|
||||
func (s *SessionCommands) Exec(ctx context.Context) error {
|
||||
// Exec will execute the commands specified and returns an error on the first occurrence.
|
||||
// In case of an error there might be specific commands returned, e.g. a failed pw check will have to be stored.
|
||||
func (s *SessionCommands) Exec(ctx context.Context) ([]eventstore.Command, error) {
|
||||
for _, cmd := range s.sessionCommands {
|
||||
if err := cmd(ctx, s); err != nil {
|
||||
return err
|
||||
if cmds, err := cmd(ctx, s); err != nil {
|
||||
return cmds, err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *SessionCommands) Start(ctx context.Context, userAgent *domain.UserAgent) {
|
||||
@ -360,8 +336,11 @@ func (c *Commands) updateSession(ctx context.Context, checks *SessionCommands, m
|
||||
if err = checks.sessionWriteModel.CheckNotInvalidated(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := checks.Exec(ctx); err != nil {
|
||||
// TODO: how to handle failed checks (e.g. pw wrong) https://github.com/zitadel/zitadel/issues/5807
|
||||
if cmds, err := checks.Exec(ctx); err != nil {
|
||||
if len(cmds) > 0 {
|
||||
_, pushErr := c.eventstore.Push(ctx, cmds...)
|
||||
logging.OnError(pushErr).Error("unable to store check failures")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
checks.ChangeMetadata(ctx, metadata)
|
||||
|
@ -6,9 +6,10 @@ import (
|
||||
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/session"
|
||||
"github.com/zitadel/zitadel/internal/repository/user"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
@ -21,26 +22,26 @@ func (c *Commands) CreateOTPSMSChallenge() SessionCommand {
|
||||
}
|
||||
|
||||
func (c *Commands) createOTPSMSChallenge(returnCode bool, dst *string) SessionCommand {
|
||||
return func(ctx context.Context, cmd *SessionCommands) error {
|
||||
return func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
|
||||
if cmd.sessionWriteModel.UserID == "" {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-JKL3g", "Errors.User.UserIDMissing")
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-JKL3g", "Errors.User.UserIDMissing")
|
||||
}
|
||||
writeModel := NewHumanOTPSMSWriteModel(cmd.sessionWriteModel.UserID, "")
|
||||
if err := cmd.eventstore.FilterToQueryReducer(ctx, writeModel); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if !writeModel.OTPAdded() {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-BJ2g3", "Errors.User.MFA.OTP.NotReady")
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-BJ2g3", "Errors.User.MFA.OTP.NotReady")
|
||||
}
|
||||
code, err := cmd.createCode(ctx, cmd.eventstore.Filter, domain.SecretGeneratorTypeOTPSMS, cmd.otpAlg, c.defaultSecretGenerators.OTPSMS)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if returnCode {
|
||||
*dst = code.Plain
|
||||
}
|
||||
cmd.OTPSMSChallenged(ctx, code.Crypted, code.Expiry, returnCode)
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -74,26 +75,26 @@ func (c *Commands) CreateOTPEmailChallenge() SessionCommand {
|
||||
}
|
||||
|
||||
func (c *Commands) createOTPEmailChallenge(returnCode bool, urlTmpl string, dst *string) SessionCommand {
|
||||
return func(ctx context.Context, cmd *SessionCommands) error {
|
||||
return func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
|
||||
if cmd.sessionWriteModel.UserID == "" {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-JK3gp", "Errors.User.UserIDMissing")
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-JK3gp", "Errors.User.UserIDMissing")
|
||||
}
|
||||
writeModel := NewHumanOTPEmailWriteModel(cmd.sessionWriteModel.UserID, "")
|
||||
if err := cmd.eventstore.FilterToQueryReducer(ctx, writeModel); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if !writeModel.OTPAdded() {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-JKLJ3", "Errors.User.MFA.OTP.NotReady")
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-JKLJ3", "Errors.User.MFA.OTP.NotReady")
|
||||
}
|
||||
code, err := cmd.createCode(ctx, cmd.eventstore.Filter, domain.SecretGeneratorTypeOTPEmail, cmd.otpAlg, c.defaultSecretGenerators.OTPEmail)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if returnCode {
|
||||
*dst = code.Plain
|
||||
}
|
||||
cmd.OTPEmailChallenged(ctx, code.Crypted, code.Expiry, returnCode, urlTmpl)
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -112,37 +113,57 @@ func (c *Commands) OTPEmailSent(ctx context.Context, sessionID, resourceOwner st
|
||||
}
|
||||
|
||||
func CheckOTPSMS(code string) SessionCommand {
|
||||
return func(ctx context.Context, cmd *SessionCommands) (err error) {
|
||||
if cmd.sessionWriteModel.UserID == "" {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-VDrh3", "Errors.User.UserIDMissing")
|
||||
return func(ctx context.Context, cmd *SessionCommands) (_ []eventstore.Command, err error) {
|
||||
writeModel := func(ctx context.Context, userID string, resourceOwner string) (OTPCodeWriteModel, error) {
|
||||
otpWriteModel := NewHumanOTPSMSCodeWriteModel(cmd.sessionWriteModel.UserID, "")
|
||||
err := cmd.eventstore.FilterToQueryReducer(ctx, otpWriteModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// explicitly set the challenge from the session write model since the code write model will only check user events
|
||||
otpWriteModel.otpCode = cmd.sessionWriteModel.OTPSMSCodeChallenge
|
||||
return otpWriteModel, nil
|
||||
}
|
||||
challenge := cmd.sessionWriteModel.OTPSMSCodeChallenge
|
||||
if challenge == nil {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-SF3tv", "Errors.User.Code.NotFound")
|
||||
succeededEvent := func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command {
|
||||
return user.NewHumanOTPSMSCheckSucceededEvent(ctx, aggregate, nil)
|
||||
}
|
||||
err = crypto.VerifyCode(challenge.CreationDate, challenge.Expiry, challenge.Code, code, cmd.otpAlg)
|
||||
failedEvent := func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command {
|
||||
return user.NewHumanOTPSMSCheckFailedEvent(ctx, aggregate, nil)
|
||||
}
|
||||
commands, err := checkOTP(ctx, cmd.sessionWriteModel.UserID, code, "", nil, writeModel, cmd.eventstore.FilterToQueryReducer, cmd.otpAlg, succeededEvent, failedEvent)
|
||||
if err != nil {
|
||||
return err
|
||||
return commands, err
|
||||
}
|
||||
cmd.eventCommands = append(cmd.eventCommands, commands...)
|
||||
cmd.OTPSMSChecked(ctx, cmd.now())
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func CheckOTPEmail(code string) SessionCommand {
|
||||
return func(ctx context.Context, cmd *SessionCommands) (err error) {
|
||||
if cmd.sessionWriteModel.UserID == "" {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-ejo2w", "Errors.User.UserIDMissing")
|
||||
return func(ctx context.Context, cmd *SessionCommands) (_ []eventstore.Command, err error) {
|
||||
writeModel := func(ctx context.Context, userID string, resourceOwner string) (OTPCodeWriteModel, error) {
|
||||
otpWriteModel := NewHumanOTPEmailCodeWriteModel(cmd.sessionWriteModel.UserID, "")
|
||||
err := cmd.eventstore.FilterToQueryReducer(ctx, otpWriteModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// explicitly set the challenge from the session write model since the code write model will only check user events
|
||||
otpWriteModel.otpCode = cmd.sessionWriteModel.OTPEmailCodeChallenge
|
||||
return otpWriteModel, nil
|
||||
}
|
||||
challenge := cmd.sessionWriteModel.OTPEmailCodeChallenge
|
||||
if challenge == nil {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-zF3g3", "Errors.User.Code.NotFound")
|
||||
succeededEvent := func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command {
|
||||
return user.NewHumanOTPEmailCheckSucceededEvent(ctx, aggregate, nil)
|
||||
}
|
||||
err = crypto.VerifyCode(challenge.CreationDate, challenge.Expiry, challenge.Code, code, cmd.otpAlg)
|
||||
failedEvent := func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command {
|
||||
return user.NewHumanOTPEmailCheckFailedEvent(ctx, aggregate, nil)
|
||||
}
|
||||
commands, err := checkOTP(ctx, cmd.sessionWriteModel.UserID, code, "", nil, writeModel, cmd.eventstore.FilterToQueryReducer, cmd.otpAlg, succeededEvent, failedEvent)
|
||||
if err != nil {
|
||||
return err
|
||||
return commands, err
|
||||
}
|
||||
cmd.eventCommands = append(cmd.eventCommands, commands...)
|
||||
cmd.OTPEmailChecked(ctx, cmd.now())
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/org"
|
||||
"github.com/zitadel/zitadel/internal/repository/session"
|
||||
"github.com/zitadel/zitadel/internal/repository/user"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
@ -110,8 +111,9 @@ func TestCommands_CreateOTPSMSChallengeReturnCode(t *testing.T) {
|
||||
now: time.Now,
|
||||
}
|
||||
|
||||
err := cmd(context.Background(), cmds)
|
||||
gotCmds, err := cmd(context.Background(), cmds)
|
||||
assert.ErrorIs(t, err, tt.res.err)
|
||||
assert.Empty(t, gotCmds)
|
||||
assert.Equal(t, tt.res.returnCode, dst)
|
||||
assert.Equal(t, tt.res.commands, cmds.eventCommands)
|
||||
})
|
||||
@ -210,8 +212,9 @@ func TestCommands_CreateOTPSMSChallenge(t *testing.T) {
|
||||
now: time.Now,
|
||||
}
|
||||
|
||||
err := cmd(context.Background(), cmds)
|
||||
gotCmds, err := cmd(context.Background(), cmds)
|
||||
assert.ErrorIs(t, err, tt.res.err)
|
||||
assert.Empty(t, gotCmds)
|
||||
assert.Equal(t, tt.res.commands, cmds.eventCommands)
|
||||
})
|
||||
}
|
||||
@ -410,8 +413,9 @@ func TestCommands_CreateOTPEmailChallengeURLTemplate(t *testing.T) {
|
||||
now: time.Now,
|
||||
}
|
||||
|
||||
err = cmd(context.Background(), cmds)
|
||||
gotCmds, err := cmd(context.Background(), cmds)
|
||||
assert.ErrorIs(t, err, tt.res.err)
|
||||
assert.Empty(t, gotCmds)
|
||||
assert.Equal(t, tt.res.commands, cmds.eventCommands)
|
||||
})
|
||||
}
|
||||
@ -511,8 +515,9 @@ func TestCommands_CreateOTPEmailChallengeReturnCode(t *testing.T) {
|
||||
now: time.Now,
|
||||
}
|
||||
|
||||
err := cmd(context.Background(), cmds)
|
||||
gotCmds, err := cmd(context.Background(), cmds)
|
||||
assert.ErrorIs(t, err, tt.res.err)
|
||||
assert.Empty(t, gotCmds)
|
||||
assert.Equal(t, tt.res.returnCode, dst)
|
||||
assert.Equal(t, tt.res.commands, cmds.eventCommands)
|
||||
})
|
||||
@ -611,8 +616,9 @@ func TestCommands_CreateOTPEmailChallenge(t *testing.T) {
|
||||
now: time.Now,
|
||||
}
|
||||
|
||||
err := cmd(context.Background(), cmds)
|
||||
gotCmds, err := cmd(context.Background(), cmds)
|
||||
assert.ErrorIs(t, err, tt.res.err)
|
||||
assert.Empty(t, gotCmds)
|
||||
assert.Equal(t, tt.res.commands, cmds.eventCommands)
|
||||
})
|
||||
}
|
||||
@ -701,8 +707,9 @@ func TestCheckOTPSMS(t *testing.T) {
|
||||
code string
|
||||
}
|
||||
type res struct {
|
||||
err error
|
||||
commands []eventstore.Command
|
||||
err error
|
||||
commands []eventstore.Command
|
||||
errorCommands []eventstore.Command
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@ -720,13 +727,43 @@ func TestCheckOTPSMS(t *testing.T) {
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-VDrh3", "Errors.User.UserIDMissing"),
|
||||
err: zerrors.ThrowInvalidArgument(nil, "COMMAND-S453v", "Errors.User.UserIDMissing"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing code",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
userID: "userID",
|
||||
},
|
||||
args: args{},
|
||||
res: res{
|
||||
err: zerrors.ThrowInvalidArgument(nil, "COMMAND-SJl2g", "Errors.User.Code.Empty"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not set up",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(),
|
||||
),
|
||||
userID: "userID",
|
||||
},
|
||||
args: args{
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-d2r52", "Errors.User.MFA.OTP.NotReady"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing challenge",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(user.NewHumanOTPSMSAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
|
||||
),
|
||||
),
|
||||
userID: "userID",
|
||||
otpCodeChallenge: nil,
|
||||
},
|
||||
@ -734,14 +771,26 @@ func TestCheckOTPSMS(t *testing.T) {
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-SF3tv", "Errors.User.Code.NotFound"),
|
||||
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-S34gh", "Errors.User.Code.NotFound"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid code",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
userID: "userID",
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(user.NewHumanOTPSMSAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
|
||||
),
|
||||
expectFilter(), // recheck
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
org.NewLockoutPolicyAddedEvent(context.Background(), &org.NewAggregate("org1").Aggregate,
|
||||
0, 0, false,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
userID: "userID",
|
||||
otpCodeChallenge: &OTPCode{
|
||||
Code: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
@ -759,13 +808,61 @@ func TestCheckOTPSMS(t *testing.T) {
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.ThrowPreconditionFailed(nil, "CODE-QvUQ4P", "Errors.User.Code.Expired"),
|
||||
errorCommands: []eventstore.Command{
|
||||
user.NewHumanOTPSMSCheckFailedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid code, locked",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(user.NewHumanOTPSMSAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
|
||||
),
|
||||
expectFilter(), // recheck
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
org.NewLockoutPolicyAddedEvent(context.Background(), &org.NewAggregate("org1").Aggregate,
|
||||
0, 1, false,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
userID: "userID",
|
||||
otpCodeChallenge: &OTPCode{
|
||||
Code: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("code"),
|
||||
},
|
||||
Expiry: 5 * time.Minute,
|
||||
CreationDate: testNow.Add(-10 * time.Minute),
|
||||
},
|
||||
otpAlg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.ThrowPreconditionFailed(nil, "CODE-QvUQ4P", "Errors.User.Code.Expired"),
|
||||
errorCommands: []eventstore.Command{
|
||||
user.NewHumanOTPSMSCheckFailedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
|
||||
user.NewUserLockedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "check ok",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
userID: "userID",
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(user.NewHumanOTPSMSAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
|
||||
),
|
||||
expectFilter(), // recheck
|
||||
),
|
||||
userID: "userID",
|
||||
otpCodeChallenge: &OTPCode{
|
||||
Code: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
@ -783,12 +880,44 @@ func TestCheckOTPSMS(t *testing.T) {
|
||||
},
|
||||
res: res{
|
||||
commands: []eventstore.Command{
|
||||
user.NewHumanOTPSMSCheckSucceededEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
|
||||
session.NewOTPSMSCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
testNow,
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "check ok, locked in the meantime",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(user.NewHumanOTPSMSAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
|
||||
),
|
||||
expectFilter(
|
||||
user.NewUserLockedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate),
|
||||
),
|
||||
),
|
||||
userID: "userID",
|
||||
otpCodeChallenge: &OTPCode{
|
||||
Code: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("code"),
|
||||
},
|
||||
Expiry: 5 * time.Minute,
|
||||
CreationDate: testNow,
|
||||
},
|
||||
otpAlg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-S6h4R", "Errors.User.Locked"),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
@ -811,8 +940,9 @@ func TestCheckOTPSMS(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := cmd(context.Background(), cmds)
|
||||
gotCmds, err := cmd(context.Background(), cmds)
|
||||
assert.ErrorIs(t, err, tt.res.err)
|
||||
assert.Equal(t, tt.res.errorCommands, gotCmds)
|
||||
assert.Equal(t, tt.res.commands, cmds.eventCommands)
|
||||
})
|
||||
}
|
||||
@ -829,8 +959,9 @@ func TestCheckOTPEmail(t *testing.T) {
|
||||
code string
|
||||
}
|
||||
type res struct {
|
||||
err error
|
||||
commands []eventstore.Command
|
||||
err error
|
||||
commands []eventstore.Command
|
||||
errorCommands []eventstore.Command
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@ -848,13 +979,43 @@ func TestCheckOTPEmail(t *testing.T) {
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-ejo2w", "Errors.User.UserIDMissing"),
|
||||
err: zerrors.ThrowInvalidArgument(nil, "COMMAND-S453v", "Errors.User.UserIDMissing"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing code",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
userID: "userID",
|
||||
},
|
||||
args: args{},
|
||||
res: res{
|
||||
err: zerrors.ThrowInvalidArgument(nil, "COMMAND-SJl2g", "Errors.User.Code.Empty"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not set up",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(),
|
||||
),
|
||||
userID: "userID",
|
||||
},
|
||||
args: args{
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-d2r52", "Errors.User.MFA.OTP.NotReady"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing challenge",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(user.NewHumanOTPEmailAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
|
||||
),
|
||||
),
|
||||
userID: "userID",
|
||||
otpCodeChallenge: nil,
|
||||
},
|
||||
@ -862,14 +1023,26 @@ func TestCheckOTPEmail(t *testing.T) {
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-zF3g3", "Errors.User.Code.NotFound"),
|
||||
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-S34gh", "Errors.User.Code.NotFound"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid code",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
userID: "userID",
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(user.NewHumanOTPEmailAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
|
||||
),
|
||||
expectFilter(), // recheck
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
org.NewLockoutPolicyAddedEvent(context.Background(), &org.NewAggregate("org1").Aggregate,
|
||||
0, 0, false,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
userID: "userID",
|
||||
otpCodeChallenge: &OTPCode{
|
||||
Code: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
@ -887,13 +1060,61 @@ func TestCheckOTPEmail(t *testing.T) {
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.ThrowPreconditionFailed(nil, "CODE-QvUQ4P", "Errors.User.Code.Expired"),
|
||||
errorCommands: []eventstore.Command{
|
||||
user.NewHumanOTPEmailCheckFailedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid code, locked",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(user.NewHumanOTPEmailAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
|
||||
),
|
||||
expectFilter(), // recheck
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
org.NewLockoutPolicyAddedEvent(context.Background(), &org.NewAggregate("org1").Aggregate,
|
||||
0, 1, false,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
userID: "userID",
|
||||
otpCodeChallenge: &OTPCode{
|
||||
Code: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("code"),
|
||||
},
|
||||
Expiry: 5 * time.Minute,
|
||||
CreationDate: testNow.Add(-10 * time.Minute),
|
||||
},
|
||||
otpAlg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.ThrowPreconditionFailed(nil, "CODE-QvUQ4P", "Errors.User.Code.Expired"),
|
||||
errorCommands: []eventstore.Command{
|
||||
user.NewHumanOTPEmailCheckFailedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
|
||||
user.NewUserLockedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "check ok",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
userID: "userID",
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(user.NewHumanOTPEmailAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
|
||||
),
|
||||
expectFilter(), // recheck
|
||||
),
|
||||
userID: "userID",
|
||||
otpCodeChallenge: &OTPCode{
|
||||
Code: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
@ -911,12 +1132,44 @@ func TestCheckOTPEmail(t *testing.T) {
|
||||
},
|
||||
res: res{
|
||||
commands: []eventstore.Command{
|
||||
user.NewHumanOTPEmailCheckSucceededEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
|
||||
session.NewOTPEmailCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
testNow,
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "check ok, locked in the meantime",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(user.NewHumanOTPEmailAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
|
||||
),
|
||||
expectFilter(
|
||||
user.NewUserLockedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate),
|
||||
),
|
||||
),
|
||||
userID: "userID",
|
||||
otpCodeChallenge: &OTPCode{
|
||||
Code: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("code"),
|
||||
},
|
||||
Expiry: 5 * time.Minute,
|
||||
CreationDate: testNow,
|
||||
},
|
||||
otpAlg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args: args{
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-S6h4R", "Errors.User.Locked"),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
@ -939,8 +1192,9 @@ func TestCheckOTPEmail(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err := cmd(context.Background(), cmds)
|
||||
gotCmds, err := cmd(context.Background(), cmds)
|
||||
assert.ErrorIs(t, err, tt.res.err)
|
||||
assert.Equal(t, tt.res.errorCommands, gotCmds)
|
||||
assert.Equal(t, tt.res.commands, cmds.eventCommands)
|
||||
})
|
||||
}
|
||||
|
@ -22,6 +22,7 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
"github.com/zitadel/zitadel/internal/id/mock"
|
||||
"github.com/zitadel/zitadel/internal/repository/idpintent"
|
||||
"github.com/zitadel/zitadel/internal/repository/org"
|
||||
"github.com/zitadel/zitadel/internal/repository/session"
|
||||
"github.com/zitadel/zitadel/internal/repository/user"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
@ -430,8 +431,8 @@ func TestCommands_updateSession(t *testing.T) {
|
||||
checks: &SessionCommands{
|
||||
sessionWriteModel: NewSessionWriteModel("sessionID", "instance1"),
|
||||
sessionCommands: []SessionCommand{
|
||||
func(ctx context.Context, cmd *SessionCommands) error {
|
||||
return zerrors.ThrowInternal(nil, "id", "check failed")
|
||||
func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
|
||||
return nil, zerrors.ThrowInternal(nil, "id", "check failed")
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -525,6 +526,55 @@ func TestCommands_updateSession(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"set user, invalid password",
|
||||
fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
user.NewHumanAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate,
|
||||
"username", "", "", "", "", language.English, domain.GenderUnspecified, "", false),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
user.NewHumanPasswordChangedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate,
|
||||
"$plain$x$password", false, ""),
|
||||
),
|
||||
),
|
||||
expectFilter(), // recheck
|
||||
expectFilter(
|
||||
org.NewLockoutPolicyAddedEvent(context.Background(), &org.NewAggregate("org1").Aggregate, 0, 0, false),
|
||||
),
|
||||
expectPush(
|
||||
user.NewHumanPasswordCheckFailedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
|
||||
),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: authz.NewMockContext("instance1", "", ""),
|
||||
checks: &SessionCommands{
|
||||
sessionWriteModel: NewSessionWriteModel("sessionID", "instance1"),
|
||||
sessionCommands: []SessionCommand{
|
||||
CheckUser("userID", "org1", &language.Afrikaans),
|
||||
CheckPassword("invalid password"),
|
||||
},
|
||||
createToken: func(sessionID string) (string, string, error) {
|
||||
return "tokenID",
|
||||
"token",
|
||||
nil
|
||||
},
|
||||
hasher: mockPasswordHasher("x"),
|
||||
now: func() time.Time {
|
||||
return testNow
|
||||
},
|
||||
},
|
||||
metadata: map[string][]byte{
|
||||
"key": []byte("value"),
|
||||
},
|
||||
},
|
||||
res{
|
||||
err: zerrors.ThrowInvalidArgument(nil, "COMMAND-3M0fs", "Errors.User.Password.Invalid"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"set user, password, metadata and token",
|
||||
fields{
|
||||
@ -539,10 +589,12 @@ func TestCommands_updateSession(t *testing.T) {
|
||||
"$plain$x$password", false, ""),
|
||||
),
|
||||
),
|
||||
expectFilter(), // recheck
|
||||
expectPush(
|
||||
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
"userID", "org1", testNow, &language.Afrikaans,
|
||||
),
|
||||
user.NewHumanPasswordCheckSucceededEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
|
||||
session.NewPasswordCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
testNow,
|
||||
),
|
||||
@ -872,6 +924,7 @@ func TestCheckTOTP(t *testing.T) {
|
||||
|
||||
sessAgg := &session.NewAggregate("session1", "instance1").Aggregate
|
||||
userAgg := &user.NewAggregate("user1", "org1").Aggregate
|
||||
orgAgg := &org.NewAggregate("org1").Aggregate
|
||||
|
||||
code, err := totp.GenerateCode(key.Secret(), testNow)
|
||||
require.NoError(t, err)
|
||||
@ -886,6 +939,7 @@ func TestCheckTOTP(t *testing.T) {
|
||||
code string
|
||||
fields fields
|
||||
wantEventCommands []eventstore.Command
|
||||
wantErrorCommands []eventstore.Command
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
@ -897,7 +951,7 @@ func TestCheckTOTP(t *testing.T) {
|
||||
},
|
||||
eventstore: expectEventstore(),
|
||||
},
|
||||
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Neil7", "Errors.User.UserIDMissing"),
|
||||
wantErr: zerrors.ThrowInvalidArgument(nil, "COMMAND-8N9ds", "Errors.User.UserIDMissing"),
|
||||
},
|
||||
{
|
||||
name: "filter error",
|
||||
@ -931,7 +985,7 @@ func TestCheckTOTP(t *testing.T) {
|
||||
),
|
||||
),
|
||||
},
|
||||
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-eej1U", "Errors.User.MFA.OTP.NotReady"),
|
||||
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-3Mif9s", "Errors.User.MFA.OTP.NotReady"),
|
||||
},
|
||||
{
|
||||
name: "otp verify error",
|
||||
@ -951,8 +1005,45 @@ func TestCheckTOTP(t *testing.T) {
|
||||
user.NewHumanOTPVerifiedEvent(ctx, userAgg, "agent1"),
|
||||
),
|
||||
),
|
||||
expectFilter(), // recheck
|
||||
expectFilter(
|
||||
eventFromEventPusher(org.NewLockoutPolicyAddedEvent(ctx, orgAgg, 0, 0, false)),
|
||||
),
|
||||
),
|
||||
},
|
||||
wantErrorCommands: []eventstore.Command{
|
||||
user.NewHumanOTPCheckFailedEvent(ctx, userAgg, nil),
|
||||
},
|
||||
wantErr: zerrors.ThrowInvalidArgument(nil, "EVENT-8isk2", "Errors.User.MFA.OTP.InvalidCode"),
|
||||
},
|
||||
{
|
||||
name: "otp verify error, locked",
|
||||
code: "foobar",
|
||||
fields: fields{
|
||||
sessionWriteModel: &SessionWriteModel{
|
||||
UserID: "user1",
|
||||
UserCheckedAt: testNow,
|
||||
aggregate: sessAgg,
|
||||
},
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
user.NewHumanOTPAddedEvent(ctx, userAgg, secret),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
user.NewHumanOTPVerifiedEvent(ctx, userAgg, "agent1"),
|
||||
),
|
||||
),
|
||||
expectFilter(), // recheck
|
||||
expectFilter(
|
||||
eventFromEventPusher(org.NewLockoutPolicyAddedEvent(ctx, orgAgg, 1, 1, false)),
|
||||
),
|
||||
),
|
||||
},
|
||||
wantErrorCommands: []eventstore.Command{
|
||||
user.NewHumanOTPCheckFailedEvent(ctx, userAgg, nil),
|
||||
user.NewUserLockedEvent(ctx, userAgg),
|
||||
},
|
||||
wantErr: zerrors.ThrowInvalidArgument(nil, "EVENT-8isk2", "Errors.User.MFA.OTP.InvalidCode"),
|
||||
},
|
||||
{
|
||||
@ -973,12 +1064,39 @@ func TestCheckTOTP(t *testing.T) {
|
||||
user.NewHumanOTPVerifiedEvent(ctx, userAgg, "agent1"),
|
||||
),
|
||||
),
|
||||
expectFilter(), // recheck
|
||||
),
|
||||
},
|
||||
wantEventCommands: []eventstore.Command{
|
||||
user.NewHumanOTPCheckSucceededEvent(ctx, userAgg, nil),
|
||||
session.NewTOTPCheckedEvent(ctx, sessAgg, testNow),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok, but locked in the meantime",
|
||||
code: code,
|
||||
fields: fields{
|
||||
sessionWriteModel: &SessionWriteModel{
|
||||
UserID: "user1",
|
||||
UserCheckedAt: testNow,
|
||||
aggregate: sessAgg,
|
||||
},
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
user.NewHumanOTPAddedEvent(ctx, userAgg, secret),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
user.NewHumanOTPVerifiedEvent(ctx, userAgg, "agent1"),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
user.NewUserLockedEvent(ctx, userAgg),
|
||||
),
|
||||
),
|
||||
},
|
||||
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-SF3fg", "Errors.User.Locked"),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
@ -988,8 +1106,9 @@ func TestCheckTOTP(t *testing.T) {
|
||||
totpAlg: cryptoAlg,
|
||||
now: func() time.Time { return testNow },
|
||||
}
|
||||
err := CheckTOTP(tt.code)(ctx, cmd)
|
||||
gotCmds, err := CheckTOTP(tt.code)(ctx, cmd)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.wantErrorCommands, gotCmds)
|
||||
assert.Equal(t, tt.wantEventCommands, cmd.eventCommands)
|
||||
})
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
@ -41,49 +42,49 @@ func (s *SessionCommands) getHumanWebAuthNTokenReadModel(ctx context.Context, us
|
||||
}
|
||||
|
||||
func (c *Commands) CreateWebAuthNChallenge(userVerification domain.UserVerificationRequirement, rpid string, dst json.Unmarshaler) SessionCommand {
|
||||
return func(ctx context.Context, cmd *SessionCommands) error {
|
||||
return func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
|
||||
humanPasskeys, err := cmd.getHumanWebAuthNTokens(ctx, userVerification)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
webAuthNLogin, err := c.webauthnConfig.BeginLogin(ctx, humanPasskeys.human, userVerification, rpid, humanPasskeys.tokens...)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if err = json.Unmarshal(webAuthNLogin.CredentialAssertionData, dst); err != nil {
|
||||
return zerrors.ThrowInternal(err, "COMMAND-Yah6A", "Errors.Internal")
|
||||
return nil, zerrors.ThrowInternal(err, "COMMAND-Yah6A", "Errors.Internal")
|
||||
}
|
||||
|
||||
cmd.WebAuthNChallenged(ctx, webAuthNLogin.Challenge, webAuthNLogin.AllowedCredentialIDs, webAuthNLogin.UserVerification, rpid)
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Commands) CheckWebAuthN(credentialAssertionData json.Marshaler) SessionCommand {
|
||||
return func(ctx context.Context, cmd *SessionCommands) error {
|
||||
return func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
|
||||
credentialAssertionData, err := json.Marshal(credentialAssertionData)
|
||||
if err != nil {
|
||||
return zerrors.ThrowInternal(err, "COMMAND-ohG2o", "Errors.Internal")
|
||||
return nil, zerrors.ThrowInternal(err, "COMMAND-ohG2o", "Errors.Internal")
|
||||
}
|
||||
challenge := cmd.sessionWriteModel.WebAuthNChallenge
|
||||
if challenge == nil {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-Ioqu5", "Errors.Session.WebAuthN.NoChallenge")
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Ioqu5", "Errors.Session.WebAuthN.NoChallenge")
|
||||
}
|
||||
webAuthNTokens, err := cmd.getHumanWebAuthNTokens(ctx, challenge.UserVerification)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
webAuthN := challenge.WebAuthNLogin(webAuthNTokens.human, credentialAssertionData)
|
||||
|
||||
credential, err := c.webauthnConfig.FinishLogin(ctx, webAuthNTokens.human, webAuthN, credentialAssertionData, webAuthNTokens.tokens...)
|
||||
if err != nil && (credential == nil || credential.ID == nil) {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
_, token := domain.GetTokenByKeyID(webAuthNTokens.tokens, credential.ID)
|
||||
if token == nil {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-Aej7i", "Errors.User.WebAuthN.NotFound")
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Aej7i", "Errors.User.WebAuthN.NotFound")
|
||||
}
|
||||
cmd.WebAuthNChecked(ctx, cmd.now(), token.WebAuthNTokenID, credential.Authenticator.SignCount, credential.Flags.UserVerified)
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
@ -161,48 +161,67 @@ func (c *Commands) HumanCheckMFATOTPSetup(ctx context.Context, userID, code, use
|
||||
}
|
||||
|
||||
func (c *Commands) HumanCheckMFATOTP(ctx context.Context, userID, code, resourceOwner string, authRequest *domain.AuthRequest) error {
|
||||
commands, err := checkTOTP(
|
||||
ctx,
|
||||
userID,
|
||||
resourceOwner,
|
||||
code,
|
||||
c.eventstore.FilterToQueryReducer,
|
||||
c.multifactors.OTP.CryptoMFA,
|
||||
authRequestDomainToAuthRequestInfo(authRequest),
|
||||
)
|
||||
|
||||
_, pushErr := c.eventstore.Push(ctx, commands...)
|
||||
logging.OnError(pushErr).Error("error create password check failed event")
|
||||
return err
|
||||
}
|
||||
|
||||
func checkTOTP(
|
||||
ctx context.Context,
|
||||
userID, resourceOwner, code string,
|
||||
queryReducer func(ctx context.Context, r eventstore.QueryReducer) error,
|
||||
alg crypto.EncryptionAlgorithm,
|
||||
optionalAuthRequestInfo *user.AuthRequestInfo,
|
||||
) ([]eventstore.Command, error) {
|
||||
if userID == "" {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-8N9ds", "Errors.User.UserIDMissing")
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-8N9ds", "Errors.User.UserIDMissing")
|
||||
}
|
||||
existingOTP, err := c.totpWriteModelByID(ctx, userID, resourceOwner)
|
||||
existingOTP := NewHumanTOTPWriteModel(userID, resourceOwner)
|
||||
err := queryReducer(ctx, existingOTP)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if existingOTP.State != domain.MFAStateReady {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-3Mif9s", "Errors.User.MFA.OTP.NotReady")
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-3Mif9s", "Errors.User.MFA.OTP.NotReady")
|
||||
}
|
||||
userAgg := UserAggregateFromWriteModel(&existingOTP.WriteModel)
|
||||
verifyErr := domain.VerifyTOTP(code, existingOTP.Secret, c.multifactors.OTP.CryptoMFA)
|
||||
verifyErr := domain.VerifyTOTP(code, existingOTP.Secret, alg)
|
||||
|
||||
// recheck for additional events (failed OTP checks or locks)
|
||||
recheckErr := c.eventstore.FilterToQueryReducer(ctx, existingOTP)
|
||||
recheckErr := queryReducer(ctx, existingOTP)
|
||||
if recheckErr != nil {
|
||||
return recheckErr
|
||||
return nil, recheckErr
|
||||
}
|
||||
if existingOTP.UserLocked {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-SF3fg", "Errors.User.Locked")
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-SF3fg", "Errors.User.Locked")
|
||||
}
|
||||
|
||||
// the OTP check succeeded and the user was not locked in the meantime
|
||||
if verifyErr == nil {
|
||||
_, err = c.eventstore.Push(ctx, user.NewHumanOTPCheckSucceededEvent(ctx, userAgg, authRequestDomainToAuthRequestInfo(authRequest)))
|
||||
return err
|
||||
return []eventstore.Command{user.NewHumanOTPCheckSucceededEvent(ctx, userAgg, optionalAuthRequestInfo)}, nil
|
||||
}
|
||||
|
||||
// the OTP check failed, therefore check if the limit was reached and the user must additionally be locked
|
||||
commands := make([]eventstore.Command, 0, 2)
|
||||
commands = append(commands, user.NewHumanOTPCheckFailedEvent(ctx, userAgg, authRequestDomainToAuthRequestInfo(authRequest)))
|
||||
lockoutPolicy, err := c.getLockoutPolicy(ctx, resourceOwner)
|
||||
commands = append(commands, user.NewHumanOTPCheckFailedEvent(ctx, userAgg, optionalAuthRequestInfo))
|
||||
lockoutPolicy, err := getLockoutPolicy(ctx, existingOTP.ResourceOwner, queryReducer)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if lockoutPolicy.MaxOTPAttempts > 0 && existingOTP.CheckFailedCount+1 >= lockoutPolicy.MaxOTPAttempts {
|
||||
commands = append(commands, user.NewUserLockedEvent(ctx, userAgg))
|
||||
}
|
||||
|
||||
_, pushErr := c.eventstore.Push(ctx, commands...)
|
||||
logging.OnError(pushErr).Error("error create password check failed event")
|
||||
return verifyErr
|
||||
return commands, verifyErr
|
||||
}
|
||||
|
||||
func (c *Commands) HumanRemoveTOTP(ctx context.Context, userID, resourceOwner string) (*domain.ObjectDetails, error) {
|
||||
@ -342,16 +361,23 @@ func (c *Commands) HumanCheckOTPSMS(ctx context.Context, userID, code, resourceO
|
||||
failedEvent := func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command {
|
||||
return user.NewHumanOTPSMSCheckFailedEvent(ctx, aggregate, authRequestDomainToAuthRequestInfo(authRequest))
|
||||
}
|
||||
return c.humanCheckOTP(
|
||||
commands, err := checkOTP(
|
||||
ctx,
|
||||
userID,
|
||||
code,
|
||||
resourceOwner,
|
||||
authRequest,
|
||||
writeModel,
|
||||
c.eventstore.FilterToQueryReducer,
|
||||
c.userEncryption,
|
||||
succeededEvent,
|
||||
failedEvent,
|
||||
)
|
||||
if len(commands) > 0 {
|
||||
_, pushErr := c.eventstore.Push(ctx, commands...)
|
||||
logging.WithFields("userID", userID).OnError(pushErr).Error("otp failure check push failed")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// AddHumanOTPEmail adds the OTP Email factor to a user.
|
||||
@ -467,16 +493,23 @@ func (c *Commands) HumanCheckOTPEmail(ctx context.Context, userID, code, resourc
|
||||
failedEvent := func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command {
|
||||
return user.NewHumanOTPEmailCheckFailedEvent(ctx, aggregate, authRequestDomainToAuthRequestInfo(authRequest))
|
||||
}
|
||||
return c.humanCheckOTP(
|
||||
commands, err := checkOTP(
|
||||
ctx,
|
||||
userID,
|
||||
code,
|
||||
resourceOwner,
|
||||
authRequest,
|
||||
writeModel,
|
||||
c.eventstore.FilterToQueryReducer,
|
||||
c.userEncryption,
|
||||
succeededEvent,
|
||||
failedEvent,
|
||||
)
|
||||
if len(commands) > 0 {
|
||||
_, pushErr := c.eventstore.Push(ctx, commands...)
|
||||
logging.WithFields("userID", userID).OnError(pushErr).Error("otp failure check push failed")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// sendHumanOTP creates a code for a registered mechanism (sms / email), which is used for a check (during login)
|
||||
@ -534,62 +567,57 @@ func (c *Commands) humanOTPSent(
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Commands) humanCheckOTP(
|
||||
func checkOTP(
|
||||
ctx context.Context,
|
||||
userID, code, resourceOwner string,
|
||||
authRequest *domain.AuthRequest,
|
||||
writeModelByID func(ctx context.Context, userID string, resourceOwner string) (OTPCodeWriteModel, error),
|
||||
checkSucceededEvent func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command,
|
||||
checkFailedEvent func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command,
|
||||
) error {
|
||||
queryReducer func(ctx context.Context, r eventstore.QueryReducer) error,
|
||||
alg crypto.EncryptionAlgorithm,
|
||||
checkSucceededEvent, checkFailedEvent func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command,
|
||||
) ([]eventstore.Command, error) {
|
||||
if userID == "" {
|
||||
return zerrors.ThrowInvalidArgument(nil, "COMMAND-S453v", "Errors.User.UserIDMissing")
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-S453v", "Errors.User.UserIDMissing")
|
||||
}
|
||||
if code == "" {
|
||||
return zerrors.ThrowInvalidArgument(nil, "COMMAND-SJl2g", "Errors.User.Code.Empty")
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-SJl2g", "Errors.User.Code.Empty")
|
||||
}
|
||||
existingOTP, err := writeModelByID(ctx, userID, resourceOwner)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if !existingOTP.OTPAdded() {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-d2r52", "Errors.User.MFA.OTP.NotReady")
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-d2r52", "Errors.User.MFA.OTP.NotReady")
|
||||
}
|
||||
if existingOTP.Code() == nil {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-S34gh", "Errors.User.Code.NotFound")
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-S34gh", "Errors.User.Code.NotFound")
|
||||
}
|
||||
userAgg := &user.NewAggregate(userID, existingOTP.ResourceOwner()).Aggregate
|
||||
verifyErr := crypto.VerifyCode(existingOTP.CodeCreationDate(), existingOTP.CodeExpiry(), existingOTP.Code(), code, c.userEncryption)
|
||||
verifyErr := crypto.VerifyCode(existingOTP.CodeCreationDate(), existingOTP.CodeExpiry(), existingOTP.Code(), code, alg)
|
||||
|
||||
// recheck for additional events (failed OTP checks or locks)
|
||||
recheckErr := c.eventstore.FilterToQueryReducer(ctx, existingOTP)
|
||||
recheckErr := queryReducer(ctx, existingOTP)
|
||||
if recheckErr != nil {
|
||||
return recheckErr
|
||||
return nil, recheckErr
|
||||
}
|
||||
if existingOTP.UserLocked() {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-S6h4R", "Errors.User.Locked")
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-S6h4R", "Errors.User.Locked")
|
||||
}
|
||||
|
||||
// the OTP check succeeded and the user was not locked in the meantime
|
||||
if verifyErr == nil {
|
||||
_, err = c.eventstore.Push(ctx, checkSucceededEvent(ctx, userAgg, authRequestDomainToAuthRequestInfo(authRequest)))
|
||||
return err
|
||||
return []eventstore.Command{checkSucceededEvent(ctx, userAgg, authRequestDomainToAuthRequestInfo(authRequest))}, nil
|
||||
}
|
||||
|
||||
// the OTP check failed, therefore check if the limit was reached and the user must additionally be locked
|
||||
commands := make([]eventstore.Command, 0, 2)
|
||||
commands = append(commands, checkFailedEvent(ctx, userAgg, authRequestDomainToAuthRequestInfo(authRequest)))
|
||||
lockoutPolicy, err := c.getLockoutPolicy(ctx, resourceOwner)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if lockoutPolicy.MaxOTPAttempts > 0 && existingOTP.CheckFailedCount()+1 >= lockoutPolicy.MaxOTPAttempts {
|
||||
lockoutPolicy, lockoutErr := getLockoutPolicy(ctx, existingOTP.ResourceOwner(), queryReducer)
|
||||
logging.OnError(lockoutErr).Error("unable to get lockout policy")
|
||||
if lockoutPolicy != nil && lockoutPolicy.MaxOTPAttempts > 0 && existingOTP.CheckFailedCount()+1 >= lockoutPolicy.MaxOTPAttempts {
|
||||
commands = append(commands, user.NewUserLockedEvent(ctx, userAgg))
|
||||
}
|
||||
|
||||
_, pushErr := c.eventstore.Push(ctx, commands...)
|
||||
logging.WithFields("userID", userID).OnError(pushErr).Error("otp failure check push failed")
|
||||
return verifyErr
|
||||
return commands, verifyErr
|
||||
}
|
||||
|
||||
func (c *Commands) totpWriteModelByID(ctx context.Context, userID, resourceOwner string) (writeModel *HumanTOTPWriteModel, err error) {
|
||||
|
@ -157,24 +157,31 @@ func (wm *HumanOTPSMSWriteModel) Query() *eventstore.SearchQueryBuilder {
|
||||
type HumanOTPSMSCodeWriteModel struct {
|
||||
*HumanOTPSMSWriteModel
|
||||
|
||||
code *crypto.CryptoValue
|
||||
codeCreationDate time.Time
|
||||
codeExpiry time.Duration
|
||||
otpCode *OTPCode
|
||||
|
||||
checkFailedCount uint64
|
||||
userLocked bool
|
||||
}
|
||||
|
||||
func (wm *HumanOTPSMSCodeWriteModel) CodeCreationDate() time.Time {
|
||||
return wm.codeCreationDate
|
||||
if wm.otpCode == nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return wm.otpCode.CreationDate
|
||||
}
|
||||
|
||||
func (wm *HumanOTPSMSCodeWriteModel) CodeExpiry() time.Duration {
|
||||
return wm.codeExpiry
|
||||
if wm.otpCode == nil {
|
||||
return 0
|
||||
}
|
||||
return wm.otpCode.Expiry
|
||||
}
|
||||
|
||||
func (wm *HumanOTPSMSCodeWriteModel) Code() *crypto.CryptoValue {
|
||||
return wm.code
|
||||
if wm.otpCode == nil {
|
||||
return nil
|
||||
}
|
||||
return wm.otpCode.Code
|
||||
}
|
||||
|
||||
func (wm *HumanOTPSMSCodeWriteModel) CheckFailedCount() uint64 {
|
||||
@ -195,9 +202,11 @@ func (wm *HumanOTPSMSCodeWriteModel) Reduce() error {
|
||||
for _, event := range wm.Events {
|
||||
switch e := event.(type) {
|
||||
case *user.HumanOTPSMSCodeAddedEvent:
|
||||
wm.code = e.Code
|
||||
wm.codeCreationDate = e.CreationDate()
|
||||
wm.codeExpiry = e.Expiry
|
||||
wm.otpCode = &OTPCode{
|
||||
Code: e.Code,
|
||||
CreationDate: e.CreationDate(),
|
||||
Expiry: e.Expiry,
|
||||
}
|
||||
case *user.HumanOTPSMSCheckSucceededEvent:
|
||||
wm.checkFailedCount = 0
|
||||
case *user.HumanOTPSMSCheckFailedEvent:
|
||||
@ -300,24 +309,31 @@ func (wm *HumanOTPEmailWriteModel) Query() *eventstore.SearchQueryBuilder {
|
||||
type HumanOTPEmailCodeWriteModel struct {
|
||||
*HumanOTPEmailWriteModel
|
||||
|
||||
code *crypto.CryptoValue
|
||||
codeCreationDate time.Time
|
||||
codeExpiry time.Duration
|
||||
otpCode *OTPCode
|
||||
|
||||
checkFailedCount uint64
|
||||
userLocked bool
|
||||
}
|
||||
|
||||
func (wm *HumanOTPEmailCodeWriteModel) CodeCreationDate() time.Time {
|
||||
return wm.codeCreationDate
|
||||
if wm.otpCode == nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return wm.otpCode.CreationDate
|
||||
}
|
||||
|
||||
func (wm *HumanOTPEmailCodeWriteModel) CodeExpiry() time.Duration {
|
||||
return wm.codeExpiry
|
||||
if wm.otpCode == nil {
|
||||
return 0
|
||||
}
|
||||
return wm.otpCode.Expiry
|
||||
}
|
||||
|
||||
func (wm *HumanOTPEmailCodeWriteModel) Code() *crypto.CryptoValue {
|
||||
return wm.code
|
||||
if wm.otpCode == nil {
|
||||
return nil
|
||||
}
|
||||
return wm.otpCode.Code
|
||||
}
|
||||
|
||||
func (wm *HumanOTPEmailCodeWriteModel) CheckFailedCount() uint64 {
|
||||
@ -338,9 +354,11 @@ func (wm *HumanOTPEmailCodeWriteModel) Reduce() error {
|
||||
for _, event := range wm.Events {
|
||||
switch e := event.(type) {
|
||||
case *user.HumanOTPEmailCodeAddedEvent:
|
||||
wm.code = e.Code
|
||||
wm.codeCreationDate = e.CreationDate()
|
||||
wm.codeExpiry = e.Expiry
|
||||
wm.otpCode = &OTPCode{
|
||||
Code: e.Code,
|
||||
CreationDate: e.CreationDate(),
|
||||
Expiry: e.Expiry,
|
||||
}
|
||||
case *user.HumanOTPEmailCheckSucceededEvent:
|
||||
wm.checkFailedCount = 0
|
||||
case *user.HumanOTPEmailCheckFailedEvent:
|
||||
|
@ -295,8 +295,8 @@ func (c *Commands) PasswordChangeSent(ctx context.Context, orgID, userID string)
|
||||
return err
|
||||
}
|
||||
|
||||
// HumanCheckPassword check password for user with additional informations from authRequest
|
||||
func (c *Commands) HumanCheckPassword(ctx context.Context, orgID, userID, password string, authRequest *domain.AuthRequest, lockoutPolicy *domain.LockoutPolicy) (err error) {
|
||||
// HumanCheckPassword check password for user with additional information from authRequest
|
||||
func (c *Commands) HumanCheckPassword(ctx context.Context, orgID, userID, password string, authRequest *domain.AuthRequest) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
@ -314,56 +314,66 @@ func (c *Commands) HumanCheckPassword(ctx context.Context, orgID, userID, passwo
|
||||
if !loginPolicy.AllowUsernamePassword {
|
||||
return zerrors.ThrowPreconditionFailed(err, "COMMAND-Dft32", "Errors.Org.LoginPolicy.UsernamePasswordNotAllowed")
|
||||
}
|
||||
|
||||
wm, err := c.passwordWriteModel(ctx, userID, orgID)
|
||||
if err != nil {
|
||||
commands, err := checkPassword(ctx, userID, password, c.eventstore, c.userPasswordHasher, authRequestDomainToAuthRequestInfo(authRequest))
|
||||
if len(commands) == 0 {
|
||||
return err
|
||||
}
|
||||
_, pushErr := c.eventstore.Push(ctx, commands...)
|
||||
logging.OnError(pushErr).Error("error create password check failed event")
|
||||
return err
|
||||
}
|
||||
|
||||
if !isUserStateExists(wm.UserState) {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-3n77z", "Errors.User.NotFound")
|
||||
func checkPassword(ctx context.Context, userID, password string, es *eventstore.Eventstore, hasher *crypto.Hasher, optionalAuthRequestInfo *user.AuthRequestInfo) ([]eventstore.Command, error) {
|
||||
if userID == "" {
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sfw3f", "Errors.User.UserIDMissing")
|
||||
}
|
||||
wm := NewHumanPasswordWriteModel(userID, "")
|
||||
err := es.FilterToQueryReducer(ctx, wm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !wm.UserState.Exists() {
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-3n77z", "Errors.User.NotFound")
|
||||
}
|
||||
if wm.UserState == domain.UserStateLocked {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-JLK35", "Errors.User.Locked")
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-JLK35", "Errors.User.Locked")
|
||||
}
|
||||
if wm.EncodedHash == "" {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-3nJ4t", "Errors.User.Password.NotSet")
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-3nJ4t", "Errors.User.Password.NotSet")
|
||||
}
|
||||
|
||||
userAgg := UserAggregateFromWriteModel(&wm.WriteModel)
|
||||
ctx, spanPasswordComparison := tracing.NewNamedSpan(ctx, "passwap.Verify")
|
||||
updated, err := c.userPasswordHasher.Verify(wm.EncodedHash, password)
|
||||
updated, err := hasher.Verify(wm.EncodedHash, password)
|
||||
spanPasswordComparison.EndWithError(err)
|
||||
err = convertPasswapErr(err)
|
||||
commands := make([]eventstore.Command, 0, 2)
|
||||
|
||||
// recheck for additional events (failed password checks or locks)
|
||||
recheckErr := c.eventstore.FilterToQueryReducer(ctx, wm)
|
||||
recheckErr := es.FilterToQueryReducer(ctx, wm)
|
||||
if recheckErr != nil {
|
||||
return recheckErr
|
||||
return nil, recheckErr
|
||||
}
|
||||
if wm.UserState == domain.UserStateLocked {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-SFA3t", "Errors.User.Locked")
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-SFA3t", "Errors.User.Locked")
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
commands = append(commands, user.NewHumanPasswordCheckSucceededEvent(ctx, userAgg, authRequestDomainToAuthRequestInfo(authRequest)))
|
||||
commands = append(commands, user.NewHumanPasswordCheckSucceededEvent(ctx, userAgg, optionalAuthRequestInfo))
|
||||
if updated != "" {
|
||||
commands = append(commands, user.NewHumanPasswordHashUpdatedEvent(ctx, userAgg, updated))
|
||||
}
|
||||
_, err = c.eventstore.Push(ctx, commands...)
|
||||
return err
|
||||
return commands, nil
|
||||
}
|
||||
|
||||
commands = append(commands, user.NewHumanPasswordCheckFailedEvent(ctx, userAgg, authRequestDomainToAuthRequestInfo(authRequest)))
|
||||
if lockoutPolicy != nil && lockoutPolicy.MaxPasswordAttempts > 0 {
|
||||
if wm.PasswordCheckFailedCount+1 >= lockoutPolicy.MaxPasswordAttempts {
|
||||
commands = append(commands, user.NewUserLockedEvent(ctx, userAgg))
|
||||
}
|
||||
commands = append(commands, user.NewHumanPasswordCheckFailedEvent(ctx, userAgg, optionalAuthRequestInfo))
|
||||
|
||||
lockoutPolicy, lockoutErr := getLockoutPolicy(ctx, wm.ResourceOwner, es.FilterToQueryReducer)
|
||||
logging.OnError(lockoutErr).Error("unable to get lockout policy")
|
||||
if lockoutPolicy != nil && lockoutPolicy.MaxPasswordAttempts > 0 && wm.PasswordCheckFailedCount+1 >= lockoutPolicy.MaxPasswordAttempts {
|
||||
commands = append(commands, user.NewUserLockedEvent(ctx, userAgg))
|
||||
}
|
||||
_, pushErr := c.eventstore.Push(ctx, commands...)
|
||||
logging.OnError(pushErr).Error("error create password check failed event")
|
||||
return err
|
||||
return commands, err
|
||||
}
|
||||
|
||||
func (c *Commands) passwordWriteModel(ctx context.Context, userID, resourceOwner string) (writeModel *HumanPasswordWriteModel, err error) {
|
||||
|
@ -1456,7 +1456,6 @@ func TestCommandSide_CheckPassword(t *testing.T) {
|
||||
resourceOwner string
|
||||
password string
|
||||
authReq *domain.AuthRequest
|
||||
lockoutPolicy *domain.LockoutPolicy
|
||||
}
|
||||
type res struct {
|
||||
err func(error) bool
|
||||
@ -1768,6 +1767,13 @@ func TestCommandSide_CheckPassword(t *testing.T) {
|
||||
"")),
|
||||
),
|
||||
expectFilter(),
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
org.NewLockoutPolicyAddedEvent(context.Background(),
|
||||
&org.NewAggregate("org1").Aggregate,
|
||||
0, 0, false,
|
||||
)),
|
||||
),
|
||||
expectPush(
|
||||
user.NewHumanPasswordCheckFailedEvent(context.Background(),
|
||||
&user.NewAggregate("user1", "org1").Aggregate,
|
||||
@ -1789,7 +1795,6 @@ func TestCommandSide_CheckPassword(t *testing.T) {
|
||||
ID: "request1",
|
||||
AgentID: "agent1",
|
||||
},
|
||||
lockoutPolicy: &domain.LockoutPolicy{},
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.IsErrorInvalidArgument,
|
||||
@ -1852,6 +1857,13 @@ func TestCommandSide_CheckPassword(t *testing.T) {
|
||||
),
|
||||
),
|
||||
expectFilter(),
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
org.NewLockoutPolicyAddedEvent(context.Background(),
|
||||
&org.NewAggregate("org1").Aggregate,
|
||||
1, 1, false,
|
||||
)),
|
||||
),
|
||||
expectPush(
|
||||
user.NewHumanPasswordCheckFailedEvent(context.Background(),
|
||||
&user.NewAggregate("user1", "org1").Aggregate,
|
||||
@ -1876,10 +1888,6 @@ func TestCommandSide_CheckPassword(t *testing.T) {
|
||||
ID: "request1",
|
||||
AgentID: "agent1",
|
||||
},
|
||||
lockoutPolicy: &domain.LockoutPolicy{
|
||||
MaxPasswordAttempts: 1,
|
||||
MaxOTPAttempts: 1,
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.IsErrorInvalidArgument,
|
||||
@ -2230,7 +2238,7 @@ func TestCommandSide_CheckPassword(t *testing.T) {
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
userPasswordHasher: tt.fields.userPasswordHasher,
|
||||
}
|
||||
err := r.HumanCheckPassword(tt.args.ctx, tt.args.resourceOwner, tt.args.userID, tt.args.password, tt.args.authReq, tt.args.lockoutPolicy)
|
||||
err := r.HumanCheckPassword(tt.args.ctx, tt.args.resourceOwner, tt.args.userID, tt.args.password, tt.args.authReq)
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
@ -281,42 +281,42 @@ func CheckRedirect(req *http.Request) (*url.URL, error) {
|
||||
return resp.Location()
|
||||
}
|
||||
|
||||
func (s *Tester) CreateOIDCCredentialsClient(ctx context.Context) (userID, clientID, clientSecret string, err error) {
|
||||
name := gofakeit.Username()
|
||||
user, err := s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{
|
||||
func (s *Tester) CreateOIDCCredentialsClient(ctx context.Context) (machine *management.AddMachineUserResponse, name, clientID, clientSecret string, err error) {
|
||||
name = gofakeit.Username()
|
||||
machine, err = s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{
|
||||
Name: name,
|
||||
UserName: name,
|
||||
AccessTokenType: user.AccessTokenType_ACCESS_TOKEN_TYPE_JWT,
|
||||
})
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
return nil, "", "", "", err
|
||||
}
|
||||
secret, err := s.Client.Mgmt.GenerateMachineSecret(ctx, &management.GenerateMachineSecretRequest{
|
||||
UserId: user.GetUserId(),
|
||||
UserId: machine.GetUserId(),
|
||||
})
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
return nil, "", "", "", err
|
||||
}
|
||||
return user.GetUserId(), secret.GetClientId(), secret.GetClientSecret(), nil
|
||||
return machine, name, secret.GetClientId(), secret.GetClientSecret(), nil
|
||||
}
|
||||
|
||||
func (s *Tester) CreateOIDCJWTProfileClient(ctx context.Context) (userID string, keyData []byte, err error) {
|
||||
name := gofakeit.Username()
|
||||
user, err := s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{
|
||||
func (s *Tester) CreateOIDCJWTProfileClient(ctx context.Context) (machine *management.AddMachineUserResponse, name string, keyData []byte, err error) {
|
||||
name = gofakeit.Username()
|
||||
machine, err = s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{
|
||||
Name: name,
|
||||
UserName: name,
|
||||
AccessTokenType: user.AccessTokenType_ACCESS_TOKEN_TYPE_JWT,
|
||||
})
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
return nil, "", nil, err
|
||||
}
|
||||
keyResp, err := s.Client.Mgmt.AddMachineKey(ctx, &management.AddMachineKeyRequest{
|
||||
UserId: user.GetUserId(),
|
||||
UserId: machine.GetUserId(),
|
||||
Type: authn.KeyType_KEY_TYPE_JSON,
|
||||
ExpirationDate: timestamppb.New(time.Now().Add(time.Hour)),
|
||||
})
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
return nil, "", nil, err
|
||||
}
|
||||
return user.GetUserId(), keyResp.GetKeyDetails(), nil
|
||||
return machine, name, keyResp.GetKeyDetails(), nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user