mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 01:47:33 +00:00
feat: add auth command side (#107)
* fix: query tests * fix: use prepare funcs * fix: go mod * fix: generate files * fix(eventstore): tests * fix(eventstore): rename modifier to editor * fix(migrations): add cluster migration, fix(migrations): fix typo of host in clean clsuter * fix(eventstore): move health * fix(eventstore): AggregateTypeFilter aggregateType as param * code quality * fix: go tests * feat: add member funcs * feat: add member model * feat: add member events * feat: add member repo model * fix: better error func testing * fix: project member funcs * fix: add tests * fix: add tests * feat: implement member requests * fix: merge master * fix: merge master * fix: read existing in project repo * fix: fix tests * feat: add internal cache * feat: add cache mock * fix: return values of cache mock * feat: add project role * fix: add cache config * fix: add role to eventstore * fix: use eventstore sdk * fix: use eventstore sdk * fix: add project role grpc requests * fix: fix getby id * fix: changes for mr * fix: change value to interface * feat: add app event creations * fix: searchmethods * Update internal/project/model/project_member.go Co-Authored-By: Silvan <silvan.reusser@gmail.com> * fix: use get project func * fix: append events * fix: check if value is string on equal ignore case * fix: add changes test * fix: add go mod * fix: add some tests * fix: return err not nil * fix: return err not nil * fix: add aggregate funcs and tests * fix: add oidc aggregate funcs and tests * fix: add oidc * fix: add some tests * fix: tests * feat: eventstore repository * fix: remove gorm * version * feat: pkg * feat: eventstore without eventstore-lib * rename files * gnueg * fix: global model * feat: add global view functions * feat(eventstore): sdk * fix(eventstore): rename app to eventstore * delete empty test * fix(models): delete unused struct * feat(eventstore): overwrite context data * fix: use global sql config * fix: oidc validation * fix: generate client secret * fix: generate client id * fix: test change app * fix: deactivate/reactivate application * fix: change oidc config * fix: change oidc config secret * begin models * begin repo * fix: implement grpc app funcs * fix: add application requests * fix: converter * fix: converter * fix: converter and generate clientid * fix: tests * feat: project grant aggregate * feat: project grant * fix: project grant check if role existing * fix: project grant requests * fix: project grant fixes * fix: project grant member model * fix: project grant member aggregate * fix: project grant member eventstore * fix: project grant member requests * feat: user model * begin repo * repo models and more * feat: user command side * lots of functions * user command side * profile requests * commit before rebase on user * save * local config with gopass and more * begin new auth command (user centric) * Update internal/user/model/user.go Co-Authored-By: Livio Amstutz <livio.a@gmail.com> * Update internal/user/repository/eventsourcing/model/address.go Co-Authored-By: Livio Amstutz <livio.a@gmail.com> * Update internal/user/repository/eventsourcing/model/address.go Co-Authored-By: Livio Amstutz <livio.a@gmail.com> * Update internal/user/repository/eventsourcing/model/email.go Co-Authored-By: Livio Amstutz <livio.a@gmail.com> * Update internal/user/repository/eventsourcing/model/email.go Co-Authored-By: Livio Amstutz <livio.a@gmail.com> * Update internal/user/repository/eventsourcing/model/email.go Co-Authored-By: Livio Amstutz <livio.a@gmail.com> * Update internal/user/repository/eventsourcing/model/mfa.go Co-Authored-By: Livio Amstutz <livio.a@gmail.com> * Update internal/user/repository/eventsourcing/model/mfa.go Co-Authored-By: Livio Amstutz <livio.a@gmail.com> * Update internal/user/repository/eventsourcing/model/password.go Co-Authored-By: Livio Amstutz <livio.a@gmail.com> * Update internal/user/repository/eventsourcing/model/password.go Co-Authored-By: Livio Amstutz <livio.a@gmail.com> * Update internal/user/repository/eventsourcing/model/password.go Co-Authored-By: Livio Amstutz <livio.a@gmail.com> * Update internal/user/repository/eventsourcing/model/phone.go Co-Authored-By: Livio Amstutz <livio.a@gmail.com> * Update internal/user/repository/eventsourcing/model/phone.go Co-Authored-By: Livio Amstutz <livio.a@gmail.com> * Update internal/user/repository/eventsourcing/model/phone.go Co-Authored-By: Livio Amstutz <livio.a@gmail.com> * Update internal/user/repository/eventsourcing/model/user.go Co-Authored-By: Livio Amstutz <livio.a@gmail.com> * Update internal/user/repository/eventsourcing/model/user.go Co-Authored-By: Livio Amstutz <livio.a@gmail.com> * Update internal/user/repository/eventsourcing/model/user.go Co-Authored-By: Livio Amstutz <livio.a@gmail.com> * Update internal/usergrant/repository/eventsourcing/model/user_grant.go Co-Authored-By: Livio Amstutz <livio.a@gmail.com> * Update internal/usergrant/repository/eventsourcing/model/user_grant.go Co-Authored-By: Livio Amstutz <livio.a@gmail.com> * Update internal/usergrant/repository/eventsourcing/user_grant.go Co-Authored-By: Livio Amstutz <livio.a@gmail.com> * Update internal/user/repository/eventsourcing/user_test.go Co-Authored-By: Livio Amstutz <livio.a@gmail.com> * Update internal/user/repository/eventsourcing/eventstore_mock_test.go Co-Authored-By: Livio Amstutz <livio.a@gmail.com> * changes from mr review * save files into basedir * changes from mr review * changes from mr review * move to auth request * Update internal/usergrant/repository/eventsourcing/cache.go Co-authored-by: Silvan <silvan.reusser@gmail.com> * Update internal/usergrant/repository/eventsourcing/cache.go Co-authored-by: Silvan <silvan.reusser@gmail.com> * changes requested on mr * fix generate codes * fix return if no events * password code * email verification step * more steps * lot of mfa * begin tests * more next steps * auth api * auth api (user) * auth api (user) * auth api (user) * differ requests * merge * tests * fix compilation error * mock for id generator * Update internal/user/repository/eventsourcing/model/password.go Co-authored-by: Silvan <silvan.reusser@gmail.com> * Update internal/user/repository/eventsourcing/model/user.go Co-authored-by: Silvan <silvan.reusser@gmail.com> * requests of mr * check email * begin separation of command and query * otp * change packages * some cleanup and fixes * tests for auth request / next steps * add VerificationLifetimes to config and make it run * tests * fix code challenge validation * cleanup * fix merge * begin view * repackaging tests and configs * fix startup config for auth * add migration * add PromptSelectAccount * fix copy / paste * remove user_agent files * fixes * fix sequences in user_session * token commands * token queries and signout * fix * fix set password test * add token handler and table * handle session init * add session state * add user view test cases * change VerifyMyMfaOTP * some fixes * fix user repo in auth api * cleanup * add user session view test * fix merge * fixes * Update internal/auth/repository/eventsourcing/eventstore/auth_request.go Co-authored-by: Fabi <38692350+fgerschwiler@users.noreply.github.com> * Update internal/auth/repository/eventsourcing/eventstore/auth_request.go Co-authored-by: Fabi <38692350+fgerschwiler@users.noreply.github.com> * Update internal/auth/repository/eventsourcing/eventstore/auth_request.go Co-authored-by: Fabi <38692350+fgerschwiler@users.noreply.github.com> * Update internal/auth/repository/eventsourcing/eventstore/auth_request.go Co-authored-by: Fabi <38692350+fgerschwiler@users.noreply.github.com> * extract method usersForUserSelection * add todo for policy check * id on auth req * fix enum name Co-authored-by: Fabiennne <fabienne.gerschwiler@gmail.com> Co-authored-by: adlerhurst <silvan.reusser@gmail.com> Co-authored-by: Fabi <38692350+fgerschwiler@users.noreply.github.com>
This commit is contained in:
25
internal/auth/auth/token_verifier.go
Normal file
25
internal/auth/auth/token_verifier.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/caos/zitadel/internal/api/auth"
|
||||
)
|
||||
|
||||
type TokenVerifier struct {
|
||||
}
|
||||
|
||||
func Start() (v *TokenVerifier) {
|
||||
return new(TokenVerifier)
|
||||
}
|
||||
|
||||
func (v *TokenVerifier) VerifyAccessToken(ctx context.Context, token string) (string, string, string, error) {
|
||||
return "", "", "", nil
|
||||
}
|
||||
|
||||
func (v *TokenVerifier) ResolveGrants(ctx context.Context, userID, orgID string) ([]*auth.Grant, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (v *TokenVerifier) GetProjectIDByClientID(ctx context.Context, clientID string) (string, error) {
|
||||
return "", nil
|
||||
}
|
15
internal/auth/repository/auth_request.go
Normal file
15
internal/auth/repository/auth_request.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/caos/zitadel/internal/auth_request/model"
|
||||
)
|
||||
|
||||
type AuthRequestRepository interface {
|
||||
CreateAuthRequest(ctx context.Context, request *model.AuthRequest) (*model.AuthRequest, error)
|
||||
AuthRequestByID(ctx context.Context, id string) (*model.AuthRequest, error)
|
||||
CheckUsername(ctx context.Context, id, username string) error
|
||||
VerifyPassword(ctx context.Context, id, userID, password string, info *model.BrowserInfo) error
|
||||
VerifyMfaOTP(ctx context.Context, agentID, authRequestID string, code string, info *model.BrowserInfo) error
|
||||
}
|
@@ -0,0 +1,239 @@
|
||||
package eventstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/caos/zitadel/internal/auth/repository/eventsourcing/view"
|
||||
"github.com/caos/zitadel/internal/auth_request/model"
|
||||
"github.com/caos/zitadel/internal/auth_request/repository/cache"
|
||||
"github.com/caos/zitadel/internal/errors"
|
||||
"github.com/caos/zitadel/internal/id"
|
||||
user_model "github.com/caos/zitadel/internal/user/model"
|
||||
user_event "github.com/caos/zitadel/internal/user/repository/eventsourcing"
|
||||
view_model "github.com/caos/zitadel/internal/user/repository/view/model"
|
||||
)
|
||||
|
||||
type AuthRequestRepo struct {
|
||||
UserEvents *user_event.UserEventstore
|
||||
AuthRequests *cache.AuthRequestCache
|
||||
View *view.View
|
||||
|
||||
UserSessionViewProvider userSessionViewProvider
|
||||
UserViewProvider userViewProvider
|
||||
|
||||
IdGenerator id.Generator
|
||||
|
||||
PasswordCheckLifeTime time.Duration
|
||||
MfaInitSkippedLifeTime time.Duration
|
||||
MfaSoftwareCheckLifeTime time.Duration
|
||||
MfaHardwareCheckLifeTime time.Duration
|
||||
}
|
||||
|
||||
type userSessionViewProvider interface {
|
||||
UserSessionByIDs(string, string) (*view_model.UserSessionView, error)
|
||||
UserSessionsByAgentID(string) ([]*view_model.UserSessionView, error)
|
||||
}
|
||||
type userViewProvider interface {
|
||||
UserByID(string) (*view_model.UserView, error)
|
||||
}
|
||||
|
||||
func (repo *AuthRequestRepo) Health(ctx context.Context) error {
|
||||
if err := repo.UserEvents.Health(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return repo.AuthRequests.Health(ctx)
|
||||
}
|
||||
|
||||
func (repo *AuthRequestRepo) CreateAuthRequest(ctx context.Context, request *model.AuthRequest) (*model.AuthRequest, error) {
|
||||
reqID, err := repo.IdGenerator.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.ID = reqID
|
||||
err = repo.AuthRequests.SaveAuthRequest(ctx, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (repo *AuthRequestRepo) AuthRequestByID(ctx context.Context, id string) (*model.AuthRequest, error) {
|
||||
request, err := repo.AuthRequests.GetAuthRequestByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
steps, err := repo.nextSteps(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.PossibleSteps = steps
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (repo *AuthRequestRepo) CheckUsername(ctx context.Context, id, username string) error {
|
||||
request, err := repo.AuthRequests.GetAuthRequestByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user, err := repo.View.UserByUsername(username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
request.UserID = user.ID
|
||||
return repo.AuthRequests.SaveAuthRequest(ctx, request)
|
||||
}
|
||||
|
||||
func (repo *AuthRequestRepo) VerifyPassword(ctx context.Context, id, userID, password string, info *model.BrowserInfo) error {
|
||||
request, err := repo.AuthRequests.GetAuthRequestByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if request.UserID == userID {
|
||||
return errors.ThrowPreconditionFailed(nil, "EVENT-ds35D", "user id does not match request id ")
|
||||
}
|
||||
return repo.UserEvents.CheckPassword(ctx, userID, password, request.WithCurrentInfo(info))
|
||||
}
|
||||
|
||||
func (repo *AuthRequestRepo) VerifyMfaOTP(ctx context.Context, authRequestID, userID string, code string, info *model.BrowserInfo) error {
|
||||
request, err := repo.AuthRequests.GetAuthRequestByID(ctx, authRequestID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if request.UserID != userID {
|
||||
return errors.ThrowPreconditionFailed(nil, "EVENT-ADJ26", "user id does not match request id")
|
||||
}
|
||||
return repo.UserEvents.CheckMfaOTP(ctx, userID, code, request.WithCurrentInfo(info))
|
||||
}
|
||||
|
||||
func (repo *AuthRequestRepo) nextSteps(request *model.AuthRequest) ([]model.NextStep, error) {
|
||||
if request == nil {
|
||||
return nil, errors.ThrowInvalidArgument(nil, "EVENT-ds27a", "request must not be nil")
|
||||
}
|
||||
steps := make([]model.NextStep, 0)
|
||||
if request.UserID == "" {
|
||||
if request.Prompt != model.PromptNone {
|
||||
steps = append(steps, &model.LoginStep{})
|
||||
}
|
||||
if request.Prompt == model.PromptSelectAccount {
|
||||
users, err := repo.usersForUserSelection(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
steps = append(steps, &model.SelectUserStep{Users: users})
|
||||
}
|
||||
return steps, nil
|
||||
}
|
||||
userSession, err := userSessionByIDs(repo.UserSessionViewProvider, request.AgentID, request.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user, err := userByID(repo.UserViewProvider, request.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.PasswordSet {
|
||||
return append(steps, &model.InitPasswordStep{}), nil
|
||||
}
|
||||
|
||||
if !checkVerificationTime(userSession.PasswordVerification, repo.PasswordCheckLifeTime) {
|
||||
return append(steps, &model.PasswordStep{}), nil
|
||||
}
|
||||
|
||||
if step, ok := repo.mfaChecked(userSession, request, user); !ok {
|
||||
return append(steps, step), nil
|
||||
}
|
||||
|
||||
if user.PasswordChangeRequired {
|
||||
steps = append(steps, &model.ChangePasswordStep{})
|
||||
}
|
||||
if !user.IsEmailVerified {
|
||||
steps = append(steps, &model.VerifyEMailStep{})
|
||||
}
|
||||
|
||||
if user.PasswordChangeRequired || !user.IsEmailVerified {
|
||||
return steps, nil
|
||||
}
|
||||
|
||||
//PLANNED: consent step
|
||||
return append(steps, &model.RedirectToCallbackStep{}), nil
|
||||
}
|
||||
|
||||
func (repo *AuthRequestRepo) usersForUserSelection(request *model.AuthRequest) ([]model.UserSelection, error) {
|
||||
userSessions, err := userSessionsByUserAgentID(repo.UserSessionViewProvider, request.AgentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users := make([]model.UserSelection, len(userSessions))
|
||||
for i, session := range userSessions {
|
||||
users[i] = model.UserSelection{
|
||||
UserID: session.UserID,
|
||||
UserName: session.UserName,
|
||||
UserSessionState: session.State,
|
||||
}
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (repo *AuthRequestRepo) mfaChecked(userSession *user_model.UserSessionView, request *model.AuthRequest, user *user_model.UserView) (model.NextStep, bool) {
|
||||
mfaLevel := request.MfaLevel()
|
||||
required := user.MfaMaxSetUp < mfaLevel
|
||||
if required || !repo.mfaSkippedOrSetUp(user) {
|
||||
return &model.MfaPromptStep{
|
||||
Required: required,
|
||||
MfaProviders: user.MfaTypesSetupPossible(mfaLevel),
|
||||
}, false
|
||||
}
|
||||
switch mfaLevel {
|
||||
default:
|
||||
fallthrough
|
||||
case model.MfaLevelSoftware:
|
||||
if checkVerificationTime(userSession.MfaSoftwareVerification, repo.MfaSoftwareCheckLifeTime) {
|
||||
return nil, true
|
||||
}
|
||||
fallthrough
|
||||
case model.MfaLevelHardware:
|
||||
if checkVerificationTime(userSession.MfaHardwareVerification, repo.MfaHardwareCheckLifeTime) {
|
||||
return nil, true
|
||||
}
|
||||
}
|
||||
return &model.MfaVerificationStep{
|
||||
MfaProviders: user.MfaTypesAllowed(mfaLevel),
|
||||
}, false
|
||||
}
|
||||
|
||||
func (repo *AuthRequestRepo) mfaSkippedOrSetUp(user *user_model.UserView) bool {
|
||||
if user.MfaMaxSetUp >= 0 {
|
||||
return true
|
||||
}
|
||||
return checkVerificationTime(user.MfaInitSkipped, repo.MfaInitSkippedLifeTime)
|
||||
}
|
||||
|
||||
func checkVerificationTime(verificationTime time.Time, lifetime time.Duration) bool {
|
||||
return verificationTime.Add(lifetime).After(time.Now().UTC())
|
||||
}
|
||||
|
||||
func userSessionsByUserAgentID(provider userSessionViewProvider, agentID string) ([]*user_model.UserSessionView, error) {
|
||||
session, err := provider.UserSessionsByAgentID(agentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return view_model.UserSessionsToModel(session), nil
|
||||
}
|
||||
|
||||
func userSessionByIDs(provider userSessionViewProvider, agentID, userID string) (*user_model.UserSessionView, error) {
|
||||
session, err := provider.UserSessionByIDs(agentID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return view_model.UserSessionToModel(session), nil
|
||||
}
|
||||
|
||||
func userByID(provider userViewProvider, userID string) (*user_model.UserView, error) {
|
||||
user, err := provider.UserByID(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return view_model.UserToModel(user), nil
|
||||
}
|
@@ -0,0 +1,475 @@
|
||||
package eventstore
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/caos/zitadel/internal/auth/repository/eventsourcing/view"
|
||||
"github.com/caos/zitadel/internal/auth_request/model"
|
||||
"github.com/caos/zitadel/internal/auth_request/repository/cache"
|
||||
"github.com/caos/zitadel/internal/errors"
|
||||
user_model "github.com/caos/zitadel/internal/user/model"
|
||||
user_event "github.com/caos/zitadel/internal/user/repository/eventsourcing"
|
||||
view_model "github.com/caos/zitadel/internal/user/repository/view/model"
|
||||
)
|
||||
|
||||
type mockViewNoUserSession struct{}
|
||||
|
||||
func (m *mockViewNoUserSession) UserSessionByIDs(string, string) (*view_model.UserSessionView, error) {
|
||||
return nil, errors.ThrowNotFound(nil, "id", "user session not found")
|
||||
}
|
||||
|
||||
func (m *mockViewNoUserSession) UserSessionsByAgentID(string) ([]*view_model.UserSessionView, error) {
|
||||
return nil, errors.ThrowInternal(nil, "id", "internal error")
|
||||
}
|
||||
|
||||
type mockViewUserSession struct {
|
||||
PasswordVerification time.Time
|
||||
MfaSoftwareVerification time.Time
|
||||
Users []mockUser
|
||||
}
|
||||
|
||||
type mockUser struct {
|
||||
UserID string
|
||||
UserName string
|
||||
}
|
||||
|
||||
func (m *mockViewUserSession) UserSessionByIDs(string, string) (*view_model.UserSessionView, error) {
|
||||
return &view_model.UserSessionView{
|
||||
PasswordVerification: m.PasswordVerification,
|
||||
MfaSoftwareVerification: m.MfaSoftwareVerification,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockViewUserSession) UserSessionsByAgentID(string) ([]*view_model.UserSessionView, error) {
|
||||
sessions := make([]*view_model.UserSessionView, len(m.Users))
|
||||
for i, user := range m.Users {
|
||||
sessions[i] = &view_model.UserSessionView{
|
||||
UserID: user.UserID,
|
||||
UserName: user.UserName,
|
||||
}
|
||||
}
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
type mockViewNoUser struct{}
|
||||
|
||||
func (m *mockViewNoUser) UserByID(string) (*view_model.UserView, error) {
|
||||
return nil, errors.ThrowNotFound(nil, "id", "user not found")
|
||||
}
|
||||
|
||||
type mockViewUser struct {
|
||||
PasswordSet bool
|
||||
PasswordChangeRequired bool
|
||||
IsEmailVerified bool
|
||||
OTPState int32
|
||||
MfaMaxSetUp int32
|
||||
MfaInitSkipped time.Time
|
||||
}
|
||||
|
||||
func (m *mockViewUser) UserByID(string) (*view_model.UserView, error) {
|
||||
return &view_model.UserView{
|
||||
PasswordSet: m.PasswordSet,
|
||||
PasswordChangeRequired: m.PasswordChangeRequired,
|
||||
IsEmailVerified: m.IsEmailVerified,
|
||||
OTPState: m.OTPState,
|
||||
MfaMaxSetUp: m.MfaMaxSetUp,
|
||||
MfaInitSkipped: m.MfaInitSkipped,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
type fields struct {
|
||||
UserEvents *user_event.UserEventstore
|
||||
AuthRequests *cache.AuthRequestCache
|
||||
View *view.View
|
||||
userSessionViewProvider userSessionViewProvider
|
||||
userViewProvider userViewProvider
|
||||
PasswordCheckLifeTime time.Duration
|
||||
MfaInitSkippedLifeTime time.Duration
|
||||
MfaSoftwareCheckLifeTime time.Duration
|
||||
MfaHardwareCheckLifeTime time.Duration
|
||||
}
|
||||
type args struct {
|
||||
request *model.AuthRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want []model.NextStep
|
||||
wantErr func(error) bool
|
||||
}{
|
||||
{
|
||||
"request nil, error",
|
||||
fields{},
|
||||
args{nil},
|
||||
nil,
|
||||
errors.IsErrorInvalidArgument,
|
||||
},
|
||||
{
|
||||
"user not set, login step",
|
||||
fields{},
|
||||
args{&model.AuthRequest{}},
|
||||
[]model.NextStep{&model.LoginStep{}},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"user not set and prompt none, no step",
|
||||
fields{},
|
||||
args{&model.AuthRequest{Prompt: model.PromptNone}},
|
||||
[]model.NextStep{},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"user not set, prompt select account and internal error, internal error",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewNoUserSession{},
|
||||
},
|
||||
args{&model.AuthRequest{Prompt: model.PromptSelectAccount}},
|
||||
nil,
|
||||
errors.IsInternal,
|
||||
},
|
||||
{
|
||||
"user not set, prompt select account, login and select account steps",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
Users: []mockUser{
|
||||
{
|
||||
"id1",
|
||||
"username1",
|
||||
},
|
||||
{
|
||||
"id2",
|
||||
"username2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
args{&model.AuthRequest{Prompt: model.PromptSelectAccount}},
|
||||
[]model.NextStep{
|
||||
&model.LoginStep{},
|
||||
&model.SelectUserStep{
|
||||
Users: []model.UserSelection{
|
||||
{
|
||||
UserID: "id1",
|
||||
UserName: "username1",
|
||||
},
|
||||
{
|
||||
UserID: "id2",
|
||||
UserName: "username2",
|
||||
},
|
||||
},
|
||||
}},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"usersession not found, not found error",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewNoUserSession{},
|
||||
},
|
||||
args{&model.AuthRequest{UserID: "UserID"}},
|
||||
nil,
|
||||
errors.IsNotFound,
|
||||
},
|
||||
{
|
||||
"user not not found, not found error",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{},
|
||||
userViewProvider: &mockViewNoUser{},
|
||||
},
|
||||
args{&model.AuthRequest{UserID: "UserID"}},
|
||||
nil,
|
||||
errors.IsNotFound,
|
||||
},
|
||||
{
|
||||
"password not set, init password step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{},
|
||||
userViewProvider: &mockViewUser{},
|
||||
},
|
||||
args{&model.AuthRequest{UserID: "UserID"}},
|
||||
[]model.NextStep{&model.InitPasswordStep{}},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"password not verified, password check step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
},
|
||||
PasswordCheckLifeTime: 10 * 24 * time.Hour,
|
||||
},
|
||||
args{&model.AuthRequest{UserID: "UserID"}},
|
||||
[]model.NextStep{&model.PasswordStep{}},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"mfa not verified, mfa check step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
OTPState: int32(user_model.MFASTATE_READY),
|
||||
MfaMaxSetUp: int32(model.MfaLevelSoftware),
|
||||
},
|
||||
PasswordCheckLifeTime: 10 * 24 * time.Hour,
|
||||
MfaSoftwareCheckLifeTime: 18 * time.Hour,
|
||||
},
|
||||
args{&model.AuthRequest{UserID: "UserID"}},
|
||||
[]model.NextStep{&model.MfaVerificationStep{
|
||||
MfaProviders: []model.MfaType{model.MfaTypeOTP},
|
||||
}},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"password change required and email verified, password change step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
MfaSoftwareVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
PasswordChangeRequired: true,
|
||||
IsEmailVerified: true,
|
||||
},
|
||||
PasswordCheckLifeTime: 10 * 24 * time.Hour,
|
||||
MfaSoftwareCheckLifeTime: 18 * time.Hour,
|
||||
},
|
||||
args{&model.AuthRequest{UserID: "UserID"}},
|
||||
[]model.NextStep{&model.ChangePasswordStep{}},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"email not verified and no password change required, mail verification step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
MfaSoftwareVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
},
|
||||
PasswordCheckLifeTime: 10 * 24 * time.Hour,
|
||||
MfaSoftwareCheckLifeTime: 18 * time.Hour,
|
||||
},
|
||||
args{&model.AuthRequest{UserID: "UserID"}},
|
||||
[]model.NextStep{&model.VerifyEMailStep{}},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"email not verified and password change required, mail verification step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
MfaSoftwareVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
PasswordChangeRequired: true,
|
||||
},
|
||||
PasswordCheckLifeTime: 10 * 24 * time.Hour,
|
||||
MfaSoftwareCheckLifeTime: 18 * time.Hour,
|
||||
},
|
||||
args{&model.AuthRequest{UserID: "UserID"}},
|
||||
[]model.NextStep{&model.ChangePasswordStep{}, &model.VerifyEMailStep{}},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"email verified and no password change required, redirect to callback step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
MfaSoftwareVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
IsEmailVerified: true,
|
||||
},
|
||||
PasswordCheckLifeTime: 10 * 24 * time.Hour,
|
||||
MfaSoftwareCheckLifeTime: 18 * time.Hour,
|
||||
},
|
||||
args{&model.AuthRequest{UserID: "UserID"}},
|
||||
[]model.NextStep{&model.RedirectToCallbackStep{}},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &AuthRequestRepo{
|
||||
UserEvents: tt.fields.UserEvents,
|
||||
AuthRequests: tt.fields.AuthRequests,
|
||||
View: tt.fields.View,
|
||||
UserSessionViewProvider: tt.fields.userSessionViewProvider,
|
||||
UserViewProvider: tt.fields.userViewProvider,
|
||||
PasswordCheckLifeTime: tt.fields.PasswordCheckLifeTime,
|
||||
MfaInitSkippedLifeTime: tt.fields.MfaInitSkippedLifeTime,
|
||||
MfaSoftwareCheckLifeTime: tt.fields.MfaSoftwareCheckLifeTime,
|
||||
MfaHardwareCheckLifeTime: tt.fields.MfaHardwareCheckLifeTime,
|
||||
}
|
||||
got, err := repo.nextSteps(tt.args.request)
|
||||
if (err != nil && tt.wantErr == nil) || (tt.wantErr != nil && !tt.wantErr(err)) {
|
||||
t.Errorf("nextSteps() wrong error = %v", err)
|
||||
return
|
||||
}
|
||||
assert.ElementsMatch(t, got, tt.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthRequestRepo_mfaChecked(t *testing.T) {
|
||||
type fields struct {
|
||||
MfaInitSkippedLifeTime time.Duration
|
||||
MfaSoftwareCheckLifeTime time.Duration
|
||||
MfaHardwareCheckLifeTime time.Duration
|
||||
}
|
||||
type args struct {
|
||||
userSession *user_model.UserSessionView
|
||||
request *model.AuthRequest
|
||||
user *user_model.UserView
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want model.NextStep
|
||||
wantChecked bool
|
||||
}{
|
||||
//{
|
||||
// "required, prompt and false", //TODO: enable when LevelsOfAssurance is checked
|
||||
// fields{},
|
||||
// args{
|
||||
// request: &model.AuthRequest{PossibleLOAs: []model.LevelOfAssurance{}},
|
||||
// user: &user_model.UserView{
|
||||
// OTPState: user_model.MFASTATE_READY,
|
||||
// },
|
||||
// },
|
||||
// false,
|
||||
//},
|
||||
{
|
||||
"not set up, prompt and false",
|
||||
fields{
|
||||
MfaInitSkippedLifeTime: 30 * 24 * time.Hour,
|
||||
},
|
||||
args{
|
||||
request: &model.AuthRequest{},
|
||||
user: &user_model.UserView{
|
||||
MfaMaxSetUp: -1,
|
||||
},
|
||||
},
|
||||
&model.MfaPromptStep{
|
||||
MfaProviders: []model.MfaType{},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"checked mfa software, true",
|
||||
fields{
|
||||
MfaSoftwareCheckLifeTime: 18 * time.Hour,
|
||||
},
|
||||
args{
|
||||
request: &model.AuthRequest{},
|
||||
user: &user_model.UserView{
|
||||
OTPState: user_model.MFASTATE_READY,
|
||||
},
|
||||
userSession: &user_model.UserSessionView{MfaSoftwareVerification: time.Now().UTC().Add(-5 * time.Hour)},
|
||||
},
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"not checked, check and false",
|
||||
fields{
|
||||
MfaSoftwareCheckLifeTime: 18 * time.Hour,
|
||||
},
|
||||
args{
|
||||
request: &model.AuthRequest{},
|
||||
user: &user_model.UserView{
|
||||
OTPState: user_model.MFASTATE_READY,
|
||||
},
|
||||
userSession: &user_model.UserSessionView{},
|
||||
},
|
||||
|
||||
&model.MfaVerificationStep{
|
||||
MfaProviders: []model.MfaType{model.MfaTypeOTP},
|
||||
},
|
||||
false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &AuthRequestRepo{
|
||||
MfaInitSkippedLifeTime: tt.fields.MfaInitSkippedLifeTime,
|
||||
MfaSoftwareCheckLifeTime: tt.fields.MfaSoftwareCheckLifeTime,
|
||||
MfaHardwareCheckLifeTime: tt.fields.MfaHardwareCheckLifeTime,
|
||||
}
|
||||
got, ok := repo.mfaChecked(tt.args.userSession, tt.args.request, tt.args.user)
|
||||
if ok != tt.wantChecked {
|
||||
t.Errorf("mfaChecked() checked = %v, want %v", ok, tt.wantChecked)
|
||||
}
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthRequestRepo_mfaSkippedOrSetUp(t *testing.T) {
|
||||
type fields struct {
|
||||
MfaInitSkippedLifeTime time.Duration
|
||||
}
|
||||
type args struct {
|
||||
user *user_model.UserView
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
"mfa set up, true",
|
||||
fields{},
|
||||
args{&user_model.UserView{
|
||||
MfaMaxSetUp: model.MfaLevelSoftware,
|
||||
}},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"mfa skipped active, true",
|
||||
fields{
|
||||
MfaInitSkippedLifeTime: 30 * 24 * time.Hour,
|
||||
},
|
||||
args{&user_model.UserView{
|
||||
MfaMaxSetUp: -1,
|
||||
MfaInitSkipped: time.Now().UTC().Add(-10 * time.Hour),
|
||||
}},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"mfa skipped inactive, false",
|
||||
fields{
|
||||
MfaInitSkippedLifeTime: 30 * 24 * time.Hour,
|
||||
},
|
||||
args{&user_model.UserView{
|
||||
MfaMaxSetUp: -1,
|
||||
MfaInitSkipped: time.Now().UTC().Add(-40 * 24 * time.Hour),
|
||||
}},
|
||||
false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &AuthRequestRepo{
|
||||
MfaInitSkippedLifeTime: tt.fields.MfaInitSkippedLifeTime,
|
||||
}
|
||||
if got := repo.mfaSkippedOrSetUp(tt.args.user); got != tt.want {
|
||||
t.Errorf("mfaSkippedOrSetUp() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
26
internal/auth/repository/eventsourcing/eventstore/token.go
Normal file
26
internal/auth/repository/eventsourcing/eventstore/token.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package eventstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/caos/zitadel/internal/auth/repository/eventsourcing/view"
|
||||
token_model "github.com/caos/zitadel/internal/token/model"
|
||||
token_view_model "github.com/caos/zitadel/internal/token/repository/view/model"
|
||||
)
|
||||
|
||||
type TokenRepo struct {
|
||||
View *view.View
|
||||
}
|
||||
|
||||
func (repo *TokenRepo) CreateToken(ctx context.Context, agentID, applicationID, userID string, lifetime time.Duration) (*token_model.Token, error) {
|
||||
token, err := repo.View.CreateToken(agentID, applicationID, userID, lifetime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return token_view_model.TokenToModel(token), nil
|
||||
}
|
||||
|
||||
func (repo *TokenRepo) IsTokenValid(ctx context.Context, tokenID string) (bool, error) {
|
||||
return repo.View.IsTokenValid(tokenID)
|
||||
}
|
129
internal/auth/repository/eventsourcing/eventstore/user.go
Normal file
129
internal/auth/repository/eventsourcing/eventstore/user.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package eventstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/caos/zitadel/internal/api/auth"
|
||||
"github.com/caos/zitadel/internal/auth/repository/eventsourcing/view"
|
||||
"github.com/caos/zitadel/internal/errors"
|
||||
es_models "github.com/caos/zitadel/internal/eventstore/models"
|
||||
"github.com/caos/zitadel/internal/user/model"
|
||||
user_event "github.com/caos/zitadel/internal/user/repository/eventsourcing"
|
||||
)
|
||||
|
||||
type UserRepo struct {
|
||||
UserEvents *user_event.UserEventstore
|
||||
View *view.View
|
||||
}
|
||||
|
||||
func (repo *UserRepo) Health(ctx context.Context) error {
|
||||
return repo.UserEvents.Health(ctx)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) Register(ctx context.Context, user *model.User, resourceOwner string) (*model.User, error) {
|
||||
return repo.UserEvents.RegisterUser(ctx, user, resourceOwner)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) MyProfile(ctx context.Context) (*model.Profile, error) {
|
||||
return repo.UserEvents.ProfileByID(ctx, auth.GetCtxData(ctx).UserID)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) ChangeMyProfile(ctx context.Context, profile *model.Profile) (*model.Profile, error) {
|
||||
if err := checkIDs(ctx, profile.ObjectRoot); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return repo.UserEvents.ChangeProfile(ctx, profile)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) MyEmail(ctx context.Context) (*model.Email, error) {
|
||||
return repo.UserEvents.EmailByID(ctx, auth.GetCtxData(ctx).UserID)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) ChangeMyEmail(ctx context.Context, email *model.Email) (*model.Email, error) {
|
||||
if err := checkIDs(ctx, email.ObjectRoot); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return repo.UserEvents.ChangeEmail(ctx, email)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) VerifyMyEmail(ctx context.Context, code string) error {
|
||||
return repo.UserEvents.VerifyEmail(ctx, auth.GetCtxData(ctx).UserID, code)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) ResendMyEmailVerificationMail(ctx context.Context) error {
|
||||
return repo.UserEvents.CreateEmailVerificationCode(ctx, auth.GetCtxData(ctx).UserID)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) MyPhone(ctx context.Context) (*model.Phone, error) {
|
||||
return repo.UserEvents.PhoneByID(ctx, auth.GetCtxData(ctx).UserID)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) ChangeMyPhone(ctx context.Context, phone *model.Phone) (*model.Phone, error) {
|
||||
if err := checkIDs(ctx, phone.ObjectRoot); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return repo.UserEvents.ChangePhone(ctx, phone)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) VerifyMyPhone(ctx context.Context, code string) error {
|
||||
return repo.UserEvents.VerifyPhone(ctx, auth.GetCtxData(ctx).UserID, code)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) ResendMyPhoneVerificationCode(ctx context.Context) error {
|
||||
return repo.UserEvents.CreatePhoneVerificationCode(ctx, auth.GetCtxData(ctx).UserID)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) MyAddress(ctx context.Context) (*model.Address, error) {
|
||||
return repo.UserEvents.AddressByID(ctx, auth.GetCtxData(ctx).UserID)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) ChangeMyAddress(ctx context.Context, address *model.Address) (*model.Address, error) {
|
||||
if err := checkIDs(ctx, address.ObjectRoot); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return repo.UserEvents.ChangeAddress(ctx, address)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) ChangeMyPassword(ctx context.Context, old, new string) error {
|
||||
_, err := repo.UserEvents.ChangePassword(ctx, auth.GetCtxData(ctx).UserID, old, new)
|
||||
return err
|
||||
}
|
||||
|
||||
func (repo *UserRepo) AddMyMfaOTP(ctx context.Context) (*model.OTP, error) {
|
||||
return repo.UserEvents.AddOTP(ctx, auth.GetCtxData(ctx).UserID)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) VerifyMyMfaOTP(ctx context.Context, code string) error {
|
||||
return repo.UserEvents.CheckMfaOTPSetup(ctx, auth.GetCtxData(ctx).UserID, code)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) RemoveMyMfaOTP(ctx context.Context) error {
|
||||
return repo.UserEvents.RemoveOTP(ctx, auth.GetCtxData(ctx).UserID)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) SkipMfaInit(ctx context.Context, userID string) error {
|
||||
return repo.UserEvents.SkipMfaInit(ctx, userID)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) RequestPasswordReset(ctx context.Context, username string) error {
|
||||
user, err := repo.View.UserByUsername(username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return repo.UserEvents.RequestSetPassword(ctx, user.ID, model.NOTIFICATIONTYPE_EMAIL)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) SetPassword(ctx context.Context, userID, code, password string) error {
|
||||
return repo.UserEvents.SetPassword(ctx, userID, code, password)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) SignOut(ctx context.Context, agentID, userID string) error {
|
||||
return repo.UserEvents.SignOut(ctx, agentID, userID)
|
||||
}
|
||||
|
||||
func checkIDs(ctx context.Context, obj es_models.ObjectRoot) error {
|
||||
if obj.AggregateID != auth.GetCtxData(ctx).UserID {
|
||||
return errors.ThrowPermissionDenied(nil, "EVENT-kFi9w", "object does not belong to user")
|
||||
}
|
||||
return nil
|
||||
}
|
42
internal/auth/repository/eventsourcing/handler/handler.go
Normal file
42
internal/auth/repository/eventsourcing/handler/handler.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/caos/zitadel/internal/auth/repository/eventsourcing/view"
|
||||
"github.com/caos/zitadel/internal/eventstore/spooler"
|
||||
usr_event "github.com/caos/zitadel/internal/user/repository/eventsourcing"
|
||||
)
|
||||
|
||||
type Configs map[string]*Config
|
||||
|
||||
type Config struct {
|
||||
MinimumCycleDurationMillisecond int
|
||||
}
|
||||
|
||||
type handler struct {
|
||||
view *view.View
|
||||
bulkLimit uint64
|
||||
cycleDuration time.Duration
|
||||
errorCountUntilSkip uint64
|
||||
}
|
||||
|
||||
type EventstoreRepos struct {
|
||||
UserEvents *usr_event.UserEventstore
|
||||
}
|
||||
|
||||
func Register(configs Configs, bulkLimit, errorCount uint64, view *view.View, repos EventstoreRepos) []spooler.Handler {
|
||||
return []spooler.Handler{
|
||||
&User{handler: handler{view, bulkLimit, configs.cycleDuration("User"), errorCount}},
|
||||
&UserSession{handler: handler{view, bulkLimit, configs.cycleDuration("UserSession"), errorCount}, userEvents: repos.UserEvents},
|
||||
&Token{handler: handler{view, bulkLimit, configs.cycleDuration("Token"), errorCount}},
|
||||
}
|
||||
}
|
||||
|
||||
func (configs Configs) cycleDuration(viewModel string) time.Duration {
|
||||
c, ok := configs[viewModel]
|
||||
if !ok {
|
||||
return 1 * time.Second
|
||||
}
|
||||
return time.Duration(c.MinimumCycleDurationMillisecond) * time.Millisecond
|
||||
}
|
69
internal/auth/repository/eventsourcing/handler/token.go
Normal file
69
internal/auth/repository/eventsourcing/handler/token.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
caos_errs "github.com/caos/zitadel/internal/errors"
|
||||
es_model "github.com/caos/zitadel/internal/user/repository/eventsourcing/model"
|
||||
|
||||
"github.com/caos/logging"
|
||||
|
||||
"github.com/caos/zitadel/internal/eventstore/models"
|
||||
"github.com/caos/zitadel/internal/eventstore/spooler"
|
||||
"github.com/caos/zitadel/internal/user/repository/eventsourcing"
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
handler
|
||||
}
|
||||
|
||||
const (
|
||||
tokenTable = "auth.tokens"
|
||||
)
|
||||
|
||||
func (u *Token) MinimumCycleDuration() time.Duration { return u.cycleDuration }
|
||||
|
||||
func (u *Token) ViewModel() string {
|
||||
return tokenTable
|
||||
}
|
||||
|
||||
func (u *Token) EventQuery() (*models.SearchQuery, error) {
|
||||
sequence, err := u.view.GetLatestTokenSequence()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return eventsourcing.UserQuery(sequence), nil
|
||||
}
|
||||
|
||||
func (u *Token) Process(event *models.Event) (err error) {
|
||||
switch event.Type {
|
||||
case es_model.SignedOut:
|
||||
id, err := agentIDFromSession(event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = u.view.DeleteSessionTokens(id, event.AggregateID, event.Sequence)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return u.view.ProcessedTokenSequence(event.Sequence)
|
||||
default:
|
||||
return u.view.ProcessedTokenSequence(event.Sequence)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *Token) OnError(event *models.Event, err error) error {
|
||||
logging.LogWithFields("SPOOL-3jkl4", "id", event.AggregateID).WithError(err).Warn("something went wrong in token handler")
|
||||
return spooler.HandleError(event, err, u.view.GetLatestTokenFailedEvent, u.view.ProcessedTokenFailedEvent, u.view.ProcessedTokenSequence, u.errorCountUntilSkip)
|
||||
}
|
||||
|
||||
func agentIDFromSession(event *models.Event) (string, error) {
|
||||
session := make(map[string]interface{})
|
||||
if err := json.Unmarshal(event.Data, session); err != nil {
|
||||
logging.Log("EVEN-s3bq9").WithError(err).Error("could not unmarshal event data")
|
||||
return "", caos_errs.ThrowInternal(nil, "MODEL-sd325", "could not unmarshal data")
|
||||
}
|
||||
return session["agentID"].(string), nil
|
||||
}
|
77
internal/auth/repository/eventsourcing/handler/user.go
Normal file
77
internal/auth/repository/eventsourcing/handler/user.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
es_model "github.com/caos/zitadel/internal/user/repository/eventsourcing/model"
|
||||
"time"
|
||||
|
||||
"github.com/caos/logging"
|
||||
|
||||
"github.com/caos/zitadel/internal/eventstore"
|
||||
"github.com/caos/zitadel/internal/eventstore/models"
|
||||
"github.com/caos/zitadel/internal/eventstore/spooler"
|
||||
"github.com/caos/zitadel/internal/user/repository/eventsourcing"
|
||||
view_model "github.com/caos/zitadel/internal/user/repository/view/model"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
handler
|
||||
eventstore eventstore.Eventstore
|
||||
}
|
||||
|
||||
const (
|
||||
userTable = "auth.users"
|
||||
)
|
||||
|
||||
func (p *User) MinimumCycleDuration() time.Duration { return p.cycleDuration }
|
||||
|
||||
func (p *User) ViewModel() string {
|
||||
return userTable
|
||||
}
|
||||
|
||||
func (p *User) EventQuery() (*models.SearchQuery, error) {
|
||||
sequence, err := p.view.GetLatestUserSequence()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return eventsourcing.UserQuery(sequence), nil
|
||||
}
|
||||
|
||||
func (p *User) Process(event *models.Event) (err error) {
|
||||
user := new(view_model.UserView)
|
||||
switch event.Type {
|
||||
case es_model.UserAdded,
|
||||
es_model.UserRegistered:
|
||||
user.AppendEvent(event)
|
||||
case es_model.UserProfileChanged,
|
||||
es_model.UserEmailChanged,
|
||||
es_model.UserEmailVerified,
|
||||
es_model.UserPhoneChanged,
|
||||
es_model.UserPhoneVerified,
|
||||
es_model.UserAddressChanged,
|
||||
es_model.UserDeactivated,
|
||||
es_model.UserReactivated,
|
||||
es_model.UserLocked,
|
||||
es_model.UserUnlocked,
|
||||
es_model.MfaOtpAdded,
|
||||
es_model.MfaOtpVerified,
|
||||
es_model.MfaOtpRemoved:
|
||||
user, err = p.view.UserByID(event.AggregateID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = user.AppendEvent(event)
|
||||
case es_model.UserDeleted:
|
||||
err = p.view.DeleteUser(event.AggregateID, event.Sequence)
|
||||
default:
|
||||
return p.view.ProcessedUserSequence(event.Sequence)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return p.view.PutUser(user)
|
||||
}
|
||||
|
||||
func (p *User) OnError(event *models.Event, err error) error {
|
||||
logging.LogWithFields("SPOOL-is8wa", "id", event.AggregateID).WithError(err).Warn("something went wrong in user handler")
|
||||
return spooler.HandleError(event, err, p.view.GetLatestUserFailedEvent, p.view.ProcessedUserFailedEvent, p.view.ProcessedUserSequence, p.errorCountUntilSkip)
|
||||
}
|
@@ -0,0 +1,90 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
req_model "github.com/caos/zitadel/internal/auth_request/model"
|
||||
"github.com/caos/zitadel/internal/errors"
|
||||
es_model "github.com/caos/zitadel/internal/user/repository/eventsourcing/model"
|
||||
|
||||
"github.com/caos/logging"
|
||||
|
||||
"github.com/caos/zitadel/internal/eventstore/models"
|
||||
"github.com/caos/zitadel/internal/eventstore/spooler"
|
||||
"github.com/caos/zitadel/internal/user/repository/eventsourcing"
|
||||
user_events "github.com/caos/zitadel/internal/user/repository/eventsourcing"
|
||||
view_model "github.com/caos/zitadel/internal/user/repository/view/model"
|
||||
)
|
||||
|
||||
type UserSession struct {
|
||||
handler
|
||||
userEvents *user_events.UserEventstore
|
||||
}
|
||||
|
||||
const (
|
||||
userSessionTable = "auth.user_sessions"
|
||||
)
|
||||
|
||||
func (u *UserSession) MinimumCycleDuration() time.Duration { return u.cycleDuration }
|
||||
|
||||
func (u *UserSession) ViewModel() string {
|
||||
return userSessionTable
|
||||
}
|
||||
|
||||
func (u *UserSession) EventQuery() (*models.SearchQuery, error) {
|
||||
sequence, err := u.view.GetLatestUserSessionSequence()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return eventsourcing.UserQuery(sequence), nil
|
||||
}
|
||||
|
||||
func (u *UserSession) Process(event *models.Event) (err error) {
|
||||
eventData, err := view_model.UserSessionFromEvent(event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
session, err := u.view.UserSessionByIDs(eventData.UserAgentID, event.AggregateID)
|
||||
if err != nil {
|
||||
if !errors.IsNotFound(err) {
|
||||
return err
|
||||
}
|
||||
session = &view_model.UserSessionView{
|
||||
CreationDate: event.CreationDate,
|
||||
ResourceOwner: event.ResourceOwner,
|
||||
UserAgentID: eventData.UserAgentID,
|
||||
UserID: event.AggregateID,
|
||||
State: int32(req_model.UserSessionStateActive),
|
||||
}
|
||||
}
|
||||
switch event.Type {
|
||||
case es_model.UserPasswordCheckSucceeded,
|
||||
es_model.UserPasswordCheckFailed,
|
||||
es_model.UserPasswordChanged,
|
||||
es_model.MfaOtpCheckSucceeded,
|
||||
es_model.MfaOtpCheckFailed,
|
||||
es_model.MfaOtpRemoved:
|
||||
session.AppendEvent(event)
|
||||
default:
|
||||
return u.view.ProcessedUserSessionSequence(event.Sequence)
|
||||
}
|
||||
if err := u.FillUserInfo(session, event.AggregateID); err != nil {
|
||||
return err
|
||||
}
|
||||
return u.view.PutUserSession(session)
|
||||
}
|
||||
|
||||
func (u *UserSession) OnError(event *models.Event, err error) error {
|
||||
logging.LogWithFields("SPOOL-sdfw3s", "id", event.AggregateID).WithError(err).Warn("something went wrong in user session handler")
|
||||
return spooler.HandleError(event, err, u.view.GetLatestUserSessionFailedEvent, u.view.ProcessedUserSessionFailedEvent, u.view.ProcessedUserSessionSequence, u.errorCountUntilSkip)
|
||||
}
|
||||
|
||||
func (u *UserSession) FillUserInfo(session *view_model.UserSessionView, id string) error {
|
||||
user, err := u.userEvents.UserByID(context.Background(), id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
session.UserName = user.UserName
|
||||
return nil
|
||||
}
|
93
internal/auth/repository/eventsourcing/repository.go
Normal file
93
internal/auth/repository/eventsourcing/repository.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package eventsourcing
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/caos/zitadel/internal/auth/repository/eventsourcing/eventstore"
|
||||
"github.com/caos/zitadel/internal/auth/repository/eventsourcing/handler"
|
||||
"github.com/caos/zitadel/internal/auth/repository/eventsourcing/spooler"
|
||||
auth_view "github.com/caos/zitadel/internal/auth/repository/eventsourcing/view"
|
||||
"github.com/caos/zitadel/internal/auth_request/repository/cache"
|
||||
sd "github.com/caos/zitadel/internal/config/systemdefaults"
|
||||
"github.com/caos/zitadel/internal/config/types"
|
||||
es_int "github.com/caos/zitadel/internal/eventstore"
|
||||
es_spol "github.com/caos/zitadel/internal/eventstore/spooler"
|
||||
"github.com/caos/zitadel/internal/id"
|
||||
es_user "github.com/caos/zitadel/internal/user/repository/eventsourcing"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Eventstore es_int.Config
|
||||
AuthRequest cache.Config
|
||||
View types.SQL
|
||||
Spooler spooler.SpoolerConfig
|
||||
}
|
||||
|
||||
type EsRepository struct {
|
||||
spooler *es_spol.Spooler
|
||||
eventstore.UserRepo
|
||||
eventstore.AuthRequestRepo
|
||||
eventstore.TokenRepo
|
||||
}
|
||||
|
||||
func Start(conf Config, systemDefaults sd.SystemDefaults) (*EsRepository, error) {
|
||||
es, err := es_int.Start(conf.Eventstore)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sqlClient, err := conf.View.Start()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
view, err := auth_view.StartView(sqlClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := es_user.StartUser(
|
||||
es_user.UserConfig{
|
||||
Eventstore: es,
|
||||
Cache: conf.Eventstore.Cache,
|
||||
},
|
||||
systemDefaults,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authReq, err := cache.Start(conf.AuthRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
repos := handler.EventstoreRepos{UserEvents: user}
|
||||
spool := spooler.StartSpooler(conf.Spooler, es, view, sqlClient, repos)
|
||||
|
||||
return &EsRepository{
|
||||
spool,
|
||||
eventstore.UserRepo{
|
||||
UserEvents: user,
|
||||
View: view,
|
||||
},
|
||||
eventstore.AuthRequestRepo{
|
||||
UserEvents: user,
|
||||
AuthRequests: authReq,
|
||||
View: view,
|
||||
UserSessionViewProvider: view,
|
||||
UserViewProvider: view,
|
||||
IdGenerator: id.SonyFlakeGenerator,
|
||||
PasswordCheckLifeTime: systemDefaults.VerificationLifetimes.PasswordCheck.Duration,
|
||||
MfaInitSkippedLifeTime: systemDefaults.VerificationLifetimes.MfaInitSkip.Duration,
|
||||
MfaSoftwareCheckLifeTime: systemDefaults.VerificationLifetimes.MfaSoftwareCheck.Duration,
|
||||
MfaHardwareCheckLifeTime: systemDefaults.VerificationLifetimes.MfaHardwareCheck.Duration,
|
||||
},
|
||||
eventstore.TokenRepo{View: view},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (repo *EsRepository) Health(ctx context.Context) error {
|
||||
if err := repo.UserRepo.Health(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return repo.AuthRequestRepo.Health(ctx)
|
||||
}
|
46
internal/auth/repository/eventsourcing/spooler/lock.go
Normal file
46
internal/auth/repository/eventsourcing/spooler/lock.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package spooler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
caos_errs "github.com/caos/zitadel/internal/errors"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/cockroach-go/crdb"
|
||||
)
|
||||
|
||||
const (
|
||||
lockTable = "auth.locks"
|
||||
lockedUntilKey = "locked_until"
|
||||
lockerIDKey = "locker_id"
|
||||
objectTypeKey = "object_type"
|
||||
)
|
||||
|
||||
type locker struct {
|
||||
dbClient *sql.DB
|
||||
}
|
||||
|
||||
type lock struct {
|
||||
LockerID string `gorm:"column:locker_id;primary_key"`
|
||||
LockedUntil time.Time `gorm:"column:locked_until"`
|
||||
ViewName string `gorm:"column:object_type;primary_key"`
|
||||
}
|
||||
|
||||
func (l *locker) Renew(lockerID, viewModel string, waitTime time.Duration) error {
|
||||
return crdb.ExecuteTx(context.Background(), l.dbClient, nil, func(tx *sql.Tx) error {
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s, %s, %s) VALUES ($1, $2, now()+$3) ON CONFLICT (%s) DO UPDATE SET %s = now()+$4, %s = $5 WHERE (locks.%s < now() OR locks.%s = $6) AND locks.%s = $7",
|
||||
lockTable, objectTypeKey, lockerIDKey, lockedUntilKey, objectTypeKey, lockedUntilKey, lockerIDKey, lockedUntilKey, lockerIDKey, objectTypeKey)
|
||||
|
||||
rs, err := tx.Exec(query, viewModel, lockerID, waitTime.Seconds(), waitTime.Seconds(), lockerID, lockerID, viewModel)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
if rows, _ := rs.RowsAffected(); rows == 0 {
|
||||
tx.Rollback()
|
||||
return caos_errs.ThrowAlreadyExists(nil, "SPOOL-lso0e", "view already locked")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
127
internal/auth/repository/eventsourcing/spooler/lock_test.go
Normal file
127
internal/auth/repository/eventsourcing/spooler/lock_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package spooler
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
)
|
||||
|
||||
type dbMock struct {
|
||||
db *sql.DB
|
||||
mock sqlmock.Sqlmock
|
||||
}
|
||||
|
||||
func mockDB(t *testing.T) *dbMock {
|
||||
mockDB := dbMock{}
|
||||
var err error
|
||||
mockDB.db, mockDB.mock, err = sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("error occured while creating stub db %v", err)
|
||||
}
|
||||
|
||||
mockDB.mock.MatchExpectationsInOrder(true)
|
||||
|
||||
return &mockDB
|
||||
}
|
||||
|
||||
func (db *dbMock) expectCommit() *dbMock {
|
||||
db.mock.ExpectCommit()
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectRollback() *dbMock {
|
||||
db.mock.ExpectRollback()
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectBegin() *dbMock {
|
||||
db.mock.ExpectBegin()
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectSavepoint() *dbMock {
|
||||
db.mock.ExpectExec("SAVEPOINT").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectReleaseSavepoint() *dbMock {
|
||||
db.mock.ExpectExec("RELEASE SAVEPOINT").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectRenew(lockerID, view string, affectedRows int64) *dbMock {
|
||||
query := db.mock.
|
||||
ExpectExec(`INSERT INTO auth\.locks \(object_type, locker_id, locked_until\) VALUES \(\$1, \$2, now\(\)\+\$3\) ON CONFLICT \(object_type\) DO UPDATE SET locked_until = now\(\)\+\$4, locker_id = \$5 WHERE \(locks\.locked_until < now\(\) OR locks\.locker_id = \$6\) AND locks\.object_type = \$7`).
|
||||
WithArgs(view, lockerID, sqlmock.AnyArg(), sqlmock.AnyArg(), lockerID, lockerID, view).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
if affectedRows == 0 {
|
||||
query.WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
} else {
|
||||
query.WillReturnResult(sqlmock.NewResult(1, affectedRows))
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func Test_locker_Renew(t *testing.T) {
|
||||
type fields struct {
|
||||
db *dbMock
|
||||
}
|
||||
type args struct {
|
||||
lockerID string
|
||||
viewModel string
|
||||
waitTime time.Duration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "renew succeeded",
|
||||
fields: fields{
|
||||
db: mockDB(t).
|
||||
expectBegin().
|
||||
expectSavepoint().
|
||||
expectRenew("locker", "view", 1).
|
||||
expectReleaseSavepoint().
|
||||
expectCommit(),
|
||||
},
|
||||
args: args{lockerID: "locker", viewModel: "view", waitTime: 1 * time.Second},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "renew now rows updated",
|
||||
fields: fields{
|
||||
db: mockDB(t).
|
||||
expectBegin().
|
||||
expectSavepoint().
|
||||
expectRenew("locker", "view", 0).
|
||||
expectRollback(),
|
||||
},
|
||||
args: args{lockerID: "locker", viewModel: "view", waitTime: 1 * time.Second},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
l := &locker{
|
||||
dbClient: tt.fields.db.db,
|
||||
}
|
||||
if err := l.Renew(tt.args.lockerID, tt.args.viewModel, tt.args.waitTime); (err != nil) != tt.wantErr {
|
||||
t.Errorf("locker.Renew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if err := tt.fields.db.mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("not all database expectations met: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
30
internal/auth/repository/eventsourcing/spooler/spooler.go
Normal file
30
internal/auth/repository/eventsourcing/spooler/spooler.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package spooler
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/caos/zitadel/internal/auth/repository/eventsourcing/handler"
|
||||
"github.com/caos/zitadel/internal/auth/repository/eventsourcing/view"
|
||||
|
||||
"github.com/caos/zitadel/internal/eventstore"
|
||||
"github.com/caos/zitadel/internal/eventstore/spooler"
|
||||
)
|
||||
|
||||
type SpoolerConfig struct {
|
||||
BulkLimit uint64
|
||||
FailureCountUntilSkip uint64
|
||||
ConcurrentTasks int
|
||||
Handlers handler.Configs
|
||||
}
|
||||
|
||||
func StartSpooler(c SpoolerConfig, es eventstore.Eventstore, view *view.View, sql *sql.DB, repos handler.EventstoreRepos) *spooler.Spooler {
|
||||
spoolerConfig := spooler.Config{
|
||||
Eventstore: es,
|
||||
Locker: &locker{dbClient: sql},
|
||||
ConcurrentTasks: c.ConcurrentTasks,
|
||||
ViewHandlers: handler.Register(c.Handlers, c.BulkLimit, c.FailureCountUntilSkip, view, repos),
|
||||
}
|
||||
spool := spoolerConfig.New()
|
||||
spool.Start()
|
||||
return spool
|
||||
}
|
17
internal/auth/repository/eventsourcing/view/error_event.go
Normal file
17
internal/auth/repository/eventsourcing/view/error_event.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package view
|
||||
|
||||
import (
|
||||
"github.com/caos/zitadel/internal/view"
|
||||
)
|
||||
|
||||
const (
|
||||
errTable = "auth.failed_event"
|
||||
)
|
||||
|
||||
func (v *View) saveFailedEvent(failedEvent *view.FailedEvent) error {
|
||||
return view.SaveFailedEvent(v.Db, errTable, failedEvent)
|
||||
}
|
||||
|
||||
func (v *View) latestFailedEvent(viewName string, sequence uint64) (*view.FailedEvent, error) {
|
||||
return view.LatestFailedEvent(v.Db, errTable, viewName, sequence)
|
||||
}
|
17
internal/auth/repository/eventsourcing/view/sequence.go
Normal file
17
internal/auth/repository/eventsourcing/view/sequence.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package view
|
||||
|
||||
import (
|
||||
"github.com/caos/zitadel/internal/view"
|
||||
)
|
||||
|
||||
const (
|
||||
sequencesTable = "auth.current_sequences"
|
||||
)
|
||||
|
||||
func (v *View) saveCurrentSequence(viewName string, sequence uint64) error {
|
||||
return view.SaveCurrentSequence(v.Db, sequencesTable, viewName, sequence)
|
||||
}
|
||||
|
||||
func (v *View) latestSequence(viewName string) (uint64, error) {
|
||||
return view.LatestSequence(v.Db, sequencesTable, viewName)
|
||||
}
|
77
internal/auth/repository/eventsourcing/view/token.go
Normal file
77
internal/auth/repository/eventsourcing/view/token.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package view
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/caos/zitadel/internal/token/repository/view"
|
||||
"github.com/caos/zitadel/internal/token/repository/view/model"
|
||||
global_view "github.com/caos/zitadel/internal/view"
|
||||
)
|
||||
|
||||
const (
|
||||
tokenTable = "auth.tokens"
|
||||
)
|
||||
|
||||
func (v *View) TokenByID(tokenID string) (*model.Token, error) {
|
||||
return view.TokenByID(v.Db, tokenTable, tokenID)
|
||||
}
|
||||
|
||||
func (v *View) IsTokenValid(tokenID string) (bool, error) {
|
||||
return view.IsTokenValid(v.Db, tokenTable, tokenID)
|
||||
}
|
||||
|
||||
func (v *View) CreateToken(agentID, applicationID, userID string, lifetime time.Duration) (*model.Token, error) {
|
||||
now := time.Now().UTC()
|
||||
token := &model.Token{
|
||||
CreationDate: now,
|
||||
UserID: userID,
|
||||
ApplicationID: applicationID,
|
||||
UserAgentID: agentID,
|
||||
Expiration: now.Add(lifetime),
|
||||
}
|
||||
err := view.PutToken(v.Db, tokenTable, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (v *View) PutToken(token *model.Token) error {
|
||||
err := view.PutToken(v.Db, tokenTable, token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return v.ProcessedTokenSequence(token.Sequence)
|
||||
}
|
||||
|
||||
func (v *View) DeleteToken(tokenID string, eventSequence uint64) error {
|
||||
err := view.DeleteToken(v.Db, tokenTable, tokenID)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return v.ProcessedTokenSequence(eventSequence)
|
||||
}
|
||||
|
||||
func (v *View) DeleteSessionTokens(agentID, userID string, eventSequence uint64) error {
|
||||
err := view.DeleteTokens(v.Db, tokenTable, agentID, userID)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return v.ProcessedTokenSequence(eventSequence)
|
||||
}
|
||||
|
||||
func (v *View) GetLatestTokenSequence() (uint64, error) {
|
||||
return v.latestSequence(tokenTable)
|
||||
}
|
||||
|
||||
func (v *View) ProcessedTokenSequence(eventSequence uint64) error {
|
||||
return v.saveCurrentSequence(tokenTable, eventSequence)
|
||||
}
|
||||
|
||||
func (v *View) GetLatestTokenFailedEvent(sequence uint64) (*global_view.FailedEvent, error) {
|
||||
return v.latestFailedEvent(tokenTable, sequence)
|
||||
}
|
||||
|
||||
func (v *View) ProcessedTokenFailedEvent(failedEvent *global_view.FailedEvent) error {
|
||||
return v.saveFailedEvent(failedEvent)
|
||||
}
|
68
internal/auth/repository/eventsourcing/view/user.go
Normal file
68
internal/auth/repository/eventsourcing/view/user.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package view
|
||||
|
||||
import (
|
||||
usr_model "github.com/caos/zitadel/internal/user/model"
|
||||
"github.com/caos/zitadel/internal/user/repository/view"
|
||||
"github.com/caos/zitadel/internal/user/repository/view/model"
|
||||
global_view "github.com/caos/zitadel/internal/view"
|
||||
)
|
||||
|
||||
const (
|
||||
userTable = "auth.users"
|
||||
)
|
||||
|
||||
func (v *View) UserByID(userID string) (*model.UserView, error) {
|
||||
return view.UserByID(v.Db, userTable, userID)
|
||||
}
|
||||
|
||||
func (v *View) UserByUsername(userName string) (*model.UserView, error) {
|
||||
return view.UserByUserName(v.Db, userTable, userName)
|
||||
}
|
||||
|
||||
func (v *View) SearchUsers(request *usr_model.UserSearchRequest) ([]*model.UserView, int, error) {
|
||||
return view.SearchUsers(v.Db, userTable, request)
|
||||
}
|
||||
|
||||
func (v *View) GetGlobalUserByEmail(email string) (*model.UserView, error) {
|
||||
return view.GetGlobalUserByEmail(v.Db, userTable, email)
|
||||
}
|
||||
|
||||
func (v *View) IsUserUnique(userName, email string) (bool, error) {
|
||||
return view.IsUserUnique(v.Db, userTable, userName, email)
|
||||
}
|
||||
|
||||
func (v *View) UserMfas(userID string) ([]*usr_model.MultiFactor, error) {
|
||||
return view.UserMfas(v.Db, userTable, userID)
|
||||
}
|
||||
|
||||
func (v *View) PutUser(user *model.UserView) error {
|
||||
err := view.PutUser(v.Db, userTable, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return v.ProcessedUserSequence(user.Sequence)
|
||||
}
|
||||
|
||||
func (v *View) DeleteUser(userID string, eventSequence uint64) error {
|
||||
err := view.DeleteUser(v.Db, userTable, userID)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return v.ProcessedUserSequence(eventSequence)
|
||||
}
|
||||
|
||||
func (v *View) GetLatestUserSequence() (uint64, error) {
|
||||
return v.latestSequence(userTable)
|
||||
}
|
||||
|
||||
func (v *View) ProcessedUserSequence(eventSequence uint64) error {
|
||||
return v.saveCurrentSequence(userTable, eventSequence)
|
||||
}
|
||||
|
||||
func (v *View) GetLatestUserFailedEvent(sequence uint64) (*global_view.FailedEvent, error) {
|
||||
return v.latestFailedEvent(userTable, sequence)
|
||||
}
|
||||
|
||||
func (v *View) ProcessedUserFailedEvent(failedEvent *global_view.FailedEvent) error {
|
||||
return v.saveFailedEvent(failedEvent)
|
||||
}
|
55
internal/auth/repository/eventsourcing/view/user_session.go
Normal file
55
internal/auth/repository/eventsourcing/view/user_session.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package view
|
||||
|
||||
import (
|
||||
"github.com/caos/zitadel/internal/user/repository/view"
|
||||
"github.com/caos/zitadel/internal/user/repository/view/model"
|
||||
global_view "github.com/caos/zitadel/internal/view"
|
||||
)
|
||||
|
||||
const (
|
||||
userSessionTable = "auth.user_sessions"
|
||||
)
|
||||
|
||||
func (v *View) UserSessionByID(sessionID string) (*model.UserSessionView, error) {
|
||||
return view.UserSessionByID(v.Db, userSessionTable, sessionID)
|
||||
}
|
||||
|
||||
func (v *View) UserSessionByIDs(agentID, userID string) (*model.UserSessionView, error) {
|
||||
return view.UserSessionByIDs(v.Db, userSessionTable, agentID, userID)
|
||||
}
|
||||
|
||||
func (v *View) UserSessionsByAgentID(agentID string) ([]*model.UserSessionView, error) {
|
||||
return view.UserSessionsByAgentID(v.Db, userSessionTable, agentID)
|
||||
}
|
||||
|
||||
func (v *View) PutUserSession(userSession *model.UserSessionView) error {
|
||||
err := view.PutUserSession(v.Db, userSessionTable, userSession)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return v.ProcessedUserSessionSequence(userSession.Sequence)
|
||||
}
|
||||
|
||||
func (v *View) DeleteUserSession(sessionID string, eventSequence uint64) error {
|
||||
err := view.DeleteUserSession(v.Db, userSessionTable, sessionID)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return v.ProcessedUserSessionSequence(eventSequence)
|
||||
}
|
||||
|
||||
func (v *View) GetLatestUserSessionSequence() (uint64, error) {
|
||||
return v.latestSequence(userSessionTable)
|
||||
}
|
||||
|
||||
func (v *View) ProcessedUserSessionSequence(eventSequence uint64) error {
|
||||
return v.saveCurrentSequence(userSessionTable, eventSequence)
|
||||
}
|
||||
|
||||
func (v *View) GetLatestUserSessionFailedEvent(sequence uint64) (*global_view.FailedEvent, error) {
|
||||
return v.latestFailedEvent(userSessionTable, sequence)
|
||||
}
|
||||
|
||||
func (v *View) ProcessedUserSessionFailedEvent(failedEvent *global_view.FailedEvent) error {
|
||||
return v.saveFailedEvent(failedEvent)
|
||||
}
|
25
internal/auth/repository/eventsourcing/view/view.go
Normal file
25
internal/auth/repository/eventsourcing/view/view.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package view
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
type View struct {
|
||||
Db *gorm.DB
|
||||
}
|
||||
|
||||
func StartView(sqlClient *sql.DB) (*View, error) {
|
||||
gorm, err := gorm.Open("postgres", sqlClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &View{
|
||||
Db: gorm,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (v *View) Health() (err error) {
|
||||
return v.Db.DB().Ping()
|
||||
}
|
12
internal/auth/repository/repository.go
Normal file
12
internal/auth/repository/repository.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
Health(context.Context) error
|
||||
UserRepository
|
||||
AuthRequestRepository
|
||||
TokenRepository
|
||||
}
|
13
internal/auth/repository/token.go
Normal file
13
internal/auth/repository/token.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/caos/zitadel/internal/token/model"
|
||||
)
|
||||
|
||||
type TokenRepository interface {
|
||||
CreateToken(ctx context.Context, agentID, applicationID, userID string, lifetime time.Duration) (*model.Token, error)
|
||||
IsTokenValid(ctx context.Context, tokenID string) (bool, error)
|
||||
}
|
42
internal/auth/repository/user.go
Normal file
42
internal/auth/repository/user.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/caos/zitadel/internal/user/model"
|
||||
)
|
||||
|
||||
type UserRepository interface {
|
||||
Register(ctx context.Context, user *model.User, resourceOwner string) (*model.User, error)
|
||||
|
||||
myUserRepo
|
||||
SkipMfaInit(ctx context.Context, userID string) error
|
||||
RequestPasswordReset(ctx context.Context, username string) error
|
||||
SetPassword(ctx context.Context, userID, code, password string) error
|
||||
|
||||
SignOut(ctx context.Context, agentID, userID string) error
|
||||
}
|
||||
|
||||
type myUserRepo interface {
|
||||
MyProfile(ctx context.Context) (*model.Profile, error)
|
||||
ChangeMyProfile(ctx context.Context, profile *model.Profile) (*model.Profile, error)
|
||||
|
||||
MyEmail(ctx context.Context) (*model.Email, error)
|
||||
ChangeMyEmail(ctx context.Context, email *model.Email) (*model.Email, error)
|
||||
VerifyMyEmail(ctx context.Context, code string) error
|
||||
ResendMyEmailVerificationMail(ctx context.Context) error
|
||||
|
||||
MyPhone(ctx context.Context) (*model.Phone, error)
|
||||
ChangeMyPhone(ctx context.Context, phone *model.Phone) (*model.Phone, error)
|
||||
VerifyMyPhone(ctx context.Context, code string) error
|
||||
ResendMyPhoneVerificationCode(ctx context.Context) error
|
||||
|
||||
MyAddress(ctx context.Context) (*model.Address, error)
|
||||
ChangeMyAddress(ctx context.Context, address *model.Address) (*model.Address, error)
|
||||
|
||||
ChangeMyPassword(ctx context.Context, old, new string) error
|
||||
|
||||
AddMyMfaOTP(ctx context.Context) (*model.OTP, error)
|
||||
VerifyMyMfaOTP(ctx context.Context, code string) error
|
||||
RemoveMyMfaOTP(ctx context.Context) error
|
||||
}
|
82
internal/auth_request/model/auth_request.go
Normal file
82
internal/auth_request/model/auth_request.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type AuthRequest struct {
|
||||
ID string
|
||||
AgentID string
|
||||
CreationDate time.Time
|
||||
ChangeDate time.Time
|
||||
BrowserInfo *BrowserInfo
|
||||
ApplicationID string
|
||||
CallbackURI string
|
||||
TransferState string
|
||||
Prompt Prompt
|
||||
PossibleLOAs []LevelOfAssurance
|
||||
UiLocales []string
|
||||
LoginHint string
|
||||
PreselectedUserID string
|
||||
MaxAuthAge uint32
|
||||
Request Request
|
||||
|
||||
levelOfAssurance LevelOfAssurance
|
||||
projectApplicationIDs []string
|
||||
UserID string
|
||||
PossibleSteps []NextStep
|
||||
}
|
||||
|
||||
type Prompt int32
|
||||
|
||||
const (
|
||||
PromptUnspecified Prompt = iota
|
||||
PromptNone
|
||||
PromptLogin
|
||||
PromptConsent
|
||||
PromptSelectAccount
|
||||
)
|
||||
|
||||
type LevelOfAssurance int
|
||||
|
||||
const (
|
||||
LevelOfAssuranceNone LevelOfAssurance = iota
|
||||
)
|
||||
|
||||
func NewAuthRequest(id, agentID string, info *BrowserInfo, applicationID, callbackURI, transferState string,
|
||||
prompt Prompt, possibleLOAs []LevelOfAssurance, uiLocales []string, loginHint, preselectedUserID string, maxAuthAge uint32, request Request) *AuthRequest {
|
||||
return &AuthRequest{
|
||||
ID: id,
|
||||
AgentID: agentID,
|
||||
BrowserInfo: info,
|
||||
ApplicationID: applicationID,
|
||||
CallbackURI: callbackURI,
|
||||
TransferState: transferState,
|
||||
Prompt: prompt,
|
||||
PossibleLOAs: possibleLOAs,
|
||||
UiLocales: uiLocales,
|
||||
LoginHint: loginHint,
|
||||
PreselectedUserID: preselectedUserID,
|
||||
MaxAuthAge: maxAuthAge,
|
||||
Request: request,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AuthRequest) IsValid() bool {
|
||||
return a.ID != "" &&
|
||||
a.AgentID != "" &&
|
||||
a.BrowserInfo != nil && a.BrowserInfo.IsValid() &&
|
||||
a.ApplicationID != "" &&
|
||||
a.CallbackURI != "" &&
|
||||
a.Request != nil && a.Request.IsValid()
|
||||
}
|
||||
|
||||
func (a *AuthRequest) MfaLevel() MfaLevel {
|
||||
return -1
|
||||
//PLANNED: check a.PossibleLOAs (and Prompt Login?)
|
||||
}
|
||||
|
||||
func (a *AuthRequest) WithCurrentInfo(info *BrowserInfo) *AuthRequest {
|
||||
a.BrowserInfo = info
|
||||
return a
|
||||
}
|
263
internal/auth_request/model/auth_request_test.go
Normal file
263
internal/auth_request/model/auth_request_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAuthRequest_IsValid(t *testing.T) {
|
||||
type fields struct {
|
||||
ID string
|
||||
AgentID string
|
||||
BrowserInfo *BrowserInfo
|
||||
ApplicationID string
|
||||
CallbackURI string
|
||||
Request Request
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
"missing id, false",
|
||||
fields{},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"missing agent id, false",
|
||||
fields{
|
||||
ID: "id",
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"missing browser info, false",
|
||||
fields{
|
||||
ID: "id",
|
||||
AgentID: "agentID",
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"browser info invalid, false",
|
||||
fields{
|
||||
ID: "id",
|
||||
AgentID: "agentID",
|
||||
BrowserInfo: &BrowserInfo{},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"missing application id, false",
|
||||
fields{
|
||||
ID: "id",
|
||||
AgentID: "agentID",
|
||||
BrowserInfo: &BrowserInfo{
|
||||
UserAgent: "user agent",
|
||||
AcceptLanguage: "accept language",
|
||||
RemoteIP: net.IPv4(29, 4, 20, 19),
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"missing callback uri, false",
|
||||
fields{
|
||||
ID: "id",
|
||||
AgentID: "agentID",
|
||||
BrowserInfo: &BrowserInfo{
|
||||
UserAgent: "user agent",
|
||||
AcceptLanguage: "accept language",
|
||||
RemoteIP: net.IPv4(29, 4, 20, 19),
|
||||
},
|
||||
ApplicationID: "appID",
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"missing request, false",
|
||||
fields{
|
||||
ID: "id",
|
||||
AgentID: "agentID",
|
||||
BrowserInfo: &BrowserInfo{
|
||||
UserAgent: "user agent",
|
||||
AcceptLanguage: "accept language",
|
||||
RemoteIP: net.IPv4(29, 4, 20, 19),
|
||||
},
|
||||
ApplicationID: "appID",
|
||||
CallbackURI: "schema://callback",
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"request invalid, false",
|
||||
fields{
|
||||
ID: "id",
|
||||
AgentID: "agentID",
|
||||
BrowserInfo: &BrowserInfo{
|
||||
UserAgent: "user agent",
|
||||
AcceptLanguage: "accept language",
|
||||
RemoteIP: net.IPv4(29, 4, 20, 19),
|
||||
},
|
||||
ApplicationID: "appID",
|
||||
CallbackURI: "schema://callback",
|
||||
Request: &AuthRequestOIDC{},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"valid auth request, true",
|
||||
fields{
|
||||
ID: "id",
|
||||
AgentID: "agentID",
|
||||
BrowserInfo: &BrowserInfo{
|
||||
UserAgent: "user agent",
|
||||
AcceptLanguage: "accept language",
|
||||
RemoteIP: net.IPv4(29, 4, 20, 19),
|
||||
},
|
||||
ApplicationID: "appID",
|
||||
CallbackURI: "schema://callback",
|
||||
Request: &AuthRequestOIDC{
|
||||
Scopes: []string{"openid"},
|
||||
CodeChallenge: &OIDCCodeChallenge{
|
||||
Challenge: "challenge",
|
||||
Method: CodeChallengeMethodS256,
|
||||
},
|
||||
},
|
||||
},
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &AuthRequest{
|
||||
ID: tt.fields.ID,
|
||||
AgentID: tt.fields.AgentID,
|
||||
BrowserInfo: tt.fields.BrowserInfo,
|
||||
ApplicationID: tt.fields.ApplicationID,
|
||||
CallbackURI: tt.fields.CallbackURI,
|
||||
Request: tt.fields.Request,
|
||||
}
|
||||
if got := a.IsValid(); got != tt.want {
|
||||
t.Errorf("IsValid() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthRequest_MfaLevel(t *testing.T) {
|
||||
type fields struct {
|
||||
Prompt Prompt
|
||||
PossibleLOAs []LevelOfAssurance
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want MfaLevel
|
||||
}{
|
||||
//PLANNED: Add / replace test cases when LOA is set
|
||||
{"-1",
|
||||
fields{},
|
||||
-1,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &AuthRequest{
|
||||
Prompt: tt.fields.Prompt,
|
||||
PossibleLOAs: tt.fields.PossibleLOAs,
|
||||
}
|
||||
if got := a.MfaLevel(); got != tt.want {
|
||||
t.Errorf("MfaLevel() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthRequest_WithCurrentInfo(t *testing.T) {
|
||||
type fields struct {
|
||||
ID string
|
||||
AgentID string
|
||||
BrowserInfo *BrowserInfo
|
||||
}
|
||||
type args struct {
|
||||
info *BrowserInfo
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want *AuthRequest
|
||||
}{
|
||||
{
|
||||
"unchanged",
|
||||
fields{
|
||||
ID: "id",
|
||||
AgentID: "agentID",
|
||||
BrowserInfo: &BrowserInfo{
|
||||
UserAgent: "ua",
|
||||
AcceptLanguage: "de",
|
||||
RemoteIP: net.IPv4(29, 4, 20, 19),
|
||||
},
|
||||
},
|
||||
args{
|
||||
&BrowserInfo{
|
||||
UserAgent: "ua",
|
||||
AcceptLanguage: "de",
|
||||
RemoteIP: net.IPv4(29, 4, 20, 19),
|
||||
},
|
||||
},
|
||||
&AuthRequest{
|
||||
ID: "id",
|
||||
AgentID: "agentID",
|
||||
BrowserInfo: &BrowserInfo{
|
||||
UserAgent: "ua",
|
||||
AcceptLanguage: "de",
|
||||
RemoteIP: net.IPv4(29, 4, 20, 19),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"changed",
|
||||
fields{
|
||||
ID: "id",
|
||||
AgentID: "agentID",
|
||||
BrowserInfo: &BrowserInfo{
|
||||
UserAgent: "ua",
|
||||
AcceptLanguage: "de",
|
||||
RemoteIP: net.IPv4(29, 4, 20, 19),
|
||||
},
|
||||
},
|
||||
args{
|
||||
&BrowserInfo{
|
||||
UserAgent: "ua",
|
||||
AcceptLanguage: "de",
|
||||
RemoteIP: net.IPv4(16, 12, 20, 19),
|
||||
},
|
||||
},
|
||||
&AuthRequest{
|
||||
ID: "id",
|
||||
AgentID: "agentID",
|
||||
BrowserInfo: &BrowserInfo{
|
||||
UserAgent: "ua",
|
||||
AcceptLanguage: "de",
|
||||
RemoteIP: net.IPv4(16, 12, 20, 19),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &AuthRequest{
|
||||
ID: tt.fields.ID,
|
||||
AgentID: tt.fields.AgentID,
|
||||
BrowserInfo: tt.fields.BrowserInfo,
|
||||
}
|
||||
if got := a.WithCurrentInfo(tt.args.info); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("WithCurrentInfo() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
15
internal/auth_request/model/browser_info.go
Normal file
15
internal/auth_request/model/browser_info.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package model
|
||||
|
||||
import "net"
|
||||
|
||||
type BrowserInfo struct {
|
||||
UserAgent string
|
||||
AcceptLanguage string
|
||||
RemoteIP net.IP
|
||||
}
|
||||
|
||||
func (i *BrowserInfo) IsValid() bool {
|
||||
return i.UserAgent != "" &&
|
||||
i.AcceptLanguage != "" &&
|
||||
i.RemoteIP != nil && !i.RemoteIP.IsUnspecified()
|
||||
}
|
17
internal/auth_request/model/code_challenge.go
Normal file
17
internal/auth_request/model/code_challenge.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package model
|
||||
|
||||
type OIDCCodeChallenge struct {
|
||||
Challenge string
|
||||
Method OIDCCodeChallengeMethod
|
||||
}
|
||||
|
||||
func (c *OIDCCodeChallenge) IsValid() bool {
|
||||
return c.Challenge != ""
|
||||
}
|
||||
|
||||
type OIDCCodeChallengeMethod int32
|
||||
|
||||
const (
|
||||
CodeChallengeMethodPlain OIDCCodeChallengeMethod = iota
|
||||
CodeChallengeMethodS256
|
||||
)
|
117
internal/auth_request/model/next_step.go
Normal file
117
internal/auth_request/model/next_step.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package model
|
||||
|
||||
type NextStep interface {
|
||||
Type() NextStepType
|
||||
}
|
||||
|
||||
type NextStepType int32
|
||||
|
||||
const (
|
||||
NextStepUnspecified NextStepType = iota
|
||||
NextStepLogin
|
||||
NextStepUserSelection
|
||||
NextStepPassword
|
||||
NextStepChangePassword
|
||||
NextStepInitPassword
|
||||
NextStepVerifyEmail
|
||||
NextStepMfaPrompt
|
||||
NextStepMfaVerify
|
||||
NextStepRedirectToCallback
|
||||
)
|
||||
|
||||
type UserSessionState int32
|
||||
|
||||
const (
|
||||
UserSessionStateActive UserSessionState = iota
|
||||
UserSessionStateTerminated
|
||||
)
|
||||
|
||||
type LoginStep struct {
|
||||
NotFound bool
|
||||
}
|
||||
|
||||
func (s *LoginStep) Type() NextStepType {
|
||||
return NextStepLogin
|
||||
}
|
||||
|
||||
type SelectUserStep struct {
|
||||
Users []UserSelection
|
||||
}
|
||||
|
||||
func (s *SelectUserStep) Type() NextStepType {
|
||||
return NextStepUserSelection
|
||||
}
|
||||
|
||||
type UserSelection struct {
|
||||
UserID string
|
||||
UserName string
|
||||
UserSessionState UserSessionState
|
||||
}
|
||||
|
||||
type PasswordStep struct {
|
||||
FailureCount uint16
|
||||
}
|
||||
|
||||
func (s *PasswordStep) Type() NextStepType {
|
||||
return NextStepPassword
|
||||
}
|
||||
|
||||
type ChangePasswordStep struct {
|
||||
}
|
||||
|
||||
func (s *ChangePasswordStep) Type() NextStepType {
|
||||
return NextStepChangePassword
|
||||
}
|
||||
|
||||
type InitPasswordStep struct {
|
||||
}
|
||||
|
||||
func (s *InitPasswordStep) Type() NextStepType {
|
||||
return NextStepInitPassword
|
||||
}
|
||||
|
||||
type VerifyEMailStep struct {
|
||||
}
|
||||
|
||||
func (s *VerifyEMailStep) Type() NextStepType {
|
||||
return NextStepVerifyEmail
|
||||
}
|
||||
|
||||
type MfaPromptStep struct {
|
||||
Required bool
|
||||
MfaProviders []MfaType
|
||||
}
|
||||
|
||||
func (s *MfaPromptStep) Type() NextStepType {
|
||||
return NextStepMfaPrompt
|
||||
}
|
||||
|
||||
type MfaVerificationStep struct {
|
||||
FailureCount uint16
|
||||
MfaProviders []MfaType
|
||||
}
|
||||
|
||||
func (s *MfaVerificationStep) Type() NextStepType {
|
||||
return NextStepMfaVerify
|
||||
}
|
||||
|
||||
type RedirectToCallbackStep struct {
|
||||
}
|
||||
|
||||
func (s *RedirectToCallbackStep) Type() NextStepType {
|
||||
return NextStepRedirectToCallback
|
||||
}
|
||||
|
||||
type MfaType int
|
||||
|
||||
const (
|
||||
MfaTypeOTP MfaType = iota
|
||||
)
|
||||
|
||||
type MfaLevel int
|
||||
|
||||
const (
|
||||
MfaLevelSoftware MfaLevel = iota
|
||||
MfaLevelHardware
|
||||
MfaLevelHardwareCertified
|
||||
)
|
48
internal/auth_request/model/request.go
Normal file
48
internal/auth_request/model/request.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package model
|
||||
|
||||
type Request interface {
|
||||
Type() AuthRequestType
|
||||
IsValid() bool
|
||||
}
|
||||
|
||||
type AuthRequestType int32
|
||||
|
||||
const (
|
||||
AuthRequestTypeOIDC AuthRequestType = iota
|
||||
AuthRequestTypeSAML
|
||||
)
|
||||
|
||||
type AuthRequestOIDC struct {
|
||||
Scopes []string
|
||||
ResponseType OIDCResponseType
|
||||
Nonce string
|
||||
CodeChallenge *OIDCCodeChallenge
|
||||
}
|
||||
|
||||
func (a *AuthRequestOIDC) Type() AuthRequestType {
|
||||
return AuthRequestTypeOIDC
|
||||
}
|
||||
|
||||
func (a *AuthRequestOIDC) IsValid() bool {
|
||||
return len(a.Scopes) > 0 &&
|
||||
a.CodeChallenge == nil || a.CodeChallenge != nil && a.CodeChallenge.IsValid()
|
||||
}
|
||||
|
||||
type AuthRequestSAML struct {
|
||||
}
|
||||
|
||||
func (a *AuthRequestSAML) Type() AuthRequestType {
|
||||
return AuthRequestTypeSAML
|
||||
}
|
||||
|
||||
func (a *AuthRequestSAML) IsValid() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type OIDCResponseType int32
|
||||
|
||||
const (
|
||||
OIDCResponseTypeCode OIDCResponseType = iota
|
||||
OIDCResponseTypeIdToken
|
||||
OIDCResponseTypeToken
|
||||
)
|
67
internal/auth_request/repository/cache/cache.go
vendored
Normal file
67
internal/auth_request/repository/cache/cache.go
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/caos/zitadel/internal/auth_request/model"
|
||||
"github.com/caos/zitadel/internal/config/types"
|
||||
caos_errs "github.com/caos/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Connection types.SQL
|
||||
}
|
||||
|
||||
type AuthRequestCache struct {
|
||||
client *sql.DB
|
||||
}
|
||||
|
||||
func Start(conf Config) (*AuthRequestCache, error) {
|
||||
client, err := sql.Open("postgres", conf.Connection.ConnectionString())
|
||||
if err != nil {
|
||||
return nil, caos_errs.ThrowPreconditionFailed(err, "SQL-9qBtr", "unable to open database connection")
|
||||
}
|
||||
return &AuthRequestCache{
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *AuthRequestCache) Health(ctx context.Context) error {
|
||||
return c.client.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (c *AuthRequestCache) GetAuthRequestByID(_ context.Context, id string) (*model.AuthRequest, error) {
|
||||
var b []byte
|
||||
err := c.client.QueryRow("SELECT request FROM auth.authrequests WHERE id = ?", id).Scan(&b)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, caos_errs.ThrowNotFound(err, "CACHE-d24aD", "auth request not found")
|
||||
}
|
||||
return nil, caos_errs.ThrowInternal(err, "CACHE-as3kj", "unable to get auth request from database")
|
||||
}
|
||||
request := new(model.AuthRequest)
|
||||
err = json.Unmarshal(b, &request)
|
||||
if err != nil {
|
||||
return nil, caos_errs.ThrowInternal(err, "CACHE-2wshg", "unable to unmarshal auth request")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (c *AuthRequestCache) SaveAuthRequest(_ context.Context, request *model.AuthRequest) error {
|
||||
b, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return caos_errs.ThrowInternal(err, "CACHE-32FH9", "unable to marshal auth request")
|
||||
}
|
||||
stmt, err := c.client.Prepare("INSERT INTO auth.authrequests (id, request) VALUES($1, $2)")
|
||||
if err != nil {
|
||||
return caos_errs.ThrowInternal(err, "CACHE-dswfF", "sql prepare failed")
|
||||
}
|
||||
_, err = stmt.Exec(request.ID, b)
|
||||
if err != nil {
|
||||
return caos_errs.ThrowInternal(err, "CACHE-sw4af", "unable to save auth request")
|
||||
}
|
||||
return nil
|
||||
}
|
3
internal/auth_request/repository/gen_mock.go
Normal file
3
internal/auth_request/repository/gen_mock.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package repository
|
||||
|
||||
//go:generate mockgen -package mock -destination ./mock/repository.mock.go github.com/caos/zitadel/internal/auth_request/repository Repository
|
12
internal/auth_request/repository/mock/repository.go
Normal file
12
internal/auth_request/repository/mock/repository.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
"github.com/caos/zitadel/internal/auth_request/repository"
|
||||
)
|
||||
|
||||
func NewMockAuthRequestRepository(ctrl *gomock.Controller) repository.Repository {
|
||||
repo := NewMockRepository(ctrl)
|
||||
return repo
|
||||
}
|
79
internal/auth_request/repository/mock/repository.mock.go
Normal file
79
internal/auth_request/repository/mock/repository.mock.go
Normal file
@@ -0,0 +1,79 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/caos/zitadel/internal/auth_request/repository (interfaces: Repository)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
model "github.com/caos/zitadel/internal/auth_request/model"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockRepository is a mock of Repository interface
|
||||
type MockRepository struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockRepositoryMockRecorder
|
||||
}
|
||||
|
||||
// MockRepositoryMockRecorder is the mock recorder for MockRepository
|
||||
type MockRepositoryMockRecorder struct {
|
||||
mock *MockRepository
|
||||
}
|
||||
|
||||
// NewMockRepository creates a new mock instance
|
||||
func NewMockRepository(ctrl *gomock.Controller) *MockRepository {
|
||||
mock := &MockRepository{ctrl: ctrl}
|
||||
mock.recorder = &MockRepositoryMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockRepository) EXPECT() *MockRepositoryMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetAuthRequestByID mocks base method
|
||||
func (m *MockRepository) GetAuthRequestByID(arg0 context.Context, arg1 string) (*model.AuthRequest, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAuthRequestByID", arg0, arg1)
|
||||
ret0, _ := ret[0].(*model.AuthRequest)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAuthRequestByID indicates an expected call of GetAuthRequestByID
|
||||
func (mr *MockRepositoryMockRecorder) GetAuthRequestByID(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByID", reflect.TypeOf((*MockRepository)(nil).GetAuthRequestByID), arg0, arg1)
|
||||
}
|
||||
|
||||
// Health mocks base method
|
||||
func (m *MockRepository) Health(arg0 context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Health", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Health indicates an expected call of Health
|
||||
func (mr *MockRepositoryMockRecorder) Health(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockRepository)(nil).Health), arg0)
|
||||
}
|
||||
|
||||
// SaveAuthRequest mocks base method
|
||||
func (m *MockRepository) SaveAuthRequest(arg0 context.Context, arg1 string) (*model.AuthRequest, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SaveAuthRequest", arg0, arg1)
|
||||
ret0, _ := ret[0].(*model.AuthRequest)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SaveAuthRequest indicates an expected call of SaveAuthRequest
|
||||
func (mr *MockRepositoryMockRecorder) SaveAuthRequest(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveAuthRequest", reflect.TypeOf((*MockRepository)(nil).SaveAuthRequest), arg0, arg1)
|
||||
}
|
14
internal/auth_request/repository/repository.go
Normal file
14
internal/auth_request/repository/repository.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/caos/zitadel/internal/auth_request/model"
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
Health(ctx context.Context) error
|
||||
|
||||
GetAuthRequestByID(ctx context.Context, id string) (*model.AuthRequest, error)
|
||||
SaveAuthRequest(ctx context.Context, id string) (*model.AuthRequest, error)
|
||||
}
|
@@ -7,12 +7,13 @@ import (
|
||||
)
|
||||
|
||||
type SystemDefaults struct {
|
||||
SecretGenerators SecretGenerators
|
||||
UserVerificationKey *crypto.KeyConfig
|
||||
Multifactors MultifactorConfig
|
||||
DefaultPolicies DefaultPolicies
|
||||
IamID string
|
||||
SetUp types.IAMSetUp
|
||||
SecretGenerators SecretGenerators
|
||||
UserVerificationKey *crypto.KeyConfig
|
||||
Multifactors MultifactorConfig
|
||||
VerificationLifetimes VerificationLifetimes
|
||||
DefaultPolicies DefaultPolicies
|
||||
IamID string
|
||||
SetUp types.IAMSetUp
|
||||
}
|
||||
|
||||
type SecretGenerators struct {
|
||||
@@ -33,6 +34,13 @@ type OTPConfig struct {
|
||||
VerificationKey *crypto.KeyConfig
|
||||
}
|
||||
|
||||
type VerificationLifetimes struct {
|
||||
PasswordCheck types.Duration
|
||||
MfaInitSkip types.Duration
|
||||
MfaSoftwareCheck types.Duration
|
||||
MfaHardwareCheck types.Duration
|
||||
}
|
||||
|
||||
type DefaultPolicies struct {
|
||||
Age pol.PasswordAgePolicyDefault
|
||||
Complexity pol.PasswordComplexityPolicyDefault
|
||||
|
@@ -27,8 +27,8 @@ func CreateMockEncryptionAlg(ctrl *gomock.Controller) EncryptionAlgorithm {
|
||||
return mCrypto
|
||||
}
|
||||
|
||||
func createMockHashAlg(t *testing.T) HashAlgorithm {
|
||||
mCrypto := NewMockHashAlgorithm(gomock.NewController(t))
|
||||
func CreateMockHashAlg(ctrl *gomock.Controller) HashAlgorithm {
|
||||
mCrypto := NewMockHashAlgorithm(ctrl)
|
||||
mCrypto.EXPECT().Algorithm().AnyTimes().Return("hash")
|
||||
mCrypto.EXPECT().Hash(gomock.Any()).DoAndReturn(
|
||||
func(code []byte) ([]byte, error) {
|
||||
|
@@ -113,7 +113,7 @@ func TestVerifyCode(t *testing.T) {
|
||||
Crypted: []byte("code"),
|
||||
},
|
||||
verificationCode: "code",
|
||||
g: createMockGenerator(t, createMockHashAlg(t)),
|
||||
g: createMockGenerator(t, CreateMockHashAlg(gomock.NewController(t))),
|
||||
},
|
||||
false,
|
||||
},
|
||||
@@ -240,7 +240,7 @@ func Test_verifyHashedCode(t *testing.T) {
|
||||
args{
|
||||
cryptoCode: nil,
|
||||
verificationCode: "",
|
||||
alg: createMockHashAlg(t),
|
||||
alg: CreateMockHashAlg(gomock.NewController(t)),
|
||||
},
|
||||
true,
|
||||
},
|
||||
@@ -252,7 +252,7 @@ func Test_verifyHashedCode(t *testing.T) {
|
||||
Crypted: nil,
|
||||
},
|
||||
verificationCode: "",
|
||||
alg: createMockHashAlg(t),
|
||||
alg: CreateMockHashAlg(gomock.NewController(t)),
|
||||
},
|
||||
true,
|
||||
},
|
||||
@@ -265,7 +265,7 @@ func Test_verifyHashedCode(t *testing.T) {
|
||||
Crypted: nil,
|
||||
},
|
||||
verificationCode: "",
|
||||
alg: createMockHashAlg(t),
|
||||
alg: CreateMockHashAlg(gomock.NewController(t)),
|
||||
},
|
||||
true,
|
||||
},
|
||||
@@ -278,7 +278,7 @@ func Test_verifyHashedCode(t *testing.T) {
|
||||
Crypted: []byte("code"),
|
||||
},
|
||||
verificationCode: "wrong",
|
||||
alg: createMockHashAlg(t),
|
||||
alg: CreateMockHashAlg(gomock.NewController(t)),
|
||||
},
|
||||
true,
|
||||
},
|
||||
@@ -291,7 +291,7 @@ func Test_verifyHashedCode(t *testing.T) {
|
||||
Crypted: []byte("code"),
|
||||
},
|
||||
verificationCode: "code",
|
||||
alg: createMockHashAlg(t),
|
||||
alg: CreateMockHashAlg(gomock.NewController(t)),
|
||||
},
|
||||
false,
|
||||
},
|
||||
|
@@ -100,7 +100,7 @@ func Hash(value []byte, alg HashAlgorithm) (*CryptoValue, error) {
|
||||
|
||||
func CompareHash(value *CryptoValue, comparer []byte, alg HashAlgorithm) error {
|
||||
if value.Algorithm != alg.Algorithm() {
|
||||
return errors.ThrowInvalidArgument(nil, "CRYPT-HF32f", "value was hash with a different algorithm")
|
||||
return errors.ThrowInvalidArgument(nil, "CRYPT-HF32f", "value was hashed with a different algorithm")
|
||||
}
|
||||
return alg.CompareHash(value.Crypted, comparer)
|
||||
}
|
||||
|
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
type filterFunc func(context.Context, *es_models.SearchQuery) ([]*es_models.Event, error)
|
||||
type appendFunc func(...*es_models.Event) error
|
||||
type aggregateFunc func(context.Context) (*es_models.Aggregate, error)
|
||||
type AggregateFunc func(context.Context) (*es_models.Aggregate, error)
|
||||
type pushFunc func(context.Context, ...*es_models.Aggregate) error
|
||||
|
||||
func Filter(ctx context.Context, filter filterFunc, appender appendFunc, query *es_models.SearchQuery) error {
|
||||
@@ -32,7 +32,7 @@ func Filter(ctx context.Context, filter filterFunc, appender appendFunc, query *
|
||||
// Push creates the aggregates from aggregater
|
||||
// and pushes the aggregates to the given pushFunc
|
||||
// the given events are appended by the appender
|
||||
func Push(ctx context.Context, push pushFunc, appender appendFunc, aggregaters ...aggregateFunc) (err error) {
|
||||
func Push(ctx context.Context, push pushFunc, appender appendFunc, aggregaters ...AggregateFunc) (err error) {
|
||||
if len(aggregaters) < 1 {
|
||||
return errors.ThrowPreconditionFailed(nil, "SDK-q9wjp", "no aggregaters passed")
|
||||
}
|
||||
@@ -73,7 +73,7 @@ func appendAggregates(appender appendFunc, aggregates []*models.Aggregate) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeAggregates(ctx context.Context, aggregaters []aggregateFunc) (aggregates []*models.Aggregate, err error) {
|
||||
func makeAggregates(ctx context.Context, aggregaters []AggregateFunc) (aggregates []*models.Aggregate, err error) {
|
||||
aggregates = make([]*models.Aggregate, len(aggregaters))
|
||||
for i, aggregater := range aggregaters {
|
||||
aggregates[i], err = aggregater(ctx)
|
||||
|
@@ -80,7 +80,7 @@ func TestPush(t *testing.T) {
|
||||
type args struct {
|
||||
push pushFunc
|
||||
appender appendFunc
|
||||
aggregaters []aggregateFunc
|
||||
aggregaters []AggregateFunc
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -101,7 +101,7 @@ func TestPush(t *testing.T) {
|
||||
args: args{
|
||||
push: nil,
|
||||
appender: nil,
|
||||
aggregaters: []aggregateFunc{
|
||||
aggregaters: []AggregateFunc{
|
||||
func(context.Context) (*es_models.Aggregate, error) {
|
||||
return nil, errors.ThrowInternal(nil, "SDK-Ec5x2", "test err")
|
||||
},
|
||||
@@ -116,7 +116,7 @@ func TestPush(t *testing.T) {
|
||||
return errors.ThrowInternal(nil, "SDK-0g4gW", "test error")
|
||||
},
|
||||
appender: nil,
|
||||
aggregaters: []aggregateFunc{
|
||||
aggregaters: []AggregateFunc{
|
||||
func(context.Context) (*es_models.Aggregate, error) {
|
||||
return &es_models.Aggregate{}, nil
|
||||
},
|
||||
@@ -133,7 +133,7 @@ func TestPush(t *testing.T) {
|
||||
appender: func(...*es_models.Event) error {
|
||||
return errors.ThrowInvalidArgument(nil, "SDK-BDhcT", "test err")
|
||||
},
|
||||
aggregaters: []aggregateFunc{
|
||||
aggregaters: []AggregateFunc{
|
||||
func(context.Context) (*es_models.Aggregate, error) {
|
||||
return &es_models.Aggregate{Events: []*es_models.Event{&es_models.Event{}}}, nil
|
||||
},
|
||||
@@ -150,7 +150,7 @@ func TestPush(t *testing.T) {
|
||||
appender: func(...*es_models.Event) error {
|
||||
return nil
|
||||
},
|
||||
aggregaters: []aggregateFunc{
|
||||
aggregaters: []AggregateFunc{
|
||||
func(context.Context) (*es_models.Aggregate, error) {
|
||||
return &es_models.Aggregate{Events: []*es_models.Event{&es_models.Event{}}}, nil
|
||||
},
|
||||
@@ -167,7 +167,7 @@ func TestPush(t *testing.T) {
|
||||
appender: func(...*es_models.Event) error {
|
||||
return nil
|
||||
},
|
||||
aggregaters: []aggregateFunc{
|
||||
aggregaters: []AggregateFunc{
|
||||
func(context.Context) (*es_models.Aggregate, error) {
|
||||
return &es_models.Aggregate{Events: []*es_models.Event{&es_models.Event{}}}, nil
|
||||
},
|
||||
|
3
internal/id/gen_mock.go
Normal file
3
internal/id/gen_mock.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package id
|
||||
|
||||
//go:generate mockgen -package mock -destination ./mock/generator.mock.go github.com/caos/zitadel/internal/id Generator
|
5
internal/id/id_generator.go
Normal file
5
internal/id/id_generator.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package id
|
||||
|
||||
type Generator interface {
|
||||
Next() (string, error)
|
||||
}
|
48
internal/id/mock/generator.mock.go
Normal file
48
internal/id/mock/generator.mock.go
Normal file
@@ -0,0 +1,48 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/caos/zitadel/internal/id (interfaces: Generator)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockGenerator is a mock of Generator interface
|
||||
type MockGenerator struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockGeneratorMockRecorder
|
||||
}
|
||||
|
||||
// MockGeneratorMockRecorder is the mock recorder for MockGenerator
|
||||
type MockGeneratorMockRecorder struct {
|
||||
mock *MockGenerator
|
||||
}
|
||||
|
||||
// NewMockGenerator creates a new mock instance
|
||||
func NewMockGenerator(ctrl *gomock.Controller) *MockGenerator {
|
||||
mock := &MockGenerator{ctrl: ctrl}
|
||||
mock.recorder = &MockGeneratorMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockGenerator) EXPECT() *MockGeneratorMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Next mocks base method
|
||||
func (m *MockGenerator) Next() (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Next")
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Next indicates an expected call of Next
|
||||
func (mr *MockGeneratorMockRecorder) Next() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockGenerator)(nil).Next))
|
||||
}
|
28
internal/id/sonyflake.go
Normal file
28
internal/id/sonyflake.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package id
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/sony/sonyflake"
|
||||
)
|
||||
|
||||
type sonyflakeGenerator struct {
|
||||
*sonyflake.Sonyflake
|
||||
}
|
||||
|
||||
func (s *sonyflakeGenerator) Next() (string, error) {
|
||||
id, err := s.NextID()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strconv.FormatUint(id, 10), nil
|
||||
}
|
||||
|
||||
var (
|
||||
SonyFlakeGenerator = Generator(&sonyflakeGenerator{
|
||||
sonyflake.NewSonyflake(sonyflake.Settings{
|
||||
StartTime: time.Date(2019, 4, 29, 0, 0, 0, 0, time.UTC),
|
||||
}),
|
||||
})
|
||||
)
|
87
internal/management/repository/eventsourcing/user.go
Normal file
87
internal/management/repository/eventsourcing/user.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package eventsourcing
|
||||
|
||||
import (
|
||||
"context"
|
||||
usr_model "github.com/caos/zitadel/internal/user/model"
|
||||
usr_event "github.com/caos/zitadel/internal/user/repository/eventsourcing"
|
||||
)
|
||||
|
||||
type UserRepo struct {
|
||||
UserEvents *usr_event.UserEventstore
|
||||
}
|
||||
|
||||
func (repo *UserRepo) UserByID(ctx context.Context, id string) (project *usr_model.User, err error) {
|
||||
return repo.UserEvents.UserByID(ctx, id)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) CreateUser(ctx context.Context, user *usr_model.User) (*usr_model.User, error) {
|
||||
return repo.UserEvents.CreateUser(ctx, user)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) RegisterUser(ctx context.Context, user *usr_model.User, resourceOwner string) (*usr_model.User, error) {
|
||||
return repo.UserEvents.RegisterUser(ctx, user, resourceOwner)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) DeactivateUser(ctx context.Context, id string) (*usr_model.User, error) {
|
||||
return repo.UserEvents.DeactivateUser(ctx, id)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) ReactivateUser(ctx context.Context, id string) (*usr_model.User, error) {
|
||||
return repo.UserEvents.ReactivateUser(ctx, id)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) LockUser(ctx context.Context, id string) (*usr_model.User, error) {
|
||||
return repo.UserEvents.LockUser(ctx, id)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) UnlockUser(ctx context.Context, id string) (*usr_model.User, error) {
|
||||
return repo.UserEvents.UnlockUser(ctx, id)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) SetOneTimePassword(ctx context.Context, password *usr_model.Password) (*usr_model.Password, error) {
|
||||
return repo.UserEvents.SetOneTimePassword(ctx, password)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) RequestSetPassword(ctx context.Context, id string, notifyType usr_model.NotificationType) error {
|
||||
return repo.UserEvents.RequestSetPassword(ctx, id, notifyType)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) ProfileByID(ctx context.Context, userID string) (*usr_model.Profile, error) {
|
||||
return repo.UserEvents.ProfileByID(ctx, userID)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) ChangeProfile(ctx context.Context, profile *usr_model.Profile) (*usr_model.Profile, error) {
|
||||
return repo.UserEvents.ChangeProfile(ctx, profile)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) EmailByID(ctx context.Context, userID string) (*usr_model.Email, error) {
|
||||
return repo.UserEvents.EmailByID(ctx, userID)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) ChangeEmail(ctx context.Context, email *usr_model.Email) (*usr_model.Email, error) {
|
||||
return repo.UserEvents.ChangeEmail(ctx, email)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) CreateEmailVerificationCode(ctx context.Context, userID string) error {
|
||||
return repo.UserEvents.CreateEmailVerificationCode(ctx, userID)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) PhoneByID(ctx context.Context, userID string) (*usr_model.Phone, error) {
|
||||
return repo.UserEvents.PhoneByID(ctx, userID)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) ChangePhone(ctx context.Context, email *usr_model.Phone) (*usr_model.Phone, error) {
|
||||
return repo.UserEvents.ChangePhone(ctx, email)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) CreatePhoneVerificationCode(ctx context.Context, userID string) error {
|
||||
return repo.UserEvents.CreatePhoneVerificationCode(ctx, userID)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) AddressByID(ctx context.Context, userID string) (*usr_model.Address, error) {
|
||||
return repo.UserEvents.AddressByID(ctx, userID)
|
||||
}
|
||||
|
||||
func (repo *UserRepo) ChangeAddress(ctx context.Context, address *usr_model.Address) (*usr_model.Address, error) {
|
||||
return repo.UserEvents.ChangeAddress(ctx, address)
|
||||
}
|
35
internal/management/repository/eventsourcing/user_grant.go
Normal file
35
internal/management/repository/eventsourcing/user_grant.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package eventsourcing
|
||||
|
||||
import (
|
||||
"context"
|
||||
grant_model "github.com/caos/zitadel/internal/usergrant/model"
|
||||
grant_event "github.com/caos/zitadel/internal/usergrant/repository/eventsourcing"
|
||||
)
|
||||
|
||||
type UserGrantRepo struct {
|
||||
UserGrantEvents *grant_event.UserGrantEventStore
|
||||
}
|
||||
|
||||
func (repo *UserGrantRepo) UserGrantByID(ctx context.Context, grantID string) (*grant_model.UserGrant, error) {
|
||||
return repo.UserGrantEvents.UserGrantByID(ctx, grantID)
|
||||
}
|
||||
|
||||
func (repo *UserGrantRepo) AddUserGrant(ctx context.Context, grant *grant_model.UserGrant) (*grant_model.UserGrant, error) {
|
||||
return repo.UserGrantEvents.AddUserGrant(ctx, grant)
|
||||
}
|
||||
|
||||
func (repo *UserGrantRepo) ChangeUserGrant(ctx context.Context, grant *grant_model.UserGrant) (*grant_model.UserGrant, error) {
|
||||
return repo.UserGrantEvents.ChangeUserGrant(ctx, grant)
|
||||
}
|
||||
|
||||
func (repo *UserGrantRepo) DeactivateUserGrant(ctx context.Context, grantID string) (*grant_model.UserGrant, error) {
|
||||
return repo.UserGrantEvents.DeactivateUserGrant(ctx, grantID)
|
||||
}
|
||||
|
||||
func (repo *UserGrantRepo) ReactivateUserGrant(ctx context.Context, grantID string) (*grant_model.UserGrant, error) {
|
||||
return repo.UserGrantEvents.ReactivateUserGrant(ctx, grantID)
|
||||
}
|
||||
|
||||
func (repo *UserGrantRepo) RemoveUserGrant(ctx context.Context, grantID string) error {
|
||||
return repo.UserGrantEvents.RemoveUserGrant(ctx, grantID)
|
||||
}
|
58
internal/token/model/token.go
Normal file
58
internal/token/model/token.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/caos/zitadel/internal/model"
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
ID string
|
||||
CreationDate time.Time
|
||||
ChangeDate time.Time
|
||||
ResourceOwner string
|
||||
UserID string
|
||||
ApplicationID string
|
||||
UserAgentID string
|
||||
Expiration time.Time
|
||||
Sequence uint64
|
||||
}
|
||||
|
||||
type TokenSearchRequest struct {
|
||||
Offset uint64
|
||||
Limit uint64
|
||||
SortingColumn TokenSearchKey
|
||||
Asc bool
|
||||
Queries []*TokenSearchQuery
|
||||
}
|
||||
|
||||
type TokenSearchKey int32
|
||||
|
||||
const (
|
||||
TOKENSEARCHKEY_UNSPECIFIED TokenSearchKey = iota
|
||||
TOKENSEARCHKEY_TOKEN_ID
|
||||
TOKENSEARCHKEY_USER_ID
|
||||
TOKENSEARCHKEY_APPLICATION_ID
|
||||
TOKENSEARCHKEY_USER_AGENT_ID
|
||||
TOKENSEARCHKEY_EXPIRATION
|
||||
TOKENSEARCHKEY_RESOURCEOWNER
|
||||
)
|
||||
|
||||
type TokenSearchQuery struct {
|
||||
Key TokenSearchKey
|
||||
Method model.SearchMethod
|
||||
Value string
|
||||
}
|
||||
|
||||
type TokenSearchResponse struct {
|
||||
Offset uint64
|
||||
Limit uint64
|
||||
TotalResult uint64
|
||||
Result []*Token
|
||||
}
|
||||
|
||||
func (r *TokenSearchRequest) EnsureLimit(limit uint64) {
|
||||
if r.Limit == 0 || r.Limit > limit {
|
||||
r.Limit = limit
|
||||
}
|
||||
}
|
56
internal/token/repository/view/model/token.go
Normal file
56
internal/token/repository/view/model/token.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/caos/zitadel/internal/token/model"
|
||||
)
|
||||
|
||||
const (
|
||||
TokenKeyTokenID = "id"
|
||||
TokenKeyUserID = "user_id"
|
||||
TokenKeyApplicationID = "application_id"
|
||||
TokenKeyUserAgentID = "user_agent_id"
|
||||
TokenKeyExpiration = "expiration"
|
||||
TokenKeyResourceOwner = "resource_owner"
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
ID string `json:"-" gorm:"column:id;primary_key"`
|
||||
CreationDate time.Time `json:"-" gorm:"column:creation_date"`
|
||||
ChangeDate time.Time `json:"-" gorm:"column:change_date"`
|
||||
ResourceOwner string `json:"-" gorm:"column:resource_owner"`
|
||||
UserID string `json:"-" gorm:"column:user_id"`
|
||||
ApplicationID string `json:"-" gorm:"column:application_id"`
|
||||
UserAgentID string `json:"-" gorm:"column:user_agent_id"`
|
||||
Expiration time.Time `json:"-" gorm:"column:expiration"`
|
||||
Sequence uint64 `json:"-" gorm:"column:sequence"`
|
||||
}
|
||||
|
||||
func TokenFromModel(token *model.Token) *Token {
|
||||
return &Token{
|
||||
ID: token.ID,
|
||||
CreationDate: token.CreationDate,
|
||||
ChangeDate: token.ChangeDate,
|
||||
ResourceOwner: token.ResourceOwner,
|
||||
UserID: token.UserID,
|
||||
ApplicationID: token.ApplicationID,
|
||||
UserAgentID: token.UserAgentID,
|
||||
Expiration: token.Expiration,
|
||||
Sequence: token.Sequence,
|
||||
}
|
||||
}
|
||||
|
||||
func TokenToModel(token *Token) *model.Token {
|
||||
return &model.Token{
|
||||
ID: token.ID,
|
||||
CreationDate: token.CreationDate,
|
||||
ChangeDate: token.ChangeDate,
|
||||
ResourceOwner: token.ResourceOwner,
|
||||
UserID: token.UserID,
|
||||
ApplicationID: token.ApplicationID,
|
||||
UserAgentID: token.UserAgentID,
|
||||
Expiration: token.Expiration,
|
||||
Sequence: token.Sequence,
|
||||
}
|
||||
}
|
69
internal/token/repository/view/model/token_query.go
Normal file
69
internal/token/repository/view/model/token_query.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
global_model "github.com/caos/zitadel/internal/model"
|
||||
token_model "github.com/caos/zitadel/internal/token/model"
|
||||
"github.com/caos/zitadel/internal/view"
|
||||
)
|
||||
|
||||
type TokenSearchRequest token_model.TokenSearchRequest
|
||||
type TokenSearchQuery token_model.TokenSearchQuery
|
||||
type TokenSearchKey token_model.TokenSearchKey
|
||||
|
||||
func (req TokenSearchRequest) GetLimit() uint64 {
|
||||
return req.Limit
|
||||
}
|
||||
|
||||
func (req TokenSearchRequest) GetOffset() uint64 {
|
||||
return req.Offset
|
||||
}
|
||||
|
||||
func (req TokenSearchRequest) GetSortingColumn() view.ColumnKey {
|
||||
if req.SortingColumn == token_model.TOKENSEARCHKEY_UNSPECIFIED {
|
||||
return nil
|
||||
}
|
||||
return TokenSearchKey(req.SortingColumn)
|
||||
}
|
||||
|
||||
func (req TokenSearchRequest) GetAsc() bool {
|
||||
return req.Asc
|
||||
}
|
||||
|
||||
func (req TokenSearchRequest) GetQueries() []view.SearchQuery {
|
||||
result := make([]view.SearchQuery, len(req.Queries))
|
||||
for i, q := range req.Queries {
|
||||
result[i] = TokenSearchQuery{Key: q.Key, Value: q.Value, Method: q.Method}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (req TokenSearchQuery) GetKey() view.ColumnKey {
|
||||
return TokenSearchKey(req.Key)
|
||||
}
|
||||
|
||||
func (req TokenSearchQuery) GetMethod() global_model.SearchMethod {
|
||||
return req.Method
|
||||
}
|
||||
|
||||
func (req TokenSearchQuery) GetValue() interface{} {
|
||||
return req.Value
|
||||
}
|
||||
|
||||
func (key TokenSearchKey) ToColumnName() string {
|
||||
switch token_model.TokenSearchKey(key) {
|
||||
case token_model.TOKENSEARCHKEY_TOKEN_ID:
|
||||
return TokenKeyTokenID
|
||||
case token_model.TOKENSEARCHKEY_USER_AGENT_ID:
|
||||
return TokenKeyUserAgentID
|
||||
case token_model.TOKENSEARCHKEY_USER_ID:
|
||||
return TokenKeyUserID
|
||||
case token_model.TOKENSEARCHKEY_APPLICATION_ID:
|
||||
return TokenKeyApplicationID
|
||||
case token_model.TOKENSEARCHKEY_EXPIRATION:
|
||||
return TokenKeyExpiration
|
||||
case token_model.TOKENSEARCHKEY_RESOURCEOWNER:
|
||||
return TokenKeyResourceOwner
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
48
internal/token/repository/view/token.go
Normal file
48
internal/token/repository/view/token.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package view
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
|
||||
"github.com/caos/zitadel/internal/errors"
|
||||
token_model "github.com/caos/zitadel/internal/token/model"
|
||||
"github.com/caos/zitadel/internal/token/repository/view/model"
|
||||
"github.com/caos/zitadel/internal/view"
|
||||
)
|
||||
|
||||
func TokenByID(db *gorm.DB, table, tokenID string) (*model.Token, error) {
|
||||
token := new(model.Token)
|
||||
query := view.PrepareGetByKey(table, model.TokenSearchKey(token_model.TOKENSEARCHKEY_TOKEN_ID), tokenID)
|
||||
err := query(db, token)
|
||||
return token, err
|
||||
}
|
||||
|
||||
func IsTokenValid(db *gorm.DB, table, tokenID string) (bool, error) {
|
||||
token, err := TokenByID(db, table, tokenID)
|
||||
if err == nil {
|
||||
return token.Expiration.After(time.Now().UTC()), nil
|
||||
}
|
||||
if errors.IsNotFound(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
func PutToken(db *gorm.DB, table string, token *model.Token) error {
|
||||
save := view.PrepareSave(table)
|
||||
return save(db, token)
|
||||
}
|
||||
|
||||
func DeleteToken(db *gorm.DB, table, tokenID string) error {
|
||||
delete := view.PrepareDeleteByKey(table, model.TokenSearchKey(token_model.TOKENSEARCHKEY_TOKEN_ID), tokenID)
|
||||
return delete(db)
|
||||
}
|
||||
|
||||
func DeleteTokens(db *gorm.DB, table, agentID, userID string) error {
|
||||
delete := view.PrepareDeleteByKeys(table,
|
||||
view.Key{Key: model.TokenSearchKey(token_model.TOKENSEARCHKEY_USER_AGENT_ID), Value: agentID},
|
||||
view.Key{Key: model.TokenSearchKey(token_model.TOKENSEARCHKEY_USER_ID), Value: userID},
|
||||
)
|
||||
return delete(db)
|
||||
}
|
61
internal/user/model/user_session_view.go
Normal file
61
internal/user/model/user_session_view.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
req_model "github.com/caos/zitadel/internal/auth_request/model"
|
||||
"github.com/caos/zitadel/internal/model"
|
||||
)
|
||||
|
||||
type UserSessionView struct {
|
||||
ID string
|
||||
CreationDate time.Time
|
||||
ChangeDate time.Time
|
||||
State req_model.UserSessionState
|
||||
ResourceOwner string
|
||||
UserAgentID string
|
||||
UserID string
|
||||
UserName string
|
||||
PasswordVerification time.Time
|
||||
MfaSoftwareVerification time.Time
|
||||
MfaHardwareVerification time.Time
|
||||
Sequence uint64
|
||||
}
|
||||
|
||||
type UserSessionSearchRequest struct {
|
||||
Offset uint64
|
||||
Limit uint64
|
||||
SortingColumn UserSessionSearchKey
|
||||
Asc bool
|
||||
Queries []*UserSessionSearchQuery
|
||||
}
|
||||
|
||||
type UserSessionSearchKey int32
|
||||
|
||||
const (
|
||||
USERSESSIONSEARCHKEY_UNSPECIFIED UserSessionSearchKey = iota
|
||||
USERSESSIONSEARCHKEY_SESSION_ID
|
||||
USERSESSIONSEARCHKEY_USER_AGENT_ID
|
||||
USERSESSIONSEARCHKEY_USER_ID
|
||||
USERSESSIONSEARCHKEY_STATE
|
||||
USERSESSIONSEARCHKEY_RESOURCEOWNER
|
||||
)
|
||||
|
||||
type UserSessionSearchQuery struct {
|
||||
Key UserSessionSearchKey
|
||||
Method model.SearchMethod
|
||||
Value string
|
||||
}
|
||||
|
||||
type UserSessionSearchResponse struct {
|
||||
Offset uint64
|
||||
Limit uint64
|
||||
TotalResult uint64
|
||||
Result []*UserSessionView
|
||||
}
|
||||
|
||||
func (r *UserSessionSearchRequest) EnsureLimit(limit uint64) {
|
||||
if r.Limit == 0 || r.Limit > limit {
|
||||
r.Limit = limit
|
||||
}
|
||||
}
|
@@ -1,36 +1,42 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/caos/zitadel/internal/model"
|
||||
"time"
|
||||
|
||||
req_model "github.com/caos/zitadel/internal/auth_request/model"
|
||||
"github.com/caos/zitadel/internal/model"
|
||||
)
|
||||
|
||||
type UserView struct {
|
||||
ID string
|
||||
CreationDate time.Time
|
||||
ChangeDate time.Time
|
||||
State UserState
|
||||
ResourceOwner string
|
||||
PasswordChanged time.Time
|
||||
LastLogin time.Time
|
||||
UserName string
|
||||
FirstName string
|
||||
LastName string
|
||||
NickName string
|
||||
DisplayName string
|
||||
PreferredLanguage string
|
||||
Gender Gender
|
||||
Email string
|
||||
IsEmailVerified bool
|
||||
Phone string
|
||||
IsPhoneVerified bool
|
||||
Country string
|
||||
Locality string
|
||||
PostalCode string
|
||||
Region string
|
||||
StreetAddress string
|
||||
OTPState MfaState
|
||||
Sequence uint64
|
||||
ID string
|
||||
CreationDate time.Time
|
||||
ChangeDate time.Time
|
||||
State UserState
|
||||
ResourceOwner string
|
||||
PasswordSet bool
|
||||
PasswordChangeRequired bool
|
||||
PasswordChanged time.Time
|
||||
LastLogin time.Time
|
||||
UserName string
|
||||
FirstName string
|
||||
LastName string
|
||||
NickName string
|
||||
DisplayName string
|
||||
PreferredLanguage string
|
||||
Gender Gender
|
||||
Email string
|
||||
IsEmailVerified bool
|
||||
Phone string
|
||||
IsPhoneVerified bool
|
||||
Country string
|
||||
Locality string
|
||||
PostalCode string
|
||||
Region string
|
||||
StreetAddress string
|
||||
OTPState MfaState
|
||||
MfaMaxSetUp req_model.MfaLevel
|
||||
MfaInitSkipped time.Time
|
||||
Sequence uint64
|
||||
}
|
||||
|
||||
type UserSearchRequest struct {
|
||||
@@ -78,3 +84,35 @@ func (r *UserSearchRequest) EnsureLimit(limit uint64) {
|
||||
func (r *UserSearchRequest) AppendMyOrgQuery(orgID string) {
|
||||
r.Queries = append(r.Queries, &UserSearchQuery{Key: USERSEARCHKEY_RESOURCEOWNER, Method: model.SEARCHMETHOD_EQUALS, Value: orgID})
|
||||
}
|
||||
|
||||
func (u *UserView) MfaTypesSetupPossible(level req_model.MfaLevel) []req_model.MfaType {
|
||||
types := make([]req_model.MfaType, 0)
|
||||
switch level {
|
||||
case req_model.MfaLevelSoftware:
|
||||
if u.OTPState != MFASTATE_READY {
|
||||
types = append(types, req_model.MfaTypeOTP)
|
||||
}
|
||||
//PLANNED: add sms
|
||||
fallthrough
|
||||
case req_model.MfaLevelHardware:
|
||||
//PLANNED: add token
|
||||
}
|
||||
return types
|
||||
}
|
||||
|
||||
func (u *UserView) MfaTypesAllowed(level req_model.MfaLevel) []req_model.MfaType {
|
||||
types := make([]req_model.MfaType, 0)
|
||||
switch level {
|
||||
default:
|
||||
fallthrough
|
||||
case req_model.MfaLevelSoftware:
|
||||
if u.OTPState == MFASTATE_READY {
|
||||
types = append(types, req_model.MfaTypeOTP)
|
||||
}
|
||||
//PLANNED: add sms
|
||||
fallthrough
|
||||
case req_model.MfaLevelHardware:
|
||||
//PLANNED: add token
|
||||
}
|
||||
return types
|
||||
}
|
||||
|
@@ -4,6 +4,10 @@ import (
|
||||
"context"
|
||||
"strconv"
|
||||
|
||||
"github.com/pquerna/otp/totp"
|
||||
"github.com/sony/sonyflake"
|
||||
|
||||
req_model "github.com/caos/zitadel/internal/auth_request/model"
|
||||
"github.com/caos/zitadel/internal/cache/config"
|
||||
sd "github.com/caos/zitadel/internal/config/systemdefaults"
|
||||
"github.com/caos/zitadel/internal/crypto"
|
||||
@@ -14,8 +18,6 @@ import (
|
||||
global_model "github.com/caos/zitadel/internal/model"
|
||||
usr_model "github.com/caos/zitadel/internal/user/model"
|
||||
"github.com/caos/zitadel/internal/user/repository/eventsourcing/model"
|
||||
"github.com/pquerna/otp/totp"
|
||||
"github.com/sony/sonyflake"
|
||||
)
|
||||
|
||||
type UserEventstore struct {
|
||||
@@ -28,6 +30,7 @@ type UserEventstore struct {
|
||||
PhoneVerificationCode crypto.Generator
|
||||
PasswordVerificationCode crypto.Generator
|
||||
Multifactors global_model.Multifactors
|
||||
validateTOTP func(string, string) bool
|
||||
}
|
||||
|
||||
type UserConfig struct {
|
||||
@@ -71,6 +74,7 @@ func StartUser(conf UserConfig, systemDefaults sd.SystemDefaults) (*UserEventsto
|
||||
PasswordVerificationCode: passwordVerificationCode,
|
||||
Multifactors: mfa,
|
||||
PasswordAlg: passwordAlg,
|
||||
validateTOTP: totp.Validate,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -333,31 +337,80 @@ func (es *UserEventstore) UserPasswordByID(ctx context.Context, userID string) (
|
||||
return nil, caos_errs.ThrowNotFound(nil, "EVENT-d8e2", "password not found")
|
||||
}
|
||||
|
||||
func (es *UserEventstore) SetOneTimePassword(ctx context.Context, password *usr_model.Password) (*usr_model.Password, error) {
|
||||
return es.changedPassword(ctx, password, true)
|
||||
}
|
||||
|
||||
func (es *UserEventstore) SetPassword(ctx context.Context, password *usr_model.Password) (*usr_model.Password, error) {
|
||||
return es.changedPassword(ctx, password, false)
|
||||
}
|
||||
|
||||
func (es *UserEventstore) changedPassword(ctx context.Context, password *usr_model.Password, onetime bool) (*usr_model.Password, error) {
|
||||
if !password.IsValid() {
|
||||
return nil, caos_errs.ThrowPreconditionFailed(nil, "EVENT-dosi3", "password invalid")
|
||||
func (es *UserEventstore) CheckPassword(ctx context.Context, userID, password string, authRequest *req_model.AuthRequest) error {
|
||||
existing, err := es.UserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existing.Password == nil {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "EVENT-s35Fa", "no password set")
|
||||
}
|
||||
if err := crypto.CompareHash(existing.Password.SecretCrypto, []byte(password), es.PasswordAlg); err == nil {
|
||||
return es.setPasswordCheckResult(ctx, existing, authRequest, PasswordCheckSucceededAggregate)
|
||||
}
|
||||
if err := es.setPasswordCheckResult(ctx, existing, authRequest, PasswordCheckFailedAggregate); err != nil {
|
||||
return err
|
||||
}
|
||||
return caos_errs.ThrowInvalidArgument(nil, "EVENT-452ad", "invalid password")
|
||||
}
|
||||
|
||||
func (es *UserEventstore) setPasswordCheckResult(ctx context.Context, user *usr_model.User, authRequest *req_model.AuthRequest, check func(*es_models.AggregateCreator, *model.User, *model.AuthRequest) es_sdk.AggregateFunc) error {
|
||||
repoUser := model.UserFromModel(user)
|
||||
repoAuthRequest := model.AuthRequestFromModel(authRequest)
|
||||
agg := check(es.AggregateCreator(), repoUser, repoAuthRequest)
|
||||
err := es_sdk.Push(ctx, es.PushAggregates, repoUser.AppendEvents, agg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
es.userCache.cacheUser(repoUser)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *UserEventstore) SetOneTimePassword(ctx context.Context, password *usr_model.Password) (*usr_model.Password, error) {
|
||||
user, err := es.UserByID(ctx, password.AggregateID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return es.changedPassword(ctx, user, password.SecretString, true)
|
||||
}
|
||||
|
||||
err = password.HashPasswordIfExisting(es.PasswordAlg, onetime)
|
||||
func (es *UserEventstore) SetPassword(ctx context.Context, userID, code, password string) error {
|
||||
user, err := es.UserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if user.PasswordCode == nil {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "EVENT-65sdr", "reset code not found")
|
||||
}
|
||||
if err := crypto.VerifyCode(user.PasswordCode.CreationDate, user.PasswordCode.Expiry, user.PasswordCode.Code, code, es.PasswordVerificationCode); err != nil {
|
||||
return caos_errs.ThrowPreconditionFailed(err, "EVENT-sd6DF", "code invalid")
|
||||
}
|
||||
_, err = es.changedPassword(ctx, user, password, false)
|
||||
return err
|
||||
}
|
||||
|
||||
func (es *UserEventstore) ChangePassword(ctx context.Context, userID, old, new string) (*usr_model.Password, error) {
|
||||
user, err := es.UserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if user.Password == nil {
|
||||
return nil, caos_errs.ThrowPreconditionFailed(nil, "EVENT-Fds3s", "user has no password")
|
||||
}
|
||||
if err := crypto.CompareHash(user.Password.SecretCrypto, []byte(old), es.PasswordAlg); err != nil {
|
||||
return nil, caos_errs.ThrowInvalidArgument(nil, "EVENT-s56a3", "invalid password")
|
||||
}
|
||||
return es.changedPassword(ctx, user, new, false)
|
||||
}
|
||||
|
||||
func (es *UserEventstore) changedPassword(ctx context.Context, user *usr_model.User, password string, onetime bool) (*usr_model.Password, error) {
|
||||
//TODO: check password policy
|
||||
secret, err := crypto.Hash([]byte(password), es.PasswordAlg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
repoPassword := &model.Password{Secret: secret, ChangeRequired: onetime}
|
||||
repoUser := model.UserFromModel(user)
|
||||
repoPassword := model.PasswordFromModel(password)
|
||||
|
||||
agg := PasswordChangeAggregate(es.AggregateCreator(), repoUser, repoPassword)
|
||||
err = es_sdk.Push(ctx, es.PushAggregates, repoUser.AppendEvents, agg)
|
||||
if err != nil {
|
||||
@@ -666,21 +719,6 @@ func (es *UserEventstore) ChangeAddress(ctx context.Context, address *usr_model.
|
||||
return model.AddressToModel(repoExisting.Address), nil
|
||||
}
|
||||
|
||||
func (es *UserEventstore) OTPByID(ctx context.Context, userID string) (*usr_model.OTP, error) {
|
||||
if userID == "" {
|
||||
return nil, caos_errs.ThrowPreconditionFailed(nil, "EVENT-do9se", "userID missing")
|
||||
}
|
||||
user, err := es.UserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.OTP != nil {
|
||||
return user.OTP, nil
|
||||
}
|
||||
return nil, caos_errs.ThrowNotFound(nil, "EVENT-dps09", "otp not found")
|
||||
}
|
||||
|
||||
func (es *UserEventstore) AddOTP(ctx context.Context, userID string) (*usr_model.OTP, error) {
|
||||
existing, err := es.UserByID(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -731,22 +769,77 @@ func (es *UserEventstore) RemoveOTP(ctx context.Context, userID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *UserEventstore) CheckMfaOTP(ctx context.Context, userID, code string) error {
|
||||
existing, err := es.UserByID(ctx, userID)
|
||||
func (es *UserEventstore) CheckMfaOTPSetup(ctx context.Context, userID, code string) error {
|
||||
user, err := es.UserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existing.OTP == nil {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "EVENT-sp0de", "no otp existing")
|
||||
if user.OTP == nil || user.IsOTPReady() {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "EVENT-sd5NJ", "otp not existing or already set up")
|
||||
}
|
||||
decrypt, err := crypto.DecryptString(existing.OTP.Secret, es.Multifactors.OTP.CryptoMFA)
|
||||
if err := es.verifyMfaOTP(user.OTP, code); err != nil {
|
||||
return err
|
||||
}
|
||||
repoUser := model.UserFromModel(user)
|
||||
err = es_sdk.Push(ctx, es.PushAggregates, repoUser.AppendEvents, MfaOTPVerifyAggregate(es.AggregateCreator(), repoUser))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
valid := totp.Validate(code, decrypt)
|
||||
es.userCache.cacheUser(repoUser)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *UserEventstore) CheckMfaOTP(ctx context.Context, userID, code string, authRequest *req_model.AuthRequest) error {
|
||||
user, err := es.UserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !user.IsOTPReady() {
|
||||
return caos_errs.ThrowPreconditionFailed(nil, "EVENT-sd5NJ", "opt not ready")
|
||||
}
|
||||
|
||||
repoUser := model.UserFromModel(user)
|
||||
repoAuthReq := model.AuthRequestFromModel(authRequest)
|
||||
var aggregate func(*es_models.AggregateCreator, *model.User, *model.AuthRequest) es_sdk.AggregateFunc
|
||||
if err := es.verifyMfaOTP(user.OTP, code); err != nil {
|
||||
aggregate = MfaOTPCheckFailedAggregate
|
||||
} else {
|
||||
aggregate = MfaOTPCheckSucceededAggregate
|
||||
}
|
||||
err = es_sdk.Push(ctx, es.PushAggregates, repoUser.AppendEvents, aggregate(es.AggregateCreator(), repoUser, repoAuthReq))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
es.userCache.cacheUser(repoUser)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *UserEventstore) verifyMfaOTP(otp *usr_model.OTP, code string) error {
|
||||
decrypt, err := crypto.DecryptString(otp.Secret, es.Multifactors.OTP.CryptoMFA)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
valid := es.validateTOTP(code, decrypt)
|
||||
if !valid {
|
||||
return caos_errs.ThrowInvalidArgument(nil, "EVENT-8isk2", "Invalid code")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *UserEventstore) SignOut(ctx context.Context, agentID, userID string) error {
|
||||
user, err := es.UserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
repoUser := model.UserFromModel(user)
|
||||
err = es_sdk.Push(ctx, es.PushAggregates, repoUser.AppendEvents, SignOutAggregate(es.AggregateCreator(), repoUser, agentID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
es.userCache.cacheUser(repoUser)
|
||||
return nil
|
||||
}
|
||||
|
@@ -2,15 +2,17 @@ package eventsourcing
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/sony/sonyflake"
|
||||
|
||||
mock_cache "github.com/caos/zitadel/internal/cache/mock"
|
||||
"github.com/caos/zitadel/internal/crypto"
|
||||
"github.com/caos/zitadel/internal/eventstore/mock"
|
||||
es_models "github.com/caos/zitadel/internal/eventstore/models"
|
||||
global_model "github.com/caos/zitadel/internal/model"
|
||||
"github.com/caos/zitadel/internal/user/repository/eventsourcing/model"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/sony/sonyflake"
|
||||
"time"
|
||||
)
|
||||
|
||||
func GetMockedEventstore(ctrl *gomock.Controller, mockEs *mock.MockEventstore) *UserEventstore {
|
||||
@@ -38,10 +40,7 @@ func GetMockedEventstoreWithPw(ctrl *gomock.Controller, mockEs *mock.MockEventst
|
||||
}
|
||||
if password {
|
||||
es.PasswordVerificationCode = GetMockPwGenerator(ctrl)
|
||||
hash := crypto.NewMockHashAlgorithm(ctrl)
|
||||
hash.EXPECT().Hash(gomock.Any()).Return(nil, nil)
|
||||
hash.EXPECT().Algorithm().Return("bcrypt")
|
||||
es.PasswordAlg = hash
|
||||
es.PasswordAlg = crypto.CreateMockHashAlg(ctrl)
|
||||
}
|
||||
return es
|
||||
}
|
||||
@@ -174,8 +173,10 @@ func GetMockManipulateUserWithPhoneCodeGen(ctrl *gomock.Controller, user model.U
|
||||
|
||||
func GetMockManipulateUserWithPasswordCodeGen(ctrl *gomock.Controller, user model.User) *UserEventstore {
|
||||
data, _ := json.Marshal(user)
|
||||
code, _ := json.Marshal(user.PasswordCode)
|
||||
events := []*es_models.Event{
|
||||
&es_models.Event{AggregateID: "AggregateID", Sequence: 1, Type: model.UserAdded, Data: data},
|
||||
&es_models.Event{AggregateID: "AggregateID", Sequence: 1, Type: model.UserPasswordCodeAdded, Data: code},
|
||||
}
|
||||
mockEs := mock.NewMockEventstore(ctrl)
|
||||
mockEs.EXPECT().FilterEvents(gomock.Any(), gomock.Any()).Return(events, nil)
|
||||
@@ -394,29 +395,48 @@ func GetMockManipulateUserFull(ctrl *gomock.Controller) *UserEventstore {
|
||||
return GetMockedEventstore(ctrl, mockEs)
|
||||
}
|
||||
|
||||
func GetMockManipulateUserWithOTP(ctrl *gomock.Controller) *UserEventstore {
|
||||
func GetMockManipulateUserWithOTP(ctrl *gomock.Controller, decrypt, verified bool) *UserEventstore {
|
||||
user := model.User{
|
||||
Profile: &model.Profile{
|
||||
UserName: "UserName",
|
||||
},
|
||||
}
|
||||
otp := model.OTP{Secret: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("code"),
|
||||
}}
|
||||
otp := model.OTP{
|
||||
Secret: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("code"),
|
||||
},
|
||||
}
|
||||
dataUser, _ := json.Marshal(user)
|
||||
dataOtp, _ := json.Marshal(otp)
|
||||
events := []*es_models.Event{
|
||||
&es_models.Event{AggregateID: "AggregateID", Sequence: 1, Type: model.UserAdded, Data: dataUser},
|
||||
&es_models.Event{AggregateID: "AggregateID", Sequence: 1, Type: model.MfaOtpAdded, Data: dataOtp},
|
||||
}
|
||||
if verified {
|
||||
events = append(events, &es_models.Event{AggregateID: "AggregateID", Sequence: 1, Type: model.MfaOtpVerified})
|
||||
}
|
||||
mockEs := mock.NewMockEventstore(ctrl)
|
||||
mockEs.EXPECT().FilterEvents(gomock.Any(), gomock.Any()).Return(events, nil)
|
||||
mockEs.EXPECT().AggregateCreator().Return(es_models.NewAggregateCreator("TEST"))
|
||||
mockEs.EXPECT().PushAggregates(gomock.Any(), gomock.Any()).Return(nil)
|
||||
return GetMockedEventstore(ctrl, mockEs)
|
||||
es := GetMockedEventstore(ctrl, mockEs)
|
||||
if !decrypt {
|
||||
return es
|
||||
}
|
||||
enc := crypto.NewMockEncryptionAlgorithm(ctrl)
|
||||
enc.EXPECT().Algorithm().Return("enc")
|
||||
enc.EXPECT().Encrypt(gomock.Any()).Return(nil, nil)
|
||||
enc.EXPECT().EncryptionKeyID().Return("id")
|
||||
enc.EXPECT().DecryptionKeyIDs().Return([]string{"id"})
|
||||
enc.EXPECT().DecryptString(gomock.Any(), gomock.Any()).Return("code", nil)
|
||||
es.Multifactors = global_model.Multifactors{OTP: global_model.OTP{
|
||||
Issuer: "Issuer",
|
||||
CryptoMFA: enc,
|
||||
}}
|
||||
return es
|
||||
}
|
||||
|
||||
func GetMockManipulateUserNoEvents(ctrl *gomock.Controller) *UserEventstore {
|
||||
|
@@ -2,14 +2,19 @@ package eventsourcing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
"github.com/caos/zitadel/internal/api/auth"
|
||||
req_model "github.com/caos/zitadel/internal/auth_request/model"
|
||||
"github.com/caos/zitadel/internal/crypto"
|
||||
caos_errs "github.com/caos/zitadel/internal/errors"
|
||||
es_models "github.com/caos/zitadel/internal/eventstore/models"
|
||||
"github.com/caos/zitadel/internal/user/model"
|
||||
repo_model "github.com/caos/zitadel/internal/user/repository/eventsourcing/model"
|
||||
"github.com/golang/mock/gomock"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestUserByID(t *testing.T) {
|
||||
@@ -1025,16 +1030,17 @@ func TestSetOneTimePassword(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetPassword(t *testing.T) {
|
||||
func TestCheckPassword(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
type args struct {
|
||||
es *UserEventstore
|
||||
ctx context.Context
|
||||
password *model.Password
|
||||
es *UserEventstore
|
||||
ctx context.Context
|
||||
userID string
|
||||
password string
|
||||
authRequest *req_model.AuthRequest
|
||||
}
|
||||
type res struct {
|
||||
password *model.Password
|
||||
errFunc func(err error) bool
|
||||
errFunc func(err error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -1042,22 +1048,40 @@ func TestSetPassword(t *testing.T) {
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "create pw",
|
||||
name: "check pw ok",
|
||||
args: args{
|
||||
es: GetMockManipulateUserWithPasswordCodeGen(ctrl, repo_model.User{ObjectRoot: es_models.ObjectRoot{AggregateID: "AggregateID"}}),
|
||||
es: GetMockManipulateUserWithPasswordAndEmailCodeGen(ctrl,
|
||||
repo_model.User{
|
||||
ObjectRoot: es_models.ObjectRoot{AggregateID: "AggregateID"},
|
||||
Password: &repo_model.Password{Secret: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeHash,
|
||||
Algorithm: "hash",
|
||||
Crypted: []byte("password"),
|
||||
}},
|
||||
},
|
||||
),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
password: &model.Password{ObjectRoot: es_models.ObjectRoot{AggregateID: "AggregateID"}, SecretString: "Password"},
|
||||
},
|
||||
res: res{
|
||||
password: &model.Password{ObjectRoot: es_models.ObjectRoot{AggregateID: "AggregateID", Sequence: 1}, ChangeRequired: false},
|
||||
userID: "userID",
|
||||
password: "password",
|
||||
authRequest: &req_model.AuthRequest{
|
||||
ID: "id",
|
||||
AgentID: "agentID",
|
||||
BrowserInfo: &req_model.BrowserInfo{
|
||||
UserAgent: "user agent",
|
||||
AcceptLanguage: "accept langugage",
|
||||
RemoteIP: net.IPv4(29, 4, 20, 19),
|
||||
},
|
||||
},
|
||||
},
|
||||
res: res{},
|
||||
},
|
||||
{
|
||||
name: "empty userid",
|
||||
args: args{
|
||||
es: GetMockManipulateUser(ctrl),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
password: &model.Password{ObjectRoot: es_models.ObjectRoot{AggregateID: ""}, SecretString: "Password"},
|
||||
userID: "",
|
||||
password: "password",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsPreconditionFailed,
|
||||
@@ -1068,25 +1092,311 @@ func TestSetPassword(t *testing.T) {
|
||||
args: args{
|
||||
es: GetMockManipulateUserNoEvents(ctrl),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
password: &model.Password{ObjectRoot: es_models.ObjectRoot{AggregateID: "AggregateID"}, SecretString: "Password"},
|
||||
userID: "userID",
|
||||
password: "password",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsNotFound,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no password",
|
||||
args: args{
|
||||
es: GetMockUserByIDOK(ctrl,
|
||||
repo_model.User{
|
||||
ObjectRoot: es_models.ObjectRoot{AggregateID: "AggregateID"},
|
||||
},
|
||||
),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "userID",
|
||||
password: "password",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsPreconditionFailed,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid password",
|
||||
args: args{
|
||||
es: GetMockManipulateUserWithPasswordCodeGen(ctrl,
|
||||
repo_model.User{
|
||||
ObjectRoot: es_models.ObjectRoot{AggregateID: "AggregateID"},
|
||||
Password: &repo_model.Password{Secret: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeHash,
|
||||
Algorithm: "hash",
|
||||
Crypted: []byte("password"),
|
||||
}},
|
||||
},
|
||||
),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "userID",
|
||||
password: "wrong password",
|
||||
authRequest: &req_model.AuthRequest{
|
||||
ID: "id",
|
||||
AgentID: "agentID",
|
||||
BrowserInfo: &req_model.BrowserInfo{
|
||||
UserAgent: "user agent",
|
||||
AcceptLanguage: "accept langugage",
|
||||
RemoteIP: net.IPv4(29, 4, 20, 19),
|
||||
},
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := tt.args.es.SetPassword(tt.args.ctx, tt.args.password)
|
||||
err := tt.args.es.CheckPassword(tt.args.ctx, tt.args.userID, tt.args.password, tt.args.authRequest)
|
||||
|
||||
if tt.res.errFunc == nil && err != nil {
|
||||
t.Errorf("result has error: %v", err)
|
||||
}
|
||||
if tt.res.errFunc != nil && !tt.res.errFunc(err) {
|
||||
t.Errorf("got wrong err: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetPassword(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
type args struct {
|
||||
es *UserEventstore
|
||||
ctx context.Context
|
||||
userID string
|
||||
code string
|
||||
password string
|
||||
}
|
||||
type res struct {
|
||||
errFunc func(err error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "create pw",
|
||||
args: args{
|
||||
es: GetMockManipulateUserWithPasswordCodeGen(ctrl,
|
||||
repo_model.User{
|
||||
ObjectRoot: es_models.ObjectRoot{AggregateID: "AggregateID"},
|
||||
PasswordCode: &repo_model.PasswordCode{Code: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("code"),
|
||||
}},
|
||||
},
|
||||
),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "userID",
|
||||
code: "code",
|
||||
password: "password",
|
||||
},
|
||||
res: res{},
|
||||
},
|
||||
{
|
||||
name: "empty userid",
|
||||
args: args{
|
||||
es: GetMockManipulateUser(ctrl),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "",
|
||||
code: "code",
|
||||
password: "password",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsPreconditionFailed,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "existing user not found",
|
||||
args: args{
|
||||
es: GetMockManipulateUserNoEvents(ctrl),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "userID",
|
||||
code: "code",
|
||||
password: "password",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsNotFound,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no passcode",
|
||||
args: args{
|
||||
es: GetMockManipulateUserWithPasswordCodeGen(ctrl,
|
||||
repo_model.User{
|
||||
ObjectRoot: es_models.ObjectRoot{AggregateID: "AggregateID"},
|
||||
},
|
||||
),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "userID",
|
||||
code: "code",
|
||||
password: "password",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsPreconditionFailed,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid passcode",
|
||||
args: args{
|
||||
es: GetMockManipulateUserWithPasswordCodeGen(ctrl,
|
||||
repo_model.User{
|
||||
ObjectRoot: es_models.ObjectRoot{AggregateID: "AggregateID"},
|
||||
PasswordCode: &repo_model.PasswordCode{Code: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc2",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("code2"),
|
||||
}},
|
||||
},
|
||||
),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "userID",
|
||||
code: "code",
|
||||
password: "password",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsPreconditionFailed,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.args.es.SetPassword(tt.args.ctx, tt.args.userID, tt.args.code, tt.args.password)
|
||||
|
||||
if tt.res.errFunc == nil && err != nil {
|
||||
t.Errorf("result has error: %v", err)
|
||||
}
|
||||
if tt.res.errFunc != nil && !tt.res.errFunc(err) {
|
||||
t.Errorf("got wrong err: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChangePassword(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
type args struct {
|
||||
es *UserEventstore
|
||||
ctx context.Context
|
||||
userID string
|
||||
old string
|
||||
new string
|
||||
}
|
||||
type res struct {
|
||||
password string
|
||||
errFunc func(err error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "change pw",
|
||||
args: args{
|
||||
es: GetMockManipulateUserWithPasswordAndEmailCodeGen(ctrl,
|
||||
repo_model.User{
|
||||
ObjectRoot: es_models.ObjectRoot{AggregateID: "AggregateID"},
|
||||
Password: &repo_model.Password{Secret: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeHash,
|
||||
Algorithm: "hash",
|
||||
Crypted: []byte("old"),
|
||||
}},
|
||||
},
|
||||
),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "userID",
|
||||
old: "old",
|
||||
new: "new",
|
||||
},
|
||||
res: res{
|
||||
password: "new",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty userid",
|
||||
args: args{
|
||||
es: GetMockManipulateUser(ctrl),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "",
|
||||
old: "old",
|
||||
new: "new",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsPreconditionFailed,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "existing user not found",
|
||||
args: args{
|
||||
es: GetMockManipulateUserNoEvents(ctrl),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "userID",
|
||||
old: "old",
|
||||
new: "new",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsNotFound,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no password",
|
||||
args: args{
|
||||
es: GetMockManipulateUserWithPasswordCodeGen(ctrl,
|
||||
repo_model.User{
|
||||
ObjectRoot: es_models.ObjectRoot{AggregateID: "AggregateID"},
|
||||
},
|
||||
),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "userID",
|
||||
old: "old",
|
||||
new: "new",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsPreconditionFailed,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid password",
|
||||
args: args{
|
||||
es: GetMockManipulateUserWithPasswordCodeGen(ctrl,
|
||||
repo_model.User{
|
||||
ObjectRoot: es_models.ObjectRoot{AggregateID: "AggregateID"},
|
||||
Password: &repo_model.Password{Secret: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeHash,
|
||||
Algorithm: "hash",
|
||||
Crypted: []byte("older"),
|
||||
}},
|
||||
},
|
||||
),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "userID",
|
||||
old: "old",
|
||||
new: "new",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := tt.args.es.ChangePassword(tt.args.ctx, tt.args.userID, tt.args.old, tt.args.new)
|
||||
|
||||
if tt.res.errFunc == nil && result.AggregateID == "" {
|
||||
t.Errorf("result has no id")
|
||||
}
|
||||
if tt.res.errFunc == nil && result.ChangeRequired != false {
|
||||
t.Errorf("should not be one time")
|
||||
if tt.res.errFunc == nil && string(result.SecretCrypto.Crypted) != tt.res.password {
|
||||
t.Errorf("got wrong result crypted: expected: %v, actual: %v ", tt.res.password, result.SecretString)
|
||||
}
|
||||
if tt.res.errFunc != nil && !tt.res.errFunc(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
t.Errorf("got wrong err: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -2036,69 +2346,6 @@ func TestChangeAddress(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOTPByID(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
type args struct {
|
||||
es *UserEventstore
|
||||
ctx context.Context
|
||||
existing *model.User
|
||||
}
|
||||
type res struct {
|
||||
errFunc func(err error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "get by id, ok",
|
||||
args: args{
|
||||
es: GetMockManipulateUserWithOTP(ctrl),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
existing: &model.User{ObjectRoot: es_models.ObjectRoot{AggregateID: "AggregateID", Sequence: 1}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty userid",
|
||||
args: args{
|
||||
es: GetMockManipulateUser(ctrl),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
existing: &model.User{ObjectRoot: es_models.ObjectRoot{AggregateID: "", Sequence: 1}},
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsPreconditionFailed,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "existing user not found",
|
||||
args: args{
|
||||
es: GetMockManipulateUserNoEvents(ctrl),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
existing: &model.User{ObjectRoot: es_models.ObjectRoot{AggregateID: "AggregateID", Sequence: 1}},
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsNotFound,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := tt.args.es.OTPByID(tt.args.ctx, tt.args.existing.AggregateID)
|
||||
|
||||
if tt.res.errFunc == nil && result.AggregateID == "" {
|
||||
t.Errorf("result has no id")
|
||||
}
|
||||
if tt.res.errFunc == nil && result == nil {
|
||||
t.Errorf("got wrong result change required: actual: %v ", result)
|
||||
}
|
||||
if tt.res.errFunc != nil && !tt.res.errFunc(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddOTP(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
type args struct {
|
||||
@@ -2168,6 +2415,245 @@ func TestAddOTP(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckMfaOTPSetup(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
type args struct {
|
||||
es *UserEventstore
|
||||
ctx context.Context
|
||||
userID string
|
||||
code string
|
||||
}
|
||||
type res struct {
|
||||
errFunc func(err error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "setup ok",
|
||||
args: args{
|
||||
es: func() *UserEventstore {
|
||||
es := GetMockManipulateUserWithOTP(ctrl, true, false)
|
||||
es.validateTOTP = func(string, string) bool {
|
||||
return true
|
||||
}
|
||||
return es
|
||||
}(),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "id",
|
||||
code: "code",
|
||||
},
|
||||
res: res{},
|
||||
},
|
||||
{
|
||||
name: "wrong code",
|
||||
args: args{
|
||||
es: func() *UserEventstore {
|
||||
es := GetMockManipulateUserWithOTP(ctrl, true, false)
|
||||
es.validateTOTP = func(string, string) bool {
|
||||
return false
|
||||
}
|
||||
return es
|
||||
}(),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "id",
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty userid",
|
||||
args: args{
|
||||
es: GetMockManipulateUser(ctrl),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsPreconditionFailed,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty code",
|
||||
args: args{
|
||||
es: GetMockManipulateUser(ctrl),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "userID",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsPreconditionFailed,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "existing user not found",
|
||||
args: args{
|
||||
es: GetMockManipulateUserNoEvents(ctrl),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "userID",
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsNotFound,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user has no otp",
|
||||
args: args{
|
||||
es: GetMockManipulateUser(ctrl),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "userID",
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsPreconditionFailed,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.args.es.CheckMfaOTPSetup(tt.args.ctx, tt.args.userID, tt.args.code)
|
||||
|
||||
if tt.res.errFunc == nil && err != nil {
|
||||
t.Errorf("result should not get err")
|
||||
}
|
||||
if tt.res.errFunc != nil && !tt.res.errFunc(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckMfaOTP(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
type args struct {
|
||||
es *UserEventstore
|
||||
ctx context.Context
|
||||
userID string
|
||||
code string
|
||||
authRequest *req_model.AuthRequest
|
||||
}
|
||||
type res struct {
|
||||
errFunc func(err error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "check ok",
|
||||
args: args{
|
||||
es: func() *UserEventstore {
|
||||
es := GetMockManipulateUserWithOTP(ctrl, true, true)
|
||||
es.validateTOTP = func(string, string) bool {
|
||||
return true
|
||||
}
|
||||
return es
|
||||
}(),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "id",
|
||||
code: "code",
|
||||
authRequest: &req_model.AuthRequest{
|
||||
ID: "id",
|
||||
AgentID: "agentID",
|
||||
BrowserInfo: &req_model.BrowserInfo{
|
||||
UserAgent: "user agent",
|
||||
AcceptLanguage: "accept langugage",
|
||||
RemoteIP: net.IPv4(29, 4, 20, 19),
|
||||
},
|
||||
},
|
||||
},
|
||||
res: res{},
|
||||
},
|
||||
{
|
||||
name: "wrong code",
|
||||
args: args{
|
||||
es: func() *UserEventstore {
|
||||
es := GetMockManipulateUserWithOTP(ctrl, true, true)
|
||||
es.validateTOTP = func(string, string) bool {
|
||||
return false
|
||||
}
|
||||
return es
|
||||
}(),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "id",
|
||||
code: "code",
|
||||
authRequest: &req_model.AuthRequest{
|
||||
ID: "id",
|
||||
AgentID: "agentID",
|
||||
BrowserInfo: &req_model.BrowserInfo{
|
||||
UserAgent: "user agent",
|
||||
AcceptLanguage: "accept langugage",
|
||||
RemoteIP: net.IPv4(29, 4, 20, 19),
|
||||
},
|
||||
},
|
||||
},
|
||||
res: res{},
|
||||
},
|
||||
{
|
||||
name: "empty userid",
|
||||
args: args{
|
||||
es: GetMockManipulateUser(ctrl),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsPreconditionFailed,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty code",
|
||||
args: args{
|
||||
es: GetMockManipulateUser(ctrl),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "userID",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsPreconditionFailed,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "existing user not found",
|
||||
args: args{
|
||||
es: GetMockManipulateUserNoEvents(ctrl),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "userID",
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsNotFound,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user has no otp",
|
||||
args: args{
|
||||
es: GetMockManipulateUser(ctrl),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "userID",
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsPreconditionFailed,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.args.es.CheckMfaOTP(tt.args.ctx, tt.args.userID, tt.args.code, tt.args.authRequest)
|
||||
|
||||
if tt.res.errFunc == nil && err != nil {
|
||||
t.Errorf("result should not get err, got : %v", err)
|
||||
}
|
||||
if tt.res.errFunc != nil && !tt.res.errFunc(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveOTP(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
type args struct {
|
||||
@@ -2186,7 +2672,7 @@ func TestRemoveOTP(t *testing.T) {
|
||||
{
|
||||
name: "remove ok",
|
||||
args: args{
|
||||
es: GetMockManipulateUserWithOTP(ctrl),
|
||||
es: GetMockManipulateUserWithOTP(ctrl, false, true),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
existing: &model.User{ObjectRoot: es_models.ObjectRoot{AggregateID: "AggregateID", Sequence: 1}},
|
||||
},
|
||||
@@ -2238,80 +2724,3 @@ func TestRemoveOTP(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckOTP(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
type args struct {
|
||||
es *UserEventstore
|
||||
ctx context.Context
|
||||
userID string
|
||||
code string
|
||||
}
|
||||
type res struct {
|
||||
errFunc func(err error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "empty userid",
|
||||
args: args{
|
||||
es: GetMockManipulateUser(ctrl),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsPreconditionFailed,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty code",
|
||||
args: args{
|
||||
es: GetMockManipulateUser(ctrl),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "userID",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsPreconditionFailed,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "existing user not found",
|
||||
args: args{
|
||||
es: GetMockManipulateUserNoEvents(ctrl),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "userID",
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsNotFound,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user has no otp",
|
||||
args: args{
|
||||
es: GetMockManipulateUser(ctrl),
|
||||
ctx: auth.NewMockContext("orgID", "userID"),
|
||||
userID: "userID",
|
||||
code: "code",
|
||||
},
|
||||
res: res{
|
||||
errFunc: caos_errs.IsPreconditionFailed,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.args.es.CheckMfaOTP(tt.args.ctx, tt.args.userID, tt.args.code)
|
||||
|
||||
if tt.res.errFunc == nil && err != nil {
|
||||
t.Errorf("result should not get err")
|
||||
}
|
||||
if tt.res.errFunc != nil && !tt.res.errFunc(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
35
internal/user/repository/eventsourcing/model/auth_request.go
Normal file
35
internal/user/repository/eventsourcing/model/auth_request.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/caos/zitadel/internal/auth_request/model"
|
||||
)
|
||||
|
||||
type AuthRequest struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
UserAgentID string `json:"userAgentID,omitempty"`
|
||||
*BrowserInfo
|
||||
}
|
||||
|
||||
func AuthRequestFromModel(request *model.AuthRequest) *AuthRequest {
|
||||
return &AuthRequest{
|
||||
ID: request.ID,
|
||||
UserAgentID: request.AgentID,
|
||||
BrowserInfo: BrowserInfoFromModel(request.BrowserInfo),
|
||||
}
|
||||
}
|
||||
|
||||
type BrowserInfo struct {
|
||||
UserAgent string `json:"userAgent,omitempty"`
|
||||
AcceptLanguage string `json:"acceptLanguage,omitempty"`
|
||||
RemoteIP net.IP `json:"remoteIP,omitempty"`
|
||||
}
|
||||
|
||||
func BrowserInfoFromModel(info *model.BrowserInfo) *BrowserInfo {
|
||||
return &BrowserInfo{
|
||||
UserAgent: info.UserAgent,
|
||||
AcceptLanguage: info.AcceptLanguage,
|
||||
RemoteIP: info.RemoteIP,
|
||||
}
|
||||
}
|
@@ -23,9 +23,11 @@ const (
|
||||
UserReactivated models.EventType = "user.reactivated"
|
||||
UserDeleted models.EventType = "user.deleted"
|
||||
|
||||
UserPasswordChanged models.EventType = "user.password.changed"
|
||||
UserPasswordCodeAdded models.EventType = "user.password.code.added"
|
||||
UserPasswordCodeSent models.EventType = "user.password.code.sent"
|
||||
UserPasswordChanged models.EventType = "user.password.changed"
|
||||
UserPasswordCodeAdded models.EventType = "user.password.code.added"
|
||||
UserPasswordCodeSent models.EventType = "user.password.code.sent"
|
||||
UserPasswordCheckSucceeded models.EventType = "user.password.check.succeeded"
|
||||
UserPasswordCheckFailed models.EventType = "user.password.check.failed"
|
||||
|
||||
UserEmailChanged models.EventType = "user.email.changed"
|
||||
UserEmailVerified models.EventType = "user.email.verified"
|
||||
@@ -40,8 +42,12 @@ const (
|
||||
UserProfileChanged models.EventType = "user.profile.changed"
|
||||
UserAddressChanged models.EventType = "user.address.changed"
|
||||
|
||||
MfaOtpAdded models.EventType = "user.mfa.otp.added"
|
||||
MfaOtpVerified models.EventType = "user.mfa.otp.verified"
|
||||
MfaOtpRemoved models.EventType = "user.mfa.otp.removed"
|
||||
MfaInitSkipped models.EventType = "user.mfa.init.skipped"
|
||||
MfaOtpAdded models.EventType = "user.mfa.otp.added"
|
||||
MfaOtpVerified models.EventType = "user.mfa.otp.verified"
|
||||
MfaOtpRemoved models.EventType = "user.mfa.otp.removed"
|
||||
MfaOtpCheckSucceeded models.EventType = "user.mfa.otp.check.succeeded"
|
||||
MfaOtpCheckFailed models.EventType = "user.mfa.otp.check.failed"
|
||||
MfaInitSkipped models.EventType = "user.mfa.init.skipped"
|
||||
|
||||
SignedOut models.EventType = "user.signed.out"
|
||||
)
|
||||
|
@@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/caos/zitadel/internal/errors"
|
||||
es_models "github.com/caos/zitadel/internal/eventstore/models"
|
||||
es_sdk "github.com/caos/zitadel/internal/eventstore/sdk"
|
||||
"github.com/caos/zitadel/internal/user/repository/eventsourcing/model"
|
||||
)
|
||||
|
||||
@@ -167,6 +168,25 @@ func PasswordChangeAggregate(aggCreator *es_models.AggregateCreator, existing *m
|
||||
}
|
||||
}
|
||||
|
||||
func PasswordCheckSucceededAggregate(aggCreator *es_models.AggregateCreator, existing *model.User, check *model.AuthRequest) es_sdk.AggregateFunc {
|
||||
return func(ctx context.Context) (*es_models.Aggregate, error) {
|
||||
agg, err := UserAggregate(ctx, aggCreator, existing)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return agg.AppendEvent(model.UserPasswordCheckSucceeded, check)
|
||||
}
|
||||
}
|
||||
func PasswordCheckFailedAggregate(aggCreator *es_models.AggregateCreator, existing *model.User, check *model.AuthRequest) es_sdk.AggregateFunc {
|
||||
return func(ctx context.Context) (*es_models.Aggregate, error) {
|
||||
agg, err := UserAggregate(ctx, aggCreator, existing)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return agg.AppendEvent(model.UserPasswordCheckFailed, check)
|
||||
}
|
||||
}
|
||||
|
||||
func RequestSetPassword(aggCreator *es_models.AggregateCreator, existing *model.User, request *model.PasswordCode) func(ctx context.Context) (*es_models.Aggregate, error) {
|
||||
return func(ctx context.Context) (*es_models.Aggregate, error) {
|
||||
if request == nil {
|
||||
@@ -338,6 +358,32 @@ func MfaOTPVerifyAggregate(aggCreator *es_models.AggregateCreator, existing *mod
|
||||
}
|
||||
}
|
||||
|
||||
func MfaOTPCheckSucceededAggregate(aggCreator *es_models.AggregateCreator, existing *model.User, authReq *model.AuthRequest) es_sdk.AggregateFunc {
|
||||
return func(ctx context.Context) (*es_models.Aggregate, error) {
|
||||
if authReq == nil {
|
||||
return nil, errors.ThrowPreconditionFailed(nil, "EVENT-sd5DA", "authReq must not be nil")
|
||||
}
|
||||
agg, err := UserAggregate(ctx, aggCreator, existing)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return agg.AppendEvent(model.MfaOtpCheckSucceeded, authReq)
|
||||
}
|
||||
}
|
||||
|
||||
func MfaOTPCheckFailedAggregate(aggCreator *es_models.AggregateCreator, existing *model.User, authReq *model.AuthRequest) es_sdk.AggregateFunc {
|
||||
return func(ctx context.Context) (*es_models.Aggregate, error) {
|
||||
if authReq == nil {
|
||||
return nil, errors.ThrowPreconditionFailed(nil, "EVENT-64sd6", "authReq must not be nil")
|
||||
}
|
||||
agg, err := UserAggregate(ctx, aggCreator, existing)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return agg.AppendEvent(model.MfaOtpCheckFailed, authReq)
|
||||
}
|
||||
}
|
||||
|
||||
func MfaOTPRemoveAggregate(aggCreator *es_models.AggregateCreator, existing *model.User) func(ctx context.Context) (*es_models.Aggregate, error) {
|
||||
return func(ctx context.Context) (*es_models.Aggregate, error) {
|
||||
agg, err := UserAggregate(ctx, aggCreator, existing)
|
||||
@@ -347,3 +393,13 @@ func MfaOTPRemoveAggregate(aggCreator *es_models.AggregateCreator, existing *mod
|
||||
return agg.AppendEvent(model.MfaOtpRemoved, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func SignOutAggregate(aggCreator *es_models.AggregateCreator, existing *model.User, agentID string) func(ctx context.Context) (*es_models.Aggregate, error) {
|
||||
return func(ctx context.Context) (*es_models.Aggregate, error) {
|
||||
agg, err := UserAggregate(ctx, aggCreator, existing)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return agg.AppendEvent(model.SignedOut, map[string]interface{}{"agentID": agentID})
|
||||
}
|
||||
}
|
||||
|
@@ -2,12 +2,15 @@ package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/caos/logging"
|
||||
|
||||
req_model "github.com/caos/zitadel/internal/auth_request/model"
|
||||
caos_errs "github.com/caos/zitadel/internal/errors"
|
||||
"github.com/caos/zitadel/internal/eventstore/models"
|
||||
"github.com/caos/zitadel/internal/user/model"
|
||||
es_model "github.com/caos/zitadel/internal/user/repository/eventsourcing/model"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -23,90 +26,102 @@ const (
|
||||
)
|
||||
|
||||
type UserView struct {
|
||||
ID string `json:"-" gorm:"column:id;primary_key"`
|
||||
CreationDate time.Time `json:"-" gorm:"column:creation_date"`
|
||||
ChangeDate time.Time `json:"-" gorm:"column:change_date"`
|
||||
ResourceOwner string `json:"-" gorm:"column:resource_owner"`
|
||||
State int32 `json:"-" gorm:"column:user_state"`
|
||||
PasswordChanged time.Time `json:"-" gorm:"column:password_change"`
|
||||
LastLogin time.Time `json:"-" gorm:"column:last_login"`
|
||||
UserName string `json:"userName" gorm:"column:user_name"`
|
||||
FirstName string `json:"firstName" gorm:"column:first_name"`
|
||||
LastName string `json:"lastName" gorm:"column:last_name"`
|
||||
NickName string `json:"nickName" gorm:"column:nick_name"`
|
||||
DisplayName string `json:"displayName" gorm:"column:display_name"`
|
||||
PreferredLanguage string `json:"preferredLanguage" gorm:"column:preferred_language"`
|
||||
Gender int32 `json:"gender" gorm:"column:gender"`
|
||||
Email string `json:"email" gorm:"column:email"`
|
||||
IsEmailVerified bool `json:"-" gorm:"column:is_email_verified"`
|
||||
Phone string `json:"phone" gorm:"column:phone"`
|
||||
IsPhoneVerified bool `json:"-" gorm:"column:is_phone_verified"`
|
||||
Country string `json:"country" gorm:"column:country"`
|
||||
Locality string `json:"locality" gorm:"column:locality"`
|
||||
PostalCode string `json:"postalCode" gorm:"column:postal_code"`
|
||||
Region string `json:"region" gorm:"column:region"`
|
||||
StreetAddress string `json:"streetAddress" gorm:"column:street_address"`
|
||||
OTPState int32 `json:"-" gorm:"column:otp_state"`
|
||||
Sequence uint64 `json:"-" gorm:"column:sequence"`
|
||||
ID string `json:"-" gorm:"column:id;primary_key"`
|
||||
CreationDate time.Time `json:"-" gorm:"column:creation_date"`
|
||||
ChangeDate time.Time `json:"-" gorm:"column:change_date"`
|
||||
ResourceOwner string `json:"-" gorm:"column:resource_owner"`
|
||||
State int32 `json:"-" gorm:"column:user_state"`
|
||||
PasswordSet bool `json:"-" gorm:"column:password_set"`
|
||||
PasswordChangeRequired bool `json:"-" gorm:"column:password_change_required"`
|
||||
PasswordChanged time.Time `json:"-" gorm:"column:password_change"`
|
||||
LastLogin time.Time `json:"-" gorm:"column:last_login"`
|
||||
UserName string `json:"userName" gorm:"column:user_name"`
|
||||
FirstName string `json:"firstName" gorm:"column:first_name"`
|
||||
LastName string `json:"lastName" gorm:"column:last_name"`
|
||||
NickName string `json:"nickName" gorm:"column:nick_name"`
|
||||
DisplayName string `json:"displayName" gorm:"column:display_name"`
|
||||
PreferredLanguage string `json:"preferredLanguage" gorm:"column:preferred_language"`
|
||||
Gender int32 `json:"gender" gorm:"column:gender"`
|
||||
Email string `json:"email" gorm:"column:email"`
|
||||
IsEmailVerified bool `json:"-" gorm:"column:is_email_verified"`
|
||||
Phone string `json:"phone" gorm:"column:phone"`
|
||||
IsPhoneVerified bool `json:"-" gorm:"column:is_phone_verified"`
|
||||
Country string `json:"country" gorm:"column:country"`
|
||||
Locality string `json:"locality" gorm:"column:locality"`
|
||||
PostalCode string `json:"postalCode" gorm:"column:postal_code"`
|
||||
Region string `json:"region" gorm:"column:region"`
|
||||
StreetAddress string `json:"streetAddress" gorm:"column:street_address"`
|
||||
OTPState int32 `json:"-" gorm:"column:otp_state"`
|
||||
MfaMaxSetUp int32 `json:"-" gorm:"column:mfa_max_set_up"`
|
||||
MfaInitSkipped time.Time `json:"-" gorm:"column:mfa_init_skipped"`
|
||||
Sequence uint64 `json:"-" gorm:"column:sequence"`
|
||||
}
|
||||
|
||||
func UserFromModel(user *model.UserView) *UserView {
|
||||
return &UserView{
|
||||
ID: user.ID,
|
||||
ChangeDate: user.ChangeDate,
|
||||
CreationDate: user.CreationDate,
|
||||
ResourceOwner: user.ResourceOwner,
|
||||
State: int32(user.State),
|
||||
PasswordChanged: user.PasswordChanged,
|
||||
LastLogin: user.LastLogin,
|
||||
UserName: user.UserName,
|
||||
FirstName: user.FirstName,
|
||||
LastName: user.LastName,
|
||||
NickName: user.NickName,
|
||||
DisplayName: user.DisplayName,
|
||||
PreferredLanguage: user.PreferredLanguage,
|
||||
Gender: int32(user.Gender),
|
||||
Email: user.Email,
|
||||
IsEmailVerified: user.IsEmailVerified,
|
||||
Phone: user.Phone,
|
||||
IsPhoneVerified: user.IsPhoneVerified,
|
||||
Country: user.Country,
|
||||
Locality: user.Locality,
|
||||
PostalCode: user.PostalCode,
|
||||
Region: user.Region,
|
||||
StreetAddress: user.StreetAddress,
|
||||
OTPState: int32(user.OTPState),
|
||||
Sequence: user.Sequence,
|
||||
ID: user.ID,
|
||||
ChangeDate: user.ChangeDate,
|
||||
CreationDate: user.CreationDate,
|
||||
ResourceOwner: user.ResourceOwner,
|
||||
State: int32(user.State),
|
||||
PasswordSet: user.PasswordSet,
|
||||
PasswordChangeRequired: user.PasswordChangeRequired,
|
||||
PasswordChanged: user.PasswordChanged,
|
||||
LastLogin: user.LastLogin,
|
||||
UserName: user.UserName,
|
||||
FirstName: user.FirstName,
|
||||
LastName: user.LastName,
|
||||
NickName: user.NickName,
|
||||
DisplayName: user.DisplayName,
|
||||
PreferredLanguage: user.PreferredLanguage,
|
||||
Gender: int32(user.Gender),
|
||||
Email: user.Email,
|
||||
IsEmailVerified: user.IsEmailVerified,
|
||||
Phone: user.Phone,
|
||||
IsPhoneVerified: user.IsPhoneVerified,
|
||||
Country: user.Country,
|
||||
Locality: user.Locality,
|
||||
PostalCode: user.PostalCode,
|
||||
Region: user.Region,
|
||||
StreetAddress: user.StreetAddress,
|
||||
OTPState: int32(user.OTPState),
|
||||
MfaMaxSetUp: int32(user.MfaMaxSetUp),
|
||||
MfaInitSkipped: user.MfaInitSkipped,
|
||||
Sequence: user.Sequence,
|
||||
}
|
||||
}
|
||||
|
||||
func UserToModel(user *UserView) *model.UserView {
|
||||
return &model.UserView{
|
||||
ID: user.ID,
|
||||
ChangeDate: user.ChangeDate,
|
||||
CreationDate: user.CreationDate,
|
||||
ResourceOwner: user.ResourceOwner,
|
||||
State: model.UserState(user.State),
|
||||
PasswordChanged: user.PasswordChanged,
|
||||
LastLogin: user.LastLogin,
|
||||
UserName: user.UserName,
|
||||
FirstName: user.FirstName,
|
||||
LastName: user.LastName,
|
||||
NickName: user.NickName,
|
||||
DisplayName: user.DisplayName,
|
||||
PreferredLanguage: user.PreferredLanguage,
|
||||
Gender: model.Gender(user.Gender),
|
||||
Email: user.Email,
|
||||
IsEmailVerified: user.IsEmailVerified,
|
||||
Phone: user.Phone,
|
||||
IsPhoneVerified: user.IsPhoneVerified,
|
||||
Country: user.Country,
|
||||
Locality: user.Locality,
|
||||
PostalCode: user.PostalCode,
|
||||
Region: user.Region,
|
||||
StreetAddress: user.StreetAddress,
|
||||
OTPState: model.MfaState(user.OTPState),
|
||||
Sequence: user.Sequence,
|
||||
ID: user.ID,
|
||||
ChangeDate: user.ChangeDate,
|
||||
CreationDate: user.CreationDate,
|
||||
ResourceOwner: user.ResourceOwner,
|
||||
State: model.UserState(user.State),
|
||||
PasswordSet: user.PasswordSet,
|
||||
PasswordChangeRequired: user.PasswordChangeRequired,
|
||||
PasswordChanged: user.PasswordChanged,
|
||||
LastLogin: user.LastLogin,
|
||||
UserName: user.UserName,
|
||||
FirstName: user.FirstName,
|
||||
LastName: user.LastName,
|
||||
NickName: user.NickName,
|
||||
DisplayName: user.DisplayName,
|
||||
PreferredLanguage: user.PreferredLanguage,
|
||||
Gender: model.Gender(user.Gender),
|
||||
Email: user.Email,
|
||||
IsEmailVerified: user.IsEmailVerified,
|
||||
Phone: user.Phone,
|
||||
IsPhoneVerified: user.IsPhoneVerified,
|
||||
Country: user.Country,
|
||||
Locality: user.Locality,
|
||||
PostalCode: user.PostalCode,
|
||||
Region: user.Region,
|
||||
StreetAddress: user.StreetAddress,
|
||||
OTPState: model.MfaState(user.OTPState),
|
||||
MfaMaxSetUp: req_model.MfaLevel(user.MfaMaxSetUp),
|
||||
MfaInitSkipped: user.MfaInitSkipped,
|
||||
Sequence: user.Sequence,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,43 +133,52 @@ func UsersToModel(users []*UserView) []*model.UserView {
|
||||
return result
|
||||
}
|
||||
|
||||
func (p *UserView) AppendEvent(event *models.Event) (err error) {
|
||||
p.ChangeDate = event.CreationDate
|
||||
p.Sequence = event.Sequence
|
||||
func (u *UserView) AppendEvent(event *models.Event) (err error) {
|
||||
u.ChangeDate = event.CreationDate
|
||||
u.Sequence = event.Sequence
|
||||
switch event.Type {
|
||||
case es_model.UserAdded,
|
||||
es_model.UserRegistered:
|
||||
p.CreationDate = event.CreationDate
|
||||
p.setRootData(event)
|
||||
err = p.setData(event)
|
||||
u.CreationDate = event.CreationDate
|
||||
u.setRootData(event)
|
||||
err = u.setData(event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = u.setPasswordData(event)
|
||||
case es_model.UserPasswordChanged:
|
||||
err = u.setPasswordData(event)
|
||||
case es_model.UserProfileChanged,
|
||||
es_model.UserAddressChanged:
|
||||
err = p.setData(event)
|
||||
err = u.setData(event)
|
||||
case es_model.UserEmailChanged:
|
||||
p.IsEmailVerified = false
|
||||
err = p.setData(event)
|
||||
u.IsEmailVerified = false
|
||||
err = u.setData(event)
|
||||
case es_model.UserEmailVerified:
|
||||
p.IsEmailVerified = true
|
||||
u.IsEmailVerified = true
|
||||
case es_model.UserPhoneChanged:
|
||||
p.IsPhoneVerified = false
|
||||
err = p.setData(event)
|
||||
u.IsPhoneVerified = false
|
||||
err = u.setData(event)
|
||||
case es_model.UserPhoneVerified:
|
||||
p.IsPhoneVerified = true
|
||||
u.IsPhoneVerified = true
|
||||
case es_model.UserDeactivated:
|
||||
p.State = int32(model.USERSTATE_INACTIVE)
|
||||
u.State = int32(model.USERSTATE_INACTIVE)
|
||||
case es_model.UserReactivated,
|
||||
es_model.UserUnlocked:
|
||||
p.State = int32(model.USERSTATE_ACTIVE)
|
||||
u.State = int32(model.USERSTATE_ACTIVE)
|
||||
case es_model.UserLocked:
|
||||
p.State = int32(model.USERSTATE_LOCKED)
|
||||
u.State = int32(model.USERSTATE_LOCKED)
|
||||
case es_model.MfaOtpAdded:
|
||||
p.OTPState = int32(model.MFASTATE_NOTREADY)
|
||||
u.OTPState = int32(model.MFASTATE_NOTREADY)
|
||||
case es_model.MfaOtpVerified:
|
||||
p.OTPState = int32(model.MFASTATE_READY)
|
||||
u.OTPState = int32(model.MFASTATE_READY)
|
||||
u.MfaInitSkipped = time.Time{}
|
||||
case es_model.MfaOtpRemoved:
|
||||
p.OTPState = int32(model.MFASTATE_UNSPECIFIED)
|
||||
u.OTPState = int32(model.MFASTATE_UNSPECIFIED)
|
||||
case es_model.MfaInitSkipped:
|
||||
u.MfaInitSkipped = event.CreationDate
|
||||
}
|
||||
p.ComputeObject()
|
||||
u.ComputeObject()
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -165,12 +189,23 @@ func (u *UserView) setRootData(event *models.Event) {
|
||||
|
||||
func (u *UserView) setData(event *models.Event) error {
|
||||
if err := json.Unmarshal(event.Data, u); err != nil {
|
||||
logging.Log("EVEN-lso9e").WithError(err).Error("could not unmarshal event data")
|
||||
logging.Log("MODEL-lso9e").WithError(err).Error("could not unmarshal event data")
|
||||
return caos_errs.ThrowInternal(nil, "MODEL-8iows", "could not unmarshal data")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UserView) setPasswordData(event *models.Event) error {
|
||||
password := new(es_model.Password)
|
||||
if err := json.Unmarshal(event.Data, password); err != nil {
|
||||
logging.Log("MODEL-sdw4r").WithError(err).Error("could not unmarshal event data")
|
||||
return caos_errs.ThrowInternal(nil, "MODEL-6jhsw", "could not unmarshal data")
|
||||
}
|
||||
u.PasswordSet = password.Secret != nil
|
||||
u.PasswordChangeRequired = password.ChangeRequired
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *UserView) ComputeObject() {
|
||||
if u.State == int32(model.USERSTATE_UNSPECIFIED) || u.State == int32(model.USERSTATE_INITIAL) {
|
||||
if u.IsEmailVerified {
|
||||
@@ -179,4 +214,7 @@ func (u *UserView) ComputeObject() {
|
||||
u.State = int32(model.USERSTATE_INITIAL)
|
||||
}
|
||||
}
|
||||
if u.OTPState == int32(model.MFASTATE_READY) {
|
||||
u.MfaMaxSetUp = int32(req_model.MfaLevelSoftware)
|
||||
}
|
||||
}
|
||||
|
91
internal/user/repository/view/model/user_session.go
Normal file
91
internal/user/repository/view/model/user_session.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/caos/logging"
|
||||
|
||||
req_model "github.com/caos/zitadel/internal/auth_request/model"
|
||||
caos_errs "github.com/caos/zitadel/internal/errors"
|
||||
"github.com/caos/zitadel/internal/eventstore/models"
|
||||
"github.com/caos/zitadel/internal/user/model"
|
||||
es_model "github.com/caos/zitadel/internal/user/repository/eventsourcing/model"
|
||||
)
|
||||
|
||||
const (
|
||||
UserSessionKeySessionID = "id"
|
||||
UserSessionKeyUserAgentID = "user_agent_id"
|
||||
UserSessionKeyUserID = "user_id"
|
||||
UserSessionKeyState = "state"
|
||||
UserSessionKeyResourceOwner = "resource_owner"
|
||||
)
|
||||
|
||||
type UserSessionView struct {
|
||||
ID string `json:"-" gorm:"column:id;primary_key"`
|
||||
CreationDate time.Time `json:"-" gorm:"column:creation_date"`
|
||||
ChangeDate time.Time `json:"-" gorm:"column:change_date"`
|
||||
ResourceOwner string `json:"-" gorm:"column:resource_owner"`
|
||||
State int32 `json:"-" gorm:"column:state"`
|
||||
UserAgentID string `json:"userAgentID" gorm:"column:user_agent_id"`
|
||||
UserID string `json:"userID" gorm:"column:user_id"`
|
||||
UserName string `json:"userName" gorm:"column:user_name"`
|
||||
PasswordVerification time.Time `json:"-" gorm:"column:password_verification"`
|
||||
MfaSoftwareVerification time.Time `json:"-" gorm:"column:mfa_software_verification"`
|
||||
MfaHardwareVerification time.Time `json:"-" gorm:"column:mfa_hardware_verification"`
|
||||
Sequence uint64 `json:"-" gorm:"column:sequence"`
|
||||
}
|
||||
|
||||
func UserSessionFromEvent(event *models.Event) (*UserSessionView, error) {
|
||||
v := new(UserSessionView)
|
||||
if err := json.Unmarshal(event.Data, v); err != nil {
|
||||
logging.Log("EVEN-lso9e").WithError(err).Error("could not unmarshal event data")
|
||||
return nil, caos_errs.ThrowInternal(nil, "MODEL-sd325", "could not unmarshal data")
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func UserSessionToModel(userSession *UserSessionView) *model.UserSessionView {
|
||||
return &model.UserSessionView{
|
||||
ID: userSession.ID,
|
||||
ChangeDate: userSession.ChangeDate,
|
||||
CreationDate: userSession.CreationDate,
|
||||
ResourceOwner: userSession.ResourceOwner,
|
||||
State: req_model.UserSessionState(userSession.State),
|
||||
UserAgentID: userSession.UserAgentID,
|
||||
UserID: userSession.UserID,
|
||||
UserName: userSession.UserName,
|
||||
PasswordVerification: userSession.PasswordVerification,
|
||||
MfaSoftwareVerification: userSession.MfaSoftwareVerification,
|
||||
MfaHardwareVerification: userSession.MfaHardwareVerification,
|
||||
Sequence: userSession.Sequence,
|
||||
}
|
||||
}
|
||||
|
||||
func UserSessionsToModel(userSessions []*UserSessionView) []*model.UserSessionView {
|
||||
result := make([]*model.UserSessionView, len(userSessions))
|
||||
for i, s := range userSessions {
|
||||
result[i] = UserSessionToModel(s)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (v *UserSessionView) AppendEvent(event *models.Event) {
|
||||
v.ChangeDate = event.CreationDate
|
||||
switch event.Type {
|
||||
case es_model.UserPasswordCheckSucceeded:
|
||||
v.PasswordVerification = event.CreationDate
|
||||
case es_model.UserPasswordCheckFailed,
|
||||
es_model.UserPasswordChanged:
|
||||
v.PasswordVerification = time.Time{}
|
||||
case es_model.MfaOtpCheckSucceeded:
|
||||
v.MfaSoftwareVerification = event.CreationDate
|
||||
case es_model.MfaOtpCheckFailed,
|
||||
es_model.MfaOtpRemoved:
|
||||
v.MfaSoftwareVerification = time.Time{}
|
||||
case es_model.SignedOut:
|
||||
v.PasswordVerification = time.Time{}
|
||||
v.MfaSoftwareVerification = time.Time{}
|
||||
v.State = int32(req_model.UserSessionStateTerminated)
|
||||
}
|
||||
}
|
67
internal/user/repository/view/model/user_session_query.go
Normal file
67
internal/user/repository/view/model/user_session_query.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
global_model "github.com/caos/zitadel/internal/model"
|
||||
usr_model "github.com/caos/zitadel/internal/user/model"
|
||||
"github.com/caos/zitadel/internal/view"
|
||||
)
|
||||
|
||||
type UserSessionSearchRequest usr_model.UserSessionSearchRequest
|
||||
type UserSessionSearchQuery usr_model.UserSessionSearchQuery
|
||||
type UserSessionSearchKey usr_model.UserSessionSearchKey
|
||||
|
||||
func (req UserSessionSearchRequest) GetLimit() uint64 {
|
||||
return req.Limit
|
||||
}
|
||||
|
||||
func (req UserSessionSearchRequest) GetOffset() uint64 {
|
||||
return req.Offset
|
||||
}
|
||||
|
||||
func (req UserSessionSearchRequest) GetSortingColumn() view.ColumnKey {
|
||||
if req.SortingColumn == usr_model.USERSESSIONSEARCHKEY_UNSPECIFIED {
|
||||
return nil
|
||||
}
|
||||
return UserSessionSearchKey(req.SortingColumn)
|
||||
}
|
||||
|
||||
func (req UserSessionSearchRequest) GetAsc() bool {
|
||||
return req.Asc
|
||||
}
|
||||
|
||||
func (req UserSessionSearchRequest) GetQueries() []view.SearchQuery {
|
||||
result := make([]view.SearchQuery, len(req.Queries))
|
||||
for i, q := range req.Queries {
|
||||
result[i] = UserSessionSearchQuery{Key: q.Key, Value: q.Value, Method: q.Method}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (req UserSessionSearchQuery) GetKey() view.ColumnKey {
|
||||
return UserSessionSearchKey(req.Key)
|
||||
}
|
||||
|
||||
func (req UserSessionSearchQuery) GetMethod() global_model.SearchMethod {
|
||||
return req.Method
|
||||
}
|
||||
|
||||
func (req UserSessionSearchQuery) GetValue() interface{} {
|
||||
return req.Value
|
||||
}
|
||||
|
||||
func (key UserSessionSearchKey) ToColumnName() string {
|
||||
switch usr_model.UserSessionSearchKey(key) {
|
||||
case usr_model.USERSESSIONSEARCHKEY_SESSION_ID:
|
||||
return UserSessionKeySessionID
|
||||
case usr_model.USERSESSIONSEARCHKEY_USER_AGENT_ID:
|
||||
return UserSessionKeyUserAgentID
|
||||
case usr_model.USERSESSIONSEARCHKEY_USER_ID:
|
||||
return UserSessionKeyUserID
|
||||
case usr_model.USERSESSIONSEARCHKEY_STATE:
|
||||
return UserSessionKeyState
|
||||
case usr_model.USERSESSIONSEARCHKEY_RESOURCEOWNER:
|
||||
return UserSessionKeyResourceOwner
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
90
internal/user/repository/view/model/user_session_test.go
Normal file
90
internal/user/repository/view/model/user_session_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
es_models "github.com/caos/zitadel/internal/eventstore/models"
|
||||
es_model "github.com/caos/zitadel/internal/user/repository/eventsourcing/model"
|
||||
)
|
||||
|
||||
func now() time.Time {
|
||||
return time.Now().UTC().Round(1 * time.Second)
|
||||
}
|
||||
|
||||
func TestAppendEvent(t *testing.T) {
|
||||
type args struct {
|
||||
event *es_models.Event
|
||||
userView *UserSessionView
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
result *UserSessionView
|
||||
}{
|
||||
{
|
||||
name: "append password check succeeded event",
|
||||
args: args{
|
||||
event: &es_models.Event{CreationDate: now(), Type: es_model.UserPasswordCheckSucceeded},
|
||||
userView: &UserSessionView{},
|
||||
},
|
||||
result: &UserSessionView{ChangeDate: now(), PasswordVerification: now()},
|
||||
},
|
||||
{
|
||||
name: "append password check failed event",
|
||||
args: args{
|
||||
event: &es_models.Event{CreationDate: now(), Type: es_model.UserPasswordCheckFailed},
|
||||
userView: &UserSessionView{PasswordVerification: now()},
|
||||
},
|
||||
result: &UserSessionView{ChangeDate: now(), PasswordVerification: time.Time{}},
|
||||
},
|
||||
{
|
||||
name: "append password changed event",
|
||||
args: args{
|
||||
event: &es_models.Event{CreationDate: now(), Type: es_model.UserPasswordChanged},
|
||||
userView: &UserSessionView{PasswordVerification: now()},
|
||||
},
|
||||
result: &UserSessionView{ChangeDate: now(), PasswordVerification: time.Time{}},
|
||||
},
|
||||
{
|
||||
name: "append otp check succeeded event",
|
||||
args: args{
|
||||
event: &es_models.Event{CreationDate: now(), Type: es_model.MfaOtpCheckSucceeded},
|
||||
userView: &UserSessionView{},
|
||||
},
|
||||
result: &UserSessionView{ChangeDate: now(), MfaSoftwareVerification: now()},
|
||||
},
|
||||
{
|
||||
name: "append otp check failed event",
|
||||
args: args{
|
||||
event: &es_models.Event{CreationDate: now(), Type: es_model.MfaOtpCheckFailed},
|
||||
userView: &UserSessionView{MfaSoftwareVerification: now()},
|
||||
},
|
||||
result: &UserSessionView{ChangeDate: now(), MfaSoftwareVerification: time.Time{}},
|
||||
},
|
||||
{
|
||||
name: "append otp removed event",
|
||||
args: args{
|
||||
event: &es_models.Event{CreationDate: now(), Type: es_model.MfaOtpCheckFailed},
|
||||
userView: &UserSessionView{MfaSoftwareVerification: now()},
|
||||
},
|
||||
result: &UserSessionView{ChangeDate: now(), MfaSoftwareVerification: time.Time{}},
|
||||
},
|
||||
{
|
||||
name: "append otp removed event",
|
||||
args: args{
|
||||
event: &es_models.Event{CreationDate: now(), Type: es_model.SignedOut},
|
||||
userView: &UserSessionView{PasswordVerification: now(), MfaSoftwareVerification: now()},
|
||||
},
|
||||
result: &UserSessionView{ChangeDate: now(), PasswordVerification: time.Time{}, MfaSoftwareVerification: time.Time{}, State: 1},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.args.userView.AppendEvent(tt.args.event)
|
||||
assert.Equal(t, tt.result, tt.args.userView)
|
||||
})
|
||||
}
|
||||
}
|
@@ -2,10 +2,13 @@ package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/caos/zitadel/internal/crypto"
|
||||
es_models "github.com/caos/zitadel/internal/eventstore/models"
|
||||
"github.com/caos/zitadel/internal/user/model"
|
||||
es_model "github.com/caos/zitadel/internal/user/repository/eventsourcing/model"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func mockUserData(user *es_model.User) []byte {
|
||||
@@ -13,6 +16,11 @@ func mockUserData(user *es_model.User) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
func mockPasswordData(password *es_model.Password) []byte {
|
||||
data, _ := json.Marshal(password)
|
||||
return data
|
||||
}
|
||||
|
||||
func mockProfileData(profile *es_model.Profile) []byte {
|
||||
data, _ := json.Marshal(profile)
|
||||
return data
|
||||
@@ -33,7 +41,7 @@ func mockAddressData(address *es_model.Address) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
func getFullUser() *es_model.User {
|
||||
func getFullUser(password *es_model.Password) *es_model.User {
|
||||
return &es_model.User{
|
||||
Profile: &es_model.Profile{
|
||||
UserName: "UserName",
|
||||
@@ -49,6 +57,7 @@ func getFullUser() *es_model.User {
|
||||
Address: &es_model.Address{
|
||||
Country: "Country",
|
||||
},
|
||||
Password: password,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,11 +74,43 @@ func TestUserAppendEvent(t *testing.T) {
|
||||
{
|
||||
name: "append added user event",
|
||||
args: args{
|
||||
event: &es_models.Event{AggregateID: "AggregateID", Sequence: 1, Type: es_model.UserAdded, ResourceOwner: "OrgID", Data: mockUserData(getFullUser())},
|
||||
event: &es_models.Event{AggregateID: "AggregateID", Sequence: 1, Type: es_model.UserAdded, ResourceOwner: "OrgID", Data: mockUserData(getFullUser(nil))},
|
||||
user: &UserView{},
|
||||
},
|
||||
result: &UserView{ID: "AggregateID", ResourceOwner: "OrgID", UserName: "UserName", FirstName: "FirstName", LastName: "LastName", Email: "Email", Phone: "Phone", Country: "Country", State: int32(model.USERSTATE_INITIAL)},
|
||||
},
|
||||
{
|
||||
name: "append added user with password event",
|
||||
args: args{
|
||||
event: &es_models.Event{AggregateID: "AggregateID", Sequence: 1, Type: es_model.UserAdded, ResourceOwner: "OrgID", Data: mockUserData(getFullUser(&es_model.Password{Secret: &crypto.CryptoValue{}}))},
|
||||
user: &UserView{},
|
||||
},
|
||||
result: &UserView{ID: "AggregateID", ResourceOwner: "OrgID", UserName: "UserName", FirstName: "FirstName", LastName: "LastName", Email: "Email", Phone: "Phone", Country: "Country", State: int32(model.USERSTATE_INITIAL), PasswordSet: true},
|
||||
},
|
||||
{
|
||||
name: "append added user with password but change required event",
|
||||
args: args{
|
||||
event: &es_models.Event{AggregateID: "AggregateID", Sequence: 1, Type: es_model.UserAdded, ResourceOwner: "OrgID", Data: mockUserData(getFullUser(&es_model.Password{ChangeRequired: true, Secret: &crypto.CryptoValue{}}))},
|
||||
user: &UserView{},
|
||||
},
|
||||
result: &UserView{ID: "AggregateID", ResourceOwner: "OrgID", UserName: "UserName", FirstName: "FirstName", LastName: "LastName", Email: "Email", Phone: "Phone", Country: "Country", State: int32(model.USERSTATE_INITIAL), PasswordSet: true, PasswordChangeRequired: true},
|
||||
},
|
||||
{
|
||||
name: "append password change event",
|
||||
args: args{
|
||||
event: &es_models.Event{AggregateID: "AggregateID", Sequence: 1, Type: es_model.UserPasswordChanged, ResourceOwner: "OrgID", Data: mockPasswordData(&es_model.Password{Secret: &crypto.CryptoValue{}})},
|
||||
user: &UserView{ID: "AggregateID", ResourceOwner: "OrgID", UserName: "UserName", FirstName: "FirstName", LastName: "LastName", Email: "Email", IsEmailVerified: true, Phone: "Phone", Country: "Country", State: int32(model.USERSTATE_ACTIVE)},
|
||||
},
|
||||
result: &UserView{ID: "AggregateID", ResourceOwner: "OrgID", UserName: "UserName", FirstName: "FirstName", LastName: "LastName", Email: "Email", IsEmailVerified: true, Phone: "Phone", Country: "Country", State: int32(model.USERSTATE_ACTIVE), PasswordSet: true},
|
||||
},
|
||||
{
|
||||
name: "append password change with change required event",
|
||||
args: args{
|
||||
event: &es_models.Event{AggregateID: "AggregateID", Sequence: 1, Type: es_model.UserPasswordChanged, ResourceOwner: "OrgID", Data: mockPasswordData(&es_model.Password{ChangeRequired: true, Secret: &crypto.CryptoValue{}})},
|
||||
user: &UserView{ID: "AggregateID", ResourceOwner: "OrgID", UserName: "UserName", FirstName: "FirstName", LastName: "LastName", Email: "Email", IsEmailVerified: true, Phone: "Phone", Country: "Country", State: int32(model.USERSTATE_ACTIVE)},
|
||||
},
|
||||
result: &UserView{ID: "AggregateID", ResourceOwner: "OrgID", UserName: "UserName", FirstName: "FirstName", LastName: "LastName", Email: "Email", IsEmailVerified: true, Phone: "Phone", Country: "Country", State: int32(model.USERSTATE_ACTIVE), PasswordSet: true, PasswordChangeRequired: true},
|
||||
},
|
||||
{
|
||||
name: "append change user profile event",
|
||||
args: args{
|
||||
@@ -174,6 +215,14 @@ func TestUserAppendEvent(t *testing.T) {
|
||||
},
|
||||
result: &UserView{ID: "AggregateID", ResourceOwner: "OrgID", UserName: "UserName", FirstName: "FirstName", LastName: "LastName", Email: "Email", Phone: "Phone", Country: "Country", State: int32(model.USERSTATE_ACTIVE), OTPState: int32(model.MFASTATE_UNSPECIFIED)},
|
||||
},
|
||||
{
|
||||
name: "append mfa init skipped event",
|
||||
args: args{
|
||||
event: &es_models.Event{Sequence: 1, CreationDate: time.Now().UTC(), Type: es_model.MfaInitSkipped, AggregateID: "AggregateID", ResourceOwner: "OrgID"},
|
||||
user: &UserView{ID: "AggregateID", ResourceOwner: "OrgID", UserName: "UserName", FirstName: "FirstName", LastName: "LastName", Email: "Email", Phone: "Phone", Country: "Country", State: int32(model.USERSTATE_ACTIVE)},
|
||||
},
|
||||
result: &UserView{ID: "AggregateID", ResourceOwner: "OrgID", UserName: "UserName", FirstName: "FirstName", LastName: "LastName", Email: "Email", Phone: "Phone", Country: "Country", State: int32(model.USERSTATE_ACTIVE), MfaInitSkipped: time.Now().UTC()},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
@@ -211,6 +260,15 @@ func TestUserAppendEvent(t *testing.T) {
|
||||
if tt.args.user.OTPState != tt.result.OTPState {
|
||||
t.Errorf("got wrong result OTPState: expected: %v, actual: %v ", tt.result.OTPState, tt.args.user.OTPState)
|
||||
}
|
||||
if tt.args.user.MfaInitSkipped.Round(1*time.Second) != tt.result.MfaInitSkipped.Round(1*time.Second) {
|
||||
t.Errorf("got wrong result MfaInitSkipped: expected: %v, actual: %v ", tt.result.MfaInitSkipped.Round(1*time.Second), tt.args.user.MfaInitSkipped.Round(1*time.Second))
|
||||
}
|
||||
if tt.args.user.PasswordSet != tt.result.PasswordSet {
|
||||
t.Errorf("got wrong result PasswordSet: expected: %v, actual: %v ", tt.result.PasswordSet, tt.args.user.PasswordSet)
|
||||
}
|
||||
if tt.args.user.PasswordChangeRequired != tt.result.PasswordChangeRequired {
|
||||
t.Errorf("got wrong result PasswordChangeRequired: expected: %v, actual: %v ", tt.result.PasswordChangeRequired, tt.args.user.PasswordChangeRequired)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
58
internal/user/repository/view/user_session_view.go
Normal file
58
internal/user/repository/view/user_session_view.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package view
|
||||
|
||||
import (
|
||||
"github.com/jinzhu/gorm"
|
||||
|
||||
global_model "github.com/caos/zitadel/internal/model"
|
||||
usr_model "github.com/caos/zitadel/internal/user/model"
|
||||
"github.com/caos/zitadel/internal/user/repository/view/model"
|
||||
"github.com/caos/zitadel/internal/view"
|
||||
)
|
||||
|
||||
func UserSessionByID(db *gorm.DB, table, sessionID string) (*model.UserSessionView, error) {
|
||||
userSession := new(model.UserSessionView)
|
||||
query := view.PrepareGetByKey(table, model.UserSessionSearchKey(usr_model.USERSESSIONSEARCHKEY_SESSION_ID), sessionID)
|
||||
err := query(db, userSession)
|
||||
return userSession, err
|
||||
}
|
||||
|
||||
func UserSessionByIDs(db *gorm.DB, table, agentID, userID string) (*model.UserSessionView, error) {
|
||||
userSession := new(model.UserSessionView)
|
||||
userAgentQuery := model.UserSessionSearchQuery{
|
||||
Key: usr_model.USERSESSIONSEARCHKEY_USER_AGENT_ID,
|
||||
Method: global_model.SEARCHMETHOD_EQUALS,
|
||||
Value: agentID,
|
||||
}
|
||||
userQuery := model.UserSessionSearchQuery{
|
||||
Key: usr_model.USERSESSIONSEARCHKEY_USER_ID,
|
||||
Method: global_model.SEARCHMETHOD_EQUALS,
|
||||
Value: userID,
|
||||
}
|
||||
query := view.PrepareGetByQuery(table, userAgentQuery, userQuery)
|
||||
err := query(db, userSession)
|
||||
return userSession, err
|
||||
}
|
||||
|
||||
func UserSessionsByAgentID(db *gorm.DB, table, agentID string) ([]*model.UserSessionView, error) {
|
||||
userSessions := make([]*model.UserSessionView, 0)
|
||||
userAgentQuery := &usr_model.UserSessionSearchQuery{
|
||||
Key: usr_model.USERSESSIONSEARCHKEY_USER_AGENT_ID,
|
||||
Method: global_model.SEARCHMETHOD_EQUALS,
|
||||
Value: agentID,
|
||||
}
|
||||
query := view.PrepareSearchQuery(table, model.UserSessionSearchRequest{
|
||||
Queries: []*usr_model.UserSessionSearchQuery{userAgentQuery},
|
||||
})
|
||||
_, err := query(db, userSessions)
|
||||
return userSessions, err
|
||||
}
|
||||
|
||||
func PutUserSession(db *gorm.DB, table string, session *model.UserSessionView) error {
|
||||
save := view.PrepareSave(table)
|
||||
return save(db, session)
|
||||
}
|
||||
|
||||
func DeleteUserSession(db *gorm.DB, table, sessionID string) error {
|
||||
delete := view.PrepareDeleteByKey(table, model.UserSessionSearchKey(usr_model.USERSESSIONSEARCHKEY_USER_ID), sessionID)
|
||||
return delete(db)
|
||||
}
|
@@ -1,19 +1,31 @@
|
||||
package view
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/caos/zitadel/internal/model"
|
||||
"github.com/jinzhu/gorm"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/jinzhu/gorm"
|
||||
|
||||
"github.com/caos/zitadel/internal/model"
|
||||
)
|
||||
|
||||
var (
|
||||
expectedGetByID = `SELECT \* FROM "%s" WHERE \(%s = \$1\) LIMIT 1`
|
||||
expectedGetByQuery = `SELECT \* FROM "%s" WHERE \(LOWER\(%s\) %s LOWER\(\$1\)\) LIMIT 1`
|
||||
expectedGetByQueryCaseSensitive = `SELECT \* FROM "%s" WHERE \(%s %s \$1\) LIMIT 1`
|
||||
expectedSave = `UPDATE "%s" SET "test" = \$1 WHERE "%s"."%s" = \$2`
|
||||
expectedRemove = `DELETE FROM "%s" WHERE \(%s = \$1\)`
|
||||
expectedGetByID = `SELECT \* FROM "%s" WHERE \(%s = \$1\) LIMIT 1`
|
||||
expectedGetByQuery = `SELECT \* FROM "%s" WHERE \(LOWER\(%s\) %s LOWER\(\$1\)\) LIMIT 1`
|
||||
expectedGetByQueryCaseSensitive = `SELECT \* FROM "%s" WHERE \(%s %s \$1\) LIMIT 1`
|
||||
expectedSave = `UPDATE "%s" SET "test" = \$1 WHERE "%s"."%s" = \$2`
|
||||
expectedRemove = `DELETE FROM "%s" WHERE \(%s = \$1\)`
|
||||
expectedRemoveByKeys = func(i int, table string) string {
|
||||
sql := fmt.Sprintf(`DELETE FROM "%s"`, table)
|
||||
sql += ` WHERE \(%s = \$1\)`
|
||||
for j := 1; j < i; j++ {
|
||||
sql = sql + ` AND \(%s = \$` + strconv.Itoa(j+1) + `\)`
|
||||
}
|
||||
return sql
|
||||
}
|
||||
expectedRemoveByObject = `DELETE FROM "%s" WHERE "%s"."%s" = \$1`
|
||||
expectedRemoveByObjectMultiplePK = `DELETE FROM "%s" WHERE "%s"."%s" = \$1 AND "%s"."%s" = \$2`
|
||||
expectedSearch = `SELECT \* FROM "%s" OFFSET 0`
|
||||
@@ -235,6 +247,21 @@ func (db *dbMock) expectRemove(table, key, value string) *dbMock {
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectRemoveKeys(table string, keys ...Key) *dbMock {
|
||||
keynames := make([]interface{}, len(keys))
|
||||
keyvalues := make([]driver.Value, len(keys))
|
||||
for i, key := range keys {
|
||||
keynames[i] = key.Key.ToColumnName()
|
||||
keyvalues[i] = key.Value
|
||||
}
|
||||
query := fmt.Sprintf(expectedRemoveByKeys(len(keys), table), keynames...)
|
||||
db.mock.ExpectExec(query).
|
||||
WithArgs(keyvalues...).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectRemoveByObject(table string, object Test) *dbMock {
|
||||
query := fmt.Sprintf(expectedRemoveByObject, table, table, "primary_id")
|
||||
db.mock.ExpectExec(query).
|
||||
|
@@ -3,9 +3,11 @@ package view
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/caos/logging"
|
||||
caos_errs "github.com/caos/zitadel/internal/errors"
|
||||
"github.com/jinzhu/gorm"
|
||||
|
||||
caos_errs "github.com/caos/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
func PrepareGetByKey(table string, key ColumnKey, id string) func(db *gorm.DB, res interface{}) error {
|
||||
@@ -71,6 +73,27 @@ func PrepareDeleteByKey(table string, key ColumnKey, id string) func(db *gorm.DB
|
||||
}
|
||||
}
|
||||
|
||||
type Key struct {
|
||||
Key ColumnKey
|
||||
Value string
|
||||
}
|
||||
|
||||
func PrepareDeleteByKeys(table string, keys ...Key) func(db *gorm.DB) error {
|
||||
return func(db *gorm.DB) error {
|
||||
for _, key := range keys {
|
||||
db = db.Table(table).
|
||||
Where(fmt.Sprintf("%s = ?", key.Key.ToColumnName()), key.Value)
|
||||
}
|
||||
err := db.
|
||||
Delete(nil).
|
||||
Error
|
||||
if err != nil {
|
||||
return caos_errs.ThrowInternal(err, "VIEW-die73", "could not delete object")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func PrepareDeleteByObject(table string, object interface{}) func(db *gorm.DB) error {
|
||||
return func(db *gorm.DB) error {
|
||||
err := db.Table(table).
|
||||
|
@@ -391,6 +391,97 @@ func TestPrepareDelete(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareDeleteByKeys(t *testing.T) {
|
||||
type args struct {
|
||||
table string
|
||||
keys []Key
|
||||
}
|
||||
type res struct {
|
||||
result Test
|
||||
wantErr bool
|
||||
errFunc func(err error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
db *dbMock
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"delete single key",
|
||||
mockDB(t).
|
||||
expectBegin(nil).
|
||||
expectRemoveKeys("TESTTABLE", Key{Key: TestSearchKey_ID, Value: "VALUE"}).
|
||||
expectCommit(nil),
|
||||
args{
|
||||
table: "TESTTABLE",
|
||||
keys: []Key{
|
||||
{Key: TestSearchKey_ID, Value: "VALUE"},
|
||||
},
|
||||
},
|
||||
res{
|
||||
result: Test{ID: "VALUE"},
|
||||
wantErr: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
"delete multiple keys",
|
||||
mockDB(t).
|
||||
expectBegin(nil).
|
||||
expectRemoveKeys("TESTTABLE", Key{Key: TestSearchKey_ID, Value: "VALUE"}, Key{Key: TestSearchKey_TEST, Value: "VALUE2"}).
|
||||
expectCommit(nil),
|
||||
args{
|
||||
table: "TESTTABLE",
|
||||
keys: []Key{
|
||||
{Key: TestSearchKey_ID, Value: "VALUE"},
|
||||
{Key: TestSearchKey_TEST, Value: "VALUE2"},
|
||||
},
|
||||
},
|
||||
res{
|
||||
result: Test{ID: "VALUE"},
|
||||
wantErr: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
"db error",
|
||||
mockDB(t).
|
||||
expectBegin(nil).
|
||||
expectRemoveErr("TESTTABLE", "id", "VALUE", gorm.ErrUnaddressable).
|
||||
expectCommit(nil),
|
||||
args{
|
||||
table: "TESTTABLE",
|
||||
keys: []Key{
|
||||
{Key: TestSearchKey_ID, Value: "VALUE"},
|
||||
},
|
||||
},
|
||||
res{
|
||||
result: Test{ID: "VALUE"},
|
||||
wantErr: true,
|
||||
errFunc: caos_errs.IsInternal,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
getDelete := PrepareDeleteByKeys(tt.args.table, tt.args.keys...)
|
||||
err := getDelete(tt.db.db)
|
||||
|
||||
if !tt.res.wantErr && err != nil {
|
||||
t.Errorf("got wrong err should be nil: %v ", err)
|
||||
}
|
||||
|
||||
if tt.res.wantErr && !tt.res.errFunc(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
}
|
||||
if err := tt.db.mock.ExpectationsWereMet(); !tt.res.wantErr && err != nil {
|
||||
t.Errorf("there were unfulfilled expectations: %s", err)
|
||||
}
|
||||
|
||||
tt.db.close()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareDeleteByObject(t *testing.T) {
|
||||
type args struct {
|
||||
table string
|
||||
|
Reference in New Issue
Block a user