diff --git a/ADOPTERS.md b/ADOPTERS.md index da37b9ffb0..0573099cf9 100644 --- a/ADOPTERS.md +++ b/ADOPTERS.md @@ -16,6 +16,7 @@ If you are using Zitadel, please consider adding yourself as a user with a quick | devOS: Sanity Edition | [@devOS-Sanity-Edition](https://github.com/devOS-Sanity-Edition) | Uses SSO Auth for every piece of our internal and external infrastructure | | CNAP.tech | [@cnap-tech](https://github.com/cnap-tech) | Using Zitadel for authentication and authorization in cloud-native applications | | Minekube | [@minekube](https://github.com/minekube) | Leveraging Zitadel for secure user authentication in gaming infrastructure | +| Dribdat | [@dribdat](https://github.com/dribdat) | Educating people about strong auth and resilient identity at hackathons | | Micromate | [@sschoeb](https://github.com/sschoeb) | Using Zitadel for authentication and authorization for learners and managers in our digital learning assistant as well as in the Micromate manage platform | | Smat.io | [@smatio](https://github.com/smatio) - [@lukasver](https://github.com/lukasver) | Zitadel for authentication in cloud applications while offering B2B portfolio management solutions for professional investors | |hirschengraben | [hirschengraben.io](hirschengraben.io) | Using Zitadel as IDP for a multitenant B2B dispatch app for bike messengers | diff --git a/cmd/defaults.yaml b/cmd/defaults.yaml index 45ee7381c2..326dcc69a8 100644 --- a/cmd/defaults.yaml +++ b/cmd/defaults.yaml @@ -110,24 +110,13 @@ PublicHostHeaders: # ZITADEL_PUBLICHOSTHEADERS WebAuthNName: ZITADEL # ZITADEL_WEBAUTHNNAME Database: - # ZITADEL manages three database connection pools. - # The *ConnRatio settings define the ratio of how many connections from - # MaxOpenConns and MaxIdleConns are used to push events and spool projections. - # Remaining connection are used for queries (search). - # Values may not be negative and the sum of the ratios must always be less than 1. - # For example this defaults define 15 MaxOpenConns overall. - # - 15*0.2=3 connections are allocated to the event pusher; - # - 15*0.135=2 connections are allocated to the projection spooler; - # - 15-(3+2)=10 connections are remaining for queries; - EventPushConnRatio: 0.2 # ZITADEL_DATABASE_COCKROACH_EVENTPUSHCONNRATIO - ProjectionSpoolerConnRatio: 0.135 # ZITADEL_DATABASE_COCKROACH_PROJECTIONSPOOLERCONNRATIO # CockroachDB is the default database of ZITADEL cockroach: Host: localhost # ZITADEL_DATABASE_COCKROACH_HOST Port: 26257 # ZITADEL_DATABASE_COCKROACH_PORT Database: zitadel # ZITADEL_DATABASE_COCKROACH_DATABASE - MaxOpenConns: 15 # ZITADEL_DATABASE_COCKROACH_MAXOPENCONNS - MaxIdleConns: 12 # ZITADEL_DATABASE_COCKROACH_MAXIDLECONNS + MaxOpenConns: 5 # ZITADEL_DATABASE_COCKROACH_MAXOPENCONNS + MaxIdleConns: 2 # ZITADEL_DATABASE_COCKROACH_MAXIDLECONNS MaxConnLifetime: 30m # ZITADEL_DATABASE_COCKROACH_MAXCONNLIFETIME MaxConnIdleTime: 5m # ZITADEL_DATABASE_COCKROACH_MAXCONNIDLETIME Options: "" # ZITADEL_DATABASE_COCKROACH_OPTIONS @@ -590,6 +579,11 @@ SAML: # Company: ZITADEL # ZITADEL_SAML_PROVIDERCONFIG_CONTACTPERSON_COMPANY # EmailAddress: hi@zitadel.com # ZITADEL_SAML_PROVIDERCONFIG_CONTACTPERSON_EMAILADDRESS +SCIM: + # default values whether an email/phone is considered verified when a users email/phone is created or updated + EmailVerified: true # ZITADEL_SCIM_EMAILVERIFIED + PhoneVerified: true # ZITADEL_SCIM_PHONEVERIFIED + Login: LanguageCookieName: zitadel.login.lang # ZITADEL_LOGIN_LANGUAGECOOKIENAME CSRFCookieName: zitadel.login.csrf # ZITADEL_LOGIN_CSRFCOOKIENAME @@ -608,6 +602,9 @@ Console: # 168h is 7 days, one week SharedMaxAge: 168h # ZITADEL_CONSOLE_LONGCACHE_SHAREDMAXAGE InstanceManagementURL: "" # ZITADEL_CONSOLE_INSTANCEMANAGEMENTURL + PostHog: + URL: "" # ZITADEL_CONSOLE_POSTHOG_URL + Token: "" # ZITADEL_CONSOLE_POSTHOG_TOKEN EncryptionKeys: DomainVerification: @@ -1124,6 +1121,7 @@ DefaultInstance: LoginDefaultOrg: true # ZITADEL_DEFAULTINSTANCE_FEATURES_LOGINDEFAULTORG # TriggerIntrospectionProjections: false # ZITADEL_DEFAULTINSTANCE_FEATURES_TRIGGERINTROSPECTIONPROJECTIONS # LegacyIntrospection: false # ZITADEL_DEFAULTINSTANCE_FEATURES_LEGACYINTROSPECTION + # PermissionCheckV2: false # ZITADEL_DEFAULTINSTANCE_FEATURES_PERMISSIONCHECKV2 Limits: # AuditLogRetention limits the number of events that can be queried via the events API by their age. # A value of "0s" means that all events are available. @@ -1187,6 +1185,9 @@ InternalAuthZ: # Configure the RolePermissionMappings by environment variable using JSON notation: # ZITADEL_INTERNALAUTHZ_ROLEPERMISSIONMAPPINGS='[{"role": "IAM_OWNER", "permissions": ["iam.write"]}, {"role": "ORG_OWNER", "permissions": ["org.write"]}]' # Beware that if you configure the RolePermissionMappings by environment variable, all the default RolePermissionMappings are lost. + # + # Warning: RolePermissionMappings are synhronized to the database. + # Changes here will only be applied after running `zitadel setup` or `zitadel start-from-setup`. RolePermissionMappings: - Role: "SYSTEM_OWNER" Permissions: diff --git a/cmd/initialise/init.go b/cmd/initialise/init.go index fba5098fa2..02fd481eab 100644 --- a/cmd/initialise/init.go +++ b/cmd/initialise/init.go @@ -9,7 +9,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/database" - "github.com/zitadel/zitadel/internal/database/dialect" ) var ( @@ -79,7 +78,7 @@ func initialise(ctx context.Context, config database.Config, steps ...func(conte return err } - db, err := database.Connect(config, true, dialect.DBPurposeQuery) + db, err := database.Connect(config, true) if err != nil { return err } diff --git a/cmd/initialise/verify_zitadel.go b/cmd/initialise/verify_zitadel.go index a5ce1fd57c..1ae85a21fa 100644 --- a/cmd/initialise/verify_zitadel.go +++ b/cmd/initialise/verify_zitadel.go @@ -11,7 +11,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/database" - "github.com/zitadel/zitadel/internal/database/dialect" es_v3 "github.com/zitadel/zitadel/internal/eventstore/v3" ) @@ -85,7 +84,7 @@ func VerifyZitadel(ctx context.Context, db *database.DB, config database.Config) func verifyZitadel(ctx context.Context, config database.Config) error { logging.WithFields("database", config.DatabaseName()).Info("verify zitadel") - db, err := database.Connect(config, false, dialect.DBPurposeQuery) + db, err := database.Connect(config, false) if err != nil { return err } diff --git a/cmd/key/key.go b/cmd/key/key.go index 2691932784..1dba8fd969 100644 --- a/cmd/key/key.go +++ b/cmd/key/key.go @@ -12,7 +12,6 @@ import ( "github.com/zitadel/zitadel/internal/crypto" cryptoDB "github.com/zitadel/zitadel/internal/crypto/database" "github.com/zitadel/zitadel/internal/database" - "github.com/zitadel/zitadel/internal/database/dialect" "github.com/zitadel/zitadel/internal/zerrors" ) @@ -124,7 +123,7 @@ func openFile(fileName string) (io.Reader, error) { } func keyStorage(config database.Config, masterKey string) (crypto.KeyStorage, error) { - db, err := database.Connect(config, false, dialect.DBPurposeQuery) + db, err := database.Connect(config, false) if err != nil { return nil, err } diff --git a/cmd/mirror/auth.go b/cmd/mirror/auth.go index df94708e71..0eba10d05f 100644 --- a/cmd/mirror/auth.go +++ b/cmd/mirror/auth.go @@ -12,7 +12,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/database" - "github.com/zitadel/zitadel/internal/database/dialect" ) func authCmd() *cobra.Command { @@ -34,11 +33,11 @@ Only auth requests are mirrored`, } func copyAuth(ctx context.Context, config *Migration) { - sourceClient, err := database.Connect(config.Source, false, dialect.DBPurposeQuery) + sourceClient, err := database.Connect(config.Source, false) logging.OnError(err).Fatal("unable to connect to source database") defer sourceClient.Close() - destClient, err := database.Connect(config.Destination, false, dialect.DBPurposeEventPusher) + destClient, err := database.Connect(config.Destination, false) logging.OnError(err).Fatal("unable to connect to destination database") defer destClient.Close() diff --git a/cmd/mirror/event_store.go b/cmd/mirror/event_store.go index 23145bdc37..3825462126 100644 --- a/cmd/mirror/event_store.go +++ b/cmd/mirror/event_store.go @@ -14,7 +14,6 @@ import ( "github.com/zitadel/logging" db "github.com/zitadel/zitadel/internal/database" - "github.com/zitadel/zitadel/internal/database/dialect" "github.com/zitadel/zitadel/internal/id" "github.com/zitadel/zitadel/internal/v2/database" "github.com/zitadel/zitadel/internal/v2/eventstore" @@ -44,11 +43,11 @@ Migrate only copies events2 and unique constraints`, } func copyEventstore(ctx context.Context, config *Migration) { - sourceClient, err := db.Connect(config.Source, false, dialect.DBPurposeEventPusher) + sourceClient, err := db.Connect(config.Source, false) logging.OnError(err).Fatal("unable to connect to source database") defer sourceClient.Close() - destClient, err := db.Connect(config.Destination, false, dialect.DBPurposeEventPusher) + destClient, err := db.Connect(config.Destination, false) logging.OnError(err).Fatal("unable to connect to destination database") defer destClient.Close() diff --git a/cmd/mirror/projections.go b/cmd/mirror/projections.go index ae903d90c5..a4987a48f6 100644 --- a/cmd/mirror/projections.go +++ b/cmd/mirror/projections.go @@ -30,7 +30,6 @@ import ( "github.com/zitadel/zitadel/internal/config/systemdefaults" crypto_db "github.com/zitadel/zitadel/internal/crypto/database" "github.com/zitadel/zitadel/internal/database" - "github.com/zitadel/zitadel/internal/database/dialect" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore" old_es "github.com/zitadel/zitadel/internal/eventstore/repository/sql" @@ -106,7 +105,7 @@ func projections( ) { start := time.Now() - client, err := database.Connect(config.Destination, false, dialect.DBPurposeQuery) + client, err := database.Connect(config.Destination, false) logging.OnError(err).Fatal("unable to connect to database") keyStorage, err := crypto_db.NewKeyStorage(client, masterKey) @@ -119,9 +118,7 @@ func projections( logging.OnError(err).Fatal("unable create static storage") config.Eventstore.Querier = old_es.NewCRDB(client) - esPusherDBClient, err := database.Connect(config.Destination, false, dialect.DBPurposeEventPusher) - logging.OnError(err).Fatal("unable to connect eventstore push client") - config.Eventstore.Pusher = new_es.NewEventstore(esPusherDBClient) + config.Eventstore.Pusher = new_es.NewEventstore(client) es := eventstore.NewEventstore(config.Eventstore) esV4 := es_v4.NewEventstoreFromOne(es_v4_pg.New(client, &es_v4_pg.Config{ MaxRetries: config.Eventstore.MaxRetries, diff --git a/cmd/mirror/system.go b/cmd/mirror/system.go index e16836aa8c..00b48eb491 100644 --- a/cmd/mirror/system.go +++ b/cmd/mirror/system.go @@ -12,7 +12,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/database" - "github.com/zitadel/zitadel/internal/database/dialect" ) func systemCmd() *cobra.Command { @@ -34,11 +33,11 @@ Only keys and assets are mirrored`, } func copySystem(ctx context.Context, config *Migration) { - sourceClient, err := database.Connect(config.Source, false, dialect.DBPurposeQuery) + sourceClient, err := database.Connect(config.Source, false) logging.OnError(err).Fatal("unable to connect to source database") defer sourceClient.Close() - destClient, err := database.Connect(config.Destination, false, dialect.DBPurposeEventPusher) + destClient, err := database.Connect(config.Destination, false) logging.OnError(err).Fatal("unable to connect to destination database") defer destClient.Close() diff --git a/cmd/mirror/verify.go b/cmd/mirror/verify.go index 68c927d091..e1a507d9fe 100644 --- a/cmd/mirror/verify.go +++ b/cmd/mirror/verify.go @@ -13,7 +13,6 @@ import ( cryptoDatabase "github.com/zitadel/zitadel/internal/crypto/database" "github.com/zitadel/zitadel/internal/database" - "github.com/zitadel/zitadel/internal/database/dialect" "github.com/zitadel/zitadel/internal/query/projection" ) @@ -37,11 +36,11 @@ var schemas = []string{ } func verifyMigration(ctx context.Context, config *Migration) { - sourceClient, err := database.Connect(config.Source, false, dialect.DBPurposeQuery) + sourceClient, err := database.Connect(config.Source, false) logging.OnError(err).Fatal("unable to connect to source database") defer sourceClient.Close() - destClient, err := database.Connect(config.Destination, false, dialect.DBPurposeEventPusher) + destClient, err := database.Connect(config.Destination, false) logging.OnError(err).Fatal("unable to connect to destination database") defer destClient.Close() diff --git a/cmd/setup/41.go b/cmd/setup/41.go deleted file mode 100644 index fa4a1d5a4b..0000000000 --- a/cmd/setup/41.go +++ /dev/null @@ -1,44 +0,0 @@ -package setup - -import ( - "context" - - "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/eventstore" - "github.com/zitadel/zitadel/internal/query/projection" - "github.com/zitadel/zitadel/internal/repository/instance" -) - -type FillFieldsForInstanceDomains struct { - eventstore *eventstore.Eventstore -} - -func (mig *FillFieldsForInstanceDomains) Execute(ctx context.Context, _ eventstore.Event) error { - instances, err := mig.eventstore.InstanceIDs( - ctx, - eventstore.NewSearchQueryBuilder(eventstore.ColumnsInstanceIDs). - OrderDesc(). - AddQuery(). - AggregateTypes("instance"). - EventTypes(instance.InstanceAddedEventType). - Builder(), - ) - if err != nil { - return err - } - for _, instance := range instances { - ctx := authz.WithInstanceID(ctx, instance) - if err := projection.InstanceDomainFields.Trigger(ctx); err != nil { - return err - } - } - return nil -} - -func (mig *FillFieldsForInstanceDomains) String() string { - return "repeatable_fill_fields_for_instance_domains" -} - -func (f *FillFieldsForInstanceDomains) Check(lastRun map[string]interface{}) bool { - return true -} diff --git a/cmd/setup/45.go b/cmd/setup/45.go new file mode 100644 index 0000000000..d8318a6d59 --- /dev/null +++ b/cmd/setup/45.go @@ -0,0 +1,111 @@ +package setup + +import ( + "context" + _ "embed" + "encoding/json" + "fmt" + + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/query/projection" + "github.com/zitadel/zitadel/internal/repository/instance" + "github.com/zitadel/zitadel/internal/repository/owner" + "github.com/zitadel/zitadel/internal/repository/project" +) + +var ( + //go:embed 45.sql + correctProjectOwnerEvents string +) + +type CorrectProjectOwners struct { + eventstore *eventstore.Eventstore +} + +func (mig *CorrectProjectOwners) Execute(ctx context.Context, _ eventstore.Event) error { + instances, err := mig.eventstore.InstanceIDs( + ctx, + eventstore.NewSearchQueryBuilder(eventstore.ColumnsInstanceIDs). + OrderDesc(). + AddQuery(). + AggregateTypes("instance"). + EventTypes(instance.InstanceAddedEventType). + Builder(), + ) + if err != nil { + return err + } + + ctx = authz.SetCtxData(ctx, authz.CtxData{UserID: "SETUP"}) + for i, instance := range instances { + ctx = authz.WithInstanceID(ctx, instance) + logging.WithFields("instance_id", instance, "migration", mig.String(), "progress", fmt.Sprintf("%d/%d", i+1, len(instances))).Info("correct owners of projects") + didCorrect, err := mig.correctInstanceProjects(ctx, instance) + if err != nil { + return err + } + if !didCorrect { + continue + } + _, err = projection.ProjectGrantProjection.Trigger(ctx) + logging.OnError(err).Debug("failed triggering project grant projection to update owners") + } + return nil +} + +func (mig *CorrectProjectOwners) correctInstanceProjects(ctx context.Context, instance string) (didCorrect bool, err error) { + var correctedOwners []eventstore.Command + + tx, err := mig.eventstore.Client().BeginTx(ctx, nil) + if err != nil { + return false, err + } + defer func() { + if err != nil { + _ = tx.Rollback() + return + } + err = tx.Commit() + }() + + rows, err := tx.QueryContext(ctx, correctProjectOwnerEvents, instance) + if err != nil { + return false, err + } + defer rows.Close() + + for rows.Next() { + aggregate := &eventstore.Aggregate{ + InstanceID: instance, + Type: project.AggregateType, + Version: project.AggregateVersion, + } + var payload json.RawMessage + err := rows.Scan( + &aggregate.ID, + &aggregate.ResourceOwner, + &payload, + ) + if err != nil { + return false, err + } + previousOwners := make(map[uint32]string) + if err := json.Unmarshal(payload, &previousOwners); err != nil { + return false, err + } + correctedOwners = append(correctedOwners, owner.NewCorrected(ctx, aggregate, previousOwners)) + } + if rows.Err() != nil { + return false, rows.Err() + } + + _, err = mig.eventstore.PushWithClient(ctx, tx, correctedOwners...) + return len(correctedOwners) > 0, err +} + +func (*CorrectProjectOwners) String() string { + return "43_correct_project_owners" +} diff --git a/cmd/setup/45.sql b/cmd/setup/45.sql new file mode 100644 index 0000000000..0e90a2683d --- /dev/null +++ b/cmd/setup/45.sql @@ -0,0 +1,79 @@ +WITH corrupt_streams AS ( + select + e.instance_id + , e.aggregate_type + , e.aggregate_id + , min(e.sequence) as min_sequence + , count(distinct e.owner) as owner_count + from + eventstore.events2 e + where + e.instance_id = $1 + and aggregate_type = 'project' + group by + e.instance_id + , e.aggregate_type + , e.aggregate_id + having + count(distinct e.owner) > 1 +), correct_owners AS ( + select + e.instance_id + , e.aggregate_type + , e.aggregate_id + , e.owner + from + eventstore.events2 e + join + corrupt_streams cs + on + e.instance_id = cs.instance_id + and e.aggregate_type = cs.aggregate_type + and e.aggregate_id = cs.aggregate_id + and e.sequence = cs.min_sequence +), wrong_events AS ( + select + e.instance_id + , e.aggregate_type + , e.aggregate_id + , e.sequence + , e.owner wrong_owner + , co.owner correct_owner + from + eventstore.events2 e + join + correct_owners co + on + e.instance_id = co.instance_id + and e.aggregate_type = co.aggregate_type + and e.aggregate_id = co.aggregate_id + and e.owner <> co.owner +), updated_events AS ( + UPDATE eventstore.events2 e + SET owner = we.correct_owner + FROM + wrong_events we + WHERE + e.instance_id = we.instance_id + and e.aggregate_type = we.aggregate_type + and e.aggregate_id = we.aggregate_id + and e.sequence = we.sequence + RETURNING + we.aggregate_id + , we.correct_owner + , we.sequence + , we.wrong_owner +) +SELECT + ue.aggregate_id + , ue.correct_owner + , jsonb_object_agg( + ue.sequence::TEXT --formant to string because crdb is not able to handle int + , ue.wrong_owner + ) payload +FROM + updated_events ue +GROUP BY + ue.aggregate_id + , ue.correct_owner +; diff --git a/cmd/setup/46.go b/cmd/setup/46.go new file mode 100644 index 0000000000..e48b16e4b0 --- /dev/null +++ b/cmd/setup/46.go @@ -0,0 +1,39 @@ +package setup + +import ( + "context" + "embed" + "fmt" + + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/eventstore" +) + +type InitPermissionFunctions struct { + eventstoreClient *database.DB +} + +var ( + //go:embed 46/*.sql + permissionFunctions embed.FS +) + +func (mig *InitPermissionFunctions) Execute(ctx context.Context, _ eventstore.Event) error { + statements, err := readStatements(permissionFunctions, "46", "") + if err != nil { + return err + } + for _, stmt := range statements { + logging.WithFields("file", stmt.file, "migration", mig.String()).Info("execute statement") + if _, err := mig.eventstoreClient.ExecContext(ctx, stmt.query); err != nil { + return fmt.Errorf("%s %s: %w", mig.String(), stmt.file, err) + } + } + return nil +} + +func (*InitPermissionFunctions) String() string { + return "46_init_permission_functions" +} diff --git a/cmd/setup/46/01-role_permissions_view.sql b/cmd/setup/46/01-role_permissions_view.sql new file mode 100644 index 0000000000..f0a8413125 --- /dev/null +++ b/cmd/setup/46/01-role_permissions_view.sql @@ -0,0 +1,6 @@ +CREATE OR REPLACE VIEW eventstore.role_permissions AS +SELECT instance_id, aggregate_id, object_id as role, text_value as permission +FROM eventstore.fields +WHERE aggregate_type = 'permission' +AND object_type = 'role_permission' +AND field_name = 'permission'; diff --git a/cmd/setup/46/02-instance_orgs_view.sql b/cmd/setup/46/02-instance_orgs_view.sql new file mode 100644 index 0000000000..aa59fcde6a --- /dev/null +++ b/cmd/setup/46/02-instance_orgs_view.sql @@ -0,0 +1,6 @@ +CREATE OR REPLACE VIEW eventstore.instance_orgs AS +SELECT instance_id, aggregate_id as org_id +FROM eventstore.fields +WHERE aggregate_type = 'org' +AND object_type = 'org' +AND field_name = 'state'; diff --git a/cmd/setup/46/03-instance_members_view.sql b/cmd/setup/46/03-instance_members_view.sql new file mode 100644 index 0000000000..cf47610f42 --- /dev/null +++ b/cmd/setup/46/03-instance_members_view.sql @@ -0,0 +1,6 @@ +CREATE OR REPLACE VIEW eventstore.instance_members AS +SELECT instance_id, object_id as user_id, text_value as role +FROM eventstore.fields +WHERE aggregate_type = 'instance' +AND object_type = 'instance_member_role' +AND field_name = 'instance_role'; diff --git a/cmd/setup/46/04-org_members_view.sql b/cmd/setup/46/04-org_members_view.sql new file mode 100644 index 0000000000..7477d9a816 --- /dev/null +++ b/cmd/setup/46/04-org_members_view.sql @@ -0,0 +1,6 @@ +CREATE OR REPLACE VIEW eventstore.org_members AS +SELECT instance_id, aggregate_id as org_id, object_id as user_id, text_value as role +FROM eventstore.fields +WHERE aggregate_type = 'org' +AND object_type = 'org_member_role' +AND field_name = 'org_role'; diff --git a/cmd/setup/46/05-project_members_view.sql b/cmd/setup/46/05-project_members_view.sql new file mode 100644 index 0000000000..0eed48cec3 --- /dev/null +++ b/cmd/setup/46/05-project_members_view.sql @@ -0,0 +1,6 @@ +CREATE OR REPLACE VIEW eventstore.project_members AS +SELECT instance_id, aggregate_id as project_id, object_id as user_id, text_value as role +FROM eventstore.fields +WHERE aggregate_type = 'project' +AND object_type = 'project_member_role' +AND field_name = 'project_role'; diff --git a/cmd/setup/46/06-permitted_orgs_function.sql b/cmd/setup/46/06-permitted_orgs_function.sql new file mode 100644 index 0000000000..0c8c0fc673 --- /dev/null +++ b/cmd/setup/46/06-permitted_orgs_function.sql @@ -0,0 +1,50 @@ +CREATE OR REPLACE FUNCTION eventstore.permitted_orgs( + instanceId TEXT + , userId TEXT + , perm TEXT + + , org_ids OUT TEXT[] +) + LANGUAGE 'plpgsql' + STABLE +AS $$ +DECLARE + matched_roles TEXT[]; -- roles containing permission +BEGIN + SELECT array_agg(rp.role) INTO matched_roles + FROM eventstore.role_permissions rp + WHERE rp.instance_id = instanceId + AND rp.permission = perm; + + -- First try if the permission was granted thru an instance-level role + DECLARE + has_instance_permission bool; + BEGIN + SELECT true INTO has_instance_permission + FROM eventstore.instance_members im + WHERE im.role = ANY(matched_roles) + AND im.instance_id = instanceId + AND im.user_id = userId + LIMIT 1; + + IF has_instance_permission THEN + -- Return all organizations + SELECT array_agg(o.org_id) INTO org_ids + FROM eventstore.instance_orgs o + WHERE o.instance_id = instanceId; + RETURN; + END IF; + END; + + -- Return the organizations where permission were granted thru org-level roles + SELECT array_agg(org_id) INTO org_ids + FROM ( + SELECT DISTINCT om.org_id + FROM eventstore.org_members om + WHERE om.role = ANY(matched_roles) + AND om.instance_id = instanceID + AND om.user_id = userId + ) AS orgs; + RETURN; +END; +$$; diff --git a/cmd/setup/cleanup.go b/cmd/setup/cleanup.go index e9bc832d21..943ac164ea 100644 --- a/cmd/setup/cleanup.go +++ b/cmd/setup/cleanup.go @@ -8,7 +8,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/database" - "github.com/zitadel/zitadel/internal/database/dialect" "github.com/zitadel/zitadel/internal/eventstore" old_es "github.com/zitadel/zitadel/internal/eventstore/repository/sql" new_es "github.com/zitadel/zitadel/internal/eventstore/v3" @@ -32,13 +31,11 @@ func Cleanup(config *Config) { logging.Info("cleanup started") - queryDBClient, err := database.Connect(config.Database, false, dialect.DBPurposeQuery) - logging.OnError(err).Fatal("unable to connect to database") - esPusherDBClient, err := database.Connect(config.Database, false, dialect.DBPurposeEventPusher) + dbClient, err := database.Connect(config.Database, false) logging.OnError(err).Fatal("unable to connect to database") - config.Eventstore.Pusher = new_es.NewEventstore(esPusherDBClient) - config.Eventstore.Querier = old_es.NewCRDB(queryDBClient) + config.Eventstore.Pusher = new_es.NewEventstore(dbClient) + config.Eventstore.Querier = old_es.NewCRDB(dbClient) es := eventstore.NewEventstore(config.Eventstore) step, err := migration.LastStuckStep(ctx, es) diff --git a/cmd/setup/config.go b/cmd/setup/config.go index 9f34c2baa5..6d9443fae0 100644 --- a/cmd/setup/config.go +++ b/cmd/setup/config.go @@ -87,6 +87,9 @@ func MustNewConfig(v *viper.Viper) *Config { id.Configure(config.Machine) + // Copy the global role permissions mappings to the instance until we allow instance-level configuration over the API. + config.DefaultInstance.RolePermissionMappings = config.InternalAuthZ.RolePermissionMappings + return config } @@ -130,6 +133,8 @@ type Steps struct { s42Apps7OIDCConfigsLoginVersion *Apps7OIDCConfigsLoginVersion s43CreateFieldsDomainIndex *CreateFieldsDomainIndex s44ReplaceCurrentSequencesIndex *ReplaceCurrentSequencesIndex + s45CorrectProjectOwners *CorrectProjectOwners + s46InitPermissionFunctions *InitPermissionFunctions } func MustNewSteps(v *viper.Viper) *Steps { diff --git a/cmd/setup/fill_fields.go b/cmd/setup/fill_fields.go new file mode 100644 index 0000000000..9dbb2fed7e --- /dev/null +++ b/cmd/setup/fill_fields.go @@ -0,0 +1,51 @@ +package setup + +import ( + "context" + "fmt" + + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/handler/v2" + "github.com/zitadel/zitadel/internal/repository/instance" +) + +type RepeatableFillFields struct { + eventstore *eventstore.Eventstore + handlers []*handler.FieldHandler +} + +func (mig *RepeatableFillFields) Execute(ctx context.Context, _ eventstore.Event) error { + instances, err := mig.eventstore.InstanceIDs( + ctx, + eventstore.NewSearchQueryBuilder(eventstore.ColumnsInstanceIDs). + OrderDesc(). + AddQuery(). + AggregateTypes(instance.AggregateType). + EventTypes(instance.InstanceAddedEventType). + Builder(), + ) + if err != nil { + return err + } + for _, instance := range instances { + ctx := authz.WithInstanceID(ctx, instance) + for _, handler := range mig.handlers { + logging.WithFields("migration", mig.String(), "instance_id", instance, "handler", handler.String()).Info("run fields trigger") + if err := handler.Trigger(ctx); err != nil { + return fmt.Errorf("%s: %s: %w", mig.String(), handler.String(), err) + } + } + } + return nil +} + +func (mig *RepeatableFillFields) String() string { + return "repeatable_fill_fields" +} + +func (f *RepeatableFillFields) Check(lastRun map[string]interface{}) bool { + return true +} diff --git a/cmd/setup/setup.go b/cmd/setup/setup.go index 4ffef441af..a48b74acb8 100644 --- a/cmd/setup/setup.go +++ b/cmd/setup/setup.go @@ -26,9 +26,9 @@ import ( "github.com/zitadel/zitadel/internal/command" cryptoDB "github.com/zitadel/zitadel/internal/crypto/database" "github.com/zitadel/zitadel/internal/database" - "github.com/zitadel/zitadel/internal/database/dialect" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/handler/v2" old_es "github.com/zitadel/zitadel/internal/eventstore/repository/sql" new_es "github.com/zitadel/zitadel/internal/eventstore/v3" "github.com/zitadel/zitadel/internal/i18n" @@ -102,26 +102,22 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string) i18n.MustLoadSupportedLanguagesFromDir() - queryDBClient, err := database.Connect(config.Database, false, dialect.DBPurposeQuery) - logging.OnError(err).Fatal("unable to connect to database") - esPusherDBClient, err := database.Connect(config.Database, false, dialect.DBPurposeEventPusher) - logging.OnError(err).Fatal("unable to connect to database") - projectionDBClient, err := database.Connect(config.Database, false, dialect.DBPurposeProjectionSpooler) + dbClient, err := database.Connect(config.Database, false) logging.OnError(err).Fatal("unable to connect to database") - config.Eventstore.Querier = old_es.NewCRDB(queryDBClient) - esV3 := new_es.NewEventstore(esPusherDBClient) + config.Eventstore.Querier = old_es.NewCRDB(dbClient) + esV3 := new_es.NewEventstore(dbClient) config.Eventstore.Pusher = esV3 config.Eventstore.Searcher = esV3 eventstoreClient := eventstore.NewEventstore(config.Eventstore) logging.OnError(err).Fatal("unable to start eventstore") - eventstoreV4 := es_v4.NewEventstoreFromOne(es_v4_pg.New(queryDBClient, &es_v4_pg.Config{ + eventstoreV4 := es_v4.NewEventstoreFromOne(es_v4_pg.New(dbClient, &es_v4_pg.Config{ MaxRetries: config.Eventstore.MaxRetries, })) - steps.s1ProjectionTable = &ProjectionTable{dbClient: queryDBClient.DB} - steps.s2AssetsTable = &AssetTable{dbClient: queryDBClient.DB} + steps.s1ProjectionTable = &ProjectionTable{dbClient: dbClient.DB} + steps.s2AssetsTable = &AssetTable{dbClient: dbClient.DB} steps.FirstInstance.Skip = config.ForMirror || steps.FirstInstance.Skip steps.FirstInstance.instanceSetup = config.DefaultInstance @@ -129,7 +125,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string) steps.FirstInstance.smtpEncryptionKey = config.EncryptionKeys.SMTP steps.FirstInstance.oidcEncryptionKey = config.EncryptionKeys.OIDC steps.FirstInstance.masterKey = masterKey - steps.FirstInstance.db = queryDBClient + steps.FirstInstance.db = dbClient steps.FirstInstance.es = eventstoreClient steps.FirstInstance.defaults = config.SystemDefaults steps.FirstInstance.zitadelRoles = config.InternalAuthZ.RolePermissionMappings @@ -137,44 +133,46 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string) steps.FirstInstance.externalSecure = config.ExternalSecure steps.FirstInstance.externalPort = config.ExternalPort - steps.s5LastFailed = &LastFailed{dbClient: queryDBClient.DB} - steps.s6OwnerRemoveColumns = &OwnerRemoveColumns{dbClient: queryDBClient.DB} - steps.s7LogstoreTables = &LogstoreTables{dbClient: queryDBClient.DB, username: config.Database.Username(), dbType: config.Database.Type()} - steps.s8AuthTokens = &AuthTokenIndexes{dbClient: queryDBClient} - steps.CorrectCreationDate.dbClient = esPusherDBClient - steps.s12AddOTPColumns = &AddOTPColumns{dbClient: queryDBClient} - steps.s13FixQuotaProjection = &FixQuotaConstraints{dbClient: queryDBClient} - steps.s14NewEventsTable = &NewEventsTable{dbClient: esPusherDBClient} - steps.s15CurrentStates = &CurrentProjectionState{dbClient: queryDBClient} - steps.s16UniqueConstraintsLower = &UniqueConstraintToLower{dbClient: queryDBClient} - steps.s17AddOffsetToUniqueConstraints = &AddOffsetToCurrentStates{dbClient: queryDBClient} - steps.s18AddLowerFieldsToLoginNames = &AddLowerFieldsToLoginNames{dbClient: queryDBClient} - steps.s19AddCurrentStatesIndex = &AddCurrentSequencesIndex{dbClient: queryDBClient} - steps.s20AddByUserSessionIndex = &AddByUserIndexToSession{dbClient: queryDBClient} - steps.s21AddBlockFieldToLimits = &AddBlockFieldToLimits{dbClient: queryDBClient} - steps.s22ActiveInstancesIndex = &ActiveInstanceEvents{dbClient: queryDBClient} - steps.s23CorrectGlobalUniqueConstraints = &CorrectGlobalUniqueConstraints{dbClient: esPusherDBClient} - steps.s24AddActorToAuthTokens = &AddActorToAuthTokens{dbClient: queryDBClient} - steps.s25User11AddLowerFieldsToVerifiedEmail = &User11AddLowerFieldsToVerifiedEmail{dbClient: esPusherDBClient} - steps.s26AuthUsers3 = &AuthUsers3{dbClient: esPusherDBClient} - steps.s27IDPTemplate6SAMLNameIDFormat = &IDPTemplate6SAMLNameIDFormat{dbClient: esPusherDBClient} - steps.s28AddFieldTable = &AddFieldTable{dbClient: esPusherDBClient} + steps.s5LastFailed = &LastFailed{dbClient: dbClient.DB} + steps.s6OwnerRemoveColumns = &OwnerRemoveColumns{dbClient: dbClient.DB} + steps.s7LogstoreTables = &LogstoreTables{dbClient: dbClient.DB, username: config.Database.Username(), dbType: config.Database.Type()} + steps.s8AuthTokens = &AuthTokenIndexes{dbClient: dbClient} + steps.CorrectCreationDate.dbClient = dbClient + steps.s12AddOTPColumns = &AddOTPColumns{dbClient: dbClient} + steps.s13FixQuotaProjection = &FixQuotaConstraints{dbClient: dbClient} + steps.s14NewEventsTable = &NewEventsTable{dbClient: dbClient} + steps.s15CurrentStates = &CurrentProjectionState{dbClient: dbClient} + steps.s16UniqueConstraintsLower = &UniqueConstraintToLower{dbClient: dbClient} + steps.s17AddOffsetToUniqueConstraints = &AddOffsetToCurrentStates{dbClient: dbClient} + steps.s18AddLowerFieldsToLoginNames = &AddLowerFieldsToLoginNames{dbClient: dbClient} + steps.s19AddCurrentStatesIndex = &AddCurrentSequencesIndex{dbClient: dbClient} + steps.s20AddByUserSessionIndex = &AddByUserIndexToSession{dbClient: dbClient} + steps.s21AddBlockFieldToLimits = &AddBlockFieldToLimits{dbClient: dbClient} + steps.s22ActiveInstancesIndex = &ActiveInstanceEvents{dbClient: dbClient} + steps.s23CorrectGlobalUniqueConstraints = &CorrectGlobalUniqueConstraints{dbClient: dbClient} + steps.s24AddActorToAuthTokens = &AddActorToAuthTokens{dbClient: dbClient} + steps.s25User11AddLowerFieldsToVerifiedEmail = &User11AddLowerFieldsToVerifiedEmail{dbClient: dbClient} + steps.s26AuthUsers3 = &AuthUsers3{dbClient: dbClient} + steps.s27IDPTemplate6SAMLNameIDFormat = &IDPTemplate6SAMLNameIDFormat{dbClient: dbClient} + steps.s28AddFieldTable = &AddFieldTable{dbClient: dbClient} steps.s29FillFieldsForProjectGrant = &FillFieldsForProjectGrant{eventstore: eventstoreClient} steps.s30FillFieldsForOrgDomainVerified = &FillFieldsForOrgDomainVerified{eventstore: eventstoreClient} - steps.s31AddAggregateIndexToFields = &AddAggregateIndexToFields{dbClient: esPusherDBClient} - steps.s32AddAuthSessionID = &AddAuthSessionID{dbClient: esPusherDBClient} - steps.s33SMSConfigs3TwilioAddVerifyServiceSid = &SMSConfigs3TwilioAddVerifyServiceSid{dbClient: esPusherDBClient} - steps.s34AddCacheSchema = &AddCacheSchema{dbClient: queryDBClient} - steps.s35AddPositionToIndexEsWm = &AddPositionToIndexEsWm{dbClient: esPusherDBClient} - steps.s36FillV2Milestones = &FillV3Milestones{dbClient: queryDBClient, eventstore: eventstoreClient} - steps.s37Apps7OIDConfigsBackChannelLogoutURI = &Apps7OIDConfigsBackChannelLogoutURI{dbClient: esPusherDBClient} - steps.s38BackChannelLogoutNotificationStart = &BackChannelLogoutNotificationStart{dbClient: esPusherDBClient, esClient: eventstoreClient} - steps.s40InitPushFunc = &InitPushFunc{dbClient: esPusherDBClient} - steps.s42Apps7OIDCConfigsLoginVersion = &Apps7OIDCConfigsLoginVersion{dbClient: esPusherDBClient} - steps.s43CreateFieldsDomainIndex = &CreateFieldsDomainIndex{dbClient: queryDBClient} - steps.s44ReplaceCurrentSequencesIndex = &ReplaceCurrentSequencesIndex{dbClient: esPusherDBClient} + steps.s31AddAggregateIndexToFields = &AddAggregateIndexToFields{dbClient: dbClient} + steps.s32AddAuthSessionID = &AddAuthSessionID{dbClient: dbClient} + steps.s33SMSConfigs3TwilioAddVerifyServiceSid = &SMSConfigs3TwilioAddVerifyServiceSid{dbClient: dbClient} + steps.s34AddCacheSchema = &AddCacheSchema{dbClient: dbClient} + steps.s35AddPositionToIndexEsWm = &AddPositionToIndexEsWm{dbClient: dbClient} + steps.s36FillV2Milestones = &FillV3Milestones{dbClient: dbClient, eventstore: eventstoreClient} + steps.s37Apps7OIDConfigsBackChannelLogoutURI = &Apps7OIDConfigsBackChannelLogoutURI{dbClient: dbClient} + steps.s38BackChannelLogoutNotificationStart = &BackChannelLogoutNotificationStart{dbClient: dbClient, esClient: eventstoreClient} + steps.s40InitPushFunc = &InitPushFunc{dbClient: dbClient} + steps.s42Apps7OIDCConfigsLoginVersion = &Apps7OIDCConfigsLoginVersion{dbClient: dbClient} + steps.s43CreateFieldsDomainIndex = &CreateFieldsDomainIndex{dbClient: dbClient} + steps.s44ReplaceCurrentSequencesIndex = &ReplaceCurrentSequencesIndex{dbClient: dbClient} + steps.s45CorrectProjectOwners = &CorrectProjectOwners{eventstore: eventstoreClient} + steps.s46InitPermissionFunctions = &InitPermissionFunctions{eventstoreClient: dbClient} - err = projection.Create(ctx, projectionDBClient, eventstoreClient, config.Projections, nil, nil, nil) + err = projection.Create(ctx, dbClient, eventstoreClient, config.Projections, nil, nil, nil) logging.OnError(err).Fatal("unable to start projections") repeatableSteps := []migration.RepeatableMigration{ @@ -192,8 +190,16 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string) &DeleteStaleOrgFields{ eventstore: eventstoreClient, }, - &FillFieldsForInstanceDomains{ + &RepeatableFillFields{ eventstore: eventstoreClient, + handlers: []*handler.FieldHandler{ + projection.InstanceDomainFields, + projection.MembershipFields, + }, + }, + &SyncRolePermissions{ + eventstore: eventstoreClient, + rolePermissionMappings: config.InternalAuthZ.RolePermissionMappings, }, } @@ -227,6 +233,8 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string) steps.s36FillV2Milestones, steps.s38BackChannelLogoutNotificationStart, steps.s44ReplaceCurrentSequencesIndex, + steps.s45CorrectProjectOwners, + steps.s46InitPermissionFunctions, } { mustExecuteMigration(ctx, eventstoreClient, step, "migration failed") } @@ -256,8 +264,8 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string) ctx, eventstoreClient, eventstoreV4, - queryDBClient, - projectionDBClient, + dbClient, + dbClient, masterKey, config, ) diff --git a/cmd/setup/sync_role_permissions.go b/cmd/setup/sync_role_permissions.go new file mode 100644 index 0000000000..b38b075d82 --- /dev/null +++ b/cmd/setup/sync_role_permissions.go @@ -0,0 +1,134 @@ +package setup + +import ( + "context" + "database/sql" + _ "embed" + "fmt" + "strings" + + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/repository/instance" + "github.com/zitadel/zitadel/internal/repository/permission" +) + +var ( + //go:embed sync_role_permissions.sql + getRolePermissionOperationsQuery string +) + +// SyncRolePermissions is a repeatable step which synchronizes the InternalAuthZ +// RolePermissionMappings from the configuration to the database. +// This is needed until role permissions are manageable over the API. +type SyncRolePermissions struct { + eventstore *eventstore.Eventstore + rolePermissionMappings []authz.RoleMapping +} + +func (mig *SyncRolePermissions) Execute(ctx context.Context, _ eventstore.Event) error { + if err := mig.executeSystem(ctx); err != nil { + return err + } + return mig.executeInstances(ctx) +} + +func (mig *SyncRolePermissions) executeSystem(ctx context.Context) error { + logging.WithFields("migration", mig.String()).Info("prepare system role permission sync events") + + target := rolePermissionMappingsToDatabaseMap(mig.rolePermissionMappings, true) + cmds, err := mig.synchronizeCommands(ctx, "SYSTEM", target) + if err != nil { + return err + } + events, err := mig.eventstore.Push(ctx, cmds...) + if err != nil { + return err + } + + logging.WithFields("migration", mig.String(), "pushed_events", len(events)).Info("pushed system role permission sync events") + return nil +} + +func (mig *SyncRolePermissions) executeInstances(ctx context.Context) error { + instances, err := mig.eventstore.InstanceIDs( + ctx, + eventstore.NewSearchQueryBuilder(eventstore.ColumnsInstanceIDs). + OrderDesc(). + AddQuery(). + AggregateTypes(instance.AggregateType). + EventTypes(instance.InstanceAddedEventType). + Builder(). + ExcludeAggregateIDs(). + AggregateTypes(instance.AggregateType). + EventTypes(instance.InstanceRemovedEventType). + Builder(), + ) + if err != nil { + return err + } + target := rolePermissionMappingsToDatabaseMap(mig.rolePermissionMappings, false) + for i, instanceID := range instances { + logging.WithFields("instance_id", instanceID, "migration", mig.String(), "progress", fmt.Sprintf("%d/%d", i+1, len(instances))).Info("prepare instance role permission sync events") + cmds, err := mig.synchronizeCommands(ctx, instanceID, target) + if err != nil { + return err + } + events, err := mig.eventstore.Push(ctx, cmds...) + if err != nil { + return err + } + logging.WithFields("instance_id", instanceID, "migration", mig.String(), "pushed_events", len(events)).Info("pushed instance role permission sync events") + } + return nil +} + +// synchronizeCommands checks the current state of role permissions in the eventstore for the aggregate. +// It returns the commands required to reach the desired state passed in target. +// For system level permissions aggregateID must be set to `SYSTEM`, +// else it is the instance ID. +func (mig *SyncRolePermissions) synchronizeCommands(ctx context.Context, aggregateID string, target database.Map[[]string]) (cmds []eventstore.Command, err error) { + aggregate := permission.NewAggregate(aggregateID) + err = mig.eventstore.Client().QueryContext(ctx, func(rows *sql.Rows) error { + for rows.Next() { + var operation, role, perm string + if err := rows.Scan(&operation, &role, &perm); err != nil { + return err + } + logging.WithFields("aggregate_id", aggregateID, "migration", mig.String(), "operation", operation, "role", role, "permission", perm).Debug("sync role permission") + switch operation { + case "add": + cmds = append(cmds, permission.NewAddedEvent(ctx, aggregate, role, perm)) + case "remove": + cmds = append(cmds, permission.NewRemovedEvent(ctx, aggregate, role, perm)) + } + } + return rows.Close() + + }, getRolePermissionOperationsQuery, aggregateID, target) + if err != nil { + return nil, err + } + return cmds, err +} + +func (*SyncRolePermissions) String() string { + return "repeatable_sync_role_permissions" +} + +func (*SyncRolePermissions) Check(lastRun map[string]interface{}) bool { + return true +} + +func rolePermissionMappingsToDatabaseMap(mappings []authz.RoleMapping, system bool) database.Map[[]string] { + out := make(database.Map[[]string], len(mappings)) + for _, m := range mappings { + if system == strings.HasPrefix(m.Role, "SYSTEM") { + out[m.Role] = m.Permissions + } + } + return out +} diff --git a/cmd/setup/sync_role_permissions.sql b/cmd/setup/sync_role_permissions.sql new file mode 100644 index 0000000000..e7ce21cee7 --- /dev/null +++ b/cmd/setup/sync_role_permissions.sql @@ -0,0 +1,52 @@ +/* +This query creates a change set of permissions that need to be added or removed. +It compares the current state in the fields table (thru the role_permissions view) +against a passed role permission mapping as JSON, created from Zitadel's config: + +{ + "IAM_ADMIN_IMPERSONATOR": ["admin.impersonation", "impersonation"], + "IAM_END_USER_IMPERSONATOR": ["impersonation"], + "FOO_BAR": ["foo.bar", "bar.foo"] + } + +It uses an aggregate_id as first argument which may be an instance_id or 'SYSTEM' +for system level permissions. +*/ +WITH target AS ( + -- unmarshal JSON representation into flattened tabular data + SELECT + key AS role, + jsonb_array_elements_text(value) AS permission + FROM jsonb_each($2::jsonb) +), add AS ( + -- find all role permissions that exist in `target` and not in `role_permissions` + SELECT t.role, t.permission + FROM eventstore.role_permissions p + RIGHT JOIN target t + ON p.aggregate_id = $1::text + AND p.role = t.role + AND p.permission = t.permission + WHERE p.role IS NULL +), remove AS ( + -- find all role permissions that exist `role_permissions` and not in `target` + SELECT p.role, p.permission + FROM eventstore.role_permissions p + LEFT JOIN target t + ON p.role = t.role + AND p.permission = t.permission + WHERE p.aggregate_id = $1::text + AND t.role IS NULL +) +-- return the required operations +SELECT + 'add' AS operation, + role, + permission +FROM add +UNION ALL +SELECT + 'remove' AS operation, + role, + permission +FROM remove +; diff --git a/cmd/start/config.go b/cmd/start/config.go index 6182342592..910759b653 100644 --- a/cmd/start/config.go +++ b/cmd/start/config.go @@ -15,6 +15,7 @@ import ( "github.com/zitadel/zitadel/internal/api/http/middleware" "github.com/zitadel/zitadel/internal/api/oidc" "github.com/zitadel/zitadel/internal/api/saml" + scim_config "github.com/zitadel/zitadel/internal/api/scim/config" "github.com/zitadel/zitadel/internal/api/ui/console" "github.com/zitadel/zitadel/internal/api/ui/login" auth_es "github.com/zitadel/zitadel/internal/auth/repository/eventsourcing" @@ -60,6 +61,7 @@ type Config struct { UserAgentCookie *middleware.UserAgentCookieConfig OIDC oidc.Config SAML saml.Config + SCIM scim_config.Config Login login.Config Console console.Config AssetStorage static_config.AssetStorageConfig @@ -125,5 +127,8 @@ func MustNewConfig(v *viper.Viper) *Config { id.Configure(config.Machine) actions.SetHTTPConfig(&config.Actions.HTTP) + // Copy the global role permissions mappings to the instance until we allow instance-level configuration over the API. + config.DefaultInstance.RolePermissionMappings = config.InternalAuthZ.RolePermissionMappings + return config } diff --git a/cmd/start/start.go b/cmd/start/start.go index 154c683481..4091213d2d 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -63,6 +63,8 @@ import ( "github.com/zitadel/zitadel/internal/api/oidc" "github.com/zitadel/zitadel/internal/api/robots_txt" "github.com/zitadel/zitadel/internal/api/saml" + "github.com/zitadel/zitadel/internal/api/scim" + "github.com/zitadel/zitadel/internal/api/scim/schemas" "github.com/zitadel/zitadel/internal/api/ui/console" "github.com/zitadel/zitadel/internal/api/ui/console/path" "github.com/zitadel/zitadel/internal/api/ui/login" @@ -75,7 +77,6 @@ import ( "github.com/zitadel/zitadel/internal/crypto" cryptoDB "github.com/zitadel/zitadel/internal/crypto/database" "github.com/zitadel/zitadel/internal/database" - "github.com/zitadel/zitadel/internal/database/dialect" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore" old_es "github.com/zitadel/zitadel/internal/eventstore/repository/sql" @@ -148,20 +149,12 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server i18n.MustLoadSupportedLanguagesFromDir() - queryDBClient, err := database.Connect(config.Database, false, dialect.DBPurposeQuery) + dbClient, err := database.Connect(config.Database, false) if err != nil { return fmt.Errorf("cannot start DB client for queries: %w", err) } - esPusherDBClient, err := database.Connect(config.Database, false, dialect.DBPurposeEventPusher) - if err != nil { - return fmt.Errorf("cannot start client for event store pusher: %w", err) - } - projectionDBClient, err := database.Connect(config.Database, false, dialect.DBPurposeProjectionSpooler) - if err != nil { - return fmt.Errorf("cannot start client for projection spooler: %w", err) - } - keyStorage, err := cryptoDB.NewKeyStorage(queryDBClient, masterKey) + keyStorage, err := cryptoDB.NewKeyStorage(dbClient, masterKey) if err != nil { return fmt.Errorf("cannot start key storage: %w", err) } @@ -170,16 +163,16 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server return err } - config.Eventstore.Pusher = new_es.NewEventstore(esPusherDBClient) - config.Eventstore.Searcher = new_es.NewEventstore(queryDBClient) - config.Eventstore.Querier = old_es.NewCRDB(queryDBClient) + config.Eventstore.Pusher = new_es.NewEventstore(dbClient) + config.Eventstore.Searcher = new_es.NewEventstore(dbClient) + config.Eventstore.Querier = old_es.NewCRDB(dbClient) eventstoreClient := eventstore.NewEventstore(config.Eventstore) - eventstoreV4 := es_v4.NewEventstoreFromOne(es_v4_pg.New(queryDBClient, &es_v4_pg.Config{ + eventstoreV4 := es_v4.NewEventstoreFromOne(es_v4_pg.New(dbClient, &es_v4_pg.Config{ MaxRetries: config.Eventstore.MaxRetries, })) sessionTokenVerifier := internal_authz.SessionTokenVerifier(keys.OIDC) - cacheConnectors, err := connector.StartConnectors(config.Caches, queryDBClient) + cacheConnectors, err := connector.StartConnectors(config.Caches, dbClient) if err != nil { return fmt.Errorf("unable to start caches: %w", err) } @@ -188,8 +181,8 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server ctx, eventstoreClient, eventstoreV4.Querier, - queryDBClient, - projectionDBClient, + dbClient, + dbClient, cacheConnectors, config.Projections, config.SystemDefaults, @@ -213,7 +206,7 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server return fmt.Errorf("cannot start queries: %w", err) } - authZRepo, err := authz.Start(queries, eventstoreClient, queryDBClient, keys.OIDC, config.ExternalSecure) + authZRepo, err := authz.Start(queries, eventstoreClient, dbClient, keys.OIDC, config.ExternalSecure) if err != nil { return fmt.Errorf("error starting authz repo: %w", err) } @@ -221,7 +214,7 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server return internal_authz.CheckPermission(ctx, authZRepo, config.InternalAuthZ.RolePermissionMappings, permission, orgID, resourceID) } - storage, err := config.AssetStorage.NewStorage(queryDBClient.DB) + storage, err := config.AssetStorage.NewStorage(dbClient.DB) if err != nil { return fmt.Errorf("cannot start asset storage client: %w", err) } @@ -266,7 +259,7 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server if err != nil { return err } - actionsExecutionDBEmitter, err := logstore.NewEmitter[*record.ExecutionLog](ctx, clock, config.Quotas.Execution, execution.NewDatabaseLogStorage(queryDBClient, commands, queries)) + actionsExecutionDBEmitter, err := logstore.NewEmitter[*record.ExecutionLog](ctx, clock, config.Quotas.Execution, execution.NewDatabaseLogStorage(dbClient, commands, queries)) if err != nil { return err } @@ -295,7 +288,7 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server keys.SMS, keys.OIDC, config.OIDC.DefaultBackChannelLogoutLifetime, - queryDBClient, + dbClient, ) notification.Start(ctx) @@ -311,7 +304,7 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server commands, queries, eventstoreClient, - queryDBClient, + dbClient, config, storage, authZRepo, @@ -331,7 +324,7 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server if server != nil { server <- &Server{ Config: config, - DB: queryDBClient, + DB: dbClient, KeyStorage: keyStorage, Keys: keys, Eventstore: eventstoreClient, @@ -440,7 +433,7 @@ func startAPIs( if err := apis.RegisterService(ctx, user_v2.CreateServer(commands, queries, keys.User, keys.IDPConfig, idp.CallbackURL(), idp.SAMLRootURL(), assets.AssetAPI(), permissionCheck)); err != nil { return nil, err } - if err := apis.RegisterService(ctx, session_v2beta.CreateServer(commands, queries)); err != nil { + if err := apis.RegisterService(ctx, session_v2beta.CreateServer(commands, queries, permissionCheck)); err != nil { return nil, err } if err := apis.RegisterService(ctx, settings_v2beta.CreateServer(commands, queries)); err != nil { @@ -452,7 +445,7 @@ func startAPIs( if err := apis.RegisterService(ctx, feature_v2beta.CreateServer(commands, queries)); err != nil { return nil, err } - if err := apis.RegisterService(ctx, session_v2.CreateServer(commands, queries)); err != nil { + if err := apis.RegisterService(ctx, session_v2.CreateServer(commands, queries, permissionCheck)); err != nil { return nil, err } if err := apis.RegisterService(ctx, settings_v2.CreateServer(commands, queries)); err != nil { @@ -519,6 +512,17 @@ func startAPIs( } apis.RegisterHandlerOnPrefix(saml.HandlerPrefix, samlProvider.HttpHandler()) + apis.RegisterHandlerOnPrefix( + schemas.HandlerPrefix, + scim.NewServer( + commands, + queries, + verifier, + keys.User, + &config.SCIM, + instanceInterceptor.HandlerFuncWithError, + middleware.AuthorizationInterceptor(verifier, config.InternalAuthZ).HandlerFuncWithError)) + c, err := console.Start(config.Console, config.ExternalSecure, oidcServer.IssuerFromRequest, middleware.CallDurationHandler, instanceInterceptor.Handler, limitingAccessInterceptor, config.CustomerPortal) if err != nil { return nil, fmt.Errorf("unable to start console: %w", err) diff --git a/docs/docs/apis/openidoauth/claims.md b/docs/docs/apis/openidoauth/claims.md index e82f0b4059..4129806aef 100644 --- a/docs/docs/apis/openidoauth/claims.md +++ b/docs/docs/apis/openidoauth/claims.md @@ -26,7 +26,7 @@ Please check below the matrix for an overview where which scope is asserted. | jti | No | Yes | No | When JWT | | locale | When requested | When requested | When requested and response_type `id_token` | No | | name | When requested | When requested | When requested and response_type `id_token` | No | -| nbf | No | Yes | Yes | When JWT | +| nbf | No | Yes | No | When JWT | | nonce | No | No | When provided in the authorization request [^1] | No | | phone | When requested | When requested | When requested and response_type `id_token` | No | | phone_verified | When requested | When requested | When requested and response_type `id_token` | No | diff --git a/docs/docs/guides/integrate/login-ui/external-login.mdx b/docs/docs/guides/integrate/login-ui/external-login.mdx index dde966d388..3b3c47cf18 100644 --- a/docs/docs/guides/integrate/login-ui/external-login.mdx +++ b/docs/docs/guides/integrate/login-ui/external-login.mdx @@ -152,7 +152,7 @@ curl --request POST \ If you didn't get a user ID in the parameters of your success page, you know that there is no existing user in ZITADEL with that provider, and you can register a new user or link it to an existing account (read the next section). Fill the IdP links in the create user request to add a user with an external login provider. -The idpId is the ID of the provider in ZITADEL, the idpExternalId is the ID of the user in the external identity provider; usually, this is sent in the “sub”. +The idpId is the ID of the provider in ZITADEL, the userId is the ID of the user in the external identity provider; usually, this is sent in the “sub”. The display name is used to list the linkings on the users. [Create User API Documentation](/docs/apis/resources/user_service_v2/user-service-add-human-user) @@ -181,8 +181,8 @@ curl --request POST \ "idpLinks": [ { "idpId": "218528353504723201", - "idpExternalId": "111392805975715856637", - "displayName": "Minnie Mouse" + "userId": "111392805975715856637", + "userName": "Minnie Mouse" } ] }' @@ -205,8 +205,8 @@ curl --request POST \ --data '{ "idpLink": { "idpId": "218528353504723201", - "idpExternalId": "1113928059757158566371", - "displayName": "Minnie Mouse" + "userId": "1113928059757158566371", + "userName": "Minnie Mouse" } }' ``` diff --git a/docs/docs/guides/integrate/login-ui/typescript-repo.mdx b/docs/docs/guides/integrate/login-ui/typescript-repo.mdx index d1a3f1d877..d5fd6d9e4d 100644 --- a/docs/docs/guides/integrate/login-ui/typescript-repo.mdx +++ b/docs/docs/guides/integrate/login-ui/typescript-repo.mdx @@ -146,4 +146,6 @@ Then create a personal access token (PAT), copy and set it as `ZITADEL_SERVICE_U Finally set your instance url as `ZITADEL_API_URL`. Make sure to set it without trailing slash. Also ensure your login domain is registered on your instance by adding it as a [trusted domain](/docs/apis/resources/admin/admin-service-add-instance-trusted-domain). +If you want to enforce users to have their email verified, you can set the optional `EMAIL_VERIFICATION` variable to `true` in your environment and your users will be enforced to verify their email address before they can log in. + ![Deploy to Vercel](/img/deploy-to-vercel.png) diff --git a/docs/docs/guides/integrate/login/hosted-login.mdx b/docs/docs/guides/integrate/login/hosted-login.mdx new file mode 100644 index 0000000000..fcb7729314 --- /dev/null +++ b/docs/docs/guides/integrate/login/hosted-login.mdx @@ -0,0 +1,207 @@ +--- +title: Login users into your application with a hosted login UI +sidebar_label: Hosted Login UI +--- + +ZITADEL provides a hosted single-sign-on page to securely sign-in users to your applications. +ZITADEL's hosted login page serves as a centralized authentication interface provided for applications that integrate ZITADEL. +As a developer, understanding the hosted login page is essential for seamlessly integrating authentication into your application. + +## Centralized authentication endpoint + +ZITADEL's hosted login page acts as a centralized authentication endpoint where users are redirected to authenticate themselves. +When users attempt to access a protected resource within your application, you can redirect them to the hosted login page to authenticate using their login methods and credentials or through Single-sign-on (SSO). +After successful authentication, the user will be redirected back to the originating application. + +## Security and compliance + +ZITADEL's hosted login page prioritizes security and compliance with industry standards and regulations. +It employs best practices for securing authentication processes, such as encryption, token-based authentication, and adherence to protocols like OAuth 2.0, [OpenID Connect](/docs/guides/integrate/login/oidc), and [SAML](/docs/guides/integrate/login/). + +We make sure to harden the login UI and minimize the attack surface. +One of the measures we apply is setting the necessary security heads thus minimizing the risk of common vulnerabilities in login pages, such as XSS vulnerabilities. +Put your current login to the test and compare the results with our hosted login page. +Tools like [Mozilla's Observatory](https://observatory.mozilla.org/) can give you a good first impression about the security posture. + +## Developer-friendly integration + +Integrating the hosted login page into your application is straightforward, thanks to ZITADEL's developer-friendly documentation, SDKs, and APIs. Developers can easily implement authentication flows, handle authentication callbacks, and customize the user experience to seamlessly integrate authentication with their application's workflow. + +Overall, ZITADEL's hosted login page simplifies the authentication process for developers by providing a secure, customizable, and developer-friendly authentication interface. By leveraging this centralized authentication endpoint, developers can enhance their application's security, user experience, and compliance with industry standards and regulations. + +## Key features of the hosted login + +### Flexible usernames + +Different login name formats can be used on ZITADEL's hosted login page to select a user. +Login methods can be a user's username, containing the username and an [organization domain](/docs/guides/manage/console/organizations#domain-verification-and-primary-domain), their email addresses, or their phone numbers. +By default, all of these login methods are allowed and can be adjusted by [Managers](/docs/concepts/structure/managers) to meet their requirements. + +### Support for multiple authentication methods + +The hosted login page supports various authentication methods, including traditional username/password authentication, social login options, multi-factor authentication (MFA), and passwordless authentication methods like [passkeys](/docs/concepts/features/passkeys.md). +The second factor (2FA) and multi-factor authentication methods (MFA) available in ZITADEL include OTP via an authenticator app, TOTP via SMS, OTP via email, and U2F. + +Developers can configure the authentication methods offered on the login page based on their application's security and usability requirements. + +### Enterprise single-sign-on + +![Screenshot of ZITADEL console showing different identity provider templates](/img/guides/integrate/login/login-external-idp-templates.png) + +With the hosted login page from ZITADEL developers will get the best support for multi-tenancy single-sign-on with third-party identity providers. +ZITADEL acts as an [identity broker](/docs/concepts/features/identity-brokering) between your applications and different external identity providers, reducing the implementation effort for developers. +External Identity providers can be configured for the whole instance or for each organization that represents a group of users such as a B2B customer or organizational unit. + +ZITADEL offers various [identity provider templates](/docs/guides/integrate/identity-providers/introduction) to integrate providers such as [Okta](/docs/guides/integrate/identity-providers/okta-oidc), [Entra ID](/docs/guides/integrate/identity-providers/azure-ad-oidc) or on-premise [LDAP](/docs/guides/integrate/identity-providers/ldap). + +### Multi-tenancy authentication + +ZITADEL simplifies multi-tenancy authentication by securely managing authentication for multiple tenants, called [Organizations](/docs/concepts/structure/organizations), within a single [instance](/docs/concepts/structure/instance). + +Key features include: + +1. **Secure Tenant Isolation**: Ensures robust security measures to prevent unauthorized access between tenants, maintaining data privacy and compliance. [Managers](/docs/concepts/structure/managers) for an organization have only access to data and configuration within their Organization. +2. **Custom Authentication Configurations**: Allows tailored [authentication settings](/docs/guides/manage/console/default-settings#login-behavior-and-access), [branding](/docs/guides/manage/customize/branding), and policies for each tenant. +3. **Centralized Management**: Provides [centralized administration](/docs/guides/manage/console/managers) for efficient management across all tenants. +4. **Scalability and Flexibility**: Scales seamlessly to accommodate growing organizations of all sizes. +5. **Domain Discovery**: Starting on a central login page, route users to their tenant based on their email address or other user attributes. Authentication settings will be applied automatically based on the organization's policies, this includes routing users seamlessly to third party identity providers like [Entra ID](/docs/guides/integrate/identity-providers/azure-ad-oidc). + +### Customization options + +While the hosted login page provides a default authentication interface out-of-the-box, ZITADEL offers [customization options](/docs/guides/manage/customize/branding) to tailor the login page to match your application's branding and user experience requirements. +Developers can customize elements such as logos, colors, and messaging to ensure a seamless integration with their application's user interface. + +:::info Customization and Branding +The login page can be changed by customizing different branding aspects and you can define a custom domain for the login (eg, login.acme.com). + +By default, the displayed branding is defined [based on the user's domain](/docs/guides/solution-scenarios/domain-discovery). In case you want to show the branding of a specific organization by default, you need to either pass a primary domain scope (`urn:zitadel:iam:org:domain:primary:{domainname}`) with the authorization request, or define the behavior on your Project's settings. +::: + +### Fast account switching + +The hosted login page remembers users who have previously authenticated. +In case a user has used multiple accounts, for example, a private account and a work account, to authenticate, then all accounts will be shown on the Account Picker. +Users can still login with a different user that is not on the list. +This allows users to quickly switch between users and provide a better user experience. + +:::info +This behavior can be changed with the authorization request. Please refer to our [guide](/guides/integrate/login/oidc/login-users). +::: + +### Self-service for users + +ZITADEL's hosted login page offers [many self-service flows](/docs/concepts/features/selfservice) that allow users to set up authentication methods or recover their login information. +Developers use the self-service functionalities to reduce manual tasks and improve user experience. +Key features include: + +### Password reset + +Unauthenticated users can request a password reset after providing the loginname during the login flow. + +- User selects reset password +- An email will be sent to the verified email address +- User opens a link and has to provide a new password + +#### Prompt users to set up multifactor authentication + +Users are automatically prompted to provide a second factor, when + +- Instance or organization [login policy](/concepts/structure/policies#login-policy) is set +- Requested by the client +- A multi-factor is set up for the user + +When a multi-factor is required, but not set up, then the user is requested to set up an additional factor. + +:::info Disabling multifactor prompt +You can disable the prompt, in case multifactor authentication is not enforced by setting the [**Multifactor Init Lifetime**](/docs/guides/manage/console/default-settings#login-lifetimes) to 0. +::: + +#### Enroll passkeys + +Users can select a button to initiate passwordless login or use a fall-back method (ie. login with username/password), if available. + +The passwordless with [passkeys](/docs/concepts/features/passkeys.md) login flow follows the FIDO2 / WebAuthN standard. +With the introduction of passkeys the gesture can be provided on ANY of the user's devices. +This is not strictly the device where the login flow is being executed (e.g., on a mobile device). +The user experience depends mainly on the operating system and browser. + +## Hosted Login Version 2 (Beta) + +We have worked on a new, self-hostable implementation of our hosted login built with Next.js and leveraging our [Session API](/docs/guides/integrate/login/login-users#zitadels-session-api). +This solution empowers you to easily fork and customize the login experience to perfectly match your brand and needs. + +In this initial release, the new login is available for self-hosting only. We'll be progressively replacing the built-in login with this improved version, built with [TypeScript](https://github.com/zitadel/typescript). + +### Current State + +Our primary goal for the TypeScript login system is to replace the existing login functionality within Zitadel Core, which is shipped with Zitadel automatically. This will allow us to leverage the benefits of the new system, including its modular architecture and enhanced security features. + +To achieve this, we are actively working on implementing the core features currently available in Zitadel Core, such as: + +- **Authentication Methods:** + - Username and Password + - Passkeys + - Multi-Factor Authentication (MFA) + - External Identity Providers (OIDC, SAML, etc.) +- **OpenID Connect (OIDC) Compliance:** Adherence to the OIDC standard for seamless integration with various identity providers. +- **Customization**: + - Branding options to match your organization's identity. + - Flexible configuration settings to tailor the login experience. + +The full feature list can be found [here](https://github.com/zitadel/typescript?tab=readme-ov-file#features-list). + +As we continue to develop the TypeScript login system, we will provide regular updates on its progress and new capabilities. + +### Limitations + +For the first implementation we have excluded the following features: + +- SAML (SP & OP) +- Generic JWT IDP +- LDAP IDP +- Device Authorization Grants +- Timebased features + - Lockout Settings + - Password Expiry Settings + - Login Settings - Multifactor init prompt + - Force MFA on external authenticated users +- Passkey/U2F Setup + - As passkey and u2f is bound to a domain, it is important to notice, that setting up the authentication possibility in the ZITADEL management console (Self-service), will not work if the login runs on a different domain +- Custom Login Texts + +### Beta Testing + +The TypeScript login system is currently in beta testing. Your feedback is invaluable in helping us refine and improve this new solution. +At your convenience please open any issues faced on our Typescript Login GitHub repository to report bugs or suggest enhancements while more general feedback can be shared directly to fabienne@zitadel.com. +Your contributions will play a crucial role in shaping the future of our login system. Thank you for your support! + +#### Step-by-step Guide + +The simplest way to deploy the new login for yourself is by using the [“Deploy” button in our repository](https://github.com/zitadel/typescript?tab=readme-ov-file#deploy-to-vercel) to deploy the login directly to your Vercel. + +1. [Create a service user](https://zitadel.com/docs/guides/integrate/service-users/personal-access-token#create-a-service-user-with-a-pat) (ZITADEL_SERVICE_USER_ID) with a PAT in your instance +2. Give the user IAM_LOGIN_CLIENT Permissions in the default settings (YOUR_DOMAIN/ui/console/instance?id=organizations) + Note: [Zitadel Manager Guide](https://zitadel.com/docs/guides/manage/console/managers) +3. Deploy login to Vercel: You can do so, be directly clicking the [“Deploy” button](https://github.com/zitadel/typescript?tab=readme-ov-file#deploy-to-vercel) at the bottom of the readme in our [repository](https://github.com/zitadel/typescript) +4. If you have used the deploy button in the steps before, you will automatically be asked for this step. Enter the environment variables in Vercel + - ZITADEL_SERVICE_USER_ID + - PAT + - ZITADEL_API_URL (Example: https://my-domain.zitadel.cloud, no trailing slash) +5. Add the domain where your login UI is hosted to the [trusted domains](https://zitadel.com/docs/apis/resources/admin/admin-service-add-instance-trusted-domain) in Zitadel. (Example: my-new-zitadel-login.vercel.app) +6. Use the new login in your application. You have three different options on how to achieve this + 1. Enable the new login on your application configuration and add the URL of your login UI, with that settings Zitadel will automatically redirect you to the new login if you call the old one. + ![Login V2 Application Configuration](/img/guides/integrate/login/login-v2-app-config.png) + 2. Enable the [loginV2 feature](https://zitadel.com/docs/apis/resources/feature_service_v2/feature-service-set-instance-features) on the instance and add the URL of your login. If you enable this feature, the login will be used for every application configured in your Zitadel instance. (Example: https://my-new-zitadel-login.vercel.app) + 3. Change the issuer in the code of your application to the new domain of your login +7. Enforce users to have their email verified. By setting `EMAIL_VERIFICATION` to `true` in your environment variables, your users will be enforced to verify their email address before they can log in. + +### Important Notes + +As this feature is currently in Beta, please be aware of some potential workarounds and important considerations before implementation. + +- **Create Users:** The new typescript login is built with the session and the user V2 API, the users V2 API does have some differences to the v1 API, so make sure you create users through the new API. +- **External IDPs:** If you want to use external identity provider login, such as Login with Google or Apple. You can follow our existing setup guides, just make sure to use the following redirect url: $YOUR-DOMAIN/idps/callback +- **Passkey/U2F:** Those authentication methods are bound to a domain. As your new login runs on a different domain than the previous login, existing passwordless authentication and u2f (fingerprint, face id, etc.) can’t be used. Also when they are managed through the management console of ZITADEL, they are added on a different domain. +
+ *Note: If you run the login on a subdomain of your current instance, this problem + can be avoided. E.g myinstance.zitadel.cloud and login.myinstance.zitadel.cloud* diff --git a/docs/docs/guides/integrate/login/login-users.mdx b/docs/docs/guides/integrate/login/login-users.mdx index 82c9dbe5a4..e80e70c5b7 100644 --- a/docs/docs/guides/integrate/login/login-users.mdx +++ b/docs/docs/guides/integrate/login/login-users.mdx @@ -1,6 +1,6 @@ --- -title: Login users into your application with a hosted or custom login UI -sidebar_label: Hosted vs. Custom Login UI +title: Log users into your application with different authentication options +sidebar_label: Authentication Options --- ZITADEL is a comprehensive identity and access management platform designed to streamline user authentication, authorization, and management processes for your application. It offers a range of features, including single sign-on (SSO), multi-factor authentication (MFA), and centralized user management. @@ -25,6 +25,8 @@ The identity provider is not part of the original application, but a standalone The user will authenticate using their credentials. After successful authentication, the user will be redirected back to the original application. +If you want to read more about authenticating with OIDC, head over to our comprehensive [OpenID Connect Guide](/docs/integrate/login/oidc). + ### Authenticate users with SAML SAML (Security Assertion Markup Language) is a widely adopted standard for exchanging authentication and authorization data between identity providers and service providers. @@ -52,13 +54,14 @@ Note that SAML might not be suitable for mobile applications. In case you want to integrate a mobile application, use OpenID Connect or our Session API. There are more [differences between SAML and OIDC](https://zitadel.com/blog/saml-vs-oidc) that you might want to consider. +If you want to read more about authenticating with SAML, head over to our comprehensive [SAML Guide](/docs/integrate/login/saml). -### ZITADEL's Session API +## ZITADEL's Session API ZITADEL's [Session API](/docs/apis/resources/session_service_v2) provides developers with a straightforward method to manage user sessions within their applications. The Session API is not an industry-standard and can be used instead of OpenID Connect or SAML to authenticate users by [building your own custom login user interface](/docs/guides/integrate/login-ui). -#### Tokens in the Session API +### Tokens in the Session API The session API will return a session token that can be used to authenticate users from your application. This token should not be confused with am access or id tokens in opaque or JWT form that is issued during OpenID connect flows. @@ -67,7 +70,7 @@ This token should not be confused with am access or id tokens in opaque or JWT f Token exchange between Session API and OIDC / SAML tokens is not possible at this moment. ::: -#### Key features of the Session API +### Key features of the Session API These are some key features of the API: @@ -85,127 +88,16 @@ Overall, ZITADEL's Session API simplifies session management within your applica ## Use the Hosted Login to sign-in users -ZITADEL provides a hosted single-sign-on page to securely sign-in users to your applications. -ZITADEL's hosted login page serves as a centralized authentication interface provided for applications that integrate ZITADEL. -As a developer, understanding the hosted login page is essential for seamlessly integrating authentication into your application. +ZITADEL provides a hosted single-sign-on page for secure user authentication within your applications. +This centralized authentication interface simplifies application integration by offering a ready-to-use login experience. +For a comprehensive understanding of the hosted login page and its capabilities, please refer to our [dedicated guide](/docs/guides/integrate/login/hosted-login) -### Centralized authentication endpoint - -ZITADEL's hosted login page acts as a centralized authentication endpoint where users are redirected to authenticate themselves. -When users attempt to access a protected resource within your application, you can redirect them to the hosted login page to authenticate using their login methods and credentials or through Single-sign-on (SSO). -After successful authentication, the user will be redirected back to the originating application. - -### Security and compliance - -ZITADEL's hosted login page prioritizes security and compliance with industry standards and regulations. -It employs best practices for securing authentication processes, such as encryption, token-based authentication, and adherence to protocols like OAuth 2.0, [OpenID Connect](/docs/guides/integrate/login/oidc), and [SAML](/docs/guides/integrate/login/). - -We make sure to harden the login UI and minimize the attack surface. -One of the measures we apply is setting the necessary security heads thus minimizing the risk of common vulnerabilities in login pages, such as XSS vulnerabilities. -Put your current login to the test and compare the results with our hosted login page. -Tools like [Mozilla's Observatory](https://observatory.mozilla.org/) can give you a good first impression about the security posture. - -### Developer-friendly integration - -Integrating the hosted login page into your application is straightforward, thanks to ZITADEL's developer-friendly documentation, SDKs, and APIs. Developers can easily implement authentication flows, handle authentication callbacks, and customize the user experience to seamlessly integrate authentication with their application's workflow. - -Overall, ZITADEL's hosted login page simplifies the authentication process for developers by providing a secure, customizable, and developer-friendly authentication interface. By leveraging this centralized authentication endpoint, developers can enhance their application's security, user experience, and compliance with industry standards and regulations. - -## Key features of the hosted login - -### Flexible usernames - -Different login name formats can be used on ZITADEL's hosted login page to select a user. -Login methods can be a user's username, containing the username and an [organization domain](/docs/guides/manage/console/organizations#domain-verification-and-primary-domain), their email addresses, or their phone numbers. -By default, all of these login methods are allowed and can be adjusted by [Managers](/docs/concepts/structure/managers) to meet their requirements. - -### Support for multiple authentication methods - -The hosted login page supports various authentication methods, including traditional username/password authentication, social login options, multi-factor authentication (MFA), and passwordless authentication methods like [passkeys](/docs/concepts/features/passkeys.md). -The second factor (2FA) and multi-factor authentication methods (MFA) available in ZITADEL include OTP via an authenticator app, TOTP via SMS, OTP via email, and U2F. - -Developers can configure the authentication methods offered on the login page based on their application's security and usability requirements. - -### Enterprise single-sign-on - -![Screenshot of ZITADEL console showing different identity provider templates](/img/guides/integrate/login/login-external-idp-templates.png) - -With the hosted login page from ZITADEL developers will get the best support for multi-tenancy single-sign-on with third-party identity providers. -ZITADEL acts as an [identity broker](/docs/concepts/features/identity-brokering) between your applications and different external identity providers, reducing the implementation effort for developers. -External Identity providers can be configured for the whole instance or for each organization that represents a group of users such as a B2B customer or organizational unit. - -ZITADEL offers various [identity provider templates](/docs/guides/integrate/identity-providers/introduction) to integrate providers such as [Okta](/docs/guides/integrate/identity-providers/okta-oidc), [Entra ID](/docs/guides/integrate/identity-providers/azure-ad-oidc) or on-premise [LDAP](/docs/guides/integrate/identity-providers/ldap). - -### Multi-tenancy authentication - -ZITADEL simplifies multi-tenancy authentication by securely managing authentication for multiple tenants, called [Organizations](/docs/concepts/structure/organizations), within a single [instance](/docs/concepts/structure/instance). - -Key features include: - -1. **Secure Tenant Isolation**: Ensures robust security measures to prevent unauthorized access between tenants, maintaining data privacy and compliance. [Managers](/docs/concepts/structure/managers) for an organization have only access to data and configuration within their Organization. -2. **Custom Authentication Configurations**: Allows tailored [authentication settings](/docs/guides/manage/console/default-settings#login-behavior-and-access), [branding](/docs/guides/manage/customize/branding), and policies for each tenant. -3. **Centralized Management**: Provides [centralized administration](/docs/guides/manage/console/managers) for efficient management across all tenants. -4. **Scalability and Flexibility**: Scales seamlessly to accommodate growing organizations of all sizes. -5. **Domain Discovery**: Starting on a central login page, route users to their tenant based on their email address or other user attributes. Authentication settings will be applied automatically based on the organization's policies, this includes routing users seamlessly to third party identity providers like [Entra ID](/docs/guides/integrate/identity-providers/azure-ad-oidc). - -### Customization options - -While the hosted login page provides a default authentication interface out-of-the-box, ZITADEL offers [customization options](/docs/guides/manage/customize/branding) to tailor the login page to match your application's branding and user experience requirements. -Developers can customize elements such as logos, colors, and messaging to ensure a seamless integration with their application's user interface. - -:::info Customization and Branding -The login page can be changed by customizing different branding aspects and you can define a custom domain for the login (eg, login.acme.com). - -By default, the displayed branding is defined [based on the user's domain](/docs/guides/solution-scenarios/domain-discovery). In case you want to show the branding of a specific organization by default, you need to either pass a primary domain scope (`urn:zitadel:iam:org:domain:primary:{domainname}`) with the authorization request, or define the behavior on your Project's settings. -::: - -### Fast account switching - -The hosted login page remembers users who have previously authenticated. -In case a user has used multiple accounts, for example, a private account and a work account, to authenticate, then all accounts will be shown on the Account Picker. -Users can still login with a different user that is not on the list. -This allows users to quickly switch between users and provide a better user experience. - -:::info -This behavior can be changed with the authorization request. Please refer to our [guide](/guides/integrate/login/oidc/login-users). -::: - -### Self-service for users - -ZITADEL's hosted login page offers [many self-service flows](/docs/concepts/features/selfservice) that allow users to set up authentication methods or recover their login information. -Developers use the self-service functionalities to reduce manual tasks and improve user experience. -Key features include: - -### Password reset - -Unauthenticated users can request a password reset after providing the loginname during the login flow. - -- User selects reset password -- An email will be sent to the verified email address -- User opens a link and has to provide a new password - -#### Prompt users to set up multifactor authentication - -Users are automatically prompted to provide a second factor, when - -- Instance or organization [login policy](/concepts/structure/policies#login-policy) is set -- Requested by the client -- A multi-factor is set up for the user - -When a multi-factor is required, but not set up, then the user is requested to set up an additional factor. - -:::info Disabling multifactor prompt -You can disable the prompt, in case multifactor authentication is not enforced by setting the [**Multifactor Init Lifetime**](/docs/guides/manage/console/default-settings#login-lifetimes) to 0. -::: - -#### Enroll passkeys - -Users can select a button to initiate passwordless login or use a fall-back method (ie. login with username/password), if available. - -The passwordless with [passkeys](/docs/concepts/features/passkeys.md) login flow follows the FIDO2 / WebAuthN standard. -With the introduction of passkeys the gesture can be provided on ANY of the user's devices. -This is not strictly the device where the login flow is being executed (e.g., on a mobile device). -The user experience depends mainly on the operating system and browser. +The hosted login is particularly well-suited for scenarios where: +- **Minimal branding is required:** If your primary focus is on functionality over a highly customized look and feel. +- **Standard authentication flows suffice:** Your application doesn't necessitate complex or unique authentication processes. +- **OIDC or SAML are suitable:** Your application integrates seamlessly with industry-standard protocols. +- **Time-to-market is critical:** You need a rapid and efficient authentication solution to accelerate your development timeline. +- **Embedding the login UI is unnecessary:** You prefer a separate, hosted login page for user authentication. ## Build a custom Login UI to authenticate users diff --git a/docs/docs/guides/integrate/zitadel-apis/example-zitadel-api-with-dot-net.md b/docs/docs/guides/integrate/zitadel-apis/example-zitadel-api-with-dot-net.md index fe1b5a2f2c..7a012f799e 100644 --- a/docs/docs/guides/integrate/zitadel-apis/example-zitadel-api-with-dot-net.md +++ b/docs/docs/guides/integrate/zitadel-apis/example-zitadel-api-with-dot-net.md @@ -43,11 +43,10 @@ dotnet add package Zitadel.Api ### Create example client Change the program.cs file to the content below. This will create a client for the management api and call its `GetMyUsers` function. -The SDK will make sure you will have access to the API by retrieving a Bearer Token using JWT Profile with the provided scopes (`openid` and `urn:zitadel:iam:org:project:id:{projectID}:aud`). +The SDK will make sure you will have access to the API by retrieving a Bearer Token using JWT Profile with the provided scopes (`openid` and `urn:zitadel:iam:org:project:id:zitadel:aud`). -Make sure to fill the const `apiUrl`, `apiProject` and `personalAccessToken` with your own instance data. The used vars below are from a test instance, to show you how it should look. +Make sure to fill the const `apiUrl`, and `personalAccessToken` with your own instance data. The used vars below are from a test instance, to show you how it should look. The apiURL is the domain of your instance you can find it on the instance detail in the Customer Portal or in the Console -The apiProject you will find in the ZITADEL project in the first organization of your instance. ```csharp // This file contains two examples: @@ -66,7 +65,8 @@ var client = Clients.AuthService(new(apiUrl, ITokenProvider.Static(personalAcces var result = await client.GetMyUserAsync(new()); Console.WriteLine($"User: {result.User}"); -const string apiProject = "170078979166961921"; +// This adds the urn:zitadel:iam:org:project:id:zitadel:aud scope to the authorization request, enabling access to ZITADEL APIs. +const string apiProject = "zitadel"; var serviceAccount = ServiceAccount.LoadFromJsonString( @" { diff --git a/docs/docs/support/advisory/a10014.md b/docs/docs/support/advisory/a10014.md new file mode 100644 index 0000000000..be19dd2cbf --- /dev/null +++ b/docs/docs/support/advisory/a10014.md @@ -0,0 +1,26 @@ +--- +title: Technical Advisory 10014 +--- + +## Date + +Versions: >= v2.67.3, v2.66 >= v2.66.6 + +Date: 2025-01-17 + +## Description + +Prior to version [v2.66.0](https://github.com/zitadel/zitadel/releases/tag/v2.66.0), some project grants were incorrectly created under the granted organization instead of the project owner's organization. To find these grants, users had to set the `x-zitadel-orgid` header to the granted organization ID when using the [`ListAllProjectGrants`](/apis/resources/mgmt/management-service-add-project-grant) gRPC method. + +Zitadel [v2.66.0](https://github.com/zitadel/zitadel/releases/tag/v2.66.0) corrected this behavior for new grants. However, existing grants were not automatically updated. Version v2.66.6 corrects the owner of these existing grants. + +## Impact + +After the release of v2.66.6, if your application uses the [`ListAllProjectGrants`](/apis/resources/mgmt/management-service-add-project-grant) method with the `x-zitadel-orgid` header set to the granted organization ID, you will not retrieve any results. + +## Mitigation + +To ensure your application continues to function correctly after the release of v2.66.6, implement the following changes: + +1. **Conditional Header:** Only set the `x-zitadel-orgid` header to the project owner's organization ID if the user executing the [`ListAllProjectGrants`](/apis/resources/mgmt/management-service-add-project-grant) method belongs to a different organization than the project. +2. **Use `grantedOrgIdQuery`:** Utilize the `grantedOrgIdQuery` parameter to filter grants for the specific granted organization. \ No newline at end of file diff --git a/docs/docs/support/technical_advisory.mdx b/docs/docs/support/technical_advisory.mdx index 7562ff3870..8805e2e1d8 100644 --- a/docs/docs/support/technical_advisory.mdx +++ b/docs/docs/support/technical_advisory.mdx @@ -214,6 +214,18 @@ We understand that these advisories may include breaking changes, and we aim to - 2024-12-09 + + + A-10014 + + Correction of project grant owner + Breaking Behavior Change + + Correct project grant owners, ensuring they are correctly associated with the projects organization. + + - + 2025-01-10 + ## Subscribe to our Mailing List diff --git a/docs/sidebars.js b/docs/sidebars.js index 05a2c42342..3b379d57ee 100644 --- a/docs/sidebars.js +++ b/docs/sidebars.js @@ -206,6 +206,31 @@ module.exports = { }, items: [ "guides/integrate/login/login-users", + { + type: "link", + href: "/docs/guides/integrate/login/login-users#zitadels-session-api", + label: "Session API" + }, + { + type: "category", + label: "Hosted Login", + link: { + type: "doc", + id: "guides/integrate/login/hosted-login" + }, + items: [ + { + type: "link", + href: "/docs/guides/integrate/login/hosted-login#hosted-login-version-2-beta", + label: "Login V2 [Beta]" + }, + ] + }, + { + type: "link", + href: "/docs/guides/integrate/login/login-users#build-a-custom-login-ui-to-authenticate-users", + label: "Custom Login UI", + }, { type: "category", label: "OpenID Connect", diff --git a/docs/static/img/guides/integrate/login/login-v2-app-config.png b/docs/static/img/guides/integrate/login/login-v2-app-config.png new file mode 100644 index 0000000000..36ddedcbf1 Binary files /dev/null and b/docs/static/img/guides/integrate/login/login-v2-app-config.png differ diff --git a/docs/vercel.json b/docs/vercel.json index 15bd2499e8..58a13e4b8c 100644 --- a/docs/vercel.json +++ b/docs/vercel.json @@ -46,6 +46,19 @@ { "source": "/docs/examples/call-zitadel-api/go", "destination": "/docs/guides/integrate/zitadel-apis/example-zitadel-api-with-go", "permanent": true }, { "source": "/docs/examples/call-zitadel-api/dot-net", "destination": "/docs/guides/integrate/zitadel-apis/example-zitadel-api-with-dot-net", "permanent": true }, { "source": "/docs/guides/manage/terraform/basics", "destination": "/docs/guides/manage/terraform-provider", "permanent": true }, - { "source": "/docs/guides/integrate/identity-providers", "destination": "/docs/guides/integrate/identity-providers/introduction", "permanent": true } + { "source": "/docs/guides/integrate/identity-providers", "destination": "/docs/guides/integrate/identity-providers/introduction", "permanent": true }, + { "source": "/docs/guides/integrate/login/login-users#centralized-authentication-endpoint", "destination": "/docs/guides/integrate/login/hosted-login#centralized-authentication-endpoint", "permanent": true }, + { "source": "/docs/guides/integrate/login/login-users#security-and-compliance", "destination": "/docs/guides/integrate/login/hosted-login#security-and-compliance", "permanent": true }, + { "source": "/docs/guides/integrate/login/login-users#developer-friendly-integration", "destination": "/docs/guides/integrate/login/hosted-login#developer-friendly-integration", "permanent": true }, + { "source": "/docs/guides/integrate/login/login-users#key-features-of-the-hosted-login", "destination": "/docs/guides/integrate/login/hosted-login#key-features-of-the-hosted-login", "permanent": true }, + { "source": "/docs/guides/integrate/login/login-users#flexible-usernames", "destination": "/docs/guides/integrate/login/hosted-login#flexible-usernames", "permanent": true }, + { "source": "/docs/guides/integrate/login/login-users#support-for-multiple-authentication-methods", "destination": "/docs/guides/integrate/login/hosted-login#support-for-multiple-authentication-methods", "permanent": true }, + { "source": "/docs/guides/integrate/login/login-users#enterprise-single-sign-on", "destination": "/docs/guides/integrate/login/hosted-login#enterprise-single-sign-on", "permanent": true }, + { "source": "/docs/guides/integrate/login/login-users#multi-tenancy-authentication", "destination": "/docs/guides/integrate/login/hosted-login#multi-tenancy-authentication", "permanent": true }, + { "source": "/docs/guides/integrate/login/login-users#customization-options", "destination": "/docs/guides/integrate/login/hosted-login#customization-options", "permanent": true }, + { "source": "/docs/guides/integrate/login/login-users#fast-account-switching", "destination": "/docs/guides/integrate/login/hosted-login#fast-account-switching", "permanent": true }, + { "source": "/docs/guides/integrate/login/login-users#self-service-for-users", "destination": "/docs/guides/integrate/login/hosted-login#self-service-for-users", "permanent": true }, + { "source": "/docs/guides/integrate/login/login-users#password-reset", "destination": "/docs/guides/integrate/login/hosted-login#password-reset", "permanent": true } ] } + diff --git a/go.mod b/go.mod index aa9fbb64a2..20d7322124 100644 --- a/go.mod +++ b/go.mod @@ -28,6 +28,8 @@ require ( github.com/go-jose/go-jose/v4 v4.0.4 github.com/go-ldap/ldap/v3 v3.4.8 github.com/go-webauthn/webauthn v0.10.2 + github.com/goccy/go-json v0.10.3 + github.com/golang/protobuf v1.5.4 github.com/gorilla/csrf v1.7.2 github.com/gorilla/mux v1.8.1 github.com/gorilla/schema v1.4.1 @@ -106,11 +108,9 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/go-sql-driver/mysql v1.7.1 // indirect github.com/go-webauthn/x v0.1.9 // indirect - github.com/goccy/go-json v0.10.3 // indirect github.com/golang-jwt/jwt/v4 v4.5.1 // indirect github.com/golang-jwt/jwt/v5 v5.2.1 // indirect github.com/golang/mock v1.6.0 // indirect - github.com/golang/protobuf v1.5.4 // indirect github.com/google/go-tpm v0.9.0 // indirect github.com/google/pprof v0.0.0-20240528025155-186aa0362fba // indirect github.com/google/s2a-go v0.1.7 // indirect diff --git a/internal/api/authz/context_mock.go b/internal/api/authz/context_mock.go index 6badf15862..6891030bd3 100644 --- a/internal/api/authz/context_mock.go +++ b/internal/api/authz/context_mock.go @@ -7,6 +7,11 @@ func NewMockContext(instanceID, orgID, userID string) context.Context { return context.WithValue(ctx, instanceKey, &instance{id: instanceID}) } +func NewMockContextWithAgent(instanceID, orgID, userID, agentID string) context.Context { + ctx := context.WithValue(context.Background(), dataKey, CtxData{UserID: userID, OrgID: orgID, AgentID: agentID}) + return context.WithValue(ctx, instanceKey, &instance{id: instanceID}) +} + func NewMockContextWithPermissions(instanceID, orgID, userID string, permissions []string) context.Context { ctx := context.WithValue(context.Background(), dataKey, CtxData{UserID: userID, OrgID: orgID}) ctx = context.WithValue(ctx, instanceKey, &instance{id: instanceID}) diff --git a/internal/api/grpc/feature/v2/converter.go b/internal/api/grpc/feature/v2/converter.go index 109d2d1e53..fee4450ce2 100644 --- a/internal/api/grpc/feature/v2/converter.go +++ b/internal/api/grpc/feature/v2/converter.go @@ -29,6 +29,7 @@ func systemFeaturesToCommand(req *feature_pb.SetSystemFeaturesRequest) (*command DisableUserTokenEvent: req.DisableUserTokenEvent, EnableBackChannelLogout: req.EnableBackChannelLogout, LoginV2: loginV2, + PermissionCheckV2: req.PermissionCheckV2, }, nil } @@ -46,6 +47,7 @@ func systemFeaturesToPb(f *query.SystemFeatures) *feature_pb.GetSystemFeaturesRe DisableUserTokenEvent: featureSourceToFlagPb(&f.DisableUserTokenEvent), EnableBackChannelLogout: featureSourceToFlagPb(&f.EnableBackChannelLogout), LoginV2: loginV2ToLoginV2FlagPb(f.LoginV2), + PermissionCheckV2: featureSourceToFlagPb(&f.PermissionCheckV2), } } @@ -68,6 +70,7 @@ func instanceFeaturesToCommand(req *feature_pb.SetInstanceFeaturesRequest) (*com DisableUserTokenEvent: req.DisableUserTokenEvent, EnableBackChannelLogout: req.EnableBackChannelLogout, LoginV2: loginV2, + PermissionCheckV2: req.PermissionCheckV2, }, nil } @@ -87,6 +90,7 @@ func instanceFeaturesToPb(f *query.InstanceFeatures) *feature_pb.GetInstanceFeat DisableUserTokenEvent: featureSourceToFlagPb(&f.DisableUserTokenEvent), EnableBackChannelLogout: featureSourceToFlagPb(&f.EnableBackChannelLogout), LoginV2: loginV2ToLoginV2FlagPb(f.LoginV2), + PermissionCheckV2: featureSourceToFlagPb(&f.PermissionCheckV2), } } diff --git a/internal/api/grpc/feature/v2/converter_test.go b/internal/api/grpc/feature/v2/converter_test.go index f8b2c0006f..bf87dc959b 100644 --- a/internal/api/grpc/feature/v2/converter_test.go +++ b/internal/api/grpc/feature/v2/converter_test.go @@ -101,6 +101,10 @@ func Test_systemFeaturesToPb(t *testing.T) { BaseURI: &url.URL{Scheme: "https", Host: "login.com"}, }, }, + PermissionCheckV2: query.FeatureSource[bool]{ + Level: feature.LevelSystem, + Value: true, + }, } want := &feature_pb.GetSystemFeaturesResponse{ Details: &object.Details{ @@ -153,6 +157,10 @@ func Test_systemFeaturesToPb(t *testing.T) { BaseUri: gu.Ptr("https://login.com"), Source: feature_pb.Source_SOURCE_SYSTEM, }, + PermissionCheckV2: &feature_pb.FeatureFlag{ + Enabled: true, + Source: feature_pb.Source_SOURCE_SYSTEM, + }, } got := systemFeaturesToPb(arg) assert.Equal(t, want, got) @@ -252,6 +260,10 @@ func Test_instanceFeaturesToPb(t *testing.T) { BaseURI: &url.URL{Scheme: "https", Host: "login.com"}, }, }, + PermissionCheckV2: query.FeatureSource[bool]{ + Level: feature.LevelInstance, + Value: true, + }, } want := &feature_pb.GetInstanceFeaturesResponse{ Details: &object.Details{ @@ -312,6 +324,10 @@ func Test_instanceFeaturesToPb(t *testing.T) { BaseUri: gu.Ptr("https://login.com"), Source: feature_pb.Source_SOURCE_INSTANCE, }, + PermissionCheckV2: &feature_pb.FeatureFlag{ + Enabled: true, + Source: feature_pb.Source_SOURCE_INSTANCE, + }, } got := instanceFeaturesToPb(arg) assert.Equal(t, want, got) diff --git a/internal/api/grpc/member/converter.go b/internal/api/grpc/member/converter.go index 0e5c87ceb1..af81d8ea45 100644 --- a/internal/api/grpc/member/converter.go +++ b/internal/api/grpc/member/converter.go @@ -34,6 +34,7 @@ func MemberToPb(assetAPIPrefix string, m *query.Member) *member_pb.Member { m.ChangeDate, m.ResourceOwner, ), + UserResourceOwner: m.UserResourceOwner, } } diff --git a/internal/api/grpc/session/v2/integration_test/query_test.go b/internal/api/grpc/session/v2/integration_test/query_test.go new file mode 100644 index 0000000000..36e412be23 --- /dev/null +++ b/internal/api/grpc/session/v2/integration_test/query_test.go @@ -0,0 +1,714 @@ +//go:build integration + +package session_test + +import ( + "context" + "testing" + "time" + + "github.com/golang/protobuf/ptypes/timestamp" + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/zitadel/zitadel/internal/integration" + "github.com/zitadel/zitadel/pkg/grpc/object/v2" + "github.com/zitadel/zitadel/pkg/grpc/session/v2" +) + +func TestServer_GetSession(t *testing.T) { + type args struct { + ctx context.Context + req *session.GetSessionRequest + dep func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 + } + tests := []struct { + name string + args args + want *session.GetSessionResponse + wantFactors []wantFactor + wantExpirationWindow time.Duration + wantErr bool + }{ + { + name: "get session, no id provided", + args: args{ + CTX, + &session.GetSessionRequest{ + SessionId: "", + }, + nil, + }, + wantErr: true, + }, + { + name: "get session, not found", + args: args{ + CTX, + &session.GetSessionRequest{ + SessionId: "unknown", + }, + nil, + }, + wantErr: true, + }, + { + name: "get session, no permission", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(ctx, &session.CreateSessionRequest{}) + require.NoError(t, err) + request.SessionId = resp.SessionId + return resp.GetDetails().GetSequence() + }, + }, + wantErr: true, + }, + { + name: "get session, permission, ok", + args: args{ + CTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(ctx, &session.CreateSessionRequest{}) + require.NoError(t, err) + request.SessionId = resp.SessionId + return resp.GetDetails().GetSequence() + }, + }, + want: &session.GetSessionResponse{ + Session: &session.Session{}, + }, + }, + { + name: "get session, token, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(ctx, &session.CreateSessionRequest{}) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + want: &session.GetSessionResponse{ + Session: &session.Session{}, + }, + }, + { + name: "get session, user agent, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(ctx, &session.CreateSessionRequest{ + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("fingerPrintID"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + ) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + want: &session.GetSessionResponse{ + Session: &session.Session{ + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("fingerPrintID"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + { + name: "get session, lifetime, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(ctx, &session.CreateSessionRequest{ + Lifetime: durationpb.New(5 * time.Minute), + }, + ) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + wantExpirationWindow: 5 * time.Minute, + want: &session.GetSessionResponse{ + Session: &session.Session{}, + }, + }, + { + name: "get session, metadata, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(ctx, &session.CreateSessionRequest{ + Metadata: map[string][]byte{"foo": []byte("bar")}, + }, + ) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + want: &session.GetSessionResponse{ + Session: &session.Session{ + Metadata: map[string][]byte{"foo": []byte("bar")}, + }, + }, + }, + { + name: "get session, user, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(ctx, &session.CreateSessionRequest{ + Checks: &session.Checks{ + User: &session.CheckUser{ + Search: &session.CheckUser_UserId{ + UserId: User.GetUserId(), + }, + }, + }, + }, + ) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.GetSessionResponse{ + Session: &session.Session{}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var sequence uint64 + if tt.args.dep != nil { + sequence = tt.args.dep(CTX, t, tt.args.req) + } + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(tt.args.ctx, time.Minute) + require.EventuallyWithT(t, func(ttt *assert.CollectT) { + got, err := Client.GetSession(tt.args.ctx, tt.args.req) + if tt.wantErr { + assert.Error(ttt, err) + return + } + if !assert.NoError(ttt, err) { + return + } + + tt.want.Session.Id = tt.args.req.SessionId + tt.want.Session.Sequence = sequence + verifySession(ttt, got.GetSession(), tt.want.GetSession(), time.Minute, tt.wantExpirationWindow, User.GetUserId(), tt.wantFactors...) + }, retryDuration, tick) + }) + } +} + +type sessionAttr struct { + ID string + UserID string + UserAgent string + CreationDate *timestamp.Timestamp + ChangeDate *timestamppb.Timestamp + Details *object.Details +} + +type sessionAttrs []*sessionAttr + +func (u sessionAttrs) ids() []string { + ids := make([]string, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return ids +} + +func createSessions(ctx context.Context, t *testing.T, count int, userID string, userAgent string, lifetime *durationpb.Duration, metadata map[string][]byte) sessionAttrs { + infos := make([]*sessionAttr, count) + for i := 0; i < count; i++ { + infos[i] = createSession(ctx, t, userID, userAgent, lifetime, metadata) + } + return infos +} + +func createSession(ctx context.Context, t *testing.T, userID string, userAgent string, lifetime *durationpb.Duration, metadata map[string][]byte) *sessionAttr { + req := &session.CreateSessionRequest{} + if userID != "" { + req.Checks = &session.Checks{ + User: &session.CheckUser{ + Search: &session.CheckUser_UserId{ + UserId: userID, + }, + }, + } + } + if userAgent != "" { + req.UserAgent = &session.UserAgent{ + FingerprintId: gu.Ptr(userAgent), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + } + } + if lifetime != nil { + req.Lifetime = lifetime + } + if metadata != nil { + req.Metadata = metadata + } + resp, err := Client.CreateSession(ctx, req) + require.NoError(t, err) + return &sessionAttr{ + resp.GetSessionId(), + userID, + userAgent, + resp.GetDetails().GetChangeDate(), + resp.GetDetails().GetChangeDate(), + resp.GetDetails(), + } +} + +func TestServer_ListSessions(t *testing.T) { + type args struct { + ctx context.Context + req *session.ListSessionsRequest + dep func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr + } + tests := []struct { + name string + args args + want *session.ListSessionsResponse + wantFactors []wantFactor + wantExpirationWindow time.Duration + wantErr bool + }{ + { + name: "list sessions, not found", + args: args{ + CTX, + &session.ListSessionsRequest{ + Queries: []*session.SearchQuery{ + {Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{"unknown"}}}}, + }, + }, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + return []*sessionAttr{} + }, + }, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 0, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{}, + }, + }, + { + name: "list sessions, no permission", + args: args{ + UserCTX, + &session.ListSessionsRequest{ + Queries: []*session.SearchQuery{}, + }, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + info := createSession(ctx, t, "", "", nil, nil) + request.Queries = append(request.Queries, &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{info.ID}}}}) + return []*sessionAttr{} + }, + }, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 1, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{}, + }, + }, + { + name: "list sessions, permission, ok", + args: args{ + CTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + info := createSession(ctx, t, "", "", nil, nil) + request.Queries = append(request.Queries, &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{info.ID}}}}) + return []*sessionAttr{info} + }, + }, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 1, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{{}}, + }, + }, + { + name: "list sessions, full, ok", + args: args{ + CTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + info := createSession(ctx, t, User.GetUserId(), "agent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")}) + request.Queries = append(request.Queries, &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{info.ID}}}}) + return []*sessionAttr{info} + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 1, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{ + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + }, + { + name: "list sessions, multiple, ok", + args: args{ + CTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + infos := createSessions(ctx, t, 3, User.GetUserId(), "agent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")}) + request.Queries = append(request.Queries, &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: infos.ids()}}}) + return infos + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 3, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{ + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + }, + { + name: "list sessions, userid, ok", + args: args{ + CTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + createdUser := createFullUser(ctx) + info := createSession(ctx, t, createdUser.GetUserId(), "agent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")}) + request.Queries = append(request.Queries, &session.SearchQuery{Query: &session.SearchQuery_UserIdQuery{UserIdQuery: &session.UserIDQuery{Id: createdUser.GetUserId()}}}) + return []*sessionAttr{info} + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 1, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{ + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + }, + { + name: "list sessions, own creator, ok", + args: args{ + CTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + info := createSession(ctx, t, User.GetUserId(), "agent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")}) + request.Queries = append(request.Queries, + &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{info.ID}}}}, + &session.SearchQuery{Query: &session.SearchQuery_CreatorQuery{CreatorQuery: &session.CreatorQuery{}}}) + return []*sessionAttr{info} + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 1, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{ + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + }, + { + name: "list sessions, creator, ok", + args: args{ + IAMOwnerCTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + info := createSession(ctx, t, User.GetUserId(), "agent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")}) + request.Queries = append(request.Queries, + &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{info.ID}}}}, + &session.SearchQuery{Query: &session.SearchQuery_CreatorQuery{CreatorQuery: &session.CreatorQuery{Id: gu.Ptr(Instance.Users.Get(integration.UserTypeOrgOwner).ID)}}}) + return []*sessionAttr{info} + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 1, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{ + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + }, + { + name: "list sessions, wrong creator", + args: args{ + IAMOwnerCTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + info := createSession(ctx, t, User.GetUserId(), "agent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")}) + request.Queries = append(request.Queries, + &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{info.ID}}}}, + &session.SearchQuery{Query: &session.SearchQuery_CreatorQuery{CreatorQuery: &session.CreatorQuery{}}}) + return []*sessionAttr{} + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 0, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{}, + }, + }, + { + name: "list sessions, empty creator", + args: args{ + IAMOwnerCTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + request.Queries = append(request.Queries, + &session.SearchQuery{Query: &session.SearchQuery_CreatorQuery{CreatorQuery: &session.CreatorQuery{Id: gu.Ptr("")}}}) + return []*sessionAttr{} + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + wantErr: true, + }, + { + name: "list sessions, useragent, ok", + args: args{ + IAMOwnerCTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + info := createSession(ctx, t, User.GetUserId(), "useragent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")}) + request.Queries = append(request.Queries, + &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{info.ID}}}}, + &session.SearchQuery{Query: &session.SearchQuery_UserAgentQuery{UserAgentQuery: &session.UserAgentQuery{FingerprintId: gu.Ptr("useragent")}}}) + return []*sessionAttr{info} + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 1, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{ + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("useragent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + }, + { + name: "list sessions, wrong useragent", + args: args{ + IAMOwnerCTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + info := createSession(ctx, t, User.GetUserId(), "useragent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")}) + request.Queries = append(request.Queries, + &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{info.ID}}}}, + &session.SearchQuery{Query: &session.SearchQuery_UserAgentQuery{UserAgentQuery: &session.UserAgentQuery{FingerprintId: gu.Ptr("wronguseragent")}}}) + return []*sessionAttr{} + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 0, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{}, + }, + }, + { + name: "list sessions, empty useragent", + args: args{ + IAMOwnerCTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + request.Queries = append(request.Queries, + &session.SearchQuery{Query: &session.SearchQuery_UserAgentQuery{UserAgentQuery: &session.UserAgentQuery{FingerprintId: gu.Ptr("")}}}) + return []*sessionAttr{} + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + infos := tt.args.dep(CTX, t, tt.args.req) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(tt.args.ctx, time.Minute) + require.EventuallyWithT(t, func(ttt *assert.CollectT) { + got, err := Client.ListSessions(tt.args.ctx, tt.args.req) + if tt.wantErr { + assert.Error(ttt, err) + return + } + if !assert.NoError(ttt, err) { + return + } + + if !assert.Equal(ttt, got.Details.TotalResult, tt.want.Details.TotalResult) || !assert.Len(ttt, got.Sessions, len(tt.want.Sessions)) { + return + } + + for i := range infos { + tt.want.Sessions[i].Id = infos[i].ID + tt.want.Sessions[i].Sequence = infos[i].Details.GetSequence() + tt.want.Sessions[i].CreationDate = infos[i].Details.GetChangeDate() + tt.want.Sessions[i].ChangeDate = infos[i].Details.GetChangeDate() + + verifySession(ttt, got.Sessions[i], tt.want.Sessions[i], time.Minute, tt.wantExpirationWindow, infos[i].UserID, tt.wantFactors...) + } + integration.AssertListDetails(ttt, tt.want, got) + }, retryDuration, tick) + }) + } +} diff --git a/internal/api/grpc/session/v2/integration_test/server_test.go b/internal/api/grpc/session/v2/integration_test/server_test.go new file mode 100644 index 0000000000..70e2146069 --- /dev/null +++ b/internal/api/grpc/session/v2/integration_test/server_test.go @@ -0,0 +1,74 @@ +//go:build integration + +package session_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/integration" + "github.com/zitadel/zitadel/pkg/grpc/session/v2" + "github.com/zitadel/zitadel/pkg/grpc/user/v2" +) + +var ( + CTX context.Context + IAMOwnerCTX context.Context + UserCTX context.Context + Instance *integration.Instance + Client session.SessionServiceClient + User *user.AddHumanUserResponse + DeactivatedUser *user.AddHumanUserResponse + LockedUser *user.AddHumanUserResponse +) + +func TestMain(m *testing.M) { + os.Exit(func() int { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) + defer cancel() + + Instance = integration.NewInstance(ctx) + Client = Instance.Client.SessionV2 + + CTX = Instance.WithAuthorization(ctx, integration.UserTypeOrgOwner) + IAMOwnerCTX = Instance.WithAuthorization(ctx, integration.UserTypeIAMOwner) + UserCTX = Instance.WithAuthorization(ctx, integration.UserTypeNoPermission) + User = createFullUser(CTX) + DeactivatedUser = createDeactivatedUser(CTX) + LockedUser = createLockedUser(CTX) + return m.Run() + }()) +} + +func createFullUser(ctx context.Context) *user.AddHumanUserResponse { + userResp := Instance.CreateHumanUser(ctx) + Instance.Client.UserV2.VerifyEmail(ctx, &user.VerifyEmailRequest{ + UserId: userResp.GetUserId(), + VerificationCode: userResp.GetEmailCode(), + }) + Instance.Client.UserV2.VerifyPhone(ctx, &user.VerifyPhoneRequest{ + UserId: userResp.GetUserId(), + VerificationCode: userResp.GetPhoneCode(), + }) + Instance.SetUserPassword(ctx, userResp.GetUserId(), integration.UserPassword, false) + Instance.RegisterUserPasskey(ctx, userResp.GetUserId()) + return userResp +} + +func createDeactivatedUser(ctx context.Context) *user.AddHumanUserResponse { + userResp := Instance.CreateHumanUser(ctx) + _, err := Instance.Client.UserV2.DeactivateUser(ctx, &user.DeactivateUserRequest{UserId: userResp.GetUserId()}) + logging.OnError(err).Fatal("deactivate human user") + return userResp +} + +func createLockedUser(ctx context.Context) *user.AddHumanUserResponse { + userResp := Instance.CreateHumanUser(ctx) + _, err := Instance.Client.UserV2.LockUser(ctx, &user.LockUserRequest{UserId: userResp.GetUserId()}) + logging.OnError(err).Fatal("lock human user") + return userResp +} diff --git a/internal/api/grpc/session/v2/integration_test/session_test.go b/internal/api/grpc/session/v2/integration_test/session_test.go index ccd08f3471..7622550b15 100644 --- a/internal/api/grpc/session/v2/integration_test/session_test.go +++ b/internal/api/grpc/session/v2/integration_test/session_test.go @@ -5,7 +5,6 @@ package session_test import ( "context" "fmt" - "os" "testing" "time" @@ -14,7 +13,6 @@ import ( "github.com/pquerna/otp/totp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/zitadel/logging" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -29,63 +27,7 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) -var ( - CTX context.Context - IAMOwnerCTX context.Context - Instance *integration.Instance - Client session.SessionServiceClient - User *user.AddHumanUserResponse - DeactivatedUser *user.AddHumanUserResponse - LockedUser *user.AddHumanUserResponse -) - -func TestMain(m *testing.M) { - os.Exit(func() int { - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) - defer cancel() - - Instance = integration.NewInstance(ctx) - Client = Instance.Client.SessionV2 - - CTX = Instance.WithAuthorization(ctx, integration.UserTypeOrgOwner) - IAMOwnerCTX = Instance.WithAuthorization(ctx, integration.UserTypeIAMOwner) - User = createFullUser(CTX) - DeactivatedUser = createDeactivatedUser(CTX) - LockedUser = createLockedUser(CTX) - return m.Run() - }()) -} - -func createFullUser(ctx context.Context) *user.AddHumanUserResponse { - userResp := Instance.CreateHumanUser(ctx) - Instance.Client.UserV2.VerifyEmail(ctx, &user.VerifyEmailRequest{ - UserId: userResp.GetUserId(), - VerificationCode: userResp.GetEmailCode(), - }) - Instance.Client.UserV2.VerifyPhone(ctx, &user.VerifyPhoneRequest{ - UserId: userResp.GetUserId(), - VerificationCode: userResp.GetPhoneCode(), - }) - Instance.SetUserPassword(ctx, userResp.GetUserId(), integration.UserPassword, false) - Instance.RegisterUserPasskey(ctx, userResp.GetUserId()) - return userResp -} - -func createDeactivatedUser(ctx context.Context) *user.AddHumanUserResponse { - userResp := Instance.CreateHumanUser(ctx) - _, err := Instance.Client.UserV2.DeactivateUser(ctx, &user.DeactivateUserRequest{UserId: userResp.GetUserId()}) - logging.OnError(err).Fatal("deactivate human user") - return userResp -} - -func createLockedUser(ctx context.Context) *user.AddHumanUserResponse { - userResp := Instance.CreateHumanUser(ctx) - _, err := Instance.Client.UserV2.LockUser(ctx, &user.LockUserRequest{UserId: userResp.GetUserId()}) - logging.OnError(err).Fatal("lock human user") - return userResp -} - -func verifyCurrentSession(t testing.TB, id, token string, sequence uint64, window time.Duration, metadata map[string][]byte, userAgent *session.UserAgent, expirationWindow time.Duration, userID string, factors ...wantFactor) *session.Session { +func verifyCurrentSession(t *testing.T, id, token string, sequence uint64, window time.Duration, metadata map[string][]byte, userAgent *session.UserAgent, expirationWindow time.Duration, userID string, factors ...wantFactor) *session.Session { t.Helper() require.NotEmpty(t, id) require.NotEmpty(t, token) @@ -96,15 +38,25 @@ func verifyCurrentSession(t testing.TB, id, token string, sequence uint64, windo }) require.NoError(t, err) s := resp.GetSession() + want := &session.Session{ + Id: id, + Sequence: sequence, + Metadata: metadata, + UserAgent: userAgent, + } + verifySession(t, s, want, window, expirationWindow, userID, factors...) + return s +} - assert.Equal(t, id, s.GetId()) +func verifySession(t assert.TestingT, s *session.Session, want *session.Session, window time.Duration, expirationWindow time.Duration, userID string, factors ...wantFactor) { + assert.Equal(t, want.Id, s.GetId()) assert.WithinRange(t, s.GetCreationDate().AsTime(), time.Now().Add(-window), time.Now().Add(window)) assert.WithinRange(t, s.GetChangeDate().AsTime(), time.Now().Add(-window), time.Now().Add(window)) - assert.Equal(t, sequence, s.GetSequence()) - assert.Equal(t, metadata, s.GetMetadata()) + assert.Equal(t, want.Sequence, s.GetSequence()) + assert.Equal(t, want.Metadata, s.GetMetadata()) - if !proto.Equal(userAgent, s.GetUserAgent()) { - t.Errorf("user agent =\n%v\nwant\n%v", s.GetUserAgent(), userAgent) + if !proto.Equal(want.UserAgent, s.GetUserAgent()) { + t.Errorf("user agent =\n%v\nwant\n%v", s.GetUserAgent(), want.UserAgent) } if expirationWindow == 0 { assert.Nil(t, s.GetExpirationDate()) @@ -113,7 +65,6 @@ func verifyCurrentSession(t testing.TB, id, token string, sequence uint64, windo } verifyFactors(t, s.GetFactors(), window, userID, factors) - return s } type wantFactor int @@ -129,7 +80,7 @@ const ( wantOTPEmailFactor ) -func verifyFactors(t testing.TB, factors *session.Factors, window time.Duration, userID string, want []wantFactor) { +func verifyFactors(t assert.TestingT, factors *session.Factors, window time.Duration, userID string, want []wantFactor) { for _, w := range want { switch w { case wantUserFactor: @@ -194,8 +145,15 @@ func TestServer_CreateSession(t *testing.T) { }, }, { - name: "user agent", + name: "full session", req: &session.CreateSessionRequest{ + Checks: &session.Checks{ + User: &session.CheckUser{ + Search: &session.CheckUser_UserId{ + UserId: User.GetUserId(), + }, + }, + }, Metadata: map[string][]byte{"foo": []byte("bar")}, UserAgent: &session.UserAgent{ FingerprintId: gu.Ptr("fingerPrintID"), @@ -205,6 +163,7 @@ func TestServer_CreateSession(t *testing.T) { "foo": {Values: []string{"foo", "bar"}}, }, }, + Lifetime: durationpb.New(5 * time.Minute), }, want: &session.CreateSessionResponse{ Details: &object.Details{ @@ -212,14 +171,6 @@ func TestServer_CreateSession(t *testing.T) { ResourceOwner: Instance.ID(), }, }, - wantUserAgent: &session.UserAgent{ - FingerprintId: gu.Ptr("fingerPrintID"), - Ip: gu.Ptr("1.2.3.4"), - Description: gu.Ptr("Description"), - Header: map[string]*session.UserAgent_HeaderValues{ - "foo": {Values: []string{"foo", "bar"}}, - }, - }, }, { name: "negative lifetime", @@ -229,40 +180,6 @@ func TestServer_CreateSession(t *testing.T) { }, wantErr: true, }, - { - name: "lifetime", - req: &session.CreateSessionRequest{ - Metadata: map[string][]byte{"foo": []byte("bar")}, - Lifetime: durationpb.New(5 * time.Minute), - }, - want: &session.CreateSessionResponse{ - Details: &object.Details{ - ChangeDate: timestamppb.Now(), - ResourceOwner: Instance.ID(), - }, - }, - wantExpirationWindow: 5 * time.Minute, - }, - { - name: "with user", - req: &session.CreateSessionRequest{ - Checks: &session.Checks{ - User: &session.CheckUser{ - Search: &session.CheckUser_UserId{ - UserId: User.GetUserId(), - }, - }, - }, - Metadata: map[string][]byte{"foo": []byte("bar")}, - }, - want: &session.CreateSessionResponse{ - Details: &object.Details{ - ChangeDate: timestamppb.Now(), - ResourceOwner: Instance.ID(), - }, - }, - wantFactors: []wantFactor{wantUserFactor}, - }, { name: "deactivated user", req: &session.CreateSessionRequest{ @@ -340,8 +257,6 @@ func TestServer_CreateSession(t *testing.T) { } require.NoError(t, err) integration.AssertDetails(t, tt.want, got) - - verifyCurrentSession(t, got.GetSessionId(), got.GetSessionToken(), got.GetDetails().GetSequence(), time.Minute, tt.req.GetMetadata(), tt.wantUserAgent, tt.wantExpirationWindow, User.GetUserId(), tt.wantFactors...) }) } } @@ -946,21 +861,30 @@ func Test_ZITADEL_API_missing_authentication(t *testing.T) { require.NoError(t, err) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("Bearer %s", createResp.GetSessionToken())) - sessionResp, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: createResp.GetSessionId()}) - require.Error(t, err) - require.Nil(t, sessionResp) + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + sessionResp, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: createResp.GetSessionId()}) + if !assert.Error(tt, err) { + return + } + assert.Nil(tt, sessionResp) + }, retryDuration, tick) } func Test_ZITADEL_API_success(t *testing.T) { id, token, _, _ := Instance.CreateVerifiedWebAuthNSession(t, CTX, User.GetUserId()) - ctx := integration.WithAuthorizationToken(context.Background(), token) - sessionResp, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) - require.NoError(t, err) - webAuthN := sessionResp.GetSession().GetFactors().GetWebAuthN() - require.NotNil(t, id, webAuthN.GetVerifiedAt().AsTime()) - require.True(t, webAuthN.GetUserVerified()) + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + sessionResp, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) + if !assert.NoError(tt, err) { + return + } + webAuthN := sessionResp.GetSession().GetFactors().GetWebAuthN() + assert.NotNil(tt, id, webAuthN.GetVerifiedAt().AsTime()) + assert.True(tt, webAuthN.GetUserVerified()) + }, retryDuration, tick) } func Test_ZITADEL_API_session_not_found(t *testing.T) { @@ -968,18 +892,30 @@ func Test_ZITADEL_API_session_not_found(t *testing.T) { // test session token works ctx := integration.WithAuthorizationToken(context.Background(), token) - _, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) - require.NoError(t, err) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + _, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) + if !assert.NoError(tt, err) { + return + } + }, retryDuration, tick) //terminate the session and test it does not work anymore - _, err = Client.DeleteSession(CTX, &session.DeleteSessionRequest{ + _, err := Client.DeleteSession(CTX, &session.DeleteSessionRequest{ SessionId: id, SessionToken: gu.Ptr(token), }) require.NoError(t, err) + ctx = integration.WithAuthorizationToken(context.Background(), token) - _, err = Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) - require.Error(t, err) + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + _, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) + if !assert.Error(tt, err) { + return + } + }, retryDuration, tick) } func Test_ZITADEL_API_session_expired(t *testing.T) { @@ -987,8 +923,13 @@ func Test_ZITADEL_API_session_expired(t *testing.T) { // test session token works ctx := integration.WithAuthorizationToken(context.Background(), token) - _, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) - require.NoError(t, err) + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + _, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) + if !assert.NoError(tt, err) { + return + } + }, retryDuration, tick) // ensure session expires and does not work anymore time.Sleep(20 * time.Second) diff --git a/internal/api/grpc/session/v2/query.go b/internal/api/grpc/session/v2/query.go new file mode 100644 index 0000000000..73303dd9e8 --- /dev/null +++ b/internal/api/grpc/session/v2/query.go @@ -0,0 +1,262 @@ +package session + +import ( + "context" + "time" + + "github.com/muhlemmer/gu" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/api/grpc/object/v2" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/query" + "github.com/zitadel/zitadel/internal/zerrors" + objpb "github.com/zitadel/zitadel/pkg/grpc/object" + "github.com/zitadel/zitadel/pkg/grpc/session/v2" +) + +var ( + timestampComparisons = map[objpb.TimestampQueryMethod]query.TimestampComparison{ + objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_EQUALS: query.TimestampEquals, + objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER: query.TimestampGreater, + objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER_OR_EQUALS: query.TimestampGreaterOrEquals, + objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS: query.TimestampLess, + objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS_OR_EQUALS: query.TimestampLessOrEquals, + } +) + +func (s *Server) GetSession(ctx context.Context, req *session.GetSessionRequest) (*session.GetSessionResponse, error) { + res, err := s.query.SessionByID(ctx, true, req.GetSessionId(), req.GetSessionToken(), s.checkPermission) + if err != nil { + return nil, err + } + return &session.GetSessionResponse{ + Session: sessionToPb(res), + }, nil +} + +func (s *Server) ListSessions(ctx context.Context, req *session.ListSessionsRequest) (*session.ListSessionsResponse, error) { + queries, err := listSessionsRequestToQuery(ctx, req) + if err != nil { + return nil, err + } + sessions, err := s.query.SearchSessions(ctx, queries, s.checkPermission) + if err != nil { + return nil, err + } + return &session.ListSessionsResponse{ + Details: object.ToListDetails(sessions.SearchResponse), + Sessions: sessionsToPb(sessions.Sessions), + }, nil +} + +func listSessionsRequestToQuery(ctx context.Context, req *session.ListSessionsRequest) (*query.SessionsSearchQueries, error) { + offset, limit, asc := object.ListQueryToQuery(req.Query) + queries, err := sessionQueriesToQuery(ctx, req.GetQueries()) + if err != nil { + return nil, err + } + return &query.SessionsSearchQueries{ + SearchRequest: query.SearchRequest{ + Offset: offset, + Limit: limit, + Asc: asc, + SortingColumn: fieldNameToSessionColumn(req.GetSortingColumn()), + }, + Queries: queries, + }, nil +} + +func sessionQueriesToQuery(ctx context.Context, queries []*session.SearchQuery) (_ []query.SearchQuery, err error) { + q := make([]query.SearchQuery, len(queries)) + for i, v := range queries { + q[i], err = sessionQueryToQuery(ctx, v) + if err != nil { + return nil, err + } + } + return q, nil +} + +func sessionQueryToQuery(ctx context.Context, sq *session.SearchQuery) (query.SearchQuery, error) { + switch q := sq.Query.(type) { + case *session.SearchQuery_IdsQuery: + return idsQueryToQuery(q.IdsQuery) + case *session.SearchQuery_UserIdQuery: + return query.NewUserIDSearchQuery(q.UserIdQuery.GetId()) + case *session.SearchQuery_CreationDateQuery: + return creationDateQueryToQuery(q.CreationDateQuery) + case *session.SearchQuery_CreatorQuery: + if q.CreatorQuery != nil && q.CreatorQuery.Id != nil { + if q.CreatorQuery.GetId() != "" { + return query.NewSessionCreatorSearchQuery(q.CreatorQuery.GetId()) + } + } else { + if userID := authz.GetCtxData(ctx).UserID; userID != "" { + return query.NewSessionCreatorSearchQuery(userID) + } + } + return nil, zerrors.ThrowInvalidArgument(nil, "GRPC-x8n24uh", "List.Query.Invalid") + case *session.SearchQuery_UserAgentQuery: + if q.UserAgentQuery != nil && q.UserAgentQuery.FingerprintId != nil { + if *q.UserAgentQuery.FingerprintId != "" { + return query.NewSessionUserAgentFingerprintIDSearchQuery(q.UserAgentQuery.GetFingerprintId()) + } + } else { + if agentID := authz.GetCtxData(ctx).AgentID; agentID != "" { + return query.NewSessionUserAgentFingerprintIDSearchQuery(agentID) + } + } + return nil, zerrors.ThrowInvalidArgument(nil, "GRPC-x8n23uh", "List.Query.Invalid") + default: + return nil, zerrors.ThrowInvalidArgument(nil, "GRPC-Sfefs", "List.Query.Invalid") + } +} + +func idsQueryToQuery(q *session.IDsQuery) (query.SearchQuery, error) { + return query.NewSessionIDsSearchQuery(q.Ids) +} + +func creationDateQueryToQuery(q *session.CreationDateQuery) (query.SearchQuery, error) { + comparison := timestampComparisons[q.GetMethod()] + return query.NewCreationDateQuery(q.GetCreationDate().AsTime(), comparison) +} + +func fieldNameToSessionColumn(field session.SessionFieldName) query.Column { + switch field { + case session.SessionFieldName_SESSION_FIELD_NAME_CREATION_DATE: + return query.SessionColumnCreationDate + case session.SessionFieldName_SESSION_FIELD_NAME_UNSPECIFIED: + return query.Column{} + default: + return query.Column{} + } +} + +func sessionsToPb(sessions []*query.Session) []*session.Session { + s := make([]*session.Session, len(sessions)) + for i, session := range sessions { + s[i] = sessionToPb(session) + } + return s +} + +func sessionToPb(s *query.Session) *session.Session { + return &session.Session{ + Id: s.ID, + CreationDate: timestamppb.New(s.CreationDate), + ChangeDate: timestamppb.New(s.ChangeDate), + Sequence: s.Sequence, + Factors: factorsToPb(s), + Metadata: s.Metadata, + UserAgent: userAgentToPb(s.UserAgent), + ExpirationDate: expirationToPb(s.Expiration), + } +} + +func userAgentToPb(ua domain.UserAgent) *session.UserAgent { + if ua.IsEmpty() { + return nil + } + + out := &session.UserAgent{ + FingerprintId: ua.FingerprintID, + Description: ua.Description, + } + if ua.IP != nil { + out.Ip = gu.Ptr(ua.IP.String()) + } + if ua.Header == nil { + return out + } + out.Header = make(map[string]*session.UserAgent_HeaderValues, len(ua.Header)) + for k, v := range ua.Header { + out.Header[k] = &session.UserAgent_HeaderValues{ + Values: v, + } + } + return out +} + +func expirationToPb(expiration time.Time) *timestamppb.Timestamp { + if expiration.IsZero() { + return nil + } + return timestamppb.New(expiration) +} + +func factorsToPb(s *query.Session) *session.Factors { + user := userFactorToPb(s.UserFactor) + if user == nil { + return nil + } + return &session.Factors{ + User: user, + Password: passwordFactorToPb(s.PasswordFactor), + WebAuthN: webAuthNFactorToPb(s.WebAuthNFactor), + Intent: intentFactorToPb(s.IntentFactor), + Totp: totpFactorToPb(s.TOTPFactor), + OtpSms: otpFactorToPb(s.OTPSMSFactor), + OtpEmail: otpFactorToPb(s.OTPEmailFactor), + } +} + +func passwordFactorToPb(factor query.SessionPasswordFactor) *session.PasswordFactor { + if factor.PasswordCheckedAt.IsZero() { + return nil + } + return &session.PasswordFactor{ + VerifiedAt: timestamppb.New(factor.PasswordCheckedAt), + } +} + +func intentFactorToPb(factor query.SessionIntentFactor) *session.IntentFactor { + if factor.IntentCheckedAt.IsZero() { + return nil + } + return &session.IntentFactor{ + VerifiedAt: timestamppb.New(factor.IntentCheckedAt), + } +} + +func webAuthNFactorToPb(factor query.SessionWebAuthNFactor) *session.WebAuthNFactor { + if factor.WebAuthNCheckedAt.IsZero() { + return nil + } + return &session.WebAuthNFactor{ + VerifiedAt: timestamppb.New(factor.WebAuthNCheckedAt), + UserVerified: factor.UserVerified, + } +} + +func totpFactorToPb(factor query.SessionTOTPFactor) *session.TOTPFactor { + if factor.TOTPCheckedAt.IsZero() { + return nil + } + return &session.TOTPFactor{ + VerifiedAt: timestamppb.New(factor.TOTPCheckedAt), + } +} + +func otpFactorToPb(factor query.SessionOTPFactor) *session.OTPFactor { + if factor.OTPCheckedAt.IsZero() { + return nil + } + return &session.OTPFactor{ + VerifiedAt: timestamppb.New(factor.OTPCheckedAt), + } +} + +func userFactorToPb(factor query.SessionUserFactor) *session.UserFactor { + if factor.UserID == "" || factor.UserCheckedAt.IsZero() { + return nil + } + return &session.UserFactor{ + VerifiedAt: timestamppb.New(factor.UserCheckedAt), + Id: factor.UserID, + LoginName: factor.LoginName, + DisplayName: factor.DisplayName, + OrganizationId: factor.ResourceOwner, + } +} diff --git a/internal/api/grpc/session/v2/server.go b/internal/api/grpc/session/v2/server.go index e94336bf47..ee534cb26c 100644 --- a/internal/api/grpc/session/v2/server.go +++ b/internal/api/grpc/session/v2/server.go @@ -6,6 +6,7 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/server" "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/pkg/grpc/session/v2" ) @@ -16,6 +17,8 @@ type Server struct { session.UnimplementedSessionServiceServer command *command.Commands query *query.Queries + + checkPermission domain.PermissionCheck } type Config struct{} @@ -23,10 +26,12 @@ type Config struct{} func CreateServer( command *command.Commands, query *query.Queries, + checkPermission domain.PermissionCheck, ) *Server { return &Server{ - command: command, - query: query, + command: command, + query: query, + checkPermission: checkPermission, } } diff --git a/internal/api/grpc/session/v2/session.go b/internal/api/grpc/session/v2/session.go index aa25fa0ae3..7562d64350 100644 --- a/internal/api/grpc/session/v2/session.go +++ b/internal/api/grpc/session/v2/session.go @@ -6,56 +6,17 @@ import ( "net/http" "time" - "github.com/muhlemmer/gu" "golang.org/x/text/language" "google.golang.org/protobuf/types/known/structpb" - "google.golang.org/protobuf/types/known/timestamppb" - "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/object/v2" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/zerrors" - objpb "github.com/zitadel/zitadel/pkg/grpc/object" "github.com/zitadel/zitadel/pkg/grpc/session/v2" ) -var ( - timestampComparisons = map[objpb.TimestampQueryMethod]query.TimestampComparison{ - objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_EQUALS: query.TimestampEquals, - objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER: query.TimestampGreater, - objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER_OR_EQUALS: query.TimestampGreaterOrEquals, - objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS: query.TimestampLess, - objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS_OR_EQUALS: query.TimestampLessOrEquals, - } -) - -func (s *Server) GetSession(ctx context.Context, req *session.GetSessionRequest) (*session.GetSessionResponse, error) { - res, err := s.query.SessionByID(ctx, true, req.GetSessionId(), req.GetSessionToken()) - if err != nil { - return nil, err - } - return &session.GetSessionResponse{ - Session: sessionToPb(res), - }, nil -} - -func (s *Server) ListSessions(ctx context.Context, req *session.ListSessionsRequest) (*session.ListSessionsResponse, error) { - queries, err := listSessionsRequestToQuery(ctx, req) - if err != nil { - return nil, err - } - sessions, err := s.query.SearchSessions(ctx, queries) - if err != nil { - return nil, err - } - return &session.ListSessionsResponse{ - Details: object.ToListDetails(sessions.SearchResponse), - Sessions: sessionsToPb(sessions.Sessions), - }, nil -} - func (s *Server) CreateSession(ctx context.Context, req *session.CreateSessionRequest) (*session.CreateSessionResponse, error) { checks, metadata, userAgent, lifetime, err := s.createSessionRequestToCommand(ctx, req) if err != nil { @@ -110,197 +71,6 @@ func (s *Server) DeleteSession(ctx context.Context, req *session.DeleteSessionRe }, nil } -func sessionsToPb(sessions []*query.Session) []*session.Session { - s := make([]*session.Session, len(sessions)) - for i, session := range sessions { - s[i] = sessionToPb(session) - } - return s -} - -func sessionToPb(s *query.Session) *session.Session { - return &session.Session{ - Id: s.ID, - CreationDate: timestamppb.New(s.CreationDate), - ChangeDate: timestamppb.New(s.ChangeDate), - Sequence: s.Sequence, - Factors: factorsToPb(s), - Metadata: s.Metadata, - UserAgent: userAgentToPb(s.UserAgent), - ExpirationDate: expirationToPb(s.Expiration), - } -} - -func userAgentToPb(ua domain.UserAgent) *session.UserAgent { - if ua.IsEmpty() { - return nil - } - - out := &session.UserAgent{ - FingerprintId: ua.FingerprintID, - Description: ua.Description, - } - if ua.IP != nil { - out.Ip = gu.Ptr(ua.IP.String()) - } - if ua.Header == nil { - return out - } - out.Header = make(map[string]*session.UserAgent_HeaderValues, len(ua.Header)) - for k, v := range ua.Header { - out.Header[k] = &session.UserAgent_HeaderValues{ - Values: v, - } - } - return out -} - -func expirationToPb(expiration time.Time) *timestamppb.Timestamp { - if expiration.IsZero() { - return nil - } - return timestamppb.New(expiration) -} - -func factorsToPb(s *query.Session) *session.Factors { - user := userFactorToPb(s.UserFactor) - if user == nil { - return nil - } - return &session.Factors{ - User: user, - Password: passwordFactorToPb(s.PasswordFactor), - WebAuthN: webAuthNFactorToPb(s.WebAuthNFactor), - Intent: intentFactorToPb(s.IntentFactor), - Totp: totpFactorToPb(s.TOTPFactor), - OtpSms: otpFactorToPb(s.OTPSMSFactor), - OtpEmail: otpFactorToPb(s.OTPEmailFactor), - } -} - -func passwordFactorToPb(factor query.SessionPasswordFactor) *session.PasswordFactor { - if factor.PasswordCheckedAt.IsZero() { - return nil - } - return &session.PasswordFactor{ - VerifiedAt: timestamppb.New(factor.PasswordCheckedAt), - } -} - -func intentFactorToPb(factor query.SessionIntentFactor) *session.IntentFactor { - if factor.IntentCheckedAt.IsZero() { - return nil - } - return &session.IntentFactor{ - VerifiedAt: timestamppb.New(factor.IntentCheckedAt), - } -} - -func webAuthNFactorToPb(factor query.SessionWebAuthNFactor) *session.WebAuthNFactor { - if factor.WebAuthNCheckedAt.IsZero() { - return nil - } - return &session.WebAuthNFactor{ - VerifiedAt: timestamppb.New(factor.WebAuthNCheckedAt), - UserVerified: factor.UserVerified, - } -} - -func totpFactorToPb(factor query.SessionTOTPFactor) *session.TOTPFactor { - if factor.TOTPCheckedAt.IsZero() { - return nil - } - return &session.TOTPFactor{ - VerifiedAt: timestamppb.New(factor.TOTPCheckedAt), - } -} - -func otpFactorToPb(factor query.SessionOTPFactor) *session.OTPFactor { - if factor.OTPCheckedAt.IsZero() { - return nil - } - return &session.OTPFactor{ - VerifiedAt: timestamppb.New(factor.OTPCheckedAt), - } -} - -func userFactorToPb(factor query.SessionUserFactor) *session.UserFactor { - if factor.UserID == "" || factor.UserCheckedAt.IsZero() { - return nil - } - return &session.UserFactor{ - VerifiedAt: timestamppb.New(factor.UserCheckedAt), - Id: factor.UserID, - LoginName: factor.LoginName, - DisplayName: factor.DisplayName, - OrganizationId: factor.ResourceOwner, - } -} - -func listSessionsRequestToQuery(ctx context.Context, req *session.ListSessionsRequest) (*query.SessionsSearchQueries, error) { - offset, limit, asc := object.ListQueryToQuery(req.Query) - queries, err := sessionQueriesToQuery(ctx, req.GetQueries()) - if err != nil { - return nil, err - } - return &query.SessionsSearchQueries{ - SearchRequest: query.SearchRequest{ - Offset: offset, - Limit: limit, - Asc: asc, - SortingColumn: fieldNameToSessionColumn(req.GetSortingColumn()), - }, - Queries: queries, - }, nil -} - -func sessionQueriesToQuery(ctx context.Context, queries []*session.SearchQuery) (_ []query.SearchQuery, err error) { - q := make([]query.SearchQuery, len(queries)+1) - for i, v := range queries { - q[i], err = sessionQueryToQuery(v) - if err != nil { - return nil, err - } - } - creatorQuery, err := query.NewSessionCreatorSearchQuery(authz.GetCtxData(ctx).UserID) - if err != nil { - return nil, err - } - q[len(queries)] = creatorQuery - return q, nil -} - -func sessionQueryToQuery(sq *session.SearchQuery) (query.SearchQuery, error) { - switch q := sq.Query.(type) { - case *session.SearchQuery_IdsQuery: - return idsQueryToQuery(q.IdsQuery) - case *session.SearchQuery_UserIdQuery: - return query.NewUserIDSearchQuery(q.UserIdQuery.GetId()) - case *session.SearchQuery_CreationDateQuery: - return creationDateQueryToQuery(q.CreationDateQuery) - default: - return nil, zerrors.ThrowInvalidArgument(nil, "GRPC-Sfefs", "List.Query.Invalid") - } -} - -func idsQueryToQuery(q *session.IDsQuery) (query.SearchQuery, error) { - return query.NewSessionIDsSearchQuery(q.Ids) -} - -func creationDateQueryToQuery(q *session.CreationDateQuery) (query.SearchQuery, error) { - comparison := timestampComparisons[q.GetMethod()] - return query.NewCreationDateQuery(q.GetCreationDate().AsTime(), comparison) -} - -func fieldNameToSessionColumn(field session.SessionFieldName) query.Column { - switch field { - case session.SessionFieldName_SESSION_FIELD_NAME_CREATION_DATE: - return query.SessionColumnCreationDate - default: - return query.Column{} - } -} - func (s *Server) createSessionRequestToCommand(ctx context.Context, req *session.CreateSessionRequest) ([]command.SessionCommand, map[string][]byte, *domain.UserAgent, time.Duration, error) { checks, err := s.checksToCommand(ctx, req.Checks) if err != nil { diff --git a/internal/api/grpc/session/v2/session_test.go b/internal/api/grpc/session/v2/session_test.go index 917be882f8..ce4f5115f2 100644 --- a/internal/api/grpc/session/v2/session_test.go +++ b/internal/api/grpc/session/v2/session_test.go @@ -339,9 +339,7 @@ func Test_listSessionsRequestToQuery(t *testing.T) { Limit: 0, Asc: false, }, - Queries: []query.SearchQuery{ - mustNewTextQuery(t, query.SessionColumnCreator, "789", query.TextEquals), - }, + Queries: []query.SearchQuery{}, }, }, { @@ -359,15 +357,13 @@ func Test_listSessionsRequestToQuery(t *testing.T) { SortingColumn: query.SessionColumnCreationDate, Asc: false, }, - Queries: []query.SearchQuery{ - mustNewTextQuery(t, query.SessionColumnCreator, "789", query.TextEquals), - }, + Queries: []query.SearchQuery{}, }, }, { name: "with list query and sessions", args: args{ - ctx: authz.NewMockContext("123", "456", "789"), + ctx: authz.SetCtxData(context.Background(), authz.CtxData{AgentID: "agent", UserID: "789"}), req: &session.ListSessionsRequest{ Query: &object.ListQuery{ Offset: 10, @@ -396,6 +392,12 @@ func Test_listSessionsRequestToQuery(t *testing.T) { Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_GREATER, }, }}, + {Query: &session.SearchQuery_CreatorQuery{ + CreatorQuery: &session.CreatorQuery{}, + }}, + {Query: &session.SearchQuery_UserAgentQuery{ + UserAgentQuery: &session.UserAgentQuery{}, + }}, }, }, }, @@ -411,6 +413,7 @@ func Test_listSessionsRequestToQuery(t *testing.T) { mustNewTextQuery(t, query.SessionColumnUserID, "10", query.TextEquals), mustNewTimestampQuery(t, query.SessionColumnCreationDate, creationDate, query.TimestampGreater), mustNewTextQuery(t, query.SessionColumnCreator, "789", query.TextEquals), + mustNewTextQuery(t, query.SessionColumnUserAgentFingerprintID, "agent", query.TextEquals), }, }, }, @@ -458,13 +461,11 @@ func Test_sessionQueriesToQuery(t *testing.T) { wantErr error }{ { - name: "creator only", + name: "no queries", args: args{ ctx: authz.NewMockContext("123", "456", "789"), }, - want: []query.SearchQuery{ - mustNewTextQuery(t, query.SessionColumnCreator, "789", query.TextEquals), - }, + want: []query.SearchQuery{}, }, { name: "invalid argument", @@ -491,6 +492,9 @@ func Test_sessionQueriesToQuery(t *testing.T) { Ids: []string{"4", "5", "6"}, }, }}, + {Query: &session.SearchQuery_CreatorQuery{ + CreatorQuery: &session.CreatorQuery{}, + }}, }, }, want: []query.SearchQuery{ @@ -511,6 +515,7 @@ func Test_sessionQueriesToQuery(t *testing.T) { func Test_sessionQueryToQuery(t *testing.T) { type args struct { + ctx context.Context query *session.SearchQuery } tests := []struct { @@ -521,60 +526,158 @@ func Test_sessionQueryToQuery(t *testing.T) { }{ { name: "invalid argument", - args: args{&session.SearchQuery{ - Query: nil, - }}, + args: args{ + context.Background(), + &session.SearchQuery{ + Query: nil, + }}, wantErr: zerrors.ThrowInvalidArgument(nil, "GRPC-Sfefs", "List.Query.Invalid"), }, { name: "ids query", - args: args{&session.SearchQuery{ - Query: &session.SearchQuery_IdsQuery{ - IdsQuery: &session.IDsQuery{ - Ids: []string{"1", "2", "3"}, + args: args{ + context.Background(), + &session.SearchQuery{ + Query: &session.SearchQuery_IdsQuery{ + IdsQuery: &session.IDsQuery{ + Ids: []string{"1", "2", "3"}, + }, }, - }, - }}, + }}, want: mustNewListQuery(t, query.SessionColumnID, []interface{}{"1", "2", "3"}, query.ListIn), }, { name: "user id query", - args: args{&session.SearchQuery{ - Query: &session.SearchQuery_UserIdQuery{ - UserIdQuery: &session.UserIDQuery{ - Id: "10", + args: args{ + context.Background(), + &session.SearchQuery{ + Query: &session.SearchQuery_UserIdQuery{ + UserIdQuery: &session.UserIDQuery{ + Id: "10", + }, }, - }, - }}, + }}, want: mustNewTextQuery(t, query.SessionColumnUserID, "10", query.TextEquals), }, { name: "creation date query", - args: args{&session.SearchQuery{ - Query: &session.SearchQuery_CreationDateQuery{ - CreationDateQuery: &session.CreationDateQuery{ - CreationDate: timestamppb.New(creationDate), - Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS, + args: args{ + context.Background(), + &session.SearchQuery{ + Query: &session.SearchQuery_CreationDateQuery{ + CreationDateQuery: &session.CreationDateQuery{ + CreationDate: timestamppb.New(creationDate), + Method: objpb.TimestampQueryMethod_TIMESTAMP_QUERY_METHOD_LESS, + }, }, - }, - }}, + }}, want: mustNewTimestampQuery(t, query.SessionColumnCreationDate, creationDate, query.TimestampLess), }, { name: "creation date query with default method", - args: args{&session.SearchQuery{ - Query: &session.SearchQuery_CreationDateQuery{ - CreationDateQuery: &session.CreationDateQuery{ - CreationDate: timestamppb.New(creationDate), + args: args{ + context.Background(), + &session.SearchQuery{ + Query: &session.SearchQuery_CreationDateQuery{ + CreationDateQuery: &session.CreationDateQuery{ + CreationDate: timestamppb.New(creationDate), + }, }, - }, - }}, + }}, want: mustNewTimestampQuery(t, query.SessionColumnCreationDate, creationDate, query.TimestampEquals), }, + { + name: "own creator", + args: args{ + authz.SetCtxData(context.Background(), authz.CtxData{UserID: "creator"}), + &session.SearchQuery{ + Query: &session.SearchQuery_CreatorQuery{ + CreatorQuery: &session.CreatorQuery{}, + }, + }}, + want: mustNewTextQuery(t, query.SessionColumnCreator, "creator", query.TextEquals), + }, + { + name: "empty own creator, error", + args: args{ + authz.SetCtxData(context.Background(), authz.CtxData{UserID: ""}), + &session.SearchQuery{ + Query: &session.SearchQuery_CreatorQuery{ + CreatorQuery: &session.CreatorQuery{}, + }, + }}, + wantErr: zerrors.ThrowInvalidArgument(nil, "GRPC-x8n24uh", "List.Query.Invalid"), + }, + { + name: "creator", + args: args{ + authz.SetCtxData(context.Background(), authz.CtxData{UserID: "creator1"}), + &session.SearchQuery{ + Query: &session.SearchQuery_CreatorQuery{ + CreatorQuery: &session.CreatorQuery{Id: gu.Ptr("creator2")}, + }, + }}, + want: mustNewTextQuery(t, query.SessionColumnCreator, "creator2", query.TextEquals), + }, + { + name: "empty creator, error", + args: args{ + authz.SetCtxData(context.Background(), authz.CtxData{UserID: "creator1"}), + &session.SearchQuery{ + Query: &session.SearchQuery_CreatorQuery{ + CreatorQuery: &session.CreatorQuery{Id: gu.Ptr("")}, + }, + }}, + wantErr: zerrors.ThrowInvalidArgument(nil, "GRPC-x8n24uh", "List.Query.Invalid"), + }, + { + name: "empty own useragent, error", + args: args{ + authz.SetCtxData(context.Background(), authz.CtxData{AgentID: ""}), + &session.SearchQuery{ + Query: &session.SearchQuery_UserAgentQuery{ + UserAgentQuery: &session.UserAgentQuery{}, + }, + }}, + wantErr: zerrors.ThrowInvalidArgument(nil, "GRPC-x8n23uh", "List.Query.Invalid"), + }, + { + name: "own useragent", + args: args{ + authz.SetCtxData(context.Background(), authz.CtxData{AgentID: "agent"}), + &session.SearchQuery{ + Query: &session.SearchQuery_UserAgentQuery{ + UserAgentQuery: &session.UserAgentQuery{}, + }, + }}, + want: mustNewTextQuery(t, query.SessionColumnUserAgentFingerprintID, "agent", query.TextEquals), + }, + { + name: "empty useragent, error", + args: args{ + authz.SetCtxData(context.Background(), authz.CtxData{AgentID: "agent"}), + &session.SearchQuery{ + Query: &session.SearchQuery_UserAgentQuery{ + UserAgentQuery: &session.UserAgentQuery{FingerprintId: gu.Ptr("")}, + }, + }}, + wantErr: zerrors.ThrowInvalidArgument(nil, "GRPC-x8n23uh", "List.Query.Invalid"), + }, + { + name: "useragent", + args: args{ + authz.SetCtxData(context.Background(), authz.CtxData{AgentID: "agent1"}), + &session.SearchQuery{ + Query: &session.SearchQuery_UserAgentQuery{ + UserAgentQuery: &session.UserAgentQuery{FingerprintId: gu.Ptr("agent2")}, + }, + }}, + want: mustNewTextQuery(t, query.SessionColumnUserAgentFingerprintID, "agent2", query.TextEquals), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := sessionQueryToQuery(tt.args.query) + got, err := sessionQueryToQuery(tt.args.ctx, tt.args.query) require.ErrorIs(t, err, tt.wantErr) assert.Equal(t, tt.want, got) }) diff --git a/internal/api/grpc/session/v2beta/integration_test/query_test.go b/internal/api/grpc/session/v2beta/integration_test/query_test.go new file mode 100644 index 0000000000..b347ba8224 --- /dev/null +++ b/internal/api/grpc/session/v2beta/integration_test/query_test.go @@ -0,0 +1,512 @@ +//go:build integration + +package session_test + +import ( + "context" + "testing" + "time" + + "github.com/golang/protobuf/ptypes/timestamp" + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/zitadel/zitadel/internal/integration" + object "github.com/zitadel/zitadel/pkg/grpc/object/v2beta" + session "github.com/zitadel/zitadel/pkg/grpc/session/v2beta" +) + +func TestServer_GetSession(t *testing.T) { + type args struct { + ctx context.Context + req *session.GetSessionRequest + dep func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 + } + tests := []struct { + name string + args args + want *session.GetSessionResponse + wantFactors []wantFactor + wantExpirationWindow time.Duration + wantErr bool + }{ + { + name: "get session, no id provided", + args: args{ + CTX, + &session.GetSessionRequest{ + SessionId: "", + }, + nil, + }, + wantErr: true, + }, + { + name: "get session, not found", + args: args{ + CTX, + &session.GetSessionRequest{ + SessionId: "unknown", + }, + nil, + }, + wantErr: true, + }, + { + name: "get session, no permission", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{}) + require.NoError(t, err) + request.SessionId = resp.SessionId + return resp.GetDetails().GetSequence() + }, + }, + wantErr: true, + }, + { + name: "get session, permission, ok", + args: args{ + CTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{}) + require.NoError(t, err) + request.SessionId = resp.SessionId + return resp.GetDetails().GetSequence() + }, + }, + want: &session.GetSessionResponse{ + Session: &session.Session{}, + }, + }, + { + name: "get session, token, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{}) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + want: &session.GetSessionResponse{ + Session: &session.Session{}, + }, + }, + { + name: "get session, user agent, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{ + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("fingerPrintID"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + ) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + want: &session.GetSessionResponse{ + Session: &session.Session{ + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("fingerPrintID"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + { + name: "get session, lifetime, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{ + Lifetime: durationpb.New(5 * time.Minute), + }, + ) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + wantExpirationWindow: 5 * time.Minute, + want: &session.GetSessionResponse{ + Session: &session.Session{}, + }, + }, + { + name: "get session, metadata, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{ + Metadata: map[string][]byte{"foo": []byte("bar")}, + }, + ) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + want: &session.GetSessionResponse{ + Session: &session.Session{ + Metadata: map[string][]byte{"foo": []byte("bar")}, + }, + }, + }, + { + name: "get session, user, ok", + args: args{ + UserCTX, + &session.GetSessionRequest{}, + func(ctx context.Context, t *testing.T, request *session.GetSessionRequest) uint64 { + resp, err := Client.CreateSession(CTX, &session.CreateSessionRequest{ + Checks: &session.Checks{ + User: &session.CheckUser{ + Search: &session.CheckUser_UserId{ + UserId: User.GetUserId(), + }, + }, + }, + }, + ) + require.NoError(t, err) + request.SessionId = resp.SessionId + request.SessionToken = gu.Ptr(resp.SessionToken) + return resp.GetDetails().GetSequence() + }, + }, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.GetSessionResponse{ + Session: &session.Session{}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var sequence uint64 + if tt.args.dep != nil { + sequence = tt.args.dep(tt.args.ctx, t, tt.args.req) + } + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(tt.args.ctx, time.Minute) + require.EventuallyWithT(t, func(ttt *assert.CollectT) { + got, err := Client.GetSession(tt.args.ctx, tt.args.req) + if tt.wantErr { + assert.Error(ttt, err) + return + } + if !assert.NoError(ttt, err) { + return + } + + tt.want.Session.Id = tt.args.req.SessionId + tt.want.Session.Sequence = sequence + verifySession(ttt, got.GetSession(), tt.want.GetSession(), time.Minute, tt.wantExpirationWindow, User.GetUserId(), tt.wantFactors...) + }, retryDuration, tick) + }) + } +} + +type sessionAttr struct { + ID string + UserID string + UserAgent string + CreationDate *timestamp.Timestamp + ChangeDate *timestamppb.Timestamp + Details *object.Details +} + +type sessionAttrs []*sessionAttr + +func (u sessionAttrs) ids() []string { + ids := make([]string, len(u)) + for i := range u { + ids[i] = u[i].ID + } + return ids +} + +func createSessions(ctx context.Context, t *testing.T, count int, userID string, userAgent string, lifetime *durationpb.Duration, metadata map[string][]byte) sessionAttrs { + infos := make([]*sessionAttr, count) + for i := 0; i < count; i++ { + infos[i] = createSession(ctx, t, userID, userAgent, lifetime, metadata) + } + return infos +} + +func createSession(ctx context.Context, t *testing.T, userID string, userAgent string, lifetime *durationpb.Duration, metadata map[string][]byte) *sessionAttr { + req := &session.CreateSessionRequest{} + if userID != "" { + req.Checks = &session.Checks{ + User: &session.CheckUser{ + Search: &session.CheckUser_UserId{ + UserId: userID, + }, + }, + } + } + if userAgent != "" { + req.UserAgent = &session.UserAgent{ + FingerprintId: gu.Ptr(userAgent), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + } + } + if lifetime != nil { + req.Lifetime = lifetime + } + if metadata != nil { + req.Metadata = metadata + } + resp, err := Client.CreateSession(ctx, req) + require.NoError(t, err) + return &sessionAttr{ + resp.GetSessionId(), + userID, + userAgent, + resp.GetDetails().GetChangeDate(), + resp.GetDetails().GetChangeDate(), + resp.GetDetails(), + } +} + +func TestServer_ListSessions(t *testing.T) { + type args struct { + ctx context.Context + req *session.ListSessionsRequest + dep func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr + } + tests := []struct { + name string + args args + want *session.ListSessionsResponse + wantFactors []wantFactor + wantExpirationWindow time.Duration + wantErr bool + }{ + { + name: "list sessions, not found", + args: args{ + CTX, + &session.ListSessionsRequest{ + Queries: []*session.SearchQuery{ + {Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{"unknown"}}}}, + }, + }, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + return []*sessionAttr{} + }, + }, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 0, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{}, + }, + }, + { + name: "list sessions, wrong creator", + args: args{ + UserCTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + info := createSession(ctx, t, "", "", nil, nil) + request.Queries = append(request.Queries, &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{info.ID}}}}) + return []*sessionAttr{} + }, + }, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 0, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{}, + }, + }, + { + name: "list sessions, full, ok", + args: args{ + CTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + info := createSession(ctx, t, User.GetUserId(), "agent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")}) + request.Queries = append(request.Queries, &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: []string{info.ID}}}}) + return []*sessionAttr{info} + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 1, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{ + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + }, + { + name: "list sessions, multiple, ok", + args: args{ + CTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + infos := createSessions(ctx, t, 3, User.GetUserId(), "agent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")}) + request.Queries = append(request.Queries, &session.SearchQuery{Query: &session.SearchQuery_IdsQuery{IdsQuery: &session.IDsQuery{Ids: infos.ids()}}}) + return infos + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 3, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{ + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + }, + { + name: "list sessions, userid, ok", + args: args{ + CTX, + &session.ListSessionsRequest{}, + func(ctx context.Context, t *testing.T, request *session.ListSessionsRequest) []*sessionAttr { + createdUser := createFullUser(ctx) + info := createSession(ctx, t, createdUser.GetUserId(), "agent", durationpb.New(time.Minute*5), map[string][]byte{"key": []byte("value")}) + request.Queries = append(request.Queries, &session.SearchQuery{Query: &session.SearchQuery_UserIdQuery{UserIdQuery: &session.UserIDQuery{Id: createdUser.GetUserId()}}}) + return []*sessionAttr{info} + }, + }, + wantExpirationWindow: time.Minute * 5, + wantFactors: []wantFactor{wantUserFactor}, + want: &session.ListSessionsResponse{ + Details: &object.ListDetails{ + TotalResult: 1, + Timestamp: timestamppb.Now(), + }, + Sessions: []*session.Session{ + { + Metadata: map[string][]byte{"key": []byte("value")}, + UserAgent: &session.UserAgent{ + FingerprintId: gu.Ptr("agent"), + Ip: gu.Ptr("1.2.3.4"), + Description: gu.Ptr("Description"), + Header: map[string]*session.UserAgent_HeaderValues{ + "foo": {Values: []string{"foo", "bar"}}, + }, + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + infos := tt.args.dep(CTX, t, tt.args.req) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(tt.args.ctx, time.Minute) + require.EventuallyWithT(t, func(ttt *assert.CollectT) { + got, err := Client.ListSessions(tt.args.ctx, tt.args.req) + if tt.wantErr { + assert.Error(ttt, err) + return + } + if !assert.NoError(ttt, err) { + return + } + + if !assert.Equal(ttt, got.Details.TotalResult, tt.want.Details.TotalResult) || !assert.Len(ttt, got.Sessions, len(tt.want.Sessions)) { + return + } + + for i := range infos { + tt.want.Sessions[i].Id = infos[i].ID + tt.want.Sessions[i].Sequence = infos[i].Details.GetSequence() + tt.want.Sessions[i].CreationDate = infos[i].Details.GetChangeDate() + tt.want.Sessions[i].ChangeDate = infos[i].Details.GetChangeDate() + + verifySession(ttt, got.Sessions[i], tt.want.Sessions[i], time.Minute, tt.wantExpirationWindow, infos[i].UserID, tt.wantFactors...) + } + integration.AssertListDetails(ttt, tt.want, got) + }, retryDuration, tick) + }) + } +} diff --git a/internal/api/grpc/session/v2beta/integration_test/server_test.go b/internal/api/grpc/session/v2beta/integration_test/server_test.go new file mode 100644 index 0000000000..4920e6ec35 --- /dev/null +++ b/internal/api/grpc/session/v2beta/integration_test/server_test.go @@ -0,0 +1,74 @@ +//go:build integration + +package session_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/integration" + session "github.com/zitadel/zitadel/pkg/grpc/session/v2beta" + "github.com/zitadel/zitadel/pkg/grpc/user/v2" +) + +var ( + CTX context.Context + IAMOwnerCTX context.Context + UserCTX context.Context + Instance *integration.Instance + Client session.SessionServiceClient + User *user.AddHumanUserResponse + DeactivatedUser *user.AddHumanUserResponse + LockedUser *user.AddHumanUserResponse +) + +func TestMain(m *testing.M) { + os.Exit(func() int { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) + defer cancel() + + Instance = integration.NewInstance(ctx) + Client = Instance.Client.SessionV2beta + + CTX = Instance.WithAuthorization(ctx, integration.UserTypeOrgOwner) + IAMOwnerCTX = Instance.WithAuthorization(ctx, integration.UserTypeIAMOwner) + UserCTX = Instance.WithAuthorization(ctx, integration.UserTypeNoPermission) + User = createFullUser(CTX) + DeactivatedUser = createDeactivatedUser(CTX) + LockedUser = createLockedUser(CTX) + return m.Run() + }()) +} + +func createFullUser(ctx context.Context) *user.AddHumanUserResponse { + userResp := Instance.CreateHumanUser(ctx) + Instance.Client.UserV2.VerifyEmail(ctx, &user.VerifyEmailRequest{ + UserId: userResp.GetUserId(), + VerificationCode: userResp.GetEmailCode(), + }) + Instance.Client.UserV2.VerifyPhone(ctx, &user.VerifyPhoneRequest{ + UserId: userResp.GetUserId(), + VerificationCode: userResp.GetPhoneCode(), + }) + Instance.SetUserPassword(ctx, userResp.GetUserId(), integration.UserPassword, false) + Instance.RegisterUserPasskey(ctx, userResp.GetUserId()) + return userResp +} + +func createDeactivatedUser(ctx context.Context) *user.AddHumanUserResponse { + userResp := Instance.CreateHumanUser(ctx) + _, err := Instance.Client.UserV2.DeactivateUser(ctx, &user.DeactivateUserRequest{UserId: userResp.GetUserId()}) + logging.OnError(err).Fatal("deactivate human user") + return userResp +} + +func createLockedUser(ctx context.Context) *user.AddHumanUserResponse { + userResp := Instance.CreateHumanUser(ctx) + _, err := Instance.Client.UserV2.LockUser(ctx, &user.LockUserRequest{UserId: userResp.GetUserId()}) + logging.OnError(err).Fatal("lock human user") + return userResp +} diff --git a/internal/api/grpc/session/v2beta/integration_test/session_test.go b/internal/api/grpc/session/v2beta/integration_test/session_test.go index 52e355204d..26d2291629 100644 --- a/internal/api/grpc/session/v2beta/integration_test/session_test.go +++ b/internal/api/grpc/session/v2beta/integration_test/session_test.go @@ -5,7 +5,6 @@ package session_test import ( "context" "fmt" - "os" "testing" "time" @@ -14,7 +13,6 @@ import ( "github.com/pquerna/otp/totp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/zitadel/logging" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -29,62 +27,6 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/user/v2" ) -var ( - CTX context.Context - IAMOwnerCTX context.Context - Instance *integration.Instance - Client session.SessionServiceClient - User *user.AddHumanUserResponse - DeactivatedUser *user.AddHumanUserResponse - LockedUser *user.AddHumanUserResponse -) - -func TestMain(m *testing.M) { - os.Exit(func() int { - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) - defer cancel() - - Instance = integration.NewInstance(ctx) - Client = Instance.Client.SessionV2beta - - CTX = Instance.WithAuthorization(ctx, integration.UserTypeOrgOwner) - IAMOwnerCTX = Instance.WithAuthorization(ctx, integration.UserTypeIAMOwner) - User = createFullUser(CTX) - DeactivatedUser = createDeactivatedUser(CTX) - LockedUser = createLockedUser(CTX) - return m.Run() - }()) -} - -func createFullUser(ctx context.Context) *user.AddHumanUserResponse { - userResp := Instance.CreateHumanUser(ctx) - Instance.Client.UserV2.VerifyEmail(ctx, &user.VerifyEmailRequest{ - UserId: userResp.GetUserId(), - VerificationCode: userResp.GetEmailCode(), - }) - Instance.Client.UserV2.VerifyPhone(ctx, &user.VerifyPhoneRequest{ - UserId: userResp.GetUserId(), - VerificationCode: userResp.GetPhoneCode(), - }) - Instance.SetUserPassword(ctx, userResp.GetUserId(), integration.UserPassword, false) - Instance.RegisterUserPasskey(ctx, userResp.GetUserId()) - return userResp -} - -func createDeactivatedUser(ctx context.Context) *user.AddHumanUserResponse { - userResp := Instance.CreateHumanUser(ctx) - _, err := Instance.Client.UserV2.DeactivateUser(ctx, &user.DeactivateUserRequest{UserId: userResp.GetUserId()}) - logging.OnError(err).Fatal("deactivate human user") - return userResp -} - -func createLockedUser(ctx context.Context) *user.AddHumanUserResponse { - userResp := Instance.CreateHumanUser(ctx) - _, err := Instance.Client.UserV2.LockUser(ctx, &user.LockUserRequest{UserId: userResp.GetUserId()}) - logging.OnError(err).Fatal("lock human user") - return userResp -} - func verifyCurrentSession(t testing.TB, id, token string, sequence uint64, window time.Duration, metadata map[string][]byte, userAgent *session.UserAgent, expirationWindow time.Duration, userID string, factors ...wantFactor) *session.Session { t.Helper() require.NotEmpty(t, id) @@ -96,15 +38,25 @@ func verifyCurrentSession(t testing.TB, id, token string, sequence uint64, windo }) require.NoError(t, err) s := resp.GetSession() + want := &session.Session{ + Id: id, + Sequence: sequence, + Metadata: metadata, + UserAgent: userAgent, + } + verifySession(t, s, want, window, expirationWindow, userID, factors...) + return s +} - assert.Equal(t, id, s.GetId()) +func verifySession(t assert.TestingT, s *session.Session, want *session.Session, window time.Duration, expirationWindow time.Duration, userID string, factors ...wantFactor) { + assert.Equal(t, want.Id, s.GetId()) assert.WithinRange(t, s.GetCreationDate().AsTime(), time.Now().Add(-window), time.Now().Add(window)) assert.WithinRange(t, s.GetChangeDate().AsTime(), time.Now().Add(-window), time.Now().Add(window)) - assert.Equal(t, sequence, s.GetSequence()) - assert.Equal(t, metadata, s.GetMetadata()) + assert.Equal(t, want.Sequence, s.GetSequence()) + assert.Equal(t, want.Metadata, s.GetMetadata()) - if !proto.Equal(userAgent, s.GetUserAgent()) { - t.Errorf("user agent =\n%v\nwant\n%v", s.GetUserAgent(), userAgent) + if !proto.Equal(want.UserAgent, s.GetUserAgent()) { + t.Errorf("user agent =\n%v\nwant\n%v", s.GetUserAgent(), want.UserAgent) } if expirationWindow == 0 { assert.Nil(t, s.GetExpirationDate()) @@ -113,7 +65,6 @@ func verifyCurrentSession(t testing.TB, id, token string, sequence uint64, windo } verifyFactors(t, s.GetFactors(), window, userID, factors) - return s } type wantFactor int @@ -129,7 +80,7 @@ const ( wantOTPEmailFactor ) -func verifyFactors(t testing.TB, factors *session.Factors, window time.Duration, userID string, want []wantFactor) { +func verifyFactors(t assert.TestingT, factors *session.Factors, window time.Duration, userID string, want []wantFactor) { for _, w := range want { switch w { case wantUserFactor: @@ -194,8 +145,15 @@ func TestServer_CreateSession(t *testing.T) { }, }, { - name: "user agent", + name: "full session", req: &session.CreateSessionRequest{ + Checks: &session.Checks{ + User: &session.CheckUser{ + Search: &session.CheckUser_UserId{ + UserId: User.GetUserId(), + }, + }, + }, Metadata: map[string][]byte{"foo": []byte("bar")}, UserAgent: &session.UserAgent{ FingerprintId: gu.Ptr("fingerPrintID"), @@ -205,6 +163,7 @@ func TestServer_CreateSession(t *testing.T) { "foo": {Values: []string{"foo", "bar"}}, }, }, + Lifetime: durationpb.New(5 * time.Minute), }, want: &session.CreateSessionResponse{ Details: &object.Details{ @@ -212,14 +171,6 @@ func TestServer_CreateSession(t *testing.T) { ResourceOwner: Instance.ID(), }, }, - wantUserAgent: &session.UserAgent{ - FingerprintId: gu.Ptr("fingerPrintID"), - Ip: gu.Ptr("1.2.3.4"), - Description: gu.Ptr("Description"), - Header: map[string]*session.UserAgent_HeaderValues{ - "foo": {Values: []string{"foo", "bar"}}, - }, - }, }, { name: "negative lifetime", @@ -229,40 +180,6 @@ func TestServer_CreateSession(t *testing.T) { }, wantErr: true, }, - { - name: "lifetime", - req: &session.CreateSessionRequest{ - Metadata: map[string][]byte{"foo": []byte("bar")}, - Lifetime: durationpb.New(5 * time.Minute), - }, - want: &session.CreateSessionResponse{ - Details: &object.Details{ - ChangeDate: timestamppb.Now(), - ResourceOwner: Instance.ID(), - }, - }, - wantExpirationWindow: 5 * time.Minute, - }, - { - name: "with user", - req: &session.CreateSessionRequest{ - Checks: &session.Checks{ - User: &session.CheckUser{ - Search: &session.CheckUser_UserId{ - UserId: User.GetUserId(), - }, - }, - }, - Metadata: map[string][]byte{"foo": []byte("bar")}, - }, - want: &session.CreateSessionResponse{ - Details: &object.Details{ - ChangeDate: timestamppb.Now(), - ResourceOwner: Instance.ID(), - }, - }, - wantFactors: []wantFactor{wantUserFactor}, - }, { name: "deactivated user", req: &session.CreateSessionRequest{ @@ -340,8 +257,6 @@ func TestServer_CreateSession(t *testing.T) { } require.NoError(t, err) integration.AssertDetails(t, tt.want, got) - - verifyCurrentSession(t, got.GetSessionId(), got.GetSessionToken(), got.GetDetails().GetSequence(), time.Minute, tt.req.GetMetadata(), tt.wantUserAgent, tt.wantExpirationWindow, User.GetUserId(), tt.wantFactors...) }) } } @@ -946,21 +861,30 @@ func Test_ZITADEL_API_missing_authentication(t *testing.T) { require.NoError(t, err) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("Bearer %s", createResp.GetSessionToken())) - sessionResp, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: createResp.GetSessionId()}) - require.Error(t, err) - require.Nil(t, sessionResp) + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + sessionResp, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: createResp.GetSessionId()}) + if !assert.Error(tt, err) { + return + } + assert.Nil(tt, sessionResp) + }, retryDuration, tick) } func Test_ZITADEL_API_success(t *testing.T) { id, token, _, _ := Instance.CreateVerifiedWebAuthNSession(t, CTX, User.GetUserId()) - ctx := integration.WithAuthorizationToken(context.Background(), token) - sessionResp, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) - require.NoError(t, err) - webAuthN := sessionResp.GetSession().GetFactors().GetWebAuthN() - require.NotNil(t, id, webAuthN.GetVerifiedAt().AsTime()) - require.True(t, webAuthN.GetUserVerified()) + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + sessionResp, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) + if !assert.NoError(tt, err) { + return + } + webAuthN := sessionResp.GetSession().GetFactors().GetWebAuthN() + assert.NotNil(tt, id, webAuthN.GetVerifiedAt().AsTime()) + assert.True(tt, webAuthN.GetUserVerified()) + }, retryDuration, tick) } func Test_ZITADEL_API_session_not_found(t *testing.T) { @@ -968,18 +892,30 @@ func Test_ZITADEL_API_session_not_found(t *testing.T) { // test session token works ctx := integration.WithAuthorizationToken(context.Background(), token) - _, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) - require.NoError(t, err) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + _, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) + if !assert.NoError(tt, err) { + return + } + }, retryDuration, tick) //terminate the session and test it does not work anymore - _, err = Client.DeleteSession(CTX, &session.DeleteSessionRequest{ + _, err := Client.DeleteSession(CTX, &session.DeleteSessionRequest{ SessionId: id, SessionToken: gu.Ptr(token), }) require.NoError(t, err) + ctx = integration.WithAuthorizationToken(context.Background(), token) - _, err = Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) - require.Error(t, err) + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + _, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) + if !assert.Error(tt, err) { + return + } + }, retryDuration, tick) } func Test_ZITADEL_API_session_expired(t *testing.T) { @@ -987,8 +923,13 @@ func Test_ZITADEL_API_session_expired(t *testing.T) { // test session token works ctx := integration.WithAuthorizationToken(context.Background(), token) - _, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) - require.NoError(t, err) + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + _, err := Client.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) + if !assert.NoError(tt, err) { + return + } + }, retryDuration, tick) // ensure session expires and does not work anymore time.Sleep(20 * time.Second) diff --git a/internal/api/grpc/session/v2beta/server.go b/internal/api/grpc/session/v2beta/server.go index 550d013ad5..cf0d0c27f0 100644 --- a/internal/api/grpc/session/v2beta/server.go +++ b/internal/api/grpc/session/v2beta/server.go @@ -6,6 +6,7 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/grpc/server" "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query" session "github.com/zitadel/zitadel/pkg/grpc/session/v2beta" ) @@ -16,6 +17,8 @@ type Server struct { session.UnimplementedSessionServiceServer command *command.Commands query *query.Queries + + checkPermission domain.PermissionCheck } type Config struct{} @@ -23,10 +26,12 @@ type Config struct{} func CreateServer( command *command.Commands, query *query.Queries, + checkPermission domain.PermissionCheck, ) *Server { return &Server{ - command: command, - query: query, + command: command, + query: query, + checkPermission: checkPermission, } } diff --git a/internal/api/grpc/session/v2beta/session.go b/internal/api/grpc/session/v2beta/session.go index 7e67a4b3ff..3b36b8ba83 100644 --- a/internal/api/grpc/session/v2beta/session.go +++ b/internal/api/grpc/session/v2beta/session.go @@ -32,7 +32,7 @@ var ( ) func (s *Server) GetSession(ctx context.Context, req *session.GetSessionRequest) (*session.GetSessionResponse, error) { - res, err := s.query.SessionByID(ctx, true, req.GetSessionId(), req.GetSessionToken()) + res, err := s.query.SessionByID(ctx, true, req.GetSessionId(), req.GetSessionToken(), s.checkPermission) if err != nil { return nil, err } @@ -46,7 +46,7 @@ func (s *Server) ListSessions(ctx context.Context, req *session.ListSessionsRequ if err != nil { return nil, err } - sessions, err := s.query.SearchSessions(ctx, queries) + sessions, err := s.query.SearchSessions(ctx, queries, s.checkPermission) if err != nil { return nil, err } diff --git a/internal/api/grpc/user/v2/integration_test/user_test.go b/internal/api/grpc/user/v2/integration_test/user_test.go index 8d4c254c6b..1d6d12241a 100644 --- a/internal/api/grpc/user/v2/integration_test/user_test.go +++ b/internal/api/grpc/user/v2/integration_test/user_test.go @@ -10,12 +10,11 @@ import ( "testing" "time" - "github.com/zitadel/logging" - "github.com/brianvoe/gofakeit/v6" "github.com/muhlemmer/gu" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/zitadel/logging" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" diff --git a/internal/api/http/header.go b/internal/api/http/header.go index 982684c77c..a6c2818728 100644 --- a/internal/api/http/header.go +++ b/internal/api/http/header.go @@ -5,6 +5,8 @@ import ( "net" "net/http" "strings" + + "github.com/gorilla/mux" ) const ( @@ -14,6 +16,7 @@ const ( CacheControl = "cache-control" ContentType = "content-type" ContentLength = "content-length" + ContentLocation = "content-location" Expires = "expires" Location = "location" Origin = "origin" @@ -42,6 +45,9 @@ const ( PermissionsPolicy = "permissions-policy" ZitadelOrgID = "x-zitadel-orgid" + + OrgIdInPathVariableName = "orgId" + OrgIdInPathVariable = "{" + OrgIdInPathVariableName + "}" ) type key int @@ -104,6 +110,12 @@ func GetAuthorization(r *http.Request) string { } func GetOrgID(r *http.Request) string { + // path variable takes precedence over header + orgID, ok := mux.Vars(r)[OrgIdInPathVariableName] + if ok { + return orgID + } + return r.Header.Get(ZitadelOrgID) } diff --git a/internal/api/http/middleware/auth_interceptor.go b/internal/api/http/middleware/auth_interceptor.go index c327d8c846..1581d401b4 100644 --- a/internal/api/http/middleware/auth_interceptor.go +++ b/internal/api/http/middleware/auth_interceptor.go @@ -2,12 +2,15 @@ package middleware import ( "context" - "errors" "net/http" + "strings" + + "github.com/gorilla/mux" "github.com/zitadel/zitadel/internal/api/authz" http_util "github.com/zitadel/zitadel/internal/api/http" "github.com/zitadel/zitadel/internal/telemetry/tracing" + "github.com/zitadel/zitadel/internal/zerrors" ) type AuthInterceptor struct { @@ -23,34 +26,40 @@ func AuthorizationInterceptor(verifier authz.APITokenVerifier, authConfig authz. } func (a *AuthInterceptor) Handler(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx, err := authorize(r, a.verifier, a.authConfig) - if err != nil { - http.Error(w, err.Error(), http.StatusUnauthorized) - return - } - r = r.WithContext(ctx) - next.ServeHTTP(w, r) - }) + return a.HandlerFunc(next) } -func (a *AuthInterceptor) HandlerFunc(next http.HandlerFunc) http.HandlerFunc { +func (a *AuthInterceptor) HandlerFunc(next http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx, err := authorize(r, a.verifier, a.authConfig) if err != nil { http.Error(w, err.Error(), http.StatusUnauthorized) return } + r = r.WithContext(ctx) next.ServeHTTP(w, r) } } +func (a *AuthInterceptor) HandlerFuncWithError(next HandlerFuncWithError) HandlerFuncWithError { + return func(w http.ResponseWriter, r *http.Request) error { + ctx, err := authorize(r, a.verifier, a.authConfig) + if err != nil { + return err + } + + r = r.WithContext(ctx) + return next(w, r) + } +} + type httpReq struct{} func authorize(r *http.Request, verifier authz.APITokenVerifier, authConfig authz.Config) (_ context.Context, err error) { ctx := r.Context() - authOpt, needsToken := verifier.CheckAuthMethod(r.Method + ":" + r.RequestURI) + + authOpt, needsToken := checkAuthMethod(r, verifier) if !needsToken { return ctx, nil } @@ -59,7 +68,7 @@ func authorize(r *http.Request, verifier authz.APITokenVerifier, authConfig auth authToken := http_util.GetAuthorization(r) if authToken == "" { - return nil, errors.New("auth header missing") + return nil, zerrors.ThrowUnauthenticated(nil, "AUT-1179", "auth header missing") } ctxSetter, err := authz.CheckUserAuthorization(authCtx, &httpReq{}, authToken, http_util.GetOrgID(r), "", verifier, authConfig, authOpt, r.RequestURI) @@ -69,3 +78,30 @@ func authorize(r *http.Request, verifier authz.APITokenVerifier, authConfig auth span.End() return ctxSetter(ctx), nil } + +func checkAuthMethod(r *http.Request, verifier authz.APITokenVerifier) (authz.Option, bool) { + authOpt, needsToken := verifier.CheckAuthMethod(r.Method + ":" + r.RequestURI) + if needsToken { + return authOpt, true + } + + route := mux.CurrentRoute(r) + if route == nil { + return authOpt, false + } + + pathTemplate, err := route.GetPathTemplate() + if err != nil || pathTemplate == "" { + return authOpt, false + } + + // the path prefix is usually handled in a router in upper layer + // trim the query and the path of the url to get the correct path prefix + pathPrefix := r.RequestURI + if i := strings.Index(pathPrefix, "?"); i != -1 { + pathPrefix = pathPrefix[0:i] + } + pathPrefix = strings.TrimSuffix(pathPrefix, r.URL.Path) + + return verifier.CheckAuthMethod(r.Method + ":" + pathPrefix + pathTemplate) +} diff --git a/internal/api/http/middleware/handler.go b/internal/api/http/middleware/handler.go new file mode 100644 index 0000000000..2c79b6227a --- /dev/null +++ b/internal/api/http/middleware/handler.go @@ -0,0 +1,26 @@ +package middleware + +import "net/http" + +// HandlerFuncWithError is a http handler func which can return an error +// the error should then get handled later on in the pipeline by an error handler +// the error handler can be dependent on the interface standard (e.g. SCIM, Problem Details, ...) +type HandlerFuncWithError = func(w http.ResponseWriter, r *http.Request) error + +// MiddlewareWithErrorFunc is a http middleware which can return an error +// the error should then get handled later on in the pipeline by an error handler +// the error handler can be dependent on the interface standard (e.g. SCIM, Problem Details, ...) +type MiddlewareWithErrorFunc = func(HandlerFuncWithError) HandlerFuncWithError + +// ErrorHandlerFunc handles errors and returns a regular http handler +type ErrorHandlerFunc = func(HandlerFuncWithError) http.Handler + +func ChainedWithErrorHandler(errorHandler ErrorHandlerFunc, middlewares ...MiddlewareWithErrorFunc) func(HandlerFuncWithError) http.Handler { + return func(next HandlerFuncWithError) http.Handler { + for i := len(middlewares) - 1; i >= 0; i-- { + next = middlewares[i](next) + } + + return errorHandler(next) + } +} diff --git a/internal/api/http/middleware/instance_interceptor.go b/internal/api/http/middleware/instance_interceptor.go index facb2ceec0..3ae5dfbb88 100644 --- a/internal/api/http/middleware/instance_interceptor.go +++ b/internal/api/http/middleware/instance_interceptor.go @@ -34,43 +34,57 @@ func InstanceInterceptor(verifier authz.InstanceVerifier, externalDomain string, } func (a *instanceInterceptor) Handler(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - a.handleInstance(w, r, next) - }) + return a.HandlerFunc(next) } -func (a *instanceInterceptor) HandlerFunc(next http.HandlerFunc) http.HandlerFunc { +func (a *instanceInterceptor) HandlerFunc(next http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - a.handleInstance(w, r, next) - } -} - -func (a *instanceInterceptor) handleInstance(w http.ResponseWriter, r *http.Request, next http.Handler) { - for _, prefix := range a.ignoredPrefixes { - if strings.HasPrefix(r.URL.Path, prefix) { + ctx, err := a.setInstanceIfNeeded(r.Context(), r) + if err == nil { + r = r.WithContext(ctx) next.ServeHTTP(w, r) return } - } - ctx, err := setInstance(r, a.verifier) - if err != nil { + origin := zitadel_http.DomainContext(r.Context()) logging.WithFields("origin", origin.Origin(), "externalDomain", a.externalDomain).WithError(err).Error("unable to set instance") + zErr := new(zerrors.ZitadelError) if errors.As(err, &zErr) { zErr.SetMessage(a.translator.LocalizeFromRequest(r, zErr.GetMessage(), nil)) http.Error(w, fmt.Sprintf("unable to set instance using origin %s (ExternalDomain is %s): %s", origin, a.externalDomain, zErr), http.StatusNotFound) return } + http.Error(w, fmt.Sprintf("unable to set instance using origin %s (ExternalDomain is %s)", origin, a.externalDomain), http.StatusNotFound) - return } - r = r.WithContext(ctx) - next.ServeHTTP(w, r) } -func setInstance(r *http.Request, verifier authz.InstanceVerifier) (_ context.Context, err error) { - ctx := r.Context() +func (a *instanceInterceptor) HandlerFuncWithError(next HandlerFuncWithError) HandlerFuncWithError { + return func(w http.ResponseWriter, r *http.Request) error { + ctx, err := a.setInstanceIfNeeded(r.Context(), r) + if err != nil { + origin := zitadel_http.DomainContext(r.Context()) + logging.WithFields("origin", origin.Origin(), "externalDomain", a.externalDomain).WithError(err).Error("unable to set instance") + return err + } + + r = r.WithContext(ctx) + return next(w, r) + } +} + +func (a *instanceInterceptor) setInstanceIfNeeded(ctx context.Context, r *http.Request) (context.Context, error) { + for _, prefix := range a.ignoredPrefixes { + if strings.HasPrefix(r.URL.Path, prefix) { + return ctx, nil + } + } + + return setInstance(ctx, a.verifier) +} + +func setInstance(ctx context.Context, verifier authz.InstanceVerifier) (_ context.Context, err error) { authCtx, span := tracing.NewServerInterceptorSpan(ctx) defer func() { span.EndWithError(err) }() diff --git a/internal/api/http/middleware/instance_interceptor_test.go b/internal/api/http/middleware/instance_interceptor_test.go index 51c0fb9a10..da831dff65 100644 --- a/internal/api/http/middleware/instance_interceptor_test.go +++ b/internal/api/http/middleware/instance_interceptor_test.go @@ -72,7 +72,7 @@ func Test_instanceInterceptor_Handler(t *testing.T) { translator: newZitadelTranslator(), } next := &testHandler{} - got := a.HandlerFunc(next.ServeHTTP) + got := a.HandlerFunc(next) rr := httptest.NewRecorder() got.ServeHTTP(rr, tt.args.request) assert.Equal(t, tt.res.statusCode, rr.Code) @@ -136,7 +136,7 @@ func Test_instanceInterceptor_HandlerFunc(t *testing.T) { translator: newZitadelTranslator(), } next := &testHandler{} - got := a.HandlerFunc(next.ServeHTTP) + got := a.HandlerFunc(next) rr := httptest.NewRecorder() got.ServeHTTP(rr, tt.args.request) assert.Equal(t, tt.res.statusCode, rr.Code) @@ -145,9 +145,78 @@ func Test_instanceInterceptor_HandlerFunc(t *testing.T) { } } +func Test_instanceInterceptor_HandlerFuncWithError(t *testing.T) { + type fields struct { + verifier authz.InstanceVerifier + } + type args struct { + request *http.Request + } + type res struct { + wantErr bool + context context.Context + } + tests := []struct { + name string + fields fields + args args + res res + }{ + { + "setInstance error", + fields{ + verifier: &mockInstanceVerifier{}, + }, + args{ + request: httptest.NewRequest("", "/url", nil), + }, + res{ + wantErr: true, + context: nil, + }, + }, + { + "setInstance ok", + fields{ + verifier: &mockInstanceVerifier{instanceHost: "host"}, + }, + args{ + request: func() *http.Request { + r := httptest.NewRequest("", "/url", nil) + r = r.WithContext(zitadel_http.WithDomainContext(r.Context(), &zitadel_http.DomainCtx{InstanceHost: "host"})) + return r + }(), + }, + res{ + context: authz.WithInstance(zitadel_http.WithDomainContext(context.Background(), &zitadel_http.DomainCtx{InstanceHost: "host"}), &mockInstance{}), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &instanceInterceptor{ + verifier: tt.fields.verifier, + translator: newZitadelTranslator(), + } + var ctx context.Context + got := a.HandlerFuncWithError(func(w http.ResponseWriter, r *http.Request) error { + ctx = r.Context() + return nil + }) + rr := httptest.NewRecorder() + err := got(rr, tt.args.request) + if (err != nil) != tt.res.wantErr { + t.Errorf("got error %v, want %v", err, tt.res.wantErr) + } + + assert.Equal(t, tt.res.context, ctx) + }) + } +} + func Test_setInstance(t *testing.T) { type args struct { - r *http.Request + ctx context.Context verifier authz.InstanceVerifier } type res struct { @@ -162,10 +231,7 @@ func Test_setInstance(t *testing.T) { { "no domain context, not found error", args{ - r: func() *http.Request { - r := httptest.NewRequest("", "/url", nil) - return r - }(), + ctx: context.Background(), verifier: &mockInstanceVerifier{}, }, res{ @@ -176,10 +242,7 @@ func Test_setInstance(t *testing.T) { { "instanceHost found, ok", args{ - r: func() *http.Request { - r := httptest.NewRequest("", "/url", nil) - return r.WithContext(zitadel_http.WithDomainContext(r.Context(), &zitadel_http.DomainCtx{InstanceHost: "host", Protocol: "https"})) - }(), + ctx: zitadel_http.WithDomainContext(context.Background(), &zitadel_http.DomainCtx{InstanceHost: "host", Protocol: "https"}), verifier: &mockInstanceVerifier{instanceHost: "host"}, }, res{ @@ -190,10 +253,7 @@ func Test_setInstance(t *testing.T) { { "instanceHost not found, error", args{ - r: func() *http.Request { - r := httptest.NewRequest("", "/url", nil) - return r.WithContext(zitadel_http.WithDomainContext(r.Context(), &zitadel_http.DomainCtx{InstanceHost: "fromorigin:9999", Protocol: "https"})) - }(), + ctx: zitadel_http.WithDomainContext(context.Background(), &zitadel_http.DomainCtx{InstanceHost: "fromorigin:9999", Protocol: "https"}), verifier: &mockInstanceVerifier{instanceHost: "unknowndomain"}, }, res{ @@ -204,7 +264,7 @@ func Test_setInstance(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := setInstance(tt.args.r, tt.args.verifier) + got, err := setInstance(tt.args.ctx, tt.args.verifier) if (err != nil) != tt.res.err { t.Errorf("setInstance() error = %v, wantErr %v", err, tt.res.err) return diff --git a/internal/api/scim/authz.go b/internal/api/scim/authz.go new file mode 100644 index 0000000000..1ab174e7b3 --- /dev/null +++ b/internal/api/scim/authz.go @@ -0,0 +1,22 @@ +package scim + +import ( + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/api/http" + "github.com/zitadel/zitadel/internal/domain" +) + +var AuthMapping = authz.MethodMapping{ + "POST:/scim/v2/" + http.OrgIdInPathVariable + "/Users": { + Permission: domain.PermissionUserWrite, + }, + "GET:/scim/v2/" + http.OrgIdInPathVariable + "/Users/{id}": { + Permission: domain.PermissionUserRead, + }, + "PUT:/scim/v2/" + http.OrgIdInPathVariable + "/Users/{id}": { + Permission: domain.PermissionUserWrite, + }, + "DELETE:/scim/v2/" + http.OrgIdInPathVariable + "/Users/{id}": { + Permission: domain.PermissionUserDelete, + }, +} diff --git a/internal/api/scim/config/config.go b/internal/api/scim/config/config.go new file mode 100644 index 0000000000..6199f0a2ea --- /dev/null +++ b/internal/api/scim/config/config.go @@ -0,0 +1,6 @@ +package config + +type Config struct { + EmailVerified bool + PhoneVerified bool +} diff --git a/internal/api/scim/integration_test/scim_test.go b/internal/api/scim/integration_test/scim_test.go new file mode 100644 index 0000000000..84c4d96bec --- /dev/null +++ b/internal/api/scim/integration_test/scim_test.go @@ -0,0 +1,29 @@ +//go:build integration + +package integration_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/zitadel/zitadel/internal/integration" +) + +var ( + Instance *integration.Instance + CTX context.Context +) + +func TestMain(m *testing.M) { + os.Exit(func() int { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) + defer cancel() + + Instance = integration.NewInstance(ctx) + + CTX = Instance.WithAuthorization(ctx, integration.UserTypeOrgOwner) + return m.Run() + }()) +} diff --git a/internal/api/scim/integration_test/testdata/users_create_test_full.json b/internal/api/scim/integration_test/testdata/users_create_test_full.json new file mode 100644 index 0000000000..7879ecf160 --- /dev/null +++ b/internal/api/scim/integration_test/testdata/users_create_test_full.json @@ -0,0 +1,116 @@ +{ + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], + "externalId": "701984", + "userName": "bjensen@example.com", + "name": { + "formatted": "Ms. Barbara J Jensen, III", + "familyName": "Jensen", + "givenName": "Barbara", + "middleName": "Jane", + "honorificPrefix": "Ms.", + "honorificSuffix": "III" + }, + "displayName": "Babs Jensen", + "nickName": "Babs", + "profileUrl": "http://login.example.com/bjensen", + "emails": [ + { + "value": "bjensen@example.com", + "type": "work", + "primary": true + }, + { + "value": "babs@jensen.org", + "type": "home" + } + ], + "addresses": [ + { + "type": "work", + "streetAddress": "100 Universal City Plaza", + "locality": "Hollywood", + "region": "CA", + "postalCode": "91608", + "country": "USA", + "formatted": "100 Universal City Plaza\nHollywood, CA 91608 USA", + "primary": true + }, + { + "type": "home", + "streetAddress": "456 Hollywood Blvd", + "locality": "Hollywood", + "region": "CA", + "postalCode": "91608", + "country": "USA", + "formatted": "456 Hollywood Blvd\nHollywood, CA 91608 USA" + } + ], + "phoneNumbers": [ + { + "value": "555-555-5555", + "type": "work", + "primary": true + }, + { + "value": "555-555-4444", + "type": "mobile" + } + ], + "ims": [ + { + "value": "someaimhandle", + "type": "aim" + }, + { + "value": "twitterhandle", + "type": "X" + } + ], + "photos": [ + { + "value": + "https://photos.example.com/profilephoto/72930000000Ccne/F", + "type": "photo" + }, + { + "value": + "https://photos.example.com/profilephoto/72930000000Ccne/T", + "type": "thumbnail" + } + ], + "roles": [ + { + "value": "my-role-1", + "display": "Rolle 1", + "type": "main-role", + "primary": true + }, + { + "value": "my-role-2", + "display": "Rolle 2", + "type": "secondary-role", + "primary": false + } + ], + "entitlements": [ + { + "value": "my-entitlement-1", + "display": "Entitlement 1", + "type": "main-entitlement", + "primary": true + }, + { + "value": "my-entitlement-2", + "display": "Entitlement 2", + "type": "secondary-entitlement", + "primary": false + } + ], + "userType": "Employee", + "title": "Tour Guide", + "preferredLanguage": "en-US", + "locale": "en-US", + "timezone": "America/Los_Angeles", + "active":true, + "password": "Password1!" +} \ No newline at end of file diff --git a/internal/api/scim/integration_test/testdata/users_create_test_invalid_locale.json b/internal/api/scim/integration_test/testdata/users_create_test_invalid_locale.json new file mode 100644 index 0000000000..eaadac8b90 --- /dev/null +++ b/internal/api/scim/integration_test/testdata/users_create_test_invalid_locale.json @@ -0,0 +1,17 @@ +{ + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User" + ], + "userName": "acmeUser1", + "name": { + "familyName": "Ross", + "givenName": "Bethany" + }, + "emails": [ + { + "value": "user1@example.com", + "primary": true + } + ], + "locale": "fooBar" +} \ No newline at end of file diff --git a/internal/api/scim/integration_test/testdata/users_create_test_invalid_password.json b/internal/api/scim/integration_test/testdata/users_create_test_invalid_password.json new file mode 100644 index 0000000000..7a3d71cbed --- /dev/null +++ b/internal/api/scim/integration_test/testdata/users_create_test_invalid_password.json @@ -0,0 +1,17 @@ +{ + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User" + ], + "userName": "acmeUser1", + "name": { + "familyName": "Ross", + "givenName": "Bethany" + }, + "emails": [ + { + "value": "user1@example.com", + "primary": true + } + ], + "password": "fooBar" +} \ No newline at end of file diff --git a/internal/api/scim/integration_test/testdata/users_create_test_invalid_profile_url.json b/internal/api/scim/integration_test/testdata/users_create_test_invalid_profile_url.json new file mode 100644 index 0000000000..3bc8fee87b --- /dev/null +++ b/internal/api/scim/integration_test/testdata/users_create_test_invalid_profile_url.json @@ -0,0 +1,17 @@ +{ + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User" + ], + "userName": "acmeUser1", + "name": { + "familyName": "Ross", + "givenName": "Bethany" + }, + "emails": [ + { + "value": "user1@example.com", + "primary": true + } + ], + "profileUrl": "ftp://login.example.com/bjensen" +} \ No newline at end of file diff --git a/internal/api/scim/integration_test/testdata/users_create_test_invalid_timezone.json b/internal/api/scim/integration_test/testdata/users_create_test_invalid_timezone.json new file mode 100644 index 0000000000..d4ac9aa0a5 --- /dev/null +++ b/internal/api/scim/integration_test/testdata/users_create_test_invalid_timezone.json @@ -0,0 +1,17 @@ +{ + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User" + ], + "userName": "acmeUser1", + "name": { + "familyName": "Ross", + "givenName": "Bethany" + }, + "emails": [ + { + "value": "user1@example.com", + "primary": true + } + ], + "timezone": "fooBar" +} \ No newline at end of file diff --git a/internal/api/scim/integration_test/testdata/users_create_test_minimal.json b/internal/api/scim/integration_test/testdata/users_create_test_minimal.json new file mode 100644 index 0000000000..c51f416bc7 --- /dev/null +++ b/internal/api/scim/integration_test/testdata/users_create_test_minimal.json @@ -0,0 +1,16 @@ +{ + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User" + ], + "userName": "acmeUser1", + "name": { + "familyName": "Ross", + "givenName": "Bethany" + }, + "emails": [ + { + "value": "user1@example.com", + "primary": true + } + ] +} \ No newline at end of file diff --git a/internal/api/scim/integration_test/testdata/users_create_test_minimal_inactive.json b/internal/api/scim/integration_test/testdata/users_create_test_minimal_inactive.json new file mode 100644 index 0000000000..11650674a6 --- /dev/null +++ b/internal/api/scim/integration_test/testdata/users_create_test_minimal_inactive.json @@ -0,0 +1,17 @@ +{ + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User" + ], + "userName": "acmeUser1", + "name": { + "familyName": "Ross", + "givenName": "Bethany" + }, + "emails": [ + { + "value": "user1@example.com", + "primary": true + } + ], + "active": false +} \ No newline at end of file diff --git a/internal/api/scim/integration_test/testdata/users_create_test_missing_email.json b/internal/api/scim/integration_test/testdata/users_create_test_missing_email.json new file mode 100644 index 0000000000..c68ebf98a0 --- /dev/null +++ b/internal/api/scim/integration_test/testdata/users_create_test_missing_email.json @@ -0,0 +1,10 @@ +{ + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User" + ], + "userName": "acmeUser1", + "name": { + "familyName": "Ross", + "givenName": "Bethany" + } +} \ No newline at end of file diff --git a/internal/api/scim/integration_test/testdata/users_create_test_missing_name.json b/internal/api/scim/integration_test/testdata/users_create_test_missing_name.json new file mode 100644 index 0000000000..d1d3375f89 --- /dev/null +++ b/internal/api/scim/integration_test/testdata/users_create_test_missing_name.json @@ -0,0 +1,15 @@ +{ + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User" + ], + "userName": "acmeUser1", + "name": { + "givenName": "Bethany" + }, + "emails": [ + { + "value": "user1@example.com", + "primary": true + } + ] +} \ No newline at end of file diff --git a/internal/api/scim/integration_test/testdata/users_create_test_missing_username.json b/internal/api/scim/integration_test/testdata/users_create_test_missing_username.json new file mode 100644 index 0000000000..9446665226 --- /dev/null +++ b/internal/api/scim/integration_test/testdata/users_create_test_missing_username.json @@ -0,0 +1,15 @@ +{ + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User" + ], + "name": { + "familyName": "Ross", + "givenName": "Bethany" + }, + "emails": [ + { + "value": "user1@example.com", + "primary": true + } + ] +} \ No newline at end of file diff --git a/internal/api/scim/integration_test/testdata/users_replace_test_full.json b/internal/api/scim/integration_test/testdata/users_replace_test_full.json new file mode 100644 index 0000000000..83ff72b697 --- /dev/null +++ b/internal/api/scim/integration_test/testdata/users_replace_test_full.json @@ -0,0 +1,116 @@ +{ + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], + "externalId": "701984-updated", + "userName": "bjensen-replaced-full@example.com", + "name": { + "formatted": "Ms. Barbara J Jensen, III-updated", + "familyName": "Jensen-updated", + "givenName": "Barbara-updated", + "middleName": "Jane-updated", + "honorificPrefix": "Ms.-updated", + "honorificSuffix": "III" + }, + "displayName": "Babs Jensen-updated", + "nickName": "Babs-updated", + "profileUrl": "http://login.example.com/bjensen-updated", + "emails": [ + { + "value": "bjensen-replaced-full@example.com", + "type": "work-updated", + "primary": true + }, + { + "value": "babs-replaced-full@jensen.org", + "type": "home-updated" + } + ], + "addresses": [ + { + "type": "work-updated", + "streetAddress": "100 Universal City Plaza-updated", + "locality": "Hollywood-updated", + "region": "CA-updated", + "postalCode": "91608-updated", + "country": "USA-updated", + "formatted": "100 Universal City Plaza\nHollywood, CA 91608 USA-updated", + "primary": true + }, + { + "type": "home-updated", + "streetAddress": "456 Hollywood Blvd-updated", + "locality": "Hollywood-updated", + "region": "CA-updated", + "postalCode": "91608-updated", + "country": "USA-updated", + "formatted": "456 Hollywood Blvd\nHollywood, CA 91608 USA-updated" + } + ], + "phoneNumbers": [ + { + "value": "555-555-5555-updated", + "type": "work-updated", + "primary": true + }, + { + "value": "555-555-4444-updated", + "type": "mobile-updated" + } + ], + "ims": [ + { + "value": "someaimhandle-updated", + "type": "aim-updated" + }, + { + "value": "twitterhandle-updated", + "type": "X-updated" + } + ], + "photos": [ + { + "value": + "https://photos.example.com/profilephoto/72930000000Ccne/F-updated", + "type": "photo-updated" + }, + { + "value": + "https://photos.example.com/profilephoto/72930000000Ccne/T-updated", + "type": "thumbnail-updated" + } + ], + "roles": [ + { + "value": "my-role-1-updated", + "display": "Rolle 1-updated", + "type": "main-role-updated", + "primary": true + }, + { + "value": "my-role-2-updated", + "display": "Rolle 2-updated", + "type": "secondary-role-updated", + "primary": false + } + ], + "entitlements": [ + { + "value": "my-entitlement-1-updated", + "display": "Entitlement 1-updated", + "type": "main-entitlement-updated", + "primary": true + }, + { + "value": "my-entitlement-2-updated", + "display": "Entitlement 2-updated", + "type": "secondary-entitlement-updated", + "primary": false + } + ], + "userType": "Employee-updated", + "title": "Tour Guide-updated", + "preferredLanguage": "en-CH", + "locale": "en-CH", + "timezone": "Europe/Zurich", + "active": false, + "password": "Password1!-updated" +} \ No newline at end of file diff --git a/internal/api/scim/integration_test/testdata/users_replace_test_minimal.json b/internal/api/scim/integration_test/testdata/users_replace_test_minimal.json new file mode 100644 index 0000000000..f8756bf4a4 --- /dev/null +++ b/internal/api/scim/integration_test/testdata/users_replace_test_minimal.json @@ -0,0 +1,16 @@ +{ + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User" + ], + "userName": "acmeUser1-minimal-replaced", + "name": { + "familyName": "Ross-replaced", + "givenName": "Bethany-replaced" + }, + "emails": [ + { + "value": "user1-minimal-replaced@example.com", + "primary": true + } + ] +} \ No newline at end of file diff --git a/internal/api/scim/integration_test/testdata/users_replace_test_minimal_with_external_id.json b/internal/api/scim/integration_test/testdata/users_replace_test_minimal_with_external_id.json new file mode 100644 index 0000000000..d02e605976 --- /dev/null +++ b/internal/api/scim/integration_test/testdata/users_replace_test_minimal_with_external_id.json @@ -0,0 +1,17 @@ +{ + "schemas": [ + "urn:ietf:params:scim:schemas:core:2.0:User" + ], + "externalID": "replaced-external-id", + "userName": "acmeUser1-replaced-with-external-id", + "name": { + "familyName": "Ross", + "givenName": "Bethany" + }, + "emails": [ + { + "value": "user1-minimal-replaced-with-external-id@example.com", + "primary": true + } + ] +} \ No newline at end of file diff --git a/internal/api/scim/integration_test/users_create_test.go b/internal/api/scim/integration_test/users_create_test.go new file mode 100644 index 0000000000..b9bc708d95 --- /dev/null +++ b/internal/api/scim/integration_test/users_create_test.go @@ -0,0 +1,405 @@ +//go:build integration + +package integration_test + +import ( + "context" + _ "embed" + "net/http" + "path" + "testing" + "time" + + "github.com/brianvoe/gofakeit/v6" + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/text/language" + "google.golang.org/grpc/codes" + + "github.com/zitadel/zitadel/internal/api/scim/resources" + "github.com/zitadel/zitadel/internal/api/scim/schemas" + "github.com/zitadel/zitadel/internal/integration" + "github.com/zitadel/zitadel/internal/integration/scim" + "github.com/zitadel/zitadel/pkg/grpc/management" + "github.com/zitadel/zitadel/pkg/grpc/user/v2" +) + +var ( + //go:embed testdata/users_create_test_minimal.json + minimalUserJson []byte + + //go:embed testdata/users_create_test_minimal_inactive.json + minimalInactiveUserJson []byte + + //go:embed testdata/users_create_test_full.json + fullUserJson []byte + + //go:embed testdata/users_create_test_missing_username.json + missingUserNameUserJson []byte + + //go:embed testdata/users_create_test_missing_name.json + missingNameUserJson []byte + + //go:embed testdata/users_create_test_missing_email.json + missingEmailUserJson []byte + + //go:embed testdata/users_create_test_invalid_password.json + invalidPasswordUserJson []byte + + //go:embed testdata/users_create_test_invalid_profile_url.json + invalidProfileUrlUserJson []byte + + //go:embed testdata/users_create_test_invalid_locale.json + invalidLocaleUserJson []byte + + //go:embed testdata/users_create_test_invalid_timezone.json + invalidTimeZoneUserJson []byte +) + +func TestCreateUser(t *testing.T) { + tests := []struct { + name string + body []byte + ctx context.Context + want *resources.ScimUser + wantErr bool + scimErrorType string + errorStatus int + zitadelErrID string + }{ + { + name: "minimal user", + body: minimalUserJson, + want: &resources.ScimUser{ + UserName: "acmeUser1", + Name: &resources.ScimUserName{ + FamilyName: "Ross", + GivenName: "Bethany", + }, + Emails: []*resources.ScimEmail{ + { + Value: "user1@example.com", + Primary: true, + }, + }, + }, + }, + { + name: "minimal inactive user", + body: minimalInactiveUserJson, + want: &resources.ScimUser{ + Active: gu.Ptr(false), + }, + }, + { + name: "full user", + body: fullUserJson, + want: &resources.ScimUser{ + ExternalID: "701984", + UserName: "bjensen@example.com", + Name: &resources.ScimUserName{ + Formatted: "Babs Jensen", // DisplayName takes precedence in Zitadel + FamilyName: "Jensen", + GivenName: "Barbara", + MiddleName: "Jane", + HonorificPrefix: "Ms.", + HonorificSuffix: "III", + }, + DisplayName: "Babs Jensen", + NickName: "Babs", + ProfileUrl: integration.Must(schemas.ParseHTTPURL("http://login.example.com/bjensen")), + Emails: []*resources.ScimEmail{ + { + Value: "bjensen@example.com", + Primary: true, + }, + }, + Addresses: []*resources.ScimAddress{ + { + Type: "work", + StreetAddress: "100 Universal City Plaza", + Locality: "Hollywood", + Region: "CA", + PostalCode: "91608", + Country: "USA", + Formatted: "100 Universal City Plaza\nHollywood, CA 91608 USA", + Primary: true, + }, + { + Type: "home", + StreetAddress: "456 Hollywood Blvd", + Locality: "Hollywood", + Region: "CA", + PostalCode: "91608", + Country: "USA", + Formatted: "456 Hollywood Blvd\nHollywood, CA 91608 USA", + }, + }, + PhoneNumbers: []*resources.ScimPhoneNumber{ + { + Value: "+415555555555", + Primary: true, + }, + }, + Ims: []*resources.ScimIms{ + { + Value: "someaimhandle", + Type: "aim", + }, + { + Value: "twitterhandle", + Type: "X", + }, + }, + Photos: []*resources.ScimPhoto{ + { + Value: *integration.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/F")), + Type: "photo", + }, + }, + Roles: []*resources.ScimRole{ + { + Value: "my-role-1", + Display: "Rolle 1", + Type: "main-role", + Primary: true, + }, + { + Value: "my-role-2", + Display: "Rolle 2", + Type: "secondary-role", + Primary: false, + }, + }, + Entitlements: []*resources.ScimEntitlement{ + { + Value: "my-entitlement-1", + Display: "Entitlement 1", + Type: "main-entitlement", + Primary: true, + }, + { + Value: "my-entitlement-2", + Display: "Entitlement 2", + Type: "secondary-entitlement", + Primary: false, + }, + }, + Title: "Tour Guide", + PreferredLanguage: language.MustParse("en-US"), + Locale: "en-US", + Timezone: "America/Los_Angeles", + Active: gu.Ptr(true), + }, + }, + { + name: "missing userName", + wantErr: true, + scimErrorType: "invalidValue", + body: missingUserNameUserJson, + }, + { + // this is an expected schema violation + name: "missing name", + wantErr: true, + scimErrorType: "invalidValue", + body: missingNameUserJson, + }, + { + name: "missing email", + wantErr: true, + scimErrorType: "invalidValue", + body: missingEmailUserJson, + }, + { + name: "password complexity violation", + wantErr: true, + scimErrorType: "invalidValue", + body: invalidPasswordUserJson, + }, + { + name: "invalid profile url", + wantErr: true, + scimErrorType: "invalidValue", + zitadelErrID: "SCIM-htturl1", + body: invalidProfileUrlUserJson, + }, + { + name: "invalid time zone", + wantErr: true, + scimErrorType: "invalidValue", + body: invalidTimeZoneUserJson, + }, + { + name: "invalid locale", + wantErr: true, + scimErrorType: "invalidValue", + body: invalidLocaleUserJson, + }, + { + name: "not authenticated", + body: minimalUserJson, + ctx: context.Background(), + wantErr: true, + errorStatus: http.StatusUnauthorized, + }, + { + name: "no permissions", + body: minimalUserJson, + ctx: Instance.WithAuthorization(CTX, integration.UserTypeNoPermission), + wantErr: true, + errorStatus: http.StatusNotFound, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := tt.ctx + if ctx == nil { + ctx = CTX + } + + createdUser, err := Instance.Client.SCIM.Users.Create(ctx, Instance.DefaultOrg.Id, tt.body) + if (err != nil) != tt.wantErr { + t.Errorf("CreateUser() error = %v, wantErr %v", err, tt.wantErr) + } + + if err != nil { + statusCode := tt.errorStatus + if statusCode == 0 { + statusCode = http.StatusBadRequest + } + scimErr := scim.RequireScimError(t, statusCode, err) + assert.Equal(t, tt.scimErrorType, scimErr.Error.ScimType) + if tt.zitadelErrID != "" { + assert.Equal(t, tt.zitadelErrID, scimErr.Error.ZitadelDetail.ID) + } + + return + } + + assert.NotEmpty(t, createdUser.ID) + defer func() { + _, err = Instance.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: createdUser.ID}) + assert.NoError(t, err) + }() + + assert.EqualValues(t, []schemas.ScimSchemaType{"urn:ietf:params:scim:schemas:core:2.0:User"}, createdUser.Resource.Schemas) + assert.Equal(t, schemas.ScimResourceTypeSingular("User"), createdUser.Resource.Meta.ResourceType) + assert.Equal(t, "http://"+Instance.Host()+path.Join(schemas.HandlerPrefix, Instance.DefaultOrg.Id, "Users", createdUser.ID), createdUser.Resource.Meta.Location) + assert.Nil(t, createdUser.Password) + + if tt.want != nil { + if !integration.PartiallyDeepEqual(tt.want, createdUser) { + t.Errorf("CreateUser() got = %v, want %v", createdUser, tt.want) + } + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + require.EventuallyWithT(t, func(ttt *assert.CollectT) { + // ensure the user is really stored and not just returned to the caller + fetchedUser, err := Instance.Client.SCIM.Users.Get(CTX, Instance.DefaultOrg.Id, createdUser.ID) + require.NoError(ttt, err) + if !integration.PartiallyDeepEqual(tt.want, fetchedUser) { + ttt.Errorf("GetUser() got = %v, want %v", fetchedUser, tt.want) + } + }, retryDuration, tick) + } + }) + } +} + +func TestCreateUser_duplicate(t *testing.T) { + createdUser, err := Instance.Client.SCIM.Users.Create(CTX, Instance.DefaultOrg.Id, minimalUserJson) + require.NoError(t, err) + + _, err = Instance.Client.SCIM.Users.Create(CTX, Instance.DefaultOrg.Id, minimalUserJson) + scimErr := scim.RequireScimError(t, http.StatusConflict, err) + assert.Equal(t, "User already exists", scimErr.Error.Detail) + + _, err = Instance.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: createdUser.ID}) + require.NoError(t, err) +} + +func TestCreateUser_metadata(t *testing.T) { + createdUser, err := Instance.Client.SCIM.Users.Create(CTX, Instance.DefaultOrg.Id, fullUserJson) + require.NoError(t, err) + + defer func() { + _, err = Instance.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: createdUser.ID}) + require.NoError(t, err) + }() + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + md, err := Instance.Client.Mgmt.ListUserMetadata(CTX, &management.ListUserMetadataRequest{ + Id: createdUser.ID, + }) + require.NoError(tt, err) + + mdMap := make(map[string]string) + for i := range md.Result { + mdMap[md.Result[i].Key] = string(md.Result[i].Value) + } + + integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:name.honorificPrefix", "Ms.") + integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:timezone", "America/Los_Angeles") + integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:photos", `[{"value":"https://photos.example.com/profilephoto/72930000000Ccne/F","type":"photo"},{"value":"https://photos.example.com/profilephoto/72930000000Ccne/T","type":"thumbnail"}]`) + integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:addresses", `[{"type":"work","streetAddress":"100 Universal City Plaza","locality":"Hollywood","region":"CA","postalCode":"91608","country":"USA","formatted":"100 Universal City Plaza\nHollywood, CA 91608 USA","primary":true},{"type":"home","streetAddress":"456 Hollywood Blvd","locality":"Hollywood","region":"CA","postalCode":"91608","country":"USA","formatted":"456 Hollywood Blvd\nHollywood, CA 91608 USA"}]`) + integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:entitlements", `[{"value":"my-entitlement-1","display":"Entitlement 1","type":"main-entitlement","primary":true},{"value":"my-entitlement-2","display":"Entitlement 2","type":"secondary-entitlement"}]`) + integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:externalId", "701984") + integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:name.middleName", "Jane") + integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:name.honorificSuffix", "III") + integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:profileURL", "http://login.example.com/bjensen") + integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:title", "Tour Guide") + integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:locale", "en-US") + integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:ims", `[{"value":"someaimhandle","type":"aim"},{"value":"twitterhandle","type":"X"}]`) + integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:roles", `[{"value":"my-role-1","display":"Rolle 1","type":"main-role","primary":true},{"value":"my-role-2","display":"Rolle 2","type":"secondary-role"}]`) + }, retryDuration, tick) +} + +func TestCreateUser_scopedExternalID(t *testing.T) { + _, err := Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{ + Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID, + Key: "urn:zitadel:scim:provisioning_domain", + Value: []byte("fooBar"), + }) + require.NoError(t, err) + + createdUser, err := Instance.Client.SCIM.Users.Create(CTX, Instance.DefaultOrg.Id, fullUserJson) + require.NoError(t, err) + + defer func() { + _, err = Instance.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: createdUser.ID}) + require.NoError(t, err) + + _, err = Instance.Client.Mgmt.RemoveUserMetadata(CTX, &management.RemoveUserMetadataRequest{ + Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID, + Key: "urn:zitadel:scim:provisioning_domain", + }) + require.NoError(t, err) + }() + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + // unscoped externalID should not exist + _, err = Instance.Client.Mgmt.GetUserMetadata(CTX, &management.GetUserMetadataRequest{ + Id: createdUser.ID, + Key: "urn:zitadel:scim:externalId", + }) + integration.AssertGrpcStatus(tt, codes.NotFound, err) + + // scoped externalID should exist + md, err := Instance.Client.Mgmt.GetUserMetadata(CTX, &management.GetUserMetadataRequest{ + Id: createdUser.ID, + Key: "urn:zitadel:scim:fooBar:externalId", + }) + require.NoError(tt, err) + assert.Equal(tt, "701984", string(md.Metadata.Value)) + }, retryDuration, tick) +} + +func TestCreateUser_anotherOrg(t *testing.T) { + org := Instance.CreateOrganization(Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner), gofakeit.Name(), gofakeit.Email()) + _, err := Instance.Client.SCIM.Users.Create(CTX, org.OrganizationId, fullUserJson) + scim.RequireScimError(t, http.StatusNotFound, err) +} diff --git a/internal/api/scim/integration_test/users_delete_test.go b/internal/api/scim/integration_test/users_delete_test.go new file mode 100644 index 0000000000..bfdd0eae88 --- /dev/null +++ b/internal/api/scim/integration_test/users_delete_test.go @@ -0,0 +1,90 @@ +//go:build integration + +package integration_test + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/brianvoe/gofakeit/v6" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + + "github.com/zitadel/zitadel/internal/integration" + "github.com/zitadel/zitadel/internal/integration/scim" + "github.com/zitadel/zitadel/pkg/grpc/user/v2" +) + +func TestDeleteUser_errors(t *testing.T) { + tests := []struct { + name string + ctx context.Context + errorStatus int + }{ + { + name: "not authenticated", + ctx: context.Background(), + errorStatus: http.StatusUnauthorized, + }, + { + name: "no permissions", + ctx: Instance.WithAuthorization(CTX, integration.UserTypeNoPermission), + errorStatus: http.StatusNotFound, + }, + { + name: "unknown user id", + errorStatus: http.StatusNotFound, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := tt.ctx + if ctx == nil { + ctx = CTX + } + + err := Instance.Client.SCIM.Users.Delete(ctx, Instance.DefaultOrg.Id, "1") + + statusCode := tt.errorStatus + if statusCode == 0 { + statusCode = http.StatusBadRequest + } + + scim.RequireScimError(t, statusCode, err) + }) + } +} + +func TestDeleteUser_ensureReallyDeleted(t *testing.T) { + // create user and dependencies + createUserResp := Instance.CreateHumanUser(CTX) + proj, err := Instance.CreateProject(CTX) + require.NoError(t, err) + + Instance.CreateProjectUserGrant(t, CTX, proj.Id, createUserResp.UserId) + + // delete user via scim + err = Instance.Client.SCIM.Users.Delete(CTX, Instance.DefaultOrg.Id, createUserResp.UserId) + assert.NoError(t, err) + + // ensure it is really deleted => try to delete again => should 404 + err = Instance.Client.SCIM.Users.Delete(CTX, Instance.DefaultOrg.Id, createUserResp.UserId) + scim.RequireScimError(t, http.StatusNotFound, err) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + // try to get user via api => should 404 + _, err = Instance.Client.UserV2.GetUserByID(CTX, &user.GetUserByIDRequest{UserId: createUserResp.UserId}) + integration.AssertGrpcStatus(tt, codes.NotFound, err) + }, retryDuration, tick) +} + +func TestDeleteUser_anotherOrg(t *testing.T) { + createUserResp := Instance.CreateHumanUser(CTX) + org := Instance.CreateOrganization(Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner), gofakeit.Name(), gofakeit.Email()) + err := Instance.Client.SCIM.Users.Delete(CTX, org.OrganizationId, createUserResp.UserId) + scim.RequireScimError(t, http.StatusNotFound, err) +} diff --git a/internal/api/scim/integration_test/users_get_test.go b/internal/api/scim/integration_test/users_get_test.go new file mode 100644 index 0000000000..a8055db600 --- /dev/null +++ b/internal/api/scim/integration_test/users_get_test.go @@ -0,0 +1,276 @@ +//go:build integration + +package integration_test + +import ( + "context" + "net/http" + "path" + "testing" + "time" + + "github.com/brianvoe/gofakeit/v6" + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/text/language" + + "github.com/zitadel/zitadel/internal/api/scim/resources" + "github.com/zitadel/zitadel/internal/api/scim/schemas" + "github.com/zitadel/zitadel/internal/integration" + "github.com/zitadel/zitadel/internal/integration/scim" + "github.com/zitadel/zitadel/pkg/grpc/management" + "github.com/zitadel/zitadel/pkg/grpc/user/v2" +) + +func TestGetUser(t *testing.T) { + tests := []struct { + name string + buildUserID func() string + cleanup func(userID string) + ctx context.Context + want *resources.ScimUser + wantErr bool + errorStatus int + }{ + { + name: "not authenticated", + ctx: context.Background(), + errorStatus: http.StatusUnauthorized, + wantErr: true, + }, + { + name: "no permissions", + ctx: Instance.WithAuthorization(CTX, integration.UserTypeNoPermission), + errorStatus: http.StatusNotFound, + wantErr: true, + }, + { + name: "unknown user id", + buildUserID: func() string { + return "unknown" + }, + errorStatus: http.StatusNotFound, + wantErr: true, + }, + { + name: "created via grpc", + want: &resources.ScimUser{ + Name: &resources.ScimUserName{ + FamilyName: "Mouse", + GivenName: "Mickey", + }, + PreferredLanguage: language.MustParse("nl"), + PhoneNumbers: []*resources.ScimPhoneNumber{ + { + Value: "+41791234567", + Primary: true, + }, + }, + }, + }, + { + name: "created via scim", + buildUserID: func() string { + createdUser, err := Instance.Client.SCIM.Users.Create(CTX, Instance.DefaultOrg.Id, fullUserJson) + require.NoError(t, err) + return createdUser.ID + }, + cleanup: func(userID string) { + _, err := Instance.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: userID}) + require.NoError(t, err) + }, + want: &resources.ScimUser{ + ExternalID: "701984", + UserName: "bjensen@example.com", + Name: &resources.ScimUserName{ + Formatted: "Babs Jensen", // DisplayName takes precedence + FamilyName: "Jensen", + GivenName: "Barbara", + MiddleName: "Jane", + HonorificPrefix: "Ms.", + HonorificSuffix: "III", + }, + DisplayName: "Babs Jensen", + NickName: "Babs", + ProfileUrl: integration.Must(schemas.ParseHTTPURL("http://login.example.com/bjensen")), + Title: "Tour Guide", + PreferredLanguage: language.Make("en-US"), + Locale: "en-US", + Timezone: "America/Los_Angeles", + Active: gu.Ptr(true), + Emails: []*resources.ScimEmail{ + { + Value: "bjensen@example.com", + Primary: true, + }, + }, + PhoneNumbers: []*resources.ScimPhoneNumber{ + { + Value: "+415555555555", + Primary: true, + }, + }, + Ims: []*resources.ScimIms{ + { + Value: "someaimhandle", + Type: "aim", + }, + { + Value: "twitterhandle", + Type: "X", + }, + }, + Addresses: []*resources.ScimAddress{ + { + Type: "work", + StreetAddress: "100 Universal City Plaza", + Locality: "Hollywood", + Region: "CA", + PostalCode: "91608", + Country: "USA", + Formatted: "100 Universal City Plaza\nHollywood, CA 91608 USA", + Primary: true, + }, + { + Type: "home", + StreetAddress: "456 Hollywood Blvd", + Locality: "Hollywood", + Region: "CA", + PostalCode: "91608", + Country: "USA", + Formatted: "456 Hollywood Blvd\nHollywood, CA 91608 USA", + }, + }, + Photos: []*resources.ScimPhoto{ + { + Value: *integration.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/F")), + Type: "photo", + }, + { + Value: *integration.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/T")), + Type: "thumbnail", + }, + }, + Roles: []*resources.ScimRole{ + { + Value: "my-role-1", + Display: "Rolle 1", + Type: "main-role", + Primary: true, + }, + { + Value: "my-role-2", + Display: "Rolle 2", + Type: "secondary-role", + Primary: false, + }, + }, + Entitlements: []*resources.ScimEntitlement{ + { + Value: "my-entitlement-1", + Display: "Entitlement 1", + Type: "main-entitlement", + Primary: true, + }, + { + Value: "my-entitlement-2", + Display: "Entitlement 2", + Type: "secondary-entitlement", + Primary: false, + }, + }, + }, + }, + { + name: "scoped externalID", + buildUserID: func() string { + // create user without provisioning domain + createdUser, err := Instance.Client.SCIM.Users.Create(CTX, Instance.DefaultOrg.Id, fullUserJson) + require.NoError(t, err) + + // set provisioning domain of service user + _, err = Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{ + Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID, + Key: "urn:zitadel:scim:provisioning_domain", + Value: []byte("fooBar"), + }) + require.NoError(t, err) + + // set externalID for provisioning domain + _, err = Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{ + Id: createdUser.ID, + Key: "urn:zitadel:scim:fooBar:externalId", + Value: []byte("100-scopedExternalId"), + }) + require.NoError(t, err) + return createdUser.ID + }, + cleanup: func(userID string) { + _, err := Instance.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: userID}) + require.NoError(t, err) + + _, err = Instance.Client.Mgmt.RemoveUserMetadata(CTX, &management.RemoveUserMetadataRequest{ + Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID, + Key: "urn:zitadel:scim:provisioning_domain", + }) + require.NoError(t, err) + }, + want: &resources.ScimUser{ + ExternalID: "100-scopedExternalId", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := tt.ctx + if ctx == nil { + ctx = CTX + } + + var userID string + if tt.buildUserID != nil { + userID = tt.buildUserID() + } else { + createUserResp := Instance.CreateHumanUser(CTX) + userID = createUserResp.UserId + } + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + var fetchedUser *resources.ScimUser + var err error + require.EventuallyWithT(t, func(ttt *assert.CollectT) { + fetchedUser, err = Instance.Client.SCIM.Users.Get(ctx, Instance.DefaultOrg.Id, userID) + if tt.wantErr { + statusCode := tt.errorStatus + if statusCode == 0 { + statusCode = http.StatusBadRequest + } + + scim.RequireScimError(ttt, statusCode, err) + return + } + + assert.Equal(ttt, userID, fetchedUser.ID) + assert.EqualValues(ttt, []schemas.ScimSchemaType{"urn:ietf:params:scim:schemas:core:2.0:User"}, fetchedUser.Schemas) + assert.Equal(ttt, schemas.ScimResourceTypeSingular("User"), fetchedUser.Resource.Meta.ResourceType) + assert.Equal(ttt, "http://"+Instance.Host()+path.Join(schemas.HandlerPrefix, Instance.DefaultOrg.Id, "Users", fetchedUser.ID), fetchedUser.Resource.Meta.Location) + assert.Nil(ttt, fetchedUser.Password) + if !integration.PartiallyDeepEqual(tt.want, fetchedUser) { + ttt.Errorf("GetUser() got = %#v, want %#v", fetchedUser, tt.want) + } + }, retryDuration, tick) + + if tt.cleanup != nil { + tt.cleanup(fetchedUser.ID) + } + }) + } +} + +func TestGetUser_anotherOrg(t *testing.T) { + createUserResp := Instance.CreateHumanUser(CTX) + org := Instance.CreateOrganization(Instance.WithAuthorization(CTX, integration.UserTypeIAMOwner), gofakeit.Name(), gofakeit.Email()) + _, err := Instance.Client.SCIM.Users.Get(CTX, org.OrganizationId, createUserResp.UserId) + scim.RequireScimError(t, http.StatusNotFound, err) +} diff --git a/internal/api/scim/integration_test/users_replace_test.go b/internal/api/scim/integration_test/users_replace_test.go new file mode 100644 index 0000000000..b43dd3acf0 --- /dev/null +++ b/internal/api/scim/integration_test/users_replace_test.go @@ -0,0 +1,331 @@ +//go:build integration + +package integration_test + +import ( + "context" + _ "embed" + "net/http" + "path" + "testing" + "time" + + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/text/language" + + "github.com/zitadel/zitadel/internal/api/scim/resources" + "github.com/zitadel/zitadel/internal/api/scim/schemas" + "github.com/zitadel/zitadel/internal/integration" + "github.com/zitadel/zitadel/internal/integration/scim" + "github.com/zitadel/zitadel/pkg/grpc/management" + "github.com/zitadel/zitadel/pkg/grpc/user/v2" +) + +var ( + //go:embed testdata/users_replace_test_minimal_with_external_id.json + minimalUserWithExternalIDJson []byte + + //go:embed testdata/users_replace_test_minimal.json + minimalUserReplaceJson []byte + + //go:embed testdata/users_replace_test_full.json + fullUserReplaceJson []byte +) + +func TestReplaceUser(t *testing.T) { + tests := []struct { + name string + body []byte + ctx context.Context + want *resources.ScimUser + wantErr bool + scimErrorType string + errorStatus int + zitadelErrID string + }{ + { + name: "minimal user", + body: minimalUserReplaceJson, + want: &resources.ScimUser{ + UserName: "acmeUser1-minimal-replaced", + Name: &resources.ScimUserName{ + FamilyName: "Ross-replaced", + GivenName: "Bethany-replaced", + }, + Emails: []*resources.ScimEmail{ + { + Value: "user1-minimal-replaced@example.com", + Primary: true, + }, + }, + }, + }, + { + name: "full user", + body: fullUserReplaceJson, + want: &resources.ScimUser{ + ExternalID: "701984-updated", + UserName: "bjensen-replaced-full@example.com", + Name: &resources.ScimUserName{ + Formatted: "Babs Jensen-updated", // display name takes precedence + FamilyName: "Jensen-updated", + GivenName: "Barbara-updated", + MiddleName: "Jane-updated", + HonorificPrefix: "Ms.-updated", + HonorificSuffix: "III", + }, + DisplayName: "Babs Jensen-updated", + NickName: "Babs-updated", + ProfileUrl: integration.Must(schemas.ParseHTTPURL("http://login.example.com/bjensen-updated")), + Emails: []*resources.ScimEmail{ + { + Value: "bjensen-replaced-full@example.com", + Primary: true, + }, + }, + Addresses: []*resources.ScimAddress{ + { + Type: "work-updated", + StreetAddress: "100 Universal City Plaza-updated", + Locality: "Hollywood-updated", + Region: "CA-updated", + PostalCode: "91608-updated", + Country: "USA-updated", + Formatted: "100 Universal City Plaza\nHollywood, CA 91608 USA-updated", + Primary: true, + }, + { + Type: "home-updated", + StreetAddress: "456 Hollywood Blvd-updated", + Locality: "Hollywood-updated", + Region: "CA-updated", + PostalCode: "91608-updated", + Country: "USA-updated", + Formatted: "456 Hollywood Blvd\nHollywood, CA 91608 USA-updated", + }, + }, + PhoneNumbers: []*resources.ScimPhoneNumber{ + { + Value: "+4155555555558732833", + Primary: true, + }, + }, + Ims: []*resources.ScimIms{ + { + Value: "someaimhandle-updated", + Type: "aim-updated", + }, + { + Value: "twitterhandle-updated", + Type: "X-updated", + }, + }, + Photos: []*resources.ScimPhoto{ + { + Value: *integration.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/F-updated")), + Type: "photo-updated", + }, + { + Value: *integration.Must(schemas.ParseHTTPURL("https://photos.example.com/profilephoto/72930000000Ccne/T-updated")), + Type: "thumbnail-updated", + }, + }, + Roles: []*resources.ScimRole{ + { + Value: "my-role-1-updated", + Display: "Rolle 1-updated", + Type: "main-role-updated", + Primary: true, + }, + { + Value: "my-role-2-updated", + Display: "Rolle 2-updated", + Type: "secondary-role-updated", + Primary: false, + }, + }, + Entitlements: []*resources.ScimEntitlement{ + { + Value: "my-entitlement-1-updated", + Display: "Entitlement 1-updated", + Type: "main-entitlement-updated", + Primary: true, + }, + { + Value: "my-entitlement-2-updated", + Display: "Entitlement 2-updated", + Type: "secondary-entitlement-updated", + Primary: false, + }, + }, + Title: "Tour Guide-updated", + PreferredLanguage: language.MustParse("en-CH"), + Locale: "en-CH", + Timezone: "Europe/Zurich", + Active: gu.Ptr(false), + }, + }, + { + name: "password complexity violation", + wantErr: true, + scimErrorType: "invalidValue", + body: invalidPasswordUserJson, + }, + { + name: "invalid profile url", + wantErr: true, + scimErrorType: "invalidValue", + zitadelErrID: "SCIM-htturl1", + body: invalidProfileUrlUserJson, + }, + { + name: "invalid time zone", + wantErr: true, + scimErrorType: "invalidValue", + body: invalidTimeZoneUserJson, + }, + { + name: "invalid locale", + wantErr: true, + scimErrorType: "invalidValue", + body: invalidLocaleUserJson, + }, + { + name: "not authenticated", + body: minimalUserJson, + ctx: context.Background(), + wantErr: true, + errorStatus: http.StatusUnauthorized, + }, + { + name: "no permissions", + body: minimalUserJson, + ctx: Instance.WithAuthorization(CTX, integration.UserTypeNoPermission), + wantErr: true, + errorStatus: http.StatusNotFound, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + createdUser, err := Instance.Client.SCIM.Users.Create(CTX, Instance.DefaultOrg.Id, fullUserJson) + require.NoError(t, err) + + defer func() { + _, err = Instance.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: createdUser.ID}) + assert.NoError(t, err) + }() + + ctx := tt.ctx + if ctx == nil { + ctx = CTX + } + + replacedUser, err := Instance.Client.SCIM.Users.Replace(ctx, Instance.DefaultOrg.Id, createdUser.ID, tt.body) + if (err != nil) != tt.wantErr { + t.Errorf("ReplaceUser() error = %v, wantErr %v", err, tt.wantErr) + } + + if err != nil { + statusCode := tt.errorStatus + if statusCode == 0 { + statusCode = http.StatusBadRequest + } + scimErr := scim.RequireScimError(t, statusCode, err) + assert.Equal(t, tt.scimErrorType, scimErr.Error.ScimType) + if tt.zitadelErrID != "" { + assert.Equal(t, tt.zitadelErrID, scimErr.Error.ZitadelDetail.ID) + } + + return + } + + assert.NotEmpty(t, replacedUser.ID) + assert.EqualValues(t, []schemas.ScimSchemaType{"urn:ietf:params:scim:schemas:core:2.0:User"}, replacedUser.Resource.Schemas) + assert.Equal(t, schemas.ScimResourceTypeSingular("User"), replacedUser.Resource.Meta.ResourceType) + assert.Equal(t, "http://"+Instance.Host()+path.Join(schemas.HandlerPrefix, Instance.DefaultOrg.Id, "Users", createdUser.ID), replacedUser.Resource.Meta.Location) + assert.Nil(t, createdUser.Password) + + if !integration.PartiallyDeepEqual(tt.want, replacedUser) { + t.Errorf("ReplaceUser() got = %#v, want %#v", replacedUser, tt.want) + } + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + require.EventuallyWithT(t, func(ttt *assert.CollectT) { + // ensure the user is really stored and not just returned to the caller + fetchedUser, err := Instance.Client.SCIM.Users.Get(CTX, Instance.DefaultOrg.Id, replacedUser.ID) + require.NoError(ttt, err) + if !integration.PartiallyDeepEqual(tt.want, fetchedUser) { + ttt.Errorf("GetUser() got = %#v, want %#v", fetchedUser, tt.want) + } + }, retryDuration, tick) + }) + } + +} + +func TestReplaceUser_removeOldMetadata(t *testing.T) { + // ensure old metadata is removed correctly + createdUser, err := Instance.Client.SCIM.Users.Create(CTX, Instance.DefaultOrg.Id, fullUserJson) + require.NoError(t, err) + + _, err = Instance.Client.SCIM.Users.Replace(CTX, Instance.DefaultOrg.Id, createdUser.ID, minimalUserJson) + require.NoError(t, err) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + md, err := Instance.Client.Mgmt.ListUserMetadata(CTX, &management.ListUserMetadataRequest{ + Id: createdUser.ID, + }) + require.NoError(tt, err) + require.Equal(tt, 0, len(md.Result)) + }, retryDuration, tick) + + _, err = Instance.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: createdUser.ID}) + require.NoError(t, err) +} + +func TestReplaceUser_scopedExternalID(t *testing.T) { + // create user without provisioning domain set + createdUser, err := Instance.Client.SCIM.Users.Create(CTX, Instance.DefaultOrg.Id, fullUserJson) + require.NoError(t, err) + + // set provisioning domain of service user + _, err = Instance.Client.Mgmt.SetUserMetadata(CTX, &management.SetUserMetadataRequest{ + Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID, + Key: "urn:zitadel:scim:provisioning_domain", + Value: []byte("fooBazz"), + }) + require.NoError(t, err) + + // replace the user with provisioning domain set + _, err = Instance.Client.SCIM.Users.Replace(CTX, Instance.DefaultOrg.Id, createdUser.ID, minimalUserWithExternalIDJson) + require.NoError(t, err) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + require.EventuallyWithT(t, func(tt *assert.CollectT) { + md, err := Instance.Client.Mgmt.ListUserMetadata(CTX, &management.ListUserMetadataRequest{ + Id: createdUser.ID, + }) + require.NoError(tt, err) + + mdMap := make(map[string]string) + for i := range md.Result { + mdMap[md.Result[i].Key] = string(md.Result[i].Value) + } + + // both external IDs should be present on the user + integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:externalId", "701984") + integration.AssertMapContains(tt, mdMap, "urn:zitadel:scim:fooBazz:externalId", "replaced-external-id") + }, retryDuration, tick) + + _, err = Instance.Client.UserV2.DeleteUser(CTX, &user.DeleteUserRequest{UserId: createdUser.ID}) + require.NoError(t, err) + + _, err = Instance.Client.Mgmt.RemoveUserMetadata(CTX, &management.RemoveUserMetadataRequest{ + Id: Instance.Users.Get(integration.UserTypeOrgOwner).ID, + Key: "urn:zitadel:scim:provisioning_domain", + }) + require.NoError(t, err) +} diff --git a/internal/api/scim/metadata/context.go b/internal/api/scim/metadata/context.go new file mode 100644 index 0000000000..5be54d7123 --- /dev/null +++ b/internal/api/scim/metadata/context.go @@ -0,0 +1,23 @@ +package metadata + +import ( + "context" +) + +type provisioningDomainKeyType struct{} + +var provisioningDomainKey provisioningDomainKeyType + +type ScimContextData struct { + ProvisioningDomain string + ExternalIDScopedMetadataKey ScopedKey +} + +func SetScimContextData(ctx context.Context, data ScimContextData) context.Context { + return context.WithValue(ctx, provisioningDomainKey, data) +} + +func GetScimContextData(ctx context.Context) ScimContextData { + data, _ := ctx.Value(provisioningDomainKey).(ScimContextData) + return data +} diff --git a/internal/api/scim/metadata/metadata.go b/internal/api/scim/metadata/metadata.go new file mode 100644 index 0000000000..626d938234 --- /dev/null +++ b/internal/api/scim/metadata/metadata.go @@ -0,0 +1,60 @@ +package metadata + +import ( + "context" + "strings" +) + +type Key string +type ScopedKey string + +const ( + externalIdProvisioningDomainPlaceholder = "{provisioningDomain}" + + KeyPrefix = "urn:zitadel:scim:" + KeyProvisioningDomain Key = KeyPrefix + "provisioning_domain" + + KeyExternalId Key = KeyPrefix + "externalId" + keyScopedExternalIdTemplate = KeyPrefix + externalIdProvisioningDomainPlaceholder + ":externalId" + KeyMiddleName Key = KeyPrefix + "name.middleName" + KeyHonorificPrefix Key = KeyPrefix + "name.honorificPrefix" + KeyHonorificSuffix Key = KeyPrefix + "name.honorificSuffix" + KeyProfileUrl Key = KeyPrefix + "profileURL" + KeyTitle Key = KeyPrefix + "title" + KeyLocale Key = KeyPrefix + "locale" + KeyTimezone Key = KeyPrefix + "timezone" + KeyIms Key = KeyPrefix + "ims" + KeyPhotos Key = KeyPrefix + "photos" + KeyAddresses Key = KeyPrefix + "addresses" + KeyEntitlements Key = KeyPrefix + "entitlements" + KeyRoles Key = KeyPrefix + "roles" +) + +var ScimUserRelevantMetadataKeys = []Key{ + KeyExternalId, + KeyMiddleName, + KeyHonorificPrefix, + KeyHonorificSuffix, + KeyProfileUrl, + KeyTitle, + KeyLocale, + KeyTimezone, + KeyIms, + KeyPhotos, + KeyAddresses, + KeyEntitlements, + KeyRoles, +} + +func ScopeExternalIdKey(provisioningDomain string) ScopedKey { + return ScopedKey(strings.Replace(keyScopedExternalIdTemplate, externalIdProvisioningDomainPlaceholder, provisioningDomain, 1)) +} + +func ScopeKey(ctx context.Context, key Key) ScopedKey { + // only the externalID is scoped + if key == KeyExternalId { + return GetScimContextData(ctx).ExternalIDScopedMetadataKey + } + + return ScopedKey(key) +} diff --git a/internal/api/scim/middleware/content_type_middleware.go b/internal/api/scim/middleware/content_type_middleware.go new file mode 100644 index 0000000000..9b456bb141 --- /dev/null +++ b/internal/api/scim/middleware/content_type_middleware.go @@ -0,0 +1,53 @@ +package middleware + +import ( + "mime" + "net/http" + "strings" + + "github.com/zitadel/logging" + + zhttp "github.com/zitadel/zitadel/internal/api/http" + "github.com/zitadel/zitadel/internal/api/http/middleware" + "github.com/zitadel/zitadel/internal/zerrors" +) + +const ( + ContentTypeScim = "application/scim+json" + ContentTypeJson = "application/json" +) + +func ContentTypeMiddleware(next middleware.HandlerFuncWithError) middleware.HandlerFuncWithError { + return func(w http.ResponseWriter, r *http.Request) error { + w.Header().Set(zhttp.ContentType, ContentTypeScim) + + if !validateContentType(r.Header.Get(zhttp.ContentType)) { + return zerrors.ThrowInvalidArgumentf(nil, "SMCM-12x4", "Invalid content type header") + } + + if !validateContentType(r.Header.Get(zhttp.Accept)) { + return zerrors.ThrowInvalidArgumentf(nil, "SMCM-12x5", "Invalid accept header") + } + + return next(w, r) + } +} + +func validateContentType(contentType string) bool { + if contentType == "" { + return true + } + + mediaType, params, err := mime.ParseMediaType(contentType) + if err != nil { + logging.OnError(err).Warn("failed to parse content type header") + return false + } + + if mediaType != "" && !strings.EqualFold(mediaType, ContentTypeJson) && !strings.EqualFold(mediaType, ContentTypeScim) { + return false + } + + charset, ok := params["charset"] + return !ok || strings.EqualFold(charset, "utf-8") +} diff --git a/internal/api/scim/middleware/content_type_middleware_test.go b/internal/api/scim/middleware/content_type_middleware_test.go new file mode 100644 index 0000000000..918d4618ae --- /dev/null +++ b/internal/api/scim/middleware/content_type_middleware_test.go @@ -0,0 +1,107 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + zhttp "github.com/zitadel/zitadel/internal/api/http" +) + +func TestContentTypeMiddleware(t *testing.T) { + tests := []struct { + name string + contentTypeHeader string + acceptHeader string + wantErr bool + }{ + { + name: "valid", + contentTypeHeader: "application/scim+json", + acceptHeader: "application/scim+json", + wantErr: false, + }, + { + name: "invalid content type", + contentTypeHeader: "application/octet-stream", + acceptHeader: "application/json", + wantErr: true, + }, + { + name: "invalid accept", + contentTypeHeader: "application/json", + acceptHeader: "application/octet-stream", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + if tt.acceptHeader != "" { + req.Header.Set(zhttp.Accept, tt.acceptHeader) + } + + if tt.contentTypeHeader != "" { + req.Header.Set(zhttp.ContentType, tt.contentTypeHeader) + } + + err := ContentTypeMiddleware(func(w http.ResponseWriter, r *http.Request) error { + return nil + })(httptest.NewRecorder(), req) + if (err != nil) != tt.wantErr { + t.Errorf("ContentTypeMiddleware() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_validateContentType(t *testing.T) { + tests := []struct { + name string + contentType string + want bool + }{ + { + name: "empty", + contentType: "", + want: true, + }, + { + name: "json", + contentType: "application/json", + want: true, + }, + { + name: "scim", + contentType: "application/scim+json", + want: true, + }, + { + name: "json utf-8", + contentType: "application/json; charset=utf-8", + want: true, + }, + { + name: "scim utf-8", + contentType: "application/scim+json; charset=utf-8", + want: true, + }, + { + name: "unknown content type", + contentType: "application/octet-stream", + want: false, + }, + { + name: "unknown charset", + contentType: "application/scim+json; charset=utf-16", + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := validateContentType(tt.contentType); got != tt.want { + t.Errorf("validateContentType() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/api/scim/middleware/scim_context_middleware.go b/internal/api/scim/middleware/scim_context_middleware.go new file mode 100644 index 0000000000..c52f6f13f6 --- /dev/null +++ b/internal/api/scim/middleware/scim_context_middleware.go @@ -0,0 +1,54 @@ +package middleware + +import ( + "context" + "net/http" + + "github.com/zitadel/zitadel/internal/api/authz" + zhttp "github.com/zitadel/zitadel/internal/api/http/middleware" + smetadata "github.com/zitadel/zitadel/internal/api/scim/metadata" + "github.com/zitadel/zitadel/internal/query" + "github.com/zitadel/zitadel/internal/zerrors" +) + +func ScimContextMiddleware(q *query.Queries) func(next zhttp.HandlerFuncWithError) zhttp.HandlerFuncWithError { + return func(next zhttp.HandlerFuncWithError) zhttp.HandlerFuncWithError { + return func(w http.ResponseWriter, r *http.Request) error { + ctx, err := initScimContext(r.Context(), q) + if err != nil { + return err + } + + return next(w, r.WithContext(ctx)) + } + } +} + +func initScimContext(ctx context.Context, q *query.Queries) (context.Context, error) { + data := smetadata.ScimContextData{ + ProvisioningDomain: "", + ExternalIDScopedMetadataKey: smetadata.ScopedKey(smetadata.KeyExternalId), + } + + ctx = smetadata.SetScimContextData(ctx, data) + + userID := authz.GetCtxData(ctx).UserID + metadata, err := q.GetUserMetadataByKey(ctx, false, userID, string(smetadata.KeyProvisioningDomain), false) + if err != nil { + if zerrors.IsNotFound(err) { + return ctx, nil + } + + return ctx, err + } + + if metadata == nil { + return ctx, nil + } + + data.ProvisioningDomain = string(metadata.Value) + if data.ProvisioningDomain != "" { + data.ExternalIDScopedMetadataKey = smetadata.ScopeExternalIdKey(data.ProvisioningDomain) + } + return smetadata.SetScimContextData(ctx, data), nil +} diff --git a/internal/api/scim/resources/resource_handler.go b/internal/api/scim/resources/resource_handler.go new file mode 100644 index 0000000000..4e1d9c1d4a --- /dev/null +++ b/internal/api/scim/resources/resource_handler.go @@ -0,0 +1,64 @@ +package resources + +import ( + "context" + "path" + "strconv" + "time" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/api/http" + "github.com/zitadel/zitadel/internal/api/scim/schemas" + "github.com/zitadel/zitadel/internal/domain" +) + +type ResourceHandler[T ResourceHolder] interface { + ResourceNameSingular() schemas.ScimResourceTypeSingular + ResourceNamePlural() schemas.ScimResourceTypePlural + SchemaType() schemas.ScimSchemaType + NewResource() T + + Create(ctx context.Context, resource T) (T, error) + Replace(ctx context.Context, id string, resource T) (T, error) + Delete(ctx context.Context, id string) error + Get(ctx context.Context, id string) (T, error) +} + +type Resource struct { + Schemas []schemas.ScimSchemaType `json:"schemas"` + Meta *ResourceMeta `json:"meta"` +} + +type ResourceMeta struct { + ResourceType schemas.ScimResourceTypeSingular `json:"resourceType"` + Created time.Time `json:"created"` + LastModified time.Time `json:"lastModified"` + Version string `json:"version"` + Location string `json:"location"` +} + +type ResourceHolder interface { + GetResource() *Resource +} + +func buildResource[T ResourceHolder](ctx context.Context, handler ResourceHandler[T], details *domain.ObjectDetails) *Resource { + created := details.CreationDate.UTC() + if created.IsZero() { + created = details.EventDate.UTC() + } + + return &Resource{ + Schemas: []schemas.ScimSchemaType{handler.SchemaType()}, + Meta: &ResourceMeta{ + ResourceType: handler.ResourceNameSingular(), + Created: created, + LastModified: details.EventDate.UTC(), + Version: strconv.FormatUint(details.Sequence, 10), + Location: buildLocation(ctx, handler, details.ID), + }, + } +} + +func buildLocation[T ResourceHolder](ctx context.Context, handler ResourceHandler[T], id string) string { + return http.DomainContext(ctx).Origin() + path.Join(schemas.HandlerPrefix, authz.GetCtxData(ctx).OrgID, string(handler.ResourceNamePlural()), id) +} diff --git a/internal/api/scim/resources/resource_handler_adapter.go b/internal/api/scim/resources/resource_handler_adapter.go new file mode 100644 index 0000000000..5a346911af --- /dev/null +++ b/internal/api/scim/resources/resource_handler_adapter.go @@ -0,0 +1,91 @@ +package resources + +import ( + "encoding/json" + "net/http" + "slices" + + "github.com/gorilla/mux" + + "github.com/zitadel/zitadel/internal/api/scim/schemas" + "github.com/zitadel/zitadel/internal/api/scim/serrors" + "github.com/zitadel/zitadel/internal/zerrors" +) + +type ResourceHandlerAdapter[T ResourceHolder] struct { + handler ResourceHandler[T] +} + +type ListRequest struct { + // Count An integer indicating the desired maximum number of query results per page. OPTIONAL. + Count uint64 `json:"count" schema:"count"` + + // StartIndex An integer indicating the 1-based index of the first query result. Optional. + StartIndex uint64 `json:"startIndex" schema:"startIndex"` +} + +type ListResponse[T any] struct { + Schemas []schemas.ScimSchemaType `json:"schemas"` + ItemsPerPage uint64 `json:"itemsPerPage"` + TotalResults uint64 `json:"totalResults"` + StartIndex uint64 `json:"startIndex"` + Resources []T `json:"Resources"` // according to the rfc this is the only field in PascalCase... +} + +func NewResourceHandlerAdapter[T ResourceHolder](handler ResourceHandler[T]) *ResourceHandlerAdapter[T] { + return &ResourceHandlerAdapter[T]{ + handler, + } +} + +func (adapter *ResourceHandlerAdapter[T]) Create(r *http.Request) (T, error) { + entity, err := adapter.readEntityFromBody(r) + if err != nil { + return entity, err + } + + return adapter.handler.Create(r.Context(), entity) +} + +func (adapter *ResourceHandlerAdapter[T]) Replace(r *http.Request) (T, error) { + entity, err := adapter.readEntityFromBody(r) + if err != nil { + return entity, err + } + + id := mux.Vars(r)["id"] + return adapter.handler.Replace(r.Context(), id, entity) +} + +func (adapter *ResourceHandlerAdapter[T]) Delete(r *http.Request) error { + id := mux.Vars(r)["id"] + return adapter.handler.Delete(r.Context(), id) +} + +func (adapter *ResourceHandlerAdapter[T]) Get(r *http.Request) (T, error) { + id := mux.Vars(r)["id"] + return adapter.handler.Get(r.Context(), id) +} + +func (adapter *ResourceHandlerAdapter[T]) readEntityFromBody(r *http.Request) (T, error) { + entity := adapter.handler.NewResource() + err := json.NewDecoder(r.Body).Decode(entity) + if err != nil { + if zerrors.IsZitadelError(err) { + return entity, err + } + + return entity, serrors.ThrowInvalidSyntax(zerrors.ThrowInvalidArgumentf(nil, "SCIM-ucrjson", "Could not deserialize json: %v", err.Error())) + } + + resource := entity.GetResource() + if resource == nil { + return entity, serrors.ThrowInvalidSyntax(zerrors.ThrowInvalidArgument(nil, "SCIM-xxrjson", "Could not get resource, is the schema correct?")) + } + + if !slices.Contains(resource.Schemas, adapter.handler.SchemaType()) { + return entity, serrors.ThrowInvalidSyntax(zerrors.ThrowInvalidArgumentf(nil, "SCIM-xxrschema", "Expected schema %v is not provided", adapter.handler.SchemaType())) + } + + return entity, nil +} diff --git a/internal/api/scim/resources/user.go b/internal/api/scim/resources/user.go new file mode 100644 index 0000000000..defe849538 --- /dev/null +++ b/internal/api/scim/resources/user.go @@ -0,0 +1,212 @@ +package resources + +import ( + "context" + + "golang.org/x/text/language" + + "github.com/zitadel/zitadel/internal/api/authz" + scim_config "github.com/zitadel/zitadel/internal/api/scim/config" + scim_schemas "github.com/zitadel/zitadel/internal/api/scim/schemas" + "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/crypto" + "github.com/zitadel/zitadel/internal/query" +) + +type UsersHandler struct { + command *command.Commands + query *query.Queries + userCodeAlg crypto.EncryptionAlgorithm + config *scim_config.Config +} + +type ScimUser struct { + *Resource + ID string `json:"id"` + ExternalID string `json:"externalId,omitempty"` + UserName string `json:"userName,omitempty"` + Name *ScimUserName `json:"name,omitempty"` + DisplayName string `json:"displayName,omitempty"` + NickName string `json:"nickName,omitempty"` + ProfileUrl *scim_schemas.HttpURL `json:"profileUrl,omitempty"` + Title string `json:"title,omitempty"` + PreferredLanguage language.Tag `json:"preferredLanguage,omitempty"` + Locale string `json:"locale,omitempty"` + Timezone string `json:"timezone,omitempty"` + Active *bool `json:"active,omitempty"` + Emails []*ScimEmail `json:"emails,omitempty"` + PhoneNumbers []*ScimPhoneNumber `json:"phoneNumbers,omitempty"` + Password *scim_schemas.WriteOnlyString `json:"password,omitempty"` + Ims []*ScimIms `json:"ims,omitempty"` + Addresses []*ScimAddress `json:"addresses,omitempty"` + Photos []*ScimPhoto `json:"photos,omitempty"` + Entitlements []*ScimEntitlement `json:"entitlements,omitempty"` + Roles []*ScimRole `json:"roles,omitempty"` +} + +type ScimEntitlement struct { + Value string `json:"value,omitempty"` + Display string `json:"display,omitempty"` + Type string `json:"type,omitempty"` + Primary bool `json:"primary,omitempty"` +} + +type ScimRole struct { + Value string `json:"value,omitempty"` + Display string `json:"display,omitempty"` + Type string `json:"type,omitempty"` + Primary bool `json:"primary,omitempty"` +} + +type ScimPhoto struct { + Value scim_schemas.HttpURL `json:"value"` + Display string `json:"display,omitempty"` + Type string `json:"type"` + Primary bool `json:"primary,omitempty"` +} + +type ScimAddress struct { + Type string `json:"type,omitempty"` + StreetAddress string `json:"streetAddress,omitempty"` + Locality string `json:"locality,omitempty"` + Region string `json:"region,omitempty"` + PostalCode string `json:"postalCode,omitempty"` + Country string `json:"country,omitempty"` + Formatted string `json:"formatted,omitempty"` + Primary bool `json:"primary,omitempty"` +} + +type ScimIms struct { + Value string `json:"value"` + Type string `json:"type"` +} + +type ScimEmail struct { + Value string `json:"value"` + Primary bool `json:"primary"` +} + +type ScimPhoneNumber struct { + Value string `json:"value"` + Primary bool `json:"primary"` +} + +type ScimUserName struct { + Formatted string `json:"formatted,omitempty"` + FamilyName string `json:"familyName,omitempty"` + GivenName string `json:"givenName,omitempty"` + MiddleName string `json:"middleName,omitempty"` + HonorificPrefix string `json:"honorificPrefix,omitempty"` + HonorificSuffix string `json:"honorificSuffix,omitempty"` +} + +func NewUsersHandler( + command *command.Commands, + query *query.Queries, + userCodeAlg crypto.EncryptionAlgorithm, + config *scim_config.Config) ResourceHandler[*ScimUser] { + return &UsersHandler{command, query, userCodeAlg, config} +} + +func (h *UsersHandler) ResourceNameSingular() scim_schemas.ScimResourceTypeSingular { + return scim_schemas.UserResourceType +} + +func (h *UsersHandler) ResourceNamePlural() scim_schemas.ScimResourceTypePlural { + return scim_schemas.UsersResourceType +} + +func (u *ScimUser) GetResource() *Resource { + return u.Resource +} + +func (h *UsersHandler) NewResource() *ScimUser { + return new(ScimUser) +} + +func (h *UsersHandler) SchemaType() scim_schemas.ScimSchemaType { + return scim_schemas.IdUser +} + +func (h *UsersHandler) Create(ctx context.Context, user *ScimUser) (*ScimUser, error) { + orgID := authz.GetCtxData(ctx).OrgID + addHuman, err := h.mapToAddHuman(ctx, user) + if err != nil { + return nil, err + } + + err = h.command.AddUserHuman(ctx, orgID, addHuman, true, h.userCodeAlg) + if err != nil { + return nil, err + } + + h.mapAddCommandToScimUser(ctx, user, addHuman) + return user, nil +} + +func (h *UsersHandler) Replace(ctx context.Context, id string, user *ScimUser) (*ScimUser, error) { + user.ID = id + changeHuman, err := h.mapToChangeHuman(ctx, user) + if err != nil { + return nil, err + } + + err = h.command.ChangeUserHuman(ctx, changeHuman, h.userCodeAlg) + if err != nil { + return nil, err + } + + h.mapChangeCommandToScimUser(ctx, user, changeHuman) + return user, nil +} + +func (h *UsersHandler) Delete(ctx context.Context, id string) error { + memberships, grants, err := h.queryUserDependencies(ctx, id) + if err != nil { + return err + } + + _, err = h.command.RemoveUserV2(ctx, id, memberships, grants...) + return err +} + +func (h *UsersHandler) Get(ctx context.Context, id string) (*ScimUser, error) { + user, err := h.query.GetUserByID(ctx, false, id) + if err != nil { + return nil, err + } + + metadata, err := h.queryMetadataForUser(ctx, id) + if err != nil { + return nil, err + } + return h.mapToScimUser(ctx, user, metadata), nil +} + +func (h *UsersHandler) queryUserDependencies(ctx context.Context, userID string) ([]*command.CascadingMembership, []string, error) { + userGrantUserQuery, err := query.NewUserGrantUserIDSearchQuery(userID) + if err != nil { + return nil, nil, err + } + + grants, err := h.query.UserGrants(ctx, &query.UserGrantsQueries{ + Queries: []query.SearchQuery{userGrantUserQuery}, + }, true) + if err != nil { + return nil, nil, err + } + + membershipsUserQuery, err := query.NewMembershipUserIDQuery(userID) + if err != nil { + return nil, nil, err + } + + memberships, err := h.query.Memberships(ctx, &query.MembershipSearchQuery{ + Queries: []query.SearchQuery{membershipsUserQuery}, + }, false) + + if err != nil { + return nil, nil, err + } + return cascadingMemberships(memberships.Memberships), userGrantsToIDs(grants.UserGrants), nil +} diff --git a/internal/api/scim/resources/user_mapping.go b/internal/api/scim/resources/user_mapping.go new file mode 100644 index 0000000000..4de826ca69 --- /dev/null +++ b/internal/api/scim/resources/user_mapping.go @@ -0,0 +1,366 @@ +package resources + +import ( + "context" + "strconv" + "time" + + "github.com/muhlemmer/gu" + "github.com/zitadel/logging" + "golang.org/x/text/language" + + "github.com/zitadel/zitadel/internal/api/scim/metadata" + "github.com/zitadel/zitadel/internal/api/scim/schemas" + "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/query" +) + +func (h *UsersHandler) mapToAddHuman(ctx context.Context, scimUser *ScimUser) (*command.AddHuman, error) { + human := &command.AddHuman{ + Username: scimUser.UserName, + NickName: scimUser.NickName, + DisplayName: scimUser.DisplayName, + } + + if scimUser.Active != nil && !*scimUser.Active { + human.SetInactive = true + } + + if email := h.mapPrimaryEmail(scimUser); email != nil { + human.Email = *email + } + + if phone := h.mapPrimaryPhone(scimUser); phone != nil { + human.Phone = *phone + } + + md, err := h.mapMetadataToCommands(ctx, scimUser) + if err != nil { + return nil, err + } + human.Metadata = md + + if scimUser.Password != nil { + human.Password = scimUser.Password.String() + scimUser.Password = nil + } + + if scimUser.Name != nil { + human.FirstName = scimUser.Name.GivenName + human.LastName = scimUser.Name.FamilyName + + // the direct mapping displayName => displayName has priority + // over the formatted name assignment + if human.DisplayName == "" { + human.DisplayName = scimUser.Name.Formatted + } else { + // update user to match the actual stored value + scimUser.Name.Formatted = human.DisplayName + } + } + + if err := domain.LanguageIsDefined(scimUser.PreferredLanguage); err != nil { + human.PreferredLanguage = language.English + scimUser.PreferredLanguage = language.English + } + + return human, nil +} + +func (h *UsersHandler) mapToChangeHuman(ctx context.Context, scimUser *ScimUser) (*command.ChangeHuman, error) { + human := &command.ChangeHuman{ + ID: scimUser.ID, + Username: &scimUser.UserName, + Profile: &command.Profile{ + NickName: &scimUser.NickName, + DisplayName: &scimUser.DisplayName, + }, + Email: h.mapPrimaryEmail(scimUser), + Phone: h.mapPrimaryPhone(scimUser), + } + + if scimUser.Active != nil { + if *scimUser.Active { + human.State = gu.Ptr(domain.UserStateActive) + } else { + human.State = gu.Ptr(domain.UserStateInactive) + } + } + + md, mdRemovedKeys, err := h.mapMetadataToDomain(ctx, scimUser) + if err != nil { + return nil, err + } + human.Metadata = md + human.MetadataKeysToRemove = mdRemovedKeys + + if scimUser.Password != nil { + human.Password = &command.Password{ + Password: scimUser.Password.String(), + } + scimUser.Password = nil + } + + if scimUser.Name != nil { + human.Profile.FirstName = &scimUser.Name.GivenName + human.Profile.LastName = &scimUser.Name.FamilyName + + // the direct mapping displayName => displayName has priority + // over the formatted name assignment + if *human.Profile.DisplayName == "" { + human.Profile.DisplayName = &scimUser.Name.Formatted + } else { + // update user to match the actual stored value + scimUser.Name.Formatted = *human.Profile.DisplayName + } + } + + if err := domain.LanguageIsDefined(scimUser.PreferredLanguage); err != nil { + human.Profile.PreferredLanguage = &language.English + scimUser.PreferredLanguage = language.English + } + + return human, nil +} + +func (h *UsersHandler) mapPrimaryEmail(scimUser *ScimUser) *command.Email { + for _, email := range scimUser.Emails { + if !email.Primary { + continue + } + + return &command.Email{ + Address: domain.EmailAddress(email.Value), + Verified: h.config.EmailVerified, + } + } + + return nil +} + +func (h *UsersHandler) mapPrimaryPhone(scimUser *ScimUser) *command.Phone { + for _, phone := range scimUser.PhoneNumbers { + if !phone.Primary { + continue + } + + return &command.Phone{ + Number: domain.PhoneNumber(phone.Value), + Verified: h.config.PhoneVerified, + } + } + + return nil +} + +func (h *UsersHandler) mapAddCommandToScimUser(ctx context.Context, user *ScimUser, addHuman *command.AddHuman) { + user.ID = addHuman.Details.ID + user.Resource = buildResource(ctx, h, addHuman.Details) + user.Password = nil + + // ZITADEL supports only one (primary) phone number or email. + // Therefore, only the primary one should be returned. + // Note that the phone number might also be reformatted. + if addHuman.Phone.Number != "" { + user.PhoneNumbers = []*ScimPhoneNumber{ + { + Value: string(addHuman.Phone.Number), + Primary: true, + }, + } + } + + if addHuman.Email.Address != "" { + user.Emails = []*ScimEmail{ + { + Value: string(addHuman.Email.Address), + Primary: true, + }, + } + } +} + +func (h *UsersHandler) mapChangeCommandToScimUser(ctx context.Context, user *ScimUser, changeHuman *command.ChangeHuman) { + user.ID = changeHuman.Details.ID + user.Resource = buildResource(ctx, h, changeHuman.Details) + user.Password = nil + + // ZITADEL supports only one (primary) phone number or email. + // Therefore, only the primary one should be returned. + // Note that the phone number might also be reformatted. + if changeHuman.Phone != nil { + user.PhoneNumbers = []*ScimPhoneNumber{ + { + Value: string(changeHuman.Phone.Number), + Primary: true, + }, + } + } + + if changeHuman.Email != nil { + user.Emails = []*ScimEmail{ + { + Value: string(changeHuman.Email.Address), + Primary: true, + }, + } + } +} + +func (h *UsersHandler) mapToScimUser(ctx context.Context, user *query.User, md map[metadata.ScopedKey][]byte) *ScimUser { + scimUser := &ScimUser{ + Resource: h.buildResourceForQuery(ctx, user), + ID: user.ID, + ExternalID: extractScalarMetadata(ctx, md, metadata.KeyExternalId), + UserName: user.Username, + ProfileUrl: extractHttpURLMetadata(ctx, md, metadata.KeyProfileUrl), + Title: extractScalarMetadata(ctx, md, metadata.KeyTitle), + Locale: extractScalarMetadata(ctx, md, metadata.KeyLocale), + Timezone: extractScalarMetadata(ctx, md, metadata.KeyTimezone), + Active: gu.Ptr(user.State.IsEnabled()), + Ims: make([]*ScimIms, 0), + Addresses: make([]*ScimAddress, 0), + Photos: make([]*ScimPhoto, 0), + Entitlements: make([]*ScimEntitlement, 0), + Roles: make([]*ScimRole, 0), + } + + if scimUser.Locale != "" { + _, err := language.Parse(scimUser.Locale) + if err != nil { + logging.OnError(err).Warn("Failed to load locale of scim user") + scimUser.Locale = "" + } + } + + if scimUser.Timezone != "" { + _, err := time.LoadLocation(scimUser.Timezone) + if err != nil { + logging.OnError(err).Warn("Failed to load timezone of scim user") + scimUser.Timezone = "" + } + } + + if err := extractJsonMetadata(ctx, md, metadata.KeyIms, &scimUser.Ims); err != nil { + logging.OnError(err).Warn("Could not deserialize scim ims metadata") + } + + if err := extractJsonMetadata(ctx, md, metadata.KeyAddresses, &scimUser.Addresses); err != nil { + logging.OnError(err).Warn("Could not deserialize scim addresses metadata") + } + + if err := extractJsonMetadata(ctx, md, metadata.KeyPhotos, &scimUser.Photos); err != nil { + logging.OnError(err).Warn("Could not deserialize scim photos metadata") + } + + if err := extractJsonMetadata(ctx, md, metadata.KeyEntitlements, &scimUser.Entitlements); err != nil { + logging.OnError(err).Warn("Could not deserialize scim entitlements metadata") + } + + if err := extractJsonMetadata(ctx, md, metadata.KeyRoles, &scimUser.Roles); err != nil { + logging.OnError(err).Warn("Could not deserialize scim roles metadata") + } + + if user.Human != nil { + mapHumanToScimUser(ctx, user.Human, scimUser, md) + } + + return scimUser +} + +func mapHumanToScimUser(ctx context.Context, human *query.Human, user *ScimUser, md map[metadata.ScopedKey][]byte) { + user.DisplayName = human.DisplayName + user.NickName = human.NickName + user.PreferredLanguage = human.PreferredLanguage + user.Name = &ScimUserName{ + Formatted: human.DisplayName, + FamilyName: human.LastName, + GivenName: human.FirstName, + MiddleName: extractScalarMetadata(ctx, md, metadata.KeyMiddleName), + HonorificPrefix: extractScalarMetadata(ctx, md, metadata.KeyHonorificPrefix), + HonorificSuffix: extractScalarMetadata(ctx, md, metadata.KeyHonorificSuffix), + } + + if string(human.Email) != "" { + user.Emails = []*ScimEmail{ + { + Value: string(human.Email), + Primary: true, + }, + } + } + + if string(human.Phone) != "" { + user.PhoneNumbers = []*ScimPhoneNumber{ + { + Value: string(human.Phone), + Primary: true, + }, + } + } +} + +func (h *UsersHandler) buildResourceForQuery(ctx context.Context, user *query.User) *Resource { + return &Resource{ + Schemas: []schemas.ScimSchemaType{schemas.IdUser}, + Meta: &ResourceMeta{ + ResourceType: schemas.UserResourceType, + Created: user.CreationDate.UTC(), + LastModified: user.ChangeDate.UTC(), + Version: strconv.FormatUint(user.Sequence, 10), + Location: buildLocation(ctx, h, user.ID), + }, + } +} + +func cascadingMemberships(memberships []*query.Membership) []*command.CascadingMembership { + cascades := make([]*command.CascadingMembership, len(memberships)) + for i, membership := range memberships { + cascades[i] = &command.CascadingMembership{ + UserID: membership.UserID, + ResourceOwner: membership.ResourceOwner, + IAM: cascadingIAMMembership(membership.IAM), + Org: cascadingOrgMembership(membership.Org), + Project: cascadingProjectMembership(membership.Project), + ProjectGrant: cascadingProjectGrantMembership(membership.ProjectGrant), + } + } + return cascades +} + +func cascadingIAMMembership(membership *query.IAMMembership) *command.CascadingIAMMembership { + if membership == nil { + return nil + } + return &command.CascadingIAMMembership{IAMID: membership.IAMID} +} + +func cascadingOrgMembership(membership *query.OrgMembership) *command.CascadingOrgMembership { + if membership == nil { + return nil + } + return &command.CascadingOrgMembership{OrgID: membership.OrgID} +} + +func cascadingProjectMembership(membership *query.ProjectMembership) *command.CascadingProjectMembership { + if membership == nil { + return nil + } + return &command.CascadingProjectMembership{ProjectID: membership.ProjectID} +} + +func cascadingProjectGrantMembership(membership *query.ProjectGrantMembership) *command.CascadingProjectGrantMembership { + if membership == nil { + return nil + } + return &command.CascadingProjectGrantMembership{ProjectID: membership.ProjectID, GrantID: membership.GrantID} +} + +func userGrantsToIDs(userGrants []*query.UserGrant) []string { + converted := make([]string, len(userGrants)) + for i, grant := range userGrants { + converted[i] = grant.ID + } + return converted +} diff --git a/internal/api/scim/resources/user_metadata.go b/internal/api/scim/resources/user_metadata.go new file mode 100644 index 0000000000..d08594c3cf --- /dev/null +++ b/internal/api/scim/resources/user_metadata.go @@ -0,0 +1,259 @@ +package resources + +import ( + "context" + "encoding/json" + "time" + // import timezone database to ensure it is available at runtime + // data is required to validate time zones. + _ "time/tzdata" + + "github.com/zitadel/logging" + "golang.org/x/text/language" + + "github.com/zitadel/zitadel/internal/api/scim/metadata" + "github.com/zitadel/zitadel/internal/api/scim/schemas" + "github.com/zitadel/zitadel/internal/api/scim/serrors" + "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/query" + "github.com/zitadel/zitadel/internal/zerrors" +) + +func (h *UsersHandler) queryMetadataForUser(ctx context.Context, id string) (map[metadata.ScopedKey][]byte, error) { + queries := h.buildMetadataQueries(ctx) + + md, err := h.query.SearchUserMetadata(ctx, false, id, queries, false) + if err != nil { + return nil, err + } + + metadataMap := make(map[metadata.ScopedKey][]byte, len(md.Metadata)) + for _, entry := range md.Metadata { + metadataMap[metadata.ScopedKey(entry.Key)] = entry.Value + } + + return metadataMap, nil +} + +func (h *UsersHandler) buildMetadataQueries(ctx context.Context) *query.UserMetadataSearchQueries { + keyQueries := make([]query.SearchQuery, len(metadata.ScimUserRelevantMetadataKeys)) + for i, key := range metadata.ScimUserRelevantMetadataKeys { + keyQueries[i] = buildMetadataKeyQuery(ctx, key) + } + + queries := &query.UserMetadataSearchQueries{ + SearchRequest: query.SearchRequest{}, + Queries: []query.SearchQuery{query.Or(keyQueries...)}, + } + return queries +} + +func buildMetadataKeyQuery(ctx context.Context, key metadata.Key) query.SearchQuery { + scopedKey := metadata.ScopeKey(ctx, key) + q, err := query.NewUserMetadataKeySearchQuery(string(scopedKey), query.TextEquals) + if err != nil { + logging.Panic("Error build user metadata query for key " + key) + } + + return q +} + +func (h *UsersHandler) mapMetadataToDomain(ctx context.Context, user *ScimUser) (md []*domain.Metadata, skippedMetadata []string, err error) { + md = make([]*domain.Metadata, 0, len(metadata.ScimUserRelevantMetadataKeys)) + for _, key := range metadata.ScimUserRelevantMetadataKeys { + var value []byte + value, err = getValueForMetadataKey(user, key) + if err != nil { + return + } + + if len(value) > 0 { + md = append(md, &domain.Metadata{ + Key: string(metadata.ScopeKey(ctx, key)), + Value: value, + }) + } else { + skippedMetadata = append(skippedMetadata, string(metadata.ScopeKey(ctx, key))) + } + } + + return +} + +func (h *UsersHandler) mapMetadataToCommands(ctx context.Context, user *ScimUser) ([]*command.AddMetadataEntry, error) { + md := make([]*command.AddMetadataEntry, 0, len(metadata.ScimUserRelevantMetadataKeys)) + for _, key := range metadata.ScimUserRelevantMetadataKeys { + value, err := getValueForMetadataKey(user, key) + if err != nil { + return nil, err + } + + if len(value) > 0 { + md = append(md, &command.AddMetadataEntry{ + Key: string(metadata.ScopeKey(ctx, key)), + Value: value, + }) + } + } + + return md, nil +} + +func getValueForMetadataKey(user *ScimUser, key metadata.Key) ([]byte, error) { + value := getRawValueForMetadataKey(user, key) + if value == nil { + return nil, nil + } + + switch key { + // json values + case metadata.KeyEntitlements: + fallthrough + case metadata.KeyIms: + fallthrough + case metadata.KeyPhotos: + fallthrough + case metadata.KeyAddresses: + fallthrough + case metadata.KeyRoles: + val, err := json.Marshal(value) + if err != nil { + return nil, err + } + + // null is considered no value + if len(val) == 4 && string(val) == "null" { + return nil, nil + } + + return val, nil + + // http url values + case metadata.KeyProfileUrl: + return []byte(value.(*schemas.HttpURL).String()), nil + + // raw values + case metadata.KeyProvisioningDomain: + fallthrough + case metadata.KeyExternalId: + fallthrough + case metadata.KeyMiddleName: + fallthrough + case metadata.KeyHonorificSuffix: + fallthrough + case metadata.KeyHonorificPrefix: + fallthrough + case metadata.KeyTitle: + fallthrough + case metadata.KeyLocale: + fallthrough + case metadata.KeyTimezone: + valueStr := value.(string) + if valueStr == "" { + return nil, nil + } + + return []byte(valueStr), validateValueForMetadataKey(valueStr, key) + } + + logging.Panicf("Unknown metadata key %s", key) + return nil, nil +} + +func validateValueForMetadataKey(v string, key metadata.Key) error { + //nolint:exhaustive + switch key { + case metadata.KeyLocale: + if _, err := language.Parse(v); err != nil { + return serrors.ThrowInvalidValue(zerrors.ThrowInvalidArgument(err, "SCIM-MD11", "Could not parse locale")) + } + return nil + case metadata.KeyTimezone: + if _, err := time.LoadLocation(v); err != nil { + return serrors.ThrowInvalidValue(zerrors.ThrowInvalidArgument(err, "SCIM-MD12", "Could not parse timezone")) + } + + return nil + } + + return nil +} + +func getRawValueForMetadataKey(user *ScimUser, key metadata.Key) interface{} { + switch key { + case metadata.KeyIms: + return user.Ims + case metadata.KeyPhotos: + return user.Photos + case metadata.KeyAddresses: + return user.Addresses + case metadata.KeyEntitlements: + return user.Entitlements + case metadata.KeyRoles: + return user.Roles + case metadata.KeyMiddleName: + if user.Name == nil { + return "" + } + return user.Name.MiddleName + case metadata.KeyHonorificPrefix: + if user.Name == nil { + return "" + } + return user.Name.HonorificPrefix + case metadata.KeyHonorificSuffix: + if user.Name == nil { + return "" + } + return user.Name.HonorificSuffix + case metadata.KeyExternalId: + return user.ExternalID + case metadata.KeyProfileUrl: + return user.ProfileUrl + case metadata.KeyTitle: + return user.Title + case metadata.KeyLocale: + return user.Locale + case metadata.KeyTimezone: + return user.Timezone + case metadata.KeyProvisioningDomain: + break + } + + logging.Panicf("Unknown or unsupported metadata key %s", key) + return nil +} + +func extractScalarMetadata(ctx context.Context, md map[metadata.ScopedKey][]byte, key metadata.Key) string { + val, ok := md[metadata.ScopeKey(ctx, key)] + if !ok { + return "" + } + + return string(val) +} + +func extractHttpURLMetadata(ctx context.Context, md map[metadata.ScopedKey][]byte, key metadata.Key) *schemas.HttpURL { + val, ok := md[metadata.ScopeKey(ctx, key)] + if !ok { + return nil + } + + url, err := schemas.ParseHTTPURL(string(val)) + if err != nil { + logging.OnError(err).Warn("Failed to parse scim url metadata for " + key) + return nil + } + + return url +} + +func extractJsonMetadata(ctx context.Context, md map[metadata.ScopedKey][]byte, key metadata.Key, v interface{}) error { + val, ok := md[metadata.ScopeKey(ctx, key)] + if !ok { + return nil + } + + return json.Unmarshal(val, v) +} diff --git a/internal/api/scim/schemas/schemas.go b/internal/api/scim/schemas/schemas.go new file mode 100644 index 0000000000..662a31f46f --- /dev/null +++ b/internal/api/scim/schemas/schemas.go @@ -0,0 +1,20 @@ +package schemas + +type ScimSchemaType string +type ScimResourceTypeSingular string +type ScimResourceTypePlural string + +const ( + idPrefixMessages = "urn:ietf:params:scim:api:messages:2.0:" + idPrefixCore = "urn:ietf:params:scim:schemas:core:2.0:" + idPrefixZitadelMessages = "urn:ietf:params:scim:api:zitadel:messages:2.0:" + + IdUser ScimSchemaType = idPrefixCore + "User" + IdError ScimSchemaType = idPrefixMessages + "Error" + IdZitadelErrorDetail ScimSchemaType = idPrefixZitadelMessages + "ErrorDetail" + + UserResourceType ScimResourceTypeSingular = "User" + UsersResourceType ScimResourceTypePlural = "Users" + + HandlerPrefix = "/scim/v2" +) diff --git a/internal/api/scim/schemas/string.go b/internal/api/scim/schemas/string.go new file mode 100644 index 0000000000..b62e50893d --- /dev/null +++ b/internal/api/scim/schemas/string.go @@ -0,0 +1,28 @@ +package schemas + +import "encoding/json" + +// WriteOnlyString a write only string is not serializable to json. +// in the SCIM RFC it has a mutability of writeOnly. +// This increases security to really ensure this is never sent to a client. +type WriteOnlyString string + +func NewWriteOnlyString(s string) *WriteOnlyString { + wos := WriteOnlyString(s) + return &wos +} + +func (s *WriteOnlyString) MarshalJSON() ([]byte, error) { + return []byte("null"), nil +} + +func (s *WriteOnlyString) UnmarshalJSON(bytes []byte) error { + var str string + err := json.Unmarshal(bytes, &str) + *s = WriteOnlyString(str) + return err +} + +func (s *WriteOnlyString) String() string { + return string(*s) +} diff --git a/internal/api/scim/schemas/string_test.go b/internal/api/scim/schemas/string_test.go new file mode 100644 index 0000000000..c48130a5d1 --- /dev/null +++ b/internal/api/scim/schemas/string_test.go @@ -0,0 +1,70 @@ +package schemas + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWriteOnlyString_MarshalJSON(t *testing.T) { + tests := []struct { + name string + s WriteOnlyString + }{ + { + name: "always returns null", + s: "foo bar", + }, + { + name: "empty string returns null", + s: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := json.Marshal(&tt.s) + assert.NoError(t, err) + assert.Equal(t, "null", string(got)) + }) + } +} + +func TestWriteOnlyString_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + input []byte + want WriteOnlyString + wantErr bool + }{ + { + name: "string", + input: []byte(`"fooBar"`), + want: "fooBar", + wantErr: false, + }, + { + name: "empty string", + input: []byte(`""`), + want: "", + wantErr: false, + }, + { + name: "bad format", + input: []byte(`"bad "format"`), + want: "", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got WriteOnlyString + err := json.Unmarshal(tt.input, &got) + if (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/api/scim/schemas/url.go b/internal/api/scim/schemas/url.go new file mode 100644 index 0000000000..343803bc04 --- /dev/null +++ b/internal/api/scim/schemas/url.go @@ -0,0 +1,50 @@ +package schemas + +import ( + "encoding/json" + "net/url" + + "github.com/zitadel/zitadel/internal/zerrors" +) + +type HttpURL url.URL + +func ParseHTTPURL(rawURL string) (*HttpURL, error) { + parsedURL, err := url.Parse(rawURL) + if err != nil { + return nil, err + } + + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return nil, zerrors.ThrowInvalidArgumentf(nil, "SCIM-htturl1", "HTTP URL expected, got %v", parsedURL.Scheme) + } + + return (*HttpURL)(parsedURL), nil +} + +func (u *HttpURL) UnmarshalJSON(data []byte) error { + var urlStr string + if err := json.Unmarshal(data, &urlStr); err != nil { + return err + } + + parsedURL, err := ParseHTTPURL(urlStr) + if err != nil { + return err + } + + *u = *parsedURL + return nil +} + +func (u *HttpURL) MarshalJSON() ([]byte, error) { + return json.Marshal(u.String()) +} + +func (u *HttpURL) String() string { + if u == nil { + return "" + } + + return (*url.URL)(u).String() +} diff --git a/internal/api/scim/schemas/url_test.go b/internal/api/scim/schemas/url_test.go new file mode 100644 index 0000000000..a6a60322e0 --- /dev/null +++ b/internal/api/scim/schemas/url_test.go @@ -0,0 +1,182 @@ +package schemas + +import ( + "reflect" + "testing" + + "github.com/goccy/go-json" + "github.com/stretchr/testify/assert" + "github.com/zitadel/logging" +) + +func TestHttpURL_MarshalJSON(t *testing.T) { + tests := []struct { + name string + u *HttpURL + want []byte + wantErr bool + }{ + { + name: "http url", + u: mustParseURL("http://example.com"), + want: []byte(`"http://example.com"`), + wantErr: false, + }, + { + name: "https url", + u: mustParseURL("https://example.com"), + want: []byte(`"https://example.com"`), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := json.Marshal(tt.u) + if (err != nil) != tt.wantErr { + t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + + assert.Equal(t, string(got), string(tt.want)) + }) + } +} + +func TestHttpURL_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + data []byte + want *HttpURL + wantErr bool + }{ + { + name: "http url", + data: []byte(`"http://example.com"`), + want: mustParseURL("http://example.com"), + wantErr: false, + }, + { + name: "https url", + data: []byte(`"https://example.com"`), + want: mustParseURL("https://example.com"), + wantErr: false, + }, + { + name: "ftp url should fail", + data: []byte(`"ftp://example.com"`), + want: nil, + wantErr: true, + }, + { + name: "no url should fail", + data: []byte(`"test"`), + want: nil, + wantErr: true, + }, + { + name: "number should fail", + data: []byte(`120`), + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url := new(HttpURL) + err := json.Unmarshal(tt.data, url) + if (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantErr { + return + } + + assert.Equal(t, tt.want.String(), url.String()) + }) + } +} + +func TestHttpURL_String(t *testing.T) { + tests := []struct { + name string + u *HttpURL + want string + }{ + { + name: "http url", + u: mustParseURL("http://example.com"), + want: "http://example.com", + }, + { + name: "https url", + u: mustParseURL("https://example.com"), + want: "https://example.com", + }, + { + name: "nil", + u: nil, + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.u.String(); got != tt.want { + t.Errorf("String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestParseHTTPURL(t *testing.T) { + tests := []struct { + name string + rawURL string + want *HttpURL + wantErr bool + }{ + { + name: "http url", + rawURL: "http://example.com", + want: mustParseURL("http://example.com"), + wantErr: false, + }, + { + name: "https url", + rawURL: "https://example.com", + want: mustParseURL("https://example.com"), + wantErr: false, + }, + { + name: "ftp url should fail", + rawURL: "ftp://example.com", + want: nil, + wantErr: true, + }, + { + name: "no url should fail", + rawURL: "test", + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseHTTPURL(tt.rawURL) + if (err != nil) != tt.wantErr { + t.Errorf("ParseHTTPURL() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ParseHTTPURL() got = %v, want %v", got, tt.want) + } + }) + } +} + +func mustParseURL(rawURL string) *HttpURL { + url, err := ParseHTTPURL(rawURL) + logging.OnError(err).Fatal("failed to parse URL") + return url +} diff --git a/internal/api/scim/serrors/errors.go b/internal/api/scim/serrors/errors.go new file mode 100644 index 0000000000..fffd598b27 --- /dev/null +++ b/internal/api/scim/serrors/errors.go @@ -0,0 +1,140 @@ +package serrors + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + + "github.com/zitadel/logging" + "golang.org/x/text/language" + + http_util "github.com/zitadel/zitadel/internal/api/http" + zhttp_middleware "github.com/zitadel/zitadel/internal/api/http/middleware" + "github.com/zitadel/zitadel/internal/api/scim/schemas" + "github.com/zitadel/zitadel/internal/i18n" + "github.com/zitadel/zitadel/internal/zerrors" +) + +type scimErrorType string + +type wrappedScimError struct { + Parent error + ScimType scimErrorType +} + +type scimError struct { + Schemas []schemas.ScimSchemaType `json:"schemas"` + ScimType scimErrorType `json:"scimType,omitempty"` + Detail string `json:"detail,omitempty"` + StatusCode int `json:"-"` + Status string `json:"status"` + ZitadelDetail *errorDetail `json:"urn:ietf:params:scim:api:zitadel:messages:2.0:ErrorDetail,omitempty"` +} + +type errorDetail struct { + ID string `json:"id"` + Message string `json:"message"` +} + +const ( + // ScimTypeInvalidValue A required value was missing, + // or the value specified was not compatible with the operation, + // or attribute type (see Section 2.2 of RFC7643), + // or resource schema (see Section 4 of RFC7643). + ScimTypeInvalidValue scimErrorType = "invalidValue" + + // ScimTypeInvalidSyntax The request body message structure was invalid or did + // not conform to the request schema. + ScimTypeInvalidSyntax scimErrorType = "invalidSyntax" +) + +var translator *i18n.Translator + +func ErrorHandler(next zhttp_middleware.HandlerFuncWithError) http.Handler { + var err error + translator, err = i18n.NewZitadelTranslator(language.English) + logging.OnError(err).Panic("unable to get translator") + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err = next(w, r); err == nil { + return + } + + scimErr := mapToScimJsonError(r.Context(), err) + w.WriteHeader(scimErr.StatusCode) + + jsonErr := json.NewEncoder(w).Encode(scimErr) + logging.OnError(jsonErr).Warn("Failed to marshal scim error response") + }) +} + +func ThrowInvalidValue(parent error) error { + return &wrappedScimError{ + Parent: parent, + ScimType: ScimTypeInvalidValue, + } +} + +func ThrowInvalidSyntax(parent error) error { + return &wrappedScimError{ + Parent: parent, + ScimType: ScimTypeInvalidSyntax, + } +} + +func (err *scimError) Error() string { + return fmt.Sprintf("SCIM Error: %s: %s", err.ScimType, err.Detail) +} + +func (err *wrappedScimError) Error() string { + return fmt.Sprintf("SCIM Error: %s: %s", err.ScimType, err.Parent.Error()) +} + +func mapToScimJsonError(ctx context.Context, err error) *scimError { + scimErr := new(wrappedScimError) + if ok := errors.As(err, &scimErr); ok { + mappedErr := mapToScimJsonError(ctx, scimErr.Parent) + mappedErr.ScimType = scimErr.ScimType + return mappedErr + } + + zitadelErr := new(zerrors.ZitadelError) + if ok := errors.As(err, &zitadelErr); !ok { + return &scimError{ + Schemas: []schemas.ScimSchemaType{schemas.IdError}, + Detail: "Unknown internal server error", + Status: strconv.Itoa(http.StatusInternalServerError), + StatusCode: http.StatusInternalServerError, + } + } + + statusCode, ok := http_util.ZitadelErrorToHTTPStatusCode(err) + if !ok { + statusCode = http.StatusInternalServerError + } + + localizedMsg := translator.LocalizeFromCtx(ctx, zitadelErr.GetMessage(), nil) + return &scimError{ + Schemas: []schemas.ScimSchemaType{schemas.IdError, schemas.IdZitadelErrorDetail}, + ScimType: mapErrorToScimErrorType(err), + Detail: localizedMsg, + StatusCode: statusCode, + Status: strconv.Itoa(statusCode), + ZitadelDetail: &errorDetail{ + ID: zitadelErr.GetID(), + Message: zitadelErr.GetMessage(), + }, + } +} + +func mapErrorToScimErrorType(err error) scimErrorType { + switch { + case zerrors.IsErrorInvalidArgument(err): + return ScimTypeInvalidValue + default: + return "" + } +} diff --git a/internal/api/scim/serrors/errors_test.go b/internal/api/scim/serrors/errors_test.go new file mode 100644 index 0000000000..71d8018355 --- /dev/null +++ b/internal/api/scim/serrors/errors_test.go @@ -0,0 +1,110 @@ +package serrors + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/zitadel/zitadel/internal/i18n" + "github.com/zitadel/zitadel/internal/zerrors" +) + +func TestErrorHandler(t *testing.T) { + i18n.MustLoadSupportedLanguagesFromDir() + + tests := []struct { + name string + err error + wantStatus int + wantBody string + }{ + { + name: "scim error", + err: ThrowInvalidSyntax(zerrors.ThrowInvalidArgument(nil, "FOO", "Invalid syntax")), + wantStatus: http.StatusBadRequest, + wantBody: `{ + "schemas":[ + "urn:ietf:params:scim:api:messages:2.0:Error", + "urn:ietf:params:scim:api:zitadel:messages:2.0:ErrorDetail" + ], + "scimType":"invalidSyntax", + "detail":"Invalid syntax", + "status":"400", + "urn:ietf:params:scim:api:zitadel:messages:2.0:ErrorDetail": { + "id":"FOO", + "message":"Invalid syntax" + } + }`, + }, + { + name: "zitadel error", + err: zerrors.ThrowInvalidArgument(nil, "FOO", "Invalid syntax"), + wantStatus: http.StatusBadRequest, + wantBody: `{ + "schemas":[ + "urn:ietf:params:scim:api:messages:2.0:Error", + "urn:ietf:params:scim:api:zitadel:messages:2.0:ErrorDetail" + ], + "scimType":"invalidValue", + "detail":"Invalid syntax", + "status":"400", + "urn:ietf:params:scim:api:zitadel:messages:2.0:ErrorDetail": { + "id":"FOO", + "message":"Invalid syntax" + } + }`, + }, + { + name: "zitadel internal error", + err: zerrors.ThrowInternal(nil, "FOO", "Internal error"), + wantStatus: http.StatusInternalServerError, + wantBody: `{ + "schemas":[ + "urn:ietf:params:scim:api:messages:2.0:Error", + "urn:ietf:params:scim:api:zitadel:messages:2.0:ErrorDetail" + ], + "detail":"Internal error", + "status":"500", + "urn:ietf:params:scim:api:zitadel:messages:2.0:ErrorDetail": { + "id":"FOO", + "message":"Internal error" + } + }`, + }, + { + name: "unknown error", + err: errors.New("FOO"), + wantStatus: http.StatusInternalServerError, + wantBody: `{ + "schemas":[ + "urn:ietf:params:scim:api:messages:2.0:Error" + ], + "detail":"Unknown internal server error", + "status":"500" + }`, + }, + { + name: "no error", + err: nil, + wantStatus: http.StatusOK, + wantBody: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + recorder := httptest.NewRecorder() + ErrorHandler(func(http.ResponseWriter, *http.Request) error { + return tt.err + }).ServeHTTP(recorder, req) + assert.Equal(t, tt.wantStatus, recorder.Code) + + if tt.wantBody != "" { + assert.JSONEq(t, tt.wantBody, recorder.Body.String()) + } + }) + } +} diff --git a/internal/api/scim/server.go b/internal/api/scim/server.go new file mode 100644 index 0000000000..d5d739bdc9 --- /dev/null +++ b/internal/api/scim/server.go @@ -0,0 +1,105 @@ +package scim + +import ( + "encoding/json" + "net/http" + "path" + + "github.com/gorilla/mux" + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/api/authz" + zhttp "github.com/zitadel/zitadel/internal/api/http" + zhttp_middlware "github.com/zitadel/zitadel/internal/api/http/middleware" + sconfig "github.com/zitadel/zitadel/internal/api/scim/config" + smiddleware "github.com/zitadel/zitadel/internal/api/scim/middleware" + sresources "github.com/zitadel/zitadel/internal/api/scim/resources" + "github.com/zitadel/zitadel/internal/api/scim/schemas" + "github.com/zitadel/zitadel/internal/api/scim/serrors" + "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/crypto" + "github.com/zitadel/zitadel/internal/query" +) + +func NewServer( + command *command.Commands, + query *query.Queries, + verifier *authz.ApiTokenVerifier, + userCodeAlg crypto.EncryptionAlgorithm, + config *sconfig.Config, + middlewares ...zhttp_middlware.MiddlewareWithErrorFunc) http.Handler { + verifier.RegisterServer("SCIM-V2", schemas.HandlerPrefix, AuthMapping) + return buildHandler(command, query, userCodeAlg, config, middlewares...) +} + +func buildHandler( + command *command.Commands, + query *query.Queries, + userCodeAlg crypto.EncryptionAlgorithm, + cfg *sconfig.Config, + middlewares ...zhttp_middlware.MiddlewareWithErrorFunc) http.Handler { + + router := mux.NewRouter() + + // content type middleware needs to run at the very beginning to correctly set content types of errors + middlewares = append([]zhttp_middlware.MiddlewareWithErrorFunc{smiddleware.ContentTypeMiddleware}, middlewares...) + middlewares = append(middlewares, smiddleware.ScimContextMiddleware(query)) + scimMiddleware := zhttp_middlware.ChainedWithErrorHandler(serrors.ErrorHandler, middlewares...) + mapResource(router, scimMiddleware, sresources.NewUsersHandler(command, query, userCodeAlg, cfg)) + return router +} + +func mapResource[T sresources.ResourceHolder](router *mux.Router, mw zhttp_middlware.ErrorHandlerFunc, handler sresources.ResourceHandler[T]) { + adapter := sresources.NewResourceHandlerAdapter[T](handler) + resourceRouter := router.PathPrefix("/" + path.Join(zhttp.OrgIdInPathVariable, string(handler.ResourceNamePlural()))).Subrouter() + + resourceRouter.Handle("", mw(handleResourceCreatedResponse(adapter.Create))).Methods(http.MethodPost) + resourceRouter.Handle("/{id}", mw(handleResourceResponse(adapter.Get))).Methods(http.MethodGet) + resourceRouter.Handle("/{id}", mw(handleResourceResponse(adapter.Replace))).Methods(http.MethodPut) + resourceRouter.Handle("/{id}", mw(handleEmptyResponse(adapter.Delete))).Methods(http.MethodDelete) +} + +func handleResourceCreatedResponse[T sresources.ResourceHolder](next func(*http.Request) (T, error)) zhttp_middlware.HandlerFuncWithError { + return func(w http.ResponseWriter, r *http.Request) error { + entity, err := next(r) + if err != nil { + return err + } + + resource := entity.GetResource() + w.Header().Set(zhttp.Location, resource.Meta.Location) + w.WriteHeader(http.StatusCreated) + + err = json.NewEncoder(w).Encode(entity) + logging.OnError(err).Warn("scim json response encoding failed") + return nil + } +} + +func handleResourceResponse[T sresources.ResourceHolder](next func(*http.Request) (T, error)) zhttp_middlware.HandlerFuncWithError { + return func(w http.ResponseWriter, r *http.Request) error { + entity, err := next(r) + if err != nil { + return err + } + + resource := entity.GetResource() + w.Header().Set(zhttp.ContentLocation, resource.Meta.Location) + + err = json.NewEncoder(w).Encode(entity) + logging.OnError(err).Warn("scim json response encoding failed") + return nil + } +} + +func handleEmptyResponse(next func(*http.Request) error) zhttp_middlware.HandlerFuncWithError { + return func(w http.ResponseWriter, r *http.Request) error { + err := next(r) + if err != nil { + return err + } + + w.WriteHeader(http.StatusNoContent) + return nil + } +} diff --git a/internal/api/ui/console/console.go b/internal/api/ui/console/console.go index 515f26db9b..fffbc00d5b 100644 --- a/internal/api/ui/console/console.go +++ b/internal/api/ui/console/console.go @@ -28,6 +28,10 @@ type Config struct { ShortCache middleware.CacheConfig LongCache middleware.CacheConfig InstanceManagementURL string + PostHog struct { + Token string + URL string + } } type spaHandler struct { @@ -117,7 +121,7 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, call return } limited := limitingAccessInterceptor.Limit(w, r) - environmentJSON, err := createEnvironmentJSON(url, issuer(r), instance.ConsoleClientID(), customerPortal, instanceMgmtURL, limited) + environmentJSON, err := createEnvironmentJSON(url, issuer(r), instance.ConsoleClientID(), customerPortal, instanceMgmtURL, config.PostHog.URL, config.PostHog.Token, limited) if err != nil { http.Error(w, fmt.Sprintf("unable to marshal env for console: %v", err), http.StatusInternalServerError) return @@ -150,13 +154,15 @@ func csp() *middleware.CSP { return &csp } -func createEnvironmentJSON(api, issuer, clientID, customerPortal, instanceMgmtUrl string, exhausted bool) ([]byte, error) { +func createEnvironmentJSON(api, issuer, clientID, customerPortal, instanceMgmtUrl, postHogURL, postHogToken string, exhausted bool) ([]byte, error) { environment := struct { API string `json:"api,omitempty"` Issuer string `json:"issuer,omitempty"` ClientID string `json:"clientid,omitempty"` CustomerPortal string `json:"customer_portal,omitempty"` InstanceManagementURL string `json:"instance_management_url,omitempty"` + PostHogURL string `json:"posthog_url,omitempty"` + PostHogToken string `json:"posthog_token,omitempty"` Exhausted bool `json:"exhausted,omitempty"` }{ API: api, @@ -164,6 +170,8 @@ func createEnvironmentJSON(api, issuer, clientID, customerPortal, instanceMgmtUr ClientID: clientID, CustomerPortal: customerPortal, InstanceManagementURL: instanceMgmtUrl, + PostHogURL: postHogURL, + PostHogToken: postHogToken, Exhausted: exhausted, } return json.Marshal(environment) diff --git a/internal/api/ui/login/change_password_handler.go b/internal/api/ui/login/change_password_handler.go index 19f2404c94..7f2b83b80f 100644 --- a/internal/api/ui/login/change_password_handler.go +++ b/internal/api/ui/login/change_password_handler.go @@ -35,10 +35,6 @@ func (l *Login) handleChangePassword(w http.ResponseWriter, r *http.Request) { } func (l *Login) renderChangePassword(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, err error) { - var errType, errMessage string - if err != nil { - errType, errMessage = l.getErrorMessage(r, err) - } translator := l.getTranslator(r.Context(), authReq) if authReq == nil || len(authReq.PossibleSteps) < 1 { l.renderError(w, r, authReq, err) @@ -50,7 +46,7 @@ func (l *Login) renderChangePassword(w http.ResponseWriter, r *http.Request, aut return } data := passwordData{ - baseData: l.getBaseData(r, authReq, translator, "PasswordChange.Title", "PasswordChange.Description", errType, errMessage), + baseData: l.getBaseData(r, authReq, translator, "PasswordChange.Title", "PasswordChange.Description", err), profileData: l.getProfileData(authReq), Expired: step.Expired, } @@ -75,6 +71,6 @@ func (l *Login) renderChangePassword(w http.ResponseWriter, r *http.Request, aut func (l *Login) renderChangePasswordDone(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest) { translator := l.getTranslator(r.Context(), authReq) - data := l.getUserData(r, authReq, translator, "PasswordChange.Title", "PasswordChange.Description", "", "") + data := l.getUserData(r, authReq, translator, "PasswordChange.Title", "PasswordChange.Description", nil) l.renderer.RenderTemplate(w, r, translator, l.renderer.Templates[tmplChangePasswordDone], data, nil) } diff --git a/internal/api/ui/login/device_auth.go b/internal/api/ui/login/device_auth.go index e7ddaccd08..fb61102d76 100644 --- a/internal/api/ui/login/device_auth.go +++ b/internal/api/ui/login/device_auth.go @@ -22,13 +22,11 @@ const ( ) func (l *Login) renderDeviceAuthUserCode(w http.ResponseWriter, r *http.Request, err error) { - var errID, errMessage string if err != nil { logging.WithError(err).Error() - errID, errMessage = l.getErrorMessage(r, err) } translator := l.getTranslator(r.Context(), nil) - data := l.getBaseData(r, nil, translator, "DeviceAuth.Title", "DeviceAuth.UserCode.Description", errID, errMessage) + data := l.getBaseData(r, nil, translator, "DeviceAuth.Title", "DeviceAuth.UserCode.Description", err) l.renderer.RenderTemplate(w, r, translator, l.renderer.Templates[tmplDeviceAuthUserCode], data, nil) } @@ -41,7 +39,7 @@ func (l *Login) renderDeviceAuthAction(w http.ResponseWriter, r *http.Request, a ClientID string Scopes []string }{ - baseData: l.getBaseData(r, authReq, translator, "DeviceAuth.Title", "DeviceAuth.Action.Description", "", ""), + baseData: l.getBaseData(r, authReq, translator, "DeviceAuth.Title", "DeviceAuth.Action.Description", nil), AuthRequestID: authReq.ID, Username: authReq.UserName, ClientID: authReq.ApplicationID, @@ -63,7 +61,7 @@ func (l *Login) renderDeviceAuthDone(w http.ResponseWriter, r *http.Request, aut baseData Message string }{ - baseData: l.getBaseData(r, authReq, translator, "DeviceAuth.Title", "DeviceAuth.Done.Description", "", ""), + baseData: l.getBaseData(r, authReq, translator, "DeviceAuth.Title", "DeviceAuth.Done.Description", nil), } switch action { case deviceAuthAllowed: diff --git a/internal/api/ui/login/external_provider_handler.go b/internal/api/ui/login/external_provider_handler.go index c60e0eb0bb..5481c6aed1 100644 --- a/internal/api/ui/login/external_provider_handler.go +++ b/internal/api/ui/login/external_provider_handler.go @@ -2,8 +2,10 @@ package login import ( "context" + "errors" "net/http" "net/url" + "slices" "strings" "github.com/crewjam/saml/samlsp" @@ -150,7 +152,7 @@ func (l *Login) handleIDP(w http.ResponseWriter, r *http.Request, authReq *domai userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) err = l.authRepo.SelectExternalIDP(r.Context(), authReq.ID, identityProvider.ID, userAgentID) if err != nil { - l.renderLogin(w, r, authReq, err) + l.externalAuthFailed(w, r, authReq, err) return } var provider idp.Provider @@ -183,17 +185,17 @@ func (l *Login) handleIDP(w http.ResponseWriter, r *http.Request, authReq *domai case domain.IDPTypeUnspecified: fallthrough default: - l.renderLogin(w, r, authReq, zerrors.ThrowInvalidArgument(nil, "LOGIN-AShek", "Errors.ExternalIDP.IDPTypeNotImplemented")) + l.externalAuthFailed(w, r, authReq, zerrors.ThrowInvalidArgument(nil, "LOGIN-AShek", "Errors.ExternalIDP.IDPTypeNotImplemented")) return } if err != nil { - l.renderLogin(w, r, authReq, err) + l.externalAuthFailed(w, r, authReq, err) return } params := l.sessionParamsFromAuthRequest(r.Context(), authReq, identityProvider.ID) session, err := provider.BeginAuth(r.Context(), authReq.ID, params...) if err != nil { - l.renderLogin(w, r, authReq, err) + l.externalAuthFailed(w, r, authReq, err) return } @@ -215,7 +217,7 @@ func (l *Login) handleIDP(w http.ResponseWriter, r *http.Request, authReq *domai func (l *Login) handleExternalLoginCallbackForm(w http.ResponseWriter, r *http.Request) { err := r.ParseForm() if err != nil { - l.renderLogin(w, r, nil, err) + l.externalAuthFailed(w, r, nil, err) return } state := r.Form.Get(queryState) @@ -223,7 +225,7 @@ func (l *Login) handleExternalLoginCallbackForm(w http.ResponseWriter, r *http.R state = r.Form.Get(queryRelayState) } if state == "" { - l.renderLogin(w, r, nil, zerrors.ThrowInvalidArgument(nil, "LOGIN-dsg3f", "Errors.AuthRequest.NotFound")) + l.externalAuthFailed(w, r, nil, zerrors.ThrowInvalidArgument(nil, "LOGIN-dsg3f", "Errors.AuthRequest.NotFound")) return } l.caches.idpFormCallbacks.Set(r.Context(), &idpFormCallback{ @@ -243,7 +245,7 @@ func (l *Login) handleExternalLoginCallback(w http.ResponseWriter, r *http.Reque // workaround because of CSRF on external identity provider flows using form_post if r.URL.Query().Get(queryMethod) == http.MethodPost { if err := l.setDataFromFormCallback(r, r.URL.Query().Get(queryState)); err != nil { - l.renderLogin(w, r, nil, err) + l.externalAuthFailed(w, r, nil, err) return } } @@ -251,7 +253,7 @@ func (l *Login) handleExternalLoginCallback(w http.ResponseWriter, r *http.Reque data := new(externalIDPCallbackData) err := l.getParseData(r, data) if err != nil { - l.renderLogin(w, r, nil, err) + l.externalAuthFailed(w, r, nil, err) return } if data.State == "" { @@ -261,12 +263,12 @@ func (l *Login) handleExternalLoginCallback(w http.ResponseWriter, r *http.Reque userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) authReq, err := l.authRepo.AuthRequestByID(r.Context(), data.State, userAgentID) if err != nil { - l.externalAuthFailed(w, r, authReq, nil, nil, err) + l.externalAuthCallbackFailed(w, r, authReq, nil, nil, err) return } identityProvider, err := l.getIDPByID(r, authReq.SelectedIDPConfigID) if err != nil { - l.externalAuthFailed(w, r, authReq, nil, nil, err) + l.externalAuthCallbackFailed(w, r, authReq, nil, nil, err) return } var provider idp.Provider @@ -275,75 +277,75 @@ func (l *Login) handleExternalLoginCallback(w http.ResponseWriter, r *http.Reque case domain.IDPTypeOAuth: provider, err = l.oauthProvider(r.Context(), identityProvider) if err != nil { - l.externalAuthFailed(w, r, authReq, nil, nil, err) + l.externalAuthCallbackFailed(w, r, authReq, nil, nil, err) return } session = &oauth.Session{Provider: provider.(*oauth.Provider), Code: data.Code} case domain.IDPTypeOIDC: provider, err = l.oidcProvider(r.Context(), identityProvider) if err != nil { - l.externalAuthFailed(w, r, authReq, nil, nil, err) + l.externalAuthCallbackFailed(w, r, authReq, nil, nil, err) return } session = &openid.Session{Provider: provider.(*openid.Provider), Code: data.Code} case domain.IDPTypeAzureAD: provider, err = l.azureProvider(r.Context(), identityProvider) if err != nil { - l.externalAuthFailed(w, r, authReq, nil, nil, err) + l.externalAuthCallbackFailed(w, r, authReq, nil, nil, err) return } session = &azuread.Session{Provider: provider.(*azuread.Provider), Code: data.Code} case domain.IDPTypeGitHub: provider, err = l.githubProvider(r.Context(), identityProvider) if err != nil { - l.externalAuthFailed(w, r, authReq, nil, nil, err) + l.externalAuthCallbackFailed(w, r, authReq, nil, nil, err) return } session = &oauth.Session{Provider: provider.(*github.Provider).Provider, Code: data.Code} case domain.IDPTypeGitHubEnterprise: provider, err = l.githubEnterpriseProvider(r.Context(), identityProvider) if err != nil { - l.externalAuthFailed(w, r, authReq, nil, nil, err) + l.externalAuthCallbackFailed(w, r, authReq, nil, nil, err) return } session = &oauth.Session{Provider: provider.(*github.Provider).Provider, Code: data.Code} case domain.IDPTypeGitLab: provider, err = l.gitlabProvider(r.Context(), identityProvider) if err != nil { - l.externalAuthFailed(w, r, authReq, nil, nil, err) + l.externalAuthCallbackFailed(w, r, authReq, nil, nil, err) return } session = &openid.Session{Provider: provider.(*gitlab.Provider).Provider, Code: data.Code} case domain.IDPTypeGitLabSelfHosted: provider, err = l.gitlabSelfHostedProvider(r.Context(), identityProvider) if err != nil { - l.externalAuthFailed(w, r, authReq, nil, nil, err) + l.externalAuthCallbackFailed(w, r, authReq, nil, nil, err) return } session = &openid.Session{Provider: provider.(*gitlab.Provider).Provider, Code: data.Code} case domain.IDPTypeGoogle: provider, err = l.googleProvider(r.Context(), identityProvider) if err != nil { - l.externalAuthFailed(w, r, authReq, nil, nil, err) + l.externalAuthCallbackFailed(w, r, authReq, nil, nil, err) return } session = &openid.Session{Provider: provider.(*google.Provider).Provider, Code: data.Code} case domain.IDPTypeApple: provider, err = l.appleProvider(r.Context(), identityProvider) if err != nil { - l.externalAuthFailed(w, r, authReq, nil, nil, err) + l.externalAuthCallbackFailed(w, r, authReq, nil, nil, err) return } session = &apple.Session{Session: &openid.Session{Provider: provider.(*apple.Provider).Provider, Code: data.Code}, UserFormValue: data.User} case domain.IDPTypeSAML: provider, err = l.samlProvider(r.Context(), identityProvider) if err != nil { - l.externalAuthFailed(w, r, authReq, nil, nil, err) + l.externalAuthCallbackFailed(w, r, authReq, nil, nil, err) return } session, err = saml.NewSession(provider.(*saml.Provider), authReq.SAMLRequestID, r) if err != nil { - l.externalAuthFailed(w, r, authReq, nil, nil, err) + l.externalAuthCallbackFailed(w, r, authReq, nil, nil, err) return } case domain.IDPTypeJWT, @@ -351,7 +353,7 @@ func (l *Login) handleExternalLoginCallback(w http.ResponseWriter, r *http.Reque domain.IDPTypeUnspecified: fallthrough default: - l.renderLogin(w, r, authReq, zerrors.ThrowInvalidArgument(nil, "LOGIN-SFefg", "Errors.ExternalIDP.IDPTypeNotImplemented")) + l.externalAuthFailed(w, r, authReq, zerrors.ThrowInvalidArgument(nil, "LOGIN-SFefg", "Errors.ExternalIDP.IDPTypeNotImplemented")) return } @@ -361,7 +363,7 @@ func (l *Login) handleExternalLoginCallback(w http.ResponseWriter, r *http.Reque "instance", authz.GetInstance(r.Context()).InstanceID(), "providerID", identityProvider.ID, ).WithError(err).Info("external authentication failed") - l.externalAuthFailed(w, r, authReq, tokens(session), user, err) + l.externalAuthCallbackFailed(w, r, authReq, tokens(session), user, err) return } l.handleExternalUserAuthenticated(w, r, authReq, identityProvider, session, user, l.renderNextStep) @@ -619,10 +621,6 @@ func (l *Login) autoCreateExternalUser(w http.ResponseWriter, r *http.Request, a // renderExternalNotFoundOption renders a page, where the user is able to edit the IDP data, // create a new externalUser of link to existing on (based on the IDP template) func (l *Login) renderExternalNotFoundOption(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, orgIAMPolicy *query.DomainPolicy, human *domain.Human, idpLink *domain.UserIDPLink, err error) { - var errID, errMessage string - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) - } resourceOwner := determineResourceOwner(r.Context(), authReq) if orgIAMPolicy == nil { orgIAMPolicy, err = l.getOrgDomainPolicy(r, resourceOwner) @@ -656,7 +654,7 @@ func (l *Login) renderExternalNotFoundOption(w http.ResponseWriter, r *http.Requ translator := l.getTranslator(r.Context(), authReq) data := externalNotFoundOptionData{ - baseData: l.getBaseData(r, authReq, translator, "ExternalNotFound.Title", "ExternalNotFound.Description", errID, errMessage), + baseData: l.getBaseData(r, authReq, translator, "ExternalNotFound.Title", "ExternalNotFound.Description", err), externalNotFoundOptionFormData: externalNotFoundOptionFormData{ externalRegisterFormData: externalRegisterFormData{ Email: human.EmailAddress, @@ -1215,7 +1213,7 @@ func (l *Login) appendUserGrants(ctx context.Context, userGrants []*domain.UserG return nil } -func (l *Login) externalAuthFailed(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, tokens *oidc.Tokens[*oidc.IDTokenClaims], user idp.User, err error) { +func (l *Login) externalAuthCallbackFailed(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, tokens *oidc.Tokens[*oidc.IDTokenClaims], user idp.User, err error) { if authReq == nil { l.renderLogin(w, r, authReq, err) return @@ -1223,7 +1221,37 @@ func (l *Login) externalAuthFailed(w http.ResponseWriter, r *http.Request, authR if _, _, actionErr := l.runPostExternalAuthenticationActions(&domain.ExternalUser{}, tokens, authReq, r, user, err); actionErr != nil { logging.WithError(err).Error("both external user authentication and action post authentication failed") } - l.renderLogin(w, r, authReq, err) + l.externalAuthFailed(w, r, authReq, err) +} + +func (l *Login) externalAuthFailed(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, err error) { + if authReq == nil || authReq.LoginPolicy == nil || !authReq.LoginPolicy.AllowUsernamePassword || authReq.UserID == "" { + l.renderLogin(w, r, authReq, err) + return + } + authMethods, authMethodsError := l.query.ListUserAuthMethodTypes(setUserContext(r.Context(), authReq.UserID, ""), authReq.UserID, true, false, "") + if authMethodsError != nil { + logging.WithFields("userID", authReq.UserID).WithError(authMethodsError).Warn("unable to load user's auth methods for idp login error") + l.renderLogin(w, r, authReq, err) + return + } + passwordless := slices.Contains(authMethods.AuthMethodTypes, domain.UserAuthMethodTypePasswordless) + password := slices.Contains(authMethods.AuthMethodTypes, domain.UserAuthMethodTypePassword) + if !passwordless && !password { + l.renderLogin(w, r, authReq, err) + return + } + localAuthError := l.authRepo.RequestLocalAuth(setContext(r.Context(), authReq.UserOrgID), authReq.ID, authReq.AgentID) + if localAuthError != nil { + l.renderLogin(w, r, authReq, err) + return + } + err = WrapIdPError(err) + if passwordless { + l.renderPasswordlessVerification(w, r, authReq, password, err) + return + } + l.renderPassword(w, r, authReq, err) } // tokens extracts the oidc.Tokens for backwards compatibility of PostExternalAuthenticationActions @@ -1359,3 +1387,34 @@ func (l *Login) getUserLinks(ctx context.Context, userID, idpID string) (*query. }, nil, ) } + +// IdPError wraps an error from an external IDP to be able to distinguish it from other errors and to display it +// more prominent (popup style) . +// It's used if an error occurs during the login process with an external IDP and local authentication is allowed, +// respectively used as fallback. +type IdPError struct { + err *zerrors.ZitadelError +} + +func (e *IdPError) Error() string { + return e.err.Error() +} + +func (e *IdPError) Unwrap() error { + return e.err +} + +func (e *IdPError) Is(target error) bool { + _, ok := target.(*IdPError) + return ok +} + +func WrapIdPError(err error) *IdPError { + zErr := new(zerrors.ZitadelError) + id := "LOGIN-JWo3f" + // keep the original error id if there is one + if errors.As(err, &zErr) { + id = zErr.ID + } + return &IdPError{err: zerrors.CreateZitadelError(err, id, "Errors.User.ExternalIDP.LoginFailedSwitchLocal")} +} diff --git a/internal/api/ui/login/init_password_handler.go b/internal/api/ui/login/init_password_handler.go index b8c6d401c5..17ac13ff31 100644 --- a/internal/api/ui/login/init_password_handler.go +++ b/internal/api/ui/login/init_password_handler.go @@ -112,10 +112,6 @@ func (l *Login) resendPasswordSet(w http.ResponseWriter, r *http.Request, authRe } func (l *Login) renderInitPassword(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, userID, code string, err error) { - var errID, errMessage string - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) - } if userID == "" && authReq != nil { userID = authReq.UserID } @@ -123,7 +119,7 @@ func (l *Login) renderInitPassword(w http.ResponseWriter, r *http.Request, authR translator := l.getTranslator(r.Context(), authReq) data := initPasswordData{ - baseData: l.getBaseData(r, authReq, translator, "InitPassword.Title", "InitPassword.Description", errID, errMessage), + baseData: l.getBaseData(r, authReq, translator, "InitPassword.Title", "InitPassword.Description", err), profileData: l.getProfileData(authReq), UserID: userID, Code: code, @@ -155,7 +151,7 @@ func (l *Login) renderInitPassword(w http.ResponseWriter, r *http.Request, authR func (l *Login) renderInitPasswordDone(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, orgID string) { translator := l.getTranslator(r.Context(), authReq) - data := l.getUserData(r, authReq, translator, "InitPasswordDone.Title", "InitPasswordDone.Description", "", "") + data := l.getUserData(r, authReq, translator, "InitPasswordDone.Title", "InitPasswordDone.Description", nil) if authReq == nil { l.customTexts(r.Context(), translator, orgID) } diff --git a/internal/api/ui/login/init_user_handler.go b/internal/api/ui/login/init_user_handler.go index 9a6d052dcd..fa4854a473 100644 --- a/internal/api/ui/login/init_user_handler.go +++ b/internal/api/ui/login/init_user_handler.go @@ -131,17 +131,13 @@ func (l *Login) resendUserInit(w http.ResponseWriter, r *http.Request, authReq * } func (l *Login) renderInitUser(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, userID, loginName string, code string, passwordSet bool, err error) { - var errID, errMessage string - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) - } if authReq != nil { userID = authReq.UserID } translator := l.getTranslator(r.Context(), authReq) data := initUserData{ - baseData: l.getBaseData(r, authReq, translator, "InitUser.Title", "InitUser.Description", errID, errMessage), + baseData: l.getBaseData(r, authReq, translator, "InitUser.Title", "InitUser.Description", err), profileData: l.getProfileData(authReq), UserID: userID, Code: code, @@ -179,7 +175,7 @@ func (l *Login) renderInitUser(w http.ResponseWriter, r *http.Request, authReq * func (l *Login) renderInitUserDone(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, orgID string) { translator := l.getTranslator(r.Context(), authReq) - data := l.getUserData(r, authReq, translator, "InitUserDone.Title", "InitUserDone.Description", "", "") + data := l.getUserData(r, authReq, translator, "InitUserDone.Title", "InitUserDone.Description", nil) if authReq == nil { l.customTexts(r.Context(), translator, orgID) } diff --git a/internal/api/ui/login/invite_user_handler.go b/internal/api/ui/login/invite_user_handler.go index e083277c93..18ba502483 100644 --- a/internal/api/ui/login/invite_user_handler.go +++ b/internal/api/ui/login/invite_user_handler.go @@ -119,10 +119,6 @@ func (l *Login) resendUserInvite(w http.ResponseWriter, r *http.Request, authReq } func (l *Login) renderInviteUser(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, userID, orgID, loginName string, code string, err error) { - var errID, errMessage string - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) - } if authReq != nil { userID = authReq.UserID orgID = authReq.UserOrgID @@ -130,7 +126,7 @@ func (l *Login) renderInviteUser(w http.ResponseWriter, r *http.Request, authReq translator := l.getTranslator(r.Context(), authReq) data := inviteUserData{ - baseData: l.getBaseData(r, authReq, translator, "InviteUser.Title", "InviteUser.Description", errID, errMessage), + baseData: l.getBaseData(r, authReq, translator, "InviteUser.Title", "InviteUser.Description", err), profileData: l.getProfileData(authReq), UserID: userID, Code: code, diff --git a/internal/api/ui/login/ldap_handler.go b/internal/api/ui/login/ldap_handler.go index 0fd47c5a6a..147a319523 100644 --- a/internal/api/ui/login/ldap_handler.go +++ b/internal/api/ui/login/ldap_handler.go @@ -30,13 +30,9 @@ func (l *Login) handleLDAP(w http.ResponseWriter, r *http.Request) { } func (l *Login) renderLDAPLogin(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, err error) { - var errID, errMessage string - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) - } temp := l.renderer.Templates[tmplLDAPLogin] translator := l.getTranslator(r.Context(), authReq) - data := l.getUserData(r, authReq, translator, "Login.Title", "Login.Description", errID, errMessage) + data := l.getUserData(r, authReq, translator, "Login.Title", "Login.Description", err) l.renderer.RenderTemplate(w, r, translator, temp, data, nil) } diff --git a/internal/api/ui/login/link_users_handler.go b/internal/api/ui/login/link_users_handler.go index c720559084..0b0803f8a1 100644 --- a/internal/api/ui/login/link_users_handler.go +++ b/internal/api/ui/login/link_users_handler.go @@ -18,11 +18,7 @@ func (l *Login) linkUsers(w http.ResponseWriter, r *http.Request, authReq *domai } func (l *Login) renderLinkUsersDone(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, err error) { - var errType, errMessage string - if err != nil { - errType, errMessage = l.getErrorMessage(r, err) - } translator := l.getTranslator(r.Context(), authReq) - data := l.getUserData(r, authReq, translator, "LinkingUsersDone.Title", "LinkingUsersDone.Description", errType, errMessage) + data := l.getUserData(r, authReq, translator, "LinkingUsersDone.Title", "LinkingUsersDone.Description", err) l.renderer.RenderTemplate(w, r, translator, l.renderer.Templates[tmplLinkUsersDone], data, nil) } diff --git a/internal/api/ui/login/login_handler.go b/internal/api/ui/login/login_handler.go index 059048eecb..729bd1955b 100644 --- a/internal/api/ui/login/login_handler.go +++ b/internal/api/ui/login/login_handler.go @@ -91,16 +91,12 @@ func (l *Login) handleLoginNameCheck(w http.ResponseWriter, r *http.Request) { } func (l *Login) renderLogin(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, err error) { - var errID, errMessage string - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) - } if err == nil && singleIDPAllowed(authReq) { l.handleIDP(w, r, authReq, authReq.AllowedExternalIDPs[0].IDPConfigID) return } translator := l.getTranslator(r.Context(), authReq) - data := l.getUserData(r, authReq, translator, "Login.Title", "Login.Description", errID, errMessage) + data := l.getUserData(r, authReq, translator, "Login.Title", "Login.Description", err) funcs := map[string]interface{}{ "hasUsernamePasswordLogin": func() bool { return authReq != nil && authReq.LoginPolicy != nil && authReq.LoginPolicy.AllowUsernamePassword diff --git a/internal/api/ui/login/login_success_handler.go b/internal/api/ui/login/login_success_handler.go index 00f29becfd..a18a3a2d5c 100644 --- a/internal/api/ui/login/login_success_handler.go +++ b/internal/api/ui/login/login_success_handler.go @@ -37,13 +37,9 @@ func (l *Login) handleLoginSuccess(w http.ResponseWriter, r *http.Request) { } func (l *Login) renderSuccessAndCallback(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, err error) { - var errID, errMessage string - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) - } translator := l.getTranslator(r.Context(), authReq) data := loginSuccessData{ - userData: l.getUserData(r, authReq, translator, "LoginSuccess.Title", "", errID, errMessage), + userData: l.getUserData(r, authReq, translator, "LoginSuccess.Title", "", err), } if authReq != nil { data.RedirectURI, err = l.authRequestCallback(r.Context(), authReq) diff --git a/internal/api/ui/login/logout_handler.go b/internal/api/ui/login/logout_handler.go index e270cd5541..9596f477af 100644 --- a/internal/api/ui/login/logout_handler.go +++ b/internal/api/ui/login/logout_handler.go @@ -14,6 +14,6 @@ func (l *Login) handleLogoutDone(w http.ResponseWriter, r *http.Request) { func (l *Login) renderLogoutDone(w http.ResponseWriter, r *http.Request) { translator := l.getTranslator(r.Context(), nil) - data := l.getUserData(r, nil, translator, "LogoutDone.Title", "LogoutDone.Description", "", "") + data := l.getUserData(r, nil, translator, "LogoutDone.Title", "LogoutDone.Description", nil) l.renderer.RenderTemplate(w, r, translator, l.renderer.Templates[tmplLogoutDone], data, nil) } diff --git a/internal/api/ui/login/mail_verify_handler.go b/internal/api/ui/login/mail_verify_handler.go index 864ff76dd2..071fe6539d 100644 --- a/internal/api/ui/login/mail_verify_handler.go +++ b/internal/api/ui/login/mail_verify_handler.go @@ -145,17 +145,13 @@ func (l *Login) checkMailCode(w http.ResponseWriter, r *http.Request, authReq *d } func (l *Login) renderMailVerification(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, userID, code string, passwordInit bool, err error) { - var errID, errMessage string - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) - } if userID == "" && authReq != nil { userID = authReq.UserID } translator := l.getTranslator(r.Context(), authReq) data := mailVerificationData{ - baseData: l.getBaseData(r, authReq, translator, "EmailVerification.Title", "EmailVerification.Description", errID, errMessage), + baseData: l.getBaseData(r, authReq, translator, "EmailVerification.Title", "EmailVerification.Description", err), UserID: userID, profileData: l.getProfileData(authReq), Code: code, @@ -191,7 +187,7 @@ func (l *Login) renderMailVerification(w http.ResponseWriter, r *http.Request, a func (l *Login) renderMailVerified(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, orgID string) { translator := l.getTranslator(r.Context(), authReq) data := mailVerificationData{ - baseData: l.getBaseData(r, authReq, translator, "EmailVerificationDone.Title", "EmailVerificationDone.Description", "", ""), + baseData: l.getBaseData(r, authReq, translator, "EmailVerificationDone.Title", "EmailVerificationDone.Description", nil), profileData: l.getProfileData(authReq), } if authReq == nil { diff --git a/internal/api/ui/login/mfa_init_done_handler.go b/internal/api/ui/login/mfa_init_done_handler.go index 437fde29f4..ae4bab69ea 100644 --- a/internal/api/ui/login/mfa_init_done_handler.go +++ b/internal/api/ui/login/mfa_init_done_handler.go @@ -14,9 +14,8 @@ type mfaInitDoneData struct { } func (l *Login) renderMFAInitDone(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, data *mfaDoneData) { - var errType, errMessage string translator := l.getTranslator(r.Context(), authReq) - data.baseData = l.getBaseData(r, authReq, translator, "InitMFADone.Title", "InitMFADone.Description", errType, errMessage) + data.baseData = l.getBaseData(r, authReq, translator, "InitMFADone.Title", "InitMFADone.Description", nil) data.profileData = l.getProfileData(authReq) l.renderer.RenderTemplate(w, r, translator, l.renderer.Templates[tmplMFAInitDone], data, nil) } diff --git a/internal/api/ui/login/mfa_init_sms.go b/internal/api/ui/login/mfa_init_sms.go index 03f2c32014..048677f0f4 100644 --- a/internal/api/ui/login/mfa_init_sms.go +++ b/internal/api/ui/login/mfa_init_sms.go @@ -53,12 +53,8 @@ func (l *Login) handleRegisterOTPSMS(w http.ResponseWriter, r *http.Request, aut } func (l *Login) renderRegisterSMS(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, data *smsInitData, err error) { - var errID, errMessage string - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) - } translator := l.getTranslator(r.Context(), authReq) - data.baseData = l.getBaseData(r, authReq, translator, "InitMFAOTP.Title", "InitMFAOTP.Description", errID, errMessage) + data.baseData = l.getBaseData(r, authReq, translator, "InitMFAOTP.Title", "InitMFAOTP.Description", err) data.profileData = l.getProfileData(authReq) data.MFAType = domain.MFATypeOTPSMS l.renderer.RenderTemplate(w, r, translator, l.renderer.Templates[tmplMFASMSInit], data, nil) diff --git a/internal/api/ui/login/mfa_init_u2f.go b/internal/api/ui/login/mfa_init_u2f.go index c84948796c..0e75bd1b69 100644 --- a/internal/api/ui/login/mfa_init_u2f.go +++ b/internal/api/ui/login/mfa_init_u2f.go @@ -18,21 +18,18 @@ type u2fInitData struct { } func (l *Login) renderRegisterU2F(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, err error) { - var errID, errMessage, credentialData string + var credentialData string var u2f *domain.WebAuthNToken if err == nil { u2f, err = l.command.HumanAddU2FSetup(setUserContext(r.Context(), authReq.UserID, authReq.UserOrgID), authReq.UserID, authReq.UserOrgID) } - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) - } if u2f != nil { credentialData = base64.RawURLEncoding.EncodeToString(u2f.CredentialCreationData) } translator := l.getTranslator(r.Context(), authReq) data := &u2fInitData{ webAuthNData: webAuthNData{ - userData: l.getUserData(r, authReq, translator, "InitMFAU2F.Title", "InitMFAU2F.Description", errID, errMessage), + userData: l.getUserData(r, authReq, translator, "InitMFAU2F.Title", "InitMFAU2F.Description", err), CredentialCreationData: credentialData, }, MFAType: domain.MFATypeU2F, diff --git a/internal/api/ui/login/mfa_init_verify_handler.go b/internal/api/ui/login/mfa_init_verify_handler.go index cd6a9091e2..e3488d391c 100644 --- a/internal/api/ui/login/mfa_init_verify_handler.go +++ b/internal/api/ui/login/mfa_init_verify_handler.go @@ -66,12 +66,8 @@ func (l *Login) handleOTPVerify(w http.ResponseWriter, r *http.Request, authReq } func (l *Login) renderMFAInitVerify(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, data *mfaVerifyData, err error) { - var errID, errMessage string - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) - } translator := l.getTranslator(r.Context(), authReq) - data.baseData = l.getBaseData(r, authReq, translator, "InitMFAOTP.Title", "InitMFAOTP.Description", errID, errMessage) + data.baseData = l.getBaseData(r, authReq, translator, "InitMFAOTP.Title", "InitMFAOTP.Description", err) data.profileData = l.getProfileData(authReq) if data.MFAType == domain.MFATypeTOTP { code, err := generateQrCode(data.totpData.Url) diff --git a/internal/api/ui/login/mfa_prompt_handler.go b/internal/api/ui/login/mfa_prompt_handler.go index ce1b7240ec..ca318741b7 100644 --- a/internal/api/ui/login/mfa_prompt_handler.go +++ b/internal/api/ui/login/mfa_prompt_handler.go @@ -49,13 +49,9 @@ func (l *Login) handleMFAPromptSelection(w http.ResponseWriter, r *http.Request) } func (l *Login) renderMFAPrompt(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, mfaPromptData *domain.MFAPromptStep, err error) { - var errID, errMessage string - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) - } translator := l.getTranslator(r.Context(), authReq) data := mfaData{ - baseData: l.getBaseData(r, authReq, translator, "InitMFAPrompt.Title", "InitMFAPrompt.Description", errID, errMessage), + baseData: l.getBaseData(r, authReq, translator, "InitMFAPrompt.Title", "InitMFAPrompt.Description", err), profileData: l.getProfileData(authReq), } diff --git a/internal/api/ui/login/mfa_verify_handler.go b/internal/api/ui/login/mfa_verify_handler.go index cfffc6fced..34832413cf 100644 --- a/internal/api/ui/login/mfa_verify_handler.go +++ b/internal/api/ui/login/mfa_verify_handler.go @@ -62,12 +62,8 @@ func (l *Login) renderMFAVerify(w http.ResponseWriter, r *http.Request, authReq } func (l *Login) renderMFAVerifySelected(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, verificationStep *domain.MFAVerificationStep, selectedProvider domain.MFAType, err error) { - var errID, errMessage string - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) - } translator := l.getTranslator(r.Context(), authReq) - data := l.getUserData(r, authReq, translator, "", "", errID, errMessage) + data := l.getUserData(r, authReq, translator, "", "", err) if verificationStep == nil { l.renderError(w, r, authReq, err) return diff --git a/internal/api/ui/login/mfa_verify_otp_handler.go b/internal/api/ui/login/mfa_verify_otp_handler.go index fb77bbcba9..bd09a7652b 100644 --- a/internal/api/ui/login/mfa_verify_otp_handler.go +++ b/internal/api/ui/login/mfa_verify_otp_handler.go @@ -61,13 +61,9 @@ func (l *Login) handleOTPVerification(w http.ResponseWriter, r *http.Request, au } func (l *Login) renderOTPVerification(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, providers []domain.MFAType, selectedProvider domain.MFAType, err error) { - var errID, errMessage string - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) - } translator := l.getTranslator(r.Context(), authReq) data := &mfaOTPData{ - userData: l.getUserData(r, authReq, translator, "VerifyMFAU2F.Title", "VerifyMFAU2F.Description", errID, errMessage), + userData: l.getUserData(r, authReq, translator, "VerifyMFAU2F.Title", "VerifyMFAU2F.Description", err), MFAProviders: removeSelectedProviderFromList(providers, selectedProvider), SelectedProvider: selectedProvider, } diff --git a/internal/api/ui/login/mfa_verify_u2f_handler.go b/internal/api/ui/login/mfa_verify_u2f_handler.go index 7873468616..8541c043e4 100644 --- a/internal/api/ui/login/mfa_verify_u2f_handler.go +++ b/internal/api/ui/login/mfa_verify_u2f_handler.go @@ -24,22 +24,19 @@ type mfaU2FFormData struct { } func (l *Login) renderU2FVerification(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, providers []domain.MFAType, err error) { - var errID, errMessage, credentialData string + var credentialData string var webAuthNLogin *domain.WebAuthNLogin if err == nil { userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) webAuthNLogin, err = l.authRepo.BeginMFAU2FLogin(setContext(r.Context(), authReq.UserOrgID), authReq.UserID, authReq.UserOrgID, authReq.ID, userAgentID) } - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) - } if webAuthNLogin != nil { credentialData = base64.RawURLEncoding.EncodeToString(webAuthNLogin.CredentialAssertionData) } translator := l.getTranslator(r.Context(), authReq) data := &mfaU2FData{ webAuthNData: webAuthNData{ - userData: l.getUserData(r, authReq, translator, "VerifyMFAU2F.Title", "VerifyMFAU2F.Description", errID, errMessage), + userData: l.getUserData(r, authReq, translator, "VerifyMFAU2F.Title", "VerifyMFAU2F.Description", err), CredentialCreationData: credentialData, }, MFAProviders: providers, diff --git a/internal/api/ui/login/password_handler.go b/internal/api/ui/login/password_handler.go index 026963bbde..a6e9199ff7 100644 --- a/internal/api/ui/login/password_handler.go +++ b/internal/api/ui/login/password_handler.go @@ -15,12 +15,8 @@ type passwordFormData struct { } func (l *Login) renderPassword(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, err error) { - var errID, errMessage string - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) - } translator := l.getTranslator(r.Context(), authReq) - data := l.getUserData(r, authReq, translator, "Password.Title", "Password.Description", errID, errMessage) + data := l.getUserData(r, authReq, translator, "Password.Title", "Password.Description", err) funcs := map[string]interface{}{ "showPasswordReset": func() bool { if authReq.LoginPolicy != nil { diff --git a/internal/api/ui/login/password_reset_handler.go b/internal/api/ui/login/password_reset_handler.go index f4f98806c7..5bdee7904c 100644 --- a/internal/api/ui/login/password_reset_handler.go +++ b/internal/api/ui/login/password_reset_handler.go @@ -30,11 +30,7 @@ func (l *Login) handlePasswordReset(w http.ResponseWriter, r *http.Request) { } func (l *Login) renderPasswordResetDone(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, err error) { - var errID, errMessage string - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) - } translator := l.getTranslator(r.Context(), authReq) - data := l.getUserData(r, authReq, translator, "PasswordResetDone.Title", "PasswordResetDone.Description", errID, errMessage) + data := l.getUserData(r, authReq, translator, "PasswordResetDone.Title", "PasswordResetDone.Description", err) l.renderer.RenderTemplate(w, r, translator, l.renderer.Templates[tmplPasswordResetDone], data, nil) } diff --git a/internal/api/ui/login/passwordless_login_handler.go b/internal/api/ui/login/passwordless_login_handler.go index 52b9d06fed..d64ad2c3c1 100644 --- a/internal/api/ui/login/passwordless_login_handler.go +++ b/internal/api/ui/login/passwordless_login_handler.go @@ -2,6 +2,7 @@ package login import ( "encoding/base64" + "errors" "net/http" "github.com/zitadel/zitadel/internal/domain" @@ -22,13 +23,15 @@ type passwordlessFormData struct { } func (l *Login) renderPasswordlessVerification(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, passwordSet bool, err error) { - var errID, errMessage, credentialData string + var credentialData string var webAuthNLogin *domain.WebAuthNLogin - if err == nil { - webAuthNLogin, err = l.authRepo.BeginPasswordlessLogin(setContext(r.Context(), authReq.UserOrgID), authReq.UserID, authReq.UserOrgID, authReq.ID, authReq.AgentID) - } - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) + if err == nil || errors.Is(err, &IdPError{}) { // make sure we still proceed with the webauthn login even if the idp login failed + var creationErr error + webAuthNLogin, creationErr = l.authRepo.BeginPasswordlessLogin(setContext(r.Context(), authReq.UserOrgID), authReq.UserID, authReq.UserOrgID, authReq.ID, authReq.AgentID) + // and only overwrite the error if the webauthn creation failed + if creationErr != nil { + err = creationErr + } } if webAuthNLogin != nil { credentialData = base64.RawURLEncoding.EncodeToString(webAuthNLogin.CredentialAssertionData) @@ -39,7 +42,7 @@ func (l *Login) renderPasswordlessVerification(w http.ResponseWriter, r *http.Re translator := l.getTranslator(r.Context(), authReq) data := &passwordlessData{ webAuthNData{ - userData: l.getUserData(r, authReq, translator, "Passwordless.Title", "Passwordless.Description", errID, errMessage), + userData: l.getUserData(r, authReq, translator, "Passwordless.Title", "Passwordless.Description", err), CredentialCreationData: credentialData, }, passwordSet, diff --git a/internal/api/ui/login/passwordless_prompt_handler.go b/internal/api/ui/login/passwordless_prompt_handler.go index ee70b76126..36a3ede71e 100644 --- a/internal/api/ui/login/passwordless_prompt_handler.go +++ b/internal/api/ui/login/passwordless_prompt_handler.go @@ -27,13 +27,9 @@ func (l *Login) handlePasswordlessPrompt(w http.ResponseWriter, r *http.Request) } func (l *Login) renderPasswordlessPrompt(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, err error) { - var errID, errMessage string - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) - } translator := l.getTranslator(r.Context(), authReq) data := &passwordlessPromptData{ - userData: l.getUserData(r, authReq, translator, "PasswordlessPrompt.Title", "PasswordlessPrompt.Description", errID, errMessage), + userData: l.getUserData(r, authReq, translator, "PasswordlessPrompt.Title", "PasswordlessPrompt.Description", err), } l.renderer.RenderTemplate(w, r, translator, l.renderer.Templates[tmplPasswordlessPrompt], data, nil) } diff --git a/internal/api/ui/login/passwordless_registration_handler.go b/internal/api/ui/login/passwordless_registration_handler.go index 976a9277b2..782d62f1fe 100644 --- a/internal/api/ui/login/passwordless_registration_handler.go +++ b/internal/api/ui/login/passwordless_registration_handler.go @@ -78,7 +78,7 @@ func (l *Login) handlePasswordlessRegistration(w http.ResponseWriter, r *http.Re } func (l *Login) renderPasswordlessRegistration(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, userID, orgID, codeID, code string, requestedPlatformType authPlatform, err error) { - var errID, errMessage, credentialData string + var credentialData string var disabled bool if authReq != nil { userID = authReq.UserID @@ -93,7 +93,6 @@ func (l *Login) renderPasswordlessRegistration(w http.ResponseWriter, r *http.Re } } if err != nil { - errID, errMessage = l.getErrorMessage(r, err) disabled = true } if webAuthNToken != nil { @@ -102,7 +101,7 @@ func (l *Login) renderPasswordlessRegistration(w http.ResponseWriter, r *http.Re translator := l.getTranslator(r.Context(), authReq) data := &passwordlessRegistrationData{ webAuthNData{ - userData: l.getUserData(r, authReq, translator, "PasswordlessRegistration.Title", "PasswordlessRegistration.Description", errID, errMessage), + userData: l.getUserData(r, authReq, translator, "PasswordlessRegistration.Title", "PasswordlessRegistration.Description", err), CredentialCreationData: credentialData, }, code, @@ -185,13 +184,9 @@ func (l *Login) checkPasswordlessRegistration(w http.ResponseWriter, r *http.Req } func (l *Login) renderPasswordlessRegistrationDone(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, orgID string, err error) { - var errID, errMessage string - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) - } translator := l.getTranslator(r.Context(), authReq) data := passwordlessRegistrationDoneDate{ - userData: l.getUserData(r, authReq, translator, "PasswordlessRegistrationDone.Title", "PasswordlessRegistrationDone.Description", errID, errMessage), + userData: l.getUserData(r, authReq, translator, "PasswordlessRegistrationDone.Title", "PasswordlessRegistrationDone.Description", err), HideNextButton: authReq == nil, } if authReq == nil { diff --git a/internal/api/ui/login/register_handler.go b/internal/api/ui/login/register_handler.go index 89e0eec7b3..bd5629c432 100644 --- a/internal/api/ui/login/register_handler.go +++ b/internal/api/ui/login/register_handler.go @@ -142,10 +142,6 @@ func (l *Login) handleRegisterCheck(w http.ResponseWriter, r *http.Request) { } func (l *Login) renderRegister(w http.ResponseWriter, r *http.Request, authRequest *domain.AuthRequest, formData *registerFormData, err error) { - var errID, errMessage string - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) - } translator := l.getTranslator(r.Context(), authRequest) if formData == nil { formData = new(registerFormData) @@ -156,7 +152,7 @@ func (l *Login) renderRegister(w http.ResponseWriter, r *http.Request, authReque resourceOwner := determineResourceOwner(r.Context(), authRequest) data := registerData{ - baseData: l.getBaseData(r, authRequest, translator, "RegistrationUser.Title", "RegistrationUser.Description", errID, errMessage), + baseData: l.getBaseData(r, authRequest, translator, "RegistrationUser.Title", "RegistrationUser.Description", err), registerFormData: *formData, } diff --git a/internal/api/ui/login/register_option_handler.go b/internal/api/ui/login/register_option_handler.go index 7d88f76c6c..31270c0442 100644 --- a/internal/api/ui/login/register_option_handler.go +++ b/internal/api/ui/login/register_option_handler.go @@ -33,10 +33,6 @@ func (l *Login) handleRegisterOption(w http.ResponseWriter, r *http.Request) { } func (l *Login) renderRegisterOption(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, err error) { - var errID, errMessage string - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) - } allowed := registrationAllowed(authReq) externalAllowed := externalRegistrationAllowed(authReq) if err == nil { @@ -54,7 +50,7 @@ func (l *Login) renderRegisterOption(w http.ResponseWriter, r *http.Request, aut } translator := l.getTranslator(r.Context(), authReq) data := registerOptionData{ - baseData: l.getBaseData(r, authReq, translator, "RegisterOption.Title", "RegisterOption.Description", errID, errMessage), + baseData: l.getBaseData(r, authReq, translator, "RegisterOption.Title", "RegisterOption.Description", err), } funcs := map[string]interface{}{ "hasRegistration": func() bool { diff --git a/internal/api/ui/login/register_org_handler.go b/internal/api/ui/login/register_org_handler.go index acb032d8f1..58a49f1d08 100644 --- a/internal/api/ui/login/register_org_handler.go +++ b/internal/api/ui/login/register_org_handler.go @@ -97,16 +97,12 @@ func (l *Login) handleRegisterOrgCheck(w http.ResponseWriter, r *http.Request) { } func (l *Login) renderRegisterOrg(w http.ResponseWriter, r *http.Request, authRequest *domain.AuthRequest, formData *registerOrgFormData, err error) { - var errID, errMessage string - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) - } if formData == nil { formData = new(registerOrgFormData) } translator := l.getTranslator(r.Context(), authRequest) data := registerOrgData{ - baseData: l.getBaseData(r, authRequest, translator, "RegistrationOrg.Title", "RegistrationOrg.Description", errID, errMessage), + baseData: l.getBaseData(r, authRequest, translator, "RegistrationOrg.Title", "RegistrationOrg.Description", err), registerOrgFormData: *formData, } pwPolicy := l.getPasswordComplexityPolicy(r, "0") diff --git a/internal/api/ui/login/renderer.go b/internal/api/ui/login/renderer.go index cb05f78323..79fc2dcf0d 100644 --- a/internal/api/ui/login/renderer.go +++ b/internal/api/ui/login/renderer.go @@ -341,7 +341,6 @@ func (l *Login) chooseNextStep(w http.ResponseWriter, r *http.Request, authReq * } func (l *Login) renderInternalError(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, err error) { - var msg string if err != nil { log := logging.WithError(err) if authReq != nil { @@ -352,17 +351,15 @@ func (l *Login) renderInternalError(w http.ResponseWriter, r *http.Request, auth } else { log.Info() } - - _, msg = l.getErrorMessage(r, err) } translator := l.getTranslator(r.Context(), authReq) - data := l.getBaseData(r, authReq, translator, "Errors.Internal", "", "Internal", msg) + data := l.getBaseData(r, authReq, translator, "Errors.Internal", "", err) l.renderer.RenderTemplate(w, r, translator, l.renderer.Templates[tmplError], data, nil) } -func (l *Login) getUserData(r *http.Request, authReq *domain.AuthRequest, translator *i18n.Translator, titleI18nKey string, descriptionI18nKey string, errType, errMessage string) userData { +func (l *Login) getUserData(r *http.Request, authReq *domain.AuthRequest, translator *i18n.Translator, titleI18nKey string, descriptionI18nKey string, err error) userData { userData := userData{ - baseData: l.getBaseData(r, authReq, translator, titleI18nKey, descriptionI18nKey, errType, errMessage), + baseData: l.getBaseData(r, authReq, translator, titleI18nKey, descriptionI18nKey, err), profileData: l.getProfileData(authReq), } if authReq != nil && authReq.LinkingUsers != nil { @@ -371,7 +368,7 @@ func (l *Login) getUserData(r *http.Request, authReq *domain.AuthRequest, transl return userData } -func (l *Login) getBaseData(r *http.Request, authReq *domain.AuthRequest, translator *i18n.Translator, titleI18nKey string, descriptionI18nKey string, errType, errMessage string) baseData { +func (l *Login) getBaseData(r *http.Request, authReq *domain.AuthRequest, translator *i18n.Translator, titleI18nKey string, descriptionI18nKey string, err error) baseData { title := "" if titleI18nKey != "" { title = translator.LocalizeWithoutArgs(titleI18nKey) @@ -383,10 +380,16 @@ func (l *Login) getBaseData(r *http.Request, authReq *domain.AuthRequest, transl } lang, _ := l.renderer.ReqLang(translator, r).Base() + var errID, errMessage string + var errPopup bool + if err != nil { + errID, errMessage, errPopup = l.getErrorMessage(r, err) + } baseData := baseData{ errorData: errorData{ - ErrID: errType, + ErrID: errID, ErrMessage: errMessage, + ErrPopup: errPopup, }, Lang: lang.String(), Title: title, @@ -482,14 +485,17 @@ func (l *Login) setLinksOnBaseData(baseData baseData, privacyPolicy *domain.Priv return baseData } -func (l *Login) getErrorMessage(r *http.Request, err error) (errID, errMsg string) { +func (l *Login) getErrorMessage(r *http.Request, err error) (errID, errMsg string, popup bool) { + idpErr := new(IdPError) + if errors.Is(err, idpErr) { + popup = true + } caosErr := new(zerrors.ZitadelError) if errors.As(err, &caosErr) { - localized := l.renderer.LocalizeFromRequest(l.getTranslator(r.Context(), nil), r, caosErr.Message, nil) - return caosErr.ID, localized - + localized := l.renderer.LocalizeFromRequest(l.getTranslator(r.Context(), nil), r, caosErr.Message, map[string]interface{}{"Details": caosErr.Parent}) + return caosErr.ID, localized, popup } - return "", err.Error() + return "", err.Error(), popup } func (l *Login) getTheme(r *http.Request) string { @@ -662,6 +668,7 @@ type baseData struct { type errorData struct { ErrID string ErrMessage string + ErrPopup bool } type userData struct { diff --git a/internal/api/ui/login/select_user_handler.go b/internal/api/ui/login/select_user_handler.go index 98c3993376..b15366baa1 100644 --- a/internal/api/ui/login/select_user_handler.go +++ b/internal/api/ui/login/select_user_handler.go @@ -27,7 +27,7 @@ func (l *Login) renderUserSelection(w http.ResponseWriter, r *http.Request, auth descriptionI18nKey = "SelectAccount.DescriptionLinking" } data := userSelectionData{ - baseData: l.getBaseData(r, authReq, translator, titleI18nKey, descriptionI18nKey, "", ""), + baseData: l.getBaseData(r, authReq, translator, titleI18nKey, descriptionI18nKey, nil), Users: selectionData.Users, Linking: linking, } diff --git a/internal/api/ui/login/static/i18n/bg.yaml b/internal/api/ui/login/static/i18n/bg.yaml index ad308b859d..be0b1d7f14 100644 --- a/internal/api/ui/login/static/i18n/bg.yaml +++ b/internal/api/ui/login/static/i18n/bg.yaml @@ -493,6 +493,10 @@ Errors: CreationNotAllowed: Създаването на нов потребител не е разрешено на този доставчик LinkingNotAllowed: Свързването на потребител не е разрешено на този доставчик NoOptionAllowed: Нито създаване, нито свързване е разрешено за този доставчик. Моля, свържете се с администратора. + LoginFailedSwitchLocal: | + Вход в външен доставчик на идентификация е неуспешен. Връщане към локален вход. + + Подробности за грешката: {{.Details}} GrantRequired: 'Влизането не е възможно. ' ProjectRequired: 'Влизането не е възможно. ' IdentityProvider: diff --git a/internal/api/ui/login/static/i18n/cs.yaml b/internal/api/ui/login/static/i18n/cs.yaml index 032302e3b8..f362add6f3 100644 --- a/internal/api/ui/login/static/i18n/cs.yaml +++ b/internal/api/ui/login/static/i18n/cs.yaml @@ -505,6 +505,10 @@ Errors: CreationNotAllowed: Vytvoření nového uživatele není na tomto poskytovateli povoleno LinkingNotAllowed: Propojení uživatele není na tomto poskytovateli povoleno NoOptionAllowed: Ani vytvoření, ani propojení není povoleno pro tohoto poskytovatele. Obraťte se na svého správce. + LoginFailedSwitchLocal: | + Přihlášení u externího poskytovatele identit selhalo. Vracíme se k místnímu přihlášení. + + Podrobnosti o chybě: {{.Details}} GrantRequired: Přihlášení není možné. Uživatel musí mít alespoň jeden oprávnění na aplikaci. Prosím, kontaktujte svého správce. ProjectRequired: Přihlášení není možné. Organizace uživatele musí být přidělena k projektu. Prosím, kontaktujte svého správce. IdentityProvider: diff --git a/internal/api/ui/login/static/i18n/de.yaml b/internal/api/ui/login/static/i18n/de.yaml index 28f4d00a88..4e6782fcb8 100644 --- a/internal/api/ui/login/static/i18n/de.yaml +++ b/internal/api/ui/login/static/i18n/de.yaml @@ -504,6 +504,10 @@ Errors: CreationNotAllowed: Erstellen eines neuen Benutzers mit diesem Provider ist nicht erlaubt LinkingNotAllowed: Verknüpfen eines Benutzers mit diesem Provider ist nicht erlaubt NoOptionAllowed: Weder Erstellung noch Verknüpfung ist für diesen Provider erlaubt. Bitte wenden Sie sich an Ihren Administrator. + LoginFailedSwitchLocal: | + Anmeldung beim externen Identitätsanbieter fehlgeschlagen. Zurück zur lokalen Anmeldung. + + Fehlerdetails: {{.Details}} GrantRequired: Die Anmeldung an diese Applikation ist nicht möglich. Der Benutzer benötigt mindestens eine Berechtigung an der Applikation. Bitte wende dich an deinen Administrator. ProjectRequired: Die Anmeldung an dieser Applikation ist nicht möglich. Die Organisation des Benutzer benötigt Berechtigung auf das Projekt. Bitte wende dich an deinen Administrator. IdentityProvider: diff --git a/internal/api/ui/login/static/i18n/en.yaml b/internal/api/ui/login/static/i18n/en.yaml index 6c58b11257..bdf42ae57f 100644 --- a/internal/api/ui/login/static/i18n/en.yaml +++ b/internal/api/ui/login/static/i18n/en.yaml @@ -505,6 +505,10 @@ Errors: CreationNotAllowed: Creation of a new user is not allowed on this provider LinkingNotAllowed: Linking of a user is not allowed on this provider NoOptionAllowed: Neither creation of linking is allowed on this provider. Please contact your administrator. + LoginFailedSwitchLocal: | + Login at External IDP failed. Falling back to local login. + + Error details: {{.Details}} GrantRequired: Login not possible. The user is required to have at least one grant on the application. Please contact your administrator. ProjectRequired: Login not possible. The organization of the user must be granted to the project. Please contact your administrator. IdentityProvider: diff --git a/internal/api/ui/login/static/i18n/es.yaml b/internal/api/ui/login/static/i18n/es.yaml index de57fdcd85..c6aaac6bf0 100644 --- a/internal/api/ui/login/static/i18n/es.yaml +++ b/internal/api/ui/login/static/i18n/es.yaml @@ -488,6 +488,10 @@ Errors: CreationNotAllowed: La creación de un nuevo usuario no está permitida para este proveedor LinkingNotAllowed: La vinculación de un usuario no está permitida para este proveedor NoOptionAllowed: Ni la creación ni la vinculación están permitidas en este proveedor. Póngase en contacto con su administrador. + LoginFailedSwitchLocal: | + Error al iniciar sesión en el proveedor de identidad externo. Volviendo al inicio de sesión local. + + Detalles del error: {{.Details}} GrantRequired: El inicio de sesión no es posible. Se requiere que el usuario tenga al menos una concesión sobre la aplicación. Por favor contacta con tu administrador. ProjectRequired: El inicio de sesión no es posible. La organización del usuario debe tener el acceso concedido para el proyecto. Por favor contacta con tu administrador. IdentityProvider: diff --git a/internal/api/ui/login/static/i18n/fr.yaml b/internal/api/ui/login/static/i18n/fr.yaml index 8534085ae9..83dd64d147 100644 --- a/internal/api/ui/login/static/i18n/fr.yaml +++ b/internal/api/ui/login/static/i18n/fr.yaml @@ -506,6 +506,10 @@ Errors: CreationNotAllowed: La création d'un nouvel utilisateur n'est pas autorisée sur ce fournisseur. LinkingNotAllowed: La création d'un lien vers un utilisateur n'est pas autorisée pour ce fournisseur. NoOptionAllowed: Ni la création ni la liaison sont autorisées pour ce fournisseur. Veuillez contacter votre administrateur. + LoginFailedSwitchLocal: | + Échec de la connexion au fournisseur d'identité externe. Retour à la connexion locale. + + Détails de l'erreur: {{.Details}} GrantRequired: Connexion impossible. L'utilisateur doit avoir au moins une subvention sur l'application. Veuillez contacter votre administrateur. ProjectRequired: Connexion impossible. L'organisation de l'utilisateur doit être accordée au projet. Veuillez contacter votre administrateur. IdentityProvider: diff --git a/internal/api/ui/login/static/i18n/hu.yaml b/internal/api/ui/login/static/i18n/hu.yaml index 80ed98945c..ef2a2acab4 100644 --- a/internal/api/ui/login/static/i18n/hu.yaml +++ b/internal/api/ui/login/static/i18n/hu.yaml @@ -465,6 +465,10 @@ Errors: CreationNotAllowed: Új felhasználó létrehozása nem engedélyezett ezen a szolgáltatón LinkingNotAllowed: A felhasználó összekapcsolása nem engedélyezett ezen a szolgáltatón NoOptionAllowed: Sem új felhasználó létrehozása, sem összekapcsolás nem engedélyezett ezen a szolgáltatón. Kérjük, lépj kapcsolatba az adminisztrátoroddal. + LoginFailedSwitchLocal: | + Az egyéni azonosító szolgáltatóhoz való bejelentkezés sikertelen volt. Visszatérés a helyi bejelentkezéshez. + + Hiba részletei: {{.Details}} GrantRequired: Bejelentkezés nem lehetséges. A felhasználónak legalább egy jogosultsággal kell rendelkeznie az alkalmazáson. Kérlek, lépj kapcsolatba az adminisztrátoroddal. ProjectRequired: Bejelentkezés nem lehetséges. A felhasználó szervezetének engedélyezve kell lennie a projektre. Kérlek, lépj kapcsolatba az adminisztrátoroddal. IdentityProvider: diff --git a/internal/api/ui/login/static/i18n/id.yaml b/internal/api/ui/login/static/i18n/id.yaml index 63deb41229..7fdd1bee1a 100644 --- a/internal/api/ui/login/static/i18n/id.yaml +++ b/internal/api/ui/login/static/i18n/id.yaml @@ -464,6 +464,10 @@ Errors: CreationNotAllowed: Pembuatan pengguna baru tidak diperbolehkan pada penyedia ini LinkingNotAllowed: Menautkan pengguna tidak diperbolehkan di penyedia ini NoOptionAllowed: 'Pembuatan tautan tidak diperbolehkan pada penyedia ini. ' + LoginFailedSwitchLocal: | + Gagal masuk ke Penyedia ID Eksternal. Kembali ke login lokal. + + Detail kesalahan: {{.Details}} GrantRequired: 'Masuk tidak dapat dilakukan. ' ProjectRequired: 'Masuk tidak dapat dilakukan. ' IdentityProvider: diff --git a/internal/api/ui/login/static/i18n/it.yaml b/internal/api/ui/login/static/i18n/it.yaml index 46e74d3b13..ca681b6e82 100644 --- a/internal/api/ui/login/static/i18n/it.yaml +++ b/internal/api/ui/login/static/i18n/it.yaml @@ -505,6 +505,10 @@ Errors: CreationNotAllowed: La creazione di un nuovo utente non è consentita su questo provider. LinkingNotAllowed: Il collegamento di un utente non è consentito su questo provider. NoOptionAllowed: Né la creazione né il collegamento sono consentiti per questo provider. Contattare l'amministratore. + LoginFailedSwitchLocal: | + Accesso al provider di identità esterno non riuscito. Ritorno all'accesso locale. + + Dettagli dell'errore: {{.Details}} GrantRequired: Accesso non possibile. L'utente deve avere almeno una sovvenzione sull'applicazione. Contatta il tuo amministratore. ProjectRequired: Accesso non possibile. L'organizzazione dell'utente deve essere concessa al progetto. Contatta il tuo amministratore. IdentityProvider: diff --git a/internal/api/ui/login/static/i18n/ja.yaml b/internal/api/ui/login/static/i18n/ja.yaml index 9ec99eb912..8d725785c6 100644 --- a/internal/api/ui/login/static/i18n/ja.yaml +++ b/internal/api/ui/login/static/i18n/ja.yaml @@ -469,6 +469,10 @@ Errors: CreationNotAllowed: このプロバイダーでは、新しいユーザーの作成は許可されていません LinkingNotAllowed: このプロバイダーでは、ユーザーのリンクが許可されていません NoOptionAllowed: このプロバイダーでは作成もリンクも許可されていません。 管理者にお問い合わせください。 + LoginFailedSwitchLocal: | + 外部IDプロバイダーへのログインに失敗しました。ローカルログインに戻ります。 + + エラーの詳細: {{.Details}} GrantRequired: ログインできません。このユーザーは、アプリケーションに少なくとも1つの権限を付与されていることが必要です。管理者にお問い合わせください。 ProjectRequired: ログインできません。ユーザーの組織がプロジェクトに権限を付与されている必要があります。管理者にお問い合わせください。 IdentityProvider: diff --git a/internal/api/ui/login/static/i18n/ko.yaml b/internal/api/ui/login/static/i18n/ko.yaml index e62cfcb8b5..bbe7a403a0 100644 --- a/internal/api/ui/login/static/i18n/ko.yaml +++ b/internal/api/ui/login/static/i18n/ko.yaml @@ -505,6 +505,10 @@ Errors: CreationNotAllowed: 이 제공자에서는 새 사용자 생성을 허용하지 않습니다 LinkingNotAllowed: 이 제공자에서는 사용자를 연결할 수 없습니다 NoOptionAllowed: 이 제공자에서는 생성과 연결이 모두 허용되지 않습니다. 관리자에게 문의하세요. + LoginFailedSwitchLocal: | + 외부 IDP에서 로그인에 실패했습니다. 로컬 로그인으로 돌아갑니다. + + 오류 세부 정보: {{.Details}} GrantRequired: 로그인 불가. 사용자는 애플리케이션에서 최소한 하나의 권한이 필요합니다. 관리자에게 문의하세요. ProjectRequired: 로그인 불가. 사용자의 조직이 프로젝트에 허가되어야 합니다. 관리자에게 문의하세요. IdentityProvider: diff --git a/internal/api/ui/login/static/i18n/mk.yaml b/internal/api/ui/login/static/i18n/mk.yaml index dbb988a0a6..2465c935b2 100644 --- a/internal/api/ui/login/static/i18n/mk.yaml +++ b/internal/api/ui/login/static/i18n/mk.yaml @@ -506,6 +506,10 @@ Errors: CreationNotAllowed: Креирањето на нов корисник не е дозволено на овој провајдер LinkingNotAllowed: Поврзувањето на корисник не е дозволено на овој провајдер NoOptionAllowed: NНиту создавање, ниту поврзување е дозволено за овој провајдер. Ве молиме контактирајте го вашиот администратор. + LoginFailedSwitchLocal: | + Најавата во надворешен провајдер на идентитет не успеа. Враќање на локална најава. + + Детали за грешката: {{.Details}} GrantRequired: Не е можно најавување. Корисникот мора да има барем едно овластување за апликацијата. Ве молиме контактирајте го вашиот администратор. ProjectRequired: Не е можно најавување. Организацијата на корисникот мора да биде доделена на проектот. Ве молиме контактирајте го вашиот администратор. IdentityProvider: diff --git a/internal/api/ui/login/static/i18n/nl.yaml b/internal/api/ui/login/static/i18n/nl.yaml index 3bbcee94b6..2bc1137154 100644 --- a/internal/api/ui/login/static/i18n/nl.yaml +++ b/internal/api/ui/login/static/i18n/nl.yaml @@ -505,6 +505,10 @@ Errors: CreationNotAllowed: Creatie van een nieuwe gebruiker is niet toegestaan op deze Provider LinkingNotAllowed: Koppeling van een gebruiker is niet toegestaan op deze Provider NoOptionAllowed: Noch aanmaak noch koppeling is toegestaan voor deze provider. Neem contact op met uw beheerder. + LoginFailedSwitchLocal: | + Aanmelding bij externe identiteitsprovider is mislukt. Terug naar lokale aanmelding. + + Foutdetails: {{.Details}} GrantRequired: Inloggen niet mogelijk. De gebruiker moet minimaal één grant hebben op de applicatie. Neem contact op met uw beheerder. ProjectRequired: Inloggen niet mogelijk. De organisatie van de gebruiker moet toegekend zijn aan het project. Neem contact op met uw beheerder. IdentityProvider: diff --git a/internal/api/ui/login/static/i18n/pl.yaml b/internal/api/ui/login/static/i18n/pl.yaml index 2c8b4fddf0..912af49a74 100644 --- a/internal/api/ui/login/static/i18n/pl.yaml +++ b/internal/api/ui/login/static/i18n/pl.yaml @@ -506,6 +506,10 @@ Errors: CreationNotAllowed: Tworzenie nowego użytkownika nie jest dozwolone w tym Providencie LinkingNotAllowed: Linkowanie użytkownika nie jest dozwolone na tym Providencie NoOptionAllowed: Ani tworzenie, ani łączenie nie jest dozwolone dla tego dostawcy. Skontaktuj się z administratorem. + LoginFailedSwitchLocal: | + Logowanie w zewnętrznym dostawcy tożsamości nie powiodło się. Powrót do logowania lokalnego. + + Szczegóły błędu: {{.Details}} GrantRequired: Logowanie nie jest możliwe. Użytkownik musi posiadać przynajmniej jedno uprawnienie w aplikacji. Skontaktuj się z administratorem. ProjectRequired: Logowanie nie jest możliwe. Organizacja użytkownika musi zostać udzielona projektowi. Skontaktuj się z administratorem. IdentityProvider: diff --git a/internal/api/ui/login/static/i18n/pt.yaml b/internal/api/ui/login/static/i18n/pt.yaml index f03f120ed8..5f18157e67 100644 --- a/internal/api/ui/login/static/i18n/pt.yaml +++ b/internal/api/ui/login/static/i18n/pt.yaml @@ -502,6 +502,10 @@ Errors: CreationNotAllowed: A criação de um novo usuário não é permitida neste provedor LinkingNotAllowed: A vinculação de um usuário não é permitida neste provedor NoOptionAllowed: Nem criação nem vinculação são permitidas neste fornecedor. Contate o seu administrador. + LoginFailedSwitchLocal: | + Falha no login no provedor de identidade externo. Retornando ao login local. + + Detalhes do erro: {{.Details}} GrantRequired: Login não é possível. O usuário precisa ter pelo menos uma permissão no aplicativo. Entre em contato com o administrador. ProjectRequired: Login não é possível. A organização do usuário precisa ser concedida ao projeto. Entre em contato com o administrador. IdentityProvider: diff --git a/internal/api/ui/login/static/i18n/ru.yaml b/internal/api/ui/login/static/i18n/ru.yaml index 221c20a2e9..8afd3a31b6 100644 --- a/internal/api/ui/login/static/i18n/ru.yaml +++ b/internal/api/ui/login/static/i18n/ru.yaml @@ -506,6 +506,10 @@ Errors: CreationNotAllowed: Создание нового пользователя для этого провайдера запрещено LinkingNotAllowed: Привязка к этому провайдеру запрещена NoOptionAllowed: Ни создание, ни привязка пользователя к этому провайдеру невозможны. Обратитесь к администратору. + LoginFailedSwitchLocal: | + Вход в внешний поставщик идентификации не удался. Возвращаемся к локальному входу. + + Подробности об ошибке: {{.Details}} GrantRequired: Вход невозможен. Пользователь должен иметь хотя бы один допуск к приложению. Обратитесь к администратору. ProjectRequired: Вход невозможен. Организация пользователя должна иметь доступ к проекту. Обратитесь к администратору. IdentityProvider: diff --git a/internal/api/ui/login/static/i18n/sv.yaml b/internal/api/ui/login/static/i18n/sv.yaml index 26fee23551..e6c1245503 100644 --- a/internal/api/ui/login/static/i18n/sv.yaml +++ b/internal/api/ui/login/static/i18n/sv.yaml @@ -505,6 +505,10 @@ Errors: CreationNotAllowed: Det är inte tillåtet att skapa nya konton från den här externa leverantören LinkingNotAllowed: Det är inte tillåtet att koppla ihop konton från den här externa leverantören NoOptionAllowed: Varken skapande eller länkande är tillåtet för denna leverantör. Kontakta administratören. + LoginFailedSwitchLocal: | + Inloggning vid extern identitetsprovider misslyckades. Återgår till lokal inloggning. + + Felaktighetsdetaljer: {{.Details}} GrantRequired: Det går inte att logga in just nu. Användarkontot har inte tillgång till någonting i tjänsten. Ta kontakt med systemansvarig. ProjectRequired: Det går inte att logga in just nu. Användarkontots organisation har inte tillgång till tjänsten. Ta kontakt med systemansvarig. IdentityProvider: diff --git a/internal/api/ui/login/static/i18n/zh.yaml b/internal/api/ui/login/static/i18n/zh.yaml index 79db3c020e..4fcb469831 100644 --- a/internal/api/ui/login/static/i18n/zh.yaml +++ b/internal/api/ui/login/static/i18n/zh.yaml @@ -505,6 +505,10 @@ Errors: CreationNotAllowed: 不允许在该供应商上创建新用户 LinkingNotAllowed: 在此提供者上不允许链接一个用户 NoOptionAllowed: 此提供商不允许创建或链接。请联系您的管理员。 + LoginFailedSwitchLocal: | + 外部身份提供商的登录失败。返回到本地登录。 + + 错误详情: {{.Details}} GrantRequired: 无法登录,用户需要在应用程序上拥有至少一项授权,请联系您的管理员。 ProjectRequired: 无法登录,用户的组织必须授予项目,请联系您的管理员。 IdentityProvider: diff --git a/internal/api/ui/login/static/resources/scripts/error_popup.js b/internal/api/ui/login/static/resources/scripts/error_popup.js new file mode 100644 index 0000000000..e817f81ac3 --- /dev/null +++ b/internal/api/ui/login/static/resources/scripts/error_popup.js @@ -0,0 +1,30 @@ +function removeOverlay(overlay) { + if (overlay.classList.contains("show")) { + overlay.classList.remove("show"); + document.removeEventListener("mousemove", onMouseMove); + document.removeEventListener("click", onClick); + } +} + +function onMouseMove() { + const overlay = document.getElementById("dialog_overlay"); + if (overlay) { + removeOverlay(overlay); + } +} + +function onClick() { + const overlay = document.getElementById("dialog_overlay"); + if (overlay) { + removeOverlay(overlay); + } +} + +window.addEventListener('DOMContentLoaded', () => { + const overlay = document.getElementById("dialog_overlay"); + if (overlay && overlay.classList.contains("show")) { + setTimeout(() => removeOverlay(overlay), 5000); + document.addEventListener("mousemove", onMouseMove); + document.addEventListener("click", onClick); + } +}); diff --git a/internal/api/ui/login/static/resources/themes/scss/main.scss b/internal/api/ui/login/static/resources/themes/scss/main.scss index 3b0ddc0da2..e9a97df2ca 100644 --- a/internal/api/ui/login/static/resources/themes/scss/main.scss +++ b/internal/api/ui/login/static/resources/themes/scss/main.scss @@ -20,3 +20,26 @@ body { .text-align-center { text-align: center; } + +.dialog_overlay { + display: none; + position: fixed; + top: 0; + left: 0; + width: 100vw; + height: 100vh; + background-color: black; + z-index: 1001; + -moz-opacity: 0.8; + opacity: .80; + filter: alpha(opacity=80); +} + +.dialog_overlay.show { + display: block; +} + +.dialog_content { + position: relative; + z-index: 1002; +} \ No newline at end of file diff --git a/internal/api/ui/login/static/resources/themes/scss/styles/error/error.scss b/internal/api/ui/login/static/resources/themes/scss/styles/error/error.scss index 39982e94ba..05131ec071 100644 --- a/internal/api/ui/login/static/resources/themes/scss/styles/error/error.scss +++ b/internal/api/ui/login/static/resources/themes/scss/styles/error/error.scss @@ -6,6 +6,10 @@ margin-right: .5rem; font-size: 1.5rem; } + + .lgn-error-message { + white-space: pre-line; + } } #wa-error { diff --git a/internal/api/ui/login/static/templates/error-message.html b/internal/api/ui/login/static/templates/error-message.html index 6c56caad1c..3de5d81bf4 100644 --- a/internal/api/ui/login/static/templates/error-message.html +++ b/internal/api/ui/login/static/templates/error-message.html @@ -1,10 +1,9 @@ {{ define "error-message" }} {{if .ErrMessage }} -
+
-

- {{ .ErrMessage }} -

+

{{ .ErrMessage }}

+
{{end}} -{{ end }} \ No newline at end of file +{{end}} \ No newline at end of file diff --git a/internal/api/ui/login/static/templates/password.html b/internal/api/ui/login/static/templates/password.html index c036e3c51b..98d94f3ef8 100644 --- a/internal/api/ui/login/static/templates/password.html +++ b/internal/api/ui/login/static/templates/password.html @@ -41,4 +41,5 @@ + diff --git a/internal/api/ui/login/static/templates/passwordless.html b/internal/api/ui/login/static/templates/passwordless.html index 6a95b54079..2dc6d544a0 100644 --- a/internal/api/ui/login/static/templates/passwordless.html +++ b/internal/api/ui/login/static/templates/passwordless.html @@ -40,5 +40,6 @@ + {{template "main-bottom" .}} diff --git a/internal/api/ui/login/username_change_handler.go b/internal/api/ui/login/username_change_handler.go index f11fe43a72..b932079dd0 100644 --- a/internal/api/ui/login/username_change_handler.go +++ b/internal/api/ui/login/username_change_handler.go @@ -16,12 +16,8 @@ type changeUsernameData struct { } func (l *Login) renderChangeUsername(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest, err error) { - var errID, errMessage string - if err != nil { - errID, errMessage = l.getErrorMessage(r, err) - } translator := l.getTranslator(r.Context(), authReq) - data := l.getUserData(r, authReq, translator, "UsernameChange.Title", "UsernameChange.Description", errID, errMessage) + data := l.getUserData(r, authReq, translator, "UsernameChange.Title", "UsernameChange.Description", err) l.renderer.RenderTemplate(w, r, translator, l.renderer.Templates[tmplChangeUsername], data, nil) } @@ -41,8 +37,7 @@ func (l *Login) handleChangeUsername(w http.ResponseWriter, r *http.Request) { } func (l *Login) renderChangeUsernameDone(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest) { - var errType, errMessage string translator := l.getTranslator(r.Context(), authReq) - data := l.getUserData(r, authReq, translator, "UsernameChangeDone.Title", "UsernameChangeDone.Description", errType, errMessage) + data := l.getUserData(r, authReq, translator, "UsernameChangeDone.Title", "UsernameChangeDone.Description", nil) l.renderer.RenderTemplate(w, r, translator, l.renderer.Templates[tmplChangeUsernameDone], data, nil) } diff --git a/internal/auth/repository/auth_request.go b/internal/auth/repository/auth_request.go index d89eb35a8b..c16a757a01 100644 --- a/internal/auth/repository/auth_request.go +++ b/internal/auth/repository/auth_request.go @@ -41,4 +41,5 @@ type AuthRequestRepository interface { AutoRegisterExternalUser(ctx context.Context, user *domain.Human, externalIDP *domain.UserIDPLink, orgMemberRoles []string, authReqID, userAgentID, resourceOwner string, metadatas []*domain.Metadata, info *domain.BrowserInfo) error ResetLinkingUsers(ctx context.Context, authReqID, userAgentID string) error ResetSelectedIDP(ctx context.Context, authReqID, userAgentID string) error + RequestLocalAuth(ctx context.Context, authReqID, userAgentID string) error } diff --git a/internal/auth/repository/eventsourcing/eventstore/auth_request.go b/internal/auth/repository/eventsourcing/eventstore/auth_request.go index 813c5668f4..60486b66f9 100644 --- a/internal/auth/repository/eventsourcing/eventstore/auth_request.go +++ b/internal/auth/repository/eventsourcing/eventstore/auth_request.go @@ -563,6 +563,15 @@ func (repo *AuthRequestRepo) ResetSelectedIDP(ctx context.Context, authReqID, us return repo.AuthRequests.UpdateAuthRequest(ctx, request) } +func (repo *AuthRequestRepo) RequestLocalAuth(ctx context.Context, authReqID, userAgentID string) error { + request, err := repo.getAuthRequest(ctx, authReqID, userAgentID) + if err != nil { + return err + } + request.RequestLocalAuth = true + return repo.AuthRequests.UpdateAuthRequest(ctx, request) +} + func (repo *AuthRequestRepo) AutoRegisterExternalUser(ctx context.Context, registerUser *domain.Human, externalIDP *domain.UserIDPLink, orgMemberRoles []string, authReqID, userAgentID, resourceOwner string, metadatas []*domain.Metadata, info *domain.BrowserInfo) (err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -1059,7 +1068,7 @@ func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *domain.Auth request.PreferredLanguage = gu.Ptr(language.Make(user.HumanView.PreferredLanguage)) } - isInternalLogin := request.SelectedIDPConfigID == "" && userSession.SelectedIDPConfigID == "" + isInternalLogin := (request.SelectedIDPConfigID == "" && userSession.SelectedIDPConfigID == "") || request.RequestLocalAuth idps, err := checkExternalIDPsOfUser(ctx, repo.IDPUserLinksProvider, user.ID) if err != nil { return nil, err @@ -1067,7 +1076,9 @@ func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *domain.Auth noLocalAuth := request.LoginPolicy != nil && !request.LoginPolicy.AllowUsernamePassword allowedLinkedIDPs := checkForAllowedIDPs(request.AllowedExternalIDPs, idps.Links) - if (!isInternalLogin || len(allowedLinkedIDPs) > 0 || noLocalAuth) && len(request.LinkingUsers) == 0 { + if (!isInternalLogin || len(allowedLinkedIDPs) > 0 || noLocalAuth) && + len(request.LinkingUsers) == 0 && + !request.RequestLocalAuth { step, err := repo.idpChecked(request, allowedLinkedIDPs, userSession) if err != nil { return nil, err diff --git a/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go b/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go index 976ae8d8a9..7d71ddecd9 100644 --- a/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go +++ b/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go @@ -2263,6 +2263,86 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) { []domain.NextStep{&domain.LinkUsersStep{}}, nil, }, + { + "local auth requested (passwordless and password set up), passwordless step", + fields{ + userSessionViewProvider: &mockViewUserSession{}, + userViewProvider: &mockViewUser{ + PasswordSet: true, + IsEmailVerified: true, + MFAMaxSetUp: int32(domain.MFALevelMultiFactor), + PasswordlessTokens: user_view_model.WebAuthNTokens{&user_view_model.WebAuthNView{ID: "id", State: int32(user_model.MFAStateReady)}}, + }, + userEventProvider: &mockEventUser{}, + orgViewProvider: &mockViewOrg{State: domain.OrgStateActive}, + lockoutPolicyProvider: &mockLockoutPolicy{ + policy: &query.LockoutPolicy{ + ShowFailures: true, + }, + }, + idpUserLinksProvider: &mockIDPUserLinks{ + idps: []*query.IDPUserLink{{IDPID: "IDPConfigID"}}, + }, + }, + args{ + &domain.AuthRequest{ + UserID: "UserID", + SelectedIDPConfigID: "IDPConfigID", + LoginPolicy: &domain.LoginPolicy{ + PasswordlessType: domain.PasswordlessTypeAllowed, + }, + AllowedExternalIDPs: []*domain.IDPProvider{ + { + IDPConfigID: "IDPConfigID", + }, + }, + RequestLocalAuth: true, + }, false}, + []domain.NextStep{ + &domain.PasswordlessStep{ + PasswordSet: true, + }, + }, + nil, + }, + { + "local auth requested (password set up), password step", + fields{ + userSessionViewProvider: &mockViewUserSession{}, + userViewProvider: &mockViewUser{ + PasswordSet: true, + IsEmailVerified: true, + }, + userEventProvider: &mockEventUser{}, + orgViewProvider: &mockViewOrg{State: domain.OrgStateActive}, + lockoutPolicyProvider: &mockLockoutPolicy{ + policy: &query.LockoutPolicy{ + ShowFailures: true, + }, + }, + idpUserLinksProvider: &mockIDPUserLinks{ + idps: []*query.IDPUserLink{{IDPID: "IDPConfigID"}}, + }, + }, + args{ + &domain.AuthRequest{ + UserID: "UserID", + SelectedIDPConfigID: "IDPConfigID", + LoginPolicy: &domain.LoginPolicy{ + PasswordlessType: domain.PasswordlessTypeAllowed, + }, + AllowedExternalIDPs: []*domain.IDPProvider{ + { + IDPConfigID: "IDPConfigID", + }, + }, + RequestLocalAuth: true, + }, false}, + []domain.NextStep{ + &domain.PasswordStep{}, + }, + nil, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/internal/authz/repository/eventsourcing/eventstore/token_verifier.go b/internal/authz/repository/eventsourcing/eventstore/token_verifier.go index 9dec3fcf00..b707631c22 100644 --- a/internal/authz/repository/eventsourcing/eventstore/token_verifier.go +++ b/internal/authz/repository/eventsourcing/eventstore/token_verifier.go @@ -159,7 +159,7 @@ func (repo *TokenVerifierRepo) verifySessionToken(ctx context.Context, sessionID ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - session, err := repo.Query.SessionByID(ctx, true, sessionID, token) + session, err := repo.Query.SessionByID(ctx, true, sessionID, token, nil) if err != nil { return "", "", "", err } diff --git a/internal/command/instance.go b/internal/command/instance.go index c5ac4d8472..99075ccfad 100644 --- a/internal/command/instance.go +++ b/internal/command/instance.go @@ -4,9 +4,9 @@ import ( "context" "time" + "github.com/zitadel/logging" "golang.org/x/text/language" - "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/http" "github.com/zitadel/zitadel/internal/command/preparation" @@ -116,14 +116,15 @@ type InstanceSetup struct { MaxOTPAttempts uint64 ShouldShowLockoutFailure bool } - EmailTemplate []byte - MessageTexts []*domain.CustomMessageText - SMTPConfiguration *SMTPConfiguration - OIDCSettings *OIDCSettings - Quotas *SetQuotas - Features *InstanceFeatures - Limits *SetLimits - Restrictions *SetRestrictions + EmailTemplate []byte + MessageTexts []*domain.CustomMessageText + SMTPConfiguration *SMTPConfiguration + OIDCSettings *OIDCSettings + Quotas *SetQuotas + Features *InstanceFeatures + Limits *SetLimits + Restrictions *SetRestrictions + RolePermissionMappings []authz.RoleMapping } type SMTPConfiguration struct { @@ -379,6 +380,7 @@ func setupInstanceElements(instanceAgg *instance.Aggregate, setup *InstanceSetup setup.LabelPolicy.ThemeMode, ), prepareAddDefaultEmailTemplate(instanceAgg, setup.EmailTemplate), + prepareAddRolePermissions(instanceAgg, setup.RolePermissionMappings), } } diff --git a/internal/command/instance_features.go b/internal/command/instance_features.go index 44f122e98f..1f714671bd 100644 --- a/internal/command/instance_features.go +++ b/internal/command/instance_features.go @@ -29,6 +29,7 @@ type InstanceFeatures struct { DisableUserTokenEvent *bool EnableBackChannelLogout *bool LoginV2 *feature.LoginV2 + PermissionCheckV2 *bool } func (m *InstanceFeatures) isEmpty() bool { @@ -45,7 +46,8 @@ func (m *InstanceFeatures) isEmpty() bool { m.OIDCSingleV1SessionTermination == nil && m.DisableUserTokenEvent == nil && m.EnableBackChannelLogout == nil && - m.LoginV2 == nil + m.LoginV2 == nil && + m.PermissionCheckV2 == nil } func (c *Commands) SetInstanceFeatures(ctx context.Context, f *InstanceFeatures) (*domain.ObjectDetails, error) { diff --git a/internal/command/instance_features_model.go b/internal/command/instance_features_model.go index 8fa52318db..aaa8b2e53a 100644 --- a/internal/command/instance_features_model.go +++ b/internal/command/instance_features_model.go @@ -79,6 +79,7 @@ func (m *InstanceFeaturesWriteModel) Query() *eventstore.SearchQueryBuilder { feature_v2.InstanceDisableUserTokenEvent, feature_v2.InstanceEnableBackChannelLogout, feature_v2.InstanceLoginVersion, + feature_v2.InstancePermissionCheckV2, ). Builder().ResourceOwner(m.ResourceOwner) } @@ -129,6 +130,9 @@ func reduceInstanceFeature(features *InstanceFeatures, key feature.Key, value an features.EnableBackChannelLogout = &v case feature.KeyLoginV2: features.LoginV2 = value.(*feature.LoginV2) + case feature.KeyPermissionCheckV2: + v := value.(bool) + features.PermissionCheckV2 = &v } } @@ -148,5 +152,6 @@ func (wm *InstanceFeaturesWriteModel) setCommands(ctx context.Context, f *Instan cmds = appendFeatureUpdate(ctx, cmds, aggregate, wm.DisableUserTokenEvent, f.DisableUserTokenEvent, feature_v2.InstanceDisableUserTokenEvent) cmds = appendFeatureUpdate(ctx, cmds, aggregate, wm.EnableBackChannelLogout, f.EnableBackChannelLogout, feature_v2.InstanceEnableBackChannelLogout) cmds = appendFeatureUpdate(ctx, cmds, aggregate, wm.LoginV2, f.LoginV2, feature_v2.InstanceLoginVersion) + cmds = appendFeatureUpdate(ctx, cmds, aggregate, wm.PermissionCheckV2, f.PermissionCheckV2, feature_v2.InstancePermissionCheckV2) return cmds } diff --git a/internal/command/instance_permissions.go b/internal/command/instance_permissions.go new file mode 100644 index 0000000000..c46c8f7c4a --- /dev/null +++ b/internal/command/instance_permissions.go @@ -0,0 +1,29 @@ +package command + +import ( + "context" + "strings" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/command/preparation" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/repository/instance" + "github.com/zitadel/zitadel/internal/repository/permission" +) + +func prepareAddRolePermissions(a *instance.Aggregate, roles []authz.RoleMapping) preparation.Validation { + return func() (preparation.CreateCommands, error) { + return func(ctx context.Context, _ preparation.FilterToQueryReducer) (cmds []eventstore.Command, _ error) { + aggregate := permission.NewAggregate(a.InstanceID) + for _, r := range roles { + if strings.HasPrefix(r.Role, "SYSTEM") { + continue + } + for _, p := range r.Permissions { + cmds = append(cmds, permission.NewAddedEvent(ctx, aggregate, r.Role, p)) + } + } + return cmds, nil + }, nil + } +} diff --git a/internal/command/milestone.go b/internal/command/milestone.go index 11e6e5ab7f..e2f4fdc9de 100644 --- a/internal/command/milestone.go +++ b/internal/command/milestone.go @@ -4,6 +4,7 @@ import ( "context" "github.com/zitadel/logging" + "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/command/preparation" "github.com/zitadel/zitadel/internal/eventstore" diff --git a/internal/command/project_member_model.go b/internal/command/project_member_model.go index 4e78fb4f52..8e743e0a46 100644 --- a/internal/command/project_member_model.go +++ b/internal/command/project_member_model.go @@ -58,9 +58,9 @@ func (wm *ProjectMemberWriteModel) Query() *eventstore.SearchQueryBuilder { AddQuery(). AggregateTypes(project.AggregateType). AggregateIDs(wm.MemberWriteModel.AggregateID). - EventTypes(project.MemberAddedType, - project.MemberChangedType, - project.MemberRemovedType, - project.MemberCascadeRemovedType). + EventTypes(project.MemberAddedEventType, + project.MemberChangedEventType, + project.MemberRemovedEventType, + project.MemberCascadeRemovedEventType). Builder() } diff --git a/internal/command/system_features.go b/internal/command/system_features.go index eb10bba553..dc886de318 100644 --- a/internal/command/system_features.go +++ b/internal/command/system_features.go @@ -21,6 +21,7 @@ type SystemFeatures struct { DisableUserTokenEvent *bool EnableBackChannelLogout *bool LoginV2 *feature.LoginV2 + PermissionCheckV2 *bool } func (m *SystemFeatures) isEmpty() bool { @@ -35,7 +36,8 @@ func (m *SystemFeatures) isEmpty() bool { m.OIDCSingleV1SessionTermination == nil && m.DisableUserTokenEvent == nil && m.EnableBackChannelLogout == nil && - m.LoginV2 == nil + m.LoginV2 == nil && + m.PermissionCheckV2 == nil } func (c *Commands) SetSystemFeatures(ctx context.Context, f *SystemFeatures) (*domain.ObjectDetails, error) { diff --git a/internal/command/system_features_model.go b/internal/command/system_features_model.go index d656a6e266..15fc3e0bf0 100644 --- a/internal/command/system_features_model.go +++ b/internal/command/system_features_model.go @@ -70,6 +70,7 @@ func (m *SystemFeaturesWriteModel) Query() *eventstore.SearchQueryBuilder { feature_v2.SystemDisableUserTokenEvent, feature_v2.SystemEnableBackChannelLogout, feature_v2.SystemLoginVersion, + feature_v2.SystemPermissionCheckV2, ). Builder().ResourceOwner(m.ResourceOwner) } @@ -113,6 +114,9 @@ func reduceSystemFeature(features *SystemFeatures, key feature.Key, value any) { features.EnableBackChannelLogout = &v case feature.KeyLoginV2: features.LoginV2 = value.(*feature.LoginV2) + case feature.KeyPermissionCheckV2: + v := value.(bool) + features.PermissionCheckV2 = &v } } @@ -130,6 +134,7 @@ func (wm *SystemFeaturesWriteModel) setCommands(ctx context.Context, f *SystemFe cmds = appendFeatureUpdate(ctx, cmds, aggregate, wm.DisableUserTokenEvent, f.DisableUserTokenEvent, feature_v2.SystemDisableUserTokenEvent) cmds = appendFeatureUpdate(ctx, cmds, aggregate, wm.EnableBackChannelLogout, f.EnableBackChannelLogout, feature_v2.SystemEnableBackChannelLogout) cmds = appendFeatureUpdate(ctx, cmds, aggregate, wm.LoginV2, f.LoginV2, feature_v2.SystemLoginVersion) + cmds = appendFeatureUpdate(ctx, cmds, aggregate, wm.PermissionCheckV2, f.PermissionCheckV2, feature_v2.SystemPermissionCheckV2) return cmds } diff --git a/internal/command/user_human.go b/internal/command/user_human.go index ab2617c276..9e6ba43629 100644 --- a/internal/command/user_human.go +++ b/internal/command/user_human.go @@ -59,6 +59,8 @@ type AddHuman struct { Passwordless bool ExternalIDP bool Register bool + // SetInactive whether the user initially should be set as inactive + SetInactive bool // UserAgentID is optional and can be passed in case the user registered themselves. // This will be used in the login UI to handle authentication automatically. UserAgentID string diff --git a/internal/command/user_model.go b/internal/command/user_model.go index 0e68f0812c..d600f9d98a 100644 --- a/internal/command/user_model.go +++ b/internal/command/user_model.go @@ -137,6 +137,10 @@ func isUserStateInactive(state domain.UserState) bool { return hasUserState(state, domain.UserStateInactive) } +func isUserStateActive(state domain.UserState) bool { + return hasUserState(state, domain.UserStateActive) +} + func isUserStateInitial(state domain.UserState) bool { return hasUserState(state, domain.UserStateInitial) } diff --git a/internal/command/user_v2_human.go b/internal/command/user_v2_human.go index a85f905e05..fa627ec66e 100644 --- a/internal/command/user_v2_human.go +++ b/internal/command/user_v2_human.go @@ -14,11 +14,14 @@ import ( ) type ChangeHuman struct { - ID string - Username *string - Profile *Profile - Email *Email - Phone *Phone + ID string + State *domain.UserState + Username *string + Profile *Profile + Email *Email + Phone *Phone + Metadata []*domain.Metadata + MetadataKeysToRemove []string Password *Password @@ -100,6 +103,15 @@ func (h *ChangeHuman) Changed() bool { if h.Password != nil { return true } + if h.State != nil { + return true + } + if len(h.Metadata) > 0 { + return true + } + if len(h.MetadataKeysToRemove) > 0 { + return true + } return false } @@ -229,6 +241,10 @@ func (c *Commands) AddUserHuman(ctx context.Context, resourceOwner string, human ) } + if human.SetInactive { + cmds = append(cmds, user.NewUserDeactivatedEvent(ctx, &existingHuman.Aggregate().Aggregate)) + } + if len(cmds) == 0 { human.Details = writeModelToObjectDetails(&existingHuman.WriteModel) return nil @@ -270,6 +286,7 @@ func (c *Commands) ChangeUserHuman(ctx context.Context, human *ChangeHuman, alg } } + userAgg := UserAggregateFromWriteModelCtx(ctx, &existingHuman.WriteModel) cmds := make([]eventstore.Command, 0) if human.Username != nil { cmds, err = c.changeUsername(ctx, cmds, existingHuman, *human.Username) @@ -302,6 +319,58 @@ func (c *Commands) ChangeUserHuman(ctx context.Context, human *ChangeHuman, alg } } + for _, md := range human.Metadata { + cmd, err := c.setUserMetadata(ctx, userAgg, md) + if err != nil { + return err + } + + cmds = append(cmds, cmd) + } + + for _, mdKey := range human.MetadataKeysToRemove { + cmd, err := c.removeUserMetadata(ctx, userAgg, mdKey) + if err != nil { + return err + } + + cmds = append(cmds, cmd) + } + + if human.State != nil { + // only allow toggling between active and inactive + // any other target state is not supported + // the existing human's state has to be the + switch { + case isUserStateActive(*human.State): + if isUserStateActive(existingHuman.UserState) { + // user is already active => no change needed + break + } + + // do not allow switching from other states than active (e.g. locked) + if !isUserStateInactive(existingHuman.UserState) { + return zerrors.ThrowInvalidArgumentf(nil, "USER2-statex1", "Errors.User.State.Invalid") + } + + cmds = append(cmds, user.NewUserReactivatedEvent(ctx, &existingHuman.Aggregate().Aggregate)) + case isUserStateInactive(*human.State): + if isUserStateInactive(existingHuman.UserState) { + // user is already inactive => no change needed + break + } + + // do not allow switching from other states than active (e.g. locked) + if !isUserStateActive(existingHuman.UserState) { + return zerrors.ThrowInvalidArgumentf(nil, "USER2-statex2", "Errors.User.State.Invalid") + } + + cmds = append(cmds, user.NewUserDeactivatedEvent(ctx, &existingHuman.Aggregate().Aggregate)) + default: + return zerrors.ThrowInvalidArgumentf(nil, "USER2-statex3", "Errors.User.State.Invalid") + } + } + if len(cmds) == 0 { human.Details = writeModelToObjectDetails(&existingHuman.WriteModel) return nil diff --git a/internal/database/cockroach/crdb.go b/internal/database/cockroach/crdb.go index cc89be8687..48e912b5f5 100644 --- a/internal/database/cockroach/crdb.go +++ b/internal/database/cockroach/crdb.go @@ -3,7 +3,6 @@ package cockroach import ( "context" "database/sql" - "fmt" "strconv" "strings" "time" @@ -14,7 +13,6 @@ import ( "github.com/mitchellh/mapstructure" "github.com/zitadel/logging" - "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/database/dialect" ) @@ -74,19 +72,16 @@ func (_ *Config) Decode(configs []interface{}) (dialect.Connector, error) { return connector, nil } -func (c *Config) Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpose dialect.DBPurpose) (*sql.DB, *pgxpool.Pool, error) { +func (c *Config) Connect(useAdmin bool) (*sql.DB, *pgxpool.Pool, error) { dialect.RegisterAfterConnect(func(ctx context.Context, c *pgx.Conn) error { // CockroachDB by default does not allow multiple modifications of the same table using ON CONFLICT // This is needed to fill the fields table of the eventstore during eventstore.Push. _, err := c.Exec(ctx, "SET enable_multiple_modifications_of_table = on") return err }) - connConfig, err := dialect.NewConnectionConfig(c.MaxOpenConns, c.MaxIdleConns, pusherRatio, spoolerRatio, purpose) - if err != nil { - return nil, nil, err - } + connConfig := dialect.NewConnectionConfig(c.MaxOpenConns, c.MaxIdleConns) - config, err := pgxpool.ParseConfig(c.String(useAdmin, purpose.AppName())) + config, err := pgxpool.ParseConfig(c.String(useAdmin)) if err != nil { return nil, nil, err } @@ -102,18 +97,6 @@ func (c *Config) Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpo } } - // For the pusher we set the app name with the instance ID - if purpose == dialect.DBPurposeEventPusher { - config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool { - return setAppNameWithID(ctx, conn, purpose, authz.GetInstance(ctx).InstanceID()) - } - config.AfterRelease = func(conn *pgx.Conn) bool { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - return setAppNameWithID(ctx, conn, purpose, "IDLE") - } - } - if connConfig.MaxOpenConns != 0 { config.MaxConns = int32(connConfig.MaxOpenConns) } @@ -195,7 +178,7 @@ func (c *Config) checkSSL(user User) { } } -func (c Config) String(useAdmin bool, appName string) string { +func (c Config) String(useAdmin bool) string { user := c.User if useAdmin { user = c.Admin.User @@ -206,7 +189,7 @@ func (c Config) String(useAdmin bool, appName string) string { "port=" + strconv.Itoa(int(c.Port)), "user=" + user.Username, "dbname=" + c.Database, - "application_name=" + appName, + "application_name=" + dialect.DefaultAppName, "sslmode=" + user.SSL.Mode, } if c.Options != "" { @@ -232,11 +215,3 @@ func (c Config) String(useAdmin bool, appName string) string { return strings.Join(fields, " ") } - -func setAppNameWithID(ctx context.Context, conn *pgx.Conn, purpose dialect.DBPurpose, id string) bool { - // needs to be set like this because psql complains about parameters in the SET statement - query := fmt.Sprintf("SET application_name = '%s_%s'", purpose.AppName(), id) - _, err := conn.Exec(ctx, query) - logging.OnError(err).Warn("failed to set application name") - return err == nil -} diff --git a/internal/database/database.go b/internal/database/database.go index b86a9f247c..e254edadc1 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -65,10 +65,8 @@ func CloseTransaction(tx Tx, err error) error { } type Config struct { - Dialects map[string]interface{} `mapstructure:",remain"` - EventPushConnRatio float64 - ProjectionSpoolerConnRatio float64 - connector dialect.Connector + Dialects map[string]interface{} `mapstructure:",remain"` + connector dialect.Connector } func (c *Config) SetConnector(connector dialect.Connector) { @@ -134,8 +132,8 @@ func QueryJSONObject[T any](ctx context.Context, db *DB, query string, args ...a return obj, nil } -func Connect(config Config, useAdmin bool, purpose dialect.DBPurpose) (*DB, error) { - client, pool, err := config.connector.Connect(useAdmin, config.EventPushConnRatio, config.ProjectionSpoolerConnRatio, purpose) +func Connect(config Config, useAdmin bool) (*DB, error) { + client, pool, err := config.connector.Connect(useAdmin) if err != nil { return nil, err } diff --git a/internal/database/dialect/config.go b/internal/database/dialect/config.go index 8ca4e7f748..71fb477ea1 100644 --- a/internal/database/dialect/config.go +++ b/internal/database/dialect/config.go @@ -26,36 +26,11 @@ type Matcher interface { } const ( - QueryAppName = "zitadel_queries" - EventstorePusherAppName = "zitadel_es_pusher" - ProjectionSpoolerAppName = "zitadel_projection_spooler" - defaultAppName = "zitadel" + DefaultAppName = "zitadel" ) -// DBPurpose is what the resulting connection pool is used for. -type DBPurpose int - -const ( - DBPurposeQuery DBPurpose = iota - DBPurposeEventPusher - DBPurposeProjectionSpooler -) - -func (p DBPurpose) AppName() string { - switch p { - case DBPurposeQuery: - return QueryAppName - case DBPurposeEventPusher: - return EventstorePusherAppName - case DBPurposeProjectionSpooler: - return ProjectionSpoolerAppName - default: - return defaultAppName - } -} - type Connector interface { - Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpose DBPurpose) (*sql.DB, *pgxpool.Pool, error) + Connect(useAdmin bool) (*sql.DB, *pgxpool.Pool, error) Password() string Database } diff --git a/internal/database/dialect/config_test.go b/internal/database/dialect/config_test.go deleted file mode 100644 index d7297f8b67..0000000000 --- a/internal/database/dialect/config_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package dialect - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestDBPurpose_AppName(t *testing.T) { - tests := []struct { - p DBPurpose - want string - }{ - { - p: DBPurposeQuery, - want: QueryAppName, - }, - { - p: DBPurposeEventPusher, - want: EventstorePusherAppName, - }, - { - p: DBPurposeProjectionSpooler, - want: ProjectionSpoolerAppName, - }, - { - p: 99, - want: defaultAppName, - }, - } - for _, tt := range tests { - t.Run(tt.want, func(t *testing.T) { - assert.Equal(t, tt.want, tt.p.AppName()) - }) - } -} diff --git a/internal/database/dialect/connections.go b/internal/database/dialect/connections.go index f957870df0..13a4d657c3 100644 --- a/internal/database/dialect/connections.go +++ b/internal/database/dialect/connections.go @@ -3,7 +3,6 @@ package dialect import ( "context" "errors" - "fmt" "reflect" "github.com/jackc/pgx/v5" @@ -11,11 +10,8 @@ import ( ) var ( - ErrNegativeRatio = errors.New("ratio cannot be negative") - ErrHighSumRatio = errors.New("sum of pusher and projection ratios must be < 1") ErrIllegalMaxOpenConns = errors.New("MaxOpenConns of the database must be higher than 3 or 0 for unlimited") ErrIllegalMaxIdleConns = errors.New("MaxIdleConns of the database must be higher than 3 or 0 for unlimited") - ErrInvalidPurpose = errors.New("DBPurpose out of range") ) // ConnectionConfig defines the Max Open and Idle connections for a DB connection pool. @@ -25,28 +21,6 @@ type ConnectionConfig struct { AfterConnect []func(ctx context.Context, c *pgx.Conn) error } -// takeRatio of MaxOpenConns and MaxIdleConns from config and returns -// a new ConnectionConfig with the resulting values. -func (c *ConnectionConfig) takeRatio(ratio float64) (*ConnectionConfig, error) { - if ratio < 0 { - return nil, ErrNegativeRatio - } - - out := &ConnectionConfig{ - MaxOpenConns: uint32(ratio * float64(c.MaxOpenConns)), - MaxIdleConns: uint32(ratio * float64(c.MaxIdleConns)), - AfterConnect: c.AfterConnect, - } - if c.MaxOpenConns != 0 && out.MaxOpenConns < 1 && ratio > 0 { - out.MaxOpenConns = 1 - } - if c.MaxIdleConns != 0 && out.MaxIdleConns < 1 && ratio > 0 { - out.MaxIdleConns = 1 - } - - return out, nil -} - var afterConnectFuncs []func(ctx context.Context, c *pgx.Conn) error func RegisterAfterConnect(f func(ctx context.Context, c *pgx.Conn) error) { @@ -82,48 +56,10 @@ func RegisterDefaultPgTypeVariants[T any](m *pgtype.Map, name, arrayName string) // // openConns and idleConns must be at least 3 or 0, which means no limit. // The pusherRatio and spoolerRatio must be between 0 and 1. -func NewConnectionConfig(openConns, idleConns uint32, pusherRatio, projectionRatio float64, purpose DBPurpose) (*ConnectionConfig, error) { - if openConns != 0 && openConns < 3 { - return nil, ErrIllegalMaxOpenConns - } - if idleConns != 0 && idleConns < 3 { - return nil, ErrIllegalMaxIdleConns - } - if pusherRatio+projectionRatio >= 1 { - return nil, ErrHighSumRatio - } - - queryConfig := &ConnectionConfig{ +func NewConnectionConfig(openConns, idleConns uint32) *ConnectionConfig { + return &ConnectionConfig{ MaxOpenConns: openConns, MaxIdleConns: idleConns, AfterConnect: afterConnectFuncs, } - pusherConfig, err := queryConfig.takeRatio(pusherRatio) - if err != nil { - return nil, fmt.Errorf("event pusher: %w", err) - } - - spoolerConfig, err := queryConfig.takeRatio(projectionRatio) - if err != nil { - return nil, fmt.Errorf("projection spooler: %w", err) - } - - // subtract the claimed amount - if queryConfig.MaxOpenConns > 0 { - queryConfig.MaxOpenConns -= pusherConfig.MaxOpenConns + spoolerConfig.MaxOpenConns - } - if queryConfig.MaxIdleConns > 0 { - queryConfig.MaxIdleConns -= pusherConfig.MaxIdleConns + spoolerConfig.MaxIdleConns - } - - switch purpose { - case DBPurposeQuery: - return queryConfig, nil - case DBPurposeEventPusher: - return pusherConfig, nil - case DBPurposeProjectionSpooler: - return spoolerConfig, nil - default: - return nil, fmt.Errorf("%w: %v", ErrInvalidPurpose, purpose) - } } diff --git a/internal/database/dialect/connections_test.go b/internal/database/dialect/connections_test.go deleted file mode 100644 index 6256658d0a..0000000000 --- a/internal/database/dialect/connections_test.go +++ /dev/null @@ -1,252 +0,0 @@ -package dialect - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestConnectionConfig_takeRatio(t *testing.T) { - type fields struct { - MaxOpenConns uint32 - MaxIdleConns uint32 - } - tests := []struct { - name string - fields fields - ratio float64 - wantOut *ConnectionConfig - wantErr error - }{ - { - name: "ratio less than 0 error", - ratio: -0.1, - wantErr: ErrNegativeRatio, - }, - { - name: "zero values", - fields: fields{ - MaxOpenConns: 0, - MaxIdleConns: 0, - }, - ratio: 0, - wantOut: &ConnectionConfig{ - MaxOpenConns: 0, - MaxIdleConns: 0, - }, - }, - { - name: "max conns, ratio 0", - fields: fields{ - MaxOpenConns: 10, - MaxIdleConns: 5, - }, - ratio: 0, - wantOut: &ConnectionConfig{ - MaxOpenConns: 0, - MaxIdleConns: 0, - }, - }, - { - name: "half ratio", - fields: fields{ - MaxOpenConns: 10, - MaxIdleConns: 5, - }, - ratio: 0.5, - wantOut: &ConnectionConfig{ - MaxOpenConns: 5, - MaxIdleConns: 2, - }, - }, - { - name: "minimal 1", - fields: fields{ - MaxOpenConns: 2, - MaxIdleConns: 2, - }, - ratio: 0.1, - wantOut: &ConnectionConfig{ - MaxOpenConns: 1, - MaxIdleConns: 1, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - in := &ConnectionConfig{ - MaxOpenConns: tt.fields.MaxOpenConns, - MaxIdleConns: tt.fields.MaxIdleConns, - } - got, err := in.takeRatio(tt.ratio) - require.ErrorIs(t, err, tt.wantErr) - assert.Equal(t, tt.wantOut, got) - }) - } -} - -func TestNewConnectionConfig(t *testing.T) { - type args struct { - openConns uint32 - idleConns uint32 - pusherRatio float64 - projectionRatio float64 - purpose DBPurpose - } - tests := []struct { - name string - args args - want *ConnectionConfig - wantErr error - }{ - { - name: "illegal open conns error", - args: args{ - openConns: 2, - idleConns: 3, - }, - wantErr: ErrIllegalMaxOpenConns, - }, - { - name: "illegal idle conns error", - args: args{ - openConns: 3, - idleConns: 2, - }, - wantErr: ErrIllegalMaxIdleConns, - }, - { - name: "high ration sum error", - args: args{ - openConns: 3, - idleConns: 3, - pusherRatio: 0.5, - projectionRatio: 0.5, - }, - wantErr: ErrHighSumRatio, - }, - { - name: "illegal pusher ratio error", - args: args{ - openConns: 3, - idleConns: 3, - pusherRatio: -0.1, - projectionRatio: 0.5, - }, - wantErr: ErrNegativeRatio, - }, - { - name: "illegal projection ratio error", - args: args{ - openConns: 3, - idleConns: 3, - pusherRatio: 0.5, - projectionRatio: -0.1, - }, - wantErr: ErrNegativeRatio, - }, - { - name: "invalid purpose error", - args: args{ - openConns: 3, - idleConns: 3, - pusherRatio: 0.4, - projectionRatio: 0.4, - purpose: 99, - }, - wantErr: ErrInvalidPurpose, - }, - { - name: "min values, query purpose", - args: args{ - openConns: 3, - idleConns: 3, - pusherRatio: 0.2, - projectionRatio: 0.2, - purpose: DBPurposeQuery, - }, - want: &ConnectionConfig{ - MaxOpenConns: 1, - MaxIdleConns: 1, - }, - }, - { - name: "min values, pusher purpose", - args: args{ - openConns: 3, - idleConns: 3, - pusherRatio: 0.2, - projectionRatio: 0.2, - purpose: DBPurposeEventPusher, - }, - want: &ConnectionConfig{ - MaxOpenConns: 1, - MaxIdleConns: 1, - }, - }, - { - name: "min values, projection purpose", - args: args{ - openConns: 3, - idleConns: 3, - pusherRatio: 0.2, - projectionRatio: 0.2, - purpose: DBPurposeProjectionSpooler, - }, - want: &ConnectionConfig{ - MaxOpenConns: 1, - MaxIdleConns: 1, - }, - }, - { - name: "high values, query purpose", - args: args{ - openConns: 10, - idleConns: 5, - pusherRatio: 0.2, - projectionRatio: 0.2, - purpose: DBPurposeQuery, - }, - want: &ConnectionConfig{ - MaxOpenConns: 6, - MaxIdleConns: 3, - }, - }, - { - name: "high values, pusher purpose", - args: args{ - openConns: 10, - idleConns: 5, - pusherRatio: 0.2, - projectionRatio: 0.2, - purpose: DBPurposeEventPusher, - }, - want: &ConnectionConfig{ - MaxOpenConns: 2, - MaxIdleConns: 1, - }, - }, - { - name: "high values, projection purpose", - args: args{ - openConns: 10, - idleConns: 5, - pusherRatio: 0.2, - projectionRatio: 0.2, - purpose: DBPurposeProjectionSpooler, - }, - want: &ConnectionConfig{ - MaxOpenConns: 2, - MaxIdleConns: 1, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := NewConnectionConfig(tt.args.openConns, tt.args.idleConns, tt.args.pusherRatio, tt.args.projectionRatio, tt.args.purpose) - require.ErrorIs(t, err, tt.wantErr) - assert.Equal(t, tt.want, got) - }) - } -} diff --git a/internal/database/postgres/pg.go b/internal/database/postgres/pg.go index c12e122437..5f4d9a6c9b 100644 --- a/internal/database/postgres/pg.go +++ b/internal/database/postgres/pg.go @@ -3,7 +3,6 @@ package postgres import ( "context" "database/sql" - "fmt" "strconv" "strings" "time" @@ -14,7 +13,6 @@ import ( "github.com/mitchellh/mapstructure" "github.com/zitadel/logging" - "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/database/dialect" ) @@ -75,13 +73,10 @@ func (_ *Config) Decode(configs []interface{}) (dialect.Connector, error) { return connector, nil } -func (c *Config) Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpose dialect.DBPurpose) (*sql.DB, *pgxpool.Pool, error) { - connConfig, err := dialect.NewConnectionConfig(c.MaxOpenConns, c.MaxIdleConns, pusherRatio, spoolerRatio, purpose) - if err != nil { - return nil, nil, err - } +func (c *Config) Connect(useAdmin bool) (*sql.DB, *pgxpool.Pool, error) { + connConfig := dialect.NewConnectionConfig(c.MaxOpenConns, c.MaxIdleConns) - config, err := pgxpool.ParseConfig(c.String(useAdmin, purpose.AppName())) + config, err := pgxpool.ParseConfig(c.String(useAdmin)) if err != nil { return nil, nil, err } @@ -95,18 +90,6 @@ func (c *Config) Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpo return nil } - // For the pusher we set the app name with the instance ID - if purpose == dialect.DBPurposeEventPusher { - config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool { - return setAppNameWithID(ctx, conn, purpose, authz.GetInstance(ctx).InstanceID()) - } - config.AfterRelease = func(conn *pgx.Conn) bool { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - return setAppNameWithID(ctx, conn, purpose, "IDLE") - } - } - if connConfig.MaxOpenConns != 0 { config.MaxConns = int32(connConfig.MaxOpenConns) } @@ -191,7 +174,7 @@ func (s *Config) checkSSL(user User) { } } -func (c Config) String(useAdmin bool, appName string) string { +func (c Config) String(useAdmin bool) string { user := c.User if useAdmin { user = c.Admin.User @@ -201,7 +184,7 @@ func (c Config) String(useAdmin bool, appName string) string { "host=" + c.Host, "port=" + strconv.Itoa(int(c.Port)), "user=" + user.Username, - "application_name=" + appName, + "application_name=" + dialect.DefaultAppName, "sslmode=" + user.SSL.Mode, } if c.Options != "" { @@ -233,11 +216,3 @@ func (c Config) String(useAdmin bool, appName string) string { return strings.Join(fields, " ") } - -func setAppNameWithID(ctx context.Context, conn *pgx.Conn, purpose dialect.DBPurpose, id string) bool { - // needs to be set like this because psql complains about parameters in the SET statement - query := fmt.Sprintf("SET application_name = '%s_%s'", purpose.AppName(), id) - _, err := conn.Exec(ctx, query) - logging.OnError(err).Warn("failed to set application name") - return err == nil -} diff --git a/internal/domain/auth_request.go b/internal/domain/auth_request.go index 01b6ae25da..85ec340f67 100644 --- a/internal/domain/auth_request.go +++ b/internal/domain/auth_request.go @@ -60,6 +60,7 @@ type AuthRequest struct { DefaultTranslations []*CustomText OrgTranslations []*CustomText SAMLRequestID string + RequestLocalAuth bool // orgID the policies were last loaded with policyOrgID string // SessionID is set to the computed sessionID of the login session table diff --git a/internal/eventstore/handler/v2/field_handler.go b/internal/eventstore/handler/v2/field_handler.go index bbe40ed465..ad309ac790 100644 --- a/internal/eventstore/handler/v2/field_handler.go +++ b/internal/eventstore/handler/v2/field_handler.go @@ -32,6 +32,9 @@ func (f *fieldProjection) Reducers() []AggregateReducer { var _ Projection = (*fieldProjection)(nil) +// NewFieldHandler returns a projection handler which backfills the `eventstore.fields` table with historic events which +// might have existed before they had and Field Operations defined. +// The events are filtered by the mapped aggregate types and each event type for that aggregate. func NewFieldHandler(config *Config, name string, eventTypes map[eventstore.AggregateType][]eventstore.EventType) *FieldHandler { return &FieldHandler{ Handler: Handler{ @@ -51,6 +54,7 @@ func NewFieldHandler(config *Config, name string, eventTypes map[eventstore.Aggr } } +// Trigger executes the backfill job of events for the instance currently in the context. func (h *FieldHandler) Trigger(ctx context.Context, opts ...TriggerOpt) (err error) { config := new(triggerConfig) for _, opt := range opts { diff --git a/internal/eventstore/handler/v2/statement.go b/internal/eventstore/handler/v2/statement.go index 961881d24b..a02e5d3580 100644 --- a/internal/eventstore/handler/v2/statement.go +++ b/internal/eventstore/handler/v2/statement.go @@ -601,6 +601,12 @@ func NewCond(name string, value interface{}) Condition { } } +func NewUnequalCond(name string, value any) Condition { + return func(param string) (string, []any) { + return name + " <> " + param, []any{value} + } +} + func NewNamespacedCondition(name string, value interface{}) NamespacedCondition { return func(namespace string) Condition { return NewCond(namespace+"."+name, value) diff --git a/internal/eventstore/repository/sql/query.go b/internal/eventstore/repository/sql/query.go index b93e663b17..4e1cc87aff 100644 --- a/internal/eventstore/repository/sql/query.go +++ b/internal/eventstore/repository/sql/query.go @@ -309,7 +309,7 @@ func prepareConditions(criteria querier, query *repository.SearchQuery, useV1 bo } for i := range instanceIDs { - instanceIDs[i] = dialect.DBPurposeEventPusher.AppName() + "_" + instanceIDs[i] + instanceIDs[i] = "zitadel_es_pusher_" + instanceIDs[i] } clauses += awaitOpenTransactions(useV1) diff --git a/internal/eventstore/v3/event.go b/internal/eventstore/v3/event.go index da4e7a0383..1141a9eacf 100644 --- a/internal/eventstore/v3/event.go +++ b/internal/eventstore/v3/event.go @@ -7,6 +7,7 @@ import ( "time" "github.com/zitadel/logging" + "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/zerrors" diff --git a/internal/eventstore/v3/push.go b/internal/eventstore/v3/push.go index fb597021e2..6497b96ed8 100644 --- a/internal/eventstore/v3/push.go +++ b/internal/eventstore/v3/push.go @@ -4,9 +4,11 @@ import ( "context" "database/sql" _ "embed" + "fmt" "github.com/zitadel/logging" + "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -55,6 +57,11 @@ func (es *Eventstore) writeCommands(ctx context.Context, client database.Context }() } + _, err = tx.ExecContext(ctx, fmt.Sprintf("SET LOCAL application_name = '%s'", fmt.Sprintf("zitadel_es_pusher_%s", authz.GetInstance(ctx).InstanceID()))) + if err != nil { + return nil, err + } + events, err := writeEvents(ctx, tx, commands) if err != nil { return nil, err diff --git a/internal/eventstore/v3/sequence.go b/internal/eventstore/v3/sequence.go index 7d97e1080d..1976af4093 100644 --- a/internal/eventstore/v3/sequence.go +++ b/internal/eventstore/v3/sequence.go @@ -125,7 +125,18 @@ func scanToSequence(rows *sql.Rows, sequences []*latestSequence) error { return nil } sequence.sequence = currentSequence - if sequence.aggregate.ResourceOwner == "" { + if resourceOwner != "" && sequence.aggregate.ResourceOwner != "" && sequence.aggregate.ResourceOwner != resourceOwner { + logging.WithFields( + "current_sequence", sequence.sequence, + "instance_id", sequence.aggregate.InstanceID, + "agg_type", sequence.aggregate.Type, + "agg_id", sequence.aggregate.ID, + "current_owner", resourceOwner, + "provided_owner", sequence.aggregate.ResourceOwner, + ).Info("would have set wrong resource owner") + } + // set resource owner from previous events + if resourceOwner != "" { sequence.aggregate.ResourceOwner = resourceOwner } diff --git a/internal/feature/feature.go b/internal/feature/feature.go index 09fdf2ff52..d9a2d6352d 100644 --- a/internal/feature/feature.go +++ b/internal/feature/feature.go @@ -23,6 +23,7 @@ const ( KeyDisableUserTokenEvent KeyEnableBackChannelLogout KeyLoginV2 + KeyPermissionCheckV2 ) //go:generate enumer -type Level -transform snake -trimprefix Level @@ -52,6 +53,7 @@ type Features struct { DisableUserTokenEvent bool `json:"disable_user_token_event,omitempty"` EnableBackChannelLogout bool `json:"enable_back_channel_logout,omitempty"` LoginV2 LoginV2 `json:"login_v2,omitempty"` + PermissionCheckV2 bool `json:"permission_check_v2,omitempty"` } type ImprovedPerformanceType int32 diff --git a/internal/feature/key_enumer.go b/internal/feature/key_enumer.go index 462b751e6c..3a805df807 100644 --- a/internal/feature/key_enumer.go +++ b/internal/feature/key_enumer.go @@ -7,11 +7,11 @@ import ( "strings" ) -const _KeyName = "unspecifiedlogin_default_orgtrigger_introspection_projectionslegacy_introspectionuser_schematoken_exchangeactionsimproved_performanceweb_keydebug_oidc_parent_erroroidc_single_v1_session_terminationdisable_user_token_eventenable_back_channel_logoutlogin_v2" +const _KeyName = "unspecifiedlogin_default_orgtrigger_introspection_projectionslegacy_introspectionuser_schematoken_exchangeactionsimproved_performanceweb_keydebug_oidc_parent_erroroidc_single_v1_session_terminationdisable_user_token_eventenable_back_channel_logoutlogin_v2permission_check_v2" -var _KeyIndex = [...]uint8{0, 11, 28, 61, 81, 92, 106, 113, 133, 140, 163, 197, 221, 247, 255} +var _KeyIndex = [...]uint16{0, 11, 28, 61, 81, 92, 106, 113, 133, 140, 163, 197, 221, 247, 255, 274} -const _KeyLowerName = "unspecifiedlogin_default_orgtrigger_introspection_projectionslegacy_introspectionuser_schematoken_exchangeactionsimproved_performanceweb_keydebug_oidc_parent_erroroidc_single_v1_session_terminationdisable_user_token_eventenable_back_channel_logoutlogin_v2" +const _KeyLowerName = "unspecifiedlogin_default_orgtrigger_introspection_projectionslegacy_introspectionuser_schematoken_exchangeactionsimproved_performanceweb_keydebug_oidc_parent_erroroidc_single_v1_session_terminationdisable_user_token_eventenable_back_channel_logoutlogin_v2permission_check_v2" func (i Key) String() string { if i < 0 || i >= Key(len(_KeyIndex)-1) { @@ -38,9 +38,10 @@ func _KeyNoOp() { _ = x[KeyDisableUserTokenEvent-(11)] _ = x[KeyEnableBackChannelLogout-(12)] _ = x[KeyLoginV2-(13)] + _ = x[KeyPermissionCheckV2-(14)] } -var _KeyValues = []Key{KeyUnspecified, KeyLoginDefaultOrg, KeyTriggerIntrospectionProjections, KeyLegacyIntrospection, KeyUserSchema, KeyTokenExchange, KeyActions, KeyImprovedPerformance, KeyWebKey, KeyDebugOIDCParentError, KeyOIDCSingleV1SessionTermination, KeyDisableUserTokenEvent, KeyEnableBackChannelLogout, KeyLoginV2} +var _KeyValues = []Key{KeyUnspecified, KeyLoginDefaultOrg, KeyTriggerIntrospectionProjections, KeyLegacyIntrospection, KeyUserSchema, KeyTokenExchange, KeyActions, KeyImprovedPerformance, KeyWebKey, KeyDebugOIDCParentError, KeyOIDCSingleV1SessionTermination, KeyDisableUserTokenEvent, KeyEnableBackChannelLogout, KeyLoginV2, KeyPermissionCheckV2} var _KeyNameToValueMap = map[string]Key{ _KeyName[0:11]: KeyUnspecified, @@ -71,6 +72,8 @@ var _KeyNameToValueMap = map[string]Key{ _KeyLowerName[221:247]: KeyEnableBackChannelLogout, _KeyName[247:255]: KeyLoginV2, _KeyLowerName[247:255]: KeyLoginV2, + _KeyName[255:274]: KeyPermissionCheckV2, + _KeyLowerName[255:274]: KeyPermissionCheckV2, } var _KeyNames = []string{ @@ -88,6 +91,7 @@ var _KeyNames = []string{ _KeyName[197:221], _KeyName[221:247], _KeyName[247:255], + _KeyName[255:274], } // KeyString retrieves an enum value from the enum constants string name. diff --git a/internal/integration/assert.go b/internal/integration/assert.go index 6743c8297e..de35357dd7 100644 --- a/internal/integration/assert.go +++ b/internal/integration/assert.go @@ -1,11 +1,14 @@ package integration import ( + "reflect" "testing" "time" "github.com/pmezard/go-difflib/difflib" "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" @@ -128,6 +131,13 @@ func AssertResourceListDetails[D ResourceListDetailsMsg](t assert.TestingT, expe } } +func AssertGrpcStatus(t assert.TestingT, expected codes.Code, err error) { + assert.Error(t, err) + statusErr, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, expected, statusErr.Code()) +} + // EqualProto is inspired by [assert.Equal], only that it tests equality of a proto message. // A message diff is printed on the error test log if the messages are not equal. // @@ -160,3 +170,99 @@ func diffProto(expected, actual proto.Message) string { } return "\n\nDiff:\n" + diff } + +func AssertMapContains[M ~map[K]V, K comparable, V any](t assert.TestingT, m M, key K, expectedValue V) { + val, exists := m[key] + assert.True(t, exists, "Key '%s' should exist in the map", key) + if !exists { + return + } + + assert.Equal(t, expectedValue, val, "Key '%s' should have value '%d'", key, expectedValue) +} + +// PartiallyDeepEqual is similar to reflect.DeepEqual, +// but only compares exported non-zero fields of the expectedValue +func PartiallyDeepEqual(expected, actual interface{}) bool { + if expected == nil { + return actual == nil + } + + if actual == nil { + return false + } + + return partiallyDeepEqual(reflect.ValueOf(expected), reflect.ValueOf(actual)) +} + +func partiallyDeepEqual(expected, actual reflect.Value) bool { + // Dereference pointers if needed + if expected.Kind() == reflect.Ptr { + if expected.IsNil() { + return true + } + + expected = expected.Elem() + } + + if actual.Kind() == reflect.Ptr { + if actual.IsNil() { + return false + } + + actual = actual.Elem() + } + + if expected.Type() != actual.Type() { + return false + } + + switch expected.Kind() { //nolint:exhaustive + case reflect.Struct: + for i := 0; i < expected.NumField(); i++ { + field := expected.Type().Field(i) + if field.PkgPath != "" { // Skip unexported fields + continue + } + + expectedField := expected.Field(i) + actualField := actual.Field(i) + + // Skip zero-value fields in expected + if reflect.DeepEqual(expectedField.Interface(), reflect.Zero(expectedField.Type()).Interface()) { + continue + } + + // Compare fields recursively + if !partiallyDeepEqual(expectedField, actualField) { + return false + } + } + return true + + case reflect.Slice, reflect.Array: + if expected.Len() > actual.Len() { + return false + } + + for i := 0; i < expected.Len(); i++ { + if !partiallyDeepEqual(expected.Index(i), actual.Index(i)) { + return false + } + } + + return true + + default: + // Compare primitive types + return reflect.DeepEqual(expected.Interface(), actual.Interface()) + } +} + +func Must[T any](result T, error error) T { + if error != nil { + panic(error) + } + + return result +} diff --git a/internal/integration/assert_test.go b/internal/integration/assert_test.go index 0355ffec98..191078ffd1 100644 --- a/internal/integration/assert_test.go +++ b/internal/integration/assert_test.go @@ -50,3 +50,153 @@ func TestAssertDetails(t *testing.T) { }) } } + +func TestPartiallyDeepEqual(t *testing.T) { + type SecondaryNestedType struct { + Value int + } + type NestedType struct { + Value int + ValueSlice []int + Nested SecondaryNestedType + NestedPointer *SecondaryNestedType + } + + type args struct { + expected interface{} + actual interface{} + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "nil", + args: args{ + expected: nil, + actual: nil, + }, + want: true, + }, + { + name: "scalar value", + args: args{ + expected: 10, + actual: 10, + }, + want: true, + }, + { + name: "different scalar value", + args: args{ + expected: 11, + actual: 10, + }, + want: false, + }, + { + name: "string value", + args: args{ + expected: "foo", + actual: "foo", + }, + want: true, + }, + { + name: "different string value", + args: args{ + expected: "foo2", + actual: "foo", + }, + want: false, + }, + { + name: "scalar only set in actual", + args: args{ + expected: &SecondaryNestedType{}, + actual: &SecondaryNestedType{Value: 10}, + }, + want: true, + }, + { + name: "scalar equal", + args: args{ + expected: &SecondaryNestedType{Value: 10}, + actual: &SecondaryNestedType{Value: 10}, + }, + want: true, + }, + { + name: "scalar only set in expected", + args: args{ + expected: &SecondaryNestedType{Value: 10}, + actual: &SecondaryNestedType{}, + }, + want: false, + }, + { + name: "ptr only set in expected", + args: args{ + expected: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}}, + actual: &NestedType{}, + }, + want: false, + }, + { + name: "ptr only set in actual", + args: args{ + expected: &NestedType{}, + actual: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}}, + }, + want: true, + }, + { + name: "ptr equal", + args: args{ + expected: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}}, + actual: &NestedType{NestedPointer: &SecondaryNestedType{Value: 10}}, + }, + want: true, + }, + { + name: "nested equal", + args: args{ + expected: &NestedType{Nested: SecondaryNestedType{Value: 10}}, + actual: &NestedType{Nested: SecondaryNestedType{Value: 10}}, + }, + want: true, + }, + { + name: "slice equal", + args: args{ + expected: &NestedType{ValueSlice: []int{10, 20}}, + actual: &NestedType{ValueSlice: []int{10, 20}}, + }, + want: true, + }, + { + name: "slice additional in expected", + args: args{ + expected: &NestedType{ValueSlice: []int{10, 20, 30}}, + actual: &NestedType{ValueSlice: []int{10, 20}}, + }, + want: false, + }, + { + name: "slice additional in actual", + args: args{ + expected: &NestedType{ValueSlice: []int{10, 20}}, + actual: &NestedType{ValueSlice: []int{10, 20, 30}}, + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := PartiallyDeepEqual(tt.args.expected, tt.args.actual); got != tt.want { + t.Errorf("PartiallyDeepEqual() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/integration/client.go b/internal/integration/client.go index af30f0e642..d18c2d9b12 100644 --- a/internal/integration/client.go +++ b/internal/integration/client.go @@ -17,6 +17,7 @@ import ( "google.golang.org/protobuf/types/known/structpb" "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/integration/scim" "github.com/zitadel/zitadel/pkg/grpc/admin" "github.com/zitadel/zitadel/pkg/grpc/auth" "github.com/zitadel/zitadel/pkg/grpc/feature/v2" @@ -67,6 +68,7 @@ type Client struct { IDPv2 idp_pb.IdentityProviderServiceClient UserV3Alpha user_v3alpha.ZITADELUsersClient SAMLv2 saml_pb.SAMLServiceClient + SCIM *scim.Client } func newClient(ctx context.Context, target string) (*Client, error) { @@ -99,6 +101,7 @@ func newClient(ctx context.Context, target string) (*Client, error) { IDPv2: idp_pb.NewIdentityProviderServiceClient(cc), UserV3Alpha: user_v3alpha.NewZITADELUsersClient(cc), SAMLv2: saml_pb.NewSAMLServiceClient(cc), + SCIM: scim.NewScimClient(target), } return client, client.pollHealth(ctx) } diff --git a/internal/integration/scim/assertions.go b/internal/integration/scim/assertions.go new file mode 100644 index 0000000000..a91c33da82 --- /dev/null +++ b/internal/integration/scim/assertions.go @@ -0,0 +1,22 @@ +package scim + +import ( + "errors" + "strconv" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type AssertedScimError struct { + Error *ScimError +} + +func RequireScimError(t require.TestingT, httpStatus int, err error) AssertedScimError { + require.Error(t, err) + + var scimErr *ScimError + assert.True(t, errors.As(err, &scimErr)) + assert.Equal(t, strconv.Itoa(httpStatus), scimErr.Status) + return AssertedScimError{scimErr} // wrap it, otherwise error handling is enforced +} diff --git a/internal/integration/scim/client.go b/internal/integration/scim/client.go new file mode 100644 index 0000000000..262835a827 --- /dev/null +++ b/internal/integration/scim/client.go @@ -0,0 +1,153 @@ +package scim + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "path" + + "github.com/zitadel/logging" + "google.golang.org/grpc/metadata" + + zhttp "github.com/zitadel/zitadel/internal/api/http" + "github.com/zitadel/zitadel/internal/api/scim/middleware" + "github.com/zitadel/zitadel/internal/api/scim/resources" + "github.com/zitadel/zitadel/internal/api/scim/schemas" +) + +type Client struct { + Users *ResourceClient[resources.ScimUser] +} + +type ResourceClient[T any] struct { + client *http.Client + baseUrl string + resourceName string +} + +type ScimError struct { + Schemas []string `json:"schemas"` + ScimType string `json:"scimType"` + Detail string `json:"detail"` + Status string `json:"status"` + ZitadelDetail *ZitadelErrorDetail `json:"urn:ietf:params:scim:api:zitadel:messages:2.0:ErrorDetail,omitempty"` +} + +type ZitadelErrorDetail struct { + ID string `json:"id"` + Message string `json:"message"` +} + +func NewScimClient(target string) *Client { + target = "http://" + target + schemas.HandlerPrefix + client := &http.Client{} + return &Client{ + Users: &ResourceClient[resources.ScimUser]{ + client: client, + baseUrl: target, + resourceName: "Users", + }, + } +} + +func (c *ResourceClient[T]) Create(ctx context.Context, orgID string, body []byte) (*T, error) { + return c.doWithBody(ctx, http.MethodPost, orgID, "", bytes.NewReader(body)) +} + +func (c *ResourceClient[T]) Replace(ctx context.Context, orgID, id string, body []byte) (*T, error) { + return c.doWithBody(ctx, http.MethodPut, orgID, id, bytes.NewReader(body)) +} + +func (c *ResourceClient[T]) Get(ctx context.Context, orgID, resourceID string) (*T, error) { + return c.doWithBody(ctx, http.MethodGet, orgID, resourceID, nil) +} + +func (c *ResourceClient[T]) Delete(ctx context.Context, orgID, id string) error { + return c.do(ctx, http.MethodDelete, orgID, id) +} + +func (c *ResourceClient[T]) do(ctx context.Context, method, orgID, url string) error { + req, err := http.NewRequestWithContext(ctx, method, c.buildURL(orgID, url), nil) + if err != nil { + return err + } + + return c.doReq(req, nil) +} + +func (c *ResourceClient[T]) doWithBody(ctx context.Context, method, orgID, url string, body io.Reader) (*T, error) { + req, err := http.NewRequestWithContext(ctx, method, c.buildURL(orgID, url), body) + if err != nil { + return nil, err + } + + req.Header.Set(zhttp.ContentType, middleware.ContentTypeScim) + responseEntity := new(T) + return responseEntity, c.doReq(req, responseEntity) +} + +func (c *ResourceClient[T]) doReq(req *http.Request, responseEntity *T) error { + addTokenAsHeader(req) + + resp, err := c.client.Do(req) + defer func() { + err := resp.Body.Close() + logging.OnError(err).Error("Failed to close response body") + }() + + if err != nil { + return err + } + + if (resp.StatusCode / 100) != 2 { + return readScimError(resp) + } + + if responseEntity == nil { + return nil + } + + err = readJson(responseEntity, resp) + return err +} + +func addTokenAsHeader(req *http.Request) { + md, ok := metadata.FromOutgoingContext(req.Context()) + if !ok { + return + } + + req.Header.Set("Authorization", md.Get("Authorization")[0]) +} + +func readJson(entity interface{}, resp *http.Response) error { + defer func(body io.ReadCloser) { + err := body.Close() + logging.OnError(err).Panic("Failed to close response body") + }(resp.Body) + + err := json.NewDecoder(resp.Body).Decode(entity) + logging.OnError(err).Panic("Failed decoding entity") + return err +} + +func readScimError(resp *http.Response) error { + scimErr := new(ScimError) + readErr := readJson(scimErr, resp) + logging.OnError(readErr).Panic("Failed reading scim error") + return scimErr +} + +func (c *ResourceClient[T]) buildURL(orgID, segment string) string { + if segment == "" { + return c.baseUrl + "/" + path.Join(orgID, c.resourceName) + } + + return c.baseUrl + "/" + path.Join(orgID, c.resourceName, segment) +} + +func (err *ScimError) Error() string { + return "scim error: " + err.Detail +} diff --git a/internal/notification/handlers/mock/commands.mock.go b/internal/notification/handlers/mock/commands.mock.go index ee6eb3c6b1..de32ce067c 100644 --- a/internal/notification/handlers/mock/commands.mock.go +++ b/internal/notification/handlers/mock/commands.mock.go @@ -25,7 +25,6 @@ import ( type MockCommands struct { ctrl *gomock.Controller recorder *MockCommandsMockRecorder - isgomock struct{} } // MockCommandsMockRecorder is the mock recorder for MockCommands. @@ -46,253 +45,253 @@ func (m *MockCommands) EXPECT() *MockCommandsMockRecorder { } // HumanEmailVerificationCodeSent mocks base method. -func (m *MockCommands) HumanEmailVerificationCodeSent(ctx context.Context, orgID, userID string) error { +func (m *MockCommands) HumanEmailVerificationCodeSent(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HumanEmailVerificationCodeSent", ctx, orgID, userID) + ret := m.ctrl.Call(m, "HumanEmailVerificationCodeSent", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // HumanEmailVerificationCodeSent indicates an expected call of HumanEmailVerificationCodeSent. -func (mr *MockCommandsMockRecorder) HumanEmailVerificationCodeSent(ctx, orgID, userID any) *gomock.Call { +func (mr *MockCommandsMockRecorder) HumanEmailVerificationCodeSent(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanEmailVerificationCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanEmailVerificationCodeSent), ctx, orgID, userID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanEmailVerificationCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanEmailVerificationCodeSent), arg0, arg1, arg2) } // HumanInitCodeSent mocks base method. -func (m *MockCommands) HumanInitCodeSent(ctx context.Context, orgID, userID string) error { +func (m *MockCommands) HumanInitCodeSent(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HumanInitCodeSent", ctx, orgID, userID) + ret := m.ctrl.Call(m, "HumanInitCodeSent", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // HumanInitCodeSent indicates an expected call of HumanInitCodeSent. -func (mr *MockCommandsMockRecorder) HumanInitCodeSent(ctx, orgID, userID any) *gomock.Call { +func (mr *MockCommandsMockRecorder) HumanInitCodeSent(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanInitCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanInitCodeSent), ctx, orgID, userID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanInitCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanInitCodeSent), arg0, arg1, arg2) } // HumanOTPEmailCodeSent mocks base method. -func (m *MockCommands) HumanOTPEmailCodeSent(ctx context.Context, userID, resourceOwner string) error { +func (m *MockCommands) HumanOTPEmailCodeSent(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HumanOTPEmailCodeSent", ctx, userID, resourceOwner) + ret := m.ctrl.Call(m, "HumanOTPEmailCodeSent", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // HumanOTPEmailCodeSent indicates an expected call of HumanOTPEmailCodeSent. -func (mr *MockCommandsMockRecorder) HumanOTPEmailCodeSent(ctx, userID, resourceOwner any) *gomock.Call { +func (mr *MockCommandsMockRecorder) HumanOTPEmailCodeSent(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanOTPEmailCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanOTPEmailCodeSent), ctx, userID, resourceOwner) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanOTPEmailCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanOTPEmailCodeSent), arg0, arg1, arg2) } // HumanOTPSMSCodeSent mocks base method. -func (m *MockCommands) HumanOTPSMSCodeSent(ctx context.Context, userID, resourceOwner string, generatorInfo *senders.CodeGeneratorInfo) error { +func (m *MockCommands) HumanOTPSMSCodeSent(arg0 context.Context, arg1, arg2 string, arg3 *senders.CodeGeneratorInfo) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HumanOTPSMSCodeSent", ctx, userID, resourceOwner, generatorInfo) + ret := m.ctrl.Call(m, "HumanOTPSMSCodeSent", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) return ret0 } // HumanOTPSMSCodeSent indicates an expected call of HumanOTPSMSCodeSent. -func (mr *MockCommandsMockRecorder) HumanOTPSMSCodeSent(ctx, userID, resourceOwner, generatorInfo any) *gomock.Call { +func (mr *MockCommandsMockRecorder) HumanOTPSMSCodeSent(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanOTPSMSCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanOTPSMSCodeSent), ctx, userID, resourceOwner, generatorInfo) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanOTPSMSCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanOTPSMSCodeSent), arg0, arg1, arg2, arg3) } // HumanPasswordlessInitCodeSent mocks base method. -func (m *MockCommands) HumanPasswordlessInitCodeSent(ctx context.Context, userID, resourceOwner, codeID string) error { +func (m *MockCommands) HumanPasswordlessInitCodeSent(arg0 context.Context, arg1, arg2, arg3 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HumanPasswordlessInitCodeSent", ctx, userID, resourceOwner, codeID) + ret := m.ctrl.Call(m, "HumanPasswordlessInitCodeSent", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) return ret0 } // HumanPasswordlessInitCodeSent indicates an expected call of HumanPasswordlessInitCodeSent. -func (mr *MockCommandsMockRecorder) HumanPasswordlessInitCodeSent(ctx, userID, resourceOwner, codeID any) *gomock.Call { +func (mr *MockCommandsMockRecorder) HumanPasswordlessInitCodeSent(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanPasswordlessInitCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanPasswordlessInitCodeSent), ctx, userID, resourceOwner, codeID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanPasswordlessInitCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanPasswordlessInitCodeSent), arg0, arg1, arg2, arg3) } // HumanPhoneVerificationCodeSent mocks base method. -func (m *MockCommands) HumanPhoneVerificationCodeSent(ctx context.Context, orgID, userID string, generatorInfo *senders.CodeGeneratorInfo) error { +func (m *MockCommands) HumanPhoneVerificationCodeSent(arg0 context.Context, arg1, arg2 string, arg3 *senders.CodeGeneratorInfo) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HumanPhoneVerificationCodeSent", ctx, orgID, userID, generatorInfo) + ret := m.ctrl.Call(m, "HumanPhoneVerificationCodeSent", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) return ret0 } // HumanPhoneVerificationCodeSent indicates an expected call of HumanPhoneVerificationCodeSent. -func (mr *MockCommandsMockRecorder) HumanPhoneVerificationCodeSent(ctx, orgID, userID, generatorInfo any) *gomock.Call { +func (mr *MockCommandsMockRecorder) HumanPhoneVerificationCodeSent(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanPhoneVerificationCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanPhoneVerificationCodeSent), ctx, orgID, userID, generatorInfo) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanPhoneVerificationCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanPhoneVerificationCodeSent), arg0, arg1, arg2, arg3) } // InviteCodeSent mocks base method. -func (m *MockCommands) InviteCodeSent(ctx context.Context, orgID, userID string) error { +func (m *MockCommands) InviteCodeSent(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "InviteCodeSent", ctx, orgID, userID) + ret := m.ctrl.Call(m, "InviteCodeSent", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // InviteCodeSent indicates an expected call of InviteCodeSent. -func (mr *MockCommandsMockRecorder) InviteCodeSent(ctx, orgID, userID any) *gomock.Call { +func (mr *MockCommandsMockRecorder) InviteCodeSent(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InviteCodeSent", reflect.TypeOf((*MockCommands)(nil).InviteCodeSent), ctx, orgID, userID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InviteCodeSent", reflect.TypeOf((*MockCommands)(nil).InviteCodeSent), arg0, arg1, arg2) } // MilestonePushed mocks base method. -func (m *MockCommands) MilestonePushed(ctx context.Context, instanceID string, msType milestone.Type, endpoints []string) error { +func (m *MockCommands) MilestonePushed(arg0 context.Context, arg1 string, arg2 milestone.Type, arg3 []string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MilestonePushed", ctx, instanceID, msType, endpoints) + ret := m.ctrl.Call(m, "MilestonePushed", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) return ret0 } // MilestonePushed indicates an expected call of MilestonePushed. -func (mr *MockCommandsMockRecorder) MilestonePushed(ctx, instanceID, msType, endpoints any) *gomock.Call { +func (mr *MockCommandsMockRecorder) MilestonePushed(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MilestonePushed", reflect.TypeOf((*MockCommands)(nil).MilestonePushed), ctx, instanceID, msType, endpoints) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MilestonePushed", reflect.TypeOf((*MockCommands)(nil).MilestonePushed), arg0, arg1, arg2, arg3) } // NotificationCanceled mocks base method. -func (m *MockCommands) NotificationCanceled(ctx context.Context, tx *sql.Tx, id, resourceOwner string, err error) error { +func (m *MockCommands) NotificationCanceled(arg0 context.Context, arg1 *sql.Tx, arg2, arg3 string, arg4 error) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NotificationCanceled", ctx, tx, id, resourceOwner, err) + ret := m.ctrl.Call(m, "NotificationCanceled", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].(error) return ret0 } // NotificationCanceled indicates an expected call of NotificationCanceled. -func (mr *MockCommandsMockRecorder) NotificationCanceled(ctx, tx, id, resourceOwner, err any) *gomock.Call { +func (mr *MockCommandsMockRecorder) NotificationCanceled(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationCanceled", reflect.TypeOf((*MockCommands)(nil).NotificationCanceled), ctx, tx, id, resourceOwner, err) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationCanceled", reflect.TypeOf((*MockCommands)(nil).NotificationCanceled), arg0, arg1, arg2, arg3, arg4) } // NotificationRetryRequested mocks base method. -func (m *MockCommands) NotificationRetryRequested(ctx context.Context, tx *sql.Tx, id, resourceOwner string, request *command.NotificationRetryRequest, err error) error { +func (m *MockCommands) NotificationRetryRequested(arg0 context.Context, arg1 *sql.Tx, arg2, arg3 string, arg4 *command.NotificationRetryRequest, arg5 error) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NotificationRetryRequested", ctx, tx, id, resourceOwner, request, err) + ret := m.ctrl.Call(m, "NotificationRetryRequested", arg0, arg1, arg2, arg3, arg4, arg5) ret0, _ := ret[0].(error) return ret0 } // NotificationRetryRequested indicates an expected call of NotificationRetryRequested. -func (mr *MockCommandsMockRecorder) NotificationRetryRequested(ctx, tx, id, resourceOwner, request, err any) *gomock.Call { +func (mr *MockCommandsMockRecorder) NotificationRetryRequested(arg0, arg1, arg2, arg3, arg4, arg5 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationRetryRequested", reflect.TypeOf((*MockCommands)(nil).NotificationRetryRequested), ctx, tx, id, resourceOwner, request, err) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationRetryRequested", reflect.TypeOf((*MockCommands)(nil).NotificationRetryRequested), arg0, arg1, arg2, arg3, arg4, arg5) } // NotificationSent mocks base method. -func (m *MockCommands) NotificationSent(ctx context.Context, tx *sql.Tx, id, instanceID string) error { +func (m *MockCommands) NotificationSent(arg0 context.Context, arg1 *sql.Tx, arg2, arg3 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NotificationSent", ctx, tx, id, instanceID) + ret := m.ctrl.Call(m, "NotificationSent", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) return ret0 } // NotificationSent indicates an expected call of NotificationSent. -func (mr *MockCommandsMockRecorder) NotificationSent(ctx, tx, id, instanceID any) *gomock.Call { +func (mr *MockCommandsMockRecorder) NotificationSent(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationSent", reflect.TypeOf((*MockCommands)(nil).NotificationSent), ctx, tx, id, instanceID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationSent", reflect.TypeOf((*MockCommands)(nil).NotificationSent), arg0, arg1, arg2, arg3) } // OTPEmailSent mocks base method. -func (m *MockCommands) OTPEmailSent(ctx context.Context, sessionID, resourceOwner string) error { +func (m *MockCommands) OTPEmailSent(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OTPEmailSent", ctx, sessionID, resourceOwner) + ret := m.ctrl.Call(m, "OTPEmailSent", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // OTPEmailSent indicates an expected call of OTPEmailSent. -func (mr *MockCommandsMockRecorder) OTPEmailSent(ctx, sessionID, resourceOwner any) *gomock.Call { +func (mr *MockCommandsMockRecorder) OTPEmailSent(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OTPEmailSent", reflect.TypeOf((*MockCommands)(nil).OTPEmailSent), ctx, sessionID, resourceOwner) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OTPEmailSent", reflect.TypeOf((*MockCommands)(nil).OTPEmailSent), arg0, arg1, arg2) } // OTPSMSSent mocks base method. -func (m *MockCommands) OTPSMSSent(ctx context.Context, sessionID, resourceOwner string, generatorInfo *senders.CodeGeneratorInfo) error { +func (m *MockCommands) OTPSMSSent(arg0 context.Context, arg1, arg2 string, arg3 *senders.CodeGeneratorInfo) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OTPSMSSent", ctx, sessionID, resourceOwner, generatorInfo) + ret := m.ctrl.Call(m, "OTPSMSSent", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) return ret0 } // OTPSMSSent indicates an expected call of OTPSMSSent. -func (mr *MockCommandsMockRecorder) OTPSMSSent(ctx, sessionID, resourceOwner, generatorInfo any) *gomock.Call { +func (mr *MockCommandsMockRecorder) OTPSMSSent(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OTPSMSSent", reflect.TypeOf((*MockCommands)(nil).OTPSMSSent), ctx, sessionID, resourceOwner, generatorInfo) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OTPSMSSent", reflect.TypeOf((*MockCommands)(nil).OTPSMSSent), arg0, arg1, arg2, arg3) } // PasswordChangeSent mocks base method. -func (m *MockCommands) PasswordChangeSent(ctx context.Context, orgID, userID string) error { +func (m *MockCommands) PasswordChangeSent(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PasswordChangeSent", ctx, orgID, userID) + ret := m.ctrl.Call(m, "PasswordChangeSent", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // PasswordChangeSent indicates an expected call of PasswordChangeSent. -func (mr *MockCommandsMockRecorder) PasswordChangeSent(ctx, orgID, userID any) *gomock.Call { +func (mr *MockCommandsMockRecorder) PasswordChangeSent(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordChangeSent", reflect.TypeOf((*MockCommands)(nil).PasswordChangeSent), ctx, orgID, userID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordChangeSent", reflect.TypeOf((*MockCommands)(nil).PasswordChangeSent), arg0, arg1, arg2) } // PasswordCodeSent mocks base method. -func (m *MockCommands) PasswordCodeSent(ctx context.Context, orgID, userID string, generatorInfo *senders.CodeGeneratorInfo) error { +func (m *MockCommands) PasswordCodeSent(arg0 context.Context, arg1, arg2 string, arg3 *senders.CodeGeneratorInfo) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PasswordCodeSent", ctx, orgID, userID, generatorInfo) + ret := m.ctrl.Call(m, "PasswordCodeSent", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) return ret0 } // PasswordCodeSent indicates an expected call of PasswordCodeSent. -func (mr *MockCommandsMockRecorder) PasswordCodeSent(ctx, orgID, userID, generatorInfo any) *gomock.Call { +func (mr *MockCommandsMockRecorder) PasswordCodeSent(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordCodeSent", reflect.TypeOf((*MockCommands)(nil).PasswordCodeSent), ctx, orgID, userID, generatorInfo) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordCodeSent", reflect.TypeOf((*MockCommands)(nil).PasswordCodeSent), arg0, arg1, arg2, arg3) } // RequestNotification mocks base method. -func (m *MockCommands) RequestNotification(ctx context.Context, instanceID string, request *command.NotificationRequest) error { +func (m *MockCommands) RequestNotification(arg0 context.Context, arg1 string, arg2 *command.NotificationRequest) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RequestNotification", ctx, instanceID, request) + ret := m.ctrl.Call(m, "RequestNotification", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // RequestNotification indicates an expected call of RequestNotification. -func (mr *MockCommandsMockRecorder) RequestNotification(ctx, instanceID, request any) *gomock.Call { +func (mr *MockCommandsMockRecorder) RequestNotification(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestNotification", reflect.TypeOf((*MockCommands)(nil).RequestNotification), ctx, instanceID, request) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestNotification", reflect.TypeOf((*MockCommands)(nil).RequestNotification), arg0, arg1, arg2) } // UsageNotificationSent mocks base method. -func (m *MockCommands) UsageNotificationSent(ctx context.Context, dueEvent *quota.NotificationDueEvent) error { +func (m *MockCommands) UsageNotificationSent(arg0 context.Context, arg1 *quota.NotificationDueEvent) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UsageNotificationSent", ctx, dueEvent) + ret := m.ctrl.Call(m, "UsageNotificationSent", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // UsageNotificationSent indicates an expected call of UsageNotificationSent. -func (mr *MockCommandsMockRecorder) UsageNotificationSent(ctx, dueEvent any) *gomock.Call { +func (mr *MockCommandsMockRecorder) UsageNotificationSent(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UsageNotificationSent", reflect.TypeOf((*MockCommands)(nil).UsageNotificationSent), ctx, dueEvent) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UsageNotificationSent", reflect.TypeOf((*MockCommands)(nil).UsageNotificationSent), arg0, arg1) } // UserDomainClaimedSent mocks base method. -func (m *MockCommands) UserDomainClaimedSent(ctx context.Context, orgID, userID string) error { +func (m *MockCommands) UserDomainClaimedSent(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UserDomainClaimedSent", ctx, orgID, userID) + ret := m.ctrl.Call(m, "UserDomainClaimedSent", arg0, arg1, arg2) ret0, _ := ret[0].(error) return ret0 } // UserDomainClaimedSent indicates an expected call of UserDomainClaimedSent. -func (mr *MockCommandsMockRecorder) UserDomainClaimedSent(ctx, orgID, userID any) *gomock.Call { +func (mr *MockCommandsMockRecorder) UserDomainClaimedSent(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserDomainClaimedSent", reflect.TypeOf((*MockCommands)(nil).UserDomainClaimedSent), ctx, orgID, userID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserDomainClaimedSent", reflect.TypeOf((*MockCommands)(nil).UserDomainClaimedSent), arg0, arg1, arg2) } diff --git a/internal/notification/handlers/mock/queries.mock.go b/internal/notification/handlers/mock/queries.mock.go index 5ead216437..670d3f3896 100644 --- a/internal/notification/handlers/mock/queries.mock.go +++ b/internal/notification/handlers/mock/queries.mock.go @@ -26,7 +26,6 @@ import ( type MockQueries struct { ctrl *gomock.Controller recorder *MockQueriesMockRecorder - isgomock struct{} } // MockQueriesMockRecorder is the mock recorder for MockQueries. @@ -61,240 +60,240 @@ func (mr *MockQueriesMockRecorder) ActiveInstances() *gomock.Call { } // ActiveLabelPolicyByOrg mocks base method. -func (m *MockQueries) ActiveLabelPolicyByOrg(ctx context.Context, orgID string, withOwnerRemoved bool) (*query.LabelPolicy, error) { +func (m *MockQueries) ActiveLabelPolicyByOrg(arg0 context.Context, arg1 string, arg2 bool) (*query.LabelPolicy, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ActiveLabelPolicyByOrg", ctx, orgID, withOwnerRemoved) + ret := m.ctrl.Call(m, "ActiveLabelPolicyByOrg", arg0, arg1, arg2) ret0, _ := ret[0].(*query.LabelPolicy) ret1, _ := ret[1].(error) return ret0, ret1 } // ActiveLabelPolicyByOrg indicates an expected call of ActiveLabelPolicyByOrg. -func (mr *MockQueriesMockRecorder) ActiveLabelPolicyByOrg(ctx, orgID, withOwnerRemoved any) *gomock.Call { +func (mr *MockQueriesMockRecorder) ActiveLabelPolicyByOrg(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActiveLabelPolicyByOrg", reflect.TypeOf((*MockQueries)(nil).ActiveLabelPolicyByOrg), ctx, orgID, withOwnerRemoved) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActiveLabelPolicyByOrg", reflect.TypeOf((*MockQueries)(nil).ActiveLabelPolicyByOrg), arg0, arg1, arg2) } // ActivePrivateSigningKey mocks base method. -func (m *MockQueries) ActivePrivateSigningKey(ctx context.Context, t time.Time) (*query.PrivateKeys, error) { +func (m *MockQueries) ActivePrivateSigningKey(arg0 context.Context, arg1 time.Time) (*query.PrivateKeys, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ActivePrivateSigningKey", ctx, t) + ret := m.ctrl.Call(m, "ActivePrivateSigningKey", arg0, arg1) ret0, _ := ret[0].(*query.PrivateKeys) ret1, _ := ret[1].(error) return ret0, ret1 } // ActivePrivateSigningKey indicates an expected call of ActivePrivateSigningKey. -func (mr *MockQueriesMockRecorder) ActivePrivateSigningKey(ctx, t any) *gomock.Call { +func (mr *MockQueriesMockRecorder) ActivePrivateSigningKey(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActivePrivateSigningKey", reflect.TypeOf((*MockQueries)(nil).ActivePrivateSigningKey), ctx, t) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActivePrivateSigningKey", reflect.TypeOf((*MockQueries)(nil).ActivePrivateSigningKey), arg0, arg1) } // CustomTextListByTemplate mocks base method. -func (m *MockQueries) CustomTextListByTemplate(ctx context.Context, aggregateID, template string, withOwnerRemoved bool) (*query.CustomTexts, error) { +func (m *MockQueries) CustomTextListByTemplate(arg0 context.Context, arg1, arg2 string, arg3 bool) (*query.CustomTexts, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CustomTextListByTemplate", ctx, aggregateID, template, withOwnerRemoved) + ret := m.ctrl.Call(m, "CustomTextListByTemplate", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(*query.CustomTexts) ret1, _ := ret[1].(error) return ret0, ret1 } // CustomTextListByTemplate indicates an expected call of CustomTextListByTemplate. -func (mr *MockQueriesMockRecorder) CustomTextListByTemplate(ctx, aggregateID, template, withOwnerRemoved any) *gomock.Call { +func (mr *MockQueriesMockRecorder) CustomTextListByTemplate(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CustomTextListByTemplate", reflect.TypeOf((*MockQueries)(nil).CustomTextListByTemplate), ctx, aggregateID, template, withOwnerRemoved) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CustomTextListByTemplate", reflect.TypeOf((*MockQueries)(nil).CustomTextListByTemplate), arg0, arg1, arg2, arg3) } // GetActiveSigningWebKey mocks base method. -func (m *MockQueries) GetActiveSigningWebKey(ctx context.Context) (*jose.JSONWebKey, error) { +func (m *MockQueries) GetActiveSigningWebKey(arg0 context.Context) (*jose.JSONWebKey, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetActiveSigningWebKey", ctx) + ret := m.ctrl.Call(m, "GetActiveSigningWebKey", arg0) ret0, _ := ret[0].(*jose.JSONWebKey) ret1, _ := ret[1].(error) return ret0, ret1 } // GetActiveSigningWebKey indicates an expected call of GetActiveSigningWebKey. -func (mr *MockQueriesMockRecorder) GetActiveSigningWebKey(ctx any) *gomock.Call { +func (mr *MockQueriesMockRecorder) GetActiveSigningWebKey(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveSigningWebKey", reflect.TypeOf((*MockQueries)(nil).GetActiveSigningWebKey), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveSigningWebKey", reflect.TypeOf((*MockQueries)(nil).GetActiveSigningWebKey), arg0) } // GetDefaultLanguage mocks base method. -func (m *MockQueries) GetDefaultLanguage(ctx context.Context) language.Tag { +func (m *MockQueries) GetDefaultLanguage(arg0 context.Context) language.Tag { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetDefaultLanguage", ctx) + ret := m.ctrl.Call(m, "GetDefaultLanguage", arg0) ret0, _ := ret[0].(language.Tag) return ret0 } // GetDefaultLanguage indicates an expected call of GetDefaultLanguage. -func (mr *MockQueriesMockRecorder) GetDefaultLanguage(ctx any) *gomock.Call { +func (mr *MockQueriesMockRecorder) GetDefaultLanguage(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDefaultLanguage", reflect.TypeOf((*MockQueries)(nil).GetDefaultLanguage), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDefaultLanguage", reflect.TypeOf((*MockQueries)(nil).GetDefaultLanguage), arg0) } // GetInstanceRestrictions mocks base method. -func (m *MockQueries) GetInstanceRestrictions(ctx context.Context) (query.Restrictions, error) { +func (m *MockQueries) GetInstanceRestrictions(arg0 context.Context) (query.Restrictions, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetInstanceRestrictions", ctx) + ret := m.ctrl.Call(m, "GetInstanceRestrictions", arg0) ret0, _ := ret[0].(query.Restrictions) ret1, _ := ret[1].(error) return ret0, ret1 } // GetInstanceRestrictions indicates an expected call of GetInstanceRestrictions. -func (mr *MockQueriesMockRecorder) GetInstanceRestrictions(ctx any) *gomock.Call { +func (mr *MockQueriesMockRecorder) GetInstanceRestrictions(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInstanceRestrictions", reflect.TypeOf((*MockQueries)(nil).GetInstanceRestrictions), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInstanceRestrictions", reflect.TypeOf((*MockQueries)(nil).GetInstanceRestrictions), arg0) } // GetNotifyUserByID mocks base method. -func (m *MockQueries) GetNotifyUserByID(ctx context.Context, shouldTriggered bool, userID string) (*query.NotifyUser, error) { +func (m *MockQueries) GetNotifyUserByID(arg0 context.Context, arg1 bool, arg2 string) (*query.NotifyUser, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetNotifyUserByID", ctx, shouldTriggered, userID) + ret := m.ctrl.Call(m, "GetNotifyUserByID", arg0, arg1, arg2) ret0, _ := ret[0].(*query.NotifyUser) ret1, _ := ret[1].(error) return ret0, ret1 } // GetNotifyUserByID indicates an expected call of GetNotifyUserByID. -func (mr *MockQueriesMockRecorder) GetNotifyUserByID(ctx, shouldTriggered, userID any) *gomock.Call { +func (mr *MockQueriesMockRecorder) GetNotifyUserByID(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNotifyUserByID", reflect.TypeOf((*MockQueries)(nil).GetNotifyUserByID), ctx, shouldTriggered, userID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNotifyUserByID", reflect.TypeOf((*MockQueries)(nil).GetNotifyUserByID), arg0, arg1, arg2) } // InstanceByID mocks base method. -func (m *MockQueries) InstanceByID(ctx context.Context, id string) (authz.Instance, error) { +func (m *MockQueries) InstanceByID(arg0 context.Context, arg1 string) (authz.Instance, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "InstanceByID", ctx, id) + ret := m.ctrl.Call(m, "InstanceByID", arg0, arg1) ret0, _ := ret[0].(authz.Instance) ret1, _ := ret[1].(error) return ret0, ret1 } // InstanceByID indicates an expected call of InstanceByID. -func (mr *MockQueriesMockRecorder) InstanceByID(ctx, id any) *gomock.Call { +func (mr *MockQueriesMockRecorder) InstanceByID(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InstanceByID", reflect.TypeOf((*MockQueries)(nil).InstanceByID), ctx, id) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InstanceByID", reflect.TypeOf((*MockQueries)(nil).InstanceByID), arg0, arg1) } // MailTemplateByOrg mocks base method. -func (m *MockQueries) MailTemplateByOrg(ctx context.Context, orgID string, withOwnerRemoved bool) (*query.MailTemplate, error) { +func (m *MockQueries) MailTemplateByOrg(arg0 context.Context, arg1 string, arg2 bool) (*query.MailTemplate, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MailTemplateByOrg", ctx, orgID, withOwnerRemoved) + ret := m.ctrl.Call(m, "MailTemplateByOrg", arg0, arg1, arg2) ret0, _ := ret[0].(*query.MailTemplate) ret1, _ := ret[1].(error) return ret0, ret1 } // MailTemplateByOrg indicates an expected call of MailTemplateByOrg. -func (mr *MockQueriesMockRecorder) MailTemplateByOrg(ctx, orgID, withOwnerRemoved any) *gomock.Call { +func (mr *MockQueriesMockRecorder) MailTemplateByOrg(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MailTemplateByOrg", reflect.TypeOf((*MockQueries)(nil).MailTemplateByOrg), ctx, orgID, withOwnerRemoved) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MailTemplateByOrg", reflect.TypeOf((*MockQueries)(nil).MailTemplateByOrg), arg0, arg1, arg2) } // NotificationPolicyByOrg mocks base method. -func (m *MockQueries) NotificationPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (*query.NotificationPolicy, error) { +func (m *MockQueries) NotificationPolicyByOrg(arg0 context.Context, arg1 bool, arg2 string, arg3 bool) (*query.NotificationPolicy, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NotificationPolicyByOrg", ctx, shouldTriggerBulk, orgID, withOwnerRemoved) + ret := m.ctrl.Call(m, "NotificationPolicyByOrg", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(*query.NotificationPolicy) ret1, _ := ret[1].(error) return ret0, ret1 } // NotificationPolicyByOrg indicates an expected call of NotificationPolicyByOrg. -func (mr *MockQueriesMockRecorder) NotificationPolicyByOrg(ctx, shouldTriggerBulk, orgID, withOwnerRemoved any) *gomock.Call { +func (mr *MockQueriesMockRecorder) NotificationPolicyByOrg(arg0, arg1, arg2, arg3 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationPolicyByOrg", reflect.TypeOf((*MockQueries)(nil).NotificationPolicyByOrg), ctx, shouldTriggerBulk, orgID, withOwnerRemoved) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationPolicyByOrg", reflect.TypeOf((*MockQueries)(nil).NotificationPolicyByOrg), arg0, arg1, arg2, arg3) } // NotificationProviderByIDAndType mocks base method. -func (m *MockQueries) NotificationProviderByIDAndType(ctx context.Context, aggID string, providerType domain.NotificationProviderType) (*query.DebugNotificationProvider, error) { +func (m *MockQueries) NotificationProviderByIDAndType(arg0 context.Context, arg1 string, arg2 domain.NotificationProviderType) (*query.DebugNotificationProvider, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NotificationProviderByIDAndType", ctx, aggID, providerType) + ret := m.ctrl.Call(m, "NotificationProviderByIDAndType", arg0, arg1, arg2) ret0, _ := ret[0].(*query.DebugNotificationProvider) ret1, _ := ret[1].(error) return ret0, ret1 } // NotificationProviderByIDAndType indicates an expected call of NotificationProviderByIDAndType. -func (mr *MockQueriesMockRecorder) NotificationProviderByIDAndType(ctx, aggID, providerType any) *gomock.Call { +func (mr *MockQueriesMockRecorder) NotificationProviderByIDAndType(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationProviderByIDAndType", reflect.TypeOf((*MockQueries)(nil).NotificationProviderByIDAndType), ctx, aggID, providerType) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationProviderByIDAndType", reflect.TypeOf((*MockQueries)(nil).NotificationProviderByIDAndType), arg0, arg1, arg2) } // SMSProviderConfigActive mocks base method. -func (m *MockQueries) SMSProviderConfigActive(ctx context.Context, resourceOwner string) (*query.SMSConfig, error) { +func (m *MockQueries) SMSProviderConfigActive(arg0 context.Context, arg1 string) (*query.SMSConfig, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SMSProviderConfigActive", ctx, resourceOwner) + ret := m.ctrl.Call(m, "SMSProviderConfigActive", arg0, arg1) ret0, _ := ret[0].(*query.SMSConfig) ret1, _ := ret[1].(error) return ret0, ret1 } // SMSProviderConfigActive indicates an expected call of SMSProviderConfigActive. -func (mr *MockQueriesMockRecorder) SMSProviderConfigActive(ctx, resourceOwner any) *gomock.Call { +func (mr *MockQueriesMockRecorder) SMSProviderConfigActive(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SMSProviderConfigActive", reflect.TypeOf((*MockQueries)(nil).SMSProviderConfigActive), ctx, resourceOwner) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SMSProviderConfigActive", reflect.TypeOf((*MockQueries)(nil).SMSProviderConfigActive), arg0, arg1) } // SMTPConfigActive mocks base method. -func (m *MockQueries) SMTPConfigActive(ctx context.Context, resourceOwner string) (*query.SMTPConfig, error) { +func (m *MockQueries) SMTPConfigActive(arg0 context.Context, arg1 string) (*query.SMTPConfig, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SMTPConfigActive", ctx, resourceOwner) + ret := m.ctrl.Call(m, "SMTPConfigActive", arg0, arg1) ret0, _ := ret[0].(*query.SMTPConfig) ret1, _ := ret[1].(error) return ret0, ret1 } // SMTPConfigActive indicates an expected call of SMTPConfigActive. -func (mr *MockQueriesMockRecorder) SMTPConfigActive(ctx, resourceOwner any) *gomock.Call { +func (mr *MockQueriesMockRecorder) SMTPConfigActive(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SMTPConfigActive", reflect.TypeOf((*MockQueries)(nil).SMTPConfigActive), ctx, resourceOwner) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SMTPConfigActive", reflect.TypeOf((*MockQueries)(nil).SMTPConfigActive), arg0, arg1) } // SearchInstanceDomains mocks base method. -func (m *MockQueries) SearchInstanceDomains(ctx context.Context, queries *query.InstanceDomainSearchQueries) (*query.InstanceDomains, error) { +func (m *MockQueries) SearchInstanceDomains(arg0 context.Context, arg1 *query.InstanceDomainSearchQueries) (*query.InstanceDomains, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SearchInstanceDomains", ctx, queries) + ret := m.ctrl.Call(m, "SearchInstanceDomains", arg0, arg1) ret0, _ := ret[0].(*query.InstanceDomains) ret1, _ := ret[1].(error) return ret0, ret1 } // SearchInstanceDomains indicates an expected call of SearchInstanceDomains. -func (mr *MockQueriesMockRecorder) SearchInstanceDomains(ctx, queries any) *gomock.Call { +func (mr *MockQueriesMockRecorder) SearchInstanceDomains(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchInstanceDomains", reflect.TypeOf((*MockQueries)(nil).SearchInstanceDomains), ctx, queries) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchInstanceDomains", reflect.TypeOf((*MockQueries)(nil).SearchInstanceDomains), arg0, arg1) } // SearchMilestones mocks base method. -func (m *MockQueries) SearchMilestones(ctx context.Context, instanceIDs []string, queries *query.MilestonesSearchQueries) (*query.Milestones, error) { +func (m *MockQueries) SearchMilestones(arg0 context.Context, arg1 []string, arg2 *query.MilestonesSearchQueries) (*query.Milestones, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SearchMilestones", ctx, instanceIDs, queries) + ret := m.ctrl.Call(m, "SearchMilestones", arg0, arg1, arg2) ret0, _ := ret[0].(*query.Milestones) ret1, _ := ret[1].(error) return ret0, ret1 } // SearchMilestones indicates an expected call of SearchMilestones. -func (mr *MockQueriesMockRecorder) SearchMilestones(ctx, instanceIDs, queries any) *gomock.Call { +func (mr *MockQueriesMockRecorder) SearchMilestones(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchMilestones", reflect.TypeOf((*MockQueries)(nil).SearchMilestones), ctx, instanceIDs, queries) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchMilestones", reflect.TypeOf((*MockQueries)(nil).SearchMilestones), arg0, arg1, arg2) } // SessionByID mocks base method. -func (m *MockQueries) SessionByID(ctx context.Context, shouldTriggerBulk bool, id, sessionToken string) (*query.Session, error) { +func (m *MockQueries) SessionByID(arg0 context.Context, arg1 bool, arg2, arg3 string, arg4 domain.PermissionCheck) (*query.Session, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SessionByID", ctx, shouldTriggerBulk, id, sessionToken) + ret := m.ctrl.Call(m, "SessionByID", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].(*query.Session) ret1, _ := ret[1].(error) return ret0, ret1 } // SessionByID indicates an expected call of SessionByID. -func (mr *MockQueriesMockRecorder) SessionByID(ctx, shouldTriggerBulk, id, sessionToken any) *gomock.Call { +func (mr *MockQueriesMockRecorder) SessionByID(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SessionByID", reflect.TypeOf((*MockQueries)(nil).SessionByID), ctx, shouldTriggerBulk, id, sessionToken) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SessionByID", reflect.TypeOf((*MockQueries)(nil).SessionByID), arg0, arg1, arg2, arg3, arg4) } diff --git a/internal/notification/handlers/queries.go b/internal/notification/handlers/queries.go index 1c8d37598e..a3d68e4797 100644 --- a/internal/notification/handlers/queries.go +++ b/internal/notification/handlers/queries.go @@ -20,7 +20,7 @@ type Queries interface { GetNotifyUserByID(ctx context.Context, shouldTriggered bool, userID string) (*query.NotifyUser, error) CustomTextListByTemplate(ctx context.Context, aggregateID, template string, withOwnerRemoved bool) (*query.CustomTexts, error) SearchInstanceDomains(ctx context.Context, queries *query.InstanceDomainSearchQueries) (*query.InstanceDomains, error) - SessionByID(ctx context.Context, shouldTriggerBulk bool, id, sessionToken string) (*query.Session, error) + SessionByID(ctx context.Context, shouldTriggerBulk bool, id, sessionToken string, check domain.PermissionCheck) (*query.Session, error) NotificationPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (*query.NotificationPolicy, error) SearchMilestones(ctx context.Context, instanceIDs []string, queries *query.MilestonesSearchQueries) (*query.Milestones, error) NotificationProviderByIDAndType(ctx context.Context, aggID string, providerType domain.NotificationProviderType) (*query.DebugNotificationProvider, error) diff --git a/internal/notification/handlers/user_notifier.go b/internal/notification/handlers/user_notifier.go index ec30ab476f..c24b87c2f6 100644 --- a/internal/notification/handlers/user_notifier.go +++ b/internal/notification/handlers/user_notifier.go @@ -400,7 +400,7 @@ func (u *userNotifier) reduceSessionOTPSMSChallenged(event eventstore.Event) (*h if alreadyHandled { return nil } - s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "") + s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "", nil) if err != nil { return err } @@ -496,7 +496,7 @@ func (u *userNotifier) reduceSessionOTPEmailChallenged(event eventstore.Event) ( if alreadyHandled { return nil } - s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "") + s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "", nil) if err != nil { return err } diff --git a/internal/notification/handlers/user_notifier_legacy.go b/internal/notification/handlers/user_notifier_legacy.go index 7df31cdf91..4bfa1a796e 100644 --- a/internal/notification/handlers/user_notifier_legacy.go +++ b/internal/notification/handlers/user_notifier_legacy.go @@ -324,7 +324,7 @@ func (u *userNotifierLegacy) reduceSessionOTPSMSChallenged(event eventstore.Even return handler.NewNoOpStatement(e), nil } ctx := HandlerContext(event.Aggregate()) - s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "") + s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "", nil) if err != nil { return nil, err } @@ -428,7 +428,7 @@ func (u *userNotifierLegacy) reduceSessionOTPEmailChallenged(event eventstore.Ev return handler.NewNoOpStatement(e), nil } ctx := HandlerContext(event.Aggregate()) - s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "") + s, err := u.queries.SessionByID(ctx, true, e.Aggregate().ID, "", nil) if err != nil { return nil, err } diff --git a/internal/notification/handlers/user_notifier_legacy_test.go b/internal/notification/handlers/user_notifier_legacy_test.go index fe99eaa572..02f21670f5 100644 --- a/internal/notification/handlers/user_notifier_legacy_test.go +++ b/internal/notification/handlers/user_notifier_legacy_test.go @@ -1228,7 +1228,7 @@ func Test_userNotifierLegacy_reduceOTPEmailChallenged(t *testing.T) { } codeAlg, code := cryptoValue(t, ctrl, "testcode") expectTemplateWithNotifyUserQueries(queries, givenTemplate) - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any()).Return(&query.Session{}, nil) + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any(), nil).Return(&query.Session{}, nil) commands.EXPECT().OTPEmailSent(gomock.Any(), userID, orgID).Return(nil) return fields{ queries: queries, @@ -1264,7 +1264,7 @@ func Test_userNotifierLegacy_reduceOTPEmailChallenged(t *testing.T) { } codeAlg, code := cryptoValue(t, ctrl, "testcode") expectTemplateWithNotifyUserQueries(queries, givenTemplate) - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any()).Return(&query.Session{}, nil) + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any(), nil).Return(&query.Session{}, nil) queries.EXPECT().SearchInstanceDomains(gomock.Any(), gomock.Any()).Return(&query.InstanceDomains{ Domains: []*query.InstanceDomain{{ Domain: instancePrimaryDomain, @@ -1306,7 +1306,7 @@ func Test_userNotifierLegacy_reduceOTPEmailChallenged(t *testing.T) { } codeAlg, code := cryptoValue(t, ctrl, testCode) expectTemplateWithNotifyUserQueries(queries, givenTemplate) - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any()).Return(&query.Session{}, nil) + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any(), nil).Return(&query.Session{}, nil) commands.EXPECT().OTPEmailSent(gomock.Any(), userID, orgID).Return(nil) return fields{ queries: queries, @@ -1350,7 +1350,7 @@ func Test_userNotifierLegacy_reduceOTPEmailChallenged(t *testing.T) { }}, }, nil) expectTemplateWithNotifyUserQueries(queries, givenTemplate) - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any()).Return(&query.Session{}, nil) + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any(), nil).Return(&query.Session{}, nil) commands.EXPECT().OTPEmailSent(gomock.Any(), userID, orgID).Return(nil) return fields{ queries: queries, @@ -1386,7 +1386,7 @@ func Test_userNotifierLegacy_reduceOTPEmailChallenged(t *testing.T) { } codeAlg, code := cryptoValue(t, ctrl, testCode) expectTemplateWithNotifyUserQueries(queries, givenTemplate) - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any()).Return(&query.Session{}, nil) + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any(), nil).Return(&query.Session{}, nil) commands.EXPECT().OTPEmailSent(gomock.Any(), userID, orgID).Return(nil) return fields{ queries: queries, @@ -1445,7 +1445,7 @@ func Test_userNotifierLegacy_reduceOTPSMSChallenged(t *testing.T) { Content: expectContent, } expectTemplateWithNotifyUserQueriesSMS(queries) - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any()).Return(&query.Session{}, nil) + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any(), nil).Return(&query.Session{}, nil) commands.EXPECT().OTPSMSSent(gomock.Any(), userID, orgID, &senders.CodeGeneratorInfo{ID: smsProviderID, VerificationID: verificationID}).Return(nil) return fields{ queries: queries, @@ -1481,7 +1481,7 @@ func Test_userNotifierLegacy_reduceOTPSMSChallenged(t *testing.T) { Content: expectContent, } expectTemplateWithNotifyUserQueriesSMS(queries) - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any()).Return(&query.Session{}, nil) + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any(), nil).Return(&query.Session{}, nil) queries.EXPECT().SearchInstanceDomains(gomock.Any(), gomock.Any()).Return(&query.InstanceDomains{ Domains: []*query.InstanceDomain{{ Domain: instancePrimaryDomain, diff --git a/internal/notification/handlers/user_notifier_test.go b/internal/notification/handlers/user_notifier_test.go index b57edcc57c..b7b7ceb446 100644 --- a/internal/notification/handlers/user_notifier_test.go +++ b/internal/notification/handlers/user_notifier_test.go @@ -980,7 +980,7 @@ func Test_userNotifier_reduceOTPEmailChallenged(t *testing.T) { name: "url with event trigger", test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) { _, code := cryptoValue(t, ctrl, "testCode") - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any()).Return(&query.Session{ + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any(), nil).Return(&query.Session{ ID: sessionID, ResourceOwner: instanceID, UserFactor: query.SessionUserFactor{ @@ -1044,7 +1044,7 @@ func Test_userNotifier_reduceOTPEmailChallenged(t *testing.T) { IsPrimary: true, }}, }, nil) - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any()).Return(&query.Session{ + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any(), nil).Return(&query.Session{ ID: sessionID, ResourceOwner: instanceID, UserFactor: query.SessionUserFactor{ @@ -1129,7 +1129,7 @@ func Test_userNotifier_reduceOTPEmailChallenged(t *testing.T) { name: "url template", test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) { _, code := cryptoValue(t, ctrl, "testCode") - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any()).Return(&query.Session{ + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any(), nil).Return(&query.Session{ ID: sessionID, ResourceOwner: instanceID, UserFactor: query.SessionUserFactor{ @@ -1220,7 +1220,7 @@ func Test_userNotifier_reduceOTPSMSChallenged(t *testing.T) { test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) { testCode := "testcode" _, code := cryptoValue(t, ctrl, testCode) - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any()).Return(&query.Session{ + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any(), nil).Return(&query.Session{ ID: sessionID, ResourceOwner: instanceID, UserFactor: query.SessionUserFactor{ @@ -1284,7 +1284,7 @@ func Test_userNotifier_reduceOTPSMSChallenged(t *testing.T) { IsPrimary: true, }}, }, nil) - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any()).Return(&query.Session{ + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any(), nil).Return(&query.Session{ ID: sessionID, ResourceOwner: instanceID, UserFactor: query.SessionUserFactor{ @@ -1339,7 +1339,7 @@ func Test_userNotifier_reduceOTPSMSChallenged(t *testing.T) { { name: "external code", test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) { - queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any()).Return(&query.Session{ + queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), sessionID, gomock.Any(), nil).Return(&query.Session{ ID: sessionID, ResourceOwner: instanceID, UserFactor: query.SessionUserFactor{ diff --git a/internal/query/iam_member.go b/internal/query/iam_member.go index 9f1c5521c9..87b906aa51 100644 --- a/internal/query/iam_member.go +++ b/internal/query/iam_member.go @@ -44,6 +44,10 @@ var ( name: projection.MemberResourceOwner, table: instanceMemberTable, } + InstanceMemberUserResourceOwner = Column{ + name: projection.MemberUserResourceOwner, + table: instanceMemberTable, + } InstanceMemberInstanceID = Column{ name: projection.MemberInstanceID, table: instanceMemberTable, @@ -96,6 +100,7 @@ func prepareInstanceMembersQuery(ctx context.Context, db prepareDatabase) (sq.Se InstanceMemberChangeDate.identifier(), InstanceMemberSequence.identifier(), InstanceMemberResourceOwner.identifier(), + InstanceMemberUserResourceOwner.identifier(), InstanceMemberUserID.identifier(), InstanceMemberRoles.identifier(), LoginNameNameCol.identifier(), @@ -138,6 +143,7 @@ func prepareInstanceMembersQuery(ctx context.Context, db prepareDatabase) (sq.Se &member.ChangeDate, &member.Sequence, &member.ResourceOwner, + &member.UserResourceOwner, &member.UserID, &member.Roles, &preferredLoginName, diff --git a/internal/query/iam_member_test.go b/internal/query/iam_member_test.go index 2ab62d3244..38b9bbc8bc 100644 --- a/internal/query/iam_member_test.go +++ b/internal/query/iam_member_test.go @@ -18,6 +18,7 @@ var ( ", members.change_date" + ", members.sequence" + ", members.resource_owner" + + ", members.user_resource_owner" + ", members.user_id" + ", members.roles" + ", projections.login_names3.login_name" + @@ -45,6 +46,7 @@ var ( "change_date", "sequence", "resource_owner", + "user_resource_owner", "user_id", "roles", "login_name", @@ -97,6 +99,7 @@ func Test_IAMMemberPrepares(t *testing.T) { testNow, uint64(20211206), "ro", + "uro", "user-id", database.TextArray[string]{"role-1", "role-2"}, "gigi@caos-ag.zitadel.ch", @@ -121,6 +124,7 @@ func Test_IAMMemberPrepares(t *testing.T) { ChangeDate: testNow, Sequence: 20211206, ResourceOwner: "ro", + UserResourceOwner: "uro", UserID: "user-id", Roles: database.TextArray[string]{"role-1", "role-2"}, PreferredLoginName: "gigi@caos-ag.zitadel.ch", @@ -147,6 +151,7 @@ func Test_IAMMemberPrepares(t *testing.T) { testNow, uint64(20211206), "ro", + "uro", "user-id", database.TextArray[string]{"role-1", "role-2"}, "machine@caos-ag.zitadel.ch", @@ -171,6 +176,7 @@ func Test_IAMMemberPrepares(t *testing.T) { ChangeDate: testNow, Sequence: 20211206, ResourceOwner: "ro", + UserResourceOwner: "uro", UserID: "user-id", Roles: database.TextArray[string]{"role-1", "role-2"}, PreferredLoginName: "machine@caos-ag.zitadel.ch", @@ -197,6 +203,7 @@ func Test_IAMMemberPrepares(t *testing.T) { testNow, uint64(20211206), "ro", + "uro", "user-id-1", database.TextArray[string]{"role-1", "role-2"}, "gigi@caos-ag.zitadel.ch", @@ -213,6 +220,7 @@ func Test_IAMMemberPrepares(t *testing.T) { testNow, uint64(20211206), "ro", + "uro", "user-id-2", database.TextArray[string]{"role-1", "role-2"}, "machine@caos-ag.zitadel.ch", @@ -237,6 +245,7 @@ func Test_IAMMemberPrepares(t *testing.T) { ChangeDate: testNow, Sequence: 20211206, ResourceOwner: "ro", + UserResourceOwner: "uro", UserID: "user-id-1", Roles: database.TextArray[string]{"role-1", "role-2"}, PreferredLoginName: "gigi@caos-ag.zitadel.ch", @@ -252,6 +261,7 @@ func Test_IAMMemberPrepares(t *testing.T) { ChangeDate: testNow, Sequence: 20211206, ResourceOwner: "ro", + UserResourceOwner: "uro", UserID: "user-id-2", Roles: database.TextArray[string]{"role-1", "role-2"}, PreferredLoginName: "machine@caos-ag.zitadel.ch", diff --git a/internal/query/instance_features.go b/internal/query/instance_features.go index 4f06577a6d..646404ce6c 100644 --- a/internal/query/instance_features.go +++ b/internal/query/instance_features.go @@ -22,6 +22,7 @@ type InstanceFeatures struct { DisableUserTokenEvent FeatureSource[bool] EnableBackChannelLogout FeatureSource[bool] LoginV2 FeatureSource[*feature.LoginV2] + PermissionCheckV2 FeatureSource[bool] } func (q *Queries) GetInstanceFeatures(ctx context.Context, cascade bool) (_ *InstanceFeatures, err error) { diff --git a/internal/query/instance_features_model.go b/internal/query/instance_features_model.go index c7f273a24a..b9839bf359 100644 --- a/internal/query/instance_features_model.go +++ b/internal/query/instance_features_model.go @@ -75,6 +75,7 @@ func (m *InstanceFeaturesReadModel) Query() *eventstore.SearchQueryBuilder { feature_v2.InstanceDisableUserTokenEvent, feature_v2.InstanceEnableBackChannelLogout, feature_v2.InstanceLoginVersion, + feature_v2.InstancePermissionCheckV2, ). Builder().ResourceOwner(m.ResourceOwner) } @@ -139,6 +140,8 @@ func reduceInstanceFeatureSet[T any](features *InstanceFeatures, event *feature_ features.EnableBackChannelLogout.set(level, event.Value) case feature.KeyLoginV2: features.LoginV2.set(level, event.Value) + case feature.KeyPermissionCheckV2: + features.PermissionCheckV2.set(level, event.Value) } return nil } diff --git a/internal/query/member.go b/internal/query/member.go index 2c4b4db5fe..584ae15d1c 100644 --- a/internal/query/member.go +++ b/internal/query/member.go @@ -47,11 +47,11 @@ type Members struct { } type Member struct { - CreationDate time.Time - ChangeDate time.Time - Sequence uint64 - ResourceOwner string - + CreationDate time.Time + ChangeDate time.Time + Sequence uint64 + ResourceOwner string + UserResourceOwner string UserID string Roles database.TextArray[string] PreferredLoginName string diff --git a/internal/query/org_member.go b/internal/query/org_member.go index ea452fe357..4daa31d341 100644 --- a/internal/query/org_member.go +++ b/internal/query/org_member.go @@ -44,6 +44,10 @@ var ( name: projection.MemberResourceOwner, table: orgMemberTable, } + OrgMemberUserResourceOwner = Column{ + name: projection.MemberUserResourceOwner, + table: orgMemberTable, + } OrgMemberInstanceID = Column{ name: projection.MemberInstanceID, table: orgMemberTable, @@ -99,6 +103,7 @@ func prepareOrgMembersQuery(ctx context.Context, db prepareDatabase) (sq.SelectB OrgMemberChangeDate.identifier(), OrgMemberSequence.identifier(), OrgMemberResourceOwner.identifier(), + OrgMemberUserResourceOwner.identifier(), OrgMemberUserID.identifier(), OrgMemberRoles.identifier(), LoginNameNameCol.identifier(), @@ -141,6 +146,7 @@ func prepareOrgMembersQuery(ctx context.Context, db prepareDatabase) (sq.SelectB &member.ChangeDate, &member.Sequence, &member.ResourceOwner, + &member.UserResourceOwner, &member.UserID, &member.Roles, &preferredLoginName, diff --git a/internal/query/org_member_test.go b/internal/query/org_member_test.go index d0247c39d3..d42c9b4317 100644 --- a/internal/query/org_member_test.go +++ b/internal/query/org_member_test.go @@ -18,6 +18,7 @@ var ( ", members.change_date" + ", members.sequence" + ", members.resource_owner" + + ", members.user_resource_owner" + ", members.user_id" + ", members.roles" + ", projections.login_names3.login_name" + @@ -49,6 +50,7 @@ var ( "change_date", "sequence", "resource_owner", + "user_resource_owner", "user_id", "roles", "login_name", @@ -101,6 +103,7 @@ func Test_OrgMemberPrepares(t *testing.T) { testNow, uint64(20211206), "ro", + "uro", "user-id", database.TextArray[string]{"role-1", "role-2"}, "gigi@caos-ag.zitadel.ch", @@ -125,6 +128,7 @@ func Test_OrgMemberPrepares(t *testing.T) { ChangeDate: testNow, Sequence: 20211206, ResourceOwner: "ro", + UserResourceOwner: "uro", UserID: "user-id", Roles: database.TextArray[string]{"role-1", "role-2"}, PreferredLoginName: "gigi@caos-ag.zitadel.ch", @@ -151,6 +155,7 @@ func Test_OrgMemberPrepares(t *testing.T) { testNow, uint64(20211206), "ro", + "uro", "user-id", database.TextArray[string]{"role-1", "role-2"}, "machine@caos-ag.zitadel.ch", @@ -175,6 +180,7 @@ func Test_OrgMemberPrepares(t *testing.T) { ChangeDate: testNow, Sequence: 20211206, ResourceOwner: "ro", + UserResourceOwner: "uro", UserID: "user-id", Roles: database.TextArray[string]{"role-1", "role-2"}, PreferredLoginName: "machine@caos-ag.zitadel.ch", @@ -201,6 +207,7 @@ func Test_OrgMemberPrepares(t *testing.T) { testNow, uint64(20211206), "ro", + "uro", "user-id-1", database.TextArray[string]{"role-1", "role-2"}, "gigi@caos-ag.zitadel.ch", @@ -217,6 +224,7 @@ func Test_OrgMemberPrepares(t *testing.T) { testNow, uint64(20211206), "ro", + "uro", "user-id-2", database.TextArray[string]{"role-1", "role-2"}, "machine@caos-ag.zitadel.ch", @@ -241,6 +249,7 @@ func Test_OrgMemberPrepares(t *testing.T) { ChangeDate: testNow, Sequence: 20211206, ResourceOwner: "ro", + UserResourceOwner: "uro", UserID: "user-id-1", Roles: database.TextArray[string]{"role-1", "role-2"}, PreferredLoginName: "gigi@caos-ag.zitadel.ch", @@ -256,6 +265,7 @@ func Test_OrgMemberPrepares(t *testing.T) { ChangeDate: testNow, Sequence: 20211206, ResourceOwner: "ro", + UserResourceOwner: "uro", UserID: "user-id-2", Roles: database.TextArray[string]{"role-1", "role-2"}, PreferredLoginName: "machine@caos-ag.zitadel.ch", diff --git a/internal/query/permission.go b/internal/query/permission.go new file mode 100644 index 0000000000..96d7db6c6a --- /dev/null +++ b/internal/query/permission.go @@ -0,0 +1,35 @@ +package query + +import ( + "context" + "fmt" + + sq "github.com/Masterminds/squirrel" + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/api/authz" +) + +const ( + // eventstore.permitted_orgs(instanceid text, userid text, perm text) + wherePermittedOrgsClause = "%s = ANY(eventstore.permitted_orgs(?, ?, ?))" +) + +// wherePermittedOrgs sets a `WHERE` clause to the query that filters the orgs +// for which the authenticated user has the requested permission for. +// The user ID is taken from the context. +// +// The `orgIDColumn` specifies the table column to which this filter must be applied, +// and is typically the `resource_owner` column in ZITADEL. +// We use full identifiers in the query builder so this function should be +// called with something like `UserResourceOwnerCol.identifier()` for example. +func wherePermittedOrgs(ctx context.Context, query sq.SelectBuilder, orgIDColumn, permission string) sq.SelectBuilder { + userID := authz.GetCtxData(ctx).UserID + logging.WithFields("permission_check_v2_flag", authz.GetFeatures(ctx).PermissionCheckV2, "org_id_column", orgIDColumn, "permission", permission, "user_id", userID).Debug("permitted orgs check used") + return query.Where( + fmt.Sprintf(wherePermittedOrgsClause, orgIDColumn), + authz.GetInstance(ctx).InstanceID(), + userID, + permission, + ) +} diff --git a/internal/query/project_grant_member.go b/internal/query/project_grant_member.go index c13300713f..0820ada826 100644 --- a/internal/query/project_grant_member.go +++ b/internal/query/project_grant_member.go @@ -43,6 +43,10 @@ var ( name: projection.MemberResourceOwner, table: projectGrantMemberTable, } + ProjectGrantMemberUserResourceOwner = Column{ + name: projection.MemberUserResourceOwner, + table: projectGrantMemberTable, + } ProjectGrantMemberInstanceID = Column{ name: projection.MemberInstanceID, table: projectGrantMemberTable, @@ -108,6 +112,7 @@ func prepareProjectGrantMembersQuery(ctx context.Context, db prepareDatabase) (s ProjectGrantMemberChangeDate.identifier(), ProjectGrantMemberSequence.identifier(), ProjectGrantMemberResourceOwner.identifier(), + ProjectGrantMemberUserResourceOwner.identifier(), ProjectGrantMemberUserID.identifier(), ProjectGrantMemberRoles.identifier(), LoginNameNameCol.identifier(), @@ -151,6 +156,7 @@ func prepareProjectGrantMembersQuery(ctx context.Context, db prepareDatabase) (s &member.ChangeDate, &member.Sequence, &member.ResourceOwner, + &member.UserResourceOwner, &member.UserID, &member.Roles, &preferredLoginName, diff --git a/internal/query/project_grant_member_test.go b/internal/query/project_grant_member_test.go index 839a1f2c1b..f55841ff76 100644 --- a/internal/query/project_grant_member_test.go +++ b/internal/query/project_grant_member_test.go @@ -18,6 +18,7 @@ var ( ", members.change_date" + ", members.sequence" + ", members.resource_owner" + + ", members.user_resource_owner" + ", members.user_id" + ", members.roles" + ", projections.login_names3.login_name" + @@ -52,6 +53,7 @@ var ( "change_date", "sequence", "resource_owner", + "user_resource_owner", "user_id", "roles", "login_name", @@ -104,6 +106,7 @@ func Test_ProjectGrantMemberPrepares(t *testing.T) { testNow, uint64(20211206), "ro", + "uro", "user-id", database.TextArray[string]{"role-1", "role-2"}, "gigi@caos-ag.zitadel.ch", @@ -128,6 +131,7 @@ func Test_ProjectGrantMemberPrepares(t *testing.T) { ChangeDate: testNow, Sequence: 20211206, ResourceOwner: "ro", + UserResourceOwner: "uro", UserID: "user-id", Roles: database.TextArray[string]{"role-1", "role-2"}, PreferredLoginName: "gigi@caos-ag.zitadel.ch", @@ -154,6 +158,7 @@ func Test_ProjectGrantMemberPrepares(t *testing.T) { testNow, uint64(20211206), "ro", + "uro", "user-id", database.TextArray[string]{"role-1", "role-2"}, "machine@caos-ag.zitadel.ch", @@ -178,6 +183,7 @@ func Test_ProjectGrantMemberPrepares(t *testing.T) { ChangeDate: testNow, Sequence: 20211206, ResourceOwner: "ro", + UserResourceOwner: "uro", UserID: "user-id", Roles: database.TextArray[string]{"role-1", "role-2"}, PreferredLoginName: "machine@caos-ag.zitadel.ch", @@ -204,6 +210,7 @@ func Test_ProjectGrantMemberPrepares(t *testing.T) { testNow, uint64(20211206), "ro", + "uro", "user-id-1", database.TextArray[string]{"role-1", "role-2"}, "gigi@caos-ag.zitadel.ch", @@ -220,6 +227,7 @@ func Test_ProjectGrantMemberPrepares(t *testing.T) { testNow, uint64(20211206), "ro", + "uro", "user-id-2", database.TextArray[string]{"role-1", "role-2"}, "machine@caos-ag.zitadel.ch", @@ -244,6 +252,7 @@ func Test_ProjectGrantMemberPrepares(t *testing.T) { ChangeDate: testNow, Sequence: 20211206, ResourceOwner: "ro", + UserResourceOwner: "uro", UserID: "user-id-1", Roles: database.TextArray[string]{"role-1", "role-2"}, PreferredLoginName: "gigi@caos-ag.zitadel.ch", @@ -259,6 +268,7 @@ func Test_ProjectGrantMemberPrepares(t *testing.T) { ChangeDate: testNow, Sequence: 20211206, ResourceOwner: "ro", + UserResourceOwner: "uro", UserID: "user-id-2", Roles: database.TextArray[string]{"role-1", "role-2"}, PreferredLoginName: "machine@caos-ag.zitadel.ch", diff --git a/internal/query/project_member.go b/internal/query/project_member.go index a86246bdd7..347eac12b9 100644 --- a/internal/query/project_member.go +++ b/internal/query/project_member.go @@ -44,6 +44,10 @@ var ( name: projection.MemberResourceOwner, table: projectMemberTable, } + ProjectMemberUserResourceOwner = Column{ + name: projection.MemberUserResourceOwner, + table: projectMemberTable, + } ProjectMemberInstanceID = Column{ name: projection.MemberInstanceID, table: projectMemberTable, @@ -99,6 +103,7 @@ func prepareProjectMembersQuery(ctx context.Context, db prepareDatabase) (sq.Sel ProjectMemberChangeDate.identifier(), ProjectMemberSequence.identifier(), ProjectMemberResourceOwner.identifier(), + ProjectMemberUserResourceOwner.identifier(), ProjectMemberUserID.identifier(), ProjectMemberRoles.identifier(), LoginNameNameCol.identifier(), @@ -141,6 +146,7 @@ func prepareProjectMembersQuery(ctx context.Context, db prepareDatabase) (sq.Sel &member.ChangeDate, &member.Sequence, &member.ResourceOwner, + &member.UserResourceOwner, &member.UserID, &member.Roles, &preferredLoginName, diff --git a/internal/query/project_member_test.go b/internal/query/project_member_test.go index 74f35ef6ee..21be454f43 100644 --- a/internal/query/project_member_test.go +++ b/internal/query/project_member_test.go @@ -18,6 +18,7 @@ var ( ", members.change_date" + ", members.sequence" + ", members.resource_owner" + + ", members.user_resource_owner" + ", members.user_id" + ", members.roles" + ", projections.login_names3.login_name" + @@ -49,6 +50,7 @@ var ( "change_date", "sequence", "resource_owner", + "user_resource_owner", "user_id", "roles", "login_name", @@ -101,6 +103,7 @@ func Test_ProjectMemberPrepares(t *testing.T) { testNow, uint64(20211206), "ro", + "uro", "user-id", database.TextArray[string]{"role-1", "role-2"}, "gigi@caos-ag.zitadel.ch", @@ -125,6 +128,7 @@ func Test_ProjectMemberPrepares(t *testing.T) { ChangeDate: testNow, Sequence: 20211206, ResourceOwner: "ro", + UserResourceOwner: "uro", UserID: "user-id", Roles: database.TextArray[string]{"role-1", "role-2"}, PreferredLoginName: "gigi@caos-ag.zitadel.ch", @@ -151,6 +155,7 @@ func Test_ProjectMemberPrepares(t *testing.T) { testNow, uint64(20211206), "ro", + "uro", "user-id", database.TextArray[string]{"role-1", "role-2"}, "machine@caos-ag.zitadel.ch", @@ -175,6 +180,7 @@ func Test_ProjectMemberPrepares(t *testing.T) { ChangeDate: testNow, Sequence: 20211206, ResourceOwner: "ro", + UserResourceOwner: "uro", UserID: "user-id", Roles: database.TextArray[string]{"role-1", "role-2"}, PreferredLoginName: "machine@caos-ag.zitadel.ch", @@ -201,6 +207,7 @@ func Test_ProjectMemberPrepares(t *testing.T) { testNow, uint64(20211206), "ro", + "uro", "user-id-1", database.TextArray[string]{"role-1", "role-2"}, "gigi@caos-ag.zitadel.ch", @@ -217,6 +224,7 @@ func Test_ProjectMemberPrepares(t *testing.T) { testNow, uint64(20211206), "ro", + "uro", "user-id-2", database.TextArray[string]{"role-1", "role-2"}, "machine@caos-ag.zitadel.ch", @@ -241,6 +249,7 @@ func Test_ProjectMemberPrepares(t *testing.T) { ChangeDate: testNow, Sequence: 20211206, ResourceOwner: "ro", + UserResourceOwner: "uro", UserID: "user-id-1", Roles: database.TextArray[string]{"role-1", "role-2"}, PreferredLoginName: "gigi@caos-ag.zitadel.ch", @@ -256,6 +265,7 @@ func Test_ProjectMemberPrepares(t *testing.T) { ChangeDate: testNow, Sequence: 20211206, ResourceOwner: "ro", + UserResourceOwner: "uro", UserID: "user-id-2", Roles: database.TextArray[string]{"role-1", "role-2"}, PreferredLoginName: "machine@caos-ag.zitadel.ch", diff --git a/internal/query/projection/eventstore_field.go b/internal/query/projection/eventstore_field.go index 59dde7507d..5dbdad717a 100644 --- a/internal/query/projection/eventstore_field.go +++ b/internal/query/projection/eventstore_field.go @@ -12,6 +12,7 @@ const ( fieldsProjectGrant = "project_grant_fields" fieldsOrgDomainVerified = "org_domain_verified_fields" fieldsInstanceDomain = "instance_domain_fields" + fieldsMemberships = "membership_fields" ) func newFillProjectGrantFields(config handler.Config) *handler.FieldHandler { @@ -52,3 +53,33 @@ func newFillInstanceDomainFields(config handler.Config) *handler.FieldHandler { }, ) } + +func newFillMembershipFields(config handler.Config) *handler.FieldHandler { + return handler.NewFieldHandler( + &config, + fieldsMemberships, + map[eventstore.AggregateType][]eventstore.EventType{ + instance.AggregateType: { + instance.MemberAddedEventType, + instance.MemberChangedEventType, + instance.MemberRemovedEventType, + instance.MemberCascadeRemovedEventType, + instance.InstanceRemovedEventType, + }, + org.AggregateType: { + org.MemberAddedEventType, + org.MemberChangedEventType, + org.MemberRemovedEventType, + org.MemberCascadeRemovedEventType, + org.OrgRemovedEventType, + }, + project.AggregateType: { + project.MemberAddedEventType, + project.MemberChangedEventType, + project.MemberRemovedEventType, + project.MemberCascadeRemovedEventType, + project.ProjectRemovedType, + }, + }, + ) +} diff --git a/internal/query/projection/instance_features.go b/internal/query/projection/instance_features.go index 2479203d09..2cd846bf2e 100644 --- a/internal/query/projection/instance_features.go +++ b/internal/query/projection/instance_features.go @@ -112,6 +112,10 @@ func (*instanceFeatureProjection) Reducers() []handler.AggregateReducer { Event: feature_v2.InstanceLoginVersion, Reduce: reduceInstanceSetFeature[*feature.LoginV2], }, + { + Event: feature_v2.InstancePermissionCheckV2, + Reduce: reduceInstanceSetFeature[bool], + }, { Event: instance.InstanceRemovedEventType, Reduce: reduceInstanceRemovedHelper(InstanceDomainInstanceIDCol), diff --git a/internal/query/projection/project_grant.go b/internal/query/projection/project_grant.go index d6fbde8556..d5a075c486 100644 --- a/internal/query/projection/project_grant.go +++ b/internal/query/projection/project_grant.go @@ -93,6 +93,10 @@ func (p *projectGrantProjection) Reducers() []handler.AggregateReducer { Event: project.ProjectRemovedType, Reduce: p.reduceProjectRemoved, }, + { + Event: project.ProjectOwnerCorrected, + Reduce: p.reduceOwnerCorrected, + }, }, }, { @@ -269,3 +273,17 @@ func (p *projectGrantProjection) reduceOwnerRemoved(event eventstore.Event) (*ha ), ), nil } + +func (p *projectGrantProjection) reduceOwnerCorrected(event eventstore.Event) (*handler.Statement, error) { + return handler.NewUpdateStatement( + event, + []handler.Column{ + handler.NewCol(ProjectGrantColumnResourceOwner, event.Aggregate().ResourceOwner), + }, + []handler.Condition{ + handler.NewCond(ProjectGrantColumnInstanceID, event.Aggregate().InstanceID), + handler.NewCond(ProjectGrantColumnProjectID, event.Aggregate().ID), + handler.NewUnequalCond(ProjectGrantColumnResourceOwner, event.Aggregate().ResourceOwner), + }, + ), nil +} diff --git a/internal/query/projection/project_member.go b/internal/query/projection/project_member.go index 822e2e8d7e..8f03192019 100644 --- a/internal/query/projection/project_member.go +++ b/internal/query/projection/project_member.go @@ -60,19 +60,19 @@ func (p *projectMemberProjection) Reducers() []handler.AggregateReducer { Aggregate: project.AggregateType, EventReducers: []handler.EventReducer{ { - Event: project.MemberAddedType, + Event: project.MemberAddedEventType, Reduce: p.reduceAdded, }, { - Event: project.MemberChangedType, + Event: project.MemberChangedEventType, Reduce: p.reduceChanged, }, { - Event: project.MemberCascadeRemovedType, + Event: project.MemberCascadeRemovedEventType, Reduce: p.reduceCascadeRemoved, }, { - Event: project.MemberRemovedType, + Event: project.MemberRemovedEventType, Reduce: p.reduceRemoved, }, { @@ -114,7 +114,7 @@ func (p *projectMemberProjection) Reducers() []handler.AggregateReducer { func (p *projectMemberProjection) reduceAdded(event eventstore.Event) (*handler.Statement, error) { e, ok := event.(*project.MemberAddedEvent) if !ok { - return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-bgx5Q", "reduce.wrong.event.type %s", project.MemberAddedType) + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-bgx5Q", "reduce.wrong.event.type %s", project.MemberAddedEventType) } ctx := setMemberContext(e.Aggregate()) userOwner, err := getUserResourceOwner(ctx, p.es, e.Aggregate().InstanceID, e.UserID) @@ -131,7 +131,7 @@ func (p *projectMemberProjection) reduceAdded(event eventstore.Event) (*handler. func (p *projectMemberProjection) reduceChanged(event eventstore.Event) (*handler.Statement, error) { e, ok := event.(*project.MemberChangedEvent) if !ok { - return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-90WJ1", "reduce.wrong.event.type %s", project.MemberChangedType) + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-90WJ1", "reduce.wrong.event.type %s", project.MemberChangedEventType) } return reduceMemberChanged( *member.NewMemberChangedEvent(&e.BaseEvent, e.UserID, e.Roles...), @@ -142,7 +142,7 @@ func (p *projectMemberProjection) reduceChanged(event eventstore.Event) (*handle func (p *projectMemberProjection) reduceCascadeRemoved(event eventstore.Event) (*handler.Statement, error) { e, ok := event.(*project.MemberCascadeRemovedEvent) if !ok { - return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-aGd43", "reduce.wrong.event.type %s", project.MemberCascadeRemovedType) + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-aGd43", "reduce.wrong.event.type %s", project.MemberCascadeRemovedEventType) } return reduceMemberCascadeRemoved( *member.NewCascadeRemovedEvent(&e.BaseEvent, e.UserID), @@ -153,7 +153,7 @@ func (p *projectMemberProjection) reduceCascadeRemoved(event eventstore.Event) ( func (p *projectMemberProjection) reduceRemoved(event eventstore.Event) (*handler.Statement, error) { e, ok := event.(*project.MemberRemovedEvent) if !ok { - return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-eJZPh", "reduce.wrong.event.type %s", project.MemberRemovedType) + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-eJZPh", "reduce.wrong.event.type %s", project.MemberRemovedEventType) } return reduceMemberRemoved(e, withMemberCond(MemberUserIDCol, e.UserID), diff --git a/internal/query/projection/project_member_test.go b/internal/query/projection/project_member_test.go index bd7e1049cf..c33a319524 100644 --- a/internal/query/projection/project_member_test.go +++ b/internal/query/projection/project_member_test.go @@ -32,7 +32,7 @@ func TestProjectMemberProjection_reduces(t *testing.T) { args: args{ event: getEvent( testEvent( - project.MemberAddedType, + project.MemberAddedEventType, project.AggregateType, []byte(`{ "userId": "user-id", @@ -56,7 +56,7 @@ func TestProjectMemberProjection_reduces(t *testing.T) { args: args{ event: getEvent( testEvent( - project.MemberAddedType, + project.MemberAddedEventType, project.AggregateType, []byte(`{ "userId": "user-id", @@ -110,7 +110,7 @@ func TestProjectMemberProjection_reduces(t *testing.T) { args: args{ event: getEvent( testEvent( - project.MemberAddedType, + project.MemberAddedEventType, project.AggregateType, []byte(`{ "userId": "user-id", @@ -176,7 +176,7 @@ func TestProjectMemberProjection_reduces(t *testing.T) { args: args{ event: getEvent( testEvent( - project.MemberChangedType, + project.MemberChangedEventType, project.AggregateType, []byte(`{ "userId": "user-id", @@ -210,7 +210,7 @@ func TestProjectMemberProjection_reduces(t *testing.T) { args: args{ event: getEvent( testEvent( - project.MemberCascadeRemovedType, + project.MemberCascadeRemovedEventType, project.AggregateType, []byte(`{ "userId": "user-id" @@ -240,7 +240,7 @@ func TestProjectMemberProjection_reduces(t *testing.T) { args: args{ event: getEvent( testEvent( - project.MemberRemovedType, + project.MemberRemovedEventType, project.AggregateType, []byte(`{ "userId": "user-id" diff --git a/internal/query/projection/projection.go b/internal/query/projection/projection.go index ebe7454b58..d6647d0961 100644 --- a/internal/query/projection/projection.go +++ b/internal/query/projection/projection.go @@ -85,6 +85,7 @@ var ( ProjectGrantFields *handler.FieldHandler OrgDomainVerifiedFields *handler.FieldHandler InstanceDomainFields *handler.FieldHandler + MembershipFields *handler.FieldHandler ) type projection interface { @@ -174,6 +175,7 @@ func Create(ctx context.Context, sqlClient *database.DB, es handler.EventStore, ProjectGrantFields = newFillProjectGrantFields(applyCustomConfig(projectionConfig, config.Customizations[fieldsProjectGrant])) OrgDomainVerifiedFields = newFillOrgDomainVerifiedFields(applyCustomConfig(projectionConfig, config.Customizations[fieldsOrgDomainVerified])) InstanceDomainFields = newFillInstanceDomainFields(applyCustomConfig(projectionConfig, config.Customizations[fieldsInstanceDomain])) + MembershipFields = newFillMembershipFields(applyCustomConfig(projectionConfig, config.Customizations[fieldsMemberships])) newProjectionsList() return nil diff --git a/internal/query/projection/system_features.go b/internal/query/projection/system_features.go index 410234c27c..f6f0a36d56 100644 --- a/internal/query/projection/system_features.go +++ b/internal/query/projection/system_features.go @@ -92,6 +92,10 @@ func (*systemFeatureProjection) Reducers() []handler.AggregateReducer { Event: feature_v2.SystemLoginVersion, Reduce: reduceSystemSetFeature[*feature.LoginV2], }, + { + Event: feature_v2.SystemPermissionCheckV2, + Reduce: reduceSystemSetFeature[bool], + }, }, }} } diff --git a/internal/query/session.go b/internal/query/session.go index 54afbde064..d30fe4cda9 100644 --- a/internal/query/session.go +++ b/internal/query/session.go @@ -6,6 +6,7 @@ import ( "errors" "net" "net/http" + "slices" "time" sq "github.com/Masterminds/squirrel" @@ -80,6 +81,39 @@ type SessionsSearchQueries struct { Queries []SearchQuery } +func sessionsCheckPermission(ctx context.Context, sessions *Sessions, permissionCheck domain.PermissionCheck) { + sessions.Sessions = slices.DeleteFunc(sessions.Sessions, + func(session *Session) bool { + return sessionCheckPermission(ctx, session.ResourceOwner, session.Creator, session.UserAgent, session.UserFactor, permissionCheck) != nil + }, + ) +} + +func sessionCheckPermission(ctx context.Context, resourceOwner string, creator string, useragent domain.UserAgent, userFactor SessionUserFactor, permissionCheck domain.PermissionCheck) error { + data := authz.GetCtxData(ctx) + // no permission check necessary if user is creator + if data.UserID == creator { + return nil + } + // no permission check necessary if session belongs to the user + if userFactor.UserID != "" && data.UserID == userFactor.UserID { + return nil + } + // no permission check necessary if session belongs to the same useragent as used + if data.AgentID != "" && useragent.FingerprintID != nil && *useragent.FingerprintID != "" && data.AgentID == *useragent.FingerprintID { + return nil + } + // if session belongs to a user, check for permission on the user resource + if userFactor.ResourceOwner != "" { + if err := permissionCheck(ctx, domain.PermissionSessionRead, userFactor.ResourceOwner, userFactor.UserID); err != nil { + return err + } + return nil + } + // default, check for permission on instance + return permissionCheck(ctx, domain.PermissionSessionRead, resourceOwner, "") +} + func (q *SessionsSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder { query = q.SearchRequest.toQuery(query) for _, q := range q.Queries { @@ -195,7 +229,24 @@ var ( } ) -func (q *Queries) SessionByID(ctx context.Context, shouldTriggerBulk bool, id, sessionToken string) (session *Session, err error) { +func (q *Queries) SessionByID(ctx context.Context, shouldTriggerBulk bool, id, sessionToken string, permissionCheck domain.PermissionCheck) (session *Session, err error) { + session, tokenID, err := q.sessionByID(ctx, shouldTriggerBulk, id) + if err != nil { + return nil, err + } + if sessionToken == "" { + if err := sessionCheckPermission(ctx, session.ResourceOwner, session.Creator, session.UserAgent, session.UserFactor, permissionCheck); err != nil { + return nil, err + } + return session, nil + } + if err := q.sessionTokenVerifier(ctx, sessionToken, session.ID, tokenID); err != nil { + return nil, zerrors.ThrowPermissionDenied(nil, "QUERY-dsfr3", "Errors.PermissionDenied") + } + return session, nil +} + +func (q *Queries) sessionByID(ctx context.Context, shouldTriggerBulk bool, id string) (session *Session, tokenID string, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -214,27 +265,31 @@ func (q *Queries) SessionByID(ctx context.Context, shouldTriggerBulk bool, id, s }, ).ToSql() if err != nil { - return nil, zerrors.ThrowInternal(err, "QUERY-dn9JW", "Errors.Query.SQLStatement") + return nil, "", zerrors.ThrowInternal(err, "QUERY-dn9JW", "Errors.Query.SQLStatement") } - var tokenID string err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { session, tokenID, err = scan(row) return err }, stmt, args...) if err != nil { - return nil, err + return nil, "", err } - if sessionToken == "" { - return session, nil - } - if err := q.sessionTokenVerifier(ctx, sessionToken, session.ID, tokenID); err != nil { - return nil, zerrors.ThrowPermissionDenied(nil, "QUERY-dsfr3", "Errors.PermissionDenied") - } - return session, nil + return session, tokenID, nil } -func (q *Queries) SearchSessions(ctx context.Context, queries *SessionsSearchQueries) (sessions *Sessions, err error) { +func (q *Queries) SearchSessions(ctx context.Context, queries *SessionsSearchQueries, permissionCheck domain.PermissionCheck) (*Sessions, error) { + sessions, err := q.searchSessions(ctx, queries) + if err != nil { + return nil, err + } + if permissionCheck != nil { + sessionsCheckPermission(ctx, sessions, permissionCheck) + } + return sessions, nil +} + +func (q *Queries) searchSessions(ctx context.Context, queries *SessionsSearchQueries) (sessions *Sessions, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -272,6 +327,10 @@ func NewSessionCreatorSearchQuery(creator string) (SearchQuery, error) { return NewTextQuery(SessionColumnCreator, creator, TextEquals) } +func NewSessionUserAgentFingerprintIDSearchQuery(fingerprintID string) (SearchQuery, error) { + return NewTextQuery(SessionColumnUserAgentFingerprintID, fingerprintID, TextEquals) +} + func NewUserIDSearchQuery(id string) (SearchQuery, error) { return NewTextQuery(SessionColumnUserID, id, TextEquals) } @@ -415,6 +474,10 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui SessionColumnOTPSMSCheckedAt.identifier(), SessionColumnOTPEmailCheckedAt.identifier(), SessionColumnMetadata.identifier(), + SessionColumnUserAgentFingerprintID.identifier(), + SessionColumnUserAgentIP.identifier(), + SessionColumnUserAgentDescription.identifier(), + SessionColumnUserAgentHeader.identifier(), SessionColumnExpiration.identifier(), countColumn.identifier(), ).From(sessionsTable.identifier()). @@ -441,6 +504,8 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui otpSMSCheckedAt sql.NullTime otpEmailCheckedAt sql.NullTime metadata database.Map[[]byte] + userAgentIP sql.NullString + userAgentHeader database.Map[[]string] expiration sql.NullTime ) @@ -465,6 +530,10 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui &otpSMSCheckedAt, &otpEmailCheckedAt, &metadata, + &session.UserAgent.FingerprintID, + &userAgentIP, + &session.UserAgent.Description, + &userAgentHeader, &expiration, &sessions.Count, ) @@ -485,6 +554,10 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui session.OTPSMSFactor.OTPCheckedAt = otpSMSCheckedAt.Time session.OTPEmailFactor.OTPCheckedAt = otpEmailCheckedAt.Time session.Metadata = metadata + session.UserAgent.Header = http.Header(userAgentHeader) + if userAgentIP.Valid { + session.UserAgent.IP = net.ParseIP(userAgentIP.String) + } session.Expiration = expiration.Time sessions.Sessions = append(sessions.Sessions, session) diff --git a/internal/query/sessions_test.go b/internal/query/sessions_test.go index c7929a98a8..4109969262 100644 --- a/internal/query/sessions_test.go +++ b/internal/query/sessions_test.go @@ -15,6 +15,7 @@ import ( "github.com/muhlemmer/gu" "github.com/stretchr/testify/require" + "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/zerrors" ) @@ -71,6 +72,10 @@ var ( ` projections.sessions8.otp_sms_checked_at,` + ` projections.sessions8.otp_email_checked_at,` + ` projections.sessions8.metadata,` + + ` projections.sessions8.user_agent_fingerprint_id,` + + ` projections.sessions8.user_agent_ip,` + + ` projections.sessions8.user_agent_description,` + + ` projections.sessions8.user_agent_header,` + ` projections.sessions8.expiration,` + ` COUNT(*) OVER ()` + ` FROM projections.sessions8` + @@ -129,6 +134,10 @@ var ( "otp_sms_checked_at", "otp_email_checked_at", "metadata", + "user_agent_fingerprint_id", + "user_agent_ip", + "user_agent_description", + "user_agent_header", "expiration", "count", } @@ -186,6 +195,10 @@ func Test_SessionsPrepare(t *testing.T) { testNow, testNow, []byte(`{"key": "dmFsdWU="}`), + "fingerPrintID", + "1.2.3.4", + "agentDescription", + []byte(`{"foo":["foo","bar"]}`), testNow, }, }, @@ -233,6 +246,12 @@ func Test_SessionsPrepare(t *testing.T) { Metadata: map[string][]byte{ "key": []byte("value"), }, + UserAgent: domain.UserAgent{ + FingerprintID: gu.Ptr("fingerPrintID"), + IP: net.IPv4(1, 2, 3, 4), + Description: gu.Ptr("agentDescription"), + Header: http.Header{"foo": []string{"foo", "bar"}}, + }, Expiration: testNow, }, }, @@ -267,6 +286,10 @@ func Test_SessionsPrepare(t *testing.T) { testNow, testNow, []byte(`{"key": "dmFsdWU="}`), + "fingerPrintID", + "1.2.3.4", + "agentDescription", + []byte(`{"foo":["foo","bar"]}`), testNow, }, { @@ -290,6 +313,10 @@ func Test_SessionsPrepare(t *testing.T) { testNow, testNow, []byte(`{"key": "dmFsdWU="}`), + "fingerPrintID", + "1.2.3.4", + "agentDescription", + []byte(`{"foo":["foo","bar"]}`), testNow, }, }, @@ -337,6 +364,12 @@ func Test_SessionsPrepare(t *testing.T) { Metadata: map[string][]byte{ "key": []byte("value"), }, + UserAgent: domain.UserAgent{ + FingerprintID: gu.Ptr("fingerPrintID"), + IP: net.IPv4(1, 2, 3, 4), + Description: gu.Ptr("agentDescription"), + Header: http.Header{"foo": []string{"foo", "bar"}}, + }, Expiration: testNow, }, { @@ -376,6 +409,12 @@ func Test_SessionsPrepare(t *testing.T) { Metadata: map[string][]byte{ "key": []byte("value"), }, + UserAgent: domain.UserAgent{ + FingerprintID: gu.Ptr("fingerPrintID"), + IP: net.IPv4(1, 2, 3, 4), + Description: gu.Ptr("agentDescription"), + Header: http.Header{"foo": []string{"foo", "bar"}}, + }, Expiration: testNow, }, }, @@ -553,3 +592,157 @@ func prepareSessionQueryTesting(t *testing.T, token string) func(context.Context } } } + +func Test_sessionCheckPermission(t *testing.T) { + type args struct { + ctx context.Context + resourceOwner string + creator string + useragent domain.UserAgent + userFactor SessionUserFactor + permissionCheck domain.PermissionCheck + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "permission check, no user in context", + args: args{ + ctx: authz.NewMockContextWithAgent("instance", "org", "", ""), + resourceOwner: "instance", + creator: "creator", + permissionCheck: expectedFailedPermissionCheck("instance", ""), + }, + wantErr: true, + }, + { + name: "permission check, factor, no user in context", + args: args{ + ctx: authz.NewMockContextWithAgent("instance", "org", "", ""), + resourceOwner: "instance", + creator: "creator", + userFactor: SessionUserFactor{ResourceOwner: "resourceowner", UserID: "user"}, + permissionCheck: expectedFailedPermissionCheck("resourceowner", "user"), + }, + wantErr: true, + }, + { + name: "no permission check, creator", + args: args{ + ctx: authz.NewMockContextWithAgent("instance", "org", "user", "agent"), + resourceOwner: "instance", + creator: "user", + }, + wantErr: false, + }, + { + name: "no permission check, same user", + args: args{ + ctx: authz.NewMockContextWithAgent("instance", "org", "user", "agent"), + resourceOwner: "instance", + creator: "creator", + userFactor: SessionUserFactor{UserID: "user"}, + }, + wantErr: false, + }, + { + name: "no permission check, same useragent", + args: args{ + ctx: authz.NewMockContextWithAgent("instance", "org", "user1", "agent"), + resourceOwner: "instance", + creator: "creator", + userFactor: SessionUserFactor{UserID: "user2"}, + useragent: domain.UserAgent{ + FingerprintID: gu.Ptr("agent"), + }, + }, + wantErr: false, + }, + { + name: "permission check, factor", + args: args{ + ctx: authz.NewMockContextWithAgent("instance", "org", "user", "agent"), + resourceOwner: "instance", + creator: "not-user", + useragent: domain.UserAgent{ + FingerprintID: gu.Ptr("not-agent"), + }, + userFactor: SessionUserFactor{UserID: "user2", ResourceOwner: "resourceowner2"}, + permissionCheck: expectedSuccessfulPermissionCheck("resourceowner2", "user2"), + }, + wantErr: false, + }, + { + name: "permission check, factor, error", + args: args{ + ctx: authz.NewMockContextWithAgent("instance", "org", "user", "agent"), + resourceOwner: "instance", + creator: "not-user", + useragent: domain.UserAgent{ + FingerprintID: gu.Ptr("not-agent"), + }, + userFactor: SessionUserFactor{UserID: "user2", ResourceOwner: "resourceowner2"}, + permissionCheck: expectedFailedPermissionCheck("resourceowner2", "user2"), + }, + wantErr: true, + }, + { + name: "permission check", + args: args{ + ctx: authz.NewMockContextWithAgent("instance", "org", "user", "agent"), + resourceOwner: "instance", + creator: "not-user", + useragent: domain.UserAgent{ + FingerprintID: gu.Ptr("not-agent"), + }, + userFactor: SessionUserFactor{}, + permissionCheck: expectedSuccessfulPermissionCheck("instance", ""), + }, + wantErr: false, + }, + { + name: "permission check, error", + args: args{ + ctx: authz.NewMockContextWithAgent("instance", "org", "user", "agent"), + resourceOwner: "instance", + creator: "not-user", + useragent: domain.UserAgent{ + FingerprintID: gu.Ptr("not-agent"), + }, + userFactor: SessionUserFactor{}, + permissionCheck: expectedFailedPermissionCheck("instance", ""), + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := sessionCheckPermission(tt.args.ctx, tt.args.resourceOwner, tt.args.creator, tt.args.useragent, tt.args.userFactor, tt.args.permissionCheck) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + }) + } +} + +func expectedSuccessfulPermissionCheck(resourceOwner, userID string) func(ctx context.Context, permission, orgID, resourceID string) (err error) { + return func(ctx context.Context, permission, orgID, resourceID string) (err error) { + if orgID == resourceOwner && resourceID == userID { + return nil + } + return fmt.Errorf("permission check failed: %s %s", orgID, resourceID) + } +} + +func expectedFailedPermissionCheck(resourceOwner, userID string) func(ctx context.Context, permission, orgID, resourceID string) (err error) { + return func(ctx context.Context, permission, orgID, resourceID string) (err error) { + if orgID == resourceOwner && resourceID == userID { + return fmt.Errorf("permission check failed: %s %s", orgID, resourceID) + } + return nil + } +} diff --git a/internal/query/system_features.go b/internal/query/system_features.go index e696f6bf6f..31ad402d12 100644 --- a/internal/query/system_features.go +++ b/internal/query/system_features.go @@ -31,6 +31,7 @@ type SystemFeatures struct { DisableUserTokenEvent FeatureSource[bool] EnableBackChannelLogout FeatureSource[bool] LoginV2 FeatureSource[*feature.LoginV2] + PermissionCheckV2 FeatureSource[bool] } func (q *Queries) GetSystemFeatures(ctx context.Context) (_ *SystemFeatures, err error) { diff --git a/internal/query/system_features_model.go b/internal/query/system_features_model.go index f486e1ba4a..217154e3ed 100644 --- a/internal/query/system_features_model.go +++ b/internal/query/system_features_model.go @@ -66,6 +66,7 @@ func (m *SystemFeaturesReadModel) Query() *eventstore.SearchQueryBuilder { feature_v2.SystemDisableUserTokenEvent, feature_v2.SystemEnableBackChannelLogout, feature_v2.SystemLoginVersion, + feature_v2.SystemPermissionCheckV2, ). Builder().ResourceOwner(m.ResourceOwner) } @@ -105,6 +106,8 @@ func reduceSystemFeatureSet[T any](features *SystemFeatures, event *feature_v2.S features.EnableBackChannelLogout.set(level, event.Value) case feature.KeyLoginV2: features.LoginV2.set(level, event.Value) + case feature.KeyPermissionCheckV2: + features.PermissionCheckV2.set(level, event.Value) } return nil } diff --git a/internal/query/user.go b/internal/query/user.go index 415e50aae5..9f29ec77b3 100644 --- a/internal/query/user.go +++ b/internal/query/user.go @@ -605,24 +605,29 @@ func (q *Queries) GetNotifyUser(ctx context.Context, shouldTriggered bool, queri } func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, permissionCheck domain.PermissionCheck) (*Users, error) { - users, err := q.searchUsers(ctx, queries) + users, err := q.searchUsers(ctx, queries, permissionCheck != nil && authz.GetFeatures(ctx).PermissionCheckV2) if err != nil { return nil, err } - if permissionCheck != nil { + if permissionCheck != nil && !authz.GetFeatures(ctx).PermissionCheckV2 { usersCheckPermission(ctx, users, permissionCheck) } return users, nil } -func (q *Queries) searchUsers(ctx context.Context, queries *UserSearchQueries) (users *Users, err error) { +func (q *Queries) searchUsers(ctx context.Context, queries *UserSearchQueries, permissionCheckV2 bool) (users *Users, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() query, scan := prepareUsersQuery(ctx, q.client) - eq := sq.Eq{UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID()} - stmt, args, err := queries.toQuery(query).Where(eq). - ToSql() + query = queries.toQuery(query).Where(sq.Eq{ + UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(), + }) + if permissionCheckV2 { + query = wherePermittedOrgs(ctx, query, UserResourceOwnerCol.identifier(), domain.PermissionUserRead) + } + + stmt, args, err := query.ToSql() if err != nil { return nil, zerrors.ThrowInternal(err, "QUERY-Dgbg2", "Errors.Query.SQLStatment") } diff --git a/internal/repository/feature/feature_v2/eventstore.go b/internal/repository/feature/feature_v2/eventstore.go index d4d2617aea..f5e033af1c 100644 --- a/internal/repository/feature/feature_v2/eventstore.go +++ b/internal/repository/feature/feature_v2/eventstore.go @@ -18,6 +18,7 @@ func init() { eventstore.RegisterFilterEventMapper(AggregateType, SystemDisableUserTokenEvent, eventstore.GenericEventMapper[SetEvent[bool]]) eventstore.RegisterFilterEventMapper(AggregateType, SystemEnableBackChannelLogout, eventstore.GenericEventMapper[SetEvent[bool]]) eventstore.RegisterFilterEventMapper(AggregateType, SystemLoginVersion, eventstore.GenericEventMapper[SetEvent[*feature.LoginV2]]) + eventstore.RegisterFilterEventMapper(AggregateType, SystemPermissionCheckV2, eventstore.GenericEventMapper[SetEvent[bool]]) eventstore.RegisterFilterEventMapper(AggregateType, InstanceResetEventType, eventstore.GenericEventMapper[ResetEvent]) eventstore.RegisterFilterEventMapper(AggregateType, InstanceLoginDefaultOrgEventType, eventstore.GenericEventMapper[SetEvent[bool]]) @@ -33,4 +34,5 @@ func init() { eventstore.RegisterFilterEventMapper(AggregateType, InstanceDisableUserTokenEvent, eventstore.GenericEventMapper[SetEvent[bool]]) eventstore.RegisterFilterEventMapper(AggregateType, InstanceEnableBackChannelLogout, eventstore.GenericEventMapper[SetEvent[bool]]) eventstore.RegisterFilterEventMapper(AggregateType, InstanceLoginVersion, eventstore.GenericEventMapper[SetEvent[*feature.LoginV2]]) + eventstore.RegisterFilterEventMapper(AggregateType, InstancePermissionCheckV2, eventstore.GenericEventMapper[SetEvent[bool]]) } diff --git a/internal/repository/feature/feature_v2/feature.go b/internal/repository/feature/feature_v2/feature.go index 0255203bdd..331a5143f9 100644 --- a/internal/repository/feature/feature_v2/feature.go +++ b/internal/repository/feature/feature_v2/feature.go @@ -23,6 +23,7 @@ var ( SystemDisableUserTokenEvent = setEventTypeFromFeature(feature.LevelSystem, feature.KeyDisableUserTokenEvent) SystemEnableBackChannelLogout = setEventTypeFromFeature(feature.LevelSystem, feature.KeyEnableBackChannelLogout) SystemLoginVersion = setEventTypeFromFeature(feature.LevelSystem, feature.KeyLoginV2) + SystemPermissionCheckV2 = setEventTypeFromFeature(feature.LevelSystem, feature.KeyPermissionCheckV2) InstanceResetEventType = resetEventTypeFromFeature(feature.LevelInstance) InstanceLoginDefaultOrgEventType = setEventTypeFromFeature(feature.LevelInstance, feature.KeyLoginDefaultOrg) @@ -38,6 +39,7 @@ var ( InstanceDisableUserTokenEvent = setEventTypeFromFeature(feature.LevelInstance, feature.KeyDisableUserTokenEvent) InstanceEnableBackChannelLogout = setEventTypeFromFeature(feature.LevelInstance, feature.KeyEnableBackChannelLogout) InstanceLoginVersion = setEventTypeFromFeature(feature.LevelInstance, feature.KeyLoginV2) + InstancePermissionCheckV2 = setEventTypeFromFeature(feature.LevelInstance, feature.KeyPermissionCheckV2) ) const ( diff --git a/internal/repository/instance/member.go b/internal/repository/instance/member.go index 0518aab47f..161bdcdaec 100644 --- a/internal/repository/instance/member.go +++ b/internal/repository/instance/member.go @@ -7,17 +7,25 @@ import ( "github.com/zitadel/zitadel/internal/repository/member" ) -var ( +const ( MemberAddedEventType = instanceEventTypePrefix + member.AddedEventType MemberChangedEventType = instanceEventTypePrefix + member.ChangedEventType MemberRemovedEventType = instanceEventTypePrefix + member.RemovedEventType MemberCascadeRemovedEventType = instanceEventTypePrefix + member.CascadeRemovedEventType ) +const ( + fieldPrefix = "instance" +) + type MemberAddedEvent struct { member.MemberAddedEvent } +func (e *MemberAddedEvent) Fields() []*eventstore.FieldOperation { + return e.FieldOperations(fieldPrefix) +} + func NewMemberAddedEvent( ctx context.Context, aggregate *eventstore.Aggregate, @@ -51,6 +59,10 @@ type MemberChangedEvent struct { member.MemberChangedEvent } +func (e *MemberChangedEvent) Fields() []*eventstore.FieldOperation { + return e.FieldOperations(fieldPrefix) +} + func NewMemberChangedEvent( ctx context.Context, aggregate *eventstore.Aggregate, @@ -83,6 +95,10 @@ type MemberRemovedEvent struct { member.MemberRemovedEvent } +func (e *MemberRemovedEvent) Fields() []*eventstore.FieldOperation { + return e.FieldOperations(fieldPrefix) +} + func NewMemberRemovedEvent( ctx context.Context, aggregate *eventstore.Aggregate, @@ -113,6 +129,10 @@ type MemberCascadeRemovedEvent struct { member.MemberCascadeRemovedEvent } +func (e *MemberCascadeRemovedEvent) Fields() []*eventstore.FieldOperation { + return e.FieldOperations(fieldPrefix) +} + func NewMemberCascadeRemovedEvent( ctx context.Context, aggregate *eventstore.Aggregate, diff --git a/internal/repository/member/events.go b/internal/repository/member/events.go index 0c98b46a41..5d0a28c243 100644 --- a/internal/repository/member/events.go +++ b/internal/repository/member/events.go @@ -7,6 +7,7 @@ import ( "github.com/zitadel/zitadel/internal/zerrors" ) +// Event types const ( UniqueMember = "member" AddedEventType = "member.added" @@ -15,6 +16,13 @@ const ( CascadeRemovedEventType = "member.cascade.removed" ) +// Field table and unique types +const ( + memberRoleTypeSuffix string = "_member_role" + MemberRoleRevision uint8 = 1 + roleSearchFieldSuffix string = "_role" +) + func NewAddMemberUniqueConstraint(aggregateID, userID string) *eventstore.UniqueConstraint { return eventstore.NewAddEventUniqueConstraint( UniqueMember, @@ -44,6 +52,32 @@ func (e *MemberAddedEvent) UniqueConstraints() []*eventstore.UniqueConstraint { return []*eventstore.UniqueConstraint{NewAddMemberUniqueConstraint(e.Aggregate().ID, e.UserID)} } +func (e *MemberAddedEvent) FieldOperations(prefix string) []*eventstore.FieldOperation { + ops := make([]*eventstore.FieldOperation, len(e.Roles)) + for i, role := range e.Roles { + ops[i] = eventstore.SetField( + e.Aggregate(), + memberSearchObject(prefix, e.UserID), + prefix+roleSearchFieldSuffix, + &eventstore.Value{ + Value: role, + MustBeUnique: false, + ShouldIndex: true, + }, + + eventstore.FieldTypeInstanceID, + eventstore.FieldTypeResourceOwner, + eventstore.FieldTypeAggregateType, + eventstore.FieldTypeAggregateID, + eventstore.FieldTypeObjectType, + eventstore.FieldTypeObjectID, + eventstore.FieldTypeFieldName, + eventstore.FieldTypeValue, + ) + } + return ops +} + func NewMemberAddedEvent( base *eventstore.BaseEvent, userID string, @@ -85,6 +119,38 @@ func (e *MemberChangedEvent) UniqueConstraints() []*eventstore.UniqueConstraint return nil } +// FieldOperations removes the existing membership role fields first and sets the new roles after. +func (e *MemberChangedEvent) FieldOperations(prefix string) []*eventstore.FieldOperation { + ops := make([]*eventstore.FieldOperation, len(e.Roles)+1) + ops[0] = eventstore.RemoveSearchFieldsByAggregateAndObject( + e.Aggregate(), + memberSearchObject(prefix, e.UserID), + ) + + for i, role := range e.Roles { + ops[i+1] = eventstore.SetField( + e.Aggregate(), + memberSearchObject(prefix, e.UserID), + prefix+roleSearchFieldSuffix, + &eventstore.Value{ + Value: role, + MustBeUnique: false, + ShouldIndex: true, + }, + + eventstore.FieldTypeInstanceID, + eventstore.FieldTypeResourceOwner, + eventstore.FieldTypeAggregateType, + eventstore.FieldTypeAggregateID, + eventstore.FieldTypeObjectType, + eventstore.FieldTypeObjectID, + eventstore.FieldTypeFieldName, + eventstore.FieldTypeValue, + ) + } + return ops +} + func NewMemberChangedEvent( base *eventstore.BaseEvent, userID string, @@ -124,6 +190,15 @@ func (e *MemberRemovedEvent) UniqueConstraints() []*eventstore.UniqueConstraint return []*eventstore.UniqueConstraint{NewRemoveMemberUniqueConstraint(e.Aggregate().ID, e.UserID)} } +func (e *MemberRemovedEvent) FieldOperations(prefix string) []*eventstore.FieldOperation { + return []*eventstore.FieldOperation{ + eventstore.RemoveSearchFieldsByAggregateAndObject( + e.Aggregate(), + memberSearchObject(prefix, e.UserID), + ), + } +} + func NewRemovedEvent( base *eventstore.BaseEvent, userID string, @@ -162,6 +237,15 @@ func (e *MemberCascadeRemovedEvent) UniqueConstraints() []*eventstore.UniqueCons return []*eventstore.UniqueConstraint{NewRemoveMemberUniqueConstraint(e.Aggregate().ID, e.UserID)} } +func (e *MemberCascadeRemovedEvent) FieldOperations(prefix string) []*eventstore.FieldOperation { + return []*eventstore.FieldOperation{ + eventstore.RemoveSearchFieldsByAggregateAndObject( + e.Aggregate(), + memberSearchObject(prefix, e.UserID), + ), + } +} + func NewCascadeRemovedEvent( base *eventstore.BaseEvent, userID string, @@ -185,3 +269,11 @@ func CascadeRemovedEventMapper(event eventstore.Event) (eventstore.Event, error) return e, nil } + +func memberSearchObject(prefix, userID string) eventstore.Object { + return eventstore.Object{ + Type: prefix + memberRoleTypeSuffix, + ID: userID, + Revision: MemberRoleRevision, + } +} diff --git a/internal/repository/org/member.go b/internal/repository/org/member.go index 81a4d5850f..5068a274b8 100644 --- a/internal/repository/org/member.go +++ b/internal/repository/org/member.go @@ -7,17 +7,25 @@ import ( "github.com/zitadel/zitadel/internal/repository/member" ) -var ( +const ( MemberAddedEventType = orgEventTypePrefix + member.AddedEventType MemberChangedEventType = orgEventTypePrefix + member.ChangedEventType MemberRemovedEventType = orgEventTypePrefix + member.RemovedEventType MemberCascadeRemovedEventType = orgEventTypePrefix + member.CascadeRemovedEventType ) +const ( + fieldPrefix = "org" +) + type MemberAddedEvent struct { member.MemberAddedEvent } +func (e *MemberAddedEvent) Fields() []*eventstore.FieldOperation { + return e.FieldOperations(fieldPrefix) +} + func NewMemberAddedEvent( ctx context.Context, aggregate *eventstore.Aggregate, @@ -50,6 +58,10 @@ type MemberChangedEvent struct { member.MemberChangedEvent } +func (e *MemberChangedEvent) Fields() []*eventstore.FieldOperation { + return e.FieldOperations(fieldPrefix) +} + func NewMemberChangedEvent( ctx context.Context, aggregate *eventstore.Aggregate, @@ -83,6 +95,10 @@ type MemberRemovedEvent struct { member.MemberRemovedEvent } +func (e *MemberRemovedEvent) Fields() []*eventstore.FieldOperation { + return e.FieldOperations(fieldPrefix) +} + func NewMemberRemovedEvent( ctx context.Context, aggregate *eventstore.Aggregate, @@ -113,6 +129,10 @@ type MemberCascadeRemovedEvent struct { member.MemberCascadeRemovedEvent } +func (e *MemberCascadeRemovedEvent) Fields() []*eventstore.FieldOperation { + return e.FieldOperations(fieldPrefix) +} + func NewMemberCascadeRemovedEvent( ctx context.Context, aggregate *eventstore.Aggregate, diff --git a/internal/repository/owner/owner_corrected.go b/internal/repository/owner/owner_corrected.go new file mode 100644 index 0000000000..29bb4842d4 --- /dev/null +++ b/internal/repository/owner/owner_corrected.go @@ -0,0 +1,40 @@ +package owner + +import ( + "context" + + "github.com/zitadel/zitadel/internal/eventstore" +) + +const OwnerCorrectedType = ".owner.corrected" + +type Corrected struct { + eventstore.BaseEvent `json:"-"` + + PreviousOwners map[uint32]string `json:"previousOwners,omitempty"` +} + +var _ eventstore.Command = (*Corrected)(nil) + +func (e *Corrected) Payload() interface{} { + return e +} + +func (e *Corrected) UniqueConstraints() []*eventstore.UniqueConstraint { + return nil +} + +func NewCorrected( + ctx context.Context, + aggregate *eventstore.Aggregate, + previousOwners map[uint32]string, +) *Corrected { + return &Corrected{ + BaseEvent: *eventstore.NewBaseEventForPush( + ctx, + aggregate, + eventstore.EventType(aggregate.Type+OwnerCorrectedType), + ), + PreviousOwners: previousOwners, + } +} diff --git a/internal/repository/permission/aggregate.go b/internal/repository/permission/aggregate.go new file mode 100644 index 0000000000..a0ac199102 --- /dev/null +++ b/internal/repository/permission/aggregate.go @@ -0,0 +1,22 @@ +package permission + +import "github.com/zitadel/zitadel/internal/eventstore" + +const ( + AggregateType eventstore.AggregateType = "permission" + AggregateVersion eventstore.Version = "v1" +) + +func NewAggregate(aggregateID string) *eventstore.Aggregate { + var instanceID string + if aggregateID != "SYSTEM" { + instanceID = aggregateID + } + return &eventstore.Aggregate{ + ID: aggregateID, + Type: AggregateType, + ResourceOwner: aggregateID, + InstanceID: instanceID, + Version: AggregateVersion, + } +} diff --git a/internal/repository/permission/permission.go b/internal/repository/permission/permission.go new file mode 100644 index 0000000000..a02a4dca0a --- /dev/null +++ b/internal/repository/permission/permission.go @@ -0,0 +1,114 @@ +package permission + +import ( + "context" + + "github.com/zitadel/zitadel/internal/eventstore" +) + +// Event types +const ( + permissionEventPrefix eventstore.EventType = "permission." + AddedType = permissionEventPrefix + "added" + RemovedType = permissionEventPrefix + "removed" +) + +// Field table and unique types +const ( + RolePermissionType string = "role_permission" + RolePermissionRevision uint8 = 1 + PermissionSearchField string = "permission" +) + +type AddedEvent struct { + *eventstore.BaseEvent `json:"-"` + Role string `json:"role"` + Permission string `json:"permission"` +} + +func (e *AddedEvent) Payload() interface{} { + return e +} + +func (e *AddedEvent) UniqueConstraints() []*eventstore.UniqueConstraint { + return nil +} + +func (e *AddedEvent) SetBaseEvent(event *eventstore.BaseEvent) { + e.BaseEvent = event +} + +func (e *AddedEvent) Fields() []*eventstore.FieldOperation { + return []*eventstore.FieldOperation{ + eventstore.SetField( + e.Aggregate(), + roleSearchObject(e.Role), + PermissionSearchField, + &eventstore.Value{ + Value: e.Permission, + MustBeUnique: false, + ShouldIndex: true, + }, + + eventstore.FieldTypeInstanceID, + eventstore.FieldTypeResourceOwner, + eventstore.FieldTypeAggregateType, + eventstore.FieldTypeAggregateID, + eventstore.FieldTypeObjectType, + eventstore.FieldTypeObjectID, + eventstore.FieldTypeFieldName, + eventstore.FieldTypeValue, + ), + } +} + +func NewAddedEvent(ctx context.Context, aggregate *eventstore.Aggregate, role, permission string) *AddedEvent { + return &AddedEvent{ + BaseEvent: eventstore.NewBaseEventForPush(ctx, aggregate, AddedType), + Role: role, + Permission: permission, + } +} + +type RemovedEvent struct { + *eventstore.BaseEvent `json:"-"` + Role string `json:"role"` + Permission string `json:"permission"` +} + +func (e *RemovedEvent) Payload() interface{} { + return e +} + +func (e *RemovedEvent) UniqueConstraints() []*eventstore.UniqueConstraint { + return nil +} + +func (e *RemovedEvent) SetBaseEvent(event *eventstore.BaseEvent) { + e.BaseEvent = event +} + +func (e *RemovedEvent) Fields() []*eventstore.FieldOperation { + return []*eventstore.FieldOperation{ + eventstore.RemoveSearchFieldsByAggregateAndObject( + e.Aggregate(), + roleSearchObject(e.Role), + ), + } +} + +func NewRemovedEvent(ctx context.Context, aggregate *eventstore.Aggregate, role, permission string) *RemovedEvent { + return &RemovedEvent{ + BaseEvent: eventstore.NewBaseEventForPush(ctx, aggregate, AddedType), + Role: role, + Permission: permission, + } +} + +func roleSearchObject(role string) eventstore.Object { + return eventstore.Object{ + Type: RolePermissionType, + ID: role, + Revision: RolePermissionRevision, + } +} diff --git a/internal/repository/project/eventstore.go b/internal/repository/project/eventstore.go index 5705649739..2648737d3b 100644 --- a/internal/repository/project/eventstore.go +++ b/internal/repository/project/eventstore.go @@ -10,10 +10,10 @@ func init() { eventstore.RegisterFilterEventMapper(AggregateType, ProjectDeactivatedType, ProjectDeactivatedEventMapper) eventstore.RegisterFilterEventMapper(AggregateType, ProjectReactivatedType, ProjectReactivatedEventMapper) eventstore.RegisterFilterEventMapper(AggregateType, ProjectRemovedType, ProjectRemovedEventMapper) - eventstore.RegisterFilterEventMapper(AggregateType, MemberAddedType, MemberAddedEventMapper) - eventstore.RegisterFilterEventMapper(AggregateType, MemberChangedType, MemberChangedEventMapper) - eventstore.RegisterFilterEventMapper(AggregateType, MemberRemovedType, MemberRemovedEventMapper) - eventstore.RegisterFilterEventMapper(AggregateType, MemberCascadeRemovedType, MemberCascadeRemovedEventMapper) + eventstore.RegisterFilterEventMapper(AggregateType, MemberAddedEventType, MemberAddedEventMapper) + eventstore.RegisterFilterEventMapper(AggregateType, MemberChangedEventType, MemberChangedEventMapper) + eventstore.RegisterFilterEventMapper(AggregateType, MemberRemovedEventType, MemberRemovedEventMapper) + eventstore.RegisterFilterEventMapper(AggregateType, MemberCascadeRemovedEventType, MemberCascadeRemovedEventMapper) eventstore.RegisterFilterEventMapper(AggregateType, RoleAddedType, RoleAddedEventMapper) eventstore.RegisterFilterEventMapper(AggregateType, RoleChangedType, RoleChangedEventMapper) eventstore.RegisterFilterEventMapper(AggregateType, RoleRemovedType, RoleRemovedEventMapper) diff --git a/internal/repository/project/member.go b/internal/repository/project/member.go index d2928bfdc2..6fb3ceddfe 100644 --- a/internal/repository/project/member.go +++ b/internal/repository/project/member.go @@ -8,16 +8,24 @@ import ( ) var ( - MemberAddedType = projectEventTypePrefix + member.AddedEventType - MemberChangedType = projectEventTypePrefix + member.ChangedEventType - MemberRemovedType = projectEventTypePrefix + member.RemovedEventType - MemberCascadeRemovedType = projectEventTypePrefix + member.CascadeRemovedEventType + MemberAddedEventType = projectEventTypePrefix + member.AddedEventType + MemberChangedEventType = projectEventTypePrefix + member.ChangedEventType + MemberRemovedEventType = projectEventTypePrefix + member.RemovedEventType + MemberCascadeRemovedEventType = projectEventTypePrefix + member.CascadeRemovedEventType +) + +const ( + fieldPrefix = "project" ) type MemberAddedEvent struct { member.MemberAddedEvent } +func (e *MemberAddedEvent) Fields() []*eventstore.FieldOperation { + return e.FieldOperations(fieldPrefix) +} + func NewProjectMemberAddedEvent( ctx context.Context, aggregate *eventstore.Aggregate, @@ -29,7 +37,7 @@ func NewProjectMemberAddedEvent( eventstore.NewBaseEventForPush( ctx, aggregate, - MemberAddedType, + MemberAddedEventType, ), userID, roles..., @@ -50,6 +58,10 @@ type MemberChangedEvent struct { member.MemberChangedEvent } +func (e *MemberChangedEvent) Fields() []*eventstore.FieldOperation { + return e.FieldOperations(fieldPrefix) +} + func NewProjectMemberChangedEvent( ctx context.Context, aggregate *eventstore.Aggregate, @@ -62,7 +74,7 @@ func NewProjectMemberChangedEvent( eventstore.NewBaseEventForPush( ctx, aggregate, - MemberChangedType, + MemberChangedEventType, ), userID, roles..., @@ -83,6 +95,10 @@ type MemberRemovedEvent struct { member.MemberRemovedEvent } +func (e *MemberRemovedEvent) Fields() []*eventstore.FieldOperation { + return e.FieldOperations(fieldPrefix) +} + func NewProjectMemberRemovedEvent( ctx context.Context, aggregate *eventstore.Aggregate, @@ -94,7 +110,7 @@ func NewProjectMemberRemovedEvent( eventstore.NewBaseEventForPush( ctx, aggregate, - MemberRemovedType, + MemberRemovedEventType, ), userID, ), @@ -114,6 +130,10 @@ type MemberCascadeRemovedEvent struct { member.MemberCascadeRemovedEvent } +func (e *MemberCascadeRemovedEvent) Fields() []*eventstore.FieldOperation { + return e.FieldOperations(fieldPrefix) +} + func NewProjectMemberCascadeRemovedEvent( ctx context.Context, aggregate *eventstore.Aggregate, @@ -125,7 +145,7 @@ func NewProjectMemberCascadeRemovedEvent( eventstore.NewBaseEventForPush( ctx, aggregate, - MemberCascadeRemovedType, + MemberCascadeRemovedEventType, ), userID, ), diff --git a/internal/repository/project/project.go b/internal/repository/project/project.go index 6147a632eb..44f882b3e1 100644 --- a/internal/repository/project/project.go +++ b/internal/repository/project/project.go @@ -16,6 +16,7 @@ const ( ProjectDeactivatedType = projectEventTypePrefix + "deactivated" ProjectReactivatedType = projectEventTypePrefix + "reactivated" ProjectRemovedType = projectEventTypePrefix + "removed" + ProjectOwnerCorrected = projectEventTypePrefix + "owner.corrected" ProjectSearchType = "project" ProjectObjectRevision = uint8(1) diff --git a/internal/zerrors/zerror.go b/internal/zerrors/zerror.go index d7b85b84a7..996f67ce29 100644 --- a/internal/zerrors/zerror.go +++ b/internal/zerrors/zerror.go @@ -79,3 +79,8 @@ func (err *ZitadelError) As(target interface{}) bool { reflect.Indirect(reflect.ValueOf(target)).Set(reflect.ValueOf(err)) return true } + +func IsZitadelError(err error) bool { + zitadelErr := new(ZitadelError) + return errors.As(err, &zitadelErr) +} diff --git a/internal/zerrors/zerror_test.go b/internal/zerrors/zerror_test.go index 3a11a8e78e..517f938ee4 100644 --- a/internal/zerrors/zerror_test.go +++ b/internal/zerrors/zerror_test.go @@ -1,6 +1,7 @@ package zerrors_test import ( + "errors" "testing" "github.com/stretchr/testify/assert" @@ -17,3 +18,27 @@ func TestErrorMethod(t *testing.T) { subExptected := "ID=subID Message=subMsg Parent=(ID=id Message=msg)" assert.Equal(t, subExptected, err.Error()) } + +func TestIsZitadelError(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + { + name: "zitadel error", + err: zerrors.ThrowInvalidArgument(nil, "id", "msg"), + want: true, + }, + { + name: "other error", + err: errors.New("just a random error"), + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, zerrors.IsZitadelError(tt.err), "IsZitadelError(%v)", tt.err) + }) + } +} diff --git a/proto/zitadel/feature/v2/instance.proto b/proto/zitadel/feature/v2/instance.proto index 385ce5a4d0..3d2280fc0c 100644 --- a/proto/zitadel/feature/v2/instance.proto +++ b/proto/zitadel/feature/v2/instance.proto @@ -99,6 +99,13 @@ message SetInstanceFeaturesRequest{ description: "Specify the login UI for all users and applications regardless of their preference."; } ]; + + optional bool permission_check_v2 = 14 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + example: "true"; + description: "Enable a newer, more performant, permission check used for v2 and v3 resource based APIs."; + } + ]; } message SetInstanceFeaturesResponse { @@ -212,4 +219,10 @@ message GetInstanceFeaturesResponse { description: "If the flag is set, all users will be redirected to the login V2 regardless of the application's preference."; } ]; + + FeatureFlag permission_check_v2 = 15 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "Enable a newer, more performant, permission check used for v2 and v3 resource based APIs."; + } + ]; } diff --git a/proto/zitadel/feature/v2/system.proto b/proto/zitadel/feature/v2/system.proto index cac8fe774f..c734905fb2 100644 --- a/proto/zitadel/feature/v2/system.proto +++ b/proto/zitadel/feature/v2/system.proto @@ -88,6 +88,13 @@ message SetSystemFeaturesRequest{ description: "Specify the login UI for all users and applications regardless of their preference."; } ]; + + optional bool permission_check_v2 = 12 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + example: "true"; + description: "Enable a newer, more performant, permission check used for v2 and v3 resource based APIs."; + } + ]; } message SetSystemFeaturesResponse { @@ -180,4 +187,10 @@ message GetSystemFeaturesResponse { description: "If the flag is set, all users will be redirected to the login V2 regardless of the application's preference."; } ]; + + FeatureFlag permission_check_v2 = 13 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "Enable a newer, more performant, permission check used for v2 and v3 resource based APIs."; + } + ]; } diff --git a/proto/zitadel/member.proto b/proto/zitadel/member.proto index 07091e195e..c3351a99d3 100644 --- a/proto/zitadel/member.proto +++ b/proto/zitadel/member.proto @@ -63,6 +63,14 @@ message Member { description: "type of the user (human / machine)" } ]; + + // The organization the user belong to. + string user_resource_owner = 11 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + example: "\"69629023906488334\""; + } + ]; + } message SearchQuery { diff --git a/proto/zitadel/session/v2/session.proto b/proto/zitadel/session/v2/session.proto index 2c17d81f99..7ab6b77610 100644 --- a/proto/zitadel/session/v2/session.proto +++ b/proto/zitadel/session/v2/session.proto @@ -136,6 +136,8 @@ message SearchQuery { IDsQuery ids_query = 1; UserIDQuery user_id_query = 2; CreationDateQuery creation_date_query = 3; + CreatorQuery creator_query = 4; + UserAgentQuery user_agent_query = 5; } } @@ -157,9 +159,33 @@ message CreationDateQuery { ]; } +message CreatorQuery { + // ID of the user who created the session. If empty, the calling user's ID is used. + optional string id = 1 [ + (validate.rules).string = {max_len: 200}, + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + max_length: 200; + example: "\"69629023906488334\""; + } + ]; +} + +message UserAgentQuery { + // Finger print id of the user agent used for the session. + // Set an empty fingerprint_id to use the user agent from the call. + // If the user agent is not available from the current token, an error will be returned. + optional string fingerprint_id = 1 [ + (validate.rules).string = {max_len: 200}, + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + max_length: 200; + example: "\"69629023906488334\""; + } + ]; +} + message UserAgent { optional string fingerprint_id = 1; - optional string ip = 2; + optional string ip = 2; optional string description = 3; // A header may have multiple values. @@ -169,7 +195,7 @@ message UserAgent { message HeaderValues { repeated string values = 1; } - map header = 4; + map header = 4; } enum SessionFieldName {