# coding: UTF-8
import re
import string

import common
import com.huawei.ism.tool.protocol.utils.RestUtil as RestUtil
from com.huawei.ism.exception import IsmException
from check_pool_num import CheckPoolNumUtil
from ds_rest_util import CommonRestService
from com.huawei.ism.tool.protocol.rest import RestConnectionManager

HANDLE = py_java_env.get("preInspectHandle")

LANG = common.getLang(py_java_env)
LOGGER = common.getLogger(PY_LOGGER, __file__)
ITEM_ID = "check_spdk_consistency"
PRE_ITEM_ID = "getSpdkSwitch"


class Pool:
    def __init__(self, spdk_switch, pool_name, open_node, close_node):
        self.spdk_switch = spdk_switch
        self.pool_name = pool_name
        self.open_node = open_node
        self.close_node = close_node

    def get_spdk_switch(self):
        return self.spdk_switch

    def get_pool_name(self):
        return self.pool_name

    def get_open_node(self):
        return self.open_node

    def get_close_nose(self):
        return self.close_node


def build_pool_by_disk(base_uri, rest, pool_id, pool_name):
    """
    :function: 801之前 获取当前存储池下的spdk开关 得出一致性结果
    :param version: none
    :return: 存储池的pool对象
    """
    spdk_open = "spdk_switch=1"
    spdk_consistency = 1
    open_count = 0
    close_count = 0
    open_node = []
    close_node = []
    cmd_str = (
        "{}/dsware/service/cluster/storagepool/"
        "queryNodeDiskInfo?poolId={}".format(base_uri, pool_id)
    )
    node_json = \
        CommonRestService.exec_get_gor_big_by_ds(rest, cmd_str)
    # 池中没有子节点
    if not node_json.get("nodeInfo"):
        return None
    nodes_info = node_json.get("nodeInfo", [])

    # 对每个池中的ip进行遍历
    pool_info = "StoragePoolName[{}]".format(pool_name)
    LOGGER.logInfo("nodeInfo is:%s." % nodes_info)
    for node in nodes_info:
        node_ip = node.get("nodeMgrIp")
        # start 预处理执行shell脚本收集集群节点信息获取返回值：spdk_switch=1
        spdk_switch = HANDLE.getPreInspectResult(node_ip, PRE_ITEM_ID)
        LOGGER.logInfo("diskPoolsName[{}]-nodeMgrIp[{}]-spdkSwitch[{}]".format(pool_name, node_ip, spdk_switch))
        if spdk_switch == spdk_open:
            LOGGER.logInfo("NodeIP:%s spdk: open " % str(node_ip))
            open_count += 1
            open_node.append(node_ip)
        else:
            LOGGER.logInfo("NodeIP:%s spdk: close " % str(node_ip))
            close_count += 1
            close_node.append(node_ip)
    # 得到开的节点数量和关的数量 需要做决策
    # 1、打开（关闭）数量和存储节点数量一样 检查通过 2、不通过
    LOGGER.logInfo("Node name:[{}] spdk num: open[{}] close[{}]".format(pool_info, open_count, close_count))
    if open_count == len(nodes_info) or close_count == len(nodes_info):
        LOGGER.logInfo("Node name:%s spdk open all check pass " % str(pool_info))
    else:
        LOGGER.logInfo(
            "Node name:[{}] spdk status: open[{}] close[{}]".format(pool_info, open_count, close_count))
        spdk_consistency = 0

    pool = Pool(spdk_consistency, pool_info, open_node, close_node)
    return pool


def build_disk_pool(disk_pool_record, base_uri, rest, pool_name):
    """
    :function: 810之后获取一个硬盘池的spdk信息
    :param version: none
    :return: 一个硬盘池的pool对象
    """
    spdk_open = "spdk_switch=1"
    spdk_consistency = 1
    open_node = []
    close_node = []
    disk_pools_name = disk_pool_record.get("poolName")
    disk_pool_id = disk_pool_record.get("poolId")

    if disk_pool_id is None:
        return None
    cmd_str = (
        "{}/dsware/service/cluster/diskpool/"
        "queryNodeDiskInfo?diskPoolId={}".format(base_uri, disk_pool_id)
    )
    node_json = \
        CommonRestService.exec_get_gor_big_by_ds(rest, cmd_str)
    if not node_json.get("nodeInfo"):
        return None

    nodes_info = node_json.get("nodeInfo", [])
    # 对每个池中的ip进行遍历
    LOGGER.logInfo("nodeInfo is:%s." % nodes_info)
    pool_info = "StoragePoolName[{}]-DiskPoolName[{}]".format(pool_name, disk_pools_name)
    open_count = 0
    close_count = 0
    for node in nodes_info:
        node_ip = node.get("nodeMgrIp")
        # start 预处理执行shell脚本收集集群节点信息获取返回值：spdk_switch=1
        spdk_switch = HANDLE.getPreInspectResult(node_ip, PRE_ITEM_ID)
        LOGGER.logInfo("diskPoolsName[{}]-nodeMgrIp[{}]-spdkSwitch[{}]".format(disk_pools_name, node_ip, spdk_switch))
        if spdk_switch == spdk_open:
            LOGGER.logInfo("NodeIP:%s spdk: open " % str(node_ip))
            open_count += 1
            open_node.append(node_ip)
        else:
            LOGGER.logInfo("NodeIP:%s spdk: close " % str(node_ip))
            close_count += 1
            close_node.append(node_ip)
    # 得到开的节点数量和关的数量 需要做决策
    # 1、打开(关闭)数量和存储节点数量一样 检查通过 2、不通过
    LOGGER.logInfo("Node name:[{}] spdk num: open[{}] close[{}]".format(pool_info, open_count, close_count))
    if open_count == len(nodes_info) or close_count == len(nodes_info):
        LOGGER.logInfo("Node name:%s spdk open all check pass " % str(pool_info))
    else:
        LOGGER.logInfo(
            "Node name:[{}] spdk status: open[{}] close[{}]".format(pool_info, open_count, close_count))
        spdk_consistency = 0

    pool = Pool(spdk_consistency, pool_info, open_node, close_node)
    return pool


def build_pool_by_disk_pool(base_uri, rest, pool_id, pool_name):
    """
    :function: 810之后 获取存储池下的所有硬盘池 的一致性结果
    :param version: none
    :return: 存储池的pool对象数组
    """
    cmd_str = (
        "{}/api/v2/data_service/diskpool?storagePoolId={}".format(base_uri,
                                                                  pool_id)
    )
    disk_pool_json = \
        CommonRestService.exec_get_gor_big_by_ds(rest, cmd_str)
    if not disk_pool_json.get("diskPools"):
        return None
    disk_pool_info_list = disk_pool_json.get("diskPools", [])

    LOGGER.logInfo("disk_pool_info_list is:%s." % disk_pool_info_list)

    pools = []
    for disk_pool_record in disk_pool_info_list:
        pool = build_disk_pool(disk_pool_record, base_uri, rest, pool_name)
        if pool is None:
            continue
        pools.append(pool)
    return pools


def build_pool(base_uri, dev_node, rest, storage_pools):
    """
    :function: 查询所有的池以及spdk的检查结果
    :param version: none
    :return: 池数组
    """
    LOGGER.logInfo("storage_pools is:%s." % storage_pools)
    pools = []
    for record in storage_pools:

        pool_name = record.get("poolName")
        LOGGER.logInfo("poolName is:{}.".format(pool_name))
        pool_id = record.get("poolId")
        LOGGER.logInfo("pool_id is:{}.".format(pool_id))
        # 获取池id
        if pool_id is None:
            continue

        # 8.0的块设备，结构为：一个存储池-n个硬盘
        # 8.1及之后的设备，结构为：一个存储池-n个硬盘池-m个硬盘
        product_version = str(dev_node.getProductVersion())
        LOGGER.logInfo("product_version is:{}.".format(product_version))
        if product_version.startswith("8.0"):
            pool = build_pool_by_disk(base_uri, rest, pool_id, pool_name)
            if pool is None:
                continue
            pools.append(pool)
        else:
            pool = build_pool_by_disk_pool(base_uri, rest, pool_id, pool_name)
            if pool is None:
                continue
            pools += list(pool)
    return pools


def get_node_pool(dev_node):
    """
    获取环境所有的池
    :param dev_node: 环境信息
    :return:
    """
    pools = []
    try:
        rest = RestConnectionManager.getRestConnection(dev_node)
        base_uri = RestUtil.getDstorageUrlHead(dev_node)
        cmd_str = "{}/dsware/service/resource/queryStoragePool".format(
            base_uri
        )
        pools_json = CommonRestService.exec_get_gor_big_by_ds(rest, cmd_str)
        if not pools_json.get("storagePools"):
            return None

        storage_pools = pools_json.get("storagePools", [])

        pools = build_pool(base_uri, dev_node, rest, storage_pools)
        RestConnectionManager.releaseConn(dev_node)
        return pools
    except (IsmException, Exception) as exception:
        LOGGER.logException(exception)
        RestConnectionManager.releaseConn(dev_node)
        return pools


def execute(rest):
    """
    检查spdk一致性
    :param env:
    :return:
    """
    dev_node = py_java_env.get("devInfo")
    try:
        check_ret = True
        scale_name = []
        error_info = []

        pools = get_node_pool(dev_node)
        if pools is None:
            return common.INSPECT_PASS, "CHECK PASS! No storage pool", ""
        LOGGER.logInfo("pools len:%s." % str(len(pools)))
        for disk_pool in pools:
            spdk_consistency = disk_pool.get_spdk_switch()
            disk_pool_name = disk_pool.get_pool_name()
            open_node = disk_pool.get_open_node()
            close_node = disk_pool.get_close_nose()
            LOGGER.logInfo("spdk_switch[{}]-get_pool_name[{}].".format(spdk_consistency, disk_pool_name))

            if spdk_consistency == 1:
                LOGGER.logInfo("Node name:%s spdk open all check pass " % str(disk_pool_name))
            else:
                LOGGER.logInfo("Node name:[{}] spdk open check failed".format(disk_pool_name))
                check_ret = False
                scale_name.append(disk_pool_name)
                error_info.append("spdk switch: open[{}] close[{}]".format(len(open_node), len(close_node)))

        str_scale_name = str(scale_name).replace('u\'', '\'')
        str_error_info = str(error_info).replace('u\'', '\'')
        if check_ret:
            return common.INSPECT_PASS, "CHECK PASS! SPDK ALL CLOSE or OPEN", ""
        else:
            return common.INSPECT_UNNORMAL, "check failed. the pools[{}] error info[{}]".format(
                str_scale_name, str_error_info), common.get_err_msg(
                LANG, "pool.storage.spdk.consistency", (str_scale_name, str_error_info))
    except (IsmException, Exception) as exception:
        LOGGER.logException(exception)
        return (
            common.INSPECT_UNNORMAL, "", common.get_err_msg(LANG, "query.result.abnormal"),
        )
