mirror of
https://github.com/zitadel/zitadel.git
synced 2024-12-12 11:04:25 +00:00
feat: add callback endpoint and query side for saml requests
This commit is contained in:
parent
905da945ff
commit
f321b070ba
250
internal/api/grpc/saml/integration/saml_test.go
Normal file
250
internal/api/grpc/saml/integration/saml_test.go
Normal file
@ -0,0 +1,250 @@
|
||||
//go:build integration
|
||||
|
||||
package saml_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/object/v2"
|
||||
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2"
|
||||
saml_pb "github.com/zitadel/zitadel/pkg/grpc/saml/v2"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/session/v2"
|
||||
)
|
||||
|
||||
var (
|
||||
CTX context.Context
|
||||
Instance *integration.Instance
|
||||
Client saml_pb.SAMLServiceClient
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(func() int {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
Instance = integration.NewInstance(ctx)
|
||||
Client = Instance.Client.SAMLv2
|
||||
|
||||
CTX = Instance.WithAuthorization(ctx, integration.UserTypeOrgOwner)
|
||||
return m.Run()
|
||||
}())
|
||||
}
|
||||
|
||||
func TestServer_GetAuthRequest(t *testing.T) {
|
||||
entityID := "https://sp.example.com"
|
||||
|
||||
project, err := Instance.CreateProject(CTX)
|
||||
require.NoError(t, err)
|
||||
client, err := Instance.CreateSAMLClient(CTX, project.GetId(), entityID, entityID+"/saml/v2/sso", entityID+"/saml/v2/slo")
|
||||
require.NoError(t, err)
|
||||
authRequestID, err := Instance.CreateSAMLAuthRequest(CTX, "", Instance.Users[integration.UserTypeOrgOwner].ID, "acs", "relaystate")
|
||||
require.NoError(t, err)
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
AuthRequestID string
|
||||
want *oidc_pb.GetAuthRequestResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Not found",
|
||||
AuthRequestID: "123",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
AuthRequestID: authRequestID,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Client.GetAuthRequest(CTX, &saml_pb.GetSAMLRequestRequest{
|
||||
SamlRequestId: tt.AuthRequestID,
|
||||
})
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
authRequest := got.GetSamlRequest()
|
||||
assert.NotNil(t, authRequest)
|
||||
assert.Equal(t, authRequestID, authRequest.GetId())
|
||||
assert.WithinRange(t, authRequest.GetCreationDate().AsTime(), now.Add(-time.Second), now.Add(time.Second))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_CreateCallback(t *testing.T) {
|
||||
project, err := Instance.CreateProject(CTX)
|
||||
require.NoError(t, err)
|
||||
client, err := Instance.CreateOIDCNativeClient(CTX, redirectURI, logoutRedirectURI, project.GetId(), false)
|
||||
require.NoError(t, err)
|
||||
sessionResp, err := Instance.Client.SessionV2.CreateSession(CTX, &session.CreateSessionRequest{
|
||||
Checks: &session.Checks{
|
||||
User: &session.CheckUser{
|
||||
Search: &session.CheckUser_UserId{
|
||||
UserId: Instance.Users[integration.UserTypeOrgOwner].ID,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
req *oidc_pb.CreateCallbackRequest
|
||||
AuthError string
|
||||
want *oidc_pb.CreateCallbackResponse
|
||||
wantURL *url.URL
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Not found",
|
||||
req: &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: "123",
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionResp.GetSessionId(),
|
||||
SessionToken: sessionResp.GetSessionToken(),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "session not found",
|
||||
req: &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: func() string {
|
||||
authRequestID, err := Instance.CreateOIDCAuthRequest(CTX, client.GetClientId(), Instance.Users[integration.UserTypeOrgOwner].ID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
return authRequestID
|
||||
}(),
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: "foo",
|
||||
SessionToken: "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "session token invalid",
|
||||
req: &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: func() string {
|
||||
authRequestID, err := Instance.CreateOIDCAuthRequest(CTX, client.GetClientId(), Instance.Users.Get(integration.UserTypeOrgOwner).ID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
return authRequestID
|
||||
}(),
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionResp.GetSessionId(),
|
||||
SessionToken: "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "fail callback",
|
||||
req: &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: func() string {
|
||||
authRequestID, err := Instance.CreateOIDCAuthRequest(CTX, client.GetClientId(), Instance.Users.Get(integration.UserTypeOrgOwner).ID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
return authRequestID
|
||||
}(),
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Error{
|
||||
Error: &oidc_pb.AuthorizationError{
|
||||
Error: oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED,
|
||||
ErrorDescription: gu.Ptr("nope"),
|
||||
ErrorUri: gu.Ptr("https://example.com/docs"),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &oidc_pb.CreateCallbackResponse{
|
||||
CallbackUrl: regexp.QuoteMeta(`oidcintegrationtest://callback?error=access_denied&error_description=nope&error_uri=https%3A%2F%2Fexample.com%2Fdocs&state=state`),
|
||||
Details: &object.Details{
|
||||
ChangeDate: timestamppb.Now(),
|
||||
ResourceOwner: Instance.ID(),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "code callback",
|
||||
req: &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: func() string {
|
||||
authRequestID, err := Instance.CreateOIDCAuthRequest(CTX, client.GetClientId(), Instance.Users.Get(integration.UserTypeOrgOwner).ID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
return authRequestID
|
||||
}(),
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionResp.GetSessionId(),
|
||||
SessionToken: sessionResp.GetSessionToken(),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &oidc_pb.CreateCallbackResponse{
|
||||
CallbackUrl: `oidcintegrationtest:\/\/callback\?code=(.*)&state=state`,
|
||||
Details: &object.Details{
|
||||
ChangeDate: timestamppb.Now(),
|
||||
ResourceOwner: Instance.ID(),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "implicit",
|
||||
req: &oidc_pb.CreateCallbackRequest{
|
||||
AuthRequestId: func() string {
|
||||
client, err := Instance.CreateOIDCImplicitFlowClient(CTX, redirectURIImplicit)
|
||||
require.NoError(t, err)
|
||||
authRequestID, err := Instance.CreateOIDCAuthRequestImplicit(CTX, client.GetClientId(), Instance.Users.Get(integration.UserTypeOrgOwner).ID, redirectURIImplicit)
|
||||
require.NoError(t, err)
|
||||
return authRequestID
|
||||
}(),
|
||||
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
|
||||
Session: &oidc_pb.Session{
|
||||
SessionId: sessionResp.GetSessionId(),
|
||||
SessionToken: sessionResp.GetSessionToken(),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &oidc_pb.CreateCallbackResponse{
|
||||
CallbackUrl: `http:\/\/localhost:9999\/callback#access_token=(.*)&expires_in=(.*)&id_token=(.*)&state=state&token_type=Bearer`,
|
||||
Details: &object.Details{
|
||||
ChangeDate: timestamppb.Now(),
|
||||
ResourceOwner: Instance.ID(),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Client.CreateCallback(CTX, tt.req)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
integration.AssertDetails(t, tt.want, got)
|
||||
if tt.want != nil {
|
||||
assert.Regexp(t, regexp.MustCompile(tt.want.CallbackUrl), got.GetCallbackUrl())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -1,203 +0,0 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/object/v2"
|
||||
"github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/api/oidc"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2"
|
||||
)
|
||||
|
||||
func (s *Server) GetAuthRequest(ctx context.Context, req *oidc_pb.GetAuthRequestRequest) (*oidc_pb.GetAuthRequestResponse, error) {
|
||||
authRequest, err := s.query.AuthRequestByID(ctx, true, req.GetAuthRequestId(), true)
|
||||
if err != nil {
|
||||
logging.WithError(err).Error("query authRequest by ID")
|
||||
return nil, err
|
||||
}
|
||||
return &oidc_pb.GetAuthRequestResponse{
|
||||
AuthRequest: authRequestToPb(authRequest),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func authRequestToPb(a *query.AuthRequest) *oidc_pb.AuthRequest {
|
||||
pba := &oidc_pb.AuthRequest{
|
||||
Id: a.ID,
|
||||
CreationDate: timestamppb.New(a.CreationDate),
|
||||
ClientId: a.ClientID,
|
||||
Scope: a.Scope,
|
||||
RedirectUri: a.RedirectURI,
|
||||
Prompt: promptsToPb(a.Prompt),
|
||||
UiLocales: a.UiLocales,
|
||||
LoginHint: a.LoginHint,
|
||||
HintUserId: a.HintUserID,
|
||||
}
|
||||
if a.MaxAge != nil {
|
||||
pba.MaxAge = durationpb.New(*a.MaxAge)
|
||||
}
|
||||
return pba
|
||||
}
|
||||
|
||||
func promptsToPb(promps []domain.Prompt) []oidc_pb.Prompt {
|
||||
out := make([]oidc_pb.Prompt, len(promps))
|
||||
for i, p := range promps {
|
||||
out[i] = promptToPb(p)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func promptToPb(p domain.Prompt) oidc_pb.Prompt {
|
||||
switch p {
|
||||
case domain.PromptUnspecified:
|
||||
return oidc_pb.Prompt_PROMPT_UNSPECIFIED
|
||||
case domain.PromptNone:
|
||||
return oidc_pb.Prompt_PROMPT_NONE
|
||||
case domain.PromptLogin:
|
||||
return oidc_pb.Prompt_PROMPT_LOGIN
|
||||
case domain.PromptConsent:
|
||||
return oidc_pb.Prompt_PROMPT_CONSENT
|
||||
case domain.PromptSelectAccount:
|
||||
return oidc_pb.Prompt_PROMPT_SELECT_ACCOUNT
|
||||
case domain.PromptCreate:
|
||||
return oidc_pb.Prompt_PROMPT_CREATE
|
||||
default:
|
||||
return oidc_pb.Prompt_PROMPT_UNSPECIFIED
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) CreateCallback(ctx context.Context, req *oidc_pb.CreateCallbackRequest) (*oidc_pb.CreateCallbackResponse, error) {
|
||||
switch v := req.GetCallbackKind().(type) {
|
||||
case *oidc_pb.CreateCallbackRequest_Error:
|
||||
return s.failAuthRequest(ctx, req.GetAuthRequestId(), v.Error)
|
||||
case *oidc_pb.CreateCallbackRequest_Session:
|
||||
return s.linkSessionToAuthRequest(ctx, req.GetAuthRequestId(), v.Session)
|
||||
default:
|
||||
return nil, zerrors.ThrowUnimplementedf(nil, "OIDCv2-zee7A", "verification oneOf %T in method CreateCallback not implemented", v)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) failAuthRequest(ctx context.Context, authRequestID string, ae *oidc_pb.AuthorizationError) (*oidc_pb.CreateCallbackResponse, error) {
|
||||
details, aar, err := s.command.FailAuthRequest(ctx, authRequestID, errorReasonToDomain(ae.GetError()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar}
|
||||
callback, err := oidc.CreateErrorCallbackURL(authReq, errorReasonToOIDC(ae.GetError()), ae.GetErrorDescription(), ae.GetErrorUri(), s.op.Provider())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &oidc_pb.CreateCallbackResponse{
|
||||
Details: object.DomainToDetailsPb(details),
|
||||
CallbackUrl: callback,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) linkSessionToAuthRequest(ctx context.Context, authRequestID string, session *oidc_pb.Session) (*oidc_pb.CreateCallbackResponse, error) {
|
||||
details, aar, err := s.command.LinkSessionToAuthRequest(ctx, authRequestID, session.GetSessionId(), session.GetSessionToken(), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar}
|
||||
ctx = op.ContextWithIssuer(ctx, http.DomainContext(ctx).Origin())
|
||||
var callback string
|
||||
if aar.ResponseType == domain.OIDCResponseTypeCode {
|
||||
callback, err = oidc.CreateCodeCallbackURL(ctx, authReq, s.op.Provider())
|
||||
} else {
|
||||
callback, err = s.op.CreateTokenCallbackURL(ctx, authReq)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &oidc_pb.CreateCallbackResponse{
|
||||
Details: object.DomainToDetailsPb(details),
|
||||
CallbackUrl: callback,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func errorReasonToDomain(errorReason oidc_pb.ErrorReason) domain.OIDCErrorReason {
|
||||
switch errorReason {
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED:
|
||||
return domain.OIDCErrorReasonUnspecified
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST:
|
||||
return domain.OIDCErrorReasonInvalidRequest
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT:
|
||||
return domain.OIDCErrorReasonUnauthorizedClient
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED:
|
||||
return domain.OIDCErrorReasonAccessDenied
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE:
|
||||
return domain.OIDCErrorReasonUnsupportedResponseType
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE:
|
||||
return domain.OIDCErrorReasonInvalidScope
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR:
|
||||
return domain.OIDCErrorReasonServerError
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE:
|
||||
return domain.OIDCErrorReasonTemporaryUnavailable
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED:
|
||||
return domain.OIDCErrorReasonInteractionRequired
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED:
|
||||
return domain.OIDCErrorReasonLoginRequired
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED:
|
||||
return domain.OIDCErrorReasonAccountSelectionRequired
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED:
|
||||
return domain.OIDCErrorReasonConsentRequired
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI:
|
||||
return domain.OIDCErrorReasonInvalidRequestURI
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT:
|
||||
return domain.OIDCErrorReasonInvalidRequestObject
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED:
|
||||
return domain.OIDCErrorReasonRequestNotSupported
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED:
|
||||
return domain.OIDCErrorReasonRequestURINotSupported
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED:
|
||||
return domain.OIDCErrorReasonRegistrationNotSupported
|
||||
default:
|
||||
return domain.OIDCErrorReasonUnspecified
|
||||
}
|
||||
}
|
||||
|
||||
func errorReasonToOIDC(reason oidc_pb.ErrorReason) string {
|
||||
switch reason {
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST:
|
||||
return "invalid_request"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT:
|
||||
return "unauthorized_client"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED:
|
||||
return "access_denied"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE:
|
||||
return "unsupported_response_type"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE:
|
||||
return "invalid_scope"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE:
|
||||
return "temporarily_unavailable"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED:
|
||||
return "interaction_required"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED:
|
||||
return "login_required"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED:
|
||||
return "account_selection_required"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED:
|
||||
return "consent_required"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI:
|
||||
return "invalid_request_uri"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT:
|
||||
return "invalid_request_object"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED:
|
||||
return "request_not_supported"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED:
|
||||
return "request_uri_not_supported"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED:
|
||||
return "registration_not_supported"
|
||||
case oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED, oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR:
|
||||
fallthrough
|
||||
default:
|
||||
return "server_error"
|
||||
}
|
||||
}
|
132
internal/api/grpc/saml/saml.go
Normal file
132
internal/api/grpc/saml/saml.go
Normal file
@ -0,0 +1,132 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
"github.com/zitadel/saml/pkg/provider"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/object/v2"
|
||||
"github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/api/oidc"
|
||||
"github.com/zitadel/zitadel/internal/api/saml"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
saml_pb "github.com/zitadel/zitadel/pkg/grpc/saml/v2"
|
||||
)
|
||||
|
||||
func (s *Server) GetAuthRequest(ctx context.Context, req *saml_pb.GetSAMLRequestRequest) (*saml_pb.GetSAMLRequestResponse, error) {
|
||||
authRequest, err := s.query.AuthRequestByID(ctx, true, req.GetSamlRequestId(), true)
|
||||
if err != nil {
|
||||
logging.WithError(err).Error("query samlRequest by ID")
|
||||
return nil, err
|
||||
}
|
||||
return &saml_pb.GetSAMLRequestResponse{
|
||||
SamlRequest: samlRequestToPb(authRequest),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func samlRequestToPb(a *query.AuthRequest) *saml_pb.SAMLRequest {
|
||||
return &saml_pb.SAMLRequest{
|
||||
Id: a.ID,
|
||||
CreationDate: timestamppb.New(a.CreationDate),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) CreateCallback(ctx context.Context, req *saml_pb.CreateCallbackRequest) (*saml_pb.CreateCallbackResponse, error) {
|
||||
switch v := req.GetCallbackKind().(type) {
|
||||
case *saml_pb.CreateCallbackRequest_Error:
|
||||
return s.failAuthRequest(ctx, req.GetSamlRequestId(), v.Error)
|
||||
case *saml_pb.CreateCallbackRequest_Session:
|
||||
return s.linkSessionToAuthRequest(ctx, req.GetSamlRequestId(), v.Session)
|
||||
default:
|
||||
return nil, zerrors.ThrowUnimplementedf(nil, "OIDCv2-zee7A", "verification oneOf %T in method CreateCallback not implemented", v)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) failAuthRequest(ctx context.Context, samlRequestID string, ae *saml_pb.AuthorizationError) (*saml_pb.CreateCallbackResponse, error) {
|
||||
details, aar, err := s.command.FailSAMLRequest(ctx, samlRequestID, errorReasonToDomain(ae.GetError()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authReq := &saml.AuthRequestV2{CurrentSAMLRequest: aar}
|
||||
callback, err := saml.CreateErrorCallbackURL(authReq, errorReasonToSAML(ae.GetError()), ae.GetErrorDescription(), ae.GetErrorUri())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &saml_pb.CreateCallbackResponse{
|
||||
Details: object.DomainToDetailsPb(details),
|
||||
CallbackUrl: callback,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) linkSessionToAuthRequest(ctx context.Context, samlRequestID string, session *saml_pb.Session) (*saml_pb.CreateCallbackResponse, error) {
|
||||
details, aar, err := s.command.LinkSessionToAuthRequest(ctx, samlRequestID, session.GetSessionId(), session.GetSessionToken(), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar}
|
||||
ctx = op.ContextWithIssuer(ctx, http.DomainContext(ctx).Origin())
|
||||
var callback string
|
||||
if aar.ResponseType == domain.OIDCResponseTypeCode {
|
||||
callback, err = oidc.CreateCodeCallbackURL(ctx, authReq, s.op.Provider())
|
||||
} else {
|
||||
callback, err = s.op.CreateTokenCallbackURL(ctx, authReq)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &saml_pb.CreateCallbackResponse{
|
||||
Details: object.DomainToDetailsPb(details),
|
||||
CallbackUrl: callback,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func errorReasonToDomain(errorReason saml_pb.ErrorReason) domain.SAMLErrorReason {
|
||||
switch errorReason {
|
||||
case saml_pb.ErrorReason_ERROR_REASON_UNSPECIFIED:
|
||||
return domain.SAMLErrorReasonUnspecified
|
||||
case saml_pb.ErrorReason_ERROR_REASON_VERSION_MISSMATCH:
|
||||
return domain.SAMLErrorReasonVersionMissmatch
|
||||
case saml_pb.ErrorReason_ERROR_REASON_AUTH_N_FAILED:
|
||||
return domain.SAMLErrorReasonAuthNFailed
|
||||
case saml_pb.ErrorReason_ERROR_REASON_INVALID_ATTR_NAME_OR_VALUE:
|
||||
return domain.SAMLErrorReasonInvalidAttrNameOrValue
|
||||
case saml_pb.ErrorReason_ERROR_REASON_INVALID_NAMEID_POLICY:
|
||||
return domain.SAMLErrorReasonInvalidNameIDPolicy
|
||||
case saml_pb.ErrorReason_ERROR_REASON_REQUEST_DENIED:
|
||||
return domain.SAMLErrorReasonRequestDenied
|
||||
case saml_pb.ErrorReason_ERROR_REASON_REQUEST_UNSUPPORTED:
|
||||
return domain.SAMLErrorReasonRequestUnsupported
|
||||
case saml_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_BINDING:
|
||||
return domain.SAMLErrorReasonUnsupportedBinding
|
||||
default:
|
||||
return domain.SAMLErrorReasonUnspecified
|
||||
}
|
||||
}
|
||||
|
||||
func errorReasonToSAML(reason saml_pb.ErrorReason) string {
|
||||
switch reason {
|
||||
case saml_pb.ErrorReason_ERROR_REASON_UNSPECIFIED:
|
||||
return "unspecified error"
|
||||
case saml_pb.ErrorReason_ERROR_REASON_VERSION_MISSMATCH:
|
||||
return provider.StatusCodeVersionMissmatch
|
||||
case saml_pb.ErrorReason_ERROR_REASON_AUTH_N_FAILED:
|
||||
return provider.StatusCodeAuthNFailed
|
||||
case saml_pb.ErrorReason_ERROR_REASON_INVALID_ATTR_NAME_OR_VALUE:
|
||||
return provider.StatusCodeInvalidAttrNameOrValue
|
||||
case saml_pb.ErrorReason_ERROR_REASON_INVALID_NAMEID_POLICY:
|
||||
return provider.StatusCodeInvalidNameIDPolicy
|
||||
case saml_pb.ErrorReason_ERROR_REASON_REQUEST_DENIED:
|
||||
return provider.StatusCodeRequestDenied
|
||||
case saml_pb.ErrorReason_ERROR_REASON_REQUEST_UNSUPPORTED:
|
||||
return provider.StatusCodeRequestUnsupported
|
||||
case saml_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_BINDING:
|
||||
return provider.StatusCodeUnsupportedBinding
|
||||
default:
|
||||
return "unspecified error"
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package oidc
|
||||
package saml
|
||||
|
||||
import (
|
||||
"google.golang.org/grpc"
|
||||
|
@ -157,6 +157,11 @@ func (p *Storage) createAuthRequest(ctx context.Context, req *samlp.AuthnRequest
|
||||
return AuthRequestFromBusiness(resp)
|
||||
}
|
||||
|
||||
func CreateErrorCallbackURL(authReq models.AuthRequestInt, reason, description, uri string) (string, error) {
|
||||
// TODO handling for errors
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (p *Storage) AuthRequestByID(ctx context.Context, id string) (_ models.AuthRequestInt, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
@ -4,6 +4,13 @@ type SAMLErrorReason int32
|
||||
|
||||
const (
|
||||
SAMLErrorReasonUnspecified SAMLErrorReason = iota
|
||||
SAMLErrorReasonVersionMissmatch
|
||||
SAMLErrorReasonAuthNFailed
|
||||
SAMLErrorReasonInvalidAttrNameOrValue
|
||||
SAMLErrorReasonInvalidNameIDPolicy
|
||||
SAMLErrorReasonRequestDenied
|
||||
SAMLErrorReasonRequestUnsupported
|
||||
SAMLErrorReasonUnsupportedBinding
|
||||
)
|
||||
|
||||
func SAMLErrorReasonFromError(err error) SAMLErrorReason {
|
||||
|
@ -34,6 +34,7 @@ import (
|
||||
user_v3alpha "github.com/zitadel/zitadel/pkg/grpc/resources/user/v3alpha"
|
||||
userschema_v3alpha "github.com/zitadel/zitadel/pkg/grpc/resources/userschema/v3alpha"
|
||||
webkey_v3alpha "github.com/zitadel/zitadel/pkg/grpc/resources/webkey/v3alpha"
|
||||
saml_pb "github.com/zitadel/zitadel/pkg/grpc/saml/v2"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/session/v2"
|
||||
session_v2beta "github.com/zitadel/zitadel/pkg/grpc/session/v2beta"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/settings/v2"
|
||||
@ -65,6 +66,7 @@ type Client struct {
|
||||
WebKeyV3Alpha webkey_v3alpha.ZITADELWebKeysClient
|
||||
IDPv2 idp_pb.IdentityProviderServiceClient
|
||||
UserV3Alpha user_v3alpha.ZITADELUsersClient
|
||||
SAMLv2 saml_pb.SAMLServiceClient
|
||||
}
|
||||
|
||||
func newClient(ctx context.Context, target string) (*Client, error) {
|
||||
@ -96,6 +98,7 @@ func newClient(ctx context.Context, target string) (*Client, error) {
|
||||
WebKeyV3Alpha: webkey_v3alpha.NewZITADELWebKeysClient(cc),
|
||||
IDPv2: idp_pb.NewIdentityProviderServiceClient(cc),
|
||||
UserV3Alpha: user_v3alpha.NewZITADELUsersClient(cc),
|
||||
SAMLv2: saml_pb.NewSAMLServiceClient(cc),
|
||||
}
|
||||
return client, client.pollHealth(ctx)
|
||||
}
|
||||
|
94
internal/integration/saml.go
Normal file
94
internal/integration/saml.go
Normal file
@ -0,0 +1,94 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v6"
|
||||
"github.com/crewjam/saml"
|
||||
"github.com/crewjam/saml/samlsp"
|
||||
|
||||
oidc_internal "github.com/zitadel/zitadel/internal/api/oidc"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/management"
|
||||
)
|
||||
|
||||
func (i *Instance) CreateSAMLClient(ctx context.Context, projectID, entityID, acsURL, logoutURL string) (*management.AddSAMLAppResponse, error) {
|
||||
samlSPMetadata := `<?xml version="1.0"?>
|
||||
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata"
|
||||
validUntil="2024-12-05T17:23:27Z"
|
||||
cacheDuration="PT604800S"
|
||||
entityID="` + entityID + `">
|
||||
<md:SPSSODescriptor AuthnRequestsSigned="true" WantAssertionsSigned="true" protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
|
||||
<md:SingleLogoutService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
|
||||
Location="` + logoutURL + `" />
|
||||
<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>
|
||||
<md:AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
|
||||
Location="` + acsURL + `"
|
||||
index="1" />
|
||||
</md:SPSSODescriptor>
|
||||
</md:EntityDescriptor>`
|
||||
|
||||
resp, err := i.Client.Mgmt.AddSAMLApp(ctx, &management.AddSAMLAppRequest{
|
||||
ProjectId: projectID,
|
||||
Name: fmt.Sprintf("app-%s", gofakeit.AppName()),
|
||||
Metadata: &management.AddSAMLAppRequest_MetadataXml{MetadataXml: []byte(samlSPMetadata)},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, await(func() error {
|
||||
_, err := i.Client.Mgmt.GetProjectByID(ctx, &management.GetProjectByIDRequest{
|
||||
Id: projectID,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = i.Client.Mgmt.GetAppByID(ctx, &management.GetAppByIDRequest{
|
||||
ProjectId: projectID,
|
||||
AppId: resp.GetAppId(),
|
||||
})
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (i *Instance) CreateSAMLAuthRequest(ctx context.Context, entityID, loginClient, acsURL, relayState string) (authRequestID string, err error) {
|
||||
binding := saml.HTTPRedirectBinding
|
||||
entityDescriptor := new(saml.EntityDescriptor)
|
||||
rootURL, err := url.Parse(entityID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
m, _ := samlsp.New(samlsp.Options{
|
||||
URL: *rootURL,
|
||||
/* TODO
|
||||
Key: keyPair.PrivateKey.(*rsa.PrivateKey),
|
||||
Certificate: keyPair.Leaf,
|
||||
*/
|
||||
IDPMetadata: entityDescriptor,
|
||||
})
|
||||
|
||||
authReq, err := m.ServiceProvider.MakeAuthenticationRequest(acsURL, binding, m.ResponseBinding)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
redirectURL, err := authReq.Redirect(relayState, &m.ServiceProvider)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
req, err := GetRequest(redirectURL.String(), map[string]string{oidc_internal.LoginClientHeader: loginClient})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get request: %w", err)
|
||||
}
|
||||
|
||||
loc, err := CheckRedirect(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("check redirect: %w", err)
|
||||
}
|
||||
|
||||
//TODO get id from loc
|
||||
return loc.String(), nil
|
||||
}
|
@ -69,6 +69,7 @@ var (
|
||||
DeviceAuthProjection *handler.Handler
|
||||
SessionProjection *handler.Handler
|
||||
AuthRequestProjection *handler.Handler
|
||||
SamlRequestProjection *handler.Handler
|
||||
MilestoneProjection *handler.Handler
|
||||
QuotaProjection *quotaProjection
|
||||
LimitsProjection *handler.Handler
|
||||
@ -156,6 +157,7 @@ func Create(ctx context.Context, sqlClient *database.DB, es handler.EventStore,
|
||||
DeviceAuthProjection = newDeviceAuthProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["device_auth"]))
|
||||
SessionProjection = newSessionProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["sessions"]))
|
||||
AuthRequestProjection = newAuthRequestProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["auth_requests"]))
|
||||
SamlRequestProjection = newSamlRequestProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["saml_requests"]))
|
||||
MilestoneProjection = newMilestoneProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["milestones"]))
|
||||
QuotaProjection = newQuotaProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["quotas"]))
|
||||
LimitsProjection = newLimitsProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["limits"]))
|
||||
@ -287,6 +289,7 @@ func newProjectionsList() {
|
||||
DeviceAuthProjection,
|
||||
SessionProjection,
|
||||
AuthRequestProjection,
|
||||
SamlRequestProjection,
|
||||
MilestoneProjection,
|
||||
QuotaProjection.handler,
|
||||
LimitsProjection,
|
||||
|
132
internal/query/projection/saml_request.go
Normal file
132
internal/query/projection/saml_request.go
Normal file
@ -0,0 +1,132 @@
|
||||
package projection
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
old_handler "github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
|
||||
"github.com/zitadel/zitadel/internal/repository/instance"
|
||||
"github.com/zitadel/zitadel/internal/repository/samlrequest"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
const (
|
||||
SamlRequestsProjectionTable = "projections.saml_requests"
|
||||
|
||||
SamlRequestColumnID = "id"
|
||||
SamlRequestColumnCreationDate = "creation_date"
|
||||
SamlRequestColumnChangeDate = "change_date"
|
||||
SamlRequestColumnSequence = "sequence"
|
||||
SamlRequestColumnResourceOwner = "resource_owner"
|
||||
SamlRequestColumnInstanceID = "instance_id"
|
||||
SamlRequestColumnLoginClient = "login_client"
|
||||
SamlRequestColumnIssuer = "issuer"
|
||||
SamlRequestColumnACS = "acs"
|
||||
SamlRequestColumnRelayState = "relay_state"
|
||||
SamlRequestColumnBinding = "binding"
|
||||
)
|
||||
|
||||
type samlRequestProjection struct{}
|
||||
|
||||
// Name implements handler.Projection.
|
||||
func (*samlRequestProjection) Name() string {
|
||||
return SamlRequestsProjectionTable
|
||||
}
|
||||
|
||||
func newSamlRequestProjection(ctx context.Context, config handler.Config) *handler.Handler {
|
||||
return handler.NewHandler(ctx, &config, new(samlRequestProjection))
|
||||
}
|
||||
|
||||
func (*samlRequestProjection) Init() *old_handler.Check {
|
||||
return handler.NewMultiTableCheck(
|
||||
handler.NewTable([]*handler.InitColumn{
|
||||
handler.NewColumn(SamlRequestColumnID, handler.ColumnTypeText),
|
||||
handler.NewColumn(SamlRequestColumnCreationDate, handler.ColumnTypeTimestamp),
|
||||
handler.NewColumn(SamlRequestColumnChangeDate, handler.ColumnTypeTimestamp),
|
||||
handler.NewColumn(SamlRequestColumnSequence, handler.ColumnTypeInt64),
|
||||
handler.NewColumn(SamlRequestColumnResourceOwner, handler.ColumnTypeText),
|
||||
handler.NewColumn(SamlRequestColumnInstanceID, handler.ColumnTypeText),
|
||||
handler.NewColumn(SamlRequestColumnLoginClient, handler.ColumnTypeText),
|
||||
handler.NewColumn(SamlRequestColumnIssuer, handler.ColumnTypeText),
|
||||
handler.NewColumn(SamlRequestColumnACS, handler.ColumnTypeText),
|
||||
handler.NewColumn(SamlRequestColumnRelayState, handler.ColumnTypeText),
|
||||
handler.NewColumn(SamlRequestColumnBinding, handler.ColumnTypeText),
|
||||
},
|
||||
handler.NewPrimaryKey(SamlRequestColumnInstanceID, SamlRequestColumnID),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func (p *samlRequestProjection) Reducers() []handler.AggregateReducer {
|
||||
return []handler.AggregateReducer{
|
||||
{
|
||||
Aggregate: samlrequest.AggregateType,
|
||||
EventReducers: []handler.EventReducer{
|
||||
{
|
||||
Event: samlrequest.AddedType,
|
||||
Reduce: p.reduceSamlRequestAdded,
|
||||
},
|
||||
{
|
||||
Event: samlrequest.SucceededType,
|
||||
Reduce: p.reduceSamlRequestEnded,
|
||||
},
|
||||
{
|
||||
Event: samlrequest.FailedType,
|
||||
Reduce: p.reduceSamlRequestEnded,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Aggregate: instance.AggregateType,
|
||||
EventReducers: []handler.EventReducer{
|
||||
{
|
||||
Event: instance.InstanceRemovedEventType,
|
||||
Reduce: reduceInstanceRemovedHelper(SamlRequestColumnInstanceID),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *samlRequestProjection) reduceSamlRequestAdded(event eventstore.Event) (*handler.Statement, error) {
|
||||
e, ok := event.(*samlrequest.AddedEvent)
|
||||
if !ok {
|
||||
return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-Sfwfa", "reduce.wrong.event.type %s", samlrequest.AddedType)
|
||||
}
|
||||
|
||||
return handler.NewCreateStatement(
|
||||
e,
|
||||
[]handler.Column{
|
||||
handler.NewCol(SamlRequestColumnID, e.Aggregate().ID),
|
||||
handler.NewCol(SamlRequestColumnInstanceID, e.Aggregate().InstanceID),
|
||||
handler.NewCol(SamlRequestColumnCreationDate, e.CreationDate()),
|
||||
handler.NewCol(SamlRequestColumnChangeDate, e.CreationDate()),
|
||||
handler.NewCol(SamlRequestColumnResourceOwner, e.Aggregate().ResourceOwner),
|
||||
handler.NewCol(SamlRequestColumnSequence, e.Sequence()),
|
||||
handler.NewCol(SamlRequestColumnLoginClient, e.LoginClient),
|
||||
handler.NewCol(SamlRequestColumnIssuer, e.Issuer),
|
||||
handler.NewCol(SamlRequestColumnACS, e.ACSURL),
|
||||
handler.NewCol(SamlRequestColumnRelayState, e.RelayState),
|
||||
handler.NewCol(SamlRequestColumnBinding, e.Binding),
|
||||
},
|
||||
), nil
|
||||
}
|
||||
|
||||
func (p *samlRequestProjection) reduceSamlRequestEnded(event eventstore.Event) (*handler.Statement, error) {
|
||||
switch event.(type) {
|
||||
case *samlrequest.SucceededEvent,
|
||||
*samlrequest.FailedEvent:
|
||||
break
|
||||
default:
|
||||
return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-ASF3h", "reduce.wrong.event.type %s", []eventstore.EventType{samlrequest.SucceededType, samlrequest.FailedType})
|
||||
}
|
||||
|
||||
return handler.NewDeleteStatement(
|
||||
event,
|
||||
[]handler.Condition{
|
||||
handler.NewCond(SamlRequestColumnID, event.Aggregate().ID),
|
||||
handler.NewCond(SamlRequestColumnInstanceID, event.Aggregate().InstanceID),
|
||||
},
|
||||
), nil
|
||||
}
|
123
internal/query/projection/saml_request_test.go
Normal file
123
internal/query/projection/saml_request_test.go
Normal file
@ -0,0 +1,123 @@
|
||||
package projection
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
|
||||
"github.com/zitadel/zitadel/internal/repository/samlrequest"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func TestSamlRequestProjection_reduces(t *testing.T) {
|
||||
type args struct {
|
||||
event func(t *testing.T) eventstore.Event
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
reduce func(event eventstore.Event) (*handler.Statement, error)
|
||||
want wantReduce
|
||||
}{
|
||||
{
|
||||
name: "reduceSamlRequestAdded",
|
||||
args: args{
|
||||
event: getEvent(testEvent(
|
||||
samlrequest.AddedType,
|
||||
samlrequest.AggregateType,
|
||||
[]byte(`{"login_client": "loginClient", "issuer": "issuer", "acs_url": "acs", "relay_state": "relayState", "binding": "binding"}`),
|
||||
), eventstore.GenericEventMapper[samlrequest.AddedEvent]),
|
||||
},
|
||||
reduce: (&samlRequestProjection{}).reduceSamlRequestAdded,
|
||||
want: wantReduce{
|
||||
aggregateType: eventstore.AggregateType("saml_request"),
|
||||
sequence: 15,
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "INSERT INTO projections.saml_requests (id, instance_id, creation_date, change_date, resource_owner, sequence, login_client, issuer, acs, relay_state, binding) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)",
|
||||
expectedArgs: []interface{}{
|
||||
"agg-id",
|
||||
"instance-id",
|
||||
anyArg{},
|
||||
anyArg{},
|
||||
"ro-id",
|
||||
uint64(15),
|
||||
"loginClient",
|
||||
"issuer",
|
||||
"acs",
|
||||
"relayState",
|
||||
"binding",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "reduceSamlRequestFailed",
|
||||
args: args{
|
||||
event: getEvent(testEvent(
|
||||
samlrequest.FailedType,
|
||||
samlrequest.AggregateType,
|
||||
[]byte(`{"reason": 0}`),
|
||||
), eventstore.GenericEventMapper[samlrequest.FailedEvent]),
|
||||
},
|
||||
reduce: (&samlRequestProjection{}).reduceSamlRequestEnded,
|
||||
want: wantReduce{
|
||||
aggregateType: eventstore.AggregateType("saml_request"),
|
||||
sequence: 15,
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "DELETE FROM projections.saml_requests WHERE (id = $1) AND (instance_id = $2)",
|
||||
expectedArgs: []interface{}{
|
||||
"agg-id",
|
||||
"instance-id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "reduceSamlRequestSucceeded",
|
||||
args: args{
|
||||
event: getEvent(testEvent(
|
||||
samlrequest.SucceededType,
|
||||
samlrequest.AggregateType,
|
||||
nil,
|
||||
), eventstore.GenericEventMapper[samlrequest.SucceededEvent]),
|
||||
},
|
||||
reduce: (&samlRequestProjection{}).reduceSamlRequestEnded,
|
||||
want: wantReduce{
|
||||
aggregateType: eventstore.AggregateType("saml_request"),
|
||||
sequence: 15,
|
||||
executer: &testExecuter{
|
||||
executions: []execution{
|
||||
{
|
||||
expectedStmt: "DELETE FROM projections.saml_requests WHERE (id = $1) AND (instance_id = $2)",
|
||||
expectedArgs: []interface{}{
|
||||
"agg-id",
|
||||
"instance-id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
event := baseEvent(t)
|
||||
got, err := tt.reduce(event)
|
||||
if !zerrors.IsErrorInvalidArgument(err) {
|
||||
t.Errorf("no wrong event mapping: %v, got: %v", err, got)
|
||||
}
|
||||
|
||||
event = tt.args.event(t)
|
||||
got, err = tt.reduce(event)
|
||||
assertReduce(t, got, err, SamlRequestsProjectionTable, tt.want)
|
||||
})
|
||||
}
|
||||
}
|
81
internal/query/saml_request.go
Normal file
81
internal/query/saml_request.go
Normal file
@ -0,0 +1,81 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/call"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
|
||||
"github.com/zitadel/zitadel/internal/query/projection"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
type SamlRequest struct {
|
||||
ID string
|
||||
CreationDate time.Time
|
||||
LoginClient string
|
||||
Issuer string
|
||||
ACS string
|
||||
RelayState string
|
||||
Binding string
|
||||
}
|
||||
|
||||
func (a *SamlRequest) checkLoginClient(ctx context.Context) error {
|
||||
if uid := authz.GetCtxData(ctx).UserID; uid != a.LoginClient {
|
||||
return zerrors.ThrowPermissionDenied(nil, "OIDCv2-aL0ag", "Errors.SamlRequest.WrongLoginClient")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//go:embed saml_request_by_id.sql
|
||||
var samlRequestByIDQuery string
|
||||
|
||||
func (q *Queries) samlRequestByIDQuery(ctx context.Context) string {
|
||||
return fmt.Sprintf(samlRequestByIDQuery, q.client.Timetravel(call.Took(ctx)))
|
||||
}
|
||||
|
||||
func (q *Queries) SamlRequestByID(ctx context.Context, shouldTriggerBulk bool, id string, checkLoginClient bool) (_ *SamlRequest, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if shouldTriggerBulk {
|
||||
_, traceSpan := tracing.NewNamedSpan(ctx, "TriggerSamlRequestProjection")
|
||||
ctx, err = projection.SamlRequestProjection.Trigger(ctx, handler.WithAwaitRunning())
|
||||
logging.OnError(err).Debug("trigger failed")
|
||||
traceSpan.EndWithError(err)
|
||||
}
|
||||
|
||||
dst := new(SamlRequest)
|
||||
err = q.client.QueryRowContext(
|
||||
ctx,
|
||||
func(row *sql.Row) error {
|
||||
return row.Scan(
|
||||
&dst.ID, &dst.CreationDate, &dst.LoginClient, &dst.Issuer, &dst.ACS, &dst.RelayState, &dst.Binding,
|
||||
)
|
||||
},
|
||||
q.samlRequestByIDQuery(ctx),
|
||||
id, authz.GetInstance(ctx).InstanceID(),
|
||||
)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, zerrors.ThrowNotFound(err, "QUERY-Thee9", "Errors.SamlRequest.NotExisting")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "QUERY-Ou8ue", "Errors.Internal")
|
||||
}
|
||||
|
||||
if checkLoginClient {
|
||||
if err = dst.checkLoginClient(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return dst, nil
|
||||
}
|
11
internal/query/saml_request_by_id.sql
Normal file
11
internal/query/saml_request_by_id.sql
Normal file
@ -0,0 +1,11 @@
|
||||
select
|
||||
id,
|
||||
creation_date,
|
||||
login_client,
|
||||
issuer,
|
||||
acs,
|
||||
relay_state,
|
||||
binding
|
||||
from projections.saml_requests %s
|
||||
where id = $1 and instance_id = $2
|
||||
limit 1;
|
127
internal/query/saml_request_test.go
Normal file
127
internal/query/saml_request_test.go
Normal file
@ -0,0 +1,127 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/query/projection"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func TestQueries_SamlRequestByID(t *testing.T) {
|
||||
expQuery := regexp.QuoteMeta(fmt.Sprintf(
|
||||
samlRequestByIDQuery,
|
||||
asOfSystemTime,
|
||||
))
|
||||
|
||||
cols := []string{
|
||||
projection.SamlRequestColumnID,
|
||||
projection.SamlRequestColumnCreationDate,
|
||||
projection.SamlRequestColumnLoginClient,
|
||||
projection.SamlRequestColumnIssuer,
|
||||
projection.SamlRequestColumnACS,
|
||||
projection.SamlRequestColumnRelayState,
|
||||
projection.SamlRequestColumnBinding,
|
||||
}
|
||||
type args struct {
|
||||
shouldTriggerBulk bool
|
||||
id string
|
||||
checkLoginClient bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
expect sqlExpectation
|
||||
want *SamlRequest
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "success, all values",
|
||||
args: args{
|
||||
shouldTriggerBulk: false,
|
||||
id: "123",
|
||||
checkLoginClient: true,
|
||||
},
|
||||
expect: mockQuery(expQuery, cols, []driver.Value{
|
||||
"id",
|
||||
testNow,
|
||||
"loginClient",
|
||||
"issuer",
|
||||
"acs",
|
||||
"relayState",
|
||||
"binding",
|
||||
}, "123", "instanceID"),
|
||||
want: &SamlRequest{
|
||||
ID: "id",
|
||||
CreationDate: testNow,
|
||||
LoginClient: "loginClient",
|
||||
Issuer: "issuer",
|
||||
ACS: "acs",
|
||||
RelayState: "relayState",
|
||||
Binding: "binding",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no rows",
|
||||
args: args{
|
||||
shouldTriggerBulk: false,
|
||||
id: "123",
|
||||
},
|
||||
expect: mockQueryScanErr(expQuery, cols, nil, "123", "instanceID"),
|
||||
wantErr: zerrors.ThrowNotFound(sql.ErrNoRows, "QUERY-Thee9", "Errors.SamlRequest.NotExisting"),
|
||||
},
|
||||
{
|
||||
name: "query error",
|
||||
args: args{
|
||||
shouldTriggerBulk: false,
|
||||
id: "123",
|
||||
},
|
||||
expect: mockQueryErr(expQuery, sql.ErrConnDone, "123", "instanceID"),
|
||||
wantErr: zerrors.ThrowInternal(sql.ErrConnDone, "QUERY-Ou8ue", "Errors.Internal"),
|
||||
},
|
||||
{
|
||||
name: "wrong login client",
|
||||
args: args{
|
||||
shouldTriggerBulk: false,
|
||||
id: "123",
|
||||
checkLoginClient: true,
|
||||
},
|
||||
expect: mockQuery(expQuery, cols, []driver.Value{
|
||||
"id",
|
||||
testNow,
|
||||
"wrongLoginClient",
|
||||
"issuer",
|
||||
"acs",
|
||||
"relayState",
|
||||
"binding",
|
||||
}, "123", "instanceID"),
|
||||
wantErr: zerrors.ThrowPermissionDeniedf(nil, "OIDCv2-aL0ag", "Errors.SamlRequest.WrongLoginClient"),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
execMock(t, tt.expect, func(db *sql.DB) {
|
||||
q := &Queries{
|
||||
client: &database.DB{
|
||||
DB: db,
|
||||
Database: &prepareDB{},
|
||||
},
|
||||
}
|
||||
ctx := authz.NewMockContext("instanceID", "orgID", "loginClient")
|
||||
|
||||
got, err := q.SamlRequestByID(ctx, tt.args.shouldTriggerBulk, tt.args.id, tt.args.checkLoginClient)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
@ -19,7 +19,7 @@ const (
|
||||
type AddedEvent struct {
|
||||
*eventstore.BaseEvent `json:"-"`
|
||||
|
||||
LoginClient string `json:"loginClient,omitempty"`
|
||||
LoginClient string `json:"login_client,omitempty"`
|
||||
ApplicationID string `json:"application_id,omitempty"`
|
||||
ACSURL string `json:"acs_url,omitempty"`
|
||||
RelayState string `json:"relay_state,omitempty"`
|
||||
|
@ -169,7 +169,7 @@ message CreateCallbackRequest {
|
||||
string auth_request_id = 1 [
|
||||
(validate.rules).string = {min_len: 1, max_len: 200},
|
||||
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||
description: "ID of the SAML Request.";
|
||||
description: "ID of the Auth Request.";
|
||||
example: "\"163840776835432705\"";
|
||||
}
|
||||
];
|
||||
|
@ -62,23 +62,11 @@ message AuthorizationError {
|
||||
enum ErrorReason {
|
||||
ERROR_REASON_UNSPECIFIED = 0;
|
||||
|
||||
// Error states from https://datatracker.ietf.org/doc/html/rfc6749#section-4.2.2.1
|
||||
ERROR_REASON_INVALID_REQUEST = 1;
|
||||
ERROR_REASON_UNAUTHORIZED_CLIENT = 2;
|
||||
ERROR_REASON_ACCESS_DENIED = 3;
|
||||
ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE = 4;
|
||||
ERROR_REASON_INVALID_SCOPE = 5;
|
||||
ERROR_REASON_SERVER_ERROR = 6;
|
||||
ERROR_REASON_TEMPORARY_UNAVAILABLE = 7;
|
||||
|
||||
// Error states from https://openid.net/specs/openid-connect-core-1_0.html#AuthError
|
||||
ERROR_REASON_INTERACTION_REQUIRED = 8;
|
||||
ERROR_REASON_LOGIN_REQUIRED = 9;
|
||||
ERROR_REASON_ACCOUNT_SELECTION_REQUIRED = 10;
|
||||
ERROR_REASON_CONSENT_REQUIRED = 11;
|
||||
ERROR_REASON_INVALID_REQUEST_URI = 12;
|
||||
ERROR_REASON_INVALID_REQUEST_OBJECT = 13;
|
||||
ERROR_REASON_REQUEST_NOT_SUPPORTED = 14;
|
||||
ERROR_REASON_REQUEST_URI_NOT_SUPPORTED = 15;
|
||||
ERROR_REASON_REGISTRATION_NOT_SUPPORTED = 16;
|
||||
ERROR_REASON_VERSION_MISSMATCH = 1;
|
||||
ERROR_REASON_AUTH_N_FAILED = 2;
|
||||
ERROR_REASON_INVALID_ATTR_NAME_OR_VALUE = 3;
|
||||
ERROR_REASON_INVALID_NAMEID_POLICY = 4;
|
||||
ERROR_REASON_REQUEST_DENIED =5;
|
||||
ERROR_REASON_REQUEST_UNSUPPORTED = 6;
|
||||
ERROR_REASON_UNSUPPORTED_BINDING = 7;
|
||||
}
|
@ -123,6 +123,30 @@ service SAMLService {
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
rpc CreateCallback (CreateCallbackRequest) returns (CreateCallbackResponse) {
|
||||
option (google.api.http) = {
|
||||
post: "/v2/saml/saml_requests/{saml_request_id}"
|
||||
body: "*"
|
||||
};
|
||||
|
||||
option (zitadel.protoc_gen_zitadel.v2.options) = {
|
||||
auth_option: {
|
||||
permission: "authenticated"
|
||||
}
|
||||
};
|
||||
|
||||
option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = {
|
||||
summary: "Finalize a SAML Request and get the callback URL.";
|
||||
description: "Finalize a SAML Request and get the callback URL for success or failure. The user must be redirected to the URL in order to inform the application about the success or failure. On success, the URL contains details for the application to obtain the tokens. This method can only be called once for an SAML request."
|
||||
responses: {
|
||||
key: "200"
|
||||
value: {
|
||||
description: "OK";
|
||||
}
|
||||
};
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
message GetSAMLRequestRequest {
|
||||
@ -140,3 +164,54 @@ message GetSAMLRequestRequest {
|
||||
message GetSAMLRequestResponse {
|
||||
SAMLRequest saml_request = 1;
|
||||
}
|
||||
|
||||
message CreateCallbackRequest {
|
||||
string saml_request_id = 1 [
|
||||
(validate.rules).string = {min_len: 1, max_len: 200},
|
||||
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||
description: "ID of the SAML Request.";
|
||||
example: "\"163840776835432705\"";
|
||||
}
|
||||
];
|
||||
|
||||
oneof callback_kind {
|
||||
option (validate.required) = true;
|
||||
Session session = 2;
|
||||
AuthorizationError error = 3 [
|
||||
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||
description: "Set this field when the authorization flow failed. It creates a callback URL to the application, with the error details set.";
|
||||
}
|
||||
];
|
||||
}
|
||||
}
|
||||
|
||||
message Session {
|
||||
string session_id = 1 [
|
||||
(validate.rules).string = {min_len: 1, max_len: 200},
|
||||
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||
min_length: 1;
|
||||
max_length: 200;
|
||||
description: "ID of the session, used to login the user. Connects the session to the SAML Request.";
|
||||
example: "\"163840776835432705\"";
|
||||
}
|
||||
];
|
||||
|
||||
string session_token = 2 [
|
||||
(validate.rules).string = {min_len: 1, max_len: 200},
|
||||
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||
min_length: 1;
|
||||
max_length: 200;
|
||||
description: "Token to verify the session is valid";
|
||||
}
|
||||
];
|
||||
}
|
||||
|
||||
message CreateCallbackResponse {
|
||||
zitadel.object.v2.Details details = 1;
|
||||
string callback_url = 2 [
|
||||
(grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = {
|
||||
description: "Callback URL where the user should be redirected, using a \"302 FOUND\" status. Contains details for the application to obtain the tokens on success, or error details on failure. Note that this field must be treated as credentials, as the contained code can be used to obtain tokens on behalve of the user.";
|
||||
example: "\"https://client.example.org/cb\""
|
||||
}
|
||||
];
|
||||
}
|
Loading…
Reference in New Issue
Block a user