diff --git a/backend/v3/storage/database/dialect/postgres/config.go b/backend/v3/storage/database/dialect/postgres/config.go index b10088bf7c..1e4e5d751b 100644 --- a/backend/v3/storage/database/dialect/postgres/config.go +++ b/backend/v3/storage/database/dialect/postgres/config.go @@ -41,10 +41,10 @@ type Config struct { func (c *Config) Connect(ctx context.Context) (database.Pool, error) { pool, err := c.getPool(ctx) if err != nil { - return nil, err + return nil, wrapError(err) } if err = pool.Ping(ctx); err != nil { - return nil, err + return nil, wrapError(err) } return &pgxPool{Pool: pool}, nil } diff --git a/backend/v3/storage/database/dialect/postgres/conn.go b/backend/v3/storage/database/dialect/postgres/conn.go index aa477dfd51..24c0de92c2 100644 --- a/backend/v3/storage/database/dialect/postgres/conn.go +++ b/backend/v3/storage/database/dialect/postgres/conn.go @@ -25,7 +25,7 @@ func (c *pgxConn) Release(_ context.Context) error { func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) { tx, err := c.Conn.BeginTx(ctx, transactionOptionsToPgx(opts)) if err != nil { - return nil, err + return nil, wrapError(err) } return &pgxTx{tx}, nil } @@ -34,20 +34,26 @@ func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions) // Subtle: this method shadows the method (*Conn).Query of pgxConn.Conn. func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { rows, err := c.Conn.Query(ctx, sql, args...) - return &Rows{rows}, err + if err != nil { + return nil, wrapError(err) + } + return &Rows{rows}, nil } // QueryRow implements sql.Client. // Subtle: this method shadows the method (*Conn).QueryRow of pgxConn.Conn. func (c *pgxConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row { - return c.Conn.QueryRow(ctx, sql, args...) + return &Row{c.Conn.QueryRow(ctx, sql, args...)} } // Exec implements [database.Pool]. // Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool. func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) (int64, error) { res, err := c.Conn.Exec(ctx, sql, args...) - return res.RowsAffected(), err + if err != nil { + return 0, wrapError(err) + } + return res.RowsAffected(), nil } // Migrate implements [database.Migrator]. @@ -57,5 +63,5 @@ func (c *pgxConn) Migrate(ctx context.Context) error { } err := migration.Migrate(ctx, c.Conn.Conn()) isMigrated = err == nil - return err + return wrapError(err) } diff --git a/backend/v3/storage/database/dialect/postgres/error.go b/backend/v3/storage/database/dialect/postgres/error.go new file mode 100644 index 0000000000..89b3f8837a --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/error.go @@ -0,0 +1,38 @@ +package postgres + +import ( + "errors" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +func wrapError(err error) error { + if err == nil { + return nil + } + if errors.Is(err, pgx.ErrNoRows) { + return database.NewNoRowFoundError(err) + } + var pgxErr *pgconn.PgError + if !errors.As(err, &pgxErr) { + return database.NewUnknownError(err) + } + switch pgxErr.Code { + // 23514: check_violation - A value violates a CHECK constraint. + case "23514": + return database.NewCheckError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr) + // 23505: unique_violation - A value violates a UNIQUE constraint. + case "23505": + return database.NewUniqueError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr) + // 23503: foreign_key_violation - A value violates a foreign key constraint. + case "23503": + return database.NewForeignKeyError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr) + // 23502: not_null_violation - A value violates a NOT NULL constraint. + case "23502": + return database.NewNotNullError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr) + } + return database.NewUnknownError(err) +} diff --git a/backend/v3/storage/database/dialect/postgres/pool.go b/backend/v3/storage/database/dialect/postgres/pool.go index 4179f9398b..9722750e4c 100644 --- a/backend/v3/storage/database/dialect/postgres/pool.go +++ b/backend/v3/storage/database/dialect/postgres/pool.go @@ -25,7 +25,7 @@ func PGxPool(pool *pgxpool.Pool) *pgxPool { func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) { conn, err := c.Pool.Acquire(ctx) if err != nil { - return nil, err + return nil, wrapError(err) } return &pgxConn{Conn: conn}, nil } @@ -34,27 +34,33 @@ func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) { // Subtle: this method shadows the method (Pool).Query of pgxPool.Pool. func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { rows, err := c.Pool.Query(ctx, sql, args...) - return &Rows{rows}, err + if err != nil { + return nil, wrapError(err) + } + return &Rows{rows}, nil } // QueryRow implements [database.Pool]. // Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool. func (c *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row { - return c.Pool.QueryRow(ctx, sql, args...) + return &Row{c.Pool.QueryRow(ctx, sql, args...)} } // Exec implements [database.Pool]. // Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool. func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) { res, err := c.Pool.Exec(ctx, sql, args...) - return res.RowsAffected(), err + if err != nil { + return 0, wrapError(err) + } + return res.RowsAffected(), nil } // Begin implements [database.Pool]. func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) { tx, err := c.Pool.BeginTx(ctx, transactionOptionsToPgx(opts)) if err != nil { - return nil, err + return nil, wrapError(err) } return &pgxTx{tx}, nil } @@ -78,7 +84,7 @@ func (c *pgxPool) Migrate(ctx context.Context) error { err = migration.Migrate(ctx, client.Conn()) isMigrated = err == nil - return err + return wrapError(err) } // Migrate implements [database.PoolTest]. diff --git a/backend/v3/storage/database/dialect/postgres/rows.go b/backend/v3/storage/database/dialect/postgres/rows.go index 8dafc88f4f..d151effd59 100644 --- a/backend/v3/storage/database/dialect/postgres/rows.go +++ b/backend/v3/storage/database/dialect/postgres/rows.go @@ -10,10 +10,29 @@ import ( var ( _ database.Rows = (*Rows)(nil) _ database.CollectableRows = (*Rows)(nil) + _ database.Row = (*Row)(nil) ) +type Row struct{ pgx.Row } + +// Scan implements [database.Row]. +// Subtle: this method shadows the method ([pgx.Row]).Scan of Row.Row. +func (r *Row) Scan(dest ...any) error { + return wrapError(r.Row.Scan(dest...)) +} + type Rows struct{ pgx.Rows } +// Err implements [database.Rows]. +// Subtle: this method shadows the method ([pgx.Rows]).Err of Rows.Rows. +func (r *Rows) Err() error { + return wrapError(r.Rows.Err()) +} + +func (r *Rows) Scan(dest ...any) error { + return wrapError(r.Rows.Scan(dest...)) +} + // Collect implements [database.CollectableRows]. // See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details. func (r *Rows) Collect(dest any) (err error) { @@ -23,7 +42,7 @@ func (r *Rows) Collect(dest any) (err error) { err = closeErr } }() - return pgxscan.ScanAll(dest, r.Rows) + return wrapError(pgxscan.ScanAll(dest, r.Rows)) } // CollectFirst implements [database.CollectableRows]. @@ -35,7 +54,7 @@ func (r *Rows) CollectFirst(dest any) (err error) { err = closeErr } }() - return pgxscan.ScanRow(dest, r.Rows) + return wrapError(pgxscan.ScanRow(dest, r.Rows)) } // CollectExactlyOneRow implements [database.CollectableRows]. @@ -47,7 +66,7 @@ func (r *Rows) CollectExactlyOneRow(dest any) (err error) { err = closeErr } }() - return pgxscan.ScanOne(dest, r.Rows) + return wrapError(pgxscan.ScanOne(dest, r.Rows)) } // Close implements [database.Rows]. diff --git a/backend/v3/storage/database/dialect/postgres/tx.go b/backend/v3/storage/database/dialect/postgres/tx.go index 6a5e1c9574..6a330c16b9 100644 --- a/backend/v3/storage/database/dialect/postgres/tx.go +++ b/backend/v3/storage/database/dialect/postgres/tx.go @@ -15,12 +15,14 @@ var _ database.Transaction = (*pgxTx)(nil) // Commit implements [database.Transaction]. func (tx *pgxTx) Commit(ctx context.Context) error { - return tx.Tx.Commit(ctx) + err := tx.Tx.Commit(ctx) + return wrapError(err) } // Rollback implements [database.Transaction]. func (tx *pgxTx) Rollback(ctx context.Context) error { - return tx.Tx.Rollback(ctx) + err := tx.Tx.Rollback(ctx) + return wrapError(err) } // End implements [database.Transaction]. @@ -39,20 +41,26 @@ func (tx *pgxTx) End(ctx context.Context, err error) error { // Subtle: this method shadows the method (Tx).Query of pgxTx.Tx. func (tx *pgxTx) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { rows, err := tx.Tx.Query(ctx, sql, args...) - return &Rows{rows}, err + if err != nil { + return nil, wrapError(err) + } + return &Rows{rows}, nil } // QueryRow implements [database.Transaction]. // Subtle: this method shadows the method (Tx).QueryRow of pgxTx.Tx. func (tx *pgxTx) QueryRow(ctx context.Context, sql string, args ...any) database.Row { - return tx.Tx.QueryRow(ctx, sql, args...) + return &Row{tx.Tx.QueryRow(ctx, sql, args...)} } // Exec implements [database.Transaction]. // Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool. func (tx *pgxTx) Exec(ctx context.Context, sql string, args ...any) (int64, error) { res, err := tx.Tx.Exec(ctx, sql, args...) - return res.RowsAffected(), err + if err != nil { + return 0, wrapError(err) + } + return res.RowsAffected(), nil } // Begin implements [database.Transaction]. @@ -60,7 +68,7 @@ func (tx *pgxTx) Exec(ctx context.Context, sql string, args ...any) (int64, erro func (tx *pgxTx) Begin(ctx context.Context) (database.Transaction, error) { savepoint, err := tx.Tx.Begin(ctx) if err != nil { - return nil, err + return nil, wrapError(err) } return &pgxTx{savepoint}, nil } diff --git a/backend/v3/storage/database/errors.go b/backend/v3/storage/database/errors.go new file mode 100644 index 0000000000..a4a091516f --- /dev/null +++ b/backend/v3/storage/database/errors.go @@ -0,0 +1,227 @@ +package database + +import ( + "fmt" +) + +// NoRowFoundError is returned when QueryRow does not find any row. +// It wraps the dialect specific original error to provide more context. +type NoRowFoundError struct { + original error +} + +func NewNoRowFoundError(original error) error { + return &NoRowFoundError{ + original: original, + } +} + +func (e *NoRowFoundError) Error() string { + return "no row found" +} + +func (e *NoRowFoundError) Is(target error) bool { + _, ok := target.(*NoRowFoundError) + return ok +} + +func (e *NoRowFoundError) Unwrap() error { + return e.original +} + +// MultipleRowsFoundError is returned when QueryRow finds multiple rows. +// It wraps the dialect specific original error to provide more context. +type MultipleRowsFoundError struct { + original error + count int +} + +func NewMultipleRowsFoundError(original error, count int) error { + return &MultipleRowsFoundError{ + original: original, + count: count, + } +} + +func (e *MultipleRowsFoundError) Error() string { + return fmt.Sprintf("multiple rows found: %d", e.count) +} + +func (e *MultipleRowsFoundError) Is(target error) bool { + _, ok := target.(*MultipleRowsFoundError) + return ok +} + +func (e *MultipleRowsFoundError) Unwrap() error { + return e.original +} + +type IntegrityType string + +const ( + IntegrityTypeCheck IntegrityType = "check" + IntegrityTypeUnique IntegrityType = "unique" + IntegrityTypeForeign IntegrityType = "foreign" + IntegrityTypeNotNull IntegrityType = "not null" +) + +// IntegrityViolationError represents a generic integrity violation error. +// It wraps the dialect specific original error to provide more context. +type IntegrityViolationError struct { + integrityType IntegrityType + table string + constraint string + original error +} + +func NewIntegrityViolationError(typ IntegrityType, table, constraint string, original error) error { + return &IntegrityViolationError{ + integrityType: typ, + table: table, + constraint: constraint, + original: original, + } +} + +func (e *IntegrityViolationError) Error() string { + return fmt.Sprintf("integrity violation of type %q on %q (constraint: %q): %v", e.integrityType, e.table, e.constraint, e.original) +} + +func (e *IntegrityViolationError) Is(target error) bool { + _, ok := target.(*IntegrityViolationError) + return ok +} + +// CheckError is returned when a check constraint fails. +// It wraps the [IntegrityViolationError] to provide more context. +// It is used to indicate that a check constraint was violated during an insert or update operation. +type CheckError struct { + IntegrityViolationError +} + +func NewCheckError(table, constraint string, original error) error { + return &CheckError{ + IntegrityViolationError: IntegrityViolationError{ + integrityType: IntegrityTypeCheck, + table: table, + constraint: constraint, + original: original, + }, + } +} + +func (e *CheckError) Is(target error) bool { + _, ok := target.(*CheckError) + return ok +} + +func (e *CheckError) Unwrap() error { + return &e.IntegrityViolationError +} + +// UniqueError is returned when a unique constraint fails. +// It wraps the [IntegrityViolationError] to provide more context. +// It is used to indicate that a unique constraint was violated during an insert or update operation. +type UniqueError struct { + IntegrityViolationError +} + +func NewUniqueError(table, constraint string, original error) error { + return &UniqueError{ + IntegrityViolationError: IntegrityViolationError{ + integrityType: IntegrityTypeUnique, + table: table, + constraint: constraint, + original: original, + }, + } +} + +func (e *UniqueError) Is(target error) bool { + _, ok := target.(*UniqueError) + return ok +} + +func (e *UniqueError) Unwrap() error { + return &e.IntegrityViolationError +} + +// ForeignKeyError is returned when a foreign key constraint fails. +// It wraps the [IntegrityViolationError] to provide more context. +// It is used to indicate that a foreign key constraint was violated during an insert or update operation +type ForeignKeyError struct { + IntegrityViolationError +} + +func NewForeignKeyError(table, constraint string, original error) error { + return &ForeignKeyError{ + IntegrityViolationError: IntegrityViolationError{ + integrityType: IntegrityTypeForeign, + table: table, + constraint: constraint, + original: original, + }, + } +} + +func (e *ForeignKeyError) Is(target error) bool { + _, ok := target.(*ForeignKeyError) + return ok +} + +func (e *ForeignKeyError) Unwrap() error { + return &e.IntegrityViolationError +} + +// NotNullError is returned when a not null constraint fails. +// It wraps the [IntegrityViolationError] to provide more context. +// It is used to indicate that a not null constraint was violated during an insert or update operation. +type NotNullError struct { + IntegrityViolationError +} + +func NewNotNullError(table, constraint string, original error) error { + return &NotNullError{ + IntegrityViolationError: IntegrityViolationError{ + integrityType: IntegrityTypeNotNull, + table: table, + constraint: constraint, + original: original, + }, + } +} + +func (e *NotNullError) Is(target error) bool { + _, ok := target.(*NotNullError) + return ok +} + +func (e *NotNullError) Unwrap() error { + return &e.IntegrityViolationError +} + +// UnknownError is returned when an unknown error occurs. +// It wraps the dialect specific original error to provide more context. +// It is used to indicate that an error occurred that does not fit into any of the other categories. +type UnknownError struct { + original error +} + +func NewUnknownError(original error) error { + return &UnknownError{ + original: original, + } +} + +func (e *UnknownError) Error() string { + return fmt.Sprintf("unknown database error: %v", e.original) +} + +func (e *UnknownError) Is(target error) bool { + _, ok := target.(*UnknownError) + return ok +} + +func (e *UnknownError) Unwrap() error { + return e.original +} diff --git a/backend/v3/storage/database/repository/instance.go b/backend/v3/storage/database/repository/instance.go index cc74b95afd..63f878574c 100644 --- a/backend/v3/storage/database/repository/instance.go +++ b/backend/v3/storage/database/repository/instance.go @@ -4,8 +4,6 @@ import ( "context" "errors" - "github.com/jackc/pgx/v5/pgconn" - "github.com/zitadel/zitadel/backend/v3/domain" "github.com/zitadel/zitadel/backend/v3/storage/database" ) @@ -67,28 +65,7 @@ func (i *instance) Create(ctx context.Context, instance *domain.Instance) error builder.AppendArgs(instance.ID, instance.Name, instance.DefaultOrgID, instance.IAMProjectID, instance.ConsoleClientID, instance.ConsoleAppID, instance.DefaultLanguage) builder.WriteString(createInstanceStmt) - err := i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&instance.CreatedAt, &instance.UpdatedAt) - if err != nil { - var pgErr *pgconn.PgError - if errors.As(err, &pgErr) { - // constraint violation - if pgErr.Code == "23514" { - if pgErr.ConstraintName == "instances_name_check" { - return errors.New("instance name not provided") - } - if pgErr.ConstraintName == "instances_id_check" { - return errors.New("instance id not provided") - } - } - // duplicate - if pgErr.Code == "23505" { - if pgErr.ConstraintName == "instances_pkey" { - return errors.New("instance id already exists") - } - } - } - } - return err + return i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&instance.CreatedAt, &instance.UpdatedAt) } // Update implements [domain.InstanceRepository]. @@ -107,8 +84,7 @@ func (i instance) Update(ctx context.Context, id string, changes ...database.Cha stmt := builder.String() - rowsAffected, err := i.client.Exec(ctx, stmt, builder.Args()...) - return rowsAffected, err + return i.client.Exec(ctx, stmt, builder.Args()...) } // Delete implements [domain.InstanceRepository]. @@ -203,9 +179,6 @@ func scanInstance(ctx context.Context, querier database.Querier, builder *databa instance := new(domain.Instance) if err := rows.(database.CollectableRows).CollectExactlyOneRow(instance); err != nil { - if err.Error() == "no rows in result set" { - return nil, ErrResourceDoesNotExist - } return nil, err } @@ -219,12 +192,6 @@ func scanInstances(ctx context.Context, querier database.Querier, builder *datab } if err := rows.(database.CollectableRows).Collect(&instances); err != nil { - // if no results returned, this is not a error - // it just means the instance was not found - // the caller should check if the returned instance is nil - if err.Error() == "no rows in result set" { - return nil, nil - } return nil, err } diff --git a/backend/v3/storage/database/repository/instance_test.go b/backend/v3/storage/database/repository/instance_test.go index c3509639a8..bdc914157b 100644 --- a/backend/v3/storage/database/repository/instance_test.go +++ b/backend/v3/storage/database/repository/instance_test.go @@ -2,7 +2,6 @@ package repository_test import ( "context" - "errors" "testing" "time" @@ -55,7 +54,7 @@ func TestCreateInstance(t *testing.T) { } return instance }(), - err: errors.New("instance name not provided"), + err: new(database.CheckError), }, { name: "adding same instance twice", @@ -80,7 +79,7 @@ func TestCreateInstance(t *testing.T) { require.NoError(t, err) return &inst }, - err: errors.New("instance id already exists"), + err: new(database.UniqueError), }, func() struct { name string @@ -146,7 +145,7 @@ func TestCreateInstance(t *testing.T) { } return instance }(), - err: errors.New("instance id not provided"), + err: new(database.CheckError), }, } for _, tt := range tests { @@ -164,7 +163,7 @@ func TestCreateInstance(t *testing.T) { // create instance beforeCreate := time.Now() err := instanceRepo.Create(ctx, instance) - assert.Equal(t, tt.err, err) + assert.ErrorIs(t, err, tt.err) if err != nil { return } @@ -263,7 +262,7 @@ func TestUpdateInstance(t *testing.T) { return &inst }, rowsAffected: 0, - getErr: repository.ErrResourceDoesNotExist, + getErr: new(database.NoRowFoundError), }, } for _, tt := range tests { @@ -342,7 +341,7 @@ func TestGetInstance(t *testing.T) { } return &inst }, - err: repository.ErrResourceDoesNotExist, + err: new(database.NoRowFoundError), }, } for _, tt := range tests { @@ -360,7 +359,7 @@ func TestGetInstance(t *testing.T) { instance.ID, ) if tt.err != nil { - require.Equal(t, tt.err, err) + require.ErrorIs(t, err, tt.err) return } @@ -655,7 +654,7 @@ func TestDeleteInstance(t *testing.T) { instance, err := instanceRepo.Get(ctx, tt.instanceID, ) - require.Equal(t, err, repository.ErrResourceDoesNotExist) + require.ErrorIs(t, err, new(database.NoRowFoundError)) assert.Nil(t, instance) }) } diff --git a/backend/v3/storage/database/repository/org.go b/backend/v3/storage/database/repository/org.go index be2188bd40..e8053aadd9 100644 --- a/backend/v3/storage/database/repository/org.go +++ b/backend/v3/storage/database/repository/org.go @@ -224,9 +224,6 @@ func scanOrganization(ctx context.Context, querier database.Querier, builder *da organization := &domain.Organization{} if err := rows.(database.CollectableRows).CollectExactlyOneRow(organization); err != nil { - if err.Error() == "no rows in result set" { - return nil, ErrResourceDoesNotExist - } return nil, err } @@ -241,12 +238,6 @@ func scanOrganizations(ctx context.Context, querier database.Querier, builder *d organizations := []*domain.Organization{} if err := rows.(database.CollectableRows).Collect(&organizations); err != nil { - // if no results returned, this is not a error - // it just means the organization was not found - // the caller should check if the returned organization is nil - if err.Error() == "no rows in result set" { - return nil, nil - } return nil, err } return organizations, nil diff --git a/backend/v3/storage/database/repository/org_test.go b/backend/v3/storage/database/repository/org_test.go index 3c23ba6ca5..a6b5182f8c 100644 --- a/backend/v3/storage/database/repository/org_test.go +++ b/backend/v3/storage/database/repository/org_test.go @@ -2,7 +2,6 @@ package repository_test import ( "context" - "errors" "testing" "time" @@ -29,7 +28,7 @@ func TestCreateOrganization(t *testing.T) { } instanceRepo := repository.InstanceRepository(pool) err := instanceRepo.Create(t.Context(), &instance) - assert.Nil(t, err) + require.NoError(t, err) tests := []struct { name string @@ -64,7 +63,7 @@ func TestCreateOrganization(t *testing.T) { } return organization }(), - err: errors.New("organization name not provided"), + err: new(database.CheckError), }, { name: "adding org with same id twice", @@ -86,7 +85,7 @@ func TestCreateOrganization(t *testing.T) { org.Name = gofakeit.Name() return &org }, - err: errors.New("organization id already exists"), + err: new(database.UniqueError), }, { name: "adding org with same name twice", @@ -108,7 +107,7 @@ func TestCreateOrganization(t *testing.T) { org.ID = gofakeit.Name() return &org }, - err: errors.New("organization name already exists for instance"), + err: new(database.UniqueError), }, func() struct { name string @@ -181,7 +180,7 @@ func TestCreateOrganization(t *testing.T) { } return organization }(), - err: errors.New("organization id not provided"), + err: new(database.CheckError), }, { name: "adding organization with no instance id", @@ -195,7 +194,7 @@ func TestCreateOrganization(t *testing.T) { } return organization }(), - err: errors.New("invalid instance id"), + err: new(database.ForeignKeyError), }, { name: "adding organization with non existent instance id", @@ -210,7 +209,7 @@ func TestCreateOrganization(t *testing.T) { } return organization }(), - err: errors.New("invalid instance id"), + err: new(database.ForeignKeyError), }, } for _, tt := range tests { @@ -228,7 +227,7 @@ func TestCreateOrganization(t *testing.T) { // create organization beforeCreate := time.Now() err = organizationRepo.Create(ctx, organization) - assert.Equal(t, tt.err, err) + assert.ErrorIs(t, err, tt.err) if err != nil { return } @@ -265,7 +264,7 @@ func TestUpdateOrganization(t *testing.T) { } instanceRepo := repository.InstanceRepository(pool) err := instanceRepo.Create(t.Context(), &instance) - assert.Nil(t, err) + require.NoError(t, err) organizationRepo := repository.OrganizationRepository(pool) tests := []struct { @@ -417,7 +416,7 @@ func TestGetOrganization(t *testing.T) { } instanceRepo := repository.InstanceRepository(pool) err := instanceRepo.Create(t.Context(), &instance) - assert.Nil(t, err) + require.NoError(t, err) orgRepo := repository.OrganizationRepository(pool) @@ -497,7 +496,7 @@ func TestGetOrganization(t *testing.T) { return &org }, orgIdentifierCondition: orgRepo.NameCondition("non-existent-instance-name"), - err: repository.ErrResourceDoesNotExist, + err: new(database.NoRowFoundError), }, } for _, tt := range tests { @@ -516,7 +515,7 @@ func TestGetOrganization(t *testing.T) { org.InstanceID, ) if tt.err != nil { - require.Equal(t, tt.err, err) + require.ErrorIs(t, tt.err, err) return } @@ -553,7 +552,7 @@ func TestListOrganization(t *testing.T) { } instanceRepo := repository.InstanceRepository(pool) err = instanceRepo.Create(ctx, &instance) - assert.Nil(t, err) + require.NoError(t, err) type test struct { name string @@ -800,7 +799,7 @@ func TestDeleteOrganization(t *testing.T) { } instanceRepo := repository.InstanceRepository(pool) err := instanceRepo.Create(t.Context(), &instance) - assert.Nil(t, err) + require.NoError(t, err) type test struct { name string @@ -933,7 +932,7 @@ func TestDeleteOrganization(t *testing.T) { tt.orgIdentifierCondition, instanceId, ) - require.Equal(t, err, repository.ErrResourceDoesNotExist) + require.ErrorIs(t, err, new(database.NoRowFoundError)) assert.Nil(t, organization) }) } diff --git a/backend/v3/storage/database/repository/repository.go b/backend/v3/storage/database/repository/repository.go index 9abf656ccc..c5b9ff81f0 100644 --- a/backend/v3/storage/database/repository/repository.go +++ b/backend/v3/storage/database/repository/repository.go @@ -1,13 +1,9 @@ package repository import ( - "errors" - "github.com/zitadel/zitadel/backend/v3/storage/database" ) -var ErrResourceDoesNotExist = errors.New("resource does not exist") - type repository struct { client database.QueryExecutor }