From f610d48569916469767e637b0f6c5d9e6e6afabd Mon Sep 17 00:00:00 2001 From: Livio Spring Date: Thu, 28 Jul 2022 16:25:42 +0200 Subject: [PATCH] feat: prepare for multiple database types (#4068) BREAKING CHANGE: the database and admin user config has changed. --- cmd/defaults.yaml | 48 ++++---- cmd/initialise/config.go | 29 +---- cmd/initialise/init.go | 12 +- cmd/initialise/verify_database.go | 2 +- cmd/initialise/verify_grant.go | 2 +- cmd/initialise/verify_user.go | 2 +- cmd/initialise/verify_zitadel.go | 2 +- cmd/key/key.go | 2 +- cmd/setup/config.go | 1 + cmd/setup/setup.go | 2 +- cmd/start/config.go | 1 + cmd/start/start.go | 6 +- go.mod | 2 +- go.sum | 4 +- internal/database/cockroach/config.go | 155 ++++++++++++++++++++++++++ internal/database/cockroach/crdb.go | 10 ++ internal/database/config.go | 86 -------------- internal/database/database.go | 87 +++++++++++++-- internal/database/dialect/config.go | 62 +++++++++++ 19 files changed, 354 insertions(+), 161 deletions(-) create mode 100644 internal/database/cockroach/config.go create mode 100644 internal/database/cockroach/crdb.go delete mode 100644 internal/database/config.go create mode 100644 internal/database/dialect/config.go diff --git a/cmd/defaults.yaml b/cmd/defaults.yaml index a05d1695d1..4c7872b5e4 100644 --- a/cmd/defaults.yaml +++ b/cmd/defaults.yaml @@ -45,30 +45,30 @@ HTTP1HostHeader: "host" WebAuthNName: ZITADEL Database: - Host: localhost - Port: 26257 - Database: zitadel - MaxOpenConns: 20 - MaxConnLifetime: 30m - MaxConnIdleTime: 30m - Options: "" - User: - Username: zitadel - Password: "" - SSL: - Mode: disable - RootCert: "" - Cert: "" - Key: "" - -AdminUser: - Username: root - Password: "" - SSL: - Mode: disable - RootCert: "" - Cert: "" - Key: "" + cockroach: + Host: localhost + Port: 26257 + Database: zitadel + MaxOpenConns: 20 + MaxConnLifetime: 30m + MaxConnIdleTime: 30m + Options: "" + User: + Username: zitadel + Password: "" + SSL: + Mode: disable + RootCert: "" + Cert: "" + Key: "" + Admin: + Username: root + Password: "" + SSL: + Mode: disable + RootCert: "" + Cert: "" + Key: "" Machine: # Cloud hosted VMs need to specify their metadata endpoint so that the machine can be uniquely identified. diff --git a/cmd/initialise/config.go b/cmd/initialise/config.go index dbbd7b5e53..848bece05b 100644 --- a/cmd/initialise/config.go +++ b/cmd/initialise/config.go @@ -9,15 +9,16 @@ import ( ) type Config struct { - Database database.Config - AdminUser database.User - Machine *id.Config - Log *logging.Config + Database database.Config + Machine *id.Config + Log *logging.Config } func MustNewConfig(v *viper.Viper) *Config { config := new(Config) - err := v.Unmarshal(config) + err := v.Unmarshal(config, + viper.DecodeHook(database.DecodeHook), + ) logging.OnError(err).Fatal("unable to read config") err = config.Log.SetLogger() @@ -25,21 +26,3 @@ func MustNewConfig(v *viper.Viper) *Config { return config } - -func adminConfig(config *Config) database.Config { - adminConfig := config.Database - adminConfig.Username = config.AdminUser.Username - adminConfig.Password = config.AdminUser.Password - adminConfig.SSL.Cert = config.AdminUser.SSL.Cert - adminConfig.SSL.Key = config.AdminUser.SSL.Key - if config.AdminUser.SSL.RootCert != "" { - adminConfig.SSL.RootCert = config.AdminUser.SSL.RootCert - } - if config.AdminUser.SSL.Mode != "" { - adminConfig.SSL.Mode = config.AdminUser.SSL.Mode - } - //use default database because the zitadel database might not exist - adminConfig.Database = "" - - return adminConfig -} diff --git a/cmd/initialise/init.go b/cmd/initialise/init.go index 04cf47e0bc..3f06ea571a 100644 --- a/cmd/initialise/init.go +++ b/cmd/initialise/init.go @@ -39,10 +39,10 @@ The user provided by flags needs privileges to func InitAll(config *Config) { id.Configure(config.Machine) - err := initialise(config, - VerifyUser(config.Database.Username, config.Database.Password), - VerifyDatabase(config.Database.Database), - VerifyGrant(config.Database.Database, config.Database.Username), + err := initialise(config.Database, + VerifyUser(config.Database.Username(), config.Database.Password()), + VerifyDatabase(config.Database.Database()), + VerifyGrant(config.Database.Database(), config.Database.Username()), ) logging.OnError(err).Fatal("unable to initialize the database") @@ -50,10 +50,10 @@ func InitAll(config *Config) { logging.OnError(err).Fatal("unable to initialize ZITADEL") } -func initialise(config *Config, steps ...func(*sql.DB) error) error { +func initialise(config database.Config, steps ...func(*sql.DB) error) error { logging.Info("initialization started") - db, err := database.Connect(adminConfig(config)) + db, err := database.Connect(config, true) if err != nil { return err } diff --git a/cmd/initialise/verify_database.go b/cmd/initialise/verify_database.go index 06fb1857a5..1269cdde8d 100644 --- a/cmd/initialise/verify_database.go +++ b/cmd/initialise/verify_database.go @@ -34,7 +34,7 @@ The user provided by flags needs priviledge to Run: func(cmd *cobra.Command, args []string) { config := MustNewConfig(viper.New()) - err := initialise(config, VerifyDatabase(config.Database.Database)) + err := initialise(config.Database, VerifyDatabase(config.Database.Database())) logging.OnError(err).Fatal("unable to initialize the database") }, } diff --git a/cmd/initialise/verify_grant.go b/cmd/initialise/verify_grant.go index f9e1e1e248..0abbcae180 100644 --- a/cmd/initialise/verify_grant.go +++ b/cmd/initialise/verify_grant.go @@ -28,7 +28,7 @@ Prereqesits: Run: func(cmd *cobra.Command, args []string) { config := MustNewConfig(viper.New()) - err := initialise(config, VerifyGrant(config.Database.Database, config.Database.Username)) + err := initialise(config.Database, VerifyGrant(config.Database.Database(), config.Database.Username())) logging.OnError(err).Fatal("unable to set grant") }, } diff --git a/cmd/initialise/verify_user.go b/cmd/initialise/verify_user.go index 386bab96b4..aa36c723c6 100644 --- a/cmd/initialise/verify_user.go +++ b/cmd/initialise/verify_user.go @@ -33,7 +33,7 @@ The user provided by flags needs priviledge to Run: func(cmd *cobra.Command, args []string) { config := MustNewConfig(viper.New()) - err := initialise(config, VerifyUser(config.Database.Username, config.Database.Password)) + err := initialise(config.Database, VerifyUser(config.Database.Username(), config.Database.Password())) logging.OnError(err).Fatal("unable to init user") }, } diff --git a/cmd/initialise/verify_zitadel.go b/cmd/initialise/verify_zitadel.go index 82a61409ec..c1a3c9386a 100644 --- a/cmd/initialise/verify_zitadel.go +++ b/cmd/initialise/verify_zitadel.go @@ -95,7 +95,7 @@ func VerifyZitadel(db *sql.DB) error { func verifyZitadel(config database.Config) error { logging.WithFields("database", config.Database).Info("verify zitadel") - db, err := database.Connect(config) + db, err := database.Connect(config, false) if err != nil { return err } diff --git a/cmd/key/key.go b/cmd/key/key.go index 784f7a1c0b..02fa272a8a 100644 --- a/cmd/key/key.go +++ b/cmd/key/key.go @@ -124,7 +124,7 @@ func openFile(fileName string) (io.Reader, error) { } func keyStorage(config database.Config, masterKey string) (crypto.KeyStorage, error) { - db, err := database.Connect(config) + db, err := database.Connect(config, false) if err != nil { return nil, err } diff --git a/cmd/setup/config.go b/cmd/setup/config.go index 5036aabb42..06093002bb 100644 --- a/cmd/setup/config.go +++ b/cmd/setup/config.go @@ -36,6 +36,7 @@ func MustNewConfig(v *viper.Viper) *Config { hook.TagToLanguageHookFunc(), mapstructure.StringToTimeDurationHookFunc(), mapstructure.StringToSliceHookFunc(","), + database.DecodeHook, )), ) logging.OnError(err).Fatal("unable to read default config") diff --git a/cmd/setup/setup.go b/cmd/setup/setup.go index 935a241326..8e7b81ba8c 100644 --- a/cmd/setup/setup.go +++ b/cmd/setup/setup.go @@ -56,7 +56,7 @@ func Flags(cmd *cobra.Command) { func Setup(config *Config, steps *Steps, masterKey string) { logging.Info("setup started") - dbClient, err := database.Connect(config.Database) + dbClient, err := database.Connect(config.Database, false) logging.OnError(err).Fatal("unable to connect to database") eventstoreClient, err := eventstore.Start(dbClient) diff --git a/cmd/start/config.go b/cmd/start/config.go index 045ebf16bd..66f97b06c7 100644 --- a/cmd/start/config.go +++ b/cmd/start/config.go @@ -65,6 +65,7 @@ func MustNewConfig(v *viper.Viper) *Config { hook.TagToLanguageHookFunc(), mapstructure.StringToTimeDurationHookFunc(), mapstructure.StringToSliceHookFunc(","), + database.DecodeHook, )), ) logging.OnError(err).Fatal("unable to read config") diff --git a/cmd/start/start.go b/cmd/start/start.go index 92668f4d24..9d152aae8a 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -81,7 +81,7 @@ Requirements: func startZitadel(config *Config, masterKey string) error { ctx := context.Background() - dbClient, err := database.Connect(config.Database) + dbClient, err := database.Connect(config.Database, false) if err != nil { return fmt.Errorf("cannot start client for projection: %w", err) } @@ -175,10 +175,10 @@ func startAPIs(ctx context.Context, router *mux.Router, commands *command.Comman if err != nil { return fmt.Errorf("error starting admin repo: %w", err) } - if err := apis.RegisterServer(ctx, system.CreateServer(commands, queries, adminRepo, config.Database.Database, config.DefaultInstance)); err != nil { + if err := apis.RegisterServer(ctx, system.CreateServer(commands, queries, adminRepo, config.Database.Database(), config.DefaultInstance)); err != nil { return err } - if err := apis.RegisterServer(ctx, admin.CreateServer(config.Database.Database, commands, queries, config.SystemDefaults, adminRepo, config.ExternalSecure, keys.User)); err != nil { + if err := apis.RegisterServer(ctx, admin.CreateServer(config.Database.Database(), commands, queries, config.SystemDefaults, adminRepo, config.ExternalSecure, keys.User)); err != nil { return err } if err := apis.RegisterServer(ctx, management.CreateServer(commands, queries, config.SystemDefaults, keys.User, config.ExternalSecure, config.AuditLogRetention)); err != nil { diff --git a/go.mod b/go.mod index f56d089b5a..db123eb4f1 100644 --- a/go.mod +++ b/go.mod @@ -61,7 +61,7 @@ require ( go.opentelemetry.io/otel/sdk/export/metric v0.25.0 go.opentelemetry.io/otel/sdk/metric v0.25.0 go.opentelemetry.io/otel/trace v1.2.0 - golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871 + golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa golang.org/x/net v0.0.0-20220121210141-e204ce36a2ba golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c diff --git a/go.sum b/go.sum index 4110d1ee0a..fc2c7bee5c 100644 --- a/go.sum +++ b/go.sum @@ -958,8 +958,8 @@ golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871 h1:/pEO3GD/ABYAjuakUS6xSEmmlyVS4kxBNkeA9tLJiTI= -golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= diff --git a/internal/database/cockroach/config.go b/internal/database/cockroach/config.go new file mode 100644 index 0000000000..e31acc6015 --- /dev/null +++ b/internal/database/cockroach/config.go @@ -0,0 +1,155 @@ +package cockroach + +import ( + "database/sql" + "strconv" + "strings" + "time" + + //sql import + _ "github.com/lib/pq" + + "github.com/mitchellh/mapstructure" + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/database/dialect" +) + +const ( + sslDisabledMode = "disable" +) + +type Config struct { + Host string + Port uint16 + Database string + MaxOpenConns uint32 + MaxConnLifetime time.Duration + MaxConnIdleTime time.Duration + User User + Admin User + + //Additional options to be appended as options= + //The value will be taken as is. Multiple options are space separated. + Options string +} + +func (c *Config) MatchName(name string) bool { + for _, key := range []string{"crdb", "cockroach"} { + if strings.TrimSpace(strings.ToLower(name)) == key { + return true + } + } + return false +} + +func (c *Config) Decode(configs []interface{}) (dialect.Connector, error) { + decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + DecodeHook: mapstructure.StringToTimeDurationHookFunc(), + Result: c, + }) + if err != nil { + return nil, err + } + + for _, config := range configs { + if err = decoder.Decode(config); err != nil { + return nil, err + } + } + return c, nil +} + +func (c *Config) Connect(useAdmin bool) (*sql.DB, error) { + client, err := sql.Open("postgres", c.String(useAdmin)) + if err != nil { + return nil, err + } + client.SetMaxOpenConns(int(c.MaxOpenConns)) + client.SetConnMaxLifetime(c.MaxConnLifetime) + client.SetConnMaxIdleTime(c.MaxConnIdleTime) + return client, nil +} + +func (c *Config) DatabaseName() string { + return c.Database +} + +func (c *Config) Username() string { + return c.User.Username +} + +func (c *Config) Password() string { + return c.User.Password +} + +func (c *Config) Type() string { + return "cockroach" +} + +type User struct { + Username string + Password string + SSL SSL +} + +type SSL struct { + // type of connection security + Mode string + // RootCert Path to the CA certificate + RootCert string + // Cert Path to the client certificate + Cert string + // Key Path to the client private key + Key string +} + +func (c *Config) checkSSL(user User) { + if user.SSL.Mode == sslDisabledMode || user.SSL.Mode == "" { + user.SSL = SSL{Mode: sslDisabledMode} + return + } + if user.SSL.RootCert == "" { + logging.WithFields( + "cert set", user.SSL.Cert != "", + "key set", user.SSL.Key != "", + "rootCert set", user.SSL.RootCert != "", + ).Fatal("at least ssl root cert has to be set") + } +} + +func (c Config) String(useAdmin bool) string { + user := c.User + if useAdmin { + user = c.Admin + } + c.checkSSL(user) + fields := []string{ + "host=" + c.Host, + "port=" + strconv.Itoa(int(c.Port)), + "user=" + user.Username, + "dbname=" + c.Database, + "application_name=zitadel", + "sslmode=" + user.SSL.Mode, + } + if c.Options != "" { + fields = append(fields, "options="+c.Options) + } + if !useAdmin { + fields = append(fields, "dbname="+c.Database) + } + if user.Password != "" { + fields = append(fields, "password="+user.Password) + } + if user.SSL.Mode != sslDisabledMode { + fields = append(fields, "sslrootcert="+user.SSL.RootCert) + if user.SSL.Cert != "" { + fields = append(fields, "sslcert="+user.SSL.Cert) + } + if user.SSL.Key != "" { + fields = append(fields, "sslkey="+user.SSL.Key) + } + } + + return strings.Join(fields, " ") +} diff --git a/internal/database/cockroach/crdb.go b/internal/database/cockroach/crdb.go new file mode 100644 index 0000000000..da634c9c83 --- /dev/null +++ b/internal/database/cockroach/crdb.go @@ -0,0 +1,10 @@ +package cockroach + +import ( + "github.com/zitadel/zitadel/internal/database/dialect" +) + +func init() { + config := &Config{} + dialect.Register(config, config, true) +} diff --git a/internal/database/config.go b/internal/database/config.go deleted file mode 100644 index 56e84a0417..0000000000 --- a/internal/database/config.go +++ /dev/null @@ -1,86 +0,0 @@ -package database - -import ( - "strings" - "time" - - "github.com/zitadel/logging" -) - -const ( - sslDisabledMode = "disable" -) - -type Config struct { - Host string - Port string - Database string - MaxOpenConns uint32 - MaxConnLifetime time.Duration - MaxConnIdleTime time.Duration - User - - //Additional options to be appended as options= - //The value will be taken as is. Multiple options are space separated. - Options string -} - -type User struct { - Username string - Password string - SSL SSL -} - -type SSL struct { - // type of connection security - Mode string - // RootCert Path to the CA certificate - RootCert string - // Cert Path to the client certificate - Cert string - // Key Path to the client private key - Key string -} - -func (s *Config) checkSSL() { - if s.SSL.Mode == sslDisabledMode || s.SSL.Mode == "" { - s.SSL = SSL{Mode: sslDisabledMode} - return - } - if s.SSL.RootCert == "" { - logging.WithFields( - "cert set", s.SSL.Cert != "", - "key set", s.SSL.Key != "", - "rootCert set", s.SSL.RootCert != "", - ).Fatal("at least ssl root cert has to be set") - } -} - -func (c Config) String() string { - c.checkSSL() - fields := []string{ - "host=" + c.Host, - "port=" + c.Port, - "user=" + c.Username, - "dbname=" + c.Database, - "application_name=zitadel", - "sslmode=" + c.SSL.Mode, - } - if c.Options != "" { - fields = append(fields, "options="+c.Options) - } - if c.Password != "" { - fields = append(fields, "password="+c.Password) - } - if c.SSL.Mode != sslDisabledMode { - fields = append(fields, "sslrootcert="+c.SSL.RootCert) - if c.SSL.Cert != "" { - fields = append(fields, "sslcert="+c.SSL.Cert) - } - if c.SSL.Key != "" { - fields = append(fields, "sslkey="+c.SSL.Key) - } - } - - return strings.Join(fields, " ") -} diff --git a/internal/database/database.go b/internal/database/database.go index fbb40a2449..3f2315b926 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -2,25 +2,92 @@ package database import ( "database/sql" - //sql import - _ "github.com/lib/pq" + "reflect" - "github.com/zitadel/zitadel/internal/errors" + _ "github.com/zitadel/zitadel/internal/database/cockroach" + "github.com/zitadel/zitadel/internal/database/dialect" ) -func Connect(config Config) (*sql.DB, error) { - client, err := sql.Open("postgres", config.String()) +type Config struct { + Dialects map[string]interface{} `mapstructure:",remain"` + connector dialect.Connector +} + +func (c *Config) SetConnector(connector dialect.Connector) { + c.connector = connector +} + +type User struct { + Username string + Password string + SSL SSL +} + +type SSL struct { + // type of connection security + Mode string + // RootCert Path to the CA certificate + RootCert string + // Cert Path to the client certificate + Cert string + // Key Path to the client private key + Key string +} + +func Connect(config Config, useAdmin bool) (*sql.DB, error) { + client, err := config.connector.Connect(useAdmin) if err != nil { return nil, err } - client.SetMaxOpenConns(int(config.MaxOpenConns)) - client.SetConnMaxLifetime(config.MaxConnLifetime) - client.SetConnMaxIdleTime(config.MaxConnIdleTime) - if err := client.Ping(); err != nil { - return nil, errors.ThrowPreconditionFailed(err, "DATAB-0pIWD", "Errors.Database.Connection.Failed") + return nil, err } return client, nil } + +func DecodeHook(from, to reflect.Value) (interface{}, error) { + if to.Type() != reflect.TypeOf(Config{}) { + return from.Interface(), nil + } + + configuredDialects, ok := from.Interface().(map[string]interface{}) + if !ok { + return from.Interface(), nil + } + + configuredDialect := dialect.SelectByConfig(configuredDialects) + configs := make([]interface{}, 0, len(configuredDialects)-1) + + for name, dialectConfig := range configuredDialects { + if !configuredDialect.Matcher.MatchName(name) { + continue + } + + configs = append(configs, dialectConfig) + } + + connector, err := configuredDialect.Matcher.Decode(configs) + if err != nil { + return nil, err + } + + return Config{connector: connector}, nil +} + +func (c Config) Database() string { + return c.connector.DatabaseName() +} + +func (c Config) Username() string { + return c.connector.Username() +} + +func (c Config) Password() string { + return c.connector.Password() +} + +func (c Config) Type() string { + return c.connector.Type() +} diff --git a/internal/database/dialect/config.go b/internal/database/dialect/config.go new file mode 100644 index 0000000000..b043c9f370 --- /dev/null +++ b/internal/database/dialect/config.go @@ -0,0 +1,62 @@ +package dialect + +import ( + "database/sql" + "sync" +) + +type Config struct { + Dialects map[string]interface{} `mapstructure:",remain"` + Dialect Matcher +} + +type Dialect struct { + Matcher Matcher + Config Connector + IsDefault bool +} + +var ( + dialects []*Dialect + defaultDialect *Dialect + dialectsMu sync.Mutex +) + +type Matcher interface { + MatchName(string) bool + Decode([]interface{}) (Connector, error) +} + +type Connector interface { + Connect(useAdmin bool) (*sql.DB, error) + DatabaseName() string + Username() string + Password() string + Type() string +} + +func Register(matcher Matcher, config Connector, isDefault bool) { + dialectsMu.Lock() + defer dialectsMu.Unlock() + + d := &Dialect{Matcher: matcher, Config: config} + + if isDefault { + defaultDialect = d + return + } + + dialects = append(dialects, d) +} + +func SelectByConfig(config map[string]interface{}) *Dialect { + for key := range config { + for _, d := range dialects { + if d.Matcher.MatchName(key) { + return d + } + } + } + + return defaultDialect +}