feat: add basic structure of idp templates (#5053)

add basic structure and implement first providers for IDP templates to be able to manage and use them in the future
This commit is contained in:
Livio Spring
2023-01-23 08:11:40 +01:00
committed by GitHub
parent 7b5135e637
commit 598a4d2d4b
29 changed files with 3907 additions and 54 deletions

View File

@@ -0,0 +1,136 @@
package jwt
import (
"context"
"encoding/base64"
"errors"
"net/url"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/idp"
)
const (
queryAuthRequestID = "authRequestID"
queryUserAgentID = "userAgentID"
)
var _ idp.Provider = (*Provider)(nil)
var (
ErrNoTokens = errors.New("no tokens provided")
ErrMissingUserAgentID = errors.New("userAgentID missing")
)
// Provider is the [idp.Provider] implementation for a JWT provider
type Provider struct {
name string
headerName string
issuer string
jwtEndpoint string
keysEndpoint string
isLinkingAllowed bool
isCreationAllowed bool
isAutoCreation bool
isAutoUpdate bool
encryptionAlg crypto.EncryptionAlgorithm
}
type ProviderOpts func(provider *Provider)
// WithLinkingAllowed allows end users to link the federated user to an existing one
func WithLinkingAllowed() ProviderOpts {
return func(p *Provider) {
p.isLinkingAllowed = true
}
}
// WithCreationAllowed allows end users to create a new user using the federated information
func WithCreationAllowed() ProviderOpts {
return func(p *Provider) {
p.isCreationAllowed = true
}
}
// WithAutoCreation enables that federated users are automatically created if not already existing
func WithAutoCreation() ProviderOpts {
return func(p *Provider) {
p.isAutoCreation = true
}
}
// WithAutoUpdate enables that information retrieved from the provider is automatically used to update
// the existing user on each authentication
func WithAutoUpdate() ProviderOpts {
return func(p *Provider) {
p.isAutoUpdate = true
}
}
// New creates a JWT provider
func New(name, issuer, jwtEndpoint, keysEndpoint, headerName string, encryptionAlg crypto.EncryptionAlgorithm, options ...ProviderOpts) (*Provider, error) {
provider := &Provider{
name: name,
issuer: issuer,
jwtEndpoint: jwtEndpoint,
keysEndpoint: keysEndpoint,
headerName: headerName,
encryptionAlg: encryptionAlg,
}
for _, option := range options {
option(provider)
}
return provider, nil
}
// Name implements the [idp.Provider] interface
func (p *Provider) Name() string {
return p.name
}
// BeginAuth implements the [idp.Provider] interface.
// It will create a [Session] with an AuthURL, pointing to the jwtEndpoint
// with the authRequest and encrypted userAgent ids.
func (p *Provider) BeginAuth(ctx context.Context, state string, params ...any) (idp.Session, error) {
if len(params) != 1 {
return nil, ErrMissingUserAgentID
}
userAgentID, ok := params[0].(string)
if !ok {
return nil, ErrMissingUserAgentID
}
redirect, err := url.Parse(p.jwtEndpoint)
if err != nil {
return nil, err
}
q := redirect.Query()
q.Set(queryAuthRequestID, state)
nonce, err := p.encryptionAlg.Encrypt([]byte(userAgentID))
if err != nil {
return nil, err
}
q.Set(queryUserAgentID, base64.RawURLEncoding.EncodeToString(nonce))
redirect.RawQuery = q.Encode()
return &Session{AuthURL: redirect.String()}, nil
}
// IsLinkingAllowed implements the [idp.Provider] interface.
func (p *Provider) IsLinkingAllowed() bool {
return p.isLinkingAllowed
}
// IsCreationAllowed implements the [idp.Provider] interface.
func (p *Provider) IsCreationAllowed() bool {
return p.isCreationAllowed
}
// IsAutoCreation implements the [idp.Provider] interface.
func (p *Provider) IsAutoCreation() bool {
return p.isAutoCreation
}
// IsAutoUpdate implements the [idp.Provider] interface.
func (p *Provider) IsAutoUpdate() bool {
return p.isAutoUpdate
}

View File

@@ -0,0 +1,222 @@
package jwt
import (
"context"
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/idp"
)
func TestProvider_BeginAuth(t *testing.T) {
type fields struct {
name string
issuer string
jwtEndpoint string
keysEndpoint string
headerName string
encryptionAlg func(t *testing.T) crypto.EncryptionAlgorithm
}
type args struct {
params []any
}
type want struct {
session idp.Session
err func(error) bool
}
tests := []struct {
name string
fields fields
args args
want want
}{
{
name: "missing userAgentID error",
fields: fields{
issuer: "https://jwt.com",
jwtEndpoint: "https://auth.com/jwt",
keysEndpoint: "https://jwt.com/keys",
headerName: "jwt-header",
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
},
},
args: args{
params: nil,
},
want: want{
err: func(err error) bool {
return errors.Is(err, ErrMissingUserAgentID)
},
},
},
{
name: "invalid userAgentID error",
fields: fields{
issuer: "https://jwt.com",
jwtEndpoint: "https://auth.com/jwt",
keysEndpoint: "https://jwt.com/keys",
headerName: "jwt-header",
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
},
},
args: args{
params: []any{
0,
},
},
want: want{
err: func(err error) bool {
return errors.Is(err, ErrMissingUserAgentID)
},
},
},
{
name: "successful auth",
fields: fields{
issuer: "https://jwt.com",
jwtEndpoint: "https://auth.com/jwt",
keysEndpoint: "https://jwt.com/keys",
headerName: "jwt-header",
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
},
},
args: args{
params: []any{
"agent",
},
},
want: want{
session: &Session{AuthURL: "https://auth.com/jwt?authRequestID=testState&userAgentID=YWdlbnQ"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
provider, err := New(
tt.fields.name,
tt.fields.issuer,
tt.fields.jwtEndpoint,
tt.fields.keysEndpoint,
tt.fields.headerName,
tt.fields.encryptionAlg(t),
)
require.NoError(t, err)
session, err := provider.BeginAuth(context.Background(), "testState", tt.args.params...)
if tt.want.err != nil && !tt.want.err(err) {
a.Fail("invalid error", err)
}
if tt.want.err == nil {
a.NoError(err)
a.Equal(tt.want.session.GetAuthURL(), session.GetAuthURL())
}
})
}
}
func TestProvider_Options(t *testing.T) {
type fields struct {
name string
issuer string
jwtEndpoint string
keysEndpoint string
headerName string
encryptionAlg func(t *testing.T) crypto.EncryptionAlgorithm
opts []ProviderOpts
}
type want struct {
name string
linkingAllowed bool
creationAllowed bool
autoCreation bool
autoUpdate bool
pkce bool
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "default",
fields: fields{
name: "jwt",
issuer: "https://jwt.com",
jwtEndpoint: "https://auth.com/jwt",
keysEndpoint: "https://jwt.com/keys",
headerName: "jwt-header",
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
},
opts: nil,
},
want: want{
name: "jwt",
linkingAllowed: false,
creationAllowed: false,
autoCreation: false,
autoUpdate: false,
pkce: false,
},
},
{
name: "all true",
fields: fields{
name: "jwt",
issuer: "https://jwt.com",
jwtEndpoint: "https://auth.com/jwt",
keysEndpoint: "https://jwt.com/keys",
headerName: "jwt-header",
encryptionAlg: func(t *testing.T) crypto.EncryptionAlgorithm {
return crypto.CreateMockEncryptionAlg(gomock.NewController(t))
},
opts: []ProviderOpts{
WithLinkingAllowed(),
WithCreationAllowed(),
WithAutoCreation(),
WithAutoUpdate(),
},
},
want: want{
name: "jwt",
linkingAllowed: true,
creationAllowed: true,
autoCreation: true,
autoUpdate: true,
pkce: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
provider, err := New(
tt.fields.name,
tt.fields.issuer,
tt.fields.jwtEndpoint,
tt.fields.keysEndpoint,
tt.fields.headerName,
tt.fields.encryptionAlg(t),
tt.fields.opts...,
)
require.NoError(t, err)
a.Equal(tt.want.name, provider.Name())
a.Equal(tt.want.linkingAllowed, provider.IsLinkingAllowed())
a.Equal(tt.want.creationAllowed, provider.IsCreationAllowed())
a.Equal(tt.want.autoCreation, provider.IsAutoCreation())
a.Equal(tt.want.autoUpdate, provider.IsAutoUpdate())
})
}
}

View File

@@ -0,0 +1,72 @@
package jwt
import (
"context"
"github.com/zitadel/oidc/v2/pkg/oidc"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/idp"
)
var _ idp.Session = (*Session)(nil)
// Session is the [idp.Session] implementation for the JWT provider
type Session struct {
AuthURL string
Tokens *oidc.Tokens
}
// GetAuthURL implements the [idp.Session] interface
func (s *Session) GetAuthURL() string {
return s.AuthURL
}
// FetchUser implements the [idp.Session] interface.
// It will map the received idToken into an [idp.User].
func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
if s.Tokens == nil {
return nil, ErrNoTokens
}
return &User{s.Tokens.IDTokenClaims}, nil
}
type User struct {
oidc.IDTokenClaims
}
func (u *User) GetID() string {
return u.IDTokenClaims.GetSubject()
}
func (u *User) GetFirstName() string {
return u.IDTokenClaims.GetGivenName()
}
func (u *User) GetLastName() string {
return u.IDTokenClaims.GetFamilyName()
}
func (u *User) GetDisplayName() string {
return u.IDTokenClaims.GetName()
}
func (u *User) GetNickname() string {
return u.IDTokenClaims.GetNickname()
}
func (u *User) GetPhone() string {
return u.IDTokenClaims.GetPhoneNumber()
}
func (u *User) IsPhoneVerified() bool {
return u.IDTokenClaims.IsPhoneNumberVerified()
}
func (u *User) GetPreferredLanguage() language.Tag {
return u.IDTokenClaims.GetLocale()
}
func (u *User) GetAvatarURL() string {
return u.IDTokenClaims.GetPicture()
}

View File

@@ -0,0 +1,145 @@
package jwt
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zitadel/oidc/v2/pkg/oidc"
"golang.org/x/oauth2"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/idp"
)
func TestSession_FetchUser(t *testing.T) {
type fields struct {
authURL string
tokens *oidc.Tokens
}
type want struct {
err func(error) bool
user idp.User
id string
firstName string
lastName string
displayName string
nickName string
preferredUsername string
email string
isEmailVerified bool
phone string
isPhoneVerified bool
preferredLanguage language.Tag
avatarURL string
profile string
}
tests := []struct {
name string
fields fields
want want
}{
{
name: "no tokens",
fields: fields{},
want: want{
err: func(err error) bool {
return errors.Is(err, ErrNoTokens)
},
},
},
{
name: "successful fetch",
fields: fields{
authURL: "https://auth.com/jwt?authRequestID=testState",
tokens: &oidc.Tokens{
Token: &oauth2.Token{},
IDTokenClaims: func() oidc.IDTokenClaims {
claims := oidc.EmptyIDTokenClaims()
userinfo := oidc.NewUserInfo()
userinfo.SetSubject("sub")
userinfo.SetPicture("picture")
userinfo.SetName("firstname lastname")
userinfo.SetEmail("email", true)
userinfo.SetGivenName("firstname")
userinfo.SetFamilyName("lastname")
userinfo.SetNickname("nickname")
userinfo.SetPreferredUsername("username")
userinfo.SetProfile("profile")
userinfo.SetPhone("phone", true)
userinfo.SetLocale(language.English)
claims.SetUserinfo(userinfo)
return claims
}(),
},
},
want: want{
user: &User{
IDTokenClaims: func() oidc.IDTokenClaims {
claims := oidc.EmptyIDTokenClaims()
userinfo := oidc.NewUserInfo()
userinfo.SetSubject("sub")
userinfo.SetPicture("picture")
userinfo.SetName("firstname lastname")
userinfo.SetEmail("email", true)
userinfo.SetGivenName("firstname")
userinfo.SetFamilyName("lastname")
userinfo.SetNickname("nickname")
userinfo.SetPreferredUsername("username")
userinfo.SetProfile("profile")
userinfo.SetPhone("phone", true)
userinfo.SetLocale(language.English)
claims.SetUserinfo(userinfo)
return claims
}(),
},
id: "sub",
firstName: "firstname",
lastName: "lastname",
displayName: "firstname lastname",
nickName: "nickname",
preferredUsername: "username",
email: "email",
isEmailVerified: true,
phone: "phone",
isPhoneVerified: true,
preferredLanguage: language.English,
avatarURL: "picture",
profile: "profile",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := assert.New(t)
session := &Session{
AuthURL: tt.fields.authURL,
Tokens: tt.fields.tokens,
}
user, err := session.FetchUser(context.Background())
if tt.want.err != nil && !tt.want.err(err) {
a.Fail("invalid error", err)
}
if tt.want.err == nil {
a.NoError(err)
a.Equal(tt.want.user, user)
a.Equal(tt.want.id, user.GetID())
a.Equal(tt.want.firstName, user.GetFirstName())
a.Equal(tt.want.lastName, user.GetLastName())
a.Equal(tt.want.displayName, user.GetDisplayName())
a.Equal(tt.want.nickName, user.GetNickname())
a.Equal(tt.want.preferredUsername, user.GetPreferredUsername())
a.Equal(tt.want.email, user.GetEmail())
a.Equal(tt.want.isEmailVerified, user.IsEmailVerified())
a.Equal(tt.want.phone, user.GetPhone())
a.Equal(tt.want.isPhoneVerified, user.IsPhoneVerified())
a.Equal(tt.want.preferredLanguage, user.GetPreferredLanguage())
a.Equal(tt.want.avatarURL, user.GetAvatarURL())
a.Equal(tt.want.profile, user.GetProfile())
}
})
}
}