mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 21:37:32 +00:00
feat: add SAML as identity provider (#6454)
* feat: first implementation for saml sp * fix: add command side instance and org for saml provider * fix: add query side instance and org for saml provider * fix: request handling in event and retrieval of finished intent * fix: add review changes and integration tests * fix: add integration tests for saml idp * fix: correct unit tests with review changes * fix: add saml session unit test * fix: add saml session unit test * fix: add saml session unit test * fix: changes from review * fix: changes from review * fix: proto build error * fix: proto build error * fix: proto build error * fix: proto require metadata oneof * fix: login with saml provider * fix: integration test for saml assertion * lint client.go * fix json tag * fix: linting * fix import * fix: linting * fix saml idp query * fix: linting * lint: try all issues * revert linting config * fix: add regenerate endpoints * fix: translations * fix mk.yaml * ignore acs path for user agent cookie * fix: add AuthFromProvider test for saml * fix: integration test for saml retrieve information --------- Co-authored-by: Livio Spring <livio.a@gmail.com>
This commit is contained in:
@@ -1,18 +1,23 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/crewjam/saml"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
http_utils "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/api/ui/login"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
z_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/form"
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
@@ -25,19 +30,26 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/ldap"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
|
||||
openid "github.com/zitadel/zitadel/internal/idp/providers/oidc"
|
||||
saml2 "github.com/zitadel/zitadel/internal/idp/providers/saml"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
)
|
||||
|
||||
const (
|
||||
HandlerPrefix = "/idps"
|
||||
callbackPath = "/callback"
|
||||
ldapCallbackPath = callbackPath + "/ldap"
|
||||
HandlerPrefix = "/idps"
|
||||
|
||||
idpPrefix = "/{" + varIDPID + ":[0-9]+}"
|
||||
|
||||
callbackPath = "/callback"
|
||||
metadataPath = idpPrefix + "/saml/metadata"
|
||||
acsPath = idpPrefix + "/saml/acs"
|
||||
certificatePath = idpPrefix + "/saml/certificate"
|
||||
|
||||
paramIntentID = "id"
|
||||
paramToken = "token"
|
||||
paramUserID = "user"
|
||||
paramError = "error"
|
||||
paramErrorDescription = "error_description"
|
||||
varIDPID = "idpid"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
@@ -46,6 +58,8 @@ type Handler struct {
|
||||
parser *form.Parser
|
||||
encryptionAlgorithm crypto.EncryptionAlgorithm
|
||||
callbackURL func(ctx context.Context) string
|
||||
samlRootURL func(ctx context.Context, idpID string) string
|
||||
loginSAMLRootURL func(ctx context.Context) string
|
||||
}
|
||||
|
||||
type externalIDPCallbackData struct {
|
||||
@@ -58,6 +72,12 @@ type externalIDPCallbackData struct {
|
||||
User string `schema:"user"`
|
||||
}
|
||||
|
||||
type externalSAMLIDPCallbackData struct {
|
||||
IDPID string
|
||||
Response string
|
||||
RelayState string
|
||||
}
|
||||
|
||||
// CallbackURL generates the instance specific URL to the IDP callback handler
|
||||
func CallbackURL(externalSecure bool) func(ctx context.Context) string {
|
||||
return func(ctx context.Context) string {
|
||||
@@ -65,6 +85,18 @@ func CallbackURL(externalSecure bool) func(ctx context.Context) string {
|
||||
}
|
||||
}
|
||||
|
||||
func SAMLRootURL(externalSecure bool) func(ctx context.Context, idpID string) string {
|
||||
return func(ctx context.Context, idpID string) string {
|
||||
return http_utils.BuildOrigin(authz.GetInstance(ctx).RequestedHost(), externalSecure) + HandlerPrefix + "/" + idpID + "/"
|
||||
}
|
||||
}
|
||||
|
||||
func LoginSAMLRootURL(externalSecure bool) func(ctx context.Context) string {
|
||||
return func(ctx context.Context) string {
|
||||
return http_utils.BuildOrigin(authz.GetInstance(ctx).RequestedHost(), externalSecure) + login.HandlerPrefix + login.EndpointSAMLACS
|
||||
}
|
||||
}
|
||||
|
||||
func NewHandler(
|
||||
commands *command.Commands,
|
||||
queries *query.Queries,
|
||||
@@ -78,14 +110,166 @@ func NewHandler(
|
||||
parser: form.NewParser(),
|
||||
encryptionAlgorithm: encryptionAlgorithm,
|
||||
callbackURL: CallbackURL(externalSecure),
|
||||
samlRootURL: SAMLRootURL(externalSecure),
|
||||
loginSAMLRootURL: LoginSAMLRootURL(externalSecure),
|
||||
}
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.Use(instanceInterceptor)
|
||||
router.HandleFunc(callbackPath, h.handleCallback)
|
||||
router.HandleFunc(metadataPath, h.handleMetadata)
|
||||
router.HandleFunc(certificatePath, h.handleCertificate)
|
||||
router.HandleFunc(acsPath, h.handleACS)
|
||||
return router
|
||||
}
|
||||
|
||||
func parseSAMLRequest(r *http.Request) *externalSAMLIDPCallbackData {
|
||||
vars := mux.Vars(r)
|
||||
return &externalSAMLIDPCallbackData{
|
||||
IDPID: vars[varIDPID],
|
||||
Response: r.FormValue("SAMLResponse"),
|
||||
RelayState: r.FormValue("RelayState"),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) getProvider(ctx context.Context, idpID string) (idp.Provider, error) {
|
||||
return h.commands.GetProvider(ctx, idpID, h.callbackURL(ctx), h.samlRootURL(ctx, idpID))
|
||||
}
|
||||
|
||||
func (h *Handler) handleCertificate(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
data := parseSAMLRequest(r)
|
||||
|
||||
provider, err := h.getProvider(ctx, data.IDPID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
samlProvider, ok := provider.(*saml2.Provider)
|
||||
if !ok {
|
||||
http.Error(w, z_errs.ThrowInvalidArgument(nil, "SAML-lrud8s9coi", "Errors.Intent.IDPInvalid").Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
certPem := new(bytes.Buffer)
|
||||
if _, err := certPem.Write(samlProvider.Certificate); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Disposition", "attachment; filename=idp.crt")
|
||||
w.Header().Set("Content-Type", r.Header.Get("Content-Type"))
|
||||
_, err = io.Copy(w, certPem)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Errorf("failed to response with certificate: %w", err).Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) handleMetadata(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
data := parseSAMLRequest(r)
|
||||
|
||||
provider, err := h.getProvider(ctx, data.IDPID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
samlProvider, ok := provider.(*saml2.Provider)
|
||||
if !ok {
|
||||
http.Error(w, z_errs.ThrowInvalidArgument(nil, "SAML-lrud8s9coi", "Errors.Intent.IDPInvalid").Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
sp, err := samlProvider.GetSP()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
metadata := sp.ServiceProvider.Metadata()
|
||||
|
||||
for i, spDesc := range metadata.SPSSODescriptors {
|
||||
spDesc.AssertionConsumerServices = append(
|
||||
spDesc.AssertionConsumerServices,
|
||||
saml.IndexedEndpoint{
|
||||
Binding: saml.HTTPPostBinding,
|
||||
Location: h.loginSAMLRootURL(ctx),
|
||||
Index: len(spDesc.AssertionConsumerServices) + 1,
|
||||
}, saml.IndexedEndpoint{
|
||||
Binding: saml.HTTPArtifactBinding,
|
||||
Location: h.loginSAMLRootURL(ctx),
|
||||
Index: len(spDesc.AssertionConsumerServices) + 2,
|
||||
},
|
||||
)
|
||||
metadata.SPSSODescriptors[i] = spDesc
|
||||
}
|
||||
|
||||
buf, _ := xml.MarshalIndent(metadata, "", " ")
|
||||
w.Header().Set("Content-Type", "application/samlmetadata+xml")
|
||||
_, err = w.Write(buf)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) handleACS(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
data := parseSAMLRequest(r)
|
||||
|
||||
provider, err := h.getProvider(ctx, data.IDPID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
samlProvider, ok := provider.(*saml2.Provider)
|
||||
if !ok {
|
||||
err := z_errs.ThrowInvalidArgument(nil, "SAML-ui9wyux0hp", "Errors.Intent.IDPInvalid")
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
sp, err := samlProvider.GetSP()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
intent, err := h.commands.GetActiveIntent(ctx, data.RelayState)
|
||||
if err != nil {
|
||||
if z_errs.IsNotFound(err) {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
redirectToFailureURLErr(w, r, intent, err)
|
||||
return
|
||||
}
|
||||
|
||||
session := saml2.Session{
|
||||
ServiceProvider: sp,
|
||||
RequestID: intent.RequestID,
|
||||
Request: r,
|
||||
}
|
||||
|
||||
idpUser, err := session.FetchUser(r.Context())
|
||||
if err != nil {
|
||||
cmdErr := h.commands.FailIDPIntent(ctx, intent, err.Error())
|
||||
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
|
||||
redirectToFailureURLErr(w, r, intent, err)
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := h.checkExternalUser(ctx, intent.IDPID, idpUser.GetID())
|
||||
logging.WithFields("intent", intent.AggregateID).OnError(err).Error("could not check if idp user already exists")
|
||||
|
||||
token, err := h.commands.SucceedSAMLIDPIntent(ctx, intent, idpUser, userID, session.Assertion)
|
||||
if err != nil {
|
||||
redirectToFailureURLErr(w, r, intent, z_errs.ThrowInternal(err, "IDP-JdD3g", "Errors.Intent.TokenCreationFailed"))
|
||||
return
|
||||
}
|
||||
redirectToSuccessURL(w, r, intent, token, userID)
|
||||
}
|
||||
|
||||
func (h *Handler) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
data, err := h.parseCallbackRequest(r)
|
||||
@@ -111,7 +295,7 @@ func (h *Handler) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := h.commands.GetProvider(ctx, intent.IDPID, h.callbackURL(ctx))
|
||||
provider, err := h.getProvider(ctx, intent.IDPID)
|
||||
if err != nil {
|
||||
cmdErr := h.commands.FailIDPIntent(ctx, intent, err.Error())
|
||||
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
|
||||
@@ -119,7 +303,7 @@ func (h *Handler) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
idpUser, idpSession, err := h.fetchIDPUser(ctx, provider, data.Code, data.User)
|
||||
idpUser, idpSession, err := h.fetchIDPUserFromCode(ctx, provider, data.Code, data.User)
|
||||
if err != nil {
|
||||
cmdErr := h.commands.FailIDPIntent(ctx, intent, err.Error())
|
||||
logging.WithFields("intent", intent.AggregateID).OnError(cmdErr).Error("failed to push failed event on idp intent")
|
||||
@@ -170,23 +354,6 @@ func (h *Handler) parseCallbackRequest(r *http.Request) (*externalIDPCallbackDat
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (h *Handler) getActiveIntent(w http.ResponseWriter, r *http.Request, state string) *command.IDPIntentWriteModel {
|
||||
intent, err := h.commands.GetIntentWriteModel(r.Context(), state, "")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return nil
|
||||
}
|
||||
if intent.State == domain.IDPIntentStateUnspecified {
|
||||
http.Error(w, reason("IDP-Hk38e", "Errors.Intent.NotStarted"), http.StatusBadRequest)
|
||||
return nil
|
||||
}
|
||||
if intent.State != domain.IDPIntentStateStarted {
|
||||
redirectToFailureURL(w, r, intent, "IDP-Sfrgs", "Errors.Intent.NotStarted")
|
||||
return nil
|
||||
}
|
||||
return intent
|
||||
}
|
||||
|
||||
func redirectToSuccessURL(w http.ResponseWriter, r *http.Request, intent *command.IDPIntentWriteModel, token, userID string) {
|
||||
queries := intent.SuccessURL.Query()
|
||||
queries.Set(paramIntentID, intent.AggregateID)
|
||||
@@ -218,7 +385,7 @@ func redirectToFailureURL(w http.ResponseWriter, r *http.Request, i *command.IDP
|
||||
http.Redirect(w, r, i.FailureURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
func (h *Handler) fetchIDPUser(ctx context.Context, identityProvider idp.Provider, code string, appleUser string) (user idp.User, idpTokens idp.Session, err error) {
|
||||
func (h *Handler) fetchIDPUserFromCode(ctx context.Context, identityProvider idp.Provider, code string, appleUser string) (user idp.User, idpTokens idp.Session, err error) {
|
||||
var session idp.Session
|
||||
switch provider := identityProvider.(type) {
|
||||
case *oauth.Provider:
|
||||
@@ -235,7 +402,7 @@ func (h *Handler) fetchIDPUser(ctx context.Context, identityProvider idp.Provide
|
||||
session = &openid.Session{Provider: provider.Provider, Code: code}
|
||||
case *apple.Provider:
|
||||
session = &apple.Session{Session: &openid.Session{Provider: provider.Provider, Code: code}, UserFormValue: appleUser}
|
||||
case *jwt.Provider, *ldap.Provider:
|
||||
case *jwt.Provider, *ldap.Provider, *saml2.Provider:
|
||||
return nil, nil, z_errs.ThrowInvalidArgument(nil, "IDP-52jmn", "Errors.ExternalIDP.IDPTypeNotImplemented")
|
||||
default:
|
||||
return nil, nil, z_errs.ThrowUnimplemented(nil, "IDP-SSDg", "Errors.ExternalIDP.IDPTypeNotImplemented")
|
||||
|
488
internal/api/idp/idp_integration_test.go
Normal file
488
internal/api/idp/idp_integration_test.go
Normal file
@@ -0,0 +1,488 @@
|
||||
//go:build integration
|
||||
|
||||
package idp_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"encoding/xml"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/beevik/etree"
|
||||
"github.com/crewjam/saml"
|
||||
"github.com/crewjam/saml/samlidp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
saml_xml "github.com/zitadel/saml/pkg/provider/xml"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
http_util "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
user "github.com/zitadel/zitadel/pkg/grpc/user/v2beta"
|
||||
)
|
||||
|
||||
var (
|
||||
CTX context.Context
|
||||
ErrCTX context.Context
|
||||
Tester *integration.Tester
|
||||
Client user.UserServiceClient
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(func() int {
|
||||
ctx, errCtx, cancel := integration.Contexts(time.Hour)
|
||||
defer cancel()
|
||||
|
||||
Tester = integration.NewTester(ctx)
|
||||
defer Tester.Done()
|
||||
|
||||
CTX, ErrCTX = Tester.WithAuthorization(ctx, integration.OrgOwner), errCtx
|
||||
Client = Tester.Client.UserV2
|
||||
return m.Run()
|
||||
}())
|
||||
}
|
||||
|
||||
func TestServer_SAMLCertificate(t *testing.T) {
|
||||
samlRedirectIdpID := Tester.AddSAMLRedirectProvider(t)
|
||||
oauthIdpID := Tester.AddGenericOAuthProvider(t)
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
idpID string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "saml certificate, invalid idp",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
idpID: "unknown",
|
||||
},
|
||||
want: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "saml certificate, invalid idp type",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
idpID: oauthIdpID,
|
||||
},
|
||||
want: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "saml certificate, ok",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
idpID: samlRedirectIdpID,
|
||||
},
|
||||
want: http.StatusOK,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
certificateURL := http_util.BuildOrigin(Tester.Host(), Tester.Server.Config.ExternalSecure) + "/idps/" + tt.args.idpID + "/saml/certificate"
|
||||
resp, err := http.Get(certificateURL)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, resp.StatusCode)
|
||||
if tt.want == http.StatusOK {
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
defer resp.Body.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
block, _ := pem.Decode(b)
|
||||
_, err = x509.ParseCertificate(block.Bytes)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_SAMLMetadata(t *testing.T) {
|
||||
samlRedirectIdpID := Tester.AddSAMLRedirectProvider(t)
|
||||
oauthIdpID := Tester.AddGenericOAuthProvider(t)
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
idpID string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "saml metadata, invalid idp",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
idpID: "unknown",
|
||||
},
|
||||
want: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "saml metadata, invalid idp type",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
idpID: oauthIdpID,
|
||||
},
|
||||
want: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "saml metadata, ok",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
idpID: samlRedirectIdpID,
|
||||
},
|
||||
want: http.StatusOK,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
metadataURL := http_util.BuildOrigin(Tester.Host(), Tester.Server.Config.ExternalSecure) + "/idps/" + tt.args.idpID + "/saml/metadata"
|
||||
resp, err := http.Get(metadataURL)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, resp.StatusCode)
|
||||
if tt.want == http.StatusOK {
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
defer resp.Body.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = saml_xml.ParseMetadataXmlIntoStruct(b)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_SAMLACS(t *testing.T) {
|
||||
userHuman := Tester.CreateHumanUser(CTX)
|
||||
samlRedirectIdpID := Tester.AddSAMLRedirectProvider(t)
|
||||
externalUserID := "test1"
|
||||
linkedExternalUserID := "test2"
|
||||
Tester.CreateUserIDPlink(CTX, userHuman.UserId, linkedExternalUserID, samlRedirectIdpID, linkedExternalUserID)
|
||||
idp, err := getIDP(
|
||||
http_util.BuildOrigin(Tester.Host(), Tester.Server.Config.ExternalSecure),
|
||||
[]string{samlRedirectIdpID},
|
||||
externalUserID,
|
||||
linkedExternalUserID,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
successURL string
|
||||
failureURL string
|
||||
idpID string
|
||||
username string
|
||||
intentID string
|
||||
response string
|
||||
}
|
||||
type want struct {
|
||||
successful bool
|
||||
user string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want want
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "intent invalid",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
successURL: "https://example.com/success",
|
||||
failureURL: "https://example.com/failure",
|
||||
idpID: samlRedirectIdpID,
|
||||
username: externalUserID,
|
||||
intentID: "notexisting",
|
||||
},
|
||||
want: want{
|
||||
successful: false,
|
||||
user: "",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "response invalid",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
successURL: "https://example.com/success",
|
||||
failureURL: "https://example.com/failure",
|
||||
idpID: samlRedirectIdpID,
|
||||
username: externalUserID,
|
||||
response: "invalid",
|
||||
},
|
||||
want: want{
|
||||
successful: false,
|
||||
user: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "saml flow redirect, ok",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
successURL: "https://example.com/success",
|
||||
failureURL: "https://example.com/failure",
|
||||
idpID: samlRedirectIdpID,
|
||||
username: externalUserID,
|
||||
},
|
||||
want: want{
|
||||
successful: true,
|
||||
user: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "saml flow redirect with link, ok",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
successURL: "https://example.com/success",
|
||||
failureURL: "https://example.com/failure",
|
||||
idpID: samlRedirectIdpID,
|
||||
username: linkedExternalUserID,
|
||||
},
|
||||
want: want{
|
||||
successful: true,
|
||||
user: userHuman.UserId,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := Client.StartIdentityProviderFlow(tt.args.ctx,
|
||||
&user.StartIdentityProviderFlowRequest{
|
||||
IdpId: tt.args.idpID,
|
||||
Content: &user.StartIdentityProviderFlowRequest_Urls{
|
||||
Urls: &user.RedirectURLs{
|
||||
SuccessUrl: tt.args.successURL,
|
||||
FailureUrl: tt.args.failureURL,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
// can't fail as covered in other tests
|
||||
require.NoError(t, err)
|
||||
|
||||
//parse returned URL to continue flow to callback with the same intentID==RelayState
|
||||
authURL, err := url.Parse(got.GetAuthUrl())
|
||||
require.NoError(t, err)
|
||||
samlRequest := &http.Request{Method: http.MethodGet, URL: authURL}
|
||||
assert.NotEmpty(t, authURL)
|
||||
|
||||
//generate necessary information to create request to callback URL
|
||||
relayState := authURL.Query().Get("RelayState")
|
||||
//test purposes, use defined intentID
|
||||
if tt.args.intentID != "" {
|
||||
relayState = tt.args.intentID
|
||||
}
|
||||
callbackURL := http_util.BuildOrigin(Tester.Host(), Tester.Server.Config.ExternalSecure) + "/idps/" + tt.args.idpID + "/saml/acs"
|
||||
response := createResponse(t, idp, samlRequest, tt.args.username)
|
||||
//test purposes, use defined response
|
||||
if tt.args.response != "" {
|
||||
response = tt.args.response
|
||||
}
|
||||
req := httpPostFormRequest(t, callbackURL, relayState, response)
|
||||
|
||||
//do request to callback URL and check redirect to either success or failure url
|
||||
location, err := integration.CheckRedirect(req)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, relayState, location.Query().Get("id"))
|
||||
if tt.want.successful {
|
||||
assert.True(t, strings.HasPrefix(location.String(), tt.args.successURL))
|
||||
assert.NotEmpty(t, location.Query().Get("token"))
|
||||
assert.Equal(t, tt.want.user, location.Query().Get("user"))
|
||||
} else {
|
||||
assert.True(t, strings.HasPrefix(location.String(), tt.args.failureURL))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var key = func() crypto.PrivateKey {
|
||||
b, _ := pem.Decode([]byte(`-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEpAIBAAKCAQEA0OhbMuizgtbFOfwbK7aURuXhZx6VRuAs3nNibiuifwCGz6u9
|
||||
yy7bOR0P+zqN0YkjxaokqFgra7rXKCdeABmoLqCC0U+cGmLNwPOOA0PaD5q5xKhQ
|
||||
4Me3rt/R9C4Ca6k3/OnkxnKwnogcsmdgs2l8liT3qVHP04Oc7Uymq2v09bGb6nPu
|
||||
fOrkXS9F6mSClxHG/q59AGOWsXK1xzIRV1eu8W2SNdyeFVU1JHiQe444xLoPul5t
|
||||
InWasKayFsPlJfWNc8EoU8COjNhfo/GovFTHVjh9oUR/gwEFVwifIHihRE0Hazn2
|
||||
EQSLaOr2LM0TsRsQroFjmwSGgI+X2bfbMTqWOQIDAQABAoIBAFWZwDTeESBdrLcT
|
||||
zHZe++cJLxE4AObn2LrWANEv5AeySYsyzjRBYObIN9IzrgTb8uJ900N/zVr5VkxH
|
||||
xUa5PKbOcowd2NMfBTw5EEnaNbILLm+coHdanrNzVu59I9TFpAFoPavrNt/e2hNo
|
||||
NMGPSdOkFi81LLl4xoadz/WR6O/7N2famM+0u7C2uBe+TrVwHyuqboYoidJDhO8M
|
||||
w4WlY9QgAUhkPyzZqrl+VfF1aDTGVf4LJgaVevfFCas8Ws6DQX5q4QdIoV6/0vXi
|
||||
B1M+aTnWjHuiIzjBMWhcYW2+I5zfwNWRXaxdlrYXRukGSdnyO+DH/FhHePJgmlkj
|
||||
NInADDkCgYEA6MEQFOFSCc/ELXYWgStsrtIlJUcsLdLBsy1ocyQa2lkVUw58TouW
|
||||
RciE6TjW9rp31pfQUnO2l6zOUC6LT9Jvlb9PSsyW+rvjtKB5PjJI6W0hjX41wEO6
|
||||
fshFELMJd9W+Ezao2AsP2hZJ8McCF8no9e00+G4xTAyxHsNI2AFTCQcCgYEA5cWZ
|
||||
JwNb4t7YeEajPt9xuYNUOQpjvQn1aGOV7KcwTx5ELP/Hzi723BxHs7GSdrLkkDmi
|
||||
Gpb+mfL4wxCt0fK0i8GFQsRn5eusyq9hLqP/bmjpHoXe/1uajFbE1fZQR+2LX05N
|
||||
3ATlKaH2hdfCJedFa4wf43+cl6Yhp6ZA0Yet1r8CgYEAwiu1j8W9G+RRA5/8/DtO
|
||||
yrUTOfsbFws4fpLGDTA0mq0whf6Soy/96C90+d9qLaC3srUpnG9eB0CpSOjbXXbv
|
||||
kdxseLkexwOR3bD2FHX8r4dUM2bzznZyEaxfOaQypN8SV5ME3l60Fbr8ajqLO288
|
||||
wlTmGM5Mn+YCqOg/T7wjGmcCgYBpzNfdl/VafOROVbBbhgXWtzsz3K3aYNiIjbp+
|
||||
MunStIwN8GUvcn6nEbqOaoiXcX4/TtpuxfJMLw4OvAJdtxUdeSmEee2heCijV6g3
|
||||
ErrOOy6EqH3rNWHvlxChuP50cFQJuYOueO6QggyCyruSOnDDuc0BM0SGq6+5g5s7
|
||||
H++S/wKBgQDIkqBtFr9UEf8d6JpkxS0RXDlhSMjkXmkQeKGFzdoJcYVFIwq8jTNB
|
||||
nJrVIGs3GcBkqGic+i7rTO1YPkquv4dUuiIn+vKZVoO6b54f+oPBXd4S0BnuEqFE
|
||||
rdKNuCZhiaE2XD9L/O9KP1fh5bfEcKwazQ23EvpJHBMm8BGC+/YZNw==
|
||||
-----END RSA PRIVATE KEY-----`))
|
||||
k, _ := x509.ParsePKCS1PrivateKey(b.Bytes)
|
||||
return k
|
||||
}()
|
||||
|
||||
var cert = func() *x509.Certificate {
|
||||
b, _ := pem.Decode([]byte(`-----BEGIN CERTIFICATE-----
|
||||
MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNV
|
||||
BAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5
|
||||
NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEB
|
||||
BQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8A
|
||||
hs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+a
|
||||
ucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWx
|
||||
m+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6
|
||||
D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURN
|
||||
B2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0O
|
||||
BBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56
|
||||
zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5
|
||||
pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uv
|
||||
NONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEf
|
||||
y/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL
|
||||
/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsb
|
||||
GFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTL
|
||||
UzreO96WzlBBMtY=
|
||||
-----END CERTIFICATE-----`))
|
||||
c, _ := x509.ParseCertificate(b.Bytes)
|
||||
return c
|
||||
}()
|
||||
|
||||
func getIDP(zitadelBaseURL string, idpIDs []string, user1, user2 string) (*saml.IdentityProvider, error) {
|
||||
baseURL, err := url.Parse("http://localhost:8000")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
store := &samlidp.MemoryStore{}
|
||||
hashedPassword1, _ := bcrypt.GenerateFromPassword([]byte("test"), bcrypt.DefaultCost)
|
||||
err = store.Put("/users/"+user1, samlidp.User{
|
||||
Name: user1,
|
||||
HashedPassword: hashedPassword1,
|
||||
Groups: []string{"Administrators", "Users"},
|
||||
Email: "test@example.com",
|
||||
CommonName: "Test Test",
|
||||
Surname: "Test",
|
||||
GivenName: "Test",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hashedPassword2, _ := bcrypt.GenerateFromPassword([]byte("test"), bcrypt.DefaultCost)
|
||||
err = store.Put("/users/"+user2, samlidp.User{
|
||||
Name: user2,
|
||||
HashedPassword: hashedPassword2,
|
||||
Groups: []string{"Administrators", "Users"},
|
||||
Email: "test@example.com",
|
||||
CommonName: "Test Test",
|
||||
Surname: "Test",
|
||||
GivenName: "Test",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, idpID := range idpIDs {
|
||||
metadata, err := saml_xml.ReadMetadataFromURL(http.DefaultClient, zitadelBaseURL+"/idps/"+idpID+"/saml/metadata")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entity := new(saml.EntityDescriptor)
|
||||
if err := xml.Unmarshal(metadata, entity); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := store.Put("/services/"+idpID, samlidp.Service{
|
||||
Name: idpID,
|
||||
Metadata: *entity,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
idpServer, err := samlidp.New(samlidp.Options{
|
||||
URL: *baseURL,
|
||||
Key: key,
|
||||
Certificate: cert,
|
||||
Store: store,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if idpServer.IDP.AssertionMaker == nil {
|
||||
idpServer.IDP.AssertionMaker = &saml.DefaultAssertionMaker{}
|
||||
}
|
||||
return &idpServer.IDP, nil
|
||||
}
|
||||
|
||||
func createResponse(t *testing.T, idp *saml.IdentityProvider, req *http.Request, username string) string {
|
||||
authnReq, err := saml.NewIdpAuthnRequest(idp, req)
|
||||
assert.NoError(t, authnReq.Validate())
|
||||
|
||||
err = idp.AssertionMaker.MakeAssertion(authnReq, &saml.Session{
|
||||
CreateTime: time.Now().UTC(),
|
||||
Index: "",
|
||||
NameID: username,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
err = authnReq.MakeResponse()
|
||||
assert.NoError(t, err)
|
||||
|
||||
doc := etree.NewDocument()
|
||||
doc.SetRoot(authnReq.ResponseEl)
|
||||
responseBuf, err := doc.WriteToBytes()
|
||||
assert.NoError(t, err)
|
||||
responseBuf = append([]byte("<?xml version=\"1.0\"?>"), responseBuf...)
|
||||
|
||||
return base64.StdEncoding.EncodeToString(responseBuf)
|
||||
}
|
||||
|
||||
func httpGETRequest(t *testing.T, callbackURL string, relayState, response, sig, sigAlg string) *http.Request {
|
||||
req, err := http.NewRequest("GET", callbackURL, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
q := req.URL.Query()
|
||||
q.Add("RelayState", relayState)
|
||||
q.Add("SAMLResponse", response)
|
||||
if sig != "" {
|
||||
q.Add("Sig", sig)
|
||||
}
|
||||
if sigAlg != "" {
|
||||
q.Add("SigAlg", sigAlg)
|
||||
}
|
||||
req.URL.RawQuery = q.Encode()
|
||||
return req
|
||||
}
|
||||
|
||||
func httpPostFormRequest(t *testing.T, callbackURL, relayState, response string) *http.Request {
|
||||
body := url.Values{
|
||||
"SAMLResponse": {response},
|
||||
"RelayState": {relayState},
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", callbackURL, strings.NewReader(body.Encode()))
|
||||
assert.NoError(t, err)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.ParseForm()
|
||||
return req
|
||||
}
|
Reference in New Issue
Block a user