feat: add basic structure of idp templates (#5053)

add basic structure and implement first providers for IDP templates to be able to manage and use them in the future
This commit is contained in:
Livio Spring
2023-01-23 08:11:40 +01:00
committed by GitHub
parent 7b5135e637
commit 598a4d2d4b
29 changed files with 3907 additions and 54 deletions

View File

@@ -0,0 +1,205 @@
package azuread
import (
"fmt"
"github.com/zitadel/oidc/v2/pkg/oidc"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
)
const (
authURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/v2.0/authorize"
tokenURLTemplate string = "https://login.microsoftonline.com/%s/oauth2/v2.0/token"
userinfoURL string = "https://graph.microsoft.com/oidc/userinfo"
)
// TenantType are the well known tenant types to scope the users that can authenticate. TenantType is not an
// exclusive list of Azure Tenants which can be used. A consumer can also use their own Tenant ID to scope
// authentication to their specific Tenant either through the Tenant ID or the friendly domain name.
//
// see also https://docs.microsoft.com/en-us/azure/active-directory/develop/active-directory-v2-protocols#endpoints
type TenantType string
const (
// CommonTenant allows users with both personal Microsoft accounts and work/school accounts from Azure Active
// Directory to sign in to the application.
CommonTenant TenantType = "common"
// OrganizationsTenant allows only users with work/school accounts from Azure Active Directory to sign in to the application.
OrganizationsTenant TenantType = "organizations"
// ConsumersTenant allows only users with personal Microsoft accounts (MSA) to sign in to the application.
ConsumersTenant TenantType = "consumers"
)
var _ idp.Provider = (*Provider)(nil)
// Provider is the [idp.Provider] implementation for AzureAD (V2 Endpoints)
type Provider struct {
*oauth.Provider
tenant TenantType
emailVerified bool
options []oauth.ProviderOpts
}
type ProviderOptions func(*Provider)
// WithTenant allows to set a [TenantType] (can also be a Tenant ID)
// default is CommonTenant
func WithTenant(tenantType TenantType) ProviderOptions {
return func(p *Provider) {
p.tenant = tenantType
}
}
// WithEmailVerified allows to set every email received as verified
func WithEmailVerified() ProviderOptions {
return func(p *Provider) {
p.emailVerified = true
}
}
// WithOAuthOptions allows to specify [oauth.ProviderOpts] like [oauth.WithLinkingAllowed]
func WithOAuthOptions(opts ...oauth.ProviderOpts) ProviderOptions {
return func(p *Provider) {
p.options = append(p.options, opts...)
}
}
// New creates an AzureAD provider using the [oauth.Provider] (OAuth 2.0 generic provider).
// By default, it uses the [CommonTenant] and unverified emails.
func New(name, clientID, clientSecret, redirectURI string, opts ...ProviderOptions) (*Provider, error) {
provider := &Provider{
tenant: CommonTenant,
options: make([]oauth.ProviderOpts, 0),
}
for _, opt := range opts {
opt(provider)
}
config := newConfig(provider.tenant, clientID, clientSecret, redirectURI, []string{oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail})
rp, err := oauth.New(
config,
name,
userinfoURL,
func() idp.User {
return &User{isEmailVerified: provider.emailVerified}
},
provider.options...,
)
if err != nil {
return nil, err
}
provider.Provider = rp
return provider, nil
}
func newConfig(tenant TenantType, clientID, secret, callbackURL string, scopes []string) *oauth2.Config {
c := &oauth2.Config{
ClientID: clientID,
ClientSecret: secret,
RedirectURL: callbackURL,
Endpoint: oauth2.Endpoint{
AuthURL: fmt.Sprintf(authURLTemplate, tenant),
TokenURL: fmt.Sprintf(tokenURLTemplate, tenant),
},
Scopes: []string{oidc.ScopeOpenID},
}
if len(scopes) > 0 {
c.Scopes = scopes
}
return c
}
// User represents the structure return on the userinfo endpoint and implements the [idp.User] interface
//
// AzureAD does not return an `email_verified` claim.
// The verification can be automatically activated on the provider ([WithEmailVerified])
type User struct {
Sub string `json:"sub"`
FamilyName string `json:"family_name"`
GivenName string `json:"given_name"`
Name string `json:"name"`
PreferredUsername string `json:"preferred_username"`
Email string `json:"email"`
Picture string `json:"picture"`
isEmailVerified bool
}
// GetID is an implementation of the [idp.User] interface.
func (u *User) GetID() string {
return u.Sub
}
// GetFirstName is an implementation of the [idp.User] interface.
func (u *User) GetFirstName() string {
return u.GivenName
}
// GetLastName is an implementation of the [idp.User] interface.
func (u *User) GetLastName() string {
return u.FamilyName
}
// GetDisplayName is an implementation of the [idp.User] interface.
func (u *User) GetDisplayName() string {
return u.Name
}
// GetNickname is an implementation of the [idp.User] interface.
// It returns an empty string because AzureAD does not provide the user's nickname.
func (u *User) GetNickname() string {
return ""
}
// GetPreferredUsername is an implementation of the [idp.User] interface.
func (u *User) GetPreferredUsername() string {
return u.PreferredUsername
}
// GetEmail is an implementation of the [idp.User] interface.
func (u *User) GetEmail() string {
return u.Email
}
// IsEmailVerified is an implementation of the [idp.User] interface
// returning the value specified in the creation of the [Provider].
// Default is false because AzureAD does not return an `email_verified` claim.
// The verification can be automatically activated on the provider ([WithEmailVerified]).
func (u *User) IsEmailVerified() bool {
return u.isEmailVerified
}
// GetPhone is an implementation of the [idp.User] interface.
// It returns an empty string because AzureAD does not provide the user's phone.
func (u *User) GetPhone() string {
return ""
}
// IsPhoneVerified is an implementation of the [idp.User] interface.
// It returns false because AzureAD does not provide the user's phone.
func (u *User) IsPhoneVerified() bool {
return false
}
// GetPreferredLanguage is an implementation of the [idp.User] interface.
// It returns [language.Und] because AzureAD does not provide the user's language
func (u *User) GetPreferredLanguage() language.Tag {
// AzureAD does not provide the user's language
return language.Und
}
// GetProfile is an implementation of the [idp.User] interface.
// It returns an empty string because AzureAD does not provide the user's profile page.
func (u *User) GetProfile() string {
return ""
}
// GetAvatarURL is an implementation of the [idp.User] interface.
func (u *User) GetAvatarURL() string {
return u.Picture
}

View File

@@ -0,0 +1,162 @@
package azuread
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/pkg/client/rp"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
)
func TestProvider_BeginAuth(t *testing.T) {
type fields struct {
name string
clientID string
clientSecret string
redirectURI string
options []ProviderOptions
}
tests := []struct {
name string
fields fields
want idp.Session
}{
{
name: "default common tenant",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
},
want: &oidc.Session{
AuthURL: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
},
},
{
name: "tenant",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
options: []ProviderOptions{
WithTenant(ConsumersTenant),
},
},
want: &oidc.Session{
AuthURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
r := require.New(t)
provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.options...)
r.NoError(err)
session, err := provider.BeginAuth(context.Background(), "testState")
r.NoError(err)
a.Equal(tt.want.GetAuthURL(), session.GetAuthURL())
})
}
}
func TestProvider_Options(t *testing.T) {
type fields struct {
name string
clientID string
clientSecret string
redirectURI string
options []ProviderOptions
}
type want struct {
name string
tenant TenantType
emailVerified bool
linkingAllowed bool
creationAllowed bool
autoCreation bool
autoUpdate bool
pkce bool
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "default common tenant",
fields: fields{
name: "default common tenant",
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
options: nil,
},
want: want{
name: "default common tenant",
tenant: CommonTenant,
emailVerified: false,
linkingAllowed: false,
creationAllowed: false,
autoCreation: false,
autoUpdate: false,
pkce: false,
},
},
{
name: "all set",
fields: fields{
name: "custom tenant",
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
options: []ProviderOptions{
WithTenant("tenant"),
WithEmailVerified(),
WithOAuthOptions(
oauth.WithLinkingAllowed(),
oauth.WithCreationAllowed(),
oauth.WithAutoCreation(),
oauth.WithAutoUpdate(),
oauth.WithRelyingPartyOption(rp.WithPKCE(nil)),
),
},
},
want: want{
name: "custom tenant",
tenant: "tenant",
emailVerified: true,
linkingAllowed: true,
creationAllowed: true,
autoCreation: true,
autoUpdate: true,
pkce: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.options...)
require.NoError(t, err)
a.Equal(tt.want.name, provider.Name())
a.Equal(tt.want.tenant, provider.tenant)
a.Equal(tt.want.emailVerified, provider.emailVerified)
a.Equal(tt.want.linkingAllowed, provider.IsLinkingAllowed())
a.Equal(tt.want.creationAllowed, provider.IsCreationAllowed())
a.Equal(tt.want.autoCreation, provider.IsAutoCreation())
a.Equal(tt.want.autoUpdate, provider.IsAutoUpdate())
a.Equal(tt.want.pkce, provider.RelyingParty.IsPKCE())
})
}
}

View File

@@ -0,0 +1,285 @@
package azuread
import (
"context"
"errors"
"net/http"
"testing"
"time"
"github.com/h2non/gock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/pkg/oidc"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
)
func TestSession_FetchUser(t *testing.T) {
type fields struct {
name string
clientID string
clientSecret string
redirectURI string
httpMock func()
options []ProviderOptions
authURL string
code string
tokens *oidc.Tokens
}
type want struct {
err func(error) bool
user idp.User
id string
firstName string
lastName string
displayName string
nickName string
preferredUsername string
email string
isEmailVerified bool
phone string
isPhoneVerified bool
preferredLanguage language.Tag
avatarURL string
profile string
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "unauthenticated session, error",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://graph.microsoft.com").
Get("/oidc/userinfo").
Reply(200).
JSON(userinfo())
},
authURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
tokens: nil,
},
want: want{
err: func(err error) bool {
return errors.Is(err, oauth.ErrCodeMissing)
},
},
},
{
name: "user error",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://graph.microsoft.com").
Get("/oidc/userinfo").
Reply(http.StatusInternalServerError)
},
authURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
tokens: &oidc.Tokens{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
IDTokenClaims: oidc.NewIDTokenClaims(
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
"sub2",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
want: want{
err: func(err error) bool {
return err.Error() == "http status not ok: 500 Internal Server Error "
},
},
},
{
name: "successful fetch",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://graph.microsoft.com").
Get("/oidc/userinfo").
Reply(200).
JSON(userinfo())
},
authURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
tokens: &oidc.Tokens{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
IDTokenClaims: oidc.NewIDTokenClaims(
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
"sub",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
want: want{
user: &User{
Sub: "sub",
FamilyName: "lastname",
GivenName: "firstname",
Name: "firstname lastname",
PreferredUsername: "username",
Email: "email",
Picture: "picture",
isEmailVerified: false,
},
id: "sub",
firstName: "firstname",
lastName: "lastname",
displayName: "firstname lastname",
nickName: "",
preferredUsername: "username",
email: "email",
isEmailVerified: false,
phone: "",
isPhoneVerified: false,
preferredLanguage: language.Und,
avatarURL: "picture",
profile: "",
},
},
{
name: "successful fetch with email verified",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
options: []ProviderOptions{
WithEmailVerified(),
},
httpMock: func() {
gock.New("https://graph.microsoft.com").
Get("/oidc/userinfo").
Reply(200).
JSON(userinfo())
},
authURL: "https://login.microsoftonline.com/consumers/oauth2/v2.0/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid+profile+email&state=testState",
tokens: &oidc.Tokens{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
IDTokenClaims: oidc.NewIDTokenClaims(
"https://login.microsoftonline.com/consumers/oauth2/v2.0",
"sub",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
want: want{
user: &User{
Sub: "sub",
FamilyName: "lastname",
GivenName: "firstname",
Name: "firstname lastname",
PreferredUsername: "username",
Email: "email",
Picture: "picture",
isEmailVerified: true,
},
id: "sub",
firstName: "firstname",
lastName: "lastname",
displayName: "firstname lastname",
nickName: "",
preferredUsername: "username",
email: "email",
isEmailVerified: true,
phone: "",
isPhoneVerified: false,
preferredLanguage: language.Und,
avatarURL: "picture",
profile: "",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer gock.Off()
tt.fields.httpMock()
a := assert.New(t)
provider, err := New(tt.fields.name, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.options...)
require.NoError(t, err)
session := &oauth.Session{
AuthURL: tt.fields.authURL,
Code: tt.fields.code,
Tokens: tt.fields.tokens,
Provider: provider.Provider,
}
user, err := session.FetchUser(context.Background())
if tt.want.err != nil && !tt.want.err(err) {
a.Fail("invalid error", err)
}
if tt.want.err == nil {
a.NoError(err)
a.Equal(tt.want.user, user)
a.Equal(tt.want.id, user.GetID())
a.Equal(tt.want.firstName, user.GetFirstName())
a.Equal(tt.want.lastName, user.GetLastName())
a.Equal(tt.want.displayName, user.GetDisplayName())
a.Equal(tt.want.nickName, user.GetNickname())
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
a.Equal(tt.want.email, user.GetEmail())
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
a.Equal(tt.want.phone, user.GetPhone())
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
a.Equal(tt.want.profile, user.GetProfile())
}
})
}
}
func userinfo() oidc.UserInfoSetter {
userinfo := oidc.NewUserInfo()
userinfo.SetSubject("sub")
userinfo.SetName("firstname lastname")
userinfo.SetPreferredUsername("username")
userinfo.SetNickname("nickname")
userinfo.SetEmail("email", false) // azure add does not send the email_verified claim
userinfo.SetPicture("picture")
userinfo.SetGivenName("firstname")
userinfo.SetFamilyName("lastname")
return userinfo
}

View File

@@ -0,0 +1,189 @@
package github
import (
"strconv"
"time"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
)
const (
authURL = "https://github.com/login/oauth/authorize"
tokenURL = "https://github.com/login/oauth/access_token"
profileURL = "https://api.github.com/user"
name = "GitHub"
)
var _ idp.Provider = (*Provider)(nil)
// New creates a GitHub.com provider using the [oauth.Provider] (OAuth 2.0 generic provider)
func New(clientID, secret, callbackURL string, scopes []string, options ...oauth.ProviderOpts) (*Provider, error) {
return NewCustomURL(name, clientID, secret, callbackURL, authURL, tokenURL, profileURL, scopes, options...)
}
// NewCustomURL creates a GitHub provider using the [oauth.Provider] (OAuth 2.0 generic provider)
// with custom endpoints, e.g. GitHub Enterprise server
func NewCustomURL(name, clientID, secret, callbackURL, authURL, tokenURL, profileURL string, scopes []string, options ...oauth.ProviderOpts) (*Provider, error) {
rp, err := oauth.New(
newConfig(clientID, secret, callbackURL, authURL, tokenURL, scopes),
name,
profileURL,
func() idp.User {
return new(User)
},
options...,
)
if err != nil {
return nil, err
}
return &Provider{
Provider: rp,
}, nil
}
// Provider is the [idp.Provider] implementation for GitHub
type Provider struct {
*oauth.Provider
}
func newConfig(clientID, secret, callbackURL, authURL, tokenURL string, scopes []string) *oauth2.Config {
c := &oauth2.Config{
ClientID: clientID,
ClientSecret: secret,
RedirectURL: callbackURL,
Endpoint: oauth2.Endpoint{
AuthURL: authURL,
TokenURL: tokenURL,
},
Scopes: scopes,
}
return c
}
// User is a representation of the authenticated GitHub user and implements the [idp.User] interface
// https://docs.github.com/en/rest/users/users?apiVersion=2022-11-28#get-the-authenticated-user
type User struct {
Login string `json:"login"`
ID int `json:"id"`
NodeId string `json:"node_id"`
AvatarUrl string `json:"avatar_url"`
GravatarId string `json:"gravatar_id"`
Url string `json:"url"`
HtmlUrl string `json:"html_url"`
FollowersUrl string `json:"followers_url"`
FollowingUrl string `json:"following_url"`
GistsUrl string `json:"gists_url"`
StarredUrl string `json:"starred_url"`
SubscriptionsUrl string `json:"subscriptions_url"`
OrganizationsUrl string `json:"organizations_url"`
ReposUrl string `json:"repos_url"`
EventsUrl string `json:"events_url"`
ReceivedEventsUrl string `json:"received_events_url"`
Type string `json:"type"`
SiteAdmin bool `json:"site_admin"`
Name string `json:"name"`
Company string `json:"company"`
Blog string `json:"blog"`
Location string `json:"location"`
Email string `json:"email"`
Hireable bool `json:"hireable"`
Bio string `json:"bio"`
TwitterUsername string `json:"twitter_username"`
PublicRepos int `json:"public_repos"`
PublicGists int `json:"public_gists"`
Followers int `json:"followers"`
Following int `json:"following"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
PrivateGists int `json:"private_gists"`
TotalPrivateRepos int `json:"total_private_repos"`
OwnedPrivateRepos int `json:"owned_private_repos"`
DiskUsage int `json:"disk_usage"`
Collaborators int `json:"collaborators"`
TwoFactorAuthentication bool `json:"two_factor_authentication"`
Plan struct {
Name string `json:"name"`
Space int `json:"space"`
PrivateRepos int `json:"private_repos"`
Collaborators int `json:"collaborators"`
} `json:"plan"`
}
// GetID is an implementation of the [idp.User] interface.
func (u *User) GetID() string {
return strconv.Itoa(u.ID)
}
// GetFirstName is an implementation of the [idp.User] interface.
// It returns an empty string because GitHub does not provide the user's firstname.
func (u *User) GetFirstName() string {
return ""
}
// GetLastName is an implementation of the [idp.User] interface.
// It returns an empty string because GitHub does not provide the user's lastname.
func (u *User) GetLastName() string {
// GitHub does not provide the user's lastname
return ""
}
// GetDisplayName is an implementation of the [idp.User] interface.
func (u *User) GetDisplayName() string {
return u.Name
}
// GetNickname is an implementation of the [idp.User] interface
// returning the login name of the GitHub user.
func (u *User) GetNickname() string {
return u.Login
}
// GetPreferredUsername is an implementation of the [idp.User] interface
// returning the login name of the GitHub user.
func (u *User) GetPreferredUsername() string {
return u.Login
}
// GetEmail is an implementation of the [idp.User] interface.
func (u *User) GetEmail() string {
return u.Email
}
// IsEmailVerified is an implementation of the [idp.User] interface.
// It returns true because GitHub validates emails themselves.
func (u *User) IsEmailVerified() bool {
return true
}
// GetPhone is an implementation of the [idp.User] interface.
// It returns an empty string because GitHub does not provide the user's phone.
func (u *User) GetPhone() string {
return ""
}
// IsPhoneVerified is an implementation of the [idp.User] interface
// it returns false because GitHub does not provide the user's phone
func (u *User) IsPhoneVerified() bool {
return false
}
// GetPreferredLanguage is an implementation of the [idp.User] interface.
// It returns [language.Und] because GitHub does not provide the user's language.
func (u *User) GetPreferredLanguage() language.Tag {
return language.Und
}
// GetProfile is an implementation of the [idp.User] interface.
func (u *User) GetProfile() string {
return u.HtmlUrl
}
// GetAvatarURL is an implementation of the [idp.User] interface.
func (u *User) GetAvatarURL() string {
return u.AvatarUrl
}

View File

@@ -0,0 +1,53 @@
package github
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
)
func TestProvider_BeginAuth(t *testing.T) {
type fields struct {
clientID string
clientSecret string
redirectURI string
scopes []string
options []oauth.ProviderOpts
}
tests := []struct {
name string
fields fields
want idp.Session
}{
{
name: "successful auth",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
},
want: &oauth.Session{
AuthURL: "https://github.com/login/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&state=testState",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
r := require.New(t)
provider, err := New(tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.options...)
r.NoError(err)
session, err := provider.BeginAuth(context.Background(), "testState")
r.NoError(err)
a.Equal(tt.want.GetAuthURL(), session.GetAuthURL())
})
}
}

View File

@@ -0,0 +1,215 @@
package github
import (
"context"
"errors"
"net/http"
"testing"
"time"
"github.com/h2non/gock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/pkg/oidc"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
)
func TestSession_FetchUser(t *testing.T) {
type fields struct {
clientID string
clientSecret string
redirectURI string
httpMock func()
authURL string
code string
tokens *oidc.Tokens
scopes []string
options []oauth.ProviderOpts
}
type args struct {
session idp.Session
}
type want struct {
err func(error) bool
user idp.User
id string
firstName string
lastName string
displayName string
nickName string
preferredUsername string
email string
isEmailVerified bool
phone string
isPhoneVerified bool
preferredLanguage language.Tag
avatarURL string
profile string
}
tests := []struct {
name string
fields fields
args args
want want
}{
{
name: "unauthenticated session, error",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://api.github.com").
Get("/user").
Reply(200).
JSON(userinfo())
},
authURL: "https://github.com/login/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&state=testState",
tokens: nil,
},
args: args{
&oauth.Session{},
},
want: want{
err: func(err error) bool {
return errors.Is(err, oauth.ErrCodeMissing)
},
},
},
{
name: "user error",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://api.github.com").
Get("/user").
Reply(http.StatusInternalServerError)
},
authURL: "https://github.com/login/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&state=testState",
tokens: &oidc.Tokens{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
},
},
args: args{
&oauth.Session{},
},
want: want{
err: func(err error) bool {
return err.Error() == "http status not ok: 500 Internal Server Error "
},
},
},
{
name: "successful fetch",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://api.github.com").
Get("/user").
Reply(200).
JSON(userinfo())
},
authURL: "https://github.com/login/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&state=testState",
tokens: &oidc.Tokens{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
},
},
args: args{
&oauth.Session{},
},
want: want{
user: &User{
Login: "login",
ID: 1,
AvatarUrl: "avatarURL",
GravatarId: "gravatarID",
Name: "name",
Email: "email",
HtmlUrl: "htmlURL",
CreatedAt: time.Date(2023, 01, 10, 11, 10, 35, 0, time.UTC),
UpdatedAt: time.Date(2023, 01, 10, 11, 10, 35, 0, time.UTC),
},
id: "1",
firstName: "",
lastName: "",
displayName: "name",
nickName: "login",
preferredUsername: "login",
email: "email",
isEmailVerified: true,
phone: "",
isPhoneVerified: false,
preferredLanguage: language.Und,
avatarURL: "avatarURL",
profile: "htmlURL",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer gock.Off()
tt.fields.httpMock()
a := assert.New(t)
provider, err := New(tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.options...)
require.NoError(t, err)
session := &oauth.Session{
AuthURL: tt.fields.authURL,
Code: tt.fields.code,
Tokens: tt.fields.tokens,
Provider: provider.Provider,
}
user, err := session.FetchUser(context.Background())
if tt.want.err != nil && !tt.want.err(err) {
a.Fail("invalid error", err)
}
if tt.want.err == nil {
a.NoError(err)
a.Equal(tt.want.user, user)
a.Equal(tt.want.id, user.GetID())
a.Equal(tt.want.firstName, user.GetFirstName())
a.Equal(tt.want.lastName, user.GetLastName())
a.Equal(tt.want.displayName, user.GetDisplayName())
a.Equal(tt.want.nickName, user.GetNickname())
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
a.Equal(tt.want.email, user.GetEmail())
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
a.Equal(tt.want.phone, user.GetPhone())
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
a.Equal(tt.want.profile, user.GetProfile())
}
})
}
}
func userinfo() *User {
return &User{
Login: "login",
ID: 1,
AvatarUrl: "avatarURL",
GravatarId: "gravatarID",
Name: "name",
Email: "email",
HtmlUrl: "htmlURL",
CreatedAt: time.Date(2023, 01, 10, 11, 10, 35, 0, time.UTC),
UpdatedAt: time.Date(2023, 01, 10, 11, 10, 35, 0, time.UTC),
}
}

View File

@@ -0,0 +1,35 @@
package gitlab
import (
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
)
const (
issuer = "https://gitlab.com"
name = "GitLab"
)
var _ idp.Provider = (*Provider)(nil)
// Provider is the [idp.Provider] implementation for Gitlab
type Provider struct {
*oidc.Provider
}
// New creates a GitLab.com provider using the [oidc.Provider] (OIDC generic provider)
func New(clientID, clientSecret, redirectURI string, options ...oidc.ProviderOpts) (*Provider, error) {
return NewCustomIssuer(name, issuer, clientID, clientSecret, redirectURI, options...)
}
// NewCustomIssuer creates a GitLab provider using the [oidc.Provider] (OIDC generic provider)
// with a custom issuer for self-managed instances
func NewCustomIssuer(name, issuer, clientID, clientSecret, redirectURI string, options ...oidc.ProviderOpts) (*Provider, error) {
rp, err := oidc.New(name, issuer, clientID, clientSecret, redirectURI, oidc.DefaultMapper, options...)
if err != nil {
return nil, err
}
return &Provider{
Provider: rp,
}, nil
}

View File

@@ -0,0 +1,52 @@
package gitlab
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
)
func TestProvider_BeginAuth(t *testing.T) {
type fields struct {
clientID string
clientSecret string
redirectURI string
opts []oidc.ProviderOpts
}
tests := []struct {
name string
fields fields
want idp.Session
}{
{
name: "successful auth",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
},
want: &oidc.Session{
AuthURL: "https://gitlab.com/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
r := require.New(t)
provider, err := New(tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.opts...)
r.NoError(err)
session, err := provider.BeginAuth(context.Background(), "testState")
r.NoError(err)
a.Equal(tt.want.GetAuthURL(), session.GetAuthURL())
})
}
}

View File

@@ -0,0 +1,212 @@
package gitlab
import (
"context"
"errors"
"testing"
"time"
"github.com/h2non/gock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/pkg/client/rp"
openid "github.com/zitadel/oidc/v2/pkg/oidc"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
)
func TestProvider_FetchUser(t *testing.T) {
type fields struct {
clientID string
clientSecret string
redirectURI string
httpMock func()
authURL string
code string
tokens *openid.Tokens
options []oidc.ProviderOpts
}
type want struct {
err error
id string
firstName string
lastName string
displayName string
nickName string
preferredUsername string
email string
isEmailVerified bool
phone string
isPhoneVerified bool
preferredLanguage language.Tag
avatarURL string
profile string
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "unauthenticated session, error",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://gitlab.com/oauth").
Get("/userinfo").
Reply(200).
JSON(userinfo())
},
authURL: "https://gitlab.com/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
tokens: nil,
},
want: want{
err: oidc.ErrCodeMissing,
},
},
{
name: "userinfo error",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://gitlab.com/oauth").
Get("/userinfo").
Reply(200).
JSON(userinfo())
},
authURL: "https://gitlab.com/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
tokens: &openid.Tokens{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: openid.BearerToken,
},
IDTokenClaims: openid.NewIDTokenClaims(
issuer,
"sub2",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
want: want{
err: rp.ErrUserInfoSubNotMatching,
},
},
{
name: "successful fetch",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://gitlab.com/oauth").
Get("/userinfo").
Reply(200).
JSON(userinfo())
},
authURL: "https://gitlab.com/oauth/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
tokens: &openid.Tokens{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: openid.BearerToken,
},
IDTokenClaims: openid.NewIDTokenClaims(
issuer,
"sub",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
want: want{
id: "sub",
firstName: "firstname",
lastName: "lastname",
displayName: "firstname lastname",
nickName: "nickname",
preferredUsername: "username",
email: "email",
isEmailVerified: true,
phone: "phone",
isPhoneVerified: true,
preferredLanguage: language.English,
avatarURL: "picture",
profile: "profile",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer gock.Off()
tt.fields.httpMock()
a := assert.New(t)
// call the real discovery endpoint
gock.New(issuer).Get(openid.DiscoveryEndpoint).EnableNetworking()
provider, err := New(tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.options...)
require.NoError(t, err)
session := &oidc.Session{
Provider: provider.Provider,
AuthURL: tt.fields.authURL,
Code: tt.fields.code,
Tokens: tt.fields.tokens,
}
user, err := session.FetchUser(context.Background())
if tt.want.err != nil && !errors.Is(err, tt.want.err) {
a.Fail("invalid error", "expected %v, got %v", tt.want.err, err)
}
if tt.want.err == nil {
a.NoError(err)
a.Equal(tt.want.id, user.GetID())
a.Equal(tt.want.firstName, user.GetFirstName())
a.Equal(tt.want.lastName, user.GetLastName())
a.Equal(tt.want.displayName, user.GetDisplayName())
a.Equal(tt.want.nickName, user.GetNickname())
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
a.Equal(tt.want.email, user.GetEmail())
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
a.Equal(tt.want.phone, user.GetPhone())
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
a.Equal(tt.want.profile, user.GetProfile())
}
})
}
}
func userinfo() openid.UserInfoSetter {
info := openid.NewUserInfo()
info.SetSubject("sub")
info.SetGivenName("firstname")
info.SetFamilyName("lastname")
info.SetName("firstname lastname")
info.SetNickname("nickname")
info.SetPreferredUsername("username")
info.SetEmail("email", true)
info.SetPhone("phone", true)
info.SetLocale(language.English)
info.SetPicture("picture")
info.SetProfile("profile")
return info
}

View File

@@ -0,0 +1,47 @@
package google
import (
openid "github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
)
const (
issuer = "https://accounts.google.com"
name = "Google"
)
var _ idp.Provider = (*Provider)(nil)
// Provider is the [idp.Provider] implementation for Google
type Provider struct {
*oidc.Provider
}
// New creates a Google provider using the [oidc.Provider] (OIDC generic provider)
func New(clientID, clientSecret, redirectURI string, opts ...oidc.ProviderOpts) (*Provider, error) {
rp, err := oidc.New(name, issuer, clientID, clientSecret, redirectURI, userMapper, opts...)
if err != nil {
return nil, err
}
return &Provider{
Provider: rp,
}, nil
}
var userMapper = func(info openid.UserInfo) idp.User {
return &User{oidc.DefaultMapper(info)}
}
// User is a representation of the authenticated Google and implements the [idp.User] interface
// by wrapping an [idp.User] (implemented by [oidc.User]). It overwrites the [GetPreferredUsername] to use the `email` claim.
type User struct {
idp.User
}
// GetPreferredUsername implements the [idp.User] interface.
// It returns the email, because Google does not return a username.
func (u *User) GetPreferredUsername() string {
return u.GetEmail()
}

View File

@@ -0,0 +1,51 @@
package google
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
)
func TestProvider_BeginAuth(t *testing.T) {
type fields struct {
clientID string
clientSecret string
redirectURI string
}
tests := []struct {
name string
fields fields
want idp.Session
}{
{
name: "successful auth",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
},
want: &oidc.Session{
AuthURL: "https://accounts.google.com/o/oauth2/v2/auth?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
r := require.New(t)
provider, err := New(tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI)
r.NoError(err)
session, err := provider.BeginAuth(context.Background(), "testState")
r.NoError(err)
a.Equal(tt.want.GetAuthURL(), session.GetAuthURL())
})
}
}

View File

@@ -0,0 +1,210 @@
package google
import (
"context"
"errors"
"testing"
"time"
"github.com/h2non/gock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/pkg/client/rp"
openid "github.com/zitadel/oidc/v2/pkg/oidc"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/idp/providers/oidc"
)
func TestSession_FetchUser(t *testing.T) {
type fields struct {
clientID string
clientSecret string
redirectURI string
httpMock func()
authURL string
code string
tokens *openid.Tokens
}
type want struct {
err error
id string
firstName string
lastName string
displayName string
nickName string
preferredUsername string
email string
isEmailVerified bool
phone string
isPhoneVerified bool
preferredLanguage language.Tag
avatarURL string
profile string
hostedDomain string
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "unauthenticated session, error",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://openidconnect.googleapis.com").
Get("/v1/userinfo").
Reply(200).
JSON(userinfo())
},
authURL: "https://accounts.google.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
tokens: nil,
},
want: want{
err: oidc.ErrCodeMissing,
},
},
{
name: "userinfo error",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://openidconnect.googleapis.com").
Get("/v1/userinfo").
Reply(200).
JSON(userinfo())
},
authURL: "https://accounts.google.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
tokens: &openid.Tokens{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: openid.BearerToken,
},
IDTokenClaims: openid.NewIDTokenClaims(
issuer,
"sub2",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
want: want{
err: rp.ErrUserInfoSubNotMatching,
},
},
{
name: "successful fetch",
fields: fields{
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
httpMock: func() {
gock.New("https://openidconnect.googleapis.com").
Get("/v1/userinfo").
Reply(200).
JSON(userinfo())
},
authURL: "https://accounts.google.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
tokens: &openid.Tokens{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: openid.BearerToken,
},
IDTokenClaims: openid.NewIDTokenClaims(
issuer,
"sub",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
want: want{
id: "sub",
firstName: "firstname",
lastName: "lastname",
displayName: "firstname lastname",
nickName: "",
preferredUsername: "email",
email: "email",
isEmailVerified: true,
phone: "",
isPhoneVerified: false,
preferredLanguage: language.English,
avatarURL: "picture",
profile: "",
hostedDomain: "hosted domain",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer gock.Off()
tt.fields.httpMock()
a := assert.New(t)
// call the real discovery endpoint
gock.New(issuer).Get(openid.DiscoveryEndpoint).EnableNetworking()
provider, err := New(tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI)
require.NoError(t, err)
session := &oidc.Session{
Provider: provider.Provider,
AuthURL: tt.fields.authURL,
Code: tt.fields.code,
Tokens: tt.fields.tokens,
}
user, err := session.FetchUser(context.Background())
if tt.want.err != nil && !errors.Is(err, tt.want.err) {
a.Fail("invalid error", "expected %v, got %v", tt.want.err, err)
}
if tt.want.err == nil {
a.NoError(err)
a.Equal(tt.want.id, user.GetID())
a.Equal(tt.want.firstName, user.GetFirstName())
a.Equal(tt.want.lastName, user.GetLastName())
a.Equal(tt.want.displayName, user.GetDisplayName())
a.Equal(tt.want.nickName, user.GetNickname())
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
a.Equal(tt.want.email, user.GetEmail())
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
a.Equal(tt.want.phone, user.GetPhone())
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
a.Equal(tt.want.profile, user.GetProfile())
}
})
}
}
func userinfo() openid.UserInfoSetter {
info := openid.NewUserInfo()
info.SetSubject("sub")
info.SetGivenName("firstname")
info.SetFamilyName("lastname")
info.SetName("firstname lastname")
info.SetEmail("email", true)
info.SetLocale(language.English)
info.SetPicture("picture")
info.AppendClaims("hd", "hosted domain")
return info
}

View File

@@ -0,0 +1,136 @@
package jwt
import (
"context"
"encoding/base64"
"errors"
"net/url"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/idp"
)
const (
queryAuthRequestID = "authRequestID"
queryUserAgentID = "userAgentID"
)
var _ idp.Provider = (*Provider)(nil)
var (
ErrNoTokens = errors.New("no tokens provided")
ErrMissingUserAgentID = errors.New("userAgentID missing")
)
// Provider is the [idp.Provider] implementation for a JWT provider
type Provider struct {
name string
headerName string
issuer string
jwtEndpoint string
keysEndpoint string
isLinkingAllowed bool
isCreationAllowed bool
isAutoCreation bool
isAutoUpdate bool
encryptionAlg crypto.EncryptionAlgorithm
}
type ProviderOpts func(provider *Provider)
// WithLinkingAllowed allows end users to link the federated user to an existing one
func WithLinkingAllowed() ProviderOpts {
return func(p *Provider) {
p.isLinkingAllowed = true
}
}
// WithCreationAllowed allows end users to create a new user using the federated information
func WithCreationAllowed() ProviderOpts {
return func(p *Provider) {
p.isCreationAllowed = true
}
}
// WithAutoCreation enables that federated users are automatically created if not already existing
func WithAutoCreation() ProviderOpts {
return func(p *Provider) {
p.isAutoCreation = true
}
}
// WithAutoUpdate enables that information retrieved from the provider is automatically used to update
// the existing user on each authentication
func WithAutoUpdate() ProviderOpts {
return func(p *Provider) {
p.isAutoUpdate = true
}
}
// New creates a JWT provider
func New(name, issuer, jwtEndpoint, keysEndpoint, headerName string, encryptionAlg crypto.EncryptionAlgorithm, options ...ProviderOpts) (*Provider, error) {
provider := &Provider{
name: name,
issuer: issuer,
jwtEndpoint: jwtEndpoint,
keysEndpoint: keysEndpoint,
headerName: headerName,
encryptionAlg: encryptionAlg,
}
for _, option := range options {
option(provider)
}
return provider, nil
}
// Name implements the [idp.Provider] interface
func (p *Provider) Name() string {
return p.name
}
// BeginAuth implements the [idp.Provider] interface.
// It will create a [Session] with an AuthURL, pointing to the jwtEndpoint
// with the authRequest and encrypted userAgent ids.
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) {
if len(params) != 1 {
return nil, ErrMissingUserAgentID
}
userAgentID, ok := params[0].(string)
if !ok {
return nil, ErrMissingUserAgentID
}
redirect, err := url.Parse(p.jwtEndpoint)
if err != nil {
return nil, err
}
q := redirect.Query()
q.Set(queryAuthRequestID, state)
nonce, err := p.encryptionAlg.Encrypt([]byte(userAgentID))
if err != nil {
return nil, err
}
q.Set(queryUserAgentID, base64.RawURLEncoding.EncodeToString(nonce))
redirect.RawQuery = q.Encode()
return &Session{AuthURL: redirect.String()}, nil
}
// IsLinkingAllowed implements the [idp.Provider] interface.
func (p *Provider) IsLinkingAllowed() bool {
return p.isLinkingAllowed
}
// IsCreationAllowed implements the [idp.Provider] interface.
func (p *Provider) IsCreationAllowed() bool {
return p.isCreationAllowed
}
// IsAutoCreation implements the [idp.Provider] interface.
func (p *Provider) IsAutoCreation() bool {
return p.isAutoCreation
}
// IsAutoUpdate implements the [idp.Provider] interface.
func (p *Provider) IsAutoUpdate() bool {
return p.isAutoUpdate
}

View File

@@ -0,0 +1,222 @@
package jwt
import (
"context"
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/idp"
)
func TestProvider_BeginAuth(t *testing.T) {
type fields struct {
name string
issuer string
jwtEndpoint string
keysEndpoint string
headerName string
encryptionAlg func(t *testing.T) crypto.EncryptionAlgorithm
}
type args struct {
params []any
}
type want struct {
session idp.Session
err func(error) bool
}
tests := []struct {
name string
fields fields
args args
want want
}{
{
name: "missing userAgentID error",
fields: fields{
issuer: "https://jwt.com",
jwtEndpoint: "https://auth.com/jwt",
keysEndpoint: "https://jwt.com/keys",
headerName: "jwt-header",
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
},
},
args: args{
params: nil,
},
want: want{
err: func(err error) bool {
return errors.Is(err, ErrMissingUserAgentID)
},
},
},
{
name: "invalid userAgentID error",
fields: fields{
issuer: "https://jwt.com",
jwtEndpoint: "https://auth.com/jwt",
keysEndpoint: "https://jwt.com/keys",
headerName: "jwt-header",
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
},
},
args: args{
params: []any{
0,
},
},
want: want{
err: func(err error) bool {
return errors.Is(err, ErrMissingUserAgentID)
},
},
},
{
name: "successful auth",
fields: fields{
issuer: "https://jwt.com",
jwtEndpoint: "https://auth.com/jwt",
keysEndpoint: "https://jwt.com/keys",
headerName: "jwt-header",
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
},
},
args: args{
params: []any{
"agent",
},
},
want: want{
session: &Session{AuthURL: "https://auth.com/jwt?authRequestID=testState&userAgentID=YWdlbnQ"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
provider, err := New(
tt.fields.name,
tt.fields.issuer,
tt.fields.jwtEndpoint,
tt.fields.keysEndpoint,
tt.fields.headerName,
tt.fields.encryptionAlg(t),
)
require.NoError(t, err)
session, err := provider.BeginAuth(context.Background(), "testState", tt.args.params...)
if tt.want.err != nil && !tt.want.err(err) {
a.Fail("invalid error", err)
}
if tt.want.err == nil {
a.NoError(err)
a.Equal(tt.want.session.GetAuthURL(), session.GetAuthURL())
}
})
}
}
func TestProvider_Options(t *testing.T) {
type fields struct {
name string
issuer string
jwtEndpoint string
keysEndpoint string
headerName string
encryptionAlg func(t *testing.T) crypto.EncryptionAlgorithm
opts []ProviderOpts
}
type want struct {
name string
linkingAllowed bool
creationAllowed bool
autoCreation bool
autoUpdate bool
pkce bool
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "default",
fields: fields{
name: "jwt",
issuer: "https://jwt.com",
jwtEndpoint: "https://auth.com/jwt",
keysEndpoint: "https://jwt.com/keys",
headerName: "jwt-header",
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
},
opts: nil,
},
want: want{
name: "jwt",
linkingAllowed: false,
creationAllowed: false,
autoCreation: false,
autoUpdate: false,
pkce: false,
},
},
{
name: "all true",
fields: fields{
name: "jwt",
issuer: "https://jwt.com",
jwtEndpoint: "https://auth.com/jwt",
keysEndpoint: "https://jwt.com/keys",
headerName: "jwt-header",
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
},
opts: []ProviderOpts{
WithLinkingAllowed(),
WithCreationAllowed(),
WithAutoCreation(),
WithAutoUpdate(),
},
},
want: want{
name: "jwt",
linkingAllowed: true,
creationAllowed: true,
autoCreation: true,
autoUpdate: true,
pkce: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
provider, err := New(
tt.fields.name,
tt.fields.issuer,
tt.fields.jwtEndpoint,
tt.fields.keysEndpoint,
tt.fields.headerName,
tt.fields.encryptionAlg(t),
tt.fields.opts...,
)
require.NoError(t, err)
a.Equal(tt.want.name, provider.Name())
a.Equal(tt.want.linkingAllowed, provider.IsLinkingAllowed())
a.Equal(tt.want.creationAllowed, provider.IsCreationAllowed())
a.Equal(tt.want.autoCreation, provider.IsAutoCreation())
a.Equal(tt.want.autoUpdate, provider.IsAutoUpdate())
})
}
}

View File

@@ -0,0 +1,72 @@
package jwt
import (
"context"
"github.com/zitadel/oidc/v2/pkg/oidc"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/idp"
)
var _ idp.Session = (*Session)(nil)
// Session is the [idp.Session] implementation for the JWT provider
type Session struct {
AuthURL string
Tokens *oidc.Tokens
}
// GetAuthURL implements the [idp.Session] interface
func (s *Session) GetAuthURL() string {
return s.AuthURL
}
// FetchUser implements the [idp.Session] interface.
// It will map the received idToken into an [idp.User].
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
if s.Tokens == nil {
return nil, ErrNoTokens
}
return &User{s.Tokens.IDTokenClaims}, nil
}
type User struct {
oidc.IDTokenClaims
}
func (u *User) GetID() string {
return u.IDTokenClaims.GetSubject()
}
func (u *User) GetFirstName() string {
return u.IDTokenClaims.GetGivenName()
}
func (u *User) GetLastName() string {
return u.IDTokenClaims.GetFamilyName()
}
func (u *User) GetDisplayName() string {
return u.IDTokenClaims.GetName()
}
func (u *User) GetNickname() string {
return u.IDTokenClaims.GetNickname()
}
func (u *User) GetPhone() string {
return u.IDTokenClaims.GetPhoneNumber()
}
func (u *User) IsPhoneVerified() bool {
return u.IDTokenClaims.IsPhoneNumberVerified()
}
func (u *User) GetPreferredLanguage() language.Tag {
return u.IDTokenClaims.GetLocale()
}
func (u *User) GetAvatarURL() string {
return u.IDTokenClaims.GetPicture()
}

View File

@@ -0,0 +1,145 @@
package jwt
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zitadel/oidc/v2/pkg/oidc"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/idp"
)
func TestSession_FetchUser(t *testing.T) {
type fields struct {
authURL string
tokens *oidc.Tokens
}
type want struct {
err func(error) bool
user idp.User
id string
firstName string
lastName string
displayName string
nickName string
preferredUsername string
email string
isEmailVerified bool
phone string
isPhoneVerified bool
preferredLanguage language.Tag
avatarURL string
profile string
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "no tokens",
fields: fields{},
want: want{
err: func(err error) bool {
return errors.Is(err, ErrNoTokens)
},
},
},
{
name: "successful fetch",
fields: fields{
authURL: "https://auth.com/jwt?authRequestID=testState",
tokens: &oidc.Tokens{
Token: &oauth2.Token{},
IDTokenClaims: func() oidc.IDTokenClaims {
claims := oidc.EmptyIDTokenClaims()
userinfo := oidc.NewUserInfo()
userinfo.SetSubject("sub")
userinfo.SetPicture("picture")
userinfo.SetName("firstname lastname")
userinfo.SetEmail("email", true)
userinfo.SetGivenName("firstname")
userinfo.SetFamilyName("lastname")
userinfo.SetNickname("nickname")
userinfo.SetPreferredUsername("username")
userinfo.SetProfile("profile")
userinfo.SetPhone("phone", true)
userinfo.SetLocale(language.English)
claims.SetUserinfo(userinfo)
return claims
}(),
},
},
want: want{
user: &User{
IDTokenClaims: func() oidc.IDTokenClaims {
claims := oidc.EmptyIDTokenClaims()
userinfo := oidc.NewUserInfo()
userinfo.SetSubject("sub")
userinfo.SetPicture("picture")
userinfo.SetName("firstname lastname")
userinfo.SetEmail("email", true)
userinfo.SetGivenName("firstname")
userinfo.SetFamilyName("lastname")
userinfo.SetNickname("nickname")
userinfo.SetPreferredUsername("username")
userinfo.SetProfile("profile")
userinfo.SetPhone("phone", true)
userinfo.SetLocale(language.English)
claims.SetUserinfo(userinfo)
return claims
}(),
},
id: "sub",
firstName: "firstname",
lastName: "lastname",
displayName: "firstname lastname",
nickName: "nickname",
preferredUsername: "username",
email: "email",
isEmailVerified: true,
phone: "phone",
isPhoneVerified: true,
preferredLanguage: language.English,
avatarURL: "picture",
profile: "profile",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
session := &Session{
AuthURL: tt.fields.authURL,
Tokens: tt.fields.tokens,
}
user, err := session.FetchUser(context.Background())
if tt.want.err != nil && !tt.want.err(err) {
a.Fail("invalid error", err)
}
if tt.want.err == nil {
a.NoError(err)
a.Equal(tt.want.user, user)
a.Equal(tt.want.id, user.GetID())
a.Equal(tt.want.firstName, user.GetFirstName())
a.Equal(tt.want.lastName, user.GetLastName())
a.Equal(tt.want.displayName, user.GetDisplayName())
a.Equal(tt.want.nickName, user.GetNickname())
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
a.Equal(tt.want.email, user.GetEmail())
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
a.Equal(tt.want.phone, user.GetPhone())
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
a.Equal(tt.want.profile, user.GetProfile())
}
})
}
}

View File

@@ -0,0 +1,102 @@
package oauth
import (
"encoding/json"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/idp"
)
var _ idp.User = (*UserMapper)(nil)
// UserMapper is an implementation of [idp.User].
// It can be used in ZITADEL actions to map the raw `info`
type UserMapper struct {
ID string
FirstName string
LastName string
DisplayName string
NickName string
PreferredUsername string
Email string
EmailVerified bool
Phone string
PhoneVerified bool
PreferredLanguage string
AvatarURL string
Profile string
info map[string]interface{}
}
func (u *UserMapper) UnmarshalJSON(data []byte) error {
if u.info == nil {
u.info = make(map[string]interface{})
}
return json.Unmarshal(data, &u.info)
}
// GetID is an implementation of the [idp.User] interface.
func (u *UserMapper) GetID() string {
return u.ID
}
// GetFirstName is an implementation of the [idp.User] interface.
func (u *UserMapper) GetFirstName() string {
return u.FirstName
}
// GetLastName is an implementation of the [idp.User] interface.
func (u *UserMapper) GetLastName() string {
return u.LastName
}
// GetDisplayName is an implementation of the [idp.User] interface.
func (u *UserMapper) GetDisplayName() string {
return u.DisplayName
}
// GetNickname is an implementation of the [idp.User] interface.
func (u *UserMapper) GetNickname() string {
return u.NickName
}
// GetPreferredUsername is an implementation of the [idp.User] interface.
func (u *UserMapper) GetPreferredUsername() string {
return u.PreferredUsername
}
// GetEmail is an implementation of the [idp.User] interface.
func (u *UserMapper) GetEmail() string {
return u.Email
}
// IsEmailVerified is an implementation of the [idp.User] interface.
func (u *UserMapper) IsEmailVerified() bool {
return u.EmailVerified
}
// GetPhone is an implementation of the [idp.User] interface.
func (u *UserMapper) GetPhone() string {
return u.Phone
}
// IsPhoneVerified is an implementation of the [idp.User] interface.
func (u *UserMapper) IsPhoneVerified() bool {
return u.PhoneVerified
}
// GetPreferredLanguage is an implementation of the [idp.User] interface.
func (u *UserMapper) GetPreferredLanguage() language.Tag {
return language.Make(u.PreferredLanguage)
}
// GetAvatarURL is an implementation of the [idp.User] interface.
func (u *UserMapper) GetAvatarURL() string {
return u.AvatarURL
}
// GetProfile is an implementation of the [idp.User] interface.
func (u *UserMapper) GetProfile() string {
return u.Profile
}

View File

@@ -0,0 +1,112 @@
package oauth
import (
"context"
"github.com/zitadel/oidc/v2/pkg/client/rp"
"golang.org/x/oauth2"
"github.com/zitadel/zitadel/internal/idp"
)
var _ idp.Provider = (*Provider)(nil)
// Provider is the [idp.Provider] implementation for a generic OAuth 2.0 provider
type Provider struct {
rp.RelyingParty
options []rp.Option
name string
userEndpoint string
userMapper func() idp.User
isLinkingAllowed bool
isCreationAllowed bool
isAutoCreation bool
isAutoUpdate bool
}
type ProviderOpts func(provider *Provider)
// WithLinkingAllowed allows end users to link the federated user to an existing one.
func WithLinkingAllowed() ProviderOpts {
return func(p *Provider) {
p.isLinkingAllowed = true
}
}
// WithCreationAllowed allows end users to create a new user using the federated information.
func WithCreationAllowed() ProviderOpts {
return func(p *Provider) {
p.isCreationAllowed = true
}
}
// WithAutoCreation enables that federated users are automatically created if not already existing.
func WithAutoCreation() ProviderOpts {
return func(p *Provider) {
p.isAutoCreation = true
}
}
// WithAutoUpdate enables that information retrieved from the provider is automatically used to update
// the existing user on each authentication.
func WithAutoUpdate() ProviderOpts {
return func(p *Provider) {
p.isAutoUpdate = true
}
}
// WithRelyingPartyOption allows to set an additional [rp.Option] like [rp.WithPKCE].
func WithRelyingPartyOption(option rp.Option) ProviderOpts {
return func(p *Provider) {
p.options = append(p.options, option)
}
}
// New creates a generic OAuth 2.0 provider
func New(config *oauth2.Config, name, userEndpoint string, userMapper func() idp.User, options ...ProviderOpts) (provider *Provider, err error) {
provider = &Provider{
name: name,
userEndpoint: userEndpoint,
userMapper: userMapper,
}
for _, option := range options {
option(provider)
}
provider.RelyingParty, err = rp.NewRelyingPartyOAuth(config, provider.options...)
if err != nil {
return nil, err
}
return provider, nil
}
// Name implements the [idp.Provider] interface
func (p *Provider) Name() string {
return p.name
}
// BeginAuth implements the [idp.Provider] interface.
// It will create a [Session] with an OAuth2.0 authorization request as AuthURL.
func (p *Provider) BeginAuth(ctx context.Context, state string, _ ...any) (idp.Session, error) {
url := rp.AuthURL(state, p.RelyingParty)
return &Session{AuthURL: url, Provider: p}, nil
}
// IsLinkingAllowed implements the [idp.Provider] interface.
func (p *Provider) IsLinkingAllowed() bool {
return p.isLinkingAllowed
}
// IsCreationAllowed implements the [idp.Provider] interface.
func (p *Provider) IsCreationAllowed() bool {
return p.isCreationAllowed
}
// IsAutoCreation implements the [idp.Provider] interface.
func (p *Provider) IsAutoCreation() bool {
return p.isAutoCreation
}
// IsAutoUpdate implements the [idp.Provider] interface.
func (p *Provider) IsAutoUpdate() bool {
return p.isAutoUpdate
}

View File

@@ -0,0 +1,153 @@
package oauth
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/pkg/client/rp"
"golang.org/x/oauth2"
"github.com/zitadel/zitadel/internal/idp"
)
func TestProvider_BeginAuth(t *testing.T) {
type fields struct {
config *oauth2.Config
name string
userEndpoint string
userMapper func() idp.User
}
tests := []struct {
name string
fields fields
want idp.Session
}{
{
name: "successful auth",
fields: fields{
config: &oauth2.Config{
ClientID: "clientID",
ClientSecret: "clientSecret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://oauth2.com/authorize",
TokenURL: "https://oauth2.com/token",
},
RedirectURL: "redirectURI",
Scopes: []string{"user"},
},
},
want: &Session{AuthURL: "https://oauth2.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=user&state=testState"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
r := require.New(t)
provider, err := New(tt.fields.config, tt.fields.name, tt.fields.userEndpoint, tt.fields.userMapper)
r.NoError(err)
session, err := provider.BeginAuth(context.Background(), "testState")
r.NoError(err)
a.Equal(tt.want.GetAuthURL(), session.GetAuthURL())
})
}
}
func TestProvider_Options(t *testing.T) {
type fields struct {
config *oauth2.Config
name string
userEndpoint string
userMapper func() idp.User
options []ProviderOpts
}
type want struct {
name string
linkingAllowed bool
creationAllowed bool
autoCreation bool
autoUpdate bool
pkce bool
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "default",
fields: fields{
name: "oauth",
config: &oauth2.Config{
ClientID: "clientID",
ClientSecret: "clientSecret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://oauth2.com/authorize",
TokenURL: "https://oauth2.com/token",
},
RedirectURL: "redirectURI",
Scopes: []string{"user"},
},
options: nil,
},
want: want{
name: "oauth",
linkingAllowed: false,
creationAllowed: false,
autoCreation: false,
autoUpdate: false,
pkce: false,
},
},
{
name: "all true",
fields: fields{
name: "oauth",
config: &oauth2.Config{
ClientID: "clientID",
ClientSecret: "clientSecret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://oauth2.com/authorize",
TokenURL: "https://oauth2.com/token",
},
RedirectURL: "redirectURI",
Scopes: []string{"user"},
},
options: []ProviderOpts{
WithLinkingAllowed(),
WithCreationAllowed(),
WithAutoCreation(),
WithAutoUpdate(),
WithRelyingPartyOption(rp.WithPKCE(nil)),
},
},
want: want{
name: "oauth",
linkingAllowed: true,
creationAllowed: true,
autoCreation: true,
autoUpdate: true,
pkce: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
provider, err := New(tt.fields.config, tt.fields.name, tt.fields.userEndpoint, tt.fields.userMapper, tt.fields.options...)
require.NoError(t, err)
a.Equal(tt.want.name, provider.Name())
a.Equal(tt.want.linkingAllowed, provider.IsLinkingAllowed())
a.Equal(tt.want.creationAllowed, provider.IsCreationAllowed())
a.Equal(tt.want.autoCreation, provider.IsAutoCreation())
a.Equal(tt.want.autoUpdate, provider.IsAutoUpdate())
a.Equal(tt.want.pkce, provider.RelyingParty.IsPKCE())
})
}
}

View File

@@ -0,0 +1,63 @@
package oauth
import (
"context"
"errors"
"net/http"
"github.com/zitadel/oidc/v2/pkg/client/rp"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/zitadel/internal/idp"
)
var ErrCodeMissing = errors.New("no auth code provided")
var _ idp.Session = (*Session)(nil)
// Session is the [idp.Session] implementation for the OAuth2.0 provider.
type Session struct {
AuthURL string
Code string
Tokens *oidc.Tokens
Provider *Provider
}
// GetAuthURL implements the [idp.Session] interface.
func (s *Session) GetAuthURL() string {
return s.AuthURL
}
// FetchUser implements the [idp.Session] interface.
// It will execute an OAuth 2.0 code exchange if needed to retrieve the access token,
// call the specified userEndpoint and map the received information into an [idp.User].
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
if s.Tokens == nil {
if err = s.authorize(ctx); err != nil {
return nil, err
}
}
req, err := http.NewRequest("GET", s.Provider.userEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("authorization", s.Tokens.TokenType+" "+s.Tokens.AccessToken)
mapper := s.Provider.userMapper()
if err := httphelper.HttpRequest(s.Provider.RelyingParty.HttpClient(), req, &mapper); err != nil {
return nil, err
}
return mapper, nil
}
func (s *Session) authorize(ctx context.Context) (err error) {
if s.Code == "" {
return ErrCodeMissing
}
s.Tokens, err = rp.CodeExchange(ctx, s.Code, s.Provider.RelyingParty)
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,273 @@
package oauth
import (
"context"
"errors"
"net/http"
"testing"
"github.com/h2non/gock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/pkg/oidc"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/idp"
)
func TestProvider_FetchUser(t *testing.T) {
type fields struct {
config *oauth2.Config
name string
userEndpoint string
httpMock func(issuer string)
userMapper func() idp.User
authURL string
code string
tokens *oidc.Tokens
}
type want struct {
err func(error) bool
user idp.User
id string
firstName string
lastName string
displayName string
nickName string
preferredUsername string
email string
isEmailVerified bool
phone string
isPhoneVerified bool
preferredLanguage language.Tag
avatarURL string
profile string
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "unauthenticated session, error",
fields: fields{
config: &oauth2.Config{
ClientID: "clientID",
ClientSecret: "clientSecret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://oauth2.com/authorize",
TokenURL: "https://oauth2.com/token",
},
RedirectURL: "redirectURI",
Scopes: []string{"user"},
},
userEndpoint: "https://oauth2.com/user",
httpMock: func(issuer string) {},
authURL: "https://oauth2.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=user&state=testState",
tokens: nil,
},
want: want{
err: func(err error) bool {
return errors.Is(err, ErrCodeMissing)
},
},
},
{
name: "user error",
fields: fields{
config: &oauth2.Config{
ClientID: "clientID",
ClientSecret: "clientSecret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://oauth2.com/authorize",
TokenURL: "https://oauth2.com/token",
},
RedirectURL: "redirectURI",
Scopes: []string{"user"},
},
userEndpoint: "https://oauth2.com/user",
httpMock: func(issuer string) {
gock.New(issuer).
Get("/user").
Reply(http.StatusInternalServerError)
},
userMapper: func() idp.User {
return &UserMapper{
ID: "userID",
}
},
authURL: "https://oauth2.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=user&state=testState",
tokens: &oidc.Tokens{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
},
},
want: want{
err: func(err error) bool {
return err.Error() == "http status not ok: 500 Internal Server Error "
},
},
},
{
name: "successful fetch",
fields: fields{
config: &oauth2.Config{
ClientID: "clientID",
ClientSecret: "clientSecret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://oauth2.com/authorize",
TokenURL: "https://oauth2.com/token",
},
RedirectURL: "redirectURI",
Scopes: []string{"user"},
},
userEndpoint: "https://oauth2.com/user",
httpMock: func(issuer string) {
gock.New(issuer).
Get("/user").
Reply(200).
JSON(map[string]interface{}{
"userID": "id",
"custom": "claim",
})
},
userMapper: func() idp.User {
return &UserMapper{}
},
authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=user&state=testState",
tokens: &oidc.Tokens{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
},
},
want: want{
user: &UserMapper{
info: map[string]interface{}{
"userID": "id",
"custom": "claim",
},
},
id: "",
firstName: "",
lastName: "",
displayName: "",
nickName: "",
preferredUsername: "",
email: "",
isEmailVerified: false,
phone: "",
isPhoneVerified: false,
preferredLanguage: language.Und,
avatarURL: "",
profile: "",
},
},
{
name: "successful fetch with code exchange",
fields: fields{
config: &oauth2.Config{
ClientID: "clientID",
ClientSecret: "clientSecret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://oauth2.com/authorize",
TokenURL: "https://oauth2.com/token",
},
RedirectURL: "redirectURI",
Scopes: []string{"user"},
},
userEndpoint: "https://oauth2.com/user",
httpMock: func(issuer string) {
gock.New(issuer).
Post("/token").
BodyString("client_id=clientID&client_secret=clientSecret&code=code&grant_type=authorization_code&redirect_uri=redirectURI").
Reply(200).
JSON(&oidc.AccessTokenResponse{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
RefreshToken: "",
ExpiresIn: 3600,
IDToken: "",
State: "testState"})
gock.New(issuer).
Get("/user").
Reply(200).
JSON(map[string]interface{}{
"userID": "id",
"custom": "claim",
})
},
userMapper: func() idp.User {
return &UserMapper{}
},
authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=user&state=testState",
tokens: nil,
code: "code",
},
want: want{
user: &UserMapper{
info: map[string]interface{}{
"userID": "id",
"custom": "claim",
},
},
id: "",
firstName: "",
lastName: "",
displayName: "",
nickName: "",
preferredUsername: "",
email: "",
isEmailVerified: false,
phone: "",
isPhoneVerified: false,
preferredLanguage: language.Und,
avatarURL: "",
profile: "",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer gock.Off()
tt.fields.httpMock("https://oauth2.com")
a := assert.New(t)
provider, err := New(tt.fields.config, tt.fields.name, tt.fields.userEndpoint, tt.fields.userMapper)
require.NoError(t, err)
session := &Session{
AuthURL: tt.fields.authURL,
Code: tt.fields.code,
Tokens: tt.fields.tokens,
Provider: provider,
}
user, err := session.FetchUser(context.Background())
if tt.want.err != nil && !tt.want.err(err) {
a.Fail("invalid error", err)
}
if tt.want.err == nil {
a.NoError(err)
a.Equal(tt.want.user, user)
a.Equal(tt.want.id, user.GetID())
a.Equal(tt.want.firstName, user.GetFirstName())
a.Equal(tt.want.lastName, user.GetLastName())
a.Equal(tt.want.displayName, user.GetDisplayName())
a.Equal(tt.want.nickName, user.GetNickname())
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
a.Equal(tt.want.email, user.GetEmail())
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
a.Equal(tt.want.phone, user.GetPhone())
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
a.Equal(tt.want.profile, user.GetProfile())
}
})
}
}

View File

@@ -0,0 +1,116 @@
package oidc
import (
"context"
"github.com/zitadel/oidc/v2/pkg/client/rp"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/zitadel/internal/idp"
)
var _ idp.Provider = (*Provider)(nil)
// Provider is the [idp.Provider] implementation for a generic OIDC provider
type Provider struct {
rp.RelyingParty
options []rp.Option
name string
isLinkingAllowed bool
isCreationAllowed bool
isAutoCreation bool
isAutoUpdate bool
userInfoMapper func(info oidc.UserInfo) idp.User
}
type ProviderOpts func(provider *Provider)
// WithLinkingAllowed allows end users to link the federated user to an existing one.
func WithLinkingAllowed() ProviderOpts {
return func(p *Provider) {
p.isLinkingAllowed = true
}
}
// WithCreationAllowed allows end users to create a new user using the federated information.
func WithCreationAllowed() ProviderOpts {
return func(p *Provider) {
p.isCreationAllowed = true
}
}
// WithAutoCreation enables that federated users are automatically created if not already existing.
func WithAutoCreation() ProviderOpts {
return func(p *Provider) {
p.isAutoCreation = true
}
}
// WithAutoUpdate enables that information retrieved from the provider is automatically used to update
// the existing user on each authentication.
func WithAutoUpdate() ProviderOpts {
return func(p *Provider) {
p.isAutoUpdate = true
}
}
// WithRelyingPartyOption allows to set an additional [rp.Option] like [rp.WithPKCE].
func WithRelyingPartyOption(option rp.Option) ProviderOpts {
return func(p *Provider) {
p.options = append(p.options, option)
}
}
type UserInfoMapper func(info oidc.UserInfo) idp.User
var DefaultMapper UserInfoMapper = func(info oidc.UserInfo) idp.User {
return NewUser(info)
}
// New creates a generic OIDC provider
func New(name, issuer, clientID, clientSecret, redirectURI string, userInfoMapper UserInfoMapper, options ...ProviderOpts) (provider *Provider, err error) {
provider = &Provider{
name: name,
userInfoMapper: userInfoMapper,
}
for _, option := range options {
option(provider)
}
provider.RelyingParty, err = rp.NewRelyingPartyOIDC(issuer, clientID, clientSecret, redirectURI, []string{oidc.ScopeOpenID}, provider.options...)
if err != nil {
return nil, err
}
return provider, nil
}
// Name implements the [idp.Provider] interface
func (p *Provider) Name() string {
return p.name
}
// BeginAuth implements the [idp.Provider] interface.
// It will create a [Session] with an OIDC authorization request as AuthURL.
func (p *Provider) BeginAuth(ctx context.Context, state string, _ ...any) (idp.Session, error) {
url := rp.AuthURL(state, p.RelyingParty)
return &Session{AuthURL: url, Provider: p}, nil
}
// IsLinkingAllowed implements the [idp.Provider] interface.
func (p *Provider) IsLinkingAllowed() bool {
return p.isLinkingAllowed
}
// IsCreationAllowed implements the [idp.Provider] interface.
func (p *Provider) IsCreationAllowed() bool {
return p.isCreationAllowed
}
// IsAutoCreation implements the [idp.Provider] interface.
func (p *Provider) IsAutoCreation() bool {
return p.isAutoCreation
}
// IsAutoUpdate implements the [idp.Provider] interface.
func (p *Provider) IsAutoUpdate() bool {
return p.isAutoUpdate
}

View File

@@ -0,0 +1,182 @@
package oidc
import (
"context"
"testing"
"github.com/h2non/gock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/pkg/client/rp"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/zitadel/internal/idp"
)
func TestProvider_BeginAuth(t *testing.T) {
type fields struct {
name string
issuer string
clientID string
clientSecret string
redirectURI string
userMapper func(info oidc.UserInfo) idp.User
httpMock func(issuer string)
}
tests := []struct {
name string
fields fields
want idp.Session
}{
{
name: "successful auth",
fields: fields{
name: "oidc",
issuer: "https://issuer.com",
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
userMapper: DefaultMapper,
httpMock: func(issuer string) {
gock.New(issuer).
Get(oidc.DiscoveryEndpoint).
Reply(200).
JSON(&oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
UserinfoEndpoint: issuer + "/userinfo",
})
},
},
want: &Session{AuthURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer gock.Off()
tt.fields.httpMock(tt.fields.issuer)
a := assert.New(t)
r := require.New(t)
provider, err := New(tt.fields.name, tt.fields.issuer, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.userMapper)
r.NoError(err)
session, err := provider.BeginAuth(context.Background(), "testState")
r.NoError(err)
a.Equal(tt.want.GetAuthURL(), session.GetAuthURL())
})
}
}
func TestProvider_Options(t *testing.T) {
type fields struct {
name string
issuer string
clientID string
clientSecret string
redirectURI string
userMapper func(info oidc.UserInfo) idp.User
opts []ProviderOpts
httpMock func(issuer string)
}
type want struct {
name string
linkingAllowed bool
creationAllowed bool
autoCreation bool
autoUpdate bool
pkce bool
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "default",
fields: fields{
name: "oidc",
issuer: "https://issuer.com",
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
userMapper: DefaultMapper,
opts: nil,
httpMock: func(issuer string) {
gock.New(issuer).
Get(oidc.DiscoveryEndpoint).
Reply(200).
JSON(&oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
UserinfoEndpoint: issuer + "/userinfo",
})
},
},
want: want{
name: "oidc",
linkingAllowed: false,
creationAllowed: false,
autoCreation: false,
autoUpdate: false,
pkce: false,
},
},
{
name: "all true",
fields: fields{
name: "oidc",
issuer: "https://issuer.com",
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
userMapper: DefaultMapper,
opts: []ProviderOpts{
WithLinkingAllowed(),
WithCreationAllowed(),
WithAutoCreation(),
WithAutoUpdate(),
WithRelyingPartyOption(rp.WithPKCE(nil)),
},
httpMock: func(issuer string) {
gock.New(issuer).
Get(oidc.DiscoveryEndpoint).
Reply(200).
JSON(&oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
UserinfoEndpoint: issuer + "/userinfo",
})
},
},
want: want{
name: "oidc",
linkingAllowed: true,
creationAllowed: true,
autoCreation: true,
autoUpdate: true,
pkce: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer gock.Off()
tt.fields.httpMock(tt.fields.issuer)
a := assert.New(t)
provider, err := New(tt.fields.name, tt.fields.issuer, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.userMapper, tt.fields.opts...)
require.NoError(t, err)
a.Equal(tt.want.name, provider.Name())
a.Equal(tt.want.linkingAllowed, provider.IsLinkingAllowed())
a.Equal(tt.want.creationAllowed, provider.IsCreationAllowed())
a.Equal(tt.want.autoCreation, provider.IsAutoCreation())
a.Equal(tt.want.autoUpdate, provider.IsAutoUpdate())
})
}
}

View File

@@ -0,0 +1,99 @@
package oidc
import (
"context"
"errors"
"github.com/zitadel/oidc/v2/pkg/client/rp"
"github.com/zitadel/oidc/v2/pkg/oidc"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/idp"
)
var ErrCodeMissing = errors.New("no auth code provided")
var _ idp.Session = (*Session)(nil)
// Session is the [idp.Session] implementation for the OIDC provider.
type Session struct {
Provider *Provider
AuthURL string
Code string
Tokens *oidc.Tokens
}
// GetAuthURL implements the [idp.Session] interface.
func (s *Session) GetAuthURL() string {
return s.AuthURL
}
// FetchUser implements the [idp.Session] interface.
// It will execute an OIDC code exchange if needed to retrieve the tokens,
// call the userinfo endpoint and map the received information into an [idp.User].
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
if s.Tokens == nil {
if err = s.authorize(ctx); err != nil {
return nil, err
}
}
info, err := rp.Userinfo(
s.Tokens.AccessToken,
s.Tokens.TokenType,
s.Tokens.IDTokenClaims.GetSubject(),
s.Provider.RelyingParty,
)
if err != nil {
return nil, err
}
u := s.Provider.userInfoMapper(info)
return u, nil
}
func (s *Session) authorize(ctx context.Context) (err error) {
if s.Code == "" {
return ErrCodeMissing
}
s.Tokens, err = rp.CodeExchange(ctx, s.Code, s.Provider.RelyingParty)
return err
}
func NewUser(info oidc.UserInfo) *User {
return &User{UserInfo: info}
}
type User struct {
oidc.UserInfo
}
func (u *User) GetID() string {
return u.GetSubject()
}
func (u *User) GetFirstName() string {
return u.GetGivenName()
}
func (u *User) GetLastName() string {
return u.GetFamilyName()
}
func (u *User) GetDisplayName() string {
return u.GetName()
}
func (u *User) GetPhone() string {
return u.GetPhoneNumber()
}
func (u *User) IsPhoneVerified() bool {
return u.IsPhoneNumberVerified()
}
func (u *User) GetPreferredLanguage() language.Tag {
return u.GetLocale()
}
func (u *User) GetAvatarURL() string {
return u.GetPicture()
}

View File

@@ -0,0 +1,392 @@
package oidc
import (
"context"
"encoding/json"
"errors"
"testing"
"time"
"github.com/h2non/gock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/pkg/client/rp"
"github.com/zitadel/oidc/v2/pkg/oidc"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"gopkg.in/square/go-jose.v2"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/idp"
)
func TestSession_FetchUser(t *testing.T) {
type fields struct {
name string
issuer string
clientID string
clientSecret string
redirectURI string
userMapper func(oidc.UserInfo) idp.User
httpMock func(issuer string)
authURL string
code string
tokens *oidc.Tokens
}
type want struct {
err error
id string
firstName string
lastName string
displayName string
nickName string
preferredUsername string
email string
isEmailVerified bool
phone string
isPhoneVerified bool
preferredLanguage language.Tag
avatarURL string
profile string
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "unauthenticated session, error",
fields: fields{
name: "oidc",
issuer: "https://issuer.com",
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
userMapper: DefaultMapper,
httpMock: func(issuer string) {
gock.New(issuer).
Get(oidc.DiscoveryEndpoint).
Reply(200).
JSON(&oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
UserinfoEndpoint: issuer + "/userinfo",
})
gock.New(issuer).
Get("/userinfo").
Reply(200).
JSON(userinfo())
},
authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
tokens: nil,
},
want: want{
err: ErrCodeMissing,
},
},
{
name: "userinfo error",
fields: fields{
name: "oidc",
issuer: "https://issuer.com",
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
userMapper: DefaultMapper,
httpMock: func(issuer string) {
gock.New(issuer).
Get(oidc.DiscoveryEndpoint).
Reply(200).
JSON(&oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
UserinfoEndpoint: issuer + "/userinfo",
})
gock.New(issuer).
Get("/userinfo").
Reply(200).
JSON(userinfo())
},
authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
tokens: &oidc.Tokens{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
IDTokenClaims: oidc.NewIDTokenClaims(
"https://issuer.com",
"sub2",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
want: want{
err: rp.ErrUserInfoSubNotMatching,
},
},
{
name: "successful fetch",
fields: fields{
name: "oidc",
issuer: "https://issuer.com",
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
userMapper: DefaultMapper,
httpMock: func(issuer string) {
gock.New(issuer).
Get(oidc.DiscoveryEndpoint).
Reply(200).
JSON(&oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
UserinfoEndpoint: issuer + "/userinfo",
})
gock.New(issuer).
Get("/userinfo").
Reply(200).
JSON(userinfo())
},
authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
tokens: &oidc.Tokens{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
IDTokenClaims: oidc.NewIDTokenClaims(
"https://issuer.com",
"sub",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Second),
"nonce",
"",
nil,
"clientID",
0,
),
},
},
want: want{
id: "sub",
firstName: "firstname",
lastName: "lastname",
displayName: "firstname lastname",
nickName: "nickname",
preferredUsername: "username",
email: "email",
isEmailVerified: true,
phone: "phone",
isPhoneVerified: true,
preferredLanguage: language.English,
avatarURL: "picture",
profile: "profile",
},
},
{
name: "successful fetch with token exchange",
fields: fields{
name: "oidc",
issuer: "https://issuer.com",
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
userMapper: DefaultMapper,
httpMock: func(issuer string) {
gock.New(issuer).
Get(oidc.DiscoveryEndpoint).
Reply(200).
JSON(&oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
JwksURI: issuer + "/keys",
UserinfoEndpoint: issuer + "/userinfo",
})
gock.New(issuer).
Post("/token").
BodyString("client_id=clientID&client_secret=clientSecret&code=code&grant_type=authorization_code&redirect_uri=redirectURI").
Reply(200).
JSON(tokenResponse(t, issuer))
gock.New(issuer).
Get("/keys").
Reply(200).
JSON(keys(t))
gock.New(issuer).
Get("/userinfo").
Reply(200).
JSON(userinfo())
},
authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState",
tokens: nil,
code: "code",
},
want: want{
id: "sub",
firstName: "firstname",
lastName: "lastname",
displayName: "firstname lastname",
nickName: "nickname",
preferredUsername: "username",
email: "email",
isEmailVerified: true,
phone: "phone",
isPhoneVerified: true,
preferredLanguage: language.English,
avatarURL: "picture",
profile: "profile",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer gock.Off()
tt.fields.httpMock(tt.fields.issuer)
a := assert.New(t)
provider, err := New(tt.fields.name, tt.fields.issuer, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.userMapper)
require.NoError(t, err)
session := &Session{
Provider: provider,
AuthURL: tt.fields.authURL,
Code: tt.fields.code,
Tokens: tt.fields.tokens,
}
user, err := session.FetchUser(context.Background())
if tt.want.err != nil && !errors.Is(err, tt.want.err) {
a.Fail("invalid error", "expected %v, got %v", tt.want.err, err)
}
if tt.want.err == nil {
a.NoError(err)
a.Equal(tt.want.id, user.GetID())
a.Equal(tt.want.firstName, user.GetFirstName())
a.Equal(tt.want.lastName, user.GetLastName())
a.Equal(tt.want.displayName, user.GetDisplayName())
a.Equal(tt.want.nickName, user.GetNickname())
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
a.Equal(tt.want.email, user.GetEmail())
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
a.Equal(tt.want.phone, user.GetPhone())
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
a.Equal(tt.want.profile, user.GetProfile())
}
})
}
}
func userinfo() oidc.UserInfoSetter {
info := oidc.NewUserInfo()
info.SetSubject("sub")
info.SetGivenName("firstname")
info.SetFamilyName("lastname")
info.SetName("firstname lastname")
info.SetNickname("nickname")
info.SetPreferredUsername("username")
info.SetEmail("email", true)
info.SetPhone("phone", true)
info.SetLocale(language.English)
info.SetPicture("picture")
info.SetProfile("profile")
return info
}
func tokenResponse(t *testing.T, issuer string) *oidc.AccessTokenResponse {
claims := oidc.NewIDTokenClaims(
issuer,
"sub",
[]string{"clientID"},
time.Now().Add(1*time.Hour),
time.Now().Add(-1*time.Minute),
"",
"",
nil,
"clientID",
0,
)
privateKey, err := crypto.BytesToPrivateKey([]byte(`-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAs38btwb3c7r0tMaQpGvBmY+mPwMU/LpfuPoC0k2t4RsKp0fv
40SMl50CRrHgk395wch8PMPYbl3+8TtYAJuyrFALIj3Ff1UcKIk0hOH5DDsfh7/q
2wFuncTmS6bifYo8CfSq2vDGnM7nZnEvxY/MfSydZdcmIqlkUpfQmtzExw9+tSe5
Dxq6gn5JtlGgLgZGt69r5iMMrTEGhhVAXzNuMZbmlCoBru+rC8ITlTX/0V1ZcsSb
L8tYWhthyu9x6yjo1bH85wiVI4gs0MhU8f2a+kjL/KGZbR14Ua2eo6tonBZLC5DH
WM2TkYXgRCDPufjcgmzN0Lm91E4P8KvBcvly6QIDAQABAoIBAQCPj1nbSPcg2KZe
73FAD+8HopyUSSK//1AP4eXfzcEECVy77g0u9+R6XlkzsZCsZ4g6NN8ounqfyw3c
YlpAIkcFCf/dowoSjT+4LASVQyatYZwWNqjgAIU4KgMG/rKnNahPTiBYe7peMB1j
EaPjnt8uPkCk8y7NCi3y4Pk24tt/WM5KbJK2NQhUi1csGnleDfE+0blV0l/e6C68
W5cbnbWAroMqae/Yon3XVZiXX0m+l2f6ZzIgKaD18J+eEM8FjJC+jQKiRe1i9v3K
nQrLwh/gn8J10FcbKn3xqslKVidzASIrNIzHT9j/Z5T9NXuAKa7IV2x+Dtdus+wq
iBsUunwBAoGBANpYew+8i9vDwK4/SefduDTuzJ0H9lWTjtbiWQ+KYZoeJ7q3/qns
jsmi+mjxkXxXg1RrGbNbjtbl3RXXIrUeeBB0lglRJUjc3VK7VvNoyXIWsiqhCspH
IJ9Yuknv4mXB01m/glbSCS/xu4RTgf5aOG4jUiRb9+dCIpvDxI9gbXEVAoGBANJz
hIJkplIJ+biTi3G1Oz17qkUkInNXzAEzKD9Atoz5AIAiR1ivOMLOlbucfjevw/Nw
TnpkMs9xqCefKupTlsriXtZI88m7ZKzAmolYsPolOy/Jhi31h9JFVTEfKGqVS+dk
A4ndhgdW9RUeNJPY2YVCARXQrWpueweQDA1cNaeFAoGAPJsYtXqBW6PPRM5+ZiSt
78tk8iV2o7RMjqrPS7f+dXfvUS2nO2VVEPTzCtQarOfhpToBLT65vD6bimdn09w8
OV0TFEz4y2u65y7m6LNqTwertpdy1ki97l0DgGhccCBH2P6GYDD2qd8wTH+dcot6
ZF/begopGoDJ+HBzi9SZLC0CgYBZzPslHMevyBvr++GLwrallKhiWnns1/DwLiEl
ZHrBCtuA0Z+6IwLIdZiE9tEQ+ApYTXrfVPQteqUzSwLn/IUiy5eGPpjwYushoAoR
Q2w5QTvRN1/vKo8rVXR1woLfgBdkhFPSN1mitiNcQIhU8jpXV4PZCDOHb99FqdzK
sqcedQKBgQCOmgbqxGsnT2WQhoOdzln+NOo6Tx+FveLLqat2KzpY59W4noeI2Awn
HfIQgWUAW9dsjVVOXMP1jhq8U9hmH/PFWA11V/iCdk1NTxZEw87VAOeWuajpdDHG
+iex349j8h2BcQ4Zd0FWu07gGFnS/yuDJPn6jBhRusdieEcxLRjTKg==
-----END RSA PRIVATE KEY-----
`))
if err != nil {
t.Fatal(err)
}
signer, err := jose.NewSigner(jose.SigningKey{Key: privateKey, Algorithm: "RS256"}, &jose.SignerOptions{})
if err != nil {
t.Fatal(err)
}
data, err := json.Marshal(claims)
if err != nil {
t.Fatal(err)
}
jws, err := signer.Sign(data)
if err != nil {
t.Fatal(err)
}
idToken, err := jws.CompactSerialize()
if err != nil {
t.Fatal(err)
}
return &oidc.AccessTokenResponse{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
RefreshToken: "",
ExpiresIn: 3600,
IDToken: idToken,
State: "testState",
}
}
func keys(t *testing.T) *jose.JSONWebKeySet {
privateKey, err := crypto.BytesToPublicKey([]byte(`-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAs38btwb3c7r0tMaQpGvB
mY+mPwMU/LpfuPoC0k2t4RsKp0fv40SMl50CRrHgk395wch8PMPYbl3+8TtYAJuy
rFALIj3Ff1UcKIk0hOH5DDsfh7/q2wFuncTmS6bifYo8CfSq2vDGnM7nZnEvxY/M
fSydZdcmIqlkUpfQmtzExw9+tSe5Dxq6gn5JtlGgLgZGt69r5iMMrTEGhhVAXzNu
MZbmlCoBru+rC8ITlTX/0V1ZcsSbL8tYWhthyu9x6yjo1bH85wiVI4gs0MhU8f2a
+kjL/KGZbR14Ua2eo6tonBZLC5DHWM2TkYXgRCDPufjcgmzN0Lm91E4P8KvBcvly
6QIDAQAB
-----END PUBLIC KEY-----
`))
if err != nil {
t.Fatal(err)
}
return &jose.JSONWebKeySet{Keys: []jose.JSONWebKey{{Key: privateKey, Algorithm: "RS256", Use: oidc.KeyUseSignature}}}
}