# coding: UTF-8
import re
import string

import common
from com.huawei.ism.exception import IsmException

LANG = common.getLang(py_java_env)
LOGGER = common.getLogger(PY_LOGGER, __file__)
ITEM_ID = "expand_node_count_difference"

NODE_COUNT_DIFFERENCE_THRESHOLD = 0.7


def execute(rest):
    """
    检查扩容后存储池间节点数差异
    :param env:
    :return:
    """
    ret_list = common.get_err_msg(LANG, "expansion.config.info")
    dev_node = py_java_env.get("devInfo")
    observer = py_java_env.get("progressObserver")

    try:
        if not is_node_software_version_in_inspection_range(dev_node):
            return common.INSPECT_NOSUPPORT, ret_list, ""

        (inspection_ok, problematic_disk_pool_pairs) = execute_inspection(dev_node, observer)
        if inspection_ok:
            return common.INSPECT_PASS, ret_list, ""
        else:
            return common.INSPECT_WARNING, ret_list, generate_warning_message(problematic_disk_pool_pairs)

    except (IsmException, Exception) as exception:
        LOGGER.logException(exception)
        return common.INSPECT_UNNORMAL, ret_list, common.get_err_msg(LANG, "query.result.abnormal")


def is_node_software_version_in_inspection_range(dev_node):
    # 810开始才涉及本巡检的背景问题，812已得到修复，因此只需巡检810 811两个版本
    match = re.match(r"^(\d+\.\d+\.\d+).*$", str(dev_node.getProductVersion()))
    if match is not None and match.group(1) in ["8.1.0", "8.1.1"]:
        return True
    else:
        return False


def is_node_software_version_before_810(dev_node):
    match = re.match(r"^(\d+)\.(\d+)\.(\d+).*$", str(dev_node.getProductVersion()))
    if match is None:
        return False
    (major_version, minor_version, revision) = map(lambda it: int(it), match.groups())
    if major_version > 8:
        return False
    elif major_version == 8 and minor_version >= 1:
        return False
    else:
        return True


def generate_warning_message(problematic_disk_pools_list):
    result = ""
    for problematic_disk_pools in problematic_disk_pools_list:
        visualized_disk_pools = map(lambda p: (lambda pool, i:
                                    common.get_err_msg(LANG, "expand.node.count.difference.disk.pool.template",
                                                       (i + 1,
                                                        pool.getName(),
                                                        get_disk_pool_node_count_after_expansion(pool))))(*p),
                                    zip(problematic_disk_pools, range(0, len(problematic_disk_pools))))
        result += common.get_err_msg(LANG, "expand.node.count.difference.exceeds.threshold",
                                     string.join(visualized_disk_pools, ", "))
    return result


def group_by(key_selector, any_iterable):
    result_dict = {}
    for item in any_iterable:
        key = key_selector(item)

        if key not in result_dict:
            item_list = list()
            result_dict[key] = item_list
        else:
            item_list = result_dict.get(key)

        item_list.append(item)
    return result_dict


def get_disk_pool_node_count_after_expansion(disk_pool):
    node_count = 0
    node_count += len(disk_pool.getJoinedClusterNode())
    node_count += len(disk_pool.getExpansionClusterNode())
    node_count += len(disk_pool.getExpansionNodeList())
    return node_count


def execute_inspection(dev_node, observer):
    inspection_ok = True
    problematic_disk_pools_list = list()

    progress_map = {ITEM_ID: 1}
    observer.updateProgress(progress_map)

    # 810前getStoragePools()得到的是disk pool列表，810及以后是storage pool列表，需要区分处理
    if is_node_software_version_before_810(dev_node):
        disk_pools = dev_node.getStoragePools()
        (disk_pools_ok, local_problematic_disk_pools_list) = are_disk_pools_node_count_ok(disk_pools)
        inspection_ok = inspection_ok and disk_pools_ok
        problematic_disk_pools_list.extend(local_problematic_disk_pools_list)
    else:
        storage_pools = dev_node.getStoragePools()
        for storage_pool in storage_pools:
            # 本巡检的背景问题仅涉及融合池，过滤掉块池不检查
            if storage_pool.isBlock():
                continue
            (disk_pools_ok, local_problematic_disk_pools_list) = are_disk_pools_node_count_ok(
                storage_pool.getDiskPools())
            inspection_ok = inspection_ok and disk_pools_ok
            problematic_disk_pools_list.extend(local_problematic_disk_pools_list)

            size_of_storage_pools = len(storage_pools)
            if size_of_storage_pools != 0:
                progress_map[ITEM_ID] += 99 / size_of_storage_pools
                observer.updateProgress(progress_map)

    return inspection_ok, problematic_disk_pools_list


def are_disk_pools_node_count_ok(disk_pools):
    if len(disk_pools) == 0:
        return True, list()

    inspection_result = True
    problematic_disk_pools_list = list()
    disk_pool_group_by_disk_type = group_by(lambda disk_pool: disk_pool.getMainStorageDiskType(), disk_pools)
    for disk_pools_in_same_type in disk_pool_group_by_disk_type.itervalues():
        disk_pools_in_same_type.sort(key=get_disk_pool_node_count_after_expansion)
        size = len(disk_pools_in_same_type)
        disk_pool_with_max_nodes = disk_pools_in_same_type[size - 1]  # 上方的len判断保证size - 1不为-1
        disk_pool_with_min_nodes = disk_pools_in_same_type[0]  # 上方的len判断保证这里[0]能取到非空元素
        max_num = get_disk_pool_node_count_after_expansion(disk_pool_with_max_nodes)
        min_num = get_disk_pool_node_count_after_expansion(disk_pool_with_min_nodes)
        # 阈值并不是特别精确，采用宽松的取整方式向下取整
        if min_num >= int(max_num * NODE_COUNT_DIFFERENCE_THRESHOLD):
            pass
        else:
            inspection_result = False
            problematic_disk_pools_list.append(disk_pools_in_same_type)

    return inspection_result, problematic_disk_pools_list
