mirror of
https://github.com/zitadel/zitadel.git
synced 2025-12-23 12:16:42 +00:00
refactor: database interaction and error handling (#10762)
This pull request introduces a significant refactoring of the database interaction layer, focusing on improving explicitness, transactional control, and error handling. The core change is the removal of the stateful `QueryExecutor` from repository instances. Instead, it is now passed as an argument to each method that interacts with the database. This change makes transaction management more explicit and flexible, as the same repository instance can be used with a database pool or a specific transaction without needing to be re-instantiated. ### Key Changes - **Explicit `QueryExecutor` Passing:** - All repository methods (`Get`, `List`, `Create`, `Update`, `Delete`, etc.) in `InstanceRepository`, `OrganizationRepository`, `UserRepository`, and their sub-repositories now require a `database.QueryExecutor` (e.g., a `*pgxpool.Pool` or `pgx.Tx`) as the first argument. - Repository constructors no longer accept a `QueryExecutor`. For example, `repository.InstanceRepository(pool)` is now `repository.InstanceRepository()`. - **Enhanced Error Handling:** - A new `database.MissingConditionError` is introduced to enforce required query conditions, such as ensuring an `instance_id` is always present in `UPDATE` and `DELETE` operations. - The database error wrapper in the `postgres` package now correctly identifies and wraps `pgx.ErrTooManyRows` and similar errors from the `scany` library into a `database.MultipleRowsFoundError`. - **Improved Database Conditions:** - The `database.Condition` interface now includes a `ContainsColumn(Column) bool` method. This allows for runtime checks to ensure that critical filters (like `instance_id`) are included in a query, preventing accidental cross-tenant data modification. - A new `database.Exists()` condition has been added to support `EXISTS` subqueries, enabling more complex filtering logic, such as finding an organization that has a specific domain. - **Repository and Interface Refactoring:** - The method for loading related entities (e.g., domains for an organization) has been changed from a boolean flag (`Domains(true)`) to a more explicit, chainable method (`LoadDomains()`). This returns a new repository instance configured to load the sub-resource, promoting immutability. - The custom `OrgIdentifierCondition` has been removed in favor of using the standard `database.Condition` interface, simplifying the API. - **Code Cleanup and Test Updates:** - Unnecessary struct embeddings and metadata have been removed. - All integration and repository tests have been updated to reflect the new method signatures, passing the database pool or transaction object explicitly. - New tests have been added to cover the new `ExistsDomain` functionality and other enhancements. These changes make the data access layer more robust, predictable, and easier to work with, especially in the context of database transactions.
This commit is contained in:
@@ -69,6 +69,8 @@ type instanceConditions interface {
|
|||||||
IDCondition(instanceID string) database.Condition
|
IDCondition(instanceID string) database.Condition
|
||||||
// NameCondition returns a filter on the name field.
|
// NameCondition returns a filter on the name field.
|
||||||
NameCondition(op database.TextOperation, name string) database.Condition
|
NameCondition(op database.TextOperation, name string) database.Condition
|
||||||
|
// ExistsDomain returns a filter on the instance domains.
|
||||||
|
ExistsDomain(cond database.Condition) database.Condition
|
||||||
}
|
}
|
||||||
|
|
||||||
// instanceChanges define all the changes for the instance table.
|
// instanceChanges define all the changes for the instance table.
|
||||||
@@ -99,17 +101,16 @@ type InstanceRepository interface {
|
|||||||
// Member returns the member repository which is a sub repository of the instance repository.
|
// Member returns the member repository which is a sub repository of the instance repository.
|
||||||
// Member() MemberRepository
|
// Member() MemberRepository
|
||||||
|
|
||||||
Get(ctx context.Context, opts ...database.QueryOption) (*Instance, error)
|
Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*Instance, error)
|
||||||
List(ctx context.Context, opts ...database.QueryOption) ([]*Instance, error)
|
List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*Instance, error)
|
||||||
|
|
||||||
Create(ctx context.Context, instance *Instance) error
|
Create(ctx context.Context, client database.QueryExecutor, instance *Instance) error
|
||||||
Update(ctx context.Context, id string, changes ...database.Change) (int64, error)
|
Update(ctx context.Context, client database.QueryExecutor, id string, changes ...database.Change) (int64, error)
|
||||||
Delete(ctx context.Context, id string) (int64, error)
|
Delete(ctx context.Context, client database.QueryExecutor, id string) (int64, error)
|
||||||
|
|
||||||
// Domains returns the domain sub repository for the instance.
|
// LoadDomains loads the domains of the given instance.
|
||||||
// If shouldLoad is true, the domains will be loaded from the database and written to the [Instance].Domains field.
|
// If it is called the [Instance].Domains field will be set on future calls to Get or List.
|
||||||
// If shouldLoad is set to true once, the Domains field will be set even if shouldLoad is false in the future.
|
LoadDomains() InstanceRepository
|
||||||
Domains(shouldLoad bool) InstanceDomainRepository
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateInstance struct {
|
type CreateInstance struct {
|
||||||
|
|||||||
@@ -65,15 +65,15 @@ type InstanceDomainRepository interface {
|
|||||||
// Get returns a single domain based on the criteria.
|
// Get returns a single domain based on the criteria.
|
||||||
// If no domain is found, it returns an error of type [database.ErrNotFound].
|
// If no domain is found, it returns an error of type [database.ErrNotFound].
|
||||||
// If multiple domains are found, it returns an error of type [database.ErrMultipleRows].
|
// If multiple domains are found, it returns an error of type [database.ErrMultipleRows].
|
||||||
Get(ctx context.Context, opts ...database.QueryOption) (*InstanceDomain, error)
|
Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*InstanceDomain, error)
|
||||||
// List returns a list of domains based on the criteria.
|
// List returns a list of domains based on the criteria.
|
||||||
// If no domains are found, it returns an empty slice.
|
// If no domains are found, it returns an empty slice.
|
||||||
List(ctx context.Context, opts ...database.QueryOption) ([]*InstanceDomain, error)
|
List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*InstanceDomain, error)
|
||||||
|
|
||||||
// Add adds a new domain to the instance.
|
// Add adds a new domain to the instance.
|
||||||
Add(ctx context.Context, domain *AddInstanceDomain) error
|
Add(ctx context.Context, client database.QueryExecutor, domain *AddInstanceDomain) error
|
||||||
// Update updates an existing domain in the instance.
|
// Update updates an existing domain in the instance.
|
||||||
Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error)
|
Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error)
|
||||||
// Remove removes a domain from the instance.
|
// Remove removes a domain from the instance.
|
||||||
Remove(ctx context.Context, condition database.Condition) (int64, error)
|
Remove(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,13 +26,6 @@ type Organization struct {
|
|||||||
Domains []*OrganizationDomain `json:"domains,omitempty" db:"-"` // domains need to be handled separately
|
Domains []*OrganizationDomain `json:"domains,omitempty" db:"-"` // domains need to be handled separately
|
||||||
}
|
}
|
||||||
|
|
||||||
// OrgIdentifierCondition is used to help specify a single Organization,
|
|
||||||
// it will either be used as the organization ID or organization name,
|
|
||||||
// as organizations can be identified either using (instanceID + ID) OR (instanceID + name)
|
|
||||||
type OrgIdentifierCondition interface {
|
|
||||||
database.Condition
|
|
||||||
}
|
|
||||||
|
|
||||||
// organizationColumns define all the columns of the instance table.
|
// organizationColumns define all the columns of the instance table.
|
||||||
type organizationColumns interface {
|
type organizationColumns interface {
|
||||||
// IDColumn returns the column for the id field.
|
// IDColumn returns the column for the id field.
|
||||||
@@ -52,13 +45,15 @@ type organizationColumns interface {
|
|||||||
// organizationConditions define all the conditions for the instance table.
|
// organizationConditions define all the conditions for the instance table.
|
||||||
type organizationConditions interface {
|
type organizationConditions interface {
|
||||||
// IDCondition returns an equal filter on the id field.
|
// IDCondition returns an equal filter on the id field.
|
||||||
IDCondition(instanceID string) OrgIdentifierCondition
|
IDCondition(instanceID string) database.Condition
|
||||||
// NameCondition returns a filter on the name field.
|
// NameCondition returns a filter on the name field.
|
||||||
NameCondition(name string) OrgIdentifierCondition
|
NameCondition(op database.TextOperation, name string) database.Condition
|
||||||
// InstanceIDCondition returns a filter on the instance id field.
|
// InstanceIDCondition returns a filter on the instance id field.
|
||||||
InstanceIDCondition(instanceID string) database.Condition
|
InstanceIDCondition(instanceID string) database.Condition
|
||||||
// StateCondition returns a filter on the name field.
|
// StateCondition returns a filter on the name field.
|
||||||
StateCondition(state OrgState) database.Condition
|
StateCondition(state OrgState) database.Condition
|
||||||
|
// ExistsDomain returns a filter on the organizations domains.
|
||||||
|
ExistsDomain(cond database.Condition) database.Condition
|
||||||
}
|
}
|
||||||
|
|
||||||
// organizationChanges define all the changes for the instance table.
|
// organizationChanges define all the changes for the instance table.
|
||||||
@@ -75,17 +70,16 @@ type OrganizationRepository interface {
|
|||||||
organizationConditions
|
organizationConditions
|
||||||
organizationChanges
|
organizationChanges
|
||||||
|
|
||||||
Get(ctx context.Context, opts ...database.QueryOption) (*Organization, error)
|
Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*Organization, error)
|
||||||
List(ctx context.Context, opts ...database.QueryOption) ([]*Organization, error)
|
List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*Organization, error)
|
||||||
|
|
||||||
Create(ctx context.Context, instance *Organization) error
|
Create(ctx context.Context, client database.QueryExecutor, org *Organization) error
|
||||||
Update(ctx context.Context, id OrgIdentifierCondition, instance_id string, changes ...database.Change) (int64, error)
|
Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error)
|
||||||
Delete(ctx context.Context, id OrgIdentifierCondition, instance_id string) (int64, error)
|
Delete(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error)
|
||||||
|
|
||||||
// Domains returns the domain sub repository for the organization.
|
// LoadDomains loads the domains of the given organizations.
|
||||||
// If shouldLoad is true, the domains will be loaded from the database and written to the [Instance].Domains field.
|
// If it is called the [Organization].Domains field will be set on future calls to Get or List.
|
||||||
// If shouldLoad is set to true once, the Domains field will be set event if shouldLoad is false in the future.
|
LoadDomains() OrganizationRepository
|
||||||
Domains(shouldLoad bool) OrganizationDomainRepository
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateOrganization struct {
|
type CreateOrganization struct {
|
||||||
|
|||||||
@@ -70,15 +70,15 @@ type OrganizationDomainRepository interface {
|
|||||||
// Get returns a single domain based on the criteria.
|
// Get returns a single domain based on the criteria.
|
||||||
// If no domain is found, it returns an error of type [database.ErrNotFound].
|
// If no domain is found, it returns an error of type [database.ErrNotFound].
|
||||||
// If multiple domains are found, it returns an error of type [database.ErrMultipleRows].
|
// If multiple domains are found, it returns an error of type [database.ErrMultipleRows].
|
||||||
Get(ctx context.Context, opts ...database.QueryOption) (*OrganizationDomain, error)
|
Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*OrganizationDomain, error)
|
||||||
// List returns a list of domains based on the criteria.
|
// List returns a list of domains based on the criteria.
|
||||||
// If no domains are found, it returns an empty slice.
|
// If no domains are found, it returns an empty slice.
|
||||||
List(ctx context.Context, opts ...database.QueryOption) ([]*OrganizationDomain, error)
|
List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*OrganizationDomain, error)
|
||||||
|
|
||||||
// Add adds a new domain to the organization.
|
// Add adds a new domain to the organization.
|
||||||
Add(ctx context.Context, domain *AddOrganizationDomain) error
|
Add(ctx context.Context, client database.QueryExecutor, domain *AddOrganizationDomain) error
|
||||||
// Update updates an existing domain in the organization.
|
// Update updates an existing domain in the organization.
|
||||||
Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error)
|
Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error)
|
||||||
// Remove removes a domain from the organization.
|
// Remove removes a domain from the organization.
|
||||||
Remove(ctx context.Context, condition database.Condition) (int64, error)
|
Remove(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,13 +57,13 @@ type UserRepository interface {
|
|||||||
userConditions
|
userConditions
|
||||||
userChanges
|
userChanges
|
||||||
// Get returns a user based on the given condition.
|
// Get returns a user based on the given condition.
|
||||||
Get(ctx context.Context, opts ...database.QueryOption) (*User, error)
|
Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*User, error)
|
||||||
// List returns a list of users based on the given condition.
|
// List returns a list of users based on the given condition.
|
||||||
List(ctx context.Context, opts ...database.QueryOption) ([]*User, error)
|
List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*User, error)
|
||||||
// Create creates a new user.
|
// Create creates a new user.
|
||||||
Create(ctx context.Context, user *User) error
|
Create(ctx context.Context, client database.QueryExecutor, user *User) error
|
||||||
// Delete removes users based on the given condition.
|
// Delete removes users based on the given condition.
|
||||||
Delete(ctx context.Context, condition database.Condition) error
|
Delete(ctx context.Context, client database.QueryExecutor, condition database.Condition) error
|
||||||
// Human returns the [HumanRepository].
|
// Human returns the [HumanRepository].
|
||||||
Human() HumanRepository
|
Human() HumanRepository
|
||||||
// Machine returns the [MachineRepository].
|
// Machine returns the [MachineRepository].
|
||||||
@@ -143,9 +143,9 @@ type HumanRepository interface {
|
|||||||
humanChanges
|
humanChanges
|
||||||
|
|
||||||
// Get returns an email based on the given condition.
|
// Get returns an email based on the given condition.
|
||||||
GetEmail(ctx context.Context, condition database.Condition) (*Email, error)
|
GetEmail(ctx context.Context, client database.QueryExecutor, condition database.Condition) (*Email, error)
|
||||||
// Update updates human users based on the given condition and changes.
|
// Update updates human users based on the given condition and changes.
|
||||||
Update(ctx context.Context, condition database.Condition, changes ...database.Change) error
|
Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// machineColumns define all the columns of the machine table which inherits the user table.
|
// machineColumns define all the columns of the machine table which inherits the user table.
|
||||||
@@ -172,7 +172,7 @@ type machineChanges interface {
|
|||||||
// MachineRepository is the interface for the machine repository it inherits the user repository.
|
// MachineRepository is the interface for the machine repository it inherits the user repository.
|
||||||
type MachineRepository interface {
|
type MachineRepository interface {
|
||||||
// Update updates machine users based on the given condition and changes.
|
// Update updates machine users based on the given condition and changes.
|
||||||
Update(ctx context.Context, condition database.Condition, changes ...database.Change) error
|
Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) error
|
||||||
|
|
||||||
machineColumns
|
machineColumns
|
||||||
machineConditions
|
machineConditions
|
||||||
|
|||||||
@@ -1,9 +1,14 @@
|
|||||||
package database
|
package database
|
||||||
|
|
||||||
|
import "slices"
|
||||||
|
|
||||||
// Change represents a change to a column in a database table.
|
// Change represents a change to a column in a database table.
|
||||||
// Its written in the SET clause of an UPDATE statement.
|
// Its written in the SET clause of an UPDATE statement.
|
||||||
type Change interface {
|
type Change interface {
|
||||||
|
// Write writes the change to the given statement builder.
|
||||||
Write(builder *StatementBuilder)
|
Write(builder *StatementBuilder)
|
||||||
|
// IsOnColumn checks if the change is on the given column.
|
||||||
|
IsOnColumn(col Column) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type change[V Value] struct {
|
type change[V Value] struct {
|
||||||
@@ -13,6 +18,8 @@ type change[V Value] struct {
|
|||||||
|
|
||||||
var _ Change = (*change[string])(nil)
|
var _ Change = (*change[string])(nil)
|
||||||
|
|
||||||
|
// NewChange creates a new Change for the given column and value.
|
||||||
|
// If you want to set a column to NULL, use [NewChangePtr].
|
||||||
func NewChange[V Value](col Column, value V) Change {
|
func NewChange[V Value](col Column, value V) Change {
|
||||||
return &change[V]{
|
return &change[V]{
|
||||||
column: col,
|
column: col,
|
||||||
@@ -20,6 +27,8 @@ func NewChange[V Value](col Column, value V) Change {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewChangePtr creates a new Change for the given column and value pointer.
|
||||||
|
// If the value pointer is nil, the column will be set to NULL.
|
||||||
func NewChangePtr[V Value](col Column, value *V) Change {
|
func NewChangePtr[V Value](col Column, value *V) Change {
|
||||||
if value == nil {
|
if value == nil {
|
||||||
return NewChange(col, NullInstruction)
|
return NewChange(col, NullInstruction)
|
||||||
@@ -34,19 +43,31 @@ func (c change[V]) Write(builder *StatementBuilder) {
|
|||||||
builder.WriteArg(c.value)
|
builder.WriteArg(c.value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsOnColumn implements [Change].
|
||||||
|
func (c change[V]) IsOnColumn(col Column) bool {
|
||||||
|
return c.column.Equals(col)
|
||||||
|
}
|
||||||
|
|
||||||
type Changes []Change
|
type Changes []Change
|
||||||
|
|
||||||
func NewChanges(cols ...Change) Change {
|
func NewChanges(cols ...Change) Change {
|
||||||
return Changes(cols)
|
return Changes(cols)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsOnColumn implements [Change].
|
||||||
|
func (c Changes) IsOnColumn(col Column) bool {
|
||||||
|
return slices.ContainsFunc(c, func(change Change) bool {
|
||||||
|
return change.IsOnColumn(col)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Write implements [Change].
|
// Write implements [Change].
|
||||||
func (m Changes) Write(builder *StatementBuilder) {
|
func (m Changes) Write(builder *StatementBuilder) {
|
||||||
for i, col := range m {
|
for i, change := range m {
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
}
|
}
|
||||||
col.Write(builder)
|
change.Write(builder)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
68
backend/v3/storage/database/change_test.go
Normal file
68
backend/v3/storage/database/change_test.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/muhlemmer/gu"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestChangeWrite(t *testing.T) {
|
||||||
|
type want struct {
|
||||||
|
stmt string
|
||||||
|
args []any
|
||||||
|
}
|
||||||
|
for _, test := range []struct {
|
||||||
|
name string
|
||||||
|
change Change
|
||||||
|
want want
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "change",
|
||||||
|
change: NewChange(NewColumn("table", "column"), "value"),
|
||||||
|
want: want{
|
||||||
|
stmt: "column = $1",
|
||||||
|
args: []any{"value"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "change ptr to null",
|
||||||
|
change: NewChangePtr[int](NewColumn("table", "column"), nil),
|
||||||
|
want: want{
|
||||||
|
stmt: "column = NULL",
|
||||||
|
args: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "change ptr to value",
|
||||||
|
change: NewChangePtr(NewColumn("table", "column"), gu.Ptr(42)),
|
||||||
|
want: want{
|
||||||
|
stmt: "column = $1",
|
||||||
|
args: []any{42},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple changes",
|
||||||
|
change: NewChanges(
|
||||||
|
NewChange(NewColumn("table", "column1"), "value1"),
|
||||||
|
NewChangePtr[int](NewColumn("table", "column2"), nil),
|
||||||
|
NewChange(NewColumn("table", "column3"), 123),
|
||||||
|
),
|
||||||
|
want: want{
|
||||||
|
stmt: "column1 = $1, column2 = NULL, column3 = $2",
|
||||||
|
args: []any{"value1", 123},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
var builder StatementBuilder
|
||||||
|
test.change.Write(&builder)
|
||||||
|
assert.Equal(t, test.want.stmt, builder.String())
|
||||||
|
require.Len(t, builder.Args(), len(test.want.args))
|
||||||
|
for i, arg := range test.want.args {
|
||||||
|
assert.Equal(t, arg, builder.Args()[i])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,8 +3,9 @@ package database
|
|||||||
type Columns []Column
|
type Columns []Column
|
||||||
|
|
||||||
// WriteQualified implements [Column].
|
// WriteQualified implements [Column].
|
||||||
func (m Columns) WriteQualified(builder *StatementBuilder) {
|
// Columns are separated by ", ".
|
||||||
for i, col := range m {
|
func (c Columns) WriteQualified(builder *StatementBuilder) {
|
||||||
|
for i, col := range c {
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
}
|
}
|
||||||
@@ -13,8 +14,9 @@ func (m Columns) WriteQualified(builder *StatementBuilder) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WriteUnqualified implements [Column].
|
// WriteUnqualified implements [Column].
|
||||||
func (m Columns) WriteUnqualified(builder *StatementBuilder) {
|
// Columns are separated by ", ".
|
||||||
for i, col := range m {
|
func (c Columns) WriteUnqualified(builder *StatementBuilder) {
|
||||||
|
for i, col := range c {
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
builder.WriteString(", ")
|
builder.WriteString(", ")
|
||||||
}
|
}
|
||||||
@@ -22,11 +24,33 @@ func (m Columns) WriteUnqualified(builder *StatementBuilder) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Equals implements [Column].
|
||||||
|
func (c Columns) Equals(col Column) bool {
|
||||||
|
if col == nil {
|
||||||
|
return c == nil
|
||||||
|
}
|
||||||
|
other, ok := col.(Columns)
|
||||||
|
if !ok || len(other) != len(c) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i, col := range c {
|
||||||
|
if !col.Equals(other[i]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Column = (Columns)(nil)
|
||||||
|
|
||||||
// Column represents a column in a database table.
|
// Column represents a column in a database table.
|
||||||
type Column interface {
|
type Column interface {
|
||||||
// Write(builder *StatementBuilder)
|
// WriteQualified writes the column with the table name as prefix.
|
||||||
WriteQualified(builder *StatementBuilder)
|
WriteQualified(builder *StatementBuilder)
|
||||||
|
// WriteUnqualified writes the column without the table name as prefix.
|
||||||
WriteUnqualified(builder *StatementBuilder)
|
WriteUnqualified(builder *StatementBuilder)
|
||||||
|
// Equals checks if two columns are equal.
|
||||||
|
Equals(col Column) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type column struct {
|
type column struct {
|
||||||
@@ -35,7 +59,7 @@ type column struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewColumn(table, name string) Column {
|
func NewColumn(table, name string) Column {
|
||||||
return column{table: table, name: name}
|
return &column{table: table, name: name}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteQualified implements [Column].
|
// WriteQualified implements [Column].
|
||||||
@@ -51,35 +75,69 @@ func (c column) WriteUnqualified(builder *StatementBuilder) {
|
|||||||
builder.WriteString(c.name)
|
builder.WriteString(c.name)
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Column = (*column)(nil)
|
// Equals implements [Column].
|
||||||
|
func (c *column) Equals(col Column) bool {
|
||||||
|
if col == nil {
|
||||||
|
return c == nil
|
||||||
|
}
|
||||||
|
toMatch, ok := col.(*column)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return c.table == toMatch.table && c.name == toMatch.name
|
||||||
|
}
|
||||||
|
|
||||||
// // ignoreCaseColumn represents two database columns, one for the
|
// LowerColumn returns a column that represents LOWER(col).
|
||||||
// // original value and one for the lower case value.
|
func LowerColumn(col Column) Column {
|
||||||
// type ignoreCaseColumn interface {
|
return &functionColumn{fn: functionLower, col: col}
|
||||||
// Column
|
}
|
||||||
// WriteIgnoreCase(builder *StatementBuilder)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// func NewIgnoreCaseColumn(col Column, suffix string) ignoreCaseColumn {
|
// SHA256Column returns a column that represents SHA256(col).
|
||||||
// return ignoreCaseCol{
|
func SHA256Column(col Column) Column {
|
||||||
// column: col,
|
return &functionColumn{fn: functionSHA256, col: col}
|
||||||
// suffix: suffix,
|
}
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// type ignoreCaseCol struct {
|
type functionColumn struct {
|
||||||
// column Column
|
fn function
|
||||||
// suffix string
|
col Column
|
||||||
// }
|
}
|
||||||
|
|
||||||
// // WriteIgnoreCase implements [ignoreCaseColumn].
|
type function string
|
||||||
// func (c ignoreCaseCol) WriteIgnoreCase(builder *StatementBuilder) {
|
|
||||||
// c.column.WriteQualified(builder)
|
|
||||||
// builder.WriteString(c.suffix)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // WriteQualified implements [ignoreCaseColumn].
|
const (
|
||||||
// func (c ignoreCaseCol) WriteQualified(builder *StatementBuilder) {
|
_ function = ""
|
||||||
// c.column.WriteQualified(builder)
|
functionLower function = "LOWER"
|
||||||
// builder.WriteString(c.suffix)
|
functionSHA256 function = "SHA256"
|
||||||
// }
|
)
|
||||||
|
|
||||||
|
// WriteQualified implements [Column].
|
||||||
|
func (c functionColumn) WriteQualified(builder *StatementBuilder) {
|
||||||
|
builder.Grow(len(c.fn) + 2)
|
||||||
|
builder.WriteString(string(c.fn))
|
||||||
|
builder.WriteRune('(')
|
||||||
|
c.col.WriteQualified(builder)
|
||||||
|
builder.WriteRune(')')
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteUnqualified implements [Column].
|
||||||
|
func (c functionColumn) WriteUnqualified(builder *StatementBuilder) {
|
||||||
|
builder.Grow(len(c.fn) + 2)
|
||||||
|
builder.WriteString(string(c.fn))
|
||||||
|
builder.WriteRune('(')
|
||||||
|
c.col.WriteUnqualified(builder)
|
||||||
|
builder.WriteRune(')')
|
||||||
|
}
|
||||||
|
|
||||||
|
// Equals implements [Column].
|
||||||
|
func (c *functionColumn) Equals(col Column) bool {
|
||||||
|
if col == nil {
|
||||||
|
return c == nil
|
||||||
|
}
|
||||||
|
toMatch, ok := col.(*functionColumn)
|
||||||
|
if !ok || toMatch.fn != c.fn {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return c.col.Equals(toMatch.col)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Column = (*functionColumn)(nil)
|
||||||
|
|||||||
212
backend/v3/storage/database/column_test.go
Normal file
212
backend/v3/storage/database/column_test.go
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWriteUnqualified(t *testing.T) {
|
||||||
|
for _, tests := range []struct {
|
||||||
|
name string
|
||||||
|
column Column
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "column",
|
||||||
|
column: NewColumn("table", "column"),
|
||||||
|
expected: "column",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "columns",
|
||||||
|
column: Columns{
|
||||||
|
NewColumn("table", "column1"),
|
||||||
|
NewColumn("table", "column2"),
|
||||||
|
},
|
||||||
|
expected: "column1, column2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "function column",
|
||||||
|
column: SHA256Column(NewColumn("table", "column")),
|
||||||
|
expected: "SHA256(column)",
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(tests.name, func(t *testing.T) {
|
||||||
|
var builder StatementBuilder
|
||||||
|
tests.column.WriteUnqualified(&builder)
|
||||||
|
assert.Equal(t, tests.expected, builder.String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteQualified(t *testing.T) {
|
||||||
|
for _, tests := range []struct {
|
||||||
|
name string
|
||||||
|
column Column
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "column",
|
||||||
|
column: NewColumn("table", "column"),
|
||||||
|
expected: "table.column",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "columns",
|
||||||
|
column: Columns{
|
||||||
|
NewColumn("table", "column1"),
|
||||||
|
NewColumn("table", "column2"),
|
||||||
|
},
|
||||||
|
expected: "table.column1, table.column2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "function column",
|
||||||
|
column: SHA256Column(NewColumn("table", "column")),
|
||||||
|
expected: "SHA256(table.column)",
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(tests.name, func(t *testing.T) {
|
||||||
|
var builder StatementBuilder
|
||||||
|
tests.column.WriteQualified(&builder)
|
||||||
|
assert.Equal(t, tests.expected, builder.String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEquals(t *testing.T) {
|
||||||
|
for _, tests := range []struct {
|
||||||
|
name string
|
||||||
|
column Column
|
||||||
|
toCheck Column
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "column equal",
|
||||||
|
column: NewColumn("table", "column"),
|
||||||
|
toCheck: NewColumn("table", "column"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "column nil check",
|
||||||
|
column: NewColumn("table", "column"),
|
||||||
|
toCheck: nil,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "column both nil",
|
||||||
|
column: (*column)(nil),
|
||||||
|
toCheck: nil,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "column not equal (different name)",
|
||||||
|
column: NewColumn("table", "column"),
|
||||||
|
toCheck: NewColumn("table", "column2"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "column not equal (different type)",
|
||||||
|
column: NewColumn("table", "column"),
|
||||||
|
toCheck: SHA256Column(NewColumn("table", "column")),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "columns equal",
|
||||||
|
column: Columns{
|
||||||
|
NewColumn("table", "column1"),
|
||||||
|
NewColumn("table", "column2"),
|
||||||
|
},
|
||||||
|
toCheck: Columns{
|
||||||
|
NewColumn("table", "column1"),
|
||||||
|
NewColumn("table", "column2"),
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "columns nil check",
|
||||||
|
column: Columns{
|
||||||
|
NewColumn("table", "column1"),
|
||||||
|
NewColumn("table", "column2"),
|
||||||
|
},
|
||||||
|
toCheck: nil,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "columns both nil",
|
||||||
|
column: Columns(nil),
|
||||||
|
toCheck: nil,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "columns not equal (different type)",
|
||||||
|
column: Columns{
|
||||||
|
NewColumn("table", "column1"),
|
||||||
|
NewColumn("table", "column2"),
|
||||||
|
},
|
||||||
|
toCheck: NewColumn("table", "column1"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "columns not equal (different length)",
|
||||||
|
column: Columns{
|
||||||
|
NewColumn("table", "column1"),
|
||||||
|
NewColumn("table", "column2"),
|
||||||
|
},
|
||||||
|
toCheck: Columns{
|
||||||
|
NewColumn("table", "column1"),
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "columns not equal (different order)",
|
||||||
|
column: Columns{
|
||||||
|
NewColumn("table", "column1"),
|
||||||
|
NewColumn("table", "column2"),
|
||||||
|
},
|
||||||
|
toCheck: Columns{
|
||||||
|
NewColumn("table", "column2"),
|
||||||
|
NewColumn("table", "column1"),
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "function column equal",
|
||||||
|
column: SHA256Column(NewColumn("table", "column")),
|
||||||
|
toCheck: SHA256Column(NewColumn("table", "column")),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "function nil check",
|
||||||
|
column: SHA256Column(NewColumn("table", "column")),
|
||||||
|
toCheck: nil,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "function both nil",
|
||||||
|
column: (*functionColumn)(nil),
|
||||||
|
toCheck: nil,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "function column not equal (different function)",
|
||||||
|
column: SHA256Column(NewColumn("table", "column")),
|
||||||
|
toCheck: LowerColumn(NewColumn("table", "column")),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "function column not equal (different inner column)",
|
||||||
|
column: SHA256Column(NewColumn("table", "column")),
|
||||||
|
toCheck: SHA256Column(NewColumn("table", "column2")),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "function column not equal (different type)",
|
||||||
|
column: SHA256Column(NewColumn("table", "column")),
|
||||||
|
toCheck: NewColumn("table", "column2"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(tests.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tests.expected, tests.column.Equals(tests.toCheck))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,6 +4,9 @@ package database
|
|||||||
// Its written after the WHERE keyword in a SQL statement.
|
// Its written after the WHERE keyword in a SQL statement.
|
||||||
type Condition interface {
|
type Condition interface {
|
||||||
Write(builder *StatementBuilder)
|
Write(builder *StatementBuilder)
|
||||||
|
// IsRestrictingColumn is used to check if the condition filters for a specific column.
|
||||||
|
// It acts as a save guard database operations that should be specific on the given column.
|
||||||
|
IsRestrictingColumn(col Column) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type and struct {
|
type and struct {
|
||||||
@@ -11,7 +14,7 @@ type and struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Write implements [Condition].
|
// Write implements [Condition].
|
||||||
func (a *and) Write(builder *StatementBuilder) {
|
func (a and) Write(builder *StatementBuilder) {
|
||||||
if len(a.conditions) > 1 {
|
if len(a.conditions) > 1 {
|
||||||
builder.WriteString("(")
|
builder.WriteString("(")
|
||||||
defer builder.WriteString(")")
|
defer builder.WriteString(")")
|
||||||
@@ -29,6 +32,16 @@ func And(conditions ...Condition) *and {
|
|||||||
return &and{conditions: conditions}
|
return &and{conditions: conditions}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsRestrictingColumn implements [Condition].
|
||||||
|
func (a and) IsRestrictingColumn(col Column) bool {
|
||||||
|
for _, condition := range a.conditions {
|
||||||
|
if condition.IsRestrictingColumn(col) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
var _ Condition = (*and)(nil)
|
var _ Condition = (*and)(nil)
|
||||||
|
|
||||||
type or struct {
|
type or struct {
|
||||||
@@ -36,7 +49,7 @@ type or struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Write implements [Condition].
|
// Write implements [Condition].
|
||||||
func (o *or) Write(builder *StatementBuilder) {
|
func (o or) Write(builder *StatementBuilder) {
|
||||||
if len(o.conditions) > 1 {
|
if len(o.conditions) > 1 {
|
||||||
builder.WriteString("(")
|
builder.WriteString("(")
|
||||||
defer builder.WriteString(")")
|
defer builder.WriteString(")")
|
||||||
@@ -54,6 +67,17 @@ func Or(conditions ...Condition) *or {
|
|||||||
return &or{conditions: conditions}
|
return &or{conditions: conditions}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsRestrictingColumn implements [Condition].
|
||||||
|
// It returns true only if all conditions are restricting the given column.
|
||||||
|
func (o or) IsRestrictingColumn(col Column) bool {
|
||||||
|
for _, condition := range o.conditions {
|
||||||
|
if !condition.IsRestrictingColumn(col) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
var _ Condition = (*or)(nil)
|
var _ Condition = (*or)(nil)
|
||||||
|
|
||||||
type isNull struct {
|
type isNull struct {
|
||||||
@@ -61,7 +85,7 @@ type isNull struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Write implements [Condition].
|
// Write implements [Condition].
|
||||||
func (i *isNull) Write(builder *StatementBuilder) {
|
func (i isNull) Write(builder *StatementBuilder) {
|
||||||
i.column.WriteQualified(builder)
|
i.column.WriteQualified(builder)
|
||||||
builder.WriteString(" IS NULL")
|
builder.WriteString(" IS NULL")
|
||||||
}
|
}
|
||||||
@@ -71,6 +95,12 @@ func IsNull(column Column) *isNull {
|
|||||||
return &isNull{column: column}
|
return &isNull{column: column}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsRestrictingColumn implements [Condition].
|
||||||
|
// It returns false because it cannot be used for restricting a column.
|
||||||
|
func (i isNull) IsRestrictingColumn(col Column) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
var _ Condition = (*isNull)(nil)
|
var _ Condition = (*isNull)(nil)
|
||||||
|
|
||||||
type isNotNull struct {
|
type isNotNull struct {
|
||||||
@@ -78,7 +108,7 @@ type isNotNull struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Write implements [Condition].
|
// Write implements [Condition].
|
||||||
func (i *isNotNull) Write(builder *StatementBuilder) {
|
func (i isNotNull) Write(builder *StatementBuilder) {
|
||||||
i.column.WriteQualified(builder)
|
i.column.WriteQualified(builder)
|
||||||
builder.WriteString(" IS NOT NULL")
|
builder.WriteString(" IS NOT NULL")
|
||||||
}
|
}
|
||||||
@@ -88,43 +118,122 @@ func IsNotNull(column Column) *isNotNull {
|
|||||||
return &isNotNull{column: column}
|
return &isNotNull{column: column}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsRestrictingColumn implements [Condition].
|
||||||
|
// It returns false because it cannot be used for restricting a column.
|
||||||
|
func (i isNotNull) IsRestrictingColumn(col Column) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
var _ Condition = (*isNotNull)(nil)
|
var _ Condition = (*isNotNull)(nil)
|
||||||
|
|
||||||
type valueCondition func(builder *StatementBuilder)
|
type valueCondition struct {
|
||||||
|
write func(builder *StatementBuilder)
|
||||||
|
col Column
|
||||||
|
}
|
||||||
|
|
||||||
// NewTextCondition creates a condition that compares a text column with a value.
|
// NewTextCondition creates a condition that compares a text column with a value.
|
||||||
func NewTextCondition[V Text](col Column, op TextOperation, value V) Condition {
|
// If you want to use ignore case operations, consider using [NewTextIgnoreCaseCondition].
|
||||||
return valueCondition(func(builder *StatementBuilder) {
|
func NewTextCondition[T Text](col Column, op TextOperation, value T) Condition {
|
||||||
writeTextOperation(builder, col, op, value)
|
return valueCondition{
|
||||||
})
|
col: col,
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeTextOperation[T](builder, col, op, value)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTextIgnoreCaseCondition creates a condition that compares a text column with a value, ignoring case by lowercasing both.
|
||||||
|
func NewTextIgnoreCaseCondition[T Text](col Column, op TextOperation, value T) Condition {
|
||||||
|
return valueCondition{
|
||||||
|
col: col,
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeTextOperation[T](builder, LowerColumn(col), op, LowerValue(value))
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDateCondition creates a condition that compares a numeric column with a value.
|
// NewDateCondition creates a condition that compares a numeric column with a value.
|
||||||
func NewNumberCondition[V Number](col Column, op NumberOperation, value V) Condition {
|
func NewNumberCondition[V Number](col Column, op NumberOperation, value V) Condition {
|
||||||
return valueCondition(func(builder *StatementBuilder) {
|
return valueCondition{
|
||||||
writeNumberOperation(builder, col, op, value)
|
col: col,
|
||||||
})
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeNumberOperation[V](builder, col, op, value)
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDateCondition creates a condition that compares a boolean column with a value.
|
// NewDateCondition creates a condition that compares a boolean column with a value.
|
||||||
func NewBooleanCondition[V Boolean](col Column, value V) Condition {
|
func NewBooleanCondition[V Boolean](col Column, value V) Condition {
|
||||||
return valueCondition(func(builder *StatementBuilder) {
|
return valueCondition{
|
||||||
writeBooleanOperation(builder, col, value)
|
col: col,
|
||||||
})
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeBooleanOperation[V](builder, col, value)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBytesCondition creates a condition that compares a BYTEA column with a value.
|
||||||
|
func NewBytesCondition[V Bytes](col Column, op BytesOperation, value any) Condition {
|
||||||
|
return valueCondition{
|
||||||
|
col: col,
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeBytesOperation[V](builder, col, op, value)
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewColumnCondition creates a condition that compares two columns on equality.
|
// NewColumnCondition creates a condition that compares two columns on equality.
|
||||||
func NewColumnCondition(col1, col2 Column) Condition {
|
func NewColumnCondition(col1, col2 Column) Condition {
|
||||||
return valueCondition(func(builder *StatementBuilder) {
|
return valueCondition{
|
||||||
col1.WriteQualified(builder)
|
col: col1,
|
||||||
builder.WriteString(" = ")
|
write: func(builder *StatementBuilder) {
|
||||||
col2.WriteQualified(builder)
|
col1.WriteQualified(builder)
|
||||||
})
|
builder.WriteString(" = ")
|
||||||
|
col2.WriteQualified(builder)
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write implements [Condition].
|
// Write implements [Condition].
|
||||||
func (c valueCondition) Write(builder *StatementBuilder) {
|
func (c valueCondition) Write(builder *StatementBuilder) {
|
||||||
c(builder)
|
c.write(builder)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRestrictingColumn implements [Condition].
|
||||||
|
func (i valueCondition) IsRestrictingColumn(col Column) bool {
|
||||||
|
return i.col.Equals(col)
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Condition = (*valueCondition)(nil)
|
var _ Condition = (*valueCondition)(nil)
|
||||||
|
|
||||||
|
// existsCondition is a helper to write an EXISTS (SELECT 1 FROM <table> WHERE <condition>) clause.
|
||||||
|
// It implements Condition so it can be composed with other conditions using And/Or.
|
||||||
|
type existsCondition struct {
|
||||||
|
table string
|
||||||
|
condition Condition
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exists creates a condition that checks for the existence of rows in a subquery.
|
||||||
|
func Exists(table string, condition Condition) Condition {
|
||||||
|
return &existsCondition{
|
||||||
|
table: table,
|
||||||
|
condition: condition,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write implements [Condition].
|
||||||
|
func (e existsCondition) Write(builder *StatementBuilder) {
|
||||||
|
builder.WriteString(" EXISTS (SELECT 1 FROM ")
|
||||||
|
builder.WriteString(e.table)
|
||||||
|
builder.WriteString(" WHERE ")
|
||||||
|
e.condition.Write(builder)
|
||||||
|
builder.WriteString(")")
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRestrictingColumn implements [Condition].
|
||||||
|
func (e existsCondition) IsRestrictingColumn(col Column) bool {
|
||||||
|
// Forward to the inner condition so safety checks (like instance_id presence) can still work.
|
||||||
|
return e.condition.IsRestrictingColumn(col)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Condition = (*existsCondition)(nil)
|
||||||
|
|||||||
248
backend/v3/storage/database/condition_test.go
Normal file
248
backend/v3/storage/database/condition_test.go
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWrite(t *testing.T) {
|
||||||
|
type want struct {
|
||||||
|
stmt string
|
||||||
|
args []any
|
||||||
|
}
|
||||||
|
for _, test := range []struct {
|
||||||
|
name string
|
||||||
|
cond Condition
|
||||||
|
want want
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no and condition",
|
||||||
|
cond: And(),
|
||||||
|
want: want{
|
||||||
|
stmt: "",
|
||||||
|
args: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "one and condition",
|
||||||
|
cond: And(
|
||||||
|
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
|
||||||
|
),
|
||||||
|
want: want{
|
||||||
|
stmt: "table.column1 = other_table.column2",
|
||||||
|
args: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple and condition",
|
||||||
|
cond: And(
|
||||||
|
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
|
||||||
|
NewColumnCondition(NewColumn("table", "column3"), NewColumn("other_table", "column4")),
|
||||||
|
),
|
||||||
|
want: want{
|
||||||
|
stmt: "(table.column1 = other_table.column2 AND table.column3 = other_table.column4)",
|
||||||
|
args: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no or condition",
|
||||||
|
cond: Or(),
|
||||||
|
want: want{
|
||||||
|
stmt: "",
|
||||||
|
args: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "one or condition",
|
||||||
|
cond: Or(
|
||||||
|
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
|
||||||
|
),
|
||||||
|
want: want{
|
||||||
|
stmt: "table.column1 = other_table.column2",
|
||||||
|
args: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple or condition",
|
||||||
|
cond: Or(
|
||||||
|
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
|
||||||
|
NewColumnCondition(NewColumn("table", "column3"), NewColumn("other_table", "column4")),
|
||||||
|
),
|
||||||
|
want: want{
|
||||||
|
stmt: "(table.column1 = other_table.column2 OR table.column3 = other_table.column4)",
|
||||||
|
args: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "is null condition",
|
||||||
|
cond: IsNull(NewColumn("table", "column1")),
|
||||||
|
want: want{
|
||||||
|
stmt: "table.column1 IS NULL",
|
||||||
|
args: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "is not null condition",
|
||||||
|
cond: IsNotNull(NewColumn("table", "column1")),
|
||||||
|
want: want{
|
||||||
|
stmt: "table.column1 IS NOT NULL",
|
||||||
|
args: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "text condition",
|
||||||
|
cond: NewTextCondition(NewColumn("table", "column1"), TextOperationEqual, "some text"),
|
||||||
|
want: want{
|
||||||
|
stmt: "table.column1 = $1",
|
||||||
|
args: []any{"some text"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "text ignore case condition",
|
||||||
|
cond: NewTextIgnoreCaseCondition(NewColumn("table", "column1"), TextOperationNotEqual, "some TEXT"),
|
||||||
|
want: want{
|
||||||
|
stmt: "LOWER(table.column1) <> LOWER($1)",
|
||||||
|
args: []any{"some TEXT"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "number condition",
|
||||||
|
cond: NewNumberCondition(NewColumn("table", "column1"), NumberOperationEqual, 42),
|
||||||
|
want: want{
|
||||||
|
stmt: "table.column1 = $1",
|
||||||
|
args: []any{42},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "boolean condition",
|
||||||
|
cond: NewBooleanCondition(NewColumn("table", "column1"), true),
|
||||||
|
want: want{
|
||||||
|
stmt: "table.column1 = $1",
|
||||||
|
args: []any{true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bytes condition",
|
||||||
|
cond: NewBytesCondition[[]byte](NewColumn("table", "column1"), BytesOperationEqual, []byte{0x01, 0x02, 0x03}),
|
||||||
|
want: want{
|
||||||
|
stmt: "table.column1 = $1",
|
||||||
|
args: []any{[]byte{0x01, 0x02, 0x03}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "column condition",
|
||||||
|
cond: NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
|
||||||
|
want: want{
|
||||||
|
stmt: "table.column1 = other_table.column2",
|
||||||
|
args: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exists condition",
|
||||||
|
cond: Exists("table", And(
|
||||||
|
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
|
||||||
|
NewColumnCondition(NewColumn("table", "column3"), NewColumn("other_table", "column4")),
|
||||||
|
)),
|
||||||
|
want: want{
|
||||||
|
stmt: " EXISTS (SELECT 1 FROM table WHERE (table.column1 = other_table.column2 AND table.column3 = other_table.column4))",
|
||||||
|
args: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
var builder StatementBuilder
|
||||||
|
test.cond.Write(&builder)
|
||||||
|
assert.Equal(t, test.want.stmt, builder.String())
|
||||||
|
require.Len(t, builder.Args(), len(test.want.args))
|
||||||
|
for i, arg := range test.want.args {
|
||||||
|
assert.Equal(t, arg, builder.Args()[i])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsRestrictingColumn(t *testing.T) {
|
||||||
|
for _, test := range []struct {
|
||||||
|
name string
|
||||||
|
col Column
|
||||||
|
cond Condition
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "and with restricting column",
|
||||||
|
col: NewColumn("table", "column1"),
|
||||||
|
cond: And(
|
||||||
|
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
|
||||||
|
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")),
|
||||||
|
),
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "and without restricting column",
|
||||||
|
col: NewColumn("table", "column1"),
|
||||||
|
cond: And(
|
||||||
|
NewColumnCondition(NewColumn("table", "column2"), NewColumn("other_table", "column3")),
|
||||||
|
IsNull(NewColumn("table", "column4")),
|
||||||
|
IsNotNull(NewColumn("table", "column5")),
|
||||||
|
),
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "or with restricting column",
|
||||||
|
col: NewColumn("table", "column1"),
|
||||||
|
cond: Or(
|
||||||
|
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
|
||||||
|
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")),
|
||||||
|
),
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "or without restricting column",
|
||||||
|
col: NewColumn("table", "column1"),
|
||||||
|
cond: Or(
|
||||||
|
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")),
|
||||||
|
IsNotNull(NewColumn("table", "column4")),
|
||||||
|
IsNull(NewColumn("table", "column5")),
|
||||||
|
),
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "is null never restricts",
|
||||||
|
col: NewColumn("table", "column1"),
|
||||||
|
cond: IsNull(NewColumn("table", "column1")),
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "is not null never restricts",
|
||||||
|
col: NewColumn("table", "column1"),
|
||||||
|
cond: IsNotNull(NewColumn("table", "column1")),
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exists with restricting column",
|
||||||
|
col: NewColumn("table", "column1"),
|
||||||
|
cond: Exists("table", And(
|
||||||
|
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column2")),
|
||||||
|
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")),
|
||||||
|
)),
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exists without restricting column",
|
||||||
|
col: NewColumn("table", "column1"),
|
||||||
|
cond: Exists("table", Or(
|
||||||
|
NewColumnCondition(NewColumn("table", "column1"), NewColumn("other_table", "column3")),
|
||||||
|
IsNotNull(NewColumn("table", "column4")),
|
||||||
|
IsNull(NewColumn("table", "column5")),
|
||||||
|
)),
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
isRestricting := test.cond.IsRestrictingColumn(test.col)
|
||||||
|
assert.Equal(t, test.want, isRestricting)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -10,7 +10,7 @@ type Pool interface {
|
|||||||
QueryExecutor
|
QueryExecutor
|
||||||
Migrator
|
Migrator
|
||||||
|
|
||||||
Acquire(ctx context.Context) (Client, error)
|
Acquire(ctx context.Context) (Connection, error)
|
||||||
Close(ctx context.Context) error
|
Close(ctx context.Context) error
|
||||||
|
|
||||||
Ping(ctx context.Context) error
|
Ping(ctx context.Context) error
|
||||||
@@ -22,8 +22,8 @@ type PoolTest interface {
|
|||||||
MigrateTest(ctx context.Context) error
|
MigrateTest(ctx context.Context) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client is a single database connection which can be released back to the pool.
|
// Connection is a single database connection which can be released back to the pool.
|
||||||
type Client interface {
|
type Connection interface {
|
||||||
Beginner
|
Beginner
|
||||||
QueryExecutor
|
QueryExecutor
|
||||||
Migrator
|
Migrator
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
// Code generated by MockGen. DO NOT EDIT.
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
// Source: github.com/zitadel/zitadel/backend/v3/storage/database (interfaces: Pool,Client,Row,Rows,Transaction)
|
// Source: github.com/zitadel/zitadel/backend/v3/storage/database (interfaces: Pool,Connection,Row,Rows,Transaction)
|
||||||
//
|
//
|
||||||
// Generated by this command:
|
// Generated by this command:
|
||||||
//
|
//
|
||||||
// mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Client,Row,Rows,Transaction
|
// mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Connection,Row,Rows,Transaction
|
||||||
//
|
//
|
||||||
|
|
||||||
// Package dbmock is a generated GoMock package.
|
// Package dbmock is a generated GoMock package.
|
||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
type MockPool struct {
|
type MockPool struct {
|
||||||
ctrl *gomock.Controller
|
ctrl *gomock.Controller
|
||||||
recorder *MockPoolMockRecorder
|
recorder *MockPoolMockRecorder
|
||||||
|
isgomock struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockPoolMockRecorder is the mock recorder for MockPool.
|
// MockPoolMockRecorder is the mock recorder for MockPool.
|
||||||
@@ -41,18 +42,18 @@ func (m *MockPool) EXPECT() *MockPoolMockRecorder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Acquire mocks base method.
|
// Acquire mocks base method.
|
||||||
func (m *MockPool) Acquire(arg0 context.Context) (database.Client, error) {
|
func (m *MockPool) Acquire(ctx context.Context) (database.Connection, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Acquire", arg0)
|
ret := m.ctrl.Call(m, "Acquire", ctx)
|
||||||
ret0, _ := ret[0].(database.Client)
|
ret0, _ := ret[0].(database.Connection)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Acquire indicates an expected call of Acquire.
|
// Acquire indicates an expected call of Acquire.
|
||||||
func (mr *MockPoolMockRecorder) Acquire(arg0 any) *MockPoolAcquireCall {
|
func (mr *MockPoolMockRecorder) Acquire(ctx any) *MockPoolAcquireCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockPool)(nil).Acquire), arg0)
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockPool)(nil).Acquire), ctx)
|
||||||
return &MockPoolAcquireCall{Call: call}
|
return &MockPoolAcquireCall{Call: call}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,36 +63,36 @@ type MockPoolAcquireCall struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return rewrite *gomock.Call.Return
|
// Return rewrite *gomock.Call.Return
|
||||||
func (c *MockPoolAcquireCall) Return(arg0 database.Client, arg1 error) *MockPoolAcquireCall {
|
func (c *MockPoolAcquireCall) Return(arg0 database.Connection, arg1 error) *MockPoolAcquireCall {
|
||||||
c.Call = c.Call.Return(arg0, arg1)
|
c.Call = c.Call.Return(arg0, arg1)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do rewrite *gomock.Call.Do
|
// Do rewrite *gomock.Call.Do
|
||||||
func (c *MockPoolAcquireCall) Do(f func(context.Context) (database.Client, error)) *MockPoolAcquireCall {
|
func (c *MockPoolAcquireCall) Do(f func(context.Context) (database.Connection, error)) *MockPoolAcquireCall {
|
||||||
c.Call = c.Call.Do(f)
|
c.Call = c.Call.Do(f)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||||
func (c *MockPoolAcquireCall) DoAndReturn(f func(context.Context) (database.Client, error)) *MockPoolAcquireCall {
|
func (c *MockPoolAcquireCall) DoAndReturn(f func(context.Context) (database.Connection, error)) *MockPoolAcquireCall {
|
||||||
c.Call = c.Call.DoAndReturn(f)
|
c.Call = c.Call.DoAndReturn(f)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Begin mocks base method.
|
// Begin mocks base method.
|
||||||
func (m *MockPool) Begin(arg0 context.Context, arg1 *database.TransactionOptions) (database.Transaction, error) {
|
func (m *MockPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Begin", arg0, arg1)
|
ret := m.ctrl.Call(m, "Begin", ctx, opts)
|
||||||
ret0, _ := ret[0].(database.Transaction)
|
ret0, _ := ret[0].(database.Transaction)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Begin indicates an expected call of Begin.
|
// Begin indicates an expected call of Begin.
|
||||||
func (mr *MockPoolMockRecorder) Begin(arg0, arg1 any) *MockPoolBeginCall {
|
func (mr *MockPoolMockRecorder) Begin(ctx, opts any) *MockPoolBeginCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockPool)(nil).Begin), arg0, arg1)
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockPool)(nil).Begin), ctx, opts)
|
||||||
return &MockPoolBeginCall{Call: call}
|
return &MockPoolBeginCall{Call: call}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,17 +120,17 @@ func (c *MockPoolBeginCall) DoAndReturn(f func(context.Context, *database.Transa
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Close mocks base method.
|
// Close mocks base method.
|
||||||
func (m *MockPool) Close(arg0 context.Context) error {
|
func (m *MockPool) Close(ctx context.Context) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Close", arg0)
|
ret := m.ctrl.Call(m, "Close", ctx)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close indicates an expected call of Close.
|
// Close indicates an expected call of Close.
|
||||||
func (mr *MockPoolMockRecorder) Close(arg0 any) *MockPoolCloseCall {
|
func (mr *MockPoolMockRecorder) Close(ctx any) *MockPoolCloseCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPool)(nil).Close), arg0)
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPool)(nil).Close), ctx)
|
||||||
return &MockPoolCloseCall{Call: call}
|
return &MockPoolCloseCall{Call: call}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,10 +158,10 @@ func (c *MockPoolCloseCall) DoAndReturn(f func(context.Context) error) *MockPool
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Exec mocks base method.
|
// Exec mocks base method.
|
||||||
func (m *MockPool) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) {
|
func (m *MockPool) Exec(ctx context.Context, stmt string, args ...any) (int64, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
varargs := []any{arg0, arg1}
|
varargs := []any{ctx, stmt}
|
||||||
for _, a := range arg2 {
|
for _, a := range args {
|
||||||
varargs = append(varargs, a)
|
varargs = append(varargs, a)
|
||||||
}
|
}
|
||||||
ret := m.ctrl.Call(m, "Exec", varargs...)
|
ret := m.ctrl.Call(m, "Exec", varargs...)
|
||||||
@@ -170,9 +171,9 @@ func (m *MockPool) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Exec indicates an expected call of Exec.
|
// Exec indicates an expected call of Exec.
|
||||||
func (mr *MockPoolMockRecorder) Exec(arg0, arg1 any, arg2 ...any) *MockPoolExecCall {
|
func (mr *MockPoolMockRecorder) Exec(ctx, stmt any, args ...any) *MockPoolExecCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
varargs := append([]any{arg0, arg1}, arg2...)
|
varargs := append([]any{ctx, stmt}, args...)
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockPool)(nil).Exec), varargs...)
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockPool)(nil).Exec), varargs...)
|
||||||
return &MockPoolExecCall{Call: call}
|
return &MockPoolExecCall{Call: call}
|
||||||
}
|
}
|
||||||
@@ -201,17 +202,17 @@ func (c *MockPoolExecCall) DoAndReturn(f func(context.Context, string, ...any) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Migrate mocks base method.
|
// Migrate mocks base method.
|
||||||
func (m *MockPool) Migrate(arg0 context.Context) error {
|
func (m *MockPool) Migrate(ctx context.Context) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Migrate", arg0)
|
ret := m.ctrl.Call(m, "Migrate", ctx)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Migrate indicates an expected call of Migrate.
|
// Migrate indicates an expected call of Migrate.
|
||||||
func (mr *MockPoolMockRecorder) Migrate(arg0 any) *MockPoolMigrateCall {
|
func (mr *MockPoolMockRecorder) Migrate(ctx any) *MockPoolMigrateCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockPool)(nil).Migrate), arg0)
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockPool)(nil).Migrate), ctx)
|
||||||
return &MockPoolMigrateCall{Call: call}
|
return &MockPoolMigrateCall{Call: call}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -238,11 +239,49 @@ func (c *MockPoolMigrateCall) DoAndReturn(f func(context.Context) error) *MockPo
|
|||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query mocks base method.
|
// Ping mocks base method.
|
||||||
func (m *MockPool) Query(arg0 context.Context, arg1 string, arg2 ...any) (database.Rows, error) {
|
func (m *MockPool) Ping(ctx context.Context) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
varargs := []any{arg0, arg1}
|
ret := m.ctrl.Call(m, "Ping", ctx)
|
||||||
for _, a := range arg2 {
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ping indicates an expected call of Ping.
|
||||||
|
func (mr *MockPoolMockRecorder) Ping(ctx any) *MockPoolPingCall {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockPool)(nil).Ping), ctx)
|
||||||
|
return &MockPoolPingCall{Call: call}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockPoolPingCall wrap *gomock.Call
|
||||||
|
type MockPoolPingCall struct {
|
||||||
|
*gomock.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return rewrite *gomock.Call.Return
|
||||||
|
func (c *MockPoolPingCall) Return(arg0 error) *MockPoolPingCall {
|
||||||
|
c.Call = c.Call.Return(arg0)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do rewrite *gomock.Call.Do
|
||||||
|
func (c *MockPoolPingCall) Do(f func(context.Context) error) *MockPoolPingCall {
|
||||||
|
c.Call = c.Call.Do(f)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||||
|
func (c *MockPoolPingCall) DoAndReturn(f func(context.Context) error) *MockPoolPingCall {
|
||||||
|
c.Call = c.Call.DoAndReturn(f)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query mocks base method.
|
||||||
|
func (m *MockPool) Query(ctx context.Context, stmt string, args ...any) (database.Rows, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
varargs := []any{ctx, stmt}
|
||||||
|
for _, a := range args {
|
||||||
varargs = append(varargs, a)
|
varargs = append(varargs, a)
|
||||||
}
|
}
|
||||||
ret := m.ctrl.Call(m, "Query", varargs...)
|
ret := m.ctrl.Call(m, "Query", varargs...)
|
||||||
@@ -252,9 +291,9 @@ func (m *MockPool) Query(arg0 context.Context, arg1 string, arg2 ...any) (databa
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Query indicates an expected call of Query.
|
// Query indicates an expected call of Query.
|
||||||
func (mr *MockPoolMockRecorder) Query(arg0, arg1 any, arg2 ...any) *MockPoolQueryCall {
|
func (mr *MockPoolMockRecorder) Query(ctx, stmt any, args ...any) *MockPoolQueryCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
varargs := append([]any{arg0, arg1}, arg2...)
|
varargs := append([]any{ctx, stmt}, args...)
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockPool)(nil).Query), varargs...)
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockPool)(nil).Query), varargs...)
|
||||||
return &MockPoolQueryCall{Call: call}
|
return &MockPoolQueryCall{Call: call}
|
||||||
}
|
}
|
||||||
@@ -283,10 +322,10 @@ func (c *MockPoolQueryCall) DoAndReturn(f func(context.Context, string, ...any)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// QueryRow mocks base method.
|
// QueryRow mocks base method.
|
||||||
func (m *MockPool) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) database.Row {
|
func (m *MockPool) QueryRow(ctx context.Context, stmt string, args ...any) database.Row {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
varargs := []any{arg0, arg1}
|
varargs := []any{ctx, stmt}
|
||||||
for _, a := range arg2 {
|
for _, a := range args {
|
||||||
varargs = append(varargs, a)
|
varargs = append(varargs, a)
|
||||||
}
|
}
|
||||||
ret := m.ctrl.Call(m, "QueryRow", varargs...)
|
ret := m.ctrl.Call(m, "QueryRow", varargs...)
|
||||||
@@ -295,9 +334,9 @@ func (m *MockPool) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) data
|
|||||||
}
|
}
|
||||||
|
|
||||||
// QueryRow indicates an expected call of QueryRow.
|
// QueryRow indicates an expected call of QueryRow.
|
||||||
func (mr *MockPoolMockRecorder) QueryRow(arg0, arg1 any, arg2 ...any) *MockPoolQueryRowCall {
|
func (mr *MockPoolMockRecorder) QueryRow(ctx, stmt any, args ...any) *MockPoolQueryRowCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
varargs := append([]any{arg0, arg1}, arg2...)
|
varargs := append([]any{ctx, stmt}, args...)
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockPool)(nil).QueryRow), varargs...)
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockPool)(nil).QueryRow), varargs...)
|
||||||
return &MockPoolQueryRowCall{Call: call}
|
return &MockPoolQueryRowCall{Call: call}
|
||||||
}
|
}
|
||||||
@@ -325,73 +364,74 @@ func (c *MockPoolQueryRowCall) DoAndReturn(f func(context.Context, string, ...an
|
|||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockClient is a mock of Client interface.
|
// MockConnection is a mock of Connection interface.
|
||||||
type MockClient struct {
|
type MockConnection struct {
|
||||||
ctrl *gomock.Controller
|
ctrl *gomock.Controller
|
||||||
recorder *MockClientMockRecorder
|
recorder *MockConnectionMockRecorder
|
||||||
|
isgomock struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockClientMockRecorder is the mock recorder for MockClient.
|
// MockConnectionMockRecorder is the mock recorder for MockConnection.
|
||||||
type MockClientMockRecorder struct {
|
type MockConnectionMockRecorder struct {
|
||||||
mock *MockClient
|
mock *MockConnection
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMockClient creates a new mock instance.
|
// NewMockConnection creates a new mock instance.
|
||||||
func NewMockClient(ctrl *gomock.Controller) *MockClient {
|
func NewMockConnection(ctrl *gomock.Controller) *MockConnection {
|
||||||
mock := &MockClient{ctrl: ctrl}
|
mock := &MockConnection{ctrl: ctrl}
|
||||||
mock.recorder = &MockClientMockRecorder{mock}
|
mock.recorder = &MockConnectionMockRecorder{mock}
|
||||||
return mock
|
return mock
|
||||||
}
|
}
|
||||||
|
|
||||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
func (m *MockClient) EXPECT() *MockClientMockRecorder {
|
func (m *MockConnection) EXPECT() *MockConnectionMockRecorder {
|
||||||
return m.recorder
|
return m.recorder
|
||||||
}
|
}
|
||||||
|
|
||||||
// Begin mocks base method.
|
// Begin mocks base method.
|
||||||
func (m *MockClient) Begin(arg0 context.Context, arg1 *database.TransactionOptions) (database.Transaction, error) {
|
func (m *MockConnection) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Begin", arg0, arg1)
|
ret := m.ctrl.Call(m, "Begin", ctx, opts)
|
||||||
ret0, _ := ret[0].(database.Transaction)
|
ret0, _ := ret[0].(database.Transaction)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Begin indicates an expected call of Begin.
|
// Begin indicates an expected call of Begin.
|
||||||
func (mr *MockClientMockRecorder) Begin(arg0, arg1 any) *MockClientBeginCall {
|
func (mr *MockConnectionMockRecorder) Begin(ctx, opts any) *MockConnectionBeginCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockClient)(nil).Begin), arg0, arg1)
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockConnection)(nil).Begin), ctx, opts)
|
||||||
return &MockClientBeginCall{Call: call}
|
return &MockConnectionBeginCall{Call: call}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockClientBeginCall wrap *gomock.Call
|
// MockConnectionBeginCall wrap *gomock.Call
|
||||||
type MockClientBeginCall struct {
|
type MockConnectionBeginCall struct {
|
||||||
*gomock.Call
|
*gomock.Call
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return rewrite *gomock.Call.Return
|
// Return rewrite *gomock.Call.Return
|
||||||
func (c *MockClientBeginCall) Return(arg0 database.Transaction, arg1 error) *MockClientBeginCall {
|
func (c *MockConnectionBeginCall) Return(arg0 database.Transaction, arg1 error) *MockConnectionBeginCall {
|
||||||
c.Call = c.Call.Return(arg0, arg1)
|
c.Call = c.Call.Return(arg0, arg1)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do rewrite *gomock.Call.Do
|
// Do rewrite *gomock.Call.Do
|
||||||
func (c *MockClientBeginCall) Do(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockClientBeginCall {
|
func (c *MockConnectionBeginCall) Do(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockConnectionBeginCall {
|
||||||
c.Call = c.Call.Do(f)
|
c.Call = c.Call.Do(f)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||||
func (c *MockClientBeginCall) DoAndReturn(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockClientBeginCall {
|
func (c *MockConnectionBeginCall) DoAndReturn(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockConnectionBeginCall {
|
||||||
c.Call = c.Call.DoAndReturn(f)
|
c.Call = c.Call.DoAndReturn(f)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exec mocks base method.
|
// Exec mocks base method.
|
||||||
func (m *MockClient) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) {
|
func (m *MockConnection) Exec(ctx context.Context, stmt string, args ...any) (int64, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
varargs := []any{arg0, arg1}
|
varargs := []any{ctx, stmt}
|
||||||
for _, a := range arg2 {
|
for _, a := range args {
|
||||||
varargs = append(varargs, a)
|
varargs = append(varargs, a)
|
||||||
}
|
}
|
||||||
ret := m.ctrl.Call(m, "Exec", varargs...)
|
ret := m.ctrl.Call(m, "Exec", varargs...)
|
||||||
@@ -401,79 +441,117 @@ func (m *MockClient) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Exec indicates an expected call of Exec.
|
// Exec indicates an expected call of Exec.
|
||||||
func (mr *MockClientMockRecorder) Exec(arg0, arg1 any, arg2 ...any) *MockClientExecCall {
|
func (mr *MockConnectionMockRecorder) Exec(ctx, stmt any, args ...any) *MockConnectionExecCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
varargs := append([]any{arg0, arg1}, arg2...)
|
varargs := append([]any{ctx, stmt}, args...)
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockClient)(nil).Exec), varargs...)
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockConnection)(nil).Exec), varargs...)
|
||||||
return &MockClientExecCall{Call: call}
|
return &MockConnectionExecCall{Call: call}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockClientExecCall wrap *gomock.Call
|
// MockConnectionExecCall wrap *gomock.Call
|
||||||
type MockClientExecCall struct {
|
type MockConnectionExecCall struct {
|
||||||
*gomock.Call
|
*gomock.Call
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return rewrite *gomock.Call.Return
|
// Return rewrite *gomock.Call.Return
|
||||||
func (c *MockClientExecCall) Return(arg0 int64, arg1 error) *MockClientExecCall {
|
func (c *MockConnectionExecCall) Return(arg0 int64, arg1 error) *MockConnectionExecCall {
|
||||||
c.Call = c.Call.Return(arg0, arg1)
|
c.Call = c.Call.Return(arg0, arg1)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do rewrite *gomock.Call.Do
|
// Do rewrite *gomock.Call.Do
|
||||||
func (c *MockClientExecCall) Do(f func(context.Context, string, ...any) (int64, error)) *MockClientExecCall {
|
func (c *MockConnectionExecCall) Do(f func(context.Context, string, ...any) (int64, error)) *MockConnectionExecCall {
|
||||||
c.Call = c.Call.Do(f)
|
c.Call = c.Call.Do(f)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||||
func (c *MockClientExecCall) DoAndReturn(f func(context.Context, string, ...any) (int64, error)) *MockClientExecCall {
|
func (c *MockConnectionExecCall) DoAndReturn(f func(context.Context, string, ...any) (int64, error)) *MockConnectionExecCall {
|
||||||
c.Call = c.Call.DoAndReturn(f)
|
c.Call = c.Call.DoAndReturn(f)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Migrate mocks base method.
|
// Migrate mocks base method.
|
||||||
func (m *MockClient) Migrate(arg0 context.Context) error {
|
func (m *MockConnection) Migrate(ctx context.Context) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Migrate", arg0)
|
ret := m.ctrl.Call(m, "Migrate", ctx)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Migrate indicates an expected call of Migrate.
|
// Migrate indicates an expected call of Migrate.
|
||||||
func (mr *MockClientMockRecorder) Migrate(arg0 any) *MockClientMigrateCall {
|
func (mr *MockConnectionMockRecorder) Migrate(ctx any) *MockConnectionMigrateCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockClient)(nil).Migrate), arg0)
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockConnection)(nil).Migrate), ctx)
|
||||||
return &MockClientMigrateCall{Call: call}
|
return &MockConnectionMigrateCall{Call: call}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockClientMigrateCall wrap *gomock.Call
|
// MockConnectionMigrateCall wrap *gomock.Call
|
||||||
type MockClientMigrateCall struct {
|
type MockConnectionMigrateCall struct {
|
||||||
*gomock.Call
|
*gomock.Call
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return rewrite *gomock.Call.Return
|
// Return rewrite *gomock.Call.Return
|
||||||
func (c *MockClientMigrateCall) Return(arg0 error) *MockClientMigrateCall {
|
func (c *MockConnectionMigrateCall) Return(arg0 error) *MockConnectionMigrateCall {
|
||||||
c.Call = c.Call.Return(arg0)
|
c.Call = c.Call.Return(arg0)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do rewrite *gomock.Call.Do
|
// Do rewrite *gomock.Call.Do
|
||||||
func (c *MockClientMigrateCall) Do(f func(context.Context) error) *MockClientMigrateCall {
|
func (c *MockConnectionMigrateCall) Do(f func(context.Context) error) *MockConnectionMigrateCall {
|
||||||
c.Call = c.Call.Do(f)
|
c.Call = c.Call.Do(f)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||||
func (c *MockClientMigrateCall) DoAndReturn(f func(context.Context) error) *MockClientMigrateCall {
|
func (c *MockConnectionMigrateCall) DoAndReturn(f func(context.Context) error) *MockConnectionMigrateCall {
|
||||||
|
c.Call = c.Call.DoAndReturn(f)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ping mocks base method.
|
||||||
|
func (m *MockConnection) Ping(ctx context.Context) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "Ping", ctx)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ping indicates an expected call of Ping.
|
||||||
|
func (mr *MockConnectionMockRecorder) Ping(ctx any) *MockConnectionPingCall {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockConnection)(nil).Ping), ctx)
|
||||||
|
return &MockConnectionPingCall{Call: call}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockConnectionPingCall wrap *gomock.Call
|
||||||
|
type MockConnectionPingCall struct {
|
||||||
|
*gomock.Call
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return rewrite *gomock.Call.Return
|
||||||
|
func (c *MockConnectionPingCall) Return(arg0 error) *MockConnectionPingCall {
|
||||||
|
c.Call = c.Call.Return(arg0)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do rewrite *gomock.Call.Do
|
||||||
|
func (c *MockConnectionPingCall) Do(f func(context.Context) error) *MockConnectionPingCall {
|
||||||
|
c.Call = c.Call.Do(f)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||||
|
func (c *MockConnectionPingCall) DoAndReturn(f func(context.Context) error) *MockConnectionPingCall {
|
||||||
c.Call = c.Call.DoAndReturn(f)
|
c.Call = c.Call.DoAndReturn(f)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query mocks base method.
|
// Query mocks base method.
|
||||||
func (m *MockClient) Query(arg0 context.Context, arg1 string, arg2 ...any) (database.Rows, error) {
|
func (m *MockConnection) Query(ctx context.Context, stmt string, args ...any) (database.Rows, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
varargs := []any{arg0, arg1}
|
varargs := []any{ctx, stmt}
|
||||||
for _, a := range arg2 {
|
for _, a := range args {
|
||||||
varargs = append(varargs, a)
|
varargs = append(varargs, a)
|
||||||
}
|
}
|
||||||
ret := m.ctrl.Call(m, "Query", varargs...)
|
ret := m.ctrl.Call(m, "Query", varargs...)
|
||||||
@@ -483,41 +561,41 @@ func (m *MockClient) Query(arg0 context.Context, arg1 string, arg2 ...any) (data
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Query indicates an expected call of Query.
|
// Query indicates an expected call of Query.
|
||||||
func (mr *MockClientMockRecorder) Query(arg0, arg1 any, arg2 ...any) *MockClientQueryCall {
|
func (mr *MockConnectionMockRecorder) Query(ctx, stmt any, args ...any) *MockConnectionQueryCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
varargs := append([]any{arg0, arg1}, arg2...)
|
varargs := append([]any{ctx, stmt}, args...)
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockClient)(nil).Query), varargs...)
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockConnection)(nil).Query), varargs...)
|
||||||
return &MockClientQueryCall{Call: call}
|
return &MockConnectionQueryCall{Call: call}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockClientQueryCall wrap *gomock.Call
|
// MockConnectionQueryCall wrap *gomock.Call
|
||||||
type MockClientQueryCall struct {
|
type MockConnectionQueryCall struct {
|
||||||
*gomock.Call
|
*gomock.Call
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return rewrite *gomock.Call.Return
|
// Return rewrite *gomock.Call.Return
|
||||||
func (c *MockClientQueryCall) Return(arg0 database.Rows, arg1 error) *MockClientQueryCall {
|
func (c *MockConnectionQueryCall) Return(arg0 database.Rows, arg1 error) *MockConnectionQueryCall {
|
||||||
c.Call = c.Call.Return(arg0, arg1)
|
c.Call = c.Call.Return(arg0, arg1)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do rewrite *gomock.Call.Do
|
// Do rewrite *gomock.Call.Do
|
||||||
func (c *MockClientQueryCall) Do(f func(context.Context, string, ...any) (database.Rows, error)) *MockClientQueryCall {
|
func (c *MockConnectionQueryCall) Do(f func(context.Context, string, ...any) (database.Rows, error)) *MockConnectionQueryCall {
|
||||||
c.Call = c.Call.Do(f)
|
c.Call = c.Call.Do(f)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||||
func (c *MockClientQueryCall) DoAndReturn(f func(context.Context, string, ...any) (database.Rows, error)) *MockClientQueryCall {
|
func (c *MockConnectionQueryCall) DoAndReturn(f func(context.Context, string, ...any) (database.Rows, error)) *MockConnectionQueryCall {
|
||||||
c.Call = c.Call.DoAndReturn(f)
|
c.Call = c.Call.DoAndReturn(f)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryRow mocks base method.
|
// QueryRow mocks base method.
|
||||||
func (m *MockClient) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) database.Row {
|
func (m *MockConnection) QueryRow(ctx context.Context, stmt string, args ...any) database.Row {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
varargs := []any{arg0, arg1}
|
varargs := []any{ctx, stmt}
|
||||||
for _, a := range arg2 {
|
for _, a := range args {
|
||||||
varargs = append(varargs, a)
|
varargs = append(varargs, a)
|
||||||
}
|
}
|
||||||
ret := m.ctrl.Call(m, "QueryRow", varargs...)
|
ret := m.ctrl.Call(m, "QueryRow", varargs...)
|
||||||
@@ -526,70 +604,70 @@ func (m *MockClient) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) da
|
|||||||
}
|
}
|
||||||
|
|
||||||
// QueryRow indicates an expected call of QueryRow.
|
// QueryRow indicates an expected call of QueryRow.
|
||||||
func (mr *MockClientMockRecorder) QueryRow(arg0, arg1 any, arg2 ...any) *MockClientQueryRowCall {
|
func (mr *MockConnectionMockRecorder) QueryRow(ctx, stmt any, args ...any) *MockConnectionQueryRowCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
varargs := append([]any{arg0, arg1}, arg2...)
|
varargs := append([]any{ctx, stmt}, args...)
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockClient)(nil).QueryRow), varargs...)
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockConnection)(nil).QueryRow), varargs...)
|
||||||
return &MockClientQueryRowCall{Call: call}
|
return &MockConnectionQueryRowCall{Call: call}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockClientQueryRowCall wrap *gomock.Call
|
// MockConnectionQueryRowCall wrap *gomock.Call
|
||||||
type MockClientQueryRowCall struct {
|
type MockConnectionQueryRowCall struct {
|
||||||
*gomock.Call
|
*gomock.Call
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return rewrite *gomock.Call.Return
|
// Return rewrite *gomock.Call.Return
|
||||||
func (c *MockClientQueryRowCall) Return(arg0 database.Row) *MockClientQueryRowCall {
|
func (c *MockConnectionQueryRowCall) Return(arg0 database.Row) *MockConnectionQueryRowCall {
|
||||||
c.Call = c.Call.Return(arg0)
|
c.Call = c.Call.Return(arg0)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do rewrite *gomock.Call.Do
|
// Do rewrite *gomock.Call.Do
|
||||||
func (c *MockClientQueryRowCall) Do(f func(context.Context, string, ...any) database.Row) *MockClientQueryRowCall {
|
func (c *MockConnectionQueryRowCall) Do(f func(context.Context, string, ...any) database.Row) *MockConnectionQueryRowCall {
|
||||||
c.Call = c.Call.Do(f)
|
c.Call = c.Call.Do(f)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||||
func (c *MockClientQueryRowCall) DoAndReturn(f func(context.Context, string, ...any) database.Row) *MockClientQueryRowCall {
|
func (c *MockConnectionQueryRowCall) DoAndReturn(f func(context.Context, string, ...any) database.Row) *MockConnectionQueryRowCall {
|
||||||
c.Call = c.Call.DoAndReturn(f)
|
c.Call = c.Call.DoAndReturn(f)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Release mocks base method.
|
// Release mocks base method.
|
||||||
func (m *MockClient) Release(arg0 context.Context) error {
|
func (m *MockConnection) Release(ctx context.Context) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Release", arg0)
|
ret := m.ctrl.Call(m, "Release", ctx)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Release indicates an expected call of Release.
|
// Release indicates an expected call of Release.
|
||||||
func (mr *MockClientMockRecorder) Release(arg0 any) *MockClientReleaseCall {
|
func (mr *MockConnectionMockRecorder) Release(ctx any) *MockConnectionReleaseCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Release", reflect.TypeOf((*MockClient)(nil).Release), arg0)
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Release", reflect.TypeOf((*MockConnection)(nil).Release), ctx)
|
||||||
return &MockClientReleaseCall{Call: call}
|
return &MockConnectionReleaseCall{Call: call}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockClientReleaseCall wrap *gomock.Call
|
// MockConnectionReleaseCall wrap *gomock.Call
|
||||||
type MockClientReleaseCall struct {
|
type MockConnectionReleaseCall struct {
|
||||||
*gomock.Call
|
*gomock.Call
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return rewrite *gomock.Call.Return
|
// Return rewrite *gomock.Call.Return
|
||||||
func (c *MockClientReleaseCall) Return(arg0 error) *MockClientReleaseCall {
|
func (c *MockConnectionReleaseCall) Return(arg0 error) *MockConnectionReleaseCall {
|
||||||
c.Call = c.Call.Return(arg0)
|
c.Call = c.Call.Return(arg0)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do rewrite *gomock.Call.Do
|
// Do rewrite *gomock.Call.Do
|
||||||
func (c *MockClientReleaseCall) Do(f func(context.Context) error) *MockClientReleaseCall {
|
func (c *MockConnectionReleaseCall) Do(f func(context.Context) error) *MockConnectionReleaseCall {
|
||||||
c.Call = c.Call.Do(f)
|
c.Call = c.Call.Do(f)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
// DoAndReturn rewrite *gomock.Call.DoAndReturn
|
||||||
func (c *MockClientReleaseCall) DoAndReturn(f func(context.Context) error) *MockClientReleaseCall {
|
func (c *MockConnectionReleaseCall) DoAndReturn(f func(context.Context) error) *MockConnectionReleaseCall {
|
||||||
c.Call = c.Call.DoAndReturn(f)
|
c.Call = c.Call.DoAndReturn(f)
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
@@ -598,6 +676,7 @@ func (c *MockClientReleaseCall) DoAndReturn(f func(context.Context) error) *Mock
|
|||||||
type MockRow struct {
|
type MockRow struct {
|
||||||
ctrl *gomock.Controller
|
ctrl *gomock.Controller
|
||||||
recorder *MockRowMockRecorder
|
recorder *MockRowMockRecorder
|
||||||
|
isgomock struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockRowMockRecorder is the mock recorder for MockRow.
|
// MockRowMockRecorder is the mock recorder for MockRow.
|
||||||
@@ -618,10 +697,10 @@ func (m *MockRow) EXPECT() *MockRowMockRecorder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Scan mocks base method.
|
// Scan mocks base method.
|
||||||
func (m *MockRow) Scan(arg0 ...any) error {
|
func (m *MockRow) Scan(dest ...any) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
varargs := []any{}
|
varargs := []any{}
|
||||||
for _, a := range arg0 {
|
for _, a := range dest {
|
||||||
varargs = append(varargs, a)
|
varargs = append(varargs, a)
|
||||||
}
|
}
|
||||||
ret := m.ctrl.Call(m, "Scan", varargs...)
|
ret := m.ctrl.Call(m, "Scan", varargs...)
|
||||||
@@ -630,9 +709,9 @@ func (m *MockRow) Scan(arg0 ...any) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Scan indicates an expected call of Scan.
|
// Scan indicates an expected call of Scan.
|
||||||
func (mr *MockRowMockRecorder) Scan(arg0 ...any) *MockRowScanCall {
|
func (mr *MockRowMockRecorder) Scan(dest ...any) *MockRowScanCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRow)(nil).Scan), arg0...)
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRow)(nil).Scan), dest...)
|
||||||
return &MockRowScanCall{Call: call}
|
return &MockRowScanCall{Call: call}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -663,6 +742,7 @@ func (c *MockRowScanCall) DoAndReturn(f func(...any) error) *MockRowScanCall {
|
|||||||
type MockRows struct {
|
type MockRows struct {
|
||||||
ctrl *gomock.Controller
|
ctrl *gomock.Controller
|
||||||
recorder *MockRowsMockRecorder
|
recorder *MockRowsMockRecorder
|
||||||
|
isgomock struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockRowsMockRecorder is the mock recorder for MockRows.
|
// MockRowsMockRecorder is the mock recorder for MockRows.
|
||||||
@@ -797,10 +877,10 @@ func (c *MockRowsNextCall) DoAndReturn(f func() bool) *MockRowsNextCall {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Scan mocks base method.
|
// Scan mocks base method.
|
||||||
func (m *MockRows) Scan(arg0 ...any) error {
|
func (m *MockRows) Scan(dest ...any) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
varargs := []any{}
|
varargs := []any{}
|
||||||
for _, a := range arg0 {
|
for _, a := range dest {
|
||||||
varargs = append(varargs, a)
|
varargs = append(varargs, a)
|
||||||
}
|
}
|
||||||
ret := m.ctrl.Call(m, "Scan", varargs...)
|
ret := m.ctrl.Call(m, "Scan", varargs...)
|
||||||
@@ -809,9 +889,9 @@ func (m *MockRows) Scan(arg0 ...any) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Scan indicates an expected call of Scan.
|
// Scan indicates an expected call of Scan.
|
||||||
func (mr *MockRowsMockRecorder) Scan(arg0 ...any) *MockRowsScanCall {
|
func (mr *MockRowsMockRecorder) Scan(dest ...any) *MockRowsScanCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), arg0...)
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), dest...)
|
||||||
return &MockRowsScanCall{Call: call}
|
return &MockRowsScanCall{Call: call}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -842,6 +922,7 @@ func (c *MockRowsScanCall) DoAndReturn(f func(...any) error) *MockRowsScanCall {
|
|||||||
type MockTransaction struct {
|
type MockTransaction struct {
|
||||||
ctrl *gomock.Controller
|
ctrl *gomock.Controller
|
||||||
recorder *MockTransactionMockRecorder
|
recorder *MockTransactionMockRecorder
|
||||||
|
isgomock struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockTransactionMockRecorder is the mock recorder for MockTransaction.
|
// MockTransactionMockRecorder is the mock recorder for MockTransaction.
|
||||||
@@ -862,18 +943,18 @@ func (m *MockTransaction) EXPECT() *MockTransactionMockRecorder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Begin mocks base method.
|
// Begin mocks base method.
|
||||||
func (m *MockTransaction) Begin(arg0 context.Context) (database.Transaction, error) {
|
func (m *MockTransaction) Begin(ctx context.Context) (database.Transaction, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Begin", arg0)
|
ret := m.ctrl.Call(m, "Begin", ctx)
|
||||||
ret0, _ := ret[0].(database.Transaction)
|
ret0, _ := ret[0].(database.Transaction)
|
||||||
ret1, _ := ret[1].(error)
|
ret1, _ := ret[1].(error)
|
||||||
return ret0, ret1
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// Begin indicates an expected call of Begin.
|
// Begin indicates an expected call of Begin.
|
||||||
func (mr *MockTransactionMockRecorder) Begin(arg0 any) *MockTransactionBeginCall {
|
func (mr *MockTransactionMockRecorder) Begin(ctx any) *MockTransactionBeginCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockTransaction)(nil).Begin), arg0)
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockTransaction)(nil).Begin), ctx)
|
||||||
return &MockTransactionBeginCall{Call: call}
|
return &MockTransactionBeginCall{Call: call}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -901,17 +982,17 @@ func (c *MockTransactionBeginCall) DoAndReturn(f func(context.Context) (database
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Commit mocks base method.
|
// Commit mocks base method.
|
||||||
func (m *MockTransaction) Commit(arg0 context.Context) error {
|
func (m *MockTransaction) Commit(ctx context.Context) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Commit", arg0)
|
ret := m.ctrl.Call(m, "Commit", ctx)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Commit indicates an expected call of Commit.
|
// Commit indicates an expected call of Commit.
|
||||||
func (mr *MockTransactionMockRecorder) Commit(arg0 any) *MockTransactionCommitCall {
|
func (mr *MockTransactionMockRecorder) Commit(ctx any) *MockTransactionCommitCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTransaction)(nil).Commit), arg0)
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTransaction)(nil).Commit), ctx)
|
||||||
return &MockTransactionCommitCall{Call: call}
|
return &MockTransactionCommitCall{Call: call}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -939,17 +1020,17 @@ func (c *MockTransactionCommitCall) DoAndReturn(f func(context.Context) error) *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// End mocks base method.
|
// End mocks base method.
|
||||||
func (m *MockTransaction) End(arg0 context.Context, arg1 error) error {
|
func (m *MockTransaction) End(ctx context.Context, err error) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "End", arg0, arg1)
|
ret := m.ctrl.Call(m, "End", ctx, err)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// End indicates an expected call of End.
|
// End indicates an expected call of End.
|
||||||
func (mr *MockTransactionMockRecorder) End(arg0, arg1 any) *MockTransactionEndCall {
|
func (mr *MockTransactionMockRecorder) End(ctx, err any) *MockTransactionEndCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "End", reflect.TypeOf((*MockTransaction)(nil).End), arg0, arg1)
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "End", reflect.TypeOf((*MockTransaction)(nil).End), ctx, err)
|
||||||
return &MockTransactionEndCall{Call: call}
|
return &MockTransactionEndCall{Call: call}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -977,10 +1058,10 @@ func (c *MockTransactionEndCall) DoAndReturn(f func(context.Context, error) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Exec mocks base method.
|
// Exec mocks base method.
|
||||||
func (m *MockTransaction) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) {
|
func (m *MockTransaction) Exec(ctx context.Context, stmt string, args ...any) (int64, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
varargs := []any{arg0, arg1}
|
varargs := []any{ctx, stmt}
|
||||||
for _, a := range arg2 {
|
for _, a := range args {
|
||||||
varargs = append(varargs, a)
|
varargs = append(varargs, a)
|
||||||
}
|
}
|
||||||
ret := m.ctrl.Call(m, "Exec", varargs...)
|
ret := m.ctrl.Call(m, "Exec", varargs...)
|
||||||
@@ -990,9 +1071,9 @@ func (m *MockTransaction) Exec(arg0 context.Context, arg1 string, arg2 ...any) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Exec indicates an expected call of Exec.
|
// Exec indicates an expected call of Exec.
|
||||||
func (mr *MockTransactionMockRecorder) Exec(arg0, arg1 any, arg2 ...any) *MockTransactionExecCall {
|
func (mr *MockTransactionMockRecorder) Exec(ctx, stmt any, args ...any) *MockTransactionExecCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
varargs := append([]any{arg0, arg1}, arg2...)
|
varargs := append([]any{ctx, stmt}, args...)
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTransaction)(nil).Exec), varargs...)
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTransaction)(nil).Exec), varargs...)
|
||||||
return &MockTransactionExecCall{Call: call}
|
return &MockTransactionExecCall{Call: call}
|
||||||
}
|
}
|
||||||
@@ -1021,10 +1102,10 @@ func (c *MockTransactionExecCall) DoAndReturn(f func(context.Context, string, ..
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Query mocks base method.
|
// Query mocks base method.
|
||||||
func (m *MockTransaction) Query(arg0 context.Context, arg1 string, arg2 ...any) (database.Rows, error) {
|
func (m *MockTransaction) Query(ctx context.Context, stmt string, args ...any) (database.Rows, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
varargs := []any{arg0, arg1}
|
varargs := []any{ctx, stmt}
|
||||||
for _, a := range arg2 {
|
for _, a := range args {
|
||||||
varargs = append(varargs, a)
|
varargs = append(varargs, a)
|
||||||
}
|
}
|
||||||
ret := m.ctrl.Call(m, "Query", varargs...)
|
ret := m.ctrl.Call(m, "Query", varargs...)
|
||||||
@@ -1034,9 +1115,9 @@ func (m *MockTransaction) Query(arg0 context.Context, arg1 string, arg2 ...any)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Query indicates an expected call of Query.
|
// Query indicates an expected call of Query.
|
||||||
func (mr *MockTransactionMockRecorder) Query(arg0, arg1 any, arg2 ...any) *MockTransactionQueryCall {
|
func (mr *MockTransactionMockRecorder) Query(ctx, stmt any, args ...any) *MockTransactionQueryCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
varargs := append([]any{arg0, arg1}, arg2...)
|
varargs := append([]any{ctx, stmt}, args...)
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockTransaction)(nil).Query), varargs...)
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockTransaction)(nil).Query), varargs...)
|
||||||
return &MockTransactionQueryCall{Call: call}
|
return &MockTransactionQueryCall{Call: call}
|
||||||
}
|
}
|
||||||
@@ -1065,10 +1146,10 @@ func (c *MockTransactionQueryCall) DoAndReturn(f func(context.Context, string, .
|
|||||||
}
|
}
|
||||||
|
|
||||||
// QueryRow mocks base method.
|
// QueryRow mocks base method.
|
||||||
func (m *MockTransaction) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) database.Row {
|
func (m *MockTransaction) QueryRow(ctx context.Context, stmt string, args ...any) database.Row {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
varargs := []any{arg0, arg1}
|
varargs := []any{ctx, stmt}
|
||||||
for _, a := range arg2 {
|
for _, a := range args {
|
||||||
varargs = append(varargs, a)
|
varargs = append(varargs, a)
|
||||||
}
|
}
|
||||||
ret := m.ctrl.Call(m, "QueryRow", varargs...)
|
ret := m.ctrl.Call(m, "QueryRow", varargs...)
|
||||||
@@ -1077,9 +1158,9 @@ func (m *MockTransaction) QueryRow(arg0 context.Context, arg1 string, arg2 ...an
|
|||||||
}
|
}
|
||||||
|
|
||||||
// QueryRow indicates an expected call of QueryRow.
|
// QueryRow indicates an expected call of QueryRow.
|
||||||
func (mr *MockTransactionMockRecorder) QueryRow(arg0, arg1 any, arg2 ...any) *MockTransactionQueryRowCall {
|
func (mr *MockTransactionMockRecorder) QueryRow(ctx, stmt any, args ...any) *MockTransactionQueryRowCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
varargs := append([]any{arg0, arg1}, arg2...)
|
varargs := append([]any{ctx, stmt}, args...)
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockTransaction)(nil).QueryRow), varargs...)
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockTransaction)(nil).QueryRow), varargs...)
|
||||||
return &MockTransactionQueryRowCall{Call: call}
|
return &MockTransactionQueryRowCall{Call: call}
|
||||||
}
|
}
|
||||||
@@ -1108,17 +1189,17 @@ func (c *MockTransactionQueryRowCall) DoAndReturn(f func(context.Context, string
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Rollback mocks base method.
|
// Rollback mocks base method.
|
||||||
func (m *MockTransaction) Rollback(arg0 context.Context) error {
|
func (m *MockTransaction) Rollback(ctx context.Context) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Rollback", arg0)
|
ret := m.ctrl.Call(m, "Rollback", ctx)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rollback indicates an expected call of Rollback.
|
// Rollback indicates an expected call of Rollback.
|
||||||
func (mr *MockTransactionMockRecorder) Rollback(arg0 any) *MockTransactionRollbackCall {
|
func (mr *MockTransactionMockRecorder) Rollback(ctx any) *MockTransactionRollbackCall {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockTransaction)(nil).Rollback), arg0)
|
call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockTransaction)(nil).Rollback), ctx)
|
||||||
return &MockTransactionRollbackCall{Call: call}
|
return &MockTransactionRollbackCall{Call: call}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,15 +13,15 @@ type pgxConn struct {
|
|||||||
*pgxpool.Conn
|
*pgxpool.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ database.Client = (*pgxConn)(nil)
|
var _ database.Connection = (*pgxConn)(nil)
|
||||||
|
|
||||||
// Release implements [database.Client].
|
// Release implements [database.Connection].
|
||||||
func (c *pgxConn) Release(_ context.Context) error {
|
func (c *pgxConn) Release(_ context.Context) error {
|
||||||
c.Conn.Release()
|
c.Conn.Release()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Begin implements [database.Client].
|
// Begin implements [database.Connection].
|
||||||
func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||||
tx, err := c.BeginTx(ctx, transactionOptionsToPgx(opts))
|
tx, err := c.BeginTx(ctx, transactionOptionsToPgx(opts))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -30,8 +30,8 @@ func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions)
|
|||||||
return &Transaction{tx}, nil
|
return &Transaction{tx}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query implements sql.Client.
|
// Query implements [database.Connection].
|
||||||
// Subtle: this method shadows the method (*Conn).Query of pgxConn.Conn.
|
// Subtle: this method shadows the method (*Conn).Query of [pgxConn.Conn].
|
||||||
func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||||
rows, err := c.Conn.Query(ctx, sql, args...)
|
rows, err := c.Conn.Query(ctx, sql, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -40,14 +40,14 @@ func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.
|
|||||||
return &Rows{rows}, nil
|
return &Rows{rows}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryRow implements sql.Client.
|
// QueryRow implements [database.Connection].
|
||||||
// Subtle: this method shadows the method (*Conn).QueryRow of pgxConn.Conn.
|
// Subtle: this method shadows the method (*Conn).QueryRow of [pgxConn.Conn].
|
||||||
func (c *pgxConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
func (c *pgxConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||||
return &Row{c.Conn.QueryRow(ctx, sql, args...)}
|
return &Row{c.Conn.QueryRow(ctx, sql, args...)}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exec implements [database.Pool].
|
// QueryRow implements [database.Connection].
|
||||||
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
// Subtle: this method shadows the method (*Conn).QueryRow of [pgxConn.Conn].
|
||||||
func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
|
func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
|
||||||
res, err := c.Conn.Exec(ctx, sql, args...)
|
res, err := c.Conn.Exec(ctx, sql, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package postgres
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v5/pgconn"
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
@@ -16,6 +17,18 @@ func wrapError(err error) error {
|
|||||||
if errors.Is(err, pgx.ErrNoRows) {
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
return database.NewNoRowFoundError(err)
|
return database.NewNoRowFoundError(err)
|
||||||
}
|
}
|
||||||
|
if errors.Is(err, pgx.ErrTooManyRows) {
|
||||||
|
return database.NewMultipleRowsFoundError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// scany only exports its errors as strings
|
||||||
|
if strings.HasPrefix(err.Error(), "scany: expected 1 row, got: ") {
|
||||||
|
return database.NewMultipleRowsFoundError(err)
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(err.Error(), "scany:") || strings.HasPrefix(err.Error(), "scanning:") {
|
||||||
|
return database.NewScanError(err)
|
||||||
|
}
|
||||||
|
|
||||||
var pgxErr *pgconn.PgError
|
var pgxErr *pgconn.PgError
|
||||||
if !errors.As(err, &pgxErr) {
|
if !errors.As(err, &pgxErr) {
|
||||||
return database.NewUnknownError(err)
|
return database.NewUnknownError(err)
|
||||||
|
|||||||
@@ -0,0 +1,16 @@
|
|||||||
|
package migration
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "embed"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
//go:embed 004_correct_set_updated_at/up.sql
|
||||||
|
up004CorrectSetUpdatedAt string
|
||||||
|
//go:embed 004_correct_set_updated_at/down.sql
|
||||||
|
down004CorrectSetUpdatedAt string
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
registerSQLMigration(4, up004CorrectSetUpdatedAt, down004CorrectSetUpdatedAt)
|
||||||
|
}
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
CREATE OR REPLACE TRIGGER trigger_set_updated_at
|
||||||
|
BEFORE UPDATE ON zitadel.instances
|
||||||
|
FOR EACH ROW
|
||||||
|
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
|
||||||
|
EXECUTE FUNCTION zitadel.set_updated_at();
|
||||||
|
|
||||||
|
CREATE OR REPLACE TRIGGER trigger_set_updated_at
|
||||||
|
BEFORE UPDATE ON zitadel.organizations
|
||||||
|
FOR EACH ROW
|
||||||
|
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
|
||||||
|
EXECUTE FUNCTION zitadel.set_updated_at();
|
||||||
|
|
||||||
|
CREATE OR REPLACE TRIGGER trg_set_updated_at_instance_domains
|
||||||
|
BEFORE UPDATE ON zitadel.instance_domains
|
||||||
|
FOR EACH ROW
|
||||||
|
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
|
||||||
|
EXECUTE FUNCTION zitadel.set_updated_at();
|
||||||
|
|
||||||
|
CREATE OR REPLACE TRIGGER trg_set_updated_at_org_domains
|
||||||
|
BEFORE UPDATE ON zitadel.org_domains
|
||||||
|
FOR EACH ROW
|
||||||
|
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
|
||||||
|
EXECUTE FUNCTION zitadel.set_updated_at();
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
CREATE OR REPLACE TRIGGER trigger_set_updated_at
|
||||||
|
BEFORE UPDATE ON zitadel.instances
|
||||||
|
FOR EACH ROW
|
||||||
|
WHEN (NEW.updated_at IS NULL)
|
||||||
|
EXECUTE FUNCTION zitadel.set_updated_at();
|
||||||
|
|
||||||
|
CREATE OR REPLACE TRIGGER trigger_set_updated_at
|
||||||
|
BEFORE UPDATE ON zitadel.organizations
|
||||||
|
FOR EACH ROW
|
||||||
|
WHEN (NEW.updated_at IS NULL)
|
||||||
|
EXECUTE FUNCTION zitadel.set_updated_at();
|
||||||
|
|
||||||
|
CREATE OR REPLACE TRIGGER trg_set_updated_at_instance_domains
|
||||||
|
BEFORE UPDATE ON zitadel.instance_domains
|
||||||
|
FOR EACH ROW
|
||||||
|
WHEN (NEW.updated_at IS NULL)
|
||||||
|
EXECUTE FUNCTION zitadel.set_updated_at();
|
||||||
|
|
||||||
|
CREATE OR REPLACE TRIGGER trg_set_updated_at_org_domains
|
||||||
|
BEFORE UPDATE ON zitadel.org_domains
|
||||||
|
FOR EACH ROW
|
||||||
|
WHEN (NEW.updated_at IS NULL)
|
||||||
|
EXECUTE FUNCTION zitadel.set_updated_at();
|
||||||
@@ -22,8 +22,8 @@ func PGxPool(pool *pgxpool.Pool) *pgxPool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Acquire implements [database.Pool].
|
// Acquire implements [database.Pool].
|
||||||
func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) {
|
func (p *pgxPool) Acquire(ctx context.Context) (database.Connection, error) {
|
||||||
conn, err := c.Pool.Acquire(ctx)
|
conn, err := p.Pool.Acquire(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, wrapError(err)
|
return nil, wrapError(err)
|
||||||
}
|
}
|
||||||
@@ -32,8 +32,8 @@ func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) {
|
|||||||
|
|
||||||
// Query implements [database.Pool].
|
// Query implements [database.Pool].
|
||||||
// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool.
|
// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool.
|
||||||
func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
func (p *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||||
rows, err := c.Pool.Query(ctx, sql, args...)
|
rows, err := p.Pool.Query(ctx, sql, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, wrapError(err)
|
return nil, wrapError(err)
|
||||||
}
|
}
|
||||||
@@ -42,14 +42,14 @@ func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.
|
|||||||
|
|
||||||
// QueryRow implements [database.Pool].
|
// QueryRow implements [database.Pool].
|
||||||
// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool.
|
// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool.
|
||||||
func (c *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
func (p *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||||
return &Row{c.Pool.QueryRow(ctx, sql, args...)}
|
return &Row{p.Pool.QueryRow(ctx, sql, args...)}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exec implements [database.Pool].
|
// Exec implements [database.Pool].
|
||||||
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
||||||
func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
|
func (p *pgxPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
|
||||||
res, err := c.Pool.Exec(ctx, sql, args...)
|
res, err := p.Pool.Exec(ctx, sql, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, wrapError(err)
|
return 0, wrapError(err)
|
||||||
}
|
}
|
||||||
@@ -57,8 +57,8 @@ func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) (int64, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Begin implements [database.Pool].
|
// Begin implements [database.Pool].
|
||||||
func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
func (p *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||||
tx, err := c.BeginTx(ctx, transactionOptionsToPgx(opts))
|
tx, err := p.BeginTx(ctx, transactionOptionsToPgx(opts))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, wrapError(err)
|
return nil, wrapError(err)
|
||||||
}
|
}
|
||||||
@@ -66,23 +66,23 @@ func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Close implements [database.Pool].
|
// Close implements [database.Pool].
|
||||||
func (c *pgxPool) Close(_ context.Context) error {
|
func (p *pgxPool) Close(_ context.Context) error {
|
||||||
c.Pool.Close()
|
p.Pool.Close()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ping implements [database.Pool].
|
// Ping implements [database.Pool].
|
||||||
func (c *pgxPool) Ping(ctx context.Context) error {
|
func (p *pgxPool) Ping(ctx context.Context) error {
|
||||||
return wrapError(c.Pool.Ping(ctx))
|
return wrapError(p.Pool.Ping(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Migrate implements [database.Migrator].
|
// Migrate implements [database.Migrator].
|
||||||
func (c *pgxPool) Migrate(ctx context.Context) error {
|
func (p *pgxPool) Migrate(ctx context.Context) error {
|
||||||
if isMigrated {
|
if isMigrated {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := c.Pool.Acquire(ctx)
|
client, err := p.Pool.Acquire(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -93,8 +93,8 @@ func (c *pgxPool) Migrate(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Migrate implements [database.PoolTest].
|
// Migrate implements [database.PoolTest].
|
||||||
func (c *pgxPool) MigrateTest(ctx context.Context) error {
|
func (p *pgxPool) MigrateTest(ctx context.Context) error {
|
||||||
client, err := c.Pool.Acquire(ctx)
|
client, err := p.Pool.Acquire(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,18 +11,18 @@ type sqlConn struct {
|
|||||||
*sql.Conn
|
*sql.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func SQLConn(conn *sql.Conn) database.Client {
|
func SQLConn(conn *sql.Conn) database.Connection {
|
||||||
return &sqlConn{Conn: conn}
|
return &sqlConn{Conn: conn}
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ database.Client = (*sqlConn)(nil)
|
var _ database.Connection = (*sqlConn)(nil)
|
||||||
|
|
||||||
// Release implements [database.Client].
|
// Release implements [database.Connection].
|
||||||
func (c *sqlConn) Release(_ context.Context) error {
|
func (c *sqlConn) Release(_ context.Context) error {
|
||||||
return c.Close()
|
return c.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Begin implements [database.Client].
|
// Begin implements [database.Connection].
|
||||||
func (c *sqlConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
func (c *sqlConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||||
tx, err := c.BeginTx(ctx, transactionOptionsToSQL(opts))
|
tx, err := c.BeginTx(ctx, transactionOptionsToSQL(opts))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package sql
|
|||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v5/pgconn"
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
@@ -19,6 +20,15 @@ func wrapError(err error) error {
|
|||||||
if errors.Is(err, pgx.ErrNoRows) || errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, pgx.ErrNoRows) || errors.Is(err, sql.ErrNoRows) {
|
||||||
return database.NewNoRowFoundError(err)
|
return database.NewNoRowFoundError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// scany only exports its errors as strings
|
||||||
|
if strings.HasPrefix(err.Error(), "scany: expected 1 row, got: ") {
|
||||||
|
return database.NewMultipleRowsFoundError(err)
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(err.Error(), "scany:") || strings.HasPrefix(err.Error(), "scanning:") {
|
||||||
|
return database.NewScanError(err)
|
||||||
|
}
|
||||||
|
|
||||||
var pgxErr *pgconn.PgError
|
var pgxErr *pgconn.PgError
|
||||||
if !errors.As(err, &pgxErr) {
|
if !errors.As(err, &pgxErr) {
|
||||||
return database.NewUnknownError(err)
|
return database.NewUnknownError(err)
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ func SQLPool(db *sql.DB) *sqlPool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Acquire implements [database.Pool].
|
// Acquire implements [database.Pool].
|
||||||
func (c *sqlPool) Acquire(ctx context.Context) (database.Client, error) {
|
func (p *sqlPool) Acquire(ctx context.Context) (database.Connection, error) {
|
||||||
conn, err := c.Conn(ctx)
|
conn, err := p.Conn(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, wrapError(err)
|
return nil, wrapError(err)
|
||||||
}
|
}
|
||||||
@@ -30,9 +30,9 @@ func (c *sqlPool) Acquire(ctx context.Context) (database.Client, error) {
|
|||||||
|
|
||||||
// Query implements [database.Pool].
|
// Query implements [database.Pool].
|
||||||
// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool.
|
// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool.
|
||||||
func (c *sqlPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
func (p *sqlPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||||
//nolint:rowserrcheck // Rows.Close is called by the caller
|
//nolint:rowserrcheck // Rows.Close is called by the caller
|
||||||
rows, err := c.QueryContext(ctx, sql, args...)
|
rows, err := p.QueryContext(ctx, sql, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, wrapError(err)
|
return nil, wrapError(err)
|
||||||
}
|
}
|
||||||
@@ -41,14 +41,14 @@ func (c *sqlPool) Query(ctx context.Context, sql string, args ...any) (database.
|
|||||||
|
|
||||||
// QueryRow implements [database.Pool].
|
// QueryRow implements [database.Pool].
|
||||||
// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool.
|
// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool.
|
||||||
func (c *sqlPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
func (p *sqlPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||||
return &Row{c.QueryRowContext(ctx, sql, args...)}
|
return &Row{p.QueryRowContext(ctx, sql, args...)}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exec implements [database.Pool].
|
// Exec implements [database.Pool].
|
||||||
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
||||||
func (c *sqlPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
|
func (p *sqlPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
|
||||||
res, err := c.ExecContext(ctx, sql, args...)
|
res, err := p.ExecContext(ctx, sql, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, wrapError(err)
|
return 0, wrapError(err)
|
||||||
}
|
}
|
||||||
@@ -56,8 +56,8 @@ func (c *sqlPool) Exec(ctx context.Context, sql string, args ...any) (int64, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Begin implements [database.Pool].
|
// Begin implements [database.Pool].
|
||||||
func (c *sqlPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
func (p *sqlPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||||
tx, err := c.BeginTx(ctx, transactionOptionsToSQL(opts))
|
tx, err := p.BeginTx(ctx, transactionOptionsToSQL(opts))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, wrapError(err)
|
return nil, wrapError(err)
|
||||||
}
|
}
|
||||||
@@ -65,16 +65,16 @@ func (c *sqlPool) Begin(ctx context.Context, opts *database.TransactionOptions)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Ping implements [database.Pool].
|
// Ping implements [database.Pool].
|
||||||
func (c *sqlPool) Ping(ctx context.Context) error {
|
func (p *sqlPool) Ping(ctx context.Context) error {
|
||||||
return wrapError(c.PingContext(ctx))
|
return wrapError(p.PingContext(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close implements [database.Pool].
|
// Close implements [database.Pool].
|
||||||
func (c *sqlPool) Close(_ context.Context) error {
|
func (p *sqlPool) Close(_ context.Context) error {
|
||||||
return c.DB.Close()
|
return p.DB.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Migrate implements [database.Migrator].
|
// Migrate implements [database.Migrator].
|
||||||
func (c *sqlPool) Migrate(ctx context.Context) error {
|
func (p *sqlPool) Migrate(ctx context.Context) error {
|
||||||
return ErrMigrate
|
return ErrMigrate
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,34 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrNoChanges = errors.New("update must contain a change")
|
var (
|
||||||
|
ErrNoChanges = errors.New("update must contain a change")
|
||||||
|
)
|
||||||
|
|
||||||
|
type MissingConditionError struct {
|
||||||
|
col Column
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMissingConditionError(col Column) error {
|
||||||
|
return &MissingConditionError{
|
||||||
|
col: col,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *MissingConditionError) Error() string {
|
||||||
|
var builder StatementBuilder
|
||||||
|
builder.WriteString("missing condition for column")
|
||||||
|
if e.col != nil {
|
||||||
|
builder.WriteString(" on ")
|
||||||
|
e.col.WriteQualified(&builder)
|
||||||
|
}
|
||||||
|
return builder.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *MissingConditionError) Is(target error) bool {
|
||||||
|
_, ok := target.(*MissingConditionError)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
// NoRowFoundError is returned when QueryRow does not find any row.
|
// NoRowFoundError is returned when QueryRow does not find any row.
|
||||||
// It wraps the dialect specific original error to provide more context.
|
// It wraps the dialect specific original error to provide more context.
|
||||||
@@ -20,6 +47,9 @@ func NewNoRowFoundError(original error) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *NoRowFoundError) Error() string {
|
func (e *NoRowFoundError) Error() string {
|
||||||
|
if e.original != nil {
|
||||||
|
return fmt.Sprintf("no row found: %v", e.original)
|
||||||
|
}
|
||||||
return "no row found"
|
return "no row found"
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -36,18 +66,19 @@ func (e *NoRowFoundError) Unwrap() error {
|
|||||||
// It wraps the dialect specific original error to provide more context.
|
// It wraps the dialect specific original error to provide more context.
|
||||||
type MultipleRowsFoundError struct {
|
type MultipleRowsFoundError struct {
|
||||||
original error
|
original error
|
||||||
count int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMultipleRowsFoundError(original error, count int) error {
|
func NewMultipleRowsFoundError(original error) error {
|
||||||
return &MultipleRowsFoundError{
|
return &MultipleRowsFoundError{
|
||||||
original: original,
|
original: original,
|
||||||
count: count,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *MultipleRowsFoundError) Error() string {
|
func (e *MultipleRowsFoundError) Error() string {
|
||||||
return fmt.Sprintf("multiple rows found: %d", e.count)
|
if e.original != nil {
|
||||||
|
return fmt.Sprintf("multiple rows found: %v", e.original)
|
||||||
|
}
|
||||||
|
return "multiple rows found"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *MultipleRowsFoundError) Is(target error) bool {
|
func (e *MultipleRowsFoundError) Is(target error) bool {
|
||||||
@@ -77,8 +108,8 @@ type IntegrityViolationError struct {
|
|||||||
original error
|
original error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewIntegrityViolationError(typ IntegrityType, table, constraint string, original error) error {
|
func newIntegrityViolationError(typ IntegrityType, table, constraint string, original error) IntegrityViolationError {
|
||||||
return &IntegrityViolationError{
|
return IntegrityViolationError{
|
||||||
integrityType: typ,
|
integrityType: typ,
|
||||||
table: table,
|
table: table,
|
||||||
constraint: constraint,
|
constraint: constraint,
|
||||||
@@ -87,7 +118,10 @@ func NewIntegrityViolationError(typ IntegrityType, table, constraint string, ori
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *IntegrityViolationError) Error() string {
|
func (e *IntegrityViolationError) Error() string {
|
||||||
return fmt.Sprintf("integrity violation of type %q on %q (constraint: %q): %v", e.integrityType, e.table, e.constraint, e.original)
|
if e.original != nil {
|
||||||
|
return fmt.Sprintf("integrity violation of type %q on %q (constraint: %q): %v", e.integrityType, e.table, e.constraint, e.original)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("integrity violation of type %q on %q (constraint: %q)", e.integrityType, e.table, e.constraint)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *IntegrityViolationError) Is(target error) bool {
|
func (e *IntegrityViolationError) Is(target error) bool {
|
||||||
@@ -108,12 +142,7 @@ type CheckError struct {
|
|||||||
|
|
||||||
func NewCheckError(table, constraint string, original error) error {
|
func NewCheckError(table, constraint string, original error) error {
|
||||||
return &CheckError{
|
return &CheckError{
|
||||||
IntegrityViolationError: IntegrityViolationError{
|
IntegrityViolationError: newIntegrityViolationError(IntegrityTypeCheck, table, constraint, original),
|
||||||
integrityType: IntegrityTypeCheck,
|
|
||||||
table: table,
|
|
||||||
constraint: constraint,
|
|
||||||
original: original,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,12 +164,7 @@ type UniqueError struct {
|
|||||||
|
|
||||||
func NewUniqueError(table, constraint string, original error) error {
|
func NewUniqueError(table, constraint string, original error) error {
|
||||||
return &UniqueError{
|
return &UniqueError{
|
||||||
IntegrityViolationError: IntegrityViolationError{
|
IntegrityViolationError: newIntegrityViolationError(IntegrityTypeUnique, table, constraint, original),
|
||||||
integrityType: IntegrityTypeUnique,
|
|
||||||
table: table,
|
|
||||||
constraint: constraint,
|
|
||||||
original: original,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -162,12 +186,7 @@ type ForeignKeyError struct {
|
|||||||
|
|
||||||
func NewForeignKeyError(table, constraint string, original error) error {
|
func NewForeignKeyError(table, constraint string, original error) error {
|
||||||
return &ForeignKeyError{
|
return &ForeignKeyError{
|
||||||
IntegrityViolationError: IntegrityViolationError{
|
IntegrityViolationError: newIntegrityViolationError(IntegrityTypeForeign, table, constraint, original),
|
||||||
integrityType: IntegrityTypeForeign,
|
|
||||||
table: table,
|
|
||||||
constraint: constraint,
|
|
||||||
original: original,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,12 +208,7 @@ type NotNullError struct {
|
|||||||
|
|
||||||
func NewNotNullError(table, constraint string, original error) error {
|
func NewNotNullError(table, constraint string, original error) error {
|
||||||
return &NotNullError{
|
return &NotNullError{
|
||||||
IntegrityViolationError: IntegrityViolationError{
|
IntegrityViolationError: newIntegrityViolationError(IntegrityTypeNotNull, table, constraint, original),
|
||||||
integrityType: IntegrityTypeNotNull,
|
|
||||||
table: table,
|
|
||||||
constraint: constraint,
|
|
||||||
original: original,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -207,6 +221,31 @@ func (e *NotNullError) Unwrap() error {
|
|||||||
return &e.IntegrityViolationError
|
return &e.IntegrityViolationError
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanError is returned when scanning rows into objects failed.
|
||||||
|
// It wraps the original error to provide more context.
|
||||||
|
type ScanError struct {
|
||||||
|
original error
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewScanError(original error) error {
|
||||||
|
return &ScanError{
|
||||||
|
original: original,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ScanError) Error() string {
|
||||||
|
return fmt.Sprintf("Scan error: %v", e.original)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ScanError) Is(target error) bool {
|
||||||
|
_, ok := target.(*ScanError)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ScanError) Unwrap() error {
|
||||||
|
return e.original
|
||||||
|
}
|
||||||
|
|
||||||
// UnknownError is returned when an unknown error occurs.
|
// UnknownError is returned when an unknown error occurs.
|
||||||
// It wraps the dialect specific original error to provide more context.
|
// It wraps the dialect specific original error to provide more context.
|
||||||
// It is used to indicate that an error occurred that does not fit into any of the other categories.
|
// It is used to indicate that an error occurred that does not fit into any of the other categories.
|
||||||
|
|||||||
304
backend/v3/storage/database/errors_test.go
Normal file
304
backend/v3/storage/database/errors_test.go
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestError(t *testing.T) {
|
||||||
|
for _, test := range []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "missing condition without column",
|
||||||
|
err: NewMissingConditionError(nil),
|
||||||
|
want: "missing condition for column",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing condition with column",
|
||||||
|
err: NewMissingConditionError(NewColumn("table", "column")),
|
||||||
|
want: "missing condition for column on table.column",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no row found without original error",
|
||||||
|
err: NewNoRowFoundError(nil),
|
||||||
|
want: "no row found",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no row found with original error",
|
||||||
|
err: NewNoRowFoundError(errors.New("original error")),
|
||||||
|
want: "no row found: original error",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple rows found without original error",
|
||||||
|
err: NewMultipleRowsFoundError(nil),
|
||||||
|
want: "multiple rows found",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple rows found with original error",
|
||||||
|
err: NewMultipleRowsFoundError(errors.New("original error")),
|
||||||
|
want: "multiple rows found: original error",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "check violation without original error",
|
||||||
|
err: NewCheckError("table", "constraint", nil),
|
||||||
|
want: `integrity violation of type "check" on "table" (constraint: "constraint")`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "check violation with original error",
|
||||||
|
err: NewCheckError("table", "constraint", errors.New("original error")),
|
||||||
|
want: `integrity violation of type "check" on "table" (constraint: "constraint"): original error`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unique violation without original error",
|
||||||
|
err: NewUniqueError("table", "constraint", nil),
|
||||||
|
want: `integrity violation of type "unique" on "table" (constraint: "constraint")`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unique violation with original error",
|
||||||
|
err: NewUniqueError("table", "constraint", errors.New("original error")),
|
||||||
|
want: `integrity violation of type "unique" on "table" (constraint: "constraint"): original error`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "foreign key violation without original error",
|
||||||
|
err: NewForeignKeyError("table", "constraint", nil),
|
||||||
|
want: `integrity violation of type "foreign" on "table" (constraint: "constraint")`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "foreign key violation with original error",
|
||||||
|
err: NewForeignKeyError("table", "constraint", errors.New("original error")),
|
||||||
|
want: `integrity violation of type "foreign" on "table" (constraint: "constraint"): original error`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not null violation without original error",
|
||||||
|
err: NewNotNullError("table", "constraint", nil),
|
||||||
|
want: `integrity violation of type "not null" on "table" (constraint: "constraint")`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not null violation with original error",
|
||||||
|
err: NewNotNullError("table", "constraint", errors.New("original error")),
|
||||||
|
want: `integrity violation of type "not null" on "table" (constraint: "constraint"): original error`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown error",
|
||||||
|
err: NewUnknownError(errors.New("original error")),
|
||||||
|
want: `unknown database error: original error`,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, test.want, test.err.Error())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnwrap(t *testing.T) {
|
||||||
|
originalErr := errors.New("original error")
|
||||||
|
for _, test := range []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
want error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "missing condition without column",
|
||||||
|
err: NewMissingConditionError(nil),
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing condition with column",
|
||||||
|
err: NewMissingConditionError(NewColumn("table", "column")),
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no row found without original error",
|
||||||
|
err: NewNoRowFoundError(nil),
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no row found with original error",
|
||||||
|
err: NewNoRowFoundError(errors.New("original error")),
|
||||||
|
want: originalErr,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple rows found without original error",
|
||||||
|
err: NewMultipleRowsFoundError(nil),
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple rows found with original error",
|
||||||
|
err: NewMultipleRowsFoundError(originalErr),
|
||||||
|
want: originalErr,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "check violation without original error",
|
||||||
|
err: NewCheckError("table", "constraint", nil),
|
||||||
|
want: &IntegrityViolationError{
|
||||||
|
integrityType: IntegrityTypeCheck,
|
||||||
|
table: "table",
|
||||||
|
constraint: "constraint",
|
||||||
|
original: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "check violation with original error",
|
||||||
|
err: NewCheckError("table", "constraint", originalErr),
|
||||||
|
want: &IntegrityViolationError{
|
||||||
|
integrityType: IntegrityTypeCheck,
|
||||||
|
table: "table",
|
||||||
|
constraint: "constraint",
|
||||||
|
original: originalErr,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unique violation without original error",
|
||||||
|
err: NewUniqueError("table", "constraint", nil),
|
||||||
|
want: &IntegrityViolationError{
|
||||||
|
integrityType: IntegrityTypeUnique,
|
||||||
|
table: "table",
|
||||||
|
constraint: "constraint",
|
||||||
|
original: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unique violation with original error",
|
||||||
|
err: NewUniqueError("table", "constraint", originalErr),
|
||||||
|
want: &IntegrityViolationError{
|
||||||
|
integrityType: IntegrityTypeUnique,
|
||||||
|
table: "table",
|
||||||
|
constraint: "constraint",
|
||||||
|
original: originalErr,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "foreign key violation without original error",
|
||||||
|
err: NewForeignKeyError("table", "constraint", nil),
|
||||||
|
want: &IntegrityViolationError{
|
||||||
|
integrityType: IntegrityTypeForeign,
|
||||||
|
table: "table",
|
||||||
|
constraint: "constraint",
|
||||||
|
original: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "foreign key violation with original error",
|
||||||
|
err: NewForeignKeyError("table", "constraint", originalErr),
|
||||||
|
want: &IntegrityViolationError{
|
||||||
|
integrityType: IntegrityTypeForeign,
|
||||||
|
table: "table",
|
||||||
|
constraint: "constraint",
|
||||||
|
original: originalErr,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not null violation without original error",
|
||||||
|
err: NewNotNullError("table", "constraint", nil),
|
||||||
|
want: &IntegrityViolationError{
|
||||||
|
integrityType: IntegrityTypeNotNull,
|
||||||
|
table: "table",
|
||||||
|
constraint: "constraint",
|
||||||
|
original: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not null violation with original error",
|
||||||
|
err: NewNotNullError("table", "constraint", originalErr),
|
||||||
|
want: &IntegrityViolationError{
|
||||||
|
integrityType: IntegrityTypeNotNull,
|
||||||
|
table: "table",
|
||||||
|
constraint: "constraint",
|
||||||
|
original: originalErr,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unwrap integrity violation error",
|
||||||
|
err: errors.Unwrap(NewNotNullError("table", "constraint", originalErr)),
|
||||||
|
want: originalErr,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown error",
|
||||||
|
err: NewUnknownError(originalErr),
|
||||||
|
want: originalErr,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, test.want, errors.Unwrap(test.err))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIs(t *testing.T) {
|
||||||
|
originalErr := errors.New("original error")
|
||||||
|
for _, test := range []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
want error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "missing condition",
|
||||||
|
err: NewMissingConditionError(NewColumn("table", "column")),
|
||||||
|
want: new(MissingConditionError),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no row found",
|
||||||
|
err: NewNoRowFoundError(errors.New("original error")),
|
||||||
|
want: new(NoRowFoundError),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple rows found",
|
||||||
|
err: NewMultipleRowsFoundError(originalErr),
|
||||||
|
want: new(MultipleRowsFoundError),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "check violation is for integrity",
|
||||||
|
err: NewCheckError("table", "constraint", nil),
|
||||||
|
want: new(IntegrityViolationError),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "check violation is check violation",
|
||||||
|
err: NewCheckError("table", "constraint", nil),
|
||||||
|
want: new(CheckError),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unique violation is for integrity",
|
||||||
|
err: NewUniqueError("table", "constraint", nil),
|
||||||
|
want: new(IntegrityViolationError),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unique violation is unique violation",
|
||||||
|
err: NewUniqueError("table", "constraint", nil),
|
||||||
|
want: new(UniqueError),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "foreign key violation is for integrity",
|
||||||
|
err: NewForeignKeyError("table", "constraint", nil),
|
||||||
|
want: new(IntegrityViolationError),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "foreign key violation is foreign key violation",
|
||||||
|
err: NewForeignKeyError("table", "constraint", nil),
|
||||||
|
want: new(ForeignKeyError),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not null violation is for integrity",
|
||||||
|
err: NewNotNullError("table", "constraint", nil),
|
||||||
|
want: new(IntegrityViolationError),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not null violation is not null violation",
|
||||||
|
err: NewNotNullError("table", "constraint", nil),
|
||||||
|
want: new(NotNullError),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown error",
|
||||||
|
err: NewUnknownError(originalErr),
|
||||||
|
want: new(UnknownError),
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
assert.ErrorIs(t, test.err, test.want)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -21,8 +21,8 @@ import (
|
|||||||
func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||||
instance := integration.NewInstance(CTX)
|
instance := integration.NewInstance(CTX)
|
||||||
|
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
instanceRepo := repository.InstanceRepository()
|
||||||
instanceDomainRepo := instanceRepo.Domains(true)
|
instanceDomainRepo := repository.InstanceDomainRepository()
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
_, err := instance.Client.InstanceV2Beta.DeleteInstance(CTX, &v2beta.DeleteInstanceRequest{
|
_, err := instance.Client.InstanceV2Beta.DeleteInstance(CTX, &v2beta.DeleteInstanceRequest{
|
||||||
@@ -36,7 +36,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
|||||||
// Wait for instance to be created
|
// Wait for instance to be created
|
||||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
_, err := instanceRepo.Get(CTX,
|
_, err := instanceRepo.Get(CTX, pool,
|
||||||
database.WithCondition(instanceRepo.IDCondition(instance.ID())),
|
database.WithCondition(instanceRepo.IDCondition(instance.ID())),
|
||||||
)
|
)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
@@ -66,7 +66,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
|||||||
// Test that domain add reduces
|
// Test that domain add reduces
|
||||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
domain, err := instanceDomainRepo.Get(CTX,
|
domain, err := instanceDomainRepo.Get(CTX, pool,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
database.And(
|
database.And(
|
||||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||||
@@ -96,7 +96,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
|||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
// first we change the primary domain to something else
|
// first we change the primary domain to something else
|
||||||
domain, err := instanceDomainRepo.Get(CTX,
|
domain, err := instanceDomainRepo.Get(CTX, pool,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
database.And(
|
database.And(
|
||||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||||
@@ -125,7 +125,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
|||||||
// Wait for domain to be created
|
// Wait for domain to be created
|
||||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
domain, err := instanceDomainRepo.Get(CTX,
|
domain, err := instanceDomainRepo.Get(CTX, pool,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
database.And(
|
database.And(
|
||||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||||
@@ -151,7 +151,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
|||||||
// Test that set primary reduces
|
// Test that set primary reduces
|
||||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
domain, err := instanceDomainRepo.Get(CTX,
|
domain, err := instanceDomainRepo.Get(CTX, pool,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
database.And(
|
database.And(
|
||||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||||
@@ -180,7 +180,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
|||||||
// Wait for domain to be created and verify it exists
|
// Wait for domain to be created and verify it exists
|
||||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
_, err := instanceDomainRepo.Get(CTX,
|
_, err := instanceDomainRepo.Get(CTX, pool,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
database.And(
|
database.And(
|
||||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||||
@@ -202,7 +202,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
|||||||
// Test that domain remove reduces
|
// Test that domain remove reduces
|
||||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
domain, err := instanceDomainRepo.Get(CTX,
|
domain, err := instanceDomainRepo.Get(CTX, pool,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
database.And(
|
database.And(
|
||||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||||
@@ -241,7 +241,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
|||||||
// Test that domain add reduces
|
// Test that domain add reduces
|
||||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
domain, err := instanceDomainRepo.Get(CTX,
|
domain, err := instanceDomainRepo.Get(CTX, pool,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
database.And(
|
database.And(
|
||||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||||
@@ -271,7 +271,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
|||||||
// Wait for domain to be created and verify it exists
|
// Wait for domain to be created and verify it exists
|
||||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
_, err := instanceDomainRepo.Get(CTX,
|
_, err := instanceDomainRepo.Get(CTX, pool,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
database.And(
|
database.And(
|
||||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||||
@@ -293,7 +293,7 @@ func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
|||||||
// Test that domain remove reduces
|
// Test that domain remove reduces
|
||||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
domain, err := instanceDomainRepo.Get(CTX,
|
domain, err := instanceDomainRepo.Get(CTX, pool,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
database.And(
|
database.And(
|
||||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestServer_TestInstanceReduces(t *testing.T) {
|
func TestServer_TestInstanceReduces(t *testing.T) {
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
instanceRepo := repository.InstanceRepository()
|
||||||
|
|
||||||
t.Run("test instance add reduces", func(t *testing.T) {
|
t.Run("test instance add reduces", func(t *testing.T) {
|
||||||
instanceName := gofakeit.Name()
|
instanceName := gofakeit.Name()
|
||||||
@@ -46,7 +46,7 @@ func TestServer_TestInstanceReduces(t *testing.T) {
|
|||||||
|
|
||||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
instance, err := instanceRepo.Get(CTX,
|
instance, err := instanceRepo.Get(CTX, pool,
|
||||||
database.WithCondition(instanceRepo.IDCondition(instance.GetInstanceId())),
|
database.WithCondition(instanceRepo.IDCondition(instance.GetInstanceId())),
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -92,7 +92,7 @@ func TestServer_TestInstanceReduces(t *testing.T) {
|
|||||||
// check instance exists
|
// check instance exists
|
||||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
instance, err := instanceRepo.Get(CTX,
|
instance, err := instanceRepo.Get(CTX, pool,
|
||||||
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
|
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -110,7 +110,7 @@ func TestServer_TestInstanceReduces(t *testing.T) {
|
|||||||
|
|
||||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
instance, err := instanceRepo.Get(CTX,
|
instance, err := instanceRepo.Get(CTX, pool,
|
||||||
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
|
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -137,7 +137,7 @@ func TestServer_TestInstanceReduces(t *testing.T) {
|
|||||||
// check instance exists
|
// check instance exists
|
||||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
instance, err := instanceRepo.Get(CTX,
|
instance, err := instanceRepo.Get(CTX, pool,
|
||||||
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
|
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -151,7 +151,7 @@ func TestServer_TestInstanceReduces(t *testing.T) {
|
|||||||
|
|
||||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
instance, err := instanceRepo.Get(CTX,
|
instance, err := instanceRepo.Get(CTX, pool,
|
||||||
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
|
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
|
||||||
)
|
)
|
||||||
// event instance.removed
|
// event instance.removed
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
|
|||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
orgRepo := repository.OrganizationRepository(pool)
|
orgRepo := repository.OrganizationRepository()
|
||||||
orgDomainRepo := orgRepo.Domains(false)
|
orgDomainRepo := repository.OrganizationDomainRepository()
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
_, err := OrgClient.DeleteOrganization(CTX, &v2beta.DeleteOrganizationRequest{
|
_, err := OrgClient.DeleteOrganization(CTX, &v2beta.DeleteOrganizationRequest{
|
||||||
@@ -37,8 +37,13 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
|
|||||||
// Wait for org to be created
|
// Wait for org to be created
|
||||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
_, err := orgRepo.Get(CTX,
|
_, err := orgRepo.Get(CTX, pool,
|
||||||
database.WithCondition(orgRepo.IDCondition(org.GetId())),
|
database.WithCondition(
|
||||||
|
database.And(
|
||||||
|
orgRepo.InstanceIDCondition(Instance.Instance.Id),
|
||||||
|
orgRepo.IDCondition(org.GetId()),
|
||||||
|
),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}, retryDuration, tick)
|
}, retryDuration, tick)
|
||||||
@@ -68,7 +73,7 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
|
|||||||
// Test that domain add reduces
|
// Test that domain add reduces
|
||||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
gottenDomain, err := orgDomainRepo.Get(CTX,
|
gottenDomain, err := orgDomainRepo.Get(CTX, pool,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
database.And(
|
database.And(
|
||||||
orgDomainRepo.InstanceIDCondition(Instance.Instance.Id),
|
orgDomainRepo.InstanceIDCondition(Instance.Instance.Id),
|
||||||
@@ -107,7 +112,7 @@ func TestServer_TestOrgDomainReduces(t *testing.T) {
|
|||||||
// Test that domain remove reduces
|
// Test that domain remove reduces
|
||||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
domain, err := orgDomainRepo.Get(CTX,
|
domain, err := orgDomainRepo.Get(CTX, pool,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
database.And(
|
database.And(
|
||||||
orgDomainRepo.InstanceIDCondition(Instance.Instance.Id),
|
orgDomainRepo.InstanceIDCondition(Instance.Instance.Id),
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import (
|
|||||||
|
|
||||||
func TestServer_TestOrganizationReduces(t *testing.T) {
|
func TestServer_TestOrganizationReduces(t *testing.T) {
|
||||||
instanceID := Instance.ID()
|
instanceID := Instance.ID()
|
||||||
orgRepo := repository.OrganizationRepository(pool)
|
orgRepo := repository.OrganizationRepository()
|
||||||
|
|
||||||
t.Run("test org add reduces", func(t *testing.T) {
|
t.Run("test org add reduces", func(t *testing.T) {
|
||||||
beforeCreate := time.Now()
|
beforeCreate := time.Now()
|
||||||
@@ -42,7 +42,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
|
|||||||
|
|
||||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(tt *assert.CollectT) {
|
assert.EventuallyWithT(t, func(tt *assert.CollectT) {
|
||||||
organization, err := orgRepo.Get(CTX,
|
organization, err := orgRepo.Get(CTX, pool,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
database.And(
|
database.And(
|
||||||
orgRepo.IDCondition(org.GetId()),
|
orgRepo.IDCondition(org.GetId()),
|
||||||
@@ -92,7 +92,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
|
|||||||
|
|
||||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
organization, err := orgRepo.Get(CTX,
|
organization, err := orgRepo.Get(CTX, pool,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
database.And(
|
database.And(
|
||||||
orgRepo.IDCondition(organization.Id),
|
orgRepo.IDCondition(organization.Id),
|
||||||
@@ -137,7 +137,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
|
|||||||
|
|
||||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
organization, err := orgRepo.Get(CTX,
|
organization, err := orgRepo.Get(CTX, pool,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
database.And(
|
database.And(
|
||||||
orgRepo.IDCondition(organization.Id),
|
orgRepo.IDCondition(organization.Id),
|
||||||
@@ -177,11 +177,11 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
|
|||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
orgRepo := repository.OrganizationRepository(pool)
|
orgRepo := repository.OrganizationRepository()
|
||||||
// 3. check org deactivated
|
// 3. check org deactivated
|
||||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
organization, err := orgRepo.Get(CTX,
|
organization, err := orgRepo.Get(CTX, pool,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
database.And(
|
database.And(
|
||||||
orgRepo.IDCondition(organization.Id),
|
orgRepo.IDCondition(organization.Id),
|
||||||
@@ -203,7 +203,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
|
|||||||
|
|
||||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
organization, err := orgRepo.Get(CTX,
|
organization, err := orgRepo.Get(CTX, pool,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
database.And(
|
database.And(
|
||||||
orgRepo.IDCondition(organization.Id),
|
orgRepo.IDCondition(organization.Id),
|
||||||
@@ -230,10 +230,10 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// 2. check org retrievable
|
// 2. check org retrievable
|
||||||
orgRepo := repository.OrganizationRepository(pool)
|
orgRepo := repository.OrganizationRepository()
|
||||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
_, err := orgRepo.Get(CTX,
|
_, err := orgRepo.Get(CTX, pool,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
database.And(
|
database.And(
|
||||||
orgRepo.IDCondition(organization.Id),
|
orgRepo.IDCondition(organization.Id),
|
||||||
@@ -252,7 +252,7 @@ func TestServer_TestOrganizationReduces(t *testing.T) {
|
|||||||
|
|
||||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
organization, err := orgRepo.Get(CTX,
|
organization, err := orgRepo.Get(CTX, pool,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
database.And(
|
database.And(
|
||||||
orgRepo.IDCondition(organization.Id),
|
orgRepo.IDCondition(organization.Id),
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
package database
|
package database
|
||||||
|
|
||||||
//go:generate mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Client,Row,Rows,Transaction
|
//go:generate mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Connection,Row,Rows,Transaction
|
||||||
|
|||||||
@@ -6,16 +6,40 @@ import (
|
|||||||
"golang.org/x/exp/constraints"
|
"golang.org/x/exp/constraints"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Value interface {
|
type wrappedValue[V Value] struct {
|
||||||
Boolean | Number | Text | Instruction
|
value V
|
||||||
|
fn function
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func LowerValue[T Value](v T) wrappedValue[T] {
|
||||||
|
return wrappedValue[T]{value: v, fn: functionLower}
|
||||||
|
}
|
||||||
|
|
||||||
|
func SHA256Value[T Value](v T) wrappedValue[T] {
|
||||||
|
return wrappedValue[T]{value: v, fn: functionSHA256}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b wrappedValue[V]) WriteArg(builder *StatementBuilder) {
|
||||||
|
builder.Grow(len(b.fn) + 5)
|
||||||
|
builder.WriteString(string(b.fn))
|
||||||
|
builder.WriteRune('(')
|
||||||
|
builder.WriteArg(b.value)
|
||||||
|
builder.WriteRune(')')
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ argWriter = (*wrappedValue[string])(nil)
|
||||||
|
|
||||||
|
type Value interface {
|
||||||
|
Boolean | Number | Text | Instruction | Bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
//go:generate enumer -type NumberOperation,TextOperation,BytesOperation -linecomment -output ./operators_enumer.go
|
||||||
type Operation interface {
|
type Operation interface {
|
||||||
BooleanOperation | NumberOperation | TextOperation
|
NumberOperation | TextOperation | BytesOperation
|
||||||
}
|
}
|
||||||
|
|
||||||
type Text interface {
|
type Text interface {
|
||||||
~string | ~[]byte
|
~string | Bytes
|
||||||
}
|
}
|
||||||
|
|
||||||
// TextOperation are operations that can be performed on text values.
|
// TextOperation are operations that can be performed on text values.
|
||||||
@@ -23,60 +47,17 @@ type TextOperation uint8
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
// TextOperationEqual compares two strings for equality.
|
// TextOperationEqual compares two strings for equality.
|
||||||
TextOperationEqual TextOperation = iota + 1
|
TextOperationEqual TextOperation = iota + 1 // =
|
||||||
// TextOperationEqualIgnoreCase compares two strings for equality, ignoring case.
|
|
||||||
TextOperationEqualIgnoreCase
|
|
||||||
// TextOperationNotEqual compares two strings for inequality.
|
// TextOperationNotEqual compares two strings for inequality.
|
||||||
TextOperationNotEqual
|
TextOperationNotEqual // <>
|
||||||
// TextOperationNotEqualIgnoreCase compares two strings for inequality, ignoring case.
|
|
||||||
TextOperationNotEqualIgnoreCase
|
|
||||||
// TextOperationStartsWith checks if the first string starts with the second.
|
// TextOperationStartsWith checks if the first string starts with the second.
|
||||||
TextOperationStartsWith
|
TextOperationStartsWith // LIKE
|
||||||
// TextOperationStartsWithIgnoreCase checks if the first string starts with the second, ignoring case.
|
|
||||||
TextOperationStartsWithIgnoreCase
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var textOperations = map[TextOperation]string{
|
func writeTextOperation[T Text](builder *StatementBuilder, col Column, op TextOperation, value any) {
|
||||||
TextOperationEqual: " = ",
|
writeOperation[T](builder, col, op.String(), value)
|
||||||
TextOperationEqualIgnoreCase: " LIKE ",
|
if op == TextOperationStartsWith {
|
||||||
TextOperationNotEqual: " <> ",
|
|
||||||
TextOperationNotEqualIgnoreCase: " NOT LIKE ",
|
|
||||||
TextOperationStartsWith: " LIKE ",
|
|
||||||
TextOperationStartsWithIgnoreCase: " LIKE ",
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeTextOperation[T Text](builder *StatementBuilder, col Column, op TextOperation, value T) {
|
|
||||||
switch op {
|
|
||||||
case TextOperationEqual, TextOperationNotEqual:
|
|
||||||
col.WriteQualified(builder)
|
|
||||||
builder.WriteString(textOperations[op])
|
|
||||||
builder.WriteArg(value)
|
|
||||||
case TextOperationEqualIgnoreCase, TextOperationNotEqualIgnoreCase:
|
|
||||||
builder.WriteString("LOWER(")
|
|
||||||
col.WriteQualified(builder)
|
|
||||||
builder.WriteString(")")
|
|
||||||
|
|
||||||
builder.WriteString(textOperations[op])
|
|
||||||
builder.WriteString("LOWER(")
|
|
||||||
builder.WriteArg(value)
|
|
||||||
builder.WriteString(")")
|
|
||||||
case TextOperationStartsWith:
|
|
||||||
col.WriteQualified(builder)
|
|
||||||
builder.WriteString(textOperations[op])
|
|
||||||
builder.WriteArg(value)
|
|
||||||
builder.WriteString(" || '%'")
|
builder.WriteString(" || '%'")
|
||||||
case TextOperationStartsWithIgnoreCase:
|
|
||||||
builder.WriteString("LOWER(")
|
|
||||||
col.WriteQualified(builder)
|
|
||||||
builder.WriteString(")")
|
|
||||||
|
|
||||||
builder.WriteString(textOperations[op])
|
|
||||||
builder.WriteString("LOWER(")
|
|
||||||
builder.WriteArg(value)
|
|
||||||
builder.WriteString(")")
|
|
||||||
builder.WriteString(" || '%'")
|
|
||||||
default:
|
|
||||||
panic("unsupported text operation")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,48 +70,60 @@ type NumberOperation uint8
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
// NumberOperationEqual compares two numbers for equality.
|
// NumberOperationEqual compares two numbers for equality.
|
||||||
NumberOperationEqual NumberOperation = iota + 1
|
NumberOperationEqual NumberOperation = iota + 1 // =
|
||||||
// NumberOperationNotEqual compares two numbers for inequality.
|
// NumberOperationNotEqual compares two numbers for inequality.
|
||||||
NumberOperationNotEqual
|
NumberOperationNotEqual // <>
|
||||||
// NumberOperationLessThan compares two numbers to check if the first is less than the second.
|
// NumberOperationLessThan compares two numbers to check if the first is less than the second.
|
||||||
NumberOperationLessThan
|
NumberOperationLessThan // <
|
||||||
// NumberOperationLessThanOrEqual compares two numbers to check if the first is less than or equal to the second.
|
// NumberOperationLessThanOrEqual compares two numbers to check if the first is less than or equal to the second.
|
||||||
NumberOperationAtLeast
|
NumberOperationAtLeast // <=
|
||||||
// NumberOperationGreaterThan compares two numbers to check if the first is greater than the second.
|
// NumberOperationGreaterThan compares two numbers to check if the first is greater than the second.
|
||||||
NumberOperationGreaterThan
|
NumberOperationGreaterThan // >
|
||||||
// NumberOperationGreaterThanOrEqual compares two numbers to check if the first is greater than or equal to the second.
|
// NumberOperationGreaterThanOrEqual compares two numbers to check if the first is greater than or equal to the second.
|
||||||
NumberOperationAtMost
|
NumberOperationAtMost // >=
|
||||||
)
|
)
|
||||||
|
|
||||||
var numberOperations = map[NumberOperation]string{
|
func writeNumberOperation[T Number](builder *StatementBuilder, col Column, op NumberOperation, value any) {
|
||||||
NumberOperationEqual: " = ",
|
writeOperation[T](builder, col, op.String(), value)
|
||||||
NumberOperationNotEqual: " <> ",
|
|
||||||
NumberOperationLessThan: " < ",
|
|
||||||
NumberOperationAtLeast: " <= ",
|
|
||||||
NumberOperationGreaterThan: " > ",
|
|
||||||
NumberOperationAtMost: " >= ",
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeNumberOperation[T Number](builder *StatementBuilder, col Column, op NumberOperation, value T) {
|
|
||||||
col.WriteQualified(builder)
|
|
||||||
builder.WriteString(numberOperations[op])
|
|
||||||
builder.WriteArg(value)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Boolean interface {
|
type Boolean interface {
|
||||||
~bool
|
~bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// BooleanOperation are operations that can be performed on boolean values.
|
func writeBooleanOperation[T Boolean](builder *StatementBuilder, col Column, value any) {
|
||||||
type BooleanOperation uint8
|
writeOperation[T](builder, col, "=", value)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Bytes interface {
|
||||||
|
~[]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// BytesOperation are operations that can be performed on bytea values.
|
||||||
|
type BytesOperation uint8
|
||||||
|
|
||||||
const (
|
const (
|
||||||
BooleanOperationIsTrue BooleanOperation = iota + 1
|
BytesOperationEqual BytesOperation = iota + 1 // =
|
||||||
BooleanOperationIsFalse
|
BytesOperationNotEqual // <>
|
||||||
)
|
)
|
||||||
|
|
||||||
func writeBooleanOperation[T Boolean](builder *StatementBuilder, col Column, value T) {
|
func writeBytesOperation[T Bytes](builder *StatementBuilder, col Column, op BytesOperation, value any) {
|
||||||
|
writeOperation[T](builder, col, op.String(), value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeOperation[V Value](builder *StatementBuilder, col Column, op string, value any) {
|
||||||
|
if op == "" {
|
||||||
|
panic("unsupported operation")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch value.(type) {
|
||||||
|
case V, wrappedValue[V], *wrappedValue[V]:
|
||||||
|
default:
|
||||||
|
panic("unsupported value type")
|
||||||
|
}
|
||||||
col.WriteQualified(builder)
|
col.WriteQualified(builder)
|
||||||
builder.WriteString(" = ")
|
builder.WriteRune(' ')
|
||||||
|
builder.WriteString(op)
|
||||||
|
builder.WriteRune(' ')
|
||||||
builder.WriteArg(value)
|
builder.WriteArg(value)
|
||||||
}
|
}
|
||||||
|
|||||||
241
backend/v3/storage/database/operators_enumer.go
Normal file
241
backend/v3/storage/database/operators_enumer.go
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
// Code generated by "enumer -type NumberOperation,TextOperation,BytesOperation -linecomment -output ./operators_enumer.go"; DO NOT EDIT.
|
||||||
|
|
||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const _NumberOperationName = "=<><<=>>="
|
||||||
|
|
||||||
|
var _NumberOperationIndex = [...]uint8{0, 1, 3, 4, 6, 7, 9}
|
||||||
|
|
||||||
|
const _NumberOperationLowerName = "=<><<=>>="
|
||||||
|
|
||||||
|
func (i NumberOperation) String() string {
|
||||||
|
i -= 1
|
||||||
|
if i >= NumberOperation(len(_NumberOperationIndex)-1) {
|
||||||
|
return fmt.Sprintf("NumberOperation(%d)", i+1)
|
||||||
|
}
|
||||||
|
return _NumberOperationName[_NumberOperationIndex[i]:_NumberOperationIndex[i+1]]
|
||||||
|
}
|
||||||
|
|
||||||
|
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||||
|
// Re-run the stringer command to generate them again.
|
||||||
|
func _NumberOperationNoOp() {
|
||||||
|
var x [1]struct{}
|
||||||
|
_ = x[NumberOperationEqual-(1)]
|
||||||
|
_ = x[NumberOperationNotEqual-(2)]
|
||||||
|
_ = x[NumberOperationLessThan-(3)]
|
||||||
|
_ = x[NumberOperationAtLeast-(4)]
|
||||||
|
_ = x[NumberOperationGreaterThan-(5)]
|
||||||
|
_ = x[NumberOperationAtMost-(6)]
|
||||||
|
}
|
||||||
|
|
||||||
|
var _NumberOperationValues = []NumberOperation{NumberOperationEqual, NumberOperationNotEqual, NumberOperationLessThan, NumberOperationAtLeast, NumberOperationGreaterThan, NumberOperationAtMost}
|
||||||
|
|
||||||
|
var _NumberOperationNameToValueMap = map[string]NumberOperation{
|
||||||
|
_NumberOperationName[0:1]: NumberOperationEqual,
|
||||||
|
_NumberOperationLowerName[0:1]: NumberOperationEqual,
|
||||||
|
_NumberOperationName[1:3]: NumberOperationNotEqual,
|
||||||
|
_NumberOperationLowerName[1:3]: NumberOperationNotEqual,
|
||||||
|
_NumberOperationName[3:4]: NumberOperationLessThan,
|
||||||
|
_NumberOperationLowerName[3:4]: NumberOperationLessThan,
|
||||||
|
_NumberOperationName[4:6]: NumberOperationAtLeast,
|
||||||
|
_NumberOperationLowerName[4:6]: NumberOperationAtLeast,
|
||||||
|
_NumberOperationName[6:7]: NumberOperationGreaterThan,
|
||||||
|
_NumberOperationLowerName[6:7]: NumberOperationGreaterThan,
|
||||||
|
_NumberOperationName[7:9]: NumberOperationAtMost,
|
||||||
|
_NumberOperationLowerName[7:9]: NumberOperationAtMost,
|
||||||
|
}
|
||||||
|
|
||||||
|
var _NumberOperationNames = []string{
|
||||||
|
_NumberOperationName[0:1],
|
||||||
|
_NumberOperationName[1:3],
|
||||||
|
_NumberOperationName[3:4],
|
||||||
|
_NumberOperationName[4:6],
|
||||||
|
_NumberOperationName[6:7],
|
||||||
|
_NumberOperationName[7:9],
|
||||||
|
}
|
||||||
|
|
||||||
|
// NumberOperationString retrieves an enum value from the enum constants string name.
|
||||||
|
// Throws an error if the param is not part of the enum.
|
||||||
|
func NumberOperationString(s string) (NumberOperation, error) {
|
||||||
|
if val, ok := _NumberOperationNameToValueMap[s]; ok {
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if val, ok := _NumberOperationNameToValueMap[strings.ToLower(s)]; ok {
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("%s does not belong to NumberOperation values", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NumberOperationValues returns all values of the enum
|
||||||
|
func NumberOperationValues() []NumberOperation {
|
||||||
|
return _NumberOperationValues
|
||||||
|
}
|
||||||
|
|
||||||
|
// NumberOperationStrings returns a slice of all String values of the enum
|
||||||
|
func NumberOperationStrings() []string {
|
||||||
|
strs := make([]string, len(_NumberOperationNames))
|
||||||
|
copy(strs, _NumberOperationNames)
|
||||||
|
return strs
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsANumberOperation returns "true" if the value is listed in the enum definition. "false" otherwise
|
||||||
|
func (i NumberOperation) IsANumberOperation() bool {
|
||||||
|
for _, v := range _NumberOperationValues {
|
||||||
|
if i == v {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
const _TextOperationName = "=<>LIKE"
|
||||||
|
|
||||||
|
var _TextOperationIndex = [...]uint8{0, 1, 3, 7}
|
||||||
|
|
||||||
|
const _TextOperationLowerName = "=<>like"
|
||||||
|
|
||||||
|
func (i TextOperation) String() string {
|
||||||
|
i -= 1
|
||||||
|
if i >= TextOperation(len(_TextOperationIndex)-1) {
|
||||||
|
return fmt.Sprintf("TextOperation(%d)", i+1)
|
||||||
|
}
|
||||||
|
return _TextOperationName[_TextOperationIndex[i]:_TextOperationIndex[i+1]]
|
||||||
|
}
|
||||||
|
|
||||||
|
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||||
|
// Re-run the stringer command to generate them again.
|
||||||
|
func _TextOperationNoOp() {
|
||||||
|
var x [1]struct{}
|
||||||
|
_ = x[TextOperationEqual-(1)]
|
||||||
|
_ = x[TextOperationNotEqual-(2)]
|
||||||
|
_ = x[TextOperationStartsWith-(3)]
|
||||||
|
}
|
||||||
|
|
||||||
|
var _TextOperationValues = []TextOperation{TextOperationEqual, TextOperationNotEqual, TextOperationStartsWith}
|
||||||
|
|
||||||
|
var _TextOperationNameToValueMap = map[string]TextOperation{
|
||||||
|
_TextOperationName[0:1]: TextOperationEqual,
|
||||||
|
_TextOperationLowerName[0:1]: TextOperationEqual,
|
||||||
|
_TextOperationName[1:3]: TextOperationNotEqual,
|
||||||
|
_TextOperationLowerName[1:3]: TextOperationNotEqual,
|
||||||
|
_TextOperationName[3:7]: TextOperationStartsWith,
|
||||||
|
_TextOperationLowerName[3:7]: TextOperationStartsWith,
|
||||||
|
}
|
||||||
|
|
||||||
|
var _TextOperationNames = []string{
|
||||||
|
_TextOperationName[0:1],
|
||||||
|
_TextOperationName[1:3],
|
||||||
|
_TextOperationName[3:7],
|
||||||
|
}
|
||||||
|
|
||||||
|
// TextOperationString retrieves an enum value from the enum constants string name.
|
||||||
|
// Throws an error if the param is not part of the enum.
|
||||||
|
func TextOperationString(s string) (TextOperation, error) {
|
||||||
|
if val, ok := _TextOperationNameToValueMap[s]; ok {
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if val, ok := _TextOperationNameToValueMap[strings.ToLower(s)]; ok {
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("%s does not belong to TextOperation values", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TextOperationValues returns all values of the enum
|
||||||
|
func TextOperationValues() []TextOperation {
|
||||||
|
return _TextOperationValues
|
||||||
|
}
|
||||||
|
|
||||||
|
// TextOperationStrings returns a slice of all String values of the enum
|
||||||
|
func TextOperationStrings() []string {
|
||||||
|
strs := make([]string, len(_TextOperationNames))
|
||||||
|
copy(strs, _TextOperationNames)
|
||||||
|
return strs
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsATextOperation returns "true" if the value is listed in the enum definition. "false" otherwise
|
||||||
|
func (i TextOperation) IsATextOperation() bool {
|
||||||
|
for _, v := range _TextOperationValues {
|
||||||
|
if i == v {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
const _BytesOperationName = "=<>"
|
||||||
|
|
||||||
|
var _BytesOperationIndex = [...]uint8{0, 1, 3}
|
||||||
|
|
||||||
|
const _BytesOperationLowerName = "=<>"
|
||||||
|
|
||||||
|
func (i BytesOperation) String() string {
|
||||||
|
i -= 1
|
||||||
|
if i >= BytesOperation(len(_BytesOperationIndex)-1) {
|
||||||
|
return fmt.Sprintf("BytesOperation(%d)", i+1)
|
||||||
|
}
|
||||||
|
return _BytesOperationName[_BytesOperationIndex[i]:_BytesOperationIndex[i+1]]
|
||||||
|
}
|
||||||
|
|
||||||
|
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||||
|
// Re-run the stringer command to generate them again.
|
||||||
|
func _BytesOperationNoOp() {
|
||||||
|
var x [1]struct{}
|
||||||
|
_ = x[BytesOperationEqual-(1)]
|
||||||
|
_ = x[BytesOperationNotEqual-(2)]
|
||||||
|
}
|
||||||
|
|
||||||
|
var _BytesOperationValues = []BytesOperation{BytesOperationEqual, BytesOperationNotEqual}
|
||||||
|
|
||||||
|
var _BytesOperationNameToValueMap = map[string]BytesOperation{
|
||||||
|
_BytesOperationName[0:1]: BytesOperationEqual,
|
||||||
|
_BytesOperationLowerName[0:1]: BytesOperationEqual,
|
||||||
|
_BytesOperationName[1:3]: BytesOperationNotEqual,
|
||||||
|
_BytesOperationLowerName[1:3]: BytesOperationNotEqual,
|
||||||
|
}
|
||||||
|
|
||||||
|
var _BytesOperationNames = []string{
|
||||||
|
_BytesOperationName[0:1],
|
||||||
|
_BytesOperationName[1:3],
|
||||||
|
}
|
||||||
|
|
||||||
|
// BytesOperationString retrieves an enum value from the enum constants string name.
|
||||||
|
// Throws an error if the param is not part of the enum.
|
||||||
|
func BytesOperationString(s string) (BytesOperation, error) {
|
||||||
|
if val, ok := _BytesOperationNameToValueMap[s]; ok {
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if val, ok := _BytesOperationNameToValueMap[strings.ToLower(s)]; ok {
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("%s does not belong to BytesOperation values", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BytesOperationValues returns all values of the enum
|
||||||
|
func BytesOperationValues() []BytesOperation {
|
||||||
|
return _BytesOperationValues
|
||||||
|
}
|
||||||
|
|
||||||
|
// BytesOperationStrings returns a slice of all String values of the enum
|
||||||
|
func BytesOperationStrings() []string {
|
||||||
|
strs := make([]string, len(_BytesOperationNames))
|
||||||
|
copy(strs, _BytesOperationNames)
|
||||||
|
return strs
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsABytesOperation returns "true" if the value is listed in the enum definition. "false" otherwise
|
||||||
|
func (i BytesOperation) IsABytesOperation() bool {
|
||||||
|
for _, v := range _BytesOperationValues {
|
||||||
|
if i == v {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
244
backend/v3/storage/database/operators_test.go
Normal file
244
backend/v3/storage/database/operators_test.go
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_writeOperation(t *testing.T) {
|
||||||
|
type want struct {
|
||||||
|
shouldPanic bool
|
||||||
|
stmt string
|
||||||
|
args []any
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
write func(builder *StatementBuilder)
|
||||||
|
// col Column
|
||||||
|
// op string
|
||||||
|
// value any
|
||||||
|
want want
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "unsupported operation panics",
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeOperation[string](builder, NewColumn("table", "column"), "", "value")
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
shouldPanic: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unsupported value type panics",
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeOperation[string](builder, NewColumn("table", "column"), " = ", struct{}{})
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
shouldPanic: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unsupported wrapped value type panics",
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeOperation[string](builder, NewColumn("table", "column"), " = ", SHA256Value(123))
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
shouldPanic: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "text equal",
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeTextOperation[string](builder, NewColumn("table", "column"), TextOperationEqual, "value")
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: "table.column = $1",
|
||||||
|
args: []any{"value"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "text not equal",
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeTextOperation[string](builder, NewColumn("table", "column"), TextOperationNotEqual, "value")
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: "table.column <> $1",
|
||||||
|
args: []any{"value"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "text starts with",
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeTextOperation[string](builder, NewColumn("table", "column"), TextOperationStartsWith, "value")
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: "table.column LIKE $1 || '%'",
|
||||||
|
args: []any{"value"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "text equal with wrapped value",
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeTextOperation[string](builder, LowerColumn(NewColumn("table", "column")), TextOperationEqual, LowerValue("value"))
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: "LOWER(table.column) = LOWER($1)",
|
||||||
|
args: []any{"value"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "text not equal with wrapped value",
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeTextOperation[string](builder, LowerColumn(NewColumn("table", "column")), TextOperationNotEqual, LowerValue("value"))
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: "LOWER(table.column) <> LOWER($1)",
|
||||||
|
args: []any{"value"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "text starts with with wrapped value",
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeTextOperation[string](builder, LowerColumn(NewColumn("table", "column")), TextOperationStartsWith, LowerValue("value"))
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: "LOWER(table.column) LIKE LOWER($1) || '%'",
|
||||||
|
args: []any{"value"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "number equal",
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationEqual, 123)
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: "table.column = $1",
|
||||||
|
args: []any{123},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "number not equal",
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationNotEqual, 123)
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: "table.column <> $1",
|
||||||
|
args: []any{123},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "number less than",
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationLessThan, 123)
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: "table.column < $1",
|
||||||
|
args: []any{123},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "number less than or equal",
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationAtLeast, 123)
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: "table.column <= $1",
|
||||||
|
args: []any{123},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "number greater than",
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationGreaterThan, 123)
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: "table.column > $1",
|
||||||
|
args: []any{123},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "number greater than or equal",
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeNumberOperation[int](builder, NewColumn("table", "column"), NumberOperationAtMost, 123)
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: "table.column >= $1",
|
||||||
|
args: []any{123},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "boolean is true",
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeBooleanOperation[bool](builder, NewColumn("table", "column"), true)
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: "table.column = $1",
|
||||||
|
args: []any{true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "boolean is false",
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeBooleanOperation[bool](builder, NewColumn("table", "column"), false)
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: "table.column = $1",
|
||||||
|
args: []any{false},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bytes equal",
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeBytesOperation[[]byte](builder, NewColumn("table", "column"), BytesOperationEqual, []byte{0x01, 0x02, 0x03})
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: "table.column = $1",
|
||||||
|
args: []any{[]byte{0x01, 0x02, 0x03}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bytes not equal",
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeBytesOperation[[]byte](builder, NewColumn("table", "column"), BytesOperationNotEqual, []byte{0x01, 0x02, 0x03})
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: "table.column <> $1",
|
||||||
|
args: []any{[]byte{0x01, 0x02, 0x03}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bytes equal with wrapped value",
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeBytesOperation[[]byte](builder, SHA256Column(NewColumn("table", "column")), BytesOperationEqual, SHA256Value([]byte{0x01, 0x02, 0x03}))
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: "SHA256(table.column) = SHA256($1)",
|
||||||
|
args: []any{[]byte{0x01, 0x02, 0x03}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bytes not equal with wrapped value",
|
||||||
|
write: func(builder *StatementBuilder) {
|
||||||
|
writeBytesOperation[[]byte](builder, SHA256Column(NewColumn("table", "column")), BytesOperationNotEqual, SHA256Value([]byte{0x01, 0x02, 0x03}))
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: "SHA256(table.column) <> SHA256($1)",
|
||||||
|
args: []any{[]byte{0x01, 0x02, 0x03}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
r := recover()
|
||||||
|
assert.Equal(t, tt.want.shouldPanic, r != nil)
|
||||||
|
}()
|
||||||
|
var builder StatementBuilder
|
||||||
|
tt.write(&builder)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.want.stmt, builder.String())
|
||||||
|
assert.Equal(t, tt.want.args, builder.Args())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
28
backend/v3/storage/database/order_test.go
Normal file
28
backend/v3/storage/database/order_test.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_orderBy_Write(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
want string
|
||||||
|
order Order
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "order by column",
|
||||||
|
want: " ORDER BY table.column",
|
||||||
|
order: OrderBy(NewColumn("table", "column")),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var builder StatementBuilder
|
||||||
|
tt.order.Write(&builder)
|
||||||
|
assert.Equal(t, tt.want, builder.String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
147
backend/v3/storage/database/query_test.go
Normal file
147
backend/v3/storage/database/query_test.go
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestQueryOptions(t *testing.T) {
|
||||||
|
type want struct {
|
||||||
|
stmt string
|
||||||
|
args []any
|
||||||
|
}
|
||||||
|
for _, test := range []struct {
|
||||||
|
name string
|
||||||
|
options []QueryOption
|
||||||
|
want want
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no options",
|
||||||
|
want: want{
|
||||||
|
stmt: "",
|
||||||
|
args: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "limit option",
|
||||||
|
options: []QueryOption{
|
||||||
|
WithLimit(10),
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: " LIMIT $1",
|
||||||
|
args: []any{uint32(10)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "offset option",
|
||||||
|
options: []QueryOption{
|
||||||
|
WithOffset(5),
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: " OFFSET $1",
|
||||||
|
args: []any{uint32(5)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "order by asc option",
|
||||||
|
options: []QueryOption{
|
||||||
|
WithOrderByAscending(NewColumn("table", "column")),
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: " ORDER BY table.column",
|
||||||
|
args: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "order by desc option",
|
||||||
|
options: []QueryOption{
|
||||||
|
WithOrderByDescending(NewColumn("table", "column")),
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: " ORDER BY table.column DESC",
|
||||||
|
args: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "order by option",
|
||||||
|
options: []QueryOption{
|
||||||
|
WithOrderBy(OrderDirectionAsc, NewColumn("table", "column1"), NewColumn("table", "column2")),
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: " ORDER BY table.column1, table.column2",
|
||||||
|
args: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "condition option",
|
||||||
|
options: []QueryOption{
|
||||||
|
WithCondition(NewBooleanCondition(NewColumn("table", "column"), true)),
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: " WHERE table.column = $1",
|
||||||
|
args: []any{true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "group by option",
|
||||||
|
options: []QueryOption{
|
||||||
|
WithGroupBy(NewColumn("table", "column")),
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: " GROUP BY table.column",
|
||||||
|
args: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "left join option",
|
||||||
|
options: []QueryOption{
|
||||||
|
WithLeftJoin("other_table", NewColumnCondition(NewColumn("table", "id"), NewColumn("other_table", "table_id"))),
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: " LEFT JOIN other_table ON table.id = other_table.table_id",
|
||||||
|
args: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "permission check option",
|
||||||
|
options: []QueryOption{
|
||||||
|
WithPermissionCheck("permission"),
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: "",
|
||||||
|
args: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple options",
|
||||||
|
options: []QueryOption{
|
||||||
|
WithLeftJoin("other_table", NewColumnCondition(NewColumn("table", "id"), NewColumn("other_table", "table_id"))),
|
||||||
|
WithCondition(NewNumberCondition(NewColumn("table", "column"), NumberOperationEqual, 123)),
|
||||||
|
WithOrderByDescending(NewColumn("table", "column")),
|
||||||
|
WithLimit(10),
|
||||||
|
WithOffset(5),
|
||||||
|
},
|
||||||
|
want: want{
|
||||||
|
stmt: " LEFT JOIN other_table ON table.id = other_table.table_id WHERE table.column = $1 ORDER BY table.column DESC LIMIT $2 OFFSET $3",
|
||||||
|
args: []any{123, uint32(10), uint32(5)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
var b StatementBuilder
|
||||||
|
var opts QueryOpts
|
||||||
|
for _, option := range test.options {
|
||||||
|
option(&opts)
|
||||||
|
}
|
||||||
|
opts.Write(&b)
|
||||||
|
assert.Equal(t, test.want.stmt, b.String())
|
||||||
|
require.Len(t, b.Args(), len(test.want.args))
|
||||||
|
for i := range test.want.args {
|
||||||
|
assert.Equal(t, test.want.args[i], b.Args()[i])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -3,19 +3,35 @@ package repository
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
|
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
type JSONArray[T any] []*T
|
type JSONArray[T any] []*T
|
||||||
|
|
||||||
func (a JSONArray[T]) Scan(src any) error {
|
var ErrScanSource = errors.New("unsupported scan source")
|
||||||
|
|
||||||
|
func (a *JSONArray[T]) Scan(src any) (err error) {
|
||||||
|
var rawJSON []byte
|
||||||
switch s := src.(type) {
|
switch s := src.(type) {
|
||||||
case string:
|
case string:
|
||||||
return json.Unmarshal([]byte(s), &a)
|
if len(s) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
rawJSON = []byte(s)
|
||||||
case []byte:
|
case []byte:
|
||||||
return json.Unmarshal(s, &a)
|
if len(s) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
rawJSON = s
|
||||||
case nil:
|
case nil:
|
||||||
return nil
|
return nil
|
||||||
default:
|
default:
|
||||||
return errors.New("unsupported scan source")
|
return ErrScanSource
|
||||||
}
|
}
|
||||||
|
err = json.Unmarshal(rawJSON, a)
|
||||||
|
if err != nil {
|
||||||
|
return database.NewScanError(err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
117
backend/v3/storage/database/repository/array_test.go
Normal file
117
backend/v3/storage/database/repository/array_test.go
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
package repository_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||||
|
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testObject struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Number int `json:"number"`
|
||||||
|
Active bool `json:"active"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONArray_Scan(t *testing.T) {
|
||||||
|
type want struct {
|
||||||
|
err error
|
||||||
|
res []*testObject
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
src any
|
||||||
|
want want
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil",
|
||||||
|
src: nil,
|
||||||
|
want: want{
|
||||||
|
err: nil,
|
||||||
|
res: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "[]byte with data",
|
||||||
|
src: []byte(`[{"id":"1","number":1,"active":true}]`),
|
||||||
|
want: want{
|
||||||
|
err: nil,
|
||||||
|
res: []*testObject{{ID: "1", Number: 1, Active: true}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "[]byte without data",
|
||||||
|
src: []byte(`[]`),
|
||||||
|
want: want{
|
||||||
|
err: nil,
|
||||||
|
res: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil []byte",
|
||||||
|
src: []byte(nil),
|
||||||
|
want: want{
|
||||||
|
err: nil,
|
||||||
|
res: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty []byte",
|
||||||
|
src: []byte{},
|
||||||
|
want: want{
|
||||||
|
err: nil,
|
||||||
|
res: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string with data",
|
||||||
|
src: `[{"id":"1","number":1,"active":true}]`,
|
||||||
|
want: want{
|
||||||
|
err: nil,
|
||||||
|
res: []*testObject{{ID: "1", Number: 1, Active: true}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string without data",
|
||||||
|
src: string(`[]`),
|
||||||
|
want: want{
|
||||||
|
err: nil,
|
||||||
|
res: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
src: "",
|
||||||
|
want: want{
|
||||||
|
err: nil,
|
||||||
|
res: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong type",
|
||||||
|
src: []int{1, 2, 3},
|
||||||
|
want: want{
|
||||||
|
err: repository.ErrScanSource,
|
||||||
|
res: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "badly formatted JSON",
|
||||||
|
src: []byte(`this is not a JSON`),
|
||||||
|
want: want{
|
||||||
|
err: new(database.ScanError),
|
||||||
|
res: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var a repository.JSONArray[testObject]
|
||||||
|
gotErr := a.Scan(tt.src)
|
||||||
|
require.ErrorIs(t, gotErr, tt.want.err)
|
||||||
|
require.Len(t, a, len(tt.want.res))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -13,17 +13,20 @@ import (
|
|||||||
var _ domain.InstanceRepository = (*instance)(nil)
|
var _ domain.InstanceRepository = (*instance)(nil)
|
||||||
|
|
||||||
type instance struct {
|
type instance struct {
|
||||||
repository
|
|
||||||
shouldLoadDomains bool
|
shouldLoadDomains bool
|
||||||
domainRepo *instanceDomain
|
domainRepo instanceDomain
|
||||||
}
|
}
|
||||||
|
|
||||||
func InstanceRepository(client database.QueryExecutor) domain.InstanceRepository {
|
func InstanceRepository() domain.InstanceRepository {
|
||||||
return &instance{
|
return new(instance)
|
||||||
repository: repository{
|
}
|
||||||
client: client,
|
|
||||||
},
|
func (instance) qualifiedTableName() string {
|
||||||
}
|
return "zitadel.instances"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (instance) unqualifiedTableName() string {
|
||||||
|
return "instances"
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------------
|
// -------------------------------------------------------------
|
||||||
@@ -37,7 +40,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Get implements [domain.InstanceRepository].
|
// Get implements [domain.InstanceRepository].
|
||||||
func (i *instance) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Instance, error) {
|
func (i instance) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.Instance, error) {
|
||||||
opts = append(opts,
|
opts = append(opts,
|
||||||
i.joinDomains(),
|
i.joinDomains(),
|
||||||
database.WithGroupBy(i.IDColumn()),
|
database.WithGroupBy(i.IDColumn()),
|
||||||
@@ -52,11 +55,11 @@ func (i *instance) Get(ctx context.Context, opts ...database.QueryOption) (*doma
|
|||||||
builder.WriteString(queryInstanceStmt)
|
builder.WriteString(queryInstanceStmt)
|
||||||
options.Write(&builder)
|
options.Write(&builder)
|
||||||
|
|
||||||
return scanInstance(ctx, i.client, &builder)
|
return scanInstance(ctx, client, &builder)
|
||||||
}
|
}
|
||||||
|
|
||||||
// List implements [domain.InstanceRepository].
|
// List implements [domain.InstanceRepository].
|
||||||
func (i *instance) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Instance, error) {
|
func (i instance) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*domain.Instance, error) {
|
||||||
opts = append(opts,
|
opts = append(opts,
|
||||||
i.joinDomains(),
|
i.joinDomains(),
|
||||||
database.WithGroupBy(i.IDColumn()),
|
database.WithGroupBy(i.IDColumn()),
|
||||||
@@ -71,27 +74,11 @@ func (i *instance) List(ctx context.Context, opts ...database.QueryOption) ([]*d
|
|||||||
builder.WriteString(queryInstanceStmt)
|
builder.WriteString(queryInstanceStmt)
|
||||||
options.Write(&builder)
|
options.Write(&builder)
|
||||||
|
|
||||||
return scanInstances(ctx, i.client, &builder)
|
return scanInstances(ctx, client, &builder)
|
||||||
}
|
|
||||||
|
|
||||||
func (i *instance) joinDomains() database.QueryOption {
|
|
||||||
columns := make([]database.Condition, 0, 2)
|
|
||||||
columns = append(columns, database.NewColumnCondition(i.IDColumn(), i.Domains(false).InstanceIDColumn()))
|
|
||||||
|
|
||||||
// If domains should not be joined, we make sure to return null for the domain columns
|
|
||||||
// the query optimizer of the dialect should optimize this away if no domains are requested
|
|
||||||
if !i.shouldLoadDomains {
|
|
||||||
columns = append(columns, database.IsNull(i.Domains(false).InstanceIDColumn()))
|
|
||||||
}
|
|
||||||
|
|
||||||
return database.WithLeftJoin(
|
|
||||||
"zitadel.instance_domains",
|
|
||||||
database.And(columns...),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create implements [domain.InstanceRepository].
|
// Create implements [domain.InstanceRepository].
|
||||||
func (i *instance) Create(ctx context.Context, instance *domain.Instance) error {
|
func (i instance) Create(ctx context.Context, client database.QueryExecutor, instance *domain.Instance) error {
|
||||||
var (
|
var (
|
||||||
builder database.StatementBuilder
|
builder database.StatementBuilder
|
||||||
createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction
|
createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction
|
||||||
@@ -103,42 +90,46 @@ func (i *instance) Create(ctx context.Context, instance *domain.Instance) error
|
|||||||
updatedAt = instance.UpdatedAt
|
updatedAt = instance.UpdatedAt
|
||||||
}
|
}
|
||||||
|
|
||||||
builder.WriteString(`INSERT INTO zitadel.instances (id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language, created_at, updated_at) VALUES (`)
|
builder.WriteString(`INSERT INTO `)
|
||||||
|
builder.WriteString(i.qualifiedTableName())
|
||||||
|
builder.WriteString(` (id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language, created_at, updated_at) VALUES (`)
|
||||||
builder.WriteArgs(instance.ID, instance.Name, instance.DefaultOrgID, instance.IAMProjectID, instance.ConsoleClientID, instance.ConsoleAppID, instance.DefaultLanguage, createdAt, updatedAt)
|
builder.WriteArgs(instance.ID, instance.Name, instance.DefaultOrgID, instance.IAMProjectID, instance.ConsoleClientID, instance.ConsoleAppID, instance.DefaultLanguage, createdAt, updatedAt)
|
||||||
builder.WriteString(`) RETURNING created_at, updated_at`)
|
builder.WriteString(`) RETURNING created_at, updated_at`)
|
||||||
|
|
||||||
return i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&instance.CreatedAt, &instance.UpdatedAt)
|
return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&instance.CreatedAt, &instance.UpdatedAt)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update implements [domain.InstanceRepository].
|
// Update implements [domain.InstanceRepository].
|
||||||
func (i instance) Update(ctx context.Context, id string, changes ...database.Change) (int64, error) {
|
func (i instance) Update(ctx context.Context, client database.QueryExecutor, id string, changes ...database.Change) (int64, error) {
|
||||||
if len(changes) == 0 {
|
if len(changes) == 0 {
|
||||||
return 0, database.ErrNoChanges
|
return 0, database.ErrNoChanges
|
||||||
}
|
}
|
||||||
|
if !database.Changes(changes).IsOnColumn(i.UpdatedAtColumn()) {
|
||||||
|
changes = append(changes, database.NewChange(i.UpdatedAtColumn(), database.NullInstruction))
|
||||||
|
}
|
||||||
|
|
||||||
var builder database.StatementBuilder
|
var builder database.StatementBuilder
|
||||||
|
|
||||||
builder.WriteString(`UPDATE zitadel.instances SET `)
|
builder.WriteString(`UPDATE zitadel.instances SET `)
|
||||||
|
|
||||||
database.Changes(changes).Write(&builder)
|
database.Changes(changes).Write(&builder)
|
||||||
|
|
||||||
idCondition := i.IDCondition(id)
|
idCondition := i.IDCondition(id)
|
||||||
writeCondition(&builder, idCondition)
|
writeCondition(&builder, idCondition)
|
||||||
|
|
||||||
stmt := builder.String()
|
stmt := builder.String()
|
||||||
|
|
||||||
return i.client.Exec(ctx, stmt, builder.Args()...)
|
return client.Exec(ctx, stmt, builder.Args()...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete implements [domain.InstanceRepository].
|
// Delete implements [domain.InstanceRepository].
|
||||||
func (i instance) Delete(ctx context.Context, id string) (int64, error) {
|
func (i instance) Delete(ctx context.Context, client database.QueryExecutor, id string) (int64, error) {
|
||||||
var builder database.StatementBuilder
|
var builder database.StatementBuilder
|
||||||
|
|
||||||
builder.WriteString(`DELETE FROM zitadel.instances`)
|
builder.WriteString(`DELETE FROM `)
|
||||||
|
builder.WriteString(i.qualifiedTableName())
|
||||||
|
|
||||||
idCondition := i.IDCondition(id)
|
idCondition := i.IDCondition(id)
|
||||||
writeCondition(&builder, idCondition)
|
writeCondition(&builder, idCondition)
|
||||||
|
|
||||||
return i.client.Exec(ctx, builder.String(), builder.Args()...)
|
return client.Exec(ctx, builder.String(), builder.Args()...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------------
|
// -------------------------------------------------------------
|
||||||
@@ -185,53 +176,77 @@ func (i instance) NameCondition(op database.TextOperation, name string) database
|
|||||||
return database.NewTextCondition(i.NameColumn(), op, name)
|
return database.NewTextCondition(i.NameColumn(), op, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExistsDomain creates a correlated [database.Exists] condition on instance_domains.
|
||||||
|
// Use this filter to make sure the Instance returned contains a specific domain.
|
||||||
|
// of the instance in the aggregated result.
|
||||||
|
// Example usage:
|
||||||
|
//
|
||||||
|
// domainRepo := instanceRepo.Domains(true) // ensure domains are loaded/aggregated
|
||||||
|
// instance, _ := instanceRepo.Get(ctx,
|
||||||
|
// database.WithCondition(
|
||||||
|
// database.And(
|
||||||
|
// instanceRepo.InstanceIDCondition(instanceID),
|
||||||
|
// instanceRepo.ExistsDomain(domainRepo.DomainCondition(database.TextOperationEqual, "example.com")),
|
||||||
|
// ),
|
||||||
|
// ),
|
||||||
|
// )
|
||||||
|
func (i instance) ExistsDomain(cond database.Condition) database.Condition {
|
||||||
|
return database.Exists(
|
||||||
|
i.domainRepo.qualifiedTableName(),
|
||||||
|
database.And(
|
||||||
|
database.NewColumnCondition(i.IDColumn(), i.domainRepo.InstanceIDColumn()),
|
||||||
|
cond,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------------
|
// -------------------------------------------------------------
|
||||||
// columns
|
// columns
|
||||||
// -------------------------------------------------------------
|
// -------------------------------------------------------------
|
||||||
|
|
||||||
// IDColumn implements [domain.instanceColumns].
|
// IDColumn implements [domain.instanceColumns].
|
||||||
func (instance) IDColumn() database.Column {
|
func (i instance) IDColumn() database.Column {
|
||||||
return database.NewColumn("instances", "id")
|
return database.NewColumn(i.unqualifiedTableName(), "id")
|
||||||
}
|
}
|
||||||
|
|
||||||
// NameColumn implements [domain.instanceColumns].
|
// NameColumn implements [domain.instanceColumns].
|
||||||
func (instance) NameColumn() database.Column {
|
func (i instance) NameColumn() database.Column {
|
||||||
return database.NewColumn("instances", "name")
|
return database.NewColumn(i.unqualifiedTableName(), "name")
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreatedAtColumn implements [domain.instanceColumns].
|
// CreatedAtColumn implements [domain.instanceColumns].
|
||||||
func (instance) CreatedAtColumn() database.Column {
|
func (i instance) CreatedAtColumn() database.Column {
|
||||||
return database.NewColumn("instances", "created_at")
|
return database.NewColumn(i.unqualifiedTableName(), "created_at")
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultOrgIdColumn implements [domain.instanceColumns].
|
// DefaultOrgIdColumn implements [domain.instanceColumns].
|
||||||
func (instance) DefaultOrgIDColumn() database.Column {
|
func (i instance) DefaultOrgIDColumn() database.Column {
|
||||||
return database.NewColumn("instances", "default_org_id")
|
return database.NewColumn(i.unqualifiedTableName(), "default_org_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
// IAMProjectIDColumn implements [domain.instanceColumns].
|
// IAMProjectIDColumn implements [domain.instanceColumns].
|
||||||
func (instance) IAMProjectIDColumn() database.Column {
|
func (i instance) IAMProjectIDColumn() database.Column {
|
||||||
return database.NewColumn("instances", "iam_project_id")
|
return database.NewColumn(i.unqualifiedTableName(), "iam_project_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConsoleClientIDColumn implements [domain.instanceColumns].
|
// ConsoleClientIDColumn implements [domain.instanceColumns].
|
||||||
func (instance) ConsoleClientIDColumn() database.Column {
|
func (i instance) ConsoleClientIDColumn() database.Column {
|
||||||
return database.NewColumn("instances", "console_client_id")
|
return database.NewColumn(i.unqualifiedTableName(), "console_client_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConsoleAppIDColumn implements [domain.instanceColumns].
|
// ConsoleAppIDColumn implements [domain.instanceColumns].
|
||||||
func (instance) ConsoleAppIDColumn() database.Column {
|
func (i instance) ConsoleAppIDColumn() database.Column {
|
||||||
return database.NewColumn("instances", "console_app_id")
|
return database.NewColumn(i.unqualifiedTableName(), "console_app_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultLanguageColumn implements [domain.instanceColumns].
|
// DefaultLanguageColumn implements [domain.instanceColumns].
|
||||||
func (instance) DefaultLanguageColumn() database.Column {
|
func (i instance) DefaultLanguageColumn() database.Column {
|
||||||
return database.NewColumn("instances", "default_language")
|
return database.NewColumn(i.unqualifiedTableName(), "default_language")
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatedAtColumn implements [domain.instanceColumns].
|
// UpdatedAtColumn implements [domain.instanceColumns].
|
||||||
func (instance) UpdatedAtColumn() database.Column {
|
func (i instance) UpdatedAtColumn() database.Column {
|
||||||
return database.NewColumn("instances", "updated_at")
|
return database.NewColumn(i.unqualifiedTableName(), "updated_at")
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------------
|
// -------------------------------------------------------------
|
||||||
@@ -282,19 +297,24 @@ func scanInstances(ctx context.Context, querier database.Querier, builder *datab
|
|||||||
// sub repositories
|
// sub repositories
|
||||||
// -------------------------------------------------------------
|
// -------------------------------------------------------------
|
||||||
|
|
||||||
// Domains implements [domain.InstanceRepository].
|
func (i *instance) LoadDomains() domain.InstanceRepository {
|
||||||
func (i *instance) Domains(shouldLoad bool) domain.InstanceDomainRepository {
|
return &instance{
|
||||||
if !i.shouldLoadDomains {
|
shouldLoadDomains: true,
|
||||||
i.shouldLoadDomains = shouldLoad
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
if i.domainRepo != nil {
|
|
||||||
return i.domainRepo
|
func (i *instance) joinDomains() database.QueryOption {
|
||||||
}
|
columns := make([]database.Condition, 0, 2)
|
||||||
|
columns = append(columns, database.NewColumnCondition(i.IDColumn(), i.domainRepo.InstanceIDColumn()))
|
||||||
i.domainRepo = &instanceDomain{
|
|
||||||
repository: i.repository,
|
// If domains should not be joined, we make sure to return null for the domain columns
|
||||||
instance: i,
|
// the query optimizer of the dialect should optimize this away if no domains are requested
|
||||||
}
|
if !i.shouldLoadDomains {
|
||||||
return i.domainRepo
|
columns = append(columns, database.IsNull(i.domainRepo.InstanceIDColumn()))
|
||||||
|
}
|
||||||
|
|
||||||
|
return database.WithLeftJoin(
|
||||||
|
i.domainRepo.qualifiedTableName(),
|
||||||
|
database.And(columns...),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,9 +10,18 @@ import (
|
|||||||
|
|
||||||
var _ domain.InstanceDomainRepository = (*instanceDomain)(nil)
|
var _ domain.InstanceDomainRepository = (*instanceDomain)(nil)
|
||||||
|
|
||||||
type instanceDomain struct {
|
type instanceDomain struct{}
|
||||||
repository
|
|
||||||
*instance
|
func InstanceDomainRepository() domain.InstanceDomainRepository {
|
||||||
|
return new(instanceDomain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (instanceDomain) qualifiedTableName() string {
|
||||||
|
return "zitadel.instance_domains"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (instanceDomain) unqualifiedTableName() string {
|
||||||
|
return "instance_domains"
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------------
|
// -------------------------------------------------------------
|
||||||
@@ -23,8 +32,7 @@ const queryInstanceDomainStmt = `SELECT instance_domains.instance_id, instance_d
|
|||||||
`FROM zitadel.instance_domains`
|
`FROM zitadel.instance_domains`
|
||||||
|
|
||||||
// Get implements [domain.InstanceDomainRepository].
|
// Get implements [domain.InstanceDomainRepository].
|
||||||
// Subtle: this method shadows the method ([domain.InstanceRepository]).Get of instanceDomain.instance.
|
func (i instanceDomain) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.InstanceDomain, error) {
|
||||||
func (i *instanceDomain) Get(ctx context.Context, opts ...database.QueryOption) (*domain.InstanceDomain, error) {
|
|
||||||
options := new(database.QueryOpts)
|
options := new(database.QueryOpts)
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(options)
|
opt(options)
|
||||||
@@ -34,12 +42,11 @@ func (i *instanceDomain) Get(ctx context.Context, opts ...database.QueryOption)
|
|||||||
builder.WriteString(queryInstanceDomainStmt)
|
builder.WriteString(queryInstanceDomainStmt)
|
||||||
options.Write(&builder)
|
options.Write(&builder)
|
||||||
|
|
||||||
return scanInstanceDomain(ctx, i.client, &builder)
|
return scanInstanceDomain(ctx, client, &builder)
|
||||||
}
|
}
|
||||||
|
|
||||||
// List implements [domain.InstanceDomainRepository].
|
// List implements [domain.InstanceDomainRepository].
|
||||||
// Subtle: this method shadows the method ([domain.InstanceRepository]).List of instanceDomain.instance.
|
func (i instanceDomain) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*domain.InstanceDomain, error) {
|
||||||
func (i *instanceDomain) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.InstanceDomain, error) {
|
|
||||||
options := new(database.QueryOpts)
|
options := new(database.QueryOpts)
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(options)
|
opt(options)
|
||||||
@@ -49,11 +56,11 @@ func (i *instanceDomain) List(ctx context.Context, opts ...database.QueryOption)
|
|||||||
builder.WriteString(queryInstanceDomainStmt)
|
builder.WriteString(queryInstanceDomainStmt)
|
||||||
options.Write(&builder)
|
options.Write(&builder)
|
||||||
|
|
||||||
return scanInstanceDomains(ctx, i.client, &builder)
|
return scanInstanceDomains(ctx, client, &builder)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add implements [domain.InstanceDomainRepository].
|
// Add implements [domain.InstanceDomainRepository].
|
||||||
func (i *instanceDomain) Add(ctx context.Context, domain *domain.AddInstanceDomain) error {
|
func (i instanceDomain) Add(ctx context.Context, client database.QueryExecutor, domain *domain.AddInstanceDomain) error {
|
||||||
var (
|
var (
|
||||||
builder database.StatementBuilder
|
builder database.StatementBuilder
|
||||||
createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction
|
createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction
|
||||||
@@ -69,33 +76,41 @@ func (i *instanceDomain) Add(ctx context.Context, domain *domain.AddInstanceDoma
|
|||||||
builder.WriteArgs(domain.InstanceID, domain.Domain, domain.IsPrimary, domain.IsGenerated, domain.Type, createdAt, updatedAt)
|
builder.WriteArgs(domain.InstanceID, domain.Domain, domain.IsPrimary, domain.IsGenerated, domain.Type, createdAt, updatedAt)
|
||||||
builder.WriteString(`) RETURNING created_at, updated_at`)
|
builder.WriteString(`) RETURNING created_at, updated_at`)
|
||||||
|
|
||||||
return i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt)
|
return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update implements [domain.InstanceDomainRepository].
|
// Update implements [domain.InstanceDomainRepository].
|
||||||
// Subtle: this method shadows the method ([domain.InstanceRepository]).Update of instanceDomain.instance.
|
func (i instanceDomain) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error) {
|
||||||
func (i *instanceDomain) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) {
|
if !condition.IsRestrictingColumn(i.InstanceIDColumn()) {
|
||||||
|
return 0, database.NewMissingConditionError(i.InstanceIDColumn())
|
||||||
|
}
|
||||||
if len(changes) == 0 {
|
if len(changes) == 0 {
|
||||||
return 0, database.ErrNoChanges
|
return 0, database.ErrNoChanges
|
||||||
}
|
}
|
||||||
var builder database.StatementBuilder
|
if !database.Changes(changes).IsOnColumn(i.UpdatedAtColumn()) {
|
||||||
|
changes = append(changes, database.NewChange(i.UpdatedAtColumn(), database.NullInstruction))
|
||||||
|
}
|
||||||
|
|
||||||
|
var builder database.StatementBuilder
|
||||||
builder.WriteString(`UPDATE zitadel.instance_domains SET `)
|
builder.WriteString(`UPDATE zitadel.instance_domains SET `)
|
||||||
database.Changes(changes).Write(&builder)
|
database.Changes(changes).Write(&builder)
|
||||||
|
|
||||||
writeCondition(&builder, condition)
|
writeCondition(&builder, condition)
|
||||||
|
|
||||||
return i.client.Exec(ctx, builder.String(), builder.Args()...)
|
return client.Exec(ctx, builder.String(), builder.Args()...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove implements [domain.InstanceDomainRepository].
|
// Remove implements [domain.InstanceDomainRepository].
|
||||||
func (i *instanceDomain) Remove(ctx context.Context, condition database.Condition) (int64, error) {
|
func (i instanceDomain) Remove(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error) {
|
||||||
var builder database.StatementBuilder
|
if !condition.IsRestrictingColumn(i.InstanceIDColumn()) {
|
||||||
|
return 0, database.NewMissingConditionError(i.InstanceIDColumn())
|
||||||
|
}
|
||||||
|
|
||||||
|
var builder database.StatementBuilder
|
||||||
builder.WriteString(`DELETE FROM zitadel.instance_domains WHERE `)
|
builder.WriteString(`DELETE FROM zitadel.instance_domains WHERE `)
|
||||||
condition.Write(&builder)
|
condition.Write(&builder)
|
||||||
|
|
||||||
return i.client.Exec(ctx, builder.String(), builder.Args()...)
|
return client.Exec(ctx, builder.String(), builder.Args()...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------------
|
// -------------------------------------------------------------
|
||||||
@@ -146,40 +161,38 @@ func (i instanceDomain) TypeCondition(typ domain.DomainType) database.Condition
|
|||||||
// -------------------------------------------------------------
|
// -------------------------------------------------------------
|
||||||
|
|
||||||
// CreatedAtColumn implements [domain.InstanceDomainRepository].
|
// CreatedAtColumn implements [domain.InstanceDomainRepository].
|
||||||
// Subtle: this method shadows the method ([domain.InstanceRepository]).CreatedAtColumn of instanceDomain.instance.
|
func (i instanceDomain) CreatedAtColumn() database.Column {
|
||||||
func (instanceDomain) CreatedAtColumn() database.Column {
|
return database.NewColumn(i.unqualifiedTableName(), "created_at")
|
||||||
return database.NewColumn("instance_domains", "created_at")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DomainColumn implements [domain.InstanceDomainRepository].
|
// DomainColumn implements [domain.InstanceDomainRepository].
|
||||||
func (instanceDomain) DomainColumn() database.Column {
|
func (i instanceDomain) DomainColumn() database.Column {
|
||||||
return database.NewColumn("instance_domains", "domain")
|
return database.NewColumn(i.unqualifiedTableName(), "domain")
|
||||||
}
|
}
|
||||||
|
|
||||||
// InstanceIDColumn implements [domain.InstanceDomainRepository].
|
// InstanceIDColumn implements [domain.InstanceDomainRepository].
|
||||||
func (instanceDomain) InstanceIDColumn() database.Column {
|
func (i instanceDomain) InstanceIDColumn() database.Column {
|
||||||
return database.NewColumn("instance_domains", "instance_id")
|
return database.NewColumn(i.unqualifiedTableName(), "instance_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsPrimaryColumn implements [domain.InstanceDomainRepository].
|
// IsPrimaryColumn implements [domain.InstanceDomainRepository].
|
||||||
func (instanceDomain) IsPrimaryColumn() database.Column {
|
func (i instanceDomain) IsPrimaryColumn() database.Column {
|
||||||
return database.NewColumn("instance_domains", "is_primary")
|
return database.NewColumn(i.unqualifiedTableName(), "is_primary")
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatedAtColumn implements [domain.InstanceDomainRepository].
|
// UpdatedAtColumn implements [domain.InstanceDomainRepository].
|
||||||
// Subtle: this method shadows the method ([domain.InstanceRepository]).UpdatedAtColumn of instanceDomain.instance.
|
func (i instanceDomain) UpdatedAtColumn() database.Column {
|
||||||
func (instanceDomain) UpdatedAtColumn() database.Column {
|
return database.NewColumn(i.unqualifiedTableName(), "updated_at")
|
||||||
return database.NewColumn("instance_domains", "updated_at")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsGeneratedColumn implements [domain.InstanceDomainRepository].
|
// IsGeneratedColumn implements [domain.InstanceDomainRepository].
|
||||||
func (instanceDomain) IsGeneratedColumn() database.Column {
|
func (i instanceDomain) IsGeneratedColumn() database.Column {
|
||||||
return database.NewColumn("instance_domains", "is_generated")
|
return database.NewColumn(i.unqualifiedTableName(), "is_generated")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TypeColumn implements [domain.InstanceDomainRepository].
|
// TypeColumn implements [domain.InstanceDomainRepository].
|
||||||
func (instanceDomain) TypeColumn() database.Column {
|
func (i instanceDomain) TypeColumn() database.Column {
|
||||||
return database.NewColumn("instance_domains", "type")
|
return database.NewColumn(i.unqualifiedTableName(), "type")
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------------
|
// -------------------------------------------------------------
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package repository_test
|
package repository_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -16,31 +15,43 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestAddInstanceDomain(t *testing.T) {
|
func TestAddInstanceDomain(t *testing.T) {
|
||||||
|
// we take now here because the timestamp of the transaction is used to set the createdAt and updatedAt fields
|
||||||
|
beforeAdd := time.Now()
|
||||||
|
tx, err := pool.Begin(t.Context(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
err = tx.Rollback(t.Context())
|
||||||
|
if err != nil {
|
||||||
|
t.Log("error during rollback:", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
instanceRepo := repository.InstanceRepository()
|
||||||
|
domainRepo := repository.InstanceDomainRepository()
|
||||||
|
|
||||||
// create instance
|
// create instance
|
||||||
instanceID := gofakeit.UUID()
|
|
||||||
instance := domain.Instance{
|
instance := domain.Instance{
|
||||||
ID: instanceID,
|
ID: gofakeit.NewCrypto().UUID(),
|
||||||
Name: gofakeit.Name(),
|
Name: gofakeit.BeerName(),
|
||||||
DefaultOrgID: "defaultOrgId",
|
DefaultOrgID: "defaultOrgId",
|
||||||
IAMProjectID: "iamProject",
|
IAMProjectID: "iamProject",
|
||||||
ConsoleClientID: "consoleClient",
|
ConsoleClientID: "consoleClient",
|
||||||
ConsoleAppID: "consoleApp",
|
ConsoleAppID: "consoleApp",
|
||||||
DefaultLanguage: "defaultLanguage",
|
DefaultLanguage: "defaultLanguage",
|
||||||
}
|
}
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||||
err := instanceRepo.Create(t.Context(), &instance)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
testFunc func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain
|
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.AddInstanceDomain
|
||||||
instanceDomain domain.AddInstanceDomain
|
instanceDomain domain.AddInstanceDomain
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "happy path custom domain",
|
name: "happy path custom domain",
|
||||||
instanceDomain: domain.AddInstanceDomain{
|
instanceDomain: domain.AddInstanceDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
Domain: gofakeit.DomainName(),
|
Domain: gofakeit.DomainName(),
|
||||||
Type: domain.DomainTypeCustom,
|
Type: domain.DomainTypeCustom,
|
||||||
IsPrimary: gu.Ptr(false),
|
IsPrimary: gu.Ptr(false),
|
||||||
@@ -50,7 +61,7 @@ func TestAddInstanceDomain(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "happy path trusted domain",
|
name: "happy path trusted domain",
|
||||||
instanceDomain: domain.AddInstanceDomain{
|
instanceDomain: domain.AddInstanceDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
Domain: gofakeit.DomainName(),
|
Domain: gofakeit.DomainName(),
|
||||||
Type: domain.DomainTypeTrusted,
|
Type: domain.DomainTypeTrusted,
|
||||||
},
|
},
|
||||||
@@ -58,7 +69,7 @@ func TestAddInstanceDomain(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "add primary domain",
|
name: "add primary domain",
|
||||||
instanceDomain: domain.AddInstanceDomain{
|
instanceDomain: domain.AddInstanceDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
Domain: gofakeit.DomainName(),
|
Domain: gofakeit.DomainName(),
|
||||||
Type: domain.DomainTypeCustom,
|
Type: domain.DomainTypeCustom,
|
||||||
IsPrimary: gu.Ptr(true),
|
IsPrimary: gu.Ptr(true),
|
||||||
@@ -68,7 +79,7 @@ func TestAddInstanceDomain(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "add custom domain without domain name",
|
name: "add custom domain without domain name",
|
||||||
instanceDomain: domain.AddInstanceDomain{
|
instanceDomain: domain.AddInstanceDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
Domain: "",
|
Domain: "",
|
||||||
Type: domain.DomainTypeCustom,
|
Type: domain.DomainTypeCustom,
|
||||||
IsPrimary: gu.Ptr(false),
|
IsPrimary: gu.Ptr(false),
|
||||||
@@ -79,7 +90,7 @@ func TestAddInstanceDomain(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "add trusted domain without domain name",
|
name: "add trusted domain without domain name",
|
||||||
instanceDomain: domain.AddInstanceDomain{
|
instanceDomain: domain.AddInstanceDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
Domain: "",
|
Domain: "",
|
||||||
Type: domain.DomainTypeTrusted,
|
Type: domain.DomainTypeTrusted,
|
||||||
},
|
},
|
||||||
@@ -87,23 +98,23 @@ func TestAddInstanceDomain(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "add custom domain with same domain twice",
|
name: "add custom domain with same domain twice",
|
||||||
testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain {
|
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.AddInstanceDomain {
|
||||||
domainName := gofakeit.DomainName()
|
domainName := gofakeit.DomainName()
|
||||||
|
|
||||||
instanceDomain := &domain.AddInstanceDomain{
|
instanceDomain := &domain.AddInstanceDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
Domain: domainName,
|
Domain: domainName,
|
||||||
Type: domain.DomainTypeCustom,
|
Type: domain.DomainTypeCustom,
|
||||||
IsPrimary: gu.Ptr(false),
|
IsPrimary: gu.Ptr(false),
|
||||||
IsGenerated: gu.Ptr(false),
|
IsGenerated: gu.Ptr(false),
|
||||||
}
|
}
|
||||||
|
|
||||||
err := domainRepo.Add(ctx, instanceDomain)
|
err := domainRepo.Add(t.Context(), tx, instanceDomain)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// return same domain again
|
// return same domain again
|
||||||
return &domain.AddInstanceDomain{
|
return &domain.AddInstanceDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
Domain: domainName,
|
Domain: domainName,
|
||||||
Type: domain.DomainTypeCustom,
|
Type: domain.DomainTypeCustom,
|
||||||
IsPrimary: gu.Ptr(false),
|
IsPrimary: gu.Ptr(false),
|
||||||
@@ -114,22 +125,20 @@ func TestAddInstanceDomain(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "add trusted domain with same domain twice",
|
name: "add trusted domain with same domain twice",
|
||||||
testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain {
|
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.AddInstanceDomain {
|
||||||
domainName := gofakeit.DomainName()
|
|
||||||
|
|
||||||
instanceDomain := &domain.AddInstanceDomain{
|
instanceDomain := &domain.AddInstanceDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
Domain: domainName,
|
Domain: gofakeit.DomainName(),
|
||||||
Type: domain.DomainTypeTrusted,
|
Type: domain.DomainTypeTrusted,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := domainRepo.Add(ctx, instanceDomain)
|
err := domainRepo.Add(t.Context(), tx, instanceDomain)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// return same domain again
|
// return same domain again
|
||||||
return &domain.AddInstanceDomain{
|
return &domain.AddInstanceDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
Domain: domainName,
|
Domain: instanceDomain.Domain,
|
||||||
Type: domain.DomainTypeTrusted,
|
Type: domain.DomainTypeTrusted,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -196,26 +205,23 @@ func TestAddInstanceDomain(t *testing.T) {
|
|||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
ctx := t.Context()
|
savepoint, err := tx.Begin(t.Context())
|
||||||
|
|
||||||
// we take now here because the timestamp of the transaction is used to set the createdAt and updatedAt fields
|
|
||||||
beforeAdd := time.Now()
|
|
||||||
tx, err := pool.Begin(t.Context(), nil)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, tx.Rollback(t.Context()))
|
err := savepoint.Rollback(t.Context())
|
||||||
|
if err != nil {
|
||||||
|
t.Log("error during rollback:", err)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
instanceRepo := repository.InstanceRepository(tx)
|
|
||||||
domainRepo := instanceRepo.Domains(false)
|
|
||||||
|
|
||||||
var instanceDomain *domain.AddInstanceDomain
|
var instanceDomain *domain.AddInstanceDomain
|
||||||
if test.testFunc != nil {
|
if test.testFunc != nil {
|
||||||
instanceDomain = test.testFunc(ctx, t, domainRepo)
|
instanceDomain = test.testFunc(t, savepoint)
|
||||||
} else {
|
} else {
|
||||||
instanceDomain = &test.instanceDomain
|
instanceDomain = &test.instanceDomain
|
||||||
}
|
}
|
||||||
|
|
||||||
err = domainRepo.Add(ctx, instanceDomain)
|
err = domainRepo.Add(t.Context(), savepoint, instanceDomain)
|
||||||
afterAdd := time.Now()
|
afterAdd := time.Now()
|
||||||
if test.err != nil {
|
if test.err != nil {
|
||||||
assert.ErrorIs(t, err, test.err)
|
assert.ErrorIs(t, err, test.err)
|
||||||
@@ -232,10 +238,21 @@ func TestAddInstanceDomain(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGetInstanceDomain(t *testing.T) {
|
func TestGetInstanceDomain(t *testing.T) {
|
||||||
|
tx, err := pool.Begin(t.Context(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
err = tx.Rollback(t.Context())
|
||||||
|
if err != nil {
|
||||||
|
t.Log("error during rollback:", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
instanceRepo := repository.InstanceRepository()
|
||||||
|
domainRepo := repository.InstanceDomainRepository()
|
||||||
|
|
||||||
// create instance
|
// create instance
|
||||||
instanceID := gofakeit.UUID()
|
|
||||||
instance := domain.Instance{
|
instance := domain.Instance{
|
||||||
ID: instanceID,
|
ID: gofakeit.NewCrypto().UUID(),
|
||||||
Name: gofakeit.Name(),
|
Name: gofakeit.Name(),
|
||||||
DefaultOrgID: "defaultOrgId",
|
DefaultOrgID: "defaultOrgId",
|
||||||
IAMProjectID: "iamProject",
|
IAMProjectID: "iamProject",
|
||||||
@@ -243,38 +260,29 @@ func TestGetInstanceDomain(t *testing.T) {
|
|||||||
ConsoleAppID: "consoleApp",
|
ConsoleAppID: "consoleApp",
|
||||||
DefaultLanguage: "defaultLanguage",
|
DefaultLanguage: "defaultLanguage",
|
||||||
}
|
}
|
||||||
tx, err := pool.Begin(t.Context(), nil)
|
|
||||||
require.NoError(t, err)
|
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||||
defer func() {
|
|
||||||
require.NoError(t, tx.Rollback(t.Context()))
|
|
||||||
}()
|
|
||||||
instanceRepo := repository.InstanceRepository(tx)
|
|
||||||
err = instanceRepo.Create(t.Context(), &instance)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// add domains
|
// add domains
|
||||||
domainRepo := instanceRepo.Domains(false)
|
|
||||||
domainName1 := gofakeit.DomainName()
|
|
||||||
domainName2 := gofakeit.DomainName()
|
|
||||||
|
|
||||||
domain1 := &domain.AddInstanceDomain{
|
domain1 := &domain.AddInstanceDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
Domain: domainName1,
|
Domain: gofakeit.DomainName(),
|
||||||
IsPrimary: gu.Ptr(true),
|
IsPrimary: gu.Ptr(true),
|
||||||
IsGenerated: gu.Ptr(false),
|
IsGenerated: gu.Ptr(false),
|
||||||
Type: domain.DomainTypeCustom,
|
Type: domain.DomainTypeCustom,
|
||||||
}
|
}
|
||||||
domain2 := &domain.AddInstanceDomain{
|
domain2 := &domain.AddInstanceDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
Domain: domainName2,
|
Domain: gofakeit.DomainName(),
|
||||||
IsPrimary: gu.Ptr(false),
|
IsPrimary: gu.Ptr(false),
|
||||||
IsGenerated: gu.Ptr(false),
|
IsGenerated: gu.Ptr(false),
|
||||||
Type: domain.DomainTypeCustom,
|
Type: domain.DomainTypeCustom,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = domainRepo.Add(t.Context(), domain1)
|
err = domainRepo.Add(t.Context(), tx, domain1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
err = domainRepo.Add(t.Context(), domain2)
|
err = domainRepo.Add(t.Context(), tx, domain2)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -289,19 +297,19 @@ func TestGetInstanceDomain(t *testing.T) {
|
|||||||
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
|
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
|
||||||
},
|
},
|
||||||
expected: &domain.InstanceDomain{
|
expected: &domain.InstanceDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
Domain: domainName1,
|
Domain: domain1.Domain,
|
||||||
IsPrimary: gu.Ptr(true),
|
IsPrimary: gu.Ptr(true),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "get by domain name",
|
name: "get by domain name",
|
||||||
opts: []database.QueryOption{
|
opts: []database.QueryOption{
|
||||||
database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, domainName2)),
|
database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, domain2.Domain)),
|
||||||
},
|
},
|
||||||
expected: &domain.InstanceDomain{
|
expected: &domain.InstanceDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
Domain: domainName2,
|
Domain: domain2.Domain,
|
||||||
IsPrimary: gu.Ptr(false),
|
IsPrimary: gu.Ptr(false),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -318,7 +326,7 @@ func TestGetInstanceDomain(t *testing.T) {
|
|||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
ctx := t.Context()
|
ctx := t.Context()
|
||||||
|
|
||||||
result, err := domainRepo.Get(ctx, test.opts...)
|
result, err := domainRepo.Get(ctx, tx, test.opts...)
|
||||||
if test.err != nil {
|
if test.err != nil {
|
||||||
assert.ErrorIs(t, err, test.err)
|
assert.ErrorIs(t, err, test.err)
|
||||||
return
|
return
|
||||||
@@ -335,10 +343,21 @@ func TestGetInstanceDomain(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestListInstanceDomains(t *testing.T) {
|
func TestListInstanceDomains(t *testing.T) {
|
||||||
|
tx, err := pool.Begin(t.Context(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
err = tx.Rollback(t.Context())
|
||||||
|
if err != nil {
|
||||||
|
t.Log("error during rollback:", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
instanceRepo := repository.InstanceRepository()
|
||||||
|
domainRepo := repository.InstanceDomainRepository()
|
||||||
|
|
||||||
// create instance
|
// create instance
|
||||||
instanceID := gofakeit.UUID()
|
|
||||||
instance := domain.Instance{
|
instance := domain.Instance{
|
||||||
ID: instanceID,
|
ID: gofakeit.NewCrypto().UUID(),
|
||||||
Name: gofakeit.Name(),
|
Name: gofakeit.Name(),
|
||||||
DefaultOrgID: "defaultOrgId",
|
DefaultOrgID: "defaultOrgId",
|
||||||
IAMProjectID: "iamProject",
|
IAMProjectID: "iamProject",
|
||||||
@@ -346,42 +365,35 @@ func TestListInstanceDomains(t *testing.T) {
|
|||||||
ConsoleAppID: "consoleApp",
|
ConsoleAppID: "consoleApp",
|
||||||
DefaultLanguage: "defaultLanguage",
|
DefaultLanguage: "defaultLanguage",
|
||||||
}
|
}
|
||||||
tx, err := pool.Begin(t.Context(), nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
require.NoError(t, tx.Rollback(t.Context()))
|
|
||||||
}()
|
|
||||||
|
|
||||||
instanceRepo := repository.InstanceRepository(tx)
|
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||||
err = instanceRepo.Create(t.Context(), &instance)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// add multiple domains
|
// add multiple domains
|
||||||
domainRepo := instanceRepo.Domains(false)
|
|
||||||
domains := []domain.AddInstanceDomain{
|
domains := []domain.AddInstanceDomain{
|
||||||
{
|
{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
Domain: gofakeit.DomainName(),
|
Domain: gofakeit.DomainName(),
|
||||||
IsPrimary: gu.Ptr(true),
|
IsPrimary: gu.Ptr(true),
|
||||||
IsGenerated: gu.Ptr(false),
|
IsGenerated: gu.Ptr(false),
|
||||||
Type: domain.DomainTypeCustom,
|
Type: domain.DomainTypeCustom,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
Domain: gofakeit.DomainName(),
|
Domain: gofakeit.DomainName(),
|
||||||
IsPrimary: gu.Ptr(false),
|
IsPrimary: gu.Ptr(false),
|
||||||
IsGenerated: gu.Ptr(false),
|
IsGenerated: gu.Ptr(false),
|
||||||
Type: domain.DomainTypeCustom,
|
Type: domain.DomainTypeCustom,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
Domain: gofakeit.DomainName(),
|
Domain: gofakeit.DomainName(),
|
||||||
Type: domain.DomainTypeTrusted,
|
Type: domain.DomainTypeTrusted,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range domains {
|
for i := range domains {
|
||||||
err = domainRepo.Add(t.Context(), &domains[i])
|
err = domainRepo.Add(t.Context(), tx, &domains[i])
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -405,7 +417,7 @@ func TestListInstanceDomains(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "list by instance",
|
name: "list by instance",
|
||||||
opts: []database.QueryOption{
|
opts: []database.QueryOption{
|
||||||
database.WithCondition(domainRepo.InstanceIDCondition(instanceID)),
|
database.WithCondition(domainRepo.InstanceIDCondition(instance.ID)),
|
||||||
},
|
},
|
||||||
expectedCount: 3,
|
expectedCount: 3,
|
||||||
},
|
},
|
||||||
@@ -422,12 +434,12 @@ func TestListInstanceDomains(t *testing.T) {
|
|||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
ctx := t.Context()
|
ctx := t.Context()
|
||||||
|
|
||||||
results, err := domainRepo.List(ctx, test.opts...)
|
results, err := domainRepo.List(ctx, tx, test.opts...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, results, test.expectedCount)
|
assert.Len(t, results, test.expectedCount)
|
||||||
|
|
||||||
for _, result := range results {
|
for _, result := range results {
|
||||||
assert.Equal(t, instanceID, result.InstanceID)
|
assert.Equal(t, instance.ID, result.InstanceID)
|
||||||
assert.NotEmpty(t, result.Domain)
|
assert.NotEmpty(t, result.Domain)
|
||||||
assert.NotEmpty(t, result.CreatedAt)
|
assert.NotEmpty(t, result.CreatedAt)
|
||||||
assert.NotEmpty(t, result.UpdatedAt)
|
assert.NotEmpty(t, result.UpdatedAt)
|
||||||
@@ -437,10 +449,21 @@ func TestListInstanceDomains(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateInstanceDomain(t *testing.T) {
|
func TestUpdateInstanceDomain(t *testing.T) {
|
||||||
|
tx, err := pool.Begin(t.Context(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
err = tx.Rollback(t.Context())
|
||||||
|
if err != nil {
|
||||||
|
t.Log("error during rollback:", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
instanceRepo := repository.InstanceRepository()
|
||||||
|
domainRepo := repository.InstanceDomainRepository()
|
||||||
|
|
||||||
// create instance
|
// create instance
|
||||||
instanceID := gofakeit.UUID()
|
|
||||||
instance := domain.Instance{
|
instance := domain.Instance{
|
||||||
ID: instanceID,
|
ID: gofakeit.UUID(),
|
||||||
Name: gofakeit.Name(),
|
Name: gofakeit.Name(),
|
||||||
DefaultOrgID: "defaultOrgId",
|
DefaultOrgID: "defaultOrgId",
|
||||||
IAMProjectID: "iamProject",
|
IAMProjectID: "iamProject",
|
||||||
@@ -448,29 +471,19 @@ func TestUpdateInstanceDomain(t *testing.T) {
|
|||||||
ConsoleAppID: "consoleApp",
|
ConsoleAppID: "consoleApp",
|
||||||
DefaultLanguage: "defaultLanguage",
|
DefaultLanguage: "defaultLanguage",
|
||||||
}
|
}
|
||||||
|
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||||
tx, err := pool.Begin(t.Context(), nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
require.NoError(t, tx.Rollback(t.Context()))
|
|
||||||
}()
|
|
||||||
|
|
||||||
instanceRepo := repository.InstanceRepository(tx)
|
|
||||||
err = instanceRepo.Create(t.Context(), &instance)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// add domain
|
// add domain
|
||||||
domainRepo := instanceRepo.Domains(false)
|
|
||||||
domainName := gofakeit.DomainName()
|
|
||||||
instanceDomain := &domain.AddInstanceDomain{
|
instanceDomain := &domain.AddInstanceDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
Domain: domainName,
|
Domain: gofakeit.DomainName(),
|
||||||
IsPrimary: gu.Ptr(false),
|
IsPrimary: gu.Ptr(false),
|
||||||
IsGenerated: gu.Ptr(false),
|
IsGenerated: gu.Ptr(false),
|
||||||
Type: domain.DomainTypeCustom,
|
Type: domain.DomainTypeCustom,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = domainRepo.Add(t.Context(), instanceDomain)
|
err = domainRepo.Add(t.Context(), tx, instanceDomain)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -481,31 +494,38 @@ func TestUpdateInstanceDomain(t *testing.T) {
|
|||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "set primary",
|
name: "set primary",
|
||||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
condition: database.And(
|
||||||
changes: []database.Change{domainRepo.SetPrimary()},
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
expected: 1,
|
domainRepo.DomainCondition(database.TextOperationEqual, instanceDomain.Domain),
|
||||||
|
),
|
||||||
|
changes: []database.Change{domainRepo.SetPrimary()},
|
||||||
|
expected: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "update non-existent domain",
|
name: "update non-existent domain",
|
||||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
|
condition: database.And(
|
||||||
changes: []database.Change{domainRepo.SetPrimary()},
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
expected: 0,
|
domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
|
||||||
|
),
|
||||||
|
changes: []database.Change{domainRepo.SetPrimary()},
|
||||||
|
expected: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "no changes",
|
name: "no changes",
|
||||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
condition: database.And(
|
||||||
changes: []database.Change{},
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
expected: 0,
|
domainRepo.DomainCondition(database.TextOperationEqual, instanceDomain.Domain),
|
||||||
err: database.ErrNoChanges,
|
),
|
||||||
|
changes: []database.Change{},
|
||||||
|
expected: 0,
|
||||||
|
err: database.ErrNoChanges,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
ctx := t.Context()
|
rowsAffected, err := domainRepo.Update(t.Context(), tx, test.condition, test.changes...)
|
||||||
|
|
||||||
rowsAffected, err := domainRepo.Update(ctx, test.condition, test.changes...)
|
|
||||||
if test.err != nil {
|
if test.err != nil {
|
||||||
assert.ErrorIs(t, err, test.err)
|
assert.ErrorIs(t, err, test.err)
|
||||||
return
|
return
|
||||||
@@ -516,7 +536,7 @@ func TestUpdateInstanceDomain(t *testing.T) {
|
|||||||
|
|
||||||
// verify changes were applied if rows were affected
|
// verify changes were applied if rows were affected
|
||||||
if rowsAffected > 0 && len(test.changes) > 0 {
|
if rowsAffected > 0 && len(test.changes) > 0 {
|
||||||
result, err := domainRepo.Get(ctx, database.WithCondition(test.condition))
|
result, err := domainRepo.Get(t.Context(), tx, database.WithCondition(test.condition))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// We know changes were applied since rowsAffected > 0
|
// We know changes were applied since rowsAffected > 0
|
||||||
@@ -529,10 +549,21 @@ func TestUpdateInstanceDomain(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRemoveInstanceDomain(t *testing.T) {
|
func TestRemoveInstanceDomain(t *testing.T) {
|
||||||
|
tx, err := pool.Begin(t.Context(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
err = tx.Rollback(t.Context())
|
||||||
|
if err != nil {
|
||||||
|
t.Log("error during rollback:", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
instanceRepo := repository.InstanceRepository()
|
||||||
|
domainRepo := repository.InstanceDomainRepository()
|
||||||
|
|
||||||
// create instance
|
// create instance
|
||||||
instanceID := gofakeit.UUID()
|
|
||||||
instance := domain.Instance{
|
instance := domain.Instance{
|
||||||
ID: instanceID,
|
ID: gofakeit.UUID(),
|
||||||
Name: gofakeit.Name(),
|
Name: gofakeit.Name(),
|
||||||
DefaultOrgID: "defaultOrgId",
|
DefaultOrgID: "defaultOrgId",
|
||||||
IAMProjectID: "iamProject",
|
IAMProjectID: "iamProject",
|
||||||
@@ -540,37 +571,28 @@ func TestRemoveInstanceDomain(t *testing.T) {
|
|||||||
ConsoleAppID: "consoleApp",
|
ConsoleAppID: "consoleApp",
|
||||||
DefaultLanguage: "defaultLanguage",
|
DefaultLanguage: "defaultLanguage",
|
||||||
}
|
}
|
||||||
tx, err := pool.Begin(t.Context(), nil)
|
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
require.NoError(t, tx.Rollback(t.Context()))
|
|
||||||
}()
|
|
||||||
instanceRepo := repository.InstanceRepository(tx)
|
|
||||||
err = instanceRepo.Create(t.Context(), &instance)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// add domains
|
// add domains
|
||||||
domainRepo := instanceRepo.Domains(false)
|
|
||||||
domainName1 := gofakeit.DomainName()
|
|
||||||
|
|
||||||
domain1 := &domain.AddInstanceDomain{
|
domain1 := &domain.AddInstanceDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
Domain: domainName1,
|
Domain: gofakeit.DomainName(),
|
||||||
IsPrimary: gu.Ptr(true),
|
IsPrimary: gu.Ptr(true),
|
||||||
IsGenerated: gu.Ptr(false),
|
IsGenerated: gu.Ptr(false),
|
||||||
Type: domain.DomainTypeCustom,
|
Type: domain.DomainTypeCustom,
|
||||||
}
|
}
|
||||||
domain2 := &domain.AddInstanceDomain{
|
domain2 := &domain.AddInstanceDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
Domain: gofakeit.DomainName(),
|
Domain: gofakeit.DomainName(),
|
||||||
IsPrimary: gu.Ptr(false),
|
IsPrimary: gu.Ptr(false),
|
||||||
IsGenerated: gu.Ptr(false),
|
IsGenerated: gu.Ptr(false),
|
||||||
Type: domain.DomainTypeCustom,
|
Type: domain.DomainTypeCustom,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = domainRepo.Add(t.Context(), domain1)
|
err = domainRepo.Add(t.Context(), tx, domain1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
err = domainRepo.Add(t.Context(), domain2)
|
err = domainRepo.Add(t.Context(), tx, domain2)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -579,36 +601,43 @@ func TestRemoveInstanceDomain(t *testing.T) {
|
|||||||
expected int64
|
expected int64
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "remove by domain name",
|
name: "remove by domain name",
|
||||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName1),
|
condition: database.And(
|
||||||
expected: 1,
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
|
domainRepo.DomainCondition(database.TextOperationEqual, domain1.Domain),
|
||||||
|
),
|
||||||
|
expected: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "remove by primary condition",
|
name: "remove by primary condition",
|
||||||
condition: domainRepo.IsPrimaryCondition(false),
|
condition: database.And(
|
||||||
expected: 1, // domain2 should still exist and be non-primary
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
|
domainRepo.IsPrimaryCondition(false),
|
||||||
|
),
|
||||||
|
expected: 1, // domain2 should still exist and be non-primary
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "remove non-existent domain",
|
name: "remove non-existent domain",
|
||||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
|
condition: database.And(
|
||||||
expected: 0,
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
|
domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
|
||||||
|
),
|
||||||
|
expected: 0,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
ctx := t.Context()
|
|
||||||
|
|
||||||
// count before removal
|
// count before removal
|
||||||
beforeCount, err := domainRepo.List(ctx)
|
beforeCount, err := domainRepo.List(t.Context(), tx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
rowsAffected, err := domainRepo.Remove(ctx, test.condition)
|
rowsAffected, err := domainRepo.Remove(t.Context(), tx, test.condition)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, test.expected, rowsAffected)
|
assert.Equal(t, test.expected, rowsAffected)
|
||||||
|
|
||||||
// verify removal
|
// verify removal
|
||||||
afterCount, err := domainRepo.List(ctx)
|
afterCount, err := domainRepo.List(t.Context(), tx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, len(beforeCount)-int(test.expected), len(afterCount))
|
assert.Equal(t, len(beforeCount)-int(test.expected), len(afterCount))
|
||||||
})
|
})
|
||||||
@@ -616,8 +645,7 @@ func TestRemoveInstanceDomain(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestInstanceDomainConditions(t *testing.T) {
|
func TestInstanceDomainConditions(t *testing.T) {
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
domainRepo := repository.InstanceDomainRepository()
|
||||||
domainRepo := instanceRepo.Domains(false)
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -671,8 +699,7 @@ func TestInstanceDomainConditions(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestInstanceDomainChanges(t *testing.T) {
|
func TestInstanceDomainChanges(t *testing.T) {
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
domainRepo := repository.InstanceDomainRepository()
|
||||||
domainRepo := instanceRepo.Domains(false)
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package repository_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -16,9 +17,20 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestCreateInstance(t *testing.T) {
|
func TestCreateInstance(t *testing.T) {
|
||||||
|
beforeCreate := time.Now()
|
||||||
|
tx, err := pool.Begin(context.Background(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
err := tx.Rollback(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Log("error during rollback:", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
instanceRepo := repository.InstanceRepository()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
|
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance
|
||||||
instance domain.Instance
|
instance domain.Instance
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
@@ -59,14 +71,10 @@ func TestCreateInstance(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "adding same instance twice",
|
name: "adding same instance twice",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
|
||||||
instanceId := gofakeit.Name()
|
|
||||||
instanceName := gofakeit.Name()
|
|
||||||
|
|
||||||
inst := domain.Instance{
|
inst := domain.Instance{
|
||||||
ID: instanceId,
|
ID: gofakeit.UUID(),
|
||||||
Name: instanceName,
|
Name: gofakeit.Name(),
|
||||||
DefaultOrgID: "defaultOrgId",
|
DefaultOrgID: "defaultOrgId",
|
||||||
IAMProjectID: "iamProject",
|
IAMProjectID: "iamProject",
|
||||||
ConsoleClientID: "consoleCLient",
|
ConsoleClientID: "consoleCLient",
|
||||||
@@ -74,7 +82,9 @@ func TestCreateInstance(t *testing.T) {
|
|||||||
DefaultLanguage: "defaultLanguage",
|
DefaultLanguage: "defaultLanguage",
|
||||||
}
|
}
|
||||||
|
|
||||||
err := instanceRepo.Create(ctx, &inst)
|
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
// change the name to make sure same only the id clashes
|
// change the name to make sure same only the id clashes
|
||||||
inst.Name = gofakeit.Name()
|
inst.Name = gofakeit.Name()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -84,7 +94,7 @@ func TestCreateInstance(t *testing.T) {
|
|||||||
},
|
},
|
||||||
func() struct {
|
func() struct {
|
||||||
name string
|
name string
|
||||||
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
|
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance
|
||||||
instance domain.Instance
|
instance domain.Instance
|
||||||
err error
|
err error
|
||||||
} {
|
} {
|
||||||
@@ -92,14 +102,12 @@ func TestCreateInstance(t *testing.T) {
|
|||||||
instanceName := gofakeit.Name()
|
instanceName := gofakeit.Name()
|
||||||
return struct {
|
return struct {
|
||||||
name string
|
name string
|
||||||
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
|
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance
|
||||||
instance domain.Instance
|
instance domain.Instance
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
name: "adding instance with same name twice",
|
name: "adding instance with same name twice",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
|
||||||
|
|
||||||
inst := domain.Instance{
|
inst := domain.Instance{
|
||||||
ID: gofakeit.Name(),
|
ID: gofakeit.Name(),
|
||||||
Name: instanceName,
|
Name: instanceName,
|
||||||
@@ -110,7 +118,7 @@ func TestCreateInstance(t *testing.T) {
|
|||||||
DefaultLanguage: "defaultLanguage",
|
DefaultLanguage: "defaultLanguage",
|
||||||
}
|
}
|
||||||
|
|
||||||
err := instanceRepo.Create(ctx, &inst)
|
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// change the id
|
// change the id
|
||||||
@@ -135,11 +143,8 @@ func TestCreateInstance(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "adding instance with no id",
|
name: "adding instance with no id",
|
||||||
instance: func() domain.Instance {
|
instance: func() domain.Instance {
|
||||||
// instanceId := gofakeit.Name()
|
|
||||||
instanceName := gofakeit.Name()
|
|
||||||
instance := domain.Instance{
|
instance := domain.Instance{
|
||||||
// ID: instanceId,
|
Name: gofakeit.Name(),
|
||||||
Name: instanceName,
|
|
||||||
DefaultOrgID: "defaultOrgId",
|
DefaultOrgID: "defaultOrgId",
|
||||||
IAMProjectID: "iamProject",
|
IAMProjectID: "iamProject",
|
||||||
ConsoleClientID: "consoleCLient",
|
ConsoleClientID: "consoleCLient",
|
||||||
@@ -153,19 +158,25 @@ func TestCreateInstance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := context.Background()
|
savepoint, err := tx.Begin(t.Context())
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
err = savepoint.Rollback(t.Context())
|
||||||
|
if err != nil {
|
||||||
|
t.Log("error during rollback:", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
var instance *domain.Instance
|
var instance *domain.Instance
|
||||||
if tt.testFunc != nil {
|
if tt.testFunc != nil {
|
||||||
instance = tt.testFunc(ctx, t)
|
instance = tt.testFunc(t, savepoint)
|
||||||
} else {
|
} else {
|
||||||
instance = &tt.instance
|
instance = &tt.instance
|
||||||
}
|
}
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
|
||||||
|
|
||||||
// create instance
|
// create instance
|
||||||
beforeCreate := time.Now()
|
|
||||||
err := instanceRepo.Create(ctx, instance)
|
err = instanceRepo.Create(t.Context(), tx, instance)
|
||||||
assert.ErrorIs(t, err, tt.err)
|
assert.ErrorIs(t, err, tt.err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@@ -173,7 +184,7 @@ func TestCreateInstance(t *testing.T) {
|
|||||||
afterCreate := time.Now()
|
afterCreate := time.Now()
|
||||||
|
|
||||||
// check instance values
|
// check instance values
|
||||||
instance, err = instanceRepo.Get(ctx,
|
instance, err = instanceRepo.Get(t.Context(), tx,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
instanceRepo.IDCondition(instance.ID),
|
instanceRepo.IDCondition(instance.ID),
|
||||||
),
|
),
|
||||||
@@ -194,22 +205,30 @@ func TestCreateInstance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateInstance(t *testing.T) {
|
func TestUpdateInstance(t *testing.T) {
|
||||||
|
beforeUpdate := time.Now()
|
||||||
|
tx, err := pool.Begin(context.Background(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
err := tx.Rollback(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Log("error during rollback:", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
instanceRepo := repository.InstanceRepository()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
|
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance
|
||||||
rowsAffected int64
|
rowsAffected int64
|
||||||
getErr error
|
getErr error
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "happy path",
|
name: "happy path",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
|
||||||
instanceId := gofakeit.Name()
|
|
||||||
instanceName := gofakeit.Name()
|
|
||||||
|
|
||||||
inst := domain.Instance{
|
inst := domain.Instance{
|
||||||
ID: instanceId,
|
ID: gofakeit.UUID(),
|
||||||
Name: instanceName,
|
Name: gofakeit.Name(),
|
||||||
DefaultOrgID: "defaultOrgId",
|
DefaultOrgID: "defaultOrgId",
|
||||||
IAMProjectID: "iamProject",
|
IAMProjectID: "iamProject",
|
||||||
ConsoleClientID: "consoleCLient",
|
ConsoleClientID: "consoleCLient",
|
||||||
@@ -218,7 +237,7 @@ func TestUpdateInstance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create instance
|
// create instance
|
||||||
err := instanceRepo.Create(ctx, &inst)
|
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return &inst
|
return &inst
|
||||||
},
|
},
|
||||||
@@ -226,14 +245,10 @@ func TestUpdateInstance(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "update deleted instance",
|
name: "update deleted instance",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
|
||||||
instanceId := gofakeit.Name()
|
|
||||||
instanceName := gofakeit.Name()
|
|
||||||
|
|
||||||
inst := domain.Instance{
|
inst := domain.Instance{
|
||||||
ID: instanceId,
|
ID: gofakeit.UUID(),
|
||||||
Name: instanceName,
|
Name: gofakeit.Name(),
|
||||||
DefaultOrgID: "defaultOrgId",
|
DefaultOrgID: "defaultOrgId",
|
||||||
IAMProjectID: "iamProject",
|
IAMProjectID: "iamProject",
|
||||||
ConsoleClientID: "consoleCLient",
|
ConsoleClientID: "consoleCLient",
|
||||||
@@ -242,11 +257,11 @@ func TestUpdateInstance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create instance
|
// create instance
|
||||||
err := instanceRepo.Create(ctx, &inst)
|
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// delete instance
|
// delete instance
|
||||||
affectedRows, err := instanceRepo.Delete(ctx,
|
affectedRows, err := instanceRepo.Delete(t.Context(), tx,
|
||||||
inst.ID,
|
inst.ID,
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -258,11 +273,9 @@ func TestUpdateInstance(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "update non existent instance",
|
name: "update non existent instance",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
|
||||||
instanceId := gofakeit.Name()
|
|
||||||
|
|
||||||
inst := domain.Instance{
|
inst := domain.Instance{
|
||||||
ID: instanceId,
|
ID: gofakeit.UUID(),
|
||||||
}
|
}
|
||||||
return &inst
|
return &inst
|
||||||
},
|
},
|
||||||
@@ -272,15 +285,11 @@ func TestUpdateInstance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := context.Background()
|
instance := tt.testFunc(t, tx)
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
|
||||||
|
|
||||||
instance := tt.testFunc(ctx, t)
|
|
||||||
|
|
||||||
beforeUpdate := time.Now()
|
|
||||||
// update name
|
// update name
|
||||||
newName := "new_" + instance.Name
|
newName := "new_" + instance.Name
|
||||||
rowsAffected, err := instanceRepo.Update(ctx,
|
rowsAffected, err := instanceRepo.Update(t.Context(), tx,
|
||||||
instance.ID,
|
instance.ID,
|
||||||
instanceRepo.SetName(newName),
|
instanceRepo.SetName(newName),
|
||||||
)
|
)
|
||||||
@@ -294,7 +303,7 @@ func TestUpdateInstance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check instance values
|
// check instance values
|
||||||
instance, err = instanceRepo.Get(ctx,
|
instance, err = instanceRepo.Get(t.Context(), tx,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
instanceRepo.IDCondition(instance.ID),
|
instanceRepo.IDCondition(instance.ID),
|
||||||
),
|
),
|
||||||
@@ -308,24 +317,31 @@ func TestUpdateInstance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGetInstance(t *testing.T) {
|
func TestGetInstance(t *testing.T) {
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
tx, err := pool.Begin(context.Background(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
err := tx.Rollback(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Log("error during rollback:", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
instanceRepo := repository.InstanceRepository()
|
||||||
|
domainRepo := repository.InstanceDomainRepository()
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
name string
|
name string
|
||||||
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
|
testFunc func(t *testing.T) *domain.Instance
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []test{
|
tests := []test{
|
||||||
func() test {
|
func() test {
|
||||||
instanceId := gofakeit.Name()
|
|
||||||
return test{
|
return test{
|
||||||
name: "happy path get using id",
|
name: "happy path get using id",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
testFunc: func(t *testing.T) *domain.Instance {
|
||||||
instanceName := gofakeit.Name()
|
|
||||||
|
|
||||||
inst := domain.Instance{
|
inst := domain.Instance{
|
||||||
ID: instanceId,
|
ID: gofakeit.UUID(),
|
||||||
Name: instanceName,
|
Name: gofakeit.BeerName(),
|
||||||
DefaultOrgID: "defaultOrgId",
|
DefaultOrgID: "defaultOrgId",
|
||||||
IAMProjectID: "iamProject",
|
IAMProjectID: "iamProject",
|
||||||
ConsoleClientID: "consoleCLient",
|
ConsoleClientID: "consoleCLient",
|
||||||
@@ -334,7 +350,7 @@ func TestGetInstance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create instance
|
// create instance
|
||||||
err := instanceRepo.Create(ctx, &inst)
|
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
return &inst
|
return &inst
|
||||||
},
|
},
|
||||||
@@ -342,14 +358,10 @@ func TestGetInstance(t *testing.T) {
|
|||||||
}(),
|
}(),
|
||||||
{
|
{
|
||||||
name: "happy path including domains",
|
name: "happy path including domains",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
testFunc: func(t *testing.T) *domain.Instance {
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
|
||||||
instanceId := gofakeit.Name()
|
|
||||||
instanceName := gofakeit.Name()
|
|
||||||
|
|
||||||
inst := domain.Instance{
|
inst := domain.Instance{
|
||||||
ID: instanceId,
|
ID: gofakeit.NewCrypto().UUID(),
|
||||||
Name: instanceName,
|
Name: gofakeit.BeerName(),
|
||||||
DefaultOrgID: "defaultOrgId",
|
DefaultOrgID: "defaultOrgId",
|
||||||
IAMProjectID: "iamProject",
|
IAMProjectID: "iamProject",
|
||||||
ConsoleClientID: "consoleCLient",
|
ConsoleClientID: "consoleCLient",
|
||||||
@@ -358,10 +370,9 @@ func TestGetInstance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create instance
|
// create instance
|
||||||
err := instanceRepo.Create(ctx, &inst)
|
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
domainRepo := instanceRepo.Domains(false)
|
|
||||||
d := &domain.AddInstanceDomain{
|
d := &domain.AddInstanceDomain{
|
||||||
InstanceID: inst.ID,
|
InstanceID: inst.ID,
|
||||||
Domain: gofakeit.DomainName(),
|
Domain: gofakeit.DomainName(),
|
||||||
@@ -369,7 +380,7 @@ func TestGetInstance(t *testing.T) {
|
|||||||
IsGenerated: gu.Ptr(false),
|
IsGenerated: gu.Ptr(false),
|
||||||
Type: domain.DomainTypeCustom,
|
Type: domain.DomainTypeCustom,
|
||||||
}
|
}
|
||||||
err = domainRepo.Add(ctx, d)
|
err = domainRepo.Add(t.Context(), tx, d)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
inst.Domains = append(inst.Domains, &domain.InstanceDomain{
|
inst.Domains = append(inst.Domains, &domain.InstanceDomain{
|
||||||
@@ -387,7 +398,7 @@ func TestGetInstance(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "get non existent instance",
|
name: "get non existent instance",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
testFunc: func(t *testing.T) *domain.Instance {
|
||||||
inst := domain.Instance{
|
inst := domain.Instance{
|
||||||
ID: "get non existent instance",
|
ID: "get non existent instance",
|
||||||
}
|
}
|
||||||
@@ -398,16 +409,13 @@ func TestGetInstance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := context.Background()
|
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
|
||||||
|
|
||||||
var instance *domain.Instance
|
var instance *domain.Instance
|
||||||
if tt.testFunc != nil {
|
if tt.testFunc != nil {
|
||||||
instance = tt.testFunc(ctx, t)
|
instance = tt.testFunc(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// check instance values
|
// check instance values
|
||||||
returnedInstance, err := instanceRepo.Get(ctx,
|
returnedInstance, err := instanceRepo.Get(t.Context(), tx,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
instanceRepo.IDCondition(instance.ID),
|
instanceRepo.IDCondition(instance.ID),
|
||||||
),
|
),
|
||||||
@@ -434,28 +442,33 @@ func TestGetInstance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestListInstance(t *testing.T) {
|
func TestListInstance(t *testing.T) {
|
||||||
ctx := context.Background()
|
tx, err := pool.Begin(context.Background(), nil)
|
||||||
pool, stop, err := newEmbeddedDB(ctx)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer stop()
|
defer func() {
|
||||||
|
err := tx.Rollback(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Log("error during rollback:", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
instanceRepo := repository.InstanceRepository()
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
name string
|
name string
|
||||||
testFunc func(ctx context.Context, t *testing.T) []*domain.Instance
|
testFunc func(t *testing.T, tx database.QueryExecutor) []*domain.Instance
|
||||||
conditionClauses []database.Condition
|
conditionClauses []database.Condition
|
||||||
noInstanceReturned bool
|
noInstanceReturned bool
|
||||||
}
|
}
|
||||||
tests := []test{
|
tests := []test{
|
||||||
{
|
{
|
||||||
name: "happy path single instance no filter",
|
name: "happy path single instance no filter",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
|
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance {
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
|
||||||
noOfInstances := 1
|
noOfInstances := 1
|
||||||
instances := make([]*domain.Instance, noOfInstances)
|
instances := make([]*domain.Instance, noOfInstances)
|
||||||
for i := range noOfInstances {
|
for i := range noOfInstances {
|
||||||
|
|
||||||
inst := domain.Instance{
|
inst := domain.Instance{
|
||||||
ID: gofakeit.Name(),
|
ID: strconv.Itoa(i),
|
||||||
Name: gofakeit.Name(),
|
Name: gofakeit.Name(),
|
||||||
DefaultOrgID: "defaultOrgId",
|
DefaultOrgID: "defaultOrgId",
|
||||||
IAMProjectID: "iamProject",
|
IAMProjectID: "iamProject",
|
||||||
@@ -465,7 +478,7 @@ func TestListInstance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create instance
|
// create instance
|
||||||
err := instanceRepo.Create(ctx, &inst)
|
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
instances[i] = &inst
|
instances[i] = &inst
|
||||||
@@ -476,14 +489,13 @@ func TestListInstance(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "happy path multiple instance no filter",
|
name: "happy path multiple instance no filter",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
|
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance {
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
|
||||||
noOfInstances := 5
|
noOfInstances := 5
|
||||||
instances := make([]*domain.Instance, noOfInstances)
|
instances := make([]*domain.Instance, noOfInstances)
|
||||||
for i := range noOfInstances {
|
for i := range noOfInstances {
|
||||||
|
|
||||||
inst := domain.Instance{
|
inst := domain.Instance{
|
||||||
ID: gofakeit.Name(),
|
ID: strconv.Itoa(i),
|
||||||
Name: gofakeit.Name(),
|
Name: gofakeit.Name(),
|
||||||
DefaultOrgID: "defaultOrgId",
|
DefaultOrgID: "defaultOrgId",
|
||||||
IAMProjectID: "iamProject",
|
IAMProjectID: "iamProject",
|
||||||
@@ -493,7 +505,7 @@ func TestListInstance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create instance
|
// create instance
|
||||||
err := instanceRepo.Create(ctx, &inst)
|
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
instances[i] = &inst
|
instances[i] = &inst
|
||||||
@@ -503,17 +515,16 @@ func TestListInstance(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
func() test {
|
func() test {
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
instanceID := gofakeit.BeerName()
|
||||||
instanceId := gofakeit.Name()
|
|
||||||
return test{
|
return test{
|
||||||
name: "instance filter on id",
|
name: "instance filter on id",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
|
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance {
|
||||||
noOfInstances := 1
|
noOfInstances := 1
|
||||||
instances := make([]*domain.Instance, noOfInstances)
|
instances := make([]*domain.Instance, noOfInstances)
|
||||||
for i := range noOfInstances {
|
for i := range noOfInstances {
|
||||||
|
|
||||||
inst := domain.Instance{
|
inst := domain.Instance{
|
||||||
ID: instanceId,
|
ID: instanceID,
|
||||||
Name: gofakeit.Name(),
|
Name: gofakeit.Name(),
|
||||||
DefaultOrgID: "defaultOrgId",
|
DefaultOrgID: "defaultOrgId",
|
||||||
IAMProjectID: "iamProject",
|
IAMProjectID: "iamProject",
|
||||||
@@ -523,7 +534,7 @@ func TestListInstance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create instance
|
// create instance
|
||||||
err := instanceRepo.Create(ctx, &inst)
|
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
instances[i] = &inst
|
instances[i] = &inst
|
||||||
@@ -531,21 +542,20 @@ func TestListInstance(t *testing.T) {
|
|||||||
|
|
||||||
return instances
|
return instances
|
||||||
},
|
},
|
||||||
conditionClauses: []database.Condition{instanceRepo.IDCondition(instanceId)},
|
conditionClauses: []database.Condition{instanceRepo.IDCondition(instanceID)},
|
||||||
}
|
}
|
||||||
}(),
|
}(),
|
||||||
func() test {
|
func() test {
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
instanceName := gofakeit.BeerName()
|
||||||
instanceName := gofakeit.Name()
|
|
||||||
return test{
|
return test{
|
||||||
name: "multiple instance filter on name",
|
name: "multiple instance filter on name",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
|
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance {
|
||||||
noOfInstances := 5
|
noOfInstances := 5
|
||||||
instances := make([]*domain.Instance, noOfInstances)
|
instances := make([]*domain.Instance, noOfInstances)
|
||||||
for i := range noOfInstances {
|
for i := range noOfInstances {
|
||||||
|
|
||||||
inst := domain.Instance{
|
inst := domain.Instance{
|
||||||
ID: gofakeit.Name(),
|
ID: strconv.Itoa(i),
|
||||||
Name: instanceName,
|
Name: instanceName,
|
||||||
DefaultOrgID: "defaultOrgId",
|
DefaultOrgID: "defaultOrgId",
|
||||||
IAMProjectID: "iamProject",
|
IAMProjectID: "iamProject",
|
||||||
@@ -555,7 +565,7 @@ func TestListInstance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create instance
|
// create instance
|
||||||
err := instanceRepo.Create(ctx, &inst)
|
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
instances[i] = &inst
|
instances[i] = &inst
|
||||||
@@ -569,14 +579,15 @@ func TestListInstance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
t.Cleanup(func() {
|
savepoint, err := tx.Begin(t.Context())
|
||||||
_, err := pool.Exec(ctx, "DELETE FROM zitadel.instances")
|
require.NoError(t, err)
|
||||||
require.NoError(t, err)
|
defer func() {
|
||||||
})
|
err = savepoint.Rollback(t.Context())
|
||||||
|
if err != nil {
|
||||||
instances := tt.testFunc(ctx, t)
|
t.Log("error during rollback:", err)
|
||||||
|
}
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
}()
|
||||||
|
instances := tt.testFunc(t, savepoint)
|
||||||
|
|
||||||
var condition database.Condition
|
var condition database.Condition
|
||||||
if len(tt.conditionClauses) > 0 {
|
if len(tt.conditionClauses) > 0 {
|
||||||
@@ -584,13 +595,13 @@ func TestListInstance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check instance values
|
// check instance values
|
||||||
returnedInstances, err := instanceRepo.List(ctx,
|
returnedInstances, err := instanceRepo.List(t.Context(), tx,
|
||||||
database.WithCondition(condition),
|
database.WithCondition(condition),
|
||||||
database.WithOrderByAscending(instanceRepo.CreatedAtColumn()),
|
database.WithOrderByAscending(instanceRepo.IDColumn()),
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
if tt.noInstanceReturned {
|
if tt.noInstanceReturned {
|
||||||
assert.Nil(t, returnedInstances)
|
assert.Len(t, returnedInstances, 0)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -609,42 +620,45 @@ func TestListInstance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDeleteInstance(t *testing.T) {
|
func TestDeleteInstance(t *testing.T) {
|
||||||
|
tx, err := pool.Begin(context.Background(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
err := tx.Rollback(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Log("error during rollback:", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
instanceRepo := repository.InstanceRepository()
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
name string
|
name string
|
||||||
testFunc func(ctx context.Context, t *testing.T)
|
testFunc func(t *testing.T, tx database.QueryExecutor)
|
||||||
instanceID string
|
instanceID string
|
||||||
noOfDeletedRows int64
|
noOfDeletedRows int64
|
||||||
}
|
}
|
||||||
tests := []test{
|
tests := []test{
|
||||||
func() test {
|
func() test {
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
instanceID := gofakeit.NewCrypto().UUID()
|
||||||
instanceId := gofakeit.Name()
|
|
||||||
var noOfInstances int64 = 1
|
|
||||||
return test{
|
return test{
|
||||||
name: "happy path delete single instance filter id",
|
name: "happy path delete single instance filter id",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) {
|
testFunc: func(t *testing.T, tx database.QueryExecutor) {
|
||||||
instances := make([]*domain.Instance, noOfInstances)
|
inst := domain.Instance{
|
||||||
for i := range noOfInstances {
|
ID: instanceID,
|
||||||
|
Name: gofakeit.Name(),
|
||||||
inst := domain.Instance{
|
DefaultOrgID: "defaultOrgId",
|
||||||
ID: instanceId,
|
IAMProjectID: "iamProject",
|
||||||
Name: gofakeit.Name(),
|
ConsoleClientID: "consoleCLient",
|
||||||
DefaultOrgID: "defaultOrgId",
|
ConsoleAppID: "consoleApp",
|
||||||
IAMProjectID: "iamProject",
|
DefaultLanguage: "defaultLanguage",
|
||||||
ConsoleClientID: "consoleCLient",
|
|
||||||
ConsoleAppID: "consoleApp",
|
|
||||||
DefaultLanguage: "defaultLanguage",
|
|
||||||
}
|
|
||||||
|
|
||||||
// create instance
|
|
||||||
err := instanceRepo.Create(ctx, &inst)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
instances[i] = &inst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create instance
|
||||||
|
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||||
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
instanceID: instanceId,
|
instanceID: instanceID,
|
||||||
noOfDeletedRows: noOfInstances,
|
noOfDeletedRows: 1,
|
||||||
}
|
}
|
||||||
}(),
|
}(),
|
||||||
func() test {
|
func() test {
|
||||||
@@ -655,40 +669,33 @@ func TestDeleteInstance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}(),
|
}(),
|
||||||
func() test {
|
func() test {
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
instanceID := gofakeit.Name()
|
||||||
instanceName := gofakeit.Name()
|
|
||||||
return test{
|
return test{
|
||||||
name: "deleted already deleted instance",
|
name: "deleted already deleted instance",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) {
|
testFunc: func(t *testing.T, tx database.QueryExecutor) {
|
||||||
noOfInstances := 1
|
|
||||||
instances := make([]*domain.Instance, noOfInstances)
|
|
||||||
for i := range noOfInstances {
|
|
||||||
|
|
||||||
inst := domain.Instance{
|
inst := domain.Instance{
|
||||||
ID: gofakeit.Name(),
|
ID: instanceID,
|
||||||
Name: instanceName,
|
Name: gofakeit.BeerName(),
|
||||||
DefaultOrgID: "defaultOrgId",
|
DefaultOrgID: "defaultOrgId",
|
||||||
IAMProjectID: "iamProject",
|
IAMProjectID: "iamProject",
|
||||||
ConsoleClientID: "consoleCLient",
|
ConsoleClientID: "consoleCLient",
|
||||||
ConsoleAppID: "consoleApp",
|
ConsoleAppID: "consoleApp",
|
||||||
DefaultLanguage: "defaultLanguage",
|
DefaultLanguage: "defaultLanguage",
|
||||||
}
|
|
||||||
|
|
||||||
// create instance
|
|
||||||
err := instanceRepo.Create(ctx, &inst)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
instances[i] = &inst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// create instance
|
||||||
|
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
// delete instance
|
// delete instance
|
||||||
affectedRows, err := instanceRepo.Delete(ctx,
|
affectedRows, err := instanceRepo.Delete(t.Context(), tx,
|
||||||
instances[0].ID,
|
inst.ID,
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, int64(1), affectedRows)
|
assert.Equal(t, int64(1), affectedRows)
|
||||||
},
|
},
|
||||||
instanceID: instanceName,
|
instanceID: instanceID,
|
||||||
// this test should return 0 affected rows as the instance was already deleted
|
// this test should return 0 affected rows as the instance was already deleted
|
||||||
noOfDeletedRows: 0,
|
noOfDeletedRows: 0,
|
||||||
}
|
}
|
||||||
@@ -696,22 +703,26 @@ func TestDeleteInstance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := context.Background()
|
savepoint, err := tx.Begin(t.Context())
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
err = savepoint.Rollback(t.Context())
|
||||||
|
if err != nil {
|
||||||
|
t.Log("error during rollback:", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if tt.testFunc != nil {
|
if tt.testFunc != nil {
|
||||||
tt.testFunc(ctx, t)
|
tt.testFunc(t, savepoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
// delete instance
|
// delete instance
|
||||||
noOfDeletedRows, err := instanceRepo.Delete(ctx,
|
noOfDeletedRows, err := instanceRepo.Delete(t.Context(), savepoint, tt.instanceID)
|
||||||
tt.instanceID,
|
|
||||||
)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, noOfDeletedRows, tt.noOfDeletedRows)
|
assert.Equal(t, noOfDeletedRows, tt.noOfDeletedRows)
|
||||||
|
|
||||||
// check instance was deleted
|
// check instance was deleted
|
||||||
instance, err := instanceRepo.Get(ctx,
|
instance, err := instanceRepo.Get(t.Context(), savepoint,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
instanceRepo.IDCondition(tt.instanceID),
|
instanceRepo.IDCondition(tt.instanceID),
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -14,25 +14,24 @@ import (
|
|||||||
var _ domain.OrganizationRepository = (*org)(nil)
|
var _ domain.OrganizationRepository = (*org)(nil)
|
||||||
|
|
||||||
type org struct {
|
type org struct {
|
||||||
repository
|
|
||||||
shouldLoadDomains bool
|
shouldLoadDomains bool
|
||||||
domainRepo domain.OrganizationDomainRepository
|
domainRepo orgDomain
|
||||||
}
|
}
|
||||||
|
|
||||||
func OrganizationRepository(client database.QueryExecutor) domain.OrganizationRepository {
|
func (o org) unqualifiedTableName() string {
|
||||||
return &org{
|
return "organizations"
|
||||||
repository: repository{
|
}
|
||||||
client: client,
|
|
||||||
},
|
func OrganizationRepository() domain.OrganizationRepository {
|
||||||
}
|
return new(org)
|
||||||
}
|
}
|
||||||
|
|
||||||
const queryOrganizationStmt = `SELECT organizations.id, organizations.name, organizations.instance_id, organizations.state, organizations.created_at, organizations.updated_at` +
|
const queryOrganizationStmt = `SELECT organizations.id, organizations.name, organizations.instance_id, organizations.state, organizations.created_at, organizations.updated_at` +
|
||||||
` , jsonb_agg(json_build_object('domain', org_domains.domain, 'isVerified', org_domains.is_verified, 'isPrimary', org_domains.is_primary, 'validationType', org_domains.validation_type, 'createdAt', org_domains.created_at, 'updatedAt', org_domains.updated_at)) FILTER (WHERE org_domains.org_id IS NOT NULL) AS domains` +
|
` , jsonb_agg(json_build_object('instanceId', org_domains.instance_id, 'orgId', org_domains.org_id, 'domain', org_domains.domain, 'isVerified', org_domains.is_verified, 'isPrimary', org_domains.is_primary, 'validationType', org_domains.validation_type, 'createdAt', org_domains.created_at, 'updatedAt', org_domains.updated_at)) FILTER (WHERE org_domains.org_id IS NOT NULL) AS domains` +
|
||||||
` FROM zitadel.organizations`
|
` FROM zitadel.organizations`
|
||||||
|
|
||||||
// Get implements [domain.OrganizationRepository].
|
// Get implements [domain.OrganizationRepository].
|
||||||
func (o *org) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Organization, error) {
|
func (o org) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.Organization, error) {
|
||||||
opts = append(opts,
|
opts = append(opts,
|
||||||
o.joinDomains(),
|
o.joinDomains(),
|
||||||
database.WithGroupBy(o.InstanceIDColumn(), o.IDColumn()),
|
database.WithGroupBy(o.InstanceIDColumn(), o.IDColumn()),
|
||||||
@@ -43,15 +42,19 @@ func (o *org) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Or
|
|||||||
opt(options)
|
opt(options)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !options.Condition.IsRestrictingColumn(o.InstanceIDColumn()) {
|
||||||
|
return nil, database.NewMissingConditionError(o.InstanceIDColumn())
|
||||||
|
}
|
||||||
|
|
||||||
var builder database.StatementBuilder
|
var builder database.StatementBuilder
|
||||||
builder.WriteString(queryOrganizationStmt)
|
builder.WriteString(queryOrganizationStmt)
|
||||||
options.Write(&builder)
|
options.Write(&builder)
|
||||||
|
|
||||||
return scanOrganization(ctx, o.client, &builder)
|
return scanOrganization(ctx, client, &builder)
|
||||||
}
|
}
|
||||||
|
|
||||||
// List implements [domain.OrganizationRepository].
|
// List implements [domain.OrganizationRepository].
|
||||||
func (o *org) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Organization, error) {
|
func (o org) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*domain.Organization, error) {
|
||||||
opts = append(opts,
|
opts = append(opts,
|
||||||
o.joinDomains(),
|
o.joinDomains(),
|
||||||
database.WithGroupBy(o.InstanceIDColumn(), o.IDColumn()),
|
database.WithGroupBy(o.InstanceIDColumn(), o.IDColumn()),
|
||||||
@@ -62,30 +65,15 @@ func (o *org) List(ctx context.Context, opts ...database.QueryOption) ([]*domain
|
|||||||
opt(options)
|
opt(options)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !options.Condition.IsRestrictingColumn(o.InstanceIDColumn()) {
|
||||||
|
return nil, database.NewMissingConditionError(o.InstanceIDColumn())
|
||||||
|
}
|
||||||
|
|
||||||
var builder database.StatementBuilder
|
var builder database.StatementBuilder
|
||||||
builder.WriteString(queryOrganizationStmt)
|
builder.WriteString(queryOrganizationStmt)
|
||||||
options.Write(&builder)
|
options.Write(&builder)
|
||||||
|
|
||||||
return scanOrganizations(ctx, o.client, &builder)
|
return scanOrganizations(ctx, client, &builder)
|
||||||
}
|
|
||||||
|
|
||||||
func (o *org) joinDomains() database.QueryOption {
|
|
||||||
columns := make([]database.Condition, 0, 3)
|
|
||||||
columns = append(columns,
|
|
||||||
database.NewColumnCondition(o.InstanceIDColumn(), o.Domains(false).InstanceIDColumn()),
|
|
||||||
database.NewColumnCondition(o.IDColumn(), o.Domains(false).OrgIDColumn()),
|
|
||||||
)
|
|
||||||
|
|
||||||
// If domains should not be joined, we make sure to return null for the domain columns
|
|
||||||
// the query optimizer of the dialect should optimize this away if no domains are requested
|
|
||||||
if !o.shouldLoadDomains {
|
|
||||||
columns = append(columns, database.IsNull(o.domainRepo.OrgIDColumn()))
|
|
||||||
}
|
|
||||||
|
|
||||||
return database.WithLeftJoin(
|
|
||||||
"zitadel.org_domains",
|
|
||||||
database.And(columns...),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const createOrganizationStmt = `INSERT INTO zitadel.organizations (id, name, instance_id, state)` +
|
const createOrganizationStmt = `INSERT INTO zitadel.organizations (id, name, instance_id, state)` +
|
||||||
@@ -93,46 +81,48 @@ const createOrganizationStmt = `INSERT INTO zitadel.organizations (id, name, ins
|
|||||||
` RETURNING created_at, updated_at`
|
` RETURNING created_at, updated_at`
|
||||||
|
|
||||||
// Create implements [domain.OrganizationRepository].
|
// Create implements [domain.OrganizationRepository].
|
||||||
func (o *org) Create(ctx context.Context, organization *domain.Organization) error {
|
func (o org) Create(ctx context.Context, client database.QueryExecutor, organization *domain.Organization) error {
|
||||||
builder := database.StatementBuilder{}
|
builder := database.StatementBuilder{}
|
||||||
builder.AppendArgs(organization.ID, organization.Name, organization.InstanceID, organization.State)
|
builder.AppendArgs(organization.ID, organization.Name, organization.InstanceID, organization.State)
|
||||||
builder.WriteString(createOrganizationStmt)
|
builder.WriteString(createOrganizationStmt)
|
||||||
|
|
||||||
return o.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&organization.CreatedAt, &organization.UpdatedAt)
|
return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&organization.CreatedAt, &organization.UpdatedAt)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update implements [domain.OrganizationRepository].
|
// Update implements [domain.OrganizationRepository].
|
||||||
func (o *org) Update(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string, changes ...database.Change) (int64, error) {
|
func (o org) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error) {
|
||||||
if len(changes) == 0 {
|
if len(changes) == 0 {
|
||||||
return 0, database.ErrNoChanges
|
return 0, database.ErrNoChanges
|
||||||
}
|
}
|
||||||
builder := database.StatementBuilder{}
|
if !condition.IsRestrictingColumn(o.InstanceIDColumn()) {
|
||||||
|
return 0, database.NewMissingConditionError(o.InstanceIDColumn())
|
||||||
|
}
|
||||||
|
if !database.Changes(changes).IsOnColumn(o.UpdatedAtColumn()) {
|
||||||
|
changes = append(changes, database.NewChange(o.UpdatedAtColumn(), database.NullInstruction))
|
||||||
|
}
|
||||||
|
|
||||||
|
var builder database.StatementBuilder
|
||||||
builder.WriteString(`UPDATE zitadel.organizations SET `)
|
builder.WriteString(`UPDATE zitadel.organizations SET `)
|
||||||
|
|
||||||
instanceIDCondition := o.InstanceIDCondition(instanceID)
|
|
||||||
|
|
||||||
conditions := []database.Condition{id, instanceIDCondition}
|
|
||||||
database.Changes(changes).Write(&builder)
|
database.Changes(changes).Write(&builder)
|
||||||
writeCondition(&builder, database.And(conditions...))
|
writeCondition(&builder, condition)
|
||||||
|
|
||||||
stmt := builder.String()
|
stmt := builder.String()
|
||||||
|
|
||||||
rowsAffected, err := o.client.Exec(ctx, stmt, builder.Args()...)
|
rowsAffected, err := client.Exec(ctx, stmt, builder.Args()...)
|
||||||
return rowsAffected, err
|
return rowsAffected, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete implements [domain.OrganizationRepository].
|
// Delete implements [domain.OrganizationRepository].
|
||||||
func (o *org) Delete(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string) (int64, error) {
|
func (o org) Delete(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error) {
|
||||||
builder := database.StatementBuilder{}
|
if !condition.IsRestrictingColumn(o.InstanceIDColumn()) {
|
||||||
|
return 0, database.NewMissingConditionError(o.InstanceIDColumn())
|
||||||
|
}
|
||||||
|
|
||||||
|
var builder database.StatementBuilder
|
||||||
builder.WriteString(`DELETE FROM zitadel.organizations`)
|
builder.WriteString(`DELETE FROM zitadel.organizations`)
|
||||||
|
writeCondition(&builder, condition)
|
||||||
|
|
||||||
instanceIDCondition := o.InstanceIDCondition(instanceID)
|
return client.Exec(ctx, builder.String(), builder.Args()...)
|
||||||
|
|
||||||
conditions := []database.Condition{id, instanceIDCondition}
|
|
||||||
writeCondition(&builder, database.And(conditions...))
|
|
||||||
|
|
||||||
return o.client.Exec(ctx, builder.String(), builder.Args()...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------------
|
// -------------------------------------------------------------
|
||||||
@@ -154,13 +144,13 @@ func (o org) SetState(state domain.OrgState) database.Change {
|
|||||||
// -------------------------------------------------------------
|
// -------------------------------------------------------------
|
||||||
|
|
||||||
// IDCondition implements [domain.organizationConditions].
|
// IDCondition implements [domain.organizationConditions].
|
||||||
func (o org) IDCondition(id string) domain.OrgIdentifierCondition {
|
func (o org) IDCondition(id string) database.Condition {
|
||||||
return database.NewTextCondition(o.IDColumn(), database.TextOperationEqual, id)
|
return database.NewTextCondition(o.IDColumn(), database.TextOperationEqual, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NameCondition implements [domain.organizationConditions].
|
// NameCondition implements [domain.organizationConditions].
|
||||||
func (o org) NameCondition(name string) domain.OrgIdentifierCondition {
|
func (o org) NameCondition(op database.TextOperation, name string) database.Condition {
|
||||||
return database.NewTextCondition(o.NameColumn(), database.TextOperationEqual, name)
|
return database.NewTextCondition(o.NameColumn(), op, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// InstanceIDCondition implements [domain.organizationConditions].
|
// InstanceIDCondition implements [domain.organizationConditions].
|
||||||
@@ -173,38 +163,62 @@ func (o org) StateCondition(state domain.OrgState) database.Condition {
|
|||||||
return database.NewTextCondition(o.StateColumn(), database.TextOperationEqual, state.String())
|
return database.NewTextCondition(o.StateColumn(), database.TextOperationEqual, state.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExistsDomain creates a correlated [database.Exists] condition on org_domains.
|
||||||
|
// Use this filter to make sure the organization returned contains a specific domain.
|
||||||
|
// Example usage:
|
||||||
|
//
|
||||||
|
// domainRepo := orgRepo.Domains(true) // ensure domains are loaded/aggregated
|
||||||
|
// org, _ := orgRepo.Get(ctx,
|
||||||
|
// database.WithCondition(
|
||||||
|
// database.And(
|
||||||
|
// orgRepo.InstanceIDCondition(instanceID),
|
||||||
|
// orgRepo.ExistsDomain(domainRepo.DomainCondition(database.TextOperationEqual, "example.com")),
|
||||||
|
// ),
|
||||||
|
// ),
|
||||||
|
// )
|
||||||
|
func (o org) ExistsDomain(cond database.Condition) database.Condition {
|
||||||
|
return database.Exists(
|
||||||
|
o.domainRepo.qualifiedTableName(),
|
||||||
|
database.And(
|
||||||
|
database.NewColumnCondition(o.InstanceIDColumn(), o.domainRepo.InstanceIDColumn()),
|
||||||
|
database.NewColumnCondition(o.IDColumn(), o.domainRepo.OrgIDColumn()),
|
||||||
|
cond,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------------
|
// -------------------------------------------------------------
|
||||||
// columns
|
// columns
|
||||||
// -------------------------------------------------------------
|
// -------------------------------------------------------------
|
||||||
|
|
||||||
// IDColumn implements [domain.organizationColumns].
|
// IDColumn implements [domain.organizationColumns].
|
||||||
func (org) IDColumn() database.Column {
|
func (o org) IDColumn() database.Column {
|
||||||
return database.NewColumn("organizations", "id")
|
return database.NewColumn(o.unqualifiedTableName(), "id")
|
||||||
}
|
}
|
||||||
|
|
||||||
// NameColumn implements [domain.organizationColumns].
|
// NameColumn implements [domain.organizationColumns].
|
||||||
func (org) NameColumn() database.Column {
|
func (o org) NameColumn() database.Column {
|
||||||
return database.NewColumn("organizations", "name")
|
return database.NewColumn(o.unqualifiedTableName(), "name")
|
||||||
}
|
}
|
||||||
|
|
||||||
// InstanceIDColumn implements [domain.organizationColumns].
|
// InstanceIDColumn implements [domain.organizationColumns].
|
||||||
func (org) InstanceIDColumn() database.Column {
|
func (o org) InstanceIDColumn() database.Column {
|
||||||
return database.NewColumn("organizations", "instance_id")
|
return database.NewColumn(o.unqualifiedTableName(), "instance_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
// StateColumn implements [domain.organizationColumns].
|
// StateColumn implements [domain.organizationColumns].
|
||||||
func (org) StateColumn() database.Column {
|
func (o org) StateColumn() database.Column {
|
||||||
return database.NewColumn("organizations", "state")
|
return database.NewColumn(o.unqualifiedTableName(), "state")
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreatedAtColumn implements [domain.organizationColumns].
|
// CreatedAtColumn implements [domain.organizationColumns].
|
||||||
func (org) CreatedAtColumn() database.Column {
|
func (o org) CreatedAtColumn() database.Column {
|
||||||
return database.NewColumn("organizations", "created_at")
|
return database.NewColumn(o.unqualifiedTableName(), "created_at")
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatedAtColumn implements [domain.organizationColumns].
|
// UpdatedAtColumn implements [domain.organizationColumns].
|
||||||
func (org) UpdatedAtColumn() database.Column {
|
func (o org) UpdatedAtColumn() database.Column {
|
||||||
return database.NewColumn("organizations", "updated_at")
|
return database.NewColumn(o.unqualifiedTableName(), "updated_at")
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------------
|
// -------------------------------------------------------------
|
||||||
@@ -255,20 +269,27 @@ func scanOrganizations(ctx context.Context, querier database.Querier, builder *d
|
|||||||
// sub repositories
|
// sub repositories
|
||||||
// -------------------------------------------------------------
|
// -------------------------------------------------------------
|
||||||
|
|
||||||
// Domains implements [domain.OrganizationRepository].
|
func (o org) LoadDomains() domain.OrganizationRepository {
|
||||||
func (o *org) Domains(shouldLoad bool) domain.OrganizationDomainRepository {
|
return &org{
|
||||||
if !o.shouldLoadDomains {
|
shouldLoadDomains: true,
|
||||||
o.shouldLoadDomains = shouldLoad
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
if o.domainRepo != nil {
|
|
||||||
return o.domainRepo
|
func (o org) joinDomains() database.QueryOption {
|
||||||
}
|
columns := make([]database.Condition, 0, 3)
|
||||||
|
columns = append(columns,
|
||||||
o.domainRepo = &orgDomain{
|
database.NewColumnCondition(o.InstanceIDColumn(), o.domainRepo.InstanceIDColumn()),
|
||||||
repository: o.repository,
|
database.NewColumnCondition(o.IDColumn(), o.domainRepo.OrgIDColumn()),
|
||||||
org: o,
|
)
|
||||||
}
|
|
||||||
|
// If domains should not be joined, we make sure to return null for the domain columns
|
||||||
return o.domainRepo
|
// the query optimizer of the dialect should optimize this away if no domains are requested
|
||||||
|
if !o.shouldLoadDomains {
|
||||||
|
columns = append(columns, database.IsNull(o.domainRepo.OrgIDColumn()))
|
||||||
|
}
|
||||||
|
|
||||||
|
return database.WithLeftJoin(
|
||||||
|
o.domainRepo.qualifiedTableName(),
|
||||||
|
database.And(columns...),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,9 +10,18 @@ import (
|
|||||||
|
|
||||||
var _ domain.OrganizationDomainRepository = (*orgDomain)(nil)
|
var _ domain.OrganizationDomainRepository = (*orgDomain)(nil)
|
||||||
|
|
||||||
type orgDomain struct {
|
type orgDomain struct{}
|
||||||
repository
|
|
||||||
*org
|
func OrganizationDomainRepository() domain.OrganizationDomainRepository {
|
||||||
|
return new(orgDomain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (orgDomain) qualifiedTableName() string {
|
||||||
|
return "zitadel.org_domains"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (orgDomain) unqualifiedTableName() string {
|
||||||
|
return "org_domains"
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------------
|
// -------------------------------------------------------------
|
||||||
@@ -24,36 +33,44 @@ const queryOrganizationDomainStmt = `SELECT instance_id, org_id, domain, is_veri
|
|||||||
|
|
||||||
// Get implements [domain.OrganizationDomainRepository].
|
// Get implements [domain.OrganizationDomainRepository].
|
||||||
// Subtle: this method shadows the method ([domain.OrganizationRepository]).Get of orgDomain.org.
|
// Subtle: this method shadows the method ([domain.OrganizationRepository]).Get of orgDomain.org.
|
||||||
func (o *orgDomain) Get(ctx context.Context, opts ...database.QueryOption) (*domain.OrganizationDomain, error) {
|
func (o orgDomain) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.OrganizationDomain, error) {
|
||||||
options := new(database.QueryOpts)
|
options := new(database.QueryOpts)
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(options)
|
opt(options)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !options.Condition.IsRestrictingColumn(o.InstanceIDColumn()) {
|
||||||
|
return nil, database.NewMissingConditionError(o.InstanceIDColumn())
|
||||||
|
}
|
||||||
|
|
||||||
var builder database.StatementBuilder
|
var builder database.StatementBuilder
|
||||||
builder.WriteString(queryOrganizationDomainStmt)
|
builder.WriteString(queryOrganizationDomainStmt)
|
||||||
options.Write(&builder)
|
options.Write(&builder)
|
||||||
|
|
||||||
return scanOrganizationDomain(ctx, o.client, &builder)
|
return scanOrganizationDomain(ctx, client, &builder)
|
||||||
}
|
}
|
||||||
|
|
||||||
// List implements [domain.OrganizationDomainRepository].
|
// List implements [domain.OrganizationDomainRepository].
|
||||||
// Subtle: this method shadows the method ([domain.OrganizationRepository]).List of orgDomain.org.
|
// Subtle: this method shadows the method ([domain.OrganizationRepository]).List of orgDomain.org.
|
||||||
func (o *orgDomain) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.OrganizationDomain, error) {
|
func (o orgDomain) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) ([]*domain.OrganizationDomain, error) {
|
||||||
options := new(database.QueryOpts)
|
options := new(database.QueryOpts)
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(options)
|
opt(options)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !options.Condition.IsRestrictingColumn(o.InstanceIDColumn()) {
|
||||||
|
return nil, database.NewMissingConditionError(o.InstanceIDColumn())
|
||||||
|
}
|
||||||
|
|
||||||
var builder database.StatementBuilder
|
var builder database.StatementBuilder
|
||||||
builder.WriteString(queryOrganizationDomainStmt)
|
builder.WriteString(queryOrganizationDomainStmt)
|
||||||
options.Write(&builder)
|
options.Write(&builder)
|
||||||
|
|
||||||
return scanOrganizationDomains(ctx, o.client, &builder)
|
return scanOrganizationDomains(ctx, client, &builder)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add implements [domain.OrganizationDomainRepository].
|
// Add implements [domain.OrganizationDomainRepository].
|
||||||
func (o *orgDomain) Add(ctx context.Context, domain *domain.AddOrganizationDomain) error {
|
func (o orgDomain) Add(ctx context.Context, client database.QueryExecutor, domain *domain.AddOrganizationDomain) error {
|
||||||
var (
|
var (
|
||||||
builder database.StatementBuilder
|
builder database.StatementBuilder
|
||||||
createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction
|
createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction
|
||||||
@@ -69,33 +86,47 @@ func (o *orgDomain) Add(ctx context.Context, domain *domain.AddOrganizationDomai
|
|||||||
builder.WriteArgs(domain.InstanceID, domain.OrgID, domain.Domain, domain.IsVerified, domain.IsPrimary, domain.ValidationType, createdAt, updatedAt)
|
builder.WriteArgs(domain.InstanceID, domain.OrgID, domain.Domain, domain.IsVerified, domain.IsPrimary, domain.ValidationType, createdAt, updatedAt)
|
||||||
builder.WriteString(`) RETURNING created_at, updated_at`)
|
builder.WriteString(`) RETURNING created_at, updated_at`)
|
||||||
|
|
||||||
return o.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt)
|
return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update implements [domain.OrganizationDomainRepository].
|
// Update implements [domain.OrganizationDomainRepository].
|
||||||
// Subtle: this method shadows the method ([domain.OrganizationRepository]).Update of orgDomain.org.
|
// Subtle: this method shadows the method ([domain.OrganizationRepository]).Update of orgDomain.org.
|
||||||
func (o *orgDomain) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) {
|
func (o orgDomain) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) (int64, error) {
|
||||||
if len(changes) == 0 {
|
if len(changes) == 0 {
|
||||||
return 0, database.ErrNoChanges
|
return 0, database.ErrNoChanges
|
||||||
}
|
}
|
||||||
|
if !condition.IsRestrictingColumn(o.InstanceIDColumn()) {
|
||||||
|
return 0, database.NewMissingConditionError(o.InstanceIDColumn())
|
||||||
|
}
|
||||||
|
if !condition.IsRestrictingColumn(o.OrgIDColumn()) {
|
||||||
|
return 0, database.NewMissingConditionError(o.OrgIDColumn())
|
||||||
|
}
|
||||||
|
if !database.Changes(changes).IsOnColumn(o.UpdatedAtColumn()) {
|
||||||
|
changes = append(changes, database.NewChange(o.UpdatedAtColumn(), database.NullInstruction))
|
||||||
|
}
|
||||||
|
|
||||||
var builder database.StatementBuilder
|
var builder database.StatementBuilder
|
||||||
|
|
||||||
builder.WriteString(`UPDATE zitadel.org_domains SET `)
|
builder.WriteString(`UPDATE zitadel.org_domains SET `)
|
||||||
database.Changes(changes).Write(&builder)
|
database.Changes(changes).Write(&builder)
|
||||||
writeCondition(&builder, condition)
|
writeCondition(&builder, condition)
|
||||||
|
|
||||||
return o.client.Exec(ctx, builder.String(), builder.Args()...)
|
return client.Exec(ctx, builder.String(), builder.Args()...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove implements [domain.OrganizationDomainRepository].
|
// Remove implements [domain.OrganizationDomainRepository].
|
||||||
func (o *orgDomain) Remove(ctx context.Context, condition database.Condition) (int64, error) {
|
func (o orgDomain) Remove(ctx context.Context, client database.QueryExecutor, condition database.Condition) (int64, error) {
|
||||||
var builder database.StatementBuilder
|
if !condition.IsRestrictingColumn(o.InstanceIDColumn()) {
|
||||||
|
return 0, database.NewMissingConditionError(o.InstanceIDColumn())
|
||||||
|
}
|
||||||
|
if !condition.IsRestrictingColumn(o.OrgIDColumn()) {
|
||||||
|
return 0, database.NewMissingConditionError(o.OrgIDColumn())
|
||||||
|
}
|
||||||
|
|
||||||
|
var builder database.StatementBuilder
|
||||||
builder.WriteString(`DELETE FROM zitadel.org_domains `)
|
builder.WriteString(`DELETE FROM zitadel.org_domains `)
|
||||||
writeCondition(&builder, condition)
|
writeCondition(&builder, condition)
|
||||||
|
|
||||||
return o.client.Exec(ctx, builder.String(), builder.Args()...)
|
return client.Exec(ctx, builder.String(), builder.Args()...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------------
|
// -------------------------------------------------------------
|
||||||
@@ -158,45 +189,43 @@ func (o orgDomain) OrgIDCondition(orgID string) database.Condition {
|
|||||||
|
|
||||||
// CreatedAtColumn implements [domain.OrganizationDomainRepository].
|
// CreatedAtColumn implements [domain.OrganizationDomainRepository].
|
||||||
// Subtle: this method shadows the method ([domain.OrganizationRepository]).CreatedAtColumn of orgDomain.org.
|
// Subtle: this method shadows the method ([domain.OrganizationRepository]).CreatedAtColumn of orgDomain.org.
|
||||||
func (orgDomain) CreatedAtColumn() database.Column {
|
func (o orgDomain) CreatedAtColumn() database.Column {
|
||||||
return database.NewColumn("org_domains", "created_at")
|
return database.NewColumn(o.unqualifiedTableName(), "created_at")
|
||||||
}
|
}
|
||||||
|
|
||||||
// DomainColumn implements [domain.OrganizationDomainRepository].
|
// DomainColumn implements [domain.OrganizationDomainRepository].
|
||||||
func (orgDomain) DomainColumn() database.Column {
|
func (o orgDomain) DomainColumn() database.Column {
|
||||||
return database.NewColumn("org_domains", "domain")
|
return database.NewColumn(o.unqualifiedTableName(), "domain")
|
||||||
}
|
}
|
||||||
|
|
||||||
// InstanceIDColumn implements [domain.OrganizationDomainRepository].
|
// InstanceIDColumn implements [domain.OrganizationDomainRepository].
|
||||||
// Subtle: this method shadows the method ([domain.OrganizationRepository]).InstanceIDColumn of orgDomain.org.
|
func (o orgDomain) InstanceIDColumn() database.Column {
|
||||||
func (orgDomain) InstanceIDColumn() database.Column {
|
return database.NewColumn(o.unqualifiedTableName(), "instance_id")
|
||||||
return database.NewColumn("org_domains", "instance_id")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsPrimaryColumn implements [domain.OrganizationDomainRepository].
|
// IsPrimaryColumn implements [domain.OrganizationDomainRepository].
|
||||||
func (orgDomain) IsPrimaryColumn() database.Column {
|
func (o orgDomain) IsPrimaryColumn() database.Column {
|
||||||
return database.NewColumn("org_domains", "is_primary")
|
return database.NewColumn(o.unqualifiedTableName(), "is_primary")
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsVerifiedColumn implements [domain.OrganizationDomainRepository].
|
// IsVerifiedColumn implements [domain.OrganizationDomainRepository].
|
||||||
func (orgDomain) IsVerifiedColumn() database.Column {
|
func (o orgDomain) IsVerifiedColumn() database.Column {
|
||||||
return database.NewColumn("org_domains", "is_verified")
|
return database.NewColumn(o.unqualifiedTableName(), "is_verified")
|
||||||
}
|
}
|
||||||
|
|
||||||
// OrgIDColumn implements [domain.OrganizationDomainRepository].
|
// OrgIDColumn implements [domain.OrganizationDomainRepository].
|
||||||
func (orgDomain) OrgIDColumn() database.Column {
|
func (o orgDomain) OrgIDColumn() database.Column {
|
||||||
return database.NewColumn("org_domains", "org_id")
|
return database.NewColumn(o.unqualifiedTableName(), "org_id")
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdatedAtColumn implements [domain.OrganizationDomainRepository].
|
// UpdatedAtColumn implements [domain.OrganizationDomainRepository].
|
||||||
// Subtle: this method shadows the method ([domain.OrganizationRepository]).UpdatedAtColumn of orgDomain.org.
|
func (o orgDomain) UpdatedAtColumn() database.Column {
|
||||||
func (orgDomain) UpdatedAtColumn() database.Column {
|
return database.NewColumn(o.unqualifiedTableName(), "updated_at")
|
||||||
return database.NewColumn("org_domains", "updated_at")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidationTypeColumn implements [domain.OrganizationDomainRepository].
|
// ValidationTypeColumn implements [domain.OrganizationDomainRepository].
|
||||||
func (orgDomain) ValidationTypeColumn() database.Column {
|
func (o orgDomain) ValidationTypeColumn() database.Column {
|
||||||
return database.NewColumn("org_domains", "validation_type")
|
return database.NewColumn(o.unqualifiedTableName(), "validation_type")
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------------------------------------------------
|
// -------------------------------------------------------------
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package repository_test
|
package repository_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/brianvoe/gofakeit/v6"
|
"github.com/brianvoe/gofakeit/v6"
|
||||||
@@ -15,6 +14,15 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestAddOrganizationDomain(t *testing.T) {
|
func TestAddOrganizationDomain(t *testing.T) {
|
||||||
|
tx, err := pool.Begin(t.Context(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
err = tx.Rollback(t.Context())
|
||||||
|
if err != nil {
|
||||||
|
t.Log("error during rollback:", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// create instance
|
// create instance
|
||||||
instanceID := gofakeit.UUID()
|
instanceID := gofakeit.UUID()
|
||||||
instance := domain.Instance{
|
instance := domain.Instance{
|
||||||
@@ -26,8 +34,11 @@ func TestAddOrganizationDomain(t *testing.T) {
|
|||||||
ConsoleAppID: "consoleApp",
|
ConsoleAppID: "consoleApp",
|
||||||
DefaultLanguage: "defaultLanguage",
|
DefaultLanguage: "defaultLanguage",
|
||||||
}
|
}
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
instanceRepo := repository.InstanceRepository()
|
||||||
err := instanceRepo.Create(t.Context(), &instance)
|
orgRepo := repository.OrganizationRepository()
|
||||||
|
domainRepo := repository.OrganizationDomainRepository()
|
||||||
|
|
||||||
|
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// create organization
|
// create organization
|
||||||
@@ -41,7 +52,7 @@ func TestAddOrganizationDomain(t *testing.T) {
|
|||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
testFunc func(ctx context.Context, t *testing.T, domainRepo domain.OrganizationDomainRepository) *domain.AddOrganizationDomain
|
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.AddOrganizationDomain
|
||||||
organizationDomain domain.AddOrganizationDomain
|
organizationDomain domain.AddOrganizationDomain
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
@@ -92,7 +103,7 @@ func TestAddOrganizationDomain(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "add domain with same domain twice",
|
name: "add domain with same domain twice",
|
||||||
testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.OrganizationDomainRepository) *domain.AddOrganizationDomain {
|
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.AddOrganizationDomain {
|
||||||
domainName := gofakeit.DomainName()
|
domainName := gofakeit.DomainName()
|
||||||
|
|
||||||
organizationDomain := &domain.AddOrganizationDomain{
|
organizationDomain := &domain.AddOrganizationDomain{
|
||||||
@@ -104,7 +115,7 @@ func TestAddOrganizationDomain(t *testing.T) {
|
|||||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||||
}
|
}
|
||||||
|
|
||||||
err := domainRepo.Add(ctx, organizationDomain)
|
err := domainRepo.Add(t.Context(), tx, organizationDomain)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// return same domain again
|
// return same domain again
|
||||||
@@ -169,28 +180,26 @@ func TestAddOrganizationDomain(t *testing.T) {
|
|||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
ctx := t.Context()
|
savepoint, err := tx.Begin(t.Context())
|
||||||
|
|
||||||
tx, err := pool.Begin(t.Context(), nil)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, tx.Rollback(t.Context()))
|
err = savepoint.Rollback(t.Context())
|
||||||
|
if err != nil {
|
||||||
|
t.Log("error during rollback:", err)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
orgRepo := repository.OrganizationRepository(tx)
|
err = orgRepo.Create(t.Context(), savepoint, &organization)
|
||||||
err = orgRepo.Create(t.Context(), &organization)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
domainRepo := orgRepo.Domains(false)
|
|
||||||
|
|
||||||
var organizationDomain *domain.AddOrganizationDomain
|
var organizationDomain *domain.AddOrganizationDomain
|
||||||
if test.testFunc != nil {
|
if test.testFunc != nil {
|
||||||
organizationDomain = test.testFunc(ctx, t, domainRepo)
|
organizationDomain = test.testFunc(t, savepoint)
|
||||||
} else {
|
} else {
|
||||||
organizationDomain = &test.organizationDomain
|
organizationDomain = &test.organizationDomain
|
||||||
}
|
}
|
||||||
|
|
||||||
err = domainRepo.Add(ctx, organizationDomain)
|
err = domainRepo.Add(t.Context(), tx, organizationDomain)
|
||||||
if test.err != nil {
|
if test.err != nil {
|
||||||
assert.ErrorIs(t, err, test.err)
|
assert.ErrorIs(t, err, test.err)
|
||||||
return
|
return
|
||||||
@@ -204,6 +213,16 @@ func TestAddOrganizationDomain(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGetOrganizationDomain(t *testing.T) {
|
func TestGetOrganizationDomain(t *testing.T) {
|
||||||
|
tx, err := pool.Begin(t.Context(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, tx.Rollback(t.Context()))
|
||||||
|
}()
|
||||||
|
|
||||||
|
instanceRepo := repository.InstanceRepository()
|
||||||
|
orgRepo := repository.OrganizationRepository()
|
||||||
|
domainRepo := repository.OrganizationDomainRepository()
|
||||||
|
|
||||||
// create instance
|
// create instance
|
||||||
instanceID := gofakeit.UUID()
|
instanceID := gofakeit.UUID()
|
||||||
instance := domain.Instance{
|
instance := domain.Instance{
|
||||||
@@ -225,29 +244,18 @@ func TestGetOrganizationDomain(t *testing.T) {
|
|||||||
State: domain.OrgStateActive,
|
State: domain.OrgStateActive,
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := pool.Begin(t.Context(), nil)
|
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
require.NoError(t, tx.Rollback(t.Context()))
|
|
||||||
}()
|
|
||||||
|
|
||||||
instanceRepo := repository.InstanceRepository(tx)
|
|
||||||
err = instanceRepo.Create(t.Context(), &instance)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
orgRepo := repository.OrganizationRepository(tx)
|
err = orgRepo.Create(t.Context(), tx, &organization)
|
||||||
err = orgRepo.Create(t.Context(), &organization)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// add domains
|
// add domains
|
||||||
domainRepo := orgRepo.Domains(false)
|
|
||||||
domainName1 := gofakeit.DomainName()
|
|
||||||
domainName2 := gofakeit.DomainName()
|
|
||||||
|
|
||||||
domain1 := &domain.AddOrganizationDomain{
|
domain1 := &domain.AddOrganizationDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instanceID,
|
||||||
OrgID: orgID,
|
OrgID: orgID,
|
||||||
Domain: domainName1,
|
Domain: gofakeit.DomainName(),
|
||||||
IsVerified: true,
|
IsVerified: true,
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||||
@@ -255,15 +263,15 @@ func TestGetOrganizationDomain(t *testing.T) {
|
|||||||
domain2 := &domain.AddOrganizationDomain{
|
domain2 := &domain.AddOrganizationDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instanceID,
|
||||||
OrgID: orgID,
|
OrgID: orgID,
|
||||||
Domain: domainName2,
|
Domain: gofakeit.DomainName(),
|
||||||
IsVerified: false,
|
IsVerified: false,
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
|
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
|
||||||
}
|
}
|
||||||
|
|
||||||
err = domainRepo.Add(t.Context(), domain1)
|
err = domainRepo.Add(t.Context(), tx, domain1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
err = domainRepo.Add(t.Context(), domain2)
|
err = domainRepo.Add(t.Context(), tx, domain2)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -275,12 +283,15 @@ func TestGetOrganizationDomain(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "get primary domain",
|
name: "get primary domain",
|
||||||
opts: []database.QueryOption{
|
opts: []database.QueryOption{
|
||||||
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
|
database.WithCondition(database.And(
|
||||||
|
domainRepo.InstanceIDCondition(instanceID),
|
||||||
|
domainRepo.IsPrimaryCondition(true),
|
||||||
|
)),
|
||||||
},
|
},
|
||||||
expected: &domain.OrganizationDomain{
|
expected: &domain.OrganizationDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instanceID,
|
||||||
OrgID: orgID,
|
OrgID: orgID,
|
||||||
Domain: domainName1,
|
Domain: domain1.Domain,
|
||||||
IsVerified: true,
|
IsVerified: true,
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||||
@@ -289,12 +300,15 @@ func TestGetOrganizationDomain(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "get by domain name",
|
name: "get by domain name",
|
||||||
opts: []database.QueryOption{
|
opts: []database.QueryOption{
|
||||||
database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, domainName2)),
|
database.WithCondition(database.And(
|
||||||
|
domainRepo.InstanceIDCondition(instanceID),
|
||||||
|
domainRepo.DomainCondition(database.TextOperationEqual, domain2.Domain),
|
||||||
|
)),
|
||||||
},
|
},
|
||||||
expected: &domain.OrganizationDomain{
|
expected: &domain.OrganizationDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instanceID,
|
||||||
OrgID: orgID,
|
OrgID: orgID,
|
||||||
Domain: domainName2,
|
Domain: domain2.Domain,
|
||||||
IsVerified: false,
|
IsVerified: false,
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
|
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
|
||||||
@@ -303,13 +317,16 @@ func TestGetOrganizationDomain(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "get by org ID",
|
name: "get by org ID",
|
||||||
opts: []database.QueryOption{
|
opts: []database.QueryOption{
|
||||||
database.WithCondition(domainRepo.OrgIDCondition(orgID)),
|
database.WithCondition(database.And(
|
||||||
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
|
domainRepo.InstanceIDCondition(instanceID),
|
||||||
|
domainRepo.OrgIDCondition(orgID),
|
||||||
|
domainRepo.IsPrimaryCondition(true),
|
||||||
|
)),
|
||||||
},
|
},
|
||||||
expected: &domain.OrganizationDomain{
|
expected: &domain.OrganizationDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instanceID,
|
||||||
OrgID: orgID,
|
OrgID: orgID,
|
||||||
Domain: domainName1,
|
Domain: domain1.Domain,
|
||||||
IsVerified: true,
|
IsVerified: true,
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||||
@@ -318,12 +335,15 @@ func TestGetOrganizationDomain(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "get verified domain",
|
name: "get verified domain",
|
||||||
opts: []database.QueryOption{
|
opts: []database.QueryOption{
|
||||||
database.WithCondition(domainRepo.IsVerifiedCondition(true)),
|
database.WithCondition(database.And(
|
||||||
|
domainRepo.InstanceIDCondition(instanceID),
|
||||||
|
domainRepo.IsVerifiedCondition(true),
|
||||||
|
)),
|
||||||
},
|
},
|
||||||
expected: &domain.OrganizationDomain{
|
expected: &domain.OrganizationDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instanceID,
|
||||||
OrgID: orgID,
|
OrgID: orgID,
|
||||||
Domain: domainName1,
|
Domain: domain1.Domain,
|
||||||
IsVerified: true,
|
IsVerified: true,
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||||
@@ -332,7 +352,10 @@ func TestGetOrganizationDomain(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "get non-existent domain",
|
name: "get non-existent domain",
|
||||||
opts: []database.QueryOption{
|
opts: []database.QueryOption{
|
||||||
database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com")),
|
database.WithCondition(database.And(
|
||||||
|
domainRepo.InstanceIDCondition(instanceID),
|
||||||
|
domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
|
||||||
|
)),
|
||||||
},
|
},
|
||||||
err: new(database.NoRowFoundError),
|
err: new(database.NoRowFoundError),
|
||||||
},
|
},
|
||||||
@@ -340,9 +363,7 @@ func TestGetOrganizationDomain(t *testing.T) {
|
|||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
ctx := t.Context()
|
result, err := domainRepo.Get(t.Context(), tx, test.opts...)
|
||||||
|
|
||||||
result, err := domainRepo.Get(ctx, test.opts...)
|
|
||||||
if test.err != nil {
|
if test.err != nil {
|
||||||
assert.ErrorIs(t, err, test.err)
|
assert.ErrorIs(t, err, test.err)
|
||||||
return
|
return
|
||||||
@@ -362,10 +383,22 @@ func TestGetOrganizationDomain(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestListOrganizationDomains(t *testing.T) {
|
func TestListOrganizationDomains(t *testing.T) {
|
||||||
|
tx, err := pool.Begin(t.Context(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
err = tx.Rollback(t.Context())
|
||||||
|
if err != nil {
|
||||||
|
t.Log("error during rollback:", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
instanceRepo := repository.InstanceRepository()
|
||||||
|
orgRepo := repository.OrganizationRepository()
|
||||||
|
domainRepo := repository.OrganizationDomainRepository()
|
||||||
|
|
||||||
// create instance
|
// create instance
|
||||||
instanceID := gofakeit.UUID()
|
|
||||||
instance := domain.Instance{
|
instance := domain.Instance{
|
||||||
ID: instanceID,
|
ID: gofakeit.UUID(),
|
||||||
Name: gofakeit.Name(),
|
Name: gofakeit.Name(),
|
||||||
DefaultOrgID: "defaultOrgId",
|
DefaultOrgID: "defaultOrgId",
|
||||||
IAMProjectID: "iamProject",
|
IAMProjectID: "iamProject",
|
||||||
@@ -375,50 +408,40 @@ func TestListOrganizationDomains(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create organization
|
// create organization
|
||||||
orgID := gofakeit.UUID()
|
|
||||||
organization := domain.Organization{
|
organization := domain.Organization{
|
||||||
ID: orgID,
|
ID: gofakeit.UUID(),
|
||||||
Name: gofakeit.Name(),
|
Name: gofakeit.Name(),
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
State: domain.OrgStateActive,
|
State: domain.OrgStateActive,
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := pool.Begin(t.Context(), nil)
|
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
require.NoError(t, tx.Rollback(t.Context()))
|
|
||||||
}()
|
|
||||||
|
|
||||||
instanceRepo := repository.InstanceRepository(tx)
|
|
||||||
err = instanceRepo.Create(t.Context(), &instance)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
orgRepo := repository.OrganizationRepository(tx)
|
err = orgRepo.Create(t.Context(), tx, &organization)
|
||||||
err = orgRepo.Create(t.Context(), &organization)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// add multiple domains
|
// add multiple domains
|
||||||
domainRepo := orgRepo.Domains(false)
|
|
||||||
domains := []domain.AddOrganizationDomain{
|
domains := []domain.AddOrganizationDomain{
|
||||||
{
|
{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
OrgID: orgID,
|
OrgID: organization.ID,
|
||||||
Domain: gofakeit.DomainName(),
|
Domain: gofakeit.DomainName(),
|
||||||
IsVerified: true,
|
IsVerified: true,
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
OrgID: orgID,
|
OrgID: organization.ID,
|
||||||
Domain: gofakeit.DomainName(),
|
Domain: gofakeit.DomainName(),
|
||||||
IsVerified: false,
|
IsVerified: false,
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
|
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
OrgID: orgID,
|
OrgID: organization.ID,
|
||||||
Domain: gofakeit.DomainName(),
|
Domain: gofakeit.DomainName(),
|
||||||
IsVerified: true,
|
IsVerified: true,
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
@@ -427,7 +450,7 @@ func TestListOrganizationDomains(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := range domains {
|
for i := range domains {
|
||||||
err = domainRepo.Add(t.Context(), &domains[i])
|
err = domainRepo.Add(t.Context(), tx, &domains[i])
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -437,42 +460,59 @@ func TestListOrganizationDomains(t *testing.T) {
|
|||||||
expectedCount int
|
expectedCount int
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "list all domains",
|
name: "list all domains",
|
||||||
opts: []database.QueryOption{},
|
opts: []database.QueryOption{
|
||||||
|
database.WithCondition(domainRepo.InstanceIDCondition(instance.ID)),
|
||||||
|
},
|
||||||
expectedCount: 3,
|
expectedCount: 3,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "list verified domains",
|
name: "list verified domains",
|
||||||
opts: []database.QueryOption{
|
opts: []database.QueryOption{
|
||||||
database.WithCondition(domainRepo.IsVerifiedCondition(true)),
|
database.WithCondition(database.And(
|
||||||
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
|
domainRepo.IsVerifiedCondition(true),
|
||||||
|
)),
|
||||||
},
|
},
|
||||||
expectedCount: 2,
|
expectedCount: 2,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "list primary domains",
|
name: "list primary domains",
|
||||||
opts: []database.QueryOption{
|
opts: []database.QueryOption{
|
||||||
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
|
database.WithCondition(database.And(
|
||||||
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
|
domainRepo.IsPrimaryCondition(true),
|
||||||
|
)),
|
||||||
},
|
},
|
||||||
expectedCount: 1,
|
expectedCount: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "list by organization",
|
name: "list by organization",
|
||||||
opts: []database.QueryOption{
|
opts: []database.QueryOption{
|
||||||
database.WithCondition(domainRepo.OrgIDCondition(orgID)),
|
database.WithCondition(database.And(
|
||||||
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
|
domainRepo.OrgIDCondition(organization.ID),
|
||||||
|
)),
|
||||||
},
|
},
|
||||||
expectedCount: 3,
|
expectedCount: 3,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "list by instance",
|
name: "list by instance",
|
||||||
opts: []database.QueryOption{
|
opts: []database.QueryOption{
|
||||||
database.WithCondition(domainRepo.InstanceIDCondition(instanceID)),
|
database.WithCondition(database.And(
|
||||||
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
|
)),
|
||||||
},
|
},
|
||||||
expectedCount: 3,
|
expectedCount: 3,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "list non-existent organization",
|
name: "list non-existent organization",
|
||||||
opts: []database.QueryOption{
|
opts: []database.QueryOption{
|
||||||
database.WithCondition(domainRepo.OrgIDCondition("non-existent")),
|
database.WithCondition(database.And(
|
||||||
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
|
domainRepo.OrgIDCondition("non-existent"),
|
||||||
|
)),
|
||||||
},
|
},
|
||||||
expectedCount: 0,
|
expectedCount: 0,
|
||||||
},
|
},
|
||||||
@@ -482,13 +522,13 @@ func TestListOrganizationDomains(t *testing.T) {
|
|||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
ctx := t.Context()
|
ctx := t.Context()
|
||||||
|
|
||||||
results, err := domainRepo.List(ctx, test.opts...)
|
results, err := domainRepo.List(ctx, tx, test.opts...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, results, test.expectedCount)
|
assert.Len(t, results, test.expectedCount)
|
||||||
|
|
||||||
for _, result := range results {
|
for _, result := range results {
|
||||||
assert.Equal(t, instanceID, result.InstanceID)
|
assert.Equal(t, instance.ID, result.InstanceID)
|
||||||
assert.Equal(t, orgID, result.OrgID)
|
assert.Equal(t, organization.ID, result.OrgID)
|
||||||
assert.NotEmpty(t, result.Domain)
|
assert.NotEmpty(t, result.Domain)
|
||||||
assert.NotEmpty(t, result.CreatedAt)
|
assert.NotEmpty(t, result.CreatedAt)
|
||||||
assert.NotEmpty(t, result.UpdatedAt)
|
assert.NotEmpty(t, result.UpdatedAt)
|
||||||
@@ -498,10 +538,22 @@ func TestListOrganizationDomains(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateOrganizationDomain(t *testing.T) {
|
func TestUpdateOrganizationDomain(t *testing.T) {
|
||||||
|
tx, err := pool.Begin(t.Context(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
err = tx.Rollback(t.Context())
|
||||||
|
if err != nil {
|
||||||
|
t.Log("error during rollback:", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
instanceRepo := repository.InstanceRepository()
|
||||||
|
orgRepo := repository.OrganizationRepository()
|
||||||
|
domainRepo := repository.OrganizationDomainRepository()
|
||||||
|
|
||||||
// create instance
|
// create instance
|
||||||
instanceID := gofakeit.UUID()
|
|
||||||
instance := domain.Instance{
|
instance := domain.Instance{
|
||||||
ID: instanceID,
|
ID: gofakeit.UUID(),
|
||||||
Name: gofakeit.Name(),
|
Name: gofakeit.Name(),
|
||||||
DefaultOrgID: "defaultOrgId",
|
DefaultOrgID: "defaultOrgId",
|
||||||
IAMProjectID: "iamProject",
|
IAMProjectID: "iamProject",
|
||||||
@@ -511,41 +563,30 @@ func TestUpdateOrganizationDomain(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create organization
|
// create organization
|
||||||
orgID := gofakeit.UUID()
|
|
||||||
organization := domain.Organization{
|
organization := domain.Organization{
|
||||||
ID: orgID,
|
ID: gofakeit.UUID(),
|
||||||
Name: gofakeit.Name(),
|
Name: gofakeit.Name(),
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
State: domain.OrgStateActive,
|
State: domain.OrgStateActive,
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := pool.Begin(t.Context(), nil)
|
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
require.NoError(t, tx.Rollback(t.Context()))
|
|
||||||
}()
|
|
||||||
|
|
||||||
instanceRepo := repository.InstanceRepository(tx)
|
|
||||||
err = instanceRepo.Create(t.Context(), &instance)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
orgRepo := repository.OrganizationRepository(tx)
|
err = orgRepo.Create(t.Context(), tx, &organization)
|
||||||
err = orgRepo.Create(t.Context(), &organization)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// add domain
|
// add domain
|
||||||
domainRepo := orgRepo.Domains(false)
|
|
||||||
domainName := gofakeit.DomainName()
|
|
||||||
organizationDomain := &domain.AddOrganizationDomain{
|
organizationDomain := &domain.AddOrganizationDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
OrgID: orgID,
|
OrgID: organization.ID,
|
||||||
Domain: domainName,
|
Domain: gofakeit.DomainName(),
|
||||||
IsVerified: false,
|
IsVerified: false,
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||||
}
|
}
|
||||||
|
|
||||||
err = domainRepo.Add(t.Context(), organizationDomain)
|
err = domainRepo.Add(t.Context(), tx, organizationDomain)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -556,26 +597,42 @@ func TestUpdateOrganizationDomain(t *testing.T) {
|
|||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "set verified",
|
name: "set verified",
|
||||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
condition: database.And(
|
||||||
changes: []database.Change{domainRepo.SetVerified()},
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
expected: 1,
|
domainRepo.OrgIDCondition(organization.ID),
|
||||||
|
domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
|
||||||
|
),
|
||||||
|
changes: []database.Change{domainRepo.SetVerified()},
|
||||||
|
expected: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "set primary",
|
name: "set primary",
|
||||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
condition: database.And(
|
||||||
changes: []database.Change{domainRepo.SetPrimary()},
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
expected: 1,
|
domainRepo.OrgIDCondition(organization.ID),
|
||||||
|
domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
|
||||||
|
),
|
||||||
|
changes: []database.Change{domainRepo.SetPrimary()},
|
||||||
|
expected: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "set validation type",
|
name: "set validation type",
|
||||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
condition: database.And(
|
||||||
changes: []database.Change{domainRepo.SetValidationType(domain.DomainValidationTypeHTTP)},
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
expected: 1,
|
domainRepo.OrgIDCondition(organization.ID),
|
||||||
|
domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
|
||||||
|
),
|
||||||
|
changes: []database.Change{domainRepo.SetValidationType(domain.DomainValidationTypeHTTP)},
|
||||||
|
expected: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple changes",
|
name: "multiple changes",
|
||||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
condition: database.And(
|
||||||
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
|
domainRepo.OrgIDCondition(organization.ID),
|
||||||
|
domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
|
||||||
|
),
|
||||||
changes: []database.Change{
|
changes: []database.Change{
|
||||||
domainRepo.SetVerified(),
|
domainRepo.SetVerified(),
|
||||||
domainRepo.SetPrimary(),
|
domainRepo.SetPrimary(),
|
||||||
@@ -584,31 +641,41 @@ func TestUpdateOrganizationDomain(t *testing.T) {
|
|||||||
expected: 1,
|
expected: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "update by org ID and domain",
|
name: "update by org ID and domain",
|
||||||
condition: database.And(domainRepo.OrgIDCondition(orgID), domainRepo.DomainCondition(database.TextOperationEqual, domainName)),
|
condition: database.And(
|
||||||
changes: []database.Change{domainRepo.SetVerified()},
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
expected: 1,
|
domainRepo.OrgIDCondition(organization.ID),
|
||||||
|
domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
|
||||||
|
),
|
||||||
|
changes: []database.Change{domainRepo.SetVerified()},
|
||||||
|
expected: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "update non-existent domain",
|
name: "update non-existent domain",
|
||||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
|
condition: database.And(
|
||||||
changes: []database.Change{domainRepo.SetVerified()},
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
expected: 0,
|
domainRepo.OrgIDCondition(organization.ID),
|
||||||
|
domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
|
||||||
|
),
|
||||||
|
changes: []database.Change{domainRepo.SetVerified()},
|
||||||
|
expected: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "no changes",
|
name: "no changes",
|
||||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
condition: database.And(
|
||||||
changes: []database.Change{},
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
expected: 0,
|
domainRepo.OrgIDCondition(organization.ID),
|
||||||
err: database.ErrNoChanges,
|
domainRepo.DomainCondition(database.TextOperationEqual, organizationDomain.Domain),
|
||||||
|
),
|
||||||
|
changes: []database.Change{},
|
||||||
|
expected: 0,
|
||||||
|
err: database.ErrNoChanges,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
ctx := t.Context()
|
rowsAffected, err := domainRepo.Update(t.Context(), tx, test.condition, test.changes...)
|
||||||
|
|
||||||
rowsAffected, err := domainRepo.Update(ctx, test.condition, test.changes...)
|
|
||||||
if test.err != nil {
|
if test.err != nil {
|
||||||
assert.ErrorIs(t, err, test.err)
|
assert.ErrorIs(t, err, test.err)
|
||||||
return
|
return
|
||||||
@@ -619,7 +686,7 @@ func TestUpdateOrganizationDomain(t *testing.T) {
|
|||||||
|
|
||||||
// verify changes were applied if rows were affected
|
// verify changes were applied if rows were affected
|
||||||
if rowsAffected > 0 && len(test.changes) > 0 {
|
if rowsAffected > 0 && len(test.changes) > 0 {
|
||||||
result, err := domainRepo.Get(ctx, database.WithCondition(test.condition))
|
result, err := domainRepo.Get(t.Context(), tx, database.WithCondition(test.condition))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// We know changes were applied since rowsAffected > 0
|
// We know changes were applied since rowsAffected > 0
|
||||||
@@ -632,10 +699,22 @@ func TestUpdateOrganizationDomain(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRemoveOrganizationDomain(t *testing.T) {
|
func TestRemoveOrganizationDomain(t *testing.T) {
|
||||||
|
tx, err := pool.Begin(t.Context(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
err = tx.Rollback(t.Context())
|
||||||
|
if err != nil {
|
||||||
|
t.Log("error during rollback:", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
instanceRepo := repository.InstanceRepository()
|
||||||
|
orgRepo := repository.OrganizationRepository()
|
||||||
|
domainRepo := repository.OrganizationDomainRepository()
|
||||||
|
|
||||||
// create instance
|
// create instance
|
||||||
instanceID := gofakeit.UUID()
|
|
||||||
instance := domain.Instance{
|
instance := domain.Instance{
|
||||||
ID: instanceID,
|
ID: gofakeit.UUID(),
|
||||||
Name: gofakeit.Name(),
|
Name: gofakeit.Name(),
|
||||||
DefaultOrgID: "defaultOrgId",
|
DefaultOrgID: "defaultOrgId",
|
||||||
IAMProjectID: "iamProject",
|
IAMProjectID: "iamProject",
|
||||||
@@ -645,53 +724,40 @@ func TestRemoveOrganizationDomain(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create organization
|
// create organization
|
||||||
orgID := gofakeit.UUID()
|
|
||||||
organization := domain.Organization{
|
organization := domain.Organization{
|
||||||
ID: orgID,
|
ID: gofakeit.UUID(),
|
||||||
Name: gofakeit.Name(),
|
Name: gofakeit.Name(),
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
State: domain.OrgStateActive,
|
State: domain.OrgStateActive,
|
||||||
}
|
}
|
||||||
|
|
||||||
tx, err := pool.Begin(t.Context(), nil)
|
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
require.NoError(t, tx.Rollback(t.Context()))
|
|
||||||
}()
|
|
||||||
|
|
||||||
instanceRepo := repository.InstanceRepository(tx)
|
|
||||||
err = instanceRepo.Create(t.Context(), &instance)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
orgRepo := repository.OrganizationRepository(tx)
|
err = orgRepo.Create(t.Context(), tx, &organization)
|
||||||
err = orgRepo.Create(t.Context(), &organization)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// add domains
|
// add domains
|
||||||
domainRepo := orgRepo.Domains(false)
|
|
||||||
domainName1 := gofakeit.DomainName()
|
|
||||||
domainName2 := gofakeit.DomainName()
|
|
||||||
|
|
||||||
domain1 := &domain.AddOrganizationDomain{
|
domain1 := &domain.AddOrganizationDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
OrgID: orgID,
|
OrgID: organization.ID,
|
||||||
Domain: domainName1,
|
Domain: gofakeit.DomainName(),
|
||||||
IsVerified: true,
|
IsVerified: true,
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||||
}
|
}
|
||||||
domain2 := &domain.AddOrganizationDomain{
|
domain2 := &domain.AddOrganizationDomain{
|
||||||
InstanceID: instanceID,
|
InstanceID: instance.ID,
|
||||||
OrgID: orgID,
|
OrgID: organization.ID,
|
||||||
Domain: domainName2,
|
Domain: gofakeit.DomainName(),
|
||||||
IsVerified: false,
|
IsVerified: false,
|
||||||
IsPrimary: false,
|
IsPrimary: false,
|
||||||
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
|
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
|
||||||
}
|
}
|
||||||
|
|
||||||
err = domainRepo.Add(t.Context(), domain1)
|
err = domainRepo.Add(t.Context(), tx, domain1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
err = domainRepo.Add(t.Context(), domain2)
|
err = domainRepo.Add(t.Context(), tx, domain2)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -700,50 +766,70 @@ func TestRemoveOrganizationDomain(t *testing.T) {
|
|||||||
expected int64
|
expected int64
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "remove by domain name",
|
name: "remove by domain name",
|
||||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName1),
|
condition: database.And(
|
||||||
expected: 1,
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
|
domainRepo.OrgIDCondition(organization.ID),
|
||||||
|
domainRepo.DomainCondition(database.TextOperationEqual, domain1.Domain),
|
||||||
|
),
|
||||||
|
expected: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "remove by primary condition",
|
name: "remove by primary condition",
|
||||||
condition: domainRepo.IsPrimaryCondition(false),
|
condition: database.And(
|
||||||
expected: 1, // domain2 should still exist and be non-primary
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
|
domainRepo.OrgIDCondition(organization.ID),
|
||||||
|
domainRepo.IsPrimaryCondition(false),
|
||||||
|
),
|
||||||
|
expected: 1, // domain2 should still exist and be non-primary
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "remove by org ID and domain",
|
name: "remove by org ID and domain",
|
||||||
condition: database.And(domainRepo.OrgIDCondition(orgID), domainRepo.DomainCondition(database.TextOperationEqual, domainName2)),
|
condition: database.And(
|
||||||
expected: 1,
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
|
domainRepo.OrgIDCondition(organization.ID),
|
||||||
|
domainRepo.DomainCondition(database.TextOperationEqual, domain2.Domain),
|
||||||
|
),
|
||||||
|
expected: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "remove non-existent domain",
|
name: "remove non-existent domain",
|
||||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
|
condition: database.And(
|
||||||
expected: 0,
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
|
domainRepo.OrgIDCondition(organization.ID),
|
||||||
|
domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
|
||||||
|
),
|
||||||
|
expected: 0,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
ctx := t.Context()
|
snapshot, err := tx.Begin(t.Context())
|
||||||
|
|
||||||
snapshot, err := tx.Begin(ctx)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, snapshot.Rollback(ctx))
|
err = snapshot.Rollback(t.Context())
|
||||||
|
if err != nil {
|
||||||
|
t.Log("error during rollback:", err)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
orgRepo := repository.OrganizationRepository(snapshot)
|
|
||||||
domainRepo := orgRepo.Domains(false)
|
|
||||||
|
|
||||||
// count before removal
|
// count before removal
|
||||||
beforeCount, err := domainRepo.List(ctx)
|
beforeCount, err := domainRepo.List(t.Context(), snapshot, database.WithCondition(database.And(
|
||||||
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
|
domainRepo.OrgIDCondition(organization.ID),
|
||||||
|
)))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
rowsAffected, err := domainRepo.Remove(ctx, test.condition)
|
rowsAffected, err := domainRepo.Remove(t.Context(), snapshot, test.condition)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, test.expected, rowsAffected)
|
assert.Equal(t, test.expected, rowsAffected)
|
||||||
|
|
||||||
// verify removal
|
// verify removal
|
||||||
afterCount, err := domainRepo.List(ctx)
|
afterCount, err := domainRepo.List(t.Context(), snapshot, database.WithCondition(database.And(
|
||||||
|
domainRepo.InstanceIDCondition(instance.ID),
|
||||||
|
domainRepo.OrgIDCondition(organization.ID),
|
||||||
|
)))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, len(beforeCount)-int(test.expected), len(afterCount))
|
assert.Equal(t, len(beforeCount)-int(test.expected), len(afterCount))
|
||||||
})
|
})
|
||||||
@@ -751,8 +837,7 @@ func TestRemoveOrganizationDomain(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestOrganizationDomainConditions(t *testing.T) {
|
func TestOrganizationDomainConditions(t *testing.T) {
|
||||||
orgRepo := repository.OrganizationRepository(pool)
|
domainRepo := repository.OrganizationDomainRepository()
|
||||||
domainRepo := orgRepo.Domains(false)
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -811,8 +896,7 @@ func TestOrganizationDomainConditions(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestOrganizationDomainChanges(t *testing.T) {
|
func TestOrganizationDomainChanges(t *testing.T) {
|
||||||
orgRepo := repository.OrganizationRepository(pool)
|
domainRepo := repository.OrganizationDomainRepository()
|
||||||
domainRepo := orgRepo.Domains(false)
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -851,8 +935,7 @@ func TestOrganizationDomainChanges(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestOrganizationDomainColumns(t *testing.T) {
|
func TestOrganizationDomainColumns(t *testing.T) {
|
||||||
orgRepo := repository.OrganizationRepository(pool)
|
domainRepo := repository.OrganizationDomainRepository()
|
||||||
domainRepo := orgRepo.Domains(false)
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package repository_test
|
package repository_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -15,6 +15,18 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestCreateOrganization(t *testing.T) {
|
func TestCreateOrganization(t *testing.T) {
|
||||||
|
beforeCreate := time.Now()
|
||||||
|
tx, err := pool.Begin(t.Context(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
err := tx.Rollback(t.Context())
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("error during rollback: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
instanceRepo := repository.InstanceRepository()
|
||||||
|
organizationRepo := repository.OrganizationRepository()
|
||||||
// create instance
|
// create instance
|
||||||
instanceId := gofakeit.Name()
|
instanceId := gofakeit.Name()
|
||||||
instance := domain.Instance{
|
instance := domain.Instance{
|
||||||
@@ -26,13 +38,12 @@ func TestCreateOrganization(t *testing.T) {
|
|||||||
ConsoleAppID: "consoleApp",
|
ConsoleAppID: "consoleApp",
|
||||||
DefaultLanguage: "defaultLanguage",
|
DefaultLanguage: "defaultLanguage",
|
||||||
}
|
}
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||||
err := instanceRepo.Create(t.Context(), &instance)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
testFunc func(ctx context.Context, t *testing.T) *domain.Organization
|
testFunc func(t *testing.T, client database.QueryExecutor) *domain.Organization
|
||||||
organization domain.Organization
|
organization domain.Organization
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
@@ -67,8 +78,7 @@ func TestCreateOrganization(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "adding org with same id twice",
|
name: "adding org with same id twice",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
|
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Organization {
|
||||||
organizationRepo := repository.OrganizationRepository(pool)
|
|
||||||
organizationId := gofakeit.Name()
|
organizationId := gofakeit.Name()
|
||||||
organizationName := gofakeit.Name()
|
organizationName := gofakeit.Name()
|
||||||
|
|
||||||
@@ -79,7 +89,7 @@ func TestCreateOrganization(t *testing.T) {
|
|||||||
State: domain.OrgStateActive,
|
State: domain.OrgStateActive,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := organizationRepo.Create(ctx, &org)
|
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// change the name to make sure same only the id clashes
|
// change the name to make sure same only the id clashes
|
||||||
org.Name = gofakeit.Name()
|
org.Name = gofakeit.Name()
|
||||||
@@ -89,8 +99,7 @@ func TestCreateOrganization(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "adding org with same name twice",
|
name: "adding org with same name twice",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
|
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Organization {
|
||||||
organizationRepo := repository.OrganizationRepository(pool)
|
|
||||||
organizationId := gofakeit.Name()
|
organizationId := gofakeit.Name()
|
||||||
organizationName := gofakeit.Name()
|
organizationName := gofakeit.Name()
|
||||||
|
|
||||||
@@ -101,7 +110,7 @@ func TestCreateOrganization(t *testing.T) {
|
|||||||
State: domain.OrgStateActive,
|
State: domain.OrgStateActive,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := organizationRepo.Create(ctx, &org)
|
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
// change the id to make sure same name+instance causes an error
|
// change the id to make sure same name+instance causes an error
|
||||||
org.ID = gofakeit.Name()
|
org.ID = gofakeit.Name()
|
||||||
@@ -111,7 +120,7 @@ func TestCreateOrganization(t *testing.T) {
|
|||||||
},
|
},
|
||||||
func() struct {
|
func() struct {
|
||||||
name string
|
name string
|
||||||
testFunc func(ctx context.Context, t *testing.T) *domain.Organization
|
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Organization
|
||||||
organization domain.Organization
|
organization domain.Organization
|
||||||
err error
|
err error
|
||||||
} {
|
} {
|
||||||
@@ -120,12 +129,12 @@ func TestCreateOrganization(t *testing.T) {
|
|||||||
|
|
||||||
return struct {
|
return struct {
|
||||||
name string
|
name string
|
||||||
testFunc func(ctx context.Context, t *testing.T) *domain.Organization
|
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Organization
|
||||||
organization domain.Organization
|
organization domain.Organization
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
name: "adding org with same name, different instance",
|
name: "adding org with same name, different instance",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
|
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Organization {
|
||||||
// create instance
|
// create instance
|
||||||
instId := gofakeit.Name()
|
instId := gofakeit.Name()
|
||||||
instance := domain.Instance{
|
instance := domain.Instance{
|
||||||
@@ -137,12 +146,9 @@ func TestCreateOrganization(t *testing.T) {
|
|||||||
ConsoleAppID: "consoleApp",
|
ConsoleAppID: "consoleApp",
|
||||||
DefaultLanguage: "defaultLanguage",
|
DefaultLanguage: "defaultLanguage",
|
||||||
}
|
}
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
err := instanceRepo.Create(t.Context(), tx, &instance)
|
||||||
err := instanceRepo.Create(ctx, &instance)
|
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
organizationRepo := repository.OrganizationRepository(pool)
|
|
||||||
|
|
||||||
org := domain.Organization{
|
org := domain.Organization{
|
||||||
ID: gofakeit.Name(),
|
ID: gofakeit.Name(),
|
||||||
Name: organizationName,
|
Name: organizationName,
|
||||||
@@ -150,7 +156,7 @@ func TestCreateOrganization(t *testing.T) {
|
|||||||
State: domain.OrgStateActive,
|
State: domain.OrgStateActive,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = organizationRepo.Create(ctx, &org)
|
err = organizationRepo.Create(t.Context(), tx, &org)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// change the id to make it unique
|
// change the id to make it unique
|
||||||
@@ -214,19 +220,25 @@ func TestCreateOrganization(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := context.Background()
|
savepoint, err := tx.Begin(t.Context())
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
err = savepoint.Rollback(t.Context())
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("error during rollback: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
var organization *domain.Organization
|
var organization *domain.Organization
|
||||||
if tt.testFunc != nil {
|
if tt.testFunc != nil {
|
||||||
organization = tt.testFunc(ctx, t)
|
organization = tt.testFunc(t, savepoint)
|
||||||
} else {
|
} else {
|
||||||
organization = &tt.organization
|
organization = &tt.organization
|
||||||
}
|
}
|
||||||
organizationRepo := repository.OrganizationRepository(pool)
|
|
||||||
|
|
||||||
// create organization
|
// create organization
|
||||||
beforeCreate := time.Now()
|
|
||||||
err = organizationRepo.Create(ctx, organization)
|
err = organizationRepo.Create(t.Context(), savepoint, organization)
|
||||||
assert.ErrorIs(t, err, tt.err)
|
assert.ErrorIs(t, err, tt.err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@@ -234,7 +246,7 @@ func TestCreateOrganization(t *testing.T) {
|
|||||||
afterCreate := time.Now()
|
afterCreate := time.Now()
|
||||||
|
|
||||||
// check organization values
|
// check organization values
|
||||||
organization, err = organizationRepo.Get(ctx,
|
organization, err = organizationRepo.Get(t.Context(), savepoint,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
database.And(
|
database.And(
|
||||||
organizationRepo.IDCondition(organization.ID),
|
organizationRepo.IDCondition(organization.ID),
|
||||||
@@ -255,6 +267,19 @@ func TestCreateOrganization(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateOrganization(t *testing.T) {
|
func TestUpdateOrganization(t *testing.T) {
|
||||||
|
beforeUpdate := time.Now()
|
||||||
|
tx, err := pool.Begin(t.Context(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
err := tx.Rollback(t.Context())
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("error during rollback: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
instanceRepo := repository.InstanceRepository()
|
||||||
|
organizationRepo := repository.OrganizationRepository()
|
||||||
|
|
||||||
// create instance
|
// create instance
|
||||||
instanceId := gofakeit.Name()
|
instanceId := gofakeit.Name()
|
||||||
instance := domain.Instance{
|
instance := domain.Instance{
|
||||||
@@ -266,20 +291,18 @@ func TestUpdateOrganization(t *testing.T) {
|
|||||||
ConsoleAppID: "consoleApp",
|
ConsoleAppID: "consoleApp",
|
||||||
DefaultLanguage: "defaultLanguage",
|
DefaultLanguage: "defaultLanguage",
|
||||||
}
|
}
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||||
err := instanceRepo.Create(t.Context(), &instance)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
organizationRepo := repository.OrganizationRepository(pool)
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
testFunc func(ctx context.Context, t *testing.T) *domain.Organization
|
testFunc func(t *testing.T) *domain.Organization
|
||||||
update []database.Change
|
update []database.Change
|
||||||
rowsAffected int64
|
rowsAffected int64
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "happy path update name",
|
name: "happy path update name",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
|
testFunc: func(t *testing.T) *domain.Organization {
|
||||||
organizationId := gofakeit.Name()
|
organizationId := gofakeit.Name()
|
||||||
organizationName := gofakeit.Name()
|
organizationName := gofakeit.Name()
|
||||||
|
|
||||||
@@ -291,7 +314,7 @@ func TestUpdateOrganization(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create organization
|
// create organization
|
||||||
err := organizationRepo.Create(ctx, &org)
|
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// update with updated value
|
// update with updated value
|
||||||
@@ -303,7 +326,7 @@ func TestUpdateOrganization(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "update deleted organization",
|
name: "update deleted organization",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
|
testFunc: func(t *testing.T) *domain.Organization {
|
||||||
organizationId := gofakeit.Name()
|
organizationId := gofakeit.Name()
|
||||||
organizationName := gofakeit.Name()
|
organizationName := gofakeit.Name()
|
||||||
|
|
||||||
@@ -315,13 +338,15 @@ func TestUpdateOrganization(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create organization
|
// create organization
|
||||||
err := organizationRepo.Create(ctx, &org)
|
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// delete instance
|
// delete instance
|
||||||
_, err = organizationRepo.Delete(ctx,
|
_, err = organizationRepo.Delete(t.Context(), tx,
|
||||||
organizationRepo.IDCondition(org.ID),
|
database.And(
|
||||||
org.InstanceID,
|
organizationRepo.InstanceIDCondition(org.InstanceID),
|
||||||
|
organizationRepo.IDCondition(org.ID),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -332,7 +357,7 @@ func TestUpdateOrganization(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "happy path change state",
|
name: "happy path change state",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
|
testFunc: func(t *testing.T) *domain.Organization {
|
||||||
organizationId := gofakeit.Name()
|
organizationId := gofakeit.Name()
|
||||||
organizationName := gofakeit.Name()
|
organizationName := gofakeit.Name()
|
||||||
|
|
||||||
@@ -344,7 +369,7 @@ func TestUpdateOrganization(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create organization
|
// create organization
|
||||||
err := organizationRepo.Create(ctx, &org)
|
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// update with updated value
|
// update with updated value
|
||||||
@@ -356,7 +381,7 @@ func TestUpdateOrganization(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "update non existent organization",
|
name: "update non existent organization",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
|
testFunc: func(t *testing.T) *domain.Organization {
|
||||||
organizationId := gofakeit.Name()
|
organizationId := gofakeit.Name()
|
||||||
|
|
||||||
org := domain.Organization{
|
org := domain.Organization{
|
||||||
@@ -370,16 +395,14 @@ func TestUpdateOrganization(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := context.Background()
|
createdOrg := tt.testFunc(t)
|
||||||
organizationRepo := repository.OrganizationRepository(pool)
|
|
||||||
|
|
||||||
createdOrg := tt.testFunc(ctx, t)
|
|
||||||
|
|
||||||
// update org
|
// update org
|
||||||
beforeUpdate := time.Now()
|
rowsAffected, err := organizationRepo.Update(t.Context(), tx,
|
||||||
rowsAffected, err := organizationRepo.Update(ctx,
|
database.And(
|
||||||
organizationRepo.IDCondition(createdOrg.ID),
|
organizationRepo.InstanceIDCondition(createdOrg.InstanceID),
|
||||||
createdOrg.InstanceID,
|
organizationRepo.IDCondition(createdOrg.ID),
|
||||||
|
),
|
||||||
tt.update...,
|
tt.update...,
|
||||||
)
|
)
|
||||||
afterUpdate := time.Now()
|
afterUpdate := time.Now()
|
||||||
@@ -392,7 +415,7 @@ func TestUpdateOrganization(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check organization values
|
// check organization values
|
||||||
organization, err := organizationRepo.Get(ctx,
|
organization, err := organizationRepo.Get(t.Context(), tx,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
database.And(
|
database.And(
|
||||||
organizationRepo.IDCondition(createdOrg.ID),
|
organizationRepo.IDCondition(createdOrg.ID),
|
||||||
@@ -411,6 +434,19 @@ func TestUpdateOrganization(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestGetOrganization(t *testing.T) {
|
func TestGetOrganization(t *testing.T) {
|
||||||
|
tx, err := pool.Begin(t.Context(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
err := tx.Rollback(t.Context())
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("error during rollback: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
instanceRepo := repository.InstanceRepository()
|
||||||
|
orgRepo := repository.OrganizationRepository()
|
||||||
|
orgDomainRepo := repository.OrganizationDomainRepository()
|
||||||
|
|
||||||
// create instance
|
// create instance
|
||||||
instanceId := gofakeit.Name()
|
instanceId := gofakeit.Name()
|
||||||
instance := domain.Instance{
|
instance := domain.Instance{
|
||||||
@@ -422,11 +458,9 @@ func TestGetOrganization(t *testing.T) {
|
|||||||
ConsoleAppID: "consoleApp",
|
ConsoleAppID: "consoleApp",
|
||||||
DefaultLanguage: "defaultLanguage",
|
DefaultLanguage: "defaultLanguage",
|
||||||
}
|
}
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
|
||||||
err := instanceRepo.Create(t.Context(), &instance)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
orgRepo := repository.OrganizationRepository(pool)
|
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
// create organization
|
// create organization
|
||||||
// this org is created as an additional org which should NOT
|
// this org is created as an additional org which should NOT
|
||||||
@@ -437,13 +471,13 @@ func TestGetOrganization(t *testing.T) {
|
|||||||
InstanceID: instanceId,
|
InstanceID: instanceId,
|
||||||
State: domain.OrgStateActive,
|
State: domain.OrgStateActive,
|
||||||
}
|
}
|
||||||
err = orgRepo.Create(t.Context(), &org)
|
err = orgRepo.Create(t.Context(), tx, &org)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
name string
|
name string
|
||||||
testFunc func(ctx context.Context, t *testing.T) *domain.Organization
|
testFunc func(t *testing.T) *domain.Organization
|
||||||
orgIdentifierCondition domain.OrgIdentifierCondition
|
orgIdentifierCondition database.Condition
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -452,7 +486,7 @@ func TestGetOrganization(t *testing.T) {
|
|||||||
organizationId := gofakeit.Name()
|
organizationId := gofakeit.Name()
|
||||||
return test{
|
return test{
|
||||||
name: "happy path get using id",
|
name: "happy path get using id",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
|
testFunc: func(t *testing.T) *domain.Organization {
|
||||||
organizationName := gofakeit.Name()
|
organizationName := gofakeit.Name()
|
||||||
|
|
||||||
org := domain.Organization{
|
org := domain.Organization{
|
||||||
@@ -463,7 +497,7 @@ func TestGetOrganization(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create organization
|
// create organization
|
||||||
err := orgRepo.Create(ctx, &org)
|
err := orgRepo.Create(t.Context(), tx, &org)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return &org
|
return &org
|
||||||
@@ -475,7 +509,7 @@ func TestGetOrganization(t *testing.T) {
|
|||||||
organizationId := gofakeit.Name()
|
organizationId := gofakeit.Name()
|
||||||
return test{
|
return test{
|
||||||
name: "happy path get using id including domain",
|
name: "happy path get using id including domain",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
|
testFunc: func(t *testing.T) *domain.Organization {
|
||||||
organizationName := gofakeit.Name()
|
organizationName := gofakeit.Name()
|
||||||
|
|
||||||
org := domain.Organization{
|
org := domain.Organization{
|
||||||
@@ -486,7 +520,7 @@ func TestGetOrganization(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create organization
|
// create organization
|
||||||
err := orgRepo.Create(ctx, &org)
|
err := orgRepo.Create(t.Context(), tx, &org)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
d := &domain.AddOrganizationDomain{
|
d := &domain.AddOrganizationDomain{
|
||||||
@@ -496,7 +530,7 @@ func TestGetOrganization(t *testing.T) {
|
|||||||
IsVerified: true,
|
IsVerified: true,
|
||||||
IsPrimary: true,
|
IsPrimary: true,
|
||||||
}
|
}
|
||||||
err = orgRepo.Domains(false).Add(ctx, d)
|
err = orgDomainRepo.Add(t.Context(), tx, d)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
org.Domains = []*domain.OrganizationDomain{
|
org.Domains = []*domain.OrganizationDomain{
|
||||||
@@ -521,7 +555,7 @@ func TestGetOrganization(t *testing.T) {
|
|||||||
organizationName := gofakeit.Name()
|
organizationName := gofakeit.Name()
|
||||||
return test{
|
return test{
|
||||||
name: "happy path get using name",
|
name: "happy path get using name",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
|
testFunc: func(t *testing.T) *domain.Organization {
|
||||||
organizationId := gofakeit.Name()
|
organizationId := gofakeit.Name()
|
||||||
|
|
||||||
org := domain.Organization{
|
org := domain.Organization{
|
||||||
@@ -532,39 +566,36 @@ func TestGetOrganization(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create organization
|
// create organization
|
||||||
err := orgRepo.Create(ctx, &org)
|
err := orgRepo.Create(t.Context(), tx, &org)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return &org
|
return &org
|
||||||
},
|
},
|
||||||
orgIdentifierCondition: orgRepo.NameCondition(organizationName),
|
orgIdentifierCondition: orgRepo.NameCondition(database.TextOperationEqual, organizationName),
|
||||||
}
|
}
|
||||||
}(),
|
}(),
|
||||||
{
|
{
|
||||||
name: "get non existent organization",
|
name: "get non existent organization",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Organization {
|
testFunc: func(t *testing.T) *domain.Organization {
|
||||||
org := domain.Organization{
|
org := domain.Organization{
|
||||||
ID: "non existent org",
|
ID: "non existent org",
|
||||||
Name: "non existent org",
|
Name: "non existent org",
|
||||||
}
|
}
|
||||||
return &org
|
return &org
|
||||||
},
|
},
|
||||||
orgIdentifierCondition: orgRepo.NameCondition("non-existent-instance-name"),
|
orgIdentifierCondition: orgRepo.NameCondition(database.TextOperationEqual, "non-existent-instance-name"),
|
||||||
err: new(database.NoRowFoundError),
|
err: new(database.NoRowFoundError),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := context.Background()
|
|
||||||
orgRepo := repository.OrganizationRepository(pool)
|
|
||||||
|
|
||||||
var org *domain.Organization
|
var org *domain.Organization
|
||||||
if tt.testFunc != nil {
|
if tt.testFunc != nil {
|
||||||
org = tt.testFunc(ctx, t)
|
org = tt.testFunc(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// get org values
|
// get org values
|
||||||
returnedOrg, err := orgRepo.Get(ctx,
|
returnedOrg, err := orgRepo.Get(t.Context(), tx,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
database.And(
|
database.And(
|
||||||
tt.orgIdentifierCondition,
|
tt.orgIdentifierCondition,
|
||||||
@@ -592,11 +623,17 @@ func TestGetOrganization(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestListOrganization(t *testing.T) {
|
func TestListOrganization(t *testing.T) {
|
||||||
ctx := t.Context()
|
tx, err := pool.Begin(t.Context(), nil)
|
||||||
pool, stop, err := newEmbeddedDB(ctx)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer stop()
|
defer func() {
|
||||||
organizationRepo := repository.OrganizationRepository(pool)
|
err := tx.Rollback(t.Context())
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("error during rollback: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
instanceRepo := repository.InstanceRepository()
|
||||||
|
organizationRepo := repository.OrganizationRepository()
|
||||||
|
|
||||||
// create instance
|
// create instance
|
||||||
instanceId := gofakeit.Name()
|
instanceId := gofakeit.Name()
|
||||||
@@ -609,33 +646,32 @@ func TestListOrganization(t *testing.T) {
|
|||||||
ConsoleAppID: "consoleApp",
|
ConsoleAppID: "consoleApp",
|
||||||
DefaultLanguage: "defaultLanguage",
|
DefaultLanguage: "defaultLanguage",
|
||||||
}
|
}
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||||
err = instanceRepo.Create(ctx, &instance)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
name string
|
name string
|
||||||
testFunc func(ctx context.Context, t *testing.T) []*domain.Organization
|
testFunc func(t *testing.T, tx database.QueryExecutor) []*domain.Organization
|
||||||
conditionClauses []database.Condition
|
conditionClauses []database.Condition
|
||||||
noOrganizationReturned bool
|
noOrganizationReturned bool
|
||||||
}
|
}
|
||||||
tests := []test{
|
tests := []test{
|
||||||
{
|
{
|
||||||
name: "happy path single organization no filter",
|
name: "happy path single organization no filter",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
|
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization {
|
||||||
noOfOrganizations := 1
|
noOfOrganizations := 1
|
||||||
organizations := make([]*domain.Organization, noOfOrganizations)
|
organizations := make([]*domain.Organization, noOfOrganizations)
|
||||||
for i := range noOfOrganizations {
|
for i := range noOfOrganizations {
|
||||||
|
|
||||||
org := domain.Organization{
|
org := domain.Organization{
|
||||||
ID: gofakeit.Name(),
|
ID: strconv.Itoa(i),
|
||||||
Name: gofakeit.Name(),
|
Name: gofakeit.Name(),
|
||||||
InstanceID: instanceId,
|
InstanceID: instanceId,
|
||||||
State: domain.OrgStateActive,
|
State: domain.OrgStateActive,
|
||||||
}
|
}
|
||||||
|
|
||||||
// create organization
|
// create organization
|
||||||
err := organizationRepo.Create(ctx, &org)
|
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
organizations[i] = &org
|
organizations[i] = &org
|
||||||
@@ -646,20 +682,20 @@ func TestListOrganization(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "happy path multiple organization no filter",
|
name: "happy path multiple organization no filter",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
|
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization {
|
||||||
noOfOrganizations := 5
|
noOfOrganizations := 5
|
||||||
organizations := make([]*domain.Organization, noOfOrganizations)
|
organizations := make([]*domain.Organization, noOfOrganizations)
|
||||||
for i := range noOfOrganizations {
|
for i := range noOfOrganizations {
|
||||||
|
|
||||||
org := domain.Organization{
|
org := domain.Organization{
|
||||||
ID: gofakeit.Name(),
|
ID: strconv.Itoa(i),
|
||||||
Name: gofakeit.Name(),
|
Name: gofakeit.Name(),
|
||||||
InstanceID: instanceId,
|
InstanceID: instanceId,
|
||||||
State: domain.OrgStateActive,
|
State: domain.OrgStateActive,
|
||||||
}
|
}
|
||||||
|
|
||||||
// create organization
|
// create organization
|
||||||
err := organizationRepo.Create(ctx, &org)
|
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
organizations[i] = &org
|
organizations[i] = &org
|
||||||
@@ -672,7 +708,7 @@ func TestListOrganization(t *testing.T) {
|
|||||||
organizationId := gofakeit.Name()
|
organizationId := gofakeit.Name()
|
||||||
return test{
|
return test{
|
||||||
name: "organization filter on id",
|
name: "organization filter on id",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
|
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization {
|
||||||
// create organization
|
// create organization
|
||||||
// this org is created as an additional org which should NOT
|
// this org is created as an additional org which should NOT
|
||||||
// be returned in the results of this test case
|
// be returned in the results of this test case
|
||||||
@@ -682,7 +718,7 @@ func TestListOrganization(t *testing.T) {
|
|||||||
InstanceID: instanceId,
|
InstanceID: instanceId,
|
||||||
State: domain.OrgStateActive,
|
State: domain.OrgStateActive,
|
||||||
}
|
}
|
||||||
err = organizationRepo.Create(ctx, &org)
|
err = organizationRepo.Create(t.Context(), tx, &org)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
noOfOrganizations := 1
|
noOfOrganizations := 1
|
||||||
@@ -697,7 +733,7 @@ func TestListOrganization(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create organization
|
// create organization
|
||||||
err := organizationRepo.Create(ctx, &org)
|
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
organizations[i] = &org
|
organizations[i] = &org
|
||||||
@@ -705,12 +741,15 @@ func TestListOrganization(t *testing.T) {
|
|||||||
|
|
||||||
return organizations
|
return organizations
|
||||||
},
|
},
|
||||||
conditionClauses: []database.Condition{organizationRepo.IDCondition(organizationId)},
|
conditionClauses: []database.Condition{
|
||||||
|
organizationRepo.InstanceIDCondition(instanceId),
|
||||||
|
organizationRepo.IDCondition(organizationId),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}(),
|
}(),
|
||||||
{
|
{
|
||||||
name: "multiple organization filter on state",
|
name: "multiple organization filter on state",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
|
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization {
|
||||||
// create organization
|
// create organization
|
||||||
// this org is created as an additional org which should NOT
|
// this org is created as an additional org which should NOT
|
||||||
// be returned in the results of this test case
|
// be returned in the results of this test case
|
||||||
@@ -720,7 +759,7 @@ func TestListOrganization(t *testing.T) {
|
|||||||
InstanceID: instanceId,
|
InstanceID: instanceId,
|
||||||
State: domain.OrgStateActive,
|
State: domain.OrgStateActive,
|
||||||
}
|
}
|
||||||
err = organizationRepo.Create(ctx, &org)
|
err = organizationRepo.Create(t.Context(), tx, &org)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
noOfOrganizations := 5
|
noOfOrganizations := 5
|
||||||
@@ -728,14 +767,14 @@ func TestListOrganization(t *testing.T) {
|
|||||||
for i := range noOfOrganizations {
|
for i := range noOfOrganizations {
|
||||||
|
|
||||||
org := domain.Organization{
|
org := domain.Organization{
|
||||||
ID: gofakeit.Name(),
|
ID: strconv.Itoa(i),
|
||||||
Name: gofakeit.Name(),
|
Name: gofakeit.Name(),
|
||||||
InstanceID: instanceId,
|
InstanceID: instanceId,
|
||||||
State: domain.OrgStateInactive,
|
State: domain.OrgStateInactive,
|
||||||
}
|
}
|
||||||
|
|
||||||
// create organization
|
// create organization
|
||||||
err := organizationRepo.Create(ctx, &org)
|
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
organizations[i] = &org
|
organizations[i] = &org
|
||||||
@@ -743,13 +782,16 @@ func TestListOrganization(t *testing.T) {
|
|||||||
|
|
||||||
return organizations
|
return organizations
|
||||||
},
|
},
|
||||||
conditionClauses: []database.Condition{organizationRepo.StateCondition(domain.OrgStateInactive)},
|
conditionClauses: []database.Condition{
|
||||||
|
organizationRepo.InstanceIDCondition(instanceId),
|
||||||
|
organizationRepo.StateCondition(domain.OrgStateInactive),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
func() test {
|
func() test {
|
||||||
instanceId_2 := gofakeit.Name()
|
instanceId_2 := gofakeit.Name()
|
||||||
return test{
|
return test{
|
||||||
name: "multiple organization filter on instance",
|
name: "multiple organization filter on instance",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization {
|
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Organization {
|
||||||
// create instance 1
|
// create instance 1
|
||||||
instanceId_1 := gofakeit.Name()
|
instanceId_1 := gofakeit.Name()
|
||||||
instance := domain.Instance{
|
instance := domain.Instance{
|
||||||
@@ -761,8 +803,7 @@ func TestListOrganization(t *testing.T) {
|
|||||||
ConsoleAppID: "consoleApp",
|
ConsoleAppID: "consoleApp",
|
||||||
DefaultLanguage: "defaultLanguage",
|
DefaultLanguage: "defaultLanguage",
|
||||||
}
|
}
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||||
err = instanceRepo.Create(ctx, &instance)
|
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
// create organization
|
// create organization
|
||||||
@@ -774,7 +815,7 @@ func TestListOrganization(t *testing.T) {
|
|||||||
InstanceID: instanceId_1,
|
InstanceID: instanceId_1,
|
||||||
State: domain.OrgStateActive,
|
State: domain.OrgStateActive,
|
||||||
}
|
}
|
||||||
err = organizationRepo.Create(ctx, &org)
|
err = organizationRepo.Create(t.Context(), tx, &org)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// create instance 2
|
// create instance 2
|
||||||
@@ -787,7 +828,7 @@ func TestListOrganization(t *testing.T) {
|
|||||||
ConsoleAppID: "consoleApp",
|
ConsoleAppID: "consoleApp",
|
||||||
DefaultLanguage: "defaultLanguage",
|
DefaultLanguage: "defaultLanguage",
|
||||||
}
|
}
|
||||||
err = instanceRepo.Create(ctx, &instance_2)
|
err = instanceRepo.Create(t.Context(), tx, &instance_2)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
noOfOrganizations := 5
|
noOfOrganizations := 5
|
||||||
@@ -795,14 +836,14 @@ func TestListOrganization(t *testing.T) {
|
|||||||
for i := range noOfOrganizations {
|
for i := range noOfOrganizations {
|
||||||
|
|
||||||
org := domain.Organization{
|
org := domain.Organization{
|
||||||
ID: gofakeit.Name(),
|
ID: strconv.Itoa(i),
|
||||||
Name: gofakeit.Name(),
|
Name: gofakeit.Name(),
|
||||||
InstanceID: instanceId_2,
|
InstanceID: instanceId_2,
|
||||||
State: domain.OrgStateActive,
|
State: domain.OrgStateActive,
|
||||||
}
|
}
|
||||||
|
|
||||||
// create organization
|
// create organization
|
||||||
err := organizationRepo.Create(ctx, &org)
|
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
organizations[i] = &org
|
organizations[i] = &org
|
||||||
@@ -816,22 +857,25 @@ func TestListOrganization(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
t.Cleanup(func() {
|
savepoint, err := tx.Begin(t.Context())
|
||||||
_, err := pool.Exec(ctx, "DELETE FROM zitadel.organizations")
|
require.NoError(t, err)
|
||||||
require.NoError(t, err)
|
defer func() {
|
||||||
})
|
err = savepoint.Rollback(t.Context())
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("error during rollback: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
organizations := tt.testFunc(t, savepoint)
|
||||||
|
|
||||||
organizations := tt.testFunc(ctx, t)
|
condition := organizationRepo.InstanceIDCondition(instanceId)
|
||||||
|
|
||||||
var condition database.Condition
|
|
||||||
if len(tt.conditionClauses) > 0 {
|
if len(tt.conditionClauses) > 0 {
|
||||||
condition = database.And(tt.conditionClauses...)
|
condition = database.And(tt.conditionClauses...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// check organization values
|
// check organization values
|
||||||
returnedOrgs, err := organizationRepo.List(ctx,
|
returnedOrgs, err := organizationRepo.List(t.Context(), tx,
|
||||||
database.WithCondition(condition),
|
database.WithCondition(condition),
|
||||||
database.WithOrderByAscending(organizationRepo.CreatedAtColumn()),
|
database.WithOrderByAscending(organizationRepo.CreatedAtColumn(), organizationRepo.IDColumn()),
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
if tt.noOrganizationReturned {
|
if tt.noOrganizationReturned {
|
||||||
@@ -851,6 +895,15 @@ func TestListOrganization(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDeleteOrganization(t *testing.T) {
|
func TestDeleteOrganization(t *testing.T) {
|
||||||
|
tx, err := pool.Begin(t.Context(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = tx.Rollback(t.Context())
|
||||||
|
})
|
||||||
|
|
||||||
|
instanceRepo := repository.InstanceRepository()
|
||||||
|
organizationRepo := repository.OrganizationRepository()
|
||||||
|
|
||||||
// create instance
|
// create instance
|
||||||
instanceId := gofakeit.Name()
|
instanceId := gofakeit.Name()
|
||||||
instance := domain.Instance{
|
instance := domain.Instance{
|
||||||
@@ -862,24 +915,22 @@ func TestDeleteOrganization(t *testing.T) {
|
|||||||
ConsoleAppID: "consoleApp",
|
ConsoleAppID: "consoleApp",
|
||||||
DefaultLanguage: "defaultLanguage",
|
DefaultLanguage: "defaultLanguage",
|
||||||
}
|
}
|
||||||
instanceRepo := repository.InstanceRepository(pool)
|
err = instanceRepo.Create(t.Context(), tx, &instance)
|
||||||
err := instanceRepo.Create(t.Context(), &instance)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
name string
|
name string
|
||||||
testFunc func(ctx context.Context, t *testing.T)
|
testFunc func(t *testing.T)
|
||||||
orgIdentifierCondition domain.OrgIdentifierCondition
|
orgIdentifierCondition database.Condition
|
||||||
noOfDeletedRows int64
|
noOfDeletedRows int64
|
||||||
}
|
}
|
||||||
tests := []test{
|
tests := []test{
|
||||||
func() test {
|
func() test {
|
||||||
organizationRepo := repository.OrganizationRepository(pool)
|
|
||||||
organizationId := gofakeit.Name()
|
organizationId := gofakeit.Name()
|
||||||
var noOfOrganizations int64 = 1
|
var noOfOrganizations int64 = 1
|
||||||
return test{
|
return test{
|
||||||
name: "happy path delete organization filter id",
|
name: "happy path delete organization filter id",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) {
|
testFunc: func(t *testing.T) {
|
||||||
organizations := make([]*domain.Organization, noOfOrganizations)
|
organizations := make([]*domain.Organization, noOfOrganizations)
|
||||||
for i := range noOfOrganizations {
|
for i := range noOfOrganizations {
|
||||||
|
|
||||||
@@ -891,7 +942,7 @@ func TestDeleteOrganization(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create organization
|
// create organization
|
||||||
err := organizationRepo.Create(ctx, &org)
|
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
organizations[i] = &org
|
organizations[i] = &org
|
||||||
@@ -902,12 +953,11 @@ func TestDeleteOrganization(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}(),
|
}(),
|
||||||
func() test {
|
func() test {
|
||||||
organizationRepo := repository.OrganizationRepository(pool)
|
|
||||||
organizationName := gofakeit.Name()
|
organizationName := gofakeit.Name()
|
||||||
var noOfOrganizations int64 = 1
|
var noOfOrganizations int64 = 1
|
||||||
return test{
|
return test{
|
||||||
name: "happy path delete organization filter name",
|
name: "happy path delete organization filter name",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) {
|
testFunc: func(t *testing.T) {
|
||||||
organizations := make([]*domain.Organization, noOfOrganizations)
|
organizations := make([]*domain.Organization, noOfOrganizations)
|
||||||
for i := range noOfOrganizations {
|
for i := range noOfOrganizations {
|
||||||
|
|
||||||
@@ -919,30 +969,28 @@ func TestDeleteOrganization(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create organization
|
// create organization
|
||||||
err := organizationRepo.Create(ctx, &org)
|
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
organizations[i] = &org
|
organizations[i] = &org
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
orgIdentifierCondition: organizationRepo.NameCondition(organizationName),
|
orgIdentifierCondition: organizationRepo.NameCondition(database.TextOperationEqual, organizationName),
|
||||||
noOfDeletedRows: noOfOrganizations,
|
noOfDeletedRows: noOfOrganizations,
|
||||||
}
|
}
|
||||||
}(),
|
}(),
|
||||||
func() test {
|
func() test {
|
||||||
organizationRepo := repository.OrganizationRepository(pool)
|
|
||||||
non_existent_organization_name := gofakeit.Name()
|
non_existent_organization_name := gofakeit.Name()
|
||||||
return test{
|
return test{
|
||||||
name: "delete non existent organization",
|
name: "delete non existent organization",
|
||||||
orgIdentifierCondition: organizationRepo.NameCondition(non_existent_organization_name),
|
orgIdentifierCondition: organizationRepo.NameCondition(database.TextOperationEqual, non_existent_organization_name),
|
||||||
}
|
}
|
||||||
}(),
|
}(),
|
||||||
func() test {
|
func() test {
|
||||||
organizationRepo := repository.OrganizationRepository(pool)
|
|
||||||
organizationName := gofakeit.Name()
|
organizationName := gofakeit.Name()
|
||||||
return test{
|
return test{
|
||||||
name: "deleted already deleted organization",
|
name: "deleted already deleted organization",
|
||||||
testFunc: func(ctx context.Context, t *testing.T) {
|
testFunc: func(t *testing.T) {
|
||||||
noOfOrganizations := 1
|
noOfOrganizations := 1
|
||||||
organizations := make([]*domain.Organization, noOfOrganizations)
|
organizations := make([]*domain.Organization, noOfOrganizations)
|
||||||
for i := range noOfOrganizations {
|
for i := range noOfOrganizations {
|
||||||
@@ -955,21 +1003,23 @@ func TestDeleteOrganization(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// create organization
|
// create organization
|
||||||
err := organizationRepo.Create(ctx, &org)
|
err := organizationRepo.Create(t.Context(), tx, &org)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
organizations[i] = &org
|
organizations[i] = &org
|
||||||
}
|
}
|
||||||
|
|
||||||
// delete organization
|
// delete organization
|
||||||
affectedRows, err := organizationRepo.Delete(ctx,
|
affectedRows, err := organizationRepo.Delete(t.Context(), tx,
|
||||||
organizationRepo.NameCondition(organizationName),
|
database.And(
|
||||||
organizations[0].InstanceID,
|
organizationRepo.InstanceIDCondition(organizations[0].InstanceID),
|
||||||
|
organizationRepo.NameCondition(database.TextOperationEqual, organizationName),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
assert.Equal(t, int64(1), affectedRows)
|
assert.Equal(t, int64(1), affectedRows)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
orgIdentifierCondition: organizationRepo.NameCondition(organizationName),
|
orgIdentifierCondition: organizationRepo.NameCondition(database.TextOperationEqual, organizationName),
|
||||||
// this test should return 0 affected rows as the org was already deleted
|
// this test should return 0 affected rows as the org was already deleted
|
||||||
noOfDeletedRows: 0,
|
noOfDeletedRows: 0,
|
||||||
}
|
}
|
||||||
@@ -977,23 +1027,22 @@ func TestDeleteOrganization(t *testing.T) {
|
|||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := context.Background()
|
|
||||||
organizationRepo := repository.OrganizationRepository(pool)
|
|
||||||
|
|
||||||
if tt.testFunc != nil {
|
if tt.testFunc != nil {
|
||||||
tt.testFunc(ctx, t)
|
tt.testFunc(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// delete organization
|
// delete organization
|
||||||
noOfDeletedRows, err := organizationRepo.Delete(ctx,
|
noOfDeletedRows, err := organizationRepo.Delete(t.Context(), tx,
|
||||||
tt.orgIdentifierCondition,
|
database.And(
|
||||||
instanceId,
|
organizationRepo.InstanceIDCondition(instanceId),
|
||||||
|
tt.orgIdentifierCondition,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, noOfDeletedRows, tt.noOfDeletedRows)
|
assert.Equal(t, noOfDeletedRows, tt.noOfDeletedRows)
|
||||||
|
|
||||||
// check organization was deleted
|
// check organization was deleted
|
||||||
organization, err := organizationRepo.Get(ctx,
|
organization, err := organizationRepo.Get(t.Context(), tx,
|
||||||
database.WithCondition(
|
database.WithCondition(
|
||||||
database.And(
|
database.And(
|
||||||
tt.orgIdentifierCondition,
|
tt.orgIdentifierCondition,
|
||||||
@@ -1006,3 +1055,99 @@ func TestDeleteOrganization(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetOrganizationWithSubResources(t *testing.T) {
|
||||||
|
tx, err := pool.Begin(t.Context(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = tx.Rollback(t.Context())
|
||||||
|
})
|
||||||
|
|
||||||
|
instanceRepo := repository.InstanceRepository()
|
||||||
|
orgRepo := repository.OrganizationRepository()
|
||||||
|
|
||||||
|
// create instance
|
||||||
|
instanceId := gofakeit.Name()
|
||||||
|
err = instanceRepo.Create(t.Context(), tx, &domain.Instance{
|
||||||
|
ID: instanceId,
|
||||||
|
Name: gofakeit.Name(),
|
||||||
|
DefaultOrgID: "defaultOrgId",
|
||||||
|
IAMProjectID: "iamProject",
|
||||||
|
ConsoleClientID: "consoleCLient",
|
||||||
|
ConsoleAppID: "consoleApp",
|
||||||
|
DefaultLanguage: "defaultLanguage",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// create organization
|
||||||
|
org := domain.Organization{
|
||||||
|
ID: "1",
|
||||||
|
Name: "org-name",
|
||||||
|
InstanceID: instanceId,
|
||||||
|
State: domain.OrgStateActive,
|
||||||
|
}
|
||||||
|
err = orgRepo.Create(t.Context(), tx, &org)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = orgRepo.Create(t.Context(), tx, &domain.Organization{
|
||||||
|
ID: "without-sub-resources",
|
||||||
|
Name: "org-name-2",
|
||||||
|
InstanceID: instanceId,
|
||||||
|
State: domain.OrgStateActive,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Run("domains", func(t *testing.T) {
|
||||||
|
domainRepo := repository.OrganizationDomainRepository()
|
||||||
|
|
||||||
|
domain1 := &domain.AddOrganizationDomain{
|
||||||
|
InstanceID: org.InstanceID,
|
||||||
|
OrgID: org.ID,
|
||||||
|
Domain: "domain1.com",
|
||||||
|
IsVerified: true,
|
||||||
|
IsPrimary: true,
|
||||||
|
}
|
||||||
|
err = domainRepo.Add(t.Context(), tx, domain1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
domain2 := &domain.AddOrganizationDomain{
|
||||||
|
InstanceID: org.InstanceID,
|
||||||
|
OrgID: org.ID,
|
||||||
|
Domain: "domain2.com",
|
||||||
|
IsVerified: false,
|
||||||
|
IsPrimary: false,
|
||||||
|
}
|
||||||
|
err = domainRepo.Add(t.Context(), tx, domain2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Run("org by domain", func(t *testing.T) {
|
||||||
|
orgRepo := orgRepo.LoadDomains()
|
||||||
|
|
||||||
|
returnedOrg, err := orgRepo.Get(t.Context(), tx,
|
||||||
|
database.WithCondition(
|
||||||
|
database.And(
|
||||||
|
orgRepo.InstanceIDCondition(instanceId),
|
||||||
|
orgRepo.ExistsDomain(domainRepo.DomainCondition(database.TextOperationEqual, domain1.Domain)),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, org.ID, returnedOrg.ID)
|
||||||
|
assert.Len(t, returnedOrg.Domains, 2)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ensure org by domain works without LoadDomains", func(t *testing.T) {
|
||||||
|
returnedOrg, err := orgRepo.Get(t.Context(), tx,
|
||||||
|
database.WithCondition(
|
||||||
|
database.And(
|
||||||
|
orgRepo.InstanceIDCondition(instanceId),
|
||||||
|
orgRepo.ExistsDomain(domainRepo.DomainCondition(database.TextOperationEqual, domain1.Domain)),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, org.ID, returnedOrg.ID)
|
||||||
|
assert.Len(t, returnedOrg.Domains, 0)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,10 +4,6 @@ import (
|
|||||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
type repository struct {
|
|
||||||
client database.QueryExecutor
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeCondition(
|
func writeCondition(
|
||||||
builder *database.StatementBuilder,
|
builder *database.StatementBuilder,
|
||||||
condition database.Condition,
|
condition database.Condition,
|
||||||
|
|||||||
@@ -12,16 +12,10 @@ const queryUserStmt = `SELECT instance_id, org_id, id, username, type, created_a
|
|||||||
` first_name, last_name, email_address, email_verified_at, phone_number, phone_verified_at, description` +
|
` first_name, last_name, email_address, email_verified_at, phone_number, phone_verified_at, description` +
|
||||||
` FROM users_view users`
|
` FROM users_view users`
|
||||||
|
|
||||||
type user struct {
|
type user struct{}
|
||||||
repository
|
|
||||||
}
|
|
||||||
|
|
||||||
func UserRepository(client database.QueryExecutor) domain.UserRepository {
|
func UserRepository() domain.UserRepository {
|
||||||
return &user{
|
return new(user)
|
||||||
repository: repository{
|
|
||||||
client: client,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ domain.UserRepository = (*user)(nil)
|
var _ domain.UserRepository = (*user)(nil)
|
||||||
@@ -31,17 +25,17 @@ var _ domain.UserRepository = (*user)(nil)
|
|||||||
// -------------------------------------------------------------
|
// -------------------------------------------------------------
|
||||||
|
|
||||||
// Human implements [domain.UserRepository].
|
// Human implements [domain.UserRepository].
|
||||||
func (u *user) Human() domain.HumanRepository {
|
func (u user) Human() domain.HumanRepository {
|
||||||
return &userHuman{user: u}
|
return &userHuman{user: u}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Machine implements [domain.UserRepository].
|
// Machine implements [domain.UserRepository].
|
||||||
func (u *user) Machine() domain.MachineRepository {
|
func (u user) Machine() domain.MachineRepository {
|
||||||
return &userMachine{user: u}
|
return &userMachine{user: u}
|
||||||
}
|
}
|
||||||
|
|
||||||
// List implements [domain.UserRepository].
|
// List implements [domain.UserRepository].
|
||||||
func (u *user) List(ctx context.Context, opts ...database.QueryOption) (users []*domain.User, err error) {
|
func (u user) List(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (users []*domain.User, err error) {
|
||||||
options := new(database.QueryOpts)
|
options := new(database.QueryOpts)
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(options)
|
opt(options)
|
||||||
@@ -54,7 +48,7 @@ func (u *user) List(ctx context.Context, opts ...database.QueryOption) (users []
|
|||||||
options.WriteLimit(&builder)
|
options.WriteLimit(&builder)
|
||||||
options.WriteOffset(&builder)
|
options.WriteOffset(&builder)
|
||||||
|
|
||||||
rows, err := u.client.Query(ctx, builder.String(), builder.Args()...)
|
rows, err := client.Query(ctx, builder.String(), builder.Args()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -79,7 +73,7 @@ func (u *user) List(ctx context.Context, opts ...database.QueryOption) (users []
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get implements [domain.UserRepository].
|
// Get implements [domain.UserRepository].
|
||||||
func (u *user) Get(ctx context.Context, opts ...database.QueryOption) (*domain.User, error) {
|
func (u user) Get(ctx context.Context, client database.QueryExecutor, opts ...database.QueryOption) (*domain.User, error) {
|
||||||
options := new(database.QueryOpts)
|
options := new(database.QueryOpts)
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(options)
|
opt(options)
|
||||||
@@ -92,7 +86,7 @@ func (u *user) Get(ctx context.Context, opts ...database.QueryOption) (*domain.U
|
|||||||
options.WriteLimit(&builder)
|
options.WriteLimit(&builder)
|
||||||
options.WriteOffset(&builder)
|
options.WriteOffset(&builder)
|
||||||
|
|
||||||
return scanUser(u.client.QueryRow(ctx, builder.String(), builder.Args()...))
|
return scanUser(client.QueryRow(ctx, builder.String(), builder.Args()...))
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -105,7 +99,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Create implements [domain.UserRepository].
|
// Create implements [domain.UserRepository].
|
||||||
func (u *user) Create(ctx context.Context, user *domain.User) error {
|
func (u user) Create(ctx context.Context, client database.QueryExecutor, user *domain.User) error {
|
||||||
builder := database.StatementBuilder{}
|
builder := database.StatementBuilder{}
|
||||||
builder.AppendArgs(user.InstanceID, user.OrgID, user.ID, user.Username, user.Traits.Type())
|
builder.AppendArgs(user.InstanceID, user.OrgID, user.ID, user.Username, user.Traits.Type())
|
||||||
switch trait := user.Traits.(type) {
|
switch trait := user.Traits.(type) {
|
||||||
@@ -116,15 +110,15 @@ func (u *user) Create(ctx context.Context, user *domain.User) error {
|
|||||||
builder.WriteString(createMachineStmt)
|
builder.WriteString(createMachineStmt)
|
||||||
builder.AppendArgs(trait.Description)
|
builder.AppendArgs(trait.Description)
|
||||||
}
|
}
|
||||||
return u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&user.CreatedAt, &user.UpdatedAt)
|
return client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&user.CreatedAt, &user.UpdatedAt)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete implements [domain.UserRepository].
|
// Delete implements [domain.UserRepository].
|
||||||
func (u *user) Delete(ctx context.Context, condition database.Condition) error {
|
func (u user) Delete(ctx context.Context, client database.QueryExecutor, condition database.Condition) error {
|
||||||
builder := database.StatementBuilder{}
|
builder := database.StatementBuilder{}
|
||||||
builder.WriteString("DELETE FROM users")
|
builder.WriteString("DELETE FROM users")
|
||||||
writeCondition(&builder, condition)
|
writeCondition(&builder, condition)
|
||||||
_, err := u.client.Exec(ctx, builder.String(), builder.Args()...)
|
_, err := client.Exec(ctx, builder.String(), builder.Args()...)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
// -------------------------------------------------------------
|
// -------------------------------------------------------------
|
||||||
|
|
||||||
type userHuman struct {
|
type userHuman struct {
|
||||||
*user
|
user
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ domain.HumanRepository = (*userHuman)(nil)
|
var _ domain.HumanRepository = (*userHuman)(nil)
|
||||||
@@ -21,14 +21,14 @@ var _ domain.HumanRepository = (*userHuman)(nil)
|
|||||||
const userEmailQuery = `SELECT h.email_address, h.email_verified_at FROM user_humans h`
|
const userEmailQuery = `SELECT h.email_address, h.email_verified_at FROM user_humans h`
|
||||||
|
|
||||||
// GetEmail implements [domain.HumanRepository].
|
// GetEmail implements [domain.HumanRepository].
|
||||||
func (u *userHuman) GetEmail(ctx context.Context, condition database.Condition) (*domain.Email, error) {
|
func (u *userHuman) GetEmail(ctx context.Context, client database.QueryExecutor, condition database.Condition) (*domain.Email, error) {
|
||||||
var email domain.Email
|
var email domain.Email
|
||||||
|
|
||||||
builder := database.StatementBuilder{}
|
builder := database.StatementBuilder{}
|
||||||
builder.WriteString(userEmailQuery)
|
builder.WriteString(userEmailQuery)
|
||||||
writeCondition(&builder, condition)
|
writeCondition(&builder, condition)
|
||||||
|
|
||||||
err := u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(
|
err := client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(
|
||||||
&email.Address,
|
&email.Address,
|
||||||
&email.VerifiedAt,
|
&email.VerifiedAt,
|
||||||
)
|
)
|
||||||
@@ -39,7 +39,7 @@ func (u *userHuman) GetEmail(ctx context.Context, condition database.Condition)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update implements [domain.HumanRepository].
|
// Update implements [domain.HumanRepository].
|
||||||
func (h userHuman) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error {
|
func (h userHuman) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) error {
|
||||||
builder := database.StatementBuilder{}
|
builder := database.StatementBuilder{}
|
||||||
builder.WriteString(`UPDATE human_users SET `)
|
builder.WriteString(`UPDATE human_users SET `)
|
||||||
database.Changes(changes).Write(&builder)
|
database.Changes(changes).Write(&builder)
|
||||||
@@ -47,7 +47,7 @@ func (h userHuman) Update(ctx context.Context, condition database.Condition, cha
|
|||||||
|
|
||||||
stmt := builder.String()
|
stmt := builder.String()
|
||||||
|
|
||||||
_, err := h.client.Exec(ctx, stmt, builder.Args()...)
|
_, err := client.Exec(ctx, stmt, builder.Args()...)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type userMachine struct {
|
type userMachine struct {
|
||||||
*user
|
user
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ domain.MachineRepository = (*userMachine)(nil)
|
var _ domain.MachineRepository = (*userMachine)(nil)
|
||||||
@@ -18,14 +18,14 @@ var _ domain.MachineRepository = (*userMachine)(nil)
|
|||||||
// -------------------------------------------------------------
|
// -------------------------------------------------------------
|
||||||
|
|
||||||
// Update implements [domain.MachineRepository].
|
// Update implements [domain.MachineRepository].
|
||||||
func (m userMachine) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error {
|
func (m userMachine) Update(ctx context.Context, client database.QueryExecutor, condition database.Condition, changes ...database.Change) error {
|
||||||
builder := database.StatementBuilder{}
|
builder := database.StatementBuilder{}
|
||||||
builder.WriteString("UPDATE user_machines SET ")
|
builder.WriteString("UPDATE user_machines SET ")
|
||||||
database.Changes(changes).Write(&builder)
|
database.Changes(changes).Write(&builder)
|
||||||
writeCondition(&builder, condition)
|
writeCondition(&builder, condition)
|
||||||
m.writeReturning()
|
m.writeReturning()
|
||||||
|
|
||||||
_, err := m.client.Exec(ctx, builder.String(), builder.Args()...)
|
_, err := client.Exec(ctx, builder.String(), builder.Args()...)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package database
|
package database
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/hex"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@@ -20,8 +21,16 @@ type StatementBuilder struct {
|
|||||||
existingArgs map[any]string
|
existingArgs map[any]string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type argWriter interface {
|
||||||
|
WriteArg(builder *StatementBuilder)
|
||||||
|
}
|
||||||
|
|
||||||
// WriteArgs adds the argument to the statement and writes the placeholder to the query.
|
// WriteArgs adds the argument to the statement and writes the placeholder to the query.
|
||||||
func (b *StatementBuilder) WriteArg(arg any) {
|
func (b *StatementBuilder) WriteArg(arg any) {
|
||||||
|
if writer, ok := arg.(argWriter); ok {
|
||||||
|
writer.WriteArg(b)
|
||||||
|
return
|
||||||
|
}
|
||||||
b.WriteString(b.AppendArg(arg))
|
b.WriteString(b.AppendArg(arg))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,7 +50,13 @@ func (b *StatementBuilder) AppendArg(arg any) (placeholder string) {
|
|||||||
if b.existingArgs == nil {
|
if b.existingArgs == nil {
|
||||||
b.existingArgs = make(map[any]string)
|
b.existingArgs = make(map[any]string)
|
||||||
}
|
}
|
||||||
if placeholder, ok := b.existingArgs[arg]; ok {
|
// the key is used to work around the following panic:
|
||||||
|
// runtime error: hash of unhashable type []uint8
|
||||||
|
key := arg
|
||||||
|
if argBytes, ok := arg.([]uint8); ok {
|
||||||
|
key = `\\bytes-` + hex.EncodeToString(argBytes)
|
||||||
|
}
|
||||||
|
if placeholder, ok := b.existingArgs[key]; ok {
|
||||||
return placeholder
|
return placeholder
|
||||||
}
|
}
|
||||||
if instruction, ok := arg.(Instruction); ok {
|
if instruction, ok := arg.(Instruction); ok {
|
||||||
@@ -50,7 +65,7 @@ func (b *StatementBuilder) AppendArg(arg any) (placeholder string) {
|
|||||||
|
|
||||||
b.args = append(b.args, arg)
|
b.args = append(b.args, arg)
|
||||||
placeholder = "$" + strconv.Itoa(len(b.args))
|
placeholder = "$" + strconv.Itoa(len(b.args))
|
||||||
b.existingArgs[arg] = placeholder
|
b.existingArgs[key] = placeholder
|
||||||
return placeholder
|
return placeholder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
144
backend/v3/storage/database/statement_test.go
Normal file
144
backend/v3/storage/database/statement_test.go
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStatementBuilder_AppendArg(t *testing.T) {
|
||||||
|
t.Run("same arg returns same placeholder", func(t *testing.T) {
|
||||||
|
var b StatementBuilder
|
||||||
|
placeholder1 := b.AppendArg("same")
|
||||||
|
placeholder2 := b.AppendArg("same")
|
||||||
|
assert.Equal(t, placeholder1, placeholder2)
|
||||||
|
assert.Len(t, b.Args(), 1)
|
||||||
|
assert.Len(t, b.existingArgs, 1)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("same arg different types", func(t *testing.T) {
|
||||||
|
var b StatementBuilder
|
||||||
|
placeholder1 := b.AppendArg("same")
|
||||||
|
placeholder2 := b.AppendArg([]byte("same"))
|
||||||
|
placeholder3 := b.AppendArg("same")
|
||||||
|
assert.NotEqual(t, placeholder1, placeholder2)
|
||||||
|
assert.Equal(t, placeholder1, placeholder3)
|
||||||
|
assert.Len(t, b.Args(), 2)
|
||||||
|
assert.Len(t, b.existingArgs, 2)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Instruction args are always different", func(t *testing.T) {
|
||||||
|
var b StatementBuilder
|
||||||
|
placeholder1 := b.AppendArg(DefaultInstruction)
|
||||||
|
placeholder2 := b.AppendArg(DefaultInstruction)
|
||||||
|
assert.Equal(t, placeholder1, placeholder2)
|
||||||
|
assert.Len(t, b.Args(), 0)
|
||||||
|
assert.Len(t, b.existingArgs, 0)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStatementBuilder_AppendArgs(t *testing.T) {
|
||||||
|
t.Run("same arg returns same placeholder", func(t *testing.T) {
|
||||||
|
var b StatementBuilder
|
||||||
|
b.AppendArgs("same", "same")
|
||||||
|
assert.Len(t, b.Args(), 1)
|
||||||
|
assert.Len(t, b.existingArgs, 1)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("same arg different types", func(t *testing.T) {
|
||||||
|
var b StatementBuilder
|
||||||
|
b.AppendArgs("same", []byte("same"), "same")
|
||||||
|
assert.Len(t, b.Args(), 2)
|
||||||
|
assert.Len(t, b.existingArgs, 2)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Instruction args are always different", func(t *testing.T) {
|
||||||
|
var b StatementBuilder
|
||||||
|
b.AppendArgs(DefaultInstruction, DefaultInstruction)
|
||||||
|
assert.Len(t, b.Args(), 0)
|
||||||
|
assert.Len(t, b.existingArgs, 0)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStatementBuilder_WriteArg(t *testing.T) {
|
||||||
|
for _, test := range []struct {
|
||||||
|
name string
|
||||||
|
arg any
|
||||||
|
wantSQL string
|
||||||
|
wantArg []any
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "primitive arg",
|
||||||
|
arg: "test",
|
||||||
|
wantSQL: "$1",
|
||||||
|
wantArg: []any{"test"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "argWriter arg",
|
||||||
|
arg: SHA256Value("wrapped"),
|
||||||
|
wantSQL: "SHA256($1)",
|
||||||
|
wantArg: []any{"wrapped"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Instruction arg",
|
||||||
|
arg: DefaultInstruction,
|
||||||
|
wantSQL: "DEFAULT",
|
||||||
|
wantArg: []any{},
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
var b StatementBuilder
|
||||||
|
b.WriteArg(test.arg)
|
||||||
|
assert.Equal(t, test.wantSQL, b.String())
|
||||||
|
require.Len(t, b.Args(), len(test.wantArg))
|
||||||
|
for i := range test.wantArg {
|
||||||
|
assert.Equal(t, test.wantArg[i], b.Args()[i])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStatementBuilder_WriteArgs(t *testing.T) {
|
||||||
|
for _, test := range []struct {
|
||||||
|
name string
|
||||||
|
args []any
|
||||||
|
wantSQL string
|
||||||
|
wantArg []any
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "primitive args",
|
||||||
|
args: []any{"test", 123, true, uint32(123)},
|
||||||
|
wantSQL: "$1, $2, $3, $4",
|
||||||
|
wantArg: []any{"test", 123, true, uint32(123)},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "argWriter args",
|
||||||
|
args: []any{SHA256Value("wrapped"), LowerValue("ASDF")},
|
||||||
|
wantSQL: "SHA256($1), LOWER($2)",
|
||||||
|
wantArg: []any{"wrapped", "ASDF"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Instruction args",
|
||||||
|
args: []any{DefaultInstruction, NowInstruction, NullInstruction},
|
||||||
|
wantSQL: "DEFAULT, NOW(), NULL",
|
||||||
|
wantArg: []any{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed args",
|
||||||
|
args: []any{123, uint32(123), NowInstruction, NullInstruction, SHA256Value("wrapped"), LowerValue("ASDF")},
|
||||||
|
wantSQL: "$1, $2, NOW(), NULL, SHA256($3), LOWER($4)",
|
||||||
|
wantArg: []any{123, uint32(123), "wrapped", "ASDF"},
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
var b StatementBuilder
|
||||||
|
b.WriteArgs(test.args...)
|
||||||
|
assert.Equal(t, test.wantSQL, b.String())
|
||||||
|
require.Len(t, b.Args(), len(test.wantArg))
|
||||||
|
for i := range test.wantArg {
|
||||||
|
assert.Equal(t, test.wantArg[i], b.Args()[i])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -66,7 +66,7 @@ func (p *instanceDomainRelationalProjection) reduceCustomDomainAdded(event event
|
|||||||
if !ok {
|
if !ok {
|
||||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-bXCa6", "reduce.wrong.db.pool %T", ex)
|
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-bXCa6", "reduce.wrong.db.pool %T", ex)
|
||||||
}
|
}
|
||||||
return repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false).Add(ctx, &domain.AddInstanceDomain{
|
return repository.InstanceDomainRepository().Add(ctx, v3_sql.SQLTx(tx), &domain.AddInstanceDomain{
|
||||||
InstanceID: e.Aggregate().InstanceID,
|
InstanceID: e.Aggregate().InstanceID,
|
||||||
Domain: e.Domain,
|
Domain: e.Domain,
|
||||||
IsPrimary: gu.Ptr(false),
|
IsPrimary: gu.Ptr(false),
|
||||||
@@ -88,25 +88,15 @@ func (p *instanceDomainRelationalProjection) reduceDomainPrimarySet(event events
|
|||||||
if !ok {
|
if !ok {
|
||||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-QnjHo", "reduce.wrong.db.pool %T", ex)
|
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-QnjHo", "reduce.wrong.db.pool %T", ex)
|
||||||
}
|
}
|
||||||
domainRepo := repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false)
|
domainRepo := repository.InstanceDomainRepository()
|
||||||
|
|
||||||
condition := database.And(
|
_, err := domainRepo.Update(ctx, v3_sql.SQLTx(tx),
|
||||||
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
database.And(
|
||||||
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
|
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
||||||
domainRepo.TypeCondition(domain.DomainTypeCustom),
|
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
|
||||||
)
|
domainRepo.TypeCondition(domain.DomainTypeCustom),
|
||||||
|
),
|
||||||
_, err := domainRepo.Update(ctx,
|
|
||||||
condition,
|
|
||||||
domainRepo.SetPrimary(),
|
domainRepo.SetPrimary(),
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// we need to split the update into two statements because multiple events can have the same creation date
|
|
||||||
// therefore we first do not set the updated_at timestamp
|
|
||||||
_, err = domainRepo.Update(ctx,
|
|
||||||
condition,
|
|
||||||
domainRepo.SetUpdatedAt(e.CreationDate()),
|
domainRepo.SetUpdatedAt(e.CreationDate()),
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
@@ -123,8 +113,8 @@ func (p *instanceDomainRelationalProjection) reduceCustomDomainRemoved(event eve
|
|||||||
if !ok {
|
if !ok {
|
||||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-58ghE", "reduce.wrong.db.pool %T", ex)
|
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-58ghE", "reduce.wrong.db.pool %T", ex)
|
||||||
}
|
}
|
||||||
domainRepo := repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false)
|
domainRepo := repository.InstanceDomainRepository()
|
||||||
_, err := domainRepo.Remove(ctx,
|
_, err := domainRepo.Remove(ctx, v3_sql.SQLTx(tx),
|
||||||
database.And(
|
database.And(
|
||||||
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
||||||
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
|
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
|
||||||
@@ -145,7 +135,7 @@ func (p *instanceDomainRelationalProjection) reduceTrustedDomainAdded(event even
|
|||||||
if !ok {
|
if !ok {
|
||||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-gx7tQ", "reduce.wrong.db.pool %T", ex)
|
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-gx7tQ", "reduce.wrong.db.pool %T", ex)
|
||||||
}
|
}
|
||||||
return repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false).Add(ctx, &domain.AddInstanceDomain{
|
return repository.InstanceDomainRepository().Add(ctx, v3_sql.SQLTx(tx), &domain.AddInstanceDomain{
|
||||||
InstanceID: e.Aggregate().InstanceID,
|
InstanceID: e.Aggregate().InstanceID,
|
||||||
Domain: e.Domain,
|
Domain: e.Domain,
|
||||||
Type: domain.DomainTypeTrusted,
|
Type: domain.DomainTypeTrusted,
|
||||||
@@ -165,8 +155,8 @@ func (p *instanceDomainRelationalProjection) reduceTrustedDomainRemoved(event ev
|
|||||||
if !ok {
|
if !ok {
|
||||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-D68ap", "reduce.wrong.db.pool %T", ex)
|
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-D68ap", "reduce.wrong.db.pool %T", ex)
|
||||||
}
|
}
|
||||||
domainRepo := repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false)
|
domainRepo := repository.InstanceDomainRepository()
|
||||||
_, err := domainRepo.Remove(ctx,
|
_, err := domainRepo.Remove(ctx, v3_sql.SQLTx(tx),
|
||||||
database.And(
|
database.And(
|
||||||
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
||||||
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
|
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ func (p *instanceRelationalProjection) reduceInstanceAdded(event eventstore.Even
|
|||||||
if !ok {
|
if !ok {
|
||||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
|
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
|
||||||
}
|
}
|
||||||
return repository.InstanceRepository(v3_sql.SQLTx(tx)).Create(ctx, &domain.Instance{
|
return repository.InstanceRepository().Create(ctx, v3_sql.SQLTx(tx), &domain.Instance{
|
||||||
ID: e.Aggregate().ID,
|
ID: e.Aggregate().ID,
|
||||||
Name: e.Name,
|
Name: e.Name,
|
||||||
CreatedAt: e.CreationDate(),
|
CreatedAt: e.CreationDate(),
|
||||||
@@ -93,8 +93,8 @@ func (p *instanceRelationalProjection) reduceInstanceChanged(event eventstore.Ev
|
|||||||
if !ok {
|
if !ok {
|
||||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
|
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
|
||||||
}
|
}
|
||||||
repo := repository.InstanceRepository(v3_sql.SQLTx(tx))
|
repo := repository.InstanceRepository()
|
||||||
return p.updateInstance(ctx, event, repo, repo.SetName(e.Name))
|
return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetName(e.Name))
|
||||||
}), nil
|
}), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,7 +108,7 @@ func (p *instanceRelationalProjection) reduceInstanceDelete(event eventstore.Eve
|
|||||||
if !ok {
|
if !ok {
|
||||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
|
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
|
||||||
}
|
}
|
||||||
_, err := repository.InstanceRepository(v3_sql.SQLTx(tx)).Delete(ctx, e.Aggregate().ID)
|
_, err := repository.InstanceRepository().Delete(ctx, v3_sql.SQLTx(tx), e.Aggregate().ID)
|
||||||
return err
|
return err
|
||||||
}), nil
|
}), nil
|
||||||
}
|
}
|
||||||
@@ -124,8 +124,8 @@ func (p *instanceRelationalProjection) reduceDefaultOrgSet(event eventstore.Even
|
|||||||
if !ok {
|
if !ok {
|
||||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
|
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
|
||||||
}
|
}
|
||||||
repo := repository.InstanceRepository(v3_sql.SQLTx(tx))
|
repo := repository.InstanceRepository()
|
||||||
return p.updateInstance(ctx, event, repo, repo.SetDefaultOrg(e.OrgID))
|
return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetDefaultOrg(e.OrgID))
|
||||||
}), nil
|
}), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,8 +140,8 @@ func (p *instanceRelationalProjection) reduceIAMProjectSet(event eventstore.Even
|
|||||||
if !ok {
|
if !ok {
|
||||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
|
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
|
||||||
}
|
}
|
||||||
repo := repository.InstanceRepository(v3_sql.SQLTx(tx))
|
repo := repository.InstanceRepository()
|
||||||
return p.updateInstance(ctx, event, repo, repo.SetIAMProject(e.ProjectID))
|
return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetIAMProject(e.ProjectID))
|
||||||
}), nil
|
}), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,8 +156,8 @@ func (p *instanceRelationalProjection) reduceConsoleSet(event eventstore.Event)
|
|||||||
if !ok {
|
if !ok {
|
||||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
|
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
|
||||||
}
|
}
|
||||||
repo := repository.InstanceRepository(v3_sql.SQLTx(tx))
|
repo := repository.InstanceRepository()
|
||||||
return p.updateInstance(ctx, event, repo, repo.SetConsoleClientID(e.ClientID), repo.SetConsoleAppID(e.AppID))
|
return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetConsoleClientID(e.ClientID), repo.SetConsoleAppID(e.AppID))
|
||||||
}), nil
|
}), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -172,18 +172,18 @@ func (p *instanceRelationalProjection) reduceDefaultLanguageSet(event eventstore
|
|||||||
if !ok {
|
if !ok {
|
||||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
|
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex)
|
||||||
}
|
}
|
||||||
repo := repository.InstanceRepository(v3_sql.SQLTx(tx))
|
repo := repository.InstanceRepository()
|
||||||
return p.updateInstance(ctx, event, repo, repo.SetDefaultLanguage(e.Language))
|
return p.updateInstance(ctx, v3_sql.SQLTx(tx), event, repo, repo.SetDefaultLanguage(e.Language))
|
||||||
}), nil
|
}), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *instanceRelationalProjection) updateInstance(ctx context.Context, event eventstore.Event, repo domain.InstanceRepository, changes ...database.Change) error {
|
func (p *instanceRelationalProjection) updateInstance(ctx context.Context, tx database.Transaction, event eventstore.Event, repo domain.InstanceRepository, changes ...database.Change) error {
|
||||||
_, err := repo.Update(ctx, event.Aggregate().ID, changes...)
|
_, err := repo.Update(ctx, tx, event.Aggregate().ID, changes...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
instance, err := repo.Get(ctx, database.WithCondition(repo.IDCondition(event.Aggregate().ID)))
|
instance, err := repo.Get(ctx, tx, database.WithCondition(repo.IDCondition(event.Aggregate().ID)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -192,7 +192,7 @@ func (p *instanceRelationalProjection) updateInstance(ctx context.Context, event
|
|||||||
}
|
}
|
||||||
// we need to split the update into two statements because multiple events can have the same creation date
|
// we need to split the update into two statements because multiple events can have the same creation date
|
||||||
// therefore we first do not set the updated_at timestamp
|
// therefore we first do not set the updated_at timestamp
|
||||||
_, err = repo.Update(ctx,
|
_, err = repo.Update(ctx, tx,
|
||||||
event.Aggregate().ID,
|
event.Aggregate().ID,
|
||||||
repo.SetUpdatedAt(event.CreatedAt()),
|
repo.SetUpdatedAt(event.CreatedAt()),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ func (p *orgDomainRelationalProjection) reduceAdded(event eventstore.Event) (*ha
|
|||||||
if !ok {
|
if !ok {
|
||||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-kGokE", "reduce.wrong.db.pool %T", ex)
|
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-kGokE", "reduce.wrong.db.pool %T", ex)
|
||||||
}
|
}
|
||||||
return repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false).Add(ctx, &domain.AddOrganizationDomain{
|
return repository.OrganizationDomainRepository().Add(ctx, v3_sql.SQLTx(tx), &domain.AddOrganizationDomain{
|
||||||
InstanceID: e.Aggregate().InstanceID,
|
InstanceID: e.Aggregate().InstanceID,
|
||||||
OrgID: e.Aggregate().ResourceOwner,
|
OrgID: e.Aggregate().ResourceOwner,
|
||||||
Domain: e.Domain,
|
Domain: e.Domain,
|
||||||
@@ -85,23 +85,14 @@ func (p *orgDomainRelationalProjection) reducePrimarySet(event eventstore.Event)
|
|||||||
if !ok {
|
if !ok {
|
||||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-h6xF0", "reduce.wrong.db.pool %T", ex)
|
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-h6xF0", "reduce.wrong.db.pool %T", ex)
|
||||||
}
|
}
|
||||||
domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false)
|
domainRepo := repository.OrganizationDomainRepository()
|
||||||
condition := database.And(
|
_, err := domainRepo.Update(ctx, v3_sql.SQLTx(tx),
|
||||||
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
database.And(
|
||||||
domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
|
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
||||||
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
|
domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
|
||||||
)
|
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
|
||||||
_, err := domainRepo.Update(ctx,
|
),
|
||||||
condition,
|
|
||||||
domainRepo.SetPrimary(),
|
domainRepo.SetPrimary(),
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// we need to split the update into two statements because multiple events can have the same creation date
|
|
||||||
// therefore we first do not set the updated_at timestamp
|
|
||||||
_, err = domainRepo.Update(ctx,
|
|
||||||
condition,
|
|
||||||
domainRepo.SetUpdatedAt(e.CreationDate()),
|
domainRepo.SetUpdatedAt(e.CreationDate()),
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
@@ -118,8 +109,8 @@ func (p *orgDomainRelationalProjection) reduceRemoved(event eventstore.Event) (*
|
|||||||
if !ok {
|
if !ok {
|
||||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-X8oS8", "reduce.wrong.db.pool %T", ex)
|
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-X8oS8", "reduce.wrong.db.pool %T", ex)
|
||||||
}
|
}
|
||||||
domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false)
|
domainRepo := repository.OrganizationDomainRepository()
|
||||||
_, err := domainRepo.Remove(ctx,
|
_, err := domainRepo.Remove(ctx, v3_sql.SQLTx(tx),
|
||||||
database.And(
|
database.And(
|
||||||
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
||||||
domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
|
domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
|
||||||
@@ -149,24 +140,15 @@ func (p *orgDomainRelationalProjection) reduceVerificationAdded(event eventstore
|
|||||||
if !ok {
|
if !ok {
|
||||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-yF03i", "reduce.wrong.db.pool %T", ex)
|
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-yF03i", "reduce.wrong.db.pool %T", ex)
|
||||||
}
|
}
|
||||||
domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false)
|
domainRepo := repository.OrganizationDomainRepository()
|
||||||
condition := database.And(
|
|
||||||
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
|
||||||
domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
|
|
||||||
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
|
|
||||||
)
|
|
||||||
|
|
||||||
_, err := domainRepo.Update(ctx,
|
_, err := domainRepo.Update(ctx, v3_sql.SQLTx(tx),
|
||||||
condition,
|
database.And(
|
||||||
|
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
||||||
|
domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
|
||||||
|
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
|
||||||
|
),
|
||||||
domainRepo.SetValidationType(validationType),
|
domainRepo.SetValidationType(validationType),
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// we need to split the update into two statements because multiple events can have the same creation date
|
|
||||||
// therefore we first do not set the updated_at timestamp
|
|
||||||
_, err = domainRepo.Update(ctx,
|
|
||||||
condition,
|
|
||||||
domainRepo.SetUpdatedAt(e.CreationDate()),
|
domainRepo.SetUpdatedAt(e.CreationDate()),
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
@@ -183,28 +165,17 @@ func (p *orgDomainRelationalProjection) reduceVerified(event eventstore.Event) (
|
|||||||
if !ok {
|
if !ok {
|
||||||
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-0ZGqC", "reduce.wrong.db.pool %T", ex)
|
return zerrors.ThrowInvalidArgumentf(nil, "HANDL-0ZGqC", "reduce.wrong.db.pool %T", ex)
|
||||||
}
|
}
|
||||||
domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false)
|
domainRepo := repository.OrganizationDomainRepository()
|
||||||
|
|
||||||
condition := database.And(
|
_, err := domainRepo.Update(ctx, v3_sql.SQLTx(tx),
|
||||||
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
database.And(
|
||||||
domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
|
domainRepo.InstanceIDCondition(e.Aggregate().InstanceID),
|
||||||
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
|
domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner),
|
||||||
)
|
domainRepo.DomainCondition(database.TextOperationEqual, e.Domain),
|
||||||
|
),
|
||||||
_, err := domainRepo.Update(ctx,
|
|
||||||
condition,
|
|
||||||
domainRepo.SetVerified(),
|
domainRepo.SetVerified(),
|
||||||
domainRepo.SetUpdatedAt(e.CreationDate()),
|
domainRepo.SetUpdatedAt(e.CreationDate()),
|
||||||
)
|
)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// we need to split the update into two statements because multiple events can have the same creation date
|
|
||||||
// therefore we first do not set the updated_at timestamp
|
|
||||||
_, err = domainRepo.Update(ctx,
|
|
||||||
condition,
|
|
||||||
domainRepo.SetUpdatedAt(e.CreationDate()),
|
|
||||||
)
|
|
||||||
return err
|
return err
|
||||||
}), nil
|
}), nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user