mirror of
https://github.com/zitadel/zitadel.git
synced 2024-12-14 20:08:02 +00:00
267 lines
8.0 KiB
Go
267 lines
8.0 KiB
Go
|
package handlers
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"slices"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/zitadel/logging"
|
||
|
"github.com/zitadel/oidc/v3/pkg/crypto"
|
||
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||
|
"github.com/zitadel/oidc/v3/pkg/op"
|
||
|
|
||
|
"github.com/zitadel/zitadel/internal/api/authz"
|
||
|
http_utils "github.com/zitadel/zitadel/internal/api/http"
|
||
|
zoidc "github.com/zitadel/zitadel/internal/api/oidc"
|
||
|
"github.com/zitadel/zitadel/internal/command"
|
||
|
zcrypto "github.com/zitadel/zitadel/internal/crypto"
|
||
|
"github.com/zitadel/zitadel/internal/eventstore"
|
||
|
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
|
||
|
"github.com/zitadel/zitadel/internal/id"
|
||
|
"github.com/zitadel/zitadel/internal/notification/channels/set"
|
||
|
_ "github.com/zitadel/zitadel/internal/notification/statik"
|
||
|
"github.com/zitadel/zitadel/internal/notification/types"
|
||
|
"github.com/zitadel/zitadel/internal/repository/session"
|
||
|
"github.com/zitadel/zitadel/internal/repository/sessionlogout"
|
||
|
"github.com/zitadel/zitadel/internal/repository/user"
|
||
|
"github.com/zitadel/zitadel/internal/zerrors"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
BackChannelLogoutNotificationsProjectionTable = "projections.notifications_back_channel_logout"
|
||
|
)
|
||
|
|
||
|
type backChannelLogoutNotifier struct {
|
||
|
commands *command.Commands
|
||
|
queries *NotificationQueries
|
||
|
eventstore *eventstore.Eventstore
|
||
|
keyEncryptionAlg zcrypto.EncryptionAlgorithm
|
||
|
channels types.ChannelChains
|
||
|
idGenerator id.Generator
|
||
|
tokenLifetime time.Duration
|
||
|
}
|
||
|
|
||
|
func NewBackChannelLogoutNotifier(
|
||
|
ctx context.Context,
|
||
|
config handler.Config,
|
||
|
commands *command.Commands,
|
||
|
queries *NotificationQueries,
|
||
|
es *eventstore.Eventstore,
|
||
|
keyEncryptionAlg zcrypto.EncryptionAlgorithm,
|
||
|
channels types.ChannelChains,
|
||
|
tokenLifetime time.Duration,
|
||
|
) *handler.Handler {
|
||
|
return handler.NewHandler(ctx, &config, &backChannelLogoutNotifier{
|
||
|
commands: commands,
|
||
|
queries: queries,
|
||
|
eventstore: es,
|
||
|
keyEncryptionAlg: keyEncryptionAlg,
|
||
|
channels: channels,
|
||
|
tokenLifetime: tokenLifetime,
|
||
|
idGenerator: id.SonyFlakeGenerator(),
|
||
|
})
|
||
|
|
||
|
}
|
||
|
|
||
|
func (*backChannelLogoutNotifier) Name() string {
|
||
|
return BackChannelLogoutNotificationsProjectionTable
|
||
|
}
|
||
|
|
||
|
func (u *backChannelLogoutNotifier) Reducers() []handler.AggregateReducer {
|
||
|
return []handler.AggregateReducer{
|
||
|
{
|
||
|
Aggregate: session.AggregateType,
|
||
|
EventReducers: []handler.EventReducer{
|
||
|
{
|
||
|
Event: session.TerminateType,
|
||
|
Reduce: u.reduceSessionTerminated,
|
||
|
},
|
||
|
},
|
||
|
}, {
|
||
|
Aggregate: user.AggregateType,
|
||
|
EventReducers: []handler.EventReducer{
|
||
|
{
|
||
|
Event: user.HumanSignedOutType,
|
||
|
Reduce: u.reduceUserSignedOut,
|
||
|
},
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (u *backChannelLogoutNotifier) reduceUserSignedOut(event eventstore.Event) (*handler.Statement, error) {
|
||
|
e, ok := event.(*user.HumanSignedOutEvent)
|
||
|
if !ok {
|
||
|
return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-Gr63h", "reduce.wrong.event.type %s", user.HumanSignedOutType)
|
||
|
}
|
||
|
|
||
|
return handler.NewStatement(event, func(ex handler.Executer, projectionName string) error {
|
||
|
ctx, err := u.queries.HandlerContext(event.Aggregate())
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if !authz.GetFeatures(ctx).EnableBackChannelLogout {
|
||
|
return nil
|
||
|
}
|
||
|
if e.SessionID == "" {
|
||
|
return nil
|
||
|
}
|
||
|
return u.terminateSession(ctx, e.SessionID, e)
|
||
|
}), nil
|
||
|
}
|
||
|
|
||
|
func (u *backChannelLogoutNotifier) reduceSessionTerminated(event eventstore.Event) (*handler.Statement, error) {
|
||
|
e, ok := event.(*session.TerminateEvent)
|
||
|
if !ok {
|
||
|
return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-D6H2h", "reduce.wrong.event.type %s", session.TerminateType)
|
||
|
}
|
||
|
|
||
|
return handler.NewStatement(event, func(ex handler.Executer, projectionName string) error {
|
||
|
ctx, err := u.queries.HandlerContext(event.Aggregate())
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if !authz.GetFeatures(ctx).EnableBackChannelLogout {
|
||
|
return nil
|
||
|
}
|
||
|
return u.terminateSession(ctx, e.Aggregate().ID, e)
|
||
|
}), nil
|
||
|
}
|
||
|
|
||
|
type backChannelLogoutSession struct {
|
||
|
sessionID string
|
||
|
|
||
|
// sessions contain a map of oidc session IDs and their corresponding clientID
|
||
|
sessions []backChannelLogoutOIDCSessions
|
||
|
}
|
||
|
|
||
|
func (u *backChannelLogoutNotifier) terminateSession(ctx context.Context, id string, e eventstore.Event) error {
|
||
|
sessions := &backChannelLogoutSession{sessionID: id}
|
||
|
err := u.eventstore.FilterToQueryReducer(ctx, sessions)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
ctx, err = u.queries.Origin(ctx, e)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
getSigner := zoidc.GetSignerOnce(u.queries.GetActiveSigningWebKey, u.signingKey)
|
||
|
|
||
|
var wg sync.WaitGroup
|
||
|
wg.Add(len(sessions.sessions))
|
||
|
errs := make([]error, 0, len(sessions.sessions))
|
||
|
for _, oidcSession := range sessions.sessions {
|
||
|
go func(oidcSession *backChannelLogoutOIDCSessions) {
|
||
|
defer wg.Done()
|
||
|
err := u.sendLogoutToken(ctx, oidcSession, e, getSigner)
|
||
|
if err != nil {
|
||
|
errs = append(errs, err)
|
||
|
return
|
||
|
}
|
||
|
err = u.commands.BackChannelLogoutSent(ctx, oidcSession.SessionID, oidcSession.OIDCSessionID, e.Aggregate().InstanceID)
|
||
|
if err != nil {
|
||
|
errs = append(errs, err)
|
||
|
}
|
||
|
}(&oidcSession)
|
||
|
}
|
||
|
wg.Wait()
|
||
|
return errors.Join(errs...)
|
||
|
}
|
||
|
|
||
|
func (u *backChannelLogoutNotifier) signingKey(ctx context.Context) (op.SigningKey, error) {
|
||
|
keys, err := u.queries.ActivePrivateSigningKey(ctx, time.Now())
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if len(keys.Keys) == 0 {
|
||
|
logging.WithFields("instanceID", authz.GetInstance(ctx).InstanceID()).
|
||
|
Info("There's no active signing key and automatic rotation is not supported for back channel logout." +
|
||
|
"Please enable the webkey management feature on your instance")
|
||
|
return nil, zerrors.ThrowPreconditionFailed(nil, "HANDL-DF3nf", "no active signing key")
|
||
|
}
|
||
|
return zoidc.PrivateKeyToSigningKey(zoidc.SelectSigningKey(keys.Keys), u.keyEncryptionAlg)
|
||
|
}
|
||
|
|
||
|
func (u *backChannelLogoutNotifier) sendLogoutToken(ctx context.Context, oidcSession *backChannelLogoutOIDCSessions, e eventstore.Event, getSigner zoidc.SignerFunc) error {
|
||
|
token, err := u.logoutToken(ctx, oidcSession, getSigner)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
err = types.SendSecurityTokenEvent(ctx, set.Config{CallURL: oidcSession.BackChannelLogoutURI}, u.channels, &LogoutTokenMessage{LogoutToken: token}, e).WithoutTemplate()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (u *backChannelLogoutNotifier) logoutToken(ctx context.Context, oidcSession *backChannelLogoutOIDCSessions, getSigner zoidc.SignerFunc) (string, error) {
|
||
|
jwtID, err := u.idGenerator.Next()
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
token := oidc.NewLogoutTokenClaims(
|
||
|
http_utils.DomainContext(ctx).Origin(),
|
||
|
oidcSession.UserID,
|
||
|
oidc.Audience{oidcSession.ClientID},
|
||
|
time.Now().Add(u.tokenLifetime),
|
||
|
jwtID,
|
||
|
oidcSession.SessionID,
|
||
|
time.Second,
|
||
|
)
|
||
|
signer, _, err := getSigner(ctx)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
return crypto.Sign(token, signer)
|
||
|
}
|
||
|
|
||
|
type LogoutTokenMessage struct {
|
||
|
LogoutToken string `schema:"logout_token"`
|
||
|
}
|
||
|
|
||
|
type backChannelLogoutOIDCSessions struct {
|
||
|
SessionID string
|
||
|
OIDCSessionID string
|
||
|
UserID string
|
||
|
ClientID string
|
||
|
BackChannelLogoutURI string
|
||
|
}
|
||
|
|
||
|
func (b *backChannelLogoutSession) Reduce() error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (b *backChannelLogoutSession) AppendEvents(events ...eventstore.Event) {
|
||
|
for _, event := range events {
|
||
|
switch e := event.(type) {
|
||
|
case *sessionlogout.BackChannelLogoutRegisteredEvent:
|
||
|
b.sessions = append(b.sessions, backChannelLogoutOIDCSessions{
|
||
|
SessionID: b.sessionID,
|
||
|
OIDCSessionID: e.OIDCSessionID,
|
||
|
UserID: e.UserID,
|
||
|
ClientID: e.ClientID,
|
||
|
BackChannelLogoutURI: e.BackChannelLogoutURI,
|
||
|
})
|
||
|
case *sessionlogout.BackChannelLogoutSentEvent:
|
||
|
slices.DeleteFunc(b.sessions, func(session backChannelLogoutOIDCSessions) bool {
|
||
|
return session.OIDCSessionID == e.OIDCSessionID
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (b *backChannelLogoutSession) Query() *eventstore.SearchQueryBuilder {
|
||
|
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||
|
AddQuery().
|
||
|
AggregateTypes(sessionlogout.AggregateType).
|
||
|
AggregateIDs(b.sessionID).
|
||
|
EventTypes(
|
||
|
sessionlogout.BackChannelLogoutRegisteredType,
|
||
|
sessionlogout.BackChannelLogoutSentType).
|
||
|
Builder()
|
||
|
}
|