# -*- coding: UTF-8 -*-
import re

from cbb.frame.cli import cliUtil
from cbb.frame.base.exception import UnCheckedException
from cbb.frame.tlv import adTlvUtil
from cbb.frame.context import contextUtil
from cbb.frame.base import baseUtil
from cbb.business.operate.expansion import common
from cbb.common.conf.productConfig import DORADO_NAS, compare_dorado_version
from cbb.frame.cli.cliUtil import get_smart_cache_pool_for_nas


def execute(cli, lang, logger, context,
            free_disk_list=None, check_type="inspect"):
    """
    智能缓存池扩引擎检查

    :param cli: cli
    :param lang: lang
    :param logger: logger
    :param context: context
    :param free_disk_list:  free_disk
    :param check_type: 场景
    :return:
    """
    scm_check = SCMCheck(cli, lang, context,
                         logger, free_disk_list, check_type)
    flag, err_msg = scm_check.execute_check()
    return flag, "\n".join(scm_check.cli_ret_list), err_msg


def check_need_expand_scm(context):
    """
    检查是否需要扩scm
    :param context: 上下文
    :return: True/False
    """
    logger = common.getLogger(context.get("logger"), __file__)
    dev_obj = contextUtil.getDevObj(context)
    product_mode = dev_obj.get("type")
    if not common.isDorado(product_mode):
        logger.logInfo(
            "the device is not dorado. product model={}".format(product_mode))
        return False

    version = dev_obj.get("version")
    lang = contextUtil.getLang(context)
    if not compare_dorado_version(version, DORADO_NAS)[1]:
        logger.logInfo(
            "the device is not dorado nas. version={}".format(version))
        return False
    cli = contextUtil.getCli(context)
    flag, pool_id, _, _ = get_smart_cache_pool_for_nas(cli, lang, logger)
    if pool_id == "-1":
        logger.logInfo("no smart cache pool.")
        return False
    return True


class SCMCheck:

    def __init__(self, cli, lang, context, logger,
                 free_disk_list, check_type):
        self.cli = cli
        self.lang = lang
        self.context = context
        self.logger = logger
        self.free_disk_list = free_disk_list if free_disk_list else []
        self.cli_ret_list = []
        self.check_type = check_type

    def execute_check(self):
        try:
            if self.check_type in ["inspect", "select_disk"]:
                return self.check_for_inspect()
            return self.check_for_expansion()
        except UnCheckedException as e:
            self.logger.logException(e)
            return cliUtil.RESULT_NOCHECK, e.errorMsg
        except Exception as e:
            self.logger.logException(e)
            err_msg = cliUtil.getMsg(self.lang, "query.result.abnormal")
            return cliUtil.RESULT_NOCHECK, err_msg

    def check_for_inspect(self):
        if not self.has_smart_cache_pool():
            return False, \
                   common.getMsg(self.lang, "expand.cache.no.cache.pool")[0]

        if not self._check_scm_inspect():
            flag, disk_info_list, cli_ret, err_msg = \
                self.get_old_cache_scm_disk()
            if flag is not True:
                raise UnCheckedException(err_msg)
            disk_ids = [disk.get("location") for disk in disk_info_list]
            disk_ids.extend(self.free_disk_list)
            if self.check_type == "inspect":
                self.cli_ret_list.append(cli_ret)
                disk_data = self.get_msg_info(disk_ids)
                return False, common.getMsg(self.lang,
                                            "old.smart.cache.pool.not.equal",
                                            disk_data)[0]
            return self._pop_warning_dialog(disk_ids, 0)
        return True, ""

    def check_for_expansion(self):
        if not self.has_smart_cache_pool():
            return True, ""

        flag, err_msg = self._check_scm_expansion()
        if not flag:
            return False, err_msg
        return True, ""

    def has_smart_cache_pool(self):
        """
        判断是否存在 smart_cache_pool
        :return: True: 检查不通过
                 False: 检查通过
        """
        flag, has_scm_pool, cli_ret, err_msg = \
            cliUtil.has_smart_cache_pool(self.cli, self.lang, self.logger)
        self.cli_ret_list.append(cli_ret)
        return has_scm_pool

    def get_old_cache_scm_disk(self):
        """
        获取原设备上 Cache Disk的SCM盘
        """
        disk_info_list = []
        cmd = "show disk general"
        flag, cli_ret, err_msg = \
            cliUtil.excuteCmdInCliMode(self.cli, cmd, True, self.lang)

        if flag is not True:
            return flag, disk_info_list, cli_ret, err_msg

        if cliUtil.queryResultWithNoRecord(cli_ret):
            self.logger.logInfo("Smart Cache Pool does not exist!")
            return False, disk_info_list, cli_ret, err_msg

        ret_dict = cliUtil.getHorizontalCliRet(cli_ret)
        if not ret_dict:
            return flag, disk_info_list, cli_ret, err_msg
        for record in ret_dict:
            if self.is_scm_card_or_disk(record):
                disk_info = {
                    "location": record.get("ID"),
                    "type": record.get("Type"),
                    "capacity": record.get("Capacity")
                }
                disk_info_list.append(disk_info)

        return flag, disk_info_list, cli_ret, err_msg

    def is_scm_card_or_disk(self, record):
        """
        判断是否为SCM卡或盘
        :return:
        """
        if not record.get("Type").startswith("SCM"):
            return False
        if record.get("Role").startswith("Cache Disk"):
            return True
        if all([self.check_type == "expansion",
                contextUtil.getItem(self.context, "expand_ctrl_and_engine"),
                self.is_disk_belong_to_cd_quadrant(record.get("ID")),
                record.get("Role").startswith("Free Disk")]):
            return True
        return False

    @staticmethod
    def is_disk_belong_to_cd_quadrant(location):
        """
        判断硬盘是否属于CD象限
        :param location:
        :return:
        """
        matched = re.match("^CTE\d+\..*[A-Za-z](\d+)$", location)
        if not matched:
            return False
        return int(matched.group(1)) >= 7

    def _check_scm_inspect(self):
        """
        检查待扩设备是否有SCM卡或盘
        :return:
        """
        # 获取原设备SCM
        flag, disk_info_list, cli_ret, err_msg = \
            self.get_old_cache_scm_disk()
        if flag is not True:
            raise UnCheckedException(err_msg)
        self.cli_ret_list.append(cli_ret)

        # 判断扩容后各引擎SCM盘是否对称
        disk_ids = [disk.get("location") for disk in disk_info_list]
        disk_ids.extend(self.free_disk_list)
        return self.check_scm_disk_balanced(disk_ids)[0]

    def _check_scm_expansion(self):
        """
        检查待扩设备是否有SCM卡或盘
        :return:
        """
        # 获取原设备SCM
        flag, disk_info_list, cli_ret, err_msg = \
            self.get_old_cache_scm_disk()
        if not flag:
            raise UnCheckedException(err_msg)

        # 判断原设备SCM是否相等
        disk_ids = [disk.get("location") for disk in disk_info_list]
        new_scm_dict = self.get_new_engine_scm_disk()
        for value in new_scm_dict.values():
            if len(value) == 0:
                return False, \
                       common.getMsg(self.lang,
                                     "expand.scm.new.engine.no.scm")[0]
        flag, old_scm_num = self.check_scm_disk_balanced(disk_ids)
        if not flag:
            return self._pop_warning_dialog(disk_ids, new_scm_dict)

        if self.get_scm_min_num(new_scm_dict) < old_scm_num:
            return self._pop_warning_dialog(disk_ids, new_scm_dict)
        return True, ""

    def _pop_warning_dialog(self, disk_ids, new_scm_dict):
        # 弹出提示框
        dialog_util = self.context['dialogUtil']
        old_disk_data = self.get_msg_info(disk_ids)
        if new_scm_dict != 0:
            new_disk_data = self.get_msg_info_on_sn(new_scm_dict)
            msg = common.getMsg(self.lang,
                                "smart.cache.pool.not.equal",
                                (old_disk_data, new_disk_data))[0]
        else:
            msg = common.getMsg(self.lang,
                                "old.smart.cache.pool.not.equal",
                                old_disk_data)[0]
        rec = dialog_util.showWarningDialog(msg)
        self.logger.logInfo("user select value={}".format(rec))
        if not rec:
            return False, common.getMsg(self.lang,
                                        "expand.scm.select.cancel")[0]
        return True, ""

    def get_new_engine_scm_disk(self):
        tlv = contextUtil.getTlv(self.context)
        new_boards_list = contextUtil.getItem(self.context, "newBoardsList")
        # 获取新设备SCM
        new_scm_dict = {}
        for board in new_boards_list:
            scm_list = adTlvUtil.get_scm_disk(tlv, board)
            if board.get("enclosureSN") not in new_scm_dict.keys():
                new_scm_dict[board.get("enclosureSN")] = []
            new_scm_dict[board.get("enclosureSN")].extend(scm_list)
        return new_scm_dict

    def check_scm_disk_balanced(self, disk_info_list):
        product_model = contextUtil.getProductModel(self.context)
        if baseUtil.isDoradoV6HighEnd(product_model):
            return self._check_disk_balance_by_plane(disk_info_list)
        return self._check_scm_disk_balanced_by_engine(disk_info_list)

    def _check_scm_disk_balanced_by_engine(self, disk_info_list):
        """
        判断每个引擎的SCM盘是否相等，如果相等返SCM盘数量
        :param disk_info_list:
        :return:
        """
        self.logger.logInfo("check balanced disk={}".format(disk_info_list))
        num_list = []
        if not disk_info_list:
            return False, num_list
        for disk_info in disk_info_list:
            num_list.append(disk_info[3])
        self.logger.logInfo("num_list={}".format(num_list))
        scm_num_dict = {}
        for key in num_list:
            scm_num_dict[key] = scm_num_dict.get(key, 0) + 1
        scm_num_list = scm_num_dict.values()
        self.logger.logInfo("scm_num_list={}".format(scm_num_list))
        if all([len(self.get_engine_id()) == len(scm_num_list),
                len(set(scm_num_list)) == 1]):
            return True, scm_num_list[0]

        return False, scm_num_list[0]

    def _check_disk_balance_by_plane(self, scm_disks):
        """
        高端按平面判断冗余
        :param scm_disks:
        :return: 是否均衡，均衡时每个引擎的盘数
        """
        all_engine_datas = self.get_disk_info(scm_disks)
        # 各引擎的硬盘数量如下：0-left:2,right:1, 1-left:0,right:1
        engine_scm_disks = []
        for engine_id, plane_info in all_engine_datas.items():
            if not self.is_balance_by_plane(plane_info):
                self.logger.logInfo(
                    "engine {} plane not consistent.".format(engine_id))
                return False, 0
            engine_scm_disks.append(
                plane_info.get("left") + plane_info.get("right"))
        if len(set(engine_scm_disks)) != 1 or \
                len(self.get_engine_id()) != len(engine_scm_disks):
            self.logger.logInfo(
                "between engine not consistent.")
            return False, ""
        return True, engine_scm_disks[0]

    def is_balance_by_plane(self, plane_info):
        """
        判断高端平面是否均衡
        :return:
        """
        if all([contextUtil.getItem(self.context, "expand_ctrl_and_engine"),
                self.check_type == "expansion",
                plane_info.get("right") >= plane_info.get("left")]):
            return True
        if plane_info.get("left") == plane_info.get("right"):
            return True
        return False

    def get_scm_min_num(self, new_scm_dict):
        """
        获取单个引擎SCM的最小值
        :param new_scm_dict: SCM字典
        :return:
        """
        self.logger.logInfo("new scm ")
        num_list = []
        for disk_info in new_scm_dict.values():
            num_list.append(len(disk_info))
        return min(num_list)

    def get_engine_id(self):
        """
        获取引擎的ID
        :return:
        """
        engine_list = cliUtil.getEngineIdList(self.cli, self.lang)[-1]
        if not engine_list:
            err_msg = cliUtil.getMsg(self.lang, "query.result.abnormal")
            raise UnCheckedException(err_msg)
        return engine_list

    def get_msg_info(self, disk_ids):
        product_model = contextUtil.getProductModel(self.context)
        all_engine_datas = dict()
        msg_list = []
        if baseUtil.isDoradoV6HighEnd(product_model):
            all_engine_datas = self.get_disk_info(disk_ids)
            for key, value in all_engine_datas.items():
                msg = common.getMsg(self.lang,
                                    "scm.card.info",
                                    (key,
                                     value.get("left", 0),
                                     value.get("right", 0)))[0]
                msg_list.append(msg)
            return ",".join(msg_list)
        for location in disk_ids:
            engine_id = location[3]
            if engine_id not in all_engine_datas:
                all_engine_datas[engine_id] = 0
            all_engine_datas[engine_id] += 1
        for key, value in all_engine_datas.items():
            msg = common.getMsg(self.lang, "scm.disk.info", (key, value))[0]
            msg_list.append(msg)
        return ",".join(msg_list)

    def get_msg_info_on_sn(self, new_scm_dict):
        product_model = contextUtil.getProductModel(self.context)
        all_engine_datas = dict()
        msg_list = []
        if baseUtil.isDoradoV6HighEnd(product_model):
            for sn, disk_info in new_scm_dict.items():
                all_engine_datas = self.get_disk_info(disk_info)
                for value in all_engine_datas.values():
                    msg = common.getMsg(self.lang,
                                        "scm.card.info.on.sn",
                                        (sn,
                                         value.get("left", 0),
                                         value.get("right", 0)))[0]
                    msg_list.append(msg)
            return ",".join(msg_list)
        for sn, disk_info in new_scm_dict.items():
            all_engine_datas[sn] = len(disk_info)
        for sn, value in all_engine_datas.items():
            msg = common.getMsg(self.lang, "scm.disk.info.on.sn", (sn, value))[
                0]
            msg_list.append(msg)
        return ",".join(msg_list)

    @staticmethod
    def get_disk_info(scm_disks):
        """
        获取各引擎SCM盘信息
        """
        all_engine_datas = dict()
        for location in scm_disks:
            matched = re.match("^CTE(\d)\..*[A-Za-z](\d+)$", location)
            if not matched:
                continue
            engine_id = matched.group(1)
            slot = int(matched.group(2))
            plane = "left" if slot < 7 else "right"
            if engine_id not in all_engine_datas:
                all_engine_datas[engine_id] = {}
            if plane not in all_engine_datas[engine_id]:
                all_engine_datas[engine_id][plane] = 0
            all_engine_datas[engine_id][plane] += 1
        return all_engine_datas
