diff --git a/backend/v3/domain/instance.go b/backend/v3/domain/instance.go
index 03dd40cf2ec..446331a320e 100644
--- a/backend/v3/domain/instance.go
+++ b/backend/v3/domain/instance.go
@@ -69,6 +69,8 @@ type instanceConditions interface {
IDCondition(instanceID string) database.Condition
// NameCondition returns a filter on the name field.
NameCondition(op database.TextOperation, name string) database.Condition
+ // ExistsDomain returns a filter on the instance domains.
+ ExistsDomain(cond database.Condition) database.Condition
}
// instanceChanges define all the changes for the instance table.
@@ -99,17 +101,16 @@ type InstanceRepository interface {
// Member returns the member repository which is a sub repository of the instance repository.
// Member() MemberRepository
- Get(ctx context.Context, opts ...database.QueryOption) (*Instance, error)
- List(ctx context.Context, opts ...database.QueryOption) ([]*Instance, error)
+ Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*Instance, error)
+ List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*Instance, error)
- Create(ctx context.Context, instance *Instance) error
- Update(ctx context.Context, id string, changes ...database.Change) (int64, error)
- Delete(ctx context.Context, id string) (int64, error)
+ Create(ctx context.Context, client database.QueryExecutor, instance *Instance) error
+ Update(ctx context.Context, client database.QueryExecutor, id string, changes ...database.Change) (int64, error)
+ Delete(ctx context.Context, client database.QueryExecutor, id string) (int64, error)
- // Domains returns the domain sub repository for the instance.
- // If shouldLoad is true, the domains will be loaded from the database and written to the [Instance].Domains field.
- // If shouldLoad is set to true once, the Domains field will be set even if shouldLoad is false in the future.
- Domains(shouldLoad bool) InstanceDomainRepository
+ // LoadDomains loads the domains of the given instance.
+ // If it is called the [Instance].Domains field will be set on future calls to Get or List.
+ LoadDomains() InstanceRepository
}
type CreateInstance struct {
diff --git a/backend/v3/domain/instance_domain.go b/backend/v3/domain/instance_domain.go
index 4c6a71b2e9f..786be287ee7 100644
--- a/backend/v3/domain/instance_domain.go
+++ b/backend/v3/domain/instance_domain.go
@@ -65,15 +65,15 @@ type InstanceDomainRepository interface {
// Get returns a single domain based on the criteria.
// If no domain is found, it returns an error of type [database.ErrNotFound].
// If multiple domains are found, it returns an error of type [database.ErrMultipleRows].
- Get(ctx context.Context, opts ...database.QueryOption) (*InstanceDomain, error)
+ Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*InstanceDomain, error)
// List returns a list of domains based on the criteria.
// If no domains are found, it returns an empty slice.
- List(ctx context.Context, opts ...database.QueryOption) ([]*InstanceDomain, error)
+ List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*InstanceDomain, error)
// Add adds a new domain to the instance.
- Add(ctx context.Context, domain *AddInstanceDomain) error
+ Add(ctx context.Context, client database.QueryExecutor, domain *AddInstanceDomain) error
// Update updates an existing domain in the instance.
- Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error)
+ Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error)
// Remove removes a domain from the instance.
- Remove(ctx context.Context, condition database.Condition) (int64, error)
+ Remove(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error)
}
diff --git a/backend/v3/domain/organization.go b/backend/v3/domain/organization.go
index 570e96c39e9..32d6528ba86 100644
--- a/backend/v3/domain/organization.go
+++ b/backend/v3/domain/organization.go
@@ -26,13 +26,6 @@ type Organization struct {
Domains []*OrganizationDomain `json:"domains,omitempty" db:"-"` // domains need to be handled separately
}
-// OrgIdentifierCondition is used to help specify a single Organization,
-// it will either be used as the organization ID or organization name,
-// as organizations can be identified either using (instanceID + ID) OR (instanceID + name)
-type OrgIdentifierCondition interface {
- database.Condition
-}
-
// organizationColumns define all the columns of the instance table.
type organizationColumns interface {
// IDColumn returns the column for the id field.
@@ -52,13 +45,15 @@ type organizationColumns interface {
// organizationConditions define all the conditions for the instance table.
type organizationConditions interface {
// IDCondition returns an equal filter on the id field.
- IDCondition(instanceID string) OrgIdentifierCondition
+ IDCondition(instanceID string) database.Condition
// NameCondition returns a filter on the name field.
- NameCondition(name string) OrgIdentifierCondition
+ NameCondition(op database.TextOperation, name string) database.Condition
// InstanceIDCondition returns a filter on the instance id field.
InstanceIDCondition(instanceID string) database.Condition
// StateCondition returns a filter on the name field.
StateCondition(state OrgState) database.Condition
+ // ExistsDomain returns a filter on the organizations domains.
+ ExistsDomain(cond database.Condition) database.Condition
}
// organizationChanges define all the changes for the instance table.
@@ -75,17 +70,16 @@ type OrganizationRepository interface {
organizationConditions
organizationChanges
- Get(ctx context.Context, opts ...database.QueryOption) (*Organization, error)
- List(ctx context.Context, opts ...database.QueryOption) ([]*Organization, error)
+ Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*Organization, error)
+ List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*Organization, error)
- Create(ctx context.Context, instance *Organization) error
- Update(ctx context.Context, id OrgIdentifierCondition, instance_id string, changes ...database.Change) (int64, error)
- Delete(ctx context.Context, id OrgIdentifierCondition, instance_id string) (int64, error)
+ Create(ctx context.Context, client database.QueryExecutor, org *Organization) error
+ Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error)
+ Delete(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error)
- // Domains returns the domain sub repository for the organization.
- // If shouldLoad is true, the domains will be loaded from the database and written to the [Instance].Domains field.
- // If shouldLoad is set to true once, the Domains field will be set event if shouldLoad is false in the future.
- Domains(shouldLoad bool) OrganizationDomainRepository
+ // LoadDomains loads the domains of the given organizations.
+ // If it is called the [Organization].Domains field will be set on future calls to Get or List.
+ LoadDomains() OrganizationRepository
}
type CreateOrganization struct {
diff --git a/backend/v3/domain/organization_domain.go b/backend/v3/domain/organization_domain.go
index c0868e3a620..1e1b3877c49 100644
--- a/backend/v3/domain/organization_domain.go
+++ b/backend/v3/domain/organization_domain.go
@@ -70,15 +70,15 @@ type OrganizationDomainRepository interface {
// Get returns a single domain based on the criteria.
// If no domain is found, it returns an error of type [database.ErrNotFound].
// If multiple domains are found, it returns an error of type [database.ErrMultipleRows].
- Get(ctx context.Context, opts ...database.QueryOption) (*OrganizationDomain, error)
+ Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*OrganizationDomain, error)
// List returns a list of domains based on the criteria.
// If no domains are found, it returns an empty slice.
- List(ctx context.Context, opts ...database.QueryOption) ([]*OrganizationDomain, error)
+ List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*OrganizationDomain, error)
// Add adds a new domain to the organization.
- Add(ctx context.Context, domain *AddOrganizationDomain) error
+ Add(ctx context.Context, client database.QueryExecutor, domain *AddOrganizationDomain) error
// Update updates an existing domain in the organization.
- Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error)
+ Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error)
// Remove removes a domain from the organization.
- Remove(ctx context.Context, condition database.Condition) (int64, error)
+ Remove(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error)
}
diff --git a/backend/v3/domain/user.go b/backend/v3/domain/user.go
index fae0d75b6e7..333b5c96fbb 100644
--- a/backend/v3/domain/user.go
+++ b/backend/v3/domain/user.go
@@ -57,13 +57,13 @@ type UserRepository interface {
userConditions
userChanges
// Get returns a user based on the given condition.
- Get(ctx context.Context, opts ...database.QueryOption) (*User, error)
+ Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*User, error)
// List returns a list of users based on the given condition.
- List(ctx context.Context, opts ...database.QueryOption) ([]*User, error)
+ List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*User, error)
// Create creates a new user.
- Create(ctx context.Context, user *User) error
+ Create(ctx context.Context, client database.QueryExecutor, user *User) error
// Delete removes users based on the given condition.
- Delete(ctx context.Context, condition database.Condition) error
+ Delete(ctx context.Context, client database.QueryExecutor, condition database.Condition) error
// Human returns the [HumanRepository].
Human() HumanRepository
// Machine returns the [MachineRepository].
@@ -143,9 +143,9 @@ type HumanRepository interface {
humanChanges
// Get returns an email based on the given condition.
- GetEmail(ctx context.Context, condition database.Condition) (*Email, error)
+ GetEmail(ctx context.Context, client database.QueryExecutor, condition database.Condition) (*Email, error)
// Update updates human users based on the given condition and changes.
- Update(ctx context.Context, condition database.Condition, changes ...database.Change) error
+ Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) error
}
// machineColumns define all the columns of the machine table which inherits the user table.
@@ -172,7 +172,7 @@ type machineChanges interface {
// MachineRepository is the interface for the machine repository it inherits the user repository.
type MachineRepository interface {
// Update updates machine users based on the given condition and changes.
- Update(ctx context.Context, condition database.Condition, changes ...database.Change) error
+ Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) error
machineColumns
machineConditions
diff --git a/backend/v3/storage/database/change.go b/backend/v3/storage/database/change.go
index 724e029dbe9..6581a2c4ce0 100644
--- a/backend/v3/storage/database/change.go
+++ b/backend/v3/storage/database/change.go
@@ -1,9 +1,14 @@
package database
+import "slices"
+
// Change represents a change to a column in a database table.
// Its written in the SET clause of an UPDATE statement.
type Change interface {
+ // Write writes the change to the given statement builder.
Write(builder *StatementBuilder)
+ // IsOnColumn checks if the change is on the given column.
+ IsOnColumn(col Column) bool
}
type change[V Value] struct {
@@ -13,6 +18,8 @@ type change[V Value] struct {
var _ Change = (*change[string])(nil)
+// NewChange creates a new Change for the given column and value.
+// If you want to set a column to NULL, use [NewChangePtr].
func NewChange[V Value](col Column, value V) Change {
return &change[V]{
column: col,
@@ -20,6 +27,8 @@ func NewChange[V Value](col Column, value V) Change {
}
}
+// NewChangePtr creates a new Change for the given column and value pointer.
+// If the value pointer is nil, the column will be set to NULL.
func NewChangePtr[V Value](col Column, value *V) Change {
if value == nil {
return NewChange(col, NullInstruction)
@@ -34,19 +43,31 @@ func (c change[V]) Write(builder *StatementBuilder) {
builder.WriteArg(c.value)
}
+// IsOnColumn implements [Change].
+func (c change[V]) IsOnColumn(col Column) bool {
+ return c.column.Equals(col)
+}
+
type Changes []Change
func NewChanges(cols ...Change) Change {
return Changes(cols)
}
+// IsOnColumn implements [Change].
+func (c Changes) IsOnColumn(col Column) bool {
+ return slices.ContainsFunc(c, func(change Change) bool {
+ return change.IsOnColumn(col)
+ })
+}
+
// Write implements [Change].
func (m Changes) Write(builder *StatementBuilder) {
- for i, col := range m {
+ for i, change := range m {
if i > 0 {
builder.WriteString(", ")
}
- col.Write(builder)
+ change.Write(builder)
}
}
diff --git a/backend/v3/storage/database/change_test.go b/backend/v3/storage/database/change_test.go
new file mode 100644
index 00000000000..8b2733ddcea
--- /dev/null
+++ b/backend/v3/storage/database/change_test.go
@@ -0,0 +1,68 @@
+package database
+
+import (
+ "testing"
+
+ "github.com/muhlemmer/gu"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestChangeWrite(t *testing.T) {
+ type want struct {
+ stmt string
+ args []any
+ }
+ for _, test := range []struct {
+ name string
+ change Change
+ want want
+ }{
+ {
+ name: "change",
+ change: NewChange(NewColumn("table", "column"), "value"),
+ want: want{
+ stmt: "column = $1",
+ args: []any{"value"},
+ },
+ },
+ {
+ name: "change ptr to null",
+ change: NewChangePtr[int](NewColumn("table", "column"), nil),
+ want: want{
+ stmt: "column = NULL",
+ args: nil,
+ },
+ },
+ {
+ name: "change ptr to value",
+ change: NewChangePtr(NewColumn("table", "column"), gu.Ptr(42)),
+ want: want{
+ stmt: "column = $1",
+ args: []any{42},
+ },
+ },
+ {
+ name: "multiple changes",
+ change: NewChanges(
+ NewChange(NewColumn("table", "column1"), "value1"),
+ NewChangePtr[int](NewColumn("table", "column2"), nil),
+ NewChange(NewColumn("table", "column3"), 123),
+ ),
+ want: want{
+ stmt: "column1 = $1, column2 = NULL, column3 = $2",
+ args: []any{"value1", 123},
+ },
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ var builder StatementBuilder
+ test.change.Write(&builder)
+ assert.Equal(t, test.want.stmt, builder.String())
+ require.Len(t, builder.Args(), len(test.want.args))
+ for i, arg := range test.want.args {
+ assert.Equal(t, arg, builder.Args()[i])
+ }
+ })
+ }
+}
diff --git a/backend/v3/storage/database/column.go b/backend/v3/storage/database/column.go
index 7f57637d38a..f1316048bbe 100644
--- a/backend/v3/storage/database/column.go
+++ b/backend/v3/storage/database/column.go
@@ -3,8 +3,9 @@ package database
type Columns []Column
// WriteQualified implements [Column].
-func (m Columns) WriteQualified(builder *StatementBuilder) {
- for i, col := range m {
+// Columns are separated by ", ".
+func (c Columns) WriteQualified(builder *StatementBuilder) {
+ for i, col := range c {
if i > 0 {
builder.WriteString(", ")
}
@@ -13,8 +14,9 @@ func (m Columns) WriteQualified(builder *StatementBuilder) {
}
// WriteUnqualified implements [Column].
-func (m Columns) WriteUnqualified(builder *StatementBuilder) {
- for i, col := range m {
+// Columns are separated by ", ".
+func (c Columns) WriteUnqualified(builder *StatementBuilder) {
+ for i, col := range c {
if i > 0 {
builder.WriteString(", ")
}
@@ -22,11 +24,33 @@ func (m Columns) WriteUnqualified(builder *StatementBuilder) {
}
}
+// Equals implements [Column].
+func (c Columns) Equals(col Column) bool {
+ if col == nil {
+ return c == nil
+ }
+ other, ok := col.(Columns)
+ if !ok || len(other) != len(c) {
+ return false
+ }
+ for i, col := range c {
+ if !col.Equals(other[i]) {
+ return false
+ }
+ }
+ return true
+}
+
+var _ Column = (Columns)(nil)
+
// Column represents a column in a database table.
type Column interface {
- // Write(builder *StatementBuilder)
+ // WriteQualified writes the column with the table name as prefix.
WriteQualified(builder *StatementBuilder)
+ // WriteUnqualified writes the column without the table name as prefix.
WriteUnqualified(builder *StatementBuilder)
+ // Equals checks if two columns are equal.
+ Equals(col Column) bool
}
type column struct {
@@ -35,7 +59,7 @@ type column struct {
}
func NewColumn(table, name string) Column {
- return column{table: table, name: name}
+ return &column{table: table, name: name}
}
// WriteQualified implements [Column].
@@ -51,35 +75,69 @@ func (c column) WriteUnqualified(builder *StatementBuilder) {
builder.WriteString(c.name)
}
-var _ Column = (*column)(nil)
+// Equals implements [Column].
+func (c *column) Equals(col Column) bool {
+ if col == nil {
+ return c == nil
+ }
+ toMatch, ok := col.(*column)
+ if !ok {
+ return false
+ }
+ return c.table == toMatch.table && c.name == toMatch.name
+}
-// // ignoreCaseColumn represents two database columns, one for the
-// // original value and one for the lower case value.
-// type ignoreCaseColumn interface {
-// Column
-// WriteIgnoreCase(builder *StatementBuilder)
-// }
+// LowerColumn returns a column that represents LOWER(col).
+func LowerColumn(col Column) Column {
+ return &functionColumn{fn: functionLower, col: col}
+}
-// func NewIgnoreCaseColumn(col Column, suffix string) ignoreCaseColumn {
-// return ignoreCaseCol{
-// column: col,
-// suffix: suffix,
-// }
-// }
+// SHA256Column returns a column that represents SHA256(col).
+func SHA256Column(col Column) Column {
+ return &functionColumn{fn: functionSHA256, col: col}
+}
-// type ignoreCaseCol struct {
-// column Column
-// suffix string
-// }
+type functionColumn struct {
+ fn function
+ col Column
+}
-// // WriteIgnoreCase implements [ignoreCaseColumn].
-// func (c ignoreCaseCol) WriteIgnoreCase(builder *StatementBuilder) {
-// c.column.WriteQualified(builder)
-// builder.WriteString(c.suffix)
-// }
+type function string
-// // WriteQualified implements [ignoreCaseColumn].
-// func (c ignoreCaseCol) WriteQualified(builder *StatementBuilder) {
-// c.column.WriteQualified(builder)
-// builder.WriteString(c.suffix)
-// }
+const (
+ _ function = ""
+ functionLower function = "LOWER"
+ functionSHA256 function = "SHA256"
+)
+
+// WriteQualified implements [Column].
+func (c functionColumn) WriteQualified(builder *StatementBuilder) {
+ builder.Grow(len(c.fn) + 2)
+ builder.WriteString(string(c.fn))
+ builder.WriteRune('(')
+ c.col.WriteQualified(builder)
+ builder.WriteRune(')')
+}
+
+// WriteUnqualified implements [Column].
+func (c functionColumn) WriteUnqualified(builder *StatementBuilder) {
+ builder.Grow(len(c.fn) + 2)
+ builder.WriteString(string(c.fn))
+ builder.WriteRune('(')
+ c.col.WriteUnqualified(builder)
+ builder.WriteRune(')')
+}
+
+// Equals implements [Column].
+func (c *functionColumn) Equals(col Column) bool {
+ if col == nil {
+ return c == nil
+ }
+ toMatch, ok := col.(*functionColumn)
+ if !ok || toMatch.fn != c.fn {
+ return false
+ }
+ return c.col.Equals(toMatch.col)
+}
+
+var _ Column = (*functionColumn)(nil)
diff --git a/backend/v3/storage/database/column_test.go b/backend/v3/storage/database/column_test.go
new file mode 100644
index 00000000000..216dddfb5e8
--- /dev/null
+++ b/backend/v3/storage/database/column_test.go
@@ -0,0 +1,212 @@
+package database
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestWriteUnqualified(t *testing.T) {
+ for _, tests := range []struct {
+ name string
+ column Column
+ expected string
+ }{
+ {
+ name: "column",
+ column: NewColumn("table", "column"),
+ expected: "column",
+ },
+ {
+ name: "columns",
+ column: Columns{
+ NewColumn("table", "column1"),
+ NewColumn("table", "column2"),
+ },
+ expected: "column1, column2",
+ },
+ {
+ name: "function column",
+ column: SHA256Column(NewColumn("table", "column")),
+ expected: "SHA256(column)",
+ },
+ } {
+ t.Run(tests.name, func(t *testing.T) {
+ var builder StatementBuilder
+ tests.column.WriteUnqualified(&builder)
+ assert.Equal(t, tests.expected, builder.String())
+ })
+ }
+}
+
+func TestWriteQualified(t *testing.T) {
+ for _, tests := range []struct {
+ name string
+ column Column
+ expected string
+ }{
+ {
+ name: "column",
+ column: NewColumn("table", "column"),
+ expected: "table.column",
+ },
+ {
+ name: "columns",
+ column: Columns{
+ NewColumn("table", "column1"),
+ NewColumn("table", "column2"),
+ },
+ expected: "table.column1, table.column2",
+ },
+ {
+ name: "function column",
+ column: SHA256Column(NewColumn("table", "column")),
+ expected: "SHA256(table.column)",
+ },
+ } {
+ t.Run(tests.name, func(t *testing.T) {
+ var builder StatementBuilder
+ tests.column.WriteQualified(&builder)
+ assert.Equal(t, tests.expected, builder.String())
+ })
+ }
+}
+
+func TestEquals(t *testing.T) {
+ for _, tests := range []struct {
+ name string
+ column Column
+ toCheck Column
+ expected bool
+ }{
+ {
+ name: "column equal",
+ column: NewColumn("table", "column"),
+ toCheck: NewColumn("table", "column"),
+ expected: true,
+ },
+ {
+ name: "column nil check",
+ column: NewColumn("table", "column"),
+ toCheck: nil,
+ expected: false,
+ },
+ {
+ name: "column both nil",
+ column: (*column)(nil),
+ toCheck: nil,
+ expected: true,
+ },
+ {
+ name: "column not equal (different name)",
+ column: NewColumn("table", "column"),
+ toCheck: NewColumn("table", "column2"),
+ expected: false,
+ },
+ {
+ name: "column not equal (different type)",
+ column: NewColumn("table", "column"),
+ toCheck: SHA256Column(NewColumn("table", "column")),
+ expected: false,
+ },
+ {
+ name: "columns equal",
+ column: Columns{
+ NewColumn("table", "column1"),
+ NewColumn("table", "column2"),
+ },
+ toCheck: Columns{
+ NewColumn("table", "column1"),
+ NewColumn("table", "column2"),
+ },
+ expected: true,
+ },
+ {
+ name: "columns nil check",
+ column: Columns{
+ NewColumn("table", "column1"),
+ NewColumn("table", "column2"),
+ },
+ toCheck: nil,
+ expected: false,
+ },
+ {
+ name: "columns both nil",
+ column: Columns(nil),
+ toCheck: nil,
+ expected: true,
+ },
+ {
+ name: "columns not equal (different type)",
+ column: Columns{
+ NewColumn("table", "column1"),
+ NewColumn("table", "column2"),
+ },
+ toCheck: NewColumn("table", "column1"),
+ expected: false,
+ },
+ {
+ name: "columns not equal (different length)",
+ column: Columns{
+ NewColumn("table", "column1"),
+ NewColumn("table", "column2"),
+ },
+ toCheck: Columns{
+ NewColumn("table", "column1"),
+ },
+ expected: false,
+ },
+ {
+ name: "columns not equal (different order)",
+ column: Columns{
+ NewColumn("table", "column1"),
+ NewColumn("table", "column2"),
+ },
+ toCheck: Columns{
+ NewColumn("table", "column2"),
+ NewColumn("table", "column1"),
+ },
+ expected: false,
+ },
+ {
+ name: "function column equal",
+ column: SHA256Column(NewColumn("table", "column")),
+ toCheck: SHA256Column(NewColumn("table", "column")),
+ expected: true,
+ },
+ {
+ name: "function nil check",
+ column: SHA256Column(NewColumn("table", "column")),
+ toCheck: nil,
+ expected: false,
+ },
+ {
+ name: "function both nil",
+ column: (*functionColumn)(nil),
+ toCheck: nil,
+ expected: true,
+ },
+ {
+ name: "function column not equal (different function)",
+ column: SHA256Column(NewColumn("table", "column")),
+ toCheck: LowerColumn(NewColumn("table", "column")),
+ expected: false,
+ },
+ {
+ name: "function column not equal (different inner column)",
+ column: SHA256Column(NewColumn("table", "column")),
+ toCheck: SHA256Column(NewColumn("table", "column2")),
+ expected: false,
+ },
+ {
+ name: "function column not equal (different type)",
+ column: SHA256Column(NewColumn("table", "column")),
+ toCheck: NewColumn("table", "column2"),
+ expected: false,
+ },
+ } {
+ t.Run(tests.name, func(t *testing.T) {
+ assert.Equal(t, tests.expected, tests.column.Equals(tests.toCheck))
+ })
+ }
+}
diff --git a/backend/v3/storage/database/condition.go b/backend/v3/storage/database/condition.go
index 5c8da5ff4b2..6bc1124f12e 100644
--- a/backend/v3/storage/database/condition.go
+++ b/backend/v3/storage/database/condition.go
@@ -4,6 +4,9 @@ package database
// Its written after the WHERE keyword in a SQL statement.
type Condition interface {
Write(builder *StatementBuilder)
+ // IsRestrictingColumn is used to check if the condition filters for a specific column.
+ // It acts as a save guard database operations that should be specific on the given column.
+ IsRestrictingColumn(col Column) bool
}
type and struct {
@@ -11,7 +14,7 @@ type and struct {
}
// Write implements [Condition].
-func (a *and) Write(builder *StatementBuilder) {
+func (a and) Write(builder *StatementBuilder) {
if len(a.conditions) > 1 {
builder.WriteString("(")
defer builder.WriteString(")")
@@ -29,6 +32,16 @@ func And(conditions ...Condition) *and {
return &and{conditions: conditions}
}
+// IsRestrictingColumn implements [Condition].
+func (a and) IsRestrictingColumn(col Column) bool {
+ for _, condition := range a.conditions {
+ if condition.IsRestrictingColumn(col) {
+ return true
+ }
+ }
+ return false
+}
+
var _ Condition = (*and)(nil)
type or struct {
@@ -36,7 +49,7 @@ type or struct {
}
// Write implements [Condition].
-func (o *or) Write(builder *StatementBuilder) {
+func (o or) Write(builder *StatementBuilder) {
if len(o.conditions) > 1 {
builder.WriteString("(")
defer builder.WriteString(")")
@@ -54,6 +67,17 @@ func Or(conditions ...Condition) *or {
return &or{conditions: conditions}
}
+// IsRestrictingColumn implements [Condition].
+// It returns true only if all conditions are restricting the given column.
+func (o or) IsRestrictingColumn(col Column) bool {
+ for _, condition := range o.conditions {
+ if !condition.IsRestrictingColumn(col) {
+ return false
+ }
+ }
+ return true
+}
+
var _ Condition = (*or)(nil)
type isNull struct {
@@ -61,7 +85,7 @@ type isNull struct {
}
// Write implements [Condition].
-func (i *isNull) Write(builder *StatementBuilder) {
+func (i isNull) Write(builder *StatementBuilder) {
i.column.WriteQualified(builder)
builder.WriteString(" IS NULL")
}
@@ -71,6 +95,12 @@ func IsNull(column Column) *isNull {
return &isNull{column: column}
}
+// IsRestrictingColumn implements [Condition].
+// It returns false because it cannot be used for restricting a column.
+func (i isNull) IsRestrictingColumn(col Column) bool {
+ return false
+}
+
var _ Condition = (*isNull)(nil)
type isNotNull struct {
@@ -78,7 +108,7 @@ type isNotNull struct {
}
// Write implements [Condition].
-func (i *isNotNull) Write(builder *StatementBuilder) {
+func (i isNotNull) Write(builder *StatementBuilder) {
i.column.WriteQualified(builder)
builder.WriteString(" IS NOT NULL")
}
@@ -88,43 +118,122 @@ func IsNotNull(column Column) *isNotNull {
return &isNotNull{column: column}
}
+// IsRestrictingColumn implements [Condition].
+// It returns false because it cannot be used for restricting a column.
+func (i isNotNull) IsRestrictingColumn(col Column) bool {
+ return false
+}
+
var _ Condition = (*isNotNull)(nil)
-type valueCondition func(builder *StatementBuilder)
+type valueCondition struct {
+ write func(builder *StatementBuilder)
+ col Column
+}
// NewTextCondition creates a condition that compares a text column with a value.
-func NewTextCondition[V Text](col Column, op TextOperation, value V) Condition {
- return valueCondition(func(builder *StatementBuilder) {
- writeTextOperation(builder, col, op, value)
- })
+// If you want to use ignore case operations, consider using [NewTextIgnoreCaseCondition].
+func NewTextCondition[T Text](col Column, op TextOperation, value T) Condition {
+ return valueCondition{
+ col: col,
+ write: func(builder *StatementBuilder) {
+ writeTextOperation[T](builder, col, op, value)
+ },
+ }
+}
+
+// NewTextIgnoreCaseCondition creates a condition that compares a text column with a value, ignoring case by lowercasing both.
+func NewTextIgnoreCaseCondition[T Text](col Column, op TextOperation, value T) Condition {
+ return valueCondition{
+ col: col,
+ write: func(builder *StatementBuilder) {
+ writeTextOperation[T](builder, LowerColumn(col), op, LowerValue(value))
+ },
+ }
}
// NewDateCondition creates a condition that compares a numeric column with a value.
func NewNumberCondition[V Number](col Column, op NumberOperation, value V) Condition {
- return valueCondition(func(builder *StatementBuilder) {
- writeNumberOperation(builder, col, op, value)
- })
+ return valueCondition{
+ col: col,
+ write: func(builder *StatementBuilder) {
+ writeNumberOperation[V](builder, col, op, value)
+ },
+ }
}
// NewDateCondition creates a condition that compares a boolean column with a value.
func NewBooleanCondition[V Boolean](col Column, value V) Condition {
- return valueCondition(func(builder *StatementBuilder) {
- writeBooleanOperation(builder, col, value)
- })
+ return valueCondition{
+ col: col,
+ write: func(builder *StatementBuilder) {
+ writeBooleanOperation[V](builder, col, value)
+ },
+ }
+}
+
+// NewBytesCondition creates a condition that compares a BYTEA column with a value.
+func NewBytesCondition[V Bytes](col Column, op BytesOperation, value any) Condition {
+ return valueCondition{
+ col: col,
+ write: func(builder *StatementBuilder) {
+ writeBytesOperation[V](builder, col, op, value)
+ },
+ }
}
// NewColumnCondition creates a condition that compares two columns on equality.
func NewColumnCondition(col1, col2 Column) Condition {
- return valueCondition(func(builder *StatementBuilder) {
- col1.WriteQualified(builder)
- builder.WriteString(" = ")
- col2.WriteQualified(builder)
- })
+ return valueCondition{
+ col: col1,
+ write: func(builder *StatementBuilder) {
+ col1.WriteQualified(builder)
+ builder.WriteString(" = ")
+ col2.WriteQualified(builder)
+ },
+ }
}
// Write implements [Condition].
func (c valueCondition) Write(builder *StatementBuilder) {
- c(builder)
+ c.write(builder)
+}
+
+// IsRestrictingColumn implements [Condition].
+func (i valueCondition) IsRestrictingColumn(col Column) bool {
+ return i.col.Equals(col)
}
var _ Condition = (*valueCondition)(nil)
+
+// existsCondition is a helper to write an EXISTS (SELECT 1 FROM
WHERE ) clause.
+// It implements Condition so it can be composed with other conditions using And/Or.
+type existsCondition struct {
+ table string
+ condition Condition
+}
+
+// Exists creates a condition that checks for the existence of rows in a subquery.
+func Exists(table string, condition Condition) Condition {
+ return &existsCondition{
+ table: table,
+ condition: condition,
+ }
+}
+
+// Write implements [Condition].
+func (e existsCondition) Write(builder *StatementBuilder) {
+ builder.WriteString(" EXISTS (SELECT 1 FROM ")
+ builder.WriteString(e.table)
+ builder.WriteString(" WHERE ")
+ e.condition.Write(builder)
+ builder.WriteString(")")
+}
+
+// IsRestrictingColumn implements [Condition].
+func (e existsCondition) IsRestrictingColumn(col Column) bool {
+ // Forward to the inner condition so safety checks (like instance_id presence) can still work.
+ return e.condition.IsRestrictingColumn(col)
+}
+
+var _ Condition = (*existsCondition)(nil)
diff --git a/backend/v3/storage/database/condition_test.go b/backend/v3/storage/database/condition_test.go
new file mode 100644
index 00000000000..a15dcd8fd62
--- /dev/null
+++ b/backend/v3/storage/database/condition_test.go
@@ -0,0 +1,248 @@
+package database
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestWrite(t *testing.T) {
+ type want struct {
+ stmt string
+ args []any
+ }
+ for _, test := range []struct {
+ name string
+ cond Condition
+ want want
+ }{
+ {
+ name: "no and condition",
+ cond: And(),
+ want: want{
+ stmt: "",
+ args: nil,
+ },
+ },
+ {
+ name: "one and condition",
+ cond: And(
+ NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
+ ),
+ want: want{
+ stmt: "table.column1 = other_table.column2",
+ args: nil,
+ },
+ },
+ {
+ name: "multiple and condition",
+ cond: And(
+ NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
+ NewColumnCondition(NewColumn("table", "column3"), NewColumn("other_table", "column4")),
+ ),
+ want: want{
+ stmt: "(table.column1 = other_table.column2 AND table.column3 = other_table.column4)",
+ args: nil,
+ },
+ },
+ {
+ name: "no or condition",
+ cond: Or(),
+ want: want{
+ stmt: "",
+ args: nil,
+ },
+ },
+ {
+ name: "one or condition",
+ cond: Or(
+ NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
+ ),
+ want: want{
+ stmt: "table.column1 = other_table.column2",
+ args: nil,
+ },
+ },
+ {
+ name: "multiple or condition",
+ cond: Or(
+ NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
+ NewColumnCondition(NewColumn("table", "column3"), NewColumn("other_table", "column4")),
+ ),
+ want: want{
+ stmt: "(table.column1 = other_table.column2 OR table.column3 = other_table.column4)",
+ args: nil,
+ },
+ },
+ {
+ name: "is null condition",
+ cond: IsNull(NewColumn("table", "column1")),
+ want: want{
+ stmt: "table.column1 IS NULL",
+ args: nil,
+ },
+ },
+ {
+ name: "is not null condition",
+ cond: IsNotNull(NewColumn("table", "column1")),
+ want: want{
+ stmt: "table.column1 IS NOT NULL",
+ args: nil,
+ },
+ },
+ {
+ name: "text condition",
+ cond: NewTextCondition(NewColumn("table", "column1"), TextOperationEqual, "some text"),
+ want: want{
+ stmt: "table.column1 = $1",
+ args: []any{"some text"},
+ },
+ },
+ {
+ name: "text ignore case condition",
+ cond: NewTextIgnoreCaseCondition(NewColumn("table", "column1"), TextOperationNotEqual, "some TEXT"),
+ want: want{
+ stmt: "LOWER(table.column1) <> LOWER($1)",
+ args: []any{"some TEXT"},
+ },
+ },
+ {
+ name: "number condition",
+ cond: NewNumberCondition(NewColumn("table", "column1"), NumberOperationEqual, 42),
+ want: want{
+ stmt: "table.column1 = $1",
+ args: []any{42},
+ },
+ },
+ {
+ name: "boolean condition",
+ cond: NewBooleanCondition(NewColumn("table", "column1"), true),
+ want: want{
+ stmt: "table.column1 = $1",
+ args: []any{true},
+ },
+ },
+ {
+ name: "bytes condition",
+ cond: NewBytesCondition[[]byte](NewColumn("table", "column1"), BytesOperationEqual, []byte{0x01, 0x02, 0x03}),
+ want: want{
+ stmt: "table.column1 = $1",
+ args: []any{[]byte{0x01, 0x02, 0x03}},
+ },
+ },
+ {
+ name: "column condition",
+ cond: NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
+ want: want{
+ stmt: "table.column1 = other_table.column2",
+ args: nil,
+ },
+ },
+ {
+ name: "exists condition",
+ cond: Exists("table", And(
+ NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
+ NewColumnCondition(NewColumn("table", "column3"), NewColumn("other_table", "column4")),
+ )),
+ want: want{
+ stmt: " EXISTS (SELECT 1 FROM table WHERE (table.column1 = other_table.column2 AND table.column3 = other_table.column4))",
+ args: nil,
+ },
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ var builder StatementBuilder
+ test.cond.Write(&builder)
+ assert.Equal(t, test.want.stmt, builder.String())
+ require.Len(t, builder.Args(), len(test.want.args))
+ for i, arg := range test.want.args {
+ assert.Equal(t, arg, builder.Args()[i])
+ }
+ })
+ }
+}
+
+func TestIsRestrictingColumn(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ col Column
+ cond Condition
+ want bool
+ }{
+ {
+ name: "and with restricting column",
+ col: NewColumn("table", "column1"),
+ cond: And(
+ NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
+ NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")),
+ ),
+ want: true,
+ },
+ {
+ name: "and without restricting column",
+ col: NewColumn("table", "column1"),
+ cond: And(
+ NewColumnCondition(NewColumn("table", "column2"), NewColumn("other_table", "column3")),
+ IsNull(NewColumn("table", "column4")),
+ IsNotNull(NewColumn("table", "column5")),
+ ),
+ want: false,
+ },
+ {
+ name: "or with restricting column",
+ col: NewColumn("table", "column1"),
+ cond: Or(
+ NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
+ NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")),
+ ),
+ want: true,
+ },
+ {
+ name: "or without restricting column",
+ col: NewColumn("table", "column1"),
+ cond: Or(
+ NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")),
+ IsNotNull(NewColumn("table", "column4")),
+ IsNull(NewColumn("table", "column5")),
+ ),
+ want: false,
+ },
+ {
+ name: "is null never restricts",
+ col: NewColumn("table", "column1"),
+ cond: IsNull(NewColumn("table", "column1")),
+ want: false,
+ },
+ {
+ name: "is not null never restricts",
+ col: NewColumn("table", "column1"),
+ cond: IsNotNull(NewColumn("table", "column1")),
+ want: false,
+ },
+ {
+ name: "exists with restricting column",
+ col: NewColumn("table", "column1"),
+ cond: Exists("table", And(
+ NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
+ NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")),
+ )),
+ want: true,
+ },
+ {
+ name: "exists without restricting column",
+ col: NewColumn("table", "column1"),
+ cond: Exists("table", Or(
+ NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")),
+ IsNotNull(NewColumn("table", "column4")),
+ IsNull(NewColumn("table", "column5")),
+ )),
+ want: false,
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ isRestricting := test.cond.IsRestrictingColumn(test.col)
+ assert.Equal(t, test.want, isRestricting)
+ })
+ }
+}
diff --git a/backend/v3/storage/database/database.go b/backend/v3/storage/database/database.go
index 7cdeb9c0c3c..c241b42e589 100644
--- a/backend/v3/storage/database/database.go
+++ b/backend/v3/storage/database/database.go
@@ -10,7 +10,7 @@ type Pool interface {
QueryExecutor
Migrator
- Acquire(ctx context.Context) (Client, error)
+ Acquire(ctx context.Context) (Connection, error)
Close(ctx context.Context) error
Ping(ctx context.Context) error
@@ -22,8 +22,8 @@ type PoolTest interface {
MigrateTest(ctx context.Context) error
}
-// Client is a single database connection which can be released back to the pool.
-type Client interface {
+// Connection is a single database connection which can be released back to the pool.
+type Connection interface {
Beginner
QueryExecutor
Migrator
diff --git a/backend/v3/storage/database/dbmock/database.mock.go b/backend/v3/storage/database/dbmock/database.mock.go
index 1ff898257c0..2c8b176b7e1 100644
--- a/backend/v3/storage/database/dbmock/database.mock.go
+++ b/backend/v3/storage/database/dbmock/database.mock.go
@@ -1,9 +1,9 @@
// Code generated by MockGen. DO NOT EDIT.
-// Source: github.com/zitadel/zitadel/backend/v3/storage/database (interfaces: Pool,Client,Row,Rows,Transaction)
+// Source: github.com/zitadel/zitadel/backend/v3/storage/database (interfaces: Pool,Connection,Row,Rows,Transaction)
//
// Generated by this command:
//
-// mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Client,Row,Rows,Transaction
+// mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Connection,Row,Rows,Transaction
//
// Package dbmock is a generated GoMock package.
@@ -21,6 +21,7 @@ import (
type MockPool struct {
ctrl *gomock.Controller
recorder *MockPoolMockRecorder
+ isgomock struct{}
}
// MockPoolMockRecorder is the mock recorder for MockPool.
@@ -41,18 +42,18 @@ func (m *MockPool) EXPECT() *MockPoolMockRecorder {
}
// Acquire mocks base method.
-func (m *MockPool) Acquire(arg0 context.Context) (database.Client, error) {
+func (m *MockPool) Acquire(ctx context.Context) (database.Connection, error) {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "Acquire", arg0)
- ret0, _ := ret[0].(database.Client)
+ ret := m.ctrl.Call(m, "Acquire", ctx)
+ ret0, _ := ret[0].(database.Connection)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Acquire indicates an expected call of Acquire.
-func (mr *MockPoolMockRecorder) Acquire(arg0 any) *MockPoolAcquireCall {
+func (mr *MockPoolMockRecorder) Acquire(ctx any) *MockPoolAcquireCall {
mr.mock.ctrl.T.Helper()
- call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockPool)(nil).Acquire), arg0)
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockPool)(nil).Acquire), ctx)
return &MockPoolAcquireCall{Call: call}
}
@@ -62,36 +63,36 @@ type MockPoolAcquireCall struct {
}
// Return rewrite *gomock.Call.Return
-func (c *MockPoolAcquireCall) Return(arg0 database.Client, arg1 error) *MockPoolAcquireCall {
+func (c *MockPoolAcquireCall) Return(arg0 database.Connection, arg1 error) *MockPoolAcquireCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
-func (c *MockPoolAcquireCall) Do(f func(context.Context) (database.Client, error)) *MockPoolAcquireCall {
+func (c *MockPoolAcquireCall) Do(f func(context.Context) (database.Connection, error)) *MockPoolAcquireCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockPoolAcquireCall) DoAndReturn(f func(context.Context) (database.Client, error)) *MockPoolAcquireCall {
+func (c *MockPoolAcquireCall) DoAndReturn(f func(context.Context) (database.Connection, error)) *MockPoolAcquireCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Begin mocks base method.
-func (m *MockPool) Begin(arg0 context.Context, arg1 *database.TransactionOptions) (database.Transaction, error) {
+func (m *MockPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "Begin", arg0, arg1)
+ ret := m.ctrl.Call(m, "Begin", ctx, opts)
ret0, _ := ret[0].(database.Transaction)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Begin indicates an expected call of Begin.
-func (mr *MockPoolMockRecorder) Begin(arg0, arg1 any) *MockPoolBeginCall {
+func (mr *MockPoolMockRecorder) Begin(ctx, opts any) *MockPoolBeginCall {
mr.mock.ctrl.T.Helper()
- call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockPool)(nil).Begin), arg0, arg1)
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockPool)(nil).Begin), ctx, opts)
return &MockPoolBeginCall{Call: call}
}
@@ -119,17 +120,17 @@ func (c *MockPoolBeginCall) DoAndReturn(f func(context.Context, *database.Transa
}
// Close mocks base method.
-func (m *MockPool) Close(arg0 context.Context) error {
+func (m *MockPool) Close(ctx context.Context) error {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "Close", arg0)
+ ret := m.ctrl.Call(m, "Close", ctx)
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
-func (mr *MockPoolMockRecorder) Close(arg0 any) *MockPoolCloseCall {
+func (mr *MockPoolMockRecorder) Close(ctx any) *MockPoolCloseCall {
mr.mock.ctrl.T.Helper()
- call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPool)(nil).Close), arg0)
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPool)(nil).Close), ctx)
return &MockPoolCloseCall{Call: call}
}
@@ -157,10 +158,10 @@ func (c *MockPoolCloseCall) DoAndReturn(f func(context.Context) error) *MockPool
}
// Exec mocks base method.
-func (m *MockPool) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) {
+func (m *MockPool) Exec(ctx context.Context, stmt string, args ...any) (int64, error) {
m.ctrl.T.Helper()
- varargs := []any{arg0, arg1}
- for _, a := range arg2 {
+ varargs := []any{ctx, stmt}
+ for _, a := range args {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
@@ -170,9 +171,9 @@ func (m *MockPool) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64,
}
// Exec indicates an expected call of Exec.
-func (mr *MockPoolMockRecorder) Exec(arg0, arg1 any, arg2 ...any) *MockPoolExecCall {
+func (mr *MockPoolMockRecorder) Exec(ctx, stmt any, args ...any) *MockPoolExecCall {
mr.mock.ctrl.T.Helper()
- varargs := append([]any{arg0, arg1}, arg2...)
+ varargs := append([]any{ctx, stmt}, args...)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockPool)(nil).Exec), varargs...)
return &MockPoolExecCall{Call: call}
}
@@ -201,17 +202,17 @@ func (c *MockPoolExecCall) DoAndReturn(f func(context.Context, string, ...any) (
}
// Migrate mocks base method.
-func (m *MockPool) Migrate(arg0 context.Context) error {
+func (m *MockPool) Migrate(ctx context.Context) error {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "Migrate", arg0)
+ ret := m.ctrl.Call(m, "Migrate", ctx)
ret0, _ := ret[0].(error)
return ret0
}
// Migrate indicates an expected call of Migrate.
-func (mr *MockPoolMockRecorder) Migrate(arg0 any) *MockPoolMigrateCall {
+func (mr *MockPoolMockRecorder) Migrate(ctx any) *MockPoolMigrateCall {
mr.mock.ctrl.T.Helper()
- call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockPool)(nil).Migrate), arg0)
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockPool)(nil).Migrate), ctx)
return &MockPoolMigrateCall{Call: call}
}
@@ -238,11 +239,49 @@ func (c *MockPoolMigrateCall) DoAndReturn(f func(context.Context) error) *MockPo
return c
}
-// Query mocks base method.
-func (m *MockPool) Query(arg0 context.Context, arg1 string, arg2 ...any) (database.Rows, error) {
+// Ping mocks base method.
+func (m *MockPool) Ping(ctx context.Context) error {
m.ctrl.T.Helper()
- varargs := []any{arg0, arg1}
- for _, a := range arg2 {
+ ret := m.ctrl.Call(m, "Ping", ctx)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// Ping indicates an expected call of Ping.
+func (mr *MockPoolMockRecorder) Ping(ctx any) *MockPoolPingCall {
+ mr.mock.ctrl.T.Helper()
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockPool)(nil).Ping), ctx)
+ return &MockPoolPingCall{Call: call}
+}
+
+// MockPoolPingCall wrap *gomock.Call
+type MockPoolPingCall struct {
+ *gomock.Call
+}
+
+// Return rewrite *gomock.Call.Return
+func (c *MockPoolPingCall) Return(arg0 error) *MockPoolPingCall {
+ c.Call = c.Call.Return(arg0)
+ return c
+}
+
+// Do rewrite *gomock.Call.Do
+func (c *MockPoolPingCall) Do(f func(context.Context) error) *MockPoolPingCall {
+ c.Call = c.Call.Do(f)
+ return c
+}
+
+// DoAndReturn rewrite *gomock.Call.DoAndReturn
+func (c *MockPoolPingCall) DoAndReturn(f func(context.Context) error) *MockPoolPingCall {
+ c.Call = c.Call.DoAndReturn(f)
+ return c
+}
+
+// Query mocks base method.
+func (m *MockPool) Query(ctx context.Context, stmt string, args ...any) (database.Rows, error) {
+ m.ctrl.T.Helper()
+ varargs := []any{ctx, stmt}
+ for _, a := range args {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Query", varargs...)
@@ -252,9 +291,9 @@ func (m *MockPool) Query(arg0 context.Context, arg1 string, arg2 ...any) (databa
}
// Query indicates an expected call of Query.
-func (mr *MockPoolMockRecorder) Query(arg0, arg1 any, arg2 ...any) *MockPoolQueryCall {
+func (mr *MockPoolMockRecorder) Query(ctx, stmt any, args ...any) *MockPoolQueryCall {
mr.mock.ctrl.T.Helper()
- varargs := append([]any{arg0, arg1}, arg2...)
+ varargs := append([]any{ctx, stmt}, args...)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockPool)(nil).Query), varargs...)
return &MockPoolQueryCall{Call: call}
}
@@ -283,10 +322,10 @@ func (c *MockPoolQueryCall) DoAndReturn(f func(context.Context, string, ...any)
}
// QueryRow mocks base method.
-func (m *MockPool) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) database.Row {
+func (m *MockPool) QueryRow(ctx context.Context, stmt string, args ...any) database.Row {
m.ctrl.T.Helper()
- varargs := []any{arg0, arg1}
- for _, a := range arg2 {
+ varargs := []any{ctx, stmt}
+ for _, a := range args {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "QueryRow", varargs...)
@@ -295,9 +334,9 @@ func (m *MockPool) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) data
}
// QueryRow indicates an expected call of QueryRow.
-func (mr *MockPoolMockRecorder) QueryRow(arg0, arg1 any, arg2 ...any) *MockPoolQueryRowCall {
+func (mr *MockPoolMockRecorder) QueryRow(ctx, stmt any, args ...any) *MockPoolQueryRowCall {
mr.mock.ctrl.T.Helper()
- varargs := append([]any{arg0, arg1}, arg2...)
+ varargs := append([]any{ctx, stmt}, args...)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockPool)(nil).QueryRow), varargs...)
return &MockPoolQueryRowCall{Call: call}
}
@@ -325,73 +364,74 @@ func (c *MockPoolQueryRowCall) DoAndReturn(f func(context.Context, string, ...an
return c
}
-// MockClient is a mock of Client interface.
-type MockClient struct {
+// MockConnection is a mock of Connection interface.
+type MockConnection struct {
ctrl *gomock.Controller
- recorder *MockClientMockRecorder
+ recorder *MockConnectionMockRecorder
+ isgomock struct{}
}
-// MockClientMockRecorder is the mock recorder for MockClient.
-type MockClientMockRecorder struct {
- mock *MockClient
+// MockConnectionMockRecorder is the mock recorder for MockConnection.
+type MockConnectionMockRecorder struct {
+ mock *MockConnection
}
-// NewMockClient creates a new mock instance.
-func NewMockClient(ctrl *gomock.Controller) *MockClient {
- mock := &MockClient{ctrl: ctrl}
- mock.recorder = &MockClientMockRecorder{mock}
+// NewMockConnection creates a new mock instance.
+func NewMockConnection(ctrl *gomock.Controller) *MockConnection {
+ mock := &MockConnection{ctrl: ctrl}
+ mock.recorder = &MockConnectionMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
-func (m *MockClient) EXPECT() *MockClientMockRecorder {
+func (m *MockConnection) EXPECT() *MockConnectionMockRecorder {
return m.recorder
}
// Begin mocks base method.
-func (m *MockClient) Begin(arg0 context.Context, arg1 *database.TransactionOptions) (database.Transaction, error) {
+func (m *MockConnection) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "Begin", arg0, arg1)
+ ret := m.ctrl.Call(m, "Begin", ctx, opts)
ret0, _ := ret[0].(database.Transaction)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Begin indicates an expected call of Begin.
-func (mr *MockClientMockRecorder) Begin(arg0, arg1 any) *MockClientBeginCall {
+func (mr *MockConnectionMockRecorder) Begin(ctx, opts any) *MockConnectionBeginCall {
mr.mock.ctrl.T.Helper()
- call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockClient)(nil).Begin), arg0, arg1)
- return &MockClientBeginCall{Call: call}
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockConnection)(nil).Begin), ctx, opts)
+ return &MockConnectionBeginCall{Call: call}
}
-// MockClientBeginCall wrap *gomock.Call
-type MockClientBeginCall struct {
+// MockConnectionBeginCall wrap *gomock.Call
+type MockConnectionBeginCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
-func (c *MockClientBeginCall) Return(arg0 database.Transaction, arg1 error) *MockClientBeginCall {
+func (c *MockConnectionBeginCall) Return(arg0 database.Transaction, arg1 error) *MockConnectionBeginCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
-func (c *MockClientBeginCall) Do(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockClientBeginCall {
+func (c *MockConnectionBeginCall) Do(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockConnectionBeginCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockClientBeginCall) DoAndReturn(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockClientBeginCall {
+func (c *MockConnectionBeginCall) DoAndReturn(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockConnectionBeginCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Exec mocks base method.
-func (m *MockClient) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) {
+func (m *MockConnection) Exec(ctx context.Context, stmt string, args ...any) (int64, error) {
m.ctrl.T.Helper()
- varargs := []any{arg0, arg1}
- for _, a := range arg2 {
+ varargs := []any{ctx, stmt}
+ for _, a := range args {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
@@ -401,79 +441,117 @@ func (m *MockClient) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64
}
// Exec indicates an expected call of Exec.
-func (mr *MockClientMockRecorder) Exec(arg0, arg1 any, arg2 ...any) *MockClientExecCall {
+func (mr *MockConnectionMockRecorder) Exec(ctx, stmt any, args ...any) *MockConnectionExecCall {
mr.mock.ctrl.T.Helper()
- varargs := append([]any{arg0, arg1}, arg2...)
- call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockClient)(nil).Exec), varargs...)
- return &MockClientExecCall{Call: call}
+ varargs := append([]any{ctx, stmt}, args...)
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockConnection)(nil).Exec), varargs...)
+ return &MockConnectionExecCall{Call: call}
}
-// MockClientExecCall wrap *gomock.Call
-type MockClientExecCall struct {
+// MockConnectionExecCall wrap *gomock.Call
+type MockConnectionExecCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
-func (c *MockClientExecCall) Return(arg0 int64, arg1 error) *MockClientExecCall {
+func (c *MockConnectionExecCall) Return(arg0 int64, arg1 error) *MockConnectionExecCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
-func (c *MockClientExecCall) Do(f func(context.Context, string, ...any) (int64, error)) *MockClientExecCall {
+func (c *MockConnectionExecCall) Do(f func(context.Context, string, ...any) (int64, error)) *MockConnectionExecCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockClientExecCall) DoAndReturn(f func(context.Context, string, ...any) (int64, error)) *MockClientExecCall {
+func (c *MockConnectionExecCall) DoAndReturn(f func(context.Context, string, ...any) (int64, error)) *MockConnectionExecCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Migrate mocks base method.
-func (m *MockClient) Migrate(arg0 context.Context) error {
+func (m *MockConnection) Migrate(ctx context.Context) error {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "Migrate", arg0)
+ ret := m.ctrl.Call(m, "Migrate", ctx)
ret0, _ := ret[0].(error)
return ret0
}
// Migrate indicates an expected call of Migrate.
-func (mr *MockClientMockRecorder) Migrate(arg0 any) *MockClientMigrateCall {
+func (mr *MockConnectionMockRecorder) Migrate(ctx any) *MockConnectionMigrateCall {
mr.mock.ctrl.T.Helper()
- call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockClient)(nil).Migrate), arg0)
- return &MockClientMigrateCall{Call: call}
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockConnection)(nil).Migrate), ctx)
+ return &MockConnectionMigrateCall{Call: call}
}
-// MockClientMigrateCall wrap *gomock.Call
-type MockClientMigrateCall struct {
+// MockConnectionMigrateCall wrap *gomock.Call
+type MockConnectionMigrateCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
-func (c *MockClientMigrateCall) Return(arg0 error) *MockClientMigrateCall {
+func (c *MockConnectionMigrateCall) Return(arg0 error) *MockConnectionMigrateCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
-func (c *MockClientMigrateCall) Do(f func(context.Context) error) *MockClientMigrateCall {
+func (c *MockConnectionMigrateCall) Do(f func(context.Context) error) *MockConnectionMigrateCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockClientMigrateCall) DoAndReturn(f func(context.Context) error) *MockClientMigrateCall {
+func (c *MockConnectionMigrateCall) DoAndReturn(f func(context.Context) error) *MockConnectionMigrateCall {
+ c.Call = c.Call.DoAndReturn(f)
+ return c
+}
+
+// Ping mocks base method.
+func (m *MockConnection) Ping(ctx context.Context) error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Ping", ctx)
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// Ping indicates an expected call of Ping.
+func (mr *MockConnectionMockRecorder) Ping(ctx any) *MockConnectionPingCall {
+ mr.mock.ctrl.T.Helper()
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockConnection)(nil).Ping), ctx)
+ return &MockConnectionPingCall{Call: call}
+}
+
+// MockConnectionPingCall wrap *gomock.Call
+type MockConnectionPingCall struct {
+ *gomock.Call
+}
+
+// Return rewrite *gomock.Call.Return
+func (c *MockConnectionPingCall) Return(arg0 error) *MockConnectionPingCall {
+ c.Call = c.Call.Return(arg0)
+ return c
+}
+
+// Do rewrite *gomock.Call.Do
+func (c *MockConnectionPingCall) Do(f func(context.Context) error) *MockConnectionPingCall {
+ c.Call = c.Call.Do(f)
+ return c
+}
+
+// DoAndReturn rewrite *gomock.Call.DoAndReturn
+func (c *MockConnectionPingCall) DoAndReturn(f func(context.Context) error) *MockConnectionPingCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Query mocks base method.
-func (m *MockClient) Query(arg0 context.Context, arg1 string, arg2 ...any) (database.Rows, error) {
+func (m *MockConnection) Query(ctx context.Context, stmt string, args ...any) (database.Rows, error) {
m.ctrl.T.Helper()
- varargs := []any{arg0, arg1}
- for _, a := range arg2 {
+ varargs := []any{ctx, stmt}
+ for _, a := range args {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Query", varargs...)
@@ -483,41 +561,41 @@ func (m *MockClient) Query(arg0 context.Context, arg1 string, arg2 ...any) (data
}
// Query indicates an expected call of Query.
-func (mr *MockClientMockRecorder) Query(arg0, arg1 any, arg2 ...any) *MockClientQueryCall {
+func (mr *MockConnectionMockRecorder) Query(ctx, stmt any, args ...any) *MockConnectionQueryCall {
mr.mock.ctrl.T.Helper()
- varargs := append([]any{arg0, arg1}, arg2...)
- call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockClient)(nil).Query), varargs...)
- return &MockClientQueryCall{Call: call}
+ varargs := append([]any{ctx, stmt}, args...)
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockConnection)(nil).Query), varargs...)
+ return &MockConnectionQueryCall{Call: call}
}
-// MockClientQueryCall wrap *gomock.Call
-type MockClientQueryCall struct {
+// MockConnectionQueryCall wrap *gomock.Call
+type MockConnectionQueryCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
-func (c *MockClientQueryCall) Return(arg0 database.Rows, arg1 error) *MockClientQueryCall {
+func (c *MockConnectionQueryCall) Return(arg0 database.Rows, arg1 error) *MockConnectionQueryCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
-func (c *MockClientQueryCall) Do(f func(context.Context, string, ...any) (database.Rows, error)) *MockClientQueryCall {
+func (c *MockConnectionQueryCall) Do(f func(context.Context, string, ...any) (database.Rows, error)) *MockConnectionQueryCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockClientQueryCall) DoAndReturn(f func(context.Context, string, ...any) (database.Rows, error)) *MockClientQueryCall {
+func (c *MockConnectionQueryCall) DoAndReturn(f func(context.Context, string, ...any) (database.Rows, error)) *MockConnectionQueryCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// QueryRow mocks base method.
-func (m *MockClient) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) database.Row {
+func (m *MockConnection) QueryRow(ctx context.Context, stmt string, args ...any) database.Row {
m.ctrl.T.Helper()
- varargs := []any{arg0, arg1}
- for _, a := range arg2 {
+ varargs := []any{ctx, stmt}
+ for _, a := range args {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "QueryRow", varargs...)
@@ -526,70 +604,70 @@ func (m *MockClient) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) da
}
// QueryRow indicates an expected call of QueryRow.
-func (mr *MockClientMockRecorder) QueryRow(arg0, arg1 any, arg2 ...any) *MockClientQueryRowCall {
+func (mr *MockConnectionMockRecorder) QueryRow(ctx, stmt any, args ...any) *MockConnectionQueryRowCall {
mr.mock.ctrl.T.Helper()
- varargs := append([]any{arg0, arg1}, arg2...)
- call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockClient)(nil).QueryRow), varargs...)
- return &MockClientQueryRowCall{Call: call}
+ varargs := append([]any{ctx, stmt}, args...)
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockConnection)(nil).QueryRow), varargs...)
+ return &MockConnectionQueryRowCall{Call: call}
}
-// MockClientQueryRowCall wrap *gomock.Call
-type MockClientQueryRowCall struct {
+// MockConnectionQueryRowCall wrap *gomock.Call
+type MockConnectionQueryRowCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
-func (c *MockClientQueryRowCall) Return(arg0 database.Row) *MockClientQueryRowCall {
+func (c *MockConnectionQueryRowCall) Return(arg0 database.Row) *MockConnectionQueryRowCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
-func (c *MockClientQueryRowCall) Do(f func(context.Context, string, ...any) database.Row) *MockClientQueryRowCall {
+func (c *MockConnectionQueryRowCall) Do(f func(context.Context, string, ...any) database.Row) *MockConnectionQueryRowCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockClientQueryRowCall) DoAndReturn(f func(context.Context, string, ...any) database.Row) *MockClientQueryRowCall {
+func (c *MockConnectionQueryRowCall) DoAndReturn(f func(context.Context, string, ...any) database.Row) *MockConnectionQueryRowCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Release mocks base method.
-func (m *MockClient) Release(arg0 context.Context) error {
+func (m *MockConnection) Release(ctx context.Context) error {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "Release", arg0)
+ ret := m.ctrl.Call(m, "Release", ctx)
ret0, _ := ret[0].(error)
return ret0
}
// Release indicates an expected call of Release.
-func (mr *MockClientMockRecorder) Release(arg0 any) *MockClientReleaseCall {
+func (mr *MockConnectionMockRecorder) Release(ctx any) *MockConnectionReleaseCall {
mr.mock.ctrl.T.Helper()
- call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Release", reflect.TypeOf((*MockClient)(nil).Release), arg0)
- return &MockClientReleaseCall{Call: call}
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Release", reflect.TypeOf((*MockConnection)(nil).Release), ctx)
+ return &MockConnectionReleaseCall{Call: call}
}
-// MockClientReleaseCall wrap *gomock.Call
-type MockClientReleaseCall struct {
+// MockConnectionReleaseCall wrap *gomock.Call
+type MockConnectionReleaseCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
-func (c *MockClientReleaseCall) Return(arg0 error) *MockClientReleaseCall {
+func (c *MockConnectionReleaseCall) Return(arg0 error) *MockConnectionReleaseCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
-func (c *MockClientReleaseCall) Do(f func(context.Context) error) *MockClientReleaseCall {
+func (c *MockConnectionReleaseCall) Do(f func(context.Context) error) *MockConnectionReleaseCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
-func (c *MockClientReleaseCall) DoAndReturn(f func(context.Context) error) *MockClientReleaseCall {
+func (c *MockConnectionReleaseCall) DoAndReturn(f func(context.Context) error) *MockConnectionReleaseCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
@@ -598,6 +676,7 @@ func (c *MockClientReleaseCall) DoAndReturn(f func(context.Context) error) *Mock
type MockRow struct {
ctrl *gomock.Controller
recorder *MockRowMockRecorder
+ isgomock struct{}
}
// MockRowMockRecorder is the mock recorder for MockRow.
@@ -618,10 +697,10 @@ func (m *MockRow) EXPECT() *MockRowMockRecorder {
}
// Scan mocks base method.
-func (m *MockRow) Scan(arg0 ...any) error {
+func (m *MockRow) Scan(dest ...any) error {
m.ctrl.T.Helper()
varargs := []any{}
- for _, a := range arg0 {
+ for _, a := range dest {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Scan", varargs...)
@@ -630,9 +709,9 @@ func (m *MockRow) Scan(arg0 ...any) error {
}
// Scan indicates an expected call of Scan.
-func (mr *MockRowMockRecorder) Scan(arg0 ...any) *MockRowScanCall {
+func (mr *MockRowMockRecorder) Scan(dest ...any) *MockRowScanCall {
mr.mock.ctrl.T.Helper()
- call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRow)(nil).Scan), arg0...)
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRow)(nil).Scan), dest...)
return &MockRowScanCall{Call: call}
}
@@ -663,6 +742,7 @@ func (c *MockRowScanCall) DoAndReturn(f func(...any) error) *MockRowScanCall {
type MockRows struct {
ctrl *gomock.Controller
recorder *MockRowsMockRecorder
+ isgomock struct{}
}
// MockRowsMockRecorder is the mock recorder for MockRows.
@@ -797,10 +877,10 @@ func (c *MockRowsNextCall) DoAndReturn(f func() bool) *MockRowsNextCall {
}
// Scan mocks base method.
-func (m *MockRows) Scan(arg0 ...any) error {
+func (m *MockRows) Scan(dest ...any) error {
m.ctrl.T.Helper()
varargs := []any{}
- for _, a := range arg0 {
+ for _, a := range dest {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Scan", varargs...)
@@ -809,9 +889,9 @@ func (m *MockRows) Scan(arg0 ...any) error {
}
// Scan indicates an expected call of Scan.
-func (mr *MockRowsMockRecorder) Scan(arg0 ...any) *MockRowsScanCall {
+func (mr *MockRowsMockRecorder) Scan(dest ...any) *MockRowsScanCall {
mr.mock.ctrl.T.Helper()
- call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), arg0...)
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), dest...)
return &MockRowsScanCall{Call: call}
}
@@ -842,6 +922,7 @@ func (c *MockRowsScanCall) DoAndReturn(f func(...any) error) *MockRowsScanCall {
type MockTransaction struct {
ctrl *gomock.Controller
recorder *MockTransactionMockRecorder
+ isgomock struct{}
}
// MockTransactionMockRecorder is the mock recorder for MockTransaction.
@@ -862,18 +943,18 @@ func (m *MockTransaction) EXPECT() *MockTransactionMockRecorder {
}
// Begin mocks base method.
-func (m *MockTransaction) Begin(arg0 context.Context) (database.Transaction, error) {
+func (m *MockTransaction) Begin(ctx context.Context) (database.Transaction, error) {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "Begin", arg0)
+ ret := m.ctrl.Call(m, "Begin", ctx)
ret0, _ := ret[0].(database.Transaction)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Begin indicates an expected call of Begin.
-func (mr *MockTransactionMockRecorder) Begin(arg0 any) *MockTransactionBeginCall {
+func (mr *MockTransactionMockRecorder) Begin(ctx any) *MockTransactionBeginCall {
mr.mock.ctrl.T.Helper()
- call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockTransaction)(nil).Begin), arg0)
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockTransaction)(nil).Begin), ctx)
return &MockTransactionBeginCall{Call: call}
}
@@ -901,17 +982,17 @@ func (c *MockTransactionBeginCall) DoAndReturn(f func(context.Context) (database
}
// Commit mocks base method.
-func (m *MockTransaction) Commit(arg0 context.Context) error {
+func (m *MockTransaction) Commit(ctx context.Context) error {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "Commit", arg0)
+ ret := m.ctrl.Call(m, "Commit", ctx)
ret0, _ := ret[0].(error)
return ret0
}
// Commit indicates an expected call of Commit.
-func (mr *MockTransactionMockRecorder) Commit(arg0 any) *MockTransactionCommitCall {
+func (mr *MockTransactionMockRecorder) Commit(ctx any) *MockTransactionCommitCall {
mr.mock.ctrl.T.Helper()
- call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTransaction)(nil).Commit), arg0)
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTransaction)(nil).Commit), ctx)
return &MockTransactionCommitCall{Call: call}
}
@@ -939,17 +1020,17 @@ func (c *MockTransactionCommitCall) DoAndReturn(f func(context.Context) error) *
}
// End mocks base method.
-func (m *MockTransaction) End(arg0 context.Context, arg1 error) error {
+func (m *MockTransaction) End(ctx context.Context, err error) error {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "End", arg0, arg1)
+ ret := m.ctrl.Call(m, "End", ctx, err)
ret0, _ := ret[0].(error)
return ret0
}
// End indicates an expected call of End.
-func (mr *MockTransactionMockRecorder) End(arg0, arg1 any) *MockTransactionEndCall {
+func (mr *MockTransactionMockRecorder) End(ctx, err any) *MockTransactionEndCall {
mr.mock.ctrl.T.Helper()
- call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "End", reflect.TypeOf((*MockTransaction)(nil).End), arg0, arg1)
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "End", reflect.TypeOf((*MockTransaction)(nil).End), ctx, err)
return &MockTransactionEndCall{Call: call}
}
@@ -977,10 +1058,10 @@ func (c *MockTransactionEndCall) DoAndReturn(f func(context.Context, error) erro
}
// Exec mocks base method.
-func (m *MockTransaction) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) {
+func (m *MockTransaction) Exec(ctx context.Context, stmt string, args ...any) (int64, error) {
m.ctrl.T.Helper()
- varargs := []any{arg0, arg1}
- for _, a := range arg2 {
+ varargs := []any{ctx, stmt}
+ for _, a := range args {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
@@ -990,9 +1071,9 @@ func (m *MockTransaction) Exec(arg0 context.Context, arg1 string, arg2 ...any) (
}
// Exec indicates an expected call of Exec.
-func (mr *MockTransactionMockRecorder) Exec(arg0, arg1 any, arg2 ...any) *MockTransactionExecCall {
+func (mr *MockTransactionMockRecorder) Exec(ctx, stmt any, args ...any) *MockTransactionExecCall {
mr.mock.ctrl.T.Helper()
- varargs := append([]any{arg0, arg1}, arg2...)
+ varargs := append([]any{ctx, stmt}, args...)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTransaction)(nil).Exec), varargs...)
return &MockTransactionExecCall{Call: call}
}
@@ -1021,10 +1102,10 @@ func (c *MockTransactionExecCall) DoAndReturn(f func(context.Context, string, ..
}
// Query mocks base method.
-func (m *MockTransaction) Query(arg0 context.Context, arg1 string, arg2 ...any) (database.Rows, error) {
+func (m *MockTransaction) Query(ctx context.Context, stmt string, args ...any) (database.Rows, error) {
m.ctrl.T.Helper()
- varargs := []any{arg0, arg1}
- for _, a := range arg2 {
+ varargs := []any{ctx, stmt}
+ for _, a := range args {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Query", varargs...)
@@ -1034,9 +1115,9 @@ func (m *MockTransaction) Query(arg0 context.Context, arg1 string, arg2 ...any)
}
// Query indicates an expected call of Query.
-func (mr *MockTransactionMockRecorder) Query(arg0, arg1 any, arg2 ...any) *MockTransactionQueryCall {
+func (mr *MockTransactionMockRecorder) Query(ctx, stmt any, args ...any) *MockTransactionQueryCall {
mr.mock.ctrl.T.Helper()
- varargs := append([]any{arg0, arg1}, arg2...)
+ varargs := append([]any{ctx, stmt}, args...)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockTransaction)(nil).Query), varargs...)
return &MockTransactionQueryCall{Call: call}
}
@@ -1065,10 +1146,10 @@ func (c *MockTransactionQueryCall) DoAndReturn(f func(context.Context, string, .
}
// QueryRow mocks base method.
-func (m *MockTransaction) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) database.Row {
+func (m *MockTransaction) QueryRow(ctx context.Context, stmt string, args ...any) database.Row {
m.ctrl.T.Helper()
- varargs := []any{arg0, arg1}
- for _, a := range arg2 {
+ varargs := []any{ctx, stmt}
+ for _, a := range args {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "QueryRow", varargs...)
@@ -1077,9 +1158,9 @@ func (m *MockTransaction) QueryRow(arg0 context.Context, arg1 string, arg2 ...an
}
// QueryRow indicates an expected call of QueryRow.
-func (mr *MockTransactionMockRecorder) QueryRow(arg0, arg1 any, arg2 ...any) *MockTransactionQueryRowCall {
+func (mr *MockTransactionMockRecorder) QueryRow(ctx, stmt any, args ...any) *MockTransactionQueryRowCall {
mr.mock.ctrl.T.Helper()
- varargs := append([]any{arg0, arg1}, arg2...)
+ varargs := append([]any{ctx, stmt}, args...)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockTransaction)(nil).QueryRow), varargs...)
return &MockTransactionQueryRowCall{Call: call}
}
@@ -1108,17 +1189,17 @@ func (c *MockTransactionQueryRowCall) DoAndReturn(f func(context.Context, string
}
// Rollback mocks base method.
-func (m *MockTransaction) Rollback(arg0 context.Context) error {
+func (m *MockTransaction) Rollback(ctx context.Context) error {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "Rollback", arg0)
+ ret := m.ctrl.Call(m, "Rollback", ctx)
ret0, _ := ret[0].(error)
return ret0
}
// Rollback indicates an expected call of Rollback.
-func (mr *MockTransactionMockRecorder) Rollback(arg0 any) *MockTransactionRollbackCall {
+func (mr *MockTransactionMockRecorder) Rollback(ctx any) *MockTransactionRollbackCall {
mr.mock.ctrl.T.Helper()
- call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockTransaction)(nil).Rollback), arg0)
+ call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockTransaction)(nil).Rollback), ctx)
return &MockTransactionRollbackCall{Call: call}
}
diff --git a/backend/v3/storage/database/dialect/postgres/conn.go b/backend/v3/storage/database/dialect/postgres/conn.go
index a8639692689..8d780feb417 100644
--- a/backend/v3/storage/database/dialect/postgres/conn.go
+++ b/backend/v3/storage/database/dialect/postgres/conn.go
@@ -13,15 +13,15 @@ type pgxConn struct {
*pgxpool.Conn
}
-var _ database.Client = (*pgxConn)(nil)
+var _ database.Connection = (*pgxConn)(nil)
-// Release implements [database.Client].
+// Release implements [database.Connection].
func (c *pgxConn) Release(_ context.Context) error {
c.Conn.Release()
return nil
}
-// Begin implements [database.Client].
+// Begin implements [database.Connection].
func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
tx, err := c.BeginTx(ctx, transactionOptionsToPgx(opts))
if err != nil {
@@ -30,8 +30,8 @@ func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions)
return &Transaction{tx}, nil
}
-// Query implements sql.Client.
-// Subtle: this method shadows the method (*Conn).Query of pgxConn.Conn.
+// Query implements [database.Connection].
+// Subtle: this method shadows the method (*Conn).Query of [pgxConn.Conn].
func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
rows, err := c.Conn.Query(ctx, sql, args...)
if err != nil {
@@ -40,14 +40,14 @@ func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.
return &Rows{rows}, nil
}
-// QueryRow implements sql.Client.
-// Subtle: this method shadows the method (*Conn).QueryRow of pgxConn.Conn.
+// QueryRow implements [database.Connection].
+// Subtle: this method shadows the method (*Conn).QueryRow of [pgxConn.Conn].
func (c *pgxConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return &Row{c.Conn.QueryRow(ctx, sql, args...)}
}
-// Exec implements [database.Pool].
-// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
+// QueryRow implements [database.Connection].
+// Subtle: this method shadows the method (*Conn).QueryRow of [pgxConn.Conn].
func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
res, err := c.Conn.Exec(ctx, sql, args...)
if err != nil {
diff --git a/backend/v3/storage/database/dialect/postgres/error.go b/backend/v3/storage/database/dialect/postgres/error.go
index 89b3f8837a9..92fe2147af4 100644
--- a/backend/v3/storage/database/dialect/postgres/error.go
+++ b/backend/v3/storage/database/dialect/postgres/error.go
@@ -2,6 +2,7 @@ package postgres
import (
"errors"
+ "strings"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
@@ -16,6 +17,18 @@ func wrapError(err error) error {
if errors.Is(err, pgx.ErrNoRows) {
return database.NewNoRowFoundError(err)
}
+ if errors.Is(err, pgx.ErrTooManyRows) {
+ return database.NewMultipleRowsFoundError(err)
+ }
+
+ // scany only exports its errors as strings
+ if strings.HasPrefix(err.Error(), "scany: expected 1 row, got: ") {
+ return database.NewMultipleRowsFoundError(err)
+ }
+ if strings.HasPrefix(err.Error(), "scany:") || strings.HasPrefix(err.Error(), "scanning:") {
+ return database.NewScanError(err)
+ }
+
var pgxErr *pgconn.PgError
if !errors.As(err, &pgxErr) {
return database.NewUnknownError(err)
diff --git a/backend/v3/storage/database/dialect/postgres/migration/004_correct_set_updated_at.go b/backend/v3/storage/database/dialect/postgres/migration/004_correct_set_updated_at.go
new file mode 100644
index 00000000000..01dea1e38a8
--- /dev/null
+++ b/backend/v3/storage/database/dialect/postgres/migration/004_correct_set_updated_at.go
@@ -0,0 +1,16 @@
+package migration
+
+import (
+ _ "embed"
+)
+
+var (
+ //go:embed 004_correct_set_updated_at/up.sql
+ up004CorrectSetUpdatedAt string
+ //go:embed 004_correct_set_updated_at/down.sql
+ down004CorrectSetUpdatedAt string
+)
+
+func init() {
+ registerSQLMigration(4, up004CorrectSetUpdatedAt, down004CorrectSetUpdatedAt)
+}
diff --git a/backend/v3/storage/database/dialect/postgres/migration/004_correct_set_updated_at/down.sql b/backend/v3/storage/database/dialect/postgres/migration/004_correct_set_updated_at/down.sql
new file mode 100644
index 00000000000..5e905f9fa98
--- /dev/null
+++ b/backend/v3/storage/database/dialect/postgres/migration/004_correct_set_updated_at/down.sql
@@ -0,0 +1,23 @@
+CREATE OR REPLACE TRIGGER trigger_set_updated_at
+BEFORE UPDATE ON zitadel.instances
+FOR EACH ROW
+WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
+EXECUTE FUNCTION zitadel.set_updated_at();
+
+CREATE OR REPLACE TRIGGER trigger_set_updated_at
+BEFORE UPDATE ON zitadel.organizations
+FOR EACH ROW
+WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
+EXECUTE FUNCTION zitadel.set_updated_at();
+
+CREATE OR REPLACE TRIGGER trg_set_updated_at_instance_domains
+ BEFORE UPDATE ON zitadel.instance_domains
+ FOR EACH ROW
+ WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
+ EXECUTE FUNCTION zitadel.set_updated_at();
+
+CREATE OR REPLACE TRIGGER trg_set_updated_at_org_domains
+ BEFORE UPDATE ON zitadel.org_domains
+ FOR EACH ROW
+ WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
+ EXECUTE FUNCTION zitadel.set_updated_at();
\ No newline at end of file
diff --git a/backend/v3/storage/database/dialect/postgres/migration/004_correct_set_updated_at/up.sql b/backend/v3/storage/database/dialect/postgres/migration/004_correct_set_updated_at/up.sql
new file mode 100644
index 00000000000..6292197b6a9
--- /dev/null
+++ b/backend/v3/storage/database/dialect/postgres/migration/004_correct_set_updated_at/up.sql
@@ -0,0 +1,23 @@
+CREATE OR REPLACE TRIGGER trigger_set_updated_at
+BEFORE UPDATE ON zitadel.instances
+FOR EACH ROW
+WHEN (NEW.updated_at IS NULL)
+EXECUTE FUNCTION zitadel.set_updated_at();
+
+CREATE OR REPLACE TRIGGER trigger_set_updated_at
+BEFORE UPDATE ON zitadel.organizations
+FOR EACH ROW
+WHEN (NEW.updated_at IS NULL)
+EXECUTE FUNCTION zitadel.set_updated_at();
+
+CREATE OR REPLACE TRIGGER trg_set_updated_at_instance_domains
+ BEFORE UPDATE ON zitadel.instance_domains
+ FOR EACH ROW
+ WHEN (NEW.updated_at IS NULL)
+ EXECUTE FUNCTION zitadel.set_updated_at();
+
+CREATE OR REPLACE TRIGGER trg_set_updated_at_org_domains
+ BEFORE UPDATE ON zitadel.org_domains
+ FOR EACH ROW
+ WHEN (NEW.updated_at IS NULL)
+ EXECUTE FUNCTION zitadel.set_updated_at();
\ No newline at end of file
diff --git a/backend/v3/storage/database/dialect/postgres/pool.go b/backend/v3/storage/database/dialect/postgres/pool.go
index 07cba119b5e..275be54e11e 100644
--- a/backend/v3/storage/database/dialect/postgres/pool.go
+++ b/backend/v3/storage/database/dialect/postgres/pool.go
@@ -22,8 +22,8 @@ func PGxPool(pool *pgxpool.Pool) *pgxPool {
}
// Acquire implements [database.Pool].
-func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) {
- conn, err := c.Pool.Acquire(ctx)
+func (p *pgxPool) Acquire(ctx context.Context) (database.Connection, error) {
+ conn, err := p.Pool.Acquire(ctx)
if err != nil {
return nil, wrapError(err)
}
@@ -32,8 +32,8 @@ func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) {
// Query implements [database.Pool].
// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool.
-func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
- rows, err := c.Pool.Query(ctx, sql, args...)
+func (p *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
+ rows, err := p.Pool.Query(ctx, sql, args...)
if err != nil {
return nil, wrapError(err)
}
@@ -42,14 +42,14 @@ func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.
// QueryRow implements [database.Pool].
// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool.
-func (c *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
- return &Row{c.Pool.QueryRow(ctx, sql, args...)}
+func (p *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
+ return &Row{p.Pool.QueryRow(ctx, sql, args...)}
}
// Exec implements [database.Pool].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
-func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
- res, err := c.Pool.Exec(ctx, sql, args...)
+func (p *pgxPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
+ res, err := p.Pool.Exec(ctx, sql, args...)
if err != nil {
return 0, wrapError(err)
}
@@ -57,8 +57,8 @@ func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) (int64, err
}
// Begin implements [database.Pool].
-func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
- tx, err := c.BeginTx(ctx, transactionOptionsToPgx(opts))
+func (p *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
+ tx, err := p.BeginTx(ctx, transactionOptionsToPgx(opts))
if err != nil {
return nil, wrapError(err)
}
@@ -66,23 +66,23 @@ func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions)
}
// Close implements [database.Pool].
-func (c *pgxPool) Close(_ context.Context) error {
- c.Pool.Close()
+func (p *pgxPool) Close(_ context.Context) error {
+ p.Pool.Close()
return nil
}
// Ping implements [database.Pool].
-func (c *pgxPool) Ping(ctx context.Context) error {
- return wrapError(c.Pool.Ping(ctx))
+func (p *pgxPool) Ping(ctx context.Context) error {
+ return wrapError(p.Pool.Ping(ctx))
}
// Migrate implements [database.Migrator].
-func (c *pgxPool) Migrate(ctx context.Context) error {
+func (p *pgxPool) Migrate(ctx context.Context) error {
if isMigrated {
return nil
}
- client, err := c.Pool.Acquire(ctx)
+ client, err := p.Pool.Acquire(ctx)
if err != nil {
return err
}
@@ -93,8 +93,8 @@ func (c *pgxPool) Migrate(ctx context.Context) error {
}
// Migrate implements [database.PoolTest].
-func (c *pgxPool) MigrateTest(ctx context.Context) error {
- client, err := c.Pool.Acquire(ctx)
+func (p *pgxPool) MigrateTest(ctx context.Context) error {
+ client, err := p.Pool.Acquire(ctx)
if err != nil {
return err
}
diff --git a/backend/v3/storage/database/dialect/sql/conn.go b/backend/v3/storage/database/dialect/sql/conn.go
index 2a4e4c91b8b..9d9f5472df0 100644
--- a/backend/v3/storage/database/dialect/sql/conn.go
+++ b/backend/v3/storage/database/dialect/sql/conn.go
@@ -11,18 +11,18 @@ type sqlConn struct {
*sql.Conn
}
-func SQLConn(conn *sql.Conn) database.Client {
+func SQLConn(conn *sql.Conn) database.Connection {
return &sqlConn{Conn: conn}
}
-var _ database.Client = (*sqlConn)(nil)
+var _ database.Connection = (*sqlConn)(nil)
-// Release implements [database.Client].
+// Release implements [database.Connection].
func (c *sqlConn) Release(_ context.Context) error {
return c.Close()
}
-// Begin implements [database.Client].
+// Begin implements [database.Connection].
func (c *sqlConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
tx, err := c.BeginTx(ctx, transactionOptionsToSQL(opts))
if err != nil {
diff --git a/backend/v3/storage/database/dialect/sql/error.go b/backend/v3/storage/database/dialect/sql/error.go
index f01b1e70e13..4879180a559 100644
--- a/backend/v3/storage/database/dialect/sql/error.go
+++ b/backend/v3/storage/database/dialect/sql/error.go
@@ -3,6 +3,7 @@ package sql
import (
"database/sql"
"errors"
+ "strings"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
@@ -19,6 +20,15 @@ func wrapError(err error) error {
if errors.Is(err, pgx.ErrNoRows) || errors.Is(err, sql.ErrNoRows) {
return database.NewNoRowFoundError(err)
}
+
+ // scany only exports its errors as strings
+ if strings.HasPrefix(err.Error(), "scany: expected 1 row, got: ") {
+ return database.NewMultipleRowsFoundError(err)
+ }
+ if strings.HasPrefix(err.Error(), "scany:") || strings.HasPrefix(err.Error(), "scanning:") {
+ return database.NewScanError(err)
+ }
+
var pgxErr *pgconn.PgError
if !errors.As(err, &pgxErr) {
return database.NewUnknownError(err)
diff --git a/backend/v3/storage/database/dialect/sql/pool.go b/backend/v3/storage/database/dialect/sql/pool.go
index 2d37520be78..521d1d1e12c 100644
--- a/backend/v3/storage/database/dialect/sql/pool.go
+++ b/backend/v3/storage/database/dialect/sql/pool.go
@@ -20,8 +20,8 @@ func SQLPool(db *sql.DB) *sqlPool {
}
// Acquire implements [database.Pool].
-func (c *sqlPool) Acquire(ctx context.Context) (database.Client, error) {
- conn, err := c.Conn(ctx)
+func (p *sqlPool) Acquire(ctx context.Context) (database.Connection, error) {
+ conn, err := p.Conn(ctx)
if err != nil {
return nil, wrapError(err)
}
@@ -30,9 +30,9 @@ func (c *sqlPool) Acquire(ctx context.Context) (database.Client, error) {
// Query implements [database.Pool].
// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool.
-func (c *sqlPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
+func (p *sqlPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
//nolint:rowserrcheck // Rows.Close is called by the caller
- rows, err := c.QueryContext(ctx, sql, args...)
+ rows, err := p.QueryContext(ctx, sql, args...)
if err != nil {
return nil, wrapError(err)
}
@@ -41,14 +41,14 @@ func (c *sqlPool) Query(ctx context.Context, sql string, args ...any) (database.
// QueryRow implements [database.Pool].
// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool.
-func (c *sqlPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
- return &Row{c.QueryRowContext(ctx, sql, args...)}
+func (p *sqlPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
+ return &Row{p.QueryRowContext(ctx, sql, args...)}
}
// Exec implements [database.Pool].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
-func (c *sqlPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
- res, err := c.ExecContext(ctx, sql, args...)
+func (p *sqlPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
+ res, err := p.ExecContext(ctx, sql, args...)
if err != nil {
return 0, wrapError(err)
}
@@ -56,8 +56,8 @@ func (c *sqlPool) Exec(ctx context.Context, sql string, args ...any) (int64, err
}
// Begin implements [database.Pool].
-func (c *sqlPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
- tx, err := c.BeginTx(ctx, transactionOptionsToSQL(opts))
+func (p *sqlPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
+ tx, err := p.BeginTx(ctx, transactionOptionsToSQL(opts))
if err != nil {
return nil, wrapError(err)
}
@@ -65,16 +65,16 @@ func (c *sqlPool) Begin(ctx context.Context, opts *database.TransactionOptions)
}
// Ping implements [database.Pool].
-func (c *sqlPool) Ping(ctx context.Context) error {
- return wrapError(c.PingContext(ctx))
+func (p *sqlPool) Ping(ctx context.Context) error {
+ return wrapError(p.PingContext(ctx))
}
// Close implements [database.Pool].
-func (c *sqlPool) Close(_ context.Context) error {
- return c.DB.Close()
+func (p *sqlPool) Close(_ context.Context) error {
+ return p.DB.Close()
}
// Migrate implements [database.Migrator].
-func (c *sqlPool) Migrate(ctx context.Context) error {
+func (p *sqlPool) Migrate(ctx context.Context) error {
return ErrMigrate
}
diff --git a/backend/v3/storage/database/errors.go b/backend/v3/storage/database/errors.go
index 418019fd304..de6cb794c3f 100644
--- a/backend/v3/storage/database/errors.go
+++ b/backend/v3/storage/database/errors.go
@@ -5,7 +5,34 @@ import (
"fmt"
)
-var ErrNoChanges = errors.New("update must contain a change")
+var (
+ ErrNoChanges = errors.New("update must contain a change")
+)
+
+type MissingConditionError struct {
+ col Column
+}
+
+func NewMissingConditionError(col Column) error {
+ return &MissingConditionError{
+ col: col,
+ }
+}
+
+func (e *MissingConditionError) Error() string {
+ var builder StatementBuilder
+ builder.WriteString("missing condition for column")
+ if e.col != nil {
+ builder.WriteString(" on ")
+ e.col.WriteQualified(&builder)
+ }
+ return builder.String()
+}
+
+func (e *MissingConditionError) Is(target error) bool {
+ _, ok := target.(*MissingConditionError)
+ return ok
+}
// NoRowFoundError is returned when QueryRow does not find any row.
// It wraps the dialect specific original error to provide more context.
@@ -20,6 +47,9 @@ func NewNoRowFoundError(original error) error {
}
func (e *NoRowFoundError) Error() string {
+ if e.original != nil {
+ return fmt.Sprintf("no row found: %v", e.original)
+ }
return "no row found"
}
@@ -36,18 +66,19 @@ func (e *NoRowFoundError) Unwrap() error {
// It wraps the dialect specific original error to provide more context.
type MultipleRowsFoundError struct {
original error
- count int
}
-func NewMultipleRowsFoundError(original error, count int) error {
+func NewMultipleRowsFoundError(original error) error {
return &MultipleRowsFoundError{
original: original,
- count: count,
}
}
func (e *MultipleRowsFoundError) Error() string {
- return fmt.Sprintf("multiple rows found: %d", e.count)
+ if e.original != nil {
+ return fmt.Sprintf("multiple rows found: %v", e.original)
+ }
+ return "multiple rows found"
}
func (e *MultipleRowsFoundError) Is(target error) bool {
@@ -77,8 +108,8 @@ type IntegrityViolationError struct {
original error
}
-func NewIntegrityViolationError(typ IntegrityType, table, constraint string, original error) error {
- return &IntegrityViolationError{
+func newIntegrityViolationError(typ IntegrityType, table, constraint string, original error) IntegrityViolationError {
+ return IntegrityViolationError{
integrityType: typ,
table: table,
constraint: constraint,
@@ -87,7 +118,10 @@ func NewIntegrityViolationError(typ IntegrityType, table, constraint string, ori
}
func (e *IntegrityViolationError) Error() string {
- return fmt.Sprintf("integrity violation of type %q on %q (constraint: %q): %v", e.integrityType, e.table, e.constraint, e.original)
+ if e.original != nil {
+ return fmt.Sprintf("integrity violation of type %q on %q (constraint: %q): %v", e.integrityType, e.table, e.constraint, e.original)
+ }
+ return fmt.Sprintf("integrity violation of type %q on %q (constraint: %q)", e.integrityType, e.table, e.constraint)
}
func (e *IntegrityViolationError) Is(target error) bool {
@@ -108,12 +142,7 @@ type CheckError struct {
func NewCheckError(table, constraint string, original error) error {
return &CheckError{
- IntegrityViolationError: IntegrityViolationError{
- integrityType: IntegrityTypeCheck,
- table: table,
- constraint: constraint,
- original: original,
- },
+ IntegrityViolationError: newIntegrityViolationError(IntegrityTypeCheck, table, constraint, original),
}
}
@@ -135,12 +164,7 @@ type UniqueError struct {
func NewUniqueError(table, constraint string, original error) error {
return &UniqueError{
- IntegrityViolationError: IntegrityViolationError{
- integrityType: IntegrityTypeUnique,
- table: table,
- constraint: constraint,
- original: original,
- },
+ IntegrityViolationError: newIntegrityViolationError(IntegrityTypeUnique, table, constraint, original),
}
}
@@ -162,12 +186,7 @@ type ForeignKeyError struct {
func NewForeignKeyError(table, constraint string, original error) error {
return &ForeignKeyError{
- IntegrityViolationError: IntegrityViolationError{
- integrityType: IntegrityTypeForeign,
- table: table,
- constraint: constraint,
- original: original,
- },
+ IntegrityViolationError: newIntegrityViolationError(IntegrityTypeForeign, table, constraint, original),
}
}
@@ -189,12 +208,7 @@ type NotNullError struct {
func NewNotNullError(table, constraint string, original error) error {
return &NotNullError{
- IntegrityViolationError: IntegrityViolationError{
- integrityType: IntegrityTypeNotNull,
- table: table,
- constraint: constraint,
- original: original,
- },
+ IntegrityViolationError: newIntegrityViolationError(IntegrityTypeNotNull, table, constraint, original),
}
}
@@ -207,6 +221,31 @@ func (e *NotNullError) Unwrap() error {
return &e.IntegrityViolationError
}
+// ScanError is returned when scanning rows into objects failed.
+// It wraps the original error to provide more context.
+type ScanError struct {
+ original error
+}
+
+func NewScanError(original error) error {
+ return &ScanError{
+ original: original,
+ }
+}
+
+func (e *ScanError) Error() string {
+ return fmt.Sprintf("Scan error: %v", e.original)
+}
+
+func (e *ScanError) Is(target error) bool {
+ _, ok := target.(*ScanError)
+ return ok
+}
+
+func (e *ScanError) Unwrap() error {
+ return e.original
+}
+
// UnknownError is returned when an unknown error occurs.
// It wraps the dialect specific original error to provide more context.
// It is used to indicate that an error occurred that does not fit into any of the other categories.
diff --git a/backend/v3/storage/database/errors_test.go b/backend/v3/storage/database/errors_test.go
new file mode 100644
index 00000000000..1d8e928c0f9
--- /dev/null
+++ b/backend/v3/storage/database/errors_test.go
@@ -0,0 +1,304 @@
+package database
+
+import (
+ "errors"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestError(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ err error
+ want string
+ }{
+ {
+ name: "missing condition without column",
+ err: NewMissingConditionError(nil),
+ want: "missing condition for column",
+ },
+ {
+ name: "missing condition with column",
+ err: NewMissingConditionError(NewColumn("table", "column")),
+ want: "missing condition for column on table.column",
+ },
+ {
+ name: "no row found without original error",
+ err: NewNoRowFoundError(nil),
+ want: "no row found",
+ },
+ {
+ name: "no row found with original error",
+ err: NewNoRowFoundError(errors.New("original error")),
+ want: "no row found: original error",
+ },
+ {
+ name: "multiple rows found without original error",
+ err: NewMultipleRowsFoundError(nil),
+ want: "multiple rows found",
+ },
+ {
+ name: "multiple rows found with original error",
+ err: NewMultipleRowsFoundError(errors.New("original error")),
+ want: "multiple rows found: original error",
+ },
+ {
+ name: "check violation without original error",
+ err: NewCheckError("table", "constraint", nil),
+ want: `integrity violation of type "check" on "table" (constraint: "constraint")`,
+ },
+ {
+ name: "check violation with original error",
+ err: NewCheckError("table", "constraint", errors.New("original error")),
+ want: `integrity violation of type "check" on "table" (constraint: "constraint"): original error`,
+ },
+ {
+ name: "unique violation without original error",
+ err: NewUniqueError("table", "constraint", nil),
+ want: `integrity violation of type "unique" on "table" (constraint: "constraint")`,
+ },
+ {
+ name: "unique violation with original error",
+ err: NewUniqueError("table", "constraint", errors.New("original error")),
+ want: `integrity violation of type "unique" on "table" (constraint: "constraint"): original error`,
+ },
+ {
+ name: "foreign key violation without original error",
+ err: NewForeignKeyError("table", "constraint", nil),
+ want: `integrity violation of type "foreign" on "table" (constraint: "constraint")`,
+ },
+ {
+ name: "foreign key violation with original error",
+ err: NewForeignKeyError("table", "constraint", errors.New("original error")),
+ want: `integrity violation of type "foreign" on "table" (constraint: "constraint"): original error`,
+ },
+ {
+ name: "not null violation without original error",
+ err: NewNotNullError("table", "constraint", nil),
+ want: `integrity violation of type "not null" on "table" (constraint: "constraint")`,
+ },
+ {
+ name: "not null violation with original error",
+ err: NewNotNullError("table", "constraint", errors.New("original error")),
+ want: `integrity violation of type "not null" on "table" (constraint: "constraint"): original error`,
+ },
+ {
+ name: "unknown error",
+ err: NewUnknownError(errors.New("original error")),
+ want: `unknown database error: original error`,
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ assert.Equal(t, test.want, test.err.Error())
+ })
+ }
+}
+
+func TestUnwrap(t *testing.T) {
+ originalErr := errors.New("original error")
+ for _, test := range []struct {
+ name string
+ err error
+ want error
+ }{
+ {
+ name: "missing condition without column",
+ err: NewMissingConditionError(nil),
+ want: nil,
+ },
+ {
+ name: "missing condition with column",
+ err: NewMissingConditionError(NewColumn("table", "column")),
+ want: nil,
+ },
+ {
+ name: "no row found without original error",
+ err: NewNoRowFoundError(nil),
+ want: nil,
+ },
+ {
+ name: "no row found with original error",
+ err: NewNoRowFoundError(errors.New("original error")),
+ want: originalErr,
+ },
+ {
+ name: "multiple rows found without original error",
+ err: NewMultipleRowsFoundError(nil),
+ want: nil,
+ },
+ {
+ name: "multiple rows found with original error",
+ err: NewMultipleRowsFoundError(originalErr),
+ want: originalErr,
+ },
+ {
+ name: "check violation without original error",
+ err: NewCheckError("table", "constraint", nil),
+ want: &IntegrityViolationError{
+ integrityType: IntegrityTypeCheck,
+ table: "table",
+ constraint: "constraint",
+ original: nil,
+ },
+ },
+ {
+ name: "check violation with original error",
+ err: NewCheckError("table", "constraint", originalErr),
+ want: &IntegrityViolationError{
+ integrityType: IntegrityTypeCheck,
+ table: "table",
+ constraint: "constraint",
+ original: originalErr,
+ },
+ },
+ {
+ name: "unique violation without original error",
+ err: NewUniqueError("table", "constraint", nil),
+ want: &IntegrityViolationError{
+ integrityType: IntegrityTypeUnique,
+ table: "table",
+ constraint: "constraint",
+ original: nil,
+ },
+ },
+ {
+ name: "unique violation with original error",
+ err: NewUniqueError("table", "constraint", originalErr),
+ want: &IntegrityViolationError{
+ integrityType: IntegrityTypeUnique,
+ table: "table",
+ constraint: "constraint",
+ original: originalErr,
+ },
+ },
+ {
+ name: "foreign key violation without original error",
+ err: NewForeignKeyError("table", "constraint", nil),
+ want: &IntegrityViolationError{
+ integrityType: IntegrityTypeForeign,
+ table: "table",
+ constraint: "constraint",
+ original: nil,
+ },
+ },
+ {
+ name: "foreign key violation with original error",
+ err: NewForeignKeyError("table", "constraint", originalErr),
+ want: &IntegrityViolationError{
+ integrityType: IntegrityTypeForeign,
+ table: "table",
+ constraint: "constraint",
+ original: originalErr,
+ },
+ },
+ {
+ name: "not null violation without original error",
+ err: NewNotNullError("table", "constraint", nil),
+ want: &IntegrityViolationError{
+ integrityType: IntegrityTypeNotNull,
+ table: "table",
+ constraint: "constraint",
+ original: nil,
+ },
+ },
+ {
+ name: "not null violation with original error",
+ err: NewNotNullError("table", "constraint", originalErr),
+ want: &IntegrityViolationError{
+ integrityType: IntegrityTypeNotNull,
+ table: "table",
+ constraint: "constraint",
+ original: originalErr,
+ },
+ },
+ {
+ name: "unwrap integrity violation error",
+ err: errors.Unwrap(NewNotNullError("table", "constraint", originalErr)),
+ want: originalErr,
+ },
+ {
+ name: "unknown error",
+ err: NewUnknownError(originalErr),
+ want: originalErr,
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ assert.Equal(t, test.want, errors.Unwrap(test.err))
+ })
+ }
+}
+
+func TestIs(t *testing.T) {
+ originalErr := errors.New("original error")
+ for _, test := range []struct {
+ name string
+ err error
+ want error
+ }{
+ {
+ name: "missing condition",
+ err: NewMissingConditionError(NewColumn("table", "column")),
+ want: new(MissingConditionError),
+ },
+ {
+ name: "no row found",
+ err: NewNoRowFoundError(errors.New("original error")),
+ want: new(NoRowFoundError),
+ },
+ {
+ name: "multiple rows found",
+ err: NewMultipleRowsFoundError(originalErr),
+ want: new(MultipleRowsFoundError),
+ },
+ {
+ name: "check violation is for integrity",
+ err: NewCheckError("table", "constraint", nil),
+ want: new(IntegrityViolationError),
+ },
+ {
+ name: "check violation is check violation",
+ err: NewCheckError("table", "constraint", nil),
+ want: new(CheckError),
+ },
+ {
+ name: "unique violation is for integrity",
+ err: NewUniqueError("table", "constraint", nil),
+ want: new(IntegrityViolationError),
+ },
+ {
+ name: "unique violation is unique violation",
+ err: NewUniqueError("table", "constraint", nil),
+ want: new(UniqueError),
+ },
+ {
+ name: "foreign key violation is for integrity",
+ err: NewForeignKeyError("table", "constraint", nil),
+ want: new(IntegrityViolationError),
+ },
+ {
+ name: "foreign key violation is foreign key violation",
+ err: NewForeignKeyError("table", "constraint", nil),
+ want: new(ForeignKeyError),
+ },
+ {
+ name: "not null violation is for integrity",
+ err: NewNotNullError("table", "constraint", nil),
+ want: new(IntegrityViolationError),
+ },
+ {
+ name: "not null violation is not null violation",
+ err: NewNotNullError("table", "constraint", nil),
+ want: new(NotNullError),
+ },
+ {
+ name: "unknown error",
+ err: NewUnknownError(originalErr),
+ want: new(UnknownError),
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ assert.ErrorIs(t, test.err, test.want)
+ })
+ }
+}
diff --git a/backend/v3/storage/database/events_testing/instance_domain_test.go b/backend/v3/storage/database/events_testing/instance_domain_test.go
index c42ee1a0ff0..8dc6d141765 100644
--- a/backend/v3/storage/database/events_testing/instance_domain_test.go
+++ b/backend/v3/storage/database/events_testing/instance_domain_test.go
@@ -21,8 +21,8 @@ import (
func TestServer_TestInstanceDomainReduces(t *testing.T) {
instance := integration.NewInstance(CTX)
- instanceRepo := repository.InstanceRepository(pool)
- instanceDomainRepo := instanceRepo.Domains(true)
+ instanceRepo := repository.InstanceRepository()
+ instanceDomainRepo := repository.InstanceDomainRepository()
t.Cleanup(func() {
_, err := instance.Client.InstanceV2Beta.DeleteInstance(CTX, &v2beta.DeleteInstanceRequest{
@@ -36,7 +36,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
// Wait for instance to be created
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
- _, err := instanceRepo.Get(CTX,
+ _, err := instanceRepo.Get(CTX, pool,
database.WithCondition(instanceRepo.IDCondition(instance.ID())),
)
assert.NoError(t, err)
@@ -66,7 +66,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
// Test that domain add reduces
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
- domain, err := instanceDomainRepo.Get(CTX,
+ domain, err := instanceDomainRepo.Get(CTX, pool,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
@@ -96,7 +96,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
t.Cleanup(func() {
// first we change the primary domain to something else
- domain, err := instanceDomainRepo.Get(CTX,
+ domain, err := instanceDomainRepo.Get(CTX, pool,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
@@ -125,7 +125,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
// Wait for domain to be created
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
- domain, err := instanceDomainRepo.Get(CTX,
+ domain, err := instanceDomainRepo.Get(CTX, pool,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
@@ -151,7 +151,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
// Test that set primary reduces
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
- domain, err := instanceDomainRepo.Get(CTX,
+ domain, err := instanceDomainRepo.Get(CTX, pool,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
@@ -180,7 +180,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
// Wait for domain to be created and verify it exists
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
- _, err := instanceDomainRepo.Get(CTX,
+ _, err := instanceDomainRepo.Get(CTX, pool,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
@@ -202,7 +202,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
// Test that domain remove reduces
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
- domain, err := instanceDomainRepo.Get(CTX,
+ domain, err := instanceDomainRepo.Get(CTX, pool,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
@@ -241,7 +241,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
// Test that domain add reduces
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
- domain, err := instanceDomainRepo.Get(CTX,
+ domain, err := instanceDomainRepo.Get(CTX, pool,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
@@ -271,7 +271,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
// Wait for domain to be created and verify it exists
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
- _, err := instanceDomainRepo.Get(CTX,
+ _, err := instanceDomainRepo.Get(CTX, pool,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
@@ -293,7 +293,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
// Test that domain remove reduces
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
- domain, err := instanceDomainRepo.Get(CTX,
+ domain, err := instanceDomainRepo.Get(CTX, pool,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
diff --git a/backend/v3/storage/database/events_testing/instance_test.go b/backend/v3/storage/database/events_testing/instance_test.go
index 977fe7c7a9c..46ffe310e53 100644
--- a/backend/v3/storage/database/events_testing/instance_test.go
+++ b/backend/v3/storage/database/events_testing/instance_test.go
@@ -17,7 +17,7 @@ import (
)
func TestServer_TestInstanceReduces(t *testing.T) {
- instanceRepo := repository.InstanceRepository(pool)
+ instanceRepo := repository.InstanceRepository()
t.Run("test instance add reduces", func(t *testing.T) {
instanceName := gofakeit.Name()
@@ -46,7 +46,7 @@ func TestServer_TestInstanceReduces(t *testing.T) {
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
- instance, err := instanceRepo.Get(CTX,
+ instance, err := instanceRepo.Get(CTX, pool,
database.WithCondition(instanceRepo.IDCondition(instance.GetInstanceId())),
)
require.NoError(t, err)
@@ -92,7 +92,7 @@ func TestServer_TestInstanceReduces(t *testing.T) {
// check instance exists
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
- instance, err := instanceRepo.Get(CTX,
+ instance, err := instanceRepo.Get(CTX, pool,
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
)
require.NoError(t, err)
@@ -110,7 +110,7 @@ func TestServer_TestInstanceReduces(t *testing.T) {
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
- instance, err := instanceRepo.Get(CTX,
+ instance, err := instanceRepo.Get(CTX, pool,
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
)
require.NoError(t, err)
@@ -137,7 +137,7 @@ func TestServer_TestInstanceReduces(t *testing.T) {
// check instance exists
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
- instance, err := instanceRepo.Get(CTX,
+ instance, err := instanceRepo.Get(CTX, pool,
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
)
require.NoError(t, err)
@@ -151,7 +151,7 @@ func TestServer_TestInstanceReduces(t *testing.T) {
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
- instance, err := instanceRepo.Get(CTX,
+ instance, err := instanceRepo.Get(CTX, pool,
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
)
// event instance.removed
diff --git a/backend/v3/storage/database/events_testing/org_domain_test.go b/backend/v3/storage/database/events_testing/org_domain_test.go
index af554a22817..4e35fca93ea 100644
--- a/backend/v3/storage/database/events_testing/org_domain_test.go
+++ b/backend/v3/storage/database/events_testing/org_domain_test.go
@@ -22,8 +22,8 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
})
require.NoError(t, err)
- orgRepo := repository.OrganizationRepository(pool)
- orgDomainRepo := orgRepo.Domains(false)
+ orgRepo := repository.OrganizationRepository()
+ orgDomainRepo := repository.OrganizationDomainRepository()
t.Cleanup(func() {
_, err := OrgClient.DeleteOrganization(CTX, &v2beta.DeleteOrganizationRequest{
@@ -37,8 +37,13 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
// Wait for org to be created
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
- _, err := orgRepo.Get(CTX,
- database.WithCondition(orgRepo.IDCondition(org.GetId())),
+ _, err := orgRepo.Get(CTX, pool,
+ database.WithCondition(
+ database.And(
+ orgRepo.InstanceIDCondition(Instance.Instance.Id),
+ orgRepo.IDCondition(org.GetId()),
+ ),
+ ),
)
assert.NoError(t, err)
}, retryDuration, tick)
@@ -68,7 +73,7 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
// Test that domain add reduces
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
- gottenDomain, err := orgDomainRepo.Get(CTX,
+ gottenDomain, err := orgDomainRepo.Get(CTX, pool,
database.WithCondition(
database.And(
orgDomainRepo.InstanceIDCondition(Instance.Instance.Id),
@@ -107,7 +112,7 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
// Test that domain remove reduces
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
- domain, err := orgDomainRepo.Get(CTX,
+ domain, err := orgDomainRepo.Get(CTX, pool,
database.WithCondition(
database.And(
orgDomainRepo.InstanceIDCondition(Instance.Instance.Id),
diff --git a/backend/v3/storage/database/events_testing/organization_test.go b/backend/v3/storage/database/events_testing/organization_test.go
index 7c89cfbcd51..bcca1d9ad0f 100644
--- a/backend/v3/storage/database/events_testing/organization_test.go
+++ b/backend/v3/storage/database/events_testing/organization_test.go
@@ -19,7 +19,7 @@ import (
func TestServer_TestOrganizationReduces(t *testing.T) {
instanceID := Instance.ID()
- orgRepo := repository.OrganizationRepository(pool)
+ orgRepo := repository.OrganizationRepository()
t.Run("test org add reduces", func(t *testing.T) {
beforeCreate := time.Now()
@@ -42,7 +42,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(tt *assert.CollectT) {
- organization, err := orgRepo.Get(CTX,
+ organization, err := orgRepo.Get(CTX, pool,
database.WithCondition(
database.And(
orgRepo.IDCondition(org.GetId()),
@@ -92,7 +92,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
- organization, err := orgRepo.Get(CTX,
+ organization, err := orgRepo.Get(CTX, pool,
database.WithCondition(
database.And(
orgRepo.IDCondition(organization.Id),
@@ -137,7 +137,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
- organization, err := orgRepo.Get(CTX,
+ organization, err := orgRepo.Get(CTX, pool,
database.WithCondition(
database.And(
orgRepo.IDCondition(organization.Id),
@@ -177,11 +177,11 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
})
require.NoError(t, err)
- orgRepo := repository.OrganizationRepository(pool)
+ orgRepo := repository.OrganizationRepository()
// 3. check org deactivated
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
- organization, err := orgRepo.Get(CTX,
+ organization, err := orgRepo.Get(CTX, pool,
database.WithCondition(
database.And(
orgRepo.IDCondition(organization.Id),
@@ -203,7 +203,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
- organization, err := orgRepo.Get(CTX,
+ organization, err := orgRepo.Get(CTX, pool,
database.WithCondition(
database.And(
orgRepo.IDCondition(organization.Id),
@@ -230,10 +230,10 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
require.NoError(t, err)
// 2. check org retrievable
- orgRepo := repository.OrganizationRepository(pool)
+ orgRepo := repository.OrganizationRepository()
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
- _, err := orgRepo.Get(CTX,
+ _, err := orgRepo.Get(CTX, pool,
database.WithCondition(
database.And(
orgRepo.IDCondition(organization.Id),
@@ -252,7 +252,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
- organization, err := orgRepo.Get(CTX,
+ organization, err := orgRepo.Get(CTX, pool,
database.WithCondition(
database.And(
orgRepo.IDCondition(organization.Id),
diff --git a/backend/v3/storage/database/gen_mock.go b/backend/v3/storage/database/gen_mock.go
index 04d204cfa1b..40ffd3a7b29 100644
--- a/backend/v3/storage/database/gen_mock.go
+++ b/backend/v3/storage/database/gen_mock.go
@@ -1,3 +1,3 @@
package database
-//go:generate mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Client,Row,Rows,Transaction
+//go:generate mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Connection,Row,Rows,Transaction
diff --git a/backend/v3/storage/database/operators.go b/backend/v3/storage/database/operators.go
index c8820d918db..93cf1d32cd6 100644
--- a/backend/v3/storage/database/operators.go
+++ b/backend/v3/storage/database/operators.go
@@ -6,16 +6,40 @@ import (
"golang.org/x/exp/constraints"
)
-type Value interface {
- Boolean | Number | Text | Instruction
+type wrappedValue[V Value] struct {
+ value V
+ fn function
}
+func LowerValue[T Value](v T) wrappedValue[T] {
+ return wrappedValue[T]{value: v, fn: functionLower}
+}
+
+func SHA256Value[T Value](v T) wrappedValue[T] {
+ return wrappedValue[T]{value: v, fn: functionSHA256}
+}
+
+func (b wrappedValue[V]) WriteArg(builder *StatementBuilder) {
+ builder.Grow(len(b.fn) + 5)
+ builder.WriteString(string(b.fn))
+ builder.WriteRune('(')
+ builder.WriteArg(b.value)
+ builder.WriteRune(')')
+}
+
+var _ argWriter = (*wrappedValue[string])(nil)
+
+type Value interface {
+ Boolean | Number | Text | Instruction | Bytes
+}
+
+//go:generate enumer -type NumberOperation,TextOperation,BytesOperation -linecomment -output ./operators_enumer.go
type Operation interface {
- BooleanOperation | NumberOperation | TextOperation
+ NumberOperation | TextOperation | BytesOperation
}
type Text interface {
- ~string | ~[]byte
+ ~string | Bytes
}
// TextOperation are operations that can be performed on text values.
@@ -23,60 +47,17 @@ type TextOperation uint8
const (
// TextOperationEqual compares two strings for equality.
- TextOperationEqual TextOperation = iota + 1
- // TextOperationEqualIgnoreCase compares two strings for equality, ignoring case.
- TextOperationEqualIgnoreCase
+ TextOperationEqual TextOperation = iota + 1 // =
// TextOperationNotEqual compares two strings for inequality.
- TextOperationNotEqual
- // TextOperationNotEqualIgnoreCase compares two strings for inequality, ignoring case.
- TextOperationNotEqualIgnoreCase
+ TextOperationNotEqual // <>
// TextOperationStartsWith checks if the first string starts with the second.
- TextOperationStartsWith
- // TextOperationStartsWithIgnoreCase checks if the first string starts with the second, ignoring case.
- TextOperationStartsWithIgnoreCase
+ TextOperationStartsWith // LIKE
)
-var textOperations = map[TextOperation]string{
- TextOperationEqual: " = ",
- TextOperationEqualIgnoreCase: " LIKE ",
- TextOperationNotEqual: " <> ",
- TextOperationNotEqualIgnoreCase: " NOT LIKE ",
- TextOperationStartsWith: " LIKE ",
- TextOperationStartsWithIgnoreCase: " LIKE ",
-}
-
-func writeTextOperation[T Text](builder *StatementBuilder, col Column, op TextOperation, value T) {
- switch op {
- case TextOperationEqual, TextOperationNotEqual:
- col.WriteQualified(builder)
- builder.WriteString(textOperations[op])
- builder.WriteArg(value)
- case TextOperationEqualIgnoreCase, TextOperationNotEqualIgnoreCase:
- builder.WriteString("LOWER(")
- col.WriteQualified(builder)
- builder.WriteString(")")
-
- builder.WriteString(textOperations[op])
- builder.WriteString("LOWER(")
- builder.WriteArg(value)
- builder.WriteString(")")
- case TextOperationStartsWith:
- col.WriteQualified(builder)
- builder.WriteString(textOperations[op])
- builder.WriteArg(value)
+func writeTextOperation[T Text](builder *StatementBuilder, col Column, op TextOperation, value any) {
+ writeOperation[T](builder, col, op.String(), value)
+ if op == TextOperationStartsWith {
builder.WriteString(" || '%'")
- case TextOperationStartsWithIgnoreCase:
- builder.WriteString("LOWER(")
- col.WriteQualified(builder)
- builder.WriteString(")")
-
- builder.WriteString(textOperations[op])
- builder.WriteString("LOWER(")
- builder.WriteArg(value)
- builder.WriteString(")")
- builder.WriteString(" || '%'")
- default:
- panic("unsupported text operation")
}
}
@@ -89,48 +70,60 @@ type NumberOperation uint8
const (
// NumberOperationEqual compares two numbers for equality.
- NumberOperationEqual NumberOperation = iota + 1
+ NumberOperationEqual NumberOperation = iota + 1 // =
// NumberOperationNotEqual compares two numbers for inequality.
- NumberOperationNotEqual
+ NumberOperationNotEqual // <>
// NumberOperationLessThan compares two numbers to check if the first is less than the second.
- NumberOperationLessThan
+ NumberOperationLessThan // <
// NumberOperationLessThanOrEqual compares two numbers to check if the first is less than or equal to the second.
- NumberOperationAtLeast
+ NumberOperationAtLeast // <=
// NumberOperationGreaterThan compares two numbers to check if the first is greater than the second.
- NumberOperationGreaterThan
+ NumberOperationGreaterThan // >
// NumberOperationGreaterThanOrEqual compares two numbers to check if the first is greater than or equal to the second.
- NumberOperationAtMost
+ NumberOperationAtMost // >=
)
-var numberOperations = map[NumberOperation]string{
- NumberOperationEqual: " = ",
- NumberOperationNotEqual: " <> ",
- NumberOperationLessThan: " < ",
- NumberOperationAtLeast: " <= ",
- NumberOperationGreaterThan: " > ",
- NumberOperationAtMost: " >= ",
-}
-
-func writeNumberOperation[T Number](builder *StatementBuilder, col Column, op NumberOperation, value T) {
- col.WriteQualified(builder)
- builder.WriteString(numberOperations[op])
- builder.WriteArg(value)
+func writeNumberOperation[T Number](builder *StatementBuilder, col Column, op NumberOperation, value any) {
+ writeOperation[T](builder, col, op.String(), value)
}
type Boolean interface {
~bool
}
-// BooleanOperation are operations that can be performed on boolean values.
-type BooleanOperation uint8
+func writeBooleanOperation[T Boolean](builder *StatementBuilder, col Column, value any) {
+ writeOperation[T](builder, col, "=", value)
+}
+
+type Bytes interface {
+ ~[]byte
+}
+
+// BytesOperation are operations that can be performed on bytea values.
+type BytesOperation uint8
const (
- BooleanOperationIsTrue BooleanOperation = iota + 1
- BooleanOperationIsFalse
+ BytesOperationEqual BytesOperation = iota + 1 // =
+ BytesOperationNotEqual // <>
)
-func writeBooleanOperation[T Boolean](builder *StatementBuilder, col Column, value T) {
+func writeBytesOperation[T Bytes](builder *StatementBuilder, col Column, op BytesOperation, value any) {
+ writeOperation[T](builder, col, op.String(), value)
+}
+
+func writeOperation[V Value](builder *StatementBuilder, col Column, op string, value any) {
+ if op == "" {
+ panic("unsupported operation")
+ }
+
+ switch value.(type) {
+ case V, wrappedValue[V], *wrappedValue[V]:
+ default:
+ panic("unsupported value type")
+ }
col.WriteQualified(builder)
- builder.WriteString(" = ")
+ builder.WriteRune(' ')
+ builder.WriteString(op)
+ builder.WriteRune(' ')
builder.WriteArg(value)
}
diff --git a/backend/v3/storage/database/operators_enumer.go b/backend/v3/storage/database/operators_enumer.go
new file mode 100644
index 00000000000..8c73ad37aaf
--- /dev/null
+++ b/backend/v3/storage/database/operators_enumer.go
@@ -0,0 +1,241 @@
+// Code generated by "enumer -type NumberOperation,TextOperation,BytesOperation -linecomment -output ./operators_enumer.go"; DO NOT EDIT.
+
+package database
+
+import (
+ "fmt"
+ "strings"
+)
+
+const _NumberOperationName = "=<><<=>>="
+
+var _NumberOperationIndex = [...]uint8{0, 1, 3, 4, 6, 7, 9}
+
+const _NumberOperationLowerName = "=<><<=>>="
+
+func (i NumberOperation) String() string {
+ i -= 1
+ if i >= NumberOperation(len(_NumberOperationIndex)-1) {
+ return fmt.Sprintf("NumberOperation(%d)", i+1)
+ }
+ return _NumberOperationName[_NumberOperationIndex[i]:_NumberOperationIndex[i+1]]
+}
+
+// An "invalid array index" compiler error signifies that the constant values have changed.
+// Re-run the stringer command to generate them again.
+func _NumberOperationNoOp() {
+ var x [1]struct{}
+ _ = x[NumberOperationEqual-(1)]
+ _ = x[NumberOperationNotEqual-(2)]
+ _ = x[NumberOperationLessThan-(3)]
+ _ = x[NumberOperationAtLeast-(4)]
+ _ = x[NumberOperationGreaterThan-(5)]
+ _ = x[NumberOperationAtMost-(6)]
+}
+
+var _NumberOperationValues = []NumberOperation{NumberOperationEqual, NumberOperationNotEqual, NumberOperationLessThan, NumberOperationAtLeast, NumberOperationGreaterThan, NumberOperationAtMost}
+
+var _NumberOperationNameToValueMap = map[string]NumberOperation{
+ _NumberOperationName[0:1]: NumberOperationEqual,
+ _NumberOperationLowerName[0:1]: NumberOperationEqual,
+ _NumberOperationName[1:3]: NumberOperationNotEqual,
+ _NumberOperationLowerName[1:3]: NumberOperationNotEqual,
+ _NumberOperationName[3:4]: NumberOperationLessThan,
+ _NumberOperationLowerName[3:4]: NumberOperationLessThan,
+ _NumberOperationName[4:6]: NumberOperationAtLeast,
+ _NumberOperationLowerName[4:6]: NumberOperationAtLeast,
+ _NumberOperationName[6:7]: NumberOperationGreaterThan,
+ _NumberOperationLowerName[6:7]: NumberOperationGreaterThan,
+ _NumberOperationName[7:9]: NumberOperationAtMost,
+ _NumberOperationLowerName[7:9]: NumberOperationAtMost,
+}
+
+var _NumberOperationNames = []string{
+ _NumberOperationName[0:1],
+ _NumberOperationName[1:3],
+ _NumberOperationName[3:4],
+ _NumberOperationName[4:6],
+ _NumberOperationName[6:7],
+ _NumberOperationName[7:9],
+}
+
+// NumberOperationString retrieves an enum value from the enum constants string name.
+// Throws an error if the param is not part of the enum.
+func NumberOperationString(s string) (NumberOperation, error) {
+ if val, ok := _NumberOperationNameToValueMap[s]; ok {
+ return val, nil
+ }
+
+ if val, ok := _NumberOperationNameToValueMap[strings.ToLower(s)]; ok {
+ return val, nil
+ }
+ return 0, fmt.Errorf("%s does not belong to NumberOperation values", s)
+}
+
+// NumberOperationValues returns all values of the enum
+func NumberOperationValues() []NumberOperation {
+ return _NumberOperationValues
+}
+
+// NumberOperationStrings returns a slice of all String values of the enum
+func NumberOperationStrings() []string {
+ strs := make([]string, len(_NumberOperationNames))
+ copy(strs, _NumberOperationNames)
+ return strs
+}
+
+// IsANumberOperation returns "true" if the value is listed in the enum definition. "false" otherwise
+func (i NumberOperation) IsANumberOperation() bool {
+ for _, v := range _NumberOperationValues {
+ if i == v {
+ return true
+ }
+ }
+ return false
+}
+
+const _TextOperationName = "=<>LIKE"
+
+var _TextOperationIndex = [...]uint8{0, 1, 3, 7}
+
+const _TextOperationLowerName = "=<>like"
+
+func (i TextOperation) String() string {
+ i -= 1
+ if i >= TextOperation(len(_TextOperationIndex)-1) {
+ return fmt.Sprintf("TextOperation(%d)", i+1)
+ }
+ return _TextOperationName[_TextOperationIndex[i]:_TextOperationIndex[i+1]]
+}
+
+// An "invalid array index" compiler error signifies that the constant values have changed.
+// Re-run the stringer command to generate them again.
+func _TextOperationNoOp() {
+ var x [1]struct{}
+ _ = x[TextOperationEqual-(1)]
+ _ = x[TextOperationNotEqual-(2)]
+ _ = x[TextOperationStartsWith-(3)]
+}
+
+var _TextOperationValues = []TextOperation{TextOperationEqual, TextOperationNotEqual, TextOperationStartsWith}
+
+var _TextOperationNameToValueMap = map[string]TextOperation{
+ _TextOperationName[0:1]: TextOperationEqual,
+ _TextOperationLowerName[0:1]: TextOperationEqual,
+ _TextOperationName[1:3]: TextOperationNotEqual,
+ _TextOperationLowerName[1:3]: TextOperationNotEqual,
+ _TextOperationName[3:7]: TextOperationStartsWith,
+ _TextOperationLowerName[3:7]: TextOperationStartsWith,
+}
+
+var _TextOperationNames = []string{
+ _TextOperationName[0:1],
+ _TextOperationName[1:3],
+ _TextOperationName[3:7],
+}
+
+// TextOperationString retrieves an enum value from the enum constants string name.
+// Throws an error if the param is not part of the enum.
+func TextOperationString(s string) (TextOperation, error) {
+ if val, ok := _TextOperationNameToValueMap[s]; ok {
+ return val, nil
+ }
+
+ if val, ok := _TextOperationNameToValueMap[strings.ToLower(s)]; ok {
+ return val, nil
+ }
+ return 0, fmt.Errorf("%s does not belong to TextOperation values", s)
+}
+
+// TextOperationValues returns all values of the enum
+func TextOperationValues() []TextOperation {
+ return _TextOperationValues
+}
+
+// TextOperationStrings returns a slice of all String values of the enum
+func TextOperationStrings() []string {
+ strs := make([]string, len(_TextOperationNames))
+ copy(strs, _TextOperationNames)
+ return strs
+}
+
+// IsATextOperation returns "true" if the value is listed in the enum definition. "false" otherwise
+func (i TextOperation) IsATextOperation() bool {
+ for _, v := range _TextOperationValues {
+ if i == v {
+ return true
+ }
+ }
+ return false
+}
+
+const _BytesOperationName = "=<>"
+
+var _BytesOperationIndex = [...]uint8{0, 1, 3}
+
+const _BytesOperationLowerName = "=<>"
+
+func (i BytesOperation) String() string {
+ i -= 1
+ if i >= BytesOperation(len(_BytesOperationIndex)-1) {
+ return fmt.Sprintf("BytesOperation(%d)", i+1)
+ }
+ return _BytesOperationName[_BytesOperationIndex[i]:_BytesOperationIndex[i+1]]
+}
+
+// An "invalid array index" compiler error signifies that the constant values have changed.
+// Re-run the stringer command to generate them again.
+func _BytesOperationNoOp() {
+ var x [1]struct{}
+ _ = x[BytesOperationEqual-(1)]
+ _ = x[BytesOperationNotEqual-(2)]
+}
+
+var _BytesOperationValues = []BytesOperation{BytesOperationEqual, BytesOperationNotEqual}
+
+var _BytesOperationNameToValueMap = map[string]BytesOperation{
+ _BytesOperationName[0:1]: BytesOperationEqual,
+ _BytesOperationLowerName[0:1]: BytesOperationEqual,
+ _BytesOperationName[1:3]: BytesOperationNotEqual,
+ _BytesOperationLowerName[1:3]: BytesOperationNotEqual,
+}
+
+var _BytesOperationNames = []string{
+ _BytesOperationName[0:1],
+ _BytesOperationName[1:3],
+}
+
+// BytesOperationString retrieves an enum value from the enum constants string name.
+// Throws an error if the param is not part of the enum.
+func BytesOperationString(s string) (BytesOperation, error) {
+ if val, ok := _BytesOperationNameToValueMap[s]; ok {
+ return val, nil
+ }
+
+ if val, ok := _BytesOperationNameToValueMap[strings.ToLower(s)]; ok {
+ return val, nil
+ }
+ return 0, fmt.Errorf("%s does not belong to BytesOperation values", s)
+}
+
+// BytesOperationValues returns all values of the enum
+func BytesOperationValues() []BytesOperation {
+ return _BytesOperationValues
+}
+
+// BytesOperationStrings returns a slice of all String values of the enum
+func BytesOperationStrings() []string {
+ strs := make([]string, len(_BytesOperationNames))
+ copy(strs, _BytesOperationNames)
+ return strs
+}
+
+// IsABytesOperation returns "true" if the value is listed in the enum definition. "false" otherwise
+func (i BytesOperation) IsABytesOperation() bool {
+ for _, v := range _BytesOperationValues {
+ if i == v {
+ return true
+ }
+ }
+ return false
+}
diff --git a/backend/v3/storage/database/operators_test.go b/backend/v3/storage/database/operators_test.go
new file mode 100644
index 00000000000..d7de372b8de
--- /dev/null
+++ b/backend/v3/storage/database/operators_test.go
@@ -0,0 +1,244 @@
+package database
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func Test_writeOperation(t *testing.T) {
+ type want struct {
+ shouldPanic bool
+ stmt string
+ args []any
+ }
+ tests := []struct {
+ name string
+ write func(builder *StatementBuilder)
+ // col Column
+ // op string
+ // value any
+ want want
+ }{
+ {
+ name: "unsupported operation panics",
+ write: func(builder *StatementBuilder) {
+ writeOperation[string](builder, NewColumn("table", "column"), "", "value")
+ },
+ want: want{
+ shouldPanic: true,
+ },
+ },
+ {
+ name: "unsupported value type panics",
+ write: func(builder *StatementBuilder) {
+ writeOperation[string](builder, NewColumn("table", "column"), " = ", struct{}{})
+ },
+ want: want{
+ shouldPanic: true,
+ },
+ },
+ {
+ name: "unsupported wrapped value type panics",
+ write: func(builder *StatementBuilder) {
+ writeOperation[string](builder, NewColumn("table", "column"), " = ", SHA256Value(123))
+ },
+ want: want{
+ shouldPanic: true,
+ },
+ },
+ {
+ name: "text equal",
+ write: func(builder *StatementBuilder) {
+ writeTextOperation[string](builder, NewColumn("table", "column"), TextOperationEqual, "value")
+ },
+ want: want{
+ stmt: "table.column = $1",
+ args: []any{"value"},
+ },
+ },
+ {
+ name: "text not equal",
+ write: func(builder *StatementBuilder) {
+ writeTextOperation[string](builder, NewColumn("table", "column"), TextOperationNotEqual, "value")
+ },
+ want: want{
+ stmt: "table.column <> $1",
+ args: []any{"value"},
+ },
+ },
+ {
+ name: "text starts with",
+ write: func(builder *StatementBuilder) {
+ writeTextOperation[string](builder, NewColumn("table", "column"), TextOperationStartsWith, "value")
+ },
+ want: want{
+ stmt: "table.column LIKE $1 || '%'",
+ args: []any{"value"},
+ },
+ },
+ {
+ name: "text equal with wrapped value",
+ write: func(builder *StatementBuilder) {
+ writeTextOperation[string](builder, LowerColumn(NewColumn("table", "column")), TextOperationEqual, LowerValue("value"))
+ },
+ want: want{
+ stmt: "LOWER(table.column) = LOWER($1)",
+ args: []any{"value"},
+ },
+ },
+ {
+ name: "text not equal with wrapped value",
+ write: func(builder *StatementBuilder) {
+ writeTextOperation[string](builder, LowerColumn(NewColumn("table", "column")), TextOperationNotEqual, LowerValue("value"))
+ },
+ want: want{
+ stmt: "LOWER(table.column) <> LOWER($1)",
+ args: []any{"value"},
+ },
+ },
+ {
+ name: "text starts with with wrapped value",
+ write: func(builder *StatementBuilder) {
+ writeTextOperation[string](builder, LowerColumn(NewColumn("table", "column")), TextOperationStartsWith, LowerValue("value"))
+ },
+ want: want{
+ stmt: "LOWER(table.column) LIKE LOWER($1) || '%'",
+ args: []any{"value"},
+ },
+ },
+ {
+ name: "number equal",
+ write: func(builder *StatementBuilder) {
+ writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationEqual, 123)
+ },
+ want: want{
+ stmt: "table.column = $1",
+ args: []any{123},
+ },
+ },
+ {
+ name: "number not equal",
+ write: func(builder *StatementBuilder) {
+ writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationNotEqual, 123)
+ },
+ want: want{
+ stmt: "table.column <> $1",
+ args: []any{123},
+ },
+ },
+ {
+ name: "number less than",
+ write: func(builder *StatementBuilder) {
+ writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationLessThan, 123)
+ },
+ want: want{
+ stmt: "table.column < $1",
+ args: []any{123},
+ },
+ },
+ {
+ name: "number less than or equal",
+ write: func(builder *StatementBuilder) {
+ writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationAtLeast, 123)
+ },
+ want: want{
+ stmt: "table.column <= $1",
+ args: []any{123},
+ },
+ },
+ {
+ name: "number greater than",
+ write: func(builder *StatementBuilder) {
+ writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationGreaterThan, 123)
+ },
+ want: want{
+ stmt: "table.column > $1",
+ args: []any{123},
+ },
+ },
+ {
+ name: "number greater than or equal",
+ write: func(builder *StatementBuilder) {
+ writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationAtMost, 123)
+ },
+ want: want{
+ stmt: "table.column >= $1",
+ args: []any{123},
+ },
+ },
+ {
+ name: "boolean is true",
+ write: func(builder *StatementBuilder) {
+ writeBooleanOperation[bool](builder, NewColumn("table", "column"), true)
+ },
+ want: want{
+ stmt: "table.column = $1",
+ args: []any{true},
+ },
+ },
+ {
+ name: "boolean is false",
+ write: func(builder *StatementBuilder) {
+ writeBooleanOperation[bool](builder, NewColumn("table", "column"), false)
+ },
+ want: want{
+ stmt: "table.column = $1",
+ args: []any{false},
+ },
+ },
+ {
+ name: "bytes equal",
+ write: func(builder *StatementBuilder) {
+ writeBytesOperation[[]byte](builder, NewColumn("table", "column"), BytesOperationEqual, []byte{0x01, 0x02, 0x03})
+ },
+ want: want{
+ stmt: "table.column = $1",
+ args: []any{[]byte{0x01, 0x02, 0x03}},
+ },
+ },
+ {
+ name: "bytes not equal",
+ write: func(builder *StatementBuilder) {
+ writeBytesOperation[[]byte](builder, NewColumn("table", "column"), BytesOperationNotEqual, []byte{0x01, 0x02, 0x03})
+ },
+ want: want{
+ stmt: "table.column <> $1",
+ args: []any{[]byte{0x01, 0x02, 0x03}},
+ },
+ },
+ {
+ name: "bytes equal with wrapped value",
+ write: func(builder *StatementBuilder) {
+ writeBytesOperation[[]byte](builder, SHA256Column(NewColumn("table", "column")), BytesOperationEqual, SHA256Value([]byte{0x01, 0x02, 0x03}))
+ },
+ want: want{
+ stmt: "SHA256(table.column) = SHA256($1)",
+ args: []any{[]byte{0x01, 0x02, 0x03}},
+ },
+ },
+ {
+ name: "bytes not equal with wrapped value",
+ write: func(builder *StatementBuilder) {
+ writeBytesOperation[[]byte](builder, SHA256Column(NewColumn("table", "column")), BytesOperationNotEqual, SHA256Value([]byte{0x01, 0x02, 0x03}))
+ },
+ want: want{
+ stmt: "SHA256(table.column) <> SHA256($1)",
+ args: []any{[]byte{0x01, 0x02, 0x03}},
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ defer func() {
+ r := recover()
+ assert.Equal(t, tt.want.shouldPanic, r != nil)
+ }()
+ var builder StatementBuilder
+ tt.write(&builder)
+
+ assert.Equal(t, tt.want.stmt, builder.String())
+ assert.Equal(t, tt.want.args, builder.Args())
+ })
+ }
+}
diff --git a/backend/v3/storage/database/order_test.go b/backend/v3/storage/database/order_test.go
new file mode 100644
index 00000000000..f90aafe97f9
--- /dev/null
+++ b/backend/v3/storage/database/order_test.go
@@ -0,0 +1,28 @@
+package database
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func Test_orderBy_Write(t *testing.T) {
+ tests := []struct {
+ name string
+ want string
+ order Order
+ }{
+ {
+ name: "order by column",
+ want: " ORDER BY table.column",
+ order: OrderBy(NewColumn("table", "column")),
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var builder StatementBuilder
+ tt.order.Write(&builder)
+ assert.Equal(t, tt.want, builder.String())
+ })
+ }
+}
diff --git a/backend/v3/storage/database/query_test.go b/backend/v3/storage/database/query_test.go
new file mode 100644
index 00000000000..5e8e1a58361
--- /dev/null
+++ b/backend/v3/storage/database/query_test.go
@@ -0,0 +1,147 @@
+package database
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestQueryOptions(t *testing.T) {
+ type want struct {
+ stmt string
+ args []any
+ }
+ for _, test := range []struct {
+ name string
+ options []QueryOption
+ want want
+ }{
+ {
+ name: "no options",
+ want: want{
+ stmt: "",
+ args: nil,
+ },
+ },
+ {
+ name: "limit option",
+ options: []QueryOption{
+ WithLimit(10),
+ },
+ want: want{
+ stmt: " LIMIT $1",
+ args: []any{uint32(10)},
+ },
+ },
+ {
+ name: "offset option",
+ options: []QueryOption{
+ WithOffset(5),
+ },
+ want: want{
+ stmt: " OFFSET $1",
+ args: []any{uint32(5)},
+ },
+ },
+ {
+ name: "order by asc option",
+ options: []QueryOption{
+ WithOrderByAscending(NewColumn("table", "column")),
+ },
+ want: want{
+ stmt: " ORDER BY table.column",
+ args: nil,
+ },
+ },
+ {
+ name: "order by desc option",
+ options: []QueryOption{
+ WithOrderByDescending(NewColumn("table", "column")),
+ },
+ want: want{
+ stmt: " ORDER BY table.column DESC",
+ args: nil,
+ },
+ },
+ {
+ name: "order by option",
+ options: []QueryOption{
+ WithOrderBy(OrderDirectionAsc, NewColumn("table", "column1"), NewColumn("table", "column2")),
+ },
+ want: want{
+ stmt: " ORDER BY table.column1, table.column2",
+ args: nil,
+ },
+ },
+ {
+ name: "condition option",
+ options: []QueryOption{
+ WithCondition(NewBooleanCondition(NewColumn("table", "column"), true)),
+ },
+ want: want{
+ stmt: " WHERE table.column = $1",
+ args: []any{true},
+ },
+ },
+ {
+ name: "group by option",
+ options: []QueryOption{
+ WithGroupBy(NewColumn("table", "column")),
+ },
+ want: want{
+ stmt: " GROUP BY table.column",
+ args: nil,
+ },
+ },
+ {
+ name: "left join option",
+ options: []QueryOption{
+ WithLeftJoin("other_table", NewColumnCondition(NewColumn("table", "id"), NewColumn("other_table", "table_id"))),
+ },
+ want: want{
+ stmt: " LEFT JOIN other_table ON table.id = other_table.table_id",
+ args: nil,
+ },
+ },
+ {
+ name: "permission check option",
+ options: []QueryOption{
+ WithPermissionCheck("permission"),
+ },
+ want: want{
+ stmt: "",
+ args: nil,
+ },
+ },
+ {
+ name: "multiple options",
+ options: []QueryOption{
+ WithLeftJoin("other_table", NewColumnCondition(NewColumn("table", "id"), NewColumn("other_table", "table_id"))),
+ WithCondition(NewNumberCondition(NewColumn("table", "column"), NumberOperationEqual, 123)),
+ WithOrderByDescending(NewColumn("table", "column")),
+ WithLimit(10),
+ WithOffset(5),
+ },
+ want: want{
+ stmt: " LEFT JOIN other_table ON table.id = other_table.table_id WHERE table.column = $1 ORDER BY table.column DESC LIMIT $2 OFFSET $3",
+ args: []any{123, uint32(10), uint32(5)},
+ },
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ var b StatementBuilder
+ var opts QueryOpts
+ for _, option := range test.options {
+ option(&opts)
+ }
+ opts.Write(&b)
+ assert.Equal(t, test.want.stmt, b.String())
+ require.Len(t, b.Args(), len(test.want.args))
+ for i := range test.want.args {
+ assert.Equal(t, test.want.args[i], b.Args()[i])
+ }
+ })
+ }
+
+}
diff --git a/backend/v3/storage/database/repository/array.go b/backend/v3/storage/database/repository/array.go
index 77851033d61..49c150fd5d6 100644
--- a/backend/v3/storage/database/repository/array.go
+++ b/backend/v3/storage/database/repository/array.go
@@ -3,19 +3,35 @@ package repository
import (
"encoding/json"
"errors"
+
+ "github.com/zitadel/zitadel/backend/v3/storage/database"
)
type JSONArray[T any] []*T
-func (a JSONArray[T]) Scan(src any) error {
+var ErrScanSource = errors.New("unsupported scan source")
+
+func (a *JSONArray[T]) Scan(src any) (err error) {
+ var rawJSON []byte
switch s := src.(type) {
case string:
- return json.Unmarshal([]byte(s), &a)
+ if len(s) == 0 {
+ return nil
+ }
+ rawJSON = []byte(s)
case []byte:
- return json.Unmarshal(s, &a)
+ if len(s) == 0 {
+ return nil
+ }
+ rawJSON = s
case nil:
return nil
default:
- return errors.New("unsupported scan source")
+ return ErrScanSource
}
+ err = json.Unmarshal(rawJSON, a)
+ if err != nil {
+ return database.NewScanError(err)
+ }
+ return nil
}
diff --git a/backend/v3/storage/database/repository/array_test.go b/backend/v3/storage/database/repository/array_test.go
new file mode 100644
index 00000000000..9f320190f0d
--- /dev/null
+++ b/backend/v3/storage/database/repository/array_test.go
@@ -0,0 +1,117 @@
+package repository_test
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "github.com/zitadel/zitadel/backend/v3/storage/database"
+ "github.com/zitadel/zitadel/backend/v3/storage/database/repository"
+)
+
+type testObject struct {
+ ID string `json:"id"`
+ Number int `json:"number"`
+ Active bool `json:"active"`
+}
+
+func TestJSONArray_Scan(t *testing.T) {
+ type want struct {
+ err error
+ res []*testObject
+ }
+ tests := []struct {
+ name string
+ src any
+ want want
+ }{
+ {
+ name: "nil",
+ src: nil,
+ want: want{
+ err: nil,
+ res: nil,
+ },
+ },
+ {
+ name: "[]byte with data",
+ src: []byte(`[{"id":"1","number":1,"active":true}]`),
+ want: want{
+ err: nil,
+ res: []*testObject{{ID: "1", Number: 1, Active: true}},
+ },
+ },
+ {
+ name: "[]byte without data",
+ src: []byte(`[]`),
+ want: want{
+ err: nil,
+ res: nil,
+ },
+ },
+ {
+ name: "nil []byte",
+ src: []byte(nil),
+ want: want{
+ err: nil,
+ res: nil,
+ },
+ },
+ {
+ name: "empty []byte",
+ src: []byte{},
+ want: want{
+ err: nil,
+ res: nil,
+ },
+ },
+ {
+ name: "string with data",
+ src: `[{"id":"1","number":1,"active":true}]`,
+ want: want{
+ err: nil,
+ res: []*testObject{{ID: "1", Number: 1, Active: true}},
+ },
+ },
+ {
+ name: "string without data",
+ src: string(`[]`),
+ want: want{
+ err: nil,
+ res: nil,
+ },
+ },
+ {
+ name: "empty string",
+ src: "",
+ want: want{
+ err: nil,
+ res: nil,
+ },
+ },
+ {
+ name: "wrong type",
+ src: []int{1, 2, 3},
+ want: want{
+ err: repository.ErrScanSource,
+ res: nil,
+ },
+ },
+ {
+ name: "badly formatted JSON",
+ src: []byte(`this is not a JSON`),
+ want: want{
+ err: new(database.ScanError),
+ res: nil,
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var a repository.JSONArray[testObject]
+ gotErr := a.Scan(tt.src)
+ require.ErrorIs(t, gotErr, tt.want.err)
+ require.Len(t, a, len(tt.want.res))
+ })
+ }
+}
diff --git a/backend/v3/storage/database/repository/instance.go b/backend/v3/storage/database/repository/instance.go
index fa5691fc94e..2d03bdc5e0e 100644
--- a/backend/v3/storage/database/repository/instance.go
+++ b/backend/v3/storage/database/repository/instance.go
@@ -13,17 +13,20 @@ import (
var _ domain.InstanceRepository = (*instance)(nil)
type instance struct {
- repository
shouldLoadDomains bool
- domainRepo *instanceDomain
+ domainRepo instanceDomain
}
-func InstanceRepository(client database.QueryExecutor) domain.InstanceRepository {
- return &instance{
- repository: repository{
- client: client,
- },
- }
+func InstanceRepository() domain.InstanceRepository {
+ return new(instance)
+}
+
+func (instance) qualifiedTableName() string {
+ return "zitadel.instances"
+}
+
+func (instance) unqualifiedTableName() string {
+ return "instances"
}
// -------------------------------------------------------------
@@ -37,7 +40,7 @@ const (
)
// Get implements [domain.InstanceRepository].
-func (i *instance) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Instance, error) {
+func (i instance) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.Instance, error) {
opts = append(opts,
i.joinDomains(),
database.WithGroupBy(i.IDColumn()),
@@ -52,11 +55,11 @@ func (i *instance) Get(ctx context.Context, opts ...database.QueryOption) (*doma
builder.WriteString(queryInstanceStmt)
options.Write(&builder)
- return scanInstance(ctx, i.client, &builder)
+ return scanInstance(ctx, client, &builder)
}
// List implements [domain.InstanceRepository].
-func (i *instance) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Instance, error) {
+func (i instance) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*domain.Instance, error) {
opts = append(opts,
i.joinDomains(),
database.WithGroupBy(i.IDColumn()),
@@ -71,27 +74,11 @@ func (i *instance) List(ctx context.Context, opts ...database.QueryOption) ([]*d
builder.WriteString(queryInstanceStmt)
options.Write(&builder)
- return scanInstances(ctx, i.client, &builder)
-}
-
-func (i *instance) joinDomains() database.QueryOption {
- columns := make([]database.Condition, 0, 2)
- columns = append(columns, database.NewColumnCondition(i.IDColumn(), i.Domains(false).InstanceIDColumn()))
-
- // If domains should not be joined, we make sure to return null for the domain columns
- // the query optimizer of the dialect should optimize this away if no domains are requested
- if !i.shouldLoadDomains {
- columns = append(columns, database.IsNull(i.Domains(false).InstanceIDColumn()))
- }
-
- return database.WithLeftJoin(
- "zitadel.instance_domains",
- database.And(columns...),
- )
+ return scanInstances(ctx, client, &builder)
}
// Create implements [domain.InstanceRepository].
-func (i *instance) Create(ctx context.Context, instance *domain.Instance) error {
+func (i instance) Create(ctx context.Context, client database.QueryExecutor, instance *domain.Instance) error {
var (
builder database.StatementBuilder
createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction
@@ -103,42 +90,46 @@ func (i *instance) Create(ctx context.Context, instance *domain.Instance) error
updatedAt = instance.UpdatedAt
}
- builder.WriteString(`INSERT INTO zitadel.instances (id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language, created_at, updated_at) VALUES (`)
+ builder.WriteString(`INSERT INTO `)
+ builder.WriteString(i.qualifiedTableName())
+ builder.WriteString(` (id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language, created_at, updated_at) VALUES (`)
builder.WriteArgs(instance.ID, instance.Name, instance.DefaultOrgID, instance.IAMProjectID, instance.ConsoleClientID, instance.ConsoleAppID, instance.DefaultLanguage, createdAt, updatedAt)
builder.WriteString(`) RETURNING created_at, updated_at`)
- return i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&instance.CreatedAt, &instance.UpdatedAt)
+ return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&instance.CreatedAt, &instance.UpdatedAt)
}
// Update implements [domain.InstanceRepository].
-func (i instance) Update(ctx context.Context, id string, changes ...database.Change) (int64, error) {
+func (i instance) Update(ctx context.Context, client database.QueryExecutor, id string, changes ...database.Change) (int64, error) {
if len(changes) == 0 {
return 0, database.ErrNoChanges
}
+ if !database.Changes(changes).IsOnColumn(i.UpdatedAtColumn()) {
+ changes = append(changes, database.NewChange(i.UpdatedAtColumn(), database.NullInstruction))
+ }
+
var builder database.StatementBuilder
-
builder.WriteString(`UPDATE zitadel.instances SET `)
-
database.Changes(changes).Write(&builder)
-
idCondition := i.IDCondition(id)
writeCondition(&builder, idCondition)
stmt := builder.String()
- return i.client.Exec(ctx, stmt, builder.Args()...)
+ return client.Exec(ctx, stmt, builder.Args()...)
}
// Delete implements [domain.InstanceRepository].
-func (i instance) Delete(ctx context.Context, id string) (int64, error) {
+func (i instance) Delete(ctx context.Context, client database.QueryExecutor, id string) (int64, error) {
var builder database.StatementBuilder
- builder.WriteString(`DELETE FROM zitadel.instances`)
+ builder.WriteString(`DELETE FROM `)
+ builder.WriteString(i.qualifiedTableName())
idCondition := i.IDCondition(id)
writeCondition(&builder, idCondition)
- return i.client.Exec(ctx, builder.String(), builder.Args()...)
+ return client.Exec(ctx, builder.String(), builder.Args()...)
}
// -------------------------------------------------------------
@@ -185,53 +176,77 @@ func (i instance) NameCondition(op database.TextOperation, name string) database
return database.NewTextCondition(i.NameColumn(), op, name)
}
+// ExistsDomain creates a correlated [database.Exists] condition on instance_domains.
+// Use this filter to make sure the Instance returned contains a specific domain.
+// of the instance in the aggregated result.
+// Example usage:
+//
+// domainRepo := instanceRepo.Domains(true) // ensure domains are loaded/aggregated
+// instance, _ := instanceRepo.Get(ctx,
+// database.WithCondition(
+// database.And(
+// instanceRepo.InstanceIDCondition(instanceID),
+// instanceRepo.ExistsDomain(domainRepo.DomainCondition(database.TextOperationEqual, "example.com")),
+// ),
+// ),
+// )
+func (i instance) ExistsDomain(cond database.Condition) database.Condition {
+ return database.Exists(
+ i.domainRepo.qualifiedTableName(),
+ database.And(
+ database.NewColumnCondition(i.IDColumn(), i.domainRepo.InstanceIDColumn()),
+ cond,
+ ),
+ )
+}
+
// -------------------------------------------------------------
// columns
// -------------------------------------------------------------
// IDColumn implements [domain.instanceColumns].
-func (instance) IDColumn() database.Column {
- return database.NewColumn("instances", "id")
+func (i instance) IDColumn() database.Column {
+ return database.NewColumn(i.unqualifiedTableName(), "id")
}
// NameColumn implements [domain.instanceColumns].
-func (instance) NameColumn() database.Column {
- return database.NewColumn("instances", "name")
+func (i instance) NameColumn() database.Column {
+ return database.NewColumn(i.unqualifiedTableName(), "name")
}
// CreatedAtColumn implements [domain.instanceColumns].
-func (instance) CreatedAtColumn() database.Column {
- return database.NewColumn("instances", "created_at")
+func (i instance) CreatedAtColumn() database.Column {
+ return database.NewColumn(i.unqualifiedTableName(), "created_at")
}
// DefaultOrgIdColumn implements [domain.instanceColumns].
-func (instance) DefaultOrgIDColumn() database.Column {
- return database.NewColumn("instances", "default_org_id")
+func (i instance) DefaultOrgIDColumn() database.Column {
+ return database.NewColumn(i.unqualifiedTableName(), "default_org_id")
}
// IAMProjectIDColumn implements [domain.instanceColumns].
-func (instance) IAMProjectIDColumn() database.Column {
- return database.NewColumn("instances", "iam_project_id")
+func (i instance) IAMProjectIDColumn() database.Column {
+ return database.NewColumn(i.unqualifiedTableName(), "iam_project_id")
}
// ConsoleClientIDColumn implements [domain.instanceColumns].
-func (instance) ConsoleClientIDColumn() database.Column {
- return database.NewColumn("instances", "console_client_id")
+func (i instance) ConsoleClientIDColumn() database.Column {
+ return database.NewColumn(i.unqualifiedTableName(), "console_client_id")
}
// ConsoleAppIDColumn implements [domain.instanceColumns].
-func (instance) ConsoleAppIDColumn() database.Column {
- return database.NewColumn("instances", "console_app_id")
+func (i instance) ConsoleAppIDColumn() database.Column {
+ return database.NewColumn(i.unqualifiedTableName(), "console_app_id")
}
// DefaultLanguageColumn implements [domain.instanceColumns].
-func (instance) DefaultLanguageColumn() database.Column {
- return database.NewColumn("instances", "default_language")
+func (i instance) DefaultLanguageColumn() database.Column {
+ return database.NewColumn(i.unqualifiedTableName(), "default_language")
}
// UpdatedAtColumn implements [domain.instanceColumns].
-func (instance) UpdatedAtColumn() database.Column {
- return database.NewColumn("instances", "updated_at")
+func (i instance) UpdatedAtColumn() database.Column {
+ return database.NewColumn(i.unqualifiedTableName(), "updated_at")
}
// -------------------------------------------------------------
@@ -282,19 +297,24 @@ func scanInstances(ctx context.Context, querier database.Querier, builder *datab
// sub repositories
// -------------------------------------------------------------
-// Domains implements [domain.InstanceRepository].
-func (i *instance) Domains(shouldLoad bool) domain.InstanceDomainRepository {
- if !i.shouldLoadDomains {
- i.shouldLoadDomains = shouldLoad
+func (i *instance) LoadDomains() domain.InstanceRepository {
+ return &instance{
+ shouldLoadDomains: true,
}
-
- if i.domainRepo != nil {
- return i.domainRepo
- }
-
- i.domainRepo = &instanceDomain{
- repository: i.repository,
- instance: i,
- }
- return i.domainRepo
+}
+
+func (i *instance) joinDomains() database.QueryOption {
+ columns := make([]database.Condition, 0, 2)
+ columns = append(columns, database.NewColumnCondition(i.IDColumn(), i.domainRepo.InstanceIDColumn()))
+
+ // If domains should not be joined, we make sure to return null for the domain columns
+ // the query optimizer of the dialect should optimize this away if no domains are requested
+ if !i.shouldLoadDomains {
+ columns = append(columns, database.IsNull(i.domainRepo.InstanceIDColumn()))
+ }
+
+ return database.WithLeftJoin(
+ i.domainRepo.qualifiedTableName(),
+ database.And(columns...),
+ )
}
diff --git a/backend/v3/storage/database/repository/instance_domain.go b/backend/v3/storage/database/repository/instance_domain.go
index 3aed56238ea..9cb4d9e127a 100644
--- a/backend/v3/storage/database/repository/instance_domain.go
+++ b/backend/v3/storage/database/repository/instance_domain.go
@@ -10,9 +10,18 @@ import (
var _ domain.InstanceDomainRepository = (*instanceDomain)(nil)
-type instanceDomain struct {
- repository
- *instance
+type instanceDomain struct{}
+
+func InstanceDomainRepository() domain.InstanceDomainRepository {
+ return new(instanceDomain)
+}
+
+func (instanceDomain) qualifiedTableName() string {
+ return "zitadel.instance_domains"
+}
+
+func (instanceDomain) unqualifiedTableName() string {
+ return "instance_domains"
}
// -------------------------------------------------------------
@@ -23,8 +32,7 @@ const queryInstanceDomainStmt = `SELECT instance_domains.instance_id, instance_d
`FROM zitadel.instance_domains`
// Get implements [domain.InstanceDomainRepository].
-// Subtle: this method shadows the method ([domain.InstanceRepository]).Get of instanceDomain.instance.
-func (i *instanceDomain) Get(ctx context.Context, opts ...database.QueryOption) (*domain.InstanceDomain, error) {
+func (i instanceDomain) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.InstanceDomain, error) {
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
@@ -34,12 +42,11 @@ func (i *instanceDomain) Get(ctx context.Context, opts ...database.QueryOption)
builder.WriteString(queryInstanceDomainStmt)
options.Write(&builder)
- return scanInstanceDomain(ctx, i.client, &builder)
+ return scanInstanceDomain(ctx, client, &builder)
}
// List implements [domain.InstanceDomainRepository].
-// Subtle: this method shadows the method ([domain.InstanceRepository]).List of instanceDomain.instance.
-func (i *instanceDomain) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.InstanceDomain, error) {
+func (i instanceDomain) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*domain.InstanceDomain, error) {
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
@@ -49,11 +56,11 @@ func (i *instanceDomain) List(ctx context.Context, opts ...database.QueryOption)
builder.WriteString(queryInstanceDomainStmt)
options.Write(&builder)
- return scanInstanceDomains(ctx, i.client, &builder)
+ return scanInstanceDomains(ctx, client, &builder)
}
// Add implements [domain.InstanceDomainRepository].
-func (i *instanceDomain) Add(ctx context.Context, domain *domain.AddInstanceDomain) error {
+func (i instanceDomain) Add(ctx context.Context, client database.QueryExecutor, domain *domain.AddInstanceDomain) error {
var (
builder database.StatementBuilder
createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction
@@ -69,33 +76,41 @@ func (i *instanceDomain) Add(ctx context.Context, domain *domain.AddInstanceDoma
builder.WriteArgs(domain.InstanceID, domain.Domain, domain.IsPrimary, domain.IsGenerated, domain.Type, createdAt, updatedAt)
builder.WriteString(`) RETURNING created_at, updated_at`)
- return i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt)
+ return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt)
}
// Update implements [domain.InstanceDomainRepository].
-// Subtle: this method shadows the method ([domain.InstanceRepository]).Update of instanceDomain.instance.
-func (i *instanceDomain) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) {
+func (i instanceDomain) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error) {
+ if !condition.IsRestrictingColumn(i.InstanceIDColumn()) {
+ return 0, database.NewMissingConditionError(i.InstanceIDColumn())
+ }
if len(changes) == 0 {
return 0, database.ErrNoChanges
}
- var builder database.StatementBuilder
+ if !database.Changes(changes).IsOnColumn(i.UpdatedAtColumn()) {
+ changes = append(changes, database.NewChange(i.UpdatedAtColumn(), database.NullInstruction))
+ }
+ var builder database.StatementBuilder
builder.WriteString(`UPDATE zitadel.instance_domains SET `)
database.Changes(changes).Write(&builder)
writeCondition(&builder, condition)
- return i.client.Exec(ctx, builder.String(), builder.Args()...)
+ return client.Exec(ctx, builder.String(), builder.Args()...)
}
// Remove implements [domain.InstanceDomainRepository].
-func (i *instanceDomain) Remove(ctx context.Context, condition database.Condition) (int64, error) {
- var builder database.StatementBuilder
+func (i instanceDomain) Remove(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error) {
+ if !condition.IsRestrictingColumn(i.InstanceIDColumn()) {
+ return 0, database.NewMissingConditionError(i.InstanceIDColumn())
+ }
+ var builder database.StatementBuilder
builder.WriteString(`DELETE FROM zitadel.instance_domains WHERE `)
condition.Write(&builder)
- return i.client.Exec(ctx, builder.String(), builder.Args()...)
+ return client.Exec(ctx, builder.String(), builder.Args()...)
}
// -------------------------------------------------------------
@@ -146,40 +161,38 @@ func (i instanceDomain) TypeCondition(typ domain.DomainType) database.Condition
// -------------------------------------------------------------
// CreatedAtColumn implements [domain.InstanceDomainRepository].
-// Subtle: this method shadows the method ([domain.InstanceRepository]).CreatedAtColumn of instanceDomain.instance.
-func (instanceDomain) CreatedAtColumn() database.Column {
- return database.NewColumn("instance_domains", "created_at")
+func (i instanceDomain) CreatedAtColumn() database.Column {
+ return database.NewColumn(i.unqualifiedTableName(), "created_at")
}
// DomainColumn implements [domain.InstanceDomainRepository].
-func (instanceDomain) DomainColumn() database.Column {
- return database.NewColumn("instance_domains", "domain")
+func (i instanceDomain) DomainColumn() database.Column {
+ return database.NewColumn(i.unqualifiedTableName(), "domain")
}
// InstanceIDColumn implements [domain.InstanceDomainRepository].
-func (instanceDomain) InstanceIDColumn() database.Column {
- return database.NewColumn("instance_domains", "instance_id")
+func (i instanceDomain) InstanceIDColumn() database.Column {
+ return database.NewColumn(i.unqualifiedTableName(), "instance_id")
}
// IsPrimaryColumn implements [domain.InstanceDomainRepository].
-func (instanceDomain) IsPrimaryColumn() database.Column {
- return database.NewColumn("instance_domains", "is_primary")
+func (i instanceDomain) IsPrimaryColumn() database.Column {
+ return database.NewColumn(i.unqualifiedTableName(), "is_primary")
}
// UpdatedAtColumn implements [domain.InstanceDomainRepository].
-// Subtle: this method shadows the method ([domain.InstanceRepository]).UpdatedAtColumn of instanceDomain.instance.
-func (instanceDomain) UpdatedAtColumn() database.Column {
- return database.NewColumn("instance_domains", "updated_at")
+func (i instanceDomain) UpdatedAtColumn() database.Column {
+ return database.NewColumn(i.unqualifiedTableName(), "updated_at")
}
// IsGeneratedColumn implements [domain.InstanceDomainRepository].
-func (instanceDomain) IsGeneratedColumn() database.Column {
- return database.NewColumn("instance_domains", "is_generated")
+func (i instanceDomain) IsGeneratedColumn() database.Column {
+ return database.NewColumn(i.unqualifiedTableName(), "is_generated")
}
// TypeColumn implements [domain.InstanceDomainRepository].
-func (instanceDomain) TypeColumn() database.Column {
- return database.NewColumn("instance_domains", "type")
+func (i instanceDomain) TypeColumn() database.Column {
+ return database.NewColumn(i.unqualifiedTableName(), "type")
}
// -------------------------------------------------------------
diff --git a/backend/v3/storage/database/repository/instance_domain_test.go b/backend/v3/storage/database/repository/instance_domain_test.go
index cc18ada380f..1900e648461 100644
--- a/backend/v3/storage/database/repository/instance_domain_test.go
+++ b/backend/v3/storage/database/repository/instance_domain_test.go
@@ -1,7 +1,6 @@
package repository_test
import (
- "context"
"testing"
"time"
@@ -16,31 +15,43 @@ import (
)
func TestAddInstanceDomain(t *testing.T) {
+ // we take now here because the timestamp of the transaction is used to set the createdAt and updatedAt fields
+ beforeAdd := time.Now()
+ tx, err := pool.Begin(t.Context(), nil)
+ require.NoError(t, err)
+ defer func() {
+ err = tx.Rollback(t.Context())
+ if err != nil {
+ t.Log("error during rollback:", err)
+ }
+ }()
+
+ instanceRepo := repository.InstanceRepository()
+ domainRepo := repository.InstanceDomainRepository()
+
// create instance
- instanceID := gofakeit.UUID()
instance := domain.Instance{
- ID: instanceID,
- Name: gofakeit.Name(),
+ ID: gofakeit.NewCrypto().UUID(),
+ Name: gofakeit.BeerName(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleClient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
- instanceRepo := repository.InstanceRepository(pool)
- err := instanceRepo.Create(t.Context(), &instance)
+ err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
tests := []struct {
name string
- testFunc func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain
+ testFunc func(t *testing.T, tx database.QueryExecutor) *domain.AddInstanceDomain
instanceDomain domain.AddInstanceDomain
err error
}{
{
name: "happy path custom domain",
instanceDomain: domain.AddInstanceDomain{
- InstanceID: instanceID,
+ InstanceID: instance.ID,
Domain: gofakeit.DomainName(),
Type: domain.DomainTypeCustom,
IsPrimary: gu.Ptr(false),
@@ -50,7 +61,7 @@ func TestAddInstanceDomain(t *testing.T) {
{
name: "happy path trusted domain",
instanceDomain: domain.AddInstanceDomain{
- InstanceID: instanceID,
+ InstanceID: instance.ID,
Domain: gofakeit.DomainName(),
Type: domain.DomainTypeTrusted,
},
@@ -58,7 +69,7 @@ func TestAddInstanceDomain(t *testing.T) {
{
name: "add primary domain",
instanceDomain: domain.AddInstanceDomain{
- InstanceID: instanceID,
+ InstanceID: instance.ID,
Domain: gofakeit.DomainName(),
Type: domain.DomainTypeCustom,
IsPrimary: gu.Ptr(true),
@@ -68,7 +79,7 @@ func TestAddInstanceDomain(t *testing.T) {
{
name: "add custom domain without domain name",
instanceDomain: domain.AddInstanceDomain{
- InstanceID: instanceID,
+ InstanceID: instance.ID,
Domain: "",
Type: domain.DomainTypeCustom,
IsPrimary: gu.Ptr(false),
@@ -79,7 +90,7 @@ func TestAddInstanceDomain(t *testing.T) {
{
name: "add trusted domain without domain name",
instanceDomain: domain.AddInstanceDomain{
- InstanceID: instanceID,
+ InstanceID: instance.ID,
Domain: "",
Type: domain.DomainTypeTrusted,
},
@@ -87,23 +98,23 @@ func TestAddInstanceDomain(t *testing.T) {
},
{
name: "add custom domain with same domain twice",
- testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain {
+ testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.AddInstanceDomain {
domainName := gofakeit.DomainName()
instanceDomain := &domain.AddInstanceDomain{
- InstanceID: instanceID,
+ InstanceID: instance.ID,
Domain: domainName,
Type: domain.DomainTypeCustom,
IsPrimary: gu.Ptr(false),
IsGenerated: gu.Ptr(false),
}
- err := domainRepo.Add(ctx, instanceDomain)
+ err := domainRepo.Add(t.Context(), tx, instanceDomain)
require.NoError(t, err)
// return same domain again
return &domain.AddInstanceDomain{
- InstanceID: instanceID,
+ InstanceID: instance.ID,
Domain: domainName,
Type: domain.DomainTypeCustom,
IsPrimary: gu.Ptr(false),
@@ -114,22 +125,20 @@ func TestAddInstanceDomain(t *testing.T) {
},
{
name: "add trusted domain with same domain twice",
- testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain {
- domainName := gofakeit.DomainName()
-
+ testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.AddInstanceDomain {
instanceDomain := &domain.AddInstanceDomain{
- InstanceID: instanceID,
- Domain: domainName,
+ InstanceID: instance.ID,
+ Domain: gofakeit.DomainName(),
Type: domain.DomainTypeTrusted,
}
- err := domainRepo.Add(ctx, instanceDomain)
+ err := domainRepo.Add(t.Context(), tx, instanceDomain)
require.NoError(t, err)
// return same domain again
return &domain.AddInstanceDomain{
- InstanceID: instanceID,
- Domain: domainName,
+ InstanceID: instance.ID,
+ Domain: instanceDomain.Domain,
Type: domain.DomainTypeTrusted,
}
},
@@ -196,26 +205,23 @@ func TestAddInstanceDomain(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- ctx := t.Context()
-
- // we take now here because the timestamp of the transaction is used to set the createdAt and updatedAt fields
- beforeAdd := time.Now()
- tx, err := pool.Begin(t.Context(), nil)
+ savepoint, err := tx.Begin(t.Context())
require.NoError(t, err)
defer func() {
- require.NoError(t, tx.Rollback(t.Context()))
+ err := savepoint.Rollback(t.Context())
+ if err != nil {
+ t.Log("error during rollback:", err)
+ }
}()
- instanceRepo := repository.InstanceRepository(tx)
- domainRepo := instanceRepo.Domains(false)
var instanceDomain *domain.AddInstanceDomain
if test.testFunc != nil {
- instanceDomain = test.testFunc(ctx, t, domainRepo)
+ instanceDomain = test.testFunc(t, savepoint)
} else {
instanceDomain = &test.instanceDomain
}
- err = domainRepo.Add(ctx, instanceDomain)
+ err = domainRepo.Add(t.Context(), savepoint, instanceDomain)
afterAdd := time.Now()
if test.err != nil {
assert.ErrorIs(t, err, test.err)
@@ -232,10 +238,21 @@ func TestAddInstanceDomain(t *testing.T) {
}
func TestGetInstanceDomain(t *testing.T) {
+ tx, err := pool.Begin(t.Context(), nil)
+ require.NoError(t, err)
+ defer func() {
+ err = tx.Rollback(t.Context())
+ if err != nil {
+ t.Log("error during rollback:", err)
+ }
+ }()
+
+ instanceRepo := repository.InstanceRepository()
+ domainRepo := repository.InstanceDomainRepository()
+
// create instance
- instanceID := gofakeit.UUID()
instance := domain.Instance{
- ID: instanceID,
+ ID: gofakeit.NewCrypto().UUID(),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
@@ -243,38 +260,29 @@ func TestGetInstanceDomain(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
- tx, err := pool.Begin(t.Context(), nil)
- require.NoError(t, err)
- defer func() {
- require.NoError(t, tx.Rollback(t.Context()))
- }()
- instanceRepo := repository.InstanceRepository(tx)
- err = instanceRepo.Create(t.Context(), &instance)
+
+ err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
// add domains
- domainRepo := instanceRepo.Domains(false)
- domainName1 := gofakeit.DomainName()
- domainName2 := gofakeit.DomainName()
-
domain1 := &domain.AddInstanceDomain{
- InstanceID: instanceID,
- Domain: domainName1,
+ InstanceID: instance.ID,
+ Domain: gofakeit.DomainName(),
IsPrimary: gu.Ptr(true),
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
}
domain2 := &domain.AddInstanceDomain{
- InstanceID: instanceID,
- Domain: domainName2,
+ InstanceID: instance.ID,
+ Domain: gofakeit.DomainName(),
IsPrimary: gu.Ptr(false),
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
}
- err = domainRepo.Add(t.Context(), domain1)
+ err = domainRepo.Add(t.Context(), tx, domain1)
require.NoError(t, err)
- err = domainRepo.Add(t.Context(), domain2)
+ err = domainRepo.Add(t.Context(), tx, domain2)
require.NoError(t, err)
tests := []struct {
@@ -289,19 +297,19 @@ func TestGetInstanceDomain(t *testing.T) {
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
},
expected: &domain.InstanceDomain{
- InstanceID: instanceID,
- Domain: domainName1,
+ InstanceID: instance.ID,
+ Domain: domain1.Domain,
IsPrimary: gu.Ptr(true),
},
},
{
name: "get by domain name",
opts: []database.QueryOption{
- database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, domainName2)),
+ database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, domain2.Domain)),
},
expected: &domain.InstanceDomain{
- InstanceID: instanceID,
- Domain: domainName2,
+ InstanceID: instance.ID,
+ Domain: domain2.Domain,
IsPrimary: gu.Ptr(false),
},
},
@@ -318,7 +326,7 @@ func TestGetInstanceDomain(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
ctx := t.Context()
- result, err := domainRepo.Get(ctx, test.opts...)
+ result, err := domainRepo.Get(ctx, tx, test.opts...)
if test.err != nil {
assert.ErrorIs(t, err, test.err)
return
@@ -335,10 +343,21 @@ func TestGetInstanceDomain(t *testing.T) {
}
func TestListInstanceDomains(t *testing.T) {
+ tx, err := pool.Begin(t.Context(), nil)
+ require.NoError(t, err)
+ defer func() {
+ err = tx.Rollback(t.Context())
+ if err != nil {
+ t.Log("error during rollback:", err)
+ }
+ }()
+
+ instanceRepo := repository.InstanceRepository()
+ domainRepo := repository.InstanceDomainRepository()
+
// create instance
- instanceID := gofakeit.UUID()
instance := domain.Instance{
- ID: instanceID,
+ ID: gofakeit.NewCrypto().UUID(),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
@@ -346,42 +365,35 @@ func TestListInstanceDomains(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
- tx, err := pool.Begin(t.Context(), nil)
- require.NoError(t, err)
- defer func() {
- require.NoError(t, tx.Rollback(t.Context()))
- }()
- instanceRepo := repository.InstanceRepository(tx)
- err = instanceRepo.Create(t.Context(), &instance)
+ err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
// add multiple domains
- domainRepo := instanceRepo.Domains(false)
domains := []domain.AddInstanceDomain{
{
- InstanceID: instanceID,
+ InstanceID: instance.ID,
Domain: gofakeit.DomainName(),
IsPrimary: gu.Ptr(true),
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
},
{
- InstanceID: instanceID,
+ InstanceID: instance.ID,
Domain: gofakeit.DomainName(),
IsPrimary: gu.Ptr(false),
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
},
{
- InstanceID: instanceID,
+ InstanceID: instance.ID,
Domain: gofakeit.DomainName(),
Type: domain.DomainTypeTrusted,
},
}
for i := range domains {
- err = domainRepo.Add(t.Context(), &domains[i])
+ err = domainRepo.Add(t.Context(), tx, &domains[i])
require.NoError(t, err)
}
@@ -405,7 +417,7 @@ func TestListInstanceDomains(t *testing.T) {
{
name: "list by instance",
opts: []database.QueryOption{
- database.WithCondition(domainRepo.InstanceIDCondition(instanceID)),
+ database.WithCondition(domainRepo.InstanceIDCondition(instance.ID)),
},
expectedCount: 3,
},
@@ -422,12 +434,12 @@ func TestListInstanceDomains(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
ctx := t.Context()
- results, err := domainRepo.List(ctx, test.opts...)
+ results, err := domainRepo.List(ctx, tx, test.opts...)
require.NoError(t, err)
assert.Len(t, results, test.expectedCount)
for _, result := range results {
- assert.Equal(t, instanceID, result.InstanceID)
+ assert.Equal(t, instance.ID, result.InstanceID)
assert.NotEmpty(t, result.Domain)
assert.NotEmpty(t, result.CreatedAt)
assert.NotEmpty(t, result.UpdatedAt)
@@ -437,10 +449,21 @@ func TestListInstanceDomains(t *testing.T) {
}
func TestUpdateInstanceDomain(t *testing.T) {
+ tx, err := pool.Begin(t.Context(), nil)
+ require.NoError(t, err)
+ defer func() {
+ err = tx.Rollback(t.Context())
+ if err != nil {
+ t.Log("error during rollback:", err)
+ }
+ }()
+
+ instanceRepo := repository.InstanceRepository()
+ domainRepo := repository.InstanceDomainRepository()
+
// create instance
- instanceID := gofakeit.UUID()
instance := domain.Instance{
- ID: instanceID,
+ ID: gofakeit.UUID(),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
@@ -448,29 +471,19 @@ func TestUpdateInstanceDomain(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
-
- tx, err := pool.Begin(t.Context(), nil)
- require.NoError(t, err)
- defer func() {
- require.NoError(t, tx.Rollback(t.Context()))
- }()
-
- instanceRepo := repository.InstanceRepository(tx)
- err = instanceRepo.Create(t.Context(), &instance)
+ err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
// add domain
- domainRepo := instanceRepo.Domains(false)
- domainName := gofakeit.DomainName()
instanceDomain := &domain.AddInstanceDomain{
- InstanceID: instanceID,
- Domain: domainName,
+ InstanceID: instance.ID,
+ Domain: gofakeit.DomainName(),
IsPrimary: gu.Ptr(false),
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
}
- err = domainRepo.Add(t.Context(), instanceDomain)
+ err = domainRepo.Add(t.Context(), tx, instanceDomain)
require.NoError(t, err)
tests := []struct {
@@ -481,31 +494,38 @@ func TestUpdateInstanceDomain(t *testing.T) {
err error
}{
{
- name: "set primary",
- condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
- changes: []database.Change{domainRepo.SetPrimary()},
- expected: 1,
+ name: "set primary",
+ condition: database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.DomainCondition(database.TextOperationEqual, instanceDomain.Domain),
+ ),
+ changes: []database.Change{domainRepo.SetPrimary()},
+ expected: 1,
},
{
- name: "update non-existent domain",
- condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
- changes: []database.Change{domainRepo.SetPrimary()},
- expected: 0,
+ name: "update non-existent domain",
+ condition: database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
+ ),
+ changes: []database.Change{domainRepo.SetPrimary()},
+ expected: 0,
},
{
- name: "no changes",
- condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
- changes: []database.Change{},
- expected: 0,
- err: database.ErrNoChanges,
+ name: "no changes",
+ condition: database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.DomainCondition(database.TextOperationEqual, instanceDomain.Domain),
+ ),
+ changes: []database.Change{},
+ expected: 0,
+ err: database.ErrNoChanges,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- ctx := t.Context()
-
- rowsAffected, err := domainRepo.Update(ctx, test.condition, test.changes...)
+ rowsAffected, err := domainRepo.Update(t.Context(), tx, test.condition, test.changes...)
if test.err != nil {
assert.ErrorIs(t, err, test.err)
return
@@ -516,7 +536,7 @@ func TestUpdateInstanceDomain(t *testing.T) {
// verify changes were applied if rows were affected
if rowsAffected > 0 && len(test.changes) > 0 {
- result, err := domainRepo.Get(ctx, database.WithCondition(test.condition))
+ result, err := domainRepo.Get(t.Context(), tx, database.WithCondition(test.condition))
require.NoError(t, err)
// We know changes were applied since rowsAffected > 0
@@ -529,10 +549,21 @@ func TestUpdateInstanceDomain(t *testing.T) {
}
func TestRemoveInstanceDomain(t *testing.T) {
+ tx, err := pool.Begin(t.Context(), nil)
+ require.NoError(t, err)
+ defer func() {
+ err = tx.Rollback(t.Context())
+ if err != nil {
+ t.Log("error during rollback:", err)
+ }
+ }()
+
+ instanceRepo := repository.InstanceRepository()
+ domainRepo := repository.InstanceDomainRepository()
+
// create instance
- instanceID := gofakeit.UUID()
instance := domain.Instance{
- ID: instanceID,
+ ID: gofakeit.UUID(),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
@@ -540,37 +571,28 @@ func TestRemoveInstanceDomain(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
- tx, err := pool.Begin(t.Context(), nil)
- require.NoError(t, err)
- defer func() {
- require.NoError(t, tx.Rollback(t.Context()))
- }()
- instanceRepo := repository.InstanceRepository(tx)
- err = instanceRepo.Create(t.Context(), &instance)
+ err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
// add domains
- domainRepo := instanceRepo.Domains(false)
- domainName1 := gofakeit.DomainName()
-
domain1 := &domain.AddInstanceDomain{
- InstanceID: instanceID,
- Domain: domainName1,
+ InstanceID: instance.ID,
+ Domain: gofakeit.DomainName(),
IsPrimary: gu.Ptr(true),
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
}
domain2 := &domain.AddInstanceDomain{
- InstanceID: instanceID,
+ InstanceID: instance.ID,
Domain: gofakeit.DomainName(),
IsPrimary: gu.Ptr(false),
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
}
- err = domainRepo.Add(t.Context(), domain1)
+ err = domainRepo.Add(t.Context(), tx, domain1)
require.NoError(t, err)
- err = domainRepo.Add(t.Context(), domain2)
+ err = domainRepo.Add(t.Context(), tx, domain2)
require.NoError(t, err)
tests := []struct {
@@ -579,36 +601,43 @@ func TestRemoveInstanceDomain(t *testing.T) {
expected int64
}{
{
- name: "remove by domain name",
- condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName1),
- expected: 1,
+ name: "remove by domain name",
+ condition: database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.DomainCondition(database.TextOperationEqual, domain1.Domain),
+ ),
+ expected: 1,
},
{
- name: "remove by primary condition",
- condition: domainRepo.IsPrimaryCondition(false),
- expected: 1, // domain2 should still exist and be non-primary
+ name: "remove by primary condition",
+ condition: database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.IsPrimaryCondition(false),
+ ),
+ expected: 1, // domain2 should still exist and be non-primary
},
{
- name: "remove non-existent domain",
- condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
- expected: 0,
+ name: "remove non-existent domain",
+ condition: database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
+ ),
+ expected: 0,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- ctx := t.Context()
-
// count before removal
- beforeCount, err := domainRepo.List(ctx)
+ beforeCount, err := domainRepo.List(t.Context(), tx)
require.NoError(t, err)
- rowsAffected, err := domainRepo.Remove(ctx, test.condition)
+ rowsAffected, err := domainRepo.Remove(t.Context(), tx, test.condition)
require.NoError(t, err)
assert.Equal(t, test.expected, rowsAffected)
// verify removal
- afterCount, err := domainRepo.List(ctx)
+ afterCount, err := domainRepo.List(t.Context(), tx)
require.NoError(t, err)
assert.Equal(t, len(beforeCount)-int(test.expected), len(afterCount))
})
@@ -616,8 +645,7 @@ func TestRemoveInstanceDomain(t *testing.T) {
}
func TestInstanceDomainConditions(t *testing.T) {
- instanceRepo := repository.InstanceRepository(pool)
- domainRepo := instanceRepo.Domains(false)
+ domainRepo := repository.InstanceDomainRepository()
tests := []struct {
name string
@@ -671,8 +699,7 @@ func TestInstanceDomainConditions(t *testing.T) {
}
func TestInstanceDomainChanges(t *testing.T) {
- instanceRepo := repository.InstanceRepository(pool)
- domainRepo := instanceRepo.Domains(false)
+ domainRepo := repository.InstanceDomainRepository()
tests := []struct {
name string
diff --git a/backend/v3/storage/database/repository/instance_test.go b/backend/v3/storage/database/repository/instance_test.go
index 728cbdd753e..d57a9c979e2 100644
--- a/backend/v3/storage/database/repository/instance_test.go
+++ b/backend/v3/storage/database/repository/instance_test.go
@@ -2,6 +2,7 @@ package repository_test
import (
"context"
+ "strconv"
"testing"
"time"
@@ -16,9 +17,20 @@ import (
)
func TestCreateInstance(t *testing.T) {
+ beforeCreate := time.Now()
+ tx, err := pool.Begin(context.Background(), nil)
+ require.NoError(t, err)
+ defer func() {
+ err := tx.Rollback(context.Background())
+ if err != nil {
+ t.Log("error during rollback:", err)
+ }
+ }()
+ instanceRepo := repository.InstanceRepository()
+
tests := []struct {
name string
- testFunc func(ctx context.Context, t *testing.T) *domain.Instance
+ testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance
instance domain.Instance
err error
}{
@@ -59,14 +71,10 @@ func TestCreateInstance(t *testing.T) {
},
{
name: "adding same instance twice",
- testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
- instanceRepo := repository.InstanceRepository(pool)
- instanceId := gofakeit.Name()
- instanceName := gofakeit.Name()
-
+ testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
inst := domain.Instance{
- ID: instanceId,
- Name: instanceName,
+ ID: gofakeit.UUID(),
+ Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
@@ -74,7 +82,9 @@ func TestCreateInstance(t *testing.T) {
DefaultLanguage: "defaultLanguage",
}
- err := instanceRepo.Create(ctx, &inst)
+ err := instanceRepo.Create(t.Context(), tx, &inst)
+ require.NoError(t, err)
+
// change the name to make sure same only the id clashes
inst.Name = gofakeit.Name()
require.NoError(t, err)
@@ -84,7 +94,7 @@ func TestCreateInstance(t *testing.T) {
},
func() struct {
name string
- testFunc func(ctx context.Context, t *testing.T) *domain.Instance
+ testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance
instance domain.Instance
err error
} {
@@ -92,14 +102,12 @@ func TestCreateInstance(t *testing.T) {
instanceName := gofakeit.Name()
return struct {
name string
- testFunc func(ctx context.Context, t *testing.T) *domain.Instance
+ testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance
instance domain.Instance
err error
}{
name: "adding instance with same name twice",
- testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
- instanceRepo := repository.InstanceRepository(pool)
-
+ testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
inst := domain.Instance{
ID: gofakeit.Name(),
Name: instanceName,
@@ -110,7 +118,7 @@ func TestCreateInstance(t *testing.T) {
DefaultLanguage: "defaultLanguage",
}
- err := instanceRepo.Create(ctx, &inst)
+ err := instanceRepo.Create(t.Context(), tx, &inst)
require.NoError(t, err)
// change the id
@@ -135,11 +143,8 @@ func TestCreateInstance(t *testing.T) {
{
name: "adding instance with no id",
instance: func() domain.Instance {
- // instanceId := gofakeit.Name()
- instanceName := gofakeit.Name()
instance := domain.Instance{
- // ID: instanceId,
- Name: instanceName,
+ Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
@@ -153,19 +158,25 @@ func TestCreateInstance(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- ctx := context.Background()
+ savepoint, err := tx.Begin(t.Context())
+ require.NoError(t, err)
+ defer func() {
+ err = savepoint.Rollback(t.Context())
+ if err != nil {
+ t.Log("error during rollback:", err)
+ }
+ }()
var instance *domain.Instance
if tt.testFunc != nil {
- instance = tt.testFunc(ctx, t)
+ instance = tt.testFunc(t, savepoint)
} else {
instance = &tt.instance
}
- instanceRepo := repository.InstanceRepository(pool)
// create instance
- beforeCreate := time.Now()
- err := instanceRepo.Create(ctx, instance)
+
+ err = instanceRepo.Create(t.Context(), tx, instance)
assert.ErrorIs(t, err, tt.err)
if err != nil {
return
@@ -173,7 +184,7 @@ func TestCreateInstance(t *testing.T) {
afterCreate := time.Now()
// check instance values
- instance, err = instanceRepo.Get(ctx,
+ instance, err = instanceRepo.Get(t.Context(), tx,
database.WithCondition(
instanceRepo.IDCondition(instance.ID),
),
@@ -194,22 +205,30 @@ func TestCreateInstance(t *testing.T) {
}
func TestUpdateInstance(t *testing.T) {
+ beforeUpdate := time.Now()
+ tx, err := pool.Begin(context.Background(), nil)
+ require.NoError(t, err)
+ defer func() {
+ err := tx.Rollback(context.Background())
+ if err != nil {
+ t.Log("error during rollback:", err)
+ }
+ }()
+
+ instanceRepo := repository.InstanceRepository()
+
tests := []struct {
name string
- testFunc func(ctx context.Context, t *testing.T) *domain.Instance
+ testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance
rowsAffected int64
getErr error
}{
{
name: "happy path",
- testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
- instanceRepo := repository.InstanceRepository(pool)
- instanceId := gofakeit.Name()
- instanceName := gofakeit.Name()
-
+ testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
inst := domain.Instance{
- ID: instanceId,
- Name: instanceName,
+ ID: gofakeit.UUID(),
+ Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
@@ -218,7 +237,7 @@ func TestUpdateInstance(t *testing.T) {
}
// create instance
- err := instanceRepo.Create(ctx, &inst)
+ err := instanceRepo.Create(t.Context(), tx, &inst)
require.NoError(t, err)
return &inst
},
@@ -226,14 +245,10 @@ func TestUpdateInstance(t *testing.T) {
},
{
name: "update deleted instance",
- testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
- instanceRepo := repository.InstanceRepository(pool)
- instanceId := gofakeit.Name()
- instanceName := gofakeit.Name()
-
+ testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
inst := domain.Instance{
- ID: instanceId,
- Name: instanceName,
+ ID: gofakeit.UUID(),
+ Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
@@ -242,11 +257,11 @@ func TestUpdateInstance(t *testing.T) {
}
// create instance
- err := instanceRepo.Create(ctx, &inst)
+ err := instanceRepo.Create(t.Context(), tx, &inst)
require.NoError(t, err)
// delete instance
- affectedRows, err := instanceRepo.Delete(ctx,
+ affectedRows, err := instanceRepo.Delete(t.Context(), tx,
inst.ID,
)
require.NoError(t, err)
@@ -258,11 +273,9 @@ func TestUpdateInstance(t *testing.T) {
},
{
name: "update non existent instance",
- testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
- instanceId := gofakeit.Name()
-
+ testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
inst := domain.Instance{
- ID: instanceId,
+ ID: gofakeit.UUID(),
}
return &inst
},
@@ -272,15 +285,11 @@ func TestUpdateInstance(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- ctx := context.Background()
- instanceRepo := repository.InstanceRepository(pool)
+ instance := tt.testFunc(t, tx)
- instance := tt.testFunc(ctx, t)
-
- beforeUpdate := time.Now()
// update name
newName := "new_" + instance.Name
- rowsAffected, err := instanceRepo.Update(ctx,
+ rowsAffected, err := instanceRepo.Update(t.Context(), tx,
instance.ID,
instanceRepo.SetName(newName),
)
@@ -294,7 +303,7 @@ func TestUpdateInstance(t *testing.T) {
}
// check instance values
- instance, err = instanceRepo.Get(ctx,
+ instance, err = instanceRepo.Get(t.Context(), tx,
database.WithCondition(
instanceRepo.IDCondition(instance.ID),
),
@@ -308,24 +317,31 @@ func TestUpdateInstance(t *testing.T) {
}
func TestGetInstance(t *testing.T) {
- instanceRepo := repository.InstanceRepository(pool)
+ tx, err := pool.Begin(context.Background(), nil)
+ require.NoError(t, err)
+ defer func() {
+ err := tx.Rollback(context.Background())
+ if err != nil {
+ t.Log("error during rollback:", err)
+ }
+ }()
+
+ instanceRepo := repository.InstanceRepository()
+ domainRepo := repository.InstanceDomainRepository()
+
type test struct {
name string
- testFunc func(ctx context.Context, t *testing.T) *domain.Instance
+ testFunc func(t *testing.T) *domain.Instance
err error
}
-
tests := []test{
func() test {
- instanceId := gofakeit.Name()
return test{
name: "happy path get using id",
- testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
- instanceName := gofakeit.Name()
-
+ testFunc: func(t *testing.T) *domain.Instance {
inst := domain.Instance{
- ID: instanceId,
- Name: instanceName,
+ ID: gofakeit.UUID(),
+ Name: gofakeit.BeerName(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
@@ -334,7 +350,7 @@ func TestGetInstance(t *testing.T) {
}
// create instance
- err := instanceRepo.Create(ctx, &inst)
+ err := instanceRepo.Create(t.Context(), tx, &inst)
require.NoError(t, err)
return &inst
},
@@ -342,14 +358,10 @@ func TestGetInstance(t *testing.T) {
}(),
{
name: "happy path including domains",
- testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
- instanceRepo := repository.InstanceRepository(pool)
- instanceId := gofakeit.Name()
- instanceName := gofakeit.Name()
-
+ testFunc: func(t *testing.T) *domain.Instance {
inst := domain.Instance{
- ID: instanceId,
- Name: instanceName,
+ ID: gofakeit.NewCrypto().UUID(),
+ Name: gofakeit.BeerName(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
@@ -358,10 +370,9 @@ func TestGetInstance(t *testing.T) {
}
// create instance
- err := instanceRepo.Create(ctx, &inst)
+ err := instanceRepo.Create(t.Context(), tx, &inst)
require.NoError(t, err)
- domainRepo := instanceRepo.Domains(false)
d := &domain.AddInstanceDomain{
InstanceID: inst.ID,
Domain: gofakeit.DomainName(),
@@ -369,7 +380,7 @@ func TestGetInstance(t *testing.T) {
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
}
- err = domainRepo.Add(ctx, d)
+ err = domainRepo.Add(t.Context(), tx, d)
require.NoError(t, err)
inst.Domains = append(inst.Domains, &domain.InstanceDomain{
@@ -387,7 +398,7 @@ func TestGetInstance(t *testing.T) {
},
{
name: "get non existent instance",
- testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
+ testFunc: func(t *testing.T) *domain.Instance {
inst := domain.Instance{
ID: "get non existent instance",
}
@@ -398,16 +409,13 @@ func TestGetInstance(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- ctx := context.Background()
- instanceRepo := repository.InstanceRepository(pool)
-
var instance *domain.Instance
if tt.testFunc != nil {
- instance = tt.testFunc(ctx, t)
+ instance = tt.testFunc(t)
}
// check instance values
- returnedInstance, err := instanceRepo.Get(ctx,
+ returnedInstance, err := instanceRepo.Get(t.Context(), tx,
database.WithCondition(
instanceRepo.IDCondition(instance.ID),
),
@@ -434,28 +442,33 @@ func TestGetInstance(t *testing.T) {
}
func TestListInstance(t *testing.T) {
- ctx := context.Background()
- pool, stop, err := newEmbeddedDB(ctx)
+ tx, err := pool.Begin(context.Background(), nil)
require.NoError(t, err)
- defer stop()
+ defer func() {
+ err := tx.Rollback(context.Background())
+ if err != nil {
+ t.Log("error during rollback:", err)
+ }
+ }()
+
+ instanceRepo := repository.InstanceRepository()
type test struct {
name string
- testFunc func(ctx context.Context, t *testing.T) []*domain.Instance
+ testFunc func(t *testing.T, tx database.QueryExecutor) []*domain.Instance
conditionClauses []database.Condition
noInstanceReturned bool
}
tests := []test{
{
name: "happy path single instance no filter",
- testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
- instanceRepo := repository.InstanceRepository(pool)
+ testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance {
noOfInstances := 1
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
- ID: gofakeit.Name(),
+ ID: strconv.Itoa(i),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
@@ -465,7 +478,7 @@ func TestListInstance(t *testing.T) {
}
// create instance
- err := instanceRepo.Create(ctx, &inst)
+ err := instanceRepo.Create(t.Context(), tx, &inst)
require.NoError(t, err)
instances[i] = &inst
@@ -476,14 +489,13 @@ func TestListInstance(t *testing.T) {
},
{
name: "happy path multiple instance no filter",
- testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
- instanceRepo := repository.InstanceRepository(pool)
+ testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance {
noOfInstances := 5
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
- ID: gofakeit.Name(),
+ ID: strconv.Itoa(i),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
@@ -493,7 +505,7 @@ func TestListInstance(t *testing.T) {
}
// create instance
- err := instanceRepo.Create(ctx, &inst)
+ err := instanceRepo.Create(t.Context(), tx, &inst)
require.NoError(t, err)
instances[i] = &inst
@@ -503,17 +515,16 @@ func TestListInstance(t *testing.T) {
},
},
func() test {
- instanceRepo := repository.InstanceRepository(pool)
- instanceId := gofakeit.Name()
+ instanceID := gofakeit.BeerName()
return test{
name: "instance filter on id",
- testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
+ testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance {
noOfInstances := 1
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
- ID: instanceId,
+ ID: instanceID,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
@@ -523,7 +534,7 @@ func TestListInstance(t *testing.T) {
}
// create instance
- err := instanceRepo.Create(ctx, &inst)
+ err := instanceRepo.Create(t.Context(), tx, &inst)
require.NoError(t, err)
instances[i] = &inst
@@ -531,21 +542,20 @@ func TestListInstance(t *testing.T) {
return instances
},
- conditionClauses: []database.Condition{instanceRepo.IDCondition(instanceId)},
+ conditionClauses: []database.Condition{instanceRepo.IDCondition(instanceID)},
}
}(),
func() test {
- instanceRepo := repository.InstanceRepository(pool)
- instanceName := gofakeit.Name()
+ instanceName := gofakeit.BeerName()
return test{
name: "multiple instance filter on name",
- testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
+ testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance {
noOfInstances := 5
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
- ID: gofakeit.Name(),
+ ID: strconv.Itoa(i),
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
@@ -555,7 +565,7 @@ func TestListInstance(t *testing.T) {
}
// create instance
- err := instanceRepo.Create(ctx, &inst)
+ err := instanceRepo.Create(t.Context(), tx, &inst)
require.NoError(t, err)
instances[i] = &inst
@@ -569,14 +579,15 @@ func TestListInstance(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- t.Cleanup(func() {
- _, err := pool.Exec(ctx, "DELETE FROM zitadel.instances")
- require.NoError(t, err)
- })
-
- instances := tt.testFunc(ctx, t)
-
- instanceRepo := repository.InstanceRepository(pool)
+ savepoint, err := tx.Begin(t.Context())
+ require.NoError(t, err)
+ defer func() {
+ err = savepoint.Rollback(t.Context())
+ if err != nil {
+ t.Log("error during rollback:", err)
+ }
+ }()
+ instances := tt.testFunc(t, savepoint)
var condition database.Condition
if len(tt.conditionClauses) > 0 {
@@ -584,13 +595,13 @@ func TestListInstance(t *testing.T) {
}
// check instance values
- returnedInstances, err := instanceRepo.List(ctx,
+ returnedInstances, err := instanceRepo.List(t.Context(), tx,
database.WithCondition(condition),
- database.WithOrderByAscending(instanceRepo.CreatedAtColumn()),
+ database.WithOrderByAscending(instanceRepo.IDColumn()),
)
require.NoError(t, err)
if tt.noInstanceReturned {
- assert.Nil(t, returnedInstances)
+ assert.Len(t, returnedInstances, 0)
return
}
@@ -609,42 +620,45 @@ func TestListInstance(t *testing.T) {
}
func TestDeleteInstance(t *testing.T) {
+ tx, err := pool.Begin(context.Background(), nil)
+ require.NoError(t, err)
+ defer func() {
+ err := tx.Rollback(context.Background())
+ if err != nil {
+ t.Log("error during rollback:", err)
+ }
+ }()
+
+ instanceRepo := repository.InstanceRepository()
+
type test struct {
name string
- testFunc func(ctx context.Context, t *testing.T)
+ testFunc func(t *testing.T, tx database.QueryExecutor)
instanceID string
noOfDeletedRows int64
}
tests := []test{
func() test {
- instanceRepo := repository.InstanceRepository(pool)
- instanceId := gofakeit.Name()
- var noOfInstances int64 = 1
+ instanceID := gofakeit.NewCrypto().UUID()
return test{
name: "happy path delete single instance filter id",
- testFunc: func(ctx context.Context, t *testing.T) {
- instances := make([]*domain.Instance, noOfInstances)
- for i := range noOfInstances {
-
- inst := domain.Instance{
- ID: instanceId,
- Name: gofakeit.Name(),
- DefaultOrgID: "defaultOrgId",
- IAMProjectID: "iamProject",
- ConsoleClientID: "consoleCLient",
- ConsoleAppID: "consoleApp",
- DefaultLanguage: "defaultLanguage",
- }
-
- // create instance
- err := instanceRepo.Create(ctx, &inst)
- require.NoError(t, err)
-
- instances[i] = &inst
+ testFunc: func(t *testing.T, tx database.QueryExecutor) {
+ inst := domain.Instance{
+ ID: instanceID,
+ Name: gofakeit.Name(),
+ DefaultOrgID: "defaultOrgId",
+ IAMProjectID: "iamProject",
+ ConsoleClientID: "consoleCLient",
+ ConsoleAppID: "consoleApp",
+ DefaultLanguage: "defaultLanguage",
}
+
+ // create instance
+ err := instanceRepo.Create(t.Context(), tx, &inst)
+ require.NoError(t, err)
},
- instanceID: instanceId,
- noOfDeletedRows: noOfInstances,
+ instanceID: instanceID,
+ noOfDeletedRows: 1,
}
}(),
func() test {
@@ -655,40 +669,33 @@ func TestDeleteInstance(t *testing.T) {
}
}(),
func() test {
- instanceRepo := repository.InstanceRepository(pool)
- instanceName := gofakeit.Name()
+ instanceID := gofakeit.Name()
return test{
name: "deleted already deleted instance",
- testFunc: func(ctx context.Context, t *testing.T) {
- noOfInstances := 1
- instances := make([]*domain.Instance, noOfInstances)
- for i := range noOfInstances {
+ testFunc: func(t *testing.T, tx database.QueryExecutor) {
- inst := domain.Instance{
- ID: gofakeit.Name(),
- Name: instanceName,
- DefaultOrgID: "defaultOrgId",
- IAMProjectID: "iamProject",
- ConsoleClientID: "consoleCLient",
- ConsoleAppID: "consoleApp",
- DefaultLanguage: "defaultLanguage",
- }
-
- // create instance
- err := instanceRepo.Create(ctx, &inst)
- require.NoError(t, err)
-
- instances[i] = &inst
+ inst := domain.Instance{
+ ID: instanceID,
+ Name: gofakeit.BeerName(),
+ DefaultOrgID: "defaultOrgId",
+ IAMProjectID: "iamProject",
+ ConsoleClientID: "consoleCLient",
+ ConsoleAppID: "consoleApp",
+ DefaultLanguage: "defaultLanguage",
}
+ // create instance
+ err := instanceRepo.Create(t.Context(), tx, &inst)
+ require.NoError(t, err)
+
// delete instance
- affectedRows, err := instanceRepo.Delete(ctx,
- instances[0].ID,
+ affectedRows, err := instanceRepo.Delete(t.Context(), tx,
+ inst.ID,
)
require.NoError(t, err)
assert.Equal(t, int64(1), affectedRows)
},
- instanceID: instanceName,
+ instanceID: instanceID,
// this test should return 0 affected rows as the instance was already deleted
noOfDeletedRows: 0,
}
@@ -696,22 +703,26 @@ func TestDeleteInstance(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- ctx := context.Background()
- instanceRepo := repository.InstanceRepository(pool)
+ savepoint, err := tx.Begin(t.Context())
+ require.NoError(t, err)
+ defer func() {
+ err = savepoint.Rollback(t.Context())
+ if err != nil {
+ t.Log("error during rollback:", err)
+ }
+ }()
if tt.testFunc != nil {
- tt.testFunc(ctx, t)
+ tt.testFunc(t, savepoint)
}
// delete instance
- noOfDeletedRows, err := instanceRepo.Delete(ctx,
- tt.instanceID,
- )
+ noOfDeletedRows, err := instanceRepo.Delete(t.Context(), savepoint, tt.instanceID)
require.NoError(t, err)
assert.Equal(t, noOfDeletedRows, tt.noOfDeletedRows)
// check instance was deleted
- instance, err := instanceRepo.Get(ctx,
+ instance, err := instanceRepo.Get(t.Context(), savepoint,
database.WithCondition(
instanceRepo.IDCondition(tt.instanceID),
),
diff --git a/backend/v3/storage/database/repository/org.go b/backend/v3/storage/database/repository/org.go
index 85536c96dcb..515664879a0 100644
--- a/backend/v3/storage/database/repository/org.go
+++ b/backend/v3/storage/database/repository/org.go
@@ -14,25 +14,24 @@ import (
var _ domain.OrganizationRepository = (*org)(nil)
type org struct {
- repository
shouldLoadDomains bool
- domainRepo domain.OrganizationDomainRepository
+ domainRepo orgDomain
}
-func OrganizationRepository(client database.QueryExecutor) domain.OrganizationRepository {
- return &org{
- repository: repository{
- client: client,
- },
- }
+func (o org) unqualifiedTableName() string {
+ return "organizations"
+}
+
+func OrganizationRepository() domain.OrganizationRepository {
+ return new(org)
}
const queryOrganizationStmt = `SELECT organizations.id, organizations.name, organizations.instance_id, organizations.state, organizations.created_at, organizations.updated_at` +
- ` , jsonb_agg(json_build_object('domain', org_domains.domain, 'isVerified', org_domains.is_verified, 'isPrimary', org_domains.is_primary, 'validationType', org_domains.validation_type, 'createdAt', org_domains.created_at, 'updatedAt', org_domains.updated_at)) FILTER (WHERE org_domains.org_id IS NOT NULL) AS domains` +
+ ` , jsonb_agg(json_build_object('instanceId', org_domains.instance_id, 'orgId', org_domains.org_id, 'domain', org_domains.domain, 'isVerified', org_domains.is_verified, 'isPrimary', org_domains.is_primary, 'validationType', org_domains.validation_type, 'createdAt', org_domains.created_at, 'updatedAt', org_domains.updated_at)) FILTER (WHERE org_domains.org_id IS NOT NULL) AS domains` +
` FROM zitadel.organizations`
// Get implements [domain.OrganizationRepository].
-func (o *org) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Organization, error) {
+func (o org) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.Organization, error) {
opts = append(opts,
o.joinDomains(),
database.WithGroupBy(o.InstanceIDColumn(), o.IDColumn()),
@@ -43,15 +42,19 @@ func (o *org) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Or
opt(options)
}
+ if !options.Condition.IsRestrictingColumn(o.InstanceIDColumn()) {
+ return nil, database.NewMissingConditionError(o.InstanceIDColumn())
+ }
+
var builder database.StatementBuilder
builder.WriteString(queryOrganizationStmt)
options.Write(&builder)
- return scanOrganization(ctx, o.client, &builder)
+ return scanOrganization(ctx, client, &builder)
}
// List implements [domain.OrganizationRepository].
-func (o *org) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Organization, error) {
+func (o org) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*domain.Organization, error) {
opts = append(opts,
o.joinDomains(),
database.WithGroupBy(o.InstanceIDColumn(), o.IDColumn()),
@@ -62,30 +65,15 @@ func (o *org) List(ctx context.Context, opts ...database.QueryOption) ([]*domain
opt(options)
}
+ if !options.Condition.IsRestrictingColumn(o.InstanceIDColumn()) {
+ return nil, database.NewMissingConditionError(o.InstanceIDColumn())
+ }
+
var builder database.StatementBuilder
builder.WriteString(queryOrganizationStmt)
options.Write(&builder)
- return scanOrganizations(ctx, o.client, &builder)
-}
-
-func (o *org) joinDomains() database.QueryOption {
- columns := make([]database.Condition, 0, 3)
- columns = append(columns,
- database.NewColumnCondition(o.InstanceIDColumn(), o.Domains(false).InstanceIDColumn()),
- database.NewColumnCondition(o.IDColumn(), o.Domains(false).OrgIDColumn()),
- )
-
- // If domains should not be joined, we make sure to return null for the domain columns
- // the query optimizer of the dialect should optimize this away if no domains are requested
- if !o.shouldLoadDomains {
- columns = append(columns, database.IsNull(o.domainRepo.OrgIDColumn()))
- }
-
- return database.WithLeftJoin(
- "zitadel.org_domains",
- database.And(columns...),
- )
+ return scanOrganizations(ctx, client, &builder)
}
const createOrganizationStmt = `INSERT INTO zitadel.organizations (id, name, instance_id, state)` +
@@ -93,46 +81,48 @@ const createOrganizationStmt = `INSERT INTO zitadel.organizations (id, name, ins
` RETURNING created_at, updated_at`
// Create implements [domain.OrganizationRepository].
-func (o *org) Create(ctx context.Context, organization *domain.Organization) error {
+func (o org) Create(ctx context.Context, client database.QueryExecutor, organization *domain.Organization) error {
builder := database.StatementBuilder{}
builder.AppendArgs(organization.ID, organization.Name, organization.InstanceID, organization.State)
builder.WriteString(createOrganizationStmt)
- return o.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&organization.CreatedAt, &organization.UpdatedAt)
+ return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&organization.CreatedAt, &organization.UpdatedAt)
}
// Update implements [domain.OrganizationRepository].
-func (o *org) Update(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string, changes ...database.Change) (int64, error) {
+func (o org) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error) {
if len(changes) == 0 {
return 0, database.ErrNoChanges
}
- builder := database.StatementBuilder{}
+ if !condition.IsRestrictingColumn(o.InstanceIDColumn()) {
+ return 0, database.NewMissingConditionError(o.InstanceIDColumn())
+ }
+ if !database.Changes(changes).IsOnColumn(o.UpdatedAtColumn()) {
+ changes = append(changes, database.NewChange(o.UpdatedAtColumn(), database.NullInstruction))
+ }
+
+ var builder database.StatementBuilder
builder.WriteString(`UPDATE zitadel.organizations SET `)
-
- instanceIDCondition := o.InstanceIDCondition(instanceID)
-
- conditions := []database.Condition{id, instanceIDCondition}
database.Changes(changes).Write(&builder)
- writeCondition(&builder, database.And(conditions...))
+ writeCondition(&builder, condition)
stmt := builder.String()
- rowsAffected, err := o.client.Exec(ctx, stmt, builder.Args()...)
+ rowsAffected, err := client.Exec(ctx, stmt, builder.Args()...)
return rowsAffected, err
}
// Delete implements [domain.OrganizationRepository].
-func (o *org) Delete(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string) (int64, error) {
- builder := database.StatementBuilder{}
+func (o org) Delete(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error) {
+ if !condition.IsRestrictingColumn(o.InstanceIDColumn()) {
+ return 0, database.NewMissingConditionError(o.InstanceIDColumn())
+ }
+ var builder database.StatementBuilder
builder.WriteString(`DELETE FROM zitadel.organizations`)
+ writeCondition(&builder, condition)
- instanceIDCondition := o.InstanceIDCondition(instanceID)
-
- conditions := []database.Condition{id, instanceIDCondition}
- writeCondition(&builder, database.And(conditions...))
-
- return o.client.Exec(ctx, builder.String(), builder.Args()...)
+ return client.Exec(ctx, builder.String(), builder.Args()...)
}
// -------------------------------------------------------------
@@ -154,13 +144,13 @@ func (o org) SetState(state domain.OrgState) database.Change {
// -------------------------------------------------------------
// IDCondition implements [domain.organizationConditions].
-func (o org) IDCondition(id string) domain.OrgIdentifierCondition {
+func (o org) IDCondition(id string) database.Condition {
return database.NewTextCondition(o.IDColumn(), database.TextOperationEqual, id)
}
// NameCondition implements [domain.organizationConditions].
-func (o org) NameCondition(name string) domain.OrgIdentifierCondition {
- return database.NewTextCondition(o.NameColumn(), database.TextOperationEqual, name)
+func (o org) NameCondition(op database.TextOperation, name string) database.Condition {
+ return database.NewTextCondition(o.NameColumn(), op, name)
}
// InstanceIDCondition implements [domain.organizationConditions].
@@ -173,38 +163,62 @@ func (o org) StateCondition(state domain.OrgState) database.Condition {
return database.NewTextCondition(o.StateColumn(), database.TextOperationEqual, state.String())
}
+// ExistsDomain creates a correlated [database.Exists] condition on org_domains.
+// Use this filter to make sure the organization returned contains a specific domain.
+// Example usage:
+//
+// domainRepo := orgRepo.Domains(true) // ensure domains are loaded/aggregated
+// org, _ := orgRepo.Get(ctx,
+// database.WithCondition(
+// database.And(
+// orgRepo.InstanceIDCondition(instanceID),
+// orgRepo.ExistsDomain(domainRepo.DomainCondition(database.TextOperationEqual, "example.com")),
+// ),
+// ),
+// )
+func (o org) ExistsDomain(cond database.Condition) database.Condition {
+ return database.Exists(
+ o.domainRepo.qualifiedTableName(),
+ database.And(
+ database.NewColumnCondition(o.InstanceIDColumn(), o.domainRepo.InstanceIDColumn()),
+ database.NewColumnCondition(o.IDColumn(), o.domainRepo.OrgIDColumn()),
+ cond,
+ ),
+ )
+}
+
// -------------------------------------------------------------
// columns
// -------------------------------------------------------------
// IDColumn implements [domain.organizationColumns].
-func (org) IDColumn() database.Column {
- return database.NewColumn("organizations", "id")
+func (o org) IDColumn() database.Column {
+ return database.NewColumn(o.unqualifiedTableName(), "id")
}
// NameColumn implements [domain.organizationColumns].
-func (org) NameColumn() database.Column {
- return database.NewColumn("organizations", "name")
+func (o org) NameColumn() database.Column {
+ return database.NewColumn(o.unqualifiedTableName(), "name")
}
// InstanceIDColumn implements [domain.organizationColumns].
-func (org) InstanceIDColumn() database.Column {
- return database.NewColumn("organizations", "instance_id")
+func (o org) InstanceIDColumn() database.Column {
+ return database.NewColumn(o.unqualifiedTableName(), "instance_id")
}
// StateColumn implements [domain.organizationColumns].
-func (org) StateColumn() database.Column {
- return database.NewColumn("organizations", "state")
+func (o org) StateColumn() database.Column {
+ return database.NewColumn(o.unqualifiedTableName(), "state")
}
// CreatedAtColumn implements [domain.organizationColumns].
-func (org) CreatedAtColumn() database.Column {
- return database.NewColumn("organizations", "created_at")
+func (o org) CreatedAtColumn() database.Column {
+ return database.NewColumn(o.unqualifiedTableName(), "created_at")
}
// UpdatedAtColumn implements [domain.organizationColumns].
-func (org) UpdatedAtColumn() database.Column {
- return database.NewColumn("organizations", "updated_at")
+func (o org) UpdatedAtColumn() database.Column {
+ return database.NewColumn(o.unqualifiedTableName(), "updated_at")
}
// -------------------------------------------------------------
@@ -255,20 +269,27 @@ func scanOrganizations(ctx context.Context, querier database.Querier, builder *d
// sub repositories
// -------------------------------------------------------------
-// Domains implements [domain.OrganizationRepository].
-func (o *org) Domains(shouldLoad bool) domain.OrganizationDomainRepository {
- if !o.shouldLoadDomains {
- o.shouldLoadDomains = shouldLoad
+func (o org) LoadDomains() domain.OrganizationRepository {
+ return &org{
+ shouldLoadDomains: true,
}
-
- if o.domainRepo != nil {
- return o.domainRepo
- }
-
- o.domainRepo = &orgDomain{
- repository: o.repository,
- org: o,
- }
-
- return o.domainRepo
+}
+
+func (o org) joinDomains() database.QueryOption {
+ columns := make([]database.Condition, 0, 3)
+ columns = append(columns,
+ database.NewColumnCondition(o.InstanceIDColumn(), o.domainRepo.InstanceIDColumn()),
+ database.NewColumnCondition(o.IDColumn(), o.domainRepo.OrgIDColumn()),
+ )
+
+ // If domains should not be joined, we make sure to return null for the domain columns
+ // the query optimizer of the dialect should optimize this away if no domains are requested
+ if !o.shouldLoadDomains {
+ columns = append(columns, database.IsNull(o.domainRepo.OrgIDColumn()))
+ }
+
+ return database.WithLeftJoin(
+ o.domainRepo.qualifiedTableName(),
+ database.And(columns...),
+ )
}
diff --git a/backend/v3/storage/database/repository/org_domain.go b/backend/v3/storage/database/repository/org_domain.go
index 5b2e91acf28..4d73647edf2 100644
--- a/backend/v3/storage/database/repository/org_domain.go
+++ b/backend/v3/storage/database/repository/org_domain.go
@@ -10,9 +10,18 @@ import (
var _ domain.OrganizationDomainRepository = (*orgDomain)(nil)
-type orgDomain struct {
- repository
- *org
+type orgDomain struct{}
+
+func OrganizationDomainRepository() domain.OrganizationDomainRepository {
+ return new(orgDomain)
+}
+
+func (orgDomain) qualifiedTableName() string {
+ return "zitadel.org_domains"
+}
+
+func (orgDomain) unqualifiedTableName() string {
+ return "org_domains"
}
// -------------------------------------------------------------
@@ -24,36 +33,44 @@ const queryOrganizationDomainStmt = `SELECT instance_id, org_id, domain, is_veri
// Get implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).Get of orgDomain.org.
-func (o *orgDomain) Get(ctx context.Context, opts ...database.QueryOption) (*domain.OrganizationDomain, error) {
+func (o orgDomain) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.OrganizationDomain, error) {
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
}
+ if !options.Condition.IsRestrictingColumn(o.InstanceIDColumn()) {
+ return nil, database.NewMissingConditionError(o.InstanceIDColumn())
+ }
+
var builder database.StatementBuilder
builder.WriteString(queryOrganizationDomainStmt)
options.Write(&builder)
- return scanOrganizationDomain(ctx, o.client, &builder)
+ return scanOrganizationDomain(ctx, client, &builder)
}
// List implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).List of orgDomain.org.
-func (o *orgDomain) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.OrganizationDomain, error) {
+func (o orgDomain) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*domain.OrganizationDomain, error) {
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
}
+ if !options.Condition.IsRestrictingColumn(o.InstanceIDColumn()) {
+ return nil, database.NewMissingConditionError(o.InstanceIDColumn())
+ }
+
var builder database.StatementBuilder
builder.WriteString(queryOrganizationDomainStmt)
options.Write(&builder)
- return scanOrganizationDomains(ctx, o.client, &builder)
+ return scanOrganizationDomains(ctx, client, &builder)
}
// Add implements [domain.OrganizationDomainRepository].
-func (o *orgDomain) Add(ctx context.Context, domain *domain.AddOrganizationDomain) error {
+func (o orgDomain) Add(ctx context.Context, client database.QueryExecutor, domain *domain.AddOrganizationDomain) error {
var (
builder database.StatementBuilder
createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction
@@ -69,33 +86,47 @@ func (o *orgDomain) Add(ctx context.Context, domain *domain.AddOrganizationDomai
builder.WriteArgs(domain.InstanceID, domain.OrgID, domain.Domain, domain.IsVerified, domain.IsPrimary, domain.ValidationType, createdAt, updatedAt)
builder.WriteString(`) RETURNING created_at, updated_at`)
- return o.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt)
+ return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt)
}
// Update implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).Update of orgDomain.org.
-func (o *orgDomain) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) {
+func (o orgDomain) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error) {
if len(changes) == 0 {
return 0, database.ErrNoChanges
}
+ if !condition.IsRestrictingColumn(o.InstanceIDColumn()) {
+ return 0, database.NewMissingConditionError(o.InstanceIDColumn())
+ }
+ if !condition.IsRestrictingColumn(o.OrgIDColumn()) {
+ return 0, database.NewMissingConditionError(o.OrgIDColumn())
+ }
+ if !database.Changes(changes).IsOnColumn(o.UpdatedAtColumn()) {
+ changes = append(changes, database.NewChange(o.UpdatedAtColumn(), database.NullInstruction))
+ }
var builder database.StatementBuilder
-
builder.WriteString(`UPDATE zitadel.org_domains SET `)
database.Changes(changes).Write(&builder)
writeCondition(&builder, condition)
- return o.client.Exec(ctx, builder.String(), builder.Args()...)
+ return client.Exec(ctx, builder.String(), builder.Args()...)
}
// Remove implements [domain.OrganizationDomainRepository].
-func (o *orgDomain) Remove(ctx context.Context, condition database.Condition) (int64, error) {
- var builder database.StatementBuilder
+func (o orgDomain) Remove(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error) {
+ if !condition.IsRestrictingColumn(o.InstanceIDColumn()) {
+ return 0, database.NewMissingConditionError(o.InstanceIDColumn())
+ }
+ if !condition.IsRestrictingColumn(o.OrgIDColumn()) {
+ return 0, database.NewMissingConditionError(o.OrgIDColumn())
+ }
+ var builder database.StatementBuilder
builder.WriteString(`DELETE FROM zitadel.org_domains `)
writeCondition(&builder, condition)
- return o.client.Exec(ctx, builder.String(), builder.Args()...)
+ return client.Exec(ctx, builder.String(), builder.Args()...)
}
// -------------------------------------------------------------
@@ -158,45 +189,43 @@ func (o orgDomain) OrgIDCondition(orgID string) database.Condition {
// CreatedAtColumn implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).CreatedAtColumn of orgDomain.org.
-func (orgDomain) CreatedAtColumn() database.Column {
- return database.NewColumn("org_domains", "created_at")
+func (o orgDomain) CreatedAtColumn() database.Column {
+ return database.NewColumn(o.unqualifiedTableName(), "created_at")
}
// DomainColumn implements [domain.OrganizationDomainRepository].
-func (orgDomain) DomainColumn() database.Column {
- return database.NewColumn("org_domains", "domain")
+func (o orgDomain) DomainColumn() database.Column {
+ return database.NewColumn(o.unqualifiedTableName(), "domain")
}
// InstanceIDColumn implements [domain.OrganizationDomainRepository].
-// Subtle: this method shadows the method ([domain.OrganizationRepository]).InstanceIDColumn of orgDomain.org.
-func (orgDomain) InstanceIDColumn() database.Column {
- return database.NewColumn("org_domains", "instance_id")
+func (o orgDomain) InstanceIDColumn() database.Column {
+ return database.NewColumn(o.unqualifiedTableName(), "instance_id")
}
// IsPrimaryColumn implements [domain.OrganizationDomainRepository].
-func (orgDomain) IsPrimaryColumn() database.Column {
- return database.NewColumn("org_domains", "is_primary")
+func (o orgDomain) IsPrimaryColumn() database.Column {
+ return database.NewColumn(o.unqualifiedTableName(), "is_primary")
}
// IsVerifiedColumn implements [domain.OrganizationDomainRepository].
-func (orgDomain) IsVerifiedColumn() database.Column {
- return database.NewColumn("org_domains", "is_verified")
+func (o orgDomain) IsVerifiedColumn() database.Column {
+ return database.NewColumn(o.unqualifiedTableName(), "is_verified")
}
// OrgIDColumn implements [domain.OrganizationDomainRepository].
-func (orgDomain) OrgIDColumn() database.Column {
- return database.NewColumn("org_domains", "org_id")
+func (o orgDomain) OrgIDColumn() database.Column {
+ return database.NewColumn(o.unqualifiedTableName(), "org_id")
}
// UpdatedAtColumn implements [domain.OrganizationDomainRepository].
-// Subtle: this method shadows the method ([domain.OrganizationRepository]).UpdatedAtColumn of orgDomain.org.
-func (orgDomain) UpdatedAtColumn() database.Column {
- return database.NewColumn("org_domains", "updated_at")
+func (o orgDomain) UpdatedAtColumn() database.Column {
+ return database.NewColumn(o.unqualifiedTableName(), "updated_at")
}
// ValidationTypeColumn implements [domain.OrganizationDomainRepository].
-func (orgDomain) ValidationTypeColumn() database.Column {
- return database.NewColumn("org_domains", "validation_type")
+func (o orgDomain) ValidationTypeColumn() database.Column {
+ return database.NewColumn(o.unqualifiedTableName(), "validation_type")
}
// -------------------------------------------------------------
diff --git a/backend/v3/storage/database/repository/org_domain_test.go b/backend/v3/storage/database/repository/org_domain_test.go
index a06d9eeee14..c0a108a2712 100644
--- a/backend/v3/storage/database/repository/org_domain_test.go
+++ b/backend/v3/storage/database/repository/org_domain_test.go
@@ -1,7 +1,6 @@
package repository_test
import (
- "context"
"testing"
"github.com/brianvoe/gofakeit/v6"
@@ -15,6 +14,15 @@ import (
)
func TestAddOrganizationDomain(t *testing.T) {
+ tx, err := pool.Begin(t.Context(), nil)
+ require.NoError(t, err)
+ defer func() {
+ err = tx.Rollback(t.Context())
+ if err != nil {
+ t.Log("error during rollback:", err)
+ }
+ }()
+
// create instance
instanceID := gofakeit.UUID()
instance := domain.Instance{
@@ -26,8 +34,11 @@ func TestAddOrganizationDomain(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
- instanceRepo := repository.InstanceRepository(pool)
- err := instanceRepo.Create(t.Context(), &instance)
+ instanceRepo := repository.InstanceRepository()
+ orgRepo := repository.OrganizationRepository()
+ domainRepo := repository.OrganizationDomainRepository()
+
+ err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
// create organization
@@ -41,7 +52,7 @@ func TestAddOrganizationDomain(t *testing.T) {
tests := []struct {
name string
- testFunc func(ctx context.Context, t *testing.T, domainRepo domain.OrganizationDomainRepository) *domain.AddOrganizationDomain
+ testFunc func(t *testing.T, tx database.QueryExecutor) *domain.AddOrganizationDomain
organizationDomain domain.AddOrganizationDomain
err error
}{
@@ -92,7 +103,7 @@ func TestAddOrganizationDomain(t *testing.T) {
},
{
name: "add domain with same domain twice",
- testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.OrganizationDomainRepository) *domain.AddOrganizationDomain {
+ testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.AddOrganizationDomain {
domainName := gofakeit.DomainName()
organizationDomain := &domain.AddOrganizationDomain{
@@ -104,7 +115,7 @@ func TestAddOrganizationDomain(t *testing.T) {
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
}
- err := domainRepo.Add(ctx, organizationDomain)
+ err := domainRepo.Add(t.Context(), tx, organizationDomain)
require.NoError(t, err)
// return same domain again
@@ -169,28 +180,26 @@ func TestAddOrganizationDomain(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- ctx := t.Context()
-
- tx, err := pool.Begin(t.Context(), nil)
+ savepoint, err := tx.Begin(t.Context())
require.NoError(t, err)
defer func() {
- require.NoError(t, tx.Rollback(t.Context()))
+ err = savepoint.Rollback(t.Context())
+ if err != nil {
+ t.Log("error during rollback:", err)
+ }
}()
- orgRepo := repository.OrganizationRepository(tx)
- err = orgRepo.Create(t.Context(), &organization)
+ err = orgRepo.Create(t.Context(), savepoint, &organization)
require.NoError(t, err)
- domainRepo := orgRepo.Domains(false)
-
var organizationDomain *domain.AddOrganizationDomain
if test.testFunc != nil {
- organizationDomain = test.testFunc(ctx, t, domainRepo)
+ organizationDomain = test.testFunc(t, savepoint)
} else {
organizationDomain = &test.organizationDomain
}
- err = domainRepo.Add(ctx, organizationDomain)
+ err = domainRepo.Add(t.Context(), tx, organizationDomain)
if test.err != nil {
assert.ErrorIs(t, err, test.err)
return
@@ -204,6 +213,16 @@ func TestAddOrganizationDomain(t *testing.T) {
}
func TestGetOrganizationDomain(t *testing.T) {
+ tx, err := pool.Begin(t.Context(), nil)
+ require.NoError(t, err)
+ defer func() {
+ require.NoError(t, tx.Rollback(t.Context()))
+ }()
+
+ instanceRepo := repository.InstanceRepository()
+ orgRepo := repository.OrganizationRepository()
+ domainRepo := repository.OrganizationDomainRepository()
+
// create instance
instanceID := gofakeit.UUID()
instance := domain.Instance{
@@ -225,29 +244,18 @@ func TestGetOrganizationDomain(t *testing.T) {
State: domain.OrgStateActive,
}
- tx, err := pool.Begin(t.Context(), nil)
- require.NoError(t, err)
- defer func() {
- require.NoError(t, tx.Rollback(t.Context()))
- }()
-
- instanceRepo := repository.InstanceRepository(tx)
- err = instanceRepo.Create(t.Context(), &instance)
+ err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
- orgRepo := repository.OrganizationRepository(tx)
- err = orgRepo.Create(t.Context(), &organization)
+ err = orgRepo.Create(t.Context(), tx, &organization)
require.NoError(t, err)
// add domains
- domainRepo := orgRepo.Domains(false)
- domainName1 := gofakeit.DomainName()
- domainName2 := gofakeit.DomainName()
domain1 := &domain.AddOrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
- Domain: domainName1,
+ Domain: gofakeit.DomainName(),
IsVerified: true,
IsPrimary: true,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
@@ -255,15 +263,15 @@ func TestGetOrganizationDomain(t *testing.T) {
domain2 := &domain.AddOrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
- Domain: domainName2,
+ Domain: gofakeit.DomainName(),
IsVerified: false,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
}
- err = domainRepo.Add(t.Context(), domain1)
+ err = domainRepo.Add(t.Context(), tx, domain1)
require.NoError(t, err)
- err = domainRepo.Add(t.Context(), domain2)
+ err = domainRepo.Add(t.Context(), tx, domain2)
require.NoError(t, err)
tests := []struct {
@@ -275,12 +283,15 @@ func TestGetOrganizationDomain(t *testing.T) {
{
name: "get primary domain",
opts: []database.QueryOption{
- database.WithCondition(domainRepo.IsPrimaryCondition(true)),
+ database.WithCondition(database.And(
+ domainRepo.InstanceIDCondition(instanceID),
+ domainRepo.IsPrimaryCondition(true),
+ )),
},
expected: &domain.OrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
- Domain: domainName1,
+ Domain: domain1.Domain,
IsVerified: true,
IsPrimary: true,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
@@ -289,12 +300,15 @@ func TestGetOrganizationDomain(t *testing.T) {
{
name: "get by domain name",
opts: []database.QueryOption{
- database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, domainName2)),
+ database.WithCondition(database.And(
+ domainRepo.InstanceIDCondition(instanceID),
+ domainRepo.DomainCondition(database.TextOperationEqual, domain2.Domain),
+ )),
},
expected: &domain.OrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
- Domain: domainName2,
+ Domain: domain2.Domain,
IsVerified: false,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
@@ -303,13 +317,16 @@ func TestGetOrganizationDomain(t *testing.T) {
{
name: "get by org ID",
opts: []database.QueryOption{
- database.WithCondition(domainRepo.OrgIDCondition(orgID)),
- database.WithCondition(domainRepo.IsPrimaryCondition(true)),
+ database.WithCondition(database.And(
+ domainRepo.InstanceIDCondition(instanceID),
+ domainRepo.OrgIDCondition(orgID),
+ domainRepo.IsPrimaryCondition(true),
+ )),
},
expected: &domain.OrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
- Domain: domainName1,
+ Domain: domain1.Domain,
IsVerified: true,
IsPrimary: true,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
@@ -318,12 +335,15 @@ func TestGetOrganizationDomain(t *testing.T) {
{
name: "get verified domain",
opts: []database.QueryOption{
- database.WithCondition(domainRepo.IsVerifiedCondition(true)),
+ database.WithCondition(database.And(
+ domainRepo.InstanceIDCondition(instanceID),
+ domainRepo.IsVerifiedCondition(true),
+ )),
},
expected: &domain.OrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
- Domain: domainName1,
+ Domain: domain1.Domain,
IsVerified: true,
IsPrimary: true,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
@@ -332,7 +352,10 @@ func TestGetOrganizationDomain(t *testing.T) {
{
name: "get non-existent domain",
opts: []database.QueryOption{
- database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com")),
+ database.WithCondition(database.And(
+ domainRepo.InstanceIDCondition(instanceID),
+ domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
+ )),
},
err: new(database.NoRowFoundError),
},
@@ -340,9 +363,7 @@ func TestGetOrganizationDomain(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- ctx := t.Context()
-
- result, err := domainRepo.Get(ctx, test.opts...)
+ result, err := domainRepo.Get(t.Context(), tx, test.opts...)
if test.err != nil {
assert.ErrorIs(t, err, test.err)
return
@@ -362,10 +383,22 @@ func TestGetOrganizationDomain(t *testing.T) {
}
func TestListOrganizationDomains(t *testing.T) {
+ tx, err := pool.Begin(t.Context(), nil)
+ require.NoError(t, err)
+ defer func() {
+ err = tx.Rollback(t.Context())
+ if err != nil {
+ t.Log("error during rollback:", err)
+ }
+ }()
+
+ instanceRepo := repository.InstanceRepository()
+ orgRepo := repository.OrganizationRepository()
+ domainRepo := repository.OrganizationDomainRepository()
+
// create instance
- instanceID := gofakeit.UUID()
instance := domain.Instance{
- ID: instanceID,
+ ID: gofakeit.UUID(),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
@@ -375,50 +408,40 @@ func TestListOrganizationDomains(t *testing.T) {
}
// create organization
- orgID := gofakeit.UUID()
organization := domain.Organization{
- ID: orgID,
+ ID: gofakeit.UUID(),
Name: gofakeit.Name(),
- InstanceID: instanceID,
+ InstanceID: instance.ID,
State: domain.OrgStateActive,
}
- tx, err := pool.Begin(t.Context(), nil)
- require.NoError(t, err)
- defer func() {
- require.NoError(t, tx.Rollback(t.Context()))
- }()
-
- instanceRepo := repository.InstanceRepository(tx)
- err = instanceRepo.Create(t.Context(), &instance)
+ err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
- orgRepo := repository.OrganizationRepository(tx)
- err = orgRepo.Create(t.Context(), &organization)
+ err = orgRepo.Create(t.Context(), tx, &organization)
require.NoError(t, err)
// add multiple domains
- domainRepo := orgRepo.Domains(false)
domains := []domain.AddOrganizationDomain{
{
- InstanceID: instanceID,
- OrgID: orgID,
+ InstanceID: instance.ID,
+ OrgID: organization.ID,
Domain: gofakeit.DomainName(),
IsVerified: true,
IsPrimary: true,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
},
{
- InstanceID: instanceID,
- OrgID: orgID,
+ InstanceID: instance.ID,
+ OrgID: organization.ID,
Domain: gofakeit.DomainName(),
IsVerified: false,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
},
{
- InstanceID: instanceID,
- OrgID: orgID,
+ InstanceID: instance.ID,
+ OrgID: organization.ID,
Domain: gofakeit.DomainName(),
IsVerified: true,
IsPrimary: false,
@@ -427,7 +450,7 @@ func TestListOrganizationDomains(t *testing.T) {
}
for i := range domains {
- err = domainRepo.Add(t.Context(), &domains[i])
+ err = domainRepo.Add(t.Context(), tx, &domains[i])
require.NoError(t, err)
}
@@ -437,42 +460,59 @@ func TestListOrganizationDomains(t *testing.T) {
expectedCount int
}{
{
- name: "list all domains",
- opts: []database.QueryOption{},
+ name: "list all domains",
+ opts: []database.QueryOption{
+ database.WithCondition(domainRepo.InstanceIDCondition(instance.ID)),
+ },
expectedCount: 3,
},
{
name: "list verified domains",
opts: []database.QueryOption{
- database.WithCondition(domainRepo.IsVerifiedCondition(true)),
+ database.WithCondition(database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.IsVerifiedCondition(true),
+ )),
},
expectedCount: 2,
},
{
name: "list primary domains",
opts: []database.QueryOption{
- database.WithCondition(domainRepo.IsPrimaryCondition(true)),
+ database.WithCondition(database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.IsPrimaryCondition(true),
+ )),
},
expectedCount: 1,
},
{
name: "list by organization",
opts: []database.QueryOption{
- database.WithCondition(domainRepo.OrgIDCondition(orgID)),
+ database.WithCondition(database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.OrgIDCondition(organization.ID),
+ )),
},
expectedCount: 3,
},
{
name: "list by instance",
opts: []database.QueryOption{
- database.WithCondition(domainRepo.InstanceIDCondition(instanceID)),
+ database.WithCondition(database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.InstanceIDCondition(instance.ID),
+ )),
},
expectedCount: 3,
},
{
name: "list non-existent organization",
opts: []database.QueryOption{
- database.WithCondition(domainRepo.OrgIDCondition("non-existent")),
+ database.WithCondition(database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.OrgIDCondition("non-existent"),
+ )),
},
expectedCount: 0,
},
@@ -482,13 +522,13 @@ func TestListOrganizationDomains(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
ctx := t.Context()
- results, err := domainRepo.List(ctx, test.opts...)
+ results, err := domainRepo.List(ctx, tx, test.opts...)
require.NoError(t, err)
assert.Len(t, results, test.expectedCount)
for _, result := range results {
- assert.Equal(t, instanceID, result.InstanceID)
- assert.Equal(t, orgID, result.OrgID)
+ assert.Equal(t, instance.ID, result.InstanceID)
+ assert.Equal(t, organization.ID, result.OrgID)
assert.NotEmpty(t, result.Domain)
assert.NotEmpty(t, result.CreatedAt)
assert.NotEmpty(t, result.UpdatedAt)
@@ -498,10 +538,22 @@ func TestListOrganizationDomains(t *testing.T) {
}
func TestUpdateOrganizationDomain(t *testing.T) {
+ tx, err := pool.Begin(t.Context(), nil)
+ require.NoError(t, err)
+ defer func() {
+ err = tx.Rollback(t.Context())
+ if err != nil {
+ t.Log("error during rollback:", err)
+ }
+ }()
+
+ instanceRepo := repository.InstanceRepository()
+ orgRepo := repository.OrganizationRepository()
+ domainRepo := repository.OrganizationDomainRepository()
+
// create instance
- instanceID := gofakeit.UUID()
instance := domain.Instance{
- ID: instanceID,
+ ID: gofakeit.UUID(),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
@@ -511,41 +563,30 @@ func TestUpdateOrganizationDomain(t *testing.T) {
}
// create organization
- orgID := gofakeit.UUID()
organization := domain.Organization{
- ID: orgID,
+ ID: gofakeit.UUID(),
Name: gofakeit.Name(),
- InstanceID: instanceID,
+ InstanceID: instance.ID,
State: domain.OrgStateActive,
}
- tx, err := pool.Begin(t.Context(), nil)
- require.NoError(t, err)
- defer func() {
- require.NoError(t, tx.Rollback(t.Context()))
- }()
-
- instanceRepo := repository.InstanceRepository(tx)
- err = instanceRepo.Create(t.Context(), &instance)
+ err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
- orgRepo := repository.OrganizationRepository(tx)
- err = orgRepo.Create(t.Context(), &organization)
+ err = orgRepo.Create(t.Context(), tx, &organization)
require.NoError(t, err)
// add domain
- domainRepo := orgRepo.Domains(false)
- domainName := gofakeit.DomainName()
organizationDomain := &domain.AddOrganizationDomain{
- InstanceID: instanceID,
- OrgID: orgID,
- Domain: domainName,
+ InstanceID: instance.ID,
+ OrgID: organization.ID,
+ Domain: gofakeit.DomainName(),
IsVerified: false,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
}
- err = domainRepo.Add(t.Context(), organizationDomain)
+ err = domainRepo.Add(t.Context(), tx, organizationDomain)
require.NoError(t, err)
tests := []struct {
@@ -556,26 +597,42 @@ func TestUpdateOrganizationDomain(t *testing.T) {
err error
}{
{
- name: "set verified",
- condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
- changes: []database.Change{domainRepo.SetVerified()},
- expected: 1,
+ name: "set verified",
+ condition: database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.OrgIDCondition(organization.ID),
+ domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
+ ),
+ changes: []database.Change{domainRepo.SetVerified()},
+ expected: 1,
},
{
- name: "set primary",
- condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
- changes: []database.Change{domainRepo.SetPrimary()},
- expected: 1,
+ name: "set primary",
+ condition: database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.OrgIDCondition(organization.ID),
+ domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
+ ),
+ changes: []database.Change{domainRepo.SetPrimary()},
+ expected: 1,
},
{
- name: "set validation type",
- condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
- changes: []database.Change{domainRepo.SetValidationType(domain.DomainValidationTypeHTTP)},
- expected: 1,
+ name: "set validation type",
+ condition: database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.OrgIDCondition(organization.ID),
+ domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
+ ),
+ changes: []database.Change{domainRepo.SetValidationType(domain.DomainValidationTypeHTTP)},
+ expected: 1,
},
{
- name: "multiple changes",
- condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
+ name: "multiple changes",
+ condition: database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.OrgIDCondition(organization.ID),
+ domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
+ ),
changes: []database.Change{
domainRepo.SetVerified(),
domainRepo.SetPrimary(),
@@ -584,31 +641,41 @@ func TestUpdateOrganizationDomain(t *testing.T) {
expected: 1,
},
{
- name: "update by org ID and domain",
- condition: database.And(domainRepo.OrgIDCondition(orgID), domainRepo.DomainCondition(database.TextOperationEqual, domainName)),
- changes: []database.Change{domainRepo.SetVerified()},
- expected: 1,
+ name: "update by org ID and domain",
+ condition: database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.OrgIDCondition(organization.ID),
+ domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
+ ),
+ changes: []database.Change{domainRepo.SetVerified()},
+ expected: 1,
},
{
- name: "update non-existent domain",
- condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
- changes: []database.Change{domainRepo.SetVerified()},
- expected: 0,
+ name: "update non-existent domain",
+ condition: database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.OrgIDCondition(organization.ID),
+ domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
+ ),
+ changes: []database.Change{domainRepo.SetVerified()},
+ expected: 0,
},
{
- name: "no changes",
- condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
- changes: []database.Change{},
- expected: 0,
- err: database.ErrNoChanges,
+ name: "no changes",
+ condition: database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.OrgIDCondition(organization.ID),
+ domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
+ ),
+ changes: []database.Change{},
+ expected: 0,
+ err: database.ErrNoChanges,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- ctx := t.Context()
-
- rowsAffected, err := domainRepo.Update(ctx, test.condition, test.changes...)
+ rowsAffected, err := domainRepo.Update(t.Context(), tx, test.condition, test.changes...)
if test.err != nil {
assert.ErrorIs(t, err, test.err)
return
@@ -619,7 +686,7 @@ func TestUpdateOrganizationDomain(t *testing.T) {
// verify changes were applied if rows were affected
if rowsAffected > 0 && len(test.changes) > 0 {
- result, err := domainRepo.Get(ctx, database.WithCondition(test.condition))
+ result, err := domainRepo.Get(t.Context(), tx, database.WithCondition(test.condition))
require.NoError(t, err)
// We know changes were applied since rowsAffected > 0
@@ -632,10 +699,22 @@ func TestUpdateOrganizationDomain(t *testing.T) {
}
func TestRemoveOrganizationDomain(t *testing.T) {
+ tx, err := pool.Begin(t.Context(), nil)
+ require.NoError(t, err)
+ defer func() {
+ err = tx.Rollback(t.Context())
+ if err != nil {
+ t.Log("error during rollback:", err)
+ }
+ }()
+
+ instanceRepo := repository.InstanceRepository()
+ orgRepo := repository.OrganizationRepository()
+ domainRepo := repository.OrganizationDomainRepository()
+
// create instance
- instanceID := gofakeit.UUID()
instance := domain.Instance{
- ID: instanceID,
+ ID: gofakeit.UUID(),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
@@ -645,53 +724,40 @@ func TestRemoveOrganizationDomain(t *testing.T) {
}
// create organization
- orgID := gofakeit.UUID()
organization := domain.Organization{
- ID: orgID,
+ ID: gofakeit.UUID(),
Name: gofakeit.Name(),
- InstanceID: instanceID,
+ InstanceID: instance.ID,
State: domain.OrgStateActive,
}
- tx, err := pool.Begin(t.Context(), nil)
- require.NoError(t, err)
- defer func() {
- require.NoError(t, tx.Rollback(t.Context()))
- }()
-
- instanceRepo := repository.InstanceRepository(tx)
- err = instanceRepo.Create(t.Context(), &instance)
+ err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
- orgRepo := repository.OrganizationRepository(tx)
- err = orgRepo.Create(t.Context(), &organization)
+ err = orgRepo.Create(t.Context(), tx, &organization)
require.NoError(t, err)
// add domains
- domainRepo := orgRepo.Domains(false)
- domainName1 := gofakeit.DomainName()
- domainName2 := gofakeit.DomainName()
-
domain1 := &domain.AddOrganizationDomain{
- InstanceID: instanceID,
- OrgID: orgID,
- Domain: domainName1,
+ InstanceID: instance.ID,
+ OrgID: organization.ID,
+ Domain: gofakeit.DomainName(),
IsVerified: true,
IsPrimary: true,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
}
domain2 := &domain.AddOrganizationDomain{
- InstanceID: instanceID,
- OrgID: orgID,
- Domain: domainName2,
+ InstanceID: instance.ID,
+ OrgID: organization.ID,
+ Domain: gofakeit.DomainName(),
IsVerified: false,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
}
- err = domainRepo.Add(t.Context(), domain1)
+ err = domainRepo.Add(t.Context(), tx, domain1)
require.NoError(t, err)
- err = domainRepo.Add(t.Context(), domain2)
+ err = domainRepo.Add(t.Context(), tx, domain2)
require.NoError(t, err)
tests := []struct {
@@ -700,50 +766,70 @@ func TestRemoveOrganizationDomain(t *testing.T) {
expected int64
}{
{
- name: "remove by domain name",
- condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName1),
- expected: 1,
+ name: "remove by domain name",
+ condition: database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.OrgIDCondition(organization.ID),
+ domainRepo.DomainCondition(database.TextOperationEqual, domain1.Domain),
+ ),
+ expected: 1,
},
{
- name: "remove by primary condition",
- condition: domainRepo.IsPrimaryCondition(false),
- expected: 1, // domain2 should still exist and be non-primary
+ name: "remove by primary condition",
+ condition: database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.OrgIDCondition(organization.ID),
+ domainRepo.IsPrimaryCondition(false),
+ ),
+ expected: 1, // domain2 should still exist and be non-primary
},
{
- name: "remove by org ID and domain",
- condition: database.And(domainRepo.OrgIDCondition(orgID), domainRepo.DomainCondition(database.TextOperationEqual, domainName2)),
- expected: 1,
+ name: "remove by org ID and domain",
+ condition: database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.OrgIDCondition(organization.ID),
+ domainRepo.DomainCondition(database.TextOperationEqual, domain2.Domain),
+ ),
+ expected: 1,
},
{
- name: "remove non-existent domain",
- condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
- expected: 0,
+ name: "remove non-existent domain",
+ condition: database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.OrgIDCondition(organization.ID),
+ domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
+ ),
+ expected: 0,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- ctx := t.Context()
-
- snapshot, err := tx.Begin(ctx)
+ snapshot, err := tx.Begin(t.Context())
require.NoError(t, err)
defer func() {
- require.NoError(t, snapshot.Rollback(ctx))
+ err = snapshot.Rollback(t.Context())
+ if err != nil {
+ t.Log("error during rollback:", err)
+ }
}()
- orgRepo := repository.OrganizationRepository(snapshot)
- domainRepo := orgRepo.Domains(false)
-
// count before removal
- beforeCount, err := domainRepo.List(ctx)
+ beforeCount, err := domainRepo.List(t.Context(), snapshot, database.WithCondition(database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.OrgIDCondition(organization.ID),
+ )))
require.NoError(t, err)
- rowsAffected, err := domainRepo.Remove(ctx, test.condition)
+ rowsAffected, err := domainRepo.Remove(t.Context(), snapshot, test.condition)
require.NoError(t, err)
assert.Equal(t, test.expected, rowsAffected)
// verify removal
- afterCount, err := domainRepo.List(ctx)
+ afterCount, err := domainRepo.List(t.Context(), snapshot, database.WithCondition(database.And(
+ domainRepo.InstanceIDCondition(instance.ID),
+ domainRepo.OrgIDCondition(organization.ID),
+ )))
require.NoError(t, err)
assert.Equal(t, len(beforeCount)-int(test.expected), len(afterCount))
})
@@ -751,8 +837,7 @@ func TestRemoveOrganizationDomain(t *testing.T) {
}
func TestOrganizationDomainConditions(t *testing.T) {
- orgRepo := repository.OrganizationRepository(pool)
- domainRepo := orgRepo.Domains(false)
+ domainRepo := repository.OrganizationDomainRepository()
tests := []struct {
name string
@@ -811,8 +896,7 @@ func TestOrganizationDomainConditions(t *testing.T) {
}
func TestOrganizationDomainChanges(t *testing.T) {
- orgRepo := repository.OrganizationRepository(pool)
- domainRepo := orgRepo.Domains(false)
+ domainRepo := repository.OrganizationDomainRepository()
tests := []struct {
name string
@@ -851,8 +935,7 @@ func TestOrganizationDomainChanges(t *testing.T) {
}
func TestOrganizationDomainColumns(t *testing.T) {
- orgRepo := repository.OrganizationRepository(pool)
- domainRepo := orgRepo.Domains(false)
+ domainRepo := repository.OrganizationDomainRepository()
tests := []struct {
name string
diff --git a/backend/v3/storage/database/repository/org_test.go b/backend/v3/storage/database/repository/org_test.go
index baaa02cff9c..c624ffe6c1c 100644
--- a/backend/v3/storage/database/repository/org_test.go
+++ b/backend/v3/storage/database/repository/org_test.go
@@ -1,7 +1,7 @@
package repository_test
import (
- "context"
+ "strconv"
"testing"
"time"
@@ -15,6 +15,18 @@ import (
)
func TestCreateOrganization(t *testing.T) {
+ beforeCreate := time.Now()
+ tx, err := pool.Begin(t.Context(), nil)
+ require.NoError(t, err)
+ defer func() {
+ err := tx.Rollback(t.Context())
+ if err != nil {
+ t.Logf("error during rollback: %v", err)
+ }
+ }()
+
+ instanceRepo := repository.InstanceRepository()
+ organizationRepo := repository.OrganizationRepository()
// create instance
instanceId := gofakeit.Name()
instance := domain.Instance{
@@ -26,13 +38,12 @@ func TestCreateOrganization(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
- instanceRepo := repository.InstanceRepository(pool)
- err := instanceRepo.Create(t.Context(), &instance)
+ err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
tests := []struct {
name string
- testFunc func(ctx context.Context, t *testing.T) *domain.Organization
+ testFunc func(t *testing.T, client database.QueryExecutor) *domain.Organization
organization domain.Organization
err error
}{
@@ -67,8 +78,7 @@ func TestCreateOrganization(t *testing.T) {
},
{
name: "adding org with same id twice",
- testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
- organizationRepo := repository.OrganizationRepository(pool)
+ testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Organization {
organizationId := gofakeit.Name()
organizationName := gofakeit.Name()
@@ -79,7 +89,7 @@ func TestCreateOrganization(t *testing.T) {
State: domain.OrgStateActive,
}
- err := organizationRepo.Create(ctx, &org)
+ err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
// change the name to make sure same only the id clashes
org.Name = gofakeit.Name()
@@ -89,8 +99,7 @@ func TestCreateOrganization(t *testing.T) {
},
{
name: "adding org with same name twice",
- testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
- organizationRepo := repository.OrganizationRepository(pool)
+ testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Organization {
organizationId := gofakeit.Name()
organizationName := gofakeit.Name()
@@ -101,7 +110,7 @@ func TestCreateOrganization(t *testing.T) {
State: domain.OrgStateActive,
}
- err := organizationRepo.Create(ctx, &org)
+ err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
// change the id to make sure same name+instance causes an error
org.ID = gofakeit.Name()
@@ -111,7 +120,7 @@ func TestCreateOrganization(t *testing.T) {
},
func() struct {
name string
- testFunc func(ctx context.Context, t *testing.T) *domain.Organization
+ testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Organization
organization domain.Organization
err error
} {
@@ -120,12 +129,12 @@ func TestCreateOrganization(t *testing.T) {
return struct {
name string
- testFunc func(ctx context.Context, t *testing.T) *domain.Organization
+ testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Organization
organization domain.Organization
err error
}{
name: "adding org with same name, different instance",
- testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
+ testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Organization {
// create instance
instId := gofakeit.Name()
instance := domain.Instance{
@@ -137,12 +146,9 @@ func TestCreateOrganization(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
- instanceRepo := repository.InstanceRepository(pool)
- err := instanceRepo.Create(ctx, &instance)
+ err := instanceRepo.Create(t.Context(), tx, &instance)
assert.Nil(t, err)
- organizationRepo := repository.OrganizationRepository(pool)
-
org := domain.Organization{
ID: gofakeit.Name(),
Name: organizationName,
@@ -150,7 +156,7 @@ func TestCreateOrganization(t *testing.T) {
State: domain.OrgStateActive,
}
- err = organizationRepo.Create(ctx, &org)
+ err = organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
// change the id to make it unique
@@ -214,19 +220,25 @@ func TestCreateOrganization(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- ctx := context.Background()
+ savepoint, err := tx.Begin(t.Context())
+ require.NoError(t, err)
+ defer func() {
+ err = savepoint.Rollback(t.Context())
+ if err != nil {
+ t.Logf("error during rollback: %v", err)
+ }
+ }()
var organization *domain.Organization
if tt.testFunc != nil {
- organization = tt.testFunc(ctx, t)
+ organization = tt.testFunc(t, savepoint)
} else {
organization = &tt.organization
}
- organizationRepo := repository.OrganizationRepository(pool)
// create organization
- beforeCreate := time.Now()
- err = organizationRepo.Create(ctx, organization)
+
+ err = organizationRepo.Create(t.Context(), savepoint, organization)
assert.ErrorIs(t, err, tt.err)
if err != nil {
return
@@ -234,7 +246,7 @@ func TestCreateOrganization(t *testing.T) {
afterCreate := time.Now()
// check organization values
- organization, err = organizationRepo.Get(ctx,
+ organization, err = organizationRepo.Get(t.Context(), savepoint,
database.WithCondition(
database.And(
organizationRepo.IDCondition(organization.ID),
@@ -255,6 +267,19 @@ func TestCreateOrganization(t *testing.T) {
}
func TestUpdateOrganization(t *testing.T) {
+ beforeUpdate := time.Now()
+ tx, err := pool.Begin(t.Context(), nil)
+ require.NoError(t, err)
+ defer func() {
+ err := tx.Rollback(t.Context())
+ if err != nil {
+ t.Logf("error during rollback: %v", err)
+ }
+ }()
+
+ instanceRepo := repository.InstanceRepository()
+ organizationRepo := repository.OrganizationRepository()
+
// create instance
instanceId := gofakeit.Name()
instance := domain.Instance{
@@ -266,20 +291,18 @@ func TestUpdateOrganization(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
- instanceRepo := repository.InstanceRepository(pool)
- err := instanceRepo.Create(t.Context(), &instance)
+ err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
- organizationRepo := repository.OrganizationRepository(pool)
tests := []struct {
name string
- testFunc func(ctx context.Context, t *testing.T) *domain.Organization
+ testFunc func(t *testing.T) *domain.Organization
update []database.Change
rowsAffected int64
}{
{
name: "happy path update name",
- testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
+ testFunc: func(t *testing.T) *domain.Organization {
organizationId := gofakeit.Name()
organizationName := gofakeit.Name()
@@ -291,7 +314,7 @@ func TestUpdateOrganization(t *testing.T) {
}
// create organization
- err := organizationRepo.Create(ctx, &org)
+ err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
// update with updated value
@@ -303,7 +326,7 @@ func TestUpdateOrganization(t *testing.T) {
},
{
name: "update deleted organization",
- testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
+ testFunc: func(t *testing.T) *domain.Organization {
organizationId := gofakeit.Name()
organizationName := gofakeit.Name()
@@ -315,13 +338,15 @@ func TestUpdateOrganization(t *testing.T) {
}
// create organization
- err := organizationRepo.Create(ctx, &org)
+ err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
// delete instance
- _, err = organizationRepo.Delete(ctx,
- organizationRepo.IDCondition(org.ID),
- org.InstanceID,
+ _, err = organizationRepo.Delete(t.Context(), tx,
+ database.And(
+ organizationRepo.InstanceIDCondition(org.InstanceID),
+ organizationRepo.IDCondition(org.ID),
+ ),
)
require.NoError(t, err)
@@ -332,7 +357,7 @@ func TestUpdateOrganization(t *testing.T) {
},
{
name: "happy path change state",
- testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
+ testFunc: func(t *testing.T) *domain.Organization {
organizationId := gofakeit.Name()
organizationName := gofakeit.Name()
@@ -344,7 +369,7 @@ func TestUpdateOrganization(t *testing.T) {
}
// create organization
- err := organizationRepo.Create(ctx, &org)
+ err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
// update with updated value
@@ -356,7 +381,7 @@ func TestUpdateOrganization(t *testing.T) {
},
{
name: "update non existent organization",
- testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
+ testFunc: func(t *testing.T) *domain.Organization {
organizationId := gofakeit.Name()
org := domain.Organization{
@@ -370,16 +395,14 @@ func TestUpdateOrganization(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- ctx := context.Background()
- organizationRepo := repository.OrganizationRepository(pool)
-
- createdOrg := tt.testFunc(ctx, t)
+ createdOrg := tt.testFunc(t)
// update org
- beforeUpdate := time.Now()
- rowsAffected, err := organizationRepo.Update(ctx,
- organizationRepo.IDCondition(createdOrg.ID),
- createdOrg.InstanceID,
+ rowsAffected, err := organizationRepo.Update(t.Context(), tx,
+ database.And(
+ organizationRepo.InstanceIDCondition(createdOrg.InstanceID),
+ organizationRepo.IDCondition(createdOrg.ID),
+ ),
tt.update...,
)
afterUpdate := time.Now()
@@ -392,7 +415,7 @@ func TestUpdateOrganization(t *testing.T) {
}
// check organization values
- organization, err := organizationRepo.Get(ctx,
+ organization, err := organizationRepo.Get(t.Context(), tx,
database.WithCondition(
database.And(
organizationRepo.IDCondition(createdOrg.ID),
@@ -411,6 +434,19 @@ func TestUpdateOrganization(t *testing.T) {
}
func TestGetOrganization(t *testing.T) {
+ tx, err := pool.Begin(t.Context(), nil)
+ require.NoError(t, err)
+ defer func() {
+ err := tx.Rollback(t.Context())
+ if err != nil {
+ t.Logf("error during rollback: %v", err)
+ }
+ }()
+
+ instanceRepo := repository.InstanceRepository()
+ orgRepo := repository.OrganizationRepository()
+ orgDomainRepo := repository.OrganizationDomainRepository()
+
// create instance
instanceId := gofakeit.Name()
instance := domain.Instance{
@@ -422,11 +458,9 @@ func TestGetOrganization(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
- instanceRepo := repository.InstanceRepository(pool)
- err := instanceRepo.Create(t.Context(), &instance)
- require.NoError(t, err)
- orgRepo := repository.OrganizationRepository(pool)
+ err = instanceRepo.Create(t.Context(), tx, &instance)
+ require.NoError(t, err)
// create organization
// this org is created as an additional org which should NOT
@@ -437,13 +471,13 @@ func TestGetOrganization(t *testing.T) {
InstanceID: instanceId,
State: domain.OrgStateActive,
}
- err = orgRepo.Create(t.Context(), &org)
+ err = orgRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
type test struct {
name string
- testFunc func(ctx context.Context, t *testing.T) *domain.Organization
- orgIdentifierCondition domain.OrgIdentifierCondition
+ testFunc func(t *testing.T) *domain.Organization
+ orgIdentifierCondition database.Condition
err error
}
@@ -452,7 +486,7 @@ func TestGetOrganization(t *testing.T) {
organizationId := gofakeit.Name()
return test{
name: "happy path get using id",
- testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
+ testFunc: func(t *testing.T) *domain.Organization {
organizationName := gofakeit.Name()
org := domain.Organization{
@@ -463,7 +497,7 @@ func TestGetOrganization(t *testing.T) {
}
// create organization
- err := orgRepo.Create(ctx, &org)
+ err := orgRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
return &org
@@ -475,7 +509,7 @@ func TestGetOrganization(t *testing.T) {
organizationId := gofakeit.Name()
return test{
name: "happy path get using id including domain",
- testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
+ testFunc: func(t *testing.T) *domain.Organization {
organizationName := gofakeit.Name()
org := domain.Organization{
@@ -486,7 +520,7 @@ func TestGetOrganization(t *testing.T) {
}
// create organization
- err := orgRepo.Create(ctx, &org)
+ err := orgRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
d := &domain.AddOrganizationDomain{
@@ -496,7 +530,7 @@ func TestGetOrganization(t *testing.T) {
IsVerified: true,
IsPrimary: true,
}
- err = orgRepo.Domains(false).Add(ctx, d)
+ err = orgDomainRepo.Add(t.Context(), tx, d)
require.NoError(t, err)
org.Domains = []*domain.OrganizationDomain{
@@ -521,7 +555,7 @@ func TestGetOrganization(t *testing.T) {
organizationName := gofakeit.Name()
return test{
name: "happy path get using name",
- testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
+ testFunc: func(t *testing.T) *domain.Organization {
organizationId := gofakeit.Name()
org := domain.Organization{
@@ -532,39 +566,36 @@ func TestGetOrganization(t *testing.T) {
}
// create organization
- err := orgRepo.Create(ctx, &org)
+ err := orgRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
return &org
},
- orgIdentifierCondition: orgRepo.NameCondition(organizationName),
+ orgIdentifierCondition: orgRepo.NameCondition(database.TextOperationEqual, organizationName),
}
}(),
{
name: "get non existent organization",
- testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
+ testFunc: func(t *testing.T) *domain.Organization {
org := domain.Organization{
ID: "non existent org",
Name: "non existent org",
}
return &org
},
- orgIdentifierCondition: orgRepo.NameCondition("non-existent-instance-name"),
+ orgIdentifierCondition: orgRepo.NameCondition(database.TextOperationEqual, "non-existent-instance-name"),
err: new(database.NoRowFoundError),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- ctx := context.Background()
- orgRepo := repository.OrganizationRepository(pool)
-
var org *domain.Organization
if tt.testFunc != nil {
- org = tt.testFunc(ctx, t)
+ org = tt.testFunc(t)
}
// get org values
- returnedOrg, err := orgRepo.Get(ctx,
+ returnedOrg, err := orgRepo.Get(t.Context(), tx,
database.WithCondition(
database.And(
tt.orgIdentifierCondition,
@@ -592,11 +623,17 @@ func TestGetOrganization(t *testing.T) {
}
func TestListOrganization(t *testing.T) {
- ctx := t.Context()
- pool, stop, err := newEmbeddedDB(ctx)
+ tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
- defer stop()
- organizationRepo := repository.OrganizationRepository(pool)
+ defer func() {
+ err := tx.Rollback(t.Context())
+ if err != nil {
+ t.Logf("error during rollback: %v", err)
+ }
+ }()
+
+ instanceRepo := repository.InstanceRepository()
+ organizationRepo := repository.OrganizationRepository()
// create instance
instanceId := gofakeit.Name()
@@ -609,33 +646,32 @@ func TestListOrganization(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
- instanceRepo := repository.InstanceRepository(pool)
- err = instanceRepo.Create(ctx, &instance)
+ err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
type test struct {
name string
- testFunc func(ctx context.Context, t *testing.T) []*domain.Organization
+ testFunc func(t *testing.T, tx database.QueryExecutor) []*domain.Organization
conditionClauses []database.Condition
noOrganizationReturned bool
}
tests := []test{
{
name: "happy path single organization no filter",
- testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
+ testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization {
noOfOrganizations := 1
organizations := make([]*domain.Organization, noOfOrganizations)
for i := range noOfOrganizations {
org := domain.Organization{
- ID: gofakeit.Name(),
+ ID: strconv.Itoa(i),
Name: gofakeit.Name(),
InstanceID: instanceId,
State: domain.OrgStateActive,
}
// create organization
- err := organizationRepo.Create(ctx, &org)
+ err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
organizations[i] = &org
@@ -646,20 +682,20 @@ func TestListOrganization(t *testing.T) {
},
{
name: "happy path multiple organization no filter",
- testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
+ testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization {
noOfOrganizations := 5
organizations := make([]*domain.Organization, noOfOrganizations)
for i := range noOfOrganizations {
org := domain.Organization{
- ID: gofakeit.Name(),
+ ID: strconv.Itoa(i),
Name: gofakeit.Name(),
InstanceID: instanceId,
State: domain.OrgStateActive,
}
// create organization
- err := organizationRepo.Create(ctx, &org)
+ err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
organizations[i] = &org
@@ -672,7 +708,7 @@ func TestListOrganization(t *testing.T) {
organizationId := gofakeit.Name()
return test{
name: "organization filter on id",
- testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
+ testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization {
// create organization
// this org is created as an additional org which should NOT
// be returned in the results of this test case
@@ -682,7 +718,7 @@ func TestListOrganization(t *testing.T) {
InstanceID: instanceId,
State: domain.OrgStateActive,
}
- err = organizationRepo.Create(ctx, &org)
+ err = organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
noOfOrganizations := 1
@@ -697,7 +733,7 @@ func TestListOrganization(t *testing.T) {
}
// create organization
- err := organizationRepo.Create(ctx, &org)
+ err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
organizations[i] = &org
@@ -705,12 +741,15 @@ func TestListOrganization(t *testing.T) {
return organizations
},
- conditionClauses: []database.Condition{organizationRepo.IDCondition(organizationId)},
+ conditionClauses: []database.Condition{
+ organizationRepo.InstanceIDCondition(instanceId),
+ organizationRepo.IDCondition(organizationId),
+ },
}
}(),
{
name: "multiple organization filter on state",
- testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
+ testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization {
// create organization
// this org is created as an additional org which should NOT
// be returned in the results of this test case
@@ -720,7 +759,7 @@ func TestListOrganization(t *testing.T) {
InstanceID: instanceId,
State: domain.OrgStateActive,
}
- err = organizationRepo.Create(ctx, &org)
+ err = organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
noOfOrganizations := 5
@@ -728,14 +767,14 @@ func TestListOrganization(t *testing.T) {
for i := range noOfOrganizations {
org := domain.Organization{
- ID: gofakeit.Name(),
+ ID: strconv.Itoa(i),
Name: gofakeit.Name(),
InstanceID: instanceId,
State: domain.OrgStateInactive,
}
// create organization
- err := organizationRepo.Create(ctx, &org)
+ err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
organizations[i] = &org
@@ -743,13 +782,16 @@ func TestListOrganization(t *testing.T) {
return organizations
},
- conditionClauses: []database.Condition{organizationRepo.StateCondition(domain.OrgStateInactive)},
+ conditionClauses: []database.Condition{
+ organizationRepo.InstanceIDCondition(instanceId),
+ organizationRepo.StateCondition(domain.OrgStateInactive),
+ },
},
func() test {
instanceId_2 := gofakeit.Name()
return test{
name: "multiple organization filter on instance",
- testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
+ testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization {
// create instance 1
instanceId_1 := gofakeit.Name()
instance := domain.Instance{
@@ -761,8 +803,7 @@ func TestListOrganization(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
- instanceRepo := repository.InstanceRepository(pool)
- err = instanceRepo.Create(ctx, &instance)
+ err = instanceRepo.Create(t.Context(), tx, &instance)
assert.Nil(t, err)
// create organization
@@ -774,7 +815,7 @@ func TestListOrganization(t *testing.T) {
InstanceID: instanceId_1,
State: domain.OrgStateActive,
}
- err = organizationRepo.Create(ctx, &org)
+ err = organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
// create instance 2
@@ -787,7 +828,7 @@ func TestListOrganization(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
- err = instanceRepo.Create(ctx, &instance_2)
+ err = instanceRepo.Create(t.Context(), tx, &instance_2)
assert.Nil(t, err)
noOfOrganizations := 5
@@ -795,14 +836,14 @@ func TestListOrganization(t *testing.T) {
for i := range noOfOrganizations {
org := domain.Organization{
- ID: gofakeit.Name(),
+ ID: strconv.Itoa(i),
Name: gofakeit.Name(),
InstanceID: instanceId_2,
State: domain.OrgStateActive,
}
// create organization
- err := organizationRepo.Create(ctx, &org)
+ err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
organizations[i] = &org
@@ -816,22 +857,25 @@ func TestListOrganization(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- t.Cleanup(func() {
- _, err := pool.Exec(ctx, "DELETE FROM zitadel.organizations")
- require.NoError(t, err)
- })
+ savepoint, err := tx.Begin(t.Context())
+ require.NoError(t, err)
+ defer func() {
+ err = savepoint.Rollback(t.Context())
+ if err != nil {
+ t.Logf("error during rollback: %v", err)
+ }
+ }()
+ organizations := tt.testFunc(t, savepoint)
- organizations := tt.testFunc(ctx, t)
-
- var condition database.Condition
+ condition := organizationRepo.InstanceIDCondition(instanceId)
if len(tt.conditionClauses) > 0 {
condition = database.And(tt.conditionClauses...)
}
// check organization values
- returnedOrgs, err := organizationRepo.List(ctx,
+ returnedOrgs, err := organizationRepo.List(t.Context(), tx,
database.WithCondition(condition),
- database.WithOrderByAscending(organizationRepo.CreatedAtColumn()),
+ database.WithOrderByAscending(organizationRepo.CreatedAtColumn(), organizationRepo.IDColumn()),
)
require.NoError(t, err)
if tt.noOrganizationReturned {
@@ -851,6 +895,15 @@ func TestListOrganization(t *testing.T) {
}
func TestDeleteOrganization(t *testing.T) {
+ tx, err := pool.Begin(t.Context(), nil)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ _ = tx.Rollback(t.Context())
+ })
+
+ instanceRepo := repository.InstanceRepository()
+ organizationRepo := repository.OrganizationRepository()
+
// create instance
instanceId := gofakeit.Name()
instance := domain.Instance{
@@ -862,24 +915,22 @@ func TestDeleteOrganization(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
- instanceRepo := repository.InstanceRepository(pool)
- err := instanceRepo.Create(t.Context(), &instance)
+ err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
type test struct {
name string
- testFunc func(ctx context.Context, t *testing.T)
- orgIdentifierCondition domain.OrgIdentifierCondition
+ testFunc func(t *testing.T)
+ orgIdentifierCondition database.Condition
noOfDeletedRows int64
}
tests := []test{
func() test {
- organizationRepo := repository.OrganizationRepository(pool)
organizationId := gofakeit.Name()
var noOfOrganizations int64 = 1
return test{
name: "happy path delete organization filter id",
- testFunc: func(ctx context.Context, t *testing.T) {
+ testFunc: func(t *testing.T) {
organizations := make([]*domain.Organization, noOfOrganizations)
for i := range noOfOrganizations {
@@ -891,7 +942,7 @@ func TestDeleteOrganization(t *testing.T) {
}
// create organization
- err := organizationRepo.Create(ctx, &org)
+ err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
organizations[i] = &org
@@ -902,12 +953,11 @@ func TestDeleteOrganization(t *testing.T) {
}
}(),
func() test {
- organizationRepo := repository.OrganizationRepository(pool)
organizationName := gofakeit.Name()
var noOfOrganizations int64 = 1
return test{
name: "happy path delete organization filter name",
- testFunc: func(ctx context.Context, t *testing.T) {
+ testFunc: func(t *testing.T) {
organizations := make([]*domain.Organization, noOfOrganizations)
for i := range noOfOrganizations {
@@ -919,30 +969,28 @@ func TestDeleteOrganization(t *testing.T) {
}
// create organization
- err := organizationRepo.Create(ctx, &org)
+ err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
organizations[i] = &org
}
},
- orgIdentifierCondition: organizationRepo.NameCondition(organizationName),
+ orgIdentifierCondition: organizationRepo.NameCondition(database.TextOperationEqual, organizationName),
noOfDeletedRows: noOfOrganizations,
}
}(),
func() test {
- organizationRepo := repository.OrganizationRepository(pool)
non_existent_organization_name := gofakeit.Name()
return test{
name: "delete non existent organization",
- orgIdentifierCondition: organizationRepo.NameCondition(non_existent_organization_name),
+ orgIdentifierCondition: organizationRepo.NameCondition(database.TextOperationEqual, non_existent_organization_name),
}
}(),
func() test {
- organizationRepo := repository.OrganizationRepository(pool)
organizationName := gofakeit.Name()
return test{
name: "deleted already deleted organization",
- testFunc: func(ctx context.Context, t *testing.T) {
+ testFunc: func(t *testing.T) {
noOfOrganizations := 1
organizations := make([]*domain.Organization, noOfOrganizations)
for i := range noOfOrganizations {
@@ -955,21 +1003,23 @@ func TestDeleteOrganization(t *testing.T) {
}
// create organization
- err := organizationRepo.Create(ctx, &org)
+ err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
organizations[i] = &org
}
// delete organization
- affectedRows, err := organizationRepo.Delete(ctx,
- organizationRepo.NameCondition(organizationName),
- organizations[0].InstanceID,
+ affectedRows, err := organizationRepo.Delete(t.Context(), tx,
+ database.And(
+ organizationRepo.InstanceIDCondition(organizations[0].InstanceID),
+ organizationRepo.NameCondition(database.TextOperationEqual, organizationName),
+ ),
)
assert.Equal(t, int64(1), affectedRows)
require.NoError(t, err)
},
- orgIdentifierCondition: organizationRepo.NameCondition(organizationName),
+ orgIdentifierCondition: organizationRepo.NameCondition(database.TextOperationEqual, organizationName),
// this test should return 0 affected rows as the org was already deleted
noOfDeletedRows: 0,
}
@@ -977,23 +1027,22 @@ func TestDeleteOrganization(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- ctx := context.Background()
- organizationRepo := repository.OrganizationRepository(pool)
-
if tt.testFunc != nil {
- tt.testFunc(ctx, t)
+ tt.testFunc(t)
}
// delete organization
- noOfDeletedRows, err := organizationRepo.Delete(ctx,
- tt.orgIdentifierCondition,
- instanceId,
+ noOfDeletedRows, err := organizationRepo.Delete(t.Context(), tx,
+ database.And(
+ organizationRepo.InstanceIDCondition(instanceId),
+ tt.orgIdentifierCondition,
+ ),
)
require.NoError(t, err)
assert.Equal(t, noOfDeletedRows, tt.noOfDeletedRows)
// check organization was deleted
- organization, err := organizationRepo.Get(ctx,
+ organization, err := organizationRepo.Get(t.Context(), tx,
database.WithCondition(
database.And(
tt.orgIdentifierCondition,
@@ -1006,3 +1055,99 @@ func TestDeleteOrganization(t *testing.T) {
})
}
}
+
+func TestGetOrganizationWithSubResources(t *testing.T) {
+ tx, err := pool.Begin(t.Context(), nil)
+ require.NoError(t, err)
+ t.Cleanup(func() {
+ _ = tx.Rollback(t.Context())
+ })
+
+ instanceRepo := repository.InstanceRepository()
+ orgRepo := repository.OrganizationRepository()
+
+ // create instance
+ instanceId := gofakeit.Name()
+ err = instanceRepo.Create(t.Context(), tx, &domain.Instance{
+ ID: instanceId,
+ Name: gofakeit.Name(),
+ DefaultOrgID: "defaultOrgId",
+ IAMProjectID: "iamProject",
+ ConsoleClientID: "consoleCLient",
+ ConsoleAppID: "consoleApp",
+ DefaultLanguage: "defaultLanguage",
+ })
+ require.NoError(t, err)
+
+ // create organization
+ org := domain.Organization{
+ ID: "1",
+ Name: "org-name",
+ InstanceID: instanceId,
+ State: domain.OrgStateActive,
+ }
+ err = orgRepo.Create(t.Context(), tx, &org)
+ require.NoError(t, err)
+
+ err = orgRepo.Create(t.Context(), tx, &domain.Organization{
+ ID: "without-sub-resources",
+ Name: "org-name-2",
+ InstanceID: instanceId,
+ State: domain.OrgStateActive,
+ })
+ require.NoError(t, err)
+
+ t.Run("domains", func(t *testing.T) {
+ domainRepo := repository.OrganizationDomainRepository()
+
+ domain1 := &domain.AddOrganizationDomain{
+ InstanceID: org.InstanceID,
+ OrgID: org.ID,
+ Domain: "domain1.com",
+ IsVerified: true,
+ IsPrimary: true,
+ }
+ err = domainRepo.Add(t.Context(), tx, domain1)
+ require.NoError(t, err)
+
+ domain2 := &domain.AddOrganizationDomain{
+ InstanceID: org.InstanceID,
+ OrgID: org.ID,
+ Domain: "domain2.com",
+ IsVerified: false,
+ IsPrimary: false,
+ }
+ err = domainRepo.Add(t.Context(), tx, domain2)
+ require.NoError(t, err)
+
+ t.Run("org by domain", func(t *testing.T) {
+ orgRepo := orgRepo.LoadDomains()
+
+ returnedOrg, err := orgRepo.Get(t.Context(), tx,
+ database.WithCondition(
+ database.And(
+ orgRepo.InstanceIDCondition(instanceId),
+ orgRepo.ExistsDomain(domainRepo.DomainCondition(database.TextOperationEqual, domain1.Domain)),
+ ),
+ ),
+ )
+ require.NoError(t, err)
+ assert.Equal(t, org.ID, returnedOrg.ID)
+ assert.Len(t, returnedOrg.Domains, 2)
+ })
+
+ t.Run("ensure org by domain works without LoadDomains", func(t *testing.T) {
+ returnedOrg, err := orgRepo.Get(t.Context(), tx,
+ database.WithCondition(
+ database.And(
+ orgRepo.InstanceIDCondition(instanceId),
+ orgRepo.ExistsDomain(domainRepo.DomainCondition(database.TextOperationEqual, domain1.Domain)),
+ ),
+ ),
+ )
+ require.NoError(t, err)
+ assert.Equal(t, org.ID, returnedOrg.ID)
+ assert.Len(t, returnedOrg.Domains, 0)
+ })
+ })
+}
diff --git a/backend/v3/storage/database/repository/repository.go b/backend/v3/storage/database/repository/repository.go
index c5b9ff81f09..af7d3b08752 100644
--- a/backend/v3/storage/database/repository/repository.go
+++ b/backend/v3/storage/database/repository/repository.go
@@ -4,10 +4,6 @@ import (
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
-type repository struct {
- client database.QueryExecutor
-}
-
func writeCondition(
builder *database.StatementBuilder,
condition database.Condition,
diff --git a/backend/v3/storage/database/repository/user.go b/backend/v3/storage/database/repository/user.go
index 5953af78572..8527767c443 100644
--- a/backend/v3/storage/database/repository/user.go
+++ b/backend/v3/storage/database/repository/user.go
@@ -12,16 +12,10 @@ const queryUserStmt = `SELECT instance_id, org_id, id, username, type, created_a
` first_name, last_name, email_address, email_verified_at, phone_number, phone_verified_at, description` +
` FROM users_view users`
-type user struct {
- repository
-}
+type user struct{}
-func UserRepository(client database.QueryExecutor) domain.UserRepository {
- return &user{
- repository: repository{
- client: client,
- },
- }
+func UserRepository() domain.UserRepository {
+ return new(user)
}
var _ domain.UserRepository = (*user)(nil)
@@ -31,17 +25,17 @@ var _ domain.UserRepository = (*user)(nil)
// -------------------------------------------------------------
// Human implements [domain.UserRepository].
-func (u *user) Human() domain.HumanRepository {
+func (u user) Human() domain.HumanRepository {
return &userHuman{user: u}
}
// Machine implements [domain.UserRepository].
-func (u *user) Machine() domain.MachineRepository {
+func (u user) Machine() domain.MachineRepository {
return &userMachine{user: u}
}
// List implements [domain.UserRepository].
-func (u *user) List(ctx context.Context, opts ...database.QueryOption) (users []*domain.User, err error) {
+func (u user) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (users []*domain.User, err error) {
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
@@ -54,7 +48,7 @@ func (u *user) List(ctx context.Context, opts ...database.QueryOption) (users []
options.WriteLimit(&builder)
options.WriteOffset(&builder)
- rows, err := u.client.Query(ctx, builder.String(), builder.Args()...)
+ rows, err := client.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
@@ -79,7 +73,7 @@ func (u *user) List(ctx context.Context, opts ...database.QueryOption) (users []
}
// Get implements [domain.UserRepository].
-func (u *user) Get(ctx context.Context, opts ...database.QueryOption) (*domain.User, error) {
+func (u user) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.User, error) {
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
@@ -92,7 +86,7 @@ func (u *user) Get(ctx context.Context, opts ...database.QueryOption) (*domain.U
options.WriteLimit(&builder)
options.WriteOffset(&builder)
- return scanUser(u.client.QueryRow(ctx, builder.String(), builder.Args()...))
+ return scanUser(client.QueryRow(ctx, builder.String(), builder.Args()...))
}
const (
@@ -105,7 +99,7 @@ const (
)
// Create implements [domain.UserRepository].
-func (u *user) Create(ctx context.Context, user *domain.User) error {
+func (u user) Create(ctx context.Context, client database.QueryExecutor, user *domain.User) error {
builder := database.StatementBuilder{}
builder.AppendArgs(user.InstanceID, user.OrgID, user.ID, user.Username, user.Traits.Type())
switch trait := user.Traits.(type) {
@@ -116,15 +110,15 @@ func (u *user) Create(ctx context.Context, user *domain.User) error {
builder.WriteString(createMachineStmt)
builder.AppendArgs(trait.Description)
}
- return u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&user.CreatedAt, &user.UpdatedAt)
+ return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&user.CreatedAt, &user.UpdatedAt)
}
// Delete implements [domain.UserRepository].
-func (u *user) Delete(ctx context.Context, condition database.Condition) error {
+func (u user) Delete(ctx context.Context, client database.QueryExecutor, condition database.Condition) error {
builder := database.StatementBuilder{}
builder.WriteString("DELETE FROM users")
writeCondition(&builder, condition)
- _, err := u.client.Exec(ctx, builder.String(), builder.Args()...)
+ _, err := client.Exec(ctx, builder.String(), builder.Args()...)
return err
}
diff --git a/backend/v3/storage/database/repository/user_human.go b/backend/v3/storage/database/repository/user_human.go
index ae7643c53c8..6bacc497d11 100644
--- a/backend/v3/storage/database/repository/user_human.go
+++ b/backend/v3/storage/database/repository/user_human.go
@@ -13,7 +13,7 @@ import (
// -------------------------------------------------------------
type userHuman struct {
- *user
+ user
}
var _ domain.HumanRepository = (*userHuman)(nil)
@@ -21,14 +21,14 @@ var _ domain.HumanRepository = (*userHuman)(nil)
const userEmailQuery = `SELECT h.email_address, h.email_verified_at FROM user_humans h`
// GetEmail implements [domain.HumanRepository].
-func (u *userHuman) GetEmail(ctx context.Context, condition database.Condition) (*domain.Email, error) {
+func (u *userHuman) GetEmail(ctx context.Context, client database.QueryExecutor, condition database.Condition) (*domain.Email, error) {
var email domain.Email
builder := database.StatementBuilder{}
builder.WriteString(userEmailQuery)
writeCondition(&builder, condition)
- err := u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(
+ err := client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(
&email.Address,
&email.VerifiedAt,
)
@@ -39,7 +39,7 @@ func (u *userHuman) GetEmail(ctx context.Context, condition database.Condition)
}
// Update implements [domain.HumanRepository].
-func (h userHuman) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error {
+func (h userHuman) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) error {
builder := database.StatementBuilder{}
builder.WriteString(`UPDATE human_users SET `)
database.Changes(changes).Write(&builder)
@@ -47,7 +47,7 @@ func (h userHuman) Update(ctx context.Context, condition database.Condition, cha
stmt := builder.String()
- _, err := h.client.Exec(ctx, stmt, builder.Args()...)
+ _, err := client.Exec(ctx, stmt, builder.Args()...)
return err
}
diff --git a/backend/v3/storage/database/repository/user_machine.go b/backend/v3/storage/database/repository/user_machine.go
index 2bda09d0e5d..82fe87ba5c0 100644
--- a/backend/v3/storage/database/repository/user_machine.go
+++ b/backend/v3/storage/database/repository/user_machine.go
@@ -8,7 +8,7 @@ import (
)
type userMachine struct {
- *user
+ user
}
var _ domain.MachineRepository = (*userMachine)(nil)
@@ -18,14 +18,14 @@ var _ domain.MachineRepository = (*userMachine)(nil)
// -------------------------------------------------------------
// Update implements [domain.MachineRepository].
-func (m userMachine) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error {
+func (m userMachine) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) error {
builder := database.StatementBuilder{}
builder.WriteString("UPDATE user_machines SET ")
database.Changes(changes).Write(&builder)
writeCondition(&builder, condition)
m.writeReturning()
- _, err := m.client.Exec(ctx, builder.String(), builder.Args()...)
+ _, err := client.Exec(ctx, builder.String(), builder.Args()...)
return err
}
diff --git a/backend/v3/storage/database/statement.go b/backend/v3/storage/database/statement.go
index 2858feae434..77af573f5ee 100644
--- a/backend/v3/storage/database/statement.go
+++ b/backend/v3/storage/database/statement.go
@@ -1,6 +1,7 @@
package database
import (
+ "encoding/hex"
"strconv"
"strings"
)
@@ -20,8 +21,16 @@ type StatementBuilder struct {
existingArgs map[any]string
}
+type argWriter interface {
+ WriteArg(builder *StatementBuilder)
+}
+
// WriteArgs adds the argument to the statement and writes the placeholder to the query.
func (b *StatementBuilder) WriteArg(arg any) {
+ if writer, ok := arg.(argWriter); ok {
+ writer.WriteArg(b)
+ return
+ }
b.WriteString(b.AppendArg(arg))
}
@@ -41,7 +50,13 @@ func (b *StatementBuilder) AppendArg(arg any) (placeholder string) {
if b.existingArgs == nil {
b.existingArgs = make(map[any]string)
}
- if placeholder, ok := b.existingArgs[arg]; ok {
+ // the key is used to work around the following panic:
+ // runtime error: hash of unhashable type []uint8
+ key := arg
+ if argBytes, ok := arg.([]uint8); ok {
+ key = `\\bytes-` + hex.EncodeToString(argBytes)
+ }
+ if placeholder, ok := b.existingArgs[key]; ok {
return placeholder
}
if instruction, ok := arg.(Instruction); ok {
@@ -50,7 +65,7 @@ func (b *StatementBuilder) AppendArg(arg any) (placeholder string) {
b.args = append(b.args, arg)
placeholder = "$" + strconv.Itoa(len(b.args))
- b.existingArgs[arg] = placeholder
+ b.existingArgs[key] = placeholder
return placeholder
}
diff --git a/backend/v3/storage/database/statement_test.go b/backend/v3/storage/database/statement_test.go
new file mode 100644
index 00000000000..147043e8c28
--- /dev/null
+++ b/backend/v3/storage/database/statement_test.go
@@ -0,0 +1,144 @@
+package database
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestStatementBuilder_AppendArg(t *testing.T) {
+ t.Run("same arg returns same placeholder", func(t *testing.T) {
+ var b StatementBuilder
+ placeholder1 := b.AppendArg("same")
+ placeholder2 := b.AppendArg("same")
+ assert.Equal(t, placeholder1, placeholder2)
+ assert.Len(t, b.Args(), 1)
+ assert.Len(t, b.existingArgs, 1)
+ })
+
+ t.Run("same arg different types", func(t *testing.T) {
+ var b StatementBuilder
+ placeholder1 := b.AppendArg("same")
+ placeholder2 := b.AppendArg([]byte("same"))
+ placeholder3 := b.AppendArg("same")
+ assert.NotEqual(t, placeholder1, placeholder2)
+ assert.Equal(t, placeholder1, placeholder3)
+ assert.Len(t, b.Args(), 2)
+ assert.Len(t, b.existingArgs, 2)
+ })
+
+ t.Run("Instruction args are always different", func(t *testing.T) {
+ var b StatementBuilder
+ placeholder1 := b.AppendArg(DefaultInstruction)
+ placeholder2 := b.AppendArg(DefaultInstruction)
+ assert.Equal(t, placeholder1, placeholder2)
+ assert.Len(t, b.Args(), 0)
+ assert.Len(t, b.existingArgs, 0)
+ })
+}
+
+func TestStatementBuilder_AppendArgs(t *testing.T) {
+ t.Run("same arg returns same placeholder", func(t *testing.T) {
+ var b StatementBuilder
+ b.AppendArgs("same", "same")
+ assert.Len(t, b.Args(), 1)
+ assert.Len(t, b.existingArgs, 1)
+ })
+
+ t.Run("same arg different types", func(t *testing.T) {
+ var b StatementBuilder
+ b.AppendArgs("same", []byte("same"), "same")
+ assert.Len(t, b.Args(), 2)
+ assert.Len(t, b.existingArgs, 2)
+ })
+
+ t.Run("Instruction args are always different", func(t *testing.T) {
+ var b StatementBuilder
+ b.AppendArgs(DefaultInstruction, DefaultInstruction)
+ assert.Len(t, b.Args(), 0)
+ assert.Len(t, b.existingArgs, 0)
+ })
+}
+
+func TestStatementBuilder_WriteArg(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ arg any
+ wantSQL string
+ wantArg []any
+ }{
+ {
+ name: "primitive arg",
+ arg: "test",
+ wantSQL: "$1",
+ wantArg: []any{"test"},
+ },
+ {
+ name: "argWriter arg",
+ arg: SHA256Value("wrapped"),
+ wantSQL: "SHA256($1)",
+ wantArg: []any{"wrapped"},
+ },
+ {
+ name: "Instruction arg",
+ arg: DefaultInstruction,
+ wantSQL: "DEFAULT",
+ wantArg: []any{},
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ var b StatementBuilder
+ b.WriteArg(test.arg)
+ assert.Equal(t, test.wantSQL, b.String())
+ require.Len(t, b.Args(), len(test.wantArg))
+ for i := range test.wantArg {
+ assert.Equal(t, test.wantArg[i], b.Args()[i])
+ }
+ })
+ }
+}
+
+func TestStatementBuilder_WriteArgs(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ args []any
+ wantSQL string
+ wantArg []any
+ }{
+ {
+ name: "primitive args",
+ args: []any{"test", 123, true, uint32(123)},
+ wantSQL: "$1, $2, $3, $4",
+ wantArg: []any{"test", 123, true, uint32(123)},
+ },
+ {
+ name: "argWriter args",
+ args: []any{SHA256Value("wrapped"), LowerValue("ASDF")},
+ wantSQL: "SHA256($1), LOWER($2)",
+ wantArg: []any{"wrapped", "ASDF"},
+ },
+ {
+ name: "Instruction args",
+ args: []any{DefaultInstruction, NowInstruction, NullInstruction},
+ wantSQL: "DEFAULT, NOW(), NULL",
+ wantArg: []any{},
+ },
+ {
+ name: "mixed args",
+ args: []any{123, uint32(123), NowInstruction, NullInstruction, SHA256Value("wrapped"), LowerValue("ASDF")},
+ wantSQL: "$1, $2, NOW(), NULL, SHA256($3), LOWER($4)",
+ wantArg: []any{123, uint32(123), "wrapped", "ASDF"},
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ var b StatementBuilder
+ b.WriteArgs(test.args...)
+ assert.Equal(t, test.wantSQL, b.String())
+ require.Len(t, b.Args(), len(test.wantArg))
+ for i := range test.wantArg {
+ assert.Equal(t, test.wantArg[i], b.Args()[i])
+ }
+ })
+ }
+}
diff --git a/internal/query/projection/instance_domain_relational.go b/internal/query/projection/instance_domain_relational.go
index 0deb4e82a45..80606fda677 100644
--- a/internal/query/projection/instance_domain_relational.go
+++ b/internal/query/projection/instance_domain_relational.go
@@ -66,7 +66,7 @@ func (p *instanceDomainRelationalProjection) reduceCustomDomainAdded(event event
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-bXCa6", "reduce.wrong.db.pool %T", ex)
}
- return repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false).Add(ctx, &domain.AddInstanceDomain{
+ return repository.InstanceDomainRepository().Add(ctx, v3_sql.SQLTx(tx), &domain.AddInstanceDomain{
InstanceID: e.Aggregate().InstanceID,
Domain: e.Domain,
IsPrimary: gu.Ptr(false),
@@ -88,25 +88,15 @@ func (p *instanceDomainRelationalProjection) reduceDomainPrimarySet(event events
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-QnjHo", "reduce.wrong.db.pool %T", ex)
}
- domainRepo := repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false)
+ domainRepo := repository.InstanceDomainRepository()
- condition := database.And(
- domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
- domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
- domainRepo.TypeCondition(domain.DomainTypeCustom),
- )
-
- _, err := domainRepo.Update(ctx,
- condition,
+ _, err := domainRepo.Update(ctx, v3_sql.SQLTx(tx),
+ database.And(
+ domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
+ domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
+ domainRepo.TypeCondition(domain.DomainTypeCustom),
+ ),
domainRepo.SetPrimary(),
- )
- if err != nil {
- return err
- }
- // we need to split the update into two statements because multiple events can have the same creation date
- // therefore we first do not set the updated_at timestamp
- _, err = domainRepo.Update(ctx,
- condition,
domainRepo.SetUpdatedAt(e.CreationDate()),
)
return err
@@ -123,8 +113,8 @@ func (p *instanceDomainRelationalProjection) reduceCustomDomainRemoved(event eve
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-58ghE", "reduce.wrong.db.pool %T", ex)
}
- domainRepo := repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false)
- _, err := domainRepo.Remove(ctx,
+ domainRepo := repository.InstanceDomainRepository()
+ _, err := domainRepo.Remove(ctx, v3_sql.SQLTx(tx),
database.And(
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
@@ -145,7 +135,7 @@ func (p *instanceDomainRelationalProjection) reduceTrustedDomainAdded(event even
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-gx7tQ", "reduce.wrong.db.pool %T", ex)
}
- return repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false).Add(ctx, &domain.AddInstanceDomain{
+ return repository.InstanceDomainRepository().Add(ctx, v3_sql.SQLTx(tx), &domain.AddInstanceDomain{
InstanceID: e.Aggregate().InstanceID,
Domain: e.Domain,
Type: domain.DomainTypeTrusted,
@@ -165,8 +155,8 @@ func (p *instanceDomainRelationalProjection) reduceTrustedDomainRemoved(event ev
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-D68ap", "reduce.wrong.db.pool %T", ex)
}
- domainRepo := repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false)
- _, err := domainRepo.Remove(ctx,
+ domainRepo := repository.InstanceDomainRepository()
+ _, err := domainRepo.Remove(ctx, v3_sql.SQLTx(tx),
database.And(
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
diff --git a/internal/query/projection/instance_relational.go b/internal/query/projection/instance_relational.go
index f79127bef86..3607d7986d6 100644
--- a/internal/query/projection/instance_relational.go
+++ b/internal/query/projection/instance_relational.go
@@ -74,7 +74,7 @@ func (p *instanceRelationalProjection) reduceInstanceAdded(event eventstore.Even
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
}
- return repository.InstanceRepository(v3_sql.SQLTx(tx)).Create(ctx, &domain.Instance{
+ return repository.InstanceRepository().Create(ctx, v3_sql.SQLTx(tx), &domain.Instance{
ID: e.Aggregate().ID,
Name: e.Name,
CreatedAt: e.CreationDate(),
@@ -93,8 +93,8 @@ func (p *instanceRelationalProjection) reduceInstanceChanged(event eventstore.Ev
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
}
- repo := repository.InstanceRepository(v3_sql.SQLTx(tx))
- return p.updateInstance(ctx, event, repo, repo.SetName(e.Name))
+ repo := repository.InstanceRepository()
+ return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetName(e.Name))
}), nil
}
@@ -108,7 +108,7 @@ func (p *instanceRelationalProjection) reduceInstanceDelete(event eventstore.Eve
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
}
- _, err := repository.InstanceRepository(v3_sql.SQLTx(tx)).Delete(ctx, e.Aggregate().ID)
+ _, err := repository.InstanceRepository().Delete(ctx, v3_sql.SQLTx(tx), e.Aggregate().ID)
return err
}), nil
}
@@ -124,8 +124,8 @@ func (p *instanceRelationalProjection) reduceDefaultOrgSet(event eventstore.Even
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
}
- repo := repository.InstanceRepository(v3_sql.SQLTx(tx))
- return p.updateInstance(ctx, event, repo, repo.SetDefaultOrg(e.OrgID))
+ repo := repository.InstanceRepository()
+ return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetDefaultOrg(e.OrgID))
}), nil
}
@@ -140,8 +140,8 @@ func (p *instanceRelationalProjection) reduceIAMProjectSet(event eventstore.Even
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
}
- repo := repository.InstanceRepository(v3_sql.SQLTx(tx))
- return p.updateInstance(ctx, event, repo, repo.SetIAMProject(e.ProjectID))
+ repo := repository.InstanceRepository()
+ return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetIAMProject(e.ProjectID))
}), nil
}
@@ -156,8 +156,8 @@ func (p *instanceRelationalProjection) reduceConsoleSet(event eventstore.Event)
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
}
- repo := repository.InstanceRepository(v3_sql.SQLTx(tx))
- return p.updateInstance(ctx, event, repo, repo.SetConsoleClientID(e.ClientID), repo.SetConsoleAppID(e.AppID))
+ repo := repository.InstanceRepository()
+ return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetConsoleClientID(e.ClientID), repo.SetConsoleAppID(e.AppID))
}), nil
}
@@ -172,18 +172,18 @@ func (p *instanceRelationalProjection) reduceDefaultLanguageSet(event eventstore
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
}
- repo := repository.InstanceRepository(v3_sql.SQLTx(tx))
- return p.updateInstance(ctx, event, repo, repo.SetDefaultLanguage(e.Language))
+ repo := repository.InstanceRepository()
+ return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetDefaultLanguage(e.Language))
}), nil
}
-func (p *instanceRelationalProjection) updateInstance(ctx context.Context, event eventstore.Event, repo domain.InstanceRepository, changes ...database.Change) error {
- _, err := repo.Update(ctx, event.Aggregate().ID, changes...)
+func (p *instanceRelationalProjection) updateInstance(ctx context.Context, tx database.Transaction, event eventstore.Event, repo domain.InstanceRepository, changes ...database.Change) error {
+ _, err := repo.Update(ctx, tx, event.Aggregate().ID, changes...)
if err != nil {
return err
}
- instance, err := repo.Get(ctx, database.WithCondition(repo.IDCondition(event.Aggregate().ID)))
+ instance, err := repo.Get(ctx, tx, database.WithCondition(repo.IDCondition(event.Aggregate().ID)))
if err != nil {
return err
}
@@ -192,7 +192,7 @@ func (p *instanceRelationalProjection) updateInstance(ctx context.Context, event
}
// we need to split the update into two statements because multiple events can have the same creation date
// therefore we first do not set the updated_at timestamp
- _, err = repo.Update(ctx,
+ _, err = repo.Update(ctx, tx,
event.Aggregate().ID,
repo.SetUpdatedAt(event.CreatedAt()),
)
diff --git a/internal/query/projection/org_domain_relational.go b/internal/query/projection/org_domain_relational.go
index 747a21a2cb0..2ac3dc44931 100644
--- a/internal/query/projection/org_domain_relational.go
+++ b/internal/query/projection/org_domain_relational.go
@@ -65,7 +65,7 @@ func (p *orgDomainRelationalProjection) reduceAdded(event eventstore.Event) (*ha
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-kGokE", "reduce.wrong.db.pool %T", ex)
}
- return repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false).Add(ctx, &domain.AddOrganizationDomain{
+ return repository.OrganizationDomainRepository().Add(ctx, v3_sql.SQLTx(tx), &domain.AddOrganizationDomain{
InstanceID: e.Aggregate().InstanceID,
OrgID: e.Aggregate().ResourceOwner,
Domain: e.Domain,
@@ -85,23 +85,14 @@ func (p *orgDomainRelationalProjection) reducePrimarySet(event eventstore.Event)
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-h6xF0", "reduce.wrong.db.pool %T", ex)
}
- domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false)
- condition := database.And(
- domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
- domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
- domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
- )
- _, err := domainRepo.Update(ctx,
- condition,
+ domainRepo := repository.OrganizationDomainRepository()
+ _, err := domainRepo.Update(ctx, v3_sql.SQLTx(tx),
+ database.And(
+ domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
+ domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
+ domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
+ ),
domainRepo.SetPrimary(),
- )
- if err != nil {
- return err
- }
- // we need to split the update into two statements because multiple events can have the same creation date
- // therefore we first do not set the updated_at timestamp
- _, err = domainRepo.Update(ctx,
- condition,
domainRepo.SetUpdatedAt(e.CreationDate()),
)
return err
@@ -118,8 +109,8 @@ func (p *orgDomainRelationalProjection) reduceRemoved(event eventstore.Event) (*
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-X8oS8", "reduce.wrong.db.pool %T", ex)
}
- domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false)
- _, err := domainRepo.Remove(ctx,
+ domainRepo := repository.OrganizationDomainRepository()
+ _, err := domainRepo.Remove(ctx, v3_sql.SQLTx(tx),
database.And(
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
@@ -149,24 +140,15 @@ func (p *orgDomainRelationalProjection) reduceVerificationAdded(event eventstore
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-yF03i", "reduce.wrong.db.pool %T", ex)
}
- domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false)
- condition := database.And(
- domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
- domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
- domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
- )
+ domainRepo := repository.OrganizationDomainRepository()
- _, err := domainRepo.Update(ctx,
- condition,
+ _, err := domainRepo.Update(ctx, v3_sql.SQLTx(tx),
+ database.And(
+ domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
+ domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
+ domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
+ ),
domainRepo.SetValidationType(validationType),
- )
- if err != nil {
- return err
- }
- // we need to split the update into two statements because multiple events can have the same creation date
- // therefore we first do not set the updated_at timestamp
- _, err = domainRepo.Update(ctx,
- condition,
domainRepo.SetUpdatedAt(e.CreationDate()),
)
return err
@@ -183,28 +165,17 @@ func (p *orgDomainRelationalProjection) reduceVerified(event eventstore.Event) (
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-0ZGqC", "reduce.wrong.db.pool %T", ex)
}
- domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false)
+ domainRepo := repository.OrganizationDomainRepository()
- condition := database.And(
- domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
- domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
- domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
- )
-
- _, err := domainRepo.Update(ctx,
- condition,
+ _, err := domainRepo.Update(ctx, v3_sql.SQLTx(tx),
+ database.And(
+ domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
+ domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
+ domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
+ ),
domainRepo.SetVerified(),
domainRepo.SetUpdatedAt(e.CreationDate()),
)
- if err != nil {
- return err
- }
- // we need to split the update into two statements because multiple events can have the same creation date
- // therefore we first do not set the updated_at timestamp
- _, err = domainRepo.Update(ctx,
- condition,
- domainRepo.SetUpdatedAt(e.CreationDate()),
- )
return err
}), nil
}