diff --git a/api.go b/api.go index 88c44225..bd39f5b0 100644 --- a/api.go +++ b/api.go @@ -321,33 +321,61 @@ func (h *Headscale) getMapResponse( Msgf("Generated map response: %s", tailMapResponseToString(resp)) var respBody []byte - if req.Compress == "zstd" { - src, err := json.Marshal(resp) + if machineKey.IsZero() { + // The TS2021 protocol does not rely anymore on the machine key to + // encrypt in a NaCl box the map response. We just send it back + // unencrypted via the encrypted Noise channel. + // declare the incoming size on the first 4 bytes + respBody, err := json.Marshal(resp) if err != nil { log.Error(). Caller(). - Str("func", "getMapResponse"). Err(err). - Msg("Failed to marshal response for the client") - - return nil, err + Msg("Cannot marshal map response") } - encoder, _ := zstd.NewWriter(nil) - srcCompressed := encoder.EncodeAll(src, nil) - respBody = h.privateKey.SealTo(machineKey, srcCompressed) + var srcCompressed []byte + if req.Compress == "zstd" { + encoder, _ := zstd.NewWriter(nil) + srcCompressed = encoder.EncodeAll(respBody, nil) + } else { + srcCompressed = respBody + } + + data := make([]byte, reservedResponseHeaderSize) + binary.LittleEndian.PutUint32(data, uint32(len(srcCompressed))) + data = append(data, srcCompressed...) + + return data, nil } else { - respBody, err = encode(resp, &machineKey, h.privateKey) - if err != nil { - return nil, err - } - } - // declare the incoming size on the first 4 bytes - data := make([]byte, reservedResponseHeaderSize) - binary.LittleEndian.PutUint32(data, uint32(len(respBody))) - data = append(data, respBody...) + if req.Compress == "zstd" { + src, err := json.Marshal(resp) + if err != nil { + log.Error(). + Caller(). + Str("func", "getMapResponse"). + Err(err). + Msg("Failed to marshal response for the client") - return data, nil + return nil, err + } + + encoder, _ := zstd.NewWriter(nil) + srcCompressed := encoder.EncodeAll(src, nil) + respBody = h.privateKey.SealTo(machineKey, srcCompressed) + } else { + respBody, err = encode(resp, &machineKey, h.privateKey) + if err != nil { + return nil, err + } + } + // declare the incoming size on the first 4 bytes + data := make([]byte, reservedResponseHeaderSize) + binary.LittleEndian.PutUint32(data, uint32(len(respBody))) + data = append(data, respBody...) + + return data, nil + } } func (h *Headscale) getMapKeepAliveResponse( @@ -359,31 +387,36 @@ func (h *Headscale) getMapKeepAliveResponse( } var respBody []byte var err error - if mapRequest.Compress == "zstd" { - src, err := json.Marshal(mapResponse) - if err != nil { - log.Error(). - Caller(). - Str("func", "getMapKeepAliveResponse"). - Err(err). - Msg("Failed to marshal keepalive response for the client") - - return nil, err - } - encoder, _ := zstd.NewWriter(nil) - srcCompressed := encoder.EncodeAll(src, nil) - respBody = h.privateKey.SealTo(machineKey, srcCompressed) + if machineKey.IsZero() { + // The TS2021 protocol does not rely anymore on the machine key. + return json.Marshal(mapResponse) } else { - respBody, err = encode(mapResponse, &machineKey, h.privateKey) - if err != nil { - return nil, err - } - } - data := make([]byte, reservedResponseHeaderSize) - binary.LittleEndian.PutUint32(data, uint32(len(respBody))) - data = append(data, respBody...) + if mapRequest.Compress == "zstd" { + src, err := json.Marshal(mapResponse) + if err != nil { + log.Error(). + Caller(). + Str("func", "getMapKeepAliveResponse"). + Err(err). + Msg("Failed to marshal keepalive response for the client") - return data, nil + return nil, err + } + encoder, _ := zstd.NewWriter(nil) + srcCompressed := encoder.EncodeAll(src, nil) + respBody = h.privateKey.SealTo(machineKey, srcCompressed) + } else { + respBody, err = encode(mapResponse, &machineKey, h.privateKey) + if err != nil { + return nil, err + } + } + data := make([]byte, reservedResponseHeaderSize) + binary.LittleEndian.PutUint32(data, uint32(len(respBody))) + data = append(data, respBody...) + + return data, nil + } } func (h *Headscale) handleMachineLogOut( @@ -571,7 +604,7 @@ func (h *Headscale) handleAuthKey( machineKeyStr = MachinePublicKeyStripPrefix(machineKey) } log.Debug(). - Str("func", "handleAuthKey"). + Caller(). Str("machine", registerRequest.Hostinfo.Hostname). Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname) resp := tailcfg.RegisterResponse{} @@ -618,7 +651,7 @@ func (h *Headscale) handleAuthKey( } log.Debug(). - Str("func", "handleAuthKey"). + Caller(). Str("machine", registerRequest.Hostinfo.Hostname). Msg("Authentication key was valid, proceeding to acquire IP addresses") @@ -674,6 +707,14 @@ func (h *Headscale) handleAuthKey( resp.MachineAuthorized = true resp.User = *pak.Namespace.toUser() + + // TS2021 + if machineKey.IsZero() { + ctx.JSON(http.StatusOK, resp) + + return + } + respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). diff --git a/app.go b/app.go index 7fd1aaa3..d72f56ce 100644 --- a/app.go +++ b/app.go @@ -514,6 +514,7 @@ func (h *Headscale) createNoiseRouter() *gin.Engine { router := gin.Default() router.POST("/machine/register", h.NoiseRegistrationHandler) + router.POST("/machine/map", h.NoisePollNetMapHandler) return router } diff --git a/machine.go b/machine.go index 7afaaf9a..538903a0 100644 --- a/machine.go +++ b/machine.go @@ -518,11 +518,14 @@ func (machine Machine) toNode( } var machineKey key.MachinePublic - err = machineKey.UnmarshalText( - []byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)), - ) - if err != nil { - return nil, fmt.Errorf("failed to parse machine public key: %w", err) + if machine.MachineKey != "" { + // MachineKey is only used in the legacy protocol + err = machineKey.UnmarshalText( + []byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)), + ) + if err != nil { + return nil, fmt.Errorf("failed to parse machine public key: %w", err) + } } var discoKey key.DiscoPublic diff --git a/noise_api.go b/noise_api.go index 6db661ee..0914e23e 100644 --- a/noise_api.go +++ b/noise_api.go @@ -31,6 +31,10 @@ func (h *Headscale) NoiseRegistrationHandler(ctx *gin.Context) { return } + log.Info().Caller(). + Str("nodekey", req.NodeKey.ShortString()). + Str("oldnodekey", req.OldNodeKey.ShortString()).Msg("Nodekys!") + now := time.Now().UTC() machine, err := h.GetMachineByNodeKeys(req.NodeKey, req.OldNodeKey) if errors.Is(err, gorm.ErrRecordNotFound) { @@ -49,7 +53,6 @@ func (h *Headscale) NoiseRegistrationHandler(ctx *gin.Context) { if err != nil { log.Error(). Caller(). - Str("func", "RegistrationHandler"). Str("hostinfo.name", req.Hostinfo.Hostname). Err(err) @@ -128,6 +131,232 @@ func (h *Headscale) NoiseRegistrationHandler(ctx *gin.Context) { } } +// NoisePollNetMapHandler takes care of /machine/:id/map +// +// This is the busiest endpoint, as it keeps the HTTP long poll that updates +// the clients when something in the network changes. +// +// The clients POST stuff like HostInfo and their Endpoints here, but +// only after their first request (marked with the ReadOnly field). +// +// At this moment the updates are sent in a quite horrendous way, but they kinda work. +func (h *Headscale) NoisePollNetMapHandler(ctx *gin.Context) { + log.Trace(). + Caller(). + Str("id", ctx.Param("id")). + Msg("PollNetMapHandler called") + body, _ := io.ReadAll(ctx.Request.Body) + + req := tailcfg.MapRequest{} + if err := json.Unmarshal(body, &req); err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Cannot parse MapRequest") + ctx.String(http.StatusInternalServerError, "Eek!") + + return + } + + machine, err := h.GetMachineByNodeKeys(req.NodeKey, key.NodePublic{}) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + log.Warn().Caller(). + Msgf("Ignoring request, cannot find node with node key %s", req.NodeKey.String()) + ctx.String(http.StatusUnauthorized, "") + + return + } + log.Error(). + Caller(). + Msgf("Failed to fetch machine from the database with NodeKey: %s", req.NodeKey.String()) + ctx.String(http.StatusInternalServerError, "") + + return + } + log.Trace().Caller(). + Str("NodeKey", req.NodeKey.ShortString()). + Str("machine", machine.Name). + Msg("Found machine in database") + + hname, err := NormalizeToFQDNRules( + req.Hostinfo.Hostname, + h.cfg.OIDC.StripEmaildomain, + ) + if err != nil { + log.Error(). + Caller(). + Str("hostinfo.name", req.Hostinfo.Hostname). + Err(err) + } + machine.Name = hname + machine.HostInfo = HostInfo(*req.Hostinfo) + machine.DiscoKey = DiscoPublicKeyStripPrefix(req.DiscoKey) + now := time.Now().UTC() + + // update ACLRules with peer informations (to update server tags if necessary) + if h.aclPolicy != nil { + err = h.UpdateACLRules() + if err != nil { + log.Error(). + Caller(). + Str("func", "handleAuthKey"). + Str("machine", machine.Name). + Err(err) + } + } + // From Tailscale client: + // + // ReadOnly is whether the client just wants to fetch the MapResponse, + // without updating their Endpoints. The Endpoints field will be ignored and + // LastSeen will not be updated and peers will not be notified of changes. + // + // The intended use is for clients to discover the DERP map at start-up + // before their first real endpoint update. + if !req.ReadOnly { + machine.Endpoints = req.Endpoints + machine.LastSeen = &now + } + h.db.Updates(machine) + + data, err := h.getMapResponse(key.MachinePublic{}, req, machine) + if err != nil { + log.Error(). + Caller(). + Str("id", ctx.Param("id")). + Str("machine", machine.Name). + Err(err). + Msg("Failed to get Map response") + ctx.String(http.StatusInternalServerError, ":(") + + return + } + + // We update our peers if the client is not sending ReadOnly in the MapRequest + // so we don't distribute its initial request (it comes with + // empty endpoints to peers) + + // Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696 + log.Debug(). + Caller(). + Str("id", ctx.Param("id")). + Str("machine", machine.Name). + Bool("readOnly", req.ReadOnly). + Bool("omitPeers", req.OmitPeers). + Bool("stream", req.Stream). + Msg("Client map request processed") + + if req.ReadOnly { + log.Info(). + Caller(). + Str("machine", machine.Name). + Msg("Client is starting up. Probably interested in a DERP map") + // log.Info().Str("machine", machine.Name).Bytes("resp", data).Msg("Sending DERP map to client") + + ctx.Data(http.StatusOK, "application/json; charset=utf-8", data) + + return + } + + // There has been an update to _any_ of the nodes that the other nodes would + // need to know about + h.setLastStateChangeToNow(machine.Namespace.Name) + + // The request is not ReadOnly, so we need to set up channels for updating + // peers via longpoll + + // Only create update channel if it has not been created + log.Trace(). + Caller(). + Str("id", ctx.Param("id")). + Str("machine", machine.Name). + Msg("Loading or creating update channel") + + // TODO: could probably remove all that duplication once generics land. + closeChanWithLog := func(channel interface{}, name string) { + log.Trace(). + Caller(). + Str("machine", machine.Name). + Str("channel", "Done"). + Msg(fmt.Sprintf("Closing %s channel", name)) + + switch c := channel.(type) { + case (chan struct{}): + close(c) + + case (chan []byte): + close(c) + } + } + + const chanSize = 8 + updateChan := make(chan struct{}, chanSize) + defer closeChanWithLog(updateChan, "updateChan") + + pollDataChan := make(chan []byte, chanSize) + defer closeChanWithLog(pollDataChan, "pollDataChan") + + keepAliveChan := make(chan []byte) + defer closeChanWithLog(keepAliveChan, "keepAliveChan") + + if req.OmitPeers && !req.Stream { + log.Info(). + Caller(). + Str("machine", machine.Name). + Msg("Client sent endpoint update and is ok with a response without peer list") + ctx.Data(http.StatusOK, "application/json; charset=utf-8", data) + + // It sounds like we should update the nodes when we have received a endpoint update + // even tho the comments in the tailscale code dont explicitly say so. + updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Name, "endpoint-update"). + Inc() + updateChan <- struct{}{} + + return + } else if req.OmitPeers && req.Stream { + log.Warn(). + Caller(). + Str("machine", machine.Name). + Msg("Ignoring request, don't know how to handle it") + ctx.String(http.StatusBadRequest, "") + + return + } + + log.Info(). + Caller(). + Str("machine", machine.Name). + Msg("Client is ready to access the tailnet") + log.Info(). + Caller(). + Str("machine", machine.Name). + Msg("Sending initial map") + pollDataChan <- data + + log.Info(). + Caller(). + Str("machine", machine.Name). + Msg("Notifying peers") + updateRequestsFromNode.WithLabelValues(machine.Namespace.Name, machine.Name, "full-update"). + Inc() + updateChan <- struct{}{} + + h.PollNetMapStream( + ctx, + machine, + req, + key.MachinePublic{}, + pollDataChan, + keepAliveChan, + updateChan, + ) + log.Trace(). + Caller(). + Str("id", ctx.Param("id")). + Str("machine", machine.Name). + Msg("Finished stream, closing PollNetMap session") +} + func (h *Headscale) handleNoiseNodeValidRegistration( ctx *gin.Context, machine Machine,