feat: saml application configuration for login version (#9351)

# Which Problems Are Solved

OIDC applications can configure the used login version, which is
currently not possible for SAML applications.

# How the Problems Are Solved

Add the same functionality dependent on the feature-flag for SAML
applications.

# Additional Changes

None

# Additional Context

Closes #9267
Follow up issue for frontend changes #9354

---------

Co-authored-by: Livio Spring <livio.a@gmail.com>
This commit is contained in:
Stefan Benz
2025-02-13 17:03:05 +01:00
committed by GitHub
parent 66296db971
commit 49de5c61b2
40 changed files with 1051 additions and 240 deletions

View File

@@ -98,7 +98,11 @@ func (s *Server) AddOIDCApp(ctx context.Context, req *mgmt_pb.AddOIDCAppRequest)
}, nil
}
func (s *Server) AddSAMLApp(ctx context.Context, req *mgmt_pb.AddSAMLAppRequest) (*mgmt_pb.AddSAMLAppResponse, error) {
app, err := s.command.AddSAMLApplication(ctx, AddSAMLAppRequestToDomain(req), authz.GetCtxData(ctx).OrgID)
samlApp, err := AddSAMLAppRequestToDomain(req)
if err != nil {
return nil, err
}
app, err := s.command.AddSAMLApplication(ctx, samlApp, authz.GetCtxData(ctx).OrgID)
if err != nil {
return nil, err
}
@@ -150,7 +154,11 @@ func (s *Server) UpdateOIDCAppConfig(ctx context.Context, req *mgmt_pb.UpdateOID
}
func (s *Server) UpdateSAMLAppConfig(ctx context.Context, req *mgmt_pb.UpdateSAMLAppConfigRequest) (*mgmt_pb.UpdateSAMLAppConfigResponse, error) {
config, err := s.command.ChangeSAMLApplication(ctx, UpdateSAMLAppConfigRequestToDomain(req), authz.GetCtxData(ctx).OrgID)
samlApp, err := UpdateSAMLAppConfigRequestToDomain(req)
if err != nil {
return nil, err
}
config, err := s.command.ChangeSAMLApplication(ctx, samlApp, authz.GetCtxData(ctx).OrgID)
if err != nil {
return nil, err
}

View File

@@ -67,15 +67,21 @@ func AddOIDCAppRequestToDomain(req *mgmt_pb.AddOIDCAppRequest) (*domain.OIDCApp,
}, nil
}
func AddSAMLAppRequestToDomain(req *mgmt_pb.AddSAMLAppRequest) *domain.SAMLApp {
func AddSAMLAppRequestToDomain(req *mgmt_pb.AddSAMLAppRequest) (*domain.SAMLApp, error) {
loginVersion, loginBaseURI, err := app_grpc.LoginVersionToDomain(req.GetLoginVersion())
if err != nil {
return nil, err
}
return &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{
AggregateID: req.ProjectId,
},
AppName: req.Name,
Metadata: req.GetMetadataXml(),
MetadataURL: req.GetMetadataUrl(),
}
AppName: req.Name,
Metadata: req.GetMetadataXml(),
MetadataURL: req.GetMetadataUrl(),
LoginVersion: loginVersion,
LoginBaseURI: loginBaseURI,
}, nil
}
func AddAPIAppRequestToDomain(app *mgmt_pb.AddAPIAppRequest) *domain.APIApp {
@@ -125,15 +131,21 @@ func UpdateOIDCAppConfigRequestToDomain(app *mgmt_pb.UpdateOIDCAppConfigRequest)
}, nil
}
func UpdateSAMLAppConfigRequestToDomain(app *mgmt_pb.UpdateSAMLAppConfigRequest) *domain.SAMLApp {
func UpdateSAMLAppConfigRequestToDomain(app *mgmt_pb.UpdateSAMLAppConfigRequest) (*domain.SAMLApp, error) {
loginVersion, loginBaseURI, err := app_grpc.LoginVersionToDomain(app.GetLoginVersion())
if err != nil {
return nil, err
}
return &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{
AggregateID: app.ProjectId,
},
AppID: app.AppId,
Metadata: app.GetMetadataXml(),
MetadataURL: app.GetMetadataUrl(),
}
AppID: app.AppId,
Metadata: app.GetMetadataXml(),
MetadataURL: app.GetMetadataUrl(),
LoginVersion: loginVersion,
LoginBaseURI: loginBaseURI,
}, nil
}
func UpdateAPIAppConfigRequestToDomain(app *mgmt_pb.UpdateAPIAppConfigRequest) *domain.APIApp {

View File

@@ -85,7 +85,8 @@ func loginVersionToPb(version domain.LoginVersion, baseURI *string) *app_pb.Logi
func AppSAMLConfigToPb(app *query.SAMLApp) app_pb.AppConfig {
return &app_pb.App_SamlConfig{
SamlConfig: &app_pb.SAMLConfig{
Metadata: &app_pb.SAMLConfig_MetadataXml{MetadataXml: app.Metadata},
Metadata: &app_pb.SAMLConfig_MetadataXml{MetadataXml: app.Metadata},
LoginVersion: loginVersionToPb(app.LoginVersion, app.LoginBaseURI),
},
}
}

View File

@@ -0,0 +1,53 @@
package saml
import (
"strings"
"github.com/zitadel/saml/pkg/provider/serviceprovider"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/query"
)
const (
LoginSamlRequestParam = "samlRequest"
LoginPath = "/login"
)
type ServiceProvider struct {
SP *query.SAMLServiceProvider
defaultLoginURL string
defaultLoginURLV2 string
}
func ServiceProviderFromBusiness(spQuery *query.SAMLServiceProvider, defaultLoginURL, defaultLoginURLV2 string) (*serviceprovider.ServiceProvider, error) {
sp := &ServiceProvider{
SP: spQuery,
defaultLoginURL: defaultLoginURL,
defaultLoginURLV2: defaultLoginURLV2,
}
return serviceprovider.NewServiceProvider(
spQuery.AppID,
&serviceprovider.Config{Metadata: spQuery.Metadata},
sp.LoginURL,
)
}
func (s *ServiceProvider) LoginURL(id string) string {
// if the authRequest does not have the v2 prefix, it was created for login V1
if !strings.HasPrefix(id, command.IDPrefixV2) {
return s.defaultLoginURL + id
}
// any v2 login without a specific base uri will be sent to the configured login v2 UI
// this way we're also backwards compatible
if s.SP.LoginBaseURI == nil || s.SP.LoginBaseURI.String() == "" {
return s.defaultLoginURLV2 + id
}
// for clients with a specific URI (internal or external) we only need to add the auth request id
uri := s.SP.LoginBaseURI.JoinPath(LoginPath)
q := uri.Query()
q.Set(LoginSamlRequestParam, id)
uri.RawQuery = q.Encode()
return uri.String()
}

View File

@@ -17,6 +17,7 @@ import (
"github.com/zitadel/zitadel/internal/actions"
"github.com/zitadel/zitadel/internal/actions/object"
"github.com/zitadel/zitadel/internal/activity"
"github.com/zitadel/zitadel/internal/api/authz"
http_utils "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/auth/repository"
@@ -62,22 +63,12 @@ type Storage struct {
}
func (p *Storage) GetEntityByID(ctx context.Context, entityID string) (*serviceprovider.ServiceProvider, error) {
app, err := p.query.ActiveAppBySAMLEntityID(ctx, entityID)
sp, err := p.query.ActiveSAMLServiceProviderByID(ctx, entityID)
if err != nil {
return nil, err
}
return serviceprovider.NewServiceProvider(
app.ID,
&serviceprovider.Config{
Metadata: app.SAMLConfig.Metadata,
},
func(id string) string {
if strings.HasPrefix(id, command.IDPrefixV2) {
return p.defaultLoginURLv2 + id
}
return p.defaultLoginURL + id
},
)
return ServiceProviderFromBusiness(sp, p.defaultLoginURL, p.defaultLoginURLv2)
}
func (p *Storage) GetEntityIDByAppID(ctx context.Context, appID string) (string, error) {
@@ -108,11 +99,34 @@ func (p *Storage) CreateAuthRequest(ctx context.Context, req *samlp.AuthnRequest
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
// for backwards compatibility we pass the login client if set
headers, _ := http_utils.HeadersFromCtx(ctx)
if loginClient := headers.Get(LoginClientHeader); loginClient != "" {
loginClient := headers.Get(LoginClientHeader)
// for backwards compatibility we'll use the new login if the header is set (no matter the other configs)
if loginClient != "" {
return p.createAuthRequestLoginClient(ctx, req, acsUrl, protocolBinding, relayState, applicationID, loginClient)
}
return p.createAuthRequest(ctx, req, acsUrl, protocolBinding, relayState, applicationID)
// if the instance requires the v2 login, use it no matter what the application configured
if authz.GetFeatures(ctx).LoginV2.Required {
return p.createAuthRequestLoginClient(ctx, req, acsUrl, protocolBinding, relayState, applicationID, loginClient)
}
version, err := p.query.SAMLAppLoginVersion(ctx, applicationID)
if err != nil {
return nil, err
}
switch version {
case domain.LoginVersion1:
return p.createAuthRequest(ctx, req, acsUrl, protocolBinding, relayState, applicationID)
case domain.LoginVersion2:
return p.createAuthRequestLoginClient(ctx, req, acsUrl, protocolBinding, relayState, applicationID, loginClient)
case domain.LoginVersionUnspecified:
fallthrough
default:
// since we already checked for a login header, we can fall back to the v1 login
return p.createAuthRequest(ctx, req, acsUrl, protocolBinding, relayState, applicationID)
}
}
func (p *Storage) createAuthRequestLoginClient(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID, loginClient string) (_ models.AuthRequestInt, err error) {

View File

@@ -1264,10 +1264,10 @@ func TestCommandSide_RemoveOrg(t *testing.T) {
),
expectFilter(
eventFromEventPusher(
project.NewSAMLConfigAddedEvent(context.Background(), &project.NewAggregate("project1", "org1").Aggregate, "app1", "entity1", []byte{}, ""),
project.NewSAMLConfigAddedEvent(context.Background(), &project.NewAggregate("project1", "org1").Aggregate, "app1", "entity1", []byte{}, "", domain.LoginVersionUnspecified, ""),
),
eventFromEventPusher(
project.NewSAMLConfigAddedEvent(context.Background(), &project.NewAggregate("project2", "org1").Aggregate, "app2", "entity2", []byte{}, ""),
project.NewSAMLConfigAddedEvent(context.Background(), &project.NewAggregate("project2", "org1").Aggregate, "app2", "entity2", []byte{}, "", domain.LoginVersionUnspecified, ""),
),
),
expectPush(

View File

@@ -325,10 +325,10 @@ func (wm *OIDCApplicationWriteModel) NewChangedEvent(
changes = append(changes, project.ChangeBackChannelLogoutURI(backChannelLogoutURI))
}
if wm.LoginVersion != loginVersion {
changes = append(changes, project.ChangeLoginVersion(loginVersion))
changes = append(changes, project.ChangeOIDCLoginVersion(loginVersion))
}
if wm.LoginBaseURI != loginBaseURI {
changes = append(changes, project.ChangeLoginBaseURI(loginBaseURI))
changes = append(changes, project.ChangeOIDCLoginBaseURI(loginBaseURI))
}
if len(changes) == 0 {

View File

@@ -1297,8 +1297,8 @@ func newOIDCAppChangedEvent(ctx context.Context, appID, projectID, resourceOwner
project.ChangeIDTokenRoleAssertion(false),
project.ChangeIDTokenUserinfoAssertion(false),
project.ChangeClockSkew(time.Second * 2),
project.ChangeLoginVersion(domain.LoginVersion2),
project.ChangeLoginBaseURI("https://login.test.ch"),
project.ChangeOIDCLoginVersion(domain.LoginVersion2),
project.ChangeOIDCLoginBaseURI("https://login.test.ch"),
}
event, _ := project.NewOIDCConfigChangedEvent(ctx,
&project.NewAggregate(projectID, resourceOwner).Aggregate,

View File

@@ -79,6 +79,8 @@ func (c *Commands) addSAMLApplication(ctx context.Context, projectAgg *eventstor
string(entity.EntityID),
samlApp.Metadata,
samlApp.MetadataURL,
samlApp.LoginVersion,
samlApp.LoginBaseURI,
),
}, nil
}
@@ -119,7 +121,10 @@ func (c *Commands) ChangeSAMLApplication(ctx context.Context, samlApp *domain.SA
samlApp.AppID,
string(entity.EntityID),
samlApp.Metadata,
samlApp.MetadataURL)
samlApp.MetadataURL,
samlApp.LoginVersion,
samlApp.LoginBaseURI,
)
if err != nil {
return nil, err
}

View File

@@ -12,11 +12,13 @@ import (
type SAMLApplicationWriteModel struct {
eventstore.WriteModel
AppID string
AppName string
EntityID string
Metadata []byte
MetadataURL string
AppID string
AppName string
EntityID string
Metadata []byte
MetadataURL string
LoginVersion domain.LoginVersion
LoginBaseURI string
State domain.AppState
saml bool
@@ -121,6 +123,8 @@ func (wm *SAMLApplicationWriteModel) appendAddSAMLEvent(e *project.SAMLConfigAdd
wm.Metadata = e.Metadata
wm.MetadataURL = e.MetadataURL
wm.EntityID = e.EntityID
wm.LoginVersion = e.LoginVersion
wm.LoginBaseURI = e.LoginBaseURI
}
func (wm *SAMLApplicationWriteModel) appendChangeSAMLEvent(e *project.SAMLConfigChangedEvent) {
@@ -134,6 +138,12 @@ func (wm *SAMLApplicationWriteModel) appendChangeSAMLEvent(e *project.SAMLConfig
if e.EntityID != "" {
wm.EntityID = e.EntityID
}
if e.LoginVersion != nil {
wm.LoginVersion = *e.LoginVersion
}
if e.LoginBaseURI != nil {
wm.LoginBaseURI = *e.LoginBaseURI
}
}
func (wm *SAMLApplicationWriteModel) Query() *eventstore.SearchQueryBuilder {
@@ -161,6 +171,8 @@ func (wm *SAMLApplicationWriteModel) NewChangedEvent(
entityID string,
metadata []byte,
metadataURL string,
loginVersion domain.LoginVersion,
loginBaseURI string,
) (*project.SAMLConfigChangedEvent, bool, error) {
changes := make([]project.SAMLConfigChanges, 0)
var err error
@@ -173,6 +185,12 @@ func (wm *SAMLApplicationWriteModel) NewChangedEvent(
if wm.EntityID != entityID {
changes = append(changes, project.ChangeEntityID(entityID))
}
if wm.LoginVersion != loginVersion {
changes = append(changes, project.ChangeSAMLLoginVersion(loginVersion))
}
if wm.LoginBaseURI != loginBaseURI {
changes = append(changes, project.ChangeSAMLLoginBaseURI(loginBaseURI))
}
if len(changes) == 0 {
return nil, false, nil

View File

@@ -50,7 +50,7 @@ var testMetadataChangedEntityID = []byte(`<?xml version="1.0"?>
func TestCommandSide_AddSAMLApplication(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
eventstore func(t *testing.T) *eventstore.Eventstore
idGenerator id.Generator
httpClient *http.Client
}
@@ -72,9 +72,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
{
name: "no aggregate id, invalid argument error",
fields: fields{
eventstore: eventstoreExpect(
t,
),
eventstore: expectEventstore(),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
@@ -88,8 +86,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
{
name: "project not existing, not found error",
fields: fields{
eventstore: eventstoreExpect(
t,
eventstore: expectEventstore(
expectFilter(),
),
},
@@ -111,8 +108,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
{
name: "invalid app, invalid argument error",
fields: fields{
eventstore: eventstoreExpect(
t,
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
project.NewProjectAddedEvent(context.Background(),
@@ -141,8 +137,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
{
name: "create saml app, metadata not parsable",
fields: fields{
eventstore: eventstoreExpect(
t,
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
project.NewProjectAddedEvent(context.Background(),
@@ -174,8 +169,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
{
name: "create saml app, ok",
fields: fields{
eventstore: eventstoreExpect(
t,
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
project.NewProjectAddedEvent(context.Background(),
@@ -196,6 +190,8 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
"https://test.com/saml/metadata",
testMetadata,
"",
domain.LoginVersionUnspecified,
"",
),
),
),
@@ -229,11 +225,73 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
},
},
},
{
name: "create saml app, loginversion, ok",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
project.NewProjectAddedEvent(context.Background(),
&project.NewAggregate("project1", "org1").Aggregate,
"project", true, true, true,
domain.PrivateLabelingSettingUnspecified),
),
),
expectPush(
project.NewApplicationAddedEvent(context.Background(),
&project.NewAggregate("project1", "org1").Aggregate,
"app1",
"app",
),
project.NewSAMLConfigAddedEvent(context.Background(),
&project.NewAggregate("project1", "org1").Aggregate,
"app1",
"https://test.com/saml/metadata",
testMetadata,
"",
domain.LoginVersion2,
"https://test.com/login",
),
),
),
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "app1"),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
samlApp: &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
},
AppName: "app",
EntityID: "https://test.com/saml/metadata",
Metadata: testMetadata,
MetadataURL: "",
LoginVersion: domain.LoginVersion2,
LoginBaseURI: "https://test.com/login",
},
resourceOwner: "org1",
},
res: res{
want: &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
ResourceOwner: "org1",
},
AppID: "app1",
AppName: "app",
EntityID: "https://test.com/saml/metadata",
Metadata: testMetadata,
MetadataURL: "",
State: domain.AppStateActive,
LoginVersion: domain.LoginVersion2,
LoginBaseURI: "https://test.com/login",
},
},
},
{
name: "create saml app metadataURL, ok",
fields: fields{
eventstore: eventstoreExpect(
t,
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
project.NewProjectAddedEvent(context.Background(),
@@ -254,6 +312,8 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
"https://test.com/saml/metadata",
testMetadata,
"http://localhost:8080/saml/metadata",
domain.LoginVersionUnspecified,
"",
),
),
),
@@ -291,8 +351,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
{
name: "create saml app metadataURL, http error",
fields: fields{
eventstore: eventstoreExpect(
t,
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
project.NewProjectAddedEvent(context.Background(),
@@ -327,7 +386,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore,
eventstore: tt.fields.eventstore(t),
idGenerator: tt.fields.idGenerator,
httpClient: tt.fields.httpClient,
}
@@ -348,7 +407,7 @@ func TestCommandSide_AddSAMLApplication(t *testing.T) {
func TestCommandSide_ChangeSAMLApplication(t *testing.T) {
type fields struct {
eventstore *eventstore.Eventstore
eventstore func(t *testing.T) *eventstore.Eventstore
httpClient *http.Client
}
type args struct {
@@ -369,9 +428,7 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) {
{
name: "invalid app, invalid argument error",
fields: fields{
eventstore: eventstoreExpect(
t,
),
eventstore: expectEventstore(),
},
args: args{
ctx: context.Background(),
@@ -390,9 +447,7 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) {
{
name: "missing appid, invalid argument error",
fields: fields{
eventstore: eventstoreExpect(
t,
),
eventstore: expectEventstore(),
},
args: args{
ctx: context.Background(),
@@ -412,9 +467,7 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) {
{
name: "missing aggregateid, invalid argument error",
fields: fields{
eventstore: eventstoreExpect(
t,
),
eventstore: expectEventstore(),
},
args: args{
ctx: context.Background(),
@@ -434,8 +487,7 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) {
{
name: "app not existing, not found error",
fields: fields{
eventstore: eventstoreExpect(
t,
eventstore: expectEventstore(
expectFilter(),
),
},
@@ -457,8 +509,7 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) {
{
name: "no changes, precondition error, metadataURL",
fields: fields{
eventstore: eventstoreExpect(
t,
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
project.NewApplicationAddedEvent(context.Background(),
@@ -474,6 +525,8 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) {
"https://test.com/saml/metadata",
testMetadata,
"http://localhost:8080/saml/metadata",
domain.LoginVersionUnspecified,
"",
),
),
),
@@ -502,8 +555,7 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) {
{
name: "no changes, precondition error, metadata",
fields: fields{
eventstore: eventstoreExpect(
t,
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
project.NewApplicationAddedEvent(context.Background(),
@@ -519,6 +571,8 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) {
"https://test.com/saml/metadata",
testMetadata,
"",
domain.LoginVersionUnspecified,
"",
),
),
),
@@ -547,8 +601,7 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) {
{
name: "change saml app, ok, metadataURL",
fields: fields{
eventstore: eventstoreExpect(
t,
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
project.NewApplicationAddedEvent(context.Background(),
@@ -564,6 +617,8 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) {
"https://test.com/saml/metadata",
testMetadata,
"http://localhost:8080/saml/metadata",
domain.LoginVersionUnspecified,
"",
),
),
),
@@ -613,8 +668,7 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) {
{
name: "change saml app, ok, metadata",
fields: fields{
eventstore: eventstoreExpect(
t,
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
project.NewApplicationAddedEvent(context.Background(),
@@ -630,6 +684,8 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) {
"https://test.com/saml/metadata",
testMetadata,
"",
domain.LoginVersionUnspecified,
"",
),
),
),
@@ -675,13 +731,85 @@ func TestCommandSide_ChangeSAMLApplication(t *testing.T) {
State: domain.AppStateActive,
},
},
}, {
name: "change saml app, ok, loginversion",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
project.NewApplicationAddedEvent(context.Background(),
&project.NewAggregate("project1", "org1").Aggregate,
"app1",
"app",
),
),
eventFromEventPusher(
project.NewSAMLConfigAddedEvent(context.Background(),
&project.NewAggregate("project1", "org1").Aggregate,
"app1",
"https://test.com/saml/metadata",
testMetadata,
"",
domain.LoginVersionUnspecified,
"",
),
),
),
expectPush(
newSAMLAppChangedEventLoginVersion(context.Background(),
"app1",
"project1",
"org1",
"https://test.com/saml/metadata",
"https://test2.com/saml/metadata",
testMetadataChangedEntityID,
domain.LoginVersion2,
"https://test.com/login",
),
),
),
httpClient: nil,
},
args: args{
ctx: context.Background(),
samlApp: &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
ResourceOwner: "org1",
},
AppID: "app1",
AppName: "app",
EntityID: "https://test2.com/saml/metadata",
Metadata: testMetadataChangedEntityID,
MetadataURL: "",
LoginVersion: domain.LoginVersion2,
LoginBaseURI: "https://test.com/login",
},
resourceOwner: "org1",
},
res: res{
want: &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{
AggregateID: "project1",
ResourceOwner: "org1",
},
AppID: "app1",
AppName: "app",
EntityID: "https://test2.com/saml/metadata",
Metadata: testMetadataChangedEntityID,
MetadataURL: "",
State: domain.AppStateActive,
LoginVersion: domain.LoginVersion2,
LoginBaseURI: "https://test.com/login",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &Commands{
eventstore: tt.fields.eventstore,
eventstore: tt.fields.eventstore(t),
httpClient: tt.fields.httpClient,
}
got, err := r.ChangeSAMLApplication(tt.args.ctx, tt.args.samlApp, tt.args.resourceOwner)
@@ -726,6 +854,22 @@ func newSAMLAppChangedEventMetadataURL(ctx context.Context, appID, projectID, re
return event
}
func newSAMLAppChangedEventLoginVersion(ctx context.Context, appID, projectID, resourceOwner, oldEntityID, entityID string, metadata []byte, loginVersion domain.LoginVersion, loginURI string) *project.SAMLConfigChangedEvent {
changes := []project.SAMLConfigChanges{
project.ChangeEntityID(entityID),
project.ChangeMetadata(metadata),
project.ChangeSAMLLoginVersion(loginVersion),
project.ChangeSAMLLoginBaseURI(loginURI),
}
event, _ := project.NewSAMLConfigChangedEvent(ctx,
&project.NewAggregate(projectID, resourceOwner).Aggregate,
appID,
oldEntityID,
changes,
)
return event
}
type roundTripperFunc func(*http.Request) *http.Response
// RoundTrip implements the http.RoundTripper interface.

View File

@@ -596,6 +596,8 @@ func TestCommandSide_RemoveApplication(t *testing.T) {
"https://test.com/saml/metadata",
[]byte("<?xml version=\"1.0\"?>\n<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\"\n validUntil=\"2022-08-26T14:08:16Z\"\n cacheDuration=\"PT604800S\"\n entityID=\"https://test.com/saml/metadata\">\n <md:SPSSODescriptor AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"false\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>\n <md:AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"\n Location=\"https://test.com/saml/acs\"\n index=\"1\" />\n \n </md:SPSSODescriptor>\n</md:EntityDescriptor>"),
"",
domain.LoginVersionUnspecified,
"",
)),
),
expectPush(

View File

@@ -55,13 +55,15 @@ func oidcWriteModelToOIDCConfig(writeModel *OIDCApplicationWriteModel) *domain.O
func samlWriteModelToSAMLConfig(writeModel *SAMLApplicationWriteModel) *domain.SAMLApp {
return &domain.SAMLApp{
ObjectRoot: writeModelToObjectRoot(writeModel.WriteModel),
AppID: writeModel.AppID,
AppName: writeModel.AppName,
State: writeModel.State,
Metadata: writeModel.Metadata,
MetadataURL: writeModel.MetadataURL,
EntityID: writeModel.EntityID,
ObjectRoot: writeModelToObjectRoot(writeModel.WriteModel),
AppID: writeModel.AppID,
AppName: writeModel.AppName,
State: writeModel.State,
Metadata: writeModel.Metadata,
MetadataURL: writeModel.MetadataURL,
EntityID: writeModel.EntityID,
LoginVersion: writeModel.LoginVersion,
LoginBaseURI: writeModel.LoginBaseURI,
}
}

View File

@@ -988,6 +988,8 @@ func TestCommandSide_RemoveProject(t *testing.T) {
"https://test.com/saml/metadata",
[]byte("<?xml version=\"1.0\"?>\n<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\"\n validUntil=\"2022-08-26T14:08:16Z\"\n cacheDuration=\"PT604800S\"\n entityID=\"https://test.com/saml/metadata\">\n <md:SPSSODescriptor AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"false\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>\n <md:AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"\n Location=\"https://test.com/saml/acs\"\n index=\"1\" />\n \n </md:SPSSODescriptor>\n</md:EntityDescriptor>"),
"http://localhost:8080/saml/metadata",
domain.LoginVersionUnspecified,
"",
),
),
),
@@ -1039,6 +1041,8 @@ func TestCommandSide_RemoveProject(t *testing.T) {
"https://test1.com/saml/metadata",
[]byte("<?xml version=\"1.0\"?>\n<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\"\n validUntil=\"2022-08-26T14:08:16Z\"\n cacheDuration=\"PT604800S\"\n entityID=\"https://test.com/saml/metadata\">\n <md:SPSSODescriptor AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"false\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>\n <md:AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"\n Location=\"https://test.com/saml/acs\"\n index=\"1\" />\n \n </md:SPSSODescriptor>\n</md:EntityDescriptor>"),
"",
domain.LoginVersionUnspecified,
"",
),
),
eventFromEventPusher(project.NewApplicationAddedEvent(context.Background(),
@@ -1053,6 +1057,8 @@ func TestCommandSide_RemoveProject(t *testing.T) {
"https://test2.com/saml/metadata",
[]byte("<?xml version=\"1.0\"?>\n<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\"\n validUntil=\"2022-08-26T14:08:16Z\"\n cacheDuration=\"PT604800S\"\n entityID=\"https://test.com/saml/metadata\">\n <md:SPSSODescriptor AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"false\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>\n <md:AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"\n Location=\"https://test.com/saml/acs\"\n index=\"1\" />\n \n </md:SPSSODescriptor>\n</md:EntityDescriptor>"),
"",
domain.LoginVersionUnspecified,
"",
),
),
eventFromEventPusher(project.NewApplicationAddedEvent(context.Background(),
@@ -1067,6 +1073,8 @@ func TestCommandSide_RemoveProject(t *testing.T) {
"https://test3.com/saml/metadata",
[]byte("<?xml version=\"1.0\"?>\n<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\"\n validUntil=\"2022-08-26T14:08:16Z\"\n cacheDuration=\"PT604800S\"\n entityID=\"https://test.com/saml/metadata\">\n <md:SPSSODescriptor AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"false\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>\n <md:AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"\n Location=\"https://test.com/saml/acs\"\n index=\"1\" />\n \n </md:SPSSODescriptor>\n</md:EntityDescriptor>"),
"",
domain.LoginVersionUnspecified,
"",
),
),
),

View File

@@ -75,7 +75,9 @@ func (c *Commands) LinkSessionToSAMLRequest(ctx context.Context, id, sessionID,
return nil, nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-ttPKNdAIFT", "Errors.SAMLRequest.AlreadyHandled")
}
if checkLoginClient && authz.GetCtxData(ctx).UserID != writeModel.LoginClient {
return nil, nil, zerrors.ThrowPermissionDenied(nil, "COMMAND-KCd48Rxt7x", "Errors.SAMLRequest.WrongLoginClient")
if err := c.checkPermission(ctx, domain.PermissionSessionLink, writeModel.ResourceOwner, ""); err != nil {
return nil, nil, err
}
}
sessionWriteModel := NewSessionWriteModel(sessionID, authz.GetInstance(ctx).InstanceID())
err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel)

View File

@@ -132,8 +132,9 @@ func TestCommands_AddSAMLRequest(t *testing.T) {
func TestCommands_LinkSessionToSAMLRequest(t *testing.T) {
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
type fields struct {
eventstore func(t *testing.T) *eventstore.Eventstore
tokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error)
eventstore func(t *testing.T) *eventstore.Eventstore
tokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error)
checkPermission domain.PermissionCheck
}
type args struct {
ctx context.Context
@@ -207,7 +208,7 @@ func TestCommands_LinkSessionToSAMLRequest(t *testing.T) {
},
},
{
"wrong login client",
"wrong login client / not permitted",
fields{
eventstore: expectEventstore(
expectFilter(
@@ -225,7 +226,8 @@ func TestCommands_LinkSessionToSAMLRequest(t *testing.T) {
),
),
),
tokenVerifier: newMockTokenVerifierValid(),
tokenVerifier: newMockTokenVerifierValid(),
checkPermission: newMockPermissionCheckNotAllowed(),
},
args{
ctx: authz.NewMockContext("instanceID", "orgID", "wrongLoginClient"),
@@ -235,7 +237,7 @@ func TestCommands_LinkSessionToSAMLRequest(t *testing.T) {
checkLoginClient: true,
},
res{
wantErr: zerrors.ThrowPermissionDenied(nil, "COMMAND-KCd48Rxt7x", "Errors.SAMLRequest.WrongLoginClient"),
wantErr: zerrors.ThrowPermissionDenied(nil, "AUTHZ-HKJD33", "Errors.PermissionDenied"),
},
},
{
@@ -524,6 +526,86 @@ func TestCommands_LinkSessionToSAMLRequest(t *testing.T) {
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
},
},
}, {
"linked with permission",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
samlrequest.NewAddedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"loginClient",
"application",
"acs",
"relaystate",
"request",
"binding",
"issuer",
"destination",
),
),
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(mockCtx,
&session.NewAggregate("sessionID", "instance1").Aggregate,
&domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
)),
eventFromEventPusher(
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
"userID", "org1", testNow, &language.Afrikaans),
),
eventFromEventPusher(
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
testNow),
),
eventFromEventPusherWithCreationDateNow(
session.NewLifetimeSetEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
2*time.Minute),
),
),
expectPush(
samlrequest.NewSessionLinkedEvent(mockCtx, &samlrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"sessionID",
"userID",
testNow,
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
),
),
),
tokenVerifier: newMockTokenVerifierValid(),
checkPermission: newMockPermissionCheckAllowed(),
},
args{
ctx: authz.NewMockContext("instanceID", "orgID", "loginClient"),
id: "V2_id",
sessionID: "sessionID",
sessionToken: "token",
checkLoginClient: true,
},
res{
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
authReq: &CurrentSAMLRequest{
SAMLRequest: &SAMLRequest{
ID: "V2_id",
LoginClient: "loginClient",
ApplicationID: "application",
ACSURL: "acs",
RelayState: "relaystate",
RequestID: "request",
Binding: "binding",
Issuer: "issuer",
Destination: "destination",
},
SessionID: "sessionID",
UserID: "userID",
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
},
},
},
{
"linked with login client check, application permission check",
@@ -669,6 +751,7 @@ func TestCommands_LinkSessionToSAMLRequest(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
sessionTokenVerifier: tt.fields.tokenVerifier,
checkPermission: tt.fields.checkPermission,
}
details, got, err := c.LinkSessionToSAMLRequest(tt.args.ctx, tt.args.id, tt.args.sessionID, tt.args.sessionToken, tt.args.checkLoginClient, tt.args.checkPermission)
require.ErrorIs(t, err, tt.res.wantErr)

View File

@@ -7,11 +7,13 @@ import (
type SAMLApp struct {
models.ObjectRoot
AppID string
AppName string
EntityID string
Metadata []byte
MetadataURL string
AppID string
AppName string
EntityID string
Metadata []byte
MetadataURL string
LoginVersion LoginVersion
LoginBaseURI string
State AppState
}

View File

@@ -20,6 +20,7 @@ import (
http_util "github.com/zitadel/zitadel/internal/api/http"
oidc_internal "github.com/zitadel/zitadel/internal/api/oidc"
app_pb "github.com/zitadel/zitadel/pkg/grpc/app"
"github.com/zitadel/zitadel/pkg/grpc/management"
saml_pb "github.com/zitadel/zitadel/pkg/grpc/saml/v2"
session_pb "github.com/zitadel/zitadel/pkg/grpc/session/v2"
@@ -102,7 +103,7 @@ func CreateSAMLSP(root string, idpMetadata *saml.EntityDescriptor, binding strin
return sp, nil
}
func (i *Instance) CreateSAMLClient(ctx context.Context, projectID string, m *samlsp.Middleware) (*management.AddSAMLAppResponse, error) {
func (i *Instance) CreateSAMLClientLoginVersion(ctx context.Context, projectID string, m *samlsp.Middleware, loginVersion *app_pb.LoginVersion) (*management.AddSAMLAppResponse, error) {
spMetadata, err := xml.MarshalIndent(m.ServiceProvider.Metadata(), "", " ")
if err != nil {
return nil, err
@@ -114,9 +115,10 @@ func (i *Instance) CreateSAMLClient(ctx context.Context, projectID string, m *sa
}
resp, err := i.Client.Mgmt.AddSAMLApp(ctx, &management.AddSAMLAppRequest{
ProjectId: projectID,
Name: fmt.Sprintf("app-%s", gofakeit.AppName()),
Metadata: &management.AddSAMLAppRequest_MetadataXml{MetadataXml: spMetadata},
ProjectId: projectID,
Name: fmt.Sprintf("app-%s", gofakeit.AppName()),
Metadata: &management.AddSAMLAppRequest_MetadataXml{MetadataXml: spMetadata},
LoginVersion: loginVersion,
})
if err != nil {
return nil, err
@@ -136,7 +138,19 @@ func (i *Instance) CreateSAMLClient(ctx context.Context, projectID string, m *sa
})
}
func (i *Instance) CreateSAMLAuthRequest(m *samlsp.Middleware, loginClient string, acs saml.Endpoint, relayState string, responseBinding string) (now time.Time, authRequestID string, err error) {
func (i *Instance) CreateSAMLClient(ctx context.Context, projectID string, m *samlsp.Middleware) (*management.AddSAMLAppResponse, error) {
return i.CreateSAMLClientLoginVersion(ctx, projectID, m, nil)
}
func (i *Instance) CreateSAMLAuthRequestWithoutLoginClientHeader(m *samlsp.Middleware, loginBaseURI string, acs saml.Endpoint, relayState, responseBinding string) (now time.Time, authRequestID string, err error) {
return i.createSAMLAuthRequest(m, "", loginBaseURI, acs, relayState, responseBinding)
}
func (i *Instance) CreateSAMLAuthRequest(m *samlsp.Middleware, loginClient string, acs saml.Endpoint, relayState, responseBinding string) (now time.Time, authRequestID string, err error) {
return i.createSAMLAuthRequest(m, loginClient, "", acs, relayState, responseBinding)
}
func (i *Instance) createSAMLAuthRequest(m *samlsp.Middleware, loginClient, loginBaseURI string, acs saml.Endpoint, relayState, responseBinding string) (now time.Time, authRequestID string, err error) {
authReq, err := m.ServiceProvider.MakeAuthenticationRequest(acs.Location, acs.Binding, responseBinding)
if err != nil {
return now, "", err
@@ -147,7 +161,11 @@ func (i *Instance) CreateSAMLAuthRequest(m *samlsp.Middleware, loginClient strin
return now, "", err
}
req, err := GetRequest(redirectURL.String(), map[string]string{oidc_internal.LoginClientHeader: loginClient})
var headers map[string]string
if loginClient != "" {
headers = map[string]string{oidc_internal.LoginClientHeader: loginClient}
}
req, err := GetRequest(redirectURL.String(), headers)
if err != nil {
return now, "", fmt.Errorf("get request: %w", err)
}
@@ -158,11 +176,13 @@ func (i *Instance) CreateSAMLAuthRequest(m *samlsp.Middleware, loginClient strin
return now, "", fmt.Errorf("check redirect: %w", err)
}
prefixWithHost := i.Issuer() + i.Config.LoginURLV2
if !strings.HasPrefix(loc.String(), prefixWithHost) {
return now, "", fmt.Errorf("login location has not prefix %s, but is %s", prefixWithHost, loc.String())
if loginBaseURI == "" {
loginBaseURI = i.Issuer() + i.Config.LoginURLV2
}
return now, strings.TrimPrefix(loc.String(), prefixWithHost), nil
if !strings.HasPrefix(loc.String(), loginBaseURI) {
return now, "", fmt.Errorf("login location has not prefix %s, but is %s", loginBaseURI, loc.String())
}
return now, strings.TrimPrefix(loc.String(), loginBaseURI), nil
}
func (i *Instance) FailSAMLAuthRequest(ctx context.Context, id string, reason saml_pb.ErrorReason) *saml_pb.CreateResponseResponse {

View File

@@ -66,9 +66,11 @@ type OIDCApp struct {
}
type SAMLApp struct {
Metadata []byte
MetadataURL string
EntityID string
Metadata []byte
MetadataURL string
EntityID string
LoginVersion domain.LoginVersion
LoginBaseURI *string
}
type APIApp struct {
@@ -137,6 +139,10 @@ var (
name: projection.AppSAMLTable,
instanceIDCol: projection.AppSAMLConfigColumnInstanceID,
}
AppSAMLConfigColumnInstanceID = Column{
name: projection.AppSAMLConfigColumnInstanceID,
table: appSAMLConfigsTable,
}
AppSAMLConfigColumnAppID = Column{
name: projection.AppSAMLConfigColumnAppID,
table: appSAMLConfigsTable,
@@ -153,6 +159,14 @@ var (
name: projection.AppSAMLConfigColumnMetadataURL,
table: appSAMLConfigsTable,
}
AppSAMLConfigColumnLoginVersion = Column{
name: projection.AppSAMLConfigColumnLoginVersion,
table: appSAMLConfigsTable,
}
AppSAMLConfigColumnLoginBaseURI = Column{
name: projection.AppSAMLConfigColumnLoginBaseURI,
table: appSAMLConfigsTable,
}
)
var (
@@ -320,30 +334,6 @@ func (q *Queries) AppByID(ctx context.Context, appID string, activeOnly bool) (a
return app, err
}
func (q *Queries) ActiveAppBySAMLEntityID(ctx context.Context, entityID string) (app *App, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareSAMLAppQuery(ctx, q.client)
eq := sq.Eq{
AppSAMLConfigColumnEntityID.identifier(): entityID,
AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
AppColumnState.identifier(): domain.AppStateActive,
ProjectColumnState.identifier(): domain.ProjectStateActive,
OrgColumnState.identifier(): domain.OrgStateActive,
}
query, args, err := stmt.Where(eq).ToSql()
if err != nil {
return nil, zerrors.ThrowInternal(err, "QUERY-JgUop", "Errors.Query.SQLStatement")
}
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
app, err = scan(row)
return err
}, query, args...)
return app, err
}
func (q *Queries) ProjectByClientID(ctx context.Context, appID string) (project *Project, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
@@ -591,7 +581,7 @@ func (q *Queries) OIDCClientLoginVersion(ctx context.Context, clientID string) (
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareLoginVersionByClientID(ctx, q.client)
query, scan := prepareLoginVersionByOIDCClientID(ctx, q.client)
eq := sq.Eq{
AppOIDCConfigColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
AppOIDCConfigColumnClientID.identifier(): clientID,
@@ -611,6 +601,30 @@ func (q *Queries) OIDCClientLoginVersion(ctx context.Context, clientID string) (
return loginVersion, nil
}
func (q *Queries) SAMLAppLoginVersion(ctx context.Context, appID string) (loginVersion domain.LoginVersion, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareLoginVersionBySAMLAppID(ctx, q.client)
eq := sq.Eq{
AppSAMLConfigColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
AppSAMLConfigColumnAppID.identifier(): appID,
}
stmt, args, err := query.Where(eq).ToSql()
if err != nil {
return domain.LoginVersionUnspecified, zerrors.ThrowInvalidArgument(err, "QUERY-TnaciwZfp3", "Errors.Query.InvalidRequest")
}
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
loginVersion, err = scan(row)
return err
}, stmt, args...)
if err != nil {
return domain.LoginVersionUnspecified, zerrors.ThrowInternal(err, "QUERY-lvDDwRzIoP", "Errors.Internal")
}
return loginVersion, nil
}
func NewAppNameSearchQuery(method TextComparison, value string) (SearchQuery, error) {
return NewTextQuery(AppColumnName, value, method)
}
@@ -659,6 +673,8 @@ func prepareAppQuery(ctx context.Context, db prepareDatabase, activeOnly bool) (
AppSAMLConfigColumnEntityID.identifier(),
AppSAMLConfigColumnMetadata.identifier(),
AppSAMLConfigColumnMetadataURL.identifier(),
AppSAMLConfigColumnLoginVersion.identifier(),
AppSAMLConfigColumnLoginBaseURI.identifier(),
).From(appsTable.identifier()).
PlaceholderFormat(sq.Dollar)
@@ -726,6 +742,8 @@ func scanApp(row *sql.Row) (*App, error) {
&samlConfig.entityID,
&samlConfig.metadata,
&samlConfig.metadataURL,
&samlConfig.loginVersion,
&samlConfig.loginBaseURI,
)
if err != nil {
@@ -827,61 +845,6 @@ func prepareOIDCAppQuery() (sq.SelectBuilder, func(*sql.Row) (*App, error)) {
}
}
func prepareSAMLAppQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) {
return sq.Select(
AppColumnID.identifier(),
AppColumnName.identifier(),
AppColumnProjectID.identifier(),
AppColumnCreationDate.identifier(),
AppColumnChangeDate.identifier(),
AppColumnResourceOwner.identifier(),
AppColumnState.identifier(),
AppColumnSequence.identifier(),
AppSAMLConfigColumnAppID.identifier(),
AppSAMLConfigColumnEntityID.identifier(),
AppSAMLConfigColumnMetadata.identifier(),
AppSAMLConfigColumnMetadataURL.identifier(),
).From(appsTable.identifier()).
Join(join(AppSAMLConfigColumnAppID, AppColumnID)).
Join(join(ProjectColumnID, AppColumnProjectID)).
Join(join(OrgColumnID, AppColumnResourceOwner)).
PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*App, error) {
app := new(App)
var (
samlConfig = sqlSAMLConfig{}
)
err := row.Scan(
&app.ID,
&app.Name,
&app.ProjectID,
&app.CreationDate,
&app.ChangeDate,
&app.ResourceOwner,
&app.State,
&app.Sequence,
&samlConfig.appID,
&samlConfig.entityID,
&samlConfig.metadata,
&samlConfig.metadataURL,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, zerrors.ThrowNotFound(err, "QUERY-d6TO1", "Errors.App.NotExisting")
}
return nil, zerrors.ThrowInternal(err, "QUERY-NAtPg", "Errors.Internal")
}
samlConfig.set(app)
return app, nil
}
}
func prepareProjectIDByAppQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (projectID string, err error)) {
return sq.Select(
AppColumnProjectID.identifier(),
@@ -1031,6 +994,8 @@ func prepareAppsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder
AppSAMLConfigColumnEntityID.identifier(),
AppSAMLConfigColumnMetadata.identifier(),
AppSAMLConfigColumnMetadataURL.identifier(),
AppSAMLConfigColumnLoginVersion.identifier(),
AppSAMLConfigColumnLoginBaseURI.identifier(),
countColumn.identifier(),
).From(appsTable.identifier()).
LeftJoin(join(AppAPIConfigColumnAppID, AppColumnID)).
@@ -1086,6 +1051,8 @@ func prepareAppsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder
&samlConfig.entityID,
&samlConfig.metadata,
&samlConfig.metadataURL,
&samlConfig.loginVersion,
&samlConfig.loginBaseURI,
&apps.Count,
)
@@ -1135,7 +1102,7 @@ func prepareClientIDsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBu
}
}
func prepareLoginVersionByClientID(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (domain.LoginVersion, error)) {
func prepareLoginVersionByOIDCClientID(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (domain.LoginVersion, error)) {
return sq.Select(
AppOIDCConfigColumnLoginVersion.identifier(),
).From(appOIDCConfigsTable.identifier()).
@@ -1150,6 +1117,21 @@ func prepareLoginVersionByClientID(ctx context.Context, db prepareDatabase) (sq.
}
}
func prepareLoginVersionBySAMLAppID(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (domain.LoginVersion, error)) {
return sq.Select(
AppSAMLConfigColumnLoginVersion.identifier(),
).From(appSAMLConfigsTable.identifier()).
PlaceholderFormat(sq.Dollar), func(row *sql.Row) (domain.LoginVersion, error) {
var loginVersion sql.NullInt16
if err := row.Scan(
&loginVersion,
); err != nil {
return domain.LoginVersionUnspecified, zerrors.ThrowInternal(err, "QUERY-KbzaCnaziI", "Errors.Internal")
}
return domain.LoginVersion(loginVersion.Int16), nil
}
}
type sqlOIDCConfig struct {
appID sql.NullString
version sql.NullInt32
@@ -1209,10 +1191,12 @@ func (c sqlOIDCConfig) set(app *App) {
}
type sqlSAMLConfig struct {
appID sql.NullString
entityID sql.NullString
metadataURL sql.NullString
metadata []byte
appID sql.NullString
entityID sql.NullString
metadataURL sql.NullString
metadata []byte
loginVersion sql.NullInt16
loginBaseURI sql.NullString
}
func (c sqlSAMLConfig) set(app *App) {
@@ -1220,9 +1204,13 @@ func (c sqlSAMLConfig) set(app *App) {
return
}
app.SAMLConfig = &SAMLApp{
MetadataURL: c.metadataURL.String,
Metadata: c.metadata,
EntityID: c.entityID.String,
EntityID: c.entityID.String,
MetadataURL: c.metadataURL.String,
Metadata: c.metadata,
LoginVersion: domain.LoginVersion(c.loginVersion.Int16),
}
if c.loginBaseURI.Valid {
app.SAMLConfig.LoginBaseURI = &c.loginBaseURI.String
}
}

View File

@@ -56,7 +56,9 @@ var (
` projections.apps7_saml_configs.app_id,` +
` projections.apps7_saml_configs.entity_id,` +
` projections.apps7_saml_configs.metadata,` +
` projections.apps7_saml_configs.metadata_url` +
` projections.apps7_saml_configs.metadata_url,` +
` projections.apps7_saml_configs.login_version,` +
` projections.apps7_saml_configs.login_base_uri` +
` FROM projections.apps7` +
` LEFT JOIN projections.apps7_api_configs ON projections.apps7.id = projections.apps7_api_configs.app_id AND projections.apps7.instance_id = projections.apps7_api_configs.instance_id` +
` LEFT JOIN projections.apps7_oidc_configs ON projections.apps7.id = projections.apps7_oidc_configs.app_id AND projections.apps7.instance_id = projections.apps7_oidc_configs.instance_id` +
@@ -103,6 +105,8 @@ var (
` projections.apps7_saml_configs.entity_id,` +
` projections.apps7_saml_configs.metadata,` +
` projections.apps7_saml_configs.metadata_url,` +
` projections.apps7_saml_configs.login_version,` +
` projections.apps7_saml_configs.login_base_uri,` +
` COUNT(*) OVER ()` +
` FROM projections.apps7` +
` LEFT JOIN projections.apps7_api_configs ON projections.apps7.id = projections.apps7_api_configs.app_id AND projections.apps7.instance_id = projections.apps7_api_configs.instance_id` +
@@ -178,6 +182,8 @@ var (
"entity_id",
"metadata",
"metadata_url",
"login_version",
"login_base_uri",
}
appsCols = append(appCols, "count")
)
@@ -252,6 +258,8 @@ func Test_AppsPrepare(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
},
},
),
@@ -321,6 +329,8 @@ func Test_AppsPrepare(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
},
},
),
@@ -393,6 +403,8 @@ func Test_AppsPrepare(t *testing.T) {
"https://test.com/saml/metadata",
[]byte("<?xml version=\"1.0\"?>\n<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\"\n validUntil=\"2022-08-26T14:08:16Z\"\n cacheDuration=\"PT604800S\"\n entityID=\"https://test.com/saml/metadata\">\n <md:SPSSODescriptor AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"false\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>\n <md:AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"\n Location=\"https://test.com/saml/acs\"\n index=\"1\" />\n \n </md:SPSSODescriptor>\n</md:EntityDescriptor>"),
"https://test.com/saml/metadata",
domain.LoginVersionUnspecified,
nil,
},
},
),
@@ -467,6 +479,8 @@ func Test_AppsPrepare(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
},
},
),
@@ -559,6 +573,8 @@ func Test_AppsPrepare(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
},
},
),
@@ -651,6 +667,8 @@ func Test_AppsPrepare(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
},
},
),
@@ -743,6 +761,8 @@ func Test_AppsPrepare(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
},
},
),
@@ -835,6 +855,8 @@ func Test_AppsPrepare(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
},
},
),
@@ -927,6 +949,8 @@ func Test_AppsPrepare(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
},
},
),
@@ -1019,6 +1043,8 @@ func Test_AppsPrepare(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
},
{
"api-app-id",
@@ -1059,6 +1085,8 @@ func Test_AppsPrepare(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
},
{
"saml-app-id",
@@ -1099,6 +1127,8 @@ func Test_AppsPrepare(t *testing.T) {
"https://test.com/saml/metadata",
[]byte("<?xml version=\"1.0\"?>\n<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\"\n validUntil=\"2022-08-26T14:08:16Z\"\n cacheDuration=\"PT604800S\"\n entityID=\"https://test.com/saml/metadata\">\n <md:SPSSODescriptor AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"false\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>\n <md:AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"\n Location=\"https://test.com/saml/acs\"\n index=\"1\" />\n \n </md:SPSSODescriptor>\n</md:EntityDescriptor>"),
"https://test.com/saml/metadata",
domain.LoginVersion2,
"https://login.ch/",
},
},
),
@@ -1165,9 +1195,11 @@ func Test_AppsPrepare(t *testing.T) {
Name: "app-name",
ProjectID: "project-id",
SAMLConfig: &SAMLApp{
Metadata: []byte("<?xml version=\"1.0\"?>\n<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\"\n validUntil=\"2022-08-26T14:08:16Z\"\n cacheDuration=\"PT604800S\"\n entityID=\"https://test.com/saml/metadata\">\n <md:SPSSODescriptor AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"false\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>\n <md:AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"\n Location=\"https://test.com/saml/acs\"\n index=\"1\" />\n \n </md:SPSSODescriptor>\n</md:EntityDescriptor>"),
MetadataURL: "https://test.com/saml/metadata",
EntityID: "https://test.com/saml/metadata",
Metadata: []byte("<?xml version=\"1.0\"?>\n<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\"\n validUntil=\"2022-08-26T14:08:16Z\"\n cacheDuration=\"PT604800S\"\n entityID=\"https://test.com/saml/metadata\">\n <md:SPSSODescriptor AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"false\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>\n <md:AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"\n Location=\"https://test.com/saml/acs\"\n index=\"1\" />\n \n </md:SPSSODescriptor>\n</md:EntityDescriptor>"),
MetadataURL: "https://test.com/saml/metadata",
EntityID: "https://test.com/saml/metadata",
LoginVersion: domain.LoginVersion2,
LoginBaseURI: gu.Ptr("https://login.ch/"),
},
},
},
@@ -1280,6 +1312,8 @@ func Test_AppPrepare(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
},
),
},
@@ -1343,6 +1377,8 @@ func Test_AppPrepare(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
},
},
),
@@ -1411,6 +1447,8 @@ func Test_AppPrepare(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
},
},
),
@@ -1498,6 +1536,8 @@ func Test_AppPrepare(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
},
},
),
@@ -1585,6 +1625,8 @@ func Test_AppPrepare(t *testing.T) {
"https://test.com/saml/metadata",
[]byte("<?xml version=\"1.0\"?>\n<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\"\n validUntil=\"2022-08-26T14:08:16Z\"\n cacheDuration=\"PT604800S\"\n entityID=\"https://test.com/saml/metadata\">\n <md:SPSSODescriptor AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"false\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>\n <md:AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"\n Location=\"https://test.com/saml/acs\"\n index=\"1\" />\n \n </md:SPSSODescriptor>\n</md:EntityDescriptor>"),
"https://test.com/saml/metadata",
domain.LoginVersionUnspecified,
nil,
},
},
),
@@ -1599,9 +1641,11 @@ func Test_AppPrepare(t *testing.T) {
Name: "app-name",
ProjectID: "project-id",
SAMLConfig: &SAMLApp{
Metadata: []byte("<?xml version=\"1.0\"?>\n<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\"\n validUntil=\"2022-08-26T14:08:16Z\"\n cacheDuration=\"PT604800S\"\n entityID=\"https://test.com/saml/metadata\">\n <md:SPSSODescriptor AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"false\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>\n <md:AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"\n Location=\"https://test.com/saml/acs\"\n index=\"1\" />\n \n </md:SPSSODescriptor>\n</md:EntityDescriptor>"),
MetadataURL: "https://test.com/saml/metadata",
EntityID: "https://test.com/saml/metadata",
Metadata: []byte("<?xml version=\"1.0\"?>\n<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\"\n validUntil=\"2022-08-26T14:08:16Z\"\n cacheDuration=\"PT604800S\"\n entityID=\"https://test.com/saml/metadata\">\n <md:SPSSODescriptor AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"false\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>\n <md:AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"\n Location=\"https://test.com/saml/acs\"\n index=\"1\" />\n \n </md:SPSSODescriptor>\n</md:EntityDescriptor>"),
MetadataURL: "https://test.com/saml/metadata",
EntityID: "https://test.com/saml/metadata",
LoginVersion: domain.LoginVersionUnspecified,
LoginBaseURI: nil,
},
},
},
@@ -1654,6 +1698,8 @@ func Test_AppPrepare(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
},
},
),
@@ -1741,6 +1787,8 @@ func Test_AppPrepare(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
},
},
),
@@ -1828,6 +1876,8 @@ func Test_AppPrepare(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
},
},
),
@@ -1915,6 +1965,8 @@ func Test_AppPrepare(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
},
},
),

View File

@@ -4,6 +4,7 @@ import (
"database/sql"
"database/sql/driver"
_ "embed"
"net/url"
"regexp"
"testing"
@@ -19,6 +20,8 @@ import (
var (
//go:embed testdata/oidc_client_jwt.json
testdataOidcClientJWT string
//go:embed testdata/oidc_client_jwt_loginversion.json
testdataOidcClientJWTLoginVersion string
//go:embed testdata/oidc_client_public.json
testdataOidcClientPublic string
//go:embed testdata/oidc_client_public_old_id.json
@@ -91,6 +94,44 @@ low2kyJov38V4Uk2I8kuXpLcnrpw5Tio2ooiUE27b0vHZqBKOei9Uo88qCrn3EKx
},
},
},
{
name: "jwt client, login version",
mock: mockQuery(expQuery, cols, []driver.Value{testdataOidcClientJWTLoginVersion}, "instanceID", "clientID", true),
want: &OIDCClient{
InstanceID: "230690539048009730",
AppID: "236647088211886082",
State: domain.AppStateActive,
ClientID: "236647088211951618",
HashedSecret: "",
RedirectURIs: []string{"http://localhost:9999/auth/callback"},
ResponseTypes: []domain.OIDCResponseType{domain.OIDCResponseTypeCode},
GrantTypes: []domain.OIDCGrantType{domain.OIDCGrantTypeAuthorizationCode, domain.OIDCGrantTypeRefreshToken},
ApplicationType: domain.OIDCApplicationTypeWeb,
AuthMethodType: domain.OIDCAuthMethodTypePrivateKeyJWT,
PostLogoutRedirectURIs: []string{"https://example.com/logout"},
IsDevMode: true,
AccessTokenType: domain.OIDCTokenTypeJWT,
AccessTokenRoleAssertion: true,
IDTokenRoleAssertion: true,
IDTokenUserinfoAssertion: true,
ClockSkew: 1000000000,
AdditionalOrigins: []string{"https://example.com"},
ProjectID: "236645808328409090",
ProjectRoleAssertion: true,
PublicKeys: map[string][]byte{"236647201860747266": []byte(pubkey)},
ProjectRoleKeys: []string{"role1", "role2"},
Settings: &OIDCSettings{
AccessTokenLifetime: 43200000000000,
IdTokenLifetime: 43200000000000,
},
LoginVersion: domain.LoginVersion1,
LoginBaseURI: func() *URL {
ret, _ := url.Parse("https://test.com/login")
retURL := URL(*ret)
return &retURL
}(),
},
},
{
name: "public client",
mock: mockQuery(expQuery, cols, []driver.Value{testdataOidcClientPublic}, "instanceID", "clientID", true),

View File

@@ -62,12 +62,14 @@ const (
AppOIDCConfigColumnLoginVersion = "login_version"
AppOIDCConfigColumnLoginBaseURI = "login_base_uri"
appSAMLTableSuffix = "saml_configs"
AppSAMLConfigColumnAppID = "app_id"
AppSAMLConfigColumnInstanceID = "instance_id"
AppSAMLConfigColumnEntityID = "entity_id"
AppSAMLConfigColumnMetadata = "metadata"
AppSAMLConfigColumnMetadataURL = "metadata_url"
appSAMLTableSuffix = "saml_configs"
AppSAMLConfigColumnAppID = "app_id"
AppSAMLConfigColumnInstanceID = "instance_id"
AppSAMLConfigColumnEntityID = "entity_id"
AppSAMLConfigColumnMetadata = "metadata"
AppSAMLConfigColumnMetadataURL = "metadata_url"
AppSAMLConfigColumnLoginVersion = "login_version"
AppSAMLConfigColumnLoginBaseURI = "login_base_uri"
)
type appProjection struct{}
@@ -143,6 +145,8 @@ func (*appProjection) Init() *old_handler.Check {
handler.NewColumn(AppSAMLConfigColumnEntityID, handler.ColumnTypeText),
handler.NewColumn(AppSAMLConfigColumnMetadata, handler.ColumnTypeBytes),
handler.NewColumn(AppSAMLConfigColumnMetadataURL, handler.ColumnTypeText),
handler.NewColumn(AppSAMLConfigColumnLoginVersion, handler.ColumnTypeEnum, handler.Nullable()),
handler.NewColumn(AppSAMLConfigColumnLoginBaseURI, handler.ColumnTypeText, handler.Nullable()),
},
handler.NewPrimaryKey(AppSAMLConfigColumnInstanceID, AppSAMLConfigColumnAppID),
appSAMLTableSuffix,
@@ -703,6 +707,8 @@ func (p *appProjection) reduceSAMLConfigAdded(event eventstore.Event) (*handler.
handler.NewCol(AppSAMLConfigColumnEntityID, e.EntityID),
handler.NewCol(AppSAMLConfigColumnMetadata, e.Metadata),
handler.NewCol(AppSAMLConfigColumnMetadataURL, e.MetadataURL),
handler.NewCol(AppSAMLConfigColumnLoginVersion, e.LoginVersion),
handler.NewCol(AppSAMLConfigColumnLoginBaseURI, e.LoginBaseURI),
},
handler.WithTableSuffix(appSAMLTableSuffix),
),
@@ -735,6 +741,12 @@ func (p *appProjection) reduceSAMLConfigChanged(event eventstore.Event) (*handle
if e.EntityID != "" {
cols = append(cols, handler.NewCol(AppSAMLConfigColumnEntityID, e.EntityID))
}
if e.LoginVersion != nil {
cols = append(cols, handler.NewCol(AppSAMLConfigColumnLoginVersion, *e.LoginVersion))
}
if e.LoginBaseURI != nil {
cols = append(cols, handler.NewCol(AppSAMLConfigColumnLoginBaseURI, *e.LoginBaseURI))
}
if len(cols) == 0 {
return handler.NewNoOpStatement(e), nil

104
internal/query/saml_sp.go Normal file
View File

@@ -0,0 +1,104 @@
package query
import (
"context"
"database/sql"
_ "embed"
"errors"
"net/url"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
type SAMLServiceProvider struct {
InstanceID string `json:"instance_id,omitempty"`
AppID string `json:"app_id,omitempty"`
State domain.AppState `json:"state,omitempty"`
EntityID string `json:"entity_id,omitempty"`
Metadata []byte `json:"metadata,omitempty"`
MetadataURL string `json:"metadata_url,omitempty"`
ProjectID string `json:"project_id,omitempty"`
ProjectRoleAssertion bool `json:"project_role_assertion,omitempty"`
LoginVersion domain.LoginVersion `json:"login_version,omitempty"`
LoginBaseURI *url.URL `json:"login_base_uri,omitempty"`
}
//go:embed saml_sp_by_id.sql
var samlSPQuery string
func (q *Queries) ActiveSAMLServiceProviderByID(ctx context.Context, entityID string) (sp *SAMLServiceProvider, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
sp, err = scanSAMLServiceProviderByID(row)
return err
}, samlSPQuery,
authz.GetInstance(ctx).InstanceID(),
entityID,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, zerrors.ThrowNotFound(err, "QUERY-HeOcis2511", "Errors.App.NotFound")
}
if err != nil {
return nil, zerrors.ThrowInternal(err, "QUERY-OyJx1Rp30z", "Errors.Internal")
}
instance := authz.GetInstance(ctx)
loginV2 := instance.Features().LoginV2
if loginV2.Required {
sp.LoginVersion = domain.LoginVersion2
sp.LoginBaseURI = loginV2.BaseURI
}
return sp, err
}
func scanSAMLServiceProviderByID(row *sql.Row) (*SAMLServiceProvider, error) {
var instanceID, appID, entityID, metadataURL, projectID sql.NullString
var projectRoleAssertion sql.NullBool
var metadata []byte
var state, loginVersion sql.NullInt16
var loginBaseURI sql.NullString
err := row.Scan(
&instanceID,
&appID,
&state,
&entityID,
&metadata,
&metadataURL,
&projectID,
&projectRoleAssertion,
&loginVersion,
&loginBaseURI,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, zerrors.ThrowNotFound(err, "QUERY-8cjj8ao6yY", "Errors.App.NotFound")
}
return nil, zerrors.ThrowInternal(err, "QUERY-1xzFD209Bp", "Errors.Internal")
}
sp := &SAMLServiceProvider{
InstanceID: instanceID.String,
AppID: appID.String,
State: domain.AppState(state.Int16),
EntityID: entityID.String,
Metadata: metadata,
MetadataURL: metadataURL.String,
ProjectID: projectID.String,
ProjectRoleAssertion: projectRoleAssertion.Bool,
}
if loginVersion.Valid {
sp.LoginVersion = domain.LoginVersion(loginVersion.Int16)
}
if loginBaseURI.Valid && loginBaseURI.String != "" {
url, err := url.Parse(loginBaseURI.String)
if err != nil {
return nil, err
}
sp.LoginBaseURI = url
}
return sp, nil
}

View File

@@ -0,0 +1,19 @@
select c.instance_id,
c.app_id,
a.state,
c.entity_id,
c.metadata,
c.metadata_url,
a.project_id,
p.project_role_assertion,
c.login_version,
c.login_base_uri
from projections.apps7_saml_configs c
join projections.apps7 a
on a.id = c.app_id and a.instance_id = c.instance_id and a.state = 1
join projections.projects4 p
on p.id = a.project_id and p.instance_id = a.instance_id and p.state = 1
join projections.orgs1 o
on o.id = p.resource_owner and o.instance_id = c.instance_id and o.org_state = 1
where c.instance_id = $1
and c.entity_id = $2

View File

@@ -0,0 +1,123 @@
package query
import (
"database/sql"
"database/sql/driver"
_ "embed"
"net/url"
"regexp"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/zerrors"
)
func TestQueries_ActiveSAMLServiceProviderByID(t *testing.T) {
expQuery := regexp.QuoteMeta(samlSPQuery)
cols := []string{
"instance_id",
"app_id",
"state",
"entity_id",
"metadata",
"metadata_url",
"project_id",
"project_role_assertion",
"login_version",
"login_base_uri",
}
tests := []struct {
name string
mock sqlExpectation
want *SAMLServiceProvider
wantErr error
}{
{
name: "no rows",
mock: mockQueryErr(expQuery, sql.ErrNoRows, "instanceID", "entityID"),
wantErr: zerrors.ThrowNotFound(sql.ErrNoRows, "QUERY-HeOcis2511", "Errors.App.NotFound"),
},
{
name: "internal error",
mock: mockQueryErr(expQuery, sql.ErrConnDone, "instanceID", "entityID"),
wantErr: zerrors.ThrowInternal(sql.ErrConnDone, "QUERY-OyJx1Rp30z", "Errors.Internal"),
},
{
name: "sp",
mock: mockQuery(expQuery, cols, []driver.Value{
"230690539048009730",
"236647088211886082",
domain.AppStateActive,
"https://test.com/metadata",
"metadata",
"https://test.com/metadata",
"236645808328409090",
true,
domain.LoginVersionUnspecified,
"",
}, "instanceID", "entityID"),
want: &SAMLServiceProvider{
InstanceID: "230690539048009730",
AppID: "236647088211886082",
State: domain.AppStateActive,
EntityID: "https://test.com/metadata",
Metadata: []byte("metadata"),
MetadataURL: "https://test.com/metadata",
ProjectID: "236645808328409090",
ProjectRoleAssertion: true,
},
},
{
name: "sp with loginversion",
mock: mockQuery(expQuery, cols, []driver.Value{
"230690539048009730",
"236647088211886082",
domain.AppStateActive,
"https://test.com/metadata",
"metadata",
"https://test.com/metadata",
"236645808328409090",
true,
domain.LoginVersion2,
"https://test.com/login",
}, "instanceID", "entityID"),
want: &SAMLServiceProvider{
InstanceID: "230690539048009730",
AppID: "236647088211886082",
State: domain.AppStateActive,
EntityID: "https://test.com/metadata",
Metadata: []byte("metadata"),
MetadataURL: "https://test.com/metadata",
ProjectID: "236645808328409090",
ProjectRoleAssertion: true,
LoginVersion: domain.LoginVersion2,
LoginBaseURI: func() *url.URL {
ret, _ := url.Parse("https://test.com/login")
return ret
}(),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
execMock(t, tt.mock, func(db *sql.DB) {
q := &Queries{
client: &database.DB{
DB: db,
Database: &prepareDB{},
},
}
ctx := authz.NewMockContext("instanceID", "orgID", "loginClient")
got, err := q.ActiveSAMLServiceProviderByID(ctx, "entityID")
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
})
})
}
}

View File

@@ -0,0 +1,32 @@
{
"instance_id": "230690539048009730",
"app_id": "236647088211886082",
"state": 1,
"client_id": "236647088211951618",
"client_secret": null,
"redirect_uris": ["http://localhost:9999/auth/callback"],
"response_types": [0],
"grant_types": [0, 2],
"application_type": 0,
"auth_method_type": 3,
"post_logout_redirect_uris": ["https://example.com/logout"],
"is_dev_mode": true,
"access_token_type": 1,
"access_token_role_assertion": true,
"id_token_role_assertion": true,
"id_token_userinfo_assertion": true,
"clock_skew": 1000000000,
"additional_origins": ["https://example.com"],
"project_id": "236645808328409090",
"project_role_assertion": true,
"project_role_keys": ["role1", "role2"],
"public_keys": {
"236647201860747266": "LS0tLS1CRUdJTiBSU0EgUFVCTElDIEtFWS0tLS0tCk1JSUJJakFOQmdrcWhraUc5dzBCQVFFRkFB\nT0NBUThBTUlJQkNnS0NBUUVBMnVmQUwxYjcyYkl5MWFyK1dzNmIKR29oSkpRRkI3ZGZSYXBEcWVx\nTThVa3A2Q1ZkUHpxL3BPejF2aUFxNTB5eldaSnJ5Risyd3NoRkFLR0Y5QTIvQgoyWWY5YkpYUFov\nS2JrRnJZVDNOVHZZRGt2bGFTVGw5bU1uenJVMjlzNDhGMVBUV0tmQitDM2FNc09FRzFCdWZWCnM2\nM3FGNG5yRVBqU2JobGpJY285RlpxNFhwcEl6aE1RMGZEZEEvK1h5Z0NKcXZ1YUwwTGliTTFLcmxV\nZG51NzEKWWVraFNKakVQbnZPaXNYSWs0SVh5d29HSU93dGp4a0R2Tkl0UXZhTVZsZHI0L2tiNnV2\nYmdkV3dxNUV3QlpYcQpsb3cya3lKb3YzOFY0VWsySThrdVhwTGNucnB3NVRpbzJvb2lVRTI3YjB2\nSFpxQktPZWk5VW84OHFDcm4zRUt4CjZRSURBUUFCCi0tLS0tRU5EIFJTQSBQVUJMSUMgS0VZLS0t\nLS0K"
},
"settings": {
"access_token_lifetime": 43200000000000,
"id_token_lifetime": 43200000000000
},
"login_version": 1,
"login_base_uri": "https://test.com/login"
}

View File

@@ -384,13 +384,13 @@ func ChangeBackChannelLogoutURI(backChannelLogoutURI string) func(event *OIDCCon
}
}
func ChangeLoginVersion(loginVersion domain.LoginVersion) func(event *OIDCConfigChangedEvent) {
func ChangeOIDCLoginVersion(loginVersion domain.LoginVersion) func(event *OIDCConfigChangedEvent) {
return func(e *OIDCConfigChangedEvent) {
e.LoginVersion = &loginVersion
}
}
func ChangeLoginBaseURI(loginBaseURI string) func(event *OIDCConfigChangedEvent) {
func ChangeOIDCLoginBaseURI(loginBaseURI string) func(event *OIDCConfigChangedEvent) {
return func(e *OIDCConfigChangedEvent) {
e.LoginBaseURI = &loginBaseURI
}

View File

@@ -3,6 +3,7 @@ package project
import (
"context"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/zerrors"
)
@@ -16,10 +17,12 @@ const (
type SAMLConfigAddedEvent struct {
eventstore.BaseEvent `json:"-"`
AppID string `json:"appId"`
EntityID string `json:"entityId"`
Metadata []byte `json:"metadata,omitempty"`
MetadataURL string `json:"metadata_url,omitempty"`
AppID string `json:"appId"`
EntityID string `json:"entityId"`
Metadata []byte `json:"metadata,omitempty"`
MetadataURL string `json:"metadata_url,omitempty"`
LoginVersion domain.LoginVersion `json:"loginVersion,omitempty"`
LoginBaseURI string `json:"loginBaseURI,omitempty"`
}
func (e *SAMLConfigAddedEvent) Payload() interface{} {
@@ -50,6 +53,8 @@ func NewSAMLConfigAddedEvent(
entityID string,
metadata []byte,
metadataURL string,
loginVersion domain.LoginVersion,
loginBaseURI string,
) *SAMLConfigAddedEvent {
return &SAMLConfigAddedEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
@@ -57,10 +62,12 @@ func NewSAMLConfigAddedEvent(
aggregate,
SAMLConfigAddedType,
),
AppID: appID,
EntityID: entityID,
Metadata: metadata,
MetadataURL: metadataURL,
AppID: appID,
EntityID: entityID,
Metadata: metadata,
MetadataURL: metadataURL,
LoginVersion: loginVersion,
LoginBaseURI: loginBaseURI,
}
}
@@ -80,11 +87,13 @@ func SAMLConfigAddedEventMapper(event eventstore.Event) (eventstore.Event, error
type SAMLConfigChangedEvent struct {
eventstore.BaseEvent `json:"-"`
AppID string `json:"appId"`
EntityID string `json:"entityId"`
Metadata []byte `json:"metadata,omitempty"`
MetadataURL *string `json:"metadata_url,omitempty"`
oldEntityID string
AppID string `json:"appId"`
EntityID string `json:"entityId"`
Metadata []byte `json:"metadata,omitempty"`
MetadataURL *string `json:"metadata_url,omitempty"`
LoginVersion *domain.LoginVersion `json:"loginVersion,omitempty"`
LoginBaseURI *string `json:"loginBaseURI,omitempty"`
oldEntityID string
}
func (e *SAMLConfigChangedEvent) Payload() interface{} {
@@ -147,6 +156,17 @@ func ChangeEntityID(entityID string) func(event *SAMLConfigChangedEvent) {
}
}
func ChangeSAMLLoginVersion(loginVersion domain.LoginVersion) func(event *SAMLConfigChangedEvent) {
return func(e *SAMLConfigChangedEvent) {
e.LoginVersion = &loginVersion
}
}
func ChangeSAMLLoginBaseURI(loginBaseURI string) func(event *SAMLConfigChangedEvent) {
return func(e *SAMLConfigChangedEvent) {
e.LoginBaseURI = &loginBaseURI
}
}
func SAMLConfigChangedEventMapper(event eventstore.Event) (eventstore.Event, error) {
e := &SAMLConfigChangedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(event),