fixup! fixup! fixup! fixup! Merge branch 'main' into import_export_merge

This commit is contained in:
Iraq Jaber
2025-06-06 13:08:14 +02:00
parent 400fb97d6d
commit 4384652b5e
7 changed files with 111 additions and 95 deletions

View File

@@ -157,15 +157,16 @@ func (c *MockPoolCloseCall) DoAndReturn(f func(context.Context) error) *MockPool
} }
// Exec mocks base method. // Exec mocks base method.
func (m *MockPool) Exec(arg0 context.Context, arg1 string, arg2 ...any) error { func (m *MockPool) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
varargs := []any{arg0, arg1} varargs := []any{arg0, arg1}
for _, a := range arg2 { for _, a := range arg2 {
varargs = append(varargs, a) varargs = append(varargs, a)
} }
ret := m.ctrl.Call(m, "Exec", varargs...) ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(int64)
return ret0 ret1, _ := ret[1].(error)
return ret0, ret1
} }
// Exec indicates an expected call of Exec. // Exec indicates an expected call of Exec.
@@ -182,19 +183,19 @@ type MockPoolExecCall struct {
} }
// Return rewrite *gomock.Call.Return // Return rewrite *gomock.Call.Return
func (c *MockPoolExecCall) Return(arg0 error) *MockPoolExecCall { func (c *MockPoolExecCall) Return(arg0 int64, arg1 error) *MockPoolExecCall {
c.Call = c.Call.Return(arg0) c.Call = c.Call.Return(arg0, arg1)
return c return c
} }
// Do rewrite *gomock.Call.Do // Do rewrite *gomock.Call.Do
func (c *MockPoolExecCall) Do(f func(context.Context, string, ...any) error) *MockPoolExecCall { func (c *MockPoolExecCall) Do(f func(context.Context, string, ...any) (int64, error)) *MockPoolExecCall {
c.Call = c.Call.Do(f) c.Call = c.Call.Do(f)
return c return c
} }
// DoAndReturn rewrite *gomock.Call.DoAndReturn // DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockPoolExecCall) DoAndReturn(f func(context.Context, string, ...any) error) *MockPoolExecCall { func (c *MockPoolExecCall) DoAndReturn(f func(context.Context, string, ...any) (int64, error)) *MockPoolExecCall {
c.Call = c.Call.DoAndReturn(f) c.Call = c.Call.DoAndReturn(f)
return c return c
} }
@@ -387,15 +388,16 @@ func (c *MockClientBeginCall) DoAndReturn(f func(context.Context, *database.Tran
} }
// Exec mocks base method. // Exec mocks base method.
func (m *MockClient) Exec(arg0 context.Context, arg1 string, arg2 ...any) error { func (m *MockClient) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
varargs := []any{arg0, arg1} varargs := []any{arg0, arg1}
for _, a := range arg2 { for _, a := range arg2 {
varargs = append(varargs, a) varargs = append(varargs, a)
} }
ret := m.ctrl.Call(m, "Exec", varargs...) ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(int64)
return ret0 ret1, _ := ret[1].(error)
return ret0, ret1
} }
// Exec indicates an expected call of Exec. // Exec indicates an expected call of Exec.
@@ -412,19 +414,19 @@ type MockClientExecCall struct {
} }
// Return rewrite *gomock.Call.Return // Return rewrite *gomock.Call.Return
func (c *MockClientExecCall) Return(arg0 error) *MockClientExecCall { func (c *MockClientExecCall) Return(arg0 int64, arg1 error) *MockClientExecCall {
c.Call = c.Call.Return(arg0) c.Call = c.Call.Return(arg0, arg1)
return c return c
} }
// Do rewrite *gomock.Call.Do // Do rewrite *gomock.Call.Do
func (c *MockClientExecCall) Do(f func(context.Context, string, ...any) error) *MockClientExecCall { func (c *MockClientExecCall) Do(f func(context.Context, string, ...any) (int64, error)) *MockClientExecCall {
c.Call = c.Call.Do(f) c.Call = c.Call.Do(f)
return c return c
} }
// DoAndReturn rewrite *gomock.Call.DoAndReturn // DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockClientExecCall) DoAndReturn(f func(context.Context, string, ...any) error) *MockClientExecCall { func (c *MockClientExecCall) DoAndReturn(f func(context.Context, string, ...any) (int64, error)) *MockClientExecCall {
c.Call = c.Call.DoAndReturn(f) c.Call = c.Call.DoAndReturn(f)
return c return c
} }
@@ -975,15 +977,16 @@ func (c *MockTransactionEndCall) DoAndReturn(f func(context.Context, error) erro
} }
// Exec mocks base method. // Exec mocks base method.
func (m *MockTransaction) Exec(arg0 context.Context, arg1 string, arg2 ...any) error { func (m *MockTransaction) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
varargs := []any{arg0, arg1} varargs := []any{arg0, arg1}
for _, a := range arg2 { for _, a := range arg2 {
varargs = append(varargs, a) varargs = append(varargs, a)
} }
ret := m.ctrl.Call(m, "Exec", varargs...) ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(int64)
return ret0 ret1, _ := ret[1].(error)
return ret0, ret1
} }
// Exec indicates an expected call of Exec. // Exec indicates an expected call of Exec.
@@ -1000,19 +1003,19 @@ type MockTransactionExecCall struct {
} }
// Return rewrite *gomock.Call.Return // Return rewrite *gomock.Call.Return
func (c *MockTransactionExecCall) Return(arg0 error) *MockTransactionExecCall { func (c *MockTransactionExecCall) Return(arg0 int64, arg1 error) *MockTransactionExecCall {
c.Call = c.Call.Return(arg0) c.Call = c.Call.Return(arg0, arg1)
return c return c
} }
// Do rewrite *gomock.Call.Do // Do rewrite *gomock.Call.Do
func (c *MockTransactionExecCall) Do(f func(context.Context, string, ...any) error) *MockTransactionExecCall { func (c *MockTransactionExecCall) Do(f func(context.Context, string, ...any) (int64, error)) *MockTransactionExecCall {
c.Call = c.Call.Do(f) c.Call = c.Call.Do(f)
return c return c
} }
// DoAndReturn rewrite *gomock.Call.DoAndReturn // DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockTransactionExecCall) DoAndReturn(f func(context.Context, string, ...any) error) *MockTransactionExecCall { func (c *MockTransactionExecCall) DoAndReturn(f func(context.Context, string, ...any) (int64, error)) *MockTransactionExecCall {
c.Call = c.Call.DoAndReturn(f) c.Call = c.Call.DoAndReturn(f)
return c return c
} }

View File

@@ -33,30 +33,30 @@ const queryInstanceStmt = `SELECT id, name, default_org_id, iam_project_id, cons
// Get implements [domain.InstanceRepository]. // Get implements [domain.InstanceRepository].
func (i *instance) Get(ctx context.Context, opts ...database.Condition) (*domain.Instance, error) { func (i *instance) Get(ctx context.Context, opts ...database.Condition) (*domain.Instance, error) {
i.builder = database.StatementBuilder{} builder := database.StatementBuilder{}
i.builder.WriteString(queryInstanceStmt) builder.WriteString(queryInstanceStmt)
// return only non deleted isntances // return only non deleted isntances
opts = append(opts, database.IsNull(i.DeletedAtColumn())) opts = append(opts, database.IsNull(i.DeletedAtColumn()))
andCondition := database.And(opts...) andCondition := database.And(opts...)
andCondition.Write(&i.builder) andCondition.Write(&builder)
return scanInstance(i.client.QueryRow(ctx, i.builder.String(), i.builder.Args()...)) return scanInstance(i.client.QueryRow(ctx, builder.String(), builder.Args()...))
} }
// List implements [domain.InstanceRepository]. // List implements [domain.InstanceRepository].
func (i *instance) List(ctx context.Context, opts ...database.Condition) ([]*domain.Instance, error) { func (i *instance) List(ctx context.Context, opts ...database.Condition) ([]*domain.Instance, error) {
i.builder = database.StatementBuilder{} builder := database.StatementBuilder{}
i.builder.WriteString(queryInstanceStmt) builder.WriteString(queryInstanceStmt)
// return only non deleted isntances // return only non deleted isntances
opts = append(opts, database.IsNull(i.DeletedAtColumn())) opts = append(opts, database.IsNull(i.DeletedAtColumn()))
andCondition := database.And(opts...) andCondition := database.And(opts...)
andCondition.Write(&i.builder) andCondition.Write(&builder)
rows, err := i.client.Query(ctx, i.builder.String(), i.builder.Args()...) rows, err := i.client.Query(ctx, builder.String(), builder.Args()...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -71,11 +71,11 @@ const createInstanceStmt = `INSERT INTO zitadel.instances (id, name, default_org
// Create implements [domain.InstanceRepository]. // Create implements [domain.InstanceRepository].
func (i *instance) Create(ctx context.Context, instance *domain.Instance) error { func (i *instance) Create(ctx context.Context, instance *domain.Instance) error {
i.builder = database.StatementBuilder{} builder := database.StatementBuilder{}
i.builder.AppendArgs(instance.ID, instance.Name, instance.DefaultOrgID, instance.IAMProjectID, instance.ConsoleClientID, instance.ConsoleAppID, instance.DefaultLanguage) builder.AppendArgs(instance.ID, instance.Name, instance.DefaultOrgID, instance.IAMProjectID, instance.ConsoleClientID, instance.ConsoleAppID, instance.DefaultLanguage)
i.builder.WriteString(createInstanceStmt) builder.WriteString(createInstanceStmt)
err := i.client.QueryRow(ctx, i.builder.String(), i.builder.Args()...).Scan(&instance.CreatedAt, &instance.UpdatedAt) err := i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&instance.CreatedAt, &instance.UpdatedAt)
if err != nil { if err != nil {
var pgErr *pgconn.PgError var pgErr *pgconn.PgError
if errors.As(err, &pgErr) { if errors.As(err, &pgErr) {
@@ -101,14 +101,14 @@ func (i *instance) Create(ctx context.Context, instance *domain.Instance) error
// Update implements [domain.InstanceRepository]. // Update implements [domain.InstanceRepository].
func (i instance) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) { func (i instance) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) {
i.builder = database.StatementBuilder{} builder := database.StatementBuilder{}
i.builder.WriteString(`UPDATE zitadel.instances SET `) builder.WriteString(`UPDATE zitadel.instances SET `)
database.Changes(changes).Write(&i.builder) database.Changes(changes).Write(&builder)
i.writeCondition(condition) i.writeCondition(&builder, condition)
stmt := i.builder.String() stmt := builder.String()
rowsAffected, err := i.client.Exec(ctx, stmt, i.builder.Args()...) rowsAffected, err := i.client.Exec(ctx, stmt, builder.Args()...)
return rowsAffected, err return rowsAffected, err
} }
@@ -117,12 +117,12 @@ func (i instance) Delete(ctx context.Context, condition database.Condition) erro
if condition == nil { if condition == nil {
return errors.New("Delete must contain a condition") // (otherwise ALL instances will be deleted) return errors.New("Delete must contain a condition") // (otherwise ALL instances will be deleted)
} }
i.builder = database.StatementBuilder{} builder := database.StatementBuilder{}
i.builder.WriteString(`UPDATE zitadel.instances SET deleted_at = $1`) builder.WriteString(`UPDATE zitadel.instances SET deleted_at = $1`)
i.builder.AppendArgs(time.Now()) builder.AppendArgs(time.Now())
i.writeCondition(condition) i.writeCondition(&builder, condition)
_, err := i.client.Exec(ctx, i.builder.String(), i.builder.Args()...) _, err := i.client.Exec(ctx, builder.String(), builder.Args()...)
return err return err
} }
@@ -203,12 +203,15 @@ func (instance) DeletedAtColumn() database.Column {
return database.NewColumn("deleted_at") return database.NewColumn("deleted_at")
} }
func (i *instance) writeCondition(condition database.Condition) { func (i *instance) writeCondition(
builder *database.StatementBuilder,
condition database.Condition,
) {
if condition == nil { if condition == nil {
return return
} }
i.builder.WriteString(" WHERE ") builder.WriteString(" WHERE ")
condition.Write(&i.builder) condition.Write(builder)
} }
func scanInstance(scanner database.Scanner) (*domain.Instance, error) { func scanInstance(scanner database.Scanner) (*domain.Instance, error) {

View File

@@ -3,6 +3,6 @@ package repository
import "github.com/zitadel/zitadel/backend/v3/storage/database" import "github.com/zitadel/zitadel/backend/v3/storage/database"
type repository struct { type repository struct {
builder database.StatementBuilder // builder database.StatementBuilder
client database.QueryExecutor client database.QueryExecutor
} }

View File

@@ -47,13 +47,14 @@ func (u *user) List(ctx context.Context, opts ...database.QueryOption) (users []
opt(options) opt(options)
} }
u.builder.WriteString(queryUserStmt) builder := database.StatementBuilder{}
options.WriteCondition(&u.builder) builder.WriteString(queryUserStmt)
options.WriteOrderBy(&u.builder) options.WriteCondition(&builder)
options.WriteLimit(&u.builder) options.WriteOrderBy(&builder)
options.WriteOffset(&u.builder) options.WriteLimit(&builder)
options.WriteOffset(&builder)
rows, err := u.client.Query(ctx, u.builder.String(), u.builder.Args()...) rows, err := u.client.Query(ctx, builder.String(), builder.Args()...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -84,13 +85,14 @@ func (u *user) Get(ctx context.Context, opts ...database.QueryOption) (*domain.U
opt(options) opt(options)
} }
u.builder.WriteString(queryUserStmt) builder := database.StatementBuilder{}
options.WriteCondition(&u.builder) builder.WriteString(queryUserStmt)
options.WriteOrderBy(&u.builder) options.WriteCondition(&builder)
options.WriteLimit(&u.builder) options.WriteOrderBy(&builder)
options.WriteOffset(&u.builder) options.WriteLimit(&builder)
options.WriteOffset(&builder)
return scanUser(u.client.QueryRow(ctx, u.builder.String(), u.builder.Args()...)) return scanUser(u.client.QueryRow(ctx, builder.String(), builder.Args()...))
} }
const ( const (
@@ -104,23 +106,25 @@ const (
// Create implements [domain.UserRepository]. // Create implements [domain.UserRepository].
func (u *user) Create(ctx context.Context, user *domain.User) error { func (u *user) Create(ctx context.Context, user *domain.User) error {
u.builder.AppendArgs(user.InstanceID, user.OrgID, user.ID, user.Username, user.Traits.Type()) builder := database.StatementBuilder{}
builder.AppendArgs(user.InstanceID, user.OrgID, user.ID, user.Username, user.Traits.Type())
switch trait := user.Traits.(type) { switch trait := user.Traits.(type) {
case *domain.Human: case *domain.Human:
u.builder.WriteString(createHumanStmt) builder.WriteString(createHumanStmt)
u.builder.AppendArgs(trait.FirstName, trait.LastName, trait.Email.Address, trait.Email.VerifiedAt, trait.Phone.Number, trait.Phone.VerifiedAt) builder.AppendArgs(trait.FirstName, trait.LastName, trait.Email.Address, trait.Email.VerifiedAt, trait.Phone.Number, trait.Phone.VerifiedAt)
case *domain.Machine: case *domain.Machine:
u.builder.WriteString(createMachineStmt) builder.WriteString(createMachineStmt)
u.builder.AppendArgs(trait.Description) builder.AppendArgs(trait.Description)
} }
return u.client.QueryRow(ctx, u.builder.String(), u.builder.Args()...).Scan(&user.CreatedAt, &user.UpdatedAt) return u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&user.CreatedAt, &user.UpdatedAt)
} }
// Delete implements [domain.UserRepository]. // Delete implements [domain.UserRepository].
func (u *user) Delete(ctx context.Context, condition database.Condition) error { func (u *user) Delete(ctx context.Context, condition database.Condition) error {
u.builder.WriteString("DELETE FROM users") builder := database.StatementBuilder{}
u.writeCondition(condition) builder.WriteString("DELETE FROM users")
_, err := u.client.Exec(ctx, u.builder.String(), u.builder.Args()...) u.writeCondition(builder, condition)
_, err := u.client.Exec(ctx, builder.String(), builder.Args()...)
return err return err
} }
@@ -219,12 +223,15 @@ func (user) DeletedAtColumn() database.Column {
return database.NewColumn("deleted_at") return database.NewColumn("deleted_at")
} }
func (u *user) writeCondition(condition database.Condition) { func (u *user) writeCondition(
builder database.StatementBuilder,
condition database.Condition,
) {
if condition == nil { if condition == nil {
return return
} }
u.builder.WriteString(" WHERE ") builder.WriteString(" WHERE ")
condition.Write(&u.builder) condition.Write(&builder)
} }
func (u user) columns() database.Columns { func (u user) columns() database.Columns {

View File

@@ -24,14 +24,14 @@ const userEmailQuery = `SELECT h.email_address, h.email_verified_at FROM user_hu
func (u *userHuman) GetEmail(ctx context.Context, condition database.Condition) (*domain.Email, error) { func (u *userHuman) GetEmail(ctx context.Context, condition database.Condition) (*domain.Email, error) {
var email domain.Email var email domain.Email
u.builder.WriteString(userEmailQuery) builder := database.StatementBuilder{}
u.writeCondition(condition) builder.WriteString(userEmailQuery)
u.writeCondition(builder, condition)
err := u.client.QueryRow(ctx, u.builder.String(), u.builder.Args()...).Scan( err := u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(
&email.Address, &email.Address,
&email.VerifiedAt, &email.VerifiedAt,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -40,13 +40,14 @@ func (u *userHuman) GetEmail(ctx context.Context, condition database.Condition)
// Update implements [domain.HumanRepository]. // Update implements [domain.HumanRepository].
func (h userHuman) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error { func (h userHuman) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error {
h.builder.WriteString(`UPDATE human_users SET `) builder := database.StatementBuilder{}
database.Changes(changes).Write(&h.builder) builder.WriteString(`UPDATE human_users SET `)
h.writeCondition(condition) database.Changes(changes).Write(&builder)
h.writeCondition(builder, condition)
stmt := h.builder.String() stmt := builder.String()
_, err := h.client.Exec(ctx, stmt, h.builder.Args()...) _, err := h.client.Exec(ctx, stmt, builder.Args()...)
return err return err
} }

View File

@@ -19,12 +19,13 @@ var _ domain.MachineRepository = (*userMachine)(nil)
// Update implements [domain.MachineRepository]. // Update implements [domain.MachineRepository].
func (m userMachine) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error { func (m userMachine) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error {
m.builder.WriteString("UPDATE user_machines SET ") builder := database.StatementBuilder{}
database.Changes(changes).Write(&m.builder) builder.WriteString("UPDATE user_machines SET ")
m.writeCondition(condition) database.Changes(changes).Write(&builder)
m.writeCondition(builder, condition)
m.writeReturning() m.writeReturning()
_, err := m.client.Exec(ctx, m.builder.String(), m.builder.Args()...) _, err := m.client.Exec(ctx, builder.String(), builder.Args()...)
return err return err
} }
@@ -60,6 +61,7 @@ func (m userMachine) columns() database.Columns {
} }
func (m *userMachine) writeReturning() { func (m *userMachine) writeReturning() {
m.builder.WriteString(" RETURNING ") builder := database.StatementBuilder{}
m.columns().Write(&m.builder) builder.WriteString(" RETURNING ")
m.columns().Write(&builder)
} }

View File

@@ -54,16 +54,16 @@ func (mr *MockMessageMockRecorder) GetContent() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetContent", reflect.TypeOf((*MockMessage)(nil).GetContent)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetContent", reflect.TypeOf((*MockMessage)(nil).GetContent))
} }
// GetTriggeringEvent mocks base method. // GetTriggeringEventType mocks base method.
func (m *MockMessage) GetTriggeringEvent() eventstore.Event { func (m *MockMessage) GetTriggeringEventType() eventstore.EventType {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTriggeringEvent") ret := m.ctrl.Call(m, "GetTriggeringEventType")
ret0, _ := ret[0].(eventstore.Event) ret0, _ := ret[0].(eventstore.EventType)
return ret0 return ret0
} }
// GetTriggeringEvent indicates an expected call of GetTriggeringEvent. // GetTriggeringEventType indicates an expected call of GetTriggeringEventType.
func (mr *MockMessageMockRecorder) GetTriggeringEvent() *gomock.Call { func (mr *MockMessageMockRecorder) GetTriggeringEventType() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTriggeringEvent", reflect.TypeOf((*MockMessage)(nil).GetTriggeringEvent)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTriggeringEventType", reflect.TypeOf((*MockMessage)(nil).GetTriggeringEventType))
} }