#  coding=UTF-8
#  Copyright (c) Huawei Technologies Co., Ltd. 2019-2023. All rights reserved.

"""
@time: 2020/08/18
@file: check_cx_nic_version.py
@function:
"""
import re

from Common.base import context_util
from Common.base import entity
from Common.base.constant import MsgKey
from Common.base.entity import DeployException
from Common.base.entity import ResultFactory, Compare
from Common.protocol import ssh_util
from Common.service import check_main_board_uids_service

PY_JAVA_ENV = py_java_env
CX_NIC_DRIVER_KEY = "CX_NIC_Driver_ver"
CX_NIC_DRIVER_KEY_V6 = "CX_NIC_Driver_ver_V6"
CX_1822_NORMALIZED_DRIVER_KEY = "1822_mlnx_normalized_driver_ver"
CX_NIC_FW_KEY = "CX{}_NIC_FW_Ver"
CX_NIC_FW_KEY_V6 = "CX{}_NIC_FW_Ver_V6"
CX_NIC_NUM = (4, 5, 6)
CX_NIC_MST_ID_TO_NUM = {"4117": 4, "4119": 5, "4123": 6}

CUR_NOT_SUPPORT_VERSION = ['8.0.1', '8.1.0', '8.1.1']


def execute(task):
    try:
        return CheckCxNicVersion(task).check()
    except DeployException as e:
        return ResultFactory.create_not_pass(e.origin_info, err_msg=e.err_msg)


class CheckCxNicVersion(object):

    def __init__(self, task):
        self.task = task
        self.deploy_node = context_util.get_deploy_node(PY_JAVA_ENV)
        self._logger = entity.create_logger(__file__)
        self._ssh_rets = list()
        self._err_msgs = list()
        self.not_support_msg = entity.create_msg(MsgKey.NOT_INVOLVE)
        self._use_normalized_driver = False
        self.login_info = context_util.get_login_info(PY_JAVA_ENV)
        self.is_v6_type = check_main_board_uids_service.is_v6_supported_main_board(self.login_info, self._logger)
        self._server_version = context_util.get_current_mapping_version()
        self._mapping_fw_url_dict = {}
        self._mapping_driver_url = ''

    def check(self):
        if not self._contain_check_key():
            return ResultFactory.create_pass()
        try:
            self._get_mapping_msg()
            if not self._need_check_cx_nic():
                return ResultFactory.create_pass(self._ssh_rets,
                                                 "\n".join([entity.create_msg("cx.nic.not.found"), self._mapping_msg]))
            driver_match = self._check_driver_version()
            fw_match = self._check_fw_version()
            if driver_match and fw_match:
                self._err_msgs.append(entity.create_source_file_msg(PY_JAVA_ENV, entity.build_url_error_msg(
                    self._mapping_driver_url, self._mapping_msg)))
                return ResultFactory.create_pass(self._ssh_rets, self._err_msgs)
            self._err_msgs.insert(0, entity.build_driver_tool_tips())
            self._err_msgs.append(entity.create_source_file_msg(PY_JAVA_ENV, entity.build_url_error_msg(
                self._mapping_driver_url, self._mapping_msg)))
            return ResultFactory.create_not_pass(self._ssh_rets, self._err_msgs)
        except DeployException as de:
            self._logger.error(de.message)
            self._ssh_rets.append(de.origin_info)
            self._err_msgs.insert(0, entity.create_source_file_msg(PY_JAVA_ENV, self._mapping_msg))
            self._err_msgs.append(de.err_msg)
            if de.may_info_miss():
                self.task.openAutoRetry()
            return ResultFactory.create_not_pass(self._ssh_rets,
                                                 self._err_msgs)

    def _contain_check_key(self):
        check_keys = [CX_NIC_DRIVER_KEY, CX_1822_NORMALIZED_DRIVER_KEY]
        current_fw_key = CX_NIC_FW_KEY
        if self.is_v6_type:
            check_keys = [CX_NIC_DRIVER_KEY_V6, CX_1822_NORMALIZED_DRIVER_KEY]
            current_fw_key = CX_NIC_FW_KEY_V6
        for num in CX_NIC_NUM:
            check_keys.append(current_fw_key.format(num))
        return context_util.contain_need_check_key(PY_JAVA_ENV, check_keys)

    def _get_mapping_msg(self):
        mapping_msgs = list()
        self._obtain_mapping_driver_version()
        driver_msg = self._mapping_driver_ver if self._mapping_driver_ver \
            else self.not_support_msg
        mapping_msgs.append(entity.create_msg(
            "cx.nic.match.version").format(driver_msg))
        if not self._mapping_driver_url:
            self.deploy_node.putVersion(context_util.get_version_key_enum().NIC_DRIVER.getKey(), self.not_support_msg)
        # 归一驱动不需要检查固件，驱动中自带有固件
        self._mapping_fw_version_dict = {}

        if not self._use_normalized_driver:
            current_fw_key = CX_NIC_FW_KEY
            if self.is_v6_type:
                current_fw_key = CX_NIC_FW_KEY_V6
            for num in CX_NIC_NUM:
                ver = context_util.get_mapping_attribute(
                    PY_JAVA_ENV, current_fw_key.format(num))
                self._mapping_fw_version_dict[num] = ver
                self._mapping_fw_url_dict[num] = context_util.get_mapping_attribute_url(
                    PY_JAVA_ENV, current_fw_key.format(num))
                fw_msg = ver if ver else self.not_support_msg
                mapping_msgs.append(entity.create_msg("cx.nic.match.fw.version").format(num, fw_msg))
                if not ver:
                    self.deploy_node.putVersion(context_util.get_version_key_enum().NIC_FW.getKey(),
                                                self.not_support_msg)
        self._mapping_msg = ", ".join(mapping_msgs)

    def _obtain_mapping_driver_version(self):
        current_driver_key = CX_NIC_DRIVER_KEY
        if self.is_v6_type:
            current_driver_key = CX_NIC_DRIVER_KEY_V6
        self._mapping_driver_ver = context_util.get_mapping_attribute(
            PY_JAVA_ENV, current_driver_key)
        self._mapping_driver_url = context_util.get_mapping_attribute_url(
            PY_JAVA_ENV, current_driver_key)
        if self._mapping_driver_ver:
            return
        self._mapping_driver_ver = context_util. \
            get_mapping_attribute(PY_JAVA_ENV, CX_1822_NORMALIZED_DRIVER_KEY)
        self._mapping_driver_url = context_util. \
            get_mapping_attribute_url(PY_JAVA_ENV, CX_1822_NORMALIZED_DRIVER_KEY)
        self._use_normalized_driver = bool(self._mapping_driver_ver)

    def _contains_cx_nic_card(self):
        ssh_ret = ssh_util.exec_ssh_cmd_nocheck(PY_JAVA_ENV,
                                                "lspci | grep Mellanox")
        self._ssh_rets.append(ssh_ret)
        cmd_and_end_line_num = 2
        return len(ssh_ret.splitlines()) > cmd_and_end_line_num

    def _need_check_cx_nic(self):
        if self._use_normalized_driver or self._use_normalized_cmd():
            self.driver_ssh_ret = ssh_util.exec_ssh_cmd_disable_none(
                PY_JAVA_ENV, "rdma_ver")
        else:
            self.driver_ssh_ret = ssh_util.exec_ssh_cmd_disable_none(
                PY_JAVA_ENV, "ofed_info -s")
        self._ssh_rets.append(self.driver_ssh_ret)
        if ssh_util.is_invalid_cmd(self.driver_ssh_ret):
            if self._contains_cx_nic_card():
                raise DeployException("invalid cmd", err_msg=entity.
                                      create_msg("cmd.not.invalid"))
            return False
        return True

    def _check_driver_version(self):
        """
        检查驱动版本，分为两种驱动，cx驱动和1882cx归一驱动
        """

        def trans_normalized_driver_version_2_digital_version(driver_version):
            # 4.19-26-pacific-47f-0
            # 归一驱动中带有字符pacific、arm和x86，需要替换掉再进行判断
            replace_version = driver_version.strip()
            for driver_key in context_util.NORMALIZED_DRIVER_KEYS:
                replace_version = replace_version.replace(driver_key, "0")
            # 倒数第二位为16进制需要转换成10进制
            versions = replace_version.split("-")
            versions[-2] = str(int(versions[-2], 16))
            return "-".join(versions)

        version_match = True
        msg_res = entity.create_msg("cx.nic.driver.current.version")
        if self._use_normalized_driver:
            # 版本在第二行,执行命令判断有长度大于2，所以不会越界
            ver_info = self.driver_ssh_ret.splitlines()[1]
            self.deploy_node.putVersion(context_util.get_version_key_enum().NIC_DRIVER.getKey(), ver_info)
            if not context_util.contain_normalized_driver_key(ver_info) or Compare.compare_digital_version(
                    trans_normalized_driver_version_2_digital_version(
                        ver_info),
                    trans_normalized_driver_version_2_digital_version(
                        self._mapping_driver_ver)) < 0:
                self.deploy_node.putResult(context_util.get_version_key_enum().NIC_DRIVER.getKey(),
                                           context_util.get_not_pass_key())
                version_match = False
            self._err_msgs.append(msg_res.format(ver_info))
            return version_match
        match = re.findall(r"-([0-9.-]+)", self.driver_ssh_ret)
        if match:
            self._logger.info("driver version: {}".format(match[0]))
            if Compare.compare_digital_version(
                    match[0], self._mapping_driver_ver) < 0:
                version_match = False
            self._err_msgs.append(
                entity.build_url_error_msg(self._mapping_driver_url, msg_res.format(match[0])))
        else:
            self._err_msgs.append(entity.create_msg(
                "match.cx.nic.driver.version.failed"))
            self.task.openAutoRetry()
            version_match = False
        return version_match

    def _check_fw_version(self):
        version_match = True
        if self._use_normalized_driver:
            return version_match
        # 加载mst工具
        self._ssh_rets.append(ssh_util.exec_ssh_cmd_nocheck(
            PY_JAVA_ENV, "mst start"))
        ssh_ret = ssh_util.exec_ssh_cmd(PY_JAVA_ENV, "mst status")
        self._ssh_rets.append(ssh_ret)
        nic_vers_dict = dict()
        for line in ssh_ret.splitlines():
            nic_num, nic_path = self._get_nic_info(line)
            if nic_num:
                nic_vers = nic_vers_dict.get(nic_num, {})
                nic_fw_ver = self._parse_nic_fw_version(nic_path)
                nic_vers[nic_path] = nic_fw_ver
                nic_vers_dict[nic_num] = nic_vers

        for nic_num, nic_vers in nic_vers_dict.items():
            mapping_ver = self._mapping_fw_version_dict.get(nic_num)
            if not mapping_ver:
                continue
            fw_vers = []
            for nic_path, fw_ver in nic_vers.items():
                if Compare.compare_digital_version(fw_ver, mapping_ver) < 0:
                    self.deploy_node.putResult(context_util.get_version_key_enum().NIC_FW.getKey(),
                                               context_util.get_not_pass_key())
                    version_match = False
                msg_res = entity.create_msg("cx.nic.fw.current.version").format(nic_path, nic_num, fw_ver)
                self._err_msgs.append(entity.build_url_error_msg(self._mapping_fw_url_dict.get(nic_num), msg_res))
                fw_vers.append(fw_ver)
            self.deploy_node.putVersion(context_util.get_version_key_enum().NIC_FW.getKey(), ";".join(set(fw_vers)))
            return version_match

    def _get_nic_info(self, line):
        """
        获取CX卡信息
        :param line: 解析行
        :return: 卡类型， 卡路径
        """
        match = re.findall(r"^/dev/mst/mt(\d+)_", line)
        if match and match[0] in CX_NIC_MST_ID_TO_NUM:
            return CX_NIC_MST_ID_TO_NUM[match[0]], line.split()[0]
        return "", ""

    def _parse_nic_fw_version(self, nic_path):
        ssh_ret = ssh_util.exec_ssh_cmd(PY_JAVA_ENV, "flint -d {} q".format(
            nic_path))
        self._ssh_rets.append(ssh_ret)
        for line in ssh_ret.splitlines():
            if line.startswith("FW Version"):
                fw_ver = line.split()[-1].strip()
                self._logger.info("{} fw version: {}".format(nic_path, fw_ver))
                return fw_ver
        raise DeployException(
            "parse nic:{} fw ver failed".format(nic_path),
            err_code=DeployException.ErrCode.MAY_INFO_MISS)

    def _use_normalized_cmd(self):
        # 8.1.2版本，cx网卡驱动查询命令归一化，用rdma_ver
        for item in CUR_NOT_SUPPORT_VERSION:
            if item in self._server_version:
                return False
        return True
