# -*- coding: UTF-8 -*-
#  Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
import codecs
import json
import os
import re

from resource.resource import MESSAGES_DICT
from cbb.frame.cli import cliUtil
from cbb.frame.context import contextUtil
from cbb.frame.base import baseUtil
from cbb.frame.base.exception import UnCheckException
from cbb.frame.cli.execute_on_all_controllers import (
    ExecuteOnAllControllers,
    ResultType,
    FuncResult,
    ExeOnAllCtrlException,
    ExeOnAllCtrlContext
)

# 系统盘盘符为/dev/sda, 或者/dev/sda + /dev/sdb
SYS_DISK_LIST = ('/dev/sda', '/dev/sdb')
LIFE_NOT_PASS_VALUE = 80
# 用于标识巡检项是否通过, 2: NO_PASS
RESULT_NOT_PASS_VALUE = 2
# 用于标识巡检项是否通过, 1: WARNING
RESULT_WARNING_VALUE = 1
# 2023.8.31增加3款系统盘model：MD619HXCLDE3TC、ME619HXELDF3TE、ME619DXFNEC6CF
# 240G容量点盘PE不准
CAPACITY_240 = ('ME619GXEHDE3TE', 'ME619GXEHDE3AE', 'ME619HXELDF3TE')


def execute(data_dict):
    """
    :功能描述: 系统盘风险检查
    """
    check_item = CheckSystemDiskRisk(data_dict)
    result = check_item.execute()
    if result == RESULT_NOT_PASS_VALUE:
        not_pass_ctrl_info = ";\n".join(check_item.err_msg_list)
        return False, check_item.get_msg("check.systemdisk.risk.not.pass") + not_pass_ctrl_info, ""
    return True, "", ""


class CheckSystemDiskRisk:
    def __init__(self, data_dict):
        self.data_dict = data_dict
        self.cli = contextUtil.getCli(data_dict)
        self.logger = contextUtil.getLogger(data_dict)
        self.lang = contextUtil.getLang(data_dict)
        self.dev = self.data_dict.get("dev")
        self.msg_list = []
        self.max_program_fail = 0
        self.check_results = {}
        with codecs.open(os.path.join(os.path.dirname(__file__), "check_systemdisk_risk.json"), "r",
                         encoding="utf-8") as config_f:
            self.config_map = json.load(config_f)
        # 系统盘厂商列表及SMART解析项：id为smartID，value_column为取值的列(VALUE为3，RAW_VALUE为-1)
        self.vendor_map = self.config_map.get('vendor_map')
        # 未Trim的系统盘SN
        self.un_trim_list = self.config_map.get('un_trim_list')
        self.err_msg_list = []

    @staticmethod
    def get_rtmm_petimes_value(smart_value):
        # RAW_VALUE值转化成2进制数
        binary = str(bin(int(smart_value))).replace('0b', '')
        # 高位补0,补为48位
        if len(binary) < 48:
            temp_num = 48 - len(binary)
            for _ in range(temp_num):
                binary = '0' + binary
        # 比较3组数据大小，取中间值作为PE次数
        nums = [int(binary[0:16], 2), int(binary[16:32], 2), int(binary[32:48], 2)]
        nums.sort()
        return nums[1]

    def get_msg(self, msg_key):
        return baseUtil.getPyResource(self.lang, msg_key, "", resource=MESSAGES_DICT)

    def execute(self):
        ret = cliUtil.hasSuperAdminPrivilege(self.cli, self.lang)
        if not ret[1]:
            raise Exception("DO_NOT_HAS_SUPER_ADMIN_RIGHTS", "")
        exe_context = ExeOnAllCtrlContext(self.data_dict)
        self.get_drive_risk(exe_context)
        self.logger.info("get_drive_risk result: {}".format(self.check_results))
        return max(self.check_results.values())

    @ExecuteOnAllControllers
    def get_drive_risk(self, data_dict):
        """
        flag: 用于标识巡检项是否通过，0: PASS 1: Warning 2:NOT_PASS
        """
        flag = 0
        current_ctrl_id = data_dict.cur_ctrl_id
        self.cli = data_dict.dev_info.cli
        self.lang = data_dict.lang
        tool_name = self._get_tool_name()
        try:
            self.logger.info("current_ctrl_id is :{}".format(current_ctrl_id))
            if not tool_name:
                self.check_results.update({current_ctrl_id: flag})
                self.logger.info("disktool and disk_repair.sh do not exist")
                return FuncResult(ResultType.SUCCESS, "", "")
            dsl_cmd_info = "{} -s".format(tool_name)
            is_suc, disk_info_lines, err_msg = cliUtil.excuteCmdInMinisystemModel(self.cli, dsl_cmd_info, self.lang)
            disk_info_list = disk_info_lines.splitlines()
            if not disk_info_list:
                self.check_results.update({current_ctrl_id: flag})
                self.logger.info("disk_info_list does not exist")
                return FuncResult(ResultType.SUCCESS, "", "")
            flag = self._get_controller_result(disk_info_list, flag, tool_name, current_ctrl_id)
            self.check_results.update({current_ctrl_id: flag})
            return FuncResult(ResultType.SUCCESS, "", "")
        except UnCheckException as uncheck:
            self.logger.error("check_systemdisk_risk uncheck exception:{}".format(uncheck.errorMsg))
            return FuncResult(ResultType.FAILED, "", "")

    def _get_tool_name(self):
        tool_name = ""
        error_reg = re.compile("^Error[\\s\\S]+", re.I)
        check_tool_version = "{} -v".format("disk_repair.sh")
        is_suc, tool_version, err_msg = cliUtil.excuteCmdInMinisystemModel(self.cli, check_tool_version, self.lang)
        if error_reg.search(tool_version):
            check_tool_version_new = "{} -v".format("disktool")
            is_suc, tool_version_new, err_msg = cliUtil.excuteCmdInMinisystemModel(
                self.cli, check_tool_version_new, self.lang)
            if not error_reg.search(tool_version_new):
                tool_name = "disktool"
        else:
            tool_name = "disk_repair.sh"
        return tool_name

    def _get_controller_result(self, disk_info_list, flag, tool_name, current_ctrl_id):
        for sys_letter in SYS_DISK_LIST:
            for disk_info in disk_info_list:
                if sys_letter not in disk_info:
                    continue
                vendor = self._get_drive_vendor(disk_info)
                dsl_cmd_smart = "{} -f a {}".format(tool_name, sys_letter)
                flag = self._get_drive_smart(dsl_cmd_smart, vendor, flag)
        if flag != 0:
            self._get_not_pass_ctrl_info(current_ctrl_id, flag)
        return flag

    def _get_drive_vendor(self, disk_info):
        for vendor_key in self.vendor_map.keys():
            if vendor_key in disk_info:
                return vendor_key
        return ""

    def _get_drive_smart(self, dsl_cmd_smart, vendor, res_flag):
        if not vendor:
            self.logger.info("Systemdisk Model is not involved")
            return res_flag
        is_suc, dsl_smart_res, err_msg = cliUtil.excuteCmdInMinisystemModel(self.cli, dsl_cmd_smart, self.lang)
        smart = self._get_vendor_smart_info(dsl_smart_res.splitlines(), vendor)
        self.logger.info("Systemdisk Model is: {}, SMART info is: {}.".format(vendor, smart))
        return self._check_smart_res_pass(res_flag, smart, vendor)

    def _get_not_pass_ctrl_info(self, node_id, res_flag):
        if not node_id:
            self.err_msg_list.append("Get controller info failed")
            return
        if res_flag == RESULT_NOT_PASS_VALUE:
            self.err_msg_list.append("local node id: {}  NOT PASS".format(node_id))
        elif res_flag == RESULT_WARNING_VALUE:
            self.err_msg_list.append("local node id: {}  WARNING".format(node_id))

    def _get_ctrl_id_by_node_id(self, node_id):
        try:
            one_engine_ctrl_num = 0
            controller = "showsysstatus"
            is_suc, cli_ret, err_msg = cliUtil.excuteCmdInMinisystemModel(self.cli, controller, self.lang)
            dev_infos = cliUtil.getVerticalCliRet(cli_ret)
            if dev_infos:
                dev_info = dev_infos[0]
                one_engine_ctrl_num = int(int(dev_info.get("node max")) / int(dev_info.get("group max")))
            # 根据one_engine_ctrl_num、node_id获取对应ctrl id
            if one_engine_ctrl_num > 0:
                ctrl_trans_dict = {"0": "A", "1": "B", "2": "C", "3": "D"}
                node_id = int(node_id)
                return str(int(node_id / one_engine_ctrl_num)) + ctrl_trans_dict.get(str(node_id % one_engine_ctrl_num))
            else:
                return "NULL"
        except UnCheckException as uncheck:
            self.logger.error("check_systemdisk_risk uncheck exception:{}".format(uncheck.errorMsg))
            return "NULL"
        finally:
            cliUtil.enterCliModeFromSomeModel(self.cli, self.lang)

    def _get_vendor_smart_info(self, dsl_smart_res, vendor):
        """
        pe准、pe不准盘的smart获取
        """
        vendor_smart_map = self.vendor_map[vendor]
        smart = {
            'PE_Times': -1, 'Spare_Block_Life': -1, 'Reallocated': -1, 'Program_Fail': 0,
            'Erase_Fail': 0, 'Host_Write': -1, 'unc': -1, 'Uncorrectable_Sector_Count': -1,
            'Offline_Uncorrectable_Sector_Count': -1, 'SN': ''
        }
        for smart_info in dsl_smart_res:
            # smart中取SN
            sn_reg = re.compile(r"Serial Number:\s+(\S+)", re.I)
            sn_ret = sn_reg.findall(smart_info)
            if sn_ret:
                smart.update({'SN': sn_ret[0]})
                continue
            # smart中取数字开头的行
            smart_id_reg = re.compile(r"^\s*(\d+)\s+.*")
            smart_id_ret = re.match(smart_id_reg, smart_info)
            if not smart_id_ret:
                continue
            self._update_smart_by_vendor_map(smart, smart_id_ret, vendor, vendor_smart_map)
            # 部分model号有多个Program fail值，取最大
            self._update_max_program_fail(smart, smart_id_ret, vendor_smart_map)
        return smart

    def _update_smart_by_vendor_map(self, smart, smart_id_ret, vendor, vendor_smart_map):
        smart_id = smart_id_ret.group(1)
        for smart_key in smart.keys():
            if smart_key not in vendor_smart_map.keys():
                continue
            if smart_id != vendor_smart_map[smart_key]['id']:
                continue
            smart_value = int(smart_id_ret.group().split()[vendor_smart_map[smart_key]["value_column"]])
            if smart_key == 'PE_Times' and vendor.startswith("RTMM"):
                smart.update({smart_key: self.get_rtmm_petimes_value(smart_value)})
                break
            if smart_key == 'Uncorrectable_Sector_Count' or smart_key == 'Offline_Uncorrectable_Sector_Count':
                unc = smart.get('unc')
                smart.update({smart_key: smart_value})
                smart.update({'unc': smart_value if unc == -1 else smart_value + unc})
                break
            smart.update({smart_key: smart_value})
            break

    def _update_max_program_fail(self, smart, smart_id_ret, vendor_smart_map):
        if "Program_Fail_One" in vendor_smart_map.keys():
            smart_id = smart_id_ret.group(1)
            smart_value = -1
            if smart_id == vendor_smart_map.get("Program_Fail_One")['id']:
                smart_value = int(
                    smart_id_ret.group().split()[vendor_smart_map["Program_Fail_One"]["value_column"]])
            elif smart_id == vendor_smart_map.get("Program_Fail_Two")['id']:
                smart_value = int(
                    smart_id_ret.group().split()[vendor_smart_map["Program_Fail_Two"]["value_column"]])
            elif smart_id == vendor_smart_map.get("Program_Fail_Three")['id']:
                smart_value = int(
                    smart_id_ret.group().split()[vendor_smart_map["Program_Fail_Three"]["value_column"]])
            if smart_value > self.max_program_fail:
                self.max_program_fail = smart_value
                smart.update({"Program_Fail": self.max_program_fail})

    def _check_smart_res_pass(self, res_flag, smart, vendor):
        temp_flag = 0
        # 基于Spare Block的剩余寿命<=80,建议更换
        if not self._check_spare_block_life_pass(smart):
            return RESULT_NOT_PASS_VALUE
        # 获取PE准确盘、PE不准盘的PE次数
        if vendor in CAPACITY_240 and smart.get("Host_Write") != -1:
            if smart.get("SN") in self.un_trim_list:
                # PE不准盘，未Trim：Host写入量 * 20 为估算PE次数
                pe_times = smart.get("Host_Write") * 32 / 1024 / 240 * 20
            else:
                # PE不准盘，已Trim：Host写入量 * 8 为估算PE次数
                pe_times = smart.get("Host_Write") * 32 / 1024 / 240 * 8
        else:
            pe_times = smart.get("PE_Times")
        # 根据PE次数判断是否建议优化
        if 5000 <= pe_times < 6000:
            self.logger.info("A large amount of data is written to the system disk. Optimization is recommended. ")
            temp_flag = RESULT_WARNING_VALUE
        # 根据PE次数、UNC、Program fail、Erase fail判断是否建议更换
        if pe_times >= 6000:
            # PE次数大于6000，建议更换
            temp_flag = RESULT_NOT_PASS_VALUE
        elif pe_times >= 3000:
            # PE次数大于3000小于6000，unc大于1，建议更换
            if smart.get("unc") > 1:
                temp_flag = RESULT_NOT_PASS_VALUE
            elif smart.get('Program_Fail') + smart.get('Erase_Fail') > 1:
                # PE次数大于3000小于6000，unc小于2，program fail + erase fail大于1，建议更换
                temp_flag = RESULT_NOT_PASS_VALUE
        else:
            # PE次数小于3000，unc大于1，建议更换
            if smart.get("unc") > 1:
                temp_flag = RESULT_NOT_PASS_VALUE
        return max(res_flag, temp_flag)

    def _check_spare_block_life_pass(self, smart):
        if -1 < smart.get('Spare_Block_Life') <= 80:
            self.logger.info("Life based on Spare Block is less than or equal to 80%.")
            return False
        return True
