Merge branch 'main' into next

This commit is contained in:
Livio Spring 2024-05-31 12:12:02 +02:00
commit 50e0e7d564
No known key found for this signature in database
GPG Key ID: 26BB1C2FA5952CF0
33 changed files with 1010 additions and 450 deletions

View File

@ -6,9 +6,11 @@
<mat-icon class="icon">info_outline</mat-icon> <mat-icon class="icon">info_outline</mat-icon>
</a> </a>
</div> </div>
<p class="sub cnsl-secondary-text max-width-description"> <p
{{ 'DESCRIPTIONS.PROJECTS.DESCRIPTION' | translate }} class="sub cnsl-secondary-text max-width-description"
</p> [innerHTML]="'DESCRIPTIONS.PROJECTS.DESCRIPTION' | translate"
></p>
<div class="projects-controls"> <div class="projects-controls">
<div class="project-toggle-group"> <div class="project-toggle-group">
<cnsl-nav-toggle <cnsl-nav-toggle

View File

@ -14,12 +14,15 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/zitadel/logging" "github.com/zitadel/logging"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"github.com/zitadel/zitadel/internal/integration" "github.com/zitadel/zitadel/internal/integration"
mgmt "github.com/zitadel/zitadel/pkg/grpc/management"
object "github.com/zitadel/zitadel/pkg/grpc/object/v2beta" object "github.com/zitadel/zitadel/pkg/grpc/object/v2beta"
session "github.com/zitadel/zitadel/pkg/grpc/session/v2beta" session "github.com/zitadel/zitadel/pkg/grpc/session/v2beta"
user "github.com/zitadel/zitadel/pkg/grpc/user/v2beta" user "github.com/zitadel/zitadel/pkg/grpc/user/v2beta"
@ -27,6 +30,7 @@ import (
var ( var (
CTX context.Context CTX context.Context
IAMOwnerCTX context.Context
Tester *integration.Tester Tester *integration.Tester
Client session.SessionServiceClient Client session.SessionServiceClient
User *user.AddHumanUserResponse User *user.AddHumanUserResponse
@ -44,6 +48,7 @@ func TestMain(m *testing.M) {
Client = Tester.Client.SessionV2 Client = Tester.Client.SessionV2
CTX, _ = Tester.WithAuthorization(ctx, integration.OrgOwner), errCtx CTX, _ = Tester.WithAuthorization(ctx, integration.OrgOwner), errCtx
IAMOwnerCTX = Tester.WithAuthorization(ctx, integration.IAMOwner)
User = createFullUser(CTX) User = createFullUser(CTX)
DeactivatedUser = createDeactivatedUser(CTX) DeactivatedUser = createDeactivatedUser(CTX)
LockedUser = createLockedUser(CTX) LockedUser = createLockedUser(CTX)
@ -341,6 +346,48 @@ func TestServer_CreateSession(t *testing.T) {
} }
} }
func TestServer_CreateSession_lock_user(t *testing.T) {
// create a separate org so we don't interfere with any other test
org := Tester.CreateOrganization(IAMOwnerCTX,
fmt.Sprintf("TestServer_CreateSession_lock_user_%d", time.Now().UnixNano()),
fmt.Sprintf("%d@mouse.com", time.Now().UnixNano()),
)
userID := org.CreatedAdmins[0].GetUserId()
Tester.SetUserPassword(IAMOwnerCTX, userID, integration.UserPassword, false)
// enable password lockout
maxAttempts := 2
ctxOrg := metadata.AppendToOutgoingContext(IAMOwnerCTX, "x-zitadel-orgid", org.GetOrganizationId())
_, err := Tester.Client.Mgmt.AddCustomLockoutPolicy(ctxOrg, &mgmt.AddCustomLockoutPolicyRequest{
MaxPasswordAttempts: uint32(maxAttempts),
})
require.NoError(t, err)
for i := 0; i <= maxAttempts; i++ {
_, err := Client.CreateSession(CTX, &session.CreateSessionRequest{
Checks: &session.Checks{
User: &session.CheckUser{
Search: &session.CheckUser_UserId{
UserId: userID,
},
},
Password: &session.CheckPassword{
Password: "invalid",
},
},
})
assert.Error(t, err)
statusCode := status.Code(err)
expectedCode := codes.InvalidArgument
// as soon as we hit the limit the user is locked and following request will
// already deny any check with a precondition failed since the user is locked
if i >= maxAttempts {
expectedCode = codes.FailedPrecondition
}
assert.Equal(t, expectedCode, statusCode)
}
}
func TestServer_CreateSession_webauthn(t *testing.T) { func TestServer_CreateSession_webauthn(t *testing.T) {
// create new session with user and request the webauthn challenge // create new session with user and request the webauthn challenge
createResp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{ createResp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{

View File

@ -471,7 +471,7 @@ func (s *Server) CreateTokenCallbackURL(ctx context.Context, req op.AuthRequest)
if err != nil { if err != nil {
return "", err return "", err
} }
resp, err := s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion) resp, err := s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -563,7 +563,7 @@ func (s *Server) authResponseToken(authReq *AuthRequest, authorizer op.Authorize
op.AuthRequestError(w, r, authReq, err, authorizer) op.AuthRequestError(w, r, authReq, err, authorizer)
return err return err
} }
resp, err := s.accessTokenResponseFromSession(ctx, client, session, authReq.GetState(), client.client.ProjectID, client.client.ProjectRoleAssertion) resp, err := s.accessTokenResponseFromSession(ctx, client, session, authReq.GetState(), client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion)
if err != nil { if err != nil {
op.AuthRequestError(w, r, authReq, err, authorizer) op.AuthRequestError(w, r, authReq, err, authorizer)
return err return err

View File

@ -54,7 +54,7 @@ func TestOPStorage_CreateAccessToken_code(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, false) assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// callback on a succeeded request must fail // callback on a succeeded request must fail
linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
@ -108,7 +108,7 @@ func TestOPStorage_CreateAccessToken_implicit(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
claims, err := rp.VerifyTokens[*oidc.IDTokenClaims](context.Background(), accessToken, idToken, provider.IDTokenVerifier()) claims, err := rp.VerifyTokens[*oidc.IDTokenClaims](context.Background(), accessToken, idToken, provider.IDTokenVerifier())
require.NoError(t, err) require.NoError(t, err)
assertIDTokenClaims(t, claims, User.GetUserId(), armPasskey, startTime, changeTime) assertIDTokenClaims(t, claims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// callback on a succeeded request must fail // callback on a succeeded request must fail
linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
@ -143,7 +143,7 @@ func TestOPStorage_CreateAccessAndRefreshTokens_code(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
} }
func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) { func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) {
@ -168,14 +168,14 @@ func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// test actual refresh grant // test actual refresh grant
newTokens, err := refreshTokens(t, clientID, tokens.RefreshToken) newTokens, err := refreshTokens(t, clientID, tokens.RefreshToken)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, newTokens, true) assertTokens(t, newTokens, true)
// auth time must still be the initial // auth time must still be the initial
assertIDTokenClaims(t, newTokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) assertIDTokenClaims(t, newTokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// refresh with an old refresh_token must fail // refresh with an old refresh_token must fail
_, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "") _, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "")
@ -204,7 +204,7 @@ func TestOPStorage_RevokeToken_access_token(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// revoke access token // revoke access token
err = rp.RevokeToken(CTX, provider, tokens.AccessToken, "access_token") err = rp.RevokeToken(CTX, provider, tokens.AccessToken, "access_token")
@ -247,7 +247,7 @@ func TestOPStorage_RevokeToken_access_token_invalid_token_hint_type(t *testing.T
tokens, err := exchangeTokens(t, clientID, code, redirectURI) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// revoke access token // revoke access token
err = rp.RevokeToken(CTX, provider, tokens.AccessToken, "refresh_token") err = rp.RevokeToken(CTX, provider, tokens.AccessToken, "refresh_token")
@ -284,7 +284,7 @@ func TestOPStorage_RevokeToken_refresh_token(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// revoke refresh token -> invalidates also access token // revoke refresh token -> invalidates also access token
err = rp.RevokeToken(CTX, provider, tokens.RefreshToken, "refresh_token") err = rp.RevokeToken(CTX, provider, tokens.RefreshToken, "refresh_token")
@ -327,7 +327,7 @@ func TestOPStorage_RevokeToken_refresh_token_invalid_token_type_hint(t *testing.
tokens, err := exchangeTokens(t, clientID, code, redirectURI) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// revoke refresh token even with a wrong hint // revoke refresh token even with a wrong hint
err = rp.RevokeToken(CTX, provider, tokens.RefreshToken, "access_token") err = rp.RevokeToken(CTX, provider, tokens.RefreshToken, "access_token")
@ -362,7 +362,7 @@ func TestOPStorage_RevokeToken_invalid_client(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// simulate second client (not part of the audience) trying to revoke the token // simulate second client (not part of the audience) trying to revoke the token
otherClientID, _ := createClient(t) otherClientID, _ := createClient(t)
@ -394,7 +394,7 @@ func TestOPStorage_TerminateSession(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, false) assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// userinfo must not fail // userinfo must not fail
_, err = rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider) _, err = rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider)
@ -431,7 +431,7 @@ func TestOPStorage_TerminateSession_refresh_grant(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// userinfo must not fail // userinfo must not fail
_, err = rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider) _, err = rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider)
@ -475,7 +475,7 @@ func TestOPStorage_TerminateSession_empty_id_token_hint(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, false) assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
postLogoutRedirect, err := rp.EndSession(CTX, provider, "", logoutRedirectURI, "state") postLogoutRedirect, err := rp.EndSession(CTX, provider, "", logoutRedirectURI, "state")
require.NoError(t, err) require.NoError(t, err)
@ -530,8 +530,13 @@ func assertTokens(t *testing.T, tokens *oidc.Tokens[*oidc.IDTokenClaims], requir
} }
} }
func assertIDTokenClaims(t *testing.T, claims *oidc.IDTokenClaims, userID string, arm []string, sessionStart, sessionChange time.Time) { func assertIDTokenClaims(t *testing.T, claims *oidc.IDTokenClaims, userID string, arm []string, sessionStart, sessionChange time.Time, sessionID string) {
assert.Equal(t, userID, claims.Subject) assert.Equal(t, userID, claims.Subject)
assert.Equal(t, arm, claims.AuthenticationMethodsReferences) assert.Equal(t, arm, claims.AuthenticationMethodsReferences)
assertOIDCTimeRange(t, claims.AuthTime, sessionStart, sessionChange) assertOIDCTimeRange(t, claims.AuthTime, sessionStart, sessionChange)
assert.Equal(t, sessionID, claims.SessionID)
assert.Empty(t, claims.Name)
assert.Empty(t, claims.GivenName)
assert.Empty(t, claims.FamilyName)
assert.Empty(t, claims.PreferredUsername)
} }

View File

@ -122,7 +122,7 @@ func TestServer_Introspect(t *testing.T) {
tokens, err := exchangeTokens(t, app.GetClientId(), code, redirectURI) tokens, err := exchangeTokens(t, app.GetClientId(), code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// test actual introspection // test actual introspection
introspection, err := rs.Introspect[*oidc.IntrospectionResponse](context.Background(), resourceServer, tokens.AccessToken) introspection, err := rs.Introspect[*oidc.IntrospectionResponse](context.Background(), resourceServer, tokens.AccessToken)
@ -317,7 +317,7 @@ func TestServer_VerifyClient(t *testing.T) {
} }
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, false) assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
}) })
} }
} }

View File

@ -11,6 +11,7 @@ import (
"github.com/zitadel/oidc/v3/pkg/op" "github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
@ -80,7 +81,11 @@ func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionR
// with active: false // with active: false
defer func() { defer func() {
if err != nil { if err != nil {
s.getLogger(ctx).ErrorContext(ctx, "oidc introspection", "err", err) if zerrors.IsInternal(err) {
s.getLogger(ctx).ErrorContext(ctx, "oidc introspection", "err", err)
} else {
s.getLogger(ctx).InfoContext(ctx, "oidc introspection", "err", err)
}
resp, err = op.NewResponse(new(oidc.IntrospectionResponse)), nil resp, err = op.NewResponse(new(oidc.IntrospectionResponse)), nil
} }
}() }()
@ -99,7 +104,14 @@ func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionR
if err = validateIntrospectionAudience(token.audience, client.clientID, client.projectID); err != nil { if err = validateIntrospectionAudience(token.audience, client.clientID, client.projectID); err != nil {
return nil, err return nil, err
} }
userInfo, err := s.userInfo(ctx, token.userID, token.scope, client.projectID, client.projectRoleAssertion, true) userInfo, err := s.userInfo(
token.userID,
token.scope,
client.projectID,
client.projectRoleAssertion,
true,
true,
)(ctx, true, domain.TriggerTypePreUserinfoCreation)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -77,7 +77,7 @@ func Test_ZITADEL_API_missing_audience_scope(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, false) assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
@ -142,7 +142,7 @@ func Test_ZITADEL_API_missing_mfa_2fa_setup(t *testing.T) {
code := assertCodeResponse(t, linkResp.GetCallbackUrl()) code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code, redirectURI) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertIDTokenClaims(t, tokens.IDTokenClaims, userResp.GetUserId(), armPassword, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, userResp.GetUserId(), armPassword, startTime, changeTime, sessionID)
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
@ -173,7 +173,7 @@ func Test_ZITADEL_API_missing_mfa_policy(t *testing.T) {
code := assertCodeResponse(t, linkResp.GetCallbackUrl()) code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code, redirectURI) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertIDTokenClaims(t, tokens.IDTokenClaims, userID, armPassword, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, userID, armPassword, startTime, changeTime, sessionID)
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
@ -227,7 +227,7 @@ func Test_ZITADEL_API_success(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, false) assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
@ -261,7 +261,7 @@ func Test_ZITADEL_API_glob_redirects(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, false) assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
@ -290,7 +290,7 @@ func Test_ZITADEL_API_inactive_access_token(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// make sure token works // make sure token works
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
@ -332,7 +332,7 @@ func Test_ZITADEL_API_terminated_session(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// make sure token works // make sure token works
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
@ -402,7 +402,7 @@ func Test_ZITADEL_API_terminated_session_user_disabled(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, disabledUser.GetUserId(), armPassword, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, disabledUser.GetUserId(), armPassword, startTime, changeTime, sessionID)
// make sure token works // make sure token works
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))

View File

@ -29,8 +29,8 @@ In some cases step 1 till 3 are completely implemented in the command package,
for example the v2 code exchange and refresh token. for example the v2 code exchange and refresh token.
*/ */
func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.Client, session *command.OIDCSession, state, projectID string, projectRoleAssertion bool) (_ *oidc.AccessTokenResponse, err error) { func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.Client, session *command.OIDCSession, state, projectID string, projectRoleAssertion, accessTokenRoleAssertion, idTokenRoleAssertion, userInfoAssertion bool) (_ *oidc.AccessTokenResponse, err error) {
getUserInfo := s.getUserInfoOnce(session.UserID, projectID, projectRoleAssertion, session.Scope) getUserInfo := s.getUserInfo(session.UserID, projectID, projectRoleAssertion, userInfoAssertion, session.Scope)
getSigner := s.getSignerOnce() getSigner := s.getSignerOnce()
resp := &oidc.AccessTokenResponse{ resp := &oidc.AccessTokenResponse{
@ -43,7 +43,7 @@ func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.C
// If the session does not have a token ID, it is an implicit ID-Token only response. // If the session does not have a token ID, it is an implicit ID-Token only response.
if session.TokenID != "" { if session.TokenID != "" {
if client.AccessTokenType() == op.AccessTokenTypeJWT { if client.AccessTokenType() == op.AccessTokenTypeJWT {
resp.AccessToken, err = s.createJWT(ctx, client, session, getUserInfo, getSigner) resp.AccessToken, err = s.createJWT(ctx, client, session, getUserInfo, accessTokenRoleAssertion, getSigner)
} else { } else {
resp.AccessToken, err = op.CreateBearerToken(session.TokenID, session.UserID, s.opCrypto) resp.AccessToken, err = op.CreateBearerToken(session.TokenID, session.UserID, s.opCrypto)
} }
@ -53,7 +53,7 @@ func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.C
} }
if slices.Contains(session.Scope, oidc.ScopeOpenID) { if slices.Contains(session.Scope, oidc.ScopeOpenID) {
resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, getSigner, session.SessionID, resp.AccessToken, session.Audience, session.AuthMethods, session.AuthTime, session.Nonce, session.Actor) resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, idTokenRoleAssertion, getSigner, session.SessionID, resp.AccessToken, session.Audience, session.AuthMethods, session.AuthTime, session.Nonce, session.Actor)
} }
return resp, err return resp, err
} }
@ -92,31 +92,22 @@ func (s *Server) getSignerOnce() signerFunc {
} }
// userInfoFunc is a getter function that allows add-hoc retrieval of a user. // userInfoFunc is a getter function that allows add-hoc retrieval of a user.
type userInfoFunc func(ctx context.Context) (*oidc.UserInfo, error) type userInfoFunc func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (*oidc.UserInfo, error)
// getUserInfoOnce returns a function which retrieves userinfo from the database once. // getUserInfo returns a function which retrieves userinfo from the database once.
// Repeated calls of the returned function return the same results. // However, each time, role claims are asserted and also action flows will trigger.
func (s *Server) getUserInfoOnce(userID, projectID string, projectRoleAssertion bool, scope []string) userInfoFunc { func (s *Server) getUserInfo(userID, projectID string, projectRoleAssertion, userInfoAssertion bool, scope []string) userInfoFunc {
var ( userInfo := s.userInfo(userID, scope, projectID, projectRoleAssertion, userInfoAssertion, false)
once sync.Once return func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (*oidc.UserInfo, error) {
userInfo *oidc.UserInfo return userInfo(ctx, roleAssertion, triggerType)
err error
)
return func(ctx context.Context) (*oidc.UserInfo, error) {
once.Do(func() {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
userInfo, err = s.userInfo(ctx, userID, scope, projectID, projectRoleAssertion, false)
})
return userInfo, err
} }
} }
func (*Server) createIDToken(ctx context.Context, client op.Client, getUserInfo userInfoFunc, getSigningKey signerFunc, sessionID, accessToken string, audience []string, authMethods []domain.UserAuthMethodType, authTime time.Time, nonce string, actor *domain.TokenActor) (idToken string, exp uint64, err error) { func (*Server) createIDToken(ctx context.Context, client op.Client, getUserInfo userInfoFunc, roleAssertion bool, getSigningKey signerFunc, sessionID, accessToken string, audience []string, authMethods []domain.UserAuthMethodType, authTime time.Time, nonce string, actor *domain.TokenActor) (idToken string, exp uint64, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
userInfo, err := getUserInfo(ctx) userInfo, err := getUserInfo(ctx, roleAssertion, domain.TriggerTypePreUserinfoCreation)
if err != nil { if err != nil {
return "", 0, err return "", 0, err
} }
@ -156,11 +147,11 @@ func timeToOIDCExpiresIn(exp time.Time) uint64 {
return uint64(time.Until(exp) / time.Second) return uint64(time.Until(exp) / time.Second)
} }
func (*Server) createJWT(ctx context.Context, client op.Client, session *command.OIDCSession, getUserInfo userInfoFunc, getSigner signerFunc) (_ string, err error) { func (s *Server) createJWT(ctx context.Context, client op.Client, session *command.OIDCSession, getUserInfo userInfoFunc, assertRoles bool, getSigner signerFunc) (_ string, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
userInfo, err := getUserInfo(ctx) userInfo, err := getUserInfo(ctx, assertRoles, domain.TriggerTypePreAccessTokenCreation)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -47,5 +47,5 @@ func (s *Server) ClientCredentialsExchange(ctx context.Context, r *op.ClientRequ
false, false,
) )
return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false)) return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false, true, false, false))
} }

View File

@ -4,6 +4,7 @@ package oidc_test
import ( import (
"testing" "testing"
"time"
"github.com/brianvoe/gofakeit/v6" "github.com/brianvoe/gofakeit/v6"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -18,10 +19,13 @@ import (
) )
func TestServer_ClientCredentialsExchange(t *testing.T) { func TestServer_ClientCredentialsExchange(t *testing.T) {
userID, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX) machine, name, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX)
require.NoError(t, err) require.NoError(t, err)
type claims struct { type claims struct {
name string
username string
updated time.Time
resourceOwnerID any resourceOwnerID any
resourceOwnerName any resourceOwnerName any
resourceOwnerPrimaryDomain any resourceOwnerPrimaryDomain any
@ -78,6 +82,17 @@ func TestServer_ClientCredentialsExchange(t *testing.T) {
clientSecret: clientSecret, clientSecret: clientSecret,
scope: []string{oidc.ScopeOpenID}, scope: []string{oidc.ScopeOpenID},
}, },
{
name: "openid, profile, email",
clientID: clientID,
clientSecret: clientSecret,
scope: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail},
wantClaims: claims{
name: name,
username: name,
updated: machine.GetDetails().GetChangeDate().AsTime(),
},
},
{ {
name: "org id and domain scope", name: "org id and domain scope",
clientID: clientID, clientID: clientID,
@ -132,12 +147,20 @@ func TestServer_ClientCredentialsExchange(t *testing.T) {
} }
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, tokens) require.NotNil(t, tokens)
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, userID, provider) userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, machine.GetUserId(), provider)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, tt.wantClaims.resourceOwnerID, userinfo.Claims[oidc_api.ClaimResourceOwnerID]) assert.Equal(t, tt.wantClaims.resourceOwnerID, userinfo.Claims[oidc_api.ClaimResourceOwnerID])
assert.Equal(t, tt.wantClaims.resourceOwnerName, userinfo.Claims[oidc_api.ClaimResourceOwnerName]) assert.Equal(t, tt.wantClaims.resourceOwnerName, userinfo.Claims[oidc_api.ClaimResourceOwnerName])
assert.Equal(t, tt.wantClaims.resourceOwnerPrimaryDomain, userinfo.Claims[oidc_api.ClaimResourceOwnerPrimaryDomain]) assert.Equal(t, tt.wantClaims.resourceOwnerPrimaryDomain, userinfo.Claims[oidc_api.ClaimResourceOwnerPrimaryDomain])
assert.Equal(t, tt.wantClaims.orgDomain, userinfo.Claims[domain.OrgDomainPrimaryClaim]) assert.Equal(t, tt.wantClaims.orgDomain, userinfo.Claims[domain.OrgDomainPrimaryClaim])
assert.Equal(t, tt.wantClaims.name, userinfo.Name)
assert.Equal(t, tt.wantClaims.username, userinfo.PreferredUsername)
assertOIDCTime(t, userinfo.UpdatedAt, tt.wantClaims.updated)
assert.Empty(t, userinfo.UserInfoProfile.FamilyName)
assert.Empty(t, userinfo.UserInfoProfile.GivenName)
assert.Empty(t, userinfo.UserInfoEmail)
assert.Empty(t, userinfo.UserInfoPhone)
assert.Empty(t, userinfo.Address)
}) })
} }
} }

View File

@ -49,7 +49,7 @@ func (s *Server) CodeExchange(ctx context.Context, r *op.ClientRequest[oidc.Acce
if err != nil { if err != nil {
return nil, err return nil, err
} }
return response(s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion)) return response(s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion))
} }
// codeExchangeV1 creates a v2 token from a v1 auth request. // codeExchangeV1 creates a v2 token from a v1 auth request.

View File

@ -26,7 +26,7 @@ func (s *Server) DeviceToken(ctx context.Context, r *op.ClientRequest[oidc.Devic
} }
session, err := s.command.CreateOIDCSessionFromDeviceAuth(ctx, r.Data.DeviceCode) session, err := s.command.CreateOIDCSessionFromDeviceAuth(ctx, r.Data.DeviceCode)
if err == nil { if err == nil {
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion)) return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion))
} }
if errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.DeadlineExceeded) {
return nil, oidc.ErrSlowDown().WithParent(err) return nil, oidc.ErrSlowDown().WithParent(err)

View File

@ -218,7 +218,7 @@ func validateTokenExchangeAudience(requestedAudience, subjectAudience, actorAudi
// Both tokens may point to the same object (subjectToken) in case of a regular Token Exchange. // Both tokens may point to the same object (subjectToken) in case of a regular Token Exchange.
// When the subject and actor Tokens point to different objects, the new tokens will be for impersonation / delegation. // When the subject and actor Tokens point to different objects, the new tokens will be for impersonation / delegation.
func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenType, client *Client, subjectToken, actorToken *exchangeToken, audience, scopes []string) (_ *oidc.TokenExchangeResponse, err error) { func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenType, client *Client, subjectToken, actorToken *exchangeToken, audience, scopes []string) (_ *oidc.TokenExchangeResponse, err error) {
getUserInfo := s.getUserInfoOnce(subjectToken.userID, client.client.ProjectID, client.client.ProjectRoleAssertion, scopes) getUserInfo := s.getUserInfo(subjectToken.userID, client.client.ProjectID, client.client.ProjectRoleAssertion, client.IDTokenUserinfoClaimsAssertion(), scopes)
getSigner := s.getSignerOnce() getSigner := s.getSignerOnce()
resp := &oidc.TokenExchangeResponse{ resp := &oidc.TokenExchangeResponse{
@ -240,12 +240,12 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT
resp.IssuedTokenType = oidc.AccessTokenType resp.IssuedTokenType = oidc.AccessTokenType
case oidc.JWTTokenType: case oidc.JWTTokenType:
resp.AccessToken, resp.RefreshToken, resp.ExpiresIn, err = s.createExchangeJWT(ctx, client, getUserInfo, getSigner, subjectToken.userID, subjectToken.resourceOwner, audience, scopes, actorToken.authMethods, actorToken.authTime, subjectToken.preferredLanguage, reason, actor) resp.AccessToken, resp.RefreshToken, resp.ExpiresIn, err = s.createExchangeJWT(ctx, client, getUserInfo, client.client.AccessTokenRoleAssertion, getSigner, subjectToken.userID, subjectToken.resourceOwner, audience, scopes, actorToken.authMethods, actorToken.authTime, subjectToken.preferredLanguage, reason, actor)
resp.TokenType = oidc.BearerToken resp.TokenType = oidc.BearerToken
resp.IssuedTokenType = oidc.JWTTokenType resp.IssuedTokenType = oidc.JWTTokenType
case oidc.IDTokenType: case oidc.IDTokenType:
resp.AccessToken, resp.ExpiresIn, err = s.createIDToken(ctx, client, getUserInfo, getSigner, "", resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor) resp.AccessToken, resp.ExpiresIn, err = s.createIDToken(ctx, client, getUserInfo, client.client.IDTokenRoleAssertion, getSigner, "", resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor)
resp.TokenType = TokenTypeNA resp.TokenType = TokenTypeNA
resp.IssuedTokenType = oidc.IDTokenType resp.IssuedTokenType = oidc.IDTokenType
@ -259,7 +259,7 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT
} }
if slices.Contains(scopes, oidc.ScopeOpenID) && tokenType != oidc.IDTokenType { if slices.Contains(scopes, oidc.ScopeOpenID) && tokenType != oidc.IDTokenType {
resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, getSigner, sessionID, resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor) resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, client.client.IDTokenRoleAssertion, getSigner, sessionID, resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -313,6 +313,7 @@ func (s *Server) createExchangeJWT(
ctx context.Context, ctx context.Context,
client *Client, client *Client,
getUserInfo userInfoFunc, getUserInfo userInfoFunc,
roleAssertion bool,
getSigner signerFunc, getSigner signerFunc,
userID, userID,
resourceOwner string, resourceOwner string,
@ -342,7 +343,7 @@ func (s *Server) createExchangeJWT(
actor, actor,
slices.Contains(scope, oidc.ScopeOfflineAccess), slices.Contains(scope, oidc.ScopeOfflineAccess),
) )
accessToken, err = s.createJWT(ctx, client, session, getUserInfo, getSigner) accessToken, err = s.createJWT(ctx, client, session, getUserInfo, roleAssertion, getSigner)
if err != nil { if err != nil {
return "", "", 0, err return "", "", 0, err
} }

View File

@ -54,7 +54,7 @@ func (s *Server) JWTProfile(ctx context.Context, r *op.Request[oidc.JWTProfileGr
nil, nil,
false, false,
) )
return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false)) return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false, true, false, false))
} }
func (s *Server) verifyJWTProfile(ctx context.Context, req *oidc.JWTProfileGrantRequest) (user *query.User, tokenRequest *oidc.JWTTokenRequest, err error) { func (s *Server) verifyJWTProfile(ctx context.Context, req *oidc.JWTProfileGrantRequest) (user *query.User, tokenRequest *oidc.JWTTokenRequest, err error) {

View File

@ -4,6 +4,7 @@ package oidc_test
import ( import (
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -16,10 +17,13 @@ import (
) )
func TestServer_JWTProfile(t *testing.T) { func TestServer_JWTProfile(t *testing.T) {
userID, keyData, err := Tester.CreateOIDCJWTProfileClient(CTX) user, name, keyData, err := Tester.CreateOIDCJWTProfileClient(CTX)
require.NoError(t, err) require.NoError(t, err)
type claims struct { type claims struct {
name string
username string
updated time.Time
resourceOwnerID any resourceOwnerID any
resourceOwnerName any resourceOwnerName any
resourceOwnerPrimaryDomain any resourceOwnerPrimaryDomain any
@ -37,6 +41,16 @@ func TestServer_JWTProfile(t *testing.T) {
keyData: keyData, keyData: keyData,
scope: []string{oidc.ScopeOpenID}, scope: []string{oidc.ScopeOpenID},
}, },
{
name: "openid, profile, email",
keyData: keyData,
scope: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail},
wantClaims: claims{
name: name,
username: name,
updated: user.GetDetails().GetChangeDate().AsTime(),
},
},
{ {
name: "org id and domain scope", name: "org id and domain scope",
keyData: keyData, keyData: keyData,
@ -92,12 +106,20 @@ func TestServer_JWTProfile(t *testing.T) {
provider, err := rp.NewRelyingPartyOIDC(CTX, Tester.OIDCIssuer(), "", "", redirectURI, tt.scope) provider, err := rp.NewRelyingPartyOIDC(CTX, Tester.OIDCIssuer(), "", "", redirectURI, tt.scope)
require.NoError(t, err) require.NoError(t, err)
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, userID, provider) userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, user.GetUserId(), provider)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, tt.wantClaims.resourceOwnerID, userinfo.Claims[oidc_api.ClaimResourceOwnerID]) assert.Equal(t, tt.wantClaims.resourceOwnerID, userinfo.Claims[oidc_api.ClaimResourceOwnerID])
assert.Equal(t, tt.wantClaims.resourceOwnerName, userinfo.Claims[oidc_api.ClaimResourceOwnerName]) assert.Equal(t, tt.wantClaims.resourceOwnerName, userinfo.Claims[oidc_api.ClaimResourceOwnerName])
assert.Equal(t, tt.wantClaims.resourceOwnerPrimaryDomain, userinfo.Claims[oidc_api.ClaimResourceOwnerPrimaryDomain]) assert.Equal(t, tt.wantClaims.resourceOwnerPrimaryDomain, userinfo.Claims[oidc_api.ClaimResourceOwnerPrimaryDomain])
assert.Equal(t, tt.wantClaims.orgDomain, userinfo.Claims[domain.OrgDomainPrimaryClaim]) assert.Equal(t, tt.wantClaims.orgDomain, userinfo.Claims[domain.OrgDomainPrimaryClaim])
assert.Equal(t, tt.wantClaims.name, userinfo.Name)
assert.Equal(t, tt.wantClaims.username, userinfo.PreferredUsername)
assertOIDCTime(t, userinfo.UpdatedAt, tt.wantClaims.updated)
assert.Empty(t, userinfo.UserInfoProfile.FamilyName)
assert.Empty(t, userinfo.UserInfoProfile.GivenName)
assert.Empty(t, userinfo.UserInfoEmail)
assert.Empty(t, userinfo.UserInfoPhone)
assert.Empty(t, userinfo.Address)
}) })
} }
} }

View File

@ -28,7 +28,7 @@ func (s *Server) RefreshToken(ctx context.Context, r *op.ClientRequest[oidc.Refr
session, err := s.command.ExchangeOIDCSessionRefreshAndAccessToken(ctx, r.Data.RefreshToken, r.Data.Scopes, refreshTokenComplianceChecker()) session, err := s.command.ExchangeOIDCSessionRefreshAndAccessToken(ctx, r.Data.RefreshToken, r.Data.Scopes, refreshTokenComplianceChecker())
if err == nil { if err == nil {
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion)) return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion))
} else if errors.Is(err, zerrors.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid")) { } else if errors.Is(err, zerrors.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid")) {
// We try again for v1 tokens when we encountered specific parsing error // We try again for v1 tokens when we encountered specific parsing error
return s.refreshTokenV1(ctx, client, r) return s.refreshTokenV1(ctx, client, r)
@ -78,7 +78,7 @@ func (s *Server) refreshTokenV1(ctx context.Context, client *Client, r *op.Clien
return nil, err return nil, err
} }
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion)) return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion))
} }
// refreshTokenComplianceChecker validates that the requested scope is a subset of the original auth request scope. // refreshTokenComplianceChecker validates that the requested scope is a subset of the original auth request scope.

View File

@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"slices" "slices"
"strings" "strings"
"sync"
"github.com/dop251/goja" "github.com/dop251/goja"
"github.com/zitadel/logging" "github.com/zitadel/logging"
@ -55,7 +56,14 @@ func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoReques
} }
} }
userInfo, err := s.userInfo(ctx, token.userID, token.scope, projectID, assertion, false) userInfo, err := s.userInfo(
token.userID,
token.scope,
projectID,
assertion,
true,
false,
)(ctx, true, domain.TriggerTypePreUserinfoCreation)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -66,24 +74,44 @@ func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoReques
// The returned UserInfo contains standard and reserved claims, documented // The returned UserInfo contains standard and reserved claims, documented
// here: https://zitadel.com/docs/apis/openidoauth/claims. // here: https://zitadel.com/docs/apis/openidoauth/claims.
// //
// User information is only retrieved once from the database.
// However, each time, role claims are asserted and also action flows will trigger.
//
// projectID is an optional parameter which defines the default audience when there are any (or all) role claims requested. // projectID is an optional parameter which defines the default audience when there are any (or all) role claims requested.
// projectRoleAssertion sets the default of returning all project roles, only if no specific roles were requested in the scope. // projectRoleAssertion sets the default of returning all project roles, only if no specific roles were requested in the scope.
// roleAssertion decides whether the roles will be returned (in the token or response)
// userInfoAssertion decides whether the user information (profile data like name, email, ...) are returned
// //
// currentProjectOnly can be set to use the current project ID only and ignore the audience from the scope. // currentProjectOnly can be set to use the current project ID only and ignore the audience from the scope.
// It should be set in cases where the client doesn't need to know roles outside its own project, // It should be set in cases where the client doesn't need to know roles outside its own project,
// for example an introspection client. // for example an introspection client.
func (s *Server) userInfo(ctx context.Context, userID string, scope []string, projectID string, projectRoleAssertion, currentProjectOnly bool) (_ *oidc.UserInfo, err error) { func (s *Server) userInfo(
ctx, span := tracing.NewSpan(ctx) userID string,
defer func() { span.EndWithError(err) }() scope []string,
projectID string,
projectRoleAssertion, userInfoAssertion, currentProjectOnly bool,
) func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (_ *oidc.UserInfo, err error) {
var (
once sync.Once
userInfo *oidc.UserInfo
qu *query.OIDCUserInfo
roleAudience, requestedRoles []string
)
return func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (_ *oidc.UserInfo, err error) {
once.Do(func() {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
roleAudience, requestedRoles := prepareRoles(ctx, scope, projectID, projectRoleAssertion, currentProjectOnly) roleAudience, requestedRoles = prepareRoles(ctx, scope, projectID, projectRoleAssertion, currentProjectOnly)
qu, err := s.query.GetOIDCUserInfo(ctx, userID, roleAudience) qu, err = s.query.GetOIDCUserInfo(ctx, userID, roleAudience)
if err != nil { if err != nil {
return nil, err return
}
userInfo = userInfoToOIDC(qu, userInfoAssertion, scope, s.assetAPIPrefix(ctx))
})
userInfoWithRoles := assertRoles(projectID, qu, roleAudience, requestedRoles, roleAssertion, userInfo)
return userInfoWithRoles, s.userinfoFlows(ctx, qu, userInfoWithRoles, triggerType)
} }
userInfo := userInfoToOIDC(projectID, qu, scope, roleAudience, requestedRoles, s.assetAPIPrefix(ctx))
return userInfo, s.userinfoFlows(ctx, qu, userInfo)
} }
// prepareRoles scans the requested scopes and builds the requested roles // prepareRoles scans the requested scopes and builds the requested roles
@ -120,20 +148,32 @@ func prepareRoles(ctx context.Context, scope []string, projectID string, project
return roleAudience, requestedRoles return roleAudience, requestedRoles
} }
func userInfoToOIDC(projectID string, user *query.OIDCUserInfo, scope, roleAudience, requestedRoles []string, assetPrefix string) *oidc.UserInfo { func userInfoToOIDC(user *query.OIDCUserInfo, userInfoAssertion bool, scope []string, assetPrefix string) *oidc.UserInfo {
out := new(oidc.UserInfo) out := new(oidc.UserInfo)
for _, s := range scope { for _, s := range scope {
switch s { switch s {
case oidc.ScopeOpenID: case oidc.ScopeOpenID:
out.Subject = user.User.ID out.Subject = user.User.ID
case oidc.ScopeEmail: case oidc.ScopeEmail:
if !userInfoAssertion {
continue
}
out.UserInfoEmail = userInfoEmailToOIDC(user.User) out.UserInfoEmail = userInfoEmailToOIDC(user.User)
case oidc.ScopeProfile: case oidc.ScopeProfile:
if !userInfoAssertion {
continue
}
out.UserInfoProfile = userInfoProfileToOidc(user.User, assetPrefix) out.UserInfoProfile = userInfoProfileToOidc(user.User, assetPrefix)
case oidc.ScopePhone: case oidc.ScopePhone:
if !userInfoAssertion {
continue
}
out.UserInfoPhone = userInfoPhoneToOIDC(user.User) out.UserInfoPhone = userInfoPhoneToOIDC(user.User)
case oidc.ScopeAddress: case oidc.ScopeAddress:
//TODO: handle address for human users as soon as implemented if !userInfoAssertion {
continue
}
// TODO: handle address for human users as soon as implemented
case ScopeUserMetaData: case ScopeUserMetaData:
setUserInfoMetadata(user.Metadata, out) setUserInfoMetadata(user.Metadata, out)
case ScopeResourceOwner: case ScopeResourceOwner:
@ -148,12 +188,19 @@ func userInfoToOIDC(projectID string, user *query.OIDCUserInfo, scope, roleAudie
} }
} }
} }
return out
}
func assertRoles(projectID string, user *query.OIDCUserInfo, roleAudience, requestedRoles []string, assertion bool, info *oidc.UserInfo) *oidc.UserInfo {
if !assertion {
return info
}
userInfo := *info
// prevent returning obtained grants if none where requested // prevent returning obtained grants if none where requested
if (projectID != "" && len(requestedRoles) > 0) || len(roleAudience) > 0 { if (projectID != "" && len(requestedRoles) > 0) || len(roleAudience) > 0 {
setUserInfoRoleClaims(out, newProjectRoles(projectID, user.UserGrants, requestedRoles)) setUserInfoRoleClaims(&userInfo, newProjectRoles(projectID, user.UserGrants, requestedRoles))
} }
return out return &userInfo
} }
func userInfoEmailToOIDC(user *query.User) oidc.UserInfoEmail { func userInfoEmailToOIDC(user *query.User) oidc.UserInfoEmail {
@ -230,11 +277,11 @@ func setUserInfoRoleClaims(userInfo *oidc.UserInfo, roles *projectsRoles) {
} }
} }
func (s *Server) userinfoFlows(ctx context.Context, qu *query.OIDCUserInfo, userInfo *oidc.UserInfo) (err error) { func (s *Server) userinfoFlows(ctx context.Context, qu *query.OIDCUserInfo, userInfo *oidc.UserInfo, triggerType domain.TriggerType) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
queriedActions, err := s.query.GetActiveActionsByFlowAndTriggerType(ctx, domain.FlowTypeCustomiseToken, domain.TriggerTypePreUserinfoCreation, qu.User.ResourceOwner) queriedActions, err := s.query.GetActiveActionsByFlowAndTriggerType(ctx, domain.FlowTypeCustomiseToken, triggerType, qu.User.ResourceOwner)
if err != nil { if err != nil {
return err return err
} }

View File

@ -231,9 +231,9 @@ func TestServer_UserInfo_Issue6662(t *testing.T) {
project, err := Tester.CreateProject(CTX) project, err := Tester.CreateProject(CTX)
projectID := project.GetId() projectID := project.GetId()
require.NoError(t, err) require.NoError(t, err)
userID, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX) user, _, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX)
require.NoError(t, err) require.NoError(t, err)
addProjectRolesGrants(t, userID, projectID, roleFoo, roleBar) addProjectRolesGrants(t, user.GetUserId(), projectID, roleFoo, roleBar)
scope := []string{oidc.ScopeProfile, oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeOfflineAccess, scope := []string{oidc.ScopeProfile, oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeOfflineAccess,
oidc_api.ScopeProjectRolePrefix + roleFoo, oidc_api.ScopeProjectRolePrefix + roleFoo,
@ -245,7 +245,7 @@ func TestServer_UserInfo_Issue6662(t *testing.T) {
tokens, err := rp.ClientCredentials(CTX, provider, nil) tokens, err := rp.ClientCredentials(CTX, provider, nil)
require.NoError(t, err) require.NoError(t, err)
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, userID, provider) userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, user.GetUserId(), provider)
require.NoError(t, err) require.NoError(t, err)
assertProjectRoleClaims(t, projectID, userinfo.Claims, false, roleFoo) assertProjectRoleClaims(t, projectID, userinfo.Claims, false, roleFoo)
} }
@ -291,7 +291,7 @@ func getTokens(t *testing.T, clientID string, scope []string) *oidc.Tokens[*oidc
tokens, err := exchangeTokens(t, clientID, code, redirectURI) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertTokens(t, tokens, true) assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
return tokens return tokens
} }

View File

@ -3,7 +3,6 @@ package oidc
import ( import (
"context" "context"
"encoding/base64" "encoding/base64"
"fmt"
"testing" "testing"
"time" "time"
@ -267,11 +266,9 @@ func Test_userInfoToOIDC(t *testing.T) {
} }
type args struct { type args struct {
projectID string user *query.OIDCUserInfo
user *query.OIDCUserInfo userInfoAssertion bool
scope []string scope []string
roleAudience []string
requestedRoles []string
} }
tests := []struct { tests := []struct {
name string name string
@ -281,25 +278,22 @@ func Test_userInfoToOIDC(t *testing.T) {
{ {
name: "human, empty", name: "human, empty",
args: args{ args: args{
projectID: "project1", user: humanUserInfo,
user: humanUserInfo,
}, },
want: &oidc.UserInfo{}, want: &oidc.UserInfo{},
}, },
{ {
name: "machine, empty", name: "machine, empty",
args: args{ args: args{
projectID: "project1", user: machineUserInfo,
user: machineUserInfo,
}, },
want: &oidc.UserInfo{}, want: &oidc.UserInfo{},
}, },
{ {
name: "human, scope openid", name: "human, scope openid",
args: args{ args: args{
projectID: "project1", user: humanUserInfo,
user: humanUserInfo, scope: []string{oidc.ScopeOpenID},
scope: []string{oidc.ScopeOpenID},
}, },
want: &oidc.UserInfo{ want: &oidc.UserInfo{
Subject: "human1", Subject: "human1",
@ -308,20 +302,19 @@ func Test_userInfoToOIDC(t *testing.T) {
{ {
name: "machine, scope openid", name: "machine, scope openid",
args: args{ args: args{
projectID: "project1", user: machineUserInfo,
user: machineUserInfo, scope: []string{oidc.ScopeOpenID},
scope: []string{oidc.ScopeOpenID},
}, },
want: &oidc.UserInfo{ want: &oidc.UserInfo{
Subject: "machine1", Subject: "machine1",
}, },
}, },
{ {
name: "human, scope email", name: "human, scope email, profileInfoAssertion",
args: args{ args: args{
projectID: "project1", user: humanUserInfo,
user: humanUserInfo, userInfoAssertion: true,
scope: []string{oidc.ScopeEmail}, scope: []string{oidc.ScopeEmail},
}, },
want: &oidc.UserInfo{ want: &oidc.UserInfo{
UserInfoEmail: oidc.UserInfoEmail{ UserInfoEmail: oidc.UserInfoEmail{
@ -331,22 +324,29 @@ func Test_userInfoToOIDC(t *testing.T) {
}, },
}, },
{ {
name: "machine, scope email", name: "human, scope email",
args: args{ args: args{
projectID: "project1", user: humanUserInfo,
user: machineUserInfo, scope: []string{oidc.ScopeEmail},
scope: []string{oidc.ScopeEmail}, },
want: &oidc.UserInfo{},
},
{
name: "machine, scope email, profileInfoAssertion",
args: args{
user: machineUserInfo,
scope: []string{oidc.ScopeEmail},
}, },
want: &oidc.UserInfo{ want: &oidc.UserInfo{
UserInfoEmail: oidc.UserInfoEmail{}, UserInfoEmail: oidc.UserInfoEmail{},
}, },
}, },
{ {
name: "human, scope profile", name: "human, scope profile, profileInfoAssertion",
args: args{ args: args{
projectID: "project1", user: humanUserInfo,
user: humanUserInfo, userInfoAssertion: true,
scope: []string{oidc.ScopeProfile}, scope: []string{oidc.ScopeProfile},
}, },
want: &oidc.UserInfo{ want: &oidc.UserInfo{
UserInfoProfile: oidc.UserInfoProfile{ UserInfoProfile: oidc.UserInfoProfile{
@ -363,11 +363,11 @@ func Test_userInfoToOIDC(t *testing.T) {
}, },
}, },
{ {
name: "machine, scope profile", name: "machine, scope profile, profileInfoAssertion",
args: args{ args: args{
projectID: "project1", user: machineUserInfo,
user: machineUserInfo, userInfoAssertion: true,
scope: []string{oidc.ScopeProfile}, scope: []string{oidc.ScopeProfile},
}, },
want: &oidc.UserInfo{ want: &oidc.UserInfo{
UserInfoProfile: oidc.UserInfoProfile{ UserInfoProfile: oidc.UserInfoProfile{
@ -378,11 +378,19 @@ func Test_userInfoToOIDC(t *testing.T) {
}, },
}, },
{ {
name: "human, scope phone", name: "machine, scope profile",
args: args{ args: args{
projectID: "project1", user: machineUserInfo,
user: humanUserInfo, scope: []string{oidc.ScopeProfile},
scope: []string{oidc.ScopePhone}, },
want: &oidc.UserInfo{},
},
{
name: "human, scope phone, profileInfoAssertion",
args: args{
user: humanUserInfo,
userInfoAssertion: true,
scope: []string{oidc.ScopePhone},
}, },
want: &oidc.UserInfo{ want: &oidc.UserInfo{
UserInfoPhone: oidc.UserInfoPhone{ UserInfoPhone: oidc.UserInfoPhone{
@ -391,12 +399,19 @@ func Test_userInfoToOIDC(t *testing.T) {
}, },
}, },
}, },
{
name: "human, scope phone",
args: args{
user: humanUserInfo,
scope: []string{oidc.ScopePhone},
},
want: &oidc.UserInfo{},
},
{ {
name: "machine, scope phone", name: "machine, scope phone",
args: args{ args: args{
projectID: "project1", user: machineUserInfo,
user: machineUserInfo, scope: []string{oidc.ScopePhone},
scope: []string{oidc.ScopePhone},
}, },
want: &oidc.UserInfo{ want: &oidc.UserInfo{
UserInfoPhone: oidc.UserInfoPhone{}, UserInfoPhone: oidc.UserInfoPhone{},
@ -405,9 +420,8 @@ func Test_userInfoToOIDC(t *testing.T) {
{ {
name: "human, scope metadata", name: "human, scope metadata",
args: args{ args: args{
projectID: "project1", user: humanUserInfo,
user: humanUserInfo, scope: []string{ScopeUserMetaData},
scope: []string{ScopeUserMetaData},
}, },
want: &oidc.UserInfo{ want: &oidc.UserInfo{
Claims: map[string]any{ Claims: map[string]any{
@ -421,18 +435,16 @@ func Test_userInfoToOIDC(t *testing.T) {
{ {
name: "machine, scope metadata, none found", name: "machine, scope metadata, none found",
args: args{ args: args{
projectID: "project1", user: machineUserInfo,
user: machineUserInfo, scope: []string{ScopeUserMetaData},
scope: []string{ScopeUserMetaData},
}, },
want: &oidc.UserInfo{}, want: &oidc.UserInfo{},
}, },
{ {
name: "machine, scope resource owner", name: "machine, scope resource owner",
args: args{ args: args{
projectID: "project1", user: machineUserInfo,
user: machineUserInfo, scope: []string{ScopeResourceOwner},
scope: []string{ScopeResourceOwner},
}, },
want: &oidc.UserInfo{ want: &oidc.UserInfo{
Claims: map[string]any{ Claims: map[string]any{
@ -445,9 +457,8 @@ func Test_userInfoToOIDC(t *testing.T) {
{ {
name: "human, scope org primary domain prefix", name: "human, scope org primary domain prefix",
args: args{ args: args{
projectID: "project1", user: humanUserInfo,
user: humanUserInfo, scope: []string{domain.OrgDomainPrimaryScope + "foo.com"},
scope: []string{domain.OrgDomainPrimaryScope + "foo.com"},
}, },
want: &oidc.UserInfo{ want: &oidc.UserInfo{
Claims: map[string]any{ Claims: map[string]any{
@ -458,9 +469,8 @@ func Test_userInfoToOIDC(t *testing.T) {
{ {
name: "machine, scope org id", name: "machine, scope org id",
args: args{ args: args{
projectID: "project1", user: machineUserInfo,
user: machineUserInfo, scope: []string{domain.OrgIDScope + "orgID"},
scope: []string{domain.OrgIDScope + "orgID"},
}, },
want: &oidc.UserInfo{ want: &oidc.UserInfo{
Claims: map[string]any{ Claims: map[string]any{
@ -471,50 +481,11 @@ func Test_userInfoToOIDC(t *testing.T) {
}, },
}, },
}, },
{
name: "human, roleAudience",
args: args{
projectID: "project1",
user: humanUserInfo,
roleAudience: []string{"project1"},
},
want: &oidc.UserInfo{
Claims: map[string]any{
ClaimProjectRoles: projectRoles{
"role1": {"orgID": "orgDomain"},
"role2": {"orgID": "orgDomain"},
},
fmt.Sprintf(ClaimProjectRolesFormat, "project1"): projectRoles{
"role1": {"orgID": "orgDomain"},
"role2": {"orgID": "orgDomain"},
},
},
},
},
{
name: "human, requested roles",
args: args{
projectID: "project1",
user: humanUserInfo,
roleAudience: []string{"project1"},
requestedRoles: []string{"role2"},
},
want: &oidc.UserInfo{
Claims: map[string]any{
ClaimProjectRoles: projectRoles{
"role2": {"orgID": "orgDomain"},
},
fmt.Sprintf(ClaimProjectRolesFormat, "project1"): projectRoles{
"role2": {"orgID": "orgDomain"},
},
},
},
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
assetPrefix := "https://foo.com/assets" assetPrefix := "https://foo.com/assets"
got := userInfoToOIDC(tt.args.projectID, tt.args.user, tt.args.scope, tt.args.roleAudience, tt.args.requestedRoles, assetPrefix) got := userInfoToOIDC(tt.args.user, tt.args.userInfoAssertion, tt.args.scope, assetPrefix)
assert.Equal(t, tt.want, got) assert.Equal(t, tt.want, got)
}) })
} }

View File

@ -342,7 +342,11 @@ func (l *Login) renderInternalError(w http.ResponseWriter, r *http.Request, auth
if authReq != nil { if authReq != nil {
log = log.WithField("auth_req_id", authReq.ID) log = log.WithField("auth_req_id", authReq.ID)
} }
log.Error() if zerrors.IsInternal(err) {
log.Error()
} else {
log.Info()
}
_, msg = l.getErrorMessage(r, err) _, msg = l.getErrorMessage(r, err)
} }

View File

@ -340,11 +340,7 @@ func (repo *AuthRequestRepo) VerifyPassword(ctx context.Context, authReqID, user
} }
return err return err
} }
policy, err := repo.getLockoutPolicy(ctx, resourceOwner) err = repo.Command.HumanCheckPassword(ctx, resourceOwner, userID, password, request.WithCurrentInfo(info))
if err != nil {
return err
}
err = repo.Command.HumanCheckPassword(ctx, resourceOwner, userID, password, request.WithCurrentInfo(info), lockoutPolicyToDomain(policy))
if isIgnoreUserInvalidPasswordError(err, request) { if isIgnoreUserInvalidPasswordError(err, request) {
return zerrors.ThrowInvalidArgument(nil, "EVENT-Jsf32", "Errors.User.UsernameOrPassword.Invalid") return zerrors.ThrowInvalidArgument(nil, "EVENT-Jsf32", "Errors.User.UsernameOrPassword.Invalid")
} }

View File

@ -32,7 +32,7 @@ func (c *Commands) AddDefaultLockoutPolicy(ctx context.Context, maxPasswordAttem
} }
func (c *Commands) ChangeDefaultLockoutPolicy(ctx context.Context, policy *domain.LockoutPolicy) (*domain.LockoutPolicy, error) { func (c *Commands) ChangeDefaultLockoutPolicy(ctx context.Context, policy *domain.LockoutPolicy) (*domain.LockoutPolicy, error) {
existingPolicy, err := c.defaultLockoutPolicyWriteModelByID(ctx) existingPolicy, err := defaultLockoutPolicyWriteModelByID(ctx, c.eventstore.FilterToQueryReducer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -63,12 +63,12 @@ func (c *Commands) ChangeDefaultLockoutPolicy(ctx context.Context, policy *domai
return writeModelToLockoutPolicy(&existingPolicy.LockoutPolicyWriteModel), nil return writeModelToLockoutPolicy(&existingPolicy.LockoutPolicyWriteModel), nil
} }
func (c *Commands) defaultLockoutPolicyWriteModelByID(ctx context.Context) (policy *InstanceLockoutPolicyWriteModel, err error) { func defaultLockoutPolicyWriteModelByID(ctx context.Context, reducer func(ctx context.Context, r eventstore.QueryReducer) error) (policy *InstanceLockoutPolicyWriteModel, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
writeModel := NewInstanceLockoutPolicyWriteModel(ctx) writeModel := NewInstanceLockoutPolicyWriteModel(ctx)
err = c.eventstore.FilterToQueryReducer(ctx, writeModel) err = reducer(ctx, writeModel)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/org" "github.com/zitadel/zitadel/internal/repository/org"
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
) )
@ -12,7 +13,7 @@ func (c *Commands) AddLockoutPolicy(ctx context.Context, resourceOwner string, p
if resourceOwner == "" { if resourceOwner == "" {
return nil, zerrors.ThrowInvalidArgument(nil, "Org-8fJif", "Errors.ResourceOwnerMissing") return nil, zerrors.ThrowInvalidArgument(nil, "Org-8fJif", "Errors.ResourceOwnerMissing")
} }
addedPolicy, err := c.orgLockoutPolicyWriteModelByID(ctx, resourceOwner) addedPolicy, err := orgLockoutPolicyWriteModelByID(ctx, resourceOwner, c.eventstore.FilterToQueryReducer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -42,7 +43,7 @@ func (c *Commands) ChangeLockoutPolicy(ctx context.Context, resourceOwner string
if resourceOwner == "" { if resourceOwner == "" {
return nil, zerrors.ThrowInvalidArgument(nil, "Org-3J9fs", "Errors.ResourceOwnerMissing") return nil, zerrors.ThrowInvalidArgument(nil, "Org-3J9fs", "Errors.ResourceOwnerMissing")
} }
existingPolicy, err := c.orgLockoutPolicyWriteModelByID(ctx, resourceOwner) existingPolicy, err := orgLockoutPolicyWriteModelByID(ctx, resourceOwner, c.eventstore.FilterToQueryReducer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -71,7 +72,7 @@ func (c *Commands) RemoveLockoutPolicy(ctx context.Context, orgID string) (*doma
if orgID == "" { if orgID == "" {
return nil, zerrors.ThrowInvalidArgument(nil, "Org-4J9fs", "Errors.ResourceOwnerMissing") return nil, zerrors.ThrowInvalidArgument(nil, "Org-4J9fs", "Errors.ResourceOwnerMissing")
} }
existingPolicy, err := c.orgLockoutPolicyWriteModelByID(ctx, orgID) existingPolicy, err := orgLockoutPolicyWriteModelByID(ctx, orgID, c.eventstore.FilterToQueryReducer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -93,7 +94,7 @@ func (c *Commands) RemoveLockoutPolicy(ctx context.Context, orgID string) (*doma
} }
func (c *Commands) removeLockoutPolicyIfExists(ctx context.Context, orgID string) (*org.LockoutPolicyRemovedEvent, error) { func (c *Commands) removeLockoutPolicyIfExists(ctx context.Context, orgID string) (*org.LockoutPolicyRemovedEvent, error) {
existingPolicy, err := c.orgLockoutPolicyWriteModelByID(ctx, orgID) existingPolicy, err := orgLockoutPolicyWriteModelByID(ctx, orgID, c.eventstore.FilterToQueryReducer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -104,24 +105,24 @@ func (c *Commands) removeLockoutPolicyIfExists(ctx context.Context, orgID string
return org.NewLockoutPolicyRemovedEvent(ctx, orgAgg), nil return org.NewLockoutPolicyRemovedEvent(ctx, orgAgg), nil
} }
func (c *Commands) orgLockoutPolicyWriteModelByID(ctx context.Context, orgID string) (*OrgLockoutPolicyWriteModel, error) { func orgLockoutPolicyWriteModelByID(ctx context.Context, orgID string, queryReducer func(ctx context.Context, r eventstore.QueryReducer) error) (*OrgLockoutPolicyWriteModel, error) {
policy := NewOrgLockoutPolicyWriteModel(orgID) policy := NewOrgLockoutPolicyWriteModel(orgID)
err := c.eventstore.FilterToQueryReducer(ctx, policy) err := queryReducer(ctx, policy)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return policy, nil return policy, nil
} }
func (c *Commands) getLockoutPolicy(ctx context.Context, orgID string) (*domain.LockoutPolicy, error) { func getLockoutPolicy(ctx context.Context, orgID string, queryReducer func(ctx context.Context, r eventstore.QueryReducer) error) (*domain.LockoutPolicy, error) {
orgWm, err := c.orgLockoutPolicyWriteModelByID(ctx, orgID) orgWm, err := orgLockoutPolicyWriteModelByID(ctx, orgID, queryReducer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if orgWm.State == domain.PolicyStateActive { if orgWm.State == domain.PolicyStateActive {
return writeModelToLockoutPolicy(&orgWm.LockoutPolicyWriteModel), nil return writeModelToLockoutPolicy(&orgWm.LockoutPolicyWriteModel), nil
} }
instanceWm, err := c.defaultLockoutPolicyWriteModelByID(ctx) instanceWm, err := defaultLockoutPolicyWriteModelByID(ctx, queryReducer)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/zitadel/logging"
"golang.org/x/text/language" "golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/activity" "github.com/zitadel/zitadel/internal/activity"
@ -17,21 +18,18 @@ import (
"github.com/zitadel/zitadel/internal/id" "github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/repository/session" "github.com/zitadel/zitadel/internal/repository/session"
"github.com/zitadel/zitadel/internal/repository/user" "github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
) )
type SessionCommand func(ctx context.Context, cmd *SessionCommands) error type SessionCommand func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error)
type SessionCommands struct { type SessionCommands struct {
sessionCommands []SessionCommand sessionCommands []SessionCommand
sessionWriteModel *SessionWriteModel sessionWriteModel *SessionWriteModel
passwordWriteModel *HumanPasswordWriteModel intentWriteModel *IDPIntentWriteModel
intentWriteModel *IDPIntentWriteModel eventstore *eventstore.Eventstore
totpWriteModel *HumanTOTPWriteModel eventCommands []eventstore.Command
eventstore *eventstore.Eventstore
eventCommands []eventstore.Command
hasher *crypto.Hasher hasher *crypto.Hasher
intentAlg crypto.EncryptionAlgorithm intentAlg crypto.EncryptionAlgorithm
@ -59,114 +57,92 @@ func (c *Commands) NewSessionCommands(cmds []SessionCommand, session *SessionWri
// CheckUser defines a user check to be executed for a session update // CheckUser defines a user check to be executed for a session update
func CheckUser(id string, resourceOwner string, preferredLanguage *language.Tag) SessionCommand { func CheckUser(id string, resourceOwner string, preferredLanguage *language.Tag) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) error { return func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
if cmd.sessionWriteModel.UserID != "" && id != "" && cmd.sessionWriteModel.UserID != id { if cmd.sessionWriteModel.UserID != "" && id != "" && cmd.sessionWriteModel.UserID != id {
return zerrors.ThrowInvalidArgument(nil, "", "user change not possible") return nil, zerrors.ThrowInvalidArgument(nil, "", "user change not possible")
} }
return cmd.UserChecked(ctx, id, resourceOwner, cmd.now(), preferredLanguage) return nil, cmd.UserChecked(ctx, id, resourceOwner, cmd.now(), preferredLanguage)
} }
} }
// CheckPassword defines a password check to be executed for a session update // CheckPassword defines a password check to be executed for a session update
func CheckPassword(password string) SessionCommand { func CheckPassword(password string) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) error { return func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
if cmd.sessionWriteModel.UserID == "" { commands, err := checkPassword(ctx, cmd.sessionWriteModel.UserID, password, cmd.eventstore, cmd.hasher, nil)
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sfw3f", "Errors.User.UserIDMissing")
}
cmd.passwordWriteModel = NewHumanPasswordWriteModel(cmd.sessionWriteModel.UserID, "")
err := cmd.eventstore.FilterToQueryReducer(ctx, cmd.passwordWriteModel)
if err != nil { if err != nil {
return err return commands, err
} }
if cmd.passwordWriteModel.UserState == domain.UserStateUnspecified || cmd.passwordWriteModel.UserState == domain.UserStateDeleted { cmd.eventCommands = append(cmd.eventCommands, commands...)
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-Df4b3", "Errors.User.NotFound")
}
if cmd.passwordWriteModel.EncodedHash == "" {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-WEf3t", "Errors.User.Password.NotSet")
}
ctx, spanPasswordComparison := tracing.NewNamedSpan(ctx, "passwap.Verify")
updated, err := cmd.hasher.Verify(cmd.passwordWriteModel.EncodedHash, password)
spanPasswordComparison.EndWithError(err)
if err != nil {
//TODO: maybe we want to reset the session in the future https://github.com/zitadel/zitadel/issues/5807
return zerrors.ThrowInvalidArgument(err, "COMMAND-SAF3g", "Errors.User.Password.Invalid")
}
if updated != "" {
cmd.eventCommands = append(cmd.eventCommands, user.NewHumanPasswordHashUpdatedEvent(ctx, UserAggregateFromWriteModel(&cmd.passwordWriteModel.WriteModel), updated))
}
cmd.PasswordChecked(ctx, cmd.now()) cmd.PasswordChecked(ctx, cmd.now())
return nil return nil, nil
} }
} }
// CheckIntent defines a check for a succeeded intent to be executed for a session update // CheckIntent defines a check for a succeeded intent to be executed for a session update
func CheckIntent(intentID, token string) SessionCommand { func CheckIntent(intentID, token string) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) error { return func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
if cmd.sessionWriteModel.UserID == "" { if cmd.sessionWriteModel.UserID == "" {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sfw3r", "Errors.User.UserIDMissing") return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sfw3r", "Errors.User.UserIDMissing")
} }
if err := crypto.CheckToken(cmd.intentAlg, token, intentID); err != nil { if err := crypto.CheckToken(cmd.intentAlg, token, intentID); err != nil {
return err return nil, err
} }
cmd.intentWriteModel = NewIDPIntentWriteModel(intentID, "") cmd.intentWriteModel = NewIDPIntentWriteModel(intentID, "")
err := cmd.eventstore.FilterToQueryReducer(ctx, cmd.intentWriteModel) err := cmd.eventstore.FilterToQueryReducer(ctx, cmd.intentWriteModel)
if err != nil { if err != nil {
return err return nil, err
} }
if cmd.intentWriteModel.State != domain.IDPIntentStateSucceeded { if cmd.intentWriteModel.State != domain.IDPIntentStateSucceeded {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-Df4bw", "Errors.Intent.NotSucceeded") return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Df4bw", "Errors.Intent.NotSucceeded")
} }
if cmd.intentWriteModel.UserID != "" { if cmd.intentWriteModel.UserID != "" {
if cmd.intentWriteModel.UserID != cmd.sessionWriteModel.UserID { if cmd.intentWriteModel.UserID != cmd.sessionWriteModel.UserID {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-O8xk3w", "Errors.Intent.OtherUser") return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-O8xk3w", "Errors.Intent.OtherUser")
} }
} else { } else {
linkWriteModel := NewUserIDPLinkWriteModel(cmd.sessionWriteModel.UserID, cmd.intentWriteModel.IDPID, cmd.intentWriteModel.IDPUserID, cmd.sessionWriteModel.UserResourceOwner) linkWriteModel := NewUserIDPLinkWriteModel(cmd.sessionWriteModel.UserID, cmd.intentWriteModel.IDPID, cmd.intentWriteModel.IDPUserID, cmd.sessionWriteModel.UserResourceOwner)
err := cmd.eventstore.FilterToQueryReducer(ctx, linkWriteModel) err := cmd.eventstore.FilterToQueryReducer(ctx, linkWriteModel)
if err != nil { if err != nil {
return err return nil, err
} }
if linkWriteModel.State != domain.UserIDPLinkStateActive { if linkWriteModel.State != domain.UserIDPLinkStateActive {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-O8xk3w", "Errors.Intent.OtherUser") return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-O8xk3w", "Errors.Intent.OtherUser")
} }
} }
cmd.IntentChecked(ctx, cmd.now()) cmd.IntentChecked(ctx, cmd.now())
return nil return nil, nil
} }
} }
func CheckTOTP(code string) SessionCommand { func CheckTOTP(code string) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) (err error) { return func(ctx context.Context, cmd *SessionCommands) (_ []eventstore.Command, err error) {
if cmd.sessionWriteModel.UserID == "" { commands, err := checkTOTP(
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-Neil7", "Errors.User.UserIDMissing") ctx,
} cmd.sessionWriteModel.UserID,
cmd.totpWriteModel = NewHumanTOTPWriteModel(cmd.sessionWriteModel.UserID, "") "",
err = cmd.eventstore.FilterToQueryReducer(ctx, cmd.totpWriteModel) code,
cmd.eventstore.FilterToQueryReducer,
cmd.totpAlg,
nil,
)
if err != nil { if err != nil {
return err return commands, err
}
if cmd.totpWriteModel.State != domain.MFAStateReady {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-eej1U", "Errors.User.MFA.OTP.NotReady")
}
err = domain.VerifyTOTP(code, cmd.totpWriteModel.Secret, cmd.totpAlg)
if err != nil {
return err
} }
cmd.eventCommands = append(cmd.eventCommands, commands...)
cmd.TOTPChecked(ctx, cmd.now()) cmd.TOTPChecked(ctx, cmd.now())
return nil return nil, nil
} }
} }
// Exec will execute the commands specified and returns an error on the first occurrence // Exec will execute the commands specified and returns an error on the first occurrence.
func (s *SessionCommands) Exec(ctx context.Context) error { // In case of an error there might be specific commands returned, e.g. a failed pw check will have to be stored.
func (s *SessionCommands) Exec(ctx context.Context) ([]eventstore.Command, error) {
for _, cmd := range s.sessionCommands { for _, cmd := range s.sessionCommands {
if err := cmd(ctx, s); err != nil { if cmds, err := cmd(ctx, s); err != nil {
return err return cmds, err
} }
} }
return nil return nil, nil
} }
func (s *SessionCommands) Start(ctx context.Context, userAgent *domain.UserAgent) { func (s *SessionCommands) Start(ctx context.Context, userAgent *domain.UserAgent) {
@ -360,8 +336,11 @@ func (c *Commands) updateSession(ctx context.Context, checks *SessionCommands, m
if err = checks.sessionWriteModel.CheckNotInvalidated(); err != nil { if err = checks.sessionWriteModel.CheckNotInvalidated(); err != nil {
return nil, err return nil, err
} }
if err := checks.Exec(ctx); err != nil { if cmds, err := checks.Exec(ctx); err != nil {
// TODO: how to handle failed checks (e.g. pw wrong) https://github.com/zitadel/zitadel/issues/5807 if len(cmds) > 0 {
_, pushErr := c.eventstore.Push(ctx, cmds...)
logging.OnError(pushErr).Error("unable to store check failures")
}
return nil, err return nil, err
} }
checks.ChangeMetadata(ctx, metadata) checks.ChangeMetadata(ctx, metadata)

View File

@ -6,9 +6,10 @@ import (
"golang.org/x/text/language" "golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/session" "github.com/zitadel/zitadel/internal/repository/session"
"github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
) )
@ -21,26 +22,26 @@ func (c *Commands) CreateOTPSMSChallenge() SessionCommand {
} }
func (c *Commands) createOTPSMSChallenge(returnCode bool, dst *string) SessionCommand { func (c *Commands) createOTPSMSChallenge(returnCode bool, dst *string) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) error { return func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
if cmd.sessionWriteModel.UserID == "" { if cmd.sessionWriteModel.UserID == "" {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-JKL3g", "Errors.User.UserIDMissing") return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-JKL3g", "Errors.User.UserIDMissing")
} }
writeModel := NewHumanOTPSMSWriteModel(cmd.sessionWriteModel.UserID, "") writeModel := NewHumanOTPSMSWriteModel(cmd.sessionWriteModel.UserID, "")
if err := cmd.eventstore.FilterToQueryReducer(ctx, writeModel); err != nil { if err := cmd.eventstore.FilterToQueryReducer(ctx, writeModel); err != nil {
return err return nil, err
} }
if !writeModel.OTPAdded() { if !writeModel.OTPAdded() {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-BJ2g3", "Errors.User.MFA.OTP.NotReady") return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-BJ2g3", "Errors.User.MFA.OTP.NotReady")
} }
code, err := cmd.createCode(ctx, cmd.eventstore.Filter, domain.SecretGeneratorTypeOTPSMS, cmd.otpAlg, c.defaultSecretGenerators.OTPSMS) code, err := cmd.createCode(ctx, cmd.eventstore.Filter, domain.SecretGeneratorTypeOTPSMS, cmd.otpAlg, c.defaultSecretGenerators.OTPSMS)
if err != nil { if err != nil {
return err return nil, err
} }
if returnCode { if returnCode {
*dst = code.Plain *dst = code.Plain
} }
cmd.OTPSMSChallenged(ctx, code.Crypted, code.Expiry, returnCode) cmd.OTPSMSChallenged(ctx, code.Crypted, code.Expiry, returnCode)
return nil return nil, nil
} }
} }
@ -74,26 +75,26 @@ func (c *Commands) CreateOTPEmailChallenge() SessionCommand {
} }
func (c *Commands) createOTPEmailChallenge(returnCode bool, urlTmpl string, dst *string) SessionCommand { func (c *Commands) createOTPEmailChallenge(returnCode bool, urlTmpl string, dst *string) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) error { return func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
if cmd.sessionWriteModel.UserID == "" { if cmd.sessionWriteModel.UserID == "" {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-JK3gp", "Errors.User.UserIDMissing") return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-JK3gp", "Errors.User.UserIDMissing")
} }
writeModel := NewHumanOTPEmailWriteModel(cmd.sessionWriteModel.UserID, "") writeModel := NewHumanOTPEmailWriteModel(cmd.sessionWriteModel.UserID, "")
if err := cmd.eventstore.FilterToQueryReducer(ctx, writeModel); err != nil { if err := cmd.eventstore.FilterToQueryReducer(ctx, writeModel); err != nil {
return err return nil, err
} }
if !writeModel.OTPAdded() { if !writeModel.OTPAdded() {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-JKLJ3", "Errors.User.MFA.OTP.NotReady") return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-JKLJ3", "Errors.User.MFA.OTP.NotReady")
} }
code, err := cmd.createCode(ctx, cmd.eventstore.Filter, domain.SecretGeneratorTypeOTPEmail, cmd.otpAlg, c.defaultSecretGenerators.OTPEmail) code, err := cmd.createCode(ctx, cmd.eventstore.Filter, domain.SecretGeneratorTypeOTPEmail, cmd.otpAlg, c.defaultSecretGenerators.OTPEmail)
if err != nil { if err != nil {
return err return nil, err
} }
if returnCode { if returnCode {
*dst = code.Plain *dst = code.Plain
} }
cmd.OTPEmailChallenged(ctx, code.Crypted, code.Expiry, returnCode, urlTmpl) cmd.OTPEmailChallenged(ctx, code.Crypted, code.Expiry, returnCode, urlTmpl)
return nil return nil, nil
} }
} }
@ -112,37 +113,57 @@ func (c *Commands) OTPEmailSent(ctx context.Context, sessionID, resourceOwner st
} }
func CheckOTPSMS(code string) SessionCommand { func CheckOTPSMS(code string) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) (err error) { return func(ctx context.Context, cmd *SessionCommands) (_ []eventstore.Command, err error) {
if cmd.sessionWriteModel.UserID == "" { writeModel := func(ctx context.Context, userID string, resourceOwner string) (OTPCodeWriteModel, error) {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-VDrh3", "Errors.User.UserIDMissing") otpWriteModel := NewHumanOTPSMSCodeWriteModel(cmd.sessionWriteModel.UserID, "")
err := cmd.eventstore.FilterToQueryReducer(ctx, otpWriteModel)
if err != nil {
return nil, err
}
// explicitly set the challenge from the session write model since the code write model will only check user events
otpWriteModel.otpCode = cmd.sessionWriteModel.OTPSMSCodeChallenge
return otpWriteModel, nil
} }
challenge := cmd.sessionWriteModel.OTPSMSCodeChallenge succeededEvent := func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command {
if challenge == nil { return user.NewHumanOTPSMSCheckSucceededEvent(ctx, aggregate, nil)
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-SF3tv", "Errors.User.Code.NotFound")
} }
err = crypto.VerifyCode(challenge.CreationDate, challenge.Expiry, challenge.Code, code, cmd.otpAlg) failedEvent := func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command {
return user.NewHumanOTPSMSCheckFailedEvent(ctx, aggregate, nil)
}
commands, err := checkOTP(ctx, cmd.sessionWriteModel.UserID, code, "", nil, writeModel, cmd.eventstore.FilterToQueryReducer, cmd.otpAlg, succeededEvent, failedEvent)
if err != nil { if err != nil {
return err return commands, err
} }
cmd.eventCommands = append(cmd.eventCommands, commands...)
cmd.OTPSMSChecked(ctx, cmd.now()) cmd.OTPSMSChecked(ctx, cmd.now())
return nil return nil, nil
} }
} }
func CheckOTPEmail(code string) SessionCommand { func CheckOTPEmail(code string) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) (err error) { return func(ctx context.Context, cmd *SessionCommands) (_ []eventstore.Command, err error) {
if cmd.sessionWriteModel.UserID == "" { writeModel := func(ctx context.Context, userID string, resourceOwner string) (OTPCodeWriteModel, error) {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-ejo2w", "Errors.User.UserIDMissing") otpWriteModel := NewHumanOTPEmailCodeWriteModel(cmd.sessionWriteModel.UserID, "")
err := cmd.eventstore.FilterToQueryReducer(ctx, otpWriteModel)
if err != nil {
return nil, err
}
// explicitly set the challenge from the session write model since the code write model will only check user events
otpWriteModel.otpCode = cmd.sessionWriteModel.OTPEmailCodeChallenge
return otpWriteModel, nil
} }
challenge := cmd.sessionWriteModel.OTPEmailCodeChallenge succeededEvent := func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command {
if challenge == nil { return user.NewHumanOTPEmailCheckSucceededEvent(ctx, aggregate, nil)
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-zF3g3", "Errors.User.Code.NotFound")
} }
err = crypto.VerifyCode(challenge.CreationDate, challenge.Expiry, challenge.Code, code, cmd.otpAlg) failedEvent := func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command {
return user.NewHumanOTPEmailCheckFailedEvent(ctx, aggregate, nil)
}
commands, err := checkOTP(ctx, cmd.sessionWriteModel.UserID, code, "", nil, writeModel, cmd.eventstore.FilterToQueryReducer, cmd.otpAlg, succeededEvent, failedEvent)
if err != nil { if err != nil {
return err return commands, err
} }
cmd.eventCommands = append(cmd.eventCommands, commands...)
cmd.OTPEmailChecked(ctx, cmd.now()) cmd.OTPEmailChecked(ctx, cmd.now())
return nil return nil, nil
} }
} }

View File

@ -11,6 +11,7 @@ import (
"github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/org"
"github.com/zitadel/zitadel/internal/repository/session" "github.com/zitadel/zitadel/internal/repository/session"
"github.com/zitadel/zitadel/internal/repository/user" "github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
@ -110,8 +111,9 @@ func TestCommands_CreateOTPSMSChallengeReturnCode(t *testing.T) {
now: time.Now, now: time.Now,
} }
err := cmd(context.Background(), cmds) gotCmds, err := cmd(context.Background(), cmds)
assert.ErrorIs(t, err, tt.res.err) assert.ErrorIs(t, err, tt.res.err)
assert.Empty(t, gotCmds)
assert.Equal(t, tt.res.returnCode, dst) assert.Equal(t, tt.res.returnCode, dst)
assert.Equal(t, tt.res.commands, cmds.eventCommands) assert.Equal(t, tt.res.commands, cmds.eventCommands)
}) })
@ -210,8 +212,9 @@ func TestCommands_CreateOTPSMSChallenge(t *testing.T) {
now: time.Now, now: time.Now,
} }
err := cmd(context.Background(), cmds) gotCmds, err := cmd(context.Background(), cmds)
assert.ErrorIs(t, err, tt.res.err) assert.ErrorIs(t, err, tt.res.err)
assert.Empty(t, gotCmds)
assert.Equal(t, tt.res.commands, cmds.eventCommands) assert.Equal(t, tt.res.commands, cmds.eventCommands)
}) })
} }
@ -410,8 +413,9 @@ func TestCommands_CreateOTPEmailChallengeURLTemplate(t *testing.T) {
now: time.Now, now: time.Now,
} }
err = cmd(context.Background(), cmds) gotCmds, err := cmd(context.Background(), cmds)
assert.ErrorIs(t, err, tt.res.err) assert.ErrorIs(t, err, tt.res.err)
assert.Empty(t, gotCmds)
assert.Equal(t, tt.res.commands, cmds.eventCommands) assert.Equal(t, tt.res.commands, cmds.eventCommands)
}) })
} }
@ -511,8 +515,9 @@ func TestCommands_CreateOTPEmailChallengeReturnCode(t *testing.T) {
now: time.Now, now: time.Now,
} }
err := cmd(context.Background(), cmds) gotCmds, err := cmd(context.Background(), cmds)
assert.ErrorIs(t, err, tt.res.err) assert.ErrorIs(t, err, tt.res.err)
assert.Empty(t, gotCmds)
assert.Equal(t, tt.res.returnCode, dst) assert.Equal(t, tt.res.returnCode, dst)
assert.Equal(t, tt.res.commands, cmds.eventCommands) assert.Equal(t, tt.res.commands, cmds.eventCommands)
}) })
@ -611,8 +616,9 @@ func TestCommands_CreateOTPEmailChallenge(t *testing.T) {
now: time.Now, now: time.Now,
} }
err := cmd(context.Background(), cmds) gotCmds, err := cmd(context.Background(), cmds)
assert.ErrorIs(t, err, tt.res.err) assert.ErrorIs(t, err, tt.res.err)
assert.Empty(t, gotCmds)
assert.Equal(t, tt.res.commands, cmds.eventCommands) assert.Equal(t, tt.res.commands, cmds.eventCommands)
}) })
} }
@ -701,8 +707,9 @@ func TestCheckOTPSMS(t *testing.T) {
code string code string
} }
type res struct { type res struct {
err error err error
commands []eventstore.Command commands []eventstore.Command
errorCommands []eventstore.Command
} }
tests := []struct { tests := []struct {
name string name string
@ -720,13 +727,43 @@ func TestCheckOTPSMS(t *testing.T) {
code: "code", code: "code",
}, },
res: res{ res: res{
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-VDrh3", "Errors.User.UserIDMissing"), err: zerrors.ThrowInvalidArgument(nil, "COMMAND-S453v", "Errors.User.UserIDMissing"),
},
},
{
name: "missing code",
fields: fields{
eventstore: expectEventstore(),
userID: "userID",
},
args: args{},
res: res{
err: zerrors.ThrowInvalidArgument(nil, "COMMAND-SJl2g", "Errors.User.Code.Empty"),
},
},
{
name: "not set up",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
),
userID: "userID",
},
args: args{
code: "code",
},
res: res{
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-d2r52", "Errors.User.MFA.OTP.NotReady"),
}, },
}, },
{ {
name: "missing challenge", name: "missing challenge",
fields: fields{ fields: fields{
eventstore: expectEventstore(), eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(user.NewHumanOTPSMSAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
),
),
userID: "userID", userID: "userID",
otpCodeChallenge: nil, otpCodeChallenge: nil,
}, },
@ -734,14 +771,26 @@ func TestCheckOTPSMS(t *testing.T) {
code: "code", code: "code",
}, },
res: res{ res: res{
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-SF3tv", "Errors.User.Code.NotFound"), err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-S34gh", "Errors.User.Code.NotFound"),
}, },
}, },
{ {
name: "invalid code", name: "invalid code",
fields: fields{ fields: fields{
eventstore: expectEventstore(), eventstore: expectEventstore(
userID: "userID", expectFilter(
eventFromEventPusher(user.NewHumanOTPSMSAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
),
expectFilter(), // recheck
expectFilter(
eventFromEventPusher(
org.NewLockoutPolicyAddedEvent(context.Background(), &org.NewAggregate("org1").Aggregate,
0, 0, false,
),
),
),
),
userID: "userID",
otpCodeChallenge: &OTPCode{ otpCodeChallenge: &OTPCode{
Code: &crypto.CryptoValue{ Code: &crypto.CryptoValue{
CryptoType: crypto.TypeEncryption, CryptoType: crypto.TypeEncryption,
@ -759,13 +808,61 @@ func TestCheckOTPSMS(t *testing.T) {
}, },
res: res{ res: res{
err: zerrors.ThrowPreconditionFailed(nil, "CODE-QvUQ4P", "Errors.User.Code.Expired"), err: zerrors.ThrowPreconditionFailed(nil, "CODE-QvUQ4P", "Errors.User.Code.Expired"),
errorCommands: []eventstore.Command{
user.NewHumanOTPSMSCheckFailedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
},
},
},
{
name: "invalid code, locked",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(user.NewHumanOTPSMSAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
),
expectFilter(), // recheck
expectFilter(
eventFromEventPusher(
org.NewLockoutPolicyAddedEvent(context.Background(), &org.NewAggregate("org1").Aggregate,
0, 1, false,
),
),
),
),
userID: "userID",
otpCodeChallenge: &OTPCode{
Code: &crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
Algorithm: "enc",
KeyID: "id",
Crypted: []byte("code"),
},
Expiry: 5 * time.Minute,
CreationDate: testNow.Add(-10 * time.Minute),
},
otpAlg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args: args{
code: "code",
},
res: res{
err: zerrors.ThrowPreconditionFailed(nil, "CODE-QvUQ4P", "Errors.User.Code.Expired"),
errorCommands: []eventstore.Command{
user.NewHumanOTPSMSCheckFailedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
user.NewUserLockedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate),
},
}, },
}, },
{ {
name: "check ok", name: "check ok",
fields: fields{ fields: fields{
eventstore: expectEventstore(), eventstore: expectEventstore(
userID: "userID", expectFilter(
eventFromEventPusher(user.NewHumanOTPSMSAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
),
expectFilter(), // recheck
),
userID: "userID",
otpCodeChallenge: &OTPCode{ otpCodeChallenge: &OTPCode{
Code: &crypto.CryptoValue{ Code: &crypto.CryptoValue{
CryptoType: crypto.TypeEncryption, CryptoType: crypto.TypeEncryption,
@ -783,12 +880,44 @@ func TestCheckOTPSMS(t *testing.T) {
}, },
res: res{ res: res{
commands: []eventstore.Command{ commands: []eventstore.Command{
user.NewHumanOTPSMSCheckSucceededEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
session.NewOTPSMSCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, session.NewOTPSMSCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
testNow, testNow,
), ),
}, },
}, },
}, },
{
name: "check ok, locked in the meantime",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(user.NewHumanOTPSMSAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
),
expectFilter(
user.NewUserLockedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate),
),
),
userID: "userID",
otpCodeChallenge: &OTPCode{
Code: &crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
Algorithm: "enc",
KeyID: "id",
Crypted: []byte("code"),
},
Expiry: 5 * time.Minute,
CreationDate: testNow,
},
otpAlg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args: args{
code: "code",
},
res: res{
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-S6h4R", "Errors.User.Locked"),
},
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -811,8 +940,9 @@ func TestCheckOTPSMS(t *testing.T) {
}, },
} }
err := cmd(context.Background(), cmds) gotCmds, err := cmd(context.Background(), cmds)
assert.ErrorIs(t, err, tt.res.err) assert.ErrorIs(t, err, tt.res.err)
assert.Equal(t, tt.res.errorCommands, gotCmds)
assert.Equal(t, tt.res.commands, cmds.eventCommands) assert.Equal(t, tt.res.commands, cmds.eventCommands)
}) })
} }
@ -829,8 +959,9 @@ func TestCheckOTPEmail(t *testing.T) {
code string code string
} }
type res struct { type res struct {
err error err error
commands []eventstore.Command commands []eventstore.Command
errorCommands []eventstore.Command
} }
tests := []struct { tests := []struct {
name string name string
@ -848,13 +979,43 @@ func TestCheckOTPEmail(t *testing.T) {
code: "code", code: "code",
}, },
res: res{ res: res{
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-ejo2w", "Errors.User.UserIDMissing"), err: zerrors.ThrowInvalidArgument(nil, "COMMAND-S453v", "Errors.User.UserIDMissing"),
},
},
{
name: "missing code",
fields: fields{
eventstore: expectEventstore(),
userID: "userID",
},
args: args{},
res: res{
err: zerrors.ThrowInvalidArgument(nil, "COMMAND-SJl2g", "Errors.User.Code.Empty"),
},
},
{
name: "not set up",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
),
userID: "userID",
},
args: args{
code: "code",
},
res: res{
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-d2r52", "Errors.User.MFA.OTP.NotReady"),
}, },
}, },
{ {
name: "missing challenge", name: "missing challenge",
fields: fields{ fields: fields{
eventstore: expectEventstore(), eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(user.NewHumanOTPEmailAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
),
),
userID: "userID", userID: "userID",
otpCodeChallenge: nil, otpCodeChallenge: nil,
}, },
@ -862,14 +1023,26 @@ func TestCheckOTPEmail(t *testing.T) {
code: "code", code: "code",
}, },
res: res{ res: res{
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-zF3g3", "Errors.User.Code.NotFound"), err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-S34gh", "Errors.User.Code.NotFound"),
}, },
}, },
{ {
name: "invalid code", name: "invalid code",
fields: fields{ fields: fields{
eventstore: expectEventstore(), eventstore: expectEventstore(
userID: "userID", expectFilter(
eventFromEventPusher(user.NewHumanOTPEmailAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
),
expectFilter(), // recheck
expectFilter(
eventFromEventPusher(
org.NewLockoutPolicyAddedEvent(context.Background(), &org.NewAggregate("org1").Aggregate,
0, 0, false,
),
),
),
),
userID: "userID",
otpCodeChallenge: &OTPCode{ otpCodeChallenge: &OTPCode{
Code: &crypto.CryptoValue{ Code: &crypto.CryptoValue{
CryptoType: crypto.TypeEncryption, CryptoType: crypto.TypeEncryption,
@ -887,13 +1060,61 @@ func TestCheckOTPEmail(t *testing.T) {
}, },
res: res{ res: res{
err: zerrors.ThrowPreconditionFailed(nil, "CODE-QvUQ4P", "Errors.User.Code.Expired"), err: zerrors.ThrowPreconditionFailed(nil, "CODE-QvUQ4P", "Errors.User.Code.Expired"),
errorCommands: []eventstore.Command{
user.NewHumanOTPEmailCheckFailedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
},
},
},
{
name: "invalid code, locked",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(user.NewHumanOTPEmailAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
),
expectFilter(), // recheck
expectFilter(
eventFromEventPusher(
org.NewLockoutPolicyAddedEvent(context.Background(), &org.NewAggregate("org1").Aggregate,
0, 1, false,
),
),
),
),
userID: "userID",
otpCodeChallenge: &OTPCode{
Code: &crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
Algorithm: "enc",
KeyID: "id",
Crypted: []byte("code"),
},
Expiry: 5 * time.Minute,
CreationDate: testNow.Add(-10 * time.Minute),
},
otpAlg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args: args{
code: "code",
},
res: res{
err: zerrors.ThrowPreconditionFailed(nil, "CODE-QvUQ4P", "Errors.User.Code.Expired"),
errorCommands: []eventstore.Command{
user.NewHumanOTPEmailCheckFailedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
user.NewUserLockedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate),
},
}, },
}, },
{ {
name: "check ok", name: "check ok",
fields: fields{ fields: fields{
eventstore: expectEventstore(), eventstore: expectEventstore(
userID: "userID", expectFilter(
eventFromEventPusher(user.NewHumanOTPEmailAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
),
expectFilter(), // recheck
),
userID: "userID",
otpCodeChallenge: &OTPCode{ otpCodeChallenge: &OTPCode{
Code: &crypto.CryptoValue{ Code: &crypto.CryptoValue{
CryptoType: crypto.TypeEncryption, CryptoType: crypto.TypeEncryption,
@ -911,12 +1132,44 @@ func TestCheckOTPEmail(t *testing.T) {
}, },
res: res{ res: res{
commands: []eventstore.Command{ commands: []eventstore.Command{
user.NewHumanOTPEmailCheckSucceededEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
session.NewOTPEmailCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, session.NewOTPEmailCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
testNow, testNow,
), ),
}, },
}, },
}, },
{
name: "check ok, locked in the meantime",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(user.NewHumanOTPEmailAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
),
expectFilter(
user.NewUserLockedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate),
),
),
userID: "userID",
otpCodeChallenge: &OTPCode{
Code: &crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
Algorithm: "enc",
KeyID: "id",
Crypted: []byte("code"),
},
Expiry: 5 * time.Minute,
CreationDate: testNow,
},
otpAlg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args: args{
code: "code",
},
res: res{
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-S6h4R", "Errors.User.Locked"),
},
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -939,8 +1192,9 @@ func TestCheckOTPEmail(t *testing.T) {
}, },
} }
err := cmd(context.Background(), cmds) gotCmds, err := cmd(context.Background(), cmds)
assert.ErrorIs(t, err, tt.res.err) assert.ErrorIs(t, err, tt.res.err)
assert.Equal(t, tt.res.errorCommands, gotCmds)
assert.Equal(t, tt.res.commands, cmds.eventCommands) assert.Equal(t, tt.res.commands, cmds.eventCommands)
}) })
} }

View File

@ -22,6 +22,7 @@ import (
"github.com/zitadel/zitadel/internal/id" "github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/id/mock" "github.com/zitadel/zitadel/internal/id/mock"
"github.com/zitadel/zitadel/internal/repository/idpintent" "github.com/zitadel/zitadel/internal/repository/idpintent"
"github.com/zitadel/zitadel/internal/repository/org"
"github.com/zitadel/zitadel/internal/repository/session" "github.com/zitadel/zitadel/internal/repository/session"
"github.com/zitadel/zitadel/internal/repository/user" "github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
@ -430,8 +431,8 @@ func TestCommands_updateSession(t *testing.T) {
checks: &SessionCommands{ checks: &SessionCommands{
sessionWriteModel: NewSessionWriteModel("sessionID", "instance1"), sessionWriteModel: NewSessionWriteModel("sessionID", "instance1"),
sessionCommands: []SessionCommand{ sessionCommands: []SessionCommand{
func(ctx context.Context, cmd *SessionCommands) error { func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
return zerrors.ThrowInternal(nil, "id", "check failed") return nil, zerrors.ThrowInternal(nil, "id", "check failed")
}, },
}, },
}, },
@ -525,6 +526,55 @@ func TestCommands_updateSession(t *testing.T) {
}, },
}, },
}, },
{
"set user, invalid password",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
user.NewHumanAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate,
"username", "", "", "", "", language.English, domain.GenderUnspecified, "", false),
),
eventFromEventPusher(
user.NewHumanPasswordChangedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate,
"$plain$x$password", false, ""),
),
),
expectFilter(), // recheck
expectFilter(
org.NewLockoutPolicyAddedEvent(context.Background(), &org.NewAggregate("org1").Aggregate, 0, 0, false),
),
expectPush(
user.NewHumanPasswordCheckFailedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
),
),
},
args{
ctx: authz.NewMockContext("instance1", "", ""),
checks: &SessionCommands{
sessionWriteModel: NewSessionWriteModel("sessionID", "instance1"),
sessionCommands: []SessionCommand{
CheckUser("userID", "org1", &language.Afrikaans),
CheckPassword("invalid password"),
},
createToken: func(sessionID string) (string, string, error) {
return "tokenID",
"token",
nil
},
hasher: mockPasswordHasher("x"),
now: func() time.Time {
return testNow
},
},
metadata: map[string][]byte{
"key": []byte("value"),
},
},
res{
err: zerrors.ThrowInvalidArgument(nil, "COMMAND-3M0fs", "Errors.User.Password.Invalid"),
},
},
{ {
"set user, password, metadata and token", "set user, password, metadata and token",
fields{ fields{
@ -539,10 +589,12 @@ func TestCommands_updateSession(t *testing.T) {
"$plain$x$password", false, ""), "$plain$x$password", false, ""),
), ),
), ),
expectFilter(), // recheck
expectPush( expectPush(
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate, session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
"userID", "org1", testNow, &language.Afrikaans, "userID", "org1", testNow, &language.Afrikaans,
), ),
user.NewHumanPasswordCheckSucceededEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
session.NewPasswordCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate, session.NewPasswordCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
testNow, testNow,
), ),
@ -872,6 +924,7 @@ func TestCheckTOTP(t *testing.T) {
sessAgg := &session.NewAggregate("session1", "instance1").Aggregate sessAgg := &session.NewAggregate("session1", "instance1").Aggregate
userAgg := &user.NewAggregate("user1", "org1").Aggregate userAgg := &user.NewAggregate("user1", "org1").Aggregate
orgAgg := &org.NewAggregate("org1").Aggregate
code, err := totp.GenerateCode(key.Secret(), testNow) code, err := totp.GenerateCode(key.Secret(), testNow)
require.NoError(t, err) require.NoError(t, err)
@ -886,6 +939,7 @@ func TestCheckTOTP(t *testing.T) {
code string code string
fields fields fields fields
wantEventCommands []eventstore.Command wantEventCommands []eventstore.Command
wantErrorCommands []eventstore.Command
wantErr error wantErr error
}{ }{
{ {
@ -897,7 +951,7 @@ func TestCheckTOTP(t *testing.T) {
}, },
eventstore: expectEventstore(), eventstore: expectEventstore(),
}, },
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Neil7", "Errors.User.UserIDMissing"), wantErr: zerrors.ThrowInvalidArgument(nil, "COMMAND-8N9ds", "Errors.User.UserIDMissing"),
}, },
{ {
name: "filter error", name: "filter error",
@ -931,7 +985,7 @@ func TestCheckTOTP(t *testing.T) {
), ),
), ),
}, },
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-eej1U", "Errors.User.MFA.OTP.NotReady"), wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-3Mif9s", "Errors.User.MFA.OTP.NotReady"),
}, },
{ {
name: "otp verify error", name: "otp verify error",
@ -951,8 +1005,45 @@ func TestCheckTOTP(t *testing.T) {
user.NewHumanOTPVerifiedEvent(ctx, userAgg, "agent1"), user.NewHumanOTPVerifiedEvent(ctx, userAgg, "agent1"),
), ),
), ),
expectFilter(), // recheck
expectFilter(
eventFromEventPusher(org.NewLockoutPolicyAddedEvent(ctx, orgAgg, 0, 0, false)),
),
), ),
}, },
wantErrorCommands: []eventstore.Command{
user.NewHumanOTPCheckFailedEvent(ctx, userAgg, nil),
},
wantErr: zerrors.ThrowInvalidArgument(nil, "EVENT-8isk2", "Errors.User.MFA.OTP.InvalidCode"),
},
{
name: "otp verify error, locked",
code: "foobar",
fields: fields{
sessionWriteModel: &SessionWriteModel{
UserID: "user1",
UserCheckedAt: testNow,
aggregate: sessAgg,
},
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
user.NewHumanOTPAddedEvent(ctx, userAgg, secret),
),
eventFromEventPusher(
user.NewHumanOTPVerifiedEvent(ctx, userAgg, "agent1"),
),
),
expectFilter(), // recheck
expectFilter(
eventFromEventPusher(org.NewLockoutPolicyAddedEvent(ctx, orgAgg, 1, 1, false)),
),
),
},
wantErrorCommands: []eventstore.Command{
user.NewHumanOTPCheckFailedEvent(ctx, userAgg, nil),
user.NewUserLockedEvent(ctx, userAgg),
},
wantErr: zerrors.ThrowInvalidArgument(nil, "EVENT-8isk2", "Errors.User.MFA.OTP.InvalidCode"), wantErr: zerrors.ThrowInvalidArgument(nil, "EVENT-8isk2", "Errors.User.MFA.OTP.InvalidCode"),
}, },
{ {
@ -973,12 +1064,39 @@ func TestCheckTOTP(t *testing.T) {
user.NewHumanOTPVerifiedEvent(ctx, userAgg, "agent1"), user.NewHumanOTPVerifiedEvent(ctx, userAgg, "agent1"),
), ),
), ),
expectFilter(), // recheck
), ),
}, },
wantEventCommands: []eventstore.Command{ wantEventCommands: []eventstore.Command{
user.NewHumanOTPCheckSucceededEvent(ctx, userAgg, nil),
session.NewTOTPCheckedEvent(ctx, sessAgg, testNow), session.NewTOTPCheckedEvent(ctx, sessAgg, testNow),
}, },
}, },
{
name: "ok, but locked in the meantime",
code: code,
fields: fields{
sessionWriteModel: &SessionWriteModel{
UserID: "user1",
UserCheckedAt: testNow,
aggregate: sessAgg,
},
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
user.NewHumanOTPAddedEvent(ctx, userAgg, secret),
),
eventFromEventPusher(
user.NewHumanOTPVerifiedEvent(ctx, userAgg, "agent1"),
),
),
expectFilter(
user.NewUserLockedEvent(ctx, userAgg),
),
),
},
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-SF3fg", "Errors.User.Locked"),
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -988,8 +1106,9 @@ func TestCheckTOTP(t *testing.T) {
totpAlg: cryptoAlg, totpAlg: cryptoAlg,
now: func() time.Time { return testNow }, now: func() time.Time { return testNow },
} }
err := CheckTOTP(tt.code)(ctx, cmd) gotCmds, err := CheckTOTP(tt.code)(ctx, cmd)
require.ErrorIs(t, err, tt.wantErr) require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantErrorCommands, gotCmds)
assert.Equal(t, tt.wantEventCommands, cmd.eventCommands) assert.Equal(t, tt.wantEventCommands, cmd.eventCommands)
}) })
} }

View File

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
) )
@ -41,49 +42,49 @@ func (s *SessionCommands) getHumanWebAuthNTokenReadModel(ctx context.Context, us
} }
func (c *Commands) CreateWebAuthNChallenge(userVerification domain.UserVerificationRequirement, rpid string, dst json.Unmarshaler) SessionCommand { func (c *Commands) CreateWebAuthNChallenge(userVerification domain.UserVerificationRequirement, rpid string, dst json.Unmarshaler) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) error { return func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
humanPasskeys, err := cmd.getHumanWebAuthNTokens(ctx, userVerification) humanPasskeys, err := cmd.getHumanWebAuthNTokens(ctx, userVerification)
if err != nil { if err != nil {
return err return nil, err
} }
webAuthNLogin, err := c.webauthnConfig.BeginLogin(ctx, humanPasskeys.human, userVerification, rpid, humanPasskeys.tokens...) webAuthNLogin, err := c.webauthnConfig.BeginLogin(ctx, humanPasskeys.human, userVerification, rpid, humanPasskeys.tokens...)
if err != nil { if err != nil {
return err return nil, err
} }
if err = json.Unmarshal(webAuthNLogin.CredentialAssertionData, dst); err != nil { if err = json.Unmarshal(webAuthNLogin.CredentialAssertionData, dst); err != nil {
return zerrors.ThrowInternal(err, "COMMAND-Yah6A", "Errors.Internal") return nil, zerrors.ThrowInternal(err, "COMMAND-Yah6A", "Errors.Internal")
} }
cmd.WebAuthNChallenged(ctx, webAuthNLogin.Challenge, webAuthNLogin.AllowedCredentialIDs, webAuthNLogin.UserVerification, rpid) cmd.WebAuthNChallenged(ctx, webAuthNLogin.Challenge, webAuthNLogin.AllowedCredentialIDs, webAuthNLogin.UserVerification, rpid)
return nil return nil, nil
} }
} }
func (c *Commands) CheckWebAuthN(credentialAssertionData json.Marshaler) SessionCommand { func (c *Commands) CheckWebAuthN(credentialAssertionData json.Marshaler) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) error { return func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
credentialAssertionData, err := json.Marshal(credentialAssertionData) credentialAssertionData, err := json.Marshal(credentialAssertionData)
if err != nil { if err != nil {
return zerrors.ThrowInternal(err, "COMMAND-ohG2o", "Errors.Internal") return nil, zerrors.ThrowInternal(err, "COMMAND-ohG2o", "Errors.Internal")
} }
challenge := cmd.sessionWriteModel.WebAuthNChallenge challenge := cmd.sessionWriteModel.WebAuthNChallenge
if challenge == nil { if challenge == nil {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-Ioqu5", "Errors.Session.WebAuthN.NoChallenge") return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Ioqu5", "Errors.Session.WebAuthN.NoChallenge")
} }
webAuthNTokens, err := cmd.getHumanWebAuthNTokens(ctx, challenge.UserVerification) webAuthNTokens, err := cmd.getHumanWebAuthNTokens(ctx, challenge.UserVerification)
if err != nil { if err != nil {
return err return nil, err
} }
webAuthN := challenge.WebAuthNLogin(webAuthNTokens.human, credentialAssertionData) webAuthN := challenge.WebAuthNLogin(webAuthNTokens.human, credentialAssertionData)
credential, err := c.webauthnConfig.FinishLogin(ctx, webAuthNTokens.human, webAuthN, credentialAssertionData, webAuthNTokens.tokens...) credential, err := c.webauthnConfig.FinishLogin(ctx, webAuthNTokens.human, webAuthN, credentialAssertionData, webAuthNTokens.tokens...)
if err != nil && (credential == nil || credential.ID == nil) { if err != nil && (credential == nil || credential.ID == nil) {
return err return nil, err
} }
_, token := domain.GetTokenByKeyID(webAuthNTokens.tokens, credential.ID) _, token := domain.GetTokenByKeyID(webAuthNTokens.tokens, credential.ID)
if token == nil { if token == nil {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-Aej7i", "Errors.User.WebAuthN.NotFound") return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Aej7i", "Errors.User.WebAuthN.NotFound")
} }
cmd.WebAuthNChecked(ctx, cmd.now(), token.WebAuthNTokenID, credential.Authenticator.SignCount, credential.Flags.UserVerified) cmd.WebAuthNChecked(ctx, cmd.now(), token.WebAuthNTokenID, credential.Authenticator.SignCount, credential.Flags.UserVerified)
return nil return nil, nil
} }
} }

View File

@ -161,48 +161,67 @@ func (c *Commands) HumanCheckMFATOTPSetup(ctx context.Context, userID, code, use
} }
func (c *Commands) HumanCheckMFATOTP(ctx context.Context, userID, code, resourceOwner string, authRequest *domain.AuthRequest) error { func (c *Commands) HumanCheckMFATOTP(ctx context.Context, userID, code, resourceOwner string, authRequest *domain.AuthRequest) error {
commands, err := checkTOTP(
ctx,
userID,
resourceOwner,
code,
c.eventstore.FilterToQueryReducer,
c.multifactors.OTP.CryptoMFA,
authRequestDomainToAuthRequestInfo(authRequest),
)
_, pushErr := c.eventstore.Push(ctx, commands...)
logging.OnError(pushErr).Error("error create password check failed event")
return err
}
func checkTOTP(
ctx context.Context,
userID, resourceOwner, code string,
queryReducer func(ctx context.Context, r eventstore.QueryReducer) error,
alg crypto.EncryptionAlgorithm,
optionalAuthRequestInfo *user.AuthRequestInfo,
) ([]eventstore.Command, error) {
if userID == "" { if userID == "" {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-8N9ds", "Errors.User.UserIDMissing") return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-8N9ds", "Errors.User.UserIDMissing")
} }
existingOTP, err := c.totpWriteModelByID(ctx, userID, resourceOwner) existingOTP := NewHumanTOTPWriteModel(userID, resourceOwner)
err := queryReducer(ctx, existingOTP)
if err != nil { if err != nil {
return err return nil, err
} }
if existingOTP.State != domain.MFAStateReady { if existingOTP.State != domain.MFAStateReady {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-3Mif9s", "Errors.User.MFA.OTP.NotReady") return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-3Mif9s", "Errors.User.MFA.OTP.NotReady")
} }
userAgg := UserAggregateFromWriteModel(&existingOTP.WriteModel) userAgg := UserAggregateFromWriteModel(&existingOTP.WriteModel)
verifyErr := domain.VerifyTOTP(code, existingOTP.Secret, c.multifactors.OTP.CryptoMFA) verifyErr := domain.VerifyTOTP(code, existingOTP.Secret, alg)
// recheck for additional events (failed OTP checks or locks) // recheck for additional events (failed OTP checks or locks)
recheckErr := c.eventstore.FilterToQueryReducer(ctx, existingOTP) recheckErr := queryReducer(ctx, existingOTP)
if recheckErr != nil { if recheckErr != nil {
return recheckErr return nil, recheckErr
} }
if existingOTP.UserLocked { if existingOTP.UserLocked {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-SF3fg", "Errors.User.Locked") return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-SF3fg", "Errors.User.Locked")
} }
// the OTP check succeeded and the user was not locked in the meantime // the OTP check succeeded and the user was not locked in the meantime
if verifyErr == nil { if verifyErr == nil {
_, err = c.eventstore.Push(ctx, user.NewHumanOTPCheckSucceededEvent(ctx, userAgg, authRequestDomainToAuthRequestInfo(authRequest))) return []eventstore.Command{user.NewHumanOTPCheckSucceededEvent(ctx, userAgg, optionalAuthRequestInfo)}, nil
return err
} }
// the OTP check failed, therefore check if the limit was reached and the user must additionally be locked // the OTP check failed, therefore check if the limit was reached and the user must additionally be locked
commands := make([]eventstore.Command, 0, 2) commands := make([]eventstore.Command, 0, 2)
commands = append(commands, user.NewHumanOTPCheckFailedEvent(ctx, userAgg, authRequestDomainToAuthRequestInfo(authRequest))) commands = append(commands, user.NewHumanOTPCheckFailedEvent(ctx, userAgg, optionalAuthRequestInfo))
lockoutPolicy, err := c.getLockoutPolicy(ctx, resourceOwner) lockoutPolicy, err := getLockoutPolicy(ctx, existingOTP.ResourceOwner, queryReducer)
if err != nil { if err != nil {
return err return nil, err
} }
if lockoutPolicy.MaxOTPAttempts > 0 && existingOTP.CheckFailedCount+1 >= lockoutPolicy.MaxOTPAttempts { if lockoutPolicy.MaxOTPAttempts > 0 && existingOTP.CheckFailedCount+1 >= lockoutPolicy.MaxOTPAttempts {
commands = append(commands, user.NewUserLockedEvent(ctx, userAgg)) commands = append(commands, user.NewUserLockedEvent(ctx, userAgg))
} }
return commands, verifyErr
_, pushErr := c.eventstore.Push(ctx, commands...)
logging.OnError(pushErr).Error("error create password check failed event")
return verifyErr
} }
func (c *Commands) HumanRemoveTOTP(ctx context.Context, userID, resourceOwner string) (*domain.ObjectDetails, error) { func (c *Commands) HumanRemoveTOTP(ctx context.Context, userID, resourceOwner string) (*domain.ObjectDetails, error) {
@ -342,16 +361,23 @@ func (c *Commands) HumanCheckOTPSMS(ctx context.Context, userID, code, resourceO
failedEvent := func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command { failedEvent := func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command {
return user.NewHumanOTPSMSCheckFailedEvent(ctx, aggregate, authRequestDomainToAuthRequestInfo(authRequest)) return user.NewHumanOTPSMSCheckFailedEvent(ctx, aggregate, authRequestDomainToAuthRequestInfo(authRequest))
} }
return c.humanCheckOTP( commands, err := checkOTP(
ctx, ctx,
userID, userID,
code, code,
resourceOwner, resourceOwner,
authRequest, authRequest,
writeModel, writeModel,
c.eventstore.FilterToQueryReducer,
c.userEncryption,
succeededEvent, succeededEvent,
failedEvent, failedEvent,
) )
if len(commands) > 0 {
_, pushErr := c.eventstore.Push(ctx, commands...)
logging.WithFields("userID", userID).OnError(pushErr).Error("otp failure check push failed")
}
return err
} }
// AddHumanOTPEmail adds the OTP Email factor to a user. // AddHumanOTPEmail adds the OTP Email factor to a user.
@ -467,16 +493,23 @@ func (c *Commands) HumanCheckOTPEmail(ctx context.Context, userID, code, resourc
failedEvent := func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command { failedEvent := func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command {
return user.NewHumanOTPEmailCheckFailedEvent(ctx, aggregate, authRequestDomainToAuthRequestInfo(authRequest)) return user.NewHumanOTPEmailCheckFailedEvent(ctx, aggregate, authRequestDomainToAuthRequestInfo(authRequest))
} }
return c.humanCheckOTP( commands, err := checkOTP(
ctx, ctx,
userID, userID,
code, code,
resourceOwner, resourceOwner,
authRequest, authRequest,
writeModel, writeModel,
c.eventstore.FilterToQueryReducer,
c.userEncryption,
succeededEvent, succeededEvent,
failedEvent, failedEvent,
) )
if len(commands) > 0 {
_, pushErr := c.eventstore.Push(ctx, commands...)
logging.WithFields("userID", userID).OnError(pushErr).Error("otp failure check push failed")
}
return err
} }
// sendHumanOTP creates a code for a registered mechanism (sms / email), which is used for a check (during login) // sendHumanOTP creates a code for a registered mechanism (sms / email), which is used for a check (during login)
@ -534,62 +567,57 @@ func (c *Commands) humanOTPSent(
return err return err
} }
func (c *Commands) humanCheckOTP( func checkOTP(
ctx context.Context, ctx context.Context,
userID, code, resourceOwner string, userID, code, resourceOwner string,
authRequest *domain.AuthRequest, authRequest *domain.AuthRequest,
writeModelByID func(ctx context.Context, userID string, resourceOwner string) (OTPCodeWriteModel, error), writeModelByID func(ctx context.Context, userID string, resourceOwner string) (OTPCodeWriteModel, error),
checkSucceededEvent func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command, queryReducer func(ctx context.Context, r eventstore.QueryReducer) error,
checkFailedEvent func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command, alg crypto.EncryptionAlgorithm,
) error { checkSucceededEvent, checkFailedEvent func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command,
) ([]eventstore.Command, error) {
if userID == "" { if userID == "" {
return zerrors.ThrowInvalidArgument(nil, "COMMAND-S453v", "Errors.User.UserIDMissing") return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-S453v", "Errors.User.UserIDMissing")
} }
if code == "" { if code == "" {
return zerrors.ThrowInvalidArgument(nil, "COMMAND-SJl2g", "Errors.User.Code.Empty") return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-SJl2g", "Errors.User.Code.Empty")
} }
existingOTP, err := writeModelByID(ctx, userID, resourceOwner) existingOTP, err := writeModelByID(ctx, userID, resourceOwner)
if err != nil { if err != nil {
return err return nil, err
} }
if !existingOTP.OTPAdded() { if !existingOTP.OTPAdded() {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-d2r52", "Errors.User.MFA.OTP.NotReady") return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-d2r52", "Errors.User.MFA.OTP.NotReady")
} }
if existingOTP.Code() == nil { if existingOTP.Code() == nil {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-S34gh", "Errors.User.Code.NotFound") return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-S34gh", "Errors.User.Code.NotFound")
} }
userAgg := &user.NewAggregate(userID, existingOTP.ResourceOwner()).Aggregate userAgg := &user.NewAggregate(userID, existingOTP.ResourceOwner()).Aggregate
verifyErr := crypto.VerifyCode(existingOTP.CodeCreationDate(), existingOTP.CodeExpiry(), existingOTP.Code(), code, c.userEncryption) verifyErr := crypto.VerifyCode(existingOTP.CodeCreationDate(), existingOTP.CodeExpiry(), existingOTP.Code(), code, alg)
// recheck for additional events (failed OTP checks or locks) // recheck for additional events (failed OTP checks or locks)
recheckErr := c.eventstore.FilterToQueryReducer(ctx, existingOTP) recheckErr := queryReducer(ctx, existingOTP)
if recheckErr != nil { if recheckErr != nil {
return recheckErr return nil, recheckErr
} }
if existingOTP.UserLocked() { if existingOTP.UserLocked() {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-S6h4R", "Errors.User.Locked") return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-S6h4R", "Errors.User.Locked")
} }
// the OTP check succeeded and the user was not locked in the meantime // the OTP check succeeded and the user was not locked in the meantime
if verifyErr == nil { if verifyErr == nil {
_, err = c.eventstore.Push(ctx, checkSucceededEvent(ctx, userAgg, authRequestDomainToAuthRequestInfo(authRequest))) return []eventstore.Command{checkSucceededEvent(ctx, userAgg, authRequestDomainToAuthRequestInfo(authRequest))}, nil
return err
} }
// the OTP check failed, therefore check if the limit was reached and the user must additionally be locked // the OTP check failed, therefore check if the limit was reached and the user must additionally be locked
commands := make([]eventstore.Command, 0, 2) commands := make([]eventstore.Command, 0, 2)
commands = append(commands, checkFailedEvent(ctx, userAgg, authRequestDomainToAuthRequestInfo(authRequest))) commands = append(commands, checkFailedEvent(ctx, userAgg, authRequestDomainToAuthRequestInfo(authRequest)))
lockoutPolicy, err := c.getLockoutPolicy(ctx, resourceOwner) lockoutPolicy, lockoutErr := getLockoutPolicy(ctx, existingOTP.ResourceOwner(), queryReducer)
if err != nil { logging.OnError(lockoutErr).Error("unable to get lockout policy")
return err if lockoutPolicy != nil && lockoutPolicy.MaxOTPAttempts > 0 && existingOTP.CheckFailedCount()+1 >= lockoutPolicy.MaxOTPAttempts {
}
if lockoutPolicy.MaxOTPAttempts > 0 && existingOTP.CheckFailedCount()+1 >= lockoutPolicy.MaxOTPAttempts {
commands = append(commands, user.NewUserLockedEvent(ctx, userAgg)) commands = append(commands, user.NewUserLockedEvent(ctx, userAgg))
} }
return commands, verifyErr
_, pushErr := c.eventstore.Push(ctx, commands...)
logging.WithFields("userID", userID).OnError(pushErr).Error("otp failure check push failed")
return verifyErr
} }
func (c *Commands) totpWriteModelByID(ctx context.Context, userID, resourceOwner string) (writeModel *HumanTOTPWriteModel, err error) { func (c *Commands) totpWriteModelByID(ctx context.Context, userID, resourceOwner string) (writeModel *HumanTOTPWriteModel, err error) {

View File

@ -157,24 +157,31 @@ func (wm *HumanOTPSMSWriteModel) Query() *eventstore.SearchQueryBuilder {
type HumanOTPSMSCodeWriteModel struct { type HumanOTPSMSCodeWriteModel struct {
*HumanOTPSMSWriteModel *HumanOTPSMSWriteModel
code *crypto.CryptoValue otpCode *OTPCode
codeCreationDate time.Time
codeExpiry time.Duration
checkFailedCount uint64 checkFailedCount uint64
userLocked bool userLocked bool
} }
func (wm *HumanOTPSMSCodeWriteModel) CodeCreationDate() time.Time { func (wm *HumanOTPSMSCodeWriteModel) CodeCreationDate() time.Time {
return wm.codeCreationDate if wm.otpCode == nil {
return time.Time{}
}
return wm.otpCode.CreationDate
} }
func (wm *HumanOTPSMSCodeWriteModel) CodeExpiry() time.Duration { func (wm *HumanOTPSMSCodeWriteModel) CodeExpiry() time.Duration {
return wm.codeExpiry if wm.otpCode == nil {
return 0
}
return wm.otpCode.Expiry
} }
func (wm *HumanOTPSMSCodeWriteModel) Code() *crypto.CryptoValue { func (wm *HumanOTPSMSCodeWriteModel) Code() *crypto.CryptoValue {
return wm.code if wm.otpCode == nil {
return nil
}
return wm.otpCode.Code
} }
func (wm *HumanOTPSMSCodeWriteModel) CheckFailedCount() uint64 { func (wm *HumanOTPSMSCodeWriteModel) CheckFailedCount() uint64 {
@ -195,9 +202,11 @@ func (wm *HumanOTPSMSCodeWriteModel) Reduce() error {
for _, event := range wm.Events { for _, event := range wm.Events {
switch e := event.(type) { switch e := event.(type) {
case *user.HumanOTPSMSCodeAddedEvent: case *user.HumanOTPSMSCodeAddedEvent:
wm.code = e.Code wm.otpCode = &OTPCode{
wm.codeCreationDate = e.CreationDate() Code: e.Code,
wm.codeExpiry = e.Expiry CreationDate: e.CreationDate(),
Expiry: e.Expiry,
}
case *user.HumanOTPSMSCheckSucceededEvent: case *user.HumanOTPSMSCheckSucceededEvent:
wm.checkFailedCount = 0 wm.checkFailedCount = 0
case *user.HumanOTPSMSCheckFailedEvent: case *user.HumanOTPSMSCheckFailedEvent:
@ -300,24 +309,31 @@ func (wm *HumanOTPEmailWriteModel) Query() *eventstore.SearchQueryBuilder {
type HumanOTPEmailCodeWriteModel struct { type HumanOTPEmailCodeWriteModel struct {
*HumanOTPEmailWriteModel *HumanOTPEmailWriteModel
code *crypto.CryptoValue otpCode *OTPCode
codeCreationDate time.Time
codeExpiry time.Duration
checkFailedCount uint64 checkFailedCount uint64
userLocked bool userLocked bool
} }
func (wm *HumanOTPEmailCodeWriteModel) CodeCreationDate() time.Time { func (wm *HumanOTPEmailCodeWriteModel) CodeCreationDate() time.Time {
return wm.codeCreationDate if wm.otpCode == nil {
return time.Time{}
}
return wm.otpCode.CreationDate
} }
func (wm *HumanOTPEmailCodeWriteModel) CodeExpiry() time.Duration { func (wm *HumanOTPEmailCodeWriteModel) CodeExpiry() time.Duration {
return wm.codeExpiry if wm.otpCode == nil {
return 0
}
return wm.otpCode.Expiry
} }
func (wm *HumanOTPEmailCodeWriteModel) Code() *crypto.CryptoValue { func (wm *HumanOTPEmailCodeWriteModel) Code() *crypto.CryptoValue {
return wm.code if wm.otpCode == nil {
return nil
}
return wm.otpCode.Code
} }
func (wm *HumanOTPEmailCodeWriteModel) CheckFailedCount() uint64 { func (wm *HumanOTPEmailCodeWriteModel) CheckFailedCount() uint64 {
@ -338,9 +354,11 @@ func (wm *HumanOTPEmailCodeWriteModel) Reduce() error {
for _, event := range wm.Events { for _, event := range wm.Events {
switch e := event.(type) { switch e := event.(type) {
case *user.HumanOTPEmailCodeAddedEvent: case *user.HumanOTPEmailCodeAddedEvent:
wm.code = e.Code wm.otpCode = &OTPCode{
wm.codeCreationDate = e.CreationDate() Code: e.Code,
wm.codeExpiry = e.Expiry CreationDate: e.CreationDate(),
Expiry: e.Expiry,
}
case *user.HumanOTPEmailCheckSucceededEvent: case *user.HumanOTPEmailCheckSucceededEvent:
wm.checkFailedCount = 0 wm.checkFailedCount = 0
case *user.HumanOTPEmailCheckFailedEvent: case *user.HumanOTPEmailCheckFailedEvent:

View File

@ -295,8 +295,8 @@ func (c *Commands) PasswordChangeSent(ctx context.Context, orgID, userID string)
return err return err
} }
// HumanCheckPassword check password for user with additional informations from authRequest // HumanCheckPassword check password for user with additional information from authRequest
func (c *Commands) HumanCheckPassword(ctx context.Context, orgID, userID, password string, authRequest *domain.AuthRequest, lockoutPolicy *domain.LockoutPolicy) (err error) { func (c *Commands) HumanCheckPassword(ctx context.Context, orgID, userID, password string, authRequest *domain.AuthRequest) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -314,56 +314,66 @@ func (c *Commands) HumanCheckPassword(ctx context.Context, orgID, userID, passwo
if !loginPolicy.AllowUsernamePassword { if !loginPolicy.AllowUsernamePassword {
return zerrors.ThrowPreconditionFailed(err, "COMMAND-Dft32", "Errors.Org.LoginPolicy.UsernamePasswordNotAllowed") return zerrors.ThrowPreconditionFailed(err, "COMMAND-Dft32", "Errors.Org.LoginPolicy.UsernamePasswordNotAllowed")
} }
commands, err := checkPassword(ctx, userID, password, c.eventstore, c.userPasswordHasher, authRequestDomainToAuthRequestInfo(authRequest))
wm, err := c.passwordWriteModel(ctx, userID, orgID) if len(commands) == 0 {
if err != nil {
return err return err
} }
_, pushErr := c.eventstore.Push(ctx, commands...)
logging.OnError(pushErr).Error("error create password check failed event")
return err
}
if !isUserStateExists(wm.UserState) { func checkPassword(ctx context.Context, userID, password string, es *eventstore.Eventstore, hasher *crypto.Hasher, optionalAuthRequestInfo *user.AuthRequestInfo) ([]eventstore.Command, error) {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-3n77z", "Errors.User.NotFound") if userID == "" {
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sfw3f", "Errors.User.UserIDMissing")
}
wm := NewHumanPasswordWriteModel(userID, "")
err := es.FilterToQueryReducer(ctx, wm)
if err != nil {
return nil, err
}
if !wm.UserState.Exists() {
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-3n77z", "Errors.User.NotFound")
} }
if wm.UserState == domain.UserStateLocked { if wm.UserState == domain.UserStateLocked {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-JLK35", "Errors.User.Locked") return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-JLK35", "Errors.User.Locked")
} }
if wm.EncodedHash == "" { if wm.EncodedHash == "" {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-3nJ4t", "Errors.User.Password.NotSet") return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-3nJ4t", "Errors.User.Password.NotSet")
} }
userAgg := UserAggregateFromWriteModel(&wm.WriteModel) userAgg := UserAggregateFromWriteModel(&wm.WriteModel)
ctx, spanPasswordComparison := tracing.NewNamedSpan(ctx, "passwap.Verify") ctx, spanPasswordComparison := tracing.NewNamedSpan(ctx, "passwap.Verify")
updated, err := c.userPasswordHasher.Verify(wm.EncodedHash, password) updated, err := hasher.Verify(wm.EncodedHash, password)
spanPasswordComparison.EndWithError(err) spanPasswordComparison.EndWithError(err)
err = convertPasswapErr(err) err = convertPasswapErr(err)
commands := make([]eventstore.Command, 0, 2) commands := make([]eventstore.Command, 0, 2)
// recheck for additional events (failed password checks or locks) // recheck for additional events (failed password checks or locks)
recheckErr := c.eventstore.FilterToQueryReducer(ctx, wm) recheckErr := es.FilterToQueryReducer(ctx, wm)
if recheckErr != nil { if recheckErr != nil {
return recheckErr return nil, recheckErr
} }
if wm.UserState == domain.UserStateLocked { if wm.UserState == domain.UserStateLocked {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-SFA3t", "Errors.User.Locked") return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-SFA3t", "Errors.User.Locked")
} }
if err == nil { if err == nil {
commands = append(commands, user.NewHumanPasswordCheckSucceededEvent(ctx, userAgg, authRequestDomainToAuthRequestInfo(authRequest))) commands = append(commands, user.NewHumanPasswordCheckSucceededEvent(ctx, userAgg, optionalAuthRequestInfo))
if updated != "" { if updated != "" {
commands = append(commands, user.NewHumanPasswordHashUpdatedEvent(ctx, userAgg, updated)) commands = append(commands, user.NewHumanPasswordHashUpdatedEvent(ctx, userAgg, updated))
} }
_, err = c.eventstore.Push(ctx, commands...) return commands, nil
return err
} }
commands = append(commands, user.NewHumanPasswordCheckFailedEvent(ctx, userAgg, authRequestDomainToAuthRequestInfo(authRequest))) commands = append(commands, user.NewHumanPasswordCheckFailedEvent(ctx, userAgg, optionalAuthRequestInfo))
if lockoutPolicy != nil && lockoutPolicy.MaxPasswordAttempts > 0 {
if wm.PasswordCheckFailedCount+1 >= lockoutPolicy.MaxPasswordAttempts { lockoutPolicy, lockoutErr := getLockoutPolicy(ctx, wm.ResourceOwner, es.FilterToQueryReducer)
commands = append(commands, user.NewUserLockedEvent(ctx, userAgg)) logging.OnError(lockoutErr).Error("unable to get lockout policy")
} if lockoutPolicy != nil && lockoutPolicy.MaxPasswordAttempts > 0 && wm.PasswordCheckFailedCount+1 >= lockoutPolicy.MaxPasswordAttempts {
commands = append(commands, user.NewUserLockedEvent(ctx, userAgg))
} }
_, pushErr := c.eventstore.Push(ctx, commands...) return commands, err
logging.OnError(pushErr).Error("error create password check failed event")
return err
} }
func (c *Commands) passwordWriteModel(ctx context.Context, userID, resourceOwner string) (writeModel *HumanPasswordWriteModel, err error) { func (c *Commands) passwordWriteModel(ctx context.Context, userID, resourceOwner string) (writeModel *HumanPasswordWriteModel, err error) {

View File

@ -1456,7 +1456,6 @@ func TestCommandSide_CheckPassword(t *testing.T) {
resourceOwner string resourceOwner string
password string password string
authReq *domain.AuthRequest authReq *domain.AuthRequest
lockoutPolicy *domain.LockoutPolicy
} }
type res struct { type res struct {
err func(error) bool err func(error) bool
@ -1768,6 +1767,13 @@ func TestCommandSide_CheckPassword(t *testing.T) {
"")), "")),
), ),
expectFilter(), expectFilter(),
expectFilter(
eventFromEventPusher(
org.NewLockoutPolicyAddedEvent(context.Background(),
&org.NewAggregate("org1").Aggregate,
0, 0, false,
)),
),
expectPush( expectPush(
user.NewHumanPasswordCheckFailedEvent(context.Background(), user.NewHumanPasswordCheckFailedEvent(context.Background(),
&user.NewAggregate("user1", "org1").Aggregate, &user.NewAggregate("user1", "org1").Aggregate,
@ -1789,7 +1795,6 @@ func TestCommandSide_CheckPassword(t *testing.T) {
ID: "request1", ID: "request1",
AgentID: "agent1", AgentID: "agent1",
}, },
lockoutPolicy: &domain.LockoutPolicy{},
}, },
res: res{ res: res{
err: zerrors.IsErrorInvalidArgument, err: zerrors.IsErrorInvalidArgument,
@ -1852,6 +1857,13 @@ func TestCommandSide_CheckPassword(t *testing.T) {
), ),
), ),
expectFilter(), expectFilter(),
expectFilter(
eventFromEventPusher(
org.NewLockoutPolicyAddedEvent(context.Background(),
&org.NewAggregate("org1").Aggregate,
1, 1, false,
)),
),
expectPush( expectPush(
user.NewHumanPasswordCheckFailedEvent(context.Background(), user.NewHumanPasswordCheckFailedEvent(context.Background(),
&user.NewAggregate("user1", "org1").Aggregate, &user.NewAggregate("user1", "org1").Aggregate,
@ -1876,10 +1888,6 @@ func TestCommandSide_CheckPassword(t *testing.T) {
ID: "request1", ID: "request1",
AgentID: "agent1", AgentID: "agent1",
}, },
lockoutPolicy: &domain.LockoutPolicy{
MaxPasswordAttempts: 1,
MaxOTPAttempts: 1,
},
}, },
res: res{ res: res{
err: zerrors.IsErrorInvalidArgument, err: zerrors.IsErrorInvalidArgument,
@ -2230,7 +2238,7 @@ func TestCommandSide_CheckPassword(t *testing.T) {
eventstore: tt.fields.eventstore(t), eventstore: tt.fields.eventstore(t),
userPasswordHasher: tt.fields.userPasswordHasher, userPasswordHasher: tt.fields.userPasswordHasher,
} }
err := r.HumanCheckPassword(tt.args.ctx, tt.args.resourceOwner, tt.args.userID, tt.args.password, tt.args.authReq, tt.args.lockoutPolicy) err := r.HumanCheckPassword(tt.args.ctx, tt.args.resourceOwner, tt.args.userID, tt.args.password, tt.args.authReq)
if tt.res.err == nil { if tt.res.err == nil {
assert.NoError(t, err) assert.NoError(t, err)
} }

View File

@ -281,42 +281,42 @@ func CheckRedirect(req *http.Request) (*url.URL, error) {
return resp.Location() return resp.Location()
} }
func (s *Tester) CreateOIDCCredentialsClient(ctx context.Context) (userID, clientID, clientSecret string, err error) { func (s *Tester) CreateOIDCCredentialsClient(ctx context.Context) (machine *management.AddMachineUserResponse, name, clientID, clientSecret string, err error) {
name := gofakeit.Username() name = gofakeit.Username()
user, err := s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{ machine, err = s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{
Name: name, Name: name,
UserName: name, UserName: name,
AccessTokenType: user.AccessTokenType_ACCESS_TOKEN_TYPE_JWT, AccessTokenType: user.AccessTokenType_ACCESS_TOKEN_TYPE_JWT,
}) })
if err != nil { if err != nil {
return "", "", "", err return nil, "", "", "", err
} }
secret, err := s.Client.Mgmt.GenerateMachineSecret(ctx, &management.GenerateMachineSecretRequest{ secret, err := s.Client.Mgmt.GenerateMachineSecret(ctx, &management.GenerateMachineSecretRequest{
UserId: user.GetUserId(), UserId: machine.GetUserId(),
}) })
if err != nil { if err != nil {
return "", "", "", err return nil, "", "", "", err
} }
return user.GetUserId(), secret.GetClientId(), secret.GetClientSecret(), nil return machine, name, secret.GetClientId(), secret.GetClientSecret(), nil
} }
func (s *Tester) CreateOIDCJWTProfileClient(ctx context.Context) (userID string, keyData []byte, err error) { func (s *Tester) CreateOIDCJWTProfileClient(ctx context.Context) (machine *management.AddMachineUserResponse, name string, keyData []byte, err error) {
name := gofakeit.Username() name = gofakeit.Username()
user, err := s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{ machine, err = s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{
Name: name, Name: name,
UserName: name, UserName: name,
AccessTokenType: user.AccessTokenType_ACCESS_TOKEN_TYPE_JWT, AccessTokenType: user.AccessTokenType_ACCESS_TOKEN_TYPE_JWT,
}) })
if err != nil { if err != nil {
return "", nil, err return nil, "", nil, err
} }
keyResp, err := s.Client.Mgmt.AddMachineKey(ctx, &management.AddMachineKeyRequest{ keyResp, err := s.Client.Mgmt.AddMachineKey(ctx, &management.AddMachineKeyRequest{
UserId: user.GetUserId(), UserId: machine.GetUserId(),
Type: authn.KeyType_KEY_TYPE_JSON, Type: authn.KeyType_KEY_TYPE_JSON,
ExpirationDate: timestamppb.New(time.Now().Add(time.Hour)), ExpirationDate: timestamppb.New(time.Now().Add(time.Hour)),
}) })
if err != nil { if err != nil {
return "", nil, err return nil, "", nil, err
} }
return user.GetUserId(), keyResp.GetKeyDetails(), nil return machine, name, keyResp.GetKeyDetails(), nil
} }