mirror of
https://github.com/zitadel/zitadel.git
synced 2025-12-23 10:27:21 +00:00
refactor: database interaction and error handling (#10762)
This pull request introduces a significant refactoring of the database interaction layer, focusing on improving explicitness, transactional control, and error handling. The core change is the removal of the stateful `QueryExecutor` from repository instances. Instead, it is now passed as an argument to each method that interacts with the database. This change makes transaction management more explicit and flexible, as the same repository instance can be used with a database pool or a specific transaction without needing to be re-instantiated. ### Key Changes - **Explicit `QueryExecutor` Passing:** - All repository methods (`Get`, `List`, `Create`, `Update`, `Delete`, etc.) in `InstanceRepository`, `OrganizationRepository`, `UserRepository`, and their sub-repositories now require a `database.QueryExecutor` (e.g., a `*pgxpool.Pool` or `pgx.Tx`) as the first argument. - Repository constructors no longer accept a `QueryExecutor`. For example, `repository.InstanceRepository(pool)` is now `repository.InstanceRepository()`. - **Enhanced Error Handling:** - A new `database.MissingConditionError` is introduced to enforce required query conditions, such as ensuring an `instance_id` is always present in `UPDATE` and `DELETE` operations. - The database error wrapper in the `postgres` package now correctly identifies and wraps `pgx.ErrTooManyRows` and similar errors from the `scany` library into a `database.MultipleRowsFoundError`. - **Improved Database Conditions:** - The `database.Condition` interface now includes a `ContainsColumn(Column) bool` method. This allows for runtime checks to ensure that critical filters (like `instance_id`) are included in a query, preventing accidental cross-tenant data modification. - A new `database.Exists()` condition has been added to support `EXISTS` subqueries, enabling more complex filtering logic, such as finding an organization that has a specific domain. - **Repository and Interface Refactoring:** - The method for loading related entities (e.g., domains for an organization) has been changed from a boolean flag (`Domains(true)`) to a more explicit, chainable method (`LoadDomains()`). This returns a new repository instance configured to load the sub-resource, promoting immutability. - The custom `OrgIdentifierCondition` has been removed in favor of using the standard `database.Condition` interface, simplifying the API. - **Code Cleanup and Test Updates:** - Unnecessary struct embeddings and metadata have been removed. - All integration and repository tests have been updated to reflect the new method signatures, passing the database pool or transaction object explicitly. - New tests have been added to cover the new `ExistsDomain` functionality and other enhancements. These changes make the data access layer more robust, predictable, and easier to work with, especially in the context of database transactions.
This commit is contained in:
@@ -69,6 +69,8 @@ type instanceConditions interface {
|
||||
IDCondition(instanceID string) database.Condition
|
||||
// NameCondition returns a filter on the name field.
|
||||
NameCondition(op database.TextOperation, name string) database.Condition
|
||||
// ExistsDomain returns a filter on the instance domains.
|
||||
ExistsDomain(cond database.Condition) database.Condition
|
||||
}
|
||||
|
||||
// instanceChanges define all the changes for the instance table.
|
||||
@@ -99,17 +101,16 @@ type InstanceRepository interface {
|
||||
// Member returns the member repository which is a sub repository of the instance repository.
|
||||
// Member() MemberRepository
|
||||
|
||||
Get(ctx context.Context, opts ...database.QueryOption) (*Instance, error)
|
||||
List(ctx context.Context, opts ...database.QueryOption) ([]*Instance, error)
|
||||
Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*Instance, error)
|
||||
List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*Instance, error)
|
||||
|
||||
Create(ctx context.Context, instance *Instance) error
|
||||
Update(ctx context.Context, id string, changes ...database.Change) (int64, error)
|
||||
Delete(ctx context.Context, id string) (int64, error)
|
||||
Create(ctx context.Context, client database.QueryExecutor, instance *Instance) error
|
||||
Update(ctx context.Context, client database.QueryExecutor, id string, changes ...database.Change) (int64, error)
|
||||
Delete(ctx context.Context, client database.QueryExecutor, id string) (int64, error)
|
||||
|
||||
// Domains returns the domain sub repository for the instance.
|
||||
// If shouldLoad is true, the domains will be loaded from the database and written to the [Instance].Domains field.
|
||||
// If shouldLoad is set to true once, the Domains field will be set even if shouldLoad is false in the future.
|
||||
Domains(shouldLoad bool) InstanceDomainRepository
|
||||
// LoadDomains loads the domains of the given instance.
|
||||
// If it is called the [Instance].Domains field will be set on future calls to Get or List.
|
||||
LoadDomains() InstanceRepository
|
||||
}
|
||||
|
||||
type CreateInstance struct {
|
||||
|
||||
@@ -65,15 +65,15 @@ type InstanceDomainRepository interface {
|
||||
// Get returns a single domain based on the criteria.
|
||||
// If no domain is found, it returns an error of type [database.ErrNotFound].
|
||||
// If multiple domains are found, it returns an error of type [database.ErrMultipleRows].
|
||||
Get(ctx context.Context, opts ...database.QueryOption) (*InstanceDomain, error)
|
||||
Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*InstanceDomain, error)
|
||||
// List returns a list of domains based on the criteria.
|
||||
// If no domains are found, it returns an empty slice.
|
||||
List(ctx context.Context, opts ...database.QueryOption) ([]*InstanceDomain, error)
|
||||
List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*InstanceDomain, error)
|
||||
|
||||
// Add adds a new domain to the instance.
|
||||
Add(ctx context.Context, domain *AddInstanceDomain) error
|
||||
Add(ctx context.Context, client database.QueryExecutor, domain *AddInstanceDomain) error
|
||||
// Update updates an existing domain in the instance.
|
||||
Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error)
|
||||
Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error)
|
||||
// Remove removes a domain from the instance.
|
||||
Remove(ctx context.Context, condition database.Condition) (int64, error)
|
||||
Remove(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error)
|
||||
}
|
||||
|
||||
@@ -26,13 +26,6 @@ type Organization struct {
|
||||
Domains []*OrganizationDomain `json:"domains,omitempty" db:"-"` // domains need to be handled separately
|
||||
}
|
||||
|
||||
// OrgIdentifierCondition is used to help specify a single Organization,
|
||||
// it will either be used as the organization ID or organization name,
|
||||
// as organizations can be identified either using (instanceID + ID) OR (instanceID + name)
|
||||
type OrgIdentifierCondition interface {
|
||||
database.Condition
|
||||
}
|
||||
|
||||
// organizationColumns define all the columns of the instance table.
|
||||
type organizationColumns interface {
|
||||
// IDColumn returns the column for the id field.
|
||||
@@ -52,13 +45,15 @@ type organizationColumns interface {
|
||||
// organizationConditions define all the conditions for the instance table.
|
||||
type organizationConditions interface {
|
||||
// IDCondition returns an equal filter on the id field.
|
||||
IDCondition(instanceID string) OrgIdentifierCondition
|
||||
IDCondition(instanceID string) database.Condition
|
||||
// NameCondition returns a filter on the name field.
|
||||
NameCondition(name string) OrgIdentifierCondition
|
||||
NameCondition(op database.TextOperation, name string) database.Condition
|
||||
// InstanceIDCondition returns a filter on the instance id field.
|
||||
InstanceIDCondition(instanceID string) database.Condition
|
||||
// StateCondition returns a filter on the name field.
|
||||
StateCondition(state OrgState) database.Condition
|
||||
// ExistsDomain returns a filter on the organizations domains.
|
||||
ExistsDomain(cond database.Condition) database.Condition
|
||||
}
|
||||
|
||||
// organizationChanges define all the changes for the instance table.
|
||||
@@ -75,17 +70,16 @@ type OrganizationRepository interface {
|
||||
organizationConditions
|
||||
organizationChanges
|
||||
|
||||
Get(ctx context.Context, opts ...database.QueryOption) (*Organization, error)
|
||||
List(ctx context.Context, opts ...database.QueryOption) ([]*Organization, error)
|
||||
Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*Organization, error)
|
||||
List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*Organization, error)
|
||||
|
||||
Create(ctx context.Context, instance *Organization) error
|
||||
Update(ctx context.Context, id OrgIdentifierCondition, instance_id string, changes ...database.Change) (int64, error)
|
||||
Delete(ctx context.Context, id OrgIdentifierCondition, instance_id string) (int64, error)
|
||||
Create(ctx context.Context, client database.QueryExecutor, org *Organization) error
|
||||
Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error)
|
||||
Delete(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error)
|
||||
|
||||
// Domains returns the domain sub repository for the organization.
|
||||
// If shouldLoad is true, the domains will be loaded from the database and written to the [Instance].Domains field.
|
||||
// If shouldLoad is set to true once, the Domains field will be set event if shouldLoad is false in the future.
|
||||
Domains(shouldLoad bool) OrganizationDomainRepository
|
||||
// LoadDomains loads the domains of the given organizations.
|
||||
// If it is called the [Organization].Domains field will be set on future calls to Get or List.
|
||||
LoadDomains() OrganizationRepository
|
||||
}
|
||||
|
||||
type CreateOrganization struct {
|
||||
|
||||
@@ -70,15 +70,15 @@ type OrganizationDomainRepository interface {
|
||||
// Get returns a single domain based on the criteria.
|
||||
// If no domain is found, it returns an error of type [database.ErrNotFound].
|
||||
// If multiple domains are found, it returns an error of type [database.ErrMultipleRows].
|
||||
Get(ctx context.Context, opts ...database.QueryOption) (*OrganizationDomain, error)
|
||||
Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*OrganizationDomain, error)
|
||||
// List returns a list of domains based on the criteria.
|
||||
// If no domains are found, it returns an empty slice.
|
||||
List(ctx context.Context, opts ...database.QueryOption) ([]*OrganizationDomain, error)
|
||||
List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*OrganizationDomain, error)
|
||||
|
||||
// Add adds a new domain to the organization.
|
||||
Add(ctx context.Context, domain *AddOrganizationDomain) error
|
||||
Add(ctx context.Context, client database.QueryExecutor, domain *AddOrganizationDomain) error
|
||||
// Update updates an existing domain in the organization.
|
||||
Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error)
|
||||
Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error)
|
||||
// Remove removes a domain from the organization.
|
||||
Remove(ctx context.Context, condition database.Condition) (int64, error)
|
||||
Remove(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error)
|
||||
}
|
||||
|
||||
@@ -57,13 +57,13 @@ type UserRepository interface {
|
||||
userConditions
|
||||
userChanges
|
||||
// Get returns a user based on the given condition.
|
||||
Get(ctx context.Context, opts ...database.QueryOption) (*User, error)
|
||||
Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*User, error)
|
||||
// List returns a list of users based on the given condition.
|
||||
List(ctx context.Context, opts ...database.QueryOption) ([]*User, error)
|
||||
List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*User, error)
|
||||
// Create creates a new user.
|
||||
Create(ctx context.Context, user *User) error
|
||||
Create(ctx context.Context, client database.QueryExecutor, user *User) error
|
||||
// Delete removes users based on the given condition.
|
||||
Delete(ctx context.Context, condition database.Condition) error
|
||||
Delete(ctx context.Context, client database.QueryExecutor, condition database.Condition) error
|
||||
// Human returns the [HumanRepository].
|
||||
Human() HumanRepository
|
||||
// Machine returns the [MachineRepository].
|
||||
@@ -143,9 +143,9 @@ type HumanRepository interface {
|
||||
humanChanges
|
||||
|
||||
// Get returns an email based on the given condition.
|
||||
GetEmail(ctx context.Context, condition database.Condition) (*Email, error)
|
||||
GetEmail(ctx context.Context, client database.QueryExecutor, condition database.Condition) (*Email, error)
|
||||
// Update updates human users based on the given condition and changes.
|
||||
Update(ctx context.Context, condition database.Condition, changes ...database.Change) error
|
||||
Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) error
|
||||
}
|
||||
|
||||
// machineColumns define all the columns of the machine table which inherits the user table.
|
||||
@@ -172,7 +172,7 @@ type machineChanges interface {
|
||||
// MachineRepository is the interface for the machine repository it inherits the user repository.
|
||||
type MachineRepository interface {
|
||||
// Update updates machine users based on the given condition and changes.
|
||||
Update(ctx context.Context, condition database.Condition, changes ...database.Change) error
|
||||
Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) error
|
||||
|
||||
machineColumns
|
||||
machineConditions
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
package database
|
||||
|
||||
import "slices"
|
||||
|
||||
// Change represents a change to a column in a database table.
|
||||
// Its written in the SET clause of an UPDATE statement.
|
||||
type Change interface {
|
||||
// Write writes the change to the given statement builder.
|
||||
Write(builder *StatementBuilder)
|
||||
// IsOnColumn checks if the change is on the given column.
|
||||
IsOnColumn(col Column) bool
|
||||
}
|
||||
|
||||
type change[V Value] struct {
|
||||
@@ -13,6 +18,8 @@ type change[V Value] struct {
|
||||
|
||||
var _ Change = (*change[string])(nil)
|
||||
|
||||
// NewChange creates a new Change for the given column and value.
|
||||
// If you want to set a column to NULL, use [NewChangePtr].
|
||||
func NewChange[V Value](col Column, value V) Change {
|
||||
return &change[V]{
|
||||
column: col,
|
||||
@@ -20,6 +27,8 @@ func NewChange[V Value](col Column, value V) Change {
|
||||
}
|
||||
}
|
||||
|
||||
// NewChangePtr creates a new Change for the given column and value pointer.
|
||||
// If the value pointer is nil, the column will be set to NULL.
|
||||
func NewChangePtr[V Value](col Column, value *V) Change {
|
||||
if value == nil {
|
||||
return NewChange(col, NullInstruction)
|
||||
@@ -34,19 +43,31 @@ func (c change[V]) Write(builder *StatementBuilder) {
|
||||
builder.WriteArg(c.value)
|
||||
}
|
||||
|
||||
// IsOnColumn implements [Change].
|
||||
func (c change[V]) IsOnColumn(col Column) bool {
|
||||
return c.column.Equals(col)
|
||||
}
|
||||
|
||||
type Changes []Change
|
||||
|
||||
func NewChanges(cols ...Change) Change {
|
||||
return Changes(cols)
|
||||
}
|
||||
|
||||
// IsOnColumn implements [Change].
|
||||
func (c Changes) IsOnColumn(col Column) bool {
|
||||
return slices.ContainsFunc(c, func(change Change) bool {
|
||||
return change.IsOnColumn(col)
|
||||
})
|
||||
}
|
||||
|
||||
// Write implements [Change].
|
||||
func (m Changes) Write(builder *StatementBuilder) {
|
||||
for i, col := range m {
|
||||
for i, change := range m {
|
||||
if i > 0 {
|
||||
builder.WriteString(", ")
|
||||
}
|
||||
col.Write(builder)
|
||||
change.Write(builder)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
68
backend/v3/storage/database/change_test.go
Normal file
68
backend/v3/storage/database/change_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestChangeWrite(t *testing.T) {
|
||||
type want struct {
|
||||
stmt string
|
||||
args []any
|
||||
}
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
change Change
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "change",
|
||||
change: NewChange(NewColumn("table", "column"), "value"),
|
||||
want: want{
|
||||
stmt: "column = $1",
|
||||
args: []any{"value"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "change ptr to null",
|
||||
change: NewChangePtr[int](NewColumn("table", "column"), nil),
|
||||
want: want{
|
||||
stmt: "column = NULL",
|
||||
args: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "change ptr to value",
|
||||
change: NewChangePtr(NewColumn("table", "column"), gu.Ptr(42)),
|
||||
want: want{
|
||||
stmt: "column = $1",
|
||||
args: []any{42},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple changes",
|
||||
change: NewChanges(
|
||||
NewChange(NewColumn("table", "column1"), "value1"),
|
||||
NewChangePtr[int](NewColumn("table", "column2"), nil),
|
||||
NewChange(NewColumn("table", "column3"), 123),
|
||||
),
|
||||
want: want{
|
||||
stmt: "column1 = $1, column2 = NULL, column3 = $2",
|
||||
args: []any{"value1", 123},
|
||||
},
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var builder StatementBuilder
|
||||
test.change.Write(&builder)
|
||||
assert.Equal(t, test.want.stmt, builder.String())
|
||||
require.Len(t, builder.Args(), len(test.want.args))
|
||||
for i, arg := range test.want.args {
|
||||
assert.Equal(t, arg, builder.Args()[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3,8 +3,9 @@ package database
|
||||
type Columns []Column
|
||||
|
||||
// WriteQualified implements [Column].
|
||||
func (m Columns) WriteQualified(builder *StatementBuilder) {
|
||||
for i, col := range m {
|
||||
// Columns are separated by ", ".
|
||||
func (c Columns) WriteQualified(builder *StatementBuilder) {
|
||||
for i, col := range c {
|
||||
if i > 0 {
|
||||
builder.WriteString(", ")
|
||||
}
|
||||
@@ -13,8 +14,9 @@ func (m Columns) WriteQualified(builder *StatementBuilder) {
|
||||
}
|
||||
|
||||
// WriteUnqualified implements [Column].
|
||||
func (m Columns) WriteUnqualified(builder *StatementBuilder) {
|
||||
for i, col := range m {
|
||||
// Columns are separated by ", ".
|
||||
func (c Columns) WriteUnqualified(builder *StatementBuilder) {
|
||||
for i, col := range c {
|
||||
if i > 0 {
|
||||
builder.WriteString(", ")
|
||||
}
|
||||
@@ -22,11 +24,33 @@ func (m Columns) WriteUnqualified(builder *StatementBuilder) {
|
||||
}
|
||||
}
|
||||
|
||||
// Equals implements [Column].
|
||||
func (c Columns) Equals(col Column) bool {
|
||||
if col == nil {
|
||||
return c == nil
|
||||
}
|
||||
other, ok := col.(Columns)
|
||||
if !ok || len(other) != len(c) {
|
||||
return false
|
||||
}
|
||||
for i, col := range c {
|
||||
if !col.Equals(other[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
var _ Column = (Columns)(nil)
|
||||
|
||||
// Column represents a column in a database table.
|
||||
type Column interface {
|
||||
// Write(builder *StatementBuilder)
|
||||
// WriteQualified writes the column with the table name as prefix.
|
||||
WriteQualified(builder *StatementBuilder)
|
||||
// WriteUnqualified writes the column without the table name as prefix.
|
||||
WriteUnqualified(builder *StatementBuilder)
|
||||
// Equals checks if two columns are equal.
|
||||
Equals(col Column) bool
|
||||
}
|
||||
|
||||
type column struct {
|
||||
@@ -35,7 +59,7 @@ type column struct {
|
||||
}
|
||||
|
||||
func NewColumn(table, name string) Column {
|
||||
return column{table: table, name: name}
|
||||
return &column{table: table, name: name}
|
||||
}
|
||||
|
||||
// WriteQualified implements [Column].
|
||||
@@ -51,35 +75,69 @@ func (c column) WriteUnqualified(builder *StatementBuilder) {
|
||||
builder.WriteString(c.name)
|
||||
}
|
||||
|
||||
var _ Column = (*column)(nil)
|
||||
// Equals implements [Column].
|
||||
func (c *column) Equals(col Column) bool {
|
||||
if col == nil {
|
||||
return c == nil
|
||||
}
|
||||
toMatch, ok := col.(*column)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return c.table == toMatch.table && c.name == toMatch.name
|
||||
}
|
||||
|
||||
// // ignoreCaseColumn represents two database columns, one for the
|
||||
// // original value and one for the lower case value.
|
||||
// type ignoreCaseColumn interface {
|
||||
// Column
|
||||
// WriteIgnoreCase(builder *StatementBuilder)
|
||||
// }
|
||||
// LowerColumn returns a column that represents LOWER(col).
|
||||
func LowerColumn(col Column) Column {
|
||||
return &functionColumn{fn: functionLower, col: col}
|
||||
}
|
||||
|
||||
// func NewIgnoreCaseColumn(col Column, suffix string) ignoreCaseColumn {
|
||||
// return ignoreCaseCol{
|
||||
// column: col,
|
||||
// suffix: suffix,
|
||||
// }
|
||||
// }
|
||||
// SHA256Column returns a column that represents SHA256(col).
|
||||
func SHA256Column(col Column) Column {
|
||||
return &functionColumn{fn: functionSHA256, col: col}
|
||||
}
|
||||
|
||||
// type ignoreCaseCol struct {
|
||||
// column Column
|
||||
// suffix string
|
||||
// }
|
||||
type functionColumn struct {
|
||||
fn function
|
||||
col Column
|
||||
}
|
||||
|
||||
// // WriteIgnoreCase implements [ignoreCaseColumn].
|
||||
// func (c ignoreCaseCol) WriteIgnoreCase(builder *StatementBuilder) {
|
||||
// c.column.WriteQualified(builder)
|
||||
// builder.WriteString(c.suffix)
|
||||
// }
|
||||
type function string
|
||||
|
||||
// // WriteQualified implements [ignoreCaseColumn].
|
||||
// func (c ignoreCaseCol) WriteQualified(builder *StatementBuilder) {
|
||||
// c.column.WriteQualified(builder)
|
||||
// builder.WriteString(c.suffix)
|
||||
// }
|
||||
const (
|
||||
_ function = ""
|
||||
functionLower function = "LOWER"
|
||||
functionSHA256 function = "SHA256"
|
||||
)
|
||||
|
||||
// WriteQualified implements [Column].
|
||||
func (c functionColumn) WriteQualified(builder *StatementBuilder) {
|
||||
builder.Grow(len(c.fn) + 2)
|
||||
builder.WriteString(string(c.fn))
|
||||
builder.WriteRune('(')
|
||||
c.col.WriteQualified(builder)
|
||||
builder.WriteRune(')')
|
||||
}
|
||||
|
||||
// WriteUnqualified implements [Column].
|
||||
func (c functionColumn) WriteUnqualified(builder *StatementBuilder) {
|
||||
builder.Grow(len(c.fn) + 2)
|
||||
builder.WriteString(string(c.fn))
|
||||
builder.WriteRune('(')
|
||||
c.col.WriteUnqualified(builder)
|
||||
builder.WriteRune(')')
|
||||
}
|
||||
|
||||
// Equals implements [Column].
|
||||
func (c *functionColumn) Equals(col Column) bool {
|
||||
if col == nil {
|
||||
return c == nil
|
||||
}
|
||||
toMatch, ok := col.(*functionColumn)
|
||||
if !ok || toMatch.fn != c.fn {
|
||||
return false
|
||||
}
|
||||
return c.col.Equals(toMatch.col)
|
||||
}
|
||||
|
||||
var _ Column = (*functionColumn)(nil)
|
||||
|
||||
212
backend/v3/storage/database/column_test.go
Normal file
212
backend/v3/storage/database/column_test.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestWriteUnqualified(t *testing.T) {
|
||||
for _, tests := range []struct {
|
||||
name string
|
||||
column Column
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "column",
|
||||
column: NewColumn("table", "column"),
|
||||
expected: "column",
|
||||
},
|
||||
{
|
||||
name: "columns",
|
||||
column: Columns{
|
||||
NewColumn("table", "column1"),
|
||||
NewColumn("table", "column2"),
|
||||
},
|
||||
expected: "column1, column2",
|
||||
},
|
||||
{
|
||||
name: "function column",
|
||||
column: SHA256Column(NewColumn("table", "column")),
|
||||
expected: "SHA256(column)",
|
||||
},
|
||||
} {
|
||||
t.Run(tests.name, func(t *testing.T) {
|
||||
var builder StatementBuilder
|
||||
tests.column.WriteUnqualified(&builder)
|
||||
assert.Equal(t, tests.expected, builder.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteQualified(t *testing.T) {
|
||||
for _, tests := range []struct {
|
||||
name string
|
||||
column Column
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "column",
|
||||
column: NewColumn("table", "column"),
|
||||
expected: "table.column",
|
||||
},
|
||||
{
|
||||
name: "columns",
|
||||
column: Columns{
|
||||
NewColumn("table", "column1"),
|
||||
NewColumn("table", "column2"),
|
||||
},
|
||||
expected: "table.column1, table.column2",
|
||||
},
|
||||
{
|
||||
name: "function column",
|
||||
column: SHA256Column(NewColumn("table", "column")),
|
||||
expected: "SHA256(table.column)",
|
||||
},
|
||||
} {
|
||||
t.Run(tests.name, func(t *testing.T) {
|
||||
var builder StatementBuilder
|
||||
tests.column.WriteQualified(&builder)
|
||||
assert.Equal(t, tests.expected, builder.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEquals(t *testing.T) {
|
||||
for _, tests := range []struct {
|
||||
name string
|
||||
column Column
|
||||
toCheck Column
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "column equal",
|
||||
column: NewColumn("table", "column"),
|
||||
toCheck: NewColumn("table", "column"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "column nil check",
|
||||
column: NewColumn("table", "column"),
|
||||
toCheck: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "column both nil",
|
||||
column: (*column)(nil),
|
||||
toCheck: nil,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "column not equal (different name)",
|
||||
column: NewColumn("table", "column"),
|
||||
toCheck: NewColumn("table", "column2"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "column not equal (different type)",
|
||||
column: NewColumn("table", "column"),
|
||||
toCheck: SHA256Column(NewColumn("table", "column")),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "columns equal",
|
||||
column: Columns{
|
||||
NewColumn("table", "column1"),
|
||||
NewColumn("table", "column2"),
|
||||
},
|
||||
toCheck: Columns{
|
||||
NewColumn("table", "column1"),
|
||||
NewColumn("table", "column2"),
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "columns nil check",
|
||||
column: Columns{
|
||||
NewColumn("table", "column1"),
|
||||
NewColumn("table", "column2"),
|
||||
},
|
||||
toCheck: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "columns both nil",
|
||||
column: Columns(nil),
|
||||
toCheck: nil,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "columns not equal (different type)",
|
||||
column: Columns{
|
||||
NewColumn("table", "column1"),
|
||||
NewColumn("table", "column2"),
|
||||
},
|
||||
toCheck: NewColumn("table", "column1"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "columns not equal (different length)",
|
||||
column: Columns{
|
||||
NewColumn("table", "column1"),
|
||||
NewColumn("table", "column2"),
|
||||
},
|
||||
toCheck: Columns{
|
||||
NewColumn("table", "column1"),
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "columns not equal (different order)",
|
||||
column: Columns{
|
||||
NewColumn("table", "column1"),
|
||||
NewColumn("table", "column2"),
|
||||
},
|
||||
toCheck: Columns{
|
||||
NewColumn("table", "column2"),
|
||||
NewColumn("table", "column1"),
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "function column equal",
|
||||
column: SHA256Column(NewColumn("table", "column")),
|
||||
toCheck: SHA256Column(NewColumn("table", "column")),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "function nil check",
|
||||
column: SHA256Column(NewColumn("table", "column")),
|
||||
toCheck: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "function both nil",
|
||||
column: (*functionColumn)(nil),
|
||||
toCheck: nil,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "function column not equal (different function)",
|
||||
column: SHA256Column(NewColumn("table", "column")),
|
||||
toCheck: LowerColumn(NewColumn("table", "column")),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "function column not equal (different inner column)",
|
||||
column: SHA256Column(NewColumn("table", "column")),
|
||||
toCheck: SHA256Column(NewColumn("table", "column2")),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "function column not equal (different type)",
|
||||
column: SHA256Column(NewColumn("table", "column")),
|
||||
toCheck: NewColumn("table", "column2"),
|
||||
expected: false,
|
||||
},
|
||||
} {
|
||||
t.Run(tests.name, func(t *testing.T) {
|
||||
assert.Equal(t, tests.expected, tests.column.Equals(tests.toCheck))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,9 @@ package database
|
||||
// Its written after the WHERE keyword in a SQL statement.
|
||||
type Condition interface {
|
||||
Write(builder *StatementBuilder)
|
||||
// IsRestrictingColumn is used to check if the condition filters for a specific column.
|
||||
// It acts as a save guard database operations that should be specific on the given column.
|
||||
IsRestrictingColumn(col Column) bool
|
||||
}
|
||||
|
||||
type and struct {
|
||||
@@ -11,7 +14,7 @@ type and struct {
|
||||
}
|
||||
|
||||
// Write implements [Condition].
|
||||
func (a *and) Write(builder *StatementBuilder) {
|
||||
func (a and) Write(builder *StatementBuilder) {
|
||||
if len(a.conditions) > 1 {
|
||||
builder.WriteString("(")
|
||||
defer builder.WriteString(")")
|
||||
@@ -29,6 +32,16 @@ func And(conditions ...Condition) *and {
|
||||
return &and{conditions: conditions}
|
||||
}
|
||||
|
||||
// IsRestrictingColumn implements [Condition].
|
||||
func (a and) IsRestrictingColumn(col Column) bool {
|
||||
for _, condition := range a.conditions {
|
||||
if condition.IsRestrictingColumn(col) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var _ Condition = (*and)(nil)
|
||||
|
||||
type or struct {
|
||||
@@ -36,7 +49,7 @@ type or struct {
|
||||
}
|
||||
|
||||
// Write implements [Condition].
|
||||
func (o *or) Write(builder *StatementBuilder) {
|
||||
func (o or) Write(builder *StatementBuilder) {
|
||||
if len(o.conditions) > 1 {
|
||||
builder.WriteString("(")
|
||||
defer builder.WriteString(")")
|
||||
@@ -54,6 +67,17 @@ func Or(conditions ...Condition) *or {
|
||||
return &or{conditions: conditions}
|
||||
}
|
||||
|
||||
// IsRestrictingColumn implements [Condition].
|
||||
// It returns true only if all conditions are restricting the given column.
|
||||
func (o or) IsRestrictingColumn(col Column) bool {
|
||||
for _, condition := range o.conditions {
|
||||
if !condition.IsRestrictingColumn(col) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
var _ Condition = (*or)(nil)
|
||||
|
||||
type isNull struct {
|
||||
@@ -61,7 +85,7 @@ type isNull struct {
|
||||
}
|
||||
|
||||
// Write implements [Condition].
|
||||
func (i *isNull) Write(builder *StatementBuilder) {
|
||||
func (i isNull) Write(builder *StatementBuilder) {
|
||||
i.column.WriteQualified(builder)
|
||||
builder.WriteString(" IS NULL")
|
||||
}
|
||||
@@ -71,6 +95,12 @@ func IsNull(column Column) *isNull {
|
||||
return &isNull{column: column}
|
||||
}
|
||||
|
||||
// IsRestrictingColumn implements [Condition].
|
||||
// It returns false because it cannot be used for restricting a column.
|
||||
func (i isNull) IsRestrictingColumn(col Column) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
var _ Condition = (*isNull)(nil)
|
||||
|
||||
type isNotNull struct {
|
||||
@@ -78,7 +108,7 @@ type isNotNull struct {
|
||||
}
|
||||
|
||||
// Write implements [Condition].
|
||||
func (i *isNotNull) Write(builder *StatementBuilder) {
|
||||
func (i isNotNull) Write(builder *StatementBuilder) {
|
||||
i.column.WriteQualified(builder)
|
||||
builder.WriteString(" IS NOT NULL")
|
||||
}
|
||||
@@ -88,43 +118,122 @@ func IsNotNull(column Column) *isNotNull {
|
||||
return &isNotNull{column: column}
|
||||
}
|
||||
|
||||
// IsRestrictingColumn implements [Condition].
|
||||
// It returns false because it cannot be used for restricting a column.
|
||||
func (i isNotNull) IsRestrictingColumn(col Column) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
var _ Condition = (*isNotNull)(nil)
|
||||
|
||||
type valueCondition func(builder *StatementBuilder)
|
||||
type valueCondition struct {
|
||||
write func(builder *StatementBuilder)
|
||||
col Column
|
||||
}
|
||||
|
||||
// NewTextCondition creates a condition that compares a text column with a value.
|
||||
func NewTextCondition[V Text](col Column, op TextOperation, value V) Condition {
|
||||
return valueCondition(func(builder *StatementBuilder) {
|
||||
writeTextOperation(builder, col, op, value)
|
||||
})
|
||||
// If you want to use ignore case operations, consider using [NewTextIgnoreCaseCondition].
|
||||
func NewTextCondition[T Text](col Column, op TextOperation, value T) Condition {
|
||||
return valueCondition{
|
||||
col: col,
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeTextOperation[T](builder, col, op, value)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewTextIgnoreCaseCondition creates a condition that compares a text column with a value, ignoring case by lowercasing both.
|
||||
func NewTextIgnoreCaseCondition[T Text](col Column, op TextOperation, value T) Condition {
|
||||
return valueCondition{
|
||||
col: col,
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeTextOperation[T](builder, LowerColumn(col), op, LowerValue(value))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewDateCondition creates a condition that compares a numeric column with a value.
|
||||
func NewNumberCondition[V Number](col Column, op NumberOperation, value V) Condition {
|
||||
return valueCondition(func(builder *StatementBuilder) {
|
||||
writeNumberOperation(builder, col, op, value)
|
||||
})
|
||||
return valueCondition{
|
||||
col: col,
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeNumberOperation[V](builder, col, op, value)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewDateCondition creates a condition that compares a boolean column with a value.
|
||||
func NewBooleanCondition[V Boolean](col Column, value V) Condition {
|
||||
return valueCondition(func(builder *StatementBuilder) {
|
||||
writeBooleanOperation(builder, col, value)
|
||||
})
|
||||
return valueCondition{
|
||||
col: col,
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeBooleanOperation[V](builder, col, value)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewBytesCondition creates a condition that compares a BYTEA column with a value.
|
||||
func NewBytesCondition[V Bytes](col Column, op BytesOperation, value any) Condition {
|
||||
return valueCondition{
|
||||
col: col,
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeBytesOperation[V](builder, col, op, value)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewColumnCondition creates a condition that compares two columns on equality.
|
||||
func NewColumnCondition(col1, col2 Column) Condition {
|
||||
return valueCondition(func(builder *StatementBuilder) {
|
||||
col1.WriteQualified(builder)
|
||||
builder.WriteString(" = ")
|
||||
col2.WriteQualified(builder)
|
||||
})
|
||||
return valueCondition{
|
||||
col: col1,
|
||||
write: func(builder *StatementBuilder) {
|
||||
col1.WriteQualified(builder)
|
||||
builder.WriteString(" = ")
|
||||
col2.WriteQualified(builder)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements [Condition].
|
||||
func (c valueCondition) Write(builder *StatementBuilder) {
|
||||
c(builder)
|
||||
c.write(builder)
|
||||
}
|
||||
|
||||
// IsRestrictingColumn implements [Condition].
|
||||
func (i valueCondition) IsRestrictingColumn(col Column) bool {
|
||||
return i.col.Equals(col)
|
||||
}
|
||||
|
||||
var _ Condition = (*valueCondition)(nil)
|
||||
|
||||
// existsCondition is a helper to write an EXISTS (SELECT 1 FROM <table> WHERE <condition>) clause.
|
||||
// It implements Condition so it can be composed with other conditions using And/Or.
|
||||
type existsCondition struct {
|
||||
table string
|
||||
condition Condition
|
||||
}
|
||||
|
||||
// Exists creates a condition that checks for the existence of rows in a subquery.
|
||||
func Exists(table string, condition Condition) Condition {
|
||||
return &existsCondition{
|
||||
table: table,
|
||||
condition: condition,
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements [Condition].
|
||||
func (e existsCondition) Write(builder *StatementBuilder) {
|
||||
builder.WriteString(" EXISTS (SELECT 1 FROM ")
|
||||
builder.WriteString(e.table)
|
||||
builder.WriteString(" WHERE ")
|
||||
e.condition.Write(builder)
|
||||
builder.WriteString(")")
|
||||
}
|
||||
|
||||
// IsRestrictingColumn implements [Condition].
|
||||
func (e existsCondition) IsRestrictingColumn(col Column) bool {
|
||||
// Forward to the inner condition so safety checks (like instance_id presence) can still work.
|
||||
return e.condition.IsRestrictingColumn(col)
|
||||
}
|
||||
|
||||
var _ Condition = (*existsCondition)(nil)
|
||||
|
||||
248
backend/v3/storage/database/condition_test.go
Normal file
248
backend/v3/storage/database/condition_test.go
Normal file
@@ -0,0 +1,248 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWrite(t *testing.T) {
|
||||
type want struct {
|
||||
stmt string
|
||||
args []any
|
||||
}
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
cond Condition
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "no and condition",
|
||||
cond: And(),
|
||||
want: want{
|
||||
stmt: "",
|
||||
args: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "one and condition",
|
||||
cond: And(
|
||||
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
|
||||
),
|
||||
want: want{
|
||||
stmt: "table.column1 = other_table.column2",
|
||||
args: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple and condition",
|
||||
cond: And(
|
||||
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
|
||||
NewColumnCondition(NewColumn("table", "column3"), NewColumn("other_table", "column4")),
|
||||
),
|
||||
want: want{
|
||||
stmt: "(table.column1 = other_table.column2 AND table.column3 = other_table.column4)",
|
||||
args: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no or condition",
|
||||
cond: Or(),
|
||||
want: want{
|
||||
stmt: "",
|
||||
args: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "one or condition",
|
||||
cond: Or(
|
||||
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
|
||||
),
|
||||
want: want{
|
||||
stmt: "table.column1 = other_table.column2",
|
||||
args: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple or condition",
|
||||
cond: Or(
|
||||
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
|
||||
NewColumnCondition(NewColumn("table", "column3"), NewColumn("other_table", "column4")),
|
||||
),
|
||||
want: want{
|
||||
stmt: "(table.column1 = other_table.column2 OR table.column3 = other_table.column4)",
|
||||
args: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "is null condition",
|
||||
cond: IsNull(NewColumn("table", "column1")),
|
||||
want: want{
|
||||
stmt: "table.column1 IS NULL",
|
||||
args: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "is not null condition",
|
||||
cond: IsNotNull(NewColumn("table", "column1")),
|
||||
want: want{
|
||||
stmt: "table.column1 IS NOT NULL",
|
||||
args: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "text condition",
|
||||
cond: NewTextCondition(NewColumn("table", "column1"), TextOperationEqual, "some text"),
|
||||
want: want{
|
||||
stmt: "table.column1 = $1",
|
||||
args: []any{"some text"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "text ignore case condition",
|
||||
cond: NewTextIgnoreCaseCondition(NewColumn("table", "column1"), TextOperationNotEqual, "some TEXT"),
|
||||
want: want{
|
||||
stmt: "LOWER(table.column1) <> LOWER($1)",
|
||||
args: []any{"some TEXT"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "number condition",
|
||||
cond: NewNumberCondition(NewColumn("table", "column1"), NumberOperationEqual, 42),
|
||||
want: want{
|
||||
stmt: "table.column1 = $1",
|
||||
args: []any{42},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "boolean condition",
|
||||
cond: NewBooleanCondition(NewColumn("table", "column1"), true),
|
||||
want: want{
|
||||
stmt: "table.column1 = $1",
|
||||
args: []any{true},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bytes condition",
|
||||
cond: NewBytesCondition[[]byte](NewColumn("table", "column1"), BytesOperationEqual, []byte{0x01, 0x02, 0x03}),
|
||||
want: want{
|
||||
stmt: "table.column1 = $1",
|
||||
args: []any{[]byte{0x01, 0x02, 0x03}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "column condition",
|
||||
cond: NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
|
||||
want: want{
|
||||
stmt: "table.column1 = other_table.column2",
|
||||
args: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "exists condition",
|
||||
cond: Exists("table", And(
|
||||
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
|
||||
NewColumnCondition(NewColumn("table", "column3"), NewColumn("other_table", "column4")),
|
||||
)),
|
||||
want: want{
|
||||
stmt: " EXISTS (SELECT 1 FROM table WHERE (table.column1 = other_table.column2 AND table.column3 = other_table.column4))",
|
||||
args: nil,
|
||||
},
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var builder StatementBuilder
|
||||
test.cond.Write(&builder)
|
||||
assert.Equal(t, test.want.stmt, builder.String())
|
||||
require.Len(t, builder.Args(), len(test.want.args))
|
||||
for i, arg := range test.want.args {
|
||||
assert.Equal(t, arg, builder.Args()[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRestrictingColumn(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
col Column
|
||||
cond Condition
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "and with restricting column",
|
||||
col: NewColumn("table", "column1"),
|
||||
cond: And(
|
||||
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
|
||||
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")),
|
||||
),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "and without restricting column",
|
||||
col: NewColumn("table", "column1"),
|
||||
cond: And(
|
||||
NewColumnCondition(NewColumn("table", "column2"), NewColumn("other_table", "column3")),
|
||||
IsNull(NewColumn("table", "column4")),
|
||||
IsNotNull(NewColumn("table", "column5")),
|
||||
),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "or with restricting column",
|
||||
col: NewColumn("table", "column1"),
|
||||
cond: Or(
|
||||
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
|
||||
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")),
|
||||
),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "or without restricting column",
|
||||
col: NewColumn("table", "column1"),
|
||||
cond: Or(
|
||||
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")),
|
||||
IsNotNull(NewColumn("table", "column4")),
|
||||
IsNull(NewColumn("table", "column5")),
|
||||
),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "is null never restricts",
|
||||
col: NewColumn("table", "column1"),
|
||||
cond: IsNull(NewColumn("table", "column1")),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "is not null never restricts",
|
||||
col: NewColumn("table", "column1"),
|
||||
cond: IsNotNull(NewColumn("table", "column1")),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "exists with restricting column",
|
||||
col: NewColumn("table", "column1"),
|
||||
cond: Exists("table", And(
|
||||
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
|
||||
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")),
|
||||
)),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "exists without restricting column",
|
||||
col: NewColumn("table", "column1"),
|
||||
cond: Exists("table", Or(
|
||||
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")),
|
||||
IsNotNull(NewColumn("table", "column4")),
|
||||
IsNull(NewColumn("table", "column5")),
|
||||
)),
|
||||
want: false,
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
isRestricting := test.cond.IsRestrictingColumn(test.col)
|
||||
assert.Equal(t, test.want, isRestricting)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,7 @@ type Pool interface {
|
||||
QueryExecutor
|
||||
Migrator
|
||||
|
||||
Acquire(ctx context.Context) (Client, error)
|
||||
Acquire(ctx context.Context) (Connection, error)
|
||||
Close(ctx context.Context) error
|
||||
|
||||
Ping(ctx context.Context) error
|
||||
@@ -22,8 +22,8 @@ type PoolTest interface {
|
||||
MigrateTest(ctx context.Context) error
|
||||
}
|
||||
|
||||
// Client is a single database connection which can be released back to the pool.
|
||||
type Client interface {
|
||||
// Connection is a single database connection which can be released back to the pool.
|
||||
type Connection interface {
|
||||
Beginner
|
||||
QueryExecutor
|
||||
Migrator
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/zitadel/zitadel/backend/v3/storage/database (interfaces: Pool,Client,Row,Rows,Transaction)
|
||||
// Source: github.com/zitadel/zitadel/backend/v3/storage/database (interfaces: Pool,Connection,Row,Rows,Transaction)
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Client,Row,Rows,Transaction
|
||||
// mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Connection,Row,Rows,Transaction
|
||||
//
|
||||
|
||||
// Package dbmock is a generated GoMock package.
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
type MockPool struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockPoolMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockPoolMockRecorder is the mock recorder for MockPool.
|
||||
@@ -41,18 +42,18 @@ func (m *MockPool) EXPECT() *MockPoolMockRecorder {
|
||||
}
|
||||
|
||||
// Acquire mocks base method.
|
||||
func (m *MockPool) Acquire(arg0 context.Context) (database.Client, error) {
|
||||
func (m *MockPool) Acquire(ctx context.Context) (database.Connection, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Acquire", arg0)
|
||||
ret0, _ := ret[0].(database.Client)
|
||||
ret := m.ctrl.Call(m, "Acquire", ctx)
|
||||
ret0, _ := ret[0].(database.Connection)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Acquire indicates an expected call of Acquire.
|
||||
func (mr *MockPoolMockRecorder) Acquire(arg0 any) *MockPoolAcquireCall {
|
||||
func (mr *MockPoolMockRecorder) Acquire(ctx any) *MockPoolAcquireCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockPool)(nil).Acquire), arg0)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockPool)(nil).Acquire), ctx)
|
||||
return &MockPoolAcquireCall{Call: call}
|
||||
}
|
||||
|
||||
@@ -62,36 +63,36 @@ type MockPoolAcquireCall struct {
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockPoolAcquireCall) Return(arg0 database.Client, arg1 error) *MockPoolAcquireCall {
|
||||
func (c *MockPoolAcquireCall) Return(arg0 database.Connection, arg1 error) *MockPoolAcquireCall {
|
||||
c.Call = c.Call.Return(arg0, arg1)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockPoolAcquireCall) Do(f func(context.Context) (database.Client, error)) *MockPoolAcquireCall {
|
||||
func (c *MockPoolAcquireCall) Do(f func(context.Context) (database.Connection, error)) *MockPoolAcquireCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockPoolAcquireCall) DoAndReturn(f func(context.Context) (database.Client, error)) *MockPoolAcquireCall {
|
||||
func (c *MockPoolAcquireCall) DoAndReturn(f func(context.Context) (database.Connection, error)) *MockPoolAcquireCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// Begin mocks base method.
|
||||
func (m *MockPool) Begin(arg0 context.Context, arg1 *database.TransactionOptions) (database.Transaction, error) {
|
||||
func (m *MockPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Begin", arg0, arg1)
|
||||
ret := m.ctrl.Call(m, "Begin", ctx, opts)
|
||||
ret0, _ := ret[0].(database.Transaction)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Begin indicates an expected call of Begin.
|
||||
func (mr *MockPoolMockRecorder) Begin(arg0, arg1 any) *MockPoolBeginCall {
|
||||
func (mr *MockPoolMockRecorder) Begin(ctx, opts any) *MockPoolBeginCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockPool)(nil).Begin), arg0, arg1)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockPool)(nil).Begin), ctx, opts)
|
||||
return &MockPoolBeginCall{Call: call}
|
||||
}
|
||||
|
||||
@@ -119,17 +120,17 @@ func (c *MockPoolBeginCall) DoAndReturn(f func(context.Context, *database.Transa
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockPool) Close(arg0 context.Context) error {
|
||||
func (m *MockPool) Close(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Close", arg0)
|
||||
ret := m.ctrl.Call(m, "Close", ctx)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close.
|
||||
func (mr *MockPoolMockRecorder) Close(arg0 any) *MockPoolCloseCall {
|
||||
func (mr *MockPoolMockRecorder) Close(ctx any) *MockPoolCloseCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPool)(nil).Close), arg0)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPool)(nil).Close), ctx)
|
||||
return &MockPoolCloseCall{Call: call}
|
||||
}
|
||||
|
||||
@@ -157,10 +158,10 @@ func (c *MockPoolCloseCall) DoAndReturn(f func(context.Context) error) *MockPool
|
||||
}
|
||||
|
||||
// Exec mocks base method.
|
||||
func (m *MockPool) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) {
|
||||
func (m *MockPool) Exec(ctx context.Context, stmt string, args ...any) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{arg0, arg1}
|
||||
for _, a := range arg2 {
|
||||
varargs := []any{ctx, stmt}
|
||||
for _, a := range args {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Exec", varargs...)
|
||||
@@ -170,9 +171,9 @@ func (m *MockPool) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64,
|
||||
}
|
||||
|
||||
// Exec indicates an expected call of Exec.
|
||||
func (mr *MockPoolMockRecorder) Exec(arg0, arg1 any, arg2 ...any) *MockPoolExecCall {
|
||||
func (mr *MockPoolMockRecorder) Exec(ctx, stmt any, args ...any) *MockPoolExecCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{arg0, arg1}, arg2...)
|
||||
varargs := append([]any{ctx, stmt}, args...)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockPool)(nil).Exec), varargs...)
|
||||
return &MockPoolExecCall{Call: call}
|
||||
}
|
||||
@@ -201,17 +202,17 @@ func (c *MockPoolExecCall) DoAndReturn(f func(context.Context, string, ...any) (
|
||||
}
|
||||
|
||||
// Migrate mocks base method.
|
||||
func (m *MockPool) Migrate(arg0 context.Context) error {
|
||||
func (m *MockPool) Migrate(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Migrate", arg0)
|
||||
ret := m.ctrl.Call(m, "Migrate", ctx)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Migrate indicates an expected call of Migrate.
|
||||
func (mr *MockPoolMockRecorder) Migrate(arg0 any) *MockPoolMigrateCall {
|
||||
func (mr *MockPoolMockRecorder) Migrate(ctx any) *MockPoolMigrateCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockPool)(nil).Migrate), arg0)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockPool)(nil).Migrate), ctx)
|
||||
return &MockPoolMigrateCall{Call: call}
|
||||
}
|
||||
|
||||
@@ -238,11 +239,49 @@ func (c *MockPoolMigrateCall) DoAndReturn(f func(context.Context) error) *MockPo
|
||||
return c
|
||||
}
|
||||
|
||||
// Query mocks base method.
|
||||
func (m *MockPool) Query(arg0 context.Context, arg1 string, arg2 ...any) (database.Rows, error) {
|
||||
// Ping mocks base method.
|
||||
func (m *MockPool) Ping(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{arg0, arg1}
|
||||
for _, a := range arg2 {
|
||||
ret := m.ctrl.Call(m, "Ping", ctx)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Ping indicates an expected call of Ping.
|
||||
func (mr *MockPoolMockRecorder) Ping(ctx any) *MockPoolPingCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockPool)(nil).Ping), ctx)
|
||||
return &MockPoolPingCall{Call: call}
|
||||
}
|
||||
|
||||
// MockPoolPingCall wrap *gomock.Call
|
||||
type MockPoolPingCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockPoolPingCall) Return(arg0 error) *MockPoolPingCall {
|
||||
c.Call = c.Call.Return(arg0)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockPoolPingCall) Do(f func(context.Context) error) *MockPoolPingCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockPoolPingCall) DoAndReturn(f func(context.Context) error) *MockPoolPingCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// Query mocks base method.
|
||||
func (m *MockPool) Query(ctx context.Context, stmt string, args ...any) (database.Rows, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{ctx, stmt}
|
||||
for _, a := range args {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Query", varargs...)
|
||||
@@ -252,9 +291,9 @@ func (m *MockPool) Query(arg0 context.Context, arg1 string, arg2 ...any) (databa
|
||||
}
|
||||
|
||||
// Query indicates an expected call of Query.
|
||||
func (mr *MockPoolMockRecorder) Query(arg0, arg1 any, arg2 ...any) *MockPoolQueryCall {
|
||||
func (mr *MockPoolMockRecorder) Query(ctx, stmt any, args ...any) *MockPoolQueryCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{arg0, arg1}, arg2...)
|
||||
varargs := append([]any{ctx, stmt}, args...)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockPool)(nil).Query), varargs...)
|
||||
return &MockPoolQueryCall{Call: call}
|
||||
}
|
||||
@@ -283,10 +322,10 @@ func (c *MockPoolQueryCall) DoAndReturn(f func(context.Context, string, ...any)
|
||||
}
|
||||
|
||||
// QueryRow mocks base method.
|
||||
func (m *MockPool) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) database.Row {
|
||||
func (m *MockPool) QueryRow(ctx context.Context, stmt string, args ...any) database.Row {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{arg0, arg1}
|
||||
for _, a := range arg2 {
|
||||
varargs := []any{ctx, stmt}
|
||||
for _, a := range args {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "QueryRow", varargs...)
|
||||
@@ -295,9 +334,9 @@ func (m *MockPool) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) data
|
||||
}
|
||||
|
||||
// QueryRow indicates an expected call of QueryRow.
|
||||
func (mr *MockPoolMockRecorder) QueryRow(arg0, arg1 any, arg2 ...any) *MockPoolQueryRowCall {
|
||||
func (mr *MockPoolMockRecorder) QueryRow(ctx, stmt any, args ...any) *MockPoolQueryRowCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{arg0, arg1}, arg2...)
|
||||
varargs := append([]any{ctx, stmt}, args...)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockPool)(nil).QueryRow), varargs...)
|
||||
return &MockPoolQueryRowCall{Call: call}
|
||||
}
|
||||
@@ -325,73 +364,74 @@ func (c *MockPoolQueryRowCall) DoAndReturn(f func(context.Context, string, ...an
|
||||
return c
|
||||
}
|
||||
|
||||
// MockClient is a mock of Client interface.
|
||||
type MockClient struct {
|
||||
// MockConnection is a mock of Connection interface.
|
||||
type MockConnection struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockClientMockRecorder
|
||||
recorder *MockConnectionMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockClientMockRecorder is the mock recorder for MockClient.
|
||||
type MockClientMockRecorder struct {
|
||||
mock *MockClient
|
||||
// MockConnectionMockRecorder is the mock recorder for MockConnection.
|
||||
type MockConnectionMockRecorder struct {
|
||||
mock *MockConnection
|
||||
}
|
||||
|
||||
// NewMockClient creates a new mock instance.
|
||||
func NewMockClient(ctrl *gomock.Controller) *MockClient {
|
||||
mock := &MockClient{ctrl: ctrl}
|
||||
mock.recorder = &MockClientMockRecorder{mock}
|
||||
// NewMockConnection creates a new mock instance.
|
||||
func NewMockConnection(ctrl *gomock.Controller) *MockConnection {
|
||||
mock := &MockConnection{ctrl: ctrl}
|
||||
mock.recorder = &MockConnectionMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockClient) EXPECT() *MockClientMockRecorder {
|
||||
func (m *MockConnection) EXPECT() *MockConnectionMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Begin mocks base method.
|
||||
func (m *MockClient) Begin(arg0 context.Context, arg1 *database.TransactionOptions) (database.Transaction, error) {
|
||||
func (m *MockConnection) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Begin", arg0, arg1)
|
||||
ret := m.ctrl.Call(m, "Begin", ctx, opts)
|
||||
ret0, _ := ret[0].(database.Transaction)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Begin indicates an expected call of Begin.
|
||||
func (mr *MockClientMockRecorder) Begin(arg0, arg1 any) *MockClientBeginCall {
|
||||
func (mr *MockConnectionMockRecorder) Begin(ctx, opts any) *MockConnectionBeginCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockClient)(nil).Begin), arg0, arg1)
|
||||
return &MockClientBeginCall{Call: call}
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockConnection)(nil).Begin), ctx, opts)
|
||||
return &MockConnectionBeginCall{Call: call}
|
||||
}
|
||||
|
||||
// MockClientBeginCall wrap *gomock.Call
|
||||
type MockClientBeginCall struct {
|
||||
// MockConnectionBeginCall wrap *gomock.Call
|
||||
type MockConnectionBeginCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockClientBeginCall) Return(arg0 database.Transaction, arg1 error) *MockClientBeginCall {
|
||||
func (c *MockConnectionBeginCall) Return(arg0 database.Transaction, arg1 error) *MockConnectionBeginCall {
|
||||
c.Call = c.Call.Return(arg0, arg1)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockClientBeginCall) Do(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockClientBeginCall {
|
||||
func (c *MockConnectionBeginCall) Do(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockConnectionBeginCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockClientBeginCall) DoAndReturn(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockClientBeginCall {
|
||||
func (c *MockConnectionBeginCall) DoAndReturn(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockConnectionBeginCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// Exec mocks base method.
|
||||
func (m *MockClient) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) {
|
||||
func (m *MockConnection) Exec(ctx context.Context, stmt string, args ...any) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{arg0, arg1}
|
||||
for _, a := range arg2 {
|
||||
varargs := []any{ctx, stmt}
|
||||
for _, a := range args {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Exec", varargs...)
|
||||
@@ -401,79 +441,117 @@ func (m *MockClient) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64
|
||||
}
|
||||
|
||||
// Exec indicates an expected call of Exec.
|
||||
func (mr *MockClientMockRecorder) Exec(arg0, arg1 any, arg2 ...any) *MockClientExecCall {
|
||||
func (mr *MockConnectionMockRecorder) Exec(ctx, stmt any, args ...any) *MockConnectionExecCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{arg0, arg1}, arg2...)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockClient)(nil).Exec), varargs...)
|
||||
return &MockClientExecCall{Call: call}
|
||||
varargs := append([]any{ctx, stmt}, args...)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockConnection)(nil).Exec), varargs...)
|
||||
return &MockConnectionExecCall{Call: call}
|
||||
}
|
||||
|
||||
// MockClientExecCall wrap *gomock.Call
|
||||
type MockClientExecCall struct {
|
||||
// MockConnectionExecCall wrap *gomock.Call
|
||||
type MockConnectionExecCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockClientExecCall) Return(arg0 int64, arg1 error) *MockClientExecCall {
|
||||
func (c *MockConnectionExecCall) Return(arg0 int64, arg1 error) *MockConnectionExecCall {
|
||||
c.Call = c.Call.Return(arg0, arg1)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockClientExecCall) Do(f func(context.Context, string, ...any) (int64, error)) *MockClientExecCall {
|
||||
func (c *MockConnectionExecCall) Do(f func(context.Context, string, ...any) (int64, error)) *MockConnectionExecCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockClientExecCall) DoAndReturn(f func(context.Context, string, ...any) (int64, error)) *MockClientExecCall {
|
||||
func (c *MockConnectionExecCall) DoAndReturn(f func(context.Context, string, ...any) (int64, error)) *MockConnectionExecCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// Migrate mocks base method.
|
||||
func (m *MockClient) Migrate(arg0 context.Context) error {
|
||||
func (m *MockConnection) Migrate(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Migrate", arg0)
|
||||
ret := m.ctrl.Call(m, "Migrate", ctx)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Migrate indicates an expected call of Migrate.
|
||||
func (mr *MockClientMockRecorder) Migrate(arg0 any) *MockClientMigrateCall {
|
||||
func (mr *MockConnectionMockRecorder) Migrate(ctx any) *MockConnectionMigrateCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockClient)(nil).Migrate), arg0)
|
||||
return &MockClientMigrateCall{Call: call}
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockConnection)(nil).Migrate), ctx)
|
||||
return &MockConnectionMigrateCall{Call: call}
|
||||
}
|
||||
|
||||
// MockClientMigrateCall wrap *gomock.Call
|
||||
type MockClientMigrateCall struct {
|
||||
// MockConnectionMigrateCall wrap *gomock.Call
|
||||
type MockConnectionMigrateCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockClientMigrateCall) Return(arg0 error) *MockClientMigrateCall {
|
||||
func (c *MockConnectionMigrateCall) Return(arg0 error) *MockConnectionMigrateCall {
|
||||
c.Call = c.Call.Return(arg0)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockClientMigrateCall) Do(f func(context.Context) error) *MockClientMigrateCall {
|
||||
func (c *MockConnectionMigrateCall) Do(f func(context.Context) error) *MockConnectionMigrateCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockClientMigrateCall) DoAndReturn(f func(context.Context) error) *MockClientMigrateCall {
|
||||
func (c *MockConnectionMigrateCall) DoAndReturn(f func(context.Context) error) *MockConnectionMigrateCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// Ping mocks base method.
|
||||
func (m *MockConnection) Ping(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Ping", ctx)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Ping indicates an expected call of Ping.
|
||||
func (mr *MockConnectionMockRecorder) Ping(ctx any) *MockConnectionPingCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockConnection)(nil).Ping), ctx)
|
||||
return &MockConnectionPingCall{Call: call}
|
||||
}
|
||||
|
||||
// MockConnectionPingCall wrap *gomock.Call
|
||||
type MockConnectionPingCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockConnectionPingCall) Return(arg0 error) *MockConnectionPingCall {
|
||||
c.Call = c.Call.Return(arg0)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockConnectionPingCall) Do(f func(context.Context) error) *MockConnectionPingCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockConnectionPingCall) DoAndReturn(f func(context.Context) error) *MockConnectionPingCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// Query mocks base method.
|
||||
func (m *MockClient) Query(arg0 context.Context, arg1 string, arg2 ...any) (database.Rows, error) {
|
||||
func (m *MockConnection) Query(ctx context.Context, stmt string, args ...any) (database.Rows, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{arg0, arg1}
|
||||
for _, a := range arg2 {
|
||||
varargs := []any{ctx, stmt}
|
||||
for _, a := range args {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Query", varargs...)
|
||||
@@ -483,41 +561,41 @@ func (m *MockClient) Query(arg0 context.Context, arg1 string, arg2 ...any) (data
|
||||
}
|
||||
|
||||
// Query indicates an expected call of Query.
|
||||
func (mr *MockClientMockRecorder) Query(arg0, arg1 any, arg2 ...any) *MockClientQueryCall {
|
||||
func (mr *MockConnectionMockRecorder) Query(ctx, stmt any, args ...any) *MockConnectionQueryCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{arg0, arg1}, arg2...)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockClient)(nil).Query), varargs...)
|
||||
return &MockClientQueryCall{Call: call}
|
||||
varargs := append([]any{ctx, stmt}, args...)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockConnection)(nil).Query), varargs...)
|
||||
return &MockConnectionQueryCall{Call: call}
|
||||
}
|
||||
|
||||
// MockClientQueryCall wrap *gomock.Call
|
||||
type MockClientQueryCall struct {
|
||||
// MockConnectionQueryCall wrap *gomock.Call
|
||||
type MockConnectionQueryCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockClientQueryCall) Return(arg0 database.Rows, arg1 error) *MockClientQueryCall {
|
||||
func (c *MockConnectionQueryCall) Return(arg0 database.Rows, arg1 error) *MockConnectionQueryCall {
|
||||
c.Call = c.Call.Return(arg0, arg1)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockClientQueryCall) Do(f func(context.Context, string, ...any) (database.Rows, error)) *MockClientQueryCall {
|
||||
func (c *MockConnectionQueryCall) Do(f func(context.Context, string, ...any) (database.Rows, error)) *MockConnectionQueryCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockClientQueryCall) DoAndReturn(f func(context.Context, string, ...any) (database.Rows, error)) *MockClientQueryCall {
|
||||
func (c *MockConnectionQueryCall) DoAndReturn(f func(context.Context, string, ...any) (database.Rows, error)) *MockConnectionQueryCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// QueryRow mocks base method.
|
||||
func (m *MockClient) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) database.Row {
|
||||
func (m *MockConnection) QueryRow(ctx context.Context, stmt string, args ...any) database.Row {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{arg0, arg1}
|
||||
for _, a := range arg2 {
|
||||
varargs := []any{ctx, stmt}
|
||||
for _, a := range args {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "QueryRow", varargs...)
|
||||
@@ -526,70 +604,70 @@ func (m *MockClient) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) da
|
||||
}
|
||||
|
||||
// QueryRow indicates an expected call of QueryRow.
|
||||
func (mr *MockClientMockRecorder) QueryRow(arg0, arg1 any, arg2 ...any) *MockClientQueryRowCall {
|
||||
func (mr *MockConnectionMockRecorder) QueryRow(ctx, stmt any, args ...any) *MockConnectionQueryRowCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{arg0, arg1}, arg2...)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockClient)(nil).QueryRow), varargs...)
|
||||
return &MockClientQueryRowCall{Call: call}
|
||||
varargs := append([]any{ctx, stmt}, args...)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockConnection)(nil).QueryRow), varargs...)
|
||||
return &MockConnectionQueryRowCall{Call: call}
|
||||
}
|
||||
|
||||
// MockClientQueryRowCall wrap *gomock.Call
|
||||
type MockClientQueryRowCall struct {
|
||||
// MockConnectionQueryRowCall wrap *gomock.Call
|
||||
type MockConnectionQueryRowCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockClientQueryRowCall) Return(arg0 database.Row) *MockClientQueryRowCall {
|
||||
func (c *MockConnectionQueryRowCall) Return(arg0 database.Row) *MockConnectionQueryRowCall {
|
||||
c.Call = c.Call.Return(arg0)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockClientQueryRowCall) Do(f func(context.Context, string, ...any) database.Row) *MockClientQueryRowCall {
|
||||
func (c *MockConnectionQueryRowCall) Do(f func(context.Context, string, ...any) database.Row) *MockConnectionQueryRowCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockClientQueryRowCall) DoAndReturn(f func(context.Context, string, ...any) database.Row) *MockClientQueryRowCall {
|
||||
func (c *MockConnectionQueryRowCall) DoAndReturn(f func(context.Context, string, ...any) database.Row) *MockConnectionQueryRowCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// Release mocks base method.
|
||||
func (m *MockClient) Release(arg0 context.Context) error {
|
||||
func (m *MockConnection) Release(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Release", arg0)
|
||||
ret := m.ctrl.Call(m, "Release", ctx)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Release indicates an expected call of Release.
|
||||
func (mr *MockClientMockRecorder) Release(arg0 any) *MockClientReleaseCall {
|
||||
func (mr *MockConnectionMockRecorder) Release(ctx any) *MockConnectionReleaseCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Release", reflect.TypeOf((*MockClient)(nil).Release), arg0)
|
||||
return &MockClientReleaseCall{Call: call}
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Release", reflect.TypeOf((*MockConnection)(nil).Release), ctx)
|
||||
return &MockConnectionReleaseCall{Call: call}
|
||||
}
|
||||
|
||||
// MockClientReleaseCall wrap *gomock.Call
|
||||
type MockClientReleaseCall struct {
|
||||
// MockConnectionReleaseCall wrap *gomock.Call
|
||||
type MockConnectionReleaseCall struct {
|
||||
*gomock.Call
|
||||
}
|
||||
|
||||
// Return rewrite *gomock.Call.Return
|
||||
func (c *MockClientReleaseCall) Return(arg0 error) *MockClientReleaseCall {
|
||||
func (c *MockConnectionReleaseCall) Return(arg0 error) *MockConnectionReleaseCall {
|
||||
c.Call = c.Call.Return(arg0)
|
||||
return c
|
||||
}
|
||||
|
||||
// Do rewrite *gomock.Call.Do
|
||||
func (c *MockClientReleaseCall) Do(f func(context.Context) error) *MockClientReleaseCall {
|
||||
func (c *MockConnectionReleaseCall) Do(f func(context.Context) error) *MockConnectionReleaseCall {
|
||||
c.Call = c.Call.Do(f)
|
||||
return c
|
||||
}
|
||||
|
||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||
func (c *MockClientReleaseCall) DoAndReturn(f func(context.Context) error) *MockClientReleaseCall {
|
||||
func (c *MockConnectionReleaseCall) DoAndReturn(f func(context.Context) error) *MockConnectionReleaseCall {
|
||||
c.Call = c.Call.DoAndReturn(f)
|
||||
return c
|
||||
}
|
||||
@@ -598,6 +676,7 @@ func (c *MockClientReleaseCall) DoAndReturn(f func(context.Context) error) *Mock
|
||||
type MockRow struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockRowMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockRowMockRecorder is the mock recorder for MockRow.
|
||||
@@ -618,10 +697,10 @@ func (m *MockRow) EXPECT() *MockRowMockRecorder {
|
||||
}
|
||||
|
||||
// Scan mocks base method.
|
||||
func (m *MockRow) Scan(arg0 ...any) error {
|
||||
func (m *MockRow) Scan(dest ...any) error {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{}
|
||||
for _, a := range arg0 {
|
||||
for _, a := range dest {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Scan", varargs...)
|
||||
@@ -630,9 +709,9 @@ func (m *MockRow) Scan(arg0 ...any) error {
|
||||
}
|
||||
|
||||
// Scan indicates an expected call of Scan.
|
||||
func (mr *MockRowMockRecorder) Scan(arg0 ...any) *MockRowScanCall {
|
||||
func (mr *MockRowMockRecorder) Scan(dest ...any) *MockRowScanCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRow)(nil).Scan), arg0...)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRow)(nil).Scan), dest...)
|
||||
return &MockRowScanCall{Call: call}
|
||||
}
|
||||
|
||||
@@ -663,6 +742,7 @@ func (c *MockRowScanCall) DoAndReturn(f func(...any) error) *MockRowScanCall {
|
||||
type MockRows struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockRowsMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockRowsMockRecorder is the mock recorder for MockRows.
|
||||
@@ -797,10 +877,10 @@ func (c *MockRowsNextCall) DoAndReturn(f func() bool) *MockRowsNextCall {
|
||||
}
|
||||
|
||||
// Scan mocks base method.
|
||||
func (m *MockRows) Scan(arg0 ...any) error {
|
||||
func (m *MockRows) Scan(dest ...any) error {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{}
|
||||
for _, a := range arg0 {
|
||||
for _, a := range dest {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Scan", varargs...)
|
||||
@@ -809,9 +889,9 @@ func (m *MockRows) Scan(arg0 ...any) error {
|
||||
}
|
||||
|
||||
// Scan indicates an expected call of Scan.
|
||||
func (mr *MockRowsMockRecorder) Scan(arg0 ...any) *MockRowsScanCall {
|
||||
func (mr *MockRowsMockRecorder) Scan(dest ...any) *MockRowsScanCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), arg0...)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), dest...)
|
||||
return &MockRowsScanCall{Call: call}
|
||||
}
|
||||
|
||||
@@ -842,6 +922,7 @@ func (c *MockRowsScanCall) DoAndReturn(f func(...any) error) *MockRowsScanCall {
|
||||
type MockTransaction struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockTransactionMockRecorder
|
||||
isgomock struct{}
|
||||
}
|
||||
|
||||
// MockTransactionMockRecorder is the mock recorder for MockTransaction.
|
||||
@@ -862,18 +943,18 @@ func (m *MockTransaction) EXPECT() *MockTransactionMockRecorder {
|
||||
}
|
||||
|
||||
// Begin mocks base method.
|
||||
func (m *MockTransaction) Begin(arg0 context.Context) (database.Transaction, error) {
|
||||
func (m *MockTransaction) Begin(ctx context.Context) (database.Transaction, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Begin", arg0)
|
||||
ret := m.ctrl.Call(m, "Begin", ctx)
|
||||
ret0, _ := ret[0].(database.Transaction)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Begin indicates an expected call of Begin.
|
||||
func (mr *MockTransactionMockRecorder) Begin(arg0 any) *MockTransactionBeginCall {
|
||||
func (mr *MockTransactionMockRecorder) Begin(ctx any) *MockTransactionBeginCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockTransaction)(nil).Begin), arg0)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockTransaction)(nil).Begin), ctx)
|
||||
return &MockTransactionBeginCall{Call: call}
|
||||
}
|
||||
|
||||
@@ -901,17 +982,17 @@ func (c *MockTransactionBeginCall) DoAndReturn(f func(context.Context) (database
|
||||
}
|
||||
|
||||
// Commit mocks base method.
|
||||
func (m *MockTransaction) Commit(arg0 context.Context) error {
|
||||
func (m *MockTransaction) Commit(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Commit", arg0)
|
||||
ret := m.ctrl.Call(m, "Commit", ctx)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Commit indicates an expected call of Commit.
|
||||
func (mr *MockTransactionMockRecorder) Commit(arg0 any) *MockTransactionCommitCall {
|
||||
func (mr *MockTransactionMockRecorder) Commit(ctx any) *MockTransactionCommitCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTransaction)(nil).Commit), arg0)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTransaction)(nil).Commit), ctx)
|
||||
return &MockTransactionCommitCall{Call: call}
|
||||
}
|
||||
|
||||
@@ -939,17 +1020,17 @@ func (c *MockTransactionCommitCall) DoAndReturn(f func(context.Context) error) *
|
||||
}
|
||||
|
||||
// End mocks base method.
|
||||
func (m *MockTransaction) End(arg0 context.Context, arg1 error) error {
|
||||
func (m *MockTransaction) End(ctx context.Context, err error) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "End", arg0, arg1)
|
||||
ret := m.ctrl.Call(m, "End", ctx, err)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// End indicates an expected call of End.
|
||||
func (mr *MockTransactionMockRecorder) End(arg0, arg1 any) *MockTransactionEndCall {
|
||||
func (mr *MockTransactionMockRecorder) End(ctx, err any) *MockTransactionEndCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "End", reflect.TypeOf((*MockTransaction)(nil).End), arg0, arg1)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "End", reflect.TypeOf((*MockTransaction)(nil).End), ctx, err)
|
||||
return &MockTransactionEndCall{Call: call}
|
||||
}
|
||||
|
||||
@@ -977,10 +1058,10 @@ func (c *MockTransactionEndCall) DoAndReturn(f func(context.Context, error) erro
|
||||
}
|
||||
|
||||
// Exec mocks base method.
|
||||
func (m *MockTransaction) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) {
|
||||
func (m *MockTransaction) Exec(ctx context.Context, stmt string, args ...any) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{arg0, arg1}
|
||||
for _, a := range arg2 {
|
||||
varargs := []any{ctx, stmt}
|
||||
for _, a := range args {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Exec", varargs...)
|
||||
@@ -990,9 +1071,9 @@ func (m *MockTransaction) Exec(arg0 context.Context, arg1 string, arg2 ...any) (
|
||||
}
|
||||
|
||||
// Exec indicates an expected call of Exec.
|
||||
func (mr *MockTransactionMockRecorder) Exec(arg0, arg1 any, arg2 ...any) *MockTransactionExecCall {
|
||||
func (mr *MockTransactionMockRecorder) Exec(ctx, stmt any, args ...any) *MockTransactionExecCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{arg0, arg1}, arg2...)
|
||||
varargs := append([]any{ctx, stmt}, args...)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTransaction)(nil).Exec), varargs...)
|
||||
return &MockTransactionExecCall{Call: call}
|
||||
}
|
||||
@@ -1021,10 +1102,10 @@ func (c *MockTransactionExecCall) DoAndReturn(f func(context.Context, string, ..
|
||||
}
|
||||
|
||||
// Query mocks base method.
|
||||
func (m *MockTransaction) Query(arg0 context.Context, arg1 string, arg2 ...any) (database.Rows, error) {
|
||||
func (m *MockTransaction) Query(ctx context.Context, stmt string, args ...any) (database.Rows, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{arg0, arg1}
|
||||
for _, a := range arg2 {
|
||||
varargs := []any{ctx, stmt}
|
||||
for _, a := range args {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Query", varargs...)
|
||||
@@ -1034,9 +1115,9 @@ func (m *MockTransaction) Query(arg0 context.Context, arg1 string, arg2 ...any)
|
||||
}
|
||||
|
||||
// Query indicates an expected call of Query.
|
||||
func (mr *MockTransactionMockRecorder) Query(arg0, arg1 any, arg2 ...any) *MockTransactionQueryCall {
|
||||
func (mr *MockTransactionMockRecorder) Query(ctx, stmt any, args ...any) *MockTransactionQueryCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{arg0, arg1}, arg2...)
|
||||
varargs := append([]any{ctx, stmt}, args...)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockTransaction)(nil).Query), varargs...)
|
||||
return &MockTransactionQueryCall{Call: call}
|
||||
}
|
||||
@@ -1065,10 +1146,10 @@ func (c *MockTransactionQueryCall) DoAndReturn(f func(context.Context, string, .
|
||||
}
|
||||
|
||||
// QueryRow mocks base method.
|
||||
func (m *MockTransaction) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) database.Row {
|
||||
func (m *MockTransaction) QueryRow(ctx context.Context, stmt string, args ...any) database.Row {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{arg0, arg1}
|
||||
for _, a := range arg2 {
|
||||
varargs := []any{ctx, stmt}
|
||||
for _, a := range args {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "QueryRow", varargs...)
|
||||
@@ -1077,9 +1158,9 @@ func (m *MockTransaction) QueryRow(arg0 context.Context, arg1 string, arg2 ...an
|
||||
}
|
||||
|
||||
// QueryRow indicates an expected call of QueryRow.
|
||||
func (mr *MockTransactionMockRecorder) QueryRow(arg0, arg1 any, arg2 ...any) *MockTransactionQueryRowCall {
|
||||
func (mr *MockTransactionMockRecorder) QueryRow(ctx, stmt any, args ...any) *MockTransactionQueryRowCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{arg0, arg1}, arg2...)
|
||||
varargs := append([]any{ctx, stmt}, args...)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockTransaction)(nil).QueryRow), varargs...)
|
||||
return &MockTransactionQueryRowCall{Call: call}
|
||||
}
|
||||
@@ -1108,17 +1189,17 @@ func (c *MockTransactionQueryRowCall) DoAndReturn(f func(context.Context, string
|
||||
}
|
||||
|
||||
// Rollback mocks base method.
|
||||
func (m *MockTransaction) Rollback(arg0 context.Context) error {
|
||||
func (m *MockTransaction) Rollback(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Rollback", arg0)
|
||||
ret := m.ctrl.Call(m, "Rollback", ctx)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Rollback indicates an expected call of Rollback.
|
||||
func (mr *MockTransactionMockRecorder) Rollback(arg0 any) *MockTransactionRollbackCall {
|
||||
func (mr *MockTransactionMockRecorder) Rollback(ctx any) *MockTransactionRollbackCall {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockTransaction)(nil).Rollback), arg0)
|
||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockTransaction)(nil).Rollback), ctx)
|
||||
return &MockTransactionRollbackCall{Call: call}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,15 +13,15 @@ type pgxConn struct {
|
||||
*pgxpool.Conn
|
||||
}
|
||||
|
||||
var _ database.Client = (*pgxConn)(nil)
|
||||
var _ database.Connection = (*pgxConn)(nil)
|
||||
|
||||
// Release implements [database.Client].
|
||||
// Release implements [database.Connection].
|
||||
func (c *pgxConn) Release(_ context.Context) error {
|
||||
c.Conn.Release()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Begin implements [database.Client].
|
||||
// Begin implements [database.Connection].
|
||||
func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||
tx, err := c.BeginTx(ctx, transactionOptionsToPgx(opts))
|
||||
if err != nil {
|
||||
@@ -30,8 +30,8 @@ func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions)
|
||||
return &Transaction{tx}, nil
|
||||
}
|
||||
|
||||
// Query implements sql.Client.
|
||||
// Subtle: this method shadows the method (*Conn).Query of pgxConn.Conn.
|
||||
// Query implements [database.Connection].
|
||||
// Subtle: this method shadows the method (*Conn).Query of [pgxConn.Conn].
|
||||
func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||
rows, err := c.Conn.Query(ctx, sql, args...)
|
||||
if err != nil {
|
||||
@@ -40,14 +40,14 @@ func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.
|
||||
return &Rows{rows}, nil
|
||||
}
|
||||
|
||||
// QueryRow implements sql.Client.
|
||||
// Subtle: this method shadows the method (*Conn).QueryRow of pgxConn.Conn.
|
||||
// QueryRow implements [database.Connection].
|
||||
// Subtle: this method shadows the method (*Conn).QueryRow of [pgxConn.Conn].
|
||||
func (c *pgxConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||
return &Row{c.Conn.QueryRow(ctx, sql, args...)}
|
||||
}
|
||||
|
||||
// Exec implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
||||
// QueryRow implements [database.Connection].
|
||||
// Subtle: this method shadows the method (*Conn).QueryRow of [pgxConn.Conn].
|
||||
func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
|
||||
res, err := c.Conn.Exec(ctx, sql, args...)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,6 +2,7 @@ package postgres
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
@@ -16,6 +17,18 @@ func wrapError(err error) error {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return database.NewNoRowFoundError(err)
|
||||
}
|
||||
if errors.Is(err, pgx.ErrTooManyRows) {
|
||||
return database.NewMultipleRowsFoundError(err)
|
||||
}
|
||||
|
||||
// scany only exports its errors as strings
|
||||
if strings.HasPrefix(err.Error(), "scany: expected 1 row, got: ") {
|
||||
return database.NewMultipleRowsFoundError(err)
|
||||
}
|
||||
if strings.HasPrefix(err.Error(), "scany:") || strings.HasPrefix(err.Error(), "scanning:") {
|
||||
return database.NewScanError(err)
|
||||
}
|
||||
|
||||
var pgxErr *pgconn.PgError
|
||||
if !errors.As(err, &pgxErr) {
|
||||
return database.NewUnknownError(err)
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed 004_correct_set_updated_at/up.sql
|
||||
up004CorrectSetUpdatedAt string
|
||||
//go:embed 004_correct_set_updated_at/down.sql
|
||||
down004CorrectSetUpdatedAt string
|
||||
)
|
||||
|
||||
func init() {
|
||||
registerSQLMigration(4, up004CorrectSetUpdatedAt, down004CorrectSetUpdatedAt)
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
CREATE OR REPLACE TRIGGER trigger_set_updated_at
|
||||
BEFORE UPDATE ON zitadel.instances
|
||||
FOR EACH ROW
|
||||
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
|
||||
EXECUTE FUNCTION zitadel.set_updated_at();
|
||||
|
||||
CREATE OR REPLACE TRIGGER trigger_set_updated_at
|
||||
BEFORE UPDATE ON zitadel.organizations
|
||||
FOR EACH ROW
|
||||
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
|
||||
EXECUTE FUNCTION zitadel.set_updated_at();
|
||||
|
||||
CREATE OR REPLACE TRIGGER trg_set_updated_at_instance_domains
|
||||
BEFORE UPDATE ON zitadel.instance_domains
|
||||
FOR EACH ROW
|
||||
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
|
||||
EXECUTE FUNCTION zitadel.set_updated_at();
|
||||
|
||||
CREATE OR REPLACE TRIGGER trg_set_updated_at_org_domains
|
||||
BEFORE UPDATE ON zitadel.org_domains
|
||||
FOR EACH ROW
|
||||
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
|
||||
EXECUTE FUNCTION zitadel.set_updated_at();
|
||||
@@ -0,0 +1,23 @@
|
||||
CREATE OR REPLACE TRIGGER trigger_set_updated_at
|
||||
BEFORE UPDATE ON zitadel.instances
|
||||
FOR EACH ROW
|
||||
WHEN (NEW.updated_at IS NULL)
|
||||
EXECUTE FUNCTION zitadel.set_updated_at();
|
||||
|
||||
CREATE OR REPLACE TRIGGER trigger_set_updated_at
|
||||
BEFORE UPDATE ON zitadel.organizations
|
||||
FOR EACH ROW
|
||||
WHEN (NEW.updated_at IS NULL)
|
||||
EXECUTE FUNCTION zitadel.set_updated_at();
|
||||
|
||||
CREATE OR REPLACE TRIGGER trg_set_updated_at_instance_domains
|
||||
BEFORE UPDATE ON zitadel.instance_domains
|
||||
FOR EACH ROW
|
||||
WHEN (NEW.updated_at IS NULL)
|
||||
EXECUTE FUNCTION zitadel.set_updated_at();
|
||||
|
||||
CREATE OR REPLACE TRIGGER trg_set_updated_at_org_domains
|
||||
BEFORE UPDATE ON zitadel.org_domains
|
||||
FOR EACH ROW
|
||||
WHEN (NEW.updated_at IS NULL)
|
||||
EXECUTE FUNCTION zitadel.set_updated_at();
|
||||
@@ -22,8 +22,8 @@ func PGxPool(pool *pgxpool.Pool) *pgxPool {
|
||||
}
|
||||
|
||||
// Acquire implements [database.Pool].
|
||||
func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) {
|
||||
conn, err := c.Pool.Acquire(ctx)
|
||||
func (p *pgxPool) Acquire(ctx context.Context) (database.Connection, error) {
|
||||
conn, err := p.Pool.Acquire(ctx)
|
||||
if err != nil {
|
||||
return nil, wrapError(err)
|
||||
}
|
||||
@@ -32,8 +32,8 @@ func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) {
|
||||
|
||||
// Query implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool.
|
||||
func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||
rows, err := c.Pool.Query(ctx, sql, args...)
|
||||
func (p *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||
rows, err := p.Pool.Query(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return nil, wrapError(err)
|
||||
}
|
||||
@@ -42,14 +42,14 @@ func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.
|
||||
|
||||
// QueryRow implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool.
|
||||
func (c *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||
return &Row{c.Pool.QueryRow(ctx, sql, args...)}
|
||||
func (p *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||
return &Row{p.Pool.QueryRow(ctx, sql, args...)}
|
||||
}
|
||||
|
||||
// Exec implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
||||
func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
|
||||
res, err := c.Pool.Exec(ctx, sql, args...)
|
||||
func (p *pgxPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
|
||||
res, err := p.Pool.Exec(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return 0, wrapError(err)
|
||||
}
|
||||
@@ -57,8 +57,8 @@ func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) (int64, err
|
||||
}
|
||||
|
||||
// Begin implements [database.Pool].
|
||||
func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||
tx, err := c.BeginTx(ctx, transactionOptionsToPgx(opts))
|
||||
func (p *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||
tx, err := p.BeginTx(ctx, transactionOptionsToPgx(opts))
|
||||
if err != nil {
|
||||
return nil, wrapError(err)
|
||||
}
|
||||
@@ -66,23 +66,23 @@ func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions)
|
||||
}
|
||||
|
||||
// Close implements [database.Pool].
|
||||
func (c *pgxPool) Close(_ context.Context) error {
|
||||
c.Pool.Close()
|
||||
func (p *pgxPool) Close(_ context.Context) error {
|
||||
p.Pool.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ping implements [database.Pool].
|
||||
func (c *pgxPool) Ping(ctx context.Context) error {
|
||||
return wrapError(c.Pool.Ping(ctx))
|
||||
func (p *pgxPool) Ping(ctx context.Context) error {
|
||||
return wrapError(p.Pool.Ping(ctx))
|
||||
}
|
||||
|
||||
// Migrate implements [database.Migrator].
|
||||
func (c *pgxPool) Migrate(ctx context.Context) error {
|
||||
func (p *pgxPool) Migrate(ctx context.Context) error {
|
||||
if isMigrated {
|
||||
return nil
|
||||
}
|
||||
|
||||
client, err := c.Pool.Acquire(ctx)
|
||||
client, err := p.Pool.Acquire(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -93,8 +93,8 @@ func (c *pgxPool) Migrate(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Migrate implements [database.PoolTest].
|
||||
func (c *pgxPool) MigrateTest(ctx context.Context) error {
|
||||
client, err := c.Pool.Acquire(ctx)
|
||||
func (p *pgxPool) MigrateTest(ctx context.Context) error {
|
||||
client, err := p.Pool.Acquire(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -11,18 +11,18 @@ type sqlConn struct {
|
||||
*sql.Conn
|
||||
}
|
||||
|
||||
func SQLConn(conn *sql.Conn) database.Client {
|
||||
func SQLConn(conn *sql.Conn) database.Connection {
|
||||
return &sqlConn{Conn: conn}
|
||||
}
|
||||
|
||||
var _ database.Client = (*sqlConn)(nil)
|
||||
var _ database.Connection = (*sqlConn)(nil)
|
||||
|
||||
// Release implements [database.Client].
|
||||
// Release implements [database.Connection].
|
||||
func (c *sqlConn) Release(_ context.Context) error {
|
||||
return c.Close()
|
||||
}
|
||||
|
||||
// Begin implements [database.Client].
|
||||
// Begin implements [database.Connection].
|
||||
func (c *sqlConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||
tx, err := c.BeginTx(ctx, transactionOptionsToSQL(opts))
|
||||
if err != nil {
|
||||
|
||||
@@ -3,6 +3,7 @@ package sql
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
@@ -19,6 +20,15 @@ func wrapError(err error) error {
|
||||
if errors.Is(err, pgx.ErrNoRows) || errors.Is(err, sql.ErrNoRows) {
|
||||
return database.NewNoRowFoundError(err)
|
||||
}
|
||||
|
||||
// scany only exports its errors as strings
|
||||
if strings.HasPrefix(err.Error(), "scany: expected 1 row, got: ") {
|
||||
return database.NewMultipleRowsFoundError(err)
|
||||
}
|
||||
if strings.HasPrefix(err.Error(), "scany:") || strings.HasPrefix(err.Error(), "scanning:") {
|
||||
return database.NewScanError(err)
|
||||
}
|
||||
|
||||
var pgxErr *pgconn.PgError
|
||||
if !errors.As(err, &pgxErr) {
|
||||
return database.NewUnknownError(err)
|
||||
|
||||
@@ -20,8 +20,8 @@ func SQLPool(db *sql.DB) *sqlPool {
|
||||
}
|
||||
|
||||
// Acquire implements [database.Pool].
|
||||
func (c *sqlPool) Acquire(ctx context.Context) (database.Client, error) {
|
||||
conn, err := c.Conn(ctx)
|
||||
func (p *sqlPool) Acquire(ctx context.Context) (database.Connection, error) {
|
||||
conn, err := p.Conn(ctx)
|
||||
if err != nil {
|
||||
return nil, wrapError(err)
|
||||
}
|
||||
@@ -30,9 +30,9 @@ func (c *sqlPool) Acquire(ctx context.Context) (database.Client, error) {
|
||||
|
||||
// Query implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool.
|
||||
func (c *sqlPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||
func (p *sqlPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||
//nolint:rowserrcheck // Rows.Close is called by the caller
|
||||
rows, err := c.QueryContext(ctx, sql, args...)
|
||||
rows, err := p.QueryContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return nil, wrapError(err)
|
||||
}
|
||||
@@ -41,14 +41,14 @@ func (c *sqlPool) Query(ctx context.Context, sql string, args ...any) (database.
|
||||
|
||||
// QueryRow implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool.
|
||||
func (c *sqlPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||
return &Row{c.QueryRowContext(ctx, sql, args...)}
|
||||
func (p *sqlPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||
return &Row{p.QueryRowContext(ctx, sql, args...)}
|
||||
}
|
||||
|
||||
// Exec implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
||||
func (c *sqlPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
|
||||
res, err := c.ExecContext(ctx, sql, args...)
|
||||
func (p *sqlPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
|
||||
res, err := p.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return 0, wrapError(err)
|
||||
}
|
||||
@@ -56,8 +56,8 @@ func (c *sqlPool) Exec(ctx context.Context, sql string, args ...any) (int64, err
|
||||
}
|
||||
|
||||
// Begin implements [database.Pool].
|
||||
func (c *sqlPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||
tx, err := c.BeginTx(ctx, transactionOptionsToSQL(opts))
|
||||
func (p *sqlPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||
tx, err := p.BeginTx(ctx, transactionOptionsToSQL(opts))
|
||||
if err != nil {
|
||||
return nil, wrapError(err)
|
||||
}
|
||||
@@ -65,16 +65,16 @@ func (c *sqlPool) Begin(ctx context.Context, opts *database.TransactionOptions)
|
||||
}
|
||||
|
||||
// Ping implements [database.Pool].
|
||||
func (c *sqlPool) Ping(ctx context.Context) error {
|
||||
return wrapError(c.PingContext(ctx))
|
||||
func (p *sqlPool) Ping(ctx context.Context) error {
|
||||
return wrapError(p.PingContext(ctx))
|
||||
}
|
||||
|
||||
// Close implements [database.Pool].
|
||||
func (c *sqlPool) Close(_ context.Context) error {
|
||||
return c.DB.Close()
|
||||
func (p *sqlPool) Close(_ context.Context) error {
|
||||
return p.DB.Close()
|
||||
}
|
||||
|
||||
// Migrate implements [database.Migrator].
|
||||
func (c *sqlPool) Migrate(ctx context.Context) error {
|
||||
func (p *sqlPool) Migrate(ctx context.Context) error {
|
||||
return ErrMigrate
|
||||
}
|
||||
|
||||
@@ -5,7 +5,34 @@ import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var ErrNoChanges = errors.New("update must contain a change")
|
||||
var (
|
||||
ErrNoChanges = errors.New("update must contain a change")
|
||||
)
|
||||
|
||||
type MissingConditionError struct {
|
||||
col Column
|
||||
}
|
||||
|
||||
func NewMissingConditionError(col Column) error {
|
||||
return &MissingConditionError{
|
||||
col: col,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *MissingConditionError) Error() string {
|
||||
var builder StatementBuilder
|
||||
builder.WriteString("missing condition for column")
|
||||
if e.col != nil {
|
||||
builder.WriteString(" on ")
|
||||
e.col.WriteQualified(&builder)
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func (e *MissingConditionError) Is(target error) bool {
|
||||
_, ok := target.(*MissingConditionError)
|
||||
return ok
|
||||
}
|
||||
|
||||
// NoRowFoundError is returned when QueryRow does not find any row.
|
||||
// It wraps the dialect specific original error to provide more context.
|
||||
@@ -20,6 +47,9 @@ func NewNoRowFoundError(original error) error {
|
||||
}
|
||||
|
||||
func (e *NoRowFoundError) Error() string {
|
||||
if e.original != nil {
|
||||
return fmt.Sprintf("no row found: %v", e.original)
|
||||
}
|
||||
return "no row found"
|
||||
}
|
||||
|
||||
@@ -36,18 +66,19 @@ func (e *NoRowFoundError) Unwrap() error {
|
||||
// It wraps the dialect specific original error to provide more context.
|
||||
type MultipleRowsFoundError struct {
|
||||
original error
|
||||
count int
|
||||
}
|
||||
|
||||
func NewMultipleRowsFoundError(original error, count int) error {
|
||||
func NewMultipleRowsFoundError(original error) error {
|
||||
return &MultipleRowsFoundError{
|
||||
original: original,
|
||||
count: count,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *MultipleRowsFoundError) Error() string {
|
||||
return fmt.Sprintf("multiple rows found: %d", e.count)
|
||||
if e.original != nil {
|
||||
return fmt.Sprintf("multiple rows found: %v", e.original)
|
||||
}
|
||||
return "multiple rows found"
|
||||
}
|
||||
|
||||
func (e *MultipleRowsFoundError) Is(target error) bool {
|
||||
@@ -77,8 +108,8 @@ type IntegrityViolationError struct {
|
||||
original error
|
||||
}
|
||||
|
||||
func NewIntegrityViolationError(typ IntegrityType, table, constraint string, original error) error {
|
||||
return &IntegrityViolationError{
|
||||
func newIntegrityViolationError(typ IntegrityType, table, constraint string, original error) IntegrityViolationError {
|
||||
return IntegrityViolationError{
|
||||
integrityType: typ,
|
||||
table: table,
|
||||
constraint: constraint,
|
||||
@@ -87,7 +118,10 @@ func NewIntegrityViolationError(typ IntegrityType, table, constraint string, ori
|
||||
}
|
||||
|
||||
func (e *IntegrityViolationError) Error() string {
|
||||
return fmt.Sprintf("integrity violation of type %q on %q (constraint: %q): %v", e.integrityType, e.table, e.constraint, e.original)
|
||||
if e.original != nil {
|
||||
return fmt.Sprintf("integrity violation of type %q on %q (constraint: %q): %v", e.integrityType, e.table, e.constraint, e.original)
|
||||
}
|
||||
return fmt.Sprintf("integrity violation of type %q on %q (constraint: %q)", e.integrityType, e.table, e.constraint)
|
||||
}
|
||||
|
||||
func (e *IntegrityViolationError) Is(target error) bool {
|
||||
@@ -108,12 +142,7 @@ type CheckError struct {
|
||||
|
||||
func NewCheckError(table, constraint string, original error) error {
|
||||
return &CheckError{
|
||||
IntegrityViolationError: IntegrityViolationError{
|
||||
integrityType: IntegrityTypeCheck,
|
||||
table: table,
|
||||
constraint: constraint,
|
||||
original: original,
|
||||
},
|
||||
IntegrityViolationError: newIntegrityViolationError(IntegrityTypeCheck, table, constraint, original),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,12 +164,7 @@ type UniqueError struct {
|
||||
|
||||
func NewUniqueError(table, constraint string, original error) error {
|
||||
return &UniqueError{
|
||||
IntegrityViolationError: IntegrityViolationError{
|
||||
integrityType: IntegrityTypeUnique,
|
||||
table: table,
|
||||
constraint: constraint,
|
||||
original: original,
|
||||
},
|
||||
IntegrityViolationError: newIntegrityViolationError(IntegrityTypeUnique, table, constraint, original),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,12 +186,7 @@ type ForeignKeyError struct {
|
||||
|
||||
func NewForeignKeyError(table, constraint string, original error) error {
|
||||
return &ForeignKeyError{
|
||||
IntegrityViolationError: IntegrityViolationError{
|
||||
integrityType: IntegrityTypeForeign,
|
||||
table: table,
|
||||
constraint: constraint,
|
||||
original: original,
|
||||
},
|
||||
IntegrityViolationError: newIntegrityViolationError(IntegrityTypeForeign, table, constraint, original),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,12 +208,7 @@ type NotNullError struct {
|
||||
|
||||
func NewNotNullError(table, constraint string, original error) error {
|
||||
return &NotNullError{
|
||||
IntegrityViolationError: IntegrityViolationError{
|
||||
integrityType: IntegrityTypeNotNull,
|
||||
table: table,
|
||||
constraint: constraint,
|
||||
original: original,
|
||||
},
|
||||
IntegrityViolationError: newIntegrityViolationError(IntegrityTypeNotNull, table, constraint, original),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -207,6 +221,31 @@ func (e *NotNullError) Unwrap() error {
|
||||
return &e.IntegrityViolationError
|
||||
}
|
||||
|
||||
// ScanError is returned when scanning rows into objects failed.
|
||||
// It wraps the original error to provide more context.
|
||||
type ScanError struct {
|
||||
original error
|
||||
}
|
||||
|
||||
func NewScanError(original error) error {
|
||||
return &ScanError{
|
||||
original: original,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ScanError) Error() string {
|
||||
return fmt.Sprintf("Scan error: %v", e.original)
|
||||
}
|
||||
|
||||
func (e *ScanError) Is(target error) bool {
|
||||
_, ok := target.(*ScanError)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (e *ScanError) Unwrap() error {
|
||||
return e.original
|
||||
}
|
||||
|
||||
// UnknownError is returned when an unknown error occurs.
|
||||
// It wraps the dialect specific original error to provide more context.
|
||||
// It is used to indicate that an error occurred that does not fit into any of the other categories.
|
||||
|
||||
304
backend/v3/storage/database/errors_test.go
Normal file
304
backend/v3/storage/database/errors_test.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
err error
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "missing condition without column",
|
||||
err: NewMissingConditionError(nil),
|
||||
want: "missing condition for column",
|
||||
},
|
||||
{
|
||||
name: "missing condition with column",
|
||||
err: NewMissingConditionError(NewColumn("table", "column")),
|
||||
want: "missing condition for column on table.column",
|
||||
},
|
||||
{
|
||||
name: "no row found without original error",
|
||||
err: NewNoRowFoundError(nil),
|
||||
want: "no row found",
|
||||
},
|
||||
{
|
||||
name: "no row found with original error",
|
||||
err: NewNoRowFoundError(errors.New("original error")),
|
||||
want: "no row found: original error",
|
||||
},
|
||||
{
|
||||
name: "multiple rows found without original error",
|
||||
err: NewMultipleRowsFoundError(nil),
|
||||
want: "multiple rows found",
|
||||
},
|
||||
{
|
||||
name: "multiple rows found with original error",
|
||||
err: NewMultipleRowsFoundError(errors.New("original error")),
|
||||
want: "multiple rows found: original error",
|
||||
},
|
||||
{
|
||||
name: "check violation without original error",
|
||||
err: NewCheckError("table", "constraint", nil),
|
||||
want: `integrity violation of type "check" on "table" (constraint: "constraint")`,
|
||||
},
|
||||
{
|
||||
name: "check violation with original error",
|
||||
err: NewCheckError("table", "constraint", errors.New("original error")),
|
||||
want: `integrity violation of type "check" on "table" (constraint: "constraint"): original error`,
|
||||
},
|
||||
{
|
||||
name: "unique violation without original error",
|
||||
err: NewUniqueError("table", "constraint", nil),
|
||||
want: `integrity violation of type "unique" on "table" (constraint: "constraint")`,
|
||||
},
|
||||
{
|
||||
name: "unique violation with original error",
|
||||
err: NewUniqueError("table", "constraint", errors.New("original error")),
|
||||
want: `integrity violation of type "unique" on "table" (constraint: "constraint"): original error`,
|
||||
},
|
||||
{
|
||||
name: "foreign key violation without original error",
|
||||
err: NewForeignKeyError("table", "constraint", nil),
|
||||
want: `integrity violation of type "foreign" on "table" (constraint: "constraint")`,
|
||||
},
|
||||
{
|
||||
name: "foreign key violation with original error",
|
||||
err: NewForeignKeyError("table", "constraint", errors.New("original error")),
|
||||
want: `integrity violation of type "foreign" on "table" (constraint: "constraint"): original error`,
|
||||
},
|
||||
{
|
||||
name: "not null violation without original error",
|
||||
err: NewNotNullError("table", "constraint", nil),
|
||||
want: `integrity violation of type "not null" on "table" (constraint: "constraint")`,
|
||||
},
|
||||
{
|
||||
name: "not null violation with original error",
|
||||
err: NewNotNullError("table", "constraint", errors.New("original error")),
|
||||
want: `integrity violation of type "not null" on "table" (constraint: "constraint"): original error`,
|
||||
},
|
||||
{
|
||||
name: "unknown error",
|
||||
err: NewUnknownError(errors.New("original error")),
|
||||
want: `unknown database error: original error`,
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
assert.Equal(t, test.want, test.err.Error())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnwrap(t *testing.T) {
|
||||
originalErr := errors.New("original error")
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
err error
|
||||
want error
|
||||
}{
|
||||
{
|
||||
name: "missing condition without column",
|
||||
err: NewMissingConditionError(nil),
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "missing condition with column",
|
||||
err: NewMissingConditionError(NewColumn("table", "column")),
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "no row found without original error",
|
||||
err: NewNoRowFoundError(nil),
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "no row found with original error",
|
||||
err: NewNoRowFoundError(errors.New("original error")),
|
||||
want: originalErr,
|
||||
},
|
||||
{
|
||||
name: "multiple rows found without original error",
|
||||
err: NewMultipleRowsFoundError(nil),
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "multiple rows found with original error",
|
||||
err: NewMultipleRowsFoundError(originalErr),
|
||||
want: originalErr,
|
||||
},
|
||||
{
|
||||
name: "check violation without original error",
|
||||
err: NewCheckError("table", "constraint", nil),
|
||||
want: &IntegrityViolationError{
|
||||
integrityType: IntegrityTypeCheck,
|
||||
table: "table",
|
||||
constraint: "constraint",
|
||||
original: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "check violation with original error",
|
||||
err: NewCheckError("table", "constraint", originalErr),
|
||||
want: &IntegrityViolationError{
|
||||
integrityType: IntegrityTypeCheck,
|
||||
table: "table",
|
||||
constraint: "constraint",
|
||||
original: originalErr,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unique violation without original error",
|
||||
err: NewUniqueError("table", "constraint", nil),
|
||||
want: &IntegrityViolationError{
|
||||
integrityType: IntegrityTypeUnique,
|
||||
table: "table",
|
||||
constraint: "constraint",
|
||||
original: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unique violation with original error",
|
||||
err: NewUniqueError("table", "constraint", originalErr),
|
||||
want: &IntegrityViolationError{
|
||||
integrityType: IntegrityTypeUnique,
|
||||
table: "table",
|
||||
constraint: "constraint",
|
||||
original: originalErr,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "foreign key violation without original error",
|
||||
err: NewForeignKeyError("table", "constraint", nil),
|
||||
want: &IntegrityViolationError{
|
||||
integrityType: IntegrityTypeForeign,
|
||||
table: "table",
|
||||
constraint: "constraint",
|
||||
original: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "foreign key violation with original error",
|
||||
err: NewForeignKeyError("table", "constraint", originalErr),
|
||||
want: &IntegrityViolationError{
|
||||
integrityType: IntegrityTypeForeign,
|
||||
table: "table",
|
||||
constraint: "constraint",
|
||||
original: originalErr,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not null violation without original error",
|
||||
err: NewNotNullError("table", "constraint", nil),
|
||||
want: &IntegrityViolationError{
|
||||
integrityType: IntegrityTypeNotNull,
|
||||
table: "table",
|
||||
constraint: "constraint",
|
||||
original: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not null violation with original error",
|
||||
err: NewNotNullError("table", "constraint", originalErr),
|
||||
want: &IntegrityViolationError{
|
||||
integrityType: IntegrityTypeNotNull,
|
||||
table: "table",
|
||||
constraint: "constraint",
|
||||
original: originalErr,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unwrap integrity violation error",
|
||||
err: errors.Unwrap(NewNotNullError("table", "constraint", originalErr)),
|
||||
want: originalErr,
|
||||
},
|
||||
{
|
||||
name: "unknown error",
|
||||
err: NewUnknownError(originalErr),
|
||||
want: originalErr,
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
assert.Equal(t, test.want, errors.Unwrap(test.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIs(t *testing.T) {
|
||||
originalErr := errors.New("original error")
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
err error
|
||||
want error
|
||||
}{
|
||||
{
|
||||
name: "missing condition",
|
||||
err: NewMissingConditionError(NewColumn("table", "column")),
|
||||
want: new(MissingConditionError),
|
||||
},
|
||||
{
|
||||
name: "no row found",
|
||||
err: NewNoRowFoundError(errors.New("original error")),
|
||||
want: new(NoRowFoundError),
|
||||
},
|
||||
{
|
||||
name: "multiple rows found",
|
||||
err: NewMultipleRowsFoundError(originalErr),
|
||||
want: new(MultipleRowsFoundError),
|
||||
},
|
||||
{
|
||||
name: "check violation is for integrity",
|
||||
err: NewCheckError("table", "constraint", nil),
|
||||
want: new(IntegrityViolationError),
|
||||
},
|
||||
{
|
||||
name: "check violation is check violation",
|
||||
err: NewCheckError("table", "constraint", nil),
|
||||
want: new(CheckError),
|
||||
},
|
||||
{
|
||||
name: "unique violation is for integrity",
|
||||
err: NewUniqueError("table", "constraint", nil),
|
||||
want: new(IntegrityViolationError),
|
||||
},
|
||||
{
|
||||
name: "unique violation is unique violation",
|
||||
err: NewUniqueError("table", "constraint", nil),
|
||||
want: new(UniqueError),
|
||||
},
|
||||
{
|
||||
name: "foreign key violation is for integrity",
|
||||
err: NewForeignKeyError("table", "constraint", nil),
|
||||
want: new(IntegrityViolationError),
|
||||
},
|
||||
{
|
||||
name: "foreign key violation is foreign key violation",
|
||||
err: NewForeignKeyError("table", "constraint", nil),
|
||||
want: new(ForeignKeyError),
|
||||
},
|
||||
{
|
||||
name: "not null violation is for integrity",
|
||||
err: NewNotNullError("table", "constraint", nil),
|
||||
want: new(IntegrityViolationError),
|
||||
},
|
||||
{
|
||||
name: "not null violation is not null violation",
|
||||
err: NewNotNullError("table", "constraint", nil),
|
||||
want: new(NotNullError),
|
||||
},
|
||||
{
|
||||
name: "unknown error",
|
||||
err: NewUnknownError(originalErr),
|
||||
want: new(UnknownError),
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
assert.ErrorIs(t, test.err, test.want)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -21,8 +21,8 @@ import (
|
||||
func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
instance := integration.NewInstance(CTX)
|
||||
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceDomainRepo := instanceRepo.Domains(true)
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
instanceDomainRepo := repository.InstanceDomainRepository()
|
||||
|
||||
t.Cleanup(func() {
|
||||
_, err := instance.Client.InstanceV2Beta.DeleteInstance(CTX, &v2beta.DeleteInstanceRequest{
|
||||
@@ -36,7 +36,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
// Wait for instance to be created
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
_, err := instanceRepo.Get(CTX,
|
||||
_, err := instanceRepo.Get(CTX, pool,
|
||||
database.WithCondition(instanceRepo.IDCondition(instance.ID())),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
@@ -66,7 +66,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
// Test that domain add reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
domain, err := instanceDomainRepo.Get(CTX, pool,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||
@@ -96,7 +96,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
|
||||
t.Cleanup(func() {
|
||||
// first we change the primary domain to something else
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
domain, err := instanceDomainRepo.Get(CTX, pool,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||
@@ -125,7 +125,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
// Wait for domain to be created
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
domain, err := instanceDomainRepo.Get(CTX, pool,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||
@@ -151,7 +151,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
// Test that set primary reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
domain, err := instanceDomainRepo.Get(CTX, pool,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||
@@ -180,7 +180,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
// Wait for domain to be created and verify it exists
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
_, err := instanceDomainRepo.Get(CTX,
|
||||
_, err := instanceDomainRepo.Get(CTX, pool,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||
@@ -202,7 +202,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
// Test that domain remove reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
domain, err := instanceDomainRepo.Get(CTX, pool,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||
@@ -241,7 +241,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
// Test that domain add reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
domain, err := instanceDomainRepo.Get(CTX, pool,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||
@@ -271,7 +271,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
// Wait for domain to be created and verify it exists
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
_, err := instanceDomainRepo.Get(CTX,
|
||||
_, err := instanceDomainRepo.Get(CTX, pool,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||
@@ -293,7 +293,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
// Test that domain remove reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
domain, err := instanceDomainRepo.Get(CTX, pool,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
)
|
||||
|
||||
func TestServer_TestInstanceReduces(t *testing.T) {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
|
||||
t.Run("test instance add reduces", func(t *testing.T) {
|
||||
instanceName := gofakeit.Name()
|
||||
@@ -46,7 +46,7 @@ func TestServer_TestInstanceReduces(t *testing.T) {
|
||||
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
instance, err := instanceRepo.Get(CTX,
|
||||
instance, err := instanceRepo.Get(CTX, pool,
|
||||
database.WithCondition(instanceRepo.IDCondition(instance.GetInstanceId())),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
@@ -92,7 +92,7 @@ func TestServer_TestInstanceReduces(t *testing.T) {
|
||||
// check instance exists
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
instance, err := instanceRepo.Get(CTX,
|
||||
instance, err := instanceRepo.Get(CTX, pool,
|
||||
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
@@ -110,7 +110,7 @@ func TestServer_TestInstanceReduces(t *testing.T) {
|
||||
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
instance, err := instanceRepo.Get(CTX,
|
||||
instance, err := instanceRepo.Get(CTX, pool,
|
||||
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
@@ -137,7 +137,7 @@ func TestServer_TestInstanceReduces(t *testing.T) {
|
||||
// check instance exists
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
instance, err := instanceRepo.Get(CTX,
|
||||
instance, err := instanceRepo.Get(CTX, pool,
|
||||
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
@@ -151,7 +151,7 @@ func TestServer_TestInstanceReduces(t *testing.T) {
|
||||
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
instance, err := instanceRepo.Get(CTX,
|
||||
instance, err := instanceRepo.Get(CTX, pool,
|
||||
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
|
||||
)
|
||||
// event instance.removed
|
||||
|
||||
@@ -22,8 +22,8 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
orgRepo := repository.OrganizationRepository(pool)
|
||||
orgDomainRepo := orgRepo.Domains(false)
|
||||
orgRepo := repository.OrganizationRepository()
|
||||
orgDomainRepo := repository.OrganizationDomainRepository()
|
||||
|
||||
t.Cleanup(func() {
|
||||
_, err := OrgClient.DeleteOrganization(CTX, &v2beta.DeleteOrganizationRequest{
|
||||
@@ -37,8 +37,13 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
|
||||
// Wait for org to be created
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
_, err := orgRepo.Get(CTX,
|
||||
database.WithCondition(orgRepo.IDCondition(org.GetId())),
|
||||
_, err := orgRepo.Get(CTX, pool,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
orgRepo.InstanceIDCondition(Instance.Instance.Id),
|
||||
orgRepo.IDCondition(org.GetId()),
|
||||
),
|
||||
),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
}, retryDuration, tick)
|
||||
@@ -68,7 +73,7 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
|
||||
// Test that domain add reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
gottenDomain, err := orgDomainRepo.Get(CTX,
|
||||
gottenDomain, err := orgDomainRepo.Get(CTX, pool,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
orgDomainRepo.InstanceIDCondition(Instance.Instance.Id),
|
||||
@@ -107,7 +112,7 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
|
||||
// Test that domain remove reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
domain, err := orgDomainRepo.Get(CTX,
|
||||
domain, err := orgDomainRepo.Get(CTX, pool,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
orgDomainRepo.InstanceIDCondition(Instance.Instance.Id),
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
|
||||
func TestServer_TestOrganizationReduces(t *testing.T) {
|
||||
instanceID := Instance.ID()
|
||||
orgRepo := repository.OrganizationRepository(pool)
|
||||
orgRepo := repository.OrganizationRepository()
|
||||
|
||||
t.Run("test org add reduces", func(t *testing.T) {
|
||||
beforeCreate := time.Now()
|
||||
@@ -42,7 +42,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
|
||||
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(tt *assert.CollectT) {
|
||||
organization, err := orgRepo.Get(CTX,
|
||||
organization, err := orgRepo.Get(CTX, pool,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
orgRepo.IDCondition(org.GetId()),
|
||||
@@ -92,7 +92,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
|
||||
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
organization, err := orgRepo.Get(CTX,
|
||||
organization, err := orgRepo.Get(CTX, pool,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
orgRepo.IDCondition(organization.Id),
|
||||
@@ -137,7 +137,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
|
||||
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
organization, err := orgRepo.Get(CTX,
|
||||
organization, err := orgRepo.Get(CTX, pool,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
orgRepo.IDCondition(organization.Id),
|
||||
@@ -177,11 +177,11 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
orgRepo := repository.OrganizationRepository(pool)
|
||||
orgRepo := repository.OrganizationRepository()
|
||||
// 3. check org deactivated
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
organization, err := orgRepo.Get(CTX,
|
||||
organization, err := orgRepo.Get(CTX, pool,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
orgRepo.IDCondition(organization.Id),
|
||||
@@ -203,7 +203,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
|
||||
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
organization, err := orgRepo.Get(CTX,
|
||||
organization, err := orgRepo.Get(CTX, pool,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
orgRepo.IDCondition(organization.Id),
|
||||
@@ -230,10 +230,10 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2. check org retrievable
|
||||
orgRepo := repository.OrganizationRepository(pool)
|
||||
orgRepo := repository.OrganizationRepository()
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
_, err := orgRepo.Get(CTX,
|
||||
_, err := orgRepo.Get(CTX, pool,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
orgRepo.IDCondition(organization.Id),
|
||||
@@ -252,7 +252,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
|
||||
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
organization, err := orgRepo.Get(CTX,
|
||||
organization, err := orgRepo.Get(CTX, pool,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
orgRepo.IDCondition(organization.Id),
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
package database
|
||||
|
||||
//go:generate mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Client,Row,Rows,Transaction
|
||||
//go:generate mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Connection,Row,Rows,Transaction
|
||||
|
||||
@@ -6,16 +6,40 @@ import (
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
type Value interface {
|
||||
Boolean | Number | Text | Instruction
|
||||
type wrappedValue[V Value] struct {
|
||||
value V
|
||||
fn function
|
||||
}
|
||||
|
||||
func LowerValue[T Value](v T) wrappedValue[T] {
|
||||
return wrappedValue[T]{value: v, fn: functionLower}
|
||||
}
|
||||
|
||||
func SHA256Value[T Value](v T) wrappedValue[T] {
|
||||
return wrappedValue[T]{value: v, fn: functionSHA256}
|
||||
}
|
||||
|
||||
func (b wrappedValue[V]) WriteArg(builder *StatementBuilder) {
|
||||
builder.Grow(len(b.fn) + 5)
|
||||
builder.WriteString(string(b.fn))
|
||||
builder.WriteRune('(')
|
||||
builder.WriteArg(b.value)
|
||||
builder.WriteRune(')')
|
||||
}
|
||||
|
||||
var _ argWriter = (*wrappedValue[string])(nil)
|
||||
|
||||
type Value interface {
|
||||
Boolean | Number | Text | Instruction | Bytes
|
||||
}
|
||||
|
||||
//go:generate enumer -type NumberOperation,TextOperation,BytesOperation -linecomment -output ./operators_enumer.go
|
||||
type Operation interface {
|
||||
BooleanOperation | NumberOperation | TextOperation
|
||||
NumberOperation | TextOperation | BytesOperation
|
||||
}
|
||||
|
||||
type Text interface {
|
||||
~string | ~[]byte
|
||||
~string | Bytes
|
||||
}
|
||||
|
||||
// TextOperation are operations that can be performed on text values.
|
||||
@@ -23,60 +47,17 @@ type TextOperation uint8
|
||||
|
||||
const (
|
||||
// TextOperationEqual compares two strings for equality.
|
||||
TextOperationEqual TextOperation = iota + 1
|
||||
// TextOperationEqualIgnoreCase compares two strings for equality, ignoring case.
|
||||
TextOperationEqualIgnoreCase
|
||||
TextOperationEqual TextOperation = iota + 1 // =
|
||||
// TextOperationNotEqual compares two strings for inequality.
|
||||
TextOperationNotEqual
|
||||
// TextOperationNotEqualIgnoreCase compares two strings for inequality, ignoring case.
|
||||
TextOperationNotEqualIgnoreCase
|
||||
TextOperationNotEqual // <>
|
||||
// TextOperationStartsWith checks if the first string starts with the second.
|
||||
TextOperationStartsWith
|
||||
// TextOperationStartsWithIgnoreCase checks if the first string starts with the second, ignoring case.
|
||||
TextOperationStartsWithIgnoreCase
|
||||
TextOperationStartsWith // LIKE
|
||||
)
|
||||
|
||||
var textOperations = map[TextOperation]string{
|
||||
TextOperationEqual: " = ",
|
||||
TextOperationEqualIgnoreCase: " LIKE ",
|
||||
TextOperationNotEqual: " <> ",
|
||||
TextOperationNotEqualIgnoreCase: " NOT LIKE ",
|
||||
TextOperationStartsWith: " LIKE ",
|
||||
TextOperationStartsWithIgnoreCase: " LIKE ",
|
||||
}
|
||||
|
||||
func writeTextOperation[T Text](builder *StatementBuilder, col Column, op TextOperation, value T) {
|
||||
switch op {
|
||||
case TextOperationEqual, TextOperationNotEqual:
|
||||
col.WriteQualified(builder)
|
||||
builder.WriteString(textOperations[op])
|
||||
builder.WriteArg(value)
|
||||
case TextOperationEqualIgnoreCase, TextOperationNotEqualIgnoreCase:
|
||||
builder.WriteString("LOWER(")
|
||||
col.WriteQualified(builder)
|
||||
builder.WriteString(")")
|
||||
|
||||
builder.WriteString(textOperations[op])
|
||||
builder.WriteString("LOWER(")
|
||||
builder.WriteArg(value)
|
||||
builder.WriteString(")")
|
||||
case TextOperationStartsWith:
|
||||
col.WriteQualified(builder)
|
||||
builder.WriteString(textOperations[op])
|
||||
builder.WriteArg(value)
|
||||
func writeTextOperation[T Text](builder *StatementBuilder, col Column, op TextOperation, value any) {
|
||||
writeOperation[T](builder, col, op.String(), value)
|
||||
if op == TextOperationStartsWith {
|
||||
builder.WriteString(" || '%'")
|
||||
case TextOperationStartsWithIgnoreCase:
|
||||
builder.WriteString("LOWER(")
|
||||
col.WriteQualified(builder)
|
||||
builder.WriteString(")")
|
||||
|
||||
builder.WriteString(textOperations[op])
|
||||
builder.WriteString("LOWER(")
|
||||
builder.WriteArg(value)
|
||||
builder.WriteString(")")
|
||||
builder.WriteString(" || '%'")
|
||||
default:
|
||||
panic("unsupported text operation")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,48 +70,60 @@ type NumberOperation uint8
|
||||
|
||||
const (
|
||||
// NumberOperationEqual compares two numbers for equality.
|
||||
NumberOperationEqual NumberOperation = iota + 1
|
||||
NumberOperationEqual NumberOperation = iota + 1 // =
|
||||
// NumberOperationNotEqual compares two numbers for inequality.
|
||||
NumberOperationNotEqual
|
||||
NumberOperationNotEqual // <>
|
||||
// NumberOperationLessThan compares two numbers to check if the first is less than the second.
|
||||
NumberOperationLessThan
|
||||
NumberOperationLessThan // <
|
||||
// NumberOperationLessThanOrEqual compares two numbers to check if the first is less than or equal to the second.
|
||||
NumberOperationAtLeast
|
||||
NumberOperationAtLeast // <=
|
||||
// NumberOperationGreaterThan compares two numbers to check if the first is greater than the second.
|
||||
NumberOperationGreaterThan
|
||||
NumberOperationGreaterThan // >
|
||||
// NumberOperationGreaterThanOrEqual compares two numbers to check if the first is greater than or equal to the second.
|
||||
NumberOperationAtMost
|
||||
NumberOperationAtMost // >=
|
||||
)
|
||||
|
||||
var numberOperations = map[NumberOperation]string{
|
||||
NumberOperationEqual: " = ",
|
||||
NumberOperationNotEqual: " <> ",
|
||||
NumberOperationLessThan: " < ",
|
||||
NumberOperationAtLeast: " <= ",
|
||||
NumberOperationGreaterThan: " > ",
|
||||
NumberOperationAtMost: " >= ",
|
||||
}
|
||||
|
||||
func writeNumberOperation[T Number](builder *StatementBuilder, col Column, op NumberOperation, value T) {
|
||||
col.WriteQualified(builder)
|
||||
builder.WriteString(numberOperations[op])
|
||||
builder.WriteArg(value)
|
||||
func writeNumberOperation[T Number](builder *StatementBuilder, col Column, op NumberOperation, value any) {
|
||||
writeOperation[T](builder, col, op.String(), value)
|
||||
}
|
||||
|
||||
type Boolean interface {
|
||||
~bool
|
||||
}
|
||||
|
||||
// BooleanOperation are operations that can be performed on boolean values.
|
||||
type BooleanOperation uint8
|
||||
func writeBooleanOperation[T Boolean](builder *StatementBuilder, col Column, value any) {
|
||||
writeOperation[T](builder, col, "=", value)
|
||||
}
|
||||
|
||||
type Bytes interface {
|
||||
~[]byte
|
||||
}
|
||||
|
||||
// BytesOperation are operations that can be performed on bytea values.
|
||||
type BytesOperation uint8
|
||||
|
||||
const (
|
||||
BooleanOperationIsTrue BooleanOperation = iota + 1
|
||||
BooleanOperationIsFalse
|
||||
BytesOperationEqual BytesOperation = iota + 1 // =
|
||||
BytesOperationNotEqual // <>
|
||||
)
|
||||
|
||||
func writeBooleanOperation[T Boolean](builder *StatementBuilder, col Column, value T) {
|
||||
func writeBytesOperation[T Bytes](builder *StatementBuilder, col Column, op BytesOperation, value any) {
|
||||
writeOperation[T](builder, col, op.String(), value)
|
||||
}
|
||||
|
||||
func writeOperation[V Value](builder *StatementBuilder, col Column, op string, value any) {
|
||||
if op == "" {
|
||||
panic("unsupported operation")
|
||||
}
|
||||
|
||||
switch value.(type) {
|
||||
case V, wrappedValue[V], *wrappedValue[V]:
|
||||
default:
|
||||
panic("unsupported value type")
|
||||
}
|
||||
col.WriteQualified(builder)
|
||||
builder.WriteString(" = ")
|
||||
builder.WriteRune(' ')
|
||||
builder.WriteString(op)
|
||||
builder.WriteRune(' ')
|
||||
builder.WriteArg(value)
|
||||
}
|
||||
|
||||
241
backend/v3/storage/database/operators_enumer.go
Normal file
241
backend/v3/storage/database/operators_enumer.go
Normal file
@@ -0,0 +1,241 @@
|
||||
// Code generated by "enumer -type NumberOperation,TextOperation,BytesOperation -linecomment -output ./operators_enumer.go"; DO NOT EDIT.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const _NumberOperationName = "=<><<=>>="
|
||||
|
||||
var _NumberOperationIndex = [...]uint8{0, 1, 3, 4, 6, 7, 9}
|
||||
|
||||
const _NumberOperationLowerName = "=<><<=>>="
|
||||
|
||||
func (i NumberOperation) String() string {
|
||||
i -= 1
|
||||
if i >= NumberOperation(len(_NumberOperationIndex)-1) {
|
||||
return fmt.Sprintf("NumberOperation(%d)", i+1)
|
||||
}
|
||||
return _NumberOperationName[_NumberOperationIndex[i]:_NumberOperationIndex[i+1]]
|
||||
}
|
||||
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
func _NumberOperationNoOp() {
|
||||
var x [1]struct{}
|
||||
_ = x[NumberOperationEqual-(1)]
|
||||
_ = x[NumberOperationNotEqual-(2)]
|
||||
_ = x[NumberOperationLessThan-(3)]
|
||||
_ = x[NumberOperationAtLeast-(4)]
|
||||
_ = x[NumberOperationGreaterThan-(5)]
|
||||
_ = x[NumberOperationAtMost-(6)]
|
||||
}
|
||||
|
||||
var _NumberOperationValues = []NumberOperation{NumberOperationEqual, NumberOperationNotEqual, NumberOperationLessThan, NumberOperationAtLeast, NumberOperationGreaterThan, NumberOperationAtMost}
|
||||
|
||||
var _NumberOperationNameToValueMap = map[string]NumberOperation{
|
||||
_NumberOperationName[0:1]: NumberOperationEqual,
|
||||
_NumberOperationLowerName[0:1]: NumberOperationEqual,
|
||||
_NumberOperationName[1:3]: NumberOperationNotEqual,
|
||||
_NumberOperationLowerName[1:3]: NumberOperationNotEqual,
|
||||
_NumberOperationName[3:4]: NumberOperationLessThan,
|
||||
_NumberOperationLowerName[3:4]: NumberOperationLessThan,
|
||||
_NumberOperationName[4:6]: NumberOperationAtLeast,
|
||||
_NumberOperationLowerName[4:6]: NumberOperationAtLeast,
|
||||
_NumberOperationName[6:7]: NumberOperationGreaterThan,
|
||||
_NumberOperationLowerName[6:7]: NumberOperationGreaterThan,
|
||||
_NumberOperationName[7:9]: NumberOperationAtMost,
|
||||
_NumberOperationLowerName[7:9]: NumberOperationAtMost,
|
||||
}
|
||||
|
||||
var _NumberOperationNames = []string{
|
||||
_NumberOperationName[0:1],
|
||||
_NumberOperationName[1:3],
|
||||
_NumberOperationName[3:4],
|
||||
_NumberOperationName[4:6],
|
||||
_NumberOperationName[6:7],
|
||||
_NumberOperationName[7:9],
|
||||
}
|
||||
|
||||
// NumberOperationString retrieves an enum value from the enum constants string name.
|
||||
// Throws an error if the param is not part of the enum.
|
||||
func NumberOperationString(s string) (NumberOperation, error) {
|
||||
if val, ok := _NumberOperationNameToValueMap[s]; ok {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
if val, ok := _NumberOperationNameToValueMap[strings.ToLower(s)]; ok {
|
||||
return val, nil
|
||||
}
|
||||
return 0, fmt.Errorf("%s does not belong to NumberOperation values", s)
|
||||
}
|
||||
|
||||
// NumberOperationValues returns all values of the enum
|
||||
func NumberOperationValues() []NumberOperation {
|
||||
return _NumberOperationValues
|
||||
}
|
||||
|
||||
// NumberOperationStrings returns a slice of all String values of the enum
|
||||
func NumberOperationStrings() []string {
|
||||
strs := make([]string, len(_NumberOperationNames))
|
||||
copy(strs, _NumberOperationNames)
|
||||
return strs
|
||||
}
|
||||
|
||||
// IsANumberOperation returns "true" if the value is listed in the enum definition. "false" otherwise
|
||||
func (i NumberOperation) IsANumberOperation() bool {
|
||||
for _, v := range _NumberOperationValues {
|
||||
if i == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const _TextOperationName = "=<>LIKE"
|
||||
|
||||
var _TextOperationIndex = [...]uint8{0, 1, 3, 7}
|
||||
|
||||
const _TextOperationLowerName = "=<>like"
|
||||
|
||||
func (i TextOperation) String() string {
|
||||
i -= 1
|
||||
if i >= TextOperation(len(_TextOperationIndex)-1) {
|
||||
return fmt.Sprintf("TextOperation(%d)", i+1)
|
||||
}
|
||||
return _TextOperationName[_TextOperationIndex[i]:_TextOperationIndex[i+1]]
|
||||
}
|
||||
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
func _TextOperationNoOp() {
|
||||
var x [1]struct{}
|
||||
_ = x[TextOperationEqual-(1)]
|
||||
_ = x[TextOperationNotEqual-(2)]
|
||||
_ = x[TextOperationStartsWith-(3)]
|
||||
}
|
||||
|
||||
var _TextOperationValues = []TextOperation{TextOperationEqual, TextOperationNotEqual, TextOperationStartsWith}
|
||||
|
||||
var _TextOperationNameToValueMap = map[string]TextOperation{
|
||||
_TextOperationName[0:1]: TextOperationEqual,
|
||||
_TextOperationLowerName[0:1]: TextOperationEqual,
|
||||
_TextOperationName[1:3]: TextOperationNotEqual,
|
||||
_TextOperationLowerName[1:3]: TextOperationNotEqual,
|
||||
_TextOperationName[3:7]: TextOperationStartsWith,
|
||||
_TextOperationLowerName[3:7]: TextOperationStartsWith,
|
||||
}
|
||||
|
||||
var _TextOperationNames = []string{
|
||||
_TextOperationName[0:1],
|
||||
_TextOperationName[1:3],
|
||||
_TextOperationName[3:7],
|
||||
}
|
||||
|
||||
// TextOperationString retrieves an enum value from the enum constants string name.
|
||||
// Throws an error if the param is not part of the enum.
|
||||
func TextOperationString(s string) (TextOperation, error) {
|
||||
if val, ok := _TextOperationNameToValueMap[s]; ok {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
if val, ok := _TextOperationNameToValueMap[strings.ToLower(s)]; ok {
|
||||
return val, nil
|
||||
}
|
||||
return 0, fmt.Errorf("%s does not belong to TextOperation values", s)
|
||||
}
|
||||
|
||||
// TextOperationValues returns all values of the enum
|
||||
func TextOperationValues() []TextOperation {
|
||||
return _TextOperationValues
|
||||
}
|
||||
|
||||
// TextOperationStrings returns a slice of all String values of the enum
|
||||
func TextOperationStrings() []string {
|
||||
strs := make([]string, len(_TextOperationNames))
|
||||
copy(strs, _TextOperationNames)
|
||||
return strs
|
||||
}
|
||||
|
||||
// IsATextOperation returns "true" if the value is listed in the enum definition. "false" otherwise
|
||||
func (i TextOperation) IsATextOperation() bool {
|
||||
for _, v := range _TextOperationValues {
|
||||
if i == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const _BytesOperationName = "=<>"
|
||||
|
||||
var _BytesOperationIndex = [...]uint8{0, 1, 3}
|
||||
|
||||
const _BytesOperationLowerName = "=<>"
|
||||
|
||||
func (i BytesOperation) String() string {
|
||||
i -= 1
|
||||
if i >= BytesOperation(len(_BytesOperationIndex)-1) {
|
||||
return fmt.Sprintf("BytesOperation(%d)", i+1)
|
||||
}
|
||||
return _BytesOperationName[_BytesOperationIndex[i]:_BytesOperationIndex[i+1]]
|
||||
}
|
||||
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
func _BytesOperationNoOp() {
|
||||
var x [1]struct{}
|
||||
_ = x[BytesOperationEqual-(1)]
|
||||
_ = x[BytesOperationNotEqual-(2)]
|
||||
}
|
||||
|
||||
var _BytesOperationValues = []BytesOperation{BytesOperationEqual, BytesOperationNotEqual}
|
||||
|
||||
var _BytesOperationNameToValueMap = map[string]BytesOperation{
|
||||
_BytesOperationName[0:1]: BytesOperationEqual,
|
||||
_BytesOperationLowerName[0:1]: BytesOperationEqual,
|
||||
_BytesOperationName[1:3]: BytesOperationNotEqual,
|
||||
_BytesOperationLowerName[1:3]: BytesOperationNotEqual,
|
||||
}
|
||||
|
||||
var _BytesOperationNames = []string{
|
||||
_BytesOperationName[0:1],
|
||||
_BytesOperationName[1:3],
|
||||
}
|
||||
|
||||
// BytesOperationString retrieves an enum value from the enum constants string name.
|
||||
// Throws an error if the param is not part of the enum.
|
||||
func BytesOperationString(s string) (BytesOperation, error) {
|
||||
if val, ok := _BytesOperationNameToValueMap[s]; ok {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
if val, ok := _BytesOperationNameToValueMap[strings.ToLower(s)]; ok {
|
||||
return val, nil
|
||||
}
|
||||
return 0, fmt.Errorf("%s does not belong to BytesOperation values", s)
|
||||
}
|
||||
|
||||
// BytesOperationValues returns all values of the enum
|
||||
func BytesOperationValues() []BytesOperation {
|
||||
return _BytesOperationValues
|
||||
}
|
||||
|
||||
// BytesOperationStrings returns a slice of all String values of the enum
|
||||
func BytesOperationStrings() []string {
|
||||
strs := make([]string, len(_BytesOperationNames))
|
||||
copy(strs, _BytesOperationNames)
|
||||
return strs
|
||||
}
|
||||
|
||||
// IsABytesOperation returns "true" if the value is listed in the enum definition. "false" otherwise
|
||||
func (i BytesOperation) IsABytesOperation() bool {
|
||||
for _, v := range _BytesOperationValues {
|
||||
if i == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
244
backend/v3/storage/database/operators_test.go
Normal file
244
backend/v3/storage/database/operators_test.go
Normal file
@@ -0,0 +1,244 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_writeOperation(t *testing.T) {
|
||||
type want struct {
|
||||
shouldPanic bool
|
||||
stmt string
|
||||
args []any
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
write func(builder *StatementBuilder)
|
||||
// col Column
|
||||
// op string
|
||||
// value any
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "unsupported operation panics",
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeOperation[string](builder, NewColumn("table", "column"), "", "value")
|
||||
},
|
||||
want: want{
|
||||
shouldPanic: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unsupported value type panics",
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeOperation[string](builder, NewColumn("table", "column"), " = ", struct{}{})
|
||||
},
|
||||
want: want{
|
||||
shouldPanic: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unsupported wrapped value type panics",
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeOperation[string](builder, NewColumn("table", "column"), " = ", SHA256Value(123))
|
||||
},
|
||||
want: want{
|
||||
shouldPanic: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "text equal",
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeTextOperation[string](builder, NewColumn("table", "column"), TextOperationEqual, "value")
|
||||
},
|
||||
want: want{
|
||||
stmt: "table.column = $1",
|
||||
args: []any{"value"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "text not equal",
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeTextOperation[string](builder, NewColumn("table", "column"), TextOperationNotEqual, "value")
|
||||
},
|
||||
want: want{
|
||||
stmt: "table.column <> $1",
|
||||
args: []any{"value"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "text starts with",
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeTextOperation[string](builder, NewColumn("table", "column"), TextOperationStartsWith, "value")
|
||||
},
|
||||
want: want{
|
||||
stmt: "table.column LIKE $1 || '%'",
|
||||
args: []any{"value"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "text equal with wrapped value",
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeTextOperation[string](builder, LowerColumn(NewColumn("table", "column")), TextOperationEqual, LowerValue("value"))
|
||||
},
|
||||
want: want{
|
||||
stmt: "LOWER(table.column) = LOWER($1)",
|
||||
args: []any{"value"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "text not equal with wrapped value",
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeTextOperation[string](builder, LowerColumn(NewColumn("table", "column")), TextOperationNotEqual, LowerValue("value"))
|
||||
},
|
||||
want: want{
|
||||
stmt: "LOWER(table.column) <> LOWER($1)",
|
||||
args: []any{"value"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "text starts with with wrapped value",
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeTextOperation[string](builder, LowerColumn(NewColumn("table", "column")), TextOperationStartsWith, LowerValue("value"))
|
||||
},
|
||||
want: want{
|
||||
stmt: "LOWER(table.column) LIKE LOWER($1) || '%'",
|
||||
args: []any{"value"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "number equal",
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationEqual, 123)
|
||||
},
|
||||
want: want{
|
||||
stmt: "table.column = $1",
|
||||
args: []any{123},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "number not equal",
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationNotEqual, 123)
|
||||
},
|
||||
want: want{
|
||||
stmt: "table.column <> $1",
|
||||
args: []any{123},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "number less than",
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationLessThan, 123)
|
||||
},
|
||||
want: want{
|
||||
stmt: "table.column < $1",
|
||||
args: []any{123},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "number less than or equal",
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationAtLeast, 123)
|
||||
},
|
||||
want: want{
|
||||
stmt: "table.column <= $1",
|
||||
args: []any{123},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "number greater than",
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationGreaterThan, 123)
|
||||
},
|
||||
want: want{
|
||||
stmt: "table.column > $1",
|
||||
args: []any{123},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "number greater than or equal",
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationAtMost, 123)
|
||||
},
|
||||
want: want{
|
||||
stmt: "table.column >= $1",
|
||||
args: []any{123},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "boolean is true",
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeBooleanOperation[bool](builder, NewColumn("table", "column"), true)
|
||||
},
|
||||
want: want{
|
||||
stmt: "table.column = $1",
|
||||
args: []any{true},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "boolean is false",
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeBooleanOperation[bool](builder, NewColumn("table", "column"), false)
|
||||
},
|
||||
want: want{
|
||||
stmt: "table.column = $1",
|
||||
args: []any{false},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bytes equal",
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeBytesOperation[[]byte](builder, NewColumn("table", "column"), BytesOperationEqual, []byte{0x01, 0x02, 0x03})
|
||||
},
|
||||
want: want{
|
||||
stmt: "table.column = $1",
|
||||
args: []any{[]byte{0x01, 0x02, 0x03}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bytes not equal",
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeBytesOperation[[]byte](builder, NewColumn("table", "column"), BytesOperationNotEqual, []byte{0x01, 0x02, 0x03})
|
||||
},
|
||||
want: want{
|
||||
stmt: "table.column <> $1",
|
||||
args: []any{[]byte{0x01, 0x02, 0x03}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bytes equal with wrapped value",
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeBytesOperation[[]byte](builder, SHA256Column(NewColumn("table", "column")), BytesOperationEqual, SHA256Value([]byte{0x01, 0x02, 0x03}))
|
||||
},
|
||||
want: want{
|
||||
stmt: "SHA256(table.column) = SHA256($1)",
|
||||
args: []any{[]byte{0x01, 0x02, 0x03}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bytes not equal with wrapped value",
|
||||
write: func(builder *StatementBuilder) {
|
||||
writeBytesOperation[[]byte](builder, SHA256Column(NewColumn("table", "column")), BytesOperationNotEqual, SHA256Value([]byte{0x01, 0x02, 0x03}))
|
||||
},
|
||||
want: want{
|
||||
stmt: "SHA256(table.column) <> SHA256($1)",
|
||||
args: []any{[]byte{0x01, 0x02, 0x03}},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer func() {
|
||||
r := recover()
|
||||
assert.Equal(t, tt.want.shouldPanic, r != nil)
|
||||
}()
|
||||
var builder StatementBuilder
|
||||
tt.write(&builder)
|
||||
|
||||
assert.Equal(t, tt.want.stmt, builder.String())
|
||||
assert.Equal(t, tt.want.args, builder.Args())
|
||||
})
|
||||
}
|
||||
}
|
||||
28
backend/v3/storage/database/order_test.go
Normal file
28
backend/v3/storage/database/order_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_orderBy_Write(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want string
|
||||
order Order
|
||||
}{
|
||||
{
|
||||
name: "order by column",
|
||||
want: " ORDER BY table.column",
|
||||
order: OrderBy(NewColumn("table", "column")),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var builder StatementBuilder
|
||||
tt.order.Write(&builder)
|
||||
assert.Equal(t, tt.want, builder.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
147
backend/v3/storage/database/query_test.go
Normal file
147
backend/v3/storage/database/query_test.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestQueryOptions(t *testing.T) {
|
||||
type want struct {
|
||||
stmt string
|
||||
args []any
|
||||
}
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
options []QueryOption
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "no options",
|
||||
want: want{
|
||||
stmt: "",
|
||||
args: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "limit option",
|
||||
options: []QueryOption{
|
||||
WithLimit(10),
|
||||
},
|
||||
want: want{
|
||||
stmt: " LIMIT $1",
|
||||
args: []any{uint32(10)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "offset option",
|
||||
options: []QueryOption{
|
||||
WithOffset(5),
|
||||
},
|
||||
want: want{
|
||||
stmt: " OFFSET $1",
|
||||
args: []any{uint32(5)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "order by asc option",
|
||||
options: []QueryOption{
|
||||
WithOrderByAscending(NewColumn("table", "column")),
|
||||
},
|
||||
want: want{
|
||||
stmt: " ORDER BY table.column",
|
||||
args: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "order by desc option",
|
||||
options: []QueryOption{
|
||||
WithOrderByDescending(NewColumn("table", "column")),
|
||||
},
|
||||
want: want{
|
||||
stmt: " ORDER BY table.column DESC",
|
||||
args: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "order by option",
|
||||
options: []QueryOption{
|
||||
WithOrderBy(OrderDirectionAsc, NewColumn("table", "column1"), NewColumn("table", "column2")),
|
||||
},
|
||||
want: want{
|
||||
stmt: " ORDER BY table.column1, table.column2",
|
||||
args: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "condition option",
|
||||
options: []QueryOption{
|
||||
WithCondition(NewBooleanCondition(NewColumn("table", "column"), true)),
|
||||
},
|
||||
want: want{
|
||||
stmt: " WHERE table.column = $1",
|
||||
args: []any{true},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "group by option",
|
||||
options: []QueryOption{
|
||||
WithGroupBy(NewColumn("table", "column")),
|
||||
},
|
||||
want: want{
|
||||
stmt: " GROUP BY table.column",
|
||||
args: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "left join option",
|
||||
options: []QueryOption{
|
||||
WithLeftJoin("other_table", NewColumnCondition(NewColumn("table", "id"), NewColumn("other_table", "table_id"))),
|
||||
},
|
||||
want: want{
|
||||
stmt: " LEFT JOIN other_table ON table.id = other_table.table_id",
|
||||
args: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "permission check option",
|
||||
options: []QueryOption{
|
||||
WithPermissionCheck("permission"),
|
||||
},
|
||||
want: want{
|
||||
stmt: "",
|
||||
args: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple options",
|
||||
options: []QueryOption{
|
||||
WithLeftJoin("other_table", NewColumnCondition(NewColumn("table", "id"), NewColumn("other_table", "table_id"))),
|
||||
WithCondition(NewNumberCondition(NewColumn("table", "column"), NumberOperationEqual, 123)),
|
||||
WithOrderByDescending(NewColumn("table", "column")),
|
||||
WithLimit(10),
|
||||
WithOffset(5),
|
||||
},
|
||||
want: want{
|
||||
stmt: " LEFT JOIN other_table ON table.id = other_table.table_id WHERE table.column = $1 ORDER BY table.column DESC LIMIT $2 OFFSET $3",
|
||||
args: []any{123, uint32(10), uint32(5)},
|
||||
},
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var b StatementBuilder
|
||||
var opts QueryOpts
|
||||
for _, option := range test.options {
|
||||
option(&opts)
|
||||
}
|
||||
opts.Write(&b)
|
||||
assert.Equal(t, test.want.stmt, b.String())
|
||||
require.Len(t, b.Args(), len(test.want.args))
|
||||
for i := range test.want.args {
|
||||
assert.Equal(t, test.want.args[i], b.Args()[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
@@ -3,19 +3,35 @@ package repository
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type JSONArray[T any] []*T
|
||||
|
||||
func (a JSONArray[T]) Scan(src any) error {
|
||||
var ErrScanSource = errors.New("unsupported scan source")
|
||||
|
||||
func (a *JSONArray[T]) Scan(src any) (err error) {
|
||||
var rawJSON []byte
|
||||
switch s := src.(type) {
|
||||
case string:
|
||||
return json.Unmarshal([]byte(s), &a)
|
||||
if len(s) == 0 {
|
||||
return nil
|
||||
}
|
||||
rawJSON = []byte(s)
|
||||
case []byte:
|
||||
return json.Unmarshal(s, &a)
|
||||
if len(s) == 0 {
|
||||
return nil
|
||||
}
|
||||
rawJSON = s
|
||||
case nil:
|
||||
return nil
|
||||
default:
|
||||
return errors.New("unsupported scan source")
|
||||
return ErrScanSource
|
||||
}
|
||||
err = json.Unmarshal(rawJSON, a)
|
||||
if err != nil {
|
||||
return database.NewScanError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
117
backend/v3/storage/database/repository/array_test.go
Normal file
117
backend/v3/storage/database/repository/array_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package repository_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
|
||||
)
|
||||
|
||||
type testObject struct {
|
||||
ID string `json:"id"`
|
||||
Number int `json:"number"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
func TestJSONArray_Scan(t *testing.T) {
|
||||
type want struct {
|
||||
err error
|
||||
res []*testObject
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
src any
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
src: nil,
|
||||
want: want{
|
||||
err: nil,
|
||||
res: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[]byte with data",
|
||||
src: []byte(`[{"id":"1","number":1,"active":true}]`),
|
||||
want: want{
|
||||
err: nil,
|
||||
res: []*testObject{{ID: "1", Number: 1, Active: true}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "[]byte without data",
|
||||
src: []byte(`[]`),
|
||||
want: want{
|
||||
err: nil,
|
||||
res: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nil []byte",
|
||||
src: []byte(nil),
|
||||
want: want{
|
||||
err: nil,
|
||||
res: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty []byte",
|
||||
src: []byte{},
|
||||
want: want{
|
||||
err: nil,
|
||||
res: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "string with data",
|
||||
src: `[{"id":"1","number":1,"active":true}]`,
|
||||
want: want{
|
||||
err: nil,
|
||||
res: []*testObject{{ID: "1", Number: 1, Active: true}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "string without data",
|
||||
src: string(`[]`),
|
||||
want: want{
|
||||
err: nil,
|
||||
res: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
src: "",
|
||||
want: want{
|
||||
err: nil,
|
||||
res: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "wrong type",
|
||||
src: []int{1, 2, 3},
|
||||
want: want{
|
||||
err: repository.ErrScanSource,
|
||||
res: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "badly formatted JSON",
|
||||
src: []byte(`this is not a JSON`),
|
||||
want: want{
|
||||
err: new(database.ScanError),
|
||||
res: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var a repository.JSONArray[testObject]
|
||||
gotErr := a.Scan(tt.src)
|
||||
require.ErrorIs(t, gotErr, tt.want.err)
|
||||
require.Len(t, a, len(tt.want.res))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -13,17 +13,20 @@ import (
|
||||
var _ domain.InstanceRepository = (*instance)(nil)
|
||||
|
||||
type instance struct {
|
||||
repository
|
||||
shouldLoadDomains bool
|
||||
domainRepo *instanceDomain
|
||||
domainRepo instanceDomain
|
||||
}
|
||||
|
||||
func InstanceRepository(client database.QueryExecutor) domain.InstanceRepository {
|
||||
return &instance{
|
||||
repository: repository{
|
||||
client: client,
|
||||
},
|
||||
}
|
||||
func InstanceRepository() domain.InstanceRepository {
|
||||
return new(instance)
|
||||
}
|
||||
|
||||
func (instance) qualifiedTableName() string {
|
||||
return "zitadel.instances"
|
||||
}
|
||||
|
||||
func (instance) unqualifiedTableName() string {
|
||||
return "instances"
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
@@ -37,7 +40,7 @@ const (
|
||||
)
|
||||
|
||||
// Get implements [domain.InstanceRepository].
|
||||
func (i *instance) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Instance, error) {
|
||||
func (i instance) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.Instance, error) {
|
||||
opts = append(opts,
|
||||
i.joinDomains(),
|
||||
database.WithGroupBy(i.IDColumn()),
|
||||
@@ -52,11 +55,11 @@ func (i *instance) Get(ctx context.Context, opts ...database.QueryOption) (*doma
|
||||
builder.WriteString(queryInstanceStmt)
|
||||
options.Write(&builder)
|
||||
|
||||
return scanInstance(ctx, i.client, &builder)
|
||||
return scanInstance(ctx, client, &builder)
|
||||
}
|
||||
|
||||
// List implements [domain.InstanceRepository].
|
||||
func (i *instance) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Instance, error) {
|
||||
func (i instance) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*domain.Instance, error) {
|
||||
opts = append(opts,
|
||||
i.joinDomains(),
|
||||
database.WithGroupBy(i.IDColumn()),
|
||||
@@ -71,27 +74,11 @@ func (i *instance) List(ctx context.Context, opts ...database.QueryOption) ([]*d
|
||||
builder.WriteString(queryInstanceStmt)
|
||||
options.Write(&builder)
|
||||
|
||||
return scanInstances(ctx, i.client, &builder)
|
||||
}
|
||||
|
||||
func (i *instance) joinDomains() database.QueryOption {
|
||||
columns := make([]database.Condition, 0, 2)
|
||||
columns = append(columns, database.NewColumnCondition(i.IDColumn(), i.Domains(false).InstanceIDColumn()))
|
||||
|
||||
// If domains should not be joined, we make sure to return null for the domain columns
|
||||
// the query optimizer of the dialect should optimize this away if no domains are requested
|
||||
if !i.shouldLoadDomains {
|
||||
columns = append(columns, database.IsNull(i.Domains(false).InstanceIDColumn()))
|
||||
}
|
||||
|
||||
return database.WithLeftJoin(
|
||||
"zitadel.instance_domains",
|
||||
database.And(columns...),
|
||||
)
|
||||
return scanInstances(ctx, client, &builder)
|
||||
}
|
||||
|
||||
// Create implements [domain.InstanceRepository].
|
||||
func (i *instance) Create(ctx context.Context, instance *domain.Instance) error {
|
||||
func (i instance) Create(ctx context.Context, client database.QueryExecutor, instance *domain.Instance) error {
|
||||
var (
|
||||
builder database.StatementBuilder
|
||||
createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction
|
||||
@@ -103,42 +90,46 @@ func (i *instance) Create(ctx context.Context, instance *domain.Instance) error
|
||||
updatedAt = instance.UpdatedAt
|
||||
}
|
||||
|
||||
builder.WriteString(`INSERT INTO zitadel.instances (id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language, created_at, updated_at) VALUES (`)
|
||||
builder.WriteString(`INSERT INTO `)
|
||||
builder.WriteString(i.qualifiedTableName())
|
||||
builder.WriteString(` (id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language, created_at, updated_at) VALUES (`)
|
||||
builder.WriteArgs(instance.ID, instance.Name, instance.DefaultOrgID, instance.IAMProjectID, instance.ConsoleClientID, instance.ConsoleAppID, instance.DefaultLanguage, createdAt, updatedAt)
|
||||
builder.WriteString(`) RETURNING created_at, updated_at`)
|
||||
|
||||
return i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&instance.CreatedAt, &instance.UpdatedAt)
|
||||
return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&instance.CreatedAt, &instance.UpdatedAt)
|
||||
}
|
||||
|
||||
// Update implements [domain.InstanceRepository].
|
||||
func (i instance) Update(ctx context.Context, id string, changes ...database.Change) (int64, error) {
|
||||
func (i instance) Update(ctx context.Context, client database.QueryExecutor, id string, changes ...database.Change) (int64, error) {
|
||||
if len(changes) == 0 {
|
||||
return 0, database.ErrNoChanges
|
||||
}
|
||||
if !database.Changes(changes).IsOnColumn(i.UpdatedAtColumn()) {
|
||||
changes = append(changes, database.NewChange(i.UpdatedAtColumn(), database.NullInstruction))
|
||||
}
|
||||
|
||||
var builder database.StatementBuilder
|
||||
|
||||
builder.WriteString(`UPDATE zitadel.instances SET `)
|
||||
|
||||
database.Changes(changes).Write(&builder)
|
||||
|
||||
idCondition := i.IDCondition(id)
|
||||
writeCondition(&builder, idCondition)
|
||||
|
||||
stmt := builder.String()
|
||||
|
||||
return i.client.Exec(ctx, stmt, builder.Args()...)
|
||||
return client.Exec(ctx, stmt, builder.Args()...)
|
||||
}
|
||||
|
||||
// Delete implements [domain.InstanceRepository].
|
||||
func (i instance) Delete(ctx context.Context, id string) (int64, error) {
|
||||
func (i instance) Delete(ctx context.Context, client database.QueryExecutor, id string) (int64, error) {
|
||||
var builder database.StatementBuilder
|
||||
|
||||
builder.WriteString(`DELETE FROM zitadel.instances`)
|
||||
builder.WriteString(`DELETE FROM `)
|
||||
builder.WriteString(i.qualifiedTableName())
|
||||
|
||||
idCondition := i.IDCondition(id)
|
||||
writeCondition(&builder, idCondition)
|
||||
|
||||
return i.client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
return client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
@@ -185,53 +176,77 @@ func (i instance) NameCondition(op database.TextOperation, name string) database
|
||||
return database.NewTextCondition(i.NameColumn(), op, name)
|
||||
}
|
||||
|
||||
// ExistsDomain creates a correlated [database.Exists] condition on instance_domains.
|
||||
// Use this filter to make sure the Instance returned contains a specific domain.
|
||||
// of the instance in the aggregated result.
|
||||
// Example usage:
|
||||
//
|
||||
// domainRepo := instanceRepo.Domains(true) // ensure domains are loaded/aggregated
|
||||
// instance, _ := instanceRepo.Get(ctx,
|
||||
// database.WithCondition(
|
||||
// database.And(
|
||||
// instanceRepo.InstanceIDCondition(instanceID),
|
||||
// instanceRepo.ExistsDomain(domainRepo.DomainCondition(database.TextOperationEqual, "example.com")),
|
||||
// ),
|
||||
// ),
|
||||
// )
|
||||
func (i instance) ExistsDomain(cond database.Condition) database.Condition {
|
||||
return database.Exists(
|
||||
i.domainRepo.qualifiedTableName(),
|
||||
database.And(
|
||||
database.NewColumnCondition(i.IDColumn(), i.domainRepo.InstanceIDColumn()),
|
||||
cond,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// columns
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// IDColumn implements [domain.instanceColumns].
|
||||
func (instance) IDColumn() database.Column {
|
||||
return database.NewColumn("instances", "id")
|
||||
func (i instance) IDColumn() database.Column {
|
||||
return database.NewColumn(i.unqualifiedTableName(), "id")
|
||||
}
|
||||
|
||||
// NameColumn implements [domain.instanceColumns].
|
||||
func (instance) NameColumn() database.Column {
|
||||
return database.NewColumn("instances", "name")
|
||||
func (i instance) NameColumn() database.Column {
|
||||
return database.NewColumn(i.unqualifiedTableName(), "name")
|
||||
}
|
||||
|
||||
// CreatedAtColumn implements [domain.instanceColumns].
|
||||
func (instance) CreatedAtColumn() database.Column {
|
||||
return database.NewColumn("instances", "created_at")
|
||||
func (i instance) CreatedAtColumn() database.Column {
|
||||
return database.NewColumn(i.unqualifiedTableName(), "created_at")
|
||||
}
|
||||
|
||||
// DefaultOrgIdColumn implements [domain.instanceColumns].
|
||||
func (instance) DefaultOrgIDColumn() database.Column {
|
||||
return database.NewColumn("instances", "default_org_id")
|
||||
func (i instance) DefaultOrgIDColumn() database.Column {
|
||||
return database.NewColumn(i.unqualifiedTableName(), "default_org_id")
|
||||
}
|
||||
|
||||
// IAMProjectIDColumn implements [domain.instanceColumns].
|
||||
func (instance) IAMProjectIDColumn() database.Column {
|
||||
return database.NewColumn("instances", "iam_project_id")
|
||||
func (i instance) IAMProjectIDColumn() database.Column {
|
||||
return database.NewColumn(i.unqualifiedTableName(), "iam_project_id")
|
||||
}
|
||||
|
||||
// ConsoleClientIDColumn implements [domain.instanceColumns].
|
||||
func (instance) ConsoleClientIDColumn() database.Column {
|
||||
return database.NewColumn("instances", "console_client_id")
|
||||
func (i instance) ConsoleClientIDColumn() database.Column {
|
||||
return database.NewColumn(i.unqualifiedTableName(), "console_client_id")
|
||||
}
|
||||
|
||||
// ConsoleAppIDColumn implements [domain.instanceColumns].
|
||||
func (instance) ConsoleAppIDColumn() database.Column {
|
||||
return database.NewColumn("instances", "console_app_id")
|
||||
func (i instance) ConsoleAppIDColumn() database.Column {
|
||||
return database.NewColumn(i.unqualifiedTableName(), "console_app_id")
|
||||
}
|
||||
|
||||
// DefaultLanguageColumn implements [domain.instanceColumns].
|
||||
func (instance) DefaultLanguageColumn() database.Column {
|
||||
return database.NewColumn("instances", "default_language")
|
||||
func (i instance) DefaultLanguageColumn() database.Column {
|
||||
return database.NewColumn(i.unqualifiedTableName(), "default_language")
|
||||
}
|
||||
|
||||
// UpdatedAtColumn implements [domain.instanceColumns].
|
||||
func (instance) UpdatedAtColumn() database.Column {
|
||||
return database.NewColumn("instances", "updated_at")
|
||||
func (i instance) UpdatedAtColumn() database.Column {
|
||||
return database.NewColumn(i.unqualifiedTableName(), "updated_at")
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
@@ -282,19 +297,24 @@ func scanInstances(ctx context.Context, querier database.Querier, builder *datab
|
||||
// sub repositories
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// Domains implements [domain.InstanceRepository].
|
||||
func (i *instance) Domains(shouldLoad bool) domain.InstanceDomainRepository {
|
||||
if !i.shouldLoadDomains {
|
||||
i.shouldLoadDomains = shouldLoad
|
||||
func (i *instance) LoadDomains() domain.InstanceRepository {
|
||||
return &instance{
|
||||
shouldLoadDomains: true,
|
||||
}
|
||||
|
||||
if i.domainRepo != nil {
|
||||
return i.domainRepo
|
||||
}
|
||||
|
||||
i.domainRepo = &instanceDomain{
|
||||
repository: i.repository,
|
||||
instance: i,
|
||||
}
|
||||
return i.domainRepo
|
||||
}
|
||||
|
||||
func (i *instance) joinDomains() database.QueryOption {
|
||||
columns := make([]database.Condition, 0, 2)
|
||||
columns = append(columns, database.NewColumnCondition(i.IDColumn(), i.domainRepo.InstanceIDColumn()))
|
||||
|
||||
// If domains should not be joined, we make sure to return null for the domain columns
|
||||
// the query optimizer of the dialect should optimize this away if no domains are requested
|
||||
if !i.shouldLoadDomains {
|
||||
columns = append(columns, database.IsNull(i.domainRepo.InstanceIDColumn()))
|
||||
}
|
||||
|
||||
return database.WithLeftJoin(
|
||||
i.domainRepo.qualifiedTableName(),
|
||||
database.And(columns...),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -10,9 +10,18 @@ import (
|
||||
|
||||
var _ domain.InstanceDomainRepository = (*instanceDomain)(nil)
|
||||
|
||||
type instanceDomain struct {
|
||||
repository
|
||||
*instance
|
||||
type instanceDomain struct{}
|
||||
|
||||
func InstanceDomainRepository() domain.InstanceDomainRepository {
|
||||
return new(instanceDomain)
|
||||
}
|
||||
|
||||
func (instanceDomain) qualifiedTableName() string {
|
||||
return "zitadel.instance_domains"
|
||||
}
|
||||
|
||||
func (instanceDomain) unqualifiedTableName() string {
|
||||
return "instance_domains"
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
@@ -23,8 +32,7 @@ const queryInstanceDomainStmt = `SELECT instance_domains.instance_id, instance_d
|
||||
`FROM zitadel.instance_domains`
|
||||
|
||||
// Get implements [domain.InstanceDomainRepository].
|
||||
// Subtle: this method shadows the method ([domain.InstanceRepository]).Get of instanceDomain.instance.
|
||||
func (i *instanceDomain) Get(ctx context.Context, opts ...database.QueryOption) (*domain.InstanceDomain, error) {
|
||||
func (i instanceDomain) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.InstanceDomain, error) {
|
||||
options := new(database.QueryOpts)
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
@@ -34,12 +42,11 @@ func (i *instanceDomain) Get(ctx context.Context, opts ...database.QueryOption)
|
||||
builder.WriteString(queryInstanceDomainStmt)
|
||||
options.Write(&builder)
|
||||
|
||||
return scanInstanceDomain(ctx, i.client, &builder)
|
||||
return scanInstanceDomain(ctx, client, &builder)
|
||||
}
|
||||
|
||||
// List implements [domain.InstanceDomainRepository].
|
||||
// Subtle: this method shadows the method ([domain.InstanceRepository]).List of instanceDomain.instance.
|
||||
func (i *instanceDomain) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.InstanceDomain, error) {
|
||||
func (i instanceDomain) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*domain.InstanceDomain, error) {
|
||||
options := new(database.QueryOpts)
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
@@ -49,11 +56,11 @@ func (i *instanceDomain) List(ctx context.Context, opts ...database.QueryOption)
|
||||
builder.WriteString(queryInstanceDomainStmt)
|
||||
options.Write(&builder)
|
||||
|
||||
return scanInstanceDomains(ctx, i.client, &builder)
|
||||
return scanInstanceDomains(ctx, client, &builder)
|
||||
}
|
||||
|
||||
// Add implements [domain.InstanceDomainRepository].
|
||||
func (i *instanceDomain) Add(ctx context.Context, domain *domain.AddInstanceDomain) error {
|
||||
func (i instanceDomain) Add(ctx context.Context, client database.QueryExecutor, domain *domain.AddInstanceDomain) error {
|
||||
var (
|
||||
builder database.StatementBuilder
|
||||
createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction
|
||||
@@ -69,33 +76,41 @@ func (i *instanceDomain) Add(ctx context.Context, domain *domain.AddInstanceDoma
|
||||
builder.WriteArgs(domain.InstanceID, domain.Domain, domain.IsPrimary, domain.IsGenerated, domain.Type, createdAt, updatedAt)
|
||||
builder.WriteString(`) RETURNING created_at, updated_at`)
|
||||
|
||||
return i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt)
|
||||
return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt)
|
||||
}
|
||||
|
||||
// Update implements [domain.InstanceDomainRepository].
|
||||
// Subtle: this method shadows the method ([domain.InstanceRepository]).Update of instanceDomain.instance.
|
||||
func (i *instanceDomain) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) {
|
||||
func (i instanceDomain) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error) {
|
||||
if !condition.IsRestrictingColumn(i.InstanceIDColumn()) {
|
||||
return 0, database.NewMissingConditionError(i.InstanceIDColumn())
|
||||
}
|
||||
if len(changes) == 0 {
|
||||
return 0, database.ErrNoChanges
|
||||
}
|
||||
var builder database.StatementBuilder
|
||||
if !database.Changes(changes).IsOnColumn(i.UpdatedAtColumn()) {
|
||||
changes = append(changes, database.NewChange(i.UpdatedAtColumn(), database.NullInstruction))
|
||||
}
|
||||
|
||||
var builder database.StatementBuilder
|
||||
builder.WriteString(`UPDATE zitadel.instance_domains SET `)
|
||||
database.Changes(changes).Write(&builder)
|
||||
|
||||
writeCondition(&builder, condition)
|
||||
|
||||
return i.client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
return client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
}
|
||||
|
||||
// Remove implements [domain.InstanceDomainRepository].
|
||||
func (i *instanceDomain) Remove(ctx context.Context, condition database.Condition) (int64, error) {
|
||||
var builder database.StatementBuilder
|
||||
func (i instanceDomain) Remove(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error) {
|
||||
if !condition.IsRestrictingColumn(i.InstanceIDColumn()) {
|
||||
return 0, database.NewMissingConditionError(i.InstanceIDColumn())
|
||||
}
|
||||
|
||||
var builder database.StatementBuilder
|
||||
builder.WriteString(`DELETE FROM zitadel.instance_domains WHERE `)
|
||||
condition.Write(&builder)
|
||||
|
||||
return i.client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
return client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
@@ -146,40 +161,38 @@ func (i instanceDomain) TypeCondition(typ domain.DomainType) database.Condition
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// CreatedAtColumn implements [domain.InstanceDomainRepository].
|
||||
// Subtle: this method shadows the method ([domain.InstanceRepository]).CreatedAtColumn of instanceDomain.instance.
|
||||
func (instanceDomain) CreatedAtColumn() database.Column {
|
||||
return database.NewColumn("instance_domains", "created_at")
|
||||
func (i instanceDomain) CreatedAtColumn() database.Column {
|
||||
return database.NewColumn(i.unqualifiedTableName(), "created_at")
|
||||
}
|
||||
|
||||
// DomainColumn implements [domain.InstanceDomainRepository].
|
||||
func (instanceDomain) DomainColumn() database.Column {
|
||||
return database.NewColumn("instance_domains", "domain")
|
||||
func (i instanceDomain) DomainColumn() database.Column {
|
||||
return database.NewColumn(i.unqualifiedTableName(), "domain")
|
||||
}
|
||||
|
||||
// InstanceIDColumn implements [domain.InstanceDomainRepository].
|
||||
func (instanceDomain) InstanceIDColumn() database.Column {
|
||||
return database.NewColumn("instance_domains", "instance_id")
|
||||
func (i instanceDomain) InstanceIDColumn() database.Column {
|
||||
return database.NewColumn(i.unqualifiedTableName(), "instance_id")
|
||||
}
|
||||
|
||||
// IsPrimaryColumn implements [domain.InstanceDomainRepository].
|
||||
func (instanceDomain) IsPrimaryColumn() database.Column {
|
||||
return database.NewColumn("instance_domains", "is_primary")
|
||||
func (i instanceDomain) IsPrimaryColumn() database.Column {
|
||||
return database.NewColumn(i.unqualifiedTableName(), "is_primary")
|
||||
}
|
||||
|
||||
// UpdatedAtColumn implements [domain.InstanceDomainRepository].
|
||||
// Subtle: this method shadows the method ([domain.InstanceRepository]).UpdatedAtColumn of instanceDomain.instance.
|
||||
func (instanceDomain) UpdatedAtColumn() database.Column {
|
||||
return database.NewColumn("instance_domains", "updated_at")
|
||||
func (i instanceDomain) UpdatedAtColumn() database.Column {
|
||||
return database.NewColumn(i.unqualifiedTableName(), "updated_at")
|
||||
}
|
||||
|
||||
// IsGeneratedColumn implements [domain.InstanceDomainRepository].
|
||||
func (instanceDomain) IsGeneratedColumn() database.Column {
|
||||
return database.NewColumn("instance_domains", "is_generated")
|
||||
func (i instanceDomain) IsGeneratedColumn() database.Column {
|
||||
return database.NewColumn(i.unqualifiedTableName(), "is_generated")
|
||||
}
|
||||
|
||||
// TypeColumn implements [domain.InstanceDomainRepository].
|
||||
func (instanceDomain) TypeColumn() database.Column {
|
||||
return database.NewColumn("instance_domains", "type")
|
||||
func (i instanceDomain) TypeColumn() database.Column {
|
||||
return database.NewColumn(i.unqualifiedTableName(), "type")
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package repository_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -16,31 +15,43 @@ import (
|
||||
)
|
||||
|
||||
func TestAddInstanceDomain(t *testing.T) {
|
||||
// we take now here because the timestamp of the transaction is used to set the createdAt and updatedAt fields
|
||||
beforeAdd := time.Now()
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = tx.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
domainRepo := repository.InstanceDomainRepository()
|
||||
|
||||
// create instance
|
||||
instanceID := gofakeit.UUID()
|
||||
instance := domain.Instance{
|
||||
ID: instanceID,
|
||||
Name: gofakeit.Name(),
|
||||
ID: gofakeit.NewCrypto().UUID(),
|
||||
Name: gofakeit.BeerName(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleClient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
err := instanceRepo.Create(t.Context(), &instance)
|
||||
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain
|
||||
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.AddInstanceDomain
|
||||
instanceDomain domain.AddInstanceDomain
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "happy path custom domain",
|
||||
instanceDomain: domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
InstanceID: instance.ID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
Type: domain.DomainTypeCustom,
|
||||
IsPrimary: gu.Ptr(false),
|
||||
@@ -50,7 +61,7 @@ func TestAddInstanceDomain(t *testing.T) {
|
||||
{
|
||||
name: "happy path trusted domain",
|
||||
instanceDomain: domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
InstanceID: instance.ID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
Type: domain.DomainTypeTrusted,
|
||||
},
|
||||
@@ -58,7 +69,7 @@ func TestAddInstanceDomain(t *testing.T) {
|
||||
{
|
||||
name: "add primary domain",
|
||||
instanceDomain: domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
InstanceID: instance.ID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
Type: domain.DomainTypeCustom,
|
||||
IsPrimary: gu.Ptr(true),
|
||||
@@ -68,7 +79,7 @@ func TestAddInstanceDomain(t *testing.T) {
|
||||
{
|
||||
name: "add custom domain without domain name",
|
||||
instanceDomain: domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
InstanceID: instance.ID,
|
||||
Domain: "",
|
||||
Type: domain.DomainTypeCustom,
|
||||
IsPrimary: gu.Ptr(false),
|
||||
@@ -79,7 +90,7 @@ func TestAddInstanceDomain(t *testing.T) {
|
||||
{
|
||||
name: "add trusted domain without domain name",
|
||||
instanceDomain: domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
InstanceID: instance.ID,
|
||||
Domain: "",
|
||||
Type: domain.DomainTypeTrusted,
|
||||
},
|
||||
@@ -87,23 +98,23 @@ func TestAddInstanceDomain(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "add custom domain with same domain twice",
|
||||
testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain {
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.AddInstanceDomain {
|
||||
domainName := gofakeit.DomainName()
|
||||
|
||||
instanceDomain := &domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
InstanceID: instance.ID,
|
||||
Domain: domainName,
|
||||
Type: domain.DomainTypeCustom,
|
||||
IsPrimary: gu.Ptr(false),
|
||||
IsGenerated: gu.Ptr(false),
|
||||
}
|
||||
|
||||
err := domainRepo.Add(ctx, instanceDomain)
|
||||
err := domainRepo.Add(t.Context(), tx, instanceDomain)
|
||||
require.NoError(t, err)
|
||||
|
||||
// return same domain again
|
||||
return &domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
InstanceID: instance.ID,
|
||||
Domain: domainName,
|
||||
Type: domain.DomainTypeCustom,
|
||||
IsPrimary: gu.Ptr(false),
|
||||
@@ -114,22 +125,20 @@ func TestAddInstanceDomain(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "add trusted domain with same domain twice",
|
||||
testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain {
|
||||
domainName := gofakeit.DomainName()
|
||||
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.AddInstanceDomain {
|
||||
instanceDomain := &domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: domainName,
|
||||
InstanceID: instance.ID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
Type: domain.DomainTypeTrusted,
|
||||
}
|
||||
|
||||
err := domainRepo.Add(ctx, instanceDomain)
|
||||
err := domainRepo.Add(t.Context(), tx, instanceDomain)
|
||||
require.NoError(t, err)
|
||||
|
||||
// return same domain again
|
||||
return &domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: domainName,
|
||||
InstanceID: instance.ID,
|
||||
Domain: instanceDomain.Domain,
|
||||
Type: domain.DomainTypeTrusted,
|
||||
}
|
||||
},
|
||||
@@ -196,26 +205,23 @@ func TestAddInstanceDomain(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
|
||||
// we take now here because the timestamp of the transaction is used to set the createdAt and updatedAt fields
|
||||
beforeAdd := time.Now()
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
savepoint, err := tx.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, tx.Rollback(t.Context()))
|
||||
err := savepoint.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
instanceRepo := repository.InstanceRepository(tx)
|
||||
domainRepo := instanceRepo.Domains(false)
|
||||
|
||||
var instanceDomain *domain.AddInstanceDomain
|
||||
if test.testFunc != nil {
|
||||
instanceDomain = test.testFunc(ctx, t, domainRepo)
|
||||
instanceDomain = test.testFunc(t, savepoint)
|
||||
} else {
|
||||
instanceDomain = &test.instanceDomain
|
||||
}
|
||||
|
||||
err = domainRepo.Add(ctx, instanceDomain)
|
||||
err = domainRepo.Add(t.Context(), savepoint, instanceDomain)
|
||||
afterAdd := time.Now()
|
||||
if test.err != nil {
|
||||
assert.ErrorIs(t, err, test.err)
|
||||
@@ -232,10 +238,21 @@ func TestAddInstanceDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetInstanceDomain(t *testing.T) {
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = tx.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
domainRepo := repository.InstanceDomainRepository()
|
||||
|
||||
// create instance
|
||||
instanceID := gofakeit.UUID()
|
||||
instance := domain.Instance{
|
||||
ID: instanceID,
|
||||
ID: gofakeit.NewCrypto().UUID(),
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
@@ -243,38 +260,29 @@ func TestGetInstanceDomain(t *testing.T) {
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, tx.Rollback(t.Context()))
|
||||
}()
|
||||
instanceRepo := repository.InstanceRepository(tx)
|
||||
err = instanceRepo.Create(t.Context(), &instance)
|
||||
|
||||
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
// add domains
|
||||
domainRepo := instanceRepo.Domains(false)
|
||||
domainName1 := gofakeit.DomainName()
|
||||
domainName2 := gofakeit.DomainName()
|
||||
|
||||
domain1 := &domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: domainName1,
|
||||
InstanceID: instance.ID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsPrimary: gu.Ptr(true),
|
||||
IsGenerated: gu.Ptr(false),
|
||||
Type: domain.DomainTypeCustom,
|
||||
}
|
||||
domain2 := &domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: domainName2,
|
||||
InstanceID: instance.ID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsPrimary: gu.Ptr(false),
|
||||
IsGenerated: gu.Ptr(false),
|
||||
Type: domain.DomainTypeCustom,
|
||||
}
|
||||
|
||||
err = domainRepo.Add(t.Context(), domain1)
|
||||
err = domainRepo.Add(t.Context(), tx, domain1)
|
||||
require.NoError(t, err)
|
||||
err = domainRepo.Add(t.Context(), domain2)
|
||||
err = domainRepo.Add(t.Context(), tx, domain2)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
@@ -289,19 +297,19 @@ func TestGetInstanceDomain(t *testing.T) {
|
||||
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
|
||||
},
|
||||
expected: &domain.InstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: domainName1,
|
||||
InstanceID: instance.ID,
|
||||
Domain: domain1.Domain,
|
||||
IsPrimary: gu.Ptr(true),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "get by domain name",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, domainName2)),
|
||||
database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, domain2.Domain)),
|
||||
},
|
||||
expected: &domain.InstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: domainName2,
|
||||
InstanceID: instance.ID,
|
||||
Domain: domain2.Domain,
|
||||
IsPrimary: gu.Ptr(false),
|
||||
},
|
||||
},
|
||||
@@ -318,7 +326,7 @@ func TestGetInstanceDomain(t *testing.T) {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
|
||||
result, err := domainRepo.Get(ctx, test.opts...)
|
||||
result, err := domainRepo.Get(ctx, tx, test.opts...)
|
||||
if test.err != nil {
|
||||
assert.ErrorIs(t, err, test.err)
|
||||
return
|
||||
@@ -335,10 +343,21 @@ func TestGetInstanceDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestListInstanceDomains(t *testing.T) {
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = tx.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
domainRepo := repository.InstanceDomainRepository()
|
||||
|
||||
// create instance
|
||||
instanceID := gofakeit.UUID()
|
||||
instance := domain.Instance{
|
||||
ID: instanceID,
|
||||
ID: gofakeit.NewCrypto().UUID(),
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
@@ -346,42 +365,35 @@ func TestListInstanceDomains(t *testing.T) {
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, tx.Rollback(t.Context()))
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository(tx)
|
||||
err = instanceRepo.Create(t.Context(), &instance)
|
||||
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
// add multiple domains
|
||||
domainRepo := instanceRepo.Domains(false)
|
||||
domains := []domain.AddInstanceDomain{
|
||||
{
|
||||
InstanceID: instanceID,
|
||||
InstanceID: instance.ID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsPrimary: gu.Ptr(true),
|
||||
IsGenerated: gu.Ptr(false),
|
||||
Type: domain.DomainTypeCustom,
|
||||
},
|
||||
{
|
||||
InstanceID: instanceID,
|
||||
InstanceID: instance.ID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsPrimary: gu.Ptr(false),
|
||||
IsGenerated: gu.Ptr(false),
|
||||
Type: domain.DomainTypeCustom,
|
||||
},
|
||||
{
|
||||
InstanceID: instanceID,
|
||||
InstanceID: instance.ID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
Type: domain.DomainTypeTrusted,
|
||||
},
|
||||
}
|
||||
|
||||
for i := range domains {
|
||||
err = domainRepo.Add(t.Context(), &domains[i])
|
||||
err = domainRepo.Add(t.Context(), tx, &domains[i])
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -405,7 +417,7 @@ func TestListInstanceDomains(t *testing.T) {
|
||||
{
|
||||
name: "list by instance",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.InstanceIDCondition(instanceID)),
|
||||
database.WithCondition(domainRepo.InstanceIDCondition(instance.ID)),
|
||||
},
|
||||
expectedCount: 3,
|
||||
},
|
||||
@@ -422,12 +434,12 @@ func TestListInstanceDomains(t *testing.T) {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
|
||||
results, err := domainRepo.List(ctx, test.opts...)
|
||||
results, err := domainRepo.List(ctx, tx, test.opts...)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, test.expectedCount)
|
||||
|
||||
for _, result := range results {
|
||||
assert.Equal(t, instanceID, result.InstanceID)
|
||||
assert.Equal(t, instance.ID, result.InstanceID)
|
||||
assert.NotEmpty(t, result.Domain)
|
||||
assert.NotEmpty(t, result.CreatedAt)
|
||||
assert.NotEmpty(t, result.UpdatedAt)
|
||||
@@ -437,10 +449,21 @@ func TestListInstanceDomains(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUpdateInstanceDomain(t *testing.T) {
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = tx.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
domainRepo := repository.InstanceDomainRepository()
|
||||
|
||||
// create instance
|
||||
instanceID := gofakeit.UUID()
|
||||
instance := domain.Instance{
|
||||
ID: instanceID,
|
||||
ID: gofakeit.UUID(),
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
@@ -448,29 +471,19 @@ func TestUpdateInstanceDomain(t *testing.T) {
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, tx.Rollback(t.Context()))
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository(tx)
|
||||
err = instanceRepo.Create(t.Context(), &instance)
|
||||
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
// add domain
|
||||
domainRepo := instanceRepo.Domains(false)
|
||||
domainName := gofakeit.DomainName()
|
||||
instanceDomain := &domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: domainName,
|
||||
InstanceID: instance.ID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsPrimary: gu.Ptr(false),
|
||||
IsGenerated: gu.Ptr(false),
|
||||
Type: domain.DomainTypeCustom,
|
||||
}
|
||||
|
||||
err = domainRepo.Add(t.Context(), instanceDomain)
|
||||
err = domainRepo.Add(t.Context(), tx, instanceDomain)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
@@ -481,31 +494,38 @@ func TestUpdateInstanceDomain(t *testing.T) {
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "set primary",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
||||
changes: []database.Change{domainRepo.SetPrimary()},
|
||||
expected: 1,
|
||||
name: "set primary",
|
||||
condition: database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, instanceDomain.Domain),
|
||||
),
|
||||
changes: []database.Change{domainRepo.SetPrimary()},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "update non-existent domain",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
|
||||
changes: []database.Change{domainRepo.SetPrimary()},
|
||||
expected: 0,
|
||||
name: "update non-existent domain",
|
||||
condition: database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
|
||||
),
|
||||
changes: []database.Change{domainRepo.SetPrimary()},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "no changes",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
||||
changes: []database.Change{},
|
||||
expected: 0,
|
||||
err: database.ErrNoChanges,
|
||||
name: "no changes",
|
||||
condition: database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, instanceDomain.Domain),
|
||||
),
|
||||
changes: []database.Change{},
|
||||
expected: 0,
|
||||
err: database.ErrNoChanges,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
|
||||
rowsAffected, err := domainRepo.Update(ctx, test.condition, test.changes...)
|
||||
rowsAffected, err := domainRepo.Update(t.Context(), tx, test.condition, test.changes...)
|
||||
if test.err != nil {
|
||||
assert.ErrorIs(t, err, test.err)
|
||||
return
|
||||
@@ -516,7 +536,7 @@ func TestUpdateInstanceDomain(t *testing.T) {
|
||||
|
||||
// verify changes were applied if rows were affected
|
||||
if rowsAffected > 0 && len(test.changes) > 0 {
|
||||
result, err := domainRepo.Get(ctx, database.WithCondition(test.condition))
|
||||
result, err := domainRepo.Get(t.Context(), tx, database.WithCondition(test.condition))
|
||||
require.NoError(t, err)
|
||||
|
||||
// We know changes were applied since rowsAffected > 0
|
||||
@@ -529,10 +549,21 @@ func TestUpdateInstanceDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRemoveInstanceDomain(t *testing.T) {
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = tx.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
domainRepo := repository.InstanceDomainRepository()
|
||||
|
||||
// create instance
|
||||
instanceID := gofakeit.UUID()
|
||||
instance := domain.Instance{
|
||||
ID: instanceID,
|
||||
ID: gofakeit.UUID(),
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
@@ -540,37 +571,28 @@ func TestRemoveInstanceDomain(t *testing.T) {
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, tx.Rollback(t.Context()))
|
||||
}()
|
||||
instanceRepo := repository.InstanceRepository(tx)
|
||||
err = instanceRepo.Create(t.Context(), &instance)
|
||||
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
// add domains
|
||||
domainRepo := instanceRepo.Domains(false)
|
||||
domainName1 := gofakeit.DomainName()
|
||||
|
||||
domain1 := &domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: domainName1,
|
||||
InstanceID: instance.ID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsPrimary: gu.Ptr(true),
|
||||
IsGenerated: gu.Ptr(false),
|
||||
Type: domain.DomainTypeCustom,
|
||||
}
|
||||
domain2 := &domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
InstanceID: instance.ID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsPrimary: gu.Ptr(false),
|
||||
IsGenerated: gu.Ptr(false),
|
||||
Type: domain.DomainTypeCustom,
|
||||
}
|
||||
|
||||
err = domainRepo.Add(t.Context(), domain1)
|
||||
err = domainRepo.Add(t.Context(), tx, domain1)
|
||||
require.NoError(t, err)
|
||||
err = domainRepo.Add(t.Context(), domain2)
|
||||
err = domainRepo.Add(t.Context(), tx, domain2)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
@@ -579,36 +601,43 @@ func TestRemoveInstanceDomain(t *testing.T) {
|
||||
expected int64
|
||||
}{
|
||||
{
|
||||
name: "remove by domain name",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName1),
|
||||
expected: 1,
|
||||
name: "remove by domain name",
|
||||
condition: database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, domain1.Domain),
|
||||
),
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "remove by primary condition",
|
||||
condition: domainRepo.IsPrimaryCondition(false),
|
||||
expected: 1, // domain2 should still exist and be non-primary
|
||||
name: "remove by primary condition",
|
||||
condition: database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.IsPrimaryCondition(false),
|
||||
),
|
||||
expected: 1, // domain2 should still exist and be non-primary
|
||||
},
|
||||
{
|
||||
name: "remove non-existent domain",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
|
||||
expected: 0,
|
||||
name: "remove non-existent domain",
|
||||
condition: database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
|
||||
),
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
|
||||
// count before removal
|
||||
beforeCount, err := domainRepo.List(ctx)
|
||||
beforeCount, err := domainRepo.List(t.Context(), tx)
|
||||
require.NoError(t, err)
|
||||
|
||||
rowsAffected, err := domainRepo.Remove(ctx, test.condition)
|
||||
rowsAffected, err := domainRepo.Remove(t.Context(), tx, test.condition)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.expected, rowsAffected)
|
||||
|
||||
// verify removal
|
||||
afterCount, err := domainRepo.List(ctx)
|
||||
afterCount, err := domainRepo.List(t.Context(), tx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(beforeCount)-int(test.expected), len(afterCount))
|
||||
})
|
||||
@@ -616,8 +645,7 @@ func TestRemoveInstanceDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInstanceDomainConditions(t *testing.T) {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
domainRepo := instanceRepo.Domains(false)
|
||||
domainRepo := repository.InstanceDomainRepository()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -671,8 +699,7 @@ func TestInstanceDomainConditions(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInstanceDomainChanges(t *testing.T) {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
domainRepo := instanceRepo.Domains(false)
|
||||
domainRepo := repository.InstanceDomainRepository()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -2,6 +2,7 @@ package repository_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -16,9 +17,20 @@ import (
|
||||
)
|
||||
|
||||
func TestCreateInstance(t *testing.T) {
|
||||
beforeCreate := time.Now()
|
||||
tx, err := pool.Begin(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := tx.Rollback(context.Background())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
|
||||
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance
|
||||
instance domain.Instance
|
||||
err error
|
||||
}{
|
||||
@@ -59,14 +71,10 @@ func TestCreateInstance(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "adding same instance twice",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceId := gofakeit.Name()
|
||||
instanceName := gofakeit.Name()
|
||||
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
Name: instanceName,
|
||||
ID: gofakeit.UUID(),
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
@@ -74,7 +82,9 @@ func TestCreateInstance(t *testing.T) {
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
// change the name to make sure same only the id clashes
|
||||
inst.Name = gofakeit.Name()
|
||||
require.NoError(t, err)
|
||||
@@ -84,7 +94,7 @@ func TestCreateInstance(t *testing.T) {
|
||||
},
|
||||
func() struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
|
||||
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance
|
||||
instance domain.Instance
|
||||
err error
|
||||
} {
|
||||
@@ -92,14 +102,12 @@ func TestCreateInstance(t *testing.T) {
|
||||
instanceName := gofakeit.Name()
|
||||
return struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
|
||||
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance
|
||||
instance domain.Instance
|
||||
err error
|
||||
}{
|
||||
name: "adding instance with same name twice",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
|
||||
inst := domain.Instance{
|
||||
ID: gofakeit.Name(),
|
||||
Name: instanceName,
|
||||
@@ -110,7 +118,7 @@ func TestCreateInstance(t *testing.T) {
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
// change the id
|
||||
@@ -135,11 +143,8 @@ func TestCreateInstance(t *testing.T) {
|
||||
{
|
||||
name: "adding instance with no id",
|
||||
instance: func() domain.Instance {
|
||||
// instanceId := gofakeit.Name()
|
||||
instanceName := gofakeit.Name()
|
||||
instance := domain.Instance{
|
||||
// ID: instanceId,
|
||||
Name: instanceName,
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
@@ -153,19 +158,25 @@ func TestCreateInstance(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
savepoint, err := tx.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = savepoint.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
var instance *domain.Instance
|
||||
if tt.testFunc != nil {
|
||||
instance = tt.testFunc(ctx, t)
|
||||
instance = tt.testFunc(t, savepoint)
|
||||
} else {
|
||||
instance = &tt.instance
|
||||
}
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
|
||||
// create instance
|
||||
beforeCreate := time.Now()
|
||||
err := instanceRepo.Create(ctx, instance)
|
||||
|
||||
err = instanceRepo.Create(t.Context(), tx, instance)
|
||||
assert.ErrorIs(t, err, tt.err)
|
||||
if err != nil {
|
||||
return
|
||||
@@ -173,7 +184,7 @@ func TestCreateInstance(t *testing.T) {
|
||||
afterCreate := time.Now()
|
||||
|
||||
// check instance values
|
||||
instance, err = instanceRepo.Get(ctx,
|
||||
instance, err = instanceRepo.Get(t.Context(), tx,
|
||||
database.WithCondition(
|
||||
instanceRepo.IDCondition(instance.ID),
|
||||
),
|
||||
@@ -194,22 +205,30 @@ func TestCreateInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUpdateInstance(t *testing.T) {
|
||||
beforeUpdate := time.Now()
|
||||
tx, err := pool.Begin(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := tx.Rollback(context.Background())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
|
||||
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance
|
||||
rowsAffected int64
|
||||
getErr error
|
||||
}{
|
||||
{
|
||||
name: "happy path",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceId := gofakeit.Name()
|
||||
instanceName := gofakeit.Name()
|
||||
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
Name: instanceName,
|
||||
ID: gofakeit.UUID(),
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
@@ -218,7 +237,7 @@ func TestUpdateInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
return &inst
|
||||
},
|
||||
@@ -226,14 +245,10 @@ func TestUpdateInstance(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "update deleted instance",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceId := gofakeit.Name()
|
||||
instanceName := gofakeit.Name()
|
||||
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
Name: instanceName,
|
||||
ID: gofakeit.UUID(),
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
@@ -242,11 +257,11 @@ func TestUpdateInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
// delete instance
|
||||
affectedRows, err := instanceRepo.Delete(ctx,
|
||||
affectedRows, err := instanceRepo.Delete(t.Context(), tx,
|
||||
inst.ID,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
@@ -258,11 +273,9 @@ func TestUpdateInstance(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "update non existent instance",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
instanceId := gofakeit.Name()
|
||||
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
ID: gofakeit.UUID(),
|
||||
}
|
||||
return &inst
|
||||
},
|
||||
@@ -272,15 +285,11 @@ func TestUpdateInstance(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instance := tt.testFunc(t, tx)
|
||||
|
||||
instance := tt.testFunc(ctx, t)
|
||||
|
||||
beforeUpdate := time.Now()
|
||||
// update name
|
||||
newName := "new_" + instance.Name
|
||||
rowsAffected, err := instanceRepo.Update(ctx,
|
||||
rowsAffected, err := instanceRepo.Update(t.Context(), tx,
|
||||
instance.ID,
|
||||
instanceRepo.SetName(newName),
|
||||
)
|
||||
@@ -294,7 +303,7 @@ func TestUpdateInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
// check instance values
|
||||
instance, err = instanceRepo.Get(ctx,
|
||||
instance, err = instanceRepo.Get(t.Context(), tx,
|
||||
database.WithCondition(
|
||||
instanceRepo.IDCondition(instance.ID),
|
||||
),
|
||||
@@ -308,24 +317,31 @@ func TestUpdateInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetInstance(t *testing.T) {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
tx, err := pool.Begin(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := tx.Rollback(context.Background())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
domainRepo := repository.InstanceDomainRepository()
|
||||
|
||||
type test struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
|
||||
testFunc func(t *testing.T) *domain.Instance
|
||||
err error
|
||||
}
|
||||
|
||||
tests := []test{
|
||||
func() test {
|
||||
instanceId := gofakeit.Name()
|
||||
return test{
|
||||
name: "happy path get using id",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
instanceName := gofakeit.Name()
|
||||
|
||||
testFunc: func(t *testing.T) *domain.Instance {
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
Name: instanceName,
|
||||
ID: gofakeit.UUID(),
|
||||
Name: gofakeit.BeerName(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
@@ -334,7 +350,7 @@ func TestGetInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
return &inst
|
||||
},
|
||||
@@ -342,14 +358,10 @@ func TestGetInstance(t *testing.T) {
|
||||
}(),
|
||||
{
|
||||
name: "happy path including domains",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceId := gofakeit.Name()
|
||||
instanceName := gofakeit.Name()
|
||||
|
||||
testFunc: func(t *testing.T) *domain.Instance {
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
Name: instanceName,
|
||||
ID: gofakeit.NewCrypto().UUID(),
|
||||
Name: gofakeit.BeerName(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
@@ -358,10 +370,9 @@ func TestGetInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
domainRepo := instanceRepo.Domains(false)
|
||||
d := &domain.AddInstanceDomain{
|
||||
InstanceID: inst.ID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
@@ -369,7 +380,7 @@ func TestGetInstance(t *testing.T) {
|
||||
IsGenerated: gu.Ptr(false),
|
||||
Type: domain.DomainTypeCustom,
|
||||
}
|
||||
err = domainRepo.Add(ctx, d)
|
||||
err = domainRepo.Add(t.Context(), tx, d)
|
||||
require.NoError(t, err)
|
||||
|
||||
inst.Domains = append(inst.Domains, &domain.InstanceDomain{
|
||||
@@ -387,7 +398,7 @@ func TestGetInstance(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "get non existent instance",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
testFunc: func(t *testing.T) *domain.Instance {
|
||||
inst := domain.Instance{
|
||||
ID: "get non existent instance",
|
||||
}
|
||||
@@ -398,16 +409,13 @@ func TestGetInstance(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
|
||||
var instance *domain.Instance
|
||||
if tt.testFunc != nil {
|
||||
instance = tt.testFunc(ctx, t)
|
||||
instance = tt.testFunc(t)
|
||||
}
|
||||
|
||||
// check instance values
|
||||
returnedInstance, err := instanceRepo.Get(ctx,
|
||||
returnedInstance, err := instanceRepo.Get(t.Context(), tx,
|
||||
database.WithCondition(
|
||||
instanceRepo.IDCondition(instance.ID),
|
||||
),
|
||||
@@ -434,28 +442,33 @@ func TestGetInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestListInstance(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
pool, stop, err := newEmbeddedDB(ctx)
|
||||
tx, err := pool.Begin(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
defer stop()
|
||||
defer func() {
|
||||
err := tx.Rollback(context.Background())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
|
||||
type test struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) []*domain.Instance
|
||||
testFunc func(t *testing.T, tx database.QueryExecutor) []*domain.Instance
|
||||
conditionClauses []database.Condition
|
||||
noInstanceReturned bool
|
||||
}
|
||||
tests := []test{
|
||||
{
|
||||
name: "happy path single instance no filter",
|
||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance {
|
||||
noOfInstances := 1
|
||||
instances := make([]*domain.Instance, noOfInstances)
|
||||
for i := range noOfInstances {
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: gofakeit.Name(),
|
||||
ID: strconv.Itoa(i),
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
@@ -465,7 +478,7 @@ func TestListInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
instances[i] = &inst
|
||||
@@ -476,14 +489,13 @@ func TestListInstance(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "happy path multiple instance no filter",
|
||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance {
|
||||
noOfInstances := 5
|
||||
instances := make([]*domain.Instance, noOfInstances)
|
||||
for i := range noOfInstances {
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: gofakeit.Name(),
|
||||
ID: strconv.Itoa(i),
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
@@ -493,7 +505,7 @@ func TestListInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
instances[i] = &inst
|
||||
@@ -503,17 +515,16 @@ func TestListInstance(t *testing.T) {
|
||||
},
|
||||
},
|
||||
func() test {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceId := gofakeit.Name()
|
||||
instanceID := gofakeit.BeerName()
|
||||
return test{
|
||||
name: "instance filter on id",
|
||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance {
|
||||
noOfInstances := 1
|
||||
instances := make([]*domain.Instance, noOfInstances)
|
||||
for i := range noOfInstances {
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
ID: instanceID,
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
@@ -523,7 +534,7 @@ func TestListInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
instances[i] = &inst
|
||||
@@ -531,21 +542,20 @@ func TestListInstance(t *testing.T) {
|
||||
|
||||
return instances
|
||||
},
|
||||
conditionClauses: []database.Condition{instanceRepo.IDCondition(instanceId)},
|
||||
conditionClauses: []database.Condition{instanceRepo.IDCondition(instanceID)},
|
||||
}
|
||||
}(),
|
||||
func() test {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceName := gofakeit.Name()
|
||||
instanceName := gofakeit.BeerName()
|
||||
return test{
|
||||
name: "multiple instance filter on name",
|
||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance {
|
||||
noOfInstances := 5
|
||||
instances := make([]*domain.Instance, noOfInstances)
|
||||
for i := range noOfInstances {
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: gofakeit.Name(),
|
||||
ID: strconv.Itoa(i),
|
||||
Name: instanceName,
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
@@ -555,7 +565,7 @@ func TestListInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
instances[i] = &inst
|
||||
@@ -569,14 +579,15 @@ func TestListInstance(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
_, err := pool.Exec(ctx, "DELETE FROM zitadel.instances")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
instances := tt.testFunc(ctx, t)
|
||||
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
savepoint, err := tx.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = savepoint.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
instances := tt.testFunc(t, savepoint)
|
||||
|
||||
var condition database.Condition
|
||||
if len(tt.conditionClauses) > 0 {
|
||||
@@ -584,13 +595,13 @@ func TestListInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
// check instance values
|
||||
returnedInstances, err := instanceRepo.List(ctx,
|
||||
returnedInstances, err := instanceRepo.List(t.Context(), tx,
|
||||
database.WithCondition(condition),
|
||||
database.WithOrderByAscending(instanceRepo.CreatedAtColumn()),
|
||||
database.WithOrderByAscending(instanceRepo.IDColumn()),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
if tt.noInstanceReturned {
|
||||
assert.Nil(t, returnedInstances)
|
||||
assert.Len(t, returnedInstances, 0)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -609,42 +620,45 @@ func TestListInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDeleteInstance(t *testing.T) {
|
||||
tx, err := pool.Begin(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := tx.Rollback(context.Background())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
|
||||
type test struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T)
|
||||
testFunc func(t *testing.T, tx database.QueryExecutor)
|
||||
instanceID string
|
||||
noOfDeletedRows int64
|
||||
}
|
||||
tests := []test{
|
||||
func() test {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceId := gofakeit.Name()
|
||||
var noOfInstances int64 = 1
|
||||
instanceID := gofakeit.NewCrypto().UUID()
|
||||
return test{
|
||||
name: "happy path delete single instance filter id",
|
||||
testFunc: func(ctx context.Context, t *testing.T) {
|
||||
instances := make([]*domain.Instance, noOfInstances)
|
||||
for i := range noOfInstances {
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
instances[i] = &inst
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) {
|
||||
inst := domain.Instance{
|
||||
ID: instanceID,
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
instanceID: instanceId,
|
||||
noOfDeletedRows: noOfInstances,
|
||||
instanceID: instanceID,
|
||||
noOfDeletedRows: 1,
|
||||
}
|
||||
}(),
|
||||
func() test {
|
||||
@@ -655,40 +669,33 @@ func TestDeleteInstance(t *testing.T) {
|
||||
}
|
||||
}(),
|
||||
func() test {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceName := gofakeit.Name()
|
||||
instanceID := gofakeit.Name()
|
||||
return test{
|
||||
name: "deleted already deleted instance",
|
||||
testFunc: func(ctx context.Context, t *testing.T) {
|
||||
noOfInstances := 1
|
||||
instances := make([]*domain.Instance, noOfInstances)
|
||||
for i := range noOfInstances {
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) {
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: gofakeit.Name(),
|
||||
Name: instanceName,
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
instances[i] = &inst
|
||||
inst := domain.Instance{
|
||||
ID: instanceID,
|
||||
Name: gofakeit.BeerName(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
// delete instance
|
||||
affectedRows, err := instanceRepo.Delete(ctx,
|
||||
instances[0].ID,
|
||||
affectedRows, err := instanceRepo.Delete(t.Context(), tx,
|
||||
inst.ID,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), affectedRows)
|
||||
},
|
||||
instanceID: instanceName,
|
||||
instanceID: instanceID,
|
||||
// this test should return 0 affected rows as the instance was already deleted
|
||||
noOfDeletedRows: 0,
|
||||
}
|
||||
@@ -696,22 +703,26 @@ func TestDeleteInstance(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
savepoint, err := tx.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = savepoint.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if tt.testFunc != nil {
|
||||
tt.testFunc(ctx, t)
|
||||
tt.testFunc(t, savepoint)
|
||||
}
|
||||
|
||||
// delete instance
|
||||
noOfDeletedRows, err := instanceRepo.Delete(ctx,
|
||||
tt.instanceID,
|
||||
)
|
||||
noOfDeletedRows, err := instanceRepo.Delete(t.Context(), savepoint, tt.instanceID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, noOfDeletedRows, tt.noOfDeletedRows)
|
||||
|
||||
// check instance was deleted
|
||||
instance, err := instanceRepo.Get(ctx,
|
||||
instance, err := instanceRepo.Get(t.Context(), savepoint,
|
||||
database.WithCondition(
|
||||
instanceRepo.IDCondition(tt.instanceID),
|
||||
),
|
||||
|
||||
@@ -14,25 +14,24 @@ import (
|
||||
var _ domain.OrganizationRepository = (*org)(nil)
|
||||
|
||||
type org struct {
|
||||
repository
|
||||
shouldLoadDomains bool
|
||||
domainRepo domain.OrganizationDomainRepository
|
||||
domainRepo orgDomain
|
||||
}
|
||||
|
||||
func OrganizationRepository(client database.QueryExecutor) domain.OrganizationRepository {
|
||||
return &org{
|
||||
repository: repository{
|
||||
client: client,
|
||||
},
|
||||
}
|
||||
func (o org) unqualifiedTableName() string {
|
||||
return "organizations"
|
||||
}
|
||||
|
||||
func OrganizationRepository() domain.OrganizationRepository {
|
||||
return new(org)
|
||||
}
|
||||
|
||||
const queryOrganizationStmt = `SELECT organizations.id, organizations.name, organizations.instance_id, organizations.state, organizations.created_at, organizations.updated_at` +
|
||||
` , jsonb_agg(json_build_object('domain', org_domains.domain, 'isVerified', org_domains.is_verified, 'isPrimary', org_domains.is_primary, 'validationType', org_domains.validation_type, 'createdAt', org_domains.created_at, 'updatedAt', org_domains.updated_at)) FILTER (WHERE org_domains.org_id IS NOT NULL) AS domains` +
|
||||
` , jsonb_agg(json_build_object('instanceId', org_domains.instance_id, 'orgId', org_domains.org_id, 'domain', org_domains.domain, 'isVerified', org_domains.is_verified, 'isPrimary', org_domains.is_primary, 'validationType', org_domains.validation_type, 'createdAt', org_domains.created_at, 'updatedAt', org_domains.updated_at)) FILTER (WHERE org_domains.org_id IS NOT NULL) AS domains` +
|
||||
` FROM zitadel.organizations`
|
||||
|
||||
// Get implements [domain.OrganizationRepository].
|
||||
func (o *org) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Organization, error) {
|
||||
func (o org) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.Organization, error) {
|
||||
opts = append(opts,
|
||||
o.joinDomains(),
|
||||
database.WithGroupBy(o.InstanceIDColumn(), o.IDColumn()),
|
||||
@@ -43,15 +42,19 @@ func (o *org) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Or
|
||||
opt(options)
|
||||
}
|
||||
|
||||
if !options.Condition.IsRestrictingColumn(o.InstanceIDColumn()) {
|
||||
return nil, database.NewMissingConditionError(o.InstanceIDColumn())
|
||||
}
|
||||
|
||||
var builder database.StatementBuilder
|
||||
builder.WriteString(queryOrganizationStmt)
|
||||
options.Write(&builder)
|
||||
|
||||
return scanOrganization(ctx, o.client, &builder)
|
||||
return scanOrganization(ctx, client, &builder)
|
||||
}
|
||||
|
||||
// List implements [domain.OrganizationRepository].
|
||||
func (o *org) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Organization, error) {
|
||||
func (o org) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*domain.Organization, error) {
|
||||
opts = append(opts,
|
||||
o.joinDomains(),
|
||||
database.WithGroupBy(o.InstanceIDColumn(), o.IDColumn()),
|
||||
@@ -62,30 +65,15 @@ func (o *org) List(ctx context.Context, opts ...database.QueryOption) ([]*domain
|
||||
opt(options)
|
||||
}
|
||||
|
||||
if !options.Condition.IsRestrictingColumn(o.InstanceIDColumn()) {
|
||||
return nil, database.NewMissingConditionError(o.InstanceIDColumn())
|
||||
}
|
||||
|
||||
var builder database.StatementBuilder
|
||||
builder.WriteString(queryOrganizationStmt)
|
||||
options.Write(&builder)
|
||||
|
||||
return scanOrganizations(ctx, o.client, &builder)
|
||||
}
|
||||
|
||||
func (o *org) joinDomains() database.QueryOption {
|
||||
columns := make([]database.Condition, 0, 3)
|
||||
columns = append(columns,
|
||||
database.NewColumnCondition(o.InstanceIDColumn(), o.Domains(false).InstanceIDColumn()),
|
||||
database.NewColumnCondition(o.IDColumn(), o.Domains(false).OrgIDColumn()),
|
||||
)
|
||||
|
||||
// If domains should not be joined, we make sure to return null for the domain columns
|
||||
// the query optimizer of the dialect should optimize this away if no domains are requested
|
||||
if !o.shouldLoadDomains {
|
||||
columns = append(columns, database.IsNull(o.domainRepo.OrgIDColumn()))
|
||||
}
|
||||
|
||||
return database.WithLeftJoin(
|
||||
"zitadel.org_domains",
|
||||
database.And(columns...),
|
||||
)
|
||||
return scanOrganizations(ctx, client, &builder)
|
||||
}
|
||||
|
||||
const createOrganizationStmt = `INSERT INTO zitadel.organizations (id, name, instance_id, state)` +
|
||||
@@ -93,46 +81,48 @@ const createOrganizationStmt = `INSERT INTO zitadel.organizations (id, name, ins
|
||||
` RETURNING created_at, updated_at`
|
||||
|
||||
// Create implements [domain.OrganizationRepository].
|
||||
func (o *org) Create(ctx context.Context, organization *domain.Organization) error {
|
||||
func (o org) Create(ctx context.Context, client database.QueryExecutor, organization *domain.Organization) error {
|
||||
builder := database.StatementBuilder{}
|
||||
builder.AppendArgs(organization.ID, organization.Name, organization.InstanceID, organization.State)
|
||||
builder.WriteString(createOrganizationStmt)
|
||||
|
||||
return o.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&organization.CreatedAt, &organization.UpdatedAt)
|
||||
return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&organization.CreatedAt, &organization.UpdatedAt)
|
||||
}
|
||||
|
||||
// Update implements [domain.OrganizationRepository].
|
||||
func (o *org) Update(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string, changes ...database.Change) (int64, error) {
|
||||
func (o org) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error) {
|
||||
if len(changes) == 0 {
|
||||
return 0, database.ErrNoChanges
|
||||
}
|
||||
builder := database.StatementBuilder{}
|
||||
if !condition.IsRestrictingColumn(o.InstanceIDColumn()) {
|
||||
return 0, database.NewMissingConditionError(o.InstanceIDColumn())
|
||||
}
|
||||
if !database.Changes(changes).IsOnColumn(o.UpdatedAtColumn()) {
|
||||
changes = append(changes, database.NewChange(o.UpdatedAtColumn(), database.NullInstruction))
|
||||
}
|
||||
|
||||
var builder database.StatementBuilder
|
||||
builder.WriteString(`UPDATE zitadel.organizations SET `)
|
||||
|
||||
instanceIDCondition := o.InstanceIDCondition(instanceID)
|
||||
|
||||
conditions := []database.Condition{id, instanceIDCondition}
|
||||
database.Changes(changes).Write(&builder)
|
||||
writeCondition(&builder, database.And(conditions...))
|
||||
writeCondition(&builder, condition)
|
||||
|
||||
stmt := builder.String()
|
||||
|
||||
rowsAffected, err := o.client.Exec(ctx, stmt, builder.Args()...)
|
||||
rowsAffected, err := client.Exec(ctx, stmt, builder.Args()...)
|
||||
return rowsAffected, err
|
||||
}
|
||||
|
||||
// Delete implements [domain.OrganizationRepository].
|
||||
func (o *org) Delete(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string) (int64, error) {
|
||||
builder := database.StatementBuilder{}
|
||||
func (o org) Delete(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error) {
|
||||
if !condition.IsRestrictingColumn(o.InstanceIDColumn()) {
|
||||
return 0, database.NewMissingConditionError(o.InstanceIDColumn())
|
||||
}
|
||||
|
||||
var builder database.StatementBuilder
|
||||
builder.WriteString(`DELETE FROM zitadel.organizations`)
|
||||
writeCondition(&builder, condition)
|
||||
|
||||
instanceIDCondition := o.InstanceIDCondition(instanceID)
|
||||
|
||||
conditions := []database.Condition{id, instanceIDCondition}
|
||||
writeCondition(&builder, database.And(conditions...))
|
||||
|
||||
return o.client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
return client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
@@ -154,13 +144,13 @@ func (o org) SetState(state domain.OrgState) database.Change {
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// IDCondition implements [domain.organizationConditions].
|
||||
func (o org) IDCondition(id string) domain.OrgIdentifierCondition {
|
||||
func (o org) IDCondition(id string) database.Condition {
|
||||
return database.NewTextCondition(o.IDColumn(), database.TextOperationEqual, id)
|
||||
}
|
||||
|
||||
// NameCondition implements [domain.organizationConditions].
|
||||
func (o org) NameCondition(name string) domain.OrgIdentifierCondition {
|
||||
return database.NewTextCondition(o.NameColumn(), database.TextOperationEqual, name)
|
||||
func (o org) NameCondition(op database.TextOperation, name string) database.Condition {
|
||||
return database.NewTextCondition(o.NameColumn(), op, name)
|
||||
}
|
||||
|
||||
// InstanceIDCondition implements [domain.organizationConditions].
|
||||
@@ -173,38 +163,62 @@ func (o org) StateCondition(state domain.OrgState) database.Condition {
|
||||
return database.NewTextCondition(o.StateColumn(), database.TextOperationEqual, state.String())
|
||||
}
|
||||
|
||||
// ExistsDomain creates a correlated [database.Exists] condition on org_domains.
|
||||
// Use this filter to make sure the organization returned contains a specific domain.
|
||||
// Example usage:
|
||||
//
|
||||
// domainRepo := orgRepo.Domains(true) // ensure domains are loaded/aggregated
|
||||
// org, _ := orgRepo.Get(ctx,
|
||||
// database.WithCondition(
|
||||
// database.And(
|
||||
// orgRepo.InstanceIDCondition(instanceID),
|
||||
// orgRepo.ExistsDomain(domainRepo.DomainCondition(database.TextOperationEqual, "example.com")),
|
||||
// ),
|
||||
// ),
|
||||
// )
|
||||
func (o org) ExistsDomain(cond database.Condition) database.Condition {
|
||||
return database.Exists(
|
||||
o.domainRepo.qualifiedTableName(),
|
||||
database.And(
|
||||
database.NewColumnCondition(o.InstanceIDColumn(), o.domainRepo.InstanceIDColumn()),
|
||||
database.NewColumnCondition(o.IDColumn(), o.domainRepo.OrgIDColumn()),
|
||||
cond,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// columns
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// IDColumn implements [domain.organizationColumns].
|
||||
func (org) IDColumn() database.Column {
|
||||
return database.NewColumn("organizations", "id")
|
||||
func (o org) IDColumn() database.Column {
|
||||
return database.NewColumn(o.unqualifiedTableName(), "id")
|
||||
}
|
||||
|
||||
// NameColumn implements [domain.organizationColumns].
|
||||
func (org) NameColumn() database.Column {
|
||||
return database.NewColumn("organizations", "name")
|
||||
func (o org) NameColumn() database.Column {
|
||||
return database.NewColumn(o.unqualifiedTableName(), "name")
|
||||
}
|
||||
|
||||
// InstanceIDColumn implements [domain.organizationColumns].
|
||||
func (org) InstanceIDColumn() database.Column {
|
||||
return database.NewColumn("organizations", "instance_id")
|
||||
func (o org) InstanceIDColumn() database.Column {
|
||||
return database.NewColumn(o.unqualifiedTableName(), "instance_id")
|
||||
}
|
||||
|
||||
// StateColumn implements [domain.organizationColumns].
|
||||
func (org) StateColumn() database.Column {
|
||||
return database.NewColumn("organizations", "state")
|
||||
func (o org) StateColumn() database.Column {
|
||||
return database.NewColumn(o.unqualifiedTableName(), "state")
|
||||
}
|
||||
|
||||
// CreatedAtColumn implements [domain.organizationColumns].
|
||||
func (org) CreatedAtColumn() database.Column {
|
||||
return database.NewColumn("organizations", "created_at")
|
||||
func (o org) CreatedAtColumn() database.Column {
|
||||
return database.NewColumn(o.unqualifiedTableName(), "created_at")
|
||||
}
|
||||
|
||||
// UpdatedAtColumn implements [domain.organizationColumns].
|
||||
func (org) UpdatedAtColumn() database.Column {
|
||||
return database.NewColumn("organizations", "updated_at")
|
||||
func (o org) UpdatedAtColumn() database.Column {
|
||||
return database.NewColumn(o.unqualifiedTableName(), "updated_at")
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
@@ -255,20 +269,27 @@ func scanOrganizations(ctx context.Context, querier database.Querier, builder *d
|
||||
// sub repositories
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// Domains implements [domain.OrganizationRepository].
|
||||
func (o *org) Domains(shouldLoad bool) domain.OrganizationDomainRepository {
|
||||
if !o.shouldLoadDomains {
|
||||
o.shouldLoadDomains = shouldLoad
|
||||
func (o org) LoadDomains() domain.OrganizationRepository {
|
||||
return &org{
|
||||
shouldLoadDomains: true,
|
||||
}
|
||||
|
||||
if o.domainRepo != nil {
|
||||
return o.domainRepo
|
||||
}
|
||||
|
||||
o.domainRepo = &orgDomain{
|
||||
repository: o.repository,
|
||||
org: o,
|
||||
}
|
||||
|
||||
return o.domainRepo
|
||||
}
|
||||
|
||||
func (o org) joinDomains() database.QueryOption {
|
||||
columns := make([]database.Condition, 0, 3)
|
||||
columns = append(columns,
|
||||
database.NewColumnCondition(o.InstanceIDColumn(), o.domainRepo.InstanceIDColumn()),
|
||||
database.NewColumnCondition(o.IDColumn(), o.domainRepo.OrgIDColumn()),
|
||||
)
|
||||
|
||||
// If domains should not be joined, we make sure to return null for the domain columns
|
||||
// the query optimizer of the dialect should optimize this away if no domains are requested
|
||||
if !o.shouldLoadDomains {
|
||||
columns = append(columns, database.IsNull(o.domainRepo.OrgIDColumn()))
|
||||
}
|
||||
|
||||
return database.WithLeftJoin(
|
||||
o.domainRepo.qualifiedTableName(),
|
||||
database.And(columns...),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -10,9 +10,18 @@ import (
|
||||
|
||||
var _ domain.OrganizationDomainRepository = (*orgDomain)(nil)
|
||||
|
||||
type orgDomain struct {
|
||||
repository
|
||||
*org
|
||||
type orgDomain struct{}
|
||||
|
||||
func OrganizationDomainRepository() domain.OrganizationDomainRepository {
|
||||
return new(orgDomain)
|
||||
}
|
||||
|
||||
func (orgDomain) qualifiedTableName() string {
|
||||
return "zitadel.org_domains"
|
||||
}
|
||||
|
||||
func (orgDomain) unqualifiedTableName() string {
|
||||
return "org_domains"
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
@@ -24,36 +33,44 @@ const queryOrganizationDomainStmt = `SELECT instance_id, org_id, domain, is_veri
|
||||
|
||||
// Get implements [domain.OrganizationDomainRepository].
|
||||
// Subtle: this method shadows the method ([domain.OrganizationRepository]).Get of orgDomain.org.
|
||||
func (o *orgDomain) Get(ctx context.Context, opts ...database.QueryOption) (*domain.OrganizationDomain, error) {
|
||||
func (o orgDomain) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.OrganizationDomain, error) {
|
||||
options := new(database.QueryOpts)
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
if !options.Condition.IsRestrictingColumn(o.InstanceIDColumn()) {
|
||||
return nil, database.NewMissingConditionError(o.InstanceIDColumn())
|
||||
}
|
||||
|
||||
var builder database.StatementBuilder
|
||||
builder.WriteString(queryOrganizationDomainStmt)
|
||||
options.Write(&builder)
|
||||
|
||||
return scanOrganizationDomain(ctx, o.client, &builder)
|
||||
return scanOrganizationDomain(ctx, client, &builder)
|
||||
}
|
||||
|
||||
// List implements [domain.OrganizationDomainRepository].
|
||||
// Subtle: this method shadows the method ([domain.OrganizationRepository]).List of orgDomain.org.
|
||||
func (o *orgDomain) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.OrganizationDomain, error) {
|
||||
func (o orgDomain) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*domain.OrganizationDomain, error) {
|
||||
options := new(database.QueryOpts)
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
if !options.Condition.IsRestrictingColumn(o.InstanceIDColumn()) {
|
||||
return nil, database.NewMissingConditionError(o.InstanceIDColumn())
|
||||
}
|
||||
|
||||
var builder database.StatementBuilder
|
||||
builder.WriteString(queryOrganizationDomainStmt)
|
||||
options.Write(&builder)
|
||||
|
||||
return scanOrganizationDomains(ctx, o.client, &builder)
|
||||
return scanOrganizationDomains(ctx, client, &builder)
|
||||
}
|
||||
|
||||
// Add implements [domain.OrganizationDomainRepository].
|
||||
func (o *orgDomain) Add(ctx context.Context, domain *domain.AddOrganizationDomain) error {
|
||||
func (o orgDomain) Add(ctx context.Context, client database.QueryExecutor, domain *domain.AddOrganizationDomain) error {
|
||||
var (
|
||||
builder database.StatementBuilder
|
||||
createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction
|
||||
@@ -69,33 +86,47 @@ func (o *orgDomain) Add(ctx context.Context, domain *domain.AddOrganizationDomai
|
||||
builder.WriteArgs(domain.InstanceID, domain.OrgID, domain.Domain, domain.IsVerified, domain.IsPrimary, domain.ValidationType, createdAt, updatedAt)
|
||||
builder.WriteString(`) RETURNING created_at, updated_at`)
|
||||
|
||||
return o.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt)
|
||||
return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt)
|
||||
}
|
||||
|
||||
// Update implements [domain.OrganizationDomainRepository].
|
||||
// Subtle: this method shadows the method ([domain.OrganizationRepository]).Update of orgDomain.org.
|
||||
func (o *orgDomain) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) {
|
||||
func (o orgDomain) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error) {
|
||||
if len(changes) == 0 {
|
||||
return 0, database.ErrNoChanges
|
||||
}
|
||||
if !condition.IsRestrictingColumn(o.InstanceIDColumn()) {
|
||||
return 0, database.NewMissingConditionError(o.InstanceIDColumn())
|
||||
}
|
||||
if !condition.IsRestrictingColumn(o.OrgIDColumn()) {
|
||||
return 0, database.NewMissingConditionError(o.OrgIDColumn())
|
||||
}
|
||||
if !database.Changes(changes).IsOnColumn(o.UpdatedAtColumn()) {
|
||||
changes = append(changes, database.NewChange(o.UpdatedAtColumn(), database.NullInstruction))
|
||||
}
|
||||
|
||||
var builder database.StatementBuilder
|
||||
|
||||
builder.WriteString(`UPDATE zitadel.org_domains SET `)
|
||||
database.Changes(changes).Write(&builder)
|
||||
writeCondition(&builder, condition)
|
||||
|
||||
return o.client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
return client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
}
|
||||
|
||||
// Remove implements [domain.OrganizationDomainRepository].
|
||||
func (o *orgDomain) Remove(ctx context.Context, condition database.Condition) (int64, error) {
|
||||
var builder database.StatementBuilder
|
||||
func (o orgDomain) Remove(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error) {
|
||||
if !condition.IsRestrictingColumn(o.InstanceIDColumn()) {
|
||||
return 0, database.NewMissingConditionError(o.InstanceIDColumn())
|
||||
}
|
||||
if !condition.IsRestrictingColumn(o.OrgIDColumn()) {
|
||||
return 0, database.NewMissingConditionError(o.OrgIDColumn())
|
||||
}
|
||||
|
||||
var builder database.StatementBuilder
|
||||
builder.WriteString(`DELETE FROM zitadel.org_domains `)
|
||||
writeCondition(&builder, condition)
|
||||
|
||||
return o.client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
return client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
@@ -158,45 +189,43 @@ func (o orgDomain) OrgIDCondition(orgID string) database.Condition {
|
||||
|
||||
// CreatedAtColumn implements [domain.OrganizationDomainRepository].
|
||||
// Subtle: this method shadows the method ([domain.OrganizationRepository]).CreatedAtColumn of orgDomain.org.
|
||||
func (orgDomain) CreatedAtColumn() database.Column {
|
||||
return database.NewColumn("org_domains", "created_at")
|
||||
func (o orgDomain) CreatedAtColumn() database.Column {
|
||||
return database.NewColumn(o.unqualifiedTableName(), "created_at")
|
||||
}
|
||||
|
||||
// DomainColumn implements [domain.OrganizationDomainRepository].
|
||||
func (orgDomain) DomainColumn() database.Column {
|
||||
return database.NewColumn("org_domains", "domain")
|
||||
func (o orgDomain) DomainColumn() database.Column {
|
||||
return database.NewColumn(o.unqualifiedTableName(), "domain")
|
||||
}
|
||||
|
||||
// InstanceIDColumn implements [domain.OrganizationDomainRepository].
|
||||
// Subtle: this method shadows the method ([domain.OrganizationRepository]).InstanceIDColumn of orgDomain.org.
|
||||
func (orgDomain) InstanceIDColumn() database.Column {
|
||||
return database.NewColumn("org_domains", "instance_id")
|
||||
func (o orgDomain) InstanceIDColumn() database.Column {
|
||||
return database.NewColumn(o.unqualifiedTableName(), "instance_id")
|
||||
}
|
||||
|
||||
// IsPrimaryColumn implements [domain.OrganizationDomainRepository].
|
||||
func (orgDomain) IsPrimaryColumn() database.Column {
|
||||
return database.NewColumn("org_domains", "is_primary")
|
||||
func (o orgDomain) IsPrimaryColumn() database.Column {
|
||||
return database.NewColumn(o.unqualifiedTableName(), "is_primary")
|
||||
}
|
||||
|
||||
// IsVerifiedColumn implements [domain.OrganizationDomainRepository].
|
||||
func (orgDomain) IsVerifiedColumn() database.Column {
|
||||
return database.NewColumn("org_domains", "is_verified")
|
||||
func (o orgDomain) IsVerifiedColumn() database.Column {
|
||||
return database.NewColumn(o.unqualifiedTableName(), "is_verified")
|
||||
}
|
||||
|
||||
// OrgIDColumn implements [domain.OrganizationDomainRepository].
|
||||
func (orgDomain) OrgIDColumn() database.Column {
|
||||
return database.NewColumn("org_domains", "org_id")
|
||||
func (o orgDomain) OrgIDColumn() database.Column {
|
||||
return database.NewColumn(o.unqualifiedTableName(), "org_id")
|
||||
}
|
||||
|
||||
// UpdatedAtColumn implements [domain.OrganizationDomainRepository].
|
||||
// Subtle: this method shadows the method ([domain.OrganizationRepository]).UpdatedAtColumn of orgDomain.org.
|
||||
func (orgDomain) UpdatedAtColumn() database.Column {
|
||||
return database.NewColumn("org_domains", "updated_at")
|
||||
func (o orgDomain) UpdatedAtColumn() database.Column {
|
||||
return database.NewColumn(o.unqualifiedTableName(), "updated_at")
|
||||
}
|
||||
|
||||
// ValidationTypeColumn implements [domain.OrganizationDomainRepository].
|
||||
func (orgDomain) ValidationTypeColumn() database.Column {
|
||||
return database.NewColumn("org_domains", "validation_type")
|
||||
func (o orgDomain) ValidationTypeColumn() database.Column {
|
||||
return database.NewColumn(o.unqualifiedTableName(), "validation_type")
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package repository_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v6"
|
||||
@@ -15,6 +14,15 @@ import (
|
||||
)
|
||||
|
||||
func TestAddOrganizationDomain(t *testing.T) {
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = tx.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// create instance
|
||||
instanceID := gofakeit.UUID()
|
||||
instance := domain.Instance{
|
||||
@@ -26,8 +34,11 @@ func TestAddOrganizationDomain(t *testing.T) {
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
err := instanceRepo.Create(t.Context(), &instance)
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
orgRepo := repository.OrganizationRepository()
|
||||
domainRepo := repository.OrganizationDomainRepository()
|
||||
|
||||
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
// create organization
|
||||
@@ -41,7 +52,7 @@ func TestAddOrganizationDomain(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T, domainRepo domain.OrganizationDomainRepository) *domain.AddOrganizationDomain
|
||||
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.AddOrganizationDomain
|
||||
organizationDomain domain.AddOrganizationDomain
|
||||
err error
|
||||
}{
|
||||
@@ -92,7 +103,7 @@ func TestAddOrganizationDomain(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "add domain with same domain twice",
|
||||
testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.OrganizationDomainRepository) *domain.AddOrganizationDomain {
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.AddOrganizationDomain {
|
||||
domainName := gofakeit.DomainName()
|
||||
|
||||
organizationDomain := &domain.AddOrganizationDomain{
|
||||
@@ -104,7 +115,7 @@ func TestAddOrganizationDomain(t *testing.T) {
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
}
|
||||
|
||||
err := domainRepo.Add(ctx, organizationDomain)
|
||||
err := domainRepo.Add(t.Context(), tx, organizationDomain)
|
||||
require.NoError(t, err)
|
||||
|
||||
// return same domain again
|
||||
@@ -169,28 +180,26 @@ func TestAddOrganizationDomain(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
savepoint, err := tx.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, tx.Rollback(t.Context()))
|
||||
err = savepoint.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
orgRepo := repository.OrganizationRepository(tx)
|
||||
err = orgRepo.Create(t.Context(), &organization)
|
||||
err = orgRepo.Create(t.Context(), savepoint, &organization)
|
||||
require.NoError(t, err)
|
||||
|
||||
domainRepo := orgRepo.Domains(false)
|
||||
|
||||
var organizationDomain *domain.AddOrganizationDomain
|
||||
if test.testFunc != nil {
|
||||
organizationDomain = test.testFunc(ctx, t, domainRepo)
|
||||
organizationDomain = test.testFunc(t, savepoint)
|
||||
} else {
|
||||
organizationDomain = &test.organizationDomain
|
||||
}
|
||||
|
||||
err = domainRepo.Add(ctx, organizationDomain)
|
||||
err = domainRepo.Add(t.Context(), tx, organizationDomain)
|
||||
if test.err != nil {
|
||||
assert.ErrorIs(t, err, test.err)
|
||||
return
|
||||
@@ -204,6 +213,16 @@ func TestAddOrganizationDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetOrganizationDomain(t *testing.T) {
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, tx.Rollback(t.Context()))
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
orgRepo := repository.OrganizationRepository()
|
||||
domainRepo := repository.OrganizationDomainRepository()
|
||||
|
||||
// create instance
|
||||
instanceID := gofakeit.UUID()
|
||||
instance := domain.Instance{
|
||||
@@ -225,29 +244,18 @@ func TestGetOrganizationDomain(t *testing.T) {
|
||||
State: domain.OrgStateActive,
|
||||
}
|
||||
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, tx.Rollback(t.Context()))
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository(tx)
|
||||
err = instanceRepo.Create(t.Context(), &instance)
|
||||
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
orgRepo := repository.OrganizationRepository(tx)
|
||||
err = orgRepo.Create(t.Context(), &organization)
|
||||
err = orgRepo.Create(t.Context(), tx, &organization)
|
||||
require.NoError(t, err)
|
||||
|
||||
// add domains
|
||||
domainRepo := orgRepo.Domains(false)
|
||||
domainName1 := gofakeit.DomainName()
|
||||
domainName2 := gofakeit.DomainName()
|
||||
|
||||
domain1 := &domain.AddOrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: domainName1,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsVerified: true,
|
||||
IsPrimary: true,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
@@ -255,15 +263,15 @@ func TestGetOrganizationDomain(t *testing.T) {
|
||||
domain2 := &domain.AddOrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: domainName2,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsVerified: false,
|
||||
IsPrimary: false,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
|
||||
}
|
||||
|
||||
err = domainRepo.Add(t.Context(), domain1)
|
||||
err = domainRepo.Add(t.Context(), tx, domain1)
|
||||
require.NoError(t, err)
|
||||
err = domainRepo.Add(t.Context(), domain2)
|
||||
err = domainRepo.Add(t.Context(), tx, domain2)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
@@ -275,12 +283,15 @@ func TestGetOrganizationDomain(t *testing.T) {
|
||||
{
|
||||
name: "get primary domain",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
|
||||
database.WithCondition(database.And(
|
||||
domainRepo.InstanceIDCondition(instanceID),
|
||||
domainRepo.IsPrimaryCondition(true),
|
||||
)),
|
||||
},
|
||||
expected: &domain.OrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: domainName1,
|
||||
Domain: domain1.Domain,
|
||||
IsVerified: true,
|
||||
IsPrimary: true,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
@@ -289,12 +300,15 @@ func TestGetOrganizationDomain(t *testing.T) {
|
||||
{
|
||||
name: "get by domain name",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, domainName2)),
|
||||
database.WithCondition(database.And(
|
||||
domainRepo.InstanceIDCondition(instanceID),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, domain2.Domain),
|
||||
)),
|
||||
},
|
||||
expected: &domain.OrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: domainName2,
|
||||
Domain: domain2.Domain,
|
||||
IsVerified: false,
|
||||
IsPrimary: false,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
|
||||
@@ -303,13 +317,16 @@ func TestGetOrganizationDomain(t *testing.T) {
|
||||
{
|
||||
name: "get by org ID",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.OrgIDCondition(orgID)),
|
||||
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
|
||||
database.WithCondition(database.And(
|
||||
domainRepo.InstanceIDCondition(instanceID),
|
||||
domainRepo.OrgIDCondition(orgID),
|
||||
domainRepo.IsPrimaryCondition(true),
|
||||
)),
|
||||
},
|
||||
expected: &domain.OrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: domainName1,
|
||||
Domain: domain1.Domain,
|
||||
IsVerified: true,
|
||||
IsPrimary: true,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
@@ -318,12 +335,15 @@ func TestGetOrganizationDomain(t *testing.T) {
|
||||
{
|
||||
name: "get verified domain",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.IsVerifiedCondition(true)),
|
||||
database.WithCondition(database.And(
|
||||
domainRepo.InstanceIDCondition(instanceID),
|
||||
domainRepo.IsVerifiedCondition(true),
|
||||
)),
|
||||
},
|
||||
expected: &domain.OrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: domainName1,
|
||||
Domain: domain1.Domain,
|
||||
IsVerified: true,
|
||||
IsPrimary: true,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
@@ -332,7 +352,10 @@ func TestGetOrganizationDomain(t *testing.T) {
|
||||
{
|
||||
name: "get non-existent domain",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com")),
|
||||
database.WithCondition(database.And(
|
||||
domainRepo.InstanceIDCondition(instanceID),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
|
||||
)),
|
||||
},
|
||||
err: new(database.NoRowFoundError),
|
||||
},
|
||||
@@ -340,9 +363,7 @@ func TestGetOrganizationDomain(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
|
||||
result, err := domainRepo.Get(ctx, test.opts...)
|
||||
result, err := domainRepo.Get(t.Context(), tx, test.opts...)
|
||||
if test.err != nil {
|
||||
assert.ErrorIs(t, err, test.err)
|
||||
return
|
||||
@@ -362,10 +383,22 @@ func TestGetOrganizationDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestListOrganizationDomains(t *testing.T) {
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = tx.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
orgRepo := repository.OrganizationRepository()
|
||||
domainRepo := repository.OrganizationDomainRepository()
|
||||
|
||||
// create instance
|
||||
instanceID := gofakeit.UUID()
|
||||
instance := domain.Instance{
|
||||
ID: instanceID,
|
||||
ID: gofakeit.UUID(),
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
@@ -375,50 +408,40 @@ func TestListOrganizationDomains(t *testing.T) {
|
||||
}
|
||||
|
||||
// create organization
|
||||
orgID := gofakeit.UUID()
|
||||
organization := domain.Organization{
|
||||
ID: orgID,
|
||||
ID: gofakeit.UUID(),
|
||||
Name: gofakeit.Name(),
|
||||
InstanceID: instanceID,
|
||||
InstanceID: instance.ID,
|
||||
State: domain.OrgStateActive,
|
||||
}
|
||||
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, tx.Rollback(t.Context()))
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository(tx)
|
||||
err = instanceRepo.Create(t.Context(), &instance)
|
||||
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
orgRepo := repository.OrganizationRepository(tx)
|
||||
err = orgRepo.Create(t.Context(), &organization)
|
||||
err = orgRepo.Create(t.Context(), tx, &organization)
|
||||
require.NoError(t, err)
|
||||
|
||||
// add multiple domains
|
||||
domainRepo := orgRepo.Domains(false)
|
||||
domains := []domain.AddOrganizationDomain{
|
||||
{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
InstanceID: instance.ID,
|
||||
OrgID: organization.ID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsVerified: true,
|
||||
IsPrimary: true,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
},
|
||||
{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
InstanceID: instance.ID,
|
||||
OrgID: organization.ID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsVerified: false,
|
||||
IsPrimary: false,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
|
||||
},
|
||||
{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
InstanceID: instance.ID,
|
||||
OrgID: organization.ID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsVerified: true,
|
||||
IsPrimary: false,
|
||||
@@ -427,7 +450,7 @@ func TestListOrganizationDomains(t *testing.T) {
|
||||
}
|
||||
|
||||
for i := range domains {
|
||||
err = domainRepo.Add(t.Context(), &domains[i])
|
||||
err = domainRepo.Add(t.Context(), tx, &domains[i])
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -437,42 +460,59 @@ func TestListOrganizationDomains(t *testing.T) {
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "list all domains",
|
||||
opts: []database.QueryOption{},
|
||||
name: "list all domains",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.InstanceIDCondition(instance.ID)),
|
||||
},
|
||||
expectedCount: 3,
|
||||
},
|
||||
{
|
||||
name: "list verified domains",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.IsVerifiedCondition(true)),
|
||||
database.WithCondition(database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.IsVerifiedCondition(true),
|
||||
)),
|
||||
},
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "list primary domains",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
|
||||
database.WithCondition(database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.IsPrimaryCondition(true),
|
||||
)),
|
||||
},
|
||||
expectedCount: 1,
|
||||
},
|
||||
{
|
||||
name: "list by organization",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.OrgIDCondition(orgID)),
|
||||
database.WithCondition(database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.OrgIDCondition(organization.ID),
|
||||
)),
|
||||
},
|
||||
expectedCount: 3,
|
||||
},
|
||||
{
|
||||
name: "list by instance",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.InstanceIDCondition(instanceID)),
|
||||
database.WithCondition(database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
)),
|
||||
},
|
||||
expectedCount: 3,
|
||||
},
|
||||
{
|
||||
name: "list non-existent organization",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.OrgIDCondition("non-existent")),
|
||||
database.WithCondition(database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.OrgIDCondition("non-existent"),
|
||||
)),
|
||||
},
|
||||
expectedCount: 0,
|
||||
},
|
||||
@@ -482,13 +522,13 @@ func TestListOrganizationDomains(t *testing.T) {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
|
||||
results, err := domainRepo.List(ctx, test.opts...)
|
||||
results, err := domainRepo.List(ctx, tx, test.opts...)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, test.expectedCount)
|
||||
|
||||
for _, result := range results {
|
||||
assert.Equal(t, instanceID, result.InstanceID)
|
||||
assert.Equal(t, orgID, result.OrgID)
|
||||
assert.Equal(t, instance.ID, result.InstanceID)
|
||||
assert.Equal(t, organization.ID, result.OrgID)
|
||||
assert.NotEmpty(t, result.Domain)
|
||||
assert.NotEmpty(t, result.CreatedAt)
|
||||
assert.NotEmpty(t, result.UpdatedAt)
|
||||
@@ -498,10 +538,22 @@ func TestListOrganizationDomains(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUpdateOrganizationDomain(t *testing.T) {
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = tx.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
orgRepo := repository.OrganizationRepository()
|
||||
domainRepo := repository.OrganizationDomainRepository()
|
||||
|
||||
// create instance
|
||||
instanceID := gofakeit.UUID()
|
||||
instance := domain.Instance{
|
||||
ID: instanceID,
|
||||
ID: gofakeit.UUID(),
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
@@ -511,41 +563,30 @@ func TestUpdateOrganizationDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
// create organization
|
||||
orgID := gofakeit.UUID()
|
||||
organization := domain.Organization{
|
||||
ID: orgID,
|
||||
ID: gofakeit.UUID(),
|
||||
Name: gofakeit.Name(),
|
||||
InstanceID: instanceID,
|
||||
InstanceID: instance.ID,
|
||||
State: domain.OrgStateActive,
|
||||
}
|
||||
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, tx.Rollback(t.Context()))
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository(tx)
|
||||
err = instanceRepo.Create(t.Context(), &instance)
|
||||
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
orgRepo := repository.OrganizationRepository(tx)
|
||||
err = orgRepo.Create(t.Context(), &organization)
|
||||
err = orgRepo.Create(t.Context(), tx, &organization)
|
||||
require.NoError(t, err)
|
||||
|
||||
// add domain
|
||||
domainRepo := orgRepo.Domains(false)
|
||||
domainName := gofakeit.DomainName()
|
||||
organizationDomain := &domain.AddOrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: domainName,
|
||||
InstanceID: instance.ID,
|
||||
OrgID: organization.ID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsVerified: false,
|
||||
IsPrimary: false,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
}
|
||||
|
||||
err = domainRepo.Add(t.Context(), organizationDomain)
|
||||
err = domainRepo.Add(t.Context(), tx, organizationDomain)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
@@ -556,26 +597,42 @@ func TestUpdateOrganizationDomain(t *testing.T) {
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "set verified",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
||||
changes: []database.Change{domainRepo.SetVerified()},
|
||||
expected: 1,
|
||||
name: "set verified",
|
||||
condition: database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.OrgIDCondition(organization.ID),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
|
||||
),
|
||||
changes: []database.Change{domainRepo.SetVerified()},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "set primary",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
||||
changes: []database.Change{domainRepo.SetPrimary()},
|
||||
expected: 1,
|
||||
name: "set primary",
|
||||
condition: database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.OrgIDCondition(organization.ID),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
|
||||
),
|
||||
changes: []database.Change{domainRepo.SetPrimary()},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "set validation type",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
||||
changes: []database.Change{domainRepo.SetValidationType(domain.DomainValidationTypeHTTP)},
|
||||
expected: 1,
|
||||
name: "set validation type",
|
||||
condition: database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.OrgIDCondition(organization.ID),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
|
||||
),
|
||||
changes: []database.Change{domainRepo.SetValidationType(domain.DomainValidationTypeHTTP)},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple changes",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
||||
name: "multiple changes",
|
||||
condition: database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.OrgIDCondition(organization.ID),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
|
||||
),
|
||||
changes: []database.Change{
|
||||
domainRepo.SetVerified(),
|
||||
domainRepo.SetPrimary(),
|
||||
@@ -584,31 +641,41 @@ func TestUpdateOrganizationDomain(t *testing.T) {
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "update by org ID and domain",
|
||||
condition: database.And(domainRepo.OrgIDCondition(orgID), domainRepo.DomainCondition(database.TextOperationEqual, domainName)),
|
||||
changes: []database.Change{domainRepo.SetVerified()},
|
||||
expected: 1,
|
||||
name: "update by org ID and domain",
|
||||
condition: database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.OrgIDCondition(organization.ID),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
|
||||
),
|
||||
changes: []database.Change{domainRepo.SetVerified()},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "update non-existent domain",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
|
||||
changes: []database.Change{domainRepo.SetVerified()},
|
||||
expected: 0,
|
||||
name: "update non-existent domain",
|
||||
condition: database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.OrgIDCondition(organization.ID),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
|
||||
),
|
||||
changes: []database.Change{domainRepo.SetVerified()},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "no changes",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
||||
changes: []database.Change{},
|
||||
expected: 0,
|
||||
err: database.ErrNoChanges,
|
||||
name: "no changes",
|
||||
condition: database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.OrgIDCondition(organization.ID),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
|
||||
),
|
||||
changes: []database.Change{},
|
||||
expected: 0,
|
||||
err: database.ErrNoChanges,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
|
||||
rowsAffected, err := domainRepo.Update(ctx, test.condition, test.changes...)
|
||||
rowsAffected, err := domainRepo.Update(t.Context(), tx, test.condition, test.changes...)
|
||||
if test.err != nil {
|
||||
assert.ErrorIs(t, err, test.err)
|
||||
return
|
||||
@@ -619,7 +686,7 @@ func TestUpdateOrganizationDomain(t *testing.T) {
|
||||
|
||||
// verify changes were applied if rows were affected
|
||||
if rowsAffected > 0 && len(test.changes) > 0 {
|
||||
result, err := domainRepo.Get(ctx, database.WithCondition(test.condition))
|
||||
result, err := domainRepo.Get(t.Context(), tx, database.WithCondition(test.condition))
|
||||
require.NoError(t, err)
|
||||
|
||||
// We know changes were applied since rowsAffected > 0
|
||||
@@ -632,10 +699,22 @@ func TestUpdateOrganizationDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRemoveOrganizationDomain(t *testing.T) {
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = tx.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
orgRepo := repository.OrganizationRepository()
|
||||
domainRepo := repository.OrganizationDomainRepository()
|
||||
|
||||
// create instance
|
||||
instanceID := gofakeit.UUID()
|
||||
instance := domain.Instance{
|
||||
ID: instanceID,
|
||||
ID: gofakeit.UUID(),
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
@@ -645,53 +724,40 @@ func TestRemoveOrganizationDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
// create organization
|
||||
orgID := gofakeit.UUID()
|
||||
organization := domain.Organization{
|
||||
ID: orgID,
|
||||
ID: gofakeit.UUID(),
|
||||
Name: gofakeit.Name(),
|
||||
InstanceID: instanceID,
|
||||
InstanceID: instance.ID,
|
||||
State: domain.OrgStateActive,
|
||||
}
|
||||
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, tx.Rollback(t.Context()))
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository(tx)
|
||||
err = instanceRepo.Create(t.Context(), &instance)
|
||||
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
orgRepo := repository.OrganizationRepository(tx)
|
||||
err = orgRepo.Create(t.Context(), &organization)
|
||||
err = orgRepo.Create(t.Context(), tx, &organization)
|
||||
require.NoError(t, err)
|
||||
|
||||
// add domains
|
||||
domainRepo := orgRepo.Domains(false)
|
||||
domainName1 := gofakeit.DomainName()
|
||||
domainName2 := gofakeit.DomainName()
|
||||
|
||||
domain1 := &domain.AddOrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: domainName1,
|
||||
InstanceID: instance.ID,
|
||||
OrgID: organization.ID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsVerified: true,
|
||||
IsPrimary: true,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
}
|
||||
domain2 := &domain.AddOrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: domainName2,
|
||||
InstanceID: instance.ID,
|
||||
OrgID: organization.ID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsVerified: false,
|
||||
IsPrimary: false,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
|
||||
}
|
||||
|
||||
err = domainRepo.Add(t.Context(), domain1)
|
||||
err = domainRepo.Add(t.Context(), tx, domain1)
|
||||
require.NoError(t, err)
|
||||
err = domainRepo.Add(t.Context(), domain2)
|
||||
err = domainRepo.Add(t.Context(), tx, domain2)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
@@ -700,50 +766,70 @@ func TestRemoveOrganizationDomain(t *testing.T) {
|
||||
expected int64
|
||||
}{
|
||||
{
|
||||
name: "remove by domain name",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName1),
|
||||
expected: 1,
|
||||
name: "remove by domain name",
|
||||
condition: database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.OrgIDCondition(organization.ID),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, domain1.Domain),
|
||||
),
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "remove by primary condition",
|
||||
condition: domainRepo.IsPrimaryCondition(false),
|
||||
expected: 1, // domain2 should still exist and be non-primary
|
||||
name: "remove by primary condition",
|
||||
condition: database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.OrgIDCondition(organization.ID),
|
||||
domainRepo.IsPrimaryCondition(false),
|
||||
),
|
||||
expected: 1, // domain2 should still exist and be non-primary
|
||||
},
|
||||
{
|
||||
name: "remove by org ID and domain",
|
||||
condition: database.And(domainRepo.OrgIDCondition(orgID), domainRepo.DomainCondition(database.TextOperationEqual, domainName2)),
|
||||
expected: 1,
|
||||
name: "remove by org ID and domain",
|
||||
condition: database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.OrgIDCondition(organization.ID),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, domain2.Domain),
|
||||
),
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "remove non-existent domain",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
|
||||
expected: 0,
|
||||
name: "remove non-existent domain",
|
||||
condition: database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.OrgIDCondition(organization.ID),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
|
||||
),
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
|
||||
snapshot, err := tx.Begin(ctx)
|
||||
snapshot, err := tx.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, snapshot.Rollback(ctx))
|
||||
err = snapshot.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
orgRepo := repository.OrganizationRepository(snapshot)
|
||||
domainRepo := orgRepo.Domains(false)
|
||||
|
||||
// count before removal
|
||||
beforeCount, err := domainRepo.List(ctx)
|
||||
beforeCount, err := domainRepo.List(t.Context(), snapshot, database.WithCondition(database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.OrgIDCondition(organization.ID),
|
||||
)))
|
||||
require.NoError(t, err)
|
||||
|
||||
rowsAffected, err := domainRepo.Remove(ctx, test.condition)
|
||||
rowsAffected, err := domainRepo.Remove(t.Context(), snapshot, test.condition)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.expected, rowsAffected)
|
||||
|
||||
// verify removal
|
||||
afterCount, err := domainRepo.List(ctx)
|
||||
afterCount, err := domainRepo.List(t.Context(), snapshot, database.WithCondition(database.And(
|
||||
domainRepo.InstanceIDCondition(instance.ID),
|
||||
domainRepo.OrgIDCondition(organization.ID),
|
||||
)))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(beforeCount)-int(test.expected), len(afterCount))
|
||||
})
|
||||
@@ -751,8 +837,7 @@ func TestRemoveOrganizationDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestOrganizationDomainConditions(t *testing.T) {
|
||||
orgRepo := repository.OrganizationRepository(pool)
|
||||
domainRepo := orgRepo.Domains(false)
|
||||
domainRepo := repository.OrganizationDomainRepository()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -811,8 +896,7 @@ func TestOrganizationDomainConditions(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestOrganizationDomainChanges(t *testing.T) {
|
||||
orgRepo := repository.OrganizationRepository(pool)
|
||||
domainRepo := orgRepo.Domains(false)
|
||||
domainRepo := repository.OrganizationDomainRepository()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -851,8 +935,7 @@ func TestOrganizationDomainChanges(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestOrganizationDomainColumns(t *testing.T) {
|
||||
orgRepo := repository.OrganizationRepository(pool)
|
||||
domainRepo := orgRepo.Domains(false)
|
||||
domainRepo := repository.OrganizationDomainRepository()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package repository_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -15,6 +15,18 @@ import (
|
||||
)
|
||||
|
||||
func TestCreateOrganization(t *testing.T) {
|
||||
beforeCreate := time.Now()
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := tx.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Logf("error during rollback: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
organizationRepo := repository.OrganizationRepository()
|
||||
// create instance
|
||||
instanceId := gofakeit.Name()
|
||||
instance := domain.Instance{
|
||||
@@ -26,13 +38,12 @@ func TestCreateOrganization(t *testing.T) {
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
err := instanceRepo.Create(t.Context(), &instance)
|
||||
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) *domain.Organization
|
||||
testFunc func(t *testing.T, client database.QueryExecutor) *domain.Organization
|
||||
organization domain.Organization
|
||||
err error
|
||||
}{
|
||||
@@ -67,8 +78,7 @@ func TestCreateOrganization(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "adding org with same id twice",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
|
||||
organizationRepo := repository.OrganizationRepository(pool)
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Organization {
|
||||
organizationId := gofakeit.Name()
|
||||
organizationName := gofakeit.Name()
|
||||
|
||||
@@ -79,7 +89,7 @@ func TestCreateOrganization(t *testing.T) {
|
||||
State: domain.OrgStateActive,
|
||||
}
|
||||
|
||||
err := organizationRepo.Create(ctx, &org)
|
||||
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||
require.NoError(t, err)
|
||||
// change the name to make sure same only the id clashes
|
||||
org.Name = gofakeit.Name()
|
||||
@@ -89,8 +99,7 @@ func TestCreateOrganization(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "adding org with same name twice",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
|
||||
organizationRepo := repository.OrganizationRepository(pool)
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Organization {
|
||||
organizationId := gofakeit.Name()
|
||||
organizationName := gofakeit.Name()
|
||||
|
||||
@@ -101,7 +110,7 @@ func TestCreateOrganization(t *testing.T) {
|
||||
State: domain.OrgStateActive,
|
||||
}
|
||||
|
||||
err := organizationRepo.Create(ctx, &org)
|
||||
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||
require.NoError(t, err)
|
||||
// change the id to make sure same name+instance causes an error
|
||||
org.ID = gofakeit.Name()
|
||||
@@ -111,7 +120,7 @@ func TestCreateOrganization(t *testing.T) {
|
||||
},
|
||||
func() struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) *domain.Organization
|
||||
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Organization
|
||||
organization domain.Organization
|
||||
err error
|
||||
} {
|
||||
@@ -120,12 +129,12 @@ func TestCreateOrganization(t *testing.T) {
|
||||
|
||||
return struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) *domain.Organization
|
||||
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Organization
|
||||
organization domain.Organization
|
||||
err error
|
||||
}{
|
||||
name: "adding org with same name, different instance",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Organization {
|
||||
// create instance
|
||||
instId := gofakeit.Name()
|
||||
instance := domain.Instance{
|
||||
@@ -137,12 +146,9 @@ func TestCreateOrganization(t *testing.T) {
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
err := instanceRepo.Create(ctx, &instance)
|
||||
err := instanceRepo.Create(t.Context(), tx, &instance)
|
||||
assert.Nil(t, err)
|
||||
|
||||
organizationRepo := repository.OrganizationRepository(pool)
|
||||
|
||||
org := domain.Organization{
|
||||
ID: gofakeit.Name(),
|
||||
Name: organizationName,
|
||||
@@ -150,7 +156,7 @@ func TestCreateOrganization(t *testing.T) {
|
||||
State: domain.OrgStateActive,
|
||||
}
|
||||
|
||||
err = organizationRepo.Create(ctx, &org)
|
||||
err = organizationRepo.Create(t.Context(), tx, &org)
|
||||
require.NoError(t, err)
|
||||
|
||||
// change the id to make it unique
|
||||
@@ -214,19 +220,25 @@ func TestCreateOrganization(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
savepoint, err := tx.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = savepoint.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Logf("error during rollback: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
var organization *domain.Organization
|
||||
if tt.testFunc != nil {
|
||||
organization = tt.testFunc(ctx, t)
|
||||
organization = tt.testFunc(t, savepoint)
|
||||
} else {
|
||||
organization = &tt.organization
|
||||
}
|
||||
organizationRepo := repository.OrganizationRepository(pool)
|
||||
|
||||
// create organization
|
||||
beforeCreate := time.Now()
|
||||
err = organizationRepo.Create(ctx, organization)
|
||||
|
||||
err = organizationRepo.Create(t.Context(), savepoint, organization)
|
||||
assert.ErrorIs(t, err, tt.err)
|
||||
if err != nil {
|
||||
return
|
||||
@@ -234,7 +246,7 @@ func TestCreateOrganization(t *testing.T) {
|
||||
afterCreate := time.Now()
|
||||
|
||||
// check organization values
|
||||
organization, err = organizationRepo.Get(ctx,
|
||||
organization, err = organizationRepo.Get(t.Context(), savepoint,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
organizationRepo.IDCondition(organization.ID),
|
||||
@@ -255,6 +267,19 @@ func TestCreateOrganization(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUpdateOrganization(t *testing.T) {
|
||||
beforeUpdate := time.Now()
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := tx.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Logf("error during rollback: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
organizationRepo := repository.OrganizationRepository()
|
||||
|
||||
// create instance
|
||||
instanceId := gofakeit.Name()
|
||||
instance := domain.Instance{
|
||||
@@ -266,20 +291,18 @@ func TestUpdateOrganization(t *testing.T) {
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
err := instanceRepo.Create(t.Context(), &instance)
|
||||
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||
require.NoError(t, err)
|
||||
organizationRepo := repository.OrganizationRepository(pool)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) *domain.Organization
|
||||
testFunc func(t *testing.T) *domain.Organization
|
||||
update []database.Change
|
||||
rowsAffected int64
|
||||
}{
|
||||
{
|
||||
name: "happy path update name",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
|
||||
testFunc: func(t *testing.T) *domain.Organization {
|
||||
organizationId := gofakeit.Name()
|
||||
organizationName := gofakeit.Name()
|
||||
|
||||
@@ -291,7 +314,7 @@ func TestUpdateOrganization(t *testing.T) {
|
||||
}
|
||||
|
||||
// create organization
|
||||
err := organizationRepo.Create(ctx, &org)
|
||||
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||
require.NoError(t, err)
|
||||
|
||||
// update with updated value
|
||||
@@ -303,7 +326,7 @@ func TestUpdateOrganization(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "update deleted organization",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
|
||||
testFunc: func(t *testing.T) *domain.Organization {
|
||||
organizationId := gofakeit.Name()
|
||||
organizationName := gofakeit.Name()
|
||||
|
||||
@@ -315,13 +338,15 @@ func TestUpdateOrganization(t *testing.T) {
|
||||
}
|
||||
|
||||
// create organization
|
||||
err := organizationRepo.Create(ctx, &org)
|
||||
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||
require.NoError(t, err)
|
||||
|
||||
// delete instance
|
||||
_, err = organizationRepo.Delete(ctx,
|
||||
organizationRepo.IDCondition(org.ID),
|
||||
org.InstanceID,
|
||||
_, err = organizationRepo.Delete(t.Context(), tx,
|
||||
database.And(
|
||||
organizationRepo.InstanceIDCondition(org.InstanceID),
|
||||
organizationRepo.IDCondition(org.ID),
|
||||
),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -332,7 +357,7 @@ func TestUpdateOrganization(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "happy path change state",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
|
||||
testFunc: func(t *testing.T) *domain.Organization {
|
||||
organizationId := gofakeit.Name()
|
||||
organizationName := gofakeit.Name()
|
||||
|
||||
@@ -344,7 +369,7 @@ func TestUpdateOrganization(t *testing.T) {
|
||||
}
|
||||
|
||||
// create organization
|
||||
err := organizationRepo.Create(ctx, &org)
|
||||
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||
require.NoError(t, err)
|
||||
|
||||
// update with updated value
|
||||
@@ -356,7 +381,7 @@ func TestUpdateOrganization(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "update non existent organization",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
|
||||
testFunc: func(t *testing.T) *domain.Organization {
|
||||
organizationId := gofakeit.Name()
|
||||
|
||||
org := domain.Organization{
|
||||
@@ -370,16 +395,14 @@ func TestUpdateOrganization(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
organizationRepo := repository.OrganizationRepository(pool)
|
||||
|
||||
createdOrg := tt.testFunc(ctx, t)
|
||||
createdOrg := tt.testFunc(t)
|
||||
|
||||
// update org
|
||||
beforeUpdate := time.Now()
|
||||
rowsAffected, err := organizationRepo.Update(ctx,
|
||||
organizationRepo.IDCondition(createdOrg.ID),
|
||||
createdOrg.InstanceID,
|
||||
rowsAffected, err := organizationRepo.Update(t.Context(), tx,
|
||||
database.And(
|
||||
organizationRepo.InstanceIDCondition(createdOrg.InstanceID),
|
||||
organizationRepo.IDCondition(createdOrg.ID),
|
||||
),
|
||||
tt.update...,
|
||||
)
|
||||
afterUpdate := time.Now()
|
||||
@@ -392,7 +415,7 @@ func TestUpdateOrganization(t *testing.T) {
|
||||
}
|
||||
|
||||
// check organization values
|
||||
organization, err := organizationRepo.Get(ctx,
|
||||
organization, err := organizationRepo.Get(t.Context(), tx,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
organizationRepo.IDCondition(createdOrg.ID),
|
||||
@@ -411,6 +434,19 @@ func TestUpdateOrganization(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetOrganization(t *testing.T) {
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := tx.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Logf("error during rollback: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
orgRepo := repository.OrganizationRepository()
|
||||
orgDomainRepo := repository.OrganizationDomainRepository()
|
||||
|
||||
// create instance
|
||||
instanceId := gofakeit.Name()
|
||||
instance := domain.Instance{
|
||||
@@ -422,11 +458,9 @@ func TestGetOrganization(t *testing.T) {
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
err := instanceRepo.Create(t.Context(), &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
orgRepo := repository.OrganizationRepository(pool)
|
||||
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
// create organization
|
||||
// this org is created as an additional org which should NOT
|
||||
@@ -437,13 +471,13 @@ func TestGetOrganization(t *testing.T) {
|
||||
InstanceID: instanceId,
|
||||
State: domain.OrgStateActive,
|
||||
}
|
||||
err = orgRepo.Create(t.Context(), &org)
|
||||
err = orgRepo.Create(t.Context(), tx, &org)
|
||||
require.NoError(t, err)
|
||||
|
||||
type test struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) *domain.Organization
|
||||
orgIdentifierCondition domain.OrgIdentifierCondition
|
||||
testFunc func(t *testing.T) *domain.Organization
|
||||
orgIdentifierCondition database.Condition
|
||||
err error
|
||||
}
|
||||
|
||||
@@ -452,7 +486,7 @@ func TestGetOrganization(t *testing.T) {
|
||||
organizationId := gofakeit.Name()
|
||||
return test{
|
||||
name: "happy path get using id",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
|
||||
testFunc: func(t *testing.T) *domain.Organization {
|
||||
organizationName := gofakeit.Name()
|
||||
|
||||
org := domain.Organization{
|
||||
@@ -463,7 +497,7 @@ func TestGetOrganization(t *testing.T) {
|
||||
}
|
||||
|
||||
// create organization
|
||||
err := orgRepo.Create(ctx, &org)
|
||||
err := orgRepo.Create(t.Context(), tx, &org)
|
||||
require.NoError(t, err)
|
||||
|
||||
return &org
|
||||
@@ -475,7 +509,7 @@ func TestGetOrganization(t *testing.T) {
|
||||
organizationId := gofakeit.Name()
|
||||
return test{
|
||||
name: "happy path get using id including domain",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
|
||||
testFunc: func(t *testing.T) *domain.Organization {
|
||||
organizationName := gofakeit.Name()
|
||||
|
||||
org := domain.Organization{
|
||||
@@ -486,7 +520,7 @@ func TestGetOrganization(t *testing.T) {
|
||||
}
|
||||
|
||||
// create organization
|
||||
err := orgRepo.Create(ctx, &org)
|
||||
err := orgRepo.Create(t.Context(), tx, &org)
|
||||
require.NoError(t, err)
|
||||
|
||||
d := &domain.AddOrganizationDomain{
|
||||
@@ -496,7 +530,7 @@ func TestGetOrganization(t *testing.T) {
|
||||
IsVerified: true,
|
||||
IsPrimary: true,
|
||||
}
|
||||
err = orgRepo.Domains(false).Add(ctx, d)
|
||||
err = orgDomainRepo.Add(t.Context(), tx, d)
|
||||
require.NoError(t, err)
|
||||
|
||||
org.Domains = []*domain.OrganizationDomain{
|
||||
@@ -521,7 +555,7 @@ func TestGetOrganization(t *testing.T) {
|
||||
organizationName := gofakeit.Name()
|
||||
return test{
|
||||
name: "happy path get using name",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
|
||||
testFunc: func(t *testing.T) *domain.Organization {
|
||||
organizationId := gofakeit.Name()
|
||||
|
||||
org := domain.Organization{
|
||||
@@ -532,39 +566,36 @@ func TestGetOrganization(t *testing.T) {
|
||||
}
|
||||
|
||||
// create organization
|
||||
err := orgRepo.Create(ctx, &org)
|
||||
err := orgRepo.Create(t.Context(), tx, &org)
|
||||
require.NoError(t, err)
|
||||
|
||||
return &org
|
||||
},
|
||||
orgIdentifierCondition: orgRepo.NameCondition(organizationName),
|
||||
orgIdentifierCondition: orgRepo.NameCondition(database.TextOperationEqual, organizationName),
|
||||
}
|
||||
}(),
|
||||
{
|
||||
name: "get non existent organization",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
|
||||
testFunc: func(t *testing.T) *domain.Organization {
|
||||
org := domain.Organization{
|
||||
ID: "non existent org",
|
||||
Name: "non existent org",
|
||||
}
|
||||
return &org
|
||||
},
|
||||
orgIdentifierCondition: orgRepo.NameCondition("non-existent-instance-name"),
|
||||
orgIdentifierCondition: orgRepo.NameCondition(database.TextOperationEqual, "non-existent-instance-name"),
|
||||
err: new(database.NoRowFoundError),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
orgRepo := repository.OrganizationRepository(pool)
|
||||
|
||||
var org *domain.Organization
|
||||
if tt.testFunc != nil {
|
||||
org = tt.testFunc(ctx, t)
|
||||
org = tt.testFunc(t)
|
||||
}
|
||||
|
||||
// get org values
|
||||
returnedOrg, err := orgRepo.Get(ctx,
|
||||
returnedOrg, err := orgRepo.Get(t.Context(), tx,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
tt.orgIdentifierCondition,
|
||||
@@ -592,11 +623,17 @@ func TestGetOrganization(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestListOrganization(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
pool, stop, err := newEmbeddedDB(ctx)
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer stop()
|
||||
organizationRepo := repository.OrganizationRepository(pool)
|
||||
defer func() {
|
||||
err := tx.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Logf("error during rollback: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
organizationRepo := repository.OrganizationRepository()
|
||||
|
||||
// create instance
|
||||
instanceId := gofakeit.Name()
|
||||
@@ -609,33 +646,32 @@ func TestListOrganization(t *testing.T) {
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
err = instanceRepo.Create(ctx, &instance)
|
||||
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
type test struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) []*domain.Organization
|
||||
testFunc func(t *testing.T, tx database.QueryExecutor) []*domain.Organization
|
||||
conditionClauses []database.Condition
|
||||
noOrganizationReturned bool
|
||||
}
|
||||
tests := []test{
|
||||
{
|
||||
name: "happy path single organization no filter",
|
||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization {
|
||||
noOfOrganizations := 1
|
||||
organizations := make([]*domain.Organization, noOfOrganizations)
|
||||
for i := range noOfOrganizations {
|
||||
|
||||
org := domain.Organization{
|
||||
ID: gofakeit.Name(),
|
||||
ID: strconv.Itoa(i),
|
||||
Name: gofakeit.Name(),
|
||||
InstanceID: instanceId,
|
||||
State: domain.OrgStateActive,
|
||||
}
|
||||
|
||||
// create organization
|
||||
err := organizationRepo.Create(ctx, &org)
|
||||
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||
require.NoError(t, err)
|
||||
|
||||
organizations[i] = &org
|
||||
@@ -646,20 +682,20 @@ func TestListOrganization(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "happy path multiple organization no filter",
|
||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization {
|
||||
noOfOrganizations := 5
|
||||
organizations := make([]*domain.Organization, noOfOrganizations)
|
||||
for i := range noOfOrganizations {
|
||||
|
||||
org := domain.Organization{
|
||||
ID: gofakeit.Name(),
|
||||
ID: strconv.Itoa(i),
|
||||
Name: gofakeit.Name(),
|
||||
InstanceID: instanceId,
|
||||
State: domain.OrgStateActive,
|
||||
}
|
||||
|
||||
// create organization
|
||||
err := organizationRepo.Create(ctx, &org)
|
||||
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||
require.NoError(t, err)
|
||||
|
||||
organizations[i] = &org
|
||||
@@ -672,7 +708,7 @@ func TestListOrganization(t *testing.T) {
|
||||
organizationId := gofakeit.Name()
|
||||
return test{
|
||||
name: "organization filter on id",
|
||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization {
|
||||
// create organization
|
||||
// this org is created as an additional org which should NOT
|
||||
// be returned in the results of this test case
|
||||
@@ -682,7 +718,7 @@ func TestListOrganization(t *testing.T) {
|
||||
InstanceID: instanceId,
|
||||
State: domain.OrgStateActive,
|
||||
}
|
||||
err = organizationRepo.Create(ctx, &org)
|
||||
err = organizationRepo.Create(t.Context(), tx, &org)
|
||||
require.NoError(t, err)
|
||||
|
||||
noOfOrganizations := 1
|
||||
@@ -697,7 +733,7 @@ func TestListOrganization(t *testing.T) {
|
||||
}
|
||||
|
||||
// create organization
|
||||
err := organizationRepo.Create(ctx, &org)
|
||||
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||
require.NoError(t, err)
|
||||
|
||||
organizations[i] = &org
|
||||
@@ -705,12 +741,15 @@ func TestListOrganization(t *testing.T) {
|
||||
|
||||
return organizations
|
||||
},
|
||||
conditionClauses: []database.Condition{organizationRepo.IDCondition(organizationId)},
|
||||
conditionClauses: []database.Condition{
|
||||
organizationRepo.InstanceIDCondition(instanceId),
|
||||
organizationRepo.IDCondition(organizationId),
|
||||
},
|
||||
}
|
||||
}(),
|
||||
{
|
||||
name: "multiple organization filter on state",
|
||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization {
|
||||
// create organization
|
||||
// this org is created as an additional org which should NOT
|
||||
// be returned in the results of this test case
|
||||
@@ -720,7 +759,7 @@ func TestListOrganization(t *testing.T) {
|
||||
InstanceID: instanceId,
|
||||
State: domain.OrgStateActive,
|
||||
}
|
||||
err = organizationRepo.Create(ctx, &org)
|
||||
err = organizationRepo.Create(t.Context(), tx, &org)
|
||||
require.NoError(t, err)
|
||||
|
||||
noOfOrganizations := 5
|
||||
@@ -728,14 +767,14 @@ func TestListOrganization(t *testing.T) {
|
||||
for i := range noOfOrganizations {
|
||||
|
||||
org := domain.Organization{
|
||||
ID: gofakeit.Name(),
|
||||
ID: strconv.Itoa(i),
|
||||
Name: gofakeit.Name(),
|
||||
InstanceID: instanceId,
|
||||
State: domain.OrgStateInactive,
|
||||
}
|
||||
|
||||
// create organization
|
||||
err := organizationRepo.Create(ctx, &org)
|
||||
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||
require.NoError(t, err)
|
||||
|
||||
organizations[i] = &org
|
||||
@@ -743,13 +782,16 @@ func TestListOrganization(t *testing.T) {
|
||||
|
||||
return organizations
|
||||
},
|
||||
conditionClauses: []database.Condition{organizationRepo.StateCondition(domain.OrgStateInactive)},
|
||||
conditionClauses: []database.Condition{
|
||||
organizationRepo.InstanceIDCondition(instanceId),
|
||||
organizationRepo.StateCondition(domain.OrgStateInactive),
|
||||
},
|
||||
},
|
||||
func() test {
|
||||
instanceId_2 := gofakeit.Name()
|
||||
return test{
|
||||
name: "multiple organization filter on instance",
|
||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization {
|
||||
// create instance 1
|
||||
instanceId_1 := gofakeit.Name()
|
||||
instance := domain.Instance{
|
||||
@@ -761,8 +803,7 @@ func TestListOrganization(t *testing.T) {
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
err = instanceRepo.Create(ctx, &instance)
|
||||
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// create organization
|
||||
@@ -774,7 +815,7 @@ func TestListOrganization(t *testing.T) {
|
||||
InstanceID: instanceId_1,
|
||||
State: domain.OrgStateActive,
|
||||
}
|
||||
err = organizationRepo.Create(ctx, &org)
|
||||
err = organizationRepo.Create(t.Context(), tx, &org)
|
||||
require.NoError(t, err)
|
||||
|
||||
// create instance 2
|
||||
@@ -787,7 +828,7 @@ func TestListOrganization(t *testing.T) {
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
err = instanceRepo.Create(ctx, &instance_2)
|
||||
err = instanceRepo.Create(t.Context(), tx, &instance_2)
|
||||
assert.Nil(t, err)
|
||||
|
||||
noOfOrganizations := 5
|
||||
@@ -795,14 +836,14 @@ func TestListOrganization(t *testing.T) {
|
||||
for i := range noOfOrganizations {
|
||||
|
||||
org := domain.Organization{
|
||||
ID: gofakeit.Name(),
|
||||
ID: strconv.Itoa(i),
|
||||
Name: gofakeit.Name(),
|
||||
InstanceID: instanceId_2,
|
||||
State: domain.OrgStateActive,
|
||||
}
|
||||
|
||||
// create organization
|
||||
err := organizationRepo.Create(ctx, &org)
|
||||
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||
require.NoError(t, err)
|
||||
|
||||
organizations[i] = &org
|
||||
@@ -816,22 +857,25 @@ func TestListOrganization(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
_, err := pool.Exec(ctx, "DELETE FROM zitadel.organizations")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
savepoint, err := tx.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = savepoint.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Logf("error during rollback: %v", err)
|
||||
}
|
||||
}()
|
||||
organizations := tt.testFunc(t, savepoint)
|
||||
|
||||
organizations := tt.testFunc(ctx, t)
|
||||
|
||||
var condition database.Condition
|
||||
condition := organizationRepo.InstanceIDCondition(instanceId)
|
||||
if len(tt.conditionClauses) > 0 {
|
||||
condition = database.And(tt.conditionClauses...)
|
||||
}
|
||||
|
||||
// check organization values
|
||||
returnedOrgs, err := organizationRepo.List(ctx,
|
||||
returnedOrgs, err := organizationRepo.List(t.Context(), tx,
|
||||
database.WithCondition(condition),
|
||||
database.WithOrderByAscending(organizationRepo.CreatedAtColumn()),
|
||||
database.WithOrderByAscending(organizationRepo.CreatedAtColumn(), organizationRepo.IDColumn()),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
if tt.noOrganizationReturned {
|
||||
@@ -851,6 +895,15 @@ func TestListOrganization(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDeleteOrganization(t *testing.T) {
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = tx.Rollback(t.Context())
|
||||
})
|
||||
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
organizationRepo := repository.OrganizationRepository()
|
||||
|
||||
// create instance
|
||||
instanceId := gofakeit.Name()
|
||||
instance := domain.Instance{
|
||||
@@ -862,24 +915,22 @@ func TestDeleteOrganization(t *testing.T) {
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
err := instanceRepo.Create(t.Context(), &instance)
|
||||
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
type test struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T)
|
||||
orgIdentifierCondition domain.OrgIdentifierCondition
|
||||
testFunc func(t *testing.T)
|
||||
orgIdentifierCondition database.Condition
|
||||
noOfDeletedRows int64
|
||||
}
|
||||
tests := []test{
|
||||
func() test {
|
||||
organizationRepo := repository.OrganizationRepository(pool)
|
||||
organizationId := gofakeit.Name()
|
||||
var noOfOrganizations int64 = 1
|
||||
return test{
|
||||
name: "happy path delete organization filter id",
|
||||
testFunc: func(ctx context.Context, t *testing.T) {
|
||||
testFunc: func(t *testing.T) {
|
||||
organizations := make([]*domain.Organization, noOfOrganizations)
|
||||
for i := range noOfOrganizations {
|
||||
|
||||
@@ -891,7 +942,7 @@ func TestDeleteOrganization(t *testing.T) {
|
||||
}
|
||||
|
||||
// create organization
|
||||
err := organizationRepo.Create(ctx, &org)
|
||||
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||
require.NoError(t, err)
|
||||
|
||||
organizations[i] = &org
|
||||
@@ -902,12 +953,11 @@ func TestDeleteOrganization(t *testing.T) {
|
||||
}
|
||||
}(),
|
||||
func() test {
|
||||
organizationRepo := repository.OrganizationRepository(pool)
|
||||
organizationName := gofakeit.Name()
|
||||
var noOfOrganizations int64 = 1
|
||||
return test{
|
||||
name: "happy path delete organization filter name",
|
||||
testFunc: func(ctx context.Context, t *testing.T) {
|
||||
testFunc: func(t *testing.T) {
|
||||
organizations := make([]*domain.Organization, noOfOrganizations)
|
||||
for i := range noOfOrganizations {
|
||||
|
||||
@@ -919,30 +969,28 @@ func TestDeleteOrganization(t *testing.T) {
|
||||
}
|
||||
|
||||
// create organization
|
||||
err := organizationRepo.Create(ctx, &org)
|
||||
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||
require.NoError(t, err)
|
||||
|
||||
organizations[i] = &org
|
||||
}
|
||||
},
|
||||
orgIdentifierCondition: organizationRepo.NameCondition(organizationName),
|
||||
orgIdentifierCondition: organizationRepo.NameCondition(database.TextOperationEqual, organizationName),
|
||||
noOfDeletedRows: noOfOrganizations,
|
||||
}
|
||||
}(),
|
||||
func() test {
|
||||
organizationRepo := repository.OrganizationRepository(pool)
|
||||
non_existent_organization_name := gofakeit.Name()
|
||||
return test{
|
||||
name: "delete non existent organization",
|
||||
orgIdentifierCondition: organizationRepo.NameCondition(non_existent_organization_name),
|
||||
orgIdentifierCondition: organizationRepo.NameCondition(database.TextOperationEqual, non_existent_organization_name),
|
||||
}
|
||||
}(),
|
||||
func() test {
|
||||
organizationRepo := repository.OrganizationRepository(pool)
|
||||
organizationName := gofakeit.Name()
|
||||
return test{
|
||||
name: "deleted already deleted organization",
|
||||
testFunc: func(ctx context.Context, t *testing.T) {
|
||||
testFunc: func(t *testing.T) {
|
||||
noOfOrganizations := 1
|
||||
organizations := make([]*domain.Organization, noOfOrganizations)
|
||||
for i := range noOfOrganizations {
|
||||
@@ -955,21 +1003,23 @@ func TestDeleteOrganization(t *testing.T) {
|
||||
}
|
||||
|
||||
// create organization
|
||||
err := organizationRepo.Create(ctx, &org)
|
||||
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||
require.NoError(t, err)
|
||||
|
||||
organizations[i] = &org
|
||||
}
|
||||
|
||||
// delete organization
|
||||
affectedRows, err := organizationRepo.Delete(ctx,
|
||||
organizationRepo.NameCondition(organizationName),
|
||||
organizations[0].InstanceID,
|
||||
affectedRows, err := organizationRepo.Delete(t.Context(), tx,
|
||||
database.And(
|
||||
organizationRepo.InstanceIDCondition(organizations[0].InstanceID),
|
||||
organizationRepo.NameCondition(database.TextOperationEqual, organizationName),
|
||||
),
|
||||
)
|
||||
assert.Equal(t, int64(1), affectedRows)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
orgIdentifierCondition: organizationRepo.NameCondition(organizationName),
|
||||
orgIdentifierCondition: organizationRepo.NameCondition(database.TextOperationEqual, organizationName),
|
||||
// this test should return 0 affected rows as the org was already deleted
|
||||
noOfDeletedRows: 0,
|
||||
}
|
||||
@@ -977,23 +1027,22 @@ func TestDeleteOrganization(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
organizationRepo := repository.OrganizationRepository(pool)
|
||||
|
||||
if tt.testFunc != nil {
|
||||
tt.testFunc(ctx, t)
|
||||
tt.testFunc(t)
|
||||
}
|
||||
|
||||
// delete organization
|
||||
noOfDeletedRows, err := organizationRepo.Delete(ctx,
|
||||
tt.orgIdentifierCondition,
|
||||
instanceId,
|
||||
noOfDeletedRows, err := organizationRepo.Delete(t.Context(), tx,
|
||||
database.And(
|
||||
organizationRepo.InstanceIDCondition(instanceId),
|
||||
tt.orgIdentifierCondition,
|
||||
),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, noOfDeletedRows, tt.noOfDeletedRows)
|
||||
|
||||
// check organization was deleted
|
||||
organization, err := organizationRepo.Get(ctx,
|
||||
organization, err := organizationRepo.Get(t.Context(), tx,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
tt.orgIdentifierCondition,
|
||||
@@ -1006,3 +1055,99 @@ func TestDeleteOrganization(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOrganizationWithSubResources(t *testing.T) {
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = tx.Rollback(t.Context())
|
||||
})
|
||||
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
orgRepo := repository.OrganizationRepository()
|
||||
|
||||
// create instance
|
||||
instanceId := gofakeit.Name()
|
||||
err = instanceRepo.Create(t.Context(), tx, &domain.Instance{
|
||||
ID: instanceId,
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// create organization
|
||||
org := domain.Organization{
|
||||
ID: "1",
|
||||
Name: "org-name",
|
||||
InstanceID: instanceId,
|
||||
State: domain.OrgStateActive,
|
||||
}
|
||||
err = orgRepo.Create(t.Context(), tx, &org)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = orgRepo.Create(t.Context(), tx, &domain.Organization{
|
||||
ID: "without-sub-resources",
|
||||
Name: "org-name-2",
|
||||
InstanceID: instanceId,
|
||||
State: domain.OrgStateActive,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("domains", func(t *testing.T) {
|
||||
domainRepo := repository.OrganizationDomainRepository()
|
||||
|
||||
domain1 := &domain.AddOrganizationDomain{
|
||||
InstanceID: org.InstanceID,
|
||||
OrgID: org.ID,
|
||||
Domain: "domain1.com",
|
||||
IsVerified: true,
|
||||
IsPrimary: true,
|
||||
}
|
||||
err = domainRepo.Add(t.Context(), tx, domain1)
|
||||
require.NoError(t, err)
|
||||
|
||||
domain2 := &domain.AddOrganizationDomain{
|
||||
InstanceID: org.InstanceID,
|
||||
OrgID: org.ID,
|
||||
Domain: "domain2.com",
|
||||
IsVerified: false,
|
||||
IsPrimary: false,
|
||||
}
|
||||
err = domainRepo.Add(t.Context(), tx, domain2)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("org by domain", func(t *testing.T) {
|
||||
orgRepo := orgRepo.LoadDomains()
|
||||
|
||||
returnedOrg, err := orgRepo.Get(t.Context(), tx,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
orgRepo.InstanceIDCondition(instanceId),
|
||||
orgRepo.ExistsDomain(domainRepo.DomainCondition(database.TextOperationEqual, domain1.Domain)),
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, org.ID, returnedOrg.ID)
|
||||
assert.Len(t, returnedOrg.Domains, 2)
|
||||
})
|
||||
|
||||
t.Run("ensure org by domain works without LoadDomains", func(t *testing.T) {
|
||||
returnedOrg, err := orgRepo.Get(t.Context(), tx,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
orgRepo.InstanceIDCondition(instanceId),
|
||||
orgRepo.ExistsDomain(domainRepo.DomainCondition(database.TextOperationEqual, domain1.Domain)),
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, org.ID, returnedOrg.ID)
|
||||
assert.Len(t, returnedOrg.Domains, 0)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,10 +4,6 @@ import (
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type repository struct {
|
||||
client database.QueryExecutor
|
||||
}
|
||||
|
||||
func writeCondition(
|
||||
builder *database.StatementBuilder,
|
||||
condition database.Condition,
|
||||
|
||||
@@ -12,16 +12,10 @@ const queryUserStmt = `SELECT instance_id, org_id, id, username, type, created_a
|
||||
` first_name, last_name, email_address, email_verified_at, phone_number, phone_verified_at, description` +
|
||||
` FROM users_view users`
|
||||
|
||||
type user struct {
|
||||
repository
|
||||
}
|
||||
type user struct{}
|
||||
|
||||
func UserRepository(client database.QueryExecutor) domain.UserRepository {
|
||||
return &user{
|
||||
repository: repository{
|
||||
client: client,
|
||||
},
|
||||
}
|
||||
func UserRepository() domain.UserRepository {
|
||||
return new(user)
|
||||
}
|
||||
|
||||
var _ domain.UserRepository = (*user)(nil)
|
||||
@@ -31,17 +25,17 @@ var _ domain.UserRepository = (*user)(nil)
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// Human implements [domain.UserRepository].
|
||||
func (u *user) Human() domain.HumanRepository {
|
||||
func (u user) Human() domain.HumanRepository {
|
||||
return &userHuman{user: u}
|
||||
}
|
||||
|
||||
// Machine implements [domain.UserRepository].
|
||||
func (u *user) Machine() domain.MachineRepository {
|
||||
func (u user) Machine() domain.MachineRepository {
|
||||
return &userMachine{user: u}
|
||||
}
|
||||
|
||||
// List implements [domain.UserRepository].
|
||||
func (u *user) List(ctx context.Context, opts ...database.QueryOption) (users []*domain.User, err error) {
|
||||
func (u user) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (users []*domain.User, err error) {
|
||||
options := new(database.QueryOpts)
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
@@ -54,7 +48,7 @@ func (u *user) List(ctx context.Context, opts ...database.QueryOption) (users []
|
||||
options.WriteLimit(&builder)
|
||||
options.WriteOffset(&builder)
|
||||
|
||||
rows, err := u.client.Query(ctx, builder.String(), builder.Args()...)
|
||||
rows, err := client.Query(ctx, builder.String(), builder.Args()...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -79,7 +73,7 @@ func (u *user) List(ctx context.Context, opts ...database.QueryOption) (users []
|
||||
}
|
||||
|
||||
// Get implements [domain.UserRepository].
|
||||
func (u *user) Get(ctx context.Context, opts ...database.QueryOption) (*domain.User, error) {
|
||||
func (u user) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.User, error) {
|
||||
options := new(database.QueryOpts)
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
@@ -92,7 +86,7 @@ func (u *user) Get(ctx context.Context, opts ...database.QueryOption) (*domain.U
|
||||
options.WriteLimit(&builder)
|
||||
options.WriteOffset(&builder)
|
||||
|
||||
return scanUser(u.client.QueryRow(ctx, builder.String(), builder.Args()...))
|
||||
return scanUser(client.QueryRow(ctx, builder.String(), builder.Args()...))
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -105,7 +99,7 @@ const (
|
||||
)
|
||||
|
||||
// Create implements [domain.UserRepository].
|
||||
func (u *user) Create(ctx context.Context, user *domain.User) error {
|
||||
func (u user) Create(ctx context.Context, client database.QueryExecutor, user *domain.User) error {
|
||||
builder := database.StatementBuilder{}
|
||||
builder.AppendArgs(user.InstanceID, user.OrgID, user.ID, user.Username, user.Traits.Type())
|
||||
switch trait := user.Traits.(type) {
|
||||
@@ -116,15 +110,15 @@ func (u *user) Create(ctx context.Context, user *domain.User) error {
|
||||
builder.WriteString(createMachineStmt)
|
||||
builder.AppendArgs(trait.Description)
|
||||
}
|
||||
return u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&user.CreatedAt, &user.UpdatedAt)
|
||||
return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&user.CreatedAt, &user.UpdatedAt)
|
||||
}
|
||||
|
||||
// Delete implements [domain.UserRepository].
|
||||
func (u *user) Delete(ctx context.Context, condition database.Condition) error {
|
||||
func (u user) Delete(ctx context.Context, client database.QueryExecutor, condition database.Condition) error {
|
||||
builder := database.StatementBuilder{}
|
||||
builder.WriteString("DELETE FROM users")
|
||||
writeCondition(&builder, condition)
|
||||
_, err := u.client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
_, err := client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
// -------------------------------------------------------------
|
||||
|
||||
type userHuman struct {
|
||||
*user
|
||||
user
|
||||
}
|
||||
|
||||
var _ domain.HumanRepository = (*userHuman)(nil)
|
||||
@@ -21,14 +21,14 @@ var _ domain.HumanRepository = (*userHuman)(nil)
|
||||
const userEmailQuery = `SELECT h.email_address, h.email_verified_at FROM user_humans h`
|
||||
|
||||
// GetEmail implements [domain.HumanRepository].
|
||||
func (u *userHuman) GetEmail(ctx context.Context, condition database.Condition) (*domain.Email, error) {
|
||||
func (u *userHuman) GetEmail(ctx context.Context, client database.QueryExecutor, condition database.Condition) (*domain.Email, error) {
|
||||
var email domain.Email
|
||||
|
||||
builder := database.StatementBuilder{}
|
||||
builder.WriteString(userEmailQuery)
|
||||
writeCondition(&builder, condition)
|
||||
|
||||
err := u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(
|
||||
err := client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(
|
||||
&email.Address,
|
||||
&email.VerifiedAt,
|
||||
)
|
||||
@@ -39,7 +39,7 @@ func (u *userHuman) GetEmail(ctx context.Context, condition database.Condition)
|
||||
}
|
||||
|
||||
// Update implements [domain.HumanRepository].
|
||||
func (h userHuman) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error {
|
||||
func (h userHuman) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) error {
|
||||
builder := database.StatementBuilder{}
|
||||
builder.WriteString(`UPDATE human_users SET `)
|
||||
database.Changes(changes).Write(&builder)
|
||||
@@ -47,7 +47,7 @@ func (h userHuman) Update(ctx context.Context, condition database.Condition, cha
|
||||
|
||||
stmt := builder.String()
|
||||
|
||||
_, err := h.client.Exec(ctx, stmt, builder.Args()...)
|
||||
_, err := client.Exec(ctx, stmt, builder.Args()...)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
type userMachine struct {
|
||||
*user
|
||||
user
|
||||
}
|
||||
|
||||
var _ domain.MachineRepository = (*userMachine)(nil)
|
||||
@@ -18,14 +18,14 @@ var _ domain.MachineRepository = (*userMachine)(nil)
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// Update implements [domain.MachineRepository].
|
||||
func (m userMachine) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error {
|
||||
func (m userMachine) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) error {
|
||||
builder := database.StatementBuilder{}
|
||||
builder.WriteString("UPDATE user_machines SET ")
|
||||
database.Changes(changes).Write(&builder)
|
||||
writeCondition(&builder, condition)
|
||||
m.writeReturning()
|
||||
|
||||
_, err := m.client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
_, err := client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
@@ -20,8 +21,16 @@ type StatementBuilder struct {
|
||||
existingArgs map[any]string
|
||||
}
|
||||
|
||||
type argWriter interface {
|
||||
WriteArg(builder *StatementBuilder)
|
||||
}
|
||||
|
||||
// WriteArgs adds the argument to the statement and writes the placeholder to the query.
|
||||
func (b *StatementBuilder) WriteArg(arg any) {
|
||||
if writer, ok := arg.(argWriter); ok {
|
||||
writer.WriteArg(b)
|
||||
return
|
||||
}
|
||||
b.WriteString(b.AppendArg(arg))
|
||||
}
|
||||
|
||||
@@ -41,7 +50,13 @@ func (b *StatementBuilder) AppendArg(arg any) (placeholder string) {
|
||||
if b.existingArgs == nil {
|
||||
b.existingArgs = make(map[any]string)
|
||||
}
|
||||
if placeholder, ok := b.existingArgs[arg]; ok {
|
||||
// the key is used to work around the following panic:
|
||||
// runtime error: hash of unhashable type []uint8
|
||||
key := arg
|
||||
if argBytes, ok := arg.([]uint8); ok {
|
||||
key = `\\bytes-` + hex.EncodeToString(argBytes)
|
||||
}
|
||||
if placeholder, ok := b.existingArgs[key]; ok {
|
||||
return placeholder
|
||||
}
|
||||
if instruction, ok := arg.(Instruction); ok {
|
||||
@@ -50,7 +65,7 @@ func (b *StatementBuilder) AppendArg(arg any) (placeholder string) {
|
||||
|
||||
b.args = append(b.args, arg)
|
||||
placeholder = "$" + strconv.Itoa(len(b.args))
|
||||
b.existingArgs[arg] = placeholder
|
||||
b.existingArgs[key] = placeholder
|
||||
return placeholder
|
||||
}
|
||||
|
||||
|
||||
144
backend/v3/storage/database/statement_test.go
Normal file
144
backend/v3/storage/database/statement_test.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStatementBuilder_AppendArg(t *testing.T) {
|
||||
t.Run("same arg returns same placeholder", func(t *testing.T) {
|
||||
var b StatementBuilder
|
||||
placeholder1 := b.AppendArg("same")
|
||||
placeholder2 := b.AppendArg("same")
|
||||
assert.Equal(t, placeholder1, placeholder2)
|
||||
assert.Len(t, b.Args(), 1)
|
||||
assert.Len(t, b.existingArgs, 1)
|
||||
})
|
||||
|
||||
t.Run("same arg different types", func(t *testing.T) {
|
||||
var b StatementBuilder
|
||||
placeholder1 := b.AppendArg("same")
|
||||
placeholder2 := b.AppendArg([]byte("same"))
|
||||
placeholder3 := b.AppendArg("same")
|
||||
assert.NotEqual(t, placeholder1, placeholder2)
|
||||
assert.Equal(t, placeholder1, placeholder3)
|
||||
assert.Len(t, b.Args(), 2)
|
||||
assert.Len(t, b.existingArgs, 2)
|
||||
})
|
||||
|
||||
t.Run("Instruction args are always different", func(t *testing.T) {
|
||||
var b StatementBuilder
|
||||
placeholder1 := b.AppendArg(DefaultInstruction)
|
||||
placeholder2 := b.AppendArg(DefaultInstruction)
|
||||
assert.Equal(t, placeholder1, placeholder2)
|
||||
assert.Len(t, b.Args(), 0)
|
||||
assert.Len(t, b.existingArgs, 0)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStatementBuilder_AppendArgs(t *testing.T) {
|
||||
t.Run("same arg returns same placeholder", func(t *testing.T) {
|
||||
var b StatementBuilder
|
||||
b.AppendArgs("same", "same")
|
||||
assert.Len(t, b.Args(), 1)
|
||||
assert.Len(t, b.existingArgs, 1)
|
||||
})
|
||||
|
||||
t.Run("same arg different types", func(t *testing.T) {
|
||||
var b StatementBuilder
|
||||
b.AppendArgs("same", []byte("same"), "same")
|
||||
assert.Len(t, b.Args(), 2)
|
||||
assert.Len(t, b.existingArgs, 2)
|
||||
})
|
||||
|
||||
t.Run("Instruction args are always different", func(t *testing.T) {
|
||||
var b StatementBuilder
|
||||
b.AppendArgs(DefaultInstruction, DefaultInstruction)
|
||||
assert.Len(t, b.Args(), 0)
|
||||
assert.Len(t, b.existingArgs, 0)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStatementBuilder_WriteArg(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
arg any
|
||||
wantSQL string
|
||||
wantArg []any
|
||||
}{
|
||||
{
|
||||
name: "primitive arg",
|
||||
arg: "test",
|
||||
wantSQL: "$1",
|
||||
wantArg: []any{"test"},
|
||||
},
|
||||
{
|
||||
name: "argWriter arg",
|
||||
arg: SHA256Value("wrapped"),
|
||||
wantSQL: "SHA256($1)",
|
||||
wantArg: []any{"wrapped"},
|
||||
},
|
||||
{
|
||||
name: "Instruction arg",
|
||||
arg: DefaultInstruction,
|
||||
wantSQL: "DEFAULT",
|
||||
wantArg: []any{},
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var b StatementBuilder
|
||||
b.WriteArg(test.arg)
|
||||
assert.Equal(t, test.wantSQL, b.String())
|
||||
require.Len(t, b.Args(), len(test.wantArg))
|
||||
for i := range test.wantArg {
|
||||
assert.Equal(t, test.wantArg[i], b.Args()[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatementBuilder_WriteArgs(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
args []any
|
||||
wantSQL string
|
||||
wantArg []any
|
||||
}{
|
||||
{
|
||||
name: "primitive args",
|
||||
args: []any{"test", 123, true, uint32(123)},
|
||||
wantSQL: "$1, $2, $3, $4",
|
||||
wantArg: []any{"test", 123, true, uint32(123)},
|
||||
},
|
||||
{
|
||||
name: "argWriter args",
|
||||
args: []any{SHA256Value("wrapped"), LowerValue("ASDF")},
|
||||
wantSQL: "SHA256($1), LOWER($2)",
|
||||
wantArg: []any{"wrapped", "ASDF"},
|
||||
},
|
||||
{
|
||||
name: "Instruction args",
|
||||
args: []any{DefaultInstruction, NowInstruction, NullInstruction},
|
||||
wantSQL: "DEFAULT, NOW(), NULL",
|
||||
wantArg: []any{},
|
||||
},
|
||||
{
|
||||
name: "mixed args",
|
||||
args: []any{123, uint32(123), NowInstruction, NullInstruction, SHA256Value("wrapped"), LowerValue("ASDF")},
|
||||
wantSQL: "$1, $2, NOW(), NULL, SHA256($3), LOWER($4)",
|
||||
wantArg: []any{123, uint32(123), "wrapped", "ASDF"},
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var b StatementBuilder
|
||||
b.WriteArgs(test.args...)
|
||||
assert.Equal(t, test.wantSQL, b.String())
|
||||
require.Len(t, b.Args(), len(test.wantArg))
|
||||
for i := range test.wantArg {
|
||||
assert.Equal(t, test.wantArg[i], b.Args()[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -66,7 +66,7 @@ func (p *instanceDomainRelationalProjection) reduceCustomDomainAdded(event event
|
||||
if !ok {
|
||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-bXCa6", "reduce.wrong.db.pool %T", ex)
|
||||
}
|
||||
return repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false).Add(ctx, &domain.AddInstanceDomain{
|
||||
return repository.InstanceDomainRepository().Add(ctx, v3_sql.SQLTx(tx), &domain.AddInstanceDomain{
|
||||
InstanceID: e.Aggregate().InstanceID,
|
||||
Domain: e.Domain,
|
||||
IsPrimary: gu.Ptr(false),
|
||||
@@ -88,25 +88,15 @@ func (p *instanceDomainRelationalProjection) reduceDomainPrimarySet(event events
|
||||
if !ok {
|
||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-QnjHo", "reduce.wrong.db.pool %T", ex)
|
||||
}
|
||||
domainRepo := repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false)
|
||||
domainRepo := repository.InstanceDomainRepository()
|
||||
|
||||
condition := database.And(
|
||||
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
|
||||
domainRepo.TypeCondition(domain.DomainTypeCustom),
|
||||
)
|
||||
|
||||
_, err := domainRepo.Update(ctx,
|
||||
condition,
|
||||
_, err := domainRepo.Update(ctx, v3_sql.SQLTx(tx),
|
||||
database.And(
|
||||
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
|
||||
domainRepo.TypeCondition(domain.DomainTypeCustom),
|
||||
),
|
||||
domainRepo.SetPrimary(),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// we need to split the update into two statements because multiple events can have the same creation date
|
||||
// therefore we first do not set the updated_at timestamp
|
||||
_, err = domainRepo.Update(ctx,
|
||||
condition,
|
||||
domainRepo.SetUpdatedAt(e.CreationDate()),
|
||||
)
|
||||
return err
|
||||
@@ -123,8 +113,8 @@ func (p *instanceDomainRelationalProjection) reduceCustomDomainRemoved(event eve
|
||||
if !ok {
|
||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-58ghE", "reduce.wrong.db.pool %T", ex)
|
||||
}
|
||||
domainRepo := repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false)
|
||||
_, err := domainRepo.Remove(ctx,
|
||||
domainRepo := repository.InstanceDomainRepository()
|
||||
_, err := domainRepo.Remove(ctx, v3_sql.SQLTx(tx),
|
||||
database.And(
|
||||
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
|
||||
@@ -145,7 +135,7 @@ func (p *instanceDomainRelationalProjection) reduceTrustedDomainAdded(event even
|
||||
if !ok {
|
||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-gx7tQ", "reduce.wrong.db.pool %T", ex)
|
||||
}
|
||||
return repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false).Add(ctx, &domain.AddInstanceDomain{
|
||||
return repository.InstanceDomainRepository().Add(ctx, v3_sql.SQLTx(tx), &domain.AddInstanceDomain{
|
||||
InstanceID: e.Aggregate().InstanceID,
|
||||
Domain: e.Domain,
|
||||
Type: domain.DomainTypeTrusted,
|
||||
@@ -165,8 +155,8 @@ func (p *instanceDomainRelationalProjection) reduceTrustedDomainRemoved(event ev
|
||||
if !ok {
|
||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-D68ap", "reduce.wrong.db.pool %T", ex)
|
||||
}
|
||||
domainRepo := repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false)
|
||||
_, err := domainRepo.Remove(ctx,
|
||||
domainRepo := repository.InstanceDomainRepository()
|
||||
_, err := domainRepo.Remove(ctx, v3_sql.SQLTx(tx),
|
||||
database.And(
|
||||
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
|
||||
|
||||
@@ -74,7 +74,7 @@ func (p *instanceRelationalProjection) reduceInstanceAdded(event eventstore.Even
|
||||
if !ok {
|
||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
|
||||
}
|
||||
return repository.InstanceRepository(v3_sql.SQLTx(tx)).Create(ctx, &domain.Instance{
|
||||
return repository.InstanceRepository().Create(ctx, v3_sql.SQLTx(tx), &domain.Instance{
|
||||
ID: e.Aggregate().ID,
|
||||
Name: e.Name,
|
||||
CreatedAt: e.CreationDate(),
|
||||
@@ -93,8 +93,8 @@ func (p *instanceRelationalProjection) reduceInstanceChanged(event eventstore.Ev
|
||||
if !ok {
|
||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
|
||||
}
|
||||
repo := repository.InstanceRepository(v3_sql.SQLTx(tx))
|
||||
return p.updateInstance(ctx, event, repo, repo.SetName(e.Name))
|
||||
repo := repository.InstanceRepository()
|
||||
return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetName(e.Name))
|
||||
}), nil
|
||||
}
|
||||
|
||||
@@ -108,7 +108,7 @@ func (p *instanceRelationalProjection) reduceInstanceDelete(event eventstore.Eve
|
||||
if !ok {
|
||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
|
||||
}
|
||||
_, err := repository.InstanceRepository(v3_sql.SQLTx(tx)).Delete(ctx, e.Aggregate().ID)
|
||||
_, err := repository.InstanceRepository().Delete(ctx, v3_sql.SQLTx(tx), e.Aggregate().ID)
|
||||
return err
|
||||
}), nil
|
||||
}
|
||||
@@ -124,8 +124,8 @@ func (p *instanceRelationalProjection) reduceDefaultOrgSet(event eventstore.Even
|
||||
if !ok {
|
||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
|
||||
}
|
||||
repo := repository.InstanceRepository(v3_sql.SQLTx(tx))
|
||||
return p.updateInstance(ctx, event, repo, repo.SetDefaultOrg(e.OrgID))
|
||||
repo := repository.InstanceRepository()
|
||||
return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetDefaultOrg(e.OrgID))
|
||||
}), nil
|
||||
}
|
||||
|
||||
@@ -140,8 +140,8 @@ func (p *instanceRelationalProjection) reduceIAMProjectSet(event eventstore.Even
|
||||
if !ok {
|
||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
|
||||
}
|
||||
repo := repository.InstanceRepository(v3_sql.SQLTx(tx))
|
||||
return p.updateInstance(ctx, event, repo, repo.SetIAMProject(e.ProjectID))
|
||||
repo := repository.InstanceRepository()
|
||||
return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetIAMProject(e.ProjectID))
|
||||
}), nil
|
||||
}
|
||||
|
||||
@@ -156,8 +156,8 @@ func (p *instanceRelationalProjection) reduceConsoleSet(event eventstore.Event)
|
||||
if !ok {
|
||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
|
||||
}
|
||||
repo := repository.InstanceRepository(v3_sql.SQLTx(tx))
|
||||
return p.updateInstance(ctx, event, repo, repo.SetConsoleClientID(e.ClientID), repo.SetConsoleAppID(e.AppID))
|
||||
repo := repository.InstanceRepository()
|
||||
return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetConsoleClientID(e.ClientID), repo.SetConsoleAppID(e.AppID))
|
||||
}), nil
|
||||
}
|
||||
|
||||
@@ -172,18 +172,18 @@ func (p *instanceRelationalProjection) reduceDefaultLanguageSet(event eventstore
|
||||
if !ok {
|
||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
|
||||
}
|
||||
repo := repository.InstanceRepository(v3_sql.SQLTx(tx))
|
||||
return p.updateInstance(ctx, event, repo, repo.SetDefaultLanguage(e.Language))
|
||||
repo := repository.InstanceRepository()
|
||||
return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetDefaultLanguage(e.Language))
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (p *instanceRelationalProjection) updateInstance(ctx context.Context, event eventstore.Event, repo domain.InstanceRepository, changes ...database.Change) error {
|
||||
_, err := repo.Update(ctx, event.Aggregate().ID, changes...)
|
||||
func (p *instanceRelationalProjection) updateInstance(ctx context.Context, tx database.Transaction, event eventstore.Event, repo domain.InstanceRepository, changes ...database.Change) error {
|
||||
_, err := repo.Update(ctx, tx, event.Aggregate().ID, changes...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
instance, err := repo.Get(ctx, database.WithCondition(repo.IDCondition(event.Aggregate().ID)))
|
||||
instance, err := repo.Get(ctx, tx, database.WithCondition(repo.IDCondition(event.Aggregate().ID)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -192,7 +192,7 @@ func (p *instanceRelationalProjection) updateInstance(ctx context.Context, event
|
||||
}
|
||||
// we need to split the update into two statements because multiple events can have the same creation date
|
||||
// therefore we first do not set the updated_at timestamp
|
||||
_, err = repo.Update(ctx,
|
||||
_, err = repo.Update(ctx, tx,
|
||||
event.Aggregate().ID,
|
||||
repo.SetUpdatedAt(event.CreatedAt()),
|
||||
)
|
||||
|
||||
@@ -65,7 +65,7 @@ func (p *orgDomainRelationalProjection) reduceAdded(event eventstore.Event) (*ha
|
||||
if !ok {
|
||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-kGokE", "reduce.wrong.db.pool %T", ex)
|
||||
}
|
||||
return repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false).Add(ctx, &domain.AddOrganizationDomain{
|
||||
return repository.OrganizationDomainRepository().Add(ctx, v3_sql.SQLTx(tx), &domain.AddOrganizationDomain{
|
||||
InstanceID: e.Aggregate().InstanceID,
|
||||
OrgID: e.Aggregate().ResourceOwner,
|
||||
Domain: e.Domain,
|
||||
@@ -85,23 +85,14 @@ func (p *orgDomainRelationalProjection) reducePrimarySet(event eventstore.Event)
|
||||
if !ok {
|
||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-h6xF0", "reduce.wrong.db.pool %T", ex)
|
||||
}
|
||||
domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false)
|
||||
condition := database.And(
|
||||
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
||||
domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
|
||||
)
|
||||
_, err := domainRepo.Update(ctx,
|
||||
condition,
|
||||
domainRepo := repository.OrganizationDomainRepository()
|
||||
_, err := domainRepo.Update(ctx, v3_sql.SQLTx(tx),
|
||||
database.And(
|
||||
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
||||
domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
|
||||
),
|
||||
domainRepo.SetPrimary(),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// we need to split the update into two statements because multiple events can have the same creation date
|
||||
// therefore we first do not set the updated_at timestamp
|
||||
_, err = domainRepo.Update(ctx,
|
||||
condition,
|
||||
domainRepo.SetUpdatedAt(e.CreationDate()),
|
||||
)
|
||||
return err
|
||||
@@ -118,8 +109,8 @@ func (p *orgDomainRelationalProjection) reduceRemoved(event eventstore.Event) (*
|
||||
if !ok {
|
||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-X8oS8", "reduce.wrong.db.pool %T", ex)
|
||||
}
|
||||
domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false)
|
||||
_, err := domainRepo.Remove(ctx,
|
||||
domainRepo := repository.OrganizationDomainRepository()
|
||||
_, err := domainRepo.Remove(ctx, v3_sql.SQLTx(tx),
|
||||
database.And(
|
||||
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
||||
domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
|
||||
@@ -149,24 +140,15 @@ func (p *orgDomainRelationalProjection) reduceVerificationAdded(event eventstore
|
||||
if !ok {
|
||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-yF03i", "reduce.wrong.db.pool %T", ex)
|
||||
}
|
||||
domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false)
|
||||
condition := database.And(
|
||||
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
||||
domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
|
||||
)
|
||||
domainRepo := repository.OrganizationDomainRepository()
|
||||
|
||||
_, err := domainRepo.Update(ctx,
|
||||
condition,
|
||||
_, err := domainRepo.Update(ctx, v3_sql.SQLTx(tx),
|
||||
database.And(
|
||||
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
||||
domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
|
||||
),
|
||||
domainRepo.SetValidationType(validationType),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// we need to split the update into two statements because multiple events can have the same creation date
|
||||
// therefore we first do not set the updated_at timestamp
|
||||
_, err = domainRepo.Update(ctx,
|
||||
condition,
|
||||
domainRepo.SetUpdatedAt(e.CreationDate()),
|
||||
)
|
||||
return err
|
||||
@@ -183,28 +165,17 @@ func (p *orgDomainRelationalProjection) reduceVerified(event eventstore.Event) (
|
||||
if !ok {
|
||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-0ZGqC", "reduce.wrong.db.pool %T", ex)
|
||||
}
|
||||
domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false)
|
||||
domainRepo := repository.OrganizationDomainRepository()
|
||||
|
||||
condition := database.And(
|
||||
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
||||
domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
|
||||
)
|
||||
|
||||
_, err := domainRepo.Update(ctx,
|
||||
condition,
|
||||
_, err := domainRepo.Update(ctx, v3_sql.SQLTx(tx),
|
||||
database.And(
|
||||
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
||||
domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
|
||||
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
|
||||
),
|
||||
domainRepo.SetVerified(),
|
||||
domainRepo.SetUpdatedAt(e.CreationDate()),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// we need to split the update into two statements because multiple events can have the same creation date
|
||||
// therefore we first do not set the updated_at timestamp
|
||||
_, err = domainRepo.Update(ctx,
|
||||
condition,
|
||||
domainRepo.SetUpdatedAt(e.CreationDate()),
|
||||
)
|
||||
return err
|
||||
}), nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user