zitadel/internal/command/saml_session.go

187 lines
5.8 KiB
Go
Raw Normal View History

package command
import (
"context"
"time"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/activity"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/repository/samlrequest"
"github.com/zitadel/zitadel/internal/repository/samlsession"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
type SAMLSession struct {
SessionID string
SAMLResponseID string
EntityID string
UserID string
Audience []string
Expiration time.Time
AuthMethods []domain.UserAuthMethodType
AuthTime time.Time
PreferredLanguage *language.Tag
UserAgent *domain.UserAgent
}
type SAMLRequestComplianceChecker func(context.Context, *SAMLRequestWriteModel) error
func (c *Commands) CreateSAMLSessionFromSAMLRequest(ctx context.Context, samlReqId string, complianceCheck SAMLRequestComplianceChecker, samlResponseID string, samlResponseLifetime time.Duration) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if samlReqId == "" {
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-0LxK6O31wH", "Errors.SAMLRequest.InvalidCode")
}
samlReqModel, err := c.getSAMLRequestWriteModel(ctx, samlReqId)
if err != nil {
return err
}
instanceID := authz.GetInstance(ctx).InstanceID()
sessionModel := NewSessionWriteModel(samlReqModel.SessionID, instanceID)
err = c.eventstore.FilterToQueryReducer(ctx, sessionModel)
if err != nil {
return err
}
if err = sessionModel.CheckIsActive(); err != nil {
return err
}
cmd, err := c.newSAMLSessionAddEvents(ctx, sessionModel.UserID, sessionModel.UserResourceOwner)
if err != nil {
return err
}
if err = complianceCheck(ctx, samlReqModel); err != nil {
return err
}
cmd.AddSession(ctx,
sessionModel.UserID,
sessionModel.UserResourceOwner,
sessionModel.AggregateID,
samlReqModel.Issuer,
[]string{samlReqModel.Issuer},
samlReqModel.AuthMethods,
samlReqModel.AuthTime,
sessionModel.PreferredLanguage,
sessionModel.UserAgent,
)
if err = cmd.AddSAMLResponse(ctx, samlResponseID, samlResponseLifetime); err != nil {
return err
}
cmd.SetSAMLRequestSuccessful(ctx, samlReqModel.aggregate)
_, err = cmd.PushEvents(ctx)
return err
}
func (c *Commands) newSAMLSessionAddEvents(ctx context.Context, userID, resourceOwner string, pending ...eventstore.Command) (*SAMLSessionEvents, error) {
userStateModel, err := c.userStateWriteModel(ctx, userID)
if err != nil {
return nil, err
}
if !userStateModel.UserState.IsEnabled() {
return nil, zerrors.ThrowPreconditionFailed(nil, "SAML-1768ZQpmcP", "Errors.User.NotActive")
}
sessionID, err := c.idGenerator.Next()
if err != nil {
return nil, err
}
sessionID = IDPrefixV2 + sessionID
return &SAMLSessionEvents{
commands: c,
idGenerator: c.idGenerator,
encryptionAlg: c.keyAlgorithm,
events: pending,
samlSessionWriteModel: NewSAMLSessionWriteModel(sessionID, resourceOwner),
userStateModel: userStateModel,
}, nil
}
type SAMLSessionEvents struct {
commands *Commands
idGenerator id.Generator
encryptionAlg crypto.EncryptionAlgorithm
events []eventstore.Command
samlSessionWriteModel *SAMLSessionWriteModel
userStateModel *UserV2WriteModel
// samlResponseID is set by the command
samlResponseID string
}
func (c *SAMLSessionEvents) AddSession(
ctx context.Context,
userID,
userResourceOwner,
sessionID,
entityID string,
audience []string,
authMethods []domain.UserAuthMethodType,
authTime time.Time,
preferredLanguage *language.Tag,
userAgent *domain.UserAgent,
) {
c.events = append(c.events, samlsession.NewAddedEvent(
ctx,
c.samlSessionWriteModel.aggregate,
userID,
userResourceOwner,
sessionID,
entityID,
audience,
authMethods,
authTime,
preferredLanguage,
userAgent,
))
}
func (c *SAMLSessionEvents) SetSAMLRequestSuccessful(ctx context.Context, samlRequestAggregate *eventstore.Aggregate) {
c.events = append(c.events, samlrequest.NewSucceededEvent(ctx, samlRequestAggregate))
}
func (c *SAMLSessionEvents) SetSAMLRequestFailed(ctx context.Context, samlRequestAggregate *eventstore.Aggregate, err domain.SAMLErrorReason) {
c.events = append(c.events, samlrequest.NewFailedEvent(ctx, samlRequestAggregate, err))
}
func (c *SAMLSessionEvents) AddSAMLResponse(ctx context.Context, id string, lifetime time.Duration) error {
c.samlResponseID = id
c.events = append(c.events, samlsession.NewSAMLResponseAddedEvent(ctx, c.samlSessionWriteModel.aggregate, id, lifetime))
return nil
}
func (c *SAMLSessionEvents) PushEvents(ctx context.Context) (*SAMLSession, error) {
pushedEvents, err := c.commands.eventstore.Push(ctx, c.events...)
if err != nil {
return nil, err
}
err = AppendAndReduce(c.samlSessionWriteModel, pushedEvents...)
if err != nil {
return nil, err
}
session := &SAMLSession{
SessionID: c.samlSessionWriteModel.SessionID,
EntityID: c.samlSessionWriteModel.EntityID,
UserID: c.samlSessionWriteModel.UserID,
Audience: c.samlSessionWriteModel.Audience,
Expiration: c.samlSessionWriteModel.SAMLResponseExpiration,
AuthMethods: c.samlSessionWriteModel.AuthMethods,
AuthTime: c.samlSessionWriteModel.AuthTime,
PreferredLanguage: c.samlSessionWriteModel.PreferredLanguage,
UserAgent: c.samlSessionWriteModel.UserAgent,
SAMLResponseID: c.samlSessionWriteModel.SAMLResponseID,
}
activity.Trigger(ctx, c.samlSessionWriteModel.UserResourceOwner, c.samlSessionWriteModel.UserID, activity.SAMLResponse, c.commands.eventstore.FilterToQueryReducer)
return session, nil
}