zitadel/internal/command/saml_request_test.go
Stefan Benz c3b97a91a2
feat: add saml request to link to sessions (#9001)
# Which Problems Are Solved

It is currently not possible to use SAML with the Session API.

# How the Problems Are Solved

Add SAML service, to get and resolve SAML requests.
Add SAML session and SAML request aggregate, which can be linked to the
Session to get back a SAMLResponse from the API directly.

# Additional Changes

Update of dependency zitadel/saml to provide all functionality for
handling of SAML requests and responses.

# Additional Context

Closes #6053

---------

Co-authored-by: Livio Spring <livio.a@gmail.com>
2024-12-19 11:11:40 +00:00

677 lines
18 KiB
Go

package command
import (
"context"
"net"
"net/http"
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/id/mock"
"github.com/zitadel/zitadel/internal/repository/samlrequest"
"github.com/zitadel/zitadel/internal/repository/session"
"github.com/zitadel/zitadel/internal/zerrors"
)
func TestCommands_AddSAMLRequest(t *testing.T) {
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
type fields struct {
eventstore func(t *testing.T) *eventstore.Eventstore
idGenerator id.Generator
}
type args struct {
ctx context.Context
request *SAMLRequest
}
tests := []struct {
name string
fields fields
args args
want *CurrentSAMLRequest
wantErr error
}{
{
"already exists error",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"login",
"application",
"acs",
"relaystate",
"request",
"binding",
"issuer",
"destination",
),
),
),
),
idGenerator: mock.NewIDGeneratorExpectIDs(t, "id"),
},
args{
ctx: mockCtx,
request: &SAMLRequest{},
},
nil,
zerrors.ThrowPreconditionFailed(nil, "COMMAND-MO3vmsMLUt", "Errors.SAMLRequest.AlreadyExisting"),
},
{
"added",
fields{
eventstore: expectEventstore(
expectFilter(),
expectPush(
samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"login",
"application",
"acs",
"relaystate",
"request",
"binding",
"issuer",
"destination",
),
),
),
idGenerator: mock.NewIDGeneratorExpectIDs(t, "id"),
},
args{
ctx: mockCtx,
request: &SAMLRequest{
LoginClient: "login",
ApplicationID: "application",
ACSURL: "acs",
RelayState: "relaystate",
RequestID: "request",
Binding: "binding",
Issuer: "issuer",
Destination: "destination",
},
},
&CurrentSAMLRequest{
SAMLRequest: &SAMLRequest{
ID: "V2_id",
LoginClient: "login",
ApplicationID: "application",
ACSURL: "acs",
RelayState: "relaystate",
RequestID: "request",
Binding: "binding",
Issuer: "issuer",
Destination: "destination",
},
},
nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
idGenerator: tt.fields.idGenerator,
}
got, err := c.AddSAMLRequest(tt.args.ctx, tt.args.request)
require.ErrorIs(t, tt.wantErr, err)
assert.Equal(t, tt.want, got)
})
}
}
func TestCommands_LinkSessionToSAMLRequest(t *testing.T) {
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
type fields struct {
eventstore func(t *testing.T) *eventstore.Eventstore
tokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error)
}
type args struct {
ctx context.Context
id string
sessionID string
sessionToken string
checkLoginClient bool
}
type res struct {
details *domain.ObjectDetails
authReq *CurrentSAMLRequest
wantErr error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"samlRequest not found",
fields{
eventstore: expectEventstore(
expectFilter(),
),
tokenVerifier: newMockTokenVerifierValid(),
},
args{
ctx: mockCtx,
id: "id",
sessionID: "sessionID",
},
res{
wantErr: zerrors.ThrowNotFound(nil, "COMMAND-GH3PVLSfXC", "Errors.SAMLRequest.NotExisting"),
},
},
{
"samlRequest not existing",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"login",
"application",
"acs",
"relaystate",
"request",
"binding",
"issuer",
"destination",
),
),
eventFromEventPusher(
samlrequest.NewFailedEvent(mockCtx, &samlrequest.NewAggregate("id", "instanceID").Aggregate,
domain.SAMLErrorReasonUnspecified,
),
),
),
),
tokenVerifier: newMockTokenVerifierValid(),
},
args{
ctx: mockCtx,
id: "id",
sessionID: "sessionID",
},
res{
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-ttPKNdAIFT", "Errors.SAMLRequest.AlreadyHandled"),
},
},
{
"wrong login client",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"login",
"application",
"acs",
"relaystate",
"request",
"binding",
"issuer",
"destination",
),
),
),
),
tokenVerifier: newMockTokenVerifierValid(),
},
args{
ctx: authz.NewMockContext("instanceID", "orgID", "wrongLoginClient"),
id: "id",
sessionID: "sessionID",
sessionToken: "token",
checkLoginClient: true,
},
res{
wantErr: zerrors.ThrowPermissionDenied(nil, "COMMAND-KCd48Rxt7x", "Errors.SAMLRequest.WrongLoginClient"),
},
},
{
"session not existing",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"login",
"application",
"acs",
"relaystate",
"request",
"binding",
"issuer",
"destination",
),
),
),
expectFilter(),
),
tokenVerifier: newMockTokenVerifierValid(),
},
args{
ctx: mockCtx,
id: "V2_id",
sessionID: "sessionID",
},
res{
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Flk38", "Errors.Session.NotExisting"),
},
},
{
"session expired",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"login",
"application",
"acs",
"relaystate",
"request",
"binding",
"issuer",
"destination",
),
),
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(mockCtx,
&session.NewAggregate("sessionID", "instance1").Aggregate,
&domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
)),
eventFromEventPusher(
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
"userID", "org1", testNow.Add(-5*time.Minute), &language.Afrikaans),
),
eventFromEventPusher(
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
testNow.Add(-5*time.Minute)),
),
eventFromEventPusher(
session.NewLifetimeSetEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
2*time.Minute),
),
),
),
},
args{
ctx: mockCtx,
id: "V2_id",
sessionID: "sessionID",
sessionToken: "token",
},
res{
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Hkl3d", "Errors.Session.Expired"),
},
},
{
"invalid session token",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"login",
"application",
"acs",
"relaystate",
"request",
"binding",
"issuer",
"destination",
),
),
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(mockCtx,
&session.NewAggregate("sessionID", "instance1").Aggregate,
&domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
)),
),
),
tokenVerifier: newMockTokenVerifierInvalid(),
},
args{
ctx: mockCtx,
id: "V2_id",
sessionID: "sessionID",
sessionToken: "invalid",
},
res{
wantErr: zerrors.ThrowPermissionDenied(nil, "COMMAND-sGr42", "Errors.Session.Token.Invalid"),
},
},
{
"linked",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"login",
"application",
"acs",
"relaystate",
"request",
"binding",
"issuer",
"destination",
),
),
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(mockCtx,
&session.NewAggregate("sessionID", "instance1").Aggregate,
&domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
)),
eventFromEventPusher(
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
"userID", "org1", testNow, &language.Afrikaans),
),
eventFromEventPusher(
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
testNow),
),
eventFromEventPusherWithCreationDateNow(
session.NewLifetimeSetEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
2*time.Minute),
),
),
expectPush(
samlrequest.NewSessionLinkedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"sessionID",
"userID",
testNow,
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
),
),
),
tokenVerifier: newMockTokenVerifierValid(),
},
args{
ctx: mockCtx,
id: "V2_id",
sessionID: "sessionID",
sessionToken: "token",
},
res{
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
authReq: &CurrentSAMLRequest{
SAMLRequest: &SAMLRequest{
ID: "V2_id",
LoginClient: "login",
ApplicationID: "application",
ACSURL: "acs",
RelayState: "relaystate",
RequestID: "request",
Binding: "binding",
Issuer: "issuer",
Destination: "destination",
},
SessionID: "sessionID",
UserID: "userID",
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
},
},
},
{
"linked with login client check",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"loginClient",
"application",
"acs",
"relaystate",
"request",
"binding",
"issuer",
"destination",
),
),
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(mockCtx,
&session.NewAggregate("sessionID", "instance1").Aggregate,
&domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
)),
eventFromEventPusher(
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
"userID", "org1", testNow, &language.Afrikaans),
),
eventFromEventPusher(
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
testNow),
),
eventFromEventPusherWithCreationDateNow(
session.NewLifetimeSetEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
2*time.Minute),
),
),
expectPush(
samlrequest.NewSessionLinkedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"sessionID",
"userID",
testNow,
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
),
),
),
tokenVerifier: newMockTokenVerifierValid(),
},
args{
ctx: authz.NewMockContext("instanceID", "orgID", "loginClient"),
id: "V2_id",
sessionID: "sessionID",
sessionToken: "token",
checkLoginClient: true,
},
res{
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
authReq: &CurrentSAMLRequest{
SAMLRequest: &SAMLRequest{
ID: "V2_id",
LoginClient: "loginClient",
ApplicationID: "application",
ACSURL: "acs",
RelayState: "relaystate",
RequestID: "request",
Binding: "binding",
Issuer: "issuer",
Destination: "destination",
},
SessionID: "sessionID",
UserID: "userID",
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
sessionTokenVerifier: tt.fields.tokenVerifier,
}
details, got, err := c.LinkSessionToSAMLRequest(tt.args.ctx, tt.args.id, tt.args.sessionID, tt.args.sessionToken, tt.args.checkLoginClient)
require.ErrorIs(t, err, tt.res.wantErr)
assertObjectDetails(t, tt.res.details, details)
if err == nil {
assert.WithinRange(t, got.AuthTime, testNow, testNow)
got.AuthTime = time.Time{}
}
assert.Equal(t, tt.res.authReq, got)
})
}
}
func TestCommands_FailSAMLRequest(t *testing.T) {
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
type fields struct {
eventstore func(t *testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
id string
reason domain.SAMLErrorReason
description string
}
type res struct {
details *domain.ObjectDetails
samlReq *CurrentSAMLRequest
wantErr error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"authRequest not existing",
fields{
eventstore: expectEventstore(
expectFilter(),
),
},
args{
ctx: mockCtx,
id: "foo",
reason: domain.SAMLErrorReasonAuthNFailed,
},
res{
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-32lGj1Fhjt", "Errors.SAMLRequest.AlreadyHandled"),
},
}, {
"already failed",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"login",
"application",
"acs",
"relaystate",
"request",
"binding",
"issuer",
"destination",
),
),
samlrequest.NewFailedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
domain.SAMLErrorReasonAuthNFailed,
),
),
),
},
args{
ctx: mockCtx,
id: "V2_id",
reason: domain.SAMLErrorReasonAuthNFailed,
description: "desc",
},
res{
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-32lGj1Fhjt", "Errors.SAMLRequest.AlreadyHandled"),
},
},
{
"failed",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"login",
"application",
"acs",
"relaystate",
"request",
"binding",
"issuer",
"destination",
),
),
),
expectPush(
samlrequest.NewFailedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
domain.SAMLErrorReasonAuthNFailed,
),
),
),
},
args{
ctx: mockCtx,
id: "V2_id",
reason: domain.SAMLErrorReasonAuthNFailed,
description: "desc",
},
res{
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
samlReq: &CurrentSAMLRequest{
SAMLRequest: &SAMLRequest{
ID: "V2_id",
LoginClient: "login",
ApplicationID: "application",
ACSURL: "acs",
RelayState: "relaystate",
RequestID: "request",
Binding: "binding",
Issuer: "issuer",
Destination: "destination",
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
}
details, got, err := c.FailSAMLRequest(tt.args.ctx, tt.args.id, tt.args.reason)
require.ErrorIs(t, err, tt.res.wantErr)
assertObjectDetails(t, tt.res.details, details)
assert.Equal(t, tt.res.samlReq, got)
})
}
}