# -*- coding: UTF-8 -*-
import traceback

import cliUtil
import common
import config
from common_utils import get_err_msg
from cli_util_cache import (
    get_vertical_cache,
    get_horizontal_no_standard_cache,
)
from storage_obj_constant import StoragePoolAttribute as PoolAttr

LANG = common.getLang(py_java_env)
LOGGER = common.getLogger(PY_LOGGER, __file__)
PY_JAVA_ENV = py_java_env

# pool中各层级名称对应到domain中的层级名称
TIER_NAME_MAP = {
    PoolAttr.TIER_NAME_EXTREME_PERFORMANCE: PoolAttr.TIER_ZERO_DISK_NUMBER,
    PoolAttr.TIER_NAME_PERFORMANCE: PoolAttr.TIER_ONE_DISK_NUMBER,
    PoolAttr.TIER_NAME_CAPACITY: PoolAttr.TIER_TWO_DISK_NUMBER,
}

# domain中各层级名称对应到domain层级数字
TIER_NUMBER = {
    PoolAttr.TIER_ZERO_DISK_NUMBER: PoolAttr.TIER_ZERO,
    PoolAttr.TIER_ONE_DISK_NUMBER: PoolAttr.TIER_ONE,
    PoolAttr.TIER_TWO_DISK_NUMBER: PoolAttr.TIER_TWO,
}

# 存储池的全部tier
ALL_TIER_IN_POOL = (
    PoolAttr.TIER_ZERO_DISK_NUMBER,
    PoolAttr.TIER_ONE_DISK_NUMBER,
    PoolAttr.TIER_TWO_DISK_NUMBER
)


def execute(cli):
    """
    RAID10存储池所在硬盘域硬盘数量检查
    :param cli:
    :return:
    """
    pool_tier_disk_num_check = PoolRaidTierDiskNumCheck(
        cli, LANG, PY_JAVA_ENV, LOGGER
    )
    flag, msg = pool_tier_disk_num_check.execute_check()
    return flag, "\n".join(pool_tier_disk_num_check.all_ret_list), msg


class PoolRaidTierDiskNumCheck:
    def __init__(self, cli, lang, env, logger):
        self.cli = cli
        self.lang = lang
        self.env = env
        self.logger = logger
        self.all_ret_list = []
        self.err_msg = ""

    def execute_check(self):
        try:
            pool_domain_info = self.get_pool_info()
            self.logger.logInfo("pool info is:{}".format(pool_domain_info))
            if not pool_domain_info:
                return True, ''

            risk_domain_info = {}
            for pool_id, domain_id in pool_domain_info.items():
                tier_lever_list = self.get_pool_raid_one_zero_tier_info(
                    pool_id
                )
                if not tier_lever_list:
                    continue

                pool_dict = risk_domain_info.get(domain_id, {})
                pool_dict[pool_id] = tier_lever_list
                risk_domain_info[domain_id] = pool_dict
            self.logger.logInfo(
                "risk_domain_info is:{}".format(risk_domain_info)
            )
            if not risk_domain_info:
                return True, ''

            err_msg_list = []
            domain_tier_disk_num = self.get_exp_domain_info()
            for domain_id, pool_info in risk_domain_info.items():
                err_msg_list.extend(
                    self.get_odd_domain_disk_number(
                        domain_id, pool_info, domain_tier_disk_num
                    )
                )
            if err_msg_list:
                return False, "".join(err_msg_list)

            return True, ""
        except common.UnCheckException as e:
            LOGGER.logError(str(e))
            return cliUtil.RESULT_NOCHECK, e.errorMsg
        except Exception:
            LOGGER.logError(str(traceback.format_exc()))
            return (
                cliUtil.RESULT_NOCHECK,
                common.getMsg(self.lang, "query.result.abnormal"),
            )

    def is_expansion_scene(self):
        """
        是否扩容扩容前巡检
        因扩容场景无法评估在哪个存储池上，所以直接通过。
        :return:
        """
        scene_data = self.env.get("sceneData")
        if not scene_data:
            return False

        self.logger.logInfo("scene data :{}".format(scene_data))
        if scene_data.get("mainScene") != "Expansion":
            return False

        tool_scene = scene_data.get("toolScene")
        if not tool_scene:
            return False

        # 是否是扩容前巡检
        if tool_scene == "perInspector":
            return True

        return False

    def get_pool_raid_one_zero_tier_info(self, pool_id):
        """
        获取pool的tier层raid level
        :param pool_id: 存储池ID
        :return: RAID10 的tier层
        """
        tier_lever_list = []
        cmd = (
            "show storage_pool tier pool_id={}|filterColumn "
            r"include columnList=RAID\sLevel,Pool\sID,Name"
        )

        flag, cli_ret, tier_info_list, msg = get_vertical_cache(
            self.cli, self.env, self.logger, cmd.format(pool_id)
        )
        self.all_ret_list.append(cli_ret)
        if flag is not True:
            raise common.UnCheckException(msg, cli_ret)
        for tier_info in tier_info_list:
            if (
                tier_info.get(PoolAttr.RAID_LEVEL)
                != PoolAttr.RAID_10
            ):
                continue

            tier_name = tier_info.get(PoolAttr.NAME, '')
            tier_level_in_domain = TIER_NAME_MAP.get(tier_name)
            if tier_level_in_domain:
                tier_lever_list.append(tier_level_in_domain)

        return tier_lever_list

    def get_odd_domain_disk_number(
            self, domain_id, pool_info, domain_tier_disk_num):
        """
        检查奇数的domain
        :param domain_id: 硬盘域ID
        :param pool_info: 存储池
        :param domain_tier_disk_num: 扩容盘数和tier对应关系
        :return: 奇数盘的domain
        """
        err_msg_list = []
        cmd = (
            "show disk_domain general disk_domain_id={}|filterColumn"
            r" include columnList=Tier0\sDisk\sNumber,"
            r"Tier1\sDisk\sNumber,Tier2\sDisk\sNumber"
        )
        flag, cli_ret, tier_info_list, msg = get_vertical_cache(
            self.cli, self.env, self.logger, cmd.format(domain_id)
        )
        self.all_ret_list.append(cli_ret)
        if flag is not True:
            raise common.UnCheckException(msg, cli_ret)
        if not tier_info_list:
            return err_msg_list

        exp_flag = self.is_expansion_scene()
        LOGGER.logInfo("is exp scene:{},domain_id:{},domain_tier_disk_num:{}"
                       .format(exp_flag, domain_id, domain_tier_disk_num))
        for pool_id, tier_lever_list in pool_info.items():
            for tier_name in tier_lever_list:
                disk_num_str = tier_info_list[0].get(tier_name, '0')
                tier_disk_num = domain_tier_disk_num.get(domain_id, {})
                err_msg_list.extend(self.check_pool_tier(
                    disk_num_str, tier_disk_num, tier_name,
                    pool_id
                ))
        return err_msg_list

    def check_pool_tier(self, disk_num_str, tier_disk_num, tier_name, pool_id):
        """
        检查tier层级为raid10的硬盘数量
        :param disk_num_str: 已有硬盘数量
        :param tier_disk_num: 扩容新增硬盘
        :param tier_name: 层级名称
        :param err_msg_list: 错误消息列表
        :param pool_id: 存储池id
        :return:
        """
        err_msg_list = []
        err_key = "software.disk.domain.raid.ten.disk.num.not.pass"
        if not str(disk_num_str).isdigit():
            return err_msg_list
        disk_num = int(disk_num_str)
        disk_num = self.get_expansion_num(disk_num, tier_name, tier_disk_num)
        if self.is_odd_num(disk_num):
            show_tier_name = tier_name.split()[0]
            err_msg_list.append(
                get_err_msg(
                    self.lang, err_key, (pool_id, show_tier_name, disk_num)
                )
            )
        return err_msg_list

    @staticmethod
    def get_expansion_num(disk_num, tier_name, tier_disk_num):
        """
        如果扩容未扩当前tier层，则当前层计算为 0 + 当前层硬盘数量。
        :param disk_num: 当前层硬盘数量
        :param tier_name: 当前tier层名称
        :param tier_disk_num: 扩容各tier层硬盘数量数据
        :return:
        """
        tier_num = str(TIER_NUMBER.get(tier_name))
        return tier_disk_num.get(tier_num, 0) + disk_num

    @staticmethod
    def get_exp_domain_info():
        """
        获取扩容场景扩容的硬盘域新增的各tier层盘数。
        需要将硬盘域下硬盘类型转换为tier层级再和盘数做对应。
        :return:
        """
        all_exp_list = common.getExpDiskListFromContextFilter(PY_JAVA_ENV)
        if not all_exp_list:
            return dict()

        domain_tier_disk_num = dict()
        for disk_info in all_exp_list:
            domain_id = disk_info.get("diskDomain")
            tier_disk_num = domain_tier_disk_num.get(domain_id, {})
            disk_type = disk_info.get("diskModel")
            tier_type = str(config.DISK_TYPE_LEVEL_ENUM.get(disk_type))
            disk_num = int(disk_info.get("diskNum"))
            tier_disk_num[tier_type] = int(tier_disk_num.get(
                tier_type, 0)) + disk_num
            domain_tier_disk_num[domain_id] = tier_disk_num

        return domain_tier_disk_num

    @staticmethod
    def is_odd_num(number_str):
        """
        是否是奇数
        :param number_str:
        :return:True:奇数，False:非奇数
        """
        return int(number_str) % 2 == 1

    def get_pool_info(self):
        """
        获取pool和domain信息
        :return: pool信息
        """
        pool_domain_info = {}
        cmd = "show storage_pool general"
        flag, ret, pool_info_list, msg = get_horizontal_no_standard_cache(
            self.cli, self.env, self.logger, cmd
        )
        self.all_ret_list.append(ret)
        if flag is not True:
            raise common.UnCheckException(msg, ret)

        for pool_info in pool_info_list:
            pool_id = pool_info.get(PoolAttr.ID, "")
            domain_id = pool_info.get(PoolAttr.DISK_DOMAIN_ID, "")
            if not pool_id or not domain_id:
                continue
            pool_domain_info[pool_id] = domain_id
        return pool_domain_info
