diff --git a/backend/v3/storage/database/database.go b/backend/v3/storage/database/database.go index 709dda0be0..f11f67a628 100644 --- a/backend/v3/storage/database/database.go +++ b/backend/v3/storage/database/database.go @@ -8,6 +8,7 @@ import ( type Pool interface { Beginner QueryExecutor + Migrator Acquire(ctx context.Context) (Client, error) Close(ctx context.Context) error @@ -17,6 +18,7 @@ type Pool interface { type Client interface { Beginner QueryExecutor + Migrator Release(ctx context.Context) error } diff --git a/backend/v3/storage/database/dbmock/database.mock.go b/backend/v3/storage/database/dbmock/database.mock.go index 215804ded8..02060efb3f 100644 --- a/backend/v3/storage/database/dbmock/database.mock.go +++ b/backend/v3/storage/database/dbmock/database.mock.go @@ -199,6 +199,44 @@ func (c *MockPoolExecCall) DoAndReturn(f func(context.Context, string, ...any) e return c } +// Migrate mocks base method. +func (m *MockPool) Migrate(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Migrate", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Migrate indicates an expected call of Migrate. +func (mr *MockPoolMockRecorder) Migrate(arg0 any) *MockPoolMigrateCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockPool)(nil).Migrate), arg0) + return &MockPoolMigrateCall{Call: call} +} + +// MockPoolMigrateCall wrap *gomock.Call +type MockPoolMigrateCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPoolMigrateCall) Return(arg0 error) *MockPoolMigrateCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPoolMigrateCall) Do(f func(context.Context) error) *MockPoolMigrateCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPoolMigrateCall) DoAndReturn(f func(context.Context) error) *MockPoolMigrateCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // Query mocks base method. func (m *MockPool) Query(arg0 context.Context, arg1 string, arg2 ...any) (database.Rows, error) { m.ctrl.T.Helper() @@ -391,6 +429,44 @@ func (c *MockClientExecCall) DoAndReturn(f func(context.Context, string, ...any) return c } +// Migrate mocks base method. +func (m *MockClient) Migrate(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Migrate", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Migrate indicates an expected call of Migrate. +func (mr *MockClientMockRecorder) Migrate(arg0 any) *MockClientMigrateCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockClient)(nil).Migrate), arg0) + return &MockClientMigrateCall{Call: call} +} + +// MockClientMigrateCall wrap *gomock.Call +type MockClientMigrateCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockClientMigrateCall) Return(arg0 error) *MockClientMigrateCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockClientMigrateCall) Do(f func(context.Context) error) *MockClientMigrateCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockClientMigrateCall) DoAndReturn(f func(context.Context) error) *MockClientMigrateCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + // Query mocks base method. func (m *MockClient) Query(arg0 context.Context, arg1 string, arg2 ...any) (database.Rows, error) { m.ctrl.T.Helper() diff --git a/backend/v3/storage/database/dialect/postgres/config.go b/backend/v3/storage/database/dialect/postgres/config.go index 630c603d14..3bab1f8b1f 100644 --- a/backend/v3/storage/database/dialect/postgres/config.go +++ b/backend/v3/storage/database/dialect/postgres/config.go @@ -13,8 +13,9 @@ import ( ) var ( - _ database.Connector = (*Config)(nil) - Name = "postgres" + _ database.Connector = (*Config)(nil) + Name = "postgres" + isMigrated bool ) type Config struct { @@ -45,7 +46,7 @@ func (c *Config) Connect(ctx context.Context) (database.Pool, error) { if err = pool.Ping(ctx); err != nil { return nil, err } - return &pgxPool{pool}, nil + return &pgxPool{Pool: pool}, nil } func (c *Config) getPool(ctx context.Context) (*pgxpool.Pool, error) { diff --git a/backend/v3/storage/database/dialect/postgres/conn.go b/backend/v3/storage/database/dialect/postgres/conn.go index 5e5942680a..0cb5d8f16a 100644 --- a/backend/v3/storage/database/dialect/postgres/conn.go +++ b/backend/v3/storage/database/dialect/postgres/conn.go @@ -9,11 +9,12 @@ import ( "github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres/migration" ) -type pgxConn struct{ *pgxpool.Conn } +type pgxConn struct { + *pgxpool.Conn +} var ( - _ database.Client = (*pgxConn)(nil) - _ database.Migrator = (*pgxConn)(nil) + _ database.Client = (*pgxConn)(nil) ) // Release implements [database.Client]. @@ -53,5 +54,10 @@ func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) error { // Migrate implements [database.Migrator]. func (c *pgxConn) Migrate(ctx context.Context) error { - return migration.Migrate(ctx, c.Conn.Conn()) + if isMigrated { + return nil + } + err := migration.Migrate(ctx, c.Conn.Conn()) + isMigrated = err == nil + return err } diff --git a/backend/v3/storage/database/dialect/postgres/embedded/start.go b/backend/v3/storage/database/dialect/postgres/embedded/start.go new file mode 100644 index 0000000000..9a5f3ea82b --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/embedded/start.go @@ -0,0 +1,50 @@ +// embedded is used for testing purposes +package embedded + +import ( + "net" + "os" + + embeddedpostgres "github.com/fergusstrange/embedded-postgres" + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres" +) + +// StartEmbedded starts an embedded postgres v16 instance and returns a database connector and a stop function +// the database is started on a random port and data are stored in a temporary directory +// its used for testing purposes only +func StartEmbedded() (connector database.Connector, stop func(), err error) { + path, err := os.MkdirTemp("", "zitadel-embedded-postgres-*") + logging.OnError(err).Fatal("unable to create temp dir") + + port, close := getPort() + + config := embeddedpostgres.DefaultConfig().Version(embeddedpostgres.V16).Port(uint32(port)).RuntimePath(path) + embedded := embeddedpostgres.NewDatabase(config) + + close() + err = embedded.Start() + logging.OnError(err).Fatal("unable to start db") + + connector, err = postgres.DecodeConfig(config.GetConnectionURL()) + if err != nil { + return nil, nil, err + } + + return connector, func() { + logging.OnError(embedded.Stop()).Error("unable to stop db") + }, nil +} + +// getPort returns a free port and locks it until close is called +func getPort() (port uint16, close func()) { + l, err := net.Listen("tcp", ":0") + logging.OnError(err).Fatal("unable to get port") + port = uint16(l.Addr().(*net.TCPAddr).Port) + logging.WithFields("port", port).Info("Port is available") + return port, func() { + logging.OnError(l.Close()).Error("unable to close port listener") + } +} diff --git a/backend/v3/storage/database/dialect/postgres/migration/migrationgs_test.go b/backend/v3/storage/database/dialect/postgres/migration/migrationgs_test.go new file mode 100644 index 0000000000..37680cb94f --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/migration/migrationgs_test.go @@ -0,0 +1,60 @@ +package migration_test + +import ( + "context" + "testing" + + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres/embedded" +) + +func TestMigrate(t *testing.T) { + tests := []struct { + name string + stmt string + args []any + res []any + }{ + { + name: "schema", + stmt: "SELECT EXISTS(SELECT 1 FROM information_schema.schemata where schema_name = 'zitadel') ;", + res: []any{true}, + }, + { + name: "001", + stmt: "SELECT EXISTS(SELECT 1 FROM pg_catalog.pg_tables WHERE schemaname = 'zitadel' and tablename=$1)", + args: []any{"instances"}, + res: []any{true}, + }, + } + + ctx := context.Background() + + connector, stop, err := embedded.StartEmbedded() + require.NoError(t, err, "failed to start embedded postgres") + defer stop() + + client, err := connector.Connect(ctx) + require.NoError(t, err, "failed to connect to embedded postgres") + + err = client.(database.Migrator).Migrate(ctx) + require.NoError(t, err, "failed to execute migration steps") + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := make([]any, len(tt.res)) + for i := range got { + got[i] = new(any) + tt.res[i] = gu.Ptr(tt.res[i]) + } + + require.NoError(t, client.QueryRow(ctx, tt.stmt, tt.args...).Scan(got...), "failed to execute check query") + + assert.Equal(t, tt.res, got, "query result does not match") + }) + } +} diff --git a/backend/v3/storage/database/dialect/postgres/pool.go b/backend/v3/storage/database/dialect/postgres/pool.go index a8a416f362..e3006d91eb 100644 --- a/backend/v3/storage/database/dialect/postgres/pool.go +++ b/backend/v3/storage/database/dialect/postgres/pool.go @@ -9,11 +9,12 @@ import ( "github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres/migration" ) -type pgxPool struct{ *pgxpool.Pool } +type pgxPool struct { + *pgxpool.Pool +} var ( - _ database.Pool = (*pgxPool)(nil) - _ database.Migrator = (*pgxPool)(nil) + _ database.Pool = (*pgxPool)(nil) ) // Acquire implements [database.Pool]. @@ -22,7 +23,7 @@ func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) { if err != nil { return nil, err } - return &pgxConn{conn}, nil + return &pgxConn{Conn: conn}, nil } // Query implements [database.Pool]. @@ -62,9 +63,16 @@ func (c *pgxPool) Close(_ context.Context) error { // Migrate implements [database.Migrator]. func (c *pgxPool) Migrate(ctx context.Context) error { + if isMigrated { + return nil + } + client, err := c.Pool.Acquire(ctx) if err != nil { return err } - return migration.Migrate(ctx, client.Conn()) + + err = migration.Migrate(ctx, client.Conn()) + isMigrated = err == nil + return err } diff --git a/backend/v3/storage/database/migration.go b/backend/v3/storage/database/migration.go index 7aa1101148..5a5b5af6fc 100644 --- a/backend/v3/storage/database/migration.go +++ b/backend/v3/storage/database/migration.go @@ -4,5 +4,6 @@ import "context" type Migrator interface { // Migrate executes migrations to setup the database. + // The method can be called once per running Zitadel. Migrate(ctx context.Context) error } diff --git a/backend/v3/storage/database/repository/org_test.go b/backend/v3/storage/database/repository/org_test.go new file mode 100644 index 0000000000..996e8d1b2c --- /dev/null +++ b/backend/v3/storage/database/repository/org_test.go @@ -0,0 +1,16 @@ +package repository + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestBla is an example and can be removed later +func TestBla(t *testing.T) { + var count int + err := pool.QueryRow(context.Background(), "select count(*) from zitadel.instances").Scan(&count) + assert.NoError(t, err) + assert.Equal(t, 0, count) +} diff --git a/backend/v3/storage/database/repository/repository_test.go b/backend/v3/storage/database/repository/repository_test.go new file mode 100644 index 0000000000..7cbca2114f --- /dev/null +++ b/backend/v3/storage/database/repository/repository_test.go @@ -0,0 +1,41 @@ +package repository + +import ( + "context" + "log" + "os" + "testing" + + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres/embedded" +) + +func TestMain(m *testing.M) { + os.Exit(runTests(m)) +} + +var pool database.Pool + +func runTests(m *testing.M) int { + connector, stop, err := embedded.StartEmbedded() + if err != nil { + log.Fatalf("unable to start embedded postgres: %v", err) + } + defer stop() + + ctx := context.Background() + + pool, err = connector.Connect(ctx) + if err != nil { + log.Printf("unable to connect to embedded postgres: %v", err) + return 1 + } + + err = pool.Migrate(ctx) + if err != nil { + log.Printf("unable to migrate database: %v", err) + return 1 + } + + return m.Run() +} diff --git a/cmd/setup/54.go b/cmd/setup/54.go index a241937261..b31055a9af 100644 --- a/cmd/setup/54.go +++ b/cmd/setup/54.go @@ -4,7 +4,6 @@ import ( "context" _ "embed" - v3_db "github.com/zitadel/zitadel/backend/v3/storage/database" "github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/eventstore" @@ -15,18 +14,13 @@ type TransactionalTables struct { } func (mig *TransactionalTables) Execute(ctx context.Context, _ eventstore.Event) error { - _, err := mig.dbClient.ExecContext(ctx, "CREATE SCHEMA IF NOT EXISTS zitadel") - if err != nil { - return err - } - config := &postgres.Config{Pool: mig.dbClient.Pool} pool, err := config.Connect(ctx) if err != nil { return err } - return pool.(v3_db.Migrator).Migrate(ctx) + return pool.Migrate(ctx) } func (mig *TransactionalTables) String() string {