#  coding=utf-8
#  Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.

import json

from py.common.entity.check_result import CheckResult
from py.common.entity.check_item import CheckItem
from py.common.util.connection_util import RestService, SshService

NEED_MODIFY_CONFIG_ITEM = \
    ("[ccdb_monitor]", "[ccdb_server]", "[djob]", "[dlm_vm]", "[kmm_server]", "[dlm_mm]", "[fdsaCore]")
QUERY_MANAGE_CLUSTER_URL = "/dsware/service/cluster/queryManageCluster"
DS_CONFIG_FILE_PATH = "/opt/dsware/infrastructure/cm/config/ds_config.cfg"


def execute(env, method):
    return ChangeExpansionFlag(env).execute(method)


class ChangeExpansionFlag(CheckItem):
    def __init__(self, env):
        super(ChangeExpansionFlag, self).__init__(env)
        self.cluster = self.task_env.get_dev_node()
        self.expansion_conf_node_list = self.task_env.get_expansion_conf_nodes()
        self.dev_node_dict = self.task_env.get_cluster_node_dict()
        self.cache_dict = self.task_env.get_task_cache()

    def do_check(self):
        """
        修改扩容标志
        1、没有ds_config.cfg文件，返回不涉及
        2、将指定模块的扩容标志从0修改为1
        """
        ori_info = []
        zk_node_list = self.get_zk_node_list()
        expansion_node_list = self.get_expansion_node_list()
        zk_node_list.extend(expansion_node_list)
        for node in zk_node_list:
            ssh_service = SshService(node)
            involve, res_dict = self.modify_expand_shrink(ssh_service)
            ssh_service.release_ssh()
            ori_info.append(res_dict)
            if not involve:
                return CheckResult(CheckResult.NOT_INVOLVED, "\n".join(ori_info), "")
        return CheckResult(CheckResult.PASS, "\n".join(ori_info), "")

    def get_expansion_node_list(self):
        expansion_node_list = []
        for node in self.expansion_conf_node_list:
            expansion_node_list.append(self.dev_node_dict.get(node.getIp()))
        return expansion_node_list

    def get_zk_node_list(self):
        res_dict = json.loads(RestService(self.cluster).execute_get(QUERY_MANAGE_CLUSTER_URL))
        self.cache_dict["clusterName"] = res_dict.get("clusterName")
        zk_node_list = []
        for node_info in res_dict.get("nodeInfo"):
            node_ip = node_info.get("nodeMgrIp")
            if node_ip in self.dev_node_dict.keys():
                zk_node_list.append(self.dev_node_dict.get(node_ip))
        return zk_node_list

    def modify_expand_shrink(self, ssh_service):
        cli_ret = ssh_service.execute_cmd("cat {}".format(DS_CONFIG_FILE_PATH))
        need_modify = 0
        # 没有ds_config.cfg文件说明不需要修改扩容标志
        if "No such file or directory" in cli_ret:
            return False, cli_ret
        for index, line in enumerate(cli_ret.splitlines()):
            if self.process_is_need_modify(line):
                need_modify = 1
            if "is_allow_expand_shrink=0" in line and need_modify:
                ssh_service.execute_cmd("sed -i '{}s/is_allow_expand_shrink=0/is_allow_expand_shrink=1/' {}".
                                        format(index, DS_CONFIG_FILE_PATH))
            if "force_shrink=0" in line and need_modify:
                ssh_service.execute_cmd("sed -i '{}s/force_shrink=0/force_shrink=1/' {}".
                                        format(index, DS_CONFIG_FILE_PATH))
            if not line:
                need_modify = 0
        return True, ssh_service.execute_cmd("cat {}".format(DS_CONFIG_FILE_PATH))

    @staticmethod
    def process_is_need_modify(line):
        for process in NEED_MODIFY_CONFIG_ITEM:
            if process in line:
                return True
        return False
