diff --git a/internal/api/grpc/saml/integration/saml_test.go b/internal/api/grpc/saml/integration/saml_test.go new file mode 100644 index 0000000000..5e33ea9673 --- /dev/null +++ b/internal/api/grpc/saml/integration/saml_test.go @@ -0,0 +1,250 @@ +//go:build integration + +package saml_test + +import ( + "context" + "net/url" + "os" + "regexp" + "testing" + "time" + + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/zitadel/zitadel/internal/integration" + "github.com/zitadel/zitadel/pkg/grpc/object/v2" + oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2" + saml_pb "github.com/zitadel/zitadel/pkg/grpc/saml/v2" + "github.com/zitadel/zitadel/pkg/grpc/session/v2" +) + +var ( + CTX context.Context + Instance *integration.Instance + Client saml_pb.SAMLServiceClient +) + +func TestMain(m *testing.M) { + os.Exit(func() int { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) + defer cancel() + + Instance = integration.NewInstance(ctx) + Client = Instance.Client.SAMLv2 + + CTX = Instance.WithAuthorization(ctx, integration.UserTypeOrgOwner) + return m.Run() + }()) +} + +func TestServer_GetAuthRequest(t *testing.T) { + entityID := "https://sp.example.com" + + project, err := Instance.CreateProject(CTX) + require.NoError(t, err) + client, err := Instance.CreateSAMLClient(CTX, project.GetId(), entityID, entityID+"/saml/v2/sso", entityID+"/saml/v2/slo") + require.NoError(t, err) + authRequestID, err := Instance.CreateSAMLAuthRequest(CTX, "", Instance.Users[integration.UserTypeOrgOwner].ID, "acs", "relaystate") + require.NoError(t, err) + now := time.Now() + + tests := []struct { + name string + AuthRequestID string + want *oidc_pb.GetAuthRequestResponse + wantErr bool + }{ + { + name: "Not found", + AuthRequestID: "123", + wantErr: true, + }, + { + name: "success", + AuthRequestID: authRequestID, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Client.GetAuthRequest(CTX, &saml_pb.GetSAMLRequestRequest{ + SamlRequestId: tt.AuthRequestID, + }) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + authRequest := got.GetSamlRequest() + assert.NotNil(t, authRequest) + assert.Equal(t, authRequestID, authRequest.GetId()) + assert.WithinRange(t, authRequest.GetCreationDate().AsTime(), now.Add(-time.Second), now.Add(time.Second)) + }) + } +} + +func TestServer_CreateCallback(t *testing.T) { + project, err := Instance.CreateProject(CTX) + require.NoError(t, err) + client, err := Instance.CreateOIDCNativeClient(CTX, redirectURI, logoutRedirectURI, project.GetId(), false) + require.NoError(t, err) + sessionResp, err := Instance.Client.SessionV2.CreateSession(CTX, &session.CreateSessionRequest{ + Checks: &session.Checks{ + User: &session.CheckUser{ + Search: &session.CheckUser_UserId{ + UserId: Instance.Users[integration.UserTypeOrgOwner].ID, + }, + }, + }, + }) + require.NoError(t, err) + + tests := []struct { + name string + req *oidc_pb.CreateCallbackRequest + AuthError string + want *oidc_pb.CreateCallbackResponse + wantURL *url.URL + wantErr bool + }{ + { + name: "Not found", + req: &oidc_pb.CreateCallbackRequest{ + AuthRequestId: "123", + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionResp.GetSessionId(), + SessionToken: sessionResp.GetSessionToken(), + }, + }, + }, + wantErr: true, + }, + { + name: "session not found", + req: &oidc_pb.CreateCallbackRequest{ + AuthRequestId: func() string { + authRequestID, err := Instance.CreateOIDCAuthRequest(CTX, client.GetClientId(), Instance.Users[integration.UserTypeOrgOwner].ID, redirectURI) + require.NoError(t, err) + return authRequestID + }(), + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: "foo", + SessionToken: "bar", + }, + }, + }, + wantErr: true, + }, + { + name: "session token invalid", + req: &oidc_pb.CreateCallbackRequest{ + AuthRequestId: func() string { + authRequestID, err := Instance.CreateOIDCAuthRequest(CTX, client.GetClientId(), Instance.Users.Get(integration.UserTypeOrgOwner).ID, redirectURI) + require.NoError(t, err) + return authRequestID + }(), + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionResp.GetSessionId(), + SessionToken: "bar", + }, + }, + }, + wantErr: true, + }, + { + name: "fail callback", + req: &oidc_pb.CreateCallbackRequest{ + AuthRequestId: func() string { + authRequestID, err := Instance.CreateOIDCAuthRequest(CTX, client.GetClientId(), Instance.Users.Get(integration.UserTypeOrgOwner).ID, redirectURI) + require.NoError(t, err) + return authRequestID + }(), + CallbackKind: &oidc_pb.CreateCallbackRequest_Error{ + Error: &oidc_pb.AuthorizationError{ + Error: oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED, + ErrorDescription: gu.Ptr("nope"), + ErrorUri: gu.Ptr("https://example.com/docs"), + }, + }, + }, + want: &oidc_pb.CreateCallbackResponse{ + CallbackUrl: regexp.QuoteMeta(`oidcintegrationtest://callback?error=access_denied&error_description=nope&error_uri=https%3A%2F%2Fexample.com%2Fdocs&state=state`), + Details: &object.Details{ + ChangeDate: timestamppb.Now(), + ResourceOwner: Instance.ID(), + }, + }, + wantErr: false, + }, + { + name: "code callback", + req: &oidc_pb.CreateCallbackRequest{ + AuthRequestId: func() string { + authRequestID, err := Instance.CreateOIDCAuthRequest(CTX, client.GetClientId(), Instance.Users.Get(integration.UserTypeOrgOwner).ID, redirectURI) + require.NoError(t, err) + return authRequestID + }(), + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionResp.GetSessionId(), + SessionToken: sessionResp.GetSessionToken(), + }, + }, + }, + want: &oidc_pb.CreateCallbackResponse{ + CallbackUrl: `oidcintegrationtest:\/\/callback\?code=(.*)&state=state`, + Details: &object.Details{ + ChangeDate: timestamppb.Now(), + ResourceOwner: Instance.ID(), + }, + }, + wantErr: false, + }, + { + name: "implicit", + req: &oidc_pb.CreateCallbackRequest{ + AuthRequestId: func() string { + client, err := Instance.CreateOIDCImplicitFlowClient(CTX, redirectURIImplicit) + require.NoError(t, err) + authRequestID, err := Instance.CreateOIDCAuthRequestImplicit(CTX, client.GetClientId(), Instance.Users.Get(integration.UserTypeOrgOwner).ID, redirectURIImplicit) + require.NoError(t, err) + return authRequestID + }(), + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionResp.GetSessionId(), + SessionToken: sessionResp.GetSessionToken(), + }, + }, + }, + want: &oidc_pb.CreateCallbackResponse{ + CallbackUrl: `http:\/\/localhost:9999\/callback#access_token=(.*)&expires_in=(.*)&id_token=(.*)&state=state&token_type=Bearer`, + Details: &object.Details{ + ChangeDate: timestamppb.Now(), + ResourceOwner: Instance.ID(), + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Client.CreateCallback(CTX, tt.req) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + integration.AssertDetails(t, tt.want, got) + if tt.want != nil { + assert.Regexp(t, regexp.MustCompile(tt.want.CallbackUrl), got.GetCallbackUrl()) + } + }) + } +} diff --git a/internal/api/grpc/saml/oidc.go b/internal/api/grpc/saml/oidc.go deleted file mode 100644 index 826c198fad..0000000000 --- a/internal/api/grpc/saml/oidc.go +++ /dev/null @@ -1,203 +0,0 @@ -package oidc - -import ( - "context" - - "github.com/zitadel/logging" - "github.com/zitadel/oidc/v3/pkg/op" - "google.golang.org/protobuf/types/known/durationpb" - "google.golang.org/protobuf/types/known/timestamppb" - - "github.com/zitadel/zitadel/internal/api/grpc/object/v2" - "github.com/zitadel/zitadel/internal/api/http" - "github.com/zitadel/zitadel/internal/api/oidc" - "github.com/zitadel/zitadel/internal/domain" - "github.com/zitadel/zitadel/internal/query" - "github.com/zitadel/zitadel/internal/zerrors" - oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2" -) - -func (s *Server) GetAuthRequest(ctx context.Context, req *oidc_pb.GetAuthRequestRequest) (*oidc_pb.GetAuthRequestResponse, error) { - authRequest, err := s.query.AuthRequestByID(ctx, true, req.GetAuthRequestId(), true) - if err != nil { - logging.WithError(err).Error("query authRequest by ID") - return nil, err - } - return &oidc_pb.GetAuthRequestResponse{ - AuthRequest: authRequestToPb(authRequest), - }, nil -} - -func authRequestToPb(a *query.AuthRequest) *oidc_pb.AuthRequest { - pba := &oidc_pb.AuthRequest{ - Id: a.ID, - CreationDate: timestamppb.New(a.CreationDate), - ClientId: a.ClientID, - Scope: a.Scope, - RedirectUri: a.RedirectURI, - Prompt: promptsToPb(a.Prompt), - UiLocales: a.UiLocales, - LoginHint: a.LoginHint, - HintUserId: a.HintUserID, - } - if a.MaxAge != nil { - pba.MaxAge = durationpb.New(*a.MaxAge) - } - return pba -} - -func promptsToPb(promps []domain.Prompt) []oidc_pb.Prompt { - out := make([]oidc_pb.Prompt, len(promps)) - for i, p := range promps { - out[i] = promptToPb(p) - } - return out -} - -func promptToPb(p domain.Prompt) oidc_pb.Prompt { - switch p { - case domain.PromptUnspecified: - return oidc_pb.Prompt_PROMPT_UNSPECIFIED - case domain.PromptNone: - return oidc_pb.Prompt_PROMPT_NONE - case domain.PromptLogin: - return oidc_pb.Prompt_PROMPT_LOGIN - case domain.PromptConsent: - return oidc_pb.Prompt_PROMPT_CONSENT - case domain.PromptSelectAccount: - return oidc_pb.Prompt_PROMPT_SELECT_ACCOUNT - case domain.PromptCreate: - return oidc_pb.Prompt_PROMPT_CREATE - default: - return oidc_pb.Prompt_PROMPT_UNSPECIFIED - } -} - -func (s *Server) CreateCallback(ctx context.Context, req *oidc_pb.CreateCallbackRequest) (*oidc_pb.CreateCallbackResponse, error) { - switch v := req.GetCallbackKind().(type) { - case *oidc_pb.CreateCallbackRequest_Error: - return s.failAuthRequest(ctx, req.GetAuthRequestId(), v.Error) - case *oidc_pb.CreateCallbackRequest_Session: - return s.linkSessionToAuthRequest(ctx, req.GetAuthRequestId(), v.Session) - default: - return nil, zerrors.ThrowUnimplementedf(nil, "OIDCv2-zee7A", "verification oneOf %T in method CreateCallback not implemented", v) - } -} - -func (s *Server) failAuthRequest(ctx context.Context, authRequestID string, ae *oidc_pb.AuthorizationError) (*oidc_pb.CreateCallbackResponse, error) { - details, aar, err := s.command.FailAuthRequest(ctx, authRequestID, errorReasonToDomain(ae.GetError())) - if err != nil { - return nil, err - } - authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar} - callback, err := oidc.CreateErrorCallbackURL(authReq, errorReasonToOIDC(ae.GetError()), ae.GetErrorDescription(), ae.GetErrorUri(), s.op.Provider()) - if err != nil { - return nil, err - } - return &oidc_pb.CreateCallbackResponse{ - Details: object.DomainToDetailsPb(details), - CallbackUrl: callback, - }, nil -} - -func (s *Server) linkSessionToAuthRequest(ctx context.Context, authRequestID string, session *oidc_pb.Session) (*oidc_pb.CreateCallbackResponse, error) { - details, aar, err := s.command.LinkSessionToAuthRequest(ctx, authRequestID, session.GetSessionId(), session.GetSessionToken(), true) - if err != nil { - return nil, err - } - authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar} - ctx = op.ContextWithIssuer(ctx, http.DomainContext(ctx).Origin()) - var callback string - if aar.ResponseType == domain.OIDCResponseTypeCode { - callback, err = oidc.CreateCodeCallbackURL(ctx, authReq, s.op.Provider()) - } else { - callback, err = s.op.CreateTokenCallbackURL(ctx, authReq) - } - if err != nil { - return nil, err - } - return &oidc_pb.CreateCallbackResponse{ - Details: object.DomainToDetailsPb(details), - CallbackUrl: callback, - }, nil -} - -func errorReasonToDomain(errorReason oidc_pb.ErrorReason) domain.OIDCErrorReason { - switch errorReason { - case oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED: - return domain.OIDCErrorReasonUnspecified - case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST: - return domain.OIDCErrorReasonInvalidRequest - case oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT: - return domain.OIDCErrorReasonUnauthorizedClient - case oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED: - return domain.OIDCErrorReasonAccessDenied - case oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE: - return domain.OIDCErrorReasonUnsupportedResponseType - case oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE: - return domain.OIDCErrorReasonInvalidScope - case oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR: - return domain.OIDCErrorReasonServerError - case oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE: - return domain.OIDCErrorReasonTemporaryUnavailable - case oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED: - return domain.OIDCErrorReasonInteractionRequired - case oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED: - return domain.OIDCErrorReasonLoginRequired - case oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED: - return domain.OIDCErrorReasonAccountSelectionRequired - case oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED: - return domain.OIDCErrorReasonConsentRequired - case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI: - return domain.OIDCErrorReasonInvalidRequestURI - case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT: - return domain.OIDCErrorReasonInvalidRequestObject - case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED: - return domain.OIDCErrorReasonRequestNotSupported - case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED: - return domain.OIDCErrorReasonRequestURINotSupported - case oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED: - return domain.OIDCErrorReasonRegistrationNotSupported - default: - return domain.OIDCErrorReasonUnspecified - } -} - -func errorReasonToOIDC(reason oidc_pb.ErrorReason) string { - switch reason { - case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST: - return "invalid_request" - case oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT: - return "unauthorized_client" - case oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED: - return "access_denied" - case oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE: - return "unsupported_response_type" - case oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE: - return "invalid_scope" - case oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE: - return "temporarily_unavailable" - case oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED: - return "interaction_required" - case oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED: - return "login_required" - case oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED: - return "account_selection_required" - case oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED: - return "consent_required" - case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI: - return "invalid_request_uri" - case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT: - return "invalid_request_object" - case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED: - return "request_not_supported" - case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED: - return "request_uri_not_supported" - case oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED: - return "registration_not_supported" - case oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED, oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR: - fallthrough - default: - return "server_error" - } -} diff --git a/internal/api/grpc/saml/saml.go b/internal/api/grpc/saml/saml.go new file mode 100644 index 0000000000..a0e265667e --- /dev/null +++ b/internal/api/grpc/saml/saml.go @@ -0,0 +1,132 @@ +package saml + +import ( + "context" + + "github.com/zitadel/logging" + "github.com/zitadel/oidc/v3/pkg/op" + "github.com/zitadel/saml/pkg/provider" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/zitadel/zitadel/internal/api/grpc/object/v2" + "github.com/zitadel/zitadel/internal/api/http" + "github.com/zitadel/zitadel/internal/api/oidc" + "github.com/zitadel/zitadel/internal/api/saml" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/query" + "github.com/zitadel/zitadel/internal/zerrors" + saml_pb "github.com/zitadel/zitadel/pkg/grpc/saml/v2" +) + +func (s *Server) GetAuthRequest(ctx context.Context, req *saml_pb.GetSAMLRequestRequest) (*saml_pb.GetSAMLRequestResponse, error) { + authRequest, err := s.query.AuthRequestByID(ctx, true, req.GetSamlRequestId(), true) + if err != nil { + logging.WithError(err).Error("query samlRequest by ID") + return nil, err + } + return &saml_pb.GetSAMLRequestResponse{ + SamlRequest: samlRequestToPb(authRequest), + }, nil +} + +func samlRequestToPb(a *query.AuthRequest) *saml_pb.SAMLRequest { + return &saml_pb.SAMLRequest{ + Id: a.ID, + CreationDate: timestamppb.New(a.CreationDate), + } +} + +func (s *Server) CreateCallback(ctx context.Context, req *saml_pb.CreateCallbackRequest) (*saml_pb.CreateCallbackResponse, error) { + switch v := req.GetCallbackKind().(type) { + case *saml_pb.CreateCallbackRequest_Error: + return s.failAuthRequest(ctx, req.GetSamlRequestId(), v.Error) + case *saml_pb.CreateCallbackRequest_Session: + return s.linkSessionToAuthRequest(ctx, req.GetSamlRequestId(), v.Session) + default: + return nil, zerrors.ThrowUnimplementedf(nil, "OIDCv2-zee7A", "verification oneOf %T in method CreateCallback not implemented", v) + } +} + +func (s *Server) failAuthRequest(ctx context.Context, samlRequestID string, ae *saml_pb.AuthorizationError) (*saml_pb.CreateCallbackResponse, error) { + details, aar, err := s.command.FailSAMLRequest(ctx, samlRequestID, errorReasonToDomain(ae.GetError())) + if err != nil { + return nil, err + } + authReq := &saml.AuthRequestV2{CurrentSAMLRequest: aar} + callback, err := saml.CreateErrorCallbackURL(authReq, errorReasonToSAML(ae.GetError()), ae.GetErrorDescription(), ae.GetErrorUri()) + if err != nil { + return nil, err + } + return &saml_pb.CreateCallbackResponse{ + Details: object.DomainToDetailsPb(details), + CallbackUrl: callback, + }, nil +} + +func (s *Server) linkSessionToAuthRequest(ctx context.Context, samlRequestID string, session *saml_pb.Session) (*saml_pb.CreateCallbackResponse, error) { + details, aar, err := s.command.LinkSessionToAuthRequest(ctx, samlRequestID, session.GetSessionId(), session.GetSessionToken(), true) + if err != nil { + return nil, err + } + authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar} + ctx = op.ContextWithIssuer(ctx, http.DomainContext(ctx).Origin()) + var callback string + if aar.ResponseType == domain.OIDCResponseTypeCode { + callback, err = oidc.CreateCodeCallbackURL(ctx, authReq, s.op.Provider()) + } else { + callback, err = s.op.CreateTokenCallbackURL(ctx, authReq) + } + if err != nil { + return nil, err + } + return &saml_pb.CreateCallbackResponse{ + Details: object.DomainToDetailsPb(details), + CallbackUrl: callback, + }, nil +} + +func errorReasonToDomain(errorReason saml_pb.ErrorReason) domain.SAMLErrorReason { + switch errorReason { + case saml_pb.ErrorReason_ERROR_REASON_UNSPECIFIED: + return domain.SAMLErrorReasonUnspecified + case saml_pb.ErrorReason_ERROR_REASON_VERSION_MISSMATCH: + return domain.SAMLErrorReasonVersionMissmatch + case saml_pb.ErrorReason_ERROR_REASON_AUTH_N_FAILED: + return domain.SAMLErrorReasonAuthNFailed + case saml_pb.ErrorReason_ERROR_REASON_INVALID_ATTR_NAME_OR_VALUE: + return domain.SAMLErrorReasonInvalidAttrNameOrValue + case saml_pb.ErrorReason_ERROR_REASON_INVALID_NAMEID_POLICY: + return domain.SAMLErrorReasonInvalidNameIDPolicy + case saml_pb.ErrorReason_ERROR_REASON_REQUEST_DENIED: + return domain.SAMLErrorReasonRequestDenied + case saml_pb.ErrorReason_ERROR_REASON_REQUEST_UNSUPPORTED: + return domain.SAMLErrorReasonRequestUnsupported + case saml_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_BINDING: + return domain.SAMLErrorReasonUnsupportedBinding + default: + return domain.SAMLErrorReasonUnspecified + } +} + +func errorReasonToSAML(reason saml_pb.ErrorReason) string { + switch reason { + case saml_pb.ErrorReason_ERROR_REASON_UNSPECIFIED: + return "unspecified error" + case saml_pb.ErrorReason_ERROR_REASON_VERSION_MISSMATCH: + return provider.StatusCodeVersionMissmatch + case saml_pb.ErrorReason_ERROR_REASON_AUTH_N_FAILED: + return provider.StatusCodeAuthNFailed + case saml_pb.ErrorReason_ERROR_REASON_INVALID_ATTR_NAME_OR_VALUE: + return provider.StatusCodeInvalidAttrNameOrValue + case saml_pb.ErrorReason_ERROR_REASON_INVALID_NAMEID_POLICY: + return provider.StatusCodeInvalidNameIDPolicy + case saml_pb.ErrorReason_ERROR_REASON_REQUEST_DENIED: + return provider.StatusCodeRequestDenied + case saml_pb.ErrorReason_ERROR_REASON_REQUEST_UNSUPPORTED: + return provider.StatusCodeRequestUnsupported + case saml_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_BINDING: + return provider.StatusCodeUnsupportedBinding + default: + return "unspecified error" + } +} diff --git a/internal/api/grpc/saml/server.go b/internal/api/grpc/saml/server.go index 446b584dd8..2528f58e6e 100644 --- a/internal/api/grpc/saml/server.go +++ b/internal/api/grpc/saml/server.go @@ -1,4 +1,4 @@ -package oidc +package saml import ( "google.golang.org/grpc" diff --git a/internal/api/saml/storage.go b/internal/api/saml/storage.go index 7aa7fc0522..2344357f16 100644 --- a/internal/api/saml/storage.go +++ b/internal/api/saml/storage.go @@ -157,6 +157,11 @@ func (p *Storage) createAuthRequest(ctx context.Context, req *samlp.AuthnRequest return AuthRequestFromBusiness(resp) } +func CreateErrorCallbackURL(authReq models.AuthRequestInt, reason, description, uri string) (string, error) { + // TODO handling for errors + return "", nil +} + func (p *Storage) AuthRequestByID(ctx context.Context, id string) (_ models.AuthRequestInt, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() diff --git a/internal/domain/saml_error_reason.go b/internal/domain/saml_error_reason.go index b302360a06..d3f821f39f 100644 --- a/internal/domain/saml_error_reason.go +++ b/internal/domain/saml_error_reason.go @@ -4,6 +4,13 @@ type SAMLErrorReason int32 const ( SAMLErrorReasonUnspecified SAMLErrorReason = iota + SAMLErrorReasonVersionMissmatch + SAMLErrorReasonAuthNFailed + SAMLErrorReasonInvalidAttrNameOrValue + SAMLErrorReasonInvalidNameIDPolicy + SAMLErrorReasonRequestDenied + SAMLErrorReasonRequestUnsupported + SAMLErrorReasonUnsupportedBinding ) func SAMLErrorReasonFromError(err error) SAMLErrorReason { diff --git a/internal/integration/client.go b/internal/integration/client.go index dde8822acd..ebb70ed3dd 100644 --- a/internal/integration/client.go +++ b/internal/integration/client.go @@ -34,6 +34,7 @@ import ( user_v3alpha "github.com/zitadel/zitadel/pkg/grpc/resources/user/v3alpha" userschema_v3alpha "github.com/zitadel/zitadel/pkg/grpc/resources/userschema/v3alpha" webkey_v3alpha "github.com/zitadel/zitadel/pkg/grpc/resources/webkey/v3alpha" + saml_pb "github.com/zitadel/zitadel/pkg/grpc/saml/v2" "github.com/zitadel/zitadel/pkg/grpc/session/v2" session_v2beta "github.com/zitadel/zitadel/pkg/grpc/session/v2beta" "github.com/zitadel/zitadel/pkg/grpc/settings/v2" @@ -65,6 +66,7 @@ type Client struct { WebKeyV3Alpha webkey_v3alpha.ZITADELWebKeysClient IDPv2 idp_pb.IdentityProviderServiceClient UserV3Alpha user_v3alpha.ZITADELUsersClient + SAMLv2 saml_pb.SAMLServiceClient } func newClient(ctx context.Context, target string) (*Client, error) { @@ -96,6 +98,7 @@ func newClient(ctx context.Context, target string) (*Client, error) { WebKeyV3Alpha: webkey_v3alpha.NewZITADELWebKeysClient(cc), IDPv2: idp_pb.NewIdentityProviderServiceClient(cc), UserV3Alpha: user_v3alpha.NewZITADELUsersClient(cc), + SAMLv2: saml_pb.NewSAMLServiceClient(cc), } return client, client.pollHealth(ctx) } diff --git a/internal/integration/saml.go b/internal/integration/saml.go new file mode 100644 index 0000000000..a9b8af1dc7 --- /dev/null +++ b/internal/integration/saml.go @@ -0,0 +1,94 @@ +package integration + +import ( + "context" + "fmt" + "net/url" + + "github.com/brianvoe/gofakeit/v6" + "github.com/crewjam/saml" + "github.com/crewjam/saml/samlsp" + + oidc_internal "github.com/zitadel/zitadel/internal/api/oidc" + "github.com/zitadel/zitadel/pkg/grpc/management" +) + +func (i *Instance) CreateSAMLClient(ctx context.Context, projectID, entityID, acsURL, logoutURL string) (*management.AddSAMLAppResponse, error) { + samlSPMetadata := ` + + + + urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified + + +` + + resp, err := i.Client.Mgmt.AddSAMLApp(ctx, &management.AddSAMLAppRequest{ + ProjectId: projectID, + Name: fmt.Sprintf("app-%s", gofakeit.AppName()), + Metadata: &management.AddSAMLAppRequest_MetadataXml{MetadataXml: []byte(samlSPMetadata)}, + }) + if err != nil { + return nil, err + } + return resp, await(func() error { + _, err := i.Client.Mgmt.GetProjectByID(ctx, &management.GetProjectByIDRequest{ + Id: projectID, + }) + if err != nil { + return err + } + _, err = i.Client.Mgmt.GetAppByID(ctx, &management.GetAppByIDRequest{ + ProjectId: projectID, + AppId: resp.GetAppId(), + }) + return err + }) +} + +func (i *Instance) CreateSAMLAuthRequest(ctx context.Context, entityID, loginClient, acsURL, relayState string) (authRequestID string, err error) { + binding := saml.HTTPRedirectBinding + entityDescriptor := new(saml.EntityDescriptor) + rootURL, err := url.Parse(entityID) + if err != nil { + return "", err + } + + m, _ := samlsp.New(samlsp.Options{ + URL: *rootURL, + /* TODO + Key: keyPair.PrivateKey.(*rsa.PrivateKey), + Certificate: keyPair.Leaf, + */ + IDPMetadata: entityDescriptor, + }) + + authReq, err := m.ServiceProvider.MakeAuthenticationRequest(acsURL, binding, m.ResponseBinding) + if err != nil { + return "", err + } + + redirectURL, err := authReq.Redirect(relayState, &m.ServiceProvider) + if err != nil { + return "", err + } + + req, err := GetRequest(redirectURL.String(), map[string]string{oidc_internal.LoginClientHeader: loginClient}) + if err != nil { + return "", fmt.Errorf("get request: %w", err) + } + + loc, err := CheckRedirect(req) + if err != nil { + return "", fmt.Errorf("check redirect: %w", err) + } + + //TODO get id from loc + return loc.String(), nil +} diff --git a/internal/query/projection/projection.go b/internal/query/projection/projection.go index a23ae72330..ea244d3b37 100644 --- a/internal/query/projection/projection.go +++ b/internal/query/projection/projection.go @@ -69,6 +69,7 @@ var ( DeviceAuthProjection *handler.Handler SessionProjection *handler.Handler AuthRequestProjection *handler.Handler + SamlRequestProjection *handler.Handler MilestoneProjection *handler.Handler QuotaProjection *quotaProjection LimitsProjection *handler.Handler @@ -156,6 +157,7 @@ func Create(ctx context.Context, sqlClient *database.DB, es handler.EventStore, DeviceAuthProjection = newDeviceAuthProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["device_auth"])) SessionProjection = newSessionProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["sessions"])) AuthRequestProjection = newAuthRequestProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["auth_requests"])) + SamlRequestProjection = newSamlRequestProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["saml_requests"])) MilestoneProjection = newMilestoneProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["milestones"])) QuotaProjection = newQuotaProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["quotas"])) LimitsProjection = newLimitsProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["limits"])) @@ -287,6 +289,7 @@ func newProjectionsList() { DeviceAuthProjection, SessionProjection, AuthRequestProjection, + SamlRequestProjection, MilestoneProjection, QuotaProjection.handler, LimitsProjection, diff --git a/internal/query/projection/saml_request.go b/internal/query/projection/saml_request.go new file mode 100644 index 0000000000..610619d31c --- /dev/null +++ b/internal/query/projection/saml_request.go @@ -0,0 +1,132 @@ +package projection + +import ( + "context" + + "github.com/zitadel/zitadel/internal/eventstore" + old_handler "github.com/zitadel/zitadel/internal/eventstore/handler" + "github.com/zitadel/zitadel/internal/eventstore/handler/v2" + "github.com/zitadel/zitadel/internal/repository/instance" + "github.com/zitadel/zitadel/internal/repository/samlrequest" + "github.com/zitadel/zitadel/internal/zerrors" +) + +const ( + SamlRequestsProjectionTable = "projections.saml_requests" + + SamlRequestColumnID = "id" + SamlRequestColumnCreationDate = "creation_date" + SamlRequestColumnChangeDate = "change_date" + SamlRequestColumnSequence = "sequence" + SamlRequestColumnResourceOwner = "resource_owner" + SamlRequestColumnInstanceID = "instance_id" + SamlRequestColumnLoginClient = "login_client" + SamlRequestColumnIssuer = "issuer" + SamlRequestColumnACS = "acs" + SamlRequestColumnRelayState = "relay_state" + SamlRequestColumnBinding = "binding" +) + +type samlRequestProjection struct{} + +// Name implements handler.Projection. +func (*samlRequestProjection) Name() string { + return SamlRequestsProjectionTable +} + +func newSamlRequestProjection(ctx context.Context, config handler.Config) *handler.Handler { + return handler.NewHandler(ctx, &config, new(samlRequestProjection)) +} + +func (*samlRequestProjection) Init() *old_handler.Check { + return handler.NewMultiTableCheck( + handler.NewTable([]*handler.InitColumn{ + handler.NewColumn(SamlRequestColumnID, handler.ColumnTypeText), + handler.NewColumn(SamlRequestColumnCreationDate, handler.ColumnTypeTimestamp), + handler.NewColumn(SamlRequestColumnChangeDate, handler.ColumnTypeTimestamp), + handler.NewColumn(SamlRequestColumnSequence, handler.ColumnTypeInt64), + handler.NewColumn(SamlRequestColumnResourceOwner, handler.ColumnTypeText), + handler.NewColumn(SamlRequestColumnInstanceID, handler.ColumnTypeText), + handler.NewColumn(SamlRequestColumnLoginClient, handler.ColumnTypeText), + handler.NewColumn(SamlRequestColumnIssuer, handler.ColumnTypeText), + handler.NewColumn(SamlRequestColumnACS, handler.ColumnTypeText), + handler.NewColumn(SamlRequestColumnRelayState, handler.ColumnTypeText), + handler.NewColumn(SamlRequestColumnBinding, handler.ColumnTypeText), + }, + handler.NewPrimaryKey(SamlRequestColumnInstanceID, SamlRequestColumnID), + ), + ) +} + +func (p *samlRequestProjection) Reducers() []handler.AggregateReducer { + return []handler.AggregateReducer{ + { + Aggregate: samlrequest.AggregateType, + EventReducers: []handler.EventReducer{ + { + Event: samlrequest.AddedType, + Reduce: p.reduceSamlRequestAdded, + }, + { + Event: samlrequest.SucceededType, + Reduce: p.reduceSamlRequestEnded, + }, + { + Event: samlrequest.FailedType, + Reduce: p.reduceSamlRequestEnded, + }, + }, + }, + { + Aggregate: instance.AggregateType, + EventReducers: []handler.EventReducer{ + { + Event: instance.InstanceRemovedEventType, + Reduce: reduceInstanceRemovedHelper(SamlRequestColumnInstanceID), + }, + }, + }, + } +} + +func (p *samlRequestProjection) reduceSamlRequestAdded(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*samlrequest.AddedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-Sfwfa", "reduce.wrong.event.type %s", samlrequest.AddedType) + } + + return handler.NewCreateStatement( + e, + []handler.Column{ + handler.NewCol(SamlRequestColumnID, e.Aggregate().ID), + handler.NewCol(SamlRequestColumnInstanceID, e.Aggregate().InstanceID), + handler.NewCol(SamlRequestColumnCreationDate, e.CreationDate()), + handler.NewCol(SamlRequestColumnChangeDate, e.CreationDate()), + handler.NewCol(SamlRequestColumnResourceOwner, e.Aggregate().ResourceOwner), + handler.NewCol(SamlRequestColumnSequence, e.Sequence()), + handler.NewCol(SamlRequestColumnLoginClient, e.LoginClient), + handler.NewCol(SamlRequestColumnIssuer, e.Issuer), + handler.NewCol(SamlRequestColumnACS, e.ACSURL), + handler.NewCol(SamlRequestColumnRelayState, e.RelayState), + handler.NewCol(SamlRequestColumnBinding, e.Binding), + }, + ), nil +} + +func (p *samlRequestProjection) reduceSamlRequestEnded(event eventstore.Event) (*handler.Statement, error) { + switch event.(type) { + case *samlrequest.SucceededEvent, + *samlrequest.FailedEvent: + break + default: + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-ASF3h", "reduce.wrong.event.type %s", []eventstore.EventType{samlrequest.SucceededType, samlrequest.FailedType}) + } + + return handler.NewDeleteStatement( + event, + []handler.Condition{ + handler.NewCond(SamlRequestColumnID, event.Aggregate().ID), + handler.NewCond(SamlRequestColumnInstanceID, event.Aggregate().InstanceID), + }, + ), nil +} diff --git a/internal/query/projection/saml_request_test.go b/internal/query/projection/saml_request_test.go new file mode 100644 index 0000000000..b0fe842d03 --- /dev/null +++ b/internal/query/projection/saml_request_test.go @@ -0,0 +1,123 @@ +package projection + +import ( + "testing" + + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/handler/v2" + "github.com/zitadel/zitadel/internal/repository/samlrequest" + "github.com/zitadel/zitadel/internal/zerrors" +) + +func TestSamlRequestProjection_reduces(t *testing.T) { + type args struct { + event func(t *testing.T) eventstore.Event + } + tests := []struct { + name string + args args + reduce func(event eventstore.Event) (*handler.Statement, error) + want wantReduce + }{ + { + name: "reduceSamlRequestAdded", + args: args{ + event: getEvent(testEvent( + samlrequest.AddedType, + samlrequest.AggregateType, + []byte(`{"login_client": "loginClient", "issuer": "issuer", "acs_url": "acs", "relay_state": "relayState", "binding": "binding"}`), + ), eventstore.GenericEventMapper[samlrequest.AddedEvent]), + }, + reduce: (&samlRequestProjection{}).reduceSamlRequestAdded, + want: wantReduce{ + aggregateType: eventstore.AggregateType("saml_request"), + sequence: 15, + executer: &testExecuter{ + executions: []execution{ + { + expectedStmt: "INSERT INTO projections.saml_requests (id, instance_id, creation_date, change_date, resource_owner, sequence, login_client, issuer, acs, relay_state, binding) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)", + expectedArgs: []interface{}{ + "agg-id", + "instance-id", + anyArg{}, + anyArg{}, + "ro-id", + uint64(15), + "loginClient", + "issuer", + "acs", + "relayState", + "binding", + }, + }, + }, + }, + }, + }, + { + name: "reduceSamlRequestFailed", + args: args{ + event: getEvent(testEvent( + samlrequest.FailedType, + samlrequest.AggregateType, + []byte(`{"reason": 0}`), + ), eventstore.GenericEventMapper[samlrequest.FailedEvent]), + }, + reduce: (&samlRequestProjection{}).reduceSamlRequestEnded, + want: wantReduce{ + aggregateType: eventstore.AggregateType("saml_request"), + sequence: 15, + executer: &testExecuter{ + executions: []execution{ + { + expectedStmt: "DELETE FROM projections.saml_requests WHERE (id = $1) AND (instance_id = $2)", + expectedArgs: []interface{}{ + "agg-id", + "instance-id", + }, + }, + }, + }, + }, + }, + { + name: "reduceSamlRequestSucceeded", + args: args{ + event: getEvent(testEvent( + samlrequest.SucceededType, + samlrequest.AggregateType, + nil, + ), eventstore.GenericEventMapper[samlrequest.SucceededEvent]), + }, + reduce: (&samlRequestProjection{}).reduceSamlRequestEnded, + want: wantReduce{ + aggregateType: eventstore.AggregateType("saml_request"), + sequence: 15, + executer: &testExecuter{ + executions: []execution{ + { + expectedStmt: "DELETE FROM projections.saml_requests WHERE (id = $1) AND (instance_id = $2)", + expectedArgs: []interface{}{ + "agg-id", + "instance-id", + }, + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + event := baseEvent(t) + got, err := tt.reduce(event) + if !zerrors.IsErrorInvalidArgument(err) { + t.Errorf("no wrong event mapping: %v, got: %v", err, got) + } + + event = tt.args.event(t) + got, err = tt.reduce(event) + assertReduce(t, got, err, SamlRequestsProjectionTable, tt.want) + }) + } +} diff --git a/internal/query/saml_request.go b/internal/query/saml_request.go new file mode 100644 index 0000000000..a0f6fdc6cd --- /dev/null +++ b/internal/query/saml_request.go @@ -0,0 +1,81 @@ +package query + +import ( + "context" + "database/sql" + _ "embed" + "errors" + "fmt" + "time" + + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/api/call" + "github.com/zitadel/zitadel/internal/eventstore/handler/v2" + "github.com/zitadel/zitadel/internal/query/projection" + "github.com/zitadel/zitadel/internal/telemetry/tracing" + "github.com/zitadel/zitadel/internal/zerrors" +) + +type SamlRequest struct { + ID string + CreationDate time.Time + LoginClient string + Issuer string + ACS string + RelayState string + Binding string +} + +func (a *SamlRequest) checkLoginClient(ctx context.Context) error { + if uid := authz.GetCtxData(ctx).UserID; uid != a.LoginClient { + return zerrors.ThrowPermissionDenied(nil, "OIDCv2-aL0ag", "Errors.SamlRequest.WrongLoginClient") + } + return nil +} + +//go:embed saml_request_by_id.sql +var samlRequestByIDQuery string + +func (q *Queries) samlRequestByIDQuery(ctx context.Context) string { + return fmt.Sprintf(samlRequestByIDQuery, q.client.Timetravel(call.Took(ctx))) +} + +func (q *Queries) SamlRequestByID(ctx context.Context, shouldTriggerBulk bool, id string, checkLoginClient bool) (_ *SamlRequest, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + if shouldTriggerBulk { + _, traceSpan := tracing.NewNamedSpan(ctx, "TriggerSamlRequestProjection") + ctx, err = projection.SamlRequestProjection.Trigger(ctx, handler.WithAwaitRunning()) + logging.OnError(err).Debug("trigger failed") + traceSpan.EndWithError(err) + } + + dst := new(SamlRequest) + err = q.client.QueryRowContext( + ctx, + func(row *sql.Row) error { + return row.Scan( + &dst.ID, &dst.CreationDate, &dst.LoginClient, &dst.Issuer, &dst.ACS, &dst.RelayState, &dst.Binding, + ) + }, + q.samlRequestByIDQuery(ctx), + id, authz.GetInstance(ctx).InstanceID(), + ) + if errors.Is(err, sql.ErrNoRows) { + return nil, zerrors.ThrowNotFound(err, "QUERY-Thee9", "Errors.SamlRequest.NotExisting") + } + if err != nil { + return nil, zerrors.ThrowInternal(err, "QUERY-Ou8ue", "Errors.Internal") + } + + if checkLoginClient { + if err = dst.checkLoginClient(ctx); err != nil { + return nil, err + } + } + + return dst, nil +} diff --git a/internal/query/saml_request_by_id.sql b/internal/query/saml_request_by_id.sql new file mode 100644 index 0000000000..ac1c60058f --- /dev/null +++ b/internal/query/saml_request_by_id.sql @@ -0,0 +1,11 @@ +select + id, + creation_date, + login_client, + issuer, + acs, + relay_state, + binding +from projections.saml_requests %s +where id = $1 and instance_id = $2 +limit 1; diff --git a/internal/query/saml_request_test.go b/internal/query/saml_request_test.go new file mode 100644 index 0000000000..5cf58369cb --- /dev/null +++ b/internal/query/saml_request_test.go @@ -0,0 +1,127 @@ +package query + +import ( + "database/sql" + "database/sql/driver" + _ "embed" + "fmt" + "regexp" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/query/projection" + "github.com/zitadel/zitadel/internal/zerrors" +) + +func TestQueries_SamlRequestByID(t *testing.T) { + expQuery := regexp.QuoteMeta(fmt.Sprintf( + samlRequestByIDQuery, + asOfSystemTime, + )) + + cols := []string{ + projection.SamlRequestColumnID, + projection.SamlRequestColumnCreationDate, + projection.SamlRequestColumnLoginClient, + projection.SamlRequestColumnIssuer, + projection.SamlRequestColumnACS, + projection.SamlRequestColumnRelayState, + projection.SamlRequestColumnBinding, + } + type args struct { + shouldTriggerBulk bool + id string + checkLoginClient bool + } + tests := []struct { + name string + args args + expect sqlExpectation + want *SamlRequest + wantErr error + }{ + { + name: "success, all values", + args: args{ + shouldTriggerBulk: false, + id: "123", + checkLoginClient: true, + }, + expect: mockQuery(expQuery, cols, []driver.Value{ + "id", + testNow, + "loginClient", + "issuer", + "acs", + "relayState", + "binding", + }, "123", "instanceID"), + want: &SamlRequest{ + ID: "id", + CreationDate: testNow, + LoginClient: "loginClient", + Issuer: "issuer", + ACS: "acs", + RelayState: "relayState", + Binding: "binding", + }, + }, + { + name: "no rows", + args: args{ + shouldTriggerBulk: false, + id: "123", + }, + expect: mockQueryScanErr(expQuery, cols, nil, "123", "instanceID"), + wantErr: zerrors.ThrowNotFound(sql.ErrNoRows, "QUERY-Thee9", "Errors.SamlRequest.NotExisting"), + }, + { + name: "query error", + args: args{ + shouldTriggerBulk: false, + id: "123", + }, + expect: mockQueryErr(expQuery, sql.ErrConnDone, "123", "instanceID"), + wantErr: zerrors.ThrowInternal(sql.ErrConnDone, "QUERY-Ou8ue", "Errors.Internal"), + }, + { + name: "wrong login client", + args: args{ + shouldTriggerBulk: false, + id: "123", + checkLoginClient: true, + }, + expect: mockQuery(expQuery, cols, []driver.Value{ + "id", + testNow, + "wrongLoginClient", + "issuer", + "acs", + "relayState", + "binding", + }, "123", "instanceID"), + wantErr: zerrors.ThrowPermissionDeniedf(nil, "OIDCv2-aL0ag", "Errors.SamlRequest.WrongLoginClient"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + execMock(t, tt.expect, func(db *sql.DB) { + q := &Queries{ + client: &database.DB{ + DB: db, + Database: &prepareDB{}, + }, + } + ctx := authz.NewMockContext("instanceID", "orgID", "loginClient") + + got, err := q.SamlRequestByID(ctx, tt.args.shouldTriggerBulk, tt.args.id, tt.args.checkLoginClient) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + }) + }) + } +} diff --git a/internal/repository/samlrequest/saml_request.go b/internal/repository/samlrequest/saml_request.go index fc0ca3ee58..9997d9c54c 100644 --- a/internal/repository/samlrequest/saml_request.go +++ b/internal/repository/samlrequest/saml_request.go @@ -19,7 +19,7 @@ const ( type AddedEvent struct { *eventstore.BaseEvent `json:"-"` - LoginClient string `json:"loginClient,omitempty"` + LoginClient string `json:"login_client,omitempty"` ApplicationID string `json:"application_id,omitempty"` ACSURL string `json:"acs_url,omitempty"` RelayState string `json:"relay_state,omitempty"` diff --git a/proto/zitadel/oidc/v2/oidc_service.proto b/proto/zitadel/oidc/v2/oidc_service.proto index 1d16f662f0..3c36057afa 100644 --- a/proto/zitadel/oidc/v2/oidc_service.proto +++ b/proto/zitadel/oidc/v2/oidc_service.proto @@ -169,7 +169,7 @@ message CreateCallbackRequest { string auth_request_id = 1 [ (validate.rules).string = {min_len: 1, max_len: 200}, (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { - description: "ID of the SAML Request."; + description: "ID of the Auth Request."; example: "\"163840776835432705\""; } ]; diff --git a/proto/zitadel/saml/v2/authorization.proto b/proto/zitadel/saml/v2/authorization.proto index 9bdada5ad0..4cf90eeb90 100644 --- a/proto/zitadel/saml/v2/authorization.proto +++ b/proto/zitadel/saml/v2/authorization.proto @@ -62,23 +62,11 @@ message AuthorizationError { enum ErrorReason { ERROR_REASON_UNSPECIFIED = 0; - // Error states from https://datatracker.ietf.org/doc/html/rfc6749#section-4.2.2.1 - ERROR_REASON_INVALID_REQUEST = 1; - ERROR_REASON_UNAUTHORIZED_CLIENT = 2; - ERROR_REASON_ACCESS_DENIED = 3; - ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE = 4; - ERROR_REASON_INVALID_SCOPE = 5; - ERROR_REASON_SERVER_ERROR = 6; - ERROR_REASON_TEMPORARY_UNAVAILABLE = 7; - - // Error states from https://openid.net/specs/openid-connect-core-1_0.html#AuthError - ERROR_REASON_INTERACTION_REQUIRED = 8; - ERROR_REASON_LOGIN_REQUIRED = 9; - ERROR_REASON_ACCOUNT_SELECTION_REQUIRED = 10; - ERROR_REASON_CONSENT_REQUIRED = 11; - ERROR_REASON_INVALID_REQUEST_URI = 12; - ERROR_REASON_INVALID_REQUEST_OBJECT = 13; - ERROR_REASON_REQUEST_NOT_SUPPORTED = 14; - ERROR_REASON_REQUEST_URI_NOT_SUPPORTED = 15; - ERROR_REASON_REGISTRATION_NOT_SUPPORTED = 16; + ERROR_REASON_VERSION_MISSMATCH = 1; + ERROR_REASON_AUTH_N_FAILED = 2; + ERROR_REASON_INVALID_ATTR_NAME_OR_VALUE = 3; + ERROR_REASON_INVALID_NAMEID_POLICY = 4; + ERROR_REASON_REQUEST_DENIED =5; + ERROR_REASON_REQUEST_UNSUPPORTED = 6; + ERROR_REASON_UNSUPPORTED_BINDING = 7; } \ No newline at end of file diff --git a/proto/zitadel/saml/v2/saml_service.proto b/proto/zitadel/saml/v2/saml_service.proto index 80a29c64b7..78f4d3f78f 100644 --- a/proto/zitadel/saml/v2/saml_service.proto +++ b/proto/zitadel/saml/v2/saml_service.proto @@ -123,6 +123,30 @@ service SAMLService { }; }; } + + rpc CreateCallback (CreateCallbackRequest) returns (CreateCallbackResponse) { + option (google.api.http) = { + post: "/v2/saml/saml_requests/{saml_request_id}" + body: "*" + }; + + option (zitadel.protoc_gen_zitadel.v2.options) = { + auth_option: { + permission: "authenticated" + } + }; + + option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = { + summary: "Finalize a SAML Request and get the callback URL."; + description: "Finalize a SAML Request and get the callback URL for success or failure. The user must be redirected to the URL in order to inform the application about the success or failure. On success, the URL contains details for the application to obtain the tokens. This method can only be called once for an SAML request." + responses: { + key: "200" + value: { + description: "OK"; + } + }; + }; + } } message GetSAMLRequestRequest { @@ -140,3 +164,54 @@ message GetSAMLRequestRequest { message GetSAMLRequestResponse { SAMLRequest saml_request = 1; } + +message CreateCallbackRequest { + string saml_request_id = 1 [ + (validate.rules).string = {min_len: 1, max_len: 200}, + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "ID of the SAML Request."; + example: "\"163840776835432705\""; + } + ]; + + oneof callback_kind { + option (validate.required) = true; + Session session = 2; + AuthorizationError error = 3 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "Set this field when the authorization flow failed. It creates a callback URL to the application, with the error details set."; + } + ]; + } +} + +message Session { + string session_id = 1 [ + (validate.rules).string = {min_len: 1, max_len: 200}, + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + min_length: 1; + max_length: 200; + description: "ID of the session, used to login the user. Connects the session to the SAML Request."; + example: "\"163840776835432705\""; + } + ]; + + string session_token = 2 [ + (validate.rules).string = {min_len: 1, max_len: 200}, + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + min_length: 1; + max_length: 200; + description: "Token to verify the session is valid"; + } + ]; +} + +message CreateCallbackResponse { + zitadel.object.v2.Details details = 1; + string callback_url = 2 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "Callback URL where the user should be redirected, using a \"302 FOUND\" status. Contains details for the application to obtain the tokens on success, or error details on failure. Note that this field must be treated as credentials, as the contained code can be used to obtain tokens on behalve of the user."; + example: "\"https://client.example.org/cb\"" + } + ]; +} \ No newline at end of file