mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 18:17:35 +00:00
feat: prepare for multiple database types (#4068)
BREAKING CHANGE: the database and admin user config has changed.
This commit is contained in:
155
internal/database/cockroach/config.go
Normal file
155
internal/database/cockroach/config.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package cockroach
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
//sql import
|
||||
_ "github.com/lib/pq"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database/dialect"
|
||||
)
|
||||
|
||||
const (
|
||||
sslDisabledMode = "disable"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Host string
|
||||
Port uint16
|
||||
Database string
|
||||
MaxOpenConns uint32
|
||||
MaxConnLifetime time.Duration
|
||||
MaxConnIdleTime time.Duration
|
||||
User User
|
||||
Admin User
|
||||
|
||||
//Additional options to be appended as options=<Options>
|
||||
//The value will be taken as is. Multiple options are space separated.
|
||||
Options string
|
||||
}
|
||||
|
||||
func (c *Config) MatchName(name string) bool {
|
||||
for _, key := range []string{"crdb", "cockroach"} {
|
||||
if strings.TrimSpace(strings.ToLower(name)) == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Config) Decode(configs []interface{}) (dialect.Connector, error) {
|
||||
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||
DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
|
||||
Result: c,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, config := range configs {
|
||||
if err = decoder.Decode(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Config) Connect(useAdmin bool) (*sql.DB, error) {
|
||||
client, err := sql.Open("postgres", c.String(useAdmin))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client.SetMaxOpenConns(int(c.MaxOpenConns))
|
||||
client.SetConnMaxLifetime(c.MaxConnLifetime)
|
||||
client.SetConnMaxIdleTime(c.MaxConnIdleTime)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Config) DatabaseName() string {
|
||||
return c.Database
|
||||
}
|
||||
|
||||
func (c *Config) Username() string {
|
||||
return c.User.Username
|
||||
}
|
||||
|
||||
func (c *Config) Password() string {
|
||||
return c.User.Password
|
||||
}
|
||||
|
||||
func (c *Config) Type() string {
|
||||
return "cockroach"
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Username string
|
||||
Password string
|
||||
SSL SSL
|
||||
}
|
||||
|
||||
type SSL struct {
|
||||
// type of connection security
|
||||
Mode string
|
||||
// RootCert Path to the CA certificate
|
||||
RootCert string
|
||||
// Cert Path to the client certificate
|
||||
Cert string
|
||||
// Key Path to the client private key
|
||||
Key string
|
||||
}
|
||||
|
||||
func (c *Config) checkSSL(user User) {
|
||||
if user.SSL.Mode == sslDisabledMode || user.SSL.Mode == "" {
|
||||
user.SSL = SSL{Mode: sslDisabledMode}
|
||||
return
|
||||
}
|
||||
if user.SSL.RootCert == "" {
|
||||
logging.WithFields(
|
||||
"cert set", user.SSL.Cert != "",
|
||||
"key set", user.SSL.Key != "",
|
||||
"rootCert set", user.SSL.RootCert != "",
|
||||
).Fatal("at least ssl root cert has to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func (c Config) String(useAdmin bool) string {
|
||||
user := c.User
|
||||
if useAdmin {
|
||||
user = c.Admin
|
||||
}
|
||||
c.checkSSL(user)
|
||||
fields := []string{
|
||||
"host=" + c.Host,
|
||||
"port=" + strconv.Itoa(int(c.Port)),
|
||||
"user=" + user.Username,
|
||||
"dbname=" + c.Database,
|
||||
"application_name=zitadel",
|
||||
"sslmode=" + user.SSL.Mode,
|
||||
}
|
||||
if c.Options != "" {
|
||||
fields = append(fields, "options="+c.Options)
|
||||
}
|
||||
if !useAdmin {
|
||||
fields = append(fields, "dbname="+c.Database)
|
||||
}
|
||||
if user.Password != "" {
|
||||
fields = append(fields, "password="+user.Password)
|
||||
}
|
||||
if user.SSL.Mode != sslDisabledMode {
|
||||
fields = append(fields, "sslrootcert="+user.SSL.RootCert)
|
||||
if user.SSL.Cert != "" {
|
||||
fields = append(fields, "sslcert="+user.SSL.Cert)
|
||||
}
|
||||
if user.SSL.Key != "" {
|
||||
fields = append(fields, "sslkey="+user.SSL.Key)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(fields, " ")
|
||||
}
|
10
internal/database/cockroach/crdb.go
Normal file
10
internal/database/cockroach/crdb.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package cockroach
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/internal/database/dialect"
|
||||
)
|
||||
|
||||
func init() {
|
||||
config := &Config{}
|
||||
dialect.Register(config, config, true)
|
||||
}
|
@@ -1,86 +0,0 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
)
|
||||
|
||||
const (
|
||||
sslDisabledMode = "disable"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Host string
|
||||
Port string
|
||||
Database string
|
||||
MaxOpenConns uint32
|
||||
MaxConnLifetime time.Duration
|
||||
MaxConnIdleTime time.Duration
|
||||
User
|
||||
|
||||
//Additional options to be appended as options=<Options>
|
||||
//The value will be taken as is. Multiple options are space separated.
|
||||
Options string
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Username string
|
||||
Password string
|
||||
SSL SSL
|
||||
}
|
||||
|
||||
type SSL struct {
|
||||
// type of connection security
|
||||
Mode string
|
||||
// RootCert Path to the CA certificate
|
||||
RootCert string
|
||||
// Cert Path to the client certificate
|
||||
Cert string
|
||||
// Key Path to the client private key
|
||||
Key string
|
||||
}
|
||||
|
||||
func (s *Config) checkSSL() {
|
||||
if s.SSL.Mode == sslDisabledMode || s.SSL.Mode == "" {
|
||||
s.SSL = SSL{Mode: sslDisabledMode}
|
||||
return
|
||||
}
|
||||
if s.SSL.RootCert == "" {
|
||||
logging.WithFields(
|
||||
"cert set", s.SSL.Cert != "",
|
||||
"key set", s.SSL.Key != "",
|
||||
"rootCert set", s.SSL.RootCert != "",
|
||||
).Fatal("at least ssl root cert has to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func (c Config) String() string {
|
||||
c.checkSSL()
|
||||
fields := []string{
|
||||
"host=" + c.Host,
|
||||
"port=" + c.Port,
|
||||
"user=" + c.Username,
|
||||
"dbname=" + c.Database,
|
||||
"application_name=zitadel",
|
||||
"sslmode=" + c.SSL.Mode,
|
||||
}
|
||||
if c.Options != "" {
|
||||
fields = append(fields, "options="+c.Options)
|
||||
}
|
||||
if c.Password != "" {
|
||||
fields = append(fields, "password="+c.Password)
|
||||
}
|
||||
if c.SSL.Mode != sslDisabledMode {
|
||||
fields = append(fields, "sslrootcert="+c.SSL.RootCert)
|
||||
if c.SSL.Cert != "" {
|
||||
fields = append(fields, "sslcert="+c.SSL.Cert)
|
||||
}
|
||||
if c.SSL.Key != "" {
|
||||
fields = append(fields, "sslkey="+c.SSL.Key)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(fields, " ")
|
||||
}
|
@@ -2,25 +2,92 @@ package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
//sql import
|
||||
_ "github.com/lib/pq"
|
||||
"reflect"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
_ "github.com/zitadel/zitadel/internal/database/cockroach"
|
||||
"github.com/zitadel/zitadel/internal/database/dialect"
|
||||
)
|
||||
|
||||
func Connect(config Config) (*sql.DB, error) {
|
||||
client, err := sql.Open("postgres", config.String())
|
||||
type Config struct {
|
||||
Dialects map[string]interface{} `mapstructure:",remain"`
|
||||
connector dialect.Connector
|
||||
}
|
||||
|
||||
func (c *Config) SetConnector(connector dialect.Connector) {
|
||||
c.connector = connector
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Username string
|
||||
Password string
|
||||
SSL SSL
|
||||
}
|
||||
|
||||
type SSL struct {
|
||||
// type of connection security
|
||||
Mode string
|
||||
// RootCert Path to the CA certificate
|
||||
RootCert string
|
||||
// Cert Path to the client certificate
|
||||
Cert string
|
||||
// Key Path to the client private key
|
||||
Key string
|
||||
}
|
||||
|
||||
func Connect(config Config, useAdmin bool) (*sql.DB, error) {
|
||||
client, err := config.connector.Connect(useAdmin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client.SetMaxOpenConns(int(config.MaxOpenConns))
|
||||
client.SetConnMaxLifetime(config.MaxConnLifetime)
|
||||
client.SetConnMaxIdleTime(config.MaxConnIdleTime)
|
||||
|
||||
if err := client.Ping(); err != nil {
|
||||
return nil, errors.ThrowPreconditionFailed(err, "DATAB-0pIWD", "Errors.Database.Connection.Failed")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func DecodeHook(from, to reflect.Value) (interface{}, error) {
|
||||
if to.Type() != reflect.TypeOf(Config{}) {
|
||||
return from.Interface(), nil
|
||||
}
|
||||
|
||||
configuredDialects, ok := from.Interface().(map[string]interface{})
|
||||
if !ok {
|
||||
return from.Interface(), nil
|
||||
}
|
||||
|
||||
configuredDialect := dialect.SelectByConfig(configuredDialects)
|
||||
configs := make([]interface{}, 0, len(configuredDialects)-1)
|
||||
|
||||
for name, dialectConfig := range configuredDialects {
|
||||
if !configuredDialect.Matcher.MatchName(name) {
|
||||
continue
|
||||
}
|
||||
|
||||
configs = append(configs, dialectConfig)
|
||||
}
|
||||
|
||||
connector, err := configuredDialect.Matcher.Decode(configs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return Config{connector: connector}, nil
|
||||
}
|
||||
|
||||
func (c Config) Database() string {
|
||||
return c.connector.DatabaseName()
|
||||
}
|
||||
|
||||
func (c Config) Username() string {
|
||||
return c.connector.Username()
|
||||
}
|
||||
|
||||
func (c Config) Password() string {
|
||||
return c.connector.Password()
|
||||
}
|
||||
|
||||
func (c Config) Type() string {
|
||||
return c.connector.Type()
|
||||
}
|
||||
|
62
internal/database/dialect/config.go
Normal file
62
internal/database/dialect/config.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package dialect
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Dialects map[string]interface{} `mapstructure:",remain"`
|
||||
Dialect Matcher
|
||||
}
|
||||
|
||||
type Dialect struct {
|
||||
Matcher Matcher
|
||||
Config Connector
|
||||
IsDefault bool
|
||||
}
|
||||
|
||||
var (
|
||||
dialects []*Dialect
|
||||
defaultDialect *Dialect
|
||||
dialectsMu sync.Mutex
|
||||
)
|
||||
|
||||
type Matcher interface {
|
||||
MatchName(string) bool
|
||||
Decode([]interface{}) (Connector, error)
|
||||
}
|
||||
|
||||
type Connector interface {
|
||||
Connect(useAdmin bool) (*sql.DB, error)
|
||||
DatabaseName() string
|
||||
Username() string
|
||||
Password() string
|
||||
Type() string
|
||||
}
|
||||
|
||||
func Register(matcher Matcher, config Connector, isDefault bool) {
|
||||
dialectsMu.Lock()
|
||||
defer dialectsMu.Unlock()
|
||||
|
||||
d := &Dialect{Matcher: matcher, Config: config}
|
||||
|
||||
if isDefault {
|
||||
defaultDialect = d
|
||||
return
|
||||
}
|
||||
|
||||
dialects = append(dialects, d)
|
||||
}
|
||||
|
||||
func SelectByConfig(config map[string]interface{}) *Dialect {
|
||||
for key := range config {
|
||||
for _, d := range dialects {
|
||||
if d.Matcher.MatchName(key) {
|
||||
return d
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return defaultDialect
|
||||
}
|
Reference in New Issue
Block a user