mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 07:57:32 +00:00
feat(saml): implementation of saml for ZITADEL v2 (#3618)
This commit is contained in:
@@ -67,6 +67,16 @@ func (s *Server) AddOIDCApp(ctx context.Context, req *mgmt_pb.AddOIDCAppRequest)
|
||||
ComplianceProblems: project_grpc.ComplianceProblemsToLocalizedMessages(app.Compliance.Problems),
|
||||
}, nil
|
||||
}
|
||||
func (s *Server) AddSAMLApp(ctx context.Context, req *mgmt_pb.AddSAMLAppRequest) (*mgmt_pb.AddSAMLAppResponse, error) {
|
||||
app, err := s.command.AddSAMLApplication(ctx, AddSAMLAppRequestToDomain(req), authz.GetCtxData(ctx).OrgID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &mgmt_pb.AddSAMLAppResponse{
|
||||
AppId: app.AppID,
|
||||
Details: object_grpc.AddToDetailsPb(app.Sequence, app.ChangeDate, app.ResourceOwner),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) AddAPIApp(ctx context.Context, req *mgmt_pb.AddAPIAppRequest) (*mgmt_pb.AddAPIAppResponse, error) {
|
||||
appSecretGenerator, err := s.query.InitHashGenerator(ctx, domain.SecretGeneratorTypeAppSecret, s.passwordHashAlg)
|
||||
@@ -109,6 +119,20 @@ func (s *Server) UpdateOIDCAppConfig(ctx context.Context, req *mgmt_pb.UpdateOID
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) UpdateSAMLAppConfig(ctx context.Context, req *mgmt_pb.UpdateSAMLAppConfigRequest) (*mgmt_pb.UpdateSAMLAppConfigResponse, error) {
|
||||
config, err := s.command.ChangeSAMLApplication(ctx, UpdateSAMLAppConfigRequestToDomain(req), authz.GetCtxData(ctx).OrgID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &mgmt_pb.UpdateSAMLAppConfigResponse{
|
||||
Details: object_grpc.ChangeToDetailsPb(
|
||||
config.Sequence,
|
||||
config.ChangeDate,
|
||||
config.ResourceOwner,
|
||||
),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) UpdateAPIAppConfig(ctx context.Context, req *mgmt_pb.UpdateAPIAppConfigRequest) (*mgmt_pb.UpdateAPIAppConfigResponse, error) {
|
||||
config, err := s.command.ChangeAPIApplication(ctx, UpdateAPIAppConfigRequestToDomain(req), authz.GetCtxData(ctx).OrgID)
|
||||
if err != nil {
|
||||
|
@@ -59,6 +59,17 @@ func AddOIDCAppRequestToDomain(req *mgmt_pb.AddOIDCAppRequest) *domain.OIDCApp {
|
||||
}
|
||||
}
|
||||
|
||||
func AddSAMLAppRequestToDomain(req *mgmt_pb.AddSAMLAppRequest) *domain.SAMLApp {
|
||||
return &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: req.ProjectId,
|
||||
},
|
||||
AppName: req.Name,
|
||||
Metadata: req.GetMetadataXml(),
|
||||
MetadataURL: req.GetMetadataUrl(),
|
||||
}
|
||||
}
|
||||
|
||||
func AddAPIAppRequestToDomain(app *mgmt_pb.AddAPIAppRequest) *domain.APIApp {
|
||||
return &domain.APIApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
@@ -98,6 +109,17 @@ func UpdateOIDCAppConfigRequestToDomain(app *mgmt_pb.UpdateOIDCAppConfigRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateSAMLAppConfigRequestToDomain(app *mgmt_pb.UpdateSAMLAppConfigRequest) *domain.SAMLApp {
|
||||
return &domain.SAMLApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
AggregateID: app.ProjectId,
|
||||
},
|
||||
AppID: app.AppId,
|
||||
Metadata: app.GetMetadataXml(),
|
||||
MetadataURL: app.GetMetadataUrl(),
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateAPIAppConfigRequestToDomain(app *mgmt_pb.UpdateAPIAppConfigRequest) *domain.APIApp {
|
||||
return &domain.APIApp{
|
||||
ObjectRoot: models.ObjectRoot{
|
||||
|
@@ -33,6 +33,9 @@ func AppConfigToPb(app *query.App) app_pb.AppConfig {
|
||||
if app.OIDCConfig != nil {
|
||||
return AppOIDCConfigToPb(app.OIDCConfig)
|
||||
}
|
||||
if app.SAMLConfig != nil {
|
||||
return AppSAMLConfigToPb(app.SAMLConfig)
|
||||
}
|
||||
return AppAPIConfigToPb(app.APIConfig)
|
||||
}
|
||||
|
||||
@@ -61,6 +64,14 @@ func AppOIDCConfigToPb(app *query.OIDCApp) *app_pb.App_OidcConfig {
|
||||
}
|
||||
}
|
||||
|
||||
func AppSAMLConfigToPb(app *query.SAMLApp) app_pb.AppConfig {
|
||||
return &app_pb.App_SamlConfig{
|
||||
SamlConfig: &app_pb.SAMLConfig{
|
||||
Metadata: &app_pb.SAMLConfig_MetadataXml{MetadataXml: app.Metadata},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func AppAPIConfigToPb(app *query.APIApp) app_pb.AppConfig {
|
||||
return &app_pb.App_ApiConfig{
|
||||
ApiConfig: &app_pb.APIConfig{
|
||||
|
99
internal/api/saml/auth_request_converter.go
Normal file
99
internal/api/saml/auth_request_converter.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/saml/pkg/provider/models"
|
||||
"github.com/zitadel/saml/pkg/provider/xml/samlp"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
var _ models.AuthRequestInt = &AuthRequest{}
|
||||
|
||||
type AuthRequest struct {
|
||||
*domain.AuthRequest
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetApplicationID() string {
|
||||
return a.ApplicationID
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetID() string {
|
||||
return a.ID
|
||||
}
|
||||
func (a *AuthRequest) GetRelayState() string {
|
||||
return a.TransferState
|
||||
}
|
||||
func (a *AuthRequest) GetAccessConsumerServiceURL() string {
|
||||
return a.CallbackURI
|
||||
}
|
||||
|
||||
func (a *AuthRequest) GetNameID() string {
|
||||
return a.UserName
|
||||
}
|
||||
|
||||
func (a *AuthRequest) saml() *domain.AuthRequestSAML {
|
||||
return a.Request.(*domain.AuthRequestSAML)
|
||||
}
|
||||
func (a *AuthRequest) GetAuthRequestID() string {
|
||||
return a.saml().ID
|
||||
}
|
||||
func (a *AuthRequest) GetBindingType() string {
|
||||
return a.saml().BindingType
|
||||
}
|
||||
func (a *AuthRequest) GetIssuer() string {
|
||||
return a.saml().Issuer
|
||||
}
|
||||
func (a *AuthRequest) GetIssuerName() string {
|
||||
return a.saml().IssuerName
|
||||
}
|
||||
func (a *AuthRequest) GetDestination() string {
|
||||
return a.saml().Destination
|
||||
}
|
||||
func (a *AuthRequest) GetCode() string {
|
||||
return a.saml().Code
|
||||
}
|
||||
func (a *AuthRequest) GetUserID() string {
|
||||
return a.UserID
|
||||
}
|
||||
func (a *AuthRequest) GetUserName() string {
|
||||
return a.UserName
|
||||
}
|
||||
func (a *AuthRequest) Done() bool {
|
||||
for _, step := range a.PossibleSteps {
|
||||
if step.Type() == domain.NextStepRedirectToCallback {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func AuthRequestFromBusiness(authReq *domain.AuthRequest) (_ models.AuthRequestInt, err error) {
|
||||
if _, ok := authReq.Request.(*domain.AuthRequestSAML); !ok {
|
||||
return nil, errors.ThrowInvalidArgument(nil, "SAML-Hbz7A", "auth request is not of type saml")
|
||||
}
|
||||
return &AuthRequest{authReq}, nil
|
||||
}
|
||||
|
||||
func CreateAuthRequestToBusiness(ctx context.Context, authReq *samlp.AuthnRequestType, acsUrl, protocolBinding, applicationID, relayState, userAgentID string) *domain.AuthRequest {
|
||||
return &domain.AuthRequest{
|
||||
CreationDate: time.Now(),
|
||||
AgentID: userAgentID,
|
||||
ApplicationID: applicationID,
|
||||
CallbackURI: acsUrl,
|
||||
TransferState: relayState,
|
||||
InstanceID: authz.GetInstance(ctx).InstanceID(),
|
||||
Request: &domain.AuthRequestSAML{
|
||||
ID: authReq.Id,
|
||||
BindingType: protocolBinding,
|
||||
Code: "",
|
||||
Issuer: authReq.Issuer.Text,
|
||||
IssuerName: authReq.Issuer.SPProvidedID,
|
||||
Destination: authReq.Destination,
|
||||
},
|
||||
}
|
||||
}
|
202
internal/api/saml/certificate.go
Normal file
202
internal/api/saml/certificate.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"github.com/zitadel/saml/pkg/provider/key"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/repository/keypair"
|
||||
)
|
||||
|
||||
const (
|
||||
locksTable = "projections.locks"
|
||||
signingKey = "signing_key"
|
||||
samlUser = "SAML"
|
||||
|
||||
retryBackoff = 500 * time.Millisecond
|
||||
retryCount = 3
|
||||
lockDuration = retryCount * retryBackoff * 5
|
||||
gracefulPeriod = 10 * time.Minute
|
||||
)
|
||||
|
||||
type CertificateAndKey struct {
|
||||
algorithm jose.SignatureAlgorithm
|
||||
id string
|
||||
key interface{}
|
||||
certificate interface{}
|
||||
}
|
||||
|
||||
func (c *CertificateAndKey) SignatureAlgorithm() jose.SignatureAlgorithm {
|
||||
return c.algorithm
|
||||
}
|
||||
|
||||
func (c *CertificateAndKey) Key() interface{} {
|
||||
return c.key
|
||||
}
|
||||
|
||||
func (c *CertificateAndKey) Certificate() interface{} {
|
||||
return c.certificate
|
||||
}
|
||||
|
||||
func (c *CertificateAndKey) ID() string {
|
||||
return c.id
|
||||
}
|
||||
|
||||
func (p *Storage) GetCertificateAndKey(ctx context.Context, usage domain.KeyUsage) (certAndKey *key.CertificateAndKey, err error) {
|
||||
err = retry(func() error {
|
||||
certAndKey, err = p.getCertificateAndKey(ctx, usage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if certAndKey == nil {
|
||||
return errors.ThrowInternal(err, "SAML-8u01nks", "no certificate found")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return certAndKey, err
|
||||
}
|
||||
|
||||
func (p *Storage) getCertificateAndKey(ctx context.Context, usage domain.KeyUsage) (*key.CertificateAndKey, error) {
|
||||
certs, err := p.query.ActiveCertificates(ctx, time.Now().Add(gracefulPeriod), usage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(certs.Certificates) > 0 {
|
||||
return p.certificateToCertificateAndKey(selectCertificate(certs.Certificates))
|
||||
}
|
||||
|
||||
var sequence uint64
|
||||
if certs.LatestSequence != nil {
|
||||
sequence = certs.LatestSequence.Sequence
|
||||
}
|
||||
|
||||
return nil, p.refreshCertificate(ctx, usage, sequence)
|
||||
}
|
||||
|
||||
func (p *Storage) refreshCertificate(
|
||||
ctx context.Context,
|
||||
usage domain.KeyUsage,
|
||||
sequence uint64,
|
||||
) error {
|
||||
ok, err := p.ensureIsLatestCertificate(ctx, sequence)
|
||||
if err != nil {
|
||||
logging.WithError(err).Error("could not ensure latest key")
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
logging.Warn("view not up to date, retrying later")
|
||||
return err
|
||||
}
|
||||
err = p.lockAndGenerateCertificateAndKey(ctx, usage, sequence)
|
||||
logging.OnError(err).Warn("could not create signing key")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Storage) ensureIsLatestCertificate(ctx context.Context, sequence uint64) (bool, error) {
|
||||
maxSequence, err := p.getMaxKeySequence(ctx)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("error retrieving new events: %w", err)
|
||||
}
|
||||
return sequence == maxSequence, nil
|
||||
}
|
||||
|
||||
func (p *Storage) lockAndGenerateCertificateAndKey(ctx context.Context, usage domain.KeyUsage, sequence uint64) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
ctx = setSAMLCtx(ctx)
|
||||
|
||||
errs := p.locker.Lock(ctx, lockDuration, authz.GetInstance(ctx).InstanceID())
|
||||
err, ok := <-errs
|
||||
if err != nil || !ok {
|
||||
if errors.IsErrorAlreadyExists(err) {
|
||||
return nil
|
||||
}
|
||||
logging.OnError(err).Warn("initial lock failed")
|
||||
return err
|
||||
}
|
||||
|
||||
switch usage {
|
||||
case domain.KeyUsageSAMLMetadataSigning, domain.KeyUsageSAMLResponseSinging:
|
||||
certAndKey, err := p.GetCertificateAndKey(ctx, domain.KeyUsageSAMLCA)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while reading ca certificate: %w", err)
|
||||
}
|
||||
if certAndKey.Key == nil || certAndKey.Certificate == nil {
|
||||
return fmt.Errorf("has no ca certificate")
|
||||
}
|
||||
|
||||
switch usage {
|
||||
case domain.KeyUsageSAMLMetadataSigning:
|
||||
return p.command.GenerateSAMLMetadataCertificate(setSAMLCtx(ctx), p.certificateAlgorithm, certAndKey.Key, certAndKey.Certificate)
|
||||
case domain.KeyUsageSAMLResponseSinging:
|
||||
return p.command.GenerateSAMLResponseCertificate(setSAMLCtx(ctx), p.certificateAlgorithm, certAndKey.Key, certAndKey.Certificate)
|
||||
default:
|
||||
return fmt.Errorf("unknown usage")
|
||||
}
|
||||
case domain.KeyUsageSAMLCA:
|
||||
return p.command.GenerateSAMLCACertificate(setSAMLCtx(ctx), p.certificateAlgorithm)
|
||||
default:
|
||||
return fmt.Errorf("unknown certificate usage")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Storage) getMaxKeySequence(ctx context.Context) (uint64, error) {
|
||||
return p.eventstore.LatestSequence(ctx,
|
||||
eventstore.NewSearchQueryBuilder(eventstore.ColumnsMaxSequence).
|
||||
ResourceOwner(authz.GetInstance(ctx).InstanceID()).
|
||||
AddQuery().
|
||||
AggregateTypes(keypair.AggregateType).
|
||||
Builder(),
|
||||
)
|
||||
}
|
||||
|
||||
func (p *Storage) certificateToCertificateAndKey(certificate query.Certificate) (_ *key.CertificateAndKey, err error) {
|
||||
keyData, err := crypto.Decrypt(certificate.Key(), p.encAlg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
privateKey, err := crypto.BytesToPrivateKey(keyData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cert, err := crypto.BytesToCertificate(certificate.Certificate())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &key.CertificateAndKey{
|
||||
Key: privateKey,
|
||||
Certificate: cert,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func selectCertificate(certs []query.Certificate) query.Certificate {
|
||||
return certs[len(certs)-1]
|
||||
}
|
||||
|
||||
func setSAMLCtx(ctx context.Context) context.Context {
|
||||
return authz.SetCtxData(ctx, authz.CtxData{UserID: samlUser, OrgID: authz.GetInstance(ctx).InstanceID()})
|
||||
}
|
||||
|
||||
func retry(retryable func() error) (err error) {
|
||||
for i := 0; i < retryCount; i++ {
|
||||
err = retryable()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(retryBackoff)
|
||||
}
|
||||
return err
|
||||
}
|
102
internal/api/saml/provider.go
Normal file
102
internal/api/saml/provider.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/zitadel/saml/pkg/provider"
|
||||
|
||||
http_utils "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/api/http/middleware"
|
||||
"github.com/zitadel/zitadel/internal/api/ui/login"
|
||||
"github.com/zitadel/zitadel/internal/auth/repository"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler/crdb"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/metrics"
|
||||
)
|
||||
|
||||
const (
|
||||
HandlerPrefix = "/saml/v2"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
ProviderConfig *provider.Config
|
||||
}
|
||||
|
||||
func NewProvider(
|
||||
ctx context.Context,
|
||||
conf Config,
|
||||
externalSecure bool,
|
||||
command *command.Commands,
|
||||
query *query.Queries,
|
||||
repo repository.Repository,
|
||||
encAlg crypto.EncryptionAlgorithm,
|
||||
certEncAlg crypto.EncryptionAlgorithm,
|
||||
es *eventstore.Eventstore,
|
||||
projections *sql.DB,
|
||||
instanceHandler,
|
||||
userAgentCookie func(http.Handler) http.Handler,
|
||||
) (*provider.Provider, error) {
|
||||
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
|
||||
|
||||
provStorage, err := newStorage(
|
||||
command,
|
||||
query,
|
||||
repo,
|
||||
encAlg,
|
||||
certEncAlg,
|
||||
es,
|
||||
projections,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
options := []provider.Option{
|
||||
provider.WithHttpInterceptors(
|
||||
middleware.MetricsHandler(metricTypes),
|
||||
middleware.TelemetryHandler(),
|
||||
middleware.NoCacheInterceptor().Handler,
|
||||
instanceHandler,
|
||||
userAgentCookie,
|
||||
http_utils.CopyHeadersToContext,
|
||||
),
|
||||
}
|
||||
if !externalSecure {
|
||||
options = append(options, provider.WithAllowInsecure())
|
||||
}
|
||||
|
||||
return provider.NewProvider(
|
||||
ctx,
|
||||
provStorage,
|
||||
HandlerPrefix,
|
||||
conf.ProviderConfig,
|
||||
options...,
|
||||
)
|
||||
}
|
||||
|
||||
func newStorage(
|
||||
command *command.Commands,
|
||||
query *query.Queries,
|
||||
repo repository.Repository,
|
||||
encAlg crypto.EncryptionAlgorithm,
|
||||
certEncAlg crypto.EncryptionAlgorithm,
|
||||
es *eventstore.Eventstore,
|
||||
projections *sql.DB,
|
||||
) (*Storage, error) {
|
||||
return &Storage{
|
||||
encAlg: encAlg,
|
||||
certEncAlg: certEncAlg,
|
||||
locker: crdb.NewLocker(projections, locksTable, signingKey),
|
||||
eventstore: es,
|
||||
repo: repo,
|
||||
command: command,
|
||||
query: query,
|
||||
defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID),
|
||||
}, nil
|
||||
}
|
187
internal/api/saml/storage.go
Normal file
187
internal/api/saml/storage.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/saml/pkg/provider"
|
||||
"github.com/zitadel/saml/pkg/provider/key"
|
||||
"github.com/zitadel/saml/pkg/provider/models"
|
||||
"github.com/zitadel/saml/pkg/provider/serviceprovider"
|
||||
"github.com/zitadel/saml/pkg/provider/xml/samlp"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/http/middleware"
|
||||
"github.com/zitadel/zitadel/internal/auth/repository"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler/crdb"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
var _ provider.EntityStorage = &Storage{}
|
||||
var _ provider.IdentityProviderStorage = &Storage{}
|
||||
var _ provider.AuthStorage = &Storage{}
|
||||
var _ provider.UserStorage = &Storage{}
|
||||
|
||||
type Storage struct {
|
||||
certChan <-chan interface{}
|
||||
defaultCertificateLifetime time.Duration
|
||||
|
||||
currentCACertificate query.Certificate
|
||||
currentMetadataCertificate query.Certificate
|
||||
currentResponseCertificate query.Certificate
|
||||
|
||||
locker crdb.Locker
|
||||
certificateAlgorithm string
|
||||
encAlg crypto.EncryptionAlgorithm
|
||||
certEncAlg crypto.EncryptionAlgorithm
|
||||
|
||||
eventstore *eventstore.Eventstore
|
||||
repo repository.Repository
|
||||
command *command.Commands
|
||||
query *query.Queries
|
||||
|
||||
defaultLoginURL string
|
||||
}
|
||||
|
||||
func (p *Storage) GetEntityByID(ctx context.Context, entityID string) (*serviceprovider.ServiceProvider, error) {
|
||||
app, err := p.query.AppBySAMLEntityID(ctx, entityID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return serviceprovider.NewServiceProvider(
|
||||
app.ID,
|
||||
&serviceprovider.Config{
|
||||
Metadata: app.SAMLConfig.Metadata,
|
||||
},
|
||||
p.defaultLoginURL,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *Storage) GetEntityIDByAppID(ctx context.Context, appID string) (string, error) {
|
||||
app, err := p.query.AppByID(ctx, appID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return app.SAMLConfig.EntityID, nil
|
||||
}
|
||||
|
||||
func (p *Storage) Health(context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Storage) GetCA(ctx context.Context) (*key.CertificateAndKey, error) {
|
||||
return p.GetCertificateAndKey(ctx, domain.KeyUsageSAMLCA)
|
||||
}
|
||||
|
||||
func (p *Storage) GetMetadataSigningKey(ctx context.Context) (*key.CertificateAndKey, error) {
|
||||
return p.GetCertificateAndKey(ctx, domain.KeyUsageSAMLMetadataSigning)
|
||||
}
|
||||
|
||||
func (p *Storage) GetResponseSigningKey(ctx context.Context) (*key.CertificateAndKey, error) {
|
||||
return p.GetCertificateAndKey(ctx, domain.KeyUsageSAMLResponseSinging)
|
||||
}
|
||||
|
||||
func (p *Storage) CreateAuthRequest(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID string) (_ models.AuthRequestInt, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
||||
if !ok {
|
||||
return nil, errors.ThrowPreconditionFailed(nil, "SAML-sd436", "no user agent id")
|
||||
}
|
||||
|
||||
authRequest := CreateAuthRequestToBusiness(ctx, req, acsUrl, protocolBinding, applicationID, relayState, userAgentID)
|
||||
|
||||
resp, err := p.repo.CreateAuthRequest(ctx, authRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return AuthRequestFromBusiness(resp)
|
||||
}
|
||||
|
||||
func (p *Storage) AuthRequestByID(ctx context.Context, id string) (_ models.AuthRequestInt, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
||||
if !ok {
|
||||
return nil, errors.ThrowPreconditionFailed(nil, "SAML-D3g21", "no user agent id")
|
||||
}
|
||||
resp, err := p.repo.AuthRequestByIDCheckLoggedIn(ctx, id, userAgentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return AuthRequestFromBusiness(resp)
|
||||
}
|
||||
|
||||
func (p *Storage) SetUserinfoWithUserID(ctx context.Context, userinfo models.AttributeSetter, userID string, attributes []int) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
user, err := p.query.GetUserByID(ctx, true, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
setUserinfo(user, userinfo, attributes)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Storage) SetUserinfoWithLoginName(ctx context.Context, userinfo models.AttributeSetter, loginName string, attributes []int) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
loginNameSQ, err := query.NewUserLoginNamesSearchQuery(loginName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user, err := p.query.GetUser(ctx, true, loginNameSQ)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
setUserinfo(user, userinfo, attributes)
|
||||
return nil
|
||||
}
|
||||
|
||||
func setUserinfo(user *query.User, userinfo models.AttributeSetter, attributes []int) {
|
||||
if len(attributes) == 0 {
|
||||
userinfo.SetUsername(user.PreferredLoginName)
|
||||
userinfo.SetUserID(user.ID)
|
||||
if user.Human == nil {
|
||||
return
|
||||
}
|
||||
userinfo.SetEmail(user.Human.Email)
|
||||
userinfo.SetSurname(user.Human.LastName)
|
||||
userinfo.SetGivenName(user.Human.FirstName)
|
||||
userinfo.SetFullName(user.Human.DisplayName)
|
||||
return
|
||||
}
|
||||
for _, attribute := range attributes {
|
||||
switch attribute {
|
||||
case provider.AttributeEmail:
|
||||
if user.Human != nil {
|
||||
userinfo.SetEmail(user.Human.Email)
|
||||
}
|
||||
case provider.AttributeSurname:
|
||||
if user.Human != nil {
|
||||
userinfo.SetSurname(user.Human.LastName)
|
||||
}
|
||||
case provider.AttributeFullName:
|
||||
if user.Human != nil {
|
||||
userinfo.SetFullName(user.Human.DisplayName)
|
||||
}
|
||||
case provider.AttributeGivenName:
|
||||
if user.Human != nil {
|
||||
userinfo.SetGivenName(user.Human.FirstName)
|
||||
}
|
||||
case provider.AttributeUsername:
|
||||
userinfo.SetUsername(user.PreferredLoginName)
|
||||
case provider.AttributeUserID:
|
||||
userinfo.SetUserID(user.ID)
|
||||
}
|
||||
}
|
||||
}
|
@@ -374,7 +374,7 @@ func (l *Login) handleAutoRegister(w http.ResponseWriter, r *http.Request, authR
|
||||
return
|
||||
}
|
||||
linkingUser = l.mapExternalNotFoundOptionFormDataToLoginUser(data)
|
||||
}
|
||||
}
|
||||
|
||||
user, externalIDP, metadata := l.mapExternalUserToLoginUser(orgIamPolicy, linkingUser, idpConfig)
|
||||
|
||||
|
@@ -37,6 +37,7 @@ type Login struct {
|
||||
externalSecure bool
|
||||
consolePath string
|
||||
oidcAuthCallbackURL func(context.Context, string) string
|
||||
samlAuthCallbackURL func(context.Context, string) string
|
||||
idpConfigAlg crypto.EncryptionAlgorithm
|
||||
userCodeAlg crypto.EncryptionAlgorithm
|
||||
}
|
||||
@@ -61,10 +62,12 @@ func CreateLogin(config Config,
|
||||
staticStorage static.Storage,
|
||||
consolePath string,
|
||||
oidcAuthCallbackURL func(context.Context, string) string,
|
||||
samlAuthCallbackURL func(context.Context, string) string,
|
||||
externalSecure bool,
|
||||
userAgentCookie,
|
||||
issuerInterceptor,
|
||||
instanceHandler,
|
||||
oidcInstanceHandler,
|
||||
samlInstanceHandler mux.MiddlewareFunc,
|
||||
assetCache mux.MiddlewareFunc,
|
||||
userCodeAlg crypto.EncryptionAlgorithm,
|
||||
idpConfigAlg crypto.EncryptionAlgorithm,
|
||||
@@ -73,6 +76,7 @@ func CreateLogin(config Config,
|
||||
|
||||
login := &Login{
|
||||
oidcAuthCallbackURL: oidcAuthCallbackURL,
|
||||
samlAuthCallbackURL: samlAuthCallbackURL,
|
||||
externalSecure: externalSecure,
|
||||
consolePath: consolePath,
|
||||
command: command,
|
||||
@@ -91,7 +95,7 @@ func CreateLogin(config Config,
|
||||
cacheInterceptor := createCacheInterceptor(config.Cache.MaxAge, config.Cache.SharedMaxAge, assetCache)
|
||||
security := middleware.SecurityHeaders(csp(), login.cspErrorHandler)
|
||||
|
||||
login.router = CreateRouter(login, statikFS, middleware.TelemetryHandler(IgnoreInstanceEndpoints...), instanceHandler, csrfInterceptor, cacheInterceptor, security, userAgentCookie, issuerInterceptor)
|
||||
login.router = CreateRouter(login, statikFS, middleware.TelemetryHandler(IgnoreInstanceEndpoints...), oidcInstanceHandler, samlInstanceHandler, csrfInterceptor, cacheInterceptor, security, userAgentCookie, issuerInterceptor)
|
||||
login.renderer = CreateRenderer(HandlerPrefix, statikFS, staticStorage, config.LanguageCookieName)
|
||||
login.parser = form.NewParser()
|
||||
return login, nil
|
||||
|
@@ -4,6 +4,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -43,11 +44,26 @@ func (l *Login) renderSuccessAndCallback(w http.ResponseWriter, r *http.Request,
|
||||
userData: l.getUserData(r, authReq, "Login Successful", errID, errMessage),
|
||||
}
|
||||
if authReq != nil {
|
||||
data.RedirectURI = l.oidcAuthCallbackURL(r.Context(), "") //the id will be set via the html (maybe change this with the login refactoring)
|
||||
//the id will be set via the html (maybe change this with the login refactoring)
|
||||
if _, ok := authReq.Request.(*domain.AuthRequestOIDC); ok {
|
||||
data.RedirectURI = l.oidcAuthCallbackURL(r.Context(), "")
|
||||
} else if _, ok := authReq.Request.(*domain.AuthRequestSAML); ok {
|
||||
data.RedirectURI = l.samlAuthCallbackURL(r.Context(), "")
|
||||
}
|
||||
}
|
||||
l.renderer.RenderTemplate(w, r, l.getTranslator(r.Context(), authReq), l.renderer.Templates[tmplLoginSuccess], data, nil)
|
||||
}
|
||||
|
||||
func (l *Login) redirectToCallback(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest) {
|
||||
http.Redirect(w, r, l.oidcAuthCallbackURL(r.Context(), authReq.ID), http.StatusFound)
|
||||
var callback string
|
||||
switch authReq.Request.(type) {
|
||||
case *domain.AuthRequestOIDC:
|
||||
callback = l.oidcAuthCallbackURL(r.Context(), authReq.ID)
|
||||
case *domain.AuthRequestSAML:
|
||||
callback = l.samlAuthCallbackURL(r.Context(), authReq.ID)
|
||||
default:
|
||||
l.renderInternalError(w, r, authReq, caos_errs.ThrowInternal(nil, "LOGIN-rhjQF", "Errors.AuthRequest.RequestTypeNotSupported"))
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, callback, http.StatusFound)
|
||||
}
|
||||
|
Reference in New Issue
Block a user