diff --git a/oidc.go b/oidc.go index 853345a6..44382796 100644 --- a/oidc.go +++ b/oidc.go @@ -76,20 +76,52 @@ func (h *Headscale) RegisterOIDC( ) { vars := mux.Vars(req) nodeKeyStr, ok := vars["nkey"] - if !ok || nodeKeyStr == "" { - log.Error(). - Caller(). - Msg("Missing node key in URL") - http.Error(writer, "Missing node key in URL", http.StatusBadRequest) - - return - } log.Trace(). Caller(). Str("node_key", nodeKeyStr). Msg("Received oidc register call") + if !NodePublicKeyRegex.Match([]byte(nodeKeyStr)) { + log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url") + + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusUnauthorized) + _, err := writer.Write([]byte("Unauthorized")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } + + return + } + + // We need to make sure we dont open for XSS style injections, if the parameter that + // is passed as a key is not parsable/validated as a NodePublic key, then fail to render + // the template and log an error. + var nodeKey key.NodePublic + err := nodeKey.UnmarshalText( + []byte(NodePublicKeyEnsurePrefix(nodeKeyStr)), + ) + + if !ok || nodeKeyStr == "" || err != nil { + log.Warn().Err(err).Msg("Failed to parse incoming nodekey") + + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusBadRequest) + _, err := writer.Write([]byte("Wrong params")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } + + return + } + randomBlob := make([]byte, randomByteSize) if _, err := rand.Read(randomBlob); err != nil { log.Error(). @@ -103,7 +135,7 @@ func (h *Headscale) RegisterOIDC( stateStr := hex.EncodeToString(randomBlob)[:32] // place the node key into the state cache, so it can be retrieved later - h.registrationCache.Set(stateStr, nodeKeyStr, registerCacheExpiration) + h.registrationCache.Set(stateStr, NodePublicKeyStripPrefix(nodeKey), registerCacheExpiration) // Add any extra parameter provided in the configuration to the Authorize Endpoint request extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams)) @@ -405,8 +437,8 @@ func (h *Headscale) validateMachineForOIDCCallback( claims *IDTokenClaims, ) (*key.NodePublic, bool, error) { // retrieve machinekey from state cache - machineKeyIf, machineKeyFound := h.registrationCache.Get(state) - if !machineKeyFound { + nodeKeyIf, nodeKeyFound := h.registrationCache.Get(state) + if !nodeKeyFound { log.Error(). Msg("requested machine state key expired before authorisation completed") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") @@ -423,16 +455,34 @@ func (h *Headscale) validateMachineForOIDCCallback( } var nodeKey key.NodePublic - nodeKeyFromCache, nodeKeyOK := machineKeyIf.(string) + nodeKeyFromCache, nodeKeyOK := nodeKeyIf.(string) + if !nodeKeyOK { + log.Error(). + Msg("requested machine state key is not a string") + writer.Header().Set("Content-Type", "text/plain; charset=utf-8") + writer.WriteHeader(http.StatusBadRequest) + _, err := writer.Write([]byte("state is invalid")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write response") + } + + return nil, false, errOIDCInvalidMachineState + } + err := nodeKey.UnmarshalText( []byte(NodePublicKeyEnsurePrefix(nodeKeyFromCache)), ) if err != nil { log.Error(). + Str("nodeKey", nodeKeyFromCache). + Bool("nodeKeyOK", nodeKeyOK). Msg("could not parse node public key") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusBadRequest) - _, werr := writer.Write([]byte("could not parse public key")) + _, werr := writer.Write([]byte("could not parse node public key")) if werr != nil { log.Error(). Caller(). @@ -443,21 +493,6 @@ func (h *Headscale) validateMachineForOIDCCallback( return nil, false, err } - if !nodeKeyOK { - log.Error().Msg("could not get node key from cache") - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusInternalServerError) - _, err := writer.Write([]byte("could not get node key from cache")) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - - return nil, false, errOIDCNodeKeyMissing - } - // retrieve machine information if it exist // The error is not important, because if it does not // exist, then this is a new machine and we will move