mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-13 13:13:38 +00:00
chore: move the go code into a subfolder
This commit is contained in:
253
apps/api/internal/idp/providers/azuread/azuread.go
Normal file
253
apps/api/internal/idp/providers/azuread/azuread.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package azuread
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
|
||||
)
|
||||
|
||||
const (
|
||||
issuerTemplate string = "https://login.microsoftonline.com/%s/v2.0"
|
||||
authURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/v2.0/authorize"
|
||||
tokenURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/v2.0/token"
|
||||
keysURLTemplate string = "https://login.microsoftonline.com/%s/discovery/v2.0/keys"
|
||||
userURL string = "https://graph.microsoft.com/v1.0/me"
|
||||
userinfoEndpoint string = "https://graph.microsoft.com/oidc/userinfo"
|
||||
|
||||
ScopeUserRead string = "User.Read"
|
||||
)
|
||||
|
||||
// TenantType are the well known tenant types to scope the users that can authenticate. TenantType is not an
|
||||
// exclusive list of Azure Tenants which can be used. A consumer can also use their own Tenant ID to scope
|
||||
// authentication to their specific Tenant either through the Tenant ID or the friendly domain name.
|
||||
//
|
||||
// see also https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-v2-protocols#endpoints
|
||||
type TenantType string
|
||||
|
||||
const (
|
||||
// CommonTenant allows users with both personal Microsoft accounts and work/school accounts from Azure Active
|
||||
// Directory to sign in to the application.
|
||||
CommonTenant TenantType = "common"
|
||||
|
||||
// OrganizationsTenant allows only users with work/school accounts from Azure Active Directory to sign in to the application.
|
||||
OrganizationsTenant TenantType = "organizations"
|
||||
|
||||
// ConsumersTenant allows only users with personal Microsoft accounts (MSA) to sign in to the application.
|
||||
ConsumersTenant TenantType = "consumers"
|
||||
)
|
||||
|
||||
var _ idp.Provider = (*Provider)(nil)
|
||||
|
||||
// Provider is the [idp.Provider] implementation for AzureAD (V2 Endpoints)
|
||||
type Provider struct {
|
||||
*oauth.Provider
|
||||
tenant TenantType
|
||||
emailVerified bool
|
||||
options []oauth.ProviderOpts
|
||||
}
|
||||
|
||||
// issuer returns the OIDC issuer based on the [TenantType]
|
||||
func (p *Provider) issuer() string {
|
||||
return fmt.Sprintf(issuerTemplate, p.tenant)
|
||||
}
|
||||
|
||||
// keysEndpoint returns the OIDC jwks_url based on the [TenantType]
|
||||
func (p *Provider) keysEndpoint() string {
|
||||
return fmt.Sprintf(keysURLTemplate, p.tenant)
|
||||
}
|
||||
|
||||
type ProviderOptions func(*Provider)
|
||||
|
||||
// WithTenant allows to set a [TenantType] (can also be a Tenant ID)
|
||||
// default is CommonTenant
|
||||
func WithTenant(tenantType TenantType) ProviderOptions {
|
||||
return func(p *Provider) {
|
||||
p.tenant = tenantType
|
||||
}
|
||||
}
|
||||
|
||||
// WithEmailVerified allows to set every email received as verified
|
||||
func WithEmailVerified() ProviderOptions {
|
||||
return func(p *Provider) {
|
||||
p.emailVerified = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithOAuthOptions allows to specify [oauth.ProviderOpts] like [oauth.WithLinkingAllowed]
|
||||
func WithOAuthOptions(opts ...oauth.ProviderOpts) ProviderOptions {
|
||||
return func(p *Provider) {
|
||||
p.options = append(p.options, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
// New creates an AzureAD provider using the [oauth.Provider] (OAuth 2.0 generic provider).
|
||||
// By default, it uses the [CommonTenant] and unverified emails.
|
||||
func New(name, clientID, clientSecret, redirectURI string, scopes []string, opts ...ProviderOptions) (*Provider, error) {
|
||||
provider := &Provider{
|
||||
tenant: CommonTenant,
|
||||
options: make([]oauth.ProviderOpts, 0),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(provider)
|
||||
}
|
||||
config := newConfig(provider.tenant, clientID, clientSecret, redirectURI, scopes)
|
||||
rp, err := oauth.New(
|
||||
config,
|
||||
name,
|
||||
userURL,
|
||||
func() idp.User {
|
||||
return &User{isEmailVerified: provider.emailVerified}
|
||||
},
|
||||
provider.options...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
provider.Provider = rp
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func newConfig(tenant TenantType, clientID, secret, callbackURL string, scopes []string) *oauth2.Config {
|
||||
return &oauth2.Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: secret,
|
||||
RedirectURL: callbackURL,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: fmt.Sprintf(authURLTemplate, tenant),
|
||||
TokenURL: fmt.Sprintf(tokenURLTemplate, tenant),
|
||||
},
|
||||
Scopes: ensureMinimalScope(scopes),
|
||||
}
|
||||
}
|
||||
|
||||
// ensureMinimalScope ensures that at least openid and `User.Read` ist set
|
||||
// if none is provided it will request `openid profile email phone User.Read`
|
||||
func ensureMinimalScope(scopes []string) []string {
|
||||
if len(scopes) == 0 {
|
||||
return []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopePhone, ScopeUserRead}
|
||||
}
|
||||
var openIDSet, userReadSet bool
|
||||
for _, scope := range scopes {
|
||||
if scope == oidc.ScopeOpenID {
|
||||
openIDSet = true
|
||||
continue
|
||||
}
|
||||
if scope == ScopeUserRead {
|
||||
userReadSet = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
if !openIDSet {
|
||||
scopes = append(scopes, oidc.ScopeOpenID)
|
||||
}
|
||||
if !userReadSet {
|
||||
scopes = append(scopes, ScopeUserRead)
|
||||
}
|
||||
return scopes
|
||||
}
|
||||
|
||||
func (p *Provider) User() idp.User {
|
||||
return p.Provider.User()
|
||||
}
|
||||
|
||||
// User represents the structure return on the userinfo endpoint and implements the [idp.User] interface
|
||||
//
|
||||
// AzureAD does not return an `email_verified` claim.
|
||||
// The verification can be automatically activated on the provider ([WithEmailVerified])
|
||||
type User struct {
|
||||
ID string `json:"id"`
|
||||
BusinessPhones []domain.PhoneNumber `json:"businessPhones"`
|
||||
DisplayName string `json:"displayName"`
|
||||
FirstName string `json:"givenName"`
|
||||
JobTitle string `json:"jobTitle"`
|
||||
Email domain.EmailAddress `json:"mail"`
|
||||
MobilePhone domain.PhoneNumber `json:"mobilePhone"`
|
||||
OfficeLocation string `json:"officeLocation"`
|
||||
PreferredLanguage string `json:"preferredLanguage"`
|
||||
LastName string `json:"surname"`
|
||||
UserPrincipalName string `json:"userPrincipalName"`
|
||||
isEmailVerified bool
|
||||
}
|
||||
|
||||
// GetID is an implementation of the [idp.User] interface.
|
||||
func (u *User) GetID() string {
|
||||
return u.ID
|
||||
}
|
||||
|
||||
// GetFirstName is an implementation of the [idp.User] interface.
|
||||
func (u *User) GetFirstName() string {
|
||||
return u.FirstName
|
||||
}
|
||||
|
||||
// GetLastName is an implementation of the [idp.User] interface.
|
||||
func (u *User) GetLastName() string {
|
||||
return u.LastName
|
||||
}
|
||||
|
||||
// GetDisplayName is an implementation of the [idp.User] interface.
|
||||
func (u *User) GetDisplayName() string {
|
||||
return u.DisplayName
|
||||
}
|
||||
|
||||
// GetNickname is an implementation of the [idp.User] interface.
|
||||
// It returns an empty string because AzureAD does not provide the user's nickname.
|
||||
func (u *User) GetNickname() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetPreferredUsername is an implementation of the [idp.User] interface.
|
||||
func (u *User) GetPreferredUsername() string {
|
||||
return u.UserPrincipalName
|
||||
}
|
||||
|
||||
// GetEmail is an implementation of the [idp.User] interface.
|
||||
func (u *User) GetEmail() domain.EmailAddress {
|
||||
if u.Email == "" {
|
||||
// if the user used a social login on Azure as well, the email will be empty
|
||||
// but is used as username
|
||||
return domain.EmailAddress(u.UserPrincipalName)
|
||||
}
|
||||
return u.Email
|
||||
}
|
||||
|
||||
// IsEmailVerified is an implementation of the [idp.User] interface
|
||||
// returning the value specified in the creation of the [Provider].
|
||||
// Default is false because AzureAD does not return an `email_verified` claim.
|
||||
// The verification can be automatically activated on the provider ([WithEmailVerified]).
|
||||
func (u *User) IsEmailVerified() bool {
|
||||
return u.isEmailVerified
|
||||
}
|
||||
|
||||
// GetPhone is an implementation of the [idp.User] interface.
|
||||
// It returns an empty string because AzureAD does not provide the user's phone.
|
||||
func (u *User) GetPhone() domain.PhoneNumber {
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsPhoneVerified is an implementation of the [idp.User] interface.
|
||||
// It returns false because AzureAD does not provide the user's phone.
|
||||
func (u *User) IsPhoneVerified() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// GetPreferredLanguage is an implementation of the [idp.User] interface.
|
||||
func (u *User) GetPreferredLanguage() language.Tag {
|
||||
return language.Make(u.PreferredLanguage)
|
||||
}
|
||||
|
||||
// GetProfile is an implementation of the [idp.User] interface.
|
||||
// It returns an empty string because AzureAD does not provide the user's profile page.
|
||||
func (u *User) GetProfile() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetAvatarURL is an implementation of the [idp.User] interface.
|
||||
func (u *User) GetAvatarURL() string {
|
||||
return ""
|
||||
}
|
185
apps/api/internal/idp/providers/azuread/azuread_test.go
Normal file
185
apps/api/internal/idp/providers/azuread/azuread_test.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package azuread
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
openid "github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
|
||||
)
|
||||
|
||||
func TestProvider_BeginAuth(t *testing.T) {
|
||||
type fields struct {
|
||||
name string
|
||||
clientID string
|
||||
clientSecret string
|
||||
redirectURI string
|
||||
scopes []string
|
||||
options []ProviderOptions
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want idp.Session
|
||||
}{
|
||||
{
|
||||
name: "default common tenant",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
},
|
||||
want: &oidc.Session{
|
||||
AuthURL: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email+phone+User.Read&state=testState",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tenant",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
options: []ProviderOptions{
|
||||
WithTenant(ConsumersTenant),
|
||||
},
|
||||
},
|
||||
want: &oidc.Session{
|
||||
AuthURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email+phone+User.Read&state=testState",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "custom scopes",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
scopes: []string{openid.ScopeOpenID, openid.ScopeProfile, "custom"},
|
||||
options: []ProviderOptions{
|
||||
WithTenant(ConsumersTenant),
|
||||
},
|
||||
},
|
||||
want: &oidc.Session{
|
||||
AuthURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid+profile+custom+User.Read&state=testState",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
r := require.New(t)
|
||||
|
||||
provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.options...)
|
||||
r.NoError(err)
|
||||
|
||||
ctx := context.Background()
|
||||
session, err := provider.BeginAuth(ctx, "testState")
|
||||
r.NoError(err)
|
||||
|
||||
wantAuth, wantErr := tt.want.GetAuth(ctx)
|
||||
gotAuth, gotErr := session.GetAuth(ctx)
|
||||
a.Equal(wantAuth, gotAuth)
|
||||
a.ErrorIs(gotErr, wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_Options(t *testing.T) {
|
||||
type fields struct {
|
||||
name string
|
||||
clientID string
|
||||
clientSecret string
|
||||
redirectURI string
|
||||
scopes []string
|
||||
options []ProviderOptions
|
||||
}
|
||||
type want struct {
|
||||
name string
|
||||
tenant TenantType
|
||||
emailVerified bool
|
||||
linkingAllowed bool
|
||||
creationAllowed bool
|
||||
autoCreation bool
|
||||
autoUpdate bool
|
||||
pkce bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "default common tenant",
|
||||
fields: fields{
|
||||
name: "default common tenant",
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
scopes: nil,
|
||||
options: nil,
|
||||
},
|
||||
want: want{
|
||||
name: "default common tenant",
|
||||
tenant: CommonTenant,
|
||||
emailVerified: false,
|
||||
linkingAllowed: false,
|
||||
creationAllowed: false,
|
||||
autoCreation: false,
|
||||
autoUpdate: false,
|
||||
pkce: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all set",
|
||||
fields: fields{
|
||||
name: "custom tenant",
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
options: []ProviderOptions{
|
||||
WithTenant("tenant"),
|
||||
WithEmailVerified(),
|
||||
WithOAuthOptions(
|
||||
oauth.WithLinkingAllowed(),
|
||||
oauth.WithCreationAllowed(),
|
||||
oauth.WithAutoCreation(),
|
||||
oauth.WithAutoUpdate(),
|
||||
oauth.WithRelyingPartyOption(rp.WithPKCE(nil)),
|
||||
),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
name: "custom tenant",
|
||||
tenant: "tenant",
|
||||
emailVerified: true,
|
||||
linkingAllowed: true,
|
||||
creationAllowed: true,
|
||||
autoCreation: true,
|
||||
autoUpdate: true,
|
||||
pkce: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := assert.New(t)
|
||||
|
||||
provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.options...)
|
||||
require.NoError(t, err)
|
||||
|
||||
a.Equal(tt.want.name, provider.Name())
|
||||
a.Equal(tt.want.tenant, provider.tenant)
|
||||
a.Equal(tt.want.emailVerified, provider.emailVerified)
|
||||
a.Equal(tt.want.linkingAllowed, provider.IsLinkingAllowed())
|
||||
a.Equal(tt.want.creationAllowed, provider.IsCreationAllowed())
|
||||
a.Equal(tt.want.autoCreation, provider.IsAutoCreation())
|
||||
a.Equal(tt.want.autoUpdate, provider.IsAutoUpdate())
|
||||
a.Equal(tt.want.pkce, provider.RelyingParty.IsPKCE())
|
||||
})
|
||||
}
|
||||
}
|
106
apps/api/internal/idp/providers/azuread/session.go
Normal file
106
apps/api/internal/idp/providers/azuread/session.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package azuread
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/client/rp"
|
||||
httphelper "github.com/zitadel/oidc/v3/pkg/http"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
|
||||
)
|
||||
|
||||
var _ idp.Session = (*Session)(nil)
|
||||
|
||||
// Session extends the [oauth.Session] to be able to handle the id_token and to implement the [idp.SessionSupportsMigration] functionality
|
||||
type Session struct {
|
||||
*Provider
|
||||
Code string
|
||||
|
||||
OAuthSession *oauth.Session
|
||||
}
|
||||
|
||||
func NewSession(provider *Provider, code string) *Session {
|
||||
return &Session{Provider: provider, Code: code}
|
||||
}
|
||||
|
||||
// GetAuth implements the [idp.Provider] interface by calling the wrapped [oauth.Session].
|
||||
func (s *Session) GetAuth(ctx context.Context) (idp.Auth, error) {
|
||||
return s.oauth().GetAuth(ctx)
|
||||
}
|
||||
|
||||
// RetrievePreviousID implements the [idp.SessionSupportsMigration] interface by returning the `sub` from the userinfo endpoint
|
||||
func (s *Session) RetrievePreviousID() (string, error) {
|
||||
req, err := http.NewRequest("GET", userinfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("authorization", s.oauth().Tokens.TokenType+" "+s.oauth().Tokens.AccessToken)
|
||||
userinfo := new(oidc.UserInfo)
|
||||
if err := httphelper.HttpRequest(s.Provider.HttpClient(), req, &userinfo); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return userinfo.Subject, nil
|
||||
}
|
||||
|
||||
// PersistentParameters implements the [idp.Session] interface.
|
||||
func (s *Session) PersistentParameters() map[string]any {
|
||||
return nil
|
||||
}
|
||||
|
||||
// FetchUser implements the [idp.Session] interface.
|
||||
// It will execute an OAuth 2.0 code exchange if needed to retrieve the access token,
|
||||
// call the specified userEndpoint and map the received information into an [idp.User].
|
||||
// In case of a specific TenantID as [TenantType] it will additionally extract the id_token and validate it.
|
||||
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
|
||||
user, err = s.oauth().FetchUser(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// since azure will sign the id_token always with the issuer of the application it might differ from
|
||||
// the issuer the auth and token were based on, e.g. when allowing all account types to login,
|
||||
// then the auth endpoint must be `https://login.microsoftonline.com/common/oauth2/v2.0/authorize`
|
||||
// even though the issuer would be like `https://login.microsoftonline.com/d8cdd43f-fd94-4576-8deb-f3bfea72dc2e/v2.0`
|
||||
if s.Provider.tenant == CommonTenant ||
|
||||
s.Provider.tenant == OrganizationsTenant ||
|
||||
s.Provider.tenant == ConsumersTenant {
|
||||
return user, nil
|
||||
}
|
||||
idToken, ok := s.oauth().Tokens.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return user, nil
|
||||
}
|
||||
idTokenVerifier := rp.NewIDTokenVerifier(s.Provider.issuer(), s.Provider.OAuthConfig().ClientID, rp.NewRemoteKeySet(s.Provider.HttpClient(), s.Provider.keysEndpoint()))
|
||||
s.oauth().Tokens.IDTokenClaims, err = rp.VerifyTokens[*oidc.IDTokenClaims](ctx, s.oauth().Tokens.AccessToken, idToken, idTokenVerifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.oauth().Tokens.IDToken = idToken
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *Session) ExpiresAt() time.Time {
|
||||
if s.OAuthSession == nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return s.OAuthSession.ExpiresAt()
|
||||
}
|
||||
|
||||
// Tokens returns the [oidc.Tokens] of the underlying [oauth.Session].
|
||||
func (s *Session) Tokens() *oidc.Tokens[*oidc.IDTokenClaims] {
|
||||
return s.oauth().Tokens
|
||||
}
|
||||
|
||||
func (s *Session) oauth() *oauth.Session {
|
||||
if s.OAuthSession != nil {
|
||||
return s.OAuthSession
|
||||
}
|
||||
s.OAuthSession = &oauth.Session{
|
||||
Code: s.Code,
|
||||
Provider: s.Provider.Provider,
|
||||
}
|
||||
return s.OAuthSession
|
||||
}
|
416
apps/api/internal/idp/providers/azuread/session_test.go
Normal file
416
apps/api/internal/idp/providers/azuread/session_test.go
Normal file
@@ -0,0 +1,416 @@
|
||||
package azuread
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/h2non/gock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
|
||||
)
|
||||
|
||||
func TestSession_FetchUser(t *testing.T) {
|
||||
type fields struct {
|
||||
name string
|
||||
clientID string
|
||||
clientSecret string
|
||||
redirectURI string
|
||||
scopes []string
|
||||
httpMock func()
|
||||
options []ProviderOptions
|
||||
authURL string
|
||||
code string
|
||||
tokens *oidc.Tokens[*oidc.IDTokenClaims]
|
||||
}
|
||||
type want struct {
|
||||
err func(error) bool
|
||||
user idp.User
|
||||
id string
|
||||
firstName string
|
||||
lastName string
|
||||
displayName string
|
||||
nickName string
|
||||
preferredUsername string
|
||||
email string
|
||||
isEmailVerified bool
|
||||
phone string
|
||||
isPhoneVerified bool
|
||||
preferredLanguage language.Tag
|
||||
avatarURL string
|
||||
profile string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "unauthenticated session, error",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
httpMock: func() {
|
||||
gock.New("https://graph.microsoft.com").
|
||||
Get("/v1.0/me").
|
||||
Reply(200).
|
||||
JSON(userinfo())
|
||||
},
|
||||
authURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
|
||||
tokens: nil,
|
||||
},
|
||||
want: want{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, oauth.ErrCodeMissing)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user error",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
httpMock: func() {
|
||||
gock.New("https://graph.microsoft.com").
|
||||
Get("/v1.0/me").
|
||||
Reply(http.StatusInternalServerError)
|
||||
},
|
||||
authURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
|
||||
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: oidc.BearerToken,
|
||||
},
|
||||
IDTokenClaims: oidc.NewIDTokenClaims(
|
||||
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
|
||||
"sub2",
|
||||
[]string{"clientID"},
|
||||
time.Now().Add(1*time.Hour),
|
||||
time.Now().Add(-1*time.Second),
|
||||
"nonce",
|
||||
"",
|
||||
nil,
|
||||
"clientID",
|
||||
0,
|
||||
),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
err: func(err error) bool {
|
||||
return err.Error() == "http status not ok: 500 Internal Server Error "
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful fetch",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
httpMock: func() {
|
||||
gock.New("https://graph.microsoft.com").
|
||||
Get("/v1.0/me").
|
||||
Reply(200).
|
||||
JSON(userinfo())
|
||||
},
|
||||
authURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
|
||||
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: oidc.BearerToken,
|
||||
},
|
||||
IDTokenClaims: oidc.NewIDTokenClaims(
|
||||
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
|
||||
"sub",
|
||||
[]string{"clientID"},
|
||||
time.Now().Add(1*time.Hour),
|
||||
time.Now().Add(-1*time.Second),
|
||||
"nonce",
|
||||
"",
|
||||
nil,
|
||||
"clientID",
|
||||
0,
|
||||
),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
user: &User{
|
||||
ID: "id",
|
||||
BusinessPhones: []domain.PhoneNumber{"phone1", "phone2"},
|
||||
DisplayName: "firstname lastname",
|
||||
FirstName: "firstname",
|
||||
JobTitle: "title",
|
||||
Email: "email",
|
||||
MobilePhone: "mobile",
|
||||
OfficeLocation: "office",
|
||||
PreferredLanguage: "en",
|
||||
LastName: "lastname",
|
||||
UserPrincipalName: "username",
|
||||
isEmailVerified: false,
|
||||
},
|
||||
id: "id",
|
||||
firstName: "firstname",
|
||||
lastName: "lastname",
|
||||
displayName: "firstname lastname",
|
||||
nickName: "",
|
||||
preferredUsername: "username",
|
||||
email: "email",
|
||||
isEmailVerified: false,
|
||||
phone: "",
|
||||
isPhoneVerified: false,
|
||||
preferredLanguage: language.English,
|
||||
profile: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful fetch with email verified",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
options: []ProviderOptions{
|
||||
WithEmailVerified(),
|
||||
},
|
||||
httpMock: func() {
|
||||
gock.New("https://graph.microsoft.com").
|
||||
Get("/v1.0/me").
|
||||
Reply(200).
|
||||
JSON(userinfo())
|
||||
},
|
||||
authURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
|
||||
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: oidc.BearerToken,
|
||||
},
|
||||
IDTokenClaims: oidc.NewIDTokenClaims(
|
||||
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
|
||||
"sub",
|
||||
[]string{"clientID"},
|
||||
time.Now().Add(1*time.Hour),
|
||||
time.Now().Add(-1*time.Second),
|
||||
"nonce",
|
||||
"",
|
||||
nil,
|
||||
"clientID",
|
||||
0,
|
||||
),
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
user: &User{
|
||||
ID: "id",
|
||||
BusinessPhones: []domain.PhoneNumber{"phone1", "phone2"},
|
||||
DisplayName: "firstname lastname",
|
||||
FirstName: "firstname",
|
||||
JobTitle: "title",
|
||||
Email: "email",
|
||||
MobilePhone: "mobile",
|
||||
OfficeLocation: "office",
|
||||
PreferredLanguage: "en",
|
||||
LastName: "lastname",
|
||||
UserPrincipalName: "username",
|
||||
isEmailVerified: true,
|
||||
},
|
||||
id: "id",
|
||||
firstName: "firstname",
|
||||
lastName: "lastname",
|
||||
displayName: "firstname lastname",
|
||||
nickName: "",
|
||||
preferredUsername: "username",
|
||||
email: "email",
|
||||
isEmailVerified: true,
|
||||
phone: "",
|
||||
isPhoneVerified: false,
|
||||
preferredLanguage: language.English,
|
||||
profile: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer gock.Off()
|
||||
tt.fields.httpMock()
|
||||
a := assert.New(t)
|
||||
|
||||
provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.options...)
|
||||
require.NoError(t, err)
|
||||
|
||||
session := &Session{
|
||||
Provider: provider,
|
||||
Code: tt.fields.code,
|
||||
|
||||
OAuthSession: &oauth.Session{
|
||||
AuthURL: tt.fields.authURL,
|
||||
Tokens: tt.fields.tokens,
|
||||
Provider: provider.Provider,
|
||||
Code: tt.fields.code,
|
||||
},
|
||||
}
|
||||
|
||||
user, err := session.FetchUser(context.Background())
|
||||
if tt.want.err != nil && !tt.want.err(err) {
|
||||
a.Fail("invalid error", err)
|
||||
}
|
||||
if tt.want.err == nil {
|
||||
a.NoError(err)
|
||||
a.Equal(tt.want.user, user)
|
||||
a.Equal(tt.want.id, user.GetID())
|
||||
a.Equal(tt.want.firstName, user.GetFirstName())
|
||||
a.Equal(tt.want.lastName, user.GetLastName())
|
||||
a.Equal(tt.want.displayName, user.GetDisplayName())
|
||||
a.Equal(tt.want.nickName, user.GetNickname())
|
||||
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
|
||||
a.Equal(domain.EmailAddress(tt.want.email), user.GetEmail())
|
||||
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
|
||||
a.Equal(domain.PhoneNumber(tt.want.phone), user.GetPhone())
|
||||
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
|
||||
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
|
||||
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
|
||||
a.Equal(tt.want.profile, user.GetProfile())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func userinfo() *User {
|
||||
return &User{
|
||||
ID: "id",
|
||||
BusinessPhones: []domain.PhoneNumber{"phone1", "phone2"},
|
||||
DisplayName: "firstname lastname",
|
||||
FirstName: "firstname",
|
||||
JobTitle: "title",
|
||||
Email: "email",
|
||||
MobilePhone: "mobile",
|
||||
OfficeLocation: "office",
|
||||
PreferredLanguage: "en",
|
||||
LastName: "lastname",
|
||||
UserPrincipalName: "username",
|
||||
}
|
||||
}
|
||||
|
||||
func TestSession_RetrievePreviousID(t *testing.T) {
|
||||
type fields struct {
|
||||
name string
|
||||
clientID string
|
||||
clientSecret string
|
||||
redirectURI string
|
||||
scopes []string
|
||||
httpMock func()
|
||||
tokens *oidc.Tokens[*oidc.IDTokenClaims]
|
||||
}
|
||||
type res struct {
|
||||
id string
|
||||
err bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "invalid token",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
httpMock: func() {
|
||||
gock.New("https://graph.microsoft.com").
|
||||
Get("/oidc/userinfo").
|
||||
Reply(401)
|
||||
},
|
||||
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: oidc.BearerToken,
|
||||
},
|
||||
IDTokenClaims: oidc.NewIDTokenClaims(
|
||||
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
|
||||
"sub",
|
||||
[]string{"clientID"},
|
||||
time.Now().Add(1*time.Hour),
|
||||
time.Now().Add(-1*time.Second),
|
||||
"nonce",
|
||||
"",
|
||||
nil,
|
||||
"clientID",
|
||||
0,
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
err: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
fields: fields{
|
||||
clientID: "clientID",
|
||||
clientSecret: "clientSecret",
|
||||
redirectURI: "redirectURI",
|
||||
httpMock: func() {
|
||||
gock.New("https://graph.microsoft.com").
|
||||
Get("/oidc/userinfo").
|
||||
Reply(200).
|
||||
JSON(`{"sub":"sub"}`)
|
||||
},
|
||||
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
|
||||
Token: &oauth2.Token{
|
||||
AccessToken: "accessToken",
|
||||
TokenType: oidc.BearerToken,
|
||||
},
|
||||
IDTokenClaims: oidc.NewIDTokenClaims(
|
||||
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
|
||||
"sub",
|
||||
[]string{"clientID"},
|
||||
time.Now().Add(1*time.Hour),
|
||||
time.Now().Add(-1*time.Second),
|
||||
"nonce",
|
||||
"",
|
||||
nil,
|
||||
"clientID",
|
||||
0,
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
id: "sub",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer gock.Off()
|
||||
tt.fields.httpMock()
|
||||
a := assert.New(t)
|
||||
|
||||
provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes)
|
||||
require.NoError(t, err)
|
||||
session := &Session{
|
||||
Provider: provider,
|
||||
OAuthSession: &oauth.Session{
|
||||
Tokens: tt.fields.tokens,
|
||||
Provider: provider.Provider,
|
||||
}}
|
||||
|
||||
id, err := session.RetrievePreviousID()
|
||||
if tt.res.err {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
a.Equal(tt.res.id, id)
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user