# coding=utf-8
# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved.
import utils.common.log as logger
from utils.common.message import Message
from utils.common.exception import HCCIException
from utils.common.fic_base import StepBaseInterface
from plugins.DistributedStorage.logic.deploy_operate import DeployOperate
from plugins.DistributedStorage.Deploy.scripts.PreCheck.sub_job_check_poolname_consistency import get_expand_params


class CheckAddNodeCabinetLevel(StepBaseInterface):
    def __init__(self, project_id, pod_id):
        super().__init__(project_id, pod_id)
        self.project_id = project_id
        self.pod_id = pod_id
        self.fs_args_list = get_expand_params(self.project_id, self.pod_id)[0]
        self.implement = CheckCabinetLevel(project_id, pod_id, self.fs_args_list)

    def pre_check(self, project_id, pod_id):
        """
        插件内部接口：执行安装前的资源预检查，该接口由execute接口调用，工具框架不会直接调用此接口。
        :param project_id:
        :param pod_id:
        :return:Message类对象
        """
        return Message(200)

    def execute(self, project_id, pod_id):
        try:
            self.implement.procedure()
        except HCCIException as e1:
            return Message(500, e1)
        except Exception as e2:
            return Message(500, e2)
        return Message(200)

    def rollback(self, project_id, pod_id):
        """
        标准调用接口：执行回滚
        :param project_id:
        :param pod_id:
        :return:Message类对象
        """
        return Message(200)

    def retry(self, project_id, pod_id):
        """
        标准调用接口：重试
        :return: Message类对象
        """
        return self.execute(project_id, pod_id)

    def check(self, project_id, pod_id):
        """
        标准调用接口：重试
        :param project_id:
        :param pod_id:
        :return:Message类对象
        """
        return Message(200)


def check_ec_pool(pool_node_info):
    max_rack = 32
    min_servers_in_rack = 3
    for rack_name, server_number in pool_node_info.items():
        if server_number < min_servers_in_rack:
            err_msg = f"rack name: {rack_name}, server number:{server_number}"
            raise HCCIException(626063, err_msg)
    if len(pool_node_info) > max_rack:
        raise HCCIException(626061, "rack_id_number")


def check_node_differ(pool_node_info):
    """
    检查机柜间节点差异是否满足：节点数最多相差2个，且百分比最多相差小于33%
    :param pool_node_info:
    :return:
    """
    min_node, max_node = float("inf"), float("-inf")
    for node_num in pool_node_info.values():
        min_node = min(min_node, node_num)
        max_node = max(max_node, node_num)
    differ_num = max_node - min_node
    differ_percent = (max_node - min_node) / max_node
    if differ_num > 2 or differ_percent >= 0.33:
        logger.error("The quantity difference is {} and the percentage difference is {:.2f}%".format(
            differ_num, differ_percent * 100))
        raise HCCIException(626403, differ_num, "{:.2f}%".format(differ_percent * 100))
    logger.info("The node quantity difference between cabinets passes the check.")


def check_3redundancy_pool(pool_node_info):
    max_rack = 12
    min_servers_in_rack = 3
    max_servers_in_rack = 24
    for rack_name, server_number in pool_node_info.items():
        if server_number < min_servers_in_rack or server_number > max_servers_in_rack:
            err_msg = f"rack name: {rack_name}, server number:{server_number}"
            raise HCCIException(626063, err_msg)
    if len(pool_node_info) > max_rack:
        raise HCCIException(626061, "rack_id_number")


class CheckCabinetLevel(object):
    def __init__(self, project_id, pod_id, fs_args):
        self.project_id = project_id
        self.pod_id = pod_id
        self.fs_args = fs_args
        self.fsm_login_user = "admin"
        self.fsm_float_ip = self.fs_args.get('fsm_float_ip')
        self.fsm_admin_passwd = self.fs_args.get('fsm_admin_passwd')
        self.osd_list = self.fs_args.get('osd_list')
        self.pool_name = self.osd_list[0].get("storage_pool_name_and_slot")
        self.operate = DeployOperate(float_ip=self.fsm_float_ip)
        self.redundancy_policy = ""

    def procedure(self):
        """
       原池扩节点时，如果扩进机柜级安全存储池，需要校验是否依然满足机柜级安全的条件
       该工步在名称一致性校验后，保证待扩池存在且只有一个
       """
        logger.info(f"Storage pool to be checked:{self.pool_name}")

        self.operate.login(self.fsm_login_user, self.fsm_admin_passwd)
        try:
            cabinet_level_pool_info = self.get_cabinet_level_pool()
            self.redundancy_policy = cabinet_level_pool_info.get("redundancyPolicy")

            if not cabinet_level_pool_info:
                logger.info(f"The security level of {self.pool_name} is not cabinet level, no need to check.")
                return

            pool_id = cabinet_level_pool_info.get("poolId")
            disk_pool_id = self.operate.get_disk_pool_id(pool_id)
            logger.info(f"Storage pool id is {pool_id}, disk pool id is {disk_pool_id}")

            pool_node_info = self.get_node_info_by_disk_pool(disk_pool_id)
            self.create_all_node_rack_info(pool_node_info)

            if self.redundancy_policy.upper() == "EC":
                logger.info(f"The redundancy policy of the current storage pool is EC.")
                check_ec_pool(pool_node_info)
            elif self.redundancy_policy == "replication":
                logger.info(f"The redundancy policy of the current storage pool is 3Redundancy.")
                check_3redundancy_pool(pool_node_info)
            else:
                err_msg = f"Unrecognized redundancy policy of storage pool {self.pool_name}: [{self.redundancy_policy}]"
                logger.error(err_msg)
                raise HCCIException(626293, err_msg)

            check_node_differ(pool_node_info)
        finally:
            self.operate.login_out(self.fsm_login_user, self.fsm_admin_passwd)

    def get_cabinet_level_pool(self):
        logger.info("Query storage pool data.")
        res_pool = self.operate.query_storage_pool()
        pool_info = res_pool.get_query_data()

        storage_pools = pool_info.get("storagePools")
        if not storage_pools:
            logger.error("check pool fail...")
            raise Exception("check pool fail...")

        for pool_info in storage_pools:
            if pool_info.get("poolName") == self.pool_name and pool_info.get("securityLevel") == "rack":
                return pool_info
        return {}

    def get_node_info_by_disk_pool(self, disk_pool_id):
        """
        :param disk_pool_id:
        :return:{"rack1": num1, "rack2": num2, ...}
        """
        original_nodes_info = self.operate.query_node_disk_info(disk_pool_id).get("nodeInfo")
        if not original_nodes_info:
            err_msg = 'Failed to query node info in diskPool. disk pool id: {}'.format(disk_pool_id)
            logger.error(err_msg)
            raise HCCIException(626078, err_msg)
        node_info = {}
        for node in original_nodes_info:
            rack_name = node.get("rack")
            node_info[rack_name] = node_info.get(rack_name, 0) + 1

        return node_info

    def create_all_node_rack_info(self, pool_node_info):
        for osd in self.osd_list:
            rack_id = osd.get("rack_id")
            pool_node_info[rack_id] = pool_node_info.get(rack_id, 0) + 1
