2022-08-31 09:52:43 +02:00
|
|
|
package postgres
|
|
|
|
|
|
|
|
import (
|
2023-10-19 12:19:10 +02:00
|
|
|
"database/sql"
|
|
|
|
"strconv"
|
|
|
|
"strings"
|
|
|
|
"time"
|
2022-08-31 09:52:43 +02:00
|
|
|
|
2024-03-27 14:48:22 +01:00
|
|
|
_ "github.com/jackc/pgx/v5/stdlib"
|
2023-10-19 12:19:10 +02:00
|
|
|
"github.com/mitchellh/mapstructure"
|
|
|
|
"github.com/zitadel/logging"
|
2022-08-31 09:52:43 +02:00
|
|
|
|
|
|
|
"github.com/zitadel/zitadel/internal/database/dialect"
|
|
|
|
)
|
|
|
|
|
|
|
|
func init() {
|
|
|
|
config := &Config{}
|
|
|
|
dialect.Register(config, config, false)
|
|
|
|
}
|
2023-10-19 12:19:10 +02:00
|
|
|
|
|
|
|
const (
|
|
|
|
sslDisabledMode = "disable"
|
|
|
|
sslRequireMode = "require"
|
|
|
|
sslAllowMode = "allow"
|
|
|
|
sslPreferMode = "prefer"
|
|
|
|
)
|
|
|
|
|
|
|
|
type Config struct {
|
|
|
|
Host string
|
|
|
|
Port int32
|
|
|
|
Database string
|
|
|
|
EventPushConnRatio float64
|
|
|
|
MaxOpenConns uint32
|
|
|
|
MaxIdleConns uint32
|
|
|
|
MaxConnLifetime time.Duration
|
|
|
|
MaxConnIdleTime time.Duration
|
|
|
|
User User
|
|
|
|
Admin User
|
|
|
|
// Additional options to be appended as options=<Options>
|
|
|
|
// The value will be taken as is. Multiple options are space separated.
|
|
|
|
Options string
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Config) MatchName(name string) bool {
|
|
|
|
for _, key := range []string{"pg", "postgres"} {
|
|
|
|
if strings.TrimSpace(strings.ToLower(name)) == key {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Config) Decode(configs []interface{}) (dialect.Connector, error) {
|
|
|
|
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
|
|
|
DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
|
|
|
|
WeaklyTypedInput: true,
|
|
|
|
Result: c,
|
|
|
|
})
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, config := range configs {
|
|
|
|
if err = decoder.Decode(config); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return c, nil
|
|
|
|
}
|
|
|
|
|
2023-12-20 18:13:04 +02:00
|
|
|
func (c *Config) Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpose dialect.DBPurpose) (*sql.DB, error) {
|
|
|
|
client, err := sql.Open("pgx", c.String(useAdmin, purpose.AppName()))
|
2023-10-19 12:19:10 +02:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2023-12-20 18:13:04 +02:00
|
|
|
connConfig, err := dialect.NewConnectionConfig(c.MaxOpenConns, c.MaxIdleConns, spoolerRatio, pusherRatio, purpose)
|
2023-10-19 12:19:10 +02:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2024-02-09 13:43:01 +01:00
|
|
|
client.SetMaxOpenConns(int(connConfig.MaxOpenConns))
|
2023-12-20 18:13:04 +02:00
|
|
|
client.SetMaxIdleConns(int(connConfig.MaxIdleConns))
|
|
|
|
client.SetConnMaxLifetime(c.MaxConnLifetime)
|
|
|
|
client.SetConnMaxIdleTime(c.MaxConnIdleTime)
|
2023-10-19 12:19:10 +02:00
|
|
|
|
2023-12-20 18:13:04 +02:00
|
|
|
return client, nil
|
2023-10-19 12:19:10 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Config) DatabaseName() string {
|
|
|
|
return c.Database
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Config) Username() string {
|
|
|
|
return c.User.Username
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Config) Password() string {
|
|
|
|
return c.User.Password
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Config) Type() string {
|
|
|
|
return "postgres"
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Config) Timetravel(time.Duration) string {
|
|
|
|
return ""
|
|
|
|
}
|
|
|
|
|
|
|
|
type User struct {
|
|
|
|
Username string
|
|
|
|
Password string
|
|
|
|
SSL SSL
|
|
|
|
}
|
|
|
|
|
|
|
|
type SSL struct {
|
|
|
|
// type of connection security
|
|
|
|
Mode string
|
|
|
|
// RootCert Path to the CA certificate
|
|
|
|
RootCert string
|
|
|
|
// Cert Path to the client certificate
|
|
|
|
Cert string
|
|
|
|
// Key Path to the client private key
|
|
|
|
Key string
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *Config) checkSSL(user User) {
|
|
|
|
if user.SSL.Mode == sslDisabledMode || user.SSL.Mode == "" {
|
|
|
|
user.SSL = SSL{Mode: sslDisabledMode}
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
if user.SSL.Mode == sslRequireMode || user.SSL.Mode == sslAllowMode || user.SSL.Mode == sslPreferMode {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
if user.SSL.RootCert == "" {
|
|
|
|
logging.WithFields(
|
|
|
|
"cert set", user.SSL.Cert != "",
|
|
|
|
"key set", user.SSL.Key != "",
|
|
|
|
"rootCert set", user.SSL.RootCert != "",
|
|
|
|
).Fatal("at least ssl root cert has to be set")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c Config) String(useAdmin bool, appName string) string {
|
|
|
|
user := c.User
|
|
|
|
if useAdmin {
|
|
|
|
user = c.Admin
|
|
|
|
}
|
|
|
|
c.checkSSL(user)
|
|
|
|
fields := []string{
|
|
|
|
"host=" + c.Host,
|
|
|
|
"port=" + strconv.Itoa(int(c.Port)),
|
|
|
|
"user=" + user.Username,
|
|
|
|
"application_name=" + appName,
|
|
|
|
"sslmode=" + user.SSL.Mode,
|
|
|
|
}
|
|
|
|
if c.Options != "" {
|
|
|
|
fields = append(fields, "options="+c.Options)
|
|
|
|
}
|
|
|
|
if user.Password != "" {
|
|
|
|
fields = append(fields, "password="+user.Password)
|
|
|
|
}
|
|
|
|
if !useAdmin {
|
|
|
|
fields = append(fields, "dbname="+c.Database)
|
|
|
|
} else {
|
|
|
|
fields = append(fields, "dbname=postgres")
|
|
|
|
}
|
|
|
|
if user.SSL.Mode != sslDisabledMode {
|
|
|
|
if user.SSL.RootCert != "" {
|
|
|
|
fields = append(fields, "sslrootcert="+user.SSL.RootCert)
|
|
|
|
}
|
|
|
|
if user.SSL.Cert != "" {
|
|
|
|
fields = append(fields, "sslcert="+user.SSL.Cert)
|
|
|
|
}
|
|
|
|
if user.SSL.Key != "" {
|
|
|
|
fields = append(fields, "sslkey="+user.SSL.Key)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return strings.Join(fields, " ")
|
|
|
|
}
|