Merge branch 'main' into next

This commit is contained in:
Livio Spring 2024-05-31 12:12:02 +02:00
commit 50e0e7d564
No known key found for this signature in database
GPG Key ID: 26BB1C2FA5952CF0
33 changed files with 1010 additions and 450 deletions

View File

@ -6,9 +6,11 @@
<mat-icon class="icon">info_outline</mat-icon>
</a>
</div>
<p class="sub cnsl-secondary-text max-width-description">
{{ 'DESCRIPTIONS.PROJECTS.DESCRIPTION' | translate }}
</p>
<p
class="sub cnsl-secondary-text max-width-description"
[innerHTML]="'DESCRIPTIONS.PROJECTS.DESCRIPTION' | translate"
></p>
<div class="projects-controls">
<div class="project-toggle-group">
<cnsl-nav-toggle

View File

@ -14,12 +14,15 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/logging"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/zitadel/zitadel/internal/integration"
mgmt "github.com/zitadel/zitadel/pkg/grpc/management"
object "github.com/zitadel/zitadel/pkg/grpc/object/v2beta"
session "github.com/zitadel/zitadel/pkg/grpc/session/v2beta"
user "github.com/zitadel/zitadel/pkg/grpc/user/v2beta"
@ -27,6 +30,7 @@ import (
var (
CTX context.Context
IAMOwnerCTX context.Context
Tester *integration.Tester
Client session.SessionServiceClient
User *user.AddHumanUserResponse
@ -44,6 +48,7 @@ func TestMain(m *testing.M) {
Client = Tester.Client.SessionV2
CTX, _ = Tester.WithAuthorization(ctx, integration.OrgOwner), errCtx
IAMOwnerCTX = Tester.WithAuthorization(ctx, integration.IAMOwner)
User = createFullUser(CTX)
DeactivatedUser = createDeactivatedUser(CTX)
LockedUser = createLockedUser(CTX)
@ -341,6 +346,48 @@ func TestServer_CreateSession(t *testing.T) {
}
}
func TestServer_CreateSession_lock_user(t *testing.T) {
// create a separate org so we don't interfere with any other test
org := Tester.CreateOrganization(IAMOwnerCTX,
fmt.Sprintf("TestServer_CreateSession_lock_user_%d", time.Now().UnixNano()),
fmt.Sprintf("%d@mouse.com", time.Now().UnixNano()),
)
userID := org.CreatedAdmins[0].GetUserId()
Tester.SetUserPassword(IAMOwnerCTX, userID, integration.UserPassword, false)
// enable password lockout
maxAttempts := 2
ctxOrg := metadata.AppendToOutgoingContext(IAMOwnerCTX, "x-zitadel-orgid", org.GetOrganizationId())
_, err := Tester.Client.Mgmt.AddCustomLockoutPolicy(ctxOrg, &mgmt.AddCustomLockoutPolicyRequest{
MaxPasswordAttempts: uint32(maxAttempts),
})
require.NoError(t, err)
for i := 0; i <= maxAttempts; i++ {
_, err := Client.CreateSession(CTX, &session.CreateSessionRequest{
Checks: &session.Checks{
User: &session.CheckUser{
Search: &session.CheckUser_UserId{
UserId: userID,
},
},
Password: &session.CheckPassword{
Password: "invalid",
},
},
})
assert.Error(t, err)
statusCode := status.Code(err)
expectedCode := codes.InvalidArgument
// as soon as we hit the limit the user is locked and following request will
// already deny any check with a precondition failed since the user is locked
if i >= maxAttempts {
expectedCode = codes.FailedPrecondition
}
assert.Equal(t, expectedCode, statusCode)
}
}
func TestServer_CreateSession_webauthn(t *testing.T) {
// create new session with user and request the webauthn challenge
createResp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{

View File

@ -471,7 +471,7 @@ func (s *Server) CreateTokenCallbackURL(ctx context.Context, req op.AuthRequest)
if err != nil {
return "", err
}
resp, err := s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion)
resp, err := s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion)
if err != nil {
return "", err
}
@ -563,7 +563,7 @@ func (s *Server) authResponseToken(authReq *AuthRequest, authorizer op.Authorize
op.AuthRequestError(w, r, authReq, err, authorizer)
return err
}
resp, err := s.accessTokenResponseFromSession(ctx, client, session, authReq.GetState(), client.client.ProjectID, client.client.ProjectRoleAssertion)
resp, err := s.accessTokenResponseFromSession(ctx, client, session, authReq.GetState(), client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion)
if err != nil {
op.AuthRequestError(w, r, authReq, err, authorizer)
return err

View File

@ -54,7 +54,7 @@ func TestOPStorage_CreateAccessToken_code(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// callback on a succeeded request must fail
linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
@ -108,7 +108,7 @@ func TestOPStorage_CreateAccessToken_implicit(t *testing.T) {
require.NoError(t, err)
claims, err := rp.VerifyTokens[*oidc.IDTokenClaims](context.Background(), accessToken, idToken, provider.IDTokenVerifier())
require.NoError(t, err)
assertIDTokenClaims(t, claims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, claims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// callback on a succeeded request must fail
linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
@ -143,7 +143,7 @@ func TestOPStorage_CreateAccessAndRefreshTokens_code(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
}
func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) {
@ -168,14 +168,14 @@ func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// test actual refresh grant
newTokens, err := refreshTokens(t, clientID, tokens.RefreshToken)
require.NoError(t, err)
assertTokens(t, newTokens, true)
// auth time must still be the initial
assertIDTokenClaims(t, newTokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, newTokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// refresh with an old refresh_token must fail
_, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "")
@ -204,7 +204,7 @@ func TestOPStorage_RevokeToken_access_token(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// revoke access token
err = rp.RevokeToken(CTX, provider, tokens.AccessToken, "access_token")
@ -247,7 +247,7 @@ func TestOPStorage_RevokeToken_access_token_invalid_token_hint_type(t *testing.T
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// revoke access token
err = rp.RevokeToken(CTX, provider, tokens.AccessToken, "refresh_token")
@ -284,7 +284,7 @@ func TestOPStorage_RevokeToken_refresh_token(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// revoke refresh token -> invalidates also access token
err = rp.RevokeToken(CTX, provider, tokens.RefreshToken, "refresh_token")
@ -327,7 +327,7 @@ func TestOPStorage_RevokeToken_refresh_token_invalid_token_type_hint(t *testing.
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// revoke refresh token even with a wrong hint
err = rp.RevokeToken(CTX, provider, tokens.RefreshToken, "access_token")
@ -362,7 +362,7 @@ func TestOPStorage_RevokeToken_invalid_client(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// simulate second client (not part of the audience) trying to revoke the token
otherClientID, _ := createClient(t)
@ -394,7 +394,7 @@ func TestOPStorage_TerminateSession(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// userinfo must not fail
_, err = rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider)
@ -431,7 +431,7 @@ func TestOPStorage_TerminateSession_refresh_grant(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// userinfo must not fail
_, err = rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, tokens.IDTokenClaims.Subject, provider)
@ -475,7 +475,7 @@ func TestOPStorage_TerminateSession_empty_id_token_hint(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
postLogoutRedirect, err := rp.EndSession(CTX, provider, "", logoutRedirectURI, "state")
require.NoError(t, err)
@ -530,8 +530,13 @@ func assertTokens(t *testing.T, tokens *oidc.Tokens[*oidc.IDTokenClaims], requir
}
}
func assertIDTokenClaims(t *testing.T, claims *oidc.IDTokenClaims, userID string, arm []string, sessionStart, sessionChange time.Time) {
func assertIDTokenClaims(t *testing.T, claims *oidc.IDTokenClaims, userID string, arm []string, sessionStart, sessionChange time.Time, sessionID string) {
assert.Equal(t, userID, claims.Subject)
assert.Equal(t, arm, claims.AuthenticationMethodsReferences)
assertOIDCTimeRange(t, claims.AuthTime, sessionStart, sessionChange)
assert.Equal(t, sessionID, claims.SessionID)
assert.Empty(t, claims.Name)
assert.Empty(t, claims.GivenName)
assert.Empty(t, claims.FamilyName)
assert.Empty(t, claims.PreferredUsername)
}

View File

@ -122,7 +122,7 @@ func TestServer_Introspect(t *testing.T) {
tokens, err := exchangeTokens(t, app.GetClientId(), code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// test actual introspection
introspection, err := rs.Introspect[*oidc.IntrospectionResponse](context.Background(), resourceServer, tokens.AccessToken)
@ -317,7 +317,7 @@ func TestServer_VerifyClient(t *testing.T) {
}
require.NoError(t, err)
assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
})
}
}

View File

@ -11,6 +11,7 @@ import (
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
@ -80,7 +81,11 @@ func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionR
// with active: false
defer func() {
if err != nil {
s.getLogger(ctx).ErrorContext(ctx, "oidc introspection", "err", err)
if zerrors.IsInternal(err) {
s.getLogger(ctx).ErrorContext(ctx, "oidc introspection", "err", err)
} else {
s.getLogger(ctx).InfoContext(ctx, "oidc introspection", "err", err)
}
resp, err = op.NewResponse(new(oidc.IntrospectionResponse)), nil
}
}()
@ -99,7 +104,14 @@ func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionR
if err = validateIntrospectionAudience(token.audience, client.clientID, client.projectID); err != nil {
return nil, err
}
userInfo, err := s.userInfo(ctx, token.userID, token.scope, client.projectID, client.projectRoleAssertion, true)
userInfo, err := s.userInfo(
token.userID,
token.scope,
client.projectID,
client.projectRoleAssertion,
true,
true,
)(ctx, true, domain.TriggerTypePreUserinfoCreation)
if err != nil {
return nil, err
}

View File

@ -77,7 +77,7 @@ func Test_ZITADEL_API_missing_audience_scope(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
@ -142,7 +142,7 @@ func Test_ZITADEL_API_missing_mfa_2fa_setup(t *testing.T) {
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertIDTokenClaims(t, tokens.IDTokenClaims, userResp.GetUserId(), armPassword, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, userResp.GetUserId(), armPassword, startTime, changeTime, sessionID)
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
@ -173,7 +173,7 @@ func Test_ZITADEL_API_missing_mfa_policy(t *testing.T) {
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertIDTokenClaims(t, tokens.IDTokenClaims, userID, armPassword, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, userID, armPassword, startTime, changeTime, sessionID)
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
@ -227,7 +227,7 @@ func Test_ZITADEL_API_success(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
@ -261,7 +261,7 @@ func Test_ZITADEL_API_glob_redirects(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, false)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
@ -290,7 +290,7 @@ func Test_ZITADEL_API_inactive_access_token(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// make sure token works
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
@ -332,7 +332,7 @@ func Test_ZITADEL_API_terminated_session(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
// make sure token works
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
@ -402,7 +402,7 @@ func Test_ZITADEL_API_terminated_session_user_disabled(t *testing.T) {
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, disabledUser.GetUserId(), armPassword, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, disabledUser.GetUserId(), armPassword, startTime, changeTime, sessionID)
// make sure token works
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))

View File

@ -29,8 +29,8 @@ In some cases step 1 till 3 are completely implemented in the command package,
for example the v2 code exchange and refresh token.
*/
func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.Client, session *command.OIDCSession, state, projectID string, projectRoleAssertion bool) (_ *oidc.AccessTokenResponse, err error) {
getUserInfo := s.getUserInfoOnce(session.UserID, projectID, projectRoleAssertion, session.Scope)
func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.Client, session *command.OIDCSession, state, projectID string, projectRoleAssertion, accessTokenRoleAssertion, idTokenRoleAssertion, userInfoAssertion bool) (_ *oidc.AccessTokenResponse, err error) {
getUserInfo := s.getUserInfo(session.UserID, projectID, projectRoleAssertion, userInfoAssertion, session.Scope)
getSigner := s.getSignerOnce()
resp := &oidc.AccessTokenResponse{
@ -43,7 +43,7 @@ func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.C
// If the session does not have a token ID, it is an implicit ID-Token only response.
if session.TokenID != "" {
if client.AccessTokenType() == op.AccessTokenTypeJWT {
resp.AccessToken, err = s.createJWT(ctx, client, session, getUserInfo, getSigner)
resp.AccessToken, err = s.createJWT(ctx, client, session, getUserInfo, accessTokenRoleAssertion, getSigner)
} else {
resp.AccessToken, err = op.CreateBearerToken(session.TokenID, session.UserID, s.opCrypto)
}
@ -53,7 +53,7 @@ func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.C
}
if slices.Contains(session.Scope, oidc.ScopeOpenID) {
resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, getSigner, session.SessionID, resp.AccessToken, session.Audience, session.AuthMethods, session.AuthTime, session.Nonce, session.Actor)
resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, idTokenRoleAssertion, getSigner, session.SessionID, resp.AccessToken, session.Audience, session.AuthMethods, session.AuthTime, session.Nonce, session.Actor)
}
return resp, err
}
@ -92,31 +92,22 @@ func (s *Server) getSignerOnce() signerFunc {
}
// userInfoFunc is a getter function that allows add-hoc retrieval of a user.
type userInfoFunc func(ctx context.Context) (*oidc.UserInfo, error)
type userInfoFunc func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (*oidc.UserInfo, error)
// getUserInfoOnce returns a function which retrieves userinfo from the database once.
// Repeated calls of the returned function return the same results.
func (s *Server) getUserInfoOnce(userID, projectID string, projectRoleAssertion bool, scope []string) userInfoFunc {
var (
once sync.Once
userInfo *oidc.UserInfo
err error
)
return func(ctx context.Context) (*oidc.UserInfo, error) {
once.Do(func() {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
userInfo, err = s.userInfo(ctx, userID, scope, projectID, projectRoleAssertion, false)
})
return userInfo, err
// getUserInfo returns a function which retrieves userinfo from the database once.
// However, each time, role claims are asserted and also action flows will trigger.
func (s *Server) getUserInfo(userID, projectID string, projectRoleAssertion, userInfoAssertion bool, scope []string) userInfoFunc {
userInfo := s.userInfo(userID, scope, projectID, projectRoleAssertion, userInfoAssertion, false)
return func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (*oidc.UserInfo, error) {
return userInfo(ctx, roleAssertion, triggerType)
}
}
func (*Server) createIDToken(ctx context.Context, client op.Client, getUserInfo userInfoFunc, getSigningKey signerFunc, sessionID, accessToken string, audience []string, authMethods []domain.UserAuthMethodType, authTime time.Time, nonce string, actor *domain.TokenActor) (idToken string, exp uint64, err error) {
func (*Server) createIDToken(ctx context.Context, client op.Client, getUserInfo userInfoFunc, roleAssertion bool, getSigningKey signerFunc, sessionID, accessToken string, audience []string, authMethods []domain.UserAuthMethodType, authTime time.Time, nonce string, actor *domain.TokenActor) (idToken string, exp uint64, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
userInfo, err := getUserInfo(ctx)
userInfo, err := getUserInfo(ctx, roleAssertion, domain.TriggerTypePreUserinfoCreation)
if err != nil {
return "", 0, err
}
@ -156,11 +147,11 @@ func timeToOIDCExpiresIn(exp time.Time) uint64 {
return uint64(time.Until(exp) / time.Second)
}
func (*Server) createJWT(ctx context.Context, client op.Client, session *command.OIDCSession, getUserInfo userInfoFunc, getSigner signerFunc) (_ string, err error) {
func (s *Server) createJWT(ctx context.Context, client op.Client, session *command.OIDCSession, getUserInfo userInfoFunc, assertRoles bool, getSigner signerFunc) (_ string, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
userInfo, err := getUserInfo(ctx)
userInfo, err := getUserInfo(ctx, assertRoles, domain.TriggerTypePreAccessTokenCreation)
if err != nil {
return "", err
}

View File

@ -47,5 +47,5 @@ func (s *Server) ClientCredentialsExchange(ctx context.Context, r *op.ClientRequ
false,
)
return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false))
return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false, true, false, false))
}

View File

@ -4,6 +4,7 @@ package oidc_test
import (
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/stretchr/testify/assert"
@ -18,10 +19,13 @@ import (
)
func TestServer_ClientCredentialsExchange(t *testing.T) {
userID, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX)
machine, name, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX)
require.NoError(t, err)
type claims struct {
name string
username string
updated time.Time
resourceOwnerID any
resourceOwnerName any
resourceOwnerPrimaryDomain any
@ -78,6 +82,17 @@ func TestServer_ClientCredentialsExchange(t *testing.T) {
clientSecret: clientSecret,
scope: []string{oidc.ScopeOpenID},
},
{
name: "openid, profile, email",
clientID: clientID,
clientSecret: clientSecret,
scope: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail},
wantClaims: claims{
name: name,
username: name,
updated: machine.GetDetails().GetChangeDate().AsTime(),
},
},
{
name: "org id and domain scope",
clientID: clientID,
@ -132,12 +147,20 @@ func TestServer_ClientCredentialsExchange(t *testing.T) {
}
require.NoError(t, err)
require.NotNil(t, tokens)
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, userID, provider)
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, machine.GetUserId(), provider)
require.NoError(t, err)
assert.Equal(t, tt.wantClaims.resourceOwnerID, userinfo.Claims[oidc_api.ClaimResourceOwnerID])
assert.Equal(t, tt.wantClaims.resourceOwnerName, userinfo.Claims[oidc_api.ClaimResourceOwnerName])
assert.Equal(t, tt.wantClaims.resourceOwnerPrimaryDomain, userinfo.Claims[oidc_api.ClaimResourceOwnerPrimaryDomain])
assert.Equal(t, tt.wantClaims.orgDomain, userinfo.Claims[domain.OrgDomainPrimaryClaim])
assert.Equal(t, tt.wantClaims.name, userinfo.Name)
assert.Equal(t, tt.wantClaims.username, userinfo.PreferredUsername)
assertOIDCTime(t, userinfo.UpdatedAt, tt.wantClaims.updated)
assert.Empty(t, userinfo.UserInfoProfile.FamilyName)
assert.Empty(t, userinfo.UserInfoProfile.GivenName)
assert.Empty(t, userinfo.UserInfoEmail)
assert.Empty(t, userinfo.UserInfoPhone)
assert.Empty(t, userinfo.Address)
})
}
}

View File

@ -49,7 +49,7 @@ func (s *Server) CodeExchange(ctx context.Context, r *op.ClientRequest[oidc.Acce
if err != nil {
return nil, err
}
return response(s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion))
return response(s.accessTokenResponseFromSession(ctx, client, session, state, client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion))
}
// codeExchangeV1 creates a v2 token from a v1 auth request.

View File

@ -26,7 +26,7 @@ func (s *Server) DeviceToken(ctx context.Context, r *op.ClientRequest[oidc.Devic
}
session, err := s.command.CreateOIDCSessionFromDeviceAuth(ctx, r.Data.DeviceCode)
if err == nil {
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion))
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion))
}
if errors.Is(err, context.DeadlineExceeded) {
return nil, oidc.ErrSlowDown().WithParent(err)

View File

@ -218,7 +218,7 @@ func validateTokenExchangeAudience(requestedAudience, subjectAudience, actorAudi
// Both tokens may point to the same object (subjectToken) in case of a regular Token Exchange.
// When the subject and actor Tokens point to different objects, the new tokens will be for impersonation / delegation.
func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenType, client *Client, subjectToken, actorToken *exchangeToken, audience, scopes []string) (_ *oidc.TokenExchangeResponse, err error) {
getUserInfo := s.getUserInfoOnce(subjectToken.userID, client.client.ProjectID, client.client.ProjectRoleAssertion, scopes)
getUserInfo := s.getUserInfo(subjectToken.userID, client.client.ProjectID, client.client.ProjectRoleAssertion, client.IDTokenUserinfoClaimsAssertion(), scopes)
getSigner := s.getSignerOnce()
resp := &oidc.TokenExchangeResponse{
@ -240,12 +240,12 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT
resp.IssuedTokenType = oidc.AccessTokenType
case oidc.JWTTokenType:
resp.AccessToken, resp.RefreshToken, resp.ExpiresIn, err = s.createExchangeJWT(ctx, client, getUserInfo, getSigner, subjectToken.userID, subjectToken.resourceOwner, audience, scopes, actorToken.authMethods, actorToken.authTime, subjectToken.preferredLanguage, reason, actor)
resp.AccessToken, resp.RefreshToken, resp.ExpiresIn, err = s.createExchangeJWT(ctx, client, getUserInfo, client.client.AccessTokenRoleAssertion, getSigner, subjectToken.userID, subjectToken.resourceOwner, audience, scopes, actorToken.authMethods, actorToken.authTime, subjectToken.preferredLanguage, reason, actor)
resp.TokenType = oidc.BearerToken
resp.IssuedTokenType = oidc.JWTTokenType
case oidc.IDTokenType:
resp.AccessToken, resp.ExpiresIn, err = s.createIDToken(ctx, client, getUserInfo, getSigner, "", resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor)
resp.AccessToken, resp.ExpiresIn, err = s.createIDToken(ctx, client, getUserInfo, client.client.IDTokenRoleAssertion, getSigner, "", resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor)
resp.TokenType = TokenTypeNA
resp.IssuedTokenType = oidc.IDTokenType
@ -259,7 +259,7 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT
}
if slices.Contains(scopes, oidc.ScopeOpenID) && tokenType != oidc.IDTokenType {
resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, getSigner, sessionID, resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor)
resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, client.client.IDTokenRoleAssertion, getSigner, sessionID, resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor)
if err != nil {
return nil, err
}
@ -313,6 +313,7 @@ func (s *Server) createExchangeJWT(
ctx context.Context,
client *Client,
getUserInfo userInfoFunc,
roleAssertion bool,
getSigner signerFunc,
userID,
resourceOwner string,
@ -342,7 +343,7 @@ func (s *Server) createExchangeJWT(
actor,
slices.Contains(scope, oidc.ScopeOfflineAccess),
)
accessToken, err = s.createJWT(ctx, client, session, getUserInfo, getSigner)
accessToken, err = s.createJWT(ctx, client, session, getUserInfo, roleAssertion, getSigner)
if err != nil {
return "", "", 0, err
}

View File

@ -54,7 +54,7 @@ func (s *Server) JWTProfile(ctx context.Context, r *op.Request[oidc.JWTProfileGr
nil,
false,
)
return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false))
return response(s.accessTokenResponseFromSession(ctx, client, session, "", "", false, true, false, false))
}
func (s *Server) verifyJWTProfile(ctx context.Context, req *oidc.JWTProfileGrantRequest) (user *query.User, tokenRequest *oidc.JWTTokenRequest, err error) {

View File

@ -4,6 +4,7 @@ package oidc_test
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -16,10 +17,13 @@ import (
)
func TestServer_JWTProfile(t *testing.T) {
userID, keyData, err := Tester.CreateOIDCJWTProfileClient(CTX)
user, name, keyData, err := Tester.CreateOIDCJWTProfileClient(CTX)
require.NoError(t, err)
type claims struct {
name string
username string
updated time.Time
resourceOwnerID any
resourceOwnerName any
resourceOwnerPrimaryDomain any
@ -37,6 +41,16 @@ func TestServer_JWTProfile(t *testing.T) {
keyData: keyData,
scope: []string{oidc.ScopeOpenID},
},
{
name: "openid, profile, email",
keyData: keyData,
scope: []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail},
wantClaims: claims{
name: name,
username: name,
updated: user.GetDetails().GetChangeDate().AsTime(),
},
},
{
name: "org id and domain scope",
keyData: keyData,
@ -92,12 +106,20 @@ func TestServer_JWTProfile(t *testing.T) {
provider, err := rp.NewRelyingPartyOIDC(CTX, Tester.OIDCIssuer(), "", "", redirectURI, tt.scope)
require.NoError(t, err)
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, userID, provider)
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, oidc.BearerToken, user.GetUserId(), provider)
require.NoError(t, err)
assert.Equal(t, tt.wantClaims.resourceOwnerID, userinfo.Claims[oidc_api.ClaimResourceOwnerID])
assert.Equal(t, tt.wantClaims.resourceOwnerName, userinfo.Claims[oidc_api.ClaimResourceOwnerName])
assert.Equal(t, tt.wantClaims.resourceOwnerPrimaryDomain, userinfo.Claims[oidc_api.ClaimResourceOwnerPrimaryDomain])
assert.Equal(t, tt.wantClaims.orgDomain, userinfo.Claims[domain.OrgDomainPrimaryClaim])
assert.Equal(t, tt.wantClaims.name, userinfo.Name)
assert.Equal(t, tt.wantClaims.username, userinfo.PreferredUsername)
assertOIDCTime(t, userinfo.UpdatedAt, tt.wantClaims.updated)
assert.Empty(t, userinfo.UserInfoProfile.FamilyName)
assert.Empty(t, userinfo.UserInfoProfile.GivenName)
assert.Empty(t, userinfo.UserInfoEmail)
assert.Empty(t, userinfo.UserInfoPhone)
assert.Empty(t, userinfo.Address)
})
}
}

View File

@ -28,7 +28,7 @@ func (s *Server) RefreshToken(ctx context.Context, r *op.ClientRequest[oidc.Refr
session, err := s.command.ExchangeOIDCSessionRefreshAndAccessToken(ctx, r.Data.RefreshToken, r.Data.Scopes, refreshTokenComplianceChecker())
if err == nil {
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion))
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion))
} else if errors.Is(err, zerrors.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid")) {
// We try again for v1 tokens when we encountered specific parsing error
return s.refreshTokenV1(ctx, client, r)
@ -78,7 +78,7 @@ func (s *Server) refreshTokenV1(ctx context.Context, client *Client, r *op.Clien
return nil, err
}
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion))
return response(s.accessTokenResponseFromSession(ctx, client, session, "", client.client.ProjectID, client.client.ProjectRoleAssertion, client.client.AccessTokenRoleAssertion, client.client.IDTokenRoleAssertion, client.client.IDTokenUserinfoAssertion))
}
// refreshTokenComplianceChecker validates that the requested scope is a subset of the original auth request scope.

View File

@ -8,6 +8,7 @@ import (
"net/http"
"slices"
"strings"
"sync"
"github.com/dop251/goja"
"github.com/zitadel/logging"
@ -55,7 +56,14 @@ func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoReques
}
}
userInfo, err := s.userInfo(ctx, token.userID, token.scope, projectID, assertion, false)
userInfo, err := s.userInfo(
token.userID,
token.scope,
projectID,
assertion,
true,
false,
)(ctx, true, domain.TriggerTypePreUserinfoCreation)
if err != nil {
return nil, err
}
@ -66,24 +74,44 @@ func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoReques
// The returned UserInfo contains standard and reserved claims, documented
// here: https://zitadel.com/docs/apis/openidoauth/claims.
//
// User information is only retrieved once from the database.
// However, each time, role claims are asserted and also action flows will trigger.
//
// projectID is an optional parameter which defines the default audience when there are any (or all) role claims requested.
// projectRoleAssertion sets the default of returning all project roles, only if no specific roles were requested in the scope.
// roleAssertion decides whether the roles will be returned (in the token or response)
// userInfoAssertion decides whether the user information (profile data like name, email, ...) are returned
//
// currentProjectOnly can be set to use the current project ID only and ignore the audience from the scope.
// It should be set in cases where the client doesn't need to know roles outside its own project,
// for example an introspection client.
func (s *Server) userInfo(ctx context.Context, userID string, scope []string, projectID string, projectRoleAssertion, currentProjectOnly bool) (_ *oidc.UserInfo, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
func (s *Server) userInfo(
userID string,
scope []string,
projectID string,
projectRoleAssertion, userInfoAssertion, currentProjectOnly bool,
) func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (_ *oidc.UserInfo, err error) {
var (
once sync.Once
userInfo *oidc.UserInfo
qu *query.OIDCUserInfo
roleAudience, requestedRoles []string
)
return func(ctx context.Context, roleAssertion bool, triggerType domain.TriggerType) (_ *oidc.UserInfo, err error) {
once.Do(func() {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
roleAudience, requestedRoles := prepareRoles(ctx, scope, projectID, projectRoleAssertion, currentProjectOnly)
qu, err := s.query.GetOIDCUserInfo(ctx, userID, roleAudience)
if err != nil {
return nil, err
roleAudience, requestedRoles = prepareRoles(ctx, scope, projectID, projectRoleAssertion, currentProjectOnly)
qu, err = s.query.GetOIDCUserInfo(ctx, userID, roleAudience)
if err != nil {
return
}
userInfo = userInfoToOIDC(qu, userInfoAssertion, scope, s.assetAPIPrefix(ctx))
})
userInfoWithRoles := assertRoles(projectID, qu, roleAudience, requestedRoles, roleAssertion, userInfo)
return userInfoWithRoles, s.userinfoFlows(ctx, qu, userInfoWithRoles, triggerType)
}
userInfo := userInfoToOIDC(projectID, qu, scope, roleAudience, requestedRoles, s.assetAPIPrefix(ctx))
return userInfo, s.userinfoFlows(ctx, qu, userInfo)
}
// prepareRoles scans the requested scopes and builds the requested roles
@ -120,20 +148,32 @@ func prepareRoles(ctx context.Context, scope []string, projectID string, project
return roleAudience, requestedRoles
}
func userInfoToOIDC(projectID string, user *query.OIDCUserInfo, scope, roleAudience, requestedRoles []string, assetPrefix string) *oidc.UserInfo {
func userInfoToOIDC(user *query.OIDCUserInfo, userInfoAssertion bool, scope []string, assetPrefix string) *oidc.UserInfo {
out := new(oidc.UserInfo)
for _, s := range scope {
switch s {
case oidc.ScopeOpenID:
out.Subject = user.User.ID
case oidc.ScopeEmail:
if !userInfoAssertion {
continue
}
out.UserInfoEmail = userInfoEmailToOIDC(user.User)
case oidc.ScopeProfile:
if !userInfoAssertion {
continue
}
out.UserInfoProfile = userInfoProfileToOidc(user.User, assetPrefix)
case oidc.ScopePhone:
if !userInfoAssertion {
continue
}
out.UserInfoPhone = userInfoPhoneToOIDC(user.User)
case oidc.ScopeAddress:
//TODO: handle address for human users as soon as implemented
if !userInfoAssertion {
continue
}
// TODO: handle address for human users as soon as implemented
case ScopeUserMetaData:
setUserInfoMetadata(user.Metadata, out)
case ScopeResourceOwner:
@ -148,12 +188,19 @@ func userInfoToOIDC(projectID string, user *query.OIDCUserInfo, scope, roleAudie
}
}
}
return out
}
func assertRoles(projectID string, user *query.OIDCUserInfo, roleAudience, requestedRoles []string, assertion bool, info *oidc.UserInfo) *oidc.UserInfo {
if !assertion {
return info
}
userInfo := *info
// prevent returning obtained grants if none where requested
if (projectID != "" && len(requestedRoles) > 0) || len(roleAudience) > 0 {
setUserInfoRoleClaims(out, newProjectRoles(projectID, user.UserGrants, requestedRoles))
setUserInfoRoleClaims(&userInfo, newProjectRoles(projectID, user.UserGrants, requestedRoles))
}
return out
return &userInfo
}
func userInfoEmailToOIDC(user *query.User) oidc.UserInfoEmail {
@ -230,11 +277,11 @@ func setUserInfoRoleClaims(userInfo *oidc.UserInfo, roles *projectsRoles) {
}
}
func (s *Server) userinfoFlows(ctx context.Context, qu *query.OIDCUserInfo, userInfo *oidc.UserInfo) (err error) {
func (s *Server) userinfoFlows(ctx context.Context, qu *query.OIDCUserInfo, userInfo *oidc.UserInfo, triggerType domain.TriggerType) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
queriedActions, err := s.query.GetActiveActionsByFlowAndTriggerType(ctx, domain.FlowTypeCustomiseToken, domain.TriggerTypePreUserinfoCreation, qu.User.ResourceOwner)
queriedActions, err := s.query.GetActiveActionsByFlowAndTriggerType(ctx, domain.FlowTypeCustomiseToken, triggerType, qu.User.ResourceOwner)
if err != nil {
return err
}

View File

@ -231,9 +231,9 @@ func TestServer_UserInfo_Issue6662(t *testing.T) {
project, err := Tester.CreateProject(CTX)
projectID := project.GetId()
require.NoError(t, err)
userID, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX)
user, _, clientID, clientSecret, err := Tester.CreateOIDCCredentialsClient(CTX)
require.NoError(t, err)
addProjectRolesGrants(t, userID, projectID, roleFoo, roleBar)
addProjectRolesGrants(t, user.GetUserId(), projectID, roleFoo, roleBar)
scope := []string{oidc.ScopeProfile, oidc.ScopeOpenID, oidc.ScopeEmail, oidc.ScopeOfflineAccess,
oidc_api.ScopeProjectRolePrefix + roleFoo,
@ -245,7 +245,7 @@ func TestServer_UserInfo_Issue6662(t *testing.T) {
tokens, err := rp.ClientCredentials(CTX, provider, nil)
require.NoError(t, err)
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, userID, provider)
userinfo, err := rp.Userinfo[*oidc.UserInfo](CTX, tokens.AccessToken, tokens.TokenType, user.GetUserId(), provider)
require.NoError(t, err)
assertProjectRoleClaims(t, projectID, userinfo.Claims, false, roleFoo)
}
@ -291,7 +291,7 @@ func getTokens(t *testing.T, clientID string, scope []string) *oidc.Tokens[*oidc
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertTokens(t, tokens, true)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPasskey, startTime, changeTime, sessionID)
return tokens
}

View File

@ -3,7 +3,6 @@ package oidc
import (
"context"
"encoding/base64"
"fmt"
"testing"
"time"
@ -267,11 +266,9 @@ func Test_userInfoToOIDC(t *testing.T) {
}
type args struct {
projectID string
user *query.OIDCUserInfo
scope []string
roleAudience []string
requestedRoles []string
user *query.OIDCUserInfo
userInfoAssertion bool
scope []string
}
tests := []struct {
name string
@ -281,25 +278,22 @@ func Test_userInfoToOIDC(t *testing.T) {
{
name: "human, empty",
args: args{
projectID: "project1",
user: humanUserInfo,
user: humanUserInfo,
},
want: &oidc.UserInfo{},
},
{
name: "machine, empty",
args: args{
projectID: "project1",
user: machineUserInfo,
user: machineUserInfo,
},
want: &oidc.UserInfo{},
},
{
name: "human, scope openid",
args: args{
projectID: "project1",
user: humanUserInfo,
scope: []string{oidc.ScopeOpenID},
user: humanUserInfo,
scope: []string{oidc.ScopeOpenID},
},
want: &oidc.UserInfo{
Subject: "human1",
@ -308,20 +302,19 @@ func Test_userInfoToOIDC(t *testing.T) {
{
name: "machine, scope openid",
args: args{
projectID: "project1",
user: machineUserInfo,
scope: []string{oidc.ScopeOpenID},
user: machineUserInfo,
scope: []string{oidc.ScopeOpenID},
},
want: &oidc.UserInfo{
Subject: "machine1",
},
},
{
name: "human, scope email",
name: "human, scope email, profileInfoAssertion",
args: args{
projectID: "project1",
user: humanUserInfo,
scope: []string{oidc.ScopeEmail},
user: humanUserInfo,
userInfoAssertion: true,
scope: []string{oidc.ScopeEmail},
},
want: &oidc.UserInfo{
UserInfoEmail: oidc.UserInfoEmail{
@ -331,22 +324,29 @@ func Test_userInfoToOIDC(t *testing.T) {
},
},
{
name: "machine, scope email",
name: "human, scope email",
args: args{
projectID: "project1",
user: machineUserInfo,
scope: []string{oidc.ScopeEmail},
user: humanUserInfo,
scope: []string{oidc.ScopeEmail},
},
want: &oidc.UserInfo{},
},
{
name: "machine, scope email, profileInfoAssertion",
args: args{
user: machineUserInfo,
scope: []string{oidc.ScopeEmail},
},
want: &oidc.UserInfo{
UserInfoEmail: oidc.UserInfoEmail{},
},
},
{
name: "human, scope profile",
name: "human, scope profile, profileInfoAssertion",
args: args{
projectID: "project1",
user: humanUserInfo,
scope: []string{oidc.ScopeProfile},
user: humanUserInfo,
userInfoAssertion: true,
scope: []string{oidc.ScopeProfile},
},
want: &oidc.UserInfo{
UserInfoProfile: oidc.UserInfoProfile{
@ -363,11 +363,11 @@ func Test_userInfoToOIDC(t *testing.T) {
},
},
{
name: "machine, scope profile",
name: "machine, scope profile, profileInfoAssertion",
args: args{
projectID: "project1",
user: machineUserInfo,
scope: []string{oidc.ScopeProfile},
user: machineUserInfo,
userInfoAssertion: true,
scope: []string{oidc.ScopeProfile},
},
want: &oidc.UserInfo{
UserInfoProfile: oidc.UserInfoProfile{
@ -378,11 +378,19 @@ func Test_userInfoToOIDC(t *testing.T) {
},
},
{
name: "human, scope phone",
name: "machine, scope profile",
args: args{
projectID: "project1",
user: humanUserInfo,
scope: []string{oidc.ScopePhone},
user: machineUserInfo,
scope: []string{oidc.ScopeProfile},
},
want: &oidc.UserInfo{},
},
{
name: "human, scope phone, profileInfoAssertion",
args: args{
user: humanUserInfo,
userInfoAssertion: true,
scope: []string{oidc.ScopePhone},
},
want: &oidc.UserInfo{
UserInfoPhone: oidc.UserInfoPhone{
@ -391,12 +399,19 @@ func Test_userInfoToOIDC(t *testing.T) {
},
},
},
{
name: "human, scope phone",
args: args{
user: humanUserInfo,
scope: []string{oidc.ScopePhone},
},
want: &oidc.UserInfo{},
},
{
name: "machine, scope phone",
args: args{
projectID: "project1",
user: machineUserInfo,
scope: []string{oidc.ScopePhone},
user: machineUserInfo,
scope: []string{oidc.ScopePhone},
},
want: &oidc.UserInfo{
UserInfoPhone: oidc.UserInfoPhone{},
@ -405,9 +420,8 @@ func Test_userInfoToOIDC(t *testing.T) {
{
name: "human, scope metadata",
args: args{
projectID: "project1",
user: humanUserInfo,
scope: []string{ScopeUserMetaData},
user: humanUserInfo,
scope: []string{ScopeUserMetaData},
},
want: &oidc.UserInfo{
Claims: map[string]any{
@ -421,18 +435,16 @@ func Test_userInfoToOIDC(t *testing.T) {
{
name: "machine, scope metadata, none found",
args: args{
projectID: "project1",
user: machineUserInfo,
scope: []string{ScopeUserMetaData},
user: machineUserInfo,
scope: []string{ScopeUserMetaData},
},
want: &oidc.UserInfo{},
},
{
name: "machine, scope resource owner",
args: args{
projectID: "project1",
user: machineUserInfo,
scope: []string{ScopeResourceOwner},
user: machineUserInfo,
scope: []string{ScopeResourceOwner},
},
want: &oidc.UserInfo{
Claims: map[string]any{
@ -445,9 +457,8 @@ func Test_userInfoToOIDC(t *testing.T) {
{
name: "human, scope org primary domain prefix",
args: args{
projectID: "project1",
user: humanUserInfo,
scope: []string{domain.OrgDomainPrimaryScope + "foo.com"},
user: humanUserInfo,
scope: []string{domain.OrgDomainPrimaryScope + "foo.com"},
},
want: &oidc.UserInfo{
Claims: map[string]any{
@ -458,9 +469,8 @@ func Test_userInfoToOIDC(t *testing.T) {
{
name: "machine, scope org id",
args: args{
projectID: "project1",
user: machineUserInfo,
scope: []string{domain.OrgIDScope + "orgID"},
user: machineUserInfo,
scope: []string{domain.OrgIDScope + "orgID"},
},
want: &oidc.UserInfo{
Claims: map[string]any{
@ -471,50 +481,11 @@ func Test_userInfoToOIDC(t *testing.T) {
},
},
},
{
name: "human, roleAudience",
args: args{
projectID: "project1",
user: humanUserInfo,
roleAudience: []string{"project1"},
},
want: &oidc.UserInfo{
Claims: map[string]any{
ClaimProjectRoles: projectRoles{
"role1": {"orgID": "orgDomain"},
"role2": {"orgID": "orgDomain"},
},
fmt.Sprintf(ClaimProjectRolesFormat, "project1"): projectRoles{
"role1": {"orgID": "orgDomain"},
"role2": {"orgID": "orgDomain"},
},
},
},
},
{
name: "human, requested roles",
args: args{
projectID: "project1",
user: humanUserInfo,
roleAudience: []string{"project1"},
requestedRoles: []string{"role2"},
},
want: &oidc.UserInfo{
Claims: map[string]any{
ClaimProjectRoles: projectRoles{
"role2": {"orgID": "orgDomain"},
},
fmt.Sprintf(ClaimProjectRolesFormat, "project1"): projectRoles{
"role2": {"orgID": "orgDomain"},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assetPrefix := "https://foo.com/assets"
got := userInfoToOIDC(tt.args.projectID, tt.args.user, tt.args.scope, tt.args.roleAudience, tt.args.requestedRoles, assetPrefix)
got := userInfoToOIDC(tt.args.user, tt.args.userInfoAssertion, tt.args.scope, assetPrefix)
assert.Equal(t, tt.want, got)
})
}

View File

@ -342,7 +342,11 @@ func (l *Login) renderInternalError(w http.ResponseWriter, r *http.Request, auth
if authReq != nil {
log = log.WithField("auth_req_id", authReq.ID)
}
log.Error()
if zerrors.IsInternal(err) {
log.Error()
} else {
log.Info()
}
_, msg = l.getErrorMessage(r, err)
}

View File

@ -340,11 +340,7 @@ func (repo *AuthRequestRepo) VerifyPassword(ctx context.Context, authReqID, user
}
return err
}
policy, err := repo.getLockoutPolicy(ctx, resourceOwner)
if err != nil {
return err
}
err = repo.Command.HumanCheckPassword(ctx, resourceOwner, userID, password, request.WithCurrentInfo(info), lockoutPolicyToDomain(policy))
err = repo.Command.HumanCheckPassword(ctx, resourceOwner, userID, password, request.WithCurrentInfo(info))
if isIgnoreUserInvalidPasswordError(err, request) {
return zerrors.ThrowInvalidArgument(nil, "EVENT-Jsf32", "Errors.User.UsernameOrPassword.Invalid")
}

View File

@ -32,7 +32,7 @@ func (c *Commands) AddDefaultLockoutPolicy(ctx context.Context, maxPasswordAttem
}
func (c *Commands) ChangeDefaultLockoutPolicy(ctx context.Context, policy *domain.LockoutPolicy) (*domain.LockoutPolicy, error) {
existingPolicy, err := c.defaultLockoutPolicyWriteModelByID(ctx)
existingPolicy, err := defaultLockoutPolicyWriteModelByID(ctx, c.eventstore.FilterToQueryReducer)
if err != nil {
return nil, err
}
@ -63,12 +63,12 @@ func (c *Commands) ChangeDefaultLockoutPolicy(ctx context.Context, policy *domai
return writeModelToLockoutPolicy(&existingPolicy.LockoutPolicyWriteModel), nil
}
func (c *Commands) defaultLockoutPolicyWriteModelByID(ctx context.Context) (policy *InstanceLockoutPolicyWriteModel, err error) {
func defaultLockoutPolicyWriteModelByID(ctx context.Context, reducer func(ctx context.Context, r eventstore.QueryReducer) error) (policy *InstanceLockoutPolicyWriteModel, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
writeModel := NewInstanceLockoutPolicyWriteModel(ctx)
err = c.eventstore.FilterToQueryReducer(ctx, writeModel)
err = reducer(ctx, writeModel)
if err != nil {
return nil, err
}

View File

@ -4,6 +4,7 @@ import (
"context"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/org"
"github.com/zitadel/zitadel/internal/zerrors"
)
@ -12,7 +13,7 @@ func (c *Commands) AddLockoutPolicy(ctx context.Context, resourceOwner string, p
if resourceOwner == "" {
return nil, zerrors.ThrowInvalidArgument(nil, "Org-8fJif", "Errors.ResourceOwnerMissing")
}
addedPolicy, err := c.orgLockoutPolicyWriteModelByID(ctx, resourceOwner)
addedPolicy, err := orgLockoutPolicyWriteModelByID(ctx, resourceOwner, c.eventstore.FilterToQueryReducer)
if err != nil {
return nil, err
}
@ -42,7 +43,7 @@ func (c *Commands) ChangeLockoutPolicy(ctx context.Context, resourceOwner string
if resourceOwner == "" {
return nil, zerrors.ThrowInvalidArgument(nil, "Org-3J9fs", "Errors.ResourceOwnerMissing")
}
existingPolicy, err := c.orgLockoutPolicyWriteModelByID(ctx, resourceOwner)
existingPolicy, err := orgLockoutPolicyWriteModelByID(ctx, resourceOwner, c.eventstore.FilterToQueryReducer)
if err != nil {
return nil, err
}
@ -71,7 +72,7 @@ func (c *Commands) RemoveLockoutPolicy(ctx context.Context, orgID string) (*doma
if orgID == "" {
return nil, zerrors.ThrowInvalidArgument(nil, "Org-4J9fs", "Errors.ResourceOwnerMissing")
}
existingPolicy, err := c.orgLockoutPolicyWriteModelByID(ctx, orgID)
existingPolicy, err := orgLockoutPolicyWriteModelByID(ctx, orgID, c.eventstore.FilterToQueryReducer)
if err != nil {
return nil, err
}
@ -93,7 +94,7 @@ func (c *Commands) RemoveLockoutPolicy(ctx context.Context, orgID string) (*doma
}
func (c *Commands) removeLockoutPolicyIfExists(ctx context.Context, orgID string) (*org.LockoutPolicyRemovedEvent, error) {
existingPolicy, err := c.orgLockoutPolicyWriteModelByID(ctx, orgID)
existingPolicy, err := orgLockoutPolicyWriteModelByID(ctx, orgID, c.eventstore.FilterToQueryReducer)
if err != nil {
return nil, err
}
@ -104,24 +105,24 @@ func (c *Commands) removeLockoutPolicyIfExists(ctx context.Context, orgID string
return org.NewLockoutPolicyRemovedEvent(ctx, orgAgg), nil
}
func (c *Commands) orgLockoutPolicyWriteModelByID(ctx context.Context, orgID string) (*OrgLockoutPolicyWriteModel, error) {
func orgLockoutPolicyWriteModelByID(ctx context.Context, orgID string, queryReducer func(ctx context.Context, r eventstore.QueryReducer) error) (*OrgLockoutPolicyWriteModel, error) {
policy := NewOrgLockoutPolicyWriteModel(orgID)
err := c.eventstore.FilterToQueryReducer(ctx, policy)
err := queryReducer(ctx, policy)
if err != nil {
return nil, err
}
return policy, nil
}
func (c *Commands) getLockoutPolicy(ctx context.Context, orgID string) (*domain.LockoutPolicy, error) {
orgWm, err := c.orgLockoutPolicyWriteModelByID(ctx, orgID)
func getLockoutPolicy(ctx context.Context, orgID string, queryReducer func(ctx context.Context, r eventstore.QueryReducer) error) (*domain.LockoutPolicy, error) {
orgWm, err := orgLockoutPolicyWriteModelByID(ctx, orgID, queryReducer)
if err != nil {
return nil, err
}
if orgWm.State == domain.PolicyStateActive {
return writeModelToLockoutPolicy(&orgWm.LockoutPolicyWriteModel), nil
}
instanceWm, err := c.defaultLockoutPolicyWriteModelByID(ctx)
instanceWm, err := defaultLockoutPolicyWriteModelByID(ctx, queryReducer)
if err != nil {
return nil, err
}

View File

@ -7,6 +7,7 @@ import (
"fmt"
"time"
"github.com/zitadel/logging"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/activity"
@ -17,21 +18,18 @@ import (
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/repository/session"
"github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
type SessionCommand func(ctx context.Context, cmd *SessionCommands) error
type SessionCommand func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error)
type SessionCommands struct {
sessionCommands []SessionCommand
sessionWriteModel *SessionWriteModel
passwordWriteModel *HumanPasswordWriteModel
intentWriteModel *IDPIntentWriteModel
totpWriteModel *HumanTOTPWriteModel
eventstore *eventstore.Eventstore
eventCommands []eventstore.Command
sessionWriteModel *SessionWriteModel
intentWriteModel *IDPIntentWriteModel
eventstore *eventstore.Eventstore
eventCommands []eventstore.Command
hasher *crypto.Hasher
intentAlg crypto.EncryptionAlgorithm
@ -59,114 +57,92 @@ func (c *Commands) NewSessionCommands(cmds []SessionCommand, session *SessionWri
// CheckUser defines a user check to be executed for a session update
func CheckUser(id string, resourceOwner string, preferredLanguage *language.Tag) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) error {
return func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
if cmd.sessionWriteModel.UserID != "" && id != "" && cmd.sessionWriteModel.UserID != id {
return zerrors.ThrowInvalidArgument(nil, "", "user change not possible")
return nil, zerrors.ThrowInvalidArgument(nil, "", "user change not possible")
}
return cmd.UserChecked(ctx, id, resourceOwner, cmd.now(), preferredLanguage)
return nil, cmd.UserChecked(ctx, id, resourceOwner, cmd.now(), preferredLanguage)
}
}
// CheckPassword defines a password check to be executed for a session update
func CheckPassword(password string) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) error {
if cmd.sessionWriteModel.UserID == "" {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sfw3f", "Errors.User.UserIDMissing")
}
cmd.passwordWriteModel = NewHumanPasswordWriteModel(cmd.sessionWriteModel.UserID, "")
err := cmd.eventstore.FilterToQueryReducer(ctx, cmd.passwordWriteModel)
return func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
commands, err := checkPassword(ctx, cmd.sessionWriteModel.UserID, password, cmd.eventstore, cmd.hasher, nil)
if err != nil {
return err
return commands, err
}
if cmd.passwordWriteModel.UserState == domain.UserStateUnspecified || cmd.passwordWriteModel.UserState == domain.UserStateDeleted {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-Df4b3", "Errors.User.NotFound")
}
if cmd.passwordWriteModel.EncodedHash == "" {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-WEf3t", "Errors.User.Password.NotSet")
}
ctx, spanPasswordComparison := tracing.NewNamedSpan(ctx, "passwap.Verify")
updated, err := cmd.hasher.Verify(cmd.passwordWriteModel.EncodedHash, password)
spanPasswordComparison.EndWithError(err)
if err != nil {
//TODO: maybe we want to reset the session in the future https://github.com/zitadel/zitadel/issues/5807
return zerrors.ThrowInvalidArgument(err, "COMMAND-SAF3g", "Errors.User.Password.Invalid")
}
if updated != "" {
cmd.eventCommands = append(cmd.eventCommands, user.NewHumanPasswordHashUpdatedEvent(ctx, UserAggregateFromWriteModel(&cmd.passwordWriteModel.WriteModel), updated))
}
cmd.eventCommands = append(cmd.eventCommands, commands...)
cmd.PasswordChecked(ctx, cmd.now())
return nil
return nil, nil
}
}
// CheckIntent defines a check for a succeeded intent to be executed for a session update
func CheckIntent(intentID, token string) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) error {
return func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
if cmd.sessionWriteModel.UserID == "" {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sfw3r", "Errors.User.UserIDMissing")
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sfw3r", "Errors.User.UserIDMissing")
}
if err := crypto.CheckToken(cmd.intentAlg, token, intentID); err != nil {
return err
return nil, err
}
cmd.intentWriteModel = NewIDPIntentWriteModel(intentID, "")
err := cmd.eventstore.FilterToQueryReducer(ctx, cmd.intentWriteModel)
if err != nil {
return err
return nil, err
}
if cmd.intentWriteModel.State != domain.IDPIntentStateSucceeded {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-Df4bw", "Errors.Intent.NotSucceeded")
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Df4bw", "Errors.Intent.NotSucceeded")
}
if cmd.intentWriteModel.UserID != "" {
if cmd.intentWriteModel.UserID != cmd.sessionWriteModel.UserID {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-O8xk3w", "Errors.Intent.OtherUser")
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-O8xk3w", "Errors.Intent.OtherUser")
}
} else {
linkWriteModel := NewUserIDPLinkWriteModel(cmd.sessionWriteModel.UserID, cmd.intentWriteModel.IDPID, cmd.intentWriteModel.IDPUserID, cmd.sessionWriteModel.UserResourceOwner)
err := cmd.eventstore.FilterToQueryReducer(ctx, linkWriteModel)
if err != nil {
return err
return nil, err
}
if linkWriteModel.State != domain.UserIDPLinkStateActive {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-O8xk3w", "Errors.Intent.OtherUser")
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-O8xk3w", "Errors.Intent.OtherUser")
}
}
cmd.IntentChecked(ctx, cmd.now())
return nil
return nil, nil
}
}
func CheckTOTP(code string) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) (err error) {
if cmd.sessionWriteModel.UserID == "" {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-Neil7", "Errors.User.UserIDMissing")
}
cmd.totpWriteModel = NewHumanTOTPWriteModel(cmd.sessionWriteModel.UserID, "")
err = cmd.eventstore.FilterToQueryReducer(ctx, cmd.totpWriteModel)
return func(ctx context.Context, cmd *SessionCommands) (_ []eventstore.Command, err error) {
commands, err := checkTOTP(
ctx,
cmd.sessionWriteModel.UserID,
"",
code,
cmd.eventstore.FilterToQueryReducer,
cmd.totpAlg,
nil,
)
if err != nil {
return err
}
if cmd.totpWriteModel.State != domain.MFAStateReady {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-eej1U", "Errors.User.MFA.OTP.NotReady")
}
err = domain.VerifyTOTP(code, cmd.totpWriteModel.Secret, cmd.totpAlg)
if err != nil {
return err
return commands, err
}
cmd.eventCommands = append(cmd.eventCommands, commands...)
cmd.TOTPChecked(ctx, cmd.now())
return nil
return nil, nil
}
}
// Exec will execute the commands specified and returns an error on the first occurrence
func (s *SessionCommands) Exec(ctx context.Context) error {
// Exec will execute the commands specified and returns an error on the first occurrence.
// In case of an error there might be specific commands returned, e.g. a failed pw check will have to be stored.
func (s *SessionCommands) Exec(ctx context.Context) ([]eventstore.Command, error) {
for _, cmd := range s.sessionCommands {
if err := cmd(ctx, s); err != nil {
return err
if cmds, err := cmd(ctx, s); err != nil {
return cmds, err
}
}
return nil
return nil, nil
}
func (s *SessionCommands) Start(ctx context.Context, userAgent *domain.UserAgent) {
@ -360,8 +336,11 @@ func (c *Commands) updateSession(ctx context.Context, checks *SessionCommands, m
if err = checks.sessionWriteModel.CheckNotInvalidated(); err != nil {
return nil, err
}
if err := checks.Exec(ctx); err != nil {
// TODO: how to handle failed checks (e.g. pw wrong) https://github.com/zitadel/zitadel/issues/5807
if cmds, err := checks.Exec(ctx); err != nil {
if len(cmds) > 0 {
_, pushErr := c.eventstore.Push(ctx, cmds...)
logging.OnError(pushErr).Error("unable to store check failures")
}
return nil, err
}
checks.ChangeMetadata(ctx, metadata)

View File

@ -6,9 +6,10 @@ import (
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/session"
"github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/zerrors"
)
@ -21,26 +22,26 @@ func (c *Commands) CreateOTPSMSChallenge() SessionCommand {
}
func (c *Commands) createOTPSMSChallenge(returnCode bool, dst *string) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) error {
return func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
if cmd.sessionWriteModel.UserID == "" {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-JKL3g", "Errors.User.UserIDMissing")
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-JKL3g", "Errors.User.UserIDMissing")
}
writeModel := NewHumanOTPSMSWriteModel(cmd.sessionWriteModel.UserID, "")
if err := cmd.eventstore.FilterToQueryReducer(ctx, writeModel); err != nil {
return err
return nil, err
}
if !writeModel.OTPAdded() {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-BJ2g3", "Errors.User.MFA.OTP.NotReady")
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-BJ2g3", "Errors.User.MFA.OTP.NotReady")
}
code, err := cmd.createCode(ctx, cmd.eventstore.Filter, domain.SecretGeneratorTypeOTPSMS, cmd.otpAlg, c.defaultSecretGenerators.OTPSMS)
if err != nil {
return err
return nil, err
}
if returnCode {
*dst = code.Plain
}
cmd.OTPSMSChallenged(ctx, code.Crypted, code.Expiry, returnCode)
return nil
return nil, nil
}
}
@ -74,26 +75,26 @@ func (c *Commands) CreateOTPEmailChallenge() SessionCommand {
}
func (c *Commands) createOTPEmailChallenge(returnCode bool, urlTmpl string, dst *string) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) error {
return func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
if cmd.sessionWriteModel.UserID == "" {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-JK3gp", "Errors.User.UserIDMissing")
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-JK3gp", "Errors.User.UserIDMissing")
}
writeModel := NewHumanOTPEmailWriteModel(cmd.sessionWriteModel.UserID, "")
if err := cmd.eventstore.FilterToQueryReducer(ctx, writeModel); err != nil {
return err
return nil, err
}
if !writeModel.OTPAdded() {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-JKLJ3", "Errors.User.MFA.OTP.NotReady")
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-JKLJ3", "Errors.User.MFA.OTP.NotReady")
}
code, err := cmd.createCode(ctx, cmd.eventstore.Filter, domain.SecretGeneratorTypeOTPEmail, cmd.otpAlg, c.defaultSecretGenerators.OTPEmail)
if err != nil {
return err
return nil, err
}
if returnCode {
*dst = code.Plain
}
cmd.OTPEmailChallenged(ctx, code.Crypted, code.Expiry, returnCode, urlTmpl)
return nil
return nil, nil
}
}
@ -112,37 +113,57 @@ func (c *Commands) OTPEmailSent(ctx context.Context, sessionID, resourceOwner st
}
func CheckOTPSMS(code string) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) (err error) {
if cmd.sessionWriteModel.UserID == "" {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-VDrh3", "Errors.User.UserIDMissing")
return func(ctx context.Context, cmd *SessionCommands) (_ []eventstore.Command, err error) {
writeModel := func(ctx context.Context, userID string, resourceOwner string) (OTPCodeWriteModel, error) {
otpWriteModel := NewHumanOTPSMSCodeWriteModel(cmd.sessionWriteModel.UserID, "")
err := cmd.eventstore.FilterToQueryReducer(ctx, otpWriteModel)
if err != nil {
return nil, err
}
// explicitly set the challenge from the session write model since the code write model will only check user events
otpWriteModel.otpCode = cmd.sessionWriteModel.OTPSMSCodeChallenge
return otpWriteModel, nil
}
challenge := cmd.sessionWriteModel.OTPSMSCodeChallenge
if challenge == nil {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-SF3tv", "Errors.User.Code.NotFound")
succeededEvent := func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command {
return user.NewHumanOTPSMSCheckSucceededEvent(ctx, aggregate, nil)
}
err = crypto.VerifyCode(challenge.CreationDate, challenge.Expiry, challenge.Code, code, cmd.otpAlg)
failedEvent := func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command {
return user.NewHumanOTPSMSCheckFailedEvent(ctx, aggregate, nil)
}
commands, err := checkOTP(ctx, cmd.sessionWriteModel.UserID, code, "", nil, writeModel, cmd.eventstore.FilterToQueryReducer, cmd.otpAlg, succeededEvent, failedEvent)
if err != nil {
return err
return commands, err
}
cmd.eventCommands = append(cmd.eventCommands, commands...)
cmd.OTPSMSChecked(ctx, cmd.now())
return nil
return nil, nil
}
}
func CheckOTPEmail(code string) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) (err error) {
if cmd.sessionWriteModel.UserID == "" {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-ejo2w", "Errors.User.UserIDMissing")
return func(ctx context.Context, cmd *SessionCommands) (_ []eventstore.Command, err error) {
writeModel := func(ctx context.Context, userID string, resourceOwner string) (OTPCodeWriteModel, error) {
otpWriteModel := NewHumanOTPEmailCodeWriteModel(cmd.sessionWriteModel.UserID, "")
err := cmd.eventstore.FilterToQueryReducer(ctx, otpWriteModel)
if err != nil {
return nil, err
}
// explicitly set the challenge from the session write model since the code write model will only check user events
otpWriteModel.otpCode = cmd.sessionWriteModel.OTPEmailCodeChallenge
return otpWriteModel, nil
}
challenge := cmd.sessionWriteModel.OTPEmailCodeChallenge
if challenge == nil {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-zF3g3", "Errors.User.Code.NotFound")
succeededEvent := func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command {
return user.NewHumanOTPEmailCheckSucceededEvent(ctx, aggregate, nil)
}
err = crypto.VerifyCode(challenge.CreationDate, challenge.Expiry, challenge.Code, code, cmd.otpAlg)
failedEvent := func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command {
return user.NewHumanOTPEmailCheckFailedEvent(ctx, aggregate, nil)
}
commands, err := checkOTP(ctx, cmd.sessionWriteModel.UserID, code, "", nil, writeModel, cmd.eventstore.FilterToQueryReducer, cmd.otpAlg, succeededEvent, failedEvent)
if err != nil {
return err
return commands, err
}
cmd.eventCommands = append(cmd.eventCommands, commands...)
cmd.OTPEmailChecked(ctx, cmd.now())
return nil
return nil, nil
}
}

View File

@ -11,6 +11,7 @@ import (
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/org"
"github.com/zitadel/zitadel/internal/repository/session"
"github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/zerrors"
@ -110,8 +111,9 @@ func TestCommands_CreateOTPSMSChallengeReturnCode(t *testing.T) {
now: time.Now,
}
err := cmd(context.Background(), cmds)
gotCmds, err := cmd(context.Background(), cmds)
assert.ErrorIs(t, err, tt.res.err)
assert.Empty(t, gotCmds)
assert.Equal(t, tt.res.returnCode, dst)
assert.Equal(t, tt.res.commands, cmds.eventCommands)
})
@ -210,8 +212,9 @@ func TestCommands_CreateOTPSMSChallenge(t *testing.T) {
now: time.Now,
}
err := cmd(context.Background(), cmds)
gotCmds, err := cmd(context.Background(), cmds)
assert.ErrorIs(t, err, tt.res.err)
assert.Empty(t, gotCmds)
assert.Equal(t, tt.res.commands, cmds.eventCommands)
})
}
@ -410,8 +413,9 @@ func TestCommands_CreateOTPEmailChallengeURLTemplate(t *testing.T) {
now: time.Now,
}
err = cmd(context.Background(), cmds)
gotCmds, err := cmd(context.Background(), cmds)
assert.ErrorIs(t, err, tt.res.err)
assert.Empty(t, gotCmds)
assert.Equal(t, tt.res.commands, cmds.eventCommands)
})
}
@ -511,8 +515,9 @@ func TestCommands_CreateOTPEmailChallengeReturnCode(t *testing.T) {
now: time.Now,
}
err := cmd(context.Background(), cmds)
gotCmds, err := cmd(context.Background(), cmds)
assert.ErrorIs(t, err, tt.res.err)
assert.Empty(t, gotCmds)
assert.Equal(t, tt.res.returnCode, dst)
assert.Equal(t, tt.res.commands, cmds.eventCommands)
})
@ -611,8 +616,9 @@ func TestCommands_CreateOTPEmailChallenge(t *testing.T) {
now: time.Now,
}
err := cmd(context.Background(), cmds)
gotCmds, err := cmd(context.Background(), cmds)
assert.ErrorIs(t, err, tt.res.err)
assert.Empty(t, gotCmds)
assert.Equal(t, tt.res.commands, cmds.eventCommands)
})
}
@ -701,8 +707,9 @@ func TestCheckOTPSMS(t *testing.T) {
code string
}
type res struct {
err error
commands []eventstore.Command
err error
commands []eventstore.Command
errorCommands []eventstore.Command
}
tests := []struct {
name string
@ -720,13 +727,43 @@ func TestCheckOTPSMS(t *testing.T) {
code: "code",
},
res: res{
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-VDrh3", "Errors.User.UserIDMissing"),
err: zerrors.ThrowInvalidArgument(nil, "COMMAND-S453v", "Errors.User.UserIDMissing"),
},
},
{
name: "missing code",
fields: fields{
eventstore: expectEventstore(),
userID: "userID",
},
args: args{},
res: res{
err: zerrors.ThrowInvalidArgument(nil, "COMMAND-SJl2g", "Errors.User.Code.Empty"),
},
},
{
name: "not set up",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
),
userID: "userID",
},
args: args{
code: "code",
},
res: res{
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-d2r52", "Errors.User.MFA.OTP.NotReady"),
},
},
{
name: "missing challenge",
fields: fields{
eventstore: expectEventstore(),
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(user.NewHumanOTPSMSAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
),
),
userID: "userID",
otpCodeChallenge: nil,
},
@ -734,14 +771,26 @@ func TestCheckOTPSMS(t *testing.T) {
code: "code",
},
res: res{
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-SF3tv", "Errors.User.Code.NotFound"),
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-S34gh", "Errors.User.Code.NotFound"),
},
},
{
name: "invalid code",
fields: fields{
eventstore: expectEventstore(),
userID: "userID",
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(user.NewHumanOTPSMSAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
),
expectFilter(), // recheck
expectFilter(
eventFromEventPusher(
org.NewLockoutPolicyAddedEvent(context.Background(), &org.NewAggregate("org1").Aggregate,
0, 0, false,
),
),
),
),
userID: "userID",
otpCodeChallenge: &OTPCode{
Code: &crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
@ -759,13 +808,61 @@ func TestCheckOTPSMS(t *testing.T) {
},
res: res{
err: zerrors.ThrowPreconditionFailed(nil, "CODE-QvUQ4P", "Errors.User.Code.Expired"),
errorCommands: []eventstore.Command{
user.NewHumanOTPSMSCheckFailedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
},
},
},
{
name: "invalid code, locked",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(user.NewHumanOTPSMSAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
),
expectFilter(), // recheck
expectFilter(
eventFromEventPusher(
org.NewLockoutPolicyAddedEvent(context.Background(), &org.NewAggregate("org1").Aggregate,
0, 1, false,
),
),
),
),
userID: "userID",
otpCodeChallenge: &OTPCode{
Code: &crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
Algorithm: "enc",
KeyID: "id",
Crypted: []byte("code"),
},
Expiry: 5 * time.Minute,
CreationDate: testNow.Add(-10 * time.Minute),
},
otpAlg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args: args{
code: "code",
},
res: res{
err: zerrors.ThrowPreconditionFailed(nil, "CODE-QvUQ4P", "Errors.User.Code.Expired"),
errorCommands: []eventstore.Command{
user.NewHumanOTPSMSCheckFailedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
user.NewUserLockedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate),
},
},
},
{
name: "check ok",
fields: fields{
eventstore: expectEventstore(),
userID: "userID",
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(user.NewHumanOTPSMSAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
),
expectFilter(), // recheck
),
userID: "userID",
otpCodeChallenge: &OTPCode{
Code: &crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
@ -783,12 +880,44 @@ func TestCheckOTPSMS(t *testing.T) {
},
res: res{
commands: []eventstore.Command{
user.NewHumanOTPSMSCheckSucceededEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
session.NewOTPSMSCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
testNow,
),
},
},
},
{
name: "check ok, locked in the meantime",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(user.NewHumanOTPSMSAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
),
expectFilter(
user.NewUserLockedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate),
),
),
userID: "userID",
otpCodeChallenge: &OTPCode{
Code: &crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
Algorithm: "enc",
KeyID: "id",
Crypted: []byte("code"),
},
Expiry: 5 * time.Minute,
CreationDate: testNow,
},
otpAlg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args: args{
code: "code",
},
res: res{
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-S6h4R", "Errors.User.Locked"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -811,8 +940,9 @@ func TestCheckOTPSMS(t *testing.T) {
},
}
err := cmd(context.Background(), cmds)
gotCmds, err := cmd(context.Background(), cmds)
assert.ErrorIs(t, err, tt.res.err)
assert.Equal(t, tt.res.errorCommands, gotCmds)
assert.Equal(t, tt.res.commands, cmds.eventCommands)
})
}
@ -829,8 +959,9 @@ func TestCheckOTPEmail(t *testing.T) {
code string
}
type res struct {
err error
commands []eventstore.Command
err error
commands []eventstore.Command
errorCommands []eventstore.Command
}
tests := []struct {
name string
@ -848,13 +979,43 @@ func TestCheckOTPEmail(t *testing.T) {
code: "code",
},
res: res{
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-ejo2w", "Errors.User.UserIDMissing"),
err: zerrors.ThrowInvalidArgument(nil, "COMMAND-S453v", "Errors.User.UserIDMissing"),
},
},
{
name: "missing code",
fields: fields{
eventstore: expectEventstore(),
userID: "userID",
},
args: args{},
res: res{
err: zerrors.ThrowInvalidArgument(nil, "COMMAND-SJl2g", "Errors.User.Code.Empty"),
},
},
{
name: "not set up",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
),
userID: "userID",
},
args: args{
code: "code",
},
res: res{
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-d2r52", "Errors.User.MFA.OTP.NotReady"),
},
},
{
name: "missing challenge",
fields: fields{
eventstore: expectEventstore(),
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(user.NewHumanOTPEmailAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
),
),
userID: "userID",
otpCodeChallenge: nil,
},
@ -862,14 +1023,26 @@ func TestCheckOTPEmail(t *testing.T) {
code: "code",
},
res: res{
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-zF3g3", "Errors.User.Code.NotFound"),
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-S34gh", "Errors.User.Code.NotFound"),
},
},
{
name: "invalid code",
fields: fields{
eventstore: expectEventstore(),
userID: "userID",
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(user.NewHumanOTPEmailAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
),
expectFilter(), // recheck
expectFilter(
eventFromEventPusher(
org.NewLockoutPolicyAddedEvent(context.Background(), &org.NewAggregate("org1").Aggregate,
0, 0, false,
),
),
),
),
userID: "userID",
otpCodeChallenge: &OTPCode{
Code: &crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
@ -887,13 +1060,61 @@ func TestCheckOTPEmail(t *testing.T) {
},
res: res{
err: zerrors.ThrowPreconditionFailed(nil, "CODE-QvUQ4P", "Errors.User.Code.Expired"),
errorCommands: []eventstore.Command{
user.NewHumanOTPEmailCheckFailedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
},
},
},
{
name: "invalid code, locked",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(user.NewHumanOTPEmailAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
),
expectFilter(), // recheck
expectFilter(
eventFromEventPusher(
org.NewLockoutPolicyAddedEvent(context.Background(), &org.NewAggregate("org1").Aggregate,
0, 1, false,
),
),
),
),
userID: "userID",
otpCodeChallenge: &OTPCode{
Code: &crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
Algorithm: "enc",
KeyID: "id",
Crypted: []byte("code"),
},
Expiry: 5 * time.Minute,
CreationDate: testNow.Add(-10 * time.Minute),
},
otpAlg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args: args{
code: "code",
},
res: res{
err: zerrors.ThrowPreconditionFailed(nil, "CODE-QvUQ4P", "Errors.User.Code.Expired"),
errorCommands: []eventstore.Command{
user.NewHumanOTPEmailCheckFailedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
user.NewUserLockedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate),
},
},
},
{
name: "check ok",
fields: fields{
eventstore: expectEventstore(),
userID: "userID",
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(user.NewHumanOTPEmailAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
),
expectFilter(), // recheck
),
userID: "userID",
otpCodeChallenge: &OTPCode{
Code: &crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
@ -911,12 +1132,44 @@ func TestCheckOTPEmail(t *testing.T) {
},
res: res{
commands: []eventstore.Command{
user.NewHumanOTPEmailCheckSucceededEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
session.NewOTPEmailCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
testNow,
),
},
},
},
{
name: "check ok, locked in the meantime",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(user.NewHumanOTPEmailAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate)),
),
expectFilter(
user.NewUserLockedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate),
),
),
userID: "userID",
otpCodeChallenge: &OTPCode{
Code: &crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
Algorithm: "enc",
KeyID: "id",
Crypted: []byte("code"),
},
Expiry: 5 * time.Minute,
CreationDate: testNow,
},
otpAlg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args: args{
code: "code",
},
res: res{
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-S6h4R", "Errors.User.Locked"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -939,8 +1192,9 @@ func TestCheckOTPEmail(t *testing.T) {
},
}
err := cmd(context.Background(), cmds)
gotCmds, err := cmd(context.Background(), cmds)
assert.ErrorIs(t, err, tt.res.err)
assert.Equal(t, tt.res.errorCommands, gotCmds)
assert.Equal(t, tt.res.commands, cmds.eventCommands)
})
}

View File

@ -22,6 +22,7 @@ import (
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/id/mock"
"github.com/zitadel/zitadel/internal/repository/idpintent"
"github.com/zitadel/zitadel/internal/repository/org"
"github.com/zitadel/zitadel/internal/repository/session"
"github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/zerrors"
@ -430,8 +431,8 @@ func TestCommands_updateSession(t *testing.T) {
checks: &SessionCommands{
sessionWriteModel: NewSessionWriteModel("sessionID", "instance1"),
sessionCommands: []SessionCommand{
func(ctx context.Context, cmd *SessionCommands) error {
return zerrors.ThrowInternal(nil, "id", "check failed")
func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
return nil, zerrors.ThrowInternal(nil, "id", "check failed")
},
},
},
@ -525,6 +526,55 @@ func TestCommands_updateSession(t *testing.T) {
},
},
},
{
"set user, invalid password",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
user.NewHumanAddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate,
"username", "", "", "", "", language.English, domain.GenderUnspecified, "", false),
),
eventFromEventPusher(
user.NewHumanPasswordChangedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate,
"$plain$x$password", false, ""),
),
),
expectFilter(), // recheck
expectFilter(
org.NewLockoutPolicyAddedEvent(context.Background(), &org.NewAggregate("org1").Aggregate, 0, 0, false),
),
expectPush(
user.NewHumanPasswordCheckFailedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
),
),
},
args{
ctx: authz.NewMockContext("instance1", "", ""),
checks: &SessionCommands{
sessionWriteModel: NewSessionWriteModel("sessionID", "instance1"),
sessionCommands: []SessionCommand{
CheckUser("userID", "org1", &language.Afrikaans),
CheckPassword("invalid password"),
},
createToken: func(sessionID string) (string, string, error) {
return "tokenID",
"token",
nil
},
hasher: mockPasswordHasher("x"),
now: func() time.Time {
return testNow
},
},
metadata: map[string][]byte{
"key": []byte("value"),
},
},
res{
err: zerrors.ThrowInvalidArgument(nil, "COMMAND-3M0fs", "Errors.User.Password.Invalid"),
},
},
{
"set user, password, metadata and token",
fields{
@ -539,10 +589,12 @@ func TestCommands_updateSession(t *testing.T) {
"$plain$x$password", false, ""),
),
),
expectFilter(), // recheck
expectPush(
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
"userID", "org1", testNow, &language.Afrikaans,
),
user.NewHumanPasswordCheckSucceededEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, nil),
session.NewPasswordCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instance1").Aggregate,
testNow,
),
@ -872,6 +924,7 @@ func TestCheckTOTP(t *testing.T) {
sessAgg := &session.NewAggregate("session1", "instance1").Aggregate
userAgg := &user.NewAggregate("user1", "org1").Aggregate
orgAgg := &org.NewAggregate("org1").Aggregate
code, err := totp.GenerateCode(key.Secret(), testNow)
require.NoError(t, err)
@ -886,6 +939,7 @@ func TestCheckTOTP(t *testing.T) {
code string
fields fields
wantEventCommands []eventstore.Command
wantErrorCommands []eventstore.Command
wantErr error
}{
{
@ -897,7 +951,7 @@ func TestCheckTOTP(t *testing.T) {
},
eventstore: expectEventstore(),
},
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Neil7", "Errors.User.UserIDMissing"),
wantErr: zerrors.ThrowInvalidArgument(nil, "COMMAND-8N9ds", "Errors.User.UserIDMissing"),
},
{
name: "filter error",
@ -931,7 +985,7 @@ func TestCheckTOTP(t *testing.T) {
),
),
},
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-eej1U", "Errors.User.MFA.OTP.NotReady"),
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-3Mif9s", "Errors.User.MFA.OTP.NotReady"),
},
{
name: "otp verify error",
@ -951,8 +1005,45 @@ func TestCheckTOTP(t *testing.T) {
user.NewHumanOTPVerifiedEvent(ctx, userAgg, "agent1"),
),
),
expectFilter(), // recheck
expectFilter(
eventFromEventPusher(org.NewLockoutPolicyAddedEvent(ctx, orgAgg, 0, 0, false)),
),
),
},
wantErrorCommands: []eventstore.Command{
user.NewHumanOTPCheckFailedEvent(ctx, userAgg, nil),
},
wantErr: zerrors.ThrowInvalidArgument(nil, "EVENT-8isk2", "Errors.User.MFA.OTP.InvalidCode"),
},
{
name: "otp verify error, locked",
code: "foobar",
fields: fields{
sessionWriteModel: &SessionWriteModel{
UserID: "user1",
UserCheckedAt: testNow,
aggregate: sessAgg,
},
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
user.NewHumanOTPAddedEvent(ctx, userAgg, secret),
),
eventFromEventPusher(
user.NewHumanOTPVerifiedEvent(ctx, userAgg, "agent1"),
),
),
expectFilter(), // recheck
expectFilter(
eventFromEventPusher(org.NewLockoutPolicyAddedEvent(ctx, orgAgg, 1, 1, false)),
),
),
},
wantErrorCommands: []eventstore.Command{
user.NewHumanOTPCheckFailedEvent(ctx, userAgg, nil),
user.NewUserLockedEvent(ctx, userAgg),
},
wantErr: zerrors.ThrowInvalidArgument(nil, "EVENT-8isk2", "Errors.User.MFA.OTP.InvalidCode"),
},
{
@ -973,12 +1064,39 @@ func TestCheckTOTP(t *testing.T) {
user.NewHumanOTPVerifiedEvent(ctx, userAgg, "agent1"),
),
),
expectFilter(), // recheck
),
},
wantEventCommands: []eventstore.Command{
user.NewHumanOTPCheckSucceededEvent(ctx, userAgg, nil),
session.NewTOTPCheckedEvent(ctx, sessAgg, testNow),
},
},
{
name: "ok, but locked in the meantime",
code: code,
fields: fields{
sessionWriteModel: &SessionWriteModel{
UserID: "user1",
UserCheckedAt: testNow,
aggregate: sessAgg,
},
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
user.NewHumanOTPAddedEvent(ctx, userAgg, secret),
),
eventFromEventPusher(
user.NewHumanOTPVerifiedEvent(ctx, userAgg, "agent1"),
),
),
expectFilter(
user.NewUserLockedEvent(ctx, userAgg),
),
),
},
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-SF3fg", "Errors.User.Locked"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -988,8 +1106,9 @@ func TestCheckTOTP(t *testing.T) {
totpAlg: cryptoAlg,
now: func() time.Time { return testNow },
}
err := CheckTOTP(tt.code)(ctx, cmd)
gotCmds, err := CheckTOTP(tt.code)(ctx, cmd)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantErrorCommands, gotCmds)
assert.Equal(t, tt.wantEventCommands, cmd.eventCommands)
})
}

View File

@ -5,6 +5,7 @@ import (
"encoding/json"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/zerrors"
)
@ -41,49 +42,49 @@ func (s *SessionCommands) getHumanWebAuthNTokenReadModel(ctx context.Context, us
}
func (c *Commands) CreateWebAuthNChallenge(userVerification domain.UserVerificationRequirement, rpid string, dst json.Unmarshaler) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) error {
return func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
humanPasskeys, err := cmd.getHumanWebAuthNTokens(ctx, userVerification)
if err != nil {
return err
return nil, err
}
webAuthNLogin, err := c.webauthnConfig.BeginLogin(ctx, humanPasskeys.human, userVerification, rpid, humanPasskeys.tokens...)
if err != nil {
return err
return nil, err
}
if err = json.Unmarshal(webAuthNLogin.CredentialAssertionData, dst); err != nil {
return zerrors.ThrowInternal(err, "COMMAND-Yah6A", "Errors.Internal")
return nil, zerrors.ThrowInternal(err, "COMMAND-Yah6A", "Errors.Internal")
}
cmd.WebAuthNChallenged(ctx, webAuthNLogin.Challenge, webAuthNLogin.AllowedCredentialIDs, webAuthNLogin.UserVerification, rpid)
return nil
return nil, nil
}
}
func (c *Commands) CheckWebAuthN(credentialAssertionData json.Marshaler) SessionCommand {
return func(ctx context.Context, cmd *SessionCommands) error {
return func(ctx context.Context, cmd *SessionCommands) ([]eventstore.Command, error) {
credentialAssertionData, err := json.Marshal(credentialAssertionData)
if err != nil {
return zerrors.ThrowInternal(err, "COMMAND-ohG2o", "Errors.Internal")
return nil, zerrors.ThrowInternal(err, "COMMAND-ohG2o", "Errors.Internal")
}
challenge := cmd.sessionWriteModel.WebAuthNChallenge
if challenge == nil {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-Ioqu5", "Errors.Session.WebAuthN.NoChallenge")
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Ioqu5", "Errors.Session.WebAuthN.NoChallenge")
}
webAuthNTokens, err := cmd.getHumanWebAuthNTokens(ctx, challenge.UserVerification)
if err != nil {
return err
return nil, err
}
webAuthN := challenge.WebAuthNLogin(webAuthNTokens.human, credentialAssertionData)
credential, err := c.webauthnConfig.FinishLogin(ctx, webAuthNTokens.human, webAuthN, credentialAssertionData, webAuthNTokens.tokens...)
if err != nil && (credential == nil || credential.ID == nil) {
return err
return nil, err
}
_, token := domain.GetTokenByKeyID(webAuthNTokens.tokens, credential.ID)
if token == nil {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-Aej7i", "Errors.User.WebAuthN.NotFound")
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Aej7i", "Errors.User.WebAuthN.NotFound")
}
cmd.WebAuthNChecked(ctx, cmd.now(), token.WebAuthNTokenID, credential.Authenticator.SignCount, credential.Flags.UserVerified)
return nil
return nil, nil
}
}

View File

@ -161,48 +161,67 @@ func (c *Commands) HumanCheckMFATOTPSetup(ctx context.Context, userID, code, use
}
func (c *Commands) HumanCheckMFATOTP(ctx context.Context, userID, code, resourceOwner string, authRequest *domain.AuthRequest) error {
commands, err := checkTOTP(
ctx,
userID,
resourceOwner,
code,
c.eventstore.FilterToQueryReducer,
c.multifactors.OTP.CryptoMFA,
authRequestDomainToAuthRequestInfo(authRequest),
)
_, pushErr := c.eventstore.Push(ctx, commands...)
logging.OnError(pushErr).Error("error create password check failed event")
return err
}
func checkTOTP(
ctx context.Context,
userID, resourceOwner, code string,
queryReducer func(ctx context.Context, r eventstore.QueryReducer) error,
alg crypto.EncryptionAlgorithm,
optionalAuthRequestInfo *user.AuthRequestInfo,
) ([]eventstore.Command, error) {
if userID == "" {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-8N9ds", "Errors.User.UserIDMissing")
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-8N9ds", "Errors.User.UserIDMissing")
}
existingOTP, err := c.totpWriteModelByID(ctx, userID, resourceOwner)
existingOTP := NewHumanTOTPWriteModel(userID, resourceOwner)
err := queryReducer(ctx, existingOTP)
if err != nil {
return err
return nil, err
}
if existingOTP.State != domain.MFAStateReady {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-3Mif9s", "Errors.User.MFA.OTP.NotReady")
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-3Mif9s", "Errors.User.MFA.OTP.NotReady")
}
userAgg := UserAggregateFromWriteModel(&existingOTP.WriteModel)
verifyErr := domain.VerifyTOTP(code, existingOTP.Secret, c.multifactors.OTP.CryptoMFA)
verifyErr := domain.VerifyTOTP(code, existingOTP.Secret, alg)
// recheck for additional events (failed OTP checks or locks)
recheckErr := c.eventstore.FilterToQueryReducer(ctx, existingOTP)
recheckErr := queryReducer(ctx, existingOTP)
if recheckErr != nil {
return recheckErr
return nil, recheckErr
}
if existingOTP.UserLocked {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-SF3fg", "Errors.User.Locked")
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-SF3fg", "Errors.User.Locked")
}
// the OTP check succeeded and the user was not locked in the meantime
if verifyErr == nil {
_, err = c.eventstore.Push(ctx, user.NewHumanOTPCheckSucceededEvent(ctx, userAgg, authRequestDomainToAuthRequestInfo(authRequest)))
return err
return []eventstore.Command{user.NewHumanOTPCheckSucceededEvent(ctx, userAgg, optionalAuthRequestInfo)}, nil
}
// the OTP check failed, therefore check if the limit was reached and the user must additionally be locked
commands := make([]eventstore.Command, 0, 2)
commands = append(commands, user.NewHumanOTPCheckFailedEvent(ctx, userAgg, authRequestDomainToAuthRequestInfo(authRequest)))
lockoutPolicy, err := c.getLockoutPolicy(ctx, resourceOwner)
commands = append(commands, user.NewHumanOTPCheckFailedEvent(ctx, userAgg, optionalAuthRequestInfo))
lockoutPolicy, err := getLockoutPolicy(ctx, existingOTP.ResourceOwner, queryReducer)
if err != nil {
return err
return nil, err
}
if lockoutPolicy.MaxOTPAttempts > 0 && existingOTP.CheckFailedCount+1 >= lockoutPolicy.MaxOTPAttempts {
commands = append(commands, user.NewUserLockedEvent(ctx, userAgg))
}
_, pushErr := c.eventstore.Push(ctx, commands...)
logging.OnError(pushErr).Error("error create password check failed event")
return verifyErr
return commands, verifyErr
}
func (c *Commands) HumanRemoveTOTP(ctx context.Context, userID, resourceOwner string) (*domain.ObjectDetails, error) {
@ -342,16 +361,23 @@ func (c *Commands) HumanCheckOTPSMS(ctx context.Context, userID, code, resourceO
failedEvent := func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command {
return user.NewHumanOTPSMSCheckFailedEvent(ctx, aggregate, authRequestDomainToAuthRequestInfo(authRequest))
}
return c.humanCheckOTP(
commands, err := checkOTP(
ctx,
userID,
code,
resourceOwner,
authRequest,
writeModel,
c.eventstore.FilterToQueryReducer,
c.userEncryption,
succeededEvent,
failedEvent,
)
if len(commands) > 0 {
_, pushErr := c.eventstore.Push(ctx, commands...)
logging.WithFields("userID", userID).OnError(pushErr).Error("otp failure check push failed")
}
return err
}
// AddHumanOTPEmail adds the OTP Email factor to a user.
@ -467,16 +493,23 @@ func (c *Commands) HumanCheckOTPEmail(ctx context.Context, userID, code, resourc
failedEvent := func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command {
return user.NewHumanOTPEmailCheckFailedEvent(ctx, aggregate, authRequestDomainToAuthRequestInfo(authRequest))
}
return c.humanCheckOTP(
commands, err := checkOTP(
ctx,
userID,
code,
resourceOwner,
authRequest,
writeModel,
c.eventstore.FilterToQueryReducer,
c.userEncryption,
succeededEvent,
failedEvent,
)
if len(commands) > 0 {
_, pushErr := c.eventstore.Push(ctx, commands...)
logging.WithFields("userID", userID).OnError(pushErr).Error("otp failure check push failed")
}
return err
}
// sendHumanOTP creates a code for a registered mechanism (sms / email), which is used for a check (during login)
@ -534,62 +567,57 @@ func (c *Commands) humanOTPSent(
return err
}
func (c *Commands) humanCheckOTP(
func checkOTP(
ctx context.Context,
userID, code, resourceOwner string,
authRequest *domain.AuthRequest,
writeModelByID func(ctx context.Context, userID string, resourceOwner string) (OTPCodeWriteModel, error),
checkSucceededEvent func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command,
checkFailedEvent func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command,
) error {
queryReducer func(ctx context.Context, r eventstore.QueryReducer) error,
alg crypto.EncryptionAlgorithm,
checkSucceededEvent, checkFailedEvent func(ctx context.Context, aggregate *eventstore.Aggregate, info *user.AuthRequestInfo) eventstore.Command,
) ([]eventstore.Command, error) {
if userID == "" {
return zerrors.ThrowInvalidArgument(nil, "COMMAND-S453v", "Errors.User.UserIDMissing")
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-S453v", "Errors.User.UserIDMissing")
}
if code == "" {
return zerrors.ThrowInvalidArgument(nil, "COMMAND-SJl2g", "Errors.User.Code.Empty")
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-SJl2g", "Errors.User.Code.Empty")
}
existingOTP, err := writeModelByID(ctx, userID, resourceOwner)
if err != nil {
return err
return nil, err
}
if !existingOTP.OTPAdded() {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-d2r52", "Errors.User.MFA.OTP.NotReady")
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-d2r52", "Errors.User.MFA.OTP.NotReady")
}
if existingOTP.Code() == nil {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-S34gh", "Errors.User.Code.NotFound")
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-S34gh", "Errors.User.Code.NotFound")
}
userAgg := &user.NewAggregate(userID, existingOTP.ResourceOwner()).Aggregate
verifyErr := crypto.VerifyCode(existingOTP.CodeCreationDate(), existingOTP.CodeExpiry(), existingOTP.Code(), code, c.userEncryption)
verifyErr := crypto.VerifyCode(existingOTP.CodeCreationDate(), existingOTP.CodeExpiry(), existingOTP.Code(), code, alg)
// recheck for additional events (failed OTP checks or locks)
recheckErr := c.eventstore.FilterToQueryReducer(ctx, existingOTP)
recheckErr := queryReducer(ctx, existingOTP)
if recheckErr != nil {
return recheckErr
return nil, recheckErr
}
if existingOTP.UserLocked() {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-S6h4R", "Errors.User.Locked")
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-S6h4R", "Errors.User.Locked")
}
// the OTP check succeeded and the user was not locked in the meantime
if verifyErr == nil {
_, err = c.eventstore.Push(ctx, checkSucceededEvent(ctx, userAgg, authRequestDomainToAuthRequestInfo(authRequest)))
return err
return []eventstore.Command{checkSucceededEvent(ctx, userAgg, authRequestDomainToAuthRequestInfo(authRequest))}, nil
}
// the OTP check failed, therefore check if the limit was reached and the user must additionally be locked
commands := make([]eventstore.Command, 0, 2)
commands = append(commands, checkFailedEvent(ctx, userAgg, authRequestDomainToAuthRequestInfo(authRequest)))
lockoutPolicy, err := c.getLockoutPolicy(ctx, resourceOwner)
if err != nil {
return err
}
if lockoutPolicy.MaxOTPAttempts > 0 && existingOTP.CheckFailedCount()+1 >= lockoutPolicy.MaxOTPAttempts {
lockoutPolicy, lockoutErr := getLockoutPolicy(ctx, existingOTP.ResourceOwner(), queryReducer)
logging.OnError(lockoutErr).Error("unable to get lockout policy")
if lockoutPolicy != nil && lockoutPolicy.MaxOTPAttempts > 0 && existingOTP.CheckFailedCount()+1 >= lockoutPolicy.MaxOTPAttempts {
commands = append(commands, user.NewUserLockedEvent(ctx, userAgg))
}
_, pushErr := c.eventstore.Push(ctx, commands...)
logging.WithFields("userID", userID).OnError(pushErr).Error("otp failure check push failed")
return verifyErr
return commands, verifyErr
}
func (c *Commands) totpWriteModelByID(ctx context.Context, userID, resourceOwner string) (writeModel *HumanTOTPWriteModel, err error) {

View File

@ -157,24 +157,31 @@ func (wm *HumanOTPSMSWriteModel) Query() *eventstore.SearchQueryBuilder {
type HumanOTPSMSCodeWriteModel struct {
*HumanOTPSMSWriteModel
code *crypto.CryptoValue
codeCreationDate time.Time
codeExpiry time.Duration
otpCode *OTPCode
checkFailedCount uint64
userLocked bool
}
func (wm *HumanOTPSMSCodeWriteModel) CodeCreationDate() time.Time {
return wm.codeCreationDate
if wm.otpCode == nil {
return time.Time{}
}
return wm.otpCode.CreationDate
}
func (wm *HumanOTPSMSCodeWriteModel) CodeExpiry() time.Duration {
return wm.codeExpiry
if wm.otpCode == nil {
return 0
}
return wm.otpCode.Expiry
}
func (wm *HumanOTPSMSCodeWriteModel) Code() *crypto.CryptoValue {
return wm.code
if wm.otpCode == nil {
return nil
}
return wm.otpCode.Code
}
func (wm *HumanOTPSMSCodeWriteModel) CheckFailedCount() uint64 {
@ -195,9 +202,11 @@ func (wm *HumanOTPSMSCodeWriteModel) Reduce() error {
for _, event := range wm.Events {
switch e := event.(type) {
case *user.HumanOTPSMSCodeAddedEvent:
wm.code = e.Code
wm.codeCreationDate = e.CreationDate()
wm.codeExpiry = e.Expiry
wm.otpCode = &OTPCode{
Code: e.Code,
CreationDate: e.CreationDate(),
Expiry: e.Expiry,
}
case *user.HumanOTPSMSCheckSucceededEvent:
wm.checkFailedCount = 0
case *user.HumanOTPSMSCheckFailedEvent:
@ -300,24 +309,31 @@ func (wm *HumanOTPEmailWriteModel) Query() *eventstore.SearchQueryBuilder {
type HumanOTPEmailCodeWriteModel struct {
*HumanOTPEmailWriteModel
code *crypto.CryptoValue
codeCreationDate time.Time
codeExpiry time.Duration
otpCode *OTPCode
checkFailedCount uint64
userLocked bool
}
func (wm *HumanOTPEmailCodeWriteModel) CodeCreationDate() time.Time {
return wm.codeCreationDate
if wm.otpCode == nil {
return time.Time{}
}
return wm.otpCode.CreationDate
}
func (wm *HumanOTPEmailCodeWriteModel) CodeExpiry() time.Duration {
return wm.codeExpiry
if wm.otpCode == nil {
return 0
}
return wm.otpCode.Expiry
}
func (wm *HumanOTPEmailCodeWriteModel) Code() *crypto.CryptoValue {
return wm.code
if wm.otpCode == nil {
return nil
}
return wm.otpCode.Code
}
func (wm *HumanOTPEmailCodeWriteModel) CheckFailedCount() uint64 {
@ -338,9 +354,11 @@ func (wm *HumanOTPEmailCodeWriteModel) Reduce() error {
for _, event := range wm.Events {
switch e := event.(type) {
case *user.HumanOTPEmailCodeAddedEvent:
wm.code = e.Code
wm.codeCreationDate = e.CreationDate()
wm.codeExpiry = e.Expiry
wm.otpCode = &OTPCode{
Code: e.Code,
CreationDate: e.CreationDate(),
Expiry: e.Expiry,
}
case *user.HumanOTPEmailCheckSucceededEvent:
wm.checkFailedCount = 0
case *user.HumanOTPEmailCheckFailedEvent:

View File

@ -295,8 +295,8 @@ func (c *Commands) PasswordChangeSent(ctx context.Context, orgID, userID string)
return err
}
// HumanCheckPassword check password for user with additional informations from authRequest
func (c *Commands) HumanCheckPassword(ctx context.Context, orgID, userID, password string, authRequest *domain.AuthRequest, lockoutPolicy *domain.LockoutPolicy) (err error) {
// HumanCheckPassword check password for user with additional information from authRequest
func (c *Commands) HumanCheckPassword(ctx context.Context, orgID, userID, password string, authRequest *domain.AuthRequest) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
@ -314,56 +314,66 @@ func (c *Commands) HumanCheckPassword(ctx context.Context, orgID, userID, passwo
if !loginPolicy.AllowUsernamePassword {
return zerrors.ThrowPreconditionFailed(err, "COMMAND-Dft32", "Errors.Org.LoginPolicy.UsernamePasswordNotAllowed")
}
wm, err := c.passwordWriteModel(ctx, userID, orgID)
if err != nil {
commands, err := checkPassword(ctx, userID, password, c.eventstore, c.userPasswordHasher, authRequestDomainToAuthRequestInfo(authRequest))
if len(commands) == 0 {
return err
}
_, pushErr := c.eventstore.Push(ctx, commands...)
logging.OnError(pushErr).Error("error create password check failed event")
return err
}
if !isUserStateExists(wm.UserState) {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-3n77z", "Errors.User.NotFound")
func checkPassword(ctx context.Context, userID, password string, es *eventstore.Eventstore, hasher *crypto.Hasher, optionalAuthRequestInfo *user.AuthRequestInfo) ([]eventstore.Command, error) {
if userID == "" {
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sfw3f", "Errors.User.UserIDMissing")
}
wm := NewHumanPasswordWriteModel(userID, "")
err := es.FilterToQueryReducer(ctx, wm)
if err != nil {
return nil, err
}
if !wm.UserState.Exists() {
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-3n77z", "Errors.User.NotFound")
}
if wm.UserState == domain.UserStateLocked {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-JLK35", "Errors.User.Locked")
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-JLK35", "Errors.User.Locked")
}
if wm.EncodedHash == "" {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-3nJ4t", "Errors.User.Password.NotSet")
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-3nJ4t", "Errors.User.Password.NotSet")
}
userAgg := UserAggregateFromWriteModel(&wm.WriteModel)
ctx, spanPasswordComparison := tracing.NewNamedSpan(ctx, "passwap.Verify")
updated, err := c.userPasswordHasher.Verify(wm.EncodedHash, password)
updated, err := hasher.Verify(wm.EncodedHash, password)
spanPasswordComparison.EndWithError(err)
err = convertPasswapErr(err)
commands := make([]eventstore.Command, 0, 2)
// recheck for additional events (failed password checks or locks)
recheckErr := c.eventstore.FilterToQueryReducer(ctx, wm)
recheckErr := es.FilterToQueryReducer(ctx, wm)
if recheckErr != nil {
return recheckErr
return nil, recheckErr
}
if wm.UserState == domain.UserStateLocked {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-SFA3t", "Errors.User.Locked")
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-SFA3t", "Errors.User.Locked")
}
if err == nil {
commands = append(commands, user.NewHumanPasswordCheckSucceededEvent(ctx, userAgg, authRequestDomainToAuthRequestInfo(authRequest)))
commands = append(commands, user.NewHumanPasswordCheckSucceededEvent(ctx, userAgg, optionalAuthRequestInfo))
if updated != "" {
commands = append(commands, user.NewHumanPasswordHashUpdatedEvent(ctx, userAgg, updated))
}
_, err = c.eventstore.Push(ctx, commands...)
return err
return commands, nil
}
commands = append(commands, user.NewHumanPasswordCheckFailedEvent(ctx, userAgg, authRequestDomainToAuthRequestInfo(authRequest)))
if lockoutPolicy != nil && lockoutPolicy.MaxPasswordAttempts > 0 {
if wm.PasswordCheckFailedCount+1 >= lockoutPolicy.MaxPasswordAttempts {
commands = append(commands, user.NewUserLockedEvent(ctx, userAgg))
}
commands = append(commands, user.NewHumanPasswordCheckFailedEvent(ctx, userAgg, optionalAuthRequestInfo))
lockoutPolicy, lockoutErr := getLockoutPolicy(ctx, wm.ResourceOwner, es.FilterToQueryReducer)
logging.OnError(lockoutErr).Error("unable to get lockout policy")
if lockoutPolicy != nil && lockoutPolicy.MaxPasswordAttempts > 0 && wm.PasswordCheckFailedCount+1 >= lockoutPolicy.MaxPasswordAttempts {
commands = append(commands, user.NewUserLockedEvent(ctx, userAgg))
}
_, pushErr := c.eventstore.Push(ctx, commands...)
logging.OnError(pushErr).Error("error create password check failed event")
return err
return commands, err
}
func (c *Commands) passwordWriteModel(ctx context.Context, userID, resourceOwner string) (writeModel *HumanPasswordWriteModel, err error) {

View File

@ -1456,7 +1456,6 @@ func TestCommandSide_CheckPassword(t *testing.T) {
resourceOwner string
password string
authReq *domain.AuthRequest
lockoutPolicy *domain.LockoutPolicy
}
type res struct {
err func(error) bool
@ -1768,6 +1767,13 @@ func TestCommandSide_CheckPassword(t *testing.T) {
"")),
),
expectFilter(),
expectFilter(
eventFromEventPusher(
org.NewLockoutPolicyAddedEvent(context.Background(),
&org.NewAggregate("org1").Aggregate,
0, 0, false,
)),
),
expectPush(
user.NewHumanPasswordCheckFailedEvent(context.Background(),
&user.NewAggregate("user1", "org1").Aggregate,
@ -1789,7 +1795,6 @@ func TestCommandSide_CheckPassword(t *testing.T) {
ID: "request1",
AgentID: "agent1",
},
lockoutPolicy: &domain.LockoutPolicy{},
},
res: res{
err: zerrors.IsErrorInvalidArgument,
@ -1852,6 +1857,13 @@ func TestCommandSide_CheckPassword(t *testing.T) {
),
),
expectFilter(),
expectFilter(
eventFromEventPusher(
org.NewLockoutPolicyAddedEvent(context.Background(),
&org.NewAggregate("org1").Aggregate,
1, 1, false,
)),
),
expectPush(
user.NewHumanPasswordCheckFailedEvent(context.Background(),
&user.NewAggregate("user1", "org1").Aggregate,
@ -1876,10 +1888,6 @@ func TestCommandSide_CheckPassword(t *testing.T) {
ID: "request1",
AgentID: "agent1",
},
lockoutPolicy: &domain.LockoutPolicy{
MaxPasswordAttempts: 1,
MaxOTPAttempts: 1,
},
},
res: res{
err: zerrors.IsErrorInvalidArgument,
@ -2230,7 +2238,7 @@ func TestCommandSide_CheckPassword(t *testing.T) {
eventstore: tt.fields.eventstore(t),
userPasswordHasher: tt.fields.userPasswordHasher,
}
err := r.HumanCheckPassword(tt.args.ctx, tt.args.resourceOwner, tt.args.userID, tt.args.password, tt.args.authReq, tt.args.lockoutPolicy)
err := r.HumanCheckPassword(tt.args.ctx, tt.args.resourceOwner, tt.args.userID, tt.args.password, tt.args.authReq)
if tt.res.err == nil {
assert.NoError(t, err)
}

View File

@ -281,42 +281,42 @@ func CheckRedirect(req *http.Request) (*url.URL, error) {
return resp.Location()
}
func (s *Tester) CreateOIDCCredentialsClient(ctx context.Context) (userID, clientID, clientSecret string, err error) {
name := gofakeit.Username()
user, err := s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{
func (s *Tester) CreateOIDCCredentialsClient(ctx context.Context) (machine *management.AddMachineUserResponse, name, clientID, clientSecret string, err error) {
name = gofakeit.Username()
machine, err = s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{
Name: name,
UserName: name,
AccessTokenType: user.AccessTokenType_ACCESS_TOKEN_TYPE_JWT,
})
if err != nil {
return "", "", "", err
return nil, "", "", "", err
}
secret, err := s.Client.Mgmt.GenerateMachineSecret(ctx, &management.GenerateMachineSecretRequest{
UserId: user.GetUserId(),
UserId: machine.GetUserId(),
})
if err != nil {
return "", "", "", err
return nil, "", "", "", err
}
return user.GetUserId(), secret.GetClientId(), secret.GetClientSecret(), nil
return machine, name, secret.GetClientId(), secret.GetClientSecret(), nil
}
func (s *Tester) CreateOIDCJWTProfileClient(ctx context.Context) (userID string, keyData []byte, err error) {
name := gofakeit.Username()
user, err := s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{
func (s *Tester) CreateOIDCJWTProfileClient(ctx context.Context) (machine *management.AddMachineUserResponse, name string, keyData []byte, err error) {
name = gofakeit.Username()
machine, err = s.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{
Name: name,
UserName: name,
AccessTokenType: user.AccessTokenType_ACCESS_TOKEN_TYPE_JWT,
})
if err != nil {
return "", nil, err
return nil, "", nil, err
}
keyResp, err := s.Client.Mgmt.AddMachineKey(ctx, &management.AddMachineKeyRequest{
UserId: user.GetUserId(),
UserId: machine.GetUserId(),
Type: authn.KeyType_KEY_TYPE_JSON,
ExpirationDate: timestamppb.New(time.Now().Add(time.Hour)),
})
if err != nil {
return "", nil, err
return nil, "", nil, err
}
return user.GetUserId(), keyResp.GetKeyDetails(), nil
return machine, name, keyResp.GetKeyDetails(), nil
}