# -*- coding: utf-8 -*-
# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved.
import traceback

import utils.common.log as logger
from utils.DBAdapter.DBConnector import BaseOps
from utils.common.OpenStackNodeManager import OpenStackNodeManager
from utils.business.project_condition_utils import get_project_condition_boolean
from plugins.DistributedStorage.utils.iterm.parameter_gain import ParamsGain
from plugins.DistributedStorage.utils.common.deploy_constant import DeployConstant
from plugins.DistributedStorage.logic.deploy_operate import DeployOperate


class CheckLicense(object):
    """
    扩容复用场景，检查管理存储或业务存储是否导入license、license是否过期，以及license是否合规
    : param:
    : return:
    """
    def __init__(self, project_id, pod_id, **kwargs):
        self.project_id = project_id
        self.pod_id = pod_id
        self.db = BaseOps()
        self.args = dict()
        self.deploy_constant = DeployConstant()
        self.bmc_ip_list, self.osd_info = self.get_osd_and_dm_info_list()
        self.fsm_user = self.deploy_constant.DM_LOGIN_USER
        self.float_ip = self.args.get("fsm_float_ip")
        self.fsm_admin_passwd = self.args.get("fsm_admin_passwd")
        self.storage_pool_name = self.args.get("storage_pool_name")
        self.operate = DeployOperate(fs_info=None, float_ip=self.float_ip)
        self.standard_bom_codes = self.deploy_constant.OCEAN_STORAGE_BLOCK_BOM_CODES.get("Standard")
        self.advanced_bom_codes = self.deploy_constant.OCEAN_STORAGE_BLOCK_BOM_CODES.get("Advanced")

    def procedure(self):
        if not self.bmc_ip_list:
            logger.info('Not expansion osd scenario.')
            return
        try:
            logger.info("start to check license")
            self.operate.login(self.fsm_user, self.fsm_admin_passwd)
            current_bom_code = self.check_license_status()
            license_version = "Standard" if current_bom_code in self.standard_bom_codes else "Advanced"
            # 检查新扩池、扩节点数是否大于license限制
            original_storage_node_ip_list, original_storage_pools_name_list, _ = self.get_exist_pool_node_info()

            total_pool = set(self.storage_pool_name + original_storage_pools_name_list)
            pools_number = len(total_pool)
            nodes_number = len(self.bmc_ip_list) + len(original_storage_node_ip_list)
            logger.info("original node list:{}, current node list:{}"
                        .format(original_storage_node_ip_list, self.bmc_ip_list)
                        )
            if license_version == "Standard" and pools_number > 3 or nodes_number > 256:
                err_msg = "The number of resource pools or nodes exceeds " \
                          "the license specification. license version: " \
                          "'Standard'. your pools and nodes:[{0},{1}], " \
                          "license specification:[3, 256]" \
                    .format(pools_number, nodes_number)
                logger.error(err_msg)
                raise Exception(err_msg)
            if license_version == "Advanced" and pools_number > 128 or nodes_number > 4096:
                err_msg = "The number of resource pools or nodes exceeds " \
                          "the license specification. license version: " \
                          "'Advanced'. your pools and nodes:[{0},{1}], " \
                          "license specification:[128, 4096]" \
                    .format(pools_number, nodes_number)
                logger.error(err_msg)
                raise Exception(err_msg)

            # 检查CSHA场景，license是否为高级版
            support_csha = get_project_condition_boolean(self.project_id, "CSHA")
            logger.info("support CSHA:[{}]".format(support_csha))
            if license_version == "Standard" and support_csha:
                err_msg = "The current license version is Standard, please upgrade to Advanced."
                logger.error(err_msg)
                raise Exception(err_msg)

            logger.info('license check completed!!!')
        finally:
            self.operate.logout()

    def check_license_status(self):
        # 1、查询license特性状态(有效、过期、无效)
        res = self.operate.get_license_information()
        data, result = res.get_license_data()
        logger.info("query result: {}, {}".format(data, result))
        # 2、检查是否导入license、license是否过期、BOMCode是否为fusionStorage block
        if data.get("FileExist") != "0":
            err_msg = "The license file does not exist, please import " \
                      "the license of device manager:[{}]. Details: {}" \
                .format(self.float_ip, data.get("FileExist"))
            logger.error(err_msg)
            raise Exception(err_msg)
        license_usage_info_list = data.get("LicenseUsageInfo")
        if not license_usage_info_list:
            err_msg = "Failed to get license usage information. Details: {}" \
                .format(license_usage_info_list)
            logger.error(err_msg)
            raise Exception(err_msg)
        for license_feature in license_usage_info_list:
            license_state = license_feature.get("State")
            if license_state != "1":
                err_msg = "The license status of the feature is expired or " \
                          "invalid. please re-import the license of device " \
                          "manager:[{}]. Details: {}" \
                    .format(self.float_ip, license_state)
                logger.error(err_msg)
                raise Exception(err_msg)
        license_bom_codes_list = data.get("bomCodes")
        if not license_bom_codes_list:
            err_msg = "Failed to get license bom codes. Details: {}" \
                .format(license_bom_codes_list)
            logger.error(err_msg)
            raise Exception(err_msg)
        current_bom_code = license_bom_codes_list[0].get("bomCode")
        if current_bom_code not in self.standard_bom_codes + self.advanced_bom_codes:
            err_msg = "This is not a license for FusionStorage block, " \
                      "please re-import the license of device manager:" \
                      "[{}]. Details: {}".format(self.float_ip, current_bom_code)
            logger.error(err_msg)
            raise Exception(err_msg)
        return current_bom_code

    def get_osd_and_dm_info_list(self):
        """
        获取管理或业务存储osd节点 bmc information、槽位号、存储池
        获取device manager登录信息
        """
        params = ParamsGain(self.project_id, self.pod_id, self.db)
        if get_project_condition_boolean(
                self.project_id, "ExpansionMgmtRes_ServiceNode"):
            nodes_info = OpenStackNodeManager.get_manage_az_nodes_info(
                self.db, self.pod_id)
            osd_info = [node for node in nodes_info if 'osd' in node.get('ref_component')]
            pool_name_list = [osd["manage_storage_pool_name"]
                              for osd in osd_info
                              if osd.get("manage_storage_pool_name")]
            fs_args = params.get_manage_converge_args()
            self.args["fsm_float_ip"] = fs_args.get("float_ip")
            self.args["fsm_admin_passwd"] = fs_args.get("dm_update_pwd")
        else:
            osd_info = self.db.get_install_os_list_info(self.pod_id)
            pool_name_list = [osd["storage_pool_name_and_slot"] for osd in osd_info]
            fs_args = params.get_business_separate_args()
            self.args["fsm_float_ip"] = fs_args.get("float_ip")
            self.args["fsm_admin_passwd"] = fs_args.get("dm_update_pwd")

        self.args["storage_pool_name"] = pool_name_list
        logger.info("osd inforamation: %s" % osd_info)
        bmc_ip_list = [bmc_info["bmc_ip"] for bmc_info in osd_info]
        return bmc_ip_list, osd_info

    def get_exist_pool_node_info(self):
        """
        获取原有az的存储池所有存储节点的管理IP、存储池名称列表、原有池的物理总容量TB
        """
        result = self.operate.query_storage_pool()
        storage_pool_info = result.get_query_data()
        if len(storage_pool_info.get('storagePools')) == 0:
            logger.error('check pool fail.')
            raise Exception("check pool fail.")

        storage_pools_list = storage_pool_info.get("storagePools")
        logger.info("original storage pool information: <%s>" % storage_pools_list)
        original_storage_pools_name_list = list()
        original_pool_id_list = list()
        original_storage_pool_total_capacity = 0
        for each_pool_info in storage_pools_list:
            total_capacity = each_pool_info.get("totalCapacity")
            if total_capacity:
                original_storage_pool_total_capacity += int(total_capacity)
            original_pool_id_list.append(str(each_pool_info.get('poolId')))
            original_storage_pools_name_list.append(each_pool_info.get("poolName"))
        try:
            storage_pool_total_capacity_tb = original_storage_pool_total_capacity / 1024**2
        except ZeroDivisionError as e:
            logger.error(traceback.format_exc())
            raise e
        logger.info("Get storage pool list %s." % original_pool_id_list)

        storage_node_ip_list = list()
        for pool_id in original_pool_id_list:
            rsp_data = self.operate.query_storage_node_by_pool_id(pool_id)
            pool_data = rsp_data.get_query_data()
            logger.info("pool id: <%s>. pool data: <%s>" % (pool_id, pool_data))
            node_info_list = pool_data.get('nodeInfo')
            node_mgr_ip_list = [node.get('nodeMgrIp') for node in node_info_list]
            storage_node_ip_list.extend(node_mgr_ip_list)
        logger.info("Get storage node ip list[%s] from pool list%s." % (
            storage_node_ip_list, original_pool_id_list))

        return storage_node_ip_list, original_storage_pools_name_list, storage_pool_total_capacity_tb
