#!/usr/bin/python
# -*- coding: UTF-8 -*-
import sys
import os
import logging

logging.basicConfig(level=logging.INFO,
                    format='[%(asctime)s][%(levelname)s][%(message)s][%(filename)s, %(lineno)d]',
                    datefmt='%Y-%m-%d %H:%M:%S')


def crc16(x, invert):
    w_crc_in = 0x0000
    wc_poly = 0x1021
    for byte in x:
        if isinstance(byte, str):
            w_crc_in ^= (ord(byte) << 8)
        else:
            w_crc_in ^= ((byte) << 8)
        for _ in range(8):
            if w_crc_in & 0x8000:
                w_crc_in = (w_crc_in << 1) ^ wc_poly
            else:
                w_crc_in = (w_crc_in << 1)

    s = hex(w_crc_in).upper()
    return s[-2:] + s[-4:-2] if invert == True else s[-4:-2] + s[-2:]


def param_check():
    if len(sys.argv) > 9 or len(sys.argv) <= 1:
        logging.error("param num is %d, should between 1 and 8.", (len(sys.argv) - 1))
        return False
    logging.info("param check pass.")
    return True


def env_check():
    if not os.path.exists(r"/usr/sbin/bsptool"):
        logging.error("can not find /usr/sbin/bsptool.")
        return False
    logging.info("env check pass.")
    return True


all_param = list()

vaild_work_mode = [
    [2, 2, 0, 0], [2, 2, 1, 1], [2, 2, 0, 1], [2, 2, 1, 0], [0, 0], [1, 1, 1, 1], [0, 1], [1, 0],
    [1, 0xff, 1], [2], [2, 2],
    [2, 2, 0], [2, 2, 1], [0], [1], [1, 1], [1, 1, 1], [1, 0xff]
]
function_type = ("nvme", "virtio", "network")


def work_mode_parser(work_mode_param):
    work_mode_val = list()
    tmp = work_mode_param.split("-")
    if len(tmp) >= 10:
        logging.error("%s param exceed.", work_mode_param)
        return False
    for i in range(1, len(tmp)):
        if tmp[i].lower() in function_type:
            work_mode_val.append(function_type.index(tmp[i].lower()))
        elif tmp[i].lower() == "ff":
            work_mode_val.append(0xff)
        else:
            logging.error("not support pf type %s.", tmp[i])
            return False
    logging.info("work mode %s parser pass.", str(work_mode_val))
    if work_mode_val in vaild_work_mode:
        logging.info("work mode %s check pass.", str(work_mode_val))
    else:
        logging.error("work mode %s check fail.", str(work_mode_val))
        return False
    all_param.append(work_mode_val)
    return True


def vf_num_parser(vf_num_param):
    vf_num_val = list()
    tmp = vf_num_param.split("-")
    if len(tmp) >= 10:
        logging.error("%s param exceed.", vf_num_param)
        return False
    for i in range(1, len(tmp)):
        if not tmp[i].isdigit():
            logging.error("vf num %s is not digit.", tmp[i])
            return False
        tmp_val = int(tmp[i])
        if tmp_val < 0 or tmp_val >= 64:
            logging.error("vf num %s should between 0 and 63.", tmp[i])
            return False
        vf_num_val.append(tmp_val)
    logging.info("vf num %s parser pass.", str(vf_num_val))
    all_param.append(vf_num_val)
    return True


def param_parser():
    # 解析workmode
    if len(sys.argv) >= 2 and not work_mode_parser(sys.argv[1]):
        return False

    # 解析vf num
    if len(sys.argv) >= 3 and not vf_num_parser(sys.argv[2]):
        return False

    # 解析其他参数,透传
    for i in range(3, len(sys.argv)):
        tmp_list = list()
        params = sys.argv[i].split("-")
        if len(params) >= 10:
            logging.error("%s param exceed.", params)
            return False
        for j in range(1, len(params)):
            if not params[j].isdigit():
                logging.error("%s is not digit.", params[j])
                return False
            tmp_val = int(params[j])
            tmp_list.append(tmp_val)
        all_param.append(tmp_list)
    logging.info("all param %s parser pass.", str(all_param))
    return True


# EEPROM的起始地址是0x400,地址位宽是2Byte
IIC_BUS = 0
EEPROM_SLAVE_ADDR = 0x51
EEPROM_ADDR = 0x400
EEPROM_ADDR_WIDTH = 2


def show_eeprom(ofset, length):
    cmd = "bsptool -c readiic %d %d %d %d %d" % (
        IIC_BUS, EEPROM_SLAVE_ADDR, EEPROM_ADDR + ofset, EEPROM_ADDR_WIDTH, length)
    os.system(cmd)


def read_eeprom_1byte(ofset):
    cmd = "bsptool -c readiic %d %d %d %d 1" % (IIC_BUS, EEPROM_SLAVE_ADDR, EEPROM_ADDR + ofset, EEPROM_ADDR_WIDTH)
    data = os.popen(cmd).read().rstrip()
    return data


def write_eeprom_1byte(ofset, val):
    cmd = "bsptool -c writeiic %d %d %d %d 1 %d > /dev/null" % (
        IIC_BUS, EEPROM_SLAVE_ADDR, EEPROM_ADDR + ofset, EEPROM_ADDR_WIDTH, val)
    ret = os.system(cmd)
    return ret


def write_all_data_2_eeprom():
    # wirte all to 0xff
    for i in range(0, 73):  # 把72个字节全部写成ff
        if write_eeprom_1byte(i, 0xff) != 0:
            logging.error("write %s fail.", i)
            return False

    # write data
    for i, value_i in enumerate(all_param):
        param = value_i
        for j, value_j in enumerate(param):
            if write_eeprom_1byte(j * 8 + i, value_j) != 0:
                logging.error("write %d fail.", (j * 8 + i))
                return False

    # mgt
    if write_eeprom_1byte(64, 0x55) != 0:  # vaild,位置在64
        logging.error("write 64 fail.")
        return False

    if write_eeprom_1byte(65, len(all_param[0])) != 0:  # vaild_line,位置在65
        logging.error("write 65 fail.")
        return False

    if write_eeprom_1byte(66, len(all_param)) != 0:  # vaild_column,位置在66
        logging.error("write 66 fail.")
        return False

    # crc16
    tmp_data = list()
    for i in range(0, 70):  # 把crc前所有的70个数据全部读出来处理
        read_data = read_eeprom_1byte(i)
        if not read_data.startswith("0x"):
            logging.error("Read fail, %s.", read_data)
            return False
        tmp_data.append(int(read_eeprom_1byte(i), 16))
    crc_data = int(crc16(tmp_data, False), 16)
    logging.info("crc data is %#x.", crc_data)
    if write_eeprom_1byte(70, crc_data & 0xff) != 0:  # 位置70是crc低位
        logging.error("write 70 fail.")
        return False

    if write_eeprom_1byte(71, (crc_data >> 8) & 0xff) != 0:  # 位置71是crc低位
        logging.error("write 70 fail.")
        return False
    return True


def show_all_data_in_eeprom():
    logging.info("show data in hardware.")
    for i in range(0, 9):
        show_eeprom(i * 8, 8)


# 参数列表
# python pcie_ep_cfg.py pf-nvme-nvme vf_num-0-0
if __name__ == "__main__":
    # 入参检查
    if not param_check():
        sys.exit(1)
    # 环境检查
    if not env_check():
        sys.exit(1)
    # 参数解析
    if not param_parser():
        sys.exit(1)
    # 写入硬件(包含CRC计算),并回读
    if not write_all_data_2_eeprom():
        sys.exit(1)
    show_all_data_in_eeprom()
    # 完成
    logging.info("pcie ep cfg compelte, reboot take effect.")
