diff --git a/backend/v3/domain/instance.go b/backend/v3/domain/instance.go index 03dd40cf2ec..446331a320e 100644 --- a/backend/v3/domain/instance.go +++ b/backend/v3/domain/instance.go @@ -69,6 +69,8 @@ type instanceConditions interface { IDCondition(instanceID string) database.Condition // NameCondition returns a filter on the name field. NameCondition(op database.TextOperation, name string) database.Condition + // ExistsDomain returns a filter on the instance domains. + ExistsDomain(cond database.Condition) database.Condition } // instanceChanges define all the changes for the instance table. @@ -99,17 +101,16 @@ type InstanceRepository interface { // Member returns the member repository which is a sub repository of the instance repository. // Member() MemberRepository - Get(ctx context.Context, opts ...database.QueryOption) (*Instance, error) - List(ctx context.Context, opts ...database.QueryOption) ([]*Instance, error) + Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*Instance, error) + List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*Instance, error) - Create(ctx context.Context, instance *Instance) error - Update(ctx context.Context, id string, changes ...database.Change) (int64, error) - Delete(ctx context.Context, id string) (int64, error) + Create(ctx context.Context, client database.QueryExecutor, instance *Instance) error + Update(ctx context.Context, client database.QueryExecutor, id string, changes ...database.Change) (int64, error) + Delete(ctx context.Context, client database.QueryExecutor, id string) (int64, error) - // Domains returns the domain sub repository for the instance. - // If shouldLoad is true, the domains will be loaded from the database and written to the [Instance].Domains field. - // If shouldLoad is set to true once, the Domains field will be set even if shouldLoad is false in the future. - Domains(shouldLoad bool) InstanceDomainRepository + // LoadDomains loads the domains of the given instance. + // If it is called the [Instance].Domains field will be set on future calls to Get or List. + LoadDomains() InstanceRepository } type CreateInstance struct { diff --git a/backend/v3/domain/instance_domain.go b/backend/v3/domain/instance_domain.go index 4c6a71b2e9f..786be287ee7 100644 --- a/backend/v3/domain/instance_domain.go +++ b/backend/v3/domain/instance_domain.go @@ -65,15 +65,15 @@ type InstanceDomainRepository interface { // Get returns a single domain based on the criteria. // If no domain is found, it returns an error of type [database.ErrNotFound]. // If multiple domains are found, it returns an error of type [database.ErrMultipleRows]. - Get(ctx context.Context, opts ...database.QueryOption) (*InstanceDomain, error) + Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*InstanceDomain, error) // List returns a list of domains based on the criteria. // If no domains are found, it returns an empty slice. - List(ctx context.Context, opts ...database.QueryOption) ([]*InstanceDomain, error) + List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*InstanceDomain, error) // Add adds a new domain to the instance. - Add(ctx context.Context, domain *AddInstanceDomain) error + Add(ctx context.Context, client database.QueryExecutor, domain *AddInstanceDomain) error // Update updates an existing domain in the instance. - Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) + Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error) // Remove removes a domain from the instance. - Remove(ctx context.Context, condition database.Condition) (int64, error) + Remove(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error) } diff --git a/backend/v3/domain/organization.go b/backend/v3/domain/organization.go index 570e96c39e9..32d6528ba86 100644 --- a/backend/v3/domain/organization.go +++ b/backend/v3/domain/organization.go @@ -26,13 +26,6 @@ type Organization struct { Domains []*OrganizationDomain `json:"domains,omitempty" db:"-"` // domains need to be handled separately } -// OrgIdentifierCondition is used to help specify a single Organization, -// it will either be used as the organization ID or organization name, -// as organizations can be identified either using (instanceID + ID) OR (instanceID + name) -type OrgIdentifierCondition interface { - database.Condition -} - // organizationColumns define all the columns of the instance table. type organizationColumns interface { // IDColumn returns the column for the id field. @@ -52,13 +45,15 @@ type organizationColumns interface { // organizationConditions define all the conditions for the instance table. type organizationConditions interface { // IDCondition returns an equal filter on the id field. - IDCondition(instanceID string) OrgIdentifierCondition + IDCondition(instanceID string) database.Condition // NameCondition returns a filter on the name field. - NameCondition(name string) OrgIdentifierCondition + NameCondition(op database.TextOperation, name string) database.Condition // InstanceIDCondition returns a filter on the instance id field. InstanceIDCondition(instanceID string) database.Condition // StateCondition returns a filter on the name field. StateCondition(state OrgState) database.Condition + // ExistsDomain returns a filter on the organizations domains. + ExistsDomain(cond database.Condition) database.Condition } // organizationChanges define all the changes for the instance table. @@ -75,17 +70,16 @@ type OrganizationRepository interface { organizationConditions organizationChanges - Get(ctx context.Context, opts ...database.QueryOption) (*Organization, error) - List(ctx context.Context, opts ...database.QueryOption) ([]*Organization, error) + Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*Organization, error) + List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*Organization, error) - Create(ctx context.Context, instance *Organization) error - Update(ctx context.Context, id OrgIdentifierCondition, instance_id string, changes ...database.Change) (int64, error) - Delete(ctx context.Context, id OrgIdentifierCondition, instance_id string) (int64, error) + Create(ctx context.Context, client database.QueryExecutor, org *Organization) error + Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error) + Delete(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error) - // Domains returns the domain sub repository for the organization. - // If shouldLoad is true, the domains will be loaded from the database and written to the [Instance].Domains field. - // If shouldLoad is set to true once, the Domains field will be set event if shouldLoad is false in the future. - Domains(shouldLoad bool) OrganizationDomainRepository + // LoadDomains loads the domains of the given organizations. + // If it is called the [Organization].Domains field will be set on future calls to Get or List. + LoadDomains() OrganizationRepository } type CreateOrganization struct { diff --git a/backend/v3/domain/organization_domain.go b/backend/v3/domain/organization_domain.go index c0868e3a620..1e1b3877c49 100644 --- a/backend/v3/domain/organization_domain.go +++ b/backend/v3/domain/organization_domain.go @@ -70,15 +70,15 @@ type OrganizationDomainRepository interface { // Get returns a single domain based on the criteria. // If no domain is found, it returns an error of type [database.ErrNotFound]. // If multiple domains are found, it returns an error of type [database.ErrMultipleRows]. - Get(ctx context.Context, opts ...database.QueryOption) (*OrganizationDomain, error) + Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*OrganizationDomain, error) // List returns a list of domains based on the criteria. // If no domains are found, it returns an empty slice. - List(ctx context.Context, opts ...database.QueryOption) ([]*OrganizationDomain, error) + List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*OrganizationDomain, error) // Add adds a new domain to the organization. - Add(ctx context.Context, domain *AddOrganizationDomain) error + Add(ctx context.Context, client database.QueryExecutor, domain *AddOrganizationDomain) error // Update updates an existing domain in the organization. - Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) + Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error) // Remove removes a domain from the organization. - Remove(ctx context.Context, condition database.Condition) (int64, error) + Remove(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error) } diff --git a/backend/v3/domain/user.go b/backend/v3/domain/user.go index fae0d75b6e7..333b5c96fbb 100644 --- a/backend/v3/domain/user.go +++ b/backend/v3/domain/user.go @@ -57,13 +57,13 @@ type UserRepository interface { userConditions userChanges // Get returns a user based on the given condition. - Get(ctx context.Context, opts ...database.QueryOption) (*User, error) + Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*User, error) // List returns a list of users based on the given condition. - List(ctx context.Context, opts ...database.QueryOption) ([]*User, error) + List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*User, error) // Create creates a new user. - Create(ctx context.Context, user *User) error + Create(ctx context.Context, client database.QueryExecutor, user *User) error // Delete removes users based on the given condition. - Delete(ctx context.Context, condition database.Condition) error + Delete(ctx context.Context, client database.QueryExecutor, condition database.Condition) error // Human returns the [HumanRepository]. Human() HumanRepository // Machine returns the [MachineRepository]. @@ -143,9 +143,9 @@ type HumanRepository interface { humanChanges // Get returns an email based on the given condition. - GetEmail(ctx context.Context, condition database.Condition) (*Email, error) + GetEmail(ctx context.Context, client database.QueryExecutor, condition database.Condition) (*Email, error) // Update updates human users based on the given condition and changes. - Update(ctx context.Context, condition database.Condition, changes ...database.Change) error + Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) error } // machineColumns define all the columns of the machine table which inherits the user table. @@ -172,7 +172,7 @@ type machineChanges interface { // MachineRepository is the interface for the machine repository it inherits the user repository. type MachineRepository interface { // Update updates machine users based on the given condition and changes. - Update(ctx context.Context, condition database.Condition, changes ...database.Change) error + Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) error machineColumns machineConditions diff --git a/backend/v3/storage/database/change.go b/backend/v3/storage/database/change.go index 724e029dbe9..6581a2c4ce0 100644 --- a/backend/v3/storage/database/change.go +++ b/backend/v3/storage/database/change.go @@ -1,9 +1,14 @@ package database +import "slices" + // Change represents a change to a column in a database table. // Its written in the SET clause of an UPDATE statement. type Change interface { + // Write writes the change to the given statement builder. Write(builder *StatementBuilder) + // IsOnColumn checks if the change is on the given column. + IsOnColumn(col Column) bool } type change[V Value] struct { @@ -13,6 +18,8 @@ type change[V Value] struct { var _ Change = (*change[string])(nil) +// NewChange creates a new Change for the given column and value. +// If you want to set a column to NULL, use [NewChangePtr]. func NewChange[V Value](col Column, value V) Change { return &change[V]{ column: col, @@ -20,6 +27,8 @@ func NewChange[V Value](col Column, value V) Change { } } +// NewChangePtr creates a new Change for the given column and value pointer. +// If the value pointer is nil, the column will be set to NULL. func NewChangePtr[V Value](col Column, value *V) Change { if value == nil { return NewChange(col, NullInstruction) @@ -34,19 +43,31 @@ func (c change[V]) Write(builder *StatementBuilder) { builder.WriteArg(c.value) } +// IsOnColumn implements [Change]. +func (c change[V]) IsOnColumn(col Column) bool { + return c.column.Equals(col) +} + type Changes []Change func NewChanges(cols ...Change) Change { return Changes(cols) } +// IsOnColumn implements [Change]. +func (c Changes) IsOnColumn(col Column) bool { + return slices.ContainsFunc(c, func(change Change) bool { + return change.IsOnColumn(col) + }) +} + // Write implements [Change]. func (m Changes) Write(builder *StatementBuilder) { - for i, col := range m { + for i, change := range m { if i > 0 { builder.WriteString(", ") } - col.Write(builder) + change.Write(builder) } } diff --git a/backend/v3/storage/database/change_test.go b/backend/v3/storage/database/change_test.go new file mode 100644 index 00000000000..8b2733ddcea --- /dev/null +++ b/backend/v3/storage/database/change_test.go @@ -0,0 +1,68 @@ +package database + +import ( + "testing" + + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestChangeWrite(t *testing.T) { + type want struct { + stmt string + args []any + } + for _, test := range []struct { + name string + change Change + want want + }{ + { + name: "change", + change: NewChange(NewColumn("table", "column"), "value"), + want: want{ + stmt: "column = $1", + args: []any{"value"}, + }, + }, + { + name: "change ptr to null", + change: NewChangePtr[int](NewColumn("table", "column"), nil), + want: want{ + stmt: "column = NULL", + args: nil, + }, + }, + { + name: "change ptr to value", + change: NewChangePtr(NewColumn("table", "column"), gu.Ptr(42)), + want: want{ + stmt: "column = $1", + args: []any{42}, + }, + }, + { + name: "multiple changes", + change: NewChanges( + NewChange(NewColumn("table", "column1"), "value1"), + NewChangePtr[int](NewColumn("table", "column2"), nil), + NewChange(NewColumn("table", "column3"), 123), + ), + want: want{ + stmt: "column1 = $1, column2 = NULL, column3 = $2", + args: []any{"value1", 123}, + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + var builder StatementBuilder + test.change.Write(&builder) + assert.Equal(t, test.want.stmt, builder.String()) + require.Len(t, builder.Args(), len(test.want.args)) + for i, arg := range test.want.args { + assert.Equal(t, arg, builder.Args()[i]) + } + }) + } +} diff --git a/backend/v3/storage/database/column.go b/backend/v3/storage/database/column.go index 7f57637d38a..f1316048bbe 100644 --- a/backend/v3/storage/database/column.go +++ b/backend/v3/storage/database/column.go @@ -3,8 +3,9 @@ package database type Columns []Column // WriteQualified implements [Column]. -func (m Columns) WriteQualified(builder *StatementBuilder) { - for i, col := range m { +// Columns are separated by ", ". +func (c Columns) WriteQualified(builder *StatementBuilder) { + for i, col := range c { if i > 0 { builder.WriteString(", ") } @@ -13,8 +14,9 @@ func (m Columns) WriteQualified(builder *StatementBuilder) { } // WriteUnqualified implements [Column]. -func (m Columns) WriteUnqualified(builder *StatementBuilder) { - for i, col := range m { +// Columns are separated by ", ". +func (c Columns) WriteUnqualified(builder *StatementBuilder) { + for i, col := range c { if i > 0 { builder.WriteString(", ") } @@ -22,11 +24,33 @@ func (m Columns) WriteUnqualified(builder *StatementBuilder) { } } +// Equals implements [Column]. +func (c Columns) Equals(col Column) bool { + if col == nil { + return c == nil + } + other, ok := col.(Columns) + if !ok || len(other) != len(c) { + return false + } + for i, col := range c { + if !col.Equals(other[i]) { + return false + } + } + return true +} + +var _ Column = (Columns)(nil) + // Column represents a column in a database table. type Column interface { - // Write(builder *StatementBuilder) + // WriteQualified writes the column with the table name as prefix. WriteQualified(builder *StatementBuilder) + // WriteUnqualified writes the column without the table name as prefix. WriteUnqualified(builder *StatementBuilder) + // Equals checks if two columns are equal. + Equals(col Column) bool } type column struct { @@ -35,7 +59,7 @@ type column struct { } func NewColumn(table, name string) Column { - return column{table: table, name: name} + return &column{table: table, name: name} } // WriteQualified implements [Column]. @@ -51,35 +75,69 @@ func (c column) WriteUnqualified(builder *StatementBuilder) { builder.WriteString(c.name) } -var _ Column = (*column)(nil) +// Equals implements [Column]. +func (c *column) Equals(col Column) bool { + if col == nil { + return c == nil + } + toMatch, ok := col.(*column) + if !ok { + return false + } + return c.table == toMatch.table && c.name == toMatch.name +} -// // ignoreCaseColumn represents two database columns, one for the -// // original value and one for the lower case value. -// type ignoreCaseColumn interface { -// Column -// WriteIgnoreCase(builder *StatementBuilder) -// } +// LowerColumn returns a column that represents LOWER(col). +func LowerColumn(col Column) Column { + return &functionColumn{fn: functionLower, col: col} +} -// func NewIgnoreCaseColumn(col Column, suffix string) ignoreCaseColumn { -// return ignoreCaseCol{ -// column: col, -// suffix: suffix, -// } -// } +// SHA256Column returns a column that represents SHA256(col). +func SHA256Column(col Column) Column { + return &functionColumn{fn: functionSHA256, col: col} +} -// type ignoreCaseCol struct { -// column Column -// suffix string -// } +type functionColumn struct { + fn function + col Column +} -// // WriteIgnoreCase implements [ignoreCaseColumn]. -// func (c ignoreCaseCol) WriteIgnoreCase(builder *StatementBuilder) { -// c.column.WriteQualified(builder) -// builder.WriteString(c.suffix) -// } +type function string -// // WriteQualified implements [ignoreCaseColumn]. -// func (c ignoreCaseCol) WriteQualified(builder *StatementBuilder) { -// c.column.WriteQualified(builder) -// builder.WriteString(c.suffix) -// } +const ( + _ function = "" + functionLower function = "LOWER" + functionSHA256 function = "SHA256" +) + +// WriteQualified implements [Column]. +func (c functionColumn) WriteQualified(builder *StatementBuilder) { + builder.Grow(len(c.fn) + 2) + builder.WriteString(string(c.fn)) + builder.WriteRune('(') + c.col.WriteQualified(builder) + builder.WriteRune(')') +} + +// WriteUnqualified implements [Column]. +func (c functionColumn) WriteUnqualified(builder *StatementBuilder) { + builder.Grow(len(c.fn) + 2) + builder.WriteString(string(c.fn)) + builder.WriteRune('(') + c.col.WriteUnqualified(builder) + builder.WriteRune(')') +} + +// Equals implements [Column]. +func (c *functionColumn) Equals(col Column) bool { + if col == nil { + return c == nil + } + toMatch, ok := col.(*functionColumn) + if !ok || toMatch.fn != c.fn { + return false + } + return c.col.Equals(toMatch.col) +} + +var _ Column = (*functionColumn)(nil) diff --git a/backend/v3/storage/database/column_test.go b/backend/v3/storage/database/column_test.go new file mode 100644 index 00000000000..216dddfb5e8 --- /dev/null +++ b/backend/v3/storage/database/column_test.go @@ -0,0 +1,212 @@ +package database + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWriteUnqualified(t *testing.T) { + for _, tests := range []struct { + name string + column Column + expected string + }{ + { + name: "column", + column: NewColumn("table", "column"), + expected: "column", + }, + { + name: "columns", + column: Columns{ + NewColumn("table", "column1"), + NewColumn("table", "column2"), + }, + expected: "column1, column2", + }, + { + name: "function column", + column: SHA256Column(NewColumn("table", "column")), + expected: "SHA256(column)", + }, + } { + t.Run(tests.name, func(t *testing.T) { + var builder StatementBuilder + tests.column.WriteUnqualified(&builder) + assert.Equal(t, tests.expected, builder.String()) + }) + } +} + +func TestWriteQualified(t *testing.T) { + for _, tests := range []struct { + name string + column Column + expected string + }{ + { + name: "column", + column: NewColumn("table", "column"), + expected: "table.column", + }, + { + name: "columns", + column: Columns{ + NewColumn("table", "column1"), + NewColumn("table", "column2"), + }, + expected: "table.column1, table.column2", + }, + { + name: "function column", + column: SHA256Column(NewColumn("table", "column")), + expected: "SHA256(table.column)", + }, + } { + t.Run(tests.name, func(t *testing.T) { + var builder StatementBuilder + tests.column.WriteQualified(&builder) + assert.Equal(t, tests.expected, builder.String()) + }) + } +} + +func TestEquals(t *testing.T) { + for _, tests := range []struct { + name string + column Column + toCheck Column + expected bool + }{ + { + name: "column equal", + column: NewColumn("table", "column"), + toCheck: NewColumn("table", "column"), + expected: true, + }, + { + name: "column nil check", + column: NewColumn("table", "column"), + toCheck: nil, + expected: false, + }, + { + name: "column both nil", + column: (*column)(nil), + toCheck: nil, + expected: true, + }, + { + name: "column not equal (different name)", + column: NewColumn("table", "column"), + toCheck: NewColumn("table", "column2"), + expected: false, + }, + { + name: "column not equal (different type)", + column: NewColumn("table", "column"), + toCheck: SHA256Column(NewColumn("table", "column")), + expected: false, + }, + { + name: "columns equal", + column: Columns{ + NewColumn("table", "column1"), + NewColumn("table", "column2"), + }, + toCheck: Columns{ + NewColumn("table", "column1"), + NewColumn("table", "column2"), + }, + expected: true, + }, + { + name: "columns nil check", + column: Columns{ + NewColumn("table", "column1"), + NewColumn("table", "column2"), + }, + toCheck: nil, + expected: false, + }, + { + name: "columns both nil", + column: Columns(nil), + toCheck: nil, + expected: true, + }, + { + name: "columns not equal (different type)", + column: Columns{ + NewColumn("table", "column1"), + NewColumn("table", "column2"), + }, + toCheck: NewColumn("table", "column1"), + expected: false, + }, + { + name: "columns not equal (different length)", + column: Columns{ + NewColumn("table", "column1"), + NewColumn("table", "column2"), + }, + toCheck: Columns{ + NewColumn("table", "column1"), + }, + expected: false, + }, + { + name: "columns not equal (different order)", + column: Columns{ + NewColumn("table", "column1"), + NewColumn("table", "column2"), + }, + toCheck: Columns{ + NewColumn("table", "column2"), + NewColumn("table", "column1"), + }, + expected: false, + }, + { + name: "function column equal", + column: SHA256Column(NewColumn("table", "column")), + toCheck: SHA256Column(NewColumn("table", "column")), + expected: true, + }, + { + name: "function nil check", + column: SHA256Column(NewColumn("table", "column")), + toCheck: nil, + expected: false, + }, + { + name: "function both nil", + column: (*functionColumn)(nil), + toCheck: nil, + expected: true, + }, + { + name: "function column not equal (different function)", + column: SHA256Column(NewColumn("table", "column")), + toCheck: LowerColumn(NewColumn("table", "column")), + expected: false, + }, + { + name: "function column not equal (different inner column)", + column: SHA256Column(NewColumn("table", "column")), + toCheck: SHA256Column(NewColumn("table", "column2")), + expected: false, + }, + { + name: "function column not equal (different type)", + column: SHA256Column(NewColumn("table", "column")), + toCheck: NewColumn("table", "column2"), + expected: false, + }, + } { + t.Run(tests.name, func(t *testing.T) { + assert.Equal(t, tests.expected, tests.column.Equals(tests.toCheck)) + }) + } +} diff --git a/backend/v3/storage/database/condition.go b/backend/v3/storage/database/condition.go index 5c8da5ff4b2..6bc1124f12e 100644 --- a/backend/v3/storage/database/condition.go +++ b/backend/v3/storage/database/condition.go @@ -4,6 +4,9 @@ package database // Its written after the WHERE keyword in a SQL statement. type Condition interface { Write(builder *StatementBuilder) + // IsRestrictingColumn is used to check if the condition filters for a specific column. + // It acts as a save guard database operations that should be specific on the given column. + IsRestrictingColumn(col Column) bool } type and struct { @@ -11,7 +14,7 @@ type and struct { } // Write implements [Condition]. -func (a *and) Write(builder *StatementBuilder) { +func (a and) Write(builder *StatementBuilder) { if len(a.conditions) > 1 { builder.WriteString("(") defer builder.WriteString(")") @@ -29,6 +32,16 @@ func And(conditions ...Condition) *and { return &and{conditions: conditions} } +// IsRestrictingColumn implements [Condition]. +func (a and) IsRestrictingColumn(col Column) bool { + for _, condition := range a.conditions { + if condition.IsRestrictingColumn(col) { + return true + } + } + return false +} + var _ Condition = (*and)(nil) type or struct { @@ -36,7 +49,7 @@ type or struct { } // Write implements [Condition]. -func (o *or) Write(builder *StatementBuilder) { +func (o or) Write(builder *StatementBuilder) { if len(o.conditions) > 1 { builder.WriteString("(") defer builder.WriteString(")") @@ -54,6 +67,17 @@ func Or(conditions ...Condition) *or { return &or{conditions: conditions} } +// IsRestrictingColumn implements [Condition]. +// It returns true only if all conditions are restricting the given column. +func (o or) IsRestrictingColumn(col Column) bool { + for _, condition := range o.conditions { + if !condition.IsRestrictingColumn(col) { + return false + } + } + return true +} + var _ Condition = (*or)(nil) type isNull struct { @@ -61,7 +85,7 @@ type isNull struct { } // Write implements [Condition]. -func (i *isNull) Write(builder *StatementBuilder) { +func (i isNull) Write(builder *StatementBuilder) { i.column.WriteQualified(builder) builder.WriteString(" IS NULL") } @@ -71,6 +95,12 @@ func IsNull(column Column) *isNull { return &isNull{column: column} } +// IsRestrictingColumn implements [Condition]. +// It returns false because it cannot be used for restricting a column. +func (i isNull) IsRestrictingColumn(col Column) bool { + return false +} + var _ Condition = (*isNull)(nil) type isNotNull struct { @@ -78,7 +108,7 @@ type isNotNull struct { } // Write implements [Condition]. -func (i *isNotNull) Write(builder *StatementBuilder) { +func (i isNotNull) Write(builder *StatementBuilder) { i.column.WriteQualified(builder) builder.WriteString(" IS NOT NULL") } @@ -88,43 +118,122 @@ func IsNotNull(column Column) *isNotNull { return &isNotNull{column: column} } +// IsRestrictingColumn implements [Condition]. +// It returns false because it cannot be used for restricting a column. +func (i isNotNull) IsRestrictingColumn(col Column) bool { + return false +} + var _ Condition = (*isNotNull)(nil) -type valueCondition func(builder *StatementBuilder) +type valueCondition struct { + write func(builder *StatementBuilder) + col Column +} // NewTextCondition creates a condition that compares a text column with a value. -func NewTextCondition[V Text](col Column, op TextOperation, value V) Condition { - return valueCondition(func(builder *StatementBuilder) { - writeTextOperation(builder, col, op, value) - }) +// If you want to use ignore case operations, consider using [NewTextIgnoreCaseCondition]. +func NewTextCondition[T Text](col Column, op TextOperation, value T) Condition { + return valueCondition{ + col: col, + write: func(builder *StatementBuilder) { + writeTextOperation[T](builder, col, op, value) + }, + } +} + +// NewTextIgnoreCaseCondition creates a condition that compares a text column with a value, ignoring case by lowercasing both. +func NewTextIgnoreCaseCondition[T Text](col Column, op TextOperation, value T) Condition { + return valueCondition{ + col: col, + write: func(builder *StatementBuilder) { + writeTextOperation[T](builder, LowerColumn(col), op, LowerValue(value)) + }, + } } // NewDateCondition creates a condition that compares a numeric column with a value. func NewNumberCondition[V Number](col Column, op NumberOperation, value V) Condition { - return valueCondition(func(builder *StatementBuilder) { - writeNumberOperation(builder, col, op, value) - }) + return valueCondition{ + col: col, + write: func(builder *StatementBuilder) { + writeNumberOperation[V](builder, col, op, value) + }, + } } // NewDateCondition creates a condition that compares a boolean column with a value. func NewBooleanCondition[V Boolean](col Column, value V) Condition { - return valueCondition(func(builder *StatementBuilder) { - writeBooleanOperation(builder, col, value) - }) + return valueCondition{ + col: col, + write: func(builder *StatementBuilder) { + writeBooleanOperation[V](builder, col, value) + }, + } +} + +// NewBytesCondition creates a condition that compares a BYTEA column with a value. +func NewBytesCondition[V Bytes](col Column, op BytesOperation, value any) Condition { + return valueCondition{ + col: col, + write: func(builder *StatementBuilder) { + writeBytesOperation[V](builder, col, op, value) + }, + } } // NewColumnCondition creates a condition that compares two columns on equality. func NewColumnCondition(col1, col2 Column) Condition { - return valueCondition(func(builder *StatementBuilder) { - col1.WriteQualified(builder) - builder.WriteString(" = ") - col2.WriteQualified(builder) - }) + return valueCondition{ + col: col1, + write: func(builder *StatementBuilder) { + col1.WriteQualified(builder) + builder.WriteString(" = ") + col2.WriteQualified(builder) + }, + } } // Write implements [Condition]. func (c valueCondition) Write(builder *StatementBuilder) { - c(builder) + c.write(builder) +} + +// IsRestrictingColumn implements [Condition]. +func (i valueCondition) IsRestrictingColumn(col Column) bool { + return i.col.Equals(col) } var _ Condition = (*valueCondition)(nil) + +// existsCondition is a helper to write an EXISTS (SELECT 1 FROM WHERE ) clause. +// It implements Condition so it can be composed with other conditions using And/Or. +type existsCondition struct { + table string + condition Condition +} + +// Exists creates a condition that checks for the existence of rows in a subquery. +func Exists(table string, condition Condition) Condition { + return &existsCondition{ + table: table, + condition: condition, + } +} + +// Write implements [Condition]. +func (e existsCondition) Write(builder *StatementBuilder) { + builder.WriteString(" EXISTS (SELECT 1 FROM ") + builder.WriteString(e.table) + builder.WriteString(" WHERE ") + e.condition.Write(builder) + builder.WriteString(")") +} + +// IsRestrictingColumn implements [Condition]. +func (e existsCondition) IsRestrictingColumn(col Column) bool { + // Forward to the inner condition so safety checks (like instance_id presence) can still work. + return e.condition.IsRestrictingColumn(col) +} + +var _ Condition = (*existsCondition)(nil) diff --git a/backend/v3/storage/database/condition_test.go b/backend/v3/storage/database/condition_test.go new file mode 100644 index 00000000000..a15dcd8fd62 --- /dev/null +++ b/backend/v3/storage/database/condition_test.go @@ -0,0 +1,248 @@ +package database + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWrite(t *testing.T) { + type want struct { + stmt string + args []any + } + for _, test := range []struct { + name string + cond Condition + want want + }{ + { + name: "no and condition", + cond: And(), + want: want{ + stmt: "", + args: nil, + }, + }, + { + name: "one and condition", + cond: And( + NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")), + ), + want: want{ + stmt: "table.column1 = other_table.column2", + args: nil, + }, + }, + { + name: "multiple and condition", + cond: And( + NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")), + NewColumnCondition(NewColumn("table", "column3"), NewColumn("other_table", "column4")), + ), + want: want{ + stmt: "(table.column1 = other_table.column2 AND table.column3 = other_table.column4)", + args: nil, + }, + }, + { + name: "no or condition", + cond: Or(), + want: want{ + stmt: "", + args: nil, + }, + }, + { + name: "one or condition", + cond: Or( + NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")), + ), + want: want{ + stmt: "table.column1 = other_table.column2", + args: nil, + }, + }, + { + name: "multiple or condition", + cond: Or( + NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")), + NewColumnCondition(NewColumn("table", "column3"), NewColumn("other_table", "column4")), + ), + want: want{ + stmt: "(table.column1 = other_table.column2 OR table.column3 = other_table.column4)", + args: nil, + }, + }, + { + name: "is null condition", + cond: IsNull(NewColumn("table", "column1")), + want: want{ + stmt: "table.column1 IS NULL", + args: nil, + }, + }, + { + name: "is not null condition", + cond: IsNotNull(NewColumn("table", "column1")), + want: want{ + stmt: "table.column1 IS NOT NULL", + args: nil, + }, + }, + { + name: "text condition", + cond: NewTextCondition(NewColumn("table", "column1"), TextOperationEqual, "some text"), + want: want{ + stmt: "table.column1 = $1", + args: []any{"some text"}, + }, + }, + { + name: "text ignore case condition", + cond: NewTextIgnoreCaseCondition(NewColumn("table", "column1"), TextOperationNotEqual, "some TEXT"), + want: want{ + stmt: "LOWER(table.column1) <> LOWER($1)", + args: []any{"some TEXT"}, + }, + }, + { + name: "number condition", + cond: NewNumberCondition(NewColumn("table", "column1"), NumberOperationEqual, 42), + want: want{ + stmt: "table.column1 = $1", + args: []any{42}, + }, + }, + { + name: "boolean condition", + cond: NewBooleanCondition(NewColumn("table", "column1"), true), + want: want{ + stmt: "table.column1 = $1", + args: []any{true}, + }, + }, + { + name: "bytes condition", + cond: NewBytesCondition[[]byte](NewColumn("table", "column1"), BytesOperationEqual, []byte{0x01, 0x02, 0x03}), + want: want{ + stmt: "table.column1 = $1", + args: []any{[]byte{0x01, 0x02, 0x03}}, + }, + }, + { + name: "column condition", + cond: NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")), + want: want{ + stmt: "table.column1 = other_table.column2", + args: nil, + }, + }, + { + name: "exists condition", + cond: Exists("table", And( + NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")), + NewColumnCondition(NewColumn("table", "column3"), NewColumn("other_table", "column4")), + )), + want: want{ + stmt: " EXISTS (SELECT 1 FROM table WHERE (table.column1 = other_table.column2 AND table.column3 = other_table.column4))", + args: nil, + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + var builder StatementBuilder + test.cond.Write(&builder) + assert.Equal(t, test.want.stmt, builder.String()) + require.Len(t, builder.Args(), len(test.want.args)) + for i, arg := range test.want.args { + assert.Equal(t, arg, builder.Args()[i]) + } + }) + } +} + +func TestIsRestrictingColumn(t *testing.T) { + for _, test := range []struct { + name string + col Column + cond Condition + want bool + }{ + { + name: "and with restricting column", + col: NewColumn("table", "column1"), + cond: And( + NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")), + NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")), + ), + want: true, + }, + { + name: "and without restricting column", + col: NewColumn("table", "column1"), + cond: And( + NewColumnCondition(NewColumn("table", "column2"), NewColumn("other_table", "column3")), + IsNull(NewColumn("table", "column4")), + IsNotNull(NewColumn("table", "column5")), + ), + want: false, + }, + { + name: "or with restricting column", + col: NewColumn("table", "column1"), + cond: Or( + NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")), + NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")), + ), + want: true, + }, + { + name: "or without restricting column", + col: NewColumn("table", "column1"), + cond: Or( + NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")), + IsNotNull(NewColumn("table", "column4")), + IsNull(NewColumn("table", "column5")), + ), + want: false, + }, + { + name: "is null never restricts", + col: NewColumn("table", "column1"), + cond: IsNull(NewColumn("table", "column1")), + want: false, + }, + { + name: "is not null never restricts", + col: NewColumn("table", "column1"), + cond: IsNotNull(NewColumn("table", "column1")), + want: false, + }, + { + name: "exists with restricting column", + col: NewColumn("table", "column1"), + cond: Exists("table", And( + NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")), + NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")), + )), + want: true, + }, + { + name: "exists without restricting column", + col: NewColumn("table", "column1"), + cond: Exists("table", Or( + NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")), + IsNotNull(NewColumn("table", "column4")), + IsNull(NewColumn("table", "column5")), + )), + want: false, + }, + } { + t.Run(test.name, func(t *testing.T) { + isRestricting := test.cond.IsRestrictingColumn(test.col) + assert.Equal(t, test.want, isRestricting) + }) + } +} diff --git a/backend/v3/storage/database/database.go b/backend/v3/storage/database/database.go index 7cdeb9c0c3c..c241b42e589 100644 --- a/backend/v3/storage/database/database.go +++ b/backend/v3/storage/database/database.go @@ -10,7 +10,7 @@ type Pool interface { QueryExecutor Migrator - Acquire(ctx context.Context) (Client, error) + Acquire(ctx context.Context) (Connection, error) Close(ctx context.Context) error Ping(ctx context.Context) error @@ -22,8 +22,8 @@ type PoolTest interface { MigrateTest(ctx context.Context) error } -// Client is a single database connection which can be released back to the pool. -type Client interface { +// Connection is a single database connection which can be released back to the pool. +type Connection interface { Beginner QueryExecutor Migrator diff --git a/backend/v3/storage/database/dbmock/database.mock.go b/backend/v3/storage/database/dbmock/database.mock.go index 1ff898257c0..2c8b176b7e1 100644 --- a/backend/v3/storage/database/dbmock/database.mock.go +++ b/backend/v3/storage/database/dbmock/database.mock.go @@ -1,9 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/zitadel/zitadel/backend/v3/storage/database (interfaces: Pool,Client,Row,Rows,Transaction) +// Source: github.com/zitadel/zitadel/backend/v3/storage/database (interfaces: Pool,Connection,Row,Rows,Transaction) // // Generated by this command: // -// mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Client,Row,Rows,Transaction +// mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Connection,Row,Rows,Transaction // // Package dbmock is a generated GoMock package. @@ -21,6 +21,7 @@ import ( type MockPool struct { ctrl *gomock.Controller recorder *MockPoolMockRecorder + isgomock struct{} } // MockPoolMockRecorder is the mock recorder for MockPool. @@ -41,18 +42,18 @@ func (m *MockPool) EXPECT() *MockPoolMockRecorder { } // Acquire mocks base method. -func (m *MockPool) Acquire(arg0 context.Context) (database.Client, error) { +func (m *MockPool) Acquire(ctx context.Context) (database.Connection, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Acquire", arg0) - ret0, _ := ret[0].(database.Client) + ret := m.ctrl.Call(m, "Acquire", ctx) + ret0, _ := ret[0].(database.Connection) ret1, _ := ret[1].(error) return ret0, ret1 } // Acquire indicates an expected call of Acquire. -func (mr *MockPoolMockRecorder) Acquire(arg0 any) *MockPoolAcquireCall { +func (mr *MockPoolMockRecorder) Acquire(ctx any) *MockPoolAcquireCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockPool)(nil).Acquire), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockPool)(nil).Acquire), ctx) return &MockPoolAcquireCall{Call: call} } @@ -62,36 +63,36 @@ type MockPoolAcquireCall struct { } // Return rewrite *gomock.Call.Return -func (c *MockPoolAcquireCall) Return(arg0 database.Client, arg1 error) *MockPoolAcquireCall { +func (c *MockPoolAcquireCall) Return(arg0 database.Connection, arg1 error) *MockPoolAcquireCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockPoolAcquireCall) Do(f func(context.Context) (database.Client, error)) *MockPoolAcquireCall { +func (c *MockPoolAcquireCall) Do(f func(context.Context) (database.Connection, error)) *MockPoolAcquireCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockPoolAcquireCall) DoAndReturn(f func(context.Context) (database.Client, error)) *MockPoolAcquireCall { +func (c *MockPoolAcquireCall) DoAndReturn(f func(context.Context) (database.Connection, error)) *MockPoolAcquireCall { c.Call = c.Call.DoAndReturn(f) return c } // Begin mocks base method. -func (m *MockPool) Begin(arg0 context.Context, arg1 *database.TransactionOptions) (database.Transaction, error) { +func (m *MockPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Begin", arg0, arg1) + ret := m.ctrl.Call(m, "Begin", ctx, opts) ret0, _ := ret[0].(database.Transaction) ret1, _ := ret[1].(error) return ret0, ret1 } // Begin indicates an expected call of Begin. -func (mr *MockPoolMockRecorder) Begin(arg0, arg1 any) *MockPoolBeginCall { +func (mr *MockPoolMockRecorder) Begin(ctx, opts any) *MockPoolBeginCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockPool)(nil).Begin), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockPool)(nil).Begin), ctx, opts) return &MockPoolBeginCall{Call: call} } @@ -119,17 +120,17 @@ func (c *MockPoolBeginCall) DoAndReturn(f func(context.Context, *database.Transa } // Close mocks base method. -func (m *MockPool) Close(arg0 context.Context) error { +func (m *MockPool) Close(ctx context.Context) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close", arg0) + ret := m.ctrl.Call(m, "Close", ctx) ret0, _ := ret[0].(error) return ret0 } // Close indicates an expected call of Close. -func (mr *MockPoolMockRecorder) Close(arg0 any) *MockPoolCloseCall { +func (mr *MockPoolMockRecorder) Close(ctx any) *MockPoolCloseCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPool)(nil).Close), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPool)(nil).Close), ctx) return &MockPoolCloseCall{Call: call} } @@ -157,10 +158,10 @@ func (c *MockPoolCloseCall) DoAndReturn(f func(context.Context) error) *MockPool } // Exec mocks base method. -func (m *MockPool) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) { +func (m *MockPool) Exec(ctx context.Context, stmt string, args ...any) (int64, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, stmt} + for _, a := range args { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Exec", varargs...) @@ -170,9 +171,9 @@ func (m *MockPool) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, } // Exec indicates an expected call of Exec. -func (mr *MockPoolMockRecorder) Exec(arg0, arg1 any, arg2 ...any) *MockPoolExecCall { +func (mr *MockPoolMockRecorder) Exec(ctx, stmt any, args ...any) *MockPoolExecCall { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, stmt}, args...) call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockPool)(nil).Exec), varargs...) return &MockPoolExecCall{Call: call} } @@ -201,17 +202,17 @@ func (c *MockPoolExecCall) DoAndReturn(f func(context.Context, string, ...any) ( } // Migrate mocks base method. -func (m *MockPool) Migrate(arg0 context.Context) error { +func (m *MockPool) Migrate(ctx context.Context) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Migrate", arg0) + ret := m.ctrl.Call(m, "Migrate", ctx) ret0, _ := ret[0].(error) return ret0 } // Migrate indicates an expected call of Migrate. -func (mr *MockPoolMockRecorder) Migrate(arg0 any) *MockPoolMigrateCall { +func (mr *MockPoolMockRecorder) Migrate(ctx any) *MockPoolMigrateCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockPool)(nil).Migrate), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockPool)(nil).Migrate), ctx) return &MockPoolMigrateCall{Call: call} } @@ -238,11 +239,49 @@ func (c *MockPoolMigrateCall) DoAndReturn(f func(context.Context) error) *MockPo return c } -// Query mocks base method. -func (m *MockPool) Query(arg0 context.Context, arg1 string, arg2 ...any) (database.Rows, error) { +// Ping mocks base method. +func (m *MockPool) Ping(ctx context.Context) error { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + ret := m.ctrl.Call(m, "Ping", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Ping indicates an expected call of Ping. +func (mr *MockPoolMockRecorder) Ping(ctx any) *MockPoolPingCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockPool)(nil).Ping), ctx) + return &MockPoolPingCall{Call: call} +} + +// MockPoolPingCall wrap *gomock.Call +type MockPoolPingCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPoolPingCall) Return(arg0 error) *MockPoolPingCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPoolPingCall) Do(f func(context.Context) error) *MockPoolPingCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPoolPingCall) DoAndReturn(f func(context.Context) error) *MockPoolPingCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Query mocks base method. +func (m *MockPool) Query(ctx context.Context, stmt string, args ...any) (database.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, stmt} + for _, a := range args { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Query", varargs...) @@ -252,9 +291,9 @@ func (m *MockPool) Query(arg0 context.Context, arg1 string, arg2 ...any) (databa } // Query indicates an expected call of Query. -func (mr *MockPoolMockRecorder) Query(arg0, arg1 any, arg2 ...any) *MockPoolQueryCall { +func (mr *MockPoolMockRecorder) Query(ctx, stmt any, args ...any) *MockPoolQueryCall { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, stmt}, args...) call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockPool)(nil).Query), varargs...) return &MockPoolQueryCall{Call: call} } @@ -283,10 +322,10 @@ func (c *MockPoolQueryCall) DoAndReturn(f func(context.Context, string, ...any) } // QueryRow mocks base method. -func (m *MockPool) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) database.Row { +func (m *MockPool) QueryRow(ctx context.Context, stmt string, args ...any) database.Row { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, stmt} + for _, a := range args { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "QueryRow", varargs...) @@ -295,9 +334,9 @@ func (m *MockPool) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) data } // QueryRow indicates an expected call of QueryRow. -func (mr *MockPoolMockRecorder) QueryRow(arg0, arg1 any, arg2 ...any) *MockPoolQueryRowCall { +func (mr *MockPoolMockRecorder) QueryRow(ctx, stmt any, args ...any) *MockPoolQueryRowCall { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, stmt}, args...) call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockPool)(nil).QueryRow), varargs...) return &MockPoolQueryRowCall{Call: call} } @@ -325,73 +364,74 @@ func (c *MockPoolQueryRowCall) DoAndReturn(f func(context.Context, string, ...an return c } -// MockClient is a mock of Client interface. -type MockClient struct { +// MockConnection is a mock of Connection interface. +type MockConnection struct { ctrl *gomock.Controller - recorder *MockClientMockRecorder + recorder *MockConnectionMockRecorder + isgomock struct{} } -// MockClientMockRecorder is the mock recorder for MockClient. -type MockClientMockRecorder struct { - mock *MockClient +// MockConnectionMockRecorder is the mock recorder for MockConnection. +type MockConnectionMockRecorder struct { + mock *MockConnection } -// NewMockClient creates a new mock instance. -func NewMockClient(ctrl *gomock.Controller) *MockClient { - mock := &MockClient{ctrl: ctrl} - mock.recorder = &MockClientMockRecorder{mock} +// NewMockConnection creates a new mock instance. +func NewMockConnection(ctrl *gomock.Controller) *MockConnection { + mock := &MockConnection{ctrl: ctrl} + mock.recorder = &MockConnectionMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockClient) EXPECT() *MockClientMockRecorder { +func (m *MockConnection) EXPECT() *MockConnectionMockRecorder { return m.recorder } // Begin mocks base method. -func (m *MockClient) Begin(arg0 context.Context, arg1 *database.TransactionOptions) (database.Transaction, error) { +func (m *MockConnection) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Begin", arg0, arg1) + ret := m.ctrl.Call(m, "Begin", ctx, opts) ret0, _ := ret[0].(database.Transaction) ret1, _ := ret[1].(error) return ret0, ret1 } // Begin indicates an expected call of Begin. -func (mr *MockClientMockRecorder) Begin(arg0, arg1 any) *MockClientBeginCall { +func (mr *MockConnectionMockRecorder) Begin(ctx, opts any) *MockConnectionBeginCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockClient)(nil).Begin), arg0, arg1) - return &MockClientBeginCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockConnection)(nil).Begin), ctx, opts) + return &MockConnectionBeginCall{Call: call} } -// MockClientBeginCall wrap *gomock.Call -type MockClientBeginCall struct { +// MockConnectionBeginCall wrap *gomock.Call +type MockConnectionBeginCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockClientBeginCall) Return(arg0 database.Transaction, arg1 error) *MockClientBeginCall { +func (c *MockConnectionBeginCall) Return(arg0 database.Transaction, arg1 error) *MockConnectionBeginCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockClientBeginCall) Do(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockClientBeginCall { +func (c *MockConnectionBeginCall) Do(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockConnectionBeginCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockClientBeginCall) DoAndReturn(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockClientBeginCall { +func (c *MockConnectionBeginCall) DoAndReturn(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockConnectionBeginCall { c.Call = c.Call.DoAndReturn(f) return c } // Exec mocks base method. -func (m *MockClient) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) { +func (m *MockConnection) Exec(ctx context.Context, stmt string, args ...any) (int64, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, stmt} + for _, a := range args { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Exec", varargs...) @@ -401,79 +441,117 @@ func (m *MockClient) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64 } // Exec indicates an expected call of Exec. -func (mr *MockClientMockRecorder) Exec(arg0, arg1 any, arg2 ...any) *MockClientExecCall { +func (mr *MockConnectionMockRecorder) Exec(ctx, stmt any, args ...any) *MockConnectionExecCall { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockClient)(nil).Exec), varargs...) - return &MockClientExecCall{Call: call} + varargs := append([]any{ctx, stmt}, args...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockConnection)(nil).Exec), varargs...) + return &MockConnectionExecCall{Call: call} } -// MockClientExecCall wrap *gomock.Call -type MockClientExecCall struct { +// MockConnectionExecCall wrap *gomock.Call +type MockConnectionExecCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockClientExecCall) Return(arg0 int64, arg1 error) *MockClientExecCall { +func (c *MockConnectionExecCall) Return(arg0 int64, arg1 error) *MockConnectionExecCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockClientExecCall) Do(f func(context.Context, string, ...any) (int64, error)) *MockClientExecCall { +func (c *MockConnectionExecCall) Do(f func(context.Context, string, ...any) (int64, error)) *MockConnectionExecCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockClientExecCall) DoAndReturn(f func(context.Context, string, ...any) (int64, error)) *MockClientExecCall { +func (c *MockConnectionExecCall) DoAndReturn(f func(context.Context, string, ...any) (int64, error)) *MockConnectionExecCall { c.Call = c.Call.DoAndReturn(f) return c } // Migrate mocks base method. -func (m *MockClient) Migrate(arg0 context.Context) error { +func (m *MockConnection) Migrate(ctx context.Context) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Migrate", arg0) + ret := m.ctrl.Call(m, "Migrate", ctx) ret0, _ := ret[0].(error) return ret0 } // Migrate indicates an expected call of Migrate. -func (mr *MockClientMockRecorder) Migrate(arg0 any) *MockClientMigrateCall { +func (mr *MockConnectionMockRecorder) Migrate(ctx any) *MockConnectionMigrateCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockClient)(nil).Migrate), arg0) - return &MockClientMigrateCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockConnection)(nil).Migrate), ctx) + return &MockConnectionMigrateCall{Call: call} } -// MockClientMigrateCall wrap *gomock.Call -type MockClientMigrateCall struct { +// MockConnectionMigrateCall wrap *gomock.Call +type MockConnectionMigrateCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockClientMigrateCall) Return(arg0 error) *MockClientMigrateCall { +func (c *MockConnectionMigrateCall) Return(arg0 error) *MockConnectionMigrateCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MockClientMigrateCall) Do(f func(context.Context) error) *MockClientMigrateCall { +func (c *MockConnectionMigrateCall) Do(f func(context.Context) error) *MockConnectionMigrateCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockClientMigrateCall) DoAndReturn(f func(context.Context) error) *MockClientMigrateCall { +func (c *MockConnectionMigrateCall) DoAndReturn(f func(context.Context) error) *MockConnectionMigrateCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Ping mocks base method. +func (m *MockConnection) Ping(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Ping", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Ping indicates an expected call of Ping. +func (mr *MockConnectionMockRecorder) Ping(ctx any) *MockConnectionPingCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockConnection)(nil).Ping), ctx) + return &MockConnectionPingCall{Call: call} +} + +// MockConnectionPingCall wrap *gomock.Call +type MockConnectionPingCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockConnectionPingCall) Return(arg0 error) *MockConnectionPingCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockConnectionPingCall) Do(f func(context.Context) error) *MockConnectionPingCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockConnectionPingCall) DoAndReturn(f func(context.Context) error) *MockConnectionPingCall { c.Call = c.Call.DoAndReturn(f) return c } // Query mocks base method. -func (m *MockClient) Query(arg0 context.Context, arg1 string, arg2 ...any) (database.Rows, error) { +func (m *MockConnection) Query(ctx context.Context, stmt string, args ...any) (database.Rows, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, stmt} + for _, a := range args { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Query", varargs...) @@ -483,41 +561,41 @@ func (m *MockClient) Query(arg0 context.Context, arg1 string, arg2 ...any) (data } // Query indicates an expected call of Query. -func (mr *MockClientMockRecorder) Query(arg0, arg1 any, arg2 ...any) *MockClientQueryCall { +func (mr *MockConnectionMockRecorder) Query(ctx, stmt any, args ...any) *MockConnectionQueryCall { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockClient)(nil).Query), varargs...) - return &MockClientQueryCall{Call: call} + varargs := append([]any{ctx, stmt}, args...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockConnection)(nil).Query), varargs...) + return &MockConnectionQueryCall{Call: call} } -// MockClientQueryCall wrap *gomock.Call -type MockClientQueryCall struct { +// MockConnectionQueryCall wrap *gomock.Call +type MockConnectionQueryCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockClientQueryCall) Return(arg0 database.Rows, arg1 error) *MockClientQueryCall { +func (c *MockConnectionQueryCall) Return(arg0 database.Rows, arg1 error) *MockConnectionQueryCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockClientQueryCall) Do(f func(context.Context, string, ...any) (database.Rows, error)) *MockClientQueryCall { +func (c *MockConnectionQueryCall) Do(f func(context.Context, string, ...any) (database.Rows, error)) *MockConnectionQueryCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockClientQueryCall) DoAndReturn(f func(context.Context, string, ...any) (database.Rows, error)) *MockClientQueryCall { +func (c *MockConnectionQueryCall) DoAndReturn(f func(context.Context, string, ...any) (database.Rows, error)) *MockConnectionQueryCall { c.Call = c.Call.DoAndReturn(f) return c } // QueryRow mocks base method. -func (m *MockClient) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) database.Row { +func (m *MockConnection) QueryRow(ctx context.Context, stmt string, args ...any) database.Row { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, stmt} + for _, a := range args { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "QueryRow", varargs...) @@ -526,70 +604,70 @@ func (m *MockClient) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) da } // QueryRow indicates an expected call of QueryRow. -func (mr *MockClientMockRecorder) QueryRow(arg0, arg1 any, arg2 ...any) *MockClientQueryRowCall { +func (mr *MockConnectionMockRecorder) QueryRow(ctx, stmt any, args ...any) *MockConnectionQueryRowCall { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockClient)(nil).QueryRow), varargs...) - return &MockClientQueryRowCall{Call: call} + varargs := append([]any{ctx, stmt}, args...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockConnection)(nil).QueryRow), varargs...) + return &MockConnectionQueryRowCall{Call: call} } -// MockClientQueryRowCall wrap *gomock.Call -type MockClientQueryRowCall struct { +// MockConnectionQueryRowCall wrap *gomock.Call +type MockConnectionQueryRowCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockClientQueryRowCall) Return(arg0 database.Row) *MockClientQueryRowCall { +func (c *MockConnectionQueryRowCall) Return(arg0 database.Row) *MockConnectionQueryRowCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MockClientQueryRowCall) Do(f func(context.Context, string, ...any) database.Row) *MockClientQueryRowCall { +func (c *MockConnectionQueryRowCall) Do(f func(context.Context, string, ...any) database.Row) *MockConnectionQueryRowCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockClientQueryRowCall) DoAndReturn(f func(context.Context, string, ...any) database.Row) *MockClientQueryRowCall { +func (c *MockConnectionQueryRowCall) DoAndReturn(f func(context.Context, string, ...any) database.Row) *MockConnectionQueryRowCall { c.Call = c.Call.DoAndReturn(f) return c } // Release mocks base method. -func (m *MockClient) Release(arg0 context.Context) error { +func (m *MockConnection) Release(ctx context.Context) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Release", arg0) + ret := m.ctrl.Call(m, "Release", ctx) ret0, _ := ret[0].(error) return ret0 } // Release indicates an expected call of Release. -func (mr *MockClientMockRecorder) Release(arg0 any) *MockClientReleaseCall { +func (mr *MockConnectionMockRecorder) Release(ctx any) *MockConnectionReleaseCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Release", reflect.TypeOf((*MockClient)(nil).Release), arg0) - return &MockClientReleaseCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Release", reflect.TypeOf((*MockConnection)(nil).Release), ctx) + return &MockConnectionReleaseCall{Call: call} } -// MockClientReleaseCall wrap *gomock.Call -type MockClientReleaseCall struct { +// MockConnectionReleaseCall wrap *gomock.Call +type MockConnectionReleaseCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockClientReleaseCall) Return(arg0 error) *MockClientReleaseCall { +func (c *MockConnectionReleaseCall) Return(arg0 error) *MockConnectionReleaseCall { c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MockClientReleaseCall) Do(f func(context.Context) error) *MockClientReleaseCall { +func (c *MockConnectionReleaseCall) Do(f func(context.Context) error) *MockConnectionReleaseCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockClientReleaseCall) DoAndReturn(f func(context.Context) error) *MockClientReleaseCall { +func (c *MockConnectionReleaseCall) DoAndReturn(f func(context.Context) error) *MockConnectionReleaseCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -598,6 +676,7 @@ func (c *MockClientReleaseCall) DoAndReturn(f func(context.Context) error) *Mock type MockRow struct { ctrl *gomock.Controller recorder *MockRowMockRecorder + isgomock struct{} } // MockRowMockRecorder is the mock recorder for MockRow. @@ -618,10 +697,10 @@ func (m *MockRow) EXPECT() *MockRowMockRecorder { } // Scan mocks base method. -func (m *MockRow) Scan(arg0 ...any) error { +func (m *MockRow) Scan(dest ...any) error { m.ctrl.T.Helper() varargs := []any{} - for _, a := range arg0 { + for _, a := range dest { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Scan", varargs...) @@ -630,9 +709,9 @@ func (m *MockRow) Scan(arg0 ...any) error { } // Scan indicates an expected call of Scan. -func (mr *MockRowMockRecorder) Scan(arg0 ...any) *MockRowScanCall { +func (mr *MockRowMockRecorder) Scan(dest ...any) *MockRowScanCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRow)(nil).Scan), arg0...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRow)(nil).Scan), dest...) return &MockRowScanCall{Call: call} } @@ -663,6 +742,7 @@ func (c *MockRowScanCall) DoAndReturn(f func(...any) error) *MockRowScanCall { type MockRows struct { ctrl *gomock.Controller recorder *MockRowsMockRecorder + isgomock struct{} } // MockRowsMockRecorder is the mock recorder for MockRows. @@ -797,10 +877,10 @@ func (c *MockRowsNextCall) DoAndReturn(f func() bool) *MockRowsNextCall { } // Scan mocks base method. -func (m *MockRows) Scan(arg0 ...any) error { +func (m *MockRows) Scan(dest ...any) error { m.ctrl.T.Helper() varargs := []any{} - for _, a := range arg0 { + for _, a := range dest { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Scan", varargs...) @@ -809,9 +889,9 @@ func (m *MockRows) Scan(arg0 ...any) error { } // Scan indicates an expected call of Scan. -func (mr *MockRowsMockRecorder) Scan(arg0 ...any) *MockRowsScanCall { +func (mr *MockRowsMockRecorder) Scan(dest ...any) *MockRowsScanCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), arg0...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), dest...) return &MockRowsScanCall{Call: call} } @@ -842,6 +922,7 @@ func (c *MockRowsScanCall) DoAndReturn(f func(...any) error) *MockRowsScanCall { type MockTransaction struct { ctrl *gomock.Controller recorder *MockTransactionMockRecorder + isgomock struct{} } // MockTransactionMockRecorder is the mock recorder for MockTransaction. @@ -862,18 +943,18 @@ func (m *MockTransaction) EXPECT() *MockTransactionMockRecorder { } // Begin mocks base method. -func (m *MockTransaction) Begin(arg0 context.Context) (database.Transaction, error) { +func (m *MockTransaction) Begin(ctx context.Context) (database.Transaction, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Begin", arg0) + ret := m.ctrl.Call(m, "Begin", ctx) ret0, _ := ret[0].(database.Transaction) ret1, _ := ret[1].(error) return ret0, ret1 } // Begin indicates an expected call of Begin. -func (mr *MockTransactionMockRecorder) Begin(arg0 any) *MockTransactionBeginCall { +func (mr *MockTransactionMockRecorder) Begin(ctx any) *MockTransactionBeginCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockTransaction)(nil).Begin), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockTransaction)(nil).Begin), ctx) return &MockTransactionBeginCall{Call: call} } @@ -901,17 +982,17 @@ func (c *MockTransactionBeginCall) DoAndReturn(f func(context.Context) (database } // Commit mocks base method. -func (m *MockTransaction) Commit(arg0 context.Context) error { +func (m *MockTransaction) Commit(ctx context.Context) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Commit", arg0) + ret := m.ctrl.Call(m, "Commit", ctx) ret0, _ := ret[0].(error) return ret0 } // Commit indicates an expected call of Commit. -func (mr *MockTransactionMockRecorder) Commit(arg0 any) *MockTransactionCommitCall { +func (mr *MockTransactionMockRecorder) Commit(ctx any) *MockTransactionCommitCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTransaction)(nil).Commit), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTransaction)(nil).Commit), ctx) return &MockTransactionCommitCall{Call: call} } @@ -939,17 +1020,17 @@ func (c *MockTransactionCommitCall) DoAndReturn(f func(context.Context) error) * } // End mocks base method. -func (m *MockTransaction) End(arg0 context.Context, arg1 error) error { +func (m *MockTransaction) End(ctx context.Context, err error) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "End", arg0, arg1) + ret := m.ctrl.Call(m, "End", ctx, err) ret0, _ := ret[0].(error) return ret0 } // End indicates an expected call of End. -func (mr *MockTransactionMockRecorder) End(arg0, arg1 any) *MockTransactionEndCall { +func (mr *MockTransactionMockRecorder) End(ctx, err any) *MockTransactionEndCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "End", reflect.TypeOf((*MockTransaction)(nil).End), arg0, arg1) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "End", reflect.TypeOf((*MockTransaction)(nil).End), ctx, err) return &MockTransactionEndCall{Call: call} } @@ -977,10 +1058,10 @@ func (c *MockTransactionEndCall) DoAndReturn(f func(context.Context, error) erro } // Exec mocks base method. -func (m *MockTransaction) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) { +func (m *MockTransaction) Exec(ctx context.Context, stmt string, args ...any) (int64, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, stmt} + for _, a := range args { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Exec", varargs...) @@ -990,9 +1071,9 @@ func (m *MockTransaction) Exec(arg0 context.Context, arg1 string, arg2 ...any) ( } // Exec indicates an expected call of Exec. -func (mr *MockTransactionMockRecorder) Exec(arg0, arg1 any, arg2 ...any) *MockTransactionExecCall { +func (mr *MockTransactionMockRecorder) Exec(ctx, stmt any, args ...any) *MockTransactionExecCall { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, stmt}, args...) call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTransaction)(nil).Exec), varargs...) return &MockTransactionExecCall{Call: call} } @@ -1021,10 +1102,10 @@ func (c *MockTransactionExecCall) DoAndReturn(f func(context.Context, string, .. } // Query mocks base method. -func (m *MockTransaction) Query(arg0 context.Context, arg1 string, arg2 ...any) (database.Rows, error) { +func (m *MockTransaction) Query(ctx context.Context, stmt string, args ...any) (database.Rows, error) { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, stmt} + for _, a := range args { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "Query", varargs...) @@ -1034,9 +1115,9 @@ func (m *MockTransaction) Query(arg0 context.Context, arg1 string, arg2 ...any) } // Query indicates an expected call of Query. -func (mr *MockTransactionMockRecorder) Query(arg0, arg1 any, arg2 ...any) *MockTransactionQueryCall { +func (mr *MockTransactionMockRecorder) Query(ctx, stmt any, args ...any) *MockTransactionQueryCall { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, stmt}, args...) call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockTransaction)(nil).Query), varargs...) return &MockTransactionQueryCall{Call: call} } @@ -1065,10 +1146,10 @@ func (c *MockTransactionQueryCall) DoAndReturn(f func(context.Context, string, . } // QueryRow mocks base method. -func (m *MockTransaction) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) database.Row { +func (m *MockTransaction) QueryRow(ctx context.Context, stmt string, args ...any) database.Row { m.ctrl.T.Helper() - varargs := []any{arg0, arg1} - for _, a := range arg2 { + varargs := []any{ctx, stmt} + for _, a := range args { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "QueryRow", varargs...) @@ -1077,9 +1158,9 @@ func (m *MockTransaction) QueryRow(arg0 context.Context, arg1 string, arg2 ...an } // QueryRow indicates an expected call of QueryRow. -func (mr *MockTransactionMockRecorder) QueryRow(arg0, arg1 any, arg2 ...any) *MockTransactionQueryRowCall { +func (mr *MockTransactionMockRecorder) QueryRow(ctx, stmt any, args ...any) *MockTransactionQueryRowCall { mr.mock.ctrl.T.Helper() - varargs := append([]any{arg0, arg1}, arg2...) + varargs := append([]any{ctx, stmt}, args...) call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockTransaction)(nil).QueryRow), varargs...) return &MockTransactionQueryRowCall{Call: call} } @@ -1108,17 +1189,17 @@ func (c *MockTransactionQueryRowCall) DoAndReturn(f func(context.Context, string } // Rollback mocks base method. -func (m *MockTransaction) Rollback(arg0 context.Context) error { +func (m *MockTransaction) Rollback(ctx context.Context) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Rollback", arg0) + ret := m.ctrl.Call(m, "Rollback", ctx) ret0, _ := ret[0].(error) return ret0 } // Rollback indicates an expected call of Rollback. -func (mr *MockTransactionMockRecorder) Rollback(arg0 any) *MockTransactionRollbackCall { +func (mr *MockTransactionMockRecorder) Rollback(ctx any) *MockTransactionRollbackCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockTransaction)(nil).Rollback), arg0) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockTransaction)(nil).Rollback), ctx) return &MockTransactionRollbackCall{Call: call} } diff --git a/backend/v3/storage/database/dialect/postgres/conn.go b/backend/v3/storage/database/dialect/postgres/conn.go index a8639692689..8d780feb417 100644 --- a/backend/v3/storage/database/dialect/postgres/conn.go +++ b/backend/v3/storage/database/dialect/postgres/conn.go @@ -13,15 +13,15 @@ type pgxConn struct { *pgxpool.Conn } -var _ database.Client = (*pgxConn)(nil) +var _ database.Connection = (*pgxConn)(nil) -// Release implements [database.Client]. +// Release implements [database.Connection]. func (c *pgxConn) Release(_ context.Context) error { c.Conn.Release() return nil } -// Begin implements [database.Client]. +// Begin implements [database.Connection]. func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) { tx, err := c.BeginTx(ctx, transactionOptionsToPgx(opts)) if err != nil { @@ -30,8 +30,8 @@ func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions) return &Transaction{tx}, nil } -// Query implements sql.Client. -// Subtle: this method shadows the method (*Conn).Query of pgxConn.Conn. +// Query implements [database.Connection]. +// Subtle: this method shadows the method (*Conn).Query of [pgxConn.Conn]. func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { rows, err := c.Conn.Query(ctx, sql, args...) if err != nil { @@ -40,14 +40,14 @@ func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database. return &Rows{rows}, nil } -// QueryRow implements sql.Client. -// Subtle: this method shadows the method (*Conn).QueryRow of pgxConn.Conn. +// QueryRow implements [database.Connection]. +// Subtle: this method shadows the method (*Conn).QueryRow of [pgxConn.Conn]. func (c *pgxConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row { return &Row{c.Conn.QueryRow(ctx, sql, args...)} } -// Exec implements [database.Pool]. -// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool. +// QueryRow implements [database.Connection]. +// Subtle: this method shadows the method (*Conn).QueryRow of [pgxConn.Conn]. func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) (int64, error) { res, err := c.Conn.Exec(ctx, sql, args...) if err != nil { diff --git a/backend/v3/storage/database/dialect/postgres/error.go b/backend/v3/storage/database/dialect/postgres/error.go index 89b3f8837a9..92fe2147af4 100644 --- a/backend/v3/storage/database/dialect/postgres/error.go +++ b/backend/v3/storage/database/dialect/postgres/error.go @@ -2,6 +2,7 @@ package postgres import ( "errors" + "strings" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" @@ -16,6 +17,18 @@ func wrapError(err error) error { if errors.Is(err, pgx.ErrNoRows) { return database.NewNoRowFoundError(err) } + if errors.Is(err, pgx.ErrTooManyRows) { + return database.NewMultipleRowsFoundError(err) + } + + // scany only exports its errors as strings + if strings.HasPrefix(err.Error(), "scany: expected 1 row, got: ") { + return database.NewMultipleRowsFoundError(err) + } + if strings.HasPrefix(err.Error(), "scany:") || strings.HasPrefix(err.Error(), "scanning:") { + return database.NewScanError(err) + } + var pgxErr *pgconn.PgError if !errors.As(err, &pgxErr) { return database.NewUnknownError(err) diff --git a/backend/v3/storage/database/dialect/postgres/migration/004_correct_set_updated_at.go b/backend/v3/storage/database/dialect/postgres/migration/004_correct_set_updated_at.go new file mode 100644 index 00000000000..01dea1e38a8 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/migration/004_correct_set_updated_at.go @@ -0,0 +1,16 @@ +package migration + +import ( + _ "embed" +) + +var ( + //go:embed 004_correct_set_updated_at/up.sql + up004CorrectSetUpdatedAt string + //go:embed 004_correct_set_updated_at/down.sql + down004CorrectSetUpdatedAt string +) + +func init() { + registerSQLMigration(4, up004CorrectSetUpdatedAt, down004CorrectSetUpdatedAt) +} diff --git a/backend/v3/storage/database/dialect/postgres/migration/004_correct_set_updated_at/down.sql b/backend/v3/storage/database/dialect/postgres/migration/004_correct_set_updated_at/down.sql new file mode 100644 index 00000000000..5e905f9fa98 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/migration/004_correct_set_updated_at/down.sql @@ -0,0 +1,23 @@ +CREATE OR REPLACE TRIGGER trigger_set_updated_at +BEFORE UPDATE ON zitadel.instances +FOR EACH ROW +WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at) +EXECUTE FUNCTION zitadel.set_updated_at(); + +CREATE OR REPLACE TRIGGER trigger_set_updated_at +BEFORE UPDATE ON zitadel.organizations +FOR EACH ROW +WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at) +EXECUTE FUNCTION zitadel.set_updated_at(); + +CREATE OR REPLACE TRIGGER trg_set_updated_at_instance_domains + BEFORE UPDATE ON zitadel.instance_domains + FOR EACH ROW + WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at) + EXECUTE FUNCTION zitadel.set_updated_at(); + +CREATE OR REPLACE TRIGGER trg_set_updated_at_org_domains + BEFORE UPDATE ON zitadel.org_domains + FOR EACH ROW + WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at) + EXECUTE FUNCTION zitadel.set_updated_at(); \ No newline at end of file diff --git a/backend/v3/storage/database/dialect/postgres/migration/004_correct_set_updated_at/up.sql b/backend/v3/storage/database/dialect/postgres/migration/004_correct_set_updated_at/up.sql new file mode 100644 index 00000000000..6292197b6a9 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/migration/004_correct_set_updated_at/up.sql @@ -0,0 +1,23 @@ +CREATE OR REPLACE TRIGGER trigger_set_updated_at +BEFORE UPDATE ON zitadel.instances +FOR EACH ROW +WHEN (NEW.updated_at IS NULL) +EXECUTE FUNCTION zitadel.set_updated_at(); + +CREATE OR REPLACE TRIGGER trigger_set_updated_at +BEFORE UPDATE ON zitadel.organizations +FOR EACH ROW +WHEN (NEW.updated_at IS NULL) +EXECUTE FUNCTION zitadel.set_updated_at(); + +CREATE OR REPLACE TRIGGER trg_set_updated_at_instance_domains + BEFORE UPDATE ON zitadel.instance_domains + FOR EACH ROW + WHEN (NEW.updated_at IS NULL) + EXECUTE FUNCTION zitadel.set_updated_at(); + +CREATE OR REPLACE TRIGGER trg_set_updated_at_org_domains + BEFORE UPDATE ON zitadel.org_domains + FOR EACH ROW + WHEN (NEW.updated_at IS NULL) + EXECUTE FUNCTION zitadel.set_updated_at(); \ No newline at end of file diff --git a/backend/v3/storage/database/dialect/postgres/pool.go b/backend/v3/storage/database/dialect/postgres/pool.go index 07cba119b5e..275be54e11e 100644 --- a/backend/v3/storage/database/dialect/postgres/pool.go +++ b/backend/v3/storage/database/dialect/postgres/pool.go @@ -22,8 +22,8 @@ func PGxPool(pool *pgxpool.Pool) *pgxPool { } // Acquire implements [database.Pool]. -func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) { - conn, err := c.Pool.Acquire(ctx) +func (p *pgxPool) Acquire(ctx context.Context) (database.Connection, error) { + conn, err := p.Pool.Acquire(ctx) if err != nil { return nil, wrapError(err) } @@ -32,8 +32,8 @@ func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) { // Query implements [database.Pool]. // Subtle: this method shadows the method (Pool).Query of pgxPool.Pool. -func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { - rows, err := c.Pool.Query(ctx, sql, args...) +func (p *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { + rows, err := p.Pool.Query(ctx, sql, args...) if err != nil { return nil, wrapError(err) } @@ -42,14 +42,14 @@ func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database. // QueryRow implements [database.Pool]. // Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool. -func (c *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row { - return &Row{c.Pool.QueryRow(ctx, sql, args...)} +func (p *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row { + return &Row{p.Pool.QueryRow(ctx, sql, args...)} } // Exec implements [database.Pool]. // Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool. -func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) { - res, err := c.Pool.Exec(ctx, sql, args...) +func (p *pgxPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) { + res, err := p.Pool.Exec(ctx, sql, args...) if err != nil { return 0, wrapError(err) } @@ -57,8 +57,8 @@ func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) (int64, err } // Begin implements [database.Pool]. -func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) { - tx, err := c.BeginTx(ctx, transactionOptionsToPgx(opts)) +func (p *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) { + tx, err := p.BeginTx(ctx, transactionOptionsToPgx(opts)) if err != nil { return nil, wrapError(err) } @@ -66,23 +66,23 @@ func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) } // Close implements [database.Pool]. -func (c *pgxPool) Close(_ context.Context) error { - c.Pool.Close() +func (p *pgxPool) Close(_ context.Context) error { + p.Pool.Close() return nil } // Ping implements [database.Pool]. -func (c *pgxPool) Ping(ctx context.Context) error { - return wrapError(c.Pool.Ping(ctx)) +func (p *pgxPool) Ping(ctx context.Context) error { + return wrapError(p.Pool.Ping(ctx)) } // Migrate implements [database.Migrator]. -func (c *pgxPool) Migrate(ctx context.Context) error { +func (p *pgxPool) Migrate(ctx context.Context) error { if isMigrated { return nil } - client, err := c.Pool.Acquire(ctx) + client, err := p.Pool.Acquire(ctx) if err != nil { return err } @@ -93,8 +93,8 @@ func (c *pgxPool) Migrate(ctx context.Context) error { } // Migrate implements [database.PoolTest]. -func (c *pgxPool) MigrateTest(ctx context.Context) error { - client, err := c.Pool.Acquire(ctx) +func (p *pgxPool) MigrateTest(ctx context.Context) error { + client, err := p.Pool.Acquire(ctx) if err != nil { return err } diff --git a/backend/v3/storage/database/dialect/sql/conn.go b/backend/v3/storage/database/dialect/sql/conn.go index 2a4e4c91b8b..9d9f5472df0 100644 --- a/backend/v3/storage/database/dialect/sql/conn.go +++ b/backend/v3/storage/database/dialect/sql/conn.go @@ -11,18 +11,18 @@ type sqlConn struct { *sql.Conn } -func SQLConn(conn *sql.Conn) database.Client { +func SQLConn(conn *sql.Conn) database.Connection { return &sqlConn{Conn: conn} } -var _ database.Client = (*sqlConn)(nil) +var _ database.Connection = (*sqlConn)(nil) -// Release implements [database.Client]. +// Release implements [database.Connection]. func (c *sqlConn) Release(_ context.Context) error { return c.Close() } -// Begin implements [database.Client]. +// Begin implements [database.Connection]. func (c *sqlConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) { tx, err := c.BeginTx(ctx, transactionOptionsToSQL(opts)) if err != nil { diff --git a/backend/v3/storage/database/dialect/sql/error.go b/backend/v3/storage/database/dialect/sql/error.go index f01b1e70e13..4879180a559 100644 --- a/backend/v3/storage/database/dialect/sql/error.go +++ b/backend/v3/storage/database/dialect/sql/error.go @@ -3,6 +3,7 @@ package sql import ( "database/sql" "errors" + "strings" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" @@ -19,6 +20,15 @@ func wrapError(err error) error { if errors.Is(err, pgx.ErrNoRows) || errors.Is(err, sql.ErrNoRows) { return database.NewNoRowFoundError(err) } + + // scany only exports its errors as strings + if strings.HasPrefix(err.Error(), "scany: expected 1 row, got: ") { + return database.NewMultipleRowsFoundError(err) + } + if strings.HasPrefix(err.Error(), "scany:") || strings.HasPrefix(err.Error(), "scanning:") { + return database.NewScanError(err) + } + var pgxErr *pgconn.PgError if !errors.As(err, &pgxErr) { return database.NewUnknownError(err) diff --git a/backend/v3/storage/database/dialect/sql/pool.go b/backend/v3/storage/database/dialect/sql/pool.go index 2d37520be78..521d1d1e12c 100644 --- a/backend/v3/storage/database/dialect/sql/pool.go +++ b/backend/v3/storage/database/dialect/sql/pool.go @@ -20,8 +20,8 @@ func SQLPool(db *sql.DB) *sqlPool { } // Acquire implements [database.Pool]. -func (c *sqlPool) Acquire(ctx context.Context) (database.Client, error) { - conn, err := c.Conn(ctx) +func (p *sqlPool) Acquire(ctx context.Context) (database.Connection, error) { + conn, err := p.Conn(ctx) if err != nil { return nil, wrapError(err) } @@ -30,9 +30,9 @@ func (c *sqlPool) Acquire(ctx context.Context) (database.Client, error) { // Query implements [database.Pool]. // Subtle: this method shadows the method (Pool).Query of pgxPool.Pool. -func (c *sqlPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { +func (p *sqlPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { //nolint:rowserrcheck // Rows.Close is called by the caller - rows, err := c.QueryContext(ctx, sql, args...) + rows, err := p.QueryContext(ctx, sql, args...) if err != nil { return nil, wrapError(err) } @@ -41,14 +41,14 @@ func (c *sqlPool) Query(ctx context.Context, sql string, args ...any) (database. // QueryRow implements [database.Pool]. // Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool. -func (c *sqlPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row { - return &Row{c.QueryRowContext(ctx, sql, args...)} +func (p *sqlPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row { + return &Row{p.QueryRowContext(ctx, sql, args...)} } // Exec implements [database.Pool]. // Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool. -func (c *sqlPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) { - res, err := c.ExecContext(ctx, sql, args...) +func (p *sqlPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) { + res, err := p.ExecContext(ctx, sql, args...) if err != nil { return 0, wrapError(err) } @@ -56,8 +56,8 @@ func (c *sqlPool) Exec(ctx context.Context, sql string, args ...any) (int64, err } // Begin implements [database.Pool]. -func (c *sqlPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) { - tx, err := c.BeginTx(ctx, transactionOptionsToSQL(opts)) +func (p *sqlPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) { + tx, err := p.BeginTx(ctx, transactionOptionsToSQL(opts)) if err != nil { return nil, wrapError(err) } @@ -65,16 +65,16 @@ func (c *sqlPool) Begin(ctx context.Context, opts *database.TransactionOptions) } // Ping implements [database.Pool]. -func (c *sqlPool) Ping(ctx context.Context) error { - return wrapError(c.PingContext(ctx)) +func (p *sqlPool) Ping(ctx context.Context) error { + return wrapError(p.PingContext(ctx)) } // Close implements [database.Pool]. -func (c *sqlPool) Close(_ context.Context) error { - return c.DB.Close() +func (p *sqlPool) Close(_ context.Context) error { + return p.DB.Close() } // Migrate implements [database.Migrator]. -func (c *sqlPool) Migrate(ctx context.Context) error { +func (p *sqlPool) Migrate(ctx context.Context) error { return ErrMigrate } diff --git a/backend/v3/storage/database/errors.go b/backend/v3/storage/database/errors.go index 418019fd304..de6cb794c3f 100644 --- a/backend/v3/storage/database/errors.go +++ b/backend/v3/storage/database/errors.go @@ -5,7 +5,34 @@ import ( "fmt" ) -var ErrNoChanges = errors.New("update must contain a change") +var ( + ErrNoChanges = errors.New("update must contain a change") +) + +type MissingConditionError struct { + col Column +} + +func NewMissingConditionError(col Column) error { + return &MissingConditionError{ + col: col, + } +} + +func (e *MissingConditionError) Error() string { + var builder StatementBuilder + builder.WriteString("missing condition for column") + if e.col != nil { + builder.WriteString(" on ") + e.col.WriteQualified(&builder) + } + return builder.String() +} + +func (e *MissingConditionError) Is(target error) bool { + _, ok := target.(*MissingConditionError) + return ok +} // NoRowFoundError is returned when QueryRow does not find any row. // It wraps the dialect specific original error to provide more context. @@ -20,6 +47,9 @@ func NewNoRowFoundError(original error) error { } func (e *NoRowFoundError) Error() string { + if e.original != nil { + return fmt.Sprintf("no row found: %v", e.original) + } return "no row found" } @@ -36,18 +66,19 @@ func (e *NoRowFoundError) Unwrap() error { // It wraps the dialect specific original error to provide more context. type MultipleRowsFoundError struct { original error - count int } -func NewMultipleRowsFoundError(original error, count int) error { +func NewMultipleRowsFoundError(original error) error { return &MultipleRowsFoundError{ original: original, - count: count, } } func (e *MultipleRowsFoundError) Error() string { - return fmt.Sprintf("multiple rows found: %d", e.count) + if e.original != nil { + return fmt.Sprintf("multiple rows found: %v", e.original) + } + return "multiple rows found" } func (e *MultipleRowsFoundError) Is(target error) bool { @@ -77,8 +108,8 @@ type IntegrityViolationError struct { original error } -func NewIntegrityViolationError(typ IntegrityType, table, constraint string, original error) error { - return &IntegrityViolationError{ +func newIntegrityViolationError(typ IntegrityType, table, constraint string, original error) IntegrityViolationError { + return IntegrityViolationError{ integrityType: typ, table: table, constraint: constraint, @@ -87,7 +118,10 @@ func NewIntegrityViolationError(typ IntegrityType, table, constraint string, ori } func (e *IntegrityViolationError) Error() string { - return fmt.Sprintf("integrity violation of type %q on %q (constraint: %q): %v", e.integrityType, e.table, e.constraint, e.original) + if e.original != nil { + return fmt.Sprintf("integrity violation of type %q on %q (constraint: %q): %v", e.integrityType, e.table, e.constraint, e.original) + } + return fmt.Sprintf("integrity violation of type %q on %q (constraint: %q)", e.integrityType, e.table, e.constraint) } func (e *IntegrityViolationError) Is(target error) bool { @@ -108,12 +142,7 @@ type CheckError struct { func NewCheckError(table, constraint string, original error) error { return &CheckError{ - IntegrityViolationError: IntegrityViolationError{ - integrityType: IntegrityTypeCheck, - table: table, - constraint: constraint, - original: original, - }, + IntegrityViolationError: newIntegrityViolationError(IntegrityTypeCheck, table, constraint, original), } } @@ -135,12 +164,7 @@ type UniqueError struct { func NewUniqueError(table, constraint string, original error) error { return &UniqueError{ - IntegrityViolationError: IntegrityViolationError{ - integrityType: IntegrityTypeUnique, - table: table, - constraint: constraint, - original: original, - }, + IntegrityViolationError: newIntegrityViolationError(IntegrityTypeUnique, table, constraint, original), } } @@ -162,12 +186,7 @@ type ForeignKeyError struct { func NewForeignKeyError(table, constraint string, original error) error { return &ForeignKeyError{ - IntegrityViolationError: IntegrityViolationError{ - integrityType: IntegrityTypeForeign, - table: table, - constraint: constraint, - original: original, - }, + IntegrityViolationError: newIntegrityViolationError(IntegrityTypeForeign, table, constraint, original), } } @@ -189,12 +208,7 @@ type NotNullError struct { func NewNotNullError(table, constraint string, original error) error { return &NotNullError{ - IntegrityViolationError: IntegrityViolationError{ - integrityType: IntegrityTypeNotNull, - table: table, - constraint: constraint, - original: original, - }, + IntegrityViolationError: newIntegrityViolationError(IntegrityTypeNotNull, table, constraint, original), } } @@ -207,6 +221,31 @@ func (e *NotNullError) Unwrap() error { return &e.IntegrityViolationError } +// ScanError is returned when scanning rows into objects failed. +// It wraps the original error to provide more context. +type ScanError struct { + original error +} + +func NewScanError(original error) error { + return &ScanError{ + original: original, + } +} + +func (e *ScanError) Error() string { + return fmt.Sprintf("Scan error: %v", e.original) +} + +func (e *ScanError) Is(target error) bool { + _, ok := target.(*ScanError) + return ok +} + +func (e *ScanError) Unwrap() error { + return e.original +} + // UnknownError is returned when an unknown error occurs. // It wraps the dialect specific original error to provide more context. // It is used to indicate that an error occurred that does not fit into any of the other categories. diff --git a/backend/v3/storage/database/errors_test.go b/backend/v3/storage/database/errors_test.go new file mode 100644 index 00000000000..1d8e928c0f9 --- /dev/null +++ b/backend/v3/storage/database/errors_test.go @@ -0,0 +1,304 @@ +package database + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestError(t *testing.T) { + for _, test := range []struct { + name string + err error + want string + }{ + { + name: "missing condition without column", + err: NewMissingConditionError(nil), + want: "missing condition for column", + }, + { + name: "missing condition with column", + err: NewMissingConditionError(NewColumn("table", "column")), + want: "missing condition for column on table.column", + }, + { + name: "no row found without original error", + err: NewNoRowFoundError(nil), + want: "no row found", + }, + { + name: "no row found with original error", + err: NewNoRowFoundError(errors.New("original error")), + want: "no row found: original error", + }, + { + name: "multiple rows found without original error", + err: NewMultipleRowsFoundError(nil), + want: "multiple rows found", + }, + { + name: "multiple rows found with original error", + err: NewMultipleRowsFoundError(errors.New("original error")), + want: "multiple rows found: original error", + }, + { + name: "check violation without original error", + err: NewCheckError("table", "constraint", nil), + want: `integrity violation of type "check" on "table" (constraint: "constraint")`, + }, + { + name: "check violation with original error", + err: NewCheckError("table", "constraint", errors.New("original error")), + want: `integrity violation of type "check" on "table" (constraint: "constraint"): original error`, + }, + { + name: "unique violation without original error", + err: NewUniqueError("table", "constraint", nil), + want: `integrity violation of type "unique" on "table" (constraint: "constraint")`, + }, + { + name: "unique violation with original error", + err: NewUniqueError("table", "constraint", errors.New("original error")), + want: `integrity violation of type "unique" on "table" (constraint: "constraint"): original error`, + }, + { + name: "foreign key violation without original error", + err: NewForeignKeyError("table", "constraint", nil), + want: `integrity violation of type "foreign" on "table" (constraint: "constraint")`, + }, + { + name: "foreign key violation with original error", + err: NewForeignKeyError("table", "constraint", errors.New("original error")), + want: `integrity violation of type "foreign" on "table" (constraint: "constraint"): original error`, + }, + { + name: "not null violation without original error", + err: NewNotNullError("table", "constraint", nil), + want: `integrity violation of type "not null" on "table" (constraint: "constraint")`, + }, + { + name: "not null violation with original error", + err: NewNotNullError("table", "constraint", errors.New("original error")), + want: `integrity violation of type "not null" on "table" (constraint: "constraint"): original error`, + }, + { + name: "unknown error", + err: NewUnknownError(errors.New("original error")), + want: `unknown database error: original error`, + }, + } { + t.Run(test.name, func(t *testing.T) { + assert.Equal(t, test.want, test.err.Error()) + }) + } +} + +func TestUnwrap(t *testing.T) { + originalErr := errors.New("original error") + for _, test := range []struct { + name string + err error + want error + }{ + { + name: "missing condition without column", + err: NewMissingConditionError(nil), + want: nil, + }, + { + name: "missing condition with column", + err: NewMissingConditionError(NewColumn("table", "column")), + want: nil, + }, + { + name: "no row found without original error", + err: NewNoRowFoundError(nil), + want: nil, + }, + { + name: "no row found with original error", + err: NewNoRowFoundError(errors.New("original error")), + want: originalErr, + }, + { + name: "multiple rows found without original error", + err: NewMultipleRowsFoundError(nil), + want: nil, + }, + { + name: "multiple rows found with original error", + err: NewMultipleRowsFoundError(originalErr), + want: originalErr, + }, + { + name: "check violation without original error", + err: NewCheckError("table", "constraint", nil), + want: &IntegrityViolationError{ + integrityType: IntegrityTypeCheck, + table: "table", + constraint: "constraint", + original: nil, + }, + }, + { + name: "check violation with original error", + err: NewCheckError("table", "constraint", originalErr), + want: &IntegrityViolationError{ + integrityType: IntegrityTypeCheck, + table: "table", + constraint: "constraint", + original: originalErr, + }, + }, + { + name: "unique violation without original error", + err: NewUniqueError("table", "constraint", nil), + want: &IntegrityViolationError{ + integrityType: IntegrityTypeUnique, + table: "table", + constraint: "constraint", + original: nil, + }, + }, + { + name: "unique violation with original error", + err: NewUniqueError("table", "constraint", originalErr), + want: &IntegrityViolationError{ + integrityType: IntegrityTypeUnique, + table: "table", + constraint: "constraint", + original: originalErr, + }, + }, + { + name: "foreign key violation without original error", + err: NewForeignKeyError("table", "constraint", nil), + want: &IntegrityViolationError{ + integrityType: IntegrityTypeForeign, + table: "table", + constraint: "constraint", + original: nil, + }, + }, + { + name: "foreign key violation with original error", + err: NewForeignKeyError("table", "constraint", originalErr), + want: &IntegrityViolationError{ + integrityType: IntegrityTypeForeign, + table: "table", + constraint: "constraint", + original: originalErr, + }, + }, + { + name: "not null violation without original error", + err: NewNotNullError("table", "constraint", nil), + want: &IntegrityViolationError{ + integrityType: IntegrityTypeNotNull, + table: "table", + constraint: "constraint", + original: nil, + }, + }, + { + name: "not null violation with original error", + err: NewNotNullError("table", "constraint", originalErr), + want: &IntegrityViolationError{ + integrityType: IntegrityTypeNotNull, + table: "table", + constraint: "constraint", + original: originalErr, + }, + }, + { + name: "unwrap integrity violation error", + err: errors.Unwrap(NewNotNullError("table", "constraint", originalErr)), + want: originalErr, + }, + { + name: "unknown error", + err: NewUnknownError(originalErr), + want: originalErr, + }, + } { + t.Run(test.name, func(t *testing.T) { + assert.Equal(t, test.want, errors.Unwrap(test.err)) + }) + } +} + +func TestIs(t *testing.T) { + originalErr := errors.New("original error") + for _, test := range []struct { + name string + err error + want error + }{ + { + name: "missing condition", + err: NewMissingConditionError(NewColumn("table", "column")), + want: new(MissingConditionError), + }, + { + name: "no row found", + err: NewNoRowFoundError(errors.New("original error")), + want: new(NoRowFoundError), + }, + { + name: "multiple rows found", + err: NewMultipleRowsFoundError(originalErr), + want: new(MultipleRowsFoundError), + }, + { + name: "check violation is for integrity", + err: NewCheckError("table", "constraint", nil), + want: new(IntegrityViolationError), + }, + { + name: "check violation is check violation", + err: NewCheckError("table", "constraint", nil), + want: new(CheckError), + }, + { + name: "unique violation is for integrity", + err: NewUniqueError("table", "constraint", nil), + want: new(IntegrityViolationError), + }, + { + name: "unique violation is unique violation", + err: NewUniqueError("table", "constraint", nil), + want: new(UniqueError), + }, + { + name: "foreign key violation is for integrity", + err: NewForeignKeyError("table", "constraint", nil), + want: new(IntegrityViolationError), + }, + { + name: "foreign key violation is foreign key violation", + err: NewForeignKeyError("table", "constraint", nil), + want: new(ForeignKeyError), + }, + { + name: "not null violation is for integrity", + err: NewNotNullError("table", "constraint", nil), + want: new(IntegrityViolationError), + }, + { + name: "not null violation is not null violation", + err: NewNotNullError("table", "constraint", nil), + want: new(NotNullError), + }, + { + name: "unknown error", + err: NewUnknownError(originalErr), + want: new(UnknownError), + }, + } { + t.Run(test.name, func(t *testing.T) { + assert.ErrorIs(t, test.err, test.want) + }) + } +} diff --git a/backend/v3/storage/database/events_testing/instance_domain_test.go b/backend/v3/storage/database/events_testing/instance_domain_test.go index c42ee1a0ff0..8dc6d141765 100644 --- a/backend/v3/storage/database/events_testing/instance_domain_test.go +++ b/backend/v3/storage/database/events_testing/instance_domain_test.go @@ -21,8 +21,8 @@ import ( func TestServer_TestInstanceDomainReduces(t *testing.T) { instance := integration.NewInstance(CTX) - instanceRepo := repository.InstanceRepository(pool) - instanceDomainRepo := instanceRepo.Domains(true) + instanceRepo := repository.InstanceRepository() + instanceDomainRepo := repository.InstanceDomainRepository() t.Cleanup(func() { _, err := instance.Client.InstanceV2Beta.DeleteInstance(CTX, &v2beta.DeleteInstanceRequest{ @@ -36,7 +36,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) { // Wait for instance to be created retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(t *assert.CollectT) { - _, err := instanceRepo.Get(CTX, + _, err := instanceRepo.Get(CTX, pool, database.WithCondition(instanceRepo.IDCondition(instance.ID())), ) assert.NoError(t, err) @@ -66,7 +66,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) { // Test that domain add reduces retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(t *assert.CollectT) { - domain, err := instanceDomainRepo.Get(CTX, + domain, err := instanceDomainRepo.Get(CTX, pool, database.WithCondition( database.And( instanceDomainRepo.InstanceIDCondition(instance.Instance.Id), @@ -96,7 +96,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) { t.Cleanup(func() { // first we change the primary domain to something else - domain, err := instanceDomainRepo.Get(CTX, + domain, err := instanceDomainRepo.Get(CTX, pool, database.WithCondition( database.And( instanceDomainRepo.InstanceIDCondition(instance.Instance.Id), @@ -125,7 +125,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) { // Wait for domain to be created retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(t *assert.CollectT) { - domain, err := instanceDomainRepo.Get(CTX, + domain, err := instanceDomainRepo.Get(CTX, pool, database.WithCondition( database.And( instanceDomainRepo.InstanceIDCondition(instance.Instance.Id), @@ -151,7 +151,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) { // Test that set primary reduces retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(t *assert.CollectT) { - domain, err := instanceDomainRepo.Get(CTX, + domain, err := instanceDomainRepo.Get(CTX, pool, database.WithCondition( database.And( instanceDomainRepo.InstanceIDCondition(instance.Instance.Id), @@ -180,7 +180,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) { // Wait for domain to be created and verify it exists retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(t *assert.CollectT) { - _, err := instanceDomainRepo.Get(CTX, + _, err := instanceDomainRepo.Get(CTX, pool, database.WithCondition( database.And( instanceDomainRepo.InstanceIDCondition(instance.Instance.Id), @@ -202,7 +202,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) { // Test that domain remove reduces retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(t *assert.CollectT) { - domain, err := instanceDomainRepo.Get(CTX, + domain, err := instanceDomainRepo.Get(CTX, pool, database.WithCondition( database.And( instanceDomainRepo.InstanceIDCondition(instance.Instance.Id), @@ -241,7 +241,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) { // Test that domain add reduces retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(t *assert.CollectT) { - domain, err := instanceDomainRepo.Get(CTX, + domain, err := instanceDomainRepo.Get(CTX, pool, database.WithCondition( database.And( instanceDomainRepo.InstanceIDCondition(instance.Instance.Id), @@ -271,7 +271,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) { // Wait for domain to be created and verify it exists retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(t *assert.CollectT) { - _, err := instanceDomainRepo.Get(CTX, + _, err := instanceDomainRepo.Get(CTX, pool, database.WithCondition( database.And( instanceDomainRepo.InstanceIDCondition(instance.Instance.Id), @@ -293,7 +293,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) { // Test that domain remove reduces retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(t *assert.CollectT) { - domain, err := instanceDomainRepo.Get(CTX, + domain, err := instanceDomainRepo.Get(CTX, pool, database.WithCondition( database.And( instanceDomainRepo.InstanceIDCondition(instance.Instance.Id), diff --git a/backend/v3/storage/database/events_testing/instance_test.go b/backend/v3/storage/database/events_testing/instance_test.go index 977fe7c7a9c..46ffe310e53 100644 --- a/backend/v3/storage/database/events_testing/instance_test.go +++ b/backend/v3/storage/database/events_testing/instance_test.go @@ -17,7 +17,7 @@ import ( ) func TestServer_TestInstanceReduces(t *testing.T) { - instanceRepo := repository.InstanceRepository(pool) + instanceRepo := repository.InstanceRepository() t.Run("test instance add reduces", func(t *testing.T) { instanceName := gofakeit.Name() @@ -46,7 +46,7 @@ func TestServer_TestInstanceReduces(t *testing.T) { retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(t *assert.CollectT) { - instance, err := instanceRepo.Get(CTX, + instance, err := instanceRepo.Get(CTX, pool, database.WithCondition(instanceRepo.IDCondition(instance.GetInstanceId())), ) require.NoError(t, err) @@ -92,7 +92,7 @@ func TestServer_TestInstanceReduces(t *testing.T) { // check instance exists retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(t *assert.CollectT) { - instance, err := instanceRepo.Get(CTX, + instance, err := instanceRepo.Get(CTX, pool, database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())), ) require.NoError(t, err) @@ -110,7 +110,7 @@ func TestServer_TestInstanceReduces(t *testing.T) { retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(t *assert.CollectT) { - instance, err := instanceRepo.Get(CTX, + instance, err := instanceRepo.Get(CTX, pool, database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())), ) require.NoError(t, err) @@ -137,7 +137,7 @@ func TestServer_TestInstanceReduces(t *testing.T) { // check instance exists retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(t *assert.CollectT) { - instance, err := instanceRepo.Get(CTX, + instance, err := instanceRepo.Get(CTX, pool, database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())), ) require.NoError(t, err) @@ -151,7 +151,7 @@ func TestServer_TestInstanceReduces(t *testing.T) { retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(t *assert.CollectT) { - instance, err := instanceRepo.Get(CTX, + instance, err := instanceRepo.Get(CTX, pool, database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())), ) // event instance.removed diff --git a/backend/v3/storage/database/events_testing/org_domain_test.go b/backend/v3/storage/database/events_testing/org_domain_test.go index af554a22817..4e35fca93ea 100644 --- a/backend/v3/storage/database/events_testing/org_domain_test.go +++ b/backend/v3/storage/database/events_testing/org_domain_test.go @@ -22,8 +22,8 @@ func TestServer_TestOrgDomainReduces(t *testing.T) { }) require.NoError(t, err) - orgRepo := repository.OrganizationRepository(pool) - orgDomainRepo := orgRepo.Domains(false) + orgRepo := repository.OrganizationRepository() + orgDomainRepo := repository.OrganizationDomainRepository() t.Cleanup(func() { _, err := OrgClient.DeleteOrganization(CTX, &v2beta.DeleteOrganizationRequest{ @@ -37,8 +37,13 @@ func TestServer_TestOrgDomainReduces(t *testing.T) { // Wait for org to be created retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(t *assert.CollectT) { - _, err := orgRepo.Get(CTX, - database.WithCondition(orgRepo.IDCondition(org.GetId())), + _, err := orgRepo.Get(CTX, pool, + database.WithCondition( + database.And( + orgRepo.InstanceIDCondition(Instance.Instance.Id), + orgRepo.IDCondition(org.GetId()), + ), + ), ) assert.NoError(t, err) }, retryDuration, tick) @@ -68,7 +73,7 @@ func TestServer_TestOrgDomainReduces(t *testing.T) { // Test that domain add reduces retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(t *assert.CollectT) { - gottenDomain, err := orgDomainRepo.Get(CTX, + gottenDomain, err := orgDomainRepo.Get(CTX, pool, database.WithCondition( database.And( orgDomainRepo.InstanceIDCondition(Instance.Instance.Id), @@ -107,7 +112,7 @@ func TestServer_TestOrgDomainReduces(t *testing.T) { // Test that domain remove reduces retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(t *assert.CollectT) { - domain, err := orgDomainRepo.Get(CTX, + domain, err := orgDomainRepo.Get(CTX, pool, database.WithCondition( database.And( orgDomainRepo.InstanceIDCondition(Instance.Instance.Id), diff --git a/backend/v3/storage/database/events_testing/organization_test.go b/backend/v3/storage/database/events_testing/organization_test.go index 7c89cfbcd51..bcca1d9ad0f 100644 --- a/backend/v3/storage/database/events_testing/organization_test.go +++ b/backend/v3/storage/database/events_testing/organization_test.go @@ -19,7 +19,7 @@ import ( func TestServer_TestOrganizationReduces(t *testing.T) { instanceID := Instance.ID() - orgRepo := repository.OrganizationRepository(pool) + orgRepo := repository.OrganizationRepository() t.Run("test org add reduces", func(t *testing.T) { beforeCreate := time.Now() @@ -42,7 +42,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) { retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(tt *assert.CollectT) { - organization, err := orgRepo.Get(CTX, + organization, err := orgRepo.Get(CTX, pool, database.WithCondition( database.And( orgRepo.IDCondition(org.GetId()), @@ -92,7 +92,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) { retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(t *assert.CollectT) { - organization, err := orgRepo.Get(CTX, + organization, err := orgRepo.Get(CTX, pool, database.WithCondition( database.And( orgRepo.IDCondition(organization.Id), @@ -137,7 +137,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) { retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(t *assert.CollectT) { - organization, err := orgRepo.Get(CTX, + organization, err := orgRepo.Get(CTX, pool, database.WithCondition( database.And( orgRepo.IDCondition(organization.Id), @@ -177,11 +177,11 @@ func TestServer_TestOrganizationReduces(t *testing.T) { }) require.NoError(t, err) - orgRepo := repository.OrganizationRepository(pool) + orgRepo := repository.OrganizationRepository() // 3. check org deactivated retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(t *assert.CollectT) { - organization, err := orgRepo.Get(CTX, + organization, err := orgRepo.Get(CTX, pool, database.WithCondition( database.And( orgRepo.IDCondition(organization.Id), @@ -203,7 +203,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) { retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(t *assert.CollectT) { - organization, err := orgRepo.Get(CTX, + organization, err := orgRepo.Get(CTX, pool, database.WithCondition( database.And( orgRepo.IDCondition(organization.Id), @@ -230,10 +230,10 @@ func TestServer_TestOrganizationReduces(t *testing.T) { require.NoError(t, err) // 2. check org retrievable - orgRepo := repository.OrganizationRepository(pool) + orgRepo := repository.OrganizationRepository() retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(t *assert.CollectT) { - _, err := orgRepo.Get(CTX, + _, err := orgRepo.Get(CTX, pool, database.WithCondition( database.And( orgRepo.IDCondition(organization.Id), @@ -252,7 +252,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) { retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) assert.EventuallyWithT(t, func(t *assert.CollectT) { - organization, err := orgRepo.Get(CTX, + organization, err := orgRepo.Get(CTX, pool, database.WithCondition( database.And( orgRepo.IDCondition(organization.Id), diff --git a/backend/v3/storage/database/gen_mock.go b/backend/v3/storage/database/gen_mock.go index 04d204cfa1b..40ffd3a7b29 100644 --- a/backend/v3/storage/database/gen_mock.go +++ b/backend/v3/storage/database/gen_mock.go @@ -1,3 +1,3 @@ package database -//go:generate mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Client,Row,Rows,Transaction +//go:generate mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Connection,Row,Rows,Transaction diff --git a/backend/v3/storage/database/operators.go b/backend/v3/storage/database/operators.go index c8820d918db..93cf1d32cd6 100644 --- a/backend/v3/storage/database/operators.go +++ b/backend/v3/storage/database/operators.go @@ -6,16 +6,40 @@ import ( "golang.org/x/exp/constraints" ) -type Value interface { - Boolean | Number | Text | Instruction +type wrappedValue[V Value] struct { + value V + fn function } +func LowerValue[T Value](v T) wrappedValue[T] { + return wrappedValue[T]{value: v, fn: functionLower} +} + +func SHA256Value[T Value](v T) wrappedValue[T] { + return wrappedValue[T]{value: v, fn: functionSHA256} +} + +func (b wrappedValue[V]) WriteArg(builder *StatementBuilder) { + builder.Grow(len(b.fn) + 5) + builder.WriteString(string(b.fn)) + builder.WriteRune('(') + builder.WriteArg(b.value) + builder.WriteRune(')') +} + +var _ argWriter = (*wrappedValue[string])(nil) + +type Value interface { + Boolean | Number | Text | Instruction | Bytes +} + +//go:generate enumer -type NumberOperation,TextOperation,BytesOperation -linecomment -output ./operators_enumer.go type Operation interface { - BooleanOperation | NumberOperation | TextOperation + NumberOperation | TextOperation | BytesOperation } type Text interface { - ~string | ~[]byte + ~string | Bytes } // TextOperation are operations that can be performed on text values. @@ -23,60 +47,17 @@ type TextOperation uint8 const ( // TextOperationEqual compares two strings for equality. - TextOperationEqual TextOperation = iota + 1 - // TextOperationEqualIgnoreCase compares two strings for equality, ignoring case. - TextOperationEqualIgnoreCase + TextOperationEqual TextOperation = iota + 1 // = // TextOperationNotEqual compares two strings for inequality. - TextOperationNotEqual - // TextOperationNotEqualIgnoreCase compares two strings for inequality, ignoring case. - TextOperationNotEqualIgnoreCase + TextOperationNotEqual // <> // TextOperationStartsWith checks if the first string starts with the second. - TextOperationStartsWith - // TextOperationStartsWithIgnoreCase checks if the first string starts with the second, ignoring case. - TextOperationStartsWithIgnoreCase + TextOperationStartsWith // LIKE ) -var textOperations = map[TextOperation]string{ - TextOperationEqual: " = ", - TextOperationEqualIgnoreCase: " LIKE ", - TextOperationNotEqual: " <> ", - TextOperationNotEqualIgnoreCase: " NOT LIKE ", - TextOperationStartsWith: " LIKE ", - TextOperationStartsWithIgnoreCase: " LIKE ", -} - -func writeTextOperation[T Text](builder *StatementBuilder, col Column, op TextOperation, value T) { - switch op { - case TextOperationEqual, TextOperationNotEqual: - col.WriteQualified(builder) - builder.WriteString(textOperations[op]) - builder.WriteArg(value) - case TextOperationEqualIgnoreCase, TextOperationNotEqualIgnoreCase: - builder.WriteString("LOWER(") - col.WriteQualified(builder) - builder.WriteString(")") - - builder.WriteString(textOperations[op]) - builder.WriteString("LOWER(") - builder.WriteArg(value) - builder.WriteString(")") - case TextOperationStartsWith: - col.WriteQualified(builder) - builder.WriteString(textOperations[op]) - builder.WriteArg(value) +func writeTextOperation[T Text](builder *StatementBuilder, col Column, op TextOperation, value any) { + writeOperation[T](builder, col, op.String(), value) + if op == TextOperationStartsWith { builder.WriteString(" || '%'") - case TextOperationStartsWithIgnoreCase: - builder.WriteString("LOWER(") - col.WriteQualified(builder) - builder.WriteString(")") - - builder.WriteString(textOperations[op]) - builder.WriteString("LOWER(") - builder.WriteArg(value) - builder.WriteString(")") - builder.WriteString(" || '%'") - default: - panic("unsupported text operation") } } @@ -89,48 +70,60 @@ type NumberOperation uint8 const ( // NumberOperationEqual compares two numbers for equality. - NumberOperationEqual NumberOperation = iota + 1 + NumberOperationEqual NumberOperation = iota + 1 // = // NumberOperationNotEqual compares two numbers for inequality. - NumberOperationNotEqual + NumberOperationNotEqual // <> // NumberOperationLessThan compares two numbers to check if the first is less than the second. - NumberOperationLessThan + NumberOperationLessThan // < // NumberOperationLessThanOrEqual compares two numbers to check if the first is less than or equal to the second. - NumberOperationAtLeast + NumberOperationAtLeast // <= // NumberOperationGreaterThan compares two numbers to check if the first is greater than the second. - NumberOperationGreaterThan + NumberOperationGreaterThan // > // NumberOperationGreaterThanOrEqual compares two numbers to check if the first is greater than or equal to the second. - NumberOperationAtMost + NumberOperationAtMost // >= ) -var numberOperations = map[NumberOperation]string{ - NumberOperationEqual: " = ", - NumberOperationNotEqual: " <> ", - NumberOperationLessThan: " < ", - NumberOperationAtLeast: " <= ", - NumberOperationGreaterThan: " > ", - NumberOperationAtMost: " >= ", -} - -func writeNumberOperation[T Number](builder *StatementBuilder, col Column, op NumberOperation, value T) { - col.WriteQualified(builder) - builder.WriteString(numberOperations[op]) - builder.WriteArg(value) +func writeNumberOperation[T Number](builder *StatementBuilder, col Column, op NumberOperation, value any) { + writeOperation[T](builder, col, op.String(), value) } type Boolean interface { ~bool } -// BooleanOperation are operations that can be performed on boolean values. -type BooleanOperation uint8 +func writeBooleanOperation[T Boolean](builder *StatementBuilder, col Column, value any) { + writeOperation[T](builder, col, "=", value) +} + +type Bytes interface { + ~[]byte +} + +// BytesOperation are operations that can be performed on bytea values. +type BytesOperation uint8 const ( - BooleanOperationIsTrue BooleanOperation = iota + 1 - BooleanOperationIsFalse + BytesOperationEqual BytesOperation = iota + 1 // = + BytesOperationNotEqual // <> ) -func writeBooleanOperation[T Boolean](builder *StatementBuilder, col Column, value T) { +func writeBytesOperation[T Bytes](builder *StatementBuilder, col Column, op BytesOperation, value any) { + writeOperation[T](builder, col, op.String(), value) +} + +func writeOperation[V Value](builder *StatementBuilder, col Column, op string, value any) { + if op == "" { + panic("unsupported operation") + } + + switch value.(type) { + case V, wrappedValue[V], *wrappedValue[V]: + default: + panic("unsupported value type") + } col.WriteQualified(builder) - builder.WriteString(" = ") + builder.WriteRune(' ') + builder.WriteString(op) + builder.WriteRune(' ') builder.WriteArg(value) } diff --git a/backend/v3/storage/database/operators_enumer.go b/backend/v3/storage/database/operators_enumer.go new file mode 100644 index 00000000000..8c73ad37aaf --- /dev/null +++ b/backend/v3/storage/database/operators_enumer.go @@ -0,0 +1,241 @@ +// Code generated by "enumer -type NumberOperation,TextOperation,BytesOperation -linecomment -output ./operators_enumer.go"; DO NOT EDIT. + +package database + +import ( + "fmt" + "strings" +) + +const _NumberOperationName = "=<><<=>>=" + +var _NumberOperationIndex = [...]uint8{0, 1, 3, 4, 6, 7, 9} + +const _NumberOperationLowerName = "=<><<=>>=" + +func (i NumberOperation) String() string { + i -= 1 + if i >= NumberOperation(len(_NumberOperationIndex)-1) { + return fmt.Sprintf("NumberOperation(%d)", i+1) + } + return _NumberOperationName[_NumberOperationIndex[i]:_NumberOperationIndex[i+1]] +} + +// An "invalid array index" compiler error signifies that the constant values have changed. +// Re-run the stringer command to generate them again. +func _NumberOperationNoOp() { + var x [1]struct{} + _ = x[NumberOperationEqual-(1)] + _ = x[NumberOperationNotEqual-(2)] + _ = x[NumberOperationLessThan-(3)] + _ = x[NumberOperationAtLeast-(4)] + _ = x[NumberOperationGreaterThan-(5)] + _ = x[NumberOperationAtMost-(6)] +} + +var _NumberOperationValues = []NumberOperation{NumberOperationEqual, NumberOperationNotEqual, NumberOperationLessThan, NumberOperationAtLeast, NumberOperationGreaterThan, NumberOperationAtMost} + +var _NumberOperationNameToValueMap = map[string]NumberOperation{ + _NumberOperationName[0:1]: NumberOperationEqual, + _NumberOperationLowerName[0:1]: NumberOperationEqual, + _NumberOperationName[1:3]: NumberOperationNotEqual, + _NumberOperationLowerName[1:3]: NumberOperationNotEqual, + _NumberOperationName[3:4]: NumberOperationLessThan, + _NumberOperationLowerName[3:4]: NumberOperationLessThan, + _NumberOperationName[4:6]: NumberOperationAtLeast, + _NumberOperationLowerName[4:6]: NumberOperationAtLeast, + _NumberOperationName[6:7]: NumberOperationGreaterThan, + _NumberOperationLowerName[6:7]: NumberOperationGreaterThan, + _NumberOperationName[7:9]: NumberOperationAtMost, + _NumberOperationLowerName[7:9]: NumberOperationAtMost, +} + +var _NumberOperationNames = []string{ + _NumberOperationName[0:1], + _NumberOperationName[1:3], + _NumberOperationName[3:4], + _NumberOperationName[4:6], + _NumberOperationName[6:7], + _NumberOperationName[7:9], +} + +// NumberOperationString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func NumberOperationString(s string) (NumberOperation, error) { + if val, ok := _NumberOperationNameToValueMap[s]; ok { + return val, nil + } + + if val, ok := _NumberOperationNameToValueMap[strings.ToLower(s)]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to NumberOperation values", s) +} + +// NumberOperationValues returns all values of the enum +func NumberOperationValues() []NumberOperation { + return _NumberOperationValues +} + +// NumberOperationStrings returns a slice of all String values of the enum +func NumberOperationStrings() []string { + strs := make([]string, len(_NumberOperationNames)) + copy(strs, _NumberOperationNames) + return strs +} + +// IsANumberOperation returns "true" if the value is listed in the enum definition. "false" otherwise +func (i NumberOperation) IsANumberOperation() bool { + for _, v := range _NumberOperationValues { + if i == v { + return true + } + } + return false +} + +const _TextOperationName = "=<>LIKE" + +var _TextOperationIndex = [...]uint8{0, 1, 3, 7} + +const _TextOperationLowerName = "=<>like" + +func (i TextOperation) String() string { + i -= 1 + if i >= TextOperation(len(_TextOperationIndex)-1) { + return fmt.Sprintf("TextOperation(%d)", i+1) + } + return _TextOperationName[_TextOperationIndex[i]:_TextOperationIndex[i+1]] +} + +// An "invalid array index" compiler error signifies that the constant values have changed. +// Re-run the stringer command to generate them again. +func _TextOperationNoOp() { + var x [1]struct{} + _ = x[TextOperationEqual-(1)] + _ = x[TextOperationNotEqual-(2)] + _ = x[TextOperationStartsWith-(3)] +} + +var _TextOperationValues = []TextOperation{TextOperationEqual, TextOperationNotEqual, TextOperationStartsWith} + +var _TextOperationNameToValueMap = map[string]TextOperation{ + _TextOperationName[0:1]: TextOperationEqual, + _TextOperationLowerName[0:1]: TextOperationEqual, + _TextOperationName[1:3]: TextOperationNotEqual, + _TextOperationLowerName[1:3]: TextOperationNotEqual, + _TextOperationName[3:7]: TextOperationStartsWith, + _TextOperationLowerName[3:7]: TextOperationStartsWith, +} + +var _TextOperationNames = []string{ + _TextOperationName[0:1], + _TextOperationName[1:3], + _TextOperationName[3:7], +} + +// TextOperationString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func TextOperationString(s string) (TextOperation, error) { + if val, ok := _TextOperationNameToValueMap[s]; ok { + return val, nil + } + + if val, ok := _TextOperationNameToValueMap[strings.ToLower(s)]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to TextOperation values", s) +} + +// TextOperationValues returns all values of the enum +func TextOperationValues() []TextOperation { + return _TextOperationValues +} + +// TextOperationStrings returns a slice of all String values of the enum +func TextOperationStrings() []string { + strs := make([]string, len(_TextOperationNames)) + copy(strs, _TextOperationNames) + return strs +} + +// IsATextOperation returns "true" if the value is listed in the enum definition. "false" otherwise +func (i TextOperation) IsATextOperation() bool { + for _, v := range _TextOperationValues { + if i == v { + return true + } + } + return false +} + +const _BytesOperationName = "=<>" + +var _BytesOperationIndex = [...]uint8{0, 1, 3} + +const _BytesOperationLowerName = "=<>" + +func (i BytesOperation) String() string { + i -= 1 + if i >= BytesOperation(len(_BytesOperationIndex)-1) { + return fmt.Sprintf("BytesOperation(%d)", i+1) + } + return _BytesOperationName[_BytesOperationIndex[i]:_BytesOperationIndex[i+1]] +} + +// An "invalid array index" compiler error signifies that the constant values have changed. +// Re-run the stringer command to generate them again. +func _BytesOperationNoOp() { + var x [1]struct{} + _ = x[BytesOperationEqual-(1)] + _ = x[BytesOperationNotEqual-(2)] +} + +var _BytesOperationValues = []BytesOperation{BytesOperationEqual, BytesOperationNotEqual} + +var _BytesOperationNameToValueMap = map[string]BytesOperation{ + _BytesOperationName[0:1]: BytesOperationEqual, + _BytesOperationLowerName[0:1]: BytesOperationEqual, + _BytesOperationName[1:3]: BytesOperationNotEqual, + _BytesOperationLowerName[1:3]: BytesOperationNotEqual, +} + +var _BytesOperationNames = []string{ + _BytesOperationName[0:1], + _BytesOperationName[1:3], +} + +// BytesOperationString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func BytesOperationString(s string) (BytesOperation, error) { + if val, ok := _BytesOperationNameToValueMap[s]; ok { + return val, nil + } + + if val, ok := _BytesOperationNameToValueMap[strings.ToLower(s)]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to BytesOperation values", s) +} + +// BytesOperationValues returns all values of the enum +func BytesOperationValues() []BytesOperation { + return _BytesOperationValues +} + +// BytesOperationStrings returns a slice of all String values of the enum +func BytesOperationStrings() []string { + strs := make([]string, len(_BytesOperationNames)) + copy(strs, _BytesOperationNames) + return strs +} + +// IsABytesOperation returns "true" if the value is listed in the enum definition. "false" otherwise +func (i BytesOperation) IsABytesOperation() bool { + for _, v := range _BytesOperationValues { + if i == v { + return true + } + } + return false +} diff --git a/backend/v3/storage/database/operators_test.go b/backend/v3/storage/database/operators_test.go new file mode 100644 index 00000000000..d7de372b8de --- /dev/null +++ b/backend/v3/storage/database/operators_test.go @@ -0,0 +1,244 @@ +package database + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_writeOperation(t *testing.T) { + type want struct { + shouldPanic bool + stmt string + args []any + } + tests := []struct { + name string + write func(builder *StatementBuilder) + // col Column + // op string + // value any + want want + }{ + { + name: "unsupported operation panics", + write: func(builder *StatementBuilder) { + writeOperation[string](builder, NewColumn("table", "column"), "", "value") + }, + want: want{ + shouldPanic: true, + }, + }, + { + name: "unsupported value type panics", + write: func(builder *StatementBuilder) { + writeOperation[string](builder, NewColumn("table", "column"), " = ", struct{}{}) + }, + want: want{ + shouldPanic: true, + }, + }, + { + name: "unsupported wrapped value type panics", + write: func(builder *StatementBuilder) { + writeOperation[string](builder, NewColumn("table", "column"), " = ", SHA256Value(123)) + }, + want: want{ + shouldPanic: true, + }, + }, + { + name: "text equal", + write: func(builder *StatementBuilder) { + writeTextOperation[string](builder, NewColumn("table", "column"), TextOperationEqual, "value") + }, + want: want{ + stmt: "table.column = $1", + args: []any{"value"}, + }, + }, + { + name: "text not equal", + write: func(builder *StatementBuilder) { + writeTextOperation[string](builder, NewColumn("table", "column"), TextOperationNotEqual, "value") + }, + want: want{ + stmt: "table.column <> $1", + args: []any{"value"}, + }, + }, + { + name: "text starts with", + write: func(builder *StatementBuilder) { + writeTextOperation[string](builder, NewColumn("table", "column"), TextOperationStartsWith, "value") + }, + want: want{ + stmt: "table.column LIKE $1 || '%'", + args: []any{"value"}, + }, + }, + { + name: "text equal with wrapped value", + write: func(builder *StatementBuilder) { + writeTextOperation[string](builder, LowerColumn(NewColumn("table", "column")), TextOperationEqual, LowerValue("value")) + }, + want: want{ + stmt: "LOWER(table.column) = LOWER($1)", + args: []any{"value"}, + }, + }, + { + name: "text not equal with wrapped value", + write: func(builder *StatementBuilder) { + writeTextOperation[string](builder, LowerColumn(NewColumn("table", "column")), TextOperationNotEqual, LowerValue("value")) + }, + want: want{ + stmt: "LOWER(table.column) <> LOWER($1)", + args: []any{"value"}, + }, + }, + { + name: "text starts with with wrapped value", + write: func(builder *StatementBuilder) { + writeTextOperation[string](builder, LowerColumn(NewColumn("table", "column")), TextOperationStartsWith, LowerValue("value")) + }, + want: want{ + stmt: "LOWER(table.column) LIKE LOWER($1) || '%'", + args: []any{"value"}, + }, + }, + { + name: "number equal", + write: func(builder *StatementBuilder) { + writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationEqual, 123) + }, + want: want{ + stmt: "table.column = $1", + args: []any{123}, + }, + }, + { + name: "number not equal", + write: func(builder *StatementBuilder) { + writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationNotEqual, 123) + }, + want: want{ + stmt: "table.column <> $1", + args: []any{123}, + }, + }, + { + name: "number less than", + write: func(builder *StatementBuilder) { + writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationLessThan, 123) + }, + want: want{ + stmt: "table.column < $1", + args: []any{123}, + }, + }, + { + name: "number less than or equal", + write: func(builder *StatementBuilder) { + writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationAtLeast, 123) + }, + want: want{ + stmt: "table.column <= $1", + args: []any{123}, + }, + }, + { + name: "number greater than", + write: func(builder *StatementBuilder) { + writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationGreaterThan, 123) + }, + want: want{ + stmt: "table.column > $1", + args: []any{123}, + }, + }, + { + name: "number greater than or equal", + write: func(builder *StatementBuilder) { + writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationAtMost, 123) + }, + want: want{ + stmt: "table.column >= $1", + args: []any{123}, + }, + }, + { + name: "boolean is true", + write: func(builder *StatementBuilder) { + writeBooleanOperation[bool](builder, NewColumn("table", "column"), true) + }, + want: want{ + stmt: "table.column = $1", + args: []any{true}, + }, + }, + { + name: "boolean is false", + write: func(builder *StatementBuilder) { + writeBooleanOperation[bool](builder, NewColumn("table", "column"), false) + }, + want: want{ + stmt: "table.column = $1", + args: []any{false}, + }, + }, + { + name: "bytes equal", + write: func(builder *StatementBuilder) { + writeBytesOperation[[]byte](builder, NewColumn("table", "column"), BytesOperationEqual, []byte{0x01, 0x02, 0x03}) + }, + want: want{ + stmt: "table.column = $1", + args: []any{[]byte{0x01, 0x02, 0x03}}, + }, + }, + { + name: "bytes not equal", + write: func(builder *StatementBuilder) { + writeBytesOperation[[]byte](builder, NewColumn("table", "column"), BytesOperationNotEqual, []byte{0x01, 0x02, 0x03}) + }, + want: want{ + stmt: "table.column <> $1", + args: []any{[]byte{0x01, 0x02, 0x03}}, + }, + }, + { + name: "bytes equal with wrapped value", + write: func(builder *StatementBuilder) { + writeBytesOperation[[]byte](builder, SHA256Column(NewColumn("table", "column")), BytesOperationEqual, SHA256Value([]byte{0x01, 0x02, 0x03})) + }, + want: want{ + stmt: "SHA256(table.column) = SHA256($1)", + args: []any{[]byte{0x01, 0x02, 0x03}}, + }, + }, + { + name: "bytes not equal with wrapped value", + write: func(builder *StatementBuilder) { + writeBytesOperation[[]byte](builder, SHA256Column(NewColumn("table", "column")), BytesOperationNotEqual, SHA256Value([]byte{0x01, 0x02, 0x03})) + }, + want: want{ + stmt: "SHA256(table.column) <> SHA256($1)", + args: []any{[]byte{0x01, 0x02, 0x03}}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + r := recover() + assert.Equal(t, tt.want.shouldPanic, r != nil) + }() + var builder StatementBuilder + tt.write(&builder) + + assert.Equal(t, tt.want.stmt, builder.String()) + assert.Equal(t, tt.want.args, builder.Args()) + }) + } +} diff --git a/backend/v3/storage/database/order_test.go b/backend/v3/storage/database/order_test.go new file mode 100644 index 00000000000..f90aafe97f9 --- /dev/null +++ b/backend/v3/storage/database/order_test.go @@ -0,0 +1,28 @@ +package database + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_orderBy_Write(t *testing.T) { + tests := []struct { + name string + want string + order Order + }{ + { + name: "order by column", + want: " ORDER BY table.column", + order: OrderBy(NewColumn("table", "column")), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var builder StatementBuilder + tt.order.Write(&builder) + assert.Equal(t, tt.want, builder.String()) + }) + } +} diff --git a/backend/v3/storage/database/query_test.go b/backend/v3/storage/database/query_test.go new file mode 100644 index 00000000000..5e8e1a58361 --- /dev/null +++ b/backend/v3/storage/database/query_test.go @@ -0,0 +1,147 @@ +package database + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestQueryOptions(t *testing.T) { + type want struct { + stmt string + args []any + } + for _, test := range []struct { + name string + options []QueryOption + want want + }{ + { + name: "no options", + want: want{ + stmt: "", + args: nil, + }, + }, + { + name: "limit option", + options: []QueryOption{ + WithLimit(10), + }, + want: want{ + stmt: " LIMIT $1", + args: []any{uint32(10)}, + }, + }, + { + name: "offset option", + options: []QueryOption{ + WithOffset(5), + }, + want: want{ + stmt: " OFFSET $1", + args: []any{uint32(5)}, + }, + }, + { + name: "order by asc option", + options: []QueryOption{ + WithOrderByAscending(NewColumn("table", "column")), + }, + want: want{ + stmt: " ORDER BY table.column", + args: nil, + }, + }, + { + name: "order by desc option", + options: []QueryOption{ + WithOrderByDescending(NewColumn("table", "column")), + }, + want: want{ + stmt: " ORDER BY table.column DESC", + args: nil, + }, + }, + { + name: "order by option", + options: []QueryOption{ + WithOrderBy(OrderDirectionAsc, NewColumn("table", "column1"), NewColumn("table", "column2")), + }, + want: want{ + stmt: " ORDER BY table.column1, table.column2", + args: nil, + }, + }, + { + name: "condition option", + options: []QueryOption{ + WithCondition(NewBooleanCondition(NewColumn("table", "column"), true)), + }, + want: want{ + stmt: " WHERE table.column = $1", + args: []any{true}, + }, + }, + { + name: "group by option", + options: []QueryOption{ + WithGroupBy(NewColumn("table", "column")), + }, + want: want{ + stmt: " GROUP BY table.column", + args: nil, + }, + }, + { + name: "left join option", + options: []QueryOption{ + WithLeftJoin("other_table", NewColumnCondition(NewColumn("table", "id"), NewColumn("other_table", "table_id"))), + }, + want: want{ + stmt: " LEFT JOIN other_table ON table.id = other_table.table_id", + args: nil, + }, + }, + { + name: "permission check option", + options: []QueryOption{ + WithPermissionCheck("permission"), + }, + want: want{ + stmt: "", + args: nil, + }, + }, + { + name: "multiple options", + options: []QueryOption{ + WithLeftJoin("other_table", NewColumnCondition(NewColumn("table", "id"), NewColumn("other_table", "table_id"))), + WithCondition(NewNumberCondition(NewColumn("table", "column"), NumberOperationEqual, 123)), + WithOrderByDescending(NewColumn("table", "column")), + WithLimit(10), + WithOffset(5), + }, + want: want{ + stmt: " LEFT JOIN other_table ON table.id = other_table.table_id WHERE table.column = $1 ORDER BY table.column DESC LIMIT $2 OFFSET $3", + args: []any{123, uint32(10), uint32(5)}, + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + var b StatementBuilder + var opts QueryOpts + for _, option := range test.options { + option(&opts) + } + opts.Write(&b) + assert.Equal(t, test.want.stmt, b.String()) + require.Len(t, b.Args(), len(test.want.args)) + for i := range test.want.args { + assert.Equal(t, test.want.args[i], b.Args()[i]) + } + }) + } + +} diff --git a/backend/v3/storage/database/repository/array.go b/backend/v3/storage/database/repository/array.go index 77851033d61..49c150fd5d6 100644 --- a/backend/v3/storage/database/repository/array.go +++ b/backend/v3/storage/database/repository/array.go @@ -3,19 +3,35 @@ package repository import ( "encoding/json" "errors" + + "github.com/zitadel/zitadel/backend/v3/storage/database" ) type JSONArray[T any] []*T -func (a JSONArray[T]) Scan(src any) error { +var ErrScanSource = errors.New("unsupported scan source") + +func (a *JSONArray[T]) Scan(src any) (err error) { + var rawJSON []byte switch s := src.(type) { case string: - return json.Unmarshal([]byte(s), &a) + if len(s) == 0 { + return nil + } + rawJSON = []byte(s) case []byte: - return json.Unmarshal(s, &a) + if len(s) == 0 { + return nil + } + rawJSON = s case nil: return nil default: - return errors.New("unsupported scan source") + return ErrScanSource } + err = json.Unmarshal(rawJSON, a) + if err != nil { + return database.NewScanError(err) + } + return nil } diff --git a/backend/v3/storage/database/repository/array_test.go b/backend/v3/storage/database/repository/array_test.go new file mode 100644 index 00000000000..9f320190f0d --- /dev/null +++ b/backend/v3/storage/database/repository/array_test.go @@ -0,0 +1,117 @@ +package repository_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/repository" +) + +type testObject struct { + ID string `json:"id"` + Number int `json:"number"` + Active bool `json:"active"` +} + +func TestJSONArray_Scan(t *testing.T) { + type want struct { + err error + res []*testObject + } + tests := []struct { + name string + src any + want want + }{ + { + name: "nil", + src: nil, + want: want{ + err: nil, + res: nil, + }, + }, + { + name: "[]byte with data", + src: []byte(`[{"id":"1","number":1,"active":true}]`), + want: want{ + err: nil, + res: []*testObject{{ID: "1", Number: 1, Active: true}}, + }, + }, + { + name: "[]byte without data", + src: []byte(`[]`), + want: want{ + err: nil, + res: nil, + }, + }, + { + name: "nil []byte", + src: []byte(nil), + want: want{ + err: nil, + res: nil, + }, + }, + { + name: "empty []byte", + src: []byte{}, + want: want{ + err: nil, + res: nil, + }, + }, + { + name: "string with data", + src: `[{"id":"1","number":1,"active":true}]`, + want: want{ + err: nil, + res: []*testObject{{ID: "1", Number: 1, Active: true}}, + }, + }, + { + name: "string without data", + src: string(`[]`), + want: want{ + err: nil, + res: nil, + }, + }, + { + name: "empty string", + src: "", + want: want{ + err: nil, + res: nil, + }, + }, + { + name: "wrong type", + src: []int{1, 2, 3}, + want: want{ + err: repository.ErrScanSource, + res: nil, + }, + }, + { + name: "badly formatted JSON", + src: []byte(`this is not a JSON`), + want: want{ + err: new(database.ScanError), + res: nil, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var a repository.JSONArray[testObject] + gotErr := a.Scan(tt.src) + require.ErrorIs(t, gotErr, tt.want.err) + require.Len(t, a, len(tt.want.res)) + }) + } +} diff --git a/backend/v3/storage/database/repository/instance.go b/backend/v3/storage/database/repository/instance.go index fa5691fc94e..2d03bdc5e0e 100644 --- a/backend/v3/storage/database/repository/instance.go +++ b/backend/v3/storage/database/repository/instance.go @@ -13,17 +13,20 @@ import ( var _ domain.InstanceRepository = (*instance)(nil) type instance struct { - repository shouldLoadDomains bool - domainRepo *instanceDomain + domainRepo instanceDomain } -func InstanceRepository(client database.QueryExecutor) domain.InstanceRepository { - return &instance{ - repository: repository{ - client: client, - }, - } +func InstanceRepository() domain.InstanceRepository { + return new(instance) +} + +func (instance) qualifiedTableName() string { + return "zitadel.instances" +} + +func (instance) unqualifiedTableName() string { + return "instances" } // ------------------------------------------------------------- @@ -37,7 +40,7 @@ const ( ) // Get implements [domain.InstanceRepository]. -func (i *instance) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Instance, error) { +func (i instance) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.Instance, error) { opts = append(opts, i.joinDomains(), database.WithGroupBy(i.IDColumn()), @@ -52,11 +55,11 @@ func (i *instance) Get(ctx context.Context, opts ...database.QueryOption) (*doma builder.WriteString(queryInstanceStmt) options.Write(&builder) - return scanInstance(ctx, i.client, &builder) + return scanInstance(ctx, client, &builder) } // List implements [domain.InstanceRepository]. -func (i *instance) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Instance, error) { +func (i instance) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*domain.Instance, error) { opts = append(opts, i.joinDomains(), database.WithGroupBy(i.IDColumn()), @@ -71,27 +74,11 @@ func (i *instance) List(ctx context.Context, opts ...database.QueryOption) ([]*d builder.WriteString(queryInstanceStmt) options.Write(&builder) - return scanInstances(ctx, i.client, &builder) -} - -func (i *instance) joinDomains() database.QueryOption { - columns := make([]database.Condition, 0, 2) - columns = append(columns, database.NewColumnCondition(i.IDColumn(), i.Domains(false).InstanceIDColumn())) - - // If domains should not be joined, we make sure to return null for the domain columns - // the query optimizer of the dialect should optimize this away if no domains are requested - if !i.shouldLoadDomains { - columns = append(columns, database.IsNull(i.Domains(false).InstanceIDColumn())) - } - - return database.WithLeftJoin( - "zitadel.instance_domains", - database.And(columns...), - ) + return scanInstances(ctx, client, &builder) } // Create implements [domain.InstanceRepository]. -func (i *instance) Create(ctx context.Context, instance *domain.Instance) error { +func (i instance) Create(ctx context.Context, client database.QueryExecutor, instance *domain.Instance) error { var ( builder database.StatementBuilder createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction @@ -103,42 +90,46 @@ func (i *instance) Create(ctx context.Context, instance *domain.Instance) error updatedAt = instance.UpdatedAt } - builder.WriteString(`INSERT INTO zitadel.instances (id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language, created_at, updated_at) VALUES (`) + builder.WriteString(`INSERT INTO `) + builder.WriteString(i.qualifiedTableName()) + builder.WriteString(` (id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language, created_at, updated_at) VALUES (`) builder.WriteArgs(instance.ID, instance.Name, instance.DefaultOrgID, instance.IAMProjectID, instance.ConsoleClientID, instance.ConsoleAppID, instance.DefaultLanguage, createdAt, updatedAt) builder.WriteString(`) RETURNING created_at, updated_at`) - return i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&instance.CreatedAt, &instance.UpdatedAt) + return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&instance.CreatedAt, &instance.UpdatedAt) } // Update implements [domain.InstanceRepository]. -func (i instance) Update(ctx context.Context, id string, changes ...database.Change) (int64, error) { +func (i instance) Update(ctx context.Context, client database.QueryExecutor, id string, changes ...database.Change) (int64, error) { if len(changes) == 0 { return 0, database.ErrNoChanges } + if !database.Changes(changes).IsOnColumn(i.UpdatedAtColumn()) { + changes = append(changes, database.NewChange(i.UpdatedAtColumn(), database.NullInstruction)) + } + var builder database.StatementBuilder - builder.WriteString(`UPDATE zitadel.instances SET `) - database.Changes(changes).Write(&builder) - idCondition := i.IDCondition(id) writeCondition(&builder, idCondition) stmt := builder.String() - return i.client.Exec(ctx, stmt, builder.Args()...) + return client.Exec(ctx, stmt, builder.Args()...) } // Delete implements [domain.InstanceRepository]. -func (i instance) Delete(ctx context.Context, id string) (int64, error) { +func (i instance) Delete(ctx context.Context, client database.QueryExecutor, id string) (int64, error) { var builder database.StatementBuilder - builder.WriteString(`DELETE FROM zitadel.instances`) + builder.WriteString(`DELETE FROM `) + builder.WriteString(i.qualifiedTableName()) idCondition := i.IDCondition(id) writeCondition(&builder, idCondition) - return i.client.Exec(ctx, builder.String(), builder.Args()...) + return client.Exec(ctx, builder.String(), builder.Args()...) } // ------------------------------------------------------------- @@ -185,53 +176,77 @@ func (i instance) NameCondition(op database.TextOperation, name string) database return database.NewTextCondition(i.NameColumn(), op, name) } +// ExistsDomain creates a correlated [database.Exists] condition on instance_domains. +// Use this filter to make sure the Instance returned contains a specific domain. +// of the instance in the aggregated result. +// Example usage: +// +// domainRepo := instanceRepo.Domains(true) // ensure domains are loaded/aggregated +// instance, _ := instanceRepo.Get(ctx, +// database.WithCondition( +// database.And( +// instanceRepo.InstanceIDCondition(instanceID), +// instanceRepo.ExistsDomain(domainRepo.DomainCondition(database.TextOperationEqual, "example.com")), +// ), +// ), +// ) +func (i instance) ExistsDomain(cond database.Condition) database.Condition { + return database.Exists( + i.domainRepo.qualifiedTableName(), + database.And( + database.NewColumnCondition(i.IDColumn(), i.domainRepo.InstanceIDColumn()), + cond, + ), + ) +} + // ------------------------------------------------------------- // columns // ------------------------------------------------------------- // IDColumn implements [domain.instanceColumns]. -func (instance) IDColumn() database.Column { - return database.NewColumn("instances", "id") +func (i instance) IDColumn() database.Column { + return database.NewColumn(i.unqualifiedTableName(), "id") } // NameColumn implements [domain.instanceColumns]. -func (instance) NameColumn() database.Column { - return database.NewColumn("instances", "name") +func (i instance) NameColumn() database.Column { + return database.NewColumn(i.unqualifiedTableName(), "name") } // CreatedAtColumn implements [domain.instanceColumns]. -func (instance) CreatedAtColumn() database.Column { - return database.NewColumn("instances", "created_at") +func (i instance) CreatedAtColumn() database.Column { + return database.NewColumn(i.unqualifiedTableName(), "created_at") } // DefaultOrgIdColumn implements [domain.instanceColumns]. -func (instance) DefaultOrgIDColumn() database.Column { - return database.NewColumn("instances", "default_org_id") +func (i instance) DefaultOrgIDColumn() database.Column { + return database.NewColumn(i.unqualifiedTableName(), "default_org_id") } // IAMProjectIDColumn implements [domain.instanceColumns]. -func (instance) IAMProjectIDColumn() database.Column { - return database.NewColumn("instances", "iam_project_id") +func (i instance) IAMProjectIDColumn() database.Column { + return database.NewColumn(i.unqualifiedTableName(), "iam_project_id") } // ConsoleClientIDColumn implements [domain.instanceColumns]. -func (instance) ConsoleClientIDColumn() database.Column { - return database.NewColumn("instances", "console_client_id") +func (i instance) ConsoleClientIDColumn() database.Column { + return database.NewColumn(i.unqualifiedTableName(), "console_client_id") } // ConsoleAppIDColumn implements [domain.instanceColumns]. -func (instance) ConsoleAppIDColumn() database.Column { - return database.NewColumn("instances", "console_app_id") +func (i instance) ConsoleAppIDColumn() database.Column { + return database.NewColumn(i.unqualifiedTableName(), "console_app_id") } // DefaultLanguageColumn implements [domain.instanceColumns]. -func (instance) DefaultLanguageColumn() database.Column { - return database.NewColumn("instances", "default_language") +func (i instance) DefaultLanguageColumn() database.Column { + return database.NewColumn(i.unqualifiedTableName(), "default_language") } // UpdatedAtColumn implements [domain.instanceColumns]. -func (instance) UpdatedAtColumn() database.Column { - return database.NewColumn("instances", "updated_at") +func (i instance) UpdatedAtColumn() database.Column { + return database.NewColumn(i.unqualifiedTableName(), "updated_at") } // ------------------------------------------------------------- @@ -282,19 +297,24 @@ func scanInstances(ctx context.Context, querier database.Querier, builder *datab // sub repositories // ------------------------------------------------------------- -// Domains implements [domain.InstanceRepository]. -func (i *instance) Domains(shouldLoad bool) domain.InstanceDomainRepository { - if !i.shouldLoadDomains { - i.shouldLoadDomains = shouldLoad +func (i *instance) LoadDomains() domain.InstanceRepository { + return &instance{ + shouldLoadDomains: true, } - - if i.domainRepo != nil { - return i.domainRepo - } - - i.domainRepo = &instanceDomain{ - repository: i.repository, - instance: i, - } - return i.domainRepo +} + +func (i *instance) joinDomains() database.QueryOption { + columns := make([]database.Condition, 0, 2) + columns = append(columns, database.NewColumnCondition(i.IDColumn(), i.domainRepo.InstanceIDColumn())) + + // If domains should not be joined, we make sure to return null for the domain columns + // the query optimizer of the dialect should optimize this away if no domains are requested + if !i.shouldLoadDomains { + columns = append(columns, database.IsNull(i.domainRepo.InstanceIDColumn())) + } + + return database.WithLeftJoin( + i.domainRepo.qualifiedTableName(), + database.And(columns...), + ) } diff --git a/backend/v3/storage/database/repository/instance_domain.go b/backend/v3/storage/database/repository/instance_domain.go index 3aed56238ea..9cb4d9e127a 100644 --- a/backend/v3/storage/database/repository/instance_domain.go +++ b/backend/v3/storage/database/repository/instance_domain.go @@ -10,9 +10,18 @@ import ( var _ domain.InstanceDomainRepository = (*instanceDomain)(nil) -type instanceDomain struct { - repository - *instance +type instanceDomain struct{} + +func InstanceDomainRepository() domain.InstanceDomainRepository { + return new(instanceDomain) +} + +func (instanceDomain) qualifiedTableName() string { + return "zitadel.instance_domains" +} + +func (instanceDomain) unqualifiedTableName() string { + return "instance_domains" } // ------------------------------------------------------------- @@ -23,8 +32,7 @@ const queryInstanceDomainStmt = `SELECT instance_domains.instance_id, instance_d `FROM zitadel.instance_domains` // Get implements [domain.InstanceDomainRepository]. -// Subtle: this method shadows the method ([domain.InstanceRepository]).Get of instanceDomain.instance. -func (i *instanceDomain) Get(ctx context.Context, opts ...database.QueryOption) (*domain.InstanceDomain, error) { +func (i instanceDomain) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.InstanceDomain, error) { options := new(database.QueryOpts) for _, opt := range opts { opt(options) @@ -34,12 +42,11 @@ func (i *instanceDomain) Get(ctx context.Context, opts ...database.QueryOption) builder.WriteString(queryInstanceDomainStmt) options.Write(&builder) - return scanInstanceDomain(ctx, i.client, &builder) + return scanInstanceDomain(ctx, client, &builder) } // List implements [domain.InstanceDomainRepository]. -// Subtle: this method shadows the method ([domain.InstanceRepository]).List of instanceDomain.instance. -func (i *instanceDomain) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.InstanceDomain, error) { +func (i instanceDomain) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*domain.InstanceDomain, error) { options := new(database.QueryOpts) for _, opt := range opts { opt(options) @@ -49,11 +56,11 @@ func (i *instanceDomain) List(ctx context.Context, opts ...database.QueryOption) builder.WriteString(queryInstanceDomainStmt) options.Write(&builder) - return scanInstanceDomains(ctx, i.client, &builder) + return scanInstanceDomains(ctx, client, &builder) } // Add implements [domain.InstanceDomainRepository]. -func (i *instanceDomain) Add(ctx context.Context, domain *domain.AddInstanceDomain) error { +func (i instanceDomain) Add(ctx context.Context, client database.QueryExecutor, domain *domain.AddInstanceDomain) error { var ( builder database.StatementBuilder createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction @@ -69,33 +76,41 @@ func (i *instanceDomain) Add(ctx context.Context, domain *domain.AddInstanceDoma builder.WriteArgs(domain.InstanceID, domain.Domain, domain.IsPrimary, domain.IsGenerated, domain.Type, createdAt, updatedAt) builder.WriteString(`) RETURNING created_at, updated_at`) - return i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt) + return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt) } // Update implements [domain.InstanceDomainRepository]. -// Subtle: this method shadows the method ([domain.InstanceRepository]).Update of instanceDomain.instance. -func (i *instanceDomain) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) { +func (i instanceDomain) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error) { + if !condition.IsRestrictingColumn(i.InstanceIDColumn()) { + return 0, database.NewMissingConditionError(i.InstanceIDColumn()) + } if len(changes) == 0 { return 0, database.ErrNoChanges } - var builder database.StatementBuilder + if !database.Changes(changes).IsOnColumn(i.UpdatedAtColumn()) { + changes = append(changes, database.NewChange(i.UpdatedAtColumn(), database.NullInstruction)) + } + var builder database.StatementBuilder builder.WriteString(`UPDATE zitadel.instance_domains SET `) database.Changes(changes).Write(&builder) writeCondition(&builder, condition) - return i.client.Exec(ctx, builder.String(), builder.Args()...) + return client.Exec(ctx, builder.String(), builder.Args()...) } // Remove implements [domain.InstanceDomainRepository]. -func (i *instanceDomain) Remove(ctx context.Context, condition database.Condition) (int64, error) { - var builder database.StatementBuilder +func (i instanceDomain) Remove(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error) { + if !condition.IsRestrictingColumn(i.InstanceIDColumn()) { + return 0, database.NewMissingConditionError(i.InstanceIDColumn()) + } + var builder database.StatementBuilder builder.WriteString(`DELETE FROM zitadel.instance_domains WHERE `) condition.Write(&builder) - return i.client.Exec(ctx, builder.String(), builder.Args()...) + return client.Exec(ctx, builder.String(), builder.Args()...) } // ------------------------------------------------------------- @@ -146,40 +161,38 @@ func (i instanceDomain) TypeCondition(typ domain.DomainType) database.Condition // ------------------------------------------------------------- // CreatedAtColumn implements [domain.InstanceDomainRepository]. -// Subtle: this method shadows the method ([domain.InstanceRepository]).CreatedAtColumn of instanceDomain.instance. -func (instanceDomain) CreatedAtColumn() database.Column { - return database.NewColumn("instance_domains", "created_at") +func (i instanceDomain) CreatedAtColumn() database.Column { + return database.NewColumn(i.unqualifiedTableName(), "created_at") } // DomainColumn implements [domain.InstanceDomainRepository]. -func (instanceDomain) DomainColumn() database.Column { - return database.NewColumn("instance_domains", "domain") +func (i instanceDomain) DomainColumn() database.Column { + return database.NewColumn(i.unqualifiedTableName(), "domain") } // InstanceIDColumn implements [domain.InstanceDomainRepository]. -func (instanceDomain) InstanceIDColumn() database.Column { - return database.NewColumn("instance_domains", "instance_id") +func (i instanceDomain) InstanceIDColumn() database.Column { + return database.NewColumn(i.unqualifiedTableName(), "instance_id") } // IsPrimaryColumn implements [domain.InstanceDomainRepository]. -func (instanceDomain) IsPrimaryColumn() database.Column { - return database.NewColumn("instance_domains", "is_primary") +func (i instanceDomain) IsPrimaryColumn() database.Column { + return database.NewColumn(i.unqualifiedTableName(), "is_primary") } // UpdatedAtColumn implements [domain.InstanceDomainRepository]. -// Subtle: this method shadows the method ([domain.InstanceRepository]).UpdatedAtColumn of instanceDomain.instance. -func (instanceDomain) UpdatedAtColumn() database.Column { - return database.NewColumn("instance_domains", "updated_at") +func (i instanceDomain) UpdatedAtColumn() database.Column { + return database.NewColumn(i.unqualifiedTableName(), "updated_at") } // IsGeneratedColumn implements [domain.InstanceDomainRepository]. -func (instanceDomain) IsGeneratedColumn() database.Column { - return database.NewColumn("instance_domains", "is_generated") +func (i instanceDomain) IsGeneratedColumn() database.Column { + return database.NewColumn(i.unqualifiedTableName(), "is_generated") } // TypeColumn implements [domain.InstanceDomainRepository]. -func (instanceDomain) TypeColumn() database.Column { - return database.NewColumn("instance_domains", "type") +func (i instanceDomain) TypeColumn() database.Column { + return database.NewColumn(i.unqualifiedTableName(), "type") } // ------------------------------------------------------------- diff --git a/backend/v3/storage/database/repository/instance_domain_test.go b/backend/v3/storage/database/repository/instance_domain_test.go index cc18ada380f..1900e648461 100644 --- a/backend/v3/storage/database/repository/instance_domain_test.go +++ b/backend/v3/storage/database/repository/instance_domain_test.go @@ -1,7 +1,6 @@ package repository_test import ( - "context" "testing" "time" @@ -16,31 +15,43 @@ import ( ) func TestAddInstanceDomain(t *testing.T) { + // we take now here because the timestamp of the transaction is used to set the createdAt and updatedAt fields + beforeAdd := time.Now() + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + err = tx.Rollback(t.Context()) + if err != nil { + t.Log("error during rollback:", err) + } + }() + + instanceRepo := repository.InstanceRepository() + domainRepo := repository.InstanceDomainRepository() + // create instance - instanceID := gofakeit.UUID() instance := domain.Instance{ - ID: instanceID, - Name: gofakeit.Name(), + ID: gofakeit.NewCrypto().UUID(), + Name: gofakeit.BeerName(), DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", ConsoleClientID: "consoleClient", ConsoleAppID: "consoleApp", DefaultLanguage: "defaultLanguage", } - instanceRepo := repository.InstanceRepository(pool) - err := instanceRepo.Create(t.Context(), &instance) + err = instanceRepo.Create(t.Context(), tx, &instance) require.NoError(t, err) tests := []struct { name string - testFunc func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain + testFunc func(t *testing.T, tx database.QueryExecutor) *domain.AddInstanceDomain instanceDomain domain.AddInstanceDomain err error }{ { name: "happy path custom domain", instanceDomain: domain.AddInstanceDomain{ - InstanceID: instanceID, + InstanceID: instance.ID, Domain: gofakeit.DomainName(), Type: domain.DomainTypeCustom, IsPrimary: gu.Ptr(false), @@ -50,7 +61,7 @@ func TestAddInstanceDomain(t *testing.T) { { name: "happy path trusted domain", instanceDomain: domain.AddInstanceDomain{ - InstanceID: instanceID, + InstanceID: instance.ID, Domain: gofakeit.DomainName(), Type: domain.DomainTypeTrusted, }, @@ -58,7 +69,7 @@ func TestAddInstanceDomain(t *testing.T) { { name: "add primary domain", instanceDomain: domain.AddInstanceDomain{ - InstanceID: instanceID, + InstanceID: instance.ID, Domain: gofakeit.DomainName(), Type: domain.DomainTypeCustom, IsPrimary: gu.Ptr(true), @@ -68,7 +79,7 @@ func TestAddInstanceDomain(t *testing.T) { { name: "add custom domain without domain name", instanceDomain: domain.AddInstanceDomain{ - InstanceID: instanceID, + InstanceID: instance.ID, Domain: "", Type: domain.DomainTypeCustom, IsPrimary: gu.Ptr(false), @@ -79,7 +90,7 @@ func TestAddInstanceDomain(t *testing.T) { { name: "add trusted domain without domain name", instanceDomain: domain.AddInstanceDomain{ - InstanceID: instanceID, + InstanceID: instance.ID, Domain: "", Type: domain.DomainTypeTrusted, }, @@ -87,23 +98,23 @@ func TestAddInstanceDomain(t *testing.T) { }, { name: "add custom domain with same domain twice", - testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain { + testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.AddInstanceDomain { domainName := gofakeit.DomainName() instanceDomain := &domain.AddInstanceDomain{ - InstanceID: instanceID, + InstanceID: instance.ID, Domain: domainName, Type: domain.DomainTypeCustom, IsPrimary: gu.Ptr(false), IsGenerated: gu.Ptr(false), } - err := domainRepo.Add(ctx, instanceDomain) + err := domainRepo.Add(t.Context(), tx, instanceDomain) require.NoError(t, err) // return same domain again return &domain.AddInstanceDomain{ - InstanceID: instanceID, + InstanceID: instance.ID, Domain: domainName, Type: domain.DomainTypeCustom, IsPrimary: gu.Ptr(false), @@ -114,22 +125,20 @@ func TestAddInstanceDomain(t *testing.T) { }, { name: "add trusted domain with same domain twice", - testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain { - domainName := gofakeit.DomainName() - + testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.AddInstanceDomain { instanceDomain := &domain.AddInstanceDomain{ - InstanceID: instanceID, - Domain: domainName, + InstanceID: instance.ID, + Domain: gofakeit.DomainName(), Type: domain.DomainTypeTrusted, } - err := domainRepo.Add(ctx, instanceDomain) + err := domainRepo.Add(t.Context(), tx, instanceDomain) require.NoError(t, err) // return same domain again return &domain.AddInstanceDomain{ - InstanceID: instanceID, - Domain: domainName, + InstanceID: instance.ID, + Domain: instanceDomain.Domain, Type: domain.DomainTypeTrusted, } }, @@ -196,26 +205,23 @@ func TestAddInstanceDomain(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := t.Context() - - // we take now here because the timestamp of the transaction is used to set the createdAt and updatedAt fields - beforeAdd := time.Now() - tx, err := pool.Begin(t.Context(), nil) + savepoint, err := tx.Begin(t.Context()) require.NoError(t, err) defer func() { - require.NoError(t, tx.Rollback(t.Context())) + err := savepoint.Rollback(t.Context()) + if err != nil { + t.Log("error during rollback:", err) + } }() - instanceRepo := repository.InstanceRepository(tx) - domainRepo := instanceRepo.Domains(false) var instanceDomain *domain.AddInstanceDomain if test.testFunc != nil { - instanceDomain = test.testFunc(ctx, t, domainRepo) + instanceDomain = test.testFunc(t, savepoint) } else { instanceDomain = &test.instanceDomain } - err = domainRepo.Add(ctx, instanceDomain) + err = domainRepo.Add(t.Context(), savepoint, instanceDomain) afterAdd := time.Now() if test.err != nil { assert.ErrorIs(t, err, test.err) @@ -232,10 +238,21 @@ func TestAddInstanceDomain(t *testing.T) { } func TestGetInstanceDomain(t *testing.T) { + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + err = tx.Rollback(t.Context()) + if err != nil { + t.Log("error during rollback:", err) + } + }() + + instanceRepo := repository.InstanceRepository() + domainRepo := repository.InstanceDomainRepository() + // create instance - instanceID := gofakeit.UUID() instance := domain.Instance{ - ID: instanceID, + ID: gofakeit.NewCrypto().UUID(), Name: gofakeit.Name(), DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", @@ -243,38 +260,29 @@ func TestGetInstanceDomain(t *testing.T) { ConsoleAppID: "consoleApp", DefaultLanguage: "defaultLanguage", } - tx, err := pool.Begin(t.Context(), nil) - require.NoError(t, err) - defer func() { - require.NoError(t, tx.Rollback(t.Context())) - }() - instanceRepo := repository.InstanceRepository(tx) - err = instanceRepo.Create(t.Context(), &instance) + + err = instanceRepo.Create(t.Context(), tx, &instance) require.NoError(t, err) // add domains - domainRepo := instanceRepo.Domains(false) - domainName1 := gofakeit.DomainName() - domainName2 := gofakeit.DomainName() - domain1 := &domain.AddInstanceDomain{ - InstanceID: instanceID, - Domain: domainName1, + InstanceID: instance.ID, + Domain: gofakeit.DomainName(), IsPrimary: gu.Ptr(true), IsGenerated: gu.Ptr(false), Type: domain.DomainTypeCustom, } domain2 := &domain.AddInstanceDomain{ - InstanceID: instanceID, - Domain: domainName2, + InstanceID: instance.ID, + Domain: gofakeit.DomainName(), IsPrimary: gu.Ptr(false), IsGenerated: gu.Ptr(false), Type: domain.DomainTypeCustom, } - err = domainRepo.Add(t.Context(), domain1) + err = domainRepo.Add(t.Context(), tx, domain1) require.NoError(t, err) - err = domainRepo.Add(t.Context(), domain2) + err = domainRepo.Add(t.Context(), tx, domain2) require.NoError(t, err) tests := []struct { @@ -289,19 +297,19 @@ func TestGetInstanceDomain(t *testing.T) { database.WithCondition(domainRepo.IsPrimaryCondition(true)), }, expected: &domain.InstanceDomain{ - InstanceID: instanceID, - Domain: domainName1, + InstanceID: instance.ID, + Domain: domain1.Domain, IsPrimary: gu.Ptr(true), }, }, { name: "get by domain name", opts: []database.QueryOption{ - database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, domainName2)), + database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, domain2.Domain)), }, expected: &domain.InstanceDomain{ - InstanceID: instanceID, - Domain: domainName2, + InstanceID: instance.ID, + Domain: domain2.Domain, IsPrimary: gu.Ptr(false), }, }, @@ -318,7 +326,7 @@ func TestGetInstanceDomain(t *testing.T) { t.Run(test.name, func(t *testing.T) { ctx := t.Context() - result, err := domainRepo.Get(ctx, test.opts...) + result, err := domainRepo.Get(ctx, tx, test.opts...) if test.err != nil { assert.ErrorIs(t, err, test.err) return @@ -335,10 +343,21 @@ func TestGetInstanceDomain(t *testing.T) { } func TestListInstanceDomains(t *testing.T) { + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + err = tx.Rollback(t.Context()) + if err != nil { + t.Log("error during rollback:", err) + } + }() + + instanceRepo := repository.InstanceRepository() + domainRepo := repository.InstanceDomainRepository() + // create instance - instanceID := gofakeit.UUID() instance := domain.Instance{ - ID: instanceID, + ID: gofakeit.NewCrypto().UUID(), Name: gofakeit.Name(), DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", @@ -346,42 +365,35 @@ func TestListInstanceDomains(t *testing.T) { ConsoleAppID: "consoleApp", DefaultLanguage: "defaultLanguage", } - tx, err := pool.Begin(t.Context(), nil) - require.NoError(t, err) - defer func() { - require.NoError(t, tx.Rollback(t.Context())) - }() - instanceRepo := repository.InstanceRepository(tx) - err = instanceRepo.Create(t.Context(), &instance) + err = instanceRepo.Create(t.Context(), tx, &instance) require.NoError(t, err) // add multiple domains - domainRepo := instanceRepo.Domains(false) domains := []domain.AddInstanceDomain{ { - InstanceID: instanceID, + InstanceID: instance.ID, Domain: gofakeit.DomainName(), IsPrimary: gu.Ptr(true), IsGenerated: gu.Ptr(false), Type: domain.DomainTypeCustom, }, { - InstanceID: instanceID, + InstanceID: instance.ID, Domain: gofakeit.DomainName(), IsPrimary: gu.Ptr(false), IsGenerated: gu.Ptr(false), Type: domain.DomainTypeCustom, }, { - InstanceID: instanceID, + InstanceID: instance.ID, Domain: gofakeit.DomainName(), Type: domain.DomainTypeTrusted, }, } for i := range domains { - err = domainRepo.Add(t.Context(), &domains[i]) + err = domainRepo.Add(t.Context(), tx, &domains[i]) require.NoError(t, err) } @@ -405,7 +417,7 @@ func TestListInstanceDomains(t *testing.T) { { name: "list by instance", opts: []database.QueryOption{ - database.WithCondition(domainRepo.InstanceIDCondition(instanceID)), + database.WithCondition(domainRepo.InstanceIDCondition(instance.ID)), }, expectedCount: 3, }, @@ -422,12 +434,12 @@ func TestListInstanceDomains(t *testing.T) { t.Run(test.name, func(t *testing.T) { ctx := t.Context() - results, err := domainRepo.List(ctx, test.opts...) + results, err := domainRepo.List(ctx, tx, test.opts...) require.NoError(t, err) assert.Len(t, results, test.expectedCount) for _, result := range results { - assert.Equal(t, instanceID, result.InstanceID) + assert.Equal(t, instance.ID, result.InstanceID) assert.NotEmpty(t, result.Domain) assert.NotEmpty(t, result.CreatedAt) assert.NotEmpty(t, result.UpdatedAt) @@ -437,10 +449,21 @@ func TestListInstanceDomains(t *testing.T) { } func TestUpdateInstanceDomain(t *testing.T) { + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + err = tx.Rollback(t.Context()) + if err != nil { + t.Log("error during rollback:", err) + } + }() + + instanceRepo := repository.InstanceRepository() + domainRepo := repository.InstanceDomainRepository() + // create instance - instanceID := gofakeit.UUID() instance := domain.Instance{ - ID: instanceID, + ID: gofakeit.UUID(), Name: gofakeit.Name(), DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", @@ -448,29 +471,19 @@ func TestUpdateInstanceDomain(t *testing.T) { ConsoleAppID: "consoleApp", DefaultLanguage: "defaultLanguage", } - - tx, err := pool.Begin(t.Context(), nil) - require.NoError(t, err) - defer func() { - require.NoError(t, tx.Rollback(t.Context())) - }() - - instanceRepo := repository.InstanceRepository(tx) - err = instanceRepo.Create(t.Context(), &instance) + err = instanceRepo.Create(t.Context(), tx, &instance) require.NoError(t, err) // add domain - domainRepo := instanceRepo.Domains(false) - domainName := gofakeit.DomainName() instanceDomain := &domain.AddInstanceDomain{ - InstanceID: instanceID, - Domain: domainName, + InstanceID: instance.ID, + Domain: gofakeit.DomainName(), IsPrimary: gu.Ptr(false), IsGenerated: gu.Ptr(false), Type: domain.DomainTypeCustom, } - err = domainRepo.Add(t.Context(), instanceDomain) + err = domainRepo.Add(t.Context(), tx, instanceDomain) require.NoError(t, err) tests := []struct { @@ -481,31 +494,38 @@ func TestUpdateInstanceDomain(t *testing.T) { err error }{ { - name: "set primary", - condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName), - changes: []database.Change{domainRepo.SetPrimary()}, - expected: 1, + name: "set primary", + condition: database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.DomainCondition(database.TextOperationEqual, instanceDomain.Domain), + ), + changes: []database.Change{domainRepo.SetPrimary()}, + expected: 1, }, { - name: "update non-existent domain", - condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"), - changes: []database.Change{domainRepo.SetPrimary()}, - expected: 0, + name: "update non-existent domain", + condition: database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"), + ), + changes: []database.Change{domainRepo.SetPrimary()}, + expected: 0, }, { - name: "no changes", - condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName), - changes: []database.Change{}, - expected: 0, - err: database.ErrNoChanges, + name: "no changes", + condition: database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.DomainCondition(database.TextOperationEqual, instanceDomain.Domain), + ), + changes: []database.Change{}, + expected: 0, + err: database.ErrNoChanges, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := t.Context() - - rowsAffected, err := domainRepo.Update(ctx, test.condition, test.changes...) + rowsAffected, err := domainRepo.Update(t.Context(), tx, test.condition, test.changes...) if test.err != nil { assert.ErrorIs(t, err, test.err) return @@ -516,7 +536,7 @@ func TestUpdateInstanceDomain(t *testing.T) { // verify changes were applied if rows were affected if rowsAffected > 0 && len(test.changes) > 0 { - result, err := domainRepo.Get(ctx, database.WithCondition(test.condition)) + result, err := domainRepo.Get(t.Context(), tx, database.WithCondition(test.condition)) require.NoError(t, err) // We know changes were applied since rowsAffected > 0 @@ -529,10 +549,21 @@ func TestUpdateInstanceDomain(t *testing.T) { } func TestRemoveInstanceDomain(t *testing.T) { + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + err = tx.Rollback(t.Context()) + if err != nil { + t.Log("error during rollback:", err) + } + }() + + instanceRepo := repository.InstanceRepository() + domainRepo := repository.InstanceDomainRepository() + // create instance - instanceID := gofakeit.UUID() instance := domain.Instance{ - ID: instanceID, + ID: gofakeit.UUID(), Name: gofakeit.Name(), DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", @@ -540,37 +571,28 @@ func TestRemoveInstanceDomain(t *testing.T) { ConsoleAppID: "consoleApp", DefaultLanguage: "defaultLanguage", } - tx, err := pool.Begin(t.Context(), nil) - require.NoError(t, err) - defer func() { - require.NoError(t, tx.Rollback(t.Context())) - }() - instanceRepo := repository.InstanceRepository(tx) - err = instanceRepo.Create(t.Context(), &instance) + err = instanceRepo.Create(t.Context(), tx, &instance) require.NoError(t, err) // add domains - domainRepo := instanceRepo.Domains(false) - domainName1 := gofakeit.DomainName() - domain1 := &domain.AddInstanceDomain{ - InstanceID: instanceID, - Domain: domainName1, + InstanceID: instance.ID, + Domain: gofakeit.DomainName(), IsPrimary: gu.Ptr(true), IsGenerated: gu.Ptr(false), Type: domain.DomainTypeCustom, } domain2 := &domain.AddInstanceDomain{ - InstanceID: instanceID, + InstanceID: instance.ID, Domain: gofakeit.DomainName(), IsPrimary: gu.Ptr(false), IsGenerated: gu.Ptr(false), Type: domain.DomainTypeCustom, } - err = domainRepo.Add(t.Context(), domain1) + err = domainRepo.Add(t.Context(), tx, domain1) require.NoError(t, err) - err = domainRepo.Add(t.Context(), domain2) + err = domainRepo.Add(t.Context(), tx, domain2) require.NoError(t, err) tests := []struct { @@ -579,36 +601,43 @@ func TestRemoveInstanceDomain(t *testing.T) { expected int64 }{ { - name: "remove by domain name", - condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName1), - expected: 1, + name: "remove by domain name", + condition: database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.DomainCondition(database.TextOperationEqual, domain1.Domain), + ), + expected: 1, }, { - name: "remove by primary condition", - condition: domainRepo.IsPrimaryCondition(false), - expected: 1, // domain2 should still exist and be non-primary + name: "remove by primary condition", + condition: database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.IsPrimaryCondition(false), + ), + expected: 1, // domain2 should still exist and be non-primary }, { - name: "remove non-existent domain", - condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"), - expected: 0, + name: "remove non-existent domain", + condition: database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"), + ), + expected: 0, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := t.Context() - // count before removal - beforeCount, err := domainRepo.List(ctx) + beforeCount, err := domainRepo.List(t.Context(), tx) require.NoError(t, err) - rowsAffected, err := domainRepo.Remove(ctx, test.condition) + rowsAffected, err := domainRepo.Remove(t.Context(), tx, test.condition) require.NoError(t, err) assert.Equal(t, test.expected, rowsAffected) // verify removal - afterCount, err := domainRepo.List(ctx) + afterCount, err := domainRepo.List(t.Context(), tx) require.NoError(t, err) assert.Equal(t, len(beforeCount)-int(test.expected), len(afterCount)) }) @@ -616,8 +645,7 @@ func TestRemoveInstanceDomain(t *testing.T) { } func TestInstanceDomainConditions(t *testing.T) { - instanceRepo := repository.InstanceRepository(pool) - domainRepo := instanceRepo.Domains(false) + domainRepo := repository.InstanceDomainRepository() tests := []struct { name string @@ -671,8 +699,7 @@ func TestInstanceDomainConditions(t *testing.T) { } func TestInstanceDomainChanges(t *testing.T) { - instanceRepo := repository.InstanceRepository(pool) - domainRepo := instanceRepo.Domains(false) + domainRepo := repository.InstanceDomainRepository() tests := []struct { name string diff --git a/backend/v3/storage/database/repository/instance_test.go b/backend/v3/storage/database/repository/instance_test.go index 728cbdd753e..d57a9c979e2 100644 --- a/backend/v3/storage/database/repository/instance_test.go +++ b/backend/v3/storage/database/repository/instance_test.go @@ -2,6 +2,7 @@ package repository_test import ( "context" + "strconv" "testing" "time" @@ -16,9 +17,20 @@ import ( ) func TestCreateInstance(t *testing.T) { + beforeCreate := time.Now() + tx, err := pool.Begin(context.Background(), nil) + require.NoError(t, err) + defer func() { + err := tx.Rollback(context.Background()) + if err != nil { + t.Log("error during rollback:", err) + } + }() + instanceRepo := repository.InstanceRepository() + tests := []struct { name string - testFunc func(ctx context.Context, t *testing.T) *domain.Instance + testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance instance domain.Instance err error }{ @@ -59,14 +71,10 @@ func TestCreateInstance(t *testing.T) { }, { name: "adding same instance twice", - testFunc: func(ctx context.Context, t *testing.T) *domain.Instance { - instanceRepo := repository.InstanceRepository(pool) - instanceId := gofakeit.Name() - instanceName := gofakeit.Name() - + testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance { inst := domain.Instance{ - ID: instanceId, - Name: instanceName, + ID: gofakeit.UUID(), + Name: gofakeit.Name(), DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", ConsoleClientID: "consoleCLient", @@ -74,7 +82,9 @@ func TestCreateInstance(t *testing.T) { DefaultLanguage: "defaultLanguage", } - err := instanceRepo.Create(ctx, &inst) + err := instanceRepo.Create(t.Context(), tx, &inst) + require.NoError(t, err) + // change the name to make sure same only the id clashes inst.Name = gofakeit.Name() require.NoError(t, err) @@ -84,7 +94,7 @@ func TestCreateInstance(t *testing.T) { }, func() struct { name string - testFunc func(ctx context.Context, t *testing.T) *domain.Instance + testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance instance domain.Instance err error } { @@ -92,14 +102,12 @@ func TestCreateInstance(t *testing.T) { instanceName := gofakeit.Name() return struct { name string - testFunc func(ctx context.Context, t *testing.T) *domain.Instance + testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance instance domain.Instance err error }{ name: "adding instance with same name twice", - testFunc: func(ctx context.Context, t *testing.T) *domain.Instance { - instanceRepo := repository.InstanceRepository(pool) - + testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance { inst := domain.Instance{ ID: gofakeit.Name(), Name: instanceName, @@ -110,7 +118,7 @@ func TestCreateInstance(t *testing.T) { DefaultLanguage: "defaultLanguage", } - err := instanceRepo.Create(ctx, &inst) + err := instanceRepo.Create(t.Context(), tx, &inst) require.NoError(t, err) // change the id @@ -135,11 +143,8 @@ func TestCreateInstance(t *testing.T) { { name: "adding instance with no id", instance: func() domain.Instance { - // instanceId := gofakeit.Name() - instanceName := gofakeit.Name() instance := domain.Instance{ - // ID: instanceId, - Name: instanceName, + Name: gofakeit.Name(), DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", ConsoleClientID: "consoleCLient", @@ -153,19 +158,25 @@ func TestCreateInstance(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() + savepoint, err := tx.Begin(t.Context()) + require.NoError(t, err) + defer func() { + err = savepoint.Rollback(t.Context()) + if err != nil { + t.Log("error during rollback:", err) + } + }() var instance *domain.Instance if tt.testFunc != nil { - instance = tt.testFunc(ctx, t) + instance = tt.testFunc(t, savepoint) } else { instance = &tt.instance } - instanceRepo := repository.InstanceRepository(pool) // create instance - beforeCreate := time.Now() - err := instanceRepo.Create(ctx, instance) + + err = instanceRepo.Create(t.Context(), tx, instance) assert.ErrorIs(t, err, tt.err) if err != nil { return @@ -173,7 +184,7 @@ func TestCreateInstance(t *testing.T) { afterCreate := time.Now() // check instance values - instance, err = instanceRepo.Get(ctx, + instance, err = instanceRepo.Get(t.Context(), tx, database.WithCondition( instanceRepo.IDCondition(instance.ID), ), @@ -194,22 +205,30 @@ func TestCreateInstance(t *testing.T) { } func TestUpdateInstance(t *testing.T) { + beforeUpdate := time.Now() + tx, err := pool.Begin(context.Background(), nil) + require.NoError(t, err) + defer func() { + err := tx.Rollback(context.Background()) + if err != nil { + t.Log("error during rollback:", err) + } + }() + + instanceRepo := repository.InstanceRepository() + tests := []struct { name string - testFunc func(ctx context.Context, t *testing.T) *domain.Instance + testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance rowsAffected int64 getErr error }{ { name: "happy path", - testFunc: func(ctx context.Context, t *testing.T) *domain.Instance { - instanceRepo := repository.InstanceRepository(pool) - instanceId := gofakeit.Name() - instanceName := gofakeit.Name() - + testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance { inst := domain.Instance{ - ID: instanceId, - Name: instanceName, + ID: gofakeit.UUID(), + Name: gofakeit.Name(), DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", ConsoleClientID: "consoleCLient", @@ -218,7 +237,7 @@ func TestUpdateInstance(t *testing.T) { } // create instance - err := instanceRepo.Create(ctx, &inst) + err := instanceRepo.Create(t.Context(), tx, &inst) require.NoError(t, err) return &inst }, @@ -226,14 +245,10 @@ func TestUpdateInstance(t *testing.T) { }, { name: "update deleted instance", - testFunc: func(ctx context.Context, t *testing.T) *domain.Instance { - instanceRepo := repository.InstanceRepository(pool) - instanceId := gofakeit.Name() - instanceName := gofakeit.Name() - + testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance { inst := domain.Instance{ - ID: instanceId, - Name: instanceName, + ID: gofakeit.UUID(), + Name: gofakeit.Name(), DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", ConsoleClientID: "consoleCLient", @@ -242,11 +257,11 @@ func TestUpdateInstance(t *testing.T) { } // create instance - err := instanceRepo.Create(ctx, &inst) + err := instanceRepo.Create(t.Context(), tx, &inst) require.NoError(t, err) // delete instance - affectedRows, err := instanceRepo.Delete(ctx, + affectedRows, err := instanceRepo.Delete(t.Context(), tx, inst.ID, ) require.NoError(t, err) @@ -258,11 +273,9 @@ func TestUpdateInstance(t *testing.T) { }, { name: "update non existent instance", - testFunc: func(ctx context.Context, t *testing.T) *domain.Instance { - instanceId := gofakeit.Name() - + testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance { inst := domain.Instance{ - ID: instanceId, + ID: gofakeit.UUID(), } return &inst }, @@ -272,15 +285,11 @@ func TestUpdateInstance(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - instanceRepo := repository.InstanceRepository(pool) + instance := tt.testFunc(t, tx) - instance := tt.testFunc(ctx, t) - - beforeUpdate := time.Now() // update name newName := "new_" + instance.Name - rowsAffected, err := instanceRepo.Update(ctx, + rowsAffected, err := instanceRepo.Update(t.Context(), tx, instance.ID, instanceRepo.SetName(newName), ) @@ -294,7 +303,7 @@ func TestUpdateInstance(t *testing.T) { } // check instance values - instance, err = instanceRepo.Get(ctx, + instance, err = instanceRepo.Get(t.Context(), tx, database.WithCondition( instanceRepo.IDCondition(instance.ID), ), @@ -308,24 +317,31 @@ func TestUpdateInstance(t *testing.T) { } func TestGetInstance(t *testing.T) { - instanceRepo := repository.InstanceRepository(pool) + tx, err := pool.Begin(context.Background(), nil) + require.NoError(t, err) + defer func() { + err := tx.Rollback(context.Background()) + if err != nil { + t.Log("error during rollback:", err) + } + }() + + instanceRepo := repository.InstanceRepository() + domainRepo := repository.InstanceDomainRepository() + type test struct { name string - testFunc func(ctx context.Context, t *testing.T) *domain.Instance + testFunc func(t *testing.T) *domain.Instance err error } - tests := []test{ func() test { - instanceId := gofakeit.Name() return test{ name: "happy path get using id", - testFunc: func(ctx context.Context, t *testing.T) *domain.Instance { - instanceName := gofakeit.Name() - + testFunc: func(t *testing.T) *domain.Instance { inst := domain.Instance{ - ID: instanceId, - Name: instanceName, + ID: gofakeit.UUID(), + Name: gofakeit.BeerName(), DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", ConsoleClientID: "consoleCLient", @@ -334,7 +350,7 @@ func TestGetInstance(t *testing.T) { } // create instance - err := instanceRepo.Create(ctx, &inst) + err := instanceRepo.Create(t.Context(), tx, &inst) require.NoError(t, err) return &inst }, @@ -342,14 +358,10 @@ func TestGetInstance(t *testing.T) { }(), { name: "happy path including domains", - testFunc: func(ctx context.Context, t *testing.T) *domain.Instance { - instanceRepo := repository.InstanceRepository(pool) - instanceId := gofakeit.Name() - instanceName := gofakeit.Name() - + testFunc: func(t *testing.T) *domain.Instance { inst := domain.Instance{ - ID: instanceId, - Name: instanceName, + ID: gofakeit.NewCrypto().UUID(), + Name: gofakeit.BeerName(), DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", ConsoleClientID: "consoleCLient", @@ -358,10 +370,9 @@ func TestGetInstance(t *testing.T) { } // create instance - err := instanceRepo.Create(ctx, &inst) + err := instanceRepo.Create(t.Context(), tx, &inst) require.NoError(t, err) - domainRepo := instanceRepo.Domains(false) d := &domain.AddInstanceDomain{ InstanceID: inst.ID, Domain: gofakeit.DomainName(), @@ -369,7 +380,7 @@ func TestGetInstance(t *testing.T) { IsGenerated: gu.Ptr(false), Type: domain.DomainTypeCustom, } - err = domainRepo.Add(ctx, d) + err = domainRepo.Add(t.Context(), tx, d) require.NoError(t, err) inst.Domains = append(inst.Domains, &domain.InstanceDomain{ @@ -387,7 +398,7 @@ func TestGetInstance(t *testing.T) { }, { name: "get non existent instance", - testFunc: func(ctx context.Context, t *testing.T) *domain.Instance { + testFunc: func(t *testing.T) *domain.Instance { inst := domain.Instance{ ID: "get non existent instance", } @@ -398,16 +409,13 @@ func TestGetInstance(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - instanceRepo := repository.InstanceRepository(pool) - var instance *domain.Instance if tt.testFunc != nil { - instance = tt.testFunc(ctx, t) + instance = tt.testFunc(t) } // check instance values - returnedInstance, err := instanceRepo.Get(ctx, + returnedInstance, err := instanceRepo.Get(t.Context(), tx, database.WithCondition( instanceRepo.IDCondition(instance.ID), ), @@ -434,28 +442,33 @@ func TestGetInstance(t *testing.T) { } func TestListInstance(t *testing.T) { - ctx := context.Background() - pool, stop, err := newEmbeddedDB(ctx) + tx, err := pool.Begin(context.Background(), nil) require.NoError(t, err) - defer stop() + defer func() { + err := tx.Rollback(context.Background()) + if err != nil { + t.Log("error during rollback:", err) + } + }() + + instanceRepo := repository.InstanceRepository() type test struct { name string - testFunc func(ctx context.Context, t *testing.T) []*domain.Instance + testFunc func(t *testing.T, tx database.QueryExecutor) []*domain.Instance conditionClauses []database.Condition noInstanceReturned bool } tests := []test{ { name: "happy path single instance no filter", - testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance { - instanceRepo := repository.InstanceRepository(pool) + testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance { noOfInstances := 1 instances := make([]*domain.Instance, noOfInstances) for i := range noOfInstances { inst := domain.Instance{ - ID: gofakeit.Name(), + ID: strconv.Itoa(i), Name: gofakeit.Name(), DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", @@ -465,7 +478,7 @@ func TestListInstance(t *testing.T) { } // create instance - err := instanceRepo.Create(ctx, &inst) + err := instanceRepo.Create(t.Context(), tx, &inst) require.NoError(t, err) instances[i] = &inst @@ -476,14 +489,13 @@ func TestListInstance(t *testing.T) { }, { name: "happy path multiple instance no filter", - testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance { - instanceRepo := repository.InstanceRepository(pool) + testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance { noOfInstances := 5 instances := make([]*domain.Instance, noOfInstances) for i := range noOfInstances { inst := domain.Instance{ - ID: gofakeit.Name(), + ID: strconv.Itoa(i), Name: gofakeit.Name(), DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", @@ -493,7 +505,7 @@ func TestListInstance(t *testing.T) { } // create instance - err := instanceRepo.Create(ctx, &inst) + err := instanceRepo.Create(t.Context(), tx, &inst) require.NoError(t, err) instances[i] = &inst @@ -503,17 +515,16 @@ func TestListInstance(t *testing.T) { }, }, func() test { - instanceRepo := repository.InstanceRepository(pool) - instanceId := gofakeit.Name() + instanceID := gofakeit.BeerName() return test{ name: "instance filter on id", - testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance { + testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance { noOfInstances := 1 instances := make([]*domain.Instance, noOfInstances) for i := range noOfInstances { inst := domain.Instance{ - ID: instanceId, + ID: instanceID, Name: gofakeit.Name(), DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", @@ -523,7 +534,7 @@ func TestListInstance(t *testing.T) { } // create instance - err := instanceRepo.Create(ctx, &inst) + err := instanceRepo.Create(t.Context(), tx, &inst) require.NoError(t, err) instances[i] = &inst @@ -531,21 +542,20 @@ func TestListInstance(t *testing.T) { return instances }, - conditionClauses: []database.Condition{instanceRepo.IDCondition(instanceId)}, + conditionClauses: []database.Condition{instanceRepo.IDCondition(instanceID)}, } }(), func() test { - instanceRepo := repository.InstanceRepository(pool) - instanceName := gofakeit.Name() + instanceName := gofakeit.BeerName() return test{ name: "multiple instance filter on name", - testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance { + testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance { noOfInstances := 5 instances := make([]*domain.Instance, noOfInstances) for i := range noOfInstances { inst := domain.Instance{ - ID: gofakeit.Name(), + ID: strconv.Itoa(i), Name: instanceName, DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", @@ -555,7 +565,7 @@ func TestListInstance(t *testing.T) { } // create instance - err := instanceRepo.Create(ctx, &inst) + err := instanceRepo.Create(t.Context(), tx, &inst) require.NoError(t, err) instances[i] = &inst @@ -569,14 +579,15 @@ func TestListInstance(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - t.Cleanup(func() { - _, err := pool.Exec(ctx, "DELETE FROM zitadel.instances") - require.NoError(t, err) - }) - - instances := tt.testFunc(ctx, t) - - instanceRepo := repository.InstanceRepository(pool) + savepoint, err := tx.Begin(t.Context()) + require.NoError(t, err) + defer func() { + err = savepoint.Rollback(t.Context()) + if err != nil { + t.Log("error during rollback:", err) + } + }() + instances := tt.testFunc(t, savepoint) var condition database.Condition if len(tt.conditionClauses) > 0 { @@ -584,13 +595,13 @@ func TestListInstance(t *testing.T) { } // check instance values - returnedInstances, err := instanceRepo.List(ctx, + returnedInstances, err := instanceRepo.List(t.Context(), tx, database.WithCondition(condition), - database.WithOrderByAscending(instanceRepo.CreatedAtColumn()), + database.WithOrderByAscending(instanceRepo.IDColumn()), ) require.NoError(t, err) if tt.noInstanceReturned { - assert.Nil(t, returnedInstances) + assert.Len(t, returnedInstances, 0) return } @@ -609,42 +620,45 @@ func TestListInstance(t *testing.T) { } func TestDeleteInstance(t *testing.T) { + tx, err := pool.Begin(context.Background(), nil) + require.NoError(t, err) + defer func() { + err := tx.Rollback(context.Background()) + if err != nil { + t.Log("error during rollback:", err) + } + }() + + instanceRepo := repository.InstanceRepository() + type test struct { name string - testFunc func(ctx context.Context, t *testing.T) + testFunc func(t *testing.T, tx database.QueryExecutor) instanceID string noOfDeletedRows int64 } tests := []test{ func() test { - instanceRepo := repository.InstanceRepository(pool) - instanceId := gofakeit.Name() - var noOfInstances int64 = 1 + instanceID := gofakeit.NewCrypto().UUID() return test{ name: "happy path delete single instance filter id", - testFunc: func(ctx context.Context, t *testing.T) { - instances := make([]*domain.Instance, noOfInstances) - for i := range noOfInstances { - - inst := domain.Instance{ - ID: instanceId, - Name: gofakeit.Name(), - DefaultOrgID: "defaultOrgId", - IAMProjectID: "iamProject", - ConsoleClientID: "consoleCLient", - ConsoleAppID: "consoleApp", - DefaultLanguage: "defaultLanguage", - } - - // create instance - err := instanceRepo.Create(ctx, &inst) - require.NoError(t, err) - - instances[i] = &inst + testFunc: func(t *testing.T, tx database.QueryExecutor) { + inst := domain.Instance{ + ID: instanceID, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", } + + // create instance + err := instanceRepo.Create(t.Context(), tx, &inst) + require.NoError(t, err) }, - instanceID: instanceId, - noOfDeletedRows: noOfInstances, + instanceID: instanceID, + noOfDeletedRows: 1, } }(), func() test { @@ -655,40 +669,33 @@ func TestDeleteInstance(t *testing.T) { } }(), func() test { - instanceRepo := repository.InstanceRepository(pool) - instanceName := gofakeit.Name() + instanceID := gofakeit.Name() return test{ name: "deleted already deleted instance", - testFunc: func(ctx context.Context, t *testing.T) { - noOfInstances := 1 - instances := make([]*domain.Instance, noOfInstances) - for i := range noOfInstances { + testFunc: func(t *testing.T, tx database.QueryExecutor) { - inst := domain.Instance{ - ID: gofakeit.Name(), - Name: instanceName, - DefaultOrgID: "defaultOrgId", - IAMProjectID: "iamProject", - ConsoleClientID: "consoleCLient", - ConsoleAppID: "consoleApp", - DefaultLanguage: "defaultLanguage", - } - - // create instance - err := instanceRepo.Create(ctx, &inst) - require.NoError(t, err) - - instances[i] = &inst + inst := domain.Instance{ + ID: instanceID, + Name: gofakeit.BeerName(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", } + // create instance + err := instanceRepo.Create(t.Context(), tx, &inst) + require.NoError(t, err) + // delete instance - affectedRows, err := instanceRepo.Delete(ctx, - instances[0].ID, + affectedRows, err := instanceRepo.Delete(t.Context(), tx, + inst.ID, ) require.NoError(t, err) assert.Equal(t, int64(1), affectedRows) }, - instanceID: instanceName, + instanceID: instanceID, // this test should return 0 affected rows as the instance was already deleted noOfDeletedRows: 0, } @@ -696,22 +703,26 @@ func TestDeleteInstance(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - instanceRepo := repository.InstanceRepository(pool) + savepoint, err := tx.Begin(t.Context()) + require.NoError(t, err) + defer func() { + err = savepoint.Rollback(t.Context()) + if err != nil { + t.Log("error during rollback:", err) + } + }() if tt.testFunc != nil { - tt.testFunc(ctx, t) + tt.testFunc(t, savepoint) } // delete instance - noOfDeletedRows, err := instanceRepo.Delete(ctx, - tt.instanceID, - ) + noOfDeletedRows, err := instanceRepo.Delete(t.Context(), savepoint, tt.instanceID) require.NoError(t, err) assert.Equal(t, noOfDeletedRows, tt.noOfDeletedRows) // check instance was deleted - instance, err := instanceRepo.Get(ctx, + instance, err := instanceRepo.Get(t.Context(), savepoint, database.WithCondition( instanceRepo.IDCondition(tt.instanceID), ), diff --git a/backend/v3/storage/database/repository/org.go b/backend/v3/storage/database/repository/org.go index 85536c96dcb..515664879a0 100644 --- a/backend/v3/storage/database/repository/org.go +++ b/backend/v3/storage/database/repository/org.go @@ -14,25 +14,24 @@ import ( var _ domain.OrganizationRepository = (*org)(nil) type org struct { - repository shouldLoadDomains bool - domainRepo domain.OrganizationDomainRepository + domainRepo orgDomain } -func OrganizationRepository(client database.QueryExecutor) domain.OrganizationRepository { - return &org{ - repository: repository{ - client: client, - }, - } +func (o org) unqualifiedTableName() string { + return "organizations" +} + +func OrganizationRepository() domain.OrganizationRepository { + return new(org) } const queryOrganizationStmt = `SELECT organizations.id, organizations.name, organizations.instance_id, organizations.state, organizations.created_at, organizations.updated_at` + - ` , jsonb_agg(json_build_object('domain', org_domains.domain, 'isVerified', org_domains.is_verified, 'isPrimary', org_domains.is_primary, 'validationType', org_domains.validation_type, 'createdAt', org_domains.created_at, 'updatedAt', org_domains.updated_at)) FILTER (WHERE org_domains.org_id IS NOT NULL) AS domains` + + ` , jsonb_agg(json_build_object('instanceId', org_domains.instance_id, 'orgId', org_domains.org_id, 'domain', org_domains.domain, 'isVerified', org_domains.is_verified, 'isPrimary', org_domains.is_primary, 'validationType', org_domains.validation_type, 'createdAt', org_domains.created_at, 'updatedAt', org_domains.updated_at)) FILTER (WHERE org_domains.org_id IS NOT NULL) AS domains` + ` FROM zitadel.organizations` // Get implements [domain.OrganizationRepository]. -func (o *org) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Organization, error) { +func (o org) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.Organization, error) { opts = append(opts, o.joinDomains(), database.WithGroupBy(o.InstanceIDColumn(), o.IDColumn()), @@ -43,15 +42,19 @@ func (o *org) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Or opt(options) } + if !options.Condition.IsRestrictingColumn(o.InstanceIDColumn()) { + return nil, database.NewMissingConditionError(o.InstanceIDColumn()) + } + var builder database.StatementBuilder builder.WriteString(queryOrganizationStmt) options.Write(&builder) - return scanOrganization(ctx, o.client, &builder) + return scanOrganization(ctx, client, &builder) } // List implements [domain.OrganizationRepository]. -func (o *org) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Organization, error) { +func (o org) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*domain.Organization, error) { opts = append(opts, o.joinDomains(), database.WithGroupBy(o.InstanceIDColumn(), o.IDColumn()), @@ -62,30 +65,15 @@ func (o *org) List(ctx context.Context, opts ...database.QueryOption) ([]*domain opt(options) } + if !options.Condition.IsRestrictingColumn(o.InstanceIDColumn()) { + return nil, database.NewMissingConditionError(o.InstanceIDColumn()) + } + var builder database.StatementBuilder builder.WriteString(queryOrganizationStmt) options.Write(&builder) - return scanOrganizations(ctx, o.client, &builder) -} - -func (o *org) joinDomains() database.QueryOption { - columns := make([]database.Condition, 0, 3) - columns = append(columns, - database.NewColumnCondition(o.InstanceIDColumn(), o.Domains(false).InstanceIDColumn()), - database.NewColumnCondition(o.IDColumn(), o.Domains(false).OrgIDColumn()), - ) - - // If domains should not be joined, we make sure to return null for the domain columns - // the query optimizer of the dialect should optimize this away if no domains are requested - if !o.shouldLoadDomains { - columns = append(columns, database.IsNull(o.domainRepo.OrgIDColumn())) - } - - return database.WithLeftJoin( - "zitadel.org_domains", - database.And(columns...), - ) + return scanOrganizations(ctx, client, &builder) } const createOrganizationStmt = `INSERT INTO zitadel.organizations (id, name, instance_id, state)` + @@ -93,46 +81,48 @@ const createOrganizationStmt = `INSERT INTO zitadel.organizations (id, name, ins ` RETURNING created_at, updated_at` // Create implements [domain.OrganizationRepository]. -func (o *org) Create(ctx context.Context, organization *domain.Organization) error { +func (o org) Create(ctx context.Context, client database.QueryExecutor, organization *domain.Organization) error { builder := database.StatementBuilder{} builder.AppendArgs(organization.ID, organization.Name, organization.InstanceID, organization.State) builder.WriteString(createOrganizationStmt) - return o.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&organization.CreatedAt, &organization.UpdatedAt) + return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&organization.CreatedAt, &organization.UpdatedAt) } // Update implements [domain.OrganizationRepository]. -func (o *org) Update(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string, changes ...database.Change) (int64, error) { +func (o org) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error) { if len(changes) == 0 { return 0, database.ErrNoChanges } - builder := database.StatementBuilder{} + if !condition.IsRestrictingColumn(o.InstanceIDColumn()) { + return 0, database.NewMissingConditionError(o.InstanceIDColumn()) + } + if !database.Changes(changes).IsOnColumn(o.UpdatedAtColumn()) { + changes = append(changes, database.NewChange(o.UpdatedAtColumn(), database.NullInstruction)) + } + + var builder database.StatementBuilder builder.WriteString(`UPDATE zitadel.organizations SET `) - - instanceIDCondition := o.InstanceIDCondition(instanceID) - - conditions := []database.Condition{id, instanceIDCondition} database.Changes(changes).Write(&builder) - writeCondition(&builder, database.And(conditions...)) + writeCondition(&builder, condition) stmt := builder.String() - rowsAffected, err := o.client.Exec(ctx, stmt, builder.Args()...) + rowsAffected, err := client.Exec(ctx, stmt, builder.Args()...) return rowsAffected, err } // Delete implements [domain.OrganizationRepository]. -func (o *org) Delete(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string) (int64, error) { - builder := database.StatementBuilder{} +func (o org) Delete(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error) { + if !condition.IsRestrictingColumn(o.InstanceIDColumn()) { + return 0, database.NewMissingConditionError(o.InstanceIDColumn()) + } + var builder database.StatementBuilder builder.WriteString(`DELETE FROM zitadel.organizations`) + writeCondition(&builder, condition) - instanceIDCondition := o.InstanceIDCondition(instanceID) - - conditions := []database.Condition{id, instanceIDCondition} - writeCondition(&builder, database.And(conditions...)) - - return o.client.Exec(ctx, builder.String(), builder.Args()...) + return client.Exec(ctx, builder.String(), builder.Args()...) } // ------------------------------------------------------------- @@ -154,13 +144,13 @@ func (o org) SetState(state domain.OrgState) database.Change { // ------------------------------------------------------------- // IDCondition implements [domain.organizationConditions]. -func (o org) IDCondition(id string) domain.OrgIdentifierCondition { +func (o org) IDCondition(id string) database.Condition { return database.NewTextCondition(o.IDColumn(), database.TextOperationEqual, id) } // NameCondition implements [domain.organizationConditions]. -func (o org) NameCondition(name string) domain.OrgIdentifierCondition { - return database.NewTextCondition(o.NameColumn(), database.TextOperationEqual, name) +func (o org) NameCondition(op database.TextOperation, name string) database.Condition { + return database.NewTextCondition(o.NameColumn(), op, name) } // InstanceIDCondition implements [domain.organizationConditions]. @@ -173,38 +163,62 @@ func (o org) StateCondition(state domain.OrgState) database.Condition { return database.NewTextCondition(o.StateColumn(), database.TextOperationEqual, state.String()) } +// ExistsDomain creates a correlated [database.Exists] condition on org_domains. +// Use this filter to make sure the organization returned contains a specific domain. +// Example usage: +// +// domainRepo := orgRepo.Domains(true) // ensure domains are loaded/aggregated +// org, _ := orgRepo.Get(ctx, +// database.WithCondition( +// database.And( +// orgRepo.InstanceIDCondition(instanceID), +// orgRepo.ExistsDomain(domainRepo.DomainCondition(database.TextOperationEqual, "example.com")), +// ), +// ), +// ) +func (o org) ExistsDomain(cond database.Condition) database.Condition { + return database.Exists( + o.domainRepo.qualifiedTableName(), + database.And( + database.NewColumnCondition(o.InstanceIDColumn(), o.domainRepo.InstanceIDColumn()), + database.NewColumnCondition(o.IDColumn(), o.domainRepo.OrgIDColumn()), + cond, + ), + ) +} + // ------------------------------------------------------------- // columns // ------------------------------------------------------------- // IDColumn implements [domain.organizationColumns]. -func (org) IDColumn() database.Column { - return database.NewColumn("organizations", "id") +func (o org) IDColumn() database.Column { + return database.NewColumn(o.unqualifiedTableName(), "id") } // NameColumn implements [domain.organizationColumns]. -func (org) NameColumn() database.Column { - return database.NewColumn("organizations", "name") +func (o org) NameColumn() database.Column { + return database.NewColumn(o.unqualifiedTableName(), "name") } // InstanceIDColumn implements [domain.organizationColumns]. -func (org) InstanceIDColumn() database.Column { - return database.NewColumn("organizations", "instance_id") +func (o org) InstanceIDColumn() database.Column { + return database.NewColumn(o.unqualifiedTableName(), "instance_id") } // StateColumn implements [domain.organizationColumns]. -func (org) StateColumn() database.Column { - return database.NewColumn("organizations", "state") +func (o org) StateColumn() database.Column { + return database.NewColumn(o.unqualifiedTableName(), "state") } // CreatedAtColumn implements [domain.organizationColumns]. -func (org) CreatedAtColumn() database.Column { - return database.NewColumn("organizations", "created_at") +func (o org) CreatedAtColumn() database.Column { + return database.NewColumn(o.unqualifiedTableName(), "created_at") } // UpdatedAtColumn implements [domain.organizationColumns]. -func (org) UpdatedAtColumn() database.Column { - return database.NewColumn("organizations", "updated_at") +func (o org) UpdatedAtColumn() database.Column { + return database.NewColumn(o.unqualifiedTableName(), "updated_at") } // ------------------------------------------------------------- @@ -255,20 +269,27 @@ func scanOrganizations(ctx context.Context, querier database.Querier, builder *d // sub repositories // ------------------------------------------------------------- -// Domains implements [domain.OrganizationRepository]. -func (o *org) Domains(shouldLoad bool) domain.OrganizationDomainRepository { - if !o.shouldLoadDomains { - o.shouldLoadDomains = shouldLoad +func (o org) LoadDomains() domain.OrganizationRepository { + return &org{ + shouldLoadDomains: true, } - - if o.domainRepo != nil { - return o.domainRepo - } - - o.domainRepo = &orgDomain{ - repository: o.repository, - org: o, - } - - return o.domainRepo +} + +func (o org) joinDomains() database.QueryOption { + columns := make([]database.Condition, 0, 3) + columns = append(columns, + database.NewColumnCondition(o.InstanceIDColumn(), o.domainRepo.InstanceIDColumn()), + database.NewColumnCondition(o.IDColumn(), o.domainRepo.OrgIDColumn()), + ) + + // If domains should not be joined, we make sure to return null for the domain columns + // the query optimizer of the dialect should optimize this away if no domains are requested + if !o.shouldLoadDomains { + columns = append(columns, database.IsNull(o.domainRepo.OrgIDColumn())) + } + + return database.WithLeftJoin( + o.domainRepo.qualifiedTableName(), + database.And(columns...), + ) } diff --git a/backend/v3/storage/database/repository/org_domain.go b/backend/v3/storage/database/repository/org_domain.go index 5b2e91acf28..4d73647edf2 100644 --- a/backend/v3/storage/database/repository/org_domain.go +++ b/backend/v3/storage/database/repository/org_domain.go @@ -10,9 +10,18 @@ import ( var _ domain.OrganizationDomainRepository = (*orgDomain)(nil) -type orgDomain struct { - repository - *org +type orgDomain struct{} + +func OrganizationDomainRepository() domain.OrganizationDomainRepository { + return new(orgDomain) +} + +func (orgDomain) qualifiedTableName() string { + return "zitadel.org_domains" +} + +func (orgDomain) unqualifiedTableName() string { + return "org_domains" } // ------------------------------------------------------------- @@ -24,36 +33,44 @@ const queryOrganizationDomainStmt = `SELECT instance_id, org_id, domain, is_veri // Get implements [domain.OrganizationDomainRepository]. // Subtle: this method shadows the method ([domain.OrganizationRepository]).Get of orgDomain.org. -func (o *orgDomain) Get(ctx context.Context, opts ...database.QueryOption) (*domain.OrganizationDomain, error) { +func (o orgDomain) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.OrganizationDomain, error) { options := new(database.QueryOpts) for _, opt := range opts { opt(options) } + if !options.Condition.IsRestrictingColumn(o.InstanceIDColumn()) { + return nil, database.NewMissingConditionError(o.InstanceIDColumn()) + } + var builder database.StatementBuilder builder.WriteString(queryOrganizationDomainStmt) options.Write(&builder) - return scanOrganizationDomain(ctx, o.client, &builder) + return scanOrganizationDomain(ctx, client, &builder) } // List implements [domain.OrganizationDomainRepository]. // Subtle: this method shadows the method ([domain.OrganizationRepository]).List of orgDomain.org. -func (o *orgDomain) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.OrganizationDomain, error) { +func (o orgDomain) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*domain.OrganizationDomain, error) { options := new(database.QueryOpts) for _, opt := range opts { opt(options) } + if !options.Condition.IsRestrictingColumn(o.InstanceIDColumn()) { + return nil, database.NewMissingConditionError(o.InstanceIDColumn()) + } + var builder database.StatementBuilder builder.WriteString(queryOrganizationDomainStmt) options.Write(&builder) - return scanOrganizationDomains(ctx, o.client, &builder) + return scanOrganizationDomains(ctx, client, &builder) } // Add implements [domain.OrganizationDomainRepository]. -func (o *orgDomain) Add(ctx context.Context, domain *domain.AddOrganizationDomain) error { +func (o orgDomain) Add(ctx context.Context, client database.QueryExecutor, domain *domain.AddOrganizationDomain) error { var ( builder database.StatementBuilder createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction @@ -69,33 +86,47 @@ func (o *orgDomain) Add(ctx context.Context, domain *domain.AddOrganizationDomai builder.WriteArgs(domain.InstanceID, domain.OrgID, domain.Domain, domain.IsVerified, domain.IsPrimary, domain.ValidationType, createdAt, updatedAt) builder.WriteString(`) RETURNING created_at, updated_at`) - return o.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt) + return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt) } // Update implements [domain.OrganizationDomainRepository]. // Subtle: this method shadows the method ([domain.OrganizationRepository]).Update of orgDomain.org. -func (o *orgDomain) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) { +func (o orgDomain) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error) { if len(changes) == 0 { return 0, database.ErrNoChanges } + if !condition.IsRestrictingColumn(o.InstanceIDColumn()) { + return 0, database.NewMissingConditionError(o.InstanceIDColumn()) + } + if !condition.IsRestrictingColumn(o.OrgIDColumn()) { + return 0, database.NewMissingConditionError(o.OrgIDColumn()) + } + if !database.Changes(changes).IsOnColumn(o.UpdatedAtColumn()) { + changes = append(changes, database.NewChange(o.UpdatedAtColumn(), database.NullInstruction)) + } var builder database.StatementBuilder - builder.WriteString(`UPDATE zitadel.org_domains SET `) database.Changes(changes).Write(&builder) writeCondition(&builder, condition) - return o.client.Exec(ctx, builder.String(), builder.Args()...) + return client.Exec(ctx, builder.String(), builder.Args()...) } // Remove implements [domain.OrganizationDomainRepository]. -func (o *orgDomain) Remove(ctx context.Context, condition database.Condition) (int64, error) { - var builder database.StatementBuilder +func (o orgDomain) Remove(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error) { + if !condition.IsRestrictingColumn(o.InstanceIDColumn()) { + return 0, database.NewMissingConditionError(o.InstanceIDColumn()) + } + if !condition.IsRestrictingColumn(o.OrgIDColumn()) { + return 0, database.NewMissingConditionError(o.OrgIDColumn()) + } + var builder database.StatementBuilder builder.WriteString(`DELETE FROM zitadel.org_domains `) writeCondition(&builder, condition) - return o.client.Exec(ctx, builder.String(), builder.Args()...) + return client.Exec(ctx, builder.String(), builder.Args()...) } // ------------------------------------------------------------- @@ -158,45 +189,43 @@ func (o orgDomain) OrgIDCondition(orgID string) database.Condition { // CreatedAtColumn implements [domain.OrganizationDomainRepository]. // Subtle: this method shadows the method ([domain.OrganizationRepository]).CreatedAtColumn of orgDomain.org. -func (orgDomain) CreatedAtColumn() database.Column { - return database.NewColumn("org_domains", "created_at") +func (o orgDomain) CreatedAtColumn() database.Column { + return database.NewColumn(o.unqualifiedTableName(), "created_at") } // DomainColumn implements [domain.OrganizationDomainRepository]. -func (orgDomain) DomainColumn() database.Column { - return database.NewColumn("org_domains", "domain") +func (o orgDomain) DomainColumn() database.Column { + return database.NewColumn(o.unqualifiedTableName(), "domain") } // InstanceIDColumn implements [domain.OrganizationDomainRepository]. -// Subtle: this method shadows the method ([domain.OrganizationRepository]).InstanceIDColumn of orgDomain.org. -func (orgDomain) InstanceIDColumn() database.Column { - return database.NewColumn("org_domains", "instance_id") +func (o orgDomain) InstanceIDColumn() database.Column { + return database.NewColumn(o.unqualifiedTableName(), "instance_id") } // IsPrimaryColumn implements [domain.OrganizationDomainRepository]. -func (orgDomain) IsPrimaryColumn() database.Column { - return database.NewColumn("org_domains", "is_primary") +func (o orgDomain) IsPrimaryColumn() database.Column { + return database.NewColumn(o.unqualifiedTableName(), "is_primary") } // IsVerifiedColumn implements [domain.OrganizationDomainRepository]. -func (orgDomain) IsVerifiedColumn() database.Column { - return database.NewColumn("org_domains", "is_verified") +func (o orgDomain) IsVerifiedColumn() database.Column { + return database.NewColumn(o.unqualifiedTableName(), "is_verified") } // OrgIDColumn implements [domain.OrganizationDomainRepository]. -func (orgDomain) OrgIDColumn() database.Column { - return database.NewColumn("org_domains", "org_id") +func (o orgDomain) OrgIDColumn() database.Column { + return database.NewColumn(o.unqualifiedTableName(), "org_id") } // UpdatedAtColumn implements [domain.OrganizationDomainRepository]. -// Subtle: this method shadows the method ([domain.OrganizationRepository]).UpdatedAtColumn of orgDomain.org. -func (orgDomain) UpdatedAtColumn() database.Column { - return database.NewColumn("org_domains", "updated_at") +func (o orgDomain) UpdatedAtColumn() database.Column { + return database.NewColumn(o.unqualifiedTableName(), "updated_at") } // ValidationTypeColumn implements [domain.OrganizationDomainRepository]. -func (orgDomain) ValidationTypeColumn() database.Column { - return database.NewColumn("org_domains", "validation_type") +func (o orgDomain) ValidationTypeColumn() database.Column { + return database.NewColumn(o.unqualifiedTableName(), "validation_type") } // ------------------------------------------------------------- diff --git a/backend/v3/storage/database/repository/org_domain_test.go b/backend/v3/storage/database/repository/org_domain_test.go index a06d9eeee14..c0a108a2712 100644 --- a/backend/v3/storage/database/repository/org_domain_test.go +++ b/backend/v3/storage/database/repository/org_domain_test.go @@ -1,7 +1,6 @@ package repository_test import ( - "context" "testing" "github.com/brianvoe/gofakeit/v6" @@ -15,6 +14,15 @@ import ( ) func TestAddOrganizationDomain(t *testing.T) { + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + err = tx.Rollback(t.Context()) + if err != nil { + t.Log("error during rollback:", err) + } + }() + // create instance instanceID := gofakeit.UUID() instance := domain.Instance{ @@ -26,8 +34,11 @@ func TestAddOrganizationDomain(t *testing.T) { ConsoleAppID: "consoleApp", DefaultLanguage: "defaultLanguage", } - instanceRepo := repository.InstanceRepository(pool) - err := instanceRepo.Create(t.Context(), &instance) + instanceRepo := repository.InstanceRepository() + orgRepo := repository.OrganizationRepository() + domainRepo := repository.OrganizationDomainRepository() + + err = instanceRepo.Create(t.Context(), tx, &instance) require.NoError(t, err) // create organization @@ -41,7 +52,7 @@ func TestAddOrganizationDomain(t *testing.T) { tests := []struct { name string - testFunc func(ctx context.Context, t *testing.T, domainRepo domain.OrganizationDomainRepository) *domain.AddOrganizationDomain + testFunc func(t *testing.T, tx database.QueryExecutor) *domain.AddOrganizationDomain organizationDomain domain.AddOrganizationDomain err error }{ @@ -92,7 +103,7 @@ func TestAddOrganizationDomain(t *testing.T) { }, { name: "add domain with same domain twice", - testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.OrganizationDomainRepository) *domain.AddOrganizationDomain { + testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.AddOrganizationDomain { domainName := gofakeit.DomainName() organizationDomain := &domain.AddOrganizationDomain{ @@ -104,7 +115,7 @@ func TestAddOrganizationDomain(t *testing.T) { ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), } - err := domainRepo.Add(ctx, organizationDomain) + err := domainRepo.Add(t.Context(), tx, organizationDomain) require.NoError(t, err) // return same domain again @@ -169,28 +180,26 @@ func TestAddOrganizationDomain(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := t.Context() - - tx, err := pool.Begin(t.Context(), nil) + savepoint, err := tx.Begin(t.Context()) require.NoError(t, err) defer func() { - require.NoError(t, tx.Rollback(t.Context())) + err = savepoint.Rollback(t.Context()) + if err != nil { + t.Log("error during rollback:", err) + } }() - orgRepo := repository.OrganizationRepository(tx) - err = orgRepo.Create(t.Context(), &organization) + err = orgRepo.Create(t.Context(), savepoint, &organization) require.NoError(t, err) - domainRepo := orgRepo.Domains(false) - var organizationDomain *domain.AddOrganizationDomain if test.testFunc != nil { - organizationDomain = test.testFunc(ctx, t, domainRepo) + organizationDomain = test.testFunc(t, savepoint) } else { organizationDomain = &test.organizationDomain } - err = domainRepo.Add(ctx, organizationDomain) + err = domainRepo.Add(t.Context(), tx, organizationDomain) if test.err != nil { assert.ErrorIs(t, err, test.err) return @@ -204,6 +213,16 @@ func TestAddOrganizationDomain(t *testing.T) { } func TestGetOrganizationDomain(t *testing.T) { + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + require.NoError(t, tx.Rollback(t.Context())) + }() + + instanceRepo := repository.InstanceRepository() + orgRepo := repository.OrganizationRepository() + domainRepo := repository.OrganizationDomainRepository() + // create instance instanceID := gofakeit.UUID() instance := domain.Instance{ @@ -225,29 +244,18 @@ func TestGetOrganizationDomain(t *testing.T) { State: domain.OrgStateActive, } - tx, err := pool.Begin(t.Context(), nil) - require.NoError(t, err) - defer func() { - require.NoError(t, tx.Rollback(t.Context())) - }() - - instanceRepo := repository.InstanceRepository(tx) - err = instanceRepo.Create(t.Context(), &instance) + err = instanceRepo.Create(t.Context(), tx, &instance) require.NoError(t, err) - orgRepo := repository.OrganizationRepository(tx) - err = orgRepo.Create(t.Context(), &organization) + err = orgRepo.Create(t.Context(), tx, &organization) require.NoError(t, err) // add domains - domainRepo := orgRepo.Domains(false) - domainName1 := gofakeit.DomainName() - domainName2 := gofakeit.DomainName() domain1 := &domain.AddOrganizationDomain{ InstanceID: instanceID, OrgID: orgID, - Domain: domainName1, + Domain: gofakeit.DomainName(), IsVerified: true, IsPrimary: true, ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), @@ -255,15 +263,15 @@ func TestGetOrganizationDomain(t *testing.T) { domain2 := &domain.AddOrganizationDomain{ InstanceID: instanceID, OrgID: orgID, - Domain: domainName2, + Domain: gofakeit.DomainName(), IsVerified: false, IsPrimary: false, ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP), } - err = domainRepo.Add(t.Context(), domain1) + err = domainRepo.Add(t.Context(), tx, domain1) require.NoError(t, err) - err = domainRepo.Add(t.Context(), domain2) + err = domainRepo.Add(t.Context(), tx, domain2) require.NoError(t, err) tests := []struct { @@ -275,12 +283,15 @@ func TestGetOrganizationDomain(t *testing.T) { { name: "get primary domain", opts: []database.QueryOption{ - database.WithCondition(domainRepo.IsPrimaryCondition(true)), + database.WithCondition(database.And( + domainRepo.InstanceIDCondition(instanceID), + domainRepo.IsPrimaryCondition(true), + )), }, expected: &domain.OrganizationDomain{ InstanceID: instanceID, OrgID: orgID, - Domain: domainName1, + Domain: domain1.Domain, IsVerified: true, IsPrimary: true, ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), @@ -289,12 +300,15 @@ func TestGetOrganizationDomain(t *testing.T) { { name: "get by domain name", opts: []database.QueryOption{ - database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, domainName2)), + database.WithCondition(database.And( + domainRepo.InstanceIDCondition(instanceID), + domainRepo.DomainCondition(database.TextOperationEqual, domain2.Domain), + )), }, expected: &domain.OrganizationDomain{ InstanceID: instanceID, OrgID: orgID, - Domain: domainName2, + Domain: domain2.Domain, IsVerified: false, IsPrimary: false, ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP), @@ -303,13 +317,16 @@ func TestGetOrganizationDomain(t *testing.T) { { name: "get by org ID", opts: []database.QueryOption{ - database.WithCondition(domainRepo.OrgIDCondition(orgID)), - database.WithCondition(domainRepo.IsPrimaryCondition(true)), + database.WithCondition(database.And( + domainRepo.InstanceIDCondition(instanceID), + domainRepo.OrgIDCondition(orgID), + domainRepo.IsPrimaryCondition(true), + )), }, expected: &domain.OrganizationDomain{ InstanceID: instanceID, OrgID: orgID, - Domain: domainName1, + Domain: domain1.Domain, IsVerified: true, IsPrimary: true, ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), @@ -318,12 +335,15 @@ func TestGetOrganizationDomain(t *testing.T) { { name: "get verified domain", opts: []database.QueryOption{ - database.WithCondition(domainRepo.IsVerifiedCondition(true)), + database.WithCondition(database.And( + domainRepo.InstanceIDCondition(instanceID), + domainRepo.IsVerifiedCondition(true), + )), }, expected: &domain.OrganizationDomain{ InstanceID: instanceID, OrgID: orgID, - Domain: domainName1, + Domain: domain1.Domain, IsVerified: true, IsPrimary: true, ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), @@ -332,7 +352,10 @@ func TestGetOrganizationDomain(t *testing.T) { { name: "get non-existent domain", opts: []database.QueryOption{ - database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com")), + database.WithCondition(database.And( + domainRepo.InstanceIDCondition(instanceID), + domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"), + )), }, err: new(database.NoRowFoundError), }, @@ -340,9 +363,7 @@ func TestGetOrganizationDomain(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := t.Context() - - result, err := domainRepo.Get(ctx, test.opts...) + result, err := domainRepo.Get(t.Context(), tx, test.opts...) if test.err != nil { assert.ErrorIs(t, err, test.err) return @@ -362,10 +383,22 @@ func TestGetOrganizationDomain(t *testing.T) { } func TestListOrganizationDomains(t *testing.T) { + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + err = tx.Rollback(t.Context()) + if err != nil { + t.Log("error during rollback:", err) + } + }() + + instanceRepo := repository.InstanceRepository() + orgRepo := repository.OrganizationRepository() + domainRepo := repository.OrganizationDomainRepository() + // create instance - instanceID := gofakeit.UUID() instance := domain.Instance{ - ID: instanceID, + ID: gofakeit.UUID(), Name: gofakeit.Name(), DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", @@ -375,50 +408,40 @@ func TestListOrganizationDomains(t *testing.T) { } // create organization - orgID := gofakeit.UUID() organization := domain.Organization{ - ID: orgID, + ID: gofakeit.UUID(), Name: gofakeit.Name(), - InstanceID: instanceID, + InstanceID: instance.ID, State: domain.OrgStateActive, } - tx, err := pool.Begin(t.Context(), nil) - require.NoError(t, err) - defer func() { - require.NoError(t, tx.Rollback(t.Context())) - }() - - instanceRepo := repository.InstanceRepository(tx) - err = instanceRepo.Create(t.Context(), &instance) + err = instanceRepo.Create(t.Context(), tx, &instance) require.NoError(t, err) - orgRepo := repository.OrganizationRepository(tx) - err = orgRepo.Create(t.Context(), &organization) + err = orgRepo.Create(t.Context(), tx, &organization) require.NoError(t, err) // add multiple domains - domainRepo := orgRepo.Domains(false) domains := []domain.AddOrganizationDomain{ { - InstanceID: instanceID, - OrgID: orgID, + InstanceID: instance.ID, + OrgID: organization.ID, Domain: gofakeit.DomainName(), IsVerified: true, IsPrimary: true, ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), }, { - InstanceID: instanceID, - OrgID: orgID, + InstanceID: instance.ID, + OrgID: organization.ID, Domain: gofakeit.DomainName(), IsVerified: false, IsPrimary: false, ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP), }, { - InstanceID: instanceID, - OrgID: orgID, + InstanceID: instance.ID, + OrgID: organization.ID, Domain: gofakeit.DomainName(), IsVerified: true, IsPrimary: false, @@ -427,7 +450,7 @@ func TestListOrganizationDomains(t *testing.T) { } for i := range domains { - err = domainRepo.Add(t.Context(), &domains[i]) + err = domainRepo.Add(t.Context(), tx, &domains[i]) require.NoError(t, err) } @@ -437,42 +460,59 @@ func TestListOrganizationDomains(t *testing.T) { expectedCount int }{ { - name: "list all domains", - opts: []database.QueryOption{}, + name: "list all domains", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.InstanceIDCondition(instance.ID)), + }, expectedCount: 3, }, { name: "list verified domains", opts: []database.QueryOption{ - database.WithCondition(domainRepo.IsVerifiedCondition(true)), + database.WithCondition(database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.IsVerifiedCondition(true), + )), }, expectedCount: 2, }, { name: "list primary domains", opts: []database.QueryOption{ - database.WithCondition(domainRepo.IsPrimaryCondition(true)), + database.WithCondition(database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.IsPrimaryCondition(true), + )), }, expectedCount: 1, }, { name: "list by organization", opts: []database.QueryOption{ - database.WithCondition(domainRepo.OrgIDCondition(orgID)), + database.WithCondition(database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.OrgIDCondition(organization.ID), + )), }, expectedCount: 3, }, { name: "list by instance", opts: []database.QueryOption{ - database.WithCondition(domainRepo.InstanceIDCondition(instanceID)), + database.WithCondition(database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.InstanceIDCondition(instance.ID), + )), }, expectedCount: 3, }, { name: "list non-existent organization", opts: []database.QueryOption{ - database.WithCondition(domainRepo.OrgIDCondition("non-existent")), + database.WithCondition(database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.OrgIDCondition("non-existent"), + )), }, expectedCount: 0, }, @@ -482,13 +522,13 @@ func TestListOrganizationDomains(t *testing.T) { t.Run(test.name, func(t *testing.T) { ctx := t.Context() - results, err := domainRepo.List(ctx, test.opts...) + results, err := domainRepo.List(ctx, tx, test.opts...) require.NoError(t, err) assert.Len(t, results, test.expectedCount) for _, result := range results { - assert.Equal(t, instanceID, result.InstanceID) - assert.Equal(t, orgID, result.OrgID) + assert.Equal(t, instance.ID, result.InstanceID) + assert.Equal(t, organization.ID, result.OrgID) assert.NotEmpty(t, result.Domain) assert.NotEmpty(t, result.CreatedAt) assert.NotEmpty(t, result.UpdatedAt) @@ -498,10 +538,22 @@ func TestListOrganizationDomains(t *testing.T) { } func TestUpdateOrganizationDomain(t *testing.T) { + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + err = tx.Rollback(t.Context()) + if err != nil { + t.Log("error during rollback:", err) + } + }() + + instanceRepo := repository.InstanceRepository() + orgRepo := repository.OrganizationRepository() + domainRepo := repository.OrganizationDomainRepository() + // create instance - instanceID := gofakeit.UUID() instance := domain.Instance{ - ID: instanceID, + ID: gofakeit.UUID(), Name: gofakeit.Name(), DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", @@ -511,41 +563,30 @@ func TestUpdateOrganizationDomain(t *testing.T) { } // create organization - orgID := gofakeit.UUID() organization := domain.Organization{ - ID: orgID, + ID: gofakeit.UUID(), Name: gofakeit.Name(), - InstanceID: instanceID, + InstanceID: instance.ID, State: domain.OrgStateActive, } - tx, err := pool.Begin(t.Context(), nil) - require.NoError(t, err) - defer func() { - require.NoError(t, tx.Rollback(t.Context())) - }() - - instanceRepo := repository.InstanceRepository(tx) - err = instanceRepo.Create(t.Context(), &instance) + err = instanceRepo.Create(t.Context(), tx, &instance) require.NoError(t, err) - orgRepo := repository.OrganizationRepository(tx) - err = orgRepo.Create(t.Context(), &organization) + err = orgRepo.Create(t.Context(), tx, &organization) require.NoError(t, err) // add domain - domainRepo := orgRepo.Domains(false) - domainName := gofakeit.DomainName() organizationDomain := &domain.AddOrganizationDomain{ - InstanceID: instanceID, - OrgID: orgID, - Domain: domainName, + InstanceID: instance.ID, + OrgID: organization.ID, + Domain: gofakeit.DomainName(), IsVerified: false, IsPrimary: false, ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), } - err = domainRepo.Add(t.Context(), organizationDomain) + err = domainRepo.Add(t.Context(), tx, organizationDomain) require.NoError(t, err) tests := []struct { @@ -556,26 +597,42 @@ func TestUpdateOrganizationDomain(t *testing.T) { err error }{ { - name: "set verified", - condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName), - changes: []database.Change{domainRepo.SetVerified()}, - expected: 1, + name: "set verified", + condition: database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.OrgIDCondition(organization.ID), + domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain), + ), + changes: []database.Change{domainRepo.SetVerified()}, + expected: 1, }, { - name: "set primary", - condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName), - changes: []database.Change{domainRepo.SetPrimary()}, - expected: 1, + name: "set primary", + condition: database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.OrgIDCondition(organization.ID), + domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain), + ), + changes: []database.Change{domainRepo.SetPrimary()}, + expected: 1, }, { - name: "set validation type", - condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName), - changes: []database.Change{domainRepo.SetValidationType(domain.DomainValidationTypeHTTP)}, - expected: 1, + name: "set validation type", + condition: database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.OrgIDCondition(organization.ID), + domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain), + ), + changes: []database.Change{domainRepo.SetValidationType(domain.DomainValidationTypeHTTP)}, + expected: 1, }, { - name: "multiple changes", - condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName), + name: "multiple changes", + condition: database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.OrgIDCondition(organization.ID), + domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain), + ), changes: []database.Change{ domainRepo.SetVerified(), domainRepo.SetPrimary(), @@ -584,31 +641,41 @@ func TestUpdateOrganizationDomain(t *testing.T) { expected: 1, }, { - name: "update by org ID and domain", - condition: database.And(domainRepo.OrgIDCondition(orgID), domainRepo.DomainCondition(database.TextOperationEqual, domainName)), - changes: []database.Change{domainRepo.SetVerified()}, - expected: 1, + name: "update by org ID and domain", + condition: database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.OrgIDCondition(organization.ID), + domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain), + ), + changes: []database.Change{domainRepo.SetVerified()}, + expected: 1, }, { - name: "update non-existent domain", - condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"), - changes: []database.Change{domainRepo.SetVerified()}, - expected: 0, + name: "update non-existent domain", + condition: database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.OrgIDCondition(organization.ID), + domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"), + ), + changes: []database.Change{domainRepo.SetVerified()}, + expected: 0, }, { - name: "no changes", - condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName), - changes: []database.Change{}, - expected: 0, - err: database.ErrNoChanges, + name: "no changes", + condition: database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.OrgIDCondition(organization.ID), + domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain), + ), + changes: []database.Change{}, + expected: 0, + err: database.ErrNoChanges, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := t.Context() - - rowsAffected, err := domainRepo.Update(ctx, test.condition, test.changes...) + rowsAffected, err := domainRepo.Update(t.Context(), tx, test.condition, test.changes...) if test.err != nil { assert.ErrorIs(t, err, test.err) return @@ -619,7 +686,7 @@ func TestUpdateOrganizationDomain(t *testing.T) { // verify changes were applied if rows were affected if rowsAffected > 0 && len(test.changes) > 0 { - result, err := domainRepo.Get(ctx, database.WithCondition(test.condition)) + result, err := domainRepo.Get(t.Context(), tx, database.WithCondition(test.condition)) require.NoError(t, err) // We know changes were applied since rowsAffected > 0 @@ -632,10 +699,22 @@ func TestUpdateOrganizationDomain(t *testing.T) { } func TestRemoveOrganizationDomain(t *testing.T) { + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + err = tx.Rollback(t.Context()) + if err != nil { + t.Log("error during rollback:", err) + } + }() + + instanceRepo := repository.InstanceRepository() + orgRepo := repository.OrganizationRepository() + domainRepo := repository.OrganizationDomainRepository() + // create instance - instanceID := gofakeit.UUID() instance := domain.Instance{ - ID: instanceID, + ID: gofakeit.UUID(), Name: gofakeit.Name(), DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", @@ -645,53 +724,40 @@ func TestRemoveOrganizationDomain(t *testing.T) { } // create organization - orgID := gofakeit.UUID() organization := domain.Organization{ - ID: orgID, + ID: gofakeit.UUID(), Name: gofakeit.Name(), - InstanceID: instanceID, + InstanceID: instance.ID, State: domain.OrgStateActive, } - tx, err := pool.Begin(t.Context(), nil) - require.NoError(t, err) - defer func() { - require.NoError(t, tx.Rollback(t.Context())) - }() - - instanceRepo := repository.InstanceRepository(tx) - err = instanceRepo.Create(t.Context(), &instance) + err = instanceRepo.Create(t.Context(), tx, &instance) require.NoError(t, err) - orgRepo := repository.OrganizationRepository(tx) - err = orgRepo.Create(t.Context(), &organization) + err = orgRepo.Create(t.Context(), tx, &organization) require.NoError(t, err) // add domains - domainRepo := orgRepo.Domains(false) - domainName1 := gofakeit.DomainName() - domainName2 := gofakeit.DomainName() - domain1 := &domain.AddOrganizationDomain{ - InstanceID: instanceID, - OrgID: orgID, - Domain: domainName1, + InstanceID: instance.ID, + OrgID: organization.ID, + Domain: gofakeit.DomainName(), IsVerified: true, IsPrimary: true, ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), } domain2 := &domain.AddOrganizationDomain{ - InstanceID: instanceID, - OrgID: orgID, - Domain: domainName2, + InstanceID: instance.ID, + OrgID: organization.ID, + Domain: gofakeit.DomainName(), IsVerified: false, IsPrimary: false, ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP), } - err = domainRepo.Add(t.Context(), domain1) + err = domainRepo.Add(t.Context(), tx, domain1) require.NoError(t, err) - err = domainRepo.Add(t.Context(), domain2) + err = domainRepo.Add(t.Context(), tx, domain2) require.NoError(t, err) tests := []struct { @@ -700,50 +766,70 @@ func TestRemoveOrganizationDomain(t *testing.T) { expected int64 }{ { - name: "remove by domain name", - condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName1), - expected: 1, + name: "remove by domain name", + condition: database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.OrgIDCondition(organization.ID), + domainRepo.DomainCondition(database.TextOperationEqual, domain1.Domain), + ), + expected: 1, }, { - name: "remove by primary condition", - condition: domainRepo.IsPrimaryCondition(false), - expected: 1, // domain2 should still exist and be non-primary + name: "remove by primary condition", + condition: database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.OrgIDCondition(organization.ID), + domainRepo.IsPrimaryCondition(false), + ), + expected: 1, // domain2 should still exist and be non-primary }, { - name: "remove by org ID and domain", - condition: database.And(domainRepo.OrgIDCondition(orgID), domainRepo.DomainCondition(database.TextOperationEqual, domainName2)), - expected: 1, + name: "remove by org ID and domain", + condition: database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.OrgIDCondition(organization.ID), + domainRepo.DomainCondition(database.TextOperationEqual, domain2.Domain), + ), + expected: 1, }, { - name: "remove non-existent domain", - condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"), - expected: 0, + name: "remove non-existent domain", + condition: database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.OrgIDCondition(organization.ID), + domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"), + ), + expected: 0, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := t.Context() - - snapshot, err := tx.Begin(ctx) + snapshot, err := tx.Begin(t.Context()) require.NoError(t, err) defer func() { - require.NoError(t, snapshot.Rollback(ctx)) + err = snapshot.Rollback(t.Context()) + if err != nil { + t.Log("error during rollback:", err) + } }() - orgRepo := repository.OrganizationRepository(snapshot) - domainRepo := orgRepo.Domains(false) - // count before removal - beforeCount, err := domainRepo.List(ctx) + beforeCount, err := domainRepo.List(t.Context(), snapshot, database.WithCondition(database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.OrgIDCondition(organization.ID), + ))) require.NoError(t, err) - rowsAffected, err := domainRepo.Remove(ctx, test.condition) + rowsAffected, err := domainRepo.Remove(t.Context(), snapshot, test.condition) require.NoError(t, err) assert.Equal(t, test.expected, rowsAffected) // verify removal - afterCount, err := domainRepo.List(ctx) + afterCount, err := domainRepo.List(t.Context(), snapshot, database.WithCondition(database.And( + domainRepo.InstanceIDCondition(instance.ID), + domainRepo.OrgIDCondition(organization.ID), + ))) require.NoError(t, err) assert.Equal(t, len(beforeCount)-int(test.expected), len(afterCount)) }) @@ -751,8 +837,7 @@ func TestRemoveOrganizationDomain(t *testing.T) { } func TestOrganizationDomainConditions(t *testing.T) { - orgRepo := repository.OrganizationRepository(pool) - domainRepo := orgRepo.Domains(false) + domainRepo := repository.OrganizationDomainRepository() tests := []struct { name string @@ -811,8 +896,7 @@ func TestOrganizationDomainConditions(t *testing.T) { } func TestOrganizationDomainChanges(t *testing.T) { - orgRepo := repository.OrganizationRepository(pool) - domainRepo := orgRepo.Domains(false) + domainRepo := repository.OrganizationDomainRepository() tests := []struct { name string @@ -851,8 +935,7 @@ func TestOrganizationDomainChanges(t *testing.T) { } func TestOrganizationDomainColumns(t *testing.T) { - orgRepo := repository.OrganizationRepository(pool) - domainRepo := orgRepo.Domains(false) + domainRepo := repository.OrganizationDomainRepository() tests := []struct { name string diff --git a/backend/v3/storage/database/repository/org_test.go b/backend/v3/storage/database/repository/org_test.go index baaa02cff9c..c624ffe6c1c 100644 --- a/backend/v3/storage/database/repository/org_test.go +++ b/backend/v3/storage/database/repository/org_test.go @@ -1,7 +1,7 @@ package repository_test import ( - "context" + "strconv" "testing" "time" @@ -15,6 +15,18 @@ import ( ) func TestCreateOrganization(t *testing.T) { + beforeCreate := time.Now() + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + err := tx.Rollback(t.Context()) + if err != nil { + t.Logf("error during rollback: %v", err) + } + }() + + instanceRepo := repository.InstanceRepository() + organizationRepo := repository.OrganizationRepository() // create instance instanceId := gofakeit.Name() instance := domain.Instance{ @@ -26,13 +38,12 @@ func TestCreateOrganization(t *testing.T) { ConsoleAppID: "consoleApp", DefaultLanguage: "defaultLanguage", } - instanceRepo := repository.InstanceRepository(pool) - err := instanceRepo.Create(t.Context(), &instance) + err = instanceRepo.Create(t.Context(), tx, &instance) require.NoError(t, err) tests := []struct { name string - testFunc func(ctx context.Context, t *testing.T) *domain.Organization + testFunc func(t *testing.T, client database.QueryExecutor) *domain.Organization organization domain.Organization err error }{ @@ -67,8 +78,7 @@ func TestCreateOrganization(t *testing.T) { }, { name: "adding org with same id twice", - testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { - organizationRepo := repository.OrganizationRepository(pool) + testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Organization { organizationId := gofakeit.Name() organizationName := gofakeit.Name() @@ -79,7 +89,7 @@ func TestCreateOrganization(t *testing.T) { State: domain.OrgStateActive, } - err := organizationRepo.Create(ctx, &org) + err := organizationRepo.Create(t.Context(), tx, &org) require.NoError(t, err) // change the name to make sure same only the id clashes org.Name = gofakeit.Name() @@ -89,8 +99,7 @@ func TestCreateOrganization(t *testing.T) { }, { name: "adding org with same name twice", - testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { - organizationRepo := repository.OrganizationRepository(pool) + testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Organization { organizationId := gofakeit.Name() organizationName := gofakeit.Name() @@ -101,7 +110,7 @@ func TestCreateOrganization(t *testing.T) { State: domain.OrgStateActive, } - err := organizationRepo.Create(ctx, &org) + err := organizationRepo.Create(t.Context(), tx, &org) require.NoError(t, err) // change the id to make sure same name+instance causes an error org.ID = gofakeit.Name() @@ -111,7 +120,7 @@ func TestCreateOrganization(t *testing.T) { }, func() struct { name string - testFunc func(ctx context.Context, t *testing.T) *domain.Organization + testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Organization organization domain.Organization err error } { @@ -120,12 +129,12 @@ func TestCreateOrganization(t *testing.T) { return struct { name string - testFunc func(ctx context.Context, t *testing.T) *domain.Organization + testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Organization organization domain.Organization err error }{ name: "adding org with same name, different instance", - testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Organization { // create instance instId := gofakeit.Name() instance := domain.Instance{ @@ -137,12 +146,9 @@ func TestCreateOrganization(t *testing.T) { ConsoleAppID: "consoleApp", DefaultLanguage: "defaultLanguage", } - instanceRepo := repository.InstanceRepository(pool) - err := instanceRepo.Create(ctx, &instance) + err := instanceRepo.Create(t.Context(), tx, &instance) assert.Nil(t, err) - organizationRepo := repository.OrganizationRepository(pool) - org := domain.Organization{ ID: gofakeit.Name(), Name: organizationName, @@ -150,7 +156,7 @@ func TestCreateOrganization(t *testing.T) { State: domain.OrgStateActive, } - err = organizationRepo.Create(ctx, &org) + err = organizationRepo.Create(t.Context(), tx, &org) require.NoError(t, err) // change the id to make it unique @@ -214,19 +220,25 @@ func TestCreateOrganization(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() + savepoint, err := tx.Begin(t.Context()) + require.NoError(t, err) + defer func() { + err = savepoint.Rollback(t.Context()) + if err != nil { + t.Logf("error during rollback: %v", err) + } + }() var organization *domain.Organization if tt.testFunc != nil { - organization = tt.testFunc(ctx, t) + organization = tt.testFunc(t, savepoint) } else { organization = &tt.organization } - organizationRepo := repository.OrganizationRepository(pool) // create organization - beforeCreate := time.Now() - err = organizationRepo.Create(ctx, organization) + + err = organizationRepo.Create(t.Context(), savepoint, organization) assert.ErrorIs(t, err, tt.err) if err != nil { return @@ -234,7 +246,7 @@ func TestCreateOrganization(t *testing.T) { afterCreate := time.Now() // check organization values - organization, err = organizationRepo.Get(ctx, + organization, err = organizationRepo.Get(t.Context(), savepoint, database.WithCondition( database.And( organizationRepo.IDCondition(organization.ID), @@ -255,6 +267,19 @@ func TestCreateOrganization(t *testing.T) { } func TestUpdateOrganization(t *testing.T) { + beforeUpdate := time.Now() + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + err := tx.Rollback(t.Context()) + if err != nil { + t.Logf("error during rollback: %v", err) + } + }() + + instanceRepo := repository.InstanceRepository() + organizationRepo := repository.OrganizationRepository() + // create instance instanceId := gofakeit.Name() instance := domain.Instance{ @@ -266,20 +291,18 @@ func TestUpdateOrganization(t *testing.T) { ConsoleAppID: "consoleApp", DefaultLanguage: "defaultLanguage", } - instanceRepo := repository.InstanceRepository(pool) - err := instanceRepo.Create(t.Context(), &instance) + err = instanceRepo.Create(t.Context(), tx, &instance) require.NoError(t, err) - organizationRepo := repository.OrganizationRepository(pool) tests := []struct { name string - testFunc func(ctx context.Context, t *testing.T) *domain.Organization + testFunc func(t *testing.T) *domain.Organization update []database.Change rowsAffected int64 }{ { name: "happy path update name", - testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + testFunc: func(t *testing.T) *domain.Organization { organizationId := gofakeit.Name() organizationName := gofakeit.Name() @@ -291,7 +314,7 @@ func TestUpdateOrganization(t *testing.T) { } // create organization - err := organizationRepo.Create(ctx, &org) + err := organizationRepo.Create(t.Context(), tx, &org) require.NoError(t, err) // update with updated value @@ -303,7 +326,7 @@ func TestUpdateOrganization(t *testing.T) { }, { name: "update deleted organization", - testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + testFunc: func(t *testing.T) *domain.Organization { organizationId := gofakeit.Name() organizationName := gofakeit.Name() @@ -315,13 +338,15 @@ func TestUpdateOrganization(t *testing.T) { } // create organization - err := organizationRepo.Create(ctx, &org) + err := organizationRepo.Create(t.Context(), tx, &org) require.NoError(t, err) // delete instance - _, err = organizationRepo.Delete(ctx, - organizationRepo.IDCondition(org.ID), - org.InstanceID, + _, err = organizationRepo.Delete(t.Context(), tx, + database.And( + organizationRepo.InstanceIDCondition(org.InstanceID), + organizationRepo.IDCondition(org.ID), + ), ) require.NoError(t, err) @@ -332,7 +357,7 @@ func TestUpdateOrganization(t *testing.T) { }, { name: "happy path change state", - testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + testFunc: func(t *testing.T) *domain.Organization { organizationId := gofakeit.Name() organizationName := gofakeit.Name() @@ -344,7 +369,7 @@ func TestUpdateOrganization(t *testing.T) { } // create organization - err := organizationRepo.Create(ctx, &org) + err := organizationRepo.Create(t.Context(), tx, &org) require.NoError(t, err) // update with updated value @@ -356,7 +381,7 @@ func TestUpdateOrganization(t *testing.T) { }, { name: "update non existent organization", - testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + testFunc: func(t *testing.T) *domain.Organization { organizationId := gofakeit.Name() org := domain.Organization{ @@ -370,16 +395,14 @@ func TestUpdateOrganization(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - organizationRepo := repository.OrganizationRepository(pool) - - createdOrg := tt.testFunc(ctx, t) + createdOrg := tt.testFunc(t) // update org - beforeUpdate := time.Now() - rowsAffected, err := organizationRepo.Update(ctx, - organizationRepo.IDCondition(createdOrg.ID), - createdOrg.InstanceID, + rowsAffected, err := organizationRepo.Update(t.Context(), tx, + database.And( + organizationRepo.InstanceIDCondition(createdOrg.InstanceID), + organizationRepo.IDCondition(createdOrg.ID), + ), tt.update..., ) afterUpdate := time.Now() @@ -392,7 +415,7 @@ func TestUpdateOrganization(t *testing.T) { } // check organization values - organization, err := organizationRepo.Get(ctx, + organization, err := organizationRepo.Get(t.Context(), tx, database.WithCondition( database.And( organizationRepo.IDCondition(createdOrg.ID), @@ -411,6 +434,19 @@ func TestUpdateOrganization(t *testing.T) { } func TestGetOrganization(t *testing.T) { + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + err := tx.Rollback(t.Context()) + if err != nil { + t.Logf("error during rollback: %v", err) + } + }() + + instanceRepo := repository.InstanceRepository() + orgRepo := repository.OrganizationRepository() + orgDomainRepo := repository.OrganizationDomainRepository() + // create instance instanceId := gofakeit.Name() instance := domain.Instance{ @@ -422,11 +458,9 @@ func TestGetOrganization(t *testing.T) { ConsoleAppID: "consoleApp", DefaultLanguage: "defaultLanguage", } - instanceRepo := repository.InstanceRepository(pool) - err := instanceRepo.Create(t.Context(), &instance) - require.NoError(t, err) - orgRepo := repository.OrganizationRepository(pool) + err = instanceRepo.Create(t.Context(), tx, &instance) + require.NoError(t, err) // create organization // this org is created as an additional org which should NOT @@ -437,13 +471,13 @@ func TestGetOrganization(t *testing.T) { InstanceID: instanceId, State: domain.OrgStateActive, } - err = orgRepo.Create(t.Context(), &org) + err = orgRepo.Create(t.Context(), tx, &org) require.NoError(t, err) type test struct { name string - testFunc func(ctx context.Context, t *testing.T) *domain.Organization - orgIdentifierCondition domain.OrgIdentifierCondition + testFunc func(t *testing.T) *domain.Organization + orgIdentifierCondition database.Condition err error } @@ -452,7 +486,7 @@ func TestGetOrganization(t *testing.T) { organizationId := gofakeit.Name() return test{ name: "happy path get using id", - testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + testFunc: func(t *testing.T) *domain.Organization { organizationName := gofakeit.Name() org := domain.Organization{ @@ -463,7 +497,7 @@ func TestGetOrganization(t *testing.T) { } // create organization - err := orgRepo.Create(ctx, &org) + err := orgRepo.Create(t.Context(), tx, &org) require.NoError(t, err) return &org @@ -475,7 +509,7 @@ func TestGetOrganization(t *testing.T) { organizationId := gofakeit.Name() return test{ name: "happy path get using id including domain", - testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + testFunc: func(t *testing.T) *domain.Organization { organizationName := gofakeit.Name() org := domain.Organization{ @@ -486,7 +520,7 @@ func TestGetOrganization(t *testing.T) { } // create organization - err := orgRepo.Create(ctx, &org) + err := orgRepo.Create(t.Context(), tx, &org) require.NoError(t, err) d := &domain.AddOrganizationDomain{ @@ -496,7 +530,7 @@ func TestGetOrganization(t *testing.T) { IsVerified: true, IsPrimary: true, } - err = orgRepo.Domains(false).Add(ctx, d) + err = orgDomainRepo.Add(t.Context(), tx, d) require.NoError(t, err) org.Domains = []*domain.OrganizationDomain{ @@ -521,7 +555,7 @@ func TestGetOrganization(t *testing.T) { organizationName := gofakeit.Name() return test{ name: "happy path get using name", - testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + testFunc: func(t *testing.T) *domain.Organization { organizationId := gofakeit.Name() org := domain.Organization{ @@ -532,39 +566,36 @@ func TestGetOrganization(t *testing.T) { } // create organization - err := orgRepo.Create(ctx, &org) + err := orgRepo.Create(t.Context(), tx, &org) require.NoError(t, err) return &org }, - orgIdentifierCondition: orgRepo.NameCondition(organizationName), + orgIdentifierCondition: orgRepo.NameCondition(database.TextOperationEqual, organizationName), } }(), { name: "get non existent organization", - testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + testFunc: func(t *testing.T) *domain.Organization { org := domain.Organization{ ID: "non existent org", Name: "non existent org", } return &org }, - orgIdentifierCondition: orgRepo.NameCondition("non-existent-instance-name"), + orgIdentifierCondition: orgRepo.NameCondition(database.TextOperationEqual, "non-existent-instance-name"), err: new(database.NoRowFoundError), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - orgRepo := repository.OrganizationRepository(pool) - var org *domain.Organization if tt.testFunc != nil { - org = tt.testFunc(ctx, t) + org = tt.testFunc(t) } // get org values - returnedOrg, err := orgRepo.Get(ctx, + returnedOrg, err := orgRepo.Get(t.Context(), tx, database.WithCondition( database.And( tt.orgIdentifierCondition, @@ -592,11 +623,17 @@ func TestGetOrganization(t *testing.T) { } func TestListOrganization(t *testing.T) { - ctx := t.Context() - pool, stop, err := newEmbeddedDB(ctx) + tx, err := pool.Begin(t.Context(), nil) require.NoError(t, err) - defer stop() - organizationRepo := repository.OrganizationRepository(pool) + defer func() { + err := tx.Rollback(t.Context()) + if err != nil { + t.Logf("error during rollback: %v", err) + } + }() + + instanceRepo := repository.InstanceRepository() + organizationRepo := repository.OrganizationRepository() // create instance instanceId := gofakeit.Name() @@ -609,33 +646,32 @@ func TestListOrganization(t *testing.T) { ConsoleAppID: "consoleApp", DefaultLanguage: "defaultLanguage", } - instanceRepo := repository.InstanceRepository(pool) - err = instanceRepo.Create(ctx, &instance) + err = instanceRepo.Create(t.Context(), tx, &instance) require.NoError(t, err) type test struct { name string - testFunc func(ctx context.Context, t *testing.T) []*domain.Organization + testFunc func(t *testing.T, tx database.QueryExecutor) []*domain.Organization conditionClauses []database.Condition noOrganizationReturned bool } tests := []test{ { name: "happy path single organization no filter", - testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization { + testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization { noOfOrganizations := 1 organizations := make([]*domain.Organization, noOfOrganizations) for i := range noOfOrganizations { org := domain.Organization{ - ID: gofakeit.Name(), + ID: strconv.Itoa(i), Name: gofakeit.Name(), InstanceID: instanceId, State: domain.OrgStateActive, } // create organization - err := organizationRepo.Create(ctx, &org) + err := organizationRepo.Create(t.Context(), tx, &org) require.NoError(t, err) organizations[i] = &org @@ -646,20 +682,20 @@ func TestListOrganization(t *testing.T) { }, { name: "happy path multiple organization no filter", - testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization { + testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization { noOfOrganizations := 5 organizations := make([]*domain.Organization, noOfOrganizations) for i := range noOfOrganizations { org := domain.Organization{ - ID: gofakeit.Name(), + ID: strconv.Itoa(i), Name: gofakeit.Name(), InstanceID: instanceId, State: domain.OrgStateActive, } // create organization - err := organizationRepo.Create(ctx, &org) + err := organizationRepo.Create(t.Context(), tx, &org) require.NoError(t, err) organizations[i] = &org @@ -672,7 +708,7 @@ func TestListOrganization(t *testing.T) { organizationId := gofakeit.Name() return test{ name: "organization filter on id", - testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization { + testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization { // create organization // this org is created as an additional org which should NOT // be returned in the results of this test case @@ -682,7 +718,7 @@ func TestListOrganization(t *testing.T) { InstanceID: instanceId, State: domain.OrgStateActive, } - err = organizationRepo.Create(ctx, &org) + err = organizationRepo.Create(t.Context(), tx, &org) require.NoError(t, err) noOfOrganizations := 1 @@ -697,7 +733,7 @@ func TestListOrganization(t *testing.T) { } // create organization - err := organizationRepo.Create(ctx, &org) + err := organizationRepo.Create(t.Context(), tx, &org) require.NoError(t, err) organizations[i] = &org @@ -705,12 +741,15 @@ func TestListOrganization(t *testing.T) { return organizations }, - conditionClauses: []database.Condition{organizationRepo.IDCondition(organizationId)}, + conditionClauses: []database.Condition{ + organizationRepo.InstanceIDCondition(instanceId), + organizationRepo.IDCondition(organizationId), + }, } }(), { name: "multiple organization filter on state", - testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization { + testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization { // create organization // this org is created as an additional org which should NOT // be returned in the results of this test case @@ -720,7 +759,7 @@ func TestListOrganization(t *testing.T) { InstanceID: instanceId, State: domain.OrgStateActive, } - err = organizationRepo.Create(ctx, &org) + err = organizationRepo.Create(t.Context(), tx, &org) require.NoError(t, err) noOfOrganizations := 5 @@ -728,14 +767,14 @@ func TestListOrganization(t *testing.T) { for i := range noOfOrganizations { org := domain.Organization{ - ID: gofakeit.Name(), + ID: strconv.Itoa(i), Name: gofakeit.Name(), InstanceID: instanceId, State: domain.OrgStateInactive, } // create organization - err := organizationRepo.Create(ctx, &org) + err := organizationRepo.Create(t.Context(), tx, &org) require.NoError(t, err) organizations[i] = &org @@ -743,13 +782,16 @@ func TestListOrganization(t *testing.T) { return organizations }, - conditionClauses: []database.Condition{organizationRepo.StateCondition(domain.OrgStateInactive)}, + conditionClauses: []database.Condition{ + organizationRepo.InstanceIDCondition(instanceId), + organizationRepo.StateCondition(domain.OrgStateInactive), + }, }, func() test { instanceId_2 := gofakeit.Name() return test{ name: "multiple organization filter on instance", - testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization { + testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization { // create instance 1 instanceId_1 := gofakeit.Name() instance := domain.Instance{ @@ -761,8 +803,7 @@ func TestListOrganization(t *testing.T) { ConsoleAppID: "consoleApp", DefaultLanguage: "defaultLanguage", } - instanceRepo := repository.InstanceRepository(pool) - err = instanceRepo.Create(ctx, &instance) + err = instanceRepo.Create(t.Context(), tx, &instance) assert.Nil(t, err) // create organization @@ -774,7 +815,7 @@ func TestListOrganization(t *testing.T) { InstanceID: instanceId_1, State: domain.OrgStateActive, } - err = organizationRepo.Create(ctx, &org) + err = organizationRepo.Create(t.Context(), tx, &org) require.NoError(t, err) // create instance 2 @@ -787,7 +828,7 @@ func TestListOrganization(t *testing.T) { ConsoleAppID: "consoleApp", DefaultLanguage: "defaultLanguage", } - err = instanceRepo.Create(ctx, &instance_2) + err = instanceRepo.Create(t.Context(), tx, &instance_2) assert.Nil(t, err) noOfOrganizations := 5 @@ -795,14 +836,14 @@ func TestListOrganization(t *testing.T) { for i := range noOfOrganizations { org := domain.Organization{ - ID: gofakeit.Name(), + ID: strconv.Itoa(i), Name: gofakeit.Name(), InstanceID: instanceId_2, State: domain.OrgStateActive, } // create organization - err := organizationRepo.Create(ctx, &org) + err := organizationRepo.Create(t.Context(), tx, &org) require.NoError(t, err) organizations[i] = &org @@ -816,22 +857,25 @@ func TestListOrganization(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - t.Cleanup(func() { - _, err := pool.Exec(ctx, "DELETE FROM zitadel.organizations") - require.NoError(t, err) - }) + savepoint, err := tx.Begin(t.Context()) + require.NoError(t, err) + defer func() { + err = savepoint.Rollback(t.Context()) + if err != nil { + t.Logf("error during rollback: %v", err) + } + }() + organizations := tt.testFunc(t, savepoint) - organizations := tt.testFunc(ctx, t) - - var condition database.Condition + condition := organizationRepo.InstanceIDCondition(instanceId) if len(tt.conditionClauses) > 0 { condition = database.And(tt.conditionClauses...) } // check organization values - returnedOrgs, err := organizationRepo.List(ctx, + returnedOrgs, err := organizationRepo.List(t.Context(), tx, database.WithCondition(condition), - database.WithOrderByAscending(organizationRepo.CreatedAtColumn()), + database.WithOrderByAscending(organizationRepo.CreatedAtColumn(), organizationRepo.IDColumn()), ) require.NoError(t, err) if tt.noOrganizationReturned { @@ -851,6 +895,15 @@ func TestListOrganization(t *testing.T) { } func TestDeleteOrganization(t *testing.T) { + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + t.Cleanup(func() { + _ = tx.Rollback(t.Context()) + }) + + instanceRepo := repository.InstanceRepository() + organizationRepo := repository.OrganizationRepository() + // create instance instanceId := gofakeit.Name() instance := domain.Instance{ @@ -862,24 +915,22 @@ func TestDeleteOrganization(t *testing.T) { ConsoleAppID: "consoleApp", DefaultLanguage: "defaultLanguage", } - instanceRepo := repository.InstanceRepository(pool) - err := instanceRepo.Create(t.Context(), &instance) + err = instanceRepo.Create(t.Context(), tx, &instance) require.NoError(t, err) type test struct { name string - testFunc func(ctx context.Context, t *testing.T) - orgIdentifierCondition domain.OrgIdentifierCondition + testFunc func(t *testing.T) + orgIdentifierCondition database.Condition noOfDeletedRows int64 } tests := []test{ func() test { - organizationRepo := repository.OrganizationRepository(pool) organizationId := gofakeit.Name() var noOfOrganizations int64 = 1 return test{ name: "happy path delete organization filter id", - testFunc: func(ctx context.Context, t *testing.T) { + testFunc: func(t *testing.T) { organizations := make([]*domain.Organization, noOfOrganizations) for i := range noOfOrganizations { @@ -891,7 +942,7 @@ func TestDeleteOrganization(t *testing.T) { } // create organization - err := organizationRepo.Create(ctx, &org) + err := organizationRepo.Create(t.Context(), tx, &org) require.NoError(t, err) organizations[i] = &org @@ -902,12 +953,11 @@ func TestDeleteOrganization(t *testing.T) { } }(), func() test { - organizationRepo := repository.OrganizationRepository(pool) organizationName := gofakeit.Name() var noOfOrganizations int64 = 1 return test{ name: "happy path delete organization filter name", - testFunc: func(ctx context.Context, t *testing.T) { + testFunc: func(t *testing.T) { organizations := make([]*domain.Organization, noOfOrganizations) for i := range noOfOrganizations { @@ -919,30 +969,28 @@ func TestDeleteOrganization(t *testing.T) { } // create organization - err := organizationRepo.Create(ctx, &org) + err := organizationRepo.Create(t.Context(), tx, &org) require.NoError(t, err) organizations[i] = &org } }, - orgIdentifierCondition: organizationRepo.NameCondition(organizationName), + orgIdentifierCondition: organizationRepo.NameCondition(database.TextOperationEqual, organizationName), noOfDeletedRows: noOfOrganizations, } }(), func() test { - organizationRepo := repository.OrganizationRepository(pool) non_existent_organization_name := gofakeit.Name() return test{ name: "delete non existent organization", - orgIdentifierCondition: organizationRepo.NameCondition(non_existent_organization_name), + orgIdentifierCondition: organizationRepo.NameCondition(database.TextOperationEqual, non_existent_organization_name), } }(), func() test { - organizationRepo := repository.OrganizationRepository(pool) organizationName := gofakeit.Name() return test{ name: "deleted already deleted organization", - testFunc: func(ctx context.Context, t *testing.T) { + testFunc: func(t *testing.T) { noOfOrganizations := 1 organizations := make([]*domain.Organization, noOfOrganizations) for i := range noOfOrganizations { @@ -955,21 +1003,23 @@ func TestDeleteOrganization(t *testing.T) { } // create organization - err := organizationRepo.Create(ctx, &org) + err := organizationRepo.Create(t.Context(), tx, &org) require.NoError(t, err) organizations[i] = &org } // delete organization - affectedRows, err := organizationRepo.Delete(ctx, - organizationRepo.NameCondition(organizationName), - organizations[0].InstanceID, + affectedRows, err := organizationRepo.Delete(t.Context(), tx, + database.And( + organizationRepo.InstanceIDCondition(organizations[0].InstanceID), + organizationRepo.NameCondition(database.TextOperationEqual, organizationName), + ), ) assert.Equal(t, int64(1), affectedRows) require.NoError(t, err) }, - orgIdentifierCondition: organizationRepo.NameCondition(organizationName), + orgIdentifierCondition: organizationRepo.NameCondition(database.TextOperationEqual, organizationName), // this test should return 0 affected rows as the org was already deleted noOfDeletedRows: 0, } @@ -977,23 +1027,22 @@ func TestDeleteOrganization(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - organizationRepo := repository.OrganizationRepository(pool) - if tt.testFunc != nil { - tt.testFunc(ctx, t) + tt.testFunc(t) } // delete organization - noOfDeletedRows, err := organizationRepo.Delete(ctx, - tt.orgIdentifierCondition, - instanceId, + noOfDeletedRows, err := organizationRepo.Delete(t.Context(), tx, + database.And( + organizationRepo.InstanceIDCondition(instanceId), + tt.orgIdentifierCondition, + ), ) require.NoError(t, err) assert.Equal(t, noOfDeletedRows, tt.noOfDeletedRows) // check organization was deleted - organization, err := organizationRepo.Get(ctx, + organization, err := organizationRepo.Get(t.Context(), tx, database.WithCondition( database.And( tt.orgIdentifierCondition, @@ -1006,3 +1055,99 @@ func TestDeleteOrganization(t *testing.T) { }) } } + +func TestGetOrganizationWithSubResources(t *testing.T) { + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + t.Cleanup(func() { + _ = tx.Rollback(t.Context()) + }) + + instanceRepo := repository.InstanceRepository() + orgRepo := repository.OrganizationRepository() + + // create instance + instanceId := gofakeit.Name() + err = instanceRepo.Create(t.Context(), tx, &domain.Instance{ + ID: instanceId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + }) + require.NoError(t, err) + + // create organization + org := domain.Organization{ + ID: "1", + Name: "org-name", + InstanceID: instanceId, + State: domain.OrgStateActive, + } + err = orgRepo.Create(t.Context(), tx, &org) + require.NoError(t, err) + + err = orgRepo.Create(t.Context(), tx, &domain.Organization{ + ID: "without-sub-resources", + Name: "org-name-2", + InstanceID: instanceId, + State: domain.OrgStateActive, + }) + require.NoError(t, err) + + t.Run("domains", func(t *testing.T) { + domainRepo := repository.OrganizationDomainRepository() + + domain1 := &domain.AddOrganizationDomain{ + InstanceID: org.InstanceID, + OrgID: org.ID, + Domain: "domain1.com", + IsVerified: true, + IsPrimary: true, + } + err = domainRepo.Add(t.Context(), tx, domain1) + require.NoError(t, err) + + domain2 := &domain.AddOrganizationDomain{ + InstanceID: org.InstanceID, + OrgID: org.ID, + Domain: "domain2.com", + IsVerified: false, + IsPrimary: false, + } + err = domainRepo.Add(t.Context(), tx, domain2) + require.NoError(t, err) + + t.Run("org by domain", func(t *testing.T) { + orgRepo := orgRepo.LoadDomains() + + returnedOrg, err := orgRepo.Get(t.Context(), tx, + database.WithCondition( + database.And( + orgRepo.InstanceIDCondition(instanceId), + orgRepo.ExistsDomain(domainRepo.DomainCondition(database.TextOperationEqual, domain1.Domain)), + ), + ), + ) + require.NoError(t, err) + assert.Equal(t, org.ID, returnedOrg.ID) + assert.Len(t, returnedOrg.Domains, 2) + }) + + t.Run("ensure org by domain works without LoadDomains", func(t *testing.T) { + returnedOrg, err := orgRepo.Get(t.Context(), tx, + database.WithCondition( + database.And( + orgRepo.InstanceIDCondition(instanceId), + orgRepo.ExistsDomain(domainRepo.DomainCondition(database.TextOperationEqual, domain1.Domain)), + ), + ), + ) + require.NoError(t, err) + assert.Equal(t, org.ID, returnedOrg.ID) + assert.Len(t, returnedOrg.Domains, 0) + }) + }) +} diff --git a/backend/v3/storage/database/repository/repository.go b/backend/v3/storage/database/repository/repository.go index c5b9ff81f09..af7d3b08752 100644 --- a/backend/v3/storage/database/repository/repository.go +++ b/backend/v3/storage/database/repository/repository.go @@ -4,10 +4,6 @@ import ( "github.com/zitadel/zitadel/backend/v3/storage/database" ) -type repository struct { - client database.QueryExecutor -} - func writeCondition( builder *database.StatementBuilder, condition database.Condition, diff --git a/backend/v3/storage/database/repository/user.go b/backend/v3/storage/database/repository/user.go index 5953af78572..8527767c443 100644 --- a/backend/v3/storage/database/repository/user.go +++ b/backend/v3/storage/database/repository/user.go @@ -12,16 +12,10 @@ const queryUserStmt = `SELECT instance_id, org_id, id, username, type, created_a ` first_name, last_name, email_address, email_verified_at, phone_number, phone_verified_at, description` + ` FROM users_view users` -type user struct { - repository -} +type user struct{} -func UserRepository(client database.QueryExecutor) domain.UserRepository { - return &user{ - repository: repository{ - client: client, - }, - } +func UserRepository() domain.UserRepository { + return new(user) } var _ domain.UserRepository = (*user)(nil) @@ -31,17 +25,17 @@ var _ domain.UserRepository = (*user)(nil) // ------------------------------------------------------------- // Human implements [domain.UserRepository]. -func (u *user) Human() domain.HumanRepository { +func (u user) Human() domain.HumanRepository { return &userHuman{user: u} } // Machine implements [domain.UserRepository]. -func (u *user) Machine() domain.MachineRepository { +func (u user) Machine() domain.MachineRepository { return &userMachine{user: u} } // List implements [domain.UserRepository]. -func (u *user) List(ctx context.Context, opts ...database.QueryOption) (users []*domain.User, err error) { +func (u user) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (users []*domain.User, err error) { options := new(database.QueryOpts) for _, opt := range opts { opt(options) @@ -54,7 +48,7 @@ func (u *user) List(ctx context.Context, opts ...database.QueryOption) (users [] options.WriteLimit(&builder) options.WriteOffset(&builder) - rows, err := u.client.Query(ctx, builder.String(), builder.Args()...) + rows, err := client.Query(ctx, builder.String(), builder.Args()...) if err != nil { return nil, err } @@ -79,7 +73,7 @@ func (u *user) List(ctx context.Context, opts ...database.QueryOption) (users [] } // Get implements [domain.UserRepository]. -func (u *user) Get(ctx context.Context, opts ...database.QueryOption) (*domain.User, error) { +func (u user) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.User, error) { options := new(database.QueryOpts) for _, opt := range opts { opt(options) @@ -92,7 +86,7 @@ func (u *user) Get(ctx context.Context, opts ...database.QueryOption) (*domain.U options.WriteLimit(&builder) options.WriteOffset(&builder) - return scanUser(u.client.QueryRow(ctx, builder.String(), builder.Args()...)) + return scanUser(client.QueryRow(ctx, builder.String(), builder.Args()...)) } const ( @@ -105,7 +99,7 @@ const ( ) // Create implements [domain.UserRepository]. -func (u *user) Create(ctx context.Context, user *domain.User) error { +func (u user) Create(ctx context.Context, client database.QueryExecutor, user *domain.User) error { builder := database.StatementBuilder{} builder.AppendArgs(user.InstanceID, user.OrgID, user.ID, user.Username, user.Traits.Type()) switch trait := user.Traits.(type) { @@ -116,15 +110,15 @@ func (u *user) Create(ctx context.Context, user *domain.User) error { builder.WriteString(createMachineStmt) builder.AppendArgs(trait.Description) } - return u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&user.CreatedAt, &user.UpdatedAt) + return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&user.CreatedAt, &user.UpdatedAt) } // Delete implements [domain.UserRepository]. -func (u *user) Delete(ctx context.Context, condition database.Condition) error { +func (u user) Delete(ctx context.Context, client database.QueryExecutor, condition database.Condition) error { builder := database.StatementBuilder{} builder.WriteString("DELETE FROM users") writeCondition(&builder, condition) - _, err := u.client.Exec(ctx, builder.String(), builder.Args()...) + _, err := client.Exec(ctx, builder.String(), builder.Args()...) return err } diff --git a/backend/v3/storage/database/repository/user_human.go b/backend/v3/storage/database/repository/user_human.go index ae7643c53c8..6bacc497d11 100644 --- a/backend/v3/storage/database/repository/user_human.go +++ b/backend/v3/storage/database/repository/user_human.go @@ -13,7 +13,7 @@ import ( // ------------------------------------------------------------- type userHuman struct { - *user + user } var _ domain.HumanRepository = (*userHuman)(nil) @@ -21,14 +21,14 @@ var _ domain.HumanRepository = (*userHuman)(nil) const userEmailQuery = `SELECT h.email_address, h.email_verified_at FROM user_humans h` // GetEmail implements [domain.HumanRepository]. -func (u *userHuman) GetEmail(ctx context.Context, condition database.Condition) (*domain.Email, error) { +func (u *userHuman) GetEmail(ctx context.Context, client database.QueryExecutor, condition database.Condition) (*domain.Email, error) { var email domain.Email builder := database.StatementBuilder{} builder.WriteString(userEmailQuery) writeCondition(&builder, condition) - err := u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan( + err := client.QueryRow(ctx, builder.String(), builder.Args()...).Scan( &email.Address, &email.VerifiedAt, ) @@ -39,7 +39,7 @@ func (u *userHuman) GetEmail(ctx context.Context, condition database.Condition) } // Update implements [domain.HumanRepository]. -func (h userHuman) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error { +func (h userHuman) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) error { builder := database.StatementBuilder{} builder.WriteString(`UPDATE human_users SET `) database.Changes(changes).Write(&builder) @@ -47,7 +47,7 @@ func (h userHuman) Update(ctx context.Context, condition database.Condition, cha stmt := builder.String() - _, err := h.client.Exec(ctx, stmt, builder.Args()...) + _, err := client.Exec(ctx, stmt, builder.Args()...) return err } diff --git a/backend/v3/storage/database/repository/user_machine.go b/backend/v3/storage/database/repository/user_machine.go index 2bda09d0e5d..82fe87ba5c0 100644 --- a/backend/v3/storage/database/repository/user_machine.go +++ b/backend/v3/storage/database/repository/user_machine.go @@ -8,7 +8,7 @@ import ( ) type userMachine struct { - *user + user } var _ domain.MachineRepository = (*userMachine)(nil) @@ -18,14 +18,14 @@ var _ domain.MachineRepository = (*userMachine)(nil) // ------------------------------------------------------------- // Update implements [domain.MachineRepository]. -func (m userMachine) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error { +func (m userMachine) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) error { builder := database.StatementBuilder{} builder.WriteString("UPDATE user_machines SET ") database.Changes(changes).Write(&builder) writeCondition(&builder, condition) m.writeReturning() - _, err := m.client.Exec(ctx, builder.String(), builder.Args()...) + _, err := client.Exec(ctx, builder.String(), builder.Args()...) return err } diff --git a/backend/v3/storage/database/statement.go b/backend/v3/storage/database/statement.go index 2858feae434..77af573f5ee 100644 --- a/backend/v3/storage/database/statement.go +++ b/backend/v3/storage/database/statement.go @@ -1,6 +1,7 @@ package database import ( + "encoding/hex" "strconv" "strings" ) @@ -20,8 +21,16 @@ type StatementBuilder struct { existingArgs map[any]string } +type argWriter interface { + WriteArg(builder *StatementBuilder) +} + // WriteArgs adds the argument to the statement and writes the placeholder to the query. func (b *StatementBuilder) WriteArg(arg any) { + if writer, ok := arg.(argWriter); ok { + writer.WriteArg(b) + return + } b.WriteString(b.AppendArg(arg)) } @@ -41,7 +50,13 @@ func (b *StatementBuilder) AppendArg(arg any) (placeholder string) { if b.existingArgs == nil { b.existingArgs = make(map[any]string) } - if placeholder, ok := b.existingArgs[arg]; ok { + // the key is used to work around the following panic: + // runtime error: hash of unhashable type []uint8 + key := arg + if argBytes, ok := arg.([]uint8); ok { + key = `\\bytes-` + hex.EncodeToString(argBytes) + } + if placeholder, ok := b.existingArgs[key]; ok { return placeholder } if instruction, ok := arg.(Instruction); ok { @@ -50,7 +65,7 @@ func (b *StatementBuilder) AppendArg(arg any) (placeholder string) { b.args = append(b.args, arg) placeholder = "$" + strconv.Itoa(len(b.args)) - b.existingArgs[arg] = placeholder + b.existingArgs[key] = placeholder return placeholder } diff --git a/backend/v3/storage/database/statement_test.go b/backend/v3/storage/database/statement_test.go new file mode 100644 index 00000000000..147043e8c28 --- /dev/null +++ b/backend/v3/storage/database/statement_test.go @@ -0,0 +1,144 @@ +package database + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStatementBuilder_AppendArg(t *testing.T) { + t.Run("same arg returns same placeholder", func(t *testing.T) { + var b StatementBuilder + placeholder1 := b.AppendArg("same") + placeholder2 := b.AppendArg("same") + assert.Equal(t, placeholder1, placeholder2) + assert.Len(t, b.Args(), 1) + assert.Len(t, b.existingArgs, 1) + }) + + t.Run("same arg different types", func(t *testing.T) { + var b StatementBuilder + placeholder1 := b.AppendArg("same") + placeholder2 := b.AppendArg([]byte("same")) + placeholder3 := b.AppendArg("same") + assert.NotEqual(t, placeholder1, placeholder2) + assert.Equal(t, placeholder1, placeholder3) + assert.Len(t, b.Args(), 2) + assert.Len(t, b.existingArgs, 2) + }) + + t.Run("Instruction args are always different", func(t *testing.T) { + var b StatementBuilder + placeholder1 := b.AppendArg(DefaultInstruction) + placeholder2 := b.AppendArg(DefaultInstruction) + assert.Equal(t, placeholder1, placeholder2) + assert.Len(t, b.Args(), 0) + assert.Len(t, b.existingArgs, 0) + }) +} + +func TestStatementBuilder_AppendArgs(t *testing.T) { + t.Run("same arg returns same placeholder", func(t *testing.T) { + var b StatementBuilder + b.AppendArgs("same", "same") + assert.Len(t, b.Args(), 1) + assert.Len(t, b.existingArgs, 1) + }) + + t.Run("same arg different types", func(t *testing.T) { + var b StatementBuilder + b.AppendArgs("same", []byte("same"), "same") + assert.Len(t, b.Args(), 2) + assert.Len(t, b.existingArgs, 2) + }) + + t.Run("Instruction args are always different", func(t *testing.T) { + var b StatementBuilder + b.AppendArgs(DefaultInstruction, DefaultInstruction) + assert.Len(t, b.Args(), 0) + assert.Len(t, b.existingArgs, 0) + }) +} + +func TestStatementBuilder_WriteArg(t *testing.T) { + for _, test := range []struct { + name string + arg any + wantSQL string + wantArg []any + }{ + { + name: "primitive arg", + arg: "test", + wantSQL: "$1", + wantArg: []any{"test"}, + }, + { + name: "argWriter arg", + arg: SHA256Value("wrapped"), + wantSQL: "SHA256($1)", + wantArg: []any{"wrapped"}, + }, + { + name: "Instruction arg", + arg: DefaultInstruction, + wantSQL: "DEFAULT", + wantArg: []any{}, + }, + } { + t.Run(test.name, func(t *testing.T) { + var b StatementBuilder + b.WriteArg(test.arg) + assert.Equal(t, test.wantSQL, b.String()) + require.Len(t, b.Args(), len(test.wantArg)) + for i := range test.wantArg { + assert.Equal(t, test.wantArg[i], b.Args()[i]) + } + }) + } +} + +func TestStatementBuilder_WriteArgs(t *testing.T) { + for _, test := range []struct { + name string + args []any + wantSQL string + wantArg []any + }{ + { + name: "primitive args", + args: []any{"test", 123, true, uint32(123)}, + wantSQL: "$1, $2, $3, $4", + wantArg: []any{"test", 123, true, uint32(123)}, + }, + { + name: "argWriter args", + args: []any{SHA256Value("wrapped"), LowerValue("ASDF")}, + wantSQL: "SHA256($1), LOWER($2)", + wantArg: []any{"wrapped", "ASDF"}, + }, + { + name: "Instruction args", + args: []any{DefaultInstruction, NowInstruction, NullInstruction}, + wantSQL: "DEFAULT, NOW(), NULL", + wantArg: []any{}, + }, + { + name: "mixed args", + args: []any{123, uint32(123), NowInstruction, NullInstruction, SHA256Value("wrapped"), LowerValue("ASDF")}, + wantSQL: "$1, $2, NOW(), NULL, SHA256($3), LOWER($4)", + wantArg: []any{123, uint32(123), "wrapped", "ASDF"}, + }, + } { + t.Run(test.name, func(t *testing.T) { + var b StatementBuilder + b.WriteArgs(test.args...) + assert.Equal(t, test.wantSQL, b.String()) + require.Len(t, b.Args(), len(test.wantArg)) + for i := range test.wantArg { + assert.Equal(t, test.wantArg[i], b.Args()[i]) + } + }) + } +} diff --git a/internal/query/projection/instance_domain_relational.go b/internal/query/projection/instance_domain_relational.go index 0deb4e82a45..80606fda677 100644 --- a/internal/query/projection/instance_domain_relational.go +++ b/internal/query/projection/instance_domain_relational.go @@ -66,7 +66,7 @@ func (p *instanceDomainRelationalProjection) reduceCustomDomainAdded(event event if !ok { return zerrors.ThrowInvalidArgumentf(nil, "HANDL-bXCa6", "reduce.wrong.db.pool %T", ex) } - return repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false).Add(ctx, &domain.AddInstanceDomain{ + return repository.InstanceDomainRepository().Add(ctx, v3_sql.SQLTx(tx), &domain.AddInstanceDomain{ InstanceID: e.Aggregate().InstanceID, Domain: e.Domain, IsPrimary: gu.Ptr(false), @@ -88,25 +88,15 @@ func (p *instanceDomainRelationalProjection) reduceDomainPrimarySet(event events if !ok { return zerrors.ThrowInvalidArgumentf(nil, "HANDL-QnjHo", "reduce.wrong.db.pool %T", ex) } - domainRepo := repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false) + domainRepo := repository.InstanceDomainRepository() - condition := database.And( - domainRepo.InstanceIDCondition(e.Aggregate().InstanceID), - domainRepo.DomainCondition(database.TextOperationEqual, e.Domain), - domainRepo.TypeCondition(domain.DomainTypeCustom), - ) - - _, err := domainRepo.Update(ctx, - condition, + _, err := domainRepo.Update(ctx, v3_sql.SQLTx(tx), + database.And( + domainRepo.InstanceIDCondition(e.Aggregate().InstanceID), + domainRepo.DomainCondition(database.TextOperationEqual, e.Domain), + domainRepo.TypeCondition(domain.DomainTypeCustom), + ), domainRepo.SetPrimary(), - ) - if err != nil { - return err - } - // we need to split the update into two statements because multiple events can have the same creation date - // therefore we first do not set the updated_at timestamp - _, err = domainRepo.Update(ctx, - condition, domainRepo.SetUpdatedAt(e.CreationDate()), ) return err @@ -123,8 +113,8 @@ func (p *instanceDomainRelationalProjection) reduceCustomDomainRemoved(event eve if !ok { return zerrors.ThrowInvalidArgumentf(nil, "HANDL-58ghE", "reduce.wrong.db.pool %T", ex) } - domainRepo := repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false) - _, err := domainRepo.Remove(ctx, + domainRepo := repository.InstanceDomainRepository() + _, err := domainRepo.Remove(ctx, v3_sql.SQLTx(tx), database.And( domainRepo.InstanceIDCondition(e.Aggregate().InstanceID), domainRepo.DomainCondition(database.TextOperationEqual, e.Domain), @@ -145,7 +135,7 @@ func (p *instanceDomainRelationalProjection) reduceTrustedDomainAdded(event even if !ok { return zerrors.ThrowInvalidArgumentf(nil, "HANDL-gx7tQ", "reduce.wrong.db.pool %T", ex) } - return repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false).Add(ctx, &domain.AddInstanceDomain{ + return repository.InstanceDomainRepository().Add(ctx, v3_sql.SQLTx(tx), &domain.AddInstanceDomain{ InstanceID: e.Aggregate().InstanceID, Domain: e.Domain, Type: domain.DomainTypeTrusted, @@ -165,8 +155,8 @@ func (p *instanceDomainRelationalProjection) reduceTrustedDomainRemoved(event ev if !ok { return zerrors.ThrowInvalidArgumentf(nil, "HANDL-D68ap", "reduce.wrong.db.pool %T", ex) } - domainRepo := repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false) - _, err := domainRepo.Remove(ctx, + domainRepo := repository.InstanceDomainRepository() + _, err := domainRepo.Remove(ctx, v3_sql.SQLTx(tx), database.And( domainRepo.InstanceIDCondition(e.Aggregate().InstanceID), domainRepo.DomainCondition(database.TextOperationEqual, e.Domain), diff --git a/internal/query/projection/instance_relational.go b/internal/query/projection/instance_relational.go index f79127bef86..3607d7986d6 100644 --- a/internal/query/projection/instance_relational.go +++ b/internal/query/projection/instance_relational.go @@ -74,7 +74,7 @@ func (p *instanceRelationalProjection) reduceInstanceAdded(event eventstore.Even if !ok { return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex) } - return repository.InstanceRepository(v3_sql.SQLTx(tx)).Create(ctx, &domain.Instance{ + return repository.InstanceRepository().Create(ctx, v3_sql.SQLTx(tx), &domain.Instance{ ID: e.Aggregate().ID, Name: e.Name, CreatedAt: e.CreationDate(), @@ -93,8 +93,8 @@ func (p *instanceRelationalProjection) reduceInstanceChanged(event eventstore.Ev if !ok { return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex) } - repo := repository.InstanceRepository(v3_sql.SQLTx(tx)) - return p.updateInstance(ctx, event, repo, repo.SetName(e.Name)) + repo := repository.InstanceRepository() + return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetName(e.Name)) }), nil } @@ -108,7 +108,7 @@ func (p *instanceRelationalProjection) reduceInstanceDelete(event eventstore.Eve if !ok { return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex) } - _, err := repository.InstanceRepository(v3_sql.SQLTx(tx)).Delete(ctx, e.Aggregate().ID) + _, err := repository.InstanceRepository().Delete(ctx, v3_sql.SQLTx(tx), e.Aggregate().ID) return err }), nil } @@ -124,8 +124,8 @@ func (p *instanceRelationalProjection) reduceDefaultOrgSet(event eventstore.Even if !ok { return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex) } - repo := repository.InstanceRepository(v3_sql.SQLTx(tx)) - return p.updateInstance(ctx, event, repo, repo.SetDefaultOrg(e.OrgID)) + repo := repository.InstanceRepository() + return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetDefaultOrg(e.OrgID)) }), nil } @@ -140,8 +140,8 @@ func (p *instanceRelationalProjection) reduceIAMProjectSet(event eventstore.Even if !ok { return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex) } - repo := repository.InstanceRepository(v3_sql.SQLTx(tx)) - return p.updateInstance(ctx, event, repo, repo.SetIAMProject(e.ProjectID)) + repo := repository.InstanceRepository() + return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetIAMProject(e.ProjectID)) }), nil } @@ -156,8 +156,8 @@ func (p *instanceRelationalProjection) reduceConsoleSet(event eventstore.Event) if !ok { return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex) } - repo := repository.InstanceRepository(v3_sql.SQLTx(tx)) - return p.updateInstance(ctx, event, repo, repo.SetConsoleClientID(e.ClientID), repo.SetConsoleAppID(e.AppID)) + repo := repository.InstanceRepository() + return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetConsoleClientID(e.ClientID), repo.SetConsoleAppID(e.AppID)) }), nil } @@ -172,18 +172,18 @@ func (p *instanceRelationalProjection) reduceDefaultLanguageSet(event eventstore if !ok { return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex) } - repo := repository.InstanceRepository(v3_sql.SQLTx(tx)) - return p.updateInstance(ctx, event, repo, repo.SetDefaultLanguage(e.Language)) + repo := repository.InstanceRepository() + return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetDefaultLanguage(e.Language)) }), nil } -func (p *instanceRelationalProjection) updateInstance(ctx context.Context, event eventstore.Event, repo domain.InstanceRepository, changes ...database.Change) error { - _, err := repo.Update(ctx, event.Aggregate().ID, changes...) +func (p *instanceRelationalProjection) updateInstance(ctx context.Context, tx database.Transaction, event eventstore.Event, repo domain.InstanceRepository, changes ...database.Change) error { + _, err := repo.Update(ctx, tx, event.Aggregate().ID, changes...) if err != nil { return err } - instance, err := repo.Get(ctx, database.WithCondition(repo.IDCondition(event.Aggregate().ID))) + instance, err := repo.Get(ctx, tx, database.WithCondition(repo.IDCondition(event.Aggregate().ID))) if err != nil { return err } @@ -192,7 +192,7 @@ func (p *instanceRelationalProjection) updateInstance(ctx context.Context, event } // we need to split the update into two statements because multiple events can have the same creation date // therefore we first do not set the updated_at timestamp - _, err = repo.Update(ctx, + _, err = repo.Update(ctx, tx, event.Aggregate().ID, repo.SetUpdatedAt(event.CreatedAt()), ) diff --git a/internal/query/projection/org_domain_relational.go b/internal/query/projection/org_domain_relational.go index 747a21a2cb0..2ac3dc44931 100644 --- a/internal/query/projection/org_domain_relational.go +++ b/internal/query/projection/org_domain_relational.go @@ -65,7 +65,7 @@ func (p *orgDomainRelationalProjection) reduceAdded(event eventstore.Event) (*ha if !ok { return zerrors.ThrowInvalidArgumentf(nil, "HANDL-kGokE", "reduce.wrong.db.pool %T", ex) } - return repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false).Add(ctx, &domain.AddOrganizationDomain{ + return repository.OrganizationDomainRepository().Add(ctx, v3_sql.SQLTx(tx), &domain.AddOrganizationDomain{ InstanceID: e.Aggregate().InstanceID, OrgID: e.Aggregate().ResourceOwner, Domain: e.Domain, @@ -85,23 +85,14 @@ func (p *orgDomainRelationalProjection) reducePrimarySet(event eventstore.Event) if !ok { return zerrors.ThrowInvalidArgumentf(nil, "HANDL-h6xF0", "reduce.wrong.db.pool %T", ex) } - domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false) - condition := database.And( - domainRepo.InstanceIDCondition(e.Aggregate().InstanceID), - domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner), - domainRepo.DomainCondition(database.TextOperationEqual, e.Domain), - ) - _, err := domainRepo.Update(ctx, - condition, + domainRepo := repository.OrganizationDomainRepository() + _, err := domainRepo.Update(ctx, v3_sql.SQLTx(tx), + database.And( + domainRepo.InstanceIDCondition(e.Aggregate().InstanceID), + domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner), + domainRepo.DomainCondition(database.TextOperationEqual, e.Domain), + ), domainRepo.SetPrimary(), - ) - if err != nil { - return err - } - // we need to split the update into two statements because multiple events can have the same creation date - // therefore we first do not set the updated_at timestamp - _, err = domainRepo.Update(ctx, - condition, domainRepo.SetUpdatedAt(e.CreationDate()), ) return err @@ -118,8 +109,8 @@ func (p *orgDomainRelationalProjection) reduceRemoved(event eventstore.Event) (* if !ok { return zerrors.ThrowInvalidArgumentf(nil, "HANDL-X8oS8", "reduce.wrong.db.pool %T", ex) } - domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false) - _, err := domainRepo.Remove(ctx, + domainRepo := repository.OrganizationDomainRepository() + _, err := domainRepo.Remove(ctx, v3_sql.SQLTx(tx), database.And( domainRepo.InstanceIDCondition(e.Aggregate().InstanceID), domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner), @@ -149,24 +140,15 @@ func (p *orgDomainRelationalProjection) reduceVerificationAdded(event eventstore if !ok { return zerrors.ThrowInvalidArgumentf(nil, "HANDL-yF03i", "reduce.wrong.db.pool %T", ex) } - domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false) - condition := database.And( - domainRepo.InstanceIDCondition(e.Aggregate().InstanceID), - domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner), - domainRepo.DomainCondition(database.TextOperationEqual, e.Domain), - ) + domainRepo := repository.OrganizationDomainRepository() - _, err := domainRepo.Update(ctx, - condition, + _, err := domainRepo.Update(ctx, v3_sql.SQLTx(tx), + database.And( + domainRepo.InstanceIDCondition(e.Aggregate().InstanceID), + domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner), + domainRepo.DomainCondition(database.TextOperationEqual, e.Domain), + ), domainRepo.SetValidationType(validationType), - ) - if err != nil { - return err - } - // we need to split the update into two statements because multiple events can have the same creation date - // therefore we first do not set the updated_at timestamp - _, err = domainRepo.Update(ctx, - condition, domainRepo.SetUpdatedAt(e.CreationDate()), ) return err @@ -183,28 +165,17 @@ func (p *orgDomainRelationalProjection) reduceVerified(event eventstore.Event) ( if !ok { return zerrors.ThrowInvalidArgumentf(nil, "HANDL-0ZGqC", "reduce.wrong.db.pool %T", ex) } - domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false) + domainRepo := repository.OrganizationDomainRepository() - condition := database.And( - domainRepo.InstanceIDCondition(e.Aggregate().InstanceID), - domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner), - domainRepo.DomainCondition(database.TextOperationEqual, e.Domain), - ) - - _, err := domainRepo.Update(ctx, - condition, + _, err := domainRepo.Update(ctx, v3_sql.SQLTx(tx), + database.And( + domainRepo.InstanceIDCondition(e.Aggregate().InstanceID), + domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner), + domainRepo.DomainCondition(database.TextOperationEqual, e.Domain), + ), domainRepo.SetVerified(), domainRepo.SetUpdatedAt(e.CreationDate()), ) - if err != nil { - return err - } - // we need to split the update into two statements because multiple events can have the same creation date - // therefore we first do not set the updated_at timestamp - _, err = domainRepo.Update(ctx, - condition, - domainRepo.SetUpdatedAt(e.CreationDate()), - ) return err }), nil }