mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 21:17:32 +00:00
feat: add saml request to link to sessions (#9001)
# Which Problems Are Solved It is currently not possible to use SAML with the Session API. # How the Problems Are Solved Add SAML service, to get and resolve SAML requests. Add SAML session and SAML request aggregate, which can be linked to the Session to get back a SAMLResponse from the API directly. # Additional Changes Update of dependency zitadel/saml to provide all functionality for handling of SAML requests and responses. # Additional Context Closes #6053 --------- Co-authored-by: Livio Spring <livio.a@gmail.com>
This commit is contained in:
161
internal/command/saml_request.go
Normal file
161
internal/command/saml_request.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/repository/samlrequest"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
type SAMLRequest struct {
|
||||
ID string
|
||||
LoginClient string
|
||||
|
||||
ApplicationID string
|
||||
ACSURL string
|
||||
RelayState string
|
||||
RequestID string
|
||||
Binding string
|
||||
Issuer string
|
||||
Destination string
|
||||
}
|
||||
|
||||
type CurrentSAMLRequest struct {
|
||||
*SAMLRequest
|
||||
SessionID string
|
||||
UserID string
|
||||
AuthMethods []domain.UserAuthMethodType
|
||||
AuthTime time.Time
|
||||
}
|
||||
|
||||
func (c *Commands) AddSAMLRequest(ctx context.Context, samlRequest *SAMLRequest) (_ *CurrentSAMLRequest, err error) {
|
||||
id, err := c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
samlRequest.ID = IDPrefixV2 + id
|
||||
writeModel, err := c.getSAMLRequestWriteModel(ctx, samlRequest.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if writeModel.SAMLRequestState != domain.SAMLRequestStateUnspecified {
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-MO3vmsMLUt", "Errors.SAMLRequest.AlreadyExisting")
|
||||
}
|
||||
err = c.pushAppendAndReduce(ctx, writeModel, samlrequest.NewAddedEvent(
|
||||
ctx,
|
||||
&samlrequest.NewAggregate(samlRequest.ID, authz.GetInstance(ctx).InstanceID()).Aggregate,
|
||||
samlRequest.LoginClient,
|
||||
samlRequest.ApplicationID,
|
||||
samlRequest.ACSURL,
|
||||
samlRequest.RelayState,
|
||||
samlRequest.RequestID,
|
||||
samlRequest.Binding,
|
||||
samlRequest.Issuer,
|
||||
samlRequest.Destination,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return samlRequestWriteModelToCurrentSAMLRequest(writeModel), nil
|
||||
}
|
||||
|
||||
func (c *Commands) LinkSessionToSAMLRequest(ctx context.Context, id, sessionID, sessionToken string, checkLoginClient bool) (*domain.ObjectDetails, *CurrentSAMLRequest, error) {
|
||||
writeModel, err := c.getSAMLRequestWriteModel(ctx, id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if writeModel.SAMLRequestState == domain.SAMLRequestStateUnspecified {
|
||||
return nil, nil, zerrors.ThrowNotFound(nil, "COMMAND-GH3PVLSfXC", "Errors.SAMLRequest.NotExisting")
|
||||
}
|
||||
if writeModel.SAMLRequestState != domain.SAMLRequestStateAdded {
|
||||
return nil, nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-ttPKNdAIFT", "Errors.SAMLRequest.AlreadyHandled")
|
||||
}
|
||||
if checkLoginClient && authz.GetCtxData(ctx).UserID != writeModel.LoginClient {
|
||||
return nil, nil, zerrors.ThrowPermissionDenied(nil, "COMMAND-KCd48Rxt7x", "Errors.SAMLRequest.WrongLoginClient")
|
||||
}
|
||||
sessionWriteModel := NewSessionWriteModel(sessionID, authz.GetInstance(ctx).InstanceID())
|
||||
err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err = sessionWriteModel.CheckIsActive(); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := c.sessionTokenVerifier(ctx, sessionToken, sessionWriteModel.AggregateID, sessionWriteModel.TokenID); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := c.pushAppendAndReduce(ctx, writeModel, samlrequest.NewSessionLinkedEvent(
|
||||
ctx, &samlrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate,
|
||||
sessionID,
|
||||
sessionWriteModel.UserID,
|
||||
sessionWriteModel.AuthenticationTime(),
|
||||
sessionWriteModel.AuthMethodTypes(),
|
||||
)); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return writeModelToObjectDetails(&writeModel.WriteModel), samlRequestWriteModelToCurrentSAMLRequest(writeModel), nil
|
||||
}
|
||||
|
||||
func (c *Commands) FailSAMLRequest(ctx context.Context, id string, reason domain.SAMLErrorReason) (*domain.ObjectDetails, *CurrentSAMLRequest, error) {
|
||||
writeModel, err := c.getSAMLRequestWriteModel(ctx, id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if writeModel.SAMLRequestState != domain.SAMLRequestStateAdded {
|
||||
return nil, nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-32lGj1Fhjt", "Errors.SAMLRequest.AlreadyHandled")
|
||||
}
|
||||
err = c.pushAppendAndReduce(ctx, writeModel, samlrequest.NewFailedEvent(
|
||||
ctx,
|
||||
&samlrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate,
|
||||
reason,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return writeModelToObjectDetails(&writeModel.WriteModel), samlRequestWriteModelToCurrentSAMLRequest(writeModel), nil
|
||||
}
|
||||
|
||||
func samlRequestWriteModelToCurrentSAMLRequest(writeModel *SAMLRequestWriteModel) (_ *CurrentSAMLRequest) {
|
||||
return &CurrentSAMLRequest{
|
||||
SAMLRequest: &SAMLRequest{
|
||||
ID: writeModel.AggregateID,
|
||||
LoginClient: writeModel.LoginClient,
|
||||
ApplicationID: writeModel.ApplicationID,
|
||||
ACSURL: writeModel.ACSURL,
|
||||
RelayState: writeModel.RelayState,
|
||||
RequestID: writeModel.RequestID,
|
||||
Binding: writeModel.Binding,
|
||||
Issuer: writeModel.Issuer,
|
||||
Destination: writeModel.Destination,
|
||||
},
|
||||
SessionID: writeModel.SessionID,
|
||||
UserID: writeModel.UserID,
|
||||
AuthMethods: writeModel.AuthMethods,
|
||||
AuthTime: writeModel.AuthTime,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Commands) GetCurrentSAMLRequest(ctx context.Context, id string) (_ *CurrentSAMLRequest, err error) {
|
||||
wm, err := c.getSAMLRequestWriteModel(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return samlRequestWriteModelToCurrentSAMLRequest(wm), nil
|
||||
}
|
||||
|
||||
func (c *Commands) getSAMLRequestWriteModel(ctx context.Context, id string) (writeModel *SAMLRequestWriteModel, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
writeModel = NewSAMLRequestWriteModel(ctx, id)
|
||||
err = c.eventstore.FilterToQueryReducer(ctx, writeModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return writeModel, nil
|
||||
}
|
88
internal/command/saml_request_model.go
Normal file
88
internal/command/saml_request_model.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/samlrequest"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
type SAMLRequestWriteModel struct {
|
||||
eventstore.WriteModel
|
||||
aggregate *eventstore.Aggregate
|
||||
|
||||
LoginClient string
|
||||
ApplicationID string
|
||||
ACSURL string
|
||||
RelayState string
|
||||
RequestID string
|
||||
Binding string
|
||||
Issuer string
|
||||
Destination string
|
||||
|
||||
SessionID string
|
||||
UserID string
|
||||
AuthTime time.Time
|
||||
AuthMethods []domain.UserAuthMethodType
|
||||
SAMLRequestState domain.SAMLRequestState
|
||||
}
|
||||
|
||||
func NewSAMLRequestWriteModel(ctx context.Context, id string) *SAMLRequestWriteModel {
|
||||
return &SAMLRequestWriteModel{
|
||||
WriteModel: eventstore.WriteModel{
|
||||
AggregateID: id,
|
||||
},
|
||||
aggregate: &samlrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SAMLRequestWriteModel) Reduce() error {
|
||||
for _, event := range m.Events {
|
||||
switch e := event.(type) {
|
||||
case *samlrequest.AddedEvent:
|
||||
m.LoginClient = e.LoginClient
|
||||
m.ApplicationID = e.ApplicationID
|
||||
m.ACSURL = e.ACSURL
|
||||
m.RelayState = e.RelayState
|
||||
m.RequestID = e.RequestID
|
||||
m.Binding = e.Binding
|
||||
m.Issuer = e.Issuer
|
||||
m.Destination = e.Destination
|
||||
m.SAMLRequestState = domain.SAMLRequestStateAdded
|
||||
case *samlrequest.SessionLinkedEvent:
|
||||
m.SessionID = e.SessionID
|
||||
m.UserID = e.UserID
|
||||
m.AuthTime = e.AuthTime
|
||||
m.AuthMethods = e.AuthMethods
|
||||
case *samlrequest.FailedEvent:
|
||||
m.SAMLRequestState = domain.SAMLRequestStateFailed
|
||||
case *samlrequest.SucceededEvent:
|
||||
m.SAMLRequestState = domain.SAMLRequestStateSucceeded
|
||||
}
|
||||
}
|
||||
return m.WriteModel.Reduce()
|
||||
}
|
||||
|
||||
func (m *SAMLRequestWriteModel) Query() *eventstore.SearchQueryBuilder {
|
||||
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
AddQuery().
|
||||
AggregateTypes(samlrequest.AggregateType).
|
||||
AggregateIDs(m.AggregateID).
|
||||
Builder()
|
||||
}
|
||||
|
||||
// CheckAuthenticated checks that the auth request exists, a session must have been linked
|
||||
func (m *SAMLRequestWriteModel) CheckAuthenticated() error {
|
||||
if m.SessionID == "" {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "AUTHR-3dNRNwSYeC", "Errors.SAMLRequest.NotAuthenticated")
|
||||
}
|
||||
// check that the requests exists, but has not succeeded yet
|
||||
if m.SAMLRequestState == domain.SAMLRequestStateAdded {
|
||||
return nil
|
||||
}
|
||||
return zerrors.ThrowPreconditionFailed(nil, "AUTHR-krQV50AlnJ", "Errors.SAMLRequest.NotAuthenticated")
|
||||
}
|
676
internal/command/saml_request_test.go
Normal file
676
internal/command/saml_request_test.go
Normal file
@@ -0,0 +1,676 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
"github.com/zitadel/zitadel/internal/id/mock"
|
||||
"github.com/zitadel/zitadel/internal/repository/samlrequest"
|
||||
"github.com/zitadel/zitadel/internal/repository/session"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func TestCommands_AddSAMLRequest(t *testing.T) {
|
||||
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
|
||||
type fields struct {
|
||||
eventstore func(t *testing.T) *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
request *SAMLRequest
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want *CurrentSAMLRequest
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
"already exists error",
|
||||
fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"login",
|
||||
"application",
|
||||
"acs",
|
||||
"relaystate",
|
||||
"request",
|
||||
"binding",
|
||||
"issuer",
|
||||
"destination",
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
idGenerator: mock.NewIDGeneratorExpectIDs(t, "id"),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
request: &SAMLRequest{},
|
||||
},
|
||||
nil,
|
||||
zerrors.ThrowPreconditionFailed(nil, "COMMAND-MO3vmsMLUt", "Errors.SAMLRequest.AlreadyExisting"),
|
||||
},
|
||||
{
|
||||
"added",
|
||||
fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(),
|
||||
expectPush(
|
||||
samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"login",
|
||||
"application",
|
||||
"acs",
|
||||
"relaystate",
|
||||
"request",
|
||||
"binding",
|
||||
"issuer",
|
||||
"destination",
|
||||
),
|
||||
),
|
||||
),
|
||||
idGenerator: mock.NewIDGeneratorExpectIDs(t, "id"),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
request: &SAMLRequest{
|
||||
LoginClient: "login",
|
||||
ApplicationID: "application",
|
||||
ACSURL: "acs",
|
||||
RelayState: "relaystate",
|
||||
RequestID: "request",
|
||||
Binding: "binding",
|
||||
Issuer: "issuer",
|
||||
Destination: "destination",
|
||||
},
|
||||
},
|
||||
&CurrentSAMLRequest{
|
||||
SAMLRequest: &SAMLRequest{
|
||||
ID: "V2_id",
|
||||
LoginClient: "login",
|
||||
ApplicationID: "application",
|
||||
ACSURL: "acs",
|
||||
RelayState: "relaystate",
|
||||
RequestID: "request",
|
||||
Binding: "binding",
|
||||
Issuer: "issuer",
|
||||
Destination: "destination",
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
}
|
||||
got, err := c.AddSAMLRequest(tt.args.ctx, tt.args.request)
|
||||
require.ErrorIs(t, tt.wantErr, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_LinkSessionToSAMLRequest(t *testing.T) {
|
||||
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
|
||||
type fields struct {
|
||||
eventstore func(t *testing.T) *eventstore.Eventstore
|
||||
tokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error)
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
id string
|
||||
sessionID string
|
||||
sessionToken string
|
||||
checkLoginClient bool
|
||||
}
|
||||
type res struct {
|
||||
details *domain.ObjectDetails
|
||||
authReq *CurrentSAMLRequest
|
||||
wantErr error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"samlRequest not found",
|
||||
fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(),
|
||||
),
|
||||
tokenVerifier: newMockTokenVerifierValid(),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "id",
|
||||
sessionID: "sessionID",
|
||||
},
|
||||
res{
|
||||
wantErr: zerrors.ThrowNotFound(nil, "COMMAND-GH3PVLSfXC", "Errors.SAMLRequest.NotExisting"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"samlRequest not existing",
|
||||
fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"login",
|
||||
"application",
|
||||
"acs",
|
||||
"relaystate",
|
||||
"request",
|
||||
"binding",
|
||||
"issuer",
|
||||
"destination",
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
samlrequest.NewFailedEvent(mockCtx, &samlrequest.NewAggregate("id", "instanceID").Aggregate,
|
||||
domain.SAMLErrorReasonUnspecified,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
tokenVerifier: newMockTokenVerifierValid(),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "id",
|
||||
sessionID: "sessionID",
|
||||
},
|
||||
res{
|
||||
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-ttPKNdAIFT", "Errors.SAMLRequest.AlreadyHandled"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"wrong login client",
|
||||
fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"login",
|
||||
"application",
|
||||
"acs",
|
||||
"relaystate",
|
||||
"request",
|
||||
"binding",
|
||||
"issuer",
|
||||
"destination",
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
tokenVerifier: newMockTokenVerifierValid(),
|
||||
},
|
||||
args{
|
||||
ctx: authz.NewMockContext("instanceID", "orgID", "wrongLoginClient"),
|
||||
id: "id",
|
||||
sessionID: "sessionID",
|
||||
sessionToken: "token",
|
||||
checkLoginClient: true,
|
||||
},
|
||||
res{
|
||||
wantErr: zerrors.ThrowPermissionDenied(nil, "COMMAND-KCd48Rxt7x", "Errors.SAMLRequest.WrongLoginClient"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"session not existing",
|
||||
fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"login",
|
||||
"application",
|
||||
"acs",
|
||||
"relaystate",
|
||||
"request",
|
||||
"binding",
|
||||
"issuer",
|
||||
"destination",
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(),
|
||||
),
|
||||
tokenVerifier: newMockTokenVerifierValid(),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "V2_id",
|
||||
sessionID: "sessionID",
|
||||
},
|
||||
res{
|
||||
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Flk38", "Errors.Session.NotExisting"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"session expired",
|
||||
fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"login",
|
||||
"application",
|
||||
"acs",
|
||||
"relaystate",
|
||||
"request",
|
||||
"binding",
|
||||
"issuer",
|
||||
"destination",
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
session.NewAddedEvent(mockCtx,
|
||||
&session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
&domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
)),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
"userID", "org1", testNow.Add(-5*time.Minute), &language.Afrikaans),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
testNow.Add(-5*time.Minute)),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewLifetimeSetEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
2*time.Minute),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "V2_id",
|
||||
sessionID: "sessionID",
|
||||
sessionToken: "token",
|
||||
},
|
||||
res{
|
||||
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Hkl3d", "Errors.Session.Expired"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"invalid session token",
|
||||
fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"login",
|
||||
"application",
|
||||
"acs",
|
||||
"relaystate",
|
||||
"request",
|
||||
"binding",
|
||||
"issuer",
|
||||
"destination",
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
session.NewAddedEvent(mockCtx,
|
||||
&session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
&domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
)),
|
||||
),
|
||||
),
|
||||
tokenVerifier: newMockTokenVerifierInvalid(),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "V2_id",
|
||||
sessionID: "sessionID",
|
||||
sessionToken: "invalid",
|
||||
},
|
||||
res{
|
||||
wantErr: zerrors.ThrowPermissionDenied(nil, "COMMAND-sGr42", "Errors.Session.Token.Invalid"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"linked",
|
||||
fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"login",
|
||||
"application",
|
||||
"acs",
|
||||
"relaystate",
|
||||
"request",
|
||||
"binding",
|
||||
"issuer",
|
||||
"destination",
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
session.NewAddedEvent(mockCtx,
|
||||
&session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
&domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
)),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
"userID", "org1", testNow, &language.Afrikaans),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
testNow),
|
||||
),
|
||||
eventFromEventPusherWithCreationDateNow(
|
||||
session.NewLifetimeSetEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
2*time.Minute),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
samlrequest.NewSessionLinkedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"sessionID",
|
||||
"userID",
|
||||
testNow,
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
),
|
||||
),
|
||||
),
|
||||
tokenVerifier: newMockTokenVerifierValid(),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "V2_id",
|
||||
sessionID: "sessionID",
|
||||
sessionToken: "token",
|
||||
},
|
||||
res{
|
||||
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
|
||||
authReq: &CurrentSAMLRequest{
|
||||
SAMLRequest: &SAMLRequest{
|
||||
ID: "V2_id",
|
||||
LoginClient: "login",
|
||||
ApplicationID: "application",
|
||||
ACSURL: "acs",
|
||||
RelayState: "relaystate",
|
||||
RequestID: "request",
|
||||
Binding: "binding",
|
||||
Issuer: "issuer",
|
||||
Destination: "destination",
|
||||
},
|
||||
SessionID: "sessionID",
|
||||
UserID: "userID",
|
||||
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"linked with login client check",
|
||||
fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"application",
|
||||
"acs",
|
||||
"relaystate",
|
||||
"request",
|
||||
"binding",
|
||||
"issuer",
|
||||
"destination",
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
session.NewAddedEvent(mockCtx,
|
||||
&session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
&domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
)),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
"userID", "org1", testNow, &language.Afrikaans),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
testNow),
|
||||
),
|
||||
eventFromEventPusherWithCreationDateNow(
|
||||
session.NewLifetimeSetEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
2*time.Minute),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
samlrequest.NewSessionLinkedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"sessionID",
|
||||
"userID",
|
||||
testNow,
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
),
|
||||
),
|
||||
),
|
||||
tokenVerifier: newMockTokenVerifierValid(),
|
||||
},
|
||||
args{
|
||||
ctx: authz.NewMockContext("instanceID", "orgID", "loginClient"),
|
||||
id: "V2_id",
|
||||
sessionID: "sessionID",
|
||||
sessionToken: "token",
|
||||
checkLoginClient: true,
|
||||
},
|
||||
res{
|
||||
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
|
||||
authReq: &CurrentSAMLRequest{
|
||||
SAMLRequest: &SAMLRequest{
|
||||
ID: "V2_id",
|
||||
LoginClient: "loginClient",
|
||||
ApplicationID: "application",
|
||||
ACSURL: "acs",
|
||||
RelayState: "relaystate",
|
||||
RequestID: "request",
|
||||
Binding: "binding",
|
||||
Issuer: "issuer",
|
||||
Destination: "destination",
|
||||
},
|
||||
SessionID: "sessionID",
|
||||
UserID: "userID",
|
||||
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
sessionTokenVerifier: tt.fields.tokenVerifier,
|
||||
}
|
||||
details, got, err := c.LinkSessionToSAMLRequest(tt.args.ctx, tt.args.id, tt.args.sessionID, tt.args.sessionToken, tt.args.checkLoginClient)
|
||||
require.ErrorIs(t, err, tt.res.wantErr)
|
||||
assertObjectDetails(t, tt.res.details, details)
|
||||
if err == nil {
|
||||
assert.WithinRange(t, got.AuthTime, testNow, testNow)
|
||||
got.AuthTime = time.Time{}
|
||||
}
|
||||
assert.Equal(t, tt.res.authReq, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_FailSAMLRequest(t *testing.T) {
|
||||
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
|
||||
type fields struct {
|
||||
eventstore func(t *testing.T) *eventstore.Eventstore
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
id string
|
||||
reason domain.SAMLErrorReason
|
||||
description string
|
||||
}
|
||||
type res struct {
|
||||
details *domain.ObjectDetails
|
||||
samlReq *CurrentSAMLRequest
|
||||
wantErr error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"authRequest not existing",
|
||||
fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "foo",
|
||||
reason: domain.SAMLErrorReasonAuthNFailed,
|
||||
},
|
||||
res{
|
||||
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-32lGj1Fhjt", "Errors.SAMLRequest.AlreadyHandled"),
|
||||
},
|
||||
}, {
|
||||
"already failed",
|
||||
fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"login",
|
||||
"application",
|
||||
"acs",
|
||||
"relaystate",
|
||||
"request",
|
||||
"binding",
|
||||
"issuer",
|
||||
"destination",
|
||||
),
|
||||
),
|
||||
samlrequest.NewFailedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
domain.SAMLErrorReasonAuthNFailed,
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "V2_id",
|
||||
reason: domain.SAMLErrorReasonAuthNFailed,
|
||||
description: "desc",
|
||||
},
|
||||
res{
|
||||
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-32lGj1Fhjt", "Errors.SAMLRequest.AlreadyHandled"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"failed",
|
||||
fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
"login",
|
||||
"application",
|
||||
"acs",
|
||||
"relaystate",
|
||||
"request",
|
||||
"binding",
|
||||
"issuer",
|
||||
"destination",
|
||||
),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
samlrequest.NewFailedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||
domain.SAMLErrorReasonAuthNFailed,
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: mockCtx,
|
||||
id: "V2_id",
|
||||
reason: domain.SAMLErrorReasonAuthNFailed,
|
||||
description: "desc",
|
||||
},
|
||||
res{
|
||||
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
|
||||
samlReq: &CurrentSAMLRequest{
|
||||
SAMLRequest: &SAMLRequest{
|
||||
ID: "V2_id",
|
||||
LoginClient: "login",
|
||||
ApplicationID: "application",
|
||||
ACSURL: "acs",
|
||||
RelayState: "relaystate",
|
||||
RequestID: "request",
|
||||
Binding: "binding",
|
||||
Issuer: "issuer",
|
||||
Destination: "destination",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
}
|
||||
details, got, err := c.FailSAMLRequest(tt.args.ctx, tt.args.id, tt.args.reason)
|
||||
require.ErrorIs(t, err, tt.res.wantErr)
|
||||
assertObjectDetails(t, tt.res.details, details)
|
||||
assert.Equal(t, tt.res.samlReq, got)
|
||||
})
|
||||
}
|
||||
}
|
186
internal/command/saml_session.go
Normal file
186
internal/command/saml_session.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/activity"
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
"github.com/zitadel/zitadel/internal/repository/samlrequest"
|
||||
"github.com/zitadel/zitadel/internal/repository/samlsession"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
type SAMLSession struct {
|
||||
SessionID string
|
||||
SAMLResponseID string
|
||||
EntityID string
|
||||
UserID string
|
||||
Audience []string
|
||||
Expiration time.Time
|
||||
AuthMethods []domain.UserAuthMethodType
|
||||
AuthTime time.Time
|
||||
PreferredLanguage *language.Tag
|
||||
UserAgent *domain.UserAgent
|
||||
}
|
||||
|
||||
type SAMLRequestComplianceChecker func(context.Context, *SAMLRequestWriteModel) error
|
||||
|
||||
func (c *Commands) CreateSAMLSessionFromSAMLRequest(ctx context.Context, samlReqId string, complianceCheck SAMLRequestComplianceChecker, samlResponseID string, samlResponseLifetime time.Duration) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if samlReqId == "" {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "COMMAND-0LxK6O31wH", "Errors.SAMLRequest.InvalidCode")
|
||||
}
|
||||
|
||||
samlReqModel, err := c.getSAMLRequestWriteModel(ctx, samlReqId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
instanceID := authz.GetInstance(ctx).InstanceID()
|
||||
sessionModel := NewSessionWriteModel(samlReqModel.SessionID, instanceID)
|
||||
err = c.eventstore.FilterToQueryReducer(ctx, sessionModel)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = sessionModel.CheckIsActive(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd, err := c.newSAMLSessionAddEvents(ctx, sessionModel.UserID, sessionModel.UserResourceOwner)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = complianceCheck(ctx, samlReqModel); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.AddSession(ctx,
|
||||
sessionModel.UserID,
|
||||
sessionModel.UserResourceOwner,
|
||||
sessionModel.AggregateID,
|
||||
samlReqModel.Issuer,
|
||||
[]string{samlReqModel.Issuer},
|
||||
samlReqModel.AuthMethods,
|
||||
samlReqModel.AuthTime,
|
||||
sessionModel.PreferredLanguage,
|
||||
sessionModel.UserAgent,
|
||||
)
|
||||
|
||||
if err = cmd.AddSAMLResponse(ctx, samlResponseID, samlResponseLifetime); err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.SetSAMLRequestSuccessful(ctx, samlReqModel.aggregate)
|
||||
_, err = cmd.PushEvents(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Commands) newSAMLSessionAddEvents(ctx context.Context, userID, resourceOwner string, pending ...eventstore.Command) (*SAMLSessionEvents, error) {
|
||||
userStateModel, err := c.userStateWriteModel(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !userStateModel.UserState.IsEnabled() {
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "SAML-1768ZQpmcP", "Errors.User.NotActive")
|
||||
}
|
||||
sessionID, err := c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessionID = IDPrefixV2 + sessionID
|
||||
return &SAMLSessionEvents{
|
||||
commands: c,
|
||||
idGenerator: c.idGenerator,
|
||||
encryptionAlg: c.keyAlgorithm,
|
||||
events: pending,
|
||||
samlSessionWriteModel: NewSAMLSessionWriteModel(sessionID, resourceOwner),
|
||||
userStateModel: userStateModel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type SAMLSessionEvents struct {
|
||||
commands *Commands
|
||||
idGenerator id.Generator
|
||||
encryptionAlg crypto.EncryptionAlgorithm
|
||||
events []eventstore.Command
|
||||
samlSessionWriteModel *SAMLSessionWriteModel
|
||||
userStateModel *UserV2WriteModel
|
||||
|
||||
// samlResponseID is set by the command
|
||||
samlResponseID string
|
||||
}
|
||||
|
||||
func (c *SAMLSessionEvents) AddSession(
|
||||
ctx context.Context,
|
||||
userID,
|
||||
userResourceOwner,
|
||||
sessionID,
|
||||
entityID string,
|
||||
audience []string,
|
||||
authMethods []domain.UserAuthMethodType,
|
||||
authTime time.Time,
|
||||
preferredLanguage *language.Tag,
|
||||
userAgent *domain.UserAgent,
|
||||
) {
|
||||
c.events = append(c.events, samlsession.NewAddedEvent(
|
||||
ctx,
|
||||
c.samlSessionWriteModel.aggregate,
|
||||
userID,
|
||||
userResourceOwner,
|
||||
sessionID,
|
||||
entityID,
|
||||
audience,
|
||||
authMethods,
|
||||
authTime,
|
||||
preferredLanguage,
|
||||
userAgent,
|
||||
))
|
||||
}
|
||||
|
||||
func (c *SAMLSessionEvents) SetSAMLRequestSuccessful(ctx context.Context, samlRequestAggregate *eventstore.Aggregate) {
|
||||
c.events = append(c.events, samlrequest.NewSucceededEvent(ctx, samlRequestAggregate))
|
||||
}
|
||||
|
||||
func (c *SAMLSessionEvents) SetSAMLRequestFailed(ctx context.Context, samlRequestAggregate *eventstore.Aggregate, err domain.SAMLErrorReason) {
|
||||
c.events = append(c.events, samlrequest.NewFailedEvent(ctx, samlRequestAggregate, err))
|
||||
}
|
||||
|
||||
func (c *SAMLSessionEvents) AddSAMLResponse(ctx context.Context, id string, lifetime time.Duration) error {
|
||||
c.samlResponseID = id
|
||||
c.events = append(c.events, samlsession.NewSAMLResponseAddedEvent(ctx, c.samlSessionWriteModel.aggregate, id, lifetime))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SAMLSessionEvents) PushEvents(ctx context.Context) (*SAMLSession, error) {
|
||||
pushedEvents, err := c.commands.eventstore.Push(ctx, c.events...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = AppendAndReduce(c.samlSessionWriteModel, pushedEvents...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
session := &SAMLSession{
|
||||
SessionID: c.samlSessionWriteModel.SessionID,
|
||||
EntityID: c.samlSessionWriteModel.EntityID,
|
||||
UserID: c.samlSessionWriteModel.UserID,
|
||||
Audience: c.samlSessionWriteModel.Audience,
|
||||
Expiration: c.samlSessionWriteModel.SAMLResponseExpiration,
|
||||
AuthMethods: c.samlSessionWriteModel.AuthMethods,
|
||||
AuthTime: c.samlSessionWriteModel.AuthTime,
|
||||
PreferredLanguage: c.samlSessionWriteModel.PreferredLanguage,
|
||||
UserAgent: c.samlSessionWriteModel.UserAgent,
|
||||
SAMLResponseID: c.samlSessionWriteModel.SAMLResponseID,
|
||||
}
|
||||
activity.Trigger(ctx, c.samlSessionWriteModel.UserResourceOwner, c.samlSessionWriteModel.UserID, activity.SAMLResponse, c.commands.eventstore.FilterToQueryReducer)
|
||||
return session, nil
|
||||
}
|
102
internal/command/saml_session_model.go
Normal file
102
internal/command/saml_session_model.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/samlsession"
|
||||
)
|
||||
|
||||
type SAMLSessionWriteModel struct {
|
||||
eventstore.WriteModel
|
||||
|
||||
UserID string
|
||||
UserResourceOwner string
|
||||
PreferredLanguage *language.Tag
|
||||
SessionID string
|
||||
EntityID string
|
||||
Audience []string
|
||||
AuthMethods []domain.UserAuthMethodType
|
||||
AuthTime time.Time
|
||||
UserAgent *domain.UserAgent
|
||||
State domain.SAMLSessionState
|
||||
SAMLResponseID string
|
||||
SAMLResponseCreation time.Time
|
||||
SAMLResponseExpiration time.Time
|
||||
|
||||
aggregate *eventstore.Aggregate
|
||||
}
|
||||
|
||||
func NewSAMLSessionWriteModel(id string, resourceOwner string) *SAMLSessionWriteModel {
|
||||
return &SAMLSessionWriteModel{
|
||||
WriteModel: eventstore.WriteModel{
|
||||
AggregateID: id,
|
||||
ResourceOwner: resourceOwner,
|
||||
},
|
||||
aggregate: &samlsession.NewAggregate(id, resourceOwner).Aggregate,
|
||||
}
|
||||
}
|
||||
|
||||
func (wm *SAMLSessionWriteModel) Reduce() error {
|
||||
for _, event := range wm.Events {
|
||||
switch e := event.(type) {
|
||||
case *samlsession.AddedEvent:
|
||||
wm.reduceAdded(e)
|
||||
case *samlsession.SAMLResponseAddedEvent:
|
||||
wm.reduceSAMLResponseAdded(e)
|
||||
case *samlsession.SAMLResponseRevokedEvent:
|
||||
wm.reduceSAMLResponseRevoked(e)
|
||||
}
|
||||
}
|
||||
return wm.WriteModel.Reduce()
|
||||
}
|
||||
|
||||
func (wm *SAMLSessionWriteModel) Query() *eventstore.SearchQueryBuilder {
|
||||
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
AddQuery().
|
||||
AggregateTypes(samlsession.AggregateType).
|
||||
AggregateIDs(wm.AggregateID).
|
||||
EventTypes(
|
||||
samlsession.AddedType,
|
||||
samlsession.SAMLResponseAddedType,
|
||||
samlsession.SAMLResponseRevokedType,
|
||||
).
|
||||
Builder()
|
||||
|
||||
if wm.ResourceOwner != "" {
|
||||
query.ResourceOwner(wm.ResourceOwner)
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
func (wm *SAMLSessionWriteModel) reduceAdded(e *samlsession.AddedEvent) {
|
||||
wm.UserID = e.UserID
|
||||
wm.UserResourceOwner = e.UserResourceOwner
|
||||
wm.SessionID = e.SessionID
|
||||
wm.EntityID = e.EntityID
|
||||
wm.Audience = e.Audience
|
||||
wm.AuthMethods = e.AuthMethods
|
||||
wm.AuthTime = e.AuthTime
|
||||
wm.PreferredLanguage = e.PreferredLanguage
|
||||
wm.UserAgent = e.UserAgent
|
||||
wm.State = domain.SAMLSessionStateActive
|
||||
// the write model might be initialized without resource owner,
|
||||
// so update the aggregate
|
||||
if wm.ResourceOwner == "" {
|
||||
wm.aggregate = &samlsession.NewAggregate(wm.AggregateID, e.Aggregate().ResourceOwner).Aggregate
|
||||
}
|
||||
}
|
||||
|
||||
func (wm *SAMLSessionWriteModel) reduceSAMLResponseAdded(e *samlsession.SAMLResponseAddedEvent) {
|
||||
wm.SAMLResponseID = e.ID
|
||||
wm.SAMLResponseCreation = e.CreationDate()
|
||||
wm.SAMLResponseExpiration = e.CreationDate().Add(e.Lifetime)
|
||||
}
|
||||
|
||||
func (wm *SAMLSessionWriteModel) reduceSAMLResponseRevoked(e *samlsession.SAMLResponseRevokedEvent) {
|
||||
wm.SAMLResponseID = ""
|
||||
wm.SAMLResponseExpiration = e.CreationDate()
|
||||
}
|
337
internal/command/saml_session_test.go
Normal file
337
internal/command/saml_session_test.go
Normal file
@@ -0,0 +1,337 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
"github.com/zitadel/zitadel/internal/id/mock"
|
||||
"github.com/zitadel/zitadel/internal/repository/samlrequest"
|
||||
"github.com/zitadel/zitadel/internal/repository/samlsession"
|
||||
"github.com/zitadel/zitadel/internal/repository/session"
|
||||
"github.com/zitadel/zitadel/internal/repository/user"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func mockSAMLRequestComplianceChecker(returnErr error) SAMLRequestComplianceChecker {
|
||||
return func(context.Context, *SAMLRequestWriteModel) error {
|
||||
return returnErr
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommands_CreateSAMLSessionFromSAMLRequest(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore func(*testing.T) *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
keyAlgorithm crypto.EncryptionAlgorithm
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
samlRequestID string
|
||||
samlResponseID string
|
||||
complianceCheck SAMLRequestComplianceChecker
|
||||
samlResponseLifetime time.Duration
|
||||
}
|
||||
type res struct {
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"missing code",
|
||||
fields{
|
||||
eventstore: expectEventstore(),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
samlRequestID: "",
|
||||
complianceCheck: mockSAMLRequestComplianceChecker(nil),
|
||||
},
|
||||
res{
|
||||
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-0LxK6O31wH", "Errors.SAMLRequest.InvalidCode"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"filter error",
|
||||
fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilterError(io.ErrClosedPipe),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
samlRequestID: "V2_samlRequestID",
|
||||
complianceCheck: mockSAMLRequestComplianceChecker(nil),
|
||||
},
|
||||
res{
|
||||
err: io.ErrClosedPipe,
|
||||
},
|
||||
},
|
||||
{
|
||||
"session filter error",
|
||||
fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
samlrequest.NewAddedEvent(context.Background(), &samlrequest.NewAggregate("V2_samlRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"applicationId",
|
||||
"acs",
|
||||
"relaystate",
|
||||
"request",
|
||||
"binding",
|
||||
"issuer",
|
||||
"destination",
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilterError(io.ErrClosedPipe),
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
samlRequestID: "V2_samlRequestID",
|
||||
complianceCheck: mockSAMLRequestComplianceChecker(nil),
|
||||
},
|
||||
res{
|
||||
err: io.ErrClosedPipe,
|
||||
},
|
||||
},
|
||||
{
|
||||
"inactive session error",
|
||||
fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
samlrequest.NewAddedEvent(context.Background(), &samlrequest.NewAggregate("V2_samlRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"applicationId",
|
||||
"acs",
|
||||
"relaystate",
|
||||
"request",
|
||||
"binding",
|
||||
"issuer",
|
||||
"destination",
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
samlrequest.NewSessionLinkedEvent(context.Background(), &samlrequest.NewAggregate("V2_samlRequestID", "instanceID").Aggregate,
|
||||
"sessionID",
|
||||
"userID",
|
||||
testNow,
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(), // inactive session
|
||||
),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
samlRequestID: "V2_samlRequestID",
|
||||
complianceCheck: mockSAMLRequestComplianceChecker(nil),
|
||||
},
|
||||
res{
|
||||
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Flk38", "Errors.Session.NotExisting"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"user not active",
|
||||
fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
samlrequest.NewAddedEvent(context.Background(), &samlrequest.NewAggregate("V2_samlRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"applicationId",
|
||||
"acs",
|
||||
"relaystate",
|
||||
"request",
|
||||
"binding",
|
||||
"issuer",
|
||||
"destination",
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
samlrequest.NewSessionLinkedEvent(context.Background(), &samlrequest.NewAggregate("V2_samlRequestID", "instanceID").Aggregate,
|
||||
"sessionID",
|
||||
"userID",
|
||||
testNow,
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
session.NewAddedEvent(context.Background(),
|
||||
&session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
&domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
"userID", "org1", testNow, &language.Afrikaans),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewPasswordCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
testNow),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
user.NewHumanAddedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "org1").Aggregate,
|
||||
"username",
|
||||
"firstname",
|
||||
"lastname",
|
||||
"nickname",
|
||||
"displayname",
|
||||
language.Afrikaans,
|
||||
domain.GenderUnspecified,
|
||||
"email",
|
||||
false,
|
||||
),
|
||||
user.NewUserDeactivatedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "org1").Aggregate,
|
||||
),
|
||||
),
|
||||
),
|
||||
idGenerator: mock.NewIDGeneratorExpectIDs(t),
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
samlRequestID: "V2_samlRequestID",
|
||||
samlResponseID: "samlResponseID",
|
||||
samlResponseLifetime: time.Minute * 5,
|
||||
complianceCheck: mockSAMLRequestComplianceChecker(nil),
|
||||
},
|
||||
res{
|
||||
err: zerrors.ThrowPreconditionFailed(nil, "SAML-1768ZQpmcP", "Errors.User.NotActive"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"add successful",
|
||||
fields{
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
samlrequest.NewAddedEvent(context.Background(), &samlrequest.NewAggregate("V2_samlRequestID", "instanceID").Aggregate,
|
||||
"loginClient",
|
||||
"applicationId",
|
||||
"acs",
|
||||
"relaystate",
|
||||
"request",
|
||||
"binding",
|
||||
"issuer",
|
||||
"destination",
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
samlrequest.NewSessionLinkedEvent(context.Background(), &samlrequest.NewAggregate("V2_samlRequestID", "instanceID").Aggregate,
|
||||
"sessionID",
|
||||
"userID",
|
||||
testNow,
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
session.NewAddedEvent(context.Background(),
|
||||
&session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||
&domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
"userID", "org1", testNow, &language.Afrikaans),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
session.NewPasswordCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
|
||||
testNow),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
user.NewHumanAddedEvent(
|
||||
context.Background(),
|
||||
&user.NewAggregate("userID", "org1").Aggregate,
|
||||
"username",
|
||||
"firstname",
|
||||
"lastname",
|
||||
"nickname",
|
||||
"displayname",
|
||||
language.Afrikaans,
|
||||
domain.GenderUnspecified,
|
||||
"email",
|
||||
false,
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
samlsession.NewAddedEvent(context.Background(), &samlsession.NewAggregate("V2_samlSessionID", "org1").Aggregate,
|
||||
"userID", "org1", "sessionID", "issuer", []string{"issuer"},
|
||||
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow, &language.Afrikaans,
|
||||
&domain.UserAgent{
|
||||
FingerprintID: gu.Ptr("fp1"),
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Description: gu.Ptr("firefox"),
|
||||
Header: http.Header{"foo": []string{"bar"}},
|
||||
},
|
||||
),
|
||||
samlsession.NewSAMLResponseAddedEvent(context.Background(), &samlsession.NewAggregate("V2_samlSessionID", "org1").Aggregate, "samlResponseID", time.Minute*5),
|
||||
samlrequest.NewSucceededEvent(context.Background(), &samlrequest.NewAggregate("V2_samlRequestID", "instanceID").Aggregate),
|
||||
),
|
||||
),
|
||||
idGenerator: mock.NewIDGeneratorExpectIDs(t, "samlSessionID"),
|
||||
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
|
||||
samlRequestID: "V2_samlRequestID",
|
||||
samlResponseID: "samlResponseID",
|
||||
samlResponseLifetime: time.Minute * 5,
|
||||
complianceCheck: mockSAMLRequestComplianceChecker(nil),
|
||||
},
|
||||
res{},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
keyAlgorithm: tt.fields.keyAlgorithm,
|
||||
}
|
||||
err := c.CreateSAMLSessionFromSAMLRequest(tt.args.ctx, tt.args.samlRequestID, tt.args.complianceCheck, tt.args.samlResponseID, tt.args.samlResponseLifetime)
|
||||
require.ErrorIs(t, err, tt.res.err)
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user