mirror of
https://github.com/zitadel/zitadel.git
synced 2025-01-12 12:43:40 +00:00
1c59d18fee
* add csrf * caching * caching * caching * caching * security headers * csp and security headers * error handler csp * select user with display name * csp * user selection styling * username to loginname * regenerate grpc * regenerate * change to login name
355 lines
11 KiB
Go
355 lines
11 KiB
Go
package eventstore
|
|
|
|
import (
|
|
"context"
|
|
"time"
|
|
|
|
"github.com/caos/logging"
|
|
|
|
"github.com/caos/zitadel/internal/auth/repository/eventsourcing/view"
|
|
"github.com/caos/zitadel/internal/auth_request/model"
|
|
cache "github.com/caos/zitadel/internal/auth_request/repository"
|
|
"github.com/caos/zitadel/internal/errors"
|
|
es_models "github.com/caos/zitadel/internal/eventstore/models"
|
|
"github.com/caos/zitadel/internal/id"
|
|
user_model "github.com/caos/zitadel/internal/user/model"
|
|
user_event "github.com/caos/zitadel/internal/user/repository/eventsourcing"
|
|
es_model "github.com/caos/zitadel/internal/user/repository/eventsourcing/model"
|
|
view_model "github.com/caos/zitadel/internal/user/repository/view/model"
|
|
)
|
|
|
|
type AuthRequestRepo struct {
|
|
UserEvents *user_event.UserEventstore
|
|
AuthRequests cache.AuthRequestCache
|
|
View *view.View
|
|
|
|
UserSessionViewProvider userSessionViewProvider
|
|
UserViewProvider userViewProvider
|
|
UserEventProvider userEventProvider
|
|
|
|
IdGenerator id.Generator
|
|
|
|
PasswordCheckLifeTime time.Duration
|
|
MfaInitSkippedLifeTime time.Duration
|
|
MfaSoftwareCheckLifeTime time.Duration
|
|
MfaHardwareCheckLifeTime time.Duration
|
|
}
|
|
|
|
type userSessionViewProvider interface {
|
|
UserSessionByIDs(string, string) (*view_model.UserSessionView, error)
|
|
UserSessionsByAgentID(string) ([]*view_model.UserSessionView, error)
|
|
}
|
|
type userViewProvider interface {
|
|
UserByID(string) (*view_model.UserView, error)
|
|
}
|
|
|
|
type userEventProvider interface {
|
|
UserEventsByID(ctx context.Context, id string, sequence uint64) ([]*es_models.Event, error)
|
|
}
|
|
|
|
func (repo *AuthRequestRepo) Health(ctx context.Context) error {
|
|
if err := repo.UserEvents.Health(ctx); err != nil {
|
|
return err
|
|
}
|
|
return repo.AuthRequests.Health(ctx)
|
|
}
|
|
|
|
func (repo *AuthRequestRepo) CreateAuthRequest(ctx context.Context, request *model.AuthRequest) (*model.AuthRequest, error) {
|
|
reqID, err := repo.IdGenerator.Next()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
request.ID = reqID
|
|
ids, err := repo.View.AppIDsFromProjectByClientID(ctx, request.ApplicationID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
request.Audience = ids
|
|
err = repo.AuthRequests.SaveAuthRequest(ctx, request)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return request, nil
|
|
}
|
|
|
|
func (repo *AuthRequestRepo) AuthRequestByID(ctx context.Context, id string) (*model.AuthRequest, error) {
|
|
return repo.getAuthRequest(ctx, id, false)
|
|
}
|
|
|
|
func (repo *AuthRequestRepo) AuthRequestByIDCheckLoggedIn(ctx context.Context, id string) (*model.AuthRequest, error) {
|
|
return repo.getAuthRequest(ctx, id, true)
|
|
}
|
|
|
|
func (repo *AuthRequestRepo) SaveAuthCode(ctx context.Context, id, code string) error {
|
|
request, err := repo.AuthRequests.GetAuthRequestByID(ctx, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
request.Code = code
|
|
return repo.AuthRequests.UpdateAuthRequest(ctx, request)
|
|
}
|
|
|
|
func (repo *AuthRequestRepo) AuthRequestByCode(ctx context.Context, code string) (*model.AuthRequest, error) {
|
|
request, err := repo.AuthRequests.GetAuthRequestByCode(ctx, code)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
steps, err := repo.nextSteps(ctx, request, true)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
request.PossibleSteps = steps
|
|
return request, nil
|
|
}
|
|
|
|
func (repo *AuthRequestRepo) DeleteAuthRequest(ctx context.Context, id string) error {
|
|
return repo.AuthRequests.DeleteAuthRequest(ctx, id)
|
|
}
|
|
|
|
func (repo *AuthRequestRepo) CheckLoginName(ctx context.Context, id, loginName string) error {
|
|
request, err := repo.AuthRequests.GetAuthRequestByID(ctx, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
user, err := repo.View.UserByLoginName(loginName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
request.SetUserInfo(user.ID, loginName, user.ResourceOwner)
|
|
return repo.AuthRequests.UpdateAuthRequest(ctx, request)
|
|
}
|
|
|
|
func (repo *AuthRequestRepo) SelectUser(ctx context.Context, id, userID string) error {
|
|
request, err := repo.AuthRequests.GetAuthRequestByID(ctx, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
user, err := repo.View.UserByID(userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
request.SetUserInfo(user.ID, user.PreferredLoginName, user.ResourceOwner)
|
|
return repo.AuthRequests.UpdateAuthRequest(ctx, request)
|
|
}
|
|
|
|
func (repo *AuthRequestRepo) VerifyPassword(ctx context.Context, id, userID, password string, info *model.BrowserInfo) error {
|
|
request, err := repo.AuthRequests.GetAuthRequestByID(ctx, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if request.UserID != userID {
|
|
return errors.ThrowPreconditionFailed(nil, "EVENT-ds35D", "Errors.User.NotMatchingUserID")
|
|
}
|
|
return repo.UserEvents.CheckPassword(ctx, userID, password, request.WithCurrentInfo(info))
|
|
}
|
|
|
|
func (repo *AuthRequestRepo) VerifyMfaOTP(ctx context.Context, authRequestID, userID string, code string, info *model.BrowserInfo) error {
|
|
request, err := repo.AuthRequests.GetAuthRequestByID(ctx, authRequestID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if request.UserID != userID {
|
|
return errors.ThrowPreconditionFailed(nil, "EVENT-ADJ26", "Errors.User.NotMatchingUserID")
|
|
}
|
|
return repo.UserEvents.CheckMfaOTP(ctx, userID, code, request.WithCurrentInfo(info))
|
|
}
|
|
|
|
func (repo *AuthRequestRepo) getAuthRequest(ctx context.Context, id string, checkLoggedIn bool) (*model.AuthRequest, error) {
|
|
request, err := repo.AuthRequests.GetAuthRequestByID(ctx, id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
steps, err := repo.nextSteps(ctx, request, checkLoggedIn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
request.PossibleSteps = steps
|
|
return request, nil
|
|
}
|
|
|
|
func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *model.AuthRequest, checkLoggedIn bool) ([]model.NextStep, error) {
|
|
if request == nil {
|
|
return nil, errors.ThrowInvalidArgument(nil, "EVENT-ds27a", "request must not be nil")
|
|
}
|
|
steps := make([]model.NextStep, 0)
|
|
if !checkLoggedIn && request.Prompt == model.PromptNone {
|
|
return append(steps, &model.RedirectToCallbackStep{}), nil
|
|
}
|
|
if request.UserID == "" {
|
|
steps = append(steps, &model.LoginStep{})
|
|
if request.Prompt == model.PromptSelectAccount {
|
|
users, err := repo.usersForUserSelection(request)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
steps = append(steps, &model.SelectUserStep{Users: users})
|
|
}
|
|
return steps, nil
|
|
}
|
|
user, err := userByID(ctx, repo.UserViewProvider, repo.UserEventProvider, request.UserID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
userSession, err := userSessionByIDs(ctx, repo.UserSessionViewProvider, repo.UserEventProvider, request.AgentID, user)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if user.InitRequired {
|
|
return append(steps, &model.InitUserStep{PasswordSet: user.PasswordSet}), nil
|
|
}
|
|
if !user.PasswordSet {
|
|
return append(steps, &model.InitPasswordStep{}), nil
|
|
}
|
|
|
|
if !checkVerificationTime(userSession.PasswordVerification, repo.PasswordCheckLifeTime) {
|
|
return append(steps, &model.PasswordStep{}), nil
|
|
}
|
|
request.PasswordVerified = true
|
|
request.AuthTime = userSession.PasswordVerification
|
|
|
|
if step, ok := repo.mfaChecked(userSession, request, user); !ok {
|
|
return append(steps, step), nil
|
|
}
|
|
|
|
if user.PasswordChangeRequired {
|
|
steps = append(steps, &model.ChangePasswordStep{})
|
|
}
|
|
if !user.IsEmailVerified {
|
|
steps = append(steps, &model.VerifyEMailStep{})
|
|
}
|
|
|
|
if user.PasswordChangeRequired || !user.IsEmailVerified {
|
|
return steps, nil
|
|
}
|
|
|
|
//PLANNED: consent step
|
|
return append(steps, &model.RedirectToCallbackStep{}), nil
|
|
}
|
|
|
|
func (repo *AuthRequestRepo) usersForUserSelection(request *model.AuthRequest) ([]model.UserSelection, error) {
|
|
userSessions, err := userSessionsByUserAgentID(repo.UserSessionViewProvider, request.AgentID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
users := make([]model.UserSelection, len(userSessions))
|
|
for i, session := range userSessions {
|
|
users[i] = model.UserSelection{
|
|
UserID: session.UserID,
|
|
DisplayName: session.DisplayName,
|
|
LoginName: session.LoginName,
|
|
UserSessionState: session.State,
|
|
}
|
|
}
|
|
return users, nil
|
|
}
|
|
|
|
func (repo *AuthRequestRepo) mfaChecked(userSession *user_model.UserSessionView, request *model.AuthRequest, user *user_model.UserView) (model.NextStep, bool) {
|
|
mfaLevel := request.MfaLevel()
|
|
promptRequired := user.MfaMaxSetUp < mfaLevel
|
|
if promptRequired || !repo.mfaSkippedOrSetUp(user) {
|
|
return &model.MfaPromptStep{
|
|
Required: promptRequired,
|
|
MfaProviders: user.MfaTypesSetupPossible(mfaLevel),
|
|
}, false
|
|
}
|
|
switch mfaLevel {
|
|
default:
|
|
fallthrough
|
|
case model.MfaLevelNotSetUp:
|
|
if user.MfaMaxSetUp == model.MfaLevelNotSetUp {
|
|
return nil, true
|
|
}
|
|
fallthrough
|
|
case model.MfaLevelSoftware:
|
|
if checkVerificationTime(userSession.MfaSoftwareVerification, repo.MfaSoftwareCheckLifeTime) {
|
|
request.MfasVerified = append(request.MfasVerified, userSession.MfaSoftwareVerificationType)
|
|
request.AuthTime = userSession.MfaSoftwareVerification
|
|
return nil, true
|
|
}
|
|
fallthrough
|
|
case model.MfaLevelHardware:
|
|
if checkVerificationTime(userSession.MfaHardwareVerification, repo.MfaHardwareCheckLifeTime) {
|
|
request.MfasVerified = append(request.MfasVerified, userSession.MfaHardwareVerificationType)
|
|
request.AuthTime = userSession.MfaHardwareVerification
|
|
return nil, true
|
|
}
|
|
}
|
|
return &model.MfaVerificationStep{
|
|
MfaProviders: user.MfaTypesAllowed(mfaLevel),
|
|
}, false
|
|
}
|
|
|
|
func (repo *AuthRequestRepo) mfaSkippedOrSetUp(user *user_model.UserView) bool {
|
|
if user.MfaMaxSetUp > model.MfaLevelNotSetUp {
|
|
return true
|
|
}
|
|
return checkVerificationTime(user.MfaInitSkipped, repo.MfaInitSkippedLifeTime)
|
|
}
|
|
|
|
func checkVerificationTime(verificationTime time.Time, lifetime time.Duration) bool {
|
|
return verificationTime.Add(lifetime).After(time.Now().UTC())
|
|
}
|
|
|
|
func userSessionsByUserAgentID(provider userSessionViewProvider, agentID string) ([]*user_model.UserSessionView, error) {
|
|
session, err := provider.UserSessionsByAgentID(agentID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return view_model.UserSessionsToModel(session), nil
|
|
}
|
|
|
|
func userSessionByIDs(ctx context.Context, provider userSessionViewProvider, eventProvider userEventProvider, agentID string, user *user_model.UserView) (*user_model.UserSessionView, error) {
|
|
session, err := provider.UserSessionByIDs(agentID, user.ID)
|
|
if err != nil {
|
|
if !errors.IsNotFound(err) {
|
|
return nil, err
|
|
}
|
|
session = &view_model.UserSessionView{}
|
|
}
|
|
events, err := eventProvider.UserEventsByID(ctx, user.ID, session.Sequence)
|
|
if err != nil {
|
|
logging.Log("EVENT-Hse6s").WithError(err).Debug("error retrieving new events")
|
|
return view_model.UserSessionToModel(session), nil
|
|
}
|
|
sessionCopy := *session
|
|
for _, event := range events {
|
|
switch event.Type {
|
|
case es_model.UserPasswordCheckSucceeded,
|
|
es_model.UserPasswordCheckFailed,
|
|
es_model.MfaOtpCheckSucceeded,
|
|
es_model.MfaOtpCheckFailed,
|
|
es_model.SignedOut:
|
|
eventData, err := view_model.UserSessionFromEvent(event)
|
|
if err != nil {
|
|
logging.Log("EVENT-sdgT3").WithError(err).Debug("error getting event data")
|
|
return view_model.UserSessionToModel(session), nil
|
|
}
|
|
if eventData.UserAgentID != agentID {
|
|
continue
|
|
}
|
|
}
|
|
sessionCopy.AppendEvent(event)
|
|
}
|
|
return view_model.UserSessionToModel(&sessionCopy), nil
|
|
}
|
|
|
|
func userByID(ctx context.Context, viewProvider userViewProvider, eventProvider userEventProvider, userID string) (*user_model.UserView, error) {
|
|
user, err := viewProvider.UserByID(userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
events, err := eventProvider.UserEventsByID(ctx, userID, user.Sequence)
|
|
if err != nil {
|
|
logging.Log("EVENT-dfg42").WithError(err).Debug("error retrieving new events")
|
|
return view_model.UserToModel(user), nil
|
|
}
|
|
userCopy := *user
|
|
for _, event := range events {
|
|
if err := userCopy.AppendEvent(event); err != nil {
|
|
return view_model.UserToModel(user), nil
|
|
}
|
|
}
|
|
return view_model.UserToModel(&userCopy), nil
|
|
}
|