feat: prepare for multiple database types (#4068)

BREAKING CHANGE: the database and admin user config has changed.
This commit is contained in:
Livio Spring 2022-07-28 16:25:42 +02:00 committed by GitHub
parent bc9a85daf3
commit f610d48569
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 354 additions and 161 deletions

View File

@ -45,6 +45,7 @@ HTTP1HostHeader: "host"
WebAuthNName: ZITADEL WebAuthNName: ZITADEL
Database: Database:
cockroach:
Host: localhost Host: localhost
Port: 26257 Port: 26257
Database: zitadel Database: zitadel
@ -60,8 +61,7 @@ Database:
RootCert: "" RootCert: ""
Cert: "" Cert: ""
Key: "" Key: ""
Admin:
AdminUser:
Username: root Username: root
Password: "" Password: ""
SSL: SSL:

View File

@ -10,14 +10,15 @@ import (
type Config struct { type Config struct {
Database database.Config Database database.Config
AdminUser database.User
Machine *id.Config Machine *id.Config
Log *logging.Config Log *logging.Config
} }
func MustNewConfig(v *viper.Viper) *Config { func MustNewConfig(v *viper.Viper) *Config {
config := new(Config) config := new(Config)
err := v.Unmarshal(config) err := v.Unmarshal(config,
viper.DecodeHook(database.DecodeHook),
)
logging.OnError(err).Fatal("unable to read config") logging.OnError(err).Fatal("unable to read config")
err = config.Log.SetLogger() err = config.Log.SetLogger()
@ -25,21 +26,3 @@ func MustNewConfig(v *viper.Viper) *Config {
return config return config
} }
func adminConfig(config *Config) database.Config {
adminConfig := config.Database
adminConfig.Username = config.AdminUser.Username
adminConfig.Password = config.AdminUser.Password
adminConfig.SSL.Cert = config.AdminUser.SSL.Cert
adminConfig.SSL.Key = config.AdminUser.SSL.Key
if config.AdminUser.SSL.RootCert != "" {
adminConfig.SSL.RootCert = config.AdminUser.SSL.RootCert
}
if config.AdminUser.SSL.Mode != "" {
adminConfig.SSL.Mode = config.AdminUser.SSL.Mode
}
//use default database because the zitadel database might not exist
adminConfig.Database = ""
return adminConfig
}

View File

@ -39,10 +39,10 @@ The user provided by flags needs privileges to
func InitAll(config *Config) { func InitAll(config *Config) {
id.Configure(config.Machine) id.Configure(config.Machine)
err := initialise(config, err := initialise(config.Database,
VerifyUser(config.Database.Username, config.Database.Password), VerifyUser(config.Database.Username(), config.Database.Password()),
VerifyDatabase(config.Database.Database), VerifyDatabase(config.Database.Database()),
VerifyGrant(config.Database.Database, config.Database.Username), VerifyGrant(config.Database.Database(), config.Database.Username()),
) )
logging.OnError(err).Fatal("unable to initialize the database") logging.OnError(err).Fatal("unable to initialize the database")
@ -50,10 +50,10 @@ func InitAll(config *Config) {
logging.OnError(err).Fatal("unable to initialize ZITADEL") logging.OnError(err).Fatal("unable to initialize ZITADEL")
} }
func initialise(config *Config, steps ...func(*sql.DB) error) error { func initialise(config database.Config, steps ...func(*sql.DB) error) error {
logging.Info("initialization started") logging.Info("initialization started")
db, err := database.Connect(adminConfig(config)) db, err := database.Connect(config, true)
if err != nil { if err != nil {
return err return err
} }

View File

@ -34,7 +34,7 @@ The user provided by flags needs priviledge to
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
config := MustNewConfig(viper.New()) config := MustNewConfig(viper.New())
err := initialise(config, VerifyDatabase(config.Database.Database)) err := initialise(config.Database, VerifyDatabase(config.Database.Database()))
logging.OnError(err).Fatal("unable to initialize the database") logging.OnError(err).Fatal("unable to initialize the database")
}, },
} }

View File

@ -28,7 +28,7 @@ Prereqesits:
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
config := MustNewConfig(viper.New()) config := MustNewConfig(viper.New())
err := initialise(config, VerifyGrant(config.Database.Database, config.Database.Username)) err := initialise(config.Database, VerifyGrant(config.Database.Database(), config.Database.Username()))
logging.OnError(err).Fatal("unable to set grant") logging.OnError(err).Fatal("unable to set grant")
}, },
} }

View File

@ -33,7 +33,7 @@ The user provided by flags needs priviledge to
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
config := MustNewConfig(viper.New()) config := MustNewConfig(viper.New())
err := initialise(config, VerifyUser(config.Database.Username, config.Database.Password)) err := initialise(config.Database, VerifyUser(config.Database.Username(), config.Database.Password()))
logging.OnError(err).Fatal("unable to init user") logging.OnError(err).Fatal("unable to init user")
}, },
} }

View File

@ -95,7 +95,7 @@ func VerifyZitadel(db *sql.DB) error {
func verifyZitadel(config database.Config) error { func verifyZitadel(config database.Config) error {
logging.WithFields("database", config.Database).Info("verify zitadel") logging.WithFields("database", config.Database).Info("verify zitadel")
db, err := database.Connect(config) db, err := database.Connect(config, false)
if err != nil { if err != nil {
return err return err
} }

View File

@ -124,7 +124,7 @@ func openFile(fileName string) (io.Reader, error) {
} }
func keyStorage(config database.Config, masterKey string) (crypto.KeyStorage, error) { func keyStorage(config database.Config, masterKey string) (crypto.KeyStorage, error) {
db, err := database.Connect(config) db, err := database.Connect(config, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -36,6 +36,7 @@ func MustNewConfig(v *viper.Viper) *Config {
hook.TagToLanguageHookFunc(), hook.TagToLanguageHookFunc(),
mapstructure.StringToTimeDurationHookFunc(), mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToSliceHookFunc(","), mapstructure.StringToSliceHookFunc(","),
database.DecodeHook,
)), )),
) )
logging.OnError(err).Fatal("unable to read default config") logging.OnError(err).Fatal("unable to read default config")

View File

@ -56,7 +56,7 @@ func Flags(cmd *cobra.Command) {
func Setup(config *Config, steps *Steps, masterKey string) { func Setup(config *Config, steps *Steps, masterKey string) {
logging.Info("setup started") logging.Info("setup started")
dbClient, err := database.Connect(config.Database) dbClient, err := database.Connect(config.Database, false)
logging.OnError(err).Fatal("unable to connect to database") logging.OnError(err).Fatal("unable to connect to database")
eventstoreClient, err := eventstore.Start(dbClient) eventstoreClient, err := eventstore.Start(dbClient)

View File

@ -65,6 +65,7 @@ func MustNewConfig(v *viper.Viper) *Config {
hook.TagToLanguageHookFunc(), hook.TagToLanguageHookFunc(),
mapstructure.StringToTimeDurationHookFunc(), mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToSliceHookFunc(","), mapstructure.StringToSliceHookFunc(","),
database.DecodeHook,
)), )),
) )
logging.OnError(err).Fatal("unable to read config") logging.OnError(err).Fatal("unable to read config")

View File

@ -81,7 +81,7 @@ Requirements:
func startZitadel(config *Config, masterKey string) error { func startZitadel(config *Config, masterKey string) error {
ctx := context.Background() ctx := context.Background()
dbClient, err := database.Connect(config.Database) dbClient, err := database.Connect(config.Database, false)
if err != nil { if err != nil {
return fmt.Errorf("cannot start client for projection: %w", err) return fmt.Errorf("cannot start client for projection: %w", err)
} }
@ -175,10 +175,10 @@ func startAPIs(ctx context.Context, router *mux.Router, commands *command.Comman
if err != nil { if err != nil {
return fmt.Errorf("error starting admin repo: %w", err) return fmt.Errorf("error starting admin repo: %w", err)
} }
if err := apis.RegisterServer(ctx, system.CreateServer(commands, queries, adminRepo, config.Database.Database, config.DefaultInstance)); err != nil { if err := apis.RegisterServer(ctx, system.CreateServer(commands, queries, adminRepo, config.Database.Database(), config.DefaultInstance)); err != nil {
return err return err
} }
if err := apis.RegisterServer(ctx, admin.CreateServer(config.Database.Database, commands, queries, config.SystemDefaults, adminRepo, config.ExternalSecure, keys.User)); err != nil { if err := apis.RegisterServer(ctx, admin.CreateServer(config.Database.Database(), commands, queries, config.SystemDefaults, adminRepo, config.ExternalSecure, keys.User)); err != nil {
return err return err
} }
if err := apis.RegisterServer(ctx, management.CreateServer(commands, queries, config.SystemDefaults, keys.User, config.ExternalSecure, config.AuditLogRetention)); err != nil { if err := apis.RegisterServer(ctx, management.CreateServer(commands, queries, config.SystemDefaults, keys.User, config.ExternalSecure, config.AuditLogRetention)); err != nil {

2
go.mod
View File

@ -61,7 +61,7 @@ require (
go.opentelemetry.io/otel/sdk/export/metric v0.25.0 go.opentelemetry.io/otel/sdk/export/metric v0.25.0
go.opentelemetry.io/otel/sdk/metric v0.25.0 go.opentelemetry.io/otel/sdk/metric v0.25.0
go.opentelemetry.io/otel/trace v1.2.0 go.opentelemetry.io/otel/trace v1.2.0
golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871 golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa
golang.org/x/net v0.0.0-20220121210141-e204ce36a2ba golang.org/x/net v0.0.0-20220121210141-e204ce36a2ba
golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sync v0.0.0-20210220032951-036812b2e83c

4
go.sum
View File

@ -958,8 +958,8 @@ golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871 h1:/pEO3GD/ABYAjuakUS6xSEmmlyVS4kxBNkeA9tLJiTI= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c=
golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=

View File

@ -0,0 +1,155 @@
package cockroach
import (
"database/sql"
"strconv"
"strings"
"time"
//sql import
_ "github.com/lib/pq"
"github.com/mitchellh/mapstructure"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database/dialect"
)
const (
sslDisabledMode = "disable"
)
type Config struct {
Host string
Port uint16
Database string
MaxOpenConns uint32
MaxConnLifetime time.Duration
MaxConnIdleTime time.Duration
User User
Admin User
//Additional options to be appended as options=<Options>
//The value will be taken as is. Multiple options are space separated.
Options string
}
func (c *Config) MatchName(name string) bool {
for _, key := range []string{"crdb", "cockroach"} {
if strings.TrimSpace(strings.ToLower(name)) == key {
return true
}
}
return false
}
func (c *Config) Decode(configs []interface{}) (dialect.Connector, error) {
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
Result: c,
})
if err != nil {
return nil, err
}
for _, config := range configs {
if err = decoder.Decode(config); err != nil {
return nil, err
}
}
return c, nil
}
func (c *Config) Connect(useAdmin bool) (*sql.DB, error) {
client, err := sql.Open("postgres", c.String(useAdmin))
if err != nil {
return nil, err
}
client.SetMaxOpenConns(int(c.MaxOpenConns))
client.SetConnMaxLifetime(c.MaxConnLifetime)
client.SetConnMaxIdleTime(c.MaxConnIdleTime)
return client, nil
}
func (c *Config) DatabaseName() string {
return c.Database
}
func (c *Config) Username() string {
return c.User.Username
}
func (c *Config) Password() string {
return c.User.Password
}
func (c *Config) Type() string {
return "cockroach"
}
type User struct {
Username string
Password string
SSL SSL
}
type SSL struct {
// type of connection security
Mode string
// RootCert Path to the CA certificate
RootCert string
// Cert Path to the client certificate
Cert string
// Key Path to the client private key
Key string
}
func (c *Config) checkSSL(user User) {
if user.SSL.Mode == sslDisabledMode || user.SSL.Mode == "" {
user.SSL = SSL{Mode: sslDisabledMode}
return
}
if user.SSL.RootCert == "" {
logging.WithFields(
"cert set", user.SSL.Cert != "",
"key set", user.SSL.Key != "",
"rootCert set", user.SSL.RootCert != "",
).Fatal("at least ssl root cert has to be set")
}
}
func (c Config) String(useAdmin bool) string {
user := c.User
if useAdmin {
user = c.Admin
}
c.checkSSL(user)
fields := []string{
"host=" + c.Host,
"port=" + strconv.Itoa(int(c.Port)),
"user=" + user.Username,
"dbname=" + c.Database,
"application_name=zitadel",
"sslmode=" + user.SSL.Mode,
}
if c.Options != "" {
fields = append(fields, "options="+c.Options)
}
if !useAdmin {
fields = append(fields, "dbname="+c.Database)
}
if user.Password != "" {
fields = append(fields, "password="+user.Password)
}
if user.SSL.Mode != sslDisabledMode {
fields = append(fields, "sslrootcert="+user.SSL.RootCert)
if user.SSL.Cert != "" {
fields = append(fields, "sslcert="+user.SSL.Cert)
}
if user.SSL.Key != "" {
fields = append(fields, "sslkey="+user.SSL.Key)
}
}
return strings.Join(fields, " ")
}

View File

@ -0,0 +1,10 @@
package cockroach
import (
"github.com/zitadel/zitadel/internal/database/dialect"
)
func init() {
config := &Config{}
dialect.Register(config, config, true)
}

View File

@ -1,86 +0,0 @@
package database
import (
"strings"
"time"
"github.com/zitadel/logging"
)
const (
sslDisabledMode = "disable"
)
type Config struct {
Host string
Port string
Database string
MaxOpenConns uint32
MaxConnLifetime time.Duration
MaxConnIdleTime time.Duration
User
//Additional options to be appended as options=<Options>
//The value will be taken as is. Multiple options are space separated.
Options string
}
type User struct {
Username string
Password string
SSL SSL
}
type SSL struct {
// type of connection security
Mode string
// RootCert Path to the CA certificate
RootCert string
// Cert Path to the client certificate
Cert string
// Key Path to the client private key
Key string
}
func (s *Config) checkSSL() {
if s.SSL.Mode == sslDisabledMode || s.SSL.Mode == "" {
s.SSL = SSL{Mode: sslDisabledMode}
return
}
if s.SSL.RootCert == "" {
logging.WithFields(
"cert set", s.SSL.Cert != "",
"key set", s.SSL.Key != "",
"rootCert set", s.SSL.RootCert != "",
).Fatal("at least ssl root cert has to be set")
}
}
func (c Config) String() string {
c.checkSSL()
fields := []string{
"host=" + c.Host,
"port=" + c.Port,
"user=" + c.Username,
"dbname=" + c.Database,
"application_name=zitadel",
"sslmode=" + c.SSL.Mode,
}
if c.Options != "" {
fields = append(fields, "options="+c.Options)
}
if c.Password != "" {
fields = append(fields, "password="+c.Password)
}
if c.SSL.Mode != sslDisabledMode {
fields = append(fields, "sslrootcert="+c.SSL.RootCert)
if c.SSL.Cert != "" {
fields = append(fields, "sslcert="+c.SSL.Cert)
}
if c.SSL.Key != "" {
fields = append(fields, "sslkey="+c.SSL.Key)
}
}
return strings.Join(fields, " ")
}

View File

@ -2,25 +2,92 @@ package database
import ( import (
"database/sql" "database/sql"
//sql import "reflect"
_ "github.com/lib/pq"
"github.com/zitadel/zitadel/internal/errors" _ "github.com/zitadel/zitadel/internal/database/cockroach"
"github.com/zitadel/zitadel/internal/database/dialect"
) )
func Connect(config Config) (*sql.DB, error) { type Config struct {
client, err := sql.Open("postgres", config.String()) Dialects map[string]interface{} `mapstructure:",remain"`
connector dialect.Connector
}
func (c *Config) SetConnector(connector dialect.Connector) {
c.connector = connector
}
type User struct {
Username string
Password string
SSL SSL
}
type SSL struct {
// type of connection security
Mode string
// RootCert Path to the CA certificate
RootCert string
// Cert Path to the client certificate
Cert string
// Key Path to the client private key
Key string
}
func Connect(config Config, useAdmin bool) (*sql.DB, error) {
client, err := config.connector.Connect(useAdmin)
if err != nil { if err != nil {
return nil, err return nil, err
} }
client.SetMaxOpenConns(int(config.MaxOpenConns))
client.SetConnMaxLifetime(config.MaxConnLifetime)
client.SetConnMaxIdleTime(config.MaxConnIdleTime)
if err := client.Ping(); err != nil { if err := client.Ping(); err != nil {
return nil, errors.ThrowPreconditionFailed(err, "DATAB-0pIWD", "Errors.Database.Connection.Failed") return nil, err
} }
return client, nil return client, nil
} }
func DecodeHook(from, to reflect.Value) (interface{}, error) {
if to.Type() != reflect.TypeOf(Config{}) {
return from.Interface(), nil
}
configuredDialects, ok := from.Interface().(map[string]interface{})
if !ok {
return from.Interface(), nil
}
configuredDialect := dialect.SelectByConfig(configuredDialects)
configs := make([]interface{}, 0, len(configuredDialects)-1)
for name, dialectConfig := range configuredDialects {
if !configuredDialect.Matcher.MatchName(name) {
continue
}
configs = append(configs, dialectConfig)
}
connector, err := configuredDialect.Matcher.Decode(configs)
if err != nil {
return nil, err
}
return Config{connector: connector}, nil
}
func (c Config) Database() string {
return c.connector.DatabaseName()
}
func (c Config) Username() string {
return c.connector.Username()
}
func (c Config) Password() string {
return c.connector.Password()
}
func (c Config) Type() string {
return c.connector.Type()
}

View File

@ -0,0 +1,62 @@
package dialect
import (
"database/sql"
"sync"
)
type Config struct {
Dialects map[string]interface{} `mapstructure:",remain"`
Dialect Matcher
}
type Dialect struct {
Matcher Matcher
Config Connector
IsDefault bool
}
var (
dialects []*Dialect
defaultDialect *Dialect
dialectsMu sync.Mutex
)
type Matcher interface {
MatchName(string) bool
Decode([]interface{}) (Connector, error)
}
type Connector interface {
Connect(useAdmin bool) (*sql.DB, error)
DatabaseName() string
Username() string
Password() string
Type() string
}
func Register(matcher Matcher, config Connector, isDefault bool) {
dialectsMu.Lock()
defer dialectsMu.Unlock()
d := &Dialect{Matcher: matcher, Config: config}
if isDefault {
defaultDialect = d
return
}
dialects = append(dialects, d)
}
func SelectByConfig(config map[string]interface{}) *Dialect {
for key := range config {
for _, d := range dialects {
if d.Matcher.MatchName(key) {
return d
}
}
}
return defaultDialect
}