#!/usr/bin/python
# -*- coding: UTF-8 -*-
# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved.
import json
import copy
from collections import defaultdict

import utils.common.log as logger
from utils.common.exception import HCCIException
from utils.common.error.hcci_error_code import get_code_msg
from plugins.DistributedStorage.utils.common.deploy_constant import DeployConstant


class PoolInfoCheck(object):
    @staticmethod
    def pool_info_check_dict_init():
        """
        pool 参数获取错误信息收集
        exam:
        {pool_1:
            { "main_storage_info":{
                 "disk_status":
                 {"no_use":{om_ip:[slot0, slot1]}, "used":{om_ip:[slot2]}},      key不为1表示有错误，错误信息为非“no_use"字段
                 "type_info": {"stata_disk":{om_ip:[slot0, slot1]}, "sas_disk":{om_ip:[slot2]}},  key不为1表示有错误
                 "capacity_info": {4000:{om_ip:[slot0, slot1]}, 3000:{om_ip:[slot2]}},            key不为1表示有错误
                 "disk_num": {num1:{om_ip:slot_range}, num2:{om_ip2:slot_range}}                  小于4表示有错误
                 }}
            {"cache_info":{
                "media_type": {"ssd_card":[om_ip], "ssd_disk":[om_ip]},          集合长度不为1表示不一致
                "media_size": {3200:[om_ip], 3000:[om_ip]},                      集合长度不为1表示不一致
                "cache_num": {5:[om_ip], 0:[om_ip]}                              集合中存在0表示存在没有缓存的节点
                }}
            "disk_num_gap": (max_num, min_num)                                   盘大小相差不大于2
            “percentage”：0.3                                                    盘数量差不超过30%
            “disk_total_num”：13                                                 存储池盘总量不低于12
        }

        """
        main_storage_info = dict()
        base_dict = defaultdict(list)
        main_storage_info["disk_status"] = defaultdict(lambda: base_dict.copy())
        main_storage_info["type_info"] = defaultdict(lambda: base_dict.copy())
        main_storage_info["capacity_info"] = defaultdict(lambda: base_dict.copy())
        main_storage_info["disk_num"] = defaultdict(dict)

        one_pool = {
            "main_storage_info": main_storage_info,
            "cache_info": defaultdict(lambda: base_dict.copy()),
            "disk_num_gap": None,
            "percentage": None,
            "disk_total_num": None,
            "nvme_main_storage": {'no_disk_error': [], 'disk_nums_error': [], 'disk_capacity': defaultdict(list)}
        }
        pool_info_check = defaultdict(lambda: copy.deepcopy(one_pool))
        return pool_info_check

    @staticmethod
    def pool_info_error_check(pool_info_check):
        total_error = list()
        for pool, pool_info in pool_info_check.items():
            # 判断主存信息
            pool_error = list()
            main_storage_infos = pool_info.get("main_storage_info", dict())
            cache_infos = pool_info.get("cache_info", dict())
            PoolInfoCheck.node_disk_num_check(pool, pool_error, pool_info)
            # 主存槽位盘状态，主存盘类型、容量一致性，以及盘数量检查
            PoolInfoCheck.pool_disk_consistency_check(main_storage_infos, pool, pool_error)

            # 当前存储池缓存盘容量、类型、数量检查
            PoolInfoCheck.pool_cache_check(cache_infos, pool, pool_error)

            # nvme全闪场景，错误信息统一处理
            PoolInfoCheck.nvme_error_check(pool, pool_error, pool_info)
            total_error.extend(pool_error)

        if total_error:
            error_msg = get_code_msg(626340) % PoolInfoCheck.format_error_msg(total_error)
            raise Exception(error_msg)

    @staticmethod
    def node_disk_num_check(pool, pool_error, pool_info):
        num = pool_info.get("disk_num_gap")
        percentage = pool_info.get("percentage", 0)
        disk_total_num = pool_info.get("disk_total_num", 12)
        if num and num[0] - num[1] > 2:
            error_msg = get_code_msg(626337) % (pool, num[0], num[1])
            pool_error.append(error_msg)
        if percentage and percentage > 0.3:
            percentage *= 100
            error_msg = get_code_msg(626338) % (pool, percentage)
            pool_error.append(error_msg)
        if disk_total_num and disk_total_num < 12:
            error_msg = get_code_msg(626339) % (pool, disk_total_num)
            pool_error.append(error_msg)

    @staticmethod
    def pool_disk_consistency_check(main_storage_infos, pool, pool_error):
        disk_status = main_storage_infos.get("disk_status", dict())
        type_info = main_storage_infos.get("type_info", dict())
        capacity_info = main_storage_infos.get("capacity_info", dict())
        disk_num = main_storage_infos.get("disk_num", dict())
        if len(disk_status.keys()) > 0:
            disk_status_error = PoolInfoCheck.format_main_storage_error(disk_status, "status")
            error_msg = get_code_msg(626116) % (pool, disk_status_error)
            pool_error.append(error_msg)
        if len(type_info.keys()) > 1:
            type_info_error = PoolInfoCheck.format_main_storage_error(type_info, "type")
            error_msg = get_code_msg(626117) % (pool, "main storage", "disk type", type_info_error)
            pool_error.append(error_msg)
        if len(capacity_info) > 1:
            capacity_error = PoolInfoCheck.format_main_storage_error(capacity_info, "capacity")
            error_msg = get_code_msg(626117) % (pool, "main storage", "capacity", capacity_error)
            pool_error.append(error_msg)
        disk_num_dict = dict()
        for num, value in disk_num.items():
            if num < 4:
                disk_num_dict[num] = value
        if disk_num_dict:
            disk_num_error = PoolInfoCheck.format_main_storage_error(disk_num_dict, "nums")
            error_msg = get_code_msg(626118) % (pool, disk_num_error)
            pool_error.append(error_msg)

    @staticmethod
    def pool_cache_check(cache_infos, pool, pool_error):
        media_type = cache_infos.get("media_type", dict())
        media_size = cache_infos.get("media_size", dict())
        cache_num = cache_infos.get("cache_num", dict())
        if len(media_type.keys()) > 1:
            media_type_error = PoolInfoCheck.format_cache_storage_error(media_type, "type")
            error_msg = get_code_msg(626117) % (pool, "cache", "cache type", media_type_error)
            pool_error.append(error_msg)
        if len(media_size.keys()) > 1:
            media_size_error = PoolInfoCheck.format_cache_storage_error(media_size, "capacity")
            error_msg = get_code_msg(626117) % (pool, "cache", "capacity", media_size_error)
            pool_error.append(error_msg)
        if 0 in cache_num.keys():
            error_msg = get_code_msg(626119) % (pool, json.dumps(cache_num.get(0)))
            pool_error.append(error_msg)

    @staticmethod
    def nvme_error_check(pool, pool_error, pool_info):
        nvme_storage_infos = pool_info.get("nvme_main_storage", dict())
        no_disk_error = nvme_storage_infos.get('no_disk_error')
        disk_nums_error = nvme_storage_infos.get('disk_nums_error')
        disk_capacity = nvme_storage_infos.get('disk_capacity')
        if no_disk_error:
            error_msg = get_code_msg(626373) % no_disk_error
            pool_error.append(error_msg)
        if disk_nums_error:
            error_msg = get_code_msg(626374) % disk_nums_error
            pool_error.append(error_msg)
        if len(disk_capacity.keys()) > 1:
            error_msg = get_code_msg(626379) % (pool, dict(disk_capacity))
            pool_error.append(error_msg)

    @staticmethod
    def format_error_msg(error_info):
        error_info_str = ""
        for index, error_msg in enumerate(error_info):
            error_info_str += str(index + 1) + "、" + error_msg
        return error_info_str

    @staticmethod
    def format_main_storage_error(error_items, error_type):
        error_list = list()
        for status, node_info in error_items.items():
            template_error = get_code_msg(626342).strip("\r\n")
            node_error_info = [template_error % (node[0], node[1]) for
                               node in zip(node_info.keys(), node_info.values())]
            node_error_info = ",".join(node_error_info)
            error_list.append(get_code_msg(626341).strip("\r\n") % (error_type, status, node_error_info))
        return ";".join(error_list)

    @staticmethod
    def format_cache_storage_error(error_items, error_type):
        error_list = list()
        for status, node_info in error_items.items():
            node_info = ','.join(node_info)
            error_list.append(get_code_msg(626341).strip("\r\n") % (error_type, status, node_info))
        return ";".join(error_list)


class PoolLLDParamsCheck(object):
    def __init__(self):
        self.pool_params_module = {'node_num': 0,
                                   'cache_type': [None, True],
                                   'pool_type': [None, True],
                                   'redundancy_policy': [None, True],
                                   'ec_data': [None, True],
                                   'ec_verify': [None, True],
                                   'primary_slot': [],
                                   'primary_type': [None, True]
                                   }
        self.pool_info_record = dict()
        self.no_pool_name_lst = []

    def init_record_pool_info(self, pool_name: str) -> dict:
        """
        :param pool_name:
        :return:
        {'存储池pool1': '节点数目num': 0,
                       '缓存类型cache_type': None, True,
                       '冗余策略redundancy_policy':None, True,
                       'ec类型ec_type': None, True,
                       'ec参数ec_data': None, True,
                       'ec验证ec_verify': None, True,
                       '主存槽位primary_slot': None, True,
                       '主存类型primary_type': None, True
                       ,...
         '存储池pool2:字典,
         '不存在的存储池信息not_exist_pool_name': False, 列表
        }
        """
        if not self.pool_info_record.get(pool_name):
            self.pool_info_record[pool_name] = copy.deepcopy(self.pool_params_module)
        return self.pool_info_record

    def no_pool_name_node(self, node_ip):
        self.no_pool_name_lst.append(node_ip)

    def pub_type_check(self, pool_name, src_type, module_str):
        try:
            record_type, res = self.pool_info_record[pool_name].get(module_str)
        except KeyError as e:
            logger.error('Key not exist. Detail:%s' % str(e))
            raise e
        if not res:
            return
        if record_type is None:
            try:
                self.pool_info_record[pool_name][module_str][0] = src_type
                return
            except KeyError as e:
                logger.error('Key not exist. Detail:%s' % str(e))
                raise e
        if record_type != src_type:
            try:
                self.pool_info_record[pool_name][module_str][1] = False
            except KeyError as e:
                logger.error('Key not exist. Detail:%s' % str(e))
                raise e

    def redundancy_policy_check(self, pool_name, osd_info):
        try:
            record_policy, res = self.pool_info_record[pool_name].get('redundancy_policy')
        except KeyError as e:
            err_msg = 'Key not exist. Detail:%s' % str(e)
            logger.error(err_msg)
            raise e
        if not res:
            return
        src_policy = osd_info.get('storagepool_redundancy_policy')
        if record_policy is None:
            try:
                self.pool_info_record[pool_name]['redundancy_policy'][0] = src_policy
            except KeyError as e:
                err_msg = 'Key not exist. Detail:%s' % str(e)
                logger.error(err_msg)
                raise e
        if src_policy in ['ec', '3Redundancy']:
            self.pub_type_check(pool_name, src_policy, 'redundancy_policy')
        else:
            try:
                self.pool_info_record[pool_name]['redundancy_policy'][1] = False
            except KeyError as e:
                err_msg = 'Key not exist. Detail:%s' % str(e)
                logger.error(err_msg)
                raise e

    def cache_type_check(self, pool_name, osd_info):
        try:
            record_type, res = self.pool_info_record[pool_name].get('cache_type')
        except KeyError as e:
            err_msg = 'Key not exist. Detail:%s' % str(e)
            logger.error(err_msg)
            raise e
        if not res:
            return
        cache_type = osd_info.get('cache_type')
        if record_type is None:
            try:
                self.pool_info_record[pool_name]['cache_type'][0] = cache_type
                return
            except KeyError as e:
                err_msg = 'Key not exist. Detail:%s' % str(e)
                logger.error(err_msg)
                raise e
        if record_type != cache_type or cache_type not in DeployConstant.SUPPORT_CACHE_TYPE:
            try:
                self.pool_info_record[pool_name]['cache_type'][1] = False
            except KeyError as e:
                err_msg = 'Key not exist. Detail:%s' % str(e)
                logger.error(err_msg)
                raise e

    def pool_type_check(self, pool_name, osd_info):
        try:
            record_type, res = self.pool_info_record[pool_name].get('pool_type')
        except KeyError as e:
            err_msg = 'Key not exist. Detail:%s' % str(e)
            logger.error(err_msg)
            raise e
        if not res:
            return
        pool_type = osd_info.get('storage_pool_type')
        if record_type is None:
            try:
                self.pool_info_record[pool_name]['pool_type'][0] = pool_type
            except KeyError as e:
                err_msg = 'Key not exist. Detail:%s' % str(e)
                logger.error(err_msg)
                raise e
        if not pool_type or pool_type not in ['Normal', 'Encryption']:
            try:
                self.pool_info_record[pool_name]['pool_type'][1] = False
                return
            except KeyError as e:
                err_msg = 'Key not exist. Detail:%s' % str(e)
                logger.error(err_msg)
                raise e
        self.pub_type_check(pool_name, pool_type, 'pool_type')

    def primary_slot_check(self, pool_name, osd_info):
        try:
            error_node = self.pool_info_record[pool_name].get('primary_slot')
        except KeyError as e:
            logger.error('Key not exist. Detail:%s' % str(e))
            raise e
        primary_slot = osd_info.get('primary_slot')
        bmc_ip = osd_info.get('bmc_ip')
        if not primary_slot:
            error_node.append(bmc_ip)
            return
        start, end = primary_slot.split('-')
        if int(end) - int(start) + 1 < 4:
            error_node.append(bmc_ip)

    def lld_params_check_result_handle(self):
        err_msg = dict()
        node_nums_error = []
        redundancy_policy_error_pool = []
        for pool, res in self.pool_info_record.items():
            if res.get('node_num') < 3:
                node_nums_error.append(pool)
            if not res['cache_type'][1]:
                err_msg[626012] = get_code_msg(626012) % DeployConstant.SUPPORT_CACHE_TYPE
            if not res['pool_type'][1]:
                err_msg[626353] = get_code_msg(626353)
            if not res['redundancy_policy'][1]:
                redundancy_policy_error_pool.append(pool)
            if not res['ec_data'][1]:
                err_msg[626015] = get_code_msg(626353)
            if not res['ec_verify'][1]:
                err_msg[626017] = get_code_msg(626017) % ''
            if res['primary_slot']:
                err_msg[626213] = get_code_msg(626213)
            if not res['primary_type'][1]:
                err_msg[626345] = get_code_msg(626345) % list(DeployConstant.STORAGE_TYPE_MAP.values())

        if node_nums_error:
            err_msg[626352] = get_code_msg(626352) % node_nums_error
        if redundancy_policy_error_pool:
            err_msg[626013] = get_code_msg(626013) % redundancy_policy_error_pool
        if self.no_pool_name_lst:
            err_msg[626011] = get_code_msg(626011) % self.no_pool_name_lst
        if err_msg:
            logger.error('LLD Params Error:{}'.format(err_msg))
            res_err_msg = PoolInfoCheck.format_error_msg(list(err_msg.values()))
            raise HCCIException(626354, res_err_msg)

    def expand_business_node_check_result_handle(self):
        err_msg = dict()
        for _pool, res in self.pool_info_record.items():
            if not res['cache_type'][1]:
                err_msg[626012] = get_code_msg(626012) % DeployConstant.SUPPORT_CACHE_TYPE
            if res['primary_slot']:
                err_msg[626213] = get_code_msg(626213)
            if not res['primary_type'][1]:
                err_msg[626345] = get_code_msg(626345) % list(DeployConstant.STORAGE_TYPE_MAP.values())
        if err_msg:
            logger.error('LLD params error:{}'.format(err_msg))
            res_err_msg = PoolInfoCheck.format_error_msg(list(err_msg.values()))
            raise HCCIException(626354, res_err_msg)

    def expand_manager_node_check_result_handle(self):
        err_msg = dict()
        for _pool, res in self.pool_info_record.items():
            if not res['cache_type'][1]:
                err_msg[626377] = get_code_msg(626377) % DeployConstant.SUPPORT_CACHE_TYPE
            if res['primary_slot']:
                err_msg[626378] = get_code_msg(626378)
        if err_msg:
            logger.error('LLD Params Error:{}'.format(err_msg))
            res_err_msg = PoolInfoCheck.format_error_msg(list(err_msg.values()))
            raise HCCIException(626354, res_err_msg)
