mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-27 12:05:26 +00:00
218138afee
expand user, add claims to user This commit expands the user table with additional fields that can be retrieved from OIDC providers (and other places) and uses this data in various tailscale response objects if it is available. This is the beginning of implementing https://docs.google.com/document/d/1X85PMxIaVWDF6T_UPji3OeeUqVBcGj_uHRM5CI-AwlY/edit trying to make OIDC more coherant and maintainable in addition to giving the user a better experience and integration with a provider. remove usernames in magic dns, normalisation of emails this commit removes the option to have usernames as part of MagicDNS domains and headscale will now align with Tailscale, where there is a root domain, and the machine name. In addition, the various normalisation functions for dns names has been made lighter not caring about username and special character that wont occur. Email are no longer normalised as part of the policy processing. untagle oidc and regcache, use typed cache This commits stops reusing the registration cache for oidc purposes and switches the cache to be types and not use any allowing the removal of a bunch of casting. try to make reauth/register branches clearer in oidc Currently there was a function that did a bunch of stuff, finding the machine key, trying to find the node, reauthing the node, returning some status, and it was called validate which was very confusing. This commit tries to split this into what to do if the node exists, if it needs to register etc. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
1024 lines
29 KiB
Go
1024 lines
29 KiB
Go
package types
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"io/fs"
|
|
"net/netip"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/coreos/go-oidc/v3/oidc"
|
|
"github.com/juanfont/headscale/hscontrol/util"
|
|
"github.com/prometheus/common/model"
|
|
"github.com/rs/zerolog"
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/spf13/viper"
|
|
"go4.org/netipx"
|
|
"tailscale.com/net/tsaddr"
|
|
"tailscale.com/tailcfg"
|
|
"tailscale.com/types/dnstype"
|
|
"tailscale.com/util/set"
|
|
)
|
|
|
|
const (
|
|
defaultOIDCExpiryTime = 180 * 24 * time.Hour // 180 Days
|
|
maxDuration time.Duration = 1<<63 - 1
|
|
)
|
|
|
|
var errOidcMutuallyExclusive = errors.New(
|
|
"oidc_client_secret and oidc_client_secret_path are mutually exclusive",
|
|
)
|
|
|
|
type IPAllocationStrategy string
|
|
|
|
const (
|
|
IPAllocationStrategySequential IPAllocationStrategy = "sequential"
|
|
IPAllocationStrategyRandom IPAllocationStrategy = "random"
|
|
)
|
|
|
|
type PolicyMode string
|
|
|
|
const (
|
|
PolicyModeDB = "database"
|
|
PolicyModeFile = "file"
|
|
)
|
|
|
|
// Config contains the initial Headscale configuration.
|
|
type Config struct {
|
|
ServerURL string
|
|
Addr string
|
|
MetricsAddr string
|
|
GRPCAddr string
|
|
GRPCAllowInsecure bool
|
|
EphemeralNodeInactivityTimeout time.Duration
|
|
PrefixV4 *netip.Prefix
|
|
PrefixV6 *netip.Prefix
|
|
IPAllocation IPAllocationStrategy
|
|
NoisePrivateKeyPath string
|
|
BaseDomain string
|
|
Log LogConfig
|
|
DisableUpdateCheck bool
|
|
|
|
Database DatabaseConfig
|
|
|
|
DERP DERPConfig
|
|
|
|
TLS TLSConfig
|
|
|
|
ACMEURL string
|
|
ACMEEmail string
|
|
|
|
DNSConfig *tailcfg.DNSConfig
|
|
|
|
UnixSocket string
|
|
UnixSocketPermission fs.FileMode
|
|
|
|
OIDC OIDCConfig
|
|
|
|
LogTail LogTailConfig
|
|
RandomizeClientPort bool
|
|
|
|
CLI CLIConfig
|
|
|
|
Policy PolicyConfig
|
|
|
|
Tuning Tuning
|
|
}
|
|
|
|
type DNSConfig struct {
|
|
MagicDNS bool `mapstructure:"magic_dns"`
|
|
BaseDomain string `mapstructure:"base_domain"`
|
|
Nameservers Nameservers
|
|
SearchDomains []string `mapstructure:"search_domains"`
|
|
ExtraRecords []tailcfg.DNSRecord `mapstructure:"extra_records"`
|
|
}
|
|
|
|
type Nameservers struct {
|
|
Global []string
|
|
Split map[string][]string
|
|
}
|
|
|
|
type SqliteConfig struct {
|
|
Path string
|
|
WriteAheadLog bool
|
|
}
|
|
|
|
type PostgresConfig struct {
|
|
Host string
|
|
Port int
|
|
Name string
|
|
User string
|
|
Pass string
|
|
Ssl string
|
|
MaxOpenConnections int
|
|
MaxIdleConnections int
|
|
ConnMaxIdleTimeSecs int
|
|
}
|
|
|
|
type GormConfig struct {
|
|
Debug bool
|
|
SlowThreshold time.Duration
|
|
SkipErrRecordNotFound bool
|
|
ParameterizedQueries bool
|
|
PrepareStmt bool
|
|
}
|
|
|
|
type DatabaseConfig struct {
|
|
// Type sets the database type, either "sqlite3" or "postgres"
|
|
Type string
|
|
Debug bool
|
|
|
|
// Type sets the gorm configuration
|
|
Gorm GormConfig
|
|
|
|
Sqlite SqliteConfig
|
|
Postgres PostgresConfig
|
|
}
|
|
|
|
type TLSConfig struct {
|
|
CertPath string
|
|
KeyPath string
|
|
|
|
LetsEncrypt LetsEncryptConfig
|
|
}
|
|
|
|
type LetsEncryptConfig struct {
|
|
Listen string
|
|
Hostname string
|
|
CacheDir string
|
|
ChallengeType string
|
|
}
|
|
|
|
type OIDCConfig struct {
|
|
OnlyStartIfOIDCIsAvailable bool
|
|
Issuer string
|
|
ClientID string
|
|
ClientSecret string
|
|
Scope []string
|
|
ExtraParams map[string]string
|
|
AllowedDomains []string
|
|
AllowedUsers []string
|
|
AllowedGroups []string
|
|
Expiry time.Duration
|
|
UseExpiryFromToken bool
|
|
}
|
|
|
|
type DERPConfig struct {
|
|
ServerEnabled bool
|
|
AutomaticallyAddEmbeddedDerpRegion bool
|
|
ServerRegionID int
|
|
ServerRegionCode string
|
|
ServerRegionName string
|
|
ServerPrivateKeyPath string
|
|
STUNAddr string
|
|
URLs []url.URL
|
|
Paths []string
|
|
AutoUpdate bool
|
|
UpdateFrequency time.Duration
|
|
IPv4 string
|
|
IPv6 string
|
|
}
|
|
|
|
type LogTailConfig struct {
|
|
Enabled bool
|
|
}
|
|
|
|
type CLIConfig struct {
|
|
Address string
|
|
APIKey string
|
|
Timeout time.Duration
|
|
Insecure bool
|
|
}
|
|
|
|
type PolicyConfig struct {
|
|
Path string
|
|
Mode PolicyMode
|
|
}
|
|
|
|
type LogConfig struct {
|
|
Format string
|
|
Level zerolog.Level
|
|
}
|
|
|
|
type Tuning struct {
|
|
NotifierSendTimeout time.Duration
|
|
BatchChangeDelay time.Duration
|
|
NodeMapSessionBufferedChanSize int
|
|
}
|
|
|
|
// LoadConfig prepares and loads the Headscale configuration into Viper.
|
|
// This means it sets the default values, reads the configuration file and
|
|
// environment variables, and handles deprecated configuration options.
|
|
// It has to be called before LoadServerConfig and LoadCLIConfig.
|
|
// The configuration is not validated and the caller should check for errors
|
|
// using a validation function.
|
|
func LoadConfig(path string, isFile bool) error {
|
|
if isFile {
|
|
viper.SetConfigFile(path)
|
|
} else {
|
|
viper.SetConfigName("config")
|
|
if path == "" {
|
|
viper.AddConfigPath("/etc/headscale/")
|
|
viper.AddConfigPath("$HOME/.headscale")
|
|
viper.AddConfigPath(".")
|
|
} else {
|
|
// For testing
|
|
viper.AddConfigPath(path)
|
|
}
|
|
}
|
|
|
|
envPrefix := "headscale"
|
|
viper.SetEnvPrefix(envPrefix)
|
|
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
|
viper.AutomaticEnv()
|
|
|
|
viper.SetDefault("policy.mode", "file")
|
|
|
|
viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache")
|
|
viper.SetDefault("tls_letsencrypt_challenge_type", HTTP01ChallengeType)
|
|
|
|
viper.SetDefault("log.level", "info")
|
|
viper.SetDefault("log.format", TextLogFormat)
|
|
|
|
viper.SetDefault("dns.magic_dns", true)
|
|
viper.SetDefault("dns.base_domain", "")
|
|
viper.SetDefault("dns.nameservers.global", []string{})
|
|
viper.SetDefault("dns.nameservers.split", map[string]string{})
|
|
viper.SetDefault("dns.search_domains", []string{})
|
|
viper.SetDefault("dns.extra_records", []tailcfg.DNSRecord{})
|
|
|
|
viper.SetDefault("derp.server.enabled", false)
|
|
viper.SetDefault("derp.server.stun.enabled", true)
|
|
viper.SetDefault("derp.server.automatically_add_embedded_derp_region", true)
|
|
|
|
viper.SetDefault("unix_socket", "/var/run/headscale/headscale.sock")
|
|
viper.SetDefault("unix_socket_permission", "0o770")
|
|
|
|
viper.SetDefault("grpc_listen_addr", ":50443")
|
|
viper.SetDefault("grpc_allow_insecure", false)
|
|
|
|
viper.SetDefault("cli.timeout", "5s")
|
|
viper.SetDefault("cli.insecure", false)
|
|
|
|
viper.SetDefault("database.postgres.ssl", false)
|
|
viper.SetDefault("database.postgres.max_open_conns", 10)
|
|
viper.SetDefault("database.postgres.max_idle_conns", 10)
|
|
viper.SetDefault("database.postgres.conn_max_idle_time_secs", 3600)
|
|
|
|
viper.SetDefault("database.sqlite.write_ahead_log", true)
|
|
|
|
viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
|
|
viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
|
|
viper.SetDefault("oidc.expiry", "180d")
|
|
viper.SetDefault("oidc.use_expiry_from_token", false)
|
|
|
|
viper.SetDefault("logtail.enabled", false)
|
|
viper.SetDefault("randomize_client_port", false)
|
|
|
|
viper.SetDefault("ephemeral_node_inactivity_timeout", "120s")
|
|
|
|
viper.SetDefault("tuning.notifier_send_timeout", "800ms")
|
|
viper.SetDefault("tuning.batch_change_delay", "800ms")
|
|
viper.SetDefault("tuning.node_mapsession_buffered_chan_size", 30)
|
|
|
|
viper.SetDefault("prefixes.allocation", string(IPAllocationStrategySequential))
|
|
|
|
if err := viper.ReadInConfig(); err != nil {
|
|
return fmt.Errorf("fatal error reading config file: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func validateServerConfig() error {
|
|
depr := deprecator{
|
|
warns: make(set.Set[string]),
|
|
fatals: make(set.Set[string]),
|
|
}
|
|
|
|
// Register aliases for backward compatibility
|
|
// Has to be called _after_ viper.ReadInConfig()
|
|
// https://github.com/spf13/viper/issues/560
|
|
|
|
// Alias the old ACL Policy path with the new configuration option.
|
|
depr.fatalIfNewKeyIsNotUsed("policy.path", "acl_policy_path")
|
|
|
|
// Move dns_config -> dns
|
|
depr.warn("dns_config.override_local_dns")
|
|
depr.fatalIfNewKeyIsNotUsed("dns.magic_dns", "dns_config.magic_dns")
|
|
depr.fatalIfNewKeyIsNotUsed("dns.base_domain", "dns_config.base_domain")
|
|
depr.fatalIfNewKeyIsNotUsed("dns.nameservers.global", "dns_config.nameservers")
|
|
depr.fatalIfNewKeyIsNotUsed("dns.nameservers.split", "dns_config.restricted_nameservers")
|
|
depr.fatalIfNewKeyIsNotUsed("dns.search_domains", "dns_config.domains")
|
|
depr.fatalIfNewKeyIsNotUsed("dns.extra_records", "dns_config.extra_records")
|
|
depr.warn("dns_config.use_username_in_magic_dns")
|
|
depr.warn("dns.use_username_in_magic_dns")
|
|
|
|
depr.fatal("oidc.strip_email_domain")
|
|
depr.fatal("dns.use_username_in_musername_in_magic_dns")
|
|
depr.fatal("dns_config.use_username_in_musername_in_magic_dns")
|
|
|
|
depr.Log()
|
|
|
|
for _, removed := range []string{
|
|
"oidc.strip_email_domain",
|
|
"dns_config.use_username_in_musername_in_magic_dns",
|
|
} {
|
|
if viper.IsSet(removed) {
|
|
log.Fatal().
|
|
Msgf("Fatal config error: %s has been removed. Please remove it from your config file", removed)
|
|
}
|
|
}
|
|
|
|
// Collect any validation errors and return them all at once
|
|
var errorText string
|
|
if (viper.GetString("tls_letsencrypt_hostname") != "") &&
|
|
((viper.GetString("tls_cert_path") != "") || (viper.GetString("tls_key_path") != "")) {
|
|
errorText += "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both\n"
|
|
}
|
|
|
|
if !viper.IsSet("noise") || viper.GetString("noise.private_key_path") == "" {
|
|
errorText += "Fatal config error: headscale now requires a new `noise.private_key_path` field in the config file for the Tailscale v2 protocol\n"
|
|
}
|
|
|
|
if (viper.GetString("tls_letsencrypt_hostname") != "") &&
|
|
(viper.GetString("tls_letsencrypt_challenge_type") == TLSALPN01ChallengeType) &&
|
|
(!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) {
|
|
// this is only a warning because there could be something sitting in front of headscale that redirects the traffic (e.g. an iptables rule)
|
|
log.Warn().
|
|
Msg("Warning: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, headscale must be reachable on port 443, i.e. listen_addr should probably end in :443")
|
|
}
|
|
|
|
if (viper.GetString("tls_letsencrypt_challenge_type") != HTTP01ChallengeType) &&
|
|
(viper.GetString("tls_letsencrypt_challenge_type") != TLSALPN01ChallengeType) {
|
|
errorText += "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are HTTP-01 and TLS-ALPN-01\n"
|
|
}
|
|
|
|
if !strings.HasPrefix(viper.GetString("server_url"), "http://") &&
|
|
!strings.HasPrefix(viper.GetString("server_url"), "https://") {
|
|
errorText += "Fatal config error: server_url must start with https:// or http://\n"
|
|
}
|
|
|
|
// Minimum inactivity time out is keepalive timeout (60s) plus a few seconds
|
|
// to avoid races
|
|
minInactivityTimeout, _ := time.ParseDuration("65s")
|
|
if viper.GetDuration("ephemeral_node_inactivity_timeout") <= minInactivityTimeout {
|
|
errorText += fmt.Sprintf(
|
|
"Fatal config error: ephemeral_node_inactivity_timeout (%s) is set too low, must be more than %s",
|
|
viper.GetString("ephemeral_node_inactivity_timeout"),
|
|
minInactivityTimeout,
|
|
)
|
|
}
|
|
|
|
if errorText != "" {
|
|
// nolint
|
|
return errors.New(strings.TrimSuffix(errorText, "\n"))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func tlsConfig() TLSConfig {
|
|
return TLSConfig{
|
|
LetsEncrypt: LetsEncryptConfig{
|
|
Hostname: viper.GetString("tls_letsencrypt_hostname"),
|
|
Listen: viper.GetString("tls_letsencrypt_listen"),
|
|
CacheDir: util.AbsolutePathFromConfigPath(
|
|
viper.GetString("tls_letsencrypt_cache_dir"),
|
|
),
|
|
ChallengeType: viper.GetString("tls_letsencrypt_challenge_type"),
|
|
},
|
|
CertPath: util.AbsolutePathFromConfigPath(
|
|
viper.GetString("tls_cert_path"),
|
|
),
|
|
KeyPath: util.AbsolutePathFromConfigPath(
|
|
viper.GetString("tls_key_path"),
|
|
),
|
|
}
|
|
}
|
|
|
|
func derpConfig() DERPConfig {
|
|
serverEnabled := viper.GetBool("derp.server.enabled")
|
|
serverRegionID := viper.GetInt("derp.server.region_id")
|
|
serverRegionCode := viper.GetString("derp.server.region_code")
|
|
serverRegionName := viper.GetString("derp.server.region_name")
|
|
stunAddr := viper.GetString("derp.server.stun_listen_addr")
|
|
privateKeyPath := util.AbsolutePathFromConfigPath(
|
|
viper.GetString("derp.server.private_key_path"),
|
|
)
|
|
ipv4 := viper.GetString("derp.server.ipv4")
|
|
ipv6 := viper.GetString("derp.server.ipv6")
|
|
automaticallyAddEmbeddedDerpRegion := viper.GetBool(
|
|
"derp.server.automatically_add_embedded_derp_region",
|
|
)
|
|
if serverEnabled && stunAddr == "" {
|
|
log.Fatal().
|
|
Msg("derp.server.stun_listen_addr must be set if derp.server.enabled is true")
|
|
}
|
|
|
|
urlStrs := viper.GetStringSlice("derp.urls")
|
|
|
|
urls := make([]url.URL, len(urlStrs))
|
|
for index, urlStr := range urlStrs {
|
|
urlAddr, err := url.Parse(urlStr)
|
|
if err != nil {
|
|
log.Error().
|
|
Str("url", urlStr).
|
|
Err(err).
|
|
Msg("Failed to parse url, ignoring...")
|
|
}
|
|
|
|
urls[index] = *urlAddr
|
|
}
|
|
|
|
paths := viper.GetStringSlice("derp.paths")
|
|
|
|
if serverEnabled && !automaticallyAddEmbeddedDerpRegion && len(paths) == 0 {
|
|
log.Fatal().
|
|
Msg("Disabling derp.server.automatically_add_embedded_derp_region requires to configure the derp server in derp.paths")
|
|
}
|
|
|
|
autoUpdate := viper.GetBool("derp.auto_update_enabled")
|
|
updateFrequency := viper.GetDuration("derp.update_frequency")
|
|
|
|
return DERPConfig{
|
|
ServerEnabled: serverEnabled,
|
|
ServerRegionID: serverRegionID,
|
|
ServerRegionCode: serverRegionCode,
|
|
ServerRegionName: serverRegionName,
|
|
ServerPrivateKeyPath: privateKeyPath,
|
|
STUNAddr: stunAddr,
|
|
URLs: urls,
|
|
Paths: paths,
|
|
AutoUpdate: autoUpdate,
|
|
UpdateFrequency: updateFrequency,
|
|
IPv4: ipv4,
|
|
IPv6: ipv6,
|
|
AutomaticallyAddEmbeddedDerpRegion: automaticallyAddEmbeddedDerpRegion,
|
|
}
|
|
}
|
|
|
|
func logtailConfig() LogTailConfig {
|
|
enabled := viper.GetBool("logtail.enabled")
|
|
|
|
return LogTailConfig{
|
|
Enabled: enabled,
|
|
}
|
|
}
|
|
|
|
func policyConfig() PolicyConfig {
|
|
policyPath := viper.GetString("policy.path")
|
|
policyMode := viper.GetString("policy.mode")
|
|
|
|
return PolicyConfig{
|
|
Path: policyPath,
|
|
Mode: PolicyMode(policyMode),
|
|
}
|
|
}
|
|
|
|
func logConfig() LogConfig {
|
|
logLevelStr := viper.GetString("log.level")
|
|
logLevel, err := zerolog.ParseLevel(logLevelStr)
|
|
if err != nil {
|
|
logLevel = zerolog.DebugLevel
|
|
}
|
|
|
|
logFormatOpt := viper.GetString("log.format")
|
|
var logFormat string
|
|
switch logFormatOpt {
|
|
case JSONLogFormat:
|
|
logFormat = JSONLogFormat
|
|
case TextLogFormat:
|
|
logFormat = TextLogFormat
|
|
case "":
|
|
logFormat = TextLogFormat
|
|
default:
|
|
log.Error().
|
|
Str("func", "GetLogConfig").
|
|
Msgf("Could not parse log format: %s. Valid choices are 'json' or 'text'", logFormatOpt)
|
|
}
|
|
|
|
return LogConfig{
|
|
Format: logFormat,
|
|
Level: logLevel,
|
|
}
|
|
}
|
|
|
|
func databaseConfig() DatabaseConfig {
|
|
debug := viper.GetBool("database.debug")
|
|
|
|
type_ := viper.GetString("database.type")
|
|
|
|
skipErrRecordNotFound := viper.GetBool("database.gorm.skip_err_record_not_found")
|
|
slowThreshold := viper.GetDuration("database.gorm.slow_threshold") * time.Millisecond
|
|
parameterizedQueries := viper.GetBool("database.gorm.parameterized_queries")
|
|
prepareStmt := viper.GetBool("database.gorm.prepare_stmt")
|
|
|
|
switch type_ {
|
|
case DatabaseSqlite, DatabasePostgres:
|
|
break
|
|
case "sqlite":
|
|
type_ = "sqlite3"
|
|
default:
|
|
log.Fatal().
|
|
Msgf("invalid database type %q, must be sqlite, sqlite3 or postgres", type_)
|
|
}
|
|
|
|
return DatabaseConfig{
|
|
Type: type_,
|
|
Debug: debug,
|
|
Gorm: GormConfig{
|
|
Debug: debug,
|
|
SkipErrRecordNotFound: skipErrRecordNotFound,
|
|
SlowThreshold: slowThreshold,
|
|
ParameterizedQueries: parameterizedQueries,
|
|
PrepareStmt: prepareStmt,
|
|
},
|
|
Sqlite: SqliteConfig{
|
|
Path: util.AbsolutePathFromConfigPath(
|
|
viper.GetString("database.sqlite.path"),
|
|
),
|
|
WriteAheadLog: viper.GetBool("database.sqlite.write_ahead_log"),
|
|
},
|
|
Postgres: PostgresConfig{
|
|
Host: viper.GetString("database.postgres.host"),
|
|
Port: viper.GetInt("database.postgres.port"),
|
|
Name: viper.GetString("database.postgres.name"),
|
|
User: viper.GetString("database.postgres.user"),
|
|
Pass: viper.GetString("database.postgres.pass"),
|
|
Ssl: viper.GetString("database.postgres.ssl"),
|
|
MaxOpenConnections: viper.GetInt("database.postgres.max_open_conns"),
|
|
MaxIdleConnections: viper.GetInt("database.postgres.max_idle_conns"),
|
|
ConnMaxIdleTimeSecs: viper.GetInt(
|
|
"database.postgres.conn_max_idle_time_secs",
|
|
),
|
|
},
|
|
}
|
|
}
|
|
|
|
func dns() (DNSConfig, error) {
|
|
var dns DNSConfig
|
|
|
|
// TODO: Use this instead of manually getting settings when
|
|
// UnmarshalKey is compatible with Environment Variables.
|
|
// err := viper.UnmarshalKey("dns", &dns)
|
|
// if err != nil {
|
|
// return DNSConfig{}, fmt.Errorf("unmarshaling dns config: %w", err)
|
|
// }
|
|
|
|
dns.MagicDNS = viper.GetBool("dns.magic_dns")
|
|
dns.BaseDomain = viper.GetString("dns.base_domain")
|
|
dns.Nameservers.Global = viper.GetStringSlice("dns.nameservers.global")
|
|
dns.Nameservers.Split = viper.GetStringMapStringSlice("dns.nameservers.split")
|
|
dns.SearchDomains = viper.GetStringSlice("dns.search_domains")
|
|
|
|
if viper.IsSet("dns.extra_records") {
|
|
var extraRecords []tailcfg.DNSRecord
|
|
|
|
err := viper.UnmarshalKey("dns.extra_records", &extraRecords)
|
|
if err != nil {
|
|
return DNSConfig{}, fmt.Errorf("unmarshaling dns extra records: %w", err)
|
|
}
|
|
dns.ExtraRecords = extraRecords
|
|
}
|
|
|
|
return dns, nil
|
|
}
|
|
|
|
// globalResolvers returns the global DNS resolvers
|
|
// defined in the config file.
|
|
// If a nameserver is a valid IP, it will be used as a regular resolver.
|
|
// If a nameserver is a valid URL, it will be used as a DoH resolver.
|
|
// If a nameserver is neither a valid URL nor a valid IP, it will be ignored.
|
|
func (d *DNSConfig) globalResolvers() []*dnstype.Resolver {
|
|
var resolvers []*dnstype.Resolver
|
|
|
|
for _, nsStr := range d.Nameservers.Global {
|
|
warn := ""
|
|
if _, err := netip.ParseAddr(nsStr); err == nil {
|
|
resolvers = append(resolvers, &dnstype.Resolver{
|
|
Addr: nsStr,
|
|
})
|
|
|
|
continue
|
|
} else {
|
|
warn = fmt.Sprintf("Invalid global nameserver %q. Parsing error: %s ignoring", nsStr, err)
|
|
}
|
|
|
|
if _, err := url.Parse(nsStr); err == nil {
|
|
resolvers = append(resolvers, &dnstype.Resolver{
|
|
Addr: nsStr,
|
|
})
|
|
|
|
continue
|
|
} else {
|
|
warn = fmt.Sprintf("Invalid global nameserver %q. Parsing error: %s ignoring", nsStr, err)
|
|
}
|
|
|
|
if warn != "" {
|
|
log.Warn().Msg(warn)
|
|
}
|
|
}
|
|
|
|
return resolvers
|
|
}
|
|
|
|
// splitResolvers returns a map of domain to DNS resolvers.
|
|
// If a nameserver is a valid IP, it will be used as a regular resolver.
|
|
// If a nameserver is a valid URL, it will be used as a DoH resolver.
|
|
// If a nameserver is neither a valid URL nor a valid IP, it will be ignored.
|
|
func (d *DNSConfig) splitResolvers() map[string][]*dnstype.Resolver {
|
|
routes := make(map[string][]*dnstype.Resolver)
|
|
for domain, nameservers := range d.Nameservers.Split {
|
|
var resolvers []*dnstype.Resolver
|
|
for _, nsStr := range nameservers {
|
|
warn := ""
|
|
if _, err := netip.ParseAddr(nsStr); err == nil {
|
|
resolvers = append(resolvers, &dnstype.Resolver{
|
|
Addr: nsStr,
|
|
})
|
|
|
|
continue
|
|
} else {
|
|
warn = fmt.Sprintf("Invalid split dns nameserver %q. Parsing error: %s ignoring", nsStr, err)
|
|
}
|
|
|
|
if _, err := url.Parse(nsStr); err == nil {
|
|
resolvers = append(resolvers, &dnstype.Resolver{
|
|
Addr: nsStr,
|
|
})
|
|
|
|
continue
|
|
} else {
|
|
warn = fmt.Sprintf("Invalid split dns nameserver %q. Parsing error: %s ignoring", nsStr, err)
|
|
}
|
|
|
|
if warn != "" {
|
|
log.Warn().Msg(warn)
|
|
}
|
|
}
|
|
routes[domain] = resolvers
|
|
}
|
|
|
|
return routes
|
|
}
|
|
|
|
func dnsToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig {
|
|
cfg := tailcfg.DNSConfig{}
|
|
|
|
if dns.BaseDomain == "" && dns.MagicDNS {
|
|
log.Fatal().Msg("dns.base_domain must be set when using MagicDNS (dns.magic_dns)")
|
|
}
|
|
|
|
cfg.Proxied = dns.MagicDNS
|
|
cfg.ExtraRecords = dns.ExtraRecords
|
|
cfg.Resolvers = dns.globalResolvers()
|
|
|
|
routes := dns.splitResolvers()
|
|
cfg.Routes = routes
|
|
if dns.BaseDomain != "" {
|
|
cfg.Domains = []string{dns.BaseDomain}
|
|
}
|
|
cfg.Domains = append(cfg.Domains, dns.SearchDomains...)
|
|
|
|
return &cfg
|
|
}
|
|
|
|
func prefixV4() (*netip.Prefix, error) {
|
|
prefixV4Str := viper.GetString("prefixes.v4")
|
|
|
|
if prefixV4Str == "" {
|
|
return nil, nil
|
|
}
|
|
|
|
prefixV4, err := netip.ParsePrefix(prefixV4Str)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing IPv4 prefix from config: %w", err)
|
|
}
|
|
|
|
builder := netipx.IPSetBuilder{}
|
|
builder.AddPrefix(tsaddr.CGNATRange())
|
|
ipSet, _ := builder.IPSet()
|
|
if !ipSet.ContainsPrefix(prefixV4) {
|
|
log.Warn().
|
|
Msgf("Prefix %s is not in the %s range. This is an unsupported configuration.",
|
|
prefixV4Str, tsaddr.CGNATRange())
|
|
}
|
|
|
|
return &prefixV4, nil
|
|
}
|
|
|
|
func prefixV6() (*netip.Prefix, error) {
|
|
prefixV6Str := viper.GetString("prefixes.v6")
|
|
|
|
if prefixV6Str == "" {
|
|
return nil, nil
|
|
}
|
|
|
|
prefixV6, err := netip.ParsePrefix(prefixV6Str)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing IPv6 prefix from config: %w", err)
|
|
}
|
|
|
|
builder := netipx.IPSetBuilder{}
|
|
builder.AddPrefix(tsaddr.TailscaleULARange())
|
|
ipSet, _ := builder.IPSet()
|
|
|
|
if !ipSet.ContainsPrefix(prefixV6) {
|
|
log.Warn().
|
|
Msgf("Prefix %s is not in the %s range. This is an unsupported configuration.",
|
|
prefixV6Str, tsaddr.TailscaleULARange())
|
|
}
|
|
|
|
return &prefixV6, nil
|
|
}
|
|
|
|
// LoadCLIConfig returns the needed configuration for the CLI client
|
|
// of Headscale to connect to a Headscale server.
|
|
func LoadCLIConfig() (*Config, error) {
|
|
logConfig := logConfig()
|
|
zerolog.SetGlobalLevel(logConfig.Level)
|
|
|
|
return &Config{
|
|
DisableUpdateCheck: viper.GetBool("disable_check_updates"),
|
|
UnixSocket: viper.GetString("unix_socket"),
|
|
CLI: CLIConfig{
|
|
Address: viper.GetString("cli.address"),
|
|
APIKey: viper.GetString("cli.api_key"),
|
|
Timeout: viper.GetDuration("cli.timeout"),
|
|
Insecure: viper.GetBool("cli.insecure"),
|
|
},
|
|
Log: logConfig,
|
|
}, nil
|
|
}
|
|
|
|
// LoadServerConfig returns the full Headscale configuration to
|
|
// host a Headscale server. This is called as part of `headscale serve`.
|
|
func LoadServerConfig() (*Config, error) {
|
|
if err := validateServerConfig(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
logConfig := logConfig()
|
|
zerolog.SetGlobalLevel(logConfig.Level)
|
|
|
|
prefix4, err := prefixV4()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
prefix6, err := prefixV6()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if prefix4 == nil && prefix6 == nil {
|
|
return nil, fmt.Errorf("no IPv4 or IPv6 prefix configured, minimum one prefix is required")
|
|
}
|
|
|
|
allocStr := viper.GetString("prefixes.allocation")
|
|
var alloc IPAllocationStrategy
|
|
switch allocStr {
|
|
case string(IPAllocationStrategySequential):
|
|
alloc = IPAllocationStrategySequential
|
|
case string(IPAllocationStrategyRandom):
|
|
alloc = IPAllocationStrategyRandom
|
|
default:
|
|
return nil, fmt.Errorf(
|
|
"config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s",
|
|
allocStr,
|
|
IPAllocationStrategySequential,
|
|
IPAllocationStrategyRandom,
|
|
)
|
|
}
|
|
|
|
dnsConfig, err := dns()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
derpConfig := derpConfig()
|
|
logTailConfig := logtailConfig()
|
|
randomizeClientPort := viper.GetBool("randomize_client_port")
|
|
|
|
oidcClientSecret := viper.GetString("oidc.client_secret")
|
|
oidcClientSecretPath := viper.GetString("oidc.client_secret_path")
|
|
if oidcClientSecretPath != "" && oidcClientSecret != "" {
|
|
return nil, errOidcMutuallyExclusive
|
|
}
|
|
if oidcClientSecretPath != "" {
|
|
secretBytes, err := os.ReadFile(os.ExpandEnv(oidcClientSecretPath))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
oidcClientSecret = strings.TrimSpace(string(secretBytes))
|
|
}
|
|
|
|
serverURL := viper.GetString("server_url")
|
|
|
|
// BaseDomain cannot be the same as the server URL.
|
|
// This is because Tailscale takes over the domain in BaseDomain,
|
|
// causing the headscale server and DERP to be unreachable.
|
|
// For Tailscale upstream, the following is true:
|
|
// - DERP run on their own domains
|
|
// - Control plane runs on login.tailscale.com/controlplane.tailscale.com
|
|
// - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net)
|
|
if dnsConfig.BaseDomain != "" &&
|
|
strings.Contains(serverURL, dnsConfig.BaseDomain) {
|
|
return nil, errors.New(
|
|
"server_url cannot contain the base_domain, this will cause the headscale server and embedded DERP to become unreachable from the Tailscale node.",
|
|
)
|
|
}
|
|
|
|
return &Config{
|
|
ServerURL: serverURL,
|
|
Addr: viper.GetString("listen_addr"),
|
|
MetricsAddr: viper.GetString("metrics_listen_addr"),
|
|
GRPCAddr: viper.GetString("grpc_listen_addr"),
|
|
GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"),
|
|
DisableUpdateCheck: false,
|
|
|
|
PrefixV4: prefix4,
|
|
PrefixV6: prefix6,
|
|
IPAllocation: IPAllocationStrategy(alloc),
|
|
|
|
NoisePrivateKeyPath: util.AbsolutePathFromConfigPath(
|
|
viper.GetString("noise.private_key_path"),
|
|
),
|
|
BaseDomain: dnsConfig.BaseDomain,
|
|
|
|
DERP: derpConfig,
|
|
|
|
EphemeralNodeInactivityTimeout: viper.GetDuration(
|
|
"ephemeral_node_inactivity_timeout",
|
|
),
|
|
|
|
Database: databaseConfig(),
|
|
|
|
TLS: tlsConfig(),
|
|
|
|
DNSConfig: dnsToTailcfgDNS(dnsConfig),
|
|
|
|
ACMEEmail: viper.GetString("acme_email"),
|
|
ACMEURL: viper.GetString("acme_url"),
|
|
|
|
UnixSocket: viper.GetString("unix_socket"),
|
|
UnixSocketPermission: util.GetFileMode("unix_socket_permission"),
|
|
|
|
OIDC: OIDCConfig{
|
|
OnlyStartIfOIDCIsAvailable: viper.GetBool(
|
|
"oidc.only_start_if_oidc_is_available",
|
|
),
|
|
Issuer: viper.GetString("oidc.issuer"),
|
|
ClientID: viper.GetString("oidc.client_id"),
|
|
ClientSecret: oidcClientSecret,
|
|
Scope: viper.GetStringSlice("oidc.scope"),
|
|
ExtraParams: viper.GetStringMapString("oidc.extra_params"),
|
|
AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
|
|
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
|
|
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),
|
|
Expiry: func() time.Duration {
|
|
// if set to 0, we assume no expiry
|
|
if value := viper.GetString("oidc.expiry"); value == "0" {
|
|
return maxDuration
|
|
} else {
|
|
expiry, err := model.ParseDuration(value)
|
|
if err != nil {
|
|
log.Warn().Msg("failed to parse oidc.expiry, defaulting back to 180 days")
|
|
|
|
return defaultOIDCExpiryTime
|
|
}
|
|
|
|
return time.Duration(expiry)
|
|
}
|
|
}(),
|
|
UseExpiryFromToken: viper.GetBool("oidc.use_expiry_from_token"),
|
|
},
|
|
|
|
LogTail: logTailConfig,
|
|
RandomizeClientPort: randomizeClientPort,
|
|
|
|
Policy: policyConfig(),
|
|
|
|
CLI: CLIConfig{
|
|
Address: viper.GetString("cli.address"),
|
|
APIKey: viper.GetString("cli.api_key"),
|
|
Timeout: viper.GetDuration("cli.timeout"),
|
|
Insecure: viper.GetBool("cli.insecure"),
|
|
},
|
|
|
|
Log: logConfig,
|
|
|
|
// TODO(kradalby): Document these settings when more stable
|
|
Tuning: Tuning{
|
|
NotifierSendTimeout: viper.GetDuration("tuning.notifier_send_timeout"),
|
|
BatchChangeDelay: viper.GetDuration("tuning.batch_change_delay"),
|
|
NodeMapSessionBufferedChanSize: viper.GetInt(
|
|
"tuning.node_mapsession_buffered_chan_size",
|
|
),
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
type deprecator struct {
|
|
warns set.Set[string]
|
|
fatals set.Set[string]
|
|
}
|
|
|
|
// warnWithAlias will register an alias between the newKey and the oldKey,
|
|
// and log a deprecation warning if the oldKey is set.
|
|
func (d *deprecator) warnWithAlias(newKey, oldKey string) {
|
|
// NOTE: RegisterAlias is called with NEW KEY -> OLD KEY
|
|
viper.RegisterAlias(newKey, oldKey)
|
|
if viper.IsSet(oldKey) {
|
|
d.warns.Add(
|
|
fmt.Sprintf(
|
|
"The %q configuration key is deprecated. Please use %q instead. %q will be removed in the future.",
|
|
oldKey,
|
|
newKey,
|
|
oldKey,
|
|
),
|
|
)
|
|
}
|
|
}
|
|
|
|
// fatal deprecates and adds an entry to the fatal list of options if the oldKey is set.
|
|
func (d *deprecator) fatal(oldKey string) {
|
|
if viper.IsSet(oldKey) {
|
|
d.fatals.Add(
|
|
fmt.Sprintf(
|
|
"The %q configuration key has been removed. Please see the changelog for more details.",
|
|
oldKey,
|
|
),
|
|
)
|
|
}
|
|
}
|
|
|
|
// fatalIfNewKeyIsNotUsed deprecates and adds an entry to the fatal list of options if the oldKey is set and the new key is _not_ set.
|
|
// If the new key is set, a warning is emitted instead.
|
|
func (d *deprecator) fatalIfNewKeyIsNotUsed(newKey, oldKey string) {
|
|
if viper.IsSet(oldKey) && !viper.IsSet(newKey) {
|
|
d.fatals.Add(
|
|
fmt.Sprintf(
|
|
"The %q configuration key is deprecated. Please use %q instead. %q has been removed.",
|
|
oldKey,
|
|
newKey,
|
|
oldKey,
|
|
),
|
|
)
|
|
} else if viper.IsSet(oldKey) {
|
|
d.warns.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey))
|
|
}
|
|
}
|
|
|
|
// warn deprecates and adds an option to log a warning if the oldKey is set.
|
|
func (d *deprecator) warnNoAlias(newKey, oldKey string) {
|
|
if viper.IsSet(oldKey) {
|
|
d.warns.Add(
|
|
fmt.Sprintf(
|
|
"The %q configuration key is deprecated. Please use %q instead. %q has been removed.",
|
|
oldKey,
|
|
newKey,
|
|
oldKey,
|
|
),
|
|
)
|
|
}
|
|
}
|
|
|
|
// warn deprecates and adds an entry to the warn list of options if the oldKey is set.
|
|
func (d *deprecator) warn(oldKey string) {
|
|
if viper.IsSet(oldKey) {
|
|
d.warns.Add(
|
|
fmt.Sprintf(
|
|
"The %q configuration key is deprecated and has been removed. Please see the changelog for more details.",
|
|
oldKey,
|
|
),
|
|
)
|
|
}
|
|
}
|
|
|
|
func (d *deprecator) String() string {
|
|
var b strings.Builder
|
|
|
|
for _, w := range d.warns.Slice() {
|
|
fmt.Fprintf(&b, "WARN: %s\n", w)
|
|
}
|
|
|
|
for _, f := range d.fatals.Slice() {
|
|
fmt.Fprintf(&b, "FATAL: %s\n", f)
|
|
}
|
|
|
|
return b.String()
|
|
}
|
|
|
|
func (d *deprecator) Log() {
|
|
if len(d.fatals) > 0 {
|
|
log.Fatal().Msg("\n" + d.String())
|
|
} else if len(d.warns) > 0 {
|
|
log.Warn().Msg("\n" + d.String())
|
|
}
|
|
}
|