# coding: UTF-8
import re

import common
from check_pool_num import CheckPoolNumUtil
from com.huawei.ism.exception import IsmException

LANG = common.getLang(py_java_env)
LOGGER = common.getLogger(PY_LOGGER, __file__)
HANDLE = py_java_env.get("preInspectHandle")
ITEM_ID = "check_plog_num_to_add"
PRE_ITEM_ID = "getDiskNumAndPlogNum"


class Pool:
    def __init__(self, pool_ip, pool_name, origin_num, expansion_num):
        self.pool_ip = pool_ip
        self.pool_name = pool_name
        self.origin_num = origin_num
        self.expansion_num = expansion_num

    def get_pool_ip(self):
        return self.pool_ip

    def get_pool_name(self):
        return self.pool_name

    def get_origin_num(self):
        return self.origin_num

    def get_expansion_num(self):
        return self.expansion_num


def get_storage_software(dev_node):
    """
    :param dev_node: 节点信息
    :function: 获取存储软件版本
    :return: 存储软件版本
    """
    LOGGER.logInfo("check product version.")
    product_version = dev_node.getProductVersion()
    patch_version = dev_node.getHotPatchVersion()

    if patch_version in product_version:
        return product_version
    elif product_version in patch_version:
        return patch_version
    elif not patch_version:
        return product_version
    else:
        return product_version + "." + patch_version


def check_storage_software(version):
    """
    :function: 存储软件风险配套检查
    :param version: 当前存储软件版本
    :return: True：检查通过；False:存在风险软件配套关系
    """
    black_product_ver = (
        "8.0.0.1",
        "8.0.1", "8.0.1.SPH5", "8.0.1.SPH8", "8.0.1.SPH10", "8.0.1.SPH11", "8.0.1.SPH12", "8.0.1.SPH301",
        "8.0.1.SPH302", "8.0.1.SPH501", "8.0.1.SPH502", "8.0.1.SPC600", "8.0.1.SPH601", "8.0.1.SPH602",
        "8.0.1.SPH606", "8.0.1.SPH607", "8.0.1.SPH608",
        "8.0.1.5",
        "8.0.2",
        "8.0.3", "8.0.3.SPH1", "8.0.3.SPH2", "8.0.3.SPH3", "8.0.3.SPH5", "8.0.3.SPH6",
        "8.1.0", "8.1.0.SPH3",
        "8.1.RC2",
        "8.1.1"
    )
    # 非风险版本与存储软件配套检查
    for black_ver in black_product_ver:
        if black_ver == version:
            return False
    return True


def check_version(pools, product_version):
    # 811HP3接口查不到，利用预巡检项来检查
    hp_flag_pattern = re.compile(r'(?<=HP_FLAG=)\d+\.?\d*')
    pool_ip = pools[0].get_pool_ip()
    pre_inpect_result = HANDLE.getPreInspectResult(pool_ip, PRE_ITEM_ID)
    hp_flag_match = hp_flag_pattern.findall(pre_inpect_result)
    hp_flag = float(hp_flag_match[0])
    # 是811HP3 hp_flag=2，否则为0
    if hp_flag - 1 > 0:
        product_version = "8.1.1.HP3"

    return product_version


def is_node_software_version_before_810(dev_node):
    """
    :function: 判断是否是810及后续版本
    :param dev_node:
    :return: True or False
    """
    # 适配匹配8.1.RC7
    pattern = re.compile(r"^(\d+)\.(\d+)\.*")
    match = pattern.findall(str(dev_node.getProductVersion()))
    if match is None:
        return False
    (major_version, minor_version) = (match[0][0], match[0][1])
    if int(major_version) > 8:
        return False
    elif int(major_version) == 8 and int(minor_version) >= 1:
        return False
    else:
        return True


def before_810_get_node_ip(storage_pools):
    """
    :function: 获取池里面的ip
    :param storage_pools:
    """
    pools = []
    for storage_pool in storage_pools:
        msg_head = common.get_err_msg(
            LANG, "storage.pool.msg", storage_pool.getName())
        pool_util = CheckPoolNumUtil(storage_pool, msg_head, LANG, LOGGER)
        # 获取池里面的ip
        for pool_nodes in pool_util.origin_nodes:
            pool = Pool(pool_nodes.getManagementIp(), storage_pool.getName(),
                        len(pool_util.origin_nodes), len(pool_util.expansion_nodes))
            pools.append(pool)

    return pools


def after_810_get_node_ip(storage_pools):
    """
    :function: 获取池里面的ip
    :param storage_pools:
    """
    pools = []
    for storage_pool in storage_pools:
        for disk_pool in storage_pool.getDiskPools():
            msg_head = common.get_err_msg(LANG,
                                          "storage.pool.disk.pool.msg",
                                          (storage_pool.getName(),
                                           disk_pool.getName()))
            pool_util = CheckPoolNumUtil(disk_pool, msg_head, LANG, LOGGER)
            # 获取池里面的ip
            for pool_nodes in pool_util.origin_nodes:
                pool = Pool(pool_nodes.getManagementIp(), disk_pool.getName(),
                            len(pool_util.origin_nodes), len(pool_util.expansion_nodes))
                pools.append(pool)
    return pools


def get_node_num_and_node_ip(storage_pools, before_810):
    """
    :function: 获取原始节点和新增节点数 和 所有原池ip
    :param before_810:
    :param storage_pools:
    :return: 原始节点和新增节点数 和 所有原池ip
    """
    ret_list = common.get_err_msg(LANG, "expansion.config.info")
    if not storage_pools:
        return common.INSPECT_UNNORMAL, ret_list, common.get_err_msg(
            LANG, "query.result.abnormal")

    # 810后getStoragePools()得到的是storage pool列表，810及以后是disk pool列表，需要区分处理
    if before_810:
        LOGGER.logInfo("software version is before 810.")
        return before_810_get_node_ip(storage_pools)

    LOGGER.logInfo("software version is after 810.")
    return after_810_get_node_ip(storage_pools)


def check_osd_plog_num(dev_node, product_version):
    """
    :function: 检查每个pt的plog数
    :param product_version:
    :param dev_node:
    :return: 检查结果
    """

    storage_pools = dev_node.getStoragePools()
    before_810 = is_node_software_version_before_810(dev_node)
    pools = get_node_num_and_node_ip(storage_pools, before_810)
    if not pools:
        # 存储池不存在，检查通过
        return common.INSPECT_PASS, "CHECK PASS! product_version = {}".format(product_version), ""
    product_version = check_version(pools, product_version)
    software_result = check_storage_software(product_version)
    if software_result:
        # 软件版本不在风险版本列表内，没有风险
        return common.INSPECT_PASS, "CHECK PASS! product_version = {}".format(product_version), ""
    tmp_err_list = []
    tmp_warn_list = []
    scale_name = []
    check_ret = True
    pt_num_pattern = re.compile(r'(?<=PT_NUM=)\d+\.?\d*')
    plog_num_pattern = re.compile(r'(?<=PLOG_NUM=)\d+\.?\d*')
    for pool in pools:
        pool_ip = pool.get_pool_ip()
        pool_name = pool.get_pool_name()
        LOGGER.logInfo("pool ip[{}],name[{}]."
                       .format(pool_ip, pool_name))

        # start 预处理执行shell脚本收集集群节点信息获取返回值：单个OSD的pt数和plog数，PT_NUM=0,PLOG_NUM=0
        pre_inpect_result = HANDLE.getPreInspectResult(pool_ip, PRE_ITEM_ID)

        pt_num_match = pt_num_pattern.findall(pre_inpect_result)
        plog_num_match = plog_num_pattern.findall(pre_inpect_result)

        if (len(pt_num_match) == 0) or (len(plog_num_match) == 0):
            LOGGER.logInfo("query result failed from storage node.")
            tmp_warn_list.append("query result failed from storage node.")
            scale_name.append(pool_name)
            check_ret = False
            continue

        pt_num = float(pt_num_match[0])
        plog_num = float(plog_num_match[0])
        LOGGER.logInfo("pt num:[{}], plog num:[{}]".format(pt_num, plog_num))
        # plog_num*3>ptNUm*10240 巡检不通过。
        if plog_num > (pt_num * 1024):
            check_ret = False
            scale_name.append(pool_name)
            tmp_err_list.append("plog num not meet the requirement".format(pool_name))
    str_scale_name = str(scale_name).replace('u\'', '\'')
    if not check_ret:
        tmp_err_list.extend(tmp_warn_list)
        return common.INSPECT_UNNORMAL, "pool name[{}], error info[{}]" \
            .format(str_scale_name, tmp_err_list), ""
    else:
        return common.INSPECT_PASS, "CHECK PASS! the productVersion:{}".format(product_version), ""


def execute(rest):
    """
    检查扩容前的plog数量
    :param rest:
    :return:
    """

    dev_node = py_java_env.get("devInfo")
    try:
        product_version = get_storage_software(dev_node)

        return check_osd_plog_num(dev_node, product_version)
    except (IsmException, Exception) as exception:
        LOGGER.logException(exception)
        return common.INSPECT_UNNORMAL, "", common.get_err_msg(LANG, "query.result.abnormal")
