feat: add saml request to link to sessions

This commit is contained in:
Stefan Benz 2024-12-03 10:30:35 +01:00
parent 26e936aec3
commit 905da945ff
No known key found for this signature in database
GPG Key ID: 071AA751ED4F9D31
16 changed files with 1763 additions and 2 deletions

View File

@ -0,0 +1,203 @@
package oidc
import (
"context"
"github.com/zitadel/logging"
"github.com/zitadel/oidc/v3/pkg/op"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/zitadel/zitadel/internal/api/grpc/object/v2"
"github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/oidc"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/zerrors"
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2"
)
func (s *Server) GetAuthRequest(ctx context.Context, req *oidc_pb.GetAuthRequestRequest) (*oidc_pb.GetAuthRequestResponse, error) {
authRequest, err := s.query.AuthRequestByID(ctx, true, req.GetAuthRequestId(), true)
if err != nil {
logging.WithError(err).Error("query authRequest by ID")
return nil, err
}
return &oidc_pb.GetAuthRequestResponse{
AuthRequest: authRequestToPb(authRequest),
}, nil
}
func authRequestToPb(a *query.AuthRequest) *oidc_pb.AuthRequest {
pba := &oidc_pb.AuthRequest{
Id: a.ID,
CreationDate: timestamppb.New(a.CreationDate),
ClientId: a.ClientID,
Scope: a.Scope,
RedirectUri: a.RedirectURI,
Prompt: promptsToPb(a.Prompt),
UiLocales: a.UiLocales,
LoginHint: a.LoginHint,
HintUserId: a.HintUserID,
}
if a.MaxAge != nil {
pba.MaxAge = durationpb.New(*a.MaxAge)
}
return pba
}
func promptsToPb(promps []domain.Prompt) []oidc_pb.Prompt {
out := make([]oidc_pb.Prompt, len(promps))
for i, p := range promps {
out[i] = promptToPb(p)
}
return out
}
func promptToPb(p domain.Prompt) oidc_pb.Prompt {
switch p {
case domain.PromptUnspecified:
return oidc_pb.Prompt_PROMPT_UNSPECIFIED
case domain.PromptNone:
return oidc_pb.Prompt_PROMPT_NONE
case domain.PromptLogin:
return oidc_pb.Prompt_PROMPT_LOGIN
case domain.PromptConsent:
return oidc_pb.Prompt_PROMPT_CONSENT
case domain.PromptSelectAccount:
return oidc_pb.Prompt_PROMPT_SELECT_ACCOUNT
case domain.PromptCreate:
return oidc_pb.Prompt_PROMPT_CREATE
default:
return oidc_pb.Prompt_PROMPT_UNSPECIFIED
}
}
func (s *Server) CreateCallback(ctx context.Context, req *oidc_pb.CreateCallbackRequest) (*oidc_pb.CreateCallbackResponse, error) {
switch v := req.GetCallbackKind().(type) {
case *oidc_pb.CreateCallbackRequest_Error:
return s.failAuthRequest(ctx, req.GetAuthRequestId(), v.Error)
case *oidc_pb.CreateCallbackRequest_Session:
return s.linkSessionToAuthRequest(ctx, req.GetAuthRequestId(), v.Session)
default:
return nil, zerrors.ThrowUnimplementedf(nil, "OIDCv2-zee7A", "verification oneOf %T in method CreateCallback not implemented", v)
}
}
func (s *Server) failAuthRequest(ctx context.Context, authRequestID string, ae *oidc_pb.AuthorizationError) (*oidc_pb.CreateCallbackResponse, error) {
details, aar, err := s.command.FailAuthRequest(ctx, authRequestID, errorReasonToDomain(ae.GetError()))
if err != nil {
return nil, err
}
authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar}
callback, err := oidc.CreateErrorCallbackURL(authReq, errorReasonToOIDC(ae.GetError()), ae.GetErrorDescription(), ae.GetErrorUri(), s.op.Provider())
if err != nil {
return nil, err
}
return &oidc_pb.CreateCallbackResponse{
Details: object.DomainToDetailsPb(details),
CallbackUrl: callback,
}, nil
}
func (s *Server) linkSessionToAuthRequest(ctx context.Context, authRequestID string, session *oidc_pb.Session) (*oidc_pb.CreateCallbackResponse, error) {
details, aar, err := s.command.LinkSessionToAuthRequest(ctx, authRequestID, session.GetSessionId(), session.GetSessionToken(), true)
if err != nil {
return nil, err
}
authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar}
ctx = op.ContextWithIssuer(ctx, http.DomainContext(ctx).Origin())
var callback string
if aar.ResponseType == domain.OIDCResponseTypeCode {
callback, err = oidc.CreateCodeCallbackURL(ctx, authReq, s.op.Provider())
} else {
callback, err = s.op.CreateTokenCallbackURL(ctx, authReq)
}
if err != nil {
return nil, err
}
return &oidc_pb.CreateCallbackResponse{
Details: object.DomainToDetailsPb(details),
CallbackUrl: callback,
}, nil
}
func errorReasonToDomain(errorReason oidc_pb.ErrorReason) domain.OIDCErrorReason {
switch errorReason {
case oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED:
return domain.OIDCErrorReasonUnspecified
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST:
return domain.OIDCErrorReasonInvalidRequest
case oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT:
return domain.OIDCErrorReasonUnauthorizedClient
case oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED:
return domain.OIDCErrorReasonAccessDenied
case oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE:
return domain.OIDCErrorReasonUnsupportedResponseType
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE:
return domain.OIDCErrorReasonInvalidScope
case oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR:
return domain.OIDCErrorReasonServerError
case oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE:
return domain.OIDCErrorReasonTemporaryUnavailable
case oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED:
return domain.OIDCErrorReasonInteractionRequired
case oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED:
return domain.OIDCErrorReasonLoginRequired
case oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED:
return domain.OIDCErrorReasonAccountSelectionRequired
case oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED:
return domain.OIDCErrorReasonConsentRequired
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI:
return domain.OIDCErrorReasonInvalidRequestURI
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT:
return domain.OIDCErrorReasonInvalidRequestObject
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED:
return domain.OIDCErrorReasonRequestNotSupported
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED:
return domain.OIDCErrorReasonRequestURINotSupported
case oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED:
return domain.OIDCErrorReasonRegistrationNotSupported
default:
return domain.OIDCErrorReasonUnspecified
}
}
func errorReasonToOIDC(reason oidc_pb.ErrorReason) string {
switch reason {
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST:
return "invalid_request"
case oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT:
return "unauthorized_client"
case oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED:
return "access_denied"
case oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE:
return "unsupported_response_type"
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE:
return "invalid_scope"
case oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE:
return "temporarily_unavailable"
case oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED:
return "interaction_required"
case oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED:
return "login_required"
case oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED:
return "account_selection_required"
case oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED:
return "consent_required"
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI:
return "invalid_request_uri"
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT:
return "invalid_request_object"
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED:
return "request_not_supported"
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED:
return "request_uri_not_supported"
case oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED:
return "registration_not_supported"
case oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED, oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR:
fallthrough
default:
return "server_error"
}
}

View File

@ -0,0 +1,59 @@
package oidc
import (
"google.golang.org/grpc"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/grpc/server"
"github.com/zitadel/zitadel/internal/api/oidc"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/query"
saml_pb "github.com/zitadel/zitadel/pkg/grpc/saml/v2"
)
var _ saml_pb.SAMLServiceServer = (*Server)(nil)
type Server struct {
saml_pb.UnimplementedSAMLServiceServer
command *command.Commands
query *query.Queries
op *oidc.Server
externalSecure bool
}
type Config struct{}
func CreateServer(
command *command.Commands,
query *query.Queries,
op *oidc.Server,
externalSecure bool,
) *Server {
return &Server{
command: command,
query: query,
op: op,
externalSecure: externalSecure,
}
}
func (s *Server) RegisterServer(grpcServer *grpc.Server) {
saml_pb.RegisterSAMLServiceServer(grpcServer, s)
}
func (s *Server) AppName() string {
return saml_pb.SAMLService_ServiceDesc.ServiceName
}
func (s *Server) MethodPrefix() string {
return saml_pb.SAMLService_ServiceDesc.ServiceName
}
func (s *Server) AuthMethods() authz.MethodMapping {
return saml_pb.SAMLService_AuthMethods
}
func (s *Server) RegisterGateway() server.RegisterGatewayFunc {
return saml_pb.RegisterSAMLServiceHandler
}

View File

@ -0,0 +1,60 @@
package saml
import (
"github.com/zitadel/saml/pkg/provider/models"
"github.com/zitadel/zitadel/internal/command"
)
var _ models.AuthRequestInt = &AuthRequestV2{}
type AuthRequestV2 struct {
*command.CurrentSAMLRequest
}
func (a *AuthRequestV2) GetApplicationID() string {
return a.ApplicationID
}
func (a *AuthRequestV2) GetID() string {
return a.ID
}
func (a *AuthRequestV2) GetRelayState() string {
return a.RelayState
}
func (a *AuthRequestV2) GetAccessConsumerServiceURL() string {
return a.ACSURL
}
func (a *AuthRequestV2) GetNameID() string {
return a.UserID
}
func (a *AuthRequestV2) GetAuthRequestID() string {
return a.RequestID
}
func (a *AuthRequestV2) GetBindingType() string {
return a.Binding
}
func (a *AuthRequestV2) GetIssuer() string {
return a.Issuer
}
func (a *AuthRequestV2) GetIssuerName() string {
return a.IssuerName
}
func (a *AuthRequestV2) GetDestination() string {
return a.Destination
}
func (a *AuthRequestV2) GetCode() string {
return ""
}
func (a *AuthRequestV2) GetUserID() string {
return a.UserID
}
func (a *AuthRequestV2) GetUserName() string {
return ""
}
func (a *AuthRequestV2) Done() bool {
return a.UserID != "" && a.SessionID != ""
}

View File

@ -3,6 +3,7 @@ package saml
import (
"context"
"encoding/json"
"strings"
"time"
"github.com/dop251/goja"
@ -16,6 +17,7 @@ import (
"github.com/zitadel/zitadel/internal/actions"
"github.com/zitadel/zitadel/internal/actions/object"
"github.com/zitadel/zitadel/internal/activity"
http_utils "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/auth/repository"
"github.com/zitadel/zitadel/internal/command"
@ -33,6 +35,10 @@ var _ provider.IdentityProviderStorage = &Storage{}
var _ provider.AuthStorage = &Storage{}
var _ provider.UserStorage = &Storage{}
const (
LoginClientHeader = "x-zitadel-login-client"
)
type Storage struct {
certChan <-chan interface{}
defaultCertificateLifetime time.Duration
@ -95,6 +101,47 @@ func (p *Storage) GetResponseSigningKey(ctx context.Context) (*key.CertificateAn
func (p *Storage) CreateAuthRequest(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID string) (_ models.AuthRequestInt, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
headers, _ := http_utils.HeadersFromCtx(ctx)
if loginClient := headers.Get(LoginClientHeader); loginClient != "" {
return p.createAuthRequestLoginClient(ctx, req, acsUrl, protocolBinding, relayState, applicationID)
}
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
if !ok {
return nil, zerrors.ThrowPreconditionFailed(nil, "SAML-sd436", "no user agent id")
}
resp, err := p.repo.CreateAuthRequest(ctx, CreateAuthRequestToBusiness(ctx, req, acsUrl, protocolBinding, applicationID, relayState, userAgentID))
if err != nil {
return nil, err
}
return AuthRequestFromBusiness(resp)
}
func (p *Storage) createAuthRequestLoginClient(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID string) (models.AuthRequestInt, error) {
samlRequest := &command.SAMLRequest{
ApplicationID: applicationID,
ACSURL: acsUrl,
RelayState: relayState,
RequestID: req.Id,
Binding: protocolBinding,
Issuer: req.Issuer.Text,
IssuerName: req.Issuer.SPProvidedID,
Destination: req.Destination,
}
aar, err := p.command.AddSAMLRequest(ctx, samlRequest)
if err != nil {
return nil, err
}
return &AuthRequestV2{aar}, nil
}
func (p *Storage) createAuthRequest(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID string) (_ models.AuthRequestInt, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
if !ok {
return nil, zerrors.ThrowPreconditionFailed(nil, "SAML-sd436", "no user agent id")
@ -113,6 +160,15 @@ func (p *Storage) CreateAuthRequest(ctx context.Context, req *samlp.AuthnRequest
func (p *Storage) AuthRequestByID(ctx context.Context, id string) (_ models.AuthRequestInt, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if strings.HasPrefix(id, command.IDPrefixV2) {
req, err := p.command.GetCurrentSAMLRequest(ctx, id)
if err != nil {
return nil, err
}
return &AuthRequestV2{req}, nil
}
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
if !ok {
return nil, zerrors.ThrowPreconditionFailed(nil, "SAML-D3g21", "no user agent id")

View File

@ -0,0 +1,162 @@
package command
import (
"context"
"time"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/repository/authrequest"
"github.com/zitadel/zitadel/internal/repository/samlrequest"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
type SAMLRequest struct {
ID string
LoginClient string
ApplicationID string
EntityID string
ACSURL string
RelayState string
RequestID string
Binding string
Issuer string
IssuerName string
Destination string
}
type CurrentSAMLRequest struct {
*SAMLRequest
SessionID string
UserID string
AuthMethods []domain.UserAuthMethodType
AuthTime time.Time
}
func (c *Commands) AddSAMLRequest(ctx context.Context, samlRequest *SAMLRequest) (_ *CurrentSAMLRequest, err error) {
id, err := c.idGenerator.Next()
if err != nil {
return nil, err
}
samlRequest.ID = IDPrefixV2 + id
writeModel, err := c.getSAMLRequestWriteModel(ctx, samlRequest.ID)
if err != nil {
return nil, err
}
if writeModel.SAMLRequestState != domain.SAMLRequestStateUnspecified {
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sf3gt", "Errors.SAMLRequest.AlreadyExisting")
}
err = c.pushAppendAndReduce(ctx, writeModel, samlrequest.NewAddedEvent(
ctx,
&authrequest.NewAggregate(samlRequest.ID, authz.GetInstance(ctx).InstanceID()).Aggregate,
samlRequest.LoginClient,
samlRequest.ApplicationID,
samlRequest.ACSURL,
samlRequest.RelayState,
samlRequest.RequestID,
samlRequest.Binding,
samlRequest.Issuer,
samlRequest.IssuerName,
samlRequest.Destination,
))
if err != nil {
return nil, err
}
return samlRequestWriteModelToCurrentSAMLRequest(writeModel), nil
}
func (c *Commands) LinkSessionToSAMLRequest(ctx context.Context, id, sessionID, sessionToken string) (*domain.ObjectDetails, *CurrentSAMLRequest, error) {
writeModel, err := c.getSAMLRequestWriteModel(ctx, id)
if err != nil {
return nil, nil, err
}
if writeModel.SAMLRequestState == domain.SAMLRequestStateUnspecified {
return nil, nil, zerrors.ThrowNotFound(nil, "COMMAND-jae5P", "Errors.SAMLRequest.NotExisting")
}
if writeModel.SAMLRequestState != domain.SAMLRequestStateAdded {
return nil, nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sx208nt", "Errors.SAMLRequest.AlreadyHandled")
}
sessionWriteModel := NewSessionWriteModel(sessionID, authz.GetInstance(ctx).InstanceID())
err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel)
if err != nil {
return nil, nil, err
}
if err = sessionWriteModel.CheckIsActive(); err != nil {
return nil, nil, err
}
if err := c.sessionTokenVerifier(ctx, sessionToken, sessionWriteModel.AggregateID, sessionWriteModel.TokenID); err != nil {
return nil, nil, err
}
if err := c.pushAppendAndReduce(ctx, writeModel, samlrequest.NewSessionLinkedEvent(
ctx, &samlrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate,
sessionID,
sessionWriteModel.UserID,
sessionWriteModel.AuthenticationTime(),
sessionWriteModel.AuthMethodTypes(),
)); err != nil {
return nil, nil, err
}
return writeModelToObjectDetails(&writeModel.WriteModel), samlRequestWriteModelToCurrentSAMLRequest(writeModel), nil
}
func (c *Commands) FailSAMLRequest(ctx context.Context, id string, reason domain.SAMLErrorReason) (*domain.ObjectDetails, *CurrentSAMLRequest, error) {
writeModel, err := c.getSAMLRequestWriteModel(ctx, id)
if err != nil {
return nil, nil, err
}
if writeModel.SAMLRequestState != domain.SAMLRequestStateAdded {
return nil, nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sx202nt", "Errors.SAMLRequest.AlreadyHandled")
}
err = c.pushAppendAndReduce(ctx, writeModel, samlrequest.NewFailedEvent(
ctx,
&samlrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate,
reason,
))
if err != nil {
return nil, nil, err
}
return writeModelToObjectDetails(&writeModel.WriteModel), samlRequestWriteModelToCurrentSAMLRequest(writeModel), nil
}
func samlRequestWriteModelToCurrentSAMLRequest(writeModel *SAMLRequestWriteModel) (_ *CurrentSAMLRequest) {
return &CurrentSAMLRequest{
SAMLRequest: &SAMLRequest{
ID: writeModel.AggregateID,
ApplicationID: writeModel.ApplicationID,
ACSURL: writeModel.ACSURL,
RelayState: writeModel.RelayState,
RequestID: writeModel.RequestID,
Binding: writeModel.Binding,
Issuer: writeModel.Issuer,
IssuerName: writeModel.IssuerName,
Destination: writeModel.Destination,
},
SessionID: writeModel.SessionID,
UserID: writeModel.UserID,
AuthMethods: writeModel.AuthMethods,
AuthTime: writeModel.AuthTime,
}
}
func (c *Commands) GetCurrentSAMLRequest(ctx context.Context, id string) (_ *CurrentSAMLRequest, err error) {
wm, err := c.getSAMLRequestWriteModel(ctx, id)
if err != nil {
return nil, err
}
return samlRequestWriteModelToCurrentSAMLRequest(wm), nil
}
func (c *Commands) getSAMLRequestWriteModel(ctx context.Context, id string) (writeModel *SAMLRequestWriteModel, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
writeModel = NewSAMLRequestWriteModel(ctx, id)
err = c.eventstore.FilterToQueryReducer(ctx, writeModel)
if err != nil {
return nil, err
}
return writeModel, nil
}

View File

@ -0,0 +1,94 @@
package command
import (
"context"
"time"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/authrequest"
"github.com/zitadel/zitadel/internal/repository/samlrequest"
"github.com/zitadel/zitadel/internal/zerrors"
)
type SAMLRequestWriteModel struct {
eventstore.WriteModel
aggregate *eventstore.Aggregate
ApplicationID string
ACSURL string
RelayState string
RequestID string
Binding string
Issuer string
IssuerName string
Destination string
SessionID string
UserID string
AuthTime time.Time
AuthMethods []domain.UserAuthMethodType
SAMLRequestState domain.SAMLRequestState
}
func NewSAMLRequestWriteModel(ctx context.Context, id string) *SAMLRequestWriteModel {
return &SAMLRequestWriteModel{
WriteModel: eventstore.WriteModel{
AggregateID: id,
},
aggregate: &authrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate,
}
}
func (m *SAMLRequestWriteModel) Reduce() error {
for _, event := range m.Events {
switch e := event.(type) {
case *samlrequest.AddedEvent:
m.ApplicationID = e.ApplicationID
m.ACSURL = e.ACSURL
m.RelayState = e.RelayState
m.RequestID = e.RequestID
m.Binding = e.Binding
m.Issuer = e.Issuer
m.IssuerName = e.IssuerName
m.Destination = e.Destination
m.SAMLRequestState = domain.SAMLRequestStateAdded
case *samlrequest.SessionLinkedEvent:
m.SessionID = e.SessionID
m.UserID = e.UserID
m.AuthTime = e.AuthTime
m.AuthMethods = e.AuthMethods
case *samlrequest.FailedEvent:
m.SAMLRequestState = domain.SAMLRequestStateFailed
case *samlrequest.SucceededEvent:
m.SAMLRequestState = domain.SAMLRequestStateSucceeded
}
}
return m.WriteModel.Reduce()
}
func (m *SAMLRequestWriteModel) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes(samlrequest.AggregateType).
AggregateIDs(m.AggregateID).
Builder()
}
// CheckAuthenticated checks that the auth request exists, a session must have been linked
func (m *SAMLRequestWriteModel) CheckAuthenticated() error {
if m.SessionID == "" {
return zerrors.ThrowPreconditionFailed(nil, "AUTHR-SF2r2", "Errors.SAMLRequest.NotAuthenticated")
}
// in case of OIDC Code Flow, the code must have been exchanged
if m.ResponseType == domain.OIDCResponseTypeCode && m.AuthRequestState == domain.AuthRequestStateCodeExchanged {
return nil
}
// in case of OIDC Implicit Flow, check that the requests exists, but has not succeeded yet
if (m.ResponseType == domain.OIDCResponseTypeIDToken || m.ResponseType == domain.OIDCResponseTypeIDTokenToken) &&
m.AuthRequestState == domain.AuthRequestStateAdded {
return nil
}
return zerrors.ThrowPreconditionFailed(nil, "AUTHR-sajk3", "Errors.SAMLRequest.NotAuthenticated")
}

View File

@ -0,0 +1,668 @@
package command
import (
"context"
"net"
"net/http"
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/id/mock"
"github.com/zitadel/zitadel/internal/repository/authrequest"
"github.com/zitadel/zitadel/internal/repository/samlrequest"
"github.com/zitadel/zitadel/internal/repository/session"
"github.com/zitadel/zitadel/internal/zerrors"
)
func TestCommands_AddSAMLRequest(t *testing.T) {
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
type fields struct {
eventstore func(t *testing.T) *eventstore.Eventstore
idGenerator id.Generator
}
type args struct {
ctx context.Context
request *SAMLRequest
}
tests := []struct {
name string
fields fields
args args
want *CurrentSAMLRequest
wantErr error
}{
{
"already exists error",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"login",
"application",
"acs",
"relaystate",
"request",
"binding",
"issuer",
"name",
"destination",
),
),
),
),
idGenerator: mock.NewIDGeneratorExpectIDs(t, "id"),
},
args{
ctx: mockCtx,
request: &SAMLRequest{},
},
nil,
zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sf3gt", "Errors.AuthRequest.AlreadyExisting"),
},
{
"added",
fields{
eventstore: expectEventstore(
expectFilter(),
expectPush(
samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"login",
"application",
"acs",
"relaystate",
"request",
"binding",
"issuer",
"name",
"destination",
),
),
),
idGenerator: mock.NewIDGeneratorExpectIDs(t, "id"),
},
args{
ctx: mockCtx,
request: &SAMLRequest{
LoginClient: "login",
ApplicationID: "application",
ACSURL: "acs",
RelayState: "relaystate",
RequestID: "request",
Binding: "binding",
Issuer: "issuer",
IssuerName: "name",
Destination: "destination",
},
},
&CurrentSAMLRequest{
SAMLRequest: &SAMLRequest{
ID: "V2_id",
LoginClient: "login",
ApplicationID: "application",
ACSURL: "acs",
RelayState: "relaystate",
RequestID: "request",
Binding: "binding",
Issuer: "issuer",
IssuerName: "name",
Destination: "destination",
},
},
nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
idGenerator: tt.fields.idGenerator,
}
got, err := c.AddSAMLRequest(tt.args.ctx, tt.args.request)
require.ErrorIs(t, tt.wantErr, err)
assert.Equal(t, tt.want, got)
})
}
}
func TestCommands_LinkSessionToSAMLRequest(t *testing.T) {
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
type fields struct {
eventstore func(t *testing.T) *eventstore.Eventstore
tokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error)
}
type args struct {
ctx context.Context
id string
sessionID string
sessionToken string
checkLoginClient bool
}
type res struct {
details *domain.ObjectDetails
authReq *CurrentSAMLRequest
wantErr error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"authRequest not found",
fields{
eventstore: expectEventstore(
expectFilter(),
),
tokenVerifier: newMockTokenVerifierValid(),
},
args{
ctx: mockCtx,
id: "id",
sessionID: "sessionID",
},
res{
wantErr: zerrors.ThrowNotFound(nil, "COMMAND-jae5P", "Errors.AuthRequest.NotExisting"),
},
},
{
"authRequest not existing",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"login",
"application",
"acs",
"relaystate",
"request",
"binding",
"issuer",
"name",
"destination",
),
),
eventFromEventPusher(
authrequest.NewFailedEvent(mockCtx, &authrequest.NewAggregate("id", "instanceID").Aggregate,
domain.OIDCErrorReasonUnspecified),
),
),
),
tokenVerifier: newMockTokenVerifierValid(),
},
args{
ctx: mockCtx,
id: "id",
sessionID: "sessionID",
},
res{
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sx208nt", "Errors.AuthRequest.AlreadyHandled"),
},
},
{
"wrong login client",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"login",
"application",
"acs",
"relaystate",
"request",
"binding",
"issuer",
"name",
"destination",
),
),
),
),
tokenVerifier: newMockTokenVerifierValid(),
},
args{
ctx: authz.NewMockContext("instanceID", "orgID", "wrongLoginClient"),
id: "id",
sessionID: "sessionID",
sessionToken: "token",
checkLoginClient: true,
},
res{
wantErr: zerrors.ThrowPermissionDenied(nil, "COMMAND-rai9Y", "Errors.AuthRequest.WrongLoginClient"),
},
},
{
"session not existing",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"login",
"application",
"acs",
"relaystate",
"request",
"binding",
"issuer",
"name",
"destination",
),
),
),
expectFilter(),
),
tokenVerifier: newMockTokenVerifierValid(),
},
args{
ctx: mockCtx,
id: "V2_id",
sessionID: "sessionID",
},
res{
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Flk38", "Errors.Session.NotExisting"),
},
},
{
"session expired",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"login",
"application",
"acs",
"relaystate",
"request",
"binding",
"issuer",
"name",
"destination",
),
),
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(mockCtx,
&session.NewAggregate("sessionID", "instance1").Aggregate,
&domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
)),
eventFromEventPusher(
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
"userID", "org1", testNow.Add(-5*time.Minute), &language.Afrikaans),
),
eventFromEventPusher(
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
testNow.Add(-5*time.Minute)),
),
eventFromEventPusher(
session.NewLifetimeSetEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
2*time.Minute),
),
),
),
},
args{
ctx: mockCtx,
id: "V2_id",
sessionID: "sessionID",
sessionToken: "token",
},
res{
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Hkl3d", "Errors.Session.Expired"),
},
},
{
"invalid session token",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"login",
"application",
"acs",
"relaystate",
"request",
"binding",
"issuer",
"name",
"destination",
),
),
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(mockCtx,
&session.NewAggregate("sessionID", "instance1").Aggregate,
&domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
)),
),
),
tokenVerifier: newMockTokenVerifierInvalid(),
},
args{
ctx: mockCtx,
id: "V2_id",
sessionID: "sessionID",
sessionToken: "invalid",
},
res{
wantErr: zerrors.ThrowPermissionDenied(nil, "COMMAND-sGr42", "Errors.Session.Token.Invalid"),
},
},
{
"linked",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
samlrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"login",
"application",
"acs",
"relaystate",
"request",
"binding",
"issuer",
"name",
"destination",
),
),
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(mockCtx,
&session.NewAggregate("sessionID", "instance1").Aggregate,
&domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
)),
eventFromEventPusher(
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
"userID", "org1", testNow, &language.Afrikaans),
),
eventFromEventPusher(
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
testNow),
),
eventFromEventPusherWithCreationDateNow(
session.NewLifetimeSetEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
2*time.Minute),
),
),
expectPush(
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"sessionID",
"userID",
testNow,
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
),
),
),
tokenVerifier: newMockTokenVerifierValid(),
},
args{
ctx: mockCtx,
id: "V2_id",
sessionID: "sessionID",
sessionToken: "token",
},
res{
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
authReq: &CurrentSAMLRequest{
SAMLRequest: &SAMLRequest{
ID: "V2_id",
ApplicationID: "application",
ACSURL: "acs",
RelayState: "relaystate",
RequestID: "request",
Binding: "binding",
Issuer: "issuer",
IssuerName: "name",
Destination: "destination",
},
SessionID: "sessionID",
UserID: "userID",
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
},
},
},
{
"linked with login client check",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
domain.OIDCResponseModeQuery,
nil,
nil,
nil,
nil,
nil,
nil,
true,
),
),
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(mockCtx,
&session.NewAggregate("sessionID", "instance1").Aggregate,
&domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
)),
eventFromEventPusher(
session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
"userID", "org1", testNow, &language.Afrikaans),
),
eventFromEventPusher(
session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
testNow),
),
eventFromEventPusherWithCreationDateNow(
session.NewLifetimeSetEvent(mockCtx, &session.NewAggregate("sessionID", "instance1").Aggregate,
2*time.Minute),
),
),
expectPush(
authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"sessionID",
"userID",
testNow,
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
),
),
),
tokenVerifier: newMockTokenVerifierValid(),
},
args{
ctx: authz.NewMockContext("instanceID", "orgID", "loginClient"),
id: "V2_id",
sessionID: "sessionID",
sessionToken: "token",
checkLoginClient: true,
},
res{
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
authReq: &CurrentSAMLRequest{
SAMLRequest: &SAMLRequest{
ID: "V2_id",
ApplicationID: "application",
ACSURL: "acs",
RelayState: "relaystate",
RequestID: "request",
Binding: "binding",
Issuer: "issuer",
IssuerName: "name",
Destination: "destination",
},
SessionID: "sessionID",
UserID: "userID",
AuthMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
sessionTokenVerifier: tt.fields.tokenVerifier,
}
details, got, err := c.LinkSessionToSAMLRequest(tt.args.ctx, tt.args.id, tt.args.sessionID, tt.args.sessionToken)
require.ErrorIs(t, err, tt.res.wantErr)
assertObjectDetails(t, tt.res.details, details)
if err == nil {
assert.WithinRange(t, got.AuthTime, testNow, testNow)
got.AuthTime = time.Time{}
}
assert.Equal(t, tt.res.authReq, got)
})
}
}
func TestCommands_FailSAMLRequest(t *testing.T) {
mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient")
type fields struct {
eventstore func(t *testing.T) *eventstore.Eventstore
}
type args struct {
ctx context.Context
id string
reason domain.OIDCErrorReason
}
type res struct {
details *domain.ObjectDetails
authReq *CurrentAuthRequest
wantErr error
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
"authRequest not existing",
fields{
eventstore: expectEventstore(
expectFilter(),
),
},
args{
ctx: mockCtx,
id: "foo",
reason: domain.OIDCErrorReasonLoginRequired,
},
res{
wantErr: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sx202nt", "Errors.AuthRequest.AlreadyHandled"),
},
},
{
"failed",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
domain.OIDCResponseModeQuery,
nil,
nil,
nil,
nil,
nil,
nil,
true,
),
),
),
expectPush(
authrequest.NewFailedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate,
domain.OIDCErrorReasonLoginRequired),
),
),
},
args{
ctx: mockCtx,
id: "V2_id",
reason: domain.OIDCErrorReasonLoginRequired,
},
res{
details: &domain.ObjectDetails{ResourceOwner: "instanceID"},
authReq: &CurrentAuthRequest{
AuthRequest: &AuthRequest{
ID: "V2_id",
LoginClient: "loginClient",
ClientID: "clientID",
RedirectURI: "redirectURI",
State: "state",
Nonce: "nonce",
Scope: []string{"openid"},
Audience: []string{"audience"},
ResponseType: domain.OIDCResponseTypeCode,
ResponseMode: domain.OIDCResponseModeQuery,
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Commands{
eventstore: tt.fields.eventstore(t),
}
details, got, err := c.FailAuthRequest(tt.args.ctx, tt.args.id, tt.args.reason)
require.ErrorIs(t, err, tt.res.wantErr)
assertObjectDetails(t, tt.res.details, details)
assert.Equal(t, tt.res.authReq, got)
})
}
}

View File

@ -0,0 +1,11 @@
package domain
type SAMLErrorReason int32
const (
SAMLErrorReasonUnspecified SAMLErrorReason = iota
)
func SAMLErrorReasonFromError(err error) SAMLErrorReason {
return SAMLErrorReasonUnspecified
}

View File

@ -0,0 +1,10 @@
package domain
type SAMLRequestState int
const (
SAMLRequestStateUnspecified SAMLRequestState = iota
SAMLRequestStateAdded
SAMLRequestStateFailed
SAMLRequestStateSucceeded
)

View File

@ -0,0 +1,26 @@
package samlrequest
import (
"github.com/zitadel/zitadel/internal/eventstore"
)
const (
AggregateType = "saml_request"
AggregateVersion = "v1"
)
type Aggregate struct {
eventstore.Aggregate
}
func NewAggregate(id, instanceID string) *Aggregate {
return &Aggregate{
Aggregate: eventstore.Aggregate{
Type: AggregateType,
Version: AggregateVersion,
ID: id,
ResourceOwner: instanceID,
InstanceID: instanceID,
},
}
}

View File

@ -0,0 +1,10 @@
package samlrequest
import "github.com/zitadel/zitadel/internal/eventstore"
func init() {
eventstore.RegisterFilterEventMapper(AggregateType, AddedType, eventstore.GenericEventMapper[AddedEvent])
eventstore.RegisterFilterEventMapper(AggregateType, SessionLinkedType, eventstore.GenericEventMapper[SessionLinkedEvent])
eventstore.RegisterFilterEventMapper(AggregateType, FailedType, eventstore.GenericEventMapper[FailedEvent])
eventstore.RegisterFilterEventMapper(AggregateType, SucceededType, eventstore.GenericEventMapper[SucceededEvent])
}

View File

@ -0,0 +1,175 @@
package samlrequest
import (
"context"
"time"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
)
const (
samlRequestEventPrefix = "saml_request."
AddedType = samlRequestEventPrefix + "added"
FailedType = samlRequestEventPrefix + "failed"
SessionLinkedType = samlRequestEventPrefix + "session.linked"
SucceededType = samlRequestEventPrefix + "succeeded"
)
type AddedEvent struct {
*eventstore.BaseEvent `json:"-"`
LoginClient string `json:"loginClient,omitempty"`
ApplicationID string `json:"application_id,omitempty"`
ACSURL string `json:"acs_url,omitempty"`
RelayState string `json:"relay_state,omitempty"`
RequestID string `json:"request_id,omitempty"`
Binding string `json:"binding,omitempty"`
Issuer string `json:"issuer,omitempty"`
IssuerName string `json:"issuer_name,omitempty"`
Destination string `json:"destination,omitempty"`
}
func (e *AddedEvent) SetBaseEvent(event *eventstore.BaseEvent) {
e.BaseEvent = event
}
func (e *AddedEvent) Payload() interface{} {
return e
}
func (e *AddedEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
return nil
}
func NewAddedEvent(ctx context.Context,
aggregate *eventstore.Aggregate,
loginClient,
applicationID string,
acsURL string,
relayState string,
requestID string,
binding string,
issuer string,
issuerName string,
destination string,
) *AddedEvent {
return &AddedEvent{
BaseEvent: eventstore.NewBaseEventForPush(
ctx,
aggregate,
AddedType,
),
LoginClient: loginClient,
ApplicationID: applicationID,
ACSURL: acsURL,
RelayState: relayState,
RequestID: requestID,
Binding: binding,
Issuer: issuer,
IssuerName: issuerName,
Destination: destination,
}
}
type SessionLinkedEvent struct {
*eventstore.BaseEvent `json:"-"`
SessionID string `json:"session_id"`
UserID string `json:"user_id"`
AuthTime time.Time `json:"auth_time"`
AuthMethods []domain.UserAuthMethodType `json:"auth_methods"`
}
func (e *SessionLinkedEvent) Payload() interface{} {
return e
}
func (e *SessionLinkedEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
return nil
}
func NewSessionLinkedEvent(ctx context.Context,
aggregate *eventstore.Aggregate,
sessionID,
userID string,
authTime time.Time,
authMethods []domain.UserAuthMethodType,
) *SessionLinkedEvent {
return &SessionLinkedEvent{
BaseEvent: eventstore.NewBaseEventForPush(
ctx,
aggregate,
SessionLinkedType,
),
SessionID: sessionID,
UserID: userID,
AuthTime: authTime,
AuthMethods: authMethods,
}
}
func (e *SessionLinkedEvent) SetBaseEvent(event *eventstore.BaseEvent) {
e.BaseEvent = event
}
type FailedEvent struct {
*eventstore.BaseEvent `json:"-"`
Reason domain.SAMLErrorReason `json:"reason,omitempty"`
}
func (e *FailedEvent) Payload() interface{} {
return e
}
func (e *FailedEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
return nil
}
func NewFailedEvent(
ctx context.Context,
aggregate *eventstore.Aggregate,
reason domain.SAMLErrorReason,
) *FailedEvent {
return &FailedEvent{
BaseEvent: eventstore.NewBaseEventForPush(
ctx,
aggregate,
FailedType,
),
Reason: reason,
}
}
func (e *FailedEvent) SetBaseEvent(event *eventstore.BaseEvent) {
e.BaseEvent = event
}
type SucceededEvent struct {
*eventstore.BaseEvent `json:"-"`
}
func (e *SucceededEvent) Payload() interface{} {
return nil
}
func (e *SucceededEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
return nil
}
func NewSucceededEvent(ctx context.Context,
aggregate *eventstore.Aggregate,
) *SucceededEvent {
return &SucceededEvent{
BaseEvent: eventstore.NewBaseEventForPush(
ctx,
aggregate,
SucceededType,
),
}
}
func (e *SucceededEvent) SetBaseEvent(event *eventstore.BaseEvent) {
e.BaseEvent = event
}

View File

@ -169,8 +169,8 @@ message CreateCallbackRequest {
string auth_request_id = 1 [
(validate.rules).string = {min_len: 1, max_len: 200},
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "Set this field when the authorization flow failed. It creates a callback URL to the application, with the error details set.";
ref: "https://openid.net/specs/openid-connect-core-1_0.html#AuthError";
description: "ID of the SAML Request.";
example: "\"163840776835432705\"";
}
];

View File

@ -0,0 +1,84 @@
syntax = "proto3";
package zitadel.saml.v2;
import "google/protobuf/duration.proto";
import "google/protobuf/timestamp.proto";
import "protoc-gen-openapiv2/options/annotations.proto";
option go_package = "github.com/zitadel/zitadel/pkg/grpc/saml/v2;saml";
message SAMLRequest{
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_schema) = {
external_docs: {
url: "https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest";
description: "Find out more about SAML Auth Request parameters";
}
};
string id = 1 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "ID of the authorization request";
}
];
google.protobuf.Timestamp creation_date = 2 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "Time when the auth request was created";
}
];
string issuer = 3 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "SAML entity ID of the application that created the auth request";
}
];
string assertion_consumer_url = 4 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "Base URI that points back to the application";
}
];
string relay_state = 5 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "RelayState provided by the application for the request";
}
];
string binding = 6 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "Binding used by the application for the request";
}
];
}
message AuthorizationError {
ErrorReason error = 1;
optional string error_description = 2;
optional string error_uri = 3;
}
enum ErrorReason {
ERROR_REASON_UNSPECIFIED = 0;
// Error states from https://datatracker.ietf.org/doc/html/rfc6749#section-4.2.2.1
ERROR_REASON_INVALID_REQUEST = 1;
ERROR_REASON_UNAUTHORIZED_CLIENT = 2;
ERROR_REASON_ACCESS_DENIED = 3;
ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE = 4;
ERROR_REASON_INVALID_SCOPE = 5;
ERROR_REASON_SERVER_ERROR = 6;
ERROR_REASON_TEMPORARY_UNAVAILABLE = 7;
// Error states from https://openid.net/specs/openid-connect-core-1_0.html#AuthError
ERROR_REASON_INTERACTION_REQUIRED = 8;
ERROR_REASON_LOGIN_REQUIRED = 9;
ERROR_REASON_ACCOUNT_SELECTION_REQUIRED = 10;
ERROR_REASON_CONSENT_REQUIRED = 11;
ERROR_REASON_INVALID_REQUEST_URI = 12;
ERROR_REASON_INVALID_REQUEST_OBJECT = 13;
ERROR_REASON_REQUEST_NOT_SUPPORTED = 14;
ERROR_REASON_REQUEST_URI_NOT_SUPPORTED = 15;
ERROR_REASON_REGISTRATION_NOT_SUPPORTED = 16;
}

View File

@ -0,0 +1,142 @@
syntax = "proto3";
package zitadel.saml.v2;
import "zitadel/object/v2/object.proto";
import "zitadel/protoc_gen_zitadel/v2/options.proto";
import "zitadel/saml/v2/authorization.proto";
import "google/api/annotations.proto";
import "google/api/field_behavior.proto";
import "protoc-gen-openapiv2/options/annotations.proto";
import "validate/validate.proto";
option go_package = "github.com/zitadel/zitadel/pkg/grpc/saml/v2;saml";
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_swagger) = {
info: {
title: "SAML Service";
version: "2.0";
description: "Get SAML Auth Request details and create callback URLs.";
contact:{
name: "ZITADEL"
url: "https://zitadel.com"
email: "hi@zitadel.com"
}
license: {
name: "Apache 2.0",
url: "https://github.com/zitadel/zitadel/blob/main/LICENSE";
};
};
schemes: HTTPS;
schemes: HTTP;
consumes: "application/json";
consumes: "application/grpc";
produces: "application/json";
produces: "application/grpc";
consumes: "application/grpc-web+proto";
produces: "application/grpc-web+proto";
host: "$CUSTOM-DOMAIN";
base_path: "/";
external_docs: {
description: "Detailed information about ZITADEL",
url: "https://zitadel.com/docs"
}
security_definitions: {
security: {
key: "OAuth2";
value: {
type: TYPE_OAUTH2;
flow: FLOW_ACCESS_CODE;
authorization_url: "$CUSTOM-DOMAIN/oauth/v2/authorize";
token_url: "$CUSTOM-DOMAIN/oauth/v2/token";
scopes: {
scope: {
key: "openid";
value: "openid";
}
scope: {
key: "urn:zitadel:iam:org:project:id:zitadel:aud";
value: "urn:zitadel:iam:org:project:id:zitadel:aud";
}
}
}
}
}
security: {
security_requirement: {
key: "OAuth2";
value: {
scope: "openid";
scope: "urn:zitadel:iam:org:project:id:zitadel:aud";
}
}
}
responses: {
key: "403";
value: {
description: "Returned when the user does not have permission to access the resource.";
schema: {
json_schema: {
ref: "#/definitions/rpcStatus";
}
}
}
}
responses: {
key: "404";
value: {
description: "Returned when the resource does not exist.";
schema: {
json_schema: {
ref: "#/definitions/rpcStatus";
}
}
}
}
};
service SAMLService {
rpc GetAuthRequest (GetSAMLRequestRequest) returns (GetSAMLRequestResponse) {
option (google.api.http) = {
get: "/v2/saml/saml_requests/{saml_request_id}"
};
option (zitadel.protoc_gen_zitadel.v2.options) = {
auth_option: {
permission: "authenticated"
}
};
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
summary: "Get SAML Request details";
description: "Get SAML Request details by ID. Returns details that are parsed from the application's SAML Request."
responses: {
key: "200"
value: {
description: "OK";
}
};
};
}
}
message GetSAMLRequestRequest {
string saml_request_id = 1 [
(validate.rules).string = {min_len: 1, max_len: 200},
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
min_length: 1;
max_length: 200;
description: "ID of the SAML Request, as obtained from the redirect URL.";
example: "\"163840776835432705\"";
}
];
}
message GetSAMLRequestResponse {
SAMLRequest saml_request = 1;
}

View File

@ -134,6 +134,7 @@ message IdentityProvider {
string id = 1;
string name = 2;
IdentityProviderType type = 3;
bool is_linking_allowed = 4;
}
enum IdentityProviderType {