feat: add PKCE option to generic OAuth2 / OIDC identity providers (#9373)

# Which Problems Are Solved

Some OAuth2 and OIDC providers require the use of PKCE for all their
clients. While ZITADEL already recommended the same for its clients, it
did not yet support the option on the IdP configuration.

# How the Problems Are Solved

- A new boolean `use_pkce` is added to the add/update generic OAuth/OIDC
endpoints.
- A new checkbox is added to the generic OAuth and OIDC provider
templates.
- The `rp.WithPKCE` option is added to the provider if the use of PKCE
has been set.
- The `rp.WithCodeChallenge` and `rp.WithCodeVerifier` options are added
to the OIDC/Auth BeginAuth and CodeExchange function.
- Store verifier or any other persistent argument in the intent or auth
request.
- Create corresponding session object before creating the intent, to be
able to store the information.
- (refactored session structs to use a constructor for unified creation
and better overview of actual usage)

Here's a screenshot showing the URI including the PKCE params:


![use_pkce_in_url](https://github.com/zitadel/zitadel/assets/30386061/eaeab123-a5da-4826-b001-2ae9efa35169)

# Additional Changes

None.

# Additional Context

- Closes #6449
- This PR replaces the existing PR (#8228) of @doncicuto. The base he
did was cherry picked. Thank you very much for that!

---------

Co-authored-by: Miguel Cabrerizo <doncicuto@gmail.com>
Co-authored-by: Stefan Benz <46600784+stebenz@users.noreply.github.com>
This commit is contained in:
Livio Spring
2025-02-26 13:20:47 +01:00
committed by GitHub
parent 32ec7d0aa9
commit 8f88c4cf5b
79 changed files with 801 additions and 169 deletions

View File

@@ -17,6 +17,10 @@ type Session struct {
UserFormValue string
}
func NewSession(provider *Provider, code, userFormValue string) *Session {
return &Session{Session: oidc.NewSession(provider.Provider, code, nil), UserFormValue: userFormValue}
}
type userFormValue struct {
Name userNamesFormValue `json:"name,omitempty" schema:"name"`
}

View File

@@ -20,6 +20,10 @@ type Session struct {
OAuthSession *oauth.Session
}
func NewSession(provider *Provider, code string) *Session {
return &Session{Provider: provider, Code: code}
}
// GetAuth implements the [idp.Provider] interface by calling the wrapped [oauth.Session].
func (s *Session) GetAuth(ctx context.Context) (content string, redirect bool) {
return s.oauth().GetAuth(ctx)
@@ -39,6 +43,11 @@ func (s *Session) RetrievePreviousID() (string, error) {
return userinfo.Subject, nil
}
// PersistentParameters implements the [idp.Session] interface.
func (s *Session) PersistentParameters() map[string]any {
return nil
}
// FetchUser implements the [idp.Session] interface.
// It will execute an OAuth 2.0 code exchange if needed to retrieve the access token,
// call the specified userEndpoint and map the received information into an [idp.User].

View File

@@ -30,11 +30,20 @@ type Session struct {
Tokens *oidc.Tokens[*oidc.IDTokenClaims]
}
func NewSession(provider *Provider, tokens *oidc.Tokens[*oidc.IDTokenClaims]) *Session {
return &Session{Provider: provider, Tokens: tokens}
}
// GetAuth implements the [idp.Session] interface.
func (s *Session) GetAuth(ctx context.Context) (string, bool) {
return idp.Redirect(s.AuthURL)
}
// PersistentParameters implements the [idp.Session] interface.
func (s *Session) PersistentParameters() map[string]any {
return nil
}
// FetchUser implements the [idp.Session] interface.
// It will map the received idToken into an [idp.User].
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {

View File

@@ -34,11 +34,21 @@ type Session struct {
Entry *ldap.Entry
}
func NewSession(provider *Provider, username, password string) *Session {
return &Session{Provider: provider, User: username, Password: password}
}
// GetAuth implements the [idp.Session] interface.
func (s *Session) GetAuth(ctx context.Context) (string, bool) {
return idp.Redirect(s.loginUrl)
}
// PersistentParameters implements the [idp.Session] interface.
func (s *Session) PersistentParameters() map[string]any {
return nil
}
// FetchUser implements the [idp.Session] interface.
func (s *Session) FetchUser(_ context.Context) (_ idp.User, err error) {
var user *ldap.Entry
for _, server := range s.Provider.servers {

View File

@@ -23,6 +23,7 @@ type Provider struct {
isCreationAllowed bool
isAutoCreation bool
isAutoUpdate bool
generateVerifier func() string
}
type ProviderOpts func(provider *Provider)
@@ -66,9 +67,10 @@ func WithRelyingPartyOption(option rp.Option) ProviderOpts {
// New creates a generic OAuth 2.0 provider
func New(config *oauth2.Config, name, userEndpoint string, userMapper func() idp.User, options ...ProviderOpts) (provider *Provider, err error) {
provider = &Provider{
name: name,
userEndpoint: userEndpoint,
userMapper: userMapper,
name: name,
userEndpoint: userEndpoint,
userMapper: userMapper,
generateVerifier: oauth2.GenerateVerifier,
}
for _, option := range options {
option(provider)
@@ -99,8 +101,15 @@ func (p *Provider) BeginAuth(ctx context.Context, state string, params ...idp.Pa
if !loginHintSet {
opts = append(opts, rp.WithPrompt(oidc.PromptSelectAccount))
}
var codeVerifier string
if p.RelyingParty.IsPKCE() {
codeVerifier = p.generateVerifier()
opts = append(opts, rp.WithCodeChallenge(oidc.NewSHACodeChallenge(codeVerifier)))
}
url := rp.AuthURL(state, p.RelyingParty, opts...)
return &Session{AuthURL: url, Provider: p}, nil
return &Session{AuthURL: url, Provider: p, CodeVerifier: codeVerifier}, nil
}
func loginHint(hint string) rp.AuthURLOpt {

View File

@@ -18,6 +18,7 @@ func TestProvider_BeginAuth(t *testing.T) {
name string
userEndpoint string
userMapper func() idp.User
options []ProviderOpts
}
tests := []struct {
name string
@@ -25,7 +26,7 @@ func TestProvider_BeginAuth(t *testing.T) {
want idp.Session
}{
{
name: "successful auth",
name: "successful auth without PKCE",
fields: fields{
config: &oauth2.Config{
ClientID: "clientID",
@@ -40,14 +41,40 @@ func TestProvider_BeginAuth(t *testing.T) {
},
want: &Session{AuthURL: "https://oauth2.com/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=user&state=testState"},
},
{
name: "successful auth with PKCE",
fields: fields{
config: &oauth2.Config{
ClientID: "clientID",
ClientSecret: "clientSecret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://oauth2.com/authorize",
TokenURL: "https://oauth2.com/token",
},
RedirectURL: "redirectURI",
Scopes: []string{"user"},
},
options: []ProviderOpts{
WithLinkingAllowed(),
WithCreationAllowed(),
WithAutoCreation(),
WithAutoUpdate(),
WithRelyingPartyOption(rp.WithPKCE(nil)),
},
},
want: &Session{AuthURL: "https://oauth2.com/authorize?client_id=clientID&code_challenge=2ZoH_a01aprzLkwVbjlPsBo4m8mJ_zOKkaDqYM7Oh5w&code_challenge_method=S256&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=user&state=testState"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
r := require.New(t)
provider, err := New(tt.fields.config, tt.fields.name, tt.fields.userEndpoint, tt.fields.userMapper)
provider, err := New(tt.fields.config, tt.fields.name, tt.fields.userEndpoint, tt.fields.userMapper, tt.fields.options...)
r.NoError(err)
provider.generateVerifier = func() string {
return "pkceOAuthVerifier"
}
ctx := context.Background()
session, err := provider.BeginAuth(ctx, "testState")

View File

@@ -14,22 +14,40 @@ import (
var ErrCodeMissing = errors.New("no auth code provided")
const (
CodeVerifier = "codeVerifier"
)
var _ idp.Session = (*Session)(nil)
// Session is the [idp.Session] implementation for the OAuth2.0 provider.
type Session struct {
AuthURL string
Code string
Tokens *oidc.Tokens[*oidc.IDTokenClaims]
AuthURL string
CodeVerifier string
Code string
Tokens *oidc.Tokens[*oidc.IDTokenClaims]
Provider *Provider
}
func NewSession(provider *Provider, code string, idpArguments map[string]any) *Session {
verifier, _ := idpArguments[CodeVerifier].(string)
return &Session{Provider: provider, Code: code, CodeVerifier: verifier}
}
// GetAuth implements the [idp.Session] interface.
func (s *Session) GetAuth(ctx context.Context) (string, bool) {
return idp.Redirect(s.AuthURL)
}
// PersistentParameters implements the [idp.Session] interface.
func (s *Session) PersistentParameters() map[string]any {
if s.CodeVerifier == "" {
return nil
}
return map[string]any{CodeVerifier: s.CodeVerifier}
}
// FetchUser implements the [idp.Session] interface.
// It will execute an OAuth 2.0 code exchange if needed to retrieve the access token,
// call the specified userEndpoint and map the received information into an [idp.User].
@@ -55,7 +73,11 @@ func (s *Session) authorize(ctx context.Context) (err error) {
if s.Code == "" {
return ErrCodeMissing
}
s.Tokens, err = rp.CodeExchange[*oidc.IDTokenClaims](ctx, s.Code, s.Provider.RelyingParty)
var opts []rp.CodeExchangeOpt
if s.CodeVerifier != "" {
opts = append(opts, rp.WithCodeVerifier(s.CodeVerifier))
}
s.Tokens, err = rp.CodeExchange[*oidc.IDTokenClaims](ctx, s.Code, s.Provider.RelyingParty, opts...)
return err
}

View File

@@ -24,6 +24,7 @@ type Provider struct {
useIDToken bool
userInfoMapper func(info *oidc.UserInfo) idp.User
authOptions []func(bool) rp.AuthURLOpt
generateVerifier func() string
}
type ProviderOpts func(provider *Provider)
@@ -102,8 +103,9 @@ var DefaultMapper UserInfoMapper = func(info *oidc.UserInfo) idp.User {
// New creates a generic OIDC provider
func New(name, issuer, clientID, clientSecret, redirectURI string, scopes []string, userInfoMapper UserInfoMapper, options ...ProviderOpts) (provider *Provider, err error) {
provider = &Provider{
name: name,
userInfoMapper: userInfoMapper,
name: name,
userInfoMapper: userInfoMapper,
generateVerifier: oauth2.GenerateVerifier,
}
for _, option := range options {
option(provider)
@@ -150,8 +152,15 @@ func (p *Provider) BeginAuth(ctx context.Context, state string, params ...idp.Pa
opts = append(opts, opt)
}
}
var codeVerifier string
if p.RelyingParty.IsPKCE() {
codeVerifier = p.generateVerifier()
opts = append(opts, rp.WithCodeChallenge(oidc.NewSHACodeChallenge(codeVerifier)))
}
url := rp.AuthURL(state, p.RelyingParty, opts...)
return &Session{AuthURL: url, Provider: p}, nil
return &Session{AuthURL: url, Provider: p, CodeVerifier: codeVerifier}, nil
}
func loginHint(hint string) rp.AuthURLOpt {

View File

@@ -31,7 +31,7 @@ func TestProvider_BeginAuth(t *testing.T) {
want idp.Session
}{
{
name: "successful auth",
name: "successful auth without PKCE",
fields: fields{
name: "oidc",
issuer: "https://issuer.com",
@@ -55,6 +55,31 @@ func TestProvider_BeginAuth(t *testing.T) {
},
want: &Session{AuthURL: "https://issuer.com/authorize?client_id=clientID&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState"},
},
{
name: "successful auth with PKCE",
fields: fields{
name: "oidc",
issuer: "https://issuer.com",
clientID: "clientID",
clientSecret: "clientSecret",
redirectURI: "redirectURI",
scopes: []string{"openid"},
userMapper: DefaultMapper,
httpMock: func(issuer string) {
gock.New(issuer).
Get(oidc.DiscoveryEndpoint).
Reply(200).
JSON(&oidc.DiscoveryConfiguration{
Issuer: issuer,
AuthorizationEndpoint: issuer + "/authorize",
TokenEndpoint: issuer + "/token",
UserinfoEndpoint: issuer + "/userinfo",
})
},
opts: []ProviderOpts{WithSelectAccount(), WithRelyingPartyOption(rp.WithPKCE(nil))},
},
want: &Session{AuthURL: "https://issuer.com/authorize?client_id=clientID&code_challenge=2ZoH_a01aprzLkwVbjlPsBo4m8mJ_zOKkaDqYM7Oh5w&code_challenge_method=S256&prompt=select_account&redirect_uri=redirectURI&response_type=code&scope=openid&state=testState"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -65,6 +90,9 @@ func TestProvider_BeginAuth(t *testing.T) {
provider, err := New(tt.fields.name, tt.fields.issuer, tt.fields.clientID, tt.fields.clientSecret, tt.fields.redirectURI, tt.fields.scopes, tt.fields.userMapper, tt.fields.opts...)
r.NoError(err)
provider.generateVerifier = func() string {
return "pkceOAuthVerifier"
}
ctx := context.Background()
session, err := provider.BeginAuth(ctx, "testState")

View File

@@ -10,6 +10,7 @@ import (
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/idp"
"github.com/zitadel/zitadel/internal/idp/providers/oauth"
)
var ErrCodeMissing = errors.New("no auth code provided")
@@ -18,10 +19,16 @@ var _ idp.Session = (*Session)(nil)
// Session is the [idp.Session] implementation for the OIDC provider.
type Session struct {
Provider *Provider
AuthURL string
Code string
Tokens *oidc.Tokens[*oidc.IDTokenClaims]
Provider *Provider
AuthURL string
CodeVerifier string
Code string
Tokens *oidc.Tokens[*oidc.IDTokenClaims]
}
func NewSession(provider *Provider, code string, idpArguments map[string]any) *Session {
verifier, _ := idpArguments[oauth.CodeVerifier].(string)
return &Session{Provider: provider, Code: code, CodeVerifier: verifier}
}
// GetAuth implements the [idp.Session] interface.
@@ -29,6 +36,14 @@ func (s *Session) GetAuth(ctx context.Context) (string, bool) {
return idp.Redirect(s.AuthURL)
}
// PersistentParameters implements the [idp.Session] interface.
func (s *Session) PersistentParameters() map[string]any {
if s.CodeVerifier == "" {
return nil
}
return map[string]any{oauth.CodeVerifier: s.CodeVerifier}
}
// FetchUser implements the [idp.Session] interface.
// It will execute an OIDC code exchange if needed to retrieve the tokens,
// call the userinfo endpoint and map the received information into an [idp.User].
@@ -61,7 +76,11 @@ func (s *Session) Authorize(ctx context.Context) (err error) {
if s.Code == "" {
return ErrCodeMissing
}
s.Tokens, err = rp.CodeExchange[*oidc.IDTokenClaims](ctx, s.Code, s.Provider.RelyingParty)
var opts []rp.CodeExchangeOpt
if s.CodeVerifier != "" {
opts = append(opts, rp.WithCodeVerifier(s.CodeVerifier))
}
s.Tokens, err = rp.CodeExchange[*oidc.IDTokenClaims](ctx, s.Code, s.Provider.RelyingParty, opts...)
return err
}

View File

@@ -60,6 +60,11 @@ func (s *Session) GetAuth(ctx context.Context) (string, bool) {
return idp.Form(resp.content.String())
}
// PersistentParameters implements the [idp.Session] interface.
func (s *Session) PersistentParameters() map[string]any {
return nil
}
// FetchUser implements the [idp.Session] interface.
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
if s.RequestID == "" || s.Request == nil {

View File

@@ -7,6 +7,7 @@ import (
// Session is the minimal implementation for a session of a 3rd party authentication [Provider]
type Session interface {
GetAuth(ctx context.Context) (content string, redirect bool)
PersistentParameters() map[string]any
FetchUser(ctx context.Context) (User, error)
}