zitadel/internal/eventstore/v3/unique_constraints.go
Silvan 2246f9fa30
fix(eventstore): differentiate unique constraint error (#6832)
* fix(eventstore): differentiate unique constraint error format

* docs: add comment to eventstore vars

* fix(eventstore): return correct error type if unique constraint already exists

(cherry picked from commit f8bf8ea2562b388bbd90debdd86bd551499e9f37)
2023-10-27 14:47:55 +02:00

90 lines
3.4 KiB
Go

package eventstore
import (
"context"
"database/sql"
_ "embed"
"errors"
"fmt"
"strings"
"github.com/jackc/pgconn"
"github.com/zitadel/logging"
errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
)
var (
//go:embed unique_constraints_delete.sql
deleteConstraintStmt string
//go:embed unique_constraints_add.sql
addConstraintStmt string
)
func handleUniqueConstraints(ctx context.Context, tx *sql.Tx, commands []eventstore.Command) error {
deletePlaceholders := make([]string, 0)
deleteArgs := make([]any, 0)
addPlaceholders := make([]string, 0)
addArgs := make([]any, 0)
addConstraints := map[string]*eventstore.UniqueConstraint{}
deleteConstraints := map[string]*eventstore.UniqueConstraint{}
for _, command := range commands {
for _, constraint := range command.UniqueConstraints() {
switch constraint.Action {
case eventstore.UniqueConstraintAdd:
addPlaceholders = append(addPlaceholders, fmt.Sprintf("($%d, $%d, $%d)", len(addArgs)+1, len(addArgs)+2, len(addArgs)+3))
addArgs = append(addArgs, command.Aggregate().InstanceID, constraint.UniqueType, constraint.UniqueField)
addConstraints[fmt.Sprintf(uniqueConstraintPlaceholderFmt, command.Aggregate().InstanceID, constraint.UniqueType, constraint.UniqueField)] = constraint
case eventstore.UniqueConstraintRemove:
deletePlaceholders = append(deletePlaceholders, fmt.Sprintf("(instance_id = $%d AND unique_type = $%d AND unique_field = $%d)", len(deleteArgs)+1, len(deleteArgs)+2, len(deleteArgs)+3))
deleteArgs = append(deleteArgs, command.Aggregate().InstanceID, constraint.UniqueType, constraint.UniqueField)
deleteConstraints[fmt.Sprintf(uniqueConstraintPlaceholderFmt, command.Aggregate().InstanceID, constraint.UniqueType, constraint.UniqueField)] = constraint
case eventstore.UniqueConstraintInstanceRemove:
deletePlaceholders = append(deletePlaceholders, fmt.Sprintf("(instance_id = $%d)", len(deleteArgs)+1))
deleteArgs = append(deleteArgs, command.Aggregate().InstanceID)
deleteConstraints[fmt.Sprintf(uniqueConstraintPlaceholderFmt, command.Aggregate().InstanceID, constraint.UniqueType, constraint.UniqueField)] = constraint
}
}
}
if len(deletePlaceholders) > 0 {
_, err := tx.ExecContext(ctx, fmt.Sprintf(deleteConstraintStmt, strings.Join(deletePlaceholders, " OR ")), deleteArgs...)
if err != nil {
logging.WithError(err).Warn("delete unique constraint failed")
errMessage := "Errors.Internal"
if constraint := constraintFromErr(err, deleteConstraints); constraint != nil {
errMessage = constraint.ErrorMessage
}
return errs.ThrowInternal(err, "V3-C8l3V", errMessage)
}
}
if len(addPlaceholders) > 0 {
_, err := tx.ExecContext(ctx, fmt.Sprintf(addConstraintStmt, strings.Join(addPlaceholders, ", ")), addArgs...)
if err != nil {
logging.WithError(err).Warn("add unique constraint failed")
errMessage := "Errors.Internal"
if constraint := constraintFromErr(err, addConstraints); constraint != nil {
errMessage = constraint.ErrorMessage
}
return errs.ThrowAlreadyExists(err, "V3-DKcYh", errMessage)
}
}
return nil
}
func constraintFromErr(err error, constraints map[string]*eventstore.UniqueConstraint) *eventstore.UniqueConstraint {
pgErr := new(pgconn.PgError)
if !errors.As(err, &pgErr) {
return nil
}
for key, constraint := range constraints {
if strings.Contains(pgErr.Detail, key) {
return constraint
}
}
return nil
}