# -*- coding: utf-8 -*-
import time
import traceback

from collections import defaultdict
from utils.common.fic_base import TestCase
import utils.common.log as logger
from utils.common.message import Message
from utils.common.exception import HCCIException
from utils.business.project_condition_utils import get_project_condition_boolean
from plugins.DistributedStorage.logic.deploy_operate import DeployOperate
from plugins.DistributedStorage.logic.install_operate import InstallOperate
from plugins.DistributedStorage.utils.common.deploy_constant import DeployConstant
from plugins.DistributedStorage.logic.error_msg_handle import PoolInfoCheck, PoolLLDParamsCheck


class CreatePool(TestCase):
    def __init__(self, project_id, pod_id, fs_args, **kwargs):
        super(CreatePool, self).__init__(project_id, pod_id)
        self.more_args = kwargs
        self.opr = DeployOperate(fs_args)
        self.install_operate = InstallOperate(project_id, pod_id, fs_args)
        self.fsa_list = fs_args.get('fsa_list')
        self.osd_list = fs_args.get('osd_list')
        self.float_ip = fs_args.get('float_ip')
        self.update_pwd = fs_args.get('dm_update_pwd')
        self.min_disk_num = 0
        self.max_disk_num = 0
        self.disk_num = 0
        self.config_license_switch_condition = get_project_condition_boolean(self.project_id, '!TenantStorFBReuse80')
        self.pool_info_check = PoolInfoCheck.pool_info_check_dict_init()

    @staticmethod
    def get_storage_type(node_info, storage_type_dict):
        if storage_type_dict.get('main_type'):
            return
        storage_type_dict['storageMediaType'] = DeployConstant.STORAGE_TYPE_MAP.get(node_info.get('primary_type'))
        storage_type_dict['cacheMediaType'] = node_info.get('cache_type')
        logger.info('primary and cache type:{}'.format(storage_type_dict))

    @staticmethod
    def get_redundancy_policy(node_info, pool_args):
        if pool_args.get('redundancyPolicy'):
            return
        redundancy_policy = node_info.get('storagepool_redundancy_policy')
        logger.info('Pool redundancy policy:{}'.format(redundancy_policy))
        if redundancy_policy == 'ec':
            pool_args['redundancyPolicy'] = 'ec'
            pool_args['numParityUnits'] = node_info.get('ec_verify_fragments')
            pool_args['numDataUnits'] = node_info.get('ec_data_fragments')
            pool_args['numFaultTolerance'] = 1
        else:
            pool_args['redundancyPolicy'] = 'replication'
            pool_args['replicaNum'] = 3

    @staticmethod
    def get_encrypt_type(node_info, pool_args):
        if pool_args.get('encryptType'):
            return
        encrypt_type_map = {'Normal': 0, 'Encryption': 1}
        encrypt_type = node_info.get('storage_pool_type')
        logger.info('Encrypt type:{}'.format(encrypt_type))
        pool_args['encryptType'] = encrypt_type_map.get(encrypt_type)

    @staticmethod
    def check_lld_params_and_get_pool_list(osd_list, exist_pool):
        logger.info('Start to check lld params')
        lld_pool_check = PoolLLDParamsCheck()
        for node in osd_list:
            bmc_ip = node.get('bmc_ip')
            cur_pool = node.get('storage_pool_name_and_slot')
            if not cur_pool:
                lld_pool_check.no_pool_name_node(bmc_ip)
                continue
            if cur_pool in exist_pool:
                continue
            lld_pool_check.init_record_pool_info(cur_pool)
            lld_pool_check.pool_info_record[cur_pool]['node_num'] += 1
            lld_pool_check.redundancy_policy_check(cur_pool, node)
            record_policy, check_res = lld_pool_check.pool_info_record[cur_pool]['redundancy_policy']
            if check_res and record_policy == 'ec':
                ec_data_fragments = node.get('ec_data_fragments')
                ec_verify_fragments = node.get('ec_verify_fragments')
                lld_pool_check.pub_type_check(cur_pool, ec_data_fragments, 'ec_data')
                lld_pool_check.pub_type_check(cur_pool, ec_verify_fragments, 'ec_verify')
            lld_pool_check.cache_type_check(cur_pool, node)
            lld_pool_check.pool_type_check(cur_pool, node)
            lld_pool_check.primary_slot_check(cur_pool, node)
            primary_type = node.get('primary_type')
            lld_pool_check.pub_type_check(cur_pool, primary_type, 'primary_type')
        logger.info('Result analysis')
        lld_pool_check.lld_params_check_result_handle()
        logger.info('Check Passed')
        return list(lld_pool_check.pool_info_record.keys())

    @staticmethod
    def record_disk_num(min_num, max_num, disk_num):
        if min_num == 0:
            min_num = disk_num
        else:
            if disk_num < min_num:
                min_num = disk_num
        if max_num == 0:
            max_num = disk_num
        else:
            if disk_num > max_num:
                max_num = disk_num
        return min_num, max_num

    @staticmethod
    def _format_pool_error_info(pool_name, task_status, description, err_msg):
        task_data = dict()
        task_data['pool'] = pool_name
        task_data['status'] = task_status
        task_data['description'] = description
        task_data['detail'] = err_msg
        return task_data

    def procedure(self):
        try:
            self.main()
        except HCCIException as e:
            return Message(500, e)
        except Exception as e:
            logger.error(traceback.format_exc())
            return Message(500, HCCIException(626003, str(e)))
        finally:
            if self.config_license_switch_condition:
                self.install_operate.config_create_pool_license_switch(self.pod_id, self.float_ip, delete=True)
            self.opr.login_out(DeployConstant.DM_LOGIN_USER, self.update_pwd)
        return Message()

    def main(self):
        # 检查fsa节点数量是否大于osd节点数量
        self._check_fsa_and_osd_host_num()
        logger.info("Start to create pool.")
        if self.config_license_switch_condition:
            self.install_operate.config_create_pool_license_switch(self.pod_id, self.float_ip)
        self.login_dm()
        self.check_cluster()
        logger.info('Get data of storage pool')
        all_nodes_disks = self.get_cluster_disk_info()
        exist_pools = self.get_exist_pool_info()
        osd_list = self.db.get_install_os_list_info(pod_id=self.pod_id)
        pool_config_list = self.get_create_pool_args(osd_list, all_nodes_disks, exist_pools)
        logger.info('Creating pool')
        fail_pool = fail_pool_detail = list()
        for pool in pool_config_list:
            res = self.opr.create_pool(pool)
            status_code, result, task_id, error_code, error_des = res.get_create_pool_code()
            if status_code != 200 or result != 0:
                logger.info("Pool args is: %s" % pool)
                err_msg = "Failed to Create pool, " \
                          "Detail:[status:%s,result:%s,error:%s]%s" % (status_code, result, error_code, error_des)
                logger.error(err_msg)
                raise HCCIException(626003, err_msg)
            query_result = self.query_pool_task_process(task_id, pool.get('poolPara').get('poolName'))
            if query_result:
                fail_pool.append(pool.get('pool_name'))
                fail_pool_detail.append(query_result)
        if len(fail_pool) > 0:
            err_msg = "Failed to create pool%s, Detail:%s" % (fail_pool, fail_pool_detail)
            logger.error(err_msg)
            raise HCCIException(626003, err_msg)
        logger.info('Create pool successful')

    def query_pool_task_process(self, task_id, pool_name, timeout=3600):
        logger.info('Query the process of creating pool[%s]' % pool_name)
        current_time = 0
        task_data = dict()
        while current_time <= timeout:
            time.sleep(10)
            current_time += 10
            logger.info("Start query the process of pool[%s] task[%s] "
                        "after %s seconds." % (pool_name, task_id, current_time))
            res_query = self.opr.query_task_info()
            task_info = res_query.get_task_by_id(task_id)
            task_status = task_info.get('taskStatus')
            entity_name = task_info.get('entityName')
            if task_status == 'success':
                logger.info("Create the storage pool[%s] success." % entity_name)
                break
            elif task_status == "failed":
                logger.error("Failed to create the storage pool[%s]. Detail:%s" % (entity_name, str(task_info)))
                task_data = self._format_pool_error_info(
                    pool_name, task_status, "Failed to create the storage pool", task_info)
                break
            elif task_status == "part_success":
                err_hint = "The task of creating a pool is partially successful. " \
                           "The OSD process fails to be started or other sub-task fail to be created on some nodes. " \
                           "Make the following operations by Manually: " \
                           "1. Login to FusionStorage web page and go to the task center, " \
                           "find the task of creating a pool, and handle the faulty disk or failed sub-task. " \
                           "2. Delete the created storage pool. " \
                           "If a message is displayed indicating that the storage pool depends on the VBS, " \
                           "enable the VBS first, delete the storage pool, " \
                           "and then disable the VBS. Try again after the pool is deleted."
                logger.error(err_hint)
                task_data = self._format_pool_error_info(pool_name, task_status, err_hint, task_info)
                break
            else:
                task_name = task_info.get('taskName')
                task_progress = task_info.get('progress')
                logger.info("Creating the storage pool[taskID: %s, taskName: %s, taskObject: %s, taskStatus: %s, "
                            "taskProgress: %s]" % (task_id, task_name, entity_name, task_status, task_progress))

        if current_time >= timeout:
            err_msg = "Waiting for the task[%s] completion times out after %s seconds." % (task_id, current_time)
            logger.error(err_msg)
            task_data = self._format_pool_error_info(
                pool_name, "timeout", "Waiting for the task completion times out.", err_msg)
        return task_data

    def check_cluster(self):
        logger.info('Check cluster...')
        response = self.opr.query_manage_cluster()
        cluster = response.get_query_data()
        if not cluster.get('clusterName'):
            logger.error('Query cluster fail...')
            raise HCCIException(626073, cluster)
        logger.info('Query cluster successfully')

    def login_dm(self):
        status_code, error_code, error_des = self.opr.login(DeployConstant.DM_LOGIN_USER, self.update_pwd)
        if status_code != 200 or error_code != 0:
            err_msg = "Failed to login deploy manager, " \
                      "Detail:[status:%s,code:%s]%s" % (status_code, error_code, error_des)
            logger.error(err_msg)
            raise HCCIException(626067, err_msg)

    def get_create_pool_args(self, osd_list, all_nodes_disks, exist_pools):
        """
        :param osd_list:
        :param all_nodes_disks:
        :param exist_pools:
        :return:
        EC pool_para_args：
             '存储池名称poolName': '',
             '冗余策略redundancyPolicy': '',
             '奇偶校验单位numParityUnits': '',
             '数据单位numDataUnits': '',
             '容错numFaultTolerance': '',
             '加密类型encryptType': '',
             '缓存媒体类型cacheMediaType': '',
             '存储介质类型storageMediaType': '',
             '安全级别securityLevel': '',
        """
        pool_list = self.check_lld_params_and_get_pool_list(osd_list, exist_pools)
        logger.info('Storage pool to be created:{}'.format(pool_list))
        pool_config_args_list = list()
        for pool in pool_list:
            self.min_disk_num, self.max_disk_num, self.disk_num = 0, 0, 0
            pool_config_args = dict()
            pool_para_args = dict()
            server_list = list()
            cur_storage_type = {'storageMediaType': None, 'cacheMediaType': None}
            for osd in osd_list:
                if pool != osd.get("storage_pool_name_and_slot"):
                    continue
                self.get_storage_type(osd, cur_storage_type)
                self.get_redundancy_policy(osd, pool_para_args)
                self.get_encrypt_type(osd, pool_para_args)
                server_dict = self.get_server_info(osd, all_nodes_disks, cur_storage_type, pool)
                server_list.append(server_dict)
            self.check_disk_num(pool)

            rack_ha = get_project_condition_boolean(self.project_id, 'TenantStorFB_RackHA')
            pool_para_args['securityLevel'] = 'rack' if rack_ha else 'server'
            pool_para_args['poolName'] = pool
            pool_para_args.update(cur_storage_type)

            pool_config_args['poolPara'] = pool_para_args
            pool_config_args['serverList'] = server_list
            pool_config_args_list.append(pool_config_args)
            logger.info("Get pool parameters:%s" % str(pool_config_args))
        # 统一校验存储池参数包括盘总数、节点盘数差（差值不超过2）、主存和缓存类型、容量、盘数量差（不小于4）校验
        PoolInfoCheck.pool_info_error_check(self.pool_info_check)
        return pool_config_args_list

    def get_cluster_disk_info(self):
        logger.info('Query all disk info')
        all_nodes_disks = self.opr.query_all_disk()
        all_disk_data = all_nodes_disks.get_query_data()
        disk_info = all_disk_data.get('disks')
        if not disk_info:
            err_msg = 'No disk information exists.'
            logger.error(err_msg)
            raise HCCIException(626074, all_disk_data)
        return all_nodes_disks

    def get_exist_pool_info(self):
        logger.info("Query storage pool data")
        exist_pool = list()
        res_pool = self.opr.query_storage_pool()
        pool_info = res_pool.get_query_data()
        logger.info('Pool info:{}'.format(pool_info))
        storage_pools = pool_info.get('storagePools')
        if storage_pools:
            exist_pool = [pool.get('poolName') for pool in storage_pools]
            logger.info("Pool[%s] has been created." % exist_pool)
        return exist_pool

    def get_server_info(self, osd, all_nodes_disks, cur_storage_type, pool):
        """
        :param osd:
        :param all_nodes_disks:
        :param cur_storage_type:
        :param pool:
        :return:
        服务器信息：
                 'nodeMgrIp节点管理ip'：'*.*.*.*',
                 '缓存类型cacheMediaType': 'ssd_disk',
                 '介质信息列表mediaList': storage_disk_list + cache_disk_list
                 '开始槽位startSlot':
                 '结束槽位endSlot':

        """
        server = dict()
        logger.info('Start to query the server info')
        main_type = cur_storage_type.get('storageMediaType')
        cache_type = cur_storage_type.get('cacheMediaType')
        start_slot, end_slot = osd.get('primary_slot').split('-')
        om_ip = osd.get('manageIp')
        server['nodeMgrIp'] = om_ip
        server['cacheMediaType'] = cache_type
        logger.info('The type of main storage is {}'.format(main_type))
        if main_type in ['sas_disk', 'sata_disk', 'ssd_disk']:
            server['startSlot'] = start_slot
            server['endSlot'] = end_slot
            storage_disk_list = self.get_main_storage_when_not_nvme(pool, osd, all_nodes_disks, [start_slot, end_slot])
        else:
            cur_node_num = int(end_slot) - int(start_slot) + 1
            storage_disk_list = self.get_node_main_storage_when_all_nvme(pool, om_ip, cur_node_num, all_nodes_disks)
        cache_disk_list = self.get_cache_disk_list(pool, osd, all_nodes_disks, cur_storage_type)
        server['mediaList'] = storage_disk_list + cache_disk_list
        return server

    def get_node_main_storage_when_all_nvme(self, pool, node_ip, num, all_nodes_disks):
        res_disk_list = list()
        disk_info = all_nodes_disks.get_query_data().get('disks')
        cur_node_disk_info = disk_info.get(node_ip)
        logger.info("All disk on node[%s]: %s" % (node_ip, cur_node_disk_info))
        if not cur_node_disk_info:
            err_msg = 'No disk information exists on the node:[%s]' % node_ip
            logger.error(err_msg)
            self.pool_info_check[pool]['nvme_main_storage']['no_disk_error'].append(node_ip)
            return res_disk_list

        capacity_team = defaultdict(list)
        for slot_disk_info in cur_node_disk_info:
            if slot_disk_info.get('devRole') != 'no_use' or slot_disk_info.get('devType') != 'ssd_card':
                continue
            capacity_team[slot_disk_info.get('devTotalCapacity')].append(slot_disk_info)
        useful_main_storage_info = list(filter(lambda x: len(x) >= 4, capacity_team.values()))
        logger.info('useful disk:{}'.format(useful_main_storage_info))
        if len(useful_main_storage_info) != 1:
            err_msg = 'This node has no available or different capacity nvme disks. node ip:{}'.format(node_ip)
            logger.error(err_msg)
            self.pool_info_check[pool]['nvme_main_storage']['disk_nums_error'].append(node_ip)
            return res_disk_list

        ordered_useful_main_storage_info = sorted(useful_main_storage_info[0], key=lambda x: x.get('devSlot'))
        logger.info('ordered useful disk:{}'.format(ordered_useful_main_storage_info))
        for slot_disk_info in ordered_useful_main_storage_info:
            cur_slot_disk = {'phySlotId': slot_disk_info.get('devSlot'),
                             'mediaRole': 'main_storage',
                             'mediaType': 'ssd_card',
                             'mediaSize': slot_disk_info.get('devTotalCapacity'),
                             'phyDevEsn': slot_disk_info.get('devEsn')}
            res_disk_list.append(cur_slot_disk)
            if len(res_disk_list) == num:
                break
        self.pool_info_check[pool]['nvme_main_storage']['disk_capacity'][res_disk_list[0]['mediaType']].append(node_ip)
        disk_num = len(res_disk_list)
        self.disk_num += disk_num
        self.min_disk_num, self.max_disk_num = self.record_disk_num(self.min_disk_num, self.max_disk_num, disk_num)
        return res_disk_list

    def get_main_storage_when_not_nvme(self, pool, osd, all_nodes_disks, slot_range):
        start_slot, end_slot = slot_range
        storage_disks = list()
        om_ip = osd.get('manageIp')
        for slot in range(int(start_slot), int(end_slot) + 1):
            disk_meta = all_nodes_disks.get_disks_by_node_slot(om_ip, slot)
            if not disk_meta:
                logger.info("There is no disk on slot[%s] at node[%s]" % (slot, om_ip))
                continue
            disk_status = disk_meta.get('devRole')
            if disk_status != 'no_use':
                self.pool_info_check[pool]["main_storage_info"]["disk_status"][disk_status][om_ip].append(slot)
                continue
            disk = dict()
            disk['phySlotId'] = slot
            disk['mediaRole'] = "main_storage"
            disk['mediaType'] = disk_meta.get('devType')
            disk['mediaSize'] = disk_meta.get('devTotalCapacity')
            disk['phyDevEsn'] = disk_meta.get('devEsn')
            try:
                self.pool_info_check[pool]["main_storage_info"]["type_info"][disk['mediaType']][om_ip].append(slot)
            except KeyError as e:
                logger.error("Key not exist. Detail:%s" % str(e))
                raise e
            try:
                self.pool_info_check[pool]["main_storage_info"]["capacity_info"][disk['mediaSize']][om_ip].append(slot)
            except KeyError as e:
                logger.error('Key not exist. Detail:%s' % str(e))
                raise e
            storage_disks.append(disk)
        disk_num = len(storage_disks)
        slot_range = '%s-%s' % (start_slot, end_slot)
        self.pool_info_check[pool]["main_storage_info"]["disk_num"][disk_num][om_ip] = slot_range
        self.disk_num += disk_num
        self.min_disk_num, self.max_disk_num = self.record_disk_num(self.min_disk_num, self.max_disk_num, disk_num)
        return storage_disks

    def get_cache_disk_list(self, pool, osd, all_nodes_disks, cur_storage_type):
        cache_type, main_storage_type = cur_storage_type.get('cacheMediaType'), cur_storage_type.get('storageMediaType')
        cache_disks = list()
        om_ip = osd.get('manageIp')
        if cache_type and main_storage_type not in ['ssd_disk', 'ssd_card']:
            all_disks = all_nodes_disks.get_disks_by_node(om_ip)
            cache_num = 0
            max_cache_num = 4
            for cache in all_disks:
                if cache.get('devType') == cache_type and cache.get('devRole') == 'no_use' \
                        and cache.get('devSlot') < 5000:
                    cache_disk = dict()
                    cache_disk['phySlotId'] = cache.get('devSlot')
                    cache_disk['mediaRole'] = "osd_cache"
                    cache_disk['mediaType'] = cache.get('devType')
                    cache_disk['mediaSize'] = cache.get('devTotalCapacity')
                    cache_disk['phyDevEsn'] = cache.get('devEsn')
                    self.pool_info_check[pool]["cache_info"]["media_type"][cache.get('devType')].append(om_ip)
                    self.pool_info_check[pool]["cache_info"]["media_size"][cache.get('devTotalCapacity')].append(om_ip)
                    cache_num += 1
                    cache_disks.append(cache_disk)
                if cache_num >= max_cache_num:
                    break
            self.pool_info_check[pool]["cache_info"]['cache_num'][len(cache_disks)].append(om_ip)
        return cache_disks

    def check_disk_num(self, pool):
        if self.max_disk_num == 0 or self.min_disk_num == 0:
            logger.info("No available disk num[max:%s, min:%s]. Skip.." % (self.max_disk_num, self.min_disk_num))
            return
        self.pool_info_check[pool]["disk_num_gap"] = (self.max_disk_num, self.min_disk_num)
        try:
            self.pool_info_check[pool]["percentage"] = (self.max_disk_num -
                                                        self.min_disk_num) / float(self.max_disk_num)
        except ZeroDivisionError as e:
            err_msg = 'ZeroDivisionError, Detail:%s' % str(e)
            logger.error(err_msg)
            raise e
        self.pool_info_check[pool]["disk_total_num"] = self.disk_num

    def _check_fsa_and_osd_host_num(self):
        """
        检查当前工程fsa节点数量是否大于osd节点数量
        :return:
        """
        if not len(self.fsa_list) >= len(self.osd_list):
            osd_ip = [osd_node.get('bmc_ip') for osd_node in self.osd_list]
            fsa_ip = [fsa_node.get('bmc_ip') for fsa_node in self.fsa_list]
            msg = "The osd list is not the subset of fsa list. osd iBMC ip:{},fsa iBMC ip:{}".format(osd_ip, fsa_ip)
            logger.error(msg)
            raise HCCIException(626003, msg)