diff --git a/protocol_common_utils.go b/protocol_common_utils.go index 5939b6a4..3eb54715 100644 --- a/protocol_common_utils.go +++ b/protocol_common_utils.go @@ -10,35 +10,87 @@ import ( "tailscale.com/types/key" ) -func (h *Headscale) getMapKeepAliveResponse( - machineKey key.MachinePublic, +func (h *Headscale) getMapResponseData( mapRequest tailcfg.MapRequest, + machine *Machine, + isNoise bool, ) ([]byte, error) { - mapResponse := tailcfg.MapResponse{ + mapResponse, err := h.generateMapResponse(mapRequest, machine) + if err != nil { + return nil, err + } + + if isNoise { + return h.marshalResponse(mapResponse, mapRequest.Compress, key.MachinePublic{}) + } + + var machineKey key.MachinePublic + err = machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey))) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Cannot parse client key") + + return nil, err + } + + return h.marshalResponse(mapResponse, mapRequest.Compress, machineKey) +} + +func (h *Headscale) getMapKeepAliveResponseData( + mapRequest tailcfg.MapRequest, + machine *Machine, + isNoise bool, +) ([]byte, error) { + keepAliveResponse := tailcfg.MapResponse{ KeepAlive: true, } - var respBody []byte - var err error - if mapRequest.Compress == ZstdCompression { - src, err := json.Marshal(mapResponse) - if err != nil { - log.Error(). - Caller(). - Str("func", "getMapKeepAliveResponse"). - Err(err). - Msg("Failed to marshal keepalive response for the client") - return nil, err - } + if isNoise { + return h.marshalResponse(keepAliveResponse, mapRequest.Compress, key.MachinePublic{}) + } + + var machineKey key.MachinePublic + err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey))) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Cannot parse client key") + + return nil, err + } + + return h.marshalResponse(keepAliveResponse, mapRequest.Compress, machineKey) +} + +func (h *Headscale) marshalResponse( + resp interface{}, + compression string, + machineKey key.MachinePublic, +) ([]byte, error) { + jsonBody, err := json.Marshal(resp) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Cannot marshal map response") + } + + var respBody []byte + if compression == ZstdCompression { encoder, _ := zstd.NewWriter(nil) - srcCompressed := encoder.EncodeAll(src, nil) - respBody = h.privateKey.SealTo(machineKey, srcCompressed) + respBody = encoder.EncodeAll(jsonBody, nil) + if !machineKey.IsZero() { // if legacy protocol + respBody = h.privateKey.SealTo(machineKey, respBody) + } } else { - respBody, err = encode(mapResponse, &machineKey, h.privateKey) - if err != nil { - return nil, err + if !machineKey.IsZero() { // if legacy protocol + respBody = h.privateKey.SealTo(machineKey, jsonBody) } } + data := make([]byte, reservedResponseHeaderSize) binary.LittleEndian.PutUint32(data, uint32(len(respBody))) data = append(data, respBody...)