fixup! fixup! feat(db): Adding identity_providers table for relational database

This commit is contained in:
Iraq Jaber
2025-07-18 15:55:18 +01:00
parent d951a8b13e
commit 4638118e29
7 changed files with 403 additions and 35 deletions

View File

@@ -31,21 +31,28 @@ const (
type IdentityProvider struct { type IdentityProvider struct {
InstanceID string `json:"instanceId,omitempty" db:"instance_id"` InstanceID string `json:"instanceId,omitempty" db:"instance_id"`
OrgID string `json:"org_id,omitempty" db:"org_id"` OrgID string `json:"orgId,omitempty" db:"org_id"`
ID string `json:"id,omitempty" db:"id"` ID string `json:"id,omitempty" db:"id"`
State string `json:"state,omitempty" db:"state"` State string `json:"state,omitempty" db:"state"`
Name string `json:"name,omitempty" db:"name"` Name string `json:"name,omitempty" db:"name"`
Type string `json:"type,omitempty" db:"type"` Type string `json:"type,omitempty" db:"type"`
AllowCreation bool `json:"allow_creation,omitempty" db:"allow_creation"` AllowCreation bool `json:"allowCreation,omitempty" db:"allow_creation"`
AllowAutoCreation bool `json:"allow_auto_creation,omitempty" db:"allow_auto_creation"` AllowAutoCreation bool `json:"allowAutoCreation,omitempty" db:"allow_auto_creation"`
AllowAutoUpdate bool `json:"allow_auto_update,omitempty" db:"allow_auto_update"` AllowAutoUpdate bool `json:"allowAutoUpdate,omitempty" db:"allow_auto_update"`
AllowLinking bool `json:"allow_linking,omitempty" db:"allow_linking"` AllowLinking bool `json:"allowLinking,omitempty" db:"allow_linking"`
StylingType int16 `json:"styling_type,omitempty" db:"styling_type"` StylingType int16 `json:"stylingType,omitempty" db:"styling_type"`
Payload string `json:"payload,omitempty" db:"payload"` Payload string `json:"payload,omitempty" db:"payload"`
CreatedAt time.Time `json:"createdAt,omitempty" db:"created_at"` CreatedAt time.Time `json:"createdAt,omitempty" db:"created_at"`
UpdatedAt time.Time `json:"updatedAt,omitempty" db:"updated_at"` UpdatedAt time.Time `json:"updatedAt,omitempty" db:"updated_at"`
} }
// IDPIdentifierCondition is used to help specify a single identity_provider,
// it will either be used as the identity_provider ID or identity_provider name,
// as identity_provider can be identified either using (instnaceID + OrgID + ID) OR (instanceID + OrgID + name)
type IDPIdentifierCondition interface {
database.Condition
}
type idProviderColumns interface { type idProviderColumns interface {
InstanceIDColumn() database.Column InstanceIDColumn() database.Column
OrgIDColumn() database.Column OrgIDColumn() database.Column
@@ -66,9 +73,9 @@ type idProviderColumns interface {
type idProviderConditions interface { type idProviderConditions interface {
InstanceIDCondition(id string) database.Condition InstanceIDCondition(id string) database.Condition
OrgIDCondition(id string) database.Condition OrgIDCondition(id string) database.Condition
IDCondition(id string) database.Condition IDCondition(id string) IDPIdentifierCondition
StateCondition(state IDPState) database.Condition StateCondition(state IDPState) database.Condition
NameCondition(name string) database.Condition NameCondition(name string) IDPIdentifierCondition
TypeCondition(typee IDPType) database.Condition TypeCondition(typee IDPType) database.Condition
AllowCreationCondition(allow bool) database.Condition AllowCreationCondition(allow bool) database.Condition
AllowAutoCreationCondition(allow bool) database.Condition AllowAutoCreationCondition(allow bool) database.Condition
@@ -94,10 +101,10 @@ type IDProviderRepository interface {
idProviderConditions idProviderConditions
idProviderChanges idProviderChanges
Get(ctx context.Context, id string) (*IdentityProvider, error) Get(ctx context.Context, id IDPIdentifierCondition, instnaceID string, orgID string) (*IdentityProvider, error)
List(ctx context.Context, conditions ...database.Condition) ([]*IdentityProvider, error) List(ctx context.Context, conditions ...database.Condition) ([]*IdentityProvider, error)
Create(ctx context.Context, idp *IdentityProvider) error Create(ctx context.Context, idp *IdentityProvider) error
Update(ctx context.Context, id string, changes ...database.Change) (int64, error) Update(ctx context.Context, id IDPIdentifierCondition, changes ...database.Change) (int64, error)
Delete(ctx context.Context, id string) (int64, error) Delete(ctx context.Context, id IDPIdentifierCondition) (int64, error)
} }

View File

@@ -0,0 +1,16 @@
package migration
import (
_ "embed"
)
var (
//go:embed 003_identity_providers_table/up.sql
up003IdentityProvidersTable string
//go:embed 003_identity_providers_table/down.sql
down003IdentityProvidersTable string
)
func init() {
registerSQLMigration(3, up003IdentityProvidersTable, down003IdentityProvidersTable)
}

View File

@@ -14,12 +14,12 @@ CREATE TYPE idp_type AS ENUM (
'apple' 'apple'
); );
CREATE TABLE identity_providers ( CREATE TABLE zitadel.identity_providers (
instance_id TEXT NOT NULL instance_id TEXT NOT NULL
, org_id TEXT , org_id TEXT
, id TEXT NOT NULL , id TEXT NOT NULL CHECK (id <> '')
, state idp_state NOT NULL DEFAULT 'active' , state idp_state NOT NULL DEFAULT 'active'
, name TEXT , name TEXT NOT NULL CHECK (name <> '')
, type idp_type NOT NULL , type idp_type NOT NULL
, allow_creation BOOLEAN NOT NULL DEFAULT TRUE , allow_creation BOOLEAN NOT NULL DEFAULT TRUE
, allow_auto_creation BOOLEAN NOT NULL DEFAULT TRUE , allow_auto_creation BOOLEAN NOT NULL DEFAULT TRUE
@@ -32,14 +32,15 @@ CREATE TABLE identity_providers (
, updated_at TIMESTAMPTZ NOT NULL DEFAULT now() , updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
, PRIMARY KEY (instance_id, id) , PRIMARY KEY (instance_id, id)
, CONSTRAINT identity_providers_unique UNIQUE NULLS NOT DISTINCT (instance_id, org_id, id) , CONSTRAINT identity_providers_id_unique UNIQUE NULLS NOT DISTINCT (instance_id, org_id, id)
, FOREIGN KEY (instance_id) REFERENCES instances(id) , CONSTRAINT identity_providers_name_unique UNIQUE NULLS NOT DISTINCT (instance_id, org_id, name)
, FOREIGN KEY (instance_id, org_id) REFERENCES organizations(instance_id, id) , FOREIGN KEY (instance_id) REFERENCES zitadel.instances(id)
, FOREIGN KEY (instance_id, org_id) REFERENCES zitadel.organizations(instance_id, id)
); );
-- CREATE INDEX idx_identity_providers_org_id ON identity_providers(instance_id, org_id) WHERE org_id IS NOT NULL; -- CREATE INDEX idx_identity_providers_org_id ON identity_providers(instance_id, org_id) WHERE org_id IS NOT NULL;
CREATE INDEX idx_identity_providers_state ON identity_providers(instance_id, state); CREATE INDEX idx_identity_providers_state ON zitadel.identity_providers(instance_id, state);
CREATE INDEX idx_identity_providers_type ON identity_providers(instance_id, type); CREATE INDEX idx_identity_providers_type ON zitadel.identity_providers(instance_id, type);
-- CREATE INDEX idx_identity_providers_created_at ON identity_providers(created_at); -- CREATE INDEX idx_identity_providers_created_at ON identity_providers(created_at);
-- CREATE INDEX idx_identity_providers_deleted_at ON identity_providers(deleted_at) WHERE deleted_at IS NOT NULL; -- CREATE INDEX idx_identity_providers_deleted_at ON identity_providers(deleted_at) WHERE deleted_at IS NOT NULL;

View File

@@ -22,18 +22,18 @@ func IDProviderRepository(client database.QueryExecutor) domain.IDProviderReposi
} }
} }
const queryIDProviderStmt = `instance_id, org_id, id, state, name, type, allow_creation, allow_auto_creation,` + const queryIDProviderStmt = `SELECT instance_id, org_id, id, state, name, type, allow_creation, allow_auto_creation,` +
` allow_auto_update, allow_linking, styling_type, payload, created_at, updated_at` + ` allow_auto_update, allow_linking, styling_type, payload, created_at, updated_at` +
` FROM zitadel.identity_providers` ` FROM zitadel.identity_providers`
func (i *idProvider) Get(ctx context.Context, id string) (*domain.IdentityProvider, error) { func (i *idProvider) Get(ctx context.Context, id domain.IDPIdentifierCondition, instnaceID string, orgID string) (*domain.IdentityProvider, error) {
builder := database.StatementBuilder{} builder := database.StatementBuilder{}
builder.WriteString(queryIDProviderStmt) builder.WriteString(queryIDProviderStmt)
idCondition := i.IDCondition(id) conditions := []database.Condition{id, i.InstanceIDCondition(instnaceID), i.OrgIDCondition(orgID)}
writeCondition(&builder, idCondition) writeCondition(&builder, database.And(conditions...))
return scanIDProvider(ctx, i.client, &builder) return scanIDProvider(ctx, i.client, &builder)
} }
@@ -56,13 +56,13 @@ func (i *idProvider) List(ctx context.Context, conditions ...database.Condition)
const createIDProviderStmt = `INSERT INTO zitadel.identity_providers` + const createIDProviderStmt = `INSERT INTO zitadel.identity_providers` +
` (instance_id, org_id, id, state, name, type, allow_creation, allow_auto_creation,` + ` (instance_id, org_id, id, state, name, type, allow_creation, allow_auto_creation,` +
` allow_auto_update, allow_linking, styling_type, payload)` + ` allow_auto_update, allow_linking, styling_type, payload)` +
` VALUES ($1, $2, $3, $4, $5, $6, $,7, $8, $9, $10, $11, $12)` + ` VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)` +
` RETURNING created_at, updated_at` ` RETURNING created_at, updated_at`
func (i *idProvider) Create(ctx context.Context, idp *domain.IdentityProvider) error { func (i *idProvider) Create(ctx context.Context, idp *domain.IdentityProvider) error {
builder := database.StatementBuilder{} builder := database.StatementBuilder{}
builder.AppendArgs(idp.InstanceID, idp.OrgID, idp.ID, idp.State, idp.Name, idp.Type, idp.AllowCreation, builder.AppendArgs(idp.InstanceID, idp.OrgID, idp.ID, idp.State, idp.Name, idp.Type, idp.AllowCreation,
idp.AllowAutoCreation, idp.AllowLinking, idp.StylingType, idp.Payload) idp.AllowAutoCreation, idp.AllowAutoUpdate, idp.AllowLinking, idp.StylingType, idp.Payload)
builder.WriteString(createIDProviderStmt) builder.WriteString(createIDProviderStmt)
err := i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&idp.CreatedAt, &idp.UpdatedAt) err := i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&idp.CreatedAt, &idp.UpdatedAt)
@@ -72,14 +72,15 @@ func (i *idProvider) Create(ctx context.Context, idp *domain.IdentityProvider) e
return nil return nil
} }
func (i *idProvider) Update(ctx context.Context, id string, changes ...database.Change) (int64, error) { func (i *idProvider) Update(ctx context.Context, id domain.IDPIdentifierCondition, changes ...database.Change) (int64, error) {
if changes == nil { if changes == nil {
return 0, errors.New("Update must contain a condition") // (otherwise ALL identity_providers will be updated) return 0, errors.New("Update must contain a condition") // (otherwise ALL identity_providers will be updated)
} }
builder := database.StatementBuilder{} builder := database.StatementBuilder{}
builder.WriteString(`UPDATE zitadel.identity_provider SET `) builder.WriteString(`UPDATE zitadel.identity_provider SET `)
conditions := []database.Condition{i.IDCondition(id)} // conditions := []database.Condition{i.IDCondition(id)}
conditions := []database.Condition{id}
database.Changes(changes).Write(&builder) database.Changes(changes).Write(&builder)
writeCondition(&builder, database.And(conditions...)) writeCondition(&builder, database.And(conditions...))
@@ -89,13 +90,13 @@ func (i *idProvider) Update(ctx context.Context, id string, changes ...database.
return rowsAffected, err return rowsAffected, err
} }
func (i *idProvider) Delete(ctx context.Context, id string) (int64, error) { func (i *idProvider) Delete(ctx context.Context, id domain.IDPIdentifierCondition) (int64, error) {
builder := database.StatementBuilder{} builder := database.StatementBuilder{}
builder.WriteString(`DELETE FROM zitadel.identity_providers`) builder.WriteString(`DELETE FROM zitadel.identity_providers`)
conditions := []database.Condition{i.IDCondition(id)} // conditions := []database.Condition{i.IDCondition(id)}
writeCondition(&builder, database.And(conditions...)) // writeCondition(&builder, database.And(conditions...))
return i.client.Exec(ctx, builder.String(), builder.Args()...) return i.client.Exec(ctx, builder.String(), builder.Args()...)
} }
@@ -172,7 +173,7 @@ func (i idProvider) OrgIDCondition(id string) database.Condition {
return database.NewTextCondition(i.OrgIDColumn(), database.TextOperationEqual, id) return database.NewTextCondition(i.OrgIDColumn(), database.TextOperationEqual, id)
} }
func (i idProvider) IDCondition(id string) database.Condition { func (i idProvider) IDCondition(id string) domain.IDPIdentifierCondition {
return database.NewTextCondition(i.IDColumn(), database.TextOperationEqual, id) return database.NewTextCondition(i.IDColumn(), database.TextOperationEqual, id)
} }
@@ -180,7 +181,7 @@ func (i idProvider) StateCondition(state domain.IDPState) database.Condition {
return database.NewTextCondition(i.OrgIDColumn(), database.TextOperationEqual, state.String()) return database.NewTextCondition(i.OrgIDColumn(), database.TextOperationEqual, state.String())
} }
func (i idProvider) NameCondition(name string) database.Condition { func (i idProvider) NameCondition(name string) domain.IDPIdentifierCondition {
return database.NewTextCondition(i.NameColumn(), database.TextOperationEqual, name) return database.NewTextCondition(i.NameColumn(), database.TextOperationEqual, name)
} }

View File

@@ -0,0 +1,345 @@
package repository_test
import (
"context"
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
)
func TestCreateIDProvider(t *testing.T) {
// create instance
instanceId := gofakeit.Name()
instance := domain.Instance{
ID: instanceId,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
instanceRepo := repository.InstanceRepository(pool)
err := instanceRepo.Create(t.Context(), &instance)
require.NoError(t, err)
// create org
orgId := gofakeit.Name()
org := domain.Organization{
ID: orgId,
Name: gofakeit.Name(),
InstanceID: instanceId,
State: domain.OrgStateActive.String(),
}
organizationRepo := repository.OrganizationRepository(pool)
err = organizationRepo.Create(t.Context(), &org)
require.NoError(t, err)
type test struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.IdentityProvider
idp domain.IdentityProvider
err error
}
// TESTS
tests := []test{
{
name: "happy path",
idp: domain.IdentityProvider{
InstanceID: instanceId,
OrgID: orgId,
ID: gofakeit.Name(),
State: domain.IDPStateActive.String(),
Name: gofakeit.Name(),
Type: domain.IDPTypeOIDC.String(),
AllowCreation: true,
AllowAutoCreation: true,
AllowAutoUpdate: true,
AllowLinking: true,
StylingType: 1,
Payload: "{}",
},
},
{
name: "create organization without name",
idp: domain.IdentityProvider{
InstanceID: instanceId,
OrgID: orgId,
ID: gofakeit.Name(),
State: domain.IDPStateActive.String(),
// Name: gofakeit.Name(),
Type: domain.IDPTypeOIDC.String(),
AllowCreation: true,
AllowAutoCreation: true,
AllowAutoUpdate: true,
AllowLinking: true,
StylingType: 1,
Payload: "{}",
},
err: new(database.CheckError),
},
{
name: "adding org with same id twice",
testFunc: func(ctx context.Context, t *testing.T) *domain.IdentityProvider {
idpRepo := repository.IDProviderRepository(pool)
idp := domain.IdentityProvider{
InstanceID: instanceId,
OrgID: orgId,
ID: gofakeit.Name(),
State: domain.IDPStateActive.String(),
Name: gofakeit.Name(),
Type: domain.IDPTypeOIDC.String(),
AllowCreation: true,
AllowAutoCreation: true,
AllowAutoUpdate: true,
AllowLinking: true,
StylingType: 1,
Payload: "{}",
}
err := idpRepo.Create(ctx, &idp)
require.NoError(t, err)
// change the name to make sure same only the id clashes
org.Name = gofakeit.Name()
return &idp
},
err: new(database.UniqueError),
},
{
name: "adding org with same name twice",
testFunc: func(ctx context.Context, t *testing.T) *domain.IdentityProvider {
idpRepo := repository.IDProviderRepository(pool)
idp := domain.IdentityProvider{
InstanceID: instanceId,
OrgID: orgId,
ID: gofakeit.Name(),
State: domain.IDPStateActive.String(),
Name: gofakeit.Name(),
Type: domain.IDPTypeOIDC.String(),
AllowCreation: true,
AllowAutoCreation: true,
AllowAutoUpdate: true,
AllowLinking: true,
StylingType: 1,
Payload: "{}",
}
err := idpRepo.Create(ctx, &idp)
require.NoError(t, err)
// change the id to make sure same name causes an error
idp.ID = gofakeit.Name()
return &idp
},
err: new(database.UniqueError),
},
func() test {
id := gofakeit.Name()
name := gofakeit.Name()
return test{
name: "adding org with same name, id, different instance",
testFunc: func(ctx context.Context, t *testing.T) *domain.IdentityProvider {
// create instance
newInstId := gofakeit.Name()
instance := domain.Instance{
ID: newInstId,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
instanceRepo := repository.InstanceRepository(pool)
err := instanceRepo.Create(ctx, &instance)
assert.Nil(t, err)
// create org
newOrgId := gofakeit.Name()
org := domain.Organization{
ID: newOrgId,
Name: gofakeit.Name(),
InstanceID: newInstId,
State: domain.OrgStateActive.String(),
}
organizationRepo := repository.OrganizationRepository(pool)
err = organizationRepo.Create(t.Context(), &org)
require.NoError(t, err)
idpRepo := repository.IDProviderRepository(pool)
idp := domain.IdentityProvider{
InstanceID: newInstId,
OrgID: newOrgId,
ID: id,
State: domain.IDPStateActive.String(),
Name: name,
Type: domain.IDPTypeOIDC.String(),
AllowCreation: true,
AllowAutoCreation: true,
AllowAutoUpdate: true,
AllowLinking: true,
StylingType: 1,
Payload: "{}",
}
err = idpRepo.Create(ctx, &idp)
require.NoError(t, err)
// change the instanceID to a different instance
idp.InstanceID = instanceId
// change the OrgId to a different organization
idp.OrgID = orgId
return &idp
},
idp: domain.IdentityProvider{
InstanceID: instanceId,
OrgID: orgId,
ID: id,
State: domain.IDPStateActive.String(),
Name: name,
Type: domain.IDPTypeOIDC.String(),
AllowCreation: true,
AllowAutoCreation: true,
AllowAutoUpdate: true,
AllowLinking: true,
StylingType: 1,
Payload: "{}",
},
}
}(),
{
name: "adding idp with no id",
idp: domain.IdentityProvider{
InstanceID: instanceId,
OrgID: orgId,
// ID: gofakeit.Name(),
State: domain.IDPStateActive.String(),
Name: gofakeit.Name(),
Type: domain.IDPTypeOIDC.String(),
AllowCreation: true,
AllowAutoCreation: true,
AllowAutoUpdate: true,
AllowLinking: true,
StylingType: 1,
Payload: "{}",
},
err: new(database.CheckError),
},
{
name: "adding idp with no name",
idp: domain.IdentityProvider{
InstanceID: instanceId,
OrgID: orgId,
ID: gofakeit.Name(),
State: domain.IDPStateActive.String(),
// Name: gofakeit.Name(),
Type: domain.IDPTypeOIDC.String(),
AllowCreation: true,
AllowAutoCreation: true,
AllowAutoUpdate: true,
AllowLinking: true,
StylingType: 1,
Payload: "{}",
},
err: new(database.CheckError),
},
{
name: "adding idp with no instance id",
idp: domain.IdentityProvider{
// InstanceID: instanceId,
OrgID: orgId,
State: domain.IDPStateActive.String(),
Name: gofakeit.Name(),
Type: domain.IDPTypeOIDC.String(),
AllowCreation: true,
AllowAutoCreation: true,
AllowAutoUpdate: true,
AllowLinking: true,
StylingType: 1,
Payload: "{}",
},
err: new(database.IntegrityViolationError),
},
{
name: "adding organization with non existent instance id",
idp: domain.IdentityProvider{
InstanceID: gofakeit.Name(),
OrgID: orgId,
State: domain.IDPStateActive.String(),
ID: gofakeit.Name(),
Name: gofakeit.Name(),
Type: domain.IDPTypeOIDC.String(),
AllowCreation: true,
AllowAutoCreation: true,
AllowAutoUpdate: true,
AllowLinking: true,
StylingType: 1,
Payload: "{}",
},
err: new(database.ForeignKeyError),
},
{
name: "adding organization with non existent org id",
idp: domain.IdentityProvider{
InstanceID: instanceId,
OrgID: gofakeit.Name(),
State: domain.IDPStateActive.String(),
ID: gofakeit.Name(),
Name: gofakeit.Name(),
Type: domain.IDPTypeOIDC.String(),
AllowCreation: true,
AllowAutoCreation: true,
AllowAutoUpdate: true,
AllowLinking: true,
StylingType: 1,
Payload: "{}",
},
err: new(database.ForeignKeyError),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
var idp *domain.IdentityProvider
if tt.testFunc != nil {
idp = tt.testFunc(ctx, t)
} else {
idp = &tt.idp
}
idpRepo := repository.IDProviderRepository(pool)
// create idp
beforeCreate := time.Now()
err = idpRepo.Create(ctx, idp)
assert.ErrorIs(t, err, tt.err)
if err != nil {
return
}
afterCreate := time.Now()
// check organization values
idp, err = idpRepo.Get(ctx,
idpRepo.IDCondition(idp.ID),
idp.InstanceID,
idp.OrgID,
)
require.NoError(t, err)
assert.Equal(t, tt.idp.ID, idp.ID)
assert.Equal(t, tt.idp.Name, idp.Name)
assert.Equal(t, tt.idp.InstanceID, idp.InstanceID)
assert.Equal(t, tt.idp.State, idp.State)
assert.WithinRange(t, idp.CreatedAt, beforeCreate, afterCreate)
assert.WithinRange(t, idp.UpdatedAt, beforeCreate, afterCreate)
})
}
}

View File

@@ -29,9 +29,7 @@ func (b *StatementBuilder) AppendArg(arg any) (placeholder string) {
if b.existingArgs == nil { if b.existingArgs == nil {
b.existingArgs = make(map[any]string) b.existingArgs = make(map[any]string)
} }
if placeholder, ok := b.existingArgs[arg]; ok {
return placeholder
}
if instruction, ok := arg.(Instruction); ok { if instruction, ok := arg.(Instruction); ok {
return string(instruction) return string(instruction)
} }