From 041af26917e6cf66ee3022fccc3288eccbfdeb4b Mon Sep 17 00:00:00 2001 From: Livio Spring Date: Thu, 31 Oct 2024 15:57:17 +0100 Subject: [PATCH] feat(OIDC): add back channel logout (#8837) # Which Problems Are Solved Currently ZITADEL supports RP-initiated logout for clients. Back-channel logout ensures that user sessions are terminated across all connected applications, even if the user closes their browser or loses connectivity providing a more secure alternative for certain use cases. # How the Problems Are Solved If the feature is activated and the client used for the authentication has a back_channel_logout_uri configured, a `session_logout.back_channel` will be registered. Once a user terminates their session, a (notification) handler will send a SET (form POST) to the registered uri containing a logout_token (with the user's ID and session ID). - A new feature "back_channel_logout" is added on system and instance level - A `back_channel_logout_uri` can be managed on OIDC applications - Added a `session_logout` aggregate to register and inform about sent `back_channel` notifications - Added a `SecurityEventToken` channel and `Form`message type in the notification handlers - Added `TriggeredAtOrigin` fields to `HumanSignedOut` and `TerminateSession` events for notification handling - Exported various functions and types in the `oidc` package to be able to reuse for token signing in the back_channel notifier. - To prevent that current existing session termination events will be handled, a setup step is added to set the `current_states` for the `projections.notifications_back_channel_logout` to the current position - [x] requires https://github.com/zitadel/oidc/pull/671 # Additional Changes - Updated all OTEL dependencies to v1.29.0, since OIDC already updated some of them to that version. - Single Session Termination feature is correctly checked (fixed feature mapping) # Additional Context - closes https://github.com/zitadel/zitadel/issues/8467 - TODO: - Documentation - UI to be done: https://github.com/zitadel/zitadel/issues/8469 --------- Co-authored-by: Hidde Wieringa --- cmd/defaults.yaml | 1 + cmd/mirror/projections.go | 3 + cmd/setup/37.go | 27 ++ cmd/setup/37.sql | 1 + cmd/setup/38.go | 28 ++ cmd/setup/38.sql | 20 + cmd/setup/config.go | 2 + cmd/setup/setup.go | 7 + cmd/start/start.go | 3 + go.mod | 34 +- go.sum | 68 ++-- internal/api/grpc/feature/v2/converter.go | 4 + .../api/grpc/feature/v2/converter_test.go | 18 + .../project_application_converter.go | 2 + internal/api/grpc/project/application.go | 1 + internal/api/oidc/auth_request.go | 15 +- internal/api/oidc/key.go | 8 +- internal/api/oidc/op.go | 1 + internal/api/oidc/server.go | 3 + internal/api/oidc/token.go | 23 +- internal/api/oidc/token_client_credentials.go | 1 + internal/api/oidc/token_code.go | 1 + internal/api/oidc/token_exchange.go | 4 +- internal/api/oidc/token_jwt_profile.go | 1 + internal/api/oidc/token_refresh.go | 1 + .../eventsourcing/eventstore/user.go | 33 +- .../eventsourcing/view/user_session.go | 4 +- internal/auth/repository/user.go | 6 +- internal/command/instance_domain_test.go | 1 + internal/command/instance_features.go | 4 +- internal/command/instance_features_model.go | 5 + internal/command/instance_test.go | 2 + internal/command/logout_session.go | 24 ++ internal/command/logout_session_model.go | 74 ++++ internal/command/oidc_session.go | 25 +- internal/command/oidc_session_test.go | 356 ++++++++++++++++-- internal/command/project_application_oidc.go | 4 + .../command/project_application_oidc_model.go | 9 + .../command/project_application_oidc_test.go | 22 ++ internal/command/project_converter.go | 1 + internal/command/system_features.go | 4 +- internal/command/system_features_model.go | 5 + internal/command/user_human.go | 17 +- internal/command/user_human_test.go | 41 +- internal/domain/application_oidc.go | 1 + internal/feature/feature.go | 4 +- internal/feature/key_enumer.go | 12 +- internal/notification/channels.go | 12 + internal/notification/channels/set/channel.go | 75 ++++ internal/notification/channels/set/config.go | 14 + .../handlers/back_channel_logout.go | 266 +++++++++++++ internal/notification/handlers/ctx.go | 10 + .../handlers/mock/commands.mock.go | 113 +++--- .../handlers/mock/queries.mock.go | 153 +++++--- internal/notification/handlers/queries.go | 6 + .../handlers/user_notifier_test.go | 5 + internal/notification/messages/form.go | 27 ++ internal/notification/projections.go | 19 +- .../senders/security_event_token.go | 49 +++ internal/notification/types/notification.go | 20 + .../types/security_token_event.go | 27 ++ internal/query/app.go | 13 + internal/query/app_test.go | 37 ++ internal/query/instance_features.go | 1 + internal/query/instance_features_model.go | 4 + internal/query/oidc_client.go | 1 + internal/query/oidc_client_by_id.sql | 4 +- internal/query/projection/app.go | 8 +- internal/query/projection/app_test.go | 19 +- .../query/projection/instance_features.go | 4 + internal/query/projection/system_features.go | 4 + internal/query/system_features.go | 1 + internal/query/system_features_model.go | 3 + .../feature/feature_v2/eventstore.go | 2 + .../repository/feature/feature_v2/feature.go | 2 + internal/repository/project/oidc_config.go | 15 +- internal/repository/session/session.go | 2 + .../repository/sessionlogout/aggregate.go | 26 ++ internal/repository/sessionlogout/events.go | 79 ++++ .../repository/sessionlogout/eventstore.go | 15 + internal/repository/user/human.go | 15 +- ...=> active_user_sessions_by_session_id.sql} | 3 +- .../user/repository/view/user_session_view.go | 27 +- proto/zitadel/app.proto | 6 + proto/zitadel/feature/v2/instance.proto | 14 + proto/zitadel/feature/v2/system.proto | 14 + proto/zitadel/management.proto | 12 + 87 files changed, 1778 insertions(+), 280 deletions(-) create mode 100644 cmd/setup/37.go create mode 100644 cmd/setup/37.sql create mode 100644 cmd/setup/38.go create mode 100644 cmd/setup/38.sql create mode 100644 internal/command/logout_session.go create mode 100644 internal/command/logout_session_model.go create mode 100644 internal/notification/channels/set/channel.go create mode 100644 internal/notification/channels/set/config.go create mode 100644 internal/notification/handlers/back_channel_logout.go create mode 100644 internal/notification/messages/form.go create mode 100644 internal/notification/senders/security_event_token.go create mode 100644 internal/notification/types/security_token_event.go create mode 100644 internal/repository/sessionlogout/aggregate.go create mode 100644 internal/repository/sessionlogout/events.go create mode 100644 internal/repository/sessionlogout/eventstore.go rename internal/user/repository/view/{active_user_ids_by_session_id.sql => active_user_sessions_by_session_id.sql} (91%) diff --git a/cmd/defaults.yaml b/cmd/defaults.yaml index a12fe474ba..f691fd2af2 100644 --- a/cmd/defaults.yaml +++ b/cmd/defaults.yaml @@ -411,6 +411,7 @@ OIDC: DefaultLoginURLV2: "/login?authRequest=" # ZITADEL_OIDC_DEFAULTLOGINURLV2 DefaultLogoutURLV2: "/logout?post_logout_redirect=" # ZITADEL_OIDC_DEFAULTLOGOUTURLV2 PublicKeyCacheMaxAge: 24h # ZITADEL_OIDC_PUBLICKEYCACHEMAXAGE + DefaultBackChannelLogoutLifetime: 15m # ZITADEL_OIDC_DEFAULTBACKCHANNELLOGOUTLIFETIME SAML: ProviderConfig: diff --git a/cmd/mirror/projections.go b/cmd/mirror/projections.go index 442609d12a..9b7ec02cb8 100644 --- a/cmd/mirror/projections.go +++ b/cmd/mirror/projections.go @@ -200,6 +200,7 @@ func projections( ctx, config.Projections.Customizations["notifications"], config.Projections.Customizations["notificationsquotas"], + config.Projections.Customizations["backchannel"], config.Projections.Customizations["telemetry"], *config.Telemetry, config.ExternalDomain, @@ -213,6 +214,8 @@ func projections( keys.User, keys.SMTP, keys.SMS, + keys.OIDC, + config.OIDC.DefaultBackChannelLogoutLifetime, ) config.Auth.Spooler.Client = client diff --git a/cmd/setup/37.go b/cmd/setup/37.go new file mode 100644 index 0000000000..1587b5c793 --- /dev/null +++ b/cmd/setup/37.go @@ -0,0 +1,27 @@ +package setup + +import ( + "context" + _ "embed" + + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/eventstore" +) + +var ( + //go:embed 37.sql + addBackChannelLogoutURI string +) + +type Apps7OIDConfigsBackChannelLogoutURI struct { + dbClient *database.DB +} + +func (mig *Apps7OIDConfigsBackChannelLogoutURI) Execute(ctx context.Context, _ eventstore.Event) error { + _, err := mig.dbClient.ExecContext(ctx, addBackChannelLogoutURI) + return err +} + +func (mig *Apps7OIDConfigsBackChannelLogoutURI) String() string { + return "37_apps7_oidc_configs_add_back_channel_logout_uri" +} diff --git a/cmd/setup/37.sql b/cmd/setup/37.sql new file mode 100644 index 0000000000..6c3fdf0dda --- /dev/null +++ b/cmd/setup/37.sql @@ -0,0 +1 @@ +ALTER TABLE IF EXISTS projections.apps7_oidc_configs ADD COLUMN IF NOT EXISTS back_channel_logout_uri TEXT; diff --git a/cmd/setup/38.go b/cmd/setup/38.go new file mode 100644 index 0000000000..0a102c9d12 --- /dev/null +++ b/cmd/setup/38.go @@ -0,0 +1,28 @@ +package setup + +import ( + "context" + _ "embed" + + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/eventstore" +) + +var ( + //go:embed 38.sql + backChannelLogoutCurrentState string +) + +type BackChannelLogoutNotificationStart struct { + dbClient *database.DB + esClient *eventstore.Eventstore +} + +func (mig *BackChannelLogoutNotificationStart) Execute(ctx context.Context, e eventstore.Event) error { + _, err := mig.dbClient.ExecContext(ctx, backChannelLogoutCurrentState, e.Sequence(), e.CreatedAt(), e.Position()) + return err +} + +func (mig *BackChannelLogoutNotificationStart) String() string { + return "38_back_channel_logout_notification_start_" +} diff --git a/cmd/setup/38.sql b/cmd/setup/38.sql new file mode 100644 index 0000000000..d8915fee4f --- /dev/null +++ b/cmd/setup/38.sql @@ -0,0 +1,20 @@ +INSERT INTO projections.current_states ( + instance_id + , projection_name + , last_updated + , sequence + , event_date + , position + , filter_offset +) + SELECT instance_id + , 'projections.notifications_back_channel_logout' + , now() + , $1 + , $2 + , $3 + , 0 + FROM eventstore.events2 + WHERE aggregate_type = 'instance' + AND event_type = 'instance.added' + ON CONFLICT DO NOTHING; \ No newline at end of file diff --git a/cmd/setup/config.go b/cmd/setup/config.go index 7a5beebcfe..09044456ea 100644 --- a/cmd/setup/config.go +++ b/cmd/setup/config.go @@ -123,6 +123,8 @@ type Steps struct { s34AddCacheSchema *AddCacheSchema s35AddPositionToIndexEsWm *AddPositionToIndexEsWm s36FillV2Milestones *FillV2Milestones + s37Apps7OIDConfigsBackChannelLogoutURI *Apps7OIDConfigsBackChannelLogoutURI + s38BackChannelLogoutNotificationStart *BackChannelLogoutNotificationStart } func MustNewSteps(v *viper.Viper) *Steps { diff --git a/cmd/setup/setup.go b/cmd/setup/setup.go index e24b69d5b6..7ffef5e853 100644 --- a/cmd/setup/setup.go +++ b/cmd/setup/setup.go @@ -166,6 +166,8 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string) steps.s34AddCacheSchema = &AddCacheSchema{dbClient: queryDBClient} steps.s35AddPositionToIndexEsWm = &AddPositionToIndexEsWm{dbClient: esPusherDBClient} steps.s36FillV2Milestones = &FillV2Milestones{dbClient: queryDBClient, eventstore: eventstoreClient} + steps.s37Apps7OIDConfigsBackChannelLogoutURI = &Apps7OIDConfigsBackChannelLogoutURI{dbClient: esPusherDBClient} + steps.s38BackChannelLogoutNotificationStart = &BackChannelLogoutNotificationStart{dbClient: esPusherDBClient, esClient: eventstoreClient} err = projection.Create(ctx, projectionDBClient, eventstoreClient, config.Projections, nil, nil, nil) logging.OnError(err).Fatal("unable to start projections") @@ -211,6 +213,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string) steps.s34AddCacheSchema, steps.s35AddPositionToIndexEsWm, steps.s36FillV2Milestones, + steps.s38BackChannelLogoutNotificationStart, } { mustExecuteMigration(ctx, eventstoreClient, step, "migration failed") } @@ -227,6 +230,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string) steps.s27IDPTemplate6SAMLNameIDFormat, steps.s32AddAuthSessionID, steps.s33SMSConfigs3TwilioAddVerifyServiceSid, + steps.s37Apps7OIDConfigsBackChannelLogoutURI, } { mustExecuteMigration(ctx, eventstoreClient, step, "migration failed") } @@ -424,6 +428,7 @@ func initProjections( ctx, config.Projections.Customizations["notifications"], config.Projections.Customizations["notificationsquotas"], + config.Projections.Customizations["backchannel"], config.Projections.Customizations["telemetry"], *config.Telemetry, config.ExternalDomain, @@ -437,6 +442,8 @@ func initProjections( keys.User, keys.SMTP, keys.SMS, + keys.OIDC, + config.OIDC.DefaultBackChannelLogoutLifetime, ) for _, p := range notify_handler.Projections() { err := migration.Migrate(ctx, eventstoreClient, p) diff --git a/cmd/start/start.go b/cmd/start/start.go index 97a38ba50d..8de1105307 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -270,6 +270,7 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server ctx, config.Projections.Customizations["notifications"], config.Projections.Customizations["notificationsquotas"], + config.Projections.Customizations["backchannel"], config.Projections.Customizations["telemetry"], *config.Telemetry, config.ExternalDomain, @@ -283,6 +284,8 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server keys.User, keys.SMTP, keys.SMS, + keys.OIDC, + config.OIDC.DefaultBackChannelLogoutLifetime, ) notification.Start(ctx) diff --git a/go.mod b/go.mod index 7e34525b6d..1e4f67eb7d 100644 --- a/go.mod +++ b/go.mod @@ -34,7 +34,7 @@ require ( github.com/gorilla/websocket v1.4.1 github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 github.com/grpc-ecosystem/grpc-gateway v1.16.0 - github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 + github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 github.com/h2non/gock v1.2.0 github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/improbable-eng/grpc-web v0.15.0 @@ -52,7 +52,7 @@ require ( github.com/pashagolub/pgxmock/v4 v4.3.0 github.com/pquerna/otp v1.4.0 github.com/rakyll/statik v0.1.7 - github.com/rs/cors v1.11.0 + github.com/rs/cors v1.11.1 github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 github.com/sony/sonyflake v1.2.0 github.com/spf13/cobra v1.8.1 @@ -62,29 +62,29 @@ require ( github.com/ttacon/libphonenumber v1.2.1 github.com/twilio/twilio-go v1.22.2 github.com/zitadel/logging v0.6.1 - github.com/zitadel/oidc/v3 v3.28.1 + github.com/zitadel/oidc/v3 v3.32.0 github.com/zitadel/passwap v0.6.0 github.com/zitadel/saml v0.2.0 github.com/zitadel/schema v1.3.0 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.53.0 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0 - go.opentelemetry.io/otel v1.28.0 - go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.28.0 + go.opentelemetry.io/otel v1.29.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.29.0 go.opentelemetry.io/otel/exporters/prometheus v0.50.0 - go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.28.0 - go.opentelemetry.io/otel/metric v1.28.0 - go.opentelemetry.io/otel/sdk v1.28.0 - go.opentelemetry.io/otel/sdk/metric v1.28.0 - go.opentelemetry.io/otel/trace v1.28.0 + go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.29.0 + go.opentelemetry.io/otel/metric v1.29.0 + go.opentelemetry.io/otel/sdk v1.29.0 + go.opentelemetry.io/otel/sdk/metric v1.29.0 + go.opentelemetry.io/otel/trace v1.29.0 go.uber.org/mock v0.4.0 golang.org/x/crypto v0.27.0 golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 - golang.org/x/net v0.26.0 - golang.org/x/oauth2 v0.22.0 + golang.org/x/net v0.28.0 + golang.org/x/oauth2 v0.23.0 golang.org/x/sync v0.8.0 - golang.org/x/text v0.18.0 + golang.org/x/text v0.19.0 google.golang.org/api v0.187.0 - google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094 + google.golang.org/genproto/googleapis/api v0.0.0-20240822170219-fc7c04adadcd google.golang.org/grpc v1.65.0 google.golang.org/protobuf v1.34.2 sigs.k8s.io/yaml v1.4.0 @@ -94,7 +94,7 @@ require ( cloud.google.com/go/auth v0.6.1 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.48.0 // indirect - github.com/bmatcuk/doublestar/v4 v4.6.1 // indirect + github.com/bmatcuk/doublestar/v4 v4.7.1 // indirect github.com/crewjam/httperr v0.2.0 // indirect github.com/go-ini/ini v1.67.0 // indirect github.com/go-logr/logr v1.4.2 // indirect @@ -125,7 +125,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/time v0.5.0 // indirect google.golang.org/genproto v0.0.0-20240624140628-dc46fd24d27d // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240822170219-fc7c04adadcd // indirect ) require ( @@ -197,7 +197,7 @@ require ( github.com/x448/float16 v0.8.4 // indirect github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913 // indirect go.opencensus.io v0.24.0 // indirect - go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.28.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.29.0 // indirect go.opentelemetry.io/proto/otlp v1.3.1 // indirect golang.org/x/sys v0.25.0 gopkg.in/ini.v1 v1.67.0 // indirect diff --git a/go.sum b/go.sum index dc1e9fb1e8..8645aa8417 100644 --- a/go.sum +++ b/go.sum @@ -80,8 +80,8 @@ github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+Ce github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= -github.com/bmatcuk/doublestar/v4 v4.6.1 h1:FH9SifrbvJhnlQpztAx++wlkk70QBf0iBWDwNy7PA4I= -github.com/bmatcuk/doublestar/v4 v4.6.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= +github.com/bmatcuk/doublestar/v4 v4.7.1 h1:fdDeAqgT47acgwd9bd9HxJRDmc9UAmPpc+2m0CXv75Q= +github.com/bmatcuk/doublestar/v4 v4.7.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/boombuler/barcode v1.0.2 h1:79yrbttoZrLGkL/oOI8hBrUKucwOL0oOjUgEguGMcJ4= github.com/boombuler/barcode v1.0.2/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= @@ -354,8 +354,8 @@ github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgf github.com/grpc-ecosystem/grpc-gateway v1.9.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4M0+kPpLofRdBo= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 h1:asbCHRVmodnJTuQ3qamDwqVOIjwqUPTYmYuemVOx+Ys= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0/go.mod h1:ggCgvZ2r7uOoQjOyu2Y1NhHmEPPzzuhWgcza5M1Ji1I= github.com/h2non/filetype v1.1.1/go.mod h1:319b3zT68BvV+WRj7cwy856M2ehB3HqNOt6sy1HndBY= github.com/h2non/filetype v1.1.3 h1:FKkx9QbD7HR/zjK1Ia5XiBsq9zdLi5Kf3zGyFTAFkGg= github.com/h2non/filetype v1.1.3/go.mod h1:319b3zT68BvV+WRj7cwy856M2ehB3HqNOt6sy1HndBY= @@ -628,8 +628,8 @@ github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6po github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU= -github.com/rs/cors v1.11.0 h1:0B9GE/r9Bc2UxRMMtymBkHTenPkHDv0CW4Y98GBY+po= -github.com/rs/cors v1.11.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= +github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA= +github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/russellhaering/goxmldsig v1.4.0 h1:8UcDh/xGyQiyrW+Fq5t8f+l2DLB1+zlhYzkPUJ7Qhys= @@ -723,8 +723,8 @@ github.com/zenazn/goji v1.0.1 h1:4lbD8Mx2h7IvloP7r2C0D6ltZP6Ufip8Hn0wmSK5LR8= github.com/zenazn/goji v1.0.1/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= github.com/zitadel/logging v0.6.1 h1:Vyzk1rl9Kq9RCevcpX6ujUaTYFX43aa4LkvV1TvUk+Y= github.com/zitadel/logging v0.6.1/go.mod h1:Y4CyAXHpl3Mig6JOszcV5Rqqsojj+3n7y2F591Mp/ow= -github.com/zitadel/oidc/v3 v3.28.1 h1:PsbFm5CzEMQq9HBXUNJ8yvnWmtVYxpwV5Cinj7TTsHo= -github.com/zitadel/oidc/v3 v3.28.1/go.mod h1:WmDFu3dZ9YNKrIoZkmxjGG8QyUR4PbbhsVVSY+rpojM= +github.com/zitadel/oidc/v3 v3.32.0 h1:Mw0EPZRC6h+OXAuT0Uk2BZIjJQNHLqUpaJCm6c3IByc= +github.com/zitadel/oidc/v3 v3.32.0/go.mod h1:DyE/XClysRK/ozFaZSqlYamKVnTh4l6Ln25ihSNI03w= github.com/zitadel/passwap v0.6.0 h1:m9F3epFC0VkBXu25rihSLGyHvWiNlCzU5kk8RoI+SXQ= github.com/zitadel/passwap v0.6.0/go.mod h1:kqAiJ4I4eZvm3Y6oAk6hlEqlZZOkjMHraGXF90GG7LI= github.com/zitadel/saml v0.2.0 h1:vv7r+Xz43eAPCb+fImMaospD+TWRZQDkb78AbSJRcL4= @@ -742,24 +742,24 @@ go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.5 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.53.0/go.mod h1:azvtTADFQJA8mX80jIH/akaE7h+dbm/sVuaHqN13w74= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0 h1:4K4tsIXefpVJtvA/8srF4V4y0akAoPHkIslgAkjixJA= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0/go.mod h1:jjdQuTGVsXV4vSs+CJ2qYDeDPf9yIJV23qlIzBm73Vg= -go.opentelemetry.io/otel v1.28.0 h1:/SqNcYk+idO0CxKEUOtKQClMK/MimZihKYMruSMViUo= -go.opentelemetry.io/otel v1.28.0/go.mod h1:q68ijF8Fc8CnMHKyzqL6akLO46ePnjkgfIMIjUIX9z4= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.28.0 h1:3Q/xZUyC1BBkualc9ROb4G8qkH90LXEIICcs5zv1OYY= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.28.0/go.mod h1:s75jGIWA9OfCMzF0xr+ZgfrB5FEbbV7UuYo32ahUiFI= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.28.0 h1:R3X6ZXmNPRR8ul6i3WgFURCHzaXjHdm0karRG/+dj3s= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.28.0/go.mod h1:QWFXnDavXWwMx2EEcZsf3yxgEKAqsxQ+Syjp+seyInw= +go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw= +go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.29.0 h1:dIIDULZJpgdiHz5tXrTgKIMLkus6jEFa7x5SOKcyR7E= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.29.0/go.mod h1:jlRVBe7+Z1wyxFSUs48L6OBQZ5JwH2Hg/Vbl+t9rAgI= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.29.0 h1:nSiV3s7wiCam610XcLbYOmMfJxB9gO4uK3Xgv5gmTgg= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.29.0/go.mod h1:hKn/e/Nmd19/x1gvIHwtOwVWM+VhuITSWip3JUDghj0= go.opentelemetry.io/otel/exporters/prometheus v0.50.0 h1:2Ewsda6hejmbhGFyUvWZjUThC98Cf8Zy6g0zkIimOng= go.opentelemetry.io/otel/exporters/prometheus v0.50.0/go.mod h1:pMm5PkUo5YwbLiuEf7t2xg4wbP0/eSJrMxIMxKosynY= -go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.28.0 h1:EVSnY9JbEEW92bEkIYOVMw4q1WJxIAGoFTrtYOzWuRQ= -go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.28.0/go.mod h1:Ea1N1QQryNXpCD0I1fdLibBAIpQuBkznMmkdKrapk1Y= -go.opentelemetry.io/otel/metric v1.28.0 h1:f0HGvSl1KRAU1DLgLGFjrwVyismPlnuU6JD6bOeuA5Q= -go.opentelemetry.io/otel/metric v1.28.0/go.mod h1:Fb1eVBFZmLVTMb6PPohq3TO9IIhUisDsbJoL/+uQW4s= -go.opentelemetry.io/otel/sdk v1.28.0 h1:b9d7hIry8yZsgtbmM0DKyPWMMUMlK9NEKuIG4aBqWyE= -go.opentelemetry.io/otel/sdk v1.28.0/go.mod h1:oYj7ClPUA7Iw3m+r7GeEjz0qckQRJK2B8zjcZEfu7Pg= -go.opentelemetry.io/otel/sdk/metric v1.28.0 h1:OkuaKgKrgAbYrrY0t92c+cC+2F6hsFNnCQArXCKlg08= -go.opentelemetry.io/otel/sdk/metric v1.28.0/go.mod h1:cWPjykihLAPvXKi4iZc1dpER3Jdq2Z0YLse3moQUCpg= -go.opentelemetry.io/otel/trace v1.28.0 h1:GhQ9cUuQGmNDd5BTCP2dAvv75RdMxEfTmYejp+lkx9g= -go.opentelemetry.io/otel/trace v1.28.0/go.mod h1:jPyXzNPg6da9+38HEwElrQiHlVMTnVfM3/yv2OlIHaI= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.29.0 h1:X3ZjNp36/WlkSYx0ul2jw4PtbNEDDeLskw3VPsrpYM0= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.29.0/go.mod h1:2uL/xnOXh0CHOBFCWXz5u1A4GXLiW+0IQIzVbeOEQ0U= +go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc= +go.opentelemetry.io/otel/metric v1.29.0/go.mod h1:auu/QWieFVWx+DmQOUMgj0F8LHWdgalxXqvp7BII/W8= +go.opentelemetry.io/otel/sdk v1.29.0 h1:vkqKjk7gwhS8VaWb0POZKmIEDimRCMsopNYnriHyryo= +go.opentelemetry.io/otel/sdk v1.29.0/go.mod h1:pM8Dx5WKnvxLCb+8lG1PRNIDxu9g9b9g59Qr7hfAAok= +go.opentelemetry.io/otel/sdk/metric v1.29.0 h1:K2CfmJohnRgvZ9UAj2/FhIf/okdWcNdBwe1m8xFXiSY= +go.opentelemetry.io/otel/sdk/metric v1.29.0/go.mod h1:6zZLdCl2fkauYoZIOn/soQIDSWFmNSRcICarHfuhNJQ= +go.opentelemetry.io/otel/trace v1.29.0 h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4= +go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ= go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0= go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= @@ -857,13 +857,13 @@ golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= -golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.22.0 h1:BzDx2FehcG7jJwgWLELCdmLuxk2i+x9UDpSiss2u0ZA= -golang.org/x/oauth2 v0.22.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= +golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -934,8 +934,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= -golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= @@ -983,10 +983,10 @@ google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEY google.golang.org/genproto v0.0.0-20210126160654-44e461bb6506/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20240624140628-dc46fd24d27d h1:PksQg4dV6Sem3/HkBX+Ltq8T0ke0PKIRBNBatoDTVls= google.golang.org/genproto v0.0.0-20240624140628-dc46fd24d27d/go.mod h1:s7iA721uChleev562UJO2OYB0PPT9CMFjV+Ce7VJH5M= -google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094 h1:0+ozOGcrp+Y8Aq8TLNN2Aliibms5LEzsq99ZZmAGYm0= -google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094/go.mod h1:fJ/e3If/Q67Mj99hin0hMhiNyCRmt6BQ2aWIJshUSJw= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094 h1:BwIjyKYGsK9dMCBOorzRri8MQwmi7mT9rGHsCEinZkA= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240701130421-f6361c86f094/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= +google.golang.org/genproto/googleapis/api v0.0.0-20240822170219-fc7c04adadcd h1:BBOTEWLuuEGQy9n1y9MhVJ9Qt0BDu21X8qZs71/uPZo= +google.golang.org/genproto/googleapis/api v0.0.0-20240822170219-fc7c04adadcd/go.mod h1:fO8wJzT2zbQbAjbIoos1285VfEIYKDDY+Dt+WpTkh6g= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240822170219-fc7c04adadcd h1:6TEm2ZxXoQmFWFlt1vNxvVOa1Q0dXFQD1m/rYjXmS0E= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240822170219-fc7c04adadcd/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.0/go.mod h1:chYK+tFQF0nDUGJgXMSgLCQk3phJEuONr2DCgLDdAQM= diff --git a/internal/api/grpc/feature/v2/converter.go b/internal/api/grpc/feature/v2/converter.go index e8b57a2885..7d951f789a 100644 --- a/internal/api/grpc/feature/v2/converter.go +++ b/internal/api/grpc/feature/v2/converter.go @@ -19,6 +19,7 @@ func systemFeaturesToCommand(req *feature_pb.SetSystemFeaturesRequest) *command. ImprovedPerformance: improvedPerformanceListToDomain(req.ImprovedPerformance), OIDCSingleV1SessionTermination: req.OidcSingleV1SessionTermination, DisableUserTokenEvent: req.DisableUserTokenEvent, + EnableBackChannelLogout: req.EnableBackChannelLogout, } } @@ -34,6 +35,7 @@ func systemFeaturesToPb(f *query.SystemFeatures) *feature_pb.GetSystemFeaturesRe ImprovedPerformance: featureSourceToImprovedPerformanceFlagPb(&f.ImprovedPerformance), OidcSingleV1SessionTermination: featureSourceToFlagPb(&f.OIDCSingleV1SessionTermination), DisableUserTokenEvent: featureSourceToFlagPb(&f.DisableUserTokenEvent), + EnableBackChannelLogout: featureSourceToFlagPb(&f.EnableBackChannelLogout), } } @@ -50,6 +52,7 @@ func instanceFeaturesToCommand(req *feature_pb.SetInstanceFeaturesRequest) *comm DebugOIDCParentError: req.DebugOidcParentError, OIDCSingleV1SessionTermination: req.OidcSingleV1SessionTermination, DisableUserTokenEvent: req.DisableUserTokenEvent, + EnableBackChannelLogout: req.EnableBackChannelLogout, } } @@ -67,6 +70,7 @@ func instanceFeaturesToPb(f *query.InstanceFeatures) *feature_pb.GetInstanceFeat DebugOidcParentError: featureSourceToFlagPb(&f.DebugOIDCParentError), OidcSingleV1SessionTermination: featureSourceToFlagPb(&f.OIDCSingleV1SessionTermination), DisableUserTokenEvent: featureSourceToFlagPb(&f.DisableUserTokenEvent), + EnableBackChannelLogout: featureSourceToFlagPb(&f.EnableBackChannelLogout), } } diff --git a/internal/api/grpc/feature/v2/converter_test.go b/internal/api/grpc/feature/v2/converter_test.go index 79bfa34839..43a848e3a6 100644 --- a/internal/api/grpc/feature/v2/converter_test.go +++ b/internal/api/grpc/feature/v2/converter_test.go @@ -80,6 +80,10 @@ func Test_systemFeaturesToPb(t *testing.T) { Level: feature.LevelSystem, Value: true, }, + EnableBackChannelLogout: query.FeatureSource[bool]{ + Level: feature.LevelSystem, + Value: true, + }, } want := &feature_pb.GetSystemFeaturesResponse{ Details: &object.Details{ @@ -123,6 +127,10 @@ func Test_systemFeaturesToPb(t *testing.T) { Enabled: false, Source: feature_pb.Source_SOURCE_UNSPECIFIED, }, + EnableBackChannelLogout: &feature_pb.FeatureFlag{ + Enabled: true, + Source: feature_pb.Source_SOURCE_SYSTEM, + }, } got := systemFeaturesToPb(arg) assert.Equal(t, want, got) @@ -140,6 +148,7 @@ func Test_instanceFeaturesToCommand(t *testing.T) { WebKey: gu.Ptr(true), DebugOidcParentError: gu.Ptr(true), OidcSingleV1SessionTermination: gu.Ptr(true), + EnableBackChannelLogout: gu.Ptr(true), } want := &command.InstanceFeatures{ LoginDefaultOrg: gu.Ptr(true), @@ -152,6 +161,7 @@ func Test_instanceFeaturesToCommand(t *testing.T) { WebKey: gu.Ptr(true), DebugOIDCParentError: gu.Ptr(true), OIDCSingleV1SessionTermination: gu.Ptr(true), + EnableBackChannelLogout: gu.Ptr(true), } got := instanceFeaturesToCommand(arg) assert.Equal(t, want, got) @@ -200,6 +210,10 @@ func Test_instanceFeaturesToPb(t *testing.T) { Level: feature.LevelInstance, Value: true, }, + EnableBackChannelLogout: query.FeatureSource[bool]{ + Level: feature.LevelInstance, + Value: true, + }, } want := &feature_pb.GetInstanceFeaturesResponse{ Details: &object.Details{ @@ -251,6 +265,10 @@ func Test_instanceFeaturesToPb(t *testing.T) { Enabled: false, Source: feature_pb.Source_SOURCE_UNSPECIFIED, }, + EnableBackChannelLogout: &feature_pb.FeatureFlag{ + Enabled: true, + Source: feature_pb.Source_SOURCE_INSTANCE, + }, } got := instanceFeaturesToPb(arg) assert.Equal(t, want, got) diff --git a/internal/api/grpc/management/project_application_converter.go b/internal/api/grpc/management/project_application_converter.go index fcfe8089ee..ea2f45fd0d 100644 --- a/internal/api/grpc/management/project_application_converter.go +++ b/internal/api/grpc/management/project_application_converter.go @@ -57,6 +57,7 @@ func AddOIDCAppRequestToDomain(req *mgmt_pb.AddOIDCAppRequest) *domain.OIDCApp { ClockSkew: req.ClockSkew.AsDuration(), AdditionalOrigins: req.AdditionalOrigins, SkipNativeAppSuccessPage: req.SkipNativeAppSuccessPage, + BackChannelLogoutURI: req.GetBackChannelLogoutUri(), } } @@ -108,6 +109,7 @@ func UpdateOIDCAppConfigRequestToDomain(app *mgmt_pb.UpdateOIDCAppConfigRequest) ClockSkew: app.ClockSkew.AsDuration(), AdditionalOrigins: app.AdditionalOrigins, SkipNativeAppSuccessPage: app.SkipNativeAppSuccessPage, + BackChannelLogoutURI: app.BackChannelLogoutUri, } } diff --git a/internal/api/grpc/project/application.go b/internal/api/grpc/project/application.go index 25274eeb1d..e70554ce64 100644 --- a/internal/api/grpc/project/application.go +++ b/internal/api/grpc/project/application.go @@ -61,6 +61,7 @@ func AppOIDCConfigToPb(app *query.OIDCApp) *app_pb.App_OidcConfig { AdditionalOrigins: app.AdditionalOrigins, AllowedOrigins: app.AllowedOrigins, SkipNativeAppSuccessPage: app.SkipNativeAppSuccessPage, + BackChannelLogoutUri: app.BackChannelLogoutURI, }, } } diff --git a/internal/api/oidc/auth_request.go b/internal/api/oidc/auth_request.go index dc402036fb..173585ff13 100644 --- a/internal/api/oidc/auth_request.go +++ b/internal/api/oidc/auth_request.go @@ -215,18 +215,18 @@ func (o *OPStorage) TerminateSession(ctx context.Context, userID, clientID strin logging.Error("no user agent id") return zerrors.ThrowPreconditionFailed(nil, "OIDC-fso7F", "no user agent id") } - userIDs, err := o.repo.UserSessionUserIDsByAgentID(ctx, userAgentID) + sessions, err := o.repo.UserSessionsByAgentID(ctx, userAgentID) if err != nil { logging.WithError(err).Error("error retrieving user sessions") return err } - if len(userIDs) == 0 { + if len(sessions) == 0 { return nil } data := authz.CtxData{ UserID: userID, } - err = o.command.HumansSignOut(authz.SetCtxData(ctx, data), userAgentID, userIDs) + err = o.command.HumansSignOut(authz.SetCtxData(ctx, data), userAgentID, sessions) logging.OnError(err).Error("error signing out") return err } @@ -278,18 +278,18 @@ func (o *OPStorage) terminateV1Session(ctx context.Context, userID, sessionID st if err != nil { return err } - return o.command.HumansSignOut(ctx, userAgentID, []string{userID}) + return o.command.HumansSignOut(ctx, userAgentID, []command.HumanSignOutSession{{ID: sessionID, UserID: userID}}) } // otherwise we search for all active sessions within the same user agent of the current session id - userAgentID, userIDs, err := o.repo.ActiveUserIDsBySessionID(ctx, sessionID) + userAgentID, sessions, err := o.repo.ActiveUserSessionsBySessionID(ctx, sessionID) if err != nil { logging.WithError(err).Error("error retrieving user sessions") return err } - if len(userIDs) == 0 { + if len(sessions) == 0 { return nil } - return o.command.HumansSignOut(ctx, userAgentID, userIDs) + return o.command.HumansSignOut(ctx, userAgentID, sessions) } func (o *OPStorage) RevokeToken(ctx context.Context, token, userID, clientID string) (err *oidc.Error) { @@ -588,6 +588,7 @@ func (s *Server) authResponseToken(authReq *AuthRequest, authorizer op.Authorize authReq.UserID, authReq.UserOrgID, client.client.ClientID, + client.client.BackChannelLogoutURI, scope, authReq.Audience, authReq.AuthMethods(), diff --git a/internal/api/oidc/key.go b/internal/api/oidc/key.go index a7e156fe78..535aa846b4 100644 --- a/internal/api/oidc/key.go +++ b/internal/api/oidc/key.go @@ -348,7 +348,7 @@ func (o *OPStorage) getSigningKey(ctx context.Context) (op.SigningKey, error) { return nil, err } if len(keys.Keys) > 0 { - return o.privateKeyToSigningKey(selectSigningKey(keys.Keys)) + return PrivateKeyToSigningKey(SelectSigningKey(keys.Keys), o.encAlg) } var position float64 if keys.State != nil { @@ -377,8 +377,8 @@ func (o *OPStorage) ensureIsLatestKey(ctx context.Context, position float64) (bo return position >= maxSequence, nil } -func (o *OPStorage) privateKeyToSigningKey(key query.PrivateKey) (_ op.SigningKey, err error) { - keyData, err := crypto.Decrypt(key.Key(), o.encAlg) +func PrivateKeyToSigningKey(key query.PrivateKey, algorithm crypto.EncryptionAlgorithm) (_ op.SigningKey, err error) { + keyData, err := crypto.Decrypt(key.Key(), algorithm) if err != nil { return nil, err } @@ -430,7 +430,7 @@ func (o *OPStorage) getMaxKeySequence(ctx context.Context) (float64, error) { ) } -func selectSigningKey(keys []query.PrivateKey) query.PrivateKey { +func SelectSigningKey(keys []query.PrivateKey) query.PrivateKey { return keys[len(keys)-1] } diff --git a/internal/api/oidc/op.go b/internal/api/oidc/op.go index c8dafb50f3..86b89690bf 100644 --- a/internal/api/oidc/op.go +++ b/internal/api/oidc/op.go @@ -42,6 +42,7 @@ type Config struct { DefaultLoginURLV2 string DefaultLogoutURLV2 string PublicKeyCacheMaxAge time.Duration + DefaultBackChannelLogoutLifetime time.Duration } type EndpointConfig struct { diff --git a/internal/api/oidc/server.go b/internal/api/oidc/server.go index 07bc4706be..1a0854e2a6 100644 --- a/internal/api/oidc/server.go +++ b/internal/api/oidc/server.go @@ -167,6 +167,7 @@ func (s *Server) EndSession(ctx context.Context, r *op.Request[oidc.EndSessionRe func (s *Server) createDiscoveryConfig(ctx context.Context, supportedUILocales oidc.Locales) *oidc.DiscoveryConfiguration { issuer := op.IssuerFromContext(ctx) + backChannelLogoutSupported := authz.GetInstance(ctx).Features().EnableBackChannelLogout return &oidc.DiscoveryConfiguration{ Issuer: issuer, @@ -199,6 +200,8 @@ func (s *Server) createDiscoveryConfig(ctx context.Context, supportedUILocales o CodeChallengeMethodsSupported: op.CodeChallengeMethods(s.Provider()), UILocalesSupported: supportedUILocales, RequestParameterSupported: s.Provider().RequestObjectSupported(), + BackChannelLogoutSupported: backChannelLogoutSupported, + BackChannelLogoutSessionSupported: backChannelLogoutSupported, } } diff --git a/internal/api/oidc/token.go b/internal/api/oidc/token.go index 56ed225902..485f455784 100644 --- a/internal/api/oidc/token.go +++ b/internal/api/oidc/token.go @@ -60,12 +60,19 @@ func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.C return resp, err } -// signerFunc is a getter function that allows add-hoc retrieval of the instance's signer. -type signerFunc func(ctx context.Context) (jose.Signer, jose.SignatureAlgorithm, error) +// SignerFunc is a getter function that allows add-hoc retrieval of the instance's signer. +type SignerFunc func(ctx context.Context) (jose.Signer, jose.SignatureAlgorithm, error) -// getSignerOnce returns a function which retrieves the instance's signer from the database once. +func (s *Server) getSignerOnce() SignerFunc { + return GetSignerOnce(s.query.GetActiveSigningWebKey, s.Provider().Storage().SigningKey) +} + +// GetSignerOnce returns a function which retrieves the instance's signer from the database once. // Repeated calls of the returned function return the same results. -func (s *Server) getSignerOnce() signerFunc { +func GetSignerOnce( + getActiveSigningWebKey func(ctx context.Context) (*jose.JSONWebKey, error), + getSigningKey func(ctx context.Context) (op.SigningKey, error), +) SignerFunc { var ( once sync.Once signer jose.Signer @@ -79,7 +86,7 @@ func (s *Server) getSignerOnce() signerFunc { if authz.GetFeatures(ctx).WebKey { var webKey *jose.JSONWebKey - webKey, err = s.query.GetActiveSigningWebKey(ctx) + webKey, err = getActiveSigningWebKey(ctx) if err != nil { return } @@ -88,7 +95,7 @@ func (s *Server) getSignerOnce() signerFunc { } var signingKey op.SigningKey - signingKey, err = s.Provider().Storage().SigningKey(ctx) + signingKey, err = getSigningKey(ctx) if err != nil { return } @@ -126,7 +133,7 @@ func (s *Server) getUserInfo(userID, projectID string, projectRoleAssertion, use } } -func (*Server) createIDToken(ctx context.Context, client op.Client, getUserInfo userInfoFunc, roleAssertion bool, getSigningKey signerFunc, sessionID, accessToken string, audience []string, authMethods []domain.UserAuthMethodType, authTime time.Time, nonce string, actor *domain.TokenActor) (idToken string, exp uint64, err error) { +func (*Server) createIDToken(ctx context.Context, client op.Client, getUserInfo userInfoFunc, roleAssertion bool, getSigningKey SignerFunc, sessionID, accessToken string, audience []string, authMethods []domain.UserAuthMethodType, authTime time.Time, nonce string, actor *domain.TokenActor) (idToken string, exp uint64, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -170,7 +177,7 @@ func timeToOIDCExpiresIn(exp time.Time) uint64 { return uint64(time.Until(exp) / time.Second) } -func (s *Server) createJWT(ctx context.Context, client op.Client, session *command.OIDCSession, getUserInfo userInfoFunc, assertRoles bool, getSigner signerFunc) (_ string, err error) { +func (s *Server) createJWT(ctx context.Context, client op.Client, session *command.OIDCSession, getUserInfo userInfoFunc, assertRoles bool, getSigner SignerFunc) (_ string, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() diff --git a/internal/api/oidc/token_client_credentials.go b/internal/api/oidc/token_client_credentials.go index 2fedd71c44..0b836a03cc 100644 --- a/internal/api/oidc/token_client_credentials.go +++ b/internal/api/oidc/token_client_credentials.go @@ -35,6 +35,7 @@ func (s *Server) ClientCredentialsExchange(ctx context.Context, r *op.ClientRequ client.userID, client.resourceOwner, client.clientID, + "", // backChannelLogoutURI not needed for service user session scope, domain.AddAudScopeToAudience(ctx, nil, r.Data.Scope), []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, diff --git a/internal/api/oidc/token_code.go b/internal/api/oidc/token_code.go index e6899bed01..3aa53e629e 100644 --- a/internal/api/oidc/token_code.go +++ b/internal/api/oidc/token_code.go @@ -75,6 +75,7 @@ func (s *Server) codeExchangeV1(ctx context.Context, client *Client, req *oidc.A authReq.UserID, authReq.UserOrgID, client.client.ClientID, + client.client.BackChannelLogoutURI, scope, authReq.Audience, authReq.AuthMethods(), diff --git a/internal/api/oidc/token_exchange.go b/internal/api/oidc/token_exchange.go index ec43729692..63a594b940 100644 --- a/internal/api/oidc/token_exchange.go +++ b/internal/api/oidc/token_exchange.go @@ -288,6 +288,7 @@ func (s *Server) createExchangeAccessToken( userID, resourceOwner, client.client.ClientID, + client.client.BackChannelLogoutURI, scope, audience, authMethods, @@ -315,7 +316,7 @@ func (s *Server) createExchangeJWT( client *Client, getUserInfo userInfoFunc, roleAssertion bool, - getSigner signerFunc, + getSigner SignerFunc, userID, resourceOwner string, audience, @@ -333,6 +334,7 @@ func (s *Server) createExchangeJWT( userID, resourceOwner, client.client.ClientID, + client.client.BackChannelLogoutURI, scope, audience, authMethods, diff --git a/internal/api/oidc/token_jwt_profile.go b/internal/api/oidc/token_jwt_profile.go index 253432cc83..4717d29f9c 100644 --- a/internal/api/oidc/token_jwt_profile.go +++ b/internal/api/oidc/token_jwt_profile.go @@ -45,6 +45,7 @@ func (s *Server) JWTProfile(ctx context.Context, r *op.Request[oidc.JWTProfileGr client.userID, client.resourceOwner, client.clientID, + "", // backChannelLogoutURI not needed for service user session scope, domain.AddAudScopeToAudience(ctx, nil, r.Data.Scope), []domain.UserAuthMethodType{domain.UserAuthMethodTypePrivateKey}, diff --git a/internal/api/oidc/token_refresh.go b/internal/api/oidc/token_refresh.go index 62f3c6dd3f..f0d92fa521 100644 --- a/internal/api/oidc/token_refresh.go +++ b/internal/api/oidc/token_refresh.go @@ -54,6 +54,7 @@ func (s *Server) refreshTokenV1(ctx context.Context, client *Client, r *op.Clien refreshToken.UserID, refreshToken.ResourceOwner, refreshToken.ClientID, + "", // backChannelLogoutURI is not in refresh token view scope, refreshToken.Audience, AMRToAuthMethodTypes(refreshToken.AuthMethodsReferences), diff --git a/internal/auth/repository/eventsourcing/eventstore/user.go b/internal/auth/repository/eventsourcing/eventstore/user.go index b11f770d77..61895c263d 100644 --- a/internal/auth/repository/eventsourcing/eventstore/user.go +++ b/internal/auth/repository/eventsourcing/eventstore/user.go @@ -6,6 +6,7 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/auth/repository/eventsourcing/view" + "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/config/systemdefaults" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore" @@ -27,26 +28,40 @@ func (repo *UserRepo) Health(ctx context.Context) error { return repo.Eventstore.Health(ctx) } -func (repo *UserRepo) UserSessionUserIDsByAgentID(ctx context.Context, agentID string) ([]string, error) { - userSessions, err := repo.View.UserSessionsByAgentID(ctx, agentID, authz.GetInstance(ctx).InstanceID()) +func (repo *UserRepo) UserSessionsByAgentID(ctx context.Context, agentID string) ([]command.HumanSignOutSession, error) { + sessions, err := repo.View.UserSessionsByAgentID(ctx, agentID, authz.GetInstance(ctx).InstanceID()) if err != nil { return nil, err } - userIDs := make([]string, 0, len(userSessions)) - for _, session := range userSessions { - if session.State.V == domain.UserSessionStateActive { - userIDs = append(userIDs, session.UserID) + signoutSessions := make([]command.HumanSignOutSession, 0, len(sessions)) + for _, session := range sessions { + if session.State.V == domain.UserSessionStateActive && session.ID.Valid { + signoutSessions = append(signoutSessions, command.HumanSignOutSession{ + ID: session.ID.String, + UserID: session.UserID, + }) } } - return userIDs, nil + return signoutSessions, nil } func (repo *UserRepo) UserAgentIDBySessionID(ctx context.Context, sessionID string) (string, error) { return repo.View.UserAgentIDBySessionID(ctx, sessionID, authz.GetInstance(ctx).InstanceID()) } -func (repo *UserRepo) ActiveUserIDsBySessionID(ctx context.Context, sessionID string) (userAgentID string, userIDs []string, err error) { - return repo.View.ActiveUserIDsBySessionID(ctx, sessionID, authz.GetInstance(ctx).InstanceID()) +func (repo *UserRepo) ActiveUserSessionsBySessionID(ctx context.Context, sessionID string) (userAgentID string, signoutSessions []command.HumanSignOutSession, err error) { + userAgentID, sessions, err := repo.View.ActiveUserSessionsBySessionID(ctx, sessionID, authz.GetInstance(ctx).InstanceID()) + if err != nil { + return "", nil, err + } + signoutSessions = make([]command.HumanSignOutSession, 0, len(sessions)) + for sessionID, userID := range sessions { + signoutSessions = append(signoutSessions, command.HumanSignOutSession{ + ID: sessionID, + UserID: userID, + }) + } + return userAgentID, signoutSessions, nil } func (repo *UserRepo) UserEventsByID(ctx context.Context, id string, changeDate time.Time, eventTypes []eventstore.EventType) ([]eventstore.Event, error) { diff --git a/internal/auth/repository/eventsourcing/view/user_session.go b/internal/auth/repository/eventsourcing/view/user_session.go index f25deb99e6..a4618e11fb 100644 --- a/internal/auth/repository/eventsourcing/view/user_session.go +++ b/internal/auth/repository/eventsourcing/view/user_session.go @@ -24,8 +24,8 @@ func (v *View) UserAgentIDBySessionID(ctx context.Context, sessionID, instanceID return view.UserAgentIDBySessionID(ctx, v.client, sessionID, instanceID) } -func (v *View) ActiveUserIDsBySessionID(ctx context.Context, sessionID, instanceID string) (userAgentID string, userIDs []string, err error) { - return view.ActiveUserIDsBySessionID(ctx, v.client, sessionID, instanceID) +func (v *View) ActiveUserSessionsBySessionID(ctx context.Context, sessionID, instanceID string) (userAgentID string, sessions map[string]string, err error) { + return view.ActiveUserSessionsBySessionID(ctx, v.client, sessionID, instanceID) } func (v *View) GetLatestUserSessionSequence(ctx context.Context, instanceID string) (_ *query.CurrentState, err error) { diff --git a/internal/auth/repository/user.go b/internal/auth/repository/user.go index 6f373ec12e..f09581b32e 100644 --- a/internal/auth/repository/user.go +++ b/internal/auth/repository/user.go @@ -2,10 +2,12 @@ package repository import ( "context" + + "github.com/zitadel/zitadel/internal/command" ) type UserRepository interface { - UserSessionUserIDsByAgentID(ctx context.Context, agentID string) ([]string, error) + UserSessionsByAgentID(ctx context.Context, agentID string) (sessions []command.HumanSignOutSession, err error) UserAgentIDBySessionID(ctx context.Context, sessionID string) (string, error) - ActiveUserIDsBySessionID(ctx context.Context, sessionID string) (userAgentID string, userIDs []string, err error) + ActiveUserSessionsBySessionID(ctx context.Context, sessionID string) (userAgentID string, sessions []command.HumanSignOutSession, err error) } diff --git a/internal/command/instance_domain_test.go b/internal/command/instance_domain_test.go index 3f5e73aedd..adaa59ec05 100644 --- a/internal/command/instance_domain_test.go +++ b/internal/command/instance_domain_test.go @@ -155,6 +155,7 @@ func TestCommandSide_AddInstanceDomain(t *testing.T) { time.Second*1, []string{"https://sub.test.ch"}, false, + "", ), ), ), diff --git a/internal/command/instance_features.go b/internal/command/instance_features.go index 79d3d25ffe..e4509ae130 100644 --- a/internal/command/instance_features.go +++ b/internal/command/instance_features.go @@ -27,6 +27,7 @@ type InstanceFeatures struct { DebugOIDCParentError *bool OIDCSingleV1SessionTermination *bool DisableUserTokenEvent *bool + EnableBackChannelLogout *bool } func (m *InstanceFeatures) isEmpty() bool { @@ -41,7 +42,8 @@ func (m *InstanceFeatures) isEmpty() bool { m.WebKey == nil && m.DebugOIDCParentError == nil && m.OIDCSingleV1SessionTermination == nil && - m.DisableUserTokenEvent == nil + m.DisableUserTokenEvent == nil && + m.EnableBackChannelLogout == nil } func (c *Commands) SetInstanceFeatures(ctx context.Context, f *InstanceFeatures) (*domain.ObjectDetails, error) { diff --git a/internal/command/instance_features_model.go b/internal/command/instance_features_model.go index 5ed0b9c24b..f6c5f39898 100644 --- a/internal/command/instance_features_model.go +++ b/internal/command/instance_features_model.go @@ -71,6 +71,7 @@ func (m *InstanceFeaturesWriteModel) Query() *eventstore.SearchQueryBuilder { feature_v2.InstanceDebugOIDCParentErrorEventType, feature_v2.InstanceOIDCSingleV1SessionTerminationEventType, feature_v2.InstanceDisableUserTokenEvent, + feature_v2.InstanceEnableBackChannelLogout, ). Builder().ResourceOwner(m.ResourceOwner) } @@ -116,6 +117,9 @@ func reduceInstanceFeature(features *InstanceFeatures, key feature.Key, value an case feature.KeyDisableUserTokenEvent: v := value.(bool) features.DisableUserTokenEvent = &v + case feature.KeyEnableBackChannelLogout: + v := value.(bool) + features.EnableBackChannelLogout = &v } } @@ -133,5 +137,6 @@ func (wm *InstanceFeaturesWriteModel) setCommands(ctx context.Context, f *Instan cmds = appendFeatureUpdate(ctx, cmds, aggregate, wm.DebugOIDCParentError, f.DebugOIDCParentError, feature_v2.InstanceDebugOIDCParentErrorEventType) cmds = appendFeatureUpdate(ctx, cmds, aggregate, wm.OIDCSingleV1SessionTermination, f.OIDCSingleV1SessionTermination, feature_v2.InstanceOIDCSingleV1SessionTerminationEventType) cmds = appendFeatureUpdate(ctx, cmds, aggregate, wm.DisableUserTokenEvent, f.DisableUserTokenEvent, feature_v2.InstanceDisableUserTokenEvent) + cmds = appendFeatureUpdate(ctx, cmds, aggregate, wm.EnableBackChannelLogout, f.EnableBackChannelLogout, feature_v2.InstanceEnableBackChannelLogout) return cmds } diff --git a/internal/command/instance_test.go b/internal/command/instance_test.go index af32ea538d..c60b2763b3 100644 --- a/internal/command/instance_test.go +++ b/internal/command/instance_test.go @@ -127,6 +127,7 @@ func oidcAppEvents(ctx context.Context, orgID, projectID, id, name, clientID str 0, nil, false, + "", ), } } @@ -439,6 +440,7 @@ func generatedDomainFilters(instanceID, orgID, projectID, appID, generatedDomain 0, nil, false, + "", ), ), expectFilter( diff --git a/internal/command/logout_session.go b/internal/command/logout_session.go new file mode 100644 index 0000000000..fd52c0f970 --- /dev/null +++ b/internal/command/logout_session.go @@ -0,0 +1,24 @@ +package command + +import ( + "context" + + "github.com/zitadel/zitadel/internal/repository/sessionlogout" + "github.com/zitadel/zitadel/internal/telemetry/tracing" +) + +func (c *Commands) BackChannelLogoutSent(ctx context.Context, id, oidcSessionID, instanceID string) (err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + sessionWriteModel := NewSessionLogoutWriteModel(id, instanceID, oidcSessionID) + if err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel); err != nil { + return err + } + + return c.pushAppendAndReduce( + ctx, + sessionWriteModel, + sessionlogout.NewBackChannelLogoutSentEvent(ctx, sessionWriteModel.aggregate, oidcSessionID), + ) +} diff --git a/internal/command/logout_session_model.go b/internal/command/logout_session_model.go new file mode 100644 index 0000000000..ed31a87012 --- /dev/null +++ b/internal/command/logout_session_model.go @@ -0,0 +1,74 @@ +package command + +import ( + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/repository/sessionlogout" +) + +type SessionLogoutWriteModel struct { + eventstore.WriteModel + + UserID string + OIDCSessionID string + ClientID string + BackChannelLogoutURI string + BackChannelLogoutSent bool + + aggregate *eventstore.Aggregate +} + +func NewSessionLogoutWriteModel(id string, instanceID string, sessionID string) *SessionLogoutWriteModel { + return &SessionLogoutWriteModel{ + WriteModel: eventstore.WriteModel{ + AggregateID: id, + ResourceOwner: instanceID, + InstanceID: instanceID, + }, + aggregate: &sessionlogout.NewAggregate(id, instanceID).Aggregate, + OIDCSessionID: sessionID, + } +} + +func (wm *SessionLogoutWriteModel) Reduce() error { + for _, event := range wm.Events { + switch e := event.(type) { + case *sessionlogout.BackChannelLogoutRegisteredEvent: + wm.reduceRegistered(e) + case *sessionlogout.BackChannelLogoutSentEvent: + wm.reduceSent(e) + } + } + return wm.WriteModel.Reduce() +} + +func (wm *SessionLogoutWriteModel) Query() *eventstore.SearchQueryBuilder { + query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). + AddQuery(). + AggregateTypes(sessionlogout.AggregateType). + AggregateIDs(wm.AggregateID). + EventTypes( + sessionlogout.BackChannelLogoutRegisteredType, + sessionlogout.BackChannelLogoutSentType, + ). + EventData(map[string]interface{}{ + "oidc_session_id": wm.OIDCSessionID, + }). + Builder() + return query +} + +func (wm *SessionLogoutWriteModel) reduceRegistered(e *sessionlogout.BackChannelLogoutRegisteredEvent) { + if wm.OIDCSessionID != e.OIDCSessionID { + return + } + wm.UserID = e.UserID + wm.ClientID = e.ClientID + wm.BackChannelLogoutURI = e.BackChannelLogoutURI +} + +func (wm *SessionLogoutWriteModel) reduceSent(e *sessionlogout.BackChannelLogoutSentEvent) { + if wm.OIDCSessionID != e.OIDCSessionID { + return + } + wm.BackChannelLogoutSent = true +} diff --git a/internal/command/oidc_session.go b/internal/command/oidc_session.go index f7bb9b4cb6..c2922f5194 100644 --- a/internal/command/oidc_session.go +++ b/internal/command/oidc_session.go @@ -18,6 +18,7 @@ import ( "github.com/zitadel/zitadel/internal/id" "github.com/zitadel/zitadel/internal/repository/authrequest" "github.com/zitadel/zitadel/internal/repository/oidcsession" + "github.com/zitadel/zitadel/internal/repository/sessionlogout" "github.com/zitadel/zitadel/internal/repository/user" "github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/zerrors" @@ -133,7 +134,8 @@ func (c *Commands) CreateOIDCSessionFromAuthRequest(ctx context.Context, authReq func (c *Commands) CreateOIDCSession(ctx context.Context, userID, resourceOwner, - clientID string, + clientID, + backChannelLogoutURI string, scope, audience []string, authMethods []domain.UserAuthMethodType, @@ -161,6 +163,7 @@ func (c *Commands) CreateOIDCSession(ctx context.Context, } cmd.AddSession(ctx, userID, resourceOwner, sessionID, clientID, audience, scope, authMethods, authTime, nonce, preferredLanguage, userAgent) + cmd.RegisterLogout(ctx, sessionID, userID, clientID, backChannelLogoutURI) if err = cmd.AddAccessToken(ctx, scope, userID, resourceOwner, reason, actor); err != nil { return nil, err } @@ -433,6 +436,26 @@ func (c *OIDCSessionEvents) SetAuthRequestFailed(ctx context.Context, authReques c.events = append(c.events, authrequest.NewFailedEvent(ctx, authRequestAggregate, domain.OIDCErrorReasonFromError(err))) } +func (c *OIDCSessionEvents) RegisterLogout(ctx context.Context, sessionID, userID, clientID, backChannelLogoutURI string) { + // If there's no SSO session (e.g. service accounts) we do not need to register a logout handler. + // Also, if the client did not register a backchannel_logout_uri it will not support it (https://openid.net/specs/openid-connect-backchannel-1_0.html#BCRegistration) + if sessionID == "" || backChannelLogoutURI == "" { + return + } + if !authz.GetFeatures(ctx).EnableBackChannelLogout { + return + } + + c.events = append(c.events, sessionlogout.NewBackChannelLogoutRegisteredEvent( + ctx, + &sessionlogout.NewAggregate(sessionID, authz.GetInstance(ctx).InstanceID()).Aggregate, + c.oidcSessionWriteModel.AggregateID, + userID, + clientID, + backChannelLogoutURI, + )) +} + func (c *OIDCSessionEvents) AddAccessToken(ctx context.Context, scope []string, userID, resourceOwner string, reason domain.TokenReason, actor *domain.TokenActor) error { accessTokenID, err := c.idGenerator.Next() if err != nil { diff --git a/internal/command/oidc_session_test.go b/internal/command/oidc_session_test.go index 86d6bd9033..43ca622a29 100644 --- a/internal/command/oidc_session_test.go +++ b/internal/command/oidc_session_test.go @@ -24,6 +24,7 @@ import ( "github.com/zitadel/zitadel/internal/repository/authrequest" "github.com/zitadel/zitadel/internal/repository/oidcsession" "github.com/zitadel/zitadel/internal/repository/session" + "github.com/zitadel/zitadel/internal/repository/sessionlogout" "github.com/zitadel/zitadel/internal/repository/user" "github.com/zitadel/zitadel/internal/zerrors" ) @@ -732,21 +733,22 @@ func TestCommands_CreateOIDCSession(t *testing.T) { checkPermission domain.PermissionCheck } type args struct { - ctx context.Context - userID string - resourceOwner string - clientID string - audience []string - scope []string - authMethods []domain.UserAuthMethodType - authTime time.Time - nonce string - preferredLanguage *language.Tag - userAgent *domain.UserAgent - reason domain.TokenReason - actor *domain.TokenActor - needRefreshToken bool - sessionID string + ctx context.Context + userID string + resourceOwner string + clientID string + backChannelLogoutURI string + audience []string + scope []string + authMethods []domain.UserAuthMethodType + authTime time.Time + nonce string + preferredLanguage *language.Tag + userAgent *domain.UserAgent + reason domain.TokenReason + actor *domain.TokenActor + needRefreshToken bool + sessionID string } tests := []struct { name string @@ -763,16 +765,17 @@ func TestCommands_CreateOIDCSession(t *testing.T) { ), }, args: args{ - ctx: authz.WithInstanceID(context.Background(), "instanceID"), - userID: "userID", - resourceOwner: "orgID", - clientID: "clientID", - audience: []string{"audience"}, - scope: []string{"openid", "offline_access"}, - authMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, - authTime: testNow, - nonce: "nonce", - preferredLanguage: &language.Afrikaans, + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + userID: "userID", + resourceOwner: "orgID", + clientID: "clientID", + backChannelLogoutURI: "backChannelLogoutURI", + audience: []string{"audience"}, + scope: []string{"openid", "offline_access"}, + authMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, + authTime: testNow, + nonce: "nonce", + preferredLanguage: &language.Afrikaans, userAgent: &domain.UserAgent{ FingerprintID: gu.Ptr("fp1"), IP: net.ParseIP("1.2.3.4"), @@ -1236,6 +1239,308 @@ func TestCommands_CreateOIDCSession(t *testing.T) { SessionID: "sessionID", }, }, + { + name: "with backChannelLogoutURI", + fields: fields{ + eventstore: expectEventstore( + expectFilter( + user.NewHumanAddedEvent( + context.Background(), + &user.NewAggregate("userID", "org1").Aggregate, + "username", + "firstname", + "lastname", + "nickname", + "displayname", + language.Afrikaans, + domain.GenderUnspecified, + "email", + false, + ), + ), + expectFilter(), // token lifetime + expectPush( + oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, + "userID", "org1", "", "clientID", []string{"audience"}, []string{"openid", "offline_access"}, + []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow, "nonce", &language.Afrikaans, + &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + ), + oidcsession.NewAccessTokenAddedEvent(context.Background(), + &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, + "at_accessTokenID", []string{"openid", "offline_access"}, time.Hour, domain.TokenReasonAuthRequest, + &domain.TokenActor{ + UserID: "user2", + Issuer: "foo.com", + }, + ), + user.NewUserTokenV2AddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, "at_accessTokenID"), + ), + ), + idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID", "accessTokenID"), + defaultAccessTokenLifetime: time.Hour, + defaultRefreshTokenLifetime: 7 * 24 * time.Hour, + defaultRefreshTokenIdleLifetime: 24 * time.Hour, + keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args: args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + userID: "userID", + resourceOwner: "org1", + clientID: "clientID", + backChannelLogoutURI: "backChannelLogoutURI", + audience: []string{"audience"}, + scope: []string{"openid", "offline_access"}, + authMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, + authTime: testNow, + nonce: "nonce", + preferredLanguage: &language.Afrikaans, + userAgent: &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + reason: domain.TokenReasonAuthRequest, + actor: &domain.TokenActor{ + UserID: "user2", + Issuer: "foo.com", + }, + needRefreshToken: false, + }, + want: &OIDCSession{ + TokenID: "V2_oidcSessionID-at_accessTokenID", + ClientID: "clientID", + UserID: "userID", + Audience: []string{"audience"}, + Expiration: time.Time{}.Add(time.Hour), + Scope: []string{"openid", "offline_access"}, + AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, + AuthTime: testNow, + Nonce: "nonce", + PreferredLanguage: &language.Afrikaans, + UserAgent: &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + Reason: domain.TokenReasonAuthRequest, + Actor: &domain.TokenActor{ + UserID: "user2", + Issuer: "foo.com", + }, + }, + }, + { + name: "with backChannelLogoutURI and sessionID", + fields: fields{ + eventstore: expectEventstore( + expectFilter( + user.NewHumanAddedEvent( + context.Background(), + &user.NewAggregate("userID", "org1").Aggregate, + "username", + "firstname", + "lastname", + "nickname", + "displayname", + language.Afrikaans, + domain.GenderUnspecified, + "email", + false, + ), + ), + expectFilter(), // token lifetime + expectPush( + oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, + "userID", "org1", "sessionID", "clientID", []string{"audience"}, []string{"openid", "offline_access"}, + []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow, "nonce", &language.Afrikaans, + &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + ), + oidcsession.NewAccessTokenAddedEvent(context.Background(), + &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, + "at_accessTokenID", []string{"openid", "offline_access"}, time.Hour, domain.TokenReasonAuthRequest, + &domain.TokenActor{ + UserID: "user2", + Issuer: "foo.com", + }, + ), + user.NewUserTokenV2AddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, "at_accessTokenID"), + ), + ), + idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID", "accessTokenID"), + defaultAccessTokenLifetime: time.Hour, + defaultRefreshTokenLifetime: 7 * 24 * time.Hour, + defaultRefreshTokenIdleLifetime: 24 * time.Hour, + keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args: args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + userID: "userID", + resourceOwner: "org1", + clientID: "clientID", + backChannelLogoutURI: "backChannelLogoutURI", + audience: []string{"audience"}, + scope: []string{"openid", "offline_access"}, + authMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, + authTime: testNow, + nonce: "nonce", + preferredLanguage: &language.Afrikaans, + userAgent: &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + reason: domain.TokenReasonAuthRequest, + actor: &domain.TokenActor{ + UserID: "user2", + Issuer: "foo.com", + }, + needRefreshToken: false, + sessionID: "sessionID", + }, + want: &OIDCSession{ + TokenID: "V2_oidcSessionID-at_accessTokenID", + ClientID: "clientID", + UserID: "userID", + Audience: []string{"audience"}, + Expiration: time.Time{}.Add(time.Hour), + Scope: []string{"openid", "offline_access"}, + AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, + AuthTime: testNow, + Nonce: "nonce", + PreferredLanguage: &language.Afrikaans, + UserAgent: &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + Reason: domain.TokenReasonAuthRequest, + Actor: &domain.TokenActor{ + UserID: "user2", + Issuer: "foo.com", + }, + SessionID: "sessionID", + }, + }, + { + name: "with backChannelLogoutURI and sessionID, backchannel logout enabled", + fields: fields{ + eventstore: expectEventstore( + expectFilter( + user.NewHumanAddedEvent( + context.Background(), + &user.NewAggregate("userID", "org1").Aggregate, + "username", + "firstname", + "lastname", + "nickname", + "displayname", + language.Afrikaans, + domain.GenderUnspecified, + "email", + false, + ), + ), + expectFilter(), // token lifetime + expectPush( + oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, + "userID", "org1", "sessionID", "clientID", []string{"audience"}, []string{"openid", "offline_access"}, + []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow, "nonce", &language.Afrikaans, + &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + ), + sessionlogout.NewBackChannelLogoutRegisteredEvent(context.Background(), + &sessionlogout.NewAggregate("sessionID", "instanceID").Aggregate, + "V2_oidcSessionID", + "userID", + "clientID", + "backChannelLogoutURI", + ), + oidcsession.NewAccessTokenAddedEvent(context.Background(), + &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, + "at_accessTokenID", []string{"openid", "offline_access"}, time.Hour, domain.TokenReasonAuthRequest, + &domain.TokenActor{ + UserID: "user2", + Issuer: "foo.com", + }, + ), + user.NewUserTokenV2AddedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, "at_accessTokenID"), + ), + ), + idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID", "accessTokenID"), + defaultAccessTokenLifetime: time.Hour, + defaultRefreshTokenLifetime: 7 * 24 * time.Hour, + defaultRefreshTokenIdleLifetime: 24 * time.Hour, + keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args: args{ + ctx: authz.WithFeatures(authz.WithInstanceID(context.Background(), "instanceID"), feature.Features{EnableBackChannelLogout: true}), + userID: "userID", + resourceOwner: "org1", + clientID: "clientID", + backChannelLogoutURI: "backChannelLogoutURI", + audience: []string{"audience"}, + scope: []string{"openid", "offline_access"}, + authMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, + authTime: testNow, + nonce: "nonce", + preferredLanguage: &language.Afrikaans, + userAgent: &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + reason: domain.TokenReasonAuthRequest, + actor: &domain.TokenActor{ + UserID: "user2", + Issuer: "foo.com", + }, + needRefreshToken: false, + sessionID: "sessionID", + }, + want: &OIDCSession{ + TokenID: "V2_oidcSessionID-at_accessTokenID", + ClientID: "clientID", + UserID: "userID", + Audience: []string{"audience"}, + Expiration: time.Time{}.Add(time.Hour), + Scope: []string{"openid", "offline_access"}, + AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, + AuthTime: testNow, + Nonce: "nonce", + PreferredLanguage: &language.Afrikaans, + UserAgent: &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + Reason: domain.TokenReasonAuthRequest, + Actor: &domain.TokenActor{ + UserID: "user2", + Issuer: "foo.com", + }, + SessionID: "sessionID", + }, + }, { name: "impersonation not allowed", fields: fields{ @@ -1412,6 +1717,7 @@ func TestCommands_CreateOIDCSession(t *testing.T) { tt.args.userID, tt.args.resourceOwner, tt.args.clientID, + tt.args.backChannelLogoutURI, tt.args.scope, tt.args.audience, tt.args.authMethods, diff --git a/internal/command/project_application_oidc.go b/internal/command/project_application_oidc.go index 9852bea23b..ac486b2e18 100644 --- a/internal/command/project_application_oidc.go +++ b/internal/command/project_application_oidc.go @@ -31,6 +31,7 @@ type addOIDCApp struct { ClockSkew time.Duration AdditionalOrigins []string SkipSuccessPageForNativeApp bool + BackChannelLogoutURI string ClientID string ClientSecret string @@ -108,6 +109,7 @@ func (c *Commands) AddOIDCAppCommand(app *addOIDCApp) preparation.Validation { app.ClockSkew, trimStringSliceWhiteSpaces(app.AdditionalOrigins), app.SkipSuccessPageForNativeApp, + app.BackChannelLogoutURI, ), }, nil }, nil @@ -199,6 +201,7 @@ func (c *Commands) addOIDCApplicationWithID(ctx context.Context, oidcApp *domain oidcApp.ClockSkew, trimStringSliceWhiteSpaces(oidcApp.AdditionalOrigins), oidcApp.SkipNativeAppSuccessPage, + strings.TrimSpace(oidcApp.BackChannelLogoutURI), )) addedApplication.AppID = oidcApp.AppID @@ -256,6 +259,7 @@ func (c *Commands) ChangeOIDCApplication(ctx context.Context, oidc *domain.OIDCA oidc.ClockSkew, trimStringSliceWhiteSpaces(oidc.AdditionalOrigins), oidc.SkipNativeAppSuccessPage, + strings.TrimSpace(oidc.BackChannelLogoutURI), ) if err != nil { return nil, err diff --git a/internal/command/project_application_oidc_model.go b/internal/command/project_application_oidc_model.go index 585fdf6c1d..1ab0ad00d6 100644 --- a/internal/command/project_application_oidc_model.go +++ b/internal/command/project_application_oidc_model.go @@ -36,6 +36,7 @@ type OIDCApplicationWriteModel struct { State domain.AppState AdditionalOrigins []string SkipNativeAppSuccessPage bool + BackChannelLogoutURI string oidc bool } @@ -165,6 +166,7 @@ func (wm *OIDCApplicationWriteModel) appendAddOIDCEvent(e *project.OIDCConfigAdd wm.ClockSkew = e.ClockSkew wm.AdditionalOrigins = e.AdditionalOrigins wm.SkipNativeAppSuccessPage = e.SkipNativeAppSuccessPage + wm.BackChannelLogoutURI = e.BackChannelLogoutURI } func (wm *OIDCApplicationWriteModel) appendChangeOIDCEvent(e *project.OIDCConfigChangedEvent) { @@ -213,6 +215,9 @@ func (wm *OIDCApplicationWriteModel) appendChangeOIDCEvent(e *project.OIDCConfig if e.SkipNativeAppSuccessPage != nil { wm.SkipNativeAppSuccessPage = *e.SkipNativeAppSuccessPage } + if e.BackChannelLogoutURI != nil { + wm.BackChannelLogoutURI = *e.BackChannelLogoutURI + } } func (wm *OIDCApplicationWriteModel) Query() *eventstore.SearchQueryBuilder { @@ -254,6 +259,7 @@ func (wm *OIDCApplicationWriteModel) NewChangedEvent( clockSkew time.Duration, additionalOrigins []string, skipNativeAppSuccessPage bool, + backChannelLogoutURI string, ) (*project.OIDCConfigChangedEvent, bool, error) { changes := make([]project.OIDCConfigChanges, 0) var err error @@ -303,6 +309,9 @@ func (wm *OIDCApplicationWriteModel) NewChangedEvent( if wm.SkipNativeAppSuccessPage != skipNativeAppSuccessPage { changes = append(changes, project.ChangeSkipNativeAppSuccessPage(skipNativeAppSuccessPage)) } + if wm.BackChannelLogoutURI != backChannelLogoutURI { + changes = append(changes, project.ChangeBackChannelLogoutURI(backChannelLogoutURI)) + } if len(changes) == 0 { return nil, false, nil diff --git a/internal/command/project_application_oidc_test.go b/internal/command/project_application_oidc_test.go index 01c848cd2e..8c79d03f82 100644 --- a/internal/command/project_application_oidc_test.go +++ b/internal/command/project_application_oidc_test.go @@ -175,6 +175,7 @@ func TestAddOIDCApp(t *testing.T) { 0, []string{"https://sub.test.ch"}, false, + "", ), }, }, @@ -240,6 +241,7 @@ func TestAddOIDCApp(t *testing.T) { 0, nil, false, + "", ), }, }, @@ -305,6 +307,7 @@ func TestAddOIDCApp(t *testing.T) { 0, nil, false, + "", ), }, }, @@ -370,6 +373,7 @@ func TestAddOIDCApp(t *testing.T) { 0, nil, false, + "", ), }, }, @@ -516,6 +520,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) { time.Second*1, []string{"https://sub.test.ch"}, true, + "https://test.ch/backchannel", ), ), ), @@ -543,6 +548,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) { ClockSkew: time.Second * 1, AdditionalOrigins: []string{" https://sub.test.ch "}, SkipNativeAppSuccessPage: true, + BackChannelLogoutURI: " https://test.ch/backchannel ", }, resourceOwner: "org1", }, @@ -571,6 +577,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) { ClockSkew: time.Second * 1, AdditionalOrigins: []string{"https://sub.test.ch"}, SkipNativeAppSuccessPage: true, + BackChannelLogoutURI: "https://test.ch/backchannel", State: domain.AppStateActive, Compliance: &domain.Compliance{}, }, @@ -614,6 +621,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) { time.Second*1, []string{"https://sub.test.ch"}, true, + "https://test.ch/backchannel", ), ), ), @@ -641,6 +649,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) { ClockSkew: time.Second * 1, AdditionalOrigins: []string{"https://sub.test.ch"}, SkipNativeAppSuccessPage: true, + BackChannelLogoutURI: "https://test.ch/backchannel", }, resourceOwner: "org1", }, @@ -669,6 +678,7 @@ func TestCommandSide_AddOIDCApplication(t *testing.T) { ClockSkew: time.Second * 1, AdditionalOrigins: []string{"https://sub.test.ch"}, SkipNativeAppSuccessPage: true, + BackChannelLogoutURI: "https://test.ch/backchannel", State: domain.AppStateActive, Compliance: &domain.Compliance{}, }, @@ -847,6 +857,7 @@ func TestCommandSide_ChangeOIDCApplication(t *testing.T) { time.Second*1, []string{"https://sub.test.ch"}, true, + "https://test.ch/backchannel", ), ), ), @@ -875,6 +886,7 @@ func TestCommandSide_ChangeOIDCApplication(t *testing.T) { ClockSkew: time.Second * 1, AdditionalOrigins: []string{"https://sub.test.ch"}, SkipNativeAppSuccessPage: true, + BackChannelLogoutURI: "https://test.ch/backchannel", }, resourceOwner: "org1", }, @@ -916,6 +928,7 @@ func TestCommandSide_ChangeOIDCApplication(t *testing.T) { time.Second*1, []string{"https://sub.test.ch"}, true, + "https://test.ch/backchannel", ), ), ), @@ -944,6 +957,7 @@ func TestCommandSide_ChangeOIDCApplication(t *testing.T) { ClockSkew: time.Second * 1, AdditionalOrigins: []string{" https://sub.test.ch "}, SkipNativeAppSuccessPage: true, + BackChannelLogoutURI: " https://test.ch/backchannel ", }, resourceOwner: "org1", }, @@ -985,6 +999,7 @@ func TestCommandSide_ChangeOIDCApplication(t *testing.T) { time.Second*1, []string{"https://sub.test.ch"}, true, + "https://test.ch/backchannel", ), ), ), @@ -1019,6 +1034,7 @@ func TestCommandSide_ChangeOIDCApplication(t *testing.T) { ClockSkew: time.Second * 2, AdditionalOrigins: []string{"https://sub.test.ch"}, SkipNativeAppSuccessPage: true, + BackChannelLogoutURI: "https://test.ch/backchannel", }, resourceOwner: "org1", }, @@ -1046,6 +1062,7 @@ func TestCommandSide_ChangeOIDCApplication(t *testing.T) { ClockSkew: time.Second * 2, AdditionalOrigins: []string{"https://sub.test.ch"}, SkipNativeAppSuccessPage: true, + BackChannelLogoutURI: "https://test.ch/backchannel", Compliance: &domain.Compliance{}, State: domain.AppStateActive, }, @@ -1170,6 +1187,7 @@ func TestCommandSide_ChangeOIDCApplicationSecret(t *testing.T) { time.Second*1, []string{"https://sub.test.ch"}, false, + "", ), ), ), @@ -1213,6 +1231,7 @@ func TestCommandSide_ChangeOIDCApplicationSecret(t *testing.T) { ClockSkew: time.Second * 1, AdditionalOrigins: []string{"https://sub.test.ch"}, SkipNativeAppSuccessPage: false, + BackChannelLogoutURI: "", State: domain.AppStateActive, }, }, @@ -1327,6 +1346,7 @@ func TestCommands_VerifyOIDCClientSecret(t *testing.T) { time.Second*1, []string{"https://sub.test.ch"}, false, + "", ), ), ), @@ -1362,6 +1382,7 @@ func TestCommands_VerifyOIDCClientSecret(t *testing.T) { time.Second*1, []string{"https://sub.test.ch"}, false, + "", ), ), ), @@ -1396,6 +1417,7 @@ func TestCommands_VerifyOIDCClientSecret(t *testing.T) { time.Second*1, []string{"https://sub.test.ch"}, false, + "", ), ), ), diff --git a/internal/command/project_converter.go b/internal/command/project_converter.go index 35679d8a14..079dc85654 100644 --- a/internal/command/project_converter.go +++ b/internal/command/project_converter.go @@ -47,6 +47,7 @@ func oidcWriteModelToOIDCConfig(writeModel *OIDCApplicationWriteModel) *domain.O ClockSkew: writeModel.ClockSkew, AdditionalOrigins: writeModel.AdditionalOrigins, SkipNativeAppSuccessPage: writeModel.SkipNativeAppSuccessPage, + BackChannelLogoutURI: writeModel.BackChannelLogoutURI, } } diff --git a/internal/command/system_features.go b/internal/command/system_features.go index e024a6dd18..f089ada207 100644 --- a/internal/command/system_features.go +++ b/internal/command/system_features.go @@ -19,6 +19,7 @@ type SystemFeatures struct { ImprovedPerformance []feature.ImprovedPerformanceType OIDCSingleV1SessionTermination *bool DisableUserTokenEvent *bool + EnableBackChannelLogout *bool } func (m *SystemFeatures) isEmpty() bool { @@ -31,7 +32,8 @@ func (m *SystemFeatures) isEmpty() bool { // nil check to allow unset improvements m.ImprovedPerformance == nil && m.OIDCSingleV1SessionTermination == nil && - m.DisableUserTokenEvent == nil + m.DisableUserTokenEvent == nil && + m.EnableBackChannelLogout == nil } func (c *Commands) SetSystemFeatures(ctx context.Context, f *SystemFeatures) (*domain.ObjectDetails, error) { diff --git a/internal/command/system_features_model.go b/internal/command/system_features_model.go index 5cc70338bb..d3fca66fea 100644 --- a/internal/command/system_features_model.go +++ b/internal/command/system_features_model.go @@ -62,6 +62,7 @@ func (m *SystemFeaturesWriteModel) Query() *eventstore.SearchQueryBuilder { feature_v2.SystemImprovedPerformanceEventType, feature_v2.SystemOIDCSingleV1SessionTerminationEventType, feature_v2.SystemDisableUserTokenEvent, + feature_v2.SystemEnableBackChannelLogout, ). Builder().ResourceOwner(m.ResourceOwner) } @@ -100,6 +101,9 @@ func reduceSystemFeature(features *SystemFeatures, key feature.Key, value any) { case feature.KeyDisableUserTokenEvent: v := value.(bool) features.DisableUserTokenEvent = &v + case feature.KeyEnableBackChannelLogout: + v := value.(bool) + features.EnableBackChannelLogout = &v } } @@ -115,6 +119,7 @@ func (wm *SystemFeaturesWriteModel) setCommands(ctx context.Context, f *SystemFe cmds = appendFeatureSliceUpdate(ctx, cmds, aggregate, wm.ImprovedPerformance, f.ImprovedPerformance, feature_v2.SystemImprovedPerformanceEventType) cmds = appendFeatureUpdate(ctx, cmds, aggregate, wm.OIDCSingleV1SessionTermination, f.OIDCSingleV1SessionTermination, feature_v2.SystemOIDCSingleV1SessionTerminationEventType) cmds = appendFeatureUpdate(ctx, cmds, aggregate, wm.DisableUserTokenEvent, f.DisableUserTokenEvent, feature_v2.SystemDisableUserTokenEvent) + cmds = appendFeatureUpdate(ctx, cmds, aggregate, wm.EnableBackChannelLogout, f.EnableBackChannelLogout, feature_v2.SystemEnableBackChannelLogout) return cmds } diff --git a/internal/command/user_human.go b/internal/command/user_human.go index 825ae50f9c..91739e0d6d 100644 --- a/internal/command/user_human.go +++ b/internal/command/user_human.go @@ -628,16 +628,21 @@ func createAddHumanEvent(ctx context.Context, aggregate *eventstore.Aggregate, h return addEvent } -func (c *Commands) HumansSignOut(ctx context.Context, agentID string, userIDs []string) error { +type HumanSignOutSession struct { + ID string + UserID string +} + +func (c *Commands) HumansSignOut(ctx context.Context, agentID string, sessions []HumanSignOutSession) error { if agentID == "" { return zerrors.ThrowInvalidArgument(nil, "COMMAND-2M0ds", "Errors.User.UserIDMissing") } - if len(userIDs) == 0 { + if len(sessions) == 0 { return zerrors.ThrowInvalidArgument(nil, "COMMAND-M0od3", "Errors.User.UserIDMissing") } events := make([]eventstore.Command, 0) - for _, userID := range userIDs { - existingUser, err := c.getHumanWriteModelByID(ctx, userID, "") + for _, session := range sessions { + existingUser, err := c.getHumanWriteModelByID(ctx, session.UserID, "") if err != nil { return err } @@ -647,7 +652,9 @@ func (c *Commands) HumansSignOut(ctx context.Context, agentID string, userIDs [] events = append(events, user.NewHumanSignedOutEvent( ctx, UserAggregateFromWriteModel(&existingUser.WriteModel), - agentID)) + agentID, + session.ID, + )) } if len(events) == 0 { return nil diff --git a/internal/command/user_human_test.go b/internal/command/user_human_test.go index fbf3523fc9..78d7248516 100644 --- a/internal/command/user_human_test.go +++ b/internal/command/user_human_test.go @@ -3123,9 +3123,9 @@ func TestCommandSide_HumanSignOut(t *testing.T) { } type ( args struct { - ctx context.Context - agentID string - userIDs []string + ctx context.Context + agentID string + sessions []HumanSignOutSession } ) type res struct { @@ -3144,9 +3144,9 @@ func TestCommandSide_HumanSignOut(t *testing.T) { eventstore: expectEventstore(), }, args: args{ - ctx: context.Background(), - agentID: "", - userIDs: []string{"user1"}, + ctx: context.Background(), + agentID: "", + sessions: []HumanSignOutSession{{ID: "session1", UserID: "user1"}}, }, res: res{ err: zerrors.IsErrorInvalidArgument, @@ -3158,9 +3158,9 @@ func TestCommandSide_HumanSignOut(t *testing.T) { eventstore: expectEventstore(), }, args: args{ - ctx: context.Background(), - agentID: "agent1", - userIDs: []string{}, + ctx: context.Background(), + agentID: "agent1", + sessions: []HumanSignOutSession{}, }, res: res{ err: zerrors.IsErrorInvalidArgument, @@ -3174,9 +3174,9 @@ func TestCommandSide_HumanSignOut(t *testing.T) { ), }, args: args{ - ctx: context.Background(), - agentID: "agent1", - userIDs: []string{"user1"}, + ctx: context.Background(), + agentID: "agent1", + sessions: []HumanSignOutSession{{ID: "session1", UserID: "user1"}}, }, res: res{}, }, @@ -3204,14 +3204,15 @@ func TestCommandSide_HumanSignOut(t *testing.T) { user.NewHumanSignedOutEvent(context.Background(), &user.NewAggregate("user1", "org1").Aggregate, "agent1", + "session1", ), ), ), }, args: args{ - ctx: context.Background(), - agentID: "agent1", - userIDs: []string{"user1"}, + ctx: context.Background(), + agentID: "agent1", + sessions: []HumanSignOutSession{{ID: "session1", UserID: "user1"}}, }, res: res{ want: &domain.ObjectDetails{ @@ -3259,18 +3260,20 @@ func TestCommandSide_HumanSignOut(t *testing.T) { user.NewHumanSignedOutEvent(context.Background(), &user.NewAggregate("user1", "org1").Aggregate, "agent1", + "session1", ), user.NewHumanSignedOutEvent(context.Background(), &user.NewAggregate("user2", "org1").Aggregate, "agent1", + "session2", ), ), ), }, args: args{ - ctx: context.Background(), - agentID: "agent1", - userIDs: []string{"user1", "user2"}, + ctx: context.Background(), + agentID: "agent1", + sessions: []HumanSignOutSession{{ID: "session1", UserID: "user1"}, {ID: "session2", UserID: "user2"}}, }, res: res{ want: &domain.ObjectDetails{ @@ -3284,7 +3287,7 @@ func TestCommandSide_HumanSignOut(t *testing.T) { r := &Commands{ eventstore: tt.fields.eventstore(t), } - err := r.HumansSignOut(tt.args.ctx, tt.args.agentID, tt.args.userIDs) + err := r.HumansSignOut(tt.args.ctx, tt.args.agentID, tt.args.sessions) if tt.res.err == nil { assert.NoError(t, err) } diff --git a/internal/domain/application_oidc.go b/internal/domain/application_oidc.go index 5fe7b1f698..1ffb61f538 100644 --- a/internal/domain/application_oidc.go +++ b/internal/domain/application_oidc.go @@ -45,6 +45,7 @@ type OIDCApp struct { ClockSkew time.Duration AdditionalOrigins []string SkipNativeAppSuccessPage bool + BackChannelLogoutURI string State AppState } diff --git a/internal/feature/feature.go b/internal/feature/feature.go index 3104f6ed59..1d619b25d8 100644 --- a/internal/feature/feature.go +++ b/internal/feature/feature.go @@ -18,6 +18,7 @@ const ( KeyDebugOIDCParentError KeyOIDCSingleV1SessionTermination KeyDisableUserTokenEvent + KeyEnableBackChannelLogout ) //go:generate enumer -type Level -transform snake -trimprefix Level @@ -43,8 +44,9 @@ type Features struct { ImprovedPerformance []ImprovedPerformanceType `json:"improved_performance,omitempty"` WebKey bool `json:"web_key,omitempty"` DebugOIDCParentError bool `json:"debug_oidc_parent_error,omitempty"` - OIDCSingleV1SessionTermination bool `json:"terminate_single_v1_session,omitempty"` + OIDCSingleV1SessionTermination bool `json:"oidc_single_v1_session_termination,omitempty"` DisableUserTokenEvent bool `json:"disable_user_token_event,omitempty"` + EnableBackChannelLogout bool `json:"enable_back_channel_logout,omitempty"` } type ImprovedPerformanceType int32 diff --git a/internal/feature/key_enumer.go b/internal/feature/key_enumer.go index 46d8613fbc..db3cf4161e 100644 --- a/internal/feature/key_enumer.go +++ b/internal/feature/key_enumer.go @@ -7,11 +7,11 @@ import ( "strings" ) -const _KeyName = "unspecifiedlogin_default_orgtrigger_introspection_projectionslegacy_introspectionuser_schematoken_exchangeactionsimproved_performanceweb_keydebug_oidc_parent_erroroidc_single_v1_session_terminationdisable_user_token_event" +const _KeyName = "unspecifiedlogin_default_orgtrigger_introspection_projectionslegacy_introspectionuser_schematoken_exchangeactionsimproved_performanceweb_keydebug_oidc_parent_erroroidc_single_v1_session_terminationdisable_user_token_eventenable_back_channel_logout" -var _KeyIndex = [...]uint8{0, 11, 28, 61, 81, 92, 106, 113, 133, 140, 163, 197, 221} +var _KeyIndex = [...]uint8{0, 11, 28, 61, 81, 92, 106, 113, 133, 140, 163, 197, 221, 247} -const _KeyLowerName = "unspecifiedlogin_default_orgtrigger_introspection_projectionslegacy_introspectionuser_schematoken_exchangeactionsimproved_performanceweb_keydebug_oidc_parent_erroroidc_single_v1_session_terminationdisable_user_token_event" +const _KeyLowerName = "unspecifiedlogin_default_orgtrigger_introspection_projectionslegacy_introspectionuser_schematoken_exchangeactionsimproved_performanceweb_keydebug_oidc_parent_erroroidc_single_v1_session_terminationdisable_user_token_eventenable_back_channel_logout" func (i Key) String() string { if i < 0 || i >= Key(len(_KeyIndex)-1) { @@ -36,9 +36,10 @@ func _KeyNoOp() { _ = x[KeyDebugOIDCParentError-(9)] _ = x[KeyOIDCSingleV1SessionTermination-(10)] _ = x[KeyDisableUserTokenEvent-(11)] + _ = x[KeyEnableBackChannelLogout-(12)] } -var _KeyValues = []Key{KeyUnspecified, KeyLoginDefaultOrg, KeyTriggerIntrospectionProjections, KeyLegacyIntrospection, KeyUserSchema, KeyTokenExchange, KeyActions, KeyImprovedPerformance, KeyWebKey, KeyDebugOIDCParentError, KeyOIDCSingleV1SessionTermination, KeyDisableUserTokenEvent} +var _KeyValues = []Key{KeyUnspecified, KeyLoginDefaultOrg, KeyTriggerIntrospectionProjections, KeyLegacyIntrospection, KeyUserSchema, KeyTokenExchange, KeyActions, KeyImprovedPerformance, KeyWebKey, KeyDebugOIDCParentError, KeyOIDCSingleV1SessionTermination, KeyDisableUserTokenEvent, KeyEnableBackChannelLogout} var _KeyNameToValueMap = map[string]Key{ _KeyName[0:11]: KeyUnspecified, @@ -65,6 +66,8 @@ var _KeyNameToValueMap = map[string]Key{ _KeyLowerName[163:197]: KeyOIDCSingleV1SessionTermination, _KeyName[197:221]: KeyDisableUserTokenEvent, _KeyLowerName[197:221]: KeyDisableUserTokenEvent, + _KeyName[221:247]: KeyEnableBackChannelLogout, + _KeyLowerName[221:247]: KeyEnableBackChannelLogout, } var _KeyNames = []string{ @@ -80,6 +83,7 @@ var _KeyNames = []string{ _KeyName[140:163], _KeyName[163:197], _KeyName[197:221], + _KeyName[221:247], } // KeyString retrieves an enum value from the enum constants string name. diff --git a/internal/notification/channels.go b/internal/notification/channels.go index c70eaecbcc..ba9bcb9d7d 100644 --- a/internal/notification/channels.go +++ b/internal/notification/channels.go @@ -6,6 +6,7 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/notification/channels/email" + "github.com/zitadel/zitadel/internal/notification/channels/set" "github.com/zitadel/zitadel/internal/notification/channels/sms" "github.com/zitadel/zitadel/internal/notification/channels/webhook" "github.com/zitadel/zitadel/internal/notification/handlers" @@ -104,3 +105,14 @@ func (c *channels) Webhook(ctx context.Context, cfg webhook.Config) (*senders.Ch c.counters.failed.json, ) } + +func (c *channels) SecurityTokenEvent(ctx context.Context, cfg set.Config) (*senders.Chain, error) { + return senders.SecurityEventTokenChannels( + ctx, + cfg, + c.q.GetFileSystemProvider, + c.q.GetLogProvider, + c.counters.success.json, + c.counters.failed.json, + ) +} diff --git a/internal/notification/channels/set/channel.go b/internal/notification/channels/set/channel.go new file mode 100644 index 0000000000..fbd4065739 --- /dev/null +++ b/internal/notification/channels/set/channel.go @@ -0,0 +1,75 @@ +package set + +import ( + "context" + "encoding/json" + "io" + "net/http" + "strings" + "time" + + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/notification/channels" + "github.com/zitadel/zitadel/internal/notification/messages" + "github.com/zitadel/zitadel/internal/zerrors" +) + +func InitChannel(ctx context.Context, cfg Config) (channels.NotificationChannel, error) { + if err := cfg.Validate(); err != nil { + return nil, err + } + + logging.Debug("successfully initialized security event token json channel") + return channels.HandleMessageFunc(func(message channels.Message) error { + requestCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + msg, ok := message.(*messages.Form) + if !ok { + return zerrors.ThrowInternal(nil, "SET-K686U", "message is not SET") + } + payload, err := msg.GetContent() + if err != nil { + return err + } + req, err := http.NewRequestWithContext(requestCtx, http.MethodPost, cfg.CallURL, strings.NewReader(payload)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + logging.WithFields("instanceID", authz.GetInstance(ctx).InstanceID(), "calling_url", cfg.CallURL).Debug("security event token called") + if resp.StatusCode == http.StatusOK || + resp.StatusCode == http.StatusAccepted || + resp.StatusCode == http.StatusNoContent { + return nil + } + body, err := mapResponse(resp) + logging.WithFields("instanceID", authz.GetInstance(ctx).InstanceID(), "callURL", cfg.CallURL). + OnError(err).Debug("error mapping response") + if resp.StatusCode == http.StatusBadRequest { + logging.WithFields("instanceID", authz.GetInstance(ctx).InstanceID(), "callURL", cfg.CallURL, "status", resp.Status, "body", body). + Error("security event token didn't return a success status") + return nil + } + return zerrors.ThrowInternalf(err, "SET-DF3dq", "security event token to %s didn't return a success status: %s (%v)", cfg.CallURL, resp.Status, body) + }), nil +} + +func mapResponse(resp *http.Response) (map[string]any, error) { + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + requestError := make(map[string]any) + err = json.Unmarshal(body, &requestError) + if err != nil { + return nil, err + } + return requestError, nil +} diff --git a/internal/notification/channels/set/config.go b/internal/notification/channels/set/config.go new file mode 100644 index 0000000000..5d28b3d110 --- /dev/null +++ b/internal/notification/channels/set/config.go @@ -0,0 +1,14 @@ +package set + +import ( + "net/url" +) + +type Config struct { + CallURL string +} + +func (w *Config) Validate() error { + _, err := url.Parse(w.CallURL) + return err +} diff --git a/internal/notification/handlers/back_channel_logout.go b/internal/notification/handlers/back_channel_logout.go new file mode 100644 index 0000000000..43d98ada11 --- /dev/null +++ b/internal/notification/handlers/back_channel_logout.go @@ -0,0 +1,266 @@ +package handlers + +import ( + "context" + "errors" + "slices" + "sync" + "time" + + "github.com/zitadel/logging" + "github.com/zitadel/oidc/v3/pkg/crypto" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" + + "github.com/zitadel/zitadel/internal/api/authz" + http_utils "github.com/zitadel/zitadel/internal/api/http" + zoidc "github.com/zitadel/zitadel/internal/api/oidc" + "github.com/zitadel/zitadel/internal/command" + zcrypto "github.com/zitadel/zitadel/internal/crypto" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/handler/v2" + "github.com/zitadel/zitadel/internal/id" + "github.com/zitadel/zitadel/internal/notification/channels/set" + _ "github.com/zitadel/zitadel/internal/notification/statik" + "github.com/zitadel/zitadel/internal/notification/types" + "github.com/zitadel/zitadel/internal/repository/session" + "github.com/zitadel/zitadel/internal/repository/sessionlogout" + "github.com/zitadel/zitadel/internal/repository/user" + "github.com/zitadel/zitadel/internal/zerrors" +) + +const ( + BackChannelLogoutNotificationsProjectionTable = "projections.notifications_back_channel_logout" +) + +type backChannelLogoutNotifier struct { + commands *command.Commands + queries *NotificationQueries + eventstore *eventstore.Eventstore + keyEncryptionAlg zcrypto.EncryptionAlgorithm + channels types.ChannelChains + idGenerator id.Generator + tokenLifetime time.Duration +} + +func NewBackChannelLogoutNotifier( + ctx context.Context, + config handler.Config, + commands *command.Commands, + queries *NotificationQueries, + es *eventstore.Eventstore, + keyEncryptionAlg zcrypto.EncryptionAlgorithm, + channels types.ChannelChains, + tokenLifetime time.Duration, +) *handler.Handler { + return handler.NewHandler(ctx, &config, &backChannelLogoutNotifier{ + commands: commands, + queries: queries, + eventstore: es, + keyEncryptionAlg: keyEncryptionAlg, + channels: channels, + tokenLifetime: tokenLifetime, + idGenerator: id.SonyFlakeGenerator(), + }) + +} + +func (*backChannelLogoutNotifier) Name() string { + return BackChannelLogoutNotificationsProjectionTable +} + +func (u *backChannelLogoutNotifier) Reducers() []handler.AggregateReducer { + return []handler.AggregateReducer{ + { + Aggregate: session.AggregateType, + EventReducers: []handler.EventReducer{ + { + Event: session.TerminateType, + Reduce: u.reduceSessionTerminated, + }, + }, + }, { + Aggregate: user.AggregateType, + EventReducers: []handler.EventReducer{ + { + Event: user.HumanSignedOutType, + Reduce: u.reduceUserSignedOut, + }, + }, + }, + } +} + +func (u *backChannelLogoutNotifier) reduceUserSignedOut(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*user.HumanSignedOutEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-Gr63h", "reduce.wrong.event.type %s", user.HumanSignedOutType) + } + + return handler.NewStatement(event, func(ex handler.Executer, projectionName string) error { + ctx, err := u.queries.HandlerContext(event.Aggregate()) + if err != nil { + return err + } + if !authz.GetFeatures(ctx).EnableBackChannelLogout { + return nil + } + if e.SessionID == "" { + return nil + } + return u.terminateSession(ctx, e.SessionID, e) + }), nil +} + +func (u *backChannelLogoutNotifier) reduceSessionTerminated(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*session.TerminateEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-D6H2h", "reduce.wrong.event.type %s", session.TerminateType) + } + + return handler.NewStatement(event, func(ex handler.Executer, projectionName string) error { + ctx, err := u.queries.HandlerContext(event.Aggregate()) + if err != nil { + return err + } + if !authz.GetFeatures(ctx).EnableBackChannelLogout { + return nil + } + return u.terminateSession(ctx, e.Aggregate().ID, e) + }), nil +} + +type backChannelLogoutSession struct { + sessionID string + + // sessions contain a map of oidc session IDs and their corresponding clientID + sessions []backChannelLogoutOIDCSessions +} + +func (u *backChannelLogoutNotifier) terminateSession(ctx context.Context, id string, e eventstore.Event) error { + sessions := &backChannelLogoutSession{sessionID: id} + err := u.eventstore.FilterToQueryReducer(ctx, sessions) + if err != nil { + return err + } + + ctx, err = u.queries.Origin(ctx, e) + if err != nil { + return err + } + + getSigner := zoidc.GetSignerOnce(u.queries.GetActiveSigningWebKey, u.signingKey) + + var wg sync.WaitGroup + wg.Add(len(sessions.sessions)) + errs := make([]error, 0, len(sessions.sessions)) + for _, oidcSession := range sessions.sessions { + go func(oidcSession *backChannelLogoutOIDCSessions) { + defer wg.Done() + err := u.sendLogoutToken(ctx, oidcSession, e, getSigner) + if err != nil { + errs = append(errs, err) + return + } + err = u.commands.BackChannelLogoutSent(ctx, oidcSession.SessionID, oidcSession.OIDCSessionID, e.Aggregate().InstanceID) + if err != nil { + errs = append(errs, err) + } + }(&oidcSession) + } + wg.Wait() + return errors.Join(errs...) +} + +func (u *backChannelLogoutNotifier) signingKey(ctx context.Context) (op.SigningKey, error) { + keys, err := u.queries.ActivePrivateSigningKey(ctx, time.Now()) + if err != nil { + return nil, err + } + if len(keys.Keys) == 0 { + logging.WithFields("instanceID", authz.GetInstance(ctx).InstanceID()). + Info("There's no active signing key and automatic rotation is not supported for back channel logout." + + "Please enable the webkey management feature on your instance") + return nil, zerrors.ThrowPreconditionFailed(nil, "HANDL-DF3nf", "no active signing key") + } + return zoidc.PrivateKeyToSigningKey(zoidc.SelectSigningKey(keys.Keys), u.keyEncryptionAlg) +} + +func (u *backChannelLogoutNotifier) sendLogoutToken(ctx context.Context, oidcSession *backChannelLogoutOIDCSessions, e eventstore.Event, getSigner zoidc.SignerFunc) error { + token, err := u.logoutToken(ctx, oidcSession, getSigner) + if err != nil { + return err + } + err = types.SendSecurityTokenEvent(ctx, set.Config{CallURL: oidcSession.BackChannelLogoutURI}, u.channels, &LogoutTokenMessage{LogoutToken: token}, e).WithoutTemplate() + if err != nil { + return err + } + return nil +} + +func (u *backChannelLogoutNotifier) logoutToken(ctx context.Context, oidcSession *backChannelLogoutOIDCSessions, getSigner zoidc.SignerFunc) (string, error) { + jwtID, err := u.idGenerator.Next() + if err != nil { + return "", err + } + token := oidc.NewLogoutTokenClaims( + http_utils.DomainContext(ctx).Origin(), + oidcSession.UserID, + oidc.Audience{oidcSession.ClientID}, + time.Now().Add(u.tokenLifetime), + jwtID, + oidcSession.SessionID, + time.Second, + ) + signer, _, err := getSigner(ctx) + if err != nil { + return "", err + } + return crypto.Sign(token, signer) +} + +type LogoutTokenMessage struct { + LogoutToken string `schema:"logout_token"` +} + +type backChannelLogoutOIDCSessions struct { + SessionID string + OIDCSessionID string + UserID string + ClientID string + BackChannelLogoutURI string +} + +func (b *backChannelLogoutSession) Reduce() error { + return nil +} + +func (b *backChannelLogoutSession) AppendEvents(events ...eventstore.Event) { + for _, event := range events { + switch e := event.(type) { + case *sessionlogout.BackChannelLogoutRegisteredEvent: + b.sessions = append(b.sessions, backChannelLogoutOIDCSessions{ + SessionID: b.sessionID, + OIDCSessionID: e.OIDCSessionID, + UserID: e.UserID, + ClientID: e.ClientID, + BackChannelLogoutURI: e.BackChannelLogoutURI, + }) + case *sessionlogout.BackChannelLogoutSentEvent: + slices.DeleteFunc(b.sessions, func(session backChannelLogoutOIDCSessions) bool { + return session.OIDCSessionID == e.OIDCSessionID + }) + } + } +} + +func (b *backChannelLogoutSession) Query() *eventstore.SearchQueryBuilder { + return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). + AddQuery(). + AggregateTypes(sessionlogout.AggregateType). + AggregateIDs(b.sessionID). + EventTypes( + sessionlogout.BackChannelLogoutRegisteredType, + sessionlogout.BackChannelLogoutSentType). + Builder() +} diff --git a/internal/notification/handlers/ctx.go b/internal/notification/handlers/ctx.go index 9dc6d87b1b..b8fc45da68 100644 --- a/internal/notification/handlers/ctx.go +++ b/internal/notification/handlers/ctx.go @@ -13,3 +13,13 @@ func HandlerContext(event *eventstore.Aggregate) context.Context { ctx := authz.WithInstanceID(context.Background(), event.InstanceID) return authz.SetCtxData(ctx, authz.CtxData{UserID: NotifyUserID, OrgID: event.ResourceOwner}) } + +func (n *NotificationQueries) HandlerContext(event *eventstore.Aggregate) (context.Context, error) { + ctx := context.Background() + instance, err := n.InstanceByID(ctx, event.InstanceID) + if err != nil { + return nil, err + } + ctx = authz.WithInstance(ctx, instance) + return authz.SetCtxData(ctx, authz.CtxData{UserID: NotifyUserID, OrgID: event.ResourceOwner}), nil +} diff --git a/internal/notification/handlers/mock/commands.mock.go b/internal/notification/handlers/mock/commands.mock.go index 7d41c30f30..ec327de8e8 100644 --- a/internal/notification/handlers/mock/commands.mock.go +++ b/internal/notification/handlers/mock/commands.mock.go @@ -23,6 +23,7 @@ import ( type MockCommands struct { ctrl *gomock.Controller recorder *MockCommandsMockRecorder + isgomock struct{} } // MockCommandsMockRecorder is the mock recorder for MockCommands. @@ -43,197 +44,197 @@ func (m *MockCommands) EXPECT() *MockCommandsMockRecorder { } // HumanEmailVerificationCodeSent mocks base method. -func (m *MockCommands) HumanEmailVerificationCodeSent(arg0 context.Context, arg1, arg2 string) error { +func (m *MockCommands) HumanEmailVerificationCodeSent(ctx context.Context, orgID, userID string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HumanEmailVerificationCodeSent", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "HumanEmailVerificationCodeSent", ctx, orgID, userID) ret0, _ := ret[0].(error) return ret0 } // HumanEmailVerificationCodeSent indicates an expected call of HumanEmailVerificationCodeSent. -func (mr *MockCommandsMockRecorder) HumanEmailVerificationCodeSent(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockCommandsMockRecorder) HumanEmailVerificationCodeSent(ctx, orgID, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanEmailVerificationCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanEmailVerificationCodeSent), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanEmailVerificationCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanEmailVerificationCodeSent), ctx, orgID, userID) } // HumanInitCodeSent mocks base method. -func (m *MockCommands) HumanInitCodeSent(arg0 context.Context, arg1, arg2 string) error { +func (m *MockCommands) HumanInitCodeSent(ctx context.Context, orgID, userID string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HumanInitCodeSent", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "HumanInitCodeSent", ctx, orgID, userID) ret0, _ := ret[0].(error) return ret0 } // HumanInitCodeSent indicates an expected call of HumanInitCodeSent. -func (mr *MockCommandsMockRecorder) HumanInitCodeSent(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockCommandsMockRecorder) HumanInitCodeSent(ctx, orgID, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanInitCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanInitCodeSent), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanInitCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanInitCodeSent), ctx, orgID, userID) } // HumanOTPEmailCodeSent mocks base method. -func (m *MockCommands) HumanOTPEmailCodeSent(arg0 context.Context, arg1, arg2 string) error { +func (m *MockCommands) HumanOTPEmailCodeSent(ctx context.Context, userID, resourceOwner string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HumanOTPEmailCodeSent", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "HumanOTPEmailCodeSent", ctx, userID, resourceOwner) ret0, _ := ret[0].(error) return ret0 } // HumanOTPEmailCodeSent indicates an expected call of HumanOTPEmailCodeSent. -func (mr *MockCommandsMockRecorder) HumanOTPEmailCodeSent(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockCommandsMockRecorder) HumanOTPEmailCodeSent(ctx, userID, resourceOwner any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanOTPEmailCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanOTPEmailCodeSent), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanOTPEmailCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanOTPEmailCodeSent), ctx, userID, resourceOwner) } // HumanOTPSMSCodeSent mocks base method. -func (m *MockCommands) HumanOTPSMSCodeSent(arg0 context.Context, arg1, arg2 string, arg3 *senders.CodeGeneratorInfo) error { +func (m *MockCommands) HumanOTPSMSCodeSent(ctx context.Context, userID, resourceOwner string, generatorInfo *senders.CodeGeneratorInfo) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HumanOTPSMSCodeSent", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "HumanOTPSMSCodeSent", ctx, userID, resourceOwner, generatorInfo) ret0, _ := ret[0].(error) return ret0 } // HumanOTPSMSCodeSent indicates an expected call of HumanOTPSMSCodeSent. -func (mr *MockCommandsMockRecorder) HumanOTPSMSCodeSent(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockCommandsMockRecorder) HumanOTPSMSCodeSent(ctx, userID, resourceOwner, generatorInfo any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanOTPSMSCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanOTPSMSCodeSent), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanOTPSMSCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanOTPSMSCodeSent), ctx, userID, resourceOwner, generatorInfo) } // HumanPasswordlessInitCodeSent mocks base method. -func (m *MockCommands) HumanPasswordlessInitCodeSent(arg0 context.Context, arg1, arg2, arg3 string) error { +func (m *MockCommands) HumanPasswordlessInitCodeSent(ctx context.Context, userID, resourceOwner, codeID string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HumanPasswordlessInitCodeSent", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "HumanPasswordlessInitCodeSent", ctx, userID, resourceOwner, codeID) ret0, _ := ret[0].(error) return ret0 } // HumanPasswordlessInitCodeSent indicates an expected call of HumanPasswordlessInitCodeSent. -func (mr *MockCommandsMockRecorder) HumanPasswordlessInitCodeSent(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockCommandsMockRecorder) HumanPasswordlessInitCodeSent(ctx, userID, resourceOwner, codeID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanPasswordlessInitCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanPasswordlessInitCodeSent), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanPasswordlessInitCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanPasswordlessInitCodeSent), ctx, userID, resourceOwner, codeID) } // HumanPhoneVerificationCodeSent mocks base method. -func (m *MockCommands) HumanPhoneVerificationCodeSent(arg0 context.Context, arg1, arg2 string, arg3 *senders.CodeGeneratorInfo) error { +func (m *MockCommands) HumanPhoneVerificationCodeSent(ctx context.Context, orgID, userID string, generatorInfo *senders.CodeGeneratorInfo) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HumanPhoneVerificationCodeSent", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "HumanPhoneVerificationCodeSent", ctx, orgID, userID, generatorInfo) ret0, _ := ret[0].(error) return ret0 } // HumanPhoneVerificationCodeSent indicates an expected call of HumanPhoneVerificationCodeSent. -func (mr *MockCommandsMockRecorder) HumanPhoneVerificationCodeSent(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockCommandsMockRecorder) HumanPhoneVerificationCodeSent(ctx, orgID, userID, generatorInfo any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanPhoneVerificationCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanPhoneVerificationCodeSent), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanPhoneVerificationCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanPhoneVerificationCodeSent), ctx, orgID, userID, generatorInfo) } // InviteCodeSent mocks base method. -func (m *MockCommands) InviteCodeSent(arg0 context.Context, arg1, arg2 string) error { +func (m *MockCommands) InviteCodeSent(ctx context.Context, orgID, userID string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "InviteCodeSent", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "InviteCodeSent", ctx, orgID, userID) ret0, _ := ret[0].(error) return ret0 } // InviteCodeSent indicates an expected call of InviteCodeSent. -func (mr *MockCommandsMockRecorder) InviteCodeSent(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockCommandsMockRecorder) InviteCodeSent(ctx, orgID, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InviteCodeSent", reflect.TypeOf((*MockCommands)(nil).InviteCodeSent), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InviteCodeSent", reflect.TypeOf((*MockCommands)(nil).InviteCodeSent), ctx, orgID, userID) } // MilestonePushed mocks base method. -func (m *MockCommands) MilestonePushed(arg0 context.Context, arg1 string, arg2 milestone.Type, arg3 []string) error { +func (m *MockCommands) MilestonePushed(ctx context.Context, instanceID string, msType milestone.Type, endpoints []string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MilestonePushed", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "MilestonePushed", ctx, instanceID, msType, endpoints) ret0, _ := ret[0].(error) return ret0 } // MilestonePushed indicates an expected call of MilestonePushed. -func (mr *MockCommandsMockRecorder) MilestonePushed(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockCommandsMockRecorder) MilestonePushed(ctx, instanceID, msType, endpoints any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MilestonePushed", reflect.TypeOf((*MockCommands)(nil).MilestonePushed), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MilestonePushed", reflect.TypeOf((*MockCommands)(nil).MilestonePushed), ctx, instanceID, msType, endpoints) } // OTPEmailSent mocks base method. -func (m *MockCommands) OTPEmailSent(arg0 context.Context, arg1, arg2 string) error { +func (m *MockCommands) OTPEmailSent(ctx context.Context, sessionID, resourceOwner string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OTPEmailSent", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "OTPEmailSent", ctx, sessionID, resourceOwner) ret0, _ := ret[0].(error) return ret0 } // OTPEmailSent indicates an expected call of OTPEmailSent. -func (mr *MockCommandsMockRecorder) OTPEmailSent(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockCommandsMockRecorder) OTPEmailSent(ctx, sessionID, resourceOwner any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OTPEmailSent", reflect.TypeOf((*MockCommands)(nil).OTPEmailSent), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OTPEmailSent", reflect.TypeOf((*MockCommands)(nil).OTPEmailSent), ctx, sessionID, resourceOwner) } // OTPSMSSent mocks base method. -func (m *MockCommands) OTPSMSSent(arg0 context.Context, arg1, arg2 string, arg3 *senders.CodeGeneratorInfo) error { +func (m *MockCommands) OTPSMSSent(ctx context.Context, sessionID, resourceOwner string, generatorInfo *senders.CodeGeneratorInfo) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OTPSMSSent", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "OTPSMSSent", ctx, sessionID, resourceOwner, generatorInfo) ret0, _ := ret[0].(error) return ret0 } // OTPSMSSent indicates an expected call of OTPSMSSent. -func (mr *MockCommandsMockRecorder) OTPSMSSent(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockCommandsMockRecorder) OTPSMSSent(ctx, sessionID, resourceOwner, generatorInfo any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OTPSMSSent", reflect.TypeOf((*MockCommands)(nil).OTPSMSSent), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OTPSMSSent", reflect.TypeOf((*MockCommands)(nil).OTPSMSSent), ctx, sessionID, resourceOwner, generatorInfo) } // PasswordChangeSent mocks base method. -func (m *MockCommands) PasswordChangeSent(arg0 context.Context, arg1, arg2 string) error { +func (m *MockCommands) PasswordChangeSent(ctx context.Context, orgID, userID string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PasswordChangeSent", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "PasswordChangeSent", ctx, orgID, userID) ret0, _ := ret[0].(error) return ret0 } // PasswordChangeSent indicates an expected call of PasswordChangeSent. -func (mr *MockCommandsMockRecorder) PasswordChangeSent(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockCommandsMockRecorder) PasswordChangeSent(ctx, orgID, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordChangeSent", reflect.TypeOf((*MockCommands)(nil).PasswordChangeSent), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordChangeSent", reflect.TypeOf((*MockCommands)(nil).PasswordChangeSent), ctx, orgID, userID) } // PasswordCodeSent mocks base method. -func (m *MockCommands) PasswordCodeSent(arg0 context.Context, arg1, arg2 string, arg3 *senders.CodeGeneratorInfo) error { +func (m *MockCommands) PasswordCodeSent(ctx context.Context, orgID, userID string, generatorInfo *senders.CodeGeneratorInfo) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PasswordCodeSent", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "PasswordCodeSent", ctx, orgID, userID, generatorInfo) ret0, _ := ret[0].(error) return ret0 } // PasswordCodeSent indicates an expected call of PasswordCodeSent. -func (mr *MockCommandsMockRecorder) PasswordCodeSent(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockCommandsMockRecorder) PasswordCodeSent(ctx, orgID, userID, generatorInfo any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordCodeSent", reflect.TypeOf((*MockCommands)(nil).PasswordCodeSent), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordCodeSent", reflect.TypeOf((*MockCommands)(nil).PasswordCodeSent), ctx, orgID, userID, generatorInfo) } // UsageNotificationSent mocks base method. -func (m *MockCommands) UsageNotificationSent(arg0 context.Context, arg1 *quota.NotificationDueEvent) error { +func (m *MockCommands) UsageNotificationSent(ctx context.Context, dueEvent *quota.NotificationDueEvent) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UsageNotificationSent", arg0, arg1) + ret := m.ctrl.Call(m, "UsageNotificationSent", ctx, dueEvent) ret0, _ := ret[0].(error) return ret0 } // UsageNotificationSent indicates an expected call of UsageNotificationSent. -func (mr *MockCommandsMockRecorder) UsageNotificationSent(arg0, arg1 any) *gomock.Call { +func (mr *MockCommandsMockRecorder) UsageNotificationSent(ctx, dueEvent any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UsageNotificationSent", reflect.TypeOf((*MockCommands)(nil).UsageNotificationSent), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UsageNotificationSent", reflect.TypeOf((*MockCommands)(nil).UsageNotificationSent), ctx, dueEvent) } // UserDomainClaimedSent mocks base method. -func (m *MockCommands) UserDomainClaimedSent(arg0 context.Context, arg1, arg2 string) error { +func (m *MockCommands) UserDomainClaimedSent(ctx context.Context, orgID, userID string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UserDomainClaimedSent", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "UserDomainClaimedSent", ctx, orgID, userID) ret0, _ := ret[0].(error) return ret0 } // UserDomainClaimedSent indicates an expected call of UserDomainClaimedSent. -func (mr *MockCommandsMockRecorder) UserDomainClaimedSent(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockCommandsMockRecorder) UserDomainClaimedSent(ctx, orgID, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserDomainClaimedSent", reflect.TypeOf((*MockCommands)(nil).UserDomainClaimedSent), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserDomainClaimedSent", reflect.TypeOf((*MockCommands)(nil).UserDomainClaimedSent), ctx, orgID, userID) } diff --git a/internal/notification/handlers/mock/queries.mock.go b/internal/notification/handlers/mock/queries.mock.go index 48d7ec21ec..5669444d4f 100644 --- a/internal/notification/handlers/mock/queries.mock.go +++ b/internal/notification/handlers/mock/queries.mock.go @@ -12,7 +12,10 @@ package mock import ( context "context" reflect "reflect" + time "time" + jose "github.com/go-jose/go-jose/v4" + authz "github.com/zitadel/zitadel/internal/api/authz" domain "github.com/zitadel/zitadel/internal/domain" query "github.com/zitadel/zitadel/internal/query" gomock "go.uber.org/mock/gomock" @@ -23,6 +26,7 @@ import ( type MockQueries struct { ctrl *gomock.Controller recorder *MockQueriesMockRecorder + isgomock struct{} } // MockQueriesMockRecorder is the mock recorder for MockQueries. @@ -43,195 +47,240 @@ func (m *MockQueries) EXPECT() *MockQueriesMockRecorder { } // ActiveLabelPolicyByOrg mocks base method. -func (m *MockQueries) ActiveLabelPolicyByOrg(arg0 context.Context, arg1 string, arg2 bool) (*query.LabelPolicy, error) { +func (m *MockQueries) ActiveLabelPolicyByOrg(ctx context.Context, orgID string, withOwnerRemoved bool) (*query.LabelPolicy, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ActiveLabelPolicyByOrg", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "ActiveLabelPolicyByOrg", ctx, orgID, withOwnerRemoved) ret0, _ := ret[0].(*query.LabelPolicy) ret1, _ := ret[1].(error) return ret0, ret1 } // ActiveLabelPolicyByOrg indicates an expected call of ActiveLabelPolicyByOrg. -func (mr *MockQueriesMockRecorder) ActiveLabelPolicyByOrg(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockQueriesMockRecorder) ActiveLabelPolicyByOrg(ctx, orgID, withOwnerRemoved any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActiveLabelPolicyByOrg", reflect.TypeOf((*MockQueries)(nil).ActiveLabelPolicyByOrg), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActiveLabelPolicyByOrg", reflect.TypeOf((*MockQueries)(nil).ActiveLabelPolicyByOrg), ctx, orgID, withOwnerRemoved) +} + +// ActivePrivateSigningKey mocks base method. +func (m *MockQueries) ActivePrivateSigningKey(ctx context.Context, t time.Time) (*query.PrivateKeys, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ActivePrivateSigningKey", ctx, t) + ret0, _ := ret[0].(*query.PrivateKeys) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ActivePrivateSigningKey indicates an expected call of ActivePrivateSigningKey. +func (mr *MockQueriesMockRecorder) ActivePrivateSigningKey(ctx, t any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActivePrivateSigningKey", reflect.TypeOf((*MockQueries)(nil).ActivePrivateSigningKey), ctx, t) } // CustomTextListByTemplate mocks base method. -func (m *MockQueries) CustomTextListByTemplate(arg0 context.Context, arg1, arg2 string, arg3 bool) (*query.CustomTexts, error) { +func (m *MockQueries) CustomTextListByTemplate(ctx context.Context, aggregateID, template string, withOwnerRemoved bool) (*query.CustomTexts, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CustomTextListByTemplate", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "CustomTextListByTemplate", ctx, aggregateID, template, withOwnerRemoved) ret0, _ := ret[0].(*query.CustomTexts) ret1, _ := ret[1].(error) return ret0, ret1 } // CustomTextListByTemplate indicates an expected call of CustomTextListByTemplate. -func (mr *MockQueriesMockRecorder) CustomTextListByTemplate(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockQueriesMockRecorder) CustomTextListByTemplate(ctx, aggregateID, template, withOwnerRemoved any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CustomTextListByTemplate", reflect.TypeOf((*MockQueries)(nil).CustomTextListByTemplate), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CustomTextListByTemplate", reflect.TypeOf((*MockQueries)(nil).CustomTextListByTemplate), ctx, aggregateID, template, withOwnerRemoved) +} + +// GetActiveSigningWebKey mocks base method. +func (m *MockQueries) GetActiveSigningWebKey(ctx context.Context) (*jose.JSONWebKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetActiveSigningWebKey", ctx) + ret0, _ := ret[0].(*jose.JSONWebKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveSigningWebKey indicates an expected call of GetActiveSigningWebKey. +func (mr *MockQueriesMockRecorder) GetActiveSigningWebKey(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveSigningWebKey", reflect.TypeOf((*MockQueries)(nil).GetActiveSigningWebKey), ctx) } // GetDefaultLanguage mocks base method. -func (m *MockQueries) GetDefaultLanguage(arg0 context.Context) language.Tag { +func (m *MockQueries) GetDefaultLanguage(ctx context.Context) language.Tag { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetDefaultLanguage", arg0) + ret := m.ctrl.Call(m, "GetDefaultLanguage", ctx) ret0, _ := ret[0].(language.Tag) return ret0 } // GetDefaultLanguage indicates an expected call of GetDefaultLanguage. -func (mr *MockQueriesMockRecorder) GetDefaultLanguage(arg0 any) *gomock.Call { +func (mr *MockQueriesMockRecorder) GetDefaultLanguage(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDefaultLanguage", reflect.TypeOf((*MockQueries)(nil).GetDefaultLanguage), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDefaultLanguage", reflect.TypeOf((*MockQueries)(nil).GetDefaultLanguage), ctx) } // GetInstanceRestrictions mocks base method. -func (m *MockQueries) GetInstanceRestrictions(arg0 context.Context) (query.Restrictions, error) { +func (m *MockQueries) GetInstanceRestrictions(ctx context.Context) (query.Restrictions, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetInstanceRestrictions", arg0) + ret := m.ctrl.Call(m, "GetInstanceRestrictions", ctx) ret0, _ := ret[0].(query.Restrictions) ret1, _ := ret[1].(error) return ret0, ret1 } // GetInstanceRestrictions indicates an expected call of GetInstanceRestrictions. -func (mr *MockQueriesMockRecorder) GetInstanceRestrictions(arg0 any) *gomock.Call { +func (mr *MockQueriesMockRecorder) GetInstanceRestrictions(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInstanceRestrictions", reflect.TypeOf((*MockQueries)(nil).GetInstanceRestrictions), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInstanceRestrictions", reflect.TypeOf((*MockQueries)(nil).GetInstanceRestrictions), ctx) } // GetNotifyUserByID mocks base method. -func (m *MockQueries) GetNotifyUserByID(arg0 context.Context, arg1 bool, arg2 string) (*query.NotifyUser, error) { +func (m *MockQueries) GetNotifyUserByID(ctx context.Context, shouldTriggered bool, userID string) (*query.NotifyUser, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetNotifyUserByID", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetNotifyUserByID", ctx, shouldTriggered, userID) ret0, _ := ret[0].(*query.NotifyUser) ret1, _ := ret[1].(error) return ret0, ret1 } // GetNotifyUserByID indicates an expected call of GetNotifyUserByID. -func (mr *MockQueriesMockRecorder) GetNotifyUserByID(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockQueriesMockRecorder) GetNotifyUserByID(ctx, shouldTriggered, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNotifyUserByID", reflect.TypeOf((*MockQueries)(nil).GetNotifyUserByID), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNotifyUserByID", reflect.TypeOf((*MockQueries)(nil).GetNotifyUserByID), ctx, shouldTriggered, userID) +} + +// InstanceByID mocks base method. +func (m *MockQueries) InstanceByID(ctx context.Context, id string) (authz.Instance, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InstanceByID", ctx, id) + ret0, _ := ret[0].(authz.Instance) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// InstanceByID indicates an expected call of InstanceByID. +func (mr *MockQueriesMockRecorder) InstanceByID(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InstanceByID", reflect.TypeOf((*MockQueries)(nil).InstanceByID), ctx, id) } // MailTemplateByOrg mocks base method. -func (m *MockQueries) MailTemplateByOrg(arg0 context.Context, arg1 string, arg2 bool) (*query.MailTemplate, error) { +func (m *MockQueries) MailTemplateByOrg(ctx context.Context, orgID string, withOwnerRemoved bool) (*query.MailTemplate, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MailTemplateByOrg", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "MailTemplateByOrg", ctx, orgID, withOwnerRemoved) ret0, _ := ret[0].(*query.MailTemplate) ret1, _ := ret[1].(error) return ret0, ret1 } // MailTemplateByOrg indicates an expected call of MailTemplateByOrg. -func (mr *MockQueriesMockRecorder) MailTemplateByOrg(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockQueriesMockRecorder) MailTemplateByOrg(ctx, orgID, withOwnerRemoved any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MailTemplateByOrg", reflect.TypeOf((*MockQueries)(nil).MailTemplateByOrg), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MailTemplateByOrg", reflect.TypeOf((*MockQueries)(nil).MailTemplateByOrg), ctx, orgID, withOwnerRemoved) } // NotificationPolicyByOrg mocks base method. -func (m *MockQueries) NotificationPolicyByOrg(arg0 context.Context, arg1 bool, arg2 string, arg3 bool) (*query.NotificationPolicy, error) { +func (m *MockQueries) NotificationPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (*query.NotificationPolicy, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NotificationPolicyByOrg", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "NotificationPolicyByOrg", ctx, shouldTriggerBulk, orgID, withOwnerRemoved) ret0, _ := ret[0].(*query.NotificationPolicy) ret1, _ := ret[1].(error) return ret0, ret1 } // NotificationPolicyByOrg indicates an expected call of NotificationPolicyByOrg. -func (mr *MockQueriesMockRecorder) NotificationPolicyByOrg(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockQueriesMockRecorder) NotificationPolicyByOrg(ctx, shouldTriggerBulk, orgID, withOwnerRemoved any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationPolicyByOrg", reflect.TypeOf((*MockQueries)(nil).NotificationPolicyByOrg), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationPolicyByOrg", reflect.TypeOf((*MockQueries)(nil).NotificationPolicyByOrg), ctx, shouldTriggerBulk, orgID, withOwnerRemoved) } // NotificationProviderByIDAndType mocks base method. -func (m *MockQueries) NotificationProviderByIDAndType(arg0 context.Context, arg1 string, arg2 domain.NotificationProviderType) (*query.DebugNotificationProvider, error) { +func (m *MockQueries) NotificationProviderByIDAndType(ctx context.Context, aggID string, providerType domain.NotificationProviderType) (*query.DebugNotificationProvider, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NotificationProviderByIDAndType", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "NotificationProviderByIDAndType", ctx, aggID, providerType) ret0, _ := ret[0].(*query.DebugNotificationProvider) ret1, _ := ret[1].(error) return ret0, ret1 } // NotificationProviderByIDAndType indicates an expected call of NotificationProviderByIDAndType. -func (mr *MockQueriesMockRecorder) NotificationProviderByIDAndType(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockQueriesMockRecorder) NotificationProviderByIDAndType(ctx, aggID, providerType any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationProviderByIDAndType", reflect.TypeOf((*MockQueries)(nil).NotificationProviderByIDAndType), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationProviderByIDAndType", reflect.TypeOf((*MockQueries)(nil).NotificationProviderByIDAndType), ctx, aggID, providerType) } // SMSProviderConfigActive mocks base method. -func (m *MockQueries) SMSProviderConfigActive(arg0 context.Context, arg1 string) (*query.SMSConfig, error) { +func (m *MockQueries) SMSProviderConfigActive(ctx context.Context, resourceOwner string) (*query.SMSConfig, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SMSProviderConfigActive", arg0, arg1) + ret := m.ctrl.Call(m, "SMSProviderConfigActive", ctx, resourceOwner) ret0, _ := ret[0].(*query.SMSConfig) ret1, _ := ret[1].(error) return ret0, ret1 } // SMSProviderConfigActive indicates an expected call of SMSProviderConfigActive. -func (mr *MockQueriesMockRecorder) SMSProviderConfigActive(arg0, arg1 any) *gomock.Call { +func (mr *MockQueriesMockRecorder) SMSProviderConfigActive(ctx, resourceOwner any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SMSProviderConfigActive", reflect.TypeOf((*MockQueries)(nil).SMSProviderConfigActive), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SMSProviderConfigActive", reflect.TypeOf((*MockQueries)(nil).SMSProviderConfigActive), ctx, resourceOwner) } // SMTPConfigActive mocks base method. -func (m *MockQueries) SMTPConfigActive(arg0 context.Context, arg1 string) (*query.SMTPConfig, error) { +func (m *MockQueries) SMTPConfigActive(ctx context.Context, resourceOwner string) (*query.SMTPConfig, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SMTPConfigActive", arg0, arg1) + ret := m.ctrl.Call(m, "SMTPConfigActive", ctx, resourceOwner) ret0, _ := ret[0].(*query.SMTPConfig) ret1, _ := ret[1].(error) return ret0, ret1 } // SMTPConfigActive indicates an expected call of SMTPConfigActive. -func (mr *MockQueriesMockRecorder) SMTPConfigActive(arg0, arg1 any) *gomock.Call { +func (mr *MockQueriesMockRecorder) SMTPConfigActive(ctx, resourceOwner any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SMTPConfigActive", reflect.TypeOf((*MockQueries)(nil).SMTPConfigActive), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SMTPConfigActive", reflect.TypeOf((*MockQueries)(nil).SMTPConfigActive), ctx, resourceOwner) } // SearchInstanceDomains mocks base method. -func (m *MockQueries) SearchInstanceDomains(arg0 context.Context, arg1 *query.InstanceDomainSearchQueries) (*query.InstanceDomains, error) { +func (m *MockQueries) SearchInstanceDomains(ctx context.Context, queries *query.InstanceDomainSearchQueries) (*query.InstanceDomains, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SearchInstanceDomains", arg0, arg1) + ret := m.ctrl.Call(m, "SearchInstanceDomains", ctx, queries) ret0, _ := ret[0].(*query.InstanceDomains) ret1, _ := ret[1].(error) return ret0, ret1 } // SearchInstanceDomains indicates an expected call of SearchInstanceDomains. -func (mr *MockQueriesMockRecorder) SearchInstanceDomains(arg0, arg1 any) *gomock.Call { +func (mr *MockQueriesMockRecorder) SearchInstanceDomains(ctx, queries any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchInstanceDomains", reflect.TypeOf((*MockQueries)(nil).SearchInstanceDomains), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchInstanceDomains", reflect.TypeOf((*MockQueries)(nil).SearchInstanceDomains), ctx, queries) } // SearchMilestones mocks base method. -func (m *MockQueries) SearchMilestones(arg0 context.Context, arg1 []string, arg2 *query.MilestonesSearchQueries) (*query.Milestones, error) { +func (m *MockQueries) SearchMilestones(ctx context.Context, instanceIDs []string, queries *query.MilestonesSearchQueries) (*query.Milestones, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SearchMilestones", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "SearchMilestones", ctx, instanceIDs, queries) ret0, _ := ret[0].(*query.Milestones) ret1, _ := ret[1].(error) return ret0, ret1 } // SearchMilestones indicates an expected call of SearchMilestones. -func (mr *MockQueriesMockRecorder) SearchMilestones(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockQueriesMockRecorder) SearchMilestones(ctx, instanceIDs, queries any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchMilestones", reflect.TypeOf((*MockQueries)(nil).SearchMilestones), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchMilestones", reflect.TypeOf((*MockQueries)(nil).SearchMilestones), ctx, instanceIDs, queries) } // SessionByID mocks base method. -func (m *MockQueries) SessionByID(arg0 context.Context, arg1 bool, arg2, arg3 string) (*query.Session, error) { +func (m *MockQueries) SessionByID(ctx context.Context, shouldTriggerBulk bool, id, sessionToken string) (*query.Session, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SessionByID", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "SessionByID", ctx, shouldTriggerBulk, id, sessionToken) ret0, _ := ret[0].(*query.Session) ret1, _ := ret[1].(error) return ret0, ret1 } // SessionByID indicates an expected call of SessionByID. -func (mr *MockQueriesMockRecorder) SessionByID(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockQueriesMockRecorder) SessionByID(ctx, shouldTriggerBulk, id, sessionToken any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SessionByID", reflect.TypeOf((*MockQueries)(nil).SessionByID), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SessionByID", reflect.TypeOf((*MockQueries)(nil).SessionByID), ctx, shouldTriggerBulk, id, sessionToken) } diff --git a/internal/notification/handlers/queries.go b/internal/notification/handlers/queries.go index 49cffc5e49..1c00460531 100644 --- a/internal/notification/handlers/queries.go +++ b/internal/notification/handlers/queries.go @@ -2,9 +2,12 @@ package handlers import ( "context" + "time" + "github.com/go-jose/go-jose/v4" "golang.org/x/text/language" + "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore" @@ -25,6 +28,9 @@ type Queries interface { SMTPConfigActive(ctx context.Context, resourceOwner string) (*query.SMTPConfig, error) GetDefaultLanguage(ctx context.Context) language.Tag GetInstanceRestrictions(ctx context.Context) (restrictions query.Restrictions, err error) + InstanceByID(ctx context.Context, id string) (instance authz.Instance, err error) + GetActiveSigningWebKey(ctx context.Context) (*jose.JSONWebKey, error) + ActivePrivateSigningKey(ctx context.Context, t time.Time) (keys *query.PrivateKeys, err error) } type NotificationQueries struct { diff --git a/internal/notification/handlers/user_notifier_test.go b/internal/notification/handlers/user_notifier_test.go index 991eb0531d..9692832787 100644 --- a/internal/notification/handlers/user_notifier_test.go +++ b/internal/notification/handlers/user_notifier_test.go @@ -19,6 +19,7 @@ import ( es_repo_mock "github.com/zitadel/zitadel/internal/eventstore/repository/mock" "github.com/zitadel/zitadel/internal/notification/channels/email" channel_mock "github.com/zitadel/zitadel/internal/notification/channels/mock" + "github.com/zitadel/zitadel/internal/notification/channels/set" "github.com/zitadel/zitadel/internal/notification/channels/sms" "github.com/zitadel/zitadel/internal/notification/channels/smtp" "github.com/zitadel/zitadel/internal/notification/channels/twilio" @@ -1663,6 +1664,10 @@ func (c *channels) Webhook(context.Context, webhook.Config) (*senders.Chain, err return &c.Chain, nil } +func (c *channels) SecurityTokenEvent(context.Context, set.Config) (*senders.Chain, error) { + return &c.Chain, nil +} + func expectTemplateQueries(queries *mock.MockQueries, template string) { queries.EXPECT().GetInstanceRestrictions(gomock.Any()).Return(query.Restrictions{ AllowedLanguages: []language.Tag{language.English}, diff --git a/internal/notification/messages/form.go b/internal/notification/messages/form.go new file mode 100644 index 0000000000..5e9a97ca68 --- /dev/null +++ b/internal/notification/messages/form.go @@ -0,0 +1,27 @@ +package messages + +import ( + "net/url" + + "github.com/zitadel/schema" + + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/notification/channels" +) + +var _ channels.Message = (*Form)(nil) + +type Form struct { + Serializable any + TriggeringEvent eventstore.Event +} + +func (msg *Form) GetContent() (string, error) { + values := make(url.Values) + err := schema.NewEncoder().Encode(msg.Serializable, values) + return values.Encode(), err +} + +func (msg *Form) GetTriggeringEvent() eventstore.Event { + return msg.TriggeringEvent +} diff --git a/internal/notification/projections.go b/internal/notification/projections.go index 46434536c2..2be95f1490 100644 --- a/internal/notification/projections.go +++ b/internal/notification/projections.go @@ -2,6 +2,7 @@ package notification import ( "context" + "time" "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/crypto" @@ -17,7 +18,7 @@ var projections []*handler.Handler func Register( ctx context.Context, - userHandlerCustomConfig, quotaHandlerCustomConfig, telemetryHandlerCustomConfig projection.CustomConfig, + userHandlerCustomConfig, quotaHandlerCustomConfig, telemetryHandlerCustomConfig, backChannelLogoutHandlerCustomConfig projection.CustomConfig, telemetryCfg handlers.TelemetryPusherConfig, externalDomain string, externalPort uint16, @@ -25,14 +26,24 @@ func Register( commands *command.Commands, queries *query.Queries, es *eventstore.Eventstore, - otpEmailTmpl string, - fileSystemPath string, - userEncryption, smtpEncryption, smsEncryption crypto.EncryptionAlgorithm, + otpEmailTmpl, fileSystemPath string, + userEncryption, smtpEncryption, smsEncryption, keysEncryptionAlg crypto.EncryptionAlgorithm, + tokenLifetime time.Duration, ) { q := handlers.NewNotificationQueries(queries, es, externalDomain, externalPort, externalSecure, fileSystemPath, userEncryption, smtpEncryption, smsEncryption) c := newChannels(q) projections = append(projections, handlers.NewUserNotifier(ctx, projection.ApplyCustomConfig(userHandlerCustomConfig), commands, q, c, otpEmailTmpl)) projections = append(projections, handlers.NewQuotaNotifier(ctx, projection.ApplyCustomConfig(quotaHandlerCustomConfig), commands, q, c)) + projections = append(projections, handlers.NewBackChannelLogoutNotifier( + ctx, + projection.ApplyCustomConfig(backChannelLogoutHandlerCustomConfig), + commands, + q, + es, + keysEncryptionAlg, + c, + tokenLifetime, + )) if telemetryCfg.Enabled { projections = append(projections, handlers.NewTelemetryPusher(ctx, telemetryCfg, projection.ApplyCustomConfig(telemetryHandlerCustomConfig), commands, q, c)) } diff --git a/internal/notification/senders/security_event_token.go b/internal/notification/senders/security_event_token.go new file mode 100644 index 0000000000..8fb21a6557 --- /dev/null +++ b/internal/notification/senders/security_event_token.go @@ -0,0 +1,49 @@ +package senders + +import ( + "context" + + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/notification/channels" + "github.com/zitadel/zitadel/internal/notification/channels/fs" + "github.com/zitadel/zitadel/internal/notification/channels/instrumenting" + "github.com/zitadel/zitadel/internal/notification/channels/log" + "github.com/zitadel/zitadel/internal/notification/channels/set" +) + +const setSpanName = "security_event_token.NotificationChannel" + +func SecurityEventTokenChannels( + ctx context.Context, + setConfig set.Config, + getFileSystemProvider func(ctx context.Context) (*fs.Config, error), + getLogProvider func(ctx context.Context) (*log.Config, error), + successMetricName, + failureMetricName string, +) (*Chain, error) { + if err := setConfig.Validate(); err != nil { + return nil, err + } + channels := make([]channels.NotificationChannel, 0, 3) + setChannel, err := set.InitChannel(ctx, setConfig) + logging.WithFields( + "instance", authz.GetInstance(ctx).InstanceID(), + "callurl", setConfig.CallURL, + ).OnError(err).Debug("initializing SET channel failed") + if err == nil { + channels = append( + channels, + instrumenting.Wrap( + ctx, + setChannel, + setSpanName, + successMetricName, + failureMetricName, + ), + ) + } + channels = append(channels, debugChannels(ctx, getFileSystemProvider, getLogProvider)...) + return ChainChannels(channels...), nil +} diff --git a/internal/notification/types/notification.go b/internal/notification/types/notification.go index 49a437ff18..61c4cf70de 100644 --- a/internal/notification/types/notification.go +++ b/internal/notification/types/notification.go @@ -8,6 +8,7 @@ import ( "github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/i18n" "github.com/zitadel/zitadel/internal/notification/channels/email" + "github.com/zitadel/zitadel/internal/notification/channels/set" "github.com/zitadel/zitadel/internal/notification/channels/sms" "github.com/zitadel/zitadel/internal/notification/channels/webhook" "github.com/zitadel/zitadel/internal/notification/senders" @@ -26,6 +27,7 @@ type ChannelChains interface { Email(context.Context) (*senders.Chain, *email.Config, error) SMS(context.Context) (*senders.Chain, *sms.Config, error) Webhook(context.Context, webhook.Config) (*senders.Chain, error) + SecurityTokenEvent(context.Context, set.Config) (*senders.Chain, error) } func SendEmail( @@ -127,3 +129,21 @@ func SendJSON( ) } } + +func SendSecurityTokenEvent( + ctx context.Context, + setConfig set.Config, + channels ChannelChains, + token any, + triggeringEvent eventstore.Event, +) Notify { + return func(_ string, _ map[string]interface{}, _ string, _ bool) error { + return handleSecurityTokenEvent( + ctx, + setConfig, + channels, + token, + triggeringEvent, + ) + } +} diff --git a/internal/notification/types/security_token_event.go b/internal/notification/types/security_token_event.go new file mode 100644 index 0000000000..d8a1d26006 --- /dev/null +++ b/internal/notification/types/security_token_event.go @@ -0,0 +1,27 @@ +package types + +import ( + "context" + + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/notification/channels/set" + "github.com/zitadel/zitadel/internal/notification/messages" +) + +func handleSecurityTokenEvent( + ctx context.Context, + setConfig set.Config, + channels ChannelChains, + token any, + triggeringEvent eventstore.Event, +) error { + message := &messages.Form{ + Serializable: token, + TriggeringEvent: triggeringEvent, + } + setChannels, err := channels.SecurityTokenEvent(ctx, setConfig) + if err != nil { + return err + } + return setChannels.HandleMessage(message) +} diff --git a/internal/query/app.go b/internal/query/app.go index b94cb9cdaf..fc0101bf06 100644 --- a/internal/query/app.go +++ b/internal/query/app.go @@ -59,6 +59,7 @@ type OIDCApp struct { AdditionalOrigins database.TextArray[string] AllowedOrigins database.TextArray[string] SkipNativeAppSuccessPage bool + BackChannelLogoutURI string } type SAMLApp struct { @@ -243,6 +244,10 @@ var ( name: projection.AppOIDCConfigColumnSkipNativeAppSuccessPage, table: appOIDCConfigsTable, } + AppOIDCConfigColumnBackChannelLogoutURI = Column{ + name: projection.AppOIDCConfigColumnBackChannelLogoutURI, + table: appOIDCConfigsTable, + } ) func (q *Queries) AppByProjectAndAppID(ctx context.Context, shouldTriggerBulk bool, projectID, appID string) (app *App, err error) { @@ -536,6 +541,7 @@ func prepareAppQuery(ctx context.Context, db prepareDatabase, activeOnly bool) ( AppOIDCConfigColumnClockSkew.identifier(), AppOIDCConfigColumnAdditionalOrigins.identifier(), AppOIDCConfigColumnSkipNativeAppSuccessPage.identifier(), + AppOIDCConfigColumnBackChannelLogoutURI.identifier(), AppSAMLConfigColumnAppID.identifier(), AppSAMLConfigColumnEntityID.identifier(), @@ -600,6 +606,7 @@ func scanApp(row *sql.Row) (*App, error) { &oidcConfig.clockSkew, &oidcConfig.additionalOrigins, &oidcConfig.skipNativeAppSuccessPage, + &oidcConfig.backChannelLogoutURI, &samlConfig.appID, &samlConfig.entityID, @@ -649,6 +656,7 @@ func prepareOIDCAppQuery() (sq.SelectBuilder, func(*sql.Row) (*App, error)) { AppOIDCConfigColumnClockSkew.identifier(), AppOIDCConfigColumnAdditionalOrigins.identifier(), AppOIDCConfigColumnSkipNativeAppSuccessPage.identifier(), + AppOIDCConfigColumnBackChannelLogoutURI.identifier(), ).From(appsTable.identifier()). Join(join(AppOIDCConfigColumnAppID, AppColumnID)). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*App, error) { @@ -685,6 +693,7 @@ func prepareOIDCAppQuery() (sq.SelectBuilder, func(*sql.Row) (*App, error)) { &oidcConfig.clockSkew, &oidcConfig.additionalOrigins, &oidcConfig.skipNativeAppSuccessPage, + &oidcConfig.backChannelLogoutURI, ) if err != nil { @@ -896,6 +905,7 @@ func prepareAppsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder AppOIDCConfigColumnClockSkew.identifier(), AppOIDCConfigColumnAdditionalOrigins.identifier(), AppOIDCConfigColumnSkipNativeAppSuccessPage.identifier(), + AppOIDCConfigColumnBackChannelLogoutURI.identifier(), AppSAMLConfigColumnAppID.identifier(), AppSAMLConfigColumnEntityID.identifier(), @@ -948,6 +958,7 @@ func prepareAppsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder &oidcConfig.clockSkew, &oidcConfig.additionalOrigins, &oidcConfig.skipNativeAppSuccessPage, + &oidcConfig.backChannelLogoutURI, &samlConfig.appID, &samlConfig.entityID, @@ -1020,6 +1031,7 @@ type sqlOIDCConfig struct { responseTypes database.NumberArray[domain.OIDCResponseType] grantTypes database.NumberArray[domain.OIDCGrantType] skipNativeAppSuccessPage sql.NullBool + backChannelLogoutURI sql.NullString } func (c sqlOIDCConfig) set(app *App) { @@ -1043,6 +1055,7 @@ func (c sqlOIDCConfig) set(app *App) { ResponseTypes: c.responseTypes, GrantTypes: c.grantTypes, SkipNativeAppSuccessPage: c.skipNativeAppSuccessPage.Bool, + BackChannelLogoutURI: c.backChannelLogoutURI.String, } compliance := domain.GetOIDCCompliance(app.OIDCConfig.Version, app.OIDCConfig.AppType, app.OIDCConfig.GrantTypes, app.OIDCConfig.ResponseTypes, app.OIDCConfig.AuthMethodType, app.OIDCConfig.RedirectURIs) app.OIDCConfig.ComplianceProblems = compliance.Problems diff --git a/internal/query/app_test.go b/internal/query/app_test.go index 9a9c613868..990ff943f0 100644 --- a/internal/query/app_test.go +++ b/internal/query/app_test.go @@ -48,6 +48,7 @@ var ( ` projections.apps7_oidc_configs.clock_skew,` + ` projections.apps7_oidc_configs.additional_origins,` + ` projections.apps7_oidc_configs.skip_native_app_success_page,` + + ` projections.apps7_oidc_configs.back_channel_logout_uri,` + //saml config ` projections.apps7_saml_configs.app_id,` + ` projections.apps7_saml_configs.entity_id,` + @@ -91,6 +92,7 @@ var ( ` projections.apps7_oidc_configs.clock_skew,` + ` projections.apps7_oidc_configs.additional_origins,` + ` projections.apps7_oidc_configs.skip_native_app_success_page,` + + ` projections.apps7_oidc_configs.back_channel_logout_uri,` + //saml config ` projections.apps7_saml_configs.app_id,` + ` projections.apps7_saml_configs.entity_id,` + @@ -163,6 +165,7 @@ var ( "clock_skew", "additional_origins", "skip_native_app_success_page", + "back_channel_logout_uri", //saml config "app_id", "entity_id", @@ -234,6 +237,7 @@ func Test_AppsPrepare(t *testing.T) { nil, nil, nil, + nil, // saml config nil, nil, @@ -300,6 +304,7 @@ func Test_AppsPrepare(t *testing.T) { nil, nil, nil, + nil, // saml config nil, nil, @@ -369,6 +374,7 @@ func Test_AppsPrepare(t *testing.T) { nil, nil, nil, + nil, // saml config "app-id", "https://test.com/saml/metadata", @@ -440,6 +446,7 @@ func Test_AppsPrepare(t *testing.T) { 1 * time.Second, database.TextArray[string]{"additional.origin"}, false, + "back.channel.logout.ch", // saml config nil, nil, @@ -482,6 +489,7 @@ func Test_AppsPrepare(t *testing.T) { ComplianceProblems: nil, AllowedOrigins: database.TextArray[string]{"https://redirect.to", "additional.origin"}, SkipNativeAppSuccessPage: false, + BackChannelLogoutURI: "back.channel.logout.ch", }, }, }, @@ -526,6 +534,7 @@ func Test_AppsPrepare(t *testing.T) { 1 * time.Second, database.TextArray[string]{"additional.origin"}, false, + "back.channel.logout.ch", // saml config nil, nil, @@ -568,6 +577,7 @@ func Test_AppsPrepare(t *testing.T) { ComplianceProblems: nil, AllowedOrigins: database.TextArray[string]{"https://redirect.to", "additional.origin"}, SkipNativeAppSuccessPage: false, + BackChannelLogoutURI: "back.channel.logout.ch", }, }, }, @@ -612,6 +622,7 @@ func Test_AppsPrepare(t *testing.T) { 1 * time.Second, database.TextArray[string]{"additional.origin"}, false, + "back.channel.logout.ch", // saml config nil, nil, @@ -654,6 +665,7 @@ func Test_AppsPrepare(t *testing.T) { ComplianceProblems: nil, AllowedOrigins: database.TextArray[string]{"https://redirect.to", "additional.origin"}, SkipNativeAppSuccessPage: false, + BackChannelLogoutURI: "back.channel.logout.ch", }, }, }, @@ -698,6 +710,7 @@ func Test_AppsPrepare(t *testing.T) { 1 * time.Second, database.TextArray[string]{"additional.origin"}, false, + "back.channel.logout.ch", // saml config nil, nil, @@ -740,6 +753,7 @@ func Test_AppsPrepare(t *testing.T) { ComplianceProblems: nil, AllowedOrigins: database.TextArray[string]{"https://redirect.to", "additional.origin"}, SkipNativeAppSuccessPage: false, + BackChannelLogoutURI: "back.channel.logout.ch", }, }, }, @@ -784,6 +798,7 @@ func Test_AppsPrepare(t *testing.T) { 1 * time.Second, database.TextArray[string]{"additional.origin"}, false, + "back.channel.logout.ch", // saml config nil, nil, @@ -826,6 +841,7 @@ func Test_AppsPrepare(t *testing.T) { ComplianceProblems: nil, AllowedOrigins: database.TextArray[string]{"https://redirect.to", "additional.origin"}, SkipNativeAppSuccessPage: false, + BackChannelLogoutURI: "back.channel.logout.ch", }, }, }, @@ -870,6 +886,7 @@ func Test_AppsPrepare(t *testing.T) { 1 * time.Second, database.TextArray[string]{"additional.origin"}, true, + "back.channel.logout.ch", // saml config nil, nil, @@ -912,6 +929,7 @@ func Test_AppsPrepare(t *testing.T) { ComplianceProblems: nil, AllowedOrigins: database.TextArray[string]{"https://redirect.to", "additional.origin"}, SkipNativeAppSuccessPage: true, + BackChannelLogoutURI: "back.channel.logout.ch", }, }, }, @@ -956,6 +974,7 @@ func Test_AppsPrepare(t *testing.T) { 1 * time.Second, database.TextArray[string]{"additional.origin"}, false, + "back.channel.logout.ch", // saml config nil, nil, @@ -993,6 +1012,7 @@ func Test_AppsPrepare(t *testing.T) { nil, nil, nil, + nil, // saml config nil, nil, @@ -1030,6 +1050,7 @@ func Test_AppsPrepare(t *testing.T) { nil, nil, nil, + nil, // saml config "saml-app-id", "https://test.com/saml/metadata", @@ -1072,6 +1093,7 @@ func Test_AppsPrepare(t *testing.T) { ComplianceProblems: nil, AllowedOrigins: database.TextArray[string]{"https://redirect.to", "additional.origin"}, SkipNativeAppSuccessPage: false, + BackChannelLogoutURI: "back.channel.logout.ch", }, }, { @@ -1205,6 +1227,7 @@ func Test_AppPrepare(t *testing.T) { nil, nil, nil, + nil, // saml config nil, nil, @@ -1265,6 +1288,7 @@ func Test_AppPrepare(t *testing.T) { nil, nil, nil, + nil, // saml config nil, nil, @@ -1330,6 +1354,7 @@ func Test_AppPrepare(t *testing.T) { 1 * time.Second, database.TextArray[string]{"additional.origin"}, false, + "back.channel.logout.ch", // saml config nil, nil, @@ -1367,6 +1392,7 @@ func Test_AppPrepare(t *testing.T) { ComplianceProblems: nil, AllowedOrigins: database.TextArray[string]{"https://redirect.to", "additional.origin"}, SkipNativeAppSuccessPage: false, + BackChannelLogoutURI: "back.channel.logout.ch", }, }, }, @@ -1411,6 +1437,7 @@ func Test_AppPrepare(t *testing.T) { 1 * time.Second, database.TextArray[string]{"additional.origin"}, false, + "back.channel.logout.ch", // saml config nil, nil, @@ -1448,6 +1475,7 @@ func Test_AppPrepare(t *testing.T) { ComplianceProblems: nil, AllowedOrigins: database.TextArray[string]{"https://redirect.to", "additional.origin"}, SkipNativeAppSuccessPage: false, + BackChannelLogoutURI: "back.channel.logout.ch", }, }, }, @@ -1492,6 +1520,7 @@ func Test_AppPrepare(t *testing.T) { nil, nil, nil, + nil, // saml config "app-id", "https://test.com/saml/metadata", @@ -1558,6 +1587,7 @@ func Test_AppPrepare(t *testing.T) { 1 * time.Second, database.TextArray[string]{"additional.origin"}, false, + "back.channel.logout.ch", // saml config nil, nil, @@ -1595,6 +1625,7 @@ func Test_AppPrepare(t *testing.T) { ComplianceProblems: nil, AllowedOrigins: database.TextArray[string]{"https://redirect.to", "additional.origin"}, SkipNativeAppSuccessPage: false, + BackChannelLogoutURI: "back.channel.logout.ch", }, }, }, @@ -1639,6 +1670,7 @@ func Test_AppPrepare(t *testing.T) { 1 * time.Second, database.TextArray[string]{"additional.origin"}, false, + "back.channel.logout.ch", // saml config nil, nil, @@ -1676,6 +1708,7 @@ func Test_AppPrepare(t *testing.T) { ComplianceProblems: nil, AllowedOrigins: database.TextArray[string]{"https://redirect.to", "additional.origin"}, SkipNativeAppSuccessPage: false, + BackChannelLogoutURI: "back.channel.logout.ch", }, }, }, @@ -1720,6 +1753,7 @@ func Test_AppPrepare(t *testing.T) { 1 * time.Second, database.TextArray[string]{"additional.origin"}, false, + "back.channel.logout.ch", // saml config nil, nil, @@ -1757,6 +1791,7 @@ func Test_AppPrepare(t *testing.T) { ComplianceProblems: nil, AllowedOrigins: database.TextArray[string]{"https://redirect.to", "additional.origin"}, SkipNativeAppSuccessPage: false, + BackChannelLogoutURI: "back.channel.logout.ch", }, }, }, @@ -1801,6 +1836,7 @@ func Test_AppPrepare(t *testing.T) { 1 * time.Second, database.TextArray[string]{"additional.origin"}, false, + "back.channel.logout.ch", // saml config nil, nil, @@ -1838,6 +1874,7 @@ func Test_AppPrepare(t *testing.T) { ComplianceProblems: nil, AllowedOrigins: database.TextArray[string]{"https://redirect.to", "additional.origin"}, SkipNativeAppSuccessPage: false, + BackChannelLogoutURI: "back.channel.logout.ch", }, }, }, diff --git a/internal/query/instance_features.go b/internal/query/instance_features.go index 1616d9b366..fed6d851df 100644 --- a/internal/query/instance_features.go +++ b/internal/query/instance_features.go @@ -20,6 +20,7 @@ type InstanceFeatures struct { DebugOIDCParentError FeatureSource[bool] OIDCSingleV1SessionTermination FeatureSource[bool] DisableUserTokenEvent FeatureSource[bool] + EnableBackChannelLogout FeatureSource[bool] } func (q *Queries) GetInstanceFeatures(ctx context.Context, cascade bool) (_ *InstanceFeatures, err error) { diff --git a/internal/query/instance_features_model.go b/internal/query/instance_features_model.go index 80515b4773..5192f7dfc5 100644 --- a/internal/query/instance_features_model.go +++ b/internal/query/instance_features_model.go @@ -71,6 +71,7 @@ func (m *InstanceFeaturesReadModel) Query() *eventstore.SearchQueryBuilder { feature_v2.InstanceDebugOIDCParentErrorEventType, feature_v2.InstanceOIDCSingleV1SessionTerminationEventType, feature_v2.InstanceDisableUserTokenEvent, + feature_v2.InstanceEnableBackChannelLogout, ). Builder().ResourceOwner(m.ResourceOwner) } @@ -96,6 +97,7 @@ func (m *InstanceFeaturesReadModel) populateFromSystem() bool { m.instance.ImprovedPerformance = m.system.ImprovedPerformance m.instance.OIDCSingleV1SessionTermination = m.system.OIDCSingleV1SessionTermination m.instance.DisableUserTokenEvent = m.system.DisableUserTokenEvent + m.instance.EnableBackChannelLogout = m.system.EnableBackChannelLogout return true } @@ -129,6 +131,8 @@ func reduceInstanceFeatureSet[T any](features *InstanceFeatures, event *feature_ features.OIDCSingleV1SessionTermination.set(level, event.Value) case feature.KeyDisableUserTokenEvent: features.DisableUserTokenEvent.set(level, event.Value) + case feature.KeyEnableBackChannelLogout: + features.EnableBackChannelLogout.set(level, event.Value) } return nil } diff --git a/internal/query/oidc_client.go b/internal/query/oidc_client.go index 89d74d4ff8..8790c9737a 100644 --- a/internal/query/oidc_client.go +++ b/internal/query/oidc_client.go @@ -21,6 +21,7 @@ type OIDCClient struct { AppID string `json:"app_id,omitempty"` State domain.AppState `json:"state,omitempty"` ClientID string `json:"client_id,omitempty"` + BackChannelLogoutURI string `json:"back_channel_logout_uri,omitempty"` HashedSecret string `json:"client_secret,omitempty"` RedirectURIs []string `json:"redirect_uris,omitempty"` ResponseTypes []domain.OIDCResponseType `json:"response_types,omitempty"` diff --git a/internal/query/oidc_client_by_id.sql b/internal/query/oidc_client_by_id.sql index ef471387b3..201705c6bf 100644 --- a/internal/query/oidc_client_by_id.sql +++ b/internal/query/oidc_client_by_id.sql @@ -1,7 +1,7 @@ with client as ( select c.instance_id, - c.app_id, a.state, c.client_id, c.client_secret, c.redirect_uris, c.response_types, c.grant_types, - c.application_type, c.auth_method_type, c.post_logout_redirect_uris, c.is_dev_mode, + c.app_id, a.state, c.client_id, c.back_channel_logout_uri, c.client_secret, c.redirect_uris, c.response_types, + c.grant_types, c.application_type, c.auth_method_type, c.post_logout_redirect_uris, c.is_dev_mode, c.access_token_type, c.access_token_role_assertion, c.id_token_role_assertion, c.id_token_userinfo_assertion, c.clock_skew, c.additional_origins, a.project_id, p.project_role_assertion from projections.apps7_oidc_configs c diff --git a/internal/query/projection/app.go b/internal/query/projection/app.go index a162548dd4..7b810c3a97 100644 --- a/internal/query/projection/app.go +++ b/internal/query/projection/app.go @@ -58,6 +58,7 @@ const ( AppOIDCConfigColumnClockSkew = "clock_skew" AppOIDCConfigColumnAdditionalOrigins = "additional_origins" AppOIDCConfigColumnSkipNativeAppSuccessPage = "skip_native_app_success_page" + AppOIDCConfigColumnBackChannelLogoutURI = "back_channel_logout_uri" appSAMLTableSuffix = "saml_configs" AppSAMLConfigColumnAppID = "app_id" @@ -125,6 +126,7 @@ func (*appProjection) Init() *old_handler.Check { handler.NewColumn(AppOIDCConfigColumnClockSkew, handler.ColumnTypeInt64, handler.Default(0)), handler.NewColumn(AppOIDCConfigColumnAdditionalOrigins, handler.ColumnTypeTextArray, handler.Nullable()), handler.NewColumn(AppOIDCConfigColumnSkipNativeAppSuccessPage, handler.ColumnTypeBool, handler.Default(false)), + handler.NewColumn(AppOIDCConfigColumnBackChannelLogoutURI, handler.ColumnTypeText, handler.Nullable()), }, handler.NewPrimaryKey(AppOIDCConfigColumnInstanceID, AppOIDCConfigColumnAppID), appOIDCTableSuffix, @@ -500,6 +502,7 @@ func (p *appProjection) reduceOIDCConfigAdded(event eventstore.Event) (*handler. handler.NewCol(AppOIDCConfigColumnClockSkew, e.ClockSkew), handler.NewCol(AppOIDCConfigColumnAdditionalOrigins, database.TextArray[string](e.AdditionalOrigins)), handler.NewCol(AppOIDCConfigColumnSkipNativeAppSuccessPage, e.SkipNativeAppSuccessPage), + handler.NewCol(AppOIDCConfigColumnBackChannelLogoutURI, e.BackChannelLogoutURI), }, handler.WithTableSuffix(appOIDCTableSuffix), ), @@ -522,7 +525,7 @@ func (p *appProjection) reduceOIDCConfigChanged(event eventstore.Event) (*handle return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-GNHU1", "reduce.wrong.event.type %s", project.OIDCConfigChangedType) } - cols := make([]handler.Column, 0, 15) + cols := make([]handler.Column, 0, 16) if e.Version != nil { cols = append(cols, handler.NewCol(AppOIDCConfigColumnVersion, *e.Version)) } @@ -568,6 +571,9 @@ func (p *appProjection) reduceOIDCConfigChanged(event eventstore.Event) (*handle if e.SkipNativeAppSuccessPage != nil { cols = append(cols, handler.NewCol(AppOIDCConfigColumnSkipNativeAppSuccessPage, *e.SkipNativeAppSuccessPage)) } + if e.BackChannelLogoutURI != nil { + cols = append(cols, handler.NewCol(AppOIDCConfigColumnBackChannelLogoutURI, *e.BackChannelLogoutURI)) + } if len(cols) == 0 { return handler.NewNoOpStatement(e), nil diff --git a/internal/query/projection/app_test.go b/internal/query/projection/app_test.go index 49979c4698..74e4e39847 100644 --- a/internal/query/projection/app_test.go +++ b/internal/query/projection/app_test.go @@ -558,7 +558,8 @@ func TestAppProjection_reduces(t *testing.T) { "idTokenUserinfoAssertion": true, "clockSkew": 1000, "additionalOrigins": ["origin.one.ch", "origin.two.ch"], - "skipNativeAppSuccessPage": true + "skipNativeAppSuccessPage": true, + "backChannelLogoutURI": "back.channel.one.ch" }`), ), project.OIDCConfigAddedEventMapper), }, @@ -569,7 +570,7 @@ func TestAppProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "INSERT INTO projections.apps7_oidc_configs (app_id, instance_id, version, client_id, client_secret, redirect_uris, response_types, grant_types, application_type, auth_method_type, post_logout_redirect_uris, is_dev_mode, access_token_type, access_token_role_assertion, id_token_role_assertion, id_token_userinfo_assertion, clock_skew, additional_origins, skip_native_app_success_page) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)", + expectedStmt: "INSERT INTO projections.apps7_oidc_configs (app_id, instance_id, version, client_id, client_secret, redirect_uris, response_types, grant_types, application_type, auth_method_type, post_logout_redirect_uris, is_dev_mode, access_token_type, access_token_role_assertion, id_token_role_assertion, id_token_userinfo_assertion, clock_skew, additional_origins, skip_native_app_success_page, back_channel_logout_uri) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20)", expectedArgs: []interface{}{ "app-id", "instance-id", @@ -590,6 +591,7 @@ func TestAppProjection_reduces(t *testing.T) { 1 * time.Microsecond, database.TextArray[string]{"origin.one.ch", "origin.two.ch"}, true, + "back.channel.one.ch", }, }, { @@ -630,7 +632,8 @@ func TestAppProjection_reduces(t *testing.T) { "idTokenUserinfoAssertion": true, "clockSkew": 1000, "additionalOrigins": ["origin.one.ch", "origin.two.ch"], - "skipNativeAppSuccessPage": true + "skipNativeAppSuccessPage": true, + "backChannelLogoutURI": "back.channel.one.ch" }`), ), project.OIDCConfigAddedEventMapper), }, @@ -641,7 +644,7 @@ func TestAppProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "INSERT INTO projections.apps7_oidc_configs (app_id, instance_id, version, client_id, client_secret, redirect_uris, response_types, grant_types, application_type, auth_method_type, post_logout_redirect_uris, is_dev_mode, access_token_type, access_token_role_assertion, id_token_role_assertion, id_token_userinfo_assertion, clock_skew, additional_origins, skip_native_app_success_page) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)", + expectedStmt: "INSERT INTO projections.apps7_oidc_configs (app_id, instance_id, version, client_id, client_secret, redirect_uris, response_types, grant_types, application_type, auth_method_type, post_logout_redirect_uris, is_dev_mode, access_token_type, access_token_role_assertion, id_token_role_assertion, id_token_userinfo_assertion, clock_skew, additional_origins, skip_native_app_success_page, back_channel_logout_uri) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20)", expectedArgs: []interface{}{ "app-id", "instance-id", @@ -662,6 +665,7 @@ func TestAppProjection_reduces(t *testing.T) { 1 * time.Microsecond, database.TextArray[string]{"origin.one.ch", "origin.two.ch"}, true, + "back.channel.one.ch", }, }, { @@ -700,8 +704,8 @@ func TestAppProjection_reduces(t *testing.T) { "idTokenUserinfoAssertion": true, "clockSkew": 1000, "additionalOrigins": ["origin.one.ch", "origin.two.ch"], - "skipNativeAppSuccessPage": true - + "skipNativeAppSuccessPage": true, + "backChannelLogoutURI": "back.channel.one.ch" }`), ), project.OIDCConfigChangedEventMapper), }, @@ -712,7 +716,7 @@ func TestAppProjection_reduces(t *testing.T) { executer: &testExecuter{ executions: []execution{ { - expectedStmt: "UPDATE projections.apps7_oidc_configs SET (version, redirect_uris, response_types, grant_types, application_type, auth_method_type, post_logout_redirect_uris, is_dev_mode, access_token_type, access_token_role_assertion, id_token_role_assertion, id_token_userinfo_assertion, clock_skew, additional_origins, skip_native_app_success_page) = ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) WHERE (app_id = $16) AND (instance_id = $17)", + expectedStmt: "UPDATE projections.apps7_oidc_configs SET (version, redirect_uris, response_types, grant_types, application_type, auth_method_type, post_logout_redirect_uris, is_dev_mode, access_token_type, access_token_role_assertion, id_token_role_assertion, id_token_userinfo_assertion, clock_skew, additional_origins, skip_native_app_success_page, back_channel_logout_uri) = ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) WHERE (app_id = $17) AND (instance_id = $18)", expectedArgs: []interface{}{ domain.OIDCVersionV1, database.TextArray[string]{"redirect.one.ch", "redirect.two.ch"}, @@ -729,6 +733,7 @@ func TestAppProjection_reduces(t *testing.T) { 1 * time.Microsecond, database.TextArray[string]{"origin.one.ch", "origin.two.ch"}, true, + "back.channel.one.ch", "app-id", "instance-id", }, diff --git a/internal/query/projection/instance_features.go b/internal/query/projection/instance_features.go index 1b18e42e76..45e360c6db 100644 --- a/internal/query/projection/instance_features.go +++ b/internal/query/projection/instance_features.go @@ -104,6 +104,10 @@ func (*instanceFeatureProjection) Reducers() []handler.AggregateReducer { Event: feature_v2.InstanceDisableUserTokenEvent, Reduce: reduceInstanceSetFeature[bool], }, + { + Event: feature_v2.InstanceEnableBackChannelLogout, + Reduce: reduceInstanceSetFeature[bool], + }, { Event: instance.InstanceRemovedEventType, Reduce: reduceInstanceRemovedHelper(InstanceDomainInstanceIDCol), diff --git a/internal/query/projection/system_features.go b/internal/query/projection/system_features.go index cf3013e57c..65e72fa394 100644 --- a/internal/query/projection/system_features.go +++ b/internal/query/projection/system_features.go @@ -84,6 +84,10 @@ func (*systemFeatureProjection) Reducers() []handler.AggregateReducer { Event: feature_v2.SystemDisableUserTokenEvent, Reduce: reduceSystemSetFeature[bool], }, + { + Event: feature_v2.SystemEnableBackChannelLogout, + Reduce: reduceSystemSetFeature[bool], + }, }, }} } diff --git a/internal/query/system_features.go b/internal/query/system_features.go index ddbd0a08ea..cae68d5fbb 100644 --- a/internal/query/system_features.go +++ b/internal/query/system_features.go @@ -29,6 +29,7 @@ type SystemFeatures struct { ImprovedPerformance FeatureSource[[]feature.ImprovedPerformanceType] OIDCSingleV1SessionTermination FeatureSource[bool] DisableUserTokenEvent FeatureSource[bool] + EnableBackChannelLogout FeatureSource[bool] } func (q *Queries) GetSystemFeatures(ctx context.Context) (_ *SystemFeatures, err error) { diff --git a/internal/query/system_features_model.go b/internal/query/system_features_model.go index f8670c87fe..119aaa4ea1 100644 --- a/internal/query/system_features_model.go +++ b/internal/query/system_features_model.go @@ -59,6 +59,7 @@ func (m *SystemFeaturesReadModel) Query() *eventstore.SearchQueryBuilder { feature_v2.SystemImprovedPerformanceEventType, feature_v2.SystemOIDCSingleV1SessionTerminationEventType, feature_v2.SystemDisableUserTokenEvent, + feature_v2.SystemEnableBackChannelLogout, ). Builder().ResourceOwner(m.ResourceOwner) } @@ -94,6 +95,8 @@ func reduceSystemFeatureSet[T any](features *SystemFeatures, event *feature_v2.S features.OIDCSingleV1SessionTermination.set(level, event.Value) case feature.KeyDisableUserTokenEvent: features.DisableUserTokenEvent.set(level, event.Value) + case feature.KeyEnableBackChannelLogout: + features.EnableBackChannelLogout.set(level, event.Value) } return nil } diff --git a/internal/repository/feature/feature_v2/eventstore.go b/internal/repository/feature/feature_v2/eventstore.go index 866d331db4..9288f0a675 100644 --- a/internal/repository/feature/feature_v2/eventstore.go +++ b/internal/repository/feature/feature_v2/eventstore.go @@ -16,6 +16,7 @@ func init() { eventstore.RegisterFilterEventMapper(AggregateType, SystemImprovedPerformanceEventType, eventstore.GenericEventMapper[SetEvent[[]feature.ImprovedPerformanceType]]) eventstore.RegisterFilterEventMapper(AggregateType, SystemOIDCSingleV1SessionTerminationEventType, eventstore.GenericEventMapper[SetEvent[bool]]) eventstore.RegisterFilterEventMapper(AggregateType, SystemDisableUserTokenEvent, eventstore.GenericEventMapper[SetEvent[bool]]) + eventstore.RegisterFilterEventMapper(AggregateType, SystemEnableBackChannelLogout, eventstore.GenericEventMapper[SetEvent[bool]]) eventstore.RegisterFilterEventMapper(AggregateType, InstanceResetEventType, eventstore.GenericEventMapper[ResetEvent]) eventstore.RegisterFilterEventMapper(AggregateType, InstanceLoginDefaultOrgEventType, eventstore.GenericEventMapper[SetEvent[bool]]) @@ -29,4 +30,5 @@ func init() { eventstore.RegisterFilterEventMapper(AggregateType, InstanceDebugOIDCParentErrorEventType, eventstore.GenericEventMapper[SetEvent[bool]]) eventstore.RegisterFilterEventMapper(AggregateType, InstanceOIDCSingleV1SessionTerminationEventType, eventstore.GenericEventMapper[SetEvent[bool]]) eventstore.RegisterFilterEventMapper(AggregateType, InstanceDisableUserTokenEvent, eventstore.GenericEventMapper[SetEvent[bool]]) + eventstore.RegisterFilterEventMapper(AggregateType, InstanceEnableBackChannelLogout, eventstore.GenericEventMapper[SetEvent[bool]]) } diff --git a/internal/repository/feature/feature_v2/feature.go b/internal/repository/feature/feature_v2/feature.go index 95f7e44360..3fc180a814 100644 --- a/internal/repository/feature/feature_v2/feature.go +++ b/internal/repository/feature/feature_v2/feature.go @@ -21,6 +21,7 @@ var ( SystemImprovedPerformanceEventType = setEventTypeFromFeature(feature.LevelSystem, feature.KeyImprovedPerformance) SystemOIDCSingleV1SessionTerminationEventType = setEventTypeFromFeature(feature.LevelSystem, feature.KeyOIDCSingleV1SessionTermination) SystemDisableUserTokenEvent = setEventTypeFromFeature(feature.LevelSystem, feature.KeyDisableUserTokenEvent) + SystemEnableBackChannelLogout = setEventTypeFromFeature(feature.LevelSystem, feature.KeyEnableBackChannelLogout) InstanceResetEventType = resetEventTypeFromFeature(feature.LevelInstance) InstanceLoginDefaultOrgEventType = setEventTypeFromFeature(feature.LevelInstance, feature.KeyLoginDefaultOrg) @@ -34,6 +35,7 @@ var ( InstanceDebugOIDCParentErrorEventType = setEventTypeFromFeature(feature.LevelInstance, feature.KeyDebugOIDCParentError) InstanceOIDCSingleV1SessionTerminationEventType = setEventTypeFromFeature(feature.LevelInstance, feature.KeyOIDCSingleV1SessionTermination) InstanceDisableUserTokenEvent = setEventTypeFromFeature(feature.LevelInstance, feature.KeyDisableUserTokenEvent) + InstanceEnableBackChannelLogout = setEventTypeFromFeature(feature.LevelInstance, feature.KeyEnableBackChannelLogout) ) const ( diff --git a/internal/repository/project/oidc_config.go b/internal/repository/project/oidc_config.go index 5ea20c220a..498f3233e2 100644 --- a/internal/repository/project/oidc_config.go +++ b/internal/repository/project/oidc_config.go @@ -43,6 +43,7 @@ type OIDCConfigAddedEvent struct { ClockSkew time.Duration `json:"clockSkew,omitempty"` AdditionalOrigins []string `json:"additionalOrigins,omitempty"` SkipNativeAppSuccessPage bool `json:"skipNativeAppSuccessPage,omitempty"` + BackChannelLogoutURI string `json:"backChannelLogoutURI,omitempty"` } func (e *OIDCConfigAddedEvent) Payload() interface{} { @@ -74,6 +75,7 @@ func NewOIDCConfigAddedEvent( clockSkew time.Duration, additionalOrigins []string, skipNativeAppSuccessPage bool, + backChannelLogoutURI string, ) *OIDCConfigAddedEvent { return &OIDCConfigAddedEvent{ BaseEvent: *eventstore.NewBaseEventForPush( @@ -99,6 +101,7 @@ func NewOIDCConfigAddedEvent( ClockSkew: clockSkew, AdditionalOrigins: additionalOrigins, SkipNativeAppSuccessPage: skipNativeAppSuccessPage, + BackChannelLogoutURI: backChannelLogoutURI, } } @@ -184,7 +187,10 @@ func (e *OIDCConfigAddedEvent) Validate(cmd eventstore.Command) bool { return false } } - return e.SkipNativeAppSuccessPage == c.SkipNativeAppSuccessPage + if e.SkipNativeAppSuccessPage != c.SkipNativeAppSuccessPage { + return false + } + return e.BackChannelLogoutURI == c.BackChannelLogoutURI } func OIDCConfigAddedEventMapper(event eventstore.Event) (eventstore.Event, error) { @@ -219,6 +225,7 @@ type OIDCConfigChangedEvent struct { ClockSkew *time.Duration `json:"clockSkew,omitempty"` AdditionalOrigins *[]string `json:"additionalOrigins,omitempty"` SkipNativeAppSuccessPage *bool `json:"skipNativeAppSuccessPage,omitempty"` + BackChannelLogoutURI *string `json:"backChannelLogoutURI,omitempty"` } func (e *OIDCConfigChangedEvent) Payload() interface{} { @@ -345,6 +352,12 @@ func ChangeSkipNativeAppSuccessPage(skipNativeAppSuccessPage bool) func(event *O } } +func ChangeBackChannelLogoutURI(backChannelLogoutURI string) func(event *OIDCConfigChangedEvent) { + return func(e *OIDCConfigChangedEvent) { + e.BackChannelLogoutURI = &backChannelLogoutURI + } +} + func OIDCConfigChangedEventMapper(event eventstore.Event) (eventstore.Event, error) { e := &OIDCConfigChangedEvent{ BaseEvent: *eventstore.BaseEventFromRepo(event), diff --git a/internal/repository/session/session.go b/internal/repository/session/session.go index f5622fd4b4..42304aca8e 100644 --- a/internal/repository/session/session.go +++ b/internal/repository/session/session.go @@ -659,6 +659,8 @@ func NewLifetimeSetEvent( type TerminateEvent struct { eventstore.BaseEvent `json:"-"` + + TriggerOrigin string `json:"triggerOrigin,omitempty"` } func (e *TerminateEvent) Payload() interface{} { diff --git a/internal/repository/sessionlogout/aggregate.go b/internal/repository/sessionlogout/aggregate.go new file mode 100644 index 0000000000..dcdc7a581a --- /dev/null +++ b/internal/repository/sessionlogout/aggregate.go @@ -0,0 +1,26 @@ +package sessionlogout + +import ( + "github.com/zitadel/zitadel/internal/eventstore" +) + +const ( + AggregateType = "session_logout" + AggregateVersion = "v1" +) + +type Aggregate struct { + eventstore.Aggregate +} + +func NewAggregate(id, instanceID string) *Aggregate { + return &Aggregate{ + Aggregate: eventstore.Aggregate{ + Type: AggregateType, + Version: AggregateVersion, + ID: id, + ResourceOwner: instanceID, + InstanceID: instanceID, + }, + } +} diff --git a/internal/repository/sessionlogout/events.go b/internal/repository/sessionlogout/events.go new file mode 100644 index 0000000000..df7c39accf --- /dev/null +++ b/internal/repository/sessionlogout/events.go @@ -0,0 +1,79 @@ +package sessionlogout + +import ( + "context" + + "github.com/zitadel/zitadel/internal/eventstore" +) + +const ( + eventTypePrefix = "session_logout." + backChannelEventTypePrefix = eventTypePrefix + "back_channel." + BackChannelLogoutRegisteredType = backChannelEventTypePrefix + "registered" + BackChannelLogoutSentType = backChannelEventTypePrefix + "sent" +) + +type BackChannelLogoutRegisteredEvent struct { + *eventstore.BaseEvent `json:"-"` + + OIDCSessionID string `json:"oidc_session_id"` + UserID string `json:"user_id"` + ClientID string `json:"client_id"` + BackChannelLogoutURI string `json:"back_channel_logout_uri"` +} + +// Payload implements eventstore.Command. +func (e *BackChannelLogoutRegisteredEvent) Payload() any { + return e +} + +func (e *BackChannelLogoutRegisteredEvent) UniqueConstraints() []*eventstore.UniqueConstraint { + return nil +} + +func (e *BackChannelLogoutRegisteredEvent) SetBaseEvent(b *eventstore.BaseEvent) { + e.BaseEvent = b +} + +func NewBackChannelLogoutRegisteredEvent(ctx context.Context, aggregate *eventstore.Aggregate, oidcSessionID, userID, clientID, backChannelLogoutURI string) *BackChannelLogoutRegisteredEvent { + return &BackChannelLogoutRegisteredEvent{ + BaseEvent: eventstore.NewBaseEventForPush( + ctx, + aggregate, + BackChannelLogoutRegisteredType, + ), + OIDCSessionID: oidcSessionID, + UserID: userID, + ClientID: clientID, + BackChannelLogoutURI: backChannelLogoutURI, + } +} + +type BackChannelLogoutSentEvent struct { + eventstore.BaseEvent `json:"-"` + + OIDCSessionID string `json:"oidc_session_id"` +} + +func (e *BackChannelLogoutSentEvent) Payload() interface{} { + return e +} + +func (e *BackChannelLogoutSentEvent) UniqueConstraints() []*eventstore.UniqueConstraint { + return nil +} + +func (e *BackChannelLogoutSentEvent) SetBaseEvent(event *eventstore.BaseEvent) { + e.BaseEvent = *event +} + +func NewBackChannelLogoutSentEvent(ctx context.Context, aggregate *eventstore.Aggregate, oidcSessionID string) *BackChannelLogoutSentEvent { + return &BackChannelLogoutSentEvent{ + BaseEvent: *eventstore.NewBaseEventForPush( + ctx, + aggregate, + BackChannelLogoutSentType, + ), + OIDCSessionID: oidcSessionID, + } +} diff --git a/internal/repository/sessionlogout/eventstore.go b/internal/repository/sessionlogout/eventstore.go new file mode 100644 index 0000000000..0aa36dd8a8 --- /dev/null +++ b/internal/repository/sessionlogout/eventstore.go @@ -0,0 +1,15 @@ +package sessionlogout + +import ( + "github.com/zitadel/zitadel/internal/eventstore" +) + +var ( + BackChannelLogoutRegisteredEventMapper = eventstore.GenericEventMapper[BackChannelLogoutRegisteredEvent] + BackChannelLogoutSentEventMapper = eventstore.GenericEventMapper[BackChannelLogoutSentEvent] +) + +func init() { + eventstore.RegisterFilterEventMapper(AggregateType, BackChannelLogoutRegisteredType, BackChannelLogoutRegisteredEventMapper) + eventstore.RegisterFilterEventMapper(AggregateType, BackChannelLogoutSentType, BackChannelLogoutSentEventMapper) +} diff --git a/internal/repository/user/human.go b/internal/repository/user/human.go index ae1e9672ef..d503ecc899 100644 --- a/internal/repository/user/human.go +++ b/internal/repository/user/human.go @@ -517,7 +517,9 @@ func NewHumanInviteCheckFailedEvent(ctx context.Context, aggregate *eventstore.A type HumanSignedOutEvent struct { eventstore.BaseEvent `json:"-"` - UserAgentID string `json:"userAgentID"` + UserAgentID string `json:"userAgentID"` + SessionID string `json:"sessionID,omitempty"` + TriggeredAtOrigin string `json:"triggerOrigin,omitempty"` } func (e *HumanSignedOutEvent) Payload() interface{} { @@ -528,10 +530,15 @@ func (e *HumanSignedOutEvent) UniqueConstraints() []*eventstore.UniqueConstraint return nil } +func (e *HumanSignedOutEvent) TriggerOrigin() string { + return e.TriggeredAtOrigin +} + func NewHumanSignedOutEvent( ctx context.Context, aggregate *eventstore.Aggregate, - userAgentID string, + userAgentID, + sessionID string, ) *HumanSignedOutEvent { return &HumanSignedOutEvent{ BaseEvent: *eventstore.NewBaseEventForPush( @@ -539,7 +546,9 @@ func NewHumanSignedOutEvent( aggregate, HumanSignedOutType, ), - UserAgentID: userAgentID, + UserAgentID: userAgentID, + SessionID: sessionID, + TriggeredAtOrigin: http.DomainContext(ctx).Origin(), } } diff --git a/internal/user/repository/view/active_user_ids_by_session_id.sql b/internal/user/repository/view/active_user_sessions_by_session_id.sql similarity index 91% rename from internal/user/repository/view/active_user_ids_by_session_id.sql rename to internal/user/repository/view/active_user_sessions_by_session_id.sql index b7c5aaebb0..d5f4754c3f 100644 --- a/internal/user/repository/view/active_user_ids_by_session_id.sql +++ b/internal/user/repository/view/active_user_sessions_by_session_id.sql @@ -1,6 +1,7 @@ SELECT s.user_agent_id, - s.user_id + s.user_id, + s.id FROM auth.user_sessions s JOIN auth.user_sessions s2 ON s.instance_id = s2.instance_id diff --git a/internal/user/repository/view/user_session_view.go b/internal/user/repository/view/user_session_view.go index f0b956e057..b3d155f1ec 100644 --- a/internal/user/repository/view/user_session_view.go +++ b/internal/user/repository/view/user_session_view.go @@ -20,8 +20,8 @@ var userSessionsByUserAgentQuery string //go:embed user_agent_by_user_session_id.sql var userAgentByUserSessionIDQuery string -//go:embed active_user_ids_by_session_id.sql -var activeUserIDsBySessionIDQuery string +//go:embed active_user_sessions_by_session_id.sql +var activeUserSessionsBySessionIDQuery string func UserSessionByIDs(ctx context.Context, db *database.DB, agentID, userID, instanceID string) (userSession *model.UserSessionView, err error) { err = db.QueryRowContext( @@ -65,36 +65,39 @@ func UserAgentIDBySessionID(ctx context.Context, db *database.DB, sessionID, ins return userAgentID, err } -// ActiveUserIDsBySessionID returns all userIDs with an active session on the same user agent (its id is also returned) based on a sessionID -func ActiveUserIDsBySessionID(ctx context.Context, db *database.DB, sessionID, instanceID string) (userAgentID string, userIDs []string, err error) { +// ActiveUserSessionsBySessionID returns all sessions (sessionID:userID map) with an active session on the same user agent (its id is also returned) based on a sessionID +func ActiveUserSessionsBySessionID(ctx context.Context, db *database.DB, sessionID, instanceID string) (userAgentID string, sessions map[string]string, err error) { err = db.QueryContext( ctx, func(rows *sql.Rows) error { - userAgentID, userIDs, err = scanActiveUserAgentUserIDs(rows) + userAgentID, sessions, err = scanActiveUserAgentUserIDs(rows) return err }, - activeUserIDsBySessionIDQuery, + activeUserSessionsBySessionIDQuery, sessionID, instanceID, ) - return userAgentID, userIDs, err + return userAgentID, sessions, err } -func scanActiveUserAgentUserIDs(rows *sql.Rows) (userAgentID string, userIDs []string, err error) { +func scanActiveUserAgentUserIDs(rows *sql.Rows) (userAgentID string, sessions map[string]string, err error) { + sessions = make(map[string]string) for rows.Next() { - var userID string + var userID, sessionID string err := rows.Scan( &userAgentID, - &userID) + &userID, + &sessionID, + ) if err != nil { return "", nil, err } - userIDs = append(userIDs, userID) + sessions[sessionID] = userID } if err := rows.Close(); err != nil { return "", nil, zerrors.ThrowInternal(err, "VIEW-Sbrws", "Errors.Query.CloseRows") } - return userAgentID, userIDs, nil + return userAgentID, sessions, nil } func scanUserSession(row *sql.Row) (*model.UserSessionView, error) { diff --git a/proto/zitadel/app.proto b/proto/zitadel/app.proto index 58735e7c9e..d18168f2b9 100644 --- a/proto/zitadel/app.proto +++ b/proto/zitadel/app.proto @@ -168,6 +168,12 @@ message OIDCConfig { description: "Skip the successful login page on native apps and directly redirect the user to the callback."; } ]; + string back_channel_logout_uri = 21 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + example: "[\"https://example.com/auth/backchannel\"]"; + description: "ZITADEL will use this URI to notify the application about terminated session according to the OIDC Back-Channel Logout (https://openid.net/specs/openid-connect-backchannel-1_0.html)"; + } + ]; } enum OIDCResponseType { diff --git a/proto/zitadel/feature/v2/instance.proto b/proto/zitadel/feature/v2/instance.proto index ee41c313f2..6717e397ea 100644 --- a/proto/zitadel/feature/v2/instance.proto +++ b/proto/zitadel/feature/v2/instance.proto @@ -86,6 +86,13 @@ message SetInstanceFeaturesRequest{ description: "Do not push user token meta-event user.token.v2.added to improve performance on many concurrent single (machine-)user logins"; } ]; + + optional bool enable_back_channel_logout = 12 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + example: "true"; + description: "If the flag is enabled, you'll be able to use the OIDC Back-Channel Logout to be notified in your application about terminated user sessions."; + } + ]; } message SetInstanceFeaturesResponse { @@ -185,4 +192,11 @@ message GetInstanceFeaturesResponse { description: "Do not push user token meta-event user.token.v2.added to improve performance on many concurrent single (machine-)user logins"; } ]; + + FeatureFlag enable_back_channel_logout = 13 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + example: "true"; + description: "If the flag is enabled, you'll be able to use the OIDC Back-Channel Logout to be notified in your application about terminated user sessions."; + } + ]; } diff --git a/proto/zitadel/feature/v2/system.proto b/proto/zitadel/feature/v2/system.proto index 70ff3c6506..cd8d7cc201 100644 --- a/proto/zitadel/feature/v2/system.proto +++ b/proto/zitadel/feature/v2/system.proto @@ -75,6 +75,13 @@ message SetSystemFeaturesRequest{ description: "Do not push user token meta-event user.token.v2.added to improve performance on many concurrent single (machine-)user logins"; } ]; + + optional bool enable_back_channel_logout = 10 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + example: "true"; + description: "If the flag is enabled, you'll be able to use the OIDC Back-Channel Logout to be notified in your application about terminated user sessions."; + } + ]; } message SetSystemFeaturesResponse { @@ -153,4 +160,11 @@ message GetSystemFeaturesResponse { description: "Do not push user token meta-event user.token.v2.added to improve performance on many concurrent single (machine-)user logins"; } ]; + + FeatureFlag enable_back_channel_logout = 11 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + example: "true"; + description: "If the flag is enabled, you'll be able to use the OIDC Back-Channel Logout to be notified in your application about terminated user sessions."; + } + ]; } diff --git a/proto/zitadel/management.proto b/proto/zitadel/management.proto index cb5bfb1389..0df07ffd4c 100644 --- a/proto/zitadel/management.proto +++ b/proto/zitadel/management.proto @@ -9802,6 +9802,12 @@ message AddOIDCAppRequest { description: "Skip the successful login page on native apps and directly redirect the user to the callback."; } ]; + string back_channel_logout_uri = 18 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + example: "[\"https://example.com/auth/backchannel\"]"; + description: "ZITADEL will use this URI to notify the application about terminated session according to the OIDC Back-Channel Logout (https://openid.net/specs/openid-connect-backchannel-1_0.html)"; + } + ]; } message AddOIDCAppResponse { @@ -9977,6 +9983,12 @@ message UpdateOIDCAppConfigRequest { description: "Skip the successful login page on native apps and directly redirect the user to the callback."; } ]; + string back_channel_logout_uri = 17 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + example: "[\"https://example.com/auth/backchannel\"]"; + description: "ZITADEL will use this URI to notify the application about terminated session according to the OIDC Back-Channel Logout (https://openid.net/specs/openid-connect-backchannel-1_0.html)"; + } + ]; } message UpdateOIDCAppConfigResponse {