# -*- coding: UTF-8 -*-

import time
import utils
import rest_util
import service_utils
from constants import Result
from constants import Timeout
from constants import PoolStatus

from com.huawei.bundleupgrade.entity import NodeUpgradePriorityEnums


def execute(context):
    """
    查询存储池状态的入口函数
    :param context: 上下文信息
    :return: PASS：检查通过；NOT_PASS:检查不通过
    """
    logger = context.get("logger")
    lang = context.get("lang")
    try:
        logger.info("start check pool status.")
        return PoolStatusCheck(context).check()
    except Exception as e:
        logger.error("pool status check exception:%s" % str(e))
        return Result.NOT_PASS, utils.get_msg(lang, "check.pool.status.error")


class PoolStatusCheck:
    def __init__(self, context):
        self.context = context
        self.logger = context.get("logger")
        self.lang = context.get("lang")
        self.manage_ip = context.get("managementIP")
        self.device = context.get("deviceEntity")

    def check(self):
        if not self.need_check():
            return Result.NOT_PASS, ""

        flag, pool_id_status_dict, err_msg = self.query_pool_infos()
        if not flag:
            return Result.NOT_PASS, err_msg

        # 判断所有存储池状态是否正常，若都是正常的，则检查通过
        all_pool_normal = True
        for pool_status in pool_id_status_dict.values():
            if pool_status != PoolStatus.NORMAL:
                all_pool_normal = False
                break
        if all_pool_normal:
            return Result.PASS, ""

        if self.device.getDeviceEntity().isSupportDiskPool():
            flag, err_msg = self.check_disk_pool_status()
            return flag, err_msg

        flag, current_node_pool_id, err_msg = self.get_curr_node_pool_id(
            pool_id_status_dict)
        if not flag:
            return Result.NOT_PASS, err_msg
        if current_node_pool_id == "":
            self.logger.info("current node pool id is null.")
            return Result.PASS, ""

        # 循环查询当前节点所在存储池的状态
        last_progress = 0
        start_time = time.time()
        while True:
            if time.time() - start_time > Timeout.QUERY_POOL_STATUS_TIMEOUT:
                err_msg = utils.get_msg(
                    self.lang, "curr.node.pool.status.timeout")
                return Result.NOT_PASS, err_msg
            flag, pool_status, progress, err_msg = \
                self.query_single_pool_status(current_node_pool_id)
            if not flag:
                return Result.NOT_PASS, err_msg
            if pool_status == PoolStatus.NORMAL:
                return Result.PASS, ""
            elif pool_status == PoolStatus.DEGRADED:
                if last_progress != progress:
                    start_time = time.time()
                    last_progress = progress
                time.sleep(Timeout.QUERY_POOL_STATUS_INTERVAL_TIME)
                continue
            else:
                err_msg = utils.get_msg(
                    self.lang, "curr.node.pool.status.abnomal")
                return Result.NOT_PASS, err_msg

    def need_check(self):
        if self.device.getDeviceEntity().getServiceType().equals(NodeUpgradePriorityEnums.NO_SERVICE):
            self.logger.error("The device has no service:%s" % str(self.manage_ip))
            return False
        return not service_utils.check_node_is_offline(self.context)

    def parse_pool_data_2_dict(self, data):
        pool_id_status_dict = {}
        try:
            pool_info_list = data.get("storagePools", [])
            for pool_info_dict in pool_info_list:
                pool_id = pool_info_dict.get("poolId")
                pool_status = pool_info_dict.get("poolStatus")
                pool_id_status_dict[pool_id] = pool_status
        except Exception as e:
            self.logger.error("parse pool status failed:%s" % str(e))
            err_msg = utils.get_msg(self.lang, "check.pool.status.error")
            return False, {}, err_msg

        return True, pool_id_status_dict, ""

    def query_pool_infos(self):
        url = "/dsware/service/resource/queryStoragePool"
        cmd_dict = dict(url=url)
        cmd_dict["retry"] = Timeout.DEFAULT_RETRY
        cmd_dict["interval"] = Timeout.QUERY_POOL_STATUS_INTERVAL_TIME
        flag, data, err_msg = rest_util.exe_rest(self.context, cmd_dict)
        if not flag:
            return False, "", err_msg
        return self.parse_pool_data_2_dict(data)

    def query_node_disk_info(self, pool_id):
        url = "/dsware/service/cluster/storagepool/queryNodeDiskInfo"
        params = dict(poolId=pool_id)
        cmd_dict = dict(url=url, params=params)
        return rest_util.exe_rest(self.context, cmd_dict)

    def get_curr_node_pool_id(self, pool_id_status_dict):
        for pool_id in pool_id_status_dict.keys():
            flag, data, err_msg = self.query_node_disk_info(pool_id)
            if not flag:
                return flag, "", err_msg
            node_info_list = data.get("nodeInfo")
            for node_info in node_info_list:
                node_mgr_ip = node_info.get("nodeMgrIp", "")
                if self.manage_ip == node_mgr_ip:
                    return True, pool_id, ""
        return True, "", ""

    def query_single_pool_status(self, pool_id):
        url = "/dsware/service/resource/queryStoragePool"
        params = dict(poolId=pool_id)
        cmd_dict = dict(url=url, params=params)
        flag, data, err_msg = rest_util.exe_rest(self.context, cmd_dict)
        pools_info_list = data.get("storagePools")
        if len(pools_info_list) == 0:
            err_msg = utils.get_msg(self.lang, "check.pool.status.error")
            return False, "", "", err_msg
        pool_info = pools_info_list[0]
        return True, pool_info.get("poolStatus"), pool_info.get("progress"), ""

    def query_single_disk_pool_status(self, pool_id):
        url = "/api/v2/data_service/diskpool"
        params = dict(diskPoolId=pool_id)
        cmd_dict = dict(url=url, params=params)
        flag, data, err_msg = rest_util.exe_rest(self.context, cmd_dict)
        pools_info_list = data.get("diskPools", [])
        if not pools_info_list:
            return [False, "", "", utils.get_msg(self.lang, "check.disk.pool.status.error")]
        pool_info = pools_info_list[0]
        return [True, pool_info.get("poolStatus"), int(pool_info.get("progress")), ""]

    def check_disk_pool_status(self):
        disk_pool_id = self.device.getDeviceEntity().getDiskPoolId()
        if not disk_pool_id:
            self.logger.info("current node pool id is null.")
            return Result.PASS, ""
        # 循环查询当前节点所在硬盘池的状态
        last_progress = 0
        start_time = time.time()
        while True:
            if time.time() - start_time > Timeout.QUERY_POOL_STATUS_TIMEOUT:
                err_msg = utils.get_msg(self.lang, "curr.node.disk.pool.status.timeout")
                return Result.NOT_PASS, err_msg
            flag, pool_status, progress, err_msg = self.query_single_disk_pool_status(disk_pool_id)
            if not flag:
                return Result.NOT_PASS, err_msg
            if pool_status == PoolStatus.NORMAL:
                return Result.PASS, ""
            elif pool_status == PoolStatus.DEGRADED:
                if last_progress != progress:
                    start_time = time.time()
                    last_progress = progress
                time.sleep(Timeout.QUERY_POOL_STATUS_INTERVAL_TIME)
                continue
            else:
                err_msg = utils.get_msg(self.lang, "curr.node.disk.pool.status.abnomal")
                return Result.NOT_PASS, err_msg
