# -*- coding:utf-8 -*-
import os

import utils.common.log as logger
from utils.common.exception import HCCIException

from plugins.eReplication.common.api.file_api import API as FILE_API
from plugins.eReplication.common.arb_api import API as ARB_API
from plugins.eReplication.common.client.mo_client import API as MO_API
from plugins.eReplication.common.client.ssh_client import API as SSH_API
from plugins.eReplication.common.lib.base import BaseSubJob
from plugins.eReplication.common.lib.conditions import Condition
from plugins.eReplication.common.lib.decorator import handle_task_result
from plugins.eReplication.common.lib.model import Auth, SSHClientInfo
from plugins.eReplication.common.lib.params import Nodes
from plugins.eReplication.common.lib.params import Params
from plugins.eReplication.common.lib.utils import check_host_connection
from plugins.eReplication.common.lib.utils import check_param_ip
from plugins.eReplication.common.os_api import API as OS_API
from plugins.eReplication.common.request_api import RequestApi
from plugins.eReplication.common.request_api import SERVICE_INVALID
from plugins.eReplication.common.request_api import SERVICE_IP_NOT_CONNECTED
from plugins.eReplication.common.request_api import SERVICE_PWD_ERROR
from plugins.eReplication.common.constant import ArbInfo

EXPIRING_DAY = 7
logger.init("PreinstallCheck")


class PreinstallCheck(BaseSubJob):
    """校验参数是否正确，包括规划IP可用性校验，已有IP登录性校验，服务可用性校验"""

    def __init__(self, project_id, pod_id, regionid_list=None):
        super(PreinstallCheck, self).__init__(
            project_id, pod_id, regionid_list)
        self.condition = Condition(project_id)
        self.nodes = Nodes(self.project_id, self.pod_id)

    @handle_task_result
    def execute(self, project_id, pod_id, *args, **kwargs):
        """
        标准调用接口：执行安装前预检查&安装&配置
        :param project_id:
        :param pod_id:
        :return:Message类对象
        """
        # 在容灾站点安装的容灾服务的时候，校验主站点的两个节点IP能够连通、登录
        if self.condition.need_check_server_login:
            self.check_dr_host_ip_validity_and_can_login()
        # 在环境上已经安装过容灾服务，现在安装另外的容灾服务的时候，校验原容灾服务IP能够连通
        if self.condition.is_current_dr_installed:
            self.preset_env_config()
        if self.condition.is_service_installed:
            self.check_replication_ip_port_validity_and_service_ok()
        if self.condition.is_csha:
            # 在CSHA场景下面，校验三方仲裁服务器IP能够连通
            if self.condition.csha_install_quorum:
                self.check_csha_quorum_ip_is_validity_and_can_login()
            if self.condition.need_check_cloud_quorum_login:
                self.check_cloud_quorum_ip_validity_and_can_login()

    @staticmethod
    def initialize_error_message(check_result, key_map=None):
        if key_map:
            error_message = ", ".join(
                [str({map_key: result_key}) for map_key in key_map for
                 result_key in check_result if
                 ((isinstance(key_map[map_key], str) and key_map[
                     map_key] == result_key) or
                  (isinstance(key_map[map_key], list) and result_key in
                   key_map[map_key])) and not check_result[result_key]])
        else:
            error_message = ", ".join(
                [dict_key for dict_key in check_result.keys()
                 if not check_result[dict_key]])
        return error_message

    def check_ips_validity(self, ips, key_map: dict = None):
        check_result = dict()
        for ip_address in ips:
            is_ip_validity = check_param_ip(ip_address)
            check_result[ip_address] = is_ip_validity
        if False not in check_result.values():
            logger.info(f"Check ips format return {check_result}.")
            return True
        error_message = self.initialize_error_message(check_result, key_map)
        raise HCCIException("663501", error_message)

    @staticmethod
    def check_ip_connection(ip_address, is_exist=True):
        if is_exist:
            if not check_host_connection(ip_address):
                return False
        else:
            if check_host_connection(ip_address):
                return False
        return True

    def check_ips_connection(
            self, ips, is_exist: bool = True, key_map: dict = None):
        check_result = dict()
        for ip_address in ips:
            can_ip_connection = self.check_ip_connection(
                ip_address, is_exist=is_exist)
            check_result[ip_address] = can_ip_connection
        if False not in check_result.values():
            logger.info(f"Check ips connection return {check_result}.")
            return True
        error_message = self.initialize_error_message(check_result, key_map)
        if is_exist:
            raise HCCIException("663507", error_message)
        raise HCCIException("663606", error_message)

    @staticmethod
    def check_server_login(
            host, ssh_user=None, ssh_pwd=None, sudo_user=None, sudo_pwd=None):
        ssh_client = None
        if sudo_user and sudo_pwd:
            try:
                ssh_client = SSH_API.get_sudo_ssh_client(
                    Auth(host, ssh_user, ssh_pwd, sudo_user, sudo_pwd))
            except Exception as err:
                logger.error(f"Login {host} return: {err}.")
            finally:
                SSH_API.close_ssh(ssh_client)
        else:
            try:
                ssh_client = SSH_API.get_ssh_client(SSHClientInfo(host, ssh_user, ssh_pwd))
            except Exception as err:
                logger.error(f"Login {host} return: {err}.")
            finally:
                SSH_API.close_ssh(ssh_client)
        if ssh_client:
            return True
        return False

    def _check_servers_login(self, ips, ssh_user=None, ssh_pwd=None):
        check_result = dict()
        for ip_address in ips:
            can_server_login = self.check_server_login(
                ip_address, ssh_user, ssh_pwd, None, None)
            check_result[ip_address] = can_server_login
        logger.info(f"Check servers login return: {check_result}.")
        if False not in check_result.values():
            return True
        error_message = self.initialize_error_message(check_result, None)
        raise HCCIException("663510", error_message)

    @staticmethod
    def check_drm_server_login(host, ssh_user, ssh_pwd, sudo_user, sudo_pwd):
        auth_provider = Auth(host, ssh_user, ssh_pwd, sudo_user, sudo_pwd)
        return_data = OS_API.check_os_password_expired_day(
            auth_provider, [ssh_user, sudo_user])
        drm_expire_day = return_data.get(ssh_user)
        root_expire_day = return_data.get(sudo_user)
        return True, drm_expire_day, root_expire_day

    def _check_drm_servers_login(self, ips, key_map: dict = None):
        check_result = dict()
        ssh_user = self.nodes.ssh_user
        ssh_pwd = self.nodes.ssh_pwd
        sudo_user = self.nodes.sudo_user
        sudo_pwd = self.nodes.sudo_pwd
        for ip_address in ips:
            result = self.check_drm_server_login(
                ip_address, ssh_user, ssh_pwd, sudo_user, sudo_pwd)
            check_result[ip_address] = result
        login_failed_dic = \
            {ip_address: check_result.get(ip_address)[0] for ip_address in
             check_result.keys() if not check_result.get(ip_address)[0]}
        drm_expiring_dic = \
            {ip_address: False for ip_address in check_result.keys() if
             check_result.get(ip_address)[1] < EXPIRING_DAY}
        root_expiring_dic = \
            {ip_address: False for ip_address in check_result.keys() if
             check_result.get(ip_address)[2] < EXPIRING_DAY}
        # 检查结果中有登录失败的节点
        logger.debug(f"Check servers login return {login_failed_dic}.")
        logger.debug(
            f"Check servers ssh user expire return {drm_expiring_dic}.")
        logger.debug(
            f"Check servers sudo user expire return {root_expiring_dic}.")
        if login_failed_dic:
            message = self.initialize_error_message(login_failed_dic, key_map)
            raise HCCIException("663511", message)
        message, drm_message, root_message = None, None, None
        # DRManager用户密码即将过期
        if drm_expiring_dic:
            drm_message = \
                self.initialize_error_message(drm_expiring_dic, key_map)
        # root用户密码即将过期
        if root_expiring_dic:
            drm_message = \
                self.initialize_error_message(root_expiring_dic, key_map)
        if drm_message and root_message:
            message = f"DRManager: {drm_message}, root: {drm_message}"
        elif drm_message and not root_message:
            message = f"DRManager: {drm_message}"
        elif not drm_message and root_message:
            message = f"root: {root_message}"
        if message:
            raise HCCIException("663513", message)
        logger.info("Check drm servers login return success.")

    def check_dr_host_ip_validity_and_can_login(self):
        is_dr_site = self.condition.is_dr_site
        ip_lst = \
            self.nodes.primary_hosts if is_dr_site else self.nodes.all_hosts
        if not ip_lst:
            logger.error("Get all server node ips failed, please check.")
            raise Exception("Get all server node ips failed, please check.")
        key_map = {
            "eReplication_Primary_IP": self.nodes.primary_ip,
            "eReplication_Second_IP": self.nodes.second_ip
        } if is_dr_site else None
        self.check_ips_validity(ip_lst, key_map=key_map)
        self.check_ips_connection(ip_lst, is_exist=True, key_map=key_map)
        self._check_drm_servers_login(ip_lst, key_map=key_map)
        logger.info("Login dr host ips{} success.".format(ip_lst))

    def check_replication_ip_port_validity_and_service_ok(self):
        # 校验IP格式是否正确
        ip_validity = check_param_ip(self.nodes.service_ip)
        if not ip_validity:
            raise HCCIException("663501", "eReplication_ip")
        # 校验IP通达性(可以连通)
        ip_connection_check = \
            self.check_ip_connection(self.nodes.service_ip, is_exist=True)
        if not ip_connection_check:
            raise HCCIException("663507", self.nodes.service_ip)
        request_api = RequestApi(
            self.nodes.service_ip, self.nodes.service_name,
            self.nodes.service_pwd, self.nodes.service_port, raise_ex=False)
        code, msg = request_api.check_dr_service()
        if code == SERVICE_PWD_ERROR:
            raise HCCIException("663505")
        if code == SERVICE_INVALID:
            raise HCCIException(
                '663506', self.nodes.service_ip, self.nodes.service_port, msg)
        if code == SERVICE_IP_NOT_CONNECTED:
            raise HCCIException('663507', self.nodes.service_ip)
        logger.info("Service check OK.")

    def check_csha_quorum_ip_is_validity_and_can_login(self):
        # 校验IP格式是否正确
        if not check_param_ip(self.nodes.arb_ip):
            raise HCCIException("663501", ArbInfo.ARB_IP)
        # 校验IP通达性(可以连通)
        if not self.check_ip_connection(self.nodes.arb_ip, is_exist=True):
            raise HCCIException("663507", self.nodes.arb_ip)
        if not self.check_server_login(
                self.nodes.arb_ip, self.nodes.arb_sudo_user,
                self.nodes.arb_sudo_pwd):
            logger.error(f"Failed to login {self.nodes.arb_ip}.")
            raise HCCIException(
                '663607', 'csha_quorum_ip',
                'csha_quorum_vm_sudo_password')
        self._rewrite_sys_config_for_quorum_install()
        logger.info(f"Login {self.nodes.arb_ip} success.")

    def _rewrite_sys_config_for_quorum_install(self):
        arch = ARB_API(self.project_id, self.pod_id).get_vm_arch()
        file_path = os.path.join(
            os.path.dirname(os.path.dirname(os.path.dirname(
                os.path.dirname(__file__)))), "scene", "deploy",
            "params", "inspect.json")
        replaced_to = "ARM_64" if arch == "X86_64" else "X86_64"
        with FILE_API.open_file_manager(file_path) as file:
            inspect_content = file.readlines()
        pkg_define = ""
        replace_index = 0
        for index, line in enumerate(inspect_content):
            if "CSHA&(TenantStorIPSAN|TenantStorFCSAN)&" \
               "install_san_quorumserver" in line:
                replace_index = index + 3
                pkg_define = inspect_content[replace_index]
                break
        pkg_define = pkg_define.replace(
            f"(?<!_{arch})", f"(?<!_{replaced_to})")
        pkg_define = pkg_define.replace(f"(_{replaced_to})?", f"(_{arch})?")
        inspect_content[replace_index] = pkg_define
        with FILE_API.open_file_manager(file_path, "w") as f_w:
            f_w.writelines(inspect_content)
        logger.info(f"Rewrite sys config {file_path} success.")

    def check_cloud_quorum_ip_validity_and_can_login(self):
        params = Params(self.project_id, self.pod_id)
        arb_info = params.arb_info
        arb_ips = arb_info['ips']
        quorum_user = params.arb_admin_user
        quorum_pwd = params.arb_admin_user_pwd
        self._check_servers_login(
            arb_ips, ssh_user=quorum_user, ssh_pwd=quorum_pwd)
        logger.info(f"Login quorum server {arb_ips} success.")

    def preset_env_config(self):
        params = Params(self.project_id, self.pod_id)
        config_path = os.path.join(
            os.path.dirname(os.path.dirname(os.path.dirname(
                os.path.dirname(__file__)))), 'conf', 'env.ini')
        env = FILE_API(config_path)
        cloud_service_info = MO_API(self.project_id, self.pod_id). \
            get_cloud_service_info(params.project_region_id)
        service_info = cloud_service_info[0]
        extend_infos = service_info.get("extendInfos")
        for extend_info in extend_infos:
            if extend_info.get("key") == "CSDR_supported_block":
                env.set_value_by_key_and_sub_key(
                    "CSDR_supported_type", "block", extend_info.get("value"))
            if extend_info.get("key") == "CSDR_supported_nas":
                env.set_value_by_key_and_sub_key(
                    "CSDR_supported_type", "nas", extend_info.get("value"))
