mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 11:27:33 +00:00
chore: move the go code into a subfolder
This commit is contained in:
51
apps/api/internal/idp/provider.go
Normal file
51
apps/api/internal/idp/provider.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
)
|
||||
|
||||
// Provider is the minimal implementation for a 3rd party authentication provider
|
||||
type Provider interface {
|
||||
Name() string
|
||||
BeginAuth(ctx context.Context, state string, params ...Parameter) (Session, error)
|
||||
IsLinkingAllowed() bool
|
||||
IsCreationAllowed() bool
|
||||
IsAutoCreation() bool
|
||||
IsAutoUpdate() bool
|
||||
}
|
||||
|
||||
// User contains the information of a federated user.
|
||||
type User interface {
|
||||
GetID() string
|
||||
GetFirstName() string
|
||||
GetLastName() string
|
||||
GetDisplayName() string
|
||||
GetNickname() string
|
||||
GetPreferredUsername() string
|
||||
GetEmail() domain.EmailAddress
|
||||
IsEmailVerified() bool
|
||||
GetPhone() domain.PhoneNumber
|
||||
IsPhoneVerified() bool
|
||||
GetPreferredLanguage() language.Tag
|
||||
GetAvatarURL() string
|
||||
GetProfile() string
|
||||
}
|
||||
|
||||
// Parameter allows to pass specific parameter to the BeginAuth function
|
||||
type Parameter interface {
|
||||
setValue()
|
||||
}
|
||||
|
||||
// UserAgentID allows to pass the user agent ID of the auth request to BeginAuth
|
||||
type UserAgentID string
|
||||
|
||||
func (p UserAgentID) setValue() {}
|
||||
|
||||
// LoginHintParam allows to pass a login_hint to BeginAuth
|
||||
type LoginHintParam string
|
||||
|
||||
func (p LoginHintParam) setValue() {}
|
68
apps/api/internal/idp/providers/apple/apple.go
Normal file
68
apps/api/internal/idp/providers/apple/apple.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package apple
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
"github.com/zitadel/oidc/v3/pkg/crypto"
|
||||
openid "github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
|
||||
)
|
||||
|
||||
const (
|
||||
name = "Apple"
|
||||
issuer = "https://appleid.apple.com"
|
||||
)
|
||||
|
||||
var _ idp.Provider = (*Provider)(nil)
|
||||
|
||||
// Provider is the [idp.Provider] implementation for Apple
|
||||
type Provider struct {
|
||||
*oidc.Provider
|
||||
}
|
||||
|
||||
func New(clientID, teamID, keyID, callbackURL string, key []byte, scopes []string, options ...oidc.ProviderOpts) (*Provider, error) {
|
||||
secret, err := clientSecretFromPrivateKey(key, teamID, clientID, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
options = append(options, oidc.WithResponseMode("form_post"))
|
||||
rp, err := oidc.New(name, issuer, clientID, secret, callbackURL, scopes, oidc.DefaultMapper, options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Provider{
|
||||
Provider: rp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// clientSecretFromPrivateKey uses the private key to create and sign a JWT, which has to be used as client_secret at Apple.
|
||||
func clientSecretFromPrivateKey(key []byte, teamID, clientID, keyID string) (string, error) {
|
||||
block, _ := pem.Decode(key)
|
||||
b := block.Bytes
|
||||
pk, err := x509.ParsePKCS8PrivateKey(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
signingKey := jose.SigningKey{
|
||||
Algorithm: jose.ES256,
|
||||
Key: &jose.JSONWebKey{Key: pk, KeyID: keyID},
|
||||
}
|
||||
signer, err := jose.NewSigner(signingKey, &jose.SignerOptions{})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
iat := time.Now().Add(-2 * time.Second)
|
||||
exp := iat.Add(time.Hour)
|
||||
return crypto.Sign(&openid.JWTTokenRequest{
|
||||
Issuer: teamID,
|
||||
Subject: clientID,
|
||||
Audience: []string{issuer},
|
||||
ExpiresAt: openid.FromTime(exp),
|
||||
IssuedAt: openid.FromTime(iat),
|
||||
}, signer)
|
||||
}
|
71
apps/api/internal/idp/providers/apple/apple_test.go
Normal file
71
apps/api/internal/idp/providers/apple/apple_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package apple
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
|
||||
)
|
||||
|
||||
const (
|
||||
privateKey = `-----BEGIN PRIVATE KEY-----
|
||||
MIGTAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBHkwdwIBAQQgXn/LDURaetCoymSj
|
||||
fRslBiBwzBSa8ifiyfYGIWNStYGgCgYIKoZIzj0DAQehRANCAATymZXIsGrXnl6b
|
||||
+80miSiVOCcLnyaYa2uQBQvQwgB7GibXhrzF+D/MRTV4P7P8+Lg1K9Khkjc59eNK
|
||||
4RrQP4g7
|
||||
-----END PRIVATE KEY-----
|
||||
`
|
||||
)
|
||||
|
||||
func TestProvider_BeginAuth(t *testing.T) {
|
||||
type fields struct {
|
||||
clientID string
|
||||
teamID string
|
||||
keyID string
|
||||
privateKey []byte
|
||||
redirectURI string
|
||||
scopes []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want idp.Session
|
||||
}{
|
||||
{
|
||||
name: "successful auth",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
teamID: "teamID",
|
||||
keyID: "keyID",
|
||||
privateKey: []byte(privateKey),
|
||||
redirectURI: "redirectURI",
|
||||
scopes: []string{"openid"},
|
||||
},
|
||||
want: &Session{
|
||||
Session: &oidc.Session{
|
||||
AuthURL: "https://appleid.apple.com/auth/authorize?client_id=clientID&redirect_uri=redirectURI&response_mode=form_post&response_type=code&scope=openid&state=testState",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
r := require.New(t)
|
||||
|
||||
provider, err := New(tt.fields.clientID, tt.fields.teamID, tt.fields.keyID, tt.fields.redirectURI, tt.fields.privateKey, tt.fields.scopes)
|
||||
r.NoError(err)
|
||||
ctx := context.Background()
|
||||
session, err := provider.BeginAuth(ctx, "testState")
|
||||
r.NoError(err)
|
||||
auth, err := session.GetAuth(ctx)
|
||||
authExpected, errExpected := tt.want.GetAuth(ctx)
|
||||
a.ErrorIs(err, errExpected)
|
||||
a.Equal(authExpected, auth)
|
||||
})
|
||||
}
|
||||
}
|
74
apps/api/internal/idp/providers/apple/session.go
Normal file
74
apps/api/internal/idp/providers/apple/session.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package apple
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
openid "github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
|
||||
)
|
||||
|
||||
var _ idp.Session = (*Session)(nil)
|
||||
|
||||
// Session extends the [oidc.Session] with the formValues returned from the callback.
|
||||
// This enables to parse the user (name and email), which Apple only returns as form params on registration
|
||||
type Session struct {
|
||||
*oidc.Session
|
||||
UserFormValue string
|
||||
}
|
||||
|
||||
func NewSession(provider *Provider, code, userFormValue string) *Session {
|
||||
return &Session{Session: oidc.NewSession(provider.Provider, code, nil), UserFormValue: userFormValue}
|
||||
}
|
||||
|
||||
type userFormValue struct {
|
||||
Name userNamesFormValue `json:"name,omitempty" schema:"name"`
|
||||
}
|
||||
|
||||
type userNamesFormValue struct {
|
||||
FirstName string `json:"firstName,omitempty" schema:"firstName"`
|
||||
LastName string `json:"lastName,omitempty" schema:"lastName"`
|
||||
}
|
||||
|
||||
// FetchUser implements the [idp.Session] interface.
|
||||
// It will execute an OIDC code exchange if needed to retrieve the tokens,
|
||||
// extract the information from the id_token and if available also from the `user` form value.
|
||||
// The information will be mapped into an [idp.User].
|
||||
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
|
||||
if s.Tokens == nil {
|
||||
if err = s.Authorize(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
info := s.Tokens.IDTokenClaims.GetUserInfo()
|
||||
userName := userFormValue{}
|
||||
if s.UserFormValue != "" {
|
||||
if err = json.Unmarshal([]byte(s.UserFormValue), &userName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return NewUser(info, userName.Name), nil
|
||||
}
|
||||
|
||||
func NewUser(info *openid.UserInfo, names userNamesFormValue) *User {
|
||||
user := oidc.NewUser(info)
|
||||
user.GivenName = names.FirstName
|
||||
user.FamilyName = names.LastName
|
||||
return &User{User: user}
|
||||
}
|
||||
|
||||
func InitUser() idp.User {
|
||||
return &User{User: oidc.InitUser()}
|
||||
}
|
||||
|
||||
// User extends the [oidc.User] by returning the email as preferred_username, since Apple does not return the latter.
|
||||
type User struct {
|
||||
*oidc.User
|
||||
}
|
||||
|
||||
func (u *User) GetPreferredUsername() string {
|
||||
return u.Email
|
||||
}
|
217
apps/api/internal/idp/providers/apple/session_test.go
Normal file
217
apps/api/internal/idp/providers/apple/session_test.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package apple
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/h2non/gock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
openid "github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
|
||||
)
|
||||
|
||||
func TestSession_FetchUser(t *testing.T) {
|
||||
type fields struct {
|
||||
clientID string
|
||||
teamID string
|
||||
keyID string
|
||||
privateKey []byte
|
||||
redirectURI string
|
||||
scopes []string
|
||||
httpMock func()
|
||||
authURL string
|
||||
code string
|
||||
tokens *openid.Tokens[*openid.IDTokenClaims]
|
||||
userFormValue string
|
||||
}
|
||||
type want struct {
|
||||
err error
|
||||
id string
|
||||
firstName string
|
||||
lastName string
|
||||
displayName string
|
||||
nickName string
|
||||
preferredUsername string
|
||||
email string
|
||||
isEmailVerified bool
|
||||
phone string
|
||||
isPhoneVerified bool
|
||||
preferredLanguage language.Tag
|
||||
avatarURL string
|
||||
profile string
|
||||
nonceSupported bool
|
||||
isPrivateEmail bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "unauthenticated session, error",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
teamID: "teamID",
|
||||
keyID: "keyID",
|
||||
privateKey: []byte(privateKey),
|
||||
redirectURI: "redirectURI",
|
||||
scopes: []string{"openid"},
|
||||
httpMock: func() {},
|
||||
authURL: "https://appleid.apple.com/auth/authorize?client_id=clientID&redirect_uri=redirectURI&response_mode=form_post&response_type=code&scope=openid&state=testState",
|
||||
tokens: nil,
|
||||
},
|
||||
want: want{
|
||||
err: oidc.ErrCodeMissing,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no user param",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
teamID: "teamID",
|
||||
keyID: "keyID",
|
||||
privateKey: []byte(privateKey),
|
||||
redirectURI: "redirectURI",
|
||||
scopes: []string{"openid"},
|
||||
httpMock: func() {},
|
||||
authURL: "https://appleid.apple.com/auth/authorize?client_id=clientID&redirect_uri=redirectURI&response_mode=form_post&response_type=code&scope=openid&state=testState",
|
||||
tokens: &openid.Tokens[*openid.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: openid.BearerToken,
|
||||
},
|
||||
IDTokenClaims: id_token(),
|
||||
},
|
||||
userFormValue: "",
|
||||
},
|
||||
want: want{
|
||||
id: "sub",
|
||||
firstName: "",
|
||||
lastName: "",
|
||||
displayName: "",
|
||||
nickName: "",
|
||||
preferredUsername: "email",
|
||||
email: "email",
|
||||
isEmailVerified: true,
|
||||
phone: "",
|
||||
isPhoneVerified: false,
|
||||
preferredLanguage: language.Und,
|
||||
avatarURL: "",
|
||||
profile: "",
|
||||
nonceSupported: true,
|
||||
isPrivateEmail: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with user param",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
teamID: "teamID",
|
||||
keyID: "keyID",
|
||||
privateKey: []byte(privateKey),
|
||||
redirectURI: "redirectURI",
|
||||
scopes: []string{"openid"},
|
||||
httpMock: func() {},
|
||||
authURL: "https://appleid.apple.com/auth/authorize?client_id=clientID&redirect_uri=redirectURI&response_mode=form_post&response_type=code&scope=openid&state=testState",
|
||||
tokens: &openid.Tokens[*openid.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: openid.BearerToken,
|
||||
},
|
||||
IDTokenClaims: id_token(),
|
||||
},
|
||||
userFormValue: `{"name": {"firstName": "firstName", "lastName": "lastName"}}`,
|
||||
},
|
||||
want: want{
|
||||
id: "sub",
|
||||
firstName: "firstName",
|
||||
lastName: "lastName",
|
||||
displayName: "",
|
||||
nickName: "",
|
||||
preferredUsername: "email",
|
||||
email: "email",
|
||||
isEmailVerified: true,
|
||||
phone: "",
|
||||
isPhoneVerified: false,
|
||||
preferredLanguage: language.Und,
|
||||
avatarURL: "",
|
||||
profile: "",
|
||||
nonceSupported: true,
|
||||
isPrivateEmail: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer gock.Off()
|
||||
tt.fields.httpMock()
|
||||
a := assert.New(t)
|
||||
|
||||
// call the real discovery endpoint
|
||||
gock.New(issuer).Get(openid.DiscoveryEndpoint).EnableNetworking()
|
||||
provider, err := New(tt.fields.clientID, tt.fields.teamID, tt.fields.keyID, tt.fields.redirectURI, tt.fields.privateKey, tt.fields.scopes)
|
||||
require.NoError(t, err)
|
||||
|
||||
session := &Session{
|
||||
Session: &oidc.Session{
|
||||
Provider: provider.Provider,
|
||||
AuthURL: tt.fields.authURL,
|
||||
Code: tt.fields.code,
|
||||
Tokens: tt.fields.tokens,
|
||||
},
|
||||
UserFormValue: tt.fields.userFormValue,
|
||||
}
|
||||
|
||||
user, err := session.FetchUser(context.Background())
|
||||
if tt.want.err != nil && !errors.Is(err, tt.want.err) {
|
||||
a.Fail("invalid error", "expected %v, got %v", tt.want.err, err)
|
||||
}
|
||||
if tt.want.err == nil {
|
||||
a.NoError(err)
|
||||
a.Equal(tt.want.id, user.GetID())
|
||||
a.Equal(tt.want.firstName, user.GetFirstName())
|
||||
a.Equal(tt.want.lastName, user.GetLastName())
|
||||
a.Equal(tt.want.displayName, user.GetDisplayName())
|
||||
a.Equal(tt.want.nickName, user.GetNickname())
|
||||
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
|
||||
a.Equal(domain.EmailAddress(tt.want.email), user.GetEmail())
|
||||
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
|
||||
a.Equal(domain.PhoneNumber(tt.want.phone), user.GetPhone())
|
||||
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
|
||||
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
|
||||
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
|
||||
a.Equal(tt.want.profile, user.GetProfile())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func id_token() *openid.IDTokenClaims {
|
||||
return &openid.IDTokenClaims{
|
||||
TokenClaims: openid.TokenClaims{
|
||||
Issuer: issuer,
|
||||
Subject: "sub",
|
||||
Audience: []string{"clientID"},
|
||||
Expiration: openid.FromTime(time.Now().Add(1 * time.Hour)),
|
||||
IssuedAt: openid.FromTime(time.Now().Add(-1 * time.Second)),
|
||||
AuthTime: openid.FromTime(time.Now().Add(-1 * time.Second)),
|
||||
Nonce: "nonce",
|
||||
ClientID: "clientID",
|
||||
},
|
||||
UserInfoEmail: openid.UserInfoEmail{
|
||||
Email: "email",
|
||||
EmailVerified: true,
|
||||
},
|
||||
Claims: map[string]any{
|
||||
"nonce_supported": true,
|
||||
"is_private_email": true,
|
||||
},
|
||||
}
|
||||
}
|
253
apps/api/internal/idp/providers/azuread/azuread.go
Normal file
253
apps/api/internal/idp/providers/azuread/azuread.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package azuread
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
|
||||
)
|
||||
|
||||
const (
|
||||
issuerTemplate string = "https://login.microsoftonline.com/%s/v2.0"
|
||||
authURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/v2.0/authorize"
|
||||
tokenURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/v2.0/token"
|
||||
keysURLTemplate string = "https://login.microsoftonline.com/%s/discovery/v2.0/keys"
|
||||
userURL string = "https://graph.microsoft.com/v1.0/me"
|
||||
userinfoEndpoint string = "https://graph.microsoft.com/oidc/userinfo"
|
||||
|
||||
ScopeUserRead string = "User.Read"
|
||||
)
|
||||
|
||||
// TenantType are the well known tenant types to scope the users that can authenticate. TenantType is not an
|
||||
// exclusive list of Azure Tenants which can be used. A consumer can also use their own Tenant ID to scope
|
||||
// authentication to their specific Tenant either through the Tenant ID or the friendly domain name.
|
||||
//
|
||||
// see also https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-v2-protocols#endpoints
|
||||
type TenantType string
|
||||
|
||||
const (
|
||||
// CommonTenant allows users with both personal Microsoft accounts and work/school accounts from Azure Active
|
||||
// Directory to sign in to the application.
|
||||
CommonTenant TenantType = "common"
|
||||
|
||||
// OrganizationsTenant allows only users with work/school accounts from Azure Active Directory to sign in to the application.
|
||||
OrganizationsTenant TenantType = "organizations"
|
||||
|
||||
// ConsumersTenant allows only users with personal Microsoft accounts (MSA) to sign in to the application.
|
||||
ConsumersTenant TenantType = "consumers"
|
||||
)
|
||||
|
||||
var _ idp.Provider = (*Provider)(nil)
|
||||
|
||||
// Provider is the [idp.Provider] implementation for AzureAD (V2 Endpoints)
|
||||
type Provider struct {
|
||||
*oauth.Provider
|
||||
tenant TenantType
|
||||
emailVerified bool
|
||||
options []oauth.ProviderOpts
|
||||
}
|
||||
|
||||
// issuer returns the OIDC issuer based on the [TenantType]
|
||||
func (p *Provider) issuer() string {
|
||||
return fmt.Sprintf(issuerTemplate, p.tenant)
|
||||
}
|
||||
|
||||
// keysEndpoint returns the OIDC jwks_url based on the [TenantType]
|
||||
func (p *Provider) keysEndpoint() string {
|
||||
return fmt.Sprintf(keysURLTemplate, p.tenant)
|
||||
}
|
||||
|
||||
type ProviderOptions func(*Provider)
|
||||
|
||||
// WithTenant allows to set a [TenantType] (can also be a Tenant ID)
|
||||
// default is CommonTenant
|
||||
func WithTenant(tenantType TenantType) ProviderOptions {
|
||||
return func(p *Provider) {
|
||||
p.tenant = tenantType
|
||||
}
|
||||
}
|
||||
|
||||
// WithEmailVerified allows to set every email received as verified
|
||||
func WithEmailVerified() ProviderOptions {
|
||||
return func(p *Provider) {
|
||||
p.emailVerified = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithOAuthOptions allows to specify [oauth.ProviderOpts] like [oauth.WithLinkingAllowed]
|
||||
func WithOAuthOptions(opts ...oauth.ProviderOpts) ProviderOptions {
|
||||
return func(p *Provider) {
|
||||
p.options = append(p.options, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
// New creates an AzureAD provider using the [oauth.Provider] (OAuth 2.0 generic provider).
|
||||
// By default, it uses the [CommonTenant] and unverified emails.
|
||||
func New(name, clientID, clientSecret, redirectURI string, scopes []string, opts ...ProviderOptions) (*Provider, error) {
|
||||
provider := &Provider{
|
||||
tenant: CommonTenant,
|
||||
options: make([]oauth.ProviderOpts, 0),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(provider)
|
||||
}
|
||||
config := newConfig(provider.tenant, clientID, clientSecret, redirectURI, scopes)
|
||||
rp, err := oauth.New(
|
||||
config,
|
||||
name,
|
||||
userURL,
|
||||
func() idp.User {
|
||||
return &User{isEmailVerified: provider.emailVerified}
|
||||
},
|
||||
provider.options...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
provider.Provider = rp
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func newConfig(tenant TenantType, clientID, secret, callbackURL string, scopes []string) *oauth2.Config {
|
||||
return &oauth2.Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: secret,
|
||||
RedirectURL: callbackURL,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: fmt.Sprintf(authURLTemplate, tenant),
|
||||
TokenURL: fmt.Sprintf(tokenURLTemplate, tenant),
|
||||
},
|
||||
Scopes: ensureMinimalScope(scopes),
|
||||
}
|
||||
}
|
||||
|
||||
// ensureMinimalScope ensures that at least openid and `User.Read` ist set
|
||||
// if none is provided it will request `openid profile email phone User.Read`
|
||||
func ensureMinimalScope(scopes []string) []string {
|
||||
if len(scopes) == 0 {
|
||||
return []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopePhone, ScopeUserRead}
|
||||
}
|
||||
var openIDSet, userReadSet bool
|
||||
for _, scope := range scopes {
|
||||
if scope == oidc.ScopeOpenID {
|
||||
openIDSet = true
|
||||
continue
|
||||
}
|
||||
if scope == ScopeUserRead {
|
||||
userReadSet = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
if !openIDSet {
|
||||
scopes = append(scopes, oidc.ScopeOpenID)
|
||||
}
|
||||
if !userReadSet {
|
||||
scopes = append(scopes, ScopeUserRead)
|
||||
}
|
||||
return scopes
|
||||
}
|
||||
|
||||
func (p *Provider) User() idp.User {
|
||||
return p.Provider.User()
|
||||
}
|
||||
|
||||
// User represents the structure return on the userinfo endpoint and implements the [idp.User] interface
|
||||
//
|
||||
// AzureAD does not return an `email_verified` claim.
|
||||
// The verification can be automatically activated on the provider ([WithEmailVerified])
|
||||
type User struct {
|
||||
ID string `json:"id"`
|
||||
BusinessPhones []domain.PhoneNumber `json:"businessPhones"`
|
||||
DisplayName string `json:"displayName"`
|
||||
FirstName string `json:"givenName"`
|
||||
JobTitle string `json:"jobTitle"`
|
||||
Email domain.EmailAddress `json:"mail"`
|
||||
MobilePhone domain.PhoneNumber `json:"mobilePhone"`
|
||||
OfficeLocation string `json:"officeLocation"`
|
||||
PreferredLanguage string `json:"preferredLanguage"`
|
||||
LastName string `json:"surname"`
|
||||
UserPrincipalName string `json:"userPrincipalName"`
|
||||
isEmailVerified bool
|
||||
}
|
||||
|
||||
// GetID is an implementation of the [idp.User] interface.
|
||||
func (u *User) GetID() string {
|
||||
return u.ID
|
||||
}
|
||||
|
||||
// GetFirstName is an implementation of the [idp.User] interface.
|
||||
func (u *User) GetFirstName() string {
|
||||
return u.FirstName
|
||||
}
|
||||
|
||||
// GetLastName is an implementation of the [idp.User] interface.
|
||||
func (u *User) GetLastName() string {
|
||||
return u.LastName
|
||||
}
|
||||
|
||||
// GetDisplayName is an implementation of the [idp.User] interface.
|
||||
func (u *User) GetDisplayName() string {
|
||||
return u.DisplayName
|
||||
}
|
||||
|
||||
// GetNickname is an implementation of the [idp.User] interface.
|
||||
// It returns an empty string because AzureAD does not provide the user's nickname.
|
||||
func (u *User) GetNickname() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetPreferredUsername is an implementation of the [idp.User] interface.
|
||||
func (u *User) GetPreferredUsername() string {
|
||||
return u.UserPrincipalName
|
||||
}
|
||||
|
||||
// GetEmail is an implementation of the [idp.User] interface.
|
||||
func (u *User) GetEmail() domain.EmailAddress {
|
||||
if u.Email == "" {
|
||||
// if the user used a social login on Azure as well, the email will be empty
|
||||
// but is used as username
|
||||
return domain.EmailAddress(u.UserPrincipalName)
|
||||
}
|
||||
return u.Email
|
||||
}
|
||||
|
||||
// IsEmailVerified is an implementation of the [idp.User] interface
|
||||
// returning the value specified in the creation of the [Provider].
|
||||
// Default is false because AzureAD does not return an `email_verified` claim.
|
||||
// The verification can be automatically activated on the provider ([WithEmailVerified]).
|
||||
func (u *User) IsEmailVerified() bool {
|
||||
return u.isEmailVerified
|
||||
}
|
||||
|
||||
// GetPhone is an implementation of the [idp.User] interface.
|
||||
// It returns an empty string because AzureAD does not provide the user's phone.
|
||||
func (u *User) GetPhone() domain.PhoneNumber {
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsPhoneVerified is an implementation of the [idp.User] interface.
|
||||
// It returns false because AzureAD does not provide the user's phone.
|
||||
func (u *User) IsPhoneVerified() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// GetPreferredLanguage is an implementation of the [idp.User] interface.
|
||||
func (u *User) GetPreferredLanguage() language.Tag {
|
||||
return language.Make(u.PreferredLanguage)
|
||||
}
|
||||
|
||||
// GetProfile is an implementation of the [idp.User] interface.
|
||||
// It returns an empty string because AzureAD does not provide the user's profile page.
|
||||
func (u *User) GetProfile() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetAvatarURL is an implementation of the [idp.User] interface.
|
||||
func (u *User) GetAvatarURL() string {
|
||||
return ""
|
||||
}
|
185
apps/api/internal/idp/providers/azuread/azuread_test.go
Normal file
185
apps/api/internal/idp/providers/azuread/azuread_test.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package azuread
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
openid "github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
|
||||
)
|
||||
|
||||
func TestProvider_BeginAuth(t *testing.T) {
|
||||
type fields struct {
|
||||
name string
|
||||
clientID string
|
||||
clientSecret string
|
||||
redirectURI string
|
||||
scopes []string
|
||||
options []ProviderOptions
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want idp.Session
|
||||
}{
|
||||
{
|
||||
name: "default common tenant",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
},
|
||||
want: &oidc.Session{
|
||||
AuthURL: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email+phone+User.Read&state=testState",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tenant",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
options: []ProviderOptions{
|
||||
WithTenant(ConsumersTenant),
|
||||
},
|
||||
},
|
||||
want: &oidc.Session{
|
||||
AuthURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email+phone+User.Read&state=testState",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "custom scopes",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
scopes: []string{openid.ScopeOpenID, openid.ScopeProfile, "custom"},
|
||||
options: []ProviderOptions{
|
||||
WithTenant(ConsumersTenant),
|
||||
},
|
||||
},
|
||||
want: &oidc.Session{
|
||||
AuthURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid+profile+custom+User.Read&state=testState",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
r := require.New(t)
|
||||
|
||||
provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.options...)
|
||||
r.NoError(err)
|
||||
|
||||
ctx := context.Background()
|
||||
session, err := provider.BeginAuth(ctx, "testState")
|
||||
r.NoError(err)
|
||||
|
||||
wantAuth, wantErr := tt.want.GetAuth(ctx)
|
||||
gotAuth, gotErr := session.GetAuth(ctx)
|
||||
a.Equal(wantAuth, gotAuth)
|
||||
a.ErrorIs(gotErr, wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_Options(t *testing.T) {
|
||||
type fields struct {
|
||||
name string
|
||||
clientID string
|
||||
clientSecret string
|
||||
redirectURI string
|
||||
scopes []string
|
||||
options []ProviderOptions
|
||||
}
|
||||
type want struct {
|
||||
name string
|
||||
tenant TenantType
|
||||
emailVerified bool
|
||||
linkingAllowed bool
|
||||
creationAllowed bool
|
||||
autoCreation bool
|
||||
autoUpdate bool
|
||||
pkce bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "default common tenant",
|
||||
fields: fields{
|
||||
name: "default common tenant",
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
scopes: nil,
|
||||
options: nil,
|
||||
},
|
||||
want: want{
|
||||
name: "default common tenant",
|
||||
tenant: CommonTenant,
|
||||
emailVerified: false,
|
||||
linkingAllowed: false,
|
||||
creationAllowed: false,
|
||||
autoCreation: false,
|
||||
autoUpdate: false,
|
||||
pkce: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all set",
|
||||
fields: fields{
|
||||
name: "custom tenant",
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
options: []ProviderOptions{
|
||||
WithTenant("tenant"),
|
||||
WithEmailVerified(),
|
||||
WithOAuthOptions(
|
||||
oauth.WithLinkingAllowed(),
|
||||
oauth.WithCreationAllowed(),
|
||||
oauth.WithAutoCreation(),
|
||||
oauth.WithAutoUpdate(),
|
||||
oauth.WithRelyingPartyOption(rp.WithPKCE(nil)),
|
||||
),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
name: "custom tenant",
|
||||
tenant: "tenant",
|
||||
emailVerified: true,
|
||||
linkingAllowed: true,
|
||||
creationAllowed: true,
|
||||
autoCreation: true,
|
||||
autoUpdate: true,
|
||||
pkce: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.options...)
|
||||
require.NoError(t, err)
|
||||
|
||||
a.Equal(tt.want.name, provider.Name())
|
||||
a.Equal(tt.want.tenant, provider.tenant)
|
||||
a.Equal(tt.want.emailVerified, provider.emailVerified)
|
||||
a.Equal(tt.want.linkingAllowed, provider.IsLinkingAllowed())
|
||||
a.Equal(tt.want.creationAllowed, provider.IsCreationAllowed())
|
||||
a.Equal(tt.want.autoCreation, provider.IsAutoCreation())
|
||||
a.Equal(tt.want.autoUpdate, provider.IsAutoUpdate())
|
||||
a.Equal(tt.want.pkce, provider.RelyingParty.IsPKCE())
|
||||
})
|
||||
}
|
||||
}
|
106
apps/api/internal/idp/providers/azuread/session.go
Normal file
106
apps/api/internal/idp/providers/azuread/session.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package azuread
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
|
||||
)
|
||||
|
||||
var _ idp.Session = (*Session)(nil)
|
||||
|
||||
// Session extends the [oauth.Session] to be able to handle the id_token and to implement the [idp.SessionSupportsMigration] functionality
|
||||
type Session struct {
|
||||
*Provider
|
||||
Code string
|
||||
|
||||
OAuthSession *oauth.Session
|
||||
}
|
||||
|
||||
func NewSession(provider *Provider, code string) *Session {
|
||||
return &Session{Provider: provider, Code: code}
|
||||
}
|
||||
|
||||
// GetAuth implements the [idp.Provider] interface by calling the wrapped [oauth.Session].
|
||||
func (s *Session) GetAuth(ctx context.Context) (idp.Auth, error) {
|
||||
return s.oauth().GetAuth(ctx)
|
||||
}
|
||||
|
||||
// RetrievePreviousID implements the [idp.SessionSupportsMigration] interface by returning the `sub` from the userinfo endpoint
|
||||
func (s *Session) RetrievePreviousID() (string, error) {
|
||||
req, err := http.NewRequest("GET", userinfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("authorization", s.oauth().Tokens.TokenType+" "+s.oauth().Tokens.AccessToken)
|
||||
userinfo := new(oidc.UserInfo)
|
||||
if err := httphelper.HttpRequest(s.Provider.HttpClient(), req, &userinfo); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return userinfo.Subject, nil
|
||||
}
|
||||
|
||||
// PersistentParameters implements the [idp.Session] interface.
|
||||
func (s *Session) PersistentParameters() map[string]any {
|
||||
return nil
|
||||
}
|
||||
|
||||
// FetchUser implements the [idp.Session] interface.
|
||||
// It will execute an OAuth 2.0 code exchange if needed to retrieve the access token,
|
||||
// call the specified userEndpoint and map the received information into an [idp.User].
|
||||
// In case of a specific TenantID as [TenantType] it will additionally extract the id_token and validate it.
|
||||
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
|
||||
user, err = s.oauth().FetchUser(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// since azure will sign the id_token always with the issuer of the application it might differ from
|
||||
// the issuer the auth and token were based on, e.g. when allowing all account types to login,
|
||||
// then the auth endpoint must be `https://login.microsoftonline.com/common/oauth2/v2.0/authorize`
|
||||
// even though the issuer would be like `https://login.microsoftonline.com/d8cdd43f-fd94-4576-8deb-f3bfea72dc2e/v2.0`
|
||||
if s.Provider.tenant == CommonTenant ||
|
||||
s.Provider.tenant == OrganizationsTenant ||
|
||||
s.Provider.tenant == ConsumersTenant {
|
||||
return user, nil
|
||||
}
|
||||
idToken, ok := s.oauth().Tokens.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return user, nil
|
||||
}
|
||||
idTokenVerifier := rp.NewIDTokenVerifier(s.Provider.issuer(), s.Provider.OAuthConfig().ClientID, rp.NewRemoteKeySet(s.Provider.HttpClient(), s.Provider.keysEndpoint()))
|
||||
s.oauth().Tokens.IDTokenClaims, err = rp.VerifyTokens[*oidc.IDTokenClaims](ctx, s.oauth().Tokens.AccessToken, idToken, idTokenVerifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.oauth().Tokens.IDToken = idToken
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *Session) ExpiresAt() time.Time {
|
||||
if s.OAuthSession == nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return s.OAuthSession.ExpiresAt()
|
||||
}
|
||||
|
||||
// Tokens returns the [oidc.Tokens] of the underlying [oauth.Session].
|
||||
func (s *Session) Tokens() *oidc.Tokens[*oidc.IDTokenClaims] {
|
||||
return s.oauth().Tokens
|
||||
}
|
||||
|
||||
func (s *Session) oauth() *oauth.Session {
|
||||
if s.OAuthSession != nil {
|
||||
return s.OAuthSession
|
||||
}
|
||||
s.OAuthSession = &oauth.Session{
|
||||
Code: s.Code,
|
||||
Provider: s.Provider.Provider,
|
||||
}
|
||||
return s.OAuthSession
|
||||
}
|
416
apps/api/internal/idp/providers/azuread/session_test.go
Normal file
416
apps/api/internal/idp/providers/azuread/session_test.go
Normal file
@@ -0,0 +1,416 @@
|
||||
package azuread
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/h2non/gock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
|
||||
)
|
||||
|
||||
func TestSession_FetchUser(t *testing.T) {
|
||||
type fields struct {
|
||||
name string
|
||||
clientID string
|
||||
clientSecret string
|
||||
redirectURI string
|
||||
scopes []string
|
||||
httpMock func()
|
||||
options []ProviderOptions
|
||||
authURL string
|
||||
code string
|
||||
tokens *oidc.Tokens[*oidc.IDTokenClaims]
|
||||
}
|
||||
type want struct {
|
||||
err func(error) bool
|
||||
user idp.User
|
||||
id string
|
||||
firstName string
|
||||
lastName string
|
||||
displayName string
|
||||
nickName string
|
||||
preferredUsername string
|
||||
email string
|
||||
isEmailVerified bool
|
||||
phone string
|
||||
isPhoneVerified bool
|
||||
preferredLanguage language.Tag
|
||||
avatarURL string
|
||||
profile string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "unauthenticated session, error",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
httpMock: func() {
|
||||
gock.New("https://graph.microsoft.com").
|
||||
Get("/v1.0/me").
|
||||
Reply(200).
|
||||
JSON(userinfo())
|
||||
},
|
||||
authURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
|
||||
tokens: nil,
|
||||
},
|
||||
want: want{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, oauth.ErrCodeMissing)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user error",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
httpMock: func() {
|
||||
gock.New("https://graph.microsoft.com").
|
||||
Get("/v1.0/me").
|
||||
Reply(http.StatusInternalServerError)
|
||||
},
|
||||
authURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
|
||||
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: oidc.BearerToken,
|
||||
},
|
||||
IDTokenClaims: oidc.NewIDTokenClaims(
|
||||
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
|
||||
"sub2",
|
||||
[]string{"clientID"},
|
||||
time.Now().Add(1*time.Hour),
|
||||
time.Now().Add(-1*time.Second),
|
||||
"nonce",
|
||||
"",
|
||||
nil,
|
||||
"clientID",
|
||||
0,
|
||||
),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
err: func(err error) bool {
|
||||
return err.Error() == "http status not ok: 500 Internal Server Error "
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful fetch",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
httpMock: func() {
|
||||
gock.New("https://graph.microsoft.com").
|
||||
Get("/v1.0/me").
|
||||
Reply(200).
|
||||
JSON(userinfo())
|
||||
},
|
||||
authURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
|
||||
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: oidc.BearerToken,
|
||||
},
|
||||
IDTokenClaims: oidc.NewIDTokenClaims(
|
||||
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
|
||||
"sub",
|
||||
[]string{"clientID"},
|
||||
time.Now().Add(1*time.Hour),
|
||||
time.Now().Add(-1*time.Second),
|
||||
"nonce",
|
||||
"",
|
||||
nil,
|
||||
"clientID",
|
||||
0,
|
||||
),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
user: &User{
|
||||
ID: "id",
|
||||
BusinessPhones: []domain.PhoneNumber{"phone1", "phone2"},
|
||||
DisplayName: "firstname lastname",
|
||||
FirstName: "firstname",
|
||||
JobTitle: "title",
|
||||
Email: "email",
|
||||
MobilePhone: "mobile",
|
||||
OfficeLocation: "office",
|
||||
PreferredLanguage: "en",
|
||||
LastName: "lastname",
|
||||
UserPrincipalName: "username",
|
||||
isEmailVerified: false,
|
||||
},
|
||||
id: "id",
|
||||
firstName: "firstname",
|
||||
lastName: "lastname",
|
||||
displayName: "firstname lastname",
|
||||
nickName: "",
|
||||
preferredUsername: "username",
|
||||
email: "email",
|
||||
isEmailVerified: false,
|
||||
phone: "",
|
||||
isPhoneVerified: false,
|
||||
preferredLanguage: language.English,
|
||||
profile: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful fetch with email verified",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
options: []ProviderOptions{
|
||||
WithEmailVerified(),
|
||||
},
|
||||
httpMock: func() {
|
||||
gock.New("https://graph.microsoft.com").
|
||||
Get("/v1.0/me").
|
||||
Reply(200).
|
||||
JSON(userinfo())
|
||||
},
|
||||
authURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
|
||||
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: oidc.BearerToken,
|
||||
},
|
||||
IDTokenClaims: oidc.NewIDTokenClaims(
|
||||
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
|
||||
"sub",
|
||||
[]string{"clientID"},
|
||||
time.Now().Add(1*time.Hour),
|
||||
time.Now().Add(-1*time.Second),
|
||||
"nonce",
|
||||
"",
|
||||
nil,
|
||||
"clientID",
|
||||
0,
|
||||
),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
user: &User{
|
||||
ID: "id",
|
||||
BusinessPhones: []domain.PhoneNumber{"phone1", "phone2"},
|
||||
DisplayName: "firstname lastname",
|
||||
FirstName: "firstname",
|
||||
JobTitle: "title",
|
||||
Email: "email",
|
||||
MobilePhone: "mobile",
|
||||
OfficeLocation: "office",
|
||||
PreferredLanguage: "en",
|
||||
LastName: "lastname",
|
||||
UserPrincipalName: "username",
|
||||
isEmailVerified: true,
|
||||
},
|
||||
id: "id",
|
||||
firstName: "firstname",
|
||||
lastName: "lastname",
|
||||
displayName: "firstname lastname",
|
||||
nickName: "",
|
||||
preferredUsername: "username",
|
||||
email: "email",
|
||||
isEmailVerified: true,
|
||||
phone: "",
|
||||
isPhoneVerified: false,
|
||||
preferredLanguage: language.English,
|
||||
profile: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer gock.Off()
|
||||
tt.fields.httpMock()
|
||||
a := assert.New(t)
|
||||
|
||||
provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.options...)
|
||||
require.NoError(t, err)
|
||||
|
||||
session := &Session{
|
||||
Provider: provider,
|
||||
Code: tt.fields.code,
|
||||
|
||||
OAuthSession: &oauth.Session{
|
||||
AuthURL: tt.fields.authURL,
|
||||
Tokens: tt.fields.tokens,
|
||||
Provider: provider.Provider,
|
||||
Code: tt.fields.code,
|
||||
},
|
||||
}
|
||||
|
||||
user, err := session.FetchUser(context.Background())
|
||||
if tt.want.err != nil && !tt.want.err(err) {
|
||||
a.Fail("invalid error", err)
|
||||
}
|
||||
if tt.want.err == nil {
|
||||
a.NoError(err)
|
||||
a.Equal(tt.want.user, user)
|
||||
a.Equal(tt.want.id, user.GetID())
|
||||
a.Equal(tt.want.firstName, user.GetFirstName())
|
||||
a.Equal(tt.want.lastName, user.GetLastName())
|
||||
a.Equal(tt.want.displayName, user.GetDisplayName())
|
||||
a.Equal(tt.want.nickName, user.GetNickname())
|
||||
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
|
||||
a.Equal(domain.EmailAddress(tt.want.email), user.GetEmail())
|
||||
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
|
||||
a.Equal(domain.PhoneNumber(tt.want.phone), user.GetPhone())
|
||||
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
|
||||
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
|
||||
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
|
||||
a.Equal(tt.want.profile, user.GetProfile())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func userinfo() *User {
|
||||
return &User{
|
||||
ID: "id",
|
||||
BusinessPhones: []domain.PhoneNumber{"phone1", "phone2"},
|
||||
DisplayName: "firstname lastname",
|
||||
FirstName: "firstname",
|
||||
JobTitle: "title",
|
||||
Email: "email",
|
||||
MobilePhone: "mobile",
|
||||
OfficeLocation: "office",
|
||||
PreferredLanguage: "en",
|
||||
LastName: "lastname",
|
||||
UserPrincipalName: "username",
|
||||
}
|
||||
}
|
||||
|
||||
func TestSession_RetrievePreviousID(t *testing.T) {
|
||||
type fields struct {
|
||||
name string
|
||||
clientID string
|
||||
clientSecret string
|
||||
redirectURI string
|
||||
scopes []string
|
||||
httpMock func()
|
||||
tokens *oidc.Tokens[*oidc.IDTokenClaims]
|
||||
}
|
||||
type res struct {
|
||||
id string
|
||||
err bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "invalid token",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
httpMock: func() {
|
||||
gock.New("https://graph.microsoft.com").
|
||||
Get("/oidc/userinfo").
|
||||
Reply(401)
|
||||
},
|
||||
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: oidc.BearerToken,
|
||||
},
|
||||
IDTokenClaims: oidc.NewIDTokenClaims(
|
||||
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
|
||||
"sub",
|
||||
[]string{"clientID"},
|
||||
time.Now().Add(1*time.Hour),
|
||||
time.Now().Add(-1*time.Second),
|
||||
"nonce",
|
||||
"",
|
||||
nil,
|
||||
"clientID",
|
||||
0,
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
err: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
httpMock: func() {
|
||||
gock.New("https://graph.microsoft.com").
|
||||
Get("/oidc/userinfo").
|
||||
Reply(200).
|
||||
JSON(`{"sub":"sub"}`)
|
||||
},
|
||||
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: oidc.BearerToken,
|
||||
},
|
||||
IDTokenClaims: oidc.NewIDTokenClaims(
|
||||
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
|
||||
"sub",
|
||||
[]string{"clientID"},
|
||||
time.Now().Add(1*time.Hour),
|
||||
time.Now().Add(-1*time.Second),
|
||||
"nonce",
|
||||
"",
|
||||
nil,
|
||||
"clientID",
|
||||
0,
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
id: "sub",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer gock.Off()
|
||||
tt.fields.httpMock()
|
||||
a := assert.New(t)
|
||||
|
||||
provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes)
|
||||
require.NoError(t, err)
|
||||
session := &Session{
|
||||
Provider: provider,
|
||||
OAuthSession: &oauth.Session{
|
||||
Tokens: tt.fields.tokens,
|
||||
Provider: provider.Provider,
|
||||
}}
|
||||
|
||||
id, err := session.RetrievePreviousID()
|
||||
if tt.res.err {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
a.Equal(tt.res.id, id)
|
||||
})
|
||||
}
|
||||
}
|
190
apps/api/internal/idp/providers/github/github.go
Normal file
190
apps/api/internal/idp/providers/github/github.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package github
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
|
||||
)
|
||||
|
||||
const (
|
||||
authURL = "https://github.com/login/oauth/authorize"
|
||||
tokenURL = "https://github.com/login/oauth/access_token"
|
||||
profileURL = "https://api.github.com/user"
|
||||
name = "GitHub"
|
||||
)
|
||||
|
||||
var _ idp.Provider = (*Provider)(nil)
|
||||
|
||||
// New creates a GitHub.com provider using the [oauth.Provider] (OAuth 2.0 generic provider)
|
||||
func New(clientID, secret, callbackURL string, scopes []string, options ...oauth.ProviderOpts) (*Provider, error) {
|
||||
return NewCustomURL(name, clientID, secret, callbackURL, authURL, tokenURL, profileURL, scopes, options...)
|
||||
}
|
||||
|
||||
// NewCustomURL creates a GitHub provider using the [oauth.Provider] (OAuth 2.0 generic provider)
|
||||
// with custom endpoints, e.g. GitHub Enterprise server
|
||||
func NewCustomURL(name, clientID, secret, callbackURL, authURL, tokenURL, profileURL string, scopes []string, options ...oauth.ProviderOpts) (*Provider, error) {
|
||||
rp, err := oauth.New(
|
||||
newConfig(clientID, secret, callbackURL, authURL, tokenURL, scopes),
|
||||
name,
|
||||
profileURL,
|
||||
func() idp.User {
|
||||
return new(User)
|
||||
},
|
||||
options...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Provider{
|
||||
Provider: rp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Provider is the [idp.Provider] implementation for GitHub
|
||||
type Provider struct {
|
||||
*oauth.Provider
|
||||
}
|
||||
|
||||
func newConfig(clientID, secret, callbackURL, authURL, tokenURL string, scopes []string) *oauth2.Config {
|
||||
c := &oauth2.Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: secret,
|
||||
RedirectURL: callbackURL,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: authURL,
|
||||
TokenURL: tokenURL,
|
||||
},
|
||||
Scopes: scopes,
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// User is a representation of the authenticated GitHub user and implements the [idp.User] interface
|
||||
// https://docs.github.com/en/rest/users/users?apiVersion=2022-11-28#get-the-authenticated-user
|
||||
type User struct {
|
||||
Login string `json:"login"`
|
||||
ID int `json:"id"`
|
||||
NodeId string `json:"node_id"`
|
||||
AvatarUrl string `json:"avatar_url"`
|
||||
GravatarId string `json:"gravatar_id"`
|
||||
Url string `json:"url"`
|
||||
HtmlUrl string `json:"html_url"`
|
||||
FollowersUrl string `json:"followers_url"`
|
||||
FollowingUrl string `json:"following_url"`
|
||||
GistsUrl string `json:"gists_url"`
|
||||
StarredUrl string `json:"starred_url"`
|
||||
SubscriptionsUrl string `json:"subscriptions_url"`
|
||||
OrganizationsUrl string `json:"organizations_url"`
|
||||
ReposUrl string `json:"repos_url"`
|
||||
EventsUrl string `json:"events_url"`
|
||||
ReceivedEventsUrl string `json:"received_events_url"`
|
||||
Type string `json:"type"`
|
||||
SiteAdmin bool `json:"site_admin"`
|
||||
Name string `json:"name"`
|
||||
Company string `json:"company"`
|
||||
Blog string `json:"blog"`
|
||||
Location string `json:"location"`
|
||||
Email domain.EmailAddress `json:"email"`
|
||||
Hireable bool `json:"hireable"`
|
||||
Bio string `json:"bio"`
|
||||
TwitterUsername string `json:"twitter_username"`
|
||||
PublicRepos int `json:"public_repos"`
|
||||
PublicGists int `json:"public_gists"`
|
||||
Followers int `json:"followers"`
|
||||
Following int `json:"following"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
PrivateGists int `json:"private_gists"`
|
||||
TotalPrivateRepos int `json:"total_private_repos"`
|
||||
OwnedPrivateRepos int `json:"owned_private_repos"`
|
||||
DiskUsage int `json:"disk_usage"`
|
||||
Collaborators int `json:"collaborators"`
|
||||
TwoFactorAuthentication bool `json:"two_factor_authentication"`
|
||||
Plan struct {
|
||||
Name string `json:"name"`
|
||||
Space int `json:"space"`
|
||||
PrivateRepos int `json:"private_repos"`
|
||||
Collaborators int `json:"collaborators"`
|
||||
} `json:"plan"`
|
||||
}
|
||||
|
||||
// GetID is an implementation of the [idp.User] interface.
|
||||
func (u *User) GetID() string {
|
||||
return strconv.Itoa(u.ID)
|
||||
}
|
||||
|
||||
// GetFirstName is an implementation of the [idp.User] interface.
|
||||
// It returns an empty string because GitHub does not provide the user's firstname.
|
||||
func (u *User) GetFirstName() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetLastName is an implementation of the [idp.User] interface.
|
||||
// It returns an empty string because GitHub does not provide the user's lastname.
|
||||
func (u *User) GetLastName() string {
|
||||
// GitHub does not provide the user's lastname
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetDisplayName is an implementation of the [idp.User] interface.
|
||||
func (u *User) GetDisplayName() string {
|
||||
return u.Name
|
||||
}
|
||||
|
||||
// GetNickname is an implementation of the [idp.User] interface
|
||||
// returning the login name of the GitHub user.
|
||||
func (u *User) GetNickname() string {
|
||||
return u.Login
|
||||
}
|
||||
|
||||
// GetPreferredUsername is an implementation of the [idp.User] interface
|
||||
// returning the login name of the GitHub user.
|
||||
func (u *User) GetPreferredUsername() string {
|
||||
return u.Login
|
||||
}
|
||||
|
||||
// GetEmail is an implementation of the [idp.User] interface.
|
||||
func (u *User) GetEmail() domain.EmailAddress {
|
||||
return u.Email
|
||||
}
|
||||
|
||||
// IsEmailVerified is an implementation of the [idp.User] interface.
|
||||
// It returns true because GitHub validates emails themselves.
|
||||
func (u *User) IsEmailVerified() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// GetPhone is an implementation of the [idp.User] interface.
|
||||
// It returns an empty string because GitHub does not provide the user's phone.
|
||||
func (u *User) GetPhone() domain.PhoneNumber {
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsPhoneVerified is an implementation of the [idp.User] interface
|
||||
// it returns false because GitHub does not provide the user's phone
|
||||
func (u *User) IsPhoneVerified() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// GetPreferredLanguage is an implementation of the [idp.User] interface.
|
||||
// It returns [language.Und] because GitHub does not provide the user's language.
|
||||
func (u *User) GetPreferredLanguage() language.Tag {
|
||||
return language.Und
|
||||
}
|
||||
|
||||
// GetProfile is an implementation of the [idp.User] interface.
|
||||
func (u *User) GetProfile() string {
|
||||
return u.HtmlUrl
|
||||
}
|
||||
|
||||
// GetAvatarURL is an implementation of the [idp.User] interface.
|
||||
func (u *User) GetAvatarURL() string {
|
||||
return u.AvatarUrl
|
||||
}
|
57
apps/api/internal/idp/providers/github/github_test.go
Normal file
57
apps/api/internal/idp/providers/github/github_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package github
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
|
||||
)
|
||||
|
||||
func TestProvider_BeginAuth(t *testing.T) {
|
||||
type fields struct {
|
||||
clientID string
|
||||
clientSecret string
|
||||
redirectURI string
|
||||
scopes []string
|
||||
options []oauth.ProviderOpts
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want idp.Session
|
||||
}{
|
||||
{
|
||||
name: "successful auth",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
},
|
||||
want: &oauth.Session{
|
||||
AuthURL: "https://github.com/login/oauth/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&state=testState",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
r := require.New(t)
|
||||
|
||||
provider, err := New(tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.options...)
|
||||
r.NoError(err)
|
||||
|
||||
ctx := context.Background()
|
||||
session, err := provider.BeginAuth(ctx, "testState")
|
||||
r.NoError(err)
|
||||
|
||||
wantAuth, wantErr := tt.want.GetAuth(ctx)
|
||||
gotAuth, gotErr := session.GetAuth(ctx)
|
||||
a.Equal(wantAuth, gotAuth)
|
||||
a.ErrorIs(gotErr, wantErr)
|
||||
})
|
||||
}
|
||||
}
|
216
apps/api/internal/idp/providers/github/session_test.go
Normal file
216
apps/api/internal/idp/providers/github/session_test.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package github
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/h2non/gock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
|
||||
)
|
||||
|
||||
func TestSession_FetchUser(t *testing.T) {
|
||||
type fields struct {
|
||||
clientID string
|
||||
clientSecret string
|
||||
redirectURI string
|
||||
httpMock func()
|
||||
authURL string
|
||||
code string
|
||||
tokens *oidc.Tokens[*oidc.IDTokenClaims]
|
||||
scopes []string
|
||||
options []oauth.ProviderOpts
|
||||
}
|
||||
type args struct {
|
||||
session idp.Session
|
||||
}
|
||||
type want struct {
|
||||
err func(error) bool
|
||||
user idp.User
|
||||
id string
|
||||
firstName string
|
||||
lastName string
|
||||
displayName string
|
||||
nickName string
|
||||
preferredUsername string
|
||||
email string
|
||||
isEmailVerified bool
|
||||
phone string
|
||||
isPhoneVerified bool
|
||||
preferredLanguage language.Tag
|
||||
avatarURL string
|
||||
profile string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "unauthenticated session, error",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
httpMock: func() {
|
||||
gock.New("https://api.github.com").
|
||||
Get("/user").
|
||||
Reply(200).
|
||||
JSON(userinfo())
|
||||
},
|
||||
authURL: "https://github.com/login/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&state=testState",
|
||||
tokens: nil,
|
||||
},
|
||||
args: args{
|
||||
&oauth.Session{},
|
||||
},
|
||||
want: want{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, oauth.ErrCodeMissing)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user error",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
httpMock: func() {
|
||||
gock.New("https://api.github.com").
|
||||
Get("/user").
|
||||
Reply(http.StatusInternalServerError)
|
||||
},
|
||||
authURL: "https://github.com/login/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&state=testState",
|
||||
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: oidc.BearerToken,
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
&oauth.Session{},
|
||||
},
|
||||
want: want{
|
||||
err: func(err error) bool {
|
||||
return err.Error() == "http status not ok: 500 Internal Server Error "
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful fetch",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
httpMock: func() {
|
||||
gock.New("https://api.github.com").
|
||||
Get("/user").
|
||||
Reply(200).
|
||||
JSON(userinfo())
|
||||
},
|
||||
authURL: "https://github.com/login/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&state=testState",
|
||||
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: oidc.BearerToken,
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
&oauth.Session{},
|
||||
},
|
||||
want: want{
|
||||
user: &User{
|
||||
Login: "login",
|
||||
ID: 1,
|
||||
AvatarUrl: "avatarURL",
|
||||
GravatarId: "gravatarID",
|
||||
Name: "name",
|
||||
Email: "email",
|
||||
HtmlUrl: "htmlURL",
|
||||
CreatedAt: time.Date(2023, 01, 10, 11, 10, 35, 0, time.UTC),
|
||||
UpdatedAt: time.Date(2023, 01, 10, 11, 10, 35, 0, time.UTC),
|
||||
},
|
||||
id: "1",
|
||||
firstName: "",
|
||||
lastName: "",
|
||||
displayName: "name",
|
||||
nickName: "login",
|
||||
preferredUsername: "login",
|
||||
email: "email",
|
||||
isEmailVerified: true,
|
||||
phone: "",
|
||||
isPhoneVerified: false,
|
||||
preferredLanguage: language.Und,
|
||||
avatarURL: "avatarURL",
|
||||
profile: "htmlURL",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer gock.Off()
|
||||
tt.fields.httpMock()
|
||||
a := assert.New(t)
|
||||
|
||||
provider, err := New(tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.options...)
|
||||
require.NoError(t, err)
|
||||
|
||||
session := &oauth.Session{
|
||||
AuthURL: tt.fields.authURL,
|
||||
Code: tt.fields.code,
|
||||
Tokens: tt.fields.tokens,
|
||||
Provider: provider.Provider,
|
||||
}
|
||||
|
||||
user, err := session.FetchUser(context.Background())
|
||||
if tt.want.err != nil && !tt.want.err(err) {
|
||||
a.Fail("invalid error", err)
|
||||
}
|
||||
if tt.want.err == nil {
|
||||
a.NoError(err)
|
||||
a.Equal(tt.want.user, user)
|
||||
a.Equal(tt.want.id, user.GetID())
|
||||
a.Equal(tt.want.firstName, user.GetFirstName())
|
||||
a.Equal(tt.want.lastName, user.GetLastName())
|
||||
a.Equal(tt.want.displayName, user.GetDisplayName())
|
||||
a.Equal(tt.want.nickName, user.GetNickname())
|
||||
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
|
||||
a.Equal(domain.EmailAddress(tt.want.email), user.GetEmail())
|
||||
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
|
||||
a.Equal(domain.PhoneNumber(tt.want.phone), user.GetPhone())
|
||||
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
|
||||
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
|
||||
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
|
||||
a.Equal(tt.want.profile, user.GetProfile())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func userinfo() *User {
|
||||
return &User{
|
||||
Login: "login",
|
||||
ID: 1,
|
||||
AvatarUrl: "avatarURL",
|
||||
GravatarId: "gravatarID",
|
||||
Name: "name",
|
||||
Email: "email",
|
||||
HtmlUrl: "htmlURL",
|
||||
CreatedAt: time.Date(2023, 01, 10, 11, 10, 35, 0, time.UTC),
|
||||
UpdatedAt: time.Date(2023, 01, 10, 11, 10, 35, 0, time.UTC),
|
||||
}
|
||||
}
|
45
apps/api/internal/idp/providers/gitlab/gitlab.go
Normal file
45
apps/api/internal/idp/providers/gitlab/gitlab.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package gitlab
|
||||
|
||||
import (
|
||||
openid "github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
|
||||
)
|
||||
|
||||
const (
|
||||
issuer = "https://gitlab.com"
|
||||
name = "GitLab"
|
||||
)
|
||||
|
||||
var _ idp.Provider = (*Provider)(nil)
|
||||
|
||||
// Provider is the [idp.Provider] implementation for Gitlab
|
||||
type Provider struct {
|
||||
*oidc.Provider
|
||||
}
|
||||
|
||||
// New creates a GitLab.com provider using the [oidc.Provider] (OIDC generic provider)
|
||||
func New(clientID, clientSecret, redirectURI string, scopes []string, options ...oidc.ProviderOpts) (*Provider, error) {
|
||||
return NewCustomIssuer(name, issuer, clientID, clientSecret, redirectURI, scopes, options...)
|
||||
}
|
||||
|
||||
// NewCustomIssuer creates a GitLab provider using the [oidc.Provider] (OIDC generic provider)
|
||||
// with a custom issuer for self-managed instances
|
||||
func NewCustomIssuer(name, issuer, clientID, clientSecret, redirectURI string, scopes []string, options ...oidc.ProviderOpts) (*Provider, error) {
|
||||
if len(scopes) == 0 {
|
||||
// the OIDC provider would set `openid profile email phone` as default scope,
|
||||
// but since gitlab does not handle unknown scopes correctly (phone) and returns an error,
|
||||
// we will just set a separate default list
|
||||
scopes = []string{openid.ScopeOpenID, openid.ScopeProfile, openid.ScopeEmail}
|
||||
}
|
||||
// gitlab is currently not able to handle the prompt `select_account`:
|
||||
// https://gitlab.com/gitlab-org/gitlab/-/issues/377368
|
||||
rp, err := oidc.New(name, issuer, clientID, clientSecret, redirectURI, scopes, oidc.DefaultMapper, options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Provider{
|
||||
Provider: rp,
|
||||
}, nil
|
||||
}
|
68
apps/api/internal/idp/providers/gitlab/gitlab_test.go
Normal file
68
apps/api/internal/idp/providers/gitlab/gitlab_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package gitlab
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
|
||||
)
|
||||
|
||||
func TestProvider_BeginAuth(t *testing.T) {
|
||||
type fields struct {
|
||||
clientID string
|
||||
clientSecret string
|
||||
redirectURI string
|
||||
scopes []string
|
||||
opts []oidc.ProviderOpts
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want idp.Session
|
||||
}{
|
||||
{
|
||||
name: "successful auth",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
scopes: []string{"openid"},
|
||||
},
|
||||
want: &oidc.Session{
|
||||
AuthURL: "https://gitlab.com/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful auth default scopes",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
},
|
||||
want: &oidc.Session{
|
||||
AuthURL: "https://gitlab.com/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
r := require.New(t)
|
||||
|
||||
provider, err := New(tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.opts...)
|
||||
r.NoError(err)
|
||||
ctx := context.Background()
|
||||
session, err := provider.BeginAuth(ctx, "testState")
|
||||
r.NoError(err)
|
||||
|
||||
wantAuth, wantErr := tt.want.GetAuth(ctx)
|
||||
gotAuth, gotErr := session.GetAuth(ctx)
|
||||
a.Equal(wantAuth, gotAuth)
|
||||
a.ErrorIs(gotErr, wantErr)
|
||||
})
|
||||
}
|
||||
}
|
225
apps/api/internal/idp/providers/gitlab/session_test.go
Normal file
225
apps/api/internal/idp/providers/gitlab/session_test.go
Normal file
@@ -0,0 +1,225 @@
|
||||
package gitlab
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/h2non/gock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
openid "github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
|
||||
)
|
||||
|
||||
func TestProvider_FetchUser(t *testing.T) {
|
||||
type fields struct {
|
||||
clientID string
|
||||
clientSecret string
|
||||
redirectURI string
|
||||
scopes []string
|
||||
httpMock func()
|
||||
authURL string
|
||||
code string
|
||||
tokens *openid.Tokens[*openid.IDTokenClaims]
|
||||
options []oidc.ProviderOpts
|
||||
}
|
||||
type want struct {
|
||||
err error
|
||||
id string
|
||||
firstName string
|
||||
lastName string
|
||||
displayName string
|
||||
nickName string
|
||||
preferredUsername string
|
||||
email string
|
||||
isEmailVerified bool
|
||||
phone string
|
||||
isPhoneVerified bool
|
||||
preferredLanguage language.Tag
|
||||
avatarURL string
|
||||
profile string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "unauthenticated session, error",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
scopes: []string{"openid"},
|
||||
httpMock: func() {
|
||||
gock.New("https://gitlab.com/oauth").
|
||||
Get("/userinfo").
|
||||
Reply(200).
|
||||
JSON(userinfo())
|
||||
},
|
||||
authURL: "https://gitlab.com/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
|
||||
tokens: nil,
|
||||
},
|
||||
want: want{
|
||||
err: oidc.ErrCodeMissing,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "userinfo error",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
scopes: []string{"openid"},
|
||||
httpMock: func() {
|
||||
gock.New("https://gitlab.com/oauth").
|
||||
Get("/userinfo").
|
||||
Reply(200).
|
||||
JSON(userinfo())
|
||||
},
|
||||
authURL: "https://gitlab.com/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
|
||||
tokens: &openid.Tokens[*openid.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: openid.BearerToken,
|
||||
},
|
||||
IDTokenClaims: openid.NewIDTokenClaims(
|
||||
issuer,
|
||||
"sub2",
|
||||
[]string{"clientID"},
|
||||
time.Now().Add(1*time.Hour),
|
||||
time.Now().Add(-1*time.Second),
|
||||
"nonce",
|
||||
"",
|
||||
nil,
|
||||
"clientID",
|
||||
0,
|
||||
),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
err: rp.ErrUserInfoSubNotMatching,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful fetch",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
scopes: []string{"openid"},
|
||||
httpMock: func() {
|
||||
gock.New("https://gitlab.com/oauth").
|
||||
Get("/userinfo").
|
||||
Reply(200).
|
||||
JSON(userinfo())
|
||||
},
|
||||
authURL: "https://gitlab.com/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
|
||||
tokens: &openid.Tokens[*openid.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: openid.BearerToken,
|
||||
},
|
||||
IDTokenClaims: openid.NewIDTokenClaims(
|
||||
issuer,
|
||||
"sub",
|
||||
[]string{"clientID"},
|
||||
time.Now().Add(1*time.Hour),
|
||||
time.Now().Add(-1*time.Second),
|
||||
"nonce",
|
||||
"",
|
||||
nil,
|
||||
"clientID",
|
||||
0,
|
||||
),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
id: "sub",
|
||||
firstName: "firstname",
|
||||
lastName: "lastname",
|
||||
displayName: "firstname lastname",
|
||||
nickName: "nickname",
|
||||
preferredUsername: "username",
|
||||
email: "email",
|
||||
isEmailVerified: true,
|
||||
phone: "phone",
|
||||
isPhoneVerified: true,
|
||||
preferredLanguage: language.English,
|
||||
avatarURL: "picture",
|
||||
profile: "profile",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer gock.Off()
|
||||
tt.fields.httpMock()
|
||||
a := assert.New(t)
|
||||
|
||||
// call the real discovery endpoint
|
||||
gock.New(issuer).Get(openid.DiscoveryEndpoint).EnableNetworking()
|
||||
provider, err := New(tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.options...)
|
||||
require.NoError(t, err)
|
||||
|
||||
session := &oidc.Session{
|
||||
Provider: provider.Provider,
|
||||
AuthURL: tt.fields.authURL,
|
||||
Code: tt.fields.code,
|
||||
Tokens: tt.fields.tokens,
|
||||
}
|
||||
|
||||
user, err := session.FetchUser(context.Background())
|
||||
if tt.want.err != nil && !errors.Is(err, tt.want.err) {
|
||||
a.Fail("invalid error", "expected %v, got %v", tt.want.err, err)
|
||||
}
|
||||
if tt.want.err == nil {
|
||||
a.NoError(err)
|
||||
a.Equal(tt.want.id, user.GetID())
|
||||
a.Equal(tt.want.firstName, user.GetFirstName())
|
||||
a.Equal(tt.want.lastName, user.GetLastName())
|
||||
a.Equal(tt.want.displayName, user.GetDisplayName())
|
||||
a.Equal(tt.want.nickName, user.GetNickname())
|
||||
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
|
||||
a.Equal(domain.EmailAddress(tt.want.email), user.GetEmail())
|
||||
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
|
||||
a.Equal(domain.PhoneNumber(tt.want.phone), user.GetPhone())
|
||||
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
|
||||
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
|
||||
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
|
||||
a.Equal(tt.want.profile, user.GetProfile())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func userinfo() *openid.UserInfo {
|
||||
return &openid.UserInfo{
|
||||
Subject: "sub",
|
||||
UserInfoProfile: openid.UserInfoProfile{
|
||||
GivenName: "firstname",
|
||||
FamilyName: "lastname",
|
||||
Name: "firstname lastname",
|
||||
Nickname: "nickname",
|
||||
PreferredUsername: "username",
|
||||
Locale: openid.NewLocale(language.English),
|
||||
Picture: "picture",
|
||||
Profile: "profile",
|
||||
},
|
||||
UserInfoEmail: openid.UserInfoEmail{
|
||||
Email: "email",
|
||||
EmailVerified: openid.Bool(true),
|
||||
},
|
||||
UserInfoPhone: openid.UserInfoPhone{
|
||||
PhoneNumber: "phone",
|
||||
PhoneNumberVerified: true,
|
||||
},
|
||||
}
|
||||
}
|
51
apps/api/internal/idp/providers/google/google.go
Normal file
51
apps/api/internal/idp/providers/google/google.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
openid "github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
|
||||
)
|
||||
|
||||
const (
|
||||
issuer = "https://accounts.google.com"
|
||||
name = "Google"
|
||||
)
|
||||
|
||||
var _ idp.Provider = (*Provider)(nil)
|
||||
|
||||
// Provider is the [idp.Provider] implementation for Google
|
||||
type Provider struct {
|
||||
*oidc.Provider
|
||||
}
|
||||
|
||||
// New creates a Google provider using the [oidc.Provider] (OIDC generic provider)
|
||||
func New(clientID, clientSecret, redirectURI string, scopes []string, opts ...oidc.ProviderOpts) (*Provider, error) {
|
||||
rp, err := oidc.New(name, issuer, clientID, clientSecret, redirectURI, scopes, userMapper, append(opts, oidc.WithSelectAccount())...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Provider{
|
||||
Provider: rp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var userMapper = func(info *openid.UserInfo) idp.User {
|
||||
return &User{oidc.DefaultMapper(info)}
|
||||
}
|
||||
|
||||
func InitUser() idp.User {
|
||||
return &User{oidc.InitUser()}
|
||||
}
|
||||
|
||||
// User is a representation of the authenticated Google and implements the [idp.User] interface
|
||||
// by wrapping an [idp.User] (implemented by [oidc.User]). It overwrites the [GetPreferredUsername] to use the `email` claim.
|
||||
type User struct {
|
||||
idp.User
|
||||
}
|
||||
|
||||
// GetPreferredUsername implements the [idp.User] interface.
|
||||
// It returns the email, because Google does not return a username.
|
||||
func (u *User) GetPreferredUsername() string {
|
||||
return string(u.GetEmail())
|
||||
}
|
57
apps/api/internal/idp/providers/google/google_test.go
Normal file
57
apps/api/internal/idp/providers/google/google_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
|
||||
)
|
||||
|
||||
func TestProvider_BeginAuth(t *testing.T) {
|
||||
type fields struct {
|
||||
clientID string
|
||||
clientSecret string
|
||||
redirectURI string
|
||||
scopes []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want idp.Session
|
||||
}{
|
||||
{
|
||||
name: "successful auth",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
scopes: []string{"openid"},
|
||||
},
|
||||
want: &oidc.Session{
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/v2/auth?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
r := require.New(t)
|
||||
|
||||
provider, err := New(tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes)
|
||||
r.NoError(err)
|
||||
|
||||
ctx := context.Background()
|
||||
session, err := provider.BeginAuth(ctx, "testState")
|
||||
r.NoError(err)
|
||||
|
||||
wantAuth, wantErr := tt.want.GetAuth(ctx)
|
||||
gotAuth, gotErr := session.GetAuth(ctx)
|
||||
a.Equal(wantAuth, gotAuth)
|
||||
a.ErrorIs(gotErr, wantErr)
|
||||
})
|
||||
}
|
||||
}
|
222
apps/api/internal/idp/providers/google/session_test.go
Normal file
222
apps/api/internal/idp/providers/google/session_test.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/h2non/gock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
openid "github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
|
||||
)
|
||||
|
||||
func TestSession_FetchUser(t *testing.T) {
|
||||
type fields struct {
|
||||
clientID string
|
||||
clientSecret string
|
||||
redirectURI string
|
||||
scopes []string
|
||||
httpMock func()
|
||||
authURL string
|
||||
code string
|
||||
tokens *openid.Tokens[*openid.IDTokenClaims]
|
||||
}
|
||||
type want struct {
|
||||
err error
|
||||
id string
|
||||
firstName string
|
||||
lastName string
|
||||
displayName string
|
||||
nickName string
|
||||
preferredUsername string
|
||||
email string
|
||||
isEmailVerified bool
|
||||
phone string
|
||||
isPhoneVerified bool
|
||||
preferredLanguage language.Tag
|
||||
avatarURL string
|
||||
profile string
|
||||
hostedDomain string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "unauthenticated session, error",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
scopes: []string{"openid"},
|
||||
httpMock: func() {
|
||||
gock.New("https://openidconnect.googleapis.com").
|
||||
Get("/v1/userinfo").
|
||||
Reply(200).
|
||||
JSON(userinfo())
|
||||
},
|
||||
authURL: "https://accounts.google.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
|
||||
tokens: nil,
|
||||
},
|
||||
want: want{
|
||||
err: oidc.ErrCodeMissing,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "userinfo error",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
scopes: []string{"openid"},
|
||||
httpMock: func() {
|
||||
gock.New("https://openidconnect.googleapis.com").
|
||||
Get("/v1/userinfo").
|
||||
Reply(200).
|
||||
JSON(userinfo())
|
||||
},
|
||||
authURL: "https://accounts.google.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
|
||||
tokens: &openid.Tokens[*openid.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: openid.BearerToken,
|
||||
},
|
||||
IDTokenClaims: openid.NewIDTokenClaims(
|
||||
issuer,
|
||||
"sub2",
|
||||
[]string{"clientID"},
|
||||
time.Now().Add(1*time.Hour),
|
||||
time.Now().Add(-1*time.Second),
|
||||
"nonce",
|
||||
"",
|
||||
nil,
|
||||
"clientID",
|
||||
0,
|
||||
),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
err: rp.ErrUserInfoSubNotMatching,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful fetch",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
scopes: []string{"openid"},
|
||||
httpMock: func() {
|
||||
gock.New("https://openidconnect.googleapis.com").
|
||||
Get("/v1/userinfo").
|
||||
Reply(200).
|
||||
JSON(userinfo())
|
||||
},
|
||||
authURL: "https://accounts.google.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
|
||||
tokens: &openid.Tokens[*openid.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: openid.BearerToken,
|
||||
},
|
||||
IDTokenClaims: openid.NewIDTokenClaims(
|
||||
issuer,
|
||||
"sub",
|
||||
[]string{"clientID"},
|
||||
time.Now().Add(1*time.Hour),
|
||||
time.Now().Add(-1*time.Second),
|
||||
"nonce",
|
||||
"",
|
||||
nil,
|
||||
"clientID",
|
||||
0,
|
||||
),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
id: "sub",
|
||||
firstName: "firstname",
|
||||
lastName: "lastname",
|
||||
displayName: "firstname lastname",
|
||||
nickName: "",
|
||||
preferredUsername: "email",
|
||||
email: "email",
|
||||
isEmailVerified: true,
|
||||
phone: "",
|
||||
isPhoneVerified: false,
|
||||
preferredLanguage: language.English,
|
||||
avatarURL: "picture",
|
||||
profile: "",
|
||||
hostedDomain: "hosted domain",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer gock.Off()
|
||||
tt.fields.httpMock()
|
||||
a := assert.New(t)
|
||||
|
||||
// call the real discovery endpoint
|
||||
gock.New(issuer).Get(openid.DiscoveryEndpoint).EnableNetworking()
|
||||
provider, err := New(tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes)
|
||||
require.NoError(t, err)
|
||||
|
||||
session := &oidc.Session{
|
||||
Provider: provider.Provider,
|
||||
AuthURL: tt.fields.authURL,
|
||||
Code: tt.fields.code,
|
||||
Tokens: tt.fields.tokens,
|
||||
}
|
||||
|
||||
user, err := session.FetchUser(context.Background())
|
||||
if tt.want.err != nil && !errors.Is(err, tt.want.err) {
|
||||
a.Fail("invalid error", "expected %v, got %v", tt.want.err, err)
|
||||
}
|
||||
if tt.want.err == nil {
|
||||
a.NoError(err)
|
||||
a.Equal(tt.want.id, user.GetID())
|
||||
a.Equal(tt.want.firstName, user.GetFirstName())
|
||||
a.Equal(tt.want.lastName, user.GetLastName())
|
||||
a.Equal(tt.want.displayName, user.GetDisplayName())
|
||||
a.Equal(tt.want.nickName, user.GetNickname())
|
||||
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
|
||||
a.Equal(domain.EmailAddress(tt.want.email), user.GetEmail())
|
||||
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
|
||||
a.Equal(domain.PhoneNumber(tt.want.phone), user.GetPhone())
|
||||
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
|
||||
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
|
||||
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
|
||||
a.Equal(tt.want.profile, user.GetProfile())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func userinfo() *openid.UserInfo {
|
||||
return &openid.UserInfo{
|
||||
Subject: "sub",
|
||||
UserInfoProfile: openid.UserInfoProfile{
|
||||
GivenName: "firstname",
|
||||
FamilyName: "lastname",
|
||||
Name: "firstname lastname",
|
||||
Locale: openid.NewLocale(language.English),
|
||||
Picture: "picture",
|
||||
},
|
||||
UserInfoEmail: openid.UserInfoEmail{
|
||||
Email: "email",
|
||||
EmailVerified: openid.Bool(true),
|
||||
},
|
||||
Claims: map[string]any{
|
||||
"hd": "hosted domain",
|
||||
},
|
||||
}
|
||||
}
|
141
apps/api/internal/idp/providers/jwt/jwt.go
Normal file
141
apps/api/internal/idp/providers/jwt/jwt.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net/url"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
)
|
||||
|
||||
const (
|
||||
QueryAuthRequestID = "authRequestID"
|
||||
QueryUserAgentID = "userAgentID"
|
||||
)
|
||||
|
||||
var _ idp.Provider = (*Provider)(nil)
|
||||
|
||||
var (
|
||||
ErrMissingState = errors.New("state missing")
|
||||
)
|
||||
|
||||
// Provider is the [idp.Provider] implementation for a JWT provider
|
||||
type Provider struct {
|
||||
name string
|
||||
headerName string
|
||||
issuer string
|
||||
jwtEndpoint string
|
||||
keysEndpoint string
|
||||
isLinkingAllowed bool
|
||||
isCreationAllowed bool
|
||||
isAutoCreation bool
|
||||
isAutoUpdate bool
|
||||
encryptionAlg crypto.EncryptionAlgorithm
|
||||
}
|
||||
|
||||
type ProviderOpts func(provider *Provider)
|
||||
|
||||
// WithLinkingAllowed allows end users to link the federated user to an existing one
|
||||
func WithLinkingAllowed() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.isLinkingAllowed = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithCreationAllowed allows end users to create a new user using the federated information
|
||||
func WithCreationAllowed() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.isCreationAllowed = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithAutoCreation enables that federated users are automatically created if not already existing
|
||||
func WithAutoCreation() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.isAutoCreation = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithAutoUpdate enables that information retrieved from the provider is automatically used to update
|
||||
// the existing user on each authentication
|
||||
func WithAutoUpdate() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.isAutoUpdate = true
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a JWT provider
|
||||
func New(name, issuer, jwtEndpoint, keysEndpoint, headerName string, encryptionAlg crypto.EncryptionAlgorithm, options ...ProviderOpts) (*Provider, error) {
|
||||
provider := &Provider{
|
||||
name: name,
|
||||
issuer: issuer,
|
||||
jwtEndpoint: jwtEndpoint,
|
||||
keysEndpoint: keysEndpoint,
|
||||
headerName: headerName,
|
||||
encryptionAlg: encryptionAlg,
|
||||
}
|
||||
for _, option := range options {
|
||||
option(provider)
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// Name implements the [idp.Provider] interface
|
||||
func (p *Provider) Name() string {
|
||||
return p.name
|
||||
}
|
||||
|
||||
// BeginAuth implements the [idp.Provider] interface.
|
||||
// It will create a [Session] with an AuthURL, pointing to the jwtEndpoint
|
||||
// with the authRequest and encrypted userAgent ids.
|
||||
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...idp.Parameter) (idp.Session, error) {
|
||||
if state == "" {
|
||||
return nil, ErrMissingState
|
||||
}
|
||||
userAgentID := userAgentIDFromParams(state, params...)
|
||||
redirect, err := url.Parse(p.jwtEndpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
q := redirect.Query()
|
||||
q.Set(QueryAuthRequestID, state)
|
||||
nonce, err := p.encryptionAlg.Encrypt([]byte(userAgentID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
q.Set(QueryUserAgentID, base64.RawURLEncoding.EncodeToString(nonce))
|
||||
redirect.RawQuery = q.Encode()
|
||||
return &Session{AuthURL: redirect.String()}, nil
|
||||
}
|
||||
|
||||
func userAgentIDFromParams(state string, params ...idp.Parameter) string {
|
||||
for _, param := range params {
|
||||
if id, ok := param.(idp.UserAgentID); ok {
|
||||
return string(id)
|
||||
}
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
// IsLinkingAllowed implements the [idp.Provider] interface.
|
||||
func (p *Provider) IsLinkingAllowed() bool {
|
||||
return p.isLinkingAllowed
|
||||
}
|
||||
|
||||
// IsCreationAllowed implements the [idp.Provider] interface.
|
||||
func (p *Provider) IsCreationAllowed() bool {
|
||||
return p.isCreationAllowed
|
||||
}
|
||||
|
||||
// IsAutoCreation implements the [idp.Provider] interface.
|
||||
func (p *Provider) IsAutoCreation() bool {
|
||||
return p.isAutoCreation
|
||||
}
|
||||
|
||||
// IsAutoUpdate implements the [idp.Provider] interface.
|
||||
func (p *Provider) IsAutoUpdate() bool {
|
||||
return p.isAutoUpdate
|
||||
}
|
226
apps/api/internal/idp/providers/jwt/jwt_test.go
Normal file
226
apps/api/internal/idp/providers/jwt/jwt_test.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
)
|
||||
|
||||
func TestProvider_BeginAuth(t *testing.T) {
|
||||
type fields struct {
|
||||
name string
|
||||
issuer string
|
||||
jwtEndpoint string
|
||||
keysEndpoint string
|
||||
headerName string
|
||||
encryptionAlg func(t *testing.T) crypto.EncryptionAlgorithm
|
||||
}
|
||||
type args struct {
|
||||
state string
|
||||
params []idp.Parameter
|
||||
}
|
||||
type want struct {
|
||||
session idp.Session
|
||||
err func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "missing state, error",
|
||||
fields: fields{
|
||||
issuer: "https://jwt.com",
|
||||
jwtEndpoint: "https://auth.com/jwt",
|
||||
keysEndpoint: "https://jwt.com/keys",
|
||||
headerName: "jwt-header",
|
||||
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
|
||||
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
state: "",
|
||||
params: nil,
|
||||
},
|
||||
want: want{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, ErrMissingState)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing userAgentID, fallback to state",
|
||||
fields: fields{
|
||||
issuer: "https://jwt.com",
|
||||
jwtEndpoint: "https://auth.com/jwt",
|
||||
keysEndpoint: "https://jwt.com/keys",
|
||||
headerName: "jwt-header",
|
||||
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
|
||||
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
state: "testState",
|
||||
params: nil,
|
||||
},
|
||||
want: want{
|
||||
session: &Session{AuthURL: "https://auth.com/jwt?authRequestID=testState&userAgentID=dGVzdFN0YXRl"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful auth",
|
||||
fields: fields{
|
||||
issuer: "https://jwt.com",
|
||||
jwtEndpoint: "https://auth.com/jwt",
|
||||
keysEndpoint: "https://jwt.com/keys",
|
||||
headerName: "jwt-header",
|
||||
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
|
||||
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
state: "testState",
|
||||
params: []idp.Parameter{
|
||||
idp.UserAgentID("agent"),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
session: &Session{AuthURL: "https://auth.com/jwt?authRequestID=testState&userAgentID=YWdlbnQ"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
provider, err := New(
|
||||
tt.fields.name,
|
||||
tt.fields.issuer,
|
||||
tt.fields.jwtEndpoint,
|
||||
tt.fields.keysEndpoint,
|
||||
tt.fields.headerName,
|
||||
tt.fields.encryptionAlg(t),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
session, err := provider.BeginAuth(ctx, tt.args.state, tt.args.params...)
|
||||
if tt.want.err != nil && !tt.want.err(err) {
|
||||
a.Fail("invalid error", err)
|
||||
}
|
||||
if tt.want.err == nil {
|
||||
a.NoError(err)
|
||||
wantAuth, wantErr := tt.want.session.GetAuth(ctx)
|
||||
gotAuth, gotErr := session.GetAuth(ctx)
|
||||
a.Equal(wantAuth, gotAuth)
|
||||
a.ErrorIs(gotErr, wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_Options(t *testing.T) {
|
||||
type fields struct {
|
||||
name string
|
||||
issuer string
|
||||
jwtEndpoint string
|
||||
keysEndpoint string
|
||||
headerName string
|
||||
encryptionAlg func(t *testing.T) crypto.EncryptionAlgorithm
|
||||
opts []ProviderOpts
|
||||
}
|
||||
type want struct {
|
||||
name string
|
||||
linkingAllowed bool
|
||||
creationAllowed bool
|
||||
autoCreation bool
|
||||
autoUpdate bool
|
||||
pkce bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "default",
|
||||
fields: fields{
|
||||
name: "jwt",
|
||||
issuer: "https://jwt.com",
|
||||
jwtEndpoint: "https://auth.com/jwt",
|
||||
keysEndpoint: "https://jwt.com/keys",
|
||||
headerName: "jwt-header",
|
||||
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
|
||||
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
|
||||
},
|
||||
opts: nil,
|
||||
},
|
||||
want: want{
|
||||
name: "jwt",
|
||||
linkingAllowed: false,
|
||||
creationAllowed: false,
|
||||
autoCreation: false,
|
||||
autoUpdate: false,
|
||||
pkce: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all true",
|
||||
fields: fields{
|
||||
name: "jwt",
|
||||
issuer: "https://jwt.com",
|
||||
jwtEndpoint: "https://auth.com/jwt",
|
||||
keysEndpoint: "https://jwt.com/keys",
|
||||
headerName: "jwt-header",
|
||||
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
|
||||
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
|
||||
},
|
||||
opts: []ProviderOpts{
|
||||
WithLinkingAllowed(),
|
||||
WithCreationAllowed(),
|
||||
WithAutoCreation(),
|
||||
WithAutoUpdate(),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
name: "jwt",
|
||||
linkingAllowed: true,
|
||||
creationAllowed: true,
|
||||
autoCreation: true,
|
||||
autoUpdate: true,
|
||||
pkce: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
provider, err := New(
|
||||
tt.fields.name,
|
||||
tt.fields.issuer,
|
||||
tt.fields.jwtEndpoint,
|
||||
tt.fields.keysEndpoint,
|
||||
tt.fields.headerName,
|
||||
tt.fields.encryptionAlg(t),
|
||||
tt.fields.opts...,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
a.Equal(tt.want.name, provider.Name())
|
||||
a.Equal(tt.want.linkingAllowed, provider.IsLinkingAllowed())
|
||||
a.Equal(tt.want.creationAllowed, provider.IsCreationAllowed())
|
||||
a.Equal(tt.want.autoCreation, provider.IsAutoCreation())
|
||||
a.Equal(tt.want.autoUpdate, provider.IsAutoUpdate())
|
||||
})
|
||||
}
|
||||
}
|
169
apps/api/internal/idp/providers/jwt/session.go
Normal file
169
apps/api/internal/idp/providers/jwt/session.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
)
|
||||
|
||||
var _ idp.Session = (*Session)(nil)
|
||||
|
||||
var (
|
||||
ErrNoTokens = errors.New("no tokens provided")
|
||||
ErrInvalidToken = errors.New("invalid tokens provided")
|
||||
)
|
||||
|
||||
// Session is the [idp.Session] implementation for the JWT provider
|
||||
type Session struct {
|
||||
*Provider
|
||||
AuthURL string
|
||||
Tokens *oidc.Tokens[*oidc.IDTokenClaims]
|
||||
}
|
||||
|
||||
func NewSession(provider *Provider, tokens *oidc.Tokens[*oidc.IDTokenClaims]) *Session {
|
||||
return &Session{Provider: provider, Tokens: tokens}
|
||||
}
|
||||
|
||||
func NewSessionFromRequest(provider *Provider, r *http.Request) *Session {
|
||||
token := strings.TrimPrefix(r.Header.Get(provider.headerName), oidc.PrefixBearer)
|
||||
return NewSession(provider, &oidc.Tokens[*oidc.IDTokenClaims]{IDToken: token, Token: &oauth2.Token{}})
|
||||
}
|
||||
|
||||
// GetAuth implements the [idp.Session] interface.
|
||||
func (s *Session) GetAuth(ctx context.Context) (idp.Auth, error) {
|
||||
return idp.Redirect(s.AuthURL)
|
||||
}
|
||||
|
||||
// PersistentParameters implements the [idp.Session] interface.
|
||||
func (s *Session) PersistentParameters() map[string]any {
|
||||
return nil
|
||||
}
|
||||
|
||||
// FetchUser implements the [idp.Session] interface.
|
||||
// It will map the received idToken into an [idp.User].
|
||||
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
|
||||
if s.Tokens == nil {
|
||||
return nil, ErrNoTokens
|
||||
}
|
||||
s.Tokens.IDTokenClaims, err = s.validateToken(ctx, s.Tokens.IDToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &User{s.Tokens.IDTokenClaims}, nil
|
||||
}
|
||||
|
||||
func (s *Session) ExpiresAt() time.Time {
|
||||
if s.Tokens == nil || s.Tokens.IDTokenClaims == nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return s.Tokens.IDTokenClaims.GetExpiration()
|
||||
}
|
||||
|
||||
func (s *Session) validateToken(ctx context.Context, token string) (*oidc.IDTokenClaims, error) {
|
||||
logging.Debug("begin token validation")
|
||||
// TODO: be able to specify them in the template: https://github.com/zitadel/zitadel/issues/5322
|
||||
offset := 3 * time.Second
|
||||
maxAge := time.Hour
|
||||
claims := new(oidc.IDTokenClaims)
|
||||
payload, err := oidc.ParseToken(token, claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: malformed jwt payload: %v", ErrInvalidToken, err)
|
||||
}
|
||||
|
||||
if err = oidc.CheckIssuer(claims, s.Provider.issuer); err != nil {
|
||||
return nil, fmt.Errorf("%w: invalid issuer: %v", ErrInvalidToken, err)
|
||||
}
|
||||
|
||||
logging.Debug("begin signature validation")
|
||||
keySet := rp.NewRemoteKeySet(http.DefaultClient, s.Provider.keysEndpoint)
|
||||
if err = oidc.CheckSignature(ctx, token, payload, claims, nil, keySet); err != nil {
|
||||
return nil, fmt.Errorf("%w: invalid signature: %v", ErrInvalidToken, err)
|
||||
}
|
||||
|
||||
if !claims.GetExpiration().IsZero() {
|
||||
if err = oidc.CheckExpiration(claims, offset); err != nil {
|
||||
return nil, fmt.Errorf("%w: expired: %v", ErrInvalidToken, err)
|
||||
}
|
||||
}
|
||||
|
||||
if !claims.GetIssuedAt().IsZero() {
|
||||
if err = oidc.CheckIssuedAt(claims, maxAge, offset); err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidToken, err)
|
||||
}
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func InitUser() *User {
|
||||
return &User{
|
||||
IDTokenClaims: &oidc.IDTokenClaims{},
|
||||
}
|
||||
}
|
||||
|
||||
type User struct {
|
||||
*oidc.IDTokenClaims
|
||||
}
|
||||
|
||||
func (u *User) GetID() string {
|
||||
return u.Subject
|
||||
}
|
||||
|
||||
func (u *User) GetFirstName() string {
|
||||
return u.GivenName
|
||||
}
|
||||
|
||||
func (u *User) GetLastName() string {
|
||||
return u.FamilyName
|
||||
}
|
||||
|
||||
func (u *User) GetDisplayName() string {
|
||||
return u.Name
|
||||
}
|
||||
|
||||
func (u *User) GetNickname() string {
|
||||
return u.Nickname
|
||||
}
|
||||
|
||||
func (u *User) GetPreferredUsername() string {
|
||||
return u.PreferredUsername
|
||||
}
|
||||
|
||||
func (u *User) GetEmail() domain.EmailAddress {
|
||||
return domain.EmailAddress(u.Email)
|
||||
}
|
||||
|
||||
func (u *User) IsEmailVerified() bool {
|
||||
return bool(u.EmailVerified)
|
||||
}
|
||||
|
||||
func (u *User) GetPhone() domain.PhoneNumber {
|
||||
return domain.PhoneNumber(u.IDTokenClaims.PhoneNumber)
|
||||
}
|
||||
|
||||
func (u *User) IsPhoneVerified() bool {
|
||||
return u.PhoneNumberVerified
|
||||
}
|
||||
|
||||
func (u *User) GetPreferredLanguage() language.Tag {
|
||||
return u.Locale.Tag()
|
||||
}
|
||||
|
||||
func (u *User) GetAvatarURL() string {
|
||||
return u.Picture
|
||||
}
|
||||
|
||||
func (u *User) GetProfile() string {
|
||||
return u.Profile
|
||||
}
|
312
apps/api/internal/idp/providers/jwt/session_test.go
Normal file
312
apps/api/internal/idp/providers/jwt/session_test.go
Normal file
@@ -0,0 +1,312 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
"github.com/h2non/gock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
)
|
||||
|
||||
func TestSession_FetchUser(t *testing.T) {
|
||||
type fields struct {
|
||||
name string
|
||||
issuer string
|
||||
jwtEndpoint string
|
||||
keysEndpoint string
|
||||
headerName string
|
||||
encryptionAlg func(t *testing.T) crypto.EncryptionAlgorithm
|
||||
httpMock func(issuer string)
|
||||
authURL string
|
||||
tokens *oidc.Tokens[*oidc.IDTokenClaims]
|
||||
}
|
||||
type want struct {
|
||||
err func(error) bool
|
||||
id string
|
||||
firstName string
|
||||
lastName string
|
||||
displayName string
|
||||
nickName string
|
||||
preferredUsername string
|
||||
email string
|
||||
isEmailVerified bool
|
||||
phone string
|
||||
isPhoneVerified bool
|
||||
preferredLanguage language.Tag
|
||||
avatarURL string
|
||||
profile string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "no tokens",
|
||||
fields: fields{
|
||||
issuer: "https://jwt.com",
|
||||
jwtEndpoint: "https://auth.com/jwt",
|
||||
keysEndpoint: "https://jwt.com/keys",
|
||||
headerName: "jwt-header",
|
||||
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
|
||||
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
|
||||
},
|
||||
httpMock: func(issuer string) {
|
||||
gock.New(issuer).
|
||||
Get("/keys").
|
||||
Reply(200).
|
||||
JSON(keys(t))
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, ErrNoTokens)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid token",
|
||||
fields: fields{
|
||||
issuer: "https://jwt.com",
|
||||
jwtEndpoint: "https://auth.com/jwt",
|
||||
keysEndpoint: "https://jwt.com/keys",
|
||||
headerName: "jwt-header",
|
||||
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
|
||||
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
|
||||
},
|
||||
httpMock: func(issuer string) {
|
||||
gock.New(issuer).
|
||||
Get("/keys").
|
||||
Reply(200).
|
||||
JSON(keys(t))
|
||||
},
|
||||
authURL: "https://auth.com/jwt?authRequestID=testState",
|
||||
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: &oauth2.Token{},
|
||||
IDToken: "invalidToken",
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, ErrInvalidToken)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful fetch",
|
||||
fields: fields{
|
||||
issuer: "https://jwt.com",
|
||||
jwtEndpoint: "https://auth.com/jwt",
|
||||
keysEndpoint: "https://jwt.com/keys",
|
||||
headerName: "jwt-header",
|
||||
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
|
||||
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
|
||||
},
|
||||
httpMock: func(issuer string) {
|
||||
gock.New(issuer).
|
||||
Get("/keys").
|
||||
Reply(200).
|
||||
JSON(keys(t))
|
||||
},
|
||||
authURL: "https://auth.com/jwt?authRequestID=testState",
|
||||
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: &oauth2.Token{},
|
||||
IDToken: idToken(t, "https://jwt.com"),
|
||||
IDTokenClaims: &oidc.IDTokenClaims{
|
||||
TokenClaims: oidc.TokenClaims{
|
||||
Subject: "sub",
|
||||
},
|
||||
UserInfoProfile: oidc.UserInfoProfile{
|
||||
Picture: "picture",
|
||||
Name: "firstname lastname",
|
||||
GivenName: "firstname",
|
||||
FamilyName: "lastname",
|
||||
Nickname: "nickname",
|
||||
PreferredUsername: "username",
|
||||
Profile: "profile",
|
||||
Locale: oidc.NewLocale(language.English),
|
||||
},
|
||||
UserInfoEmail: oidc.UserInfoEmail{
|
||||
Email: "email",
|
||||
EmailVerified: oidc.Bool(true),
|
||||
},
|
||||
UserInfoPhone: oidc.UserInfoPhone{
|
||||
PhoneNumber: "phone",
|
||||
PhoneNumberVerified: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
id: "sub",
|
||||
firstName: "firstname",
|
||||
lastName: "lastname",
|
||||
displayName: "firstname lastname",
|
||||
nickName: "nickname",
|
||||
preferredUsername: "username",
|
||||
email: "email",
|
||||
isEmailVerified: true,
|
||||
phone: "phone",
|
||||
isPhoneVerified: true,
|
||||
preferredLanguage: language.English,
|
||||
avatarURL: "picture",
|
||||
profile: "profile",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer gock.Off()
|
||||
tt.fields.httpMock(tt.fields.issuer)
|
||||
a := assert.New(t)
|
||||
|
||||
provider, err := New(
|
||||
tt.fields.name,
|
||||
tt.fields.issuer,
|
||||
tt.fields.jwtEndpoint,
|
||||
tt.fields.keysEndpoint,
|
||||
tt.fields.headerName,
|
||||
tt.fields.encryptionAlg(t),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
session := &Session{
|
||||
Provider: provider,
|
||||
AuthURL: tt.fields.authURL,
|
||||
Tokens: tt.fields.tokens,
|
||||
}
|
||||
|
||||
user, err := session.FetchUser(context.Background())
|
||||
if tt.want.err != nil && !tt.want.err(err) {
|
||||
a.Fail("invalid error", err)
|
||||
}
|
||||
if tt.want.err == nil {
|
||||
a.NoError(err)
|
||||
a.Equal(tt.want.id, user.GetID())
|
||||
a.Equal(tt.want.firstName, user.GetFirstName())
|
||||
a.Equal(tt.want.lastName, user.GetLastName())
|
||||
a.Equal(tt.want.displayName, user.GetDisplayName())
|
||||
a.Equal(tt.want.nickName, user.GetNickname())
|
||||
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
|
||||
a.Equal(domain.EmailAddress(tt.want.email), user.GetEmail())
|
||||
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
|
||||
a.Equal(domain.PhoneNumber(tt.want.phone), user.GetPhone())
|
||||
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
|
||||
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
|
||||
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
|
||||
a.Equal(tt.want.profile, user.GetProfile())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func idToken(t *testing.T, issuer string) string {
|
||||
claims := oidc.NewIDTokenClaims(
|
||||
issuer,
|
||||
"sub",
|
||||
[]string{"clientID"},
|
||||
time.Now().Add(1*time.Hour),
|
||||
time.Now().Add(-1*time.Minute),
|
||||
"",
|
||||
"",
|
||||
nil,
|
||||
"clientID",
|
||||
0,
|
||||
)
|
||||
claims.UserInfoProfile = oidc.UserInfoProfile{
|
||||
GivenName: "firstname",
|
||||
FamilyName: "lastname",
|
||||
Name: "firstname lastname",
|
||||
Nickname: "nickname",
|
||||
PreferredUsername: "username",
|
||||
Locale: oidc.NewLocale(language.English),
|
||||
Picture: "picture",
|
||||
Profile: "profile",
|
||||
}
|
||||
claims.UserInfoEmail = oidc.UserInfoEmail{
|
||||
Email: "email",
|
||||
EmailVerified: oidc.Bool(true),
|
||||
}
|
||||
claims.UserInfoPhone = oidc.UserInfoPhone{
|
||||
PhoneNumber: "phone",
|
||||
PhoneNumberVerified: true,
|
||||
}
|
||||
|
||||
privateKey, err := crypto.BytesToPrivateKey([]byte(`-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEpAIBAAKCAQEAs38btwb3c7r0tMaQpGvBmY+mPwMU/LpfuPoC0k2t4RsKp0fv
|
||||
40SMl50CRrHgk395wch8PMPYbl3+8TtYAJuyrFALIj3Ff1UcKIk0hOH5DDsfh7/q
|
||||
2wFuncTmS6bifYo8CfSq2vDGnM7nZnEvxY/MfSydZdcmIqlkUpfQmtzExw9+tSe5
|
||||
Dxq6gn5JtlGgLgZGt69r5iMMrTEGhhVAXzNuMZbmlCoBru+rC8ITlTX/0V1ZcsSb
|
||||
L8tYWhthyu9x6yjo1bH85wiVI4gs0MhU8f2a+kjL/KGZbR14Ua2eo6tonBZLC5DH
|
||||
WM2TkYXgRCDPufjcgmzN0Lm91E4P8KvBcvly6QIDAQABAoIBAQCPj1nbSPcg2KZe
|
||||
73FAD+8HopyUSSK//1AP4eXfzcEECVy77g0u9+R6XlkzsZCsZ4g6NN8ounqfyw3c
|
||||
YlpAIkcFCf/dowoSjT+4LASVQyatYZwWNqjgAIU4KgMG/rKnNahPTiBYe7peMB1j
|
||||
EaPjnt8uPkCk8y7NCi3y4Pk24tt/WM5KbJK2NQhUi1csGnleDfE+0blV0l/e6C68
|
||||
W5cbnbWAroMqae/Yon3XVZiXX0m+l2f6ZzIgKaD18J+eEM8FjJC+jQKiRe1i9v3K
|
||||
nQrLwh/gn8J10FcbKn3xqslKVidzASIrNIzHT9j/Z5T9NXuAKa7IV2x+Dtdus+wq
|
||||
iBsUunwBAoGBANpYew+8i9vDwK4/SefduDTuzJ0H9lWTjtbiWQ+KYZoeJ7q3/qns
|
||||
jsmi+mjxkXxXg1RrGbNbjtbl3RXXIrUeeBB0lglRJUjc3VK7VvNoyXIWsiqhCspH
|
||||
IJ9Yuknv4mXB01m/glbSCS/xu4RTgf5aOG4jUiRb9+dCIpvDxI9gbXEVAoGBANJz
|
||||
hIJkplIJ+biTi3G1Oz17qkUkInNXzAEzKD9Atoz5AIAiR1ivOMLOlbucfjevw/Nw
|
||||
TnpkMs9xqCefKupTlsriXtZI88m7ZKzAmolYsPolOy/Jhi31h9JFVTEfKGqVS+dk
|
||||
A4ndhgdW9RUeNJPY2YVCARXQrWpueweQDA1cNaeFAoGAPJsYtXqBW6PPRM5+ZiSt
|
||||
78tk8iV2o7RMjqrPS7f+dXfvUS2nO2VVEPTzCtQarOfhpToBLT65vD6bimdn09w8
|
||||
OV0TFEz4y2u65y7m6LNqTwertpdy1ki97l0DgGhccCBH2P6GYDD2qd8wTH+dcot6
|
||||
ZF/begopGoDJ+HBzi9SZLC0CgYBZzPslHMevyBvr++GLwrallKhiWnns1/DwLiEl
|
||||
ZHrBCtuA0Z+6IwLIdZiE9tEQ+ApYTXrfVPQteqUzSwLn/IUiy5eGPpjwYushoAoR
|
||||
Q2w5QTvRN1/vKo8rVXR1woLfgBdkhFPSN1mitiNcQIhU8jpXV4PZCDOHb99FqdzK
|
||||
sqcedQKBgQCOmgbqxGsnT2WQhoOdzln+NOo6Tx+FveLLqat2KzpY59W4noeI2Awn
|
||||
HfIQgWUAW9dsjVVOXMP1jhq8U9hmH/PFWA11V/iCdk1NTxZEw87VAOeWuajpdDHG
|
||||
+iex349j8h2BcQ4Zd0FWu07gGFnS/yuDJPn6jBhRusdieEcxLRjTKg==
|
||||
-----END RSA PRIVATE KEY-----
|
||||
`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
signer, err := jose.NewSigner(jose.SigningKey{Key: privateKey, Algorithm: "RS256"}, &jose.SignerOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
jws, err := signer.Sign(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
idToken, err := jws.CompactSerialize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return idToken
|
||||
}
|
||||
|
||||
func keys(t *testing.T) *jose.JSONWebKeySet {
|
||||
privateKey, err := crypto.BytesToPublicKey([]byte(`-----BEGIN PUBLIC KEY-----
|
||||
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAs38btwb3c7r0tMaQpGvB
|
||||
mY+mPwMU/LpfuPoC0k2t4RsKp0fv40SMl50CRrHgk395wch8PMPYbl3+8TtYAJuy
|
||||
rFALIj3Ff1UcKIk0hOH5DDsfh7/q2wFuncTmS6bifYo8CfSq2vDGnM7nZnEvxY/M
|
||||
fSydZdcmIqlkUpfQmtzExw9+tSe5Dxq6gn5JtlGgLgZGt69r5iMMrTEGhhVAXzNu
|
||||
MZbmlCoBru+rC8ITlTX/0V1ZcsSbL8tYWhthyu9x6yjo1bH85wiVI4gs0MhU8f2a
|
||||
+kjL/KGZbR14Ua2eo6tonBZLC5DHWM2TkYXgRCDPufjcgmzN0Lm91E4P8KvBcvly
|
||||
6QIDAQAB
|
||||
-----END PUBLIC KEY-----
|
||||
`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{{Key: privateKey, Algorithm: "RS256", Use: oidc.KeyUseSignature}}}
|
||||
}
|
290
apps/api/internal/idp/providers/ldap/ldap.go
Normal file
290
apps/api/internal/idp/providers/ldap/ldap.go
Normal file
@@ -0,0 +1,290 @@
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
)
|
||||
|
||||
const DefaultPort = "389"
|
||||
|
||||
var _ idp.Provider = (*Provider)(nil)
|
||||
|
||||
// Provider is the [idp.Provider] implementation for a generic LDAP provider
|
||||
type Provider struct {
|
||||
name string
|
||||
servers []string
|
||||
startTLS bool
|
||||
baseDN string
|
||||
bindDN string
|
||||
bindPassword string
|
||||
userBase string
|
||||
userObjectClasses []string
|
||||
userFilters []string
|
||||
timeout time.Duration
|
||||
rootCA []byte
|
||||
|
||||
loginUrl string
|
||||
|
||||
isLinkingAllowed bool
|
||||
isCreationAllowed bool
|
||||
isAutoCreation bool
|
||||
isAutoUpdate bool
|
||||
|
||||
idAttribute string
|
||||
firstNameAttribute string
|
||||
lastNameAttribute string
|
||||
displayNameAttribute string
|
||||
nickNameAttribute string
|
||||
preferredUsernameAttribute string
|
||||
emailAttribute string
|
||||
emailVerifiedAttribute string
|
||||
phoneAttribute string
|
||||
phoneVerifiedAttribute string
|
||||
preferredLanguageAttribute string
|
||||
avatarURLAttribute string
|
||||
profileAttribute string
|
||||
}
|
||||
|
||||
type ProviderOpts func(provider *Provider)
|
||||
|
||||
// WithLinkingAllowed allows end users to link the federated user to an existing one.
|
||||
func WithLinkingAllowed() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.isLinkingAllowed = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithCreationAllowed allows end users to create a new user using the federated information.
|
||||
func WithCreationAllowed() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.isCreationAllowed = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithAutoCreation enables that federated users are automatically created if not already existing.
|
||||
func WithAutoCreation() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.isAutoCreation = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithAutoUpdate enables that information retrieved from the provider is automatically used to update
|
||||
// the existing user on each authentication.
|
||||
func WithAutoUpdate() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.isAutoUpdate = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithoutStartTLS configures to communication insecure with the LDAP server without startTLS
|
||||
func WithoutStartTLS() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.startTLS = false
|
||||
}
|
||||
}
|
||||
|
||||
// WithCustomIDAttribute configures to map the LDAP attribute to the user, default is the uniqueUserAttribute
|
||||
func WithCustomIDAttribute(name string) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.idAttribute = name
|
||||
}
|
||||
}
|
||||
|
||||
// WithFirstNameAttribute configures to map the LDAP attribute to the user
|
||||
func WithFirstNameAttribute(name string) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.firstNameAttribute = name
|
||||
}
|
||||
}
|
||||
|
||||
// WithLastNameAttribute configures to map the LDAP attribute to the user
|
||||
func WithLastNameAttribute(name string) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.lastNameAttribute = name
|
||||
}
|
||||
}
|
||||
|
||||
// WithDisplayNameAttribute configures to map the LDAP attribute to the user
|
||||
func WithDisplayNameAttribute(name string) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.displayNameAttribute = name
|
||||
}
|
||||
}
|
||||
|
||||
// WithNickNameAttribute configures to map the LDAP attribute to the user
|
||||
func WithNickNameAttribute(name string) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.nickNameAttribute = name
|
||||
}
|
||||
}
|
||||
|
||||
// WithPreferredUsernameAttribute configures to map the LDAP attribute to the user
|
||||
func WithPreferredUsernameAttribute(name string) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.preferredUsernameAttribute = name
|
||||
}
|
||||
}
|
||||
|
||||
// WithEmailAttribute configures to map the LDAP attribute to the user
|
||||
func WithEmailAttribute(name string) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.emailAttribute = name
|
||||
}
|
||||
}
|
||||
|
||||
// WithEmailVerifiedAttribute configures to map the LDAP attribute to the user
|
||||
func WithEmailVerifiedAttribute(name string) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.emailVerifiedAttribute = name
|
||||
}
|
||||
}
|
||||
|
||||
// WithPhoneAttribute configures to map the LDAP attribute to the user
|
||||
func WithPhoneAttribute(name string) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.phoneAttribute = name
|
||||
}
|
||||
}
|
||||
|
||||
// WithPhoneVerifiedAttribute configures to map the LDAP attribute to the user
|
||||
func WithPhoneVerifiedAttribute(name string) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.phoneVerifiedAttribute = name
|
||||
}
|
||||
}
|
||||
|
||||
// WithPreferredLanguageAttribute configures to map the LDAP attribute to the user
|
||||
func WithPreferredLanguageAttribute(name string) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.preferredLanguageAttribute = name
|
||||
}
|
||||
}
|
||||
|
||||
// WithAvatarURLAttribute configures to map the LDAP attribute to the user
|
||||
func WithAvatarURLAttribute(name string) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.avatarURLAttribute = name
|
||||
}
|
||||
}
|
||||
|
||||
// WithProfileAttribute configures to map the LDAP attribute to the user
|
||||
func WithProfileAttribute(name string) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.profileAttribute = name
|
||||
}
|
||||
}
|
||||
|
||||
func New(
|
||||
name string,
|
||||
servers []string,
|
||||
baseDN string,
|
||||
bindDN string,
|
||||
bindPassword string,
|
||||
userBase string,
|
||||
userObjectClasses []string,
|
||||
userFilters []string,
|
||||
timeout time.Duration,
|
||||
rootCA []byte,
|
||||
loginUrl string,
|
||||
options ...ProviderOpts,
|
||||
) *Provider {
|
||||
provider := &Provider{
|
||||
name: name,
|
||||
servers: servers,
|
||||
startTLS: true,
|
||||
baseDN: baseDN,
|
||||
bindDN: bindDN,
|
||||
bindPassword: bindPassword,
|
||||
userBase: userBase,
|
||||
userObjectClasses: userObjectClasses,
|
||||
userFilters: userFilters,
|
||||
timeout: timeout,
|
||||
rootCA: rootCA,
|
||||
loginUrl: loginUrl,
|
||||
}
|
||||
for _, option := range options {
|
||||
option(provider)
|
||||
}
|
||||
return provider
|
||||
}
|
||||
|
||||
func (p *Provider) Name() string {
|
||||
return p.name
|
||||
}
|
||||
|
||||
func (p *Provider) BeginAuth(ctx context.Context, state string, _ ...idp.Parameter) (idp.Session, error) {
|
||||
return &Session{
|
||||
Provider: p,
|
||||
loginUrl: p.loginUrl + state,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *Provider) GetSession(username, password string) *Session {
|
||||
return &Session{
|
||||
Provider: p,
|
||||
User: username,
|
||||
Password: password,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Provider) IsLinkingAllowed() bool {
|
||||
return p.isLinkingAllowed
|
||||
}
|
||||
|
||||
func (p *Provider) IsCreationAllowed() bool {
|
||||
return p.isCreationAllowed
|
||||
}
|
||||
|
||||
func (p *Provider) IsAutoCreation() bool {
|
||||
return p.isAutoCreation
|
||||
}
|
||||
|
||||
func (p *Provider) IsAutoUpdate() bool {
|
||||
return p.isAutoUpdate
|
||||
}
|
||||
|
||||
func (p *Provider) getNecessaryAttributes() []string {
|
||||
attributes := []string{p.userBase}
|
||||
if p.idAttribute != "" {
|
||||
attributes = append(attributes, p.idAttribute)
|
||||
}
|
||||
if p.firstNameAttribute != "" {
|
||||
attributes = append(attributes, p.firstNameAttribute)
|
||||
}
|
||||
if p.lastNameAttribute != "" {
|
||||
attributes = append(attributes, p.lastNameAttribute)
|
||||
}
|
||||
if p.displayNameAttribute != "" {
|
||||
attributes = append(attributes, p.displayNameAttribute)
|
||||
}
|
||||
if p.nickNameAttribute != "" {
|
||||
attributes = append(attributes, p.nickNameAttribute)
|
||||
}
|
||||
if p.preferredUsernameAttribute != "" {
|
||||
attributes = append(attributes, p.preferredUsernameAttribute)
|
||||
}
|
||||
if p.emailAttribute != "" {
|
||||
attributes = append(attributes, p.emailAttribute)
|
||||
}
|
||||
if p.emailVerifiedAttribute != "" {
|
||||
attributes = append(attributes, p.emailVerifiedAttribute)
|
||||
}
|
||||
if p.phoneAttribute != "" {
|
||||
attributes = append(attributes, p.phoneAttribute)
|
||||
}
|
||||
if p.phoneVerifiedAttribute != "" {
|
||||
attributes = append(attributes, p.phoneVerifiedAttribute)
|
||||
}
|
||||
if p.preferredLanguageAttribute != "" {
|
||||
attributes = append(attributes, p.preferredLanguageAttribute)
|
||||
}
|
||||
if p.avatarURLAttribute != "" {
|
||||
attributes = append(attributes, p.avatarURLAttribute)
|
||||
}
|
||||
if p.profileAttribute != "" {
|
||||
attributes = append(attributes, p.profileAttribute)
|
||||
}
|
||||
return attributes
|
||||
}
|
207
apps/api/internal/idp/providers/ldap/ldap_test.go
Normal file
207
apps/api/internal/idp/providers/ldap/ldap_test.go
Normal file
@@ -0,0 +1,207 @@
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestProvider_Options(t *testing.T) {
|
||||
type fields struct {
|
||||
name string
|
||||
servers []string
|
||||
baseDN string
|
||||
bindDN string
|
||||
bindPassword string
|
||||
userBase string
|
||||
userObjectClasses []string
|
||||
userFilters []string
|
||||
timeout time.Duration
|
||||
rootCA []byte
|
||||
loginUrl string
|
||||
opts []ProviderOpts
|
||||
}
|
||||
type want struct {
|
||||
name string
|
||||
rootCA []byte
|
||||
startTls bool
|
||||
linkingAllowed bool
|
||||
creationAllowed bool
|
||||
autoCreation bool
|
||||
autoUpdate bool
|
||||
idAttribute string
|
||||
firstNameAttribute string
|
||||
lastNameAttribute string
|
||||
displayNameAttribute string
|
||||
nickNameAttribute string
|
||||
preferredUsernameAttribute string
|
||||
emailAttribute string
|
||||
emailVerifiedAttribute string
|
||||
phoneAttribute string
|
||||
phoneVerifiedAttribute string
|
||||
preferredLanguageAttribute string
|
||||
avatarURLAttribute string
|
||||
profileAttribute string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "default",
|
||||
fields: fields{
|
||||
name: "ldap",
|
||||
servers: []string{"server"},
|
||||
baseDN: "base",
|
||||
bindDN: "binddn",
|
||||
bindPassword: "password",
|
||||
userBase: "user",
|
||||
userObjectClasses: []string{"object"},
|
||||
userFilters: []string{"filter"},
|
||||
timeout: 30 * time.Second,
|
||||
loginUrl: "url",
|
||||
opts: nil,
|
||||
},
|
||||
want: want{
|
||||
name: "ldap",
|
||||
startTls: true,
|
||||
linkingAllowed: false,
|
||||
creationAllowed: false,
|
||||
autoCreation: false,
|
||||
autoUpdate: false,
|
||||
idAttribute: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all true",
|
||||
fields: fields{
|
||||
name: "ldap",
|
||||
servers: []string{"server"},
|
||||
baseDN: "base",
|
||||
bindDN: "binddn",
|
||||
bindPassword: "password",
|
||||
userBase: "user",
|
||||
userObjectClasses: []string{"object"},
|
||||
userFilters: []string{"filter"},
|
||||
timeout: 30 * time.Second,
|
||||
loginUrl: "url",
|
||||
opts: []ProviderOpts{
|
||||
WithoutStartTLS(),
|
||||
WithLinkingAllowed(),
|
||||
WithCreationAllowed(),
|
||||
WithAutoCreation(),
|
||||
WithAutoUpdate(),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
name: "ldap",
|
||||
startTls: false,
|
||||
linkingAllowed: true,
|
||||
creationAllowed: true,
|
||||
autoCreation: true,
|
||||
autoUpdate: true,
|
||||
idAttribute: "",
|
||||
},
|
||||
}, {
|
||||
name: "all true, attributes set",
|
||||
fields: fields{
|
||||
name: "ldap",
|
||||
servers: []string{"server"},
|
||||
baseDN: "base",
|
||||
bindDN: "binddn",
|
||||
bindPassword: "password",
|
||||
userBase: "user",
|
||||
userObjectClasses: []string{"object"},
|
||||
userFilters: []string{"filter"},
|
||||
timeout: 30 * time.Second,
|
||||
rootCA: []byte("certificate"),
|
||||
loginUrl: "url",
|
||||
opts: []ProviderOpts{
|
||||
WithoutStartTLS(),
|
||||
WithLinkingAllowed(),
|
||||
WithCreationAllowed(),
|
||||
WithAutoCreation(),
|
||||
WithAutoUpdate(),
|
||||
WithCustomIDAttribute("id"),
|
||||
WithFirstNameAttribute("first"),
|
||||
WithLastNameAttribute("last"),
|
||||
WithDisplayNameAttribute("display"),
|
||||
WithNickNameAttribute("nick"),
|
||||
WithPreferredUsernameAttribute("prefUser"),
|
||||
WithEmailAttribute("email"),
|
||||
WithEmailVerifiedAttribute("emailVerified"),
|
||||
WithPhoneAttribute("phone"),
|
||||
WithPhoneVerifiedAttribute("phoneVerified"),
|
||||
WithPreferredLanguageAttribute("prefLang"),
|
||||
WithAvatarURLAttribute("avatar"),
|
||||
WithProfileAttribute("profile"),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
name: "ldap",
|
||||
rootCA: []byte("certificate"),
|
||||
startTls: false,
|
||||
linkingAllowed: true,
|
||||
creationAllowed: true,
|
||||
autoCreation: true,
|
||||
autoUpdate: true,
|
||||
idAttribute: "id",
|
||||
firstNameAttribute: "first",
|
||||
lastNameAttribute: "last",
|
||||
displayNameAttribute: "display",
|
||||
nickNameAttribute: "nick",
|
||||
preferredUsernameAttribute: "prefUser",
|
||||
emailAttribute: "email",
|
||||
emailVerifiedAttribute: "emailVerified",
|
||||
phoneAttribute: "phone",
|
||||
phoneVerifiedAttribute: "phoneVerified",
|
||||
preferredLanguageAttribute: "prefLang",
|
||||
avatarURLAttribute: "avatar",
|
||||
profileAttribute: "profile",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
provider := New(
|
||||
tt.fields.name,
|
||||
tt.fields.servers,
|
||||
tt.fields.baseDN,
|
||||
tt.fields.bindDN,
|
||||
tt.fields.bindPassword,
|
||||
tt.fields.userBase,
|
||||
tt.fields.userObjectClasses,
|
||||
tt.fields.userFilters,
|
||||
tt.fields.timeout,
|
||||
tt.fields.rootCA,
|
||||
tt.fields.loginUrl,
|
||||
tt.fields.opts...,
|
||||
)
|
||||
|
||||
a.Equal(tt.want.name, provider.Name())
|
||||
a.Equal(tt.want.rootCA, provider.rootCA)
|
||||
a.Equal(tt.want.startTls, provider.startTLS)
|
||||
a.Equal(tt.want.linkingAllowed, provider.IsLinkingAllowed())
|
||||
a.Equal(tt.want.creationAllowed, provider.IsCreationAllowed())
|
||||
a.Equal(tt.want.autoCreation, provider.IsAutoCreation())
|
||||
a.Equal(tt.want.autoUpdate, provider.IsAutoUpdate())
|
||||
|
||||
a.Equal(tt.want.idAttribute, provider.idAttribute)
|
||||
a.Equal(tt.want.firstNameAttribute, provider.firstNameAttribute)
|
||||
a.Equal(tt.want.lastNameAttribute, provider.lastNameAttribute)
|
||||
a.Equal(tt.want.displayNameAttribute, provider.displayNameAttribute)
|
||||
a.Equal(tt.want.nickNameAttribute, provider.nickNameAttribute)
|
||||
a.Equal(tt.want.preferredUsernameAttribute, provider.preferredUsernameAttribute)
|
||||
a.Equal(tt.want.emailAttribute, provider.emailAttribute)
|
||||
a.Equal(tt.want.emailVerifiedAttribute, provider.emailVerifiedAttribute)
|
||||
a.Equal(tt.want.phoneAttribute, provider.phoneAttribute)
|
||||
a.Equal(tt.want.phoneVerifiedAttribute, provider.phoneVerifiedAttribute)
|
||||
a.Equal(tt.want.preferredLanguageAttribute, provider.preferredLanguageAttribute)
|
||||
a.Equal(tt.want.avatarURLAttribute, provider.avatarURLAttribute)
|
||||
a.Equal(tt.want.profileAttribute, provider.profileAttribute)
|
||||
})
|
||||
}
|
||||
}
|
333
apps/api/internal/idp/providers/ldap/session.go
Normal file
333
apps/api/internal/idp/providers/ldap/session.go
Normal file
@@ -0,0 +1,333 @@
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
"github.com/zitadel/logging"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
)
|
||||
|
||||
var ErrNoSingleUser = errors.New("user does not exist or too many entries returned")
|
||||
var ErrFailedLogin = errors.New("user failed to login")
|
||||
var ErrUnableToAppendRootCA = errors.New("unable to append rootCA")
|
||||
|
||||
var _ idp.Session = (*Session)(nil)
|
||||
|
||||
type Session struct {
|
||||
Provider *Provider
|
||||
loginUrl string
|
||||
User string
|
||||
Password string
|
||||
Entry *ldap.Entry
|
||||
}
|
||||
|
||||
func NewSession(provider *Provider, username, password string) *Session {
|
||||
return &Session{Provider: provider, User: username, Password: password}
|
||||
}
|
||||
|
||||
// GetAuth implements the [idp.Session] interface.
|
||||
func (s *Session) GetAuth(ctx context.Context) (idp.Auth, error) {
|
||||
return idp.Redirect(s.loginUrl)
|
||||
}
|
||||
|
||||
// PersistentParameters implements the [idp.Session] interface.
|
||||
func (s *Session) PersistentParameters() map[string]any {
|
||||
return nil
|
||||
}
|
||||
|
||||
// FetchUser implements the [idp.Session] interface.
|
||||
func (s *Session) FetchUser(_ context.Context) (_ idp.User, err error) {
|
||||
var user *ldap.Entry
|
||||
for _, server := range s.Provider.servers {
|
||||
user, err = tryBind(server,
|
||||
s.Provider.startTLS,
|
||||
s.Provider.bindDN,
|
||||
s.Provider.bindPassword,
|
||||
s.Provider.baseDN,
|
||||
s.Provider.getNecessaryAttributes(),
|
||||
s.Provider.userObjectClasses,
|
||||
s.Provider.userFilters,
|
||||
s.User,
|
||||
s.Password,
|
||||
s.Provider.timeout,
|
||||
s.Provider.rootCA)
|
||||
// If there were invalid credentials or multiple users with the credentials cancel process
|
||||
if err != nil && (errors.Is(err, ErrFailedLogin) || errors.Is(err, ErrNoSingleUser)) {
|
||||
return nil, err
|
||||
}
|
||||
// If a user bind was successful and user is filled continue with login, otherwise try next server
|
||||
if err == nil && user != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.Entry = user
|
||||
|
||||
return mapLDAPEntryToUser(
|
||||
user,
|
||||
s.Provider.idAttribute,
|
||||
s.Provider.firstNameAttribute,
|
||||
s.Provider.lastNameAttribute,
|
||||
s.Provider.displayNameAttribute,
|
||||
s.Provider.nickNameAttribute,
|
||||
s.Provider.preferredUsernameAttribute,
|
||||
s.Provider.emailAttribute,
|
||||
s.Provider.emailVerifiedAttribute,
|
||||
s.Provider.phoneAttribute,
|
||||
s.Provider.phoneVerifiedAttribute,
|
||||
s.Provider.preferredLanguageAttribute,
|
||||
s.Provider.avatarURLAttribute,
|
||||
s.Provider.profileAttribute,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *Session) ExpiresAt() time.Time {
|
||||
return time.Time{} // falls back to the default expiration time
|
||||
}
|
||||
|
||||
func tryBind(
|
||||
server string,
|
||||
startTLS bool,
|
||||
bindDN string,
|
||||
bindPassword string,
|
||||
baseDN string,
|
||||
attributes []string,
|
||||
objectClasses []string,
|
||||
userFilters []string,
|
||||
username string,
|
||||
password string,
|
||||
timeout time.Duration,
|
||||
rootCA []byte,
|
||||
) (*ldap.Entry, error) {
|
||||
conn, err := getConnection(server, startTLS, timeout, rootCA)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.Bind(bindDN, bindPassword); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return trySearchAndUserBind(
|
||||
conn,
|
||||
baseDN,
|
||||
attributes,
|
||||
objectClasses,
|
||||
userFilters,
|
||||
username,
|
||||
password,
|
||||
timeout,
|
||||
)
|
||||
}
|
||||
|
||||
func getConnection(
|
||||
server string,
|
||||
startTLS bool,
|
||||
timeout time.Duration,
|
||||
rootCA []byte,
|
||||
) (*ldap.Conn, error) {
|
||||
if timeout == 0 {
|
||||
timeout = ldap.DefaultTimeout
|
||||
}
|
||||
|
||||
dialer := make([]ldap.DialOpt, 1, 2)
|
||||
dialer[0] = ldap.DialWithDialer(&net.Dialer{Timeout: timeout})
|
||||
|
||||
u, err := url.Parse(server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if u.Scheme == "ldaps" && len(rootCA) > 0 {
|
||||
rootCAs := x509.NewCertPool()
|
||||
if ok := rootCAs.AppendCertsFromPEM(rootCA); !ok {
|
||||
return nil, ErrUnableToAppendRootCA
|
||||
}
|
||||
|
||||
dialer = append(dialer, ldap.DialWithTLSConfig(&tls.Config{
|
||||
RootCAs: rootCAs,
|
||||
}))
|
||||
}
|
||||
|
||||
conn, err := ldap.DialURL(server, dialer...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if u.Scheme == "ldap" && startTLS {
|
||||
err = conn.StartTLS(&tls.Config{ServerName: u.Host})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func trySearchAndUserBind(
|
||||
conn *ldap.Conn,
|
||||
baseDN string,
|
||||
attributes []string,
|
||||
objectClasses []string,
|
||||
userFilters []string,
|
||||
username string,
|
||||
password string,
|
||||
timeout time.Duration,
|
||||
) (*ldap.Entry, error) {
|
||||
searchQuery := queriesAndToSearchQuery(
|
||||
objectClassesToSearchQuery(objectClasses),
|
||||
queriesOrToSearchQuery(
|
||||
userFiltersToSearchQuery(userFilters, username)...,
|
||||
),
|
||||
)
|
||||
|
||||
// Search for user with the unique attribute for the userDN
|
||||
searchRequest := ldap.NewSearchRequest(
|
||||
baseDN,
|
||||
ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, int(timeout.Seconds()), false,
|
||||
searchQuery,
|
||||
attributes,
|
||||
nil,
|
||||
)
|
||||
|
||||
sr, err := conn.Search(searchRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(sr.Entries) != 1 {
|
||||
logging.WithFields("entries", len(sr.Entries)).Info("ldap: no single user found")
|
||||
return nil, ErrNoSingleUser
|
||||
}
|
||||
|
||||
user := sr.Entries[0]
|
||||
// Bind as the user to verify their password
|
||||
userDN, err := ldap.ParseDN(user.DN)
|
||||
if err != nil {
|
||||
logging.WithFields("userDN", user.DN).WithError(err).Info("ldap user parse DN failed")
|
||||
return nil, err
|
||||
}
|
||||
if err = conn.Bind(userDN.String(), password); err != nil {
|
||||
logging.WithFields("userDN", user.DN).WithError(err).Info("ldap user bind failed")
|
||||
return nil, ErrFailedLogin
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func queriesAndToSearchQuery(queries ...string) string {
|
||||
if len(queries) == 0 {
|
||||
return ""
|
||||
}
|
||||
if len(queries) == 1 {
|
||||
return queries[0]
|
||||
}
|
||||
joinQueries := "(&"
|
||||
for _, s := range queries {
|
||||
joinQueries += s
|
||||
}
|
||||
return joinQueries + ")"
|
||||
}
|
||||
|
||||
func queriesOrToSearchQuery(queries ...string) string {
|
||||
if len(queries) == 0 {
|
||||
return ""
|
||||
}
|
||||
if len(queries) == 1 {
|
||||
return queries[0]
|
||||
}
|
||||
joinQueries := "(|"
|
||||
for _, s := range queries {
|
||||
joinQueries += s
|
||||
}
|
||||
return joinQueries + ")"
|
||||
}
|
||||
|
||||
func objectClassesToSearchQuery(classes []string) string {
|
||||
searchQuery := ""
|
||||
for _, class := range classes {
|
||||
searchQuery += "(objectClass=" + class + ")"
|
||||
}
|
||||
return searchQuery
|
||||
}
|
||||
|
||||
func userFiltersToSearchQuery(filters []string, username string) []string {
|
||||
searchQueries := make([]string, len(filters))
|
||||
for i, filter := range filters {
|
||||
searchQueries[i] = "(" + filter + "=" + username + ")"
|
||||
}
|
||||
return searchQueries
|
||||
}
|
||||
|
||||
func mapLDAPEntryToUser(
|
||||
user *ldap.Entry,
|
||||
idAttribute,
|
||||
firstNameAttribute,
|
||||
lastNameAttribute,
|
||||
displayNameAttribute,
|
||||
nickNameAttribute,
|
||||
preferredUsernameAttribute,
|
||||
emailAttribute,
|
||||
emailVerifiedAttribute,
|
||||
phoneAttribute,
|
||||
phoneVerifiedAttribute,
|
||||
preferredLanguageAttribute,
|
||||
avatarURLAttribute,
|
||||
profileAttribute string,
|
||||
) (_ *User, err error) {
|
||||
var emailVerified bool
|
||||
if v := user.GetAttributeValue(emailVerifiedAttribute); v != "" {
|
||||
emailVerified, err = strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
var phoneVerified bool
|
||||
if v := user.GetAttributeValue(phoneVerifiedAttribute); v != "" {
|
||||
phoneVerified, err = strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return NewUser(
|
||||
getAttributeValue(user, idAttribute),
|
||||
getAttributeValue(user, firstNameAttribute),
|
||||
getAttributeValue(user, lastNameAttribute),
|
||||
getAttributeValue(user, displayNameAttribute),
|
||||
getAttributeValue(user, nickNameAttribute),
|
||||
getAttributeValue(user, preferredUsernameAttribute),
|
||||
domain.EmailAddress(user.GetAttributeValue(emailAttribute)),
|
||||
emailVerified,
|
||||
domain.PhoneNumber(user.GetAttributeValue(phoneAttribute)),
|
||||
phoneVerified,
|
||||
language.Make(user.GetAttributeValue(preferredLanguageAttribute)),
|
||||
user.GetAttributeValue(avatarURLAttribute),
|
||||
user.GetAttributeValue(profileAttribute),
|
||||
), nil
|
||||
}
|
||||
|
||||
func getAttributeValue(user *ldap.Entry, attribute string) string {
|
||||
// return an empty string if no attribute is needed
|
||||
if attribute == "" {
|
||||
return ""
|
||||
}
|
||||
value := user.GetAttributeValue(attribute)
|
||||
if utf8.ValidString(value) {
|
||||
return value
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(user.GetRawAttributeValue(attribute))
|
||||
}
|
400
apps/api/internal/idp/providers/ldap/session_test.go
Normal file
400
apps/api/internal/idp/providers/ldap/session_test.go
Normal file
@@ -0,0 +1,400 @@
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
func TestProvider_objectClassesToSearchQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fields []string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "zero",
|
||||
fields: []string{},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "one",
|
||||
fields: []string{"test"},
|
||||
want: "(objectClass=test)",
|
||||
},
|
||||
{
|
||||
name: "three",
|
||||
fields: []string{"test1", "test2", "test3"},
|
||||
want: "(objectClass=test1)(objectClass=test2)(objectClass=test3)",
|
||||
},
|
||||
{
|
||||
name: "five",
|
||||
fields: []string{"test1", "test2", "test3", "test4", "test5"},
|
||||
want: "(objectClass=test1)(objectClass=test2)(objectClass=test3)(objectClass=test4)(objectClass=test5)",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
a.Equal(tt.want, objectClassesToSearchQuery(tt.fields))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_userFiltersToSearchQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fields []string
|
||||
username string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "zero",
|
||||
fields: []string{},
|
||||
username: "user",
|
||||
want: []string{},
|
||||
},
|
||||
{
|
||||
name: "one",
|
||||
fields: []string{"test"},
|
||||
username: "user",
|
||||
want: []string{"(test=user)"},
|
||||
},
|
||||
{
|
||||
name: "three",
|
||||
fields: []string{"test1", "test2", "test3"},
|
||||
username: "user",
|
||||
want: []string{"(test1=user)", "(test2=user)", "(test3=user)"},
|
||||
},
|
||||
{
|
||||
name: "five",
|
||||
fields: []string{"test1", "test2", "test3", "test4", "test5"},
|
||||
username: "user",
|
||||
want: []string{"(test1=user)", "(test2=user)", "(test3=user)", "(test4=user)", "(test5=user)"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
a.Equal(tt.want, userFiltersToSearchQuery(tt.fields, tt.username))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_queriesAndToSearchQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fields []string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "zero",
|
||||
fields: []string{},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "one",
|
||||
fields: []string{"(test)"},
|
||||
want: "(test)",
|
||||
},
|
||||
{
|
||||
name: "three",
|
||||
fields: []string{"(test1)", "(test2)", "(test3)"},
|
||||
want: "(&(test1)(test2)(test3))",
|
||||
},
|
||||
{
|
||||
name: "five",
|
||||
fields: []string{"(test1)", "(test2)", "(test3)", "(test4)", "(test5)"},
|
||||
want: "(&(test1)(test2)(test3)(test4)(test5))",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
a.Equal(tt.want, queriesAndToSearchQuery(tt.fields...))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_queriesOrToSearchQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fields []string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "zero",
|
||||
fields: []string{},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "one",
|
||||
fields: []string{"(test)"},
|
||||
want: "(test)",
|
||||
},
|
||||
{
|
||||
name: "three",
|
||||
fields: []string{"(test1)", "(test2)", "(test3)"},
|
||||
want: "(|(test1)(test2)(test3))",
|
||||
},
|
||||
{
|
||||
name: "five",
|
||||
fields: []string{"(test1)", "(test2)", "(test3)", "(test4)", "(test5)"},
|
||||
want: "(|(test1)(test2)(test3)(test4)(test5))",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
a.Equal(tt.want, queriesOrToSearchQuery(tt.fields...))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_mapLDAPEntryToUser(t *testing.T) {
|
||||
type fields struct {
|
||||
user *ldap.Entry
|
||||
idAttribute string
|
||||
firstNameAttribute string
|
||||
lastNameAttribute string
|
||||
displayNameAttribute string
|
||||
nickNameAttribute string
|
||||
preferredUsernameAttribute string
|
||||
emailAttribute string
|
||||
emailVerifiedAttribute string
|
||||
phoneAttribute string
|
||||
phoneVerifiedAttribute string
|
||||
preferredLanguageAttribute string
|
||||
avatarURLAttribute string
|
||||
profileAttribute string
|
||||
}
|
||||
type want struct {
|
||||
user *User
|
||||
err func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
fields: fields{
|
||||
user: &ldap.Entry{
|
||||
Attributes: []*ldap.EntryAttribute{
|
||||
{Name: "id", Values: []string{"id"}},
|
||||
{Name: "first", Values: []string{"first"}},
|
||||
{Name: "last", Values: []string{"last"}},
|
||||
{Name: "display", Values: []string{"display"}},
|
||||
{Name: "nick", Values: []string{"nick"}},
|
||||
{Name: "preferred", Values: []string{"preferred"}},
|
||||
{Name: "email", Values: []string{"email"}},
|
||||
{Name: "emailVerified", Values: []string{"false"}},
|
||||
{Name: "phone", Values: []string{"phone"}},
|
||||
{Name: "phoneVerified", Values: []string{"false"}},
|
||||
{Name: "lang", Values: []string{"und"}},
|
||||
{Name: "avatar", Values: []string{"avatar"}},
|
||||
{Name: "profile", Values: []string{"profile"}},
|
||||
},
|
||||
},
|
||||
idAttribute: "",
|
||||
firstNameAttribute: "",
|
||||
lastNameAttribute: "",
|
||||
displayNameAttribute: "",
|
||||
nickNameAttribute: "",
|
||||
preferredUsernameAttribute: "",
|
||||
emailAttribute: "",
|
||||
emailVerifiedAttribute: "",
|
||||
phoneAttribute: "",
|
||||
phoneVerifiedAttribute: "",
|
||||
preferredLanguageAttribute: "",
|
||||
avatarURLAttribute: "",
|
||||
profileAttribute: "",
|
||||
},
|
||||
want: want{
|
||||
user: &User{
|
||||
ID: "",
|
||||
FirstName: "",
|
||||
LastName: "",
|
||||
DisplayName: "",
|
||||
NickName: "",
|
||||
PreferredUsername: "",
|
||||
Email: "",
|
||||
EmailVerified: false,
|
||||
Phone: "",
|
||||
PhoneVerified: false,
|
||||
PreferredLanguage: language.Tag{},
|
||||
AvatarURL: "",
|
||||
Profile: "",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "failed parse emailVerified",
|
||||
fields: fields{
|
||||
user: &ldap.Entry{
|
||||
Attributes: []*ldap.EntryAttribute{
|
||||
{Name: "id", Values: []string{"id"}},
|
||||
{Name: "first", Values: []string{"first"}},
|
||||
{Name: "last", Values: []string{"last"}},
|
||||
{Name: "display", Values: []string{"display"}},
|
||||
{Name: "nick", Values: []string{"nick"}},
|
||||
{Name: "preferred", Values: []string{"preferred"}},
|
||||
{Name: "email", Values: []string{"email"}},
|
||||
{Name: "emailVerified", Values: []string{"failure"}},
|
||||
{Name: "phone", Values: []string{"phone"}},
|
||||
{Name: "phoneVerified", Values: []string{"false"}},
|
||||
{Name: "lang", Values: []string{"und"}},
|
||||
{Name: "avatar", Values: []string{"avatar"}},
|
||||
{Name: "profile", Values: []string{"profile"}},
|
||||
},
|
||||
},
|
||||
idAttribute: "id",
|
||||
firstNameAttribute: "first",
|
||||
lastNameAttribute: "last",
|
||||
displayNameAttribute: "display",
|
||||
nickNameAttribute: "nick",
|
||||
preferredUsernameAttribute: "preferred",
|
||||
emailAttribute: "email",
|
||||
emailVerifiedAttribute: "emailVerified",
|
||||
phoneAttribute: "phone",
|
||||
phoneVerifiedAttribute: "phoneVerified",
|
||||
preferredLanguageAttribute: "lang",
|
||||
avatarURLAttribute: "avatar",
|
||||
profileAttribute: "profile",
|
||||
},
|
||||
want: want{
|
||||
err: func(err error) bool {
|
||||
return err != nil
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "failed parse phoneVerified",
|
||||
fields: fields{
|
||||
user: &ldap.Entry{
|
||||
Attributes: []*ldap.EntryAttribute{
|
||||
{Name: "id", Values: []string{"id"}},
|
||||
{Name: "first", Values: []string{"first"}},
|
||||
{Name: "last", Values: []string{"last"}},
|
||||
{Name: "display", Values: []string{"display"}},
|
||||
{Name: "nick", Values: []string{"nick"}},
|
||||
{Name: "preferred", Values: []string{"preferred"}},
|
||||
{Name: "email", Values: []string{"email"}},
|
||||
{Name: "emailVerified", Values: []string{"false"}},
|
||||
{Name: "phone", Values: []string{"phone"}},
|
||||
{Name: "phoneVerified", Values: []string{"failure"}},
|
||||
{Name: "lang", Values: []string{"und"}},
|
||||
{Name: "avatar", Values: []string{"avatar"}},
|
||||
{Name: "profile", Values: []string{"profile"}},
|
||||
},
|
||||
},
|
||||
idAttribute: "id",
|
||||
firstNameAttribute: "first",
|
||||
lastNameAttribute: "last",
|
||||
displayNameAttribute: "display",
|
||||
nickNameAttribute: "nick",
|
||||
preferredUsernameAttribute: "preferred",
|
||||
emailAttribute: "email",
|
||||
emailVerifiedAttribute: "emailVerified",
|
||||
phoneAttribute: "phone",
|
||||
phoneVerifiedAttribute: "phoneVerified",
|
||||
preferredLanguageAttribute: "lang",
|
||||
avatarURLAttribute: "avatar",
|
||||
profileAttribute: "profile",
|
||||
},
|
||||
want: want{
|
||||
err: func(err error) bool {
|
||||
return err != nil
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "full user",
|
||||
fields: fields{
|
||||
user: &ldap.Entry{
|
||||
Attributes: []*ldap.EntryAttribute{
|
||||
{Name: "id", Values: []string{"id"}},
|
||||
{Name: "first", Values: []string{"first"}},
|
||||
{Name: "last", Values: []string{"last"}},
|
||||
{Name: "display", Values: []string{"display"}},
|
||||
{Name: "nick", Values: []string{"nick"}},
|
||||
{Name: "preferred", Values: []string{"preferred"}},
|
||||
{Name: "email", Values: []string{"email"}},
|
||||
{Name: "emailVerified", Values: []string{"false"}},
|
||||
{Name: "phone", Values: []string{"phone"}},
|
||||
{Name: "phoneVerified", Values: []string{"false"}},
|
||||
{Name: "lang", Values: []string{"und"}},
|
||||
{Name: "avatar", Values: []string{"avatar"}},
|
||||
{Name: "profile", Values: []string{"profile"}},
|
||||
},
|
||||
},
|
||||
idAttribute: "id",
|
||||
firstNameAttribute: "first",
|
||||
lastNameAttribute: "last",
|
||||
displayNameAttribute: "display",
|
||||
nickNameAttribute: "nick",
|
||||
preferredUsernameAttribute: "preferred",
|
||||
emailAttribute: "email",
|
||||
emailVerifiedAttribute: "emailVerified",
|
||||
phoneAttribute: "phone",
|
||||
phoneVerifiedAttribute: "phoneVerified",
|
||||
preferredLanguageAttribute: "lang",
|
||||
avatarURLAttribute: "avatar",
|
||||
profileAttribute: "profile",
|
||||
},
|
||||
want: want{
|
||||
user: &User{
|
||||
ID: "id",
|
||||
FirstName: "first",
|
||||
LastName: "last",
|
||||
DisplayName: "display",
|
||||
NickName: "nick",
|
||||
PreferredUsername: "preferred",
|
||||
Email: "email",
|
||||
EmailVerified: false,
|
||||
Phone: "phone",
|
||||
PhoneVerified: false,
|
||||
PreferredLanguage: language.Make("und"),
|
||||
AvatarURL: "avatar",
|
||||
Profile: "profile",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := mapLDAPEntryToUser(
|
||||
tt.fields.user,
|
||||
tt.fields.idAttribute,
|
||||
tt.fields.firstNameAttribute,
|
||||
tt.fields.lastNameAttribute,
|
||||
tt.fields.displayNameAttribute,
|
||||
tt.fields.nickNameAttribute,
|
||||
tt.fields.preferredUsernameAttribute,
|
||||
tt.fields.emailAttribute,
|
||||
tt.fields.emailVerifiedAttribute,
|
||||
tt.fields.phoneAttribute,
|
||||
tt.fields.phoneVerifiedAttribute,
|
||||
tt.fields.preferredLanguageAttribute,
|
||||
tt.fields.avatarURLAttribute,
|
||||
tt.fields.profileAttribute,
|
||||
)
|
||||
if tt.want.err == nil {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
if tt.want.err != nil && !tt.want.err(err) {
|
||||
t.Errorf("got wrong err: %v ", err)
|
||||
}
|
||||
if tt.want.err == nil {
|
||||
assert.Equal(t, tt.want.user, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
95
apps/api/internal/idp/providers/ldap/user.go
Normal file
95
apps/api/internal/idp/providers/ldap/user.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package ldap
|
||||
|
||||
import (
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
FirstName string `json:"firstName,omitempty"`
|
||||
LastName string `json:"lastName,omitempty"`
|
||||
DisplayName string `json:"displayName,omitempty"`
|
||||
NickName string `json:"nickName,omitempty"`
|
||||
PreferredUsername string `json:"preferredUsername,omitempty"`
|
||||
Email domain.EmailAddress `json:"email,omitempty"`
|
||||
EmailVerified bool `json:"emailVerified,omitempty"`
|
||||
Phone domain.PhoneNumber `json:"phone,omitempty"`
|
||||
PhoneVerified bool `json:"phoneVerified,omitempty"`
|
||||
PreferredLanguage language.Tag `json:"preferredLanguage,omitempty"`
|
||||
AvatarURL string `json:"avatarURL,omitempty"`
|
||||
Profile string `json:"profile,omitempty"`
|
||||
}
|
||||
|
||||
func NewUser(
|
||||
id string,
|
||||
firstName string,
|
||||
lastName string,
|
||||
displayName string,
|
||||
nickName string,
|
||||
preferredUsername string,
|
||||
email domain.EmailAddress,
|
||||
emailVerified bool,
|
||||
phone domain.PhoneNumber,
|
||||
phoneVerified bool,
|
||||
preferredLanguage language.Tag,
|
||||
avatarURL string,
|
||||
profile string,
|
||||
) *User {
|
||||
return &User{
|
||||
id,
|
||||
firstName,
|
||||
lastName,
|
||||
displayName,
|
||||
nickName,
|
||||
preferredUsername,
|
||||
email,
|
||||
emailVerified,
|
||||
phone,
|
||||
phoneVerified,
|
||||
preferredLanguage,
|
||||
avatarURL,
|
||||
profile,
|
||||
}
|
||||
}
|
||||
|
||||
func (u *User) GetID() string {
|
||||
return u.ID
|
||||
}
|
||||
func (u *User) GetFirstName() string {
|
||||
return u.FirstName
|
||||
}
|
||||
func (u *User) GetLastName() string {
|
||||
return u.LastName
|
||||
}
|
||||
func (u *User) GetDisplayName() string {
|
||||
return u.DisplayName
|
||||
}
|
||||
func (u *User) GetNickname() string {
|
||||
return u.NickName
|
||||
}
|
||||
func (u *User) GetPreferredUsername() string {
|
||||
return u.PreferredUsername
|
||||
}
|
||||
func (u *User) GetEmail() domain.EmailAddress {
|
||||
return u.Email
|
||||
}
|
||||
func (u *User) IsEmailVerified() bool {
|
||||
return u.EmailVerified
|
||||
}
|
||||
func (u *User) GetPhone() domain.PhoneNumber {
|
||||
return u.Phone
|
||||
}
|
||||
func (u *User) IsPhoneVerified() bool {
|
||||
return u.PhoneVerified
|
||||
}
|
||||
func (u *User) GetPreferredLanguage() language.Tag {
|
||||
return u.PreferredLanguage
|
||||
}
|
||||
func (u *User) GetAvatarURL() string {
|
||||
return u.AvatarURL
|
||||
}
|
||||
func (u *User) GetProfile() string {
|
||||
return u.Profile
|
||||
}
|
110
apps/api/internal/idp/providers/oauth/mapper.go
Normal file
110
apps/api/internal/idp/providers/oauth/mapper.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
)
|
||||
|
||||
var _ idp.User = (*UserMapper)(nil)
|
||||
|
||||
// UserMapper is an implementation of [idp.User].
|
||||
// It can be used in ZITADEL actions to map the `RawInfo`
|
||||
type UserMapper struct {
|
||||
idAttribute string
|
||||
RawInfo map[string]interface{}
|
||||
}
|
||||
|
||||
func NewUserMapper(idAttribute string) *UserMapper {
|
||||
return &UserMapper{
|
||||
idAttribute: idAttribute,
|
||||
RawInfo: make(map[string]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (u *UserMapper) UnmarshalJSON(data []byte) error {
|
||||
return json.Unmarshal(data, &u.RawInfo)
|
||||
}
|
||||
|
||||
// GetID is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) GetID() string {
|
||||
id, ok := u.RawInfo[u.idAttribute]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
switch i := id.(type) {
|
||||
case string:
|
||||
return i
|
||||
case int:
|
||||
return strconv.Itoa(i)
|
||||
case float64:
|
||||
return strconv.FormatFloat(i, 'f', -1, 64)
|
||||
default:
|
||||
return fmt.Sprint(i)
|
||||
}
|
||||
}
|
||||
|
||||
// GetFirstName is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) GetFirstName() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetLastName is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) GetLastName() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetDisplayName is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) GetDisplayName() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetNickname is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) GetNickname() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetPreferredUsername is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) GetPreferredUsername() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetEmail is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) GetEmail() domain.EmailAddress {
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsEmailVerified is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) IsEmailVerified() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// GetPhone is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) GetPhone() domain.PhoneNumber {
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsPhoneVerified is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) IsPhoneVerified() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// GetPreferredLanguage is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) GetPreferredLanguage() language.Tag {
|
||||
return language.Und
|
||||
}
|
||||
|
||||
// GetAvatarURL is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) GetAvatarURL() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetProfile is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) GetProfile() string {
|
||||
return ""
|
||||
}
|
143
apps/api/internal/idp/providers/oauth/oauth2.go
Normal file
143
apps/api/internal/idp/providers/oauth/oauth2.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
)
|
||||
|
||||
var _ idp.Provider = (*Provider)(nil)
|
||||
|
||||
// Provider is the [idp.Provider] implementation for a generic OAuth 2.0 provider
|
||||
type Provider struct {
|
||||
rp.RelyingParty
|
||||
options []rp.Option
|
||||
name string
|
||||
userEndpoint string
|
||||
user func() idp.User
|
||||
isLinkingAllowed bool
|
||||
isCreationAllowed bool
|
||||
isAutoCreation bool
|
||||
isAutoUpdate bool
|
||||
generateVerifier func() string
|
||||
}
|
||||
|
||||
type ProviderOpts func(provider *Provider)
|
||||
|
||||
// WithLinkingAllowed allows end users to link the federated user to an existing one.
|
||||
func WithLinkingAllowed() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.isLinkingAllowed = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithCreationAllowed allows end users to create a new user using the federated information.
|
||||
func WithCreationAllowed() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.isCreationAllowed = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithAutoCreation enables that federated users are automatically created if not already existing.
|
||||
func WithAutoCreation() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.isAutoCreation = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithAutoUpdate enables that information retrieved from the provider is automatically used to update
|
||||
// the existing user on each authentication.
|
||||
func WithAutoUpdate() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.isAutoUpdate = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithRelyingPartyOption allows to set an additional [rp.Option] like [rp.WithPKCE].
|
||||
func WithRelyingPartyOption(option rp.Option) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.options = append(p.options, option)
|
||||
}
|
||||
}
|
||||
|
||||
// New creates a generic OAuth 2.0 provider
|
||||
func New(config *oauth2.Config, name, userEndpoint string, user func() idp.User, options ...ProviderOpts) (provider *Provider, err error) {
|
||||
provider = &Provider{
|
||||
name: name,
|
||||
userEndpoint: userEndpoint,
|
||||
user: user,
|
||||
generateVerifier: oauth2.GenerateVerifier,
|
||||
}
|
||||
for _, option := range options {
|
||||
option(provider)
|
||||
}
|
||||
provider.RelyingParty, err = rp.NewRelyingPartyOAuth(config, provider.options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// Name implements the [idp.Provider] interface
|
||||
func (p *Provider) Name() string {
|
||||
return p.name
|
||||
}
|
||||
|
||||
// BeginAuth implements the [idp.Provider] interface.
|
||||
// It will create a [Session] with an OAuth2.0 authorization request as AuthURL.
|
||||
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...idp.Parameter) (idp.Session, error) {
|
||||
opts := make([]rp.AuthURLOpt, 0)
|
||||
var loginHintSet bool
|
||||
for _, param := range params {
|
||||
if username, ok := param.(idp.LoginHintParam); ok {
|
||||
loginHintSet = true
|
||||
opts = append(opts, loginHint(string(username)))
|
||||
}
|
||||
}
|
||||
if !loginHintSet {
|
||||
opts = append(opts, rp.WithPrompt(oidc.PromptSelectAccount))
|
||||
}
|
||||
|
||||
var codeVerifier string
|
||||
if p.RelyingParty.IsPKCE() {
|
||||
codeVerifier = p.generateVerifier()
|
||||
opts = append(opts, rp.WithCodeChallenge(oidc.NewSHACodeChallenge(codeVerifier)))
|
||||
}
|
||||
|
||||
url := rp.AuthURL(state, p.RelyingParty, opts...)
|
||||
return &Session{AuthURL: url, Provider: p, CodeVerifier: codeVerifier}, nil
|
||||
}
|
||||
|
||||
func loginHint(hint string) rp.AuthURLOpt {
|
||||
return func() []oauth2.AuthCodeOption {
|
||||
return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("login_hint", hint)}
|
||||
}
|
||||
}
|
||||
|
||||
// IsLinkingAllowed implements the [idp.Provider] interface.
|
||||
func (p *Provider) IsLinkingAllowed() bool {
|
||||
return p.isLinkingAllowed
|
||||
}
|
||||
|
||||
// IsCreationAllowed implements the [idp.Provider] interface.
|
||||
func (p *Provider) IsCreationAllowed() bool {
|
||||
return p.isCreationAllowed
|
||||
}
|
||||
|
||||
// IsAutoCreation implements the [idp.Provider] interface.
|
||||
func (p *Provider) IsAutoCreation() bool {
|
||||
return p.isAutoCreation
|
||||
}
|
||||
|
||||
// IsAutoUpdate implements the [idp.Provider] interface.
|
||||
func (p *Provider) IsAutoUpdate() bool {
|
||||
return p.isAutoUpdate
|
||||
}
|
||||
|
||||
func (p *Provider) User() idp.User {
|
||||
return p.user()
|
||||
}
|
184
apps/api/internal/idp/providers/oauth/oauth2_test.go
Normal file
184
apps/api/internal/idp/providers/oauth/oauth2_test.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
)
|
||||
|
||||
func TestProvider_BeginAuth(t *testing.T) {
|
||||
type fields struct {
|
||||
config *oauth2.Config
|
||||
name string
|
||||
userEndpoint string
|
||||
userMapper func() idp.User
|
||||
options []ProviderOpts
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want idp.Session
|
||||
}{
|
||||
{
|
||||
name: "successful auth without PKCE",
|
||||
fields: fields{
|
||||
config: &oauth2.Config{
|
||||
ClientID: "clientID",
|
||||
ClientSecret: "clientSecret",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "https://oauth2.com/authorize",
|
||||
TokenURL: "https://oauth2.com/token",
|
||||
},
|
||||
RedirectURL: "redirectURI",
|
||||
Scopes: []string{"user"},
|
||||
},
|
||||
},
|
||||
want: &Session{AuthURL: "https://oauth2.com/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=user&state=testState"},
|
||||
},
|
||||
{
|
||||
name: "successful auth with PKCE",
|
||||
fields: fields{
|
||||
config: &oauth2.Config{
|
||||
ClientID: "clientID",
|
||||
ClientSecret: "clientSecret",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "https://oauth2.com/authorize",
|
||||
TokenURL: "https://oauth2.com/token",
|
||||
},
|
||||
RedirectURL: "redirectURI",
|
||||
Scopes: []string{"user"},
|
||||
},
|
||||
options: []ProviderOpts{
|
||||
WithLinkingAllowed(),
|
||||
WithCreationAllowed(),
|
||||
WithAutoCreation(),
|
||||
WithAutoUpdate(),
|
||||
WithRelyingPartyOption(rp.WithPKCE(nil)),
|
||||
},
|
||||
},
|
||||
want: &Session{AuthURL: "https://oauth2.com/authorize?client_id=clientID&code_challenge=2ZoH_a01aprzLkwVbjlPsBo4m8mJ_zOKkaDqYM7Oh5w&code_challenge_method=S256&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=user&state=testState"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
r := require.New(t)
|
||||
|
||||
provider, err := New(tt.fields.config, tt.fields.name, tt.fields.userEndpoint, tt.fields.userMapper, tt.fields.options...)
|
||||
r.NoError(err)
|
||||
provider.generateVerifier = func() string {
|
||||
return "pkceOAuthVerifier"
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
session, err := provider.BeginAuth(ctx, "testState")
|
||||
r.NoError(err)
|
||||
|
||||
wantAuth, wantErr := tt.want.GetAuth(ctx)
|
||||
gotAuth, gotErr := session.GetAuth(ctx)
|
||||
a.Equal(wantAuth, gotAuth)
|
||||
a.ErrorIs(gotErr, wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_Options(t *testing.T) {
|
||||
type fields struct {
|
||||
config *oauth2.Config
|
||||
name string
|
||||
userEndpoint string
|
||||
userMapper func() idp.User
|
||||
options []ProviderOpts
|
||||
}
|
||||
type want struct {
|
||||
name string
|
||||
linkingAllowed bool
|
||||
creationAllowed bool
|
||||
autoCreation bool
|
||||
autoUpdate bool
|
||||
pkce bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "default",
|
||||
fields: fields{
|
||||
name: "oauth",
|
||||
config: &oauth2.Config{
|
||||
ClientID: "clientID",
|
||||
ClientSecret: "clientSecret",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "https://oauth2.com/authorize",
|
||||
TokenURL: "https://oauth2.com/token",
|
||||
},
|
||||
RedirectURL: "redirectURI",
|
||||
Scopes: []string{"user"},
|
||||
},
|
||||
options: nil,
|
||||
},
|
||||
want: want{
|
||||
name: "oauth",
|
||||
linkingAllowed: false,
|
||||
creationAllowed: false,
|
||||
autoCreation: false,
|
||||
autoUpdate: false,
|
||||
pkce: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all true",
|
||||
fields: fields{
|
||||
name: "oauth",
|
||||
config: &oauth2.Config{
|
||||
ClientID: "clientID",
|
||||
ClientSecret: "clientSecret",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "https://oauth2.com/authorize",
|
||||
TokenURL: "https://oauth2.com/token",
|
||||
},
|
||||
RedirectURL: "redirectURI",
|
||||
Scopes: []string{"user"},
|
||||
},
|
||||
options: []ProviderOpts{
|
||||
WithLinkingAllowed(),
|
||||
WithCreationAllowed(),
|
||||
WithAutoCreation(),
|
||||
WithAutoUpdate(),
|
||||
WithRelyingPartyOption(rp.WithPKCE(nil)),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
name: "oauth",
|
||||
linkingAllowed: true,
|
||||
creationAllowed: true,
|
||||
autoCreation: true,
|
||||
autoUpdate: true,
|
||||
pkce: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
provider, err := New(tt.fields.config, tt.fields.name, tt.fields.userEndpoint, tt.fields.userMapper, tt.fields.options...)
|
||||
require.NoError(t, err)
|
||||
|
||||
a.Equal(tt.want.name, provider.Name())
|
||||
a.Equal(tt.want.linkingAllowed, provider.IsLinkingAllowed())
|
||||
a.Equal(tt.want.creationAllowed, provider.IsCreationAllowed())
|
||||
a.Equal(tt.want.autoCreation, provider.IsAutoCreation())
|
||||
a.Equal(tt.want.autoUpdate, provider.IsAutoUpdate())
|
||||
a.Equal(tt.want.pkce, provider.RelyingParty.IsPKCE())
|
||||
})
|
||||
}
|
||||
}
|
91
apps/api/internal/idp/providers/oauth/session.go
Normal file
91
apps/api/internal/idp/providers/oauth/session.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
)
|
||||
|
||||
var ErrCodeMissing = errors.New("no auth code provided")
|
||||
|
||||
const (
|
||||
CodeVerifier = "codeVerifier"
|
||||
)
|
||||
|
||||
var _ idp.Session = (*Session)(nil)
|
||||
|
||||
// Session is the [idp.Session] implementation for the OAuth2.0 provider.
|
||||
type Session struct {
|
||||
AuthURL string
|
||||
CodeVerifier string
|
||||
Code string
|
||||
Tokens *oidc.Tokens[*oidc.IDTokenClaims]
|
||||
|
||||
Provider *Provider
|
||||
}
|
||||
|
||||
func NewSession(provider *Provider, code string, idpArguments map[string]any) *Session {
|
||||
verifier, _ := idpArguments[CodeVerifier].(string)
|
||||
return &Session{Provider: provider, Code: code, CodeVerifier: verifier}
|
||||
}
|
||||
|
||||
// GetAuth implements the [idp.Session] interface.
|
||||
func (s *Session) GetAuth(ctx context.Context) (idp.Auth, error) {
|
||||
return idp.Redirect(s.AuthURL)
|
||||
}
|
||||
|
||||
// PersistentParameters implements the [idp.Session] interface.
|
||||
func (s *Session) PersistentParameters() map[string]any {
|
||||
if s.CodeVerifier == "" {
|
||||
return nil
|
||||
}
|
||||
return map[string]any{CodeVerifier: s.CodeVerifier}
|
||||
}
|
||||
|
||||
// FetchUser implements the [idp.Session] interface.
|
||||
// It will execute an OAuth 2.0 code exchange if needed to retrieve the access token,
|
||||
// call the specified userEndpoint and map the received information into an [idp.User].
|
||||
func (s *Session) FetchUser(ctx context.Context) (_ idp.User, err error) {
|
||||
if s.Tokens == nil {
|
||||
if err = s.authorize(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
req, err := http.NewRequest("GET", s.Provider.userEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("authorization", s.Tokens.TokenType+" "+s.Tokens.AccessToken)
|
||||
user := s.Provider.User()
|
||||
if err := httphelper.HttpRequest(s.Provider.RelyingParty.HttpClient(), req, &user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *Session) ExpiresAt() time.Time {
|
||||
if s.Tokens == nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return s.Tokens.Expiry
|
||||
}
|
||||
|
||||
func (s *Session) authorize(ctx context.Context) (err error) {
|
||||
if s.Code == "" {
|
||||
return ErrCodeMissing
|
||||
}
|
||||
var opts []rp.CodeExchangeOpt
|
||||
if s.CodeVerifier != "" {
|
||||
opts = append(opts, rp.WithCodeVerifier(s.CodeVerifier))
|
||||
}
|
||||
s.Tokens, err = rp.CodeExchange[*oidc.IDTokenClaims](ctx, s.Code, s.Provider.RelyingParty, opts...)
|
||||
|
||||
return err
|
||||
}
|
274
apps/api/internal/idp/providers/oauth/session_test.go
Normal file
274
apps/api/internal/idp/providers/oauth/session_test.go
Normal file
@@ -0,0 +1,274 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/h2non/gock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
)
|
||||
|
||||
func TestProvider_FetchUser(t *testing.T) {
|
||||
type fields struct {
|
||||
config *oauth2.Config
|
||||
name string
|
||||
userEndpoint string
|
||||
httpMock func(issuer string)
|
||||
userMapper func() idp.User
|
||||
authURL string
|
||||
code string
|
||||
tokens *oidc.Tokens[*oidc.IDTokenClaims]
|
||||
}
|
||||
type want struct {
|
||||
err func(error) bool
|
||||
user idp.User
|
||||
id string
|
||||
firstName string
|
||||
lastName string
|
||||
displayName string
|
||||
nickName string
|
||||
preferredUsername string
|
||||
email string
|
||||
isEmailVerified bool
|
||||
phone string
|
||||
isPhoneVerified bool
|
||||
preferredLanguage language.Tag
|
||||
avatarURL string
|
||||
profile string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "unauthenticated session, error",
|
||||
fields: fields{
|
||||
config: &oauth2.Config{
|
||||
ClientID: "clientID",
|
||||
ClientSecret: "clientSecret",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "https://oauth2.com/authorize",
|
||||
TokenURL: "https://oauth2.com/token",
|
||||
},
|
||||
RedirectURL: "redirectURI",
|
||||
Scopes: []string{"user"},
|
||||
},
|
||||
userEndpoint: "https://oauth2.com/user",
|
||||
httpMock: func(issuer string) {},
|
||||
authURL: "https://oauth2.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=user&state=testState",
|
||||
tokens: nil,
|
||||
},
|
||||
want: want{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, ErrCodeMissing)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user error",
|
||||
fields: fields{
|
||||
config: &oauth2.Config{
|
||||
ClientID: "clientID",
|
||||
ClientSecret: "clientSecret",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "https://oauth2.com/authorize",
|
||||
TokenURL: "https://oauth2.com/token",
|
||||
},
|
||||
RedirectURL: "redirectURI",
|
||||
Scopes: []string{"user"},
|
||||
},
|
||||
userEndpoint: "https://oauth2.com/user",
|
||||
httpMock: func(issuer string) {
|
||||
gock.New(issuer).
|
||||
Get("/user").
|
||||
Reply(http.StatusInternalServerError)
|
||||
},
|
||||
userMapper: func() idp.User {
|
||||
return NewUserMapper("userID")
|
||||
},
|
||||
authURL: "https://oauth2.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=user&state=testState",
|
||||
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: oidc.BearerToken,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
err: func(err error) bool {
|
||||
return err.Error() == "http status not ok: 500 Internal Server Error "
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful fetch",
|
||||
fields: fields{
|
||||
config: &oauth2.Config{
|
||||
ClientID: "clientID",
|
||||
ClientSecret: "clientSecret",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "https://oauth2.com/authorize",
|
||||
TokenURL: "https://oauth2.com/token",
|
||||
},
|
||||
RedirectURL: "redirectURI",
|
||||
Scopes: []string{"user"},
|
||||
},
|
||||
userEndpoint: "https://oauth2.com/user",
|
||||
httpMock: func(issuer string) {
|
||||
gock.New(issuer).
|
||||
Get("/user").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"userID": "id",
|
||||
"custom": "claim",
|
||||
})
|
||||
},
|
||||
userMapper: func() idp.User {
|
||||
return NewUserMapper("userID")
|
||||
},
|
||||
authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=user&state=testState",
|
||||
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: oidc.BearerToken,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
user: &UserMapper{
|
||||
idAttribute: "userID",
|
||||
RawInfo: map[string]interface{}{
|
||||
"userID": "id",
|
||||
"custom": "claim",
|
||||
},
|
||||
},
|
||||
id: "id",
|
||||
firstName: "",
|
||||
lastName: "",
|
||||
displayName: "",
|
||||
nickName: "",
|
||||
preferredUsername: "",
|
||||
email: "",
|
||||
isEmailVerified: false,
|
||||
phone: "",
|
||||
isPhoneVerified: false,
|
||||
preferredLanguage: language.Und,
|
||||
avatarURL: "",
|
||||
profile: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful fetch with code exchange",
|
||||
fields: fields{
|
||||
config: &oauth2.Config{
|
||||
ClientID: "clientID",
|
||||
ClientSecret: "clientSecret",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "https://oauth2.com/authorize",
|
||||
TokenURL: "https://oauth2.com/token",
|
||||
},
|
||||
RedirectURL: "redirectURI",
|
||||
Scopes: []string{"user"},
|
||||
},
|
||||
userEndpoint: "https://oauth2.com/user",
|
||||
httpMock: func(issuer string) {
|
||||
gock.New(issuer).
|
||||
Post("/token").
|
||||
BodyString("client_id=clientID&client_secret=clientSecret&code=code&grant_type=authorization_code&redirect_uri=redirectURI").
|
||||
Reply(200).
|
||||
JSON(&oidc.AccessTokenResponse{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: oidc.BearerToken,
|
||||
RefreshToken: "",
|
||||
ExpiresIn: 3600,
|
||||
IDToken: "",
|
||||
State: "testState"})
|
||||
gock.New(issuer).
|
||||
Get("/user").
|
||||
Reply(200).
|
||||
JSON(map[string]interface{}{
|
||||
"userID": "id",
|
||||
"custom": "claim",
|
||||
})
|
||||
},
|
||||
userMapper: func() idp.User {
|
||||
return NewUserMapper("userID")
|
||||
},
|
||||
authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=user&state=testState",
|
||||
tokens: nil,
|
||||
code: "code",
|
||||
},
|
||||
want: want{
|
||||
user: &UserMapper{
|
||||
idAttribute: "userID",
|
||||
RawInfo: map[string]interface{}{
|
||||
"userID": "id",
|
||||
"custom": "claim",
|
||||
},
|
||||
},
|
||||
id: "id",
|
||||
firstName: "",
|
||||
lastName: "",
|
||||
displayName: "",
|
||||
nickName: "",
|
||||
preferredUsername: "",
|
||||
email: "",
|
||||
isEmailVerified: false,
|
||||
phone: "",
|
||||
isPhoneVerified: false,
|
||||
preferredLanguage: language.Und,
|
||||
avatarURL: "",
|
||||
profile: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer gock.Off()
|
||||
tt.fields.httpMock("https://oauth2.com")
|
||||
a := assert.New(t)
|
||||
|
||||
provider, err := New(tt.fields.config, tt.fields.name, tt.fields.userEndpoint, tt.fields.userMapper)
|
||||
require.NoError(t, err)
|
||||
|
||||
session := &Session{
|
||||
AuthURL: tt.fields.authURL,
|
||||
Code: tt.fields.code,
|
||||
Tokens: tt.fields.tokens,
|
||||
Provider: provider,
|
||||
}
|
||||
|
||||
user, err := session.FetchUser(context.Background())
|
||||
if tt.want.err != nil && !tt.want.err(err) {
|
||||
a.Fail("invalid error", err)
|
||||
}
|
||||
if tt.want.err == nil {
|
||||
a.NoError(err)
|
||||
a.Equal(tt.want.user, user)
|
||||
a.Equal(tt.want.id, user.GetID())
|
||||
a.Equal(tt.want.firstName, user.GetFirstName())
|
||||
a.Equal(tt.want.lastName, user.GetLastName())
|
||||
a.Equal(tt.want.displayName, user.GetDisplayName())
|
||||
a.Equal(tt.want.nickName, user.GetNickname())
|
||||
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
|
||||
a.Equal(domain.EmailAddress(tt.want.email), user.GetEmail())
|
||||
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
|
||||
a.Equal(domain.PhoneNumber(tt.want.phone), user.GetPhone())
|
||||
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
|
||||
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
|
||||
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
|
||||
a.Equal(tt.want.profile, user.GetProfile())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
190
apps/api/internal/idp/providers/oidc/oidc.go
Normal file
190
apps/api/internal/idp/providers/oidc/oidc.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
)
|
||||
|
||||
var _ idp.Provider = (*Provider)(nil)
|
||||
|
||||
// Provider is the [idp.Provider] implementation for a generic OIDC provider
|
||||
type Provider struct {
|
||||
rp.RelyingParty
|
||||
options []rp.Option
|
||||
name string
|
||||
isLinkingAllowed bool
|
||||
isCreationAllowed bool
|
||||
isAutoCreation bool
|
||||
isAutoUpdate bool
|
||||
useIDToken bool
|
||||
userInfoMapper func(info *oidc.UserInfo) idp.User
|
||||
authOptions []func(bool) rp.AuthURLOpt
|
||||
generateVerifier func() string
|
||||
}
|
||||
|
||||
type ProviderOpts func(provider *Provider)
|
||||
|
||||
// WithLinkingAllowed allows end users to link the federated user to an existing one.
|
||||
func WithLinkingAllowed() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.isLinkingAllowed = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithCreationAllowed allows end users to create a new user using the federated information.
|
||||
func WithCreationAllowed() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.isCreationAllowed = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithAutoCreation enables that federated users are automatically created if not already existing.
|
||||
func WithAutoCreation() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.isAutoCreation = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithAutoUpdate enables that information retrieved from the provider is automatically used to update
|
||||
// the existing user on each authentication.
|
||||
func WithAutoUpdate() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.isAutoUpdate = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithIDTokenMapping enables that information to map the user is retrieved from the id_token and not the userinfo endpoint.
|
||||
func WithIDTokenMapping() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.useIDToken = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithRelyingPartyOption allows to set an additional [rp.Option] like [rp.WithPKCE].
|
||||
func WithRelyingPartyOption(option rp.Option) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.options = append(p.options, option)
|
||||
}
|
||||
}
|
||||
|
||||
// WithSelectAccount adds the select_account prompt to the auth request (if no login_hint is set)
|
||||
func WithSelectAccount() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.authOptions = append(p.authOptions, func(loginHintSet bool) rp.AuthURLOpt {
|
||||
if loginHintSet {
|
||||
return nil
|
||||
}
|
||||
return rp.WithPrompt(oidc.PromptSelectAccount)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// WithResponseMode sets the `response_mode` params in the auth request
|
||||
func WithResponseMode(mode oidc.ResponseMode) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
paramOpt := rp.WithResponseModeURLParam(mode)
|
||||
p.authOptions = append(p.authOptions, func(_ bool) rp.AuthURLOpt {
|
||||
return rp.AuthURLOpt(paramOpt)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type UserInfoMapper func(info *oidc.UserInfo) idp.User
|
||||
|
||||
var DefaultMapper UserInfoMapper = func(info *oidc.UserInfo) idp.User {
|
||||
return NewUser(info)
|
||||
}
|
||||
|
||||
// New creates a generic OIDC provider
|
||||
func New(name, issuer, clientID, clientSecret, redirectURI string, scopes []string, userInfoMapper UserInfoMapper, options ...ProviderOpts) (provider *Provider, err error) {
|
||||
provider = &Provider{
|
||||
name: name,
|
||||
userInfoMapper: userInfoMapper,
|
||||
generateVerifier: oauth2.GenerateVerifier,
|
||||
}
|
||||
for _, option := range options {
|
||||
option(provider)
|
||||
}
|
||||
provider.RelyingParty, err = rp.NewRelyingPartyOIDC(context.TODO(), issuer, clientID, clientSecret, redirectURI, setDefaultScope(scopes), provider.options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// setDefaultScope ensures that at least openid ist set
|
||||
// if none is provided it will request `openid profile email phone`
|
||||
func setDefaultScope(scopes []string) []string {
|
||||
if len(scopes) == 0 {
|
||||
return []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopePhone}
|
||||
}
|
||||
for _, scope := range scopes {
|
||||
if scope == oidc.ScopeOpenID {
|
||||
return scopes
|
||||
}
|
||||
}
|
||||
return append(scopes, oidc.ScopeOpenID)
|
||||
}
|
||||
|
||||
// Name implements the [idp.Provider] interface
|
||||
func (p *Provider) Name() string {
|
||||
return p.name
|
||||
}
|
||||
|
||||
// BeginAuth implements the [idp.Provider] interface.
|
||||
// It will create a [Session] with an OIDC authorization request as AuthURL.
|
||||
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...idp.Parameter) (idp.Session, error) {
|
||||
opts := make([]rp.AuthURLOpt, 0)
|
||||
var loginHintSet bool
|
||||
for _, param := range params {
|
||||
if username, ok := param.(idp.LoginHintParam); ok {
|
||||
loginHintSet = true
|
||||
opts = append(opts, loginHint(string(username)))
|
||||
}
|
||||
}
|
||||
for _, option := range p.authOptions {
|
||||
if opt := option(loginHintSet); opt != nil {
|
||||
opts = append(opts, opt)
|
||||
}
|
||||
}
|
||||
|
||||
var codeVerifier string
|
||||
if p.RelyingParty.IsPKCE() {
|
||||
codeVerifier = p.generateVerifier()
|
||||
opts = append(opts, rp.WithCodeChallenge(oidc.NewSHACodeChallenge(codeVerifier)))
|
||||
}
|
||||
|
||||
url := rp.AuthURL(state, p.RelyingParty, opts...)
|
||||
return &Session{AuthURL: url, Provider: p, CodeVerifier: codeVerifier}, nil
|
||||
}
|
||||
|
||||
func loginHint(hint string) rp.AuthURLOpt {
|
||||
return func() []oauth2.AuthCodeOption {
|
||||
return []oauth2.AuthCodeOption{oauth2.SetAuthURLParam("login_hint", hint)}
|
||||
}
|
||||
}
|
||||
|
||||
// IsLinkingAllowed implements the [idp.Provider] interface.
|
||||
func (p *Provider) IsLinkingAllowed() bool {
|
||||
return p.isLinkingAllowed
|
||||
}
|
||||
|
||||
// IsCreationAllowed implements the [idp.Provider] interface.
|
||||
func (p *Provider) IsCreationAllowed() bool {
|
||||
return p.isCreationAllowed
|
||||
}
|
||||
|
||||
// IsAutoCreation implements the [idp.Provider] interface.
|
||||
func (p *Provider) IsAutoCreation() bool {
|
||||
return p.isAutoCreation
|
||||
}
|
||||
|
||||
// IsAutoUpdate implements the [idp.Provider] interface.
|
||||
func (p *Provider) IsAutoUpdate() bool {
|
||||
return p.isAutoUpdate
|
||||
}
|
221
apps/api/internal/idp/providers/oidc/oidc_test.go
Normal file
221
apps/api/internal/idp/providers/oidc/oidc_test.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/h2non/gock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
)
|
||||
|
||||
func TestProvider_BeginAuth(t *testing.T) {
|
||||
type fields struct {
|
||||
name string
|
||||
issuer string
|
||||
clientID string
|
||||
clientSecret string
|
||||
redirectURI string
|
||||
scopes []string
|
||||
userMapper func(info *oidc.UserInfo) idp.User
|
||||
httpMock func(issuer string)
|
||||
opts []ProviderOpts
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want idp.Session
|
||||
}{
|
||||
{
|
||||
name: "successful auth without PKCE",
|
||||
fields: fields{
|
||||
name: "oidc",
|
||||
issuer: "https://issuer.com",
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
scopes: []string{"openid"},
|
||||
userMapper: DefaultMapper,
|
||||
httpMock: func(issuer string) {
|
||||
gock.New(issuer).
|
||||
Get(oidc.DiscoveryEndpoint).
|
||||
Reply(200).
|
||||
JSON(&oidc.DiscoveryConfiguration{
|
||||
Issuer: issuer,
|
||||
AuthorizationEndpoint: issuer + "/authorize",
|
||||
TokenEndpoint: issuer + "/token",
|
||||
UserinfoEndpoint: issuer + "/userinfo",
|
||||
})
|
||||
},
|
||||
opts: []ProviderOpts{WithSelectAccount()},
|
||||
},
|
||||
want: &Session{AuthURL: "https://issuer.com/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState"},
|
||||
},
|
||||
{
|
||||
name: "successful auth with PKCE",
|
||||
fields: fields{
|
||||
name: "oidc",
|
||||
issuer: "https://issuer.com",
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
scopes: []string{"openid"},
|
||||
userMapper: DefaultMapper,
|
||||
httpMock: func(issuer string) {
|
||||
gock.New(issuer).
|
||||
Get(oidc.DiscoveryEndpoint).
|
||||
Reply(200).
|
||||
JSON(&oidc.DiscoveryConfiguration{
|
||||
Issuer: issuer,
|
||||
AuthorizationEndpoint: issuer + "/authorize",
|
||||
TokenEndpoint: issuer + "/token",
|
||||
UserinfoEndpoint: issuer + "/userinfo",
|
||||
})
|
||||
},
|
||||
opts: []ProviderOpts{WithSelectAccount(), WithRelyingPartyOption(rp.WithPKCE(nil))},
|
||||
},
|
||||
want: &Session{AuthURL: "https://issuer.com/authorize?client_id=clientID&code_challenge=2ZoH_a01aprzLkwVbjlPsBo4m8mJ_zOKkaDqYM7Oh5w&code_challenge_method=S256&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer gock.Off()
|
||||
tt.fields.httpMock(tt.fields.issuer)
|
||||
a := assert.New(t)
|
||||
r := require.New(t)
|
||||
|
||||
provider, err := New(tt.fields.name, tt.fields.issuer, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.userMapper, tt.fields.opts...)
|
||||
r.NoError(err)
|
||||
provider.generateVerifier = func() string {
|
||||
return "pkceOAuthVerifier"
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
session, err := provider.BeginAuth(ctx, "testState")
|
||||
r.NoError(err)
|
||||
|
||||
wantAuth, wantErr := tt.want.GetAuth(ctx)
|
||||
gotAuth, gotErr := session.GetAuth(ctx)
|
||||
a.Equal(wantAuth, gotAuth)
|
||||
a.ErrorIs(gotErr, wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_Options(t *testing.T) {
|
||||
type fields struct {
|
||||
name string
|
||||
issuer string
|
||||
clientID string
|
||||
clientSecret string
|
||||
redirectURI string
|
||||
scopes []string
|
||||
userMapper func(info *oidc.UserInfo) idp.User
|
||||
opts []ProviderOpts
|
||||
httpMock func(issuer string)
|
||||
}
|
||||
type want struct {
|
||||
name string
|
||||
linkingAllowed bool
|
||||
creationAllowed bool
|
||||
autoCreation bool
|
||||
autoUpdate bool
|
||||
pkce bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "default",
|
||||
fields: fields{
|
||||
name: "oidc",
|
||||
issuer: "https://issuer.com",
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
scopes: []string{"openid"},
|
||||
userMapper: DefaultMapper,
|
||||
opts: nil,
|
||||
httpMock: func(issuer string) {
|
||||
gock.New(issuer).
|
||||
Get(oidc.DiscoveryEndpoint).
|
||||
Reply(200).
|
||||
JSON(&oidc.DiscoveryConfiguration{
|
||||
Issuer: issuer,
|
||||
AuthorizationEndpoint: issuer + "/authorize",
|
||||
TokenEndpoint: issuer + "/token",
|
||||
UserinfoEndpoint: issuer + "/userinfo",
|
||||
})
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
name: "oidc",
|
||||
linkingAllowed: false,
|
||||
creationAllowed: false,
|
||||
autoCreation: false,
|
||||
autoUpdate: false,
|
||||
pkce: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all true",
|
||||
fields: fields{
|
||||
name: "oidc",
|
||||
issuer: "https://issuer.com",
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
scopes: []string{"openid"},
|
||||
userMapper: DefaultMapper,
|
||||
opts: []ProviderOpts{
|
||||
WithLinkingAllowed(),
|
||||
WithCreationAllowed(),
|
||||
WithAutoCreation(),
|
||||
WithAutoUpdate(),
|
||||
WithRelyingPartyOption(rp.WithPKCE(nil)),
|
||||
},
|
||||
httpMock: func(issuer string) {
|
||||
gock.New(issuer).
|
||||
Get(oidc.DiscoveryEndpoint).
|
||||
Reply(200).
|
||||
JSON(&oidc.DiscoveryConfiguration{
|
||||
Issuer: issuer,
|
||||
AuthorizationEndpoint: issuer + "/authorize",
|
||||
TokenEndpoint: issuer + "/token",
|
||||
UserinfoEndpoint: issuer + "/userinfo",
|
||||
})
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
name: "oidc",
|
||||
linkingAllowed: true,
|
||||
creationAllowed: true,
|
||||
autoCreation: true,
|
||||
autoUpdate: true,
|
||||
pkce: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer gock.Off()
|
||||
tt.fields.httpMock(tt.fields.issuer)
|
||||
a := assert.New(t)
|
||||
|
||||
provider, err := New(tt.fields.name, tt.fields.issuer, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.userMapper, tt.fields.opts...)
|
||||
require.NoError(t, err)
|
||||
|
||||
a.Equal(tt.want.name, provider.Name())
|
||||
a.Equal(tt.want.linkingAllowed, provider.IsLinkingAllowed())
|
||||
a.Equal(tt.want.creationAllowed, provider.IsCreationAllowed())
|
||||
a.Equal(tt.want.autoCreation, provider.IsAutoCreation())
|
||||
a.Equal(tt.want.autoUpdate, provider.IsAutoUpdate())
|
||||
})
|
||||
}
|
||||
}
|
157
apps/api/internal/idp/providers/oidc/session.go
Normal file
157
apps/api/internal/idp/providers/oidc/session.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
|
||||
)
|
||||
|
||||
var ErrCodeMissing = errors.New("no auth code provided")
|
||||
|
||||
var _ idp.Session = (*Session)(nil)
|
||||
|
||||
// Session is the [idp.Session] implementation for the OIDC provider.
|
||||
type Session struct {
|
||||
Provider *Provider
|
||||
AuthURL string
|
||||
CodeVerifier string
|
||||
Code string
|
||||
Tokens *oidc.Tokens[*oidc.IDTokenClaims]
|
||||
}
|
||||
|
||||
func NewSession(provider *Provider, code string, idpArguments map[string]any) *Session {
|
||||
verifier, _ := idpArguments[oauth.CodeVerifier].(string)
|
||||
return &Session{Provider: provider, Code: code, CodeVerifier: verifier}
|
||||
}
|
||||
|
||||
// GetAuth implements the [idp.Session] interface.
|
||||
func (s *Session) GetAuth(ctx context.Context) (idp.Auth, error) {
|
||||
return idp.Redirect(s.AuthURL)
|
||||
}
|
||||
|
||||
// PersistentParameters implements the [idp.Session] interface.
|
||||
func (s *Session) PersistentParameters() map[string]any {
|
||||
if s.CodeVerifier == "" {
|
||||
return nil
|
||||
}
|
||||
return map[string]any{oauth.CodeVerifier: s.CodeVerifier}
|
||||
}
|
||||
|
||||
// FetchUser implements the [idp.Session] interface.
|
||||
// It will execute an OIDC code exchange if needed to retrieve the tokens,
|
||||
// call the userinfo endpoint and map the received information into an [idp.User].
|
||||
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
|
||||
if s.Tokens == nil {
|
||||
if err = s.Authorize(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var info *oidc.UserInfo
|
||||
if s.Provider.useIDToken {
|
||||
info = s.Tokens.IDTokenClaims.GetUserInfo()
|
||||
} else {
|
||||
info, err = rp.Userinfo[*oidc.UserInfo](ctx,
|
||||
s.Tokens.AccessToken,
|
||||
s.Tokens.TokenType,
|
||||
s.Tokens.IDTokenClaims.GetSubject(),
|
||||
s.Provider.RelyingParty,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
u := s.Provider.userInfoMapper(info)
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (s *Session) ExpiresAt() time.Time {
|
||||
if s.Tokens == nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return s.Tokens.Expiry
|
||||
}
|
||||
|
||||
func (s *Session) Authorize(ctx context.Context) (err error) {
|
||||
if s.Code == "" {
|
||||
return ErrCodeMissing
|
||||
}
|
||||
var opts []rp.CodeExchangeOpt
|
||||
if s.CodeVerifier != "" {
|
||||
opts = append(opts, rp.WithCodeVerifier(s.CodeVerifier))
|
||||
}
|
||||
s.Tokens, err = rp.CodeExchange[*oidc.IDTokenClaims](ctx, s.Code, s.Provider.RelyingParty, opts...)
|
||||
return err
|
||||
}
|
||||
|
||||
func NewUser(info *oidc.UserInfo) *User {
|
||||
return &User{UserInfo: info}
|
||||
}
|
||||
|
||||
func InitUser() *User {
|
||||
return &User{UserInfo: &oidc.UserInfo{}}
|
||||
}
|
||||
|
||||
type User struct {
|
||||
*oidc.UserInfo
|
||||
}
|
||||
|
||||
func (u *User) GetID() string {
|
||||
return u.Subject
|
||||
}
|
||||
|
||||
func (u *User) GetFirstName() string {
|
||||
return u.GivenName
|
||||
}
|
||||
|
||||
func (u *User) GetLastName() string {
|
||||
return u.FamilyName
|
||||
}
|
||||
|
||||
func (u *User) GetDisplayName() string {
|
||||
return u.Name
|
||||
}
|
||||
|
||||
func (u *User) GetNickname() string {
|
||||
return u.Nickname
|
||||
}
|
||||
|
||||
func (u *User) GetPreferredUsername() string {
|
||||
return u.PreferredUsername
|
||||
}
|
||||
|
||||
func (u *User) GetEmail() domain.EmailAddress {
|
||||
return domain.EmailAddress(u.UserInfo.Email)
|
||||
}
|
||||
|
||||
func (u *User) IsEmailVerified() bool {
|
||||
return bool(u.EmailVerified)
|
||||
}
|
||||
|
||||
func (u *User) GetPhone() domain.PhoneNumber {
|
||||
return domain.PhoneNumber(u.PhoneNumber)
|
||||
}
|
||||
|
||||
func (u *User) IsPhoneVerified() bool {
|
||||
return u.PhoneNumberVerified
|
||||
}
|
||||
|
||||
func (u *User) GetPreferredLanguage() language.Tag {
|
||||
return u.Locale.Tag()
|
||||
}
|
||||
|
||||
func (u *User) GetAvatarURL() string {
|
||||
return u.Picture
|
||||
}
|
||||
|
||||
func (u *User) GetProfile() string {
|
||||
return u.Profile
|
||||
}
|
471
apps/api/internal/idp/providers/oidc/session_test.go
Normal file
471
apps/api/internal/idp/providers/oidc/session_test.go
Normal file
@@ -0,0 +1,471 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v4"
|
||||
"github.com/h2non/gock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
)
|
||||
|
||||
func TestSession_FetchUser(t *testing.T) {
|
||||
type fields struct {
|
||||
name string
|
||||
issuer string
|
||||
clientID string
|
||||
clientSecret string
|
||||
redirectURI string
|
||||
scopes []string
|
||||
userMapper func(*oidc.UserInfo) idp.User
|
||||
httpMock func(issuer string)
|
||||
authURL string
|
||||
code string
|
||||
tokens *oidc.Tokens[*oidc.IDTokenClaims]
|
||||
}
|
||||
type want struct {
|
||||
err error
|
||||
id string
|
||||
firstName string
|
||||
lastName string
|
||||
displayName string
|
||||
nickName string
|
||||
preferredUsername string
|
||||
email string
|
||||
isEmailVerified bool
|
||||
phone string
|
||||
isPhoneVerified bool
|
||||
preferredLanguage language.Tag
|
||||
avatarURL string
|
||||
profile string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
opts []ProviderOpts
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "unauthenticated session, error",
|
||||
fields: fields{
|
||||
name: "oidc",
|
||||
issuer: "https://issuer.com",
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
scopes: []string{"openid"},
|
||||
userMapper: DefaultMapper,
|
||||
httpMock: func(issuer string) {
|
||||
gock.New(issuer).
|
||||
Get(oidc.DiscoveryEndpoint).
|
||||
Reply(200).
|
||||
JSON(&oidc.DiscoveryConfiguration{
|
||||
Issuer: issuer,
|
||||
AuthorizationEndpoint: issuer + "/authorize",
|
||||
TokenEndpoint: issuer + "/token",
|
||||
UserinfoEndpoint: issuer + "/userinfo",
|
||||
})
|
||||
gock.New(issuer).
|
||||
Get("/userinfo").
|
||||
Reply(200).
|
||||
JSON(userinfo())
|
||||
},
|
||||
authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
|
||||
tokens: nil,
|
||||
},
|
||||
want: want{
|
||||
err: ErrCodeMissing,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "userinfo error",
|
||||
fields: fields{
|
||||
name: "oidc",
|
||||
issuer: "https://issuer.com",
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
scopes: []string{"openid"},
|
||||
userMapper: DefaultMapper,
|
||||
httpMock: func(issuer string) {
|
||||
gock.New(issuer).
|
||||
Get(oidc.DiscoveryEndpoint).
|
||||
Reply(200).
|
||||
JSON(&oidc.DiscoveryConfiguration{
|
||||
Issuer: issuer,
|
||||
AuthorizationEndpoint: issuer + "/authorize",
|
||||
TokenEndpoint: issuer + "/token",
|
||||
UserinfoEndpoint: issuer + "/userinfo",
|
||||
})
|
||||
gock.New(issuer).
|
||||
Get("/userinfo").
|
||||
Reply(200).
|
||||
JSON(userinfo())
|
||||
},
|
||||
authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
|
||||
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: oidc.BearerToken,
|
||||
},
|
||||
IDTokenClaims: oidc.NewIDTokenClaims(
|
||||
"https://issuer.com",
|
||||
"sub2",
|
||||
[]string{"clientID"},
|
||||
time.Now().Add(1*time.Hour),
|
||||
time.Now().Add(-1*time.Second),
|
||||
"nonce",
|
||||
"",
|
||||
nil,
|
||||
"clientID",
|
||||
0,
|
||||
),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
err: rp.ErrUserInfoSubNotMatching,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful fetch",
|
||||
fields: fields{
|
||||
name: "oidc",
|
||||
issuer: "https://issuer.com",
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
scopes: []string{"openid"},
|
||||
userMapper: DefaultMapper,
|
||||
httpMock: func(issuer string) {
|
||||
gock.New(issuer).
|
||||
Get(oidc.DiscoveryEndpoint).
|
||||
Reply(200).
|
||||
JSON(&oidc.DiscoveryConfiguration{
|
||||
Issuer: issuer,
|
||||
AuthorizationEndpoint: issuer + "/authorize",
|
||||
TokenEndpoint: issuer + "/token",
|
||||
UserinfoEndpoint: issuer + "/userinfo",
|
||||
})
|
||||
gock.New(issuer).
|
||||
Get("/userinfo").
|
||||
Reply(200).
|
||||
JSON(userinfo())
|
||||
},
|
||||
authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
|
||||
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: oidc.BearerToken,
|
||||
},
|
||||
IDTokenClaims: oidc.NewIDTokenClaims(
|
||||
"https://issuer.com",
|
||||
"sub",
|
||||
[]string{"clientID"},
|
||||
time.Now().Add(1*time.Hour),
|
||||
time.Now().Add(-1*time.Second),
|
||||
"nonce",
|
||||
"",
|
||||
nil,
|
||||
"clientID",
|
||||
0,
|
||||
),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
id: "sub",
|
||||
firstName: "firstname",
|
||||
lastName: "lastname",
|
||||
displayName: "firstname lastname",
|
||||
nickName: "nickname",
|
||||
preferredUsername: "username",
|
||||
email: "email",
|
||||
isEmailVerified: true,
|
||||
phone: "phone",
|
||||
isPhoneVerified: true,
|
||||
preferredLanguage: language.English,
|
||||
avatarURL: "picture",
|
||||
profile: "profile",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "use ID token",
|
||||
fields: fields{
|
||||
name: "oidc",
|
||||
issuer: "https://issuer.com",
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
scopes: []string{"openid"},
|
||||
userMapper: DefaultMapper,
|
||||
httpMock: func(issuer string) {
|
||||
gock.New(issuer).
|
||||
Get(oidc.DiscoveryEndpoint).
|
||||
Reply(200).
|
||||
JSON(&oidc.DiscoveryConfiguration{
|
||||
Issuer: issuer,
|
||||
AuthorizationEndpoint: issuer + "/authorize",
|
||||
TokenEndpoint: issuer + "/token",
|
||||
UserinfoEndpoint: issuer + "/userinfo",
|
||||
})
|
||||
},
|
||||
authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
|
||||
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: oidc.BearerToken,
|
||||
},
|
||||
IDTokenClaims: func() *oidc.IDTokenClaims {
|
||||
claims := oidc.NewIDTokenClaims(
|
||||
"https://issuer.com",
|
||||
"sub",
|
||||
[]string{"clientID"},
|
||||
time.Now().Add(1*time.Hour),
|
||||
time.Now().Add(-1*time.Second),
|
||||
"nonce",
|
||||
"",
|
||||
nil,
|
||||
"clientID",
|
||||
0,
|
||||
)
|
||||
claims.SetUserInfo(userinfo())
|
||||
return claims
|
||||
}(),
|
||||
},
|
||||
},
|
||||
opts: []ProviderOpts{
|
||||
WithIDTokenMapping(),
|
||||
},
|
||||
want: want{
|
||||
id: "sub",
|
||||
firstName: "firstname",
|
||||
lastName: "lastname",
|
||||
displayName: "firstname lastname",
|
||||
nickName: "nickname",
|
||||
preferredUsername: "username",
|
||||
email: "email",
|
||||
isEmailVerified: true,
|
||||
phone: "phone",
|
||||
isPhoneVerified: true,
|
||||
preferredLanguage: language.English,
|
||||
avatarURL: "picture",
|
||||
profile: "profile",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful fetch with token exchange",
|
||||
fields: fields{
|
||||
name: "oidc",
|
||||
issuer: "https://issuer.com",
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
scopes: []string{"openid"},
|
||||
userMapper: DefaultMapper,
|
||||
httpMock: func(issuer string) {
|
||||
gock.New(issuer).
|
||||
Get(oidc.DiscoveryEndpoint).
|
||||
Reply(200).
|
||||
JSON(&oidc.DiscoveryConfiguration{
|
||||
Issuer: issuer,
|
||||
AuthorizationEndpoint: issuer + "/authorize",
|
||||
TokenEndpoint: issuer + "/token",
|
||||
JwksURI: issuer + "/keys",
|
||||
UserinfoEndpoint: issuer + "/userinfo",
|
||||
})
|
||||
gock.New(issuer).
|
||||
Post("/token").
|
||||
BodyString("client_id=clientID&client_secret=clientSecret&code=code&grant_type=authorization_code&redirect_uri=redirectURI").
|
||||
Reply(200).
|
||||
JSON(tokenResponse(t, issuer))
|
||||
gock.New(issuer).
|
||||
Get("/keys").
|
||||
Reply(200).
|
||||
JSON(keys(t))
|
||||
gock.New(issuer).
|
||||
Get("/userinfo").
|
||||
Reply(200).
|
||||
JSON(userinfo())
|
||||
},
|
||||
authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
|
||||
tokens: nil,
|
||||
code: "code",
|
||||
},
|
||||
want: want{
|
||||
id: "sub",
|
||||
firstName: "firstname",
|
||||
lastName: "lastname",
|
||||
displayName: "firstname lastname",
|
||||
nickName: "nickname",
|
||||
preferredUsername: "username",
|
||||
email: "email",
|
||||
isEmailVerified: true,
|
||||
phone: "phone",
|
||||
isPhoneVerified: true,
|
||||
preferredLanguage: language.English,
|
||||
avatarURL: "picture",
|
||||
profile: "profile",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer gock.Off()
|
||||
tt.fields.httpMock(tt.fields.issuer)
|
||||
a := assert.New(t)
|
||||
|
||||
provider, err := New(tt.fields.name, tt.fields.issuer, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.userMapper, tt.opts...)
|
||||
require.NoError(t, err)
|
||||
|
||||
session := &Session{
|
||||
Provider: provider,
|
||||
AuthURL: tt.fields.authURL,
|
||||
Code: tt.fields.code,
|
||||
Tokens: tt.fields.tokens,
|
||||
}
|
||||
|
||||
user, err := session.FetchUser(context.Background())
|
||||
if tt.want.err != nil && !errors.Is(err, tt.want.err) {
|
||||
a.Fail("invalid error", "expected %v, got %v", tt.want.err, err)
|
||||
}
|
||||
if tt.want.err == nil {
|
||||
require.NoError(t, err)
|
||||
a.Equal(tt.want.id, user.GetID())
|
||||
a.Equal(tt.want.firstName, user.GetFirstName())
|
||||
a.Equal(tt.want.lastName, user.GetLastName())
|
||||
a.Equal(tt.want.displayName, user.GetDisplayName())
|
||||
a.Equal(tt.want.nickName, user.GetNickname())
|
||||
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
|
||||
a.Equal(domain.EmailAddress(tt.want.email), user.GetEmail())
|
||||
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
|
||||
a.Equal(domain.PhoneNumber(tt.want.phone), user.GetPhone())
|
||||
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
|
||||
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
|
||||
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
|
||||
a.Equal(tt.want.profile, user.GetProfile())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func userinfo() *oidc.UserInfo {
|
||||
return &oidc.UserInfo{
|
||||
Subject: "sub",
|
||||
UserInfoProfile: oidc.UserInfoProfile{
|
||||
GivenName: "firstname",
|
||||
FamilyName: "lastname",
|
||||
Name: "firstname lastname",
|
||||
Nickname: "nickname",
|
||||
PreferredUsername: "username",
|
||||
Locale: oidc.NewLocale(language.English),
|
||||
Picture: "picture",
|
||||
Profile: "profile",
|
||||
},
|
||||
UserInfoEmail: oidc.UserInfoEmail{
|
||||
Email: "email",
|
||||
EmailVerified: oidc.Bool(true),
|
||||
},
|
||||
UserInfoPhone: oidc.UserInfoPhone{
|
||||
PhoneNumber: "phone",
|
||||
PhoneNumberVerified: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func tokenResponse(t *testing.T, issuer string) *oidc.AccessTokenResponse {
|
||||
claims := oidc.NewIDTokenClaims(
|
||||
issuer,
|
||||
"sub",
|
||||
[]string{"clientID"},
|
||||
time.Now().Add(1*time.Hour),
|
||||
time.Now().Add(-1*time.Minute),
|
||||
"",
|
||||
"",
|
||||
nil,
|
||||
"clientID",
|
||||
0,
|
||||
)
|
||||
privateKey, err := crypto.BytesToPrivateKey([]byte(`-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEpAIBAAKCAQEAs38btwb3c7r0tMaQpGvBmY+mPwMU/LpfuPoC0k2t4RsKp0fv
|
||||
40SMl50CRrHgk395wch8PMPYbl3+8TtYAJuyrFALIj3Ff1UcKIk0hOH5DDsfh7/q
|
||||
2wFuncTmS6bifYo8CfSq2vDGnM7nZnEvxY/MfSydZdcmIqlkUpfQmtzExw9+tSe5
|
||||
Dxq6gn5JtlGgLgZGt69r5iMMrTEGhhVAXzNuMZbmlCoBru+rC8ITlTX/0V1ZcsSb
|
||||
L8tYWhthyu9x6yjo1bH85wiVI4gs0MhU8f2a+kjL/KGZbR14Ua2eo6tonBZLC5DH
|
||||
WM2TkYXgRCDPufjcgmzN0Lm91E4P8KvBcvly6QIDAQABAoIBAQCPj1nbSPcg2KZe
|
||||
73FAD+8HopyUSSK//1AP4eXfzcEECVy77g0u9+R6XlkzsZCsZ4g6NN8ounqfyw3c
|
||||
YlpAIkcFCf/dowoSjT+4LASVQyatYZwWNqjgAIU4KgMG/rKnNahPTiBYe7peMB1j
|
||||
EaPjnt8uPkCk8y7NCi3y4Pk24tt/WM5KbJK2NQhUi1csGnleDfE+0blV0l/e6C68
|
||||
W5cbnbWAroMqae/Yon3XVZiXX0m+l2f6ZzIgKaD18J+eEM8FjJC+jQKiRe1i9v3K
|
||||
nQrLwh/gn8J10FcbKn3xqslKVidzASIrNIzHT9j/Z5T9NXuAKa7IV2x+Dtdus+wq
|
||||
iBsUunwBAoGBANpYew+8i9vDwK4/SefduDTuzJ0H9lWTjtbiWQ+KYZoeJ7q3/qns
|
||||
jsmi+mjxkXxXg1RrGbNbjtbl3RXXIrUeeBB0lglRJUjc3VK7VvNoyXIWsiqhCspH
|
||||
IJ9Yuknv4mXB01m/glbSCS/xu4RTgf5aOG4jUiRb9+dCIpvDxI9gbXEVAoGBANJz
|
||||
hIJkplIJ+biTi3G1Oz17qkUkInNXzAEzKD9Atoz5AIAiR1ivOMLOlbucfjevw/Nw
|
||||
TnpkMs9xqCefKupTlsriXtZI88m7ZKzAmolYsPolOy/Jhi31h9JFVTEfKGqVS+dk
|
||||
A4ndhgdW9RUeNJPY2YVCARXQrWpueweQDA1cNaeFAoGAPJsYtXqBW6PPRM5+ZiSt
|
||||
78tk8iV2o7RMjqrPS7f+dXfvUS2nO2VVEPTzCtQarOfhpToBLT65vD6bimdn09w8
|
||||
OV0TFEz4y2u65y7m6LNqTwertpdy1ki97l0DgGhccCBH2P6GYDD2qd8wTH+dcot6
|
||||
ZF/begopGoDJ+HBzi9SZLC0CgYBZzPslHMevyBvr++GLwrallKhiWnns1/DwLiEl
|
||||
ZHrBCtuA0Z+6IwLIdZiE9tEQ+ApYTXrfVPQteqUzSwLn/IUiy5eGPpjwYushoAoR
|
||||
Q2w5QTvRN1/vKo8rVXR1woLfgBdkhFPSN1mitiNcQIhU8jpXV4PZCDOHb99FqdzK
|
||||
sqcedQKBgQCOmgbqxGsnT2WQhoOdzln+NOo6Tx+FveLLqat2KzpY59W4noeI2Awn
|
||||
HfIQgWUAW9dsjVVOXMP1jhq8U9hmH/PFWA11V/iCdk1NTxZEw87VAOeWuajpdDHG
|
||||
+iex349j8h2BcQ4Zd0FWu07gGFnS/yuDJPn6jBhRusdieEcxLRjTKg==
|
||||
-----END RSA PRIVATE KEY-----
|
||||
`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
signer, err := jose.NewSigner(jose.SigningKey{Key: privateKey, Algorithm: "RS256"}, &jose.SignerOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
jws, err := signer.Sign(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
idToken, err := jws.CompactSerialize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &oidc.AccessTokenResponse{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: oidc.BearerToken,
|
||||
RefreshToken: "",
|
||||
ExpiresIn: 3600,
|
||||
IDToken: idToken,
|
||||
State: "testState",
|
||||
}
|
||||
}
|
||||
|
||||
func keys(t *testing.T) *jose.JSONWebKeySet {
|
||||
privateKey, err := crypto.BytesToPublicKey([]byte(`-----BEGIN PUBLIC KEY-----
|
||||
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAs38btwb3c7r0tMaQpGvB
|
||||
mY+mPwMU/LpfuPoC0k2t4RsKp0fv40SMl50CRrHgk395wch8PMPYbl3+8TtYAJuy
|
||||
rFALIj3Ff1UcKIk0hOH5DDsfh7/q2wFuncTmS6bifYo8CfSq2vDGnM7nZnEvxY/M
|
||||
fSydZdcmIqlkUpfQmtzExw9+tSe5Dxq6gn5JtlGgLgZGt69r5iMMrTEGhhVAXzNu
|
||||
MZbmlCoBru+rC8ITlTX/0V1ZcsSbL8tYWhthyu9x6yjo1bH85wiVI4gs0MhU8f2a
|
||||
+kjL/KGZbR14Ua2eo6tonBZLC5DHWM2TkYXgRCDPufjcgmzN0Lm91E4P8KvBcvly
|
||||
6QIDAQAB
|
||||
-----END PUBLIC KEY-----
|
||||
`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{{Key: privateKey, Algorithm: "RS256", Use: oidc.KeyUseSignature}}}
|
||||
}
|
89
apps/api/internal/idp/providers/saml/mapper.go
Normal file
89
apps/api/internal/idp/providers/saml/mapper.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
)
|
||||
|
||||
var _ idp.User = (*UserMapper)(nil)
|
||||
|
||||
// UserMapper is an implementation of [idp.User].
|
||||
type UserMapper struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Attributes map[string][]string `json:"attributes,omitempty"`
|
||||
}
|
||||
|
||||
func NewUser() *UserMapper {
|
||||
return &UserMapper{Attributes: map[string][]string{}}
|
||||
}
|
||||
|
||||
func (u *UserMapper) SetID(id string) {
|
||||
u.ID = id
|
||||
}
|
||||
|
||||
// GetID is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) GetID() string {
|
||||
return u.ID
|
||||
}
|
||||
|
||||
// GetFirstName is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) GetFirstName() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetLastName is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) GetLastName() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetDisplayName is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) GetDisplayName() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetNickname is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) GetNickname() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetPreferredUsername is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) GetPreferredUsername() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetEmail is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) GetEmail() domain.EmailAddress {
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsEmailVerified is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) IsEmailVerified() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// GetPhone is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) GetPhone() domain.PhoneNumber {
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsPhoneVerified is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) IsPhoneVerified() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// GetPreferredLanguage is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) GetPreferredLanguage() language.Tag {
|
||||
return language.Und
|
||||
}
|
||||
|
||||
// GetAvatarURL is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) GetAvatarURL() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetProfile is an implementation of the [idp.User] interface.
|
||||
func (u *UserMapper) GetProfile() string {
|
||||
return ""
|
||||
}
|
@@ -0,0 +1,58 @@
|
||||
package requesttracker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/crewjam/saml/samlsp"
|
||||
)
|
||||
|
||||
type GetRequest func(ctx context.Context, intentID string) (*samlsp.TrackedRequest, error)
|
||||
type AddRequest func(ctx context.Context, intentID, requestID string) error
|
||||
|
||||
type RequestTracker struct {
|
||||
addRequest AddRequest
|
||||
getRequest GetRequest
|
||||
}
|
||||
|
||||
func New(addRequestF AddRequest, getRequestF GetRequest) samlsp.RequestTracker {
|
||||
return &RequestTracker{
|
||||
addRequest: addRequestF,
|
||||
getRequest: getRequestF,
|
||||
}
|
||||
}
|
||||
|
||||
func (rt *RequestTracker) TrackRequest(w http.ResponseWriter, r *http.Request, samlRequestID string) (index string, err error) {
|
||||
// intentID is stored in r.URL
|
||||
intentID := r.URL.String()
|
||||
if err := rt.addRequest(r.Context(), intentID, samlRequestID); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return intentID, nil
|
||||
}
|
||||
|
||||
func (rt *RequestTracker) StopTrackingRequest(w http.ResponseWriter, r *http.Request, index string) error {
|
||||
// error is not handled in SP logic
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rt *RequestTracker) GetTrackedRequests(r *http.Request) []samlsp.TrackedRequest {
|
||||
// RelayState is the context of the auth flow and as such contains the intentID
|
||||
intentID := r.FormValue("RelayState")
|
||||
|
||||
request, err := rt.getRequest(r.Context(), intentID)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return []samlsp.TrackedRequest{
|
||||
{
|
||||
Index: request.Index,
|
||||
SAMLRequestID: request.SAMLRequestID,
|
||||
URI: request.URI,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (rt *RequestTracker) GetTrackedRequest(r *http.Request, index string) (*samlsp.TrackedRequest, error) {
|
||||
return rt.getRequest(r.Context(), index)
|
||||
}
|
278
apps/api/internal/idp/providers/saml/saml.go
Normal file
278
apps/api/internal/idp/providers/saml/saml.go
Normal file
@@ -0,0 +1,278 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/xml"
|
||||
"io"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/crewjam/saml"
|
||||
"github.com/crewjam/saml/samlsp"
|
||||
"golang.org/x/text/encoding/ianaindex"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
var _ idp.Provider = (*Provider)(nil)
|
||||
|
||||
// Provider is the [idp.Provider] implementation for a generic SAML provider
|
||||
type Provider struct {
|
||||
name string
|
||||
|
||||
requestTracker samlsp.RequestTracker
|
||||
Certificate []byte
|
||||
|
||||
spOptions *samlsp.Options
|
||||
|
||||
binding string
|
||||
nameIDFormat saml.NameIDFormat
|
||||
transientMappingAttributeName string
|
||||
|
||||
isLinkingAllowed bool
|
||||
isCreationAllowed bool
|
||||
isAutoCreation bool
|
||||
isAutoUpdate bool
|
||||
}
|
||||
|
||||
type ProviderOpts func(provider *Provider)
|
||||
|
||||
// WithLinkingAllowed allows end users to link the federated user to an existing one.
|
||||
func WithLinkingAllowed() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.isLinkingAllowed = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithCreationAllowed allows end users to create a new user using the federated information.
|
||||
func WithCreationAllowed() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.isCreationAllowed = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithAutoCreation enables that federated users are automatically created if not already existing.
|
||||
func WithAutoCreation() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.isAutoCreation = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithAutoUpdate enables that information retrieved from the provider is automatically used to update
|
||||
// the existing user on each authentication.
|
||||
func WithAutoUpdate() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.isAutoUpdate = true
|
||||
}
|
||||
}
|
||||
|
||||
func WithSignedRequest() ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.spOptions.SignRequest = true
|
||||
}
|
||||
}
|
||||
|
||||
func WithBinding(binding string) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.binding = binding
|
||||
}
|
||||
}
|
||||
|
||||
func WithNameIDFormat(format domain.SAMLNameIDFormat) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.nameIDFormat = nameIDFormatFromDomain(format)
|
||||
}
|
||||
}
|
||||
|
||||
func WithTransientMappingAttributeName(attribute string) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.transientMappingAttributeName = attribute
|
||||
}
|
||||
}
|
||||
|
||||
func WithCustomRequestTracker(tracker samlsp.RequestTracker) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.requestTracker = tracker
|
||||
}
|
||||
}
|
||||
|
||||
func WithEntityID(entityID string) ProviderOpts {
|
||||
return func(p *Provider) {
|
||||
p.spOptions.EntityID = entityID
|
||||
}
|
||||
}
|
||||
|
||||
// ParseMetadata parses the metadata with the provided XML encoding and returns the EntityDescriptor
|
||||
func ParseMetadata(metadata []byte) (*saml.EntityDescriptor, error) {
|
||||
entityDescriptor := new(saml.EntityDescriptor)
|
||||
reader := bytes.NewReader(metadata)
|
||||
decoder := xml.NewDecoder(reader)
|
||||
decoder.CharsetReader = func(charset string, reader io.Reader) (io.Reader, error) {
|
||||
enc, err := ianaindex.IANA.Encoding(charset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return enc.NewDecoder().Reader(reader), nil
|
||||
}
|
||||
if err := decoder.Decode(entityDescriptor); err != nil {
|
||||
if err.Error() == "expected element type <EntityDescriptor> but have <EntitiesDescriptor>" {
|
||||
// reset reader to start of metadata so we can try to parse it as an EntitiesDescriptor
|
||||
if _, err := reader.Seek(0, io.SeekStart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entities := &EntitiesDescriptor{}
|
||||
if err := decoder.Decode(entities); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, e := range entities.EntityDescriptors {
|
||||
if len(e.IDPSSODescriptors) > 0 {
|
||||
return &entities.EntityDescriptors[i], nil
|
||||
}
|
||||
}
|
||||
return nil, zerrors.ThrowInternal(nil, "SAML-Ejoi3r2", "no entity found with IDPSSODescriptor")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return entityDescriptor, nil
|
||||
}
|
||||
|
||||
func New(
|
||||
name string,
|
||||
rootURLStr string,
|
||||
metadata []byte,
|
||||
certificate []byte,
|
||||
key []byte,
|
||||
options ...ProviderOpts,
|
||||
) (*Provider, error) {
|
||||
entityDescriptor, err := ParseMetadata(metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyPair, err := tls.X509KeyPair(certificate, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rootURL, err := url.Parse(rootURLStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
opts := samlsp.Options{
|
||||
URL: *rootURL,
|
||||
Key: keyPair.PrivateKey.(*rsa.PrivateKey),
|
||||
Certificate: keyPair.Leaf,
|
||||
IDPMetadata: entityDescriptor,
|
||||
SignRequest: false,
|
||||
}
|
||||
provider := &Provider{
|
||||
name: name,
|
||||
spOptions: &opts,
|
||||
Certificate: certificate,
|
||||
// the library uses transient as default, which does not make sense for federating accounts
|
||||
nameIDFormat: saml.PersistentNameIDFormat,
|
||||
}
|
||||
for _, option := range options {
|
||||
option(provider)
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (p *Provider) Name() string {
|
||||
return p.name
|
||||
}
|
||||
|
||||
func (p *Provider) IsLinkingAllowed() bool {
|
||||
return p.isLinkingAllowed
|
||||
}
|
||||
|
||||
func (p *Provider) IsCreationAllowed() bool {
|
||||
return p.isCreationAllowed
|
||||
}
|
||||
|
||||
func (p *Provider) IsAutoCreation() bool {
|
||||
return p.isAutoCreation
|
||||
}
|
||||
|
||||
func (p *Provider) IsAutoUpdate() bool {
|
||||
return p.isAutoUpdate
|
||||
}
|
||||
|
||||
func (p *Provider) GetSP() (*samlsp.Middleware, error) {
|
||||
sp, err := samlsp.New(*p.spOptions)
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "SAML-qee09ffuq5", "Errors.Intent.IDPInvalid")
|
||||
}
|
||||
sp.ServiceProvider.AuthnNameIDFormat = p.nameIDFormat
|
||||
if p.requestTracker != nil {
|
||||
sp.RequestTracker = p.requestTracker
|
||||
}
|
||||
if p.binding != "" {
|
||||
sp.Binding = p.binding
|
||||
}
|
||||
sp.ServiceProvider.MetadataValidDuration = time.Until(sp.ServiceProvider.Certificate.NotAfter)
|
||||
return sp, nil
|
||||
}
|
||||
|
||||
func (p *Provider) BeginAuth(ctx context.Context, state string, _ ...idp.Parameter) (idp.Session, error) {
|
||||
m, err := p.GetSP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Session{
|
||||
ServiceProvider: m,
|
||||
state: state,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *Provider) TransientMappingAttributeName() string {
|
||||
return p.transientMappingAttributeName
|
||||
}
|
||||
|
||||
func nameIDFormatFromDomain(format domain.SAMLNameIDFormat) saml.NameIDFormat {
|
||||
switch format {
|
||||
case domain.SAMLNameIDFormatUnspecified:
|
||||
return saml.UnspecifiedNameIDFormat
|
||||
case domain.SAMLNameIDFormatEmailAddress:
|
||||
return saml.EmailAddressNameIDFormat
|
||||
case domain.SAMLNameIDFormatPersistent:
|
||||
return saml.PersistentNameIDFormat
|
||||
case domain.SAMLNameIDFormatTransient:
|
||||
return saml.TransientNameIDFormat
|
||||
default:
|
||||
return saml.UnspecifiedNameIDFormat
|
||||
}
|
||||
}
|
||||
|
||||
// EntitiesDescriptor is a workaround until we eventually fork the crewjam/saml library, since maintenance on that repo seems to have stopped.
|
||||
// This is to be able to handle xsd:duration format using the UnmarshalXML method.
|
||||
// crewjam/saml only implements the xsd:dateTime format for EntityDescriptor, but not EntitiesDescriptor.
|
||||
type EntitiesDescriptor saml.EntitiesDescriptor
|
||||
|
||||
// UnmarshalXML implements xml.Unmarshaler
|
||||
func (m *EntitiesDescriptor) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
|
||||
type Alias EntitiesDescriptor
|
||||
aux := &struct {
|
||||
ValidUntil *saml.RelaxedTime `xml:"validUntil,attr,omitempty"`
|
||||
CacheDuration *saml.Duration `xml:"cacheDuration,attr,omitempty"`
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(m),
|
||||
}
|
||||
if err := d.DecodeElement(aux, &start); err != nil {
|
||||
return err
|
||||
}
|
||||
m.ValidUntil = (*time.Time)(aux.ValidUntil)
|
||||
m.CacheDuration = (*time.Duration)(aux.CacheDuration)
|
||||
return nil
|
||||
}
|
438
apps/api/internal/idp/providers/saml/saml_test.go
Normal file
438
apps/api/internal/idp/providers/saml/saml_test.go
Normal file
@@ -0,0 +1,438 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/crewjam/saml"
|
||||
"github.com/crewjam/saml/samlsp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/saml/requesttracker"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func TestProvider_BeginAuth(t *testing.T) {
|
||||
requestTracker := requesttracker.New(
|
||||
func(ctx context.Context, authRequestID, samlRequestID string) error {
|
||||
assert.Equal(t, "state", authRequestID)
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, authRequestID string) (*samlsp.TrackedRequest, error) {
|
||||
return &samlsp.TrackedRequest{
|
||||
SAMLRequestID: "state",
|
||||
Index: authRequestID,
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
type fields struct {
|
||||
name string
|
||||
rootURL string
|
||||
metadata []byte
|
||||
certificate []byte
|
||||
key []byte
|
||||
options []ProviderOpts
|
||||
}
|
||||
type args struct {
|
||||
state string
|
||||
}
|
||||
type want struct {
|
||||
err func(error) bool
|
||||
authType idp.Auth
|
||||
ssoURL string
|
||||
relayState string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "redirect binding, success",
|
||||
fields: fields{
|
||||
name: "saml",
|
||||
rootURL: "https://localhost:8080",
|
||||
metadata: []byte("<EntityDescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" validUntil=\"2023-08-27T12:40:58.803Z\" cacheDuration=\"PT48H\" entityID=\"http://localhost:8000/metadata\">\n <IDPSSODescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <KeyDescriptor use=\"signing\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n </KeyDescriptor>\n <KeyDescriptor use=\"encryption\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes128-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes192-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes256-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p\"></EncryptionMethod>\n </KeyDescriptor>\n <NameIDFormat>urn:oasis:names:tc:SAML:2.0:nameid-format:transient</NameIDFormat>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n </IDPSSODescriptor>\n</EntityDescriptor>"),
|
||||
certificate: []byte("-----BEGIN CERTIFICATE-----\nMIIC2zCCAcOgAwIBAgIIAy/jm1gAAdEwDQYJKoZIhvcNAQELBQAwEjEQMA4GA1UE\nChMHWklUQURFTDAeFw0yMzA4MzAwNzExMTVaFw0yNDA4MjkwNzExMTVaMBIxEDAO\nBgNVBAoTB1pJVEFERUwwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDE\nd3TztGgSb3LBVZn8f60NbFCyZW+F9HPiMCr9F9T45Zc0fgmMwxId0WzRD5Y/3yc1\ndHJzt+Bsxvw12aUHbIPiothqk3lINoFzl2H/cSfIW3nehKyNOUqdBQ8B4mvaqH81\njTjoJ/JTJAwzglHk6JAWjhOyx9aep1yBqYa3QASeTaW9sxkpB0Co1L2UPNhuMwZq\n8RA9NkTfmYVcVBeNqihler5MhruFtqrv+J0ftwc1stw8uCN89ADyr4Ni+e+FeWar\nQs9Bkfc6KLF/5IXa9HCsHNPaaoYPY6I6RSaG4/DKoSKIEe1/GSVG1FTpZ8trUZxv\nU+xXS6gEalXcrJsiX8aXAgMBAAGjNTAzMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUE\nDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMA0GCSqGSIb3DQEBCwUAA4IBAQCx\n/dRNIj0N/16zJhZR/ahkc2AkvDXYxyr4JRT5wK9GQDNl/oaX3debRuSi/tfaXFIX\naJA6PxM4J49ZaiEpLrKfxMz5kAhjKchCBEMcH3mGt+iNZH7EOyTvHjpGrP2OZrsh\nO17yrvN3HuQxIU6roJlqtZz2iAADsoPtwOO4D7hupm9XTMkSnAmlMWOo/q46Jz89\n1sMxB+dXmH/zV0wgwh0omZfLV0u89mvdq269VhcjNBpBYSnN1ccqYWd5iwziob3I\nvaavGHGfkbvRUn/tKftYuTK30q03R+e9YbmlWZ0v695owh2e/apCzowQsCKfSVC8\nOxVyt5XkHq1tWwVyBmFp\n-----END CERTIFICATE-----\n"),
|
||||
key: []byte("-----BEGIN RSA PRIVATE KEY-----\nMIIEogIBAAKCAQEAxHd087RoEm9ywVWZ/H+tDWxQsmVvhfRz4jAq/RfU+OWXNH4J\njMMSHdFs0Q+WP98nNXRyc7fgbMb8NdmlB2yD4qLYapN5SDaBc5dh/3EnyFt53oSs\njTlKnQUPAeJr2qh/NY046CfyUyQMM4JR5OiQFo4TssfWnqdcgamGt0AEnk2lvbMZ\nKQdAqNS9lDzYbjMGavEQPTZE35mFXFQXjaooZXq+TIa7hbaq7/idH7cHNbLcPLgj\nfPQA8q+DYvnvhXlmq0LPQZH3Oiixf+SF2vRwrBzT2mqGD2OiOkUmhuPwyqEiiBHt\nfxklRtRU6WfLa1Gcb1PsV0uoBGpV3KybIl/GlwIDAQABAoIBAEQjDduLgOCL6Gem\n0X3hpdnW6/HC/jed/Sa//9jBECq2LYeWAqff64ON40hqOHi0YvvGA/+gEOSI6mWe\nsv5tIxxRz+6+cLybsq+tG96kluCE4TJMHy/nY7orS/YiWbd+4odnEApr+D3fbZ/b\nnZ1fDsHTyn8hkYx6jLmnWsJpIHDp7zxD76y7k2Bbg6DZrCGiVxngiLJk23dvz79W\np03lHLM7XE92aFwXQmhfxHGxrbuoB/9eY4ai5IHp36H4fw0vL6NXdNQAo/bhe0p9\nAYB7y0ZumF8Hg0Z/BmMeEzLy6HrYB+VE8cO93pNjhSyH+p2yDB/BlUyTiRLQAoM0\nVTmOZXECgYEA7NGlzpKNhyQEJihVqt0MW0LhKIO/xbBn+XgYfX6GpqPa/ucnMx5/\nVezpl3gK8IU4wPUhAyXXAHJiqNBcEeyxrw0MXLujDVMJgYaLysCLJdvMVgoY08mS\nK5IQivpbozpf4+0y3mOnA+Sy1kbfxv2X8xiWLODRQW3f3q/xoklwOR8CgYEA1GEe\nfaibOFTQAYcIVj77KXtBfYZsX3EGAyfAN9O7cKHq5oaxVstwnF47WxpuVtoKZxCZ\nbNm9D5WvQ9b+Ztpioe42tzwE7Bff/Osj868GcDdRPK7nFlh9N2yVn/D514dOYVwR\n4MBr1KrJzgRWt4QqS4H+to1GzudDTSNlG7gnK4kCgYBUi6AbOHzoYzZL/RhgcJwp\ntJ23nhmH1Su5h2OO4e3mbhcP66w19sxU+8iFN+kH5zfUw26utgKk+TE5vXExQQRK\nT2k7bg2PAzcgk80ybD0BHhA8I0yrx4m0nmfjhe/TPVLgh10iwgbtP+eM0i6v1vc5\nZWyvxu9N4ZEL6lpkqr0y1wKBgG/NAIQd8jhhTW7Aav8cAJQBsqQl038avJOEpYe+\nCnpsgoAAf/K0/f8TDCQVceh+t+MxtdK7fO9rWOxZjWsPo8Si5mLnUaAHoX4/OpnZ\nlYYVWMqdOEFnK+O1Yb7k2GFBdV2DXlX2dc1qavntBsls5ecB89id3pyk2aUN8Pf6\npYQhAoGAMGtrHFely9wyaxI0RTCyfmJbWZHGVGkv6ELK8wneJjdjl82XOBUGCg5q\naRCrTZ3dPitKwrUa6ibJCIFCIziiriBmjDvTHzkMvoJEap2TVxYNDR6IfINVsQ57\nlOsiC4A2uGq4Lbfld+gjoplJ5GX6qXtTgZ6m7eo0y7U6zm2tkN0=\n-----END RSA PRIVATE KEY-----\n"),
|
||||
options: []ProviderOpts{
|
||||
WithCustomRequestTracker(requestTracker),
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
state: "state",
|
||||
},
|
||||
want: want{
|
||||
authType: &idp.RedirectAuth{},
|
||||
ssoURL: "http://localhost:8000/sso",
|
||||
relayState: "state",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "post binding, success",
|
||||
fields: fields{
|
||||
name: "saml",
|
||||
rootURL: "https://localhost:8080",
|
||||
metadata: []byte("<EntityDescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" validUntil=\"2023-08-27T12:40:58.803Z\" cacheDuration=\"PT48H\" entityID=\"http://localhost:8000/metadata\">\n <IDPSSODescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <KeyDescriptor use=\"signing\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n </KeyDescriptor>\n <KeyDescriptor use=\"encryption\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes128-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes192-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes256-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p\"></EncryptionMethod>\n </KeyDescriptor>\n <NameIDFormat>urn:oasis:names:tc:SAML:2.0:nameid-format:transient</NameIDFormat>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n </IDPSSODescriptor>\n</EntityDescriptor>"),
|
||||
certificate: []byte("-----BEGIN CERTIFICATE-----\nMIIC2zCCAcOgAwIBAgIIAy/jm1gAAdEwDQYJKoZIhvcNAQELBQAwEjEQMA4GA1UE\nChMHWklUQURFTDAeFw0yMzA4MzAwNzExMTVaFw0yNDA4MjkwNzExMTVaMBIxEDAO\nBgNVBAoTB1pJVEFERUwwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDE\nd3TztGgSb3LBVZn8f60NbFCyZW+F9HPiMCr9F9T45Zc0fgmMwxId0WzRD5Y/3yc1\ndHJzt+Bsxvw12aUHbIPiothqk3lINoFzl2H/cSfIW3nehKyNOUqdBQ8B4mvaqH81\njTjoJ/JTJAwzglHk6JAWjhOyx9aep1yBqYa3QASeTaW9sxkpB0Co1L2UPNhuMwZq\n8RA9NkTfmYVcVBeNqihler5MhruFtqrv+J0ftwc1stw8uCN89ADyr4Ni+e+FeWar\nQs9Bkfc6KLF/5IXa9HCsHNPaaoYPY6I6RSaG4/DKoSKIEe1/GSVG1FTpZ8trUZxv\nU+xXS6gEalXcrJsiX8aXAgMBAAGjNTAzMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUE\nDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMA0GCSqGSIb3DQEBCwUAA4IBAQCx\n/dRNIj0N/16zJhZR/ahkc2AkvDXYxyr4JRT5wK9GQDNl/oaX3debRuSi/tfaXFIX\naJA6PxM4J49ZaiEpLrKfxMz5kAhjKchCBEMcH3mGt+iNZH7EOyTvHjpGrP2OZrsh\nO17yrvN3HuQxIU6roJlqtZz2iAADsoPtwOO4D7hupm9XTMkSnAmlMWOo/q46Jz89\n1sMxB+dXmH/zV0wgwh0omZfLV0u89mvdq269VhcjNBpBYSnN1ccqYWd5iwziob3I\nvaavGHGfkbvRUn/tKftYuTK30q03R+e9YbmlWZ0v695owh2e/apCzowQsCKfSVC8\nOxVyt5XkHq1tWwVyBmFp\n-----END CERTIFICATE-----\n"),
|
||||
key: []byte("-----BEGIN RSA PRIVATE KEY-----\nMIIEogIBAAKCAQEAxHd087RoEm9ywVWZ/H+tDWxQsmVvhfRz4jAq/RfU+OWXNH4J\njMMSHdFs0Q+WP98nNXRyc7fgbMb8NdmlB2yD4qLYapN5SDaBc5dh/3EnyFt53oSs\njTlKnQUPAeJr2qh/NY046CfyUyQMM4JR5OiQFo4TssfWnqdcgamGt0AEnk2lvbMZ\nKQdAqNS9lDzYbjMGavEQPTZE35mFXFQXjaooZXq+TIa7hbaq7/idH7cHNbLcPLgj\nfPQA8q+DYvnvhXlmq0LPQZH3Oiixf+SF2vRwrBzT2mqGD2OiOkUmhuPwyqEiiBHt\nfxklRtRU6WfLa1Gcb1PsV0uoBGpV3KybIl/GlwIDAQABAoIBAEQjDduLgOCL6Gem\n0X3hpdnW6/HC/jed/Sa//9jBECq2LYeWAqff64ON40hqOHi0YvvGA/+gEOSI6mWe\nsv5tIxxRz+6+cLybsq+tG96kluCE4TJMHy/nY7orS/YiWbd+4odnEApr+D3fbZ/b\nnZ1fDsHTyn8hkYx6jLmnWsJpIHDp7zxD76y7k2Bbg6DZrCGiVxngiLJk23dvz79W\np03lHLM7XE92aFwXQmhfxHGxrbuoB/9eY4ai5IHp36H4fw0vL6NXdNQAo/bhe0p9\nAYB7y0ZumF8Hg0Z/BmMeEzLy6HrYB+VE8cO93pNjhSyH+p2yDB/BlUyTiRLQAoM0\nVTmOZXECgYEA7NGlzpKNhyQEJihVqt0MW0LhKIO/xbBn+XgYfX6GpqPa/ucnMx5/\nVezpl3gK8IU4wPUhAyXXAHJiqNBcEeyxrw0MXLujDVMJgYaLysCLJdvMVgoY08mS\nK5IQivpbozpf4+0y3mOnA+Sy1kbfxv2X8xiWLODRQW3f3q/xoklwOR8CgYEA1GEe\nfaibOFTQAYcIVj77KXtBfYZsX3EGAyfAN9O7cKHq5oaxVstwnF47WxpuVtoKZxCZ\nbNm9D5WvQ9b+Ztpioe42tzwE7Bff/Osj868GcDdRPK7nFlh9N2yVn/D514dOYVwR\n4MBr1KrJzgRWt4QqS4H+to1GzudDTSNlG7gnK4kCgYBUi6AbOHzoYzZL/RhgcJwp\ntJ23nhmH1Su5h2OO4e3mbhcP66w19sxU+8iFN+kH5zfUw26utgKk+TE5vXExQQRK\nT2k7bg2PAzcgk80ybD0BHhA8I0yrx4m0nmfjhe/TPVLgh10iwgbtP+eM0i6v1vc5\nZWyvxu9N4ZEL6lpkqr0y1wKBgG/NAIQd8jhhTW7Aav8cAJQBsqQl038avJOEpYe+\nCnpsgoAAf/K0/f8TDCQVceh+t+MxtdK7fO9rWOxZjWsPo8Si5mLnUaAHoX4/OpnZ\nlYYVWMqdOEFnK+O1Yb7k2GFBdV2DXlX2dc1qavntBsls5ecB89id3pyk2aUN8Pf6\npYQhAoGAMGtrHFely9wyaxI0RTCyfmJbWZHGVGkv6ELK8wneJjdjl82XOBUGCg5q\naRCrTZ3dPitKwrUa6ibJCIFCIziiriBmjDvTHzkMvoJEap2TVxYNDR6IfINVsQ57\nlOsiC4A2uGq4Lbfld+gjoplJ5GX6qXtTgZ6m7eo0y7U6zm2tkN0=\n-----END RSA PRIVATE KEY-----\n"),
|
||||
options: []ProviderOpts{
|
||||
WithCustomRequestTracker(requestTracker),
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
state: "state",
|
||||
},
|
||||
want: want{
|
||||
authType: &idp.FormAuth{},
|
||||
ssoURL: "http://localhost:8000/sso",
|
||||
relayState: "state",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
provider, err := New(
|
||||
tt.fields.name,
|
||||
tt.fields.rootURL,
|
||||
tt.fields.metadata,
|
||||
tt.fields.certificate,
|
||||
tt.fields.key,
|
||||
tt.fields.options...,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
session, err := provider.BeginAuth(ctx, tt.args.state, nil)
|
||||
if tt.want.err != nil && !tt.want.err(err) {
|
||||
a.Fail("invalid error", err)
|
||||
}
|
||||
if tt.want.err == nil {
|
||||
a.NoError(err)
|
||||
gotAuth, gotErr := session.GetAuth(ctx)
|
||||
a.NoError(gotErr)
|
||||
a.IsType(tt.want.authType, gotAuth)
|
||||
|
||||
var ssoURL, relayState, samlRequest string
|
||||
switch auth := gotAuth.(type) {
|
||||
case *idp.RedirectAuth:
|
||||
gotRedirect, err := url.Parse(auth.RedirectURL)
|
||||
a.NoError(err)
|
||||
gotQuery := gotRedirect.Query()
|
||||
|
||||
ssoURL = gotRedirect.Scheme + "://" + gotRedirect.Host + gotRedirect.Path
|
||||
relayState = gotQuery.Get("RelayState")
|
||||
samlRequest = gotQuery.Get("SAMLRequest")
|
||||
case *idp.FormAuth:
|
||||
ssoURL = auth.URL
|
||||
relayState = auth.Fields["RelayState"]
|
||||
samlRequest = auth.Fields["SAMLRequest"]
|
||||
}
|
||||
a.Equal(tt.want.ssoURL, ssoURL)
|
||||
a.Equal(tt.want.relayState, relayState)
|
||||
a.NotEmpty(samlRequest)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_Options(t *testing.T) {
|
||||
type fields struct {
|
||||
name string
|
||||
rootURL string
|
||||
metadata []byte
|
||||
key []byte
|
||||
certificate []byte
|
||||
options []ProviderOpts
|
||||
}
|
||||
type want struct {
|
||||
err bool
|
||||
name string
|
||||
linkingAllowed bool
|
||||
creationAllowed bool
|
||||
autoCreation bool
|
||||
autoUpdate bool
|
||||
binding string
|
||||
nameIDFormat saml.NameIDFormat
|
||||
transientMappingAttributeName string
|
||||
withSignedRequest bool
|
||||
requesttracker samlsp.RequestTracker
|
||||
entityID string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want want
|
||||
}{{
|
||||
name: "failed metadata",
|
||||
fields: fields{
|
||||
name: "saml",
|
||||
rootURL: "https://localhost:8080",
|
||||
metadata: []byte(">xml<"),
|
||||
options: nil,
|
||||
},
|
||||
want: want{
|
||||
err: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "failed keypair cert",
|
||||
fields: fields{
|
||||
name: "saml",
|
||||
rootURL: "https://localhost:8080",
|
||||
key: []byte("-----BEGIN RSA PRIVATE KEY-----\nMIIEogIBAAKCAQEAxHd087RoEm9ywVWZ/H+tDWxQsmVvhfRz4jAq/RfU+OWXNH4J\njMMSHdFs0Q+WP98nNXRyc7fgbMb8NdmlB2yD4qLYapN5SDaBc5dh/3EnyFt53oSs\njTlKnQUPAeJr2qh/NY046CfyUyQMM4JR5OiQFo4TssfWnqdcgamGt0AEnk2lvbMZ\nKQdAqNS9lDzYbjMGavEQPTZE35mFXFQXjaooZXq+TIa7hbaq7/idH7cHNbLcPLgj\nfPQA8q+DYvnvhXlmq0LPQZH3Oiixf+SF2vRwrBzT2mqGD2OiOkUmhuPwyqEiiBHt\nfxklRtRU6WfLa1Gcb1PsV0uoBGpV3KybIl/GlwIDAQABAoIBAEQjDduLgOCL6Gem\n0X3hpdnW6/HC/jed/Sa//9jBECq2LYeWAqff64ON40hqOHi0YvvGA/+gEOSI6mWe\nsv5tIxxRz+6+cLybsq+tG96kluCE4TJMHy/nY7orS/YiWbd+4odnEApr+D3fbZ/b\nnZ1fDsHTyn8hkYx6jLmnWsJpIHDp7zxD76y7k2Bbg6DZrCGiVxngiLJk23dvz79W\np03lHLM7XE92aFwXQmhfxHGxrbuoB/9eY4ai5IHp36H4fw0vL6NXdNQAo/bhe0p9\nAYB7y0ZumF8Hg0Z/BmMeEzLy6HrYB+VE8cO93pNjhSyH+p2yDB/BlUyTiRLQAoM0\nVTmOZXECgYEA7NGlzpKNhyQEJihVqt0MW0LhKIO/xbBn+XgYfX6GpqPa/ucnMx5/\nVezpl3gK8IU4wPUhAyXXAHJiqNBcEeyxrw0MXLujDVMJgYaLysCLJdvMVgoY08mS\nK5IQivpbozpf4+0y3mOnA+Sy1kbfxv2X8xiWLODRQW3f3q/xoklwOR8CgYEA1GEe\nfaibOFTQAYcIVj77KXtBfYZsX3EGAyfAN9O7cKHq5oaxVstwnF47WxpuVtoKZxCZ\nbNm9D5WvQ9b+Ztpioe42tzwE7Bff/Osj868GcDdRPK7nFlh9N2yVn/D514dOYVwR\n4MBr1KrJzgRWt4QqS4H+to1GzudDTSNlG7gnK4kCgYBUi6AbOHzoYzZL/RhgcJwp\ntJ23nhmH1Su5h2OO4e3mbhcP66w19sxU+8iFN+kH5zfUw26utgKk+TE5vXExQQRK\nT2k7bg2PAzcgk80ybD0BHhA8I0yrx4m0nmfjhe/TPVLgh10iwgbtP+eM0i6v1vc5\nZWyvxu9N4ZEL6lpkqr0y1wKBgG/NAIQd8jhhTW7Aav8cAJQBsqQl038avJOEpYe+\nCnpsgoAAf/K0/f8TDCQVceh+t+MxtdK7fO9rWOxZjWsPo8Si5mLnUaAHoX4/OpnZ\nlYYVWMqdOEFnK+O1Yb7k2GFBdV2DXlX2dc1qavntBsls5ecB89id3pyk2aUN8Pf6\npYQhAoGAMGtrHFely9wyaxI0RTCyfmJbWZHGVGkv6ELK8wneJjdjl82XOBUGCg5q\naRCrTZ3dPitKwrUa6ibJCIFCIziiriBmjDvTHzkMvoJEap2TVxYNDR6IfINVsQ57\nlOsiC4A2uGq4Lbfld+gjoplJ5GX6qXtTgZ6m7eo0y7U6zm2tkN0=\n-----END RSA PRIVATE KEY-----\n"),
|
||||
metadata: []byte("<EntityDescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" validUntil=\"2023-08-27T12:40:58.803Z\" cacheDuration=\"PT48H\" entityID=\"http://localhost:8000/metadata\">\n <IDPSSODescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <KeyDescriptor use=\"signing\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n </KeyDescriptor>\n <KeyDescriptor use=\"encryption\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes128-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes192-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes256-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p\"></EncryptionMethod>\n </KeyDescriptor>\n <NameIDFormat>urn:oasis:names:tc:SAML:2.0:nameid-format:transient</NameIDFormat>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n </IDPSSODescriptor>\n</EntityDescriptor>"),
|
||||
options: nil,
|
||||
},
|
||||
want: want{
|
||||
err: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "failed keypair key",
|
||||
fields: fields{
|
||||
name: "saml",
|
||||
rootURL: "https://localhost:8080",
|
||||
certificate: []byte("-----BEGIN CERTIFICATE-----\nMIIC2zCCAcOgAwIBAgIIAy/jm1gAAdEwDQYJKoZIhvcNAQELBQAwEjEQMA4GA1UE\nChMHWklUQURFTDAeFw0yMzA4MzAwNzExMTVaFw0yNDA4MjkwNzExMTVaMBIxEDAO\nBgNVBAoTB1pJVEFERUwwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDE\nd3TztGgSb3LBVZn8f60NbFCyZW+F9HPiMCr9F9T45Zc0fgmMwxId0WzRD5Y/3yc1\ndHJzt+Bsxvw12aUHbIPiothqk3lINoFzl2H/cSfIW3nehKyNOUqdBQ8B4mvaqH81\njTjoJ/JTJAwzglHk6JAWjhOyx9aep1yBqYa3QASeTaW9sxkpB0Co1L2UPNhuMwZq\n8RA9NkTfmYVcVBeNqihler5MhruFtqrv+J0ftwc1stw8uCN89ADyr4Ni+e+FeWar\nQs9Bkfc6KLF/5IXa9HCsHNPaaoYPY6I6RSaG4/DKoSKIEe1/GSVG1FTpZ8trUZxv\nU+xXS6gEalXcrJsiX8aXAgMBAAGjNTAzMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUE\nDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMA0GCSqGSIb3DQEBCwUAA4IBAQCx\n/dRNIj0N/16zJhZR/ahkc2AkvDXYxyr4JRT5wK9GQDNl/oaX3debRuSi/tfaXFIX\naJA6PxM4J49ZaiEpLrKfxMz5kAhjKchCBEMcH3mGt+iNZH7EOyTvHjpGrP2OZrsh\nO17yrvN3HuQxIU6roJlqtZz2iAADsoPtwOO4D7hupm9XTMkSnAmlMWOo/q46Jz89\n1sMxB+dXmH/zV0wgwh0omZfLV0u89mvdq269VhcjNBpBYSnN1ccqYWd5iwziob3I\nvaavGHGfkbvRUn/tKftYuTK30q03R+e9YbmlWZ0v695owh2e/apCzowQsCKfSVC8\nOxVyt5XkHq1tWwVyBmFp\n-----END CERTIFICATE-----\n"),
|
||||
metadata: []byte("<EntityDescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" validUntil=\"2023-08-27T12:40:58.803Z\" cacheDuration=\"PT48H\" entityID=\"http://localhost:8000/metadata\">\n <IDPSSODescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <KeyDescriptor use=\"signing\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n </KeyDescriptor>\n <KeyDescriptor use=\"encryption\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes128-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes192-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes256-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p\"></EncryptionMethod>\n </KeyDescriptor>\n <NameIDFormat>urn:oasis:names:tc:SAML:2.0:nameid-format:transient</NameIDFormat>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n </IDPSSODescriptor>\n</EntityDescriptor>"),
|
||||
options: nil,
|
||||
},
|
||||
want: want{
|
||||
err: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "failed url",
|
||||
fields: fields{
|
||||
name: "saml",
|
||||
rootURL: "%%",
|
||||
key: []byte("-----BEGIN RSA PRIVATE KEY-----\nMIIEogIBAAKCAQEAxHd087RoEm9ywVWZ/H+tDWxQsmVvhfRz4jAq/RfU+OWXNH4J\njMMSHdFs0Q+WP98nNXRyc7fgbMb8NdmlB2yD4qLYapN5SDaBc5dh/3EnyFt53oSs\njTlKnQUPAeJr2qh/NY046CfyUyQMM4JR5OiQFo4TssfWnqdcgamGt0AEnk2lvbMZ\nKQdAqNS9lDzYbjMGavEQPTZE35mFXFQXjaooZXq+TIa7hbaq7/idH7cHNbLcPLgj\nfPQA8q+DYvnvhXlmq0LPQZH3Oiixf+SF2vRwrBzT2mqGD2OiOkUmhuPwyqEiiBHt\nfxklRtRU6WfLa1Gcb1PsV0uoBGpV3KybIl/GlwIDAQABAoIBAEQjDduLgOCL6Gem\n0X3hpdnW6/HC/jed/Sa//9jBECq2LYeWAqff64ON40hqOHi0YvvGA/+gEOSI6mWe\nsv5tIxxRz+6+cLybsq+tG96kluCE4TJMHy/nY7orS/YiWbd+4odnEApr+D3fbZ/b\nnZ1fDsHTyn8hkYx6jLmnWsJpIHDp7zxD76y7k2Bbg6DZrCGiVxngiLJk23dvz79W\np03lHLM7XE92aFwXQmhfxHGxrbuoB/9eY4ai5IHp36H4fw0vL6NXdNQAo/bhe0p9\nAYB7y0ZumF8Hg0Z/BmMeEzLy6HrYB+VE8cO93pNjhSyH+p2yDB/BlUyTiRLQAoM0\nVTmOZXECgYEA7NGlzpKNhyQEJihVqt0MW0LhKIO/xbBn+XgYfX6GpqPa/ucnMx5/\nVezpl3gK8IU4wPUhAyXXAHJiqNBcEeyxrw0MXLujDVMJgYaLysCLJdvMVgoY08mS\nK5IQivpbozpf4+0y3mOnA+Sy1kbfxv2X8xiWLODRQW3f3q/xoklwOR8CgYEA1GEe\nfaibOFTQAYcIVj77KXtBfYZsX3EGAyfAN9O7cKHq5oaxVstwnF47WxpuVtoKZxCZ\nbNm9D5WvQ9b+Ztpioe42tzwE7Bff/Osj868GcDdRPK7nFlh9N2yVn/D514dOYVwR\n4MBr1KrJzgRWt4QqS4H+to1GzudDTSNlG7gnK4kCgYBUi6AbOHzoYzZL/RhgcJwp\ntJ23nhmH1Su5h2OO4e3mbhcP66w19sxU+8iFN+kH5zfUw26utgKk+TE5vXExQQRK\nT2k7bg2PAzcgk80ybD0BHhA8I0yrx4m0nmfjhe/TPVLgh10iwgbtP+eM0i6v1vc5\nZWyvxu9N4ZEL6lpkqr0y1wKBgG/NAIQd8jhhTW7Aav8cAJQBsqQl038avJOEpYe+\nCnpsgoAAf/K0/f8TDCQVceh+t+MxtdK7fO9rWOxZjWsPo8Si5mLnUaAHoX4/OpnZ\nlYYVWMqdOEFnK+O1Yb7k2GFBdV2DXlX2dc1qavntBsls5ecB89id3pyk2aUN8Pf6\npYQhAoGAMGtrHFely9wyaxI0RTCyfmJbWZHGVGkv6ELK8wneJjdjl82XOBUGCg5q\naRCrTZ3dPitKwrUa6ibJCIFCIziiriBmjDvTHzkMvoJEap2TVxYNDR6IfINVsQ57\nlOsiC4A2uGq4Lbfld+gjoplJ5GX6qXtTgZ6m7eo0y7U6zm2tkN0=\n-----END RSA PRIVATE KEY-----\n"),
|
||||
certificate: []byte("-----BEGIN CERTIFICATE-----\nMIIC2zCCAcOgAwIBAgIIAy/jm1gAAdEwDQYJKoZIhvcNAQELBQAwEjEQMA4GA1UE\nChMHWklUQURFTDAeFw0yMzA4MzAwNzExMTVaFw0yNDA4MjkwNzExMTVaMBIxEDAO\nBgNVBAoTB1pJVEFERUwwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDE\nd3TztGgSb3LBVZn8f60NbFCyZW+F9HPiMCr9F9T45Zc0fgmMwxId0WzRD5Y/3yc1\ndHJzt+Bsxvw12aUHbIPiothqk3lINoFzl2H/cSfIW3nehKyNOUqdBQ8B4mvaqH81\njTjoJ/JTJAwzglHk6JAWjhOyx9aep1yBqYa3QASeTaW9sxkpB0Co1L2UPNhuMwZq\n8RA9NkTfmYVcVBeNqihler5MhruFtqrv+J0ftwc1stw8uCN89ADyr4Ni+e+FeWar\nQs9Bkfc6KLF/5IXa9HCsHNPaaoYPY6I6RSaG4/DKoSKIEe1/GSVG1FTpZ8trUZxv\nU+xXS6gEalXcrJsiX8aXAgMBAAGjNTAzMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUE\nDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMA0GCSqGSIb3DQEBCwUAA4IBAQCx\n/dRNIj0N/16zJhZR/ahkc2AkvDXYxyr4JRT5wK9GQDNl/oaX3debRuSi/tfaXFIX\naJA6PxM4J49ZaiEpLrKfxMz5kAhjKchCBEMcH3mGt+iNZH7EOyTvHjpGrP2OZrsh\nO17yrvN3HuQxIU6roJlqtZz2iAADsoPtwOO4D7hupm9XTMkSnAmlMWOo/q46Jz89\n1sMxB+dXmH/zV0wgwh0omZfLV0u89mvdq269VhcjNBpBYSnN1ccqYWd5iwziob3I\nvaavGHGfkbvRUn/tKftYuTK30q03R+e9YbmlWZ0v695owh2e/apCzowQsCKfSVC8\nOxVyt5XkHq1tWwVyBmFp\n-----END CERTIFICATE-----\n"),
|
||||
metadata: []byte("<EntityDescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" validUntil=\"2023-08-27T12:40:58.803Z\" cacheDuration=\"PT48H\" entityID=\"http://localhost:8000/metadata\">\n <IDPSSODescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <KeyDescriptor use=\"signing\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n </KeyDescriptor>\n <KeyDescriptor use=\"encryption\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes128-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes192-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes256-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p\"></EncryptionMethod>\n </KeyDescriptor>\n <NameIDFormat>urn:oasis:names:tc:SAML:2.0:nameid-format:transient</NameIDFormat>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n </IDPSSODescriptor>\n</EntityDescriptor>"),
|
||||
options: nil,
|
||||
},
|
||||
want: want{
|
||||
err: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "default",
|
||||
fields: fields{
|
||||
name: "saml",
|
||||
rootURL: "https://localhost:8080",
|
||||
key: []byte("-----BEGIN RSA PRIVATE KEY-----\nMIIEogIBAAKCAQEAxHd087RoEm9ywVWZ/H+tDWxQsmVvhfRz4jAq/RfU+OWXNH4J\njMMSHdFs0Q+WP98nNXRyc7fgbMb8NdmlB2yD4qLYapN5SDaBc5dh/3EnyFt53oSs\njTlKnQUPAeJr2qh/NY046CfyUyQMM4JR5OiQFo4TssfWnqdcgamGt0AEnk2lvbMZ\nKQdAqNS9lDzYbjMGavEQPTZE35mFXFQXjaooZXq+TIa7hbaq7/idH7cHNbLcPLgj\nfPQA8q+DYvnvhXlmq0LPQZH3Oiixf+SF2vRwrBzT2mqGD2OiOkUmhuPwyqEiiBHt\nfxklRtRU6WfLa1Gcb1PsV0uoBGpV3KybIl/GlwIDAQABAoIBAEQjDduLgOCL6Gem\n0X3hpdnW6/HC/jed/Sa//9jBECq2LYeWAqff64ON40hqOHi0YvvGA/+gEOSI6mWe\nsv5tIxxRz+6+cLybsq+tG96kluCE4TJMHy/nY7orS/YiWbd+4odnEApr+D3fbZ/b\nnZ1fDsHTyn8hkYx6jLmnWsJpIHDp7zxD76y7k2Bbg6DZrCGiVxngiLJk23dvz79W\np03lHLM7XE92aFwXQmhfxHGxrbuoB/9eY4ai5IHp36H4fw0vL6NXdNQAo/bhe0p9\nAYB7y0ZumF8Hg0Z/BmMeEzLy6HrYB+VE8cO93pNjhSyH+p2yDB/BlUyTiRLQAoM0\nVTmOZXECgYEA7NGlzpKNhyQEJihVqt0MW0LhKIO/xbBn+XgYfX6GpqPa/ucnMx5/\nVezpl3gK8IU4wPUhAyXXAHJiqNBcEeyxrw0MXLujDVMJgYaLysCLJdvMVgoY08mS\nK5IQivpbozpf4+0y3mOnA+Sy1kbfxv2X8xiWLODRQW3f3q/xoklwOR8CgYEA1GEe\nfaibOFTQAYcIVj77KXtBfYZsX3EGAyfAN9O7cKHq5oaxVstwnF47WxpuVtoKZxCZ\nbNm9D5WvQ9b+Ztpioe42tzwE7Bff/Osj868GcDdRPK7nFlh9N2yVn/D514dOYVwR\n4MBr1KrJzgRWt4QqS4H+to1GzudDTSNlG7gnK4kCgYBUi6AbOHzoYzZL/RhgcJwp\ntJ23nhmH1Su5h2OO4e3mbhcP66w19sxU+8iFN+kH5zfUw26utgKk+TE5vXExQQRK\nT2k7bg2PAzcgk80ybD0BHhA8I0yrx4m0nmfjhe/TPVLgh10iwgbtP+eM0i6v1vc5\nZWyvxu9N4ZEL6lpkqr0y1wKBgG/NAIQd8jhhTW7Aav8cAJQBsqQl038avJOEpYe+\nCnpsgoAAf/K0/f8TDCQVceh+t+MxtdK7fO9rWOxZjWsPo8Si5mLnUaAHoX4/OpnZ\nlYYVWMqdOEFnK+O1Yb7k2GFBdV2DXlX2dc1qavntBsls5ecB89id3pyk2aUN8Pf6\npYQhAoGAMGtrHFely9wyaxI0RTCyfmJbWZHGVGkv6ELK8wneJjdjl82XOBUGCg5q\naRCrTZ3dPitKwrUa6ibJCIFCIziiriBmjDvTHzkMvoJEap2TVxYNDR6IfINVsQ57\nlOsiC4A2uGq4Lbfld+gjoplJ5GX6qXtTgZ6m7eo0y7U6zm2tkN0=\n-----END RSA PRIVATE KEY-----\n"),
|
||||
certificate: []byte("-----BEGIN CERTIFICATE-----\nMIIC2zCCAcOgAwIBAgIIAy/jm1gAAdEwDQYJKoZIhvcNAQELBQAwEjEQMA4GA1UE\nChMHWklUQURFTDAeFw0yMzA4MzAwNzExMTVaFw0yNDA4MjkwNzExMTVaMBIxEDAO\nBgNVBAoTB1pJVEFERUwwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDE\nd3TztGgSb3LBVZn8f60NbFCyZW+F9HPiMCr9F9T45Zc0fgmMwxId0WzRD5Y/3yc1\ndHJzt+Bsxvw12aUHbIPiothqk3lINoFzl2H/cSfIW3nehKyNOUqdBQ8B4mvaqH81\njTjoJ/JTJAwzglHk6JAWjhOyx9aep1yBqYa3QASeTaW9sxkpB0Co1L2UPNhuMwZq\n8RA9NkTfmYVcVBeNqihler5MhruFtqrv+J0ftwc1stw8uCN89ADyr4Ni+e+FeWar\nQs9Bkfc6KLF/5IXa9HCsHNPaaoYPY6I6RSaG4/DKoSKIEe1/GSVG1FTpZ8trUZxv\nU+xXS6gEalXcrJsiX8aXAgMBAAGjNTAzMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUE\nDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMA0GCSqGSIb3DQEBCwUAA4IBAQCx\n/dRNIj0N/16zJhZR/ahkc2AkvDXYxyr4JRT5wK9GQDNl/oaX3debRuSi/tfaXFIX\naJA6PxM4J49ZaiEpLrKfxMz5kAhjKchCBEMcH3mGt+iNZH7EOyTvHjpGrP2OZrsh\nO17yrvN3HuQxIU6roJlqtZz2iAADsoPtwOO4D7hupm9XTMkSnAmlMWOo/q46Jz89\n1sMxB+dXmH/zV0wgwh0omZfLV0u89mvdq269VhcjNBpBYSnN1ccqYWd5iwziob3I\nvaavGHGfkbvRUn/tKftYuTK30q03R+e9YbmlWZ0v695owh2e/apCzowQsCKfSVC8\nOxVyt5XkHq1tWwVyBmFp\n-----END CERTIFICATE-----\n"),
|
||||
metadata: []byte("<EntityDescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" validUntil=\"2023-08-27T12:40:58.803Z\" cacheDuration=\"PT48H\" entityID=\"http://localhost:8000/metadata\">\n <IDPSSODescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <KeyDescriptor use=\"signing\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n </KeyDescriptor>\n <KeyDescriptor use=\"encryption\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes128-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes192-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes256-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p\"></EncryptionMethod>\n </KeyDescriptor>\n <NameIDFormat>urn:oasis:names:tc:SAML:2.0:nameid-format:transient</NameIDFormat>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n </IDPSSODescriptor>\n</EntityDescriptor>"),
|
||||
options: nil,
|
||||
},
|
||||
want: want{
|
||||
name: "saml",
|
||||
linkingAllowed: false,
|
||||
creationAllowed: false,
|
||||
autoCreation: false,
|
||||
autoUpdate: false,
|
||||
nameIDFormat: saml.PersistentNameIDFormat,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all set / true",
|
||||
fields: fields{
|
||||
name: "saml",
|
||||
key: []byte("-----BEGIN RSA PRIVATE KEY-----\nMIIEogIBAAKCAQEAxHd087RoEm9ywVWZ/H+tDWxQsmVvhfRz4jAq/RfU+OWXNH4J\njMMSHdFs0Q+WP98nNXRyc7fgbMb8NdmlB2yD4qLYapN5SDaBc5dh/3EnyFt53oSs\njTlKnQUPAeJr2qh/NY046CfyUyQMM4JR5OiQFo4TssfWnqdcgamGt0AEnk2lvbMZ\nKQdAqNS9lDzYbjMGavEQPTZE35mFXFQXjaooZXq+TIa7hbaq7/idH7cHNbLcPLgj\nfPQA8q+DYvnvhXlmq0LPQZH3Oiixf+SF2vRwrBzT2mqGD2OiOkUmhuPwyqEiiBHt\nfxklRtRU6WfLa1Gcb1PsV0uoBGpV3KybIl/GlwIDAQABAoIBAEQjDduLgOCL6Gem\n0X3hpdnW6/HC/jed/Sa//9jBECq2LYeWAqff64ON40hqOHi0YvvGA/+gEOSI6mWe\nsv5tIxxRz+6+cLybsq+tG96kluCE4TJMHy/nY7orS/YiWbd+4odnEApr+D3fbZ/b\nnZ1fDsHTyn8hkYx6jLmnWsJpIHDp7zxD76y7k2Bbg6DZrCGiVxngiLJk23dvz79W\np03lHLM7XE92aFwXQmhfxHGxrbuoB/9eY4ai5IHp36H4fw0vL6NXdNQAo/bhe0p9\nAYB7y0ZumF8Hg0Z/BmMeEzLy6HrYB+VE8cO93pNjhSyH+p2yDB/BlUyTiRLQAoM0\nVTmOZXECgYEA7NGlzpKNhyQEJihVqt0MW0LhKIO/xbBn+XgYfX6GpqPa/ucnMx5/\nVezpl3gK8IU4wPUhAyXXAHJiqNBcEeyxrw0MXLujDVMJgYaLysCLJdvMVgoY08mS\nK5IQivpbozpf4+0y3mOnA+Sy1kbfxv2X8xiWLODRQW3f3q/xoklwOR8CgYEA1GEe\nfaibOFTQAYcIVj77KXtBfYZsX3EGAyfAN9O7cKHq5oaxVstwnF47WxpuVtoKZxCZ\nbNm9D5WvQ9b+Ztpioe42tzwE7Bff/Osj868GcDdRPK7nFlh9N2yVn/D514dOYVwR\n4MBr1KrJzgRWt4QqS4H+to1GzudDTSNlG7gnK4kCgYBUi6AbOHzoYzZL/RhgcJwp\ntJ23nhmH1Su5h2OO4e3mbhcP66w19sxU+8iFN+kH5zfUw26utgKk+TE5vXExQQRK\nT2k7bg2PAzcgk80ybD0BHhA8I0yrx4m0nmfjhe/TPVLgh10iwgbtP+eM0i6v1vc5\nZWyvxu9N4ZEL6lpkqr0y1wKBgG/NAIQd8jhhTW7Aav8cAJQBsqQl038avJOEpYe+\nCnpsgoAAf/K0/f8TDCQVceh+t+MxtdK7fO9rWOxZjWsPo8Si5mLnUaAHoX4/OpnZ\nlYYVWMqdOEFnK+O1Yb7k2GFBdV2DXlX2dc1qavntBsls5ecB89id3pyk2aUN8Pf6\npYQhAoGAMGtrHFely9wyaxI0RTCyfmJbWZHGVGkv6ELK8wneJjdjl82XOBUGCg5q\naRCrTZ3dPitKwrUa6ibJCIFCIziiriBmjDvTHzkMvoJEap2TVxYNDR6IfINVsQ57\nlOsiC4A2uGq4Lbfld+gjoplJ5GX6qXtTgZ6m7eo0y7U6zm2tkN0=\n-----END RSA PRIVATE KEY-----\n"),
|
||||
certificate: []byte("-----BEGIN CERTIFICATE-----\nMIIC2zCCAcOgAwIBAgIIAy/jm1gAAdEwDQYJKoZIhvcNAQELBQAwEjEQMA4GA1UE\nChMHWklUQURFTDAeFw0yMzA4MzAwNzExMTVaFw0yNDA4MjkwNzExMTVaMBIxEDAO\nBgNVBAoTB1pJVEFERUwwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDE\nd3TztGgSb3LBVZn8f60NbFCyZW+F9HPiMCr9F9T45Zc0fgmMwxId0WzRD5Y/3yc1\ndHJzt+Bsxvw12aUHbIPiothqk3lINoFzl2H/cSfIW3nehKyNOUqdBQ8B4mvaqH81\njTjoJ/JTJAwzglHk6JAWjhOyx9aep1yBqYa3QASeTaW9sxkpB0Co1L2UPNhuMwZq\n8RA9NkTfmYVcVBeNqihler5MhruFtqrv+J0ftwc1stw8uCN89ADyr4Ni+e+FeWar\nQs9Bkfc6KLF/5IXa9HCsHNPaaoYPY6I6RSaG4/DKoSKIEe1/GSVG1FTpZ8trUZxv\nU+xXS6gEalXcrJsiX8aXAgMBAAGjNTAzMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUE\nDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMA0GCSqGSIb3DQEBCwUAA4IBAQCx\n/dRNIj0N/16zJhZR/ahkc2AkvDXYxyr4JRT5wK9GQDNl/oaX3debRuSi/tfaXFIX\naJA6PxM4J49ZaiEpLrKfxMz5kAhjKchCBEMcH3mGt+iNZH7EOyTvHjpGrP2OZrsh\nO17yrvN3HuQxIU6roJlqtZz2iAADsoPtwOO4D7hupm9XTMkSnAmlMWOo/q46Jz89\n1sMxB+dXmH/zV0wgwh0omZfLV0u89mvdq269VhcjNBpBYSnN1ccqYWd5iwziob3I\nvaavGHGfkbvRUn/tKftYuTK30q03R+e9YbmlWZ0v695owh2e/apCzowQsCKfSVC8\nOxVyt5XkHq1tWwVyBmFp\n-----END CERTIFICATE-----\n"),
|
||||
metadata: []byte("<EntityDescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" validUntil=\"2023-08-27T12:40:58.803Z\" cacheDuration=\"PT48H\" entityID=\"http://localhost:8000/metadata\">\n <IDPSSODescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\">\n <KeyDescriptor use=\"signing\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n </KeyDescriptor>\n <KeyDescriptor use=\"encryption\">\n <KeyInfo xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Data xmlns=\"http://www.w3.org/2000/09/xmldsig#\">\n <X509Certificate xmlns=\"http://www.w3.org/2000/09/xmldsig#\">MIIDBzCCAe+gAwIBAgIJAPr/Mrlc8EGhMA0GCSqGSIb3DQEBBQUAMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTAeFw0xNTEyMjgxOTE5NDVaFw0yNTEyMjUxOTE5NDVaMBoxGDAWBgNVBAMMD3d3dy5leGFtcGxlLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANDoWzLos4LWxTn8Gyu2lEbl4WcelUbgLN5zYm4ron8Ahs+rvcsu2zkdD/s6jdGJI8WqJKhYK2u61ygnXgAZqC6ggtFPnBpizcDzjgND2g+aucSoUODHt67f0fQuAmupN/zp5MZysJ6IHLJnYLNpfJYk96lRz9ODnO1Mpqtr9PWxm+pz7nzq5F0vRepkgpcRxv6ufQBjlrFytccyEVdXrvFtkjXcnhVVNSR4kHuOOMS6D7pebSJ1mrCmshbD5SX1jXPBKFPAjozYX6PxqLxUx1Y4faFEf4MBBVcInyB4oURNB2s59hEEi2jq9izNE7EbEK6BY5sEhoCPl9m32zE6ljkCAwEAAaNQME4wHQYDVR0OBBYEFB9ZklC1Ork2zl56zg08ei7ss/+iMB8GA1UdIwQYMBaAFB9ZklC1Ork2zl56zg08ei7ss/+iMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAAVoTSQ5pAirw8OR9FZ1bRSuTDhY9uxzl/OL7lUmsv2cMNeCB3BRZqm3mFt+cwN8GsH6f3uvNONIhgFpTGN5LEcXQz89zJEzB+qaHqmbFpHQl/sx2B8ezNgT/882H2IH00dXESEfy/+1gHg2pxjGnhRBN6el/gSaDiySIMKbilDrffuvxiCfbpPN0NRRiPJhd2ay9KuL/RxQRl1gl9cHaWiouWWba1bSBb2ZPhv2rPMUsFo98ntkGCObDX6Y1SpkqmoTbrsbGFsTG2DLxnvr4GdN1BSr0Uu/KV3adj47WkXVPeMYQti/bQmxQB8tRFhrw80qakTLUzreO96WzlBBMtY=</X509Certificate>\n </X509Data>\n </KeyInfo>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes128-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes192-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#aes256-cbc\"></EncryptionMethod>\n <EncryptionMethod Algorithm=\"http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p\"></EncryptionMethod>\n </KeyDescriptor>\n <NameIDFormat>urn:oasis:names:tc:SAML:2.0:nameid-format:transient</NameIDFormat>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n <SingleSignOnService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\" Location=\"http://localhost:8000/sso\"></SingleSignOnService>\n </IDPSSODescriptor>\n</EntityDescriptor>"),
|
||||
options: []ProviderOpts{
|
||||
WithLinkingAllowed(),
|
||||
WithCreationAllowed(),
|
||||
WithAutoCreation(),
|
||||
WithAutoUpdate(),
|
||||
WithBinding("binding"),
|
||||
WithSignedRequest(),
|
||||
WithCustomRequestTracker(&requesttracker.RequestTracker{}),
|
||||
WithEntityID("entityID"),
|
||||
WithNameIDFormat(domain.SAMLNameIDFormatTransient),
|
||||
WithTransientMappingAttributeName("attribute"),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
name: "saml",
|
||||
linkingAllowed: true,
|
||||
creationAllowed: true,
|
||||
autoCreation: true,
|
||||
autoUpdate: true,
|
||||
binding: "binding",
|
||||
entityID: "entityID",
|
||||
nameIDFormat: saml.TransientNameIDFormat,
|
||||
transientMappingAttributeName: "attribute",
|
||||
withSignedRequest: true,
|
||||
requesttracker: &requesttracker.RequestTracker{},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
provider, err := New(tt.fields.name, tt.fields.rootURL, tt.fields.metadata, tt.fields.certificate, tt.fields.key, tt.fields.options...)
|
||||
if tt.want.err {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
|
||||
a.Equal(tt.want.name, provider.Name())
|
||||
a.Equal(tt.want.linkingAllowed, provider.IsLinkingAllowed())
|
||||
a.Equal(tt.want.creationAllowed, provider.IsCreationAllowed())
|
||||
a.Equal(tt.want.autoCreation, provider.IsAutoCreation())
|
||||
a.Equal(tt.want.autoUpdate, provider.IsAutoUpdate())
|
||||
a.Equal(tt.want.binding, provider.binding)
|
||||
a.Equal(tt.want.nameIDFormat, provider.nameIDFormat)
|
||||
a.Equal(tt.want.transientMappingAttributeName, provider.transientMappingAttributeName)
|
||||
a.Equal(tt.want.withSignedRequest, provider.spOptions.SignRequest)
|
||||
a.Equal(tt.want.requesttracker, provider.requestTracker)
|
||||
a.Equal(tt.want.entityID, provider.spOptions.EntityID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMetadata(t *testing.T) {
|
||||
type args struct {
|
||||
metadata []byte
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *saml.EntityDescriptor
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
"invalid",
|
||||
args{
|
||||
metadata: []byte(`<Test></Test>`),
|
||||
},
|
||||
nil,
|
||||
xml.UnmarshalError("expected element type <EntityDescriptor> but have <Test>"),
|
||||
},
|
||||
{
|
||||
"valid entity descriptor",
|
||||
args{
|
||||
metadata: []byte(`<?xml version="1.0" encoding="UTF-8"?><EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="http://localhost:8000/metadata"><IDPSSODescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata"><SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="http://localhost:8000/sso"></SingleSignOnService></IDPSSODescriptor></EntityDescriptor>`),
|
||||
},
|
||||
&saml.EntityDescriptor{
|
||||
EntityID: "http://localhost:8000/metadata",
|
||||
IDPSSODescriptors: []saml.IDPSSODescriptor{
|
||||
{
|
||||
XMLName: xml.Name{
|
||||
Space: "urn:oasis:names:tc:SAML:2.0:metadata",
|
||||
Local: "IDPSSODescriptor",
|
||||
},
|
||||
SingleSignOnServices: []saml.Endpoint{
|
||||
{
|
||||
Binding: saml.HTTPRedirectBinding,
|
||||
Location: "http://localhost:8000/sso",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"valid entity descriptor, non utf-8",
|
||||
args{
|
||||
metadata: []byte(`<?xml version="1.0" encoding="windows-1252"?><EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="http://localhost:8000/metadata"><IDPSSODescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata"><SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="http://localhost:8000/sso"></SingleSignOnService></IDPSSODescriptor></EntityDescriptor>`),
|
||||
},
|
||||
&saml.EntityDescriptor{
|
||||
EntityID: "http://localhost:8000/metadata",
|
||||
IDPSSODescriptors: []saml.IDPSSODescriptor{
|
||||
{
|
||||
XMLName: xml.Name{
|
||||
Space: "urn:oasis:names:tc:SAML:2.0:metadata",
|
||||
Local: "IDPSSODescriptor",
|
||||
},
|
||||
SingleSignOnServices: []saml.Endpoint{
|
||||
{
|
||||
Binding: saml.HTTPRedirectBinding,
|
||||
Location: "http://localhost:8000/sso",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"entities descriptor without IDPSSODescriptor",
|
||||
args{
|
||||
metadata: []byte(`<?xml version="1.0" encoding="UTF-8"?><EntitiesDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata"><EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="http://localhost:8000/metadata"></EntityDescriptor></EntitiesDescriptor>`),
|
||||
},
|
||||
nil,
|
||||
zerrors.ThrowInternal(nil, "SAML-Ejoi3r2", "no entity found with IDPSSODescriptor"),
|
||||
},
|
||||
{
|
||||
"valid entities descriptor",
|
||||
args{
|
||||
metadata: []byte(`<?xml version="1.0" encoding="UTF-8"?><EntitiesDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata"><EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="http://localhost:8000/metadata"><IDPSSODescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata"><SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="http://localhost:8000/sso"></SingleSignOnService></IDPSSODescriptor></EntityDescriptor></EntitiesDescriptor>`),
|
||||
},
|
||||
&saml.EntityDescriptor{
|
||||
EntityID: "http://localhost:8000/metadata",
|
||||
IDPSSODescriptors: []saml.IDPSSODescriptor{
|
||||
{
|
||||
XMLName: xml.Name{
|
||||
Space: "urn:oasis:names:tc:SAML:2.0:metadata",
|
||||
Local: "IDPSSODescriptor",
|
||||
},
|
||||
SingleSignOnServices: []saml.Endpoint{
|
||||
{
|
||||
Binding: saml.HTTPRedirectBinding,
|
||||
Location: "http://localhost:8000/sso",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"valid entities using xsd duration descriptor",
|
||||
args{
|
||||
metadata: []byte(`<?xml version="1.0" encoding="UTF-8"?><EntitiesDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" cacheDuration="PT5H"><EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="http://localhost:8000/metadata" cacheDuration="PT5H"><IDPSSODescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata"><SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="http://localhost:8000/sso"></SingleSignOnService></IDPSSODescriptor></EntityDescriptor></EntitiesDescriptor>`),
|
||||
},
|
||||
&saml.EntityDescriptor{
|
||||
EntityID: "http://localhost:8000/metadata",
|
||||
CacheDuration: 5 * time.Hour,
|
||||
IDPSSODescriptors: []saml.IDPSSODescriptor{
|
||||
{
|
||||
XMLName: xml.Name{
|
||||
Space: "urn:oasis:names:tc:SAML:2.0:metadata",
|
||||
Local: "IDPSSODescriptor",
|
||||
},
|
||||
SingleSignOnServices: []saml.Endpoint{
|
||||
{
|
||||
Binding: saml.HTTPRedirectBinding,
|
||||
Location: "http://localhost:8000/sso",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ParseMetadata(tt.args.metadata)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
180
apps/api/internal/idp/providers/saml/session.go
Normal file
180
apps/api/internal/idp/providers/saml/session.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/beevik/etree"
|
||||
"github.com/crewjam/saml"
|
||||
"github.com/crewjam/saml/samlsp"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
var _ idp.Session = (*Session)(nil)
|
||||
|
||||
// Session is the [idp.Session] implementation for the SAML provider.
|
||||
type Session struct {
|
||||
ServiceProvider *samlsp.Middleware
|
||||
state string
|
||||
TransientMappingAttributeName string
|
||||
|
||||
RequestID string
|
||||
Request *http.Request
|
||||
|
||||
Assertion *saml.Assertion
|
||||
}
|
||||
|
||||
func NewSession(provider *Provider, requestID string, request *http.Request) (*Session, error) {
|
||||
sp, err := provider.GetSP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Session{
|
||||
ServiceProvider: sp,
|
||||
TransientMappingAttributeName: provider.TransientMappingAttributeName(),
|
||||
RequestID: requestID,
|
||||
Request: request,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetAuth implements the [idp.Session] interface.
|
||||
func (s *Session) GetAuth(ctx context.Context) (idp.Auth, error) {
|
||||
url, err := url.Parse(s.state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request := &http.Request{
|
||||
URL: url,
|
||||
}
|
||||
return s.auth(request.WithContext(ctx))
|
||||
}
|
||||
|
||||
// PersistentParameters implements the [idp.Session] interface.
|
||||
func (s *Session) PersistentParameters() map[string]any {
|
||||
return nil
|
||||
}
|
||||
|
||||
// FetchUser implements the [idp.Session] interface.
|
||||
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
|
||||
if s.RequestID == "" || s.Request == nil {
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "SAML-d09hy0wkex", "Errors.Intent.ResponseInvalid")
|
||||
}
|
||||
|
||||
s.Assertion, err = s.ServiceProvider.ServiceProvider.ParseResponse(s.Request, []string{s.RequestID})
|
||||
if err != nil {
|
||||
invalidRespErr := new(saml.InvalidResponseError)
|
||||
if errors.As(err, &invalidRespErr) {
|
||||
return nil, zerrors.ThrowInvalidArgument(invalidRespErr.PrivateErr, "SAML-ajl3irfs", "Errors.Intent.ResponseInvalid")
|
||||
}
|
||||
return nil, zerrors.ThrowInvalidArgument(err, "SAML-nuo0vphhh9", "Errors.Intent.ResponseInvalid")
|
||||
}
|
||||
|
||||
// nameID is required, but at least in ADFS it will not be sent unless explicitly configured
|
||||
if s.Assertion.Subject == nil || s.Assertion.Subject.NameID == nil {
|
||||
return nil, zerrors.ThrowInvalidArgument(err, "SAML-EFG32", "Errors.Intent.ResponseInvalid")
|
||||
}
|
||||
nameID := s.Assertion.Subject.NameID
|
||||
userMapper := NewUser()
|
||||
// use the nameID as default mapping id
|
||||
userMapper.SetID(nameID.Value)
|
||||
if nameID.Format == string(saml.TransientNameIDFormat) {
|
||||
mappingID, err := s.transientMappingID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userMapper.SetID(mappingID)
|
||||
}
|
||||
for _, statement := range s.Assertion.AttributeStatements {
|
||||
for _, attribute := range statement.Attributes {
|
||||
values := make([]string, len(attribute.Values))
|
||||
for i := range attribute.Values {
|
||||
values[i] = attribute.Values[i].Value
|
||||
}
|
||||
userMapper.Attributes[attribute.Name] = values
|
||||
}
|
||||
}
|
||||
return userMapper, nil
|
||||
}
|
||||
|
||||
func (s *Session) ExpiresAt() time.Time {
|
||||
if s.Assertion == nil || s.Assertion.Conditions == nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return s.Assertion.Conditions.NotOnOrAfter
|
||||
}
|
||||
|
||||
func (s *Session) transientMappingID() (string, error) {
|
||||
for _, statement := range s.Assertion.AttributeStatements {
|
||||
for _, attribute := range statement.Attributes {
|
||||
if attribute.Name != s.TransientMappingAttributeName {
|
||||
continue
|
||||
}
|
||||
if len(attribute.Values) != 1 {
|
||||
return "", zerrors.ThrowInvalidArgument(nil, "SAML-Soij4", "Errors.Intent.MissingSingleMappingAttribute")
|
||||
}
|
||||
return attribute.Values[0].Value, nil
|
||||
}
|
||||
}
|
||||
return "", zerrors.ThrowInvalidArgument(nil, "SAML-swwg2", "Errors.Intent.MissingSingleMappingAttribute")
|
||||
}
|
||||
|
||||
// auth is a modified copy of the [samlsp.Middleware.HandleStartAuthFlow] method.
|
||||
// Instead of writing the response to the http.ResponseWriter, it returns the auth request as an [idp.Auth].
|
||||
// In case of an error, it returns the error directly and does not write to the response.
|
||||
func (s *Session) auth(r *http.Request) (idp.Auth, error) {
|
||||
if r.URL.Path == s.ServiceProvider.ServiceProvider.AcsURL.Path {
|
||||
// should never occur, but was handled in the original method, so we keep it here
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "SAML-Eoi24", "don't wrap Middleware with RequireAccount")
|
||||
}
|
||||
|
||||
var binding, bindingLocation string
|
||||
if s.ServiceProvider.Binding != "" {
|
||||
binding = s.ServiceProvider.Binding
|
||||
bindingLocation = s.ServiceProvider.ServiceProvider.GetSSOBindingLocation(binding)
|
||||
} else {
|
||||
binding = saml.HTTPRedirectBinding
|
||||
bindingLocation = s.ServiceProvider.ServiceProvider.GetSSOBindingLocation(binding)
|
||||
if bindingLocation == "" {
|
||||
binding = saml.HTTPPostBinding
|
||||
bindingLocation = s.ServiceProvider.ServiceProvider.GetSSOBindingLocation(binding)
|
||||
}
|
||||
}
|
||||
|
||||
authReq, err := s.ServiceProvider.ServiceProvider.MakeAuthenticationRequest(bindingLocation, binding, s.ServiceProvider.ResponseBinding)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
relayState, err := s.ServiceProvider.RequestTracker.TrackRequest(nil, r, authReq.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if binding == saml.HTTPRedirectBinding {
|
||||
redirectURL, err := authReq.Redirect(relayState, &s.ServiceProvider.ServiceProvider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return idp.Redirect(redirectURL.String())
|
||||
}
|
||||
if binding == saml.HTTPPostBinding {
|
||||
doc := etree.NewDocument()
|
||||
doc.SetRoot(authReq.Element())
|
||||
reqBuf, err := doc.WriteToBytes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
encodedReqBuf := base64.StdEncoding.EncodeToString(reqBuf)
|
||||
return idp.Form(authReq.Destination,
|
||||
map[string]string{
|
||||
"SAMLRequest": encodedReqBuf,
|
||||
"RelayState": relayState,
|
||||
})
|
||||
}
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "SAML-Eoi24", "Errors.Intent.Invalid")
|
||||
}
|
328
apps/api/internal/idp/providers/saml/session_test.go
Normal file
328
apps/api/internal/idp/providers/saml/session_test.go
Normal file
File diff suppressed because one or more lines are too long
51
apps/api/internal/idp/session.go
Normal file
51
apps/api/internal/idp/session.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Session is the minimal implementation for a session of a 3rd party authentication [Provider]
|
||||
type Session interface {
|
||||
GetAuth(ctx context.Context) (Auth, error)
|
||||
PersistentParameters() map[string]any
|
||||
FetchUser(ctx context.Context) (User, error)
|
||||
ExpiresAt() time.Time
|
||||
}
|
||||
|
||||
type Auth interface {
|
||||
auth()
|
||||
}
|
||||
|
||||
type RedirectAuth struct {
|
||||
RedirectURL string
|
||||
}
|
||||
|
||||
func (r *RedirectAuth) auth() {}
|
||||
|
||||
type FormAuth struct {
|
||||
URL string
|
||||
Fields map[string]string
|
||||
}
|
||||
|
||||
func (f *FormAuth) auth() {}
|
||||
|
||||
// SessionSupportsMigration is an optional extension to the Session interface.
|
||||
// It can be implemented to support migrating users, were the initial external id has changed because of a migration of the Provider type.
|
||||
// E.g. when a user was linked on a generic OIDC provider and this provider has now been migrated to an AzureAD provider.
|
||||
// In this case OIDC used the `sub` claim and Azure now uses the id of the user endpoint, which differ.
|
||||
// The RetrievePreviousID will return the `sub` claim again, so that the user can be matched and safely migrated to the new id.
|
||||
type SessionSupportsMigration interface {
|
||||
RetrievePreviousID() (previousID string, err error)
|
||||
}
|
||||
|
||||
func Redirect(redirectURL string) (*RedirectAuth, error) {
|
||||
return &RedirectAuth{RedirectURL: redirectURL}, nil
|
||||
}
|
||||
|
||||
func Form(url string, fields map[string]string) (*FormAuth, error) {
|
||||
return &FormAuth{
|
||||
URL: url,
|
||||
Fields: fields,
|
||||
}, nil
|
||||
}
|
Reference in New Issue
Block a user