From 5c6df06a7c3a23bcdaa47c551901d9d17fe6dfea Mon Sep 17 00:00:00 2001 From: Fabi <38692350+fgerschwiler@users.noreply.github.com> Date: Thu, 20 Jan 2022 13:21:59 +0100 Subject: [PATCH] feat: auth method query side (#3068) * feat: queries for searching mfas and passwordless * feat: tests for user auth method queries * Update internal/api/grpc/auth/multi_factor.go Co-authored-by: Livio Amstutz * Update internal/api/grpc/auth/passwordless.go Co-authored-by: Livio Amstutz * Update internal/api/grpc/management/user.go Co-authored-by: Livio Amstutz * Update internal/api/grpc/management/user.go Co-authored-by: Livio Amstutz Co-authored-by: Livio Amstutz --- internal/api/grpc/auth/multi_factor.go | 15 +- internal/api/grpc/auth/passwordless.go | 14 +- internal/api/grpc/management/user.go | 26 +- internal/api/grpc/user/converter.go | 40 +-- .../eventsourcing/eventstore/user.go | 29 -- internal/management/repository/user.go | 4 - internal/query/projection/user_auth_method.go | 6 +- internal/query/user_auth_method.go | 274 +++++++++++++++ internal/query/user_auth_method_test.go | 331 ++++++++++++++++++ 9 files changed, 672 insertions(+), 67 deletions(-) create mode 100644 internal/query/user_auth_method.go create mode 100644 internal/query/user_auth_method_test.go diff --git a/internal/api/grpc/auth/multi_factor.go b/internal/api/grpc/auth/multi_factor.go index 6ceeaa2f8f..dfce63c37c 100644 --- a/internal/api/grpc/auth/multi_factor.go +++ b/internal/api/grpc/auth/multi_factor.go @@ -6,17 +6,28 @@ import ( "github.com/caos/zitadel/internal/api/authz" "github.com/caos/zitadel/internal/api/grpc/object" user_grpc "github.com/caos/zitadel/internal/api/grpc/user" + "github.com/caos/zitadel/internal/domain" + "github.com/caos/zitadel/internal/query" auth_pb "github.com/caos/zitadel/pkg/grpc/auth" user_pb "github.com/caos/zitadel/pkg/grpc/user" ) func (s *Server) ListMyAuthFactors(ctx context.Context, _ *auth_pb.ListMyAuthFactorsRequest) (*auth_pb.ListMyAuthFactorsResponse, error) { - mfas, err := s.repo.MyUserMFAs(ctx) + query := new(query.UserAuthMethodSearchQueries) + err := query.AppendUserIDQuery(authz.GetCtxData(ctx).UserID) + if err != nil { + return nil, err + } + err = query.AppendAuthMethodsQuery(domain.UserAuthMethodTypeU2F, domain.UserAuthMethodTypeOTP) + if err != nil { + return nil, err + } + authMethods, err := s.query.SearchUserAuthMethods(ctx, query) if err != nil { return nil, err } return &auth_pb.ListMyAuthFactorsResponse{ - Result: user_grpc.AuthFactorsToPb(mfas), + Result: user_grpc.AuthMethodsToPb(authMethods), }, nil } diff --git a/internal/api/grpc/auth/passwordless.go b/internal/api/grpc/auth/passwordless.go index 543a5a4f25..915999539f 100644 --- a/internal/api/grpc/auth/passwordless.go +++ b/internal/api/grpc/auth/passwordless.go @@ -3,6 +3,7 @@ package auth import ( "context" + "github.com/caos/zitadel/internal/query" "google.golang.org/protobuf/types/known/durationpb" "github.com/caos/zitadel/internal/api/authz" @@ -14,12 +15,21 @@ import ( ) func (s *Server) ListMyPasswordless(ctx context.Context, _ *auth_pb.ListMyPasswordlessRequest) (*auth_pb.ListMyPasswordlessResponse, error) { - tokens, err := s.repo.GetMyPasswordless(ctx) + query := new(query.UserAuthMethodSearchQueries) + err := query.AppendUserIDQuery(authz.GetCtxData(ctx).UserID) + if err != nil { + return nil, err + } + err = query.AppendAuthMethodQuery(domain.UserAuthMethodTypePasswordless) + if err != nil { + return nil, err + } + authMethods, err := s.query.SearchUserAuthMethods(ctx, query) if err != nil { return nil, err } return &auth_pb.ListMyPasswordlessResponse{ - Result: user_grpc.WebAuthNTokensViewToPb(tokens), + Result: user_grpc.UserAuthMethodsToWebAuthNTokenPb(authMethods), }, nil } diff --git a/internal/api/grpc/management/user.go b/internal/api/grpc/management/user.go index a7112e12b3..9f04a8cda3 100644 --- a/internal/api/grpc/management/user.go +++ b/internal/api/grpc/management/user.go @@ -473,12 +473,21 @@ func (s *Server) SendHumanResetPasswordNotification(ctx context.Context, req *mg } func (s *Server) ListHumanAuthFactors(ctx context.Context, req *mgmt_pb.ListHumanAuthFactorsRequest) (*mgmt_pb.ListHumanAuthFactorsResponse, error) { - mfas, err := s.user.UserMFAs(ctx, req.UserId) + query := new(query.UserAuthMethodSearchQueries) + err := query.AppendUserIDQuery(req.UserId) + if err != nil { + return nil, err + } + err = query.AppendAuthMethodsQuery(domain.UserAuthMethodTypeU2F, domain.UserAuthMethodTypeOTP) + if err != nil { + return nil, err + } + authMethods, err := s.query.SearchUserAuthMethods(ctx, query) if err != nil { return nil, err } return &mgmt_pb.ListHumanAuthFactorsResponse{ - Result: user_grpc.AuthFactorsToPb(mfas), + Result: user_grpc.AuthMethodsToPb(authMethods), }, nil } @@ -503,12 +512,21 @@ func (s *Server) RemoveHumanAuthFactorU2F(ctx context.Context, req *mgmt_pb.Remo } func (s *Server) ListHumanPasswordless(ctx context.Context, req *mgmt_pb.ListHumanPasswordlessRequest) (*mgmt_pb.ListHumanPasswordlessResponse, error) { - tokens, err := s.user.GetPasswordless(ctx, req.UserId) + query := new(query.UserAuthMethodSearchQueries) + err := query.AppendUserIDQuery(req.UserId) +if err != nil { + return nil, err + } + err = query.AppendAuthMethodQuery(domain.UserAuthMethodTypePasswordless) + if err != nil { + return nil, err + } + authMethods, err := s.query.SearchUserAuthMethods(ctx, query) if err != nil { return nil, err } return &mgmt_pb.ListHumanPasswordlessResponse{ - Result: user_grpc.WebAuthNTokensViewToPb(tokens), + Result: user_grpc.UserAuthMethodsToWebAuthNTokenPb(authMethods), }, nil } diff --git a/internal/api/grpc/user/converter.go b/internal/api/grpc/user/converter.go index a40c78f16b..e4ea7684ca 100644 --- a/internal/api/grpc/user/converter.go +++ b/internal/api/grpc/user/converter.go @@ -173,54 +173,54 @@ func GenderToPb(gender model.Gender) user_pb.Gender { } } -func AuthFactorsToPb(mfas []*model.MultiFactor) []*user_pb.AuthFactor { - factors := make([]*user_pb.AuthFactor, len(mfas)) - for i, mfa := range mfas { - factors[i] = AuthFactorToPb(mfa) +func AuthMethodsToPb(mfas *query.AuthMethods) []*user_pb.AuthFactor { + factors := make([]*user_pb.AuthFactor, len(mfas.AuthMethods)) + for i, mfa := range mfas.AuthMethods { + factors[i] = AuthMethodToPb(mfa) } return factors } -func AuthFactorToPb(mfa *model.MultiFactor) *user_pb.AuthFactor { +func AuthMethodToPb(mfa *query.AuthMethod) *user_pb.AuthFactor { factor := &user_pb.AuthFactor{ State: MFAStateToPb(mfa.State), } switch mfa.Type { - case model.MFATypeOTP: + case domain.UserAuthMethodTypeOTP: factor.Type = &user_pb.AuthFactor_Otp{ Otp: &user_pb.AuthFactorOTP{}, } - case model.MFATypeU2F: + case domain.UserAuthMethodTypeU2F: factor.Type = &user_pb.AuthFactor_U2F{ U2F: &user_pb.AuthFactorU2F{ - Id: mfa.ID, - Name: mfa.Attribute, + Id: mfa.TokenID, + Name: mfa.Name, }, } } return factor } -func MFAStateToPb(state model.MFAState) user_pb.AuthFactorState { +func MFAStateToPb(state domain.MFAState) user_pb.AuthFactorState { switch state { - case model.MFAStateNotReady: + case domain.MFAStateNotReady: return user_pb.AuthFactorState_AUTH_FACTOR_STATE_NOT_READY - case model.MFAStateReady: + case domain.MFAStateReady: return user_pb.AuthFactorState_AUTH_FACTOR_STATE_READY default: return user_pb.AuthFactorState_AUTH_FACTOR_STATE_UNSPECIFIED } } -func WebAuthNTokensViewToPb(tokens []*model.WebAuthNView) []*user_pb.WebAuthNToken { - t := make([]*user_pb.WebAuthNToken, len(tokens)) - for i, token := range tokens { - t[i] = WebAuthNTokenViewToPb(token) +func UserAuthMethodsToWebAuthNTokenPb(methods *query.AuthMethods) []*user_pb.WebAuthNToken { + t := make([]*user_pb.WebAuthNToken, len(methods.AuthMethods)) + for i, token := range methods.AuthMethods { + t[i] = UserAuthMethodToWebAuthNTokenPb(token) } return t } -func WebAuthNTokenViewToPb(token *model.WebAuthNView) *user_pb.WebAuthNToken { +func UserAuthMethodToWebAuthNTokenPb(token *query.AuthMethod) *user_pb.WebAuthNToken { return &user_pb.WebAuthNToken{ Id: token.TokenID, State: MFAStateToPb(token.State), @@ -228,12 +228,6 @@ func WebAuthNTokenViewToPb(token *model.WebAuthNView) *user_pb.WebAuthNToken { } } -func WebAuthNTokenToWebAuthNKeyPb(token *domain.WebAuthNToken) *user_pb.WebAuthNKey { - return &user_pb.WebAuthNKey{ - PublicKey: token.PublicKey, - } -} - func ExternalIDPViewsToExternalIDPs(externalIDPs []*query.IDPUserLink) []*domain.UserIDPLink { idps := make([]*domain.UserIDPLink, len(externalIDPs)) for i, idp := range externalIDPs { diff --git a/internal/management/repository/eventsourcing/eventstore/user.go b/internal/management/repository/eventsourcing/eventstore/user.go index 3f78122ada..0094beaebc 100644 --- a/internal/management/repository/eventsourcing/eventstore/user.go +++ b/internal/management/repository/eventsourcing/eventstore/user.go @@ -187,35 +187,6 @@ func (repo *UserRepo) SearchMetadata(ctx context.Context, userID, resourceOwner return result, nil } -func (repo *UserRepo) UserMFAs(ctx context.Context, userID string) ([]*usr_model.MultiFactor, error) { - user, err := repo.UserByID(ctx, userID) - if err != nil { - return nil, err - } - if user.HumanView == nil { - return nil, errors.ThrowPreconditionFailed(nil, "EVENT-xx0hV", "Errors.User.NotHuman") - } - mfas := make([]*usr_model.MultiFactor, 0) - if user.OTPState != usr_model.MFAStateUnspecified { - mfas = append(mfas, &usr_model.MultiFactor{Type: usr_model.MFATypeOTP, State: user.OTPState}) - } - for _, u2f := range user.U2FTokens { - mfas = append(mfas, &usr_model.MultiFactor{Type: usr_model.MFATypeU2F, State: u2f.State, Attribute: u2f.Name, ID: u2f.TokenID}) - } - return mfas, nil -} - -func (repo *UserRepo) GetPasswordless(ctx context.Context, userID string) ([]*usr_model.WebAuthNView, error) { - user, err := repo.UserByID(ctx, userID) - if err != nil { - return nil, err - } - if user.HumanView == nil { - return nil, errors.ThrowPreconditionFailed(nil, "EVENT-9anf8", "Errors.User.NotHuman") - } - return user.HumanView.PasswordlessTokens, nil -} - func (repo *UserRepo) ProfileByID(ctx context.Context, userID string) (*usr_model.Profile, error) { user, err := repo.UserByID(ctx, userID) if err != nil { diff --git a/internal/management/repository/user.go b/internal/management/repository/user.go index 046f24c14c..d0933b3a9f 100644 --- a/internal/management/repository/user.go +++ b/internal/management/repository/user.go @@ -24,10 +24,6 @@ type UserRepository interface { ProfileByID(ctx context.Context, userID string) (*model.Profile, error) - UserMFAs(ctx context.Context, userID string) ([]*model.MultiFactor, error) - - GetPasswordless(ctx context.Context, userID string) ([]*model.WebAuthNView, error) - EmailByID(ctx context.Context, userID string) (*model.Email, error) PhoneByID(ctx context.Context, userID string) (*model.Phone, error) diff --git a/internal/query/projection/user_auth_method.go b/internal/query/projection/user_auth_method.go index d5c5c0a2a3..a11ebc6c95 100644 --- a/internal/query/projection/user_auth_method.go +++ b/internal/query/projection/user_auth_method.go @@ -31,7 +31,7 @@ func NewUserAuthMethodProjection(ctx context.Context, config crdb.StatementHandl const ( UserAuthMethodTokenIDCol = "token_id" UserAuthMethodCreationDateCol = "creation_date" - UserAuthMethodChangeUseCol = "change_date" + UserAuthMethodChangeDateCol = "change_date" UserAuthMethodResourceOwnerCol = "resource_owner" UserAuthMethodUserIDCol = "user_id" UserAuthMethodSequenceCol = "sequence" @@ -108,7 +108,7 @@ func (p *UserAuthMethodProjection) reduceInitAuthMethod(event eventstore.Event) []handler.Column{ handler.NewCol(UserAuthMethodTokenIDCol, tokenID), handler.NewCol(UserAuthMethodCreationDateCol, event.CreationDate()), - handler.NewCol(UserAuthMethodChangeUseCol, event.CreationDate()), + handler.NewCol(UserAuthMethodChangeDateCol, event.CreationDate()), handler.NewCol(UserAuthMethodResourceOwnerCol, event.Aggregate().ResourceOwner), handler.NewCol(UserAuthMethodUserIDCol, event.Aggregate().ID), handler.NewCol(UserAuthMethodSequenceCol, event.Sequence()), @@ -144,7 +144,7 @@ func (p *UserAuthMethodProjection) reduceActivateEvent(event eventstore.Event) ( return crdb.NewUpdateStatement( event, []handler.Column{ - handler.NewCol(UserAuthMethodChangeUseCol, event.CreationDate()), + handler.NewCol(UserAuthMethodChangeDateCol, event.CreationDate()), handler.NewCol(UserAuthMethodSequenceCol, event.Sequence()), handler.NewCol(UserAuthMethodNameCol, name), handler.NewCol(UserAuthMethodStateCol, domain.MFAStateReady), diff --git a/internal/query/user_auth_method.go b/internal/query/user_auth_method.go new file mode 100644 index 0000000000..92d9acdd8f --- /dev/null +++ b/internal/query/user_auth_method.go @@ -0,0 +1,274 @@ +package query + +import ( + "context" + "database/sql" + errs "errors" + "time" + + sq "github.com/Masterminds/squirrel" + "github.com/caos/zitadel/internal/query/projection" + + "github.com/caos/zitadel/internal/domain" + "github.com/caos/zitadel/internal/errors" +) + +var ( + userAuthMethodTable = table{ + name: projection.UserAuthMethodTable, + } + UserAuthMethodColumnTokenID = Column{ + name: projection.UserAuthMethodTokenIDCol, + table: userAuthMethodTable, + } + UserAuthMethodColumnCreationDate = Column{ + name: projection.UserAuthMethodCreationDateCol, + table: userAuthMethodTable, + } + UserAuthMethodColumnChangeDate = Column{ + name: projection.UserAuthMethodChangeDateCol, + table: userAuthMethodTable, + } + UserAuthMethodColumnResourceOwner = Column{ + name: projection.UserAuthMethodResourceOwnerCol, + table: userAuthMethodTable, + } + UserAuthMethodColumnUserID = Column{ + name: projection.UserAuthMethodUserIDCol, + table: userAuthMethodTable, + } + UserAuthMethodColumnSequence = Column{ + name: projection.UserAuthMethodSequenceCol, + table: userAuthMethodTable, + } + UserAuthMethodColumnName = Column{ + name: projection.UserAuthMethodNameCol, + table: userAuthMethodTable, + } + UserAuthMethodColumnState = Column{ + name: projection.UserAuthMethodStateCol, + table: userAuthMethodTable, + } + UserAuthMethodColumnMethodType = Column{ + name: projection.UserAuthMethodTypeCol, + table: userAuthMethodTable, + } +) + +type AuthMethods struct { + SearchResponse + AuthMethods []*AuthMethod +} +type AuthMethod struct { + UserID string + CreationDate time.Time + ChangeDate time.Time + ResourceOwner string + State domain.MFAState + Sequence uint64 + + TokenID string + Name string + Type domain.UserAuthMethodType +} + +type UserAuthMethodSearchQueries struct { + SearchRequest + Queries []SearchQuery +} + +func (q *Queries) UserAuthMethodByIDs(ctx context.Context, userID, tokenID, resourceOwner string, methodType domain.UserAuthMethodType) (*AuthMethod, error) { + stmt, scan := prepareUserAuthMethodQuery() + query, args, err := stmt.Where(sq.Eq{ + UserAuthMethodColumnUserID.identifier(): userID, + UserAuthMethodColumnTokenID.identifier(): tokenID, + UserAuthMethodColumnResourceOwner.identifier(): resourceOwner, + UserAuthMethodColumnMethodType.identifier(): methodType, + }).ToSql() + if err != nil { + return nil, errors.ThrowInternal(err, "QUERY-2m00Q", "Errors.Query.SQLStatment") + } + + row := q.client.QueryRowContext(ctx, query, args...) + return scan(row) +} + +func (q *Queries) SearchUserAuthMethods(ctx context.Context, queries *UserAuthMethodSearchQueries) (userAuthMethods *AuthMethods, err error) { + query, scan := prepareUserAuthMethodsQuery() + stmt, args, err := queries.toQuery(query).ToSql() + if err != nil { + return nil, errors.ThrowInvalidArgument(err, "QUERY-j9NJd", "Errors.Query.InvalidRequest") + } + + rows, err := q.client.QueryContext(ctx, stmt, args...) + if err != nil { + return nil, errors.ThrowInternal(err, "QUERY-3n99f", "Errors.Internal") + } + userAuthMethods, err = scan(rows) + if err != nil { + return nil, err + } + userAuthMethods.LatestSequence, err = q.latestSequence(ctx, userAuthMethodTable) + return userAuthMethods, err +} + +func NewUserAuthMethodUserIDSearchQuery(value string) (SearchQuery, error) { + return NewTextQuery(UserAuthMethodColumnUserID, value, TextEquals) +} + +func NewUserAuthMethodTokenIDSearchQuery(value string) (SearchQuery, error) { + return NewTextQuery(UserAuthMethodColumnTokenID, value, TextEquals) +} + +func NewUserAuthMethodResourceOwnerSearchQuery(value string) (SearchQuery, error) { + return NewTextQuery(UserAuthMethodColumnResourceOwner, value, TextEquals) +} + +func NewUserAuthMethodTypeSearchQuery(value domain.UserAuthMethodType) (SearchQuery, error) { + return NewNumberQuery(UserAuthMethodColumnMethodType, value, NumberEquals) +} + +func NewUserAuthMethodTypesSearchQuery(values ...domain.UserAuthMethodType) (SearchQuery, error) { + list := make([]interface{}, len(values)) + for i, value := range values { + list[i] = value + } + return NewListQuery(UserAuthMethodColumnMethodType, list, ListIn) +} + +func (r *UserAuthMethodSearchQueries) AppendResourceOwnerQuery(orgID string) error { + query, err := NewUserAuthMethodResourceOwnerSearchQuery(orgID) + if err != nil { + return err + } + r.Queries = append(r.Queries, query) + return nil +} + +func (r *UserAuthMethodSearchQueries) AppendUserIDQuery(userID string) error { + query, err := NewUserAuthMethodUserIDSearchQuery(userID) + if err != nil { + return err + } + r.Queries = append(r.Queries, query) + return nil +} + +func (r *UserAuthMethodSearchQueries) AppendTokenIDQuery(tokenID string) error { + query, err := NewUserAuthMethodTokenIDSearchQuery(tokenID) + if err != nil { + return err + } + r.Queries = append(r.Queries, query) + return nil +} + +func (r *UserAuthMethodSearchQueries) AppendAuthMethodQuery(authMethod domain.UserAuthMethodType) error { + query, err := NewUserAuthMethodTypeSearchQuery(authMethod) + if err != nil { + return err + } + r.Queries = append(r.Queries, query) + return nil +} + +func (r *UserAuthMethodSearchQueries) AppendAuthMethodsQuery(authMethod ...domain.UserAuthMethodType) error { + query, err := NewUserAuthMethodTypesSearchQuery(authMethod...) + if err != nil { + return err + } + r.Queries = append(r.Queries, query) + return nil +} + +func (q *UserAuthMethodSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder { + query = q.SearchRequest.toQuery(query) + for _, q := range q.Queries { + query = q.toQuery(query) + } + return query +} + +func prepareUserAuthMethodQuery() (sq.SelectBuilder, func(*sql.Row) (*AuthMethod, error)) { + return sq.Select( + UserAuthMethodColumnTokenID.identifier(), + UserAuthMethodColumnCreationDate.identifier(), + UserAuthMethodColumnChangeDate.identifier(), + UserAuthMethodColumnResourceOwner.identifier(), + UserAuthMethodColumnUserID.identifier(), + UserAuthMethodColumnSequence.identifier(), + UserAuthMethodColumnName.identifier(), + UserAuthMethodColumnState.identifier(), + UserAuthMethodColumnMethodType.identifier()). + From(userAuthMethodTable.identifier()).PlaceholderFormat(sq.Dollar), + func(row *sql.Row) (*AuthMethod, error) { + authMethod := new(AuthMethod) + err := row.Scan( + &authMethod.TokenID, + &authMethod.CreationDate, + &authMethod.ChangeDate, + &authMethod.ResourceOwner, + &authMethod.UserID, + &authMethod.Sequence, + &authMethod.Name, + &authMethod.State, + &authMethod.Type, + ) + if err != nil { + if errs.Is(err, sql.ErrNoRows) { + return nil, errors.ThrowNotFound(err, "QUERY-dniiF", "Errors.AuthMethod.NotFound") + } + return nil, errors.ThrowInternal(err, "QUERY-3n9Fs", "Errors.Internal") + } + return authMethod, nil + } +} + +func prepareUserAuthMethodsQuery() (sq.SelectBuilder, func(*sql.Rows) (*AuthMethods, error)) { + return sq.Select( + UserAuthMethodColumnTokenID.identifier(), + UserAuthMethodColumnCreationDate.identifier(), + UserAuthMethodColumnChangeDate.identifier(), + UserAuthMethodColumnResourceOwner.identifier(), + UserAuthMethodColumnUserID.identifier(), + UserAuthMethodColumnSequence.identifier(), + UserAuthMethodColumnName.identifier(), + UserAuthMethodColumnState.identifier(), + UserAuthMethodColumnMethodType.identifier(), + countColumn.identifier()). + From(userAuthMethodTable.identifier()).PlaceholderFormat(sq.Dollar), + func(rows *sql.Rows) (*AuthMethods, error) { + userAuthMethods := make([]*AuthMethod, 0) + var count uint64 + for rows.Next() { + authMethod := new(AuthMethod) + err := rows.Scan( + &authMethod.TokenID, + &authMethod.CreationDate, + &authMethod.ChangeDate, + &authMethod.ResourceOwner, + &authMethod.UserID, + &authMethod.Sequence, + &authMethod.Name, + &authMethod.State, + &authMethod.Type, + &count, + ) + if err != nil { + return nil, err + } + userAuthMethods = append(userAuthMethods, authMethod) + } + + if err := rows.Close(); err != nil { + return nil, errors.ThrowInternal(err, "QUERY-3n9fl", "Errors.Query.CloseRows") + } + + return &AuthMethods{ + AuthMethods: userAuthMethods, + SearchResponse: SearchResponse{ + Count: count, + }, + }, nil + } +} diff --git a/internal/query/user_auth_method_test.go b/internal/query/user_auth_method_test.go new file mode 100644 index 0000000000..c1ccf042fd --- /dev/null +++ b/internal/query/user_auth_method_test.go @@ -0,0 +1,331 @@ +package query + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "regexp" + "testing" + + "github.com/caos/zitadel/internal/domain" + errs "github.com/caos/zitadel/internal/errors" +) + +func Test_UserAuthMethodPrepares(t *testing.T) { + type want struct { + sqlExpectations sqlExpectation + err checkErr + } + tests := []struct { + name string + prepare interface{} + want want + object interface{} + }{ + { + name: "prepareUserAuthMethodsQuery no result", + prepare: prepareUserAuthMethodsQuery, + want: want{ + sqlExpectations: mockQueries( + regexp.QuoteMeta(`SELECT zitadel.projections.user_auth_methods.token_id,`+ + ` zitadel.projections.user_auth_methods.creation_date,`+ + ` zitadel.projections.user_auth_methods.change_date,`+ + ` zitadel.projections.user_auth_methods.resource_owner,`+ + ` zitadel.projections.user_auth_methods.user_id,`+ + ` zitadel.projections.user_auth_methods.sequence,`+ + ` zitadel.projections.user_auth_methods.name,`+ + ` zitadel.projections.user_auth_methods.state,`+ + ` zitadel.projections.user_auth_methods.method_type,`+ + ` COUNT(*) OVER ()`+ + ` FROM zitadel.projections.user_auth_methods`), + nil, + nil, + ), + }, + object: &AuthMethods{AuthMethods: []*AuthMethod{}}, + }, + { + name: "prepareUserAuthMethodsQuery one result", + prepare: prepareUserAuthMethodsQuery, + want: want{ + sqlExpectations: mockQueries( + regexp.QuoteMeta(`SELECT zitadel.projections.user_auth_methods.token_id,`+ + ` zitadel.projections.user_auth_methods.creation_date,`+ + ` zitadel.projections.user_auth_methods.change_date,`+ + ` zitadel.projections.user_auth_methods.resource_owner,`+ + ` zitadel.projections.user_auth_methods.user_id,`+ + ` zitadel.projections.user_auth_methods.sequence,`+ + ` zitadel.projections.user_auth_methods.name,`+ + ` zitadel.projections.user_auth_methods.state,`+ + ` zitadel.projections.user_auth_methods.method_type,`+ + ` COUNT(*) OVER ()`+ + ` FROM zitadel.projections.user_auth_methods`), + []string{ + "token_id", + "creation_date", + "change_date", + "resource_owner", + "user_id", + "sequence", + "name", + "state", + "method_type", + "count", + }, + [][]driver.Value{ + { + "token_id", + testNow, + testNow, + "ro", + "user_id", + uint64(20211108), + "name", + domain.MFAStateReady, + domain.UserAuthMethodTypeU2F, + }, + }, + ), + }, + object: &AuthMethods{ + SearchResponse: SearchResponse{ + Count: 1, + }, + AuthMethods: []*AuthMethod{ + { + TokenID: "token_id", + CreationDate: testNow, + ChangeDate: testNow, + ResourceOwner: "ro", + UserID: "user_id", + Sequence: 20211108, + Name: "name", + State: domain.MFAStateReady, + Type: domain.UserAuthMethodTypeU2F, + }, + }, + }, + }, + { + name: "prepareUserAuthMethodsQuery multiple result", + prepare: prepareUserAuthMethodsQuery, + want: want{ + sqlExpectations: mockQueries( + regexp.QuoteMeta(`SELECT zitadel.projections.user_auth_methods.token_id,`+ + ` zitadel.projections.user_auth_methods.creation_date,`+ + ` zitadel.projections.user_auth_methods.change_date,`+ + ` zitadel.projections.user_auth_methods.resource_owner,`+ + ` zitadel.projections.user_auth_methods.user_id,`+ + ` zitadel.projections.user_auth_methods.sequence,`+ + ` zitadel.projections.user_auth_methods.name,`+ + ` zitadel.projections.user_auth_methods.state,`+ + ` zitadel.projections.user_auth_methods.method_type,`+ + ` COUNT(*) OVER ()`+ + ` FROM zitadel.projections.user_auth_methods`), + []string{ + "token_id", + "creation_date", + "change_date", + "resource_owner", + "user_id", + "sequence", + "name", + "state", + "method_type", + "count", + }, + [][]driver.Value{ + { + "token_id", + testNow, + testNow, + "ro", + "user_id", + uint64(20211108), + "name", + domain.MFAStateReady, + domain.UserAuthMethodTypeU2F, + }, + { + "token_id-2", + testNow, + testNow, + "ro", + "user_id", + uint64(20211108), + "name-2", + domain.MFAStateReady, + domain.UserAuthMethodTypePasswordless, + }, + }, + ), + }, + object: &AuthMethods{ + SearchResponse: SearchResponse{ + Count: 2, + }, + AuthMethods: []*AuthMethod{ + { + TokenID: "token_id", + CreationDate: testNow, + ChangeDate: testNow, + ResourceOwner: "ro", + UserID: "user_id", + Sequence: 20211108, + Name: "name", + State: domain.MFAStateReady, + Type: domain.UserAuthMethodTypeU2F, + }, + { + TokenID: "token_id-2", + CreationDate: testNow, + ChangeDate: testNow, + ResourceOwner: "ro", + UserID: "user_id", + Sequence: 20211108, + Name: "name-2", + State: domain.MFAStateReady, + Type: domain.UserAuthMethodTypePasswordless, + }, + }, + }, + }, + { + name: "prepareUserAuthMethodsQuery sql err", + prepare: prepareUserAuthMethodsQuery, + want: want{ + sqlExpectations: mockQueryErr( + regexp.QuoteMeta(`SELECT zitadel.projections.user_auth_methods.token_id,`+ + ` zitadel.projections.user_auth_methods.creation_date,`+ + ` zitadel.projections.user_auth_methods.change_date,`+ + ` zitadel.projections.user_auth_methods.resource_owner,`+ + ` zitadel.projections.user_auth_methods.user_id,`+ + ` zitadel.projections.user_auth_methods.sequence,`+ + ` zitadel.projections.user_auth_methods.name,`+ + ` zitadel.projections.user_auth_methods.state,`+ + ` zitadel.projections.user_auth_methods.method_type,`+ + ` COUNT(*) OVER ()`+ + ` FROM zitadel.projections.user_auth_methods`), + sql.ErrConnDone, + ), + err: func(err error) (error, bool) { + if !errors.Is(err, sql.ErrConnDone) { + return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false + } + return nil, true + }, + }, + object: nil, + }, + { + name: "prepareUserAuthMethodQuery no result", + prepare: prepareUserAuthMethodQuery, + want: want{ + sqlExpectations: mockQueries( + `SELECT zitadel.projections.user_auth_methods.token_id,`+ + ` zitadel.projections.user_auth_methods.creation_date,`+ + ` zitadel.projections.user_auth_methods.change_date,`+ + ` zitadel.projections.user_auth_methods.resource_owner,`+ + ` zitadel.projections.user_auth_methods.user_id,`+ + ` zitadel.projections.user_auth_methods.sequence,`+ + ` zitadel.projections.user_auth_methods.name,`+ + ` zitadel.projections.user_auth_methods.state,`+ + ` zitadel.projections.user_auth_methods.method_type`+ + ` FROM zitadel.projections.user_auth_methods`, + nil, + nil, + ), + err: func(err error) (error, bool) { + if !errs.IsNotFound(err) { + return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false + } + return nil, true + }, + }, + object: (*AuthMethod)(nil), + }, + { + name: "prepareUserAuthMethodQuery found", + prepare: prepareUserAuthMethodQuery, + want: want{ + sqlExpectations: mockQuery( + regexp.QuoteMeta(`SELECT zitadel.projections.user_auth_methods.token_id,`+ + ` zitadel.projections.user_auth_methods.creation_date,`+ + ` zitadel.projections.user_auth_methods.change_date,`+ + ` zitadel.projections.user_auth_methods.resource_owner,`+ + ` zitadel.projections.user_auth_methods.user_id,`+ + ` zitadel.projections.user_auth_methods.sequence,`+ + ` zitadel.projections.user_auth_methods.name,`+ + ` zitadel.projections.user_auth_methods.state,`+ + ` zitadel.projections.user_auth_methods.method_type`+ + ` FROM zitadel.projections.user_auth_methods`), + []string{ + "token_id", + "creation_date", + "change_date", + "resource_owner", + "user_id", + "sequence", + "name", + "state", + "method_type", + }, + []driver.Value{ + "token_id", + testNow, + testNow, + "ro", + "user_id", + uint64(20211108), + "name", + domain.MFAStateReady, + domain.UserAuthMethodTypeU2F, + }, + ), + }, + object: &AuthMethod{ + TokenID: "token_id", + CreationDate: testNow, + ChangeDate: testNow, + ResourceOwner: "ro", + UserID: "user_id", + Sequence: 20211108, + Name: "name", + State: domain.MFAStateReady, + Type: domain.UserAuthMethodTypeU2F, + }, + }, + { + name: "prepareUserAuthMethodQuery sql err", + prepare: prepareUserAuthMethodQuery, + want: want{ + sqlExpectations: mockQueryErr( + regexp.QuoteMeta(`SELECT zitadel.projections.user_auth_methods.token_id,`+ + ` zitadel.projections.user_auth_methods.creation_date,`+ + ` zitadel.projections.user_auth_methods.change_date,`+ + ` zitadel.projections.user_auth_methods.resource_owner,`+ + ` zitadel.projections.user_auth_methods.user_id,`+ + ` zitadel.projections.user_auth_methods.sequence,`+ + ` zitadel.projections.user_auth_methods.name,`+ + ` zitadel.projections.user_auth_methods.state,`+ + ` zitadel.projections.user_auth_methods.method_type`+ + ` FROM zitadel.projections.user_auth_methods`), + sql.ErrConnDone, + ), + err: func(err error) (error, bool) { + if !errors.Is(err, sql.ErrConnDone) { + return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false + } + return nil, true + }, + }, + object: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) + }) + } +}