zitadel/internal/eventstore/handler/v2/init.go
Tim Möhlmann f680dd934d
refactor: rename package errors to zerrors (#7039)
* chore: rename package errors to zerrors

* rename package errors to gerrors

* fix error related linting issues

* fix zitadel error assertion

* fix gosimple linting issues

* fix deprecated linting issues

* resolve gci linting issues

* fix import structure

---------

Co-authored-by: Elio Bischof <elio@zitadel.com>
2023-12-08 15:30:55 +01:00

414 lines
9.6 KiB
Go

package handler
import (
"context"
"errors"
"fmt"
"strings"
"github.com/jackc/pgconn"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/eventstore/handler"
"github.com/zitadel/zitadel/internal/zerrors"
)
type Table struct {
columns []*InitColumn
primaryKey PrimaryKey
indices []*Index
constraints []*Constraint
foreignKeys []*ForeignKey
}
func NewTable(columns []*InitColumn, key PrimaryKey, opts ...TableOption) *Table {
t := &Table{
columns: columns,
primaryKey: key,
}
for _, opt := range opts {
opt(t)
}
return t
}
type SuffixedTable struct {
Table
suffix string
}
func NewSuffixedTable(columns []*InitColumn, key PrimaryKey, suffix string, opts ...TableOption) *SuffixedTable {
return &SuffixedTable{
Table: *NewTable(columns, key, opts...),
suffix: suffix,
}
}
type TableOption func(*Table)
func WithIndex(index *Index) TableOption {
return func(table *Table) {
table.indices = append(table.indices, index)
}
}
func WithConstraint(constraint *Constraint) TableOption {
return func(table *Table) {
table.constraints = append(table.constraints, constraint)
}
}
func WithForeignKey(key *ForeignKey) TableOption {
return func(table *Table) {
table.foreignKeys = append(table.foreignKeys, key)
}
}
type InitColumn struct {
Name string
Type ColumnType
nullable bool
defaultValue interface{}
deleteCascade string
}
type ColumnOption func(*InitColumn)
func NewColumn(name string, columnType ColumnType, opts ...ColumnOption) *InitColumn {
column := &InitColumn{
Name: name,
Type: columnType,
nullable: false,
defaultValue: nil,
}
for _, opt := range opts {
opt(column)
}
return column
}
func Nullable() ColumnOption {
return func(c *InitColumn) {
c.nullable = true
}
}
func Default(value interface{}) ColumnOption {
return func(c *InitColumn) {
c.defaultValue = value
}
}
func DeleteCascade(column string) ColumnOption {
return func(c *InitColumn) {
c.deleteCascade = column
}
}
type PrimaryKey []string
func NewPrimaryKey(columnNames ...string) PrimaryKey {
return columnNames
}
type ColumnType int32
const (
ColumnTypeText ColumnType = iota
ColumnTypeTextArray
ColumnTypeJSONB
ColumnTypeBytes
ColumnTypeTimestamp
ColumnTypeInterval
ColumnTypeEnum
ColumnTypeEnumArray
ColumnTypeInt64
ColumnTypeBool
)
func NewIndex(name string, columns []string, opts ...indexOpts) *Index {
i := &Index{
Name: name,
Columns: columns,
}
for _, opt := range opts {
opt(i)
}
return i
}
type Index struct {
Name string
Columns []string
includes []string
}
type indexOpts func(*Index)
func WithInclude(columns ...string) indexOpts {
return func(i *Index) {
i.includes = columns
}
}
func NewConstraint(name string, columns []string) *Constraint {
i := &Constraint{
Name: name,
Columns: columns,
}
return i
}
type Constraint struct {
Name string
Columns []string
}
func NewForeignKey(name string, columns []string, refColumns []string) *ForeignKey {
i := &ForeignKey{
Name: name,
Columns: columns,
RefColumns: refColumns,
}
return i
}
func NewForeignKeyOfPublicKeys() *ForeignKey {
return &ForeignKey{
Name: "",
}
}
type ForeignKey struct {
Name string
Columns []string
RefColumns []string
}
type initializer interface {
Init() *handler.Check
}
func (h *Handler) Init(ctx context.Context) error {
check, ok := h.projection.(initializer)
if !ok || check.Init().IsNoop() {
return nil
}
tx, err := h.client.BeginTx(ctx, nil)
if err != nil {
return zerrors.ThrowInternal(err, "CRDB-SAdf2", "begin failed")
}
for i, execute := range check.Init().Executes {
logging.WithFields("projection", h.projection.Name(), "execute", i).Debug("executing check")
next, err := execute(tx, h.projection.Name())
if err != nil {
logging.OnError(tx.Rollback()).Debug("unable to rollback")
return err
}
if !next {
logging.WithFields("projection", h.projection.Name(), "execute", i).Debug("projection set up")
break
}
}
return tx.Commit()
}
func NewTableCheck(table *Table, opts ...execOption) *handler.Check {
config := execConfig{}
create := func(config execConfig) string {
return createTableStatement(table, config.tableName, "")
}
executes := make([]func(handler.Executer, string) (bool, error), len(table.indices)+1)
executes[0] = execNextIfExists(config, create, opts, true)
for i, index := range table.indices {
executes[i+1] = execNextIfExists(config, createIndexCheck(index), opts, true)
}
return &handler.Check{
Executes: executes,
}
}
func NewMultiTableCheck(primaryTable *Table, secondaryTables ...*SuffixedTable) *handler.Check {
config := execConfig{}
create := func(config execConfig) string {
stmt := createTableStatement(primaryTable, config.tableName, "")
for _, table := range secondaryTables {
stmt += createTableStatement(&table.Table, config.tableName, "_"+table.suffix)
}
return stmt
}
return &handler.Check{
Executes: []func(handler.Executer, string) (bool, error){
execNextIfExists(config, create, nil, true),
},
}
}
func NewViewCheck(selectStmt string, secondaryTables ...*SuffixedTable) *handler.Check {
config := execConfig{}
create := func(config execConfig) string {
var stmt string
for _, table := range secondaryTables {
stmt += createTableStatement(&table.Table, config.tableName, "_"+table.suffix)
}
stmt += createViewStatement(config.tableName, selectStmt)
return stmt
}
return &handler.Check{
Executes: []func(handler.Executer, string) (bool, error){
execNextIfExists(config, create, nil, false),
},
}
}
func execNextIfExists(config execConfig, q query, opts []execOption, executeNext bool) func(handler.Executer, string) (bool, error) {
return func(handler handler.Executer, name string) (bool, error) {
err := exec(config, q, opts)(handler, name)
if isErrAlreadyExists(err) {
return executeNext, nil
}
return false, err
}
}
func isErrAlreadyExists(err error) bool {
caosErr := &zerrors.ZitadelError{}
if !errors.As(err, &caosErr) {
return false
}
pgErr := new(pgconn.PgError)
if errors.As(caosErr.Parent, &pgErr) {
return pgErr.Code == "42P07"
}
return false
}
func createTableStatement(table *Table, tableName string, suffix string) string {
stmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s, PRIMARY KEY (%s)",
tableName+suffix,
createColumnsStatement(table.columns, tableName),
strings.Join(table.primaryKey, ", "),
)
for _, key := range table.foreignKeys {
ref := tableName
if len(key.RefColumns) > 0 {
ref += fmt.Sprintf("(%s)", strings.Join(key.RefColumns, ","))
}
if len(key.Columns) == 0 {
key.Columns = table.primaryKey
}
stmt += fmt.Sprintf(", CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE CASCADE", foreignKeyName(key.Name, tableName, suffix), strings.Join(key.Columns, ","), ref)
}
for _, constraint := range table.constraints {
stmt += fmt.Sprintf(", CONSTRAINT %s UNIQUE (%s)", constraintName(constraint.Name, tableName, suffix), strings.Join(constraint.Columns, ","))
}
stmt += ");"
for _, index := range table.indices {
stmt += createIndexStatement(index, tableName+suffix)
}
return stmt
}
func createViewStatement(viewName string, selectStmt string) string {
return fmt.Sprintf("CREATE VIEW %s AS %s",
viewName,
selectStmt,
)
}
func createIndexCheck(index *Index) func(config execConfig) string {
return func(config execConfig) string {
return createIndexStatement(index, config.tableName)
}
}
func createIndexStatement(index *Index, tableName string) string {
stmt := fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s (%s)",
indexName(index.Name, tableName),
tableName,
strings.Join(index.Columns, ","),
)
if len(index.includes) > 0 {
stmt += " INCLUDE (" + strings.Join(index.includes, ", ") + ")"
}
return stmt + ";"
}
func foreignKeyName(name, tableName, suffix string) string {
if name == "" {
key := "fk" + suffix + "_ref_" + tableNameWithoutSchema(tableName)
return key
}
return "fk_" + tableNameWithoutSchema(tableName+suffix) + "_" + name
}
func constraintName(name, tableName, suffix string) string {
return tableNameWithoutSchema(tableName+suffix) + "_" + name + "_unique"
}
func indexName(name, tableName string) string {
return tableNameWithoutSchema(tableName) + "_" + name + "_idx"
}
func tableNameWithoutSchema(name string) string {
return name[strings.LastIndex(name, ".")+1:]
}
func createColumnsStatement(cols []*InitColumn, tableName string) string {
columns := make([]string, len(cols))
for i, col := range cols {
column := col.Name + " " + columnType(col.Type)
if !col.nullable {
column += " NOT NULL"
}
if col.defaultValue != nil {
column += " DEFAULT " + defaultValue(col.defaultValue)
}
if len(col.deleteCascade) != 0 {
column += fmt.Sprintf(" REFERENCES %s (%s) ON DELETE CASCADE", tableName, col.deleteCascade)
}
columns[i] = column
}
return strings.Join(columns, ",")
}
func defaultValue(value interface{}) string {
switch v := value.(type) {
case string:
return "'" + v + "'"
case fmt.Stringer:
return fmt.Sprintf("%#v", v)
default:
return fmt.Sprintf("%v", v)
}
}
func columnType(columnType ColumnType) string {
switch columnType {
case ColumnTypeText:
return "TEXT"
case ColumnTypeTextArray:
return "TEXT[]"
case ColumnTypeTimestamp:
return "TIMESTAMPTZ"
case ColumnTypeInterval:
return "INTERVAL"
case ColumnTypeEnum:
return "SMALLINT"
case ColumnTypeEnumArray:
return "SMALLINT[]"
case ColumnTypeInt64:
return "BIGINT"
case ColumnTypeBool:
return "BOOLEAN"
case ColumnTypeJSONB:
return "JSONB"
case ColumnTypeBytes:
return "BYTEA"
default:
panic("unknown column type")
}
}