# coding:utf-8
# Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
import os
import re

import com.huawei.ism.tool.protocol.utils.RestUtil as RestUtil

import common
from common import get_err_msg
from ds_rest_util import CommonRestService

LANG = common.getLang(py_java_env)
LOGGER = common.getLogger(PY_LOGGER, __file__)
HANDLE = py_java_env.get("preInspectHandle")
ITEM_ID = "check_permanent_unbalance_capacity"
PRE_ITEM_ID = "pre_inspect_get_unbalance_capacity"
QUERY_FAILED_PREFIX = "query failed: "
NO_SUPPORT_PREFIX = "pre_inspect_get_unbalance_capacity-3"
NOT_INVOLVED = "not.involved"
TARGET_VERSION = "目标版本："
TARGET_VERSION_EN = "TARGET VERSION: "

# 11%
FREE_RATIO_THRESHOLD = 11000

COMMON_WORK_PATH = "/opt/dfv/oam/public/script/inspect"
# 获取工作目录，如果是正常的节点，返回空，如果是自定义了opt目录的dpc节点，返回自定义的目录
CURRENT_PATH = os.path.realpath(__file__).split("/opt/dfv/oam/public/script/inspect")[0] \
    if COMMON_WORK_PATH in os.path.realpath(__file__) else ""

# 保存结果为不通过时的错误信息
err_msg = []
# 保存查询到的原始信息,当报错时,方便定位
res_msg = []

# 查询到的存储池信息：
pools_info = dict()


def is_involved_version(product_version):
    """
    检查版本信息:813版本支持该巡检项
    """
    version = product_version.replace('.', "")
    if "813" == version[0:3]:
        return True
    return False


def parse_result(pre_inspect_result, node_ip_str):
    query_res = re.search(r'pool_id:(\d+),unbalance_capacity:(\d+)', pre_inspect_result)
    if not query_res:
        # 解析失败，查询结果有误，巡检不通过
        LOGGER.logError("pre inspect query failed." + node_ip_str + "pre item id:{}".format(PRE_ITEM_ID))
        err_msg.append(node_ip_str + get_err_msg(LANG, "query.storage.pool.error"))
        return False

    pool_id = int(query_res.group(1))
    unbalance_capacity = int(query_res.group(2))
    if pool_id not in pools_info:
        LOGGER.logError("pool_id:{} not in pools_info".format(pool_id))
        return False

    storage_pool = pools_info[pool_id]
    node_num = storage_pool.get("poolNodeNum")
    total_cap = storage_pool.get("writableCapacity")
    used_cap = storage_pool.get("usedCapacityAfterDedup")
    if node_num is None or total_cap is None or used_cap is None:
        LOGGER.logError("param invalid in {}".format(storage_pool))
        return False

    free_ratio = (total_cap - (unbalance_capacity) * node_num - used_cap) * 100000 / total_cap
    LOGGER.logInfo(
        "pool:{} node num:{} totalCapacity:{} usedCapacity:{} unbalance_capacity:{} free_ratio:{}.".format(
            pool_id, node_num, total_cap, used_cap, unbalance_capacity, free_ratio))
    res_msg.append(
        node_ip_str + "pool_id:{}, node_num:{}, total_cap:{}, used_cap:{}, unbalance_capacity:{}"
        .format(pool_id, node_num, total_cap, used_cap, unbalance_capacity))
    if free_ratio <= FREE_RATIO_THRESHOLD:
        LOGGER.logError(get_err_msg(LANG, "check.storage.pool.water.level.not.pass").format(pool_id))
        err_msg.append(get_err_msg(LANG, "check.storage.pool.water.level.not.pass").format(pool_id))
        return False

    return True


def get_pool_info(dev_node, rest):
    try:
        base_uri = RestUtil.getDstorageUrlHead(dev_node)
        cmd_str = "{}/dsware/service/resource/queryStoragePool?baseInfo=false".format(base_uri)
        pools_json = CommonRestService.exec_get_gor_big_by_ds(rest, cmd_str)
        if not pools_json.get("storagePools"):
            LOGGER.logInfo("get storagePools obj failed")
            return False
        for item in pools_json.get("storagePools", []):
            pool_id = item.get("poolId")
            pools_info[pool_id] = item
    except Exception as exception:
        LOGGER.logException(exception)
        return False
    return True


def check_whitelist(new_version):
    LOGGER.logInfo("targetversion is:{}.".format(new_version))
    version = new_version.replace('.', "")
    if "821" == version[0:3]:
        return True
    # 判断是否符合 8.1.5.SPHXXX 或 8.2.0.SPHXXX 的格式，并提取数字部分
    match_info = re.match(r'(?i)^(8\.(1\.5|2\.0)\.SPH(\d{3}))$', new_version)
    if match_info:
        # 提取出补丁部分
        num = int(match_info.group(3))  # 将匹配的数字字符串转换为整数
        # 判断是否符合不同规则
        LOGGER.logInfo("target upgrade version({}) SPH({}).".format(match_info.group(2), num))
        if match_info.group(2) == "1.5" and num >= 39:
            return True
        elif match_info.group(2) == "2.0" and num >= 26:
            return True 
    return False


def execute(rest):
    # 获取节点
    dev_node = py_java_env.get("devInfo")

    # 判断是否是涉及版本
    product_version = str(dev_node.getProductVersion())
    if not is_involved_version(product_version):
        LOGGER.logInfo("product version({}) is not involved.".format(product_version))
        return common.INSPECT_NOSUPPORT, get_err_msg(LANG, "query.version.na", product_version), get_err_msg(
            LANG, "query.version.na", product_version)
    # 查询升级目标版本
    new_version = py_java_env.get("version").split(TARGET_VERSION)[-1].split("）")[0].strip()
    new_version_en = py_java_env.get("version").split(TARGET_VERSION_EN)[-1].split(")")[0].strip()
    if check_whitelist(new_version) or check_whitelist(new_version_en):
        LOGGER.logInfo("target upgrade version({}) is not involved.".format(py_java_env.get("version")))
        return common.INSPECT_NOSUPPORT, get_err_msg(LANG, "query.version.na", py_java_env.get("version")), get_err_msg(
            LANG, "query.version.na", py_java_env.get("version"))
    
    dev_nodes = dev_node.getClusterNodes()
    if not dev_nodes or len(dev_nodes) == 0:
        return common.INSPECT_NOSUPPORT, "No Nodes", ""

    if not get_pool_info(dev_node, rest):
        return common.INSPECT_UNNORMAL, "no pool info can get", "no pool info can get"

    is_pass = True
    is_checked = False
    for node in dev_nodes:
        node_ip = common.get_node_ip(node)
        node_ip_str = "node ip:{} ".format(node_ip)

        # 获取查询结果
        pre_inspect_result = HANDLE.getPreInspectResult(node_ip, PRE_ITEM_ID)
        # 不涉及，检查其他节点
        if not pre_inspect_result:
            LOGGER.logInfo("not exist pre inspect result. node ip:{} pre item id:{}".format(node_ip, PRE_ITEM_ID))
            res_msg.append(node_ip_str + get_err_msg(LANG, NOT_INVOLVED))
            continue
        if NO_SUPPORT_PREFIX in pre_inspect_result:
            LOGGER.logInfo(
                get_err_msg(LANG, NOT_INVOLVED) + ". node ip:{} pre item id:{}".format(node_ip, PRE_ITEM_ID))
            res_msg.append(node_ip_str + get_err_msg(LANG, NOT_INVOLVED))
            continue

        is_checked = True
        # 查询失败，巡检不通过
        if QUERY_FAILED_PREFIX in pre_inspect_result:
            is_pass = False
            LOGGER.logError("pre inspect query failed. node ip:{} pre item id:{}".format(node_ip, PRE_ITEM_ID))
            res_msg.append(node_ip_str + get_err_msg(LANG, "query.storage.pool.error") + "({})".format(
                pre_inspect_result[
                pre_inspect_result.find(QUERY_FAILED_PREFIX) + len(QUERY_FAILED_PREFIX):pre_inspect_result.find("]")]))
            err_msg.append(node_ip_str + get_err_msg(LANG, "query.storage.pool.error") + "({})".format(
                pre_inspect_result[
                pre_inspect_result.find(QUERY_FAILED_PREFIX) + len(QUERY_FAILED_PREFIX):pre_inspect_result.find("]")]))
            continue

        # 解析查询结果
        is_pass = parse_result(pre_inspect_result, node_ip_str) and is_pass

    if is_checked:
        if not is_pass:
            return common.INSPECT_UNNORMAL, '\n'.join(res_msg), '\n'.join(err_msg)
        return common.INSPECT_PASS, '\n'.join(res_msg), ""
    else:
        return common.INSPECT_NOSUPPORT, '\n'.join(res_msg), ""
