#  coding=utf-8
#  Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.

import json
import time

from py.common.entity.check_result import CheckResult
from py.common.entity.check_item import CheckItem
from py.common.util.connection_util import RestService

CREATE_ZK_EXPANSION_TASK_URL = "/dsware/service/cluster/addZookeeper"
QUERY_DISK_INFO_URL = "/dsware/service/resource/queryDiskInfo?ip={}"
TASK_PROCESS_URL = "/dsware/service/task/queryTaskInfo"
# 轮询zk扩容任务进度时间间隔
QUERY_ZK_POLL_TIMES = 30


def execute(env, method):
    return ExpansionZk(env).execute(method)


class ExpansionZk(CheckItem):
    def __init__(self, env):
        super(ExpansionZk, self).__init__(env)
        self.cluster = self.task_env.get_dev_node()
        self.expansion_conf_node_list = self.task_env.get_expansion_conf_nodes()
        self.cache_dict = self.task_env.get_task_cache()

    def do_check(self):
        """
        创建zk扩容任务
        """
        param = self.get_param()
        res_dict = json.loads(RestService(self.cluster).execute_post(CREATE_ZK_EXPANSION_TASK_URL, param))
        if res_dict.get("result") != 0:
            return CheckResult(CheckResult.NOT_PASS, res_dict, self.get_msg("create.zk.task.fail"), True)
        # 创建zk任务后后，轮询任务进度接口判断任务是否完成,接口超时时间30分钟
        wait_sec = 0
        while wait_sec < QUERY_ZK_POLL_TIMES * 60:
            task_infos = json.loads(
                RestService(self.cluster).execute_get(TASK_PROCESS_URL))
            task_status = self.query_task_status(task_infos, res_dict.get("taskId"))
            if "failed" == task_status:
                return CheckResult(CheckResult.NOT_PASS, res_dict, self.get_msg("query.zk.task.fail"))
            if "success" == task_status:
                return CheckResult(CheckResult.PASS, res_dict, "")
            time.sleep(QUERY_ZK_POLL_TIMES)
            wait_sec += QUERY_ZK_POLL_TIMES
        return CheckResult(CheckResult.NOT_PASS, res_dict, self.get_msg("query.zk.task.fail"))

    def get_param(self):
        param = {"clusterName": self.cache_dict.get("clusterName")}
        server_list = []
        for conf_node in self.expansion_conf_node_list:
            meta_disk = conf_node.getMetaDisk()
            if meta_disk == "sys_disk":
                server_list.append({"nodeMgrIp": conf_node.getIp(), "zkType": meta_disk})
            else:
                server_list.append(self.query_meta_disk_info(meta_disk, conf_node.getIp()))
        param["serverList"] = server_list
        return param

    def query_meta_disk_info(self, esn, ip):
        disk_info_dict = {}
        res_dict = json.loads(RestService(self.cluster).execute_get(QUERY_DISK_INFO_URL.format(ip)))
        for disk_info in res_dict.get("disks"):
            if disk_info.get("diskSn") == esn:
                disk_info_dict["zkDiskSlot"] = disk_info.get("diskSlot")
                disk_info_dict["zkType"] = disk_info.get("diskType")
                disk_info_dict["nodeMgrIp"] = ip
                disk_info_dict["zkDiskEsn"] = esn
                return disk_info_dict
        return {}

    @staticmethod
    def query_task_status(task_infos, task_id):
        if task_infos.get("result") != 0:
            return "failed"
        for task_info in task_infos.get("taskInfo"):
            if task_info.get("taskId") == task_id:
                return task_info.get("taskStatus")
        return "failed"
