mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 21:27:42 +00:00
feat(eventstore): increase parallel write capabilities (#5940)
This implementation increases parallel write capabilities of the eventstore. Please have a look at the technical advisories: [05](https://zitadel.com/docs/support/advisory/a10005) and [06](https://zitadel.com/docs/support/advisory/a10006). The implementation of eventstore.push is rewritten and stored events are migrated to a new table `eventstore.events2`. If you are using cockroach: make sure that the database user of ZITADEL has `VIEWACTIVITY` grant. This is used to query events.
This commit is contained in:
@@ -1,169 +0,0 @@
|
||||
package cockroach
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database/dialect"
|
||||
)
|
||||
|
||||
const (
|
||||
sslDisabledMode = "disable"
|
||||
sslRequireMode = "require"
|
||||
sslAllowMode = "allow"
|
||||
sslPreferMode = "prefer"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Host string
|
||||
Port uint16
|
||||
Database string
|
||||
MaxOpenConns uint32
|
||||
MaxIdleConns uint32
|
||||
MaxConnLifetime time.Duration
|
||||
MaxConnIdleTime time.Duration
|
||||
User User
|
||||
Admin User
|
||||
|
||||
//Additional options to be appended as options=<Options>
|
||||
//The value will be taken as is. Multiple options are space separated.
|
||||
Options string
|
||||
}
|
||||
|
||||
func (c *Config) MatchName(name string) bool {
|
||||
for _, key := range []string{"crdb", "cockroach"} {
|
||||
if strings.TrimSpace(strings.ToLower(name)) == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Config) Decode(configs []interface{}) (dialect.Connector, error) {
|
||||
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||
DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
|
||||
WeaklyTypedInput: true,
|
||||
Result: c,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, config := range configs {
|
||||
if err = decoder.Decode(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Config) Connect(useAdmin bool) (*sql.DB, error) {
|
||||
client, err := sql.Open("pgx", c.String(useAdmin))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client.SetMaxOpenConns(int(c.MaxOpenConns))
|
||||
client.SetMaxIdleConns(int(c.MaxIdleConns))
|
||||
client.SetConnMaxLifetime(c.MaxConnLifetime)
|
||||
client.SetConnMaxIdleTime(c.MaxConnIdleTime)
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Config) DatabaseName() string {
|
||||
return c.Database
|
||||
}
|
||||
|
||||
func (c *Config) Username() string {
|
||||
return c.User.Username
|
||||
}
|
||||
|
||||
func (c *Config) Password() string {
|
||||
return c.User.Password
|
||||
}
|
||||
|
||||
func (c *Config) Type() string {
|
||||
return "cockroach"
|
||||
}
|
||||
|
||||
func (c *Config) Timetravel(d time.Duration) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Username string
|
||||
Password string
|
||||
SSL SSL
|
||||
}
|
||||
|
||||
type SSL struct {
|
||||
// type of connection security
|
||||
Mode string
|
||||
// RootCert Path to the CA certificate
|
||||
RootCert string
|
||||
// Cert Path to the client certificate
|
||||
Cert string
|
||||
// Key Path to the client private key
|
||||
Key string
|
||||
}
|
||||
|
||||
func (c *Config) checkSSL(user User) {
|
||||
if user.SSL.Mode == sslDisabledMode || user.SSL.Mode == "" {
|
||||
user.SSL = SSL{Mode: sslDisabledMode}
|
||||
return
|
||||
}
|
||||
|
||||
if user.SSL.Mode == sslRequireMode || user.SSL.Mode == sslAllowMode || user.SSL.Mode == sslPreferMode {
|
||||
return
|
||||
}
|
||||
|
||||
if user.SSL.RootCert == "" {
|
||||
logging.WithFields(
|
||||
"cert set", user.SSL.Cert != "",
|
||||
"key set", user.SSL.Key != "",
|
||||
"rootCert set", user.SSL.RootCert != "",
|
||||
).Fatal("at least ssl root cert has to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func (c Config) String(useAdmin bool) string {
|
||||
user := c.User
|
||||
if useAdmin {
|
||||
user = c.Admin
|
||||
}
|
||||
c.checkSSL(user)
|
||||
fields := []string{
|
||||
"host=" + c.Host,
|
||||
"port=" + strconv.Itoa(int(c.Port)),
|
||||
"user=" + user.Username,
|
||||
"dbname=" + c.Database,
|
||||
"application_name=zitadel",
|
||||
"sslmode=" + user.SSL.Mode,
|
||||
}
|
||||
if c.Options != "" {
|
||||
fields = append(fields, "options="+c.Options)
|
||||
}
|
||||
if !useAdmin {
|
||||
fields = append(fields, "dbname="+c.Database)
|
||||
}
|
||||
if user.Password != "" {
|
||||
fields = append(fields, "password="+user.Password)
|
||||
}
|
||||
if user.SSL.Mode != sslDisabledMode {
|
||||
fields = append(fields, "sslrootcert="+user.SSL.RootCert)
|
||||
if user.SSL.Cert != "" {
|
||||
fields = append(fields, "sslcert="+user.SSL.Cert)
|
||||
}
|
||||
if user.SSL.Key != "" {
|
||||
fields = append(fields, "sslkey="+user.SSL.Key)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(fields, " ")
|
||||
}
|
@@ -1,6 +1,15 @@
|
||||
package cockroach
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/jackc/pgx/v4/stdlib"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database/dialect"
|
||||
)
|
||||
|
||||
@@ -8,3 +17,173 @@ func init() {
|
||||
config := &Config{}
|
||||
dialect.Register(config, config, true)
|
||||
}
|
||||
|
||||
const (
|
||||
sslDisabledMode = "disable"
|
||||
sslRequireMode = "require"
|
||||
sslAllowMode = "allow"
|
||||
sslPreferMode = "prefer"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Host string
|
||||
Port uint16
|
||||
Database string
|
||||
MaxOpenConns uint32
|
||||
MaxIdleConns uint32
|
||||
MaxConnLifetime time.Duration
|
||||
MaxConnIdleTime time.Duration
|
||||
User User
|
||||
Admin User
|
||||
// Additional options to be appended as options=<Options>
|
||||
// The value will be taken as is. Multiple options are space separated.
|
||||
Options string
|
||||
}
|
||||
|
||||
func (c *Config) MatchName(name string) bool {
|
||||
for _, key := range []string{"crdb", "cockroach"} {
|
||||
if strings.TrimSpace(strings.ToLower(name)) == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Config) Decode(configs []interface{}) (dialect.Connector, error) {
|
||||
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||
DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
|
||||
WeaklyTypedInput: true,
|
||||
Result: c,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, config := range configs {
|
||||
if err = decoder.Decode(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Config) Connect(useAdmin, isEventPusher bool, pusherRatio float32, appName string) (*sql.DB, error) {
|
||||
client, err := sql.Open("pgx", c.String(useAdmin, appName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
connInfo, err := dialect.NewConnectionInfo(c.MaxOpenConns, c.MaxIdleConns, float64(pusherRatio))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var maxConns, maxIdleConns uint32
|
||||
if isEventPusher {
|
||||
maxConns = connInfo.EventstorePusher.MaxOpenConns
|
||||
maxIdleConns = connInfo.EventstorePusher.MaxIdleConns
|
||||
} else {
|
||||
maxConns = connInfo.ZITADEL.MaxOpenConns
|
||||
maxIdleConns = connInfo.ZITADEL.MaxIdleConns
|
||||
}
|
||||
|
||||
client.SetMaxOpenConns(int(maxConns))
|
||||
client.SetMaxIdleConns(int(maxIdleConns))
|
||||
client.SetConnMaxLifetime(c.MaxConnLifetime)
|
||||
client.SetConnMaxIdleTime(c.MaxConnIdleTime)
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Config) DatabaseName() string {
|
||||
return c.Database
|
||||
}
|
||||
|
||||
func (c *Config) Username() string {
|
||||
return c.User.Username
|
||||
}
|
||||
|
||||
func (c *Config) Password() string {
|
||||
return c.User.Password
|
||||
}
|
||||
|
||||
func (c *Config) Type() string {
|
||||
return "cockroach"
|
||||
}
|
||||
|
||||
func (c *Config) Timetravel(d time.Duration) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Username string
|
||||
Password string
|
||||
SSL SSL
|
||||
}
|
||||
|
||||
type SSL struct {
|
||||
// type of connection security
|
||||
Mode string
|
||||
// RootCert Path to the CA certificate
|
||||
RootCert string
|
||||
// Cert Path to the client certificate
|
||||
Cert string
|
||||
// Key Path to the client private key
|
||||
Key string
|
||||
}
|
||||
|
||||
func (c *Config) checkSSL(user User) {
|
||||
if user.SSL.Mode == sslDisabledMode || user.SSL.Mode == "" {
|
||||
user.SSL = SSL{Mode: sslDisabledMode}
|
||||
return
|
||||
}
|
||||
|
||||
if user.SSL.Mode == sslRequireMode || user.SSL.Mode == sslAllowMode || user.SSL.Mode == sslPreferMode {
|
||||
return
|
||||
}
|
||||
|
||||
if user.SSL.RootCert == "" {
|
||||
logging.WithFields(
|
||||
"cert set", user.SSL.Cert != "",
|
||||
"key set", user.SSL.Key != "",
|
||||
"rootCert set", user.SSL.RootCert != "",
|
||||
).Fatal("at least ssl root cert has to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func (c Config) String(useAdmin bool, appName string) string {
|
||||
user := c.User
|
||||
if useAdmin {
|
||||
user = c.Admin
|
||||
}
|
||||
c.checkSSL(user)
|
||||
fields := []string{
|
||||
"host=" + c.Host,
|
||||
"port=" + strconv.Itoa(int(c.Port)),
|
||||
"user=" + user.Username,
|
||||
"dbname=" + c.Database,
|
||||
"application_name=" + appName,
|
||||
"sslmode=" + user.SSL.Mode,
|
||||
}
|
||||
if c.Options != "" {
|
||||
fields = append(fields, "options="+c.Options)
|
||||
}
|
||||
if !useAdmin {
|
||||
fields = append(fields, "dbname="+c.Database)
|
||||
}
|
||||
if user.Password != "" {
|
||||
fields = append(fields, "password="+user.Password)
|
||||
}
|
||||
if user.SSL.Mode != sslDisabledMode {
|
||||
fields = append(fields, "sslrootcert="+user.SSL.RootCert)
|
||||
if user.SSL.Cert != "" {
|
||||
fields = append(fields, "sslcert="+user.SSL.Cert)
|
||||
}
|
||||
if user.SSL.Key != "" {
|
||||
fields = append(fields, "sslkey="+user.SSL.Key)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(fields, " ")
|
||||
}
|
||||
|
@@ -5,6 +5,7 @@ import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
_ "github.com/zitadel/zitadel/internal/database/cockroach"
|
||||
@@ -14,8 +15,9 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Dialects map[string]interface{} `mapstructure:",remain"`
|
||||
connector dialect.Connector
|
||||
Dialects map[string]interface{} `mapstructure:",remain"`
|
||||
EventPushConnRatio float32
|
||||
connector dialect.Connector
|
||||
}
|
||||
|
||||
func (c *Config) SetConnector(connector dialect.Connector) {
|
||||
@@ -87,8 +89,18 @@ func (db *DB) QueryRowContext(ctx context.Context, scan func(row *sql.Row) error
|
||||
return row.Err()
|
||||
}
|
||||
|
||||
func Connect(config Config, useAdmin bool) (*DB, error) {
|
||||
client, err := config.connector.Connect(useAdmin)
|
||||
const (
|
||||
zitadelAppName = "zitadel"
|
||||
EventstorePusherAppName = "zitadel_es_pusher"
|
||||
)
|
||||
|
||||
func Connect(config Config, useAdmin, isEventPusher bool) (*DB, error) {
|
||||
appName := zitadelAppName
|
||||
if isEventPusher {
|
||||
appName = EventstorePusherAppName
|
||||
}
|
||||
|
||||
client, err := config.connector.Connect(useAdmin, isEventPusher, config.EventPushConnRatio, appName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -103,20 +115,20 @@ func Connect(config Config, useAdmin bool) (*DB, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func DecodeHook(from, to reflect.Value) (interface{}, error) {
|
||||
func DecodeHook(from, to reflect.Value) (_ interface{}, err error) {
|
||||
if to.Type() != reflect.TypeOf(Config{}) {
|
||||
return from.Interface(), nil
|
||||
}
|
||||
|
||||
configuredDialects, ok := from.Interface().(map[string]interface{})
|
||||
if !ok {
|
||||
return from.Interface(), nil
|
||||
config := new(Config)
|
||||
if err = mapstructure.Decode(from.Interface(), config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
configuredDialect := dialect.SelectByConfig(configuredDialects)
|
||||
configs := make([]interface{}, 0, len(configuredDialects)-1)
|
||||
configuredDialect := dialect.SelectByConfig(config.Dialects)
|
||||
configs := make([]interface{}, 0, len(config.Dialects)-1)
|
||||
|
||||
for name, dialectConfig := range configuredDialects {
|
||||
for name, dialectConfig := range config.Dialects {
|
||||
if !configuredDialect.Matcher.MatchName(name) {
|
||||
continue
|
||||
}
|
||||
@@ -124,12 +136,12 @@ func DecodeHook(from, to reflect.Value) (interface{}, error) {
|
||||
configs = append(configs, dialectConfig)
|
||||
}
|
||||
|
||||
connector, err := configuredDialect.Matcher.Decode(configs)
|
||||
config.connector, err = configuredDialect.Matcher.Decode(configs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return Config{connector: connector}, nil
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (c Config) DatabaseName() string {
|
||||
|
@@ -6,11 +6,6 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Dialects map[string]interface{} `mapstructure:",remain"`
|
||||
Dialect Matcher
|
||||
}
|
||||
|
||||
type Dialect struct {
|
||||
Matcher Matcher
|
||||
Config Connector
|
||||
@@ -29,7 +24,7 @@ type Matcher interface {
|
||||
}
|
||||
|
||||
type Connector interface {
|
||||
Connect(useAdmin bool) (*sql.DB, error)
|
||||
Connect(useAdmin, isEventPusher bool, pusherRatio float32, appName string) (*sql.DB, error)
|
||||
Password() string
|
||||
Database
|
||||
}
|
||||
|
39
internal/database/dialect/connections.go
Normal file
39
internal/database/dialect/connections.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package dialect
|
||||
|
||||
import "errors"
|
||||
|
||||
type ConnectionInfo struct {
|
||||
EventstorePusher ConnectionConfig
|
||||
ZITADEL ConnectionConfig
|
||||
}
|
||||
|
||||
type ConnectionConfig struct {
|
||||
MaxOpenConns,
|
||||
MaxIdleConns uint32
|
||||
}
|
||||
|
||||
func NewConnectionInfo(openConns, idleConns uint32, pusherRatio float64) (*ConnectionInfo, error) {
|
||||
if pusherRatio < 0 || pusherRatio > 1 {
|
||||
return nil, errors.New("EventPushConnRatio must be between 0 and 1")
|
||||
}
|
||||
if openConns < 2 {
|
||||
return nil, errors.New("MaxOpenConns of the database must be higher that 1")
|
||||
}
|
||||
|
||||
info := new(ConnectionInfo)
|
||||
|
||||
info.EventstorePusher.MaxOpenConns = uint32(pusherRatio * float64(openConns))
|
||||
info.EventstorePusher.MaxIdleConns = uint32(pusherRatio * float64(idleConns))
|
||||
|
||||
if info.EventstorePusher.MaxOpenConns < 1 && pusherRatio > 0 {
|
||||
info.EventstorePusher.MaxOpenConns = 1
|
||||
}
|
||||
if info.EventstorePusher.MaxIdleConns < 1 && pusherRatio > 0 {
|
||||
info.EventstorePusher.MaxIdleConns = 1
|
||||
}
|
||||
|
||||
info.ZITADEL.MaxOpenConns = openConns - info.EventstorePusher.MaxOpenConns
|
||||
info.ZITADEL.MaxIdleConns = idleConns - info.EventstorePusher.MaxIdleConns
|
||||
|
||||
return info, nil
|
||||
}
|
129
internal/database/mock/sql_mock.go
Normal file
129
internal/database/mock/sql_mock.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
)
|
||||
|
||||
type SQLMock struct {
|
||||
DB *sql.DB
|
||||
mock sqlmock.Sqlmock
|
||||
}
|
||||
|
||||
type expectation func(m sqlmock.Sqlmock)
|
||||
|
||||
func NewSQLMock(t *testing.T, expectations ...expectation) *SQLMock {
|
||||
db, mock, err := sqlmock.New(
|
||||
sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal("create mock failed", err)
|
||||
}
|
||||
|
||||
for _, expectation := range expectations {
|
||||
expectation(mock)
|
||||
}
|
||||
|
||||
return &SQLMock{
|
||||
DB: db,
|
||||
mock: mock,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SQLMock) Assert(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
if err := m.mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("expectations not met: %v", err)
|
||||
}
|
||||
|
||||
m.DB.Close()
|
||||
}
|
||||
|
||||
func ExpectBegin(err error) expectation {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
e := m.ExpectBegin()
|
||||
if err != nil {
|
||||
e.WillReturnError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type ExecOpt func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec
|
||||
|
||||
func WithExecArgs(args ...driver.Value) ExecOpt {
|
||||
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
|
||||
return e.WithArgs(args...)
|
||||
}
|
||||
}
|
||||
|
||||
func WithExecErr(err error) ExecOpt {
|
||||
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
|
||||
return e.WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func WithExecNoRowsAffected() ExecOpt {
|
||||
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
|
||||
return e.WillReturnResult(driver.ResultNoRows)
|
||||
}
|
||||
}
|
||||
|
||||
func WithExecRowsAffected(affected driver.RowsAffected) ExecOpt {
|
||||
return func(e *sqlmock.ExpectedExec) *sqlmock.ExpectedExec {
|
||||
return e.WillReturnResult(affected)
|
||||
}
|
||||
}
|
||||
|
||||
func ExcpectExec(stmt string, opts ...ExecOpt) expectation {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
e := m.ExpectExec(stmt)
|
||||
for _, opt := range opts {
|
||||
e = opt(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type QueryOpt func(e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery
|
||||
|
||||
func WithQueryArgs(args ...driver.Value) QueryOpt {
|
||||
return func(e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
|
||||
return e.WithArgs(args...)
|
||||
}
|
||||
}
|
||||
|
||||
func WithQueryErr(err error) QueryOpt {
|
||||
return func(e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
|
||||
return e.WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func WithQueryResult(columns []string, rows [][]driver.Value) QueryOpt {
|
||||
return func(e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
|
||||
mockedRows := sqlmock.NewRows(columns)
|
||||
for _, row := range rows {
|
||||
mockedRows = mockedRows.AddRow(row...)
|
||||
}
|
||||
return e.WillReturnRows(mockedRows)
|
||||
}
|
||||
}
|
||||
|
||||
func ExpectQuery(stmt string, opts ...QueryOpt) expectation {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
e := m.ExpectQuery(stmt)
|
||||
for _, opt := range opts {
|
||||
e = opt(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type AnyType[T interface{}] struct{}
|
||||
|
||||
// Match satisfies sqlmock.Argument interface
|
||||
func (a AnyType[T]) Match(v driver.Value) bool {
|
||||
return reflect.TypeOf(new(T)).Elem().Kind().String() == reflect.TypeOf(v).Kind().String()
|
||||
}
|
@@ -1,172 +0,0 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database/dialect"
|
||||
)
|
||||
|
||||
const (
|
||||
sslDisabledMode = "disable"
|
||||
sslRequireMode = "require"
|
||||
sslAllowMode = "allow"
|
||||
sslPreferMode = "prefer"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Host string
|
||||
Port int32
|
||||
Database string
|
||||
MaxOpenConns uint32
|
||||
MaxIdleConns uint32
|
||||
MaxConnLifetime time.Duration
|
||||
MaxConnIdleTime time.Duration
|
||||
User User
|
||||
Admin User
|
||||
|
||||
//Additional options to be appended as options=<Options>
|
||||
//The value will be taken as is. Multiple options are space separated.
|
||||
Options string
|
||||
}
|
||||
|
||||
func (c *Config) MatchName(name string) bool {
|
||||
for _, key := range []string{"pg", "postgres"} {
|
||||
if strings.TrimSpace(strings.ToLower(name)) == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Config) Decode(configs []interface{}) (dialect.Connector, error) {
|
||||
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||
DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
|
||||
WeaklyTypedInput: true,
|
||||
Result: c,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, config := range configs {
|
||||
if err = decoder.Decode(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Config) Connect(useAdmin bool) (*sql.DB, error) {
|
||||
db, err := sql.Open("pgx", c.String(useAdmin))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db.SetMaxOpenConns(int(c.MaxOpenConns))
|
||||
db.SetMaxIdleConns(int(c.MaxIdleConns))
|
||||
db.SetConnMaxLifetime(c.MaxConnLifetime)
|
||||
db.SetConnMaxIdleTime(c.MaxConnIdleTime)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func (c *Config) DatabaseName() string {
|
||||
return c.Database
|
||||
}
|
||||
|
||||
func (c *Config) Username() string {
|
||||
return c.User.Username
|
||||
}
|
||||
|
||||
func (c *Config) Password() string {
|
||||
return c.User.Password
|
||||
}
|
||||
|
||||
func (c *Config) Type() string {
|
||||
return "postgres"
|
||||
}
|
||||
|
||||
func (c *Config) Timetravel(time.Duration) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Username string
|
||||
Password string
|
||||
SSL SSL
|
||||
}
|
||||
|
||||
type SSL struct {
|
||||
// type of connection security
|
||||
Mode string
|
||||
// RootCert Path to the CA certificate
|
||||
RootCert string
|
||||
// Cert Path to the client certificate
|
||||
Cert string
|
||||
// Key Path to the client private key
|
||||
Key string
|
||||
}
|
||||
|
||||
func (s *Config) checkSSL(user User) {
|
||||
if user.SSL.Mode == sslDisabledMode || user.SSL.Mode == "" {
|
||||
user.SSL = SSL{Mode: sslDisabledMode}
|
||||
return
|
||||
}
|
||||
|
||||
if user.SSL.Mode == sslRequireMode || user.SSL.Mode == sslAllowMode || user.SSL.Mode == sslPreferMode {
|
||||
return
|
||||
}
|
||||
|
||||
if user.SSL.RootCert == "" {
|
||||
logging.WithFields(
|
||||
"cert set", user.SSL.Cert != "",
|
||||
"key set", user.SSL.Key != "",
|
||||
"rootCert set", user.SSL.RootCert != "",
|
||||
).Fatal("at least ssl root cert has to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func (c Config) String(useAdmin bool) string {
|
||||
user := c.User
|
||||
if useAdmin {
|
||||
user = c.Admin
|
||||
}
|
||||
c.checkSSL(user)
|
||||
fields := []string{
|
||||
"host=" + c.Host,
|
||||
"port=" + strconv.Itoa(int(c.Port)),
|
||||
"user=" + user.Username,
|
||||
"application_name=zitadel",
|
||||
"sslmode=" + user.SSL.Mode,
|
||||
}
|
||||
if c.Options != "" {
|
||||
fields = append(fields, "options="+c.Options)
|
||||
}
|
||||
if user.Password != "" {
|
||||
fields = append(fields, "password="+user.Password)
|
||||
}
|
||||
if !useAdmin {
|
||||
fields = append(fields, "dbname="+c.Database)
|
||||
} else {
|
||||
fields = append(fields, "dbname=postgres")
|
||||
}
|
||||
if user.SSL.Mode != sslDisabledMode {
|
||||
if user.SSL.RootCert != "" {
|
||||
fields = append(fields, "sslrootcert="+user.SSL.RootCert)
|
||||
}
|
||||
if user.SSL.Cert != "" {
|
||||
fields = append(fields, "sslcert="+user.SSL.Cert)
|
||||
}
|
||||
if user.SSL.Key != "" {
|
||||
fields = append(fields, "sslkey="+user.SSL.Key)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(fields, " ")
|
||||
}
|
@@ -1,9 +1,14 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
//sql import
|
||||
_ "github.com/jackc/pgx/v4/stdlib"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database/dialect"
|
||||
)
|
||||
@@ -12,3 +17,176 @@ func init() {
|
||||
config := &Config{}
|
||||
dialect.Register(config, config, false)
|
||||
}
|
||||
|
||||
const (
|
||||
sslDisabledMode = "disable"
|
||||
sslRequireMode = "require"
|
||||
sslAllowMode = "allow"
|
||||
sslPreferMode = "prefer"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Host string
|
||||
Port int32
|
||||
Database string
|
||||
EventPushConnRatio float64
|
||||
MaxOpenConns uint32
|
||||
MaxIdleConns uint32
|
||||
MaxConnLifetime time.Duration
|
||||
MaxConnIdleTime time.Duration
|
||||
User User
|
||||
Admin User
|
||||
// Additional options to be appended as options=<Options>
|
||||
// The value will be taken as is. Multiple options are space separated.
|
||||
Options string
|
||||
}
|
||||
|
||||
func (c *Config) MatchName(name string) bool {
|
||||
for _, key := range []string{"pg", "postgres"} {
|
||||
if strings.TrimSpace(strings.ToLower(name)) == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Config) Decode(configs []interface{}) (dialect.Connector, error) {
|
||||
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||
DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
|
||||
WeaklyTypedInput: true,
|
||||
Result: c,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, config := range configs {
|
||||
if err = decoder.Decode(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Config) Connect(useAdmin, isEventPusher bool, pusherRatio float32, appName string) (*sql.DB, error) {
|
||||
db, err := sql.Open("pgx", c.String(useAdmin, appName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
connInfo, err := dialect.NewConnectionInfo(c.MaxOpenConns, c.MaxIdleConns, float64(pusherRatio))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var maxConns, maxIdleConns uint32
|
||||
if isEventPusher {
|
||||
maxConns = connInfo.EventstorePusher.MaxOpenConns
|
||||
maxIdleConns = connInfo.EventstorePusher.MaxIdleConns
|
||||
} else {
|
||||
maxConns = connInfo.ZITADEL.MaxOpenConns
|
||||
maxIdleConns = connInfo.ZITADEL.MaxIdleConns
|
||||
}
|
||||
db.SetMaxOpenConns(int(maxConns))
|
||||
db.SetMaxIdleConns(int(maxIdleConns))
|
||||
db.SetConnMaxLifetime(c.MaxConnLifetime)
|
||||
db.SetConnMaxIdleTime(c.MaxConnIdleTime)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func (c *Config) DatabaseName() string {
|
||||
return c.Database
|
||||
}
|
||||
|
||||
func (c *Config) Username() string {
|
||||
return c.User.Username
|
||||
}
|
||||
|
||||
func (c *Config) Password() string {
|
||||
return c.User.Password
|
||||
}
|
||||
|
||||
func (c *Config) Type() string {
|
||||
return "postgres"
|
||||
}
|
||||
|
||||
func (c *Config) Timetravel(time.Duration) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Username string
|
||||
Password string
|
||||
SSL SSL
|
||||
}
|
||||
|
||||
type SSL struct {
|
||||
// type of connection security
|
||||
Mode string
|
||||
// RootCert Path to the CA certificate
|
||||
RootCert string
|
||||
// Cert Path to the client certificate
|
||||
Cert string
|
||||
// Key Path to the client private key
|
||||
Key string
|
||||
}
|
||||
|
||||
func (s *Config) checkSSL(user User) {
|
||||
if user.SSL.Mode == sslDisabledMode || user.SSL.Mode == "" {
|
||||
user.SSL = SSL{Mode: sslDisabledMode}
|
||||
return
|
||||
}
|
||||
|
||||
if user.SSL.Mode == sslRequireMode || user.SSL.Mode == sslAllowMode || user.SSL.Mode == sslPreferMode {
|
||||
return
|
||||
}
|
||||
|
||||
if user.SSL.RootCert == "" {
|
||||
logging.WithFields(
|
||||
"cert set", user.SSL.Cert != "",
|
||||
"key set", user.SSL.Key != "",
|
||||
"rootCert set", user.SSL.RootCert != "",
|
||||
).Fatal("at least ssl root cert has to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func (c Config) String(useAdmin bool, appName string) string {
|
||||
user := c.User
|
||||
if useAdmin {
|
||||
user = c.Admin
|
||||
}
|
||||
c.checkSSL(user)
|
||||
fields := []string{
|
||||
"host=" + c.Host,
|
||||
"port=" + strconv.Itoa(int(c.Port)),
|
||||
"user=" + user.Username,
|
||||
"application_name=" + appName,
|
||||
"sslmode=" + user.SSL.Mode,
|
||||
}
|
||||
if c.Options != "" {
|
||||
fields = append(fields, "options="+c.Options)
|
||||
}
|
||||
if user.Password != "" {
|
||||
fields = append(fields, "password="+user.Password)
|
||||
}
|
||||
if !useAdmin {
|
||||
fields = append(fields, "dbname="+c.Database)
|
||||
} else {
|
||||
fields = append(fields, "dbname=postgres")
|
||||
}
|
||||
if user.SSL.Mode != sslDisabledMode {
|
||||
if user.SSL.RootCert != "" {
|
||||
fields = append(fields, "sslrootcert="+user.SSL.RootCert)
|
||||
}
|
||||
if user.SSL.Cert != "" {
|
||||
fields = append(fields, "sslcert="+user.SSL.Cert)
|
||||
}
|
||||
if user.SSL.Key != "" {
|
||||
fields = append(fields, "sslkey="+user.SSL.Key)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(fields, " ")
|
||||
}
|
||||
|
@@ -8,22 +8,19 @@ import (
|
||||
"github.com/jackc/pgtype"
|
||||
)
|
||||
|
||||
type StringArray []string
|
||||
type TextArray[t ~string] []t
|
||||
|
||||
// Scan implements the [database/sql.Scanner] interface.
|
||||
func (s *StringArray) Scan(src any) error {
|
||||
func (s *TextArray[t]) Scan(src any) error {
|
||||
array := new(pgtype.TextArray)
|
||||
if err := array.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := array.AssignTo(s); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return array.AssignTo(s)
|
||||
}
|
||||
|
||||
// Value implements the [database/sql/driver.Valuer] interface.
|
||||
func (s StringArray) Value() (driver.Value, error) {
|
||||
func (s TextArray[t]) Value() (driver.Value, error) {
|
||||
if len(s) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -36,37 +33,37 @@ func (s StringArray) Value() (driver.Value, error) {
|
||||
return array.Value()
|
||||
}
|
||||
|
||||
type enumField interface {
|
||||
type arrayField interface {
|
||||
~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32
|
||||
}
|
||||
|
||||
type EnumArray[F enumField] []F
|
||||
type Array[F arrayField] []F
|
||||
|
||||
// Scan implements the [database/sql.Scanner] interface.
|
||||
func (s *EnumArray[F]) Scan(src any) error {
|
||||
array := new(pgtype.Int2Array)
|
||||
func (a *Array[F]) Scan(src any) error {
|
||||
array := new(pgtype.Int8Array)
|
||||
if err := array.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
ints := make([]int32, 0, len(array.Elements))
|
||||
if err := array.AssignTo(&ints); err != nil {
|
||||
elements := make([]int64, len(array.Elements))
|
||||
if err := array.AssignTo(&elements); err != nil {
|
||||
return err
|
||||
}
|
||||
*s = make([]F, len(ints))
|
||||
for i, a := range ints {
|
||||
(*s)[i] = F(a)
|
||||
*a = make([]F, len(elements))
|
||||
for i, element := range elements {
|
||||
(*a)[i] = F(element)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the [database/sql/driver.Valuer] interface.
|
||||
func (s EnumArray[F]) Value() (driver.Value, error) {
|
||||
if len(s) == 0 {
|
||||
func (a Array[F]) Value() (driver.Value, error) {
|
||||
if len(a) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
array := pgtype.Int2Array{}
|
||||
if err := array.Set(s); err != nil {
|
||||
array := pgtype.Int8Array{}
|
||||
if err := array.Set(a); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@@ -117,3 +117,85 @@ func TestMap_Value(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestArray_ScanInt32(t *testing.T) {
|
||||
type args struct {
|
||||
src any
|
||||
}
|
||||
type res[V arrayField] struct {
|
||||
want Array[V]
|
||||
err bool
|
||||
}
|
||||
type testCase[V arrayField] struct {
|
||||
name string
|
||||
m Array[V]
|
||||
args args
|
||||
res[V]
|
||||
}
|
||||
tests := []testCase[int32]{
|
||||
{
|
||||
"number",
|
||||
Array[int32]{},
|
||||
args{src: "{1,2}"},
|
||||
res[int32]{
|
||||
want: []int32{1, 2},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.m.Scan(tt.args.src); (err != nil) != tt.res.err {
|
||||
t.Errorf("Scan() error = %v, wantErr %v", err, tt.res.err)
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.res.want, tt.m)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestArray_Value(t *testing.T) {
|
||||
type res struct {
|
||||
want driver.Value
|
||||
err bool
|
||||
}
|
||||
type testCase[V arrayField] struct {
|
||||
name string
|
||||
a Array[V]
|
||||
res res
|
||||
}
|
||||
tests := []testCase[int32]{
|
||||
{
|
||||
"nil",
|
||||
nil,
|
||||
res{
|
||||
want: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"empty",
|
||||
Array[int32]{},
|
||||
res{
|
||||
want: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"set",
|
||||
Array[int32]([]int32{1, 2}),
|
||||
res{
|
||||
want: driver.Value(string([]byte(`{1,2}`))),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.a.Value()
|
||||
if tt.res.err {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
if !tt.res.err {
|
||||
require.NoError(t, err)
|
||||
assert.Equalf(t, tt.res.want, got, "Value()")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user