/*
 * Copyright (C) 2015 Jared Boone, ShareBrained Technology, Inc.
 * Copyright (C) 2016 Furrtek
 *
 * This file is part of PortaPack.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2, or (at your option)
 * any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; see the file COPYING.  If not, write to
 * the Free Software Foundation, Inc., 51 Franklin Street,
 * Boston, MA 02110-1301, USA.
 */

#include "msgpack.hpp"

bool MsgPack::get_bool(const void* buffer, const bool inc, bool* value) {
    uint8_t v;

    if (seek_ptr >= buffer_size) return false;  // End of buffer

    v = ((uint8_t*)buffer)[seek_ptr];
    if (v == MSGPACK_FALSE)
        *value = false;
    else if (v == MSGPACK_TRUE)
        *value = true;
    else
        return false;  // Not a bool

    if (inc) seek_ptr++;
    return true;
}

bool MsgPack::get_raw_byte(const void* buffer, const bool inc, uint8_t* byte) {
    if (seek_ptr >= buffer_size) return false;  // End of buffer
    *byte = ((uint8_t*)buffer)[seek_ptr];
    if (inc) seek_ptr++;
    return true;
}

bool MsgPack::get_raw_word(const void* buffer, const bool inc, uint16_t* word) {
    if ((seek_ptr + 1) >= buffer_size) return false;  // End of buffer
    *word = (((uint8_t*)buffer)[seek_ptr] << 8) | ((uint8_t*)buffer)[seek_ptr + 1];
    if (inc) seek_ptr += 2;
    return true;
}

bool MsgPack::get_u8(const void* buffer, const bool inc, uint8_t* value) {
    uint8_t v;

    if (seek_ptr >= buffer_size) return false;  // End of buffer

    v = ((uint8_t*)buffer)[seek_ptr];

    if (!(v & 0x80))
        *value = ((uint8_t*)buffer)[seek_ptr];  // Fixnum
    else if (v == MSGPACK_TYPE_U8)
        *value = ((uint8_t*)buffer)[seek_ptr + 1];  // u8
    else
        return false;  // Value isn't a u8 or fixnum

    if (inc) seek_ptr++;
    return true;
}

// TODO: Typecheck function

bool MsgPack::get_u16(const void* buffer, const bool inc, uint16_t* value) {
    uint8_t byte;

    if ((seek_ptr + 1) >= buffer_size) return false;                                      // End of buffer
    if ((get_raw_byte(buffer, true, &byte)) && (byte != MSGPACK_TYPE_U16)) return false;  // Value isn't a u16
    *value = (((uint8_t*)buffer)[seek_ptr] << 8) | ((uint8_t*)buffer)[seek_ptr + 1];
    if (inc) seek_ptr += 2;
    return true;
}

bool MsgPack::get_s32(const void* buffer, const bool inc, int32_t* value) {
    uint8_t byte;

    if ((seek_ptr + 3) >= buffer_size) return false;                                      // End of buffer
    if ((get_raw_byte(buffer, true, &byte)) && (byte != MSGPACK_TYPE_S32)) return false;  // Value isn't a s32
    *value = (((uint8_t*)buffer)[seek_ptr] << 24) | (((uint8_t*)buffer)[seek_ptr + 1] << 16) |
             (((uint8_t*)buffer)[seek_ptr + 2] << 8) | ((uint8_t*)buffer)[seek_ptr + 3];
    if (inc) seek_ptr += 4;
    return true;
}

bool MsgPack::get_string(const void* buffer, const bool inc, std::string& value) {
    size_t length;
    uint8_t byte;

    // Todo: Set max length !
    if ((seek_ptr + 3) >= buffer_size) return false;                                                                       // End of buffer
    if ((get_raw_byte(buffer, true, &byte)) && (byte != MSGPACK_TYPE_STR8) && (byte != MSGPACK_TYPE_STR16)) return false;  // Value isn't a str8 or str16

    if (byte == MSGPACK_TYPE_STR8) {
        if (!get_raw_byte(buffer, true, (uint8_t*)&length)) return false;  // Couldn't get str8 length
    } else if (byte == MSGPACK_TYPE_STR16) {
        if (!get_raw_word(buffer, true, (uint16_t*)&length)) return false;  // Couldn't get str16 length
    }

    memcpy(&value[0], ((uint8_t*)buffer), length);  // std::string(

    if (inc) seek_ptr += length;
    return true;
}

bool MsgPack::init_search(const void* buffer, const size_t size) {
    uint8_t byte;
    uint16_t map_size;

    if (!size) return false;
    buffer_size = size;
    seek_ptr = 0;
    if ((get_raw_byte(buffer, true, &byte)) && (byte != MSGPACK_TYPE_MAP16)) return false;  // First record isn't a map16
    if (!get_raw_word(buffer, true, &map_size)) return false;                               // Couldn't get map16 size
    if (!map_size) return false;

    return true;
}

bool MsgPack::skip(const void* buffer) {
    uint8_t byte, c;
    size_t length;

    if (!get_raw_byte(buffer, true, &byte)) return false;  // Couldn't get type

    if (!(byte & 0x80)) return true;         // Positive fixnum, already skipped by get_raw_byte
    if ((byte & 0xE0) == 0xE0) return true;  // Negative fixnum, already skipped by get_raw_byte
    if ((byte & 0xE0) == 0xA0) {             // Fixstr
        seek_ptr += (byte & 0x1F);
        return true;
    }
    if ((byte & 0xF0) == 0x80) {  // Fixmap
        length = (byte & 0x0F) * 2;
        for (c = 0; c < length; c++)
            skip(buffer);
        return true;
    }
    if ((byte & 0xF0) == 0x90) {  // Fixarray
        length = byte & 0x0F;
        for (c = 0; c < length; c++)
            skip(buffer);
        return true;
    }

    switch (byte) {
        case MSGPACK_NIL:
        case MSGPACK_FALSE:
        case MSGPACK_TRUE:  // Already skipped by get_raw_byte
            break;
        case MSGPACK_TYPE_U8:
        case MSGPACK_TYPE_S8:
            seek_ptr++;
            break;
        case MSGPACK_TYPE_U16:
        case MSGPACK_TYPE_S16:
            seek_ptr += 2;
            break;
        case MSGPACK_TYPE_U32:
        case MSGPACK_TYPE_S32:
            seek_ptr += 4;
            break;
        case MSGPACK_TYPE_U64:
        case MSGPACK_TYPE_S64:
            seek_ptr += 8;
            break;

        case MSGPACK_TYPE_STR8:
            if (!get_raw_byte(buffer, true, (uint8_t*)&length)) return false;  // Couldn't get str8 length
            seek_ptr += length;
            break;
        case MSGPACK_TYPE_STR16:
            if (!get_raw_word(buffer, true, (uint16_t*)&length)) return false;  // Couldn't get str16 length
            seek_ptr += length;
            break;

        case MSGPACK_TYPE_ARR16:
            if (!get_raw_word(buffer, true, (uint16_t*)&length)) return false;  // Couldn't get arr16 length
            for (c = 0; c < length; c++)
                skip(buffer);
            break;

        case MSGPACK_TYPE_MAP16:
            if (!get_raw_word(buffer, true, (uint16_t*)&length)) return false;  // Couldn't get map16 length
            for (c = 0; c < (length * 2); c++)
                skip(buffer);
            break;

        default:
            return false;  // Type unsupported
    }

    return true;
}

bool MsgPack::search_key(const void* buffer, const MsgPack::RecID record_id) {
    uint8_t byte;
    uint16_t key;

    while (get_raw_byte(buffer, false, &byte)) {
        if (!get_u16(buffer, true, &key)) return false;  // Couldn't get key
        if (key == record_id) return true;               // Found record
        if (!skip(buffer)) return false;                 // Can't skip to next key
    };
    return false;
}

bool MsgPack::msgpack_get(const void* buffer, const size_t size, const RecID record_id, bool* value) {
    init_search(buffer, size);
    if (!search_key(buffer, record_id)) return false;   // Record not found
    if (!get_bool(buffer, false, value)) return false;  // Value isn't a bool

    return true;
}

bool MsgPack::msgpack_get(const void* buffer, const size_t size, const RecID record_id, uint8_t* value) {
    if (!init_search(buffer, size)) return false;
    if (!search_key(buffer, record_id)) return false;  // Record not found
    if (!get_u8(buffer, false, value)) return false;   // Value isn't a u8

    return true;
}

bool MsgPack::msgpack_get(const void* buffer, const size_t size, const RecID record_id, int64_t* value) {
    uint8_t byte;

    init_search(buffer, size);
    if (!search_key(buffer, record_id)) return false;  // Record not found

    if ((seek_ptr + 3) >= buffer_size) return false;                                      // End of buffer
    if ((get_raw_byte(buffer, true, &byte)) && (byte != MSGPACK_TYPE_S64)) return false;  // Value isn't a s64
    *value = ((int64_t)((uint8_t*)buffer)[seek_ptr] << 56) | ((int64_t)((uint8_t*)buffer)[seek_ptr + 1] << 48) |
             ((int64_t)((uint8_t*)buffer)[seek_ptr + 2] << 40) | ((int64_t)((uint8_t*)buffer)[seek_ptr + 3] << 32) |
             (((uint8_t*)buffer)[seek_ptr + 4] << 24) | (((uint8_t*)buffer)[seek_ptr + 5] << 16) |
             (((uint8_t*)buffer)[seek_ptr + 6] << 8) | ((uint8_t*)buffer)[seek_ptr + 7];

    return true;
}

bool MsgPack::msgpack_get(const void* buffer, const size_t size, const RecID record_id, std::string& value) {
    init_search(buffer, size);
    if (!search_key(buffer, record_id)) return false;     // Record not found
    if (!get_string(buffer, false, value)) return false;  // Value isn't a char array

    return true;
}

void MsgPack::msgpack_init(const void* buffer, size_t* ptr) {
    ((uint8_t*)buffer)[0] = MSGPACK_TYPE_MAP16;
    ((uint8_t*)buffer)[1] = 0;
    ((uint8_t*)buffer)[2] = 0;

    *ptr = 3;
}

void MsgPack::add_key(const void* buffer, size_t* ptr, const RecID record_id) {
    uint16_t key;

    ((uint8_t*)buffer)[(*ptr)++] = MSGPACK_TYPE_U16;
    ((uint8_t*)buffer)[(*ptr)++] = record_id >> 8;
    ((uint8_t*)buffer)[(*ptr)++] = record_id & 0xFF;

    // Auto-inc MAP16 size which should be at the beginning of the buffer

    key = (((uint8_t*)buffer)[1] << 8) | ((uint8_t*)buffer)[2];
    key++;

    ((uint8_t*)buffer)[1] = key >> 8;
    ((uint8_t*)buffer)[2] = key & 0xFF;
}

void MsgPack::msgpack_add(const void* buffer, size_t* ptr, const RecID record_id, bool value) {
    add_key(buffer, ptr, record_id);

    if (value)
        ((uint8_t*)buffer)[(*ptr)++] = MSGPACK_TRUE;
    else
        ((uint8_t*)buffer)[(*ptr)++] = MSGPACK_FALSE;
}

void MsgPack::msgpack_add(const void* buffer, size_t* ptr, const RecID record_id, uint8_t value) {
    add_key(buffer, ptr, record_id);

    if (value < 128) {
        ((uint8_t*)buffer)[(*ptr)++] = value;
    } else {
        ((uint8_t*)buffer)[(*ptr)++] = MSGPACK_TYPE_U8;
        ((uint8_t*)buffer)[(*ptr)++] = value;
    }
}

void MsgPack::msgpack_add(const void* buffer, size_t* ptr, const RecID record_id, int64_t value) {
    uint8_t c;

    add_key(buffer, ptr, record_id);

    ((uint8_t*)buffer)[(*ptr)++] = MSGPACK_TYPE_S64;

    for (c = 0; c < 8; c++)
        ((uint8_t*)buffer)[(*ptr)++] = (value >> (8 * (7 - c))) & 0xFF;
}

bool MsgPack::msgpack_add(const void* buffer, size_t* ptr, const RecID record_id, std::string value) {
    uint8_t c;
    size_t length;

    add_key(buffer, ptr, record_id);

    length = value.size();

    if (length < 32) {
        ((uint8_t*)buffer)[(*ptr)++] = length | 0xA0;  // Fixstr
    } else if ((length >= 32) && (length < 256)) {
        ((uint8_t*)buffer)[(*ptr)++] = MSGPACK_TYPE_STR8;
        ((uint8_t*)buffer)[(*ptr)++] = length;
    } else if ((length >= 256) && (length < 65536)) {
        ((uint8_t*)buffer)[(*ptr)++] = MSGPACK_TYPE_STR16;
        ((uint8_t*)buffer)[(*ptr)++] = length >> 8;
        ((uint8_t*)buffer)[(*ptr)++] = length & 0xFF;
    } else {
        return false;
    }

    for (c = 0; c < length; c++)
        ((uint8_t*)buffer)[(*ptr)++] = value[c];

    return true;
}