#include "msi.h"
#include "pci.h"
#include <common/errno.h>

/**
 * @brief 生成msi消息
 *
 * @param msi_desc msi描述符
 * @return struct msi_msg_t* msi消息指针(在描述符内)
 */
extern struct msi_msg_t *msi_arch_get_msg(struct msi_desc_t *msi_desc);

/**
 * @brief 启用 Message Signaled Interrupts
 *
 * @param header 设备header
 * @param vector 中断向量号
 * @param processor 要投递到的处理器
 * @param edge_trigger 是否边缘触发
 * @param assert 是否高电平触发
 *
 * @return 返回码
 */
int pci_enable_msi(struct msi_desc_t *msi_desc)
{
    struct pci_device_structure_header_t *ptr = msi_desc->pci_dev;
    uint32_t cap_ptr;
    uint32_t tmp;
    uint16_t message_control;
    uint64_t message_addr;

    // 先尝试获取msi-x,若不存在,则获取msi capability
    if (msi_desc->pci.msi_attribute.is_msix)
    {
        cap_ptr = pci_enumerate_capability_list(ptr, 0x11);
        if (((int32_t)cap_ptr) < 0)
        {
            cap_ptr = pci_enumerate_capability_list(ptr, 0x05);
            if (((int32_t)cap_ptr) < 0)
                return -ENOSYS;
            msi_desc->pci.msi_attribute.is_msix = 0;
        }
    }
    else
    {
        cap_ptr = pci_enumerate_capability_list(ptr, 0x05);
        if (((int32_t)cap_ptr) < 0)
            return -ENOSYS;
        msi_desc->pci.msi_attribute.is_msix = 0;
    }
    // 获取msi消息
    msi_arch_get_msg(msi_desc);

    if (msi_desc->pci.msi_attribute.is_msix)
    {
    }
    else
    {
        tmp = pci_read_config(ptr->bus, ptr->device, ptr->func, cap_ptr); // 读取cap+0x0处的值
        message_control = (tmp >> 16) & 0xffff;

        // 写入message address
        message_addr = ((((uint64_t)msi_desc->msg.address_hi) << 32) | msi_desc->msg.address_lo); // 获取message address
        pci_write_config(ptr->bus, ptr->device, ptr->func, cap_ptr + 0x4, (uint32_t)(message_addr & 0xffffffff));

        if (message_control & (1 << 7)) // 64位
            pci_write_config(ptr->bus, ptr->device, ptr->func, cap_ptr + 0x8, (uint32_t)((message_addr >> 32) & 0xffffffff));

        // 写入message data

        tmp = msi_desc->msg.data;
        if (message_control & (1 << 7)) // 64位
            pci_write_config(ptr->bus, ptr->device, ptr->func, cap_ptr + 0xc, tmp);
        else
            pci_write_config(ptr->bus, ptr->device, ptr->func, cap_ptr + 0x8, tmp);

        // 使能msi
        tmp = pci_read_config(ptr->bus, ptr->device, ptr->func, cap_ptr); // 读取cap+0x0处的值
        tmp |= (1 << 16);
        pci_write_config(ptr->bus, ptr->device, ptr->func, cap_ptr, tmp);
    }

    return 0;
}

/**
 * @brief 在已配置好msi寄存器的设备上,使能msi
 *
 * @param header 设备头部
 * @return int 返回码
 */
int pci_start_msi(void *header)
{
    struct pci_device_structure_header_t *ptr = (struct pci_device_structure_header_t *)header;
    uint32_t cap_ptr;
    uint32_t tmp;

    switch (ptr->HeaderType)
    {
    case 0x00: // general device
        if (!(ptr->Status & 0x10))
            return -ENOSYS;
        cap_ptr = ((struct pci_device_structure_general_device_t *)ptr)->Capabilities_Pointer;

        tmp = pci_read_config(ptr->bus, ptr->device, ptr->func, cap_ptr); // 读取cap+0x0处的值

        if (tmp & 0xff != 0x5)
            return -ENOSYS;

        // 使能msi
        tmp = pci_read_config(ptr->bus, ptr->device, ptr->func, cap_ptr); // 读取cap+0x0处的值
        tmp |= (1 << 16);
        pci_write_config(ptr->bus, ptr->device, ptr->func, cap_ptr, tmp);

        break;

    case 0x01: // pci to pci bridge
        if (!(ptr->Status & 0x10))
            return -ENOSYS;
        cap_ptr = ((struct pci_device_structure_pci_to_pci_bridge_t *)ptr)->Capability_Pointer;

        tmp = pci_read_config(ptr->bus, ptr->device, ptr->func, cap_ptr); // 读取cap+0x0处的值

        if (tmp & 0xff != 0x5)
            return -ENOSYS;

        //使能msi
        tmp = pci_read_config(ptr->bus, ptr->device, ptr->func, cap_ptr); // 读取cap+0x0处的值
        tmp |= (1 << 16);
        pci_write_config(ptr->bus, ptr->device, ptr->func, cap_ptr, tmp);

        break;
    case 0x02: // pci to card bus bridge
        return -ENOSYS;
        break;

    default: // 不应该到达这里
        return -EINVAL;
        break;
    }

    return 0;
}
/**
 * @brief 禁用指定设备的msi
 *
 * @param header pci header
 * @return int
 */
int pci_disable_msi(void *header)
{
    struct pci_device_structure_header_t *ptr = (struct pci_device_structure_header_t *)header;
    uint32_t cap_ptr;
    uint32_t tmp;

    switch (ptr->HeaderType)
    {
    case 0x00: // general device
        if (!(ptr->Status & 0x10))
            return -ENOSYS;
        cap_ptr = ((struct pci_device_structure_general_device_t *)ptr)->Capabilities_Pointer;

        tmp = pci_read_config(ptr->bus, ptr->device, ptr->func, cap_ptr); // 读取cap+0x0处的值

        if (tmp & 0xff != 0x5)
            return -ENOSYS;

        // 禁用msi
        tmp = pci_read_config(ptr->bus, ptr->device, ptr->func, cap_ptr); // 读取cap+0x0处的值
        tmp &= (~(1 << 16));
        pci_write_config(ptr->bus, ptr->device, ptr->func, cap_ptr, tmp);

        break;

    case 0x01: // pci to pci bridge
        if (!(ptr->Status & 0x10))
            return -ENOSYS;
        cap_ptr = ((struct pci_device_structure_pci_to_pci_bridge_t *)ptr)->Capability_Pointer;

        tmp = pci_read_config(ptr->bus, ptr->device, ptr->func, cap_ptr); // 读取cap+0x0处的值

        if (tmp & 0xff != 0x5)
            return -ENOSYS;

        //禁用msi
        tmp = pci_read_config(ptr->bus, ptr->device, ptr->func, cap_ptr); // 读取cap+0x0处的值
        tmp &= (~(1 << 16));
        pci_write_config(ptr->bus, ptr->device, ptr->func, cap_ptr, tmp);

        break;
    case 0x02: // pci to card bus bridge
        return -ENOSYS;
        break;

    default: // 不应该到达这里
        return -EINVAL;
        break;
    }

    return 0;
}