mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-14 02:17:34 +00:00
feat(saml): implementation of saml for ZITADEL v2 (#3618)
This commit is contained in:
@@ -67,6 +67,16 @@ func (s *Server) AddOIDCApp(ctx context.Context, req *mgmt_pb.AddOIDCAppRequest)
|
||||
ComplianceProblems: project_grpc.ComplianceProblemsToLocalizedMessages(app.Compliance.Problems),
|
||||
}, nil
|
||||
}
|
||||
func (s *Server) AddSAMLApp(ctx context.Context, req *mgmt_pb.AddSAMLAppRequest) (*mgmt_pb.AddSAMLAppResponse, error) {
|
||||
app, err := s.command.AddSAMLApplication(ctx, AddSAMLAppRequestToDomain(req), authz.GetCtxData(ctx).OrgID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &mgmt_pb.AddSAMLAppResponse{
|
||||
AppId: app.AppID,
|
||||
Details: object_grpc.AddToDetailsPb(app.Sequence, app.ChangeDate, app.ResourceOwner),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) AddAPIApp(ctx context.Context, req *mgmt_pb.AddAPIAppRequest) (*mgmt_pb.AddAPIAppResponse, error) {
|
||||
appSecretGenerator, err := s.query.InitHashGenerator(ctx, domain.SecretGeneratorTypeAppSecret, s.passwordHashAlg)
|
||||
@@ -109,6 +119,20 @@ func (s *Server) UpdateOIDCAppConfig(ctx context.Context, req *mgmt_pb.UpdateOID
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) UpdateSAMLAppConfig(ctx context.Context, req *mgmt_pb.UpdateSAMLAppConfigRequest) (*mgmt_pb.UpdateSAMLAppConfigResponse, error) {
|
||||
config, err := s.command.ChangeSAMLApplication(ctx, UpdateSAMLAppConfigRequestToDomain(req), authz.GetCtxData(ctx).OrgID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &mgmt_pb.UpdateSAMLAppConfigResponse{
|
||||
Details: object_grpc.ChangeToDetailsPb(
|
||||
config.Sequence,
|
||||
config.ChangeDate,
|
||||
config.ResourceOwner,
|
||||
),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) UpdateAPIAppConfig(ctx context.Context, req *mgmt_pb.UpdateAPIAppConfigRequest) (*mgmt_pb.UpdateAPIAppConfigResponse, error) {
|
||||
config, err := s.command.ChangeAPIApplication(ctx, UpdateAPIAppConfigRequestToDomain(req), authz.GetCtxData(ctx).OrgID)
|
||||
if err != nil {
|
||||
|
@@ -59,6 +59,17 @@ func AddOIDCAppRequestToDomain(req *mgmt_pb.AddOIDCAppRequest) *domain.OIDCApp {
|
||||
}
|
||||
}
|
||||
|
||||
func AddSAMLAppRequestToDomain(req *mgmt_pb.AddSAMLAppRequest) *domain.SAMLApp {
|
||||
return &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: req.ProjectId,
|
||||
},
|
||||
AppName: req.Name,
|
||||
Metadata: req.GetMetadataXml(),
|
||||
MetadataURL: req.GetMetadataUrl(),
|
||||
}
|
||||
}
|
||||
|
||||
func AddAPIAppRequestToDomain(app *mgmt_pb.AddAPIAppRequest) *domain.APIApp {
|
||||
return &domain.APIApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
@@ -98,6 +109,17 @@ func UpdateOIDCAppConfigRequestToDomain(app *mgmt_pb.UpdateOIDCAppConfigRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateSAMLAppConfigRequestToDomain(app *mgmt_pb.UpdateSAMLAppConfigRequest) *domain.SAMLApp {
|
||||
return &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: app.ProjectId,
|
||||
},
|
||||
AppID: app.AppId,
|
||||
Metadata: app.GetMetadataXml(),
|
||||
MetadataURL: app.GetMetadataUrl(),
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateAPIAppConfigRequestToDomain(app *mgmt_pb.UpdateAPIAppConfigRequest) *domain.APIApp {
|
||||
return &domain.APIApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
|
@@ -33,6 +33,9 @@ func AppConfigToPb(app *query.App) app_pb.AppConfig {
|
||||
if app.OIDCConfig != nil {
|
||||
return AppOIDCConfigToPb(app.OIDCConfig)
|
||||
}
|
||||
if app.SAMLConfig != nil {
|
||||
return AppSAMLConfigToPb(app.SAMLConfig)
|
||||
}
|
||||
return AppAPIConfigToPb(app.APIConfig)
|
||||
}
|
||||
|
||||
@@ -61,6 +64,14 @@ func AppOIDCConfigToPb(app *query.OIDCApp) *app_pb.App_OidcConfig {
|
||||
}
|
||||
}
|
||||
|
||||
func AppSAMLConfigToPb(app *query.SAMLApp) app_pb.AppConfig {
|
||||
return &app_pb.App_SamlConfig{
|
||||
SamlConfig: &app_pb.SAMLConfig{
|
||||
Metadata: &app_pb.SAMLConfig_MetadataXml{MetadataXml: app.Metadata},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func AppAPIConfigToPb(app *query.APIApp) app_pb.AppConfig {
|
||||
return &app_pb.App_ApiConfig{
|
||||
ApiConfig: &app_pb.APIConfig{
|
||||
|
99
internal/api/saml/auth_request_converter.go
Normal file
99
internal/api/saml/auth_request_converter.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/saml/pkg/provider/models"
|
||||
"github.com/zitadel/saml/pkg/provider/xml/samlp"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
var _ models.AuthRequestInt = &AuthRequest{}
|
||||
|
||||
type AuthRequest struct {
|
||||
*domain.AuthRequest
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetApplicationID() string {
|
||||
return a.ApplicationID
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetID() string {
|
||||
return a.ID
|
||||
}
|
||||
func (a *AuthRequest) GetRelayState() string {
|
||||
return a.TransferState
|
||||
}
|
||||
func (a *AuthRequest) GetAccessConsumerServiceURL() string {
|
||||
return a.CallbackURI
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetNameID() string {
|
||||
return a.UserName
|
||||
}
|
||||
|
||||
func (a *AuthRequest) saml() *domain.AuthRequestSAML {
|
||||
return a.Request.(*domain.AuthRequestSAML)
|
||||
}
|
||||
func (a *AuthRequest) GetAuthRequestID() string {
|
||||
return a.saml().ID
|
||||
}
|
||||
func (a *AuthRequest) GetBindingType() string {
|
||||
return a.saml().BindingType
|
||||
}
|
||||
func (a *AuthRequest) GetIssuer() string {
|
||||
return a.saml().Issuer
|
||||
}
|
||||
func (a *AuthRequest) GetIssuerName() string {
|
||||
return a.saml().IssuerName
|
||||
}
|
||||
func (a *AuthRequest) GetDestination() string {
|
||||
return a.saml().Destination
|
||||
}
|
||||
func (a *AuthRequest) GetCode() string {
|
||||
return a.saml().Code
|
||||
}
|
||||
func (a *AuthRequest) GetUserID() string {
|
||||
return a.UserID
|
||||
}
|
||||
func (a *AuthRequest) GetUserName() string {
|
||||
return a.UserName
|
||||
}
|
||||
func (a *AuthRequest) Done() bool {
|
||||
for _, step := range a.PossibleSteps {
|
||||
if step.Type() == domain.NextStepRedirectToCallback {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func AuthRequestFromBusiness(authReq *domain.AuthRequest) (_ models.AuthRequestInt, err error) {
|
||||
if _, ok := authReq.Request.(*domain.AuthRequestSAML); !ok {
|
||||
return nil, errors.ThrowInvalidArgument(nil, "SAML-Hbz7A", "auth request is not of type saml")
|
||||
}
|
||||
return &AuthRequest{authReq}, nil
|
||||
}
|
||||
|
||||
func CreateAuthRequestToBusiness(ctx context.Context, authReq *samlp.AuthnRequestType, acsUrl, protocolBinding, applicationID, relayState, userAgentID string) *domain.AuthRequest {
|
||||
return &domain.AuthRequest{
|
||||
CreationDate: time.Now(),
|
||||
AgentID: userAgentID,
|
||||
ApplicationID: applicationID,
|
||||
CallbackURI: acsUrl,
|
||||
TransferState: relayState,
|
||||
InstanceID: authz.GetInstance(ctx).InstanceID(),
|
||||
Request: &domain.AuthRequestSAML{
|
||||
ID: authReq.Id,
|
||||
BindingType: protocolBinding,
|
||||
Code: "",
|
||||
Issuer: authReq.Issuer.Text,
|
||||
IssuerName: authReq.Issuer.SPProvidedID,
|
||||
Destination: authReq.Destination,
|
||||
},
|
||||
}
|
||||
}
|
202
internal/api/saml/certificate.go
Normal file
202
internal/api/saml/certificate.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"github.com/zitadel/saml/pkg/provider/key"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/repository/keypair"
|
||||
)
|
||||
|
||||
const (
|
||||
locksTable = "projections.locks"
|
||||
signingKey = "signing_key"
|
||||
samlUser = "SAML"
|
||||
|
||||
retryBackoff = 500 * time.Millisecond
|
||||
retryCount = 3
|
||||
lockDuration = retryCount * retryBackoff * 5
|
||||
gracefulPeriod = 10 * time.Minute
|
||||
)
|
||||
|
||||
type CertificateAndKey struct {
|
||||
algorithm jose.SignatureAlgorithm
|
||||
id string
|
||||
key interface{}
|
||||
certificate interface{}
|
||||
}
|
||||
|
||||
func (c *CertificateAndKey) SignatureAlgorithm() jose.SignatureAlgorithm {
|
||||
return c.algorithm
|
||||
}
|
||||
|
||||
func (c *CertificateAndKey) Key() interface{} {
|
||||
return c.key
|
||||
}
|
||||
|
||||
func (c *CertificateAndKey) Certificate() interface{} {
|
||||
return c.certificate
|
||||
}
|
||||
|
||||
func (c *CertificateAndKey) ID() string {
|
||||
return c.id
|
||||
}
|
||||
|
||||
func (p *Storage) GetCertificateAndKey(ctx context.Context, usage domain.KeyUsage) (certAndKey *key.CertificateAndKey, err error) {
|
||||
err = retry(func() error {
|
||||
certAndKey, err = p.getCertificateAndKey(ctx, usage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if certAndKey == nil {
|
||||
return errors.ThrowInternal(err, "SAML-8u01nks", "no certificate found")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return certAndKey, err
|
||||
}
|
||||
|
||||
func (p *Storage) getCertificateAndKey(ctx context.Context, usage domain.KeyUsage) (*key.CertificateAndKey, error) {
|
||||
certs, err := p.query.ActiveCertificates(ctx, time.Now().Add(gracefulPeriod), usage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(certs.Certificates) > 0 {
|
||||
return p.certificateToCertificateAndKey(selectCertificate(certs.Certificates))
|
||||
}
|
||||
|
||||
var sequence uint64
|
||||
if certs.LatestSequence != nil {
|
||||
sequence = certs.LatestSequence.Sequence
|
||||
}
|
||||
|
||||
return nil, p.refreshCertificate(ctx, usage, sequence)
|
||||
}
|
||||
|
||||
func (p *Storage) refreshCertificate(
|
||||
ctx context.Context,
|
||||
usage domain.KeyUsage,
|
||||
sequence uint64,
|
||||
) error {
|
||||
ok, err := p.ensureIsLatestCertificate(ctx, sequence)
|
||||
if err != nil {
|
||||
logging.WithError(err).Error("could not ensure latest key")
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
logging.Warn("view not up to date, retrying later")
|
||||
return err
|
||||
}
|
||||
err = p.lockAndGenerateCertificateAndKey(ctx, usage, sequence)
|
||||
logging.OnError(err).Warn("could not create signing key")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Storage) ensureIsLatestCertificate(ctx context.Context, sequence uint64) (bool, error) {
|
||||
maxSequence, err := p.getMaxKeySequence(ctx)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("error retrieving new events: %w", err)
|
||||
}
|
||||
return sequence == maxSequence, nil
|
||||
}
|
||||
|
||||
func (p *Storage) lockAndGenerateCertificateAndKey(ctx context.Context, usage domain.KeyUsage, sequence uint64) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
ctx = setSAMLCtx(ctx)
|
||||
|
||||
errs := p.locker.Lock(ctx, lockDuration, authz.GetInstance(ctx).InstanceID())
|
||||
err, ok := <-errs
|
||||
if err != nil || !ok {
|
||||
if errors.IsErrorAlreadyExists(err) {
|
||||
return nil
|
||||
}
|
||||
logging.OnError(err).Warn("initial lock failed")
|
||||
return err
|
||||
}
|
||||
|
||||
switch usage {
|
||||
case domain.KeyUsageSAMLMetadataSigning, domain.KeyUsageSAMLResponseSinging:
|
||||
certAndKey, err := p.GetCertificateAndKey(ctx, domain.KeyUsageSAMLCA)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while reading ca certificate: %w", err)
|
||||
}
|
||||
if certAndKey.Key == nil || certAndKey.Certificate == nil {
|
||||
return fmt.Errorf("has no ca certificate")
|
||||
}
|
||||
|
||||
switch usage {
|
||||
case domain.KeyUsageSAMLMetadataSigning:
|
||||
return p.command.GenerateSAMLMetadataCertificate(setSAMLCtx(ctx), p.certificateAlgorithm, certAndKey.Key, certAndKey.Certificate)
|
||||
case domain.KeyUsageSAMLResponseSinging:
|
||||
return p.command.GenerateSAMLResponseCertificate(setSAMLCtx(ctx), p.certificateAlgorithm, certAndKey.Key, certAndKey.Certificate)
|
||||
default:
|
||||
return fmt.Errorf("unknown usage")
|
||||
}
|
||||
case domain.KeyUsageSAMLCA:
|
||||
return p.command.GenerateSAMLCACertificate(setSAMLCtx(ctx), p.certificateAlgorithm)
|
||||
default:
|
||||
return fmt.Errorf("unknown certificate usage")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Storage) getMaxKeySequence(ctx context.Context) (uint64, error) {
|
||||
return p.eventstore.LatestSequence(ctx,
|
||||
eventstore.NewSearchQueryBuilder(eventstore.ColumnsMaxSequence).
|
||||
ResourceOwner(authz.GetInstance(ctx).InstanceID()).
|
||||
AddQuery().
|
||||
AggregateTypes(keypair.AggregateType).
|
||||
Builder(),
|
||||
)
|
||||
}
|
||||
|
||||
func (p *Storage) certificateToCertificateAndKey(certificate query.Certificate) (_ *key.CertificateAndKey, err error) {
|
||||
keyData, err := crypto.Decrypt(certificate.Key(), p.encAlg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
privateKey, err := crypto.BytesToPrivateKey(keyData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cert, err := crypto.BytesToCertificate(certificate.Certificate())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &key.CertificateAndKey{
|
||||
Key: privateKey,
|
||||
Certificate: cert,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func selectCertificate(certs []query.Certificate) query.Certificate {
|
||||
return certs[len(certs)-1]
|
||||
}
|
||||
|
||||
func setSAMLCtx(ctx context.Context) context.Context {
|
||||
return authz.SetCtxData(ctx, authz.CtxData{UserID: samlUser, OrgID: authz.GetInstance(ctx).InstanceID()})
|
||||
}
|
||||
|
||||
func retry(retryable func() error) (err error) {
|
||||
for i := 0; i < retryCount; i++ {
|
||||
err = retryable()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(retryBackoff)
|
||||
}
|
||||
return err
|
||||
}
|
102
internal/api/saml/provider.go
Normal file
102
internal/api/saml/provider.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/zitadel/saml/pkg/provider"
|
||||
|
||||
http_utils "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/api/http/middleware"
|
||||
"github.com/zitadel/zitadel/internal/api/ui/login"
|
||||
"github.com/zitadel/zitadel/internal/auth/repository"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler/crdb"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/metrics"
|
||||
)
|
||||
|
||||
const (
|
||||
HandlerPrefix = "/saml/v2"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
ProviderConfig *provider.Config
|
||||
}
|
||||
|
||||
func NewProvider(
|
||||
ctx context.Context,
|
||||
conf Config,
|
||||
externalSecure bool,
|
||||
command *command.Commands,
|
||||
query *query.Queries,
|
||||
repo repository.Repository,
|
||||
encAlg crypto.EncryptionAlgorithm,
|
||||
certEncAlg crypto.EncryptionAlgorithm,
|
||||
es *eventstore.Eventstore,
|
||||
projections *sql.DB,
|
||||
instanceHandler,
|
||||
userAgentCookie func(http.Handler) http.Handler,
|
||||
) (*provider.Provider, error) {
|
||||
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
|
||||
|
||||
provStorage, err := newStorage(
|
||||
command,
|
||||
query,
|
||||
repo,
|
||||
encAlg,
|
||||
certEncAlg,
|
||||
es,
|
||||
projections,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
options := []provider.Option{
|
||||
provider.WithHttpInterceptors(
|
||||
middleware.MetricsHandler(metricTypes),
|
||||
middleware.TelemetryHandler(),
|
||||
middleware.NoCacheInterceptor().Handler,
|
||||
instanceHandler,
|
||||
userAgentCookie,
|
||||
http_utils.CopyHeadersToContext,
|
||||
),
|
||||
}
|
||||
if !externalSecure {
|
||||
options = append(options, provider.WithAllowInsecure())
|
||||
}
|
||||
|
||||
return provider.NewProvider(
|
||||
ctx,
|
||||
provStorage,
|
||||
HandlerPrefix,
|
||||
conf.ProviderConfig,
|
||||
options...,
|
||||
)
|
||||
}
|
||||
|
||||
func newStorage(
|
||||
command *command.Commands,
|
||||
query *query.Queries,
|
||||
repo repository.Repository,
|
||||
encAlg crypto.EncryptionAlgorithm,
|
||||
certEncAlg crypto.EncryptionAlgorithm,
|
||||
es *eventstore.Eventstore,
|
||||
projections *sql.DB,
|
||||
) (*Storage, error) {
|
||||
return &Storage{
|
||||
encAlg: encAlg,
|
||||
certEncAlg: certEncAlg,
|
||||
locker: crdb.NewLocker(projections, locksTable, signingKey),
|
||||
eventstore: es,
|
||||
repo: repo,
|
||||
command: command,
|
||||
query: query,
|
||||
defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID),
|
||||
}, nil
|
||||
}
|
187
internal/api/saml/storage.go
Normal file
187
internal/api/saml/storage.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/saml/pkg/provider"
|
||||
"github.com/zitadel/saml/pkg/provider/key"
|
||||
"github.com/zitadel/saml/pkg/provider/models"
|
||||
"github.com/zitadel/saml/pkg/provider/serviceprovider"
|
||||
"github.com/zitadel/saml/pkg/provider/xml/samlp"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/http/middleware"
|
||||
"github.com/zitadel/zitadel/internal/auth/repository"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler/crdb"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
var _ provider.EntityStorage = &Storage{}
|
||||
var _ provider.IdentityProviderStorage = &Storage{}
|
||||
var _ provider.AuthStorage = &Storage{}
|
||||
var _ provider.UserStorage = &Storage{}
|
||||
|
||||
type Storage struct {
|
||||
certChan <-chan interface{}
|
||||
defaultCertificateLifetime time.Duration
|
||||
|
||||
currentCACertificate query.Certificate
|
||||
currentMetadataCertificate query.Certificate
|
||||
currentResponseCertificate query.Certificate
|
||||
|
||||
locker crdb.Locker
|
||||
certificateAlgorithm string
|
||||
encAlg crypto.EncryptionAlgorithm
|
||||
certEncAlg crypto.EncryptionAlgorithm
|
||||
|
||||
eventstore *eventstore.Eventstore
|
||||
repo repository.Repository
|
||||
command *command.Commands
|
||||
query *query.Queries
|
||||
|
||||
defaultLoginURL string
|
||||
}
|
||||
|
||||
func (p *Storage) GetEntityByID(ctx context.Context, entityID string) (*serviceprovider.ServiceProvider, error) {
|
||||
app, err := p.query.AppBySAMLEntityID(ctx, entityID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return serviceprovider.NewServiceProvider(
|
||||
app.ID,
|
||||
&serviceprovider.Config{
|
||||
Metadata: app.SAMLConfig.Metadata,
|
||||
},
|
||||
p.defaultLoginURL,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *Storage) GetEntityIDByAppID(ctx context.Context, appID string) (string, error) {
|
||||
app, err := p.query.AppByID(ctx, appID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return app.SAMLConfig.EntityID, nil
|
||||
}
|
||||
|
||||
func (p *Storage) Health(context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Storage) GetCA(ctx context.Context) (*key.CertificateAndKey, error) {
|
||||
return p.GetCertificateAndKey(ctx, domain.KeyUsageSAMLCA)
|
||||
}
|
||||
|
||||
func (p *Storage) GetMetadataSigningKey(ctx context.Context) (*key.CertificateAndKey, error) {
|
||||
return p.GetCertificateAndKey(ctx, domain.KeyUsageSAMLMetadataSigning)
|
||||
}
|
||||
|
||||
func (p *Storage) GetResponseSigningKey(ctx context.Context) (*key.CertificateAndKey, error) {
|
||||
return p.GetCertificateAndKey(ctx, domain.KeyUsageSAMLResponseSinging)
|
||||
}
|
||||
|
||||
func (p *Storage) CreateAuthRequest(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID string) (_ models.AuthRequestInt, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
||||
if !ok {
|
||||
return nil, errors.ThrowPreconditionFailed(nil, "SAML-sd436", "no user agent id")
|
||||
}
|
||||
|
||||
authRequest := CreateAuthRequestToBusiness(ctx, req, acsUrl, protocolBinding, applicationID, relayState, userAgentID)
|
||||
|
||||
resp, err := p.repo.CreateAuthRequest(ctx, authRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return AuthRequestFromBusiness(resp)
|
||||
}
|
||||
|
||||
func (p *Storage) AuthRequestByID(ctx context.Context, id string) (_ models.AuthRequestInt, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
||||
if !ok {
|
||||
return nil, errors.ThrowPreconditionFailed(nil, "SAML-D3g21", "no user agent id")
|
||||
}
|
||||
resp, err := p.repo.AuthRequestByIDCheckLoggedIn(ctx, id, userAgentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return AuthRequestFromBusiness(resp)
|
||||
}
|
||||
|
||||
func (p *Storage) SetUserinfoWithUserID(ctx context.Context, userinfo models.AttributeSetter, userID string, attributes []int) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
user, err := p.query.GetUserByID(ctx, true, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
setUserinfo(user, userinfo, attributes)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Storage) SetUserinfoWithLoginName(ctx context.Context, userinfo models.AttributeSetter, loginName string, attributes []int) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
loginNameSQ, err := query.NewUserLoginNamesSearchQuery(loginName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user, err := p.query.GetUser(ctx, true, loginNameSQ)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
setUserinfo(user, userinfo, attributes)
|
||||
return nil
|
||||
}
|
||||
|
||||
func setUserinfo(user *query.User, userinfo models.AttributeSetter, attributes []int) {
|
||||
if len(attributes) == 0 {
|
||||
userinfo.SetUsername(user.PreferredLoginName)
|
||||
userinfo.SetUserID(user.ID)
|
||||
if user.Human == nil {
|
||||
return
|
||||
}
|
||||
userinfo.SetEmail(user.Human.Email)
|
||||
userinfo.SetSurname(user.Human.LastName)
|
||||
userinfo.SetGivenName(user.Human.FirstName)
|
||||
userinfo.SetFullName(user.Human.DisplayName)
|
||||
return
|
||||
}
|
||||
for _, attribute := range attributes {
|
||||
switch attribute {
|
||||
case provider.AttributeEmail:
|
||||
if user.Human != nil {
|
||||
userinfo.SetEmail(user.Human.Email)
|
||||
}
|
||||
case provider.AttributeSurname:
|
||||
if user.Human != nil {
|
||||
userinfo.SetSurname(user.Human.LastName)
|
||||
}
|
||||
case provider.AttributeFullName:
|
||||
if user.Human != nil {
|
||||
userinfo.SetFullName(user.Human.DisplayName)
|
||||
}
|
||||
case provider.AttributeGivenName:
|
||||
if user.Human != nil {
|
||||
userinfo.SetGivenName(user.Human.FirstName)
|
||||
}
|
||||
case provider.AttributeUsername:
|
||||
userinfo.SetUsername(user.PreferredLoginName)
|
||||
case provider.AttributeUserID:
|
||||
userinfo.SetUserID(user.ID)
|
||||
}
|
||||
}
|
||||
}
|
@@ -374,7 +374,7 @@ func (l *Login) handleAutoRegister(w http.ResponseWriter, r *http.Request, authR
|
||||
return
|
||||
}
|
||||
linkingUser = l.mapExternalNotFoundOptionFormDataToLoginUser(data)
|
||||
}
|
||||
}
|
||||
|
||||
user, externalIDP, metadata := l.mapExternalUserToLoginUser(orgIamPolicy, linkingUser, idpConfig)
|
||||
|
||||
|
@@ -37,6 +37,7 @@ type Login struct {
|
||||
externalSecure bool
|
||||
consolePath string
|
||||
oidcAuthCallbackURL func(context.Context, string) string
|
||||
samlAuthCallbackURL func(context.Context, string) string
|
||||
idpConfigAlg crypto.EncryptionAlgorithm
|
||||
userCodeAlg crypto.EncryptionAlgorithm
|
||||
}
|
||||
@@ -61,10 +62,12 @@ func CreateLogin(config Config,
|
||||
staticStorage static.Storage,
|
||||
consolePath string,
|
||||
oidcAuthCallbackURL func(context.Context, string) string,
|
||||
samlAuthCallbackURL func(context.Context, string) string,
|
||||
externalSecure bool,
|
||||
userAgentCookie,
|
||||
issuerInterceptor,
|
||||
instanceHandler,
|
||||
oidcInstanceHandler,
|
||||
samlInstanceHandler mux.MiddlewareFunc,
|
||||
assetCache mux.MiddlewareFunc,
|
||||
userCodeAlg crypto.EncryptionAlgorithm,
|
||||
idpConfigAlg crypto.EncryptionAlgorithm,
|
||||
@@ -73,6 +76,7 @@ func CreateLogin(config Config,
|
||||
|
||||
login := &Login{
|
||||
oidcAuthCallbackURL: oidcAuthCallbackURL,
|
||||
samlAuthCallbackURL: samlAuthCallbackURL,
|
||||
externalSecure: externalSecure,
|
||||
consolePath: consolePath,
|
||||
command: command,
|
||||
@@ -91,7 +95,7 @@ func CreateLogin(config Config,
|
||||
cacheInterceptor := createCacheInterceptor(config.Cache.MaxAge, config.Cache.SharedMaxAge, assetCache)
|
||||
security := middleware.SecurityHeaders(csp(), login.cspErrorHandler)
|
||||
|
||||
login.router = CreateRouter(login, statikFS, middleware.TelemetryHandler(IgnoreInstanceEndpoints...), instanceHandler, csrfInterceptor, cacheInterceptor, security, userAgentCookie, issuerInterceptor)
|
||||
login.router = CreateRouter(login, statikFS, middleware.TelemetryHandler(IgnoreInstanceEndpoints...), oidcInstanceHandler, samlInstanceHandler, csrfInterceptor, cacheInterceptor, security, userAgentCookie, issuerInterceptor)
|
||||
login.renderer = CreateRenderer(HandlerPrefix, statikFS, staticStorage, config.LanguageCookieName)
|
||||
login.parser = form.NewParser()
|
||||
return login, nil
|
||||
|
@@ -4,6 +4,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -43,11 +44,26 @@ func (l *Login) renderSuccessAndCallback(w http.ResponseWriter, r *http.Request,
|
||||
userData: l.getUserData(r, authReq, "Login Successful", errID, errMessage),
|
||||
}
|
||||
if authReq != nil {
|
||||
data.RedirectURI = l.oidcAuthCallbackURL(r.Context(), "") //the id will be set via the html (maybe change this with the login refactoring)
|
||||
//the id will be set via the html (maybe change this with the login refactoring)
|
||||
if _, ok := authReq.Request.(*domain.AuthRequestOIDC); ok {
|
||||
data.RedirectURI = l.oidcAuthCallbackURL(r.Context(), "")
|
||||
} else if _, ok := authReq.Request.(*domain.AuthRequestSAML); ok {
|
||||
data.RedirectURI = l.samlAuthCallbackURL(r.Context(), "")
|
||||
}
|
||||
}
|
||||
l.renderer.RenderTemplate(w, r, l.getTranslator(r.Context(), authReq), l.renderer.Templates[tmplLoginSuccess], data, nil)
|
||||
}
|
||||
|
||||
func (l *Login) redirectToCallback(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest) {
|
||||
http.Redirect(w, r, l.oidcAuthCallbackURL(r.Context(), authReq.ID), http.StatusFound)
|
||||
var callback string
|
||||
switch authReq.Request.(type) {
|
||||
case *domain.AuthRequestOIDC:
|
||||
callback = l.oidcAuthCallbackURL(r.Context(), authReq.ID)
|
||||
case *domain.AuthRequestSAML:
|
||||
callback = l.samlAuthCallbackURL(r.Context(), authReq.ID)
|
||||
default:
|
||||
l.renderInternalError(w, r, authReq, caos_errs.ThrowInternal(nil, "LOGIN-rhjQF", "Errors.AuthRequest.RequestTypeNotSupported"))
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, callback, http.StatusFound)
|
||||
}
|
||||
|
@@ -97,12 +97,12 @@ type orgViewProvider interface {
|
||||
}
|
||||
|
||||
type userGrantProvider interface {
|
||||
ProjectByOIDCClientID(context.Context, string) (*query.Project, error)
|
||||
ProjectByClientID(context.Context, string) (*query.Project, error)
|
||||
UserGrantsByProjectAndUserID(context.Context, string, string) ([]*query.UserGrant, error)
|
||||
}
|
||||
|
||||
type projectProvider interface {
|
||||
ProjectByOIDCClientID(context.Context, string) (*query.Project, error)
|
||||
ProjectByClientID(context.Context, string) (*query.Project, error)
|
||||
OrgProjectMappingByIDs(orgID, projectID, instanceID string) (*project_view_model.OrgProjectMapping, error)
|
||||
}
|
||||
|
||||
@@ -122,7 +122,7 @@ func (repo *AuthRequestRepo) CreateAuthRequest(ctx context.Context, request *dom
|
||||
return nil, err
|
||||
}
|
||||
request.ID = reqID
|
||||
project, err := repo.ProjectProvider.ProjectByOIDCClientID(ctx, request.ApplicationID)
|
||||
project, err := repo.ProjectProvider.ProjectByClientID(ctx, request.ApplicationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1048,6 +1048,9 @@ func (repo *AuthRequestRepo) getLoginTexts(ctx context.Context, aggregateID stri
|
||||
}
|
||||
|
||||
func (repo *AuthRequestRepo) hasSucceededPage(ctx context.Context, request *domain.AuthRequest, provider applicationProvider) (bool, error) {
|
||||
if _, ok := request.Request.(*domain.AuthRequestOIDC); !ok {
|
||||
return false, nil
|
||||
}
|
||||
app, err := provider.AppByOIDCClientID(ctx, request.ApplicationID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
@@ -1257,8 +1260,8 @@ func linkingIDPConfigExistingInAllowedIDPs(linkingUsers []*domain.ExternalUser,
|
||||
func userGrantRequired(ctx context.Context, request *domain.AuthRequest, user *user_model.UserView, userGrantProvider userGrantProvider) (_ bool, err error) {
|
||||
var project *query.Project
|
||||
switch request.Request.Type() {
|
||||
case domain.AuthRequestTypeOIDC:
|
||||
project, err = userGrantProvider.ProjectByOIDCClientID(ctx, request.ApplicationID)
|
||||
case domain.AuthRequestTypeOIDC, domain.AuthRequestTypeSAML:
|
||||
project, err = userGrantProvider.ProjectByClientID(ctx, request.ApplicationID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -1278,8 +1281,8 @@ func userGrantRequired(ctx context.Context, request *domain.AuthRequest, user *u
|
||||
func projectRequired(ctx context.Context, request *domain.AuthRequest, projectProvider projectProvider) (_ bool, err error) {
|
||||
var project *query.Project
|
||||
switch request.Request.Type() {
|
||||
case domain.AuthRequestTypeOIDC:
|
||||
project, err = projectProvider.ProjectByOIDCClientID(ctx, request.ApplicationID)
|
||||
case domain.AuthRequestTypeOIDC, domain.AuthRequestTypeSAML:
|
||||
project, err = projectProvider.ProjectByClientID(ctx, request.ApplicationID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@@ -22,6 +22,10 @@ import (
|
||||
user_view_model "github.com/zitadel/zitadel/internal/user/repository/view/model"
|
||||
)
|
||||
|
||||
var (
|
||||
testNow = time.Now()
|
||||
)
|
||||
|
||||
type mockViewNoUserSession struct{}
|
||||
|
||||
func (m *mockViewNoUserSession) UserSessionByIDs(string, string, string) (*user_view_model.UserSessionView, error) {
|
||||
@@ -191,7 +195,7 @@ type mockUserGrants struct {
|
||||
userGrants int
|
||||
}
|
||||
|
||||
func (m *mockUserGrants) ProjectByOIDCClientID(ctx context.Context, s string) (*query.Project, error) {
|
||||
func (m *mockUserGrants) ProjectByClientID(ctx context.Context, s string) (*query.Project, error) {
|
||||
return &query.Project{ProjectRoleCheck: m.roleCheck}, nil
|
||||
}
|
||||
|
||||
@@ -208,7 +212,7 @@ type mockProject struct {
|
||||
projectCheck bool
|
||||
}
|
||||
|
||||
func (m *mockProject) ProjectByOIDCClientID(ctx context.Context, s string) (*query.Project, error) {
|
||||
func (m *mockProject) ProjectByClientID(ctx context.Context, s string) (*query.Project, error) {
|
||||
return &query.Project{HasProjectCheck: m.projectCheck}, nil
|
||||
}
|
||||
|
||||
@@ -615,8 +619,8 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"passwordless verified, email not verified, email verification step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
PasswordlessVerification: time.Now().Add(-5 * time.Minute),
|
||||
MultiFactorVerification: time.Now().Add(-5 * time.Minute),
|
||||
PasswordlessVerification: testNow.Add(-5 * time.Minute),
|
||||
MultiFactorVerification: testNow.Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
@@ -667,7 +671,7 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"external user (no external verification), external login step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
SecondFactorVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
SecondFactorVerification: testNow.Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
IsEmailVerified: true,
|
||||
@@ -699,8 +703,8 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"external user (external verification set), callback",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
ExternalLoginVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
SecondFactorVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
ExternalLoginVerification: testNow.Add(-5 * time.Minute),
|
||||
SecondFactorVerification: testNow.Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
IsEmailVerified: true,
|
||||
@@ -759,8 +763,8 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"external user (no password check needed), callback",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
SecondFactorVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
ExternalLoginVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
SecondFactorVerification: testNow.Add(-5 * time.Minute),
|
||||
ExternalLoginVerification: testNow.Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
@@ -795,7 +799,7 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"password verified, passwordless set up, mfa not verified, mfa check step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
PasswordVerification: testNow.Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
@@ -829,7 +833,7 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"mfa not verified, mfa check step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
PasswordVerification: testNow.Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
@@ -862,8 +866,8 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"external user, mfa not verified, mfa check step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
ExternalLoginVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
PasswordVerification: testNow.Add(-5 * time.Minute),
|
||||
ExternalLoginVerification: testNow.Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
@@ -898,8 +902,8 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"password change required and email verified, password change step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
SecondFactorVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
PasswordVerification: testNow.Add(-5 * time.Minute),
|
||||
SecondFactorVerification: testNow.Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
@@ -931,8 +935,8 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"email not verified and no password change required, mail verification step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
SecondFactorVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
PasswordVerification: testNow.Add(-5 * time.Minute),
|
||||
SecondFactorVerification: testNow.Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
@@ -961,8 +965,8 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"email not verified and password change required, mail verification step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
SecondFactorVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
PasswordVerification: testNow.Add(-5 * time.Minute),
|
||||
SecondFactorVerification: testNow.Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
@@ -992,8 +996,8 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"email verified and no password change required, redirect to callback step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
SecondFactorVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
PasswordVerification: testNow.Add(-5 * time.Minute),
|
||||
SecondFactorVerification: testNow.Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
@@ -1027,8 +1031,8 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"prompt none, checkLoggedIn true and authenticated, redirect to callback step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
SecondFactorVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
PasswordVerification: testNow.Add(-5 * time.Minute),
|
||||
SecondFactorVerification: testNow.Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
@@ -1063,8 +1067,8 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"prompt none, checkLoggedIn true, authenticated and native, login succeeded step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
SecondFactorVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
PasswordVerification: testNow.Add(-5 * time.Minute),
|
||||
SecondFactorVerification: testNow.Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
@@ -1099,8 +1103,8 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"prompt none, checkLoggedIn true, authenticated and required user grants missing, grant required step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
SecondFactorVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
PasswordVerification: testNow.Add(-5 * time.Minute),
|
||||
SecondFactorVerification: testNow.Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
@@ -1137,8 +1141,8 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"prompt none, checkLoggedIn true, authenticated and required user grants exist, redirect to callback step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
SecondFactorVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
PasswordVerification: testNow.Add(-5 * time.Minute),
|
||||
SecondFactorVerification: testNow.Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
@@ -1176,8 +1180,8 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"prompt none, checkLoggedIn true, authenticated and required project missing, project required step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
SecondFactorVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
PasswordVerification: testNow.Add(-5 * time.Minute),
|
||||
SecondFactorVerification: testNow.Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
@@ -1214,8 +1218,8 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"prompt none, checkLoggedIn true, authenticated and required project exist, redirect to callback step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
SecondFactorVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
PasswordVerification: testNow.Add(-5 * time.Minute),
|
||||
SecondFactorVerification: testNow.Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
@@ -1253,7 +1257,7 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"linking users, password step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
SecondFactorVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
SecondFactorVerification: testNow.Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
@@ -1287,8 +1291,8 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"linking users, linking step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
SecondFactorVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
PasswordVerification: testNow.Add(-5 * time.Minute),
|
||||
SecondFactorVerification: testNow.Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
@@ -1463,7 +1467,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) {
|
||||
user: &user_model.UserView{
|
||||
HumanView: &user_model.HumanView{
|
||||
MFAMaxSetUp: domain.MFALevelNotSetUp,
|
||||
MFAInitSkipped: time.Now().UTC(),
|
||||
MFAInitSkipped: testNow,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1486,7 +1490,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) {
|
||||
OTPState: user_model.MFAStateReady,
|
||||
},
|
||||
},
|
||||
userSession: &user_model.UserSessionView{SecondFactorVerification: time.Now().UTC().Add(-5 * time.Hour)},
|
||||
userSession: &user_model.UserSessionView{SecondFactorVerification: testNow.Add(-5 * time.Hour)},
|
||||
},
|
||||
nil,
|
||||
true,
|
||||
@@ -1569,7 +1573,7 @@ func TestAuthRequestRepo_mfaSkippedOrSetUp(t *testing.T) {
|
||||
user: &user_model.UserView{
|
||||
HumanView: &user_model.HumanView{
|
||||
MFAMaxSetUp: -1,
|
||||
MFAInitSkipped: time.Now().UTC().Add(-10 * time.Hour),
|
||||
MFAInitSkipped: testNow.Add(-10 * time.Hour),
|
||||
},
|
||||
},
|
||||
request: &domain.AuthRequest{
|
||||
@@ -1587,7 +1591,7 @@ func TestAuthRequestRepo_mfaSkippedOrSetUp(t *testing.T) {
|
||||
user: &user_model.UserView{
|
||||
HumanView: &user_model.HumanView{
|
||||
MFAMaxSetUp: -1,
|
||||
MFAInitSkipped: time.Now().UTC().Add(-40 * 24 * time.Hour),
|
||||
MFAInitSkipped: testNow.Add(-40 * 24 * time.Hour),
|
||||
},
|
||||
},
|
||||
request: &domain.AuthRequest{
|
||||
@@ -1645,13 +1649,13 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
"error user events, old view model state",
|
||||
args{
|
||||
userProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Round(1 * time.Second),
|
||||
PasswordVerification: testNow,
|
||||
},
|
||||
user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}},
|
||||
eventProvider: &mockEventErrUser{},
|
||||
},
|
||||
&user_model.UserSessionView{
|
||||
PasswordVerification: time.Now().UTC().Round(1 * time.Second),
|
||||
PasswordVerification: testNow,
|
||||
SecondFactorVerification: time.Time{},
|
||||
MultiFactorVerification: time.Time{},
|
||||
},
|
||||
@@ -1661,7 +1665,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
"new user events but error, old view model state",
|
||||
args{
|
||||
userProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Round(1 * time.Second),
|
||||
PasswordVerification: testNow,
|
||||
},
|
||||
agentID: "agentID",
|
||||
user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}},
|
||||
@@ -1669,12 +1673,12 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
&es_models.Event{
|
||||
AggregateType: user_repo.AggregateType,
|
||||
Type: es_models.EventType(user_repo.UserV1MFAOTPCheckSucceededType),
|
||||
CreationDate: time.Now().UTC().Round(1 * time.Second),
|
||||
CreationDate: testNow,
|
||||
},
|
||||
},
|
||||
},
|
||||
&user_model.UserSessionView{
|
||||
PasswordVerification: time.Now().UTC().Round(1 * time.Second),
|
||||
PasswordVerification: testNow,
|
||||
SecondFactorVerification: time.Time{},
|
||||
MultiFactorVerification: time.Time{},
|
||||
},
|
||||
@@ -1684,7 +1688,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
"new user events but other agentID, old view model state",
|
||||
args{
|
||||
userProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Round(1 * time.Second),
|
||||
PasswordVerification: testNow,
|
||||
},
|
||||
agentID: "agentID",
|
||||
user: &user_model.UserView{ID: "id"},
|
||||
@@ -1692,7 +1696,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
&es_models.Event{
|
||||
AggregateType: user_repo.AggregateType,
|
||||
Type: es_models.EventType(user_repo.UserV1MFAOTPCheckSucceededType),
|
||||
CreationDate: time.Now().UTC().Round(1 * time.Second),
|
||||
CreationDate: testNow,
|
||||
Data: func() []byte {
|
||||
data, _ := json.Marshal(&user_es_model.AuthRequest{UserAgentID: "otherID"})
|
||||
return data
|
||||
@@ -1701,7 +1705,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
},
|
||||
},
|
||||
&user_model.UserSessionView{
|
||||
PasswordVerification: time.Now().UTC().Round(1 * time.Second),
|
||||
PasswordVerification: testNow,
|
||||
SecondFactorVerification: time.Time{},
|
||||
MultiFactorVerification: time.Time{},
|
||||
},
|
||||
@@ -1711,7 +1715,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
"new user events, new view model state",
|
||||
args{
|
||||
userProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Round(1 * time.Second),
|
||||
PasswordVerification: testNow,
|
||||
},
|
||||
agentID: "agentID",
|
||||
user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}},
|
||||
@@ -1719,7 +1723,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
&es_models.Event{
|
||||
AggregateType: user_repo.AggregateType,
|
||||
Type: es_models.EventType(user_repo.UserV1MFAOTPCheckSucceededType),
|
||||
CreationDate: time.Now().UTC().Round(1 * time.Second),
|
||||
CreationDate: testNow,
|
||||
Data: func() []byte {
|
||||
data, _ := json.Marshal(&user_es_model.AuthRequest{UserAgentID: "agentID"})
|
||||
return data
|
||||
@@ -1728,9 +1732,9 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
},
|
||||
},
|
||||
&user_model.UserSessionView{
|
||||
PasswordVerification: time.Now().UTC().Round(1 * time.Second),
|
||||
SecondFactorVerification: time.Now().UTC().Round(1 * time.Second),
|
||||
ChangeDate: time.Now().UTC().Round(1 * time.Second),
|
||||
PasswordVerification: testNow,
|
||||
SecondFactorVerification: testNow,
|
||||
ChangeDate: testNow,
|
||||
},
|
||||
nil,
|
||||
},
|
||||
@@ -1738,7 +1742,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
"new user events (user deleted), precondition failed error",
|
||||
args{
|
||||
userProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Round(1 * time.Second),
|
||||
PasswordVerification: testNow,
|
||||
},
|
||||
agentID: "agentID",
|
||||
user: &user_model.UserView{ID: "id"},
|
||||
@@ -1816,7 +1820,7 @@ func Test_userByID(t *testing.T) {
|
||||
&es_models.Event{
|
||||
AggregateType: user_repo.AggregateType,
|
||||
Type: es_models.EventType(user_repo.UserV1PasswordChangedType),
|
||||
CreationDate: time.Now().UTC().Round(1 * time.Second),
|
||||
CreationDate: testNow,
|
||||
Data: nil,
|
||||
},
|
||||
},
|
||||
@@ -1842,7 +1846,7 @@ func Test_userByID(t *testing.T) {
|
||||
&es_models.Event{
|
||||
AggregateType: user_repo.AggregateType,
|
||||
Type: es_models.EventType(user_repo.UserV1PasswordChangedType),
|
||||
CreationDate: time.Now().UTC().Round(1 * time.Second),
|
||||
CreationDate: testNow,
|
||||
Data: func() []byte {
|
||||
data, _ := json.Marshal(user_es_model.Password{ChangeRequired: false, Secret: &crypto.CryptoValue{}})
|
||||
return data
|
||||
@@ -1851,13 +1855,13 @@ func Test_userByID(t *testing.T) {
|
||||
},
|
||||
},
|
||||
&user_model.UserView{
|
||||
ChangeDate: time.Now().UTC().Round(1 * time.Second),
|
||||
ChangeDate: testNow,
|
||||
State: user_model.UserStateActive,
|
||||
UserName: "UserName",
|
||||
HumanView: &user_model.HumanView{
|
||||
PasswordSet: true,
|
||||
PasswordChangeRequired: false,
|
||||
PasswordChanged: time.Now().UTC().Round(1 * time.Second),
|
||||
PasswordChanged: testNow,
|
||||
FirstName: "FirstName",
|
||||
},
|
||||
},
|
||||
|
@@ -1,10 +1,11 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/http"
|
||||
api_http "github.com/zitadel/zitadel/internal/api/http"
|
||||
sd "github.com/zitadel/zitadel/internal/config/systemdefaults"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
@@ -23,6 +24,8 @@ import (
|
||||
)
|
||||
|
||||
type Commands struct {
|
||||
httpClient *http.Client
|
||||
|
||||
eventstore *eventstore.Eventstore
|
||||
static static.Storage
|
||||
idGenerator id.Generator
|
||||
@@ -40,14 +43,17 @@ type Commands struct {
|
||||
applicationKeySize int
|
||||
domainVerificationAlg crypto.EncryptionAlgorithm
|
||||
domainVerificationGenerator crypto.Generator
|
||||
domainVerificationValidator func(domain, token, verifier string, checkType http.CheckType) error
|
||||
domainVerificationValidator func(domain, token, verifier string, checkType api_http.CheckType) error
|
||||
|
||||
multifactors domain.MultifactorConfigs
|
||||
webauthnConfig *webauthn_helper.Config
|
||||
keySize int
|
||||
keyAlgorithm crypto.EncryptionAlgorithm
|
||||
privateKeyLifetime time.Duration
|
||||
publicKeyLifetime time.Duration
|
||||
multifactors domain.MultifactorConfigs
|
||||
webauthnConfig *webauthn_helper.Config
|
||||
keySize int
|
||||
keyAlgorithm crypto.EncryptionAlgorithm
|
||||
certificateAlgorithm crypto.EncryptionAlgorithm
|
||||
certKeySize int
|
||||
privateKeyLifetime time.Duration
|
||||
publicKeyLifetime time.Duration
|
||||
certificateLifetime time.Duration
|
||||
}
|
||||
|
||||
func StartCommands(es *eventstore.Eventstore,
|
||||
@@ -64,7 +70,9 @@ func StartCommands(es *eventstore.Eventstore,
|
||||
smsEncryption,
|
||||
userEncryption,
|
||||
domainVerificationEncryption,
|
||||
oidcEncryption crypto.EncryptionAlgorithm,
|
||||
oidcEncryption,
|
||||
samlEncryption crypto.EncryptionAlgorithm,
|
||||
httpClient *http.Client,
|
||||
) (repo *Commands, err error) {
|
||||
if externalDomain == "" {
|
||||
return nil, errors.ThrowInvalidArgument(nil, "COMMAND-Df21s", "no external domain specified")
|
||||
@@ -78,15 +86,19 @@ func StartCommands(es *eventstore.Eventstore,
|
||||
externalSecure: externalSecure,
|
||||
externalPort: externalPort,
|
||||
keySize: defaults.KeyConfig.Size,
|
||||
certKeySize: defaults.KeyConfig.CertificateSize,
|
||||
privateKeyLifetime: defaults.KeyConfig.PrivateKeyLifetime,
|
||||
publicKeyLifetime: defaults.KeyConfig.PublicKeyLifetime,
|
||||
certificateLifetime: defaults.KeyConfig.CertificateLifetime,
|
||||
idpConfigEncryption: idpConfigEncryption,
|
||||
smtpEncryption: smtpEncryption,
|
||||
smsEncryption: smsEncryption,
|
||||
userEncryption: userEncryption,
|
||||
domainVerificationAlg: domainVerificationEncryption,
|
||||
keyAlgorithm: oidcEncryption,
|
||||
certificateAlgorithm: samlEncryption,
|
||||
webauthnConfig: webAuthN,
|
||||
httpClient: httpClient,
|
||||
}
|
||||
|
||||
instance_repo.RegisterEventMappers(repo.eventstore)
|
||||
@@ -109,7 +121,7 @@ func StartCommands(es *eventstore.Eventstore,
|
||||
}
|
||||
|
||||
repo.domainVerificationGenerator = crypto.NewEncryptionGenerator(defaults.DomainVerification.VerificationGenerator, repo.domainVerificationAlg)
|
||||
repo.domainVerificationValidator = http.ValidateDomain
|
||||
repo.domainVerificationValidator = api_http.ValidateDomain
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
|
@@ -2,6 +2,10 @@ package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
@@ -34,3 +38,138 @@ func (c *Commands) GenerateSigningKeyPair(ctx context.Context, algorithm string)
|
||||
privateKeyExp, publicKeyExp))
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Commands) GenerateSAMLCACertificate(ctx context.Context, algorithm string) error {
|
||||
now := time.Now().UTC()
|
||||
after := now.Add(c.certificateLifetime)
|
||||
randInt, err := rand.Int(rand.Reader, big.NewInt(1000))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
privateCrypto, publicCrypto, certificateCrypto, err := crypto.GenerateEncryptedKeyPairWithCACertificate(c.certKeySize, c.keyAlgorithm, c.certificateAlgorithm, &crypto.CertificateInformations{
|
||||
SerialNumber: randInt,
|
||||
Organisation: []string{"ZITADEL"},
|
||||
CommonName: "ZITADEL SAML CA",
|
||||
NotBefore: now,
|
||||
NotAfter: after,
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageCertSign,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keyID, err := c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keyPairWriteModel := NewKeyPairWriteModel(keyID, authz.GetInstance(ctx).InstanceID())
|
||||
keyAgg := KeyPairAggregateFromWriteModel(&keyPairWriteModel.WriteModel)
|
||||
_, err = c.eventstore.Push(ctx,
|
||||
keypair.NewAddedEvent(
|
||||
ctx,
|
||||
keyAgg,
|
||||
domain.KeyUsageSAMLCA,
|
||||
algorithm,
|
||||
privateCrypto, publicCrypto,
|
||||
after, after,
|
||||
),
|
||||
keypair.NewAddedCertificateEvent(
|
||||
ctx,
|
||||
keyAgg,
|
||||
certificateCrypto,
|
||||
after,
|
||||
),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Commands) GenerateSAMLResponseCertificate(ctx context.Context, algorithm string, caPrivateKey *rsa.PrivateKey, caCertificate []byte) error {
|
||||
now := time.Now().UTC()
|
||||
after := now.Add(c.certificateLifetime)
|
||||
randInt, err := rand.Int(rand.Reader, big.NewInt(1000))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
privateCrypto, publicCrypto, certificateCrypto, err := crypto.GenerateEncryptedKeyPairWithCertificate(c.certKeySize, c.keyAlgorithm, c.certificateAlgorithm, caPrivateKey, caCertificate, &crypto.CertificateInformations{
|
||||
SerialNumber: randInt,
|
||||
Organisation: []string{"ZITADEL"},
|
||||
CommonName: "ZITADEL SAML response",
|
||||
NotBefore: now,
|
||||
NotAfter: after,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keyID, err := c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keyPairWriteModel := NewKeyPairWriteModel(keyID, authz.GetInstance(ctx).InstanceID())
|
||||
keyAgg := KeyPairAggregateFromWriteModel(&keyPairWriteModel.WriteModel)
|
||||
_, err = c.eventstore.Push(ctx,
|
||||
keypair.NewAddedEvent(
|
||||
ctx,
|
||||
keyAgg,
|
||||
domain.KeyUsageSAMLResponseSinging,
|
||||
algorithm,
|
||||
privateCrypto, publicCrypto,
|
||||
after, after,
|
||||
),
|
||||
keypair.NewAddedCertificateEvent(
|
||||
ctx,
|
||||
keyAgg,
|
||||
certificateCrypto,
|
||||
after,
|
||||
),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Commands) GenerateSAMLMetadataCertificate(ctx context.Context, algorithm string, caPrivateKey *rsa.PrivateKey, caCertificate []byte) error {
|
||||
now := time.Now().UTC()
|
||||
after := now.Add(c.certificateLifetime)
|
||||
randInt, err := rand.Int(rand.Reader, big.NewInt(1000))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
privateCrypto, publicCrypto, certificateCrypto, err := crypto.GenerateEncryptedKeyPairWithCertificate(c.certKeySize, c.keyAlgorithm, c.certificateAlgorithm, caPrivateKey, caCertificate, &crypto.CertificateInformations{
|
||||
SerialNumber: randInt,
|
||||
Organisation: []string{"ZITADEL"},
|
||||
CommonName: "ZITADEL SAML metadata",
|
||||
NotBefore: now,
|
||||
NotAfter: after,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keyID, err := c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keyPairWriteModel := NewKeyPairWriteModel(keyID, authz.GetInstance(ctx).InstanceID())
|
||||
keyAgg := KeyPairAggregateFromWriteModel(&keyPairWriteModel.WriteModel)
|
||||
_, err = c.eventstore.Push(ctx,
|
||||
keypair.NewAddedEvent(
|
||||
ctx,
|
||||
keyAgg,
|
||||
domain.KeyUsageSAMLMetadataSigning,
|
||||
algorithm,
|
||||
privateCrypto, publicCrypto,
|
||||
after, after),
|
||||
keypair.NewAddedCertificateEvent(
|
||||
ctx,
|
||||
keyAgg,
|
||||
certificateCrypto,
|
||||
after,
|
||||
),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
@@ -9,10 +9,11 @@ import (
|
||||
type KeyPairWriteModel struct {
|
||||
eventstore.WriteModel
|
||||
|
||||
Usage domain.KeyUsage
|
||||
Algorithm string
|
||||
PrivateKey *domain.Key
|
||||
PublicKey *domain.Key
|
||||
Usage domain.KeyUsage
|
||||
Algorithm string
|
||||
PrivateKey *domain.Key
|
||||
PublicKey *domain.Key
|
||||
Certificate *domain.Key
|
||||
}
|
||||
|
||||
func NewKeyPairWriteModel(aggregateID, resourceOwner string) *KeyPairWriteModel {
|
||||
@@ -42,6 +43,11 @@ func (wm *KeyPairWriteModel) Reduce() error {
|
||||
Key: e.PublicKey.Key,
|
||||
Expiry: e.PublicKey.Expiry,
|
||||
}
|
||||
case *keypair.AddedCertificateEvent:
|
||||
wm.Certificate = &domain.Key{
|
||||
Key: e.Certificate.Key,
|
||||
Expiry: e.Certificate.Expiry,
|
||||
}
|
||||
}
|
||||
}
|
||||
return wm.WriteModel.Reduce()
|
||||
@@ -53,11 +59,10 @@ func (wm *KeyPairWriteModel) Query() *eventstore.SearchQueryBuilder {
|
||||
AddQuery().
|
||||
AggregateTypes(keypair.AggregateType).
|
||||
AggregateIDs(wm.AggregateID).
|
||||
EventTypes(keypair.AddedEventType).
|
||||
EventTypes(keypair.AddedEventType, keypair.AddedCertificateEventType).
|
||||
Builder()
|
||||
}
|
||||
|
||||
func KeyPairAggregateFromWriteModel(wm *eventstore.WriteModel) *eventstore.Aggregate {
|
||||
return eventstore.AggregateFromWriteModel(wm, keypair.AggregateType, keypair.AggregateVersion)
|
||||
|
||||
}
|
||||
|
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/command/preparation"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
@@ -276,9 +277,20 @@ func (c *Commands) RemoveProject(ctx context.Context, projectID, resourceOwner s
|
||||
if existingProject.State == domain.ProjectStateUnspecified || existingProject.State == domain.ProjectStateRemoved {
|
||||
return nil, caos_errs.ThrowNotFound(nil, "COMMAND-3M9sd", "Errors.Project.NotFound")
|
||||
}
|
||||
|
||||
samlEntityIDsAgg, err := c.getSAMLEntityIdsWriteModelByProjectID(ctx, projectID, resourceOwner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uniqueConstraints := make([]*eventstore.EventUniqueConstraint, len(samlEntityIDsAgg.EntityIDs))
|
||||
for i, entityID := range samlEntityIDsAgg.EntityIDs {
|
||||
uniqueConstraints[i] = project.NewRemoveSAMLConfigEntityIDUniqueConstraint(entityID.EntityID)
|
||||
}
|
||||
|
||||
projectAgg := ProjectAggregateFromWriteModel(&existingProject.WriteModel)
|
||||
events := []eventstore.Command{
|
||||
project.NewProjectRemovedEvent(ctx, projectAgg, existingProject.Name),
|
||||
project.NewProjectRemovedEvent(ctx, projectAgg, existingProject.Name, uniqueConstraints),
|
||||
}
|
||||
|
||||
for _, grantID := range cascadingUserGrantIDs {
|
||||
@@ -309,3 +321,12 @@ func (c *Commands) getProjectWriteModelByID(ctx context.Context, projectID, reso
|
||||
}
|
||||
return projectWriteModel, nil
|
||||
}
|
||||
|
||||
func (c *Commands) getSAMLEntityIdsWriteModelByProjectID(ctx context.Context, projectID, resourceOwner string) (*SAMLEntityIDsWriteModel, error) {
|
||||
samlEntityIDsAgg := NewSAMLEntityIDsWriteModel(projectID, resourceOwner)
|
||||
err := c.eventstore.FilterToQueryReducer(ctx, samlEntityIDsAgg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return samlEntityIDsAgg, nil
|
||||
}
|
||||
|
@@ -118,7 +118,13 @@ func (c *Commands) RemoveApplication(ctx context.Context, projectID, appID, reso
|
||||
}
|
||||
projectAgg := ProjectAggregateFromWriteModel(&existingApp.WriteModel)
|
||||
|
||||
pushedEvents, err := c.eventstore.Push(ctx, project.NewApplicationRemovedEvent(ctx, projectAgg, appID, existingApp.Name))
|
||||
entityID := ""
|
||||
samlWriteModel, err := c.getSAMLAppWriteModel(ctx, projectID, appID, resourceOwner)
|
||||
if err == nil && samlWriteModel.State != domain.AppStateUnspecified && samlWriteModel.State != domain.AppStateRemoved && samlWriteModel.saml {
|
||||
entityID = samlWriteModel.EntityID
|
||||
}
|
||||
|
||||
pushedEvents, err := c.eventstore.Push(ctx, project.NewApplicationRemovedEvent(ctx, projectAgg, appID, existingApp.Name, entityID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
146
internal/command/project_application_saml.go
Normal file
146
internal/command/project_application_saml.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/saml/pkg/provider/xml"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/project"
|
||||
)
|
||||
|
||||
func (c *Commands) AddSAMLApplication(ctx context.Context, application *domain.SAMLApp, resourceOwner string) (_ *domain.SAMLApp, err error) {
|
||||
if application == nil || application.AggregateID == "" {
|
||||
return nil, caos_errs.ThrowInvalidArgument(nil, "PROJECT-35Fn0", "Errors.Project.App.Invalid")
|
||||
}
|
||||
|
||||
_, err = c.getProjectByID(ctx, application.AggregateID, resourceOwner)
|
||||
if err != nil {
|
||||
return nil, caos_errs.ThrowPreconditionFailed(err, "PROJECT-3p9ss", "Errors.Project.NotFound")
|
||||
}
|
||||
|
||||
addedApplication := NewSAMLApplicationWriteModel(application.AggregateID, resourceOwner)
|
||||
projectAgg := ProjectAggregateFromWriteModel(&addedApplication.WriteModel)
|
||||
events, err := c.addSAMLApplication(ctx, projectAgg, application)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addedApplication.AppID = application.AppID
|
||||
pushedEvents, err := c.eventstore.Push(ctx, events...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = AppendAndReduce(addedApplication, pushedEvents...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := samlWriteModelToSAMLConfig(addedApplication)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *Commands) addSAMLApplication(ctx context.Context, projectAgg *eventstore.Aggregate, samlApp *domain.SAMLApp) (events []eventstore.Command, err error) {
|
||||
|
||||
if samlApp.AppName == "" || !samlApp.IsValid() {
|
||||
return nil, caos_errs.ThrowInvalidArgument(nil, "PROJECT-1n9df", "Errors.Project.App.Invalid")
|
||||
}
|
||||
|
||||
if samlApp.Metadata == nil && samlApp.MetadataURL == "" {
|
||||
return nil, caos_errs.ThrowInvalidArgument(nil, "SAML-podix9", "Errors.Project.App.SAMLMetadataMissing")
|
||||
}
|
||||
|
||||
if samlApp.MetadataURL != "" {
|
||||
data, err := xml.ReadMetadataFromURL(c.httpClient, samlApp.MetadataURL)
|
||||
if err != nil {
|
||||
return nil, caos_errs.ThrowInvalidArgument(err, "SAML-wmqlo1", "Errors.Project.App.SAMLMetadataMissing")
|
||||
}
|
||||
samlApp.Metadata = data
|
||||
}
|
||||
|
||||
entity, err := xml.ParseMetadataXmlIntoStruct(samlApp.Metadata)
|
||||
if err != nil {
|
||||
return nil, caos_errs.ThrowInvalidArgument(err, "SAML-bquso", "Errors.Project.App.SAMLMetadataFormat")
|
||||
}
|
||||
|
||||
samlApp.AppID, err = c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []eventstore.Command{
|
||||
project.NewApplicationAddedEvent(ctx, projectAgg, samlApp.AppID, samlApp.AppName),
|
||||
project.NewSAMLConfigAddedEvent(ctx,
|
||||
projectAgg,
|
||||
samlApp.AppID,
|
||||
string(entity.EntityID),
|
||||
samlApp.Metadata,
|
||||
samlApp.MetadataURL,
|
||||
),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Commands) ChangeSAMLApplication(ctx context.Context, samlApp *domain.SAMLApp, resourceOwner string) (*domain.SAMLApp, error) {
|
||||
if !samlApp.IsValid() || samlApp.AppID == "" || samlApp.AggregateID == "" {
|
||||
return nil, caos_errs.ThrowInvalidArgument(nil, "COMMAND-5n9fs", "Errors.Project.App.SAMLConfigInvalid")
|
||||
}
|
||||
|
||||
existingSAML, err := c.getSAMLAppWriteModel(ctx, samlApp.AggregateID, samlApp.AppID, resourceOwner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if existingSAML.State == domain.AppStateUnspecified || existingSAML.State == domain.AppStateRemoved {
|
||||
return nil, caos_errs.ThrowNotFound(nil, "COMMAND-2n8uU", "Errors.Project.App.NotExisting")
|
||||
}
|
||||
if !existingSAML.IsSAML() {
|
||||
return nil, caos_errs.ThrowInvalidArgument(nil, "COMMAND-GBr35", "Errors.Project.App.IsNotSAML")
|
||||
}
|
||||
projectAgg := ProjectAggregateFromWriteModel(&existingSAML.WriteModel)
|
||||
|
||||
if samlApp.MetadataURL != "" {
|
||||
data, err := xml.ReadMetadataFromURL(c.httpClient, samlApp.MetadataURL)
|
||||
if err != nil {
|
||||
return nil, caos_errs.ThrowInvalidArgument(err, "SAML-J3kg3", "Errors.Project.App.SAMLMetadataMissing")
|
||||
}
|
||||
samlApp.Metadata = data
|
||||
}
|
||||
|
||||
entity, err := xml.ParseMetadataXmlIntoStruct(samlApp.Metadata)
|
||||
if err != nil {
|
||||
return nil, caos_errs.ThrowInvalidArgument(err, "SAML-3fk2b", "Errors.Project.App.SAMLMetadataFormat")
|
||||
}
|
||||
|
||||
changedEvent, hasChanged, err := existingSAML.NewChangedEvent(
|
||||
ctx,
|
||||
projectAgg,
|
||||
samlApp.AppID,
|
||||
string(entity.EntityID),
|
||||
samlApp.Metadata,
|
||||
samlApp.MetadataURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !hasChanged {
|
||||
return nil, caos_errs.ThrowPreconditionFailed(nil, "COMMAND-1m88i", "Errors.NoChangesFound")
|
||||
}
|
||||
|
||||
pushedEvents, err := c.eventstore.Push(ctx, changedEvent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = AppendAndReduce(existingSAML, pushedEvents...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return samlWriteModelToSAMLConfig(existingSAML), nil
|
||||
}
|
||||
|
||||
func (c *Commands) getSAMLAppWriteModel(ctx context.Context, projectID, appID, resourceOwner string) (*SAMLApplicationWriteModel, error) {
|
||||
appWriteModel := NewSAMLApplicationWriteModelWithAppID(projectID, appID, resourceOwner)
|
||||
err := c.eventstore.FilterToQueryReducer(ctx, appWriteModel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return appWriteModel, nil
|
||||
}
|
268
internal/command/project_application_saml_model.go
Normal file
268
internal/command/project_application_saml_model.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/project"
|
||||
)
|
||||
|
||||
type SAMLApplicationWriteModel struct {
|
||||
eventstore.WriteModel
|
||||
|
||||
AppID string
|
||||
AppName string
|
||||
EntityID string
|
||||
Metadata []byte
|
||||
MetadataURL string
|
||||
|
||||
State domain.AppState
|
||||
saml bool
|
||||
}
|
||||
|
||||
func NewSAMLApplicationWriteModelWithAppID(projectID, appID, resourceOwner string) *SAMLApplicationWriteModel {
|
||||
return &SAMLApplicationWriteModel{
|
||||
WriteModel: eventstore.WriteModel{
|
||||
AggregateID: projectID,
|
||||
ResourceOwner: resourceOwner,
|
||||
},
|
||||
AppID: appID,
|
||||
}
|
||||
}
|
||||
|
||||
func NewSAMLApplicationWriteModel(projectID, resourceOwner string) *SAMLApplicationWriteModel {
|
||||
return &SAMLApplicationWriteModel{
|
||||
WriteModel: eventstore.WriteModel{
|
||||
AggregateID: projectID,
|
||||
ResourceOwner: resourceOwner,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (wm *SAMLApplicationWriteModel) AppendEvents(events ...eventstore.Event) {
|
||||
for _, event := range events {
|
||||
switch e := event.(type) {
|
||||
case *project.ApplicationAddedEvent:
|
||||
if e.AppID != wm.AppID {
|
||||
continue
|
||||
}
|
||||
wm.WriteModel.AppendEvents(e)
|
||||
case *project.ApplicationChangedEvent:
|
||||
if e.AppID != wm.AppID {
|
||||
continue
|
||||
}
|
||||
wm.WriteModel.AppendEvents(e)
|
||||
case *project.ApplicationDeactivatedEvent:
|
||||
if e.AppID != wm.AppID {
|
||||
continue
|
||||
}
|
||||
wm.WriteModel.AppendEvents(e)
|
||||
case *project.ApplicationReactivatedEvent:
|
||||
if e.AppID != wm.AppID {
|
||||
continue
|
||||
}
|
||||
wm.WriteModel.AppendEvents(e)
|
||||
case *project.ApplicationRemovedEvent:
|
||||
if e.AppID != wm.AppID {
|
||||
continue
|
||||
}
|
||||
wm.WriteModel.AppendEvents(e)
|
||||
case *project.SAMLConfigAddedEvent:
|
||||
if e.AppID != wm.AppID {
|
||||
continue
|
||||
}
|
||||
wm.WriteModel.AppendEvents(e)
|
||||
case *project.SAMLConfigChangedEvent:
|
||||
if e.AppID != wm.AppID {
|
||||
continue
|
||||
}
|
||||
wm.WriteModel.AppendEvents(e)
|
||||
case *project.ProjectRemovedEvent:
|
||||
wm.WriteModel.AppendEvents(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (wm *SAMLApplicationWriteModel) Reduce() error {
|
||||
for _, event := range wm.Events {
|
||||
switch e := event.(type) {
|
||||
case *project.ApplicationAddedEvent:
|
||||
wm.AppName = e.Name
|
||||
wm.State = domain.AppStateActive
|
||||
case *project.ApplicationChangedEvent:
|
||||
wm.AppName = e.Name
|
||||
case *project.ApplicationDeactivatedEvent:
|
||||
if wm.State == domain.AppStateRemoved {
|
||||
continue
|
||||
}
|
||||
wm.State = domain.AppStateInactive
|
||||
case *project.ApplicationReactivatedEvent:
|
||||
if wm.State == domain.AppStateRemoved {
|
||||
continue
|
||||
}
|
||||
wm.State = domain.AppStateActive
|
||||
case *project.ApplicationRemovedEvent:
|
||||
wm.State = domain.AppStateRemoved
|
||||
case *project.SAMLConfigAddedEvent:
|
||||
wm.appendAddSAMLEvent(e)
|
||||
case *project.SAMLConfigChangedEvent:
|
||||
wm.appendChangeSAMLEvent(e)
|
||||
case *project.ProjectRemovedEvent:
|
||||
wm.State = domain.AppStateRemoved
|
||||
}
|
||||
}
|
||||
return wm.WriteModel.Reduce()
|
||||
}
|
||||
|
||||
func (wm *SAMLApplicationWriteModel) appendAddSAMLEvent(e *project.SAMLConfigAddedEvent) {
|
||||
wm.saml = true
|
||||
wm.Metadata = e.Metadata
|
||||
wm.MetadataURL = e.MetadataURL
|
||||
wm.EntityID = e.EntityID
|
||||
}
|
||||
|
||||
func (wm *SAMLApplicationWriteModel) appendChangeSAMLEvent(e *project.SAMLConfigChangedEvent) {
|
||||
wm.saml = true
|
||||
if e.Metadata != nil {
|
||||
wm.Metadata = e.Metadata
|
||||
}
|
||||
if e.MetadataURL != nil {
|
||||
wm.MetadataURL = *e.MetadataURL
|
||||
}
|
||||
if e.EntityID != "" {
|
||||
wm.EntityID = e.EntityID
|
||||
}
|
||||
}
|
||||
|
||||
func (wm *SAMLApplicationWriteModel) Query() *eventstore.SearchQueryBuilder {
|
||||
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
ResourceOwner(wm.ResourceOwner).
|
||||
AddQuery().
|
||||
AggregateTypes(project.AggregateType).
|
||||
AggregateIDs(wm.AggregateID).
|
||||
EventTypes(
|
||||
project.ApplicationAddedType,
|
||||
project.ApplicationChangedType,
|
||||
project.ApplicationDeactivatedType,
|
||||
project.ApplicationReactivatedType,
|
||||
project.ApplicationRemovedType,
|
||||
project.SAMLConfigAddedType,
|
||||
project.SAMLConfigChangedType,
|
||||
project.ProjectRemovedType).
|
||||
Builder()
|
||||
}
|
||||
|
||||
func (wm *SAMLApplicationWriteModel) NewChangedEvent(
|
||||
ctx context.Context,
|
||||
aggregate *eventstore.Aggregate,
|
||||
appID string,
|
||||
entityID string,
|
||||
metadata []byte,
|
||||
metadataURL string,
|
||||
) (*project.SAMLConfigChangedEvent, bool, error) {
|
||||
changes := make([]project.SAMLConfigChanges, 0)
|
||||
var err error
|
||||
if !reflect.DeepEqual(wm.Metadata, metadata) {
|
||||
changes = append(changes, project.ChangeMetadata(metadata))
|
||||
}
|
||||
if wm.MetadataURL != metadataURL {
|
||||
changes = append(changes, project.ChangeMetadataURL(metadataURL))
|
||||
}
|
||||
if wm.EntityID != entityID {
|
||||
changes = append(changes, project.ChangeEntityID(entityID))
|
||||
}
|
||||
|
||||
if len(changes) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
changeEvent, err := project.NewSAMLConfigChangedEvent(ctx, aggregate, appID, wm.EntityID, changes)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return changeEvent, true, nil
|
||||
}
|
||||
|
||||
func (wm *SAMLApplicationWriteModel) IsSAML() bool {
|
||||
return wm.saml
|
||||
}
|
||||
|
||||
type AppIDToEntityID struct {
|
||||
AppID string
|
||||
EntityID string
|
||||
}
|
||||
|
||||
type SAMLEntityIDsWriteModel struct {
|
||||
eventstore.WriteModel
|
||||
|
||||
EntityIDs []*AppIDToEntityID
|
||||
}
|
||||
|
||||
func NewSAMLEntityIDsWriteModel(projectID, resourceOwner string) *SAMLEntityIDsWriteModel {
|
||||
return &SAMLEntityIDsWriteModel{
|
||||
WriteModel: eventstore.WriteModel{
|
||||
AggregateID: projectID,
|
||||
ResourceOwner: resourceOwner,
|
||||
},
|
||||
EntityIDs: []*AppIDToEntityID{},
|
||||
}
|
||||
}
|
||||
|
||||
func (wm *SAMLEntityIDsWriteModel) Query() *eventstore.SearchQueryBuilder {
|
||||
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
ResourceOwner(wm.ResourceOwner).
|
||||
AddQuery().
|
||||
AggregateTypes(project.AggregateType).
|
||||
AggregateIDs(wm.AggregateID).
|
||||
EventTypes(
|
||||
project.ApplicationRemovedType,
|
||||
project.SAMLConfigAddedType,
|
||||
project.SAMLConfigChangedType).
|
||||
Builder()
|
||||
}
|
||||
|
||||
func (wm *SAMLEntityIDsWriteModel) AppendEvents(events ...eventstore.Event) {
|
||||
for _, event := range events {
|
||||
switch e := event.(type) {
|
||||
case *project.ApplicationRemovedEvent:
|
||||
wm.WriteModel.AppendEvents(e)
|
||||
case *project.SAMLConfigAddedEvent:
|
||||
wm.WriteModel.AppendEvents(e)
|
||||
case *project.SAMLConfigChangedEvent:
|
||||
if e.EntityID != "" {
|
||||
wm.WriteModel.AppendEvents(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (wm *SAMLEntityIDsWriteModel) Reduce() error {
|
||||
for _, event := range wm.Events {
|
||||
switch e := event.(type) {
|
||||
case *project.ApplicationRemovedEvent:
|
||||
removeAppIDFromEntityIDs(wm.EntityIDs, e.AppID)
|
||||
case *project.SAMLConfigAddedEvent:
|
||||
wm.EntityIDs = append(wm.EntityIDs, &AppIDToEntityID{AppID: e.AppID, EntityID: e.EntityID})
|
||||
case *project.SAMLConfigChangedEvent:
|
||||
for i := range wm.EntityIDs {
|
||||
item := wm.EntityIDs[i]
|
||||
if e.AppID == item.AppID && e.EntityID != "" {
|
||||
item.EntityID = e.EntityID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return wm.WriteModel.Reduce()
|
||||
}
|
||||
|
||||
func removeAppIDFromEntityIDs(items []*AppIDToEntityID, appID string) []*AppIDToEntityID {
|
||||
for i := len(items) - 1; i >= 0; i-- {
|
||||
if items[i].AppID == appID {
|
||||
items[i] = items[len(items)-1]
|
||||
items[len(items)-1] = nil
|
||||
items = items[:len(items)-1]
|
||||
}
|
||||
}
|
||||
return items
|
||||
}
|
776
internal/command/project_application_saml_test.go
Normal file
776
internal/command/project_application_saml_test.go
Normal file
@@ -0,0 +1,776 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
id_mock "github.com/zitadel/zitadel/internal/id/mock"
|
||||
"github.com/zitadel/zitadel/internal/repository/project"
|
||||
)
|
||||
|
||||
var testMetadata = []byte(`<?xml version="1.0"?>
|
||||
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata"
|
||||
validUntil="2022-08-26T14:08:16Z"
|
||||
cacheDuration="PT604800S"
|
||||
entityID="https://test.com/saml/metadata">
|
||||
<md:SPSSODescriptor AuthnRequestsSigned="false" WantAssertionsSigned="false" protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
|
||||
<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>
|
||||
<md:AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
|
||||
Location="https://test.com/saml/acs"
|
||||
index="1" />
|
||||
|
||||
</md:SPSSODescriptor>
|
||||
</md:EntityDescriptor>
|
||||
`)
|
||||
var testMetadataChangedEntityID = []byte(`<?xml version="1.0"?>
|
||||
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata"
|
||||
validUntil="2022-08-26T14:08:16Z"
|
||||
cacheDuration="PT604800S"
|
||||
entityID="https://test2.com/saml/metadata">
|
||||
<md:SPSSODescriptor AuthnRequestsSigned="false" WantAssertionsSigned="false" protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
|
||||
<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>
|
||||
<md:AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
|
||||
Location="https://test.com/saml/acs"
|
||||
index="1" />
|
||||
|
||||
</md:SPSSODescriptor>
|
||||
</md:EntityDescriptor>
|
||||
`)
|
||||
|
||||
func TestCommandSide_AddSAMLApplication(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
httpClient *http.Client
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
samlApp *domain.SAMLApp
|
||||
resourceOwner string
|
||||
}
|
||||
type res struct {
|
||||
want *domain.SAMLApp
|
||||
err func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "no aggregate id, invalid argument error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
samlApp: &domain.SAMLApp{},
|
||||
resourceOwner: "org1",
|
||||
},
|
||||
res: res{
|
||||
err: errors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "project not existing, not found error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
samlApp: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
},
|
||||
AppID: "app1",
|
||||
AppName: "app",
|
||||
},
|
||||
resourceOwner: "org1",
|
||||
},
|
||||
res: res{
|
||||
err: errors.IsPreconditionFailed,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid app, invalid argument error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
project.NewProjectAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"project", true, true, true,
|
||||
domain.PrivateLabelingSettingUnspecified),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
samlApp: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
},
|
||||
AppID: "app1",
|
||||
AppName: "",
|
||||
},
|
||||
resourceOwner: "org1",
|
||||
},
|
||||
res: res{
|
||||
err: errors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create saml app, metadata not parsable",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
project.NewProjectAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"project", true, true, true,
|
||||
domain.PrivateLabelingSettingUnspecified),
|
||||
),
|
||||
),
|
||||
),
|
||||
idGenerator: id_mock.NewIDGeneratorExpectIDs(t),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
samlApp: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
},
|
||||
AppName: "app",
|
||||
EntityID: "https://test.com/saml/metadata",
|
||||
Metadata: []byte("test metadata"),
|
||||
MetadataURL: "",
|
||||
},
|
||||
resourceOwner: "org1",
|
||||
},
|
||||
res: res{
|
||||
err: errors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create saml app, ok",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
project.NewProjectAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"project", true, true, true,
|
||||
domain.PrivateLabelingSettingUnspecified),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusher(
|
||||
project.NewApplicationAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app1",
|
||||
"app",
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
project.NewSAMLConfigAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app1",
|
||||
"https://test.com/saml/metadata",
|
||||
testMetadata,
|
||||
"",
|
||||
),
|
||||
),
|
||||
},
|
||||
uniqueConstraintsFromEventConstraint(project.NewAddApplicationUniqueConstraint("app", "project1")),
|
||||
uniqueConstraintsFromEventConstraint(project.NewAddSAMLConfigEntityIDUniqueConstraint("https://test.com/saml/metadata")),
|
||||
),
|
||||
),
|
||||
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "app1"),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
samlApp: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
},
|
||||
AppName: "app",
|
||||
EntityID: "https://test.com/saml/metadata",
|
||||
Metadata: testMetadata,
|
||||
MetadataURL: "",
|
||||
},
|
||||
resourceOwner: "org1",
|
||||
},
|
||||
res: res{
|
||||
want: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
ResourceOwner: "org1",
|
||||
},
|
||||
AppID: "app1",
|
||||
AppName: "app",
|
||||
EntityID: "https://test.com/saml/metadata",
|
||||
Metadata: testMetadata,
|
||||
MetadataURL: "",
|
||||
State: domain.AppStateActive,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create saml app metadataURL, ok",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
project.NewProjectAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"project", true, true, true,
|
||||
domain.PrivateLabelingSettingUnspecified),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusher(
|
||||
project.NewApplicationAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app1",
|
||||
"app",
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
project.NewSAMLConfigAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app1",
|
||||
"https://test.com/saml/metadata",
|
||||
testMetadata,
|
||||
"http://localhost:8080/saml/metadata",
|
||||
),
|
||||
),
|
||||
},
|
||||
uniqueConstraintsFromEventConstraint(project.NewAddApplicationUniqueConstraint("app", "project1")),
|
||||
uniqueConstraintsFromEventConstraint(project.NewAddSAMLConfigEntityIDUniqueConstraint("https://test.com/saml/metadata")),
|
||||
),
|
||||
),
|
||||
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "app1"),
|
||||
httpClient: newTestClient(200, testMetadata),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
samlApp: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
},
|
||||
AppName: "app",
|
||||
EntityID: "https://test.com/saml/metadata",
|
||||
Metadata: nil,
|
||||
MetadataURL: "http://localhost:8080/saml/metadata",
|
||||
},
|
||||
resourceOwner: "org1",
|
||||
},
|
||||
res: res{
|
||||
want: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
ResourceOwner: "org1",
|
||||
},
|
||||
AppID: "app1",
|
||||
AppName: "app",
|
||||
EntityID: "https://test.com/saml/metadata",
|
||||
Metadata: testMetadata,
|
||||
MetadataURL: "http://localhost:8080/saml/metadata",
|
||||
State: domain.AppStateActive,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create saml app metadataURL, http error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
project.NewProjectAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"project", true, true, true,
|
||||
domain.PrivateLabelingSettingUnspecified),
|
||||
),
|
||||
),
|
||||
),
|
||||
idGenerator: id_mock.NewIDGeneratorExpectIDs(t),
|
||||
httpClient: newTestClient(http.StatusNotFound, nil),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
samlApp: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
},
|
||||
AppName: "app",
|
||||
EntityID: "https://test.com/saml/metadata",
|
||||
Metadata: nil,
|
||||
MetadataURL: "http://localhost:8080/saml/metadata",
|
||||
},
|
||||
resourceOwner: "org1",
|
||||
},
|
||||
res: res{
|
||||
err: errors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
httpClient: tt.fields.httpClient,
|
||||
}
|
||||
|
||||
got, err := r.AddSAMLApplication(tt.args.ctx, tt.args.samlApp, tt.args.resourceOwner)
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
if tt.res.err != nil && !tt.res.err(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
}
|
||||
if tt.res.err == nil {
|
||||
assert.Equal(t, tt.res.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommandSide_ChangeSAMLApplication(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
httpClient *http.Client
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
samlApp *domain.SAMLApp
|
||||
resourceOwner string
|
||||
}
|
||||
type res struct {
|
||||
want *domain.SAMLApp
|
||||
err func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "invalid app, invalid argument error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
samlApp: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
},
|
||||
AppID: "app1",
|
||||
},
|
||||
resourceOwner: "org1",
|
||||
},
|
||||
res: res{
|
||||
err: errors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing appid, invalid argument error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
samlApp: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
},
|
||||
AppID: "",
|
||||
Metadata: []byte("just not empty"),
|
||||
},
|
||||
resourceOwner: "org1",
|
||||
},
|
||||
res: res{
|
||||
err: errors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing aggregateid, invalid argument error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
samlApp: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "",
|
||||
},
|
||||
AppID: "appid",
|
||||
Metadata: []byte("just not empty"),
|
||||
},
|
||||
resourceOwner: "org1",
|
||||
},
|
||||
res: res{
|
||||
err: errors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "app not existing, not found error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
samlApp: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
},
|
||||
AppID: "app1",
|
||||
Metadata: []byte("just not empty"),
|
||||
},
|
||||
resourceOwner: "org1",
|
||||
},
|
||||
res: res{
|
||||
err: errors.IsNotFound,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no changes, precondition error, metadataURL",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
project.NewApplicationAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app1",
|
||||
"app",
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
project.NewSAMLConfigAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app1",
|
||||
"https://test.com/saml/metadata",
|
||||
testMetadata,
|
||||
"http://localhost:8080/saml/metadata",
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
httpClient: newTestClient(http.StatusOK, testMetadata),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
samlApp: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
ResourceOwner: "org1",
|
||||
},
|
||||
AppName: "app",
|
||||
AppID: "app1",
|
||||
EntityID: "https://test.com/saml/metadata",
|
||||
Metadata: nil,
|
||||
MetadataURL: "http://localhost:8080/saml/metadata",
|
||||
},
|
||||
resourceOwner: "org1",
|
||||
},
|
||||
res: res{
|
||||
err: errors.IsPreconditionFailed,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no changes, precondition error, metadata",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
project.NewApplicationAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app1",
|
||||
"app",
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
project.NewSAMLConfigAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app1",
|
||||
"https://test.com/saml/metadata",
|
||||
testMetadata,
|
||||
"",
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
httpClient: nil,
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
samlApp: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
ResourceOwner: "org1",
|
||||
},
|
||||
AppName: "app",
|
||||
AppID: "app1",
|
||||
EntityID: "https://test.com/saml/metadata",
|
||||
Metadata: testMetadata,
|
||||
MetadataURL: "",
|
||||
},
|
||||
resourceOwner: "org1",
|
||||
},
|
||||
res: res{
|
||||
err: errors.IsPreconditionFailed,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "change saml app, ok, metadataURL",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
project.NewApplicationAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app1",
|
||||
"app",
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
project.NewSAMLConfigAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app1",
|
||||
"https://test.com/saml/metadata",
|
||||
testMetadata,
|
||||
"http://localhost:8080/saml/metadata",
|
||||
),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusher(
|
||||
newSAMLAppChangedEventMetadataURL(context.Background(),
|
||||
"app1",
|
||||
"project1",
|
||||
"org1",
|
||||
"https://test.com/saml/metadata",
|
||||
"https://test2.com/saml/metadata",
|
||||
testMetadataChangedEntityID,
|
||||
),
|
||||
),
|
||||
},
|
||||
uniqueConstraintsFromEventConstraint(project.NewRemoveSAMLConfigEntityIDUniqueConstraint("https://test.com/saml/metadata")),
|
||||
uniqueConstraintsFromEventConstraint(project.NewAddSAMLConfigEntityIDUniqueConstraint("https://test2.com/saml/metadata")),
|
||||
),
|
||||
),
|
||||
httpClient: newTestClient(http.StatusOK, testMetadataChangedEntityID),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
samlApp: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
ResourceOwner: "org1",
|
||||
},
|
||||
AppID: "app1",
|
||||
AppName: "app",
|
||||
EntityID: "https://test2.com/saml/metadata",
|
||||
Metadata: nil,
|
||||
MetadataURL: "http://localhost:8080/saml/metadata",
|
||||
},
|
||||
resourceOwner: "org1",
|
||||
},
|
||||
res: res{
|
||||
want: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
ResourceOwner: "org1",
|
||||
},
|
||||
AppID: "app1",
|
||||
AppName: "app",
|
||||
EntityID: "https://test2.com/saml/metadata",
|
||||
Metadata: testMetadataChangedEntityID,
|
||||
MetadataURL: "http://localhost:8080/saml/metadata",
|
||||
State: domain.AppStateActive,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "change saml app, ok, metadata",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
project.NewApplicationAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app1",
|
||||
"app",
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(
|
||||
project.NewSAMLConfigAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app1",
|
||||
"https://test.com/saml/metadata",
|
||||
testMetadata,
|
||||
"",
|
||||
),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusher(
|
||||
newSAMLAppChangedEventMetadata(context.Background(),
|
||||
"app1",
|
||||
"project1",
|
||||
"org1",
|
||||
"https://test.com/saml/metadata",
|
||||
"https://test2.com/saml/metadata",
|
||||
testMetadataChangedEntityID,
|
||||
),
|
||||
),
|
||||
},
|
||||
uniqueConstraintsFromEventConstraint(project.NewRemoveSAMLConfigEntityIDUniqueConstraint("https://test.com/saml/metadata")),
|
||||
uniqueConstraintsFromEventConstraint(project.NewAddSAMLConfigEntityIDUniqueConstraint("https://test2.com/saml/metadata")),
|
||||
),
|
||||
),
|
||||
httpClient: nil,
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
samlApp: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
ResourceOwner: "org1",
|
||||
},
|
||||
AppID: "app1",
|
||||
AppName: "app",
|
||||
EntityID: "https://test2.com/saml/metadata",
|
||||
Metadata: testMetadataChangedEntityID,
|
||||
MetadataURL: "",
|
||||
},
|
||||
resourceOwner: "org1",
|
||||
},
|
||||
res: res{
|
||||
want: &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: "project1",
|
||||
ResourceOwner: "org1",
|
||||
},
|
||||
AppID: "app1",
|
||||
AppName: "app",
|
||||
EntityID: "https://test2.com/saml/metadata",
|
||||
Metadata: testMetadataChangedEntityID,
|
||||
MetadataURL: "",
|
||||
State: domain.AppStateActive,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
httpClient: tt.fields.httpClient,
|
||||
}
|
||||
got, err := r.ChangeSAMLApplication(tt.args.ctx, tt.args.samlApp, tt.args.resourceOwner)
|
||||
if tt.res.err == nil {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
if tt.res.err != nil && !tt.res.err(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
}
|
||||
if tt.res.err == nil {
|
||||
assert.Equal(t, tt.res.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newSAMLAppChangedEventMetadata(ctx context.Context, appID, projectID, resourceOwner, oldEntityID, entityID string, metadata []byte) *project.SAMLConfigChangedEvent {
|
||||
changes := []project.SAMLConfigChanges{
|
||||
project.ChangeEntityID(entityID),
|
||||
project.ChangeMetadata(metadata),
|
||||
}
|
||||
event, _ := project.NewSAMLConfigChangedEvent(ctx,
|
||||
&project.NewAggregate(projectID, resourceOwner).Aggregate,
|
||||
appID,
|
||||
oldEntityID,
|
||||
changes,
|
||||
)
|
||||
return event
|
||||
}
|
||||
|
||||
func newSAMLAppChangedEventMetadataURL(ctx context.Context, appID, projectID, resourceOwner, oldEntityID, entityID string, metadata []byte) *project.SAMLConfigChangedEvent {
|
||||
changes := []project.SAMLConfigChanges{
|
||||
project.ChangeEntityID(entityID),
|
||||
project.ChangeMetadata(metadata),
|
||||
}
|
||||
event, _ := project.NewSAMLConfigChangedEvent(ctx,
|
||||
&project.NewAggregate(projectID, resourceOwner).Aggregate,
|
||||
appID,
|
||||
oldEntityID,
|
||||
changes,
|
||||
)
|
||||
return event
|
||||
}
|
||||
|
||||
type roundTripperFunc func(*http.Request) *http.Response
|
||||
|
||||
// RoundTrip implements the http.RoundTripper interface.
|
||||
func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return fn(req), nil
|
||||
}
|
||||
|
||||
// NewTestClient returns *http.Client with Transport replaced to avoid making real calls
|
||||
func newTestClient(httpStatus int, metadata []byte) *http.Client {
|
||||
fn := roundTripperFunc(func(req *http.Request) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: httpStatus,
|
||||
Body: ioutil.NopCloser(bytes.NewBuffer(metadata)),
|
||||
Header: make(http.Header), //must be non-nil value
|
||||
}
|
||||
})
|
||||
return &http.Client{
|
||||
Transport: fn,
|
||||
}
|
||||
}
|
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
@@ -580,6 +581,58 @@ func TestCommandSide_RemoveApplication(t *testing.T) {
|
||||
err: caos_errs.IsNotFound,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "app remove, entityID, ok",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(project.NewApplicationAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app1",
|
||||
"app",
|
||||
)),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(project.NewApplicationAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app1",
|
||||
"app",
|
||||
)),
|
||||
eventFromEventPusher(project.NewSAMLConfigAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app1",
|
||||
"https://test.com/saml/metadata",
|
||||
[]byte("<?xml version=\"1.0\"?>\n<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\"\n validUntil=\"2022-08-26T14:08:16Z\"\n cacheDuration=\"PT604800S\"\n entityID=\"https://test.com/saml/metadata\">\n <md:SPSSODescriptor AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"false\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>\n <md:AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"\n Location=\"https://test.com/saml/acs\"\n index=\"1\" />\n \n </md:SPSSODescriptor>\n</md:EntityDescriptor>"),
|
||||
"",
|
||||
)),
|
||||
),
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusher(project.NewApplicationRemovedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app1",
|
||||
"app",
|
||||
"https://test.com/saml/metadata",
|
||||
)),
|
||||
}, /**/
|
||||
uniqueConstraintsFromEventConstraint(project.NewRemoveApplicationUniqueConstraint("app", "project1")),
|
||||
uniqueConstraintsFromEventConstraint(project.NewRemoveSAMLConfigEntityIDUniqueConstraint("https://test.com/saml/metadata")),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
projectID: "project1",
|
||||
appID: "app1",
|
||||
resourceOwner: "org1",
|
||||
},
|
||||
res: res{
|
||||
want: &domain.ObjectDetails{
|
||||
ResourceOwner: "org1",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "app remove, ok",
|
||||
fields: fields{
|
||||
@@ -592,12 +645,15 @@ func TestCommandSide_RemoveApplication(t *testing.T) {
|
||||
"app",
|
||||
)),
|
||||
),
|
||||
// app is not saml, or no saml config available
|
||||
expectFilter(),
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusher(project.NewApplicationRemovedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app1",
|
||||
"app",
|
||||
"",
|
||||
)),
|
||||
},
|
||||
uniqueConstraintsFromEventConstraint(project.NewRemoveApplicationUniqueConstraint("app", "project1")),
|
||||
|
@@ -57,6 +57,18 @@ func oidcWriteModelToOIDCConfig(writeModel *OIDCApplicationWriteModel) *domain.O
|
||||
}
|
||||
}
|
||||
|
||||
func samlWriteModelToSAMLConfig(writeModel *SAMLApplicationWriteModel) *domain.SAMLApp {
|
||||
return &domain.SAMLApp{
|
||||
ObjectRoot: writeModelToObjectRoot(writeModel.WriteModel),
|
||||
AppID: writeModel.AppID,
|
||||
AppName: writeModel.AppName,
|
||||
State: writeModel.State,
|
||||
Metadata: writeModel.Metadata,
|
||||
MetadataURL: writeModel.MetadataURL,
|
||||
EntityID: writeModel.EntityID,
|
||||
}
|
||||
}
|
||||
|
||||
func apiWriteModelToAPIConfig(writeModel *APIApplicationWriteModel) *domain.APIApp {
|
||||
return &domain.APIApp{
|
||||
ObjectRoot: writeModelToObjectRoot(writeModel.WriteModel),
|
||||
|
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
@@ -50,6 +51,7 @@ func TestCommandSide_AddProjectRole(t *testing.T) {
|
||||
project.NewProjectRemovedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"projectname1",
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
@@ -253,6 +255,7 @@ func TestCommandSide_BulkAddProjectRole(t *testing.T) {
|
||||
project.NewProjectRemovedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"projectname1",
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
@@ -503,6 +506,7 @@ func TestCommandSide_ChangeProjectRole(t *testing.T) {
|
||||
project.NewProjectRemovedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"projectname1",
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
|
@@ -269,7 +269,8 @@ func TestCommandSide_ChangeProject(t *testing.T) {
|
||||
eventFromEventPusher(
|
||||
project.NewProjectRemovedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"project"),
|
||||
"project",
|
||||
nil),
|
||||
),
|
||||
),
|
||||
),
|
||||
@@ -542,7 +543,8 @@ func TestCommandSide_DeactivateProject(t *testing.T) {
|
||||
eventFromEventPusher(
|
||||
project.NewProjectRemovedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"project"),
|
||||
"project",
|
||||
nil),
|
||||
),
|
||||
),
|
||||
),
|
||||
@@ -721,7 +723,8 @@ func TestCommandSide_ReactivateProject(t *testing.T) {
|
||||
eventFromEventPusher(
|
||||
project.NewProjectRemovedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"project"),
|
||||
"project",
|
||||
nil),
|
||||
),
|
||||
),
|
||||
),
|
||||
@@ -900,7 +903,8 @@ func TestCommandSide_RemoveProject(t *testing.T) {
|
||||
eventFromEventPusher(
|
||||
project.NewProjectRemovedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"project"),
|
||||
"project",
|
||||
nil),
|
||||
),
|
||||
),
|
||||
),
|
||||
@@ -915,7 +919,7 @@ func TestCommandSide_RemoveProject(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "project remove, ok",
|
||||
name: "project remove, without entityConstraints, ok",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
@@ -927,12 +931,15 @@ func TestCommandSide_RemoveProject(t *testing.T) {
|
||||
domain.PrivateLabelingSettingAllowLoginUserResourceOwnerPolicy),
|
||||
),
|
||||
),
|
||||
// no saml application events
|
||||
expectFilter(),
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusher(
|
||||
project.NewProjectRemovedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"project"),
|
||||
"project",
|
||||
nil),
|
||||
),
|
||||
},
|
||||
uniqueConstraintsFromEventConstraint(project.NewRemoveProjectNameUniqueConstraint("project", "org1")),
|
||||
@@ -950,6 +957,150 @@ func TestCommandSide_RemoveProject(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "project remove, with entityConstraints, ok",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
project.NewProjectAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"project", true, true, true,
|
||||
domain.PrivateLabelingSettingAllowLoginUserResourceOwnerPolicy),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(project.NewApplicationAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app1",
|
||||
"app",
|
||||
)),
|
||||
eventFromEventPusher(
|
||||
project.NewSAMLConfigAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app1",
|
||||
"https://test.com/saml/metadata",
|
||||
[]byte("<?xml version=\"1.0\"?>\n<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\"\n validUntil=\"2022-08-26T14:08:16Z\"\n cacheDuration=\"PT604800S\"\n entityID=\"https://test.com/saml/metadata\">\n <md:SPSSODescriptor AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"false\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>\n <md:AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"\n Location=\"https://test.com/saml/acs\"\n index=\"1\" />\n \n </md:SPSSODescriptor>\n</md:EntityDescriptor>"),
|
||||
"http://localhost:8080/saml/metadata",
|
||||
),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusher(
|
||||
project.NewProjectRemovedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"project",
|
||||
[]*eventstore.EventUniqueConstraint{
|
||||
project.NewRemoveSAMLConfigEntityIDUniqueConstraint("https://test.com/saml/metadata"),
|
||||
}),
|
||||
),
|
||||
},
|
||||
uniqueConstraintsFromEventConstraint(project.NewRemoveProjectNameUniqueConstraint("project", "org1")),
|
||||
uniqueConstraintsFromEventConstraint(project.NewRemoveSAMLConfigEntityIDUniqueConstraint("https://test.com/saml/metadata")),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
projectID: "project1",
|
||||
resourceOwner: "org1",
|
||||
},
|
||||
res: res{
|
||||
want: &domain.ObjectDetails{
|
||||
ResourceOwner: "org1",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "project remove, with multiple entityConstraints, ok",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
project.NewProjectAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"project", true, true, true,
|
||||
domain.PrivateLabelingSettingAllowLoginUserResourceOwnerPolicy),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(project.NewApplicationAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app1",
|
||||
"app",
|
||||
)),
|
||||
eventFromEventPusher(
|
||||
project.NewSAMLConfigAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app1",
|
||||
"https://test1.com/saml/metadata",
|
||||
[]byte("<?xml version=\"1.0\"?>\n<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\"\n validUntil=\"2022-08-26T14:08:16Z\"\n cacheDuration=\"PT604800S\"\n entityID=\"https://test.com/saml/metadata\">\n <md:SPSSODescriptor AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"false\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>\n <md:AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"\n Location=\"https://test.com/saml/acs\"\n index=\"1\" />\n \n </md:SPSSODescriptor>\n</md:EntityDescriptor>"),
|
||||
"",
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(project.NewApplicationAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app2",
|
||||
"app",
|
||||
)),
|
||||
eventFromEventPusher(
|
||||
project.NewSAMLConfigAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app2",
|
||||
"https://test2.com/saml/metadata",
|
||||
[]byte("<?xml version=\"1.0\"?>\n<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\"\n validUntil=\"2022-08-26T14:08:16Z\"\n cacheDuration=\"PT604800S\"\n entityID=\"https://test.com/saml/metadata\">\n <md:SPSSODescriptor AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"false\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>\n <md:AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"\n Location=\"https://test.com/saml/acs\"\n index=\"1\" />\n \n </md:SPSSODescriptor>\n</md:EntityDescriptor>"),
|
||||
"",
|
||||
),
|
||||
),
|
||||
eventFromEventPusher(project.NewApplicationAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app3",
|
||||
"app",
|
||||
)),
|
||||
eventFromEventPusher(
|
||||
project.NewSAMLConfigAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"app3",
|
||||
"https://test3.com/saml/metadata",
|
||||
[]byte("<?xml version=\"1.0\"?>\n<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\"\n validUntil=\"2022-08-26T14:08:16Z\"\n cacheDuration=\"PT604800S\"\n entityID=\"https://test.com/saml/metadata\">\n <md:SPSSODescriptor AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"false\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>\n <md:AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"\n Location=\"https://test.com/saml/acs\"\n index=\"1\" />\n \n </md:SPSSODescriptor>\n</md:EntityDescriptor>"),
|
||||
"",
|
||||
),
|
||||
),
|
||||
),
|
||||
expectPush(
|
||||
[]*repository.Event{
|
||||
eventFromEventPusher(
|
||||
project.NewProjectRemovedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"project",
|
||||
[]*eventstore.EventUniqueConstraint{
|
||||
project.NewRemoveSAMLConfigEntityIDUniqueConstraint("https://test1.com/saml/metadata"),
|
||||
project.NewRemoveSAMLConfigEntityIDUniqueConstraint("https://test2.com/saml/metadata"),
|
||||
project.NewRemoveSAMLConfigEntityIDUniqueConstraint("https://test3.com/saml/metadata"),
|
||||
}),
|
||||
),
|
||||
},
|
||||
uniqueConstraintsFromEventConstraint(project.NewRemoveProjectNameUniqueConstraint("project", "org1")),
|
||||
uniqueConstraintsFromEventConstraint(project.NewRemoveSAMLConfigEntityIDUniqueConstraint("https://test1.com/saml/metadata")),
|
||||
uniqueConstraintsFromEventConstraint(project.NewRemoveSAMLConfigEntityIDUniqueConstraint("https://test2.com/saml/metadata")),
|
||||
uniqueConstraintsFromEventConstraint(project.NewRemoveSAMLConfigEntityIDUniqueConstraint("https://test3.com/saml/metadata")),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
projectID: "project1",
|
||||
resourceOwner: "org1",
|
||||
},
|
||||
res: res{
|
||||
want: &domain.ObjectDetails{
|
||||
ResourceOwner: "org1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/idpconfig"
|
||||
@@ -92,6 +93,10 @@ func (rm *UniqueConstraintReadModel) Reduce() error {
|
||||
rm.addUniqueConstraint(e.Aggregate().ID, e.AppID, project.NewAddApplicationUniqueConstraint(e.Name, e.Aggregate().ID))
|
||||
case *project.ApplicationChangedEvent:
|
||||
rm.changeUniqueConstraint(e.Aggregate().ID, e.AppID, project.NewAddApplicationUniqueConstraint(e.Name, e.Aggregate().ID))
|
||||
case *project.SAMLConfigAddedEvent:
|
||||
rm.addUniqueConstraint(e.Aggregate().ID, e.AppID, project.NewAddSAMLConfigEntityIDUniqueConstraint(e.EntityID))
|
||||
case *project.SAMLConfigChangedEvent:
|
||||
rm.addUniqueConstraint(e.Aggregate().ID, e.AppID, project.NewRemoveSAMLConfigEntityIDUniqueConstraint(e.EntityID))
|
||||
case *project.ApplicationRemovedEvent:
|
||||
rm.removeUniqueConstraint(e.Aggregate().ID, e.AppID, project.UniqueAppNameType)
|
||||
case *project.GrantAddedEvent:
|
||||
|
@@ -133,6 +133,7 @@ func TestCommandSide_AddUserGrant(t *testing.T) {
|
||||
project.NewProjectRemovedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"projectname1",
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
@@ -819,6 +820,7 @@ func TestCommandSide_ChangeUserGrant(t *testing.T) {
|
||||
project.NewProjectRemovedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"projectname1",
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
|
@@ -37,7 +37,9 @@ type Notifications struct {
|
||||
}
|
||||
|
||||
type KeyConfig struct {
|
||||
Size int
|
||||
PrivateKeyLifetime time.Duration
|
||||
PublicKeyLifetime time.Duration
|
||||
Size int
|
||||
PrivateKeyLifetime time.Duration
|
||||
PublicKeyLifetime time.Duration
|
||||
CertificateSize int
|
||||
CertificateLifetime time.Duration
|
||||
}
|
||||
|
@@ -1,11 +1,16 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"time"
|
||||
)
|
||||
|
||||
func GenerateKeyPair(bits int) (*rsa.PrivateKey, *rsa.PublicKey, error) {
|
||||
@@ -24,6 +29,104 @@ func GenerateEncryptedKeyPair(bits int, alg EncryptionAlgorithm) (*CryptoValue,
|
||||
return EncryptKeys(privateKey, publicKey, alg)
|
||||
}
|
||||
|
||||
type CertificateInformations struct {
|
||||
SerialNumber *big.Int
|
||||
Organisation []string
|
||||
CommonName string
|
||||
NotBefore time.Time
|
||||
NotAfter time.Time
|
||||
KeyUsage x509.KeyUsage
|
||||
ExtKeyUsage []x509.ExtKeyUsage
|
||||
}
|
||||
|
||||
func GenerateEncryptedKeyPairWithCACertificate(bits int, keyAlg, certAlg EncryptionAlgorithm, informations *CertificateInformations) (*CryptoValue, *CryptoValue, *CryptoValue, error) {
|
||||
privateKey, publicKey, cert, err := GenerateCACertificate(bits, informations)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
encryptPriv, encryptPub, encryptCaCert, err := EncryptKeysAndCert(privateKey, publicKey, cert, keyAlg, certAlg)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
return encryptPriv, encryptPub, encryptCaCert, nil
|
||||
}
|
||||
|
||||
func GenerateEncryptedKeyPairWithCertificate(bits int, keyAlg, certAlg EncryptionAlgorithm, caPrivateKey *rsa.PrivateKey, caCertificate []byte, informations *CertificateInformations) (*CryptoValue, *CryptoValue, *CryptoValue, error) {
|
||||
privateKey, publicKey, cert, err := GenerateCertificate(bits, caPrivateKey, caCertificate, informations)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
encryptPriv, encryptPub, encryptCaCert, err := EncryptKeysAndCert(privateKey, publicKey, cert, keyAlg, certAlg)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
return encryptPriv, encryptPub, encryptCaCert, nil
|
||||
}
|
||||
|
||||
func GenerateCACertificate(bits int, informations *CertificateInformations) (*rsa.PrivateKey, *rsa.PublicKey, []byte, error) {
|
||||
return generateCertificate(bits, nil, nil, informations)
|
||||
}
|
||||
|
||||
func GenerateCertificate(bits int, caPrivateKey *rsa.PrivateKey, ca []byte, informations *CertificateInformations) (*rsa.PrivateKey, *rsa.PublicKey, []byte, error) {
|
||||
return generateCertificate(bits, caPrivateKey, ca, informations)
|
||||
}
|
||||
|
||||
func generateCertificate(bits int, caPrivateKey *rsa.PrivateKey, ca []byte, informations *CertificateInformations) (*rsa.PrivateKey, *rsa.PublicKey, []byte, error) {
|
||||
notBefore := time.Now()
|
||||
if !informations.NotBefore.IsZero() {
|
||||
notBefore = informations.NotBefore
|
||||
}
|
||||
cert := &x509.Certificate{
|
||||
SerialNumber: informations.SerialNumber,
|
||||
Subject: pkix.Name{
|
||||
CommonName: informations.CommonName,
|
||||
Organization: informations.Organisation,
|
||||
},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: informations.NotAfter,
|
||||
KeyUsage: informations.KeyUsage,
|
||||
ExtKeyUsage: informations.ExtKeyUsage,
|
||||
}
|
||||
|
||||
certPrivKey, err := rsa.GenerateKey(rand.Reader, bits)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
certBytes := make([]byte, 0)
|
||||
if ca == nil {
|
||||
cert.IsCA = true
|
||||
cert.BasicConstraintsValid = true
|
||||
|
||||
certBytes, err = x509.CreateCertificate(rand.Reader, cert, cert, &certPrivKey.PublicKey, certPrivKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
} else {
|
||||
caCert, err := x509.ParseCertificate(ca)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
certBytes, err = x509.CreateCertificate(rand.Reader, cert, caCert, &certPrivKey.PublicKey, caPrivateKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
x509Cert, err := x509.ParseCertificate(certBytes)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
certPem, err := CertificateToBytes(x509Cert)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
return certPrivKey, &certPrivKey.PublicKey, certPem, nil
|
||||
}
|
||||
|
||||
func PrivateKeyToBytes(priv *rsa.PrivateKey) []byte {
|
||||
return pem.EncodeToMemory(
|
||||
&pem.Block{
|
||||
@@ -101,3 +204,34 @@ func EncryptKeys(privateKey *rsa.PrivateKey, publicKey *rsa.PublicKey, alg Encry
|
||||
}
|
||||
return encryptedPrivateKey, encryptedPublicKey, nil
|
||||
}
|
||||
|
||||
func CertificateToBytes(cert *x509.Certificate) ([]byte, error) {
|
||||
certPem := new(bytes.Buffer)
|
||||
if err := pem.Encode(certPem, &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return certPem.Bytes(), nil
|
||||
}
|
||||
|
||||
func BytesToCertificate(data []byte) ([]byte, error) {
|
||||
block, _ := pem.Decode(data)
|
||||
if block == nil || block.Type != "CERTIFICATE" {
|
||||
return nil, fmt.Errorf("failed to decode PEM block containing public key")
|
||||
}
|
||||
return block.Bytes, nil
|
||||
}
|
||||
|
||||
func EncryptKeysAndCert(privateKey *rsa.PrivateKey, publicKey *rsa.PublicKey, cert []byte, keyAlg, certAlg EncryptionAlgorithm) (*CryptoValue, *CryptoValue, *CryptoValue, error) {
|
||||
encryptedPrivateKey, encryptedPublicKey, err := EncryptKeys(privateKey, publicKey, keyAlg)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
encryptedCertificate, err := Encrypt(cert, certAlg)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
return encryptedPrivateKey, encryptedPublicKey, encryptedCertificate, nil
|
||||
}
|
||||
|
40
internal/domain/application_saml.go
Normal file
40
internal/domain/application_saml.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
)
|
||||
|
||||
type SAMLApp struct {
|
||||
models.ObjectRoot
|
||||
|
||||
AppID string
|
||||
AppName string
|
||||
EntityID string
|
||||
Metadata []byte
|
||||
MetadataURL string
|
||||
|
||||
State AppState
|
||||
}
|
||||
|
||||
func (a *SAMLApp) GetApplicationName() string {
|
||||
return a.AppName
|
||||
}
|
||||
|
||||
func (a *SAMLApp) GetState() AppState {
|
||||
return a.State
|
||||
}
|
||||
|
||||
func (a *SAMLApp) GetMetadata() []byte {
|
||||
return a.Metadata
|
||||
}
|
||||
|
||||
func (a *SAMLApp) GetMetadataURL() string {
|
||||
return a.MetadataURL
|
||||
}
|
||||
|
||||
func (a *SAMLApp) IsValid() bool {
|
||||
if a.MetadataURL == "" && a.Metadata == nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
@@ -119,6 +119,8 @@ func NewAuthRequestFromType(requestType AuthRequestType) (*AuthRequest, error) {
|
||||
switch requestType {
|
||||
case AuthRequestTypeOIDC:
|
||||
return &AuthRequest{Request: &AuthRequestOIDC{}}, nil
|
||||
case AuthRequestTypeSAML:
|
||||
return &AuthRequest{Request: &AuthRequestSAML{}}, nil
|
||||
}
|
||||
return nil, errors.ThrowInvalidArgument(nil, "DOMAIN-ds2kl", "invalid request type")
|
||||
}
|
||||
|
@@ -10,22 +10,32 @@ import (
|
||||
type KeyPair struct {
|
||||
es_models.ObjectRoot
|
||||
|
||||
Usage KeyUsage
|
||||
Algorithm string
|
||||
PrivateKey *Key
|
||||
PublicKey *Key
|
||||
Usage KeyUsage
|
||||
Algorithm string
|
||||
PrivateKey *Key
|
||||
PublicKey *Key
|
||||
Certificate *Key
|
||||
}
|
||||
|
||||
type KeyUsage int32
|
||||
|
||||
const (
|
||||
KeyUsageSigning KeyUsage = iota
|
||||
KeyUsageSAMLMetadataSigning
|
||||
KeyUsageSAMLResponseSinging
|
||||
KeyUsageSAMLCA
|
||||
)
|
||||
|
||||
func (u KeyUsage) String() string {
|
||||
switch u {
|
||||
case KeyUsageSigning:
|
||||
return "sig"
|
||||
case KeyUsageSAMLCA:
|
||||
return "saml_ca"
|
||||
case KeyUsageSAMLResponseSinging:
|
||||
return "saml_response_sig"
|
||||
case KeyUsageSAMLMetadataSigning:
|
||||
return "saml_metadata_sig"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -38,7 +48,8 @@ type Key struct {
|
||||
func (k *KeyPair) IsValid() bool {
|
||||
return k.Algorithm != "" &&
|
||||
k.PrivateKey != nil && k.PrivateKey.IsValid() &&
|
||||
k.PublicKey != nil && k.PublicKey.IsValid()
|
||||
k.PublicKey != nil && k.PublicKey.IsValid() &&
|
||||
k.Certificate != nil && k.Certificate.IsValid()
|
||||
}
|
||||
|
||||
func (k *Key) IsValid() bool {
|
||||
|
@@ -39,6 +39,13 @@ func (a *AuthRequestOIDC) IsValid() bool {
|
||||
}
|
||||
|
||||
type AuthRequestSAML struct {
|
||||
ID string
|
||||
RequestID string
|
||||
BindingType string
|
||||
Code string
|
||||
Issuer string
|
||||
IssuerName string
|
||||
Destination string
|
||||
}
|
||||
|
||||
func (a *AuthRequestSAML) Type() AuthRequestType {
|
||||
|
@@ -13,6 +13,7 @@ type Application struct {
|
||||
Type AppType
|
||||
OIDCConfig *OIDCConfig
|
||||
APIConfig *APIConfig
|
||||
SAMLConfig *SAMLConfig
|
||||
}
|
||||
|
||||
type AppState int32
|
||||
@@ -45,5 +46,8 @@ func (a *Application) IsValid(includeConfig bool) bool {
|
||||
if a.Type == AppTypeAPI && !a.APIConfig.IsValid() {
|
||||
return false
|
||||
}
|
||||
if a.Type == AppTypeSAML && !a.SAMLConfig.IsValid() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
16
internal/project/model/saml_config.go
Normal file
16
internal/project/model/saml_config.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
)
|
||||
|
||||
type SAMLConfig struct {
|
||||
es_models.ObjectRoot
|
||||
AppID string
|
||||
Metadata []byte
|
||||
MetadataURL string
|
||||
}
|
||||
|
||||
func (c *SAMLConfig) IsValid() bool {
|
||||
return !(c.Metadata == nil && c.MetadataURL == "")
|
||||
}
|
@@ -16,6 +16,7 @@ type Application struct {
|
||||
Type int32 `json:"appType,omitempty"`
|
||||
OIDCConfig *OIDCConfig `json:"-"`
|
||||
APIConfig *APIConfig `json:"-"`
|
||||
SAMLConfig *SAMLConfig `json:"-"`
|
||||
}
|
||||
|
||||
type ApplicationID struct {
|
||||
|
@@ -0,0 +1,25 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
)
|
||||
|
||||
type SAMLConfig struct {
|
||||
es_models.ObjectRoot
|
||||
AppID string `json:"appId"`
|
||||
Metadata []byte `json:"metadata,omitempty"`
|
||||
MetadataURL string `json:"metadataUrl,omitempty"`
|
||||
}
|
||||
|
||||
func (o *SAMLConfig) setData(event *es_models.Event) error {
|
||||
o.ObjectRoot.AppendEvent(event)
|
||||
if err := json.Unmarshal(event.Data, o); err != nil {
|
||||
logging.Log("EVEN-d8e3s").WithError(err).Error("could not unmarshal event data")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
@@ -57,6 +57,10 @@ type ApplicationView struct {
|
||||
IDTokenUserinfoAssertion bool `json:"idTokenUserinfoAssertion" gorm:"column:id_token_userinfo_assertion"`
|
||||
ClockSkew time.Duration `json:"clockSkew" gorm:"column:clock_skew"`
|
||||
|
||||
IsSAML bool `json:"-" gorm:"column:is_saml"`
|
||||
Metadata []byte `json:"metadata" gorm:"column:metadata"`
|
||||
MetadataURL string `json:"metadata_url" gorm:"column:metadata_url"`
|
||||
|
||||
Sequence uint64 `json:"-" gorm:"sequence"`
|
||||
}
|
||||
|
||||
@@ -90,7 +94,9 @@ func (a *ApplicationView) AppendEventIfMyApp(event *models.Event) (err error) {
|
||||
project.APIConfigAddedType,
|
||||
project.APIConfigChangedType,
|
||||
project.ApplicationDeactivatedType,
|
||||
project.ApplicationReactivatedType:
|
||||
project.ApplicationReactivatedType,
|
||||
project.SAMLConfigAddedType,
|
||||
project.SAMLConfigChangedType:
|
||||
err = view.SetData(event)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -130,6 +136,9 @@ func (a *ApplicationView) AppendEvent(event *models.Event) (err error) {
|
||||
}
|
||||
a.setCompliance()
|
||||
return a.setOriginAllowList()
|
||||
case project.SAMLConfigAddedType:
|
||||
a.IsSAML = true
|
||||
return a.SetData(event)
|
||||
case project.APIConfigAddedType:
|
||||
a.IsOIDC = false
|
||||
return a.SetData(event)
|
||||
@@ -142,6 +151,8 @@ func (a *ApplicationView) AppendEvent(event *models.Event) (err error) {
|
||||
}
|
||||
a.setCompliance()
|
||||
return a.setOriginAllowList()
|
||||
case project.SAMLConfigChangedType:
|
||||
return a.SetData(event)
|
||||
case project.APIConfigChangedType:
|
||||
return a.SetData(event)
|
||||
case project.ProjectChangedType:
|
||||
|
@@ -33,6 +33,7 @@ type App struct {
|
||||
Name string
|
||||
|
||||
OIDCConfig *OIDCApp
|
||||
SAMLConfig *SAMLApp
|
||||
APIConfig *APIApp
|
||||
}
|
||||
|
||||
@@ -56,6 +57,12 @@ type OIDCApp struct {
|
||||
AllowedOrigins database.StringArray
|
||||
}
|
||||
|
||||
type SAMLApp struct {
|
||||
Metadata []byte
|
||||
MetadataURL string
|
||||
EntityID string
|
||||
}
|
||||
|
||||
type APIApp struct {
|
||||
ClientID string
|
||||
AuthMethodType domain.APIAuthMethodType
|
||||
@@ -116,6 +123,28 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
appSAMLConfigsTable = table{
|
||||
name: projection.AppSAMLTable,
|
||||
}
|
||||
AppSAMLConfigColumnAppID = Column{
|
||||
name: projection.AppSAMLConfigColumnAppID,
|
||||
table: appSAMLConfigsTable,
|
||||
}
|
||||
AppSAMLConfigColumnEntityID = Column{
|
||||
name: projection.AppSAMLConfigColumnEntityID,
|
||||
table: appSAMLConfigsTable,
|
||||
}
|
||||
AppSAMLConfigColumnMetadata = Column{
|
||||
name: projection.AppSAMLConfigColumnMetadata,
|
||||
table: appSAMLConfigsTable,
|
||||
}
|
||||
AppSAMLConfigColumnMetadataURL = Column{
|
||||
name: projection.AppSAMLConfigColumnMetadataURL,
|
||||
table: appSAMLConfigsTable,
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
appAPIConfigsTable = table{
|
||||
name: projection.AppAPITable,
|
||||
@@ -225,6 +254,54 @@ func (q *Queries) AppByProjectAndAppID(ctx context.Context, shouldTriggerBulk bo
|
||||
return scan(row)
|
||||
}
|
||||
|
||||
func (q *Queries) AppByID(ctx context.Context, appID string) (*App, error) {
|
||||
stmt, scan := prepareAppQuery()
|
||||
query, args, err := stmt.Where(
|
||||
sq.Eq{
|
||||
AppColumnID.identifier(): appID,
|
||||
AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
|
||||
},
|
||||
).ToSql()
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "QUERY-immt9", "Errors.Query.SQLStatement")
|
||||
}
|
||||
|
||||
row := q.client.QueryRowContext(ctx, query, args...)
|
||||
return scan(row)
|
||||
}
|
||||
|
||||
func (q *Queries) AppBySAMLEntityID(ctx context.Context, entityID string) (*App, error) {
|
||||
stmt, scan := prepareAppQuery()
|
||||
query, args, err := stmt.Where(
|
||||
sq.Eq{
|
||||
AppSAMLConfigColumnEntityID.identifier(): entityID,
|
||||
},
|
||||
).ToSql()
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "QUERY-JgUop", "Errors.Query.SQLStatement")
|
||||
}
|
||||
|
||||
row := q.client.QueryRowContext(ctx, query, args...)
|
||||
return scan(row)
|
||||
}
|
||||
|
||||
func (q *Queries) ProjectByClientID(ctx context.Context, appID string) (*Project, error) {
|
||||
stmt, scan := prepareProjectByAppQuery()
|
||||
query, args, err := stmt.Where(
|
||||
sq.Or{
|
||||
sq.Eq{AppOIDCConfigColumnClientID.identifier(): appID},
|
||||
sq.Eq{AppAPIConfigColumnClientID.identifier(): appID},
|
||||
sq.Eq{AppSAMLConfigColumnAppID.identifier(): appID},
|
||||
},
|
||||
).ToSql()
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "QUERY-XhJi3", "Errors.Query.SQLStatement")
|
||||
}
|
||||
|
||||
row := q.client.QueryRowContext(ctx, query, args...)
|
||||
return scan(row)
|
||||
}
|
||||
|
||||
func (q *Queries) ProjectIDFromOIDCClientID(ctx context.Context, appID string) (string, error) {
|
||||
stmt, scan := prepareProjectIDByAppQuery()
|
||||
query, args, err := stmt.Where(
|
||||
@@ -249,6 +326,7 @@ func (q *Queries) ProjectIDFromClientID(ctx context.Context, appID string) (stri
|
||||
sq.Or{
|
||||
sq.Eq{AppOIDCConfigColumnClientID.identifier(): appID},
|
||||
sq.Eq{AppAPIConfigColumnClientID.identifier(): appID},
|
||||
sq.Eq{AppSAMLConfigColumnAppID.identifier(): appID},
|
||||
},
|
||||
},
|
||||
).ToSql()
|
||||
@@ -389,15 +467,22 @@ func prepareAppQuery() (sq.SelectBuilder, func(*sql.Row) (*App, error)) {
|
||||
AppOIDCConfigColumnIDTokenUserinfoAssertion.identifier(),
|
||||
AppOIDCConfigColumnClockSkew.identifier(),
|
||||
AppOIDCConfigColumnAdditionalOrigins.identifier(),
|
||||
|
||||
AppSAMLConfigColumnAppID.identifier(),
|
||||
AppSAMLConfigColumnEntityID.identifier(),
|
||||
AppSAMLConfigColumnMetadata.identifier(),
|
||||
AppSAMLConfigColumnMetadataURL.identifier(),
|
||||
).From(appsTable.identifier()).
|
||||
LeftJoin(join(AppAPIConfigColumnAppID, AppColumnID)).
|
||||
LeftJoin(join(AppOIDCConfigColumnAppID, AppColumnID)).
|
||||
LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID)).
|
||||
PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*App, error) {
|
||||
app := new(App)
|
||||
|
||||
var (
|
||||
apiConfig = sqlAPIConfig{}
|
||||
oidcConfig = sqlOIDCConfig{}
|
||||
samlConfig = sqlSAMLConfig{}
|
||||
)
|
||||
|
||||
err := row.Scan(
|
||||
@@ -430,6 +515,11 @@ func prepareAppQuery() (sq.SelectBuilder, func(*sql.Row) (*App, error)) {
|
||||
&oidcConfig.iDTokenUserinfoAssertion,
|
||||
&oidcConfig.clockSkew,
|
||||
&oidcConfig.additionalOrigins,
|
||||
|
||||
&samlConfig.appID,
|
||||
&samlConfig.entityID,
|
||||
&samlConfig.metadata,
|
||||
&samlConfig.metadataURL,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
@@ -441,6 +531,7 @@ func prepareAppQuery() (sq.SelectBuilder, func(*sql.Row) (*App, error)) {
|
||||
|
||||
apiConfig.set(app)
|
||||
oidcConfig.set(app)
|
||||
samlConfig.set(app)
|
||||
|
||||
return app, nil
|
||||
}
|
||||
@@ -452,6 +543,7 @@ func prepareProjectIDByAppQuery() (sq.SelectBuilder, func(*sql.Row) (projectID s
|
||||
).From(appsTable.identifier()).
|
||||
LeftJoin(join(AppAPIConfigColumnAppID, AppColumnID)).
|
||||
LeftJoin(join(AppOIDCConfigColumnAppID, AppColumnID)).
|
||||
LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID)).
|
||||
PlaceholderFormat(sq.Dollar), func(row *sql.Row) (projectID string, err error) {
|
||||
err = row.Scan(
|
||||
&projectID,
|
||||
@@ -485,6 +577,7 @@ func prepareProjectByAppQuery() (sq.SelectBuilder, func(*sql.Row) (*Project, err
|
||||
Join(join(AppColumnProjectID, ProjectColumnID)).
|
||||
LeftJoin(join(AppAPIConfigColumnAppID, AppColumnID)).
|
||||
LeftJoin(join(AppOIDCConfigColumnAppID, AppColumnID)).
|
||||
LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID)).
|
||||
PlaceholderFormat(sq.Dollar),
|
||||
func(row *sql.Row) (*Project, error) {
|
||||
p := new(Project)
|
||||
@@ -542,10 +635,16 @@ func prepareAppsQuery() (sq.SelectBuilder, func(*sql.Rows) (*Apps, error)) {
|
||||
AppOIDCConfigColumnIDTokenUserinfoAssertion.identifier(),
|
||||
AppOIDCConfigColumnClockSkew.identifier(),
|
||||
AppOIDCConfigColumnAdditionalOrigins.identifier(),
|
||||
|
||||
AppSAMLConfigColumnAppID.identifier(),
|
||||
AppSAMLConfigColumnEntityID.identifier(),
|
||||
AppSAMLConfigColumnMetadata.identifier(),
|
||||
AppSAMLConfigColumnMetadataURL.identifier(),
|
||||
countColumn.identifier(),
|
||||
).From(appsTable.identifier()).
|
||||
LeftJoin(join(AppAPIConfigColumnAppID, AppColumnID)).
|
||||
LeftJoin(join(AppOIDCConfigColumnAppID, AppColumnID)).
|
||||
LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID)).
|
||||
PlaceholderFormat(sq.Dollar), func(row *sql.Rows) (*Apps, error) {
|
||||
apps := &Apps{Apps: []*App{}}
|
||||
|
||||
@@ -554,6 +653,7 @@ func prepareAppsQuery() (sq.SelectBuilder, func(*sql.Rows) (*Apps, error)) {
|
||||
var (
|
||||
apiConfig = sqlAPIConfig{}
|
||||
oidcConfig = sqlOIDCConfig{}
|
||||
samlConfig = sqlSAMLConfig{}
|
||||
)
|
||||
|
||||
err := row.Scan(
|
||||
@@ -586,6 +686,12 @@ func prepareAppsQuery() (sq.SelectBuilder, func(*sql.Rows) (*Apps, error)) {
|
||||
&oidcConfig.iDTokenUserinfoAssertion,
|
||||
&oidcConfig.clockSkew,
|
||||
&oidcConfig.additionalOrigins,
|
||||
|
||||
&samlConfig.appID,
|
||||
&samlConfig.entityID,
|
||||
&samlConfig.metadata,
|
||||
&samlConfig.metadataURL,
|
||||
|
||||
&apps.Count,
|
||||
)
|
||||
|
||||
@@ -595,6 +701,7 @@ func prepareAppsQuery() (sq.SelectBuilder, func(*sql.Rows) (*Apps, error)) {
|
||||
|
||||
apiConfig.set(app)
|
||||
oidcConfig.set(app)
|
||||
samlConfig.set(app)
|
||||
|
||||
apps.Apps = append(apps.Apps, app)
|
||||
}
|
||||
@@ -681,6 +788,24 @@ func (c sqlOIDCConfig) set(app *App) {
|
||||
logging.LogWithFields("app", app.ID).OnError(err).Warn("unable to set allowed origins")
|
||||
}
|
||||
|
||||
type sqlSAMLConfig struct {
|
||||
appID sql.NullString
|
||||
entityID sql.NullString
|
||||
metadataURL sql.NullString
|
||||
metadata []byte
|
||||
}
|
||||
|
||||
func (c sqlSAMLConfig) set(app *App) {
|
||||
if !c.appID.Valid {
|
||||
return
|
||||
}
|
||||
app.SAMLConfig = &SAMLApp{
|
||||
MetadataURL: c.metadataURL.String,
|
||||
Metadata: c.metadata,
|
||||
EntityID: c.entityID.String,
|
||||
}
|
||||
}
|
||||
|
||||
type sqlAPIConfig struct {
|
||||
appID sql.NullString
|
||||
clientID sql.NullString
|
||||
|
@@ -15,80 +15,93 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
expectedAppQuery = regexp.QuoteMeta(`SELECT projections.apps2.id,` +
|
||||
` projections.apps2.name,` +
|
||||
` projections.apps2.project_id,` +
|
||||
` projections.apps2.creation_date,` +
|
||||
` projections.apps2.change_date,` +
|
||||
` projections.apps2.resource_owner,` +
|
||||
` projections.apps2.state,` +
|
||||
` projections.apps2.sequence,` +
|
||||
expectedAppQuery = regexp.QuoteMeta(`SELECT projections.apps3.id,` +
|
||||
` projections.apps3.name,` +
|
||||
` projections.apps3.project_id,` +
|
||||
` projections.apps3.creation_date,` +
|
||||
` projections.apps3.change_date,` +
|
||||
` projections.apps3.resource_owner,` +
|
||||
` projections.apps3.state,` +
|
||||
` projections.apps3.sequence,` +
|
||||
// api config
|
||||
` projections.apps2_api_configs.app_id,` +
|
||||
` projections.apps2_api_configs.client_id,` +
|
||||
` projections.apps2_api_configs.auth_method,` +
|
||||
` projections.apps3_api_configs.app_id,` +
|
||||
` projections.apps3_api_configs.client_id,` +
|
||||
` projections.apps3_api_configs.auth_method,` +
|
||||
// oidc config
|
||||
` projections.apps2_oidc_configs.app_id,` +
|
||||
` projections.apps2_oidc_configs.version,` +
|
||||
` projections.apps2_oidc_configs.client_id,` +
|
||||
` projections.apps2_oidc_configs.redirect_uris,` +
|
||||
` projections.apps2_oidc_configs.response_types,` +
|
||||
` projections.apps2_oidc_configs.grant_types,` +
|
||||
` projections.apps2_oidc_configs.application_type,` +
|
||||
` projections.apps2_oidc_configs.auth_method_type,` +
|
||||
` projections.apps2_oidc_configs.post_logout_redirect_uris,` +
|
||||
` projections.apps2_oidc_configs.is_dev_mode,` +
|
||||
` projections.apps2_oidc_configs.access_token_type,` +
|
||||
` projections.apps2_oidc_configs.access_token_role_assertion,` +
|
||||
` projections.apps2_oidc_configs.id_token_role_assertion,` +
|
||||
` projections.apps2_oidc_configs.id_token_userinfo_assertion,` +
|
||||
` projections.apps2_oidc_configs.clock_skew,` +
|
||||
` projections.apps2_oidc_configs.additional_origins` +
|
||||
` FROM projections.apps2` +
|
||||
` LEFT JOIN projections.apps2_api_configs ON projections.apps2.id = projections.apps2_api_configs.app_id` +
|
||||
` LEFT JOIN projections.apps2_oidc_configs ON projections.apps2.id = projections.apps2_oidc_configs.app_id`)
|
||||
expectedAppsQuery = regexp.QuoteMeta(`SELECT projections.apps2.id,` +
|
||||
` projections.apps2.name,` +
|
||||
` projections.apps2.project_id,` +
|
||||
` projections.apps2.creation_date,` +
|
||||
` projections.apps2.change_date,` +
|
||||
` projections.apps2.resource_owner,` +
|
||||
` projections.apps2.state,` +
|
||||
` projections.apps2.sequence,` +
|
||||
` projections.apps3_oidc_configs.app_id,` +
|
||||
` projections.apps3_oidc_configs.version,` +
|
||||
` projections.apps3_oidc_configs.client_id,` +
|
||||
` projections.apps3_oidc_configs.redirect_uris,` +
|
||||
` projections.apps3_oidc_configs.response_types,` +
|
||||
` projections.apps3_oidc_configs.grant_types,` +
|
||||
` projections.apps3_oidc_configs.application_type,` +
|
||||
` projections.apps3_oidc_configs.auth_method_type,` +
|
||||
` projections.apps3_oidc_configs.post_logout_redirect_uris,` +
|
||||
` projections.apps3_oidc_configs.is_dev_mode,` +
|
||||
` projections.apps3_oidc_configs.access_token_type,` +
|
||||
` projections.apps3_oidc_configs.access_token_role_assertion,` +
|
||||
` projections.apps3_oidc_configs.id_token_role_assertion,` +
|
||||
` projections.apps3_oidc_configs.id_token_userinfo_assertion,` +
|
||||
` projections.apps3_oidc_configs.clock_skew,` +
|
||||
` projections.apps3_oidc_configs.additional_origins,` +
|
||||
//saml config
|
||||
` projections.apps3_saml_configs.app_id,` +
|
||||
` projections.apps3_saml_configs.entity_id,` +
|
||||
` projections.apps3_saml_configs.metadata,` +
|
||||
` projections.apps3_saml_configs.metadata_url` +
|
||||
` FROM projections.apps3` +
|
||||
` LEFT JOIN projections.apps3_api_configs ON projections.apps3.id = projections.apps3_api_configs.app_id` +
|
||||
` LEFT JOIN projections.apps3_oidc_configs ON projections.apps3.id = projections.apps3_oidc_configs.app_id` +
|
||||
` LEFT JOIN projections.apps3_saml_configs ON projections.apps3.id = projections.apps3_saml_configs.app_id`)
|
||||
expectedAppsQuery = regexp.QuoteMeta(`SELECT projections.apps3.id,` +
|
||||
` projections.apps3.name,` +
|
||||
` projections.apps3.project_id,` +
|
||||
` projections.apps3.creation_date,` +
|
||||
` projections.apps3.change_date,` +
|
||||
` projections.apps3.resource_owner,` +
|
||||
` projections.apps3.state,` +
|
||||
` projections.apps3.sequence,` +
|
||||
// api config
|
||||
` projections.apps2_api_configs.app_id,` +
|
||||
` projections.apps2_api_configs.client_id,` +
|
||||
` projections.apps2_api_configs.auth_method,` +
|
||||
` projections.apps3_api_configs.app_id,` +
|
||||
` projections.apps3_api_configs.client_id,` +
|
||||
` projections.apps3_api_configs.auth_method,` +
|
||||
// oidc config
|
||||
` projections.apps2_oidc_configs.app_id,` +
|
||||
` projections.apps2_oidc_configs.version,` +
|
||||
` projections.apps2_oidc_configs.client_id,` +
|
||||
` projections.apps2_oidc_configs.redirect_uris,` +
|
||||
` projections.apps2_oidc_configs.response_types,` +
|
||||
` projections.apps2_oidc_configs.grant_types,` +
|
||||
` projections.apps2_oidc_configs.application_type,` +
|
||||
` projections.apps2_oidc_configs.auth_method_type,` +
|
||||
` projections.apps2_oidc_configs.post_logout_redirect_uris,` +
|
||||
` projections.apps2_oidc_configs.is_dev_mode,` +
|
||||
` projections.apps2_oidc_configs.access_token_type,` +
|
||||
` projections.apps2_oidc_configs.access_token_role_assertion,` +
|
||||
` projections.apps2_oidc_configs.id_token_role_assertion,` +
|
||||
` projections.apps2_oidc_configs.id_token_userinfo_assertion,` +
|
||||
` projections.apps2_oidc_configs.clock_skew,` +
|
||||
` projections.apps2_oidc_configs.additional_origins,` +
|
||||
` projections.apps3_oidc_configs.app_id,` +
|
||||
` projections.apps3_oidc_configs.version,` +
|
||||
` projections.apps3_oidc_configs.client_id,` +
|
||||
` projections.apps3_oidc_configs.redirect_uris,` +
|
||||
` projections.apps3_oidc_configs.response_types,` +
|
||||
` projections.apps3_oidc_configs.grant_types,` +
|
||||
` projections.apps3_oidc_configs.application_type,` +
|
||||
` projections.apps3_oidc_configs.auth_method_type,` +
|
||||
` projections.apps3_oidc_configs.post_logout_redirect_uris,` +
|
||||
` projections.apps3_oidc_configs.is_dev_mode,` +
|
||||
` projections.apps3_oidc_configs.access_token_type,` +
|
||||
` projections.apps3_oidc_configs.access_token_role_assertion,` +
|
||||
` projections.apps3_oidc_configs.id_token_role_assertion,` +
|
||||
` projections.apps3_oidc_configs.id_token_userinfo_assertion,` +
|
||||
` projections.apps3_oidc_configs.clock_skew,` +
|
||||
` projections.apps3_oidc_configs.additional_origins,` +
|
||||
//saml config
|
||||
` projections.apps3_saml_configs.app_id,` +
|
||||
` projections.apps3_saml_configs.entity_id,` +
|
||||
` projections.apps3_saml_configs.metadata,` +
|
||||
` projections.apps3_saml_configs.metadata_url,` +
|
||||
` COUNT(*) OVER ()` +
|
||||
` FROM projections.apps2` +
|
||||
` LEFT JOIN projections.apps2_api_configs ON projections.apps2.id = projections.apps2_api_configs.app_id` +
|
||||
` LEFT JOIN projections.apps2_oidc_configs ON projections.apps2.id = projections.apps2_oidc_configs.app_id`)
|
||||
expectedAppIDsQuery = regexp.QuoteMeta(`SELECT projections.apps2_api_configs.client_id,` +
|
||||
` projections.apps2_oidc_configs.client_id` +
|
||||
` FROM projections.apps2` +
|
||||
` LEFT JOIN projections.apps2_api_configs ON projections.apps2.id = projections.apps2_api_configs.app_id` +
|
||||
` LEFT JOIN projections.apps2_oidc_configs ON projections.apps2.id = projections.apps2_oidc_configs.app_id`)
|
||||
expectedProjectIDByAppQuery = regexp.QuoteMeta(`SELECT projections.apps2.project_id` +
|
||||
` FROM projections.apps2` +
|
||||
` LEFT JOIN projections.apps2_api_configs ON projections.apps2.id = projections.apps2_api_configs.app_id` +
|
||||
` LEFT JOIN projections.apps2_oidc_configs ON projections.apps2.id = projections.apps2_oidc_configs.app_id`)
|
||||
` FROM projections.apps3` +
|
||||
` LEFT JOIN projections.apps3_api_configs ON projections.apps3.id = projections.apps3_api_configs.app_id` +
|
||||
` LEFT JOIN projections.apps3_oidc_configs ON projections.apps3.id = projections.apps3_oidc_configs.app_id` +
|
||||
` LEFT JOIN projections.apps3_saml_configs ON projections.apps3.id = projections.apps3_saml_configs.app_id`)
|
||||
expectedAppIDsQuery = regexp.QuoteMeta(`SELECT projections.apps3_api_configs.client_id,` +
|
||||
` projections.apps3_oidc_configs.client_id` +
|
||||
` FROM projections.apps3` +
|
||||
` LEFT JOIN projections.apps3_api_configs ON projections.apps3.id = projections.apps3_api_configs.app_id` +
|
||||
` LEFT JOIN projections.apps3_oidc_configs ON projections.apps3.id = projections.apps3_oidc_configs.app_id`)
|
||||
expectedProjectIDByAppQuery = regexp.QuoteMeta(`SELECT projections.apps3.project_id` +
|
||||
` FROM projections.apps3` +
|
||||
` LEFT JOIN projections.apps3_api_configs ON projections.apps3.id = projections.apps3_api_configs.app_id` +
|
||||
` LEFT JOIN projections.apps3_oidc_configs ON projections.apps3.id = projections.apps3_oidc_configs.app_id` +
|
||||
` LEFT JOIN projections.apps3_saml_configs ON projections.apps3.id = projections.apps3_saml_configs.app_id`)
|
||||
expectedProjectByAppQuery = regexp.QuoteMeta(`SELECT projections.projects2.id,` +
|
||||
` projections.projects2.creation_date,` +
|
||||
` projections.projects2.change_date,` +
|
||||
@@ -101,9 +114,10 @@ var (
|
||||
` projections.projects2.has_project_check,` +
|
||||
` projections.projects2.private_labeling_setting` +
|
||||
` FROM projections.projects2` +
|
||||
` JOIN projections.apps2 ON projections.projects2.id = projections.apps2.project_id` +
|
||||
` LEFT JOIN projections.apps2_api_configs ON projections.apps2.id = projections.apps2_api_configs.app_id` +
|
||||
` LEFT JOIN projections.apps2_oidc_configs ON projections.apps2.id = projections.apps2_oidc_configs.app_id`)
|
||||
` JOIN projections.apps3 ON projections.projects2.id = projections.apps3.project_id` +
|
||||
` LEFT JOIN projections.apps3_api_configs ON projections.apps3.id = projections.apps3_api_configs.app_id` +
|
||||
` LEFT JOIN projections.apps3_oidc_configs ON projections.apps3.id = projections.apps3_oidc_configs.app_id` +
|
||||
` LEFT JOIN projections.apps3_saml_configs ON projections.apps3.id = projections.apps3_saml_configs.app_id`)
|
||||
|
||||
appCols = database.StringArray{
|
||||
"id",
|
||||
@@ -135,6 +149,11 @@ var (
|
||||
"id_token_userinfo_assertion",
|
||||
"clock_skew",
|
||||
"additional_origins",
|
||||
//saml config
|
||||
"app_id",
|
||||
"entity_id",
|
||||
"metadata",
|
||||
"metadata_url",
|
||||
}
|
||||
appsCols = append(appCols, "count")
|
||||
)
|
||||
@@ -200,6 +219,11 @@ func Test_AppsPrepare(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
// saml config
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
},
|
||||
),
|
||||
@@ -260,6 +284,11 @@ func Test_AppsPrepare(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
// saml config
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
},
|
||||
),
|
||||
@@ -285,6 +314,75 @@ func Test_AppsPrepare(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
}, {
|
||||
name: "prepareAppsQuery saml app",
|
||||
prepare: prepareAppsQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQueries(
|
||||
expectedAppsQuery,
|
||||
appsCols,
|
||||
[][]driver.Value{
|
||||
{
|
||||
"app-id",
|
||||
"app-name",
|
||||
"project-id",
|
||||
testNow,
|
||||
testNow,
|
||||
"ro",
|
||||
domain.AppStateActive,
|
||||
uint64(20211109),
|
||||
// api config
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
// oidc config
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
// saml config
|
||||
"app-id",
|
||||
"https://test.com/saml/metadata",
|
||||
[]byte("<?xml version=\"1.0\"?>\n<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\"\n validUntil=\"2022-08-26T14:08:16Z\"\n cacheDuration=\"PT604800S\"\n entityID=\"https://test.com/saml/metadata\">\n <md:SPSSODescriptor AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"false\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>\n <md:AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"\n Location=\"https://test.com/saml/acs\"\n index=\"1\" />\n \n </md:SPSSODescriptor>\n</md:EntityDescriptor>"),
|
||||
"https://test.com/saml/metadata",
|
||||
},
|
||||
},
|
||||
),
|
||||
},
|
||||
object: &Apps{
|
||||
SearchResponse: SearchResponse{
|
||||
Count: 1,
|
||||
},
|
||||
Apps: []*App{
|
||||
{
|
||||
ID: "app-id",
|
||||
CreationDate: testNow,
|
||||
ChangeDate: testNow,
|
||||
ResourceOwner: "ro",
|
||||
State: domain.AppStateActive,
|
||||
Sequence: 20211109,
|
||||
Name: "app-name",
|
||||
ProjectID: "project-id",
|
||||
SAMLConfig: &SAMLApp{
|
||||
Metadata: []byte("<?xml version=\"1.0\"?>\n<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\"\n validUntil=\"2022-08-26T14:08:16Z\"\n cacheDuration=\"PT604800S\"\n entityID=\"https://test.com/saml/metadata\">\n <md:SPSSODescriptor AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"false\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>\n <md:AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"\n Location=\"https://test.com/saml/acs\"\n index=\"1\" />\n \n </md:SPSSODescriptor>\n</md:EntityDescriptor>"),
|
||||
MetadataURL: "https://test.com/saml/metadata",
|
||||
EntityID: "https://test.com/saml/metadata",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "prepareAppsQuery oidc app",
|
||||
@@ -324,6 +422,11 @@ func Test_AppsPrepare(t *testing.T) {
|
||||
true,
|
||||
1 * time.Second,
|
||||
database.StringArray{"additional.origin"},
|
||||
// saml config
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
},
|
||||
),
|
||||
@@ -403,6 +506,11 @@ func Test_AppsPrepare(t *testing.T) {
|
||||
true,
|
||||
1 * time.Second,
|
||||
database.StringArray{"additional.origin"},
|
||||
// saml config
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
},
|
||||
),
|
||||
@@ -482,6 +590,11 @@ func Test_AppsPrepare(t *testing.T) {
|
||||
true,
|
||||
1 * time.Second,
|
||||
database.StringArray{"additional.origin"},
|
||||
// saml config
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
},
|
||||
),
|
||||
@@ -561,6 +674,11 @@ func Test_AppsPrepare(t *testing.T) {
|
||||
true,
|
||||
1 * time.Second,
|
||||
database.StringArray{"additional.origin"},
|
||||
// saml config
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
},
|
||||
),
|
||||
@@ -640,6 +758,11 @@ func Test_AppsPrepare(t *testing.T) {
|
||||
true,
|
||||
1 * time.Second,
|
||||
database.StringArray{"additional.origin"},
|
||||
// saml config
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
},
|
||||
),
|
||||
@@ -719,6 +842,11 @@ func Test_AppsPrepare(t *testing.T) {
|
||||
true,
|
||||
1 * time.Second,
|
||||
database.StringArray{"additional.origin"},
|
||||
// saml config
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"api-app-id",
|
||||
@@ -750,13 +878,54 @@ func Test_AppsPrepare(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
// saml config
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"saml-app-id",
|
||||
"app-name",
|
||||
"project-id",
|
||||
testNow,
|
||||
testNow,
|
||||
"ro",
|
||||
domain.AppStateActive,
|
||||
uint64(20211109),
|
||||
// api config
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
// oidc config
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
// saml config
|
||||
"saml-app-id",
|
||||
"https://test.com/saml/metadata",
|
||||
[]byte("<?xml version=\"1.0\"?>\n<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\"\n validUntil=\"2022-08-26T14:08:16Z\"\n cacheDuration=\"PT604800S\"\n entityID=\"https://test.com/saml/metadata\">\n <md:SPSSODescriptor AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"false\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>\n <md:AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"\n Location=\"https://test.com/saml/acs\"\n index=\"1\" />\n \n </md:SPSSODescriptor>\n</md:EntityDescriptor>"),
|
||||
"https://test.com/saml/metadata",
|
||||
},
|
||||
},
|
||||
),
|
||||
},
|
||||
object: &Apps{
|
||||
SearchResponse: SearchResponse{
|
||||
Count: 2,
|
||||
Count: 3,
|
||||
},
|
||||
Apps: []*App{
|
||||
{
|
||||
@@ -802,6 +971,21 @@ func Test_AppsPrepare(t *testing.T) {
|
||||
AuthMethodType: domain.APIAuthMethodTypePrivateKeyJWT,
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "saml-app-id",
|
||||
CreationDate: testNow,
|
||||
ChangeDate: testNow,
|
||||
ResourceOwner: "ro",
|
||||
State: domain.AppStateActive,
|
||||
Sequence: 20211109,
|
||||
Name: "app-name",
|
||||
ProjectID: "project-id",
|
||||
SAMLConfig: &SAMLApp{
|
||||
Metadata: []byte("<?xml version=\"1.0\"?>\n<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\"\n validUntil=\"2022-08-26T14:08:16Z\"\n cacheDuration=\"PT604800S\"\n entityID=\"https://test.com/saml/metadata\">\n <md:SPSSODescriptor AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"false\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>\n <md:AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"\n Location=\"https://test.com/saml/acs\"\n index=\"1\" />\n \n </md:SPSSODescriptor>\n</md:EntityDescriptor>"),
|
||||
MetadataURL: "https://test.com/saml/metadata",
|
||||
EntityID: "https://test.com/saml/metadata",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -896,6 +1080,11 @@ func Test_AppPrepare(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
// saml config
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
),
|
||||
},
|
||||
@@ -948,6 +1137,11 @@ func Test_AppPrepare(t *testing.T) {
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
// saml config
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
},
|
||||
),
|
||||
@@ -1005,6 +1199,11 @@ func Test_AppPrepare(t *testing.T) {
|
||||
true,
|
||||
1 * time.Second,
|
||||
database.StringArray{"additional.origin"},
|
||||
// saml config
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
},
|
||||
),
|
||||
@@ -1038,6 +1237,68 @@ func Test_AppPrepare(t *testing.T) {
|
||||
AllowedOrigins: database.StringArray{"https://redirect.to", "additional.origin"},
|
||||
},
|
||||
},
|
||||
}, {
|
||||
name: "prepareAppQuery saml app",
|
||||
prepare: prepareAppQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQueries(
|
||||
expectedAppQuery,
|
||||
appCols,
|
||||
[][]driver.Value{
|
||||
{
|
||||
"app-id",
|
||||
"app-name",
|
||||
"project-id",
|
||||
testNow,
|
||||
testNow,
|
||||
"ro",
|
||||
domain.AppStateActive,
|
||||
uint64(20211109),
|
||||
// api config
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
// oidc config
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
// saml config
|
||||
"app-id",
|
||||
"https://test.com/saml/metadata",
|
||||
[]byte("<?xml version=\"1.0\"?>\n<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\"\n validUntil=\"2022-08-26T14:08:16Z\"\n cacheDuration=\"PT604800S\"\n entityID=\"https://test.com/saml/metadata\">\n <md:SPSSODescriptor AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"false\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>\n <md:AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"\n Location=\"https://test.com/saml/acs\"\n index=\"1\" />\n \n </md:SPSSODescriptor>\n</md:EntityDescriptor>"),
|
||||
"https://test.com/saml/metadata",
|
||||
},
|
||||
},
|
||||
),
|
||||
},
|
||||
object: &App{
|
||||
ID: "app-id",
|
||||
CreationDate: testNow,
|
||||
ChangeDate: testNow,
|
||||
ResourceOwner: "ro",
|
||||
State: domain.AppStateActive,
|
||||
Sequence: 20211109,
|
||||
Name: "app-name",
|
||||
ProjectID: "project-id",
|
||||
SAMLConfig: &SAMLApp{
|
||||
Metadata: []byte("<?xml version=\"1.0\"?>\n<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\"\n validUntil=\"2022-08-26T14:08:16Z\"\n cacheDuration=\"PT604800S\"\n entityID=\"https://test.com/saml/metadata\">\n <md:SPSSODescriptor AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"false\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>\n <md:AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"\n Location=\"https://test.com/saml/acs\"\n index=\"1\" />\n \n </md:SPSSODescriptor>\n</md:EntityDescriptor>"),
|
||||
MetadataURL: "https://test.com/saml/metadata",
|
||||
EntityID: "https://test.com/saml/metadata",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "prepareAppQuery oidc app IsDevMode inactive",
|
||||
@@ -1077,6 +1338,11 @@ func Test_AppPrepare(t *testing.T) {
|
||||
true,
|
||||
1 * time.Second,
|
||||
database.StringArray{"additional.origin"},
|
||||
// saml config
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
},
|
||||
),
|
||||
@@ -1149,6 +1415,11 @@ func Test_AppPrepare(t *testing.T) {
|
||||
true,
|
||||
1 * time.Second,
|
||||
database.StringArray{"additional.origin"},
|
||||
// saml config
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
},
|
||||
),
|
||||
@@ -1221,6 +1492,11 @@ func Test_AppPrepare(t *testing.T) {
|
||||
true,
|
||||
1 * time.Second,
|
||||
database.StringArray{"additional.origin"},
|
||||
// saml config
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
},
|
||||
),
|
||||
@@ -1293,6 +1569,11 @@ func Test_AppPrepare(t *testing.T) {
|
||||
false,
|
||||
1 * time.Second,
|
||||
database.StringArray{"additional.origin"},
|
||||
// saml config
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
},
|
||||
},
|
||||
),
|
||||
|
156
internal/query/certificate.go
Normal file
156
internal/query/certificate.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/query/projection"
|
||||
)
|
||||
|
||||
type Certificate interface {
|
||||
Key
|
||||
Expiry() time.Time
|
||||
Key() *crypto.CryptoValue
|
||||
Certificate() []byte
|
||||
}
|
||||
|
||||
type Certificates struct {
|
||||
SearchResponse
|
||||
Certificates []Certificate
|
||||
}
|
||||
|
||||
type rsaCertificate struct {
|
||||
key
|
||||
expiry time.Time
|
||||
privateKey *crypto.CryptoValue
|
||||
certificate []byte
|
||||
}
|
||||
|
||||
func (c *rsaCertificate) Expiry() time.Time {
|
||||
return c.expiry
|
||||
}
|
||||
|
||||
func (c *rsaCertificate) Key() *crypto.CryptoValue {
|
||||
return c.privateKey
|
||||
}
|
||||
|
||||
func (c *rsaCertificate) Certificate() []byte {
|
||||
return c.certificate
|
||||
}
|
||||
|
||||
var (
|
||||
certificateTable = table{
|
||||
name: projection.CertificateTable,
|
||||
}
|
||||
CertificateColID = Column{
|
||||
name: projection.CertificateColumnID,
|
||||
table: certificateTable,
|
||||
}
|
||||
CertificateColExpiry = Column{
|
||||
name: projection.CertificateColumnExpiry,
|
||||
table: certificateTable,
|
||||
}
|
||||
CertificateColCertificate = Column{
|
||||
name: projection.CertificateColumnCertificate,
|
||||
table: certificateTable,
|
||||
}
|
||||
)
|
||||
|
||||
func (q *Queries) ActiveCertificates(ctx context.Context, t time.Time, usage domain.KeyUsage) (*Certificates, error) {
|
||||
query, scan := prepareCertificateQuery()
|
||||
if t.IsZero() {
|
||||
t = time.Now()
|
||||
}
|
||||
stmt, args, err := query.Where(
|
||||
sq.And{
|
||||
sq.Eq{
|
||||
KeyColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
|
||||
KeyColUse.identifier(): usage,
|
||||
},
|
||||
sq.Gt{
|
||||
CertificateColExpiry.identifier(): t,
|
||||
},
|
||||
sq.Gt{
|
||||
KeyPrivateColExpiry.identifier(): t,
|
||||
},
|
||||
}).OrderBy(KeyPrivateColExpiry.identifier()).ToSql()
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "QUERY-SDfkg", "Errors.Query.SQLStatement")
|
||||
}
|
||||
|
||||
rows, err := q.client.QueryContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "QUERY-Sgan4", "Errors.Internal")
|
||||
}
|
||||
keys, err := scan(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keys.LatestSequence, err = q.latestSequence(ctx, keyTable)
|
||||
if !errors.IsNotFound(err) {
|
||||
return keys, err
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func prepareCertificateQuery() (sq.SelectBuilder, func(*sql.Rows) (*Certificates, error)) {
|
||||
return sq.Select(
|
||||
KeyColID.identifier(),
|
||||
KeyColCreationDate.identifier(),
|
||||
KeyColChangeDate.identifier(),
|
||||
KeyColSequence.identifier(),
|
||||
KeyColResourceOwner.identifier(),
|
||||
KeyColAlgorithm.identifier(),
|
||||
KeyColUse.identifier(),
|
||||
CertificateColExpiry.identifier(),
|
||||
CertificateColCertificate.identifier(),
|
||||
KeyPrivateColKey.identifier(),
|
||||
countColumn.identifier(),
|
||||
).From(keyTable.identifier()).
|
||||
LeftJoin(join(CertificateColID, KeyColID)).
|
||||
LeftJoin(join(KeyPrivateColID, KeyColID)).
|
||||
PlaceholderFormat(sq.Dollar),
|
||||
func(rows *sql.Rows) (*Certificates, error) {
|
||||
certificates := make([]Certificate, 0)
|
||||
var count uint64
|
||||
for rows.Next() {
|
||||
k := new(rsaCertificate)
|
||||
err := rows.Scan(
|
||||
&k.id,
|
||||
&k.creationDate,
|
||||
&k.changeDate,
|
||||
&k.sequence,
|
||||
&k.resourceOwner,
|
||||
&k.algorithm,
|
||||
&k.use,
|
||||
&k.expiry,
|
||||
&k.certificate,
|
||||
&k.privateKey,
|
||||
&count,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
certificates = append(certificates, k)
|
||||
}
|
||||
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, errors.ThrowInternal(err, "QUERY-rKd6k", "Errors.Query.CloseRows")
|
||||
}
|
||||
|
||||
return &Certificates{
|
||||
Certificates: certificates,
|
||||
SearchResponse: SearchResponse{
|
||||
Count: count,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
169
internal/query/certificate_test.go
Normal file
169
internal/query/certificate_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
errs "github.com/zitadel/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
func Test_CertificatePrepares(t *testing.T) {
|
||||
type want struct {
|
||||
sqlExpectations sqlExpectation
|
||||
err checkErr
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
prepare interface{}
|
||||
want want
|
||||
object interface{}
|
||||
}{
|
||||
{
|
||||
name: "prepareCertificateQuery no result",
|
||||
prepare: prepareCertificateQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQueries(
|
||||
regexp.QuoteMeta(`SELECT projections.keys3.id,`+
|
||||
` projections.keys3.creation_date,`+
|
||||
` projections.keys3.change_date,`+
|
||||
` projections.keys3.sequence,`+
|
||||
` projections.keys3.resource_owner,`+
|
||||
` projections.keys3.algorithm,`+
|
||||
` projections.keys3.use,`+
|
||||
` projections.keys3_certificate.expiry,`+
|
||||
` projections.keys3_certificate.certificate,`+
|
||||
` projections.keys3_private.key,`+
|
||||
` COUNT(*) OVER ()`+
|
||||
` FROM projections.keys3`+
|
||||
` LEFT JOIN projections.keys3_certificate ON projections.keys3.id = projections.keys3_certificate.id`+
|
||||
` LEFT JOIN projections.keys3_private ON projections.keys3.id = projections.keys3_private.id`),
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
err: func(err error) (error, bool) {
|
||||
if !errs.IsNotFound(err) {
|
||||
return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false
|
||||
}
|
||||
return nil, true
|
||||
},
|
||||
},
|
||||
object: &Certificates{Certificates: []Certificate{}},
|
||||
},
|
||||
{
|
||||
name: "prepareCertificateQuery found",
|
||||
prepare: prepareCertificateQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQueries(
|
||||
regexp.QuoteMeta(`SELECT projections.keys3.id,`+
|
||||
` projections.keys3.creation_date,`+
|
||||
` projections.keys3.change_date,`+
|
||||
` projections.keys3.sequence,`+
|
||||
` projections.keys3.resource_owner,`+
|
||||
` projections.keys3.algorithm,`+
|
||||
` projections.keys3.use,`+
|
||||
` projections.keys3_certificate.expiry,`+
|
||||
` projections.keys3_certificate.certificate,`+
|
||||
` projections.keys3_private.key,`+
|
||||
` COUNT(*) OVER ()`+
|
||||
` FROM projections.keys3`+
|
||||
` LEFT JOIN projections.keys3_certificate ON projections.keys3.id = projections.keys3_certificate.id`+
|
||||
` LEFT JOIN projections.keys3_private ON projections.keys3.id = projections.keys3_private.id`),
|
||||
[]string{
|
||||
"id",
|
||||
"creation_date",
|
||||
"change_date",
|
||||
"sequence",
|
||||
"resource_owner",
|
||||
"algorithm",
|
||||
"use",
|
||||
"expiry",
|
||||
"certificate",
|
||||
"key",
|
||||
"count",
|
||||
},
|
||||
[][]driver.Value{
|
||||
{
|
||||
"key-id",
|
||||
testNow,
|
||||
testNow,
|
||||
uint64(20211109),
|
||||
"ro",
|
||||
"",
|
||||
1,
|
||||
testNow,
|
||||
[]byte(`privateKey`),
|
||||
[]byte(`{"Algorithm": "enc", "Crypted": "cHJpdmF0ZUtleQ==", "CryptoType": 0, "KeyID": "id"}`),
|
||||
},
|
||||
},
|
||||
),
|
||||
},
|
||||
object: &Certificates{
|
||||
SearchResponse: SearchResponse{
|
||||
Count: 1,
|
||||
},
|
||||
Certificates: []Certificate{
|
||||
&rsaCertificate{
|
||||
key: key{
|
||||
id: "key-id",
|
||||
creationDate: testNow,
|
||||
changeDate: testNow,
|
||||
sequence: 20211109,
|
||||
resourceOwner: "ro",
|
||||
algorithm: "",
|
||||
use: domain.KeyUsageSAMLMetadataSigning,
|
||||
},
|
||||
expiry: testNow,
|
||||
certificate: []byte("privateKey"),
|
||||
privateKey: &crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
KeyID: "id",
|
||||
Crypted: []byte("privateKey"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "prepareCertificateQuery sql err",
|
||||
prepare: prepareCertificateQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQueryErr(
|
||||
regexp.QuoteMeta(`SELECT projections.keys3.id,`+
|
||||
` projections.keys3.creation_date,`+
|
||||
` projections.keys3.change_date,`+
|
||||
` projections.keys3.sequence,`+
|
||||
` projections.keys3.resource_owner,`+
|
||||
` projections.keys3.algorithm,`+
|
||||
` projections.keys3.use,`+
|
||||
` projections.keys3_certificate.expiry,`+
|
||||
` projections.keys3_certificate.certificate,`+
|
||||
` projections.keys3_private.key,`+
|
||||
` COUNT(*) OVER ()`+
|
||||
` FROM projections.keys3`+
|
||||
` LEFT JOIN projections.keys3_certificate ON projections.keys3.id = projections.keys3_certificate.id`+
|
||||
` LEFT JOIN projections.keys3_private ON projections.keys3.id = projections.keys3_private.id`),
|
||||
sql.ErrConnDone,
|
||||
),
|
||||
err: func(err error) (error, bool) {
|
||||
if !errors.Is(err, sql.ErrConnDone) {
|
||||
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
|
||||
}
|
||||
return nil, true
|
||||
},
|
||||
},
|
||||
object: nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err)
|
||||
})
|
||||
}
|
||||
}
|
@@ -200,7 +200,10 @@ func (q *Queries) ActivePublicKeys(ctx context.Context, t time.Time) (*PublicKey
|
||||
return nil, err
|
||||
}
|
||||
keys.LatestSequence, err = q.latestSequence(ctx, keyTable)
|
||||
return keys, err
|
||||
if !errors.IsNotFound(err) {
|
||||
return keys, err
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (q *Queries) ActivePrivateSigningKey(ctx context.Context, t time.Time) (*PrivateKeys, error) {
|
||||
|
@@ -31,18 +31,18 @@ func Test_KeyPrepares(t *testing.T) {
|
||||
prepare: preparePublicKeysQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQueries(
|
||||
regexp.QuoteMeta(`SELECT projections.keys2.id,`+
|
||||
` projections.keys2.creation_date,`+
|
||||
` projections.keys2.change_date,`+
|
||||
` projections.keys2.sequence,`+
|
||||
` projections.keys2.resource_owner,`+
|
||||
` projections.keys2.algorithm,`+
|
||||
` projections.keys2.use,`+
|
||||
` projections.keys2_public.expiry,`+
|
||||
` projections.keys2_public.key,`+
|
||||
regexp.QuoteMeta(`SELECT projections.keys3.id,`+
|
||||
` projections.keys3.creation_date,`+
|
||||
` projections.keys3.change_date,`+
|
||||
` projections.keys3.sequence,`+
|
||||
` projections.keys3.resource_owner,`+
|
||||
` projections.keys3.algorithm,`+
|
||||
` projections.keys3.use,`+
|
||||
` projections.keys3_public.expiry,`+
|
||||
` projections.keys3_public.key,`+
|
||||
` COUNT(*) OVER ()`+
|
||||
` FROM projections.keys2`+
|
||||
` LEFT JOIN projections.keys2_public ON projections.keys2.id = projections.keys2_public.id`),
|
||||
` FROM projections.keys3`+
|
||||
` LEFT JOIN projections.keys3_public ON projections.keys3.id = projections.keys3_public.id`),
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
@@ -60,18 +60,18 @@ func Test_KeyPrepares(t *testing.T) {
|
||||
prepare: preparePublicKeysQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQueries(
|
||||
regexp.QuoteMeta(`SELECT projections.keys2.id,`+
|
||||
` projections.keys2.creation_date,`+
|
||||
` projections.keys2.change_date,`+
|
||||
` projections.keys2.sequence,`+
|
||||
` projections.keys2.resource_owner,`+
|
||||
` projections.keys2.algorithm,`+
|
||||
` projections.keys2.use,`+
|
||||
` projections.keys2_public.expiry,`+
|
||||
` projections.keys2_public.key,`+
|
||||
regexp.QuoteMeta(`SELECT projections.keys3.id,`+
|
||||
` projections.keys3.creation_date,`+
|
||||
` projections.keys3.change_date,`+
|
||||
` projections.keys3.sequence,`+
|
||||
` projections.keys3.resource_owner,`+
|
||||
` projections.keys3.algorithm,`+
|
||||
` projections.keys3.use,`+
|
||||
` projections.keys3_public.expiry,`+
|
||||
` projections.keys3_public.key,`+
|
||||
` COUNT(*) OVER ()`+
|
||||
` FROM projections.keys2`+
|
||||
` LEFT JOIN projections.keys2_public ON projections.keys2.id = projections.keys2_public.id`),
|
||||
` FROM projections.keys3`+
|
||||
` LEFT JOIN projections.keys3_public ON projections.keys3.id = projections.keys3_public.id`),
|
||||
[]string{
|
||||
"id",
|
||||
"creation_date",
|
||||
@@ -128,18 +128,18 @@ func Test_KeyPrepares(t *testing.T) {
|
||||
prepare: preparePublicKeysQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQueryErr(
|
||||
regexp.QuoteMeta(`SELECT projections.keys2.id,`+
|
||||
` projections.keys2.creation_date,`+
|
||||
` projections.keys2.change_date,`+
|
||||
` projections.keys2.sequence,`+
|
||||
` projections.keys2.resource_owner,`+
|
||||
` projections.keys2.algorithm,`+
|
||||
` projections.keys2.use,`+
|
||||
` projections.keys2_public.expiry,`+
|
||||
` projections.keys2_public.key,`+
|
||||
regexp.QuoteMeta(`SELECT projections.keys3.id,`+
|
||||
` projections.keys3.creation_date,`+
|
||||
` projections.keys3.change_date,`+
|
||||
` projections.keys3.sequence,`+
|
||||
` projections.keys3.resource_owner,`+
|
||||
` projections.keys3.algorithm,`+
|
||||
` projections.keys3.use,`+
|
||||
` projections.keys3_public.expiry,`+
|
||||
` projections.keys3_public.key,`+
|
||||
` COUNT(*) OVER ()`+
|
||||
` FROM projections.keys2`+
|
||||
` LEFT JOIN projections.keys2_public ON projections.keys2.id = projections.keys2_public.id`),
|
||||
` FROM projections.keys3`+
|
||||
` LEFT JOIN projections.keys3_public ON projections.keys3.id = projections.keys3_public.id`),
|
||||
sql.ErrConnDone,
|
||||
),
|
||||
err: func(err error) (error, bool) {
|
||||
@@ -156,18 +156,18 @@ func Test_KeyPrepares(t *testing.T) {
|
||||
prepare: preparePrivateKeysQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQueries(
|
||||
regexp.QuoteMeta(`SELECT projections.keys2.id,`+
|
||||
` projections.keys2.creation_date,`+
|
||||
` projections.keys2.change_date,`+
|
||||
` projections.keys2.sequence,`+
|
||||
` projections.keys2.resource_owner,`+
|
||||
` projections.keys2.algorithm,`+
|
||||
` projections.keys2.use,`+
|
||||
` projections.keys2_private.expiry,`+
|
||||
` projections.keys2_private.key,`+
|
||||
regexp.QuoteMeta(`SELECT projections.keys3.id,`+
|
||||
` projections.keys3.creation_date,`+
|
||||
` projections.keys3.change_date,`+
|
||||
` projections.keys3.sequence,`+
|
||||
` projections.keys3.resource_owner,`+
|
||||
` projections.keys3.algorithm,`+
|
||||
` projections.keys3.use,`+
|
||||
` projections.keys3_private.expiry,`+
|
||||
` projections.keys3_private.key,`+
|
||||
` COUNT(*) OVER ()`+
|
||||
` FROM projections.keys2`+
|
||||
` LEFT JOIN projections.keys2_private ON projections.keys2.id = projections.keys2_private.id`),
|
||||
` FROM projections.keys3`+
|
||||
` LEFT JOIN projections.keys3_private ON projections.keys3.id = projections.keys3_private.id`),
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
@@ -185,18 +185,18 @@ func Test_KeyPrepares(t *testing.T) {
|
||||
prepare: preparePrivateKeysQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQueries(
|
||||
regexp.QuoteMeta(`SELECT projections.keys2.id,`+
|
||||
` projections.keys2.creation_date,`+
|
||||
` projections.keys2.change_date,`+
|
||||
` projections.keys2.sequence,`+
|
||||
` projections.keys2.resource_owner,`+
|
||||
` projections.keys2.algorithm,`+
|
||||
` projections.keys2.use,`+
|
||||
` projections.keys2_private.expiry,`+
|
||||
` projections.keys2_private.key,`+
|
||||
regexp.QuoteMeta(`SELECT projections.keys3.id,`+
|
||||
` projections.keys3.creation_date,`+
|
||||
` projections.keys3.change_date,`+
|
||||
` projections.keys3.sequence,`+
|
||||
` projections.keys3.resource_owner,`+
|
||||
` projections.keys3.algorithm,`+
|
||||
` projections.keys3.use,`+
|
||||
` projections.keys3_private.expiry,`+
|
||||
` projections.keys3_private.key,`+
|
||||
` COUNT(*) OVER ()`+
|
||||
` FROM projections.keys2`+
|
||||
` LEFT JOIN projections.keys2_private ON projections.keys2.id = projections.keys2_private.id`),
|
||||
` FROM projections.keys3`+
|
||||
` LEFT JOIN projections.keys3_private ON projections.keys3.id = projections.keys3_private.id`),
|
||||
[]string{
|
||||
"id",
|
||||
"creation_date",
|
||||
@@ -255,18 +255,18 @@ func Test_KeyPrepares(t *testing.T) {
|
||||
prepare: preparePrivateKeysQuery,
|
||||
want: want{
|
||||
sqlExpectations: mockQueryErr(
|
||||
regexp.QuoteMeta(`SELECT projections.keys2.id,`+
|
||||
` projections.keys2.creation_date,`+
|
||||
` projections.keys2.change_date,`+
|
||||
` projections.keys2.sequence,`+
|
||||
` projections.keys2.resource_owner,`+
|
||||
` projections.keys2.algorithm,`+
|
||||
` projections.keys2.use,`+
|
||||
` projections.keys2_private.expiry,`+
|
||||
` projections.keys2_private.key,`+
|
||||
regexp.QuoteMeta(`SELECT projections.keys3.id,`+
|
||||
` projections.keys3.creation_date,`+
|
||||
` projections.keys3.change_date,`+
|
||||
` projections.keys3.sequence,`+
|
||||
` projections.keys3.resource_owner,`+
|
||||
` projections.keys3.algorithm,`+
|
||||
` projections.keys3.use,`+
|
||||
` projections.keys3_private.expiry,`+
|
||||
` projections.keys3_private.key,`+
|
||||
` COUNT(*) OVER ()`+
|
||||
` FROM projections.keys2`+
|
||||
` LEFT JOIN projections.keys2_private ON projections.keys2.id = projections.keys2_private.id`),
|
||||
` FROM projections.keys3`+
|
||||
` LEFT JOIN projections.keys3_private ON projections.keys3.id = projections.keys3_private.id`),
|
||||
sql.ErrConnDone,
|
||||
),
|
||||
err: func(err error) (error, bool) {
|
||||
|
@@ -13,9 +13,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
AppProjectionTable = "projections.apps2"
|
||||
AppProjectionTable = "projections.apps3"
|
||||
AppAPITable = AppProjectionTable + "_" + appAPITableSuffix
|
||||
AppOIDCTable = AppProjectionTable + "_" + appOIDCTableSuffix
|
||||
AppSAMLTable = AppProjectionTable + "_" + appSAMLTableSuffix
|
||||
|
||||
AppColumnID = "id"
|
||||
AppColumnName = "name"
|
||||
@@ -53,6 +54,13 @@ const (
|
||||
AppOIDCConfigColumnIDTokenUserinfoAssertion = "id_token_userinfo_assertion"
|
||||
AppOIDCConfigColumnClockSkew = "clock_skew"
|
||||
AppOIDCConfigColumnAdditionalOrigins = "additional_origins"
|
||||
|
||||
appSAMLTableSuffix = "saml_configs"
|
||||
AppSAMLConfigColumnAppID = "app_id"
|
||||
AppSAMLConfigColumnInstanceID = "instance_id"
|
||||
AppSAMLConfigColumnEntityID = "entity_id"
|
||||
AppSAMLConfigColumnMetadata = "metadata"
|
||||
AppSAMLConfigColumnMetadataURL = "metadata_url"
|
||||
)
|
||||
|
||||
type appProjection struct {
|
||||
@@ -116,6 +124,18 @@ func newAppProjection(ctx context.Context, config crdb.StatementHandlerConfig) *
|
||||
crdb.WithForeignKey(crdb.NewForeignKeyOfPublicKeys("fk_oidc_ref_apps")),
|
||||
crdb.WithIndex(crdb.NewIndex("oidc_client_id_idx", []string{AppOIDCConfigColumnClientID})),
|
||||
),
|
||||
crdb.NewSuffixedTable([]*crdb.Column{
|
||||
crdb.NewColumn(AppSAMLConfigColumnAppID, crdb.ColumnTypeText),
|
||||
crdb.NewColumn(AppSAMLConfigColumnInstanceID, crdb.ColumnTypeText),
|
||||
crdb.NewColumn(AppSAMLConfigColumnEntityID, crdb.ColumnTypeText),
|
||||
crdb.NewColumn(AppSAMLConfigColumnMetadata, crdb.ColumnTypeBytes),
|
||||
crdb.NewColumn(AppSAMLConfigColumnMetadataURL, crdb.ColumnTypeText),
|
||||
},
|
||||
crdb.NewPrimaryKey(AppSAMLConfigColumnInstanceID, AppSAMLConfigColumnAppID),
|
||||
appSAMLTableSuffix,
|
||||
crdb.WithForeignKey(crdb.NewForeignKeyOfPublicKeys("fk_saml_ref_apps")),
|
||||
crdb.WithIndex(crdb.NewIndex("saml_entity_id_idx", []string{AppSAMLConfigColumnEntityID})),
|
||||
),
|
||||
)
|
||||
p.StatementHandler = crdb.NewStatementHandler(ctx, config)
|
||||
return p
|
||||
@@ -174,6 +194,14 @@ func (p *appProjection) reducers() []handler.AggregateReducer {
|
||||
Event: project.OIDCConfigSecretChangedType,
|
||||
Reduce: p.reduceOIDCConfigSecretChanged,
|
||||
},
|
||||
{
|
||||
Event: project.SAMLConfigAddedType,
|
||||
Reduce: p.reduceSAMLConfigAdded,
|
||||
},
|
||||
{
|
||||
Event: project.SAMLConfigChangedType,
|
||||
Reduce: p.reduceSAMLConfigChanged,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -535,3 +563,77 @@ func (p *appProjection) reduceOIDCConfigSecretChanged(event eventstore.Event) (*
|
||||
),
|
||||
), nil
|
||||
}
|
||||
|
||||
func (p *appProjection) reduceSAMLConfigAdded(event eventstore.Event) (*handler.Statement, error) {
|
||||
e, ok := event.(*project.SAMLConfigAddedEvent)
|
||||
if !ok {
|
||||
return nil, errors.ThrowInvalidArgument(nil, "HANDL-GMHU1", "reduce.wrong.event.type")
|
||||
}
|
||||
return crdb.NewMultiStatement(
|
||||
e,
|
||||
crdb.AddCreateStatement(
|
||||
[]handler.Column{
|
||||
handler.NewCol(AppSAMLConfigColumnAppID, e.AppID),
|
||||
handler.NewCol(AppSAMLConfigColumnInstanceID, e.Aggregate().InstanceID),
|
||||
handler.NewCol(AppSAMLConfigColumnEntityID, e.EntityID),
|
||||
handler.NewCol(AppSAMLConfigColumnMetadata, e.Metadata),
|
||||
handler.NewCol(AppSAMLConfigColumnMetadataURL, e.MetadataURL),
|
||||
},
|
||||
crdb.WithTableSuffix(appSAMLTableSuffix),
|
||||
),
|
||||
crdb.AddUpdateStatement(
|
||||
[]handler.Column{
|
||||
handler.NewCol(AppColumnChangeDate, e.CreationDate()),
|
||||
handler.NewCol(AppColumnSequence, e.Sequence()),
|
||||
},
|
||||
[]handler.Condition{
|
||||
handler.NewCond(AppColumnID, e.AppID),
|
||||
handler.NewCond(AppColumnInstanceID, e.Aggregate().InstanceID),
|
||||
},
|
||||
),
|
||||
), nil
|
||||
}
|
||||
|
||||
func (p *appProjection) reduceSAMLConfigChanged(event eventstore.Event) (*handler.Statement, error) {
|
||||
e, ok := event.(*project.SAMLConfigChangedEvent)
|
||||
if !ok {
|
||||
return nil, errors.ThrowInvalidArgument(nil, "HANDL-GMHU2", "reduce.wrong.event.type")
|
||||
}
|
||||
|
||||
cols := make([]handler.Column, 0, 3)
|
||||
if e.Metadata != nil {
|
||||
cols = append(cols, handler.NewCol(AppSAMLConfigColumnMetadata, e.Metadata))
|
||||
}
|
||||
if e.MetadataURL != nil {
|
||||
cols = append(cols, handler.NewCol(AppSAMLConfigColumnMetadataURL, *e.MetadataURL))
|
||||
}
|
||||
if e.EntityID != "" {
|
||||
cols = append(cols, handler.NewCol(AppSAMLConfigColumnEntityID, e.EntityID))
|
||||
}
|
||||
|
||||
if len(cols) == 0 {
|
||||
return crdb.NewNoOpStatement(e), nil
|
||||
}
|
||||
|
||||
return crdb.NewMultiStatement(
|
||||
e,
|
||||
crdb.AddUpdateStatement(
|
||||
cols,
|
||||
[]handler.Condition{
|
||||
handler.NewCond(AppSAMLConfigColumnAppID, e.AppID),
|
||||
handler.NewCond(AppSAMLConfigColumnInstanceID, e.Aggregate().InstanceID),
|
||||
},
|
||||
crdb.WithTableSuffix(appSAMLTableSuffix),
|
||||
),
|
||||
crdb.AddUpdateStatement(
|
||||
[]handler.Column{
|
||||
handler.NewCol(AppColumnChangeDate, e.CreationDate()),
|
||||
handler.NewCol(AppColumnSequence, e.Sequence()),
|
||||
},
|
||||
[]handler.Condition{
|
||||
handler.NewCond(AppColumnID, e.AppID),
|
||||
handler.NewCond(AppColumnInstanceID, e.Aggregate().InstanceID),
|
||||
},
|
||||
),
|
||||
), nil
|
||||
}
|
||||
|
@@ -44,7 +44,7 @@ func TestAppProjection_reduces(t *testing.T) {
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "INSERT INTO projections.apps2 (id, name, project_id, creation_date, change_date, resource_owner, instance_id, state, sequence) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)",
|
||||
expectedStmt: "INSERT INTO projections.apps3 (id, name, project_id, creation_date, change_date, resource_owner, instance_id, state, sequence) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)",
|
||||
expectedArgs: []interface{}{
|
||||
"app-id",
|
||||
"my-app",
|
||||
@@ -82,7 +82,7 @@ func TestAppProjection_reduces(t *testing.T) {
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "UPDATE projections.apps2 SET (name, change_date, sequence) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
|
||||
expectedStmt: "UPDATE projections.apps3 SET (name, change_date, sequence) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
|
||||
expectedArgs: []interface{}{
|
||||
"my-app",
|
||||
anyArg{},
|
||||
@@ -115,7 +115,7 @@ func TestAppProjection_reduces(t *testing.T) {
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "UPDATE projections.apps2 SET (state, change_date, sequence) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
|
||||
expectedStmt: "UPDATE projections.apps3 SET (state, change_date, sequence) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
|
||||
expectedArgs: []interface{}{
|
||||
domain.AppStateInactive,
|
||||
anyArg{},
|
||||
@@ -148,7 +148,7 @@ func TestAppProjection_reduces(t *testing.T) {
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "UPDATE projections.apps2 SET (state, change_date, sequence) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
|
||||
expectedStmt: "UPDATE projections.apps3 SET (state, change_date, sequence) = ($1, $2, $3) WHERE (id = $4) AND (instance_id = $5)",
|
||||
expectedArgs: []interface{}{
|
||||
domain.AppStateActive,
|
||||
anyArg{},
|
||||
@@ -181,7 +181,7 @@ func TestAppProjection_reduces(t *testing.T) {
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "DELETE FROM projections.apps2 WHERE (id = $1) AND (instance_id = $2)",
|
||||
expectedStmt: "DELETE FROM projections.apps3 WHERE (id = $1) AND (instance_id = $2)",
|
||||
expectedArgs: []interface{}{
|
||||
"app-id",
|
||||
"instance-id",
|
||||
@@ -209,7 +209,7 @@ func TestAppProjection_reduces(t *testing.T) {
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "DELETE FROM projections.apps2 WHERE (project_id = $1) AND (instance_id = $2)",
|
||||
expectedStmt: "DELETE FROM projections.apps3 WHERE (project_id = $1) AND (instance_id = $2)",
|
||||
expectedArgs: []interface{}{
|
||||
"agg-id",
|
||||
"instance-id",
|
||||
@@ -242,7 +242,7 @@ func TestAppProjection_reduces(t *testing.T) {
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "INSERT INTO projections.apps2_api_configs (app_id, instance_id, client_id, client_secret, auth_method) VALUES ($1, $2, $3, $4, $5)",
|
||||
expectedStmt: "INSERT INTO projections.apps3_api_configs (app_id, instance_id, client_id, client_secret, auth_method) VALUES ($1, $2, $3, $4, $5)",
|
||||
expectedArgs: []interface{}{
|
||||
"app-id",
|
||||
"instance-id",
|
||||
@@ -252,7 +252,7 @@ func TestAppProjection_reduces(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
expectedStmt: "UPDATE projections.apps2 SET (change_date, sequence) = ($1, $2) WHERE (id = $3) AND (instance_id = $4)",
|
||||
expectedStmt: "UPDATE projections.apps3 SET (change_date, sequence) = ($1, $2) WHERE (id = $3) AND (instance_id = $4)",
|
||||
expectedArgs: []interface{}{
|
||||
anyArg{},
|
||||
uint64(15),
|
||||
@@ -287,7 +287,7 @@ func TestAppProjection_reduces(t *testing.T) {
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "UPDATE projections.apps2_api_configs SET (client_secret, auth_method) = ($1, $2) WHERE (app_id = $3) AND (instance_id = $4)",
|
||||
expectedStmt: "UPDATE projections.apps3_api_configs SET (client_secret, auth_method) = ($1, $2) WHERE (app_id = $3) AND (instance_id = $4)",
|
||||
expectedArgs: []interface{}{
|
||||
anyArg{},
|
||||
domain.APIAuthMethodTypePrivateKeyJWT,
|
||||
@@ -296,7 +296,7 @@ func TestAppProjection_reduces(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
expectedStmt: "UPDATE projections.apps2 SET (change_date, sequence) = ($1, $2) WHERE (id = $3) AND (instance_id = $4)",
|
||||
expectedStmt: "UPDATE projections.apps3 SET (change_date, sequence) = ($1, $2) WHERE (id = $3) AND (instance_id = $4)",
|
||||
expectedArgs: []interface{}{
|
||||
anyArg{},
|
||||
uint64(15),
|
||||
@@ -351,7 +351,7 @@ func TestAppProjection_reduces(t *testing.T) {
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "UPDATE projections.apps2_api_configs SET client_secret = $1 WHERE (app_id = $2) AND (instance_id = $3)",
|
||||
expectedStmt: "UPDATE projections.apps3_api_configs SET client_secret = $1 WHERE (app_id = $2) AND (instance_id = $3)",
|
||||
expectedArgs: []interface{}{
|
||||
anyArg{},
|
||||
"app-id",
|
||||
@@ -359,7 +359,7 @@ func TestAppProjection_reduces(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
expectedStmt: "UPDATE projections.apps2 SET (change_date, sequence) = ($1, $2) WHERE (id = $3) AND (instance_id = $4)",
|
||||
expectedStmt: "UPDATE projections.apps3 SET (change_date, sequence) = ($1, $2) WHERE (id = $3) AND (instance_id = $4)",
|
||||
expectedArgs: []interface{}{
|
||||
anyArg{},
|
||||
uint64(15),
|
||||
@@ -407,7 +407,7 @@ func TestAppProjection_reduces(t *testing.T) {
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "INSERT INTO projections.apps2_oidc_configs (app_id, instance_id, version, client_id, client_secret, redirect_uris, response_types, grant_types, application_type, auth_method_type, post_logout_redirect_uris, is_dev_mode, access_token_type, access_token_role_assertion, id_token_role_assertion, id_token_userinfo_assertion, clock_skew, additional_origins) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18)",
|
||||
expectedStmt: "INSERT INTO projections.apps3_oidc_configs (app_id, instance_id, version, client_id, client_secret, redirect_uris, response_types, grant_types, application_type, auth_method_type, post_logout_redirect_uris, is_dev_mode, access_token_type, access_token_role_assertion, id_token_role_assertion, id_token_userinfo_assertion, clock_skew, additional_origins) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18)",
|
||||
expectedArgs: []interface{}{
|
||||
"app-id",
|
||||
"instance-id",
|
||||
@@ -430,7 +430,7 @@ func TestAppProjection_reduces(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
expectedStmt: "UPDATE projections.apps2 SET (change_date, sequence) = ($1, $2) WHERE (id = $3) AND (instance_id = $4)",
|
||||
expectedStmt: "UPDATE projections.apps3 SET (change_date, sequence) = ($1, $2) WHERE (id = $3) AND (instance_id = $4)",
|
||||
expectedArgs: []interface{}{
|
||||
anyArg{},
|
||||
uint64(15),
|
||||
@@ -476,7 +476,7 @@ func TestAppProjection_reduces(t *testing.T) {
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "UPDATE projections.apps2_oidc_configs SET (version, redirect_uris, response_types, grant_types, application_type, auth_method_type, post_logout_redirect_uris, is_dev_mode, access_token_type, access_token_role_assertion, id_token_role_assertion, id_token_userinfo_assertion, clock_skew, additional_origins) = ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) WHERE (app_id = $15) AND (instance_id = $16)",
|
||||
expectedStmt: "UPDATE projections.apps3_oidc_configs SET (version, redirect_uris, response_types, grant_types, application_type, auth_method_type, post_logout_redirect_uris, is_dev_mode, access_token_type, access_token_role_assertion, id_token_role_assertion, id_token_userinfo_assertion, clock_skew, additional_origins) = ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) WHERE (app_id = $15) AND (instance_id = $16)",
|
||||
expectedArgs: []interface{}{
|
||||
domain.OIDCVersionV1,
|
||||
database.StringArray{"redirect.one.ch", "redirect.two.ch"},
|
||||
@@ -497,7 +497,7 @@ func TestAppProjection_reduces(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
expectedStmt: "UPDATE projections.apps2 SET (change_date, sequence) = ($1, $2) WHERE (id = $3) AND (instance_id = $4)",
|
||||
expectedStmt: "UPDATE projections.apps3 SET (change_date, sequence) = ($1, $2) WHERE (id = $3) AND (instance_id = $4)",
|
||||
expectedArgs: []interface{}{
|
||||
anyArg{},
|
||||
uint64(15),
|
||||
@@ -552,7 +552,7 @@ func TestAppProjection_reduces(t *testing.T) {
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "UPDATE projections.apps2_oidc_configs SET client_secret = $1 WHERE (app_id = $2) AND (instance_id = $3)",
|
||||
expectedStmt: "UPDATE projections.apps3_oidc_configs SET client_secret = $1 WHERE (app_id = $2) AND (instance_id = $3)",
|
||||
expectedArgs: []interface{}{
|
||||
anyArg{},
|
||||
"app-id",
|
||||
@@ -560,7 +560,7 @@ func TestAppProjection_reduces(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
expectedStmt: "UPDATE projections.apps2 SET (change_date, sequence) = ($1, $2) WHERE (id = $3) AND (instance_id = $4)",
|
||||
expectedStmt: "UPDATE projections.apps3 SET (change_date, sequence) = ($1, $2) WHERE (id = $3) AND (instance_id = $4)",
|
||||
expectedArgs: []interface{}{
|
||||
anyArg{},
|
||||
uint64(15),
|
||||
|
@@ -13,9 +13,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
KeyProjectionTable = "projections.keys2"
|
||||
KeyProjectionTable = "projections.keys3"
|
||||
KeyPrivateTable = KeyProjectionTable + "_" + privateKeyTableSuffix
|
||||
KeyPublicTable = KeyProjectionTable + "_" + publicKeyTableSuffix
|
||||
CertificateTable = KeyProjectionTable + "_" + certificateTableSuffix
|
||||
|
||||
KeyColumnID = "id"
|
||||
KeyColumnCreationDate = "creation_date"
|
||||
@@ -37,14 +38,21 @@ const (
|
||||
KeyPublicColumnInstanceID = "instance_id"
|
||||
KeyPublicColumnExpiry = "expiry"
|
||||
KeyPublicColumnKey = "key"
|
||||
|
||||
certificateTableSuffix = "certificate"
|
||||
CertificateColumnID = "id"
|
||||
CertificateColumnInstanceID = "instance_id"
|
||||
CertificateColumnExpiry = "expiry"
|
||||
CertificateColumnCertificate = "certificate"
|
||||
)
|
||||
|
||||
type keyProjection struct {
|
||||
crdb.StatementHandler
|
||||
encryptionAlgorithm crypto.EncryptionAlgorithm
|
||||
encryptionAlgorithm crypto.EncryptionAlgorithm
|
||||
certEncryptionAlgorithm crypto.EncryptionAlgorithm
|
||||
}
|
||||
|
||||
func newKeyProjection(ctx context.Context, config crdb.StatementHandlerConfig, keyEncryptionAlgorithm crypto.EncryptionAlgorithm) *keyProjection {
|
||||
func newKeyProjection(ctx context.Context, config crdb.StatementHandlerConfig, keyEncryptionAlgorithm crypto.EncryptionAlgorithm, certEncryptionAlgorithm crypto.EncryptionAlgorithm) *keyProjection {
|
||||
p := new(keyProjection)
|
||||
config.ProjectionName = KeyProjectionTable
|
||||
config.Reducers = p.reducers()
|
||||
@@ -82,8 +90,19 @@ func newKeyProjection(ctx context.Context, config crdb.StatementHandlerConfig, k
|
||||
publicKeyTableSuffix,
|
||||
crdb.WithForeignKey(crdb.NewForeignKeyOfPublicKeys("fk_public_ref_keys")),
|
||||
),
|
||||
crdb.NewSuffixedTable([]*crdb.Column{
|
||||
crdb.NewColumn(CertificateColumnID, crdb.ColumnTypeText),
|
||||
crdb.NewColumn(CertificateColumnInstanceID, crdb.ColumnTypeText),
|
||||
crdb.NewColumn(CertificateColumnExpiry, crdb.ColumnTypeTimestamp),
|
||||
crdb.NewColumn(CertificateColumnCertificate, crdb.ColumnTypeBytes),
|
||||
},
|
||||
crdb.NewPrimaryKey(CertificateColumnInstanceID, CertificateColumnID),
|
||||
certificateTableSuffix,
|
||||
crdb.WithForeignKey(crdb.NewForeignKeyOfPublicKeys("fk_certificate_ref_keys")),
|
||||
),
|
||||
)
|
||||
p.encryptionAlgorithm = keyEncryptionAlgorithm
|
||||
p.certEncryptionAlgorithm = certEncryptionAlgorithm
|
||||
p.StatementHandler = crdb.NewStatementHandler(ctx, config)
|
||||
|
||||
return p
|
||||
@@ -98,6 +117,10 @@ func (p *keyProjection) reducers() []handler.AggregateReducer {
|
||||
Event: keypair.AddedEventType,
|
||||
Reduce: p.reduceKeyPairAdded,
|
||||
},
|
||||
{
|
||||
Event: keypair.AddedCertificateEventType,
|
||||
Reduce: p.reduceCertificateAdded,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -151,5 +174,34 @@ func (p *keyProjection) reduceKeyPairAdded(event eventstore.Event) (*handler.Sta
|
||||
crdb.WithTableSuffix(publicKeyTableSuffix),
|
||||
))
|
||||
}
|
||||
|
||||
return crdb.NewMultiStatement(e, creates...), nil
|
||||
}
|
||||
|
||||
func (p *keyProjection) reduceCertificateAdded(event eventstore.Event) (*handler.Statement, error) {
|
||||
e, ok := event.(*keypair.AddedCertificateEvent)
|
||||
if !ok {
|
||||
return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-SAbr09", "reduce.wrong.event.type %s", keypair.AddedCertificateEventType)
|
||||
}
|
||||
|
||||
if e.Certificate.Expiry.Before(time.Now()) {
|
||||
return crdb.NewNoOpStatement(e), nil
|
||||
}
|
||||
|
||||
certificate, err := crypto.Decrypt(e.Certificate.Key, p.certEncryptionAlgorithm)
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "HANDL-Dajwig2f", "cannot decrypt certificate")
|
||||
}
|
||||
|
||||
creates := []func(eventstore.Event) crdb.Exec{crdb.AddCreateStatement(
|
||||
[]handler.Column{
|
||||
handler.NewCol(CertificateColumnID, e.Aggregate().ID),
|
||||
handler.NewCol(CertificateColumnInstanceID, e.Aggregate().InstanceID),
|
||||
handler.NewCol(CertificateColumnExpiry, e.Certificate.Expiry),
|
||||
handler.NewCol(CertificateColumnCertificate, certificate),
|
||||
},
|
||||
crdb.WithTableSuffix(certificateTableSuffix),
|
||||
)}
|
||||
|
||||
return crdb.NewMultiStatement(e, creates...), nil
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package projection
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -31,7 +32,7 @@ func TestKeyProjection_reduces(t *testing.T) {
|
||||
event: getEvent(testEvent(
|
||||
repository.EventType(keypair.AddedEventType),
|
||||
keypair.AggregateType,
|
||||
keypairAddedEventData(time.Now().Add(time.Hour)),
|
||||
keypairAddedEventData(domain.KeyUsageSigning, time.Now().Add(time.Hour)),
|
||||
), keypair.AddedEventMapper),
|
||||
},
|
||||
reduce: (&keyProjection{encryptionAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t))}).reduceKeyPairAdded,
|
||||
@@ -43,7 +44,7 @@ func TestKeyProjection_reduces(t *testing.T) {
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "INSERT INTO projections.keys2 (id, creation_date, change_date, resource_owner, instance_id, sequence, algorithm, use) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)",
|
||||
expectedStmt: "INSERT INTO projections.keys3 (id, creation_date, change_date, resource_owner, instance_id, sequence, algorithm, use) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)",
|
||||
expectedArgs: []interface{}{
|
||||
"agg-id",
|
||||
anyArg{},
|
||||
@@ -56,7 +57,7 @@ func TestKeyProjection_reduces(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
expectedStmt: "INSERT INTO projections.keys2_private (id, instance_id, expiry, key) VALUES ($1, $2, $3, $4)",
|
||||
expectedStmt: "INSERT INTO projections.keys3_private (id, instance_id, expiry, key) VALUES ($1, $2, $3, $4)",
|
||||
expectedArgs: []interface{}{
|
||||
"agg-id",
|
||||
"instance-id",
|
||||
@@ -70,7 +71,7 @@ func TestKeyProjection_reduces(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
expectedStmt: "INSERT INTO projections.keys2_public (id, instance_id, expiry, key) VALUES ($1, $2, $3, $4)",
|
||||
expectedStmt: "INSERT INTO projections.keys3_public (id, instance_id, expiry, key) VALUES ($1, $2, $3, $4)",
|
||||
expectedArgs: []interface{}{
|
||||
"agg-id",
|
||||
"instance-id",
|
||||
@@ -88,7 +89,7 @@ func TestKeyProjection_reduces(t *testing.T) {
|
||||
event: getEvent(testEvent(
|
||||
repository.EventType(keypair.AddedEventType),
|
||||
keypair.AggregateType,
|
||||
keypairAddedEventData(time.Now().Add(-time.Hour)),
|
||||
keypairAddedEventData(domain.KeyUsageSigning, time.Now().Add(-time.Hour)),
|
||||
), keypair.AddedEventMapper),
|
||||
},
|
||||
reduce: (&keyProjection{}).reduceKeyPairAdded,
|
||||
@@ -100,6 +101,36 @@ func TestKeyProjection_reduces(t *testing.T) {
|
||||
executer: &testExecuter{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "reduceCertificateAdded",
|
||||
args: args{
|
||||
event: getEvent(testEvent(
|
||||
repository.EventType(keypair.AddedCertificateEventType),
|
||||
keypair.AggregateType,
|
||||
certificateAddedEventData(domain.KeyUsageSAMLMetadataSigning, time.Now().Add(time.Hour)),
|
||||
), keypair.AddedCertificateEventMapper),
|
||||
},
|
||||
reduce: (&keyProjection{certEncryptionAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t))}).reduceCertificateAdded,
|
||||
want: wantReduce{
|
||||
projection: KeyProjectionTable,
|
||||
aggregateType: eventstore.AggregateType("key_pair"),
|
||||
sequence: 15,
|
||||
previousSequence: 10,
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "INSERT INTO projections.keys3_certificate (id, instance_id, expiry, certificate) VALUES ($1, $2, $3, $4)",
|
||||
expectedArgs: []interface{}{
|
||||
"agg-id",
|
||||
"instance-id",
|
||||
anyArg{},
|
||||
[]byte("privateKey"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
@@ -116,6 +147,10 @@ func TestKeyProjection_reduces(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func keypairAddedEventData(t time.Time) []byte {
|
||||
return []byte(`{"algorithm": "algorithm", "usage": 0, "privateKey": {"key": {"cryptoType": 0, "algorithm": "enc", "keyID": "id", "crypted": "cHJpdmF0ZUtleQ=="}, "expiry": "` + t.Format(time.RFC3339) + `"}, "publicKey": {"key": {"cryptoType": 0, "algorithm": "enc", "keyID": "id", "crypted": "cHVibGljS2V5"}, "expiry": "` + t.Format(time.RFC3339) + `"}}`)
|
||||
func keypairAddedEventData(usage domain.KeyUsage, t time.Time) []byte {
|
||||
return []byte(`{"algorithm": "algorithm", "usage": ` + fmt.Sprintf("%d", usage) + `, "privateKey": {"key": {"cryptoType": 0, "algorithm": "enc", "keyID": "id", "crypted": "cHJpdmF0ZUtleQ=="}, "expiry": "` + t.Format(time.RFC3339) + `"}, "publicKey": {"key": {"cryptoType": 0, "algorithm": "enc", "keyID": "id", "crypted": "cHVibGljS2V5"}, "expiry": "` + t.Format(time.RFC3339) + `"}}`)
|
||||
}
|
||||
|
||||
func certificateAddedEventData(usage domain.KeyUsage, t time.Time) []byte {
|
||||
return []byte(`{"algorithm": "algorithm", "usage": ` + fmt.Sprintf("%d", usage) + `, "certificate": {"key": {"cryptoType": 0, "algorithm": "enc", "keyID": "id", "crypted": "cHJpdmF0ZUtleQ=="}, "expiry": "` + t.Format(time.RFC3339) + `"}}`)
|
||||
}
|
||||
|
@@ -62,7 +62,7 @@ var (
|
||||
NotificationsProjection interface{}
|
||||
)
|
||||
|
||||
func Start(ctx context.Context, sqlClient *sql.DB, es *eventstore.Eventstore, config Config, keyEncryptionAlgorithm crypto.EncryptionAlgorithm) error {
|
||||
func Start(ctx context.Context, sqlClient *sql.DB, es *eventstore.Eventstore, config Config, keyEncryptionAlgorithm crypto.EncryptionAlgorithm, certEncryptionAlgorithm crypto.EncryptionAlgorithm) error {
|
||||
projectionConfig = crdb.StatementHandlerConfig{
|
||||
ProjectionHandlerConfig: handler.ProjectionHandlerConfig{
|
||||
HandlerConfig: handler.HandlerConfig{
|
||||
@@ -120,7 +120,7 @@ func Start(ctx context.Context, sqlClient *sql.DB, es *eventstore.Eventstore, co
|
||||
SMSConfigProjection = newSMSConfigProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["sms_config"]))
|
||||
OIDCSettingsProjection = newOIDCSettingsProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["oidc_settings"]))
|
||||
DebugNotificationProviderProjection = newDebugNotificationProviderProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["debug_notification_provider"]))
|
||||
KeyProjection = newKeyProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["keys"]), keyEncryptionAlgorithm)
|
||||
KeyProjection = newKeyProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["keys"]), keyEncryptionAlgorithm, certEncryptionAlgorithm)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@@ -4,14 +4,15 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
sd "github.com/zitadel/zitadel/internal/config/systemdefaults"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/rakyll/statik/fs"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
sd "github.com/zitadel/zitadel/internal/config/systemdefaults"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
@@ -42,7 +43,7 @@ type Queries struct {
|
||||
multifactors domain.MultifactorConfigs
|
||||
}
|
||||
|
||||
func StartQueries(ctx context.Context, es *eventstore.Eventstore, sqlClient *sql.DB, projections projection.Config, defaults sd.SystemDefaults, idpConfigEncryption, otpEncryption, keyEncryptionAlgorithm crypto.EncryptionAlgorithm, zitadelRoles []authz.RoleMapping) (repo *Queries, err error) {
|
||||
func StartQueries(ctx context.Context, es *eventstore.Eventstore, sqlClient *sql.DB, projections projection.Config, defaults sd.SystemDefaults, idpConfigEncryption, otpEncryption, keyEncryptionAlgorithm crypto.EncryptionAlgorithm, certEncryptionAlgorithm crypto.EncryptionAlgorithm, zitadelRoles []authz.RoleMapping) (repo *Queries, err error) {
|
||||
statikLoginFS, err := fs.NewWithNamespace("login")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to start login statik dir")
|
||||
@@ -79,7 +80,7 @@ func StartQueries(ctx context.Context, es *eventstore.Eventstore, sqlClient *sql
|
||||
},
|
||||
}
|
||||
|
||||
err = projection.Start(ctx, sqlClient, es, projections, keyEncryptionAlgorithm)
|
||||
err = projection.Start(ctx, sqlClient, es, projections, keyEncryptionAlgorithm, certEncryptionAlgorithm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
61
internal/repository/keypair/certificate.go
Normal file
61
internal/repository/keypair/certificate.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package keypair
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
)
|
||||
|
||||
const (
|
||||
AddedCertificateEventType = eventTypePrefix + "certificate.added"
|
||||
)
|
||||
|
||||
type AddedCertificateEvent struct {
|
||||
eventstore.BaseEvent `json:"-"`
|
||||
|
||||
Certificate *Key `json:"certificate"`
|
||||
}
|
||||
|
||||
func (e *AddedCertificateEvent) Data() interface{} {
|
||||
return e
|
||||
}
|
||||
|
||||
func (e *AddedCertificateEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewAddedCertificateEvent(
|
||||
ctx context.Context,
|
||||
aggregate *eventstore.Aggregate,
|
||||
certificateCrypto *crypto.CryptoValue,
|
||||
certificateExpiration time.Time) *AddedCertificateEvent {
|
||||
return &AddedCertificateEvent{
|
||||
BaseEvent: *eventstore.NewBaseEventForPush(
|
||||
ctx,
|
||||
aggregate,
|
||||
AddedCertificateEventType,
|
||||
),
|
||||
Certificate: &Key{
|
||||
Key: certificateCrypto,
|
||||
Expiry: certificateExpiration,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func AddedCertificateEventMapper(event *repository.Event) (eventstore.Event, error) {
|
||||
e := &AddedCertificateEvent{
|
||||
BaseEvent: *eventstore.BaseEventFromRepo(event),
|
||||
}
|
||||
|
||||
err := json.Unmarshal(event.Data, e)
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "KEY-4n9vs", "unable to unmarshal certificate added")
|
||||
}
|
||||
|
||||
return e, nil
|
||||
}
|
@@ -6,4 +6,5 @@ import (
|
||||
|
||||
func RegisterEventMappers(es *eventstore.Eventstore) {
|
||||
es.RegisterFilterEventMapper(AddedEventType, AddedEventMapper)
|
||||
es.RegisterFilterEventMapper(AddedCertificateEventType, AddedCertificateEventMapper)
|
||||
}
|
||||
|
@@ -217,8 +217,9 @@ func ApplicationReactivatedEventMapper(event *repository.Event) (eventstore.Even
|
||||
type ApplicationRemovedEvent struct {
|
||||
eventstore.BaseEvent `json:"-"`
|
||||
|
||||
AppID string `json:"appId,omitempty"`
|
||||
name string
|
||||
AppID string `json:"appId,omitempty"`
|
||||
name string
|
||||
entityID string
|
||||
}
|
||||
|
||||
func (e *ApplicationRemovedEvent) Data() interface{} {
|
||||
@@ -226,7 +227,11 @@ func (e *ApplicationRemovedEvent) Data() interface{} {
|
||||
}
|
||||
|
||||
func (e *ApplicationRemovedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
|
||||
return []*eventstore.EventUniqueConstraint{NewRemoveApplicationUniqueConstraint(e.name, e.Aggregate().ID)}
|
||||
remove := []*eventstore.EventUniqueConstraint{NewRemoveApplicationUniqueConstraint(e.name, e.Aggregate().ID)}
|
||||
if e.entityID != "" {
|
||||
remove = append(remove, NewRemoveSAMLConfigEntityIDUniqueConstraint(e.entityID))
|
||||
}
|
||||
return remove
|
||||
}
|
||||
|
||||
func NewApplicationRemovedEvent(
|
||||
@@ -234,6 +239,7 @@ func NewApplicationRemovedEvent(
|
||||
aggregate *eventstore.Aggregate,
|
||||
appID,
|
||||
name string,
|
||||
entityID string,
|
||||
) *ApplicationRemovedEvent {
|
||||
return &ApplicationRemovedEvent{
|
||||
BaseEvent: *eventstore.NewBaseEventForPush(
|
||||
@@ -241,8 +247,9 @@ func NewApplicationRemovedEvent(
|
||||
aggregate,
|
||||
ApplicationRemovedType,
|
||||
),
|
||||
AppID: appID,
|
||||
name: name,
|
||||
AppID: appID,
|
||||
name: name,
|
||||
entityID: entityID,
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -41,5 +41,7 @@ func RegisterEventMappers(es *eventstore.Eventstore) {
|
||||
RegisterFilterEventMapper(APIConfigChangedType, APIConfigChangedEventMapper).
|
||||
RegisterFilterEventMapper(APIConfigSecretChangedType, APIConfigSecretChangedEventMapper).
|
||||
RegisterFilterEventMapper(ApplicationKeyAddedEventType, ApplicationKeyAddedEventMapper).
|
||||
RegisterFilterEventMapper(ApplicationKeyRemovedEventType, ApplicationKeyRemovedEventMapper)
|
||||
RegisterFilterEventMapper(ApplicationKeyRemovedEventType, ApplicationKeyRemovedEventMapper).
|
||||
RegisterFilterEventMapper(SAMLConfigAddedType, SAMLConfigAddedEventMapper).
|
||||
RegisterFilterEventMapper(SAMLConfigChangedType, SAMLConfigChangedEventMapper)
|
||||
}
|
||||
|
@@ -240,7 +240,8 @@ func ProjectReactivatedEventMapper(event *repository.Event) (eventstore.Event, e
|
||||
type ProjectRemovedEvent struct {
|
||||
eventstore.BaseEvent `json:"-"`
|
||||
|
||||
Name string
|
||||
Name string
|
||||
entityIDUniqueContraints []*eventstore.EventUniqueConstraint
|
||||
}
|
||||
|
||||
func (e *ProjectRemovedEvent) Data() interface{} {
|
||||
@@ -248,13 +249,20 @@ func (e *ProjectRemovedEvent) Data() interface{} {
|
||||
}
|
||||
|
||||
func (e *ProjectRemovedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
|
||||
return []*eventstore.EventUniqueConstraint{NewRemoveProjectNameUniqueConstraint(e.Name, e.Aggregate().ResourceOwner)}
|
||||
constraints := []*eventstore.EventUniqueConstraint{NewRemoveProjectNameUniqueConstraint(e.Name, e.Aggregate().ResourceOwner)}
|
||||
if e.entityIDUniqueContraints != nil {
|
||||
for _, constraint := range e.entityIDUniqueContraints {
|
||||
constraints = append(constraints, constraint)
|
||||
}
|
||||
}
|
||||
return constraints
|
||||
}
|
||||
|
||||
func NewProjectRemovedEvent(
|
||||
ctx context.Context,
|
||||
aggregate *eventstore.Aggregate,
|
||||
name string,
|
||||
entityIDUniqueContraints []*eventstore.EventUniqueConstraint,
|
||||
) *ProjectRemovedEvent {
|
||||
return &ProjectRemovedEvent{
|
||||
BaseEvent: *eventstore.NewBaseEventForPush(
|
||||
@@ -262,7 +270,8 @@ func NewProjectRemovedEvent(
|
||||
aggregate,
|
||||
ProjectRemovedType,
|
||||
),
|
||||
Name: name,
|
||||
Name: name,
|
||||
entityIDUniqueContraints: entityIDUniqueContraints,
|
||||
}
|
||||
}
|
||||
|
||||
|
163
internal/repository/project/saml_config.go
Normal file
163
internal/repository/project/saml_config.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
)
|
||||
|
||||
const (
|
||||
UniqueEntityIDType = "entity_ids"
|
||||
SAMLConfigAddedType = applicationEventTypePrefix + "config.saml.added"
|
||||
SAMLConfigChangedType = applicationEventTypePrefix + "config.saml.changed"
|
||||
)
|
||||
|
||||
type SAMLConfigAddedEvent struct {
|
||||
eventstore.BaseEvent `json:"-"`
|
||||
|
||||
AppID string `json:"appId"`
|
||||
EntityID string `json:"entityId"`
|
||||
Metadata []byte `json:"metadata,omitempty"`
|
||||
MetadataURL string `json:"metadata_url,omitempty"`
|
||||
}
|
||||
|
||||
func (e *SAMLConfigAddedEvent) Data() interface{} {
|
||||
return e
|
||||
}
|
||||
|
||||
func NewAddSAMLConfigEntityIDUniqueConstraint(entityID string) *eventstore.EventUniqueConstraint {
|
||||
return eventstore.NewAddEventUniqueConstraint(
|
||||
UniqueEntityIDType,
|
||||
entityID,
|
||||
"Errors.Project.App.SAMLEntityIDAlreadyExists")
|
||||
}
|
||||
|
||||
func NewRemoveSAMLConfigEntityIDUniqueConstraint(entityID string) *eventstore.EventUniqueConstraint {
|
||||
return eventstore.NewRemoveEventUniqueConstraint(
|
||||
UniqueEntityIDType,
|
||||
entityID)
|
||||
}
|
||||
|
||||
func (e *SAMLConfigAddedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
|
||||
return []*eventstore.EventUniqueConstraint{NewAddSAMLConfigEntityIDUniqueConstraint(e.EntityID)}
|
||||
}
|
||||
|
||||
func NewSAMLConfigAddedEvent(
|
||||
ctx context.Context,
|
||||
aggregate *eventstore.Aggregate,
|
||||
appID string,
|
||||
entityID string,
|
||||
metadata []byte,
|
||||
metadataURL string,
|
||||
) *SAMLConfigAddedEvent {
|
||||
return &SAMLConfigAddedEvent{
|
||||
BaseEvent: *eventstore.NewBaseEventForPush(
|
||||
ctx,
|
||||
aggregate,
|
||||
SAMLConfigAddedType,
|
||||
),
|
||||
AppID: appID,
|
||||
EntityID: entityID,
|
||||
Metadata: metadata,
|
||||
MetadataURL: metadataURL,
|
||||
}
|
||||
}
|
||||
|
||||
func SAMLConfigAddedEventMapper(event *repository.Event) (eventstore.Event, error) {
|
||||
e := &SAMLConfigAddedEvent{
|
||||
BaseEvent: *eventstore.BaseEventFromRepo(event),
|
||||
}
|
||||
|
||||
err := json.Unmarshal(event.Data, e)
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "SAML-BDd15", "unable to unmarshal saml config")
|
||||
}
|
||||
|
||||
return e, nil
|
||||
}
|
||||
|
||||
type SAMLConfigChangedEvent struct {
|
||||
eventstore.BaseEvent `json:"-"`
|
||||
|
||||
AppID string `json:"appId"`
|
||||
EntityID string `json:"entityId"`
|
||||
Metadata []byte `json:"metadata,omitempty"`
|
||||
MetadataURL *string `json:"metadata_url,omitempty"`
|
||||
oldEntityID string
|
||||
}
|
||||
|
||||
func (e *SAMLConfigChangedEvent) Data() interface{} {
|
||||
return e
|
||||
}
|
||||
|
||||
func (e *SAMLConfigChangedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint {
|
||||
if e.EntityID != "" {
|
||||
return []*eventstore.EventUniqueConstraint{
|
||||
NewRemoveSAMLConfigEntityIDUniqueConstraint(e.oldEntityID),
|
||||
NewAddSAMLConfigEntityIDUniqueConstraint(e.EntityID),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewSAMLConfigChangedEvent(
|
||||
ctx context.Context,
|
||||
aggregate *eventstore.Aggregate,
|
||||
appID string,
|
||||
oldEntityID string,
|
||||
changes []SAMLConfigChanges,
|
||||
) (*SAMLConfigChangedEvent, error) {
|
||||
if len(changes) == 0 {
|
||||
return nil, errors.ThrowPreconditionFailed(nil, "SAML-i8idç", "Errors.NoChangesFound")
|
||||
}
|
||||
|
||||
changeEvent := &SAMLConfigChangedEvent{
|
||||
BaseEvent: *eventstore.NewBaseEventForPush(
|
||||
ctx,
|
||||
aggregate,
|
||||
SAMLConfigChangedType,
|
||||
),
|
||||
AppID: appID,
|
||||
oldEntityID: oldEntityID,
|
||||
}
|
||||
for _, change := range changes {
|
||||
change(changeEvent)
|
||||
}
|
||||
return changeEvent, nil
|
||||
}
|
||||
|
||||
type SAMLConfigChanges func(event *SAMLConfigChangedEvent)
|
||||
|
||||
func ChangeMetadata(metadata []byte) func(event *SAMLConfigChangedEvent) {
|
||||
return func(e *SAMLConfigChangedEvent) {
|
||||
e.Metadata = metadata
|
||||
}
|
||||
}
|
||||
|
||||
func ChangeMetadataURL(metadataURL string) func(event *SAMLConfigChangedEvent) {
|
||||
return func(e *SAMLConfigChangedEvent) {
|
||||
e.MetadataURL = &metadataURL
|
||||
}
|
||||
}
|
||||
|
||||
func ChangeEntityID(entityID string) func(event *SAMLConfigChangedEvent) {
|
||||
return func(e *SAMLConfigChangedEvent) {
|
||||
e.EntityID = entityID
|
||||
}
|
||||
}
|
||||
|
||||
func SAMLConfigChangedEventMapper(event *repository.Event) (eventstore.Event, error) {
|
||||
e := &SAMLConfigChangedEvent{
|
||||
BaseEvent: *eventstore.BaseEventFromRepo(event),
|
||||
}
|
||||
|
||||
err := json.Unmarshal(event.Data, e)
|
||||
if err != nil {
|
||||
return nil, errors.ThrowInternal(err, "SAML-BFd15", "unable to unmarshal saml config")
|
||||
}
|
||||
|
||||
return e, nil
|
||||
}
|
@@ -247,9 +247,14 @@ Errors:
|
||||
NotExisting: Applikation existiert nicht
|
||||
IsNotOIDC: Applikation ist nicht vom Typ OIDC
|
||||
IsNotAPI: Applikation ist nicht vom Typ API
|
||||
IsNotSAML: Applikation ist nicht vom Typ SAML
|
||||
NotActive: Applikation ist nicht aktiv
|
||||
NotInactive: Applikation ist nickt inaktiv
|
||||
OIDCConfigInvalid: OIDC Konfiguration ist ungültig
|
||||
SAMLConfigInvalid: SAML Konfiguration ist ungültig
|
||||
SAMLMetadataMissing: SAML Metadata ist nicht vorhanden
|
||||
SAMLMetadataFormat: SAML Metadata Formatfehler
|
||||
SAMLEntityIDAlreadyExisting: SAML EntityID existiert bereits
|
||||
APIConfigInvalid: API Konfiguration ist ungültig
|
||||
OIDCAuthMethodNoSecret: Gewählte OIDC Auth Method benötigt kein Secret
|
||||
APIAuthMethodNoSecret: Gewählte API Auth Method benötigt kein Secret
|
||||
@@ -751,6 +756,9 @@ EventTypes:
|
||||
added: Applikations Schlüssel hinzugefügt
|
||||
removed: Applikations Schlüssel entfernt
|
||||
config:
|
||||
saml:
|
||||
added: SAML Konfiguration hinzugefügt
|
||||
changed: SAML Konfiguration geändert
|
||||
oidc:
|
||||
added: OIDC Konfiguration hinzugefügt
|
||||
changed: OIDC Konfiguration geändert
|
||||
|
@@ -249,8 +249,13 @@ Errors:
|
||||
NotInactive: Application is not inactive
|
||||
OIDCConfigInvalid: OIDC configuration is invalid
|
||||
APIConfigInvalid: API configuration is invalid
|
||||
IsNotOIDC: Application is not type oidc
|
||||
SAMLConfigInvalid: SAML configuration is invalid
|
||||
IsNotOIDC: Application is not type OIDC
|
||||
IsNotAPI: Application is not type API
|
||||
IsNotSAML: Application is not type SAML
|
||||
SAMLMetadataMissing: SAML metadata is missing
|
||||
SAMLMetadataFormat: SAML Metadata format error
|
||||
SAMLEntityIDAlreadyExisting: SAML EntityID already existing
|
||||
OIDCAuthMethodNoSecret: Chosen OIDC Auth Method does not require a secret
|
||||
APIAuthMethodNoSecret: Chosen API Auth Method does not require a secret
|
||||
AuthMethodNoPrivateKeyJWT: Chosen Auth Method does not require a key
|
||||
@@ -751,6 +756,9 @@ EventTypes:
|
||||
added: Application key added
|
||||
removed: Application key removed
|
||||
config:
|
||||
saml:
|
||||
added: SAML Configuration added
|
||||
changed: SAML Configuration changed
|
||||
oidc:
|
||||
added: OIDC Configuration added
|
||||
changed: OIDC Configuration changed
|
||||
|
@@ -249,8 +249,13 @@ Errors:
|
||||
NotInactive: L'application n'est pas inactive
|
||||
OIDCConfigInvalid: La configuration de l'OIDC n'est pas valide
|
||||
APIConfigInvalid: La configuration de l'API n'est pas valide
|
||||
IsNotOIDC: L'application n'est pas de type oidc
|
||||
SAMLConfigInvalid: La configuration de l'SAML n'est pas valide
|
||||
IsNotOIDC: L'application n'est pas de type OIDC
|
||||
IsNotAPI: L'application n'est pas de type API
|
||||
IsNotSAML: L'application n'est pas de type SAML
|
||||
SAMLMetadataMissing: Les métadonnées SAML sont manquantes
|
||||
SAMLMetadataFormat: Erreur de format des métadonnées SAML
|
||||
SAMLEntityIDAlreadyExisting: SAML EntityID déjà existant
|
||||
OIDCAuthMethodNoSecret: La méthode d'authentification OIDC choisie ne nécessite pas de secret.
|
||||
APIAuthMethodNoSecret: La méthode d'authentification API choisie ne nécessite pas de secret.
|
||||
AuthMethodNoPrivateKeyJWT: La méthode d'authentification choisie ne nécessite pas de clé.
|
||||
@@ -751,6 +756,9 @@ EventTypes:
|
||||
added: Clé d'application ajoutée
|
||||
removed: Clé d'application supprimée
|
||||
config:
|
||||
saml:
|
||||
added: Configuration SAML ajoutée
|
||||
changed: La configuration de SAML a été modifiée
|
||||
oidc:
|
||||
added: Configuration OIDC ajoutée
|
||||
changed: Modification de la configuration de l'OIDC
|
||||
|
@@ -249,8 +249,13 @@ Errors:
|
||||
NotInactive: L'applicazione non è inattiva
|
||||
OIDCConfigInvalid: La configurazione OIDC non è valida
|
||||
APIConfigInvalid: La configurazione API non è valida
|
||||
IsNotOIDC: L'applicazione non è di tipo oidc
|
||||
SAMLConfigInvalid: La configurazione SAML non è valida
|
||||
IsNotOIDC: L'applicazione non è di tipo OIDC
|
||||
IsNotAPI: L'applicazione non è di tipo API
|
||||
IsNotSAML: L'applicazione non è di tipo SAML
|
||||
SAMLMetadataMissing: Mancano i metadati SAML
|
||||
SAMLMetadataFormat: Errore nel formato dei metadati SAML
|
||||
SAMLEntityIDAlreadyExisting: EntityID SAML già esistente
|
||||
OIDCAuthMethodNoSecret: Il metodo di autorizzazione OIDC scelto non richiede un segreto
|
||||
APIAuthMethodNoSecret: Il metodo di autorizzazione API scelto non richiede un segreto
|
||||
AuthMethodNoPrivateKeyJWT: Il metodo di autorizzazione scelto non richiede una chiave
|
||||
@@ -751,6 +756,9 @@ EventTypes:
|
||||
added: Chiave di applicazione aggiunta
|
||||
removed: Chiave di applicazione rimossa
|
||||
config:
|
||||
saml:
|
||||
added: Configurazione SAML aggiunta
|
||||
changed: Configurazione SAML modificata
|
||||
oidc:
|
||||
added: Configurazione OIDC aggiunta
|
||||
changed: Configurazione OIDC modificata
|
||||
|
@@ -750,6 +750,9 @@ EventTypes:
|
||||
added: 添加应用 Key
|
||||
removed: 删除应用 Key
|
||||
config:
|
||||
saml:
|
||||
added: 添加 SAML 配置
|
||||
changed: 更改 SAML 配置
|
||||
oidc:
|
||||
added: 添加 OIDC 配置
|
||||
changed: 更改 OIDC 配置
|
||||
|
Reference in New Issue
Block a user