diff --git a/app.go b/app.go index 26d7b953..8d2a2b17 100644 --- a/app.go +++ b/app.go @@ -650,21 +650,11 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { log.Warn().Msg("Listening with TLS but ServerURL does not start with https://") } - var clientAuthMode tls.ClientAuthType - switch h.cfg.TLSClientAuthMode { - case DisabledClientAuth: - // Client cert is _not_ required. - clientAuthMode = tls.NoClientCert - case RelaxedClientAuth: - // Client cert required, but _not verified_. - clientAuthMode = tls.RequireAnyClientCert - case EnforcedClientAuth: - // Client cert is _required and verified_. - clientAuthMode = tls.RequireAndVerifyClientCert - default: - return nil, Error("Invalid tls_client_auth_mode provided: " + - h.cfg.TLSClientAuthMode) - } + clientAuthMode, err := h.GetClientAuthMode() + + if err != nil { + return nil, err + } log.Info().Msg(fmt.Sprintf( "Client authentication (mTLS) is \"%s\". See the docs to learn about configuring this setting.", @@ -683,6 +673,27 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { } } +// Look up the TLS constant relative to user-supplied TLS client +// authentication mode. +func (h *Headscale) GetClientAuthMode() (tls.ClientAuthType, error) { + + switch h.cfg.TLSClientAuthMode { + case DisabledClientAuth: + // Client cert is _not_ required. + return tls.NoClientCert, nil + case RelaxedClientAuth: + // Client cert required, but _not verified_. + return tls.RequireAnyClientCert, nil + case EnforcedClientAuth: + // Client cert is _required and verified_. + return tls.RequireAndVerifyClientCert, nil + default: + return tls.NoClientCert, Error("Invalid tls_client_auth_mode provided: " + + h.cfg.TLSClientAuthMode) + } + +} + func (h *Headscale) setLastStateChangeToNow(namespace string) { now := time.Now().UTC() lastStateUpdate.WithLabelValues("", "headscale").Set(float64(now.Unix())) diff --git a/app_test.go b/app_test.go index bff13933..a53a8802 100644 --- a/app_test.go +++ b/app_test.go @@ -63,3 +63,25 @@ func (s *Suite) ResetDB(c *check.C) { } app.db = db } + +// Enusre an error is returned when an invalid auth mode +// is supplied. +func (s *Suite) TestInvalidClientAuthMode(c *check.C){ + app.cfg.TLSClientAuthMode = "invalid" + _, err := app.GetClientAuthMode() + c.Assert(err, check.NotNil) +} + +// Ensure that all client auth modes return a nil error +func (s *Suite) TestAuthModes(c *check.C){ + + var modes = []string{"disabled", "relaxed", "enforced"} + + for _, v := range modes { + app.cfg.TLSClientAuthMode = v + _, err := app.GetClientAuthMode() + c.Assert(err, check.IsNil) + } + +} +