feat: add SAML as identity provider (#6454)

* feat: first implementation for saml sp

* fix: add command side instance and org for saml provider

* fix: add query side instance and org for saml provider

* fix: request handling in event and retrieval of finished intent

* fix: add review changes and integration tests

* fix: add integration tests for saml idp

* fix: correct unit tests with review changes

* fix: add saml session unit test

* fix: add saml session unit test

* fix: add saml session unit test

* fix: changes from review

* fix: changes from review

* fix: proto build error

* fix: proto build error

* fix: proto build error

* fix: proto require metadata oneof

* fix: login with saml provider

* fix: integration test for saml assertion

* lint client.go

* fix json tag

* fix: linting

* fix import

* fix: linting

* fix saml idp query

* fix: linting

* lint: try all issues

* revert linting config

* fix: add regenerate endpoints

* fix: translations

* fix mk.yaml

* ignore acs path for user agent cookie

* fix: add AuthFromProvider test for saml

* fix: integration test for saml retrieve information

---------

Co-authored-by: Livio Spring <livio.a@gmail.com>
This commit is contained in:
Stefan Benz
2023-09-29 11:26:14 +02:00
committed by GitHub
parent 2e99d0fe1b
commit 15fd3045e0
82 changed files with 6301 additions and 245 deletions

View File

@@ -426,6 +426,37 @@ func (s *Server) UpdateAppleProvider(ctx context.Context, req *admin_pb.UpdateAp
}, nil
}
func (s *Server) AddSAMLProvider(ctx context.Context, req *admin_pb.AddSAMLProviderRequest) (*admin_pb.AddSAMLProviderResponse, error) {
id, details, err := s.command.AddInstanceSAMLProvider(ctx, addSAMLProviderToCommand(req))
if err != nil {
return nil, err
}
return &admin_pb.AddSAMLProviderResponse{
Id: id,
Details: object_pb.DomainToAddDetailsPb(details),
}, nil
}
func (s *Server) UpdateSAMLProvider(ctx context.Context, req *admin_pb.UpdateSAMLProviderRequest) (*admin_pb.UpdateSAMLProviderResponse, error) {
details, err := s.command.UpdateInstanceSAMLProvider(ctx, req.Id, updateSAMLProviderToCommand(req))
if err != nil {
return nil, err
}
return &admin_pb.UpdateSAMLProviderResponse{
Details: object_pb.DomainToChangeDetailsPb(details),
}, nil
}
func (s *Server) RegenerateSAMLProviderCertificate(ctx context.Context, req *admin_pb.RegenerateSAMLProviderCertificateRequest) (*admin_pb.RegenerateSAMLProviderCertificateResponse, error) {
details, err := s.command.RegenerateInstanceSAMLProviderCertificate(ctx, req.Id)
if err != nil {
return nil, err
}
return &admin_pb.RegenerateSAMLProviderCertificateResponse{
Details: object_pb.DomainToChangeDetailsPb(details),
}, nil
}
func (s *Server) DeleteProvider(ctx context.Context, req *admin_pb.DeleteProviderRequest) (*admin_pb.DeleteProviderResponse, error) {
details, err := s.command.DeleteInstanceProvider(ctx, req.Id)
if err != nil {

View File

@@ -1,6 +1,8 @@
package admin
import (
"github.com/crewjam/saml"
idp_grpc "github.com/zitadel/zitadel/internal/api/grpc/idp"
"github.com/zitadel/zitadel/internal/api/grpc/object"
"github.com/zitadel/zitadel/internal/command"
@@ -9,6 +11,7 @@ import (
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
"github.com/zitadel/zitadel/internal/query"
admin_pb "github.com/zitadel/zitadel/pkg/grpc/admin"
idp_pb "github.com/zitadel/zitadel/pkg/grpc/idp"
)
func addOIDCIDPRequestToDomain(req *admin_pb.AddOIDCIDPRequest) *domain.IDPConfig {
@@ -464,3 +467,40 @@ func updateAppleProviderToCommand(req *admin_pb.UpdateAppleProviderRequest) comm
IDPOptions: idp_grpc.OptionsToCommand(req.ProviderOptions),
}
}
func addSAMLProviderToCommand(req *admin_pb.AddSAMLProviderRequest) command.SAMLProvider {
return command.SAMLProvider{
Name: req.Name,
Metadata: req.GetMetadataXml(),
MetadataURL: req.GetMetadataUrl(),
Binding: bindingToCommand(req.Binding),
WithSignedRequest: req.WithSignedRequest,
IDPOptions: idp_grpc.OptionsToCommand(req.ProviderOptions),
}
}
func updateSAMLProviderToCommand(req *admin_pb.UpdateSAMLProviderRequest) command.SAMLProvider {
return command.SAMLProvider{
Name: req.Name,
Metadata: req.GetMetadataXml(),
MetadataURL: req.GetMetadataUrl(),
Binding: bindingToCommand(req.Binding),
WithSignedRequest: req.WithSignedRequest,
IDPOptions: idp_grpc.OptionsToCommand(req.ProviderOptions),
}
}
func bindingToCommand(binding idp_pb.SAMLBinding) string {
switch binding {
case idp_pb.SAMLBinding_SAML_BINDING_UNSPECIFIED:
return ""
case idp_pb.SAMLBinding_SAML_BINDING_POST:
return saml.HTTPPostBinding
case idp_pb.SAMLBinding_SAML_BINDING_REDIRECT:
return saml.HTTPRedirectBinding
case idp_pb.SAMLBinding_SAML_BINDING_ARTIFACT:
return saml.HTTPArtifactBinding
default:
return ""
}
}

View File

@@ -418,6 +418,37 @@ func (s *Server) UpdateAppleProvider(ctx context.Context, req *mgmt_pb.UpdateApp
}, nil
}
func (s *Server) AddSAMLProvider(ctx context.Context, req *mgmt_pb.AddSAMLProviderRequest) (*mgmt_pb.AddSAMLProviderResponse, error) {
id, details, err := s.command.AddOrgSAMLProvider(ctx, authz.GetCtxData(ctx).OrgID, addSAMLProviderToCommand(req))
if err != nil {
return nil, err
}
return &mgmt_pb.AddSAMLProviderResponse{
Id: id,
Details: object_pb.DomainToAddDetailsPb(details),
}, nil
}
func (s *Server) UpdateSAMLProvider(ctx context.Context, req *mgmt_pb.UpdateSAMLProviderRequest) (*mgmt_pb.UpdateSAMLProviderResponse, error) {
details, err := s.command.UpdateOrgSAMLProvider(ctx, authz.GetCtxData(ctx).OrgID, req.Id, updateSAMLProviderToCommand(req))
if err != nil {
return nil, err
}
return &mgmt_pb.UpdateSAMLProviderResponse{
Details: object_pb.DomainToChangeDetailsPb(details),
}, nil
}
func (s *Server) RegenerateSAMLProviderCertificate(ctx context.Context, req *mgmt_pb.RegenerateSAMLProviderCertificateRequest) (*mgmt_pb.RegenerateSAMLProviderCertificateResponse, error) {
details, err := s.command.RegenerateOrgSAMLProviderCertificate(ctx, authz.GetCtxData(ctx).OrgID, req.Id)
if err != nil {
return nil, err
}
return &mgmt_pb.RegenerateSAMLProviderCertificateResponse{
Details: object_pb.DomainToChangeDetailsPb(details),
}, nil
}
func (s *Server) DeleteProvider(ctx context.Context, req *mgmt_pb.DeleteProviderRequest) (*mgmt_pb.DeleteProviderResponse, error) {
details, err := s.command.DeleteOrgProvider(ctx, authz.GetCtxData(ctx).OrgID, req.Id)
if err != nil {

View File

@@ -3,6 +3,8 @@ package management
import (
"context"
"github.com/crewjam/saml"
"github.com/zitadel/zitadel/internal/api/authz"
idp_grpc "github.com/zitadel/zitadel/internal/api/grpc/idp"
"github.com/zitadel/zitadel/internal/api/grpc/object"
@@ -12,6 +14,7 @@ import (
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
iam_model "github.com/zitadel/zitadel/internal/iam/model"
"github.com/zitadel/zitadel/internal/query"
idp_pb "github.com/zitadel/zitadel/pkg/grpc/idp"
mgmt_pb "github.com/zitadel/zitadel/pkg/grpc/management"
)
@@ -481,3 +484,40 @@ func updateAppleProviderToCommand(req *mgmt_pb.UpdateAppleProviderRequest) comma
IDPOptions: idp_grpc.OptionsToCommand(req.ProviderOptions),
}
}
func addSAMLProviderToCommand(req *mgmt_pb.AddSAMLProviderRequest) command.SAMLProvider {
return command.SAMLProvider{
Name: req.Name,
Metadata: req.GetMetadataXml(),
MetadataURL: req.GetMetadataUrl(),
Binding: bindingToCommand(req.Binding),
WithSignedRequest: req.WithSignedRequest,
IDPOptions: idp_grpc.OptionsToCommand(req.ProviderOptions),
}
}
func updateSAMLProviderToCommand(req *mgmt_pb.UpdateSAMLProviderRequest) command.SAMLProvider {
return command.SAMLProvider{
Name: req.Name,
Metadata: req.GetMetadataXml(),
MetadataURL: req.GetMetadataUrl(),
Binding: bindingToCommand(req.Binding),
WithSignedRequest: req.WithSignedRequest,
IDPOptions: idp_grpc.OptionsToCommand(req.ProviderOptions),
}
}
func bindingToCommand(binding idp_pb.SAMLBinding) string {
switch binding {
case idp_pb.SAMLBinding_SAML_BINDING_UNSPECIFIED:
return ""
case idp_pb.SAMLBinding_SAML_BINDING_POST:
return saml.HTTPPostBinding
case idp_pb.SAMLBinding_SAML_BINDING_REDIRECT:
return saml.HTTPRedirectBinding
case idp_pb.SAMLBinding_SAML_BINDING_ARTIFACT:
return saml.HTTPArtifactBinding
default:
return ""
}
}

View File

@@ -22,6 +22,7 @@ type Server struct {
userCodeAlg crypto.EncryptionAlgorithm
idpAlg crypto.EncryptionAlgorithm
idpCallback func(ctx context.Context) string
samlRootURL func(ctx context.Context, idpID string) string
}
type Config struct{}
@@ -32,6 +33,7 @@ func CreateServer(
userCodeAlg crypto.EncryptionAlgorithm,
idpAlg crypto.EncryptionAlgorithm,
idpCallback func(ctx context.Context) string,
samlRootURL func(ctx context.Context, idpID string) string,
) *Server {
return &Server{
command: command,
@@ -39,6 +41,7 @@ func CreateServer(
userCodeAlg: userCodeAlg,
idpAlg: idpAlg,
idpCallback: idpCallback,
samlRootURL: samlRootURL,
}
}

View File

@@ -145,14 +145,23 @@ func (s *Server) startIDPIntent(ctx context.Context, idpID string, urls *user.Re
if err != nil {
return nil, err
}
authURL, err := s.command.AuthURLFromProvider(ctx, idpID, intentWriteModel.AggregateID, s.idpCallback(ctx))
content, redirect, err := s.command.AuthFromProvider(ctx, idpID, intentWriteModel.AggregateID, s.idpCallback(ctx), s.samlRootURL(ctx, idpID))
if err != nil {
return nil, err
}
return &user.StartIdentityProviderIntentResponse{
Details: object.DomainToDetailsPb(details),
NextStep: &user.StartIdentityProviderIntentResponse_AuthUrl{AuthUrl: authURL},
}, nil
if redirect {
return &user.StartIdentityProviderIntentResponse{
Details: object.DomainToDetailsPb(details),
NextStep: &user.StartIdentityProviderIntentResponse_AuthUrl{AuthUrl: content},
}, nil
} else {
return &user.StartIdentityProviderIntentResponse{
Details: object.DomainToDetailsPb(details),
NextStep: &user.StartIdentityProviderIntentResponse_PostForm{
PostForm: []byte(content),
},
}, nil
}
}
func (s *Server) startLDAPIntent(ctx context.Context, idpID string, ldapCredentials *user.LDAPCredentials) (*user.StartIdentityProviderIntentResponse, error) {
@@ -206,7 +215,7 @@ func (s *Server) checkLinkedExternalUser(ctx context.Context, idpID, externalUse
}
func (s *Server) ldapLogin(ctx context.Context, idpID, username, password string) (idp.User, string, map[string][]string, error) {
provider, err := s.command.GetProvider(ctx, idpID, "")
provider, err := s.command.GetProvider(ctx, idpID, "", "")
if err != nil {
return nil, "", nil, err
}
@@ -279,6 +288,14 @@ func idpIntentToIDPIntentPb(intent *command.IDPIntentWriteModel, alg crypto.Encr
information.IdpInformation.Access = access
}
if intent.Assertion != nil {
assertion, err := crypto.Decrypt(intent.Assertion, alg)
if err != nil {
return nil, err
}
information.IdpInformation.Access = IDPSAMLResponseToPb(assertion)
}
return information, nil
}
@@ -330,6 +347,14 @@ func IDPEntryAttributesToPb(entryAttributes map[string][]string) (*user.IDPInfor
}, nil
}
func IDPSAMLResponseToPb(assertion []byte) *user.IDPInformation_Saml {
return &user.IDPInformation_Saml{
Saml: &user.IDPSAMLAccessInformation{
Assertion: assertion,
},
}
}
func (s *Server) checkIntentToken(token string, intentID string) error {
return crypto.CheckToken(s.idpAlg, token, intentID)
}

View File

@@ -5,8 +5,8 @@ package user_test
import (
"context"
"fmt"
"net/url"
"os"
"strings"
"testing"
"time"
@@ -624,14 +624,24 @@ func TestServer_AddIDPLink(t *testing.T) {
func TestServer_StartIdentityProviderIntent(t *testing.T) {
idpID := Tester.AddGenericOAuthProvider(t)
samlIdpID := Tester.AddSAMLProvider(t)
samlRedirectIdpID := Tester.AddSAMLRedirectProvider(t)
samlPostIdpID := Tester.AddSAMLPostProvider(t)
type args struct {
ctx context.Context
req *user.StartIdentityProviderIntentRequest
}
type want struct {
details *object.Details
url string
parametersExisting []string
parametersEqual map[string]string
postForm bool
}
tests := []struct {
name string
args args
want *user.StartIdentityProviderIntentResponse
want want
wantErr bool
}{
{
@@ -642,11 +652,10 @@ func TestServer_StartIdentityProviderIntent(t *testing.T) {
IdpId: idpID,
},
},
want: nil,
wantErr: true,
},
{
name: "next step auth url",
name: "next step oauth auth url",
args: args{
CTX,
&user.StartIdentityProviderIntentRequest{
@@ -659,14 +668,91 @@ func TestServer_StartIdentityProviderIntent(t *testing.T) {
},
},
},
want: &user.StartIdentityProviderIntentResponse{
Details: &object.Details{
want: want{
details: &object.Details{
ChangeDate: timestamppb.Now(),
ResourceOwner: Tester.Organisation.ID,
},
NextStep: &user.StartIdentityProviderIntentResponse_AuthUrl{
AuthUrl: "https://example.com/oauth/v2/authorize?client_id=clientID&prompt=select_account&redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fidps%2Fcallback&response_type=code&scope=openid+profile+email&state=",
url: "https://example.com/oauth/v2/authorize",
parametersEqual: map[string]string{
"client_id": "clientID",
"prompt": "select_account",
"redirect_uri": "http://localhost:8080/idps/callback",
"response_type": "code",
"scope": "openid profile email",
},
parametersExisting: []string{"state"},
},
wantErr: false,
},
{
name: "next step saml default",
args: args{
CTX,
&user.StartIdentityProviderIntentRequest{
IdpId: samlIdpID,
Content: &user.StartIdentityProviderIntentRequest_Urls{
Urls: &user.RedirectURLs{
SuccessUrl: "https://example.com/success",
FailureUrl: "https://example.com/failure",
},
},
},
},
want: want{
details: &object.Details{
ChangeDate: timestamppb.Now(),
ResourceOwner: Tester.Organisation.ID,
},
url: "http://localhost:8000/sso",
parametersExisting: []string{"RelayState", "SAMLRequest"},
},
wantErr: false,
},
{
name: "next step saml auth url",
args: args{
CTX,
&user.StartIdentityProviderIntentRequest{
IdpId: samlRedirectIdpID,
Content: &user.StartIdentityProviderIntentRequest_Urls{
Urls: &user.RedirectURLs{
SuccessUrl: "https://example.com/success",
FailureUrl: "https://example.com/failure",
},
},
},
},
want: want{
details: &object.Details{
ChangeDate: timestamppb.Now(),
ResourceOwner: Tester.Organisation.ID,
},
url: "http://localhost:8000/sso",
parametersExisting: []string{"RelayState", "SAMLRequest"},
},
wantErr: false,
},
{
name: "next step saml form",
args: args{
CTX,
&user.StartIdentityProviderIntentRequest{
IdpId: samlPostIdpID,
Content: &user.StartIdentityProviderIntentRequest_Urls{
Urls: &user.RedirectURLs{
SuccessUrl: "https://example.com/success",
FailureUrl: "https://example.com/failure",
},
},
},
},
want: want{
details: &object.Details{
ChangeDate: timestamppb.Now(),
ResourceOwner: Tester.Organisation.ID,
},
postForm: true,
},
wantErr: false,
},
@@ -680,12 +766,25 @@ func TestServer_StartIdentityProviderIntent(t *testing.T) {
require.NoError(t, err)
}
if nextStep := tt.want.GetNextStep(); nextStep != nil {
if !strings.HasPrefix(got.GetAuthUrl(), tt.want.GetAuthUrl()) {
assert.Failf(t, "auth url does not match", "expected: %s, but got: %s", tt.want.GetAuthUrl(), got.GetAuthUrl())
if tt.want.url != "" {
authUrl, err := url.Parse(got.GetAuthUrl())
assert.NoError(t, err)
assert.Len(t, authUrl.Query(), len(tt.want.parametersEqual)+len(tt.want.parametersExisting))
for _, existing := range tt.want.parametersExisting {
assert.True(t, authUrl.Query().Has(existing))
}
for key, equal := range tt.want.parametersEqual {
assert.Equal(t, equal, authUrl.Query().Get(key))
}
}
integration.AssertDetails(t, tt.want, got)
if tt.want.postForm {
assert.NotEmpty(t, got.GetPostForm())
}
integration.AssertDetails(t, &user.StartIdentityProviderIntentResponse{
Details: tt.want.details,
}, got)
})
}
}
@@ -697,6 +796,7 @@ func TestServer_RetrieveIdentityProviderIntent(t *testing.T) {
successfulWithUserID, WithUsertoken, WithUserchangeDate, WithUsersequence := Tester.CreateSuccessfulOAuthIntent(t, idpID, "user", "id")
ldapSuccessfulID, ldapToken, ldapChangeDate, ldapSequence := Tester.CreateSuccessfulLDAPIntent(t, idpID, "", "id")
ldapSuccessfulWithUserID, ldapWithUserToken, ldapWithUserChangeDate, ldapWithUserSequence := Tester.CreateSuccessfulLDAPIntent(t, idpID, "user", "id")
samlSuccessfulID, samlToken, samlChangeDate, samlSequence := Tester.CreateSuccessfulSAMLIntent(t, idpID, "", "id")
type args struct {
ctx context.Context
req *user.RetrieveIdentityProviderIntentRequest
@@ -895,6 +995,44 @@ func TestServer_RetrieveIdentityProviderIntent(t *testing.T) {
},
wantErr: false,
},
{
name: "retrieve successful saml intent",
args: args{
CTX,
&user.RetrieveIdentityProviderIntentRequest{
IdpIntentId: samlSuccessfulID,
IdpIntentToken: samlToken,
},
},
want: &user.RetrieveIdentityProviderIntentResponse{
Details: &object.Details{
ChangeDate: timestamppb.New(samlChangeDate),
ResourceOwner: Tester.Organisation.ID,
Sequence: samlSequence,
},
IdpInformation: &user.IDPInformation{
Access: &user.IDPInformation_Saml{
Saml: &user.IDPSAMLAccessInformation{
Assertion: []byte("<Assertion xmlns=\"urn:oasis:names:tc:SAML:2.0:assertion\" ID=\"id\" IssueInstant=\"0001-01-01T00:00:00Z\" Version=\"\"><Issuer xmlns=\"urn:oasis:names:tc:SAML:2.0:assertion\" NameQualifier=\"\" SPNameQualifier=\"\" Format=\"\" SPProvidedID=\"\"></Issuer></Assertion>"),
},
},
IdpId: idpID,
UserId: "id",
UserName: "",
RawInformation: func() *structpb.Struct {
s, err := structpb.NewStruct(map[string]interface{}{
"id": "id",
"attributes": map[string]interface{}{
"attribute1": []interface{}{"value1"},
},
})
require.NoError(t, err)
return s
}(),
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View File

@@ -1,18 +1,23 @@
package idp
import (
"bytes"
"context"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"github.com/crewjam/saml"
"github.com/gorilla/mux"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
http_utils "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/ui/login"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
z_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/form"
"github.com/zitadel/zitadel/internal/idp"
@@ -25,19 +30,26 @@ import (
"github.com/zitadel/zitadel/internal/idp/providers/ldap"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
openid "github.com/zitadel/zitadel/internal/idp/providers/oidc"
saml2 "github.com/zitadel/zitadel/internal/idp/providers/saml"
"github.com/zitadel/zitadel/internal/query"
)
const (
HandlerPrefix = "/idps"
callbackPath = "/callback"
ldapCallbackPath = callbackPath + "/ldap"
HandlerPrefix = "/idps"
idpPrefix = "/{" + varIDPID + ":[0-9]+}"
callbackPath = "/callback"
metadataPath = idpPrefix + "/saml/metadata"
acsPath = idpPrefix + "/saml/acs"
certificatePath = idpPrefix + "/saml/certificate"
paramIntentID = "id"
paramToken = "token"
paramUserID = "user"
paramError = "error"
paramErrorDescription = "error_description"
varIDPID = "idpid"
)
type Handler struct {
@@ -46,6 +58,8 @@ type Handler struct {
parser *form.Parser
encryptionAlgorithm crypto.EncryptionAlgorithm
callbackURL func(ctx context.Context) string
samlRootURL func(ctx context.Context, idpID string) string
loginSAMLRootURL func(ctx context.Context) string
}
type externalIDPCallbackData struct {
@@ -58,6 +72,12 @@ type externalIDPCallbackData struct {
User string `schema:"user"`
}
type externalSAMLIDPCallbackData struct {
IDPID string
Response string
RelayState string
}
// CallbackURL generates the instance specific URL to the IDP callback handler
func CallbackURL(externalSecure bool) func(ctx context.Context) string {
return func(ctx context.Context) string {
@@ -65,6 +85,18 @@ func CallbackURL(externalSecure bool) func(ctx context.Context) string {
}
}
func SAMLRootURL(externalSecure bool) func(ctx context.Context, idpID string) string {
return func(ctx context.Context, idpID string) string {
return http_utils.BuildOrigin(authz.GetInstance(ctx).RequestedHost(), externalSecure) + HandlerPrefix + "/" + idpID + "/"
}
}
func LoginSAMLRootURL(externalSecure bool) func(ctx context.Context) string {
return func(ctx context.Context) string {
return http_utils.BuildOrigin(authz.GetInstance(ctx).RequestedHost(), externalSecure) + login.HandlerPrefix + login.EndpointSAMLACS
}
}
func NewHandler(
commands *command.Commands,
queries *query.Queries,
@@ -78,14 +110,166 @@ func NewHandler(
parser: form.NewParser(),
encryptionAlgorithm: encryptionAlgorithm,
callbackURL: CallbackURL(externalSecure),
samlRootURL: SAMLRootURL(externalSecure),
loginSAMLRootURL: LoginSAMLRootURL(externalSecure),
}
router := mux.NewRouter()
router.Use(instanceInterceptor)
router.HandleFunc(callbackPath, h.handleCallback)
router.HandleFunc(metadataPath, h.handleMetadata)
router.HandleFunc(certificatePath, h.handleCertificate)
router.HandleFunc(acsPath, h.handleACS)
return router
}
func parseSAMLRequest(r *http.Request) *externalSAMLIDPCallbackData {
vars := mux.Vars(r)
return &externalSAMLIDPCallbackData{
IDPID: vars[varIDPID],
Response: r.FormValue("SAMLResponse"),
RelayState: r.FormValue("RelayState"),
}
}
func (h *Handler) getProvider(ctx context.Context, idpID string) (idp.Provider, error) {
return h.commands.GetProvider(ctx, idpID, h.callbackURL(ctx), h.samlRootURL(ctx, idpID))
}
func (h *Handler) handleCertificate(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data := parseSAMLRequest(r)
provider, err := h.getProvider(ctx, data.IDPID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
samlProvider, ok := provider.(*saml2.Provider)
if !ok {
http.Error(w, z_errs.ThrowInvalidArgument(nil, "SAML-lrud8s9coi", "Errors.Intent.IDPInvalid").Error(), http.StatusBadRequest)
return
}
certPem := new(bytes.Buffer)
if _, err := certPem.Write(samlProvider.Certificate); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
w.Header().Set("Content-Disposition", "attachment; filename=idp.crt")
w.Header().Set("Content-Type", r.Header.Get("Content-Type"))
_, err = io.Copy(w, certPem)
if err != nil {
http.Error(w, fmt.Errorf("failed to response with certificate: %w", err).Error(), http.StatusInternalServerError)
return
}
}
func (h *Handler) handleMetadata(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data := parseSAMLRequest(r)
provider, err := h.getProvider(ctx, data.IDPID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
samlProvider, ok := provider.(*saml2.Provider)
if !ok {
http.Error(w, z_errs.ThrowInvalidArgument(nil, "SAML-lrud8s9coi", "Errors.Intent.IDPInvalid").Error(), http.StatusBadRequest)
return
}
sp, err := samlProvider.GetSP()
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
metadata := sp.ServiceProvider.Metadata()
for i, spDesc := range metadata.SPSSODescriptors {
spDesc.AssertionConsumerServices = append(
spDesc.AssertionConsumerServices,
saml.IndexedEndpoint{
Binding: saml.HTTPPostBinding,
Location: h.loginSAMLRootURL(ctx),
Index: len(spDesc.AssertionConsumerServices) + 1,
}, saml.IndexedEndpoint{
Binding: saml.HTTPArtifactBinding,
Location: h.loginSAMLRootURL(ctx),
Index: len(spDesc.AssertionConsumerServices) + 2,
},
)
metadata.SPSSODescriptors[i] = spDesc
}
buf, _ := xml.MarshalIndent(metadata, "", " ")
w.Header().Set("Content-Type", "application/samlmetadata+xml")
_, err = w.Write(buf)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
}
func (h *Handler) handleACS(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data := parseSAMLRequest(r)
provider, err := h.getProvider(ctx, data.IDPID)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
samlProvider, ok := provider.(*saml2.Provider)
if !ok {
err := z_errs.ThrowInvalidArgument(nil, "SAML-ui9wyux0hp", "Errors.Intent.IDPInvalid")
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
sp, err := samlProvider.GetSP()
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
intent, err := h.commands.GetActiveIntent(ctx, data.RelayState)
if err != nil {
if z_errs.IsNotFound(err) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
redirectToFailureURLErr(w, r, intent, err)
return
}
session := saml2.Session{
ServiceProvider: sp,
RequestID: intent.RequestID,
Request: r,
}
idpUser, err := session.FetchUser(r.Context())
if err != nil {
cmdErr := h.commands.FailIDPIntent(ctx, intent, err.Error())
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
redirectToFailureURLErr(w, r, intent, err)
return
}
userID, err := h.checkExternalUser(ctx, intent.IDPID, idpUser.GetID())
logging.WithFields("intent", intent.AggregateID).OnError(err).Error("could not check if idp user already exists")
token, err := h.commands.SucceedSAMLIDPIntent(ctx, intent, idpUser, userID, session.Assertion)
if err != nil {
redirectToFailureURLErr(w, r, intent, z_errs.ThrowInternal(err, "IDP-JdD3g", "Errors.Intent.TokenCreationFailed"))
return
}
redirectToSuccessURL(w, r, intent, token, userID)
}
func (h *Handler) handleCallback(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data, err := h.parseCallbackRequest(r)
@@ -111,7 +295,7 @@ func (h *Handler) handleCallback(w http.ResponseWriter, r *http.Request) {
return
}
provider, err := h.commands.GetProvider(ctx, intent.IDPID, h.callbackURL(ctx))
provider, err := h.getProvider(ctx, intent.IDPID)
if err != nil {
cmdErr := h.commands.FailIDPIntent(ctx, intent, err.Error())
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
@@ -119,7 +303,7 @@ func (h *Handler) handleCallback(w http.ResponseWriter, r *http.Request) {
return
}
idpUser, idpSession, err := h.fetchIDPUser(ctx, provider, data.Code, data.User)
idpUser, idpSession, err := h.fetchIDPUserFromCode(ctx, provider, data.Code, data.User)
if err != nil {
cmdErr := h.commands.FailIDPIntent(ctx, intent, err.Error())
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
@@ -170,23 +354,6 @@ func (h *Handler) parseCallbackRequest(r *http.Request) (*externalIDPCallbackDat
return data, nil
}
func (h *Handler) getActiveIntent(w http.ResponseWriter, r *http.Request, state string) *command.IDPIntentWriteModel {
intent, err := h.commands.GetIntentWriteModel(r.Context(), state, "")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return nil
}
if intent.State == domain.IDPIntentStateUnspecified {
http.Error(w, reason("IDP-Hk38e", "Errors.Intent.NotStarted"), http.StatusBadRequest)
return nil
}
if intent.State != domain.IDPIntentStateStarted {
redirectToFailureURL(w, r, intent, "IDP-Sfrgs", "Errors.Intent.NotStarted")
return nil
}
return intent
}
func redirectToSuccessURL(w http.ResponseWriter, r *http.Request, intent *command.IDPIntentWriteModel, token, userID string) {
queries := intent.SuccessURL.Query()
queries.Set(paramIntentID, intent.AggregateID)
@@ -218,7 +385,7 @@ func redirectToFailureURL(w http.ResponseWriter, r *http.Request, i *command.IDP
http.Redirect(w, r, i.FailureURL.String(), http.StatusFound)
}
func (h *Handler) fetchIDPUser(ctx context.Context, identityProvider idp.Provider, code string, appleUser string) (user idp.User, idpTokens idp.Session, err error) {
func (h *Handler) fetchIDPUserFromCode(ctx context.Context, identityProvider idp.Provider, code string, appleUser string) (user idp.User, idpTokens idp.Session, err error) {
var session idp.Session
switch provider := identityProvider.(type) {
case *oauth.Provider:
@@ -235,7 +402,7 @@ func (h *Handler) fetchIDPUser(ctx context.Context, identityProvider idp.Provide
session = &openid.Session{Provider: provider.Provider, Code: code}
case *apple.Provider:
session = &apple.Session{Session: &openid.Session{Provider: provider.Provider, Code: code}, UserFormValue: appleUser}
case *jwt.Provider, *ldap.Provider:
case *jwt.Provider, *ldap.Provider, *saml2.Provider:
return nil, nil, z_errs.ThrowInvalidArgument(nil, "IDP-52jmn", "Errors.ExternalIDP.IDPTypeNotImplemented")
default:
return nil, nil, z_errs.ThrowUnimplemented(nil, "IDP-SSDg", "Errors.ExternalIDP.IDPTypeNotImplemented")

View File

@@ -0,0 +1,488 @@
//go:build integration
package idp_test
import (
"context"
"crypto"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"encoding/xml"
"io"
"net/http"
"net/url"
"os"
"strings"
"testing"
"time"
"github.com/beevik/etree"
"github.com/crewjam/saml"
"github.com/crewjam/saml/samlidp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
saml_xml "github.com/zitadel/saml/pkg/provider/xml"
"golang.org/x/crypto/bcrypt"
http_util "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/integration"
user "github.com/zitadel/zitadel/pkg/grpc/user/v2beta"
)
var (
CTX context.Context
ErrCTX context.Context
Tester *integration.Tester
Client user.UserServiceClient
)
func TestMain(m *testing.M) {
os.Exit(func() int {
ctx, errCtx, cancel := integration.Contexts(time.Hour)
defer cancel()
Tester = integration.NewTester(ctx)
defer Tester.Done()
CTX, ErrCTX = Tester.WithAuthorization(ctx, integration.OrgOwner), errCtx
Client = Tester.Client.UserV2
return m.Run()
}())
}
func TestServer_SAMLCertificate(t *testing.T) {
samlRedirectIdpID := Tester.AddSAMLRedirectProvider(t)
oauthIdpID := Tester.AddGenericOAuthProvider(t)
type args struct {
ctx context.Context
idpID string
}
tests := []struct {
name string
args args
want int
}{
{
name: "saml certificate, invalid idp",
args: args{
ctx: CTX,
idpID: "unknown",
},
want: http.StatusNotFound,
},
{
name: "saml certificate, invalid idp type",
args: args{
ctx: CTX,
idpID: oauthIdpID,
},
want: http.StatusBadRequest,
},
{
name: "saml certificate, ok",
args: args{
ctx: CTX,
idpID: samlRedirectIdpID,
},
want: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
certificateURL := http_util.BuildOrigin(Tester.Host(), Tester.Server.Config.ExternalSecure) + "/idps/" + tt.args.idpID + "/saml/certificate"
resp, err := http.Get(certificateURL)
assert.NoError(t, err)
assert.Equal(t, tt.want, resp.StatusCode)
if tt.want == http.StatusOK {
b, err := io.ReadAll(resp.Body)
defer resp.Body.Close()
assert.NoError(t, err)
block, _ := pem.Decode(b)
_, err = x509.ParseCertificate(block.Bytes)
assert.NoError(t, err)
}
})
}
}
func TestServer_SAMLMetadata(t *testing.T) {
samlRedirectIdpID := Tester.AddSAMLRedirectProvider(t)
oauthIdpID := Tester.AddGenericOAuthProvider(t)
type args struct {
ctx context.Context
idpID string
}
tests := []struct {
name string
args args
want int
}{
{
name: "saml metadata, invalid idp",
args: args{
ctx: CTX,
idpID: "unknown",
},
want: http.StatusNotFound,
},
{
name: "saml metadata, invalid idp type",
args: args{
ctx: CTX,
idpID: oauthIdpID,
},
want: http.StatusBadRequest,
},
{
name: "saml metadata, ok",
args: args{
ctx: CTX,
idpID: samlRedirectIdpID,
},
want: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
metadataURL := http_util.BuildOrigin(Tester.Host(), Tester.Server.Config.ExternalSecure) + "/idps/" + tt.args.idpID + "/saml/metadata"
resp, err := http.Get(metadataURL)
assert.NoError(t, err)
assert.Equal(t, tt.want, resp.StatusCode)
if tt.want == http.StatusOK {
b, err := io.ReadAll(resp.Body)
defer resp.Body.Close()
assert.NoError(t, err)
_, err = saml_xml.ParseMetadataXmlIntoStruct(b)
assert.NoError(t, err)
}
})
}
}
func TestServer_SAMLACS(t *testing.T) {
userHuman := Tester.CreateHumanUser(CTX)
samlRedirectIdpID := Tester.AddSAMLRedirectProvider(t)
externalUserID := "test1"
linkedExternalUserID := "test2"
Tester.CreateUserIDPlink(CTX, userHuman.UserId, linkedExternalUserID, samlRedirectIdpID, linkedExternalUserID)
idp, err := getIDP(
http_util.BuildOrigin(Tester.Host(), Tester.Server.Config.ExternalSecure),
[]string{samlRedirectIdpID},
externalUserID,
linkedExternalUserID,
)
assert.NoError(t, err)
type args struct {
ctx context.Context
successURL string
failureURL string
idpID string
username string
intentID string
response string
}
type want struct {
successful bool
user string
}
tests := []struct {
name string
args args
want want
wantErr bool
}{
{
name: "intent invalid",
args: args{
ctx: CTX,
successURL: "https://example.com/success",
failureURL: "https://example.com/failure",
idpID: samlRedirectIdpID,
username: externalUserID,
intentID: "notexisting",
},
want: want{
successful: false,
user: "",
},
wantErr: true,
},
{
name: "response invalid",
args: args{
ctx: CTX,
successURL: "https://example.com/success",
failureURL: "https://example.com/failure",
idpID: samlRedirectIdpID,
username: externalUserID,
response: "invalid",
},
want: want{
successful: false,
user: "",
},
},
{
name: "saml flow redirect, ok",
args: args{
ctx: CTX,
successURL: "https://example.com/success",
failureURL: "https://example.com/failure",
idpID: samlRedirectIdpID,
username: externalUserID,
},
want: want{
successful: true,
user: "",
},
},
{
name: "saml flow redirect with link, ok",
args: args{
ctx: CTX,
successURL: "https://example.com/success",
failureURL: "https://example.com/failure",
idpID: samlRedirectIdpID,
username: linkedExternalUserID,
},
want: want{
successful: true,
user: userHuman.UserId,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Client.StartIdentityProviderFlow(tt.args.ctx,
&user.StartIdentityProviderFlowRequest{
IdpId: tt.args.idpID,
Content: &user.StartIdentityProviderFlowRequest_Urls{
Urls: &user.RedirectURLs{
SuccessUrl: tt.args.successURL,
FailureUrl: tt.args.failureURL,
},
},
},
)
// can't fail as covered in other tests
require.NoError(t, err)
//parse returned URL to continue flow to callback with the same intentID==RelayState
authURL, err := url.Parse(got.GetAuthUrl())
require.NoError(t, err)
samlRequest := &http.Request{Method: http.MethodGet, URL: authURL}
assert.NotEmpty(t, authURL)
//generate necessary information to create request to callback URL
relayState := authURL.Query().Get("RelayState")
//test purposes, use defined intentID
if tt.args.intentID != "" {
relayState = tt.args.intentID
}
callbackURL := http_util.BuildOrigin(Tester.Host(), Tester.Server.Config.ExternalSecure) + "/idps/" + tt.args.idpID + "/saml/acs"
response := createResponse(t, idp, samlRequest, tt.args.username)
//test purposes, use defined response
if tt.args.response != "" {
response = tt.args.response
}
req := httpPostFormRequest(t, callbackURL, relayState, response)
//do request to callback URL and check redirect to either success or failure url
location, err := integration.CheckRedirect(req)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, relayState, location.Query().Get("id"))
if tt.want.successful {
assert.True(t, strings.HasPrefix(location.String(), tt.args.successURL))
assert.NotEmpty(t, location.Query().Get("token"))
assert.Equal(t, tt.want.user, location.Query().Get("user"))
} else {
assert.True(t, strings.HasPrefix(location.String(), tt.args.failureURL))
}
}
})
}
}
var key = func() crypto.PrivateKey {
b, _ := pem.Decode([]byte(`-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEA0OhbMuizgtbFOfwbK7aURuXhZx6VRuAs3nNibiuifwCGz6u9
yy7bOR0P+zqN0YkjxaokqFgra7rXKCdeABmoLqCC0U+cGmLNwPOOA0PaD5q5xKhQ
4Me3rt/R9C4Ca6k3/OnkxnKwnogcsmdgs2l8liT3qVHP04Oc7Uymq2v09bGb6nPu
fOrkXS9F6mSClxHG/q59AGOWsXK1xzIRV1eu8W2SNdyeFVU1JHiQe444xLoPul5t
InWasKayFsPlJfWNc8EoU8COjNhfo/GovFTHVjh9oUR/gwEFVwifIHihRE0Hazn2
EQSLaOr2LM0TsRsQroFjmwSGgI+X2bfbMTqWOQIDAQABAoIBAFWZwDTeESBdrLcT
zHZe++cJLxE4AObn2LrWANEv5AeySYsyzjRBYObIN9IzrgTb8uJ900N/zVr5VkxH
xUa5PKbOcowd2NMfBTw5EEnaNbILLm+coHdanrNzVu59I9TFpAFoPavrNt/e2hNo
NMGPSdOkFi81LLl4xoadz/WR6O/7N2famM+0u7C2uBe+TrVwHyuqboYoidJDhO8M
w4WlY9QgAUhkPyzZqrl+VfF1aDTGVf4LJgaVevfFCas8Ws6DQX5q4QdIoV6/0vXi
B1M+aTnWjHuiIzjBMWhcYW2+I5zfwNWRXaxdlrYXRukGSdnyO+DH/FhHePJgmlkj
NInADDkCgYEA6MEQFOFSCc/ELXYWgStsrtIlJUcsLdLBsy1ocyQa2lkVUw58TouW
RciE6TjW9rp31pfQUnO2l6zOUC6LT9Jvlb9PSsyW+rvjtKB5PjJI6W0hjX41wEO6
fshFELMJd9W+Ezao2AsP2hZJ8McCF8no9e00+G4xTAyxHsNI2AFTCQcCgYEA5cWZ
JwNb4t7YeEajPt9xuYNUOQpjvQn1aGOV7KcwTx5ELP/Hzi723BxHs7GSdrLkkDmi
Gpb+mfL4wxCt0fK0i8GFQsRn5eusyq9hLqP/bmjpHoXe/1uajFbE1fZQR+2LX05N
3ATlKaH2hdfCJedFa4wf43+cl6Yhp6ZA0Yet1r8CgYEAwiu1j8W9G+RRA5/8/DtO
yrUTOfsbFws4fpLGDTA0mq0whf6Soy/96C90+d9qLaC3srUpnG9eB0CpSOjbXXbv
kdxseLkexwOR3bD2FHX8r4dUM2bzznZyEaxfOaQypN8SV5ME3l60Fbr8ajqLO288
wlTmGM5Mn+YCqOg/T7wjGmcCgYBpzNfdl/VafOROVbBbhgXWtzsz3K3aYNiIjbp+
MunStIwN8GUvcn6nEbqOaoiXcX4/TtpuxfJMLw4OvAJdtxUdeSmEee2heCijV6g3
ErrOOy6EqH3rNWHvlxChuP50cFQJuYOueO6QggyCyruSOnDDuc0BM0SGq6+5g5s7
H++S/wKBgQDIkqBtFr9UEf8d6JpkxS0RXDlhSMjkXmkQeKGFzdoJcYVFIwq8jTNB
nJrVIGs3GcBkqGic+i7rTO1YPkquv4dUuiIn+vKZVoO6b54f+oPBXd4S0BnuEqFE
rdKNuCZhiaE2XD9L/O9KP1fh5bfEcKwazQ23EvpJHBMm8BGC+/YZNw==
-----END RSA PRIVATE KEY-----`))
k, _ := x509.ParsePKCS1PrivateKey(b.Bytes)
return k
}()
var cert = func() *x509.Certificate {
b, _ := pem.Decode([]byte(`-----BEGIN CERTIFICATE-----
MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNV
BAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5
NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEB
BQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8A
hs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+a
ucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWx
m+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6
D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURN
B2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0O
BBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56
zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5
pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uv
NONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEf
y/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL
/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsb
GFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTL
UzreO96WzlBBMtY=
-----END CERTIFICATE-----`))
c, _ := x509.ParseCertificate(b.Bytes)
return c
}()
func getIDP(zitadelBaseURL string, idpIDs []string, user1, user2 string) (*saml.IdentityProvider, error) {
baseURL, err := url.Parse("http://localhost:8000")
if err != nil {
return nil, err
}
store := &samlidp.MemoryStore{}
hashedPassword1, _ := bcrypt.GenerateFromPassword([]byte("test"), bcrypt.DefaultCost)
err = store.Put("/users/"+user1, samlidp.User{
Name: user1,
HashedPassword: hashedPassword1,
Groups: []string{"Administrators", "Users"},
Email: "test@example.com",
CommonName: "Test Test",
Surname: "Test",
GivenName: "Test",
})
if err != nil {
return nil, err
}
hashedPassword2, _ := bcrypt.GenerateFromPassword([]byte("test"), bcrypt.DefaultCost)
err = store.Put("/users/"+user2, samlidp.User{
Name: user2,
HashedPassword: hashedPassword2,
Groups: []string{"Administrators", "Users"},
Email: "test@example.com",
CommonName: "Test Test",
Surname: "Test",
GivenName: "Test",
})
if err != nil {
return nil, err
}
for _, idpID := range idpIDs {
metadata, err := saml_xml.ReadMetadataFromURL(http.DefaultClient, zitadelBaseURL+"/idps/"+idpID+"/saml/metadata")
if err != nil {
return nil, err
}
entity := new(saml.EntityDescriptor)
if err := xml.Unmarshal(metadata, entity); err != nil {
return nil, err
}
if err := store.Put("/services/"+idpID, samlidp.Service{
Name: idpID,
Metadata: *entity,
}); err != nil {
return nil, err
}
}
idpServer, err := samlidp.New(samlidp.Options{
URL: *baseURL,
Key: key,
Certificate: cert,
Store: store,
})
if err != nil {
return nil, err
}
if idpServer.IDP.AssertionMaker == nil {
idpServer.IDP.AssertionMaker = &saml.DefaultAssertionMaker{}
}
return &idpServer.IDP, nil
}
func createResponse(t *testing.T, idp *saml.IdentityProvider, req *http.Request, username string) string {
authnReq, err := saml.NewIdpAuthnRequest(idp, req)
assert.NoError(t, authnReq.Validate())
err = idp.AssertionMaker.MakeAssertion(authnReq, &saml.Session{
CreateTime: time.Now().UTC(),
Index: "",
NameID: username,
})
assert.NoError(t, err)
err = authnReq.MakeResponse()
assert.NoError(t, err)
doc := etree.NewDocument()
doc.SetRoot(authnReq.ResponseEl)
responseBuf, err := doc.WriteToBytes()
assert.NoError(t, err)
responseBuf = append([]byte("<?xml version=\"1.0\"?>"), responseBuf...)
return base64.StdEncoding.EncodeToString(responseBuf)
}
func httpGETRequest(t *testing.T, callbackURL string, relayState, response, sig, sigAlg string) *http.Request {
req, err := http.NewRequest("GET", callbackURL, nil)
require.NoError(t, err)
q := req.URL.Query()
q.Add("RelayState", relayState)
q.Add("SAMLResponse", response)
if sig != "" {
q.Add("Sig", sig)
}
if sigAlg != "" {
q.Add("SigAlg", sigAlg)
}
req.URL.RawQuery = q.Encode()
return req
}
func httpPostFormRequest(t *testing.T, callbackURL, relayState, response string) *http.Request {
body := url.Values{
"SAMLResponse": {response},
"RelayState": {relayState},
}
req, err := http.NewRequest("POST", callbackURL, strings.NewReader(body.Encode()))
assert.NoError(t, err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.ParseForm()
return req
}

View File

@@ -5,6 +5,7 @@ import (
"net/http"
"strings"
"github.com/crewjam/saml/samlsp"
"github.com/zitadel/logging"
"github.com/zitadel/oidc/v2/pkg/client/rp"
"github.com/zitadel/oidc/v2/pkg/oidc"
@@ -12,6 +13,7 @@ import (
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/api/authz"
http_utils "github.com/zitadel/zitadel/internal/api/http"
http_mw "github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
@@ -27,6 +29,8 @@ import (
"github.com/zitadel/zitadel/internal/idp/providers/ldap"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
openid "github.com/zitadel/zitadel/internal/idp/providers/oidc"
"github.com/zitadel/zitadel/internal/idp/providers/saml"
"github.com/zitadel/zitadel/internal/idp/providers/saml/requesttracker"
"github.com/zitadel/zitadel/internal/query"
)
@@ -43,6 +47,9 @@ type externalIDPCallbackData struct {
State string `schema:"state"`
Code string `schema:"code"`
RelayState string `schema:"RelayState"`
Method string `schema:"Method"`
// Apple returns a user on first registration
User string `schema:"user"`
}
@@ -167,6 +174,8 @@ func (l *Login) handleIDP(w http.ResponseWriter, r *http.Request, authReq *domai
provider, err = l.appleProvider(r.Context(), identityProvider)
case domain.IDPTypeLDAP:
provider, err = l.ldapProvider(r.Context(), identityProvider)
case domain.IDPTypeSAML:
provider, err = l.samlProvider(r.Context(), identityProvider)
case domain.IDPTypeUnspecified:
fallthrough
default:
@@ -183,7 +192,17 @@ func (l *Login) handleIDP(w http.ResponseWriter, r *http.Request, authReq *domai
l.renderLogin(w, r, authReq, err)
return
}
http.Redirect(w, r, session.GetAuthURL(), http.StatusFound)
content, redirect := session.GetAuth(r.Context())
if redirect {
http.Redirect(w, r, content, http.StatusFound)
return
}
_, err = w.Write([]byte(content))
if err != nil {
l.renderError(w, r, authReq, err)
return
}
}
// handleExternalLoginCallbackForm handles the callback from a IDP with form_post.
@@ -195,6 +214,7 @@ func (l *Login) handleExternalLoginCallbackForm(w http.ResponseWriter, r *http.R
l.renderLogin(w, r, nil, err)
return
}
r.Form.Add("Method", http.MethodPost)
http.Redirect(w, r, HandlerPrefix+EndpointExternalLoginCallback+"?"+r.Form.Encode(), 302)
}
@@ -207,6 +227,15 @@ func (l *Login) handleExternalLoginCallback(w http.ResponseWriter, r *http.Reque
l.renderLogin(w, r, nil, err)
return
}
if data.State == "" {
data.State = data.RelayState
}
// workaround because of CSRF on external identity provider flows
if data.Method == http.MethodPost {
r.Method = http.MethodPost
r.PostForm = r.Form
}
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
authReq, err := l.authRepo.AuthRequestByID(r.Context(), data.State, userAgentID)
if err != nil {
@@ -284,6 +313,18 @@ func (l *Login) handleExternalLoginCallback(w http.ResponseWriter, r *http.Reque
return
}
session = &apple.Session{Session: &openid.Session{Provider: provider.(*apple.Provider).Provider, Code: data.Code}, UserFormValue: data.User}
case domain.IDPTypeSAML:
provider, err = l.samlProvider(r.Context(), identityProvider)
if err != nil {
l.externalAuthFailed(w, r, authReq, nil, nil, err)
return
}
sp, err := provider.(*saml.Provider).GetSP()
if err != nil {
l.externalAuthFailed(w, r, authReq, nil, nil, err)
return
}
session = &saml.Session{ServiceProvider: sp, RequestID: authReq.SAMLRequestID, Request: r}
case domain.IDPTypeJWT,
domain.IDPTypeLDAP,
domain.IDPTypeUnspecified:
@@ -881,6 +922,49 @@ func (l *Login) oauthProvider(ctx context.Context, identityProvider *query.IDPTe
)
}
func (l *Login) samlProvider(ctx context.Context, identityProvider *query.IDPTemplate) (*saml.Provider, error) {
key, err := crypto.Decrypt(identityProvider.SAMLIDPTemplate.Key, l.idpConfigAlg)
if err != nil {
return nil, err
}
opts := make([]saml.ProviderOpts, 0, 2)
if identityProvider.SAMLIDPTemplate.WithSignedRequest {
opts = append(opts, saml.WithSignedRequest())
}
if identityProvider.SAMLIDPTemplate.Binding != "" {
opts = append(opts, saml.WithBinding(identityProvider.SAMLIDPTemplate.Binding))
}
opts = append(opts,
saml.WithEntityID(http_utils.BuildOrigin(authz.GetInstance(ctx).RequestedHost(), l.externalSecure)+"/idps/"+identityProvider.ID+"/saml/metadata"),
saml.WithCustomRequestTracker(
requesttracker.New(
func(ctx context.Context, authRequestID, samlRequestID string) error {
useragent, _ := http_mw.UserAgentIDFromCtx(ctx)
return l.authRepo.SaveSAMLRequestID(ctx, authRequestID, samlRequestID, useragent)
},
func(ctx context.Context, authRequestID string) (*samlsp.TrackedRequest, error) {
useragent, _ := http_mw.UserAgentIDFromCtx(ctx)
auhRequest, err := l.authRepo.AuthRequestByID(ctx, authRequestID, useragent)
if err != nil {
return nil, err
}
return &samlsp.TrackedRequest{
SAMLRequestID: auhRequest.SAMLRequestID,
Index: authRequestID,
}, nil
},
),
))
return saml.New(
identityProvider.Name,
l.baseURL(ctx)+EndpointExternalLogin+"/",
identityProvider.SAMLIDPTemplate.Metadata,
identityProvider.SAMLIDPTemplate.Certificate,
key,
opts...,
)
}
func (l *Login) azureProvider(ctx context.Context, identityProvider *query.IDPTemplate) (*azuread.Provider, error) {
secret, err := crypto.DecryptString(identityProvider.AzureADIDPTemplate.ClientSecret, l.idpConfigAlg)
if err != nil {

View File

@@ -126,7 +126,7 @@ func createCSRFInterceptor(cookieName string, csrfCookieKey []byte, externalSecu
}
// ignore form post callback
// it will redirect to the "normal" callback, where the cookie is set again
if r.URL.Path == EndpointExternalLoginCallbackFormPost && r.Method == http.MethodPost {
if (r.URL.Path == EndpointExternalLoginCallbackFormPost || r.URL.Path == EndpointSAMLACS) && r.Method == http.MethodPost {
handler.ServeHTTP(w, r)
return
}

View File

@@ -14,6 +14,7 @@ const (
EndpointExternalLogin = "/login/externalidp"
EndpointExternalLoginCallback = "/login/externalidp/callback"
EndpointExternalLoginCallbackFormPost = "/login/externalidp/callback/form"
EndpointSAMLACS = "/login/externalidp/saml/acs"
EndpointJWTAuthorize = "/login/jwt/authorize"
EndpointJWTCallback = "/login/jwt/callback"
EndpointLDAPLogin = "/login/ldap"
@@ -73,6 +74,8 @@ func CreateRouter(login *Login, staticDir http.FileSystem, interceptors ...mux.M
router.HandleFunc(EndpointExternalLogin, login.handleExternalLogin).Methods(http.MethodGet)
router.HandleFunc(EndpointExternalLoginCallback, login.handleExternalLoginCallback).Methods(http.MethodGet)
router.HandleFunc(EndpointExternalLoginCallbackFormPost, login.handleExternalLoginCallbackForm).Methods(http.MethodPost)
router.HandleFunc(EndpointSAMLACS, login.handleExternalLoginCallback).Methods(http.MethodGet)
router.HandleFunc(EndpointSAMLACS, login.handleExternalLoginCallbackForm).Methods(http.MethodPost)
router.HandleFunc(EndpointJWTAuthorize, login.handleJWTRequest).Methods(http.MethodGet)
router.HandleFunc(EndpointJWTCallback, login.handleJWTCallback).Methods(http.MethodGet)
router.HandleFunc(EndpointPasswordlessLogin, login.handlePasswordlessVerification).Methods(http.MethodPost)