zitadel/internal/eventstore/v2/local_crdb_test.go

129 lines
2.8 KiB
Go
Raw Normal View History

2020-10-23 14:16:46 +00:00
package eventstore_test
import (
"database/sql"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"testing"
"github.com/caos/logging"
"github.com/cockroachdb/cockroach-go/v2/testserver"
)
var (
migrationsPath = os.ExpandEnv("${GOPATH}/src/github.com/caos/zitadel/migrations/cockroach")
testCRDBClient *sql.DB
)
func TestMain(m *testing.M) {
ts, err := testserver.NewTestServer()
if err != nil {
logging.LogWithFields("REPOS-RvjLG", "error", err).Fatal("unable to start db")
}
testCRDBClient, err = sql.Open("postgres", ts.PGURL().String())
if err != nil {
logging.LogWithFields("REPOS-CF6dQ", "error", err).Fatal("unable to connect to db")
}
defer func() {
testCRDBClient.Close()
ts.Stop()
}()
if err = executeMigrations(); err != nil {
logging.LogWithFields("REPOS-jehDD", "error", err).Fatal("migrations failed")
}
os.Exit(m.Run())
}
func executeMigrations() error {
files, err := migrationFilePaths()
if err != nil {
return err
}
sort.Sort(files)
for _, file := range files {
migration, err := ioutil.ReadFile(string(file))
if err != nil {
return err
}
transactionInMigration := strings.Contains(string(migration), "BEGIN;")
exec := testCRDBClient.Exec
var tx *sql.Tx
if !transactionInMigration {
tx, err = testCRDBClient.Begin()
if err != nil {
return fmt.Errorf("begin file: %v || err: %w", file, err)
}
exec = tx.Exec
}
if _, err = exec(string(migration)); err != nil {
return fmt.Errorf("exec file: %v || err: %w", file, err)
}
if !transactionInMigration {
if err = tx.Commit(); err != nil {
return fmt.Errorf("commit file: %v || err: %w", file, err)
}
}
}
return nil
}
type migrationPaths []string
type version struct {
major int
minor int
}
func versionFromPath(s string) version {
v := s[strings.Index(s, "/V")+2 : strings.Index(s, "__")]
splitted := strings.Split(v, ".")
res := version{}
var err error
if len(splitted) >= 1 {
res.major, err = strconv.Atoi(splitted[0])
if err != nil {
panic(err)
}
}
if len(splitted) >= 2 {
res.minor, err = strconv.Atoi(splitted[1])
if err != nil {
panic(err)
}
}
return res
}
func (a migrationPaths) Len() int { return len(a) }
func (a migrationPaths) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a migrationPaths) Less(i, j int) bool {
versionI := versionFromPath(a[i])
versionJ := versionFromPath(a[j])
return versionI.major < versionJ.major ||
(versionI.major == versionJ.major && versionI.minor < versionJ.minor)
}
func migrationFilePaths() (migrationPaths, error) {
files := make(migrationPaths, 0)
err := filepath.Walk(migrationsPath, func(path string, info os.FileInfo, err error) error {
if err != nil || info.IsDir() || !strings.HasSuffix(info.Name(), ".sql") {
return err
}
files = append(files, path)
return nil
})
return files, err
}