mirror of
https://github.com/zitadel/zitadel.git
synced 2025-01-12 01:33:40 +00:00
188 lines
5.6 KiB
Go
188 lines
5.6 KiB
Go
|
package saml
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"time"
|
||
|
|
||
|
"github.com/zitadel/saml/pkg/provider"
|
||
|
"github.com/zitadel/saml/pkg/provider/key"
|
||
|
"github.com/zitadel/saml/pkg/provider/models"
|
||
|
"github.com/zitadel/saml/pkg/provider/serviceprovider"
|
||
|
"github.com/zitadel/saml/pkg/provider/xml/samlp"
|
||
|
|
||
|
"github.com/zitadel/zitadel/internal/api/http/middleware"
|
||
|
"github.com/zitadel/zitadel/internal/auth/repository"
|
||
|
"github.com/zitadel/zitadel/internal/command"
|
||
|
"github.com/zitadel/zitadel/internal/crypto"
|
||
|
"github.com/zitadel/zitadel/internal/domain"
|
||
|
"github.com/zitadel/zitadel/internal/errors"
|
||
|
"github.com/zitadel/zitadel/internal/eventstore"
|
||
|
"github.com/zitadel/zitadel/internal/eventstore/handler/crdb"
|
||
|
"github.com/zitadel/zitadel/internal/query"
|
||
|
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||
|
)
|
||
|
|
||
|
var _ provider.EntityStorage = &Storage{}
|
||
|
var _ provider.IdentityProviderStorage = &Storage{}
|
||
|
var _ provider.AuthStorage = &Storage{}
|
||
|
var _ provider.UserStorage = &Storage{}
|
||
|
|
||
|
type Storage struct {
|
||
|
certChan <-chan interface{}
|
||
|
defaultCertificateLifetime time.Duration
|
||
|
|
||
|
currentCACertificate query.Certificate
|
||
|
currentMetadataCertificate query.Certificate
|
||
|
currentResponseCertificate query.Certificate
|
||
|
|
||
|
locker crdb.Locker
|
||
|
certificateAlgorithm string
|
||
|
encAlg crypto.EncryptionAlgorithm
|
||
|
certEncAlg crypto.EncryptionAlgorithm
|
||
|
|
||
|
eventstore *eventstore.Eventstore
|
||
|
repo repository.Repository
|
||
|
command *command.Commands
|
||
|
query *query.Queries
|
||
|
|
||
|
defaultLoginURL string
|
||
|
}
|
||
|
|
||
|
func (p *Storage) GetEntityByID(ctx context.Context, entityID string) (*serviceprovider.ServiceProvider, error) {
|
||
|
app, err := p.query.AppBySAMLEntityID(ctx, entityID)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return serviceprovider.NewServiceProvider(
|
||
|
app.ID,
|
||
|
&serviceprovider.Config{
|
||
|
Metadata: app.SAMLConfig.Metadata,
|
||
|
},
|
||
|
p.defaultLoginURL,
|
||
|
)
|
||
|
}
|
||
|
|
||
|
func (p *Storage) GetEntityIDByAppID(ctx context.Context, appID string) (string, error) {
|
||
|
app, err := p.query.AppByID(ctx, appID)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
return app.SAMLConfig.EntityID, nil
|
||
|
}
|
||
|
|
||
|
func (p *Storage) Health(context.Context) error {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (p *Storage) GetCA(ctx context.Context) (*key.CertificateAndKey, error) {
|
||
|
return p.GetCertificateAndKey(ctx, domain.KeyUsageSAMLCA)
|
||
|
}
|
||
|
|
||
|
func (p *Storage) GetMetadataSigningKey(ctx context.Context) (*key.CertificateAndKey, error) {
|
||
|
return p.GetCertificateAndKey(ctx, domain.KeyUsageSAMLMetadataSigning)
|
||
|
}
|
||
|
|
||
|
func (p *Storage) GetResponseSigningKey(ctx context.Context) (*key.CertificateAndKey, error) {
|
||
|
return p.GetCertificateAndKey(ctx, domain.KeyUsageSAMLResponseSinging)
|
||
|
}
|
||
|
|
||
|
func (p *Storage) CreateAuthRequest(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID string) (_ models.AuthRequestInt, err error) {
|
||
|
ctx, span := tracing.NewSpan(ctx)
|
||
|
defer func() { span.EndWithError(err) }()
|
||
|
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
||
|
if !ok {
|
||
|
return nil, errors.ThrowPreconditionFailed(nil, "SAML-sd436", "no user agent id")
|
||
|
}
|
||
|
|
||
|
authRequest := CreateAuthRequestToBusiness(ctx, req, acsUrl, protocolBinding, applicationID, relayState, userAgentID)
|
||
|
|
||
|
resp, err := p.repo.CreateAuthRequest(ctx, authRequest)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return AuthRequestFromBusiness(resp)
|
||
|
}
|
||
|
|
||
|
func (p *Storage) AuthRequestByID(ctx context.Context, id string) (_ models.AuthRequestInt, err error) {
|
||
|
ctx, span := tracing.NewSpan(ctx)
|
||
|
defer func() { span.EndWithError(err) }()
|
||
|
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
||
|
if !ok {
|
||
|
return nil, errors.ThrowPreconditionFailed(nil, "SAML-D3g21", "no user agent id")
|
||
|
}
|
||
|
resp, err := p.repo.AuthRequestByIDCheckLoggedIn(ctx, id, userAgentID)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return AuthRequestFromBusiness(resp)
|
||
|
}
|
||
|
|
||
|
func (p *Storage) SetUserinfoWithUserID(ctx context.Context, userinfo models.AttributeSetter, userID string, attributes []int) (err error) {
|
||
|
ctx, span := tracing.NewSpan(ctx)
|
||
|
defer func() { span.EndWithError(err) }()
|
||
|
user, err := p.query.GetUserByID(ctx, true, userID)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
setUserinfo(user, userinfo, attributes)
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (p *Storage) SetUserinfoWithLoginName(ctx context.Context, userinfo models.AttributeSetter, loginName string, attributes []int) (err error) {
|
||
|
ctx, span := tracing.NewSpan(ctx)
|
||
|
defer func() { span.EndWithError(err) }()
|
||
|
|
||
|
loginNameSQ, err := query.NewUserLoginNamesSearchQuery(loginName)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
user, err := p.query.GetUser(ctx, true, loginNameSQ)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
setUserinfo(user, userinfo, attributes)
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func setUserinfo(user *query.User, userinfo models.AttributeSetter, attributes []int) {
|
||
|
if len(attributes) == 0 {
|
||
|
userinfo.SetUsername(user.PreferredLoginName)
|
||
|
userinfo.SetUserID(user.ID)
|
||
|
if user.Human == nil {
|
||
|
return
|
||
|
}
|
||
|
userinfo.SetEmail(user.Human.Email)
|
||
|
userinfo.SetSurname(user.Human.LastName)
|
||
|
userinfo.SetGivenName(user.Human.FirstName)
|
||
|
userinfo.SetFullName(user.Human.DisplayName)
|
||
|
return
|
||
|
}
|
||
|
for _, attribute := range attributes {
|
||
|
switch attribute {
|
||
|
case provider.AttributeEmail:
|
||
|
if user.Human != nil {
|
||
|
userinfo.SetEmail(user.Human.Email)
|
||
|
}
|
||
|
case provider.AttributeSurname:
|
||
|
if user.Human != nil {
|
||
|
userinfo.SetSurname(user.Human.LastName)
|
||
|
}
|
||
|
case provider.AttributeFullName:
|
||
|
if user.Human != nil {
|
||
|
userinfo.SetFullName(user.Human.DisplayName)
|
||
|
}
|
||
|
case provider.AttributeGivenName:
|
||
|
if user.Human != nil {
|
||
|
userinfo.SetGivenName(user.Human.FirstName)
|
||
|
}
|
||
|
case provider.AttributeUsername:
|
||
|
userinfo.SetUsername(user.PreferredLoginName)
|
||
|
case provider.AttributeUserID:
|
||
|
userinfo.SetUserID(user.ID)
|
||
|
}
|
||
|
}
|
||
|
}
|