mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 00:57:33 +00:00
feat: add saml request to link to sessions (#9001)
# Which Problems Are Solved It is currently not possible to use SAML with the Session API. # How the Problems Are Solved Add SAML service, to get and resolve SAML requests. Add SAML session and SAML request aggregate, which can be linked to the Session to get back a SAMLResponse from the API directly. # Additional Changes Update of dependency zitadel/saml to provide all functionality for handling of SAML requests and responses. # Additional Context Closes #6053 --------- Co-authored-by: Livio Spring <livio.a@gmail.com>
This commit is contained in:
367
internal/api/grpc/saml/v2/integration/saml_test.go
Normal file
367
internal/api/grpc/saml/v2/integration/saml_test.go
Normal file
@@ -0,0 +1,367 @@
|
||||
//go:build integration
|
||||
|
||||
package saml_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v6"
|
||||
"github.com/crewjam/saml"
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/object/v2"
|
||||
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2"
|
||||
saml_pb "github.com/zitadel/zitadel/pkg/grpc/saml/v2"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/session/v2"
|
||||
)
|
||||
|
||||
var (
|
||||
CTX context.Context
|
||||
Instance *integration.Instance
|
||||
Client saml_pb.SAMLServiceClient
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(func() int {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
Instance = integration.NewInstance(ctx)
|
||||
Client = Instance.Client.SAMLv2
|
||||
|
||||
CTX = Instance.WithAuthorization(ctx, integration.UserTypeOrgOwner)
|
||||
return m.Run()
|
||||
}())
|
||||
}
|
||||
|
||||
func TestServer_GetAuthRequest(t *testing.T) {
|
||||
rootURL := "https://sp.example.com"
|
||||
idpMetadata, err := Instance.GetSAMLIDPMetadata()
|
||||
require.NoError(t, err)
|
||||
spMiddlewareRedirect, err := integration.CreateSAMLSP(rootURL, idpMetadata, saml.HTTPRedirectBinding)
|
||||
require.NoError(t, err)
|
||||
spMiddlewarePost, err := integration.CreateSAMLSP(rootURL, idpMetadata, saml.HTTPPostBinding)
|
||||
require.NoError(t, err)
|
||||
|
||||
acsRedirect := idpMetadata.IDPSSODescriptors[0].SingleSignOnServices[0]
|
||||
acsPost := idpMetadata.IDPSSODescriptors[0].SingleSignOnServices[1]
|
||||
|
||||
project, err := Instance.CreateProject(CTX)
|
||||
require.NoError(t, err)
|
||||
_, err = Instance.CreateSAMLClient(CTX, project.GetId(), spMiddlewareRedirect)
|
||||
require.NoError(t, err)
|
||||
_, err = Instance.CreateSAMLClient(CTX, project.GetId(), spMiddlewarePost)
|
||||
require.NoError(t, err)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
dep func() (string, error)
|
||||
want *oidc_pb.GetAuthRequestResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Not found",
|
||||
dep: func() (string, error) {
|
||||
return "123", nil
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "success, redirect binding",
|
||||
dep: func() (string, error) {
|
||||
return Instance.CreateSAMLAuthRequest(spMiddlewareRedirect, Instance.Users[integration.UserTypeOrgOwner].ID, acsRedirect, gofakeit.BitcoinAddress(), saml.HTTPRedirectBinding)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success, post binding",
|
||||
dep: func() (string, error) {
|
||||
return Instance.CreateSAMLAuthRequest(spMiddlewarePost, Instance.Users[integration.UserTypeOrgOwner].ID, acsPost, gofakeit.BitcoinAddress(), saml.HTTPPostBinding)
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
authRequestID, err := tt.dep()
|
||||
require.NoError(t, err)
|
||||
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
require.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
got, err := Client.GetSAMLRequest(CTX, &saml_pb.GetSAMLRequestRequest{
|
||||
SamlRequestId: authRequestID,
|
||||
})
|
||||
if tt.wantErr {
|
||||
assert.Error(ttt, err)
|
||||
return
|
||||
}
|
||||
assert.NoError(ttt, err)
|
||||
authRequest := got.GetSamlRequest()
|
||||
assert.NotNil(ttt, authRequest)
|
||||
assert.Equal(ttt, authRequestID, authRequest.GetId())
|
||||
assert.WithinRange(ttt, authRequest.GetCreationDate().AsTime(), now.Add(-time.Second), now.Add(time.Second))
|
||||
}, retryDuration, tick, "timeout waiting for expected saml request result")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_CreateResponse(t *testing.T) {
|
||||
idpMetadata, err := Instance.GetSAMLIDPMetadata()
|
||||
require.NoError(t, err)
|
||||
rootURLRedirect := "spredirect.example.com"
|
||||
spMiddlewareRedirect, err := integration.CreateSAMLSP("https://"+rootURLRedirect, idpMetadata, saml.HTTPRedirectBinding)
|
||||
require.NoError(t, err)
|
||||
rootURLPost := "sppost.example.com"
|
||||
spMiddlewarePost, err := integration.CreateSAMLSP("https://"+rootURLPost, idpMetadata, saml.HTTPPostBinding)
|
||||
require.NoError(t, err)
|
||||
|
||||
acsRedirect := idpMetadata.IDPSSODescriptors[0].SingleSignOnServices[0]
|
||||
acsPost := idpMetadata.IDPSSODescriptors[0].SingleSignOnServices[1]
|
||||
|
||||
project, err := Instance.CreateProject(CTX)
|
||||
require.NoError(t, err)
|
||||
_, err = Instance.CreateSAMLClient(CTX, project.GetId(), spMiddlewareRedirect)
|
||||
require.NoError(t, err)
|
||||
_, err = Instance.CreateSAMLClient(CTX, project.GetId(), spMiddlewarePost)
|
||||
require.NoError(t, err)
|
||||
|
||||
sessionResp, err := Instance.Client.SessionV2.CreateSession(CTX, &session.CreateSessionRequest{
|
||||
Checks: &session.Checks{
|
||||
User: &session.CheckUser{
|
||||
Search: &session.CheckUser_UserId{
|
||||
UserId: Instance.Users[integration.UserTypeOrgOwner].ID,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
req *saml_pb.CreateResponseRequest
|
||||
AuthError string
|
||||
want *saml_pb.CreateResponseResponse
|
||||
wantURL *url.URL
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Not found",
|
||||
req: &saml_pb.CreateResponseRequest{
|
||||
SamlRequestId: "123",
|
||||
ResponseKind: &saml_pb.CreateResponseRequest_Session{
|
||||
Session: &saml_pb.Session{
|
||||
SessionId: sessionResp.GetSessionId(),
|
||||
SessionToken: sessionResp.GetSessionToken(),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "session not found",
|
||||
req: &saml_pb.CreateResponseRequest{
|
||||
SamlRequestId: func() string {
|
||||
authRequestID, err := Instance.CreateSAMLAuthRequest(spMiddlewareRedirect, Instance.Users[integration.UserTypeOrgOwner].ID, acsRedirect, gofakeit.BitcoinAddress(), saml.HTTPRedirectBinding)
|
||||
require.NoError(t, err)
|
||||
return authRequestID
|
||||
}(),
|
||||
ResponseKind: &saml_pb.CreateResponseRequest_Session{
|
||||
Session: &saml_pb.Session{
|
||||
SessionId: "foo",
|
||||
SessionToken: "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "session token invalid",
|
||||
req: &saml_pb.CreateResponseRequest{
|
||||
SamlRequestId: func() string {
|
||||
authRequestID, err := Instance.CreateSAMLAuthRequest(spMiddlewareRedirect, Instance.Users[integration.UserTypeOrgOwner].ID, acsRedirect, gofakeit.BitcoinAddress(), saml.HTTPRedirectBinding)
|
||||
require.NoError(t, err)
|
||||
return authRequestID
|
||||
}(),
|
||||
ResponseKind: &saml_pb.CreateResponseRequest_Session{
|
||||
Session: &saml_pb.Session{
|
||||
SessionId: sessionResp.GetSessionId(),
|
||||
SessionToken: "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "fail callback, post",
|
||||
req: &saml_pb.CreateResponseRequest{
|
||||
SamlRequestId: func() string {
|
||||
authRequestID, err := Instance.CreateSAMLAuthRequest(spMiddlewarePost, Instance.Users[integration.UserTypeOrgOwner].ID, acsPost, gofakeit.BitcoinAddress(), saml.HTTPPostBinding)
|
||||
require.NoError(t, err)
|
||||
return authRequestID
|
||||
}(),
|
||||
ResponseKind: &saml_pb.CreateResponseRequest_Error{
|
||||
Error: &saml_pb.AuthorizationError{
|
||||
Error: saml_pb.ErrorReason_ERROR_REASON_REQUEST_DENIED,
|
||||
ErrorDescription: gu.Ptr("nope"),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &saml_pb.CreateResponseResponse{
|
||||
Url: regexp.QuoteMeta(`https://` + rootURLPost + `/saml/acs`),
|
||||
Binding: &saml_pb.CreateResponseResponse_Post{Post: &saml_pb.PostResponse{
|
||||
RelayState: "notempty",
|
||||
SamlResponse: "notempty",
|
||||
}},
|
||||
Details: &object.Details{
|
||||
ChangeDate: timestamppb.Now(),
|
||||
ResourceOwner: Instance.ID(),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "fail callback, post, already failed",
|
||||
req: &saml_pb.CreateResponseRequest{
|
||||
SamlRequestId: func() string {
|
||||
authRequestID, err := Instance.CreateSAMLAuthRequest(spMiddlewarePost, Instance.Users[integration.UserTypeOrgOwner].ID, acsPost, gofakeit.BitcoinAddress(), saml.HTTPPostBinding)
|
||||
require.NoError(t, err)
|
||||
Instance.FailSAMLAuthRequest(CTX, authRequestID, saml_pb.ErrorReason_ERROR_REASON_AUTH_N_FAILED)
|
||||
return authRequestID
|
||||
}(),
|
||||
ResponseKind: &saml_pb.CreateResponseRequest_Error{
|
||||
Error: &saml_pb.AuthorizationError{
|
||||
Error: saml_pb.ErrorReason_ERROR_REASON_REQUEST_DENIED,
|
||||
ErrorDescription: gu.Ptr("nope"),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "fail callback, redirect",
|
||||
req: &saml_pb.CreateResponseRequest{
|
||||
SamlRequestId: func() string {
|
||||
authRequestID, err := Instance.CreateSAMLAuthRequest(spMiddlewareRedirect, Instance.Users[integration.UserTypeOrgOwner].ID, acsPost, gofakeit.BitcoinAddress(), saml.HTTPPostBinding)
|
||||
require.NoError(t, err)
|
||||
return authRequestID
|
||||
}(),
|
||||
ResponseKind: &saml_pb.CreateResponseRequest_Error{
|
||||
Error: &saml_pb.AuthorizationError{
|
||||
Error: saml_pb.ErrorReason_ERROR_REASON_REQUEST_DENIED,
|
||||
ErrorDescription: gu.Ptr("nope"),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &saml_pb.CreateResponseResponse{
|
||||
Url: `https:\/\/` + rootURLRedirect + `\/saml\/acs\?RelayState=(.*)&SAMLResponse=(.*)`,
|
||||
Binding: &saml_pb.CreateResponseResponse_Redirect{Redirect: &saml_pb.RedirectResponse{}},
|
||||
Details: &object.Details{
|
||||
ChangeDate: timestamppb.Now(),
|
||||
ResourceOwner: Instance.ID(),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "callback, redirect",
|
||||
req: &saml_pb.CreateResponseRequest{
|
||||
SamlRequestId: func() string {
|
||||
authRequestID, err := Instance.CreateSAMLAuthRequest(spMiddlewareRedirect, Instance.Users[integration.UserTypeOrgOwner].ID, acsRedirect, gofakeit.BitcoinAddress(), saml.HTTPRedirectBinding)
|
||||
require.NoError(t, err)
|
||||
return authRequestID
|
||||
}(),
|
||||
ResponseKind: &saml_pb.CreateResponseRequest_Session{
|
||||
Session: &saml_pb.Session{
|
||||
SessionId: sessionResp.GetSessionId(),
|
||||
SessionToken: sessionResp.GetSessionToken(),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &saml_pb.CreateResponseResponse{
|
||||
Url: `https:\/\/` + rootURLRedirect + `\/saml\/acs\?RelayState=(.*)&SAMLResponse=(.*)&SigAlg=(.*)&Signature=(.*)`,
|
||||
Binding: &saml_pb.CreateResponseResponse_Redirect{Redirect: &saml_pb.RedirectResponse{}},
|
||||
Details: &object.Details{
|
||||
ChangeDate: timestamppb.Now(),
|
||||
ResourceOwner: Instance.ID(),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "callback, post",
|
||||
req: &saml_pb.CreateResponseRequest{
|
||||
SamlRequestId: func() string {
|
||||
authRequestID, err := Instance.CreateSAMLAuthRequest(spMiddlewarePost, Instance.Users[integration.UserTypeOrgOwner].ID, acsPost, gofakeit.BitcoinAddress(), saml.HTTPPostBinding)
|
||||
require.NoError(t, err)
|
||||
return authRequestID
|
||||
}(),
|
||||
ResponseKind: &saml_pb.CreateResponseRequest_Session{
|
||||
Session: &saml_pb.Session{
|
||||
SessionId: sessionResp.GetSessionId(),
|
||||
SessionToken: sessionResp.GetSessionToken(),
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &saml_pb.CreateResponseResponse{
|
||||
Url: regexp.QuoteMeta(`https://` + rootURLPost + `/saml/acs`),
|
||||
Binding: &saml_pb.CreateResponseResponse_Post{Post: &saml_pb.PostResponse{
|
||||
RelayState: "notempty",
|
||||
SamlResponse: "notempty",
|
||||
}},
|
||||
Details: &object.Details{
|
||||
ChangeDate: timestamppb.Now(),
|
||||
ResourceOwner: Instance.ID(),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "callback, post",
|
||||
req: &saml_pb.CreateResponseRequest{
|
||||
SamlRequestId: func() string {
|
||||
authRequestID, err := Instance.CreateSAMLAuthRequest(spMiddlewarePost, Instance.Users[integration.UserTypeOrgOwner].ID, acsPost, gofakeit.BitcoinAddress(), saml.HTTPPostBinding)
|
||||
require.NoError(t, err)
|
||||
Instance.SuccessfulSAMLAuthRequest(CTX, Instance.Users[integration.UserTypeOrgOwner].ID, authRequestID)
|
||||
return authRequestID
|
||||
}(),
|
||||
ResponseKind: &saml_pb.CreateResponseRequest_Session{
|
||||
Session: &saml_pb.Session{
|
||||
SessionId: sessionResp.GetSessionId(),
|
||||
SessionToken: sessionResp.GetSessionToken(),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Client.CreateResponse(CTX, tt.req)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
integration.AssertDetails(t, tt.want, got)
|
||||
if tt.want != nil {
|
||||
assert.Regexp(t, regexp.MustCompile(tt.want.Url), got.GetUrl())
|
||||
if tt.want.GetPost() != nil {
|
||||
assert.NotEmpty(t, got.GetPost().GetRelayState())
|
||||
assert.NotEmpty(t, got.GetPost().GetSamlResponse())
|
||||
}
|
||||
if tt.want.GetRedirect() != nil {
|
||||
assert.NotNil(t, got.GetRedirect())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
112
internal/api/grpc/saml/v2/saml.go
Normal file
112
internal/api/grpc/saml/v2/saml.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/object/v2"
|
||||
"github.com/zitadel/zitadel/internal/api/saml"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
saml_pb "github.com/zitadel/zitadel/pkg/grpc/saml/v2"
|
||||
)
|
||||
|
||||
func (s *Server) GetAuthRequest(ctx context.Context, req *saml_pb.GetSAMLRequestRequest) (*saml_pb.GetSAMLRequestResponse, error) {
|
||||
authRequest, err := s.query.SamlRequestByID(ctx, true, req.GetSamlRequestId(), true)
|
||||
if err != nil {
|
||||
logging.WithError(err).Error("query samlRequest by ID")
|
||||
return nil, err
|
||||
}
|
||||
return &saml_pb.GetSAMLRequestResponse{
|
||||
SamlRequest: samlRequestToPb(authRequest),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func samlRequestToPb(a *query.SamlRequest) *saml_pb.SAMLRequest {
|
||||
return &saml_pb.SAMLRequest{
|
||||
Id: a.ID,
|
||||
CreationDate: timestamppb.New(a.CreationDate),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) CreateResponse(ctx context.Context, req *saml_pb.CreateResponseRequest) (*saml_pb.CreateResponseResponse, error) {
|
||||
switch v := req.GetResponseKind().(type) {
|
||||
case *saml_pb.CreateResponseRequest_Error:
|
||||
return s.failSAMLRequest(ctx, req.GetSamlRequestId(), v.Error)
|
||||
case *saml_pb.CreateResponseRequest_Session:
|
||||
return s.linkSessionToSAMLRequest(ctx, req.GetSamlRequestId(), v.Session)
|
||||
default:
|
||||
return nil, zerrors.ThrowUnimplementedf(nil, "SAMLv2-0Tfak3fBS0", "verification oneOf %T in method CreateResponse not implemented", v)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) failSAMLRequest(ctx context.Context, samlRequestID string, ae *saml_pb.AuthorizationError) (*saml_pb.CreateResponseResponse, error) {
|
||||
details, aar, err := s.command.FailSAMLRequest(ctx, samlRequestID, errorReasonToDomain(ae.GetError()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authReq := &saml.AuthRequestV2{CurrentSAMLRequest: aar}
|
||||
url, body, err := s.idp.CreateErrorResponse(authReq, errorReasonToDomain(ae.GetError()), ae.GetErrorDescription())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return createCallbackResponseFromBinding(details, url, body, authReq.RelayState), nil
|
||||
}
|
||||
|
||||
func (s *Server) linkSessionToSAMLRequest(ctx context.Context, samlRequestID string, session *saml_pb.Session) (*saml_pb.CreateResponseResponse, error) {
|
||||
details, aar, err := s.command.LinkSessionToSAMLRequest(ctx, samlRequestID, session.GetSessionId(), session.GetSessionToken(), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authReq := &saml.AuthRequestV2{CurrentSAMLRequest: aar}
|
||||
url, body, err := s.idp.CreateResponse(ctx, authReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return createCallbackResponseFromBinding(details, url, body, authReq.RelayState), nil
|
||||
}
|
||||
|
||||
func createCallbackResponseFromBinding(details *domain.ObjectDetails, url string, body string, relayState string) *saml_pb.CreateResponseResponse {
|
||||
resp := &saml_pb.CreateResponseResponse{
|
||||
Details: object.DomainToDetailsPb(details),
|
||||
Url: url,
|
||||
}
|
||||
|
||||
if body != "" {
|
||||
resp.Binding = &saml_pb.CreateResponseResponse_Post{
|
||||
Post: &saml_pb.PostResponse{
|
||||
RelayState: relayState,
|
||||
SamlResponse: body,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
resp.Binding = &saml_pb.CreateResponseResponse_Redirect{Redirect: &saml_pb.RedirectResponse{}}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func errorReasonToDomain(errorReason saml_pb.ErrorReason) domain.SAMLErrorReason {
|
||||
switch errorReason {
|
||||
case saml_pb.ErrorReason_ERROR_REASON_UNSPECIFIED:
|
||||
return domain.SAMLErrorReasonUnspecified
|
||||
case saml_pb.ErrorReason_ERROR_REASON_VERSION_MISSMATCH:
|
||||
return domain.SAMLErrorReasonVersionMissmatch
|
||||
case saml_pb.ErrorReason_ERROR_REASON_AUTH_N_FAILED:
|
||||
return domain.SAMLErrorReasonAuthNFailed
|
||||
case saml_pb.ErrorReason_ERROR_REASON_INVALID_ATTR_NAME_OR_VALUE:
|
||||
return domain.SAMLErrorReasonInvalidAttrNameOrValue
|
||||
case saml_pb.ErrorReason_ERROR_REASON_INVALID_NAMEID_POLICY:
|
||||
return domain.SAMLErrorReasonInvalidNameIDPolicy
|
||||
case saml_pb.ErrorReason_ERROR_REASON_REQUEST_DENIED:
|
||||
return domain.SAMLErrorReasonRequestDenied
|
||||
case saml_pb.ErrorReason_ERROR_REASON_REQUEST_UNSUPPORTED:
|
||||
return domain.SAMLErrorReasonRequestUnsupported
|
||||
case saml_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_BINDING:
|
||||
return domain.SAMLErrorReasonUnsupportedBinding
|
||||
default:
|
||||
return domain.SAMLErrorReasonUnspecified
|
||||
}
|
||||
}
|
59
internal/api/grpc/saml/v2/server.go
Normal file
59
internal/api/grpc/saml/v2/server.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/server"
|
||||
"github.com/zitadel/zitadel/internal/api/saml"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
saml_pb "github.com/zitadel/zitadel/pkg/grpc/saml/v2"
|
||||
)
|
||||
|
||||
var _ saml_pb.SAMLServiceServer = (*Server)(nil)
|
||||
|
||||
type Server struct {
|
||||
saml_pb.UnimplementedSAMLServiceServer
|
||||
command *command.Commands
|
||||
query *query.Queries
|
||||
|
||||
idp *saml.Provider
|
||||
externalSecure bool
|
||||
}
|
||||
|
||||
type Config struct{}
|
||||
|
||||
func CreateServer(
|
||||
command *command.Commands,
|
||||
query *query.Queries,
|
||||
idp *saml.Provider,
|
||||
externalSecure bool,
|
||||
) *Server {
|
||||
return &Server{
|
||||
command: command,
|
||||
query: query,
|
||||
idp: idp,
|
||||
externalSecure: externalSecure,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) RegisterServer(grpcServer *grpc.Server) {
|
||||
saml_pb.RegisterSAMLServiceServer(grpcServer, s)
|
||||
}
|
||||
|
||||
func (s *Server) AppName() string {
|
||||
return saml_pb.SAMLService_ServiceDesc.ServiceName
|
||||
}
|
||||
|
||||
func (s *Server) MethodPrefix() string {
|
||||
return saml_pb.SAMLService_ServiceDesc.ServiceName
|
||||
}
|
||||
|
||||
func (s *Server) AuthMethods() authz.MethodMapping {
|
||||
return saml_pb.SAMLService_AuthMethods
|
||||
}
|
||||
|
||||
func (s *Server) RegisterGateway() server.RegisterGatewayFunc {
|
||||
return saml_pb.RegisterSAMLServiceHandler
|
||||
}
|
99
internal/api/saml/auth_request.go
Normal file
99
internal/api/saml/auth_request.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"net/url"
|
||||
|
||||
"github.com/zitadel/saml/pkg/provider"
|
||||
"github.com/zitadel/saml/pkg/provider/models"
|
||||
"github.com/zitadel/saml/pkg/provider/xml"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
)
|
||||
|
||||
func (p *Provider) CreateErrorResponse(authReq models.AuthRequestInt, reason domain.SAMLErrorReason, description string) (string, string, error) {
|
||||
resp := &provider.Response{
|
||||
ProtocolBinding: authReq.GetBindingType(),
|
||||
RelayState: authReq.GetRelayState(),
|
||||
AcsUrl: authReq.GetAccessConsumerServiceURL(),
|
||||
RequestID: authReq.GetAuthRequestID(),
|
||||
Issuer: authReq.GetDestination(),
|
||||
Audience: authReq.GetIssuer(),
|
||||
}
|
||||
return createResponse(p.AuthCallbackErrorResponse(resp, domain.SAMLErrorReasonToString(reason), description), authReq.GetBindingType(), authReq.GetAccessConsumerServiceURL(), resp.RelayState, resp.SigAlg, resp.Signature)
|
||||
}
|
||||
|
||||
func (p *Provider) CreateResponse(ctx context.Context, authReq models.AuthRequestInt) (string, string, error) {
|
||||
resp := &provider.Response{
|
||||
ProtocolBinding: authReq.GetBindingType(),
|
||||
RelayState: authReq.GetRelayState(),
|
||||
AcsUrl: authReq.GetAccessConsumerServiceURL(),
|
||||
RequestID: authReq.GetAuthRequestID(),
|
||||
Issuer: authReq.GetDestination(),
|
||||
Audience: authReq.GetIssuer(),
|
||||
}
|
||||
samlResponse, err := p.AuthCallbackResponse(ctx, authReq, resp)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if err := p.command.CreateSAMLSessionFromSAMLRequest(
|
||||
setContextUserSystem(ctx),
|
||||
authReq.GetID(),
|
||||
samlComplianceChecker(),
|
||||
samlResponse.Id,
|
||||
p.Expiration(),
|
||||
); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return createResponse(samlResponse, authReq.GetBindingType(), authReq.GetAccessConsumerServiceURL(), resp.RelayState, resp.SigAlg, resp.Signature)
|
||||
}
|
||||
|
||||
func createResponse(samlResponse interface{}, binding, acs, relayState, sigAlg, sig string) (string, string, error) {
|
||||
respData, err := xml.Marshal(samlResponse)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
switch binding {
|
||||
case provider.PostBinding:
|
||||
return acs, base64.StdEncoding.EncodeToString(respData), nil
|
||||
case provider.RedirectBinding:
|
||||
respData, err := xml.DeflateAndBase64(respData)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
parsed, err := url.Parse(acs)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
values := parsed.Query()
|
||||
values.Add("SAMLResponse", string(respData))
|
||||
values.Add("RelayState", relayState)
|
||||
values.Add("SigAlg", sigAlg)
|
||||
values.Add("Signature", sig)
|
||||
parsed.RawQuery = values.Encode()
|
||||
return parsed.String(), "", nil
|
||||
}
|
||||
return "", "", nil
|
||||
}
|
||||
|
||||
func setContextUserSystem(ctx context.Context) context.Context {
|
||||
data := authz.CtxData{
|
||||
UserID: "SYSTEM",
|
||||
}
|
||||
return authz.SetCtxData(ctx, data)
|
||||
}
|
||||
|
||||
func samlComplianceChecker() command.SAMLRequestComplianceChecker {
|
||||
return func(_ context.Context, samlReq *command.SAMLRequestWriteModel) error {
|
||||
if err := samlReq.CheckAuthenticated(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
45
internal/api/saml/auth_request_converter_v2.go
Normal file
45
internal/api/saml/auth_request_converter_v2.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"github.com/zitadel/saml/pkg/provider/models"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
)
|
||||
|
||||
var _ models.AuthRequestInt = &AuthRequestV2{}
|
||||
|
||||
type AuthRequestV2 struct {
|
||||
*command.CurrentSAMLRequest
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetApplicationID() string {
|
||||
return a.ApplicationID
|
||||
}
|
||||
|
||||
func (a *AuthRequestV2) GetID() string {
|
||||
return a.ID
|
||||
}
|
||||
func (a *AuthRequestV2) GetRelayState() string {
|
||||
return a.RelayState
|
||||
}
|
||||
func (a *AuthRequestV2) GetAccessConsumerServiceURL() string {
|
||||
return a.ACSURL
|
||||
}
|
||||
func (a *AuthRequestV2) GetAuthRequestID() string {
|
||||
return a.RequestID
|
||||
}
|
||||
func (a *AuthRequestV2) GetBindingType() string {
|
||||
return a.Binding
|
||||
}
|
||||
func (a *AuthRequestV2) GetIssuer() string {
|
||||
return a.Issuer
|
||||
}
|
||||
func (a *AuthRequestV2) GetDestination() string {
|
||||
return a.Destination
|
||||
}
|
||||
func (a *AuthRequestV2) GetUserID() string {
|
||||
return a.UserID
|
||||
}
|
||||
func (a *AuthRequestV2) Done() bool {
|
||||
return a.UserID != "" && a.SessionID != ""
|
||||
}
|
@@ -24,7 +24,13 @@ const (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
ProviderConfig *provider.Config
|
||||
ProviderConfig *provider.Config
|
||||
DefaultLoginURLV2 string
|
||||
}
|
||||
|
||||
type Provider struct {
|
||||
*provider.Provider
|
||||
command *command.Commands
|
||||
}
|
||||
|
||||
func NewProvider(
|
||||
@@ -40,7 +46,7 @@ func NewProvider(
|
||||
instanceHandler,
|
||||
userAgentCookie func(http.Handler) http.Handler,
|
||||
accessHandler *middleware.AccessInterceptor,
|
||||
) (*provider.Provider, error) {
|
||||
) (*Provider, error) {
|
||||
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
|
||||
|
||||
provStorage, err := newStorage(
|
||||
@@ -51,6 +57,8 @@ func NewProvider(
|
||||
certEncAlg,
|
||||
es,
|
||||
projections,
|
||||
fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID),
|
||||
conf.DefaultLoginURLV2,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -73,12 +81,19 @@ func NewProvider(
|
||||
options = append(options, provider.WithAllowInsecure())
|
||||
}
|
||||
|
||||
return provider.NewProvider(
|
||||
p, err := provider.NewProvider(
|
||||
provStorage,
|
||||
HandlerPrefix,
|
||||
conf.ProviderConfig,
|
||||
options...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Provider{
|
||||
p,
|
||||
command,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newStorage(
|
||||
@@ -89,16 +104,19 @@ func newStorage(
|
||||
certEncAlg crypto.EncryptionAlgorithm,
|
||||
es *eventstore.Eventstore,
|
||||
db *database.DB,
|
||||
defaultLoginURL string,
|
||||
defaultLoginURLV2 string,
|
||||
) (*Storage, error) {
|
||||
return &Storage{
|
||||
encAlg: encAlg,
|
||||
certEncAlg: certEncAlg,
|
||||
locker: crdb.NewLocker(db.DB, locksTable, signingKey),
|
||||
eventstore: es,
|
||||
repo: repo,
|
||||
command: command,
|
||||
query: query,
|
||||
defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID),
|
||||
encAlg: encAlg,
|
||||
certEncAlg: certEncAlg,
|
||||
locker: crdb.NewLocker(db.DB, locksTable, signingKey),
|
||||
eventstore: es,
|
||||
repo: repo,
|
||||
command: command,
|
||||
query: query,
|
||||
defaultLoginURL: defaultLoginURL,
|
||||
defaultLoginURLv2: defaultLoginURLV2,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@@ -3,6 +3,7 @@ package saml
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dop251/goja"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/actions"
|
||||
"github.com/zitadel/zitadel/internal/actions/object"
|
||||
"github.com/zitadel/zitadel/internal/activity"
|
||||
http_utils "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/api/http/middleware"
|
||||
"github.com/zitadel/zitadel/internal/auth/repository"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
@@ -33,6 +35,10 @@ var _ provider.IdentityProviderStorage = &Storage{}
|
||||
var _ provider.AuthStorage = &Storage{}
|
||||
var _ provider.UserStorage = &Storage{}
|
||||
|
||||
const (
|
||||
LoginClientHeader = "x-zitadel-login-client"
|
||||
)
|
||||
|
||||
type Storage struct {
|
||||
certChan <-chan interface{}
|
||||
defaultCertificateLifetime time.Duration
|
||||
@@ -51,7 +57,8 @@ type Storage struct {
|
||||
command *command.Commands
|
||||
query *query.Queries
|
||||
|
||||
defaultLoginURL string
|
||||
defaultLoginURL string
|
||||
defaultLoginURLv2 string
|
||||
}
|
||||
|
||||
func (p *Storage) GetEntityByID(ctx context.Context, entityID string) (*serviceprovider.ServiceProvider, error) {
|
||||
@@ -64,7 +71,12 @@ func (p *Storage) GetEntityByID(ctx context.Context, entityID string) (*servicep
|
||||
&serviceprovider.Config{
|
||||
Metadata: app.SAMLConfig.Metadata,
|
||||
},
|
||||
p.defaultLoginURL,
|
||||
func(id string) string {
|
||||
if strings.HasPrefix(id, command.IDPrefixV2) {
|
||||
return p.defaultLoginURLv2 + id
|
||||
}
|
||||
return p.defaultLoginURL + id
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -95,6 +107,38 @@ func (p *Storage) GetResponseSigningKey(ctx context.Context) (*key.CertificateAn
|
||||
func (p *Storage) CreateAuthRequest(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID string) (_ models.AuthRequestInt, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
headers, _ := http_utils.HeadersFromCtx(ctx)
|
||||
if loginClient := headers.Get(LoginClientHeader); loginClient != "" {
|
||||
return p.createAuthRequestLoginClient(ctx, req, acsUrl, protocolBinding, relayState, applicationID, loginClient)
|
||||
}
|
||||
return p.createAuthRequest(ctx, req, acsUrl, protocolBinding, relayState, applicationID)
|
||||
}
|
||||
|
||||
func (p *Storage) createAuthRequestLoginClient(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID, loginClient string) (_ models.AuthRequestInt, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
samlRequest := &command.SAMLRequest{
|
||||
ApplicationID: applicationID,
|
||||
ACSURL: acsUrl,
|
||||
RelayState: relayState,
|
||||
RequestID: req.Id,
|
||||
Binding: protocolBinding,
|
||||
Issuer: req.Issuer.Text,
|
||||
Destination: req.Destination,
|
||||
LoginClient: loginClient,
|
||||
}
|
||||
|
||||
aar, err := p.command.AddSAMLRequest(ctx, samlRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AuthRequestV2{aar}, nil
|
||||
}
|
||||
|
||||
func (p *Storage) createAuthRequest(ctx context.Context, req *samlp.AuthnRequestType, acsUrl, protocolBinding, relayState, applicationID string) (_ models.AuthRequestInt, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
||||
if !ok {
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "SAML-sd436", "no user agent id")
|
||||
@@ -113,6 +157,15 @@ func (p *Storage) CreateAuthRequest(ctx context.Context, req *samlp.AuthnRequest
|
||||
func (p *Storage) AuthRequestByID(ctx context.Context, id string) (_ models.AuthRequestInt, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if strings.HasPrefix(id, command.IDPrefixV2) {
|
||||
req, err := p.command.GetCurrentSAMLRequest(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AuthRequestV2{req}, nil
|
||||
}
|
||||
|
||||
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
||||
if !ok {
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "SAML-D3g21", "no user agent id")
|
||||
|
Reference in New Issue
Block a user