feat(saml): implementation of saml for ZITADEL v2 (#3618)

This commit is contained in:
Stefan Benz
2022-09-12 17:18:08 +01:00
committed by GitHub
parent 01a92ba5d9
commit 7a5f7f82cf
134 changed files with 5570 additions and 1293 deletions

View File

@@ -33,6 +33,7 @@ type App struct {
Name string
OIDCConfig *OIDCApp
SAMLConfig *SAMLApp
APIConfig *APIApp
}
@@ -56,6 +57,12 @@ type OIDCApp struct {
AllowedOrigins database.StringArray
}
type SAMLApp struct {
Metadata []byte
MetadataURL string
EntityID string
}
type APIApp struct {
ClientID string
AuthMethodType domain.APIAuthMethodType
@@ -116,6 +123,28 @@ var (
}
)
var (
appSAMLConfigsTable = table{
name: projection.AppSAMLTable,
}
AppSAMLConfigColumnAppID = Column{
name: projection.AppSAMLConfigColumnAppID,
table: appSAMLConfigsTable,
}
AppSAMLConfigColumnEntityID = Column{
name: projection.AppSAMLConfigColumnEntityID,
table: appSAMLConfigsTable,
}
AppSAMLConfigColumnMetadata = Column{
name: projection.AppSAMLConfigColumnMetadata,
table: appSAMLConfigsTable,
}
AppSAMLConfigColumnMetadataURL = Column{
name: projection.AppSAMLConfigColumnMetadataURL,
table: appSAMLConfigsTable,
}
)
var (
appAPIConfigsTable = table{
name: projection.AppAPITable,
@@ -225,6 +254,54 @@ func (q *Queries) AppByProjectAndAppID(ctx context.Context, shouldTriggerBulk bo
return scan(row)
}
func (q *Queries) AppByID(ctx context.Context, appID string) (*App, error) {
stmt, scan := prepareAppQuery()
query, args, err := stmt.Where(
sq.Eq{
AppColumnID.identifier(): appID,
AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
},
).ToSql()
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-immt9", "Errors.Query.SQLStatement")
}
row := q.client.QueryRowContext(ctx, query, args...)
return scan(row)
}
func (q *Queries) AppBySAMLEntityID(ctx context.Context, entityID string) (*App, error) {
stmt, scan := prepareAppQuery()
query, args, err := stmt.Where(
sq.Eq{
AppSAMLConfigColumnEntityID.identifier(): entityID,
},
).ToSql()
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-JgUop", "Errors.Query.SQLStatement")
}
row := q.client.QueryRowContext(ctx, query, args...)
return scan(row)
}
func (q *Queries) ProjectByClientID(ctx context.Context, appID string) (*Project, error) {
stmt, scan := prepareProjectByAppQuery()
query, args, err := stmt.Where(
sq.Or{
sq.Eq{AppOIDCConfigColumnClientID.identifier(): appID},
sq.Eq{AppAPIConfigColumnClientID.identifier(): appID},
sq.Eq{AppSAMLConfigColumnAppID.identifier(): appID},
},
).ToSql()
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-XhJi3", "Errors.Query.SQLStatement")
}
row := q.client.QueryRowContext(ctx, query, args...)
return scan(row)
}
func (q *Queries) ProjectIDFromOIDCClientID(ctx context.Context, appID string) (string, error) {
stmt, scan := prepareProjectIDByAppQuery()
query, args, err := stmt.Where(
@@ -249,6 +326,7 @@ func (q *Queries) ProjectIDFromClientID(ctx context.Context, appID string) (stri
sq.Or{
sq.Eq{AppOIDCConfigColumnClientID.identifier(): appID},
sq.Eq{AppAPIConfigColumnClientID.identifier(): appID},
sq.Eq{AppSAMLConfigColumnAppID.identifier(): appID},
},
},
).ToSql()
@@ -389,15 +467,22 @@ func prepareAppQuery() (sq.SelectBuilder, func(*sql.Row) (*App, error)) {
AppOIDCConfigColumnIDTokenUserinfoAssertion.identifier(),
AppOIDCConfigColumnClockSkew.identifier(),
AppOIDCConfigColumnAdditionalOrigins.identifier(),
AppSAMLConfigColumnAppID.identifier(),
AppSAMLConfigColumnEntityID.identifier(),
AppSAMLConfigColumnMetadata.identifier(),
AppSAMLConfigColumnMetadataURL.identifier(),
).From(appsTable.identifier()).
LeftJoin(join(AppAPIConfigColumnAppID, AppColumnID)).
LeftJoin(join(AppOIDCConfigColumnAppID, AppColumnID)).
LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID)).
PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*App, error) {
app := new(App)
var (
apiConfig = sqlAPIConfig{}
oidcConfig = sqlOIDCConfig{}
samlConfig = sqlSAMLConfig{}
)
err := row.Scan(
@@ -430,6 +515,11 @@ func prepareAppQuery() (sq.SelectBuilder, func(*sql.Row) (*App, error)) {
&oidcConfig.iDTokenUserinfoAssertion,
&oidcConfig.clockSkew,
&oidcConfig.additionalOrigins,
&samlConfig.appID,
&samlConfig.entityID,
&samlConfig.metadata,
&samlConfig.metadataURL,
)
if err != nil {
@@ -441,6 +531,7 @@ func prepareAppQuery() (sq.SelectBuilder, func(*sql.Row) (*App, error)) {
apiConfig.set(app)
oidcConfig.set(app)
samlConfig.set(app)
return app, nil
}
@@ -452,6 +543,7 @@ func prepareProjectIDByAppQuery() (sq.SelectBuilder, func(*sql.Row) (projectID s
).From(appsTable.identifier()).
LeftJoin(join(AppAPIConfigColumnAppID, AppColumnID)).
LeftJoin(join(AppOIDCConfigColumnAppID, AppColumnID)).
LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID)).
PlaceholderFormat(sq.Dollar), func(row *sql.Row) (projectID string, err error) {
err = row.Scan(
&projectID,
@@ -485,6 +577,7 @@ func prepareProjectByAppQuery() (sq.SelectBuilder, func(*sql.Row) (*Project, err
Join(join(AppColumnProjectID, ProjectColumnID)).
LeftJoin(join(AppAPIConfigColumnAppID, AppColumnID)).
LeftJoin(join(AppOIDCConfigColumnAppID, AppColumnID)).
LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID)).
PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (*Project, error) {
p := new(Project)
@@ -542,10 +635,16 @@ func prepareAppsQuery() (sq.SelectBuilder, func(*sql.Rows) (*Apps, error)) {
AppOIDCConfigColumnIDTokenUserinfoAssertion.identifier(),
AppOIDCConfigColumnClockSkew.identifier(),
AppOIDCConfigColumnAdditionalOrigins.identifier(),
AppSAMLConfigColumnAppID.identifier(),
AppSAMLConfigColumnEntityID.identifier(),
AppSAMLConfigColumnMetadata.identifier(),
AppSAMLConfigColumnMetadataURL.identifier(),
countColumn.identifier(),
).From(appsTable.identifier()).
LeftJoin(join(AppAPIConfigColumnAppID, AppColumnID)).
LeftJoin(join(AppOIDCConfigColumnAppID, AppColumnID)).
LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID)).
PlaceholderFormat(sq.Dollar), func(row *sql.Rows) (*Apps, error) {
apps := &Apps{Apps: []*App{}}
@@ -554,6 +653,7 @@ func prepareAppsQuery() (sq.SelectBuilder, func(*sql.Rows) (*Apps, error)) {
var (
apiConfig = sqlAPIConfig{}
oidcConfig = sqlOIDCConfig{}
samlConfig = sqlSAMLConfig{}
)
err := row.Scan(
@@ -586,6 +686,12 @@ func prepareAppsQuery() (sq.SelectBuilder, func(*sql.Rows) (*Apps, error)) {
&oidcConfig.iDTokenUserinfoAssertion,
&oidcConfig.clockSkew,
&oidcConfig.additionalOrigins,
&samlConfig.appID,
&samlConfig.entityID,
&samlConfig.metadata,
&samlConfig.metadataURL,
&apps.Count,
)
@@ -595,6 +701,7 @@ func prepareAppsQuery() (sq.SelectBuilder, func(*sql.Rows) (*Apps, error)) {
apiConfig.set(app)
oidcConfig.set(app)
samlConfig.set(app)
apps.Apps = append(apps.Apps, app)
}
@@ -681,6 +788,24 @@ func (c sqlOIDCConfig) set(app *App) {
logging.LogWithFields("app", app.ID).OnError(err).Warn("unable to set allowed origins")
}
type sqlSAMLConfig struct {
appID sql.NullString
entityID sql.NullString
metadataURL sql.NullString
metadata []byte
}
func (c sqlSAMLConfig) set(app *App) {
if !c.appID.Valid {
return
}
app.SAMLConfig = &SAMLApp{
MetadataURL: c.metadataURL.String,
Metadata: c.metadata,
EntityID: c.entityID.String,
}
}
type sqlAPIConfig struct {
appID sql.NullString
clientID sql.NullString