Merge branch 'refs/heads/main' into next

This commit is contained in:
Livio Spring
2024-07-24 15:47:02 +02:00
119 changed files with 2458 additions and 875 deletions

View File

@@ -50,7 +50,7 @@ func (s *Server) VerifyMyEmail(ctx context.Context, req *auth_pb.VerifyMyEmailRe
return nil, err
}
ctxData := authz.GetCtxData(ctx)
objectDetails, err := s.command.VerifyHumanEmail(ctx, ctxData.UserID, req.Code, ctxData.ResourceOwner, emailCodeGenerator)
objectDetails, err := s.command.VerifyHumanEmail(ctx, ctxData.UserID, req.Code, ctxData.ResourceOwner, "", "", emailCodeGenerator)
if err != nil {
return nil, err
}

View File

@@ -260,7 +260,7 @@ func appendIfNotExists(array []string, value string) []string {
func ListMyProjectOrgsRequestToQuery(req *auth_pb.ListMyProjectOrgsRequest) (*query.OrgSearchQueries, error) {
offset, limit, asc := obj_grpc.ListQueryToModel(req.Query)
queries, err := org.OrgQueriesToQuery(req.Queries)
queries, err := org.OrgQueriesToModel(req.Queries)
if err != nil {
return nil, err
}

View File

@@ -586,11 +586,7 @@ func (s *Server) SetHumanPassword(ctx context.Context, req *mgmt_pb.SetHumanPass
}
func (s *Server) SendHumanResetPasswordNotification(ctx context.Context, req *mgmt_pb.SendHumanResetPasswordNotificationRequest) (*mgmt_pb.SendHumanResetPasswordNotificationResponse, error) {
passwordCodeGenerator, err := s.query.InitEncryptionGenerator(ctx, domain.SecretGeneratorTypePasswordResetCode, s.userCodeAlg)
if err != nil {
return nil, err
}
objectDetails, err := s.command.RequestSetPassword(ctx, req.UserId, authz.GetCtxData(ctx).OrgID, notifyTypeToDomain(req.Type), passwordCodeGenerator, "")
objectDetails, err := s.command.RequestSetPassword(ctx, req.UserId, authz.GetCtxData(ctx).OrgID, notifyTypeToDomain(req.Type), "")
if err != nil {
return nil, err
}

View File

@@ -133,8 +133,14 @@ func ImportHumanUserRequestToDomain(req *mgmt_pb.ImportHumanUserRequest) (human
}
func AddMachineUserRequestToCommand(req *mgmt_pb.AddMachineUserRequest, resourceowner string) *command.Machine {
userId := ""
if req.UserId != nil {
userId = *req.UserId
}
return &command.Machine{
ObjectRoot: models.ObjectRoot{
AggregateID: userId,
ResourceOwner: resourceowner,
},
Username: req.UserName,

View File

@@ -78,3 +78,36 @@ func TestImport_UnparsablePreferredLanguage(t *testing.T) {
})
require.NoError(t, err)
}
func TestAdd_MachineUser(t *testing.T) {
random := integration.RandString(5)
res, err := Client.AddMachineUser(OrgCTX, &management.AddMachineUserRequest{
UserName: random,
Name: "testMachineName1",
Description: "testMachineDescription1",
AccessTokenType: 0,
})
require.NoError(t, err)
_, err = Client.GetUserByID(OrgCTX, &management.GetUserByIDRequest{Id: res.GetUserId()})
require.NoError(t, err)
}
func TestAdd_MachineUserCustomID(t *testing.T) {
id := integration.RandString(5)
random := integration.RandString(5)
res, err := Client.AddMachineUser(OrgCTX, &management.AddMachineUserRequest{
UserId: &id,
UserName: random,
Name: "testMachineName1",
Description: "testMachineDescription1",
AccessTokenType: 0,
})
require.NoError(t, err)
_, err = Client.GetUserByID(OrgCTX, &management.GetUserByIDRequest{Id: id})
require.NoError(t, err)
require.Equal(t, id, res.GetUserId())
}

View File

@@ -27,35 +27,13 @@ func OrgQueryToModel(apiQuery *org_pb.OrgQuery) (query.SearchQuery, error) {
return query.NewOrgNameSearchQuery(object.TextMethodToQuery(q.NameQuery.Method), q.NameQuery.Name)
case *org_pb.OrgQuery_StateQuery:
return query.NewOrgStateSearchQuery(OrgStateToDomain(q.StateQuery.State))
case *org_pb.OrgQuery_IdQuery:
return query.NewOrgIDSearchQuery(q.IdQuery.Id)
default:
return nil, zerrors.ThrowInvalidArgument(nil, "ORG-vR9nC", "List.Query.Invalid")
}
}
func OrgQueriesToQuery(queries []*org_pb.OrgQuery) (_ []query.SearchQuery, err error) {
q := make([]query.SearchQuery, len(queries))
for i, query := range queries {
q[i], err = OrgQueryToQuery(query)
if err != nil {
return nil, err
}
}
return q, nil
}
func OrgQueryToQuery(search *org_pb.OrgQuery) (query.SearchQuery, error) {
switch q := search.Query.(type) {
case *org_pb.OrgQuery_DomainQuery:
return query.NewOrgDomainSearchQuery(object.TextMethodToQuery(q.DomainQuery.Method), q.DomainQuery.Domain)
case *org_pb.OrgQuery_NameQuery:
return query.NewOrgNameSearchQuery(object.TextMethodToQuery(q.NameQuery.Method), q.NameQuery.Name)
case *org_pb.OrgQuery_StateQuery:
return query.NewOrgStateSearchQuery(OrgStateToDomain(q.StateQuery.State))
default:
return nil, zerrors.ThrowInvalidArgument(nil, "ADMIN-ADvsd", "List.Query.Invalid")
}
}
func OrgViewsToPb(orgs []*query.Org) []*org_pb.Org {
o := make([]*org_pb.Org, len(orgs))
for i, org := range orgs {

View File

@@ -40,6 +40,23 @@ func (s *Server) SetPhone(ctx context.Context, req *user.SetPhoneRequest) (resp
}, nil
}
func (s *Server) RemovePhone(ctx context.Context, req *user.RemovePhoneRequest) (resp *user.RemovePhoneResponse, err error) {
details, err := s.command.RemoveUserPhone(ctx,
req.GetUserId(),
)
if err != nil {
return nil, err
}
return &user.RemovePhoneResponse{
Details: &object.Details{
Sequence: details.Sequence,
ChangeDate: timestamppb.New(details.EventDate),
ResourceOwner: details.ResourceOwner,
},
}, nil
}
func (s *Server) ResendPhoneCode(ctx context.Context, req *user.ResendPhoneCodeRequest) (resp *user.ResendPhoneCodeResponse, err error) {
var phone *domain.Phone
switch v := req.GetVerification().(type) {

View File

@@ -3,6 +3,7 @@
package user_test
import (
"context"
"fmt"
"testing"
"time"
@@ -245,3 +246,99 @@ func TestServer_VerifyPhone(t *testing.T) {
})
}
}
func TestServer_RemovePhone(t *testing.T) {
userResp := Tester.CreateHumanUser(CTX)
failResp := Tester.CreateHumanUserNoPhone(CTX)
otherUser := Tester.CreateHumanUser(CTX).GetUserId()
doubleRemoveUser := Tester.CreateHumanUser(CTX)
Tester.RegisterUserPasskey(CTX, otherUser)
_, sessionTokenOtherUser, _, _ := Tester.CreateVerifiedWebAuthNSession(t, CTX, otherUser)
tests := []struct {
name string
ctx context.Context
req *user.RemovePhoneRequest
want *user.RemovePhoneResponse
wantErr bool
dep func(ctx context.Context, userID string) (*user.RemovePhoneResponse, error)
}{
{
name: "remove phone",
ctx: CTX,
req: &user.RemovePhoneRequest{
UserId: userResp.GetUserId(),
},
want: &user.RemovePhoneResponse{
Details: &object.Details{
Sequence: 1,
ChangeDate: timestamppb.Now(),
ResourceOwner: Tester.Organisation.ID,
},
},
dep: func(ctx context.Context, userID string) (*user.RemovePhoneResponse, error) {
return nil, nil
},
},
{
name: "user without phone",
ctx: CTX,
req: &user.RemovePhoneRequest{
UserId: failResp.GetUserId(),
},
wantErr: true,
dep: func(ctx context.Context, userID string) (*user.RemovePhoneResponse, error) {
return nil, nil
},
},
{
name: "remove previously deleted phone",
ctx: CTX,
req: &user.RemovePhoneRequest{
UserId: doubleRemoveUser.GetUserId(),
},
wantErr: true,
dep: func(ctx context.Context, userID string) (*user.RemovePhoneResponse, error) {
return Client.RemovePhone(ctx, &user.RemovePhoneRequest{
UserId: doubleRemoveUser.GetUserId(),
});
},
},
{
name: "no user id",
ctx: CTX,
req: &user.RemovePhoneRequest{},
wantErr: true,
dep: func(ctx context.Context, userID string) (*user.RemovePhoneResponse, error) {
return nil, nil
},
},
{
name: "other user, no permission",
ctx: Tester.WithAuthorizationToken(CTX, sessionTokenOtherUser),
req: &user.RemovePhoneRequest{
UserId: userResp.GetUserId(),
},
wantErr: true,
dep: func(ctx context.Context, userID string) (*user.RemovePhoneResponse, error) {
return nil, nil
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, depErr := tt.dep(tt.ctx, tt.req.UserId)
require.NoError(t, depErr)
got, err := Client.RemovePhone(tt.ctx, tt.req)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
integration.AssertDetails(t, tt.want, got)
})
}
}

View File

@@ -60,7 +60,12 @@ func UsersToPb(users []*query.User, assetPrefix string) []*user.User {
func userToPb(userQ *query.User, assetPrefix string) *user.User {
return &user.User{
UserId: userQ.ID,
UserId: userQ.ID,
Details: object.DomainToDetailsPb(&domain.ObjectDetails{
Sequence: userQ.Sequence,
EventDate: userQ.ChangeDate,
ResourceOwner: userQ.ResourceOwner,
}),
State: userStateToPb(userQ.State),
Username: userQ.Username,
LoginNames: userQ.LoginNames,

View File

@@ -23,7 +23,7 @@ func TestServer_GetUserByID(t *testing.T) {
type args struct {
ctx context.Context
req *user.GetUserByIDRequest
dep func(ctx context.Context, username string, request *user.GetUserByIDRequest) (*timestamppb.Timestamp, error)
dep func(ctx context.Context, username string, request *user.GetUserByIDRequest) (*userAttr, error)
}
tests := []struct {
name string
@@ -38,7 +38,7 @@ func TestServer_GetUserByID(t *testing.T) {
&user.GetUserByIDRequest{
UserId: "",
},
func(ctx context.Context, username string, request *user.GetUserByIDRequest) (*timestamppb.Timestamp, error) {
func(ctx context.Context, username string, request *user.GetUserByIDRequest) (*userAttr, error) {
return nil, nil
},
},
@@ -51,7 +51,7 @@ func TestServer_GetUserByID(t *testing.T) {
&user.GetUserByIDRequest{
UserId: "unknown",
},
func(ctx context.Context, username string, request *user.GetUserByIDRequest) (*timestamppb.Timestamp, error) {
func(ctx context.Context, username string, request *user.GetUserByIDRequest) (*userAttr, error) {
return nil, nil
},
},
@@ -62,10 +62,10 @@ func TestServer_GetUserByID(t *testing.T) {
args: args{
IamCTX,
&user.GetUserByIDRequest{},
func(ctx context.Context, username string, request *user.GetUserByIDRequest) (*timestamppb.Timestamp, error) {
func(ctx context.Context, username string, request *user.GetUserByIDRequest) (*userAttr, error) {
resp := Tester.CreateHumanUserVerified(ctx, orgResp.OrganizationId, username)
request.UserId = resp.GetUserId()
return nil, nil
return &userAttr{resp.GetUserId(), username, nil, resp.GetDetails()}, nil
},
},
want: &user.GetUserByIDResponse{
@@ -106,11 +106,11 @@ func TestServer_GetUserByID(t *testing.T) {
args: args{
IamCTX,
&user.GetUserByIDRequest{},
func(ctx context.Context, username string, request *user.GetUserByIDRequest) (*timestamppb.Timestamp, error) {
func(ctx context.Context, username string, request *user.GetUserByIDRequest) (*userAttr, error) {
resp := Tester.CreateHumanUserVerified(ctx, orgResp.OrganizationId, username)
request.UserId = resp.GetUserId()
changed := Tester.SetUserPassword(ctx, resp.GetUserId(), integration.UserPassword, true)
return changed, nil
details := Tester.SetUserPassword(ctx, resp.GetUserId(), integration.UserPassword, true)
return &userAttr{resp.GetUserId(), username, details.GetChangeDate(), resp.GetDetails()}, nil
},
},
want: &user.GetUserByIDResponse{
@@ -152,7 +152,7 @@ func TestServer_GetUserByID(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
username := fmt.Sprintf("%d@mouse.com", time.Now().UnixNano())
changed, err := tt.args.dep(tt.args.ctx, username, tt.args.req)
userAttr, err := tt.args.dep(tt.args.ctx, username, tt.args.req)
require.NoError(t, err)
retryDuration := time.Minute
if ctxDeadline, ok := CTX.Deadline(); ok {
@@ -168,14 +168,15 @@ func TestServer_GetUserByID(t *testing.T) {
if getErr != nil {
return
}
tt.want.User.UserId = tt.args.req.GetUserId()
tt.want.User.Username = username
tt.want.User.PreferredLoginName = username
tt.want.User.LoginNames = []string{username}
tt.want.User.Details = userAttr.Details
tt.want.User.UserId = userAttr.UserID
tt.want.User.Username = userAttr.Username
tt.want.User.PreferredLoginName = userAttr.Username
tt.want.User.LoginNames = []string{userAttr.Username}
if human := tt.want.User.GetHuman(); human != nil {
human.Email.Email = username
human.Email.Email = userAttr.Username
if tt.want.User.GetHuman().GetPasswordChanged() != nil {
human.PasswordChanged = changed
human.PasswordChanged = userAttr.Changed
}
}
assert.Equal(ttt, tt.want.User, got.User)
@@ -311,6 +312,9 @@ func TestServer_GetUserByID_Permission(t *testing.T) {
if human := tt.want.User.GetHuman(); human != nil {
human.Email.Email = newOrgOwnerEmail
}
// details tested in GetUserByID
tt.want.User.Details = got.User.GetDetails()
assert.Equal(t, tt.want.User, got.User)
}
})
@@ -321,6 +325,7 @@ type userAttr struct {
UserID string
Username string
Changed *timestamppb.Timestamp
Details *object.Details
}
func TestServer_ListUsers(t *testing.T) {
@@ -374,7 +379,7 @@ func TestServer_ListUsers(t *testing.T) {
for i, username := range usernames {
resp := Tester.CreateHumanUserVerified(ctx, orgResp.OrganizationId, username)
userIDs[i] = resp.GetUserId()
infos[i] = userAttr{resp.GetUserId(), username, nil}
infos[i] = userAttr{resp.GetUserId(), username, nil, resp.GetDetails()}
}
request.Queries = append(request.Queries, InUserIDsQuery(userIDs))
return infos, nil
@@ -428,8 +433,8 @@ func TestServer_ListUsers(t *testing.T) {
for i, username := range usernames {
resp := Tester.CreateHumanUserVerified(ctx, orgResp.OrganizationId, username)
userIDs[i] = resp.GetUserId()
changed := Tester.SetUserPassword(ctx, resp.GetUserId(), integration.UserPassword, true)
infos[i] = userAttr{resp.GetUserId(), username, changed}
details := Tester.SetUserPassword(ctx, resp.GetUserId(), integration.UserPassword, true)
infos[i] = userAttr{resp.GetUserId(), username, details.GetChangeDate(), resp.GetDetails()}
}
request.Queries = append(request.Queries, InUserIDsQuery(userIDs))
return infos, nil
@@ -485,7 +490,7 @@ func TestServer_ListUsers(t *testing.T) {
for i, username := range usernames {
resp := Tester.CreateHumanUserVerified(ctx, orgResp.OrganizationId, username)
userIDs[i] = resp.GetUserId()
infos[i] = userAttr{resp.GetUserId(), username, nil}
infos[i] = userAttr{resp.GetUserId(), username, nil, resp.GetDetails()}
}
request.Queries = append(request.Queries, InUserIDsQuery(userIDs))
return infos, nil
@@ -581,7 +586,7 @@ func TestServer_ListUsers(t *testing.T) {
for i, username := range usernames {
resp := Tester.CreateHumanUserVerified(ctx, orgResp.OrganizationId, username)
userIDs[i] = resp.GetUserId()
infos[i] = userAttr{resp.GetUserId(), username, nil}
infos[i] = userAttr{resp.GetUserId(), username, nil, resp.GetDetails()}
request.Queries = append(request.Queries, UsernameQuery(username))
}
return infos, nil
@@ -633,7 +638,7 @@ func TestServer_ListUsers(t *testing.T) {
infos := make([]userAttr, len(usernames))
for i, username := range usernames {
resp := Tester.CreateHumanUserVerified(ctx, orgResp.OrganizationId, username)
infos[i] = userAttr{resp.GetUserId(), username, nil}
infos[i] = userAttr{resp.GetUserId(), username, nil, resp.GetDetails()}
}
request.Queries = append(request.Queries, InUserEmailsQuery(usernames))
return infos, nil
@@ -685,7 +690,7 @@ func TestServer_ListUsers(t *testing.T) {
infos := make([]userAttr, len(usernames))
for i, username := range usernames {
resp := Tester.CreateHumanUserVerified(ctx, orgResp.OrganizationId, username)
infos[i] = userAttr{resp.GetUserId(), username, nil}
infos[i] = userAttr{resp.GetUserId(), username, nil, resp.GetDetails()}
}
request.Queries = append(request.Queries, InUserEmailsQuery(usernames))
return infos, nil
@@ -800,7 +805,7 @@ func TestServer_ListUsers(t *testing.T) {
infos := make([]userAttr, len(usernames))
for i, username := range usernames {
resp := Tester.CreateHumanUserVerified(ctx, orgResp.OrganizationId, username)
infos[i] = userAttr{resp.GetUserId(), username, nil}
infos[i] = userAttr{resp.GetUserId(), username, nil, resp.GetDetails()}
}
request.Queries = append(request.Queries, OrganizationIdQuery(orgResp.OrganizationId))
request.Queries = append(request.Queries, InUserEmailsQuery(usernames))
@@ -920,6 +925,7 @@ func TestServer_ListUsers(t *testing.T) {
human.PasswordChanged = infos[i].Changed
}
}
tt.want.Result[i].Details = infos[i].Details
}
for i := range tt.want.Result {
assert.Contains(ttt, got.Result, tt.want.Result[i])

View File

@@ -12,7 +12,6 @@ func (s *Server) RegisterTOTP(ctx context.Context, req *user.RegisterTOTPRequest
return totpDetailsToPb(
s.command.AddUserTOTP(ctx, req.GetUserId(), ""),
)
}
func totpDetailsToPb(totp *domain.TOTP, err error) (*user.RegisterTOTPResponse, error) {
@@ -35,3 +34,11 @@ func (s *Server) VerifyTOTPRegistration(ctx context.Context, req *user.VerifyTOT
Details: object.DomainToDetailsPb(objectDetails),
}, nil
}
func (s *Server) RemoveTOTP(ctx context.Context, req *user.RemoveTOTPRequest) (*user.RemoveTOTPResponse, error) {
objectDetails, err := s.command.HumanRemoveTOTP(ctx, req.GetUserId(), "")
if err != nil {
return nil, err
}
return &user.RemoveTOTPResponse{Details: object.DomainToDetailsPb(objectDetails)}, nil
}

View File

@@ -205,3 +205,80 @@ func TestServer_VerifyTOTPRegistration(t *testing.T) {
})
}
}
func TestServer_RemoveTOTP(t *testing.T) {
userID := Tester.CreateHumanUser(CTX).GetUserId()
Tester.RegisterUserPasskey(CTX, userID)
_, sessionToken, _, _ := Tester.CreateVerifiedWebAuthNSession(t, CTX, userID)
userVerified := Tester.CreateHumanUser(CTX)
Tester.RegisterUserPasskey(CTX, userVerified.GetUserId())
_, sessionTokenVerified, _, _ := Tester.CreateVerifiedWebAuthNSession(t, CTX, userVerified.GetUserId())
userVerifiedCtx := Tester.WithAuthorizationToken(context.Background(), sessionTokenVerified)
_, err := Tester.Client.UserV2.VerifyPhone(userVerifiedCtx, &user.VerifyPhoneRequest{
UserId: userVerified.GetUserId(),
VerificationCode: userVerified.GetPhoneCode(),
})
require.NoError(t, err)
regOtherUser, err := Client.RegisterTOTP(CTX, &user.RegisterTOTPRequest{
UserId: userVerified.GetUserId(),
})
require.NoError(t, err)
codeOtherUser, err := totp.GenerateCode(regOtherUser.Secret, time.Now())
require.NoError(t, err)
_, err = Client.VerifyTOTPRegistration(userVerifiedCtx, &user.VerifyTOTPRegistrationRequest{
UserId: userVerified.GetUserId(),
Code: codeOtherUser,
},
)
require.NoError(t, err)
type args struct {
ctx context.Context
req *user.RemoveTOTPRequest
}
tests := []struct {
name string
args args
want *user.RemoveTOTPResponse
wantErr bool
}{
{
name: "not added",
args: args{
ctx: Tester.WithAuthorizationToken(context.Background(), sessionToken),
req: &user.RemoveTOTPRequest{
UserId: userID,
},
},
wantErr: true,
},
{
name: "success",
args: args{
ctx: userVerifiedCtx,
req: &user.RemoveTOTPRequest{
UserId: userVerified.GetUserId(),
},
},
want: &user.RemoveTOTPResponse{
Details: &object.Details{
ResourceOwner: Tester.Organisation.ResourceOwner,
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Client.RemoveTOTP(tt.args.ctx, tt.args.req)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.NotNil(t, got)
integration.AssertDetails(t, tt.want, got)
})
}
}

View File

@@ -590,7 +590,7 @@ func (s *Server) checkIntentToken(token string, intentID string) error {
}
func (s *Server) ListAuthenticationMethodTypes(ctx context.Context, req *user.ListAuthenticationMethodTypesRequest) (*user.ListAuthenticationMethodTypesResponse, error) {
authMethods, err := s.query.ListActiveUserAuthMethodTypes(ctx, req.GetUserId())
authMethods, err := s.query.ListUserAuthMethodTypes(ctx, req.GetUserId(), true)
if err != nil {
return nil, err
}

View File

@@ -97,12 +97,7 @@ func (l *Login) resendPasswordSet(w http.ResponseWriter, r *http.Request, authRe
userID = authReq.UserID
authReqID = authReq.ID
}
passwordCodeGenerator, err := l.query.InitEncryptionGenerator(r.Context(), domain.SecretGeneratorTypePasswordResetCode, l.userCodeAlg)
if err != nil {
l.renderInitPassword(w, r, authReq, userID, "", err)
return
}
_, err = l.command.RequestSetPassword(setContext(r.Context(), userOrg), userID, userOrg, domain.NotificationTypeEmail, passwordCodeGenerator, authReqID)
_, err := l.command.RequestSetPassword(setContext(r.Context(), userOrg), userID, userOrg, domain.NotificationTypeEmail, authReqID)
l.renderInitPassword(w, r, authReq, userID, "", err)
}

View File

@@ -1,10 +1,16 @@
package login
import (
"context"
"net/http"
"net/url"
"slices"
"github.com/zitadel/logging"
http_mw "github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/zerrors"
)
const (
@@ -16,15 +22,25 @@ const (
)
type mailVerificationFormData struct {
Code string `schema:"code"`
UserID string `schema:"userID"`
Resend bool `schema:"resend"`
Code string `schema:"code"`
UserID string `schema:"userID"`
Resend bool `schema:"resend"`
PasswordInit bool `schema:"passwordInit"`
Password string `schema:"password"`
PasswordConfirm string `schema:"passwordconfirm"`
}
type mailVerificationData struct {
baseData
profileData
UserID string
UserID string
Code string
PasswordInit bool
MinLength uint64
HasUppercase string
HasLowercase string
HasNumber string
HasSymbol string
}
func MailVerificationLink(origin, userID, code, orgID, authRequestID string) string {
@@ -40,11 +56,32 @@ func (l *Login) handleMailVerification(w http.ResponseWriter, r *http.Request) {
authReq := l.checkOptionalAuthRequestOfEmailLinks(r)
userID := r.FormValue(queryUserID)
code := r.FormValue(queryCode)
if code != "" {
l.checkMailCode(w, r, authReq, userID, code)
if userID == "" && authReq == nil {
l.renderError(w, r, authReq, nil)
return
}
l.renderMailVerification(w, r, authReq, userID, nil)
if userID == "" {
userID = authReq.UserID
}
passwordInit := l.checkUserNoFirstFactor(r.Context(), userID)
if code != "" && !passwordInit {
l.checkMailCode(w, r, authReq, userID, code, "")
return
}
l.renderMailVerification(w, r, authReq, userID, code, passwordInit, nil)
}
func (l *Login) checkUserNoFirstFactor(ctx context.Context, userID string) bool {
authMethods, err := l.query.ListUserAuthMethodTypes(setUserContext(ctx, userID, ""), userID, false)
if err != nil {
logging.WithFields("userID", userID).OnError(err).Warn("unable to load user's auth methods for mail verification")
return false
}
return !slices.ContainsFunc(authMethods.AuthMethodTypes, func(m domain.UserAuthMethodType) bool {
return m == domain.UserAuthMethodTypeIDP ||
m == domain.UserAuthMethodTypePassword ||
m == domain.UserAuthMethodTypePasswordless
})
}
func (l *Login) handleMailVerificationCheck(w http.ResponseWriter, r *http.Request) {
@@ -55,7 +92,12 @@ func (l *Login) handleMailVerificationCheck(w http.ResponseWriter, r *http.Reque
return
}
if !data.Resend {
l.checkMailCode(w, r, authReq, data.UserID, data.Code)
if data.PasswordInit && data.Password != data.PasswordConfirm {
err := zerrors.ThrowInvalidArgument(nil, "VIEW-fsdfd", "Errors.User.Password.ConfirmationWrong")
l.renderMailVerification(w, r, authReq, data.UserID, data.Code, data.PasswordInit, err)
return
}
l.checkMailCode(w, r, authReq, data.UserID, data.Code, data.Password)
return
}
var userOrg, authReqID string
@@ -65,14 +107,14 @@ func (l *Login) handleMailVerificationCheck(w http.ResponseWriter, r *http.Reque
}
emailCodeGenerator, err := l.query.InitEncryptionGenerator(r.Context(), domain.SecretGeneratorTypeVerifyEmailCode, l.userCodeAlg)
if err != nil {
l.checkMailCode(w, r, authReq, data.UserID, data.Code)
l.renderMailVerification(w, r, authReq, data.UserID, "", data.PasswordInit, err)
return
}
_, err = l.command.CreateHumanEmailVerificationCode(setContext(r.Context(), userOrg), data.UserID, userOrg, emailCodeGenerator, authReqID)
l.renderMailVerification(w, r, authReq, data.UserID, err)
l.renderMailVerification(w, r, authReq, data.UserID, "", data.PasswordInit, err)
}
func (l *Login) checkMailCode(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, userID, code string) {
func (l *Login) checkMailCode(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, userID, code, password string) {
userOrg := ""
if authReq != nil {
userID = authReq.UserID
@@ -80,31 +122,52 @@ func (l *Login) checkMailCode(w http.ResponseWriter, r *http.Request, authReq *d
}
emailCodeGenerator, err := l.query.InitEncryptionGenerator(r.Context(), domain.SecretGeneratorTypeVerifyEmailCode, l.userCodeAlg)
if err != nil {
l.renderMailVerification(w, r, authReq, userID, err)
l.renderMailVerification(w, r, authReq, userID, "", password != "", err)
return
}
_, err = l.command.VerifyHumanEmail(setContext(r.Context(), userOrg), userID, code, userOrg, emailCodeGenerator)
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
_, err = l.command.VerifyHumanEmail(setContext(r.Context(), userOrg), userID, code, userOrg, password, userAgentID, emailCodeGenerator)
if err != nil {
l.renderMailVerification(w, r, authReq, userID, err)
l.renderMailVerification(w, r, authReq, userID, "", password != "", err)
return
}
l.renderMailVerified(w, r, authReq, userOrg)
}
func (l *Login) renderMailVerification(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, userID string, err error) {
func (l *Login) renderMailVerification(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, userID, code string, passwordInit bool, err error) {
var errID, errMessage string
if err != nil {
errID, errMessage = l.getErrorMessage(r, err)
}
if userID == "" {
if userID == "" && authReq != nil {
userID = authReq.UserID
}
translator := l.getTranslator(r.Context(), authReq)
data := mailVerificationData{
baseData: l.getBaseData(r, authReq, translator, "EmailVerification.Title", "EmailVerification.Description", errID, errMessage),
UserID: userID,
profileData: l.getProfileData(authReq),
baseData: l.getBaseData(r, authReq, translator, "EmailVerification.Title", "EmailVerification.Description", errID, errMessage),
UserID: userID,
profileData: l.getProfileData(authReq),
Code: code,
PasswordInit: passwordInit,
}
if passwordInit {
policy := l.getPasswordComplexityPolicyByUserID(r, userID)
if policy != nil {
data.MinLength = policy.MinLength
if policy.HasUppercase {
data.HasUppercase = UpperCaseRegex
}
if policy.HasLowercase {
data.HasLowercase = LowerCaseRegex
}
if policy.HasSymbol {
data.HasSymbol = SymbolRegex
}
if policy.HasNumber {
data.HasNumber = NumberRegex
}
}
}
if authReq == nil {
user, err := l.query.GetUserByID(r.Context(), false, userID)

View File

@@ -25,15 +25,7 @@ func (l *Login) handlePasswordReset(w http.ResponseWriter, r *http.Request) {
l.renderPasswordResetDone(w, r, authReq, err)
return
}
passwordCodeGenerator, err := l.query.InitEncryptionGenerator(r.Context(), domain.SecretGeneratorTypePasswordResetCode, l.userCodeAlg)
if err != nil {
if authReq.LoginPolicy.IgnoreUnknownUsernames && zerrors.IsNotFound(err) {
err = nil
}
l.renderPasswordResetDone(w, r, authReq, err)
return
}
_, err = l.command.RequestSetPassword(setContext(r.Context(), authReq.UserOrgID), user.ID, authReq.UserOrgID, domain.NotificationTypeEmail, passwordCodeGenerator, authReq.ID)
_, err = l.command.RequestSetPassword(setContext(r.Context(), authReq.UserOrgID), user.ID, authReq.UserOrgID, domain.NotificationTypeEmail, authReq.ID)
l.renderPasswordResetDone(w, r, authReq, err)
}

View File

@@ -313,7 +313,7 @@ func (l *Login) chooseNextStep(w http.ResponseWriter, r *http.Request, authReq *
case *domain.ChangePasswordStep:
l.renderChangePassword(w, r, authReq, err)
case *domain.VerifyEMailStep:
l.renderMailVerification(w, r, authReq, "", err)
l.renderMailVerification(w, r, authReq, authReq.UserID, "", step.InitPassword, err)
case *domain.MFAPromptStep:
l.renderMFAPrompt(w, r, authReq, step, err)
case *domain.InitUserStep:

View File

@@ -89,7 +89,7 @@ InitUserDone:
InitMFAPrompt:
Title: 两步验证设置
Description: 两步验证为您的账户提供了额外的安全保障。这确保只有你能访问你的账户。
Provider0: 软件应用(如 Google/Migrosoft Authenticator、Authy
Provider0: 软件应用(如 Google/Microsoft Authenticator、Authy
Provider1: 硬件设备(如 Face ID、Windows Hello、指纹
Provider3: 一次性密码短信
Provider4: 一次性密码电子邮件

View File

@@ -17,13 +17,29 @@
<div class="fields">
<label class="lgn-label" for="code">{{t "EmailVerification.CodeLabel"}}</label>
<input class="lgn-input" type="text" id="code" name="code" autocomplete="off" autofocus required>
<input class="lgn-input" type="text" id="code" name="code" autocomplete="off" value="{{ .Code }}" {{if not .Code}}autofocus{{end}} required>
</div>
{{ if .PasswordInit }}
<div class="field">
<label class="lgn-label" for="password">{{t "InitUser.NewPasswordLabel"}}</label>
<input data-minlength="{{ .MinLength }}" data-has-uppercase="{{ .HasUppercase }}"
data-has-lowercase="{{ .HasLowercase }}" data-has-number="{{ .HasNumber }}"
data-has-symbol="{{ .HasSymbol }}" class="lgn-input" type="password" id="password" name="password"
autocomplete="new-password" autofocus required>
</div>
<div class="field">
<label class="lgn-label" for="passwordconfirm">{{t "InitUser.NewPasswordConfirm"}}</label>
<input class="lgn-input" type="password" id="passwordconfirm" name="passwordconfirm"
autocomplete="new-password" autofocus required>
{{ template "password-complexity-policy-description" . }}
</div>
{{ end }}
{{ template "error-message" .}}
<div class="lgn-actions lgn-reverse-order">
<button type="submit" id="submit-button" name="resend" value="false"
<button type="submit" id="{{if.PasswordInit}}init-button{{else}}submit-button{{end}}" name="resend" value="false"
class="lgn-primary lgn-raised-button">{{t "EmailVerification.NextButtonText"}}
</button>
@@ -40,6 +56,11 @@
</div>
</form>
<script src="{{ resourceUrl "scripts/form_submit.js" }}"></script>
{{ if .PasswordInit }}
<script src="{{ resourceUrl "scripts/password_policy_check.js" }}"></script>
<script src="{{ resourceUrl "scripts/init_password_check.js" }}"></script>
{{ else }}
<script src="{{ resourceUrl "scripts/default_form_validation.js" }}"></script>
{{ end }}
{{template "main-bottom" .}}

View File

@@ -52,6 +52,7 @@ type AuthRequestRepo struct {
ProjectProvider projectProvider
ApplicationProvider applicationProvider
CustomTextProvider customTextProvider
PasswordReset passwordReset
IdGenerator id.Generator
}
@@ -96,6 +97,7 @@ type idpUserLinksProvider interface {
type userEventProvider interface {
UserEventsByID(ctx context.Context, id string, changeDate time.Time, eventTypes []eventstore.EventType) ([]eventstore.Event, error)
PasswordCodeExists(ctx context.Context, userID string) (exists bool, err error)
}
type userCommandProvider interface {
@@ -125,6 +127,10 @@ type customTextProvider interface {
CustomTextListByTemplate(ctx context.Context, aggregateID string, text string, withOwnerRemoved bool) (texts *query.CustomTexts, err error)
}
type passwordReset interface {
RequestSetPassword(ctx context.Context, userID, resourceOwner string, notifyType domain.NotificationType, authRequestID string) (objectDetails *domain.ObjectDetails, err error)
}
func (repo *AuthRequestRepo) Health(ctx context.Context) error {
return repo.AuthRequests.Health(ctx)
}
@@ -1046,7 +1052,7 @@ func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *domain.Auth
}
}
if isInternalLogin || (!isInternalLogin && len(request.LinkingUsers) > 0) {
step := repo.firstFactorChecked(request, user, userSession)
step := repo.firstFactorChecked(ctx, request, user, userSession)
if step != nil {
return append(steps, step), nil
}
@@ -1065,7 +1071,9 @@ func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *domain.Auth
steps = append(steps, &domain.ChangePasswordStep{Expired: expired})
}
if !user.IsEmailVerified {
steps = append(steps, &domain.VerifyEMailStep{})
steps = append(steps, &domain.VerifyEMailStep{
InitPassword: !user.PasswordSet,
})
}
if user.UsernameChangeRequired {
steps = append(steps, &domain.ChangeUsernameStep{})
@@ -1204,7 +1212,7 @@ func (repo *AuthRequestRepo) usersForUserSelection(ctx context.Context, request
return users, nil
}
func (repo *AuthRequestRepo) firstFactorChecked(request *domain.AuthRequest, user *user_model.UserView, userSession *user_model.UserSessionView) domain.NextStep {
func (repo *AuthRequestRepo) firstFactorChecked(ctx context.Context, request *domain.AuthRequest, user *user_model.UserView, userSession *user_model.UserSessionView) domain.NextStep {
if user.InitRequired {
return &domain.InitUserStep{PasswordSet: user.PasswordSet}
}
@@ -1226,6 +1234,15 @@ func (repo *AuthRequestRepo) firstFactorChecked(request *domain.AuthRequest, use
}
if user.PasswordInitRequired {
if !user.IsEmailVerified {
return &domain.VerifyEMailStep{InitPassword: true}
}
exists, err := repo.UserEventProvider.PasswordCodeExists(ctx, user.ID)
logging.WithFields("userID", user.ID).OnError(err).Error("unable to check if password code exists")
if err == nil && !exists {
_, err = repo.PasswordReset.RequestSetPassword(ctx, user.ID, user.ResourceOwner, domain.NotificationTypeEmail, request.ID)
logging.WithFields("userID", user.ID).OnError(err).Error("unable to create password code")
}
return &domain.InitPasswordStep{}
}

View File

@@ -108,7 +108,8 @@ func (m *mockViewNoUser) UserByID(string, string) (*user_view_model.UserView, er
}
type mockEventUser struct {
Event eventstore.Event
Event eventstore.Event
CodeExists bool
}
func (m *mockEventUser) UserEventsByID(ctx context.Context, id string, changeDate time.Time, types []eventstore.EventType) ([]eventstore.Event, error) {
@@ -118,6 +119,10 @@ func (m *mockEventUser) UserEventsByID(ctx context.Context, id string, changeDat
return nil, nil
}
func (m *mockEventUser) PasswordCodeExists(ctx context.Context, userID string) (bool, error) {
return m.CodeExists, nil
}
func (m *mockEventUser) GetLatestUserSessionSequence(ctx context.Context, instanceID string) (*query.CurrentState, error) {
return &query.CurrentState{State: query.State{Sequence: 0}}, nil
}
@@ -132,8 +137,8 @@ func (m *mockEventErrUser) UserEventsByID(ctx context.Context, id string, change
return nil, zerrors.ThrowInternal(nil, "id", "internal error")
}
func (m *mockEventErrUser) BulkAddExternalIDPs(ctx context.Context, userID string, externalIDPs []*user_model.ExternalIDP) error {
return zerrors.ThrowInternal(nil, "id", "internal error")
func (m *mockEventErrUser) PasswordCodeExists(ctx context.Context, userID string) (bool, error) {
return false, zerrors.ThrowInternal(nil, "id", "internal error")
}
type mockViewUser struct {
@@ -298,6 +303,28 @@ func (m *mockIDPUserLinks) IDPUserLinks(ctx context.Context, queries *query.IDPU
return &query.IDPUserLinks{Links: m.idps}, nil
}
type mockPasswordReset struct {
t *testing.T
expectCall bool
}
func newMockPasswordReset(expectCall bool) func(*testing.T) passwordReset {
return func(t *testing.T) passwordReset {
return &mockPasswordReset{
t: t,
expectCall: expectCall,
}
}
}
func (m *mockPasswordReset) RequestSetPassword(ctx context.Context, userID, resourceOwner string, notifyType domain.NotificationType, authRequestID string) (objectDetails *domain.ObjectDetails, err error) {
if !m.expectCall {
m.t.Error("unexpected call to RequestSetPassword")
return nil, nil
}
return nil, err
}
func TestAuthRequestRepo_nextSteps(t *testing.T) {
type fields struct {
AuthRequests cache.AuthRequestCache
@@ -316,6 +343,7 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
labelPolicyProvider labelPolicyProvider
passwordAgePolicyProvider passwordAgePolicyProvider
customTextProvider customTextProvider
passwordReset func(t *testing.T) passwordReset
}
type args struct {
request *domain.AuthRequest
@@ -687,7 +715,7 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
fields{
userViewProvider: &mockViewUser{},
userEventProvider: &mockEventUser{
&es_models.Event{
Event: &es_models.Event{
AggregateType: user_repo.AggregateType,
Typ: user_repo.UserDeactivatedType,
},
@@ -709,7 +737,7 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
fields{
userViewProvider: &mockViewUser{},
userEventProvider: &mockEventUser{
&es_models.Event{
Event: &es_models.Event{
AggregateType: user_repo.AggregateType,
Typ: user_repo.UserLockedType,
},
@@ -929,7 +957,7 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
nil,
},
{
"password not set, init password step",
"password not set (email not verified), init password step",
fields{
userSessionViewProvider: &mockViewUserSession{},
userViewProvider: &mockViewUser{
@@ -945,6 +973,54 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
idpUserLinksProvider: &mockIDPUserLinks{},
},
args{&domain.AuthRequest{UserID: "UserID", LoginPolicy: &domain.LoginPolicy{}}, false},
[]domain.NextStep{&domain.VerifyEMailStep{InitPassword: true}},
nil,
},
{
"password not set (email verified), init password step",
fields{
userSessionViewProvider: &mockViewUserSession{},
userViewProvider: &mockViewUser{
PasswordInitRequired: true,
IsEmailVerified: true,
},
userEventProvider: &mockEventUser{
CodeExists: true,
},
lockoutPolicyProvider: &mockLockoutPolicy{
policy: &query.LockoutPolicy{
ShowFailures: true,
},
},
orgViewProvider: &mockViewOrg{State: domain.OrgStateActive},
idpUserLinksProvider: &mockIDPUserLinks{},
passwordReset: newMockPasswordReset(false),
},
args{&domain.AuthRequest{UserID: "UserID", LoginPolicy: &domain.LoginPolicy{}}, false},
[]domain.NextStep{&domain.InitPasswordStep{}},
nil,
},
{
"password not set (email verified, password code not exists), create code, init password step",
fields{
userSessionViewProvider: &mockViewUserSession{},
userViewProvider: &mockViewUser{
PasswordInitRequired: true,
IsEmailVerified: true,
},
userEventProvider: &mockEventUser{
CodeExists: false,
},
lockoutPolicyProvider: &mockLockoutPolicy{
policy: &query.LockoutPolicy{
ShowFailures: true,
},
},
orgViewProvider: &mockViewOrg{State: domain.OrgStateActive},
idpUserLinksProvider: &mockIDPUserLinks{},
passwordReset: newMockPasswordReset(true),
},
args{&domain.AuthRequest{UserID: "UserID", LoginPolicy: &domain.LoginPolicy{}}, false},
[]domain.NextStep{&domain.InitPasswordStep{}},
nil,
},
@@ -1720,6 +1796,9 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
PasswordAgePolicyProvider: tt.fields.passwordAgePolicyProvider,
CustomTextProvider: tt.fields.customTextProvider,
}
if tt.fields.passwordReset != nil {
repo.PasswordReset = tt.fields.passwordReset(t)
}
got, err := repo.nextSteps(context.Background(), tt.args.request, tt.args.checkLoggedIn)
if (err != nil && tt.wantErr == nil) || (tt.wantErr != nil && !tt.wantErr(err)) {
t.Errorf("nextSteps() wrong error = %v", err)
@@ -2201,7 +2280,7 @@ func Test_userSessionByIDs(t *testing.T) {
agentID: "agentID",
user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}},
eventProvider: &mockEventUser{
&es_models.Event{
Event: &es_models.Event{
AggregateType: user_repo.AggregateType,
Typ: user_repo.UserV1MFAOTPCheckSucceededType,
CreationDate: testNow,
@@ -2224,7 +2303,7 @@ func Test_userSessionByIDs(t *testing.T) {
agentID: "agentID",
user: &user_model.UserView{ID: "id"},
eventProvider: &mockEventUser{
&es_models.Event{
Event: &es_models.Event{
AggregateType: user_repo.AggregateType,
Typ: user_repo.UserV1MFAOTPCheckSucceededType,
CreationDate: testNow,
@@ -2251,7 +2330,7 @@ func Test_userSessionByIDs(t *testing.T) {
agentID: "agentID",
user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}},
eventProvider: &mockEventUser{
&es_models.Event{
Event: &es_models.Event{
AggregateType: user_repo.AggregateType,
Typ: user_repo.UserV1MFAOTPCheckSucceededType,
CreationDate: testNow,
@@ -2278,7 +2357,7 @@ func Test_userSessionByIDs(t *testing.T) {
agentID: "agentID",
user: &user_model.UserView{ID: "id"},
eventProvider: &mockEventUser{
&es_models.Event{
Event: &es_models.Event{
AggregateType: user_repo.AggregateType,
Typ: user_repo.UserRemovedType,
},
@@ -2367,7 +2446,7 @@ func Test_userByID(t *testing.T) {
PasswordChangeRequired: true,
},
eventProvider: &mockEventUser{
&es_models.Event{
Event: &es_models.Event{
AggregateType: user_repo.AggregateType,
Typ: user_repo.UserV1PasswordChangedType,
CreationDate: testNow,

View File

@@ -10,7 +10,9 @@ import (
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/repository/user"
usr_view "github.com/zitadel/zitadel/internal/user/repository/view"
"github.com/zitadel/zitadel/internal/zerrors"
)
type UserRepo struct {
@@ -46,3 +48,40 @@ func (repo *UserRepo) UserEventsByID(ctx context.Context, id string, changeDate
}
return repo.Eventstore.Filter(ctx, query) //nolint:staticcheck
}
type passwordCodeCheck struct {
userID string
exists bool
events int
}
func (p *passwordCodeCheck) Reduce() error {
p.exists = p.events > 0
return nil
}
func (p *passwordCodeCheck) AppendEvents(events ...eventstore.Event) {
p.events += len(events)
}
func (p *passwordCodeCheck) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes(user.AggregateType).
AggregateIDs(p.userID).
EventTypes(user.UserV1PasswordCodeAddedType, user.UserV1PasswordCodeSentType,
user.HumanPasswordCodeAddedType, user.HumanPasswordCodeSentType).
Builder()
}
func (repo *UserRepo) PasswordCodeExists(ctx context.Context, userID string) (exists bool, err error) {
model := &passwordCodeCheck{
userID: userID,
}
err = repo.Eventstore.FilterToQueryReducer(ctx, model)
if err != nil {
return false, zerrors.ThrowPermissionDenied(err, "EVENT-SJ642", "Errors.Internal")
}
return model.exists, nil
}

View File

@@ -78,6 +78,7 @@ func Start(ctx context.Context, conf Config, systemDefaults sd.SystemDefaults, c
ProjectProvider: queryView,
ApplicationProvider: queries,
CustomTextProvider: queries,
PasswordReset: command,
IdGenerator: id.SonyFlakeGenerator(),
},
eventstore.TokenRepo{

View File

@@ -64,7 +64,7 @@ func (c *Commands) ChangeHumanEmail(ctx context.Context, email *domain.Email, em
return writeModelToEmail(existingEmail), nil
}
func (c *Commands) VerifyHumanEmail(ctx context.Context, userID, code, resourceowner string, emailCodeGenerator crypto.Generator) (*domain.ObjectDetails, error) {
func (c *Commands) VerifyHumanEmail(ctx context.Context, userID, code, resourceowner, optionalPassword, optionalUserAgentID string, emailCodeGenerator crypto.Generator) (*domain.ObjectDetails, error) {
if userID == "" {
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-4M0ds", "Errors.User.UserIDMissing")
}
@@ -82,21 +82,30 @@ func (c *Commands) VerifyHumanEmail(ctx context.Context, userID, code, resourceo
userAgg := UserAggregateFromWriteModel(&existingCode.WriteModel)
err = crypto.VerifyCode(existingCode.CodeCreationDate, existingCode.CodeExpiry, existingCode.Code, code, emailCodeGenerator.Alg())
if err == nil {
pushedEvents, err := c.eventstore.Push(ctx, user.NewHumanEmailVerifiedEvent(ctx, userAgg))
if err != nil {
return nil, err
}
err = AppendAndReduce(existingCode, pushedEvents...)
if err != nil {
return nil, err
}
return writeModelToObjectDetails(&existingCode.WriteModel), nil
if err != nil {
_, err = c.eventstore.Push(ctx, user.NewHumanEmailVerificationFailedEvent(ctx, userAgg))
logging.WithFields("userID", userAgg.ID).OnError(err).Error("NewHumanEmailVerificationFailedEvent push failed")
return nil, zerrors.ThrowInvalidArgument(err, "COMMAND-Gdsgs", "Errors.User.Code.Invalid")
}
_, err = c.eventstore.Push(ctx, user.NewHumanEmailVerificationFailedEvent(ctx, userAgg))
logging.LogWithFields("COMMAND-Dg2z5", "userID", userAgg.ID).OnError(err).Error("NewHumanEmailVerificationFailedEvent push failed")
return nil, zerrors.ThrowInvalidArgument(err, "COMMAND-Gdsgs", "Errors.User.Code.Invalid")
commands := []eventstore.Command{
user.NewHumanEmailVerifiedEvent(ctx, userAgg),
}
if optionalPassword != "" {
passwordCommand, err := c.setPasswordCommand(ctx, userAgg, domain.UserStateActive, optionalPassword, "", optionalUserAgentID, false, nil)
if err != nil {
return nil, err
}
commands = append(commands, passwordCommand)
}
pushedEvents, err := c.eventstore.Push(ctx, commands...)
if err != nil {
return nil, err
}
err = AppendAndReduce(existingCode, pushedEvents...)
if err != nil {
return nil, err
}
return writeModelToObjectDetails(&existingCode.WriteModel), nil
}
func (c *Commands) CreateHumanEmailVerificationCode(ctx context.Context, userID, resourceOwner string, emailCodeGenerator crypto.Generator, authRequestID string) (*domain.ObjectDetails, error) {

View File

@@ -12,6 +12,7 @@ import (
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
"github.com/zitadel/zitadel/internal/repository/org"
"github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/zerrors"
)
@@ -387,14 +388,17 @@ func TestCommandSide_ChangeHumanEmail(t *testing.T) {
func TestCommandSide_VerifyHumanEmail(t *testing.T) {
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
eventstore func(*testing.T) *eventstore.Eventstore
userPasswordHasher *crypto.Hasher
}
type args struct {
ctx context.Context
userID string
code string
resourceOwner string
secretGenerator crypto.Generator
ctx context.Context
userID string
code string
resourceOwner string
optionalUserAgentID string
optionalPassword string
secretGenerator crypto.Generator
}
type res struct {
want *domain.ObjectDetails
@@ -587,13 +591,96 @@ func TestCommandSide_VerifyHumanEmail(t *testing.T) {
},
},
},
{
name: "valid code (with password and user agent), ok",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
user.NewHumanAddedEvent(context.Background(),
&user.NewAggregate("user1", "org1").Aggregate,
"username",
"firstname",
"lastname",
"nickname",
"displayname",
language.German,
domain.GenderUnspecified,
"email@test.ch",
true,
),
),
eventFromEventPusherWithCreationDateNow(
user.NewHumanEmailCodeAddedEvent(context.Background(),
&user.NewAggregate("user1", "org1").Aggregate,
&crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
Algorithm: "enc",
KeyID: "id",
Crypted: []byte("a"),
},
time.Hour*1,
"",
),
),
),
expectFilter(
eventFromEventPusher(
org.NewPasswordComplexityPolicyAddedEvent(context.Background(),
&user.NewAggregate("user1", "org1").Aggregate,
1,
false,
false,
false,
false,
),
),
),
expectPush(
user.NewHumanEmailVerifiedEvent(context.Background(),
&user.NewAggregate("user1", "org1").Aggregate,
),
user.NewHumanPasswordChangedEvent(context.Background(),
&user.NewAggregate("user1", "org1").Aggregate,
"$plain$x$password",
false,
"userAgentID",
),
),
),
userPasswordHasher: mockPasswordHasher("x"),
},
args: args{
ctx: context.Background(),
userID: "user1",
code: "a",
resourceOwner: "org1",
optionalPassword: "password",
optionalUserAgentID: "userAgentID",
secretGenerator: GetMockSecretGenerator(t),
},
res: res{
want: &domain.ObjectDetails{
ResourceOwner: "org1",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &Commands{
eventstore: tt.fields.eventstore(t),
eventstore: tt.fields.eventstore(t),
userPasswordHasher: tt.fields.userPasswordHasher,
}
got, err := r.VerifyHumanEmail(tt.args.ctx, tt.args.userID, tt.args.code, tt.args.resourceOwner, tt.args.secretGenerator)
got, err := r.VerifyHumanEmail(
tt.args.ctx,
tt.args.userID,
tt.args.code,
tt.args.resourceOwner,
tt.args.optionalPassword,
tt.args.optionalUserAgentID,
tt.args.secretGenerator,
)
if tt.res.err == nil {
assert.NoError(t, err)
}

View File

@@ -239,6 +239,11 @@ func (c *Commands) HumanRemoveTOTP(ctx context.Context, userID, resourceOwner st
if existingOTP.State == domain.MFAStateUnspecified || existingOTP.State == domain.MFAStateRemoved {
return nil, zerrors.ThrowNotFound(nil, "COMMAND-Hd9sd", "Errors.User.MFA.OTP.NotExisting")
}
if userID != authz.GetCtxData(ctx).UserID {
if err := c.checkPermission(ctx, domain.PermissionUserWrite, existingOTP.ResourceOwner, userID); err != nil {
return nil, err
}
}
userAgg := UserAggregateFromWriteModel(&existingOTP.WriteModel)
pushedEvents, err := c.eventstore.Push(ctx, user.NewHumanOTPRemovedEvent(ctx, userAgg))
if err != nil {

View File

@@ -841,7 +841,8 @@ func TestCommands_HumanCheckMFATOTPSetup(t *testing.T) {
func TestCommandSide_RemoveHumanTOTP(t *testing.T) {
type fields struct {
eventstore func(t *testing.T) *eventstore.Eventstore
eventstore func(t *testing.T) *eventstore.Eventstore
checkPermission domain.PermissionCheck
}
type (
args struct {
@@ -891,7 +892,31 @@ func TestCommandSide_RemoveHumanTOTP(t *testing.T) {
},
},
{
name: "otp not existing, not found error",
name: "otp, no permission error",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
user.NewHumanOTPAddedEvent(context.Background(),
&user.NewAggregate("user1", "org1").Aggregate,
nil,
),
),
),
),
checkPermission: newMockPermissionCheckNotAllowed(),
},
args: args{
ctx: context.Background(),
orgID: "org1",
userID: "user1",
},
res: res{
err: zerrors.IsPermissionDenied,
},
},
{
name: "otp remove, ok",
fields: fields{
eventstore: expectEventstore(
expectFilter(
@@ -908,6 +933,7 @@ func TestCommandSide_RemoveHumanTOTP(t *testing.T) {
),
),
),
checkPermission: newMockPermissionCheckAllowed(),
},
args: args{
ctx: context.Background(),
@@ -924,7 +950,8 @@ func TestCommandSide_RemoveHumanTOTP(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &Commands{
eventstore: tt.fields.eventstore(t),
eventstore: tt.fields.eventstore(t),
checkPermission: tt.fields.checkPermission,
}
got, err := r.HumanRemoveTOTP(tt.args.ctx, tt.args.userID, tt.args.orgID)
if tt.res.err == nil {

View File

@@ -228,7 +228,7 @@ func (c *Commands) checkPasswordComplexity(ctx context.Context, newPassword stri
}
// RequestSetPassword generate and send out new code to change password for a specific user
func (c *Commands) RequestSetPassword(ctx context.Context, userID, resourceOwner string, notifyType domain.NotificationType, passwordVerificationCode crypto.Generator, authRequestID string) (objectDetails *domain.ObjectDetails, err error) {
func (c *Commands) RequestSetPassword(ctx context.Context, userID, resourceOwner string, notifyType domain.NotificationType, authRequestID string) (objectDetails *domain.ObjectDetails, err error) {
if userID == "" {
return nil, zerrors.ThrowInvalidArgument(nil, "COMMAND-M00oL", "Errors.User.UserIDMissing")
}
@@ -244,11 +244,11 @@ func (c *Commands) RequestSetPassword(ctx context.Context, userID, resourceOwner
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-2M9sd", "Errors.User.NotInitialised")
}
userAgg := UserAggregateFromWriteModel(&existingHuman.WriteModel)
passwordCode, err := domain.NewPasswordCode(passwordVerificationCode)
passwordCode, err := c.newEncryptedCode(ctx, c.eventstore.Filter, domain.SecretGeneratorTypePasswordResetCode, c.userEncryption) //nolint:staticcheck
if err != nil {
return nil, err
}
pushedEvents, err := c.eventstore.Push(ctx, user.NewHumanPasswordCodeAddedEvent(ctx, userAgg, passwordCode.Code, passwordCode.Expiry, notifyType, authRequestID))
pushedEvents, err := c.eventstore.Push(ctx, user.NewHumanPasswordCodeAddedEvent(ctx, userAgg, passwordCode.Crypted, passwordCode.Expiry, notifyType, authRequestID))
if err != nil {
return nil, err
}

View File

@@ -1111,14 +1111,14 @@ func TestCommandSide_ChangePassword(t *testing.T) {
func TestCommandSide_RequestSetPassword(t *testing.T) {
type fields struct {
eventstore func(*testing.T) *eventstore.Eventstore
newCode encrypedCodeFunc
}
type args struct {
ctx context.Context
userID string
resourceOwner string
notifyType domain.NotificationType
secretGenerator crypto.Generator
authRequestID string
ctx context.Context
userID string
resourceOwner string
notifyType domain.NotificationType
authRequestID string
}
type res struct {
want *domain.ObjectDetails
@@ -1251,12 +1251,12 @@ func TestCommandSide_RequestSetPassword(t *testing.T) {
),
),
),
newCode: mockEncryptedCode("a", 1*time.Hour),
},
args: args{
ctx: context.Background(),
userID: "user1",
resourceOwner: "org1",
secretGenerator: GetMockSecretGenerator(t),
ctx: context.Background(),
userID: "user1",
resourceOwner: "org1",
},
res: res{
want: &domain.ObjectDetails{
@@ -1307,13 +1307,13 @@ func TestCommandSide_RequestSetPassword(t *testing.T) {
),
),
),
newCode: mockEncryptedCode("a", 1*time.Hour),
},
args: args{
ctx: context.Background(),
userID: "user1",
resourceOwner: "org1",
secretGenerator: GetMockSecretGenerator(t),
authRequestID: "authRequestID",
ctx: context.Background(),
userID: "user1",
resourceOwner: "org1",
authRequestID: "authRequestID",
},
res: res{
want: &domain.ObjectDetails{
@@ -1325,9 +1325,10 @@ func TestCommandSide_RequestSetPassword(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &Commands{
eventstore: tt.fields.eventstore(t),
eventstore: tt.fields.eventstore(t),
newEncryptedCode: tt.fields.newCode,
}
got, err := r.RequestSetPassword(tt.args.ctx, tt.args.userID, tt.args.resourceOwner, tt.args.notifyType, tt.args.secretGenerator, tt.args.authRequestID)
got, err := r.RequestSetPassword(tt.args.ctx, tt.args.userID, tt.args.resourceOwner, tt.args.notifyType, tt.args.authRequestID)
if tt.res.err == nil {
assert.NoError(t, err)
}

View File

@@ -148,6 +148,52 @@ func TestCommandSide_AddMachine(t *testing.T) {
},
},
},
{
name: "add machine - custom id, ok",
fields: fields{
eventstore: eventstoreExpect(
t,
expectFilter(),
expectFilter(
eventFromEventPusher(
org.NewDomainPolicyAddedEvent(context.Background(),
&user.NewAggregate("optionalID1", "org1").Aggregate,
true,
true,
true,
),
),
),
expectPush(
user.NewMachineAddedEvent(context.Background(),
&user.NewAggregate("optionalID1", "org1").Aggregate,
"username",
"name",
"description",
true,
domain.OIDCTokenTypeBearer,
),
),
),
},
args: args{
ctx: context.Background(),
machine: &Machine{
ObjectRoot: models.ObjectRoot{
AggregateID: "optionalID1",
ResourceOwner: "org1",
},
Description: "description",
Name: "name",
Username: "username",
},
},
res: res{
want: &domain.ObjectDetails{
ResourceOwner: "org1",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View File

@@ -319,7 +319,7 @@ func TestCommandSide_AddUserHuman(t *testing.T) {
},
},
{
name: "add human (with initial code), ok",
name: "add human (email not verified, no password), ok (init code)",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
@@ -389,7 +389,7 @@ func TestCommandSide_AddUserHuman(t *testing.T) {
},
},
{
name: "add human (with password and initial code), ok",
name: "add human (email not verified, with password), ok (init code)",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
@@ -459,6 +459,65 @@ func TestCommandSide_AddUserHuman(t *testing.T) {
wantID: "user1",
},
},
{
name: "add human (email not verified, no password, no allowInitMail), ok (email verification with passwordInit)",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
expectFilter(
eventFromEventPusher(
org.NewDomainPolicyAddedEvent(context.Background(),
&user.NewAggregate("user1", "org1").Aggregate,
true,
true,
true,
),
),
),
expectPush(
newAddHumanEvent("", false, true, "", language.English),
user.NewHumanEmailCodeAddedEventV2(context.Background(),
&user.NewAggregate("user1", "org1").Aggregate,
&crypto.CryptoValue{
CryptoType: crypto.TypeEncryption,
Algorithm: "enc",
KeyID: "id",
Crypted: []byte("emailverify"),
},
1*time.Hour,
"",
false,
"",
),
),
),
checkPermission: newMockPermissionCheckAllowed(),
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "user1"),
newCode: mockEncryptedCode("emailverify", time.Hour),
},
args: args{
ctx: context.Background(),
orgID: "org1",
human: &AddHuman{
Username: "username",
FirstName: "firstname",
LastName: "lastname",
Email: Email{
Address: "email@test.ch",
},
PreferredLanguage: language.English,
},
secretGenerator: GetMockSecretGenerator(t),
allowInitMail: false,
codeAlg: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
res: res{
want: &domain.ObjectDetails{
ResourceOwner: "org1",
},
wantID: "user1",
},
},
{
name: "add human (with password and email code custom template), ok",
fields: fields{
@@ -609,7 +668,7 @@ func TestCommandSide_AddUserHuman(t *testing.T) {
},
},
{
name: "add human email verified, ok",
name: "add human email verified and password, ok",
fields: fields{
eventstore: expectEventstore(
expectFilter(),
@@ -1084,8 +1143,9 @@ func TestCommandSide_AddUserHuman(t *testing.T) {
},
wantID: "user1",
},
}, {
name: "add human (with return code), ok",
},
{
name: "add human (with phone return code), ok",
fields: fields{
eventstore: expectEventstore(
expectFilter(),

View File

@@ -140,6 +140,29 @@ func (c *Commands) verifyUserPhoneWithGenerator(ctx context.Context, userID, cod
return writeModelToObjectDetails(&cmd.model.WriteModel), nil
}
func (c *Commands) RemoveUserPhone(ctx context.Context, userID string) (*domain.ObjectDetails, error) {
return c.removeUserPhone(ctx, userID)
}
func (c *Commands) removeUserPhone(ctx context.Context, userID string) (*domain.ObjectDetails, error) {
cmd, err := c.NewUserPhoneEvents(ctx, userID)
if err != nil {
return nil, err
}
if authz.GetCtxData(ctx).UserID != userID {
if err = c.checkPermission(ctx, domain.PermissionUserWrite, cmd.aggregate.ResourceOwner, userID); err != nil {
return nil, err
}
}
if err = cmd.Remove(ctx); err != nil {
return nil, err
}
if _, err = cmd.Push(ctx); err != nil {
return nil, err
}
return writeModelToObjectDetails(&cmd.model.WriteModel), nil
}
// UserPhoneEvents allows step-by-step additions of events,
// operating on the Human Phone Model.
type UserPhoneEvents struct {
@@ -191,6 +214,14 @@ func (c *UserPhoneEvents) Change(ctx context.Context, phone domain.PhoneNumber)
return nil
}
func (c *UserPhoneEvents) Remove(ctx context.Context) error {
if c.model.State == domain.PhoneStateRemoved || c.model.State == domain.PhoneStateUnspecified {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-ieJ2e", "Errors.User.Phone.NotFound")
}
c.events = append(c.events, user.NewHumanPhoneRemovedEvent(ctx, c.aggregate))
return nil
}
// SetVerified sets the phone number to verified.
func (c *UserPhoneEvents) SetVerified(ctx context.Context) {
c.events = append(c.events, user.NewHumanPhoneVerifiedEvent(ctx, c.aggregate))

View File

@@ -137,7 +137,9 @@ func (s *ChangeUsernameStep) Type() NextStepType {
return NextStepChangeUsername
}
type VerifyEMailStep struct{}
type VerifyEMailStep struct {
InitPassword bool
}
func (s *VerifyEMailStep) Type() NextStepType {
return NextStepVerifyEmail

View File

@@ -89,6 +89,22 @@ func isPrivateIPv4(ip net.IP) bool {
(ip[0] == 10 || ip[0] == 172 && (ip[1] >= 16 && ip[1] < 32) || ip[0] == 192 && ip[1] == 168)
}
func MachineIdentificationMethod() string {
if GeneratorConfig.Identification.PrivateIp.Enabled {
return "Private Ip"
}
if GeneratorConfig.Identification.Hostname.Enabled {
return "Hostname"
}
if GeneratorConfig.Identification.Webhook.Enabled {
return "Webhook"
}
return "No machine identification method has been enabled"
}
func machineID() (uint16, error) {
if GeneratorConfig == nil {
logging.Panic("cannot create a unique id for the machine, generator has not been configured")

View File

@@ -17,7 +17,6 @@ import (
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/command"
@@ -154,6 +153,30 @@ func (s *Tester) CreateHumanUser(ctx context.Context) *user.AddHumanUserResponse
return resp
}
func (s *Tester) CreateHumanUserNoPhone(ctx context.Context) *user.AddHumanUserResponse {
resp, err := s.Client.UserV2.AddHumanUser(ctx, &user.AddHumanUserRequest{
Organization: &object.Organization{
Org: &object.Organization_OrgId{
OrgId: s.Organisation.ID,
},
},
Profile: &user.SetHumanProfile{
GivenName: "Mickey",
FamilyName: "Mouse",
PreferredLanguage: gu.Ptr("nl"),
Gender: gu.Ptr(user.Gender_GENDER_MALE),
},
Email: &user.SetHumanEmail{
Email: fmt.Sprintf("%d@mouse.com", time.Now().UnixNano()),
Verification: &user.SetHumanEmail_ReturnCode{
ReturnCode: &user.ReturnEmailVerificationCode{},
},
},
})
logging.OnError(err).Fatal("create human user")
return resp
}
func (s *Tester) CreateHumanUserWithTOTP(ctx context.Context, secret string) *user.AddHumanUserResponse {
resp, err := s.Client.UserV2.AddHumanUser(ctx, &user.AddHumanUserRequest{
Organization: &object.Organization{
@@ -312,7 +335,7 @@ func (s *Tester) RegisterUserU2F(ctx context.Context, userID string) {
logging.OnError(err).Fatal("create user u2f")
}
func (s *Tester) SetUserPassword(ctx context.Context, userID, password string, changeRequired bool) *timestamppb.Timestamp {
func (s *Tester) SetUserPassword(ctx context.Context, userID, password string, changeRequired bool) *object.Details {
resp, err := s.Client.UserV2.SetPassword(ctx, &user.SetPasswordRequest{
UserId: userID,
NewPassword: &user.Password{
@@ -321,7 +344,7 @@ func (s *Tester) SetUserPassword(ctx context.Context, userID, password string, c
},
})
logging.OnError(err).Fatal("set user password")
return resp.GetDetails().GetChangeDate()
return resp.GetDetails()
}
func (s *Tester) AddGenericOAuthProvider(t *testing.T, ctx context.Context) string {

View File

@@ -281,6 +281,10 @@ func (q *Queries) SearchOrgs(ctx context.Context, queries *OrgSearchQueries) (or
return orgs, err
}
func NewOrgIDSearchQuery(value string) (SearchQuery, error) {
return NewTextQuery(OrgColumnID, value, TextEquals)
}
func NewOrgDomainSearchQuery(method TextComparison, value string) (SearchQuery, error) {
return NewTextQuery(OrgColumnDomain, value, method)
}

View File

@@ -146,7 +146,7 @@ func (q *Queries) SearchUserAuthMethods(ctx context.Context, queries *UserAuthMe
return userAuthMethods, err
}
func (q *Queries) ListActiveUserAuthMethodTypes(ctx context.Context, userID string) (userAuthMethodTypes *AuthMethodTypes, err error) {
func (q *Queries) ListUserAuthMethodTypes(ctx context.Context, userID string, activeOnly bool) (userAuthMethodTypes *AuthMethodTypes, err error) {
ctxData := authz.GetCtxData(ctx)
if ctxData.UserID != userID {
if err := q.checkPermission(ctx, domain.PermissionUserRead, ctxData.OrgID, userID); err != nil {
@@ -156,7 +156,7 @@ func (q *Queries) ListActiveUserAuthMethodTypes(ctx context.Context, userID stri
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareActiveUserAuthMethodTypesQuery(ctx, q.client)
query, scan := prepareUserAuthMethodTypesQuery(ctx, q.client, activeOnly)
eq := sq.Eq{
UserIDCol.identifier(): userID,
UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
@@ -353,8 +353,8 @@ func prepareUserAuthMethodsQuery(ctx context.Context, db prepareDatabase) (sq.Se
}
}
func prepareActiveUserAuthMethodTypesQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) {
authMethodsQuery, authMethodsArgs, err := prepareAuthMethodQuery()
func prepareUserAuthMethodTypesQuery(ctx context.Context, db prepareDatabase, activeOnly bool) (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) {
authMethodsQuery, authMethodsArgs, err := prepareAuthMethodQuery(activeOnly)
if err != nil {
return sq.SelectBuilder{}, nil
}
@@ -468,14 +468,16 @@ func prepareAuthMethodsIDPsQuery() (string, error) {
return idpsQuery, err
}
func prepareAuthMethodQuery() (string, []interface{}, error) {
return sq.Select(
func prepareAuthMethodQuery(activeOnly bool) (string, []interface{}, error) {
q := sq.Select(
"DISTINCT("+authMethodTypeType.identifier()+")",
authMethodTypeUserID.identifier(),
authMethodTypeInstanceID.identifier()).
From(authMethodTypeTable.identifier()).
Where(sq.Eq{authMethodTypeState.identifier(): domain.MFAStateReady}).
ToSql()
From(authMethodTypeTable.identifier())
if activeOnly {
q = q.Where(sq.Eq{authMethodTypeState.identifier(): domain.MFAStateReady})
}
return q.ToSql()
}
func prepareAuthMethodsForceMFAQuery() (string, error) {

View File

@@ -217,8 +217,13 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
object: (*AuthMethodTypes)(nil),
},
{
name: "prepareActiveUserAuthMethodTypesQuery no result",
prepare: prepareActiveUserAuthMethodTypesQuery,
name: "prepareUserAuthMethodTypesQuery no result",
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) {
builder, scan := prepareUserAuthMethodTypesQuery(ctx, db, true)
return builder, func(rows *sql.Rows) (*AuthMethodTypes, error) {
return scan(rows)
}
},
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(prepareActiveAuthMethodTypesStmt),
@@ -229,8 +234,13 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
object: &AuthMethodTypes{AuthMethodTypes: []domain.UserAuthMethodType{}},
},
{
name: "prepareActiveUserAuthMethodTypesQuery one second factor",
prepare: prepareActiveUserAuthMethodTypesQuery,
name: "prepareUserAuthMethodTypesQuery one second factor",
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) {
builder, scan := prepareUserAuthMethodTypesQuery(ctx, db, true)
return builder, func(rows *sql.Rows) (*AuthMethodTypes, error) {
return scan(rows)
}
},
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(prepareActiveAuthMethodTypesStmt),
@@ -256,8 +266,13 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
},
},
{
name: "prepareActiveUserAuthMethodTypesQuery multiple second factors",
prepare: prepareActiveUserAuthMethodTypesQuery,
name: "prepareUserAuthMethodTypesQuery multiple second factors",
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) {
builder, scan := prepareUserAuthMethodTypesQuery(ctx, db, true)
return builder, func(rows *sql.Rows) (*AuthMethodTypes, error) {
return scan(rows)
}
},
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(prepareActiveAuthMethodTypesStmt),
@@ -289,8 +304,13 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
},
},
{
name: "prepareActiveUserAuthMethodTypesQuery sql err",
prepare: prepareActiveUserAuthMethodTypesQuery,
name: "prepareUserAuthMethodTypesQuery sql err",
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) {
builder, scan := prepareUserAuthMethodTypesQuery(ctx, db, true)
return builder, func(rows *sql.Rows) (*AuthMethodTypes, error) {
return scan(rows)
}
},
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(prepareActiveAuthMethodTypesStmt),