From 905da945ff9cdbe6d34f31fdeb1c17670c70eec2 Mon Sep 17 00:00:00 2001 From: Stefan Benz <46600784+stebenz@users.noreply.github.com> Date: Tue, 3 Dec 2024 10:30:35 +0100 Subject: [PATCH] feat: add saml request to link to sessions --- internal/api/grpc/saml/oidc.go | 203 ++++++ internal/api/grpc/saml/server.go | 59 ++ .../api/saml/auth_request_converter_v2.go | 60 ++ internal/api/saml/storage.go | 56 ++ internal/command/saml_request.go | 162 +++++ internal/command/saml_request_model.go | 94 +++ internal/command/saml_request_test.go | 668 ++++++++++++++++++ internal/domain/saml_error_reason.go | 11 + internal/domain/saml_request.go | 10 + internal/repository/samlrequest/aggregate.go | 26 + internal/repository/samlrequest/eventstore.go | 10 + .../repository/samlrequest/saml_request.go | 175 +++++ proto/zitadel/oidc/v2/oidc_service.proto | 4 +- proto/zitadel/saml/v2/authorization.proto | 84 +++ proto/zitadel/saml/v2/saml_service.proto | 142 ++++ .../zitadel/settings/v2/login_settings.proto | 1 + 16 files changed, 1763 insertions(+), 2 deletions(-) create mode 100644 internal/api/grpc/saml/oidc.go create mode 100644 internal/api/grpc/saml/server.go create mode 100644 internal/api/saml/auth_request_converter_v2.go create mode 100644 internal/command/saml_request.go create mode 100644 internal/command/saml_request_model.go create mode 100644 internal/command/saml_request_test.go create mode 100644 internal/domain/saml_error_reason.go create mode 100644 internal/domain/saml_request.go create mode 100644 internal/repository/samlrequest/aggregate.go create mode 100644 internal/repository/samlrequest/eventstore.go create mode 100644 internal/repository/samlrequest/saml_request.go create mode 100644 proto/zitadel/saml/v2/authorization.proto create mode 100644 proto/zitadel/saml/v2/saml_service.proto diff --git a/internal/api/grpc/saml/oidc.go b/internal/api/grpc/saml/oidc.go new file mode 100644 index 0000000000..826c198fad --- /dev/null +++ b/internal/api/grpc/saml/oidc.go @@ -0,0 +1,203 @@ +package oidc + +import ( + "context" + + "github.com/zitadel/logging" + "github.com/zitadel/oidc/v3/pkg/op" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/zitadel/zitadel/internal/api/grpc/object/v2" + "github.com/zitadel/zitadel/internal/api/http" + "github.com/zitadel/zitadel/internal/api/oidc" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/query" + "github.com/zitadel/zitadel/internal/zerrors" + oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2" +) + +func (s *Server) GetAuthRequest(ctx context.Context, req *oidc_pb.GetAuthRequestRequest) (*oidc_pb.GetAuthRequestResponse, error) { + authRequest, err := s.query.AuthRequestByID(ctx, true, req.GetAuthRequestId(), true) + if err != nil { + logging.WithError(err).Error("query authRequest by ID") + return nil, err + } + return &oidc_pb.GetAuthRequestResponse{ + AuthRequest: authRequestToPb(authRequest), + }, nil +} + +func authRequestToPb(a *query.AuthRequest) *oidc_pb.AuthRequest { + pba := &oidc_pb.AuthRequest{ + Id: a.ID, + CreationDate: timestamppb.New(a.CreationDate), + ClientId: a.ClientID, + Scope: a.Scope, + RedirectUri: a.RedirectURI, + Prompt: promptsToPb(a.Prompt), + UiLocales: a.UiLocales, + LoginHint: a.LoginHint, + HintUserId: a.HintUserID, + } + if a.MaxAge != nil { + pba.MaxAge = durationpb.New(*a.MaxAge) + } + return pba +} + +func promptsToPb(promps []domain.Prompt) []oidc_pb.Prompt { + out := make([]oidc_pb.Prompt, len(promps)) + for i, p := range promps { + out[i] = promptToPb(p) + } + return out +} + +func promptToPb(p domain.Prompt) oidc_pb.Prompt { + switch p { + case domain.PromptUnspecified: + return oidc_pb.Prompt_PROMPT_UNSPECIFIED + case domain.PromptNone: + return oidc_pb.Prompt_PROMPT_NONE + case domain.PromptLogin: + return oidc_pb.Prompt_PROMPT_LOGIN + case domain.PromptConsent: + return oidc_pb.Prompt_PROMPT_CONSENT + case domain.PromptSelectAccount: + return oidc_pb.Prompt_PROMPT_SELECT_ACCOUNT + case domain.PromptCreate: + return oidc_pb.Prompt_PROMPT_CREATE + default: + return oidc_pb.Prompt_PROMPT_UNSPECIFIED + } +} + +func (s *Server) CreateCallback(ctx context.Context, req *oidc_pb.CreateCallbackRequest) (*oidc_pb.CreateCallbackResponse, error) { + switch v := req.GetCallbackKind().(type) { + case *oidc_pb.CreateCallbackRequest_Error: + return s.failAuthRequest(ctx, req.GetAuthRequestId(), v.Error) + case *oidc_pb.CreateCallbackRequest_Session: + return s.linkSessionToAuthRequest(ctx, req.GetAuthRequestId(), v.Session) + default: + return nil, zerrors.ThrowUnimplementedf(nil, "OIDCv2-zee7A", "verification oneOf %T in method CreateCallback not implemented", v) + } +} + +func (s *Server) failAuthRequest(ctx context.Context, authRequestID string, ae *oidc_pb.AuthorizationError) (*oidc_pb.CreateCallbackResponse, error) { + details, aar, err := s.command.FailAuthRequest(ctx, authRequestID, errorReasonToDomain(ae.GetError())) + if err != nil { + return nil, err + } + authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar} + callback, err := oidc.CreateErrorCallbackURL(authReq, errorReasonToOIDC(ae.GetError()), ae.GetErrorDescription(), ae.GetErrorUri(), s.op.Provider()) + if err != nil { + return nil, err + } + return &oidc_pb.CreateCallbackResponse{ + Details: object.DomainToDetailsPb(details), + CallbackUrl: callback, + }, nil +} + +func (s *Server) linkSessionToAuthRequest(ctx context.Context, authRequestID string, session *oidc_pb.Session) (*oidc_pb.CreateCallbackResponse, error) { + details, aar, err := s.command.LinkSessionToAuthRequest(ctx, authRequestID, session.GetSessionId(), session.GetSessionToken(), true) + if err != nil { + return nil, err + } + authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar} + ctx = op.ContextWithIssuer(ctx, http.DomainContext(ctx).Origin()) + var callback string + if aar.ResponseType == domain.OIDCResponseTypeCode { + callback, err = oidc.CreateCodeCallbackURL(ctx, authReq, s.op.Provider()) + } else { + callback, err = s.op.CreateTokenCallbackURL(ctx, authReq) + } + if err != nil { + return nil, err + } + return &oidc_pb.CreateCallbackResponse{ + Details: object.DomainToDetailsPb(details), + CallbackUrl: callback, + }, nil +} + +func errorReasonToDomain(errorReason oidc_pb.ErrorReason) domain.OIDCErrorReason { + switch errorReason { + case oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED: + return domain.OIDCErrorReasonUnspecified + case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST: + return domain.OIDCErrorReasonInvalidRequest + case oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT: + return domain.OIDCErrorReasonUnauthorizedClient + case oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED: + return domain.OIDCErrorReasonAccessDenied + case oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE: + return domain.OIDCErrorReasonUnsupportedResponseType + case oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE: + return domain.OIDCErrorReasonInvalidScope + case oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR: + return domain.OIDCErrorReasonServerError + case oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE: + return domain.OIDCErrorReasonTemporaryUnavailable + case oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED: + return domain.OIDCErrorReasonInteractionRequired + case oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED: + return domain.OIDCErrorReasonLoginRequired + case oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED: + return domain.OIDCErrorReasonAccountSelectionRequired + case oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED: + return domain.OIDCErrorReasonConsentRequired + case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI: + return domain.OIDCErrorReasonInvalidRequestURI + case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT: + return domain.OIDCErrorReasonInvalidRequestObject + case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED: + return domain.OIDCErrorReasonRequestNotSupported + case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED: + return domain.OIDCErrorReasonRequestURINotSupported + case oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED: + return domain.OIDCErrorReasonRegistrationNotSupported + default: + return domain.OIDCErrorReasonUnspecified + } +} + +func errorReasonToOIDC(reason oidc_pb.ErrorReason) string { + switch reason { + case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST: + return "invalid_request" + case oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT: + return "unauthorized_client" + case oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED: + return "access_denied" + case oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE: + return "unsupported_response_type" + case oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE: + return "invalid_scope" + case oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE: + return "temporarily_unavailable" + case oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED: + return "interaction_required" + case oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED: + return "login_required" + case oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED: + return "account_selection_required" + case oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED: + return "consent_required" + case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI: + return "invalid_request_uri" + case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT: + return "invalid_request_object" + case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED: + return "request_not_supported" + case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED: + return "request_uri_not_supported" + case oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED: + return "registration_not_supported" + case oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED, oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR: + fallthrough + default: + return "server_error" + } +} diff --git a/internal/api/grpc/saml/server.go b/internal/api/grpc/saml/server.go new file mode 100644 index 0000000000..446b584dd8 --- /dev/null +++ b/internal/api/grpc/saml/server.go @@ -0,0 +1,59 @@ +package oidc + +import ( + "google.golang.org/grpc" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/api/grpc/server" + "github.com/zitadel/zitadel/internal/api/oidc" + "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/query" + saml_pb "github.com/zitadel/zitadel/pkg/grpc/saml/v2" +) + +var _ saml_pb.SAMLServiceServer = (*Server)(nil) + +type Server struct { + saml_pb.UnimplementedSAMLServiceServer + command *command.Commands + query *query.Queries + + op *oidc.Server + externalSecure bool +} + +type Config struct{} + +func CreateServer( + command *command.Commands, + query *query.Queries, + op *oidc.Server, + externalSecure bool, +) *Server { + return &Server{ + command: command, + query: query, + op: op, + externalSecure: externalSecure, + } +} + +func (s *Server) RegisterServer(grpcServer *grpc.Server) { + saml_pb.RegisterSAMLServiceServer(grpcServer, s) +} + +func (s *Server) AppName() string { + return saml_pb.SAMLService_ServiceDesc.ServiceName +} + +func (s *Server) MethodPrefix() string { + return saml_pb.SAMLService_ServiceDesc.ServiceName +} + +func (s *Server) AuthMethods() authz.MethodMapping { + return saml_pb.SAMLService_AuthMethods +} + +func (s *Server) RegisterGateway() server.RegisterGatewayFunc { + return saml_pb.RegisterSAMLServiceHandler +} diff --git a/internal/api/saml/auth_request_converter_v2.go b/internal/api/saml/auth_request_converter_v2.go new file mode 100644 index 0000000000..441f97139c --- /dev/null +++ b/internal/api/saml/auth_request_converter_v2.go @@ -0,0 +1,60 @@ +package saml + +import ( + "github.com/zitadel/saml/pkg/provider/models" + + "github.com/zitadel/zitadel/internal/command" +) + +var _ models.AuthRequestInt = &AuthRequestV2{} + +type AuthRequestV2 struct { + *command.CurrentSAMLRequest +} + +func (a *AuthRequestV2) GetApplicationID() string { + return a.ApplicationID +} + +func (a *AuthRequestV2) GetID() string { + return a.ID +} +func (a *AuthRequestV2) GetRelayState() string { + return a.RelayState +} +func (a *AuthRequestV2) GetAccessConsumerServiceURL() string { + return a.ACSURL +} + +func (a *AuthRequestV2) GetNameID() string { + return a.UserID +} + +func (a *AuthRequestV2) GetAuthRequestID() string { + return a.RequestID +} +func (a *AuthRequestV2) GetBindingType() string { + return a.Binding +} +func (a *AuthRequestV2) GetIssuer() string { + return a.Issuer +} +func (a *AuthRequestV2) GetIssuerName() string { + return a.IssuerName +} +func (a *AuthRequestV2) GetDestination() string { + return a.Destination +} +func (a *AuthRequestV2) GetCode() string { + return "" +} +func (a *AuthRequestV2) GetUserID() string { + return a.UserID +} +func (a *AuthRequestV2) GetUserName() string { + return "" +} + +func (a *AuthRequestV2) Done() bool { + return a.UserID != "" && a.SessionID != "" +} diff --git a/internal/api/saml/storage.go b/internal/api/saml/storage.go index 76173c2592..7aa7fc0522 100644 --- a/internal/api/saml/storage.go +++ b/internal/api/saml/storage.go @@ -3,6 +3,7 @@ package saml import ( "context" "encoding/json" + "strings" "time" "github.com/dop251/goja" @@ -16,6 +17,7 @@ import ( "github.com/zitadel/zitadel/internal/actions" "github.com/zitadel/zitadel/internal/actions/object" "github.com/zitadel/zitadel/internal/activity" + http_utils "github.com/zitadel/zitadel/internal/api/http" "github.com/zitadel/zitadel/internal/api/http/middleware" "github.com/zitadel/zitadel/internal/auth/repository" "github.com/zitadel/zitadel/internal/command" @@ -33,6 +35,10 @@ var _ provider.IdentityProviderStorage = &Storage{} var _ provider.AuthStorage = &Storage{} var _ provider.UserStorage = &Storage{} +const ( + LoginClientHeader = "x-zitadel-login-client" +) + type Storage struct { certChan <-chan interface{} defaultCertificateLifetime time.Duration @@ -95,6 +101,47 @@ func (p *Storage) GetResponseSigningKey(ctx context.Context) (*key.CertificateAn func (p *Storage) CreateAuthRequest(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID string) (_ models.AuthRequestInt, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() + + headers, _ := http_utils.HeadersFromCtx(ctx) + if loginClient := headers.Get(LoginClientHeader); loginClient != "" { + return p.createAuthRequestLoginClient(ctx, req, acsUrl, protocolBinding, relayState, applicationID) + } + + userAgentID, ok := middleware.UserAgentIDFromCtx(ctx) + if !ok { + return nil, zerrors.ThrowPreconditionFailed(nil, "SAML-sd436", "no user agent id") + } + + resp, err := p.repo.CreateAuthRequest(ctx, CreateAuthRequestToBusiness(ctx, req, acsUrl, protocolBinding, applicationID, relayState, userAgentID)) + if err != nil { + return nil, err + } + + return AuthRequestFromBusiness(resp) +} + +func (p *Storage) createAuthRequestLoginClient(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID string) (models.AuthRequestInt, error) { + samlRequest := &command.SAMLRequest{ + ApplicationID: applicationID, + ACSURL: acsUrl, + RelayState: relayState, + RequestID: req.Id, + Binding: protocolBinding, + Issuer: req.Issuer.Text, + IssuerName: req.Issuer.SPProvidedID, + Destination: req.Destination, + } + + aar, err := p.command.AddSAMLRequest(ctx, samlRequest) + if err != nil { + return nil, err + } + return &AuthRequestV2{aar}, nil +} + +func (p *Storage) createAuthRequest(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID string) (_ models.AuthRequestInt, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() userAgentID, ok := middleware.UserAgentIDFromCtx(ctx) if !ok { return nil, zerrors.ThrowPreconditionFailed(nil, "SAML-sd436", "no user agent id") @@ -113,6 +160,15 @@ func (p *Storage) CreateAuthRequest(ctx context.Context, req *samlp.AuthnRequest func (p *Storage) AuthRequestByID(ctx context.Context, id string) (_ models.AuthRequestInt, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() + + if strings.HasPrefix(id, command.IDPrefixV2) { + req, err := p.command.GetCurrentSAMLRequest(ctx, id) + if err != nil { + return nil, err + } + return &AuthRequestV2{req}, nil + } + userAgentID, ok := middleware.UserAgentIDFromCtx(ctx) if !ok { return nil, zerrors.ThrowPreconditionFailed(nil, "SAML-D3g21", "no user agent id") diff --git a/internal/command/saml_request.go b/internal/command/saml_request.go new file mode 100644 index 0000000000..60ca622b88 --- /dev/null +++ b/internal/command/saml_request.go @@ -0,0 +1,162 @@ +package command + +import ( + "context" + "time" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/repository/authrequest" + "github.com/zitadel/zitadel/internal/repository/samlrequest" + "github.com/zitadel/zitadel/internal/telemetry/tracing" + "github.com/zitadel/zitadel/internal/zerrors" +) + +type SAMLRequest struct { + ID string + LoginClient string + + ApplicationID string + EntityID string + ACSURL string + RelayState string + RequestID string + Binding string + Issuer string + IssuerName string + Destination string +} + +type CurrentSAMLRequest struct { + *SAMLRequest + SessionID string + UserID string + AuthMethods []domain.UserAuthMethodType + AuthTime time.Time +} + +func (c *Commands) AddSAMLRequest(ctx context.Context, samlRequest *SAMLRequest) (_ *CurrentSAMLRequest, err error) { + id, err := c.idGenerator.Next() + if err != nil { + return nil, err + } + samlRequest.ID = IDPrefixV2 + id + writeModel, err := c.getSAMLRequestWriteModel(ctx, samlRequest.ID) + if err != nil { + return nil, err + } + if writeModel.SAMLRequestState != domain.SAMLRequestStateUnspecified { + return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sf3gt", "Errors.SAMLRequest.AlreadyExisting") + } + err = c.pushAppendAndReduce(ctx, writeModel, samlrequest.NewAddedEvent( + ctx, + &authrequest.NewAggregate(samlRequest.ID, authz.GetInstance(ctx).InstanceID()).Aggregate, + samlRequest.LoginClient, + samlRequest.ApplicationID, + samlRequest.ACSURL, + samlRequest.RelayState, + samlRequest.RequestID, + samlRequest.Binding, + samlRequest.Issuer, + samlRequest.IssuerName, + samlRequest.Destination, + )) + if err != nil { + return nil, err + } + return samlRequestWriteModelToCurrentSAMLRequest(writeModel), nil +} + +func (c *Commands) LinkSessionToSAMLRequest(ctx context.Context, id, sessionID, sessionToken string) (*domain.ObjectDetails, *CurrentSAMLRequest, error) { + writeModel, err := c.getSAMLRequestWriteModel(ctx, id) + if err != nil { + return nil, nil, err + } + if writeModel.SAMLRequestState == domain.SAMLRequestStateUnspecified { + return nil, nil, zerrors.ThrowNotFound(nil, "COMMAND-jae5P", "Errors.SAMLRequest.NotExisting") + } + if writeModel.SAMLRequestState != domain.SAMLRequestStateAdded { + return nil, nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sx208nt", "Errors.SAMLRequest.AlreadyHandled") + } + sessionWriteModel := NewSessionWriteModel(sessionID, authz.GetInstance(ctx).InstanceID()) + err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel) + if err != nil { + return nil, nil, err + } + if err = sessionWriteModel.CheckIsActive(); err != nil { + return nil, nil, err + } + if err := c.sessionTokenVerifier(ctx, sessionToken, sessionWriteModel.AggregateID, sessionWriteModel.TokenID); err != nil { + return nil, nil, err + } + + if err := c.pushAppendAndReduce(ctx, writeModel, samlrequest.NewSessionLinkedEvent( + ctx, &samlrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate, + sessionID, + sessionWriteModel.UserID, + sessionWriteModel.AuthenticationTime(), + sessionWriteModel.AuthMethodTypes(), + )); err != nil { + return nil, nil, err + } + return writeModelToObjectDetails(&writeModel.WriteModel), samlRequestWriteModelToCurrentSAMLRequest(writeModel), nil +} + +func (c *Commands) FailSAMLRequest(ctx context.Context, id string, reason domain.SAMLErrorReason) (*domain.ObjectDetails, *CurrentSAMLRequest, error) { + writeModel, err := c.getSAMLRequestWriteModel(ctx, id) + if err != nil { + return nil, nil, err + } + if writeModel.SAMLRequestState != domain.SAMLRequestStateAdded { + return nil, nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sx202nt", "Errors.SAMLRequest.AlreadyHandled") + } + err = c.pushAppendAndReduce(ctx, writeModel, samlrequest.NewFailedEvent( + ctx, + &samlrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate, + reason, + )) + if err != nil { + return nil, nil, err + } + return writeModelToObjectDetails(&writeModel.WriteModel), samlRequestWriteModelToCurrentSAMLRequest(writeModel), nil +} + +func samlRequestWriteModelToCurrentSAMLRequest(writeModel *SAMLRequestWriteModel) (_ *CurrentSAMLRequest) { + return &CurrentSAMLRequest{ + SAMLRequest: &SAMLRequest{ + ID: writeModel.AggregateID, + ApplicationID: writeModel.ApplicationID, + ACSURL: writeModel.ACSURL, + RelayState: writeModel.RelayState, + RequestID: writeModel.RequestID, + Binding: writeModel.Binding, + Issuer: writeModel.Issuer, + IssuerName: writeModel.IssuerName, + Destination: writeModel.Destination, + }, + SessionID: writeModel.SessionID, + UserID: writeModel.UserID, + AuthMethods: writeModel.AuthMethods, + AuthTime: writeModel.AuthTime, + } +} + +func (c *Commands) GetCurrentSAMLRequest(ctx context.Context, id string) (_ *CurrentSAMLRequest, err error) { + wm, err := c.getSAMLRequestWriteModel(ctx, id) + if err != nil { + return nil, err + } + return samlRequestWriteModelToCurrentSAMLRequest(wm), nil +} + +func (c *Commands) getSAMLRequestWriteModel(ctx context.Context, id string) (writeModel *SAMLRequestWriteModel, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + writeModel = NewSAMLRequestWriteModel(ctx, id) + err = c.eventstore.FilterToQueryReducer(ctx, writeModel) + if err != nil { + return nil, err + } + return writeModel, nil +} diff --git a/internal/command/saml_request_model.go b/internal/command/saml_request_model.go new file mode 100644 index 0000000000..f6a5ddcb89 --- /dev/null +++ b/internal/command/saml_request_model.go @@ -0,0 +1,94 @@ +package command + +import ( + "context" + "time" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/repository/authrequest" + "github.com/zitadel/zitadel/internal/repository/samlrequest" + "github.com/zitadel/zitadel/internal/zerrors" +) + +type SAMLRequestWriteModel struct { + eventstore.WriteModel + aggregate *eventstore.Aggregate + + ApplicationID string + ACSURL string + RelayState string + RequestID string + Binding string + Issuer string + IssuerName string + Destination string + + SessionID string + UserID string + AuthTime time.Time + AuthMethods []domain.UserAuthMethodType + SAMLRequestState domain.SAMLRequestState +} + +func NewSAMLRequestWriteModel(ctx context.Context, id string) *SAMLRequestWriteModel { + return &SAMLRequestWriteModel{ + WriteModel: eventstore.WriteModel{ + AggregateID: id, + }, + aggregate: &authrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate, + } +} + +func (m *SAMLRequestWriteModel) Reduce() error { + for _, event := range m.Events { + switch e := event.(type) { + case *samlrequest.AddedEvent: + m.ApplicationID = e.ApplicationID + m.ACSURL = e.ACSURL + m.RelayState = e.RelayState + m.RequestID = e.RequestID + m.Binding = e.Binding + m.Issuer = e.Issuer + m.IssuerName = e.IssuerName + m.Destination = e.Destination + m.SAMLRequestState = domain.SAMLRequestStateAdded + case *samlrequest.SessionLinkedEvent: + m.SessionID = e.SessionID + m.UserID = e.UserID + m.AuthTime = e.AuthTime + m.AuthMethods = e.AuthMethods + case *samlrequest.FailedEvent: + m.SAMLRequestState = domain.SAMLRequestStateFailed + case *samlrequest.SucceededEvent: + m.SAMLRequestState = domain.SAMLRequestStateSucceeded + } + } + return m.WriteModel.Reduce() +} + +func (m *SAMLRequestWriteModel) Query() *eventstore.SearchQueryBuilder { + return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). + AddQuery(). + AggregateTypes(samlrequest.AggregateType). + AggregateIDs(m.AggregateID). + Builder() +} + +// CheckAuthenticated checks that the auth request exists, a session must have been linked +func (m *SAMLRequestWriteModel) CheckAuthenticated() error { + if m.SessionID == "" { + return zerrors.ThrowPreconditionFailed(nil, "AUTHR-SF2r2", "Errors.SAMLRequest.NotAuthenticated") + } + // in case of OIDC Code Flow, the code must have been exchanged + if m.ResponseType == domain.OIDCResponseTypeCode && m.AuthRequestState == domain.AuthRequestStateCodeExchanged { + return nil + } + // in case of OIDC Implicit Flow, check that the requests exists, but has not succeeded yet + if (m.ResponseType == domain.OIDCResponseTypeIDToken || m.ResponseType == domain.OIDCResponseTypeIDTokenToken) && + m.AuthRequestState == domain.AuthRequestStateAdded { + return nil + } + return zerrors.ThrowPreconditionFailed(nil, "AUTHR-sajk3", "Errors.SAMLRequest.NotAuthenticated") +} diff --git a/internal/command/saml_request_test.go b/internal/command/saml_request_test.go new file mode 100644 index 0000000000..05317e86f9 --- /dev/null +++ b/internal/command/saml_request_test.go @@ -0,0 +1,668 @@ +package command + +import ( + "context" + "net" + "net/http" + "testing" + "time" + + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/text/language" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/id" + "github.com/zitadel/zitadel/internal/id/mock" + "github.com/zitadel/zitadel/internal/repository/authrequest" + "github.com/zitadel/zitadel/internal/repository/samlrequest" + "github.com/zitadel/zitadel/internal/repository/session" + "github.com/zitadel/zitadel/internal/zerrors" +) + +func TestCommands_AddSAMLRequest(t *testing.T) { + mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient") + type fields struct { + eventstore func(t *testing.T) *eventstore.Eventstore + idGenerator id.Generator + } + type args struct { + ctx context.Context + request *SAMLRequest + } + tests := []struct { + name string + fields fields + args args + want *CurrentSAMLRequest + wantErr error + }{ + { + "already exists error", + fields{ + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "login", + "application", + "acs", + "relaystate", + "request", + "binding", + "issuer", + "name", + "destination", + ), + ), + ), + ), + idGenerator: mock.NewIDGeneratorExpectIDs(t, "id"), + }, + args{ + ctx: mockCtx, + request: &SAMLRequest{}, + }, + nil, + zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sf3gt", "Errors.AuthRequest.AlreadyExisting"), + }, + { + "added", + fields{ + eventstore: expectEventstore( + expectFilter(), + expectPush( + samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "login", + "application", + "acs", + "relaystate", + "request", + "binding", + "issuer", + "name", + "destination", + ), + ), + ), + idGenerator: mock.NewIDGeneratorExpectIDs(t, "id"), + }, + args{ + ctx: mockCtx, + request: &SAMLRequest{ + LoginClient: "login", + ApplicationID: "application", + ACSURL: "acs", + RelayState: "relaystate", + RequestID: "request", + Binding: "binding", + Issuer: "issuer", + IssuerName: "name", + Destination: "destination", + }, + }, + &CurrentSAMLRequest{ + SAMLRequest: &SAMLRequest{ + ID: "V2_id", + LoginClient: "login", + ApplicationID: "application", + ACSURL: "acs", + RelayState: "relaystate", + RequestID: "request", + Binding: "binding", + Issuer: "issuer", + IssuerName: "name", + Destination: "destination", + }, + }, + nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + eventstore: tt.fields.eventstore(t), + idGenerator: tt.fields.idGenerator, + } + got, err := c.AddSAMLRequest(tt.args.ctx, tt.args.request) + require.ErrorIs(t, tt.wantErr, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestCommands_LinkSessionToSAMLRequest(t *testing.T) { + mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient") + type fields struct { + eventstore func(t *testing.T) *eventstore.Eventstore + tokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) + } + type args struct { + ctx context.Context + id string + sessionID string + sessionToken string + checkLoginClient bool + } + type res struct { + details *domain.ObjectDetails + authReq *CurrentSAMLRequest + wantErr error + } + tests := []struct { + name string + fields fields + args args + res res + }{ + { + "authRequest not found", + fields{ + eventstore: expectEventstore( + expectFilter(), + ), + tokenVerifier: newMockTokenVerifierValid(), + }, + args{ + ctx: mockCtx, + id: "id", + sessionID: "sessionID", + }, + res{ + wantErr: zerrors.ThrowNotFound(nil, "COMMAND-jae5P", "Errors.AuthRequest.NotExisting"), + }, + }, + { + "authRequest not existing", + fields{ + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "login", + "application", + "acs", + "relaystate", + "request", + "binding", + "issuer", + "name", + "destination", + ), + ), + eventFromEventPusher( + authrequest.NewFailedEvent(mockCtx, &authrequest.NewAggregate("id", "instanceID").Aggregate, + domain.OIDCErrorReasonUnspecified), + ), + ), + ), + tokenVerifier: newMockTokenVerifierValid(), + }, + args{ + ctx: mockCtx, + id: "id", + sessionID: "sessionID", + }, + res{ + wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sx208nt", "Errors.AuthRequest.AlreadyHandled"), + }, + }, + { + "wrong login client", + fields{ + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "login", + "application", + "acs", + "relaystate", + "request", + "binding", + "issuer", + "name", + "destination", + ), + ), + ), + ), + tokenVerifier: newMockTokenVerifierValid(), + }, + args{ + ctx: authz.NewMockContext("instanceID", "orgID", "wrongLoginClient"), + id: "id", + sessionID: "sessionID", + sessionToken: "token", + checkLoginClient: true, + }, + res{ + wantErr: zerrors.ThrowPermissionDenied(nil, "COMMAND-rai9Y", "Errors.AuthRequest.WrongLoginClient"), + }, + }, + { + "session not existing", + fields{ + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "login", + "application", + "acs", + "relaystate", + "request", + "binding", + "issuer", + "name", + "destination", + ), + ), + ), + expectFilter(), + ), + tokenVerifier: newMockTokenVerifierValid(), + }, + args{ + ctx: mockCtx, + id: "V2_id", + sessionID: "sessionID", + }, + res{ + wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Flk38", "Errors.Session.NotExisting"), + }, + }, + { + "session expired", + fields{ + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "login", + "application", + "acs", + "relaystate", + "request", + "binding", + "issuer", + "name", + "destination", + ), + ), + ), + expectFilter( + eventFromEventPusher( + session.NewAddedEvent(mockCtx, + &session.NewAggregate("sessionID", "instance1").Aggregate, + &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + )), + eventFromEventPusher( + session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate, + "userID", "org1", testNow.Add(-5*time.Minute), &language.Afrikaans), + ), + eventFromEventPusher( + session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate, + testNow.Add(-5*time.Minute)), + ), + eventFromEventPusher( + session.NewLifetimeSetEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate, + 2*time.Minute), + ), + ), + ), + }, + args{ + ctx: mockCtx, + id: "V2_id", + sessionID: "sessionID", + sessionToken: "token", + }, + res{ + wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Hkl3d", "Errors.Session.Expired"), + }, + }, + { + "invalid session token", + fields{ + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "login", + "application", + "acs", + "relaystate", + "request", + "binding", + "issuer", + "name", + "destination", + ), + ), + ), + expectFilter( + eventFromEventPusher( + session.NewAddedEvent(mockCtx, + &session.NewAggregate("sessionID", "instance1").Aggregate, + &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + )), + ), + ), + tokenVerifier: newMockTokenVerifierInvalid(), + }, + args{ + ctx: mockCtx, + id: "V2_id", + sessionID: "sessionID", + sessionToken: "invalid", + }, + res{ + wantErr: zerrors.ThrowPermissionDenied(nil, "COMMAND-sGr42", "Errors.Session.Token.Invalid"), + }, + }, + { + "linked", + fields{ + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "login", + "application", + "acs", + "relaystate", + "request", + "binding", + "issuer", + "name", + "destination", + ), + ), + ), + expectFilter( + eventFromEventPusher( + session.NewAddedEvent(mockCtx, + &session.NewAggregate("sessionID", "instance1").Aggregate, + &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + )), + eventFromEventPusher( + session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate, + "userID", "org1", testNow, &language.Afrikaans), + ), + eventFromEventPusher( + session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate, + testNow), + ), + eventFromEventPusherWithCreationDateNow( + session.NewLifetimeSetEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate, + 2*time.Minute), + ), + ), + expectPush( + authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "sessionID", + "userID", + testNow, + []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, + ), + ), + ), + tokenVerifier: newMockTokenVerifierValid(), + }, + args{ + ctx: mockCtx, + id: "V2_id", + sessionID: "sessionID", + sessionToken: "token", + }, + res{ + details: &domain.ObjectDetails{ResourceOwner: "instanceID"}, + authReq: &CurrentSAMLRequest{ + SAMLRequest: &SAMLRequest{ + ID: "V2_id", + ApplicationID: "application", + ACSURL: "acs", + RelayState: "relaystate", + RequestID: "request", + Binding: "binding", + Issuer: "issuer", + IssuerName: "name", + Destination: "destination", + }, + SessionID: "sessionID", + UserID: "userID", + AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, + }, + }, + }, + { + "linked with login client check", + fields{ + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "loginClient", + "clientID", + "redirectURI", + "state", + "nonce", + []string{"openid"}, + []string{"audience"}, + domain.OIDCResponseTypeCode, + domain.OIDCResponseModeQuery, + nil, + nil, + nil, + nil, + nil, + nil, + true, + ), + ), + ), + expectFilter( + eventFromEventPusher( + session.NewAddedEvent(mockCtx, + &session.NewAggregate("sessionID", "instance1").Aggregate, + &domain.UserAgent{ + FingerprintID: gu.Ptr("fp1"), + IP: net.ParseIP("1.2.3.4"), + Description: gu.Ptr("firefox"), + Header: http.Header{"foo": []string{"bar"}}, + }, + )), + eventFromEventPusher( + session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate, + "userID", "org1", testNow, &language.Afrikaans), + ), + eventFromEventPusher( + session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate, + testNow), + ), + eventFromEventPusherWithCreationDateNow( + session.NewLifetimeSetEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate, + 2*time.Minute), + ), + ), + expectPush( + authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "sessionID", + "userID", + testNow, + []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, + ), + ), + ), + tokenVerifier: newMockTokenVerifierValid(), + }, + args{ + ctx: authz.NewMockContext("instanceID", "orgID", "loginClient"), + id: "V2_id", + sessionID: "sessionID", + sessionToken: "token", + checkLoginClient: true, + }, + res{ + details: &domain.ObjectDetails{ResourceOwner: "instanceID"}, + authReq: &CurrentSAMLRequest{ + SAMLRequest: &SAMLRequest{ + ID: "V2_id", + ApplicationID: "application", + ACSURL: "acs", + RelayState: "relaystate", + RequestID: "request", + Binding: "binding", + Issuer: "issuer", + IssuerName: "name", + Destination: "destination", + }, + SessionID: "sessionID", + UserID: "userID", + AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + eventstore: tt.fields.eventstore(t), + sessionTokenVerifier: tt.fields.tokenVerifier, + } + details, got, err := c.LinkSessionToSAMLRequest(tt.args.ctx, tt.args.id, tt.args.sessionID, tt.args.sessionToken) + require.ErrorIs(t, err, tt.res.wantErr) + assertObjectDetails(t, tt.res.details, details) + if err == nil { + assert.WithinRange(t, got.AuthTime, testNow, testNow) + got.AuthTime = time.Time{} + } + assert.Equal(t, tt.res.authReq, got) + }) + } +} + +func TestCommands_FailSAMLRequest(t *testing.T) { + mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient") + type fields struct { + eventstore func(t *testing.T) *eventstore.Eventstore + } + type args struct { + ctx context.Context + id string + reason domain.OIDCErrorReason + } + type res struct { + details *domain.ObjectDetails + authReq *CurrentAuthRequest + wantErr error + } + tests := []struct { + name string + fields fields + args args + res res + }{ + { + "authRequest not existing", + fields{ + eventstore: expectEventstore( + expectFilter(), + ), + }, + args{ + ctx: mockCtx, + id: "foo", + reason: domain.OIDCErrorReasonLoginRequired, + }, + res{ + wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sx202nt", "Errors.AuthRequest.AlreadyHandled"), + }, + }, + { + "failed", + fields{ + eventstore: expectEventstore( + expectFilter( + eventFromEventPusher( + authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "loginClient", + "clientID", + "redirectURI", + "state", + "nonce", + []string{"openid"}, + []string{"audience"}, + domain.OIDCResponseTypeCode, + domain.OIDCResponseModeQuery, + nil, + nil, + nil, + nil, + nil, + nil, + true, + ), + ), + ), + expectPush( + authrequest.NewFailedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + domain.OIDCErrorReasonLoginRequired), + ), + ), + }, + args{ + ctx: mockCtx, + id: "V2_id", + reason: domain.OIDCErrorReasonLoginRequired, + }, + res{ + details: &domain.ObjectDetails{ResourceOwner: "instanceID"}, + authReq: &CurrentAuthRequest{ + AuthRequest: &AuthRequest{ + ID: "V2_id", + LoginClient: "loginClient", + ClientID: "clientID", + RedirectURI: "redirectURI", + State: "state", + Nonce: "nonce", + Scope: []string{"openid"}, + Audience: []string{"audience"}, + ResponseType: domain.OIDCResponseTypeCode, + ResponseMode: domain.OIDCResponseModeQuery, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + eventstore: tt.fields.eventstore(t), + } + details, got, err := c.FailAuthRequest(tt.args.ctx, tt.args.id, tt.args.reason) + require.ErrorIs(t, err, tt.res.wantErr) + assertObjectDetails(t, tt.res.details, details) + assert.Equal(t, tt.res.authReq, got) + }) + } +} diff --git a/internal/domain/saml_error_reason.go b/internal/domain/saml_error_reason.go new file mode 100644 index 0000000000..b302360a06 --- /dev/null +++ b/internal/domain/saml_error_reason.go @@ -0,0 +1,11 @@ +package domain + +type SAMLErrorReason int32 + +const ( + SAMLErrorReasonUnspecified SAMLErrorReason = iota +) + +func SAMLErrorReasonFromError(err error) SAMLErrorReason { + return SAMLErrorReasonUnspecified +} diff --git a/internal/domain/saml_request.go b/internal/domain/saml_request.go new file mode 100644 index 0000000000..8cf13be544 --- /dev/null +++ b/internal/domain/saml_request.go @@ -0,0 +1,10 @@ +package domain + +type SAMLRequestState int + +const ( + SAMLRequestStateUnspecified SAMLRequestState = iota + SAMLRequestStateAdded + SAMLRequestStateFailed + SAMLRequestStateSucceeded +) diff --git a/internal/repository/samlrequest/aggregate.go b/internal/repository/samlrequest/aggregate.go new file mode 100644 index 0000000000..551d64c70b --- /dev/null +++ b/internal/repository/samlrequest/aggregate.go @@ -0,0 +1,26 @@ +package samlrequest + +import ( + "github.com/zitadel/zitadel/internal/eventstore" +) + +const ( + AggregateType = "saml_request" + AggregateVersion = "v1" +) + +type Aggregate struct { + eventstore.Aggregate +} + +func NewAggregate(id, instanceID string) *Aggregate { + return &Aggregate{ + Aggregate: eventstore.Aggregate{ + Type: AggregateType, + Version: AggregateVersion, + ID: id, + ResourceOwner: instanceID, + InstanceID: instanceID, + }, + } +} diff --git a/internal/repository/samlrequest/eventstore.go b/internal/repository/samlrequest/eventstore.go new file mode 100644 index 0000000000..85cbec4460 --- /dev/null +++ b/internal/repository/samlrequest/eventstore.go @@ -0,0 +1,10 @@ +package samlrequest + +import "github.com/zitadel/zitadel/internal/eventstore" + +func init() { + eventstore.RegisterFilterEventMapper(AggregateType, AddedType, eventstore.GenericEventMapper[AddedEvent]) + eventstore.RegisterFilterEventMapper(AggregateType, SessionLinkedType, eventstore.GenericEventMapper[SessionLinkedEvent]) + eventstore.RegisterFilterEventMapper(AggregateType, FailedType, eventstore.GenericEventMapper[FailedEvent]) + eventstore.RegisterFilterEventMapper(AggregateType, SucceededType, eventstore.GenericEventMapper[SucceededEvent]) +} diff --git a/internal/repository/samlrequest/saml_request.go b/internal/repository/samlrequest/saml_request.go new file mode 100644 index 0000000000..fc0ca3ee58 --- /dev/null +++ b/internal/repository/samlrequest/saml_request.go @@ -0,0 +1,175 @@ +package samlrequest + +import ( + "context" + "time" + + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/eventstore" +) + +const ( + samlRequestEventPrefix = "saml_request." + AddedType = samlRequestEventPrefix + "added" + FailedType = samlRequestEventPrefix + "failed" + SessionLinkedType = samlRequestEventPrefix + "session.linked" + SucceededType = samlRequestEventPrefix + "succeeded" +) + +type AddedEvent struct { + *eventstore.BaseEvent `json:"-"` + + LoginClient string `json:"loginClient,omitempty"` + ApplicationID string `json:"application_id,omitempty"` + ACSURL string `json:"acs_url,omitempty"` + RelayState string `json:"relay_state,omitempty"` + RequestID string `json:"request_id,omitempty"` + Binding string `json:"binding,omitempty"` + Issuer string `json:"issuer,omitempty"` + IssuerName string `json:"issuer_name,omitempty"` + Destination string `json:"destination,omitempty"` +} + +func (e *AddedEvent) SetBaseEvent(event *eventstore.BaseEvent) { + e.BaseEvent = event +} + +func (e *AddedEvent) Payload() interface{} { + return e +} + +func (e *AddedEvent) UniqueConstraints() []*eventstore.UniqueConstraint { + return nil +} + +func NewAddedEvent(ctx context.Context, + aggregate *eventstore.Aggregate, + loginClient, + applicationID string, + acsURL string, + relayState string, + requestID string, + binding string, + issuer string, + issuerName string, + destination string, +) *AddedEvent { + return &AddedEvent{ + BaseEvent: eventstore.NewBaseEventForPush( + ctx, + aggregate, + AddedType, + ), + LoginClient: loginClient, + ApplicationID: applicationID, + ACSURL: acsURL, + RelayState: relayState, + RequestID: requestID, + Binding: binding, + Issuer: issuer, + IssuerName: issuerName, + Destination: destination, + } +} + +type SessionLinkedEvent struct { + *eventstore.BaseEvent `json:"-"` + + SessionID string `json:"session_id"` + UserID string `json:"user_id"` + AuthTime time.Time `json:"auth_time"` + AuthMethods []domain.UserAuthMethodType `json:"auth_methods"` +} + +func (e *SessionLinkedEvent) Payload() interface{} { + return e +} + +func (e *SessionLinkedEvent) UniqueConstraints() []*eventstore.UniqueConstraint { + return nil +} + +func NewSessionLinkedEvent(ctx context.Context, + aggregate *eventstore.Aggregate, + sessionID, + userID string, + authTime time.Time, + authMethods []domain.UserAuthMethodType, +) *SessionLinkedEvent { + return &SessionLinkedEvent{ + BaseEvent: eventstore.NewBaseEventForPush( + ctx, + aggregate, + SessionLinkedType, + ), + SessionID: sessionID, + UserID: userID, + AuthTime: authTime, + AuthMethods: authMethods, + } +} + +func (e *SessionLinkedEvent) SetBaseEvent(event *eventstore.BaseEvent) { + e.BaseEvent = event +} + +type FailedEvent struct { + *eventstore.BaseEvent `json:"-"` + + Reason domain.SAMLErrorReason `json:"reason,omitempty"` +} + +func (e *FailedEvent) Payload() interface{} { + return e +} + +func (e *FailedEvent) UniqueConstraints() []*eventstore.UniqueConstraint { + return nil +} + +func NewFailedEvent( + ctx context.Context, + aggregate *eventstore.Aggregate, + reason domain.SAMLErrorReason, +) *FailedEvent { + return &FailedEvent{ + BaseEvent: eventstore.NewBaseEventForPush( + ctx, + aggregate, + FailedType, + ), + Reason: reason, + } +} + +func (e *FailedEvent) SetBaseEvent(event *eventstore.BaseEvent) { + e.BaseEvent = event +} + +type SucceededEvent struct { + *eventstore.BaseEvent `json:"-"` +} + +func (e *SucceededEvent) Payload() interface{} { + return nil +} + +func (e *SucceededEvent) UniqueConstraints() []*eventstore.UniqueConstraint { + return nil +} + +func NewSucceededEvent(ctx context.Context, + aggregate *eventstore.Aggregate, +) *SucceededEvent { + return &SucceededEvent{ + BaseEvent: eventstore.NewBaseEventForPush( + ctx, + aggregate, + SucceededType, + ), + } +} + +func (e *SucceededEvent) SetBaseEvent(event *eventstore.BaseEvent) { + e.BaseEvent = event +} diff --git a/proto/zitadel/oidc/v2/oidc_service.proto b/proto/zitadel/oidc/v2/oidc_service.proto index 85044e9570..1d16f662f0 100644 --- a/proto/zitadel/oidc/v2/oidc_service.proto +++ b/proto/zitadel/oidc/v2/oidc_service.proto @@ -169,8 +169,8 @@ message CreateCallbackRequest { string auth_request_id = 1 [ (validate.rules).string = {min_len: 1, max_len: 200}, (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { - description: "Set this field when the authorization flow failed. It creates a callback URL to the application, with the error details set."; - ref: "https://openid.net/specs/openid-connect-core-1_0.html#AuthError"; + description: "ID of the SAML Request."; + example: "\"163840776835432705\""; } ]; diff --git a/proto/zitadel/saml/v2/authorization.proto b/proto/zitadel/saml/v2/authorization.proto new file mode 100644 index 0000000000..9bdada5ad0 --- /dev/null +++ b/proto/zitadel/saml/v2/authorization.proto @@ -0,0 +1,84 @@ +syntax = "proto3"; + +package zitadel.saml.v2; + +import "google/protobuf/duration.proto"; +import "google/protobuf/timestamp.proto"; +import "protoc-gen-openapiv2/options/annotations.proto"; + +option go_package = "github.com/zitadel/zitadel/pkg/grpc/saml/v2;saml"; + +message SAMLRequest{ + option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_schema) = { + external_docs: { + url: "https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest"; + description: "Find out more about SAML Auth Request parameters"; + } + }; + + string id = 1 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "ID of the authorization request"; + } + ]; + + google.protobuf.Timestamp creation_date = 2 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "Time when the auth request was created"; + } + ]; + + string issuer = 3 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "SAML entity ID of the application that created the auth request"; + } + ]; + + string assertion_consumer_url = 4 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "Base URI that points back to the application"; + } + ]; + + string relay_state = 5 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "RelayState provided by the application for the request"; + } + ]; + + string binding = 6 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "Binding used by the application for the request"; + } + ]; +} + +message AuthorizationError { + ErrorReason error = 1; + optional string error_description = 2; + optional string error_uri = 3; +} + +enum ErrorReason { + ERROR_REASON_UNSPECIFIED = 0; + + // Error states from https://datatracker.ietf.org/doc/html/rfc6749#section-4.2.2.1 + ERROR_REASON_INVALID_REQUEST = 1; + ERROR_REASON_UNAUTHORIZED_CLIENT = 2; + ERROR_REASON_ACCESS_DENIED = 3; + ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE = 4; + ERROR_REASON_INVALID_SCOPE = 5; + ERROR_REASON_SERVER_ERROR = 6; + ERROR_REASON_TEMPORARY_UNAVAILABLE = 7; + + // Error states from https://openid.net/specs/openid-connect-core-1_0.html#AuthError + ERROR_REASON_INTERACTION_REQUIRED = 8; + ERROR_REASON_LOGIN_REQUIRED = 9; + ERROR_REASON_ACCOUNT_SELECTION_REQUIRED = 10; + ERROR_REASON_CONSENT_REQUIRED = 11; + ERROR_REASON_INVALID_REQUEST_URI = 12; + ERROR_REASON_INVALID_REQUEST_OBJECT = 13; + ERROR_REASON_REQUEST_NOT_SUPPORTED = 14; + ERROR_REASON_REQUEST_URI_NOT_SUPPORTED = 15; + ERROR_REASON_REGISTRATION_NOT_SUPPORTED = 16; +} \ No newline at end of file diff --git a/proto/zitadel/saml/v2/saml_service.proto b/proto/zitadel/saml/v2/saml_service.proto new file mode 100644 index 0000000000..80a29c64b7 --- /dev/null +++ b/proto/zitadel/saml/v2/saml_service.proto @@ -0,0 +1,142 @@ +syntax = "proto3"; + +package zitadel.saml.v2; + +import "zitadel/object/v2/object.proto"; +import "zitadel/protoc_gen_zitadel/v2/options.proto"; +import "zitadel/saml/v2/authorization.proto"; +import "google/api/annotations.proto"; +import "google/api/field_behavior.proto"; +import "protoc-gen-openapiv2/options/annotations.proto"; +import "validate/validate.proto"; + +option go_package = "github.com/zitadel/zitadel/pkg/grpc/saml/v2;saml"; + +option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_swagger) = { + info: { + title: "SAML Service"; + version: "2.0"; + description: "Get SAML Auth Request details and create callback URLs."; + contact:{ + name: "ZITADEL" + url: "https://zitadel.com" + email: "hi@zitadel.com" + } + license: { + name: "Apache 2.0", + url: "https://github.com/zitadel/zitadel/blob/main/LICENSE"; + }; + }; + schemes: HTTPS; + schemes: HTTP; + + consumes: "application/json"; + consumes: "application/grpc"; + + produces: "application/json"; + produces: "application/grpc"; + + consumes: "application/grpc-web+proto"; + produces: "application/grpc-web+proto"; + + host: "$CUSTOM-DOMAIN"; + base_path: "/"; + + external_docs: { + description: "Detailed information about ZITADEL", + url: "https://zitadel.com/docs" + } + security_definitions: { + security: { + key: "OAuth2"; + value: { + type: TYPE_OAUTH2; + flow: FLOW_ACCESS_CODE; + authorization_url: "$CUSTOM-DOMAIN/oauth/v2/authorize"; + token_url: "$CUSTOM-DOMAIN/oauth/v2/token"; + scopes: { + scope: { + key: "openid"; + value: "openid"; + } + scope: { + key: "urn:zitadel:iam:org:project:id:zitadel:aud"; + value: "urn:zitadel:iam:org:project:id:zitadel:aud"; + } + } + } + } + } + security: { + security_requirement: { + key: "OAuth2"; + value: { + scope: "openid"; + scope: "urn:zitadel:iam:org:project:id:zitadel:aud"; + } + } + } + responses: { + key: "403"; + value: { + description: "Returned when the user does not have permission to access the resource."; + schema: { + json_schema: { + ref: "#/definitions/rpcStatus"; + } + } + } + } + responses: { + key: "404"; + value: { + description: "Returned when the resource does not exist."; + schema: { + json_schema: { + ref: "#/definitions/rpcStatus"; + } + } + } + } +}; + +service SAMLService { + rpc GetAuthRequest (GetSAMLRequestRequest) returns (GetSAMLRequestResponse) { + option (google.api.http) = { + get: "/v2/saml/saml_requests/{saml_request_id}" + }; + + option (zitadel.protoc_gen_zitadel.v2.options) = { + auth_option: { + permission: "authenticated" + } + }; + + option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = { + summary: "Get SAML Request details"; + description: "Get SAML Request details by ID. Returns details that are parsed from the application's SAML Request." + responses: { + key: "200" + value: { + description: "OK"; + } + }; + }; + } +} + +message GetSAMLRequestRequest { + string saml_request_id = 1 [ + (validate.rules).string = {min_len: 1, max_len: 200}, + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + min_length: 1; + max_length: 200; + description: "ID of the SAML Request, as obtained from the redirect URL."; + example: "\"163840776835432705\""; + } + ]; +} + +message GetSAMLRequestResponse { + SAMLRequest saml_request = 1; +} diff --git a/proto/zitadel/settings/v2/login_settings.proto b/proto/zitadel/settings/v2/login_settings.proto index 9fdbb45993..e578ce5f20 100644 --- a/proto/zitadel/settings/v2/login_settings.proto +++ b/proto/zitadel/settings/v2/login_settings.proto @@ -134,6 +134,7 @@ message IdentityProvider { string id = 1; string name = 2; IdentityProviderType type = 3; + bool is_linking_allowed = 4; } enum IdentityProviderType {