feat: add callback endpoint and query side for saml requests

This commit is contained in:
Stefan Benz 2024-12-03 20:09:19 +01:00
parent 905da945ff
commit f321b070ba
No known key found for this signature in database
GPG Key ID: 071AA751ED4F9D31
18 changed files with 1053 additions and 225 deletions

View File

@ -0,0 +1,250 @@
//go:build integration
package saml_test
import (
"context"
"net/url"
"os"
"regexp"
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/zitadel/zitadel/internal/integration"
"github.com/zitadel/zitadel/pkg/grpc/object/v2"
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2"
saml_pb "github.com/zitadel/zitadel/pkg/grpc/saml/v2"
"github.com/zitadel/zitadel/pkg/grpc/session/v2"
)
var (
CTX context.Context
Instance *integration.Instance
Client saml_pb.SAMLServiceClient
)
func TestMain(m *testing.M) {
os.Exit(func() int {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
defer cancel()
Instance = integration.NewInstance(ctx)
Client = Instance.Client.SAMLv2
CTX = Instance.WithAuthorization(ctx, integration.UserTypeOrgOwner)
return m.Run()
}())
}
func TestServer_GetAuthRequest(t *testing.T) {
entityID := "https://sp.example.com"
project, err := Instance.CreateProject(CTX)
require.NoError(t, err)
client, err := Instance.CreateSAMLClient(CTX, project.GetId(), entityID, entityID+"/saml/v2/sso", entityID+"/saml/v2/slo")
require.NoError(t, err)
authRequestID, err := Instance.CreateSAMLAuthRequest(CTX, "", Instance.Users[integration.UserTypeOrgOwner].ID, "acs", "relaystate")
require.NoError(t, err)
now := time.Now()
tests := []struct {
name string
AuthRequestID string
want *oidc_pb.GetAuthRequestResponse
wantErr bool
}{
{
name: "Not found",
AuthRequestID: "123",
wantErr: true,
},
{
name: "success",
AuthRequestID: authRequestID,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Client.GetAuthRequest(CTX, &saml_pb.GetSAMLRequestRequest{
SamlRequestId: tt.AuthRequestID,
})
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
authRequest := got.GetSamlRequest()
assert.NotNil(t, authRequest)
assert.Equal(t, authRequestID, authRequest.GetId())
assert.WithinRange(t, authRequest.GetCreationDate().AsTime(), now.Add(-time.Second), now.Add(time.Second))
})
}
}
func TestServer_CreateCallback(t *testing.T) {
project, err := Instance.CreateProject(CTX)
require.NoError(t, err)
client, err := Instance.CreateOIDCNativeClient(CTX, redirectURI, logoutRedirectURI, project.GetId(), false)
require.NoError(t, err)
sessionResp, err := Instance.Client.SessionV2.CreateSession(CTX, &session.CreateSessionRequest{
Checks: &session.Checks{
User: &session.CheckUser{
Search: &session.CheckUser_UserId{
UserId: Instance.Users[integration.UserTypeOrgOwner].ID,
},
},
},
})
require.NoError(t, err)
tests := []struct {
name string
req *oidc_pb.CreateCallbackRequest
AuthError string
want *oidc_pb.CreateCallbackResponse
wantURL *url.URL
wantErr bool
}{
{
name: "Not found",
req: &oidc_pb.CreateCallbackRequest{
AuthRequestId: "123",
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionResp.GetSessionId(),
SessionToken: sessionResp.GetSessionToken(),
},
},
},
wantErr: true,
},
{
name: "session not found",
req: &oidc_pb.CreateCallbackRequest{
AuthRequestId: func() string {
authRequestID, err := Instance.CreateOIDCAuthRequest(CTX, client.GetClientId(), Instance.Users[integration.UserTypeOrgOwner].ID, redirectURI)
require.NoError(t, err)
return authRequestID
}(),
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: "foo",
SessionToken: "bar",
},
},
},
wantErr: true,
},
{
name: "session token invalid",
req: &oidc_pb.CreateCallbackRequest{
AuthRequestId: func() string {
authRequestID, err := Instance.CreateOIDCAuthRequest(CTX, client.GetClientId(), Instance.Users.Get(integration.UserTypeOrgOwner).ID, redirectURI)
require.NoError(t, err)
return authRequestID
}(),
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionResp.GetSessionId(),
SessionToken: "bar",
},
},
},
wantErr: true,
},
{
name: "fail callback",
req: &oidc_pb.CreateCallbackRequest{
AuthRequestId: func() string {
authRequestID, err := Instance.CreateOIDCAuthRequest(CTX, client.GetClientId(), Instance.Users.Get(integration.UserTypeOrgOwner).ID, redirectURI)
require.NoError(t, err)
return authRequestID
}(),
CallbackKind: &oidc_pb.CreateCallbackRequest_Error{
Error: &oidc_pb.AuthorizationError{
Error: oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED,
ErrorDescription: gu.Ptr("nope"),
ErrorUri: gu.Ptr("https://example.com/docs"),
},
},
},
want: &oidc_pb.CreateCallbackResponse{
CallbackUrl: regexp.QuoteMeta(`oidcintegrationtest://callback?error=access_denied&error_description=nope&error_uri=https%3A%2F%2Fexample.com%2Fdocs&state=state`),
Details: &object.Details{
ChangeDate: timestamppb.Now(),
ResourceOwner: Instance.ID(),
},
},
wantErr: false,
},
{
name: "code callback",
req: &oidc_pb.CreateCallbackRequest{
AuthRequestId: func() string {
authRequestID, err := Instance.CreateOIDCAuthRequest(CTX, client.GetClientId(), Instance.Users.Get(integration.UserTypeOrgOwner).ID, redirectURI)
require.NoError(t, err)
return authRequestID
}(),
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionResp.GetSessionId(),
SessionToken: sessionResp.GetSessionToken(),
},
},
},
want: &oidc_pb.CreateCallbackResponse{
CallbackUrl: `oidcintegrationtest:\/\/callback\?code=(.*)&state=state`,
Details: &object.Details{
ChangeDate: timestamppb.Now(),
ResourceOwner: Instance.ID(),
},
},
wantErr: false,
},
{
name: "implicit",
req: &oidc_pb.CreateCallbackRequest{
AuthRequestId: func() string {
client, err := Instance.CreateOIDCImplicitFlowClient(CTX, redirectURIImplicit)
require.NoError(t, err)
authRequestID, err := Instance.CreateOIDCAuthRequestImplicit(CTX, client.GetClientId(), Instance.Users.Get(integration.UserTypeOrgOwner).ID, redirectURIImplicit)
require.NoError(t, err)
return authRequestID
}(),
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionResp.GetSessionId(),
SessionToken: sessionResp.GetSessionToken(),
},
},
},
want: &oidc_pb.CreateCallbackResponse{
CallbackUrl: `http:\/\/localhost:9999\/callback#access_token=(.*)&expires_in=(.*)&id_token=(.*)&state=state&token_type=Bearer`,
Details: &object.Details{
ChangeDate: timestamppb.Now(),
ResourceOwner: Instance.ID(),
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Client.CreateCallback(CTX, tt.req)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
integration.AssertDetails(t, tt.want, got)
if tt.want != nil {
assert.Regexp(t, regexp.MustCompile(tt.want.CallbackUrl), got.GetCallbackUrl())
}
})
}
}

View File

@ -1,203 +0,0 @@
package oidc
import (
"context"
"github.com/zitadel/logging"
"github.com/zitadel/oidc/v3/pkg/op"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/zitadel/zitadel/internal/api/grpc/object/v2"
"github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/oidc"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/zerrors"
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2"
)
func (s *Server) GetAuthRequest(ctx context.Context, req *oidc_pb.GetAuthRequestRequest) (*oidc_pb.GetAuthRequestResponse, error) {
authRequest, err := s.query.AuthRequestByID(ctx, true, req.GetAuthRequestId(), true)
if err != nil {
logging.WithError(err).Error("query authRequest by ID")
return nil, err
}
return &oidc_pb.GetAuthRequestResponse{
AuthRequest: authRequestToPb(authRequest),
}, nil
}
func authRequestToPb(a *query.AuthRequest) *oidc_pb.AuthRequest {
pba := &oidc_pb.AuthRequest{
Id: a.ID,
CreationDate: timestamppb.New(a.CreationDate),
ClientId: a.ClientID,
Scope: a.Scope,
RedirectUri: a.RedirectURI,
Prompt: promptsToPb(a.Prompt),
UiLocales: a.UiLocales,
LoginHint: a.LoginHint,
HintUserId: a.HintUserID,
}
if a.MaxAge != nil {
pba.MaxAge = durationpb.New(*a.MaxAge)
}
return pba
}
func promptsToPb(promps []domain.Prompt) []oidc_pb.Prompt {
out := make([]oidc_pb.Prompt, len(promps))
for i, p := range promps {
out[i] = promptToPb(p)
}
return out
}
func promptToPb(p domain.Prompt) oidc_pb.Prompt {
switch p {
case domain.PromptUnspecified:
return oidc_pb.Prompt_PROMPT_UNSPECIFIED
case domain.PromptNone:
return oidc_pb.Prompt_PROMPT_NONE
case domain.PromptLogin:
return oidc_pb.Prompt_PROMPT_LOGIN
case domain.PromptConsent:
return oidc_pb.Prompt_PROMPT_CONSENT
case domain.PromptSelectAccount:
return oidc_pb.Prompt_PROMPT_SELECT_ACCOUNT
case domain.PromptCreate:
return oidc_pb.Prompt_PROMPT_CREATE
default:
return oidc_pb.Prompt_PROMPT_UNSPECIFIED
}
}
func (s *Server) CreateCallback(ctx context.Context, req *oidc_pb.CreateCallbackRequest) (*oidc_pb.CreateCallbackResponse, error) {
switch v := req.GetCallbackKind().(type) {
case *oidc_pb.CreateCallbackRequest_Error:
return s.failAuthRequest(ctx, req.GetAuthRequestId(), v.Error)
case *oidc_pb.CreateCallbackRequest_Session:
return s.linkSessionToAuthRequest(ctx, req.GetAuthRequestId(), v.Session)
default:
return nil, zerrors.ThrowUnimplementedf(nil, "OIDCv2-zee7A", "verification oneOf %T in method CreateCallback not implemented", v)
}
}
func (s *Server) failAuthRequest(ctx context.Context, authRequestID string, ae *oidc_pb.AuthorizationError) (*oidc_pb.CreateCallbackResponse, error) {
details, aar, err := s.command.FailAuthRequest(ctx, authRequestID, errorReasonToDomain(ae.GetError()))
if err != nil {
return nil, err
}
authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar}
callback, err := oidc.CreateErrorCallbackURL(authReq, errorReasonToOIDC(ae.GetError()), ae.GetErrorDescription(), ae.GetErrorUri(), s.op.Provider())
if err != nil {
return nil, err
}
return &oidc_pb.CreateCallbackResponse{
Details: object.DomainToDetailsPb(details),
CallbackUrl: callback,
}, nil
}
func (s *Server) linkSessionToAuthRequest(ctx context.Context, authRequestID string, session *oidc_pb.Session) (*oidc_pb.CreateCallbackResponse, error) {
details, aar, err := s.command.LinkSessionToAuthRequest(ctx, authRequestID, session.GetSessionId(), session.GetSessionToken(), true)
if err != nil {
return nil, err
}
authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar}
ctx = op.ContextWithIssuer(ctx, http.DomainContext(ctx).Origin())
var callback string
if aar.ResponseType == domain.OIDCResponseTypeCode {
callback, err = oidc.CreateCodeCallbackURL(ctx, authReq, s.op.Provider())
} else {
callback, err = s.op.CreateTokenCallbackURL(ctx, authReq)
}
if err != nil {
return nil, err
}
return &oidc_pb.CreateCallbackResponse{
Details: object.DomainToDetailsPb(details),
CallbackUrl: callback,
}, nil
}
func errorReasonToDomain(errorReason oidc_pb.ErrorReason) domain.OIDCErrorReason {
switch errorReason {
case oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED:
return domain.OIDCErrorReasonUnspecified
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST:
return domain.OIDCErrorReasonInvalidRequest
case oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT:
return domain.OIDCErrorReasonUnauthorizedClient
case oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED:
return domain.OIDCErrorReasonAccessDenied
case oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE:
return domain.OIDCErrorReasonUnsupportedResponseType
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE:
return domain.OIDCErrorReasonInvalidScope
case oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR:
return domain.OIDCErrorReasonServerError
case oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE:
return domain.OIDCErrorReasonTemporaryUnavailable
case oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED:
return domain.OIDCErrorReasonInteractionRequired
case oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED:
return domain.OIDCErrorReasonLoginRequired
case oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED:
return domain.OIDCErrorReasonAccountSelectionRequired
case oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED:
return domain.OIDCErrorReasonConsentRequired
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI:
return domain.OIDCErrorReasonInvalidRequestURI
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT:
return domain.OIDCErrorReasonInvalidRequestObject
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED:
return domain.OIDCErrorReasonRequestNotSupported
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED:
return domain.OIDCErrorReasonRequestURINotSupported
case oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED:
return domain.OIDCErrorReasonRegistrationNotSupported
default:
return domain.OIDCErrorReasonUnspecified
}
}
func errorReasonToOIDC(reason oidc_pb.ErrorReason) string {
switch reason {
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST:
return "invalid_request"
case oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT:
return "unauthorized_client"
case oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED:
return "access_denied"
case oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE:
return "unsupported_response_type"
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE:
return "invalid_scope"
case oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE:
return "temporarily_unavailable"
case oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED:
return "interaction_required"
case oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED:
return "login_required"
case oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED:
return "account_selection_required"
case oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED:
return "consent_required"
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI:
return "invalid_request_uri"
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT:
return "invalid_request_object"
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED:
return "request_not_supported"
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED:
return "request_uri_not_supported"
case oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED:
return "registration_not_supported"
case oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED, oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR:
fallthrough
default:
return "server_error"
}
}

View File

@ -0,0 +1,132 @@
package saml
import (
"context"
"github.com/zitadel/logging"
"github.com/zitadel/oidc/v3/pkg/op"
"github.com/zitadel/saml/pkg/provider"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/zitadel/zitadel/internal/api/grpc/object/v2"
"github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/oidc"
"github.com/zitadel/zitadel/internal/api/saml"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/zerrors"
saml_pb "github.com/zitadel/zitadel/pkg/grpc/saml/v2"
)
func (s *Server) GetAuthRequest(ctx context.Context, req *saml_pb.GetSAMLRequestRequest) (*saml_pb.GetSAMLRequestResponse, error) {
authRequest, err := s.query.AuthRequestByID(ctx, true, req.GetSamlRequestId(), true)
if err != nil {
logging.WithError(err).Error("query samlRequest by ID")
return nil, err
}
return &saml_pb.GetSAMLRequestResponse{
SamlRequest: samlRequestToPb(authRequest),
}, nil
}
func samlRequestToPb(a *query.AuthRequest) *saml_pb.SAMLRequest {
return &saml_pb.SAMLRequest{
Id: a.ID,
CreationDate: timestamppb.New(a.CreationDate),
}
}
func (s *Server) CreateCallback(ctx context.Context, req *saml_pb.CreateCallbackRequest) (*saml_pb.CreateCallbackResponse, error) {
switch v := req.GetCallbackKind().(type) {
case *saml_pb.CreateCallbackRequest_Error:
return s.failAuthRequest(ctx, req.GetSamlRequestId(), v.Error)
case *saml_pb.CreateCallbackRequest_Session:
return s.linkSessionToAuthRequest(ctx, req.GetSamlRequestId(), v.Session)
default:
return nil, zerrors.ThrowUnimplementedf(nil, "OIDCv2-zee7A", "verification oneOf %T in method CreateCallback not implemented", v)
}
}
func (s *Server) failAuthRequest(ctx context.Context, samlRequestID string, ae *saml_pb.AuthorizationError) (*saml_pb.CreateCallbackResponse, error) {
details, aar, err := s.command.FailSAMLRequest(ctx, samlRequestID, errorReasonToDomain(ae.GetError()))
if err != nil {
return nil, err
}
authReq := &saml.AuthRequestV2{CurrentSAMLRequest: aar}
callback, err := saml.CreateErrorCallbackURL(authReq, errorReasonToSAML(ae.GetError()), ae.GetErrorDescription(), ae.GetErrorUri())
if err != nil {
return nil, err
}
return &saml_pb.CreateCallbackResponse{
Details: object.DomainToDetailsPb(details),
CallbackUrl: callback,
}, nil
}
func (s *Server) linkSessionToAuthRequest(ctx context.Context, samlRequestID string, session *saml_pb.Session) (*saml_pb.CreateCallbackResponse, error) {
details, aar, err := s.command.LinkSessionToAuthRequest(ctx, samlRequestID, session.GetSessionId(), session.GetSessionToken(), true)
if err != nil {
return nil, err
}
authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar}
ctx = op.ContextWithIssuer(ctx, http.DomainContext(ctx).Origin())
var callback string
if aar.ResponseType == domain.OIDCResponseTypeCode {
callback, err = oidc.CreateCodeCallbackURL(ctx, authReq, s.op.Provider())
} else {
callback, err = s.op.CreateTokenCallbackURL(ctx, authReq)
}
if err != nil {
return nil, err
}
return &saml_pb.CreateCallbackResponse{
Details: object.DomainToDetailsPb(details),
CallbackUrl: callback,
}, nil
}
func errorReasonToDomain(errorReason saml_pb.ErrorReason) domain.SAMLErrorReason {
switch errorReason {
case saml_pb.ErrorReason_ERROR_REASON_UNSPECIFIED:
return domain.SAMLErrorReasonUnspecified
case saml_pb.ErrorReason_ERROR_REASON_VERSION_MISSMATCH:
return domain.SAMLErrorReasonVersionMissmatch
case saml_pb.ErrorReason_ERROR_REASON_AUTH_N_FAILED:
return domain.SAMLErrorReasonAuthNFailed
case saml_pb.ErrorReason_ERROR_REASON_INVALID_ATTR_NAME_OR_VALUE:
return domain.SAMLErrorReasonInvalidAttrNameOrValue
case saml_pb.ErrorReason_ERROR_REASON_INVALID_NAMEID_POLICY:
return domain.SAMLErrorReasonInvalidNameIDPolicy
case saml_pb.ErrorReason_ERROR_REASON_REQUEST_DENIED:
return domain.SAMLErrorReasonRequestDenied
case saml_pb.ErrorReason_ERROR_REASON_REQUEST_UNSUPPORTED:
return domain.SAMLErrorReasonRequestUnsupported
case saml_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_BINDING:
return domain.SAMLErrorReasonUnsupportedBinding
default:
return domain.SAMLErrorReasonUnspecified
}
}
func errorReasonToSAML(reason saml_pb.ErrorReason) string {
switch reason {
case saml_pb.ErrorReason_ERROR_REASON_UNSPECIFIED:
return "unspecified error"
case saml_pb.ErrorReason_ERROR_REASON_VERSION_MISSMATCH:
return provider.StatusCodeVersionMissmatch
case saml_pb.ErrorReason_ERROR_REASON_AUTH_N_FAILED:
return provider.StatusCodeAuthNFailed
case saml_pb.ErrorReason_ERROR_REASON_INVALID_ATTR_NAME_OR_VALUE:
return provider.StatusCodeInvalidAttrNameOrValue
case saml_pb.ErrorReason_ERROR_REASON_INVALID_NAMEID_POLICY:
return provider.StatusCodeInvalidNameIDPolicy
case saml_pb.ErrorReason_ERROR_REASON_REQUEST_DENIED:
return provider.StatusCodeRequestDenied
case saml_pb.ErrorReason_ERROR_REASON_REQUEST_UNSUPPORTED:
return provider.StatusCodeRequestUnsupported
case saml_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_BINDING:
return provider.StatusCodeUnsupportedBinding
default:
return "unspecified error"
}
}

View File

@ -1,4 +1,4 @@
package oidc
package saml
import (
"google.golang.org/grpc"

View File

@ -157,6 +157,11 @@ func (p *Storage) createAuthRequest(ctx context.Context, req *samlp.AuthnRequest
return AuthRequestFromBusiness(resp)
}
func CreateErrorCallbackURL(authReq models.AuthRequestInt, reason, description, uri string) (string, error) {
// TODO handling for errors
return "", nil
}
func (p *Storage) AuthRequestByID(ctx context.Context, id string) (_ models.AuthRequestInt, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()

View File

@ -4,6 +4,13 @@ type SAMLErrorReason int32
const (
SAMLErrorReasonUnspecified SAMLErrorReason = iota
SAMLErrorReasonVersionMissmatch
SAMLErrorReasonAuthNFailed
SAMLErrorReasonInvalidAttrNameOrValue
SAMLErrorReasonInvalidNameIDPolicy
SAMLErrorReasonRequestDenied
SAMLErrorReasonRequestUnsupported
SAMLErrorReasonUnsupportedBinding
)
func SAMLErrorReasonFromError(err error) SAMLErrorReason {

View File

@ -34,6 +34,7 @@ import (
user_v3alpha "github.com/zitadel/zitadel/pkg/grpc/resources/user/v3alpha"
userschema_v3alpha "github.com/zitadel/zitadel/pkg/grpc/resources/userschema/v3alpha"
webkey_v3alpha "github.com/zitadel/zitadel/pkg/grpc/resources/webkey/v3alpha"
saml_pb "github.com/zitadel/zitadel/pkg/grpc/saml/v2"
"github.com/zitadel/zitadel/pkg/grpc/session/v2"
session_v2beta "github.com/zitadel/zitadel/pkg/grpc/session/v2beta"
"github.com/zitadel/zitadel/pkg/grpc/settings/v2"
@ -65,6 +66,7 @@ type Client struct {
WebKeyV3Alpha webkey_v3alpha.ZITADELWebKeysClient
IDPv2 idp_pb.IdentityProviderServiceClient
UserV3Alpha user_v3alpha.ZITADELUsersClient
SAMLv2 saml_pb.SAMLServiceClient
}
func newClient(ctx context.Context, target string) (*Client, error) {
@ -96,6 +98,7 @@ func newClient(ctx context.Context, target string) (*Client, error) {
WebKeyV3Alpha: webkey_v3alpha.NewZITADELWebKeysClient(cc),
IDPv2: idp_pb.NewIdentityProviderServiceClient(cc),
UserV3Alpha: user_v3alpha.NewZITADELUsersClient(cc),
SAMLv2: saml_pb.NewSAMLServiceClient(cc),
}
return client, client.pollHealth(ctx)
}

View File

@ -0,0 +1,94 @@
package integration
import (
"context"
"fmt"
"net/url"
"github.com/brianvoe/gofakeit/v6"
"github.com/crewjam/saml"
"github.com/crewjam/saml/samlsp"
oidc_internal "github.com/zitadel/zitadel/internal/api/oidc"
"github.com/zitadel/zitadel/pkg/grpc/management"
)
func (i *Instance) CreateSAMLClient(ctx context.Context, projectID, entityID, acsURL, logoutURL string) (*management.AddSAMLAppResponse, error) {
samlSPMetadata := `<?xml version="1.0"?>
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata"
validUntil="2024-12-05T17:23:27Z"
cacheDuration="PT604800S"
entityID="` + entityID + `">
<md:SPSSODescriptor AuthnRequestsSigned="true" WantAssertionsSigned="true" protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<md:SingleLogoutService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
Location="` + logoutURL + `" />
<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>
<md:AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
Location="` + acsURL + `"
index="1" />
</md:SPSSODescriptor>
</md:EntityDescriptor>`
resp, err := i.Client.Mgmt.AddSAMLApp(ctx, &management.AddSAMLAppRequest{
ProjectId: projectID,
Name: fmt.Sprintf("app-%s", gofakeit.AppName()),
Metadata: &management.AddSAMLAppRequest_MetadataXml{MetadataXml: []byte(samlSPMetadata)},
})
if err != nil {
return nil, err
}
return resp, await(func() error {
_, err := i.Client.Mgmt.GetProjectByID(ctx, &management.GetProjectByIDRequest{
Id: projectID,
})
if err != nil {
return err
}
_, err = i.Client.Mgmt.GetAppByID(ctx, &management.GetAppByIDRequest{
ProjectId: projectID,
AppId: resp.GetAppId(),
})
return err
})
}
func (i *Instance) CreateSAMLAuthRequest(ctx context.Context, entityID, loginClient, acsURL, relayState string) (authRequestID string, err error) {
binding := saml.HTTPRedirectBinding
entityDescriptor := new(saml.EntityDescriptor)
rootURL, err := url.Parse(entityID)
if err != nil {
return "", err
}
m, _ := samlsp.New(samlsp.Options{
URL: *rootURL,
/* TODO
Key: keyPair.PrivateKey.(*rsa.PrivateKey),
Certificate: keyPair.Leaf,
*/
IDPMetadata: entityDescriptor,
})
authReq, err := m.ServiceProvider.MakeAuthenticationRequest(acsURL, binding, m.ResponseBinding)
if err != nil {
return "", err
}
redirectURL, err := authReq.Redirect(relayState, &m.ServiceProvider)
if err != nil {
return "", err
}
req, err := GetRequest(redirectURL.String(), map[string]string{oidc_internal.LoginClientHeader: loginClient})
if err != nil {
return "", fmt.Errorf("get request: %w", err)
}
loc, err := CheckRedirect(req)
if err != nil {
return "", fmt.Errorf("check redirect: %w", err)
}
//TODO get id from loc
return loc.String(), nil
}

View File

@ -69,6 +69,7 @@ var (
DeviceAuthProjection *handler.Handler
SessionProjection *handler.Handler
AuthRequestProjection *handler.Handler
SamlRequestProjection *handler.Handler
MilestoneProjection *handler.Handler
QuotaProjection *quotaProjection
LimitsProjection *handler.Handler
@ -156,6 +157,7 @@ func Create(ctx context.Context, sqlClient *database.DB, es handler.EventStore,
DeviceAuthProjection = newDeviceAuthProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["device_auth"]))
SessionProjection = newSessionProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["sessions"]))
AuthRequestProjection = newAuthRequestProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["auth_requests"]))
SamlRequestProjection = newSamlRequestProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["saml_requests"]))
MilestoneProjection = newMilestoneProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["milestones"]))
QuotaProjection = newQuotaProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["quotas"]))
LimitsProjection = newLimitsProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["limits"]))
@ -287,6 +289,7 @@ func newProjectionsList() {
DeviceAuthProjection,
SessionProjection,
AuthRequestProjection,
SamlRequestProjection,
MilestoneProjection,
QuotaProjection.handler,
LimitsProjection,

View File

@ -0,0 +1,132 @@
package projection
import (
"context"
"github.com/zitadel/zitadel/internal/eventstore"
old_handler "github.com/zitadel/zitadel/internal/eventstore/handler"
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/samlrequest"
"github.com/zitadel/zitadel/internal/zerrors"
)
const (
SamlRequestsProjectionTable = "projections.saml_requests"
SamlRequestColumnID = "id"
SamlRequestColumnCreationDate = "creation_date"
SamlRequestColumnChangeDate = "change_date"
SamlRequestColumnSequence = "sequence"
SamlRequestColumnResourceOwner = "resource_owner"
SamlRequestColumnInstanceID = "instance_id"
SamlRequestColumnLoginClient = "login_client"
SamlRequestColumnIssuer = "issuer"
SamlRequestColumnACS = "acs"
SamlRequestColumnRelayState = "relay_state"
SamlRequestColumnBinding = "binding"
)
type samlRequestProjection struct{}
// Name implements handler.Projection.
func (*samlRequestProjection) Name() string {
return SamlRequestsProjectionTable
}
func newSamlRequestProjection(ctx context.Context, config handler.Config) *handler.Handler {
return handler.NewHandler(ctx, &config, new(samlRequestProjection))
}
func (*samlRequestProjection) Init() *old_handler.Check {
return handler.NewMultiTableCheck(
handler.NewTable([]*handler.InitColumn{
handler.NewColumn(SamlRequestColumnID, handler.ColumnTypeText),
handler.NewColumn(SamlRequestColumnCreationDate, handler.ColumnTypeTimestamp),
handler.NewColumn(SamlRequestColumnChangeDate, handler.ColumnTypeTimestamp),
handler.NewColumn(SamlRequestColumnSequence, handler.ColumnTypeInt64),
handler.NewColumn(SamlRequestColumnResourceOwner, handler.ColumnTypeText),
handler.NewColumn(SamlRequestColumnInstanceID, handler.ColumnTypeText),
handler.NewColumn(SamlRequestColumnLoginClient, handler.ColumnTypeText),
handler.NewColumn(SamlRequestColumnIssuer, handler.ColumnTypeText),
handler.NewColumn(SamlRequestColumnACS, handler.ColumnTypeText),
handler.NewColumn(SamlRequestColumnRelayState, handler.ColumnTypeText),
handler.NewColumn(SamlRequestColumnBinding, handler.ColumnTypeText),
},
handler.NewPrimaryKey(SamlRequestColumnInstanceID, SamlRequestColumnID),
),
)
}
func (p *samlRequestProjection) Reducers() []handler.AggregateReducer {
return []handler.AggregateReducer{
{
Aggregate: samlrequest.AggregateType,
EventReducers: []handler.EventReducer{
{
Event: samlrequest.AddedType,
Reduce: p.reduceSamlRequestAdded,
},
{
Event: samlrequest.SucceededType,
Reduce: p.reduceSamlRequestEnded,
},
{
Event: samlrequest.FailedType,
Reduce: p.reduceSamlRequestEnded,
},
},
},
{
Aggregate: instance.AggregateType,
EventReducers: []handler.EventReducer{
{
Event: instance.InstanceRemovedEventType,
Reduce: reduceInstanceRemovedHelper(SamlRequestColumnInstanceID),
},
},
},
}
}
func (p *samlRequestProjection) reduceSamlRequestAdded(event eventstore.Event) (*handler.Statement, error) {
e, ok := event.(*samlrequest.AddedEvent)
if !ok {
return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-Sfwfa", "reduce.wrong.event.type %s", samlrequest.AddedType)
}
return handler.NewCreateStatement(
e,
[]handler.Column{
handler.NewCol(SamlRequestColumnID, e.Aggregate().ID),
handler.NewCol(SamlRequestColumnInstanceID, e.Aggregate().InstanceID),
handler.NewCol(SamlRequestColumnCreationDate, e.CreationDate()),
handler.NewCol(SamlRequestColumnChangeDate, e.CreationDate()),
handler.NewCol(SamlRequestColumnResourceOwner, e.Aggregate().ResourceOwner),
handler.NewCol(SamlRequestColumnSequence, e.Sequence()),
handler.NewCol(SamlRequestColumnLoginClient, e.LoginClient),
handler.NewCol(SamlRequestColumnIssuer, e.Issuer),
handler.NewCol(SamlRequestColumnACS, e.ACSURL),
handler.NewCol(SamlRequestColumnRelayState, e.RelayState),
handler.NewCol(SamlRequestColumnBinding, e.Binding),
},
), nil
}
func (p *samlRequestProjection) reduceSamlRequestEnded(event eventstore.Event) (*handler.Statement, error) {
switch event.(type) {
case *samlrequest.SucceededEvent,
*samlrequest.FailedEvent:
break
default:
return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-ASF3h", "reduce.wrong.event.type %s", []eventstore.EventType{samlrequest.SucceededType, samlrequest.FailedType})
}
return handler.NewDeleteStatement(
event,
[]handler.Condition{
handler.NewCond(SamlRequestColumnID, event.Aggregate().ID),
handler.NewCond(SamlRequestColumnInstanceID, event.Aggregate().InstanceID),
},
), nil
}

View File

@ -0,0 +1,123 @@
package projection
import (
"testing"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
"github.com/zitadel/zitadel/internal/repository/samlrequest"
"github.com/zitadel/zitadel/internal/zerrors"
)
func TestSamlRequestProjection_reduces(t *testing.T) {
type args struct {
event func(t *testing.T) eventstore.Event
}
tests := []struct {
name string
args args
reduce func(event eventstore.Event) (*handler.Statement, error)
want wantReduce
}{
{
name: "reduceSamlRequestAdded",
args: args{
event: getEvent(testEvent(
samlrequest.AddedType,
samlrequest.AggregateType,
[]byte(`{"login_client": "loginClient", "issuer": "issuer", "acs_url": "acs", "relay_state": "relayState", "binding": "binding"}`),
), eventstore.GenericEventMapper[samlrequest.AddedEvent]),
},
reduce: (&samlRequestProjection{}).reduceSamlRequestAdded,
want: wantReduce{
aggregateType: eventstore.AggregateType("saml_request"),
sequence: 15,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "INSERT INTO projections.saml_requests (id, instance_id, creation_date, change_date, resource_owner, sequence, login_client, issuer, acs, relay_state, binding) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)",
expectedArgs: []interface{}{
"agg-id",
"instance-id",
anyArg{},
anyArg{},
"ro-id",
uint64(15),
"loginClient",
"issuer",
"acs",
"relayState",
"binding",
},
},
},
},
},
},
{
name: "reduceSamlRequestFailed",
args: args{
event: getEvent(testEvent(
samlrequest.FailedType,
samlrequest.AggregateType,
[]byte(`{"reason": 0}`),
), eventstore.GenericEventMapper[samlrequest.FailedEvent]),
},
reduce: (&samlRequestProjection{}).reduceSamlRequestEnded,
want: wantReduce{
aggregateType: eventstore.AggregateType("saml_request"),
sequence: 15,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "DELETE FROM projections.saml_requests WHERE (id = $1) AND (instance_id = $2)",
expectedArgs: []interface{}{
"agg-id",
"instance-id",
},
},
},
},
},
},
{
name: "reduceSamlRequestSucceeded",
args: args{
event: getEvent(testEvent(
samlrequest.SucceededType,
samlrequest.AggregateType,
nil,
), eventstore.GenericEventMapper[samlrequest.SucceededEvent]),
},
reduce: (&samlRequestProjection{}).reduceSamlRequestEnded,
want: wantReduce{
aggregateType: eventstore.AggregateType("saml_request"),
sequence: 15,
executer: &testExecuter{
executions: []execution{
{
expectedStmt: "DELETE FROM projections.saml_requests WHERE (id = $1) AND (instance_id = $2)",
expectedArgs: []interface{}{
"agg-id",
"instance-id",
},
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
event := baseEvent(t)
got, err := tt.reduce(event)
if !zerrors.IsErrorInvalidArgument(err) {
t.Errorf("no wrong event mapping: %v, got: %v", err, got)
}
event = tt.args.event(t)
got, err = tt.reduce(event)
assertReduce(t, got, err, SamlRequestsProjectionTable, tt.want)
})
}
}

View File

@ -0,0 +1,81 @@
package query
import (
"context"
"database/sql"
_ "embed"
"errors"
"fmt"
"time"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
type SamlRequest struct {
ID string
CreationDate time.Time
LoginClient string
Issuer string
ACS string
RelayState string
Binding string
}
func (a *SamlRequest) checkLoginClient(ctx context.Context) error {
if uid := authz.GetCtxData(ctx).UserID; uid != a.LoginClient {
return zerrors.ThrowPermissionDenied(nil, "OIDCv2-aL0ag", "Errors.SamlRequest.WrongLoginClient")
}
return nil
}
//go:embed saml_request_by_id.sql
var samlRequestByIDQuery string
func (q *Queries) samlRequestByIDQuery(ctx context.Context) string {
return fmt.Sprintf(samlRequestByIDQuery, q.client.Timetravel(call.Took(ctx)))
}
func (q *Queries) SamlRequestByID(ctx context.Context, shouldTriggerBulk bool, id string, checkLoginClient bool) (_ *SamlRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if shouldTriggerBulk {
_, traceSpan := tracing.NewNamedSpan(ctx, "TriggerSamlRequestProjection")
ctx, err = projection.SamlRequestProjection.Trigger(ctx, handler.WithAwaitRunning())
logging.OnError(err).Debug("trigger failed")
traceSpan.EndWithError(err)
}
dst := new(SamlRequest)
err = q.client.QueryRowContext(
ctx,
func(row *sql.Row) error {
return row.Scan(
&dst.ID, &dst.CreationDate, &dst.LoginClient, &dst.Issuer, &dst.ACS, &dst.RelayState, &dst.Binding,
)
},
q.samlRequestByIDQuery(ctx),
id, authz.GetInstance(ctx).InstanceID(),
)
if errors.Is(err, sql.ErrNoRows) {
return nil, zerrors.ThrowNotFound(err, "QUERY-Thee9", "Errors.SamlRequest.NotExisting")
}
if err != nil {
return nil, zerrors.ThrowInternal(err, "QUERY-Ou8ue", "Errors.Internal")
}
if checkLoginClient {
if err = dst.checkLoginClient(ctx); err != nil {
return nil, err
}
}
return dst, nil
}

View File

@ -0,0 +1,11 @@
select
id,
creation_date,
login_client,
issuer,
acs,
relay_state,
binding
from projections.saml_requests %s
where id = $1 and instance_id = $2
limit 1;

View File

@ -0,0 +1,127 @@
package query
import (
"database/sql"
"database/sql/driver"
_ "embed"
"fmt"
"regexp"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/zerrors"
)
func TestQueries_SamlRequestByID(t *testing.T) {
expQuery := regexp.QuoteMeta(fmt.Sprintf(
samlRequestByIDQuery,
asOfSystemTime,
))
cols := []string{
projection.SamlRequestColumnID,
projection.SamlRequestColumnCreationDate,
projection.SamlRequestColumnLoginClient,
projection.SamlRequestColumnIssuer,
projection.SamlRequestColumnACS,
projection.SamlRequestColumnRelayState,
projection.SamlRequestColumnBinding,
}
type args struct {
shouldTriggerBulk bool
id string
checkLoginClient bool
}
tests := []struct {
name string
args args
expect sqlExpectation
want *SamlRequest
wantErr error
}{
{
name: "success, all values",
args: args{
shouldTriggerBulk: false,
id: "123",
checkLoginClient: true,
},
expect: mockQuery(expQuery, cols, []driver.Value{
"id",
testNow,
"loginClient",
"issuer",
"acs",
"relayState",
"binding",
}, "123", "instanceID"),
want: &SamlRequest{
ID: "id",
CreationDate: testNow,
LoginClient: "loginClient",
Issuer: "issuer",
ACS: "acs",
RelayState: "relayState",
Binding: "binding",
},
},
{
name: "no rows",
args: args{
shouldTriggerBulk: false,
id: "123",
},
expect: mockQueryScanErr(expQuery, cols, nil, "123", "instanceID"),
wantErr: zerrors.ThrowNotFound(sql.ErrNoRows, "QUERY-Thee9", "Errors.SamlRequest.NotExisting"),
},
{
name: "query error",
args: args{
shouldTriggerBulk: false,
id: "123",
},
expect: mockQueryErr(expQuery, sql.ErrConnDone, "123", "instanceID"),
wantErr: zerrors.ThrowInternal(sql.ErrConnDone, "QUERY-Ou8ue", "Errors.Internal"),
},
{
name: "wrong login client",
args: args{
shouldTriggerBulk: false,
id: "123",
checkLoginClient: true,
},
expect: mockQuery(expQuery, cols, []driver.Value{
"id",
testNow,
"wrongLoginClient",
"issuer",
"acs",
"relayState",
"binding",
}, "123", "instanceID"),
wantErr: zerrors.ThrowPermissionDeniedf(nil, "OIDCv2-aL0ag", "Errors.SamlRequest.WrongLoginClient"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
execMock(t, tt.expect, func(db *sql.DB) {
q := &Queries{
client: &database.DB{
DB: db,
Database: &prepareDB{},
},
}
ctx := authz.NewMockContext("instanceID", "orgID", "loginClient")
got, err := q.SamlRequestByID(ctx, tt.args.shouldTriggerBulk, tt.args.id, tt.args.checkLoginClient)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
})
})
}
}

View File

@ -19,7 +19,7 @@ const (
type AddedEvent struct {
*eventstore.BaseEvent `json:"-"`
LoginClient string `json:"loginClient,omitempty"`
LoginClient string `json:"login_client,omitempty"`
ApplicationID string `json:"application_id,omitempty"`
ACSURL string `json:"acs_url,omitempty"`
RelayState string `json:"relay_state,omitempty"`

View File

@ -169,7 +169,7 @@ message CreateCallbackRequest {
string auth_request_id = 1 [
(validate.rules).string = {min_len: 1, max_len: 200},
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "ID of the SAML Request.";
description: "ID of the Auth Request.";
example: "\"163840776835432705\"";
}
];

View File

@ -62,23 +62,11 @@ message AuthorizationError {
enum ErrorReason {
ERROR_REASON_UNSPECIFIED = 0;
// Error states from https://datatracker.ietf.org/doc/html/rfc6749#section-4.2.2.1
ERROR_REASON_INVALID_REQUEST = 1;
ERROR_REASON_UNAUTHORIZED_CLIENT = 2;
ERROR_REASON_ACCESS_DENIED = 3;
ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE = 4;
ERROR_REASON_INVALID_SCOPE = 5;
ERROR_REASON_SERVER_ERROR = 6;
ERROR_REASON_TEMPORARY_UNAVAILABLE = 7;
// Error states from https://openid.net/specs/openid-connect-core-1_0.html#AuthError
ERROR_REASON_INTERACTION_REQUIRED = 8;
ERROR_REASON_LOGIN_REQUIRED = 9;
ERROR_REASON_ACCOUNT_SELECTION_REQUIRED = 10;
ERROR_REASON_CONSENT_REQUIRED = 11;
ERROR_REASON_INVALID_REQUEST_URI = 12;
ERROR_REASON_INVALID_REQUEST_OBJECT = 13;
ERROR_REASON_REQUEST_NOT_SUPPORTED = 14;
ERROR_REASON_REQUEST_URI_NOT_SUPPORTED = 15;
ERROR_REASON_REGISTRATION_NOT_SUPPORTED = 16;
ERROR_REASON_VERSION_MISSMATCH = 1;
ERROR_REASON_AUTH_N_FAILED = 2;
ERROR_REASON_INVALID_ATTR_NAME_OR_VALUE = 3;
ERROR_REASON_INVALID_NAMEID_POLICY = 4;
ERROR_REASON_REQUEST_DENIED =5;
ERROR_REASON_REQUEST_UNSUPPORTED = 6;
ERROR_REASON_UNSUPPORTED_BINDING = 7;
}

View File

@ -123,6 +123,30 @@ service SAMLService {
};
};
}
rpc CreateCallback (CreateCallbackRequest) returns (CreateCallbackResponse) {
option (google.api.http) = {
post: "/v2/saml/saml_requests/{saml_request_id}"
body: "*"
};
option (zitadel.protoc_gen_zitadel.v2.options) = {
auth_option: {
permission: "authenticated"
}
};
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
summary: "Finalize a SAML Request and get the callback URL.";
description: "Finalize a SAML Request and get the callback URL for success or failure. The user must be redirected to the URL in order to inform the application about the success or failure. On success, the URL contains details for the application to obtain the tokens. This method can only be called once for an SAML request."
responses: {
key: "200"
value: {
description: "OK";
}
};
};
}
}
message GetSAMLRequestRequest {
@ -140,3 +164,54 @@ message GetSAMLRequestRequest {
message GetSAMLRequestResponse {
SAMLRequest saml_request = 1;
}
message CreateCallbackRequest {
string saml_request_id = 1 [
(validate.rules).string = {min_len: 1, max_len: 200},
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "ID of the SAML Request.";
example: "\"163840776835432705\"";
}
];
oneof callback_kind {
option (validate.required) = true;
Session session = 2;
AuthorizationError error = 3 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "Set this field when the authorization flow failed. It creates a callback URL to the application, with the error details set.";
}
];
}
}
message Session {
string session_id = 1 [
(validate.rules).string = {min_len: 1, max_len: 200},
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
min_length: 1;
max_length: 200;
description: "ID of the session, used to login the user. Connects the session to the SAML Request.";
example: "\"163840776835432705\"";
}
];
string session_token = 2 [
(validate.rules).string = {min_len: 1, max_len: 200},
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
min_length: 1;
max_length: 200;
description: "Token to verify the session is valid";
}
];
}
message CreateCallbackResponse {
zitadel.object.v2.Details details = 1;
string callback_url = 2 [
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
description: "Callback URL where the user should be redirected, using a \"302 FOUND\" status. Contains details for the application to obtain the tokens on success, or error details on failure. Note that this field must be treated as credentials, as the contained code can be used to obtain tokens on behalve of the user.";
example: "\"https://client.example.org/cb\""
}
];
}