feat: add saml request to link to sessions (#9001)

# Which Problems Are Solved

It is currently not possible to use SAML with the Session API.

# How the Problems Are Solved

Add SAML service, to get and resolve SAML requests.
Add SAML session and SAML request aggregate, which can be linked to the
Session to get back a SAMLResponse from the API directly.

# Additional Changes

Update of dependency zitadel/saml to provide all functionality for
handling of SAML requests and responses.

# Additional Context

Closes #6053

---------

Co-authored-by: Livio Spring <livio.a@gmail.com>
This commit is contained in:
Stefan Benz
2024-12-19 12:11:40 +01:00
committed by GitHub
parent 50d2b26a28
commit c3b97a91a2
57 changed files with 3947 additions and 22 deletions

View File

@@ -0,0 +1,99 @@
package saml
import (
"context"
"encoding/base64"
"net/url"
"github.com/zitadel/saml/pkg/provider"
"github.com/zitadel/saml/pkg/provider/models"
"github.com/zitadel/saml/pkg/provider/xml"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/domain"
)
func (p *Provider) CreateErrorResponse(authReq models.AuthRequestInt, reason domain.SAMLErrorReason, description string) (string, string, error) {
resp := &provider.Response{
ProtocolBinding: authReq.GetBindingType(),
RelayState: authReq.GetRelayState(),
AcsUrl: authReq.GetAccessConsumerServiceURL(),
RequestID: authReq.GetAuthRequestID(),
Issuer: authReq.GetDestination(),
Audience: authReq.GetIssuer(),
}
return createResponse(p.AuthCallbackErrorResponse(resp, domain.SAMLErrorReasonToString(reason), description), authReq.GetBindingType(), authReq.GetAccessConsumerServiceURL(), resp.RelayState, resp.SigAlg, resp.Signature)
}
func (p *Provider) CreateResponse(ctx context.Context, authReq models.AuthRequestInt) (string, string, error) {
resp := &provider.Response{
ProtocolBinding: authReq.GetBindingType(),
RelayState: authReq.GetRelayState(),
AcsUrl: authReq.GetAccessConsumerServiceURL(),
RequestID: authReq.GetAuthRequestID(),
Issuer: authReq.GetDestination(),
Audience: authReq.GetIssuer(),
}
samlResponse, err := p.AuthCallbackResponse(ctx, authReq, resp)
if err != nil {
return "", "", err
}
if err := p.command.CreateSAMLSessionFromSAMLRequest(
setContextUserSystem(ctx),
authReq.GetID(),
samlComplianceChecker(),
samlResponse.Id,
p.Expiration(),
); err != nil {
return "", "", err
}
return createResponse(samlResponse, authReq.GetBindingType(), authReq.GetAccessConsumerServiceURL(), resp.RelayState, resp.SigAlg, resp.Signature)
}
func createResponse(samlResponse interface{}, binding, acs, relayState, sigAlg, sig string) (string, string, error) {
respData, err := xml.Marshal(samlResponse)
if err != nil {
return "", "", err
}
switch binding {
case provider.PostBinding:
return acs, base64.StdEncoding.EncodeToString(respData), nil
case provider.RedirectBinding:
respData, err := xml.DeflateAndBase64(respData)
if err != nil {
return "", "", err
}
parsed, err := url.Parse(acs)
if err != nil {
return "", "", err
}
values := parsed.Query()
values.Add("SAMLResponse", string(respData))
values.Add("RelayState", relayState)
values.Add("SigAlg", sigAlg)
values.Add("Signature", sig)
parsed.RawQuery = values.Encode()
return parsed.String(), "", nil
}
return "", "", nil
}
func setContextUserSystem(ctx context.Context) context.Context {
data := authz.CtxData{
UserID: "SYSTEM",
}
return authz.SetCtxData(ctx, data)
}
func samlComplianceChecker() command.SAMLRequestComplianceChecker {
return func(_ context.Context, samlReq *command.SAMLRequestWriteModel) error {
if err := samlReq.CheckAuthenticated(); err != nil {
return err
}
return nil
}
}

View File

@@ -0,0 +1,45 @@
package saml
import (
"github.com/zitadel/saml/pkg/provider/models"
"github.com/zitadel/zitadel/internal/command"
)
var _ models.AuthRequestInt = &AuthRequestV2{}
type AuthRequestV2 struct {
*command.CurrentSAMLRequest
}
func (a *AuthRequestV2) GetApplicationID() string {
return a.ApplicationID
}
func (a *AuthRequestV2) GetID() string {
return a.ID
}
func (a *AuthRequestV2) GetRelayState() string {
return a.RelayState
}
func (a *AuthRequestV2) GetAccessConsumerServiceURL() string {
return a.ACSURL
}
func (a *AuthRequestV2) GetAuthRequestID() string {
return a.RequestID
}
func (a *AuthRequestV2) GetBindingType() string {
return a.Binding
}
func (a *AuthRequestV2) GetIssuer() string {
return a.Issuer
}
func (a *AuthRequestV2) GetDestination() string {
return a.Destination
}
func (a *AuthRequestV2) GetUserID() string {
return a.UserID
}
func (a *AuthRequestV2) Done() bool {
return a.UserID != "" && a.SessionID != ""
}

View File

@@ -24,7 +24,13 @@ const (
)
type Config struct {
ProviderConfig *provider.Config
ProviderConfig *provider.Config
DefaultLoginURLV2 string
}
type Provider struct {
*provider.Provider
command *command.Commands
}
func NewProvider(
@@ -40,7 +46,7 @@ func NewProvider(
instanceHandler,
userAgentCookie func(http.Handler) http.Handler,
accessHandler *middleware.AccessInterceptor,
) (*provider.Provider, error) {
) (*Provider, error) {
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
provStorage, err := newStorage(
@@ -51,6 +57,8 @@ func NewProvider(
certEncAlg,
es,
projections,
fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID),
conf.DefaultLoginURLV2,
)
if err != nil {
return nil, err
@@ -73,12 +81,19 @@ func NewProvider(
options = append(options, provider.WithAllowInsecure())
}
return provider.NewProvider(
p, err := provider.NewProvider(
provStorage,
HandlerPrefix,
conf.ProviderConfig,
options...,
)
if err != nil {
return nil, err
}
return &Provider{
p,
command,
}, nil
}
func newStorage(
@@ -89,16 +104,19 @@ func newStorage(
certEncAlg crypto.EncryptionAlgorithm,
es *eventstore.Eventstore,
db *database.DB,
defaultLoginURL string,
defaultLoginURLV2 string,
) (*Storage, error) {
return &Storage{
encAlg: encAlg,
certEncAlg: certEncAlg,
locker: crdb.NewLocker(db.DB, locksTable, signingKey),
eventstore: es,
repo: repo,
command: command,
query: query,
defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID),
encAlg: encAlg,
certEncAlg: certEncAlg,
locker: crdb.NewLocker(db.DB, locksTable, signingKey),
eventstore: es,
repo: repo,
command: command,
query: query,
defaultLoginURL: defaultLoginURL,
defaultLoginURLv2: defaultLoginURLV2,
}, nil
}

View File

@@ -3,6 +3,7 @@ package saml
import (
"context"
"encoding/json"
"strings"
"time"
"github.com/dop251/goja"
@@ -16,6 +17,7 @@ import (
"github.com/zitadel/zitadel/internal/actions"
"github.com/zitadel/zitadel/internal/actions/object"
"github.com/zitadel/zitadel/internal/activity"
http_utils "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/auth/repository"
"github.com/zitadel/zitadel/internal/command"
@@ -33,6 +35,10 @@ var _ provider.IdentityProviderStorage = &Storage{}
var _ provider.AuthStorage = &Storage{}
var _ provider.UserStorage = &Storage{}
const (
LoginClientHeader = "x-zitadel-login-client"
)
type Storage struct {
certChan <-chan interface{}
defaultCertificateLifetime time.Duration
@@ -51,7 +57,8 @@ type Storage struct {
command *command.Commands
query *query.Queries
defaultLoginURL string
defaultLoginURL string
defaultLoginURLv2 string
}
func (p *Storage) GetEntityByID(ctx context.Context, entityID string) (*serviceprovider.ServiceProvider, error) {
@@ -64,7 +71,12 @@ func (p *Storage) GetEntityByID(ctx context.Context, entityID string) (*servicep
&serviceprovider.Config{
Metadata: app.SAMLConfig.Metadata,
},
p.defaultLoginURL,
func(id string) string {
if strings.HasPrefix(id, command.IDPrefixV2) {
return p.defaultLoginURLv2 + id
}
return p.defaultLoginURL + id
},
)
}
@@ -95,6 +107,38 @@ func (p *Storage) GetResponseSigningKey(ctx context.Context) (*key.CertificateAn
func (p *Storage) CreateAuthRequest(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID string) (_ models.AuthRequestInt, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
headers, _ := http_utils.HeadersFromCtx(ctx)
if loginClient := headers.Get(LoginClientHeader); loginClient != "" {
return p.createAuthRequestLoginClient(ctx, req, acsUrl, protocolBinding, relayState, applicationID, loginClient)
}
return p.createAuthRequest(ctx, req, acsUrl, protocolBinding, relayState, applicationID)
}
func (p *Storage) createAuthRequestLoginClient(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID, loginClient string) (_ models.AuthRequestInt, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
samlRequest := &command.SAMLRequest{
ApplicationID: applicationID,
ACSURL: acsUrl,
RelayState: relayState,
RequestID: req.Id,
Binding: protocolBinding,
Issuer: req.Issuer.Text,
Destination: req.Destination,
LoginClient: loginClient,
}
aar, err := p.command.AddSAMLRequest(ctx, samlRequest)
if err != nil {
return nil, err
}
return &AuthRequestV2{aar}, nil
}
func (p *Storage) createAuthRequest(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID string) (_ models.AuthRequestInt, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
if !ok {
return nil, zerrors.ThrowPreconditionFailed(nil, "SAML-sd436", "no user agent id")
@@ -113,6 +157,15 @@ func (p *Storage) CreateAuthRequest(ctx context.Context, req *samlp.AuthnRequest
func (p *Storage) AuthRequestByID(ctx context.Context, id string) (_ models.AuthRequestInt, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if strings.HasPrefix(id, command.IDPrefixV2) {
req, err := p.command.GetCurrentSAMLRequest(ctx, id)
if err != nil {
return nil, err
}
return &AuthRequestV2{req}, nil
}
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
if !ok {
return nil, zerrors.ThrowPreconditionFailed(nil, "SAML-D3g21", "no user agent id")