diff --git a/app.go b/app.go index 62d284b6..b3725eea 100644 --- a/app.go +++ b/app.go @@ -657,11 +657,10 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { log.Warn().Msg("Listening with TLS but ServerURL does not start with https://") } - clientAuthMode, err := h.GetClientAuthMode() - - if err != nil { - return nil, err - } + clientAuthMode, err := h.GetClientAuthMode() + if err != nil { + return nil, err + } log.Info().Msg(fmt.Sprintf( "Client authentication (mTLS) is \"%s\". See the docs to learn about configuring this setting.", @@ -683,22 +682,20 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { // Look up the TLS constant relative to user-supplied TLS client // authentication mode. func (h *Headscale) GetClientAuthMode() (tls.ClientAuthType, error) { - - switch h.cfg.TLSClientAuthMode { - case DisabledClientAuth: - // Client cert is _not_ required. - return tls.NoClientCert, nil - case RelaxedClientAuth: - // Client cert required, but _not verified_. - return tls.RequireAnyClientCert, nil - case EnforcedClientAuth: - // Client cert is _required and verified_. - return tls.RequireAndVerifyClientCert, nil - default: - return tls.NoClientCert, Error("Invalid tls_client_auth_mode provided: " + - h.cfg.TLSClientAuthMode) - } - + switch h.cfg.TLSClientAuthMode { + case DisabledClientAuth: + // Client cert is _not_ required. + return tls.NoClientCert, nil + case RelaxedClientAuth: + // Client cert required, but _not verified_. + return tls.RequireAnyClientCert, nil + case EnforcedClientAuth: + // Client cert is _required and verified_. + return tls.RequireAndVerifyClientCert, nil + default: + return tls.NoClientCert, Error("Invalid tls_client_auth_mode provided: " + + h.cfg.TLSClientAuthMode) + } } func (h *Headscale) setLastStateChangeToNow(namespace string) { diff --git a/app_test.go b/app_test.go index a53a8802..94b6ef00 100644 --- a/app_test.go +++ b/app_test.go @@ -66,22 +66,19 @@ func (s *Suite) ResetDB(c *check.C) { // Enusre an error is returned when an invalid auth mode // is supplied. -func (s *Suite) TestInvalidClientAuthMode(c *check.C){ - app.cfg.TLSClientAuthMode = "invalid" - _, err := app.GetClientAuthMode() - c.Assert(err, check.NotNil) +func (s *Suite) TestInvalidClientAuthMode(c *check.C) { + app.cfg.TLSClientAuthMode = "invalid" + _, err := app.GetClientAuthMode() + c.Assert(err, check.NotNil) } -// Ensure that all client auth modes return a nil error -func (s *Suite) TestAuthModes(c *check.C){ - - var modes = []string{"disabled", "relaxed", "enforced"} - - for _, v := range modes { - app.cfg.TLSClientAuthMode = v - _, err := app.GetClientAuthMode() - c.Assert(err, check.IsNil) - } +// Ensure that all client auth modes return a nil error. +func (s *Suite) TestAuthModes(c *check.C) { + modes := []string{"disabled", "relaxed", "enforced"} + for _, v := range modes { + app.cfg.TLSClientAuthMode = v + _, err := app.GetClientAuthMode() + c.Assert(err, check.IsNil) + } } -