Merge branch 'error-handling' into rt-domains

This commit is contained in:
adlerhurst
2025-07-17 09:38:51 +02:00
12 changed files with 352 additions and 92 deletions

View File

@@ -41,10 +41,10 @@ type Config struct {
func (c *Config) Connect(ctx context.Context) (database.Pool, error) { func (c *Config) Connect(ctx context.Context) (database.Pool, error) {
pool, err := c.getPool(ctx) pool, err := c.getPool(ctx)
if err != nil { if err != nil {
return nil, err return nil, wrapError(err)
} }
if err = pool.Ping(ctx); err != nil { if err = pool.Ping(ctx); err != nil {
return nil, err return nil, wrapError(err)
} }
return &pgxPool{Pool: pool}, nil return &pgxPool{Pool: pool}, nil
} }

View File

@@ -25,7 +25,7 @@ func (c *pgxConn) Release(_ context.Context) error {
func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) { func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
tx, err := c.Conn.BeginTx(ctx, transactionOptionsToPgx(opts)) tx, err := c.Conn.BeginTx(ctx, transactionOptionsToPgx(opts))
if err != nil { if err != nil {
return nil, err return nil, wrapError(err)
} }
return &pgxTx{tx}, nil return &pgxTx{tx}, nil
} }
@@ -34,20 +34,26 @@ func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions)
// Subtle: this method shadows the method (*Conn).Query of pgxConn.Conn. // Subtle: this method shadows the method (*Conn).Query of pgxConn.Conn.
func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
rows, err := c.Conn.Query(ctx, sql, args...) rows, err := c.Conn.Query(ctx, sql, args...)
return &Rows{rows}, err if err != nil {
return nil, wrapError(err)
}
return &Rows{rows}, nil
} }
// QueryRow implements sql.Client. // QueryRow implements sql.Client.
// Subtle: this method shadows the method (*Conn).QueryRow of pgxConn.Conn. // Subtle: this method shadows the method (*Conn).QueryRow of pgxConn.Conn.
func (c *pgxConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row { func (c *pgxConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return c.Conn.QueryRow(ctx, sql, args...) return &Row{c.Conn.QueryRow(ctx, sql, args...)}
} }
// Exec implements [database.Pool]. // Exec implements [database.Pool].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool. // Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) (int64, error) { func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
res, err := c.Conn.Exec(ctx, sql, args...) res, err := c.Conn.Exec(ctx, sql, args...)
return res.RowsAffected(), err if err != nil {
return 0, wrapError(err)
}
return res.RowsAffected(), nil
} }
// Migrate implements [database.Migrator]. // Migrate implements [database.Migrator].
@@ -57,5 +63,5 @@ func (c *pgxConn) Migrate(ctx context.Context) error {
} }
err := migration.Migrate(ctx, c.Conn.Conn()) err := migration.Migrate(ctx, c.Conn.Conn())
isMigrated = err == nil isMigrated = err == nil
return err return wrapError(err)
} }

View File

@@ -0,0 +1,38 @@
package postgres
import (
"errors"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
func wrapError(err error) error {
if err == nil {
return nil
}
if errors.Is(err, pgx.ErrNoRows) {
return database.NewNoRowFoundError(err)
}
var pgxErr *pgconn.PgError
if !errors.As(err, &pgxErr) {
return database.NewUnknownError(err)
}
switch pgxErr.Code {
// 23514: check_violation - A value violates a CHECK constraint.
case "23514":
return database.NewCheckError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr)
// 23505: unique_violation - A value violates a UNIQUE constraint.
case "23505":
return database.NewUniqueError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr)
// 23503: foreign_key_violation - A value violates a foreign key constraint.
case "23503":
return database.NewForeignKeyError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr)
// 23502: not_null_violation - A value violates a NOT NULL constraint.
case "23502":
return database.NewNotNullError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr)
}
return database.NewUnknownError(err)
}

View File

@@ -25,7 +25,7 @@ func PGxPool(pool *pgxpool.Pool) *pgxPool {
func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) { func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) {
conn, err := c.Pool.Acquire(ctx) conn, err := c.Pool.Acquire(ctx)
if err != nil { if err != nil {
return nil, err return nil, wrapError(err)
} }
return &pgxConn{Conn: conn}, nil return &pgxConn{Conn: conn}, nil
} }
@@ -34,27 +34,33 @@ func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) {
// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool. // Subtle: this method shadows the method (Pool).Query of pgxPool.Pool.
func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
rows, err := c.Pool.Query(ctx, sql, args...) rows, err := c.Pool.Query(ctx, sql, args...)
return &Rows{rows}, err if err != nil {
return nil, wrapError(err)
}
return &Rows{rows}, nil
} }
// QueryRow implements [database.Pool]. // QueryRow implements [database.Pool].
// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool. // Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool.
func (c *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row { func (c *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return c.Pool.QueryRow(ctx, sql, args...) return &Row{c.Pool.QueryRow(ctx, sql, args...)}
} }
// Exec implements [database.Pool]. // Exec implements [database.Pool].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool. // Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) { func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
res, err := c.Pool.Exec(ctx, sql, args...) res, err := c.Pool.Exec(ctx, sql, args...)
return res.RowsAffected(), err if err != nil {
return 0, wrapError(err)
}
return res.RowsAffected(), nil
} }
// Begin implements [database.Pool]. // Begin implements [database.Pool].
func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) { func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
tx, err := c.Pool.BeginTx(ctx, transactionOptionsToPgx(opts)) tx, err := c.Pool.BeginTx(ctx, transactionOptionsToPgx(opts))
if err != nil { if err != nil {
return nil, err return nil, wrapError(err)
} }
return &pgxTx{tx}, nil return &pgxTx{tx}, nil
} }
@@ -78,7 +84,7 @@ func (c *pgxPool) Migrate(ctx context.Context) error {
err = migration.Migrate(ctx, client.Conn()) err = migration.Migrate(ctx, client.Conn())
isMigrated = err == nil isMigrated = err == nil
return err return wrapError(err)
} }
// Migrate implements [database.PoolTest]. // Migrate implements [database.PoolTest].

View File

@@ -10,10 +10,29 @@ import (
var ( var (
_ database.Rows = (*Rows)(nil) _ database.Rows = (*Rows)(nil)
_ database.CollectableRows = (*Rows)(nil) _ database.CollectableRows = (*Rows)(nil)
_ database.Row = (*Row)(nil)
) )
type Row struct{ pgx.Row }
// Scan implements [database.Row].
// Subtle: this method shadows the method ([pgx.Row]).Scan of Row.Row.
func (r *Row) Scan(dest ...any) error {
return wrapError(r.Row.Scan(dest...))
}
type Rows struct{ pgx.Rows } type Rows struct{ pgx.Rows }
// Err implements [database.Rows].
// Subtle: this method shadows the method ([pgx.Rows]).Err of Rows.Rows.
func (r *Rows) Err() error {
return wrapError(r.Rows.Err())
}
func (r *Rows) Scan(dest ...any) error {
return wrapError(r.Rows.Scan(dest...))
}
// Collect implements [database.CollectableRows]. // Collect implements [database.CollectableRows].
// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details. // See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details.
func (r *Rows) Collect(dest any) (err error) { func (r *Rows) Collect(dest any) (err error) {
@@ -23,7 +42,7 @@ func (r *Rows) Collect(dest any) (err error) {
err = closeErr err = closeErr
} }
}() }()
return pgxscan.ScanAll(dest, r.Rows) return wrapError(pgxscan.ScanAll(dest, r.Rows))
} }
// CollectFirst implements [database.CollectableRows]. // CollectFirst implements [database.CollectableRows].
@@ -35,7 +54,7 @@ func (r *Rows) CollectFirst(dest any) (err error) {
err = closeErr err = closeErr
} }
}() }()
return pgxscan.ScanRow(dest, r.Rows) return wrapError(pgxscan.ScanRow(dest, r.Rows))
} }
// CollectExactlyOneRow implements [database.CollectableRows]. // CollectExactlyOneRow implements [database.CollectableRows].
@@ -47,7 +66,7 @@ func (r *Rows) CollectExactlyOneRow(dest any) (err error) {
err = closeErr err = closeErr
} }
}() }()
return pgxscan.ScanOne(dest, r.Rows) return wrapError(pgxscan.ScanOne(dest, r.Rows))
} }
// Close implements [database.Rows]. // Close implements [database.Rows].

View File

@@ -15,12 +15,14 @@ var _ database.Transaction = (*pgxTx)(nil)
// Commit implements [database.Transaction]. // Commit implements [database.Transaction].
func (tx *pgxTx) Commit(ctx context.Context) error { func (tx *pgxTx) Commit(ctx context.Context) error {
return tx.Tx.Commit(ctx) err := tx.Tx.Commit(ctx)
return wrapError(err)
} }
// Rollback implements [database.Transaction]. // Rollback implements [database.Transaction].
func (tx *pgxTx) Rollback(ctx context.Context) error { func (tx *pgxTx) Rollback(ctx context.Context) error {
return tx.Tx.Rollback(ctx) err := tx.Tx.Rollback(ctx)
return wrapError(err)
} }
// End implements [database.Transaction]. // End implements [database.Transaction].
@@ -39,20 +41,26 @@ func (tx *pgxTx) End(ctx context.Context, err error) error {
// Subtle: this method shadows the method (Tx).Query of pgxTx.Tx. // Subtle: this method shadows the method (Tx).Query of pgxTx.Tx.
func (tx *pgxTx) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { func (tx *pgxTx) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
rows, err := tx.Tx.Query(ctx, sql, args...) rows, err := tx.Tx.Query(ctx, sql, args...)
return &Rows{rows}, err if err != nil {
return nil, wrapError(err)
}
return &Rows{rows}, nil
} }
// QueryRow implements [database.Transaction]. // QueryRow implements [database.Transaction].
// Subtle: this method shadows the method (Tx).QueryRow of pgxTx.Tx. // Subtle: this method shadows the method (Tx).QueryRow of pgxTx.Tx.
func (tx *pgxTx) QueryRow(ctx context.Context, sql string, args ...any) database.Row { func (tx *pgxTx) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return tx.Tx.QueryRow(ctx, sql, args...) return &Row{tx.Tx.QueryRow(ctx, sql, args...)}
} }
// Exec implements [database.Transaction]. // Exec implements [database.Transaction].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool. // Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (tx *pgxTx) Exec(ctx context.Context, sql string, args ...any) (int64, error) { func (tx *pgxTx) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
res, err := tx.Tx.Exec(ctx, sql, args...) res, err := tx.Tx.Exec(ctx, sql, args...)
return res.RowsAffected(), err if err != nil {
return 0, wrapError(err)
}
return res.RowsAffected(), nil
} }
// Begin implements [database.Transaction]. // Begin implements [database.Transaction].
@@ -60,7 +68,7 @@ func (tx *pgxTx) Exec(ctx context.Context, sql string, args ...any) (int64, erro
func (tx *pgxTx) Begin(ctx context.Context) (database.Transaction, error) { func (tx *pgxTx) Begin(ctx context.Context) (database.Transaction, error) {
savepoint, err := tx.Tx.Begin(ctx) savepoint, err := tx.Tx.Begin(ctx)
if err != nil { if err != nil {
return nil, err return nil, wrapError(err)
} }
return &pgxTx{savepoint}, nil return &pgxTx{savepoint}, nil
} }

View File

@@ -0,0 +1,228 @@
package database
import (
"fmt"
)
// NoRowFoundError is returned when QueryRow does not find any row.
// It wraps the dialect specific original error to provide more context.
type NoRowFoundError struct {
original error
}
func NewNoRowFoundError(original error) error {
return &NoRowFoundError{
original: original,
}
}
func (e *NoRowFoundError) Error() string {
return "no row found"
}
func (e *NoRowFoundError) Is(target error) bool {
_, ok := target.(*NoRowFoundError)
return ok
}
func (e *NoRowFoundError) Unwrap() error {
return e.original
}
// MultipleRowsFoundError is returned when QueryRow finds multiple rows.
// It wraps the dialect specific original error to provide more context.
type MultipleRowsFoundError struct {
original error
count int
}
func NewMultipleRowsFoundError(original error, count int) error {
return &MultipleRowsFoundError{
original: original,
count: count,
}
}
func (e *MultipleRowsFoundError) Error() string {
return fmt.Sprintf("multiple rows found: %d", e.count)
}
func (e *MultipleRowsFoundError) Is(target error) bool {
_, ok := target.(*MultipleRowsFoundError)
return ok
}
func (e *MultipleRowsFoundError) Unwrap() error {
return e.original
}
type IntegrityType string
const (
IntegrityTypeCheck IntegrityType = "check"
IntegrityTypeUnique IntegrityType = "unique"
IntegrityTypeForeign IntegrityType = "foreign"
IntegrityTypeNotNull IntegrityType = "not null"
IntegrityTypeUnknown IntegrityType = "unknown"
)
// IntegrityViolationError represents a generic integrity violation error.
// It wraps the dialect specific original error to provide more context.
type IntegrityViolationError struct {
integrityType IntegrityType
table string
constraint string
original error
}
func NewIntegrityViolationError(typ IntegrityType, table, constraint string, original error) error {
return &IntegrityViolationError{
integrityType: typ,
table: table,
constraint: constraint,
original: original,
}
}
func (e *IntegrityViolationError) Error() string {
return fmt.Sprintf("integrity violation of type %q on %q (constraint: %q): %v", e.integrityType, e.table, e.constraint, e.original)
}
func (e *IntegrityViolationError) Is(target error) bool {
_, ok := target.(*IntegrityViolationError)
return ok
}
// CheckError is returned when a check constraint fails.
// It wraps the [IntegrityViolationError] to provide more context.
// It is used to indicate that a check constraint was violated during an insert or update operation.
type CheckError struct {
IntegrityViolationError
}
func NewCheckError(table, constraint string, original error) error {
return &CheckError{
IntegrityViolationError: IntegrityViolationError{
integrityType: IntegrityTypeCheck,
table: table,
constraint: constraint,
original: original,
},
}
}
func (e *CheckError) Is(target error) bool {
_, ok := target.(*CheckError)
return ok
}
func (e *CheckError) Unwrap() error {
return &e.IntegrityViolationError
}
// UniqueError is returned when a unique constraint fails.
// It wraps the [IntegrityViolationError] to provide more context.
// It is used to indicate that a unique constraint was violated during an insert or update operation.
type UniqueError struct {
IntegrityViolationError
}
func NewUniqueError(table, constraint string, original error) error {
return &UniqueError{
IntegrityViolationError: IntegrityViolationError{
integrityType: IntegrityTypeUnique,
table: table,
constraint: constraint,
original: original,
},
}
}
func (e *UniqueError) Is(target error) bool {
_, ok := target.(*UniqueError)
return ok
}
func (e *UniqueError) Unwrap() error {
return &e.IntegrityViolationError
}
// ForeignKeyError is returned when a foreign key constraint fails.
// It wraps the [IntegrityViolationError] to provide more context.
// It is used to indicate that a foreign key constraint was violated during an insert or update operation
type ForeignKeyError struct {
IntegrityViolationError
}
func NewForeignKeyError(table, constraint string, original error) error {
return &ForeignKeyError{
IntegrityViolationError: IntegrityViolationError{
integrityType: IntegrityTypeForeign,
table: table,
constraint: constraint,
original: original,
},
}
}
func (e *ForeignKeyError) Is(target error) bool {
_, ok := target.(*ForeignKeyError)
return ok
}
func (e *ForeignKeyError) Unwrap() error {
return &e.IntegrityViolationError
}
// NotNullError is returned when a not null constraint fails.
// It wraps the [IntegrityViolationError] to provide more context.
// It is used to indicate that a not null constraint was violated during an insert or update operation.
type NotNullError struct {
IntegrityViolationError
}
func NewNotNullError(table, constraint string, original error) error {
return &NotNullError{
IntegrityViolationError: IntegrityViolationError{
integrityType: IntegrityTypeNotNull,
table: table,
constraint: constraint,
original: original,
},
}
}
func (e *NotNullError) Is(target error) bool {
_, ok := target.(*NotNullError)
return ok
}
func (e *NotNullError) Unwrap() error {
return &e.IntegrityViolationError
}
// UnknownError is returned when an unknown error occurs.
// It wraps the dialect specific original error to provide more context.
// It is used to indicate that an error occurred that does not fit into any of the other categories.
type UnknownError struct {
original error
}
func NewUnknownError(original error) error {
return &UnknownError{
original: original,
}
}
func (e *UnknownError) Error() string {
return fmt.Sprintf("unknown database error: %v", e.original)
}
func (e *UnknownError) Is(target error) bool {
_, ok := target.(*UnknownError)
return ok
}
func (e *UnknownError) Unwrap() error {
return e.original
}

View File

@@ -4,8 +4,6 @@ import (
"context" "context"
"errors" "errors"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/zitadel/backend/v3/domain" "github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database" "github.com/zitadel/zitadel/backend/v3/storage/database"
) )
@@ -69,28 +67,7 @@ func (i *instance) Create(ctx context.Context, instance *domain.Instance) error
builder.AppendArgs(instance.ID, instance.Name, instance.DefaultOrgID, instance.IAMProjectID, instance.ConsoleClientID, instance.ConsoleAppID, instance.DefaultLanguage) builder.AppendArgs(instance.ID, instance.Name, instance.DefaultOrgID, instance.IAMProjectID, instance.ConsoleClientID, instance.ConsoleAppID, instance.DefaultLanguage)
builder.WriteString(createInstanceStmt) builder.WriteString(createInstanceStmt)
err := i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&instance.CreatedAt, &instance.UpdatedAt) return i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&instance.CreatedAt, &instance.UpdatedAt)
if err != nil {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
// constraint violation
if pgErr.Code == "23514" {
if pgErr.ConstraintName == "instances_name_check" {
return errors.New("instance name not provided")
}
if pgErr.ConstraintName == "instances_id_check" {
return errors.New("instance id not provided")
}
}
// duplicate
if pgErr.Code == "23505" {
if pgErr.ConstraintName == "instances_pkey" {
return errors.New("instance id already exists")
}
}
}
}
return err
} }
// Update implements [domain.InstanceRepository]. // Update implements [domain.InstanceRepository].
@@ -109,8 +86,7 @@ func (i instance) Update(ctx context.Context, id string, changes ...database.Cha
stmt := builder.String() stmt := builder.String()
rowsAffected, err := i.client.Exec(ctx, stmt, builder.Args()...) return i.client.Exec(ctx, stmt, builder.Args()...)
return rowsAffected, err
} }
// Delete implements [domain.InstanceRepository]. // Delete implements [domain.InstanceRepository].
@@ -205,9 +181,6 @@ func scanInstance(ctx context.Context, querier database.Querier, builder *databa
instance := new(domain.Instance) instance := new(domain.Instance)
if err := rows.(database.CollectableRows).CollectExactlyOneRow(instance); err != nil { if err := rows.(database.CollectableRows).CollectExactlyOneRow(instance); err != nil {
if err.Error() == "no rows in result set" {
return nil, ErrResourceDoesNotExist
}
return nil, err return nil, err
} }
@@ -221,12 +194,6 @@ func scanInstances(ctx context.Context, querier database.Querier, builder *datab
} }
if err := rows.(database.CollectableRows).Collect(&instances); err != nil { if err := rows.(database.CollectableRows).Collect(&instances); err != nil {
// if no results returned, this is not a error
// it just means the instance was not found
// the caller should check if the returned instance is nil
if err.Error() == "no rows in result set" {
return nil, nil
}
return nil, err return nil, err
} }

View File

@@ -2,7 +2,6 @@ package repository_test
import ( import (
"context" "context"
"errors"
"testing" "testing"
"time" "time"
@@ -55,7 +54,7 @@ func TestCreateInstance(t *testing.T) {
} }
return instance return instance
}(), }(),
err: errors.New("instance name not provided"), err: new(database.CheckError),
}, },
{ {
name: "adding same instance twice", name: "adding same instance twice",
@@ -80,7 +79,7 @@ func TestCreateInstance(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
return &inst return &inst
}, },
err: errors.New("instance id already exists"), err: new(database.UniqueError),
}, },
func() struct { func() struct {
name string name string
@@ -146,7 +145,7 @@ func TestCreateInstance(t *testing.T) {
} }
return instance return instance
}(), }(),
err: errors.New("instance id not provided"), err: new(database.CheckError),
}, },
} }
for _, tt := range tests { for _, tt := range tests {
@@ -164,7 +163,7 @@ func TestCreateInstance(t *testing.T) {
// create instance // create instance
beforeCreate := time.Now() beforeCreate := time.Now()
err := instanceRepo.Create(ctx, instance) err := instanceRepo.Create(ctx, instance)
assert.Equal(t, tt.err, err) assert.ErrorIs(t, err, tt.err)
if err != nil { if err != nil {
return return
} }
@@ -263,7 +262,7 @@ func TestUpdateInstance(t *testing.T) {
return &inst return &inst
}, },
rowsAffected: 0, rowsAffected: 0,
getErr: repository.ErrResourceDoesNotExist, getErr: new(database.NoRowFoundError),
}, },
} }
for _, tt := range tests { for _, tt := range tests {
@@ -342,7 +341,7 @@ func TestGetInstance(t *testing.T) {
} }
return &inst return &inst
}, },
err: repository.ErrResourceDoesNotExist, err: new(database.NoRowFoundError),
}, },
} }
for _, tt := range tests { for _, tt := range tests {
@@ -360,7 +359,7 @@ func TestGetInstance(t *testing.T) {
instance.ID, instance.ID,
) )
if tt.err != nil { if tt.err != nil {
require.Equal(t, tt.err, err) require.ErrorIs(t, err, tt.err)
return return
} }
@@ -655,7 +654,7 @@ func TestDeleteInstance(t *testing.T) {
instance, err := instanceRepo.Get(ctx, instance, err := instanceRepo.Get(ctx,
tt.instanceID, tt.instanceID,
) )
require.Equal(t, err, repository.ErrResourceDoesNotExist) require.ErrorIs(t, err, new(database.NoRowFoundError))
assert.Nil(t, instance) assert.Nil(t, instance)
}) })
} }

View File

@@ -244,12 +244,6 @@ func scanOrganizations(ctx context.Context, querier database.Querier, builder *d
organizations := []*domain.Organization{} organizations := []*domain.Organization{}
if err := rows.(database.CollectableRows).Collect(&organizations); err != nil { if err := rows.(database.CollectableRows).Collect(&organizations); err != nil {
// if no results returned, this is not a error
// it just means the organization was not found
// the caller should check if the returned organization is nil
if err.Error() == "no rows in result set" {
return nil, nil
}
return nil, err return nil, err
} }
return organizations, nil return organizations, nil

View File

@@ -2,7 +2,6 @@ package repository_test
import ( import (
"context" "context"
"errors"
"testing" "testing"
"time" "time"
@@ -29,7 +28,7 @@ func TestCreateOrganization(t *testing.T) {
} }
instanceRepo := repository.InstanceRepository(pool) instanceRepo := repository.InstanceRepository(pool)
err := instanceRepo.Create(t.Context(), &instance) err := instanceRepo.Create(t.Context(), &instance)
assert.Nil(t, err) require.NoError(t, err)
tests := []struct { tests := []struct {
name string name string
@@ -64,7 +63,7 @@ func TestCreateOrganization(t *testing.T) {
} }
return organization return organization
}(), }(),
err: errors.New("organization name not provided"), err: new(database.CheckError),
}, },
{ {
name: "adding org with same id twice", name: "adding org with same id twice",
@@ -86,7 +85,7 @@ func TestCreateOrganization(t *testing.T) {
org.Name = gofakeit.Name() org.Name = gofakeit.Name()
return &org return &org
}, },
err: errors.New("organization id already exists"), err: new(database.UniqueError),
}, },
{ {
name: "adding org with same name twice", name: "adding org with same name twice",
@@ -108,7 +107,7 @@ func TestCreateOrganization(t *testing.T) {
org.ID = gofakeit.Name() org.ID = gofakeit.Name()
return &org return &org
}, },
err: errors.New("organization name already exists for instance"), err: new(database.UniqueError),
}, },
func() struct { func() struct {
name string name string
@@ -181,7 +180,7 @@ func TestCreateOrganization(t *testing.T) {
} }
return organization return organization
}(), }(),
err: errors.New("organization id not provided"), err: new(database.CheckError),
}, },
{ {
name: "adding organization with no instance id", name: "adding organization with no instance id",
@@ -195,7 +194,7 @@ func TestCreateOrganization(t *testing.T) {
} }
return organization return organization
}(), }(),
err: errors.New("invalid instance id"), err: new(database.ForeignKeyError),
}, },
{ {
name: "adding organization with non existent instance id", name: "adding organization with non existent instance id",
@@ -210,7 +209,7 @@ func TestCreateOrganization(t *testing.T) {
} }
return organization return organization
}(), }(),
err: errors.New("invalid instance id"), err: new(database.ForeignKeyError),
}, },
} }
for _, tt := range tests { for _, tt := range tests {
@@ -228,7 +227,7 @@ func TestCreateOrganization(t *testing.T) {
// create organization // create organization
beforeCreate := time.Now() beforeCreate := time.Now()
err = organizationRepo.Create(ctx, organization) err = organizationRepo.Create(ctx, organization)
assert.Equal(t, tt.err, err) assert.ErrorIs(t, err, tt.err)
if err != nil { if err != nil {
return return
} }
@@ -265,7 +264,7 @@ func TestUpdateOrganization(t *testing.T) {
} }
instanceRepo := repository.InstanceRepository(pool) instanceRepo := repository.InstanceRepository(pool)
err := instanceRepo.Create(t.Context(), &instance) err := instanceRepo.Create(t.Context(), &instance)
assert.Nil(t, err) require.NoError(t, err)
organizationRepo := repository.OrganizationRepository(pool) organizationRepo := repository.OrganizationRepository(pool)
tests := []struct { tests := []struct {
@@ -417,7 +416,7 @@ func TestGetOrganization(t *testing.T) {
} }
instanceRepo := repository.InstanceRepository(pool) instanceRepo := repository.InstanceRepository(pool)
err := instanceRepo.Create(t.Context(), &instance) err := instanceRepo.Create(t.Context(), &instance)
assert.Nil(t, err) require.NoError(t, err)
orgRepo := repository.OrganizationRepository(pool) orgRepo := repository.OrganizationRepository(pool)
@@ -497,7 +496,7 @@ func TestGetOrganization(t *testing.T) {
return &org return &org
}, },
orgIdentifierCondition: orgRepo.NameCondition("non-existent-instance-name"), orgIdentifierCondition: orgRepo.NameCondition("non-existent-instance-name"),
err: repository.ErrResourceDoesNotExist, err: new(database.NoRowFoundError),
}, },
} }
for _, tt := range tests { for _, tt := range tests {
@@ -516,7 +515,7 @@ func TestGetOrganization(t *testing.T) {
org.InstanceID, org.InstanceID,
) )
if tt.err != nil { if tt.err != nil {
require.Equal(t, tt.err, err) require.ErrorIs(t, tt.err, err)
return return
} }
@@ -553,7 +552,7 @@ func TestListOrganization(t *testing.T) {
} }
instanceRepo := repository.InstanceRepository(pool) instanceRepo := repository.InstanceRepository(pool)
err = instanceRepo.Create(ctx, &instance) err = instanceRepo.Create(ctx, &instance)
assert.Nil(t, err) require.NoError(t, err)
type test struct { type test struct {
name string name string
@@ -800,7 +799,7 @@ func TestDeleteOrganization(t *testing.T) {
} }
instanceRepo := repository.InstanceRepository(pool) instanceRepo := repository.InstanceRepository(pool)
err := instanceRepo.Create(t.Context(), &instance) err := instanceRepo.Create(t.Context(), &instance)
assert.Nil(t, err) require.NoError(t, err)
type test struct { type test struct {
name string name string
@@ -933,7 +932,7 @@ func TestDeleteOrganization(t *testing.T) {
tt.orgIdentifierCondition, tt.orgIdentifierCondition,
instanceId, instanceId,
) )
require.Equal(t, err, repository.ErrResourceDoesNotExist) require.ErrorIs(t, err, new(database.NoRowFoundError))
assert.Nil(t, organization) assert.Nil(t, organization)
}) })
} }

View File

@@ -1,13 +1,9 @@
package repository package repository
import ( import (
"errors"
"github.com/zitadel/zitadel/backend/v3/storage/database" "github.com/zitadel/zitadel/backend/v3/storage/database"
) )
var ErrResourceDoesNotExist = errors.New("resource does not exist")
type repository struct { type repository struct {
client database.QueryExecutor client database.QueryExecutor
} }