feat(saml): implementation of saml for ZITADEL v2 (#3618)

This commit is contained in:
Stefan Benz
2022-09-12 17:18:08 +01:00
committed by GitHub
parent 01a92ba5d9
commit 7a5f7f82cf
134 changed files with 5570 additions and 1293 deletions

View File

@@ -67,6 +67,16 @@ func (s *Server) AddOIDCApp(ctx context.Context, req *mgmt_pb.AddOIDCAppRequest)
ComplianceProblems: project_grpc.ComplianceProblemsToLocalizedMessages(app.Compliance.Problems),
}, nil
}
func (s *Server) AddSAMLApp(ctx context.Context, req *mgmt_pb.AddSAMLAppRequest) (*mgmt_pb.AddSAMLAppResponse, error) {
app, err := s.command.AddSAMLApplication(ctx, AddSAMLAppRequestToDomain(req), authz.GetCtxData(ctx).OrgID)
if err != nil {
return nil, err
}
return &mgmt_pb.AddSAMLAppResponse{
AppId: app.AppID,
Details: object_grpc.AddToDetailsPb(app.Sequence, app.ChangeDate, app.ResourceOwner),
}, nil
}
func (s *Server) AddAPIApp(ctx context.Context, req *mgmt_pb.AddAPIAppRequest) (*mgmt_pb.AddAPIAppResponse, error) {
appSecretGenerator, err := s.query.InitHashGenerator(ctx, domain.SecretGeneratorTypeAppSecret, s.passwordHashAlg)
@@ -109,6 +119,20 @@ func (s *Server) UpdateOIDCAppConfig(ctx context.Context, req *mgmt_pb.UpdateOID
}, nil
}
func (s *Server) UpdateSAMLAppConfig(ctx context.Context, req *mgmt_pb.UpdateSAMLAppConfigRequest) (*mgmt_pb.UpdateSAMLAppConfigResponse, error) {
config, err := s.command.ChangeSAMLApplication(ctx, UpdateSAMLAppConfigRequestToDomain(req), authz.GetCtxData(ctx).OrgID)
if err != nil {
return nil, err
}
return &mgmt_pb.UpdateSAMLAppConfigResponse{
Details: object_grpc.ChangeToDetailsPb(
config.Sequence,
config.ChangeDate,
config.ResourceOwner,
),
}, nil
}
func (s *Server) UpdateAPIAppConfig(ctx context.Context, req *mgmt_pb.UpdateAPIAppConfigRequest) (*mgmt_pb.UpdateAPIAppConfigResponse, error) {
config, err := s.command.ChangeAPIApplication(ctx, UpdateAPIAppConfigRequestToDomain(req), authz.GetCtxData(ctx).OrgID)
if err != nil {

View File

@@ -59,6 +59,17 @@ func AddOIDCAppRequestToDomain(req *mgmt_pb.AddOIDCAppRequest) *domain.OIDCApp {
}
}
func AddSAMLAppRequestToDomain(req *mgmt_pb.AddSAMLAppRequest) *domain.SAMLApp {
return &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{
AggregateID: req.ProjectId,
},
AppName: req.Name,
Metadata: req.GetMetadataXml(),
MetadataURL: req.GetMetadataUrl(),
}
}
func AddAPIAppRequestToDomain(app *mgmt_pb.AddAPIAppRequest) *domain.APIApp {
return &domain.APIApp{
ObjectRoot: models.ObjectRoot{
@@ -98,6 +109,17 @@ func UpdateOIDCAppConfigRequestToDomain(app *mgmt_pb.UpdateOIDCAppConfigRequest)
}
}
func UpdateSAMLAppConfigRequestToDomain(app *mgmt_pb.UpdateSAMLAppConfigRequest) *domain.SAMLApp {
return &domain.SAMLApp{
ObjectRoot: models.ObjectRoot{
AggregateID: app.ProjectId,
},
AppID: app.AppId,
Metadata: app.GetMetadataXml(),
MetadataURL: app.GetMetadataUrl(),
}
}
func UpdateAPIAppConfigRequestToDomain(app *mgmt_pb.UpdateAPIAppConfigRequest) *domain.APIApp {
return &domain.APIApp{
ObjectRoot: models.ObjectRoot{

View File

@@ -33,6 +33,9 @@ func AppConfigToPb(app *query.App) app_pb.AppConfig {
if app.OIDCConfig != nil {
return AppOIDCConfigToPb(app.OIDCConfig)
}
if app.SAMLConfig != nil {
return AppSAMLConfigToPb(app.SAMLConfig)
}
return AppAPIConfigToPb(app.APIConfig)
}
@@ -61,6 +64,14 @@ func AppOIDCConfigToPb(app *query.OIDCApp) *app_pb.App_OidcConfig {
}
}
func AppSAMLConfigToPb(app *query.SAMLApp) app_pb.AppConfig {
return &app_pb.App_SamlConfig{
SamlConfig: &app_pb.SAMLConfig{
Metadata: &app_pb.SAMLConfig_MetadataXml{MetadataXml: app.Metadata},
},
}
}
func AppAPIConfigToPb(app *query.APIApp) app_pb.AppConfig {
return &app_pb.App_ApiConfig{
ApiConfig: &app_pb.APIConfig{

View File

@@ -0,0 +1,99 @@
package saml
import (
"context"
"time"
"github.com/zitadel/saml/pkg/provider/models"
"github.com/zitadel/saml/pkg/provider/xml/samlp"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
)
var _ models.AuthRequestInt = &AuthRequest{}
type AuthRequest struct {
*domain.AuthRequest
}
func (a *AuthRequest) GetApplicationID() string {
return a.ApplicationID
}
func (a *AuthRequest) GetID() string {
return a.ID
}
func (a *AuthRequest) GetRelayState() string {
return a.TransferState
}
func (a *AuthRequest) GetAccessConsumerServiceURL() string {
return a.CallbackURI
}
func (a *AuthRequest) GetNameID() string {
return a.UserName
}
func (a *AuthRequest) saml() *domain.AuthRequestSAML {
return a.Request.(*domain.AuthRequestSAML)
}
func (a *AuthRequest) GetAuthRequestID() string {
return a.saml().ID
}
func (a *AuthRequest) GetBindingType() string {
return a.saml().BindingType
}
func (a *AuthRequest) GetIssuer() string {
return a.saml().Issuer
}
func (a *AuthRequest) GetIssuerName() string {
return a.saml().IssuerName
}
func (a *AuthRequest) GetDestination() string {
return a.saml().Destination
}
func (a *AuthRequest) GetCode() string {
return a.saml().Code
}
func (a *AuthRequest) GetUserID() string {
return a.UserID
}
func (a *AuthRequest) GetUserName() string {
return a.UserName
}
func (a *AuthRequest) Done() bool {
for _, step := range a.PossibleSteps {
if step.Type() == domain.NextStepRedirectToCallback {
return true
}
}
return false
}
func AuthRequestFromBusiness(authReq *domain.AuthRequest) (_ models.AuthRequestInt, err error) {
if _, ok := authReq.Request.(*domain.AuthRequestSAML); !ok {
return nil, errors.ThrowInvalidArgument(nil, "SAML-Hbz7A", "auth request is not of type saml")
}
return &AuthRequest{authReq}, nil
}
func CreateAuthRequestToBusiness(ctx context.Context, authReq *samlp.AuthnRequestType, acsUrl, protocolBinding, applicationID, relayState, userAgentID string) *domain.AuthRequest {
return &domain.AuthRequest{
CreationDate: time.Now(),
AgentID: userAgentID,
ApplicationID: applicationID,
CallbackURI: acsUrl,
TransferState: relayState,
InstanceID: authz.GetInstance(ctx).InstanceID(),
Request: &domain.AuthRequestSAML{
ID: authReq.Id,
BindingType: protocolBinding,
Code: "",
Issuer: authReq.Issuer.Text,
IssuerName: authReq.Issuer.SPProvidedID,
Destination: authReq.Destination,
},
}
}

View File

@@ -0,0 +1,202 @@
package saml
import (
"context"
"fmt"
"time"
"github.com/zitadel/logging"
"github.com/zitadel/saml/pkg/provider/key"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/repository/keypair"
)
const (
locksTable = "projections.locks"
signingKey = "signing_key"
samlUser = "SAML"
retryBackoff = 500 * time.Millisecond
retryCount = 3
lockDuration = retryCount * retryBackoff * 5
gracefulPeriod = 10 * time.Minute
)
type CertificateAndKey struct {
algorithm jose.SignatureAlgorithm
id string
key interface{}
certificate interface{}
}
func (c *CertificateAndKey) SignatureAlgorithm() jose.SignatureAlgorithm {
return c.algorithm
}
func (c *CertificateAndKey) Key() interface{} {
return c.key
}
func (c *CertificateAndKey) Certificate() interface{} {
return c.certificate
}
func (c *CertificateAndKey) ID() string {
return c.id
}
func (p *Storage) GetCertificateAndKey(ctx context.Context, usage domain.KeyUsage) (certAndKey *key.CertificateAndKey, err error) {
err = retry(func() error {
certAndKey, err = p.getCertificateAndKey(ctx, usage)
if err != nil {
return err
}
if certAndKey == nil {
return errors.ThrowInternal(err, "SAML-8u01nks", "no certificate found")
}
return nil
})
return certAndKey, err
}
func (p *Storage) getCertificateAndKey(ctx context.Context, usage domain.KeyUsage) (*key.CertificateAndKey, error) {
certs, err := p.query.ActiveCertificates(ctx, time.Now().Add(gracefulPeriod), usage)
if err != nil {
return nil, err
}
if len(certs.Certificates) > 0 {
return p.certificateToCertificateAndKey(selectCertificate(certs.Certificates))
}
var sequence uint64
if certs.LatestSequence != nil {
sequence = certs.LatestSequence.Sequence
}
return nil, p.refreshCertificate(ctx, usage, sequence)
}
func (p *Storage) refreshCertificate(
ctx context.Context,
usage domain.KeyUsage,
sequence uint64,
) error {
ok, err := p.ensureIsLatestCertificate(ctx, sequence)
if err != nil {
logging.WithError(err).Error("could not ensure latest key")
return err
}
if !ok {
logging.Warn("view not up to date, retrying later")
return err
}
err = p.lockAndGenerateCertificateAndKey(ctx, usage, sequence)
logging.OnError(err).Warn("could not create signing key")
return nil
}
func (p *Storage) ensureIsLatestCertificate(ctx context.Context, sequence uint64) (bool, error) {
maxSequence, err := p.getMaxKeySequence(ctx)
if err != nil {
return false, fmt.Errorf("error retrieving new events: %w", err)
}
return sequence == maxSequence, nil
}
func (p *Storage) lockAndGenerateCertificateAndKey(ctx context.Context, usage domain.KeyUsage, sequence uint64) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
ctx = setSAMLCtx(ctx)
errs := p.locker.Lock(ctx, lockDuration, authz.GetInstance(ctx).InstanceID())
err, ok := <-errs
if err != nil || !ok {
if errors.IsErrorAlreadyExists(err) {
return nil
}
logging.OnError(err).Warn("initial lock failed")
return err
}
switch usage {
case domain.KeyUsageSAMLMetadataSigning, domain.KeyUsageSAMLResponseSinging:
certAndKey, err := p.GetCertificateAndKey(ctx, domain.KeyUsageSAMLCA)
if err != nil {
return fmt.Errorf("error while reading ca certificate: %w", err)
}
if certAndKey.Key == nil || certAndKey.Certificate == nil {
return fmt.Errorf("has no ca certificate")
}
switch usage {
case domain.KeyUsageSAMLMetadataSigning:
return p.command.GenerateSAMLMetadataCertificate(setSAMLCtx(ctx), p.certificateAlgorithm, certAndKey.Key, certAndKey.Certificate)
case domain.KeyUsageSAMLResponseSinging:
return p.command.GenerateSAMLResponseCertificate(setSAMLCtx(ctx), p.certificateAlgorithm, certAndKey.Key, certAndKey.Certificate)
default:
return fmt.Errorf("unknown usage")
}
case domain.KeyUsageSAMLCA:
return p.command.GenerateSAMLCACertificate(setSAMLCtx(ctx), p.certificateAlgorithm)
default:
return fmt.Errorf("unknown certificate usage")
}
}
func (p *Storage) getMaxKeySequence(ctx context.Context) (uint64, error) {
return p.eventstore.LatestSequence(ctx,
eventstore.NewSearchQueryBuilder(eventstore.ColumnsMaxSequence).
ResourceOwner(authz.GetInstance(ctx).InstanceID()).
AddQuery().
AggregateTypes(keypair.AggregateType).
Builder(),
)
}
func (p *Storage) certificateToCertificateAndKey(certificate query.Certificate) (_ *key.CertificateAndKey, err error) {
keyData, err := crypto.Decrypt(certificate.Key(), p.encAlg)
if err != nil {
return nil, err
}
privateKey, err := crypto.BytesToPrivateKey(keyData)
if err != nil {
return nil, err
}
cert, err := crypto.BytesToCertificate(certificate.Certificate())
if err != nil {
return nil, err
}
return &key.CertificateAndKey{
Key: privateKey,
Certificate: cert,
}, nil
}
func selectCertificate(certs []query.Certificate) query.Certificate {
return certs[len(certs)-1]
}
func setSAMLCtx(ctx context.Context) context.Context {
return authz.SetCtxData(ctx, authz.CtxData{UserID: samlUser, OrgID: authz.GetInstance(ctx).InstanceID()})
}
func retry(retryable func() error) (err error) {
for i := 0; i < retryCount; i++ {
err = retryable()
if err == nil {
return nil
}
time.Sleep(retryBackoff)
}
return err
}

View File

@@ -0,0 +1,102 @@
package saml
import (
"context"
"database/sql"
"fmt"
"net/http"
"github.com/zitadel/saml/pkg/provider"
http_utils "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/api/ui/login"
"github.com/zitadel/zitadel/internal/auth/repository"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler/crdb"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/metrics"
)
const (
HandlerPrefix = "/saml/v2"
)
type Config struct {
ProviderConfig *provider.Config
}
func NewProvider(
ctx context.Context,
conf Config,
externalSecure bool,
command *command.Commands,
query *query.Queries,
repo repository.Repository,
encAlg crypto.EncryptionAlgorithm,
certEncAlg crypto.EncryptionAlgorithm,
es *eventstore.Eventstore,
projections *sql.DB,
instanceHandler,
userAgentCookie func(http.Handler) http.Handler,
) (*provider.Provider, error) {
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
provStorage, err := newStorage(
command,
query,
repo,
encAlg,
certEncAlg,
es,
projections,
)
if err != nil {
return nil, err
}
options := []provider.Option{
provider.WithHttpInterceptors(
middleware.MetricsHandler(metricTypes),
middleware.TelemetryHandler(),
middleware.NoCacheInterceptor().Handler,
instanceHandler,
userAgentCookie,
http_utils.CopyHeadersToContext,
),
}
if !externalSecure {
options = append(options, provider.WithAllowInsecure())
}
return provider.NewProvider(
ctx,
provStorage,
HandlerPrefix,
conf.ProviderConfig,
options...,
)
}
func newStorage(
command *command.Commands,
query *query.Queries,
repo repository.Repository,
encAlg crypto.EncryptionAlgorithm,
certEncAlg crypto.EncryptionAlgorithm,
es *eventstore.Eventstore,
projections *sql.DB,
) (*Storage, error) {
return &Storage{
encAlg: encAlg,
certEncAlg: certEncAlg,
locker: crdb.NewLocker(projections, locksTable, signingKey),
eventstore: es,
repo: repo,
command: command,
query: query,
defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID),
}, nil
}

View File

@@ -0,0 +1,187 @@
package saml
import (
"context"
"time"
"github.com/zitadel/saml/pkg/provider"
"github.com/zitadel/saml/pkg/provider/key"
"github.com/zitadel/saml/pkg/provider/models"
"github.com/zitadel/saml/pkg/provider/serviceprovider"
"github.com/zitadel/saml/pkg/provider/xml/samlp"
"github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/auth/repository"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler/crdb"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
var _ provider.EntityStorage = &Storage{}
var _ provider.IdentityProviderStorage = &Storage{}
var _ provider.AuthStorage = &Storage{}
var _ provider.UserStorage = &Storage{}
type Storage struct {
certChan <-chan interface{}
defaultCertificateLifetime time.Duration
currentCACertificate query.Certificate
currentMetadataCertificate query.Certificate
currentResponseCertificate query.Certificate
locker crdb.Locker
certificateAlgorithm string
encAlg crypto.EncryptionAlgorithm
certEncAlg crypto.EncryptionAlgorithm
eventstore *eventstore.Eventstore
repo repository.Repository
command *command.Commands
query *query.Queries
defaultLoginURL string
}
func (p *Storage) GetEntityByID(ctx context.Context, entityID string) (*serviceprovider.ServiceProvider, error) {
app, err := p.query.AppBySAMLEntityID(ctx, entityID)
if err != nil {
return nil, err
}
return serviceprovider.NewServiceProvider(
app.ID,
&serviceprovider.Config{
Metadata: app.SAMLConfig.Metadata,
},
p.defaultLoginURL,
)
}
func (p *Storage) GetEntityIDByAppID(ctx context.Context, appID string) (string, error) {
app, err := p.query.AppByID(ctx, appID)
if err != nil {
return "", err
}
return app.SAMLConfig.EntityID, nil
}
func (p *Storage) Health(context.Context) error {
return nil
}
func (p *Storage) GetCA(ctx context.Context) (*key.CertificateAndKey, error) {
return p.GetCertificateAndKey(ctx, domain.KeyUsageSAMLCA)
}
func (p *Storage) GetMetadataSigningKey(ctx context.Context) (*key.CertificateAndKey, error) {
return p.GetCertificateAndKey(ctx, domain.KeyUsageSAMLMetadataSigning)
}
func (p *Storage) GetResponseSigningKey(ctx context.Context) (*key.CertificateAndKey, error) {
return p.GetCertificateAndKey(ctx, domain.KeyUsageSAMLResponseSinging)
}
func (p *Storage) CreateAuthRequest(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID string) (_ models.AuthRequestInt, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
if !ok {
return nil, errors.ThrowPreconditionFailed(nil, "SAML-sd436", "no user agent id")
}
authRequest := CreateAuthRequestToBusiness(ctx, req, acsUrl, protocolBinding, applicationID, relayState, userAgentID)
resp, err := p.repo.CreateAuthRequest(ctx, authRequest)
if err != nil {
return nil, err
}
return AuthRequestFromBusiness(resp)
}
func (p *Storage) AuthRequestByID(ctx context.Context, id string) (_ models.AuthRequestInt, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
if !ok {
return nil, errors.ThrowPreconditionFailed(nil, "SAML-D3g21", "no user agent id")
}
resp, err := p.repo.AuthRequestByIDCheckLoggedIn(ctx, id, userAgentID)
if err != nil {
return nil, err
}
return AuthRequestFromBusiness(resp)
}
func (p *Storage) SetUserinfoWithUserID(ctx context.Context, userinfo models.AttributeSetter, userID string, attributes []int) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
user, err := p.query.GetUserByID(ctx, true, userID)
if err != nil {
return err
}
setUserinfo(user, userinfo, attributes)
return nil
}
func (p *Storage) SetUserinfoWithLoginName(ctx context.Context, userinfo models.AttributeSetter, loginName string, attributes []int) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
loginNameSQ, err := query.NewUserLoginNamesSearchQuery(loginName)
if err != nil {
return err
}
user, err := p.query.GetUser(ctx, true, loginNameSQ)
if err != nil {
return err
}
setUserinfo(user, userinfo, attributes)
return nil
}
func setUserinfo(user *query.User, userinfo models.AttributeSetter, attributes []int) {
if len(attributes) == 0 {
userinfo.SetUsername(user.PreferredLoginName)
userinfo.SetUserID(user.ID)
if user.Human == nil {
return
}
userinfo.SetEmail(user.Human.Email)
userinfo.SetSurname(user.Human.LastName)
userinfo.SetGivenName(user.Human.FirstName)
userinfo.SetFullName(user.Human.DisplayName)
return
}
for _, attribute := range attributes {
switch attribute {
case provider.AttributeEmail:
if user.Human != nil {
userinfo.SetEmail(user.Human.Email)
}
case provider.AttributeSurname:
if user.Human != nil {
userinfo.SetSurname(user.Human.LastName)
}
case provider.AttributeFullName:
if user.Human != nil {
userinfo.SetFullName(user.Human.DisplayName)
}
case provider.AttributeGivenName:
if user.Human != nil {
userinfo.SetGivenName(user.Human.FirstName)
}
case provider.AttributeUsername:
userinfo.SetUsername(user.PreferredLoginName)
case provider.AttributeUserID:
userinfo.SetUserID(user.ID)
}
}
}

View File

@@ -374,7 +374,7 @@ func (l *Login) handleAutoRegister(w http.ResponseWriter, r *http.Request, authR
return
}
linkingUser = l.mapExternalNotFoundOptionFormDataToLoginUser(data)
}
}
user, externalIDP, metadata := l.mapExternalUserToLoginUser(orgIamPolicy, linkingUser, idpConfig)

View File

@@ -37,6 +37,7 @@ type Login struct {
externalSecure bool
consolePath string
oidcAuthCallbackURL func(context.Context, string) string
samlAuthCallbackURL func(context.Context, string) string
idpConfigAlg crypto.EncryptionAlgorithm
userCodeAlg crypto.EncryptionAlgorithm
}
@@ -61,10 +62,12 @@ func CreateLogin(config Config,
staticStorage static.Storage,
consolePath string,
oidcAuthCallbackURL func(context.Context, string) string,
samlAuthCallbackURL func(context.Context, string) string,
externalSecure bool,
userAgentCookie,
issuerInterceptor,
instanceHandler,
oidcInstanceHandler,
samlInstanceHandler mux.MiddlewareFunc,
assetCache mux.MiddlewareFunc,
userCodeAlg crypto.EncryptionAlgorithm,
idpConfigAlg crypto.EncryptionAlgorithm,
@@ -73,6 +76,7 @@ func CreateLogin(config Config,
login := &Login{
oidcAuthCallbackURL: oidcAuthCallbackURL,
samlAuthCallbackURL: samlAuthCallbackURL,
externalSecure: externalSecure,
consolePath: consolePath,
command: command,
@@ -91,7 +95,7 @@ func CreateLogin(config Config,
cacheInterceptor := createCacheInterceptor(config.Cache.MaxAge, config.Cache.SharedMaxAge, assetCache)
security := middleware.SecurityHeaders(csp(), login.cspErrorHandler)
login.router = CreateRouter(login, statikFS, middleware.TelemetryHandler(IgnoreInstanceEndpoints...), instanceHandler, csrfInterceptor, cacheInterceptor, security, userAgentCookie, issuerInterceptor)
login.router = CreateRouter(login, statikFS, middleware.TelemetryHandler(IgnoreInstanceEndpoints...), oidcInstanceHandler, samlInstanceHandler, csrfInterceptor, cacheInterceptor, security, userAgentCookie, issuerInterceptor)
login.renderer = CreateRenderer(HandlerPrefix, statikFS, staticStorage, config.LanguageCookieName)
login.parser = form.NewParser()
return login, nil

View File

@@ -4,6 +4,7 @@ import (
"net/http"
"github.com/zitadel/zitadel/internal/domain"
caos_errs "github.com/zitadel/zitadel/internal/errors"
)
const (
@@ -43,11 +44,26 @@ func (l *Login) renderSuccessAndCallback(w http.ResponseWriter, r *http.Request,
userData: l.getUserData(r, authReq, "Login Successful", errID, errMessage),
}
if authReq != nil {
data.RedirectURI = l.oidcAuthCallbackURL(r.Context(), "") //the id will be set via the html (maybe change this with the login refactoring)
//the id will be set via the html (maybe change this with the login refactoring)
if _, ok := authReq.Request.(*domain.AuthRequestOIDC); ok {
data.RedirectURI = l.oidcAuthCallbackURL(r.Context(), "")
} else if _, ok := authReq.Request.(*domain.AuthRequestSAML); ok {
data.RedirectURI = l.samlAuthCallbackURL(r.Context(), "")
}
}
l.renderer.RenderTemplate(w, r, l.getTranslator(r.Context(), authReq), l.renderer.Templates[tmplLoginSuccess], data, nil)
}
func (l *Login) redirectToCallback(w http.ResponseWriter, r *http.Request, authReq *domain.AuthRequest) {
http.Redirect(w, r, l.oidcAuthCallbackURL(r.Context(), authReq.ID), http.StatusFound)
var callback string
switch authReq.Request.(type) {
case *domain.AuthRequestOIDC:
callback = l.oidcAuthCallbackURL(r.Context(), authReq.ID)
case *domain.AuthRequestSAML:
callback = l.samlAuthCallbackURL(r.Context(), authReq.ID)
default:
l.renderInternalError(w, r, authReq, caos_errs.ThrowInternal(nil, "LOGIN-rhjQF", "Errors.AuthRequest.RequestTypeNotSupported"))
return
}
http.Redirect(w, r, callback, http.StatusFound)
}