# coding: UTF-8

import com.huawei.ism.tool.protocol.utils.RestUtil as RestUtil
from com.huawei.ism.exception import IsmException

import common
from ds_rest_util import CommonRestService

LANG = common.getLang(py_java_env)
LOGGER = common.getLogger(PY_LOGGER, __file__)
ITEM_ID = "fsm_zk_num"


def execute(rest):
    """
    检查zk节点数量
    :param env:
    :return:
    """
    ret_list = []
    dev_node = py_java_env.get("devInfo")
    observer = py_java_env.get("progressObserver")
    progressMap = {}
    try:
        progressMap[ITEM_ID] = 1
        observer.updateProgress(progressMap)
        base_uri = RestUtil.getDstorageUrlHead(dev_node)
        cmd_str = "{}/api/v2/cluster/servers".format(
            base_uri
        )
        ret_list.append(cmd_str)
        response_json = CommonRestService.exec_get_gor_big_by_ds(rest, cmd_str)
        ret_list.append(str(response_json))
        check_ok = check_response(response_json)
        if not check_ok:
            return common.INSPECT_UNNORMAL, "\n".join(ret_list), common.get_err_msg(LANG, "query.result.abnormal")

        progressMap[ITEM_ID] = 50
        observer.updateProgress(progressMap)
        servers = response_json.get("data")

        if not servers:
            return (
                common.INSPECT_UNNORMAL,
                "\n".join(ret_list),
                common.get_err_msg(LANG, "query.result.abnormal"),
            )

        zk_count = get_node_count("zk", "usage", servers)
        storage_count = get_node_count("storage", "role", servers)

        # 如果存储节点小于128，则要求zk节点数至少为5
        # 如果存储节点不小于128，则要求zk节点数至少为7
        if (storage_count < 128 and zk_count < 5) or (storage_count >= 128 and zk_count < 7):
            return (
                common.INSPECT_UNNORMAL,
                "\n".join(ret_list),
                common.get_err_msg(LANG, "fsm.zk.node.count.not.match", (storage_count, zk_count))
            )

        return common.INSPECT_PASS, "\n".join(ret_list), ""
    except (IsmException, Exception) as exception:
        LOGGER.logException(exception)
        return (
            common.INSPECT_UNNORMAL,
            "\n".join(ret_list),
            common.get_err_msg(LANG, "query.result.abnormal"),
        )


def get_node_count(key_name, group_name, server_list):
    return sum(1 for server in server_list if key_name in str(server.get(group_name))) or 0


def check_response(response):
    result = response.get("result")
    if result and result.get("code") != 0:
        return False
    else:
        return True
