#  coding=utf-8
#  Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.

import json

from py.common.entity.check_result import CheckResult
from py.common.entity.check_item import CheckItem
from py.common.util.connection_util import RestService

PROCESS_LIST = ["ZK", "CM", "CCDB_MONITOR", "CCDB_SERVER", "DJOB", "DLM_MM", "KMM_SERVER"]
QUERY_SERVICE_PROCESS_URL = "/api/v2/cluster_service/service_processes?management_ips={}"


def execute(env, method):
    return CheckProcesses(env).execute(method)


class CheckProcesses(CheckItem):
    def __init__(self, env):
        super(CheckProcesses, self).__init__(env)
        self.cluster = self.task_env.get_dev_node()
        self.err_msg = []
        self.expansion_node_ips = [node.getIp() for node in self.task_env.get_expansion_conf_nodes()]

    def do_check(self):
        """
        检查新扩节点进程
        """
        url = QUERY_SERVICE_PROCESS_URL.format(",".join(self.expansion_node_ips))
        res_dict = json.loads(RestService(self.cluster).execute_get(url))
        process_info_list = res_dict.get("data")
        for process_info in process_info_list:
            self.check_process(process_info.get("processes"), process_info.get("management_ip"))
        if self.err_msg:
            return CheckResult(CheckResult.NOT_PASS, res_dict, "\n".join(self.err_msg))
        return CheckResult(CheckResult.PASS, res_dict, "")

    def check_process(self, info_list, management_ip):
        name_list = []
        for info in info_list:
            name_list.append(info.get("process_name"))
        lack_process_list = [i for i in PROCESS_LIST if i not in name_list]
        if lack_process_list:
            self.err_msg.append(
                self.get_msg("check.expand.node.process.not.pass", (management_ip, ",".join(lack_process_list))))
