chore: move the go code into a subfolder

This commit is contained in:
Florian Forster
2025-08-05 15:20:32 -07:00
parent 4ad22ba456
commit cd2921de26
2978 changed files with 373 additions and 300 deletions

View File

@@ -0,0 +1,51 @@
package idp
import (
"context"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
)
// Provider is the minimal implementation for a 3rd party authentication provider
type Provider interface {
Name() string
BeginAuth(ctx context.Context, state string, params ...Parameter) (Session, error)
IsLinkingAllowed() bool
IsCreationAllowed() bool
IsAutoCreation() bool
IsAutoUpdate() bool
}
// User contains the information of a federated user.
type User interface {
GetID() string
GetFirstName() string
GetLastName() string
GetDisplayName() string
GetNickname() string
GetPreferredUsername() string
GetEmail() domain.EmailAddress
IsEmailVerified() bool
GetPhone() domain.PhoneNumber
IsPhoneVerified() bool
GetPreferredLanguage() language.Tag
GetAvatarURL() string
GetProfile() string
}
// Parameter allows to pass specific parameter to the BeginAuth function
type Parameter interface {
setValue()
}
// UserAgentID allows to pass the user agent ID of the auth request to BeginAuth
type UserAgentID string
func (p UserAgentID) setValue() {}
// LoginHintParam allows to pass a login_hint to BeginAuth
type LoginHintParam string
func (p LoginHintParam) setValue() {}

View File

@@ -0,0 +1,68 @@
package apple
import (
"crypto/x509"
"encoding/pem"
"time"
"github.com/go-jose/go-jose/v4"
"github.com/zitadel/oidc/v3/pkg/crypto"
openid "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
)
const (
name = "Apple"
issuer = "https://appleid.apple.com"
)
var _ idp.Provider = (*Provider)(nil)
// Provider is the [idp.Provider] implementation for Apple
type Provider struct {
*oidc.Provider
}
func New(clientID, teamID, keyID, callbackURL string, key []byte, scopes []string, options ...oidc.ProviderOpts) (*Provider, error) {
secret, err := clientSecretFromPrivateKey(key, teamID, clientID, keyID)
if err != nil {
return nil, err
}
options = append(options, oidc.WithResponseMode("form_post"))
rp, err := oidc.New(name, issuer, clientID, secret, callbackURL, scopes, oidc.DefaultMapper, options...)
if err != nil {
return nil, err
}
return &Provider{
Provider: rp,
}, nil
}
// clientSecretFromPrivateKey uses the private key to create and sign a JWT, which has to be used as client_secret at Apple.
func clientSecretFromPrivateKey(key []byte, teamID, clientID, keyID string) (string, error) {
block, _ := pem.Decode(key)
b := block.Bytes
pk, err := x509.ParsePKCS8PrivateKey(b)
if err != nil {
return "", err
}
signingKey := jose.SigningKey{
Algorithm: jose.ES256,
Key: &jose.JSONWebKey{Key: pk, KeyID: keyID},
}
signer, err := jose.NewSigner(signingKey, &jose.SignerOptions{})
if err != nil {
return "", err
}
iat := time.Now().Add(-2 * time.Second)
exp := iat.Add(time.Hour)
return crypto.Sign(&openid.JWTTokenRequest{
Issuer: teamID,
Subject: clientID,
Audience: []string{issuer},
ExpiresAt: openid.FromTime(exp),
IssuedAt: openid.FromTime(iat),
}, signer)
}

View File

@@ -0,0 +1,71 @@
package apple
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
)
const (
privateKey = `-----BEGIN PRIVATE KEY-----
MIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQgXn/LDURaetCoymSj
fRslBiBwzBSa8ifiyfYGIWNStYGgCgYIKoZIzj0DAQehRANCAATymZXIsGrXnl6b
+80miSiVOCcLnyaYa2uQBQvQwgB7GibXhrzF+D/MRTV4P7P8+Lg1K9Khkjc59eNK
4RrQP4g7
-----END PRIVATE KEY-----
`
)
func TestProvider_BeginAuth(t *testing.T) {
type fields struct {
clientID string
teamID string
keyID string
privateKey []byte
redirectURI string
scopes []string
}
tests := []struct {
name string
fields fields
want idp.Session
}{
{
name: "successful auth",
fields: fields{
clientID: "clientID",
teamID: "teamID",
keyID: "keyID",
privateKey: []byte(privateKey),
redirectURI: "redirectURI",
scopes: []string{"openid"},
},
want: &Session{
Session: &oidc.Session{
AuthURL: "https://appleid.apple.com/auth/authorize?client_id=clientID&redirect_uri=redirectURI&response_mode=form_post&response_type=code&scope=openid&state=testState",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
r := require.New(t)
provider, err := New(tt.fields.clientID, tt.fields.teamID, tt.fields.keyID, tt.fields.redirectURI, tt.fields.privateKey, tt.fields.scopes)
r.NoError(err)
ctx := context.Background()
session, err := provider.BeginAuth(ctx, "testState")
r.NoError(err)
auth, err := session.GetAuth(ctx)
authExpected, errExpected := tt.want.GetAuth(ctx)
a.ErrorIs(err, errExpected)
a.Equal(authExpected, auth)
})
}
}

View File

@@ -0,0 +1,74 @@
package apple
import (
"context"
"encoding/json"
openid "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
)
var _ idp.Session = (*Session)(nil)
// Session extends the [oidc.Session] with the formValues returned from the callback.
// This enables to parse the user (name and email), which Apple only returns as form params on registration
type Session struct {
*oidc.Session
UserFormValue string
}
func NewSession(provider *Provider, code, userFormValue string) *Session {
return &Session{Session: oidc.NewSession(provider.Provider, code, nil), UserFormValue: userFormValue}
}
type userFormValue struct {
Name userNamesFormValue `json:"name,omitempty" schema:"name"`
}
type userNamesFormValue struct {
FirstName string `json:"firstName,omitempty" schema:"firstName"`
LastName string `json:"lastName,omitempty" schema:"lastName"`
}
// FetchUser implements the [idp.Session] interface.
// It will execute an OIDC code exchange if needed to retrieve the tokens,
// extract the information from the id_token and if available also from the `user` form value.
// The information will be mapped into an [idp.User].
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
if s.Tokens == nil {
if err = s.Authorize(ctx); err != nil {
return nil, err
}
}
info := s.Tokens.IDTokenClaims.GetUserInfo()
userName := userFormValue{}
if s.UserFormValue != "" {
if err = json.Unmarshal([]byte(s.UserFormValue), &userName); err != nil {
return nil, err
}
}
return NewUser(info, userName.Name), nil
}
func NewUser(info *openid.UserInfo, names userNamesFormValue) *User {
user := oidc.NewUser(info)
user.GivenName = names.FirstName
user.FamilyName = names.LastName
return &User{User: user}
}
func InitUser() idp.User {
return &User{User: oidc.InitUser()}
}
// User extends the [oidc.User] by returning the email as preferred_username, since Apple does not return the latter.
type User struct {
*oidc.User
}
func (u *User) GetPreferredUsername() string {
return u.Email
}

View File

@@ -0,0 +1,217 @@
package apple
import (
"context"
"errors"
"testing"
"time"
"github.com/h2non/gock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
openid "github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
)
func TestSession_FetchUser(t *testing.T) {
type fields struct {
clientID string
teamID string
keyID string
privateKey []byte
redirectURI string
scopes []string
httpMock func()
authURL string
code string
tokens *openid.Tokens[*openid.IDTokenClaims]
userFormValue string
}
type want struct {
err error
id string
firstName string
lastName string
displayName string
nickName string
preferredUsername string
email string
isEmailVerified bool
phone string
isPhoneVerified bool
preferredLanguage language.Tag
avatarURL string
profile string
nonceSupported bool
isPrivateEmail bool
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "unauthenticated session, error",
fields: fields{
clientID: "clientID",
teamID: "teamID",
keyID: "keyID",
privateKey: []byte(privateKey),
redirectURI: "redirectURI",
scopes: []string{"openid"},
httpMock: func() {},
authURL: "https://appleid.apple.com/auth/authorize?client_id=clientID&redirect_uri=redirectURI&response_mode=form_post&response_type=code&scope=openid&state=testState",
tokens: nil,
},
want: want{
err: oidc.ErrCodeMissing,
},
},
{
name: "no user param",
fields: fields{
clientID: "clientID",
teamID: "teamID",
keyID: "keyID",
privateKey: []byte(privateKey),
redirectURI: "redirectURI",
scopes: []string{"openid"},
httpMock: func() {},
authURL: "https://appleid.apple.com/auth/authorize?client_id=clientID&redirect_uri=redirectURI&response_mode=form_post&response_type=code&scope=openid&state=testState",
tokens: &openid.Tokens[*openid.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: openid.BearerToken,
},
IDTokenClaims: id_token(),
},
userFormValue: "",
},
want: want{
id: "sub",
firstName: "",
lastName: "",
displayName: "",
nickName: "",
preferredUsername: "email",
email: "email",
isEmailVerified: true,
phone: "",
isPhoneVerified: false,
preferredLanguage: language.Und,
avatarURL: "",
profile: "",
nonceSupported: true,
isPrivateEmail: true,
},
},
{
name: "with user param",
fields: fields{
clientID: "clientID",
teamID: "teamID",
keyID: "keyID",
privateKey: []byte(privateKey),
redirectURI: "redirectURI",
scopes: []string{"openid"},
httpMock: func() {},
authURL: "https://appleid.apple.com/auth/authorize?client_id=clientID&redirect_uri=redirectURI&response_mode=form_post&response_type=code&scope=openid&state=testState",
tokens: &openid.Tokens[*openid.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: openid.BearerToken,
},
IDTokenClaims: id_token(),
},
userFormValue: `{"name": {"firstName": "firstName", "lastName": "lastName"}}`,
},
want: want{
id: "sub",
firstName: "firstName",
lastName: "lastName",
displayName: "",
nickName: "",
preferredUsername: "email",
email: "email",
isEmailVerified: true,
phone: "",
isPhoneVerified: false,
preferredLanguage: language.Und,
avatarURL: "",
profile: "",
nonceSupported: true,
isPrivateEmail: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer gock.Off()
tt.fields.httpMock()
a := assert.New(t)
// call the real discovery endpoint
gock.New(issuer).Get(openid.DiscoveryEndpoint).EnableNetworking()
provider, err := New(tt.fields.clientID, tt.fields.teamID, tt.fields.keyID, tt.fields.redirectURI, tt.fields.privateKey, tt.fields.scopes)
require.NoError(t, err)
session := &Session{
Session: &oidc.Session{
Provider: provider.Provider,
AuthURL: tt.fields.authURL,
Code: tt.fields.code,
Tokens: tt.fields.tokens,
},
UserFormValue: tt.fields.userFormValue,
}
user, err := session.FetchUser(context.Background())
if tt.want.err != nil && !errors.Is(err, tt.want.err) {
a.Fail("invalid error", "expected %v, got %v", tt.want.err, err)
}
if tt.want.err == nil {
a.NoError(err)
a.Equal(tt.want.id, user.GetID())
a.Equal(tt.want.firstName, user.GetFirstName())
a.Equal(tt.want.lastName, user.GetLastName())
a.Equal(tt.want.displayName, user.GetDisplayName())
a.Equal(tt.want.nickName, user.GetNickname())
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
a.Equal(domain.EmailAddress(tt.want.email), user.GetEmail())
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
a.Equal(domain.PhoneNumber(tt.want.phone), user.GetPhone())
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
a.Equal(tt.want.profile, user.GetProfile())
}
})
}
}
func id_token() *openid.IDTokenClaims {
return &openid.IDTokenClaims{
TokenClaims: openid.TokenClaims{
Issuer: issuer,
Subject: "sub",
Audience: []string{"clientID"},
Expiration: openid.FromTime(time.Now().Add(1 * time.Hour)),
IssuedAt: openid.FromTime(time.Now().Add(-1 * time.Second)),
AuthTime: openid.FromTime(time.Now().Add(-1 * time.Second)),
Nonce: "nonce",
ClientID: "clientID",
},
UserInfoEmail: openid.UserInfoEmail{
Email: "email",
EmailVerified: true,
},
Claims: map[string]any{
"nonce_supported": true,
"is_private_email": true,
},
}
}

View File

@@ -0,0 +1,253 @@
package azuread
import (
"fmt"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
)
const (
issuerTemplate string = "https://login.microsoftonline.com/%s/v2.0"
authURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/v2.0/authorize"
tokenURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/v2.0/token"
keysURLTemplate string = "https://login.microsoftonline.com/%s/discovery/v2.0/keys"
userURL string = "https://graph.microsoft.com/v1.0/me"
userinfoEndpoint string = "https://graph.microsoft.com/oidc/userinfo"
ScopeUserRead string = "User.Read"
)
// TenantType are the well known tenant types to scope the users that can authenticate. TenantType is not an
// exclusive list of Azure Tenants which can be used. A consumer can also use their own Tenant ID to scope
// authentication to their specific Tenant either through the Tenant ID or the friendly domain name.
//
// see also https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-v2-protocols#endpoints
type TenantType string
const (
// CommonTenant allows users with both personal Microsoft accounts and work/school accounts from Azure Active
// Directory to sign in to the application.
CommonTenant TenantType = "common"
// OrganizationsTenant allows only users with work/school accounts from Azure Active Directory to sign in to the application.
OrganizationsTenant TenantType = "organizations"
// ConsumersTenant allows only users with personal Microsoft accounts (MSA) to sign in to the application.
ConsumersTenant TenantType = "consumers"
)
var _ idp.Provider = (*Provider)(nil)
// Provider is the [idp.Provider] implementation for AzureAD (V2 Endpoints)
type Provider struct {
*oauth.Provider
tenant TenantType
emailVerified bool
options []oauth.ProviderOpts
}
// issuer returns the OIDC issuer based on the [TenantType]
func (p *Provider) issuer() string {
return fmt.Sprintf(issuerTemplate, p.tenant)
}
// keysEndpoint returns the OIDC jwks_url based on the [TenantType]
func (p *Provider) keysEndpoint() string {
return fmt.Sprintf(keysURLTemplate, p.tenant)
}
type ProviderOptions func(*Provider)
// WithTenant allows to set a [TenantType] (can also be a Tenant ID)
// default is CommonTenant
func WithTenant(tenantType TenantType) ProviderOptions {
return func(p *Provider) {
p.tenant = tenantType
}
}
// WithEmailVerified allows to set every email received as verified
func WithEmailVerified() ProviderOptions {
return func(p *Provider) {
p.emailVerified = true
}
}
// WithOAuthOptions allows to specify [oauth.ProviderOpts] like [oauth.WithLinkingAllowed]
func WithOAuthOptions(opts ...oauth.ProviderOpts) ProviderOptions {
return func(p *Provider) {
p.options = append(p.options, opts...)
}
}
// New creates an AzureAD provider using the [oauth.Provider] (OAuth 2.0 generic provider).
// By default, it uses the [CommonTenant] and unverified emails.
func New(name, clientID, clientSecret, redirectURI string, scopes []string, opts ...ProviderOptions) (*Provider, error) {
provider := &Provider{
tenant: CommonTenant,
options: make([]oauth.ProviderOpts, 0),
}
for _, opt := range opts {
opt(provider)
}
config := newConfig(provider.tenant, clientID, clientSecret, redirectURI, scopes)
rp, err := oauth.New(
config,
name,
userURL,
func() idp.User {
return &User{isEmailVerified: provider.emailVerified}
},
provider.options...,
)
if err != nil {
return nil, err
}
provider.Provider = rp
return provider, nil
}
func newConfig(tenant TenantType, clientID, secret, callbackURL string, scopes []string) *oauth2.Config {
return &oauth2.Config{
ClientID: clientID,
ClientSecret: secret,
RedirectURL: callbackURL,
Endpoint: oauth2.Endpoint{
AuthURL: fmt.Sprintf(authURLTemplate, tenant),
TokenURL: fmt.Sprintf(tokenURLTemplate, tenant),
},
Scopes: ensureMinimalScope(scopes),
}
}
// ensureMinimalScope ensures that at least openid and `User.Read` ist set
// if none is provided it will request `openid profile email phone User.Read`
func ensureMinimalScope(scopes []string) []string {
if len(scopes) == 0 {
return []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopePhone, ScopeUserRead}
}
var openIDSet, userReadSet bool
for _, scope := range scopes {
if scope == oidc.ScopeOpenID {
openIDSet = true
continue
}
if scope == ScopeUserRead {
userReadSet = true
continue
}
}
if !openIDSet {
scopes = append(scopes, oidc.ScopeOpenID)
}
if !userReadSet {
scopes = append(scopes, ScopeUserRead)
}
return scopes
}
func (p *Provider) User() idp.User {
return p.Provider.User()
}
// User represents the structure return on the userinfo endpoint and implements the [idp.User] interface
//
// AzureAD does not return an `email_verified` claim.
// The verification can be automatically activated on the provider ([WithEmailVerified])
type User struct {
ID string `json:"id"`
BusinessPhones []domain.PhoneNumber `json:"businessPhones"`
DisplayName string `json:"displayName"`
FirstName string `json:"givenName"`
JobTitle string `json:"jobTitle"`
Email domain.EmailAddress `json:"mail"`
MobilePhone domain.PhoneNumber `json:"mobilePhone"`
OfficeLocation string `json:"officeLocation"`
PreferredLanguage string `json:"preferredLanguage"`
LastName string `json:"surname"`
UserPrincipalName string `json:"userPrincipalName"`
isEmailVerified bool
}
// GetID is an implementation of the [idp.User] interface.
func (u *User) GetID() string {
return u.ID
}
// GetFirstName is an implementation of the [idp.User] interface.
func (u *User) GetFirstName() string {
return u.FirstName
}
// GetLastName is an implementation of the [idp.User] interface.
func (u *User) GetLastName() string {
return u.LastName
}
// GetDisplayName is an implementation of the [idp.User] interface.
func (u *User) GetDisplayName() string {
return u.DisplayName
}
// GetNickname is an implementation of the [idp.User] interface.
// It returns an empty string because AzureAD does not provide the user's nickname.
func (u *User) GetNickname() string {
return ""
}
// GetPreferredUsername is an implementation of the [idp.User] interface.
func (u *User) GetPreferredUsername() string {
return u.UserPrincipalName
}
// GetEmail is an implementation of the [idp.User] interface.
func (u *User) GetEmail() domain.EmailAddress {
if u.Email == "" {
// if the user used a social login on Azure as well, the email will be empty
// but is used as username
return domain.EmailAddress(u.UserPrincipalName)
}
return u.Email
}
// IsEmailVerified is an implementation of the [idp.User] interface
// returning the value specified in the creation of the [Provider].
// Default is false because AzureAD does not return an `email_verified` claim.
// The verification can be automatically activated on the provider ([WithEmailVerified]).
func (u *User) IsEmailVerified() bool {
return u.isEmailVerified
}
// GetPhone is an implementation of the [idp.User] interface.
// It returns an empty string because AzureAD does not provide the user's phone.
func (u *User) GetPhone() domain.PhoneNumber {
return ""
}
// IsPhoneVerified is an implementation of the [idp.User] interface.
// It returns false because AzureAD does not provide the user's phone.
func (u *User) IsPhoneVerified() bool {
return false
}
// GetPreferredLanguage is an implementation of the [idp.User] interface.
func (u *User) GetPreferredLanguage() language.Tag {
return language.Make(u.PreferredLanguage)
}
// GetProfile is an implementation of the [idp.User] interface.
// It returns an empty string because AzureAD does not provide the user's profile page.
func (u *User) GetProfile() string {
return ""
}
// GetAvatarURL is an implementation of the [idp.User] interface.
func (u *User) GetAvatarURL() string {
return ""
}

View File

@@ -0,0 +1,185 @@
package azuread
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/client/rp"
openid "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
)
func TestProvider_BeginAuth(t *testing.T) {
type fields struct {
name string
clientID string
clientSecret string
redirectURI string
scopes []string
options []ProviderOptions
}
tests := []struct {
name string
fields fields
want idp.Session
}{
{
name: "default common tenant",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
},
want: &oidc.Session{
AuthURL: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email+phone+User.Read&state=testState",
},
},
{
name: "tenant",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
options: []ProviderOptions{
WithTenant(ConsumersTenant),
},
},
want: &oidc.Session{
AuthURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email+phone+User.Read&state=testState",
},
},
{
name: "custom scopes",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
scopes: []string{openid.ScopeOpenID, openid.ScopeProfile, "custom"},
options: []ProviderOptions{
WithTenant(ConsumersTenant),
},
},
want: &oidc.Session{
AuthURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid+profile+custom+User.Read&state=testState",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
r := require.New(t)
provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.options...)
r.NoError(err)
ctx := context.Background()
session, err := provider.BeginAuth(ctx, "testState")
r.NoError(err)
wantAuth, wantErr := tt.want.GetAuth(ctx)
gotAuth, gotErr := session.GetAuth(ctx)
a.Equal(wantAuth, gotAuth)
a.ErrorIs(gotErr, wantErr)
})
}
}
func TestProvider_Options(t *testing.T) {
type fields struct {
name string
clientID string
clientSecret string
redirectURI string
scopes []string
options []ProviderOptions
}
type want struct {
name string
tenant TenantType
emailVerified bool
linkingAllowed bool
creationAllowed bool
autoCreation bool
autoUpdate bool
pkce bool
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "default common tenant",
fields: fields{
name: "default common tenant",
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
scopes: nil,
options: nil,
},
want: want{
name: "default common tenant",
tenant: CommonTenant,
emailVerified: false,
linkingAllowed: false,
creationAllowed: false,
autoCreation: false,
autoUpdate: false,
pkce: false,
},
},
{
name: "all set",
fields: fields{
name: "custom tenant",
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
options: []ProviderOptions{
WithTenant("tenant"),
WithEmailVerified(),
WithOAuthOptions(
oauth.WithLinkingAllowed(),
oauth.WithCreationAllowed(),
oauth.WithAutoCreation(),
oauth.WithAutoUpdate(),
oauth.WithRelyingPartyOption(rp.WithPKCE(nil)),
),
},
},
want: want{
name: "custom tenant",
tenant: "tenant",
emailVerified: true,
linkingAllowed: true,
creationAllowed: true,
autoCreation: true,
autoUpdate: true,
pkce: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.options...)
require.NoError(t, err)
a.Equal(tt.want.name, provider.Name())
a.Equal(tt.want.tenant, provider.tenant)
a.Equal(tt.want.emailVerified, provider.emailVerified)
a.Equal(tt.want.linkingAllowed, provider.IsLinkingAllowed())
a.Equal(tt.want.creationAllowed, provider.IsCreationAllowed())
a.Equal(tt.want.autoCreation, provider.IsAutoCreation())
a.Equal(tt.want.autoUpdate, provider.IsAutoUpdate())
a.Equal(tt.want.pkce, provider.RelyingParty.IsPKCE())
})
}
}

View File

@@ -0,0 +1,106 @@
package azuread
import (
"context"
"net/http"
"time"
"github.com/zitadel/oidc/v3/pkg/client/rp"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
)
var _ idp.Session = (*Session)(nil)
// Session extends the [oauth.Session] to be able to handle the id_token and to implement the [idp.SessionSupportsMigration] functionality
type Session struct {
*Provider
Code string
OAuthSession *oauth.Session
}
func NewSession(provider *Provider, code string) *Session {
return &Session{Provider: provider, Code: code}
}
// GetAuth implements the [idp.Provider] interface by calling the wrapped [oauth.Session].
func (s *Session) GetAuth(ctx context.Context) (idp.Auth, error) {
return s.oauth().GetAuth(ctx)
}
// RetrievePreviousID implements the [idp.SessionSupportsMigration] interface by returning the `sub` from the userinfo endpoint
func (s *Session) RetrievePreviousID() (string, error) {
req, err := http.NewRequest("GET", userinfoEndpoint, nil)
if err != nil {
return "", err
}
req.Header.Set("authorization", s.oauth().Tokens.TokenType+" "+s.oauth().Tokens.AccessToken)
userinfo := new(oidc.UserInfo)
if err := httphelper.HttpRequest(s.Provider.HttpClient(), req, &userinfo); err != nil {
return "", err
}
return userinfo.Subject, nil
}
// PersistentParameters implements the [idp.Session] interface.
func (s *Session) PersistentParameters() map[string]any {
return nil
}
// FetchUser implements the [idp.Session] interface.
// It will execute an OAuth 2.0 code exchange if needed to retrieve the access token,
// call the specified userEndpoint and map the received information into an [idp.User].
// In case of a specific TenantID as [TenantType] it will additionally extract the id_token and validate it.
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
user, err = s.oauth().FetchUser(ctx)
if err != nil {
return nil, err
}
// since azure will sign the id_token always with the issuer of the application it might differ from
// the issuer the auth and token were based on, e.g. when allowing all account types to login,
// then the auth endpoint must be `https://login.microsoftonline.com/common/oauth2/v2.0/authorize`
// even though the issuer would be like `https://login.microsoftonline.com/d8cdd43f-fd94-4576-8deb-f3bfea72dc2e/v2.0`
if s.Provider.tenant == CommonTenant ||
s.Provider.tenant == OrganizationsTenant ||
s.Provider.tenant == ConsumersTenant {
return user, nil
}
idToken, ok := s.oauth().Tokens.Extra("id_token").(string)
if !ok {
return user, nil
}
idTokenVerifier := rp.NewIDTokenVerifier(s.Provider.issuer(), s.Provider.OAuthConfig().ClientID, rp.NewRemoteKeySet(s.Provider.HttpClient(), s.Provider.keysEndpoint()))
s.oauth().Tokens.IDTokenClaims, err = rp.VerifyTokens[*oidc.IDTokenClaims](ctx, s.oauth().Tokens.AccessToken, idToken, idTokenVerifier)
if err != nil {
return nil, err
}
s.oauth().Tokens.IDToken = idToken
return user, nil
}
func (s *Session) ExpiresAt() time.Time {
if s.OAuthSession == nil {
return time.Time{}
}
return s.OAuthSession.ExpiresAt()
}
// Tokens returns the [oidc.Tokens] of the underlying [oauth.Session].
func (s *Session) Tokens() *oidc.Tokens[*oidc.IDTokenClaims] {
return s.oauth().Tokens
}
func (s *Session) oauth() *oauth.Session {
if s.OAuthSession != nil {
return s.OAuthSession
}
s.OAuthSession = &oauth.Session{
Code: s.Code,
Provider: s.Provider.Provider,
}
return s.OAuthSession
}

View File

@@ -0,0 +1,416 @@
package azuread
import (
"context"
"errors"
"net/http"
"testing"
"time"
"github.com/h2non/gock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
)
func TestSession_FetchUser(t *testing.T) {
type fields struct {
name string
clientID string
clientSecret string
redirectURI string
scopes []string
httpMock func()
options []ProviderOptions
authURL string
code string
tokens *oidc.Tokens[*oidc.IDTokenClaims]
}
type want struct {
err func(error) bool
user idp.User
id string
firstName string
lastName string
displayName string
nickName string
preferredUsername string
email string
isEmailVerified bool
phone string
isPhoneVerified bool
preferredLanguage language.Tag
avatarURL string
profile string
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "unauthenticated session, error",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://graph.microsoft.com").
Get("/v1.0/me").
Reply(200).
JSON(userinfo())
},
authURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
tokens: nil,
},
want: want{
err: func(err error) bool {
return errors.Is(err, oauth.ErrCodeMissing)
},
},
},
{
name: "user error",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://graph.microsoft.com").
Get("/v1.0/me").
Reply(http.StatusInternalServerError)
},
authURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
IDTokenClaims: oidc.NewIDTokenClaims(
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
"sub2",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
want: want{
err: func(err error) bool {
return err.Error() == "http status not ok: 500 Internal Server Error "
},
},
},
{
name: "successful fetch",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://graph.microsoft.com").
Get("/v1.0/me").
Reply(200).
JSON(userinfo())
},
authURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
IDTokenClaims: oidc.NewIDTokenClaims(
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
"sub",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
want: want{
user: &User{
ID: "id",
BusinessPhones: []domain.PhoneNumber{"phone1", "phone2"},
DisplayName: "firstname lastname",
FirstName: "firstname",
JobTitle: "title",
Email: "email",
MobilePhone: "mobile",
OfficeLocation: "office",
PreferredLanguage: "en",
LastName: "lastname",
UserPrincipalName: "username",
isEmailVerified: false,
},
id: "id",
firstName: "firstname",
lastName: "lastname",
displayName: "firstname lastname",
nickName: "",
preferredUsername: "username",
email: "email",
isEmailVerified: false,
phone: "",
isPhoneVerified: false,
preferredLanguage: language.English,
profile: "",
},
},
{
name: "successful fetch with email verified",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
options: []ProviderOptions{
WithEmailVerified(),
},
httpMock: func() {
gock.New("https://graph.microsoft.com").
Get("/v1.0/me").
Reply(200).
JSON(userinfo())
},
authURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
IDTokenClaims: oidc.NewIDTokenClaims(
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
"sub",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
want: want{
user: &User{
ID: "id",
BusinessPhones: []domain.PhoneNumber{"phone1", "phone2"},
DisplayName: "firstname lastname",
FirstName: "firstname",
JobTitle: "title",
Email: "email",
MobilePhone: "mobile",
OfficeLocation: "office",
PreferredLanguage: "en",
LastName: "lastname",
UserPrincipalName: "username",
isEmailVerified: true,
},
id: "id",
firstName: "firstname",
lastName: "lastname",
displayName: "firstname lastname",
nickName: "",
preferredUsername: "username",
email: "email",
isEmailVerified: true,
phone: "",
isPhoneVerified: false,
preferredLanguage: language.English,
profile: "",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer gock.Off()
tt.fields.httpMock()
a := assert.New(t)
provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.options...)
require.NoError(t, err)
session := &Session{
Provider: provider,
Code: tt.fields.code,
OAuthSession: &oauth.Session{
AuthURL: tt.fields.authURL,
Tokens: tt.fields.tokens,
Provider: provider.Provider,
Code: tt.fields.code,
},
}
user, err := session.FetchUser(context.Background())
if tt.want.err != nil && !tt.want.err(err) {
a.Fail("invalid error", err)
}
if tt.want.err == nil {
a.NoError(err)
a.Equal(tt.want.user, user)
a.Equal(tt.want.id, user.GetID())
a.Equal(tt.want.firstName, user.GetFirstName())
a.Equal(tt.want.lastName, user.GetLastName())
a.Equal(tt.want.displayName, user.GetDisplayName())
a.Equal(tt.want.nickName, user.GetNickname())
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
a.Equal(domain.EmailAddress(tt.want.email), user.GetEmail())
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
a.Equal(domain.PhoneNumber(tt.want.phone), user.GetPhone())
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
a.Equal(tt.want.profile, user.GetProfile())
}
})
}
}
func userinfo() *User {
return &User{
ID: "id",
BusinessPhones: []domain.PhoneNumber{"phone1", "phone2"},
DisplayName: "firstname lastname",
FirstName: "firstname",
JobTitle: "title",
Email: "email",
MobilePhone: "mobile",
OfficeLocation: "office",
PreferredLanguage: "en",
LastName: "lastname",
UserPrincipalName: "username",
}
}
func TestSession_RetrievePreviousID(t *testing.T) {
type fields struct {
name string
clientID string
clientSecret string
redirectURI string
scopes []string
httpMock func()
tokens *oidc.Tokens[*oidc.IDTokenClaims]
}
type res struct {
id string
err bool
}
tests := []struct {
name string
fields fields
res res
}{
{
name: "invalid token",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://graph.microsoft.com").
Get("/oidc/userinfo").
Reply(401)
},
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
IDTokenClaims: oidc.NewIDTokenClaims(
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
"sub",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
res: res{
err: true,
},
},
{
name: "success",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://graph.microsoft.com").
Get("/oidc/userinfo").
Reply(200).
JSON(`{"sub":"sub"}`)
},
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
IDTokenClaims: oidc.NewIDTokenClaims(
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
"sub",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
res: res{
id: "sub",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer gock.Off()
tt.fields.httpMock()
a := assert.New(t)
provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes)
require.NoError(t, err)
session := &Session{
Provider: provider,
OAuthSession: &oauth.Session{
Tokens: tt.fields.tokens,
Provider: provider.Provider,
}}
id, err := session.RetrievePreviousID()
if tt.res.err {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
a.Equal(tt.res.id, id)
})
}
}

View File

@@ -0,0 +1,190 @@
package github
import (
"strconv"
"time"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
)
const (
authURL = "https://github.com/login/oauth/authorize"
tokenURL = "https://github.com/login/oauth/access_token"
profileURL = "https://api.github.com/user"
name = "GitHub"
)
var _ idp.Provider = (*Provider)(nil)
// New creates a GitHub.com provider using the [oauth.Provider] (OAuth 2.0 generic provider)
func New(clientID, secret, callbackURL string, scopes []string, options ...oauth.ProviderOpts) (*Provider, error) {
return NewCustomURL(name, clientID, secret, callbackURL, authURL, tokenURL, profileURL, scopes, options...)
}
// NewCustomURL creates a GitHub provider using the [oauth.Provider] (OAuth 2.0 generic provider)
// with custom endpoints, e.g. GitHub Enterprise server
func NewCustomURL(name, clientID, secret, callbackURL, authURL, tokenURL, profileURL string, scopes []string, options ...oauth.ProviderOpts) (*Provider, error) {
rp, err := oauth.New(
newConfig(clientID, secret, callbackURL, authURL, tokenURL, scopes),
name,
profileURL,
func() idp.User {
return new(User)
},
options...,
)
if err != nil {
return nil, err
}
return &Provider{
Provider: rp,
}, nil
}
// Provider is the [idp.Provider] implementation for GitHub
type Provider struct {
*oauth.Provider
}
func newConfig(clientID, secret, callbackURL, authURL, tokenURL string, scopes []string) *oauth2.Config {
c := &oauth2.Config{
ClientID: clientID,
ClientSecret: secret,
RedirectURL: callbackURL,
Endpoint: oauth2.Endpoint{
AuthURL: authURL,
TokenURL: tokenURL,
},
Scopes: scopes,
}
return c
}
// User is a representation of the authenticated GitHub user and implements the [idp.User] interface
// https://docs.github.com/en/rest/users/users?apiVersion=2022-11-28#get-the-authenticated-user
type User struct {
Login string `json:"login"`
ID int `json:"id"`
NodeId string `json:"node_id"`
AvatarUrl string `json:"avatar_url"`
GravatarId string `json:"gravatar_id"`
Url string `json:"url"`
HtmlUrl string `json:"html_url"`
FollowersUrl string `json:"followers_url"`
FollowingUrl string `json:"following_url"`
GistsUrl string `json:"gists_url"`
StarredUrl string `json:"starred_url"`
SubscriptionsUrl string `json:"subscriptions_url"`
OrganizationsUrl string `json:"organizations_url"`
ReposUrl string `json:"repos_url"`
EventsUrl string `json:"events_url"`
ReceivedEventsUrl string `json:"received_events_url"`
Type string `json:"type"`
SiteAdmin bool `json:"site_admin"`
Name string `json:"name"`
Company string `json:"company"`
Blog string `json:"blog"`
Location string `json:"location"`
Email domain.EmailAddress `json:"email"`
Hireable bool `json:"hireable"`
Bio string `json:"bio"`
TwitterUsername string `json:"twitter_username"`
PublicRepos int `json:"public_repos"`
PublicGists int `json:"public_gists"`
Followers int `json:"followers"`
Following int `json:"following"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
PrivateGists int `json:"private_gists"`
TotalPrivateRepos int `json:"total_private_repos"`
OwnedPrivateRepos int `json:"owned_private_repos"`
DiskUsage int `json:"disk_usage"`
Collaborators int `json:"collaborators"`
TwoFactorAuthentication bool `json:"two_factor_authentication"`
Plan struct {
Name string `json:"name"`
Space int `json:"space"`
PrivateRepos int `json:"private_repos"`
Collaborators int `json:"collaborators"`
} `json:"plan"`
}
// GetID is an implementation of the [idp.User] interface.
func (u *User) GetID() string {
return strconv.Itoa(u.ID)
}
// GetFirstName is an implementation of the [idp.User] interface.
// It returns an empty string because GitHub does not provide the user's firstname.
func (u *User) GetFirstName() string {
return ""
}
// GetLastName is an implementation of the [idp.User] interface.
// It returns an empty string because GitHub does not provide the user's lastname.
func (u *User) GetLastName() string {
// GitHub does not provide the user's lastname
return ""
}
// GetDisplayName is an implementation of the [idp.User] interface.
func (u *User) GetDisplayName() string {
return u.Name
}
// GetNickname is an implementation of the [idp.User] interface
// returning the login name of the GitHub user.
func (u *User) GetNickname() string {
return u.Login
}
// GetPreferredUsername is an implementation of the [idp.User] interface
// returning the login name of the GitHub user.
func (u *User) GetPreferredUsername() string {
return u.Login
}
// GetEmail is an implementation of the [idp.User] interface.
func (u *User) GetEmail() domain.EmailAddress {
return u.Email
}
// IsEmailVerified is an implementation of the [idp.User] interface.
// It returns true because GitHub validates emails themselves.
func (u *User) IsEmailVerified() bool {
return true
}
// GetPhone is an implementation of the [idp.User] interface.
// It returns an empty string because GitHub does not provide the user's phone.
func (u *User) GetPhone() domain.PhoneNumber {
return ""
}
// IsPhoneVerified is an implementation of the [idp.User] interface
// it returns false because GitHub does not provide the user's phone
func (u *User) IsPhoneVerified() bool {
return false
}
// GetPreferredLanguage is an implementation of the [idp.User] interface.
// It returns [language.Und] because GitHub does not provide the user's language.
func (u *User) GetPreferredLanguage() language.Tag {
return language.Und
}
// GetProfile is an implementation of the [idp.User] interface.
func (u *User) GetProfile() string {
return u.HtmlUrl
}
// GetAvatarURL is an implementation of the [idp.User] interface.
func (u *User) GetAvatarURL() string {
return u.AvatarUrl
}

View File

@@ -0,0 +1,57 @@
package github
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
)
func TestProvider_BeginAuth(t *testing.T) {
type fields struct {
clientID string
clientSecret string
redirectURI string
scopes []string
options []oauth.ProviderOpts
}
tests := []struct {
name string
fields fields
want idp.Session
}{
{
name: "successful auth",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
},
want: &oauth.Session{
AuthURL: "https://github.com/login/oauth/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&state=testState",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
r := require.New(t)
provider, err := New(tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.options...)
r.NoError(err)
ctx := context.Background()
session, err := provider.BeginAuth(ctx, "testState")
r.NoError(err)
wantAuth, wantErr := tt.want.GetAuth(ctx)
gotAuth, gotErr := session.GetAuth(ctx)
a.Equal(wantAuth, gotAuth)
a.ErrorIs(gotErr, wantErr)
})
}
}

View File

@@ -0,0 +1,216 @@
package github
import (
"context"
"errors"
"net/http"
"testing"
"time"
"github.com/h2non/gock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
)
func TestSession_FetchUser(t *testing.T) {
type fields struct {
clientID string
clientSecret string
redirectURI string
httpMock func()
authURL string
code string
tokens *oidc.Tokens[*oidc.IDTokenClaims]
scopes []string
options []oauth.ProviderOpts
}
type args struct {
session idp.Session
}
type want struct {
err func(error) bool
user idp.User
id string
firstName string
lastName string
displayName string
nickName string
preferredUsername string
email string
isEmailVerified bool
phone string
isPhoneVerified bool
preferredLanguage language.Tag
avatarURL string
profile string
}
tests := []struct {
name string
fields fields
args args
want want
}{
{
name: "unauthenticated session, error",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://api.github.com").
Get("/user").
Reply(200).
JSON(userinfo())
},
authURL: "https://github.com/login/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&state=testState",
tokens: nil,
},
args: args{
&oauth.Session{},
},
want: want{
err: func(err error) bool {
return errors.Is(err, oauth.ErrCodeMissing)
},
},
},
{
name: "user error",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://api.github.com").
Get("/user").
Reply(http.StatusInternalServerError)
},
authURL: "https://github.com/login/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&state=testState",
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
},
},
args: args{
&oauth.Session{},
},
want: want{
err: func(err error) bool {
return err.Error() == "http status not ok: 500 Internal Server Error "
},
},
},
{
name: "successful fetch",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://api.github.com").
Get("/user").
Reply(200).
JSON(userinfo())
},
authURL: "https://github.com/login/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&state=testState",
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
},
},
args: args{
&oauth.Session{},
},
want: want{
user: &User{
Login: "login",
ID: 1,
AvatarUrl: "avatarURL",
GravatarId: "gravatarID",
Name: "name",
Email: "email",
HtmlUrl: "htmlURL",
CreatedAt: time.Date(2023, 01, 10, 11, 10, 35, 0, time.UTC),
UpdatedAt: time.Date(2023, 01, 10, 11, 10, 35, 0, time.UTC),
},
id: "1",
firstName: "",
lastName: "",
displayName: "name",
nickName: "login",
preferredUsername: "login",
email: "email",
isEmailVerified: true,
phone: "",
isPhoneVerified: false,
preferredLanguage: language.Und,
avatarURL: "avatarURL",
profile: "htmlURL",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer gock.Off()
tt.fields.httpMock()
a := assert.New(t)
provider, err := New(tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.options...)
require.NoError(t, err)
session := &oauth.Session{
AuthURL: tt.fields.authURL,
Code: tt.fields.code,
Tokens: tt.fields.tokens,
Provider: provider.Provider,
}
user, err := session.FetchUser(context.Background())
if tt.want.err != nil && !tt.want.err(err) {
a.Fail("invalid error", err)
}
if tt.want.err == nil {
a.NoError(err)
a.Equal(tt.want.user, user)
a.Equal(tt.want.id, user.GetID())
a.Equal(tt.want.firstName, user.GetFirstName())
a.Equal(tt.want.lastName, user.GetLastName())
a.Equal(tt.want.displayName, user.GetDisplayName())
a.Equal(tt.want.nickName, user.GetNickname())
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
a.Equal(domain.EmailAddress(tt.want.email), user.GetEmail())
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
a.Equal(domain.PhoneNumber(tt.want.phone), user.GetPhone())
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
a.Equal(tt.want.profile, user.GetProfile())
}
})
}
}
func userinfo() *User {
return &User{
Login: "login",
ID: 1,
AvatarUrl: "avatarURL",
GravatarId: "gravatarID",
Name: "name",
Email: "email",
HtmlUrl: "htmlURL",
CreatedAt: time.Date(2023, 01, 10, 11, 10, 35, 0, time.UTC),
UpdatedAt: time.Date(2023, 01, 10, 11, 10, 35, 0, time.UTC),
}
}

View File

@@ -0,0 +1,45 @@
package gitlab
import (
openid "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
)
const (
issuer = "https://gitlab.com"
name = "GitLab"
)
var _ idp.Provider = (*Provider)(nil)
// Provider is the [idp.Provider] implementation for Gitlab
type Provider struct {
*oidc.Provider
}
// New creates a GitLab.com provider using the [oidc.Provider] (OIDC generic provider)
func New(clientID, clientSecret, redirectURI string, scopes []string, options ...oidc.ProviderOpts) (*Provider, error) {
return NewCustomIssuer(name, issuer, clientID, clientSecret, redirectURI, scopes, options...)
}
// NewCustomIssuer creates a GitLab provider using the [oidc.Provider] (OIDC generic provider)
// with a custom issuer for self-managed instances
func NewCustomIssuer(name, issuer, clientID, clientSecret, redirectURI string, scopes []string, options ...oidc.ProviderOpts) (*Provider, error) {
if len(scopes) == 0 {
// the OIDC provider would set `openid profile email phone` as default scope,
// but since gitlab does not handle unknown scopes correctly (phone) and returns an error,
// we will just set a separate default list
scopes = []string{openid.ScopeOpenID, openid.ScopeProfile, openid.ScopeEmail}
}
// gitlab is currently not able to handle the prompt `select_account`:
// https://gitlab.com/gitlab-org/gitlab/-/issues/377368
rp, err := oidc.New(name, issuer, clientID, clientSecret, redirectURI, scopes, oidc.DefaultMapper, options...)
if err != nil {
return nil, err
}
return &Provider{
Provider: rp,
}, nil
}

View File

@@ -0,0 +1,68 @@
package gitlab
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
)
func TestProvider_BeginAuth(t *testing.T) {
type fields struct {
clientID string
clientSecret string
redirectURI string
scopes []string
opts []oidc.ProviderOpts
}
tests := []struct {
name string
fields fields
want idp.Session
}{
{
name: "successful auth",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
scopes: []string{"openid"},
},
want: &oidc.Session{
AuthURL: "https://gitlab.com/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
},
},
{
name: "successful auth default scopes",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
},
want: &oidc.Session{
AuthURL: "https://gitlab.com/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
r := require.New(t)
provider, err := New(tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.opts...)
r.NoError(err)
ctx := context.Background()
session, err := provider.BeginAuth(ctx, "testState")
r.NoError(err)
wantAuth, wantErr := tt.want.GetAuth(ctx)
gotAuth, gotErr := session.GetAuth(ctx)
a.Equal(wantAuth, gotAuth)
a.ErrorIs(gotErr, wantErr)
})
}
}

View File

@@ -0,0 +1,225 @@
package gitlab
import (
"context"
"errors"
"testing"
"time"
"github.com/h2non/gock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/client/rp"
openid "github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
)
func TestProvider_FetchUser(t *testing.T) {
type fields struct {
clientID string
clientSecret string
redirectURI string
scopes []string
httpMock func()
authURL string
code string
tokens *openid.Tokens[*openid.IDTokenClaims]
options []oidc.ProviderOpts
}
type want struct {
err error
id string
firstName string
lastName string
displayName string
nickName string
preferredUsername string
email string
isEmailVerified bool
phone string
isPhoneVerified bool
preferredLanguage language.Tag
avatarURL string
profile string
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "unauthenticated session, error",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
scopes: []string{"openid"},
httpMock: func() {
gock.New("https://gitlab.com/oauth").
Get("/userinfo").
Reply(200).
JSON(userinfo())
},
authURL: "https://gitlab.com/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
tokens: nil,
},
want: want{
err: oidc.ErrCodeMissing,
},
},
{
name: "userinfo error",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
scopes: []string{"openid"},
httpMock: func() {
gock.New("https://gitlab.com/oauth").
Get("/userinfo").
Reply(200).
JSON(userinfo())
},
authURL: "https://gitlab.com/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
tokens: &openid.Tokens[*openid.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: openid.BearerToken,
},
IDTokenClaims: openid.NewIDTokenClaims(
issuer,
"sub2",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
want: want{
err: rp.ErrUserInfoSubNotMatching,
},
},
{
name: "successful fetch",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
scopes: []string{"openid"},
httpMock: func() {
gock.New("https://gitlab.com/oauth").
Get("/userinfo").
Reply(200).
JSON(userinfo())
},
authURL: "https://gitlab.com/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
tokens: &openid.Tokens[*openid.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: openid.BearerToken,
},
IDTokenClaims: openid.NewIDTokenClaims(
issuer,
"sub",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
want: want{
id: "sub",
firstName: "firstname",
lastName: "lastname",
displayName: "firstname lastname",
nickName: "nickname",
preferredUsername: "username",
email: "email",
isEmailVerified: true,
phone: "phone",
isPhoneVerified: true,
preferredLanguage: language.English,
avatarURL: "picture",
profile: "profile",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer gock.Off()
tt.fields.httpMock()
a := assert.New(t)
// call the real discovery endpoint
gock.New(issuer).Get(openid.DiscoveryEndpoint).EnableNetworking()
provider, err := New(tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.options...)
require.NoError(t, err)
session := &oidc.Session{
Provider: provider.Provider,
AuthURL: tt.fields.authURL,
Code: tt.fields.code,
Tokens: tt.fields.tokens,
}
user, err := session.FetchUser(context.Background())
if tt.want.err != nil && !errors.Is(err, tt.want.err) {
a.Fail("invalid error", "expected %v, got %v", tt.want.err, err)
}
if tt.want.err == nil {
a.NoError(err)
a.Equal(tt.want.id, user.GetID())
a.Equal(tt.want.firstName, user.GetFirstName())
a.Equal(tt.want.lastName, user.GetLastName())
a.Equal(tt.want.displayName, user.GetDisplayName())
a.Equal(tt.want.nickName, user.GetNickname())
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
a.Equal(domain.EmailAddress(tt.want.email), user.GetEmail())
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
a.Equal(domain.PhoneNumber(tt.want.phone), user.GetPhone())
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
a.Equal(tt.want.profile, user.GetProfile())
}
})
}
}
func userinfo() *openid.UserInfo {
return &openid.UserInfo{
Subject: "sub",
UserInfoProfile: openid.UserInfoProfile{
GivenName: "firstname",
FamilyName: "lastname",
Name: "firstname lastname",
Nickname: "nickname",
PreferredUsername: "username",
Locale: openid.NewLocale(language.English),
Picture: "picture",
Profile: "profile",
},
UserInfoEmail: openid.UserInfoEmail{
Email: "email",
EmailVerified: openid.Bool(true),
},
UserInfoPhone: openid.UserInfoPhone{
PhoneNumber: "phone",
PhoneNumberVerified: true,
},
}
}

View File

@@ -0,0 +1,51 @@
package google
import (
openid "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
)
const (
issuer = "https://accounts.google.com"
name = "Google"
)
var _ idp.Provider = (*Provider)(nil)
// Provider is the [idp.Provider] implementation for Google
type Provider struct {
*oidc.Provider
}
// New creates a Google provider using the [oidc.Provider] (OIDC generic provider)
func New(clientID, clientSecret, redirectURI string, scopes []string, opts ...oidc.ProviderOpts) (*Provider, error) {
rp, err := oidc.New(name, issuer, clientID, clientSecret, redirectURI, scopes, userMapper, append(opts, oidc.WithSelectAccount())...)
if err != nil {
return nil, err
}
return &Provider{
Provider: rp,
}, nil
}
var userMapper = func(info *openid.UserInfo) idp.User {
return &User{oidc.DefaultMapper(info)}
}
func InitUser() idp.User {
return &User{oidc.InitUser()}
}
// User is a representation of the authenticated Google and implements the [idp.User] interface
// by wrapping an [idp.User] (implemented by [oidc.User]). It overwrites the [GetPreferredUsername] to use the `email` claim.
type User struct {
idp.User
}
// GetPreferredUsername implements the [idp.User] interface.
// It returns the email, because Google does not return a username.
func (u *User) GetPreferredUsername() string {
return string(u.GetEmail())
}

View File

@@ -0,0 +1,57 @@
package google
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
)
func TestProvider_BeginAuth(t *testing.T) {
type fields struct {
clientID string
clientSecret string
redirectURI string
scopes []string
}
tests := []struct {
name string
fields fields
want idp.Session
}{
{
name: "successful auth",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
scopes: []string{"openid"},
},
want: &oidc.Session{
AuthURL: "https://accounts.google.com/o/oauth2/v2/auth?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
r := require.New(t)
provider, err := New(tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes)
r.NoError(err)
ctx := context.Background()
session, err := provider.BeginAuth(ctx, "testState")
r.NoError(err)
wantAuth, wantErr := tt.want.GetAuth(ctx)
gotAuth, gotErr := session.GetAuth(ctx)
a.Equal(wantAuth, gotAuth)
a.ErrorIs(gotErr, wantErr)
})
}
}

View File

@@ -0,0 +1,222 @@
package google
import (
"context"
"errors"
"testing"
"time"
"github.com/h2non/gock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/client/rp"
openid "github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
)
func TestSession_FetchUser(t *testing.T) {
type fields struct {
clientID string
clientSecret string
redirectURI string
scopes []string
httpMock func()
authURL string
code string
tokens *openid.Tokens[*openid.IDTokenClaims]
}
type want struct {
err error
id string
firstName string
lastName string
displayName string
nickName string
preferredUsername string
email string
isEmailVerified bool
phone string
isPhoneVerified bool
preferredLanguage language.Tag
avatarURL string
profile string
hostedDomain string
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "unauthenticated session, error",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
scopes: []string{"openid"},
httpMock: func() {
gock.New("https://openidconnect.googleapis.com").
Get("/v1/userinfo").
Reply(200).
JSON(userinfo())
},
authURL: "https://accounts.google.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
tokens: nil,
},
want: want{
err: oidc.ErrCodeMissing,
},
},
{
name: "userinfo error",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
scopes: []string{"openid"},
httpMock: func() {
gock.New("https://openidconnect.googleapis.com").
Get("/v1/userinfo").
Reply(200).
JSON(userinfo())
},
authURL: "https://accounts.google.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
tokens: &openid.Tokens[*openid.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: openid.BearerToken,
},
IDTokenClaims: openid.NewIDTokenClaims(
issuer,
"sub2",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
want: want{
err: rp.ErrUserInfoSubNotMatching,
},
},
{
name: "successful fetch",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
scopes: []string{"openid"},
httpMock: func() {
gock.New("https://openidconnect.googleapis.com").
Get("/v1/userinfo").
Reply(200).
JSON(userinfo())
},
authURL: "https://accounts.google.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
tokens: &openid.Tokens[*openid.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: openid.BearerToken,
},
IDTokenClaims: openid.NewIDTokenClaims(
issuer,
"sub",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
want: want{
id: "sub",
firstName: "firstname",
lastName: "lastname",
displayName: "firstname lastname",
nickName: "",
preferredUsername: "email",
email: "email",
isEmailVerified: true,
phone: "",
isPhoneVerified: false,
preferredLanguage: language.English,
avatarURL: "picture",
profile: "",
hostedDomain: "hosted domain",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer gock.Off()
tt.fields.httpMock()
a := assert.New(t)
// call the real discovery endpoint
gock.New(issuer).Get(openid.DiscoveryEndpoint).EnableNetworking()
provider, err := New(tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes)
require.NoError(t, err)
session := &oidc.Session{
Provider: provider.Provider,
AuthURL: tt.fields.authURL,
Code: tt.fields.code,
Tokens: tt.fields.tokens,
}
user, err := session.FetchUser(context.Background())
if tt.want.err != nil && !errors.Is(err, tt.want.err) {
a.Fail("invalid error", "expected %v, got %v", tt.want.err, err)
}
if tt.want.err == nil {
a.NoError(err)
a.Equal(tt.want.id, user.GetID())
a.Equal(tt.want.firstName, user.GetFirstName())
a.Equal(tt.want.lastName, user.GetLastName())
a.Equal(tt.want.displayName, user.GetDisplayName())
a.Equal(tt.want.nickName, user.GetNickname())
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
a.Equal(domain.EmailAddress(tt.want.email), user.GetEmail())
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
a.Equal(domain.PhoneNumber(tt.want.phone), user.GetPhone())
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
a.Equal(tt.want.profile, user.GetProfile())
}
})
}
}
func userinfo() *openid.UserInfo {
return &openid.UserInfo{
Subject: "sub",
UserInfoProfile: openid.UserInfoProfile{
GivenName: "firstname",
FamilyName: "lastname",
Name: "firstname lastname",
Locale: openid.NewLocale(language.English),
Picture: "picture",
},
UserInfoEmail: openid.UserInfoEmail{
Email: "email",
EmailVerified: openid.Bool(true),
},
Claims: map[string]any{
"hd": "hosted domain",
},
}
}

View File

@@ -0,0 +1,141 @@
package jwt
import (
"context"
"encoding/base64"
"errors"
"net/url"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/idp"
)
const (
QueryAuthRequestID = "authRequestID"
QueryUserAgentID = "userAgentID"
)
var _ idp.Provider = (*Provider)(nil)
var (
ErrMissingState = errors.New("state missing")
)
// Provider is the [idp.Provider] implementation for a JWT provider
type Provider struct {
name string
headerName string
issuer string
jwtEndpoint string
keysEndpoint string
isLinkingAllowed bool
isCreationAllowed bool
isAutoCreation bool
isAutoUpdate bool
encryptionAlg crypto.EncryptionAlgorithm
}
type ProviderOpts func(provider *Provider)
// WithLinkingAllowed allows end users to link the federated user to an existing one
func WithLinkingAllowed() ProviderOpts {
return func(p *Provider) {
p.isLinkingAllowed = true
}
}
// WithCreationAllowed allows end users to create a new user using the federated information
func WithCreationAllowed() ProviderOpts {
return func(p *Provider) {
p.isCreationAllowed = true
}
}
// WithAutoCreation enables that federated users are automatically created if not already existing
func WithAutoCreation() ProviderOpts {
return func(p *Provider) {
p.isAutoCreation = true
}
}
// WithAutoUpdate enables that information retrieved from the provider is automatically used to update
// the existing user on each authentication
func WithAutoUpdate() ProviderOpts {
return func(p *Provider) {
p.isAutoUpdate = true
}
}
// New creates a JWT provider
func New(name, issuer, jwtEndpoint, keysEndpoint, headerName string, encryptionAlg crypto.EncryptionAlgorithm, options ...ProviderOpts) (*Provider, error) {
provider := &Provider{
name: name,
issuer: issuer,
jwtEndpoint: jwtEndpoint,
keysEndpoint: keysEndpoint,
headerName: headerName,
encryptionAlg: encryptionAlg,
}
for _, option := range options {
option(provider)
}
return provider, nil
}
// Name implements the [idp.Provider] interface
func (p *Provider) Name() string {
return p.name
}
// BeginAuth implements the [idp.Provider] interface.
// It will create a [Session] with an AuthURL, pointing to the jwtEndpoint
// with the authRequest and encrypted userAgent ids.
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...idp.Parameter) (idp.Session, error) {
if state == "" {
return nil, ErrMissingState
}
userAgentID := userAgentIDFromParams(state, params...)
redirect, err := url.Parse(p.jwtEndpoint)
if err != nil {
return nil, err
}
q := redirect.Query()
q.Set(QueryAuthRequestID, state)
nonce, err := p.encryptionAlg.Encrypt([]byte(userAgentID))
if err != nil {
return nil, err
}
q.Set(QueryUserAgentID, base64.RawURLEncoding.EncodeToString(nonce))
redirect.RawQuery = q.Encode()
return &Session{AuthURL: redirect.String()}, nil
}
func userAgentIDFromParams(state string, params ...idp.Parameter) string {
for _, param := range params {
if id, ok := param.(idp.UserAgentID); ok {
return string(id)
}
}
return state
}
// IsLinkingAllowed implements the [idp.Provider] interface.
func (p *Provider) IsLinkingAllowed() bool {
return p.isLinkingAllowed
}
// IsCreationAllowed implements the [idp.Provider] interface.
func (p *Provider) IsCreationAllowed() bool {
return p.isCreationAllowed
}
// IsAutoCreation implements the [idp.Provider] interface.
func (p *Provider) IsAutoCreation() bool {
return p.isAutoCreation
}
// IsAutoUpdate implements the [idp.Provider] interface.
func (p *Provider) IsAutoUpdate() bool {
return p.isAutoUpdate
}

View File

@@ -0,0 +1,226 @@
package jwt
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/idp"
)
func TestProvider_BeginAuth(t *testing.T) {
type fields struct {
name string
issuer string
jwtEndpoint string
keysEndpoint string
headerName string
encryptionAlg func(t *testing.T) crypto.EncryptionAlgorithm
}
type args struct {
state string
params []idp.Parameter
}
type want struct {
session idp.Session
err func(error) bool
}
tests := []struct {
name string
fields fields
args args
want want
}{
{
name: "missing state, error",
fields: fields{
issuer: "https://jwt.com",
jwtEndpoint: "https://auth.com/jwt",
keysEndpoint: "https://jwt.com/keys",
headerName: "jwt-header",
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
},
},
args: args{
state: "",
params: nil,
},
want: want{
err: func(err error) bool {
return errors.Is(err, ErrMissingState)
},
},
},
{
name: "missing userAgentID, fallback to state",
fields: fields{
issuer: "https://jwt.com",
jwtEndpoint: "https://auth.com/jwt",
keysEndpoint: "https://jwt.com/keys",
headerName: "jwt-header",
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
},
},
args: args{
state: "testState",
params: nil,
},
want: want{
session: &Session{AuthURL: "https://auth.com/jwt?authRequestID=testState&userAgentID=dGVzdFN0YXRl"},
},
},
{
name: "successful auth",
fields: fields{
issuer: "https://jwt.com",
jwtEndpoint: "https://auth.com/jwt",
keysEndpoint: "https://jwt.com/keys",
headerName: "jwt-header",
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
},
},
args: args{
state: "testState",
params: []idp.Parameter{
idp.UserAgentID("agent"),
},
},
want: want{
session: &Session{AuthURL: "https://auth.com/jwt?authRequestID=testState&userAgentID=YWdlbnQ"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
provider, err := New(
tt.fields.name,
tt.fields.issuer,
tt.fields.jwtEndpoint,
tt.fields.keysEndpoint,
tt.fields.headerName,
tt.fields.encryptionAlg(t),
)
require.NoError(t, err)
ctx := context.Background()
session, err := provider.BeginAuth(ctx, tt.args.state, tt.args.params...)
if tt.want.err != nil && !tt.want.err(err) {
a.Fail("invalid error", err)
}
if tt.want.err == nil {
a.NoError(err)
wantAuth, wantErr := tt.want.session.GetAuth(ctx)
gotAuth, gotErr := session.GetAuth(ctx)
a.Equal(wantAuth, gotAuth)
a.ErrorIs(gotErr, wantErr)
}
})
}
}
func TestProvider_Options(t *testing.T) {
type fields struct {
name string
issuer string
jwtEndpoint string
keysEndpoint string
headerName string
encryptionAlg func(t *testing.T) crypto.EncryptionAlgorithm
opts []ProviderOpts
}
type want struct {
name string
linkingAllowed bool
creationAllowed bool
autoCreation bool
autoUpdate bool
pkce bool
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "default",
fields: fields{
name: "jwt",
issuer: "https://jwt.com",
jwtEndpoint: "https://auth.com/jwt",
keysEndpoint: "https://jwt.com/keys",
headerName: "jwt-header",
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
},
opts: nil,
},
want: want{
name: "jwt",
linkingAllowed: false,
creationAllowed: false,
autoCreation: false,
autoUpdate: false,
pkce: false,
},
},
{
name: "all true",
fields: fields{
name: "jwt",
issuer: "https://jwt.com",
jwtEndpoint: "https://auth.com/jwt",
keysEndpoint: "https://jwt.com/keys",
headerName: "jwt-header",
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
},
opts: []ProviderOpts{
WithLinkingAllowed(),
WithCreationAllowed(),
WithAutoCreation(),
WithAutoUpdate(),
},
},
want: want{
name: "jwt",
linkingAllowed: true,
creationAllowed: true,
autoCreation: true,
autoUpdate: true,
pkce: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
provider, err := New(
tt.fields.name,
tt.fields.issuer,
tt.fields.jwtEndpoint,
tt.fields.keysEndpoint,
tt.fields.headerName,
tt.fields.encryptionAlg(t),
tt.fields.opts...,
)
require.NoError(t, err)
a.Equal(tt.want.name, provider.Name())
a.Equal(tt.want.linkingAllowed, provider.IsLinkingAllowed())
a.Equal(tt.want.creationAllowed, provider.IsCreationAllowed())
a.Equal(tt.want.autoCreation, provider.IsAutoCreation())
a.Equal(tt.want.autoUpdate, provider.IsAutoUpdate())
})
}
}

View File

@@ -0,0 +1,169 @@
package jwt
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"time"
"github.com/zitadel/logging"
"github.com/zitadel/oidc/v3/pkg/client/rp"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/idp"
)
var _ idp.Session = (*Session)(nil)
var (
ErrNoTokens = errors.New("no tokens provided")
ErrInvalidToken = errors.New("invalid tokens provided")
)
// Session is the [idp.Session] implementation for the JWT provider
type Session struct {
*Provider
AuthURL string
Tokens *oidc.Tokens[*oidc.IDTokenClaims]
}
func NewSession(provider *Provider, tokens *oidc.Tokens[*oidc.IDTokenClaims]) *Session {
return &Session{Provider: provider, Tokens: tokens}
}
func NewSessionFromRequest(provider *Provider, r *http.Request) *Session {
token := strings.TrimPrefix(r.Header.Get(provider.headerName), oidc.PrefixBearer)
return NewSession(provider, &oidc.Tokens[*oidc.IDTokenClaims]{IDToken: token, Token: &oauth2.Token{}})
}
// GetAuth implements the [idp.Session] interface.
func (s *Session) GetAuth(ctx context.Context) (idp.Auth, error) {
return idp.Redirect(s.AuthURL)
}
// PersistentParameters implements the [idp.Session] interface.
func (s *Session) PersistentParameters() map[string]any {
return nil
}
// FetchUser implements the [idp.Session] interface.
// It will map the received idToken into an [idp.User].
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
if s.Tokens == nil {
return nil, ErrNoTokens
}
s.Tokens.IDTokenClaims, err = s.validateToken(ctx, s.Tokens.IDToken)
if err != nil {
return nil, err
}
return &User{s.Tokens.IDTokenClaims}, nil
}
func (s *Session) ExpiresAt() time.Time {
if s.Tokens == nil || s.Tokens.IDTokenClaims == nil {
return time.Time{}
}
return s.Tokens.IDTokenClaims.GetExpiration()
}
func (s *Session) validateToken(ctx context.Context, token string) (*oidc.IDTokenClaims, error) {
logging.Debug("begin token validation")
// TODO: be able to specify them in the template: https://github.com/zitadel/zitadel/issues/5322
offset := 3 * time.Second
maxAge := time.Hour
claims := new(oidc.IDTokenClaims)
payload, err := oidc.ParseToken(token, claims)
if err != nil {
return nil, fmt.Errorf("%w: malformed jwt payload: %v", ErrInvalidToken, err)
}
if err = oidc.CheckIssuer(claims, s.Provider.issuer); err != nil {
return nil, fmt.Errorf("%w: invalid issuer: %v", ErrInvalidToken, err)
}
logging.Debug("begin signature validation")
keySet := rp.NewRemoteKeySet(http.DefaultClient, s.Provider.keysEndpoint)
if err = oidc.CheckSignature(ctx, token, payload, claims, nil, keySet); err != nil {
return nil, fmt.Errorf("%w: invalid signature: %v", ErrInvalidToken, err)
}
if !claims.GetExpiration().IsZero() {
if err = oidc.CheckExpiration(claims, offset); err != nil {
return nil, fmt.Errorf("%w: expired: %v", ErrInvalidToken, err)
}
}
if !claims.GetIssuedAt().IsZero() {
if err = oidc.CheckIssuedAt(claims, maxAge, offset); err != nil {
return nil, fmt.Errorf("%w: %v", ErrInvalidToken, err)
}
}
return claims, nil
}
func InitUser() *User {
return &User{
IDTokenClaims: &oidc.IDTokenClaims{},
}
}
type User struct {
*oidc.IDTokenClaims
}
func (u *User) GetID() string {
return u.Subject
}
func (u *User) GetFirstName() string {
return u.GivenName
}
func (u *User) GetLastName() string {
return u.FamilyName
}
func (u *User) GetDisplayName() string {
return u.Name
}
func (u *User) GetNickname() string {
return u.Nickname
}
func (u *User) GetPreferredUsername() string {
return u.PreferredUsername
}
func (u *User) GetEmail() domain.EmailAddress {
return domain.EmailAddress(u.Email)
}
func (u *User) IsEmailVerified() bool {
return bool(u.EmailVerified)
}
func (u *User) GetPhone() domain.PhoneNumber {
return domain.PhoneNumber(u.IDTokenClaims.PhoneNumber)
}
func (u *User) IsPhoneVerified() bool {
return u.PhoneNumberVerified
}
func (u *User) GetPreferredLanguage() language.Tag {
return u.Locale.Tag()
}
func (u *User) GetAvatarURL() string {
return u.Picture
}
func (u *User) GetProfile() string {
return u.Profile
}

View File

@@ -0,0 +1,312 @@
package jwt
import (
"context"
"encoding/json"
"errors"
"testing"
"time"
"github.com/go-jose/go-jose/v4"
"github.com/h2non/gock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/oidc"
"go.uber.org/mock/gomock"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
)
func TestSession_FetchUser(t *testing.T) {
type fields struct {
name string
issuer string
jwtEndpoint string
keysEndpoint string
headerName string
encryptionAlg func(t *testing.T) crypto.EncryptionAlgorithm
httpMock func(issuer string)
authURL string
tokens *oidc.Tokens[*oidc.IDTokenClaims]
}
type want struct {
err func(error) bool
id string
firstName string
lastName string
displayName string
nickName string
preferredUsername string
email string
isEmailVerified bool
phone string
isPhoneVerified bool
preferredLanguage language.Tag
avatarURL string
profile string
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "no tokens",
fields: fields{
issuer: "https://jwt.com",
jwtEndpoint: "https://auth.com/jwt",
keysEndpoint: "https://jwt.com/keys",
headerName: "jwt-header",
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
},
httpMock: func(issuer string) {
gock.New(issuer).
Get("/keys").
Reply(200).
JSON(keys(t))
},
},
want: want{
err: func(err error) bool {
return errors.Is(err, ErrNoTokens)
},
},
},
{
name: "invalid token",
fields: fields{
issuer: "https://jwt.com",
jwtEndpoint: "https://auth.com/jwt",
keysEndpoint: "https://jwt.com/keys",
headerName: "jwt-header",
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
},
httpMock: func(issuer string) {
gock.New(issuer).
Get("/keys").
Reply(200).
JSON(keys(t))
},
authURL: "https://auth.com/jwt?authRequestID=testState",
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{},
IDToken: "invalidToken",
},
},
want: want{
err: func(err error) bool {
return errors.Is(err, ErrInvalidToken)
},
},
},
{
name: "successful fetch",
fields: fields{
issuer: "https://jwt.com",
jwtEndpoint: "https://auth.com/jwt",
keysEndpoint: "https://jwt.com/keys",
headerName: "jwt-header",
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
},
httpMock: func(issuer string) {
gock.New(issuer).
Get("/keys").
Reply(200).
JSON(keys(t))
},
authURL: "https://auth.com/jwt?authRequestID=testState",
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{},
IDToken: idToken(t, "https://jwt.com"),
IDTokenClaims: &oidc.IDTokenClaims{
TokenClaims: oidc.TokenClaims{
Subject: "sub",
},
UserInfoProfile: oidc.UserInfoProfile{
Picture: "picture",
Name: "firstname lastname",
GivenName: "firstname",
FamilyName: "lastname",
Nickname: "nickname",
PreferredUsername: "username",
Profile: "profile",
Locale: oidc.NewLocale(language.English),
},
UserInfoEmail: oidc.UserInfoEmail{
Email: "email",
EmailVerified: oidc.Bool(true),
},
UserInfoPhone: oidc.UserInfoPhone{
PhoneNumber: "phone",
PhoneNumberVerified: true,
},
},
},
},
want: want{
id: "sub",
firstName: "firstname",
lastName: "lastname",
displayName: "firstname lastname",
nickName: "nickname",
preferredUsername: "username",
email: "email",
isEmailVerified: true,
phone: "phone",
isPhoneVerified: true,
preferredLanguage: language.English,
avatarURL: "picture",
profile: "profile",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer gock.Off()
tt.fields.httpMock(tt.fields.issuer)
a := assert.New(t)
provider, err := New(
tt.fields.name,
tt.fields.issuer,
tt.fields.jwtEndpoint,
tt.fields.keysEndpoint,
tt.fields.headerName,
tt.fields.encryptionAlg(t),
)
require.NoError(t, err)
session := &Session{
Provider: provider,
AuthURL: tt.fields.authURL,
Tokens: tt.fields.tokens,
}
user, err := session.FetchUser(context.Background())
if tt.want.err != nil && !tt.want.err(err) {
a.Fail("invalid error", err)
}
if tt.want.err == nil {
a.NoError(err)
a.Equal(tt.want.id, user.GetID())
a.Equal(tt.want.firstName, user.GetFirstName())
a.Equal(tt.want.lastName, user.GetLastName())
a.Equal(tt.want.displayName, user.GetDisplayName())
a.Equal(tt.want.nickName, user.GetNickname())
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
a.Equal(domain.EmailAddress(tt.want.email), user.GetEmail())
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
a.Equal(domain.PhoneNumber(tt.want.phone), user.GetPhone())
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
a.Equal(tt.want.profile, user.GetProfile())
}
})
}
}
func idToken(t *testing.T, issuer string) string {
claims := oidc.NewIDTokenClaims(
issuer,
"sub",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Minute),
"",
"",
nil,
"clientID",
0,
)
claims.UserInfoProfile = oidc.UserInfoProfile{
GivenName: "firstname",
FamilyName: "lastname",
Name: "firstname lastname",
Nickname: "nickname",
PreferredUsername: "username",
Locale: oidc.NewLocale(language.English),
Picture: "picture",
Profile: "profile",
}
claims.UserInfoEmail = oidc.UserInfoEmail{
Email: "email",
EmailVerified: oidc.Bool(true),
}
claims.UserInfoPhone = oidc.UserInfoPhone{
PhoneNumber: "phone",
PhoneNumberVerified: true,
}
privateKey, err := crypto.BytesToPrivateKey([]byte(`-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAs38btwb3c7r0tMaQpGvBmY+mPwMU/LpfuPoC0k2t4RsKp0fv
40SMl50CRrHgk395wch8PMPYbl3+8TtYAJuyrFALIj3Ff1UcKIk0hOH5DDsfh7/q
2wFuncTmS6bifYo8CfSq2vDGnM7nZnEvxY/MfSydZdcmIqlkUpfQmtzExw9+tSe5
Dxq6gn5JtlGgLgZGt69r5iMMrTEGhhVAXzNuMZbmlCoBru+rC8ITlTX/0V1ZcsSb
L8tYWhthyu9x6yjo1bH85wiVI4gs0MhU8f2a+kjL/KGZbR14Ua2eo6tonBZLC5DH
WM2TkYXgRCDPufjcgmzN0Lm91E4P8KvBcvly6QIDAQABAoIBAQCPj1nbSPcg2KZe
73FAD+8HopyUSSK//1AP4eXfzcEECVy77g0u9+R6XlkzsZCsZ4g6NN8ounqfyw3c
YlpAIkcFCf/dowoSjT+4LASVQyatYZwWNqjgAIU4KgMG/rKnNahPTiBYe7peMB1j
EaPjnt8uPkCk8y7NCi3y4Pk24tt/WM5KbJK2NQhUi1csGnleDfE+0blV0l/e6C68
W5cbnbWAroMqae/Yon3XVZiXX0m+l2f6ZzIgKaD18J+eEM8FjJC+jQKiRe1i9v3K
nQrLwh/gn8J10FcbKn3xqslKVidzASIrNIzHT9j/Z5T9NXuAKa7IV2x+Dtdus+wq
iBsUunwBAoGBANpYew+8i9vDwK4/SefduDTuzJ0H9lWTjtbiWQ+KYZoeJ7q3/qns
jsmi+mjxkXxXg1RrGbNbjtbl3RXXIrUeeBB0lglRJUjc3VK7VvNoyXIWsiqhCspH
IJ9Yuknv4mXB01m/glbSCS/xu4RTgf5aOG4jUiRb9+dCIpvDxI9gbXEVAoGBANJz
hIJkplIJ+biTi3G1Oz17qkUkInNXzAEzKD9Atoz5AIAiR1ivOMLOlbucfjevw/Nw
TnpkMs9xqCefKupTlsriXtZI88m7ZKzAmolYsPolOy/Jhi31h9JFVTEfKGqVS+dk
A4ndhgdW9RUeNJPY2YVCARXQrWpueweQDA1cNaeFAoGAPJsYtXqBW6PPRM5+ZiSt
78tk8iV2o7RMjqrPS7f+dXfvUS2nO2VVEPTzCtQarOfhpToBLT65vD6bimdn09w8
OV0TFEz4y2u65y7m6LNqTwertpdy1ki97l0DgGhccCBH2P6GYDD2qd8wTH+dcot6
ZF/begopGoDJ+HBzi9SZLC0CgYBZzPslHMevyBvr++GLwrallKhiWnns1/DwLiEl
ZHrBCtuA0Z+6IwLIdZiE9tEQ+ApYTXrfVPQteqUzSwLn/IUiy5eGPpjwYushoAoR
Q2w5QTvRN1/vKo8rVXR1woLfgBdkhFPSN1mitiNcQIhU8jpXV4PZCDOHb99FqdzK
sqcedQKBgQCOmgbqxGsnT2WQhoOdzln+NOo6Tx+FveLLqat2KzpY59W4noeI2Awn
HfIQgWUAW9dsjVVOXMP1jhq8U9hmH/PFWA11V/iCdk1NTxZEw87VAOeWuajpdDHG
+iex349j8h2BcQ4Zd0FWu07gGFnS/yuDJPn6jBhRusdieEcxLRjTKg==
-----END RSA PRIVATE KEY-----
`))
if err != nil {
t.Fatal(err)
}
signer, err := jose.NewSigner(jose.SigningKey{Key: privateKey, Algorithm: "RS256"}, &jose.SignerOptions{})
if err != nil {
t.Fatal(err)
}
data, err := json.Marshal(claims)
if err != nil {
t.Fatal(err)
}
jws, err := signer.Sign(data)
if err != nil {
t.Fatal(err)
}
idToken, err := jws.CompactSerialize()
if err != nil {
t.Fatal(err)
}
return idToken
}
func keys(t *testing.T) *jose.JSONWebKeySet {
privateKey, err := crypto.BytesToPublicKey([]byte(`-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAs38btwb3c7r0tMaQpGvB
mY+mPwMU/LpfuPoC0k2t4RsKp0fv40SMl50CRrHgk395wch8PMPYbl3+8TtYAJuy
rFALIj3Ff1UcKIk0hOH5DDsfh7/q2wFuncTmS6bifYo8CfSq2vDGnM7nZnEvxY/M
fSydZdcmIqlkUpfQmtzExw9+tSe5Dxq6gn5JtlGgLgZGt69r5iMMrTEGhhVAXzNu
MZbmlCoBru+rC8ITlTX/0V1ZcsSbL8tYWhthyu9x6yjo1bH85wiVI4gs0MhU8f2a
+kjL/KGZbR14Ua2eo6tonBZLC5DHWM2TkYXgRCDPufjcgmzN0Lm91E4P8KvBcvly
6QIDAQAB
-----END PUBLIC KEY-----
`))
if err != nil {
t.Fatal(err)
}
return &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{{Key: privateKey, Algorithm: "RS256", Use: oidc.KeyUseSignature}}}
}

View File

@@ -0,0 +1,290 @@
package ldap
import (
"context"
"time"
"github.com/zitadel/zitadel/internal/idp"
)
const DefaultPort = "389"
var _ idp.Provider = (*Provider)(nil)
// Provider is the [idp.Provider] implementation for a generic LDAP provider
type Provider struct {
name string
servers []string
startTLS bool
baseDN string
bindDN string
bindPassword string
userBase string
userObjectClasses []string
userFilters []string
timeout time.Duration
rootCA []byte
loginUrl string
isLinkingAllowed bool
isCreationAllowed bool
isAutoCreation bool
isAutoUpdate bool
idAttribute string
firstNameAttribute string
lastNameAttribute string
displayNameAttribute string
nickNameAttribute string
preferredUsernameAttribute string
emailAttribute string
emailVerifiedAttribute string
phoneAttribute string
phoneVerifiedAttribute string
preferredLanguageAttribute string
avatarURLAttribute string
profileAttribute string
}
type ProviderOpts func(provider *Provider)
// WithLinkingAllowed allows end users to link the federated user to an existing one.
func WithLinkingAllowed() ProviderOpts {
return func(p *Provider) {
p.isLinkingAllowed = true
}
}
// WithCreationAllowed allows end users to create a new user using the federated information.
func WithCreationAllowed() ProviderOpts {
return func(p *Provider) {
p.isCreationAllowed = true
}
}
// WithAutoCreation enables that federated users are automatically created if not already existing.
func WithAutoCreation() ProviderOpts {
return func(p *Provider) {
p.isAutoCreation = true
}
}
// WithAutoUpdate enables that information retrieved from the provider is automatically used to update
// the existing user on each authentication.
func WithAutoUpdate() ProviderOpts {
return func(p *Provider) {
p.isAutoUpdate = true
}
}
// WithoutStartTLS configures to communication insecure with the LDAP server without startTLS
func WithoutStartTLS() ProviderOpts {
return func(p *Provider) {
p.startTLS = false
}
}
// WithCustomIDAttribute configures to map the LDAP attribute to the user, default is the uniqueUserAttribute
func WithCustomIDAttribute(name string) ProviderOpts {
return func(p *Provider) {
p.idAttribute = name
}
}
// WithFirstNameAttribute configures to map the LDAP attribute to the user
func WithFirstNameAttribute(name string) ProviderOpts {
return func(p *Provider) {
p.firstNameAttribute = name
}
}
// WithLastNameAttribute configures to map the LDAP attribute to the user
func WithLastNameAttribute(name string) ProviderOpts {
return func(p *Provider) {
p.lastNameAttribute = name
}
}
// WithDisplayNameAttribute configures to map the LDAP attribute to the user
func WithDisplayNameAttribute(name string) ProviderOpts {
return func(p *Provider) {
p.displayNameAttribute = name
}
}
// WithNickNameAttribute configures to map the LDAP attribute to the user
func WithNickNameAttribute(name string) ProviderOpts {
return func(p *Provider) {
p.nickNameAttribute = name
}
}
// WithPreferredUsernameAttribute configures to map the LDAP attribute to the user
func WithPreferredUsernameAttribute(name string) ProviderOpts {
return func(p *Provider) {
p.preferredUsernameAttribute = name
}
}
// WithEmailAttribute configures to map the LDAP attribute to the user
func WithEmailAttribute(name string) ProviderOpts {
return func(p *Provider) {
p.emailAttribute = name
}
}
// WithEmailVerifiedAttribute configures to map the LDAP attribute to the user
func WithEmailVerifiedAttribute(name string) ProviderOpts {
return func(p *Provider) {
p.emailVerifiedAttribute = name
}
}
// WithPhoneAttribute configures to map the LDAP attribute to the user
func WithPhoneAttribute(name string) ProviderOpts {
return func(p *Provider) {
p.phoneAttribute = name
}
}
// WithPhoneVerifiedAttribute configures to map the LDAP attribute to the user
func WithPhoneVerifiedAttribute(name string) ProviderOpts {
return func(p *Provider) {
p.phoneVerifiedAttribute = name
}
}
// WithPreferredLanguageAttribute configures to map the LDAP attribute to the user
func WithPreferredLanguageAttribute(name string) ProviderOpts {
return func(p *Provider) {
p.preferredLanguageAttribute = name
}
}
// WithAvatarURLAttribute configures to map the LDAP attribute to the user
func WithAvatarURLAttribute(name string) ProviderOpts {
return func(p *Provider) {
p.avatarURLAttribute = name
}
}
// WithProfileAttribute configures to map the LDAP attribute to the user
func WithProfileAttribute(name string) ProviderOpts {
return func(p *Provider) {
p.profileAttribute = name
}
}
func New(
name string,
servers []string,
baseDN string,
bindDN string,
bindPassword string,
userBase string,
userObjectClasses []string,
userFilters []string,
timeout time.Duration,
rootCA []byte,
loginUrl string,
options ...ProviderOpts,
) *Provider {
provider := &Provider{
name: name,
servers: servers,
startTLS: true,
baseDN: baseDN,
bindDN: bindDN,
bindPassword: bindPassword,
userBase: userBase,
userObjectClasses: userObjectClasses,
userFilters: userFilters,
timeout: timeout,
rootCA: rootCA,
loginUrl: loginUrl,
}
for _, option := range options {
option(provider)
}
return provider
}
func (p *Provider) Name() string {
return p.name
}
func (p *Provider) BeginAuth(ctx context.Context, state string, _ ...idp.Parameter) (idp.Session, error) {
return &Session{
Provider: p,
loginUrl: p.loginUrl + state,
}, nil
}
func (p *Provider) GetSession(username, password string) *Session {
return &Session{
Provider: p,
User: username,
Password: password,
}
}
func (p *Provider) IsLinkingAllowed() bool {
return p.isLinkingAllowed
}
func (p *Provider) IsCreationAllowed() bool {
return p.isCreationAllowed
}
func (p *Provider) IsAutoCreation() bool {
return p.isAutoCreation
}
func (p *Provider) IsAutoUpdate() bool {
return p.isAutoUpdate
}
func (p *Provider) getNecessaryAttributes() []string {
attributes := []string{p.userBase}
if p.idAttribute != "" {
attributes = append(attributes, p.idAttribute)
}
if p.firstNameAttribute != "" {
attributes = append(attributes, p.firstNameAttribute)
}
if p.lastNameAttribute != "" {
attributes = append(attributes, p.lastNameAttribute)
}
if p.displayNameAttribute != "" {
attributes = append(attributes, p.displayNameAttribute)
}
if p.nickNameAttribute != "" {
attributes = append(attributes, p.nickNameAttribute)
}
if p.preferredUsernameAttribute != "" {
attributes = append(attributes, p.preferredUsernameAttribute)
}
if p.emailAttribute != "" {
attributes = append(attributes, p.emailAttribute)
}
if p.emailVerifiedAttribute != "" {
attributes = append(attributes, p.emailVerifiedAttribute)
}
if p.phoneAttribute != "" {
attributes = append(attributes, p.phoneAttribute)
}
if p.phoneVerifiedAttribute != "" {
attributes = append(attributes, p.phoneVerifiedAttribute)
}
if p.preferredLanguageAttribute != "" {
attributes = append(attributes, p.preferredLanguageAttribute)
}
if p.avatarURLAttribute != "" {
attributes = append(attributes, p.avatarURLAttribute)
}
if p.profileAttribute != "" {
attributes = append(attributes, p.profileAttribute)
}
return attributes
}

View File

@@ -0,0 +1,207 @@
package ldap
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestProvider_Options(t *testing.T) {
type fields struct {
name string
servers []string
baseDN string
bindDN string
bindPassword string
userBase string
userObjectClasses []string
userFilters []string
timeout time.Duration
rootCA []byte
loginUrl string
opts []ProviderOpts
}
type want struct {
name string
rootCA []byte
startTls bool
linkingAllowed bool
creationAllowed bool
autoCreation bool
autoUpdate bool
idAttribute string
firstNameAttribute string
lastNameAttribute string
displayNameAttribute string
nickNameAttribute string
preferredUsernameAttribute string
emailAttribute string
emailVerifiedAttribute string
phoneAttribute string
phoneVerifiedAttribute string
preferredLanguageAttribute string
avatarURLAttribute string
profileAttribute string
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "default",
fields: fields{
name: "ldap",
servers: []string{"server"},
baseDN: "base",
bindDN: "binddn",
bindPassword: "password",
userBase: "user",
userObjectClasses: []string{"object"},
userFilters: []string{"filter"},
timeout: 30 * time.Second,
loginUrl: "url",
opts: nil,
},
want: want{
name: "ldap",
startTls: true,
linkingAllowed: false,
creationAllowed: false,
autoCreation: false,
autoUpdate: false,
idAttribute: "",
},
},
{
name: "all true",
fields: fields{
name: "ldap",
servers: []string{"server"},
baseDN: "base",
bindDN: "binddn",
bindPassword: "password",
userBase: "user",
userObjectClasses: []string{"object"},
userFilters: []string{"filter"},
timeout: 30 * time.Second,
loginUrl: "url",
opts: []ProviderOpts{
WithoutStartTLS(),
WithLinkingAllowed(),
WithCreationAllowed(),
WithAutoCreation(),
WithAutoUpdate(),
},
},
want: want{
name: "ldap",
startTls: false,
linkingAllowed: true,
creationAllowed: true,
autoCreation: true,
autoUpdate: true,
idAttribute: "",
},
}, {
name: "all true, attributes set",
fields: fields{
name: "ldap",
servers: []string{"server"},
baseDN: "base",
bindDN: "binddn",
bindPassword: "password",
userBase: "user",
userObjectClasses: []string{"object"},
userFilters: []string{"filter"},
timeout: 30 * time.Second,
rootCA: []byte("certificate"),
loginUrl: "url",
opts: []ProviderOpts{
WithoutStartTLS(),
WithLinkingAllowed(),
WithCreationAllowed(),
WithAutoCreation(),
WithAutoUpdate(),
WithCustomIDAttribute("id"),
WithFirstNameAttribute("first"),
WithLastNameAttribute("last"),
WithDisplayNameAttribute("display"),
WithNickNameAttribute("nick"),
WithPreferredUsernameAttribute("prefUser"),
WithEmailAttribute("email"),
WithEmailVerifiedAttribute("emailVerified"),
WithPhoneAttribute("phone"),
WithPhoneVerifiedAttribute("phoneVerified"),
WithPreferredLanguageAttribute("prefLang"),
WithAvatarURLAttribute("avatar"),
WithProfileAttribute("profile"),
},
},
want: want{
name: "ldap",
rootCA: []byte("certificate"),
startTls: false,
linkingAllowed: true,
creationAllowed: true,
autoCreation: true,
autoUpdate: true,
idAttribute: "id",
firstNameAttribute: "first",
lastNameAttribute: "last",
displayNameAttribute: "display",
nickNameAttribute: "nick",
preferredUsernameAttribute: "prefUser",
emailAttribute: "email",
emailVerifiedAttribute: "emailVerified",
phoneAttribute: "phone",
phoneVerifiedAttribute: "phoneVerified",
preferredLanguageAttribute: "prefLang",
avatarURLAttribute: "avatar",
profileAttribute: "profile",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
provider := New(
tt.fields.name,
tt.fields.servers,
tt.fields.baseDN,
tt.fields.bindDN,
tt.fields.bindPassword,
tt.fields.userBase,
tt.fields.userObjectClasses,
tt.fields.userFilters,
tt.fields.timeout,
tt.fields.rootCA,
tt.fields.loginUrl,
tt.fields.opts...,
)
a.Equal(tt.want.name, provider.Name())
a.Equal(tt.want.rootCA, provider.rootCA)
a.Equal(tt.want.startTls, provider.startTLS)
a.Equal(tt.want.linkingAllowed, provider.IsLinkingAllowed())
a.Equal(tt.want.creationAllowed, provider.IsCreationAllowed())
a.Equal(tt.want.autoCreation, provider.IsAutoCreation())
a.Equal(tt.want.autoUpdate, provider.IsAutoUpdate())
a.Equal(tt.want.idAttribute, provider.idAttribute)
a.Equal(tt.want.firstNameAttribute, provider.firstNameAttribute)
a.Equal(tt.want.lastNameAttribute, provider.lastNameAttribute)
a.Equal(tt.want.displayNameAttribute, provider.displayNameAttribute)
a.Equal(tt.want.nickNameAttribute, provider.nickNameAttribute)
a.Equal(tt.want.preferredUsernameAttribute, provider.preferredUsernameAttribute)
a.Equal(tt.want.emailAttribute, provider.emailAttribute)
a.Equal(tt.want.emailVerifiedAttribute, provider.emailVerifiedAttribute)
a.Equal(tt.want.phoneAttribute, provider.phoneAttribute)
a.Equal(tt.want.phoneVerifiedAttribute, provider.phoneVerifiedAttribute)
a.Equal(tt.want.preferredLanguageAttribute, provider.preferredLanguageAttribute)
a.Equal(tt.want.avatarURLAttribute, provider.avatarURLAttribute)
a.Equal(tt.want.profileAttribute, provider.profileAttribute)
})
}
}

View File

@@ -0,0 +1,333 @@
package ldap
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"errors"
"net"
"net/url"
"strconv"
"time"
"unicode/utf8"
"github.com/go-ldap/ldap/v3"
"github.com/zitadel/logging"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/idp"
)
var ErrNoSingleUser = errors.New("user does not exist or too many entries returned")
var ErrFailedLogin = errors.New("user failed to login")
var ErrUnableToAppendRootCA = errors.New("unable to append rootCA")
var _ idp.Session = (*Session)(nil)
type Session struct {
Provider *Provider
loginUrl string
User string
Password string
Entry *ldap.Entry
}
func NewSession(provider *Provider, username, password string) *Session {
return &Session{Provider: provider, User: username, Password: password}
}
// GetAuth implements the [idp.Session] interface.
func (s *Session) GetAuth(ctx context.Context) (idp.Auth, error) {
return idp.Redirect(s.loginUrl)
}
// PersistentParameters implements the [idp.Session] interface.
func (s *Session) PersistentParameters() map[string]any {
return nil
}
// FetchUser implements the [idp.Session] interface.
func (s *Session) FetchUser(_ context.Context) (_ idp.User, err error) {
var user *ldap.Entry
for _, server := range s.Provider.servers {
user, err = tryBind(server,
s.Provider.startTLS,
s.Provider.bindDN,
s.Provider.bindPassword,
s.Provider.baseDN,
s.Provider.getNecessaryAttributes(),
s.Provider.userObjectClasses,
s.Provider.userFilters,
s.User,
s.Password,
s.Provider.timeout,
s.Provider.rootCA)
// If there were invalid credentials or multiple users with the credentials cancel process
if err != nil && (errors.Is(err, ErrFailedLogin) || errors.Is(err, ErrNoSingleUser)) {
return nil, err
}
// If a user bind was successful and user is filled continue with login, otherwise try next server
if err == nil && user != nil {
break
}
}
if err != nil {
return nil, err
}
s.Entry = user
return mapLDAPEntryToUser(
user,
s.Provider.idAttribute,
s.Provider.firstNameAttribute,
s.Provider.lastNameAttribute,
s.Provider.displayNameAttribute,
s.Provider.nickNameAttribute,
s.Provider.preferredUsernameAttribute,
s.Provider.emailAttribute,
s.Provider.emailVerifiedAttribute,
s.Provider.phoneAttribute,
s.Provider.phoneVerifiedAttribute,
s.Provider.preferredLanguageAttribute,
s.Provider.avatarURLAttribute,
s.Provider.profileAttribute,
)
}
func (s *Session) ExpiresAt() time.Time {
return time.Time{} // falls back to the default expiration time
}
func tryBind(
server string,
startTLS bool,
bindDN string,
bindPassword string,
baseDN string,
attributes []string,
objectClasses []string,
userFilters []string,
username string,
password string,
timeout time.Duration,
rootCA []byte,
) (*ldap.Entry, error) {
conn, err := getConnection(server, startTLS, timeout, rootCA)
if err != nil {
return nil, err
}
defer conn.Close()
if err := conn.Bind(bindDN, bindPassword); err != nil {
return nil, err
}
return trySearchAndUserBind(
conn,
baseDN,
attributes,
objectClasses,
userFilters,
username,
password,
timeout,
)
}
func getConnection(
server string,
startTLS bool,
timeout time.Duration,
rootCA []byte,
) (*ldap.Conn, error) {
if timeout == 0 {
timeout = ldap.DefaultTimeout
}
dialer := make([]ldap.DialOpt, 1, 2)
dialer[0] = ldap.DialWithDialer(&net.Dialer{Timeout: timeout})
u, err := url.Parse(server)
if err != nil {
return nil, err
}
if u.Scheme == "ldaps" && len(rootCA) > 0 {
rootCAs := x509.NewCertPool()
if ok := rootCAs.AppendCertsFromPEM(rootCA); !ok {
return nil, ErrUnableToAppendRootCA
}
dialer = append(dialer, ldap.DialWithTLSConfig(&tls.Config{
RootCAs: rootCAs,
}))
}
conn, err := ldap.DialURL(server, dialer...)
if err != nil {
return nil, err
}
if u.Scheme == "ldap" && startTLS {
err = conn.StartTLS(&tls.Config{ServerName: u.Host})
if err != nil {
return nil, err
}
}
return conn, nil
}
func trySearchAndUserBind(
conn *ldap.Conn,
baseDN string,
attributes []string,
objectClasses []string,
userFilters []string,
username string,
password string,
timeout time.Duration,
) (*ldap.Entry, error) {
searchQuery := queriesAndToSearchQuery(
objectClassesToSearchQuery(objectClasses),
queriesOrToSearchQuery(
userFiltersToSearchQuery(userFilters, username)...,
),
)
// Search for user with the unique attribute for the userDN
searchRequest := ldap.NewSearchRequest(
baseDN,
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, int(timeout.Seconds()), false,
searchQuery,
attributes,
nil,
)
sr, err := conn.Search(searchRequest)
if err != nil {
return nil, err
}
if len(sr.Entries) != 1 {
logging.WithFields("entries", len(sr.Entries)).Info("ldap: no single user found")
return nil, ErrNoSingleUser
}
user := sr.Entries[0]
// Bind as the user to verify their password
userDN, err := ldap.ParseDN(user.DN)
if err != nil {
logging.WithFields("userDN", user.DN).WithError(err).Info("ldap user parse DN failed")
return nil, err
}
if err = conn.Bind(userDN.String(), password); err != nil {
logging.WithFields("userDN", user.DN).WithError(err).Info("ldap user bind failed")
return nil, ErrFailedLogin
}
return user, nil
}
func queriesAndToSearchQuery(queries ...string) string {
if len(queries) == 0 {
return ""
}
if len(queries) == 1 {
return queries[0]
}
joinQueries := "(&"
for _, s := range queries {
joinQueries += s
}
return joinQueries + ")"
}
func queriesOrToSearchQuery(queries ...string) string {
if len(queries) == 0 {
return ""
}
if len(queries) == 1 {
return queries[0]
}
joinQueries := "(|"
for _, s := range queries {
joinQueries += s
}
return joinQueries + ")"
}
func objectClassesToSearchQuery(classes []string) string {
searchQuery := ""
for _, class := range classes {
searchQuery += "(objectClass=" + class + ")"
}
return searchQuery
}
func userFiltersToSearchQuery(filters []string, username string) []string {
searchQueries := make([]string, len(filters))
for i, filter := range filters {
searchQueries[i] = "(" + filter + "=" + username + ")"
}
return searchQueries
}
func mapLDAPEntryToUser(
user *ldap.Entry,
idAttribute,
firstNameAttribute,
lastNameAttribute,
displayNameAttribute,
nickNameAttribute,
preferredUsernameAttribute,
emailAttribute,
emailVerifiedAttribute,
phoneAttribute,
phoneVerifiedAttribute,
preferredLanguageAttribute,
avatarURLAttribute,
profileAttribute string,
) (_ *User, err error) {
var emailVerified bool
if v := user.GetAttributeValue(emailVerifiedAttribute); v != "" {
emailVerified, err = strconv.ParseBool(v)
if err != nil {
return nil, err
}
}
var phoneVerified bool
if v := user.GetAttributeValue(phoneVerifiedAttribute); v != "" {
phoneVerified, err = strconv.ParseBool(v)
if err != nil {
return nil, err
}
}
return NewUser(
getAttributeValue(user, idAttribute),
getAttributeValue(user, firstNameAttribute),
getAttributeValue(user, lastNameAttribute),
getAttributeValue(user, displayNameAttribute),
getAttributeValue(user, nickNameAttribute),
getAttributeValue(user, preferredUsernameAttribute),
domain.EmailAddress(user.GetAttributeValue(emailAttribute)),
emailVerified,
domain.PhoneNumber(user.GetAttributeValue(phoneAttribute)),
phoneVerified,
language.Make(user.GetAttributeValue(preferredLanguageAttribute)),
user.GetAttributeValue(avatarURLAttribute),
user.GetAttributeValue(profileAttribute),
), nil
}
func getAttributeValue(user *ldap.Entry, attribute string) string {
// return an empty string if no attribute is needed
if attribute == "" {
return ""
}
value := user.GetAttributeValue(attribute)
if utf8.ValidString(value) {
return value
}
return base64.StdEncoding.EncodeToString(user.GetRawAttributeValue(attribute))
}

View File

@@ -0,0 +1,400 @@
package ldap
import (
"testing"
"github.com/go-ldap/ldap/v3"
"github.com/stretchr/testify/assert"
"golang.org/x/text/language"
)
func TestProvider_objectClassesToSearchQuery(t *testing.T) {
tests := []struct {
name string
fields []string
want string
}{
{
name: "zero",
fields: []string{},
want: "",
},
{
name: "one",
fields: []string{"test"},
want: "(objectClass=test)",
},
{
name: "three",
fields: []string{"test1", "test2", "test3"},
want: "(objectClass=test1)(objectClass=test2)(objectClass=test3)",
},
{
name: "five",
fields: []string{"test1", "test2", "test3", "test4", "test5"},
want: "(objectClass=test1)(objectClass=test2)(objectClass=test3)(objectClass=test4)(objectClass=test5)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
a.Equal(tt.want, objectClassesToSearchQuery(tt.fields))
})
}
}
func TestProvider_userFiltersToSearchQuery(t *testing.T) {
tests := []struct {
name string
fields []string
username string
want []string
}{
{
name: "zero",
fields: []string{},
username: "user",
want: []string{},
},
{
name: "one",
fields: []string{"test"},
username: "user",
want: []string{"(test=user)"},
},
{
name: "three",
fields: []string{"test1", "test2", "test3"},
username: "user",
want: []string{"(test1=user)", "(test2=user)", "(test3=user)"},
},
{
name: "five",
fields: []string{"test1", "test2", "test3", "test4", "test5"},
username: "user",
want: []string{"(test1=user)", "(test2=user)", "(test3=user)", "(test4=user)", "(test5=user)"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
a.Equal(tt.want, userFiltersToSearchQuery(tt.fields, tt.username))
})
}
}
func TestProvider_queriesAndToSearchQuery(t *testing.T) {
tests := []struct {
name string
fields []string
want string
}{
{
name: "zero",
fields: []string{},
want: "",
},
{
name: "one",
fields: []string{"(test)"},
want: "(test)",
},
{
name: "three",
fields: []string{"(test1)", "(test2)", "(test3)"},
want: "(&(test1)(test2)(test3))",
},
{
name: "five",
fields: []string{"(test1)", "(test2)", "(test3)", "(test4)", "(test5)"},
want: "(&(test1)(test2)(test3)(test4)(test5))",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
a.Equal(tt.want, queriesAndToSearchQuery(tt.fields...))
})
}
}
func TestProvider_queriesOrToSearchQuery(t *testing.T) {
tests := []struct {
name string
fields []string
want string
}{
{
name: "zero",
fields: []string{},
want: "",
},
{
name: "one",
fields: []string{"(test)"},
want: "(test)",
},
{
name: "three",
fields: []string{"(test1)", "(test2)", "(test3)"},
want: "(|(test1)(test2)(test3))",
},
{
name: "five",
fields: []string{"(test1)", "(test2)", "(test3)", "(test4)", "(test5)"},
want: "(|(test1)(test2)(test3)(test4)(test5))",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
a.Equal(tt.want, queriesOrToSearchQuery(tt.fields...))
})
}
}
func TestProvider_mapLDAPEntryToUser(t *testing.T) {
type fields struct {
user *ldap.Entry
idAttribute string
firstNameAttribute string
lastNameAttribute string
displayNameAttribute string
nickNameAttribute string
preferredUsernameAttribute string
emailAttribute string
emailVerifiedAttribute string
phoneAttribute string
phoneVerifiedAttribute string
preferredLanguageAttribute string
avatarURLAttribute string
profileAttribute string
}
type want struct {
user *User
err func(error) bool
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "empty",
fields: fields{
user: &ldap.Entry{
Attributes: []*ldap.EntryAttribute{
{Name: "id", Values: []string{"id"}},
{Name: "first", Values: []string{"first"}},
{Name: "last", Values: []string{"last"}},
{Name: "display", Values: []string{"display"}},
{Name: "nick", Values: []string{"nick"}},
{Name: "preferred", Values: []string{"preferred"}},
{Name: "email", Values: []string{"email"}},
{Name: "emailVerified", Values: []string{"false"}},
{Name: "phone", Values: []string{"phone"}},
{Name: "phoneVerified", Values: []string{"false"}},
{Name: "lang", Values: []string{"und"}},
{Name: "avatar", Values: []string{"avatar"}},
{Name: "profile", Values: []string{"profile"}},
},
},
idAttribute: "",
firstNameAttribute: "",
lastNameAttribute: "",
displayNameAttribute: "",
nickNameAttribute: "",
preferredUsernameAttribute: "",
emailAttribute: "",
emailVerifiedAttribute: "",
phoneAttribute: "",
phoneVerifiedAttribute: "",
preferredLanguageAttribute: "",
avatarURLAttribute: "",
profileAttribute: "",
},
want: want{
user: &User{
ID: "",
FirstName: "",
LastName: "",
DisplayName: "",
NickName: "",
PreferredUsername: "",
Email: "",
EmailVerified: false,
Phone: "",
PhoneVerified: false,
PreferredLanguage: language.Tag{},
AvatarURL: "",
Profile: "",
},
},
},
{
name: "failed parse emailVerified",
fields: fields{
user: &ldap.Entry{
Attributes: []*ldap.EntryAttribute{
{Name: "id", Values: []string{"id"}},
{Name: "first", Values: []string{"first"}},
{Name: "last", Values: []string{"last"}},
{Name: "display", Values: []string{"display"}},
{Name: "nick", Values: []string{"nick"}},
{Name: "preferred", Values: []string{"preferred"}},
{Name: "email", Values: []string{"email"}},
{Name: "emailVerified", Values: []string{"failure"}},
{Name: "phone", Values: []string{"phone"}},
{Name: "phoneVerified", Values: []string{"false"}},
{Name: "lang", Values: []string{"und"}},
{Name: "avatar", Values: []string{"avatar"}},
{Name: "profile", Values: []string{"profile"}},
},
},
idAttribute: "id",
firstNameAttribute: "first",
lastNameAttribute: "last",
displayNameAttribute: "display",
nickNameAttribute: "nick",
preferredUsernameAttribute: "preferred",
emailAttribute: "email",
emailVerifiedAttribute: "emailVerified",
phoneAttribute: "phone",
phoneVerifiedAttribute: "phoneVerified",
preferredLanguageAttribute: "lang",
avatarURLAttribute: "avatar",
profileAttribute: "profile",
},
want: want{
err: func(err error) bool {
return err != nil
},
},
},
{
name: "failed parse phoneVerified",
fields: fields{
user: &ldap.Entry{
Attributes: []*ldap.EntryAttribute{
{Name: "id", Values: []string{"id"}},
{Name: "first", Values: []string{"first"}},
{Name: "last", Values: []string{"last"}},
{Name: "display", Values: []string{"display"}},
{Name: "nick", Values: []string{"nick"}},
{Name: "preferred", Values: []string{"preferred"}},
{Name: "email", Values: []string{"email"}},
{Name: "emailVerified", Values: []string{"false"}},
{Name: "phone", Values: []string{"phone"}},
{Name: "phoneVerified", Values: []string{"failure"}},
{Name: "lang", Values: []string{"und"}},
{Name: "avatar", Values: []string{"avatar"}},
{Name: "profile", Values: []string{"profile"}},
},
},
idAttribute: "id",
firstNameAttribute: "first",
lastNameAttribute: "last",
displayNameAttribute: "display",
nickNameAttribute: "nick",
preferredUsernameAttribute: "preferred",
emailAttribute: "email",
emailVerifiedAttribute: "emailVerified",
phoneAttribute: "phone",
phoneVerifiedAttribute: "phoneVerified",
preferredLanguageAttribute: "lang",
avatarURLAttribute: "avatar",
profileAttribute: "profile",
},
want: want{
err: func(err error) bool {
return err != nil
},
},
},
{
name: "full user",
fields: fields{
user: &ldap.Entry{
Attributes: []*ldap.EntryAttribute{
{Name: "id", Values: []string{"id"}},
{Name: "first", Values: []string{"first"}},
{Name: "last", Values: []string{"last"}},
{Name: "display", Values: []string{"display"}},
{Name: "nick", Values: []string{"nick"}},
{Name: "preferred", Values: []string{"preferred"}},
{Name: "email", Values: []string{"email"}},
{Name: "emailVerified", Values: []string{"false"}},
{Name: "phone", Values: []string{"phone"}},
{Name: "phoneVerified", Values: []string{"false"}},
{Name: "lang", Values: []string{"und"}},
{Name: "avatar", Values: []string{"avatar"}},
{Name: "profile", Values: []string{"profile"}},
},
},
idAttribute: "id",
firstNameAttribute: "first",
lastNameAttribute: "last",
displayNameAttribute: "display",
nickNameAttribute: "nick",
preferredUsernameAttribute: "preferred",
emailAttribute: "email",
emailVerifiedAttribute: "emailVerified",
phoneAttribute: "phone",
phoneVerifiedAttribute: "phoneVerified",
preferredLanguageAttribute: "lang",
avatarURLAttribute: "avatar",
profileAttribute: "profile",
},
want: want{
user: &User{
ID: "id",
FirstName: "first",
LastName: "last",
DisplayName: "display",
NickName: "nick",
PreferredUsername: "preferred",
Email: "email",
EmailVerified: false,
Phone: "phone",
PhoneVerified: false,
PreferredLanguage: language.Make("und"),
AvatarURL: "avatar",
Profile: "profile",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := mapLDAPEntryToUser(
tt.fields.user,
tt.fields.idAttribute,
tt.fields.firstNameAttribute,
tt.fields.lastNameAttribute,
tt.fields.displayNameAttribute,
tt.fields.nickNameAttribute,
tt.fields.preferredUsernameAttribute,
tt.fields.emailAttribute,
tt.fields.emailVerifiedAttribute,
tt.fields.phoneAttribute,
tt.fields.phoneVerifiedAttribute,
tt.fields.preferredLanguageAttribute,
tt.fields.avatarURLAttribute,
tt.fields.profileAttribute,
)
if tt.want.err == nil {
assert.NoError(t, err)
}
if tt.want.err != nil && !tt.want.err(err) {
t.Errorf("got wrong err: %v ", err)
}
if tt.want.err == nil {
assert.Equal(t, tt.want.user, got)
}
})
}
}

View File

@@ -0,0 +1,95 @@
package ldap
import (
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
)
type User struct {
ID string `json:"id,omitempty"`
FirstName string `json:"firstName,omitempty"`
LastName string `json:"lastName,omitempty"`
DisplayName string `json:"displayName,omitempty"`
NickName string `json:"nickName,omitempty"`
PreferredUsername string `json:"preferredUsername,omitempty"`
Email domain.EmailAddress `json:"email,omitempty"`
EmailVerified bool `json:"emailVerified,omitempty"`
Phone domain.PhoneNumber `json:"phone,omitempty"`
PhoneVerified bool `json:"phoneVerified,omitempty"`
PreferredLanguage language.Tag `json:"preferredLanguage,omitempty"`
AvatarURL string `json:"avatarURL,omitempty"`
Profile string `json:"profile,omitempty"`
}
func NewUser(
id string,
firstName string,
lastName string,
displayName string,
nickName string,
preferredUsername string,
email domain.EmailAddress,
emailVerified bool,
phone domain.PhoneNumber,
phoneVerified bool,
preferredLanguage language.Tag,
avatarURL string,
profile string,
) *User {
return &User{
id,
firstName,
lastName,
displayName,
nickName,
preferredUsername,
email,
emailVerified,
phone,
phoneVerified,
preferredLanguage,
avatarURL,
profile,
}
}
func (u *User) GetID() string {
return u.ID
}
func (u *User) GetFirstName() string {
return u.FirstName
}
func (u *User) GetLastName() string {
return u.LastName
}
func (u *User) GetDisplayName() string {
return u.DisplayName
}
func (u *User) GetNickname() string {
return u.NickName
}
func (u *User) GetPreferredUsername() string {
return u.PreferredUsername
}
func (u *User) GetEmail() domain.EmailAddress {
return u.Email
}
func (u *User) IsEmailVerified() bool {
return u.EmailVerified
}
func (u *User) GetPhone() domain.PhoneNumber {
return u.Phone
}
func (u *User) IsPhoneVerified() bool {
return u.PhoneVerified
}
func (u *User) GetPreferredLanguage() language.Tag {
return u.PreferredLanguage
}
func (u *User) GetAvatarURL() string {
return u.AvatarURL
}
func (u *User) GetProfile() string {
return u.Profile
}

View File

@@ -0,0 +1,110 @@
package oauth
import (
"encoding/json"
"fmt"
"strconv"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/idp"
)
var _ idp.User = (*UserMapper)(nil)
// UserMapper is an implementation of [idp.User].
// It can be used in ZITADEL actions to map the `RawInfo`
type UserMapper struct {
idAttribute string
RawInfo map[string]interface{}
}
func NewUserMapper(idAttribute string) *UserMapper {
return &UserMapper{
idAttribute: idAttribute,
RawInfo: make(map[string]interface{}),
}
}
func (u *UserMapper) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, &u.RawInfo)
}
// GetID is an implementation of the [idp.User] interface.
func (u *UserMapper) GetID() string {
id, ok := u.RawInfo[u.idAttribute]
if !ok {
return ""
}
switch i := id.(type) {
case string:
return i
case int:
return strconv.Itoa(i)
case float64:
return strconv.FormatFloat(i, 'f', -1, 64)
default:
return fmt.Sprint(i)
}
}
// GetFirstName is an implementation of the [idp.User] interface.
func (u *UserMapper) GetFirstName() string {
return ""
}
// GetLastName is an implementation of the [idp.User] interface.
func (u *UserMapper) GetLastName() string {
return ""
}
// GetDisplayName is an implementation of the [idp.User] interface.
func (u *UserMapper) GetDisplayName() string {
return ""
}
// GetNickname is an implementation of the [idp.User] interface.
func (u *UserMapper) GetNickname() string {
return ""
}
// GetPreferredUsername is an implementation of the [idp.User] interface.
func (u *UserMapper) GetPreferredUsername() string {
return ""
}
// GetEmail is an implementation of the [idp.User] interface.
func (u *UserMapper) GetEmail() domain.EmailAddress {
return ""
}
// IsEmailVerified is an implementation of the [idp.User] interface.
func (u *UserMapper) IsEmailVerified() bool {
return false
}
// GetPhone is an implementation of the [idp.User] interface.
func (u *UserMapper) GetPhone() domain.PhoneNumber {
return ""
}
// IsPhoneVerified is an implementation of the [idp.User] interface.
func (u *UserMapper) IsPhoneVerified() bool {
return false
}
// GetPreferredLanguage is an implementation of the [idp.User] interface.
func (u *UserMapper) GetPreferredLanguage() language.Tag {
return language.Und
}
// GetAvatarURL is an implementation of the [idp.User] interface.
func (u *UserMapper) GetAvatarURL() string {
return ""
}
// GetProfile is an implementation of the [idp.User] interface.
func (u *UserMapper) GetProfile() string {
return ""
}

View File

@@ -0,0 +1,143 @@
package oauth
import (
"context"
"github.com/zitadel/oidc/v3/pkg/client/rp"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
"github.com/zitadel/zitadel/internal/idp"
)
var _ idp.Provider = (*Provider)(nil)
// Provider is the [idp.Provider] implementation for a generic OAuth 2.0 provider
type Provider struct {
rp.RelyingParty
options []rp.Option
name string
userEndpoint string
user func() idp.User
isLinkingAllowed bool
isCreationAllowed bool
isAutoCreation bool
isAutoUpdate bool
generateVerifier func() string
}
type ProviderOpts func(provider *Provider)
// WithLinkingAllowed allows end users to link the federated user to an existing one.
func WithLinkingAllowed() ProviderOpts {
return func(p *Provider) {
p.isLinkingAllowed = true
}
}
// WithCreationAllowed allows end users to create a new user using the federated information.
func WithCreationAllowed() ProviderOpts {
return func(p *Provider) {
p.isCreationAllowed = true
}
}
// WithAutoCreation enables that federated users are automatically created if not already existing.
func WithAutoCreation() ProviderOpts {
return func(p *Provider) {
p.isAutoCreation = true
}
}
// WithAutoUpdate enables that information retrieved from the provider is automatically used to update
// the existing user on each authentication.
func WithAutoUpdate() ProviderOpts {
return func(p *Provider) {
p.isAutoUpdate = true
}
}
// WithRelyingPartyOption allows to set an additional [rp.Option] like [rp.WithPKCE].
func WithRelyingPartyOption(option rp.Option) ProviderOpts {
return func(p *Provider) {
p.options = append(p.options, option)
}
}
// New creates a generic OAuth 2.0 provider
func New(config *oauth2.Config, name, userEndpoint string, user func() idp.User, options ...ProviderOpts) (provider *Provider, err error) {
provider = &Provider{
name: name,
userEndpoint: userEndpoint,
user: user,
generateVerifier: oauth2.GenerateVerifier,
}
for _, option := range options {
option(provider)
}
provider.RelyingParty, err = rp.NewRelyingPartyOAuth(config, provider.options...)
if err != nil {
return nil, err
}
return provider, nil
}
// Name implements the [idp.Provider] interface
func (p *Provider) Name() string {
return p.name
}
// BeginAuth implements the [idp.Provider] interface.
// It will create a [Session] with an OAuth2.0 authorization request as AuthURL.
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...idp.Parameter) (idp.Session, error) {
opts := make([]rp.AuthURLOpt, 0)
var loginHintSet bool
for _, param := range params {
if username, ok := param.(idp.LoginHintParam); ok {
loginHintSet = true
opts = append(opts, loginHint(string(username)))
}
}
if !loginHintSet {
opts = append(opts, rp.WithPrompt(oidc.PromptSelectAccount))
}
var codeVerifier string
if p.RelyingParty.IsPKCE() {
codeVerifier = p.generateVerifier()
opts = append(opts, rp.WithCodeChallenge(oidc.NewSHACodeChallenge(codeVerifier)))
}
url := rp.AuthURL(state, p.RelyingParty, opts...)
return &Session{AuthURL: url, Provider: p, CodeVerifier: codeVerifier}, nil
}
func loginHint(hint string) rp.AuthURLOpt {
return func() []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("login_hint", hint)}
}
}
// IsLinkingAllowed implements the [idp.Provider] interface.
func (p *Provider) IsLinkingAllowed() bool {
return p.isLinkingAllowed
}
// IsCreationAllowed implements the [idp.Provider] interface.
func (p *Provider) IsCreationAllowed() bool {
return p.isCreationAllowed
}
// IsAutoCreation implements the [idp.Provider] interface.
func (p *Provider) IsAutoCreation() bool {
return p.isAutoCreation
}
// IsAutoUpdate implements the [idp.Provider] interface.
func (p *Provider) IsAutoUpdate() bool {
return p.isAutoUpdate
}
func (p *Provider) User() idp.User {
return p.user()
}

View File

@@ -0,0 +1,184 @@
package oauth
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/client/rp"
"golang.org/x/oauth2"
"github.com/zitadel/zitadel/internal/idp"
)
func TestProvider_BeginAuth(t *testing.T) {
type fields struct {
config *oauth2.Config
name string
userEndpoint string
userMapper func() idp.User
options []ProviderOpts
}
tests := []struct {
name string
fields fields
want idp.Session
}{
{
name: "successful auth without PKCE",
fields: fields{
config: &oauth2.Config{
ClientID: "clientID",
ClientSecret: "clientSecret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://oauth2.com/authorize",
TokenURL: "https://oauth2.com/token",
},
RedirectURL: "redirectURI",
Scopes: []string{"user"},
},
},
want: &Session{AuthURL: "https://oauth2.com/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=user&state=testState"},
},
{
name: "successful auth with PKCE",
fields: fields{
config: &oauth2.Config{
ClientID: "clientID",
ClientSecret: "clientSecret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://oauth2.com/authorize",
TokenURL: "https://oauth2.com/token",
},
RedirectURL: "redirectURI",
Scopes: []string{"user"},
},
options: []ProviderOpts{
WithLinkingAllowed(),
WithCreationAllowed(),
WithAutoCreation(),
WithAutoUpdate(),
WithRelyingPartyOption(rp.WithPKCE(nil)),
},
},
want: &Session{AuthURL: "https://oauth2.com/authorize?client_id=clientID&code_challenge=2ZoH_a01aprzLkwVbjlPsBo4m8mJ_zOKkaDqYM7Oh5w&code_challenge_method=S256&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=user&state=testState"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
r := require.New(t)
provider, err := New(tt.fields.config, tt.fields.name, tt.fields.userEndpoint, tt.fields.userMapper, tt.fields.options...)
r.NoError(err)
provider.generateVerifier = func() string {
return "pkceOAuthVerifier"
}
ctx := context.Background()
session, err := provider.BeginAuth(ctx, "testState")
r.NoError(err)
wantAuth, wantErr := tt.want.GetAuth(ctx)
gotAuth, gotErr := session.GetAuth(ctx)
a.Equal(wantAuth, gotAuth)
a.ErrorIs(gotErr, wantErr)
})
}
}
func TestProvider_Options(t *testing.T) {
type fields struct {
config *oauth2.Config
name string
userEndpoint string
userMapper func() idp.User
options []ProviderOpts
}
type want struct {
name string
linkingAllowed bool
creationAllowed bool
autoCreation bool
autoUpdate bool
pkce bool
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "default",
fields: fields{
name: "oauth",
config: &oauth2.Config{
ClientID: "clientID",
ClientSecret: "clientSecret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://oauth2.com/authorize",
TokenURL: "https://oauth2.com/token",
},
RedirectURL: "redirectURI",
Scopes: []string{"user"},
},
options: nil,
},
want: want{
name: "oauth",
linkingAllowed: false,
creationAllowed: false,
autoCreation: false,
autoUpdate: false,
pkce: false,
},
},
{
name: "all true",
fields: fields{
name: "oauth",
config: &oauth2.Config{
ClientID: "clientID",
ClientSecret: "clientSecret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://oauth2.com/authorize",
TokenURL: "https://oauth2.com/token",
},
RedirectURL: "redirectURI",
Scopes: []string{"user"},
},
options: []ProviderOpts{
WithLinkingAllowed(),
WithCreationAllowed(),
WithAutoCreation(),
WithAutoUpdate(),
WithRelyingPartyOption(rp.WithPKCE(nil)),
},
},
want: want{
name: "oauth",
linkingAllowed: true,
creationAllowed: true,
autoCreation: true,
autoUpdate: true,
pkce: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
provider, err := New(tt.fields.config, tt.fields.name, tt.fields.userEndpoint, tt.fields.userMapper, tt.fields.options...)
require.NoError(t, err)
a.Equal(tt.want.name, provider.Name())
a.Equal(tt.want.linkingAllowed, provider.IsLinkingAllowed())
a.Equal(tt.want.creationAllowed, provider.IsCreationAllowed())
a.Equal(tt.want.autoCreation, provider.IsAutoCreation())
a.Equal(tt.want.autoUpdate, provider.IsAutoUpdate())
a.Equal(tt.want.pkce, provider.RelyingParty.IsPKCE())
})
}
}

View File

@@ -0,0 +1,91 @@
package oauth
import (
"context"
"errors"
"net/http"
"time"
"github.com/zitadel/oidc/v3/pkg/client/rp"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/zitadel/internal/idp"
)
var ErrCodeMissing = errors.New("no auth code provided")
const (
CodeVerifier = "codeVerifier"
)
var _ idp.Session = (*Session)(nil)
// Session is the [idp.Session] implementation for the OAuth2.0 provider.
type Session struct {
AuthURL string
CodeVerifier string
Code string
Tokens *oidc.Tokens[*oidc.IDTokenClaims]
Provider *Provider
}
func NewSession(provider *Provider, code string, idpArguments map[string]any) *Session {
verifier, _ := idpArguments[CodeVerifier].(string)
return &Session{Provider: provider, Code: code, CodeVerifier: verifier}
}
// GetAuth implements the [idp.Session] interface.
func (s *Session) GetAuth(ctx context.Context) (idp.Auth, error) {
return idp.Redirect(s.AuthURL)
}
// PersistentParameters implements the [idp.Session] interface.
func (s *Session) PersistentParameters() map[string]any {
if s.CodeVerifier == "" {
return nil
}
return map[string]any{CodeVerifier: s.CodeVerifier}
}
// FetchUser implements the [idp.Session] interface.
// It will execute an OAuth 2.0 code exchange if needed to retrieve the access token,
// call the specified userEndpoint and map the received information into an [idp.User].
func (s *Session) FetchUser(ctx context.Context) (_ idp.User, err error) {
if s.Tokens == nil {
if err = s.authorize(ctx); err != nil {
return nil, err
}
}
req, err := http.NewRequest("GET", s.Provider.userEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("authorization", s.Tokens.TokenType+" "+s.Tokens.AccessToken)
user := s.Provider.User()
if err := httphelper.HttpRequest(s.Provider.RelyingParty.HttpClient(), req, &user); err != nil {
return nil, err
}
return user, nil
}
func (s *Session) ExpiresAt() time.Time {
if s.Tokens == nil {
return time.Time{}
}
return s.Tokens.Expiry
}
func (s *Session) authorize(ctx context.Context) (err error) {
if s.Code == "" {
return ErrCodeMissing
}
var opts []rp.CodeExchangeOpt
if s.CodeVerifier != "" {
opts = append(opts, rp.WithCodeVerifier(s.CodeVerifier))
}
s.Tokens, err = rp.CodeExchange[*oidc.IDTokenClaims](ctx, s.Code, s.Provider.RelyingParty, opts...)
return err
}

View File

@@ -0,0 +1,274 @@
package oauth
import (
"context"
"errors"
"net/http"
"testing"
"github.com/h2non/gock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/idp"
)
func TestProvider_FetchUser(t *testing.T) {
type fields struct {
config *oauth2.Config
name string
userEndpoint string
httpMock func(issuer string)
userMapper func() idp.User
authURL string
code string
tokens *oidc.Tokens[*oidc.IDTokenClaims]
}
type want struct {
err func(error) bool
user idp.User
id string
firstName string
lastName string
displayName string
nickName string
preferredUsername string
email string
isEmailVerified bool
phone string
isPhoneVerified bool
preferredLanguage language.Tag
avatarURL string
profile string
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "unauthenticated session, error",
fields: fields{
config: &oauth2.Config{
ClientID: "clientID",
ClientSecret: "clientSecret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://oauth2.com/authorize",
TokenURL: "https://oauth2.com/token",
},
RedirectURL: "redirectURI",
Scopes: []string{"user"},
},
userEndpoint: "https://oauth2.com/user",
httpMock: func(issuer string) {},
authURL: "https://oauth2.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=user&state=testState",
tokens: nil,
},
want: want{
err: func(err error) bool {
return errors.Is(err, ErrCodeMissing)
},
},
},
{
name: "user error",
fields: fields{
config: &oauth2.Config{
ClientID: "clientID",
ClientSecret: "clientSecret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://oauth2.com/authorize",
TokenURL: "https://oauth2.com/token",
},
RedirectURL: "redirectURI",
Scopes: []string{"user"},
},
userEndpoint: "https://oauth2.com/user",
httpMock: func(issuer string) {
gock.New(issuer).
Get("/user").
Reply(http.StatusInternalServerError)
},
userMapper: func() idp.User {
return NewUserMapper("userID")
},
authURL: "https://oauth2.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=user&state=testState",
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
},
},
want: want{
err: func(err error) bool {
return err.Error() == "http status not ok: 500 Internal Server Error "
},
},
},
{
name: "successful fetch",
fields: fields{
config: &oauth2.Config{
ClientID: "clientID",
ClientSecret: "clientSecret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://oauth2.com/authorize",
TokenURL: "https://oauth2.com/token",
},
RedirectURL: "redirectURI",
Scopes: []string{"user"},
},
userEndpoint: "https://oauth2.com/user",
httpMock: func(issuer string) {
gock.New(issuer).
Get("/user").
Reply(200).
JSON(map[string]interface{}{
"userID": "id",
"custom": "claim",
})
},
userMapper: func() idp.User {
return NewUserMapper("userID")
},
authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=user&state=testState",
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
},
},
want: want{
user: &UserMapper{
idAttribute: "userID",
RawInfo: map[string]interface{}{
"userID": "id",
"custom": "claim",
},
},
id: "id",
firstName: "",
lastName: "",
displayName: "",
nickName: "",
preferredUsername: "",
email: "",
isEmailVerified: false,
phone: "",
isPhoneVerified: false,
preferredLanguage: language.Und,
avatarURL: "",
profile: "",
},
},
{
name: "successful fetch with code exchange",
fields: fields{
config: &oauth2.Config{
ClientID: "clientID",
ClientSecret: "clientSecret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://oauth2.com/authorize",
TokenURL: "https://oauth2.com/token",
},
RedirectURL: "redirectURI",
Scopes: []string{"user"},
},
userEndpoint: "https://oauth2.com/user",
httpMock: func(issuer string) {
gock.New(issuer).
Post("/token").
BodyString("client_id=clientID&client_secret=clientSecret&code=code&grant_type=authorization_code&redirect_uri=redirectURI").
Reply(200).
JSON(&oidc.AccessTokenResponse{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
RefreshToken: "",
ExpiresIn: 3600,
IDToken: "",
State: "testState"})
gock.New(issuer).
Get("/user").
Reply(200).
JSON(map[string]interface{}{
"userID": "id",
"custom": "claim",
})
},
userMapper: func() idp.User {
return NewUserMapper("userID")
},
authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=user&state=testState",
tokens: nil,
code: "code",
},
want: want{
user: &UserMapper{
idAttribute: "userID",
RawInfo: map[string]interface{}{
"userID": "id",
"custom": "claim",
},
},
id: "id",
firstName: "",
lastName: "",
displayName: "",
nickName: "",
preferredUsername: "",
email: "",
isEmailVerified: false,
phone: "",
isPhoneVerified: false,
preferredLanguage: language.Und,
avatarURL: "",
profile: "",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer gock.Off()
tt.fields.httpMock("https://oauth2.com")
a := assert.New(t)
provider, err := New(tt.fields.config, tt.fields.name, tt.fields.userEndpoint, tt.fields.userMapper)
require.NoError(t, err)
session := &Session{
AuthURL: tt.fields.authURL,
Code: tt.fields.code,
Tokens: tt.fields.tokens,
Provider: provider,
}
user, err := session.FetchUser(context.Background())
if tt.want.err != nil && !tt.want.err(err) {
a.Fail("invalid error", err)
}
if tt.want.err == nil {
a.NoError(err)
a.Equal(tt.want.user, user)
a.Equal(tt.want.id, user.GetID())
a.Equal(tt.want.firstName, user.GetFirstName())
a.Equal(tt.want.lastName, user.GetLastName())
a.Equal(tt.want.displayName, user.GetDisplayName())
a.Equal(tt.want.nickName, user.GetNickname())
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
a.Equal(domain.EmailAddress(tt.want.email), user.GetEmail())
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
a.Equal(domain.PhoneNumber(tt.want.phone), user.GetPhone())
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
a.Equal(tt.want.profile, user.GetProfile())
}
})
}
}

View File

@@ -0,0 +1,190 @@
package oidc
import (
"context"
"github.com/zitadel/oidc/v3/pkg/client/rp"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
"github.com/zitadel/zitadel/internal/idp"
)
var _ idp.Provider = (*Provider)(nil)
// Provider is the [idp.Provider] implementation for a generic OIDC provider
type Provider struct {
rp.RelyingParty
options []rp.Option
name string
isLinkingAllowed bool
isCreationAllowed bool
isAutoCreation bool
isAutoUpdate bool
useIDToken bool
userInfoMapper func(info *oidc.UserInfo) idp.User
authOptions []func(bool) rp.AuthURLOpt
generateVerifier func() string
}
type ProviderOpts func(provider *Provider)
// WithLinkingAllowed allows end users to link the federated user to an existing one.
func WithLinkingAllowed() ProviderOpts {
return func(p *Provider) {
p.isLinkingAllowed = true
}
}
// WithCreationAllowed allows end users to create a new user using the federated information.
func WithCreationAllowed() ProviderOpts {
return func(p *Provider) {
p.isCreationAllowed = true
}
}
// WithAutoCreation enables that federated users are automatically created if not already existing.
func WithAutoCreation() ProviderOpts {
return func(p *Provider) {
p.isAutoCreation = true
}
}
// WithAutoUpdate enables that information retrieved from the provider is automatically used to update
// the existing user on each authentication.
func WithAutoUpdate() ProviderOpts {
return func(p *Provider) {
p.isAutoUpdate = true
}
}
// WithIDTokenMapping enables that information to map the user is retrieved from the id_token and not the userinfo endpoint.
func WithIDTokenMapping() ProviderOpts {
return func(p *Provider) {
p.useIDToken = true
}
}
// WithRelyingPartyOption allows to set an additional [rp.Option] like [rp.WithPKCE].
func WithRelyingPartyOption(option rp.Option) ProviderOpts {
return func(p *Provider) {
p.options = append(p.options, option)
}
}
// WithSelectAccount adds the select_account prompt to the auth request (if no login_hint is set)
func WithSelectAccount() ProviderOpts {
return func(p *Provider) {
p.authOptions = append(p.authOptions, func(loginHintSet bool) rp.AuthURLOpt {
if loginHintSet {
return nil
}
return rp.WithPrompt(oidc.PromptSelectAccount)
})
}
}
// WithResponseMode sets the `response_mode` params in the auth request
func WithResponseMode(mode oidc.ResponseMode) ProviderOpts {
return func(p *Provider) {
paramOpt := rp.WithResponseModeURLParam(mode)
p.authOptions = append(p.authOptions, func(_ bool) rp.AuthURLOpt {
return rp.AuthURLOpt(paramOpt)
})
}
}
type UserInfoMapper func(info *oidc.UserInfo) idp.User
var DefaultMapper UserInfoMapper = func(info *oidc.UserInfo) idp.User {
return NewUser(info)
}
// New creates a generic OIDC provider
func New(name, issuer, clientID, clientSecret, redirectURI string, scopes []string, userInfoMapper UserInfoMapper, options ...ProviderOpts) (provider *Provider, err error) {
provider = &Provider{
name: name,
userInfoMapper: userInfoMapper,
generateVerifier: oauth2.GenerateVerifier,
}
for _, option := range options {
option(provider)
}
provider.RelyingParty, err = rp.NewRelyingPartyOIDC(context.TODO(), issuer, clientID, clientSecret, redirectURI, setDefaultScope(scopes), provider.options...)
if err != nil {
return nil, err
}
return provider, nil
}
// setDefaultScope ensures that at least openid ist set
// if none is provided it will request `openid profile email phone`
func setDefaultScope(scopes []string) []string {
if len(scopes) == 0 {
return []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopePhone}
}
for _, scope := range scopes {
if scope == oidc.ScopeOpenID {
return scopes
}
}
return append(scopes, oidc.ScopeOpenID)
}
// Name implements the [idp.Provider] interface
func (p *Provider) Name() string {
return p.name
}
// BeginAuth implements the [idp.Provider] interface.
// It will create a [Session] with an OIDC authorization request as AuthURL.
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...idp.Parameter) (idp.Session, error) {
opts := make([]rp.AuthURLOpt, 0)
var loginHintSet bool
for _, param := range params {
if username, ok := param.(idp.LoginHintParam); ok {
loginHintSet = true
opts = append(opts, loginHint(string(username)))
}
}
for _, option := range p.authOptions {
if opt := option(loginHintSet); opt != nil {
opts = append(opts, opt)
}
}
var codeVerifier string
if p.RelyingParty.IsPKCE() {
codeVerifier = p.generateVerifier()
opts = append(opts, rp.WithCodeChallenge(oidc.NewSHACodeChallenge(codeVerifier)))
}
url := rp.AuthURL(state, p.RelyingParty, opts...)
return &Session{AuthURL: url, Provider: p, CodeVerifier: codeVerifier}, nil
}
func loginHint(hint string) rp.AuthURLOpt {
return func() []oauth2.AuthCodeOption {
return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("login_hint", hint)}
}
}
// IsLinkingAllowed implements the [idp.Provider] interface.
func (p *Provider) IsLinkingAllowed() bool {
return p.isLinkingAllowed
}
// IsCreationAllowed implements the [idp.Provider] interface.
func (p *Provider) IsCreationAllowed() bool {
return p.isCreationAllowed
}
// IsAutoCreation implements the [idp.Provider] interface.
func (p *Provider) IsAutoCreation() bool {
return p.isAutoCreation
}
// IsAutoUpdate implements the [idp.Provider] interface.
func (p *Provider) IsAutoUpdate() bool {
return p.isAutoUpdate
}

View File

@@ -0,0 +1,221 @@
package oidc
import (
"context"
"testing"
"github.com/h2non/gock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/client/rp"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/zitadel/internal/idp"
)
func TestProvider_BeginAuth(t *testing.T) {
type fields struct {
name string
issuer string
clientID string
clientSecret string
redirectURI string
scopes []string
userMapper func(info *oidc.UserInfo) idp.User
httpMock func(issuer string)
opts []ProviderOpts
}
tests := []struct {
name string
fields fields
want idp.Session
}{
{
name: "successful auth without PKCE",
fields: fields{
name: "oidc",
issuer: "https://issuer.com",
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
scopes: []string{"openid"},
userMapper: DefaultMapper,
httpMock: func(issuer string) {
gock.New(issuer).
Get(oidc.DiscoveryEndpoint).
Reply(200).
JSON(&oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
UserinfoEndpoint: issuer + "/userinfo",
})
},
opts: []ProviderOpts{WithSelectAccount()},
},
want: &Session{AuthURL: "https://issuer.com/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState"},
},
{
name: "successful auth with PKCE",
fields: fields{
name: "oidc",
issuer: "https://issuer.com",
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
scopes: []string{"openid"},
userMapper: DefaultMapper,
httpMock: func(issuer string) {
gock.New(issuer).
Get(oidc.DiscoveryEndpoint).
Reply(200).
JSON(&oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
UserinfoEndpoint: issuer + "/userinfo",
})
},
opts: []ProviderOpts{WithSelectAccount(), WithRelyingPartyOption(rp.WithPKCE(nil))},
},
want: &Session{AuthURL: "https://issuer.com/authorize?client_id=clientID&code_challenge=2ZoH_a01aprzLkwVbjlPsBo4m8mJ_zOKkaDqYM7Oh5w&code_challenge_method=S256&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer gock.Off()
tt.fields.httpMock(tt.fields.issuer)
a := assert.New(t)
r := require.New(t)
provider, err := New(tt.fields.name, tt.fields.issuer, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.userMapper, tt.fields.opts...)
r.NoError(err)
provider.generateVerifier = func() string {
return "pkceOAuthVerifier"
}
ctx := context.Background()
session, err := provider.BeginAuth(ctx, "testState")
r.NoError(err)
wantAuth, wantErr := tt.want.GetAuth(ctx)
gotAuth, gotErr := session.GetAuth(ctx)
a.Equal(wantAuth, gotAuth)
a.ErrorIs(gotErr, wantErr)
})
}
}
func TestProvider_Options(t *testing.T) {
type fields struct {
name string
issuer string
clientID string
clientSecret string
redirectURI string
scopes []string
userMapper func(info *oidc.UserInfo) idp.User
opts []ProviderOpts
httpMock func(issuer string)
}
type want struct {
name string
linkingAllowed bool
creationAllowed bool
autoCreation bool
autoUpdate bool
pkce bool
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "default",
fields: fields{
name: "oidc",
issuer: "https://issuer.com",
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
scopes: []string{"openid"},
userMapper: DefaultMapper,
opts: nil,
httpMock: func(issuer string) {
gock.New(issuer).
Get(oidc.DiscoveryEndpoint).
Reply(200).
JSON(&oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
UserinfoEndpoint: issuer + "/userinfo",
})
},
},
want: want{
name: "oidc",
linkingAllowed: false,
creationAllowed: false,
autoCreation: false,
autoUpdate: false,
pkce: false,
},
},
{
name: "all true",
fields: fields{
name: "oidc",
issuer: "https://issuer.com",
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
scopes: []string{"openid"},
userMapper: DefaultMapper,
opts: []ProviderOpts{
WithLinkingAllowed(),
WithCreationAllowed(),
WithAutoCreation(),
WithAutoUpdate(),
WithRelyingPartyOption(rp.WithPKCE(nil)),
},
httpMock: func(issuer string) {
gock.New(issuer).
Get(oidc.DiscoveryEndpoint).
Reply(200).
JSON(&oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
UserinfoEndpoint: issuer + "/userinfo",
})
},
},
want: want{
name: "oidc",
linkingAllowed: true,
creationAllowed: true,
autoCreation: true,
autoUpdate: true,
pkce: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer gock.Off()
tt.fields.httpMock(tt.fields.issuer)
a := assert.New(t)
provider, err := New(tt.fields.name, tt.fields.issuer, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.userMapper, tt.fields.opts...)
require.NoError(t, err)
a.Equal(tt.want.name, provider.Name())
a.Equal(tt.want.linkingAllowed, provider.IsLinkingAllowed())
a.Equal(tt.want.creationAllowed, provider.IsCreationAllowed())
a.Equal(tt.want.autoCreation, provider.IsAutoCreation())
a.Equal(tt.want.autoUpdate, provider.IsAutoUpdate())
})
}
}

View File

@@ -0,0 +1,157 @@
package oidc
import (
"context"
"errors"
"time"
"github.com/zitadel/oidc/v3/pkg/client/rp"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
)
var ErrCodeMissing = errors.New("no auth code provided")
var _ idp.Session = (*Session)(nil)
// Session is the [idp.Session] implementation for the OIDC provider.
type Session struct {
Provider *Provider
AuthURL string
CodeVerifier string
Code string
Tokens *oidc.Tokens[*oidc.IDTokenClaims]
}
func NewSession(provider *Provider, code string, idpArguments map[string]any) *Session {
verifier, _ := idpArguments[oauth.CodeVerifier].(string)
return &Session{Provider: provider, Code: code, CodeVerifier: verifier}
}
// GetAuth implements the [idp.Session] interface.
func (s *Session) GetAuth(ctx context.Context) (idp.Auth, error) {
return idp.Redirect(s.AuthURL)
}
// PersistentParameters implements the [idp.Session] interface.
func (s *Session) PersistentParameters() map[string]any {
if s.CodeVerifier == "" {
return nil
}
return map[string]any{oauth.CodeVerifier: s.CodeVerifier}
}
// FetchUser implements the [idp.Session] interface.
// It will execute an OIDC code exchange if needed to retrieve the tokens,
// call the userinfo endpoint and map the received information into an [idp.User].
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
if s.Tokens == nil {
if err = s.Authorize(ctx); err != nil {
return nil, err
}
}
var info *oidc.UserInfo
if s.Provider.useIDToken {
info = s.Tokens.IDTokenClaims.GetUserInfo()
} else {
info, err = rp.Userinfo[*oidc.UserInfo](ctx,
s.Tokens.AccessToken,
s.Tokens.TokenType,
s.Tokens.IDTokenClaims.GetSubject(),
s.Provider.RelyingParty,
)
if err != nil {
return nil, err
}
}
u := s.Provider.userInfoMapper(info)
return u, nil
}
func (s *Session) ExpiresAt() time.Time {
if s.Tokens == nil {
return time.Time{}
}
return s.Tokens.Expiry
}
func (s *Session) Authorize(ctx context.Context) (err error) {
if s.Code == "" {
return ErrCodeMissing
}
var opts []rp.CodeExchangeOpt
if s.CodeVerifier != "" {
opts = append(opts, rp.WithCodeVerifier(s.CodeVerifier))
}
s.Tokens, err = rp.CodeExchange[*oidc.IDTokenClaims](ctx, s.Code, s.Provider.RelyingParty, opts...)
return err
}
func NewUser(info *oidc.UserInfo) *User {
return &User{UserInfo: info}
}
func InitUser() *User {
return &User{UserInfo: &oidc.UserInfo{}}
}
type User struct {
*oidc.UserInfo
}
func (u *User) GetID() string {
return u.Subject
}
func (u *User) GetFirstName() string {
return u.GivenName
}
func (u *User) GetLastName() string {
return u.FamilyName
}
func (u *User) GetDisplayName() string {
return u.Name
}
func (u *User) GetNickname() string {
return u.Nickname
}
func (u *User) GetPreferredUsername() string {
return u.PreferredUsername
}
func (u *User) GetEmail() domain.EmailAddress {
return domain.EmailAddress(u.UserInfo.Email)
}
func (u *User) IsEmailVerified() bool {
return bool(u.EmailVerified)
}
func (u *User) GetPhone() domain.PhoneNumber {
return domain.PhoneNumber(u.PhoneNumber)
}
func (u *User) IsPhoneVerified() bool {
return u.PhoneNumberVerified
}
func (u *User) GetPreferredLanguage() language.Tag {
return u.Locale.Tag()
}
func (u *User) GetAvatarURL() string {
return u.Picture
}
func (u *User) GetProfile() string {
return u.Profile
}

View File

@@ -0,0 +1,471 @@
package oidc
import (
"context"
"encoding/json"
"errors"
"testing"
"time"
"github.com/go-jose/go-jose/v4"
"github.com/h2non/gock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/client/rp"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/idp"
)
func TestSession_FetchUser(t *testing.T) {
type fields struct {
name string
issuer string
clientID string
clientSecret string
redirectURI string
scopes []string
userMapper func(*oidc.UserInfo) idp.User
httpMock func(issuer string)
authURL string
code string
tokens *oidc.Tokens[*oidc.IDTokenClaims]
}
type want struct {
err error
id string
firstName string
lastName string
displayName string
nickName string
preferredUsername string
email string
isEmailVerified bool
phone string
isPhoneVerified bool
preferredLanguage language.Tag
avatarURL string
profile string
}
tests := []struct {
name string
fields fields
opts []ProviderOpts
want want
}{
{
name: "unauthenticated session, error",
fields: fields{
name: "oidc",
issuer: "https://issuer.com",
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
scopes: []string{"openid"},
userMapper: DefaultMapper,
httpMock: func(issuer string) {
gock.New(issuer).
Get(oidc.DiscoveryEndpoint).
Reply(200).
JSON(&oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
UserinfoEndpoint: issuer + "/userinfo",
})
gock.New(issuer).
Get("/userinfo").
Reply(200).
JSON(userinfo())
},
authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
tokens: nil,
},
want: want{
err: ErrCodeMissing,
},
},
{
name: "userinfo error",
fields: fields{
name: "oidc",
issuer: "https://issuer.com",
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
scopes: []string{"openid"},
userMapper: DefaultMapper,
httpMock: func(issuer string) {
gock.New(issuer).
Get(oidc.DiscoveryEndpoint).
Reply(200).
JSON(&oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
UserinfoEndpoint: issuer + "/userinfo",
})
gock.New(issuer).
Get("/userinfo").
Reply(200).
JSON(userinfo())
},
authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
IDTokenClaims: oidc.NewIDTokenClaims(
"https://issuer.com",
"sub2",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
want: want{
err: rp.ErrUserInfoSubNotMatching,
},
},
{
name: "successful fetch",
fields: fields{
name: "oidc",
issuer: "https://issuer.com",
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
scopes: []string{"openid"},
userMapper: DefaultMapper,
httpMock: func(issuer string) {
gock.New(issuer).
Get(oidc.DiscoveryEndpoint).
Reply(200).
JSON(&oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
UserinfoEndpoint: issuer + "/userinfo",
})
gock.New(issuer).
Get("/userinfo").
Reply(200).
JSON(userinfo())
},
authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
IDTokenClaims: oidc.NewIDTokenClaims(
"https://issuer.com",
"sub",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
want: want{
id: "sub",
firstName: "firstname",
lastName: "lastname",
displayName: "firstname lastname",
nickName: "nickname",
preferredUsername: "username",
email: "email",
isEmailVerified: true,
phone: "phone",
isPhoneVerified: true,
preferredLanguage: language.English,
avatarURL: "picture",
profile: "profile",
},
},
{
name: "use ID token",
fields: fields{
name: "oidc",
issuer: "https://issuer.com",
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
scopes: []string{"openid"},
userMapper: DefaultMapper,
httpMock: func(issuer string) {
gock.New(issuer).
Get(oidc.DiscoveryEndpoint).
Reply(200).
JSON(&oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
UserinfoEndpoint: issuer + "/userinfo",
})
},
authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
IDTokenClaims: func() *oidc.IDTokenClaims {
claims := oidc.NewIDTokenClaims(
"https://issuer.com",
"sub",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
)
claims.SetUserInfo(userinfo())
return claims
}(),
},
},
opts: []ProviderOpts{
WithIDTokenMapping(),
},
want: want{
id: "sub",
firstName: "firstname",
lastName: "lastname",
displayName: "firstname lastname",
nickName: "nickname",
preferredUsername: "username",
email: "email",
isEmailVerified: true,
phone: "phone",
isPhoneVerified: true,
preferredLanguage: language.English,
avatarURL: "picture",
profile: "profile",
},
},
{
name: "successful fetch with token exchange",
fields: fields{
name: "oidc",
issuer: "https://issuer.com",
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
scopes: []string{"openid"},
userMapper: DefaultMapper,
httpMock: func(issuer string) {
gock.New(issuer).
Get(oidc.DiscoveryEndpoint).
Reply(200).
JSON(&oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
JwksURI: issuer + "/keys",
UserinfoEndpoint: issuer + "/userinfo",
})
gock.New(issuer).
Post("/token").
BodyString("client_id=clientID&client_secret=clientSecret&code=code&grant_type=authorization_code&redirect_uri=redirectURI").
Reply(200).
JSON(tokenResponse(t, issuer))
gock.New(issuer).
Get("/keys").
Reply(200).
JSON(keys(t))
gock.New(issuer).
Get("/userinfo").
Reply(200).
JSON(userinfo())
},
authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
tokens: nil,
code: "code",
},
want: want{
id: "sub",
firstName: "firstname",
lastName: "lastname",
displayName: "firstname lastname",
nickName: "nickname",
preferredUsername: "username",
email: "email",
isEmailVerified: true,
phone: "phone",
isPhoneVerified: true,
preferredLanguage: language.English,
avatarURL: "picture",
profile: "profile",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer gock.Off()
tt.fields.httpMock(tt.fields.issuer)
a := assert.New(t)
provider, err := New(tt.fields.name, tt.fields.issuer, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.userMapper, tt.opts...)
require.NoError(t, err)
session := &Session{
Provider: provider,
AuthURL: tt.fields.authURL,
Code: tt.fields.code,
Tokens: tt.fields.tokens,
}
user, err := session.FetchUser(context.Background())
if tt.want.err != nil && !errors.Is(err, tt.want.err) {
a.Fail("invalid error", "expected %v, got %v", tt.want.err, err)
}
if tt.want.err == nil {
require.NoError(t, err)
a.Equal(tt.want.id, user.GetID())
a.Equal(tt.want.firstName, user.GetFirstName())
a.Equal(tt.want.lastName, user.GetLastName())
a.Equal(tt.want.displayName, user.GetDisplayName())
a.Equal(tt.want.nickName, user.GetNickname())
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
a.Equal(domain.EmailAddress(tt.want.email), user.GetEmail())
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
a.Equal(domain.PhoneNumber(tt.want.phone), user.GetPhone())
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
a.Equal(tt.want.profile, user.GetProfile())
}
})
}
}
func userinfo() *oidc.UserInfo {
return &oidc.UserInfo{
Subject: "sub",
UserInfoProfile: oidc.UserInfoProfile{
GivenName: "firstname",
FamilyName: "lastname",
Name: "firstname lastname",
Nickname: "nickname",
PreferredUsername: "username",
Locale: oidc.NewLocale(language.English),
Picture: "picture",
Profile: "profile",
},
UserInfoEmail: oidc.UserInfoEmail{
Email: "email",
EmailVerified: oidc.Bool(true),
},
UserInfoPhone: oidc.UserInfoPhone{
PhoneNumber: "phone",
PhoneNumberVerified: true,
},
}
}
func tokenResponse(t *testing.T, issuer string) *oidc.AccessTokenResponse {
claims := oidc.NewIDTokenClaims(
issuer,
"sub",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Minute),
"",
"",
nil,
"clientID",
0,
)
privateKey, err := crypto.BytesToPrivateKey([]byte(`-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAs38btwb3c7r0tMaQpGvBmY+mPwMU/LpfuPoC0k2t4RsKp0fv
40SMl50CRrHgk395wch8PMPYbl3+8TtYAJuyrFALIj3Ff1UcKIk0hOH5DDsfh7/q
2wFuncTmS6bifYo8CfSq2vDGnM7nZnEvxY/MfSydZdcmIqlkUpfQmtzExw9+tSe5
Dxq6gn5JtlGgLgZGt69r5iMMrTEGhhVAXzNuMZbmlCoBru+rC8ITlTX/0V1ZcsSb
L8tYWhthyu9x6yjo1bH85wiVI4gs0MhU8f2a+kjL/KGZbR14Ua2eo6tonBZLC5DH
WM2TkYXgRCDPufjcgmzN0Lm91E4P8KvBcvly6QIDAQABAoIBAQCPj1nbSPcg2KZe
73FAD+8HopyUSSK//1AP4eXfzcEECVy77g0u9+R6XlkzsZCsZ4g6NN8ounqfyw3c
YlpAIkcFCf/dowoSjT+4LASVQyatYZwWNqjgAIU4KgMG/rKnNahPTiBYe7peMB1j
EaPjnt8uPkCk8y7NCi3y4Pk24tt/WM5KbJK2NQhUi1csGnleDfE+0blV0l/e6C68
W5cbnbWAroMqae/Yon3XVZiXX0m+l2f6ZzIgKaD18J+eEM8FjJC+jQKiRe1i9v3K
nQrLwh/gn8J10FcbKn3xqslKVidzASIrNIzHT9j/Z5T9NXuAKa7IV2x+Dtdus+wq
iBsUunwBAoGBANpYew+8i9vDwK4/SefduDTuzJ0H9lWTjtbiWQ+KYZoeJ7q3/qns
jsmi+mjxkXxXg1RrGbNbjtbl3RXXIrUeeBB0lglRJUjc3VK7VvNoyXIWsiqhCspH
IJ9Yuknv4mXB01m/glbSCS/xu4RTgf5aOG4jUiRb9+dCIpvDxI9gbXEVAoGBANJz
hIJkplIJ+biTi3G1Oz17qkUkInNXzAEzKD9Atoz5AIAiR1ivOMLOlbucfjevw/Nw
TnpkMs9xqCefKupTlsriXtZI88m7ZKzAmolYsPolOy/Jhi31h9JFVTEfKGqVS+dk
A4ndhgdW9RUeNJPY2YVCARXQrWpueweQDA1cNaeFAoGAPJsYtXqBW6PPRM5+ZiSt
78tk8iV2o7RMjqrPS7f+dXfvUS2nO2VVEPTzCtQarOfhpToBLT65vD6bimdn09w8
OV0TFEz4y2u65y7m6LNqTwertpdy1ki97l0DgGhccCBH2P6GYDD2qd8wTH+dcot6
ZF/begopGoDJ+HBzi9SZLC0CgYBZzPslHMevyBvr++GLwrallKhiWnns1/DwLiEl
ZHrBCtuA0Z+6IwLIdZiE9tEQ+ApYTXrfVPQteqUzSwLn/IUiy5eGPpjwYushoAoR
Q2w5QTvRN1/vKo8rVXR1woLfgBdkhFPSN1mitiNcQIhU8jpXV4PZCDOHb99FqdzK
sqcedQKBgQCOmgbqxGsnT2WQhoOdzln+NOo6Tx+FveLLqat2KzpY59W4noeI2Awn
HfIQgWUAW9dsjVVOXMP1jhq8U9hmH/PFWA11V/iCdk1NTxZEw87VAOeWuajpdDHG
+iex349j8h2BcQ4Zd0FWu07gGFnS/yuDJPn6jBhRusdieEcxLRjTKg==
-----END RSA PRIVATE KEY-----
`))
if err != nil {
t.Fatal(err)
}
signer, err := jose.NewSigner(jose.SigningKey{Key: privateKey, Algorithm: "RS256"}, &jose.SignerOptions{})
if err != nil {
t.Fatal(err)
}
data, err := json.Marshal(claims)
if err != nil {
t.Fatal(err)
}
jws, err := signer.Sign(data)
if err != nil {
t.Fatal(err)
}
idToken, err := jws.CompactSerialize()
if err != nil {
t.Fatal(err)
}
return &oidc.AccessTokenResponse{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
RefreshToken: "",
ExpiresIn: 3600,
IDToken: idToken,
State: "testState",
}
}
func keys(t *testing.T) *jose.JSONWebKeySet {
privateKey, err := crypto.BytesToPublicKey([]byte(`-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAs38btwb3c7r0tMaQpGvB
mY+mPwMU/LpfuPoC0k2t4RsKp0fv40SMl50CRrHgk395wch8PMPYbl3+8TtYAJuy
rFALIj3Ff1UcKIk0hOH5DDsfh7/q2wFuncTmS6bifYo8CfSq2vDGnM7nZnEvxY/M
fSydZdcmIqlkUpfQmtzExw9+tSe5Dxq6gn5JtlGgLgZGt69r5iMMrTEGhhVAXzNu
MZbmlCoBru+rC8ITlTX/0V1ZcsSbL8tYWhthyu9x6yjo1bH85wiVI4gs0MhU8f2a
+kjL/KGZbR14Ua2eo6tonBZLC5DHWM2TkYXgRCDPufjcgmzN0Lm91E4P8KvBcvly
6QIDAQAB
-----END PUBLIC KEY-----
`))
if err != nil {
t.Fatal(err)
}
return &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{{Key: privateKey, Algorithm: "RS256", Use: oidc.KeyUseSignature}}}
}

View File

@@ -0,0 +1,89 @@
package saml
import (
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/idp"
)
var _ idp.User = (*UserMapper)(nil)
// UserMapper is an implementation of [idp.User].
type UserMapper struct {
ID string `json:"id,omitempty"`
Attributes map[string][]string `json:"attributes,omitempty"`
}
func NewUser() *UserMapper {
return &UserMapper{Attributes: map[string][]string{}}
}
func (u *UserMapper) SetID(id string) {
u.ID = id
}
// GetID is an implementation of the [idp.User] interface.
func (u *UserMapper) GetID() string {
return u.ID
}
// GetFirstName is an implementation of the [idp.User] interface.
func (u *UserMapper) GetFirstName() string {
return ""
}
// GetLastName is an implementation of the [idp.User] interface.
func (u *UserMapper) GetLastName() string {
return ""
}
// GetDisplayName is an implementation of the [idp.User] interface.
func (u *UserMapper) GetDisplayName() string {
return ""
}
// GetNickname is an implementation of the [idp.User] interface.
func (u *UserMapper) GetNickname() string {
return ""
}
// GetPreferredUsername is an implementation of the [idp.User] interface.
func (u *UserMapper) GetPreferredUsername() string {
return ""
}
// GetEmail is an implementation of the [idp.User] interface.
func (u *UserMapper) GetEmail() domain.EmailAddress {
return ""
}
// IsEmailVerified is an implementation of the [idp.User] interface.
func (u *UserMapper) IsEmailVerified() bool {
return false
}
// GetPhone is an implementation of the [idp.User] interface.
func (u *UserMapper) GetPhone() domain.PhoneNumber {
return ""
}
// IsPhoneVerified is an implementation of the [idp.User] interface.
func (u *UserMapper) IsPhoneVerified() bool {
return false
}
// GetPreferredLanguage is an implementation of the [idp.User] interface.
func (u *UserMapper) GetPreferredLanguage() language.Tag {
return language.Und
}
// GetAvatarURL is an implementation of the [idp.User] interface.
func (u *UserMapper) GetAvatarURL() string {
return ""
}
// GetProfile is an implementation of the [idp.User] interface.
func (u *UserMapper) GetProfile() string {
return ""
}

View File

@@ -0,0 +1,58 @@
package requesttracker
import (
"context"
"net/http"
"github.com/crewjam/saml/samlsp"
)
type GetRequest func(ctx context.Context, intentID string) (*samlsp.TrackedRequest, error)
type AddRequest func(ctx context.Context, intentID, requestID string) error
type RequestTracker struct {
addRequest AddRequest
getRequest GetRequest
}
func New(addRequestF AddRequest, getRequestF GetRequest) samlsp.RequestTracker {
return &RequestTracker{
addRequest: addRequestF,
getRequest: getRequestF,
}
}
func (rt *RequestTracker) TrackRequest(w http.ResponseWriter, r *http.Request, samlRequestID string) (index string, err error) {
// intentID is stored in r.URL
intentID := r.URL.String()
if err := rt.addRequest(r.Context(), intentID, samlRequestID); err != nil {
return "", err
}
return intentID, nil
}
func (rt *RequestTracker) StopTrackingRequest(w http.ResponseWriter, r *http.Request, index string) error {
// error is not handled in SP logic
return nil
}
func (rt *RequestTracker) GetTrackedRequests(r *http.Request) []samlsp.TrackedRequest {
// RelayState is the context of the auth flow and as such contains the intentID
intentID := r.FormValue("RelayState")
request, err := rt.getRequest(r.Context(), intentID)
if err != nil {
return nil
}
return []samlsp.TrackedRequest{
{
Index: request.Index,
SAMLRequestID: request.SAMLRequestID,
URI: request.URI,
},
}
}
func (rt *RequestTracker) GetTrackedRequest(r *http.Request, index string) (*samlsp.TrackedRequest, error) {
return rt.getRequest(r.Context(), index)
}

View File

@@ -0,0 +1,278 @@
package saml
import (
"bytes"
"context"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/xml"
"io"
"net/url"
"time"
"github.com/crewjam/saml"
"github.com/crewjam/saml/samlsp"
"golang.org/x/text/encoding/ianaindex"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/zerrors"
)
var _ idp.Provider = (*Provider)(nil)
// Provider is the [idp.Provider] implementation for a generic SAML provider
type Provider struct {
name string
requestTracker samlsp.RequestTracker
Certificate []byte
spOptions *samlsp.Options
binding string
nameIDFormat saml.NameIDFormat
transientMappingAttributeName string
isLinkingAllowed bool
isCreationAllowed bool
isAutoCreation bool
isAutoUpdate bool
}
type ProviderOpts func(provider *Provider)
// WithLinkingAllowed allows end users to link the federated user to an existing one.
func WithLinkingAllowed() ProviderOpts {
return func(p *Provider) {
p.isLinkingAllowed = true
}
}
// WithCreationAllowed allows end users to create a new user using the federated information.
func WithCreationAllowed() ProviderOpts {
return func(p *Provider) {
p.isCreationAllowed = true
}
}
// WithAutoCreation enables that federated users are automatically created if not already existing.
func WithAutoCreation() ProviderOpts {
return func(p *Provider) {
p.isAutoCreation = true
}
}
// WithAutoUpdate enables that information retrieved from the provider is automatically used to update
// the existing user on each authentication.
func WithAutoUpdate() ProviderOpts {
return func(p *Provider) {
p.isAutoUpdate = true
}
}
func WithSignedRequest() ProviderOpts {
return func(p *Provider) {
p.spOptions.SignRequest = true
}
}
func WithBinding(binding string) ProviderOpts {
return func(p *Provider) {
p.binding = binding
}
}
func WithNameIDFormat(format domain.SAMLNameIDFormat) ProviderOpts {
return func(p *Provider) {
p.nameIDFormat = nameIDFormatFromDomain(format)
}
}
func WithTransientMappingAttributeName(attribute string) ProviderOpts {
return func(p *Provider) {
p.transientMappingAttributeName = attribute
}
}
func WithCustomRequestTracker(tracker samlsp.RequestTracker) ProviderOpts {
return func(p *Provider) {
p.requestTracker = tracker
}
}
func WithEntityID(entityID string) ProviderOpts {
return func(p *Provider) {
p.spOptions.EntityID = entityID
}
}
// ParseMetadata parses the metadata with the provided XML encoding and returns the EntityDescriptor
func ParseMetadata(metadata []byte) (*saml.EntityDescriptor, error) {
entityDescriptor := new(saml.EntityDescriptor)
reader := bytes.NewReader(metadata)
decoder := xml.NewDecoder(reader)
decoder.CharsetReader = func(charset string, reader io.Reader) (io.Reader, error) {
enc, err := ianaindex.IANA.Encoding(charset)
if err != nil {
return nil, err
}
return enc.NewDecoder().Reader(reader), nil
}
if err := decoder.Decode(entityDescriptor); err != nil {
if err.Error() == "expected element type <EntityDescriptor> but have <EntitiesDescriptor>" {
// reset reader to start of metadata so we can try to parse it as an EntitiesDescriptor
if _, err := reader.Seek(0, io.SeekStart); err != nil {
return nil, err
}
entities := &EntitiesDescriptor{}
if err := decoder.Decode(entities); err != nil {
return nil, err
}
for i, e := range entities.EntityDescriptors {
if len(e.IDPSSODescriptors) > 0 {
return &entities.EntityDescriptors[i], nil
}
}
return nil, zerrors.ThrowInternal(nil, "SAML-Ejoi3r2", "no entity found with IDPSSODescriptor")
}
return nil, err
}
return entityDescriptor, nil
}
func New(
name string,
rootURLStr string,
metadata []byte,
certificate []byte,
key []byte,
options ...ProviderOpts,
) (*Provider, error) {
entityDescriptor, err := ParseMetadata(metadata)
if err != nil {
return nil, err
}
keyPair, err := tls.X509KeyPair(certificate, key)
if err != nil {
return nil, err
}
keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
if err != nil {
return nil, err
}
rootURL, err := url.Parse(rootURLStr)
if err != nil {
return nil, err
}
opts := samlsp.Options{
URL: *rootURL,
Key: keyPair.PrivateKey.(*rsa.PrivateKey),
Certificate: keyPair.Leaf,
IDPMetadata: entityDescriptor,
SignRequest: false,
}
provider := &Provider{
name: name,
spOptions: &opts,
Certificate: certificate,
// the library uses transient as default, which does not make sense for federating accounts
nameIDFormat: saml.PersistentNameIDFormat,
}
for _, option := range options {
option(provider)
}
return provider, nil
}
func (p *Provider) Name() string {
return p.name
}
func (p *Provider) IsLinkingAllowed() bool {
return p.isLinkingAllowed
}
func (p *Provider) IsCreationAllowed() bool {
return p.isCreationAllowed
}
func (p *Provider) IsAutoCreation() bool {
return p.isAutoCreation
}
func (p *Provider) IsAutoUpdate() bool {
return p.isAutoUpdate
}
func (p *Provider) GetSP() (*samlsp.Middleware, error) {
sp, err := samlsp.New(*p.spOptions)
if err != nil {
return nil, zerrors.ThrowInternal(err, "SAML-qee09ffuq5", "Errors.Intent.IDPInvalid")
}
sp.ServiceProvider.AuthnNameIDFormat = p.nameIDFormat
if p.requestTracker != nil {
sp.RequestTracker = p.requestTracker
}
if p.binding != "" {
sp.Binding = p.binding
}
sp.ServiceProvider.MetadataValidDuration = time.Until(sp.ServiceProvider.Certificate.NotAfter)
return sp, nil
}
func (p *Provider) BeginAuth(ctx context.Context, state string, _ ...idp.Parameter) (idp.Session, error) {
m, err := p.GetSP()
if err != nil {
return nil, err
}
return &Session{
ServiceProvider: m,
state: state,
}, nil
}
func (p *Provider) TransientMappingAttributeName() string {
return p.transientMappingAttributeName
}
func nameIDFormatFromDomain(format domain.SAMLNameIDFormat) saml.NameIDFormat {
switch format {
case domain.SAMLNameIDFormatUnspecified:
return saml.UnspecifiedNameIDFormat
case domain.SAMLNameIDFormatEmailAddress:
return saml.EmailAddressNameIDFormat
case domain.SAMLNameIDFormatPersistent:
return saml.PersistentNameIDFormat
case domain.SAMLNameIDFormatTransient:
return saml.TransientNameIDFormat
default:
return saml.UnspecifiedNameIDFormat
}
}
// EntitiesDescriptor is a workaround until we eventually fork the crewjam/saml library, since maintenance on that repo seems to have stopped.
// This is to be able to handle xsd:duration format using the UnmarshalXML method.
// crewjam/saml only implements the xsd:dateTime format for EntityDescriptor, but not EntitiesDescriptor.
type EntitiesDescriptor saml.EntitiesDescriptor
// UnmarshalXML implements xml.Unmarshaler
func (m *EntitiesDescriptor) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
type Alias EntitiesDescriptor
aux := &struct {
ValidUntil *saml.RelaxedTime `xml:"validUntil,attr,omitempty"`
CacheDuration *saml.Duration `xml:"cacheDuration,attr,omitempty"`
*Alias
}{
Alias: (*Alias)(m),
}
if err := d.DecodeElement(aux, &start); err != nil {
return err
}
m.ValidUntil = (*time.Time)(aux.ValidUntil)
m.CacheDuration = (*time.Duration)(aux.CacheDuration)
return nil
}

View File

@@ -0,0 +1,438 @@
package saml
import (
"context"
"encoding/xml"
"net/url"
"testing"
"time"
"github.com/crewjam/saml"
"github.com/crewjam/saml/samlsp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/saml/requesttracker"
"github.com/zitadel/zitadel/internal/zerrors"
)
func TestProvider_BeginAuth(t *testing.T) {
requestTracker := requesttracker.New(
func(ctx context.Context, authRequestID, samlRequestID string) error {
assert.Equal(t, "state", authRequestID)
return nil
},
func(ctx context.Context, authRequestID string) (*samlsp.TrackedRequest, error) {
return &samlsp.TrackedRequest{
SAMLRequestID: "state",
Index: authRequestID,
}, nil
},
)
type fields struct {
name string
rootURL string
metadata []byte
certificate []byte
key []byte
options []ProviderOpts
}
type args struct {
state string
}
type want struct {
err func(error) bool
authType idp.Auth
ssoURL string
relayState string
}
tests := []struct {
name string
fields fields
args args
want want
}{
{
name: "redirect binding, success",
fields: fields{
name: "saml",
rootURL: "https://localhost:8080",
metadata: []byte("<EntityDescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" validUntil=\"2023-08-27T12:40:58.803Z\" cacheDuration=\"PT48H\" entityID=\"http://localhost:8000/metadata\">\n <IDPSSODescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <KeyDescriptor use=\"signing\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n </KeyDescriptor>\n <KeyDescriptor use=\"encryption\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes128-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes192-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes256-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p\"></EncryptionMethod>\n </KeyDescriptor>\n <NameIDFormat>urn:oasis:names:tc:SAML:2.0:nameid-format:transient</NameIDFormat>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n </IDPSSODescriptor>\n</EntityDescriptor>"),
certificate: []byte("-----BEGIN CERTIFICATE-----\nMIIC2zCCAcOgAwIBAgIIAy/jm1gAAdEwDQYJKoZIhvcNAQELBQAwEjEQMA4GA1UE\nChMHWklUQURFTDAeFw0yMzA4MzAwNzExMTVaFw0yNDA4MjkwNzExMTVaMBIxEDAO\nBgNVBAoTB1pJVEFERUwwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDE\nd3TztGgSb3LBVZn8f60NbFCyZW+F9HPiMCr9F9T45Zc0fgmMwxId0WzRD5Y/3yc1\ndHJzt+Bsxvw12aUHbIPiothqk3lINoFzl2H/cSfIW3nehKyNOUqdBQ8B4mvaqH81\njTjoJ/JTJAwzglHk6JAWjhOyx9aep1yBqYa3QASeTaW9sxkpB0Co1L2UPNhuMwZq\n8RA9NkTfmYVcVBeNqihler5MhruFtqrv+J0ftwc1stw8uCN89ADyr4Ni+e+FeWar\nQs9Bkfc6KLF/5IXa9HCsHNPaaoYPY6I6RSaG4/DKoSKIEe1/GSVG1FTpZ8trUZxv\nU+xXS6gEalXcrJsiX8aXAgMBAAGjNTAzMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUE\nDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMA0GCSqGSIb3DQEBCwUAA4IBAQCx\n/dRNIj0N/16zJhZR/ahkc2AkvDXYxyr4JRT5wK9GQDNl/oaX3debRuSi/tfaXFIX\naJA6PxM4J49ZaiEpLrKfxMz5kAhjKchCBEMcH3mGt+iNZH7EOyTvHjpGrP2OZrsh\nO17yrvN3HuQxIU6roJlqtZz2iAADsoPtwOO4D7hupm9XTMkSnAmlMWOo/q46Jz89\n1sMxB+dXmH/zV0wgwh0omZfLV0u89mvdq269VhcjNBpBYSnN1ccqYWd5iwziob3I\nvaavGHGfkbvRUn/tKftYuTK30q03R+e9YbmlWZ0v695owh2e/apCzowQsCKfSVC8\nOxVyt5XkHq1tWwVyBmFp\n-----END CERTIFICATE-----\n"),
key: []byte("-----BEGIN RSA PRIVATE KEY-----\nMIIEogIBAAKCAQEAxHd087RoEm9ywVWZ/H+tDWxQsmVvhfRz4jAq/RfU+OWXNH4J\njMMSHdFs0Q+WP98nNXRyc7fgbMb8NdmlB2yD4qLYapN5SDaBc5dh/3EnyFt53oSs\njTlKnQUPAeJr2qh/NY046CfyUyQMM4JR5OiQFo4TssfWnqdcgamGt0AEnk2lvbMZ\nKQdAqNS9lDzYbjMGavEQPTZE35mFXFQXjaooZXq+TIa7hbaq7/idH7cHNbLcPLgj\nfPQA8q+DYvnvhXlmq0LPQZH3Oiixf+SF2vRwrBzT2mqGD2OiOkUmhuPwyqEiiBHt\nfxklRtRU6WfLa1Gcb1PsV0uoBGpV3KybIl/GlwIDAQABAoIBAEQjDduLgOCL6Gem\n0X3hpdnW6/HC/jed/Sa//9jBECq2LYeWAqff64ON40hqOHi0YvvGA/+gEOSI6mWe\nsv5tIxxRz+6+cLybsq+tG96kluCE4TJMHy/nY7orS/YiWbd+4odnEApr+D3fbZ/b\nnZ1fDsHTyn8hkYx6jLmnWsJpIHDp7zxD76y7k2Bbg6DZrCGiVxngiLJk23dvz79W\np03lHLM7XE92aFwXQmhfxHGxrbuoB/9eY4ai5IHp36H4fw0vL6NXdNQAo/bhe0p9\nAYB7y0ZumF8Hg0Z/BmMeEzLy6HrYB+VE8cO93pNjhSyH+p2yDB/BlUyTiRLQAoM0\nVTmOZXECgYEA7NGlzpKNhyQEJihVqt0MW0LhKIO/xbBn+XgYfX6GpqPa/ucnMx5/\nVezpl3gK8IU4wPUhAyXXAHJiqNBcEeyxrw0MXLujDVMJgYaLysCLJdvMVgoY08mS\nK5IQivpbozpf4+0y3mOnA+Sy1kbfxv2X8xiWLODRQW3f3q/xoklwOR8CgYEA1GEe\nfaibOFTQAYcIVj77KXtBfYZsX3EGAyfAN9O7cKHq5oaxVstwnF47WxpuVtoKZxCZ\nbNm9D5WvQ9b+Ztpioe42tzwE7Bff/Osj868GcDdRPK7nFlh9N2yVn/D514dOYVwR\n4MBr1KrJzgRWt4QqS4H+to1GzudDTSNlG7gnK4kCgYBUi6AbOHzoYzZL/RhgcJwp\ntJ23nhmH1Su5h2OO4e3mbhcP66w19sxU+8iFN+kH5zfUw26utgKk+TE5vXExQQRK\nT2k7bg2PAzcgk80ybD0BHhA8I0yrx4m0nmfjhe/TPVLgh10iwgbtP+eM0i6v1vc5\nZWyvxu9N4ZEL6lpkqr0y1wKBgG/NAIQd8jhhTW7Aav8cAJQBsqQl038avJOEpYe+\nCnpsgoAAf/K0/f8TDCQVceh+t+MxtdK7fO9rWOxZjWsPo8Si5mLnUaAHoX4/OpnZ\nlYYVWMqdOEFnK+O1Yb7k2GFBdV2DXlX2dc1qavntBsls5ecB89id3pyk2aUN8Pf6\npYQhAoGAMGtrHFely9wyaxI0RTCyfmJbWZHGVGkv6ELK8wneJjdjl82XOBUGCg5q\naRCrTZ3dPitKwrUa6ibJCIFCIziiriBmjDvTHzkMvoJEap2TVxYNDR6IfINVsQ57\nlOsiC4A2uGq4Lbfld+gjoplJ5GX6qXtTgZ6m7eo0y7U6zm2tkN0=\n-----END RSA PRIVATE KEY-----\n"),
options: []ProviderOpts{
WithCustomRequestTracker(requestTracker),
},
},
args: args{
state: "state",
},
want: want{
authType: &idp.RedirectAuth{},
ssoURL: "http://localhost:8000/sso",
relayState: "state",
},
},
{
name: "post binding, success",
fields: fields{
name: "saml",
rootURL: "https://localhost:8080",
metadata: []byte("<EntityDescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" validUntil=\"2023-08-27T12:40:58.803Z\" cacheDuration=\"PT48H\" entityID=\"http://localhost:8000/metadata\">\n <IDPSSODescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <KeyDescriptor use=\"signing\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n </KeyDescriptor>\n <KeyDescriptor use=\"encryption\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes128-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes192-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes256-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p\"></EncryptionMethod>\n </KeyDescriptor>\n <NameIDFormat>urn:oasis:names:tc:SAML:2.0:nameid-format:transient</NameIDFormat>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n </IDPSSODescriptor>\n</EntityDescriptor>"),
certificate: []byte("-----BEGIN CERTIFICATE-----\nMIIC2zCCAcOgAwIBAgIIAy/jm1gAAdEwDQYJKoZIhvcNAQELBQAwEjEQMA4GA1UE\nChMHWklUQURFTDAeFw0yMzA4MzAwNzExMTVaFw0yNDA4MjkwNzExMTVaMBIxEDAO\nBgNVBAoTB1pJVEFERUwwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDE\nd3TztGgSb3LBVZn8f60NbFCyZW+F9HPiMCr9F9T45Zc0fgmMwxId0WzRD5Y/3yc1\ndHJzt+Bsxvw12aUHbIPiothqk3lINoFzl2H/cSfIW3nehKyNOUqdBQ8B4mvaqH81\njTjoJ/JTJAwzglHk6JAWjhOyx9aep1yBqYa3QASeTaW9sxkpB0Co1L2UPNhuMwZq\n8RA9NkTfmYVcVBeNqihler5MhruFtqrv+J0ftwc1stw8uCN89ADyr4Ni+e+FeWar\nQs9Bkfc6KLF/5IXa9HCsHNPaaoYPY6I6RSaG4/DKoSKIEe1/GSVG1FTpZ8trUZxv\nU+xXS6gEalXcrJsiX8aXAgMBAAGjNTAzMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUE\nDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMA0GCSqGSIb3DQEBCwUAA4IBAQCx\n/dRNIj0N/16zJhZR/ahkc2AkvDXYxyr4JRT5wK9GQDNl/oaX3debRuSi/tfaXFIX\naJA6PxM4J49ZaiEpLrKfxMz5kAhjKchCBEMcH3mGt+iNZH7EOyTvHjpGrP2OZrsh\nO17yrvN3HuQxIU6roJlqtZz2iAADsoPtwOO4D7hupm9XTMkSnAmlMWOo/q46Jz89\n1sMxB+dXmH/zV0wgwh0omZfLV0u89mvdq269VhcjNBpBYSnN1ccqYWd5iwziob3I\nvaavGHGfkbvRUn/tKftYuTK30q03R+e9YbmlWZ0v695owh2e/apCzowQsCKfSVC8\nOxVyt5XkHq1tWwVyBmFp\n-----END CERTIFICATE-----\n"),
key: []byte("-----BEGIN RSA PRIVATE KEY-----\nMIIEogIBAAKCAQEAxHd087RoEm9ywVWZ/H+tDWxQsmVvhfRz4jAq/RfU+OWXNH4J\njMMSHdFs0Q+WP98nNXRyc7fgbMb8NdmlB2yD4qLYapN5SDaBc5dh/3EnyFt53oSs\njTlKnQUPAeJr2qh/NY046CfyUyQMM4JR5OiQFo4TssfWnqdcgamGt0AEnk2lvbMZ\nKQdAqNS9lDzYbjMGavEQPTZE35mFXFQXjaooZXq+TIa7hbaq7/idH7cHNbLcPLgj\nfPQA8q+DYvnvhXlmq0LPQZH3Oiixf+SF2vRwrBzT2mqGD2OiOkUmhuPwyqEiiBHt\nfxklRtRU6WfLa1Gcb1PsV0uoBGpV3KybIl/GlwIDAQABAoIBAEQjDduLgOCL6Gem\n0X3hpdnW6/HC/jed/Sa//9jBECq2LYeWAqff64ON40hqOHi0YvvGA/+gEOSI6mWe\nsv5tIxxRz+6+cLybsq+tG96kluCE4TJMHy/nY7orS/YiWbd+4odnEApr+D3fbZ/b\nnZ1fDsHTyn8hkYx6jLmnWsJpIHDp7zxD76y7k2Bbg6DZrCGiVxngiLJk23dvz79W\np03lHLM7XE92aFwXQmhfxHGxrbuoB/9eY4ai5IHp36H4fw0vL6NXdNQAo/bhe0p9\nAYB7y0ZumF8Hg0Z/BmMeEzLy6HrYB+VE8cO93pNjhSyH+p2yDB/BlUyTiRLQAoM0\nVTmOZXECgYEA7NGlzpKNhyQEJihVqt0MW0LhKIO/xbBn+XgYfX6GpqPa/ucnMx5/\nVezpl3gK8IU4wPUhAyXXAHJiqNBcEeyxrw0MXLujDVMJgYaLysCLJdvMVgoY08mS\nK5IQivpbozpf4+0y3mOnA+Sy1kbfxv2X8xiWLODRQW3f3q/xoklwOR8CgYEA1GEe\nfaibOFTQAYcIVj77KXtBfYZsX3EGAyfAN9O7cKHq5oaxVstwnF47WxpuVtoKZxCZ\nbNm9D5WvQ9b+Ztpioe42tzwE7Bff/Osj868GcDdRPK7nFlh9N2yVn/D514dOYVwR\n4MBr1KrJzgRWt4QqS4H+to1GzudDTSNlG7gnK4kCgYBUi6AbOHzoYzZL/RhgcJwp\ntJ23nhmH1Su5h2OO4e3mbhcP66w19sxU+8iFN+kH5zfUw26utgKk+TE5vXExQQRK\nT2k7bg2PAzcgk80ybD0BHhA8I0yrx4m0nmfjhe/TPVLgh10iwgbtP+eM0i6v1vc5\nZWyvxu9N4ZEL6lpkqr0y1wKBgG/NAIQd8jhhTW7Aav8cAJQBsqQl038avJOEpYe+\nCnpsgoAAf/K0/f8TDCQVceh+t+MxtdK7fO9rWOxZjWsPo8Si5mLnUaAHoX4/OpnZ\nlYYVWMqdOEFnK+O1Yb7k2GFBdV2DXlX2dc1qavntBsls5ecB89id3pyk2aUN8Pf6\npYQhAoGAMGtrHFely9wyaxI0RTCyfmJbWZHGVGkv6ELK8wneJjdjl82XOBUGCg5q\naRCrTZ3dPitKwrUa6ibJCIFCIziiriBmjDvTHzkMvoJEap2TVxYNDR6IfINVsQ57\nlOsiC4A2uGq4Lbfld+gjoplJ5GX6qXtTgZ6m7eo0y7U6zm2tkN0=\n-----END RSA PRIVATE KEY-----\n"),
options: []ProviderOpts{
WithCustomRequestTracker(requestTracker),
},
},
args: args{
state: "state",
},
want: want{
authType: &idp.FormAuth{},
ssoURL: "http://localhost:8000/sso",
relayState: "state",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
provider, err := New(
tt.fields.name,
tt.fields.rootURL,
tt.fields.metadata,
tt.fields.certificate,
tt.fields.key,
tt.fields.options...,
)
require.NoError(t, err)
ctx := context.Background()
session, err := provider.BeginAuth(ctx, tt.args.state, nil)
if tt.want.err != nil && !tt.want.err(err) {
a.Fail("invalid error", err)
}
if tt.want.err == nil {
a.NoError(err)
gotAuth, gotErr := session.GetAuth(ctx)
a.NoError(gotErr)
a.IsType(tt.want.authType, gotAuth)
var ssoURL, relayState, samlRequest string
switch auth := gotAuth.(type) {
case *idp.RedirectAuth:
gotRedirect, err := url.Parse(auth.RedirectURL)
a.NoError(err)
gotQuery := gotRedirect.Query()
ssoURL = gotRedirect.Scheme + "://" + gotRedirect.Host + gotRedirect.Path
relayState = gotQuery.Get("RelayState")
samlRequest = gotQuery.Get("SAMLRequest")
case *idp.FormAuth:
ssoURL = auth.URL
relayState = auth.Fields["RelayState"]
samlRequest = auth.Fields["SAMLRequest"]
}
a.Equal(tt.want.ssoURL, ssoURL)
a.Equal(tt.want.relayState, relayState)
a.NotEmpty(samlRequest)
}
})
}
}
func TestProvider_Options(t *testing.T) {
type fields struct {
name string
rootURL string
metadata []byte
key []byte
certificate []byte
options []ProviderOpts
}
type want struct {
err bool
name string
linkingAllowed bool
creationAllowed bool
autoCreation bool
autoUpdate bool
binding string
nameIDFormat saml.NameIDFormat
transientMappingAttributeName string
withSignedRequest bool
requesttracker samlsp.RequestTracker
entityID string
}
tests := []struct {
name string
fields fields
want want
}{{
name: "failed metadata",
fields: fields{
name: "saml",
rootURL: "https://localhost:8080",
metadata: []byte(">xml<"),
options: nil,
},
want: want{
err: true,
},
},
{
name: "failed keypair cert",
fields: fields{
name: "saml",
rootURL: "https://localhost:8080",
key: []byte("-----BEGIN RSA PRIVATE KEY-----\nMIIEogIBAAKCAQEAxHd087RoEm9ywVWZ/H+tDWxQsmVvhfRz4jAq/RfU+OWXNH4J\njMMSHdFs0Q+WP98nNXRyc7fgbMb8NdmlB2yD4qLYapN5SDaBc5dh/3EnyFt53oSs\njTlKnQUPAeJr2qh/NY046CfyUyQMM4JR5OiQFo4TssfWnqdcgamGt0AEnk2lvbMZ\nKQdAqNS9lDzYbjMGavEQPTZE35mFXFQXjaooZXq+TIa7hbaq7/idH7cHNbLcPLgj\nfPQA8q+DYvnvhXlmq0LPQZH3Oiixf+SF2vRwrBzT2mqGD2OiOkUmhuPwyqEiiBHt\nfxklRtRU6WfLa1Gcb1PsV0uoBGpV3KybIl/GlwIDAQABAoIBAEQjDduLgOCL6Gem\n0X3hpdnW6/HC/jed/Sa//9jBECq2LYeWAqff64ON40hqOHi0YvvGA/+gEOSI6mWe\nsv5tIxxRz+6+cLybsq+tG96kluCE4TJMHy/nY7orS/YiWbd+4odnEApr+D3fbZ/b\nnZ1fDsHTyn8hkYx6jLmnWsJpIHDp7zxD76y7k2Bbg6DZrCGiVxngiLJk23dvz79W\np03lHLM7XE92aFwXQmhfxHGxrbuoB/9eY4ai5IHp36H4fw0vL6NXdNQAo/bhe0p9\nAYB7y0ZumF8Hg0Z/BmMeEzLy6HrYB+VE8cO93pNjhSyH+p2yDB/BlUyTiRLQAoM0\nVTmOZXECgYEA7NGlzpKNhyQEJihVqt0MW0LhKIO/xbBn+XgYfX6GpqPa/ucnMx5/\nVezpl3gK8IU4wPUhAyXXAHJiqNBcEeyxrw0MXLujDVMJgYaLysCLJdvMVgoY08mS\nK5IQivpbozpf4+0y3mOnA+Sy1kbfxv2X8xiWLODRQW3f3q/xoklwOR8CgYEA1GEe\nfaibOFTQAYcIVj77KXtBfYZsX3EGAyfAN9O7cKHq5oaxVstwnF47WxpuVtoKZxCZ\nbNm9D5WvQ9b+Ztpioe42tzwE7Bff/Osj868GcDdRPK7nFlh9N2yVn/D514dOYVwR\n4MBr1KrJzgRWt4QqS4H+to1GzudDTSNlG7gnK4kCgYBUi6AbOHzoYzZL/RhgcJwp\ntJ23nhmH1Su5h2OO4e3mbhcP66w19sxU+8iFN+kH5zfUw26utgKk+TE5vXExQQRK\nT2k7bg2PAzcgk80ybD0BHhA8I0yrx4m0nmfjhe/TPVLgh10iwgbtP+eM0i6v1vc5\nZWyvxu9N4ZEL6lpkqr0y1wKBgG/NAIQd8jhhTW7Aav8cAJQBsqQl038avJOEpYe+\nCnpsgoAAf/K0/f8TDCQVceh+t+MxtdK7fO9rWOxZjWsPo8Si5mLnUaAHoX4/OpnZ\nlYYVWMqdOEFnK+O1Yb7k2GFBdV2DXlX2dc1qavntBsls5ecB89id3pyk2aUN8Pf6\npYQhAoGAMGtrHFely9wyaxI0RTCyfmJbWZHGVGkv6ELK8wneJjdjl82XOBUGCg5q\naRCrTZ3dPitKwrUa6ibJCIFCIziiriBmjDvTHzkMvoJEap2TVxYNDR6IfINVsQ57\nlOsiC4A2uGq4Lbfld+gjoplJ5GX6qXtTgZ6m7eo0y7U6zm2tkN0=\n-----END RSA PRIVATE KEY-----\n"),
metadata: []byte("<EntityDescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" validUntil=\"2023-08-27T12:40:58.803Z\" cacheDuration=\"PT48H\" entityID=\"http://localhost:8000/metadata\">\n <IDPSSODescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <KeyDescriptor use=\"signing\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n </KeyDescriptor>\n <KeyDescriptor use=\"encryption\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes128-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes192-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes256-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p\"></EncryptionMethod>\n </KeyDescriptor>\n <NameIDFormat>urn:oasis:names:tc:SAML:2.0:nameid-format:transient</NameIDFormat>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n </IDPSSODescriptor>\n</EntityDescriptor>"),
options: nil,
},
want: want{
err: true,
},
},
{
name: "failed keypair key",
fields: fields{
name: "saml",
rootURL: "https://localhost:8080",
certificate: []byte("-----BEGIN CERTIFICATE-----\nMIIC2zCCAcOgAwIBAgIIAy/jm1gAAdEwDQYJKoZIhvcNAQELBQAwEjEQMA4GA1UE\nChMHWklUQURFTDAeFw0yMzA4MzAwNzExMTVaFw0yNDA4MjkwNzExMTVaMBIxEDAO\nBgNVBAoTB1pJVEFERUwwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDE\nd3TztGgSb3LBVZn8f60NbFCyZW+F9HPiMCr9F9T45Zc0fgmMwxId0WzRD5Y/3yc1\ndHJzt+Bsxvw12aUHbIPiothqk3lINoFzl2H/cSfIW3nehKyNOUqdBQ8B4mvaqH81\njTjoJ/JTJAwzglHk6JAWjhOyx9aep1yBqYa3QASeTaW9sxkpB0Co1L2UPNhuMwZq\n8RA9NkTfmYVcVBeNqihler5MhruFtqrv+J0ftwc1stw8uCN89ADyr4Ni+e+FeWar\nQs9Bkfc6KLF/5IXa9HCsHNPaaoYPY6I6RSaG4/DKoSKIEe1/GSVG1FTpZ8trUZxv\nU+xXS6gEalXcrJsiX8aXAgMBAAGjNTAzMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUE\nDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMA0GCSqGSIb3DQEBCwUAA4IBAQCx\n/dRNIj0N/16zJhZR/ahkc2AkvDXYxyr4JRT5wK9GQDNl/oaX3debRuSi/tfaXFIX\naJA6PxM4J49ZaiEpLrKfxMz5kAhjKchCBEMcH3mGt+iNZH7EOyTvHjpGrP2OZrsh\nO17yrvN3HuQxIU6roJlqtZz2iAADsoPtwOO4D7hupm9XTMkSnAmlMWOo/q46Jz89\n1sMxB+dXmH/zV0wgwh0omZfLV0u89mvdq269VhcjNBpBYSnN1ccqYWd5iwziob3I\nvaavGHGfkbvRUn/tKftYuTK30q03R+e9YbmlWZ0v695owh2e/apCzowQsCKfSVC8\nOxVyt5XkHq1tWwVyBmFp\n-----END CERTIFICATE-----\n"),
metadata: []byte("<EntityDescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" validUntil=\"2023-08-27T12:40:58.803Z\" cacheDuration=\"PT48H\" entityID=\"http://localhost:8000/metadata\">\n <IDPSSODescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <KeyDescriptor use=\"signing\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n </KeyDescriptor>\n <KeyDescriptor use=\"encryption\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes128-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes192-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes256-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p\"></EncryptionMethod>\n </KeyDescriptor>\n <NameIDFormat>urn:oasis:names:tc:SAML:2.0:nameid-format:transient</NameIDFormat>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n </IDPSSODescriptor>\n</EntityDescriptor>"),
options: nil,
},
want: want{
err: true,
},
},
{
name: "failed url",
fields: fields{
name: "saml",
rootURL: "%%",
key: []byte("-----BEGIN RSA PRIVATE KEY-----\nMIIEogIBAAKCAQEAxHd087RoEm9ywVWZ/H+tDWxQsmVvhfRz4jAq/RfU+OWXNH4J\njMMSHdFs0Q+WP98nNXRyc7fgbMb8NdmlB2yD4qLYapN5SDaBc5dh/3EnyFt53oSs\njTlKnQUPAeJr2qh/NY046CfyUyQMM4JR5OiQFo4TssfWnqdcgamGt0AEnk2lvbMZ\nKQdAqNS9lDzYbjMGavEQPTZE35mFXFQXjaooZXq+TIa7hbaq7/idH7cHNbLcPLgj\nfPQA8q+DYvnvhXlmq0LPQZH3Oiixf+SF2vRwrBzT2mqGD2OiOkUmhuPwyqEiiBHt\nfxklRtRU6WfLa1Gcb1PsV0uoBGpV3KybIl/GlwIDAQABAoIBAEQjDduLgOCL6Gem\n0X3hpdnW6/HC/jed/Sa//9jBECq2LYeWAqff64ON40hqOHi0YvvGA/+gEOSI6mWe\nsv5tIxxRz+6+cLybsq+tG96kluCE4TJMHy/nY7orS/YiWbd+4odnEApr+D3fbZ/b\nnZ1fDsHTyn8hkYx6jLmnWsJpIHDp7zxD76y7k2Bbg6DZrCGiVxngiLJk23dvz79W\np03lHLM7XE92aFwXQmhfxHGxrbuoB/9eY4ai5IHp36H4fw0vL6NXdNQAo/bhe0p9\nAYB7y0ZumF8Hg0Z/BmMeEzLy6HrYB+VE8cO93pNjhSyH+p2yDB/BlUyTiRLQAoM0\nVTmOZXECgYEA7NGlzpKNhyQEJihVqt0MW0LhKIO/xbBn+XgYfX6GpqPa/ucnMx5/\nVezpl3gK8IU4wPUhAyXXAHJiqNBcEeyxrw0MXLujDVMJgYaLysCLJdvMVgoY08mS\nK5IQivpbozpf4+0y3mOnA+Sy1kbfxv2X8xiWLODRQW3f3q/xoklwOR8CgYEA1GEe\nfaibOFTQAYcIVj77KXtBfYZsX3EGAyfAN9O7cKHq5oaxVstwnF47WxpuVtoKZxCZ\nbNm9D5WvQ9b+Ztpioe42tzwE7Bff/Osj868GcDdRPK7nFlh9N2yVn/D514dOYVwR\n4MBr1KrJzgRWt4QqS4H+to1GzudDTSNlG7gnK4kCgYBUi6AbOHzoYzZL/RhgcJwp\ntJ23nhmH1Su5h2OO4e3mbhcP66w19sxU+8iFN+kH5zfUw26utgKk+TE5vXExQQRK\nT2k7bg2PAzcgk80ybD0BHhA8I0yrx4m0nmfjhe/TPVLgh10iwgbtP+eM0i6v1vc5\nZWyvxu9N4ZEL6lpkqr0y1wKBgG/NAIQd8jhhTW7Aav8cAJQBsqQl038avJOEpYe+\nCnpsgoAAf/K0/f8TDCQVceh+t+MxtdK7fO9rWOxZjWsPo8Si5mLnUaAHoX4/OpnZ\nlYYVWMqdOEFnK+O1Yb7k2GFBdV2DXlX2dc1qavntBsls5ecB89id3pyk2aUN8Pf6\npYQhAoGAMGtrHFely9wyaxI0RTCyfmJbWZHGVGkv6ELK8wneJjdjl82XOBUGCg5q\naRCrTZ3dPitKwrUa6ibJCIFCIziiriBmjDvTHzkMvoJEap2TVxYNDR6IfINVsQ57\nlOsiC4A2uGq4Lbfld+gjoplJ5GX6qXtTgZ6m7eo0y7U6zm2tkN0=\n-----END RSA PRIVATE KEY-----\n"),
certificate: []byte("-----BEGIN CERTIFICATE-----\nMIIC2zCCAcOgAwIBAgIIAy/jm1gAAdEwDQYJKoZIhvcNAQELBQAwEjEQMA4GA1UE\nChMHWklUQURFTDAeFw0yMzA4MzAwNzExMTVaFw0yNDA4MjkwNzExMTVaMBIxEDAO\nBgNVBAoTB1pJVEFERUwwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDE\nd3TztGgSb3LBVZn8f60NbFCyZW+F9HPiMCr9F9T45Zc0fgmMwxId0WzRD5Y/3yc1\ndHJzt+Bsxvw12aUHbIPiothqk3lINoFzl2H/cSfIW3nehKyNOUqdBQ8B4mvaqH81\njTjoJ/JTJAwzglHk6JAWjhOyx9aep1yBqYa3QASeTaW9sxkpB0Co1L2UPNhuMwZq\n8RA9NkTfmYVcVBeNqihler5MhruFtqrv+J0ftwc1stw8uCN89ADyr4Ni+e+FeWar\nQs9Bkfc6KLF/5IXa9HCsHNPaaoYPY6I6RSaG4/DKoSKIEe1/GSVG1FTpZ8trUZxv\nU+xXS6gEalXcrJsiX8aXAgMBAAGjNTAzMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUE\nDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMA0GCSqGSIb3DQEBCwUAA4IBAQCx\n/dRNIj0N/16zJhZR/ahkc2AkvDXYxyr4JRT5wK9GQDNl/oaX3debRuSi/tfaXFIX\naJA6PxM4J49ZaiEpLrKfxMz5kAhjKchCBEMcH3mGt+iNZH7EOyTvHjpGrP2OZrsh\nO17yrvN3HuQxIU6roJlqtZz2iAADsoPtwOO4D7hupm9XTMkSnAmlMWOo/q46Jz89\n1sMxB+dXmH/zV0wgwh0omZfLV0u89mvdq269VhcjNBpBYSnN1ccqYWd5iwziob3I\nvaavGHGfkbvRUn/tKftYuTK30q03R+e9YbmlWZ0v695owh2e/apCzowQsCKfSVC8\nOxVyt5XkHq1tWwVyBmFp\n-----END CERTIFICATE-----\n"),
metadata: []byte("<EntityDescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" validUntil=\"2023-08-27T12:40:58.803Z\" cacheDuration=\"PT48H\" entityID=\"http://localhost:8000/metadata\">\n <IDPSSODescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <KeyDescriptor use=\"signing\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n </KeyDescriptor>\n <KeyDescriptor use=\"encryption\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes128-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes192-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes256-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p\"></EncryptionMethod>\n </KeyDescriptor>\n <NameIDFormat>urn:oasis:names:tc:SAML:2.0:nameid-format:transient</NameIDFormat>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n </IDPSSODescriptor>\n</EntityDescriptor>"),
options: nil,
},
want: want{
err: true,
},
},
{
name: "default",
fields: fields{
name: "saml",
rootURL: "https://localhost:8080",
key: []byte("-----BEGIN RSA PRIVATE KEY-----\nMIIEogIBAAKCAQEAxHd087RoEm9ywVWZ/H+tDWxQsmVvhfRz4jAq/RfU+OWXNH4J\njMMSHdFs0Q+WP98nNXRyc7fgbMb8NdmlB2yD4qLYapN5SDaBc5dh/3EnyFt53oSs\njTlKnQUPAeJr2qh/NY046CfyUyQMM4JR5OiQFo4TssfWnqdcgamGt0AEnk2lvbMZ\nKQdAqNS9lDzYbjMGavEQPTZE35mFXFQXjaooZXq+TIa7hbaq7/idH7cHNbLcPLgj\nfPQA8q+DYvnvhXlmq0LPQZH3Oiixf+SF2vRwrBzT2mqGD2OiOkUmhuPwyqEiiBHt\nfxklRtRU6WfLa1Gcb1PsV0uoBGpV3KybIl/GlwIDAQABAoIBAEQjDduLgOCL6Gem\n0X3hpdnW6/HC/jed/Sa//9jBECq2LYeWAqff64ON40hqOHi0YvvGA/+gEOSI6mWe\nsv5tIxxRz+6+cLybsq+tG96kluCE4TJMHy/nY7orS/YiWbd+4odnEApr+D3fbZ/b\nnZ1fDsHTyn8hkYx6jLmnWsJpIHDp7zxD76y7k2Bbg6DZrCGiVxngiLJk23dvz79W\np03lHLM7XE92aFwXQmhfxHGxrbuoB/9eY4ai5IHp36H4fw0vL6NXdNQAo/bhe0p9\nAYB7y0ZumF8Hg0Z/BmMeEzLy6HrYB+VE8cO93pNjhSyH+p2yDB/BlUyTiRLQAoM0\nVTmOZXECgYEA7NGlzpKNhyQEJihVqt0MW0LhKIO/xbBn+XgYfX6GpqPa/ucnMx5/\nVezpl3gK8IU4wPUhAyXXAHJiqNBcEeyxrw0MXLujDVMJgYaLysCLJdvMVgoY08mS\nK5IQivpbozpf4+0y3mOnA+Sy1kbfxv2X8xiWLODRQW3f3q/xoklwOR8CgYEA1GEe\nfaibOFTQAYcIVj77KXtBfYZsX3EGAyfAN9O7cKHq5oaxVstwnF47WxpuVtoKZxCZ\nbNm9D5WvQ9b+Ztpioe42tzwE7Bff/Osj868GcDdRPK7nFlh9N2yVn/D514dOYVwR\n4MBr1KrJzgRWt4QqS4H+to1GzudDTSNlG7gnK4kCgYBUi6AbOHzoYzZL/RhgcJwp\ntJ23nhmH1Su5h2OO4e3mbhcP66w19sxU+8iFN+kH5zfUw26utgKk+TE5vXExQQRK\nT2k7bg2PAzcgk80ybD0BHhA8I0yrx4m0nmfjhe/TPVLgh10iwgbtP+eM0i6v1vc5\nZWyvxu9N4ZEL6lpkqr0y1wKBgG/NAIQd8jhhTW7Aav8cAJQBsqQl038avJOEpYe+\nCnpsgoAAf/K0/f8TDCQVceh+t+MxtdK7fO9rWOxZjWsPo8Si5mLnUaAHoX4/OpnZ\nlYYVWMqdOEFnK+O1Yb7k2GFBdV2DXlX2dc1qavntBsls5ecB89id3pyk2aUN8Pf6\npYQhAoGAMGtrHFely9wyaxI0RTCyfmJbWZHGVGkv6ELK8wneJjdjl82XOBUGCg5q\naRCrTZ3dPitKwrUa6ibJCIFCIziiriBmjDvTHzkMvoJEap2TVxYNDR6IfINVsQ57\nlOsiC4A2uGq4Lbfld+gjoplJ5GX6qXtTgZ6m7eo0y7U6zm2tkN0=\n-----END RSA PRIVATE KEY-----\n"),
certificate: []byte("-----BEGIN CERTIFICATE-----\nMIIC2zCCAcOgAwIBAgIIAy/jm1gAAdEwDQYJKoZIhvcNAQELBQAwEjEQMA4GA1UE\nChMHWklUQURFTDAeFw0yMzA4MzAwNzExMTVaFw0yNDA4MjkwNzExMTVaMBIxEDAO\nBgNVBAoTB1pJVEFERUwwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDE\nd3TztGgSb3LBVZn8f60NbFCyZW+F9HPiMCr9F9T45Zc0fgmMwxId0WzRD5Y/3yc1\ndHJzt+Bsxvw12aUHbIPiothqk3lINoFzl2H/cSfIW3nehKyNOUqdBQ8B4mvaqH81\njTjoJ/JTJAwzglHk6JAWjhOyx9aep1yBqYa3QASeTaW9sxkpB0Co1L2UPNhuMwZq\n8RA9NkTfmYVcVBeNqihler5MhruFtqrv+J0ftwc1stw8uCN89ADyr4Ni+e+FeWar\nQs9Bkfc6KLF/5IXa9HCsHNPaaoYPY6I6RSaG4/DKoSKIEe1/GSVG1FTpZ8trUZxv\nU+xXS6gEalXcrJsiX8aXAgMBAAGjNTAzMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUE\nDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMA0GCSqGSIb3DQEBCwUAA4IBAQCx\n/dRNIj0N/16zJhZR/ahkc2AkvDXYxyr4JRT5wK9GQDNl/oaX3debRuSi/tfaXFIX\naJA6PxM4J49ZaiEpLrKfxMz5kAhjKchCBEMcH3mGt+iNZH7EOyTvHjpGrP2OZrsh\nO17yrvN3HuQxIU6roJlqtZz2iAADsoPtwOO4D7hupm9XTMkSnAmlMWOo/q46Jz89\n1sMxB+dXmH/zV0wgwh0omZfLV0u89mvdq269VhcjNBpBYSnN1ccqYWd5iwziob3I\nvaavGHGfkbvRUn/tKftYuTK30q03R+e9YbmlWZ0v695owh2e/apCzowQsCKfSVC8\nOxVyt5XkHq1tWwVyBmFp\n-----END CERTIFICATE-----\n"),
metadata: []byte("<EntityDescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" validUntil=\"2023-08-27T12:40:58.803Z\" cacheDuration=\"PT48H\" entityID=\"http://localhost:8000/metadata\">\n <IDPSSODescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <KeyDescriptor use=\"signing\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n </KeyDescriptor>\n <KeyDescriptor use=\"encryption\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes128-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes192-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes256-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p\"></EncryptionMethod>\n </KeyDescriptor>\n <NameIDFormat>urn:oasis:names:tc:SAML:2.0:nameid-format:transient</NameIDFormat>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n </IDPSSODescriptor>\n</EntityDescriptor>"),
options: nil,
},
want: want{
name: "saml",
linkingAllowed: false,
creationAllowed: false,
autoCreation: false,
autoUpdate: false,
nameIDFormat: saml.PersistentNameIDFormat,
},
},
{
name: "all set / true",
fields: fields{
name: "saml",
key: []byte("-----BEGIN RSA PRIVATE KEY-----\nMIIEogIBAAKCAQEAxHd087RoEm9ywVWZ/H+tDWxQsmVvhfRz4jAq/RfU+OWXNH4J\njMMSHdFs0Q+WP98nNXRyc7fgbMb8NdmlB2yD4qLYapN5SDaBc5dh/3EnyFt53oSs\njTlKnQUPAeJr2qh/NY046CfyUyQMM4JR5OiQFo4TssfWnqdcgamGt0AEnk2lvbMZ\nKQdAqNS9lDzYbjMGavEQPTZE35mFXFQXjaooZXq+TIa7hbaq7/idH7cHNbLcPLgj\nfPQA8q+DYvnvhXlmq0LPQZH3Oiixf+SF2vRwrBzT2mqGD2OiOkUmhuPwyqEiiBHt\nfxklRtRU6WfLa1Gcb1PsV0uoBGpV3KybIl/GlwIDAQABAoIBAEQjDduLgOCL6Gem\n0X3hpdnW6/HC/jed/Sa//9jBECq2LYeWAqff64ON40hqOHi0YvvGA/+gEOSI6mWe\nsv5tIxxRz+6+cLybsq+tG96kluCE4TJMHy/nY7orS/YiWbd+4odnEApr+D3fbZ/b\nnZ1fDsHTyn8hkYx6jLmnWsJpIHDp7zxD76y7k2Bbg6DZrCGiVxngiLJk23dvz79W\np03lHLM7XE92aFwXQmhfxHGxrbuoB/9eY4ai5IHp36H4fw0vL6NXdNQAo/bhe0p9\nAYB7y0ZumF8Hg0Z/BmMeEzLy6HrYB+VE8cO93pNjhSyH+p2yDB/BlUyTiRLQAoM0\nVTmOZXECgYEA7NGlzpKNhyQEJihVqt0MW0LhKIO/xbBn+XgYfX6GpqPa/ucnMx5/\nVezpl3gK8IU4wPUhAyXXAHJiqNBcEeyxrw0MXLujDVMJgYaLysCLJdvMVgoY08mS\nK5IQivpbozpf4+0y3mOnA+Sy1kbfxv2X8xiWLODRQW3f3q/xoklwOR8CgYEA1GEe\nfaibOFTQAYcIVj77KXtBfYZsX3EGAyfAN9O7cKHq5oaxVstwnF47WxpuVtoKZxCZ\nbNm9D5WvQ9b+Ztpioe42tzwE7Bff/Osj868GcDdRPK7nFlh9N2yVn/D514dOYVwR\n4MBr1KrJzgRWt4QqS4H+to1GzudDTSNlG7gnK4kCgYBUi6AbOHzoYzZL/RhgcJwp\ntJ23nhmH1Su5h2OO4e3mbhcP66w19sxU+8iFN+kH5zfUw26utgKk+TE5vXExQQRK\nT2k7bg2PAzcgk80ybD0BHhA8I0yrx4m0nmfjhe/TPVLgh10iwgbtP+eM0i6v1vc5\nZWyvxu9N4ZEL6lpkqr0y1wKBgG/NAIQd8jhhTW7Aav8cAJQBsqQl038avJOEpYe+\nCnpsgoAAf/K0/f8TDCQVceh+t+MxtdK7fO9rWOxZjWsPo8Si5mLnUaAHoX4/OpnZ\nlYYVWMqdOEFnK+O1Yb7k2GFBdV2DXlX2dc1qavntBsls5ecB89id3pyk2aUN8Pf6\npYQhAoGAMGtrHFely9wyaxI0RTCyfmJbWZHGVGkv6ELK8wneJjdjl82XOBUGCg5q\naRCrTZ3dPitKwrUa6ibJCIFCIziiriBmjDvTHzkMvoJEap2TVxYNDR6IfINVsQ57\nlOsiC4A2uGq4Lbfld+gjoplJ5GX6qXtTgZ6m7eo0y7U6zm2tkN0=\n-----END RSA PRIVATE KEY-----\n"),
certificate: []byte("-----BEGIN CERTIFICATE-----\nMIIC2zCCAcOgAwIBAgIIAy/jm1gAAdEwDQYJKoZIhvcNAQELBQAwEjEQMA4GA1UE\nChMHWklUQURFTDAeFw0yMzA4MzAwNzExMTVaFw0yNDA4MjkwNzExMTVaMBIxEDAO\nBgNVBAoTB1pJVEFERUwwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDE\nd3TztGgSb3LBVZn8f60NbFCyZW+F9HPiMCr9F9T45Zc0fgmMwxId0WzRD5Y/3yc1\ndHJzt+Bsxvw12aUHbIPiothqk3lINoFzl2H/cSfIW3nehKyNOUqdBQ8B4mvaqH81\njTjoJ/JTJAwzglHk6JAWjhOyx9aep1yBqYa3QASeTaW9sxkpB0Co1L2UPNhuMwZq\n8RA9NkTfmYVcVBeNqihler5MhruFtqrv+J0ftwc1stw8uCN89ADyr4Ni+e+FeWar\nQs9Bkfc6KLF/5IXa9HCsHNPaaoYPY6I6RSaG4/DKoSKIEe1/GSVG1FTpZ8trUZxv\nU+xXS6gEalXcrJsiX8aXAgMBAAGjNTAzMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUE\nDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMA0GCSqGSIb3DQEBCwUAA4IBAQCx\n/dRNIj0N/16zJhZR/ahkc2AkvDXYxyr4JRT5wK9GQDNl/oaX3debRuSi/tfaXFIX\naJA6PxM4J49ZaiEpLrKfxMz5kAhjKchCBEMcH3mGt+iNZH7EOyTvHjpGrP2OZrsh\nO17yrvN3HuQxIU6roJlqtZz2iAADsoPtwOO4D7hupm9XTMkSnAmlMWOo/q46Jz89\n1sMxB+dXmH/zV0wgwh0omZfLV0u89mvdq269VhcjNBpBYSnN1ccqYWd5iwziob3I\nvaavGHGfkbvRUn/tKftYuTK30q03R+e9YbmlWZ0v695owh2e/apCzowQsCKfSVC8\nOxVyt5XkHq1tWwVyBmFp\n-----END CERTIFICATE-----\n"),
metadata: []byte("<EntityDescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" validUntil=\"2023-08-27T12:40:58.803Z\" cacheDuration=\"PT48H\" entityID=\"http://localhost:8000/metadata\">\n <IDPSSODescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <KeyDescriptor use=\"signing\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n </KeyDescriptor>\n <KeyDescriptor use=\"encryption\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes128-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes192-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes256-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p\"></EncryptionMethod>\n </KeyDescriptor>\n <NameIDFormat>urn:oasis:names:tc:SAML:2.0:nameid-format:transient</NameIDFormat>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n </IDPSSODescriptor>\n</EntityDescriptor>"),
options: []ProviderOpts{
WithLinkingAllowed(),
WithCreationAllowed(),
WithAutoCreation(),
WithAutoUpdate(),
WithBinding("binding"),
WithSignedRequest(),
WithCustomRequestTracker(&requesttracker.RequestTracker{}),
WithEntityID("entityID"),
WithNameIDFormat(domain.SAMLNameIDFormatTransient),
WithTransientMappingAttributeName("attribute"),
},
},
want: want{
name: "saml",
linkingAllowed: true,
creationAllowed: true,
autoCreation: true,
autoUpdate: true,
binding: "binding",
entityID: "entityID",
nameIDFormat: saml.TransientNameIDFormat,
transientMappingAttributeName: "attribute",
withSignedRequest: true,
requesttracker: &requesttracker.RequestTracker{},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
provider, err := New(tt.fields.name, tt.fields.rootURL, tt.fields.metadata, tt.fields.certificate, tt.fields.key, tt.fields.options...)
if tt.want.err {
require.Error(t, err)
} else {
require.NoError(t, err)
a.Equal(tt.want.name, provider.Name())
a.Equal(tt.want.linkingAllowed, provider.IsLinkingAllowed())
a.Equal(tt.want.creationAllowed, provider.IsCreationAllowed())
a.Equal(tt.want.autoCreation, provider.IsAutoCreation())
a.Equal(tt.want.autoUpdate, provider.IsAutoUpdate())
a.Equal(tt.want.binding, provider.binding)
a.Equal(tt.want.nameIDFormat, provider.nameIDFormat)
a.Equal(tt.want.transientMappingAttributeName, provider.transientMappingAttributeName)
a.Equal(tt.want.withSignedRequest, provider.spOptions.SignRequest)
a.Equal(tt.want.requesttracker, provider.requestTracker)
a.Equal(tt.want.entityID, provider.spOptions.EntityID)
}
})
}
}
func TestParseMetadata(t *testing.T) {
type args struct {
metadata []byte
}
tests := []struct {
name string
args args
want *saml.EntityDescriptor
wantErr error
}{
{
"invalid",
args{
metadata: []byte(`<Test></Test>`),
},
nil,
xml.UnmarshalError("expected element type <EntityDescriptor> but have <Test>"),
},
{
"valid entity descriptor",
args{
metadata: []byte(`<?xml version="1.0" encoding="UTF-8"?><EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="http://localhost:8000/metadata"><IDPSSODescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata"><SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="http://localhost:8000/sso"></SingleSignOnService></IDPSSODescriptor></EntityDescriptor>`),
},
&saml.EntityDescriptor{
EntityID: "http://localhost:8000/metadata",
IDPSSODescriptors: []saml.IDPSSODescriptor{
{
XMLName: xml.Name{
Space: "urn:oasis:names:tc:SAML:2.0:metadata",
Local: "IDPSSODescriptor",
},
SingleSignOnServices: []saml.Endpoint{
{
Binding: saml.HTTPRedirectBinding,
Location: "http://localhost:8000/sso",
},
},
},
},
},
nil,
},
{
"valid entity descriptor, non utf-8",
args{
metadata: []byte(`<?xml version="1.0" encoding="windows-1252"?><EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="http://localhost:8000/metadata"><IDPSSODescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata"><SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="http://localhost:8000/sso"></SingleSignOnService></IDPSSODescriptor></EntityDescriptor>`),
},
&saml.EntityDescriptor{
EntityID: "http://localhost:8000/metadata",
IDPSSODescriptors: []saml.IDPSSODescriptor{
{
XMLName: xml.Name{
Space: "urn:oasis:names:tc:SAML:2.0:metadata",
Local: "IDPSSODescriptor",
},
SingleSignOnServices: []saml.Endpoint{
{
Binding: saml.HTTPRedirectBinding,
Location: "http://localhost:8000/sso",
},
},
},
},
},
nil,
},
{
"entities descriptor without IDPSSODescriptor",
args{
metadata: []byte(`<?xml version="1.0" encoding="UTF-8"?><EntitiesDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata"><EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="http://localhost:8000/metadata"></EntityDescriptor></EntitiesDescriptor>`),
},
nil,
zerrors.ThrowInternal(nil, "SAML-Ejoi3r2", "no entity found with IDPSSODescriptor"),
},
{
"valid entities descriptor",
args{
metadata: []byte(`<?xml version="1.0" encoding="UTF-8"?><EntitiesDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata"><EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="http://localhost:8000/metadata"><IDPSSODescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata"><SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="http://localhost:8000/sso"></SingleSignOnService></IDPSSODescriptor></EntityDescriptor></EntitiesDescriptor>`),
},
&saml.EntityDescriptor{
EntityID: "http://localhost:8000/metadata",
IDPSSODescriptors: []saml.IDPSSODescriptor{
{
XMLName: xml.Name{
Space: "urn:oasis:names:tc:SAML:2.0:metadata",
Local: "IDPSSODescriptor",
},
SingleSignOnServices: []saml.Endpoint{
{
Binding: saml.HTTPRedirectBinding,
Location: "http://localhost:8000/sso",
},
},
},
},
},
nil,
},
{
"valid entities using xsd duration descriptor",
args{
metadata: []byte(`<?xml version="1.0" encoding="UTF-8"?><EntitiesDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" cacheDuration="PT5H"><EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="http://localhost:8000/metadata" cacheDuration="PT5H"><IDPSSODescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata"><SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="http://localhost:8000/sso"></SingleSignOnService></IDPSSODescriptor></EntityDescriptor></EntitiesDescriptor>`),
},
&saml.EntityDescriptor{
EntityID: "http://localhost:8000/metadata",
CacheDuration: 5 * time.Hour,
IDPSSODescriptors: []saml.IDPSSODescriptor{
{
XMLName: xml.Name{
Space: "urn:oasis:names:tc:SAML:2.0:metadata",
Local: "IDPSSODescriptor",
},
SingleSignOnServices: []saml.Endpoint{
{
Binding: saml.HTTPRedirectBinding,
Location: "http://localhost:8000/sso",
},
},
},
},
},
nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseMetadata(tt.args.metadata)
assert.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -0,0 +1,180 @@
package saml
import (
"context"
"encoding/base64"
"errors"
"net/http"
"net/url"
"time"
"github.com/beevik/etree"
"github.com/crewjam/saml"
"github.com/crewjam/saml/samlsp"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/zerrors"
)
var _ idp.Session = (*Session)(nil)
// Session is the [idp.Session] implementation for the SAML provider.
type Session struct {
ServiceProvider *samlsp.Middleware
state string
TransientMappingAttributeName string
RequestID string
Request *http.Request
Assertion *saml.Assertion
}
func NewSession(provider *Provider, requestID string, request *http.Request) (*Session, error) {
sp, err := provider.GetSP()
if err != nil {
return nil, err
}
return &Session{
ServiceProvider: sp,
TransientMappingAttributeName: provider.TransientMappingAttributeName(),
RequestID: requestID,
Request: request,
}, nil
}
// GetAuth implements the [idp.Session] interface.
func (s *Session) GetAuth(ctx context.Context) (idp.Auth, error) {
url, err := url.Parse(s.state)
if err != nil {
return nil, err
}
request := &http.Request{
URL: url,
}
return s.auth(request.WithContext(ctx))
}
// PersistentParameters implements the [idp.Session] interface.
func (s *Session) PersistentParameters() map[string]any {
return nil
}
// FetchUser implements the [idp.Session] interface.
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
if s.RequestID == "" || s.Request == nil {
return nil, zerrors.ThrowInvalidArgument(nil, "SAML-d09hy0wkex", "Errors.Intent.ResponseInvalid")
}
s.Assertion, err = s.ServiceProvider.ServiceProvider.ParseResponse(s.Request, []string{s.RequestID})
if err != nil {
invalidRespErr := new(saml.InvalidResponseError)
if errors.As(err, &invalidRespErr) {
return nil, zerrors.ThrowInvalidArgument(invalidRespErr.PrivateErr, "SAML-ajl3irfs", "Errors.Intent.ResponseInvalid")
}
return nil, zerrors.ThrowInvalidArgument(err, "SAML-nuo0vphhh9", "Errors.Intent.ResponseInvalid")
}
// nameID is required, but at least in ADFS it will not be sent unless explicitly configured
if s.Assertion.Subject == nil || s.Assertion.Subject.NameID == nil {
return nil, zerrors.ThrowInvalidArgument(err, "SAML-EFG32", "Errors.Intent.ResponseInvalid")
}
nameID := s.Assertion.Subject.NameID
userMapper := NewUser()
// use the nameID as default mapping id
userMapper.SetID(nameID.Value)
if nameID.Format == string(saml.TransientNameIDFormat) {
mappingID, err := s.transientMappingID()
if err != nil {
return nil, err
}
userMapper.SetID(mappingID)
}
for _, statement := range s.Assertion.AttributeStatements {
for _, attribute := range statement.Attributes {
values := make([]string, len(attribute.Values))
for i := range attribute.Values {
values[i] = attribute.Values[i].Value
}
userMapper.Attributes[attribute.Name] = values
}
}
return userMapper, nil
}
func (s *Session) ExpiresAt() time.Time {
if s.Assertion == nil || s.Assertion.Conditions == nil {
return time.Time{}
}
return s.Assertion.Conditions.NotOnOrAfter
}
func (s *Session) transientMappingID() (string, error) {
for _, statement := range s.Assertion.AttributeStatements {
for _, attribute := range statement.Attributes {
if attribute.Name != s.TransientMappingAttributeName {
continue
}
if len(attribute.Values) != 1 {
return "", zerrors.ThrowInvalidArgument(nil, "SAML-Soij4", "Errors.Intent.MissingSingleMappingAttribute")
}
return attribute.Values[0].Value, nil
}
}
return "", zerrors.ThrowInvalidArgument(nil, "SAML-swwg2", "Errors.Intent.MissingSingleMappingAttribute")
}
// auth is a modified copy of the [samlsp.Middleware.HandleStartAuthFlow] method.
// Instead of writing the response to the http.ResponseWriter, it returns the auth request as an [idp.Auth].
// In case of an error, it returns the error directly and does not write to the response.
func (s *Session) auth(r *http.Request) (idp.Auth, error) {
if r.URL.Path == s.ServiceProvider.ServiceProvider.AcsURL.Path {
// should never occur, but was handled in the original method, so we keep it here
return nil, zerrors.ThrowInvalidArgument(nil, "SAML-Eoi24", "don't wrap Middleware with RequireAccount")
}
var binding, bindingLocation string
if s.ServiceProvider.Binding != "" {
binding = s.ServiceProvider.Binding
bindingLocation = s.ServiceProvider.ServiceProvider.GetSSOBindingLocation(binding)
} else {
binding = saml.HTTPRedirectBinding
bindingLocation = s.ServiceProvider.ServiceProvider.GetSSOBindingLocation(binding)
if bindingLocation == "" {
binding = saml.HTTPPostBinding
bindingLocation = s.ServiceProvider.ServiceProvider.GetSSOBindingLocation(binding)
}
}
authReq, err := s.ServiceProvider.ServiceProvider.MakeAuthenticationRequest(bindingLocation, binding, s.ServiceProvider.ResponseBinding)
if err != nil {
return nil, err
}
relayState, err := s.ServiceProvider.RequestTracker.TrackRequest(nil, r, authReq.ID)
if err != nil {
return nil, err
}
if binding == saml.HTTPRedirectBinding {
redirectURL, err := authReq.Redirect(relayState, &s.ServiceProvider.ServiceProvider)
if err != nil {
return nil, err
}
return idp.Redirect(redirectURL.String())
}
if binding == saml.HTTPPostBinding {
doc := etree.NewDocument()
doc.SetRoot(authReq.Element())
reqBuf, err := doc.WriteToBytes()
if err != nil {
return nil, err
}
encodedReqBuf := base64.StdEncoding.EncodeToString(reqBuf)
return idp.Form(authReq.Destination,
map[string]string{
"SAMLRequest": encodedReqBuf,
"RelayState": relayState,
})
}
return nil, zerrors.ThrowInvalidArgument(nil, "SAML-Eoi24", "Errors.Intent.Invalid")
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,51 @@
package idp
import (
"context"
"time"
)
// Session is the minimal implementation for a session of a 3rd party authentication [Provider]
type Session interface {
GetAuth(ctx context.Context) (Auth, error)
PersistentParameters() map[string]any
FetchUser(ctx context.Context) (User, error)
ExpiresAt() time.Time
}
type Auth interface {
auth()
}
type RedirectAuth struct {
RedirectURL string
}
func (r *RedirectAuth) auth() {}
type FormAuth struct {
URL string
Fields map[string]string
}
func (f *FormAuth) auth() {}
// SessionSupportsMigration is an optional extension to the Session interface.
// It can be implemented to support migrating users, were the initial external id has changed because of a migration of the Provider type.
// E.g. when a user was linked on a generic OIDC provider and this provider has now been migrated to an AzureAD provider.
// In this case OIDC used the `sub` claim and Azure now uses the id of the user endpoint, which differ.
// The RetrievePreviousID will return the `sub` claim again, so that the user can be matched and safely migrated to the new id.
type SessionSupportsMigration interface {
RetrievePreviousID() (previousID string, err error)
}
func Redirect(redirectURL string) (*RedirectAuth, error) {
return &RedirectAuth{RedirectURL: redirectURL}, nil
}
func Form(url string, fields map[string]string) (*FormAuth, error) {
return &FormAuth{
URL: url,
Fields: fields,
}, nil
}