mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 21:37:32 +00:00
feat: add saml request to link to sessions (#9001)
# Which Problems Are Solved It is currently not possible to use SAML with the Session API. # How the Problems Are Solved Add SAML service, to get and resolve SAML requests. Add SAML session and SAML request aggregate, which can be linked to the Session to get back a SAMLResponse from the API directly. # Additional Changes Update of dependency zitadel/saml to provide all functionality for handling of SAML requests and responses. # Additional Context Closes #6053 --------- Co-authored-by: Livio Spring <livio.a@gmail.com>
This commit is contained in:
@@ -69,6 +69,7 @@ var (
|
||||
DeviceAuthProjection *handler.Handler
|
||||
SessionProjection *handler.Handler
|
||||
AuthRequestProjection *handler.Handler
|
||||
SamlRequestProjection *handler.Handler
|
||||
MilestoneProjection *handler.Handler
|
||||
QuotaProjection *quotaProjection
|
||||
LimitsProjection *handler.Handler
|
||||
@@ -157,6 +158,7 @@ func Create(ctx context.Context, sqlClient *database.DB, es handler.EventStore,
|
||||
DeviceAuthProjection = newDeviceAuthProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["device_auth"]))
|
||||
SessionProjection = newSessionProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["sessions"]))
|
||||
AuthRequestProjection = newAuthRequestProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["auth_requests"]))
|
||||
SamlRequestProjection = newSamlRequestProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["saml_requests"]))
|
||||
MilestoneProjection = newMilestoneProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["milestones"]))
|
||||
QuotaProjection = newQuotaProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["quotas"]))
|
||||
LimitsProjection = newLimitsProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["limits"]))
|
||||
@@ -286,6 +288,7 @@ func newProjectionsList() {
|
||||
DeviceAuthProjection,
|
||||
SessionProjection,
|
||||
AuthRequestProjection,
|
||||
SamlRequestProjection,
|
||||
MilestoneProjection,
|
||||
QuotaProjection.handler,
|
||||
LimitsProjection,
|
||||
|
132
internal/query/projection/saml_request.go
Normal file
132
internal/query/projection/saml_request.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package projection
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
old_handler "github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
|
||||
"github.com/zitadel/zitadel/internal/repository/instance"
|
||||
"github.com/zitadel/zitadel/internal/repository/samlrequest"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
const (
|
||||
SamlRequestsProjectionTable = "projections.saml_requests"
|
||||
|
||||
SamlRequestColumnID = "id"
|
||||
SamlRequestColumnCreationDate = "creation_date"
|
||||
SamlRequestColumnChangeDate = "change_date"
|
||||
SamlRequestColumnSequence = "sequence"
|
||||
SamlRequestColumnResourceOwner = "resource_owner"
|
||||
SamlRequestColumnInstanceID = "instance_id"
|
||||
SamlRequestColumnLoginClient = "login_client"
|
||||
SamlRequestColumnIssuer = "issuer"
|
||||
SamlRequestColumnACS = "acs"
|
||||
SamlRequestColumnRelayState = "relay_state"
|
||||
SamlRequestColumnBinding = "binding"
|
||||
)
|
||||
|
||||
type samlRequestProjection struct{}
|
||||
|
||||
// Name implements handler.Projection.
|
||||
func (*samlRequestProjection) Name() string {
|
||||
return SamlRequestsProjectionTable
|
||||
}
|
||||
|
||||
func newSamlRequestProjection(ctx context.Context, config handler.Config) *handler.Handler {
|
||||
return handler.NewHandler(ctx, &config, new(samlRequestProjection))
|
||||
}
|
||||
|
||||
func (*samlRequestProjection) Init() *old_handler.Check {
|
||||
return handler.NewMultiTableCheck(
|
||||
handler.NewTable([]*handler.InitColumn{
|
||||
handler.NewColumn(SamlRequestColumnID, handler.ColumnTypeText),
|
||||
handler.NewColumn(SamlRequestColumnCreationDate, handler.ColumnTypeTimestamp),
|
||||
handler.NewColumn(SamlRequestColumnChangeDate, handler.ColumnTypeTimestamp),
|
||||
handler.NewColumn(SamlRequestColumnSequence, handler.ColumnTypeInt64),
|
||||
handler.NewColumn(SamlRequestColumnResourceOwner, handler.ColumnTypeText),
|
||||
handler.NewColumn(SamlRequestColumnInstanceID, handler.ColumnTypeText),
|
||||
handler.NewColumn(SamlRequestColumnLoginClient, handler.ColumnTypeText),
|
||||
handler.NewColumn(SamlRequestColumnIssuer, handler.ColumnTypeText),
|
||||
handler.NewColumn(SamlRequestColumnACS, handler.ColumnTypeText),
|
||||
handler.NewColumn(SamlRequestColumnRelayState, handler.ColumnTypeText),
|
||||
handler.NewColumn(SamlRequestColumnBinding, handler.ColumnTypeText),
|
||||
},
|
||||
handler.NewPrimaryKey(SamlRequestColumnInstanceID, SamlRequestColumnID),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (p *samlRequestProjection) Reducers() []handler.AggregateReducer {
|
||||
return []handler.AggregateReducer{
|
||||
{
|
||||
Aggregate: samlrequest.AggregateType,
|
||||
EventReducers: []handler.EventReducer{
|
||||
{
|
||||
Event: samlrequest.AddedType,
|
||||
Reduce: p.reduceSamlRequestAdded,
|
||||
},
|
||||
{
|
||||
Event: samlrequest.SucceededType,
|
||||
Reduce: p.reduceSamlRequestEnded,
|
||||
},
|
||||
{
|
||||
Event: samlrequest.FailedType,
|
||||
Reduce: p.reduceSamlRequestEnded,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Aggregate: instance.AggregateType,
|
||||
EventReducers: []handler.EventReducer{
|
||||
{
|
||||
Event: instance.InstanceRemovedEventType,
|
||||
Reduce: reduceInstanceRemovedHelper(SamlRequestColumnInstanceID),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *samlRequestProjection) reduceSamlRequestAdded(event eventstore.Event) (*handler.Statement, error) {
|
||||
e, ok := event.(*samlrequest.AddedEvent)
|
||||
if !ok {
|
||||
return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-Sfwfa", "reduce.wrong.event.type %s", samlrequest.AddedType)
|
||||
}
|
||||
|
||||
return handler.NewCreateStatement(
|
||||
e,
|
||||
[]handler.Column{
|
||||
handler.NewCol(SamlRequestColumnID, e.Aggregate().ID),
|
||||
handler.NewCol(SamlRequestColumnInstanceID, e.Aggregate().InstanceID),
|
||||
handler.NewCol(SamlRequestColumnCreationDate, e.CreationDate()),
|
||||
handler.NewCol(SamlRequestColumnChangeDate, e.CreationDate()),
|
||||
handler.NewCol(SamlRequestColumnResourceOwner, e.Aggregate().ResourceOwner),
|
||||
handler.NewCol(SamlRequestColumnSequence, e.Sequence()),
|
||||
handler.NewCol(SamlRequestColumnLoginClient, e.LoginClient),
|
||||
handler.NewCol(SamlRequestColumnIssuer, e.Issuer),
|
||||
handler.NewCol(SamlRequestColumnACS, e.ACSURL),
|
||||
handler.NewCol(SamlRequestColumnRelayState, e.RelayState),
|
||||
handler.NewCol(SamlRequestColumnBinding, e.Binding),
|
||||
},
|
||||
), nil
|
||||
}
|
||||
|
||||
func (p *samlRequestProjection) reduceSamlRequestEnded(event eventstore.Event) (*handler.Statement, error) {
|
||||
switch event.(type) {
|
||||
case *samlrequest.SucceededEvent,
|
||||
*samlrequest.FailedEvent:
|
||||
break
|
||||
default:
|
||||
return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-ASF3h", "reduce.wrong.event.type %s", []eventstore.EventType{samlrequest.SucceededType, samlrequest.FailedType})
|
||||
}
|
||||
|
||||
return handler.NewDeleteStatement(
|
||||
event,
|
||||
[]handler.Condition{
|
||||
handler.NewCond(SamlRequestColumnID, event.Aggregate().ID),
|
||||
handler.NewCond(SamlRequestColumnInstanceID, event.Aggregate().InstanceID),
|
||||
},
|
||||
), nil
|
||||
}
|
123
internal/query/projection/saml_request_test.go
Normal file
123
internal/query/projection/saml_request_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package projection
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
|
||||
"github.com/zitadel/zitadel/internal/repository/samlrequest"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func TestSamlRequestProjection_reduces(t *testing.T) {
|
||||
type args struct {
|
||||
event func(t *testing.T) eventstore.Event
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
reduce func(event eventstore.Event) (*handler.Statement, error)
|
||||
want wantReduce
|
||||
}{
|
||||
{
|
||||
name: "reduceSamlRequestAdded",
|
||||
args: args{
|
||||
event: getEvent(testEvent(
|
||||
samlrequest.AddedType,
|
||||
samlrequest.AggregateType,
|
||||
[]byte(`{"login_client": "loginClient", "issuer": "issuer", "acs_url": "acs", "relay_state": "relayState", "binding": "binding"}`),
|
||||
), eventstore.GenericEventMapper[samlrequest.AddedEvent]),
|
||||
},
|
||||
reduce: (&samlRequestProjection{}).reduceSamlRequestAdded,
|
||||
want: wantReduce{
|
||||
aggregateType: eventstore.AggregateType("saml_request"),
|
||||
sequence: 15,
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "INSERT INTO projections.saml_requests (id, instance_id, creation_date, change_date, resource_owner, sequence, login_client, issuer, acs, relay_state, binding) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)",
|
||||
expectedArgs: []interface{}{
|
||||
"agg-id",
|
||||
"instance-id",
|
||||
anyArg{},
|
||||
anyArg{},
|
||||
"ro-id",
|
||||
uint64(15),
|
||||
"loginClient",
|
||||
"issuer",
|
||||
"acs",
|
||||
"relayState",
|
||||
"binding",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "reduceSamlRequestFailed",
|
||||
args: args{
|
||||
event: getEvent(testEvent(
|
||||
samlrequest.FailedType,
|
||||
samlrequest.AggregateType,
|
||||
[]byte(`{"reason": 0}`),
|
||||
), eventstore.GenericEventMapper[samlrequest.FailedEvent]),
|
||||
},
|
||||
reduce: (&samlRequestProjection{}).reduceSamlRequestEnded,
|
||||
want: wantReduce{
|
||||
aggregateType: eventstore.AggregateType("saml_request"),
|
||||
sequence: 15,
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "DELETE FROM projections.saml_requests WHERE (id = $1) AND (instance_id = $2)",
|
||||
expectedArgs: []interface{}{
|
||||
"agg-id",
|
||||
"instance-id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "reduceSamlRequestSucceeded",
|
||||
args: args{
|
||||
event: getEvent(testEvent(
|
||||
samlrequest.SucceededType,
|
||||
samlrequest.AggregateType,
|
||||
nil,
|
||||
), eventstore.GenericEventMapper[samlrequest.SucceededEvent]),
|
||||
},
|
||||
reduce: (&samlRequestProjection{}).reduceSamlRequestEnded,
|
||||
want: wantReduce{
|
||||
aggregateType: eventstore.AggregateType("saml_request"),
|
||||
sequence: 15,
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "DELETE FROM projections.saml_requests WHERE (id = $1) AND (instance_id = $2)",
|
||||
expectedArgs: []interface{}{
|
||||
"agg-id",
|
||||
"instance-id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
event := baseEvent(t)
|
||||
got, err := tt.reduce(event)
|
||||
if !zerrors.IsErrorInvalidArgument(err) {
|
||||
t.Errorf("no wrong event mapping: %v, got: %v", err, got)
|
||||
}
|
||||
|
||||
event = tt.args.event(t)
|
||||
got, err = tt.reduce(event)
|
||||
assertReduce(t, got, err, SamlRequestsProjectionTable, tt.want)
|
||||
})
|
||||
}
|
||||
}
|
81
internal/query/saml_request.go
Normal file
81
internal/query/saml_request.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/call"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
|
||||
"github.com/zitadel/zitadel/internal/query/projection"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
type SamlRequest struct {
|
||||
ID string
|
||||
CreationDate time.Time
|
||||
LoginClient string
|
||||
Issuer string
|
||||
ACS string
|
||||
RelayState string
|
||||
Binding string
|
||||
}
|
||||
|
||||
func (a *SamlRequest) checkLoginClient(ctx context.Context) error {
|
||||
if uid := authz.GetCtxData(ctx).UserID; uid != a.LoginClient {
|
||||
return zerrors.ThrowPermissionDenied(nil, "OIDCv2-aL0ag", "Errors.SamlRequest.WrongLoginClient")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//go:embed saml_request_by_id.sql
|
||||
var samlRequestByIDQuery string
|
||||
|
||||
func (q *Queries) samlRequestByIDQuery(ctx context.Context) string {
|
||||
return fmt.Sprintf(samlRequestByIDQuery, q.client.Timetravel(call.Took(ctx)))
|
||||
}
|
||||
|
||||
func (q *Queries) SamlRequestByID(ctx context.Context, shouldTriggerBulk bool, id string, checkLoginClient bool) (_ *SamlRequest, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if shouldTriggerBulk {
|
||||
_, traceSpan := tracing.NewNamedSpan(ctx, "TriggerSamlRequestProjection")
|
||||
ctx, err = projection.SamlRequestProjection.Trigger(ctx, handler.WithAwaitRunning())
|
||||
logging.OnError(err).Debug("trigger failed")
|
||||
traceSpan.EndWithError(err)
|
||||
}
|
||||
|
||||
dst := new(SamlRequest)
|
||||
err = q.client.QueryRowContext(
|
||||
ctx,
|
||||
func(row *sql.Row) error {
|
||||
return row.Scan(
|
||||
&dst.ID, &dst.CreationDate, &dst.LoginClient, &dst.Issuer, &dst.ACS, &dst.RelayState, &dst.Binding,
|
||||
)
|
||||
},
|
||||
q.samlRequestByIDQuery(ctx),
|
||||
id, authz.GetInstance(ctx).InstanceID(),
|
||||
)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, zerrors.ThrowNotFound(err, "QUERY-Thee9", "Errors.SamlRequest.NotExisting")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "QUERY-Ou8ue", "Errors.Internal")
|
||||
}
|
||||
|
||||
if checkLoginClient {
|
||||
if err = dst.checkLoginClient(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return dst, nil
|
||||
}
|
11
internal/query/saml_request_by_id.sql
Normal file
11
internal/query/saml_request_by_id.sql
Normal file
@@ -0,0 +1,11 @@
|
||||
select
|
||||
id,
|
||||
creation_date,
|
||||
login_client,
|
||||
issuer,
|
||||
acs,
|
||||
relay_state,
|
||||
binding
|
||||
from projections.saml_requests %s
|
||||
where id = $1 and instance_id = $2
|
||||
limit 1;
|
127
internal/query/saml_request_test.go
Normal file
127
internal/query/saml_request_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/query/projection"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func TestQueries_SamlRequestByID(t *testing.T) {
|
||||
expQuery := regexp.QuoteMeta(fmt.Sprintf(
|
||||
samlRequestByIDQuery,
|
||||
asOfSystemTime,
|
||||
))
|
||||
|
||||
cols := []string{
|
||||
projection.SamlRequestColumnID,
|
||||
projection.SamlRequestColumnCreationDate,
|
||||
projection.SamlRequestColumnLoginClient,
|
||||
projection.SamlRequestColumnIssuer,
|
||||
projection.SamlRequestColumnACS,
|
||||
projection.SamlRequestColumnRelayState,
|
||||
projection.SamlRequestColumnBinding,
|
||||
}
|
||||
type args struct {
|
||||
shouldTriggerBulk bool
|
||||
id string
|
||||
checkLoginClient bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
expect sqlExpectation
|
||||
want *SamlRequest
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "success, all values",
|
||||
args: args{
|
||||
shouldTriggerBulk: false,
|
||||
id: "123",
|
||||
checkLoginClient: true,
|
||||
},
|
||||
expect: mockQuery(expQuery, cols, []driver.Value{
|
||||
"id",
|
||||
testNow,
|
||||
"loginClient",
|
||||
"issuer",
|
||||
"acs",
|
||||
"relayState",
|
||||
"binding",
|
||||
}, "123", "instanceID"),
|
||||
want: &SamlRequest{
|
||||
ID: "id",
|
||||
CreationDate: testNow,
|
||||
LoginClient: "loginClient",
|
||||
Issuer: "issuer",
|
||||
ACS: "acs",
|
||||
RelayState: "relayState",
|
||||
Binding: "binding",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no rows",
|
||||
args: args{
|
||||
shouldTriggerBulk: false,
|
||||
id: "123",
|
||||
},
|
||||
expect: mockQueryScanErr(expQuery, cols, nil, "123", "instanceID"),
|
||||
wantErr: zerrors.ThrowNotFound(sql.ErrNoRows, "QUERY-Thee9", "Errors.SamlRequest.NotExisting"),
|
||||
},
|
||||
{
|
||||
name: "query error",
|
||||
args: args{
|
||||
shouldTriggerBulk: false,
|
||||
id: "123",
|
||||
},
|
||||
expect: mockQueryErr(expQuery, sql.ErrConnDone, "123", "instanceID"),
|
||||
wantErr: zerrors.ThrowInternal(sql.ErrConnDone, "QUERY-Ou8ue", "Errors.Internal"),
|
||||
},
|
||||
{
|
||||
name: "wrong login client",
|
||||
args: args{
|
||||
shouldTriggerBulk: false,
|
||||
id: "123",
|
||||
checkLoginClient: true,
|
||||
},
|
||||
expect: mockQuery(expQuery, cols, []driver.Value{
|
||||
"id",
|
||||
testNow,
|
||||
"wrongLoginClient",
|
||||
"issuer",
|
||||
"acs",
|
||||
"relayState",
|
||||
"binding",
|
||||
}, "123", "instanceID"),
|
||||
wantErr: zerrors.ThrowPermissionDeniedf(nil, "OIDCv2-aL0ag", "Errors.SamlRequest.WrongLoginClient"),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
execMock(t, tt.expect, func(db *sql.DB) {
|
||||
q := &Queries{
|
||||
client: &database.DB{
|
||||
DB: db,
|
||||
Database: &prepareDB{},
|
||||
},
|
||||
}
|
||||
ctx := authz.NewMockContext("instanceID", "orgID", "loginClient")
|
||||
|
||||
got, err := q.SamlRequestByID(ctx, tt.args.shouldTriggerBulk, tt.args.id, tt.args.checkLoginClient)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user