# coding: UTF-8

import com.huawei.ism.tool.protocol.utils.RestUtil as RestUtil
from com.huawei.ism.exception import IsmException

import common
from ds_rest_util import CommonRestService

LANG = common.getLang(py_java_env)
LOGGER = common.getLogger(PY_LOGGER, __file__)
ITEM_ID = "plogserver_upgradezone"
PY_JAVA_ENV = py_java_env
VERSION_BLACKLIST = "8.0.0.1,8.0.0.2,8.0.0,8.0.1,8.0.1.3,8.0.1.5,8.0.1.SPC600,8.0.1.SPH8,8.0.1.SPH10,8.0.1.SPH11,8.0.1.SPH12,8.0.1.SPH13,8.0.1.SPH301,8.0.1.SPH303,8.0.1.SPH501,8.0.1.SPH502,8.0.1.SPH601,8.0.1.SPH602,8.1.0.SPH1,8.1.0.SPH2,"
POOL_ID = ""


def execute(rest):
    """
    检查升级域划分风险
    :param env:
    :return:
    """
    ret_list = []
    dev_node = py_java_env.get("devInfo")
    observer = py_java_env.get("progressObserver")
    progressMap = {}
    try:
        # 判断当前版本是否无升级风险
        product_version = is_normal_version()
        if product_version:
            ret_list.append("the product_version %s has not upgrade risk." % str(product_version))
            return common.INSPECT_PASS, "\n".join(ret_list), ""

        progressMap[ITEM_ID] = 1
        observer.updateProgress(progressMap)
        base_uri = RestUtil.getDstorageUrlHead(dev_node)
        cmd_str = "{}/dsware/service/upgrade/getFsaNodeTopoAutoChange".format(
            base_uri
        )
        ret_list.append(cmd_str)
        nodeTopo_json = CommonRestService.exec_get_gor_big_by_ds(rest, cmd_str, False)
        ret_list.append(str(nodeTopo_json))
        if not nodeTopo_json.get("poolTopo"):
            return common.INSPECT_PASS, "\n".join(ret_list), ""

        progressMap[ITEM_ID] = 10
        observer.updateProgress(progressMap)
        cmd_str = "{}/dsware/service/server/queryAllHost".format(
            base_uri
        )
        ret_list.append(cmd_str)
        all_host_json = CommonRestService.exec_get_gor_big_by_ds(rest, cmd_str)
        ret_list.append(str(all_host_json))
        all_host = all_host_json.get("hostInfoList", [])
        if not all_host_json.get("hostInfoList"):
            ret_list.append("no host exist.")
            return common.INSPECT_PASS, "\n".join(ret_list), ""

        result = is_exist_risk(nodeTopo_json, all_host, ret_list)
        if result:
            msg_list = []
            msg_list.append(
                common.get_err_msg(LANG, "plogserver.upgradezone.has.risk", POOL_ID))
            return common.INSPECT_UNNORMAL, "\n".join(ret_list), "".join(msg_list)
        else:
            return common.INSPECT_PASS, "\n".join(ret_list), ""

    except (IsmException, Exception) as exception:
        LOGGER.logException(exception)
        return (
            common.INSPECT_UNNORMAL,
            "\n".join(ret_list),
            common.get_err_msg(LANG, "query.result.abnormal"),
        )


# 检查是否是有风险的版本
def is_normal_version():
    product_version = str(py_java_env.get("devInfo").getProductVersion())
    product_version = product_version + ","
    LOGGER.logInfo("product_version is:%s." % str(product_version))
    result = product_version in VERSION_BLACKLIST
    if result:
        return ""
    # 除了在黑名单里的版本，还有一个product_version为810且patch_version为空的版本，都需要巡检
    if product_version == "8.1.0,":
        patch_version = str(py_java_env.get("devInfo").getHotPatchVersion())
        LOGGER.logInfo("patch_version is:%s." % str(patch_version))
        if patch_version:
            return "8.1.0." + patch_version
        else:
            return ""
    return product_version


# 检查升级域划分风险
def is_exist_risk(nodeTopo_json, all_host, ret_list):
    current_max_zone_num = 0
    current_pool_zone_num = 0
    pool_topo = nodeTopo_json.get("poolTopo", [])
    for pool in pool_topo:
        if pool.get("securityLevel") == "server":
            current_pool_zone_num = len(pool.get("osdIpList"))
            LOGGER.logInfo("the server current_pool_zone_num is:%d." % current_pool_zone_num)
        elif pool.get("securityLevel") == "rack":
            current_pool_zone_num = getRackNumber(pool, all_host)
            LOGGER.logInfo("the rack current_pool_zone_num is:%d." % current_pool_zone_num)
            if current_max_zone_num != 0 and current_pool_zone_num - current_max_zone_num >= 2:
                ret_list.append("the plogserver upgradezone has risk.")
                global POOL_ID
                POOL_ID = pool.get("poolId")
                return True
        current_max_zone_num = max(current_max_zone_num, current_pool_zone_num)
        LOGGER.logInfo("the current_max_zone_num is:%d." % current_max_zone_num)
    return False


# 获取机柜级池的升级域数量
def getRackNumber(pool, all_host):
    pool_servers = pool.get("osdIpList")
    rack_set = set()
    for server_ip in pool_servers:
        for hostInfo in all_host:
            if server_ip == hostInfo.get("manageIp"):
                rack_set.add(hostInfo.get("rackId"))
    return len(rack_set)
