# -*- coding:utf-8 -*-

import utils.common.log as logger
from utils.business.param_util import ParamUtil
from utils.common.check_result import CheckResult
from utils.common.exception import HCCIException
from utils.common.ssh_util import Ssh

from plugins.eReplication.common.client.ssh_client import API as SSH_API
from plugins.eReplication.common.lib.base import BaseSubJob
from plugins.eReplication.common.lib.decorator import handle_task_check_result
from plugins.eReplication.common.lib.params import Nodes

logger.init("eReplication")


class CheckCommunicationMatrix(BaseSubJob):
    """校验升级过程的所需要的通信矩阵是否放开"""

    @handle_task_check_result
    def execute(self, project_id, pod_id, *args, **kwargs):
        """标准调用接口：执行安装&配置

        :param project_id:
        :param pod_id:
        :return:
        """
        # 第一步：校验是否放通执行机和Server的22端口
        results = list()
        res, dmk_ips, results = self.get_dmk_node_info(results)
        if not res:
            return results
        check_server_res, results = \
            self.check_server_common_port(dmk_ips)
        if not check_server_res:
            return results
        check = CheckResult(
            itemname_ch="通信矩阵校验",
            itemname_en="Communication matrix check", status="success")
        results.append(check)
        return results

    def get_dmk_node_info(self, results):
        om_srv1_info = ParamUtil().get_vm_info_by_vm_name(
            self.project_id, "OM-SRV-DC1")
        om_srv2_info = ParamUtil().get_vm_info_by_vm_name(
            self.project_id, "OM-SRV-DC2")
        if not om_srv1_info and not om_srv2_info:
            om_srv1_info = ParamUtil().get_vm_info_by_vm_name(
                self.project_id, "OM-SRV-01")
            om_srv2_info = ParamUtil().get_vm_info_by_vm_name(
                self.project_id, "OM-SRV-02")
            if not om_srv1_info and not om_srv2_info:
                logger.error("DMK server not found, please check.")
                check = CheckResult(
                    itemname_ch="通信矩阵校验",
                    itemname_en="Communication matrix check",
                    status="failure", error_msg_cn=HCCIException(
                        675027, "DMK"))
                results.append(check)
                return False, [], results
        om_srv1_ip = om_srv1_info[0].get("ip", None)
        om_srv2_ip = om_srv2_info[0].get("ip", None)
        return True, [om_srv1_ip, om_srv2_ip], results

    def check_server_common_port(self, dmk_ips):
        nodes = Nodes(self.project_id, self.pod_id)
        check_server_res = False
        results = list()
        for host in nodes.all_hosts:
            try:
                check_server_res, results = self._do_check_server_common_port(
                    dmk_ips, host, nodes)
            except Exception as err_msg:
                logger.error(
                    f"Failed to connect to the Server node using SSH, "
                    f"reason is: {str(err_msg)}.")
                check = CheckResult(
                    itemname_ch="通信矩阵校验",
                    itemname_en="Communication matrix check",
                    status="failure", error_msg_cn=HCCIException(
                        675032, str(err_msg)))
                results.append(check)
                return False, results
            if not check_server_res:
                return check_server_res, results
        return check_server_res, results

    @staticmethod
    def _do_check_server_common_port(dmk_ips, host, nodes):
        results = list()
        ssh_client = SSH_API.get_sudo_ssh_client(
            host, nodes.ssh_user, nodes.ssh_pwd, nodes.sudo_user,
            nodes.sudo_pwd)
        # 第二步：校验是否放通DMK和Server的22端口 -- 升级Server --
        # 在Server节点上面使用wget命令校验是否放通端口
        # 指定参数--spider不下载文件
        for om_srv_ip in dmk_ips:
            result = SSH_API.send_command(
                ssh_client,
                f"wget --spider -T 10 -t 3 {om_srv_ip}:22", "#")
            if "connected" in ",".join(result):
                logger.info(
                    f"Port 22 between {host} and "
                    f"{om_srv_ip} is enabled.")
            elif "failed" in ",".join(result):
                logger.error(
                    f"Port 22 between {host} and "
                    f"{om_srv_ip} is disabled.")
                check = CheckResult(
                    itemname_ch="通信矩阵校验",
                    itemname_en="Communication matrix check",
                    status="failure", error_msg_cn=HCCIException(
                        675028, host, om_srv_ip, "22"))
                results.append(check)
                return False, results
        return True, results
