feat: add basic structure of idp templates (#5053)

add basic structure and implement first providers for IDP templates to be able to manage and use them in the future
This commit is contained in:
Livio Spring
2023-01-23 08:11:40 +01:00
committed by GitHub
parent 7b5135e637
commit 598a4d2d4b
29 changed files with 3907 additions and 54 deletions

View File

@@ -0,0 +1,102 @@
package oauth
import (
"encoding/json"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/idp"
)
var _ idp.User = (*UserMapper)(nil)
// UserMapper is an implementation of [idp.User].
// It can be used in ZITADEL actions to map the raw `info`
type UserMapper struct {
ID string
FirstName string
LastName string
DisplayName string
NickName string
PreferredUsername string
Email string
EmailVerified bool
Phone string
PhoneVerified bool
PreferredLanguage string
AvatarURL string
Profile string
info map[string]interface{}
}
func (u *UserMapper) UnmarshalJSON(data []byte) error {
if u.info == nil {
u.info = make(map[string]interface{})
}
return json.Unmarshal(data, &u.info)
}
// GetID is an implementation of the [idp.User] interface.
func (u *UserMapper) GetID() string {
return u.ID
}
// GetFirstName is an implementation of the [idp.User] interface.
func (u *UserMapper) GetFirstName() string {
return u.FirstName
}
// GetLastName is an implementation of the [idp.User] interface.
func (u *UserMapper) GetLastName() string {
return u.LastName
}
// GetDisplayName is an implementation of the [idp.User] interface.
func (u *UserMapper) GetDisplayName() string {
return u.DisplayName
}
// GetNickname is an implementation of the [idp.User] interface.
func (u *UserMapper) GetNickname() string {
return u.NickName
}
// GetPreferredUsername is an implementation of the [idp.User] interface.
func (u *UserMapper) GetPreferredUsername() string {
return u.PreferredUsername
}
// GetEmail is an implementation of the [idp.User] interface.
func (u *UserMapper) GetEmail() string {
return u.Email
}
// IsEmailVerified is an implementation of the [idp.User] interface.
func (u *UserMapper) IsEmailVerified() bool {
return u.EmailVerified
}
// GetPhone is an implementation of the [idp.User] interface.
func (u *UserMapper) GetPhone() string {
return u.Phone
}
// IsPhoneVerified is an implementation of the [idp.User] interface.
func (u *UserMapper) IsPhoneVerified() bool {
return u.PhoneVerified
}
// GetPreferredLanguage is an implementation of the [idp.User] interface.
func (u *UserMapper) GetPreferredLanguage() language.Tag {
return language.Make(u.PreferredLanguage)
}
// GetAvatarURL is an implementation of the [idp.User] interface.
func (u *UserMapper) GetAvatarURL() string {
return u.AvatarURL
}
// GetProfile is an implementation of the [idp.User] interface.
func (u *UserMapper) GetProfile() string {
return u.Profile
}

View File

@@ -0,0 +1,112 @@
package oauth
import (
"context"
"github.com/zitadel/oidc/v2/pkg/client/rp"
"golang.org/x/oauth2"
"github.com/zitadel/zitadel/internal/idp"
)
var _ idp.Provider = (*Provider)(nil)
// Provider is the [idp.Provider] implementation for a generic OAuth 2.0 provider
type Provider struct {
rp.RelyingParty
options []rp.Option
name string
userEndpoint string
userMapper func() idp.User
isLinkingAllowed bool
isCreationAllowed bool
isAutoCreation bool
isAutoUpdate bool
}
type ProviderOpts func(provider *Provider)
// WithLinkingAllowed allows end users to link the federated user to an existing one.
func WithLinkingAllowed() ProviderOpts {
return func(p *Provider) {
p.isLinkingAllowed = true
}
}
// WithCreationAllowed allows end users to create a new user using the federated information.
func WithCreationAllowed() ProviderOpts {
return func(p *Provider) {
p.isCreationAllowed = true
}
}
// WithAutoCreation enables that federated users are automatically created if not already existing.
func WithAutoCreation() ProviderOpts {
return func(p *Provider) {
p.isAutoCreation = true
}
}
// WithAutoUpdate enables that information retrieved from the provider is automatically used to update
// the existing user on each authentication.
func WithAutoUpdate() ProviderOpts {
return func(p *Provider) {
p.isAutoUpdate = true
}
}
// WithRelyingPartyOption allows to set an additional [rp.Option] like [rp.WithPKCE].
func WithRelyingPartyOption(option rp.Option) ProviderOpts {
return func(p *Provider) {
p.options = append(p.options, option)
}
}
// New creates a generic OAuth 2.0 provider
func New(config *oauth2.Config, name, userEndpoint string, userMapper func() idp.User, options ...ProviderOpts) (provider *Provider, err error) {
provider = &Provider{
name: name,
userEndpoint: userEndpoint,
userMapper: userMapper,
}
for _, option := range options {
option(provider)
}
provider.RelyingParty, err = rp.NewRelyingPartyOAuth(config, provider.options...)
if err != nil {
return nil, err
}
return provider, nil
}
// Name implements the [idp.Provider] interface
func (p *Provider) Name() string {
return p.name
}
// BeginAuth implements the [idp.Provider] interface.
// It will create a [Session] with an OAuth2.0 authorization request as AuthURL.
func (p *Provider) BeginAuth(ctx context.Context, state string, _ ...any) (idp.Session, error) {
url := rp.AuthURL(state, p.RelyingParty)
return &Session{AuthURL: url, Provider: p}, nil
}
// IsLinkingAllowed implements the [idp.Provider] interface.
func (p *Provider) IsLinkingAllowed() bool {
return p.isLinkingAllowed
}
// IsCreationAllowed implements the [idp.Provider] interface.
func (p *Provider) IsCreationAllowed() bool {
return p.isCreationAllowed
}
// IsAutoCreation implements the [idp.Provider] interface.
func (p *Provider) IsAutoCreation() bool {
return p.isAutoCreation
}
// IsAutoUpdate implements the [idp.Provider] interface.
func (p *Provider) IsAutoUpdate() bool {
return p.isAutoUpdate
}

View File

@@ -0,0 +1,153 @@
package oauth
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/pkg/client/rp"
"golang.org/x/oauth2"
"github.com/zitadel/zitadel/internal/idp"
)
func TestProvider_BeginAuth(t *testing.T) {
type fields struct {
config *oauth2.Config
name string
userEndpoint string
userMapper func() idp.User
}
tests := []struct {
name string
fields fields
want idp.Session
}{
{
name: "successful auth",
fields: fields{
config: &oauth2.Config{
ClientID: "clientID",
ClientSecret: "clientSecret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://oauth2.com/authorize",
TokenURL: "https://oauth2.com/token",
},
RedirectURL: "redirectURI",
Scopes: []string{"user"},
},
},
want: &Session{AuthURL: "https://oauth2.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=user&state=testState"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
r := require.New(t)
provider, err := New(tt.fields.config, tt.fields.name, tt.fields.userEndpoint, tt.fields.userMapper)
r.NoError(err)
session, err := provider.BeginAuth(context.Background(), "testState")
r.NoError(err)
a.Equal(tt.want.GetAuthURL(), session.GetAuthURL())
})
}
}
func TestProvider_Options(t *testing.T) {
type fields struct {
config *oauth2.Config
name string
userEndpoint string
userMapper func() idp.User
options []ProviderOpts
}
type want struct {
name string
linkingAllowed bool
creationAllowed bool
autoCreation bool
autoUpdate bool
pkce bool
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "default",
fields: fields{
name: "oauth",
config: &oauth2.Config{
ClientID: "clientID",
ClientSecret: "clientSecret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://oauth2.com/authorize",
TokenURL: "https://oauth2.com/token",
},
RedirectURL: "redirectURI",
Scopes: []string{"user"},
},
options: nil,
},
want: want{
name: "oauth",
linkingAllowed: false,
creationAllowed: false,
autoCreation: false,
autoUpdate: false,
pkce: false,
},
},
{
name: "all true",
fields: fields{
name: "oauth",
config: &oauth2.Config{
ClientID: "clientID",
ClientSecret: "clientSecret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://oauth2.com/authorize",
TokenURL: "https://oauth2.com/token",
},
RedirectURL: "redirectURI",
Scopes: []string{"user"},
},
options: []ProviderOpts{
WithLinkingAllowed(),
WithCreationAllowed(),
WithAutoCreation(),
WithAutoUpdate(),
WithRelyingPartyOption(rp.WithPKCE(nil)),
},
},
want: want{
name: "oauth",
linkingAllowed: true,
creationAllowed: true,
autoCreation: true,
autoUpdate: true,
pkce: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
provider, err := New(tt.fields.config, tt.fields.name, tt.fields.userEndpoint, tt.fields.userMapper, tt.fields.options...)
require.NoError(t, err)
a.Equal(tt.want.name, provider.Name())
a.Equal(tt.want.linkingAllowed, provider.IsLinkingAllowed())
a.Equal(tt.want.creationAllowed, provider.IsCreationAllowed())
a.Equal(tt.want.autoCreation, provider.IsAutoCreation())
a.Equal(tt.want.autoUpdate, provider.IsAutoUpdate())
a.Equal(tt.want.pkce, provider.RelyingParty.IsPKCE())
})
}
}

View File

@@ -0,0 +1,63 @@
package oauth
import (
"context"
"errors"
"net/http"
"github.com/zitadel/oidc/v2/pkg/client/rp"
httphelper "github.com/zitadel/oidc/v2/pkg/http"
"github.com/zitadel/oidc/v2/pkg/oidc"
"github.com/zitadel/zitadel/internal/idp"
)
var ErrCodeMissing = errors.New("no auth code provided")
var _ idp.Session = (*Session)(nil)
// Session is the [idp.Session] implementation for the OAuth2.0 provider.
type Session struct {
AuthURL string
Code string
Tokens *oidc.Tokens
Provider *Provider
}
// GetAuthURL implements the [idp.Session] interface.
func (s *Session) GetAuthURL() string {
return s.AuthURL
}
// FetchUser implements the [idp.Session] interface.
// It will execute an OAuth 2.0 code exchange if needed to retrieve the access token,
// call the specified userEndpoint and map the received information into an [idp.User].
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
if s.Tokens == nil {
if err = s.authorize(ctx); err != nil {
return nil, err
}
}
req, err := http.NewRequest("GET", s.Provider.userEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("authorization", s.Tokens.TokenType+" "+s.Tokens.AccessToken)
mapper := s.Provider.userMapper()
if err := httphelper.HttpRequest(s.Provider.RelyingParty.HttpClient(), req, &mapper); err != nil {
return nil, err
}
return mapper, nil
}
func (s *Session) authorize(ctx context.Context) (err error) {
if s.Code == "" {
return ErrCodeMissing
}
s.Tokens, err = rp.CodeExchange(ctx, s.Code, s.Provider.RelyingParty)
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,273 @@
package oauth
import (
"context"
"errors"
"net/http"
"testing"
"github.com/h2non/gock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/oidc/v2/pkg/oidc"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/idp"
)
func TestProvider_FetchUser(t *testing.T) {
type fields struct {
config *oauth2.Config
name string
userEndpoint string
httpMock func(issuer string)
userMapper func() idp.User
authURL string
code string
tokens *oidc.Tokens
}
type want struct {
err func(error) bool
user idp.User
id string
firstName string
lastName string
displayName string
nickName string
preferredUsername string
email string
isEmailVerified bool
phone string
isPhoneVerified bool
preferredLanguage language.Tag
avatarURL string
profile string
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "unauthenticated session, error",
fields: fields{
config: &oauth2.Config{
ClientID: "clientID",
ClientSecret: "clientSecret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://oauth2.com/authorize",
TokenURL: "https://oauth2.com/token",
},
RedirectURL: "redirectURI",
Scopes: []string{"user"},
},
userEndpoint: "https://oauth2.com/user",
httpMock: func(issuer string) {},
authURL: "https://oauth2.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=user&state=testState",
tokens: nil,
},
want: want{
err: func(err error) bool {
return errors.Is(err, ErrCodeMissing)
},
},
},
{
name: "user error",
fields: fields{
config: &oauth2.Config{
ClientID: "clientID",
ClientSecret: "clientSecret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://oauth2.com/authorize",
TokenURL: "https://oauth2.com/token",
},
RedirectURL: "redirectURI",
Scopes: []string{"user"},
},
userEndpoint: "https://oauth2.com/user",
httpMock: func(issuer string) {
gock.New(issuer).
Get("/user").
Reply(http.StatusInternalServerError)
},
userMapper: func() idp.User {
return &UserMapper{
ID: "userID",
}
},
authURL: "https://oauth2.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=user&state=testState",
tokens: &oidc.Tokens{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
},
},
want: want{
err: func(err error) bool {
return err.Error() == "http status not ok: 500 Internal Server Error "
},
},
},
{
name: "successful fetch",
fields: fields{
config: &oauth2.Config{
ClientID: "clientID",
ClientSecret: "clientSecret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://oauth2.com/authorize",
TokenURL: "https://oauth2.com/token",
},
RedirectURL: "redirectURI",
Scopes: []string{"user"},
},
userEndpoint: "https://oauth2.com/user",
httpMock: func(issuer string) {
gock.New(issuer).
Get("/user").
Reply(200).
JSON(map[string]interface{}{
"userID": "id",
"custom": "claim",
})
},
userMapper: func() idp.User {
return &UserMapper{}
},
authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=user&state=testState",
tokens: &oidc.Tokens{
Token: &oauth2.Token{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
},
},
},
want: want{
user: &UserMapper{
info: map[string]interface{}{
"userID": "id",
"custom": "claim",
},
},
id: "",
firstName: "",
lastName: "",
displayName: "",
nickName: "",
preferredUsername: "",
email: "",
isEmailVerified: false,
phone: "",
isPhoneVerified: false,
preferredLanguage: language.Und,
avatarURL: "",
profile: "",
},
},
{
name: "successful fetch with code exchange",
fields: fields{
config: &oauth2.Config{
ClientID: "clientID",
ClientSecret: "clientSecret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://oauth2.com/authorize",
TokenURL: "https://oauth2.com/token",
},
RedirectURL: "redirectURI",
Scopes: []string{"user"},
},
userEndpoint: "https://oauth2.com/user",
httpMock: func(issuer string) {
gock.New(issuer).
Post("/token").
BodyString("client_id=clientID&client_secret=clientSecret&code=code&grant_type=authorization_code&redirect_uri=redirectURI").
Reply(200).
JSON(&oidc.AccessTokenResponse{
AccessToken: "accessToken",
TokenType: oidc.BearerToken,
RefreshToken: "",
ExpiresIn: 3600,
IDToken: "",
State: "testState"})
gock.New(issuer).
Get("/user").
Reply(200).
JSON(map[string]interface{}{
"userID": "id",
"custom": "claim",
})
},
userMapper: func() idp.User {
return &UserMapper{}
},
authURL: "https://issuer.com/authorize?client_id=clientID&redirect_uri=redirectURI&response_type=code&scope=user&state=testState",
tokens: nil,
code: "code",
},
want: want{
user: &UserMapper{
info: map[string]interface{}{
"userID": "id",
"custom": "claim",
},
},
id: "",
firstName: "",
lastName: "",
displayName: "",
nickName: "",
preferredUsername: "",
email: "",
isEmailVerified: false,
phone: "",
isPhoneVerified: false,
preferredLanguage: language.Und,
avatarURL: "",
profile: "",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer gock.Off()
tt.fields.httpMock("https://oauth2.com")
a := assert.New(t)
provider, err := New(tt.fields.config, tt.fields.name, tt.fields.userEndpoint, tt.fields.userMapper)
require.NoError(t, err)
session := &Session{
AuthURL: tt.fields.authURL,
Code: tt.fields.code,
Tokens: tt.fields.tokens,
Provider: provider,
}
user, err := session.FetchUser(context.Background())
if tt.want.err != nil && !tt.want.err(err) {
a.Fail("invalid error", err)
}
if tt.want.err == nil {
a.NoError(err)
a.Equal(tt.want.user, user)
a.Equal(tt.want.id, user.GetID())
a.Equal(tt.want.firstName, user.GetFirstName())
a.Equal(tt.want.lastName, user.GetLastName())
a.Equal(tt.want.displayName, user.GetDisplayName())
a.Equal(tt.want.nickName, user.GetNickname())
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
a.Equal(tt.want.email, user.GetEmail())
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
a.Equal(tt.want.phone, user.GetPhone())
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
a.Equal(tt.want.profile, user.GetProfile())
}
})
}
}