mirror of
https://github.com/zitadel/zitadel.git
synced 2025-12-11 20:12:18 +00:00
feat: add saml request to link to sessions
This commit is contained in:
60
internal/api/saml/auth_request_converter_v2.go
Normal file
60
internal/api/saml/auth_request_converter_v2.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"github.com/zitadel/saml/pkg/provider/models"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
)
|
||||
|
||||
var _ models.AuthRequestInt = &AuthRequestV2{}
|
||||
|
||||
type AuthRequestV2 struct {
|
||||
*command.CurrentSAMLRequest
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetApplicationID() string {
|
||||
return a.ApplicationID
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetID() string {
|
||||
return a.ID
|
||||
}
|
||||
func (a *AuthRequestV2) GetRelayState() string {
|
||||
return a.RelayState
|
||||
}
|
||||
func (a *AuthRequestV2) GetAccessConsumerServiceURL() string {
|
||||
return a.ACSURL
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetNameID() string {
|
||||
return a.UserID
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetAuthRequestID() string {
|
||||
return a.RequestID
|
||||
}
|
||||
func (a *AuthRequestV2) GetBindingType() string {
|
||||
return a.Binding
|
||||
}
|
||||
func (a *AuthRequestV2) GetIssuer() string {
|
||||
return a.Issuer
|
||||
}
|
||||
func (a *AuthRequestV2) GetIssuerName() string {
|
||||
return a.IssuerName
|
||||
}
|
||||
func (a *AuthRequestV2) GetDestination() string {
|
||||
return a.Destination
|
||||
}
|
||||
func (a *AuthRequestV2) GetCode() string {
|
||||
return ""
|
||||
}
|
||||
func (a *AuthRequestV2) GetUserID() string {
|
||||
return a.UserID
|
||||
}
|
||||
func (a *AuthRequestV2) GetUserName() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) Done() bool {
|
||||
return a.UserID != "" && a.SessionID != ""
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package saml
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dop251/goja"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/actions"
|
||||
"github.com/zitadel/zitadel/internal/actions/object"
|
||||
"github.com/zitadel/zitadel/internal/activity"
|
||||
http_utils "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/api/http/middleware"
|
||||
"github.com/zitadel/zitadel/internal/auth/repository"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
@@ -33,6 +35,10 @@ var _ provider.IdentityProviderStorage = &Storage{}
|
||||
var _ provider.AuthStorage = &Storage{}
|
||||
var _ provider.UserStorage = &Storage{}
|
||||
|
||||
const (
|
||||
LoginClientHeader = "x-zitadel-login-client"
|
||||
)
|
||||
|
||||
type Storage struct {
|
||||
certChan <-chan interface{}
|
||||
defaultCertificateLifetime time.Duration
|
||||
@@ -95,6 +101,47 @@ func (p *Storage) GetResponseSigningKey(ctx context.Context) (*key.CertificateAn
|
||||
func (p *Storage) CreateAuthRequest(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID string) (_ models.AuthRequestInt, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
headers, _ := http_utils.HeadersFromCtx(ctx)
|
||||
if loginClient := headers.Get(LoginClientHeader); loginClient != "" {
|
||||
return p.createAuthRequestLoginClient(ctx, req, acsUrl, protocolBinding, relayState, applicationID)
|
||||
}
|
||||
|
||||
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
||||
if !ok {
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "SAML-sd436", "no user agent id")
|
||||
}
|
||||
|
||||
resp, err := p.repo.CreateAuthRequest(ctx, CreateAuthRequestToBusiness(ctx, req, acsUrl, protocolBinding, applicationID, relayState, userAgentID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return AuthRequestFromBusiness(resp)
|
||||
}
|
||||
|
||||
func (p *Storage) createAuthRequestLoginClient(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID string) (models.AuthRequestInt, error) {
|
||||
samlRequest := &command.SAMLRequest{
|
||||
ApplicationID: applicationID,
|
||||
ACSURL: acsUrl,
|
||||
RelayState: relayState,
|
||||
RequestID: req.Id,
|
||||
Binding: protocolBinding,
|
||||
Issuer: req.Issuer.Text,
|
||||
IssuerName: req.Issuer.SPProvidedID,
|
||||
Destination: req.Destination,
|
||||
}
|
||||
|
||||
aar, err := p.command.AddSAMLRequest(ctx, samlRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AuthRequestV2{aar}, nil
|
||||
}
|
||||
|
||||
func (p *Storage) createAuthRequest(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID string) (_ models.AuthRequestInt, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
||||
if !ok {
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "SAML-sd436", "no user agent id")
|
||||
@@ -113,6 +160,15 @@ func (p *Storage) CreateAuthRequest(ctx context.Context, req *samlp.AuthnRequest
|
||||
func (p *Storage) AuthRequestByID(ctx context.Context, id string) (_ models.AuthRequestInt, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if strings.HasPrefix(id, command.IDPrefixV2) {
|
||||
req, err := p.command.GetCurrentSAMLRequest(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AuthRequestV2{req}, nil
|
||||
}
|
||||
|
||||
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
||||
if !ok {
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "SAML-D3g21", "no user agent id")
|
||||
|
||||
Reference in New Issue
Block a user