chore: move the go code into a subfolder

This commit is contained in:
Florian Forster
2025-08-05 15:20:32 -07:00
parent 4ad22ba456
commit cd2921de26
2978 changed files with 373 additions and 300 deletions

View File

@@ -0,0 +1,253 @@
package azuread
import (
"fmt"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
)
const (
issuerTemplate string = "https://login.microsoftonline.com/%s/v2.0"
authURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/v2.0/authorize"
tokenURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/v2.0/token"
keysURLTemplate string = "https://login.microsoftonline.com/%s/discovery/v2.0/keys"
userURL string = "https://graph.microsoft.com/v1.0/me"
userinfoEndpoint string = "https://graph.microsoft.com/oidc/userinfo"
ScopeUserRead string = "User.Read"
)
// TenantType are the well known tenant types to scope the users that can authenticate. TenantType is not an
// exclusive list of Azure Tenants which can be used. A consumer can also use their own Tenant ID to scope
// authentication to their specific Tenant either through the Tenant ID or the friendly domain name.
//
// see also https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-v2-protocols#endpoints
type TenantType string
const (
// CommonTenant allows users with both personal Microsoft accounts and work/school accounts from Azure Active
// Directory to sign in to the application.
CommonTenant TenantType = "common"
// OrganizationsTenant allows only users with work/school accounts from Azure Active Directory to sign in to the application.
OrganizationsTenant TenantType = "organizations"
// ConsumersTenant allows only users with personal Microsoft accounts (MSA) to sign in to the application.
ConsumersTenant TenantType = "consumers"
)
var _ idp.Provider = (*Provider)(nil)
// Provider is the [idp.Provider] implementation for AzureAD (V2 Endpoints)
type Provider struct {
*oauth.Provider
tenant TenantType
emailVerified bool
options []oauth.ProviderOpts
}
// issuer returns the OIDC issuer based on the [TenantType]
func (p *Provider) issuer() string {
return fmt.Sprintf(issuerTemplate, p.tenant)
}
// keysEndpoint returns the OIDC jwks_url based on the [TenantType]
func (p *Provider) keysEndpoint() string {
return fmt.Sprintf(keysURLTemplate, p.tenant)
}
type ProviderOptions func(*Provider)
// WithTenant allows to set a [TenantType] (can also be a Tenant ID)
// default is CommonTenant
func WithTenant(tenantType TenantType) ProviderOptions {
return func(p *Provider) {
p.tenant = tenantType
}
}
// WithEmailVerified allows to set every email received as verified
func WithEmailVerified() ProviderOptions {
return func(p *Provider) {
p.emailVerified = true
}
}
// WithOAuthOptions allows to specify [oauth.ProviderOpts] like [oauth.WithLinkingAllowed]
func WithOAuthOptions(opts ...oauth.ProviderOpts) ProviderOptions {
return func(p *Provider) {
p.options = append(p.options, opts...)
}
}
// New creates an AzureAD provider using the [oauth.Provider] (OAuth 2.0 generic provider).
// By default, it uses the [CommonTenant] and unverified emails.
func New(name, clientID, clientSecret, redirectURI string, scopes []string, opts ...ProviderOptions) (*Provider, error) {
provider := &Provider{
tenant: CommonTenant,
options: make([]oauth.ProviderOpts, 0),
}
for _, opt := range opts {
opt(provider)
}
config := newConfig(provider.tenant, clientID, clientSecret, redirectURI, scopes)
rp, err := oauth.New(
config,
name,
userURL,
func() idp.User {
return &User{isEmailVerified: provider.emailVerified}
},
provider.options...,
)
if err != nil {
return nil, err
}
provider.Provider = rp
return provider, nil
}
func newConfig(tenant TenantType, clientID, secret, callbackURL string, scopes []string) *oauth2.Config {
return &oauth2.Config{
ClientID: clientID,
ClientSecret: secret,
RedirectURL: callbackURL,
Endpoint: oauth2.Endpoint{
AuthURL: fmt.Sprintf(authURLTemplate, tenant),
TokenURL: fmt.Sprintf(tokenURLTemplate, tenant),
},
Scopes: ensureMinimalScope(scopes),
}
}
// ensureMinimalScope ensures that at least openid and `User.Read` ist set
// if none is provided it will request `openid profile email phone User.Read`
func ensureMinimalScope(scopes []string) []string {
if len(scopes) == 0 {
return []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopePhone, ScopeUserRead}
}
var openIDSet, userReadSet bool
for _, scope := range scopes {
if scope == oidc.ScopeOpenID {
openIDSet = true
continue
}
if scope == ScopeUserRead {
userReadSet = true
continue
}
}
if !openIDSet {
scopes = append(scopes, oidc.ScopeOpenID)
}
if !userReadSet {
scopes = append(scopes, ScopeUserRead)
}
return scopes
}
func (p *Provider) User() idp.User {
return p.Provider.User()
}
// User represents the structure return on the userinfo endpoint and implements the [idp.User] interface
//
// AzureAD does not return an `email_verified` claim.
// The verification can be automatically activated on the provider ([WithEmailVerified])
type User struct {
ID string `json:"id"`
BusinessPhones []domain.PhoneNumber `json:"businessPhones"`
DisplayName string `json:"displayName"`
FirstName string `json:"givenName"`
JobTitle string `json:"jobTitle"`
Email domain.EmailAddress `json:"mail"`
MobilePhone domain.PhoneNumber `json:"mobilePhone"`
OfficeLocation string `json:"officeLocation"`
PreferredLanguage string `json:"preferredLanguage"`
LastName string `json:"surname"`
UserPrincipalName string `json:"userPrincipalName"`
isEmailVerified bool
}
// GetID is an implementation of the [idp.User] interface.
func (u *User) GetID() string {
return u.ID
}
// GetFirstName is an implementation of the [idp.User] interface.
func (u *User) GetFirstName() string {
return u.FirstName
}
// GetLastName is an implementation of the [idp.User] interface.
func (u *User) GetLastName() string {
return u.LastName
}
// GetDisplayName is an implementation of the [idp.User] interface.
func (u *User) GetDisplayName() string {
return u.DisplayName
}
// GetNickname is an implementation of the [idp.User] interface.
// It returns an empty string because AzureAD does not provide the user's nickname.
func (u *User) GetNickname() string {
return ""
}
// GetPreferredUsername is an implementation of the [idp.User] interface.
func (u *User) GetPreferredUsername() string {
return u.UserPrincipalName
}
// GetEmail is an implementation of the [idp.User] interface.
func (u *User) GetEmail() domain.EmailAddress {
if u.Email == "" {
// if the user used a social login on Azure as well, the email will be empty
// but is used as username
return domain.EmailAddress(u.UserPrincipalName)
}
return u.Email
}
// IsEmailVerified is an implementation of the [idp.User] interface
// returning the value specified in the creation of the [Provider].
// Default is false because AzureAD does not return an `email_verified` claim.
// The verification can be automatically activated on the provider ([WithEmailVerified]).
func (u *User) IsEmailVerified() bool {
return u.isEmailVerified
}
// GetPhone is an implementation of the [idp.User] interface.
// It returns an empty string because AzureAD does not provide the user's phone.
func (u *User) GetPhone() domain.PhoneNumber {
return ""
}
// IsPhoneVerified is an implementation of the [idp.User] interface.
// It returns false because AzureAD does not provide the user's phone.
func (u *User) IsPhoneVerified() bool {
return false
}
// GetPreferredLanguage is an implementation of the [idp.User] interface.
func (u *User) GetPreferredLanguage() language.Tag {
return language.Make(u.PreferredLanguage)
}
// GetProfile is an implementation of the [idp.User] interface.
// It returns an empty string because AzureAD does not provide the user's profile page.
func (u *User) GetProfile() string {
return ""
}
// GetAvatarURL is an implementation of the [idp.User] interface.
func (u *User) GetAvatarURL() string {
return ""
}

View File

@@ -0,0 +1,185 @@
package azuread
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/client/rp"
openid "github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
)
func TestProvider_BeginAuth(t *testing.T) {
type fields struct {
name string
clientID string
clientSecret string
redirectURI string
scopes []string
options []ProviderOptions
}
tests := []struct {
name string
fields fields
want idp.Session
}{
{
name: "default common tenant",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
},
want: &oidc.Session{
AuthURL: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email+phone+User.Read&state=testState",
},
},
{
name: "tenant",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
options: []ProviderOptions{
WithTenant(ConsumersTenant),
},
},
want: &oidc.Session{
AuthURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email+phone+User.Read&state=testState",
},
},
{
name: "custom scopes",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
scopes: []string{openid.ScopeOpenID, openid.ScopeProfile, "custom"},
options: []ProviderOptions{
WithTenant(ConsumersTenant),
},
},
want: &oidc.Session{
AuthURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid+profile+custom+User.Read&state=testState",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
r := require.New(t)
provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.options...)
r.NoError(err)
ctx := context.Background()
session, err := provider.BeginAuth(ctx, "testState")
r.NoError(err)
wantAuth, wantErr := tt.want.GetAuth(ctx)
gotAuth, gotErr := session.GetAuth(ctx)
a.Equal(wantAuth, gotAuth)
a.ErrorIs(gotErr, wantErr)
})
}
}
func TestProvider_Options(t *testing.T) {
type fields struct {
name string
clientID string
clientSecret string
redirectURI string
scopes []string
options []ProviderOptions
}
type want struct {
name string
tenant TenantType
emailVerified bool
linkingAllowed bool
creationAllowed bool
autoCreation bool
autoUpdate bool
pkce bool
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "default common tenant",
fields: fields{
name: "default common tenant",
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
scopes: nil,
options: nil,
},
want: want{
name: "default common tenant",
tenant: CommonTenant,
emailVerified: false,
linkingAllowed: false,
creationAllowed: false,
autoCreation: false,
autoUpdate: false,
pkce: false,
},
},
{
name: "all set",
fields: fields{
name: "custom tenant",
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
options: []ProviderOptions{
WithTenant("tenant"),
WithEmailVerified(),
WithOAuthOptions(
oauth.WithLinkingAllowed(),
oauth.WithCreationAllowed(),
oauth.WithAutoCreation(),
oauth.WithAutoUpdate(),
oauth.WithRelyingPartyOption(rp.WithPKCE(nil)),
),
},
},
want: want{
name: "custom tenant",
tenant: "tenant",
emailVerified: true,
linkingAllowed: true,
creationAllowed: true,
autoCreation: true,
autoUpdate: true,
pkce: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.options...)
require.NoError(t, err)
a.Equal(tt.want.name, provider.Name())
a.Equal(tt.want.tenant, provider.tenant)
a.Equal(tt.want.emailVerified, provider.emailVerified)
a.Equal(tt.want.linkingAllowed, provider.IsLinkingAllowed())
a.Equal(tt.want.creationAllowed, provider.IsCreationAllowed())
a.Equal(tt.want.autoCreation, provider.IsAutoCreation())
a.Equal(tt.want.autoUpdate, provider.IsAutoUpdate())
a.Equal(tt.want.pkce, provider.RelyingParty.IsPKCE())
})
}
}

View File

@@ -0,0 +1,106 @@
package azuread
import (
"context"
"net/http"
"time"
"github.com/zitadel/oidc/v3/pkg/client/rp"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
)
var _ idp.Session = (*Session)(nil)
// Session extends the [oauth.Session] to be able to handle the id_token and to implement the [idp.SessionSupportsMigration] functionality
type Session struct {
*Provider
Code string
OAuthSession *oauth.Session
}
func NewSession(provider *Provider, code string) *Session {
return &Session{Provider: provider, Code: code}
}
// GetAuth implements the [idp.Provider] interface by calling the wrapped [oauth.Session].
func (s *Session) GetAuth(ctx context.Context) (idp.Auth, error) {
return s.oauth().GetAuth(ctx)
}
// RetrievePreviousID implements the [idp.SessionSupportsMigration] interface by returning the `sub` from the userinfo endpoint
func (s *Session) RetrievePreviousID() (string, error) {
req, err := http.NewRequest("GET", userinfoEndpoint, nil)
if err != nil {
return "", err
}
req.Header.Set("authorization", s.oauth().Tokens.TokenType+" "+s.oauth().Tokens.AccessToken)
userinfo := new(oidc.UserInfo)
if err := httphelper.HttpRequest(s.Provider.HttpClient(), req, &userinfo); err != nil {
return "", err
}
return userinfo.Subject, nil
}
// PersistentParameters implements the [idp.Session] interface.
func (s *Session) PersistentParameters() map[string]any {
return nil
}
// FetchUser implements the [idp.Session] interface.
// It will execute an OAuth 2.0 code exchange if needed to retrieve the access token,
// call the specified userEndpoint and map the received information into an [idp.User].
// In case of a specific TenantID as [TenantType] it will additionally extract the id_token and validate it.
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
user, err = s.oauth().FetchUser(ctx)
if err != nil {
return nil, err
}
// since azure will sign the id_token always with the issuer of the application it might differ from
// the issuer the auth and token were based on, e.g. when allowing all account types to login,
// then the auth endpoint must be `https://login.microsoftonline.com/common/oauth2/v2.0/authorize`
// even though the issuer would be like `https://login.microsoftonline.com/d8cdd43f-fd94-4576-8deb-f3bfea72dc2e/v2.0`
if s.Provider.tenant == CommonTenant ||
s.Provider.tenant == OrganizationsTenant ||
s.Provider.tenant == ConsumersTenant {
return user, nil
}
idToken, ok := s.oauth().Tokens.Extra("id_token").(string)
if !ok {
return user, nil
}
idTokenVerifier := rp.NewIDTokenVerifier(s.Provider.issuer(), s.Provider.OAuthConfig().ClientID, rp.NewRemoteKeySet(s.Provider.HttpClient(), s.Provider.keysEndpoint()))
s.oauth().Tokens.IDTokenClaims, err = rp.VerifyTokens[*oidc.IDTokenClaims](ctx, s.oauth().Tokens.AccessToken, idToken, idTokenVerifier)
if err != nil {
return nil, err
}
s.oauth().Tokens.IDToken = idToken
return user, nil
}
func (s *Session) ExpiresAt() time.Time {
if s.OAuthSession == nil {
return time.Time{}
}
return s.OAuthSession.ExpiresAt()
}
// Tokens returns the [oidc.Tokens] of the underlying [oauth.Session].
func (s *Session) Tokens() *oidc.Tokens[*oidc.IDTokenClaims] {
return s.oauth().Tokens
}
func (s *Session) oauth() *oauth.Session {
if s.OAuthSession != nil {
return s.OAuthSession
}
s.OAuthSession = &oauth.Session{
Code: s.Code,
Provider: s.Provider.Provider,
}
return s.OAuthSession
}

View File

@@ -0,0 +1,416 @@
package azuread
import (
"context"
"errors"
"net/http"
"testing"
"time"
"github.com/h2non/gock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v3/pkg/oidc"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
)
func TestSession_FetchUser(t *testing.T) {
type fields struct {
name string
clientID string
clientSecret string
redirectURI string
scopes []string
httpMock func()
options []ProviderOptions
authURL string
code string
tokens *oidc.Tokens[*oidc.IDTokenClaims]
}
type want struct {
err func(error) bool
user idp.User
id string
firstName string
lastName string
displayName string
nickName string
preferredUsername string
email string
isEmailVerified bool
phone string
isPhoneVerified bool
preferredLanguage language.Tag
avatarURL string
profile string
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "unauthenticated session, error",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://graph.microsoft.com").
Get("/v1.0/me").
Reply(200).
JSON(userinfo())
},
authURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
tokens: nil,
},
want: want{
err: func(err error) bool {
return errors.Is(err, oauth.ErrCodeMissing)
},
},
},
{
name: "user error",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://graph.microsoft.com").
Get("/v1.0/me").
Reply(http.StatusInternalServerError)
},
authURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
IDTokenClaims: oidc.NewIDTokenClaims(
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
"sub2",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
want: want{
err: func(err error) bool {
return err.Error() == "http status not ok: 500 Internal Server Error "
},
},
},
{
name: "successful fetch",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://graph.microsoft.com").
Get("/v1.0/me").
Reply(200).
JSON(userinfo())
},
authURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
IDTokenClaims: oidc.NewIDTokenClaims(
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
"sub",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
want: want{
user: &User{
ID: "id",
BusinessPhones: []domain.PhoneNumber{"phone1", "phone2"},
DisplayName: "firstname lastname",
FirstName: "firstname",
JobTitle: "title",
Email: "email",
MobilePhone: "mobile",
OfficeLocation: "office",
PreferredLanguage: "en",
LastName: "lastname",
UserPrincipalName: "username",
isEmailVerified: false,
},
id: "id",
firstName: "firstname",
lastName: "lastname",
displayName: "firstname lastname",
nickName: "",
preferredUsername: "username",
email: "email",
isEmailVerified: false,
phone: "",
isPhoneVerified: false,
preferredLanguage: language.English,
profile: "",
},
},
{
name: "successful fetch with email verified",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
options: []ProviderOptions{
WithEmailVerified(),
},
httpMock: func() {
gock.New("https://graph.microsoft.com").
Get("/v1.0/me").
Reply(200).
JSON(userinfo())
},
authURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
IDTokenClaims: oidc.NewIDTokenClaims(
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
"sub",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
want: want{
user: &User{
ID: "id",
BusinessPhones: []domain.PhoneNumber{"phone1", "phone2"},
DisplayName: "firstname lastname",
FirstName: "firstname",
JobTitle: "title",
Email: "email",
MobilePhone: "mobile",
OfficeLocation: "office",
PreferredLanguage: "en",
LastName: "lastname",
UserPrincipalName: "username",
isEmailVerified: true,
},
id: "id",
firstName: "firstname",
lastName: "lastname",
displayName: "firstname lastname",
nickName: "",
preferredUsername: "username",
email: "email",
isEmailVerified: true,
phone: "",
isPhoneVerified: false,
preferredLanguage: language.English,
profile: "",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer gock.Off()
tt.fields.httpMock()
a := assert.New(t)
provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.options...)
require.NoError(t, err)
session := &Session{
Provider: provider,
Code: tt.fields.code,
OAuthSession: &oauth.Session{
AuthURL: tt.fields.authURL,
Tokens: tt.fields.tokens,
Provider: provider.Provider,
Code: tt.fields.code,
},
}
user, err := session.FetchUser(context.Background())
if tt.want.err != nil && !tt.want.err(err) {
a.Fail("invalid error", err)
}
if tt.want.err == nil {
a.NoError(err)
a.Equal(tt.want.user, user)
a.Equal(tt.want.id, user.GetID())
a.Equal(tt.want.firstName, user.GetFirstName())
a.Equal(tt.want.lastName, user.GetLastName())
a.Equal(tt.want.displayName, user.GetDisplayName())
a.Equal(tt.want.nickName, user.GetNickname())
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
a.Equal(domain.EmailAddress(tt.want.email), user.GetEmail())
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
a.Equal(domain.PhoneNumber(tt.want.phone), user.GetPhone())
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
a.Equal(tt.want.profile, user.GetProfile())
}
})
}
}
func userinfo() *User {
return &User{
ID: "id",
BusinessPhones: []domain.PhoneNumber{"phone1", "phone2"},
DisplayName: "firstname lastname",
FirstName: "firstname",
JobTitle: "title",
Email: "email",
MobilePhone: "mobile",
OfficeLocation: "office",
PreferredLanguage: "en",
LastName: "lastname",
UserPrincipalName: "username",
}
}
func TestSession_RetrievePreviousID(t *testing.T) {
type fields struct {
name string
clientID string
clientSecret string
redirectURI string
scopes []string
httpMock func()
tokens *oidc.Tokens[*oidc.IDTokenClaims]
}
type res struct {
id string
err bool
}
tests := []struct {
name string
fields fields
res res
}{
{
name: "invalid token",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://graph.microsoft.com").
Get("/oidc/userinfo").
Reply(401)
},
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
IDTokenClaims: oidc.NewIDTokenClaims(
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
"sub",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
res: res{
err: true,
},
},
{
name: "success",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://graph.microsoft.com").
Get("/oidc/userinfo").
Reply(200).
JSON(`{"sub":"sub"}`)
},
tokens: &oidc.Tokens[*oidc.IDTokenClaims]{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
IDTokenClaims: oidc.NewIDTokenClaims(
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
"sub",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
res: res{
id: "sub",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer gock.Off()
tt.fields.httpMock()
a := assert.New(t)
provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes)
require.NoError(t, err)
session := &Session{
Provider: provider,
OAuthSession: &oauth.Session{
Tokens: tt.fields.tokens,
Provider: provider.Provider,
}}
id, err := session.RetrievePreviousID()
if tt.res.err {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
a.Equal(tt.res.id, id)
})
}
}