feat: Add PKCE Verifier for OIDC (#2314)

* feat: add PKCE verifier for OIDC

* Update CHANGELOG.md
This commit is contained in:
Rorical 2024-12-23 00:46:36 +08:00 committed by GitHub
parent 9313e5b058
commit b81420bef1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 187 additions and 15 deletions

View File

@ -25,6 +25,7 @@ jobs:
- TestOIDCAuthenticationPingAll - TestOIDCAuthenticationPingAll
- TestOIDCExpireNodesBasedOnTokenExpiry - TestOIDCExpireNodesBasedOnTokenExpiry
- TestOIDC024UserCreation - TestOIDC024UserCreation
- TestOIDCAuthenticationWithPKCE
- TestAuthWebFlowAuthenticationPingAll - TestAuthWebFlowAuthenticationPingAll
- TestAuthWebFlowLogoutAndRelogin - TestAuthWebFlowLogoutAndRelogin
- TestUserCommand - TestUserCommand

View File

@ -172,6 +172,7 @@ This will also affect the way you
[#2261](https://github.com/juanfont/headscale/pull/2261) [#2261](https://github.com/juanfont/headscale/pull/2261)
- Add `dns.extra_records_path` configuration option [#2262](https://github.com/juanfont/headscale/issues/2262) - Add `dns.extra_records_path` configuration option [#2262](https://github.com/juanfont/headscale/issues/2262)
- Support client verify for DERP [#2046](https://github.com/juanfont/headscale/pull/2046) - Support client verify for DERP [#2046](https://github.com/juanfont/headscale/pull/2046)
- Add PKCE Verifier for OIDC [#2314](https://github.com/juanfont/headscale/pull/2314)
## 0.23.0 (2024-09-18) ## 0.23.0 (2024-09-18)

View File

@ -364,6 +364,18 @@ unix_socket_permission: "0770"
# allowed_users: # allowed_users:
# - alice@example.com # - alice@example.com
# #
# # Optional: PKCE (Proof Key for Code Exchange) configuration
# # PKCE adds an additional layer of security to the OAuth 2.0 authorization code flow
# # by preventing authorization code interception attacks
# # See https://datatracker.ietf.org/doc/html/rfc7636
# pkce:
# # Enable or disable PKCE support (default: false)
# enabled: false
# # PKCE method to use:
# # - plain: Use plain code verifier
# # - S256: Use SHA256 hashed code verifier (default, recommended)
# method: S256
#
# # Map legacy users from pre-0.24.0 versions of headscale to the new OIDC users # # Map legacy users from pre-0.24.0 versions of headscale to the new OIDC users
# # by taking the username from the legacy user and matching it with the username # # by taking the username from the legacy user and matching it with the username
# # provided by the OIDC. This is useful when migrating from legacy users to OIDC # # provided by the OIDC. This is useful when migrating from legacy users to OIDC

View File

@ -45,6 +45,18 @@ oidc:
allowed_users: allowed_users:
- alice@example.com - alice@example.com
# Optional: PKCE (Proof Key for Code Exchange) configuration
# PKCE adds an additional layer of security to the OAuth 2.0 authorization code flow
# by preventing authorization code interception attacks
# See https://datatracker.ietf.org/doc/html/rfc7636
pkce:
# Enable or disable PKCE support (default: false)
enabled: false
# PKCE method to use:
# - plain: Use plain code verifier
# - S256: Use SHA256 hashed code verifier (default, recommended)
method: S256
# If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed. # If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed.
# This will transform `first-name.last-name@example.com` to the user `first-name.last-name` # This will transform `first-name.last-name@example.com` to the user `first-name.last-name`
# If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following # If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following

View File

@ -28,12 +28,14 @@ import (
) )
const ( const (
randomByteSize = 16 randomByteSize = 16
defaultOAuthOptionsCount = 3
) )
var ( var (
errEmptyOIDCCallbackParams = errors.New("empty OIDC callback params") errEmptyOIDCCallbackParams = errors.New("empty OIDC callback params")
errNoOIDCIDToken = errors.New("could not extract ID Token for OIDC callback") errNoOIDCIDToken = errors.New("could not extract ID Token for OIDC callback")
errNoOIDCRegistrationInfo = errors.New("could not get registration info from cache")
errOIDCAllowedDomains = errors.New( errOIDCAllowedDomains = errors.New(
"authenticated principal does not match any allowed domain", "authenticated principal does not match any allowed domain",
) )
@ -47,11 +49,17 @@ var (
errOIDCNodeKeyMissing = errors.New("could not get node key from cache") errOIDCNodeKeyMissing = errors.New("could not get node key from cache")
) )
// RegistrationInfo contains both machine key and verifier information for OIDC validation.
type RegistrationInfo struct {
MachineKey key.MachinePublic
Verifier *string
}
type AuthProviderOIDC struct { type AuthProviderOIDC struct {
serverURL string serverURL string
cfg *types.OIDCConfig cfg *types.OIDCConfig
db *db.HSDatabase db *db.HSDatabase
registrationCache *zcache.Cache[string, key.MachinePublic] registrationCache *zcache.Cache[string, RegistrationInfo]
notifier *notifier.Notifier notifier *notifier.Notifier
ipAlloc *db.IPAllocator ipAlloc *db.IPAllocator
polMan policy.PolicyManager polMan policy.PolicyManager
@ -87,7 +95,7 @@ func NewAuthProviderOIDC(
Scopes: cfg.Scope, Scopes: cfg.Scope,
} }
registrationCache := zcache.New[string, key.MachinePublic]( registrationCache := zcache.New[string, RegistrationInfo](
registerCacheExpiration, registerCacheExpiration,
registerCacheCleanup, registerCacheCleanup,
) )
@ -157,19 +165,36 @@ func (a *AuthProviderOIDC) RegisterHandler(
stateStr := hex.EncodeToString(randomBlob)[:32] stateStr := hex.EncodeToString(randomBlob)[:32]
// place the node key into the state cache, so it can be retrieved later // Initialize registration info with machine key
a.registrationCache.Set( registrationInfo := RegistrationInfo{
stateStr, MachineKey: machineKey,
machineKey, }
)
// Add any extra parameter provided in the configuration to the Authorize Endpoint request extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount)
extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)) // Add PKCE verification if enabled
if a.cfg.PKCE.Enabled {
verifier := oauth2.GenerateVerifier()
registrationInfo.Verifier = &verifier
extras = append(extras, oauth2.AccessTypeOffline)
switch a.cfg.PKCE.Method {
case types.PKCEMethodS256:
extras = append(extras, oauth2.S256ChallengeOption(verifier))
case types.PKCEMethodPlain:
// oauth2 does not have a plain challenge option, so we add it manually
extras = append(extras, oauth2.SetAuthURLParam("code_challenge_method", "plain"), oauth2.SetAuthURLParam("code_challenge", verifier))
}
}
// Add any extra parameters from configuration
for k, v := range a.cfg.ExtraParams { for k, v := range a.cfg.ExtraParams {
extras = append(extras, oauth2.SetAuthURLParam(k, v)) extras = append(extras, oauth2.SetAuthURLParam(k, v))
} }
// Cache the registration info
a.registrationCache.Set(stateStr, registrationInfo)
authURL := a.oauth2Config.AuthCodeURL(stateStr, extras...) authURL := a.oauth2Config.AuthCodeURL(stateStr, extras...)
log.Debug().Msgf("Redirecting to %s for authentication", authURL) log.Debug().Msgf("Redirecting to %s for authentication", authURL)
@ -203,7 +228,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
return return
} }
idToken, err := a.extractIDToken(req.Context(), code) idToken, err := a.extractIDToken(req.Context(), code, state)
if err != nil { if err != nil {
http.Error(writer, err.Error(), http.StatusBadRequest) http.Error(writer, err.Error(), http.StatusBadRequest)
return return
@ -318,8 +343,21 @@ func extractCodeAndStateParamFromRequest(
func (a *AuthProviderOIDC) extractIDToken( func (a *AuthProviderOIDC) extractIDToken(
ctx context.Context, ctx context.Context,
code string, code string,
state string,
) (*oidc.IDToken, error) { ) (*oidc.IDToken, error) {
oauth2Token, err := a.oauth2Config.Exchange(ctx, code) var exchangeOpts []oauth2.AuthCodeOption
if a.cfg.PKCE.Enabled {
regInfo, ok := a.registrationCache.Get(state)
if !ok {
return nil, errNoOIDCRegistrationInfo
}
if regInfo.Verifier != nil {
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)}
}
}
oauth2Token, err := a.oauth2Config.Exchange(ctx, code, exchangeOpts...)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not exchange code for token: %w", err) return nil, fmt.Errorf("could not exchange code for token: %w", err)
} }
@ -394,7 +432,7 @@ func validateOIDCAllowedUsers(
// cache. If the machine key is found, it will try retrieve the // cache. If the machine key is found, it will try retrieve the
// node information from the database. // node information from the database.
func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *key.MachinePublic) { func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *key.MachinePublic) {
machineKey, ok := a.registrationCache.Get(state) regInfo, ok := a.registrationCache.Get(state)
if !ok { if !ok {
return nil, nil return nil, nil
} }
@ -403,9 +441,9 @@ func (a *AuthProviderOIDC) getMachineKeyFromState(state string) (*types.Node, *k
// The error is not important, because if it does not // The error is not important, because if it does not
// exist, then this is a new node and we will move // exist, then this is a new node and we will move
// on to registration. // on to registration.
node, _ := a.db.GetNodeByMachineKey(machineKey) node, _ := a.db.GetNodeByMachineKey(regInfo.MachineKey)
return node, &machineKey return node, &regInfo.MachineKey
} }
// reauthenticateNode updates the node expiry in the database // reauthenticateNode updates the node expiry in the database

View File

@ -26,11 +26,14 @@ import (
const ( const (
defaultOIDCExpiryTime = 180 * 24 * time.Hour // 180 Days defaultOIDCExpiryTime = 180 * 24 * time.Hour // 180 Days
maxDuration time.Duration = 1<<63 - 1 maxDuration time.Duration = 1<<63 - 1
PKCEMethodPlain string = "plain"
PKCEMethodS256 string = "S256"
) )
var ( var (
errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive") errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive")
errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable") errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable")
errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'")
) )
type IPAllocationStrategy string type IPAllocationStrategy string
@ -162,6 +165,11 @@ type LetsEncryptConfig struct {
ChallengeType string ChallengeType string
} }
type PKCEConfig struct {
Enabled bool
Method string
}
type OIDCConfig struct { type OIDCConfig struct {
OnlyStartIfOIDCIsAvailable bool OnlyStartIfOIDCIsAvailable bool
Issuer string Issuer string
@ -176,6 +184,7 @@ type OIDCConfig struct {
Expiry time.Duration Expiry time.Duration
UseExpiryFromToken bool UseExpiryFromToken bool
MapLegacyUsers bool MapLegacyUsers bool
PKCE PKCEConfig
} }
type DERPConfig struct { type DERPConfig struct {
@ -226,6 +235,13 @@ type Tuning struct {
NodeMapSessionBufferedChanSize int NodeMapSessionBufferedChanSize int
} }
func validatePKCEMethod(method string) error {
if method != PKCEMethodPlain && method != PKCEMethodS256 {
return errInvalidPKCEMethod
}
return nil
}
// LoadConfig prepares and loads the Headscale configuration into Viper. // LoadConfig prepares and loads the Headscale configuration into Viper.
// This means it sets the default values, reads the configuration file and // This means it sets the default values, reads the configuration file and
// environment variables, and handles deprecated configuration options. // environment variables, and handles deprecated configuration options.
@ -293,6 +309,8 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("oidc.expiry", "180d") viper.SetDefault("oidc.expiry", "180d")
viper.SetDefault("oidc.use_expiry_from_token", false) viper.SetDefault("oidc.use_expiry_from_token", false)
viper.SetDefault("oidc.map_legacy_users", true) viper.SetDefault("oidc.map_legacy_users", true)
viper.SetDefault("oidc.pkce.enabled", false)
viper.SetDefault("oidc.pkce.method", "S256")
viper.SetDefault("logtail.enabled", false) viper.SetDefault("logtail.enabled", false)
viper.SetDefault("randomize_client_port", false) viper.SetDefault("randomize_client_port", false)
@ -340,6 +358,12 @@ func validateServerConfig() error {
// after #2170 is cleaned up // after #2170 is cleaned up
// depr.fatal("oidc.strip_email_domain") // depr.fatal("oidc.strip_email_domain")
if viper.GetBool("oidc.enabled") {
if err := validatePKCEMethod(viper.GetString("oidc.pkce.method")); err != nil {
return err
}
}
depr.Log() depr.Log()
for _, removed := range []string{ for _, removed := range []string{
@ -928,6 +952,10 @@ func LoadServerConfig() (*Config, error) {
// after #2170 is cleaned up // after #2170 is cleaned up
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"), StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
MapLegacyUsers: viper.GetBool("oidc.map_legacy_users"), MapLegacyUsers: viper.GetBool("oidc.map_legacy_users"),
PKCE: PKCEConfig{
Enabled: viper.GetBool("oidc.pkce.enabled"),
Method: viper.GetString("oidc.pkce.method"),
},
}, },
LogTail: logTailConfig, LogTail: logTailConfig,

View File

@ -534,6 +534,86 @@ func TestOIDC024UserCreation(t *testing.T) {
} }
} }
func TestOIDCAuthenticationWithPKCE(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
baseScenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
scenario := AuthOIDCScenario{
Scenario: baseScenario,
}
defer scenario.ShutdownAssertNoPanics(t)
// Single user with one node for testing PKCE flow
spec := map[string]int{
"user1": 1,
}
mockusers := []mockoidc.MockUser{
oidcMockUser("user1", true),
}
oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers)
assertNoErrf(t, "failed to run mock OIDC server: %s", err)
defer scenario.mockOIDC.Close()
oidcMap := map[string]string{
"HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer,
"HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID,
"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret",
"CREDENTIALS_DIRECTORY_TEST": "/tmp",
"HEADSCALE_OIDC_PKCE_ENABLED": "1", // Enable PKCE
"HEADSCALE_OIDC_MAP_LEGACY_USERS": "0",
"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0",
}
err = scenario.CreateHeadscaleEnv(
spec,
hsic.WithTestName("oidcauthpkce"),
hsic.WithConfigEnv(oidcMap),
hsic.WithTLS(),
hsic.WithHostnameAsServerURL(),
hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)),
)
assertNoErrHeadscaleEnv(t, err)
// Get all clients and verify they can connect
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
allIps, err := scenario.ListTailscaleClientsIPs()
assertNoErrListClientIPs(t, err)
err = scenario.WaitForTailscaleSync()
assertNoErrSync(t, err)
// Verify PKCE was used in authentication
headscale, err := scenario.Headscale()
assertNoErr(t, err)
var listUsers []v1.User
err = executeAndUnmarshal(headscale,
[]string{
"headscale",
"users",
"list",
"--output",
"json",
},
&listUsers,
)
assertNoErr(t, err)
allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string {
return x.String()
})
success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
}
func (s *AuthOIDCScenario) CreateHeadscaleEnv( func (s *AuthOIDCScenario) CreateHeadscaleEnv(
users map[string]int, users map[string]int,
opts ...hsic.Option, opts ...hsic.Option,