mirror of
https://github.com/zitadel/zitadel.git
synced 2025-04-22 12:41:32 +00:00
feat: add saml request to link to sessions
This commit is contained in:
parent
26e936aec3
commit
905da945ff
203
internal/api/grpc/saml/oidc.go
Normal file
203
internal/api/grpc/saml/oidc.go
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
package oidc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/zitadel/logging"
|
||||||
|
"github.com/zitadel/oidc/v3/pkg/op"
|
||||||
|
"google.golang.org/protobuf/types/known/durationpb"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
|
"github.com/zitadel/zitadel/internal/api/grpc/object/v2"
|
||||||
|
"github.com/zitadel/zitadel/internal/api/http"
|
||||||
|
"github.com/zitadel/zitadel/internal/api/oidc"
|
||||||
|
"github.com/zitadel/zitadel/internal/domain"
|
||||||
|
"github.com/zitadel/zitadel/internal/query"
|
||||||
|
"github.com/zitadel/zitadel/internal/zerrors"
|
||||||
|
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Server) GetAuthRequest(ctx context.Context, req *oidc_pb.GetAuthRequestRequest) (*oidc_pb.GetAuthRequestResponse, error) {
|
||||||
|
authRequest, err := s.query.AuthRequestByID(ctx, true, req.GetAuthRequestId(), true)
|
||||||
|
if err != nil {
|
||||||
|
logging.WithError(err).Error("query authRequest by ID")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &oidc_pb.GetAuthRequestResponse{
|
||||||
|
AuthRequest: authRequestToPb(authRequest),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func authRequestToPb(a *query.AuthRequest) *oidc_pb.AuthRequest {
|
||||||
|
pba := &oidc_pb.AuthRequest{
|
||||||
|
Id: a.ID,
|
||||||
|
CreationDate: timestamppb.New(a.CreationDate),
|
||||||
|
ClientId: a.ClientID,
|
||||||
|
Scope: a.Scope,
|
||||||
|
RedirectUri: a.RedirectURI,
|
||||||
|
Prompt: promptsToPb(a.Prompt),
|
||||||
|
UiLocales: a.UiLocales,
|
||||||
|
LoginHint: a.LoginHint,
|
||||||
|
HintUserId: a.HintUserID,
|
||||||
|
}
|
||||||
|
if a.MaxAge != nil {
|
||||||
|
pba.MaxAge = durationpb.New(*a.MaxAge)
|
||||||
|
}
|
||||||
|
return pba
|
||||||
|
}
|
||||||
|
|
||||||
|
func promptsToPb(promps []domain.Prompt) []oidc_pb.Prompt {
|
||||||
|
out := make([]oidc_pb.Prompt, len(promps))
|
||||||
|
for i, p := range promps {
|
||||||
|
out[i] = promptToPb(p)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func promptToPb(p domain.Prompt) oidc_pb.Prompt {
|
||||||
|
switch p {
|
||||||
|
case domain.PromptUnspecified:
|
||||||
|
return oidc_pb.Prompt_PROMPT_UNSPECIFIED
|
||||||
|
case domain.PromptNone:
|
||||||
|
return oidc_pb.Prompt_PROMPT_NONE
|
||||||
|
case domain.PromptLogin:
|
||||||
|
return oidc_pb.Prompt_PROMPT_LOGIN
|
||||||
|
case domain.PromptConsent:
|
||||||
|
return oidc_pb.Prompt_PROMPT_CONSENT
|
||||||
|
case domain.PromptSelectAccount:
|
||||||
|
return oidc_pb.Prompt_PROMPT_SELECT_ACCOUNT
|
||||||
|
case domain.PromptCreate:
|
||||||
|
return oidc_pb.Prompt_PROMPT_CREATE
|
||||||
|
default:
|
||||||
|
return oidc_pb.Prompt_PROMPT_UNSPECIFIED
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) CreateCallback(ctx context.Context, req *oidc_pb.CreateCallbackRequest) (*oidc_pb.CreateCallbackResponse, error) {
|
||||||
|
switch v := req.GetCallbackKind().(type) {
|
||||||
|
case *oidc_pb.CreateCallbackRequest_Error:
|
||||||
|
return s.failAuthRequest(ctx, req.GetAuthRequestId(), v.Error)
|
||||||
|
case *oidc_pb.CreateCallbackRequest_Session:
|
||||||
|
return s.linkSessionToAuthRequest(ctx, req.GetAuthRequestId(), v.Session)
|
||||||
|
default:
|
||||||
|
return nil, zerrors.ThrowUnimplementedf(nil, "OIDCv2-zee7A", "verification oneOf %T in method CreateCallback not implemented", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) failAuthRequest(ctx context.Context, authRequestID string, ae *oidc_pb.AuthorizationError) (*oidc_pb.CreateCallbackResponse, error) {
|
||||||
|
details, aar, err := s.command.FailAuthRequest(ctx, authRequestID, errorReasonToDomain(ae.GetError()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar}
|
||||||
|
callback, err := oidc.CreateErrorCallbackURL(authReq, errorReasonToOIDC(ae.GetError()), ae.GetErrorDescription(), ae.GetErrorUri(), s.op.Provider())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &oidc_pb.CreateCallbackResponse{
|
||||||
|
Details: object.DomainToDetailsPb(details),
|
||||||
|
CallbackUrl: callback,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) linkSessionToAuthRequest(ctx context.Context, authRequestID string, session *oidc_pb.Session) (*oidc_pb.CreateCallbackResponse, error) {
|
||||||
|
details, aar, err := s.command.LinkSessionToAuthRequest(ctx, authRequestID, session.GetSessionId(), session.GetSessionToken(), true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar}
|
||||||
|
ctx = op.ContextWithIssuer(ctx, http.DomainContext(ctx).Origin())
|
||||||
|
var callback string
|
||||||
|
if aar.ResponseType == domain.OIDCResponseTypeCode {
|
||||||
|
callback, err = oidc.CreateCodeCallbackURL(ctx, authReq, s.op.Provider())
|
||||||
|
} else {
|
||||||
|
callback, err = s.op.CreateTokenCallbackURL(ctx, authReq)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &oidc_pb.CreateCallbackResponse{
|
||||||
|
Details: object.DomainToDetailsPb(details),
|
||||||
|
CallbackUrl: callback,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func errorReasonToDomain(errorReason oidc_pb.ErrorReason) domain.OIDCErrorReason {
|
||||||
|
switch errorReason {
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED:
|
||||||
|
return domain.OIDCErrorReasonUnspecified
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST:
|
||||||
|
return domain.OIDCErrorReasonInvalidRequest
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT:
|
||||||
|
return domain.OIDCErrorReasonUnauthorizedClient
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED:
|
||||||
|
return domain.OIDCErrorReasonAccessDenied
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE:
|
||||||
|
return domain.OIDCErrorReasonUnsupportedResponseType
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE:
|
||||||
|
return domain.OIDCErrorReasonInvalidScope
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR:
|
||||||
|
return domain.OIDCErrorReasonServerError
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE:
|
||||||
|
return domain.OIDCErrorReasonTemporaryUnavailable
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED:
|
||||||
|
return domain.OIDCErrorReasonInteractionRequired
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED:
|
||||||
|
return domain.OIDCErrorReasonLoginRequired
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED:
|
||||||
|
return domain.OIDCErrorReasonAccountSelectionRequired
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED:
|
||||||
|
return domain.OIDCErrorReasonConsentRequired
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI:
|
||||||
|
return domain.OIDCErrorReasonInvalidRequestURI
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT:
|
||||||
|
return domain.OIDCErrorReasonInvalidRequestObject
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED:
|
||||||
|
return domain.OIDCErrorReasonRequestNotSupported
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED:
|
||||||
|
return domain.OIDCErrorReasonRequestURINotSupported
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED:
|
||||||
|
return domain.OIDCErrorReasonRegistrationNotSupported
|
||||||
|
default:
|
||||||
|
return domain.OIDCErrorReasonUnspecified
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func errorReasonToOIDC(reason oidc_pb.ErrorReason) string {
|
||||||
|
switch reason {
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST:
|
||||||
|
return "invalid_request"
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT:
|
||||||
|
return "unauthorized_client"
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED:
|
||||||
|
return "access_denied"
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE:
|
||||||
|
return "unsupported_response_type"
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE:
|
||||||
|
return "invalid_scope"
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE:
|
||||||
|
return "temporarily_unavailable"
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED:
|
||||||
|
return "interaction_required"
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED:
|
||||||
|
return "login_required"
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED:
|
||||||
|
return "account_selection_required"
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED:
|
||||||
|
return "consent_required"
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI:
|
||||||
|
return "invalid_request_uri"
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT:
|
||||||
|
return "invalid_request_object"
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED:
|
||||||
|
return "request_not_supported"
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED:
|
||||||
|
return "request_uri_not_supported"
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED:
|
||||||
|
return "registration_not_supported"
|
||||||
|
case oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED, oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR:
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
return "server_error"
|
||||||
|
}
|
||||||
|
}
|
59
internal/api/grpc/saml/server.go
Normal file
59
internal/api/grpc/saml/server.go
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
package oidc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
|
"github.com/zitadel/zitadel/internal/api/authz"
|
||||||
|
"github.com/zitadel/zitadel/internal/api/grpc/server"
|
||||||
|
"github.com/zitadel/zitadel/internal/api/oidc"
|
||||||
|
"github.com/zitadel/zitadel/internal/command"
|
||||||
|
"github.com/zitadel/zitadel/internal/query"
|
||||||
|
saml_pb "github.com/zitadel/zitadel/pkg/grpc/saml/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ saml_pb.SAMLServiceServer = (*Server)(nil)
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
saml_pb.UnimplementedSAMLServiceServer
|
||||||
|
command *command.Commands
|
||||||
|
query *query.Queries
|
||||||
|
|
||||||
|
op *oidc.Server
|
||||||
|
externalSecure bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type Config struct{}
|
||||||
|
|
||||||
|
func CreateServer(
|
||||||
|
command *command.Commands,
|
||||||
|
query *query.Queries,
|
||||||
|
op *oidc.Server,
|
||||||
|
externalSecure bool,
|
||||||
|
) *Server {
|
||||||
|
return &Server{
|
||||||
|
command: command,
|
||||||
|
query: query,
|
||||||
|
op: op,
|
||||||
|
externalSecure: externalSecure,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) RegisterServer(grpcServer *grpc.Server) {
|
||||||
|
saml_pb.RegisterSAMLServiceServer(grpcServer, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) AppName() string {
|
||||||
|
return saml_pb.SAMLService_ServiceDesc.ServiceName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) MethodPrefix() string {
|
||||||
|
return saml_pb.SAMLService_ServiceDesc.ServiceName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) AuthMethods() authz.MethodMapping {
|
||||||
|
return saml_pb.SAMLService_AuthMethods
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) RegisterGateway() server.RegisterGatewayFunc {
|
||||||
|
return saml_pb.RegisterSAMLServiceHandler
|
||||||
|
}
|
60
internal/api/saml/auth_request_converter_v2.go
Normal file
60
internal/api/saml/auth_request_converter_v2.go
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
package saml
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/zitadel/saml/pkg/provider/models"
|
||||||
|
|
||||||
|
"github.com/zitadel/zitadel/internal/command"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ models.AuthRequestInt = &AuthRequestV2{}
|
||||||
|
|
||||||
|
type AuthRequestV2 struct {
|
||||||
|
*command.CurrentSAMLRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthRequestV2) GetApplicationID() string {
|
||||||
|
return a.ApplicationID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthRequestV2) GetID() string {
|
||||||
|
return a.ID
|
||||||
|
}
|
||||||
|
func (a *AuthRequestV2) GetRelayState() string {
|
||||||
|
return a.RelayState
|
||||||
|
}
|
||||||
|
func (a *AuthRequestV2) GetAccessConsumerServiceURL() string {
|
||||||
|
return a.ACSURL
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthRequestV2) GetNameID() string {
|
||||||
|
return a.UserID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthRequestV2) GetAuthRequestID() string {
|
||||||
|
return a.RequestID
|
||||||
|
}
|
||||||
|
func (a *AuthRequestV2) GetBindingType() string {
|
||||||
|
return a.Binding
|
||||||
|
}
|
||||||
|
func (a *AuthRequestV2) GetIssuer() string {
|
||||||
|
return a.Issuer
|
||||||
|
}
|
||||||
|
func (a *AuthRequestV2) GetIssuerName() string {
|
||||||
|
return a.IssuerName
|
||||||
|
}
|
||||||
|
func (a *AuthRequestV2) GetDestination() string {
|
||||||
|
return a.Destination
|
||||||
|
}
|
||||||
|
func (a *AuthRequestV2) GetCode() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
func (a *AuthRequestV2) GetUserID() string {
|
||||||
|
return a.UserID
|
||||||
|
}
|
||||||
|
func (a *AuthRequestV2) GetUserName() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthRequestV2) Done() bool {
|
||||||
|
return a.UserID != "" && a.SessionID != ""
|
||||||
|
}
|
@ -3,6 +3,7 @@ package saml
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/dop251/goja"
|
"github.com/dop251/goja"
|
||||||
@ -16,6 +17,7 @@ import (
|
|||||||
"github.com/zitadel/zitadel/internal/actions"
|
"github.com/zitadel/zitadel/internal/actions"
|
||||||
"github.com/zitadel/zitadel/internal/actions/object"
|
"github.com/zitadel/zitadel/internal/actions/object"
|
||||||
"github.com/zitadel/zitadel/internal/activity"
|
"github.com/zitadel/zitadel/internal/activity"
|
||||||
|
http_utils "github.com/zitadel/zitadel/internal/api/http"
|
||||||
"github.com/zitadel/zitadel/internal/api/http/middleware"
|
"github.com/zitadel/zitadel/internal/api/http/middleware"
|
||||||
"github.com/zitadel/zitadel/internal/auth/repository"
|
"github.com/zitadel/zitadel/internal/auth/repository"
|
||||||
"github.com/zitadel/zitadel/internal/command"
|
"github.com/zitadel/zitadel/internal/command"
|
||||||
@ -33,6 +35,10 @@ var _ provider.IdentityProviderStorage = &Storage{}
|
|||||||
var _ provider.AuthStorage = &Storage{}
|
var _ provider.AuthStorage = &Storage{}
|
||||||
var _ provider.UserStorage = &Storage{}
|
var _ provider.UserStorage = &Storage{}
|
||||||
|
|
||||||
|
const (
|
||||||
|
LoginClientHeader = "x-zitadel-login-client"
|
||||||
|
)
|
||||||
|
|
||||||
type Storage struct {
|
type Storage struct {
|
||||||
certChan <-chan interface{}
|
certChan <-chan interface{}
|
||||||
defaultCertificateLifetime time.Duration
|
defaultCertificateLifetime time.Duration
|
||||||
@ -95,6 +101,47 @@ func (p *Storage) GetResponseSigningKey(ctx context.Context) (*key.CertificateAn
|
|||||||
func (p *Storage) CreateAuthRequest(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID string) (_ models.AuthRequestInt, err error) {
|
func (p *Storage) CreateAuthRequest(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID string) (_ models.AuthRequestInt, err error) {
|
||||||
ctx, span := tracing.NewSpan(ctx)
|
ctx, span := tracing.NewSpan(ctx)
|
||||||
defer func() { span.EndWithError(err) }()
|
defer func() { span.EndWithError(err) }()
|
||||||
|
|
||||||
|
headers, _ := http_utils.HeadersFromCtx(ctx)
|
||||||
|
if loginClient := headers.Get(LoginClientHeader); loginClient != "" {
|
||||||
|
return p.createAuthRequestLoginClient(ctx, req, acsUrl, protocolBinding, relayState, applicationID)
|
||||||
|
}
|
||||||
|
|
||||||
|
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
||||||
|
if !ok {
|
||||||
|
return nil, zerrors.ThrowPreconditionFailed(nil, "SAML-sd436", "no user agent id")
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := p.repo.CreateAuthRequest(ctx, CreateAuthRequestToBusiness(ctx, req, acsUrl, protocolBinding, applicationID, relayState, userAgentID))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return AuthRequestFromBusiness(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Storage) createAuthRequestLoginClient(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID string) (models.AuthRequestInt, error) {
|
||||||
|
samlRequest := &command.SAMLRequest{
|
||||||
|
ApplicationID: applicationID,
|
||||||
|
ACSURL: acsUrl,
|
||||||
|
RelayState: relayState,
|
||||||
|
RequestID: req.Id,
|
||||||
|
Binding: protocolBinding,
|
||||||
|
Issuer: req.Issuer.Text,
|
||||||
|
IssuerName: req.Issuer.SPProvidedID,
|
||||||
|
Destination: req.Destination,
|
||||||
|
}
|
||||||
|
|
||||||
|
aar, err := p.command.AddSAMLRequest(ctx, samlRequest)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &AuthRequestV2{aar}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Storage) createAuthRequest(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID string) (_ models.AuthRequestInt, err error) {
|
||||||
|
ctx, span := tracing.NewSpan(ctx)
|
||||||
|
defer func() { span.EndWithError(err) }()
|
||||||
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, zerrors.ThrowPreconditionFailed(nil, "SAML-sd436", "no user agent id")
|
return nil, zerrors.ThrowPreconditionFailed(nil, "SAML-sd436", "no user agent id")
|
||||||
@ -113,6 +160,15 @@ func (p *Storage) CreateAuthRequest(ctx context.Context, req *samlp.AuthnRequest
|
|||||||
func (p *Storage) AuthRequestByID(ctx context.Context, id string) (_ models.AuthRequestInt, err error) {
|
func (p *Storage) AuthRequestByID(ctx context.Context, id string) (_ models.AuthRequestInt, err error) {
|
||||||
ctx, span := tracing.NewSpan(ctx)
|
ctx, span := tracing.NewSpan(ctx)
|
||||||
defer func() { span.EndWithError(err) }()
|
defer func() { span.EndWithError(err) }()
|
||||||
|
|
||||||
|
if strings.HasPrefix(id, command.IDPrefixV2) {
|
||||||
|
req, err := p.command.GetCurrentSAMLRequest(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &AuthRequestV2{req}, nil
|
||||||
|
}
|
||||||
|
|
||||||
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, zerrors.ThrowPreconditionFailed(nil, "SAML-D3g21", "no user agent id")
|
return nil, zerrors.ThrowPreconditionFailed(nil, "SAML-D3g21", "no user agent id")
|
||||||
|
162
internal/command/saml_request.go
Normal file
162
internal/command/saml_request.go
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
package command
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/zitadel/zitadel/internal/api/authz"
|
||||||
|
"github.com/zitadel/zitadel/internal/domain"
|
||||||
|
"github.com/zitadel/zitadel/internal/repository/authrequest"
|
||||||
|
"github.com/zitadel/zitadel/internal/repository/samlrequest"
|
||||||
|
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||||
|
"github.com/zitadel/zitadel/internal/zerrors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SAMLRequest struct {
|
||||||
|
ID string
|
||||||
|
LoginClient string
|
||||||
|
|
||||||
|
ApplicationID string
|
||||||
|
EntityID string
|
||||||
|
ACSURL string
|
||||||
|
RelayState string
|
||||||
|
RequestID string
|
||||||
|
Binding string
|
||||||
|
Issuer string
|
||||||
|
IssuerName string
|
||||||
|
Destination string
|
||||||
|
}
|
||||||
|
|
||||||
|
type CurrentSAMLRequest struct {
|
||||||
|
*SAMLRequest
|
||||||
|
SessionID string
|
||||||
|
UserID string
|
||||||
|
AuthMethods []domain.UserAuthMethodType
|
||||||
|
AuthTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Commands) AddSAMLRequest(ctx context.Context, samlRequest *SAMLRequest) (_ *CurrentSAMLRequest, err error) {
|
||||||
|
id, err := c.idGenerator.Next()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
samlRequest.ID = IDPrefixV2 + id
|
||||||
|
writeModel, err := c.getSAMLRequestWriteModel(ctx, samlRequest.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if writeModel.SAMLRequestState != domain.SAMLRequestStateUnspecified {
|
||||||
|
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sf3gt", "Errors.SAMLRequest.AlreadyExisting")
|
||||||
|
}
|
||||||
|
err = c.pushAppendAndReduce(ctx, writeModel, samlrequest.NewAddedEvent(
|
||||||
|
ctx,
|
||||||
|
&authrequest.NewAggregate(samlRequest.ID, authz.GetInstance(ctx).InstanceID()).Aggregate,
|
||||||
|
samlRequest.LoginClient,
|
||||||
|
samlRequest.ApplicationID,
|
||||||
|
samlRequest.ACSURL,
|
||||||
|
samlRequest.RelayState,
|
||||||
|
samlRequest.RequestID,
|
||||||
|
samlRequest.Binding,
|
||||||
|
samlRequest.Issuer,
|
||||||
|
samlRequest.IssuerName,
|
||||||
|
samlRequest.Destination,
|
||||||
|
))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return samlRequestWriteModelToCurrentSAMLRequest(writeModel), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Commands) LinkSessionToSAMLRequest(ctx context.Context, id, sessionID, sessionToken string) (*domain.ObjectDetails, *CurrentSAMLRequest, error) {
|
||||||
|
writeModel, err := c.getSAMLRequestWriteModel(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if writeModel.SAMLRequestState == domain.SAMLRequestStateUnspecified {
|
||||||
|
return nil, nil, zerrors.ThrowNotFound(nil, "COMMAND-jae5P", "Errors.SAMLRequest.NotExisting")
|
||||||
|
}
|
||||||
|
if writeModel.SAMLRequestState != domain.SAMLRequestStateAdded {
|
||||||
|
return nil, nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sx208nt", "Errors.SAMLRequest.AlreadyHandled")
|
||||||
|
}
|
||||||
|
sessionWriteModel := NewSessionWriteModel(sessionID, authz.GetInstance(ctx).InstanceID())
|
||||||
|
err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if err = sessionWriteModel.CheckIsActive(); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if err := c.sessionTokenVerifier(ctx, sessionToken, sessionWriteModel.AggregateID, sessionWriteModel.TokenID); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.pushAppendAndReduce(ctx, writeModel, samlrequest.NewSessionLinkedEvent(
|
||||||
|
ctx, &samlrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate,
|
||||||
|
sessionID,
|
||||||
|
sessionWriteModel.UserID,
|
||||||
|
sessionWriteModel.AuthenticationTime(),
|
||||||
|
sessionWriteModel.AuthMethodTypes(),
|
||||||
|
)); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return writeModelToObjectDetails(&writeModel.WriteModel), samlRequestWriteModelToCurrentSAMLRequest(writeModel), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Commands) FailSAMLRequest(ctx context.Context, id string, reason domain.SAMLErrorReason) (*domain.ObjectDetails, *CurrentSAMLRequest, error) {
|
||||||
|
writeModel, err := c.getSAMLRequestWriteModel(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if writeModel.SAMLRequestState != domain.SAMLRequestStateAdded {
|
||||||
|
return nil, nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sx202nt", "Errors.SAMLRequest.AlreadyHandled")
|
||||||
|
}
|
||||||
|
err = c.pushAppendAndReduce(ctx, writeModel, samlrequest.NewFailedEvent(
|
||||||
|
ctx,
|
||||||
|
&samlrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate,
|
||||||
|
reason,
|
||||||
|
))
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return writeModelToObjectDetails(&writeModel.WriteModel), samlRequestWriteModelToCurrentSAMLRequest(writeModel), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func samlRequestWriteModelToCurrentSAMLRequest(writeModel *SAMLRequestWriteModel) (_ *CurrentSAMLRequest) {
|
||||||
|
return &CurrentSAMLRequest{
|
||||||
|
SAMLRequest: &SAMLRequest{
|
||||||
|
ID: writeModel.AggregateID,
|
||||||
|
ApplicationID: writeModel.ApplicationID,
|
||||||
|
ACSURL: writeModel.ACSURL,
|
||||||
|
RelayState: writeModel.RelayState,
|
||||||
|
RequestID: writeModel.RequestID,
|
||||||
|
Binding: writeModel.Binding,
|
||||||
|
Issuer: writeModel.Issuer,
|
||||||
|
IssuerName: writeModel.IssuerName,
|
||||||
|
Destination: writeModel.Destination,
|
||||||
|
},
|
||||||
|
SessionID: writeModel.SessionID,
|
||||||
|
UserID: writeModel.UserID,
|
||||||
|
AuthMethods: writeModel.AuthMethods,
|
||||||
|
AuthTime: writeModel.AuthTime,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Commands) GetCurrentSAMLRequest(ctx context.Context, id string) (_ *CurrentSAMLRequest, err error) {
|
||||||
|
wm, err := c.getSAMLRequestWriteModel(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return samlRequestWriteModelToCurrentSAMLRequest(wm), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Commands) getSAMLRequestWriteModel(ctx context.Context, id string) (writeModel *SAMLRequestWriteModel, err error) {
|
||||||
|
ctx, span := tracing.NewSpan(ctx)
|
||||||
|
defer func() { span.EndWithError(err) }()
|
||||||
|
|
||||||
|
writeModel = NewSAMLRequestWriteModel(ctx, id)
|
||||||
|
err = c.eventstore.FilterToQueryReducer(ctx, writeModel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return writeModel, nil
|
||||||
|
}
|
94
internal/command/saml_request_model.go
Normal file
94
internal/command/saml_request_model.go
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
package command
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/zitadel/zitadel/internal/api/authz"
|
||||||
|
"github.com/zitadel/zitadel/internal/domain"
|
||||||
|
"github.com/zitadel/zitadel/internal/eventstore"
|
||||||
|
"github.com/zitadel/zitadel/internal/repository/authrequest"
|
||||||
|
"github.com/zitadel/zitadel/internal/repository/samlrequest"
|
||||||
|
"github.com/zitadel/zitadel/internal/zerrors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SAMLRequestWriteModel struct {
|
||||||
|
eventstore.WriteModel
|
||||||
|
aggregate *eventstore.Aggregate
|
||||||
|
|
||||||
|
ApplicationID string
|
||||||
|
ACSURL string
|
||||||
|
RelayState string
|
||||||
|
RequestID string
|
||||||
|
Binding string
|
||||||
|
Issuer string
|
||||||
|
IssuerName string
|
||||||
|
Destination string
|
||||||
|
|
||||||
|
SessionID string
|
||||||
|
UserID string
|
||||||
|
AuthTime time.Time
|
||||||
|
AuthMethods []domain.UserAuthMethodType
|
||||||
|
SAMLRequestState domain.SAMLRequestState
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSAMLRequestWriteModel(ctx context.Context, id string) *SAMLRequestWriteModel {
|
||||||
|
return &SAMLRequestWriteModel{
|
||||||
|
WriteModel: eventstore.WriteModel{
|
||||||
|
AggregateID: id,
|
||||||
|
},
|
||||||
|
aggregate: &authrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SAMLRequestWriteModel) Reduce() error {
|
||||||
|
for _, event := range m.Events {
|
||||||
|
switch e := event.(type) {
|
||||||
|
case *samlrequest.AddedEvent:
|
||||||
|
m.ApplicationID = e.ApplicationID
|
||||||
|
m.ACSURL = e.ACSURL
|
||||||
|
m.RelayState = e.RelayState
|
||||||
|
m.RequestID = e.RequestID
|
||||||
|
m.Binding = e.Binding
|
||||||
|
m.Issuer = e.Issuer
|
||||||
|
m.IssuerName = e.IssuerName
|
||||||
|
m.Destination = e.Destination
|
||||||
|
m.SAMLRequestState = domain.SAMLRequestStateAdded
|
||||||
|
case *samlrequest.SessionLinkedEvent:
|
||||||
|
m.SessionID = e.SessionID
|
||||||
|
m.UserID = e.UserID
|
||||||
|
m.AuthTime = e.AuthTime
|
||||||
|
m.AuthMethods = e.AuthMethods
|
||||||
|
case *samlrequest.FailedEvent:
|
||||||
|
m.SAMLRequestState = domain.SAMLRequestStateFailed
|
||||||
|
case *samlrequest.SucceededEvent:
|
||||||
|
m.SAMLRequestState = domain.SAMLRequestStateSucceeded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m.WriteModel.Reduce()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SAMLRequestWriteModel) Query() *eventstore.SearchQueryBuilder {
|
||||||
|
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||||
|
AddQuery().
|
||||||
|
AggregateTypes(samlrequest.AggregateType).
|
||||||
|
AggregateIDs(m.AggregateID).
|
||||||
|
Builder()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckAuthenticated checks that the auth request exists, a session must have been linked
|
||||||
|
func (m *SAMLRequestWriteModel) CheckAuthenticated() error {
|
||||||
|
if m.SessionID == "" {
|
||||||
|
return zerrors.ThrowPreconditionFailed(nil, "AUTHR-SF2r2", "Errors.SAMLRequest.NotAuthenticated")
|
||||||
|
}
|
||||||
|
// in case of OIDC Code Flow, the code must have been exchanged
|
||||||
|
if m.ResponseType == domain.OIDCResponseTypeCode && m.AuthRequestState == domain.AuthRequestStateCodeExchanged {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// in case of OIDC Implicit Flow, check that the requests exists, but has not succeeded yet
|
||||||
|
if (m.ResponseType == domain.OIDCResponseTypeIDToken || m.ResponseType == domain.OIDCResponseTypeIDTokenToken) &&
|
||||||
|
m.AuthRequestState == domain.AuthRequestStateAdded {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return zerrors.ThrowPreconditionFailed(nil, "AUTHR-sajk3", "Errors.SAMLRequest.NotAuthenticated")
|
||||||
|
}
|
668
internal/command/saml_request_test.go
Normal file
668
internal/command/saml_request_test.go
Normal file
@ -0,0 +1,668 @@
|
|||||||
|
package command
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/muhlemmer/gu"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/text/language"
|
||||||
|
|
||||||
|
"github.com/zitadel/zitadel/internal/api/authz"
|
||||||
|
"github.com/zitadel/zitadel/internal/domain"
|
||||||
|
"github.com/zitadel/zitadel/internal/eventstore"
|
||||||
|
"github.com/zitadel/zitadel/internal/id"
|
||||||
|
"github.com/zitadel/zitadel/internal/id/mock"
|
||||||
|
"github.com/zitadel/zitadel/internal/repository/authrequest"
|
||||||
|
"github.com/zitadel/zitadel/internal/repository/samlrequest"
|
||||||
|
"github.com/zitadel/zitadel/internal/repository/session"
|
||||||
|
"github.com/zitadel/zitadel/internal/zerrors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCommands_AddSAMLRequest(t *testing.T) {
|
||||||
|
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
|
||||||
|
type fields struct {
|
||||||
|
eventstore func(t *testing.T) *eventstore.Eventstore
|
||||||
|
idGenerator id.Generator
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
ctx context.Context
|
||||||
|
request *SAMLRequest
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
want *CurrentSAMLRequest
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"already exists error",
|
||||||
|
fields{
|
||||||
|
eventstore: expectEventstore(
|
||||||
|
expectFilter(
|
||||||
|
eventFromEventPusher(
|
||||||
|
samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||||
|
"login",
|
||||||
|
"application",
|
||||||
|
"acs",
|
||||||
|
"relaystate",
|
||||||
|
"request",
|
||||||
|
"binding",
|
||||||
|
"issuer",
|
||||||
|
"name",
|
||||||
|
"destination",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
idGenerator: mock.NewIDGeneratorExpectIDs(t, "id"),
|
||||||
|
},
|
||||||
|
args{
|
||||||
|
ctx: mockCtx,
|
||||||
|
request: &SAMLRequest{},
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sf3gt", "Errors.AuthRequest.AlreadyExisting"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"added",
|
||||||
|
fields{
|
||||||
|
eventstore: expectEventstore(
|
||||||
|
expectFilter(),
|
||||||
|
expectPush(
|
||||||
|
samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||||
|
"login",
|
||||||
|
"application",
|
||||||
|
"acs",
|
||||||
|
"relaystate",
|
||||||
|
"request",
|
||||||
|
"binding",
|
||||||
|
"issuer",
|
||||||
|
"name",
|
||||||
|
"destination",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
idGenerator: mock.NewIDGeneratorExpectIDs(t, "id"),
|
||||||
|
},
|
||||||
|
args{
|
||||||
|
ctx: mockCtx,
|
||||||
|
request: &SAMLRequest{
|
||||||
|
LoginClient: "login",
|
||||||
|
ApplicationID: "application",
|
||||||
|
ACSURL: "acs",
|
||||||
|
RelayState: "relaystate",
|
||||||
|
RequestID: "request",
|
||||||
|
Binding: "binding",
|
||||||
|
Issuer: "issuer",
|
||||||
|
IssuerName: "name",
|
||||||
|
Destination: "destination",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
&CurrentSAMLRequest{
|
||||||
|
SAMLRequest: &SAMLRequest{
|
||||||
|
ID: "V2_id",
|
||||||
|
LoginClient: "login",
|
||||||
|
ApplicationID: "application",
|
||||||
|
ACSURL: "acs",
|
||||||
|
RelayState: "relaystate",
|
||||||
|
RequestID: "request",
|
||||||
|
Binding: "binding",
|
||||||
|
Issuer: "issuer",
|
||||||
|
IssuerName: "name",
|
||||||
|
Destination: "destination",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := &Commands{
|
||||||
|
eventstore: tt.fields.eventstore(t),
|
||||||
|
idGenerator: tt.fields.idGenerator,
|
||||||
|
}
|
||||||
|
got, err := c.AddSAMLRequest(tt.args.ctx, tt.args.request)
|
||||||
|
require.ErrorIs(t, tt.wantErr, err)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCommands_LinkSessionToSAMLRequest(t *testing.T) {
|
||||||
|
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
|
||||||
|
type fields struct {
|
||||||
|
eventstore func(t *testing.T) *eventstore.Eventstore
|
||||||
|
tokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error)
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
ctx context.Context
|
||||||
|
id string
|
||||||
|
sessionID string
|
||||||
|
sessionToken string
|
||||||
|
checkLoginClient bool
|
||||||
|
}
|
||||||
|
type res struct {
|
||||||
|
details *domain.ObjectDetails
|
||||||
|
authReq *CurrentSAMLRequest
|
||||||
|
wantErr error
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
res res
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"authRequest not found",
|
||||||
|
fields{
|
||||||
|
eventstore: expectEventstore(
|
||||||
|
expectFilter(),
|
||||||
|
),
|
||||||
|
tokenVerifier: newMockTokenVerifierValid(),
|
||||||
|
},
|
||||||
|
args{
|
||||||
|
ctx: mockCtx,
|
||||||
|
id: "id",
|
||||||
|
sessionID: "sessionID",
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
wantErr: zerrors.ThrowNotFound(nil, "COMMAND-jae5P", "Errors.AuthRequest.NotExisting"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"authRequest not existing",
|
||||||
|
fields{
|
||||||
|
eventstore: expectEventstore(
|
||||||
|
expectFilter(
|
||||||
|
eventFromEventPusher(
|
||||||
|
samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||||
|
"login",
|
||||||
|
"application",
|
||||||
|
"acs",
|
||||||
|
"relaystate",
|
||||||
|
"request",
|
||||||
|
"binding",
|
||||||
|
"issuer",
|
||||||
|
"name",
|
||||||
|
"destination",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
eventFromEventPusher(
|
||||||
|
authrequest.NewFailedEvent(mockCtx, &authrequest.NewAggregate("id", "instanceID").Aggregate,
|
||||||
|
domain.OIDCErrorReasonUnspecified),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
tokenVerifier: newMockTokenVerifierValid(),
|
||||||
|
},
|
||||||
|
args{
|
||||||
|
ctx: mockCtx,
|
||||||
|
id: "id",
|
||||||
|
sessionID: "sessionID",
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sx208nt", "Errors.AuthRequest.AlreadyHandled"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"wrong login client",
|
||||||
|
fields{
|
||||||
|
eventstore: expectEventstore(
|
||||||
|
expectFilter(
|
||||||
|
eventFromEventPusher(
|
||||||
|
samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||||
|
"login",
|
||||||
|
"application",
|
||||||
|
"acs",
|
||||||
|
"relaystate",
|
||||||
|
"request",
|
||||||
|
"binding",
|
||||||
|
"issuer",
|
||||||
|
"name",
|
||||||
|
"destination",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
tokenVerifier: newMockTokenVerifierValid(),
|
||||||
|
},
|
||||||
|
args{
|
||||||
|
ctx: authz.NewMockContext("instanceID", "orgID", "wrongLoginClient"),
|
||||||
|
id: "id",
|
||||||
|
sessionID: "sessionID",
|
||||||
|
sessionToken: "token",
|
||||||
|
checkLoginClient: true,
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
wantErr: zerrors.ThrowPermissionDenied(nil, "COMMAND-rai9Y", "Errors.AuthRequest.WrongLoginClient"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"session not existing",
|
||||||
|
fields{
|
||||||
|
eventstore: expectEventstore(
|
||||||
|
expectFilter(
|
||||||
|
eventFromEventPusher(
|
||||||
|
samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||||
|
"login",
|
||||||
|
"application",
|
||||||
|
"acs",
|
||||||
|
"relaystate",
|
||||||
|
"request",
|
||||||
|
"binding",
|
||||||
|
"issuer",
|
||||||
|
"name",
|
||||||
|
"destination",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
expectFilter(),
|
||||||
|
),
|
||||||
|
tokenVerifier: newMockTokenVerifierValid(),
|
||||||
|
},
|
||||||
|
args{
|
||||||
|
ctx: mockCtx,
|
||||||
|
id: "V2_id",
|
||||||
|
sessionID: "sessionID",
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Flk38", "Errors.Session.NotExisting"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"session expired",
|
||||||
|
fields{
|
||||||
|
eventstore: expectEventstore(
|
||||||
|
expectFilter(
|
||||||
|
eventFromEventPusher(
|
||||||
|
samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||||
|
"login",
|
||||||
|
"application",
|
||||||
|
"acs",
|
||||||
|
"relaystate",
|
||||||
|
"request",
|
||||||
|
"binding",
|
||||||
|
"issuer",
|
||||||
|
"name",
|
||||||
|
"destination",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
expectFilter(
|
||||||
|
eventFromEventPusher(
|
||||||
|
session.NewAddedEvent(mockCtx,
|
||||||
|
&session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||||
|
&domain.UserAgent{
|
||||||
|
FingerprintID: gu.Ptr("fp1"),
|
||||||
|
IP: net.ParseIP("1.2.3.4"),
|
||||||
|
Description: gu.Ptr("firefox"),
|
||||||
|
Header: http.Header{"foo": []string{"bar"}},
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
eventFromEventPusher(
|
||||||
|
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||||
|
"userID", "org1", testNow.Add(-5*time.Minute), &language.Afrikaans),
|
||||||
|
),
|
||||||
|
eventFromEventPusher(
|
||||||
|
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||||
|
testNow.Add(-5*time.Minute)),
|
||||||
|
),
|
||||||
|
eventFromEventPusher(
|
||||||
|
session.NewLifetimeSetEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||||
|
2*time.Minute),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
args{
|
||||||
|
ctx: mockCtx,
|
||||||
|
id: "V2_id",
|
||||||
|
sessionID: "sessionID",
|
||||||
|
sessionToken: "token",
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Hkl3d", "Errors.Session.Expired"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"invalid session token",
|
||||||
|
fields{
|
||||||
|
eventstore: expectEventstore(
|
||||||
|
expectFilter(
|
||||||
|
eventFromEventPusher(
|
||||||
|
samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||||
|
"login",
|
||||||
|
"application",
|
||||||
|
"acs",
|
||||||
|
"relaystate",
|
||||||
|
"request",
|
||||||
|
"binding",
|
||||||
|
"issuer",
|
||||||
|
"name",
|
||||||
|
"destination",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
expectFilter(
|
||||||
|
eventFromEventPusher(
|
||||||
|
session.NewAddedEvent(mockCtx,
|
||||||
|
&session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||||
|
&domain.UserAgent{
|
||||||
|
FingerprintID: gu.Ptr("fp1"),
|
||||||
|
IP: net.ParseIP("1.2.3.4"),
|
||||||
|
Description: gu.Ptr("firefox"),
|
||||||
|
Header: http.Header{"foo": []string{"bar"}},
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
tokenVerifier: newMockTokenVerifierInvalid(),
|
||||||
|
},
|
||||||
|
args{
|
||||||
|
ctx: mockCtx,
|
||||||
|
id: "V2_id",
|
||||||
|
sessionID: "sessionID",
|
||||||
|
sessionToken: "invalid",
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
wantErr: zerrors.ThrowPermissionDenied(nil, "COMMAND-sGr42", "Errors.Session.Token.Invalid"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"linked",
|
||||||
|
fields{
|
||||||
|
eventstore: expectEventstore(
|
||||||
|
expectFilter(
|
||||||
|
eventFromEventPusher(
|
||||||
|
samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||||
|
"login",
|
||||||
|
"application",
|
||||||
|
"acs",
|
||||||
|
"relaystate",
|
||||||
|
"request",
|
||||||
|
"binding",
|
||||||
|
"issuer",
|
||||||
|
"name",
|
||||||
|
"destination",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
expectFilter(
|
||||||
|
eventFromEventPusher(
|
||||||
|
session.NewAddedEvent(mockCtx,
|
||||||
|
&session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||||
|
&domain.UserAgent{
|
||||||
|
FingerprintID: gu.Ptr("fp1"),
|
||||||
|
IP: net.ParseIP("1.2.3.4"),
|
||||||
|
Description: gu.Ptr("firefox"),
|
||||||
|
Header: http.Header{"foo": []string{"bar"}},
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
eventFromEventPusher(
|
||||||
|
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||||
|
"userID", "org1", testNow, &language.Afrikaans),
|
||||||
|
),
|
||||||
|
eventFromEventPusher(
|
||||||
|
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||||
|
testNow),
|
||||||
|
),
|
||||||
|
eventFromEventPusherWithCreationDateNow(
|
||||||
|
session.NewLifetimeSetEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||||
|
2*time.Minute),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
expectPush(
|
||||||
|
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||||
|
"sessionID",
|
||||||
|
"userID",
|
||||||
|
testNow,
|
||||||
|
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
tokenVerifier: newMockTokenVerifierValid(),
|
||||||
|
},
|
||||||
|
args{
|
||||||
|
ctx: mockCtx,
|
||||||
|
id: "V2_id",
|
||||||
|
sessionID: "sessionID",
|
||||||
|
sessionToken: "token",
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
|
||||||
|
authReq: &CurrentSAMLRequest{
|
||||||
|
SAMLRequest: &SAMLRequest{
|
||||||
|
ID: "V2_id",
|
||||||
|
ApplicationID: "application",
|
||||||
|
ACSURL: "acs",
|
||||||
|
RelayState: "relaystate",
|
||||||
|
RequestID: "request",
|
||||||
|
Binding: "binding",
|
||||||
|
Issuer: "issuer",
|
||||||
|
IssuerName: "name",
|
||||||
|
Destination: "destination",
|
||||||
|
},
|
||||||
|
SessionID: "sessionID",
|
||||||
|
UserID: "userID",
|
||||||
|
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"linked with login client check",
|
||||||
|
fields{
|
||||||
|
eventstore: expectEventstore(
|
||||||
|
expectFilter(
|
||||||
|
eventFromEventPusher(
|
||||||
|
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||||
|
"loginClient",
|
||||||
|
"clientID",
|
||||||
|
"redirectURI",
|
||||||
|
"state",
|
||||||
|
"nonce",
|
||||||
|
[]string{"openid"},
|
||||||
|
[]string{"audience"},
|
||||||
|
domain.OIDCResponseTypeCode,
|
||||||
|
domain.OIDCResponseModeQuery,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
true,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
expectFilter(
|
||||||
|
eventFromEventPusher(
|
||||||
|
session.NewAddedEvent(mockCtx,
|
||||||
|
&session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||||
|
&domain.UserAgent{
|
||||||
|
FingerprintID: gu.Ptr("fp1"),
|
||||||
|
IP: net.ParseIP("1.2.3.4"),
|
||||||
|
Description: gu.Ptr("firefox"),
|
||||||
|
Header: http.Header{"foo": []string{"bar"}},
|
||||||
|
},
|
||||||
|
)),
|
||||||
|
eventFromEventPusher(
|
||||||
|
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||||
|
"userID", "org1", testNow, &language.Afrikaans),
|
||||||
|
),
|
||||||
|
eventFromEventPusher(
|
||||||
|
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||||
|
testNow),
|
||||||
|
),
|
||||||
|
eventFromEventPusherWithCreationDateNow(
|
||||||
|
session.NewLifetimeSetEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
|
||||||
|
2*time.Minute),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
expectPush(
|
||||||
|
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||||
|
"sessionID",
|
||||||
|
"userID",
|
||||||
|
testNow,
|
||||||
|
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
tokenVerifier: newMockTokenVerifierValid(),
|
||||||
|
},
|
||||||
|
args{
|
||||||
|
ctx: authz.NewMockContext("instanceID", "orgID", "loginClient"),
|
||||||
|
id: "V2_id",
|
||||||
|
sessionID: "sessionID",
|
||||||
|
sessionToken: "token",
|
||||||
|
checkLoginClient: true,
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
|
||||||
|
authReq: &CurrentSAMLRequest{
|
||||||
|
SAMLRequest: &SAMLRequest{
|
||||||
|
ID: "V2_id",
|
||||||
|
ApplicationID: "application",
|
||||||
|
ACSURL: "acs",
|
||||||
|
RelayState: "relaystate",
|
||||||
|
RequestID: "request",
|
||||||
|
Binding: "binding",
|
||||||
|
Issuer: "issuer",
|
||||||
|
IssuerName: "name",
|
||||||
|
Destination: "destination",
|
||||||
|
},
|
||||||
|
SessionID: "sessionID",
|
||||||
|
UserID: "userID",
|
||||||
|
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := &Commands{
|
||||||
|
eventstore: tt.fields.eventstore(t),
|
||||||
|
sessionTokenVerifier: tt.fields.tokenVerifier,
|
||||||
|
}
|
||||||
|
details, got, err := c.LinkSessionToSAMLRequest(tt.args.ctx, tt.args.id, tt.args.sessionID, tt.args.sessionToken)
|
||||||
|
require.ErrorIs(t, err, tt.res.wantErr)
|
||||||
|
assertObjectDetails(t, tt.res.details, details)
|
||||||
|
if err == nil {
|
||||||
|
assert.WithinRange(t, got.AuthTime, testNow, testNow)
|
||||||
|
got.AuthTime = time.Time{}
|
||||||
|
}
|
||||||
|
assert.Equal(t, tt.res.authReq, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCommands_FailSAMLRequest(t *testing.T) {
|
||||||
|
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
|
||||||
|
type fields struct {
|
||||||
|
eventstore func(t *testing.T) *eventstore.Eventstore
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
ctx context.Context
|
||||||
|
id string
|
||||||
|
reason domain.OIDCErrorReason
|
||||||
|
}
|
||||||
|
type res struct {
|
||||||
|
details *domain.ObjectDetails
|
||||||
|
authReq *CurrentAuthRequest
|
||||||
|
wantErr error
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
res res
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"authRequest not existing",
|
||||||
|
fields{
|
||||||
|
eventstore: expectEventstore(
|
||||||
|
expectFilter(),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
args{
|
||||||
|
ctx: mockCtx,
|
||||||
|
id: "foo",
|
||||||
|
reason: domain.OIDCErrorReasonLoginRequired,
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sx202nt", "Errors.AuthRequest.AlreadyHandled"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"failed",
|
||||||
|
fields{
|
||||||
|
eventstore: expectEventstore(
|
||||||
|
expectFilter(
|
||||||
|
eventFromEventPusher(
|
||||||
|
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||||
|
"loginClient",
|
||||||
|
"clientID",
|
||||||
|
"redirectURI",
|
||||||
|
"state",
|
||||||
|
"nonce",
|
||||||
|
[]string{"openid"},
|
||||||
|
[]string{"audience"},
|
||||||
|
domain.OIDCResponseTypeCode,
|
||||||
|
domain.OIDCResponseModeQuery,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
true,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
expectPush(
|
||||||
|
authrequest.NewFailedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
|
||||||
|
domain.OIDCErrorReasonLoginRequired),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
args{
|
||||||
|
ctx: mockCtx,
|
||||||
|
id: "V2_id",
|
||||||
|
reason: domain.OIDCErrorReasonLoginRequired,
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
|
||||||
|
authReq: &CurrentAuthRequest{
|
||||||
|
AuthRequest: &AuthRequest{
|
||||||
|
ID: "V2_id",
|
||||||
|
LoginClient: "loginClient",
|
||||||
|
ClientID: "clientID",
|
||||||
|
RedirectURI: "redirectURI",
|
||||||
|
State: "state",
|
||||||
|
Nonce: "nonce",
|
||||||
|
Scope: []string{"openid"},
|
||||||
|
Audience: []string{"audience"},
|
||||||
|
ResponseType: domain.OIDCResponseTypeCode,
|
||||||
|
ResponseMode: domain.OIDCResponseModeQuery,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := &Commands{
|
||||||
|
eventstore: tt.fields.eventstore(t),
|
||||||
|
}
|
||||||
|
details, got, err := c.FailAuthRequest(tt.args.ctx, tt.args.id, tt.args.reason)
|
||||||
|
require.ErrorIs(t, err, tt.res.wantErr)
|
||||||
|
assertObjectDetails(t, tt.res.details, details)
|
||||||
|
assert.Equal(t, tt.res.authReq, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
11
internal/domain/saml_error_reason.go
Normal file
11
internal/domain/saml_error_reason.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
type SAMLErrorReason int32
|
||||||
|
|
||||||
|
const (
|
||||||
|
SAMLErrorReasonUnspecified SAMLErrorReason = iota
|
||||||
|
)
|
||||||
|
|
||||||
|
func SAMLErrorReasonFromError(err error) SAMLErrorReason {
|
||||||
|
return SAMLErrorReasonUnspecified
|
||||||
|
}
|
10
internal/domain/saml_request.go
Normal file
10
internal/domain/saml_request.go
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
package domain
|
||||||
|
|
||||||
|
type SAMLRequestState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
SAMLRequestStateUnspecified SAMLRequestState = iota
|
||||||
|
SAMLRequestStateAdded
|
||||||
|
SAMLRequestStateFailed
|
||||||
|
SAMLRequestStateSucceeded
|
||||||
|
)
|
26
internal/repository/samlrequest/aggregate.go
Normal file
26
internal/repository/samlrequest/aggregate.go
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
package samlrequest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/zitadel/zitadel/internal/eventstore"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
AggregateType = "saml_request"
|
||||||
|
AggregateVersion = "v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Aggregate struct {
|
||||||
|
eventstore.Aggregate
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAggregate(id, instanceID string) *Aggregate {
|
||||||
|
return &Aggregate{
|
||||||
|
Aggregate: eventstore.Aggregate{
|
||||||
|
Type: AggregateType,
|
||||||
|
Version: AggregateVersion,
|
||||||
|
ID: id,
|
||||||
|
ResourceOwner: instanceID,
|
||||||
|
InstanceID: instanceID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
10
internal/repository/samlrequest/eventstore.go
Normal file
10
internal/repository/samlrequest/eventstore.go
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
package samlrequest
|
||||||
|
|
||||||
|
import "github.com/zitadel/zitadel/internal/eventstore"
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
eventstore.RegisterFilterEventMapper(AggregateType, AddedType, eventstore.GenericEventMapper[AddedEvent])
|
||||||
|
eventstore.RegisterFilterEventMapper(AggregateType, SessionLinkedType, eventstore.GenericEventMapper[SessionLinkedEvent])
|
||||||
|
eventstore.RegisterFilterEventMapper(AggregateType, FailedType, eventstore.GenericEventMapper[FailedEvent])
|
||||||
|
eventstore.RegisterFilterEventMapper(AggregateType, SucceededType, eventstore.GenericEventMapper[SucceededEvent])
|
||||||
|
}
|
175
internal/repository/samlrequest/saml_request.go
Normal file
175
internal/repository/samlrequest/saml_request.go
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
package samlrequest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/zitadel/zitadel/internal/domain"
|
||||||
|
"github.com/zitadel/zitadel/internal/eventstore"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
samlRequestEventPrefix = "saml_request."
|
||||||
|
AddedType = samlRequestEventPrefix + "added"
|
||||||
|
FailedType = samlRequestEventPrefix + "failed"
|
||||||
|
SessionLinkedType = samlRequestEventPrefix + "session.linked"
|
||||||
|
SucceededType = samlRequestEventPrefix + "succeeded"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AddedEvent struct {
|
||||||
|
*eventstore.BaseEvent `json:"-"`
|
||||||
|
|
||||||
|
LoginClient string `json:"loginClient,omitempty"`
|
||||||
|
ApplicationID string `json:"application_id,omitempty"`
|
||||||
|
ACSURL string `json:"acs_url,omitempty"`
|
||||||
|
RelayState string `json:"relay_state,omitempty"`
|
||||||
|
RequestID string `json:"request_id,omitempty"`
|
||||||
|
Binding string `json:"binding,omitempty"`
|
||||||
|
Issuer string `json:"issuer,omitempty"`
|
||||||
|
IssuerName string `json:"issuer_name,omitempty"`
|
||||||
|
Destination string `json:"destination,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *AddedEvent) SetBaseEvent(event *eventstore.BaseEvent) {
|
||||||
|
e.BaseEvent = event
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *AddedEvent) Payload() interface{} {
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *AddedEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAddedEvent(ctx context.Context,
|
||||||
|
aggregate *eventstore.Aggregate,
|
||||||
|
loginClient,
|
||||||
|
applicationID string,
|
||||||
|
acsURL string,
|
||||||
|
relayState string,
|
||||||
|
requestID string,
|
||||||
|
binding string,
|
||||||
|
issuer string,
|
||||||
|
issuerName string,
|
||||||
|
destination string,
|
||||||
|
) *AddedEvent {
|
||||||
|
return &AddedEvent{
|
||||||
|
BaseEvent: eventstore.NewBaseEventForPush(
|
||||||
|
ctx,
|
||||||
|
aggregate,
|
||||||
|
AddedType,
|
||||||
|
),
|
||||||
|
LoginClient: loginClient,
|
||||||
|
ApplicationID: applicationID,
|
||||||
|
ACSURL: acsURL,
|
||||||
|
RelayState: relayState,
|
||||||
|
RequestID: requestID,
|
||||||
|
Binding: binding,
|
||||||
|
Issuer: issuer,
|
||||||
|
IssuerName: issuerName,
|
||||||
|
Destination: destination,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type SessionLinkedEvent struct {
|
||||||
|
*eventstore.BaseEvent `json:"-"`
|
||||||
|
|
||||||
|
SessionID string `json:"session_id"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
AuthTime time.Time `json:"auth_time"`
|
||||||
|
AuthMethods []domain.UserAuthMethodType `json:"auth_methods"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SessionLinkedEvent) Payload() interface{} {
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SessionLinkedEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSessionLinkedEvent(ctx context.Context,
|
||||||
|
aggregate *eventstore.Aggregate,
|
||||||
|
sessionID,
|
||||||
|
userID string,
|
||||||
|
authTime time.Time,
|
||||||
|
authMethods []domain.UserAuthMethodType,
|
||||||
|
) *SessionLinkedEvent {
|
||||||
|
return &SessionLinkedEvent{
|
||||||
|
BaseEvent: eventstore.NewBaseEventForPush(
|
||||||
|
ctx,
|
||||||
|
aggregate,
|
||||||
|
SessionLinkedType,
|
||||||
|
),
|
||||||
|
SessionID: sessionID,
|
||||||
|
UserID: userID,
|
||||||
|
AuthTime: authTime,
|
||||||
|
AuthMethods: authMethods,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SessionLinkedEvent) SetBaseEvent(event *eventstore.BaseEvent) {
|
||||||
|
e.BaseEvent = event
|
||||||
|
}
|
||||||
|
|
||||||
|
type FailedEvent struct {
|
||||||
|
*eventstore.BaseEvent `json:"-"`
|
||||||
|
|
||||||
|
Reason domain.SAMLErrorReason `json:"reason,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *FailedEvent) Payload() interface{} {
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *FailedEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFailedEvent(
|
||||||
|
ctx context.Context,
|
||||||
|
aggregate *eventstore.Aggregate,
|
||||||
|
reason domain.SAMLErrorReason,
|
||||||
|
) *FailedEvent {
|
||||||
|
return &FailedEvent{
|
||||||
|
BaseEvent: eventstore.NewBaseEventForPush(
|
||||||
|
ctx,
|
||||||
|
aggregate,
|
||||||
|
FailedType,
|
||||||
|
),
|
||||||
|
Reason: reason,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *FailedEvent) SetBaseEvent(event *eventstore.BaseEvent) {
|
||||||
|
e.BaseEvent = event
|
||||||
|
}
|
||||||
|
|
||||||
|
type SucceededEvent struct {
|
||||||
|
*eventstore.BaseEvent `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SucceededEvent) Payload() interface{} {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SucceededEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSucceededEvent(ctx context.Context,
|
||||||
|
aggregate *eventstore.Aggregate,
|
||||||
|
) *SucceededEvent {
|
||||||
|
return &SucceededEvent{
|
||||||
|
BaseEvent: eventstore.NewBaseEventForPush(
|
||||||
|
ctx,
|
||||||
|
aggregate,
|
||||||
|
SucceededType,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SucceededEvent) SetBaseEvent(event *eventstore.BaseEvent) {
|
||||||
|
e.BaseEvent = event
|
||||||
|
}
|
@ -169,8 +169,8 @@ message CreateCallbackRequest {
|
|||||||
string auth_request_id = 1 [
|
string auth_request_id = 1 [
|
||||||
(validate.rules).string = {min_len: 1, max_len: 200},
|
(validate.rules).string = {min_len: 1, max_len: 200},
|
||||||
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||||
description: "Set this field when the authorization flow failed. It creates a callback URL to the application, with the error details set.";
|
description: "ID of the SAML Request.";
|
||||||
ref: "https://openid.net/specs/openid-connect-core-1_0.html#AuthError";
|
example: "\"163840776835432705\"";
|
||||||
}
|
}
|
||||||
];
|
];
|
||||||
|
|
||||||
|
84
proto/zitadel/saml/v2/authorization.proto
Normal file
84
proto/zitadel/saml/v2/authorization.proto
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package zitadel.saml.v2;
|
||||||
|
|
||||||
|
import "google/protobuf/duration.proto";
|
||||||
|
import "google/protobuf/timestamp.proto";
|
||||||
|
import "protoc-gen-openapiv2/options/annotations.proto";
|
||||||
|
|
||||||
|
option go_package = "github.com/zitadel/zitadel/pkg/grpc/saml/v2;saml";
|
||||||
|
|
||||||
|
message SAMLRequest{
|
||||||
|
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_schema) = {
|
||||||
|
external_docs: {
|
||||||
|
url: "https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest";
|
||||||
|
description: "Find out more about SAML Auth Request parameters";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
string id = 1 [
|
||||||
|
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||||
|
description: "ID of the authorization request";
|
||||||
|
}
|
||||||
|
];
|
||||||
|
|
||||||
|
google.protobuf.Timestamp creation_date = 2 [
|
||||||
|
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||||
|
description: "Time when the auth request was created";
|
||||||
|
}
|
||||||
|
];
|
||||||
|
|
||||||
|
string issuer = 3 [
|
||||||
|
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||||
|
description: "SAML entity ID of the application that created the auth request";
|
||||||
|
}
|
||||||
|
];
|
||||||
|
|
||||||
|
string assertion_consumer_url = 4 [
|
||||||
|
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||||
|
description: "Base URI that points back to the application";
|
||||||
|
}
|
||||||
|
];
|
||||||
|
|
||||||
|
string relay_state = 5 [
|
||||||
|
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||||
|
description: "RelayState provided by the application for the request";
|
||||||
|
}
|
||||||
|
];
|
||||||
|
|
||||||
|
string binding = 6 [
|
||||||
|
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||||
|
description: "Binding used by the application for the request";
|
||||||
|
}
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
message AuthorizationError {
|
||||||
|
ErrorReason error = 1;
|
||||||
|
optional string error_description = 2;
|
||||||
|
optional string error_uri = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum ErrorReason {
|
||||||
|
ERROR_REASON_UNSPECIFIED = 0;
|
||||||
|
|
||||||
|
// Error states from https://datatracker.ietf.org/doc/html/rfc6749#section-4.2.2.1
|
||||||
|
ERROR_REASON_INVALID_REQUEST = 1;
|
||||||
|
ERROR_REASON_UNAUTHORIZED_CLIENT = 2;
|
||||||
|
ERROR_REASON_ACCESS_DENIED = 3;
|
||||||
|
ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE = 4;
|
||||||
|
ERROR_REASON_INVALID_SCOPE = 5;
|
||||||
|
ERROR_REASON_SERVER_ERROR = 6;
|
||||||
|
ERROR_REASON_TEMPORARY_UNAVAILABLE = 7;
|
||||||
|
|
||||||
|
// Error states from https://openid.net/specs/openid-connect-core-1_0.html#AuthError
|
||||||
|
ERROR_REASON_INTERACTION_REQUIRED = 8;
|
||||||
|
ERROR_REASON_LOGIN_REQUIRED = 9;
|
||||||
|
ERROR_REASON_ACCOUNT_SELECTION_REQUIRED = 10;
|
||||||
|
ERROR_REASON_CONSENT_REQUIRED = 11;
|
||||||
|
ERROR_REASON_INVALID_REQUEST_URI = 12;
|
||||||
|
ERROR_REASON_INVALID_REQUEST_OBJECT = 13;
|
||||||
|
ERROR_REASON_REQUEST_NOT_SUPPORTED = 14;
|
||||||
|
ERROR_REASON_REQUEST_URI_NOT_SUPPORTED = 15;
|
||||||
|
ERROR_REASON_REGISTRATION_NOT_SUPPORTED = 16;
|
||||||
|
}
|
142
proto/zitadel/saml/v2/saml_service.proto
Normal file
142
proto/zitadel/saml/v2/saml_service.proto
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package zitadel.saml.v2;
|
||||||
|
|
||||||
|
import "zitadel/object/v2/object.proto";
|
||||||
|
import "zitadel/protoc_gen_zitadel/v2/options.proto";
|
||||||
|
import "zitadel/saml/v2/authorization.proto";
|
||||||
|
import "google/api/annotations.proto";
|
||||||
|
import "google/api/field_behavior.proto";
|
||||||
|
import "protoc-gen-openapiv2/options/annotations.proto";
|
||||||
|
import "validate/validate.proto";
|
||||||
|
|
||||||
|
option go_package = "github.com/zitadel/zitadel/pkg/grpc/saml/v2;saml";
|
||||||
|
|
||||||
|
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_swagger) = {
|
||||||
|
info: {
|
||||||
|
title: "SAML Service";
|
||||||
|
version: "2.0";
|
||||||
|
description: "Get SAML Auth Request details and create callback URLs.";
|
||||||
|
contact:{
|
||||||
|
name: "ZITADEL"
|
||||||
|
url: "https://zitadel.com"
|
||||||
|
email: "hi@zitadel.com"
|
||||||
|
}
|
||||||
|
license: {
|
||||||
|
name: "Apache 2.0",
|
||||||
|
url: "https://github.com/zitadel/zitadel/blob/main/LICENSE";
|
||||||
|
};
|
||||||
|
};
|
||||||
|
schemes: HTTPS;
|
||||||
|
schemes: HTTP;
|
||||||
|
|
||||||
|
consumes: "application/json";
|
||||||
|
consumes: "application/grpc";
|
||||||
|
|
||||||
|
produces: "application/json";
|
||||||
|
produces: "application/grpc";
|
||||||
|
|
||||||
|
consumes: "application/grpc-web+proto";
|
||||||
|
produces: "application/grpc-web+proto";
|
||||||
|
|
||||||
|
host: "$CUSTOM-DOMAIN";
|
||||||
|
base_path: "/";
|
||||||
|
|
||||||
|
external_docs: {
|
||||||
|
description: "Detailed information about ZITADEL",
|
||||||
|
url: "https://zitadel.com/docs"
|
||||||
|
}
|
||||||
|
security_definitions: {
|
||||||
|
security: {
|
||||||
|
key: "OAuth2";
|
||||||
|
value: {
|
||||||
|
type: TYPE_OAUTH2;
|
||||||
|
flow: FLOW_ACCESS_CODE;
|
||||||
|
authorization_url: "$CUSTOM-DOMAIN/oauth/v2/authorize";
|
||||||
|
token_url: "$CUSTOM-DOMAIN/oauth/v2/token";
|
||||||
|
scopes: {
|
||||||
|
scope: {
|
||||||
|
key: "openid";
|
||||||
|
value: "openid";
|
||||||
|
}
|
||||||
|
scope: {
|
||||||
|
key: "urn:zitadel:iam:org:project:id:zitadel:aud";
|
||||||
|
value: "urn:zitadel:iam:org:project:id:zitadel:aud";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
security: {
|
||||||
|
security_requirement: {
|
||||||
|
key: "OAuth2";
|
||||||
|
value: {
|
||||||
|
scope: "openid";
|
||||||
|
scope: "urn:zitadel:iam:org:project:id:zitadel:aud";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
responses: {
|
||||||
|
key: "403";
|
||||||
|
value: {
|
||||||
|
description: "Returned when the user does not have permission to access the resource.";
|
||||||
|
schema: {
|
||||||
|
json_schema: {
|
||||||
|
ref: "#/definitions/rpcStatus";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
responses: {
|
||||||
|
key: "404";
|
||||||
|
value: {
|
||||||
|
description: "Returned when the resource does not exist.";
|
||||||
|
schema: {
|
||||||
|
json_schema: {
|
||||||
|
ref: "#/definitions/rpcStatus";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
service SAMLService {
|
||||||
|
rpc GetAuthRequest (GetSAMLRequestRequest) returns (GetSAMLRequestResponse) {
|
||||||
|
option (google.api.http) = {
|
||||||
|
get: "/v2/saml/saml_requests/{saml_request_id}"
|
||||||
|
};
|
||||||
|
|
||||||
|
option (zitadel.protoc_gen_zitadel.v2.options) = {
|
||||||
|
auth_option: {
|
||||||
|
permission: "authenticated"
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
|
||||||
|
summary: "Get SAML Request details";
|
||||||
|
description: "Get SAML Request details by ID. Returns details that are parsed from the application's SAML Request."
|
||||||
|
responses: {
|
||||||
|
key: "200"
|
||||||
|
value: {
|
||||||
|
description: "OK";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetSAMLRequestRequest {
|
||||||
|
string saml_request_id = 1 [
|
||||||
|
(validate.rules).string = {min_len: 1, max_len: 200},
|
||||||
|
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||||
|
min_length: 1;
|
||||||
|
max_length: 200;
|
||||||
|
description: "ID of the SAML Request, as obtained from the redirect URL.";
|
||||||
|
example: "\"163840776835432705\"";
|
||||||
|
}
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetSAMLRequestResponse {
|
||||||
|
SAMLRequest saml_request = 1;
|
||||||
|
}
|
@ -134,6 +134,7 @@ message IdentityProvider {
|
|||||||
string id = 1;
|
string id = 1;
|
||||||
string name = 2;
|
string name = 2;
|
||||||
IdentityProviderType type = 3;
|
IdentityProviderType type = 3;
|
||||||
|
bool is_linking_allowed = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
enum IdentityProviderType {
|
enum IdentityProviderType {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user