refactor: database interaction and error handling (#10762)

This pull request introduces a significant refactoring of the database
interaction layer, focusing on improving explicitness, transactional
control, and error handling. The core change is the removal of the
stateful `QueryExecutor` from repository instances. Instead, it is now
passed as an argument to each method that interacts with the database.

This change makes transaction management more explicit and flexible, as
the same repository instance can be used with a database pool or a
specific transaction without needing to be re-instantiated.

### Key Changes

- **Explicit `QueryExecutor` Passing:**
- All repository methods (`Get`, `List`, `Create`, `Update`, `Delete`,
etc.) in `InstanceRepository`, `OrganizationRepository`,
`UserRepository`, and their sub-repositories now require a
`database.QueryExecutor` (e.g., a `*pgxpool.Pool` or `pgx.Tx`) as the
first argument.
- Repository constructors no longer accept a `QueryExecutor`. For
example, `repository.InstanceRepository(pool)` is now
`repository.InstanceRepository()`.

- **Enhanced Error Handling:**
- A new `database.MissingConditionError` is introduced to enforce
required query conditions, such as ensuring an `instance_id` is always
present in `UPDATE` and `DELETE` operations.
- The database error wrapper in the `postgres` package now correctly
identifies and wraps `pgx.ErrTooManyRows` and similar errors from the
`scany` library into a `database.MultipleRowsFoundError`.

- **Improved Database Conditions:**
- The `database.Condition` interface now includes a
`ContainsColumn(Column) bool` method. This allows for runtime checks to
ensure that critical filters (like `instance_id`) are included in a
query, preventing accidental cross-tenant data modification.
- A new `database.Exists()` condition has been added to support `EXISTS`
subqueries, enabling more complex filtering logic, such as finding an
organization that has a specific domain.

- **Repository and Interface Refactoring:**
- The method for loading related entities (e.g., domains for an
organization) has been changed from a boolean flag (`Domains(true)`) to
a more explicit, chainable method (`LoadDomains()`). This returns a new
repository instance configured to load the sub-resource, promoting
immutability.
- The custom `OrgIdentifierCondition` has been removed in favor of using
the standard `database.Condition` interface, simplifying the API.

- **Code Cleanup and Test Updates:**
  - Unnecessary struct embeddings and metadata have been removed.
- All integration and repository tests have been updated to reflect the
new method signatures, passing the database pool or transaction object
explicitly.
- New tests have been added to cover the new `ExistsDomain`
functionality and other enhancements.

These changes make the data access layer more robust, predictable, and
easier to work with, especially in the context of database transactions.
This commit is contained in:
Silvan
2025-09-24 12:12:31 +02:00
committed by GitHub
parent 09d09ab337
commit cccfc816f6
53 changed files with 3900 additions and 1430 deletions

View File

@@ -69,6 +69,8 @@ type instanceConditions interface {
IDCondition(instanceID string) database.Condition
// NameCondition returns a filter on the name field.
NameCondition(op database.TextOperation, name string) database.Condition
// ExistsDomain returns a filter on the instance domains.
ExistsDomain(cond database.Condition) database.Condition
}
// instanceChanges define all the changes for the instance table.
@@ -99,17 +101,16 @@ type InstanceRepository interface {
// Member returns the member repository which is a sub repository of the instance repository.
// Member() MemberRepository
Get(ctx context.Context, opts ...database.QueryOption) (*Instance, error)
List(ctx context.Context, opts ...database.QueryOption) ([]*Instance, error)
Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*Instance, error)
List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*Instance, error)
Create(ctx context.Context, instance *Instance) error
Update(ctx context.Context, id string, changes ...database.Change) (int64, error)
Delete(ctx context.Context, id string) (int64, error)
Create(ctx context.Context, client database.QueryExecutor, instance *Instance) error
Update(ctx context.Context, client database.QueryExecutor, id string, changes ...database.Change) (int64, error)
Delete(ctx context.Context, client database.QueryExecutor, id string) (int64, error)
// Domains returns the domain sub repository for the instance.
// If shouldLoad is true, the domains will be loaded from the database and written to the [Instance].Domains field.
// If shouldLoad is set to true once, the Domains field will be set even if shouldLoad is false in the future.
Domains(shouldLoad bool) InstanceDomainRepository
// LoadDomains loads the domains of the given instance.
// If it is called the [Instance].Domains field will be set on future calls to Get or List.
LoadDomains() InstanceRepository
}
type CreateInstance struct {

View File

@@ -65,15 +65,15 @@ type InstanceDomainRepository interface {
// Get returns a single domain based on the criteria.
// If no domain is found, it returns an error of type [database.ErrNotFound].
// If multiple domains are found, it returns an error of type [database.ErrMultipleRows].
Get(ctx context.Context, opts ...database.QueryOption) (*InstanceDomain, error)
Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*InstanceDomain, error)
// List returns a list of domains based on the criteria.
// If no domains are found, it returns an empty slice.
List(ctx context.Context, opts ...database.QueryOption) ([]*InstanceDomain, error)
List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*InstanceDomain, error)
// Add adds a new domain to the instance.
Add(ctx context.Context, domain *AddInstanceDomain) error
Add(ctx context.Context, client database.QueryExecutor, domain *AddInstanceDomain) error
// Update updates an existing domain in the instance.
Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error)
Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error)
// Remove removes a domain from the instance.
Remove(ctx context.Context, condition database.Condition) (int64, error)
Remove(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error)
}

View File

@@ -26,13 +26,6 @@ type Organization struct {
Domains []*OrganizationDomain `json:"domains,omitempty" db:"-"` // domains need to be handled separately
}
// OrgIdentifierCondition is used to help specify a single Organization,
// it will either be used as the organization ID or organization name,
// as organizations can be identified either using (instanceID + ID) OR (instanceID + name)
type OrgIdentifierCondition interface {
database.Condition
}
// organizationColumns define all the columns of the instance table.
type organizationColumns interface {
// IDColumn returns the column for the id field.
@@ -52,13 +45,15 @@ type organizationColumns interface {
// organizationConditions define all the conditions for the instance table.
type organizationConditions interface {
// IDCondition returns an equal filter on the id field.
IDCondition(instanceID string) OrgIdentifierCondition
IDCondition(instanceID string) database.Condition
// NameCondition returns a filter on the name field.
NameCondition(name string) OrgIdentifierCondition
NameCondition(op database.TextOperation, name string) database.Condition
// InstanceIDCondition returns a filter on the instance id field.
InstanceIDCondition(instanceID string) database.Condition
// StateCondition returns a filter on the name field.
StateCondition(state OrgState) database.Condition
// ExistsDomain returns a filter on the organizations domains.
ExistsDomain(cond database.Condition) database.Condition
}
// organizationChanges define all the changes for the instance table.
@@ -75,17 +70,16 @@ type OrganizationRepository interface {
organizationConditions
organizationChanges
Get(ctx context.Context, opts ...database.QueryOption) (*Organization, error)
List(ctx context.Context, opts ...database.QueryOption) ([]*Organization, error)
Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*Organization, error)
List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*Organization, error)
Create(ctx context.Context, instance *Organization) error
Update(ctx context.Context, id OrgIdentifierCondition, instance_id string, changes ...database.Change) (int64, error)
Delete(ctx context.Context, id OrgIdentifierCondition, instance_id string) (int64, error)
Create(ctx context.Context, client database.QueryExecutor, org *Organization) error
Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error)
Delete(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error)
// Domains returns the domain sub repository for the organization.
// If shouldLoad is true, the domains will be loaded from the database and written to the [Instance].Domains field.
// If shouldLoad is set to true once, the Domains field will be set event if shouldLoad is false in the future.
Domains(shouldLoad bool) OrganizationDomainRepository
// LoadDomains loads the domains of the given organizations.
// If it is called the [Organization].Domains field will be set on future calls to Get or List.
LoadDomains() OrganizationRepository
}
type CreateOrganization struct {

View File

@@ -70,15 +70,15 @@ type OrganizationDomainRepository interface {
// Get returns a single domain based on the criteria.
// If no domain is found, it returns an error of type [database.ErrNotFound].
// If multiple domains are found, it returns an error of type [database.ErrMultipleRows].
Get(ctx context.Context, opts ...database.QueryOption) (*OrganizationDomain, error)
Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*OrganizationDomain, error)
// List returns a list of domains based on the criteria.
// If no domains are found, it returns an empty slice.
List(ctx context.Context, opts ...database.QueryOption) ([]*OrganizationDomain, error)
List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*OrganizationDomain, error)
// Add adds a new domain to the organization.
Add(ctx context.Context, domain *AddOrganizationDomain) error
Add(ctx context.Context, client database.QueryExecutor, domain *AddOrganizationDomain) error
// Update updates an existing domain in the organization.
Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error)
Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error)
// Remove removes a domain from the organization.
Remove(ctx context.Context, condition database.Condition) (int64, error)
Remove(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error)
}

View File

@@ -57,13 +57,13 @@ type UserRepository interface {
userConditions
userChanges
// Get returns a user based on the given condition.
Get(ctx context.Context, opts ...database.QueryOption) (*User, error)
Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*User, error)
// List returns a list of users based on the given condition.
List(ctx context.Context, opts ...database.QueryOption) ([]*User, error)
List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*User, error)
// Create creates a new user.
Create(ctx context.Context, user *User) error
Create(ctx context.Context, client database.QueryExecutor, user *User) error
// Delete removes users based on the given condition.
Delete(ctx context.Context, condition database.Condition) error
Delete(ctx context.Context, client database.QueryExecutor, condition database.Condition) error
// Human returns the [HumanRepository].
Human() HumanRepository
// Machine returns the [MachineRepository].
@@ -143,9 +143,9 @@ type HumanRepository interface {
humanChanges
// Get returns an email based on the given condition.
GetEmail(ctx context.Context, condition database.Condition) (*Email, error)
GetEmail(ctx context.Context, client database.QueryExecutor, condition database.Condition) (*Email, error)
// Update updates human users based on the given condition and changes.
Update(ctx context.Context, condition database.Condition, changes ...database.Change) error
Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) error
}
// machineColumns define all the columns of the machine table which inherits the user table.
@@ -172,7 +172,7 @@ type machineChanges interface {
// MachineRepository is the interface for the machine repository it inherits the user repository.
type MachineRepository interface {
// Update updates machine users based on the given condition and changes.
Update(ctx context.Context, condition database.Condition, changes ...database.Change) error
Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) error
machineColumns
machineConditions

View File

@@ -1,9 +1,14 @@
package database
import "slices"
// Change represents a change to a column in a database table.
// Its written in the SET clause of an UPDATE statement.
type Change interface {
// Write writes the change to the given statement builder.
Write(builder *StatementBuilder)
// IsOnColumn checks if the change is on the given column.
IsOnColumn(col Column) bool
}
type change[V Value] struct {
@@ -13,6 +18,8 @@ type change[V Value] struct {
var _ Change = (*change[string])(nil)
// NewChange creates a new Change for the given column and value.
// If you want to set a column to NULL, use [NewChangePtr].
func NewChange[V Value](col Column, value V) Change {
return &change[V]{
column: col,
@@ -20,6 +27,8 @@ func NewChange[V Value](col Column, value V) Change {
}
}
// NewChangePtr creates a new Change for the given column and value pointer.
// If the value pointer is nil, the column will be set to NULL.
func NewChangePtr[V Value](col Column, value *V) Change {
if value == nil {
return NewChange(col, NullInstruction)
@@ -34,19 +43,31 @@ func (c change[V]) Write(builder *StatementBuilder) {
builder.WriteArg(c.value)
}
// IsOnColumn implements [Change].
func (c change[V]) IsOnColumn(col Column) bool {
return c.column.Equals(col)
}
type Changes []Change
func NewChanges(cols ...Change) Change {
return Changes(cols)
}
// IsOnColumn implements [Change].
func (c Changes) IsOnColumn(col Column) bool {
return slices.ContainsFunc(c, func(change Change) bool {
return change.IsOnColumn(col)
})
}
// Write implements [Change].
func (m Changes) Write(builder *StatementBuilder) {
for i, col := range m {
for i, change := range m {
if i > 0 {
builder.WriteString(", ")
}
col.Write(builder)
change.Write(builder)
}
}

View File

@@ -0,0 +1,68 @@
package database
import (
"testing"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestChangeWrite(t *testing.T) {
type want struct {
stmt string
args []any
}
for _, test := range []struct {
name string
change Change
want want
}{
{
name: "change",
change: NewChange(NewColumn("table", "column"), "value"),
want: want{
stmt: "column = $1",
args: []any{"value"},
},
},
{
name: "change ptr to null",
change: NewChangePtr[int](NewColumn("table", "column"), nil),
want: want{
stmt: "column = NULL",
args: nil,
},
},
{
name: "change ptr to value",
change: NewChangePtr(NewColumn("table", "column"), gu.Ptr(42)),
want: want{
stmt: "column = $1",
args: []any{42},
},
},
{
name: "multiple changes",
change: NewChanges(
NewChange(NewColumn("table", "column1"), "value1"),
NewChangePtr[int](NewColumn("table", "column2"), nil),
NewChange(NewColumn("table", "column3"), 123),
),
want: want{
stmt: "column1 = $1, column2 = NULL, column3 = $2",
args: []any{"value1", 123},
},
},
} {
t.Run(test.name, func(t *testing.T) {
var builder StatementBuilder
test.change.Write(&builder)
assert.Equal(t, test.want.stmt, builder.String())
require.Len(t, builder.Args(), len(test.want.args))
for i, arg := range test.want.args {
assert.Equal(t, arg, builder.Args()[i])
}
})
}
}

View File

@@ -3,8 +3,9 @@ package database
type Columns []Column
// WriteQualified implements [Column].
func (m Columns) WriteQualified(builder *StatementBuilder) {
for i, col := range m {
// Columns are separated by ", ".
func (c Columns) WriteQualified(builder *StatementBuilder) {
for i, col := range c {
if i > 0 {
builder.WriteString(", ")
}
@@ -13,8 +14,9 @@ func (m Columns) WriteQualified(builder *StatementBuilder) {
}
// WriteUnqualified implements [Column].
func (m Columns) WriteUnqualified(builder *StatementBuilder) {
for i, col := range m {
// Columns are separated by ", ".
func (c Columns) WriteUnqualified(builder *StatementBuilder) {
for i, col := range c {
if i > 0 {
builder.WriteString(", ")
}
@@ -22,11 +24,33 @@ func (m Columns) WriteUnqualified(builder *StatementBuilder) {
}
}
// Equals implements [Column].
func (c Columns) Equals(col Column) bool {
if col == nil {
return c == nil
}
other, ok := col.(Columns)
if !ok || len(other) != len(c) {
return false
}
for i, col := range c {
if !col.Equals(other[i]) {
return false
}
}
return true
}
var _ Column = (Columns)(nil)
// Column represents a column in a database table.
type Column interface {
// Write(builder *StatementBuilder)
// WriteQualified writes the column with the table name as prefix.
WriteQualified(builder *StatementBuilder)
// WriteUnqualified writes the column without the table name as prefix.
WriteUnqualified(builder *StatementBuilder)
// Equals checks if two columns are equal.
Equals(col Column) bool
}
type column struct {
@@ -35,7 +59,7 @@ type column struct {
}
func NewColumn(table, name string) Column {
return column{table: table, name: name}
return &column{table: table, name: name}
}
// WriteQualified implements [Column].
@@ -51,35 +75,69 @@ func (c column) WriteUnqualified(builder *StatementBuilder) {
builder.WriteString(c.name)
}
var _ Column = (*column)(nil)
// Equals implements [Column].
func (c *column) Equals(col Column) bool {
if col == nil {
return c == nil
}
toMatch, ok := col.(*column)
if !ok {
return false
}
return c.table == toMatch.table && c.name == toMatch.name
}
// // ignoreCaseColumn represents two database columns, one for the
// // original value and one for the lower case value.
// type ignoreCaseColumn interface {
// Column
// WriteIgnoreCase(builder *StatementBuilder)
// }
// LowerColumn returns a column that represents LOWER(col).
func LowerColumn(col Column) Column {
return &functionColumn{fn: functionLower, col: col}
}
// func NewIgnoreCaseColumn(col Column, suffix string) ignoreCaseColumn {
// return ignoreCaseCol{
// column: col,
// suffix: suffix,
// }
// }
// SHA256Column returns a column that represents SHA256(col).
func SHA256Column(col Column) Column {
return &functionColumn{fn: functionSHA256, col: col}
}
// type ignoreCaseCol struct {
// column Column
// suffix string
// }
type functionColumn struct {
fn function
col Column
}
// // WriteIgnoreCase implements [ignoreCaseColumn].
// func (c ignoreCaseCol) WriteIgnoreCase(builder *StatementBuilder) {
// c.column.WriteQualified(builder)
// builder.WriteString(c.suffix)
// }
type function string
// // WriteQualified implements [ignoreCaseColumn].
// func (c ignoreCaseCol) WriteQualified(builder *StatementBuilder) {
// c.column.WriteQualified(builder)
// builder.WriteString(c.suffix)
// }
const (
_ function = ""
functionLower function = "LOWER"
functionSHA256 function = "SHA256"
)
// WriteQualified implements [Column].
func (c functionColumn) WriteQualified(builder *StatementBuilder) {
builder.Grow(len(c.fn) + 2)
builder.WriteString(string(c.fn))
builder.WriteRune('(')
c.col.WriteQualified(builder)
builder.WriteRune(')')
}
// WriteUnqualified implements [Column].
func (c functionColumn) WriteUnqualified(builder *StatementBuilder) {
builder.Grow(len(c.fn) + 2)
builder.WriteString(string(c.fn))
builder.WriteRune('(')
c.col.WriteUnqualified(builder)
builder.WriteRune(')')
}
// Equals implements [Column].
func (c *functionColumn) Equals(col Column) bool {
if col == nil {
return c == nil
}
toMatch, ok := col.(*functionColumn)
if !ok || toMatch.fn != c.fn {
return false
}
return c.col.Equals(toMatch.col)
}
var _ Column = (*functionColumn)(nil)

View File

@@ -0,0 +1,212 @@
package database
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestWriteUnqualified(t *testing.T) {
for _, tests := range []struct {
name string
column Column
expected string
}{
{
name: "column",
column: NewColumn("table", "column"),
expected: "column",
},
{
name: "columns",
column: Columns{
NewColumn("table", "column1"),
NewColumn("table", "column2"),
},
expected: "column1, column2",
},
{
name: "function column",
column: SHA256Column(NewColumn("table", "column")),
expected: "SHA256(column)",
},
} {
t.Run(tests.name, func(t *testing.T) {
var builder StatementBuilder
tests.column.WriteUnqualified(&builder)
assert.Equal(t, tests.expected, builder.String())
})
}
}
func TestWriteQualified(t *testing.T) {
for _, tests := range []struct {
name string
column Column
expected string
}{
{
name: "column",
column: NewColumn("table", "column"),
expected: "table.column",
},
{
name: "columns",
column: Columns{
NewColumn("table", "column1"),
NewColumn("table", "column2"),
},
expected: "table.column1, table.column2",
},
{
name: "function column",
column: SHA256Column(NewColumn("table", "column")),
expected: "SHA256(table.column)",
},
} {
t.Run(tests.name, func(t *testing.T) {
var builder StatementBuilder
tests.column.WriteQualified(&builder)
assert.Equal(t, tests.expected, builder.String())
})
}
}
func TestEquals(t *testing.T) {
for _, tests := range []struct {
name string
column Column
toCheck Column
expected bool
}{
{
name: "column equal",
column: NewColumn("table", "column"),
toCheck: NewColumn("table", "column"),
expected: true,
},
{
name: "column nil check",
column: NewColumn("table", "column"),
toCheck: nil,
expected: false,
},
{
name: "column both nil",
column: (*column)(nil),
toCheck: nil,
expected: true,
},
{
name: "column not equal (different name)",
column: NewColumn("table", "column"),
toCheck: NewColumn("table", "column2"),
expected: false,
},
{
name: "column not equal (different type)",
column: NewColumn("table", "column"),
toCheck: SHA256Column(NewColumn("table", "column")),
expected: false,
},
{
name: "columns equal",
column: Columns{
NewColumn("table", "column1"),
NewColumn("table", "column2"),
},
toCheck: Columns{
NewColumn("table", "column1"),
NewColumn("table", "column2"),
},
expected: true,
},
{
name: "columns nil check",
column: Columns{
NewColumn("table", "column1"),
NewColumn("table", "column2"),
},
toCheck: nil,
expected: false,
},
{
name: "columns both nil",
column: Columns(nil),
toCheck: nil,
expected: true,
},
{
name: "columns not equal (different type)",
column: Columns{
NewColumn("table", "column1"),
NewColumn("table", "column2"),
},
toCheck: NewColumn("table", "column1"),
expected: false,
},
{
name: "columns not equal (different length)",
column: Columns{
NewColumn("table", "column1"),
NewColumn("table", "column2"),
},
toCheck: Columns{
NewColumn("table", "column1"),
},
expected: false,
},
{
name: "columns not equal (different order)",
column: Columns{
NewColumn("table", "column1"),
NewColumn("table", "column2"),
},
toCheck: Columns{
NewColumn("table", "column2"),
NewColumn("table", "column1"),
},
expected: false,
},
{
name: "function column equal",
column: SHA256Column(NewColumn("table", "column")),
toCheck: SHA256Column(NewColumn("table", "column")),
expected: true,
},
{
name: "function nil check",
column: SHA256Column(NewColumn("table", "column")),
toCheck: nil,
expected: false,
},
{
name: "function both nil",
column: (*functionColumn)(nil),
toCheck: nil,
expected: true,
},
{
name: "function column not equal (different function)",
column: SHA256Column(NewColumn("table", "column")),
toCheck: LowerColumn(NewColumn("table", "column")),
expected: false,
},
{
name: "function column not equal (different inner column)",
column: SHA256Column(NewColumn("table", "column")),
toCheck: SHA256Column(NewColumn("table", "column2")),
expected: false,
},
{
name: "function column not equal (different type)",
column: SHA256Column(NewColumn("table", "column")),
toCheck: NewColumn("table", "column2"),
expected: false,
},
} {
t.Run(tests.name, func(t *testing.T) {
assert.Equal(t, tests.expected, tests.column.Equals(tests.toCheck))
})
}
}

View File

@@ -4,6 +4,9 @@ package database
// Its written after the WHERE keyword in a SQL statement.
type Condition interface {
Write(builder *StatementBuilder)
// IsRestrictingColumn is used to check if the condition filters for a specific column.
// It acts as a save guard database operations that should be specific on the given column.
IsRestrictingColumn(col Column) bool
}
type and struct {
@@ -11,7 +14,7 @@ type and struct {
}
// Write implements [Condition].
func (a *and) Write(builder *StatementBuilder) {
func (a and) Write(builder *StatementBuilder) {
if len(a.conditions) > 1 {
builder.WriteString("(")
defer builder.WriteString(")")
@@ -29,6 +32,16 @@ func And(conditions ...Condition) *and {
return &and{conditions: conditions}
}
// IsRestrictingColumn implements [Condition].
func (a and) IsRestrictingColumn(col Column) bool {
for _, condition := range a.conditions {
if condition.IsRestrictingColumn(col) {
return true
}
}
return false
}
var _ Condition = (*and)(nil)
type or struct {
@@ -36,7 +49,7 @@ type or struct {
}
// Write implements [Condition].
func (o *or) Write(builder *StatementBuilder) {
func (o or) Write(builder *StatementBuilder) {
if len(o.conditions) > 1 {
builder.WriteString("(")
defer builder.WriteString(")")
@@ -54,6 +67,17 @@ func Or(conditions ...Condition) *or {
return &or{conditions: conditions}
}
// IsRestrictingColumn implements [Condition].
// It returns true only if all conditions are restricting the given column.
func (o or) IsRestrictingColumn(col Column) bool {
for _, condition := range o.conditions {
if !condition.IsRestrictingColumn(col) {
return false
}
}
return true
}
var _ Condition = (*or)(nil)
type isNull struct {
@@ -61,7 +85,7 @@ type isNull struct {
}
// Write implements [Condition].
func (i *isNull) Write(builder *StatementBuilder) {
func (i isNull) Write(builder *StatementBuilder) {
i.column.WriteQualified(builder)
builder.WriteString(" IS NULL")
}
@@ -71,6 +95,12 @@ func IsNull(column Column) *isNull {
return &isNull{column: column}
}
// IsRestrictingColumn implements [Condition].
// It returns false because it cannot be used for restricting a column.
func (i isNull) IsRestrictingColumn(col Column) bool {
return false
}
var _ Condition = (*isNull)(nil)
type isNotNull struct {
@@ -78,7 +108,7 @@ type isNotNull struct {
}
// Write implements [Condition].
func (i *isNotNull) Write(builder *StatementBuilder) {
func (i isNotNull) Write(builder *StatementBuilder) {
i.column.WriteQualified(builder)
builder.WriteString(" IS NOT NULL")
}
@@ -88,43 +118,122 @@ func IsNotNull(column Column) *isNotNull {
return &isNotNull{column: column}
}
// IsRestrictingColumn implements [Condition].
// It returns false because it cannot be used for restricting a column.
func (i isNotNull) IsRestrictingColumn(col Column) bool {
return false
}
var _ Condition = (*isNotNull)(nil)
type valueCondition func(builder *StatementBuilder)
type valueCondition struct {
write func(builder *StatementBuilder)
col Column
}
// NewTextCondition creates a condition that compares a text column with a value.
func NewTextCondition[V Text](col Column, op TextOperation, value V) Condition {
return valueCondition(func(builder *StatementBuilder) {
writeTextOperation(builder, col, op, value)
})
// If you want to use ignore case operations, consider using [NewTextIgnoreCaseCondition].
func NewTextCondition[T Text](col Column, op TextOperation, value T) Condition {
return valueCondition{
col: col,
write: func(builder *StatementBuilder) {
writeTextOperation[T](builder, col, op, value)
},
}
}
// NewTextIgnoreCaseCondition creates a condition that compares a text column with a value, ignoring case by lowercasing both.
func NewTextIgnoreCaseCondition[T Text](col Column, op TextOperation, value T) Condition {
return valueCondition{
col: col,
write: func(builder *StatementBuilder) {
writeTextOperation[T](builder, LowerColumn(col), op, LowerValue(value))
},
}
}
// NewDateCondition creates a condition that compares a numeric column with a value.
func NewNumberCondition[V Number](col Column, op NumberOperation, value V) Condition {
return valueCondition(func(builder *StatementBuilder) {
writeNumberOperation(builder, col, op, value)
})
return valueCondition{
col: col,
write: func(builder *StatementBuilder) {
writeNumberOperation[V](builder, col, op, value)
},
}
}
// NewDateCondition creates a condition that compares a boolean column with a value.
func NewBooleanCondition[V Boolean](col Column, value V) Condition {
return valueCondition(func(builder *StatementBuilder) {
writeBooleanOperation(builder, col, value)
})
return valueCondition{
col: col,
write: func(builder *StatementBuilder) {
writeBooleanOperation[V](builder, col, value)
},
}
}
// NewBytesCondition creates a condition that compares a BYTEA column with a value.
func NewBytesCondition[V Bytes](col Column, op BytesOperation, value any) Condition {
return valueCondition{
col: col,
write: func(builder *StatementBuilder) {
writeBytesOperation[V](builder, col, op, value)
},
}
}
// NewColumnCondition creates a condition that compares two columns on equality.
func NewColumnCondition(col1, col2 Column) Condition {
return valueCondition(func(builder *StatementBuilder) {
return valueCondition{
col: col1,
write: func(builder *StatementBuilder) {
col1.WriteQualified(builder)
builder.WriteString(" = ")
col2.WriteQualified(builder)
})
},
}
}
// Write implements [Condition].
func (c valueCondition) Write(builder *StatementBuilder) {
c(builder)
c.write(builder)
}
// IsRestrictingColumn implements [Condition].
func (i valueCondition) IsRestrictingColumn(col Column) bool {
return i.col.Equals(col)
}
var _ Condition = (*valueCondition)(nil)
// existsCondition is a helper to write an EXISTS (SELECT 1 FROM <table> WHERE <condition>) clause.
// It implements Condition so it can be composed with other conditions using And/Or.
type existsCondition struct {
table string
condition Condition
}
// Exists creates a condition that checks for the existence of rows in a subquery.
func Exists(table string, condition Condition) Condition {
return &existsCondition{
table: table,
condition: condition,
}
}
// Write implements [Condition].
func (e existsCondition) Write(builder *StatementBuilder) {
builder.WriteString(" EXISTS (SELECT 1 FROM ")
builder.WriteString(e.table)
builder.WriteString(" WHERE ")
e.condition.Write(builder)
builder.WriteString(")")
}
// IsRestrictingColumn implements [Condition].
func (e existsCondition) IsRestrictingColumn(col Column) bool {
// Forward to the inner condition so safety checks (like instance_id presence) can still work.
return e.condition.IsRestrictingColumn(col)
}
var _ Condition = (*existsCondition)(nil)

View File

@@ -0,0 +1,248 @@
package database
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestWrite(t *testing.T) {
type want struct {
stmt string
args []any
}
for _, test := range []struct {
name string
cond Condition
want want
}{
{
name: "no and condition",
cond: And(),
want: want{
stmt: "",
args: nil,
},
},
{
name: "one and condition",
cond: And(
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
),
want: want{
stmt: "table.column1 = other_table.column2",
args: nil,
},
},
{
name: "multiple and condition",
cond: And(
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
NewColumnCondition(NewColumn("table", "column3"), NewColumn("other_table", "column4")),
),
want: want{
stmt: "(table.column1 = other_table.column2 AND table.column3 = other_table.column4)",
args: nil,
},
},
{
name: "no or condition",
cond: Or(),
want: want{
stmt: "",
args: nil,
},
},
{
name: "one or condition",
cond: Or(
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
),
want: want{
stmt: "table.column1 = other_table.column2",
args: nil,
},
},
{
name: "multiple or condition",
cond: Or(
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
NewColumnCondition(NewColumn("table", "column3"), NewColumn("other_table", "column4")),
),
want: want{
stmt: "(table.column1 = other_table.column2 OR table.column3 = other_table.column4)",
args: nil,
},
},
{
name: "is null condition",
cond: IsNull(NewColumn("table", "column1")),
want: want{
stmt: "table.column1 IS NULL",
args: nil,
},
},
{
name: "is not null condition",
cond: IsNotNull(NewColumn("table", "column1")),
want: want{
stmt: "table.column1 IS NOT NULL",
args: nil,
},
},
{
name: "text condition",
cond: NewTextCondition(NewColumn("table", "column1"), TextOperationEqual, "some text"),
want: want{
stmt: "table.column1 = $1",
args: []any{"some text"},
},
},
{
name: "text ignore case condition",
cond: NewTextIgnoreCaseCondition(NewColumn("table", "column1"), TextOperationNotEqual, "some TEXT"),
want: want{
stmt: "LOWER(table.column1) <> LOWER($1)",
args: []any{"some TEXT"},
},
},
{
name: "number condition",
cond: NewNumberCondition(NewColumn("table", "column1"), NumberOperationEqual, 42),
want: want{
stmt: "table.column1 = $1",
args: []any{42},
},
},
{
name: "boolean condition",
cond: NewBooleanCondition(NewColumn("table", "column1"), true),
want: want{
stmt: "table.column1 = $1",
args: []any{true},
},
},
{
name: "bytes condition",
cond: NewBytesCondition[[]byte](NewColumn("table", "column1"), BytesOperationEqual, []byte{0x01, 0x02, 0x03}),
want: want{
stmt: "table.column1 = $1",
args: []any{[]byte{0x01, 0x02, 0x03}},
},
},
{
name: "column condition",
cond: NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
want: want{
stmt: "table.column1 = other_table.column2",
args: nil,
},
},
{
name: "exists condition",
cond: Exists("table", And(
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
NewColumnCondition(NewColumn("table", "column3"), NewColumn("other_table", "column4")),
)),
want: want{
stmt: " EXISTS (SELECT 1 FROM table WHERE (table.column1 = other_table.column2 AND table.column3 = other_table.column4))",
args: nil,
},
},
} {
t.Run(test.name, func(t *testing.T) {
var builder StatementBuilder
test.cond.Write(&builder)
assert.Equal(t, test.want.stmt, builder.String())
require.Len(t, builder.Args(), len(test.want.args))
for i, arg := range test.want.args {
assert.Equal(t, arg, builder.Args()[i])
}
})
}
}
func TestIsRestrictingColumn(t *testing.T) {
for _, test := range []struct {
name string
col Column
cond Condition
want bool
}{
{
name: "and with restricting column",
col: NewColumn("table", "column1"),
cond: And(
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")),
),
want: true,
},
{
name: "and without restricting column",
col: NewColumn("table", "column1"),
cond: And(
NewColumnCondition(NewColumn("table", "column2"), NewColumn("other_table", "column3")),
IsNull(NewColumn("table", "column4")),
IsNotNull(NewColumn("table", "column5")),
),
want: false,
},
{
name: "or with restricting column",
col: NewColumn("table", "column1"),
cond: Or(
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")),
),
want: true,
},
{
name: "or without restricting column",
col: NewColumn("table", "column1"),
cond: Or(
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")),
IsNotNull(NewColumn("table", "column4")),
IsNull(NewColumn("table", "column5")),
),
want: false,
},
{
name: "is null never restricts",
col: NewColumn("table", "column1"),
cond: IsNull(NewColumn("table", "column1")),
want: false,
},
{
name: "is not null never restricts",
col: NewColumn("table", "column1"),
cond: IsNotNull(NewColumn("table", "column1")),
want: false,
},
{
name: "exists with restricting column",
col: NewColumn("table", "column1"),
cond: Exists("table", And(
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")),
)),
want: true,
},
{
name: "exists without restricting column",
col: NewColumn("table", "column1"),
cond: Exists("table", Or(
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")),
IsNotNull(NewColumn("table", "column4")),
IsNull(NewColumn("table", "column5")),
)),
want: false,
},
} {
t.Run(test.name, func(t *testing.T) {
isRestricting := test.cond.IsRestrictingColumn(test.col)
assert.Equal(t, test.want, isRestricting)
})
}
}

View File

@@ -10,7 +10,7 @@ type Pool interface {
QueryExecutor
Migrator
Acquire(ctx context.Context) (Client, error)
Acquire(ctx context.Context) (Connection, error)
Close(ctx context.Context) error
Ping(ctx context.Context) error
@@ -22,8 +22,8 @@ type PoolTest interface {
MigrateTest(ctx context.Context) error
}
// Client is a single database connection which can be released back to the pool.
type Client interface {
// Connection is a single database connection which can be released back to the pool.
type Connection interface {
Beginner
QueryExecutor
Migrator

View File

@@ -1,9 +1,9 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/zitadel/backend/v3/storage/database (interfaces: Pool,Client,Row,Rows,Transaction)
// Source: github.com/zitadel/zitadel/backend/v3/storage/database (interfaces: Pool,Connection,Row,Rows,Transaction)
//
// Generated by this command:
//
// mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Client,Row,Rows,Transaction
// mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Connection,Row,Rows,Transaction
//
// Package dbmock is a generated GoMock package.
@@ -21,6 +21,7 @@ import (
type MockPool struct {
ctrl *gomock.Controller
recorder *MockPoolMockRecorder
isgomock struct{}
}
// MockPoolMockRecorder is the mock recorder for MockPool.
@@ -41,18 +42,18 @@ func (m *MockPool) EXPECT() *MockPoolMockRecorder {
}
// Acquire mocks base method.
func (m *MockPool) Acquire(arg0 context.Context) (database.Client, error) {
func (m *MockPool) Acquire(ctx context.Context) (database.Connection, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Acquire", arg0)
ret0, _ := ret[0].(database.Client)
ret := m.ctrl.Call(m, "Acquire", ctx)
ret0, _ := ret[0].(database.Connection)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Acquire indicates an expected call of Acquire.
func (mr *MockPoolMockRecorder) Acquire(arg0 any) *MockPoolAcquireCall {
func (mr *MockPoolMockRecorder) Acquire(ctx any) *MockPoolAcquireCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockPool)(nil).Acquire), arg0)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockPool)(nil).Acquire), ctx)
return &MockPoolAcquireCall{Call: call}
}
@@ -62,36 +63,36 @@ type MockPoolAcquireCall struct {
}
// Return rewrite *gomock.Call.Return
func (c *MockPoolAcquireCall) Return(arg0 database.Client, arg1 error) *MockPoolAcquireCall {
func (c *MockPoolAcquireCall) Return(arg0 database.Connection, arg1 error) *MockPoolAcquireCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockPoolAcquireCall) Do(f func(context.Context) (database.Client, error)) *MockPoolAcquireCall {
func (c *MockPoolAcquireCall) Do(f func(context.Context) (database.Connection, error)) *MockPoolAcquireCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockPoolAcquireCall) DoAndReturn(f func(context.Context) (database.Client, error)) *MockPoolAcquireCall {
func (c *MockPoolAcquireCall) DoAndReturn(f func(context.Context) (database.Connection, error)) *MockPoolAcquireCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Begin mocks base method.
func (m *MockPool) Begin(arg0 context.Context, arg1 *database.TransactionOptions) (database.Transaction, error) {
func (m *MockPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Begin", arg0, arg1)
ret := m.ctrl.Call(m, "Begin", ctx, opts)
ret0, _ := ret[0].(database.Transaction)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Begin indicates an expected call of Begin.
func (mr *MockPoolMockRecorder) Begin(arg0, arg1 any) *MockPoolBeginCall {
func (mr *MockPoolMockRecorder) Begin(ctx, opts any) *MockPoolBeginCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockPool)(nil).Begin), arg0, arg1)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockPool)(nil).Begin), ctx, opts)
return &MockPoolBeginCall{Call: call}
}
@@ -119,17 +120,17 @@ func (c *MockPoolBeginCall) DoAndReturn(f func(context.Context, *database.Transa
}
// Close mocks base method.
func (m *MockPool) Close(arg0 context.Context) error {
func (m *MockPool) Close(ctx context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close", arg0)
ret := m.ctrl.Call(m, "Close", ctx)
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockPoolMockRecorder) Close(arg0 any) *MockPoolCloseCall {
func (mr *MockPoolMockRecorder) Close(ctx any) *MockPoolCloseCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPool)(nil).Close), arg0)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPool)(nil).Close), ctx)
return &MockPoolCloseCall{Call: call}
}
@@ -157,10 +158,10 @@ func (c *MockPoolCloseCall) DoAndReturn(f func(context.Context) error) *MockPool
}
// Exec mocks base method.
func (m *MockPool) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) {
func (m *MockPool) Exec(ctx context.Context, stmt string, args ...any) (int64, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1}
for _, a := range arg2 {
varargs := []any{ctx, stmt}
for _, a := range args {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
@@ -170,9 +171,9 @@ func (m *MockPool) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64,
}
// Exec indicates an expected call of Exec.
func (mr *MockPoolMockRecorder) Exec(arg0, arg1 any, arg2 ...any) *MockPoolExecCall {
func (mr *MockPoolMockRecorder) Exec(ctx, stmt any, args ...any) *MockPoolExecCall {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1}, arg2...)
varargs := append([]any{ctx, stmt}, args...)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockPool)(nil).Exec), varargs...)
return &MockPoolExecCall{Call: call}
}
@@ -201,17 +202,17 @@ func (c *MockPoolExecCall) DoAndReturn(f func(context.Context, string, ...any) (
}
// Migrate mocks base method.
func (m *MockPool) Migrate(arg0 context.Context) error {
func (m *MockPool) Migrate(ctx context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Migrate", arg0)
ret := m.ctrl.Call(m, "Migrate", ctx)
ret0, _ := ret[0].(error)
return ret0
}
// Migrate indicates an expected call of Migrate.
func (mr *MockPoolMockRecorder) Migrate(arg0 any) *MockPoolMigrateCall {
func (mr *MockPoolMockRecorder) Migrate(ctx any) *MockPoolMigrateCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockPool)(nil).Migrate), arg0)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockPool)(nil).Migrate), ctx)
return &MockPoolMigrateCall{Call: call}
}
@@ -238,11 +239,49 @@ func (c *MockPoolMigrateCall) DoAndReturn(f func(context.Context) error) *MockPo
return c
}
// Query mocks base method.
func (m *MockPool) Query(arg0 context.Context, arg1 string, arg2 ...any) (database.Rows, error) {
// Ping mocks base method.
func (m *MockPool) Ping(ctx context.Context) error {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1}
for _, a := range arg2 {
ret := m.ctrl.Call(m, "Ping", ctx)
ret0, _ := ret[0].(error)
return ret0
}
// Ping indicates an expected call of Ping.
func (mr *MockPoolMockRecorder) Ping(ctx any) *MockPoolPingCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockPool)(nil).Ping), ctx)
return &MockPoolPingCall{Call: call}
}
// MockPoolPingCall wrap *gomock.Call
type MockPoolPingCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockPoolPingCall) Return(arg0 error) *MockPoolPingCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockPoolPingCall) Do(f func(context.Context) error) *MockPoolPingCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockPoolPingCall) DoAndReturn(f func(context.Context) error) *MockPoolPingCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Query mocks base method.
func (m *MockPool) Query(ctx context.Context, stmt string, args ...any) (database.Rows, error) {
m.ctrl.T.Helper()
varargs := []any{ctx, stmt}
for _, a := range args {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Query", varargs...)
@@ -252,9 +291,9 @@ func (m *MockPool) Query(arg0 context.Context, arg1 string, arg2 ...any) (databa
}
// Query indicates an expected call of Query.
func (mr *MockPoolMockRecorder) Query(arg0, arg1 any, arg2 ...any) *MockPoolQueryCall {
func (mr *MockPoolMockRecorder) Query(ctx, stmt any, args ...any) *MockPoolQueryCall {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1}, arg2...)
varargs := append([]any{ctx, stmt}, args...)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockPool)(nil).Query), varargs...)
return &MockPoolQueryCall{Call: call}
}
@@ -283,10 +322,10 @@ func (c *MockPoolQueryCall) DoAndReturn(f func(context.Context, string, ...any)
}
// QueryRow mocks base method.
func (m *MockPool) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) database.Row {
func (m *MockPool) QueryRow(ctx context.Context, stmt string, args ...any) database.Row {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1}
for _, a := range arg2 {
varargs := []any{ctx, stmt}
for _, a := range args {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "QueryRow", varargs...)
@@ -295,9 +334,9 @@ func (m *MockPool) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) data
}
// QueryRow indicates an expected call of QueryRow.
func (mr *MockPoolMockRecorder) QueryRow(arg0, arg1 any, arg2 ...any) *MockPoolQueryRowCall {
func (mr *MockPoolMockRecorder) QueryRow(ctx, stmt any, args ...any) *MockPoolQueryRowCall {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1}, arg2...)
varargs := append([]any{ctx, stmt}, args...)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockPool)(nil).QueryRow), varargs...)
return &MockPoolQueryRowCall{Call: call}
}
@@ -325,73 +364,74 @@ func (c *MockPoolQueryRowCall) DoAndReturn(f func(context.Context, string, ...an
return c
}
// MockClient is a mock of Client interface.
type MockClient struct {
// MockConnection is a mock of Connection interface.
type MockConnection struct {
ctrl *gomock.Controller
recorder *MockClientMockRecorder
recorder *MockConnectionMockRecorder
isgomock struct{}
}
// MockClientMockRecorder is the mock recorder for MockClient.
type MockClientMockRecorder struct {
mock *MockClient
// MockConnectionMockRecorder is the mock recorder for MockConnection.
type MockConnectionMockRecorder struct {
mock *MockConnection
}
// NewMockClient creates a new mock instance.
func NewMockClient(ctrl *gomock.Controller) *MockClient {
mock := &MockClient{ctrl: ctrl}
mock.recorder = &MockClientMockRecorder{mock}
// NewMockConnection creates a new mock instance.
func NewMockConnection(ctrl *gomock.Controller) *MockConnection {
mock := &MockConnection{ctrl: ctrl}
mock.recorder = &MockConnectionMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockClient) EXPECT() *MockClientMockRecorder {
func (m *MockConnection) EXPECT() *MockConnectionMockRecorder {
return m.recorder
}
// Begin mocks base method.
func (m *MockClient) Begin(arg0 context.Context, arg1 *database.TransactionOptions) (database.Transaction, error) {
func (m *MockConnection) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Begin", arg0, arg1)
ret := m.ctrl.Call(m, "Begin", ctx, opts)
ret0, _ := ret[0].(database.Transaction)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Begin indicates an expected call of Begin.
func (mr *MockClientMockRecorder) Begin(arg0, arg1 any) *MockClientBeginCall {
func (mr *MockConnectionMockRecorder) Begin(ctx, opts any) *MockConnectionBeginCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockClient)(nil).Begin), arg0, arg1)
return &MockClientBeginCall{Call: call}
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockConnection)(nil).Begin), ctx, opts)
return &MockConnectionBeginCall{Call: call}
}
// MockClientBeginCall wrap *gomock.Call
type MockClientBeginCall struct {
// MockConnectionBeginCall wrap *gomock.Call
type MockConnectionBeginCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockClientBeginCall) Return(arg0 database.Transaction, arg1 error) *MockClientBeginCall {
func (c *MockConnectionBeginCall) Return(arg0 database.Transaction, arg1 error) *MockConnectionBeginCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockClientBeginCall) Do(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockClientBeginCall {
func (c *MockConnectionBeginCall) Do(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockConnectionBeginCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockClientBeginCall) DoAndReturn(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockClientBeginCall {
func (c *MockConnectionBeginCall) DoAndReturn(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockConnectionBeginCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Exec mocks base method.
func (m *MockClient) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) {
func (m *MockConnection) Exec(ctx context.Context, stmt string, args ...any) (int64, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1}
for _, a := range arg2 {
varargs := []any{ctx, stmt}
for _, a := range args {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
@@ -401,79 +441,117 @@ func (m *MockClient) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64
}
// Exec indicates an expected call of Exec.
func (mr *MockClientMockRecorder) Exec(arg0, arg1 any, arg2 ...any) *MockClientExecCall {
func (mr *MockConnectionMockRecorder) Exec(ctx, stmt any, args ...any) *MockConnectionExecCall {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1}, arg2...)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockClient)(nil).Exec), varargs...)
return &MockClientExecCall{Call: call}
varargs := append([]any{ctx, stmt}, args...)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockConnection)(nil).Exec), varargs...)
return &MockConnectionExecCall{Call: call}
}
// MockClientExecCall wrap *gomock.Call
type MockClientExecCall struct {
// MockConnectionExecCall wrap *gomock.Call
type MockConnectionExecCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockClientExecCall) Return(arg0 int64, arg1 error) *MockClientExecCall {
func (c *MockConnectionExecCall) Return(arg0 int64, arg1 error) *MockConnectionExecCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockClientExecCall) Do(f func(context.Context, string, ...any) (int64, error)) *MockClientExecCall {
func (c *MockConnectionExecCall) Do(f func(context.Context, string, ...any) (int64, error)) *MockConnectionExecCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockClientExecCall) DoAndReturn(f func(context.Context, string, ...any) (int64, error)) *MockClientExecCall {
func (c *MockConnectionExecCall) DoAndReturn(f func(context.Context, string, ...any) (int64, error)) *MockConnectionExecCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Migrate mocks base method.
func (m *MockClient) Migrate(arg0 context.Context) error {
func (m *MockConnection) Migrate(ctx context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Migrate", arg0)
ret := m.ctrl.Call(m, "Migrate", ctx)
ret0, _ := ret[0].(error)
return ret0
}
// Migrate indicates an expected call of Migrate.
func (mr *MockClientMockRecorder) Migrate(arg0 any) *MockClientMigrateCall {
func (mr *MockConnectionMockRecorder) Migrate(ctx any) *MockConnectionMigrateCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockClient)(nil).Migrate), arg0)
return &MockClientMigrateCall{Call: call}
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockConnection)(nil).Migrate), ctx)
return &MockConnectionMigrateCall{Call: call}
}
// MockClientMigrateCall wrap *gomock.Call
type MockClientMigrateCall struct {
// MockConnectionMigrateCall wrap *gomock.Call
type MockConnectionMigrateCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockClientMigrateCall) Return(arg0 error) *MockClientMigrateCall {
func (c *MockConnectionMigrateCall) Return(arg0 error) *MockConnectionMigrateCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockClientMigrateCall) Do(f func(context.Context) error) *MockClientMigrateCall {
func (c *MockConnectionMigrateCall) Do(f func(context.Context) error) *MockConnectionMigrateCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockClientMigrateCall) DoAndReturn(f func(context.Context) error) *MockClientMigrateCall {
func (c *MockConnectionMigrateCall) DoAndReturn(f func(context.Context) error) *MockConnectionMigrateCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Ping mocks base method.
func (m *MockConnection) Ping(ctx context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Ping", ctx)
ret0, _ := ret[0].(error)
return ret0
}
// Ping indicates an expected call of Ping.
func (mr *MockConnectionMockRecorder) Ping(ctx any) *MockConnectionPingCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockConnection)(nil).Ping), ctx)
return &MockConnectionPingCall{Call: call}
}
// MockConnectionPingCall wrap *gomock.Call
type MockConnectionPingCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockConnectionPingCall) Return(arg0 error) *MockConnectionPingCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockConnectionPingCall) Do(f func(context.Context) error) *MockConnectionPingCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockConnectionPingCall) DoAndReturn(f func(context.Context) error) *MockConnectionPingCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Query mocks base method.
func (m *MockClient) Query(arg0 context.Context, arg1 string, arg2 ...any) (database.Rows, error) {
func (m *MockConnection) Query(ctx context.Context, stmt string, args ...any) (database.Rows, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1}
for _, a := range arg2 {
varargs := []any{ctx, stmt}
for _, a := range args {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Query", varargs...)
@@ -483,41 +561,41 @@ func (m *MockClient) Query(arg0 context.Context, arg1 string, arg2 ...any) (data
}
// Query indicates an expected call of Query.
func (mr *MockClientMockRecorder) Query(arg0, arg1 any, arg2 ...any) *MockClientQueryCall {
func (mr *MockConnectionMockRecorder) Query(ctx, stmt any, args ...any) *MockConnectionQueryCall {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1}, arg2...)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockClient)(nil).Query), varargs...)
return &MockClientQueryCall{Call: call}
varargs := append([]any{ctx, stmt}, args...)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockConnection)(nil).Query), varargs...)
return &MockConnectionQueryCall{Call: call}
}
// MockClientQueryCall wrap *gomock.Call
type MockClientQueryCall struct {
// MockConnectionQueryCall wrap *gomock.Call
type MockConnectionQueryCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockClientQueryCall) Return(arg0 database.Rows, arg1 error) *MockClientQueryCall {
func (c *MockConnectionQueryCall) Return(arg0 database.Rows, arg1 error) *MockConnectionQueryCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockClientQueryCall) Do(f func(context.Context, string, ...any) (database.Rows, error)) *MockClientQueryCall {
func (c *MockConnectionQueryCall) Do(f func(context.Context, string, ...any) (database.Rows, error)) *MockConnectionQueryCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockClientQueryCall) DoAndReturn(f func(context.Context, string, ...any) (database.Rows, error)) *MockClientQueryCall {
func (c *MockConnectionQueryCall) DoAndReturn(f func(context.Context, string, ...any) (database.Rows, error)) *MockConnectionQueryCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// QueryRow mocks base method.
func (m *MockClient) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) database.Row {
func (m *MockConnection) QueryRow(ctx context.Context, stmt string, args ...any) database.Row {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1}
for _, a := range arg2 {
varargs := []any{ctx, stmt}
for _, a := range args {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "QueryRow", varargs...)
@@ -526,70 +604,70 @@ func (m *MockClient) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) da
}
// QueryRow indicates an expected call of QueryRow.
func (mr *MockClientMockRecorder) QueryRow(arg0, arg1 any, arg2 ...any) *MockClientQueryRowCall {
func (mr *MockConnectionMockRecorder) QueryRow(ctx, stmt any, args ...any) *MockConnectionQueryRowCall {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1}, arg2...)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockClient)(nil).QueryRow), varargs...)
return &MockClientQueryRowCall{Call: call}
varargs := append([]any{ctx, stmt}, args...)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockConnection)(nil).QueryRow), varargs...)
return &MockConnectionQueryRowCall{Call: call}
}
// MockClientQueryRowCall wrap *gomock.Call
type MockClientQueryRowCall struct {
// MockConnectionQueryRowCall wrap *gomock.Call
type MockConnectionQueryRowCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockClientQueryRowCall) Return(arg0 database.Row) *MockClientQueryRowCall {
func (c *MockConnectionQueryRowCall) Return(arg0 database.Row) *MockConnectionQueryRowCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockClientQueryRowCall) Do(f func(context.Context, string, ...any) database.Row) *MockClientQueryRowCall {
func (c *MockConnectionQueryRowCall) Do(f func(context.Context, string, ...any) database.Row) *MockConnectionQueryRowCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockClientQueryRowCall) DoAndReturn(f func(context.Context, string, ...any) database.Row) *MockClientQueryRowCall {
func (c *MockConnectionQueryRowCall) DoAndReturn(f func(context.Context, string, ...any) database.Row) *MockConnectionQueryRowCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
// Release mocks base method.
func (m *MockClient) Release(arg0 context.Context) error {
func (m *MockConnection) Release(ctx context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Release", arg0)
ret := m.ctrl.Call(m, "Release", ctx)
ret0, _ := ret[0].(error)
return ret0
}
// Release indicates an expected call of Release.
func (mr *MockClientMockRecorder) Release(arg0 any) *MockClientReleaseCall {
func (mr *MockConnectionMockRecorder) Release(ctx any) *MockConnectionReleaseCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Release", reflect.TypeOf((*MockClient)(nil).Release), arg0)
return &MockClientReleaseCall{Call: call}
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Release", reflect.TypeOf((*MockConnection)(nil).Release), ctx)
return &MockConnectionReleaseCall{Call: call}
}
// MockClientReleaseCall wrap *gomock.Call
type MockClientReleaseCall struct {
// MockConnectionReleaseCall wrap *gomock.Call
type MockConnectionReleaseCall struct {
*gomock.Call
}
// Return rewrite *gomock.Call.Return
func (c *MockClientReleaseCall) Return(arg0 error) *MockClientReleaseCall {
func (c *MockConnectionReleaseCall) Return(arg0 error) *MockConnectionReleaseCall {
c.Call = c.Call.Return(arg0)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockClientReleaseCall) Do(f func(context.Context) error) *MockClientReleaseCall {
func (c *MockConnectionReleaseCall) Do(f func(context.Context) error) *MockConnectionReleaseCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockClientReleaseCall) DoAndReturn(f func(context.Context) error) *MockClientReleaseCall {
func (c *MockConnectionReleaseCall) DoAndReturn(f func(context.Context) error) *MockConnectionReleaseCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
@@ -598,6 +676,7 @@ func (c *MockClientReleaseCall) DoAndReturn(f func(context.Context) error) *Mock
type MockRow struct {
ctrl *gomock.Controller
recorder *MockRowMockRecorder
isgomock struct{}
}
// MockRowMockRecorder is the mock recorder for MockRow.
@@ -618,10 +697,10 @@ func (m *MockRow) EXPECT() *MockRowMockRecorder {
}
// Scan mocks base method.
func (m *MockRow) Scan(arg0 ...any) error {
func (m *MockRow) Scan(dest ...any) error {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range arg0 {
for _, a := range dest {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Scan", varargs...)
@@ -630,9 +709,9 @@ func (m *MockRow) Scan(arg0 ...any) error {
}
// Scan indicates an expected call of Scan.
func (mr *MockRowMockRecorder) Scan(arg0 ...any) *MockRowScanCall {
func (mr *MockRowMockRecorder) Scan(dest ...any) *MockRowScanCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRow)(nil).Scan), arg0...)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRow)(nil).Scan), dest...)
return &MockRowScanCall{Call: call}
}
@@ -663,6 +742,7 @@ func (c *MockRowScanCall) DoAndReturn(f func(...any) error) *MockRowScanCall {
type MockRows struct {
ctrl *gomock.Controller
recorder *MockRowsMockRecorder
isgomock struct{}
}
// MockRowsMockRecorder is the mock recorder for MockRows.
@@ -797,10 +877,10 @@ func (c *MockRowsNextCall) DoAndReturn(f func() bool) *MockRowsNextCall {
}
// Scan mocks base method.
func (m *MockRows) Scan(arg0 ...any) error {
func (m *MockRows) Scan(dest ...any) error {
m.ctrl.T.Helper()
varargs := []any{}
for _, a := range arg0 {
for _, a := range dest {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Scan", varargs...)
@@ -809,9 +889,9 @@ func (m *MockRows) Scan(arg0 ...any) error {
}
// Scan indicates an expected call of Scan.
func (mr *MockRowsMockRecorder) Scan(arg0 ...any) *MockRowsScanCall {
func (mr *MockRowsMockRecorder) Scan(dest ...any) *MockRowsScanCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), arg0...)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), dest...)
return &MockRowsScanCall{Call: call}
}
@@ -842,6 +922,7 @@ func (c *MockRowsScanCall) DoAndReturn(f func(...any) error) *MockRowsScanCall {
type MockTransaction struct {
ctrl *gomock.Controller
recorder *MockTransactionMockRecorder
isgomock struct{}
}
// MockTransactionMockRecorder is the mock recorder for MockTransaction.
@@ -862,18 +943,18 @@ func (m *MockTransaction) EXPECT() *MockTransactionMockRecorder {
}
// Begin mocks base method.
func (m *MockTransaction) Begin(arg0 context.Context) (database.Transaction, error) {
func (m *MockTransaction) Begin(ctx context.Context) (database.Transaction, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Begin", arg0)
ret := m.ctrl.Call(m, "Begin", ctx)
ret0, _ := ret[0].(database.Transaction)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Begin indicates an expected call of Begin.
func (mr *MockTransactionMockRecorder) Begin(arg0 any) *MockTransactionBeginCall {
func (mr *MockTransactionMockRecorder) Begin(ctx any) *MockTransactionBeginCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockTransaction)(nil).Begin), arg0)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockTransaction)(nil).Begin), ctx)
return &MockTransactionBeginCall{Call: call}
}
@@ -901,17 +982,17 @@ func (c *MockTransactionBeginCall) DoAndReturn(f func(context.Context) (database
}
// Commit mocks base method.
func (m *MockTransaction) Commit(arg0 context.Context) error {
func (m *MockTransaction) Commit(ctx context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Commit", arg0)
ret := m.ctrl.Call(m, "Commit", ctx)
ret0, _ := ret[0].(error)
return ret0
}
// Commit indicates an expected call of Commit.
func (mr *MockTransactionMockRecorder) Commit(arg0 any) *MockTransactionCommitCall {
func (mr *MockTransactionMockRecorder) Commit(ctx any) *MockTransactionCommitCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTransaction)(nil).Commit), arg0)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTransaction)(nil).Commit), ctx)
return &MockTransactionCommitCall{Call: call}
}
@@ -939,17 +1020,17 @@ func (c *MockTransactionCommitCall) DoAndReturn(f func(context.Context) error) *
}
// End mocks base method.
func (m *MockTransaction) End(arg0 context.Context, arg1 error) error {
func (m *MockTransaction) End(ctx context.Context, err error) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "End", arg0, arg1)
ret := m.ctrl.Call(m, "End", ctx, err)
ret0, _ := ret[0].(error)
return ret0
}
// End indicates an expected call of End.
func (mr *MockTransactionMockRecorder) End(arg0, arg1 any) *MockTransactionEndCall {
func (mr *MockTransactionMockRecorder) End(ctx, err any) *MockTransactionEndCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "End", reflect.TypeOf((*MockTransaction)(nil).End), arg0, arg1)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "End", reflect.TypeOf((*MockTransaction)(nil).End), ctx, err)
return &MockTransactionEndCall{Call: call}
}
@@ -977,10 +1058,10 @@ func (c *MockTransactionEndCall) DoAndReturn(f func(context.Context, error) erro
}
// Exec mocks base method.
func (m *MockTransaction) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) {
func (m *MockTransaction) Exec(ctx context.Context, stmt string, args ...any) (int64, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1}
for _, a := range arg2 {
varargs := []any{ctx, stmt}
for _, a := range args {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
@@ -990,9 +1071,9 @@ func (m *MockTransaction) Exec(arg0 context.Context, arg1 string, arg2 ...any) (
}
// Exec indicates an expected call of Exec.
func (mr *MockTransactionMockRecorder) Exec(arg0, arg1 any, arg2 ...any) *MockTransactionExecCall {
func (mr *MockTransactionMockRecorder) Exec(ctx, stmt any, args ...any) *MockTransactionExecCall {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1}, arg2...)
varargs := append([]any{ctx, stmt}, args...)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTransaction)(nil).Exec), varargs...)
return &MockTransactionExecCall{Call: call}
}
@@ -1021,10 +1102,10 @@ func (c *MockTransactionExecCall) DoAndReturn(f func(context.Context, string, ..
}
// Query mocks base method.
func (m *MockTransaction) Query(arg0 context.Context, arg1 string, arg2 ...any) (database.Rows, error) {
func (m *MockTransaction) Query(ctx context.Context, stmt string, args ...any) (database.Rows, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1}
for _, a := range arg2 {
varargs := []any{ctx, stmt}
for _, a := range args {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Query", varargs...)
@@ -1034,9 +1115,9 @@ func (m *MockTransaction) Query(arg0 context.Context, arg1 string, arg2 ...any)
}
// Query indicates an expected call of Query.
func (mr *MockTransactionMockRecorder) Query(arg0, arg1 any, arg2 ...any) *MockTransactionQueryCall {
func (mr *MockTransactionMockRecorder) Query(ctx, stmt any, args ...any) *MockTransactionQueryCall {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1}, arg2...)
varargs := append([]any{ctx, stmt}, args...)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockTransaction)(nil).Query), varargs...)
return &MockTransactionQueryCall{Call: call}
}
@@ -1065,10 +1146,10 @@ func (c *MockTransactionQueryCall) DoAndReturn(f func(context.Context, string, .
}
// QueryRow mocks base method.
func (m *MockTransaction) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) database.Row {
func (m *MockTransaction) QueryRow(ctx context.Context, stmt string, args ...any) database.Row {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1}
for _, a := range arg2 {
varargs := []any{ctx, stmt}
for _, a := range args {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "QueryRow", varargs...)
@@ -1077,9 +1158,9 @@ func (m *MockTransaction) QueryRow(arg0 context.Context, arg1 string, arg2 ...an
}
// QueryRow indicates an expected call of QueryRow.
func (mr *MockTransactionMockRecorder) QueryRow(arg0, arg1 any, arg2 ...any) *MockTransactionQueryRowCall {
func (mr *MockTransactionMockRecorder) QueryRow(ctx, stmt any, args ...any) *MockTransactionQueryRowCall {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1}, arg2...)
varargs := append([]any{ctx, stmt}, args...)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockTransaction)(nil).QueryRow), varargs...)
return &MockTransactionQueryRowCall{Call: call}
}
@@ -1108,17 +1189,17 @@ func (c *MockTransactionQueryRowCall) DoAndReturn(f func(context.Context, string
}
// Rollback mocks base method.
func (m *MockTransaction) Rollback(arg0 context.Context) error {
func (m *MockTransaction) Rollback(ctx context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Rollback", arg0)
ret := m.ctrl.Call(m, "Rollback", ctx)
ret0, _ := ret[0].(error)
return ret0
}
// Rollback indicates an expected call of Rollback.
func (mr *MockTransactionMockRecorder) Rollback(arg0 any) *MockTransactionRollbackCall {
func (mr *MockTransactionMockRecorder) Rollback(ctx any) *MockTransactionRollbackCall {
mr.mock.ctrl.T.Helper()
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockTransaction)(nil).Rollback), arg0)
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockTransaction)(nil).Rollback), ctx)
return &MockTransactionRollbackCall{Call: call}
}

View File

@@ -13,15 +13,15 @@ type pgxConn struct {
*pgxpool.Conn
}
var _ database.Client = (*pgxConn)(nil)
var _ database.Connection = (*pgxConn)(nil)
// Release implements [database.Client].
// Release implements [database.Connection].
func (c *pgxConn) Release(_ context.Context) error {
c.Conn.Release()
return nil
}
// Begin implements [database.Client].
// Begin implements [database.Connection].
func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
tx, err := c.BeginTx(ctx, transactionOptionsToPgx(opts))
if err != nil {
@@ -30,8 +30,8 @@ func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions)
return &Transaction{tx}, nil
}
// Query implements sql.Client.
// Subtle: this method shadows the method (*Conn).Query of pgxConn.Conn.
// Query implements [database.Connection].
// Subtle: this method shadows the method (*Conn).Query of [pgxConn.Conn].
func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
rows, err := c.Conn.Query(ctx, sql, args...)
if err != nil {
@@ -40,14 +40,14 @@ func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.
return &Rows{rows}, nil
}
// QueryRow implements sql.Client.
// Subtle: this method shadows the method (*Conn).QueryRow of pgxConn.Conn.
// QueryRow implements [database.Connection].
// Subtle: this method shadows the method (*Conn).QueryRow of [pgxConn.Conn].
func (c *pgxConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return &Row{c.Conn.QueryRow(ctx, sql, args...)}
}
// Exec implements [database.Pool].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
// QueryRow implements [database.Connection].
// Subtle: this method shadows the method (*Conn).QueryRow of [pgxConn.Conn].
func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
res, err := c.Conn.Exec(ctx, sql, args...)
if err != nil {

View File

@@ -2,6 +2,7 @@ package postgres
import (
"errors"
"strings"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
@@ -16,6 +17,18 @@ func wrapError(err error) error {
if errors.Is(err, pgx.ErrNoRows) {
return database.NewNoRowFoundError(err)
}
if errors.Is(err, pgx.ErrTooManyRows) {
return database.NewMultipleRowsFoundError(err)
}
// scany only exports its errors as strings
if strings.HasPrefix(err.Error(), "scany: expected 1 row, got: ") {
return database.NewMultipleRowsFoundError(err)
}
if strings.HasPrefix(err.Error(), "scany:") || strings.HasPrefix(err.Error(), "scanning:") {
return database.NewScanError(err)
}
var pgxErr *pgconn.PgError
if !errors.As(err, &pgxErr) {
return database.NewUnknownError(err)

View File

@@ -0,0 +1,16 @@
package migration
import (
_ "embed"
)
var (
//go:embed 004_correct_set_updated_at/up.sql
up004CorrectSetUpdatedAt string
//go:embed 004_correct_set_updated_at/down.sql
down004CorrectSetUpdatedAt string
)
func init() {
registerSQLMigration(4, up004CorrectSetUpdatedAt, down004CorrectSetUpdatedAt)
}

View File

@@ -0,0 +1,23 @@
CREATE OR REPLACE TRIGGER trigger_set_updated_at
BEFORE UPDATE ON zitadel.instances
FOR EACH ROW
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
EXECUTE FUNCTION zitadel.set_updated_at();
CREATE OR REPLACE TRIGGER trigger_set_updated_at
BEFORE UPDATE ON zitadel.organizations
FOR EACH ROW
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
EXECUTE FUNCTION zitadel.set_updated_at();
CREATE OR REPLACE TRIGGER trg_set_updated_at_instance_domains
BEFORE UPDATE ON zitadel.instance_domains
FOR EACH ROW
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
EXECUTE FUNCTION zitadel.set_updated_at();
CREATE OR REPLACE TRIGGER trg_set_updated_at_org_domains
BEFORE UPDATE ON zitadel.org_domains
FOR EACH ROW
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
EXECUTE FUNCTION zitadel.set_updated_at();

View File

@@ -0,0 +1,23 @@
CREATE OR REPLACE TRIGGER trigger_set_updated_at
BEFORE UPDATE ON zitadel.instances
FOR EACH ROW
WHEN (NEW.updated_at IS NULL)
EXECUTE FUNCTION zitadel.set_updated_at();
CREATE OR REPLACE TRIGGER trigger_set_updated_at
BEFORE UPDATE ON zitadel.organizations
FOR EACH ROW
WHEN (NEW.updated_at IS NULL)
EXECUTE FUNCTION zitadel.set_updated_at();
CREATE OR REPLACE TRIGGER trg_set_updated_at_instance_domains
BEFORE UPDATE ON zitadel.instance_domains
FOR EACH ROW
WHEN (NEW.updated_at IS NULL)
EXECUTE FUNCTION zitadel.set_updated_at();
CREATE OR REPLACE TRIGGER trg_set_updated_at_org_domains
BEFORE UPDATE ON zitadel.org_domains
FOR EACH ROW
WHEN (NEW.updated_at IS NULL)
EXECUTE FUNCTION zitadel.set_updated_at();

View File

@@ -22,8 +22,8 @@ func PGxPool(pool *pgxpool.Pool) *pgxPool {
}
// Acquire implements [database.Pool].
func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) {
conn, err := c.Pool.Acquire(ctx)
func (p *pgxPool) Acquire(ctx context.Context) (database.Connection, error) {
conn, err := p.Pool.Acquire(ctx)
if err != nil {
return nil, wrapError(err)
}
@@ -32,8 +32,8 @@ func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) {
// Query implements [database.Pool].
// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool.
func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
rows, err := c.Pool.Query(ctx, sql, args...)
func (p *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
rows, err := p.Pool.Query(ctx, sql, args...)
if err != nil {
return nil, wrapError(err)
}
@@ -42,14 +42,14 @@ func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.
// QueryRow implements [database.Pool].
// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool.
func (c *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return &Row{c.Pool.QueryRow(ctx, sql, args...)}
func (p *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return &Row{p.Pool.QueryRow(ctx, sql, args...)}
}
// Exec implements [database.Pool].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
res, err := c.Pool.Exec(ctx, sql, args...)
func (p *pgxPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
res, err := p.Pool.Exec(ctx, sql, args...)
if err != nil {
return 0, wrapError(err)
}
@@ -57,8 +57,8 @@ func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) (int64, err
}
// Begin implements [database.Pool].
func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
tx, err := c.BeginTx(ctx, transactionOptionsToPgx(opts))
func (p *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
tx, err := p.BeginTx(ctx, transactionOptionsToPgx(opts))
if err != nil {
return nil, wrapError(err)
}
@@ -66,23 +66,23 @@ func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions)
}
// Close implements [database.Pool].
func (c *pgxPool) Close(_ context.Context) error {
c.Pool.Close()
func (p *pgxPool) Close(_ context.Context) error {
p.Pool.Close()
return nil
}
// Ping implements [database.Pool].
func (c *pgxPool) Ping(ctx context.Context) error {
return wrapError(c.Pool.Ping(ctx))
func (p *pgxPool) Ping(ctx context.Context) error {
return wrapError(p.Pool.Ping(ctx))
}
// Migrate implements [database.Migrator].
func (c *pgxPool) Migrate(ctx context.Context) error {
func (p *pgxPool) Migrate(ctx context.Context) error {
if isMigrated {
return nil
}
client, err := c.Pool.Acquire(ctx)
client, err := p.Pool.Acquire(ctx)
if err != nil {
return err
}
@@ -93,8 +93,8 @@ func (c *pgxPool) Migrate(ctx context.Context) error {
}
// Migrate implements [database.PoolTest].
func (c *pgxPool) MigrateTest(ctx context.Context) error {
client, err := c.Pool.Acquire(ctx)
func (p *pgxPool) MigrateTest(ctx context.Context) error {
client, err := p.Pool.Acquire(ctx)
if err != nil {
return err
}

View File

@@ -11,18 +11,18 @@ type sqlConn struct {
*sql.Conn
}
func SQLConn(conn *sql.Conn) database.Client {
func SQLConn(conn *sql.Conn) database.Connection {
return &sqlConn{Conn: conn}
}
var _ database.Client = (*sqlConn)(nil)
var _ database.Connection = (*sqlConn)(nil)
// Release implements [database.Client].
// Release implements [database.Connection].
func (c *sqlConn) Release(_ context.Context) error {
return c.Close()
}
// Begin implements [database.Client].
// Begin implements [database.Connection].
func (c *sqlConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
tx, err := c.BeginTx(ctx, transactionOptionsToSQL(opts))
if err != nil {

View File

@@ -3,6 +3,7 @@ package sql
import (
"database/sql"
"errors"
"strings"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
@@ -19,6 +20,15 @@ func wrapError(err error) error {
if errors.Is(err, pgx.ErrNoRows) || errors.Is(err, sql.ErrNoRows) {
return database.NewNoRowFoundError(err)
}
// scany only exports its errors as strings
if strings.HasPrefix(err.Error(), "scany: expected 1 row, got: ") {
return database.NewMultipleRowsFoundError(err)
}
if strings.HasPrefix(err.Error(), "scany:") || strings.HasPrefix(err.Error(), "scanning:") {
return database.NewScanError(err)
}
var pgxErr *pgconn.PgError
if !errors.As(err, &pgxErr) {
return database.NewUnknownError(err)

View File

@@ -20,8 +20,8 @@ func SQLPool(db *sql.DB) *sqlPool {
}
// Acquire implements [database.Pool].
func (c *sqlPool) Acquire(ctx context.Context) (database.Client, error) {
conn, err := c.Conn(ctx)
func (p *sqlPool) Acquire(ctx context.Context) (database.Connection, error) {
conn, err := p.Conn(ctx)
if err != nil {
return nil, wrapError(err)
}
@@ -30,9 +30,9 @@ func (c *sqlPool) Acquire(ctx context.Context) (database.Client, error) {
// Query implements [database.Pool].
// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool.
func (c *sqlPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
func (p *sqlPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
//nolint:rowserrcheck // Rows.Close is called by the caller
rows, err := c.QueryContext(ctx, sql, args...)
rows, err := p.QueryContext(ctx, sql, args...)
if err != nil {
return nil, wrapError(err)
}
@@ -41,14 +41,14 @@ func (c *sqlPool) Query(ctx context.Context, sql string, args ...any) (database.
// QueryRow implements [database.Pool].
// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool.
func (c *sqlPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return &Row{c.QueryRowContext(ctx, sql, args...)}
func (p *sqlPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return &Row{p.QueryRowContext(ctx, sql, args...)}
}
// Exec implements [database.Pool].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (c *sqlPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
res, err := c.ExecContext(ctx, sql, args...)
func (p *sqlPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
res, err := p.ExecContext(ctx, sql, args...)
if err != nil {
return 0, wrapError(err)
}
@@ -56,8 +56,8 @@ func (c *sqlPool) Exec(ctx context.Context, sql string, args ...any) (int64, err
}
// Begin implements [database.Pool].
func (c *sqlPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
tx, err := c.BeginTx(ctx, transactionOptionsToSQL(opts))
func (p *sqlPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
tx, err := p.BeginTx(ctx, transactionOptionsToSQL(opts))
if err != nil {
return nil, wrapError(err)
}
@@ -65,16 +65,16 @@ func (c *sqlPool) Begin(ctx context.Context, opts *database.TransactionOptions)
}
// Ping implements [database.Pool].
func (c *sqlPool) Ping(ctx context.Context) error {
return wrapError(c.PingContext(ctx))
func (p *sqlPool) Ping(ctx context.Context) error {
return wrapError(p.PingContext(ctx))
}
// Close implements [database.Pool].
func (c *sqlPool) Close(_ context.Context) error {
return c.DB.Close()
func (p *sqlPool) Close(_ context.Context) error {
return p.DB.Close()
}
// Migrate implements [database.Migrator].
func (c *sqlPool) Migrate(ctx context.Context) error {
func (p *sqlPool) Migrate(ctx context.Context) error {
return ErrMigrate
}

View File

@@ -5,7 +5,34 @@ import (
"fmt"
)
var ErrNoChanges = errors.New("update must contain a change")
var (
ErrNoChanges = errors.New("update must contain a change")
)
type MissingConditionError struct {
col Column
}
func NewMissingConditionError(col Column) error {
return &MissingConditionError{
col: col,
}
}
func (e *MissingConditionError) Error() string {
var builder StatementBuilder
builder.WriteString("missing condition for column")
if e.col != nil {
builder.WriteString(" on ")
e.col.WriteQualified(&builder)
}
return builder.String()
}
func (e *MissingConditionError) Is(target error) bool {
_, ok := target.(*MissingConditionError)
return ok
}
// NoRowFoundError is returned when QueryRow does not find any row.
// It wraps the dialect specific original error to provide more context.
@@ -20,6 +47,9 @@ func NewNoRowFoundError(original error) error {
}
func (e *NoRowFoundError) Error() string {
if e.original != nil {
return fmt.Sprintf("no row found: %v", e.original)
}
return "no row found"
}
@@ -36,18 +66,19 @@ func (e *NoRowFoundError) Unwrap() error {
// It wraps the dialect specific original error to provide more context.
type MultipleRowsFoundError struct {
original error
count int
}
func NewMultipleRowsFoundError(original error, count int) error {
func NewMultipleRowsFoundError(original error) error {
return &MultipleRowsFoundError{
original: original,
count: count,
}
}
func (e *MultipleRowsFoundError) Error() string {
return fmt.Sprintf("multiple rows found: %d", e.count)
if e.original != nil {
return fmt.Sprintf("multiple rows found: %v", e.original)
}
return "multiple rows found"
}
func (e *MultipleRowsFoundError) Is(target error) bool {
@@ -77,8 +108,8 @@ type IntegrityViolationError struct {
original error
}
func NewIntegrityViolationError(typ IntegrityType, table, constraint string, original error) error {
return &IntegrityViolationError{
func newIntegrityViolationError(typ IntegrityType, table, constraint string, original error) IntegrityViolationError {
return IntegrityViolationError{
integrityType: typ,
table: table,
constraint: constraint,
@@ -87,7 +118,10 @@ func NewIntegrityViolationError(typ IntegrityType, table, constraint string, ori
}
func (e *IntegrityViolationError) Error() string {
if e.original != nil {
return fmt.Sprintf("integrity violation of type %q on %q (constraint: %q): %v", e.integrityType, e.table, e.constraint, e.original)
}
return fmt.Sprintf("integrity violation of type %q on %q (constraint: %q)", e.integrityType, e.table, e.constraint)
}
func (e *IntegrityViolationError) Is(target error) bool {
@@ -108,12 +142,7 @@ type CheckError struct {
func NewCheckError(table, constraint string, original error) error {
return &CheckError{
IntegrityViolationError: IntegrityViolationError{
integrityType: IntegrityTypeCheck,
table: table,
constraint: constraint,
original: original,
},
IntegrityViolationError: newIntegrityViolationError(IntegrityTypeCheck, table, constraint, original),
}
}
@@ -135,12 +164,7 @@ type UniqueError struct {
func NewUniqueError(table, constraint string, original error) error {
return &UniqueError{
IntegrityViolationError: IntegrityViolationError{
integrityType: IntegrityTypeUnique,
table: table,
constraint: constraint,
original: original,
},
IntegrityViolationError: newIntegrityViolationError(IntegrityTypeUnique, table, constraint, original),
}
}
@@ -162,12 +186,7 @@ type ForeignKeyError struct {
func NewForeignKeyError(table, constraint string, original error) error {
return &ForeignKeyError{
IntegrityViolationError: IntegrityViolationError{
integrityType: IntegrityTypeForeign,
table: table,
constraint: constraint,
original: original,
},
IntegrityViolationError: newIntegrityViolationError(IntegrityTypeForeign, table, constraint, original),
}
}
@@ -189,12 +208,7 @@ type NotNullError struct {
func NewNotNullError(table, constraint string, original error) error {
return &NotNullError{
IntegrityViolationError: IntegrityViolationError{
integrityType: IntegrityTypeNotNull,
table: table,
constraint: constraint,
original: original,
},
IntegrityViolationError: newIntegrityViolationError(IntegrityTypeNotNull, table, constraint, original),
}
}
@@ -207,6 +221,31 @@ func (e *NotNullError) Unwrap() error {
return &e.IntegrityViolationError
}
// ScanError is returned when scanning rows into objects failed.
// It wraps the original error to provide more context.
type ScanError struct {
original error
}
func NewScanError(original error) error {
return &ScanError{
original: original,
}
}
func (e *ScanError) Error() string {
return fmt.Sprintf("Scan error: %v", e.original)
}
func (e *ScanError) Is(target error) bool {
_, ok := target.(*ScanError)
return ok
}
func (e *ScanError) Unwrap() error {
return e.original
}
// UnknownError is returned when an unknown error occurs.
// It wraps the dialect specific original error to provide more context.
// It is used to indicate that an error occurred that does not fit into any of the other categories.

View File

@@ -0,0 +1,304 @@
package database
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
func TestError(t *testing.T) {
for _, test := range []struct {
name string
err error
want string
}{
{
name: "missing condition without column",
err: NewMissingConditionError(nil),
want: "missing condition for column",
},
{
name: "missing condition with column",
err: NewMissingConditionError(NewColumn("table", "column")),
want: "missing condition for column on table.column",
},
{
name: "no row found without original error",
err: NewNoRowFoundError(nil),
want: "no row found",
},
{
name: "no row found with original error",
err: NewNoRowFoundError(errors.New("original error")),
want: "no row found: original error",
},
{
name: "multiple rows found without original error",
err: NewMultipleRowsFoundError(nil),
want: "multiple rows found",
},
{
name: "multiple rows found with original error",
err: NewMultipleRowsFoundError(errors.New("original error")),
want: "multiple rows found: original error",
},
{
name: "check violation without original error",
err: NewCheckError("table", "constraint", nil),
want: `integrity violation of type "check" on "table" (constraint: "constraint")`,
},
{
name: "check violation with original error",
err: NewCheckError("table", "constraint", errors.New("original error")),
want: `integrity violation of type "check" on "table" (constraint: "constraint"): original error`,
},
{
name: "unique violation without original error",
err: NewUniqueError("table", "constraint", nil),
want: `integrity violation of type "unique" on "table" (constraint: "constraint")`,
},
{
name: "unique violation with original error",
err: NewUniqueError("table", "constraint", errors.New("original error")),
want: `integrity violation of type "unique" on "table" (constraint: "constraint"): original error`,
},
{
name: "foreign key violation without original error",
err: NewForeignKeyError("table", "constraint", nil),
want: `integrity violation of type "foreign" on "table" (constraint: "constraint")`,
},
{
name: "foreign key violation with original error",
err: NewForeignKeyError("table", "constraint", errors.New("original error")),
want: `integrity violation of type "foreign" on "table" (constraint: "constraint"): original error`,
},
{
name: "not null violation without original error",
err: NewNotNullError("table", "constraint", nil),
want: `integrity violation of type "not null" on "table" (constraint: "constraint")`,
},
{
name: "not null violation with original error",
err: NewNotNullError("table", "constraint", errors.New("original error")),
want: `integrity violation of type "not null" on "table" (constraint: "constraint"): original error`,
},
{
name: "unknown error",
err: NewUnknownError(errors.New("original error")),
want: `unknown database error: original error`,
},
} {
t.Run(test.name, func(t *testing.T) {
assert.Equal(t, test.want, test.err.Error())
})
}
}
func TestUnwrap(t *testing.T) {
originalErr := errors.New("original error")
for _, test := range []struct {
name string
err error
want error
}{
{
name: "missing condition without column",
err: NewMissingConditionError(nil),
want: nil,
},
{
name: "missing condition with column",
err: NewMissingConditionError(NewColumn("table", "column")),
want: nil,
},
{
name: "no row found without original error",
err: NewNoRowFoundError(nil),
want: nil,
},
{
name: "no row found with original error",
err: NewNoRowFoundError(errors.New("original error")),
want: originalErr,
},
{
name: "multiple rows found without original error",
err: NewMultipleRowsFoundError(nil),
want: nil,
},
{
name: "multiple rows found with original error",
err: NewMultipleRowsFoundError(originalErr),
want: originalErr,
},
{
name: "check violation without original error",
err: NewCheckError("table", "constraint", nil),
want: &IntegrityViolationError{
integrityType: IntegrityTypeCheck,
table: "table",
constraint: "constraint",
original: nil,
},
},
{
name: "check violation with original error",
err: NewCheckError("table", "constraint", originalErr),
want: &IntegrityViolationError{
integrityType: IntegrityTypeCheck,
table: "table",
constraint: "constraint",
original: originalErr,
},
},
{
name: "unique violation without original error",
err: NewUniqueError("table", "constraint", nil),
want: &IntegrityViolationError{
integrityType: IntegrityTypeUnique,
table: "table",
constraint: "constraint",
original: nil,
},
},
{
name: "unique violation with original error",
err: NewUniqueError("table", "constraint", originalErr),
want: &IntegrityViolationError{
integrityType: IntegrityTypeUnique,
table: "table",
constraint: "constraint",
original: originalErr,
},
},
{
name: "foreign key violation without original error",
err: NewForeignKeyError("table", "constraint", nil),
want: &IntegrityViolationError{
integrityType: IntegrityTypeForeign,
table: "table",
constraint: "constraint",
original: nil,
},
},
{
name: "foreign key violation with original error",
err: NewForeignKeyError("table", "constraint", originalErr),
want: &IntegrityViolationError{
integrityType: IntegrityTypeForeign,
table: "table",
constraint: "constraint",
original: originalErr,
},
},
{
name: "not null violation without original error",
err: NewNotNullError("table", "constraint", nil),
want: &IntegrityViolationError{
integrityType: IntegrityTypeNotNull,
table: "table",
constraint: "constraint",
original: nil,
},
},
{
name: "not null violation with original error",
err: NewNotNullError("table", "constraint", originalErr),
want: &IntegrityViolationError{
integrityType: IntegrityTypeNotNull,
table: "table",
constraint: "constraint",
original: originalErr,
},
},
{
name: "unwrap integrity violation error",
err: errors.Unwrap(NewNotNullError("table", "constraint", originalErr)),
want: originalErr,
},
{
name: "unknown error",
err: NewUnknownError(originalErr),
want: originalErr,
},
} {
t.Run(test.name, func(t *testing.T) {
assert.Equal(t, test.want, errors.Unwrap(test.err))
})
}
}
func TestIs(t *testing.T) {
originalErr := errors.New("original error")
for _, test := range []struct {
name string
err error
want error
}{
{
name: "missing condition",
err: NewMissingConditionError(NewColumn("table", "column")),
want: new(MissingConditionError),
},
{
name: "no row found",
err: NewNoRowFoundError(errors.New("original error")),
want: new(NoRowFoundError),
},
{
name: "multiple rows found",
err: NewMultipleRowsFoundError(originalErr),
want: new(MultipleRowsFoundError),
},
{
name: "check violation is for integrity",
err: NewCheckError("table", "constraint", nil),
want: new(IntegrityViolationError),
},
{
name: "check violation is check violation",
err: NewCheckError("table", "constraint", nil),
want: new(CheckError),
},
{
name: "unique violation is for integrity",
err: NewUniqueError("table", "constraint", nil),
want: new(IntegrityViolationError),
},
{
name: "unique violation is unique violation",
err: NewUniqueError("table", "constraint", nil),
want: new(UniqueError),
},
{
name: "foreign key violation is for integrity",
err: NewForeignKeyError("table", "constraint", nil),
want: new(IntegrityViolationError),
},
{
name: "foreign key violation is foreign key violation",
err: NewForeignKeyError("table", "constraint", nil),
want: new(ForeignKeyError),
},
{
name: "not null violation is for integrity",
err: NewNotNullError("table", "constraint", nil),
want: new(IntegrityViolationError),
},
{
name: "not null violation is not null violation",
err: NewNotNullError("table", "constraint", nil),
want: new(NotNullError),
},
{
name: "unknown error",
err: NewUnknownError(originalErr),
want: new(UnknownError),
},
} {
t.Run(test.name, func(t *testing.T) {
assert.ErrorIs(t, test.err, test.want)
})
}
}

View File

@@ -21,8 +21,8 @@ import (
func TestServer_TestInstanceDomainReduces(t *testing.T) {
instance := integration.NewInstance(CTX)
instanceRepo := repository.InstanceRepository(pool)
instanceDomainRepo := instanceRepo.Domains(true)
instanceRepo := repository.InstanceRepository()
instanceDomainRepo := repository.InstanceDomainRepository()
t.Cleanup(func() {
_, err := instance.Client.InstanceV2Beta.DeleteInstance(CTX, &v2beta.DeleteInstanceRequest{
@@ -36,7 +36,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
// Wait for instance to be created
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
_, err := instanceRepo.Get(CTX,
_, err := instanceRepo.Get(CTX, pool,
database.WithCondition(instanceRepo.IDCondition(instance.ID())),
)
assert.NoError(t, err)
@@ -66,7 +66,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
// Test that domain add reduces
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
domain, err := instanceDomainRepo.Get(CTX,
domain, err := instanceDomainRepo.Get(CTX, pool,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
@@ -96,7 +96,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
t.Cleanup(func() {
// first we change the primary domain to something else
domain, err := instanceDomainRepo.Get(CTX,
domain, err := instanceDomainRepo.Get(CTX, pool,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
@@ -125,7 +125,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
// Wait for domain to be created
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
domain, err := instanceDomainRepo.Get(CTX,
domain, err := instanceDomainRepo.Get(CTX, pool,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
@@ -151,7 +151,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
// Test that set primary reduces
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
domain, err := instanceDomainRepo.Get(CTX,
domain, err := instanceDomainRepo.Get(CTX, pool,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
@@ -180,7 +180,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
// Wait for domain to be created and verify it exists
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
_, err := instanceDomainRepo.Get(CTX,
_, err := instanceDomainRepo.Get(CTX, pool,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
@@ -202,7 +202,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
// Test that domain remove reduces
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
domain, err := instanceDomainRepo.Get(CTX,
domain, err := instanceDomainRepo.Get(CTX, pool,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
@@ -241,7 +241,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
// Test that domain add reduces
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
domain, err := instanceDomainRepo.Get(CTX,
domain, err := instanceDomainRepo.Get(CTX, pool,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
@@ -271,7 +271,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
// Wait for domain to be created and verify it exists
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
_, err := instanceDomainRepo.Get(CTX,
_, err := instanceDomainRepo.Get(CTX, pool,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
@@ -293,7 +293,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
// Test that domain remove reduces
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
domain, err := instanceDomainRepo.Get(CTX,
domain, err := instanceDomainRepo.Get(CTX, pool,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),

View File

@@ -17,7 +17,7 @@ import (
)
func TestServer_TestInstanceReduces(t *testing.T) {
instanceRepo := repository.InstanceRepository(pool)
instanceRepo := repository.InstanceRepository()
t.Run("test instance add reduces", func(t *testing.T) {
instanceName := gofakeit.Name()
@@ -46,7 +46,7 @@ func TestServer_TestInstanceReduces(t *testing.T) {
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
instance, err := instanceRepo.Get(CTX,
instance, err := instanceRepo.Get(CTX, pool,
database.WithCondition(instanceRepo.IDCondition(instance.GetInstanceId())),
)
require.NoError(t, err)
@@ -92,7 +92,7 @@ func TestServer_TestInstanceReduces(t *testing.T) {
// check instance exists
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
instance, err := instanceRepo.Get(CTX,
instance, err := instanceRepo.Get(CTX, pool,
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
)
require.NoError(t, err)
@@ -110,7 +110,7 @@ func TestServer_TestInstanceReduces(t *testing.T) {
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
instance, err := instanceRepo.Get(CTX,
instance, err := instanceRepo.Get(CTX, pool,
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
)
require.NoError(t, err)
@@ -137,7 +137,7 @@ func TestServer_TestInstanceReduces(t *testing.T) {
// check instance exists
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
instance, err := instanceRepo.Get(CTX,
instance, err := instanceRepo.Get(CTX, pool,
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
)
require.NoError(t, err)
@@ -151,7 +151,7 @@ func TestServer_TestInstanceReduces(t *testing.T) {
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
instance, err := instanceRepo.Get(CTX,
instance, err := instanceRepo.Get(CTX, pool,
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
)
// event instance.removed

View File

@@ -22,8 +22,8 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
})
require.NoError(t, err)
orgRepo := repository.OrganizationRepository(pool)
orgDomainRepo := orgRepo.Domains(false)
orgRepo := repository.OrganizationRepository()
orgDomainRepo := repository.OrganizationDomainRepository()
t.Cleanup(func() {
_, err := OrgClient.DeleteOrganization(CTX, &v2beta.DeleteOrganizationRequest{
@@ -37,8 +37,13 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
// Wait for org to be created
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
_, err := orgRepo.Get(CTX,
database.WithCondition(orgRepo.IDCondition(org.GetId())),
_, err := orgRepo.Get(CTX, pool,
database.WithCondition(
database.And(
orgRepo.InstanceIDCondition(Instance.Instance.Id),
orgRepo.IDCondition(org.GetId()),
),
),
)
assert.NoError(t, err)
}, retryDuration, tick)
@@ -68,7 +73,7 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
// Test that domain add reduces
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
gottenDomain, err := orgDomainRepo.Get(CTX,
gottenDomain, err := orgDomainRepo.Get(CTX, pool,
database.WithCondition(
database.And(
orgDomainRepo.InstanceIDCondition(Instance.Instance.Id),
@@ -107,7 +112,7 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
// Test that domain remove reduces
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
domain, err := orgDomainRepo.Get(CTX,
domain, err := orgDomainRepo.Get(CTX, pool,
database.WithCondition(
database.And(
orgDomainRepo.InstanceIDCondition(Instance.Instance.Id),

View File

@@ -19,7 +19,7 @@ import (
func TestServer_TestOrganizationReduces(t *testing.T) {
instanceID := Instance.ID()
orgRepo := repository.OrganizationRepository(pool)
orgRepo := repository.OrganizationRepository()
t.Run("test org add reduces", func(t *testing.T) {
beforeCreate := time.Now()
@@ -42,7 +42,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(tt *assert.CollectT) {
organization, err := orgRepo.Get(CTX,
organization, err := orgRepo.Get(CTX, pool,
database.WithCondition(
database.And(
orgRepo.IDCondition(org.GetId()),
@@ -92,7 +92,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
organization, err := orgRepo.Get(CTX,
organization, err := orgRepo.Get(CTX, pool,
database.WithCondition(
database.And(
orgRepo.IDCondition(organization.Id),
@@ -137,7 +137,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
organization, err := orgRepo.Get(CTX,
organization, err := orgRepo.Get(CTX, pool,
database.WithCondition(
database.And(
orgRepo.IDCondition(organization.Id),
@@ -177,11 +177,11 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
})
require.NoError(t, err)
orgRepo := repository.OrganizationRepository(pool)
orgRepo := repository.OrganizationRepository()
// 3. check org deactivated
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
organization, err := orgRepo.Get(CTX,
organization, err := orgRepo.Get(CTX, pool,
database.WithCondition(
database.And(
orgRepo.IDCondition(organization.Id),
@@ -203,7 +203,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
organization, err := orgRepo.Get(CTX,
organization, err := orgRepo.Get(CTX, pool,
database.WithCondition(
database.And(
orgRepo.IDCondition(organization.Id),
@@ -230,10 +230,10 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
require.NoError(t, err)
// 2. check org retrievable
orgRepo := repository.OrganizationRepository(pool)
orgRepo := repository.OrganizationRepository()
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
_, err := orgRepo.Get(CTX,
_, err := orgRepo.Get(CTX, pool,
database.WithCondition(
database.And(
orgRepo.IDCondition(organization.Id),
@@ -252,7 +252,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
organization, err := orgRepo.Get(CTX,
organization, err := orgRepo.Get(CTX, pool,
database.WithCondition(
database.And(
orgRepo.IDCondition(organization.Id),

View File

@@ -1,3 +1,3 @@
package database
//go:generate mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Client,Row,Rows,Transaction
//go:generate mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Connection,Row,Rows,Transaction

View File

@@ -6,16 +6,40 @@ import (
"golang.org/x/exp/constraints"
)
type Value interface {
Boolean | Number | Text | Instruction
type wrappedValue[V Value] struct {
value V
fn function
}
func LowerValue[T Value](v T) wrappedValue[T] {
return wrappedValue[T]{value: v, fn: functionLower}
}
func SHA256Value[T Value](v T) wrappedValue[T] {
return wrappedValue[T]{value: v, fn: functionSHA256}
}
func (b wrappedValue[V]) WriteArg(builder *StatementBuilder) {
builder.Grow(len(b.fn) + 5)
builder.WriteString(string(b.fn))
builder.WriteRune('(')
builder.WriteArg(b.value)
builder.WriteRune(')')
}
var _ argWriter = (*wrappedValue[string])(nil)
type Value interface {
Boolean | Number | Text | Instruction | Bytes
}
//go:generate enumer -type NumberOperation,TextOperation,BytesOperation -linecomment -output ./operators_enumer.go
type Operation interface {
BooleanOperation | NumberOperation | TextOperation
NumberOperation | TextOperation | BytesOperation
}
type Text interface {
~string | ~[]byte
~string | Bytes
}
// TextOperation are operations that can be performed on text values.
@@ -23,60 +47,17 @@ type TextOperation uint8
const (
// TextOperationEqual compares two strings for equality.
TextOperationEqual TextOperation = iota + 1
// TextOperationEqualIgnoreCase compares two strings for equality, ignoring case.
TextOperationEqualIgnoreCase
TextOperationEqual TextOperation = iota + 1 // =
// TextOperationNotEqual compares two strings for inequality.
TextOperationNotEqual
// TextOperationNotEqualIgnoreCase compares two strings for inequality, ignoring case.
TextOperationNotEqualIgnoreCase
TextOperationNotEqual // <>
// TextOperationStartsWith checks if the first string starts with the second.
TextOperationStartsWith
// TextOperationStartsWithIgnoreCase checks if the first string starts with the second, ignoring case.
TextOperationStartsWithIgnoreCase
TextOperationStartsWith // LIKE
)
var textOperations = map[TextOperation]string{
TextOperationEqual: " = ",
TextOperationEqualIgnoreCase: " LIKE ",
TextOperationNotEqual: " <> ",
TextOperationNotEqualIgnoreCase: " NOT LIKE ",
TextOperationStartsWith: " LIKE ",
TextOperationStartsWithIgnoreCase: " LIKE ",
}
func writeTextOperation[T Text](builder *StatementBuilder, col Column, op TextOperation, value T) {
switch op {
case TextOperationEqual, TextOperationNotEqual:
col.WriteQualified(builder)
builder.WriteString(textOperations[op])
builder.WriteArg(value)
case TextOperationEqualIgnoreCase, TextOperationNotEqualIgnoreCase:
builder.WriteString("LOWER(")
col.WriteQualified(builder)
builder.WriteString(")")
builder.WriteString(textOperations[op])
builder.WriteString("LOWER(")
builder.WriteArg(value)
builder.WriteString(")")
case TextOperationStartsWith:
col.WriteQualified(builder)
builder.WriteString(textOperations[op])
builder.WriteArg(value)
func writeTextOperation[T Text](builder *StatementBuilder, col Column, op TextOperation, value any) {
writeOperation[T](builder, col, op.String(), value)
if op == TextOperationStartsWith {
builder.WriteString(" || '%'")
case TextOperationStartsWithIgnoreCase:
builder.WriteString("LOWER(")
col.WriteQualified(builder)
builder.WriteString(")")
builder.WriteString(textOperations[op])
builder.WriteString("LOWER(")
builder.WriteArg(value)
builder.WriteString(")")
builder.WriteString(" || '%'")
default:
panic("unsupported text operation")
}
}
@@ -89,48 +70,60 @@ type NumberOperation uint8
const (
// NumberOperationEqual compares two numbers for equality.
NumberOperationEqual NumberOperation = iota + 1
NumberOperationEqual NumberOperation = iota + 1 // =
// NumberOperationNotEqual compares two numbers for inequality.
NumberOperationNotEqual
NumberOperationNotEqual // <>
// NumberOperationLessThan compares two numbers to check if the first is less than the second.
NumberOperationLessThan
NumberOperationLessThan // <
// NumberOperationLessThanOrEqual compares two numbers to check if the first is less than or equal to the second.
NumberOperationAtLeast
NumberOperationAtLeast // <=
// NumberOperationGreaterThan compares two numbers to check if the first is greater than the second.
NumberOperationGreaterThan
NumberOperationGreaterThan // >
// NumberOperationGreaterThanOrEqual compares two numbers to check if the first is greater than or equal to the second.
NumberOperationAtMost
NumberOperationAtMost // >=
)
var numberOperations = map[NumberOperation]string{
NumberOperationEqual: " = ",
NumberOperationNotEqual: " <> ",
NumberOperationLessThan: " < ",
NumberOperationAtLeast: " <= ",
NumberOperationGreaterThan: " > ",
NumberOperationAtMost: " >= ",
}
func writeNumberOperation[T Number](builder *StatementBuilder, col Column, op NumberOperation, value T) {
col.WriteQualified(builder)
builder.WriteString(numberOperations[op])
builder.WriteArg(value)
func writeNumberOperation[T Number](builder *StatementBuilder, col Column, op NumberOperation, value any) {
writeOperation[T](builder, col, op.String(), value)
}
type Boolean interface {
~bool
}
// BooleanOperation are operations that can be performed on boolean values.
type BooleanOperation uint8
func writeBooleanOperation[T Boolean](builder *StatementBuilder, col Column, value any) {
writeOperation[T](builder, col, "=", value)
}
type Bytes interface {
~[]byte
}
// BytesOperation are operations that can be performed on bytea values.
type BytesOperation uint8
const (
BooleanOperationIsTrue BooleanOperation = iota + 1
BooleanOperationIsFalse
BytesOperationEqual BytesOperation = iota + 1 // =
BytesOperationNotEqual // <>
)
func writeBooleanOperation[T Boolean](builder *StatementBuilder, col Column, value T) {
func writeBytesOperation[T Bytes](builder *StatementBuilder, col Column, op BytesOperation, value any) {
writeOperation[T](builder, col, op.String(), value)
}
func writeOperation[V Value](builder *StatementBuilder, col Column, op string, value any) {
if op == "" {
panic("unsupported operation")
}
switch value.(type) {
case V, wrappedValue[V], *wrappedValue[V]:
default:
panic("unsupported value type")
}
col.WriteQualified(builder)
builder.WriteString(" = ")
builder.WriteRune(' ')
builder.WriteString(op)
builder.WriteRune(' ')
builder.WriteArg(value)
}

View File

@@ -0,0 +1,241 @@
// Code generated by "enumer -type NumberOperation,TextOperation,BytesOperation -linecomment -output ./operators_enumer.go"; DO NOT EDIT.
package database
import (
"fmt"
"strings"
)
const _NumberOperationName = "=<><<=>>="
var _NumberOperationIndex = [...]uint8{0, 1, 3, 4, 6, 7, 9}
const _NumberOperationLowerName = "=<><<=>>="
func (i NumberOperation) String() string {
i -= 1
if i >= NumberOperation(len(_NumberOperationIndex)-1) {
return fmt.Sprintf("NumberOperation(%d)", i+1)
}
return _NumberOperationName[_NumberOperationIndex[i]:_NumberOperationIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _NumberOperationNoOp() {
var x [1]struct{}
_ = x[NumberOperationEqual-(1)]
_ = x[NumberOperationNotEqual-(2)]
_ = x[NumberOperationLessThan-(3)]
_ = x[NumberOperationAtLeast-(4)]
_ = x[NumberOperationGreaterThan-(5)]
_ = x[NumberOperationAtMost-(6)]
}
var _NumberOperationValues = []NumberOperation{NumberOperationEqual, NumberOperationNotEqual, NumberOperationLessThan, NumberOperationAtLeast, NumberOperationGreaterThan, NumberOperationAtMost}
var _NumberOperationNameToValueMap = map[string]NumberOperation{
_NumberOperationName[0:1]: NumberOperationEqual,
_NumberOperationLowerName[0:1]: NumberOperationEqual,
_NumberOperationName[1:3]: NumberOperationNotEqual,
_NumberOperationLowerName[1:3]: NumberOperationNotEqual,
_NumberOperationName[3:4]: NumberOperationLessThan,
_NumberOperationLowerName[3:4]: NumberOperationLessThan,
_NumberOperationName[4:6]: NumberOperationAtLeast,
_NumberOperationLowerName[4:6]: NumberOperationAtLeast,
_NumberOperationName[6:7]: NumberOperationGreaterThan,
_NumberOperationLowerName[6:7]: NumberOperationGreaterThan,
_NumberOperationName[7:9]: NumberOperationAtMost,
_NumberOperationLowerName[7:9]: NumberOperationAtMost,
}
var _NumberOperationNames = []string{
_NumberOperationName[0:1],
_NumberOperationName[1:3],
_NumberOperationName[3:4],
_NumberOperationName[4:6],
_NumberOperationName[6:7],
_NumberOperationName[7:9],
}
// NumberOperationString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func NumberOperationString(s string) (NumberOperation, error) {
if val, ok := _NumberOperationNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _NumberOperationNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to NumberOperation values", s)
}
// NumberOperationValues returns all values of the enum
func NumberOperationValues() []NumberOperation {
return _NumberOperationValues
}
// NumberOperationStrings returns a slice of all String values of the enum
func NumberOperationStrings() []string {
strs := make([]string, len(_NumberOperationNames))
copy(strs, _NumberOperationNames)
return strs
}
// IsANumberOperation returns "true" if the value is listed in the enum definition. "false" otherwise
func (i NumberOperation) IsANumberOperation() bool {
for _, v := range _NumberOperationValues {
if i == v {
return true
}
}
return false
}
const _TextOperationName = "=<>LIKE"
var _TextOperationIndex = [...]uint8{0, 1, 3, 7}
const _TextOperationLowerName = "=<>like"
func (i TextOperation) String() string {
i -= 1
if i >= TextOperation(len(_TextOperationIndex)-1) {
return fmt.Sprintf("TextOperation(%d)", i+1)
}
return _TextOperationName[_TextOperationIndex[i]:_TextOperationIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _TextOperationNoOp() {
var x [1]struct{}
_ = x[TextOperationEqual-(1)]
_ = x[TextOperationNotEqual-(2)]
_ = x[TextOperationStartsWith-(3)]
}
var _TextOperationValues = []TextOperation{TextOperationEqual, TextOperationNotEqual, TextOperationStartsWith}
var _TextOperationNameToValueMap = map[string]TextOperation{
_TextOperationName[0:1]: TextOperationEqual,
_TextOperationLowerName[0:1]: TextOperationEqual,
_TextOperationName[1:3]: TextOperationNotEqual,
_TextOperationLowerName[1:3]: TextOperationNotEqual,
_TextOperationName[3:7]: TextOperationStartsWith,
_TextOperationLowerName[3:7]: TextOperationStartsWith,
}
var _TextOperationNames = []string{
_TextOperationName[0:1],
_TextOperationName[1:3],
_TextOperationName[3:7],
}
// TextOperationString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func TextOperationString(s string) (TextOperation, error) {
if val, ok := _TextOperationNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _TextOperationNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to TextOperation values", s)
}
// TextOperationValues returns all values of the enum
func TextOperationValues() []TextOperation {
return _TextOperationValues
}
// TextOperationStrings returns a slice of all String values of the enum
func TextOperationStrings() []string {
strs := make([]string, len(_TextOperationNames))
copy(strs, _TextOperationNames)
return strs
}
// IsATextOperation returns "true" if the value is listed in the enum definition. "false" otherwise
func (i TextOperation) IsATextOperation() bool {
for _, v := range _TextOperationValues {
if i == v {
return true
}
}
return false
}
const _BytesOperationName = "=<>"
var _BytesOperationIndex = [...]uint8{0, 1, 3}
const _BytesOperationLowerName = "=<>"
func (i BytesOperation) String() string {
i -= 1
if i >= BytesOperation(len(_BytesOperationIndex)-1) {
return fmt.Sprintf("BytesOperation(%d)", i+1)
}
return _BytesOperationName[_BytesOperationIndex[i]:_BytesOperationIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _BytesOperationNoOp() {
var x [1]struct{}
_ = x[BytesOperationEqual-(1)]
_ = x[BytesOperationNotEqual-(2)]
}
var _BytesOperationValues = []BytesOperation{BytesOperationEqual, BytesOperationNotEqual}
var _BytesOperationNameToValueMap = map[string]BytesOperation{
_BytesOperationName[0:1]: BytesOperationEqual,
_BytesOperationLowerName[0:1]: BytesOperationEqual,
_BytesOperationName[1:3]: BytesOperationNotEqual,
_BytesOperationLowerName[1:3]: BytesOperationNotEqual,
}
var _BytesOperationNames = []string{
_BytesOperationName[0:1],
_BytesOperationName[1:3],
}
// BytesOperationString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func BytesOperationString(s string) (BytesOperation, error) {
if val, ok := _BytesOperationNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _BytesOperationNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to BytesOperation values", s)
}
// BytesOperationValues returns all values of the enum
func BytesOperationValues() []BytesOperation {
return _BytesOperationValues
}
// BytesOperationStrings returns a slice of all String values of the enum
func BytesOperationStrings() []string {
strs := make([]string, len(_BytesOperationNames))
copy(strs, _BytesOperationNames)
return strs
}
// IsABytesOperation returns "true" if the value is listed in the enum definition. "false" otherwise
func (i BytesOperation) IsABytesOperation() bool {
for _, v := range _BytesOperationValues {
if i == v {
return true
}
}
return false
}

View File

@@ -0,0 +1,244 @@
package database
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_writeOperation(t *testing.T) {
type want struct {
shouldPanic bool
stmt string
args []any
}
tests := []struct {
name string
write func(builder *StatementBuilder)
// col Column
// op string
// value any
want want
}{
{
name: "unsupported operation panics",
write: func(builder *StatementBuilder) {
writeOperation[string](builder, NewColumn("table", "column"), "", "value")
},
want: want{
shouldPanic: true,
},
},
{
name: "unsupported value type panics",
write: func(builder *StatementBuilder) {
writeOperation[string](builder, NewColumn("table", "column"), " = ", struct{}{})
},
want: want{
shouldPanic: true,
},
},
{
name: "unsupported wrapped value type panics",
write: func(builder *StatementBuilder) {
writeOperation[string](builder, NewColumn("table", "column"), " = ", SHA256Value(123))
},
want: want{
shouldPanic: true,
},
},
{
name: "text equal",
write: func(builder *StatementBuilder) {
writeTextOperation[string](builder, NewColumn("table", "column"), TextOperationEqual, "value")
},
want: want{
stmt: "table.column = $1",
args: []any{"value"},
},
},
{
name: "text not equal",
write: func(builder *StatementBuilder) {
writeTextOperation[string](builder, NewColumn("table", "column"), TextOperationNotEqual, "value")
},
want: want{
stmt: "table.column <> $1",
args: []any{"value"},
},
},
{
name: "text starts with",
write: func(builder *StatementBuilder) {
writeTextOperation[string](builder, NewColumn("table", "column"), TextOperationStartsWith, "value")
},
want: want{
stmt: "table.column LIKE $1 || '%'",
args: []any{"value"},
},
},
{
name: "text equal with wrapped value",
write: func(builder *StatementBuilder) {
writeTextOperation[string](builder, LowerColumn(NewColumn("table", "column")), TextOperationEqual, LowerValue("value"))
},
want: want{
stmt: "LOWER(table.column) = LOWER($1)",
args: []any{"value"},
},
},
{
name: "text not equal with wrapped value",
write: func(builder *StatementBuilder) {
writeTextOperation[string](builder, LowerColumn(NewColumn("table", "column")), TextOperationNotEqual, LowerValue("value"))
},
want: want{
stmt: "LOWER(table.column) <> LOWER($1)",
args: []any{"value"},
},
},
{
name: "text starts with with wrapped value",
write: func(builder *StatementBuilder) {
writeTextOperation[string](builder, LowerColumn(NewColumn("table", "column")), TextOperationStartsWith, LowerValue("value"))
},
want: want{
stmt: "LOWER(table.column) LIKE LOWER($1) || '%'",
args: []any{"value"},
},
},
{
name: "number equal",
write: func(builder *StatementBuilder) {
writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationEqual, 123)
},
want: want{
stmt: "table.column = $1",
args: []any{123},
},
},
{
name: "number not equal",
write: func(builder *StatementBuilder) {
writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationNotEqual, 123)
},
want: want{
stmt: "table.column <> $1",
args: []any{123},
},
},
{
name: "number less than",
write: func(builder *StatementBuilder) {
writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationLessThan, 123)
},
want: want{
stmt: "table.column < $1",
args: []any{123},
},
},
{
name: "number less than or equal",
write: func(builder *StatementBuilder) {
writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationAtLeast, 123)
},
want: want{
stmt: "table.column <= $1",
args: []any{123},
},
},
{
name: "number greater than",
write: func(builder *StatementBuilder) {
writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationGreaterThan, 123)
},
want: want{
stmt: "table.column > $1",
args: []any{123},
},
},
{
name: "number greater than or equal",
write: func(builder *StatementBuilder) {
writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationAtMost, 123)
},
want: want{
stmt: "table.column >= $1",
args: []any{123},
},
},
{
name: "boolean is true",
write: func(builder *StatementBuilder) {
writeBooleanOperation[bool](builder, NewColumn("table", "column"), true)
},
want: want{
stmt: "table.column = $1",
args: []any{true},
},
},
{
name: "boolean is false",
write: func(builder *StatementBuilder) {
writeBooleanOperation[bool](builder, NewColumn("table", "column"), false)
},
want: want{
stmt: "table.column = $1",
args: []any{false},
},
},
{
name: "bytes equal",
write: func(builder *StatementBuilder) {
writeBytesOperation[[]byte](builder, NewColumn("table", "column"), BytesOperationEqual, []byte{0x01, 0x02, 0x03})
},
want: want{
stmt: "table.column = $1",
args: []any{[]byte{0x01, 0x02, 0x03}},
},
},
{
name: "bytes not equal",
write: func(builder *StatementBuilder) {
writeBytesOperation[[]byte](builder, NewColumn("table", "column"), BytesOperationNotEqual, []byte{0x01, 0x02, 0x03})
},
want: want{
stmt: "table.column <> $1",
args: []any{[]byte{0x01, 0x02, 0x03}},
},
},
{
name: "bytes equal with wrapped value",
write: func(builder *StatementBuilder) {
writeBytesOperation[[]byte](builder, SHA256Column(NewColumn("table", "column")), BytesOperationEqual, SHA256Value([]byte{0x01, 0x02, 0x03}))
},
want: want{
stmt: "SHA256(table.column) = SHA256($1)",
args: []any{[]byte{0x01, 0x02, 0x03}},
},
},
{
name: "bytes not equal with wrapped value",
write: func(builder *StatementBuilder) {
writeBytesOperation[[]byte](builder, SHA256Column(NewColumn("table", "column")), BytesOperationNotEqual, SHA256Value([]byte{0x01, 0x02, 0x03}))
},
want: want{
stmt: "SHA256(table.column) <> SHA256($1)",
args: []any{[]byte{0x01, 0x02, 0x03}},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer func() {
r := recover()
assert.Equal(t, tt.want.shouldPanic, r != nil)
}()
var builder StatementBuilder
tt.write(&builder)
assert.Equal(t, tt.want.stmt, builder.String())
assert.Equal(t, tt.want.args, builder.Args())
})
}
}

View File

@@ -0,0 +1,28 @@
package database
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_orderBy_Write(t *testing.T) {
tests := []struct {
name string
want string
order Order
}{
{
name: "order by column",
want: " ORDER BY table.column",
order: OrderBy(NewColumn("table", "column")),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var builder StatementBuilder
tt.order.Write(&builder)
assert.Equal(t, tt.want, builder.String())
})
}
}

View File

@@ -0,0 +1,147 @@
package database
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestQueryOptions(t *testing.T) {
type want struct {
stmt string
args []any
}
for _, test := range []struct {
name string
options []QueryOption
want want
}{
{
name: "no options",
want: want{
stmt: "",
args: nil,
},
},
{
name: "limit option",
options: []QueryOption{
WithLimit(10),
},
want: want{
stmt: " LIMIT $1",
args: []any{uint32(10)},
},
},
{
name: "offset option",
options: []QueryOption{
WithOffset(5),
},
want: want{
stmt: " OFFSET $1",
args: []any{uint32(5)},
},
},
{
name: "order by asc option",
options: []QueryOption{
WithOrderByAscending(NewColumn("table", "column")),
},
want: want{
stmt: " ORDER BY table.column",
args: nil,
},
},
{
name: "order by desc option",
options: []QueryOption{
WithOrderByDescending(NewColumn("table", "column")),
},
want: want{
stmt: " ORDER BY table.column DESC",
args: nil,
},
},
{
name: "order by option",
options: []QueryOption{
WithOrderBy(OrderDirectionAsc, NewColumn("table", "column1"), NewColumn("table", "column2")),
},
want: want{
stmt: " ORDER BY table.column1, table.column2",
args: nil,
},
},
{
name: "condition option",
options: []QueryOption{
WithCondition(NewBooleanCondition(NewColumn("table", "column"), true)),
},
want: want{
stmt: " WHERE table.column = $1",
args: []any{true},
},
},
{
name: "group by option",
options: []QueryOption{
WithGroupBy(NewColumn("table", "column")),
},
want: want{
stmt: " GROUP BY table.column",
args: nil,
},
},
{
name: "left join option",
options: []QueryOption{
WithLeftJoin("other_table", NewColumnCondition(NewColumn("table", "id"), NewColumn("other_table", "table_id"))),
},
want: want{
stmt: " LEFT JOIN other_table ON table.id = other_table.table_id",
args: nil,
},
},
{
name: "permission check option",
options: []QueryOption{
WithPermissionCheck("permission"),
},
want: want{
stmt: "",
args: nil,
},
},
{
name: "multiple options",
options: []QueryOption{
WithLeftJoin("other_table", NewColumnCondition(NewColumn("table", "id"), NewColumn("other_table", "table_id"))),
WithCondition(NewNumberCondition(NewColumn("table", "column"), NumberOperationEqual, 123)),
WithOrderByDescending(NewColumn("table", "column")),
WithLimit(10),
WithOffset(5),
},
want: want{
stmt: " LEFT JOIN other_table ON table.id = other_table.table_id WHERE table.column = $1 ORDER BY table.column DESC LIMIT $2 OFFSET $3",
args: []any{123, uint32(10), uint32(5)},
},
},
} {
t.Run(test.name, func(t *testing.T) {
var b StatementBuilder
var opts QueryOpts
for _, option := range test.options {
option(&opts)
}
opts.Write(&b)
assert.Equal(t, test.want.stmt, b.String())
require.Len(t, b.Args(), len(test.want.args))
for i := range test.want.args {
assert.Equal(t, test.want.args[i], b.Args()[i])
}
})
}
}

View File

@@ -3,19 +3,35 @@ package repository
import (
"encoding/json"
"errors"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type JSONArray[T any] []*T
func (a JSONArray[T]) Scan(src any) error {
var ErrScanSource = errors.New("unsupported scan source")
func (a *JSONArray[T]) Scan(src any) (err error) {
var rawJSON []byte
switch s := src.(type) {
case string:
return json.Unmarshal([]byte(s), &a)
if len(s) == 0 {
return nil
}
rawJSON = []byte(s)
case []byte:
return json.Unmarshal(s, &a)
if len(s) == 0 {
return nil
}
rawJSON = s
case nil:
return nil
default:
return errors.New("unsupported scan source")
return ErrScanSource
}
err = json.Unmarshal(rawJSON, a)
if err != nil {
return database.NewScanError(err)
}
return nil
}

View File

@@ -0,0 +1,117 @@
package repository_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
)
type testObject struct {
ID string `json:"id"`
Number int `json:"number"`
Active bool `json:"active"`
}
func TestJSONArray_Scan(t *testing.T) {
type want struct {
err error
res []*testObject
}
tests := []struct {
name string
src any
want want
}{
{
name: "nil",
src: nil,
want: want{
err: nil,
res: nil,
},
},
{
name: "[]byte with data",
src: []byte(`[{"id":"1","number":1,"active":true}]`),
want: want{
err: nil,
res: []*testObject{{ID: "1", Number: 1, Active: true}},
},
},
{
name: "[]byte without data",
src: []byte(`[]`),
want: want{
err: nil,
res: nil,
},
},
{
name: "nil []byte",
src: []byte(nil),
want: want{
err: nil,
res: nil,
},
},
{
name: "empty []byte",
src: []byte{},
want: want{
err: nil,
res: nil,
},
},
{
name: "string with data",
src: `[{"id":"1","number":1,"active":true}]`,
want: want{
err: nil,
res: []*testObject{{ID: "1", Number: 1, Active: true}},
},
},
{
name: "string without data",
src: string(`[]`),
want: want{
err: nil,
res: nil,
},
},
{
name: "empty string",
src: "",
want: want{
err: nil,
res: nil,
},
},
{
name: "wrong type",
src: []int{1, 2, 3},
want: want{
err: repository.ErrScanSource,
res: nil,
},
},
{
name: "badly formatted JSON",
src: []byte(`this is not a JSON`),
want: want{
err: new(database.ScanError),
res: nil,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var a repository.JSONArray[testObject]
gotErr := a.Scan(tt.src)
require.ErrorIs(t, gotErr, tt.want.err)
require.Len(t, a, len(tt.want.res))
})
}
}

View File

@@ -13,17 +13,20 @@ import (
var _ domain.InstanceRepository = (*instance)(nil)
type instance struct {
repository
shouldLoadDomains bool
domainRepo *instanceDomain
domainRepo instanceDomain
}
func InstanceRepository(client database.QueryExecutor) domain.InstanceRepository {
return &instance{
repository: repository{
client: client,
},
}
func InstanceRepository() domain.InstanceRepository {
return new(instance)
}
func (instance) qualifiedTableName() string {
return "zitadel.instances"
}
func (instance) unqualifiedTableName() string {
return "instances"
}
// -------------------------------------------------------------
@@ -37,7 +40,7 @@ const (
)
// Get implements [domain.InstanceRepository].
func (i *instance) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Instance, error) {
func (i instance) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.Instance, error) {
opts = append(opts,
i.joinDomains(),
database.WithGroupBy(i.IDColumn()),
@@ -52,11 +55,11 @@ func (i *instance) Get(ctx context.Context, opts ...database.QueryOption) (*doma
builder.WriteString(queryInstanceStmt)
options.Write(&builder)
return scanInstance(ctx, i.client, &builder)
return scanInstance(ctx, client, &builder)
}
// List implements [domain.InstanceRepository].
func (i *instance) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Instance, error) {
func (i instance) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*domain.Instance, error) {
opts = append(opts,
i.joinDomains(),
database.WithGroupBy(i.IDColumn()),
@@ -71,27 +74,11 @@ func (i *instance) List(ctx context.Context, opts ...database.QueryOption) ([]*d
builder.WriteString(queryInstanceStmt)
options.Write(&builder)
return scanInstances(ctx, i.client, &builder)
}
func (i *instance) joinDomains() database.QueryOption {
columns := make([]database.Condition, 0, 2)
columns = append(columns, database.NewColumnCondition(i.IDColumn(), i.Domains(false).InstanceIDColumn()))
// If domains should not be joined, we make sure to return null for the domain columns
// the query optimizer of the dialect should optimize this away if no domains are requested
if !i.shouldLoadDomains {
columns = append(columns, database.IsNull(i.Domains(false).InstanceIDColumn()))
}
return database.WithLeftJoin(
"zitadel.instance_domains",
database.And(columns...),
)
return scanInstances(ctx, client, &builder)
}
// Create implements [domain.InstanceRepository].
func (i *instance) Create(ctx context.Context, instance *domain.Instance) error {
func (i instance) Create(ctx context.Context, client database.QueryExecutor, instance *domain.Instance) error {
var (
builder database.StatementBuilder
createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction
@@ -103,42 +90,46 @@ func (i *instance) Create(ctx context.Context, instance *domain.Instance) error
updatedAt = instance.UpdatedAt
}
builder.WriteString(`INSERT INTO zitadel.instances (id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language, created_at, updated_at) VALUES (`)
builder.WriteString(`INSERT INTO `)
builder.WriteString(i.qualifiedTableName())
builder.WriteString(` (id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language, created_at, updated_at) VALUES (`)
builder.WriteArgs(instance.ID, instance.Name, instance.DefaultOrgID, instance.IAMProjectID, instance.ConsoleClientID, instance.ConsoleAppID, instance.DefaultLanguage, createdAt, updatedAt)
builder.WriteString(`) RETURNING created_at, updated_at`)
return i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&instance.CreatedAt, &instance.UpdatedAt)
return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&instance.CreatedAt, &instance.UpdatedAt)
}
// Update implements [domain.InstanceRepository].
func (i instance) Update(ctx context.Context, id string, changes ...database.Change) (int64, error) {
func (i instance) Update(ctx context.Context, client database.QueryExecutor, id string, changes ...database.Change) (int64, error) {
if len(changes) == 0 {
return 0, database.ErrNoChanges
}
if !database.Changes(changes).IsOnColumn(i.UpdatedAtColumn()) {
changes = append(changes, database.NewChange(i.UpdatedAtColumn(), database.NullInstruction))
}
var builder database.StatementBuilder
builder.WriteString(`UPDATE zitadel.instances SET `)
database.Changes(changes).Write(&builder)
idCondition := i.IDCondition(id)
writeCondition(&builder, idCondition)
stmt := builder.String()
return i.client.Exec(ctx, stmt, builder.Args()...)
return client.Exec(ctx, stmt, builder.Args()...)
}
// Delete implements [domain.InstanceRepository].
func (i instance) Delete(ctx context.Context, id string) (int64, error) {
func (i instance) Delete(ctx context.Context, client database.QueryExecutor, id string) (int64, error) {
var builder database.StatementBuilder
builder.WriteString(`DELETE FROM zitadel.instances`)
builder.WriteString(`DELETE FROM `)
builder.WriteString(i.qualifiedTableName())
idCondition := i.IDCondition(id)
writeCondition(&builder, idCondition)
return i.client.Exec(ctx, builder.String(), builder.Args()...)
return client.Exec(ctx, builder.String(), builder.Args()...)
}
// -------------------------------------------------------------
@@ -185,53 +176,77 @@ func (i instance) NameCondition(op database.TextOperation, name string) database
return database.NewTextCondition(i.NameColumn(), op, name)
}
// ExistsDomain creates a correlated [database.Exists] condition on instance_domains.
// Use this filter to make sure the Instance returned contains a specific domain.
// of the instance in the aggregated result.
// Example usage:
//
// domainRepo := instanceRepo.Domains(true) // ensure domains are loaded/aggregated
// instance, _ := instanceRepo.Get(ctx,
// database.WithCondition(
// database.And(
// instanceRepo.InstanceIDCondition(instanceID),
// instanceRepo.ExistsDomain(domainRepo.DomainCondition(database.TextOperationEqual, "example.com")),
// ),
// ),
// )
func (i instance) ExistsDomain(cond database.Condition) database.Condition {
return database.Exists(
i.domainRepo.qualifiedTableName(),
database.And(
database.NewColumnCondition(i.IDColumn(), i.domainRepo.InstanceIDColumn()),
cond,
),
)
}
// -------------------------------------------------------------
// columns
// -------------------------------------------------------------
// IDColumn implements [domain.instanceColumns].
func (instance) IDColumn() database.Column {
return database.NewColumn("instances", "id")
func (i instance) IDColumn() database.Column {
return database.NewColumn(i.unqualifiedTableName(), "id")
}
// NameColumn implements [domain.instanceColumns].
func (instance) NameColumn() database.Column {
return database.NewColumn("instances", "name")
func (i instance) NameColumn() database.Column {
return database.NewColumn(i.unqualifiedTableName(), "name")
}
// CreatedAtColumn implements [domain.instanceColumns].
func (instance) CreatedAtColumn() database.Column {
return database.NewColumn("instances", "created_at")
func (i instance) CreatedAtColumn() database.Column {
return database.NewColumn(i.unqualifiedTableName(), "created_at")
}
// DefaultOrgIdColumn implements [domain.instanceColumns].
func (instance) DefaultOrgIDColumn() database.Column {
return database.NewColumn("instances", "default_org_id")
func (i instance) DefaultOrgIDColumn() database.Column {
return database.NewColumn(i.unqualifiedTableName(), "default_org_id")
}
// IAMProjectIDColumn implements [domain.instanceColumns].
func (instance) IAMProjectIDColumn() database.Column {
return database.NewColumn("instances", "iam_project_id")
func (i instance) IAMProjectIDColumn() database.Column {
return database.NewColumn(i.unqualifiedTableName(), "iam_project_id")
}
// ConsoleClientIDColumn implements [domain.instanceColumns].
func (instance) ConsoleClientIDColumn() database.Column {
return database.NewColumn("instances", "console_client_id")
func (i instance) ConsoleClientIDColumn() database.Column {
return database.NewColumn(i.unqualifiedTableName(), "console_client_id")
}
// ConsoleAppIDColumn implements [domain.instanceColumns].
func (instance) ConsoleAppIDColumn() database.Column {
return database.NewColumn("instances", "console_app_id")
func (i instance) ConsoleAppIDColumn() database.Column {
return database.NewColumn(i.unqualifiedTableName(), "console_app_id")
}
// DefaultLanguageColumn implements [domain.instanceColumns].
func (instance) DefaultLanguageColumn() database.Column {
return database.NewColumn("instances", "default_language")
func (i instance) DefaultLanguageColumn() database.Column {
return database.NewColumn(i.unqualifiedTableName(), "default_language")
}
// UpdatedAtColumn implements [domain.instanceColumns].
func (instance) UpdatedAtColumn() database.Column {
return database.NewColumn("instances", "updated_at")
func (i instance) UpdatedAtColumn() database.Column {
return database.NewColumn(i.unqualifiedTableName(), "updated_at")
}
// -------------------------------------------------------------
@@ -282,19 +297,24 @@ func scanInstances(ctx context.Context, querier database.Querier, builder *datab
// sub repositories
// -------------------------------------------------------------
// Domains implements [domain.InstanceRepository].
func (i *instance) Domains(shouldLoad bool) domain.InstanceDomainRepository {
if !i.shouldLoadDomains {
i.shouldLoadDomains = shouldLoad
func (i *instance) LoadDomains() domain.InstanceRepository {
return &instance{
shouldLoadDomains: true,
}
if i.domainRepo != nil {
return i.domainRepo
}
i.domainRepo = &instanceDomain{
repository: i.repository,
instance: i,
}
return i.domainRepo
}
func (i *instance) joinDomains() database.QueryOption {
columns := make([]database.Condition, 0, 2)
columns = append(columns, database.NewColumnCondition(i.IDColumn(), i.domainRepo.InstanceIDColumn()))
// If domains should not be joined, we make sure to return null for the domain columns
// the query optimizer of the dialect should optimize this away if no domains are requested
if !i.shouldLoadDomains {
columns = append(columns, database.IsNull(i.domainRepo.InstanceIDColumn()))
}
return database.WithLeftJoin(
i.domainRepo.qualifiedTableName(),
database.And(columns...),
)
}

View File

@@ -10,9 +10,18 @@ import (
var _ domain.InstanceDomainRepository = (*instanceDomain)(nil)
type instanceDomain struct {
repository
*instance
type instanceDomain struct{}
func InstanceDomainRepository() domain.InstanceDomainRepository {
return new(instanceDomain)
}
func (instanceDomain) qualifiedTableName() string {
return "zitadel.instance_domains"
}
func (instanceDomain) unqualifiedTableName() string {
return "instance_domains"
}
// -------------------------------------------------------------
@@ -23,8 +32,7 @@ const queryInstanceDomainStmt = `SELECT instance_domains.instance_id, instance_d
`FROM zitadel.instance_domains`
// Get implements [domain.InstanceDomainRepository].
// Subtle: this method shadows the method ([domain.InstanceRepository]).Get of instanceDomain.instance.
func (i *instanceDomain) Get(ctx context.Context, opts ...database.QueryOption) (*domain.InstanceDomain, error) {
func (i instanceDomain) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.InstanceDomain, error) {
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
@@ -34,12 +42,11 @@ func (i *instanceDomain) Get(ctx context.Context, opts ...database.QueryOption)
builder.WriteString(queryInstanceDomainStmt)
options.Write(&builder)
return scanInstanceDomain(ctx, i.client, &builder)
return scanInstanceDomain(ctx, client, &builder)
}
// List implements [domain.InstanceDomainRepository].
// Subtle: this method shadows the method ([domain.InstanceRepository]).List of instanceDomain.instance.
func (i *instanceDomain) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.InstanceDomain, error) {
func (i instanceDomain) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*domain.InstanceDomain, error) {
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
@@ -49,11 +56,11 @@ func (i *instanceDomain) List(ctx context.Context, opts ...database.QueryOption)
builder.WriteString(queryInstanceDomainStmt)
options.Write(&builder)
return scanInstanceDomains(ctx, i.client, &builder)
return scanInstanceDomains(ctx, client, &builder)
}
// Add implements [domain.InstanceDomainRepository].
func (i *instanceDomain) Add(ctx context.Context, domain *domain.AddInstanceDomain) error {
func (i instanceDomain) Add(ctx context.Context, client database.QueryExecutor, domain *domain.AddInstanceDomain) error {
var (
builder database.StatementBuilder
createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction
@@ -69,33 +76,41 @@ func (i *instanceDomain) Add(ctx context.Context, domain *domain.AddInstanceDoma
builder.WriteArgs(domain.InstanceID, domain.Domain, domain.IsPrimary, domain.IsGenerated, domain.Type, createdAt, updatedAt)
builder.WriteString(`) RETURNING created_at, updated_at`)
return i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt)
return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt)
}
// Update implements [domain.InstanceDomainRepository].
// Subtle: this method shadows the method ([domain.InstanceRepository]).Update of instanceDomain.instance.
func (i *instanceDomain) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) {
func (i instanceDomain) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error) {
if !condition.IsRestrictingColumn(i.InstanceIDColumn()) {
return 0, database.NewMissingConditionError(i.InstanceIDColumn())
}
if len(changes) == 0 {
return 0, database.ErrNoChanges
}
var builder database.StatementBuilder
if !database.Changes(changes).IsOnColumn(i.UpdatedAtColumn()) {
changes = append(changes, database.NewChange(i.UpdatedAtColumn(), database.NullInstruction))
}
var builder database.StatementBuilder
builder.WriteString(`UPDATE zitadel.instance_domains SET `)
database.Changes(changes).Write(&builder)
writeCondition(&builder, condition)
return i.client.Exec(ctx, builder.String(), builder.Args()...)
return client.Exec(ctx, builder.String(), builder.Args()...)
}
// Remove implements [domain.InstanceDomainRepository].
func (i *instanceDomain) Remove(ctx context.Context, condition database.Condition) (int64, error) {
var builder database.StatementBuilder
func (i instanceDomain) Remove(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error) {
if !condition.IsRestrictingColumn(i.InstanceIDColumn()) {
return 0, database.NewMissingConditionError(i.InstanceIDColumn())
}
var builder database.StatementBuilder
builder.WriteString(`DELETE FROM zitadel.instance_domains WHERE `)
condition.Write(&builder)
return i.client.Exec(ctx, builder.String(), builder.Args()...)
return client.Exec(ctx, builder.String(), builder.Args()...)
}
// -------------------------------------------------------------
@@ -146,40 +161,38 @@ func (i instanceDomain) TypeCondition(typ domain.DomainType) database.Condition
// -------------------------------------------------------------
// CreatedAtColumn implements [domain.InstanceDomainRepository].
// Subtle: this method shadows the method ([domain.InstanceRepository]).CreatedAtColumn of instanceDomain.instance.
func (instanceDomain) CreatedAtColumn() database.Column {
return database.NewColumn("instance_domains", "created_at")
func (i instanceDomain) CreatedAtColumn() database.Column {
return database.NewColumn(i.unqualifiedTableName(), "created_at")
}
// DomainColumn implements [domain.InstanceDomainRepository].
func (instanceDomain) DomainColumn() database.Column {
return database.NewColumn("instance_domains", "domain")
func (i instanceDomain) DomainColumn() database.Column {
return database.NewColumn(i.unqualifiedTableName(), "domain")
}
// InstanceIDColumn implements [domain.InstanceDomainRepository].
func (instanceDomain) InstanceIDColumn() database.Column {
return database.NewColumn("instance_domains", "instance_id")
func (i instanceDomain) InstanceIDColumn() database.Column {
return database.NewColumn(i.unqualifiedTableName(), "instance_id")
}
// IsPrimaryColumn implements [domain.InstanceDomainRepository].
func (instanceDomain) IsPrimaryColumn() database.Column {
return database.NewColumn("instance_domains", "is_primary")
func (i instanceDomain) IsPrimaryColumn() database.Column {
return database.NewColumn(i.unqualifiedTableName(), "is_primary")
}
// UpdatedAtColumn implements [domain.InstanceDomainRepository].
// Subtle: this method shadows the method ([domain.InstanceRepository]).UpdatedAtColumn of instanceDomain.instance.
func (instanceDomain) UpdatedAtColumn() database.Column {
return database.NewColumn("instance_domains", "updated_at")
func (i instanceDomain) UpdatedAtColumn() database.Column {
return database.NewColumn(i.unqualifiedTableName(), "updated_at")
}
// IsGeneratedColumn implements [domain.InstanceDomainRepository].
func (instanceDomain) IsGeneratedColumn() database.Column {
return database.NewColumn("instance_domains", "is_generated")
func (i instanceDomain) IsGeneratedColumn() database.Column {
return database.NewColumn(i.unqualifiedTableName(), "is_generated")
}
// TypeColumn implements [domain.InstanceDomainRepository].
func (instanceDomain) TypeColumn() database.Column {
return database.NewColumn("instance_domains", "type")
func (i instanceDomain) TypeColumn() database.Column {
return database.NewColumn(i.unqualifiedTableName(), "type")
}
// -------------------------------------------------------------

View File

@@ -1,7 +1,6 @@
package repository_test
import (
"context"
"testing"
"time"
@@ -16,31 +15,43 @@ import (
)
func TestAddInstanceDomain(t *testing.T) {
// we take now here because the timestamp of the transaction is used to set the createdAt and updatedAt fields
beforeAdd := time.Now()
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
err = tx.Rollback(t.Context())
if err != nil {
t.Log("error during rollback:", err)
}
}()
instanceRepo := repository.InstanceRepository()
domainRepo := repository.InstanceDomainRepository()
// create instance
instanceID := gofakeit.UUID()
instance := domain.Instance{
ID: instanceID,
Name: gofakeit.Name(),
ID: gofakeit.NewCrypto().UUID(),
Name: gofakeit.BeerName(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleClient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
instanceRepo := repository.InstanceRepository(pool)
err := instanceRepo.Create(t.Context(), &instance)
err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
tests := []struct {
name string
testFunc func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.AddInstanceDomain
instanceDomain domain.AddInstanceDomain
err error
}{
{
name: "happy path custom domain",
instanceDomain: domain.AddInstanceDomain{
InstanceID: instanceID,
InstanceID: instance.ID,
Domain: gofakeit.DomainName(),
Type: domain.DomainTypeCustom,
IsPrimary: gu.Ptr(false),
@@ -50,7 +61,7 @@ func TestAddInstanceDomain(t *testing.T) {
{
name: "happy path trusted domain",
instanceDomain: domain.AddInstanceDomain{
InstanceID: instanceID,
InstanceID: instance.ID,
Domain: gofakeit.DomainName(),
Type: domain.DomainTypeTrusted,
},
@@ -58,7 +69,7 @@ func TestAddInstanceDomain(t *testing.T) {
{
name: "add primary domain",
instanceDomain: domain.AddInstanceDomain{
InstanceID: instanceID,
InstanceID: instance.ID,
Domain: gofakeit.DomainName(),
Type: domain.DomainTypeCustom,
IsPrimary: gu.Ptr(true),
@@ -68,7 +79,7 @@ func TestAddInstanceDomain(t *testing.T) {
{
name: "add custom domain without domain name",
instanceDomain: domain.AddInstanceDomain{
InstanceID: instanceID,
InstanceID: instance.ID,
Domain: "",
Type: domain.DomainTypeCustom,
IsPrimary: gu.Ptr(false),
@@ -79,7 +90,7 @@ func TestAddInstanceDomain(t *testing.T) {
{
name: "add trusted domain without domain name",
instanceDomain: domain.AddInstanceDomain{
InstanceID: instanceID,
InstanceID: instance.ID,
Domain: "",
Type: domain.DomainTypeTrusted,
},
@@ -87,23 +98,23 @@ func TestAddInstanceDomain(t *testing.T) {
},
{
name: "add custom domain with same domain twice",
testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain {
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.AddInstanceDomain {
domainName := gofakeit.DomainName()
instanceDomain := &domain.AddInstanceDomain{
InstanceID: instanceID,
InstanceID: instance.ID,
Domain: domainName,
Type: domain.DomainTypeCustom,
IsPrimary: gu.Ptr(false),
IsGenerated: gu.Ptr(false),
}
err := domainRepo.Add(ctx, instanceDomain)
err := domainRepo.Add(t.Context(), tx, instanceDomain)
require.NoError(t, err)
// return same domain again
return &domain.AddInstanceDomain{
InstanceID: instanceID,
InstanceID: instance.ID,
Domain: domainName,
Type: domain.DomainTypeCustom,
IsPrimary: gu.Ptr(false),
@@ -114,22 +125,20 @@ func TestAddInstanceDomain(t *testing.T) {
},
{
name: "add trusted domain with same domain twice",
testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain {
domainName := gofakeit.DomainName()
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.AddInstanceDomain {
instanceDomain := &domain.AddInstanceDomain{
InstanceID: instanceID,
Domain: domainName,
InstanceID: instance.ID,
Domain: gofakeit.DomainName(),
Type: domain.DomainTypeTrusted,
}
err := domainRepo.Add(ctx, instanceDomain)
err := domainRepo.Add(t.Context(), tx, instanceDomain)
require.NoError(t, err)
// return same domain again
return &domain.AddInstanceDomain{
InstanceID: instanceID,
Domain: domainName,
InstanceID: instance.ID,
Domain: instanceDomain.Domain,
Type: domain.DomainTypeTrusted,
}
},
@@ -196,26 +205,23 @@ func TestAddInstanceDomain(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := t.Context()
// we take now here because the timestamp of the transaction is used to set the createdAt and updatedAt fields
beforeAdd := time.Now()
tx, err := pool.Begin(t.Context(), nil)
savepoint, err := tx.Begin(t.Context())
require.NoError(t, err)
defer func() {
require.NoError(t, tx.Rollback(t.Context()))
err := savepoint.Rollback(t.Context())
if err != nil {
t.Log("error during rollback:", err)
}
}()
instanceRepo := repository.InstanceRepository(tx)
domainRepo := instanceRepo.Domains(false)
var instanceDomain *domain.AddInstanceDomain
if test.testFunc != nil {
instanceDomain = test.testFunc(ctx, t, domainRepo)
instanceDomain = test.testFunc(t, savepoint)
} else {
instanceDomain = &test.instanceDomain
}
err = domainRepo.Add(ctx, instanceDomain)
err = domainRepo.Add(t.Context(), savepoint, instanceDomain)
afterAdd := time.Now()
if test.err != nil {
assert.ErrorIs(t, err, test.err)
@@ -232,10 +238,21 @@ func TestAddInstanceDomain(t *testing.T) {
}
func TestGetInstanceDomain(t *testing.T) {
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
err = tx.Rollback(t.Context())
if err != nil {
t.Log("error during rollback:", err)
}
}()
instanceRepo := repository.InstanceRepository()
domainRepo := repository.InstanceDomainRepository()
// create instance
instanceID := gofakeit.UUID()
instance := domain.Instance{
ID: instanceID,
ID: gofakeit.NewCrypto().UUID(),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
@@ -243,38 +260,29 @@ func TestGetInstanceDomain(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
require.NoError(t, tx.Rollback(t.Context()))
}()
instanceRepo := repository.InstanceRepository(tx)
err = instanceRepo.Create(t.Context(), &instance)
err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
// add domains
domainRepo := instanceRepo.Domains(false)
domainName1 := gofakeit.DomainName()
domainName2 := gofakeit.DomainName()
domain1 := &domain.AddInstanceDomain{
InstanceID: instanceID,
Domain: domainName1,
InstanceID: instance.ID,
Domain: gofakeit.DomainName(),
IsPrimary: gu.Ptr(true),
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
}
domain2 := &domain.AddInstanceDomain{
InstanceID: instanceID,
Domain: domainName2,
InstanceID: instance.ID,
Domain: gofakeit.DomainName(),
IsPrimary: gu.Ptr(false),
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
}
err = domainRepo.Add(t.Context(), domain1)
err = domainRepo.Add(t.Context(), tx, domain1)
require.NoError(t, err)
err = domainRepo.Add(t.Context(), domain2)
err = domainRepo.Add(t.Context(), tx, domain2)
require.NoError(t, err)
tests := []struct {
@@ -289,19 +297,19 @@ func TestGetInstanceDomain(t *testing.T) {
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
},
expected: &domain.InstanceDomain{
InstanceID: instanceID,
Domain: domainName1,
InstanceID: instance.ID,
Domain: domain1.Domain,
IsPrimary: gu.Ptr(true),
},
},
{
name: "get by domain name",
opts: []database.QueryOption{
database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, domainName2)),
database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, domain2.Domain)),
},
expected: &domain.InstanceDomain{
InstanceID: instanceID,
Domain: domainName2,
InstanceID: instance.ID,
Domain: domain2.Domain,
IsPrimary: gu.Ptr(false),
},
},
@@ -318,7 +326,7 @@ func TestGetInstanceDomain(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
ctx := t.Context()
result, err := domainRepo.Get(ctx, test.opts...)
result, err := domainRepo.Get(ctx, tx, test.opts...)
if test.err != nil {
assert.ErrorIs(t, err, test.err)
return
@@ -335,10 +343,21 @@ func TestGetInstanceDomain(t *testing.T) {
}
func TestListInstanceDomains(t *testing.T) {
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
err = tx.Rollback(t.Context())
if err != nil {
t.Log("error during rollback:", err)
}
}()
instanceRepo := repository.InstanceRepository()
domainRepo := repository.InstanceDomainRepository()
// create instance
instanceID := gofakeit.UUID()
instance := domain.Instance{
ID: instanceID,
ID: gofakeit.NewCrypto().UUID(),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
@@ -346,42 +365,35 @@ func TestListInstanceDomains(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
require.NoError(t, tx.Rollback(t.Context()))
}()
instanceRepo := repository.InstanceRepository(tx)
err = instanceRepo.Create(t.Context(), &instance)
err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
// add multiple domains
domainRepo := instanceRepo.Domains(false)
domains := []domain.AddInstanceDomain{
{
InstanceID: instanceID,
InstanceID: instance.ID,
Domain: gofakeit.DomainName(),
IsPrimary: gu.Ptr(true),
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
},
{
InstanceID: instanceID,
InstanceID: instance.ID,
Domain: gofakeit.DomainName(),
IsPrimary: gu.Ptr(false),
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
},
{
InstanceID: instanceID,
InstanceID: instance.ID,
Domain: gofakeit.DomainName(),
Type: domain.DomainTypeTrusted,
},
}
for i := range domains {
err = domainRepo.Add(t.Context(), &domains[i])
err = domainRepo.Add(t.Context(), tx, &domains[i])
require.NoError(t, err)
}
@@ -405,7 +417,7 @@ func TestListInstanceDomains(t *testing.T) {
{
name: "list by instance",
opts: []database.QueryOption{
database.WithCondition(domainRepo.InstanceIDCondition(instanceID)),
database.WithCondition(domainRepo.InstanceIDCondition(instance.ID)),
},
expectedCount: 3,
},
@@ -422,12 +434,12 @@ func TestListInstanceDomains(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
ctx := t.Context()
results, err := domainRepo.List(ctx, test.opts...)
results, err := domainRepo.List(ctx, tx, test.opts...)
require.NoError(t, err)
assert.Len(t, results, test.expectedCount)
for _, result := range results {
assert.Equal(t, instanceID, result.InstanceID)
assert.Equal(t, instance.ID, result.InstanceID)
assert.NotEmpty(t, result.Domain)
assert.NotEmpty(t, result.CreatedAt)
assert.NotEmpty(t, result.UpdatedAt)
@@ -437,10 +449,21 @@ func TestListInstanceDomains(t *testing.T) {
}
func TestUpdateInstanceDomain(t *testing.T) {
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
err = tx.Rollback(t.Context())
if err != nil {
t.Log("error during rollback:", err)
}
}()
instanceRepo := repository.InstanceRepository()
domainRepo := repository.InstanceDomainRepository()
// create instance
instanceID := gofakeit.UUID()
instance := domain.Instance{
ID: instanceID,
ID: gofakeit.UUID(),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
@@ -448,29 +471,19 @@ func TestUpdateInstanceDomain(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
require.NoError(t, tx.Rollback(t.Context()))
}()
instanceRepo := repository.InstanceRepository(tx)
err = instanceRepo.Create(t.Context(), &instance)
err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
// add domain
domainRepo := instanceRepo.Domains(false)
domainName := gofakeit.DomainName()
instanceDomain := &domain.AddInstanceDomain{
InstanceID: instanceID,
Domain: domainName,
InstanceID: instance.ID,
Domain: gofakeit.DomainName(),
IsPrimary: gu.Ptr(false),
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
}
err = domainRepo.Add(t.Context(), instanceDomain)
err = domainRepo.Add(t.Context(), tx, instanceDomain)
require.NoError(t, err)
tests := []struct {
@@ -482,19 +495,28 @@ func TestUpdateInstanceDomain(t *testing.T) {
}{
{
name: "set primary",
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
condition: database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.DomainCondition(database.TextOperationEqual, instanceDomain.Domain),
),
changes: []database.Change{domainRepo.SetPrimary()},
expected: 1,
},
{
name: "update non-existent domain",
condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
condition: database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
),
changes: []database.Change{domainRepo.SetPrimary()},
expected: 0,
},
{
name: "no changes",
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
condition: database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.DomainCondition(database.TextOperationEqual, instanceDomain.Domain),
),
changes: []database.Change{},
expected: 0,
err: database.ErrNoChanges,
@@ -503,9 +525,7 @@ func TestUpdateInstanceDomain(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := t.Context()
rowsAffected, err := domainRepo.Update(ctx, test.condition, test.changes...)
rowsAffected, err := domainRepo.Update(t.Context(), tx, test.condition, test.changes...)
if test.err != nil {
assert.ErrorIs(t, err, test.err)
return
@@ -516,7 +536,7 @@ func TestUpdateInstanceDomain(t *testing.T) {
// verify changes were applied if rows were affected
if rowsAffected > 0 && len(test.changes) > 0 {
result, err := domainRepo.Get(ctx, database.WithCondition(test.condition))
result, err := domainRepo.Get(t.Context(), tx, database.WithCondition(test.condition))
require.NoError(t, err)
// We know changes were applied since rowsAffected > 0
@@ -529,10 +549,21 @@ func TestUpdateInstanceDomain(t *testing.T) {
}
func TestRemoveInstanceDomain(t *testing.T) {
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
err = tx.Rollback(t.Context())
if err != nil {
t.Log("error during rollback:", err)
}
}()
instanceRepo := repository.InstanceRepository()
domainRepo := repository.InstanceDomainRepository()
// create instance
instanceID := gofakeit.UUID()
instance := domain.Instance{
ID: instanceID,
ID: gofakeit.UUID(),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
@@ -540,37 +571,28 @@ func TestRemoveInstanceDomain(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
require.NoError(t, tx.Rollback(t.Context()))
}()
instanceRepo := repository.InstanceRepository(tx)
err = instanceRepo.Create(t.Context(), &instance)
err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
// add domains
domainRepo := instanceRepo.Domains(false)
domainName1 := gofakeit.DomainName()
domain1 := &domain.AddInstanceDomain{
InstanceID: instanceID,
Domain: domainName1,
InstanceID: instance.ID,
Domain: gofakeit.DomainName(),
IsPrimary: gu.Ptr(true),
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
}
domain2 := &domain.AddInstanceDomain{
InstanceID: instanceID,
InstanceID: instance.ID,
Domain: gofakeit.DomainName(),
IsPrimary: gu.Ptr(false),
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
}
err = domainRepo.Add(t.Context(), domain1)
err = domainRepo.Add(t.Context(), tx, domain1)
require.NoError(t, err)
err = domainRepo.Add(t.Context(), domain2)
err = domainRepo.Add(t.Context(), tx, domain2)
require.NoError(t, err)
tests := []struct {
@@ -580,35 +602,42 @@ func TestRemoveInstanceDomain(t *testing.T) {
}{
{
name: "remove by domain name",
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName1),
condition: database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.DomainCondition(database.TextOperationEqual, domain1.Domain),
),
expected: 1,
},
{
name: "remove by primary condition",
condition: domainRepo.IsPrimaryCondition(false),
condition: database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.IsPrimaryCondition(false),
),
expected: 1, // domain2 should still exist and be non-primary
},
{
name: "remove non-existent domain",
condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
condition: database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
),
expected: 0,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := t.Context()
// count before removal
beforeCount, err := domainRepo.List(ctx)
beforeCount, err := domainRepo.List(t.Context(), tx)
require.NoError(t, err)
rowsAffected, err := domainRepo.Remove(ctx, test.condition)
rowsAffected, err := domainRepo.Remove(t.Context(), tx, test.condition)
require.NoError(t, err)
assert.Equal(t, test.expected, rowsAffected)
// verify removal
afterCount, err := domainRepo.List(ctx)
afterCount, err := domainRepo.List(t.Context(), tx)
require.NoError(t, err)
assert.Equal(t, len(beforeCount)-int(test.expected), len(afterCount))
})
@@ -616,8 +645,7 @@ func TestRemoveInstanceDomain(t *testing.T) {
}
func TestInstanceDomainConditions(t *testing.T) {
instanceRepo := repository.InstanceRepository(pool)
domainRepo := instanceRepo.Domains(false)
domainRepo := repository.InstanceDomainRepository()
tests := []struct {
name string
@@ -671,8 +699,7 @@ func TestInstanceDomainConditions(t *testing.T) {
}
func TestInstanceDomainChanges(t *testing.T) {
instanceRepo := repository.InstanceRepository(pool)
domainRepo := instanceRepo.Domains(false)
domainRepo := repository.InstanceDomainRepository()
tests := []struct {
name string

View File

@@ -2,6 +2,7 @@ package repository_test
import (
"context"
"strconv"
"testing"
"time"
@@ -16,9 +17,20 @@ import (
)
func TestCreateInstance(t *testing.T) {
beforeCreate := time.Now()
tx, err := pool.Begin(context.Background(), nil)
require.NoError(t, err)
defer func() {
err := tx.Rollback(context.Background())
if err != nil {
t.Log("error during rollback:", err)
}
}()
instanceRepo := repository.InstanceRepository()
tests := []struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance
instance domain.Instance
err error
}{
@@ -59,14 +71,10 @@ func TestCreateInstance(t *testing.T) {
},
{
name: "adding same instance twice",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
instanceId := gofakeit.Name()
instanceName := gofakeit.Name()
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
inst := domain.Instance{
ID: instanceId,
Name: instanceName,
ID: gofakeit.UUID(),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
@@ -74,7 +82,9 @@ func TestCreateInstance(t *testing.T) {
DefaultLanguage: "defaultLanguage",
}
err := instanceRepo.Create(ctx, &inst)
err := instanceRepo.Create(t.Context(), tx, &inst)
require.NoError(t, err)
// change the name to make sure same only the id clashes
inst.Name = gofakeit.Name()
require.NoError(t, err)
@@ -84,7 +94,7 @@ func TestCreateInstance(t *testing.T) {
},
func() struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance
instance domain.Instance
err error
} {
@@ -92,14 +102,12 @@ func TestCreateInstance(t *testing.T) {
instanceName := gofakeit.Name()
return struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance
instance domain.Instance
err error
}{
name: "adding instance with same name twice",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
inst := domain.Instance{
ID: gofakeit.Name(),
Name: instanceName,
@@ -110,7 +118,7 @@ func TestCreateInstance(t *testing.T) {
DefaultLanguage: "defaultLanguage",
}
err := instanceRepo.Create(ctx, &inst)
err := instanceRepo.Create(t.Context(), tx, &inst)
require.NoError(t, err)
// change the id
@@ -135,11 +143,8 @@ func TestCreateInstance(t *testing.T) {
{
name: "adding instance with no id",
instance: func() domain.Instance {
// instanceId := gofakeit.Name()
instanceName := gofakeit.Name()
instance := domain.Instance{
// ID: instanceId,
Name: instanceName,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
@@ -153,19 +158,25 @@ func TestCreateInstance(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
savepoint, err := tx.Begin(t.Context())
require.NoError(t, err)
defer func() {
err = savepoint.Rollback(t.Context())
if err != nil {
t.Log("error during rollback:", err)
}
}()
var instance *domain.Instance
if tt.testFunc != nil {
instance = tt.testFunc(ctx, t)
instance = tt.testFunc(t, savepoint)
} else {
instance = &tt.instance
}
instanceRepo := repository.InstanceRepository(pool)
// create instance
beforeCreate := time.Now()
err := instanceRepo.Create(ctx, instance)
err = instanceRepo.Create(t.Context(), tx, instance)
assert.ErrorIs(t, err, tt.err)
if err != nil {
return
@@ -173,7 +184,7 @@ func TestCreateInstance(t *testing.T) {
afterCreate := time.Now()
// check instance values
instance, err = instanceRepo.Get(ctx,
instance, err = instanceRepo.Get(t.Context(), tx,
database.WithCondition(
instanceRepo.IDCondition(instance.ID),
),
@@ -194,22 +205,30 @@ func TestCreateInstance(t *testing.T) {
}
func TestUpdateInstance(t *testing.T) {
beforeUpdate := time.Now()
tx, err := pool.Begin(context.Background(), nil)
require.NoError(t, err)
defer func() {
err := tx.Rollback(context.Background())
if err != nil {
t.Log("error during rollback:", err)
}
}()
instanceRepo := repository.InstanceRepository()
tests := []struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance
rowsAffected int64
getErr error
}{
{
name: "happy path",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
instanceId := gofakeit.Name()
instanceName := gofakeit.Name()
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
inst := domain.Instance{
ID: instanceId,
Name: instanceName,
ID: gofakeit.UUID(),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
@@ -218,7 +237,7 @@ func TestUpdateInstance(t *testing.T) {
}
// create instance
err := instanceRepo.Create(ctx, &inst)
err := instanceRepo.Create(t.Context(), tx, &inst)
require.NoError(t, err)
return &inst
},
@@ -226,14 +245,10 @@ func TestUpdateInstance(t *testing.T) {
},
{
name: "update deleted instance",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
instanceId := gofakeit.Name()
instanceName := gofakeit.Name()
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
inst := domain.Instance{
ID: instanceId,
Name: instanceName,
ID: gofakeit.UUID(),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
@@ -242,11 +257,11 @@ func TestUpdateInstance(t *testing.T) {
}
// create instance
err := instanceRepo.Create(ctx, &inst)
err := instanceRepo.Create(t.Context(), tx, &inst)
require.NoError(t, err)
// delete instance
affectedRows, err := instanceRepo.Delete(ctx,
affectedRows, err := instanceRepo.Delete(t.Context(), tx,
inst.ID,
)
require.NoError(t, err)
@@ -258,11 +273,9 @@ func TestUpdateInstance(t *testing.T) {
},
{
name: "update non existent instance",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceId := gofakeit.Name()
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
inst := domain.Instance{
ID: instanceId,
ID: gofakeit.UUID(),
}
return &inst
},
@@ -272,15 +285,11 @@ func TestUpdateInstance(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
instanceRepo := repository.InstanceRepository(pool)
instance := tt.testFunc(t, tx)
instance := tt.testFunc(ctx, t)
beforeUpdate := time.Now()
// update name
newName := "new_" + instance.Name
rowsAffected, err := instanceRepo.Update(ctx,
rowsAffected, err := instanceRepo.Update(t.Context(), tx,
instance.ID,
instanceRepo.SetName(newName),
)
@@ -294,7 +303,7 @@ func TestUpdateInstance(t *testing.T) {
}
// check instance values
instance, err = instanceRepo.Get(ctx,
instance, err = instanceRepo.Get(t.Context(), tx,
database.WithCondition(
instanceRepo.IDCondition(instance.ID),
),
@@ -308,24 +317,31 @@ func TestUpdateInstance(t *testing.T) {
}
func TestGetInstance(t *testing.T) {
instanceRepo := repository.InstanceRepository(pool)
tx, err := pool.Begin(context.Background(), nil)
require.NoError(t, err)
defer func() {
err := tx.Rollback(context.Background())
if err != nil {
t.Log("error during rollback:", err)
}
}()
instanceRepo := repository.InstanceRepository()
domainRepo := repository.InstanceDomainRepository()
type test struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
testFunc func(t *testing.T) *domain.Instance
err error
}
tests := []test{
func() test {
instanceId := gofakeit.Name()
return test{
name: "happy path get using id",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceName := gofakeit.Name()
testFunc: func(t *testing.T) *domain.Instance {
inst := domain.Instance{
ID: instanceId,
Name: instanceName,
ID: gofakeit.UUID(),
Name: gofakeit.BeerName(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
@@ -334,7 +350,7 @@ func TestGetInstance(t *testing.T) {
}
// create instance
err := instanceRepo.Create(ctx, &inst)
err := instanceRepo.Create(t.Context(), tx, &inst)
require.NoError(t, err)
return &inst
},
@@ -342,14 +358,10 @@ func TestGetInstance(t *testing.T) {
}(),
{
name: "happy path including domains",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
instanceId := gofakeit.Name()
instanceName := gofakeit.Name()
testFunc: func(t *testing.T) *domain.Instance {
inst := domain.Instance{
ID: instanceId,
Name: instanceName,
ID: gofakeit.NewCrypto().UUID(),
Name: gofakeit.BeerName(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
@@ -358,10 +370,9 @@ func TestGetInstance(t *testing.T) {
}
// create instance
err := instanceRepo.Create(ctx, &inst)
err := instanceRepo.Create(t.Context(), tx, &inst)
require.NoError(t, err)
domainRepo := instanceRepo.Domains(false)
d := &domain.AddInstanceDomain{
InstanceID: inst.ID,
Domain: gofakeit.DomainName(),
@@ -369,7 +380,7 @@ func TestGetInstance(t *testing.T) {
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
}
err = domainRepo.Add(ctx, d)
err = domainRepo.Add(t.Context(), tx, d)
require.NoError(t, err)
inst.Domains = append(inst.Domains, &domain.InstanceDomain{
@@ -387,7 +398,7 @@ func TestGetInstance(t *testing.T) {
},
{
name: "get non existent instance",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
testFunc: func(t *testing.T) *domain.Instance {
inst := domain.Instance{
ID: "get non existent instance",
}
@@ -398,16 +409,13 @@ func TestGetInstance(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
instanceRepo := repository.InstanceRepository(pool)
var instance *domain.Instance
if tt.testFunc != nil {
instance = tt.testFunc(ctx, t)
instance = tt.testFunc(t)
}
// check instance values
returnedInstance, err := instanceRepo.Get(ctx,
returnedInstance, err := instanceRepo.Get(t.Context(), tx,
database.WithCondition(
instanceRepo.IDCondition(instance.ID),
),
@@ -434,28 +442,33 @@ func TestGetInstance(t *testing.T) {
}
func TestListInstance(t *testing.T) {
ctx := context.Background()
pool, stop, err := newEmbeddedDB(ctx)
tx, err := pool.Begin(context.Background(), nil)
require.NoError(t, err)
defer stop()
defer func() {
err := tx.Rollback(context.Background())
if err != nil {
t.Log("error during rollback:", err)
}
}()
instanceRepo := repository.InstanceRepository()
type test struct {
name string
testFunc func(ctx context.Context, t *testing.T) []*domain.Instance
testFunc func(t *testing.T, tx database.QueryExecutor) []*domain.Instance
conditionClauses []database.Condition
noInstanceReturned bool
}
tests := []test{
{
name: "happy path single instance no filter",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance {
noOfInstances := 1
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: gofakeit.Name(),
ID: strconv.Itoa(i),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
@@ -465,7 +478,7 @@ func TestListInstance(t *testing.T) {
}
// create instance
err := instanceRepo.Create(ctx, &inst)
err := instanceRepo.Create(t.Context(), tx, &inst)
require.NoError(t, err)
instances[i] = &inst
@@ -476,14 +489,13 @@ func TestListInstance(t *testing.T) {
},
{
name: "happy path multiple instance no filter",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance {
noOfInstances := 5
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: gofakeit.Name(),
ID: strconv.Itoa(i),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
@@ -493,7 +505,7 @@ func TestListInstance(t *testing.T) {
}
// create instance
err := instanceRepo.Create(ctx, &inst)
err := instanceRepo.Create(t.Context(), tx, &inst)
require.NoError(t, err)
instances[i] = &inst
@@ -503,17 +515,16 @@ func TestListInstance(t *testing.T) {
},
},
func() test {
instanceRepo := repository.InstanceRepository(pool)
instanceId := gofakeit.Name()
instanceID := gofakeit.BeerName()
return test{
name: "instance filter on id",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance {
noOfInstances := 1
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: instanceId,
ID: instanceID,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
@@ -523,7 +534,7 @@ func TestListInstance(t *testing.T) {
}
// create instance
err := instanceRepo.Create(ctx, &inst)
err := instanceRepo.Create(t.Context(), tx, &inst)
require.NoError(t, err)
instances[i] = &inst
@@ -531,21 +542,20 @@ func TestListInstance(t *testing.T) {
return instances
},
conditionClauses: []database.Condition{instanceRepo.IDCondition(instanceId)},
conditionClauses: []database.Condition{instanceRepo.IDCondition(instanceID)},
}
}(),
func() test {
instanceRepo := repository.InstanceRepository(pool)
instanceName := gofakeit.Name()
instanceName := gofakeit.BeerName()
return test{
name: "multiple instance filter on name",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance {
noOfInstances := 5
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: gofakeit.Name(),
ID: strconv.Itoa(i),
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
@@ -555,7 +565,7 @@ func TestListInstance(t *testing.T) {
}
// create instance
err := instanceRepo.Create(ctx, &inst)
err := instanceRepo.Create(t.Context(), tx, &inst)
require.NoError(t, err)
instances[i] = &inst
@@ -569,14 +579,15 @@ func TestListInstance(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Cleanup(func() {
_, err := pool.Exec(ctx, "DELETE FROM zitadel.instances")
savepoint, err := tx.Begin(t.Context())
require.NoError(t, err)
})
instances := tt.testFunc(ctx, t)
instanceRepo := repository.InstanceRepository(pool)
defer func() {
err = savepoint.Rollback(t.Context())
if err != nil {
t.Log("error during rollback:", err)
}
}()
instances := tt.testFunc(t, savepoint)
var condition database.Condition
if len(tt.conditionClauses) > 0 {
@@ -584,13 +595,13 @@ func TestListInstance(t *testing.T) {
}
// check instance values
returnedInstances, err := instanceRepo.List(ctx,
returnedInstances, err := instanceRepo.List(t.Context(), tx,
database.WithCondition(condition),
database.WithOrderByAscending(instanceRepo.CreatedAtColumn()),
database.WithOrderByAscending(instanceRepo.IDColumn()),
)
require.NoError(t, err)
if tt.noInstanceReturned {
assert.Nil(t, returnedInstances)
assert.Len(t, returnedInstances, 0)
return
}
@@ -609,25 +620,31 @@ func TestListInstance(t *testing.T) {
}
func TestDeleteInstance(t *testing.T) {
tx, err := pool.Begin(context.Background(), nil)
require.NoError(t, err)
defer func() {
err := tx.Rollback(context.Background())
if err != nil {
t.Log("error during rollback:", err)
}
}()
instanceRepo := repository.InstanceRepository()
type test struct {
name string
testFunc func(ctx context.Context, t *testing.T)
testFunc func(t *testing.T, tx database.QueryExecutor)
instanceID string
noOfDeletedRows int64
}
tests := []test{
func() test {
instanceRepo := repository.InstanceRepository(pool)
instanceId := gofakeit.Name()
var noOfInstances int64 = 1
instanceID := gofakeit.NewCrypto().UUID()
return test{
name: "happy path delete single instance filter id",
testFunc: func(ctx context.Context, t *testing.T) {
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
testFunc: func(t *testing.T, tx database.QueryExecutor) {
inst := domain.Instance{
ID: instanceId,
ID: instanceID,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
@@ -637,14 +654,11 @@ func TestDeleteInstance(t *testing.T) {
}
// create instance
err := instanceRepo.Create(ctx, &inst)
err := instanceRepo.Create(t.Context(), tx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
},
instanceID: instanceId,
noOfDeletedRows: noOfInstances,
instanceID: instanceID,
noOfDeletedRows: 1,
}
}(),
func() test {
@@ -655,18 +669,14 @@ func TestDeleteInstance(t *testing.T) {
}
}(),
func() test {
instanceRepo := repository.InstanceRepository(pool)
instanceName := gofakeit.Name()
instanceID := gofakeit.Name()
return test{
name: "deleted already deleted instance",
testFunc: func(ctx context.Context, t *testing.T) {
noOfInstances := 1
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
testFunc: func(t *testing.T, tx database.QueryExecutor) {
inst := domain.Instance{
ID: gofakeit.Name(),
Name: instanceName,
ID: instanceID,
Name: gofakeit.BeerName(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
@@ -675,20 +685,17 @@ func TestDeleteInstance(t *testing.T) {
}
// create instance
err := instanceRepo.Create(ctx, &inst)
err := instanceRepo.Create(t.Context(), tx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
// delete instance
affectedRows, err := instanceRepo.Delete(ctx,
instances[0].ID,
affectedRows, err := instanceRepo.Delete(t.Context(), tx,
inst.ID,
)
require.NoError(t, err)
assert.Equal(t, int64(1), affectedRows)
},
instanceID: instanceName,
instanceID: instanceID,
// this test should return 0 affected rows as the instance was already deleted
noOfDeletedRows: 0,
}
@@ -696,22 +703,26 @@ func TestDeleteInstance(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
instanceRepo := repository.InstanceRepository(pool)
savepoint, err := tx.Begin(t.Context())
require.NoError(t, err)
defer func() {
err = savepoint.Rollback(t.Context())
if err != nil {
t.Log("error during rollback:", err)
}
}()
if tt.testFunc != nil {
tt.testFunc(ctx, t)
tt.testFunc(t, savepoint)
}
// delete instance
noOfDeletedRows, err := instanceRepo.Delete(ctx,
tt.instanceID,
)
noOfDeletedRows, err := instanceRepo.Delete(t.Context(), savepoint, tt.instanceID)
require.NoError(t, err)
assert.Equal(t, noOfDeletedRows, tt.noOfDeletedRows)
// check instance was deleted
instance, err := instanceRepo.Get(ctx,
instance, err := instanceRepo.Get(t.Context(), savepoint,
database.WithCondition(
instanceRepo.IDCondition(tt.instanceID),
),

View File

@@ -14,25 +14,24 @@ import (
var _ domain.OrganizationRepository = (*org)(nil)
type org struct {
repository
shouldLoadDomains bool
domainRepo domain.OrganizationDomainRepository
domainRepo orgDomain
}
func OrganizationRepository(client database.QueryExecutor) domain.OrganizationRepository {
return &org{
repository: repository{
client: client,
},
}
func (o org) unqualifiedTableName() string {
return "organizations"
}
func OrganizationRepository() domain.OrganizationRepository {
return new(org)
}
const queryOrganizationStmt = `SELECT organizations.id, organizations.name, organizations.instance_id, organizations.state, organizations.created_at, organizations.updated_at` +
` , jsonb_agg(json_build_object('domain', org_domains.domain, 'isVerified', org_domains.is_verified, 'isPrimary', org_domains.is_primary, 'validationType', org_domains.validation_type, 'createdAt', org_domains.created_at, 'updatedAt', org_domains.updated_at)) FILTER (WHERE org_domains.org_id IS NOT NULL) AS domains` +
` , jsonb_agg(json_build_object('instanceId', org_domains.instance_id, 'orgId', org_domains.org_id, 'domain', org_domains.domain, 'isVerified', org_domains.is_verified, 'isPrimary', org_domains.is_primary, 'validationType', org_domains.validation_type, 'createdAt', org_domains.created_at, 'updatedAt', org_domains.updated_at)) FILTER (WHERE org_domains.org_id IS NOT NULL) AS domains` +
` FROM zitadel.organizations`
// Get implements [domain.OrganizationRepository].
func (o *org) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Organization, error) {
func (o org) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.Organization, error) {
opts = append(opts,
o.joinDomains(),
database.WithGroupBy(o.InstanceIDColumn(), o.IDColumn()),
@@ -43,15 +42,19 @@ func (o *org) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Or
opt(options)
}
if !options.Condition.IsRestrictingColumn(o.InstanceIDColumn()) {
return nil, database.NewMissingConditionError(o.InstanceIDColumn())
}
var builder database.StatementBuilder
builder.WriteString(queryOrganizationStmt)
options.Write(&builder)
return scanOrganization(ctx, o.client, &builder)
return scanOrganization(ctx, client, &builder)
}
// List implements [domain.OrganizationRepository].
func (o *org) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Organization, error) {
func (o org) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*domain.Organization, error) {
opts = append(opts,
o.joinDomains(),
database.WithGroupBy(o.InstanceIDColumn(), o.IDColumn()),
@@ -62,30 +65,15 @@ func (o *org) List(ctx context.Context, opts ...database.QueryOption) ([]*domain
opt(options)
}
if !options.Condition.IsRestrictingColumn(o.InstanceIDColumn()) {
return nil, database.NewMissingConditionError(o.InstanceIDColumn())
}
var builder database.StatementBuilder
builder.WriteString(queryOrganizationStmt)
options.Write(&builder)
return scanOrganizations(ctx, o.client, &builder)
}
func (o *org) joinDomains() database.QueryOption {
columns := make([]database.Condition, 0, 3)
columns = append(columns,
database.NewColumnCondition(o.InstanceIDColumn(), o.Domains(false).InstanceIDColumn()),
database.NewColumnCondition(o.IDColumn(), o.Domains(false).OrgIDColumn()),
)
// If domains should not be joined, we make sure to return null for the domain columns
// the query optimizer of the dialect should optimize this away if no domains are requested
if !o.shouldLoadDomains {
columns = append(columns, database.IsNull(o.domainRepo.OrgIDColumn()))
}
return database.WithLeftJoin(
"zitadel.org_domains",
database.And(columns...),
)
return scanOrganizations(ctx, client, &builder)
}
const createOrganizationStmt = `INSERT INTO zitadel.organizations (id, name, instance_id, state)` +
@@ -93,46 +81,48 @@ const createOrganizationStmt = `INSERT INTO zitadel.organizations (id, name, ins
` RETURNING created_at, updated_at`
// Create implements [domain.OrganizationRepository].
func (o *org) Create(ctx context.Context, organization *domain.Organization) error {
func (o org) Create(ctx context.Context, client database.QueryExecutor, organization *domain.Organization) error {
builder := database.StatementBuilder{}
builder.AppendArgs(organization.ID, organization.Name, organization.InstanceID, organization.State)
builder.WriteString(createOrganizationStmt)
return o.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&organization.CreatedAt, &organization.UpdatedAt)
return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&organization.CreatedAt, &organization.UpdatedAt)
}
// Update implements [domain.OrganizationRepository].
func (o *org) Update(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string, changes ...database.Change) (int64, error) {
func (o org) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error) {
if len(changes) == 0 {
return 0, database.ErrNoChanges
}
builder := database.StatementBuilder{}
if !condition.IsRestrictingColumn(o.InstanceIDColumn()) {
return 0, database.NewMissingConditionError(o.InstanceIDColumn())
}
if !database.Changes(changes).IsOnColumn(o.UpdatedAtColumn()) {
changes = append(changes, database.NewChange(o.UpdatedAtColumn(), database.NullInstruction))
}
var builder database.StatementBuilder
builder.WriteString(`UPDATE zitadel.organizations SET `)
instanceIDCondition := o.InstanceIDCondition(instanceID)
conditions := []database.Condition{id, instanceIDCondition}
database.Changes(changes).Write(&builder)
writeCondition(&builder, database.And(conditions...))
writeCondition(&builder, condition)
stmt := builder.String()
rowsAffected, err := o.client.Exec(ctx, stmt, builder.Args()...)
rowsAffected, err := client.Exec(ctx, stmt, builder.Args()...)
return rowsAffected, err
}
// Delete implements [domain.OrganizationRepository].
func (o *org) Delete(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string) (int64, error) {
builder := database.StatementBuilder{}
func (o org) Delete(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error) {
if !condition.IsRestrictingColumn(o.InstanceIDColumn()) {
return 0, database.NewMissingConditionError(o.InstanceIDColumn())
}
var builder database.StatementBuilder
builder.WriteString(`DELETE FROM zitadel.organizations`)
writeCondition(&builder, condition)
instanceIDCondition := o.InstanceIDCondition(instanceID)
conditions := []database.Condition{id, instanceIDCondition}
writeCondition(&builder, database.And(conditions...))
return o.client.Exec(ctx, builder.String(), builder.Args()...)
return client.Exec(ctx, builder.String(), builder.Args()...)
}
// -------------------------------------------------------------
@@ -154,13 +144,13 @@ func (o org) SetState(state domain.OrgState) database.Change {
// -------------------------------------------------------------
// IDCondition implements [domain.organizationConditions].
func (o org) IDCondition(id string) domain.OrgIdentifierCondition {
func (o org) IDCondition(id string) database.Condition {
return database.NewTextCondition(o.IDColumn(), database.TextOperationEqual, id)
}
// NameCondition implements [domain.organizationConditions].
func (o org) NameCondition(name string) domain.OrgIdentifierCondition {
return database.NewTextCondition(o.NameColumn(), database.TextOperationEqual, name)
func (o org) NameCondition(op database.TextOperation, name string) database.Condition {
return database.NewTextCondition(o.NameColumn(), op, name)
}
// InstanceIDCondition implements [domain.organizationConditions].
@@ -173,38 +163,62 @@ func (o org) StateCondition(state domain.OrgState) database.Condition {
return database.NewTextCondition(o.StateColumn(), database.TextOperationEqual, state.String())
}
// ExistsDomain creates a correlated [database.Exists] condition on org_domains.
// Use this filter to make sure the organization returned contains a specific domain.
// Example usage:
//
// domainRepo := orgRepo.Domains(true) // ensure domains are loaded/aggregated
// org, _ := orgRepo.Get(ctx,
// database.WithCondition(
// database.And(
// orgRepo.InstanceIDCondition(instanceID),
// orgRepo.ExistsDomain(domainRepo.DomainCondition(database.TextOperationEqual, "example.com")),
// ),
// ),
// )
func (o org) ExistsDomain(cond database.Condition) database.Condition {
return database.Exists(
o.domainRepo.qualifiedTableName(),
database.And(
database.NewColumnCondition(o.InstanceIDColumn(), o.domainRepo.InstanceIDColumn()),
database.NewColumnCondition(o.IDColumn(), o.domainRepo.OrgIDColumn()),
cond,
),
)
}
// -------------------------------------------------------------
// columns
// -------------------------------------------------------------
// IDColumn implements [domain.organizationColumns].
func (org) IDColumn() database.Column {
return database.NewColumn("organizations", "id")
func (o org) IDColumn() database.Column {
return database.NewColumn(o.unqualifiedTableName(), "id")
}
// NameColumn implements [domain.organizationColumns].
func (org) NameColumn() database.Column {
return database.NewColumn("organizations", "name")
func (o org) NameColumn() database.Column {
return database.NewColumn(o.unqualifiedTableName(), "name")
}
// InstanceIDColumn implements [domain.organizationColumns].
func (org) InstanceIDColumn() database.Column {
return database.NewColumn("organizations", "instance_id")
func (o org) InstanceIDColumn() database.Column {
return database.NewColumn(o.unqualifiedTableName(), "instance_id")
}
// StateColumn implements [domain.organizationColumns].
func (org) StateColumn() database.Column {
return database.NewColumn("organizations", "state")
func (o org) StateColumn() database.Column {
return database.NewColumn(o.unqualifiedTableName(), "state")
}
// CreatedAtColumn implements [domain.organizationColumns].
func (org) CreatedAtColumn() database.Column {
return database.NewColumn("organizations", "created_at")
func (o org) CreatedAtColumn() database.Column {
return database.NewColumn(o.unqualifiedTableName(), "created_at")
}
// UpdatedAtColumn implements [domain.organizationColumns].
func (org) UpdatedAtColumn() database.Column {
return database.NewColumn("organizations", "updated_at")
func (o org) UpdatedAtColumn() database.Column {
return database.NewColumn(o.unqualifiedTableName(), "updated_at")
}
// -------------------------------------------------------------
@@ -255,20 +269,27 @@ func scanOrganizations(ctx context.Context, querier database.Querier, builder *d
// sub repositories
// -------------------------------------------------------------
// Domains implements [domain.OrganizationRepository].
func (o *org) Domains(shouldLoad bool) domain.OrganizationDomainRepository {
if !o.shouldLoadDomains {
o.shouldLoadDomains = shouldLoad
func (o org) LoadDomains() domain.OrganizationRepository {
return &org{
shouldLoadDomains: true,
}
if o.domainRepo != nil {
return o.domainRepo
}
o.domainRepo = &orgDomain{
repository: o.repository,
org: o,
}
return o.domainRepo
}
func (o org) joinDomains() database.QueryOption {
columns := make([]database.Condition, 0, 3)
columns = append(columns,
database.NewColumnCondition(o.InstanceIDColumn(), o.domainRepo.InstanceIDColumn()),
database.NewColumnCondition(o.IDColumn(), o.domainRepo.OrgIDColumn()),
)
// If domains should not be joined, we make sure to return null for the domain columns
// the query optimizer of the dialect should optimize this away if no domains are requested
if !o.shouldLoadDomains {
columns = append(columns, database.IsNull(o.domainRepo.OrgIDColumn()))
}
return database.WithLeftJoin(
o.domainRepo.qualifiedTableName(),
database.And(columns...),
)
}

View File

@@ -10,9 +10,18 @@ import (
var _ domain.OrganizationDomainRepository = (*orgDomain)(nil)
type orgDomain struct {
repository
*org
type orgDomain struct{}
func OrganizationDomainRepository() domain.OrganizationDomainRepository {
return new(orgDomain)
}
func (orgDomain) qualifiedTableName() string {
return "zitadel.org_domains"
}
func (orgDomain) unqualifiedTableName() string {
return "org_domains"
}
// -------------------------------------------------------------
@@ -24,36 +33,44 @@ const queryOrganizationDomainStmt = `SELECT instance_id, org_id, domain, is_veri
// Get implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).Get of orgDomain.org.
func (o *orgDomain) Get(ctx context.Context, opts ...database.QueryOption) (*domain.OrganizationDomain, error) {
func (o orgDomain) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.OrganizationDomain, error) {
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
}
if !options.Condition.IsRestrictingColumn(o.InstanceIDColumn()) {
return nil, database.NewMissingConditionError(o.InstanceIDColumn())
}
var builder database.StatementBuilder
builder.WriteString(queryOrganizationDomainStmt)
options.Write(&builder)
return scanOrganizationDomain(ctx, o.client, &builder)
return scanOrganizationDomain(ctx, client, &builder)
}
// List implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).List of orgDomain.org.
func (o *orgDomain) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.OrganizationDomain, error) {
func (o orgDomain) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*domain.OrganizationDomain, error) {
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
}
if !options.Condition.IsRestrictingColumn(o.InstanceIDColumn()) {
return nil, database.NewMissingConditionError(o.InstanceIDColumn())
}
var builder database.StatementBuilder
builder.WriteString(queryOrganizationDomainStmt)
options.Write(&builder)
return scanOrganizationDomains(ctx, o.client, &builder)
return scanOrganizationDomains(ctx, client, &builder)
}
// Add implements [domain.OrganizationDomainRepository].
func (o *orgDomain) Add(ctx context.Context, domain *domain.AddOrganizationDomain) error {
func (o orgDomain) Add(ctx context.Context, client database.QueryExecutor, domain *domain.AddOrganizationDomain) error {
var (
builder database.StatementBuilder
createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction
@@ -69,33 +86,47 @@ func (o *orgDomain) Add(ctx context.Context, domain *domain.AddOrganizationDomai
builder.WriteArgs(domain.InstanceID, domain.OrgID, domain.Domain, domain.IsVerified, domain.IsPrimary, domain.ValidationType, createdAt, updatedAt)
builder.WriteString(`) RETURNING created_at, updated_at`)
return o.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt)
return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt)
}
// Update implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).Update of orgDomain.org.
func (o *orgDomain) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) {
func (o orgDomain) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error) {
if len(changes) == 0 {
return 0, database.ErrNoChanges
}
if !condition.IsRestrictingColumn(o.InstanceIDColumn()) {
return 0, database.NewMissingConditionError(o.InstanceIDColumn())
}
if !condition.IsRestrictingColumn(o.OrgIDColumn()) {
return 0, database.NewMissingConditionError(o.OrgIDColumn())
}
if !database.Changes(changes).IsOnColumn(o.UpdatedAtColumn()) {
changes = append(changes, database.NewChange(o.UpdatedAtColumn(), database.NullInstruction))
}
var builder database.StatementBuilder
builder.WriteString(`UPDATE zitadel.org_domains SET `)
database.Changes(changes).Write(&builder)
writeCondition(&builder, condition)
return o.client.Exec(ctx, builder.String(), builder.Args()...)
return client.Exec(ctx, builder.String(), builder.Args()...)
}
// Remove implements [domain.OrganizationDomainRepository].
func (o *orgDomain) Remove(ctx context.Context, condition database.Condition) (int64, error) {
var builder database.StatementBuilder
func (o orgDomain) Remove(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error) {
if !condition.IsRestrictingColumn(o.InstanceIDColumn()) {
return 0, database.NewMissingConditionError(o.InstanceIDColumn())
}
if !condition.IsRestrictingColumn(o.OrgIDColumn()) {
return 0, database.NewMissingConditionError(o.OrgIDColumn())
}
var builder database.StatementBuilder
builder.WriteString(`DELETE FROM zitadel.org_domains `)
writeCondition(&builder, condition)
return o.client.Exec(ctx, builder.String(), builder.Args()...)
return client.Exec(ctx, builder.String(), builder.Args()...)
}
// -------------------------------------------------------------
@@ -158,45 +189,43 @@ func (o orgDomain) OrgIDCondition(orgID string) database.Condition {
// CreatedAtColumn implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).CreatedAtColumn of orgDomain.org.
func (orgDomain) CreatedAtColumn() database.Column {
return database.NewColumn("org_domains", "created_at")
func (o orgDomain) CreatedAtColumn() database.Column {
return database.NewColumn(o.unqualifiedTableName(), "created_at")
}
// DomainColumn implements [domain.OrganizationDomainRepository].
func (orgDomain) DomainColumn() database.Column {
return database.NewColumn("org_domains", "domain")
func (o orgDomain) DomainColumn() database.Column {
return database.NewColumn(o.unqualifiedTableName(), "domain")
}
// InstanceIDColumn implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).InstanceIDColumn of orgDomain.org.
func (orgDomain) InstanceIDColumn() database.Column {
return database.NewColumn("org_domains", "instance_id")
func (o orgDomain) InstanceIDColumn() database.Column {
return database.NewColumn(o.unqualifiedTableName(), "instance_id")
}
// IsPrimaryColumn implements [domain.OrganizationDomainRepository].
func (orgDomain) IsPrimaryColumn() database.Column {
return database.NewColumn("org_domains", "is_primary")
func (o orgDomain) IsPrimaryColumn() database.Column {
return database.NewColumn(o.unqualifiedTableName(), "is_primary")
}
// IsVerifiedColumn implements [domain.OrganizationDomainRepository].
func (orgDomain) IsVerifiedColumn() database.Column {
return database.NewColumn("org_domains", "is_verified")
func (o orgDomain) IsVerifiedColumn() database.Column {
return database.NewColumn(o.unqualifiedTableName(), "is_verified")
}
// OrgIDColumn implements [domain.OrganizationDomainRepository].
func (orgDomain) OrgIDColumn() database.Column {
return database.NewColumn("org_domains", "org_id")
func (o orgDomain) OrgIDColumn() database.Column {
return database.NewColumn(o.unqualifiedTableName(), "org_id")
}
// UpdatedAtColumn implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).UpdatedAtColumn of orgDomain.org.
func (orgDomain) UpdatedAtColumn() database.Column {
return database.NewColumn("org_domains", "updated_at")
func (o orgDomain) UpdatedAtColumn() database.Column {
return database.NewColumn(o.unqualifiedTableName(), "updated_at")
}
// ValidationTypeColumn implements [domain.OrganizationDomainRepository].
func (orgDomain) ValidationTypeColumn() database.Column {
return database.NewColumn("org_domains", "validation_type")
func (o orgDomain) ValidationTypeColumn() database.Column {
return database.NewColumn(o.unqualifiedTableName(), "validation_type")
}
// -------------------------------------------------------------

View File

@@ -1,7 +1,6 @@
package repository_test
import (
"context"
"testing"
"github.com/brianvoe/gofakeit/v6"
@@ -15,6 +14,15 @@ import (
)
func TestAddOrganizationDomain(t *testing.T) {
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
err = tx.Rollback(t.Context())
if err != nil {
t.Log("error during rollback:", err)
}
}()
// create instance
instanceID := gofakeit.UUID()
instance := domain.Instance{
@@ -26,8 +34,11 @@ func TestAddOrganizationDomain(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
instanceRepo := repository.InstanceRepository(pool)
err := instanceRepo.Create(t.Context(), &instance)
instanceRepo := repository.InstanceRepository()
orgRepo := repository.OrganizationRepository()
domainRepo := repository.OrganizationDomainRepository()
err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
// create organization
@@ -41,7 +52,7 @@ func TestAddOrganizationDomain(t *testing.T) {
tests := []struct {
name string
testFunc func(ctx context.Context, t *testing.T, domainRepo domain.OrganizationDomainRepository) *domain.AddOrganizationDomain
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.AddOrganizationDomain
organizationDomain domain.AddOrganizationDomain
err error
}{
@@ -92,7 +103,7 @@ func TestAddOrganizationDomain(t *testing.T) {
},
{
name: "add domain with same domain twice",
testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.OrganizationDomainRepository) *domain.AddOrganizationDomain {
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.AddOrganizationDomain {
domainName := gofakeit.DomainName()
organizationDomain := &domain.AddOrganizationDomain{
@@ -104,7 +115,7 @@ func TestAddOrganizationDomain(t *testing.T) {
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
}
err := domainRepo.Add(ctx, organizationDomain)
err := domainRepo.Add(t.Context(), tx, organizationDomain)
require.NoError(t, err)
// return same domain again
@@ -169,28 +180,26 @@ func TestAddOrganizationDomain(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := t.Context()
tx, err := pool.Begin(t.Context(), nil)
savepoint, err := tx.Begin(t.Context())
require.NoError(t, err)
defer func() {
require.NoError(t, tx.Rollback(t.Context()))
err = savepoint.Rollback(t.Context())
if err != nil {
t.Log("error during rollback:", err)
}
}()
orgRepo := repository.OrganizationRepository(tx)
err = orgRepo.Create(t.Context(), &organization)
err = orgRepo.Create(t.Context(), savepoint, &organization)
require.NoError(t, err)
domainRepo := orgRepo.Domains(false)
var organizationDomain *domain.AddOrganizationDomain
if test.testFunc != nil {
organizationDomain = test.testFunc(ctx, t, domainRepo)
organizationDomain = test.testFunc(t, savepoint)
} else {
organizationDomain = &test.organizationDomain
}
err = domainRepo.Add(ctx, organizationDomain)
err = domainRepo.Add(t.Context(), tx, organizationDomain)
if test.err != nil {
assert.ErrorIs(t, err, test.err)
return
@@ -204,6 +213,16 @@ func TestAddOrganizationDomain(t *testing.T) {
}
func TestGetOrganizationDomain(t *testing.T) {
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
require.NoError(t, tx.Rollback(t.Context()))
}()
instanceRepo := repository.InstanceRepository()
orgRepo := repository.OrganizationRepository()
domainRepo := repository.OrganizationDomainRepository()
// create instance
instanceID := gofakeit.UUID()
instance := domain.Instance{
@@ -225,29 +244,18 @@ func TestGetOrganizationDomain(t *testing.T) {
State: domain.OrgStateActive,
}
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
require.NoError(t, tx.Rollback(t.Context()))
}()
instanceRepo := repository.InstanceRepository(tx)
err = instanceRepo.Create(t.Context(), &instance)
err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
orgRepo := repository.OrganizationRepository(tx)
err = orgRepo.Create(t.Context(), &organization)
err = orgRepo.Create(t.Context(), tx, &organization)
require.NoError(t, err)
// add domains
domainRepo := orgRepo.Domains(false)
domainName1 := gofakeit.DomainName()
domainName2 := gofakeit.DomainName()
domain1 := &domain.AddOrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: domainName1,
Domain: gofakeit.DomainName(),
IsVerified: true,
IsPrimary: true,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
@@ -255,15 +263,15 @@ func TestGetOrganizationDomain(t *testing.T) {
domain2 := &domain.AddOrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: domainName2,
Domain: gofakeit.DomainName(),
IsVerified: false,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
}
err = domainRepo.Add(t.Context(), domain1)
err = domainRepo.Add(t.Context(), tx, domain1)
require.NoError(t, err)
err = domainRepo.Add(t.Context(), domain2)
err = domainRepo.Add(t.Context(), tx, domain2)
require.NoError(t, err)
tests := []struct {
@@ -275,12 +283,15 @@ func TestGetOrganizationDomain(t *testing.T) {
{
name: "get primary domain",
opts: []database.QueryOption{
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
database.WithCondition(database.And(
domainRepo.InstanceIDCondition(instanceID),
domainRepo.IsPrimaryCondition(true),
)),
},
expected: &domain.OrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: domainName1,
Domain: domain1.Domain,
IsVerified: true,
IsPrimary: true,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
@@ -289,12 +300,15 @@ func TestGetOrganizationDomain(t *testing.T) {
{
name: "get by domain name",
opts: []database.QueryOption{
database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, domainName2)),
database.WithCondition(database.And(
domainRepo.InstanceIDCondition(instanceID),
domainRepo.DomainCondition(database.TextOperationEqual, domain2.Domain),
)),
},
expected: &domain.OrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: domainName2,
Domain: domain2.Domain,
IsVerified: false,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
@@ -303,13 +317,16 @@ func TestGetOrganizationDomain(t *testing.T) {
{
name: "get by org ID",
opts: []database.QueryOption{
database.WithCondition(domainRepo.OrgIDCondition(orgID)),
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
database.WithCondition(database.And(
domainRepo.InstanceIDCondition(instanceID),
domainRepo.OrgIDCondition(orgID),
domainRepo.IsPrimaryCondition(true),
)),
},
expected: &domain.OrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: domainName1,
Domain: domain1.Domain,
IsVerified: true,
IsPrimary: true,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
@@ -318,12 +335,15 @@ func TestGetOrganizationDomain(t *testing.T) {
{
name: "get verified domain",
opts: []database.QueryOption{
database.WithCondition(domainRepo.IsVerifiedCondition(true)),
database.WithCondition(database.And(
domainRepo.InstanceIDCondition(instanceID),
domainRepo.IsVerifiedCondition(true),
)),
},
expected: &domain.OrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: domainName1,
Domain: domain1.Domain,
IsVerified: true,
IsPrimary: true,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
@@ -332,7 +352,10 @@ func TestGetOrganizationDomain(t *testing.T) {
{
name: "get non-existent domain",
opts: []database.QueryOption{
database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com")),
database.WithCondition(database.And(
domainRepo.InstanceIDCondition(instanceID),
domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
)),
},
err: new(database.NoRowFoundError),
},
@@ -340,9 +363,7 @@ func TestGetOrganizationDomain(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := t.Context()
result, err := domainRepo.Get(ctx, test.opts...)
result, err := domainRepo.Get(t.Context(), tx, test.opts...)
if test.err != nil {
assert.ErrorIs(t, err, test.err)
return
@@ -362,10 +383,22 @@ func TestGetOrganizationDomain(t *testing.T) {
}
func TestListOrganizationDomains(t *testing.T) {
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
err = tx.Rollback(t.Context())
if err != nil {
t.Log("error during rollback:", err)
}
}()
instanceRepo := repository.InstanceRepository()
orgRepo := repository.OrganizationRepository()
domainRepo := repository.OrganizationDomainRepository()
// create instance
instanceID := gofakeit.UUID()
instance := domain.Instance{
ID: instanceID,
ID: gofakeit.UUID(),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
@@ -375,50 +408,40 @@ func TestListOrganizationDomains(t *testing.T) {
}
// create organization
orgID := gofakeit.UUID()
organization := domain.Organization{
ID: orgID,
ID: gofakeit.UUID(),
Name: gofakeit.Name(),
InstanceID: instanceID,
InstanceID: instance.ID,
State: domain.OrgStateActive,
}
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
require.NoError(t, tx.Rollback(t.Context()))
}()
instanceRepo := repository.InstanceRepository(tx)
err = instanceRepo.Create(t.Context(), &instance)
err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
orgRepo := repository.OrganizationRepository(tx)
err = orgRepo.Create(t.Context(), &organization)
err = orgRepo.Create(t.Context(), tx, &organization)
require.NoError(t, err)
// add multiple domains
domainRepo := orgRepo.Domains(false)
domains := []domain.AddOrganizationDomain{
{
InstanceID: instanceID,
OrgID: orgID,
InstanceID: instance.ID,
OrgID: organization.ID,
Domain: gofakeit.DomainName(),
IsVerified: true,
IsPrimary: true,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
},
{
InstanceID: instanceID,
OrgID: orgID,
InstanceID: instance.ID,
OrgID: organization.ID,
Domain: gofakeit.DomainName(),
IsVerified: false,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
},
{
InstanceID: instanceID,
OrgID: orgID,
InstanceID: instance.ID,
OrgID: organization.ID,
Domain: gofakeit.DomainName(),
IsVerified: true,
IsPrimary: false,
@@ -427,7 +450,7 @@ func TestListOrganizationDomains(t *testing.T) {
}
for i := range domains {
err = domainRepo.Add(t.Context(), &domains[i])
err = domainRepo.Add(t.Context(), tx, &domains[i])
require.NoError(t, err)
}
@@ -438,41 +461,58 @@ func TestListOrganizationDomains(t *testing.T) {
}{
{
name: "list all domains",
opts: []database.QueryOption{},
opts: []database.QueryOption{
database.WithCondition(domainRepo.InstanceIDCondition(instance.ID)),
},
expectedCount: 3,
},
{
name: "list verified domains",
opts: []database.QueryOption{
database.WithCondition(domainRepo.IsVerifiedCondition(true)),
database.WithCondition(database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.IsVerifiedCondition(true),
)),
},
expectedCount: 2,
},
{
name: "list primary domains",
opts: []database.QueryOption{
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
database.WithCondition(database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.IsPrimaryCondition(true),
)),
},
expectedCount: 1,
},
{
name: "list by organization",
opts: []database.QueryOption{
database.WithCondition(domainRepo.OrgIDCondition(orgID)),
database.WithCondition(database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.OrgIDCondition(organization.ID),
)),
},
expectedCount: 3,
},
{
name: "list by instance",
opts: []database.QueryOption{
database.WithCondition(domainRepo.InstanceIDCondition(instanceID)),
database.WithCondition(database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.InstanceIDCondition(instance.ID),
)),
},
expectedCount: 3,
},
{
name: "list non-existent organization",
opts: []database.QueryOption{
database.WithCondition(domainRepo.OrgIDCondition("non-existent")),
database.WithCondition(database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.OrgIDCondition("non-existent"),
)),
},
expectedCount: 0,
},
@@ -482,13 +522,13 @@ func TestListOrganizationDomains(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
ctx := t.Context()
results, err := domainRepo.List(ctx, test.opts...)
results, err := domainRepo.List(ctx, tx, test.opts...)
require.NoError(t, err)
assert.Len(t, results, test.expectedCount)
for _, result := range results {
assert.Equal(t, instanceID, result.InstanceID)
assert.Equal(t, orgID, result.OrgID)
assert.Equal(t, instance.ID, result.InstanceID)
assert.Equal(t, organization.ID, result.OrgID)
assert.NotEmpty(t, result.Domain)
assert.NotEmpty(t, result.CreatedAt)
assert.NotEmpty(t, result.UpdatedAt)
@@ -498,10 +538,22 @@ func TestListOrganizationDomains(t *testing.T) {
}
func TestUpdateOrganizationDomain(t *testing.T) {
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
err = tx.Rollback(t.Context())
if err != nil {
t.Log("error during rollback:", err)
}
}()
instanceRepo := repository.InstanceRepository()
orgRepo := repository.OrganizationRepository()
domainRepo := repository.OrganizationDomainRepository()
// create instance
instanceID := gofakeit.UUID()
instance := domain.Instance{
ID: instanceID,
ID: gofakeit.UUID(),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
@@ -511,41 +563,30 @@ func TestUpdateOrganizationDomain(t *testing.T) {
}
// create organization
orgID := gofakeit.UUID()
organization := domain.Organization{
ID: orgID,
ID: gofakeit.UUID(),
Name: gofakeit.Name(),
InstanceID: instanceID,
InstanceID: instance.ID,
State: domain.OrgStateActive,
}
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
require.NoError(t, tx.Rollback(t.Context()))
}()
instanceRepo := repository.InstanceRepository(tx)
err = instanceRepo.Create(t.Context(), &instance)
err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
orgRepo := repository.OrganizationRepository(tx)
err = orgRepo.Create(t.Context(), &organization)
err = orgRepo.Create(t.Context(), tx, &organization)
require.NoError(t, err)
// add domain
domainRepo := orgRepo.Domains(false)
domainName := gofakeit.DomainName()
organizationDomain := &domain.AddOrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: domainName,
InstanceID: instance.ID,
OrgID: organization.ID,
Domain: gofakeit.DomainName(),
IsVerified: false,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
}
err = domainRepo.Add(t.Context(), organizationDomain)
err = domainRepo.Add(t.Context(), tx, organizationDomain)
require.NoError(t, err)
tests := []struct {
@@ -557,25 +598,41 @@ func TestUpdateOrganizationDomain(t *testing.T) {
}{
{
name: "set verified",
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
condition: database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.OrgIDCondition(organization.ID),
domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
),
changes: []database.Change{domainRepo.SetVerified()},
expected: 1,
},
{
name: "set primary",
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
condition: database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.OrgIDCondition(organization.ID),
domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
),
changes: []database.Change{domainRepo.SetPrimary()},
expected: 1,
},
{
name: "set validation type",
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
condition: database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.OrgIDCondition(organization.ID),
domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
),
changes: []database.Change{domainRepo.SetValidationType(domain.DomainValidationTypeHTTP)},
expected: 1,
},
{
name: "multiple changes",
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
condition: database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.OrgIDCondition(organization.ID),
domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
),
changes: []database.Change{
domainRepo.SetVerified(),
domainRepo.SetPrimary(),
@@ -585,19 +642,31 @@ func TestUpdateOrganizationDomain(t *testing.T) {
},
{
name: "update by org ID and domain",
condition: database.And(domainRepo.OrgIDCondition(orgID), domainRepo.DomainCondition(database.TextOperationEqual, domainName)),
condition: database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.OrgIDCondition(organization.ID),
domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
),
changes: []database.Change{domainRepo.SetVerified()},
expected: 1,
},
{
name: "update non-existent domain",
condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
condition: database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.OrgIDCondition(organization.ID),
domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
),
changes: []database.Change{domainRepo.SetVerified()},
expected: 0,
},
{
name: "no changes",
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
condition: database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.OrgIDCondition(organization.ID),
domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
),
changes: []database.Change{},
expected: 0,
err: database.ErrNoChanges,
@@ -606,9 +675,7 @@ func TestUpdateOrganizationDomain(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := t.Context()
rowsAffected, err := domainRepo.Update(ctx, test.condition, test.changes...)
rowsAffected, err := domainRepo.Update(t.Context(), tx, test.condition, test.changes...)
if test.err != nil {
assert.ErrorIs(t, err, test.err)
return
@@ -619,7 +686,7 @@ func TestUpdateOrganizationDomain(t *testing.T) {
// verify changes were applied if rows were affected
if rowsAffected > 0 && len(test.changes) > 0 {
result, err := domainRepo.Get(ctx, database.WithCondition(test.condition))
result, err := domainRepo.Get(t.Context(), tx, database.WithCondition(test.condition))
require.NoError(t, err)
// We know changes were applied since rowsAffected > 0
@@ -632,10 +699,22 @@ func TestUpdateOrganizationDomain(t *testing.T) {
}
func TestRemoveOrganizationDomain(t *testing.T) {
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
err = tx.Rollback(t.Context())
if err != nil {
t.Log("error during rollback:", err)
}
}()
instanceRepo := repository.InstanceRepository()
orgRepo := repository.OrganizationRepository()
domainRepo := repository.OrganizationDomainRepository()
// create instance
instanceID := gofakeit.UUID()
instance := domain.Instance{
ID: instanceID,
ID: gofakeit.UUID(),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
@@ -645,53 +724,40 @@ func TestRemoveOrganizationDomain(t *testing.T) {
}
// create organization
orgID := gofakeit.UUID()
organization := domain.Organization{
ID: orgID,
ID: gofakeit.UUID(),
Name: gofakeit.Name(),
InstanceID: instanceID,
InstanceID: instance.ID,
State: domain.OrgStateActive,
}
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
require.NoError(t, tx.Rollback(t.Context()))
}()
instanceRepo := repository.InstanceRepository(tx)
err = instanceRepo.Create(t.Context(), &instance)
err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
orgRepo := repository.OrganizationRepository(tx)
err = orgRepo.Create(t.Context(), &organization)
err = orgRepo.Create(t.Context(), tx, &organization)
require.NoError(t, err)
// add domains
domainRepo := orgRepo.Domains(false)
domainName1 := gofakeit.DomainName()
domainName2 := gofakeit.DomainName()
domain1 := &domain.AddOrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: domainName1,
InstanceID: instance.ID,
OrgID: organization.ID,
Domain: gofakeit.DomainName(),
IsVerified: true,
IsPrimary: true,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
}
domain2 := &domain.AddOrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: domainName2,
InstanceID: instance.ID,
OrgID: organization.ID,
Domain: gofakeit.DomainName(),
IsVerified: false,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
}
err = domainRepo.Add(t.Context(), domain1)
err = domainRepo.Add(t.Context(), tx, domain1)
require.NoError(t, err)
err = domainRepo.Add(t.Context(), domain2)
err = domainRepo.Add(t.Context(), tx, domain2)
require.NoError(t, err)
tests := []struct {
@@ -701,49 +767,69 @@ func TestRemoveOrganizationDomain(t *testing.T) {
}{
{
name: "remove by domain name",
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName1),
condition: database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.OrgIDCondition(organization.ID),
domainRepo.DomainCondition(database.TextOperationEqual, domain1.Domain),
),
expected: 1,
},
{
name: "remove by primary condition",
condition: domainRepo.IsPrimaryCondition(false),
condition: database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.OrgIDCondition(organization.ID),
domainRepo.IsPrimaryCondition(false),
),
expected: 1, // domain2 should still exist and be non-primary
},
{
name: "remove by org ID and domain",
condition: database.And(domainRepo.OrgIDCondition(orgID), domainRepo.DomainCondition(database.TextOperationEqual, domainName2)),
condition: database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.OrgIDCondition(organization.ID),
domainRepo.DomainCondition(database.TextOperationEqual, domain2.Domain),
),
expected: 1,
},
{
name: "remove non-existent domain",
condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
condition: database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.OrgIDCondition(organization.ID),
domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
),
expected: 0,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := t.Context()
snapshot, err := tx.Begin(ctx)
snapshot, err := tx.Begin(t.Context())
require.NoError(t, err)
defer func() {
require.NoError(t, snapshot.Rollback(ctx))
err = snapshot.Rollback(t.Context())
if err != nil {
t.Log("error during rollback:", err)
}
}()
orgRepo := repository.OrganizationRepository(snapshot)
domainRepo := orgRepo.Domains(false)
// count before removal
beforeCount, err := domainRepo.List(ctx)
beforeCount, err := domainRepo.List(t.Context(), snapshot, database.WithCondition(database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.OrgIDCondition(organization.ID),
)))
require.NoError(t, err)
rowsAffected, err := domainRepo.Remove(ctx, test.condition)
rowsAffected, err := domainRepo.Remove(t.Context(), snapshot, test.condition)
require.NoError(t, err)
assert.Equal(t, test.expected, rowsAffected)
// verify removal
afterCount, err := domainRepo.List(ctx)
afterCount, err := domainRepo.List(t.Context(), snapshot, database.WithCondition(database.And(
domainRepo.InstanceIDCondition(instance.ID),
domainRepo.OrgIDCondition(organization.ID),
)))
require.NoError(t, err)
assert.Equal(t, len(beforeCount)-int(test.expected), len(afterCount))
})
@@ -751,8 +837,7 @@ func TestRemoveOrganizationDomain(t *testing.T) {
}
func TestOrganizationDomainConditions(t *testing.T) {
orgRepo := repository.OrganizationRepository(pool)
domainRepo := orgRepo.Domains(false)
domainRepo := repository.OrganizationDomainRepository()
tests := []struct {
name string
@@ -811,8 +896,7 @@ func TestOrganizationDomainConditions(t *testing.T) {
}
func TestOrganizationDomainChanges(t *testing.T) {
orgRepo := repository.OrganizationRepository(pool)
domainRepo := orgRepo.Domains(false)
domainRepo := repository.OrganizationDomainRepository()
tests := []struct {
name string
@@ -851,8 +935,7 @@ func TestOrganizationDomainChanges(t *testing.T) {
}
func TestOrganizationDomainColumns(t *testing.T) {
orgRepo := repository.OrganizationRepository(pool)
domainRepo := orgRepo.Domains(false)
domainRepo := repository.OrganizationDomainRepository()
tests := []struct {
name string

View File

@@ -1,7 +1,7 @@
package repository_test
import (
"context"
"strconv"
"testing"
"time"
@@ -15,6 +15,18 @@ import (
)
func TestCreateOrganization(t *testing.T) {
beforeCreate := time.Now()
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
err := tx.Rollback(t.Context())
if err != nil {
t.Logf("error during rollback: %v", err)
}
}()
instanceRepo := repository.InstanceRepository()
organizationRepo := repository.OrganizationRepository()
// create instance
instanceId := gofakeit.Name()
instance := domain.Instance{
@@ -26,13 +38,12 @@ func TestCreateOrganization(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
instanceRepo := repository.InstanceRepository(pool)
err := instanceRepo.Create(t.Context(), &instance)
err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
tests := []struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Organization
testFunc func(t *testing.T, client database.QueryExecutor) *domain.Organization
organization domain.Organization
err error
}{
@@ -67,8 +78,7 @@ func TestCreateOrganization(t *testing.T) {
},
{
name: "adding org with same id twice",
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
organizationRepo := repository.OrganizationRepository(pool)
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Organization {
organizationId := gofakeit.Name()
organizationName := gofakeit.Name()
@@ -79,7 +89,7 @@ func TestCreateOrganization(t *testing.T) {
State: domain.OrgStateActive,
}
err := organizationRepo.Create(ctx, &org)
err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
// change the name to make sure same only the id clashes
org.Name = gofakeit.Name()
@@ -89,8 +99,7 @@ func TestCreateOrganization(t *testing.T) {
},
{
name: "adding org with same name twice",
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
organizationRepo := repository.OrganizationRepository(pool)
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Organization {
organizationId := gofakeit.Name()
organizationName := gofakeit.Name()
@@ -101,7 +110,7 @@ func TestCreateOrganization(t *testing.T) {
State: domain.OrgStateActive,
}
err := organizationRepo.Create(ctx, &org)
err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
// change the id to make sure same name+instance causes an error
org.ID = gofakeit.Name()
@@ -111,7 +120,7 @@ func TestCreateOrganization(t *testing.T) {
},
func() struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Organization
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Organization
organization domain.Organization
err error
} {
@@ -120,12 +129,12 @@ func TestCreateOrganization(t *testing.T) {
return struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Organization
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Organization
organization domain.Organization
err error
}{
name: "adding org with same name, different instance",
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Organization {
// create instance
instId := gofakeit.Name()
instance := domain.Instance{
@@ -137,12 +146,9 @@ func TestCreateOrganization(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
instanceRepo := repository.InstanceRepository(pool)
err := instanceRepo.Create(ctx, &instance)
err := instanceRepo.Create(t.Context(), tx, &instance)
assert.Nil(t, err)
organizationRepo := repository.OrganizationRepository(pool)
org := domain.Organization{
ID: gofakeit.Name(),
Name: organizationName,
@@ -150,7 +156,7 @@ func TestCreateOrganization(t *testing.T) {
State: domain.OrgStateActive,
}
err = organizationRepo.Create(ctx, &org)
err = organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
// change the id to make it unique
@@ -214,19 +220,25 @@ func TestCreateOrganization(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
savepoint, err := tx.Begin(t.Context())
require.NoError(t, err)
defer func() {
err = savepoint.Rollback(t.Context())
if err != nil {
t.Logf("error during rollback: %v", err)
}
}()
var organization *domain.Organization
if tt.testFunc != nil {
organization = tt.testFunc(ctx, t)
organization = tt.testFunc(t, savepoint)
} else {
organization = &tt.organization
}
organizationRepo := repository.OrganizationRepository(pool)
// create organization
beforeCreate := time.Now()
err = organizationRepo.Create(ctx, organization)
err = organizationRepo.Create(t.Context(), savepoint, organization)
assert.ErrorIs(t, err, tt.err)
if err != nil {
return
@@ -234,7 +246,7 @@ func TestCreateOrganization(t *testing.T) {
afterCreate := time.Now()
// check organization values
organization, err = organizationRepo.Get(ctx,
organization, err = organizationRepo.Get(t.Context(), savepoint,
database.WithCondition(
database.And(
organizationRepo.IDCondition(organization.ID),
@@ -255,6 +267,19 @@ func TestCreateOrganization(t *testing.T) {
}
func TestUpdateOrganization(t *testing.T) {
beforeUpdate := time.Now()
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
err := tx.Rollback(t.Context())
if err != nil {
t.Logf("error during rollback: %v", err)
}
}()
instanceRepo := repository.InstanceRepository()
organizationRepo := repository.OrganizationRepository()
// create instance
instanceId := gofakeit.Name()
instance := domain.Instance{
@@ -266,20 +291,18 @@ func TestUpdateOrganization(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
instanceRepo := repository.InstanceRepository(pool)
err := instanceRepo.Create(t.Context(), &instance)
err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
organizationRepo := repository.OrganizationRepository(pool)
tests := []struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Organization
testFunc func(t *testing.T) *domain.Organization
update []database.Change
rowsAffected int64
}{
{
name: "happy path update name",
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
testFunc: func(t *testing.T) *domain.Organization {
organizationId := gofakeit.Name()
organizationName := gofakeit.Name()
@@ -291,7 +314,7 @@ func TestUpdateOrganization(t *testing.T) {
}
// create organization
err := organizationRepo.Create(ctx, &org)
err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
// update with updated value
@@ -303,7 +326,7 @@ func TestUpdateOrganization(t *testing.T) {
},
{
name: "update deleted organization",
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
testFunc: func(t *testing.T) *domain.Organization {
organizationId := gofakeit.Name()
organizationName := gofakeit.Name()
@@ -315,13 +338,15 @@ func TestUpdateOrganization(t *testing.T) {
}
// create organization
err := organizationRepo.Create(ctx, &org)
err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
// delete instance
_, err = organizationRepo.Delete(ctx,
_, err = organizationRepo.Delete(t.Context(), tx,
database.And(
organizationRepo.InstanceIDCondition(org.InstanceID),
organizationRepo.IDCondition(org.ID),
org.InstanceID,
),
)
require.NoError(t, err)
@@ -332,7 +357,7 @@ func TestUpdateOrganization(t *testing.T) {
},
{
name: "happy path change state",
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
testFunc: func(t *testing.T) *domain.Organization {
organizationId := gofakeit.Name()
organizationName := gofakeit.Name()
@@ -344,7 +369,7 @@ func TestUpdateOrganization(t *testing.T) {
}
// create organization
err := organizationRepo.Create(ctx, &org)
err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
// update with updated value
@@ -356,7 +381,7 @@ func TestUpdateOrganization(t *testing.T) {
},
{
name: "update non existent organization",
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
testFunc: func(t *testing.T) *domain.Organization {
organizationId := gofakeit.Name()
org := domain.Organization{
@@ -370,16 +395,14 @@ func TestUpdateOrganization(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
organizationRepo := repository.OrganizationRepository(pool)
createdOrg := tt.testFunc(ctx, t)
createdOrg := tt.testFunc(t)
// update org
beforeUpdate := time.Now()
rowsAffected, err := organizationRepo.Update(ctx,
rowsAffected, err := organizationRepo.Update(t.Context(), tx,
database.And(
organizationRepo.InstanceIDCondition(createdOrg.InstanceID),
organizationRepo.IDCondition(createdOrg.ID),
createdOrg.InstanceID,
),
tt.update...,
)
afterUpdate := time.Now()
@@ -392,7 +415,7 @@ func TestUpdateOrganization(t *testing.T) {
}
// check organization values
organization, err := organizationRepo.Get(ctx,
organization, err := organizationRepo.Get(t.Context(), tx,
database.WithCondition(
database.And(
organizationRepo.IDCondition(createdOrg.ID),
@@ -411,6 +434,19 @@ func TestUpdateOrganization(t *testing.T) {
}
func TestGetOrganization(t *testing.T) {
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
err := tx.Rollback(t.Context())
if err != nil {
t.Logf("error during rollback: %v", err)
}
}()
instanceRepo := repository.InstanceRepository()
orgRepo := repository.OrganizationRepository()
orgDomainRepo := repository.OrganizationDomainRepository()
// create instance
instanceId := gofakeit.Name()
instance := domain.Instance{
@@ -422,11 +458,9 @@ func TestGetOrganization(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
instanceRepo := repository.InstanceRepository(pool)
err := instanceRepo.Create(t.Context(), &instance)
require.NoError(t, err)
orgRepo := repository.OrganizationRepository(pool)
err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
// create organization
// this org is created as an additional org which should NOT
@@ -437,13 +471,13 @@ func TestGetOrganization(t *testing.T) {
InstanceID: instanceId,
State: domain.OrgStateActive,
}
err = orgRepo.Create(t.Context(), &org)
err = orgRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
type test struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Organization
orgIdentifierCondition domain.OrgIdentifierCondition
testFunc func(t *testing.T) *domain.Organization
orgIdentifierCondition database.Condition
err error
}
@@ -452,7 +486,7 @@ func TestGetOrganization(t *testing.T) {
organizationId := gofakeit.Name()
return test{
name: "happy path get using id",
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
testFunc: func(t *testing.T) *domain.Organization {
organizationName := gofakeit.Name()
org := domain.Organization{
@@ -463,7 +497,7 @@ func TestGetOrganization(t *testing.T) {
}
// create organization
err := orgRepo.Create(ctx, &org)
err := orgRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
return &org
@@ -475,7 +509,7 @@ func TestGetOrganization(t *testing.T) {
organizationId := gofakeit.Name()
return test{
name: "happy path get using id including domain",
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
testFunc: func(t *testing.T) *domain.Organization {
organizationName := gofakeit.Name()
org := domain.Organization{
@@ -486,7 +520,7 @@ func TestGetOrganization(t *testing.T) {
}
// create organization
err := orgRepo.Create(ctx, &org)
err := orgRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
d := &domain.AddOrganizationDomain{
@@ -496,7 +530,7 @@ func TestGetOrganization(t *testing.T) {
IsVerified: true,
IsPrimary: true,
}
err = orgRepo.Domains(false).Add(ctx, d)
err = orgDomainRepo.Add(t.Context(), tx, d)
require.NoError(t, err)
org.Domains = []*domain.OrganizationDomain{
@@ -521,7 +555,7 @@ func TestGetOrganization(t *testing.T) {
organizationName := gofakeit.Name()
return test{
name: "happy path get using name",
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
testFunc: func(t *testing.T) *domain.Organization {
organizationId := gofakeit.Name()
org := domain.Organization{
@@ -532,39 +566,36 @@ func TestGetOrganization(t *testing.T) {
}
// create organization
err := orgRepo.Create(ctx, &org)
err := orgRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
return &org
},
orgIdentifierCondition: orgRepo.NameCondition(organizationName),
orgIdentifierCondition: orgRepo.NameCondition(database.TextOperationEqual, organizationName),
}
}(),
{
name: "get non existent organization",
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
testFunc: func(t *testing.T) *domain.Organization {
org := domain.Organization{
ID: "non existent org",
Name: "non existent org",
}
return &org
},
orgIdentifierCondition: orgRepo.NameCondition("non-existent-instance-name"),
orgIdentifierCondition: orgRepo.NameCondition(database.TextOperationEqual, "non-existent-instance-name"),
err: new(database.NoRowFoundError),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
orgRepo := repository.OrganizationRepository(pool)
var org *domain.Organization
if tt.testFunc != nil {
org = tt.testFunc(ctx, t)
org = tt.testFunc(t)
}
// get org values
returnedOrg, err := orgRepo.Get(ctx,
returnedOrg, err := orgRepo.Get(t.Context(), tx,
database.WithCondition(
database.And(
tt.orgIdentifierCondition,
@@ -592,11 +623,17 @@ func TestGetOrganization(t *testing.T) {
}
func TestListOrganization(t *testing.T) {
ctx := t.Context()
pool, stop, err := newEmbeddedDB(ctx)
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer stop()
organizationRepo := repository.OrganizationRepository(pool)
defer func() {
err := tx.Rollback(t.Context())
if err != nil {
t.Logf("error during rollback: %v", err)
}
}()
instanceRepo := repository.InstanceRepository()
organizationRepo := repository.OrganizationRepository()
// create instance
instanceId := gofakeit.Name()
@@ -609,33 +646,32 @@ func TestListOrganization(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
instanceRepo := repository.InstanceRepository(pool)
err = instanceRepo.Create(ctx, &instance)
err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
type test struct {
name string
testFunc func(ctx context.Context, t *testing.T) []*domain.Organization
testFunc func(t *testing.T, tx database.QueryExecutor) []*domain.Organization
conditionClauses []database.Condition
noOrganizationReturned bool
}
tests := []test{
{
name: "happy path single organization no filter",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization {
noOfOrganizations := 1
organizations := make([]*domain.Organization, noOfOrganizations)
for i := range noOfOrganizations {
org := domain.Organization{
ID: gofakeit.Name(),
ID: strconv.Itoa(i),
Name: gofakeit.Name(),
InstanceID: instanceId,
State: domain.OrgStateActive,
}
// create organization
err := organizationRepo.Create(ctx, &org)
err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
organizations[i] = &org
@@ -646,20 +682,20 @@ func TestListOrganization(t *testing.T) {
},
{
name: "happy path multiple organization no filter",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization {
noOfOrganizations := 5
organizations := make([]*domain.Organization, noOfOrganizations)
for i := range noOfOrganizations {
org := domain.Organization{
ID: gofakeit.Name(),
ID: strconv.Itoa(i),
Name: gofakeit.Name(),
InstanceID: instanceId,
State: domain.OrgStateActive,
}
// create organization
err := organizationRepo.Create(ctx, &org)
err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
organizations[i] = &org
@@ -672,7 +708,7 @@ func TestListOrganization(t *testing.T) {
organizationId := gofakeit.Name()
return test{
name: "organization filter on id",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization {
// create organization
// this org is created as an additional org which should NOT
// be returned in the results of this test case
@@ -682,7 +718,7 @@ func TestListOrganization(t *testing.T) {
InstanceID: instanceId,
State: domain.OrgStateActive,
}
err = organizationRepo.Create(ctx, &org)
err = organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
noOfOrganizations := 1
@@ -697,7 +733,7 @@ func TestListOrganization(t *testing.T) {
}
// create organization
err := organizationRepo.Create(ctx, &org)
err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
organizations[i] = &org
@@ -705,12 +741,15 @@ func TestListOrganization(t *testing.T) {
return organizations
},
conditionClauses: []database.Condition{organizationRepo.IDCondition(organizationId)},
conditionClauses: []database.Condition{
organizationRepo.InstanceIDCondition(instanceId),
organizationRepo.IDCondition(organizationId),
},
}
}(),
{
name: "multiple organization filter on state",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization {
// create organization
// this org is created as an additional org which should NOT
// be returned in the results of this test case
@@ -720,7 +759,7 @@ func TestListOrganization(t *testing.T) {
InstanceID: instanceId,
State: domain.OrgStateActive,
}
err = organizationRepo.Create(ctx, &org)
err = organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
noOfOrganizations := 5
@@ -728,14 +767,14 @@ func TestListOrganization(t *testing.T) {
for i := range noOfOrganizations {
org := domain.Organization{
ID: gofakeit.Name(),
ID: strconv.Itoa(i),
Name: gofakeit.Name(),
InstanceID: instanceId,
State: domain.OrgStateInactive,
}
// create organization
err := organizationRepo.Create(ctx, &org)
err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
organizations[i] = &org
@@ -743,13 +782,16 @@ func TestListOrganization(t *testing.T) {
return organizations
},
conditionClauses: []database.Condition{organizationRepo.StateCondition(domain.OrgStateInactive)},
conditionClauses: []database.Condition{
organizationRepo.InstanceIDCondition(instanceId),
organizationRepo.StateCondition(domain.OrgStateInactive),
},
},
func() test {
instanceId_2 := gofakeit.Name()
return test{
name: "multiple organization filter on instance",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization {
// create instance 1
instanceId_1 := gofakeit.Name()
instance := domain.Instance{
@@ -761,8 +803,7 @@ func TestListOrganization(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
instanceRepo := repository.InstanceRepository(pool)
err = instanceRepo.Create(ctx, &instance)
err = instanceRepo.Create(t.Context(), tx, &instance)
assert.Nil(t, err)
// create organization
@@ -774,7 +815,7 @@ func TestListOrganization(t *testing.T) {
InstanceID: instanceId_1,
State: domain.OrgStateActive,
}
err = organizationRepo.Create(ctx, &org)
err = organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
// create instance 2
@@ -787,7 +828,7 @@ func TestListOrganization(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
err = instanceRepo.Create(ctx, &instance_2)
err = instanceRepo.Create(t.Context(), tx, &instance_2)
assert.Nil(t, err)
noOfOrganizations := 5
@@ -795,14 +836,14 @@ func TestListOrganization(t *testing.T) {
for i := range noOfOrganizations {
org := domain.Organization{
ID: gofakeit.Name(),
ID: strconv.Itoa(i),
Name: gofakeit.Name(),
InstanceID: instanceId_2,
State: domain.OrgStateActive,
}
// create organization
err := organizationRepo.Create(ctx, &org)
err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
organizations[i] = &org
@@ -816,22 +857,25 @@ func TestListOrganization(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Cleanup(func() {
_, err := pool.Exec(ctx, "DELETE FROM zitadel.organizations")
savepoint, err := tx.Begin(t.Context())
require.NoError(t, err)
})
defer func() {
err = savepoint.Rollback(t.Context())
if err != nil {
t.Logf("error during rollback: %v", err)
}
}()
organizations := tt.testFunc(t, savepoint)
organizations := tt.testFunc(ctx, t)
var condition database.Condition
condition := organizationRepo.InstanceIDCondition(instanceId)
if len(tt.conditionClauses) > 0 {
condition = database.And(tt.conditionClauses...)
}
// check organization values
returnedOrgs, err := organizationRepo.List(ctx,
returnedOrgs, err := organizationRepo.List(t.Context(), tx,
database.WithCondition(condition),
database.WithOrderByAscending(organizationRepo.CreatedAtColumn()),
database.WithOrderByAscending(organizationRepo.CreatedAtColumn(), organizationRepo.IDColumn()),
)
require.NoError(t, err)
if tt.noOrganizationReturned {
@@ -851,6 +895,15 @@ func TestListOrganization(t *testing.T) {
}
func TestDeleteOrganization(t *testing.T) {
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
t.Cleanup(func() {
_ = tx.Rollback(t.Context())
})
instanceRepo := repository.InstanceRepository()
organizationRepo := repository.OrganizationRepository()
// create instance
instanceId := gofakeit.Name()
instance := domain.Instance{
@@ -862,24 +915,22 @@ func TestDeleteOrganization(t *testing.T) {
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
instanceRepo := repository.InstanceRepository(pool)
err := instanceRepo.Create(t.Context(), &instance)
err = instanceRepo.Create(t.Context(), tx, &instance)
require.NoError(t, err)
type test struct {
name string
testFunc func(ctx context.Context, t *testing.T)
orgIdentifierCondition domain.OrgIdentifierCondition
testFunc func(t *testing.T)
orgIdentifierCondition database.Condition
noOfDeletedRows int64
}
tests := []test{
func() test {
organizationRepo := repository.OrganizationRepository(pool)
organizationId := gofakeit.Name()
var noOfOrganizations int64 = 1
return test{
name: "happy path delete organization filter id",
testFunc: func(ctx context.Context, t *testing.T) {
testFunc: func(t *testing.T) {
organizations := make([]*domain.Organization, noOfOrganizations)
for i := range noOfOrganizations {
@@ -891,7 +942,7 @@ func TestDeleteOrganization(t *testing.T) {
}
// create organization
err := organizationRepo.Create(ctx, &org)
err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
organizations[i] = &org
@@ -902,12 +953,11 @@ func TestDeleteOrganization(t *testing.T) {
}
}(),
func() test {
organizationRepo := repository.OrganizationRepository(pool)
organizationName := gofakeit.Name()
var noOfOrganizations int64 = 1
return test{
name: "happy path delete organization filter name",
testFunc: func(ctx context.Context, t *testing.T) {
testFunc: func(t *testing.T) {
organizations := make([]*domain.Organization, noOfOrganizations)
for i := range noOfOrganizations {
@@ -919,30 +969,28 @@ func TestDeleteOrganization(t *testing.T) {
}
// create organization
err := organizationRepo.Create(ctx, &org)
err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
organizations[i] = &org
}
},
orgIdentifierCondition: organizationRepo.NameCondition(organizationName),
orgIdentifierCondition: organizationRepo.NameCondition(database.TextOperationEqual, organizationName),
noOfDeletedRows: noOfOrganizations,
}
}(),
func() test {
organizationRepo := repository.OrganizationRepository(pool)
non_existent_organization_name := gofakeit.Name()
return test{
name: "delete non existent organization",
orgIdentifierCondition: organizationRepo.NameCondition(non_existent_organization_name),
orgIdentifierCondition: organizationRepo.NameCondition(database.TextOperationEqual, non_existent_organization_name),
}
}(),
func() test {
organizationRepo := repository.OrganizationRepository(pool)
organizationName := gofakeit.Name()
return test{
name: "deleted already deleted organization",
testFunc: func(ctx context.Context, t *testing.T) {
testFunc: func(t *testing.T) {
noOfOrganizations := 1
organizations := make([]*domain.Organization, noOfOrganizations)
for i := range noOfOrganizations {
@@ -955,21 +1003,23 @@ func TestDeleteOrganization(t *testing.T) {
}
// create organization
err := organizationRepo.Create(ctx, &org)
err := organizationRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
organizations[i] = &org
}
// delete organization
affectedRows, err := organizationRepo.Delete(ctx,
organizationRepo.NameCondition(organizationName),
organizations[0].InstanceID,
affectedRows, err := organizationRepo.Delete(t.Context(), tx,
database.And(
organizationRepo.InstanceIDCondition(organizations[0].InstanceID),
organizationRepo.NameCondition(database.TextOperationEqual, organizationName),
),
)
assert.Equal(t, int64(1), affectedRows)
require.NoError(t, err)
},
orgIdentifierCondition: organizationRepo.NameCondition(organizationName),
orgIdentifierCondition: organizationRepo.NameCondition(database.TextOperationEqual, organizationName),
// this test should return 0 affected rows as the org was already deleted
noOfDeletedRows: 0,
}
@@ -977,23 +1027,22 @@ func TestDeleteOrganization(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
organizationRepo := repository.OrganizationRepository(pool)
if tt.testFunc != nil {
tt.testFunc(ctx, t)
tt.testFunc(t)
}
// delete organization
noOfDeletedRows, err := organizationRepo.Delete(ctx,
noOfDeletedRows, err := organizationRepo.Delete(t.Context(), tx,
database.And(
organizationRepo.InstanceIDCondition(instanceId),
tt.orgIdentifierCondition,
instanceId,
),
)
require.NoError(t, err)
assert.Equal(t, noOfDeletedRows, tt.noOfDeletedRows)
// check organization was deleted
organization, err := organizationRepo.Get(ctx,
organization, err := organizationRepo.Get(t.Context(), tx,
database.WithCondition(
database.And(
tt.orgIdentifierCondition,
@@ -1006,3 +1055,99 @@ func TestDeleteOrganization(t *testing.T) {
})
}
}
func TestGetOrganizationWithSubResources(t *testing.T) {
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
t.Cleanup(func() {
_ = tx.Rollback(t.Context())
})
instanceRepo := repository.InstanceRepository()
orgRepo := repository.OrganizationRepository()
// create instance
instanceId := gofakeit.Name()
err = instanceRepo.Create(t.Context(), tx, &domain.Instance{
ID: instanceId,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
})
require.NoError(t, err)
// create organization
org := domain.Organization{
ID: "1",
Name: "org-name",
InstanceID: instanceId,
State: domain.OrgStateActive,
}
err = orgRepo.Create(t.Context(), tx, &org)
require.NoError(t, err)
err = orgRepo.Create(t.Context(), tx, &domain.Organization{
ID: "without-sub-resources",
Name: "org-name-2",
InstanceID: instanceId,
State: domain.OrgStateActive,
})
require.NoError(t, err)
t.Run("domains", func(t *testing.T) {
domainRepo := repository.OrganizationDomainRepository()
domain1 := &domain.AddOrganizationDomain{
InstanceID: org.InstanceID,
OrgID: org.ID,
Domain: "domain1.com",
IsVerified: true,
IsPrimary: true,
}
err = domainRepo.Add(t.Context(), tx, domain1)
require.NoError(t, err)
domain2 := &domain.AddOrganizationDomain{
InstanceID: org.InstanceID,
OrgID: org.ID,
Domain: "domain2.com",
IsVerified: false,
IsPrimary: false,
}
err = domainRepo.Add(t.Context(), tx, domain2)
require.NoError(t, err)
t.Run("org by domain", func(t *testing.T) {
orgRepo := orgRepo.LoadDomains()
returnedOrg, err := orgRepo.Get(t.Context(), tx,
database.WithCondition(
database.And(
orgRepo.InstanceIDCondition(instanceId),
orgRepo.ExistsDomain(domainRepo.DomainCondition(database.TextOperationEqual, domain1.Domain)),
),
),
)
require.NoError(t, err)
assert.Equal(t, org.ID, returnedOrg.ID)
assert.Len(t, returnedOrg.Domains, 2)
})
t.Run("ensure org by domain works without LoadDomains", func(t *testing.T) {
returnedOrg, err := orgRepo.Get(t.Context(), tx,
database.WithCondition(
database.And(
orgRepo.InstanceIDCondition(instanceId),
orgRepo.ExistsDomain(domainRepo.DomainCondition(database.TextOperationEqual, domain1.Domain)),
),
),
)
require.NoError(t, err)
assert.Equal(t, org.ID, returnedOrg.ID)
assert.Len(t, returnedOrg.Domains, 0)
})
})
}

View File

@@ -4,10 +4,6 @@ import (
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type repository struct {
client database.QueryExecutor
}
func writeCondition(
builder *database.StatementBuilder,
condition database.Condition,

View File

@@ -12,16 +12,10 @@ const queryUserStmt = `SELECT instance_id, org_id, id, username, type, created_a
` first_name, last_name, email_address, email_verified_at, phone_number, phone_verified_at, description` +
` FROM users_view users`
type user struct {
repository
}
type user struct{}
func UserRepository(client database.QueryExecutor) domain.UserRepository {
return &user{
repository: repository{
client: client,
},
}
func UserRepository() domain.UserRepository {
return new(user)
}
var _ domain.UserRepository = (*user)(nil)
@@ -31,17 +25,17 @@ var _ domain.UserRepository = (*user)(nil)
// -------------------------------------------------------------
// Human implements [domain.UserRepository].
func (u *user) Human() domain.HumanRepository {
func (u user) Human() domain.HumanRepository {
return &userHuman{user: u}
}
// Machine implements [domain.UserRepository].
func (u *user) Machine() domain.MachineRepository {
func (u user) Machine() domain.MachineRepository {
return &userMachine{user: u}
}
// List implements [domain.UserRepository].
func (u *user) List(ctx context.Context, opts ...database.QueryOption) (users []*domain.User, err error) {
func (u user) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (users []*domain.User, err error) {
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
@@ -54,7 +48,7 @@ func (u *user) List(ctx context.Context, opts ...database.QueryOption) (users []
options.WriteLimit(&builder)
options.WriteOffset(&builder)
rows, err := u.client.Query(ctx, builder.String(), builder.Args()...)
rows, err := client.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
@@ -79,7 +73,7 @@ func (u *user) List(ctx context.Context, opts ...database.QueryOption) (users []
}
// Get implements [domain.UserRepository].
func (u *user) Get(ctx context.Context, opts ...database.QueryOption) (*domain.User, error) {
func (u user) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.User, error) {
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
@@ -92,7 +86,7 @@ func (u *user) Get(ctx context.Context, opts ...database.QueryOption) (*domain.U
options.WriteLimit(&builder)
options.WriteOffset(&builder)
return scanUser(u.client.QueryRow(ctx, builder.String(), builder.Args()...))
return scanUser(client.QueryRow(ctx, builder.String(), builder.Args()...))
}
const (
@@ -105,7 +99,7 @@ const (
)
// Create implements [domain.UserRepository].
func (u *user) Create(ctx context.Context, user *domain.User) error {
func (u user) Create(ctx context.Context, client database.QueryExecutor, user *domain.User) error {
builder := database.StatementBuilder{}
builder.AppendArgs(user.InstanceID, user.OrgID, user.ID, user.Username, user.Traits.Type())
switch trait := user.Traits.(type) {
@@ -116,15 +110,15 @@ func (u *user) Create(ctx context.Context, user *domain.User) error {
builder.WriteString(createMachineStmt)
builder.AppendArgs(trait.Description)
}
return u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&user.CreatedAt, &user.UpdatedAt)
return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&user.CreatedAt, &user.UpdatedAt)
}
// Delete implements [domain.UserRepository].
func (u *user) Delete(ctx context.Context, condition database.Condition) error {
func (u user) Delete(ctx context.Context, client database.QueryExecutor, condition database.Condition) error {
builder := database.StatementBuilder{}
builder.WriteString("DELETE FROM users")
writeCondition(&builder, condition)
_, err := u.client.Exec(ctx, builder.String(), builder.Args()...)
_, err := client.Exec(ctx, builder.String(), builder.Args()...)
return err
}

View File

@@ -13,7 +13,7 @@ import (
// -------------------------------------------------------------
type userHuman struct {
*user
user
}
var _ domain.HumanRepository = (*userHuman)(nil)
@@ -21,14 +21,14 @@ var _ domain.HumanRepository = (*userHuman)(nil)
const userEmailQuery = `SELECT h.email_address, h.email_verified_at FROM user_humans h`
// GetEmail implements [domain.HumanRepository].
func (u *userHuman) GetEmail(ctx context.Context, condition database.Condition) (*domain.Email, error) {
func (u *userHuman) GetEmail(ctx context.Context, client database.QueryExecutor, condition database.Condition) (*domain.Email, error) {
var email domain.Email
builder := database.StatementBuilder{}
builder.WriteString(userEmailQuery)
writeCondition(&builder, condition)
err := u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(
err := client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(
&email.Address,
&email.VerifiedAt,
)
@@ -39,7 +39,7 @@ func (u *userHuman) GetEmail(ctx context.Context, condition database.Condition)
}
// Update implements [domain.HumanRepository].
func (h userHuman) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error {
func (h userHuman) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) error {
builder := database.StatementBuilder{}
builder.WriteString(`UPDATE human_users SET `)
database.Changes(changes).Write(&builder)
@@ -47,7 +47,7 @@ func (h userHuman) Update(ctx context.Context, condition database.Condition, cha
stmt := builder.String()
_, err := h.client.Exec(ctx, stmt, builder.Args()...)
_, err := client.Exec(ctx, stmt, builder.Args()...)
return err
}

View File

@@ -8,7 +8,7 @@ import (
)
type userMachine struct {
*user
user
}
var _ domain.MachineRepository = (*userMachine)(nil)
@@ -18,14 +18,14 @@ var _ domain.MachineRepository = (*userMachine)(nil)
// -------------------------------------------------------------
// Update implements [domain.MachineRepository].
func (m userMachine) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error {
func (m userMachine) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) error {
builder := database.StatementBuilder{}
builder.WriteString("UPDATE user_machines SET ")
database.Changes(changes).Write(&builder)
writeCondition(&builder, condition)
m.writeReturning()
_, err := m.client.Exec(ctx, builder.String(), builder.Args()...)
_, err := client.Exec(ctx, builder.String(), builder.Args()...)
return err
}

View File

@@ -1,6 +1,7 @@
package database
import (
"encoding/hex"
"strconv"
"strings"
)
@@ -20,8 +21,16 @@ type StatementBuilder struct {
existingArgs map[any]string
}
type argWriter interface {
WriteArg(builder *StatementBuilder)
}
// WriteArgs adds the argument to the statement and writes the placeholder to the query.
func (b *StatementBuilder) WriteArg(arg any) {
if writer, ok := arg.(argWriter); ok {
writer.WriteArg(b)
return
}
b.WriteString(b.AppendArg(arg))
}
@@ -41,7 +50,13 @@ func (b *StatementBuilder) AppendArg(arg any) (placeholder string) {
if b.existingArgs == nil {
b.existingArgs = make(map[any]string)
}
if placeholder, ok := b.existingArgs[arg]; ok {
// the key is used to work around the following panic:
// runtime error: hash of unhashable type []uint8
key := arg
if argBytes, ok := arg.([]uint8); ok {
key = `\\bytes-` + hex.EncodeToString(argBytes)
}
if placeholder, ok := b.existingArgs[key]; ok {
return placeholder
}
if instruction, ok := arg.(Instruction); ok {
@@ -50,7 +65,7 @@ func (b *StatementBuilder) AppendArg(arg any) (placeholder string) {
b.args = append(b.args, arg)
placeholder = "$" + strconv.Itoa(len(b.args))
b.existingArgs[arg] = placeholder
b.existingArgs[key] = placeholder
return placeholder
}

View File

@@ -0,0 +1,144 @@
package database
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStatementBuilder_AppendArg(t *testing.T) {
t.Run("same arg returns same placeholder", func(t *testing.T) {
var b StatementBuilder
placeholder1 := b.AppendArg("same")
placeholder2 := b.AppendArg("same")
assert.Equal(t, placeholder1, placeholder2)
assert.Len(t, b.Args(), 1)
assert.Len(t, b.existingArgs, 1)
})
t.Run("same arg different types", func(t *testing.T) {
var b StatementBuilder
placeholder1 := b.AppendArg("same")
placeholder2 := b.AppendArg([]byte("same"))
placeholder3 := b.AppendArg("same")
assert.NotEqual(t, placeholder1, placeholder2)
assert.Equal(t, placeholder1, placeholder3)
assert.Len(t, b.Args(), 2)
assert.Len(t, b.existingArgs, 2)
})
t.Run("Instruction args are always different", func(t *testing.T) {
var b StatementBuilder
placeholder1 := b.AppendArg(DefaultInstruction)
placeholder2 := b.AppendArg(DefaultInstruction)
assert.Equal(t, placeholder1, placeholder2)
assert.Len(t, b.Args(), 0)
assert.Len(t, b.existingArgs, 0)
})
}
func TestStatementBuilder_AppendArgs(t *testing.T) {
t.Run("same arg returns same placeholder", func(t *testing.T) {
var b StatementBuilder
b.AppendArgs("same", "same")
assert.Len(t, b.Args(), 1)
assert.Len(t, b.existingArgs, 1)
})
t.Run("same arg different types", func(t *testing.T) {
var b StatementBuilder
b.AppendArgs("same", []byte("same"), "same")
assert.Len(t, b.Args(), 2)
assert.Len(t, b.existingArgs, 2)
})
t.Run("Instruction args are always different", func(t *testing.T) {
var b StatementBuilder
b.AppendArgs(DefaultInstruction, DefaultInstruction)
assert.Len(t, b.Args(), 0)
assert.Len(t, b.existingArgs, 0)
})
}
func TestStatementBuilder_WriteArg(t *testing.T) {
for _, test := range []struct {
name string
arg any
wantSQL string
wantArg []any
}{
{
name: "primitive arg",
arg: "test",
wantSQL: "$1",
wantArg: []any{"test"},
},
{
name: "argWriter arg",
arg: SHA256Value("wrapped"),
wantSQL: "SHA256($1)",
wantArg: []any{"wrapped"},
},
{
name: "Instruction arg",
arg: DefaultInstruction,
wantSQL: "DEFAULT",
wantArg: []any{},
},
} {
t.Run(test.name, func(t *testing.T) {
var b StatementBuilder
b.WriteArg(test.arg)
assert.Equal(t, test.wantSQL, b.String())
require.Len(t, b.Args(), len(test.wantArg))
for i := range test.wantArg {
assert.Equal(t, test.wantArg[i], b.Args()[i])
}
})
}
}
func TestStatementBuilder_WriteArgs(t *testing.T) {
for _, test := range []struct {
name string
args []any
wantSQL string
wantArg []any
}{
{
name: "primitive args",
args: []any{"test", 123, true, uint32(123)},
wantSQL: "$1, $2, $3, $4",
wantArg: []any{"test", 123, true, uint32(123)},
},
{
name: "argWriter args",
args: []any{SHA256Value("wrapped"), LowerValue("ASDF")},
wantSQL: "SHA256($1), LOWER($2)",
wantArg: []any{"wrapped", "ASDF"},
},
{
name: "Instruction args",
args: []any{DefaultInstruction, NowInstruction, NullInstruction},
wantSQL: "DEFAULT, NOW(), NULL",
wantArg: []any{},
},
{
name: "mixed args",
args: []any{123, uint32(123), NowInstruction, NullInstruction, SHA256Value("wrapped"), LowerValue("ASDF")},
wantSQL: "$1, $2, NOW(), NULL, SHA256($3), LOWER($4)",
wantArg: []any{123, uint32(123), "wrapped", "ASDF"},
},
} {
t.Run(test.name, func(t *testing.T) {
var b StatementBuilder
b.WriteArgs(test.args...)
assert.Equal(t, test.wantSQL, b.String())
require.Len(t, b.Args(), len(test.wantArg))
for i := range test.wantArg {
assert.Equal(t, test.wantArg[i], b.Args()[i])
}
})
}
}

View File

@@ -66,7 +66,7 @@ func (p *instanceDomainRelationalProjection) reduceCustomDomainAdded(event event
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-bXCa6", "reduce.wrong.db.pool %T", ex)
}
return repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false).Add(ctx, &domain.AddInstanceDomain{
return repository.InstanceDomainRepository().Add(ctx, v3_sql.SQLTx(tx), &domain.AddInstanceDomain{
InstanceID: e.Aggregate().InstanceID,
Domain: e.Domain,
IsPrimary: gu.Ptr(false),
@@ -88,25 +88,15 @@ func (p *instanceDomainRelationalProjection) reduceDomainPrimarySet(event events
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-QnjHo", "reduce.wrong.db.pool %T", ex)
}
domainRepo := repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false)
domainRepo := repository.InstanceDomainRepository()
condition := database.And(
_, err := domainRepo.Update(ctx, v3_sql.SQLTx(tx),
database.And(
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
domainRepo.TypeCondition(domain.DomainTypeCustom),
)
_, err := domainRepo.Update(ctx,
condition,
),
domainRepo.SetPrimary(),
)
if err != nil {
return err
}
// we need to split the update into two statements because multiple events can have the same creation date
// therefore we first do not set the updated_at timestamp
_, err = domainRepo.Update(ctx,
condition,
domainRepo.SetUpdatedAt(e.CreationDate()),
)
return err
@@ -123,8 +113,8 @@ func (p *instanceDomainRelationalProjection) reduceCustomDomainRemoved(event eve
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-58ghE", "reduce.wrong.db.pool %T", ex)
}
domainRepo := repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false)
_, err := domainRepo.Remove(ctx,
domainRepo := repository.InstanceDomainRepository()
_, err := domainRepo.Remove(ctx, v3_sql.SQLTx(tx),
database.And(
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
@@ -145,7 +135,7 @@ func (p *instanceDomainRelationalProjection) reduceTrustedDomainAdded(event even
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-gx7tQ", "reduce.wrong.db.pool %T", ex)
}
return repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false).Add(ctx, &domain.AddInstanceDomain{
return repository.InstanceDomainRepository().Add(ctx, v3_sql.SQLTx(tx), &domain.AddInstanceDomain{
InstanceID: e.Aggregate().InstanceID,
Domain: e.Domain,
Type: domain.DomainTypeTrusted,
@@ -165,8 +155,8 @@ func (p *instanceDomainRelationalProjection) reduceTrustedDomainRemoved(event ev
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-D68ap", "reduce.wrong.db.pool %T", ex)
}
domainRepo := repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false)
_, err := domainRepo.Remove(ctx,
domainRepo := repository.InstanceDomainRepository()
_, err := domainRepo.Remove(ctx, v3_sql.SQLTx(tx),
database.And(
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),

View File

@@ -74,7 +74,7 @@ func (p *instanceRelationalProjection) reduceInstanceAdded(event eventstore.Even
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
}
return repository.InstanceRepository(v3_sql.SQLTx(tx)).Create(ctx, &domain.Instance{
return repository.InstanceRepository().Create(ctx, v3_sql.SQLTx(tx), &domain.Instance{
ID: e.Aggregate().ID,
Name: e.Name,
CreatedAt: e.CreationDate(),
@@ -93,8 +93,8 @@ func (p *instanceRelationalProjection) reduceInstanceChanged(event eventstore.Ev
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
}
repo := repository.InstanceRepository(v3_sql.SQLTx(tx))
return p.updateInstance(ctx, event, repo, repo.SetName(e.Name))
repo := repository.InstanceRepository()
return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetName(e.Name))
}), nil
}
@@ -108,7 +108,7 @@ func (p *instanceRelationalProjection) reduceInstanceDelete(event eventstore.Eve
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
}
_, err := repository.InstanceRepository(v3_sql.SQLTx(tx)).Delete(ctx, e.Aggregate().ID)
_, err := repository.InstanceRepository().Delete(ctx, v3_sql.SQLTx(tx), e.Aggregate().ID)
return err
}), nil
}
@@ -124,8 +124,8 @@ func (p *instanceRelationalProjection) reduceDefaultOrgSet(event eventstore.Even
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
}
repo := repository.InstanceRepository(v3_sql.SQLTx(tx))
return p.updateInstance(ctx, event, repo, repo.SetDefaultOrg(e.OrgID))
repo := repository.InstanceRepository()
return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetDefaultOrg(e.OrgID))
}), nil
}
@@ -140,8 +140,8 @@ func (p *instanceRelationalProjection) reduceIAMProjectSet(event eventstore.Even
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
}
repo := repository.InstanceRepository(v3_sql.SQLTx(tx))
return p.updateInstance(ctx, event, repo, repo.SetIAMProject(e.ProjectID))
repo := repository.InstanceRepository()
return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetIAMProject(e.ProjectID))
}), nil
}
@@ -156,8 +156,8 @@ func (p *instanceRelationalProjection) reduceConsoleSet(event eventstore.Event)
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
}
repo := repository.InstanceRepository(v3_sql.SQLTx(tx))
return p.updateInstance(ctx, event, repo, repo.SetConsoleClientID(e.ClientID), repo.SetConsoleAppID(e.AppID))
repo := repository.InstanceRepository()
return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetConsoleClientID(e.ClientID), repo.SetConsoleAppID(e.AppID))
}), nil
}
@@ -172,18 +172,18 @@ func (p *instanceRelationalProjection) reduceDefaultLanguageSet(event eventstore
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
}
repo := repository.InstanceRepository(v3_sql.SQLTx(tx))
return p.updateInstance(ctx, event, repo, repo.SetDefaultLanguage(e.Language))
repo := repository.InstanceRepository()
return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetDefaultLanguage(e.Language))
}), nil
}
func (p *instanceRelationalProjection) updateInstance(ctx context.Context, event eventstore.Event, repo domain.InstanceRepository, changes ...database.Change) error {
_, err := repo.Update(ctx, event.Aggregate().ID, changes...)
func (p *instanceRelationalProjection) updateInstance(ctx context.Context, tx database.Transaction, event eventstore.Event, repo domain.InstanceRepository, changes ...database.Change) error {
_, err := repo.Update(ctx, tx, event.Aggregate().ID, changes...)
if err != nil {
return err
}
instance, err := repo.Get(ctx, database.WithCondition(repo.IDCondition(event.Aggregate().ID)))
instance, err := repo.Get(ctx, tx, database.WithCondition(repo.IDCondition(event.Aggregate().ID)))
if err != nil {
return err
}
@@ -192,7 +192,7 @@ func (p *instanceRelationalProjection) updateInstance(ctx context.Context, event
}
// we need to split the update into two statements because multiple events can have the same creation date
// therefore we first do not set the updated_at timestamp
_, err = repo.Update(ctx,
_, err = repo.Update(ctx, tx,
event.Aggregate().ID,
repo.SetUpdatedAt(event.CreatedAt()),
)

View File

@@ -65,7 +65,7 @@ func (p *orgDomainRelationalProjection) reduceAdded(event eventstore.Event) (*ha
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-kGokE", "reduce.wrong.db.pool %T", ex)
}
return repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false).Add(ctx, &domain.AddOrganizationDomain{
return repository.OrganizationDomainRepository().Add(ctx, v3_sql.SQLTx(tx), &domain.AddOrganizationDomain{
InstanceID: e.Aggregate().InstanceID,
OrgID: e.Aggregate().ResourceOwner,
Domain: e.Domain,
@@ -85,23 +85,14 @@ func (p *orgDomainRelationalProjection) reducePrimarySet(event eventstore.Event)
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-h6xF0", "reduce.wrong.db.pool %T", ex)
}
domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false)
condition := database.And(
domainRepo := repository.OrganizationDomainRepository()
_, err := domainRepo.Update(ctx, v3_sql.SQLTx(tx),
database.And(
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
)
_, err := domainRepo.Update(ctx,
condition,
),
domainRepo.SetPrimary(),
)
if err != nil {
return err
}
// we need to split the update into two statements because multiple events can have the same creation date
// therefore we first do not set the updated_at timestamp
_, err = domainRepo.Update(ctx,
condition,
domainRepo.SetUpdatedAt(e.CreationDate()),
)
return err
@@ -118,8 +109,8 @@ func (p *orgDomainRelationalProjection) reduceRemoved(event eventstore.Event) (*
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-X8oS8", "reduce.wrong.db.pool %T", ex)
}
domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false)
_, err := domainRepo.Remove(ctx,
domainRepo := repository.OrganizationDomainRepository()
_, err := domainRepo.Remove(ctx, v3_sql.SQLTx(tx),
database.And(
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
@@ -149,24 +140,15 @@ func (p *orgDomainRelationalProjection) reduceVerificationAdded(event eventstore
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-yF03i", "reduce.wrong.db.pool %T", ex)
}
domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false)
condition := database.And(
domainRepo := repository.OrganizationDomainRepository()
_, err := domainRepo.Update(ctx, v3_sql.SQLTx(tx),
database.And(
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
)
_, err := domainRepo.Update(ctx,
condition,
),
domainRepo.SetValidationType(validationType),
)
if err != nil {
return err
}
// we need to split the update into two statements because multiple events can have the same creation date
// therefore we first do not set the updated_at timestamp
_, err = domainRepo.Update(ctx,
condition,
domainRepo.SetUpdatedAt(e.CreationDate()),
)
return err
@@ -183,28 +165,17 @@ func (p *orgDomainRelationalProjection) reduceVerified(event eventstore.Event) (
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-0ZGqC", "reduce.wrong.db.pool %T", ex)
}
domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false)
domainRepo := repository.OrganizationDomainRepository()
condition := database.And(
_, err := domainRepo.Update(ctx, v3_sql.SQLTx(tx),
database.And(
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
)
_, err := domainRepo.Update(ctx,
condition,
),
domainRepo.SetVerified(),
domainRepo.SetUpdatedAt(e.CreationDate()),
)
if err != nil {
return err
}
// we need to split the update into two statements because multiple events can have the same creation date
// therefore we first do not set the updated_at timestamp
_, err = domainRepo.Update(ctx,
condition,
domainRepo.SetUpdatedAt(e.CreationDate()),
)
return err
}), nil
}