# coding: UTF-8
import re

import common
from check_pool_num import CheckPoolNumUtil
from com.huawei.ism.exception import IsmException

LANG = common.getLang(py_java_env)
LOGGER = common.getLogger(PY_LOGGER, __file__)
HANDLE = py_java_env.get("preInspectHandle")
ITEM_ID = "check_disk_num_to_add"
PRE_ITEM_ID = "getDiskNumAndPlogNum"


class Pool:
    def __init__(self, pool_ip, pool_name, origin_num, expansion_num):
        self.pool_ip = pool_ip
        self.pool_name = pool_name
        self.origin_num = origin_num
        self.expansion_num = expansion_num

    def get_pool_ip(self):
        return self.pool_ip

    def get_pool_name(self):
        return self.pool_name

    def get_origin_num(self):
        return self.origin_num

    def get_expansion_num(self):
        return self.expansion_num


def get_storage_software(dev_node):
    """
    :param dev_node: 节点信息
    :function: 获取存储软件版本
    :return: 存储软件版本
    """
    LOGGER.logInfo("check product version.")
    product_version = dev_node.getProductVersion()
    patch_version = dev_node.getHotPatchVersion()

    if patch_version in product_version:
        return product_version
    elif product_version in patch_version:
        return patch_version
    elif not patch_version:
        return product_version
    else:
        return product_version + "." + patch_version


def check_storage_software(version):
    """
    :function: 存储软件风险配套检查
    :param version: 当前存储软件版本
    :return: True：检查通过；False:存在风险软件配套关系
    """
    black_product_ver = (
        "8.1.0.SPH6",
        "8.1.1.HP3",
        "8.1.RC3"
    )
    # 非风险版本与存储软件配套检查
    for black_ver in black_product_ver:
        if black_ver == version:
            return False
    return True


def check_version(pools, product_version):
    # 811HP3接口查不到，利用预巡检项来检查
    hp_flag_pattern = re.compile(r'(?<=HP_FLAG=)\d+\.?\d*')
    pool_ip = pools[0].get_pool_ip()
    pre_inpect_result = HANDLE.getPreInspectResult(pool_ip, PRE_ITEM_ID)
    hp_flag_match = hp_flag_pattern.findall(pre_inpect_result)
    hp_flag = float(hp_flag_match[0])
    # 是811HP3 hp_flag=2，否则为0
    if hp_flag - 1 > 0:
        product_version = "8.1.1.HP3"

    return product_version


def get_expansion_new_disk_num(disk_pool):
    """
    获取新增盘信息
    :param disk_pool: 硬盘池信息
    :return:
    """
    joined_cluster_node = disk_pool.getJoinedClusterNode()
    expansion_cluster_node = disk_pool.getExpansionClusterNode()
    expansion_node = disk_pool.getExpansionNodeList()
    nodes = []
    nodes.extend(joined_cluster_node)
    if expansion_node:
        nodes.extend(expansion_node)
    if expansion_cluster_node:
        nodes.extend(expansion_cluster_node)

    disk_num = 0
    for node in nodes:
        item_storage = common.get_expansion_main_storage_disk(
            node).entrySet().iterator()
        while item_storage.hasNext():
            entry = item_storage.next()
            disk_num += entry.getValue()
    return disk_num


def check_disk_num(pools, tmp_err_list, tmp_warn_list, expansion_new_disk_num, msg_head):
    """
    :param tmp_warn_list:
    :param expansion_new_disk_num:存储池扩容硬盘数
    :param tmp_err_list:
    :param pools:原池ip
    :return: 检查结果
    """
    LOGGER.logInfo("check disk num.")
    protect_mode_pattern = re.compile(r'(?<=PROTECT_MODE=)\d+\.?\d*')
    ori_disk_num_pattern = re.compile(r'(?<=DISK_POOL_DISK_NUM=)\d+\.?\d*')
    parity_block_pattern = re.compile(r'(?<=parity_block_num_m=)\d+\.?\d*')
    data_block_pattern = re.compile(r'(?<=data_block_num_n=)\d+\.?\d*')
    for pool in pools:
        pool_ip = pool.get_pool_ip()
        pool_name = pool.get_pool_name()
        LOGGER.logInfo("pool ip[{}],name[{}],new expansion disk num:[{}]"
                       .format(pool_ip, pool_name, expansion_new_disk_num))

        # start 预处理执行shell脚本收集集群节点信息获取返回值：硬盘池总盘数
        pre_inpect_result = HANDLE.getPreInspectResult(pool_ip, PRE_ITEM_ID)

        protect_mode_match = protect_mode_pattern.findall(pre_inpect_result)
        LOGGER.logInfo("protect mode is:[{}](EC:0,REPLICA:1)".format(protect_mode_match))

        ori_disk_num_match = ori_disk_num_pattern.findall(pre_inpect_result)
        parity_block_match = parity_block_pattern.findall(pre_inpect_result)
        data_block_match = data_block_pattern.findall(pre_inpect_result)

        if (len(ori_disk_num_match) == 0) or (len(ori_disk_num_match) == 0) or (len(data_block_match) == 0):
            LOGGER.logInfo("query result failed from storage node.")
            tmp_warn_list.append("query disk info or redundancy ratio failed.")
            continue

        ori_disk_num = float(ori_disk_num_match[0])
        parity_block_m = float(parity_block_match[0])
        data_block_n = float(data_block_match[0])
        LOGGER.logInfo("ori disk num:[{}],parity block:[{}],data block:[{}]"
                       .format(ori_disk_num, parity_block_m, data_block_n))
        total_disk_num = ori_disk_num + expansion_new_disk_num
        ratio = parity_block_m + data_block_n

        # (新增盘数/ 新增盘数 + 原盘数) > (1 / N + M)继续检查plogNum
        if (expansion_new_disk_num * ratio) > total_disk_num:
            ret = check_osd_plog_num(pool_ip, tmp_warn_list)
            if not ret:
                continue
            else:
                tmp_err_list.append(common.get_err_msg(LANG, "disk.num.and.plog.num.check.failed", msg_head))
                return


def check_osd_plog_num(pool_ip, tmp_warn_list):
    """
    :param pool_ip: 池里面的ip
    :param tmp_warn_list:
    :return: 检查结果
    """
    LOGGER.logInfo("check plog num.")
    # start 预处理执行shell脚本收集集群节点信息获取返回值：单个OSD的pt数和plog数，PT_NUM=0,PLOG_NUM=0
    pre_inpect_result = HANDLE.getPreInspectResult(pool_ip, PRE_ITEM_ID)

    pt_num_pattern = re.compile(r'(?<=PT_NUM=)\d+\.?\d*')
    pt_num_match = pt_num_pattern.findall(pre_inpect_result)
    plog_num_pattern = re.compile(r'(?<=PLOG_NUM=)\d+\.?\d*')
    plog_num_match = plog_num_pattern.findall(pre_inpect_result)

    if (len(pt_num_match) == 0) or (len(plog_num_match) == 0):
        LOGGER.logInfo("query result failed from storage node.")
        tmp_warn_list.append("query pt info or plog info failed.")
        return False

    pt_num = float(pt_num_match[0])
    plog_num = float(plog_num_match[0])
    LOGGER.logInfo("pt num:[{}], plog num:[{}]".format(pt_num, plog_num))
    # plog_num*3>ptNUm*10240 巡检不通过。
    if (plog_num * 3) > (pt_num * 10240):
        return True
    return False


def check_disk_num_to_add(dev_node, product_version):
    """
    :param product_version:产品版本
    :param dev_node:
    :return: 检查结果
    """
    # 获取存储池
    storage_pools = dev_node.getStoragePools()
    if not storage_pools:
        return common.INSPECT_UNNORMAL, "", common.get_err_msg(LANG, "query.result.abnormal")
    tmp_err_list = []
    tmp_warn_list = []
    LOGGER.logInfo("check disk num an plog num.")
    for storage_pool in storage_pools:
        for disk_pool in storage_pool.getDiskPools():

            msg_head = common.get_err_msg(LANG,
                                          "storage.pool.disk.pool.msg",
                                          (storage_pool.getName(),
                                           disk_pool.getName()))
            pool_util = CheckPoolNumUtil(disk_pool, msg_head, LANG, LOGGER)
            pools = []
            for pool_nodes in pool_util.origin_nodes:
                pool = Pool(pool_nodes.getManagementIp(), disk_pool.getName(),
                            len(pool_util.origin_nodes), len(pool_util.expansion_nodes))
                pools.append(pool)
            if not pools:
                # 存储池不存在，检查通过
                return common.INSPECT_PASS, "CHECK PASS! product_version = {}".format(product_version), ""
            product_version = check_version(pools, product_version)
            software_result = check_storage_software(product_version)
            if software_result:
                # 软件版本不在风险版本列表内，没有风险
                return common.INSPECT_PASS, "CHECK PASS! product_version = {}".format(product_version), ""
            expansion_new_disk_num = get_expansion_new_disk_num(disk_pool)
            check_disk_num(pools, tmp_err_list, tmp_warn_list, expansion_new_disk_num, msg_head)
    if tmp_err_list:
        return common.INSPECT_UNNORMAL, "\n".join(tmp_err_list), ""
    if tmp_warn_list:
        return common.INSPECT_WARNING, "\n".join(tmp_warn_list), ""
    return common.INSPECT_PASS, "CHECK PASS! the productVersion:{}".format(product_version), ""


def execute(rest):
    """
    检查扩容场景下的盘数量和plog数量带来的风险。
    :param rest:
    :return:
    """

    dev_node = py_java_env.get("devInfo")
    try:
        product_version = get_storage_software(dev_node)

        return check_disk_num_to_add(dev_node, product_version)
    except (IsmException, Exception) as exception:
        LOGGER.logException(exception)
        return common.INSPECT_UNNORMAL, "", common.get_err_msg(LANG, "query.result.abnormal")
