fix(init): flags (#3192)

This commit is contained in:
Silvan 2022-02-11 11:52:50 +01:00 committed by GitHub
parent e8ab237ada
commit b44b48fa1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 38 additions and 18 deletions

View File

@ -2,7 +2,6 @@ package initialise
import ( import (
_ "embed" _ "embed"
"fmt"
"github.com/caos/logging" "github.com/caos/logging"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -13,16 +12,33 @@ import (
) )
var ( var (
conn string user string
password string
sslCert string
sslKey string
)
const (
userFlag = "user"
passwordFlag = "password"
sslCertFlag = "ssl-cert"
sslKeyFlag = "ssl-key"
) )
func New() *cobra.Command { func New() *cobra.Command {
cmd := &cobra.Command{ cmd := &cobra.Command{
Use: "init", Use: "init",
Short: "initialize ZITADEL instance", Short: "initialize ZITADEL instance",
Long: `init sets up the minimum requirements to start ZITADEL. Long: `Sets up the minimum requirements to start ZITADEL.
Prereqesits: Prereqesits:
- cockroachdb`, - cockroachdb
The user provided by flags needs priviledge to
- create the database if it does not exist
- see other users and create a new one if the user does not exist
- grant all rights of the ZITADEL database to the user created if not yet set
`,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
config := new(Config) config := new(Config)
if err := viper.Unmarshal(config); err != nil { if err := viper.Unmarshal(config); err != nil {
@ -32,9 +48,11 @@ Prereqesits:
}, },
} }
// cmd.PersistentFlags().StringArrayVar(&configFiles, "config", nil, "path to config file to overwrite system defaults") cmd.PersistentFlags().StringVar(&password, passwordFlag, "", "password of the the provided user")
//TODO(hust): simplify to multiple flags cmd.PersistentFlags().StringVar(&sslCert, sslCertFlag, "", "ssl cert from the provided user")
cmd.PersistentFlags().StringVar(&conn, "connection", "", "connection string to connect with a user which is allowed to create the database and user") cmd.PersistentFlags().StringVar(&sslKey, sslKeyFlag, "", "ssl key from the provided user")
cmd.PersistentFlags().StringVar(&user, userFlag, "", "(required) the user to check if the database, user and grants exists and create if not")
cmd.MarkPersistentFlagRequired(userFlag)
return cmd return cmd
} }
@ -42,11 +60,7 @@ Prereqesits:
func initialise(config *Config) error { func initialise(config *Config) error {
logging.Info("initialization started") logging.Info("initialization started")
if conn == "" { if err := prepareDB(config.Database, user, password, sslCert, sslKey); err != nil {
return fmt.Errorf("connection not defined")
}
if err := prepareDB(config.Database); err != nil {
return err return err
} }

View File

@ -7,8 +7,14 @@ import (
"github.com/caos/zitadel/internal/database" "github.com/caos/zitadel/internal/database"
) )
func prepareDB(config database.Config) error { func prepareDB(config database.Config, user, password, sslCert, sslKey string) error {
db, err := sql.Open("postgres", conn) adminConfig := config
adminConfig.User = user
adminConfig.Password = password
adminConfig.SSL.Cert = sslCert
adminConfig.SSL.Key = sslKey
db, err := database.Connect(adminConfig)
if err != nil { if err != nil {
return err return err
} }

View File

@ -17,7 +17,7 @@ type Config struct {
User string User string
Password string Password string
Database string Database string
SSL *ssl SSL SSL
MaxOpenConns uint32 MaxOpenConns uint32
MaxConnLifetime types.Duration MaxConnLifetime types.Duration
MaxConnIdleTime types.Duration MaxConnIdleTime types.Duration
@ -27,7 +27,7 @@ type Config struct {
Options string Options string
} }
type ssl struct { type SSL struct {
// type of connection security // type of connection security
Mode string Mode string
// RootCert Path to the CA certificate // RootCert Path to the CA certificate
@ -39,8 +39,8 @@ type ssl struct {
} }
func (s *Config) checkSSL() { func (s *Config) checkSSL() {
if s.SSL == nil || s.SSL.Mode == sslDisabledMode || s.SSL.Mode == "" { if s.SSL.Mode == sslDisabledMode || s.SSL.Mode == "" {
s.SSL = &ssl{Mode: sslDisabledMode} s.SSL = SSL{Mode: sslDisabledMode}
return return
} }
if s.SSL.RootCert == "" { if s.SSL.RootCert == "" {