# coding: UTF-8
import re
import string

import common
from com.huawei.ism.exception import IsmException
from check_pool_num import CheckPoolNumUtil

LANG = common.getLang(py_java_env)
LOGGER = common.getLogger(PY_LOGGER, __file__)
HANDLE = py_java_env.get("preInspectHandle")
ITEM_ID = "check_adjust_recon_band_width"
PRE_ITEM_ID = "getReconBandWidth"


class Pool:
    def __init__(self, pool_ip, pool_name, origin_num, expansion_num):
        self.pool_ip = pool_ip
        self.pool_name = pool_name
        self.origin_num = origin_num
        self.expansion_num = expansion_num

    def get_pool_ip(self):
        return self.pool_ip

    def get_pool_name(self):
        return self.pool_name

    def get_origin_num(self):
        return self.origin_num

    def get_expansion_num(self):
        return self.expansion_num


def get_storage_software(dev_node):
    """
    :function: 获取存储软件版本
    :param version: 节点信息
    :return: 存储软件版本
    """
    product_version = dev_node.getProductVersion()
    patch_version = dev_node.getHotPatchVersion()

    if patch_version in product_version:
        return product_version
    elif product_version in patch_version:
        return patch_version
    elif not patch_version:
        return product_version
    else:
        return product_version + "." + patch_version


def check_storage_software(version):
    """
    :function: 存储软件风险配套检查
    :param version: 当前存储软件版本
    :return: True：检查通过；False:存在风险软件配套关系
    """
    black_product_ver = (
        "8.0.1.SPH5", "8.0.1.SPH8", "8.0.1.SPH11", "8.0.1.SPH301", "8.0.1.SPH302", "8.0.1.SPH501", "8.0.1.SPH601",
        "8.0.1.SPC600", "8.0.1.SPH606", "8.0.1.SPH607",
        "8.0.2",
        "8.0.3.SPH1", "8.0.3.SPH2",
        "8.1.0", "8.1.0.SPH1", "8.1.0.SPH2", "8.1.0.SPH3", "8.1.0.SPH5", "8.1.0.SPH3", "8.1.0.SPH6",
        "8.1.1", "8.1.1.SPH2", "8.1.1.SPH1", "8.1.1.SPH1",
    )
    # 非风险版本与存储软件配套检查
    for black_ver in black_product_ver:
        if black_ver == version:
            return False
    return True


def is_node_software_version_before_810(dev_node):
    pattern = re.compile(r"^(\d+)\.(\d+)\.(\d+).*$")
    match = pattern.findall(str(dev_node.getProductVersion()))
    if match is None:
        return False
    (major_version, minor_version, revision) = (match[0][0], match[0][1], match[0][2])
    if int(major_version) > 8:
        return False
    elif int(major_version) == 8 and int(minor_version) >= 1:
        return False
    else:
        return True


def get_node_num_and_node_ip(storage_pools, before_810):
    """
    :function: 获取原始节点和新增节点数 和 一个原池ip
    :param version: none
    :return: 原始节点和新增节点数 和 一个原池ip
    """
    ret_list = common.get_err_msg(LANG, "expansion.config.info")
    pools = []
    if not storage_pools:
        return common.INSPECT_UNNORMAL, ret_list, common.get_err_msg(
            LANG, "query.result.abnormal")

    # 810后getStoragePools()得到的是storage pool列表，810及以后是disk pool列表，需要区分处理
    if before_810:
        LOGGER.logInfo("software version is before 810.")
        for storage_pool in storage_pools:
            msg_head = common.get_err_msg(
                LANG, "storage.pool.msg", storage_pool.getName())
            pool_util = CheckPoolNumUtil(storage_pool, msg_head, LANG, LOGGER)
            # len为0说明该池不扩容 跳过
            if len(pool_util.expansion_nodes) != 0 and len(pool_util.origin_nodes) != 0:
                pool = Pool(pool_util.origin_nodes[0].getManagementIp(), storage_pool.getName(),
                            len(pool_util.origin_nodes), len(pool_util.expansion_nodes))
                pools.append(pool)

        return pools

    LOGGER.logInfo("software version is after 810.")
    for storage_pool in storage_pools:
        for disk_pool in storage_pool.getDiskPools():
            msg_head = common.get_err_msg(LANG,
                                          "storage.pool.disk.pool.msg",
                                          (storage_pool.getName(),
                                           disk_pool.getName()))
            pool_util = CheckPoolNumUtil(disk_pool, msg_head, LANG, LOGGER)
            # len为0说明该池不扩容 跳过
            if len(pool_util.expansion_nodes) != 0 and len(pool_util.origin_nodes) != 0:
                pool = Pool(pool_util.origin_nodes[0].getManagementIp(), storage_pool.getName(),
                            len(pool_util.origin_nodes), len(pool_util.expansion_nodes))
                pools.append(pool)

    return pools


def check_new_expansion_nodes_band_width(product_version, dev_node):
    """
    :function: 校验带宽 给出建议值
    :param version: none
    :return: 检查结果
    """
    check_ret = True
    scale_name = []
    error_info = []
    storage_pools = dev_node.getStoragePools()
    before_810 = is_node_software_version_before_810(dev_node)
    pools = get_node_num_and_node_ip(storage_pools, before_810)
    for pool in pools:
        pool_ip = pool.get_pool_ip()
        pool_name = pool.get_pool_name()
        origin_num = pool.get_origin_num()
        expansion_num = pool.get_expansion_num()
        LOGGER.logInfo("pool ip[{}],name[{}],origin num[{}],expansion num[{}]."
                       .format(pool_ip, pool_name, origin_num, expansion_num))
        # start 预处理执行shell脚本收集集群节点信息获取返回值：配置值 重构带宽 和 网络带宽阈值，执行结果返回格式 W2_ORI=0,W2=0,W1_S=0
        pre_inpect_result = HANDLE.getPreInspectResult(pool_ip, PRE_ITEM_ID)

        w2_pattern = re.compile(r'(?<=W2=)\d+\.?\d*')
        w2_match = w2_pattern.findall(pre_inpect_result)
        w1s_pattern = re.compile(r'(?<=W1_S=)\d+\.?\d*')
        w1s_match = w1s_pattern.findall(pre_inpect_result)
        w2o_pattern = re.compile(r'(?<=W2_ORI=)\d+\.?\d*')
        w2o_match = w2o_pattern.findall(pre_inpect_result)
        LOGGER.logInfo("result[{}].W2_match[].W1S_match[{}]".format(pre_inpect_result, str(w2_match), str(w1s_match)))

        if (len(w2_match) == 0) or (len(w1s_match) == 0) or (len(w2o_match) == 0):
            LOGGER.logInfo("query result failed from storage node.")
            check_ret = False
            scale_name.append(pool_name)
            error_info.append("query band width info failed")
            continue

        width_2 = float(w2_match[0])
        width_thrd = float(w1s_match[0])
        width_origin = float(w2o_match[0])
        suggest_width_3 = 0
        LOGGER.logInfo("step1:read:W2_ORI[{}],W2[{}],W1_S[{}]".format(str(width_origin), str(width_2), str(width_thrd)))

        # 计算W2：重构宽带计算 E > O: 重构带宽W2=单盘能力预估值C*单节点盘的数量D*冗余比P*扩容节点数E/原始节点数O
        width_2 = (width_2 * expansion_num / origin_num) if (expansion_num > origin_num) else width_2
        LOGGER.logInfo("step2: get Reconstruction bandwidth W2[{}]. W1_S[{}]".format(str(width_2), str(width_thrd)))

        # 判断是否需要修改节点重构带宽
        # 重构带宽W2 > 网络带宽阈值W1_S (网络带宽W1*网络阈值率P_W1)，则检查不通过
        if width_2 > width_thrd:
            check_ret = False
            suggest_width_3 = (width_thrd * origin_num / expansion_num) if (expansion_num > origin_num) else width_thrd
            LOGGER.logInfo("step3: adjust bandwidth W2[{}]. W1_S[{}]".format(str(width_2), str(width_thrd)))
            scale_name.append(pool_name)
            error_info.append("now W2[{}], suggest W3[{}]".format(str(width_2), str(suggest_width_3)))
            continue
    str_scale_name = str(scale_name).replace('u\'', '\'')
    str_error_info = str(error_info).replace('u\'', '\'')
    if check_ret:
        return common.INSPECT_PASS, "CHECK PASS! the productVersion:{}".format(product_version), ""  # 不在风险内，没有风险
    else:
        return common.INSPECT_UNNORMAL, "productVersion[{}] problem, the pools[{}] error info[{}]".format(
            product_version, str_scale_name, str_error_info), common.get_err_msg(
            LANG, "pool.storage.recon.band.width", (product_version, str_scale_name, str_error_info))


def execute(rest):
    """
    检查扩容后存储池的重构宽带
    :param env:
    :return:
    """
    dev_node = py_java_env.get("devInfo")
    try:
        product_version = get_storage_software(dev_node)
        software_result = check_storage_software(product_version)
        if software_result:
            return common.INSPECT_PASS, "product_version = {}".format(
                product_version), ""  # 软件版本不在风险版本列表内，没有风险
        return check_new_expansion_nodes_band_width(product_version, dev_node)
    except (IsmException, Exception) as exception:
        LOGGER.logException(exception)
        return (
            common.INSPECT_UNNORMAL, "", common.get_err_msg(LANG, "query.result.abnormal"),
        )
