OIDC code cleanup and harmonize with regular web auth

This commit is contained in:
Juan Font 2022-11-14 14:05:47 +00:00 committed by Juan Font
parent 46df219ed3
commit f74266f8f8

91
oidc.go
View File

@ -76,20 +76,52 @@ func (h *Headscale) RegisterOIDC(
) { ) {
vars := mux.Vars(req) vars := mux.Vars(req)
nodeKeyStr, ok := vars["nkey"] nodeKeyStr, ok := vars["nkey"]
if !ok || nodeKeyStr == "" {
log.Error().
Caller().
Msg("Missing node key in URL")
http.Error(writer, "Missing node key in URL", http.StatusBadRequest)
return
}
log.Trace(). log.Trace().
Caller(). Caller().
Str("node_key", nodeKeyStr). Str("node_key", nodeKeyStr).
Msg("Received oidc register call") Msg("Received oidc register call")
if !NodePublicKeyRegex.Match([]byte(nodeKeyStr)) {
log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusUnauthorized)
_, err := writer.Write([]byte("Unauthorized"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error.
var nodeKey key.NodePublic
err := nodeKey.UnmarshalText(
[]byte(NodePublicKeyEnsurePrefix(nodeKeyStr)),
)
if !ok || nodeKeyStr == "" || err != nil {
log.Warn().Err(err).Msg("Failed to parse incoming nodekey")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Wrong params"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
randomBlob := make([]byte, randomByteSize) randomBlob := make([]byte, randomByteSize)
if _, err := rand.Read(randomBlob); err != nil { if _, err := rand.Read(randomBlob); err != nil {
log.Error(). log.Error().
@ -103,7 +135,7 @@ func (h *Headscale) RegisterOIDC(
stateStr := hex.EncodeToString(randomBlob)[:32] stateStr := hex.EncodeToString(randomBlob)[:32]
// place the node key into the state cache, so it can be retrieved later // place the node key into the state cache, so it can be retrieved later
h.registrationCache.Set(stateStr, nodeKeyStr, registerCacheExpiration) h.registrationCache.Set(stateStr, NodePublicKeyStripPrefix(nodeKey), registerCacheExpiration)
// Add any extra parameter provided in the configuration to the Authorize Endpoint request // Add any extra parameter provided in the configuration to the Authorize Endpoint request
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams)) extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
@ -405,8 +437,8 @@ func (h *Headscale) validateMachineForOIDCCallback(
claims *IDTokenClaims, claims *IDTokenClaims,
) (*key.NodePublic, bool, error) { ) (*key.NodePublic, bool, error) {
// retrieve machinekey from state cache // retrieve machinekey from state cache
machineKeyIf, machineKeyFound := h.registrationCache.Get(state) nodeKeyIf, nodeKeyFound := h.registrationCache.Get(state)
if !machineKeyFound { if !nodeKeyFound {
log.Error(). log.Error().
Msg("requested machine state key expired before authorisation completed") Msg("requested machine state key expired before authorisation completed")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
@ -423,16 +455,34 @@ func (h *Headscale) validateMachineForOIDCCallback(
} }
var nodeKey key.NodePublic var nodeKey key.NodePublic
nodeKeyFromCache, nodeKeyOK := machineKeyIf.(string) nodeKeyFromCache, nodeKeyOK := nodeKeyIf.(string)
if !nodeKeyOK {
log.Error().
Msg("requested machine state key is not a string")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state is invalid"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return nil, false, errOIDCInvalidMachineState
}
err := nodeKey.UnmarshalText( err := nodeKey.UnmarshalText(
[]byte(NodePublicKeyEnsurePrefix(nodeKeyFromCache)), []byte(NodePublicKeyEnsurePrefix(nodeKeyFromCache)),
) )
if err != nil { if err != nil {
log.Error(). log.Error().
Str("nodeKey", nodeKeyFromCache).
Bool("nodeKeyOK", nodeKeyOK).
Msg("could not parse node public key") Msg("could not parse node public key")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("could not parse public key")) _, werr := writer.Write([]byte("could not parse node public key"))
if werr != nil { if werr != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -443,21 +493,6 @@ func (h *Headscale) validateMachineForOIDCCallback(
return nil, false, err return nil, false, err
} }
if !nodeKeyOK {
log.Error().Msg("could not get node key from cache")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("could not get node key from cache"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return nil, false, errOIDCNodeKeyMissing
}
// retrieve machine information if it exist // retrieve machine information if it exist
// The error is not important, because if it does not // The error is not important, because if it does not
// exist, then this is a new machine and we will move // exist, then this is a new machine and we will move