# -*- coding:utf-8 -*-
import utils.common.log as logger
from plugins.CSDR_CSHA_VHA.common.AutoRetry import retry
from plugins.CSDR_CSHA_VHA.common.CommonDefine import PathValue
from plugins.CSDR_CSHA_VHA.common.CommonUtil import get_all_server_nodes
from plugins.CSDR_CSHA_VHA.common.CommonUtil import get_ha_all_server_nodes
from plugins.CSDR_CSHA_VHA.common.CommonUtil import get_install_components
from plugins.CSDR_CSHA_VHA.common.CommonUtil import get_server_node_region_id
from plugins.CSDR_CSHA_VHA.common.CommonUtil import get_server_params
from plugins.CSDR_CSHA_VHA.common.CommonUtil import get_service_type
from plugins.CSDR_CSHA_VHA.common.ServerProcessor import ServerProcessor
from plugins.CSDR_CSHA_VHA.common.ServerProcessor import UpgradeProcess
from plugins.CSDR_CSHA_VHA.common.Validater import judge_upgrade_step_result
from utils.business.manageone_cmdb_util import ManageOneCmdbUtil
from utils.business.param_util import ParamUtil
from utils.common.exception import FCUException
from utils.common.fic_base import StepBaseInterface
from utils.common.message import Message

logger.init("eReplication")


class UpgradeBackService(StepBaseInterface):
    def __init__(self, project_id, pod_id, regionid_list):
        super(UpgradeBackService, self).__init__(
            project_id, pod_id, regionid_list)
        self.project_id = project_id
        self.pod_id = pod_id
        self.regionid_list = regionid_list
        self.region = regionid_list[0]
        self.service_type = get_service_type(self.project_id)

        server_params = get_server_params(
            self.project_id, self.region, self.service_type)
        self.server_ip1 = server_params["server_ip1"]
        self.server_business_user = server_params["server_business_user"]
        self.server_business_user_pwd = server_params[
            "server_business_user_pwd"]
        self.server_root_pwd = server_params["server_root_pwd"]
        self.mo_cmdb_ins = ManageOneCmdbUtil(self.project_id)

    def pre_check(self, project_id, pod_id, regionid_list):
        """插件内部接口：执行安装前的资源预检查，该接口由execute接口调用，

        工具框架不会直接调用此接口
        :return:
        """

        return Message(200)

    @retry(3, 20, 20, (FCUException, Exception),
           validate=judge_upgrade_step_result)
    def execute(self, project_id, pod_id, regionid_list):
        """标准调用接口：执行安装&配置

        :param project_id:
        :param pod_id:
        :param regionid_list:
        :return:
        """

        try:
            logger.info("Upgrade eReplication service process.")
            if not self.server_ip1:
                logger.error("Get server ip failed.")
                raise FCUException("665008")
            primary_region_id = ""
            standby_region_id = ""
            # 获取一主三备场景下所有的server ip
            all_ips = get_all_server_nodes(
                self.server_ip1, self.server_business_user,
                self.server_business_user_pwd, self.server_root_pwd)
            if len(all_ips) not in [2, 4] or self.server_ip1 not in all_ips:
                logger.error(
                    f"Query eReplication ip from {self.server_ip1} failed.")
                raise FCUException("665006")
            # 获取当前region所有的server ip
            current_region_ips = get_ha_all_server_nodes(
                self.server_ip1, self.server_business_user,
                self.server_business_user_pwd, self.server_root_pwd)
            if len(current_region_ips) < 1 or len(current_region_ips) > 2 \
                    or self.server_ip1 not in current_region_ips:
                logger.error(
                    f"Query eReplication ip from {self.server_ip1} failed.")
                raise FCUException("665006")

            # 区分主region ip和备region ip
            process = UpgradeProcess(
                self.pod_id, self.project_id, self.region, self.service_type)
            if set(all_ips) == set(current_region_ips):
                process.upgrade_service(all_ips, project_id, self.region)
            else:
                primary_ips = current_region_ips
                for primary_ip in primary_ips:
                    all_ips.remove(primary_ip)
                standby_ips = all_ips
                # 获取region id
                for primary_ip in primary_ips:
                    primary_region_id = get_server_node_region_id(
                        primary_ip, self.server_business_user,
                        self.server_business_user_pwd, self.server_root_pwd)
                    if primary_region_id:
                        logger.info(
                            f"Get primary region result: "
                            f"{primary_region_id}.")
                        break
                for standby_ip in standby_ips:
                    standby_region_id = get_server_node_region_id(
                        standby_ip, self.server_business_user,
                        self.server_business_user_pwd, self.server_root_pwd)
                    if standby_region_id:
                        logger.info(
                            f"Get standby region result: "
                            f"{standby_region_id}.")
                        break
                process.upgrade_service(
                    primary_ips, project_id, primary_region_id)
                process.upgrade_service(
                    standby_ips, project_id, standby_region_id)

                # 重新配置一主三备关系
                primary_ips.extend(standby_ips)
                all_ips = primary_ips
                region_ids = f"{primary_region_id},{standby_region_id}"
                process.config_one_a_three_s(
                    all_ips, project_id=self.project_id,
                    region_ids=region_ids)
                logger.info(
                    "Upgrade eReplication successfully in all nodes.")
            self.init_sys_config(all_ips)
            process.import_mo_log_cert(all_ips, self.project_id, self.pod_id)
            process.modify_ha_config_owner_group(all_ips)
            return Message(200)
        except FCUException as fe:
            logger.error(
                f"Upgrade eReplication failed, err_msg={str(fe)}.")
            return Message(500, fe)
        except Exception as e:
            logger.error(
                f"Upgrade eReplication failed, err_msg={str(e)}.")
            return Message(500, FCUException('665002', str(e)))

    def rollback(self, project_id, pod_id, regionid_list):
        """标准调用接口：执行回滚

        :return: Message对象
        """

        return Message(200)

    def retry(self, project_id, pod_id, regionid_list):
        """标准调用接口：重试

        :return:
        """

        return self.execute(project_id, pod_id, regionid_list)

    def check(self, project_id, pod_id, regionid_list):
        """插件内部接口：执行注册，该接口由execute接口调用，

        工具框架不会直接调用此接口
        :return:
        """

        return Message(200)

    def init_sys_config(self, all_ips):
        server_ins = ServerProcessor(
            self.server_ip1, self.server_business_user,
            self.server_business_user_pwd, self.server_root_pwd)
        write_iam, write_sdr, write_domain = \
            server_ins.check_before_rewrite_config_file()
        if write_iam or write_sdr or write_domain:
            mo_public_params = ParamUtil().get_service_cloud_param(
                self.project_id, "ManageOne_public_params", self.region)
            global_domain_name = mo_public_params.get(
                "ManageOne_global_domain_name")
            if not global_domain_name:
                global_domain_name = ParamUtil().get_value_from_cloud_param(
                    self.project_id, "ManageOne",
                    "ManageOne_global_domain_name", self.region)
            all_account = self.get_iam_accounts()
            sdr_infos = self.get_sdr_info()
            for ip in all_ips:
                server_ins = ServerProcessor(
                    ip, self.server_business_user,
                    self.server_business_user_pwd, self.server_root_pwd)
                server_ins.rewrite_config_file(
                    iam=write_iam, sdr=write_sdr, domain=write_domain,
                    all_account=all_account, all_sdr_info=sdr_infos,
                    global_domain_name=global_domain_name)
            logger.info("Init sys config success.")
            return True
        logger.info("No need init sys config, skip.")

    def get_iam_accounts(self):
        account_list = list()
        installed_services = get_install_components(
            self.project_id, self.pod_id)
        for service in installed_services:
            if service.get("key", None).upper() == \
                    PathValue.CSDR_SERVICE_TYPE.upper() \
                    and service.get("value", None).lower() == "true":
                account_list.append("csdr_service")
            elif service.get("key", None).upper() == \
                    PathValue.CSHA_SERVICE_TYPE.upper() and \
                    service.get("value", None).lower() == "true":
                account_list.append("csha_service")
            elif service.get("key", None).upper() == \
                    PathValue.VHA_SERVICE_TYPE.upper() \
                    and service.get("value", None).lower() == "true":
                account_list.append("vha_service")
        all_account = ",".join(account_list)
        return all_account

    def get_sdr_info(self):
        regions = self.mo_cmdb_ins.get_region_info()
        region_ids = [region.get("regionCode") for region in regions]
        installed_dr_region_ids = \
            [region_id for region_id in region_ids if
             self.mo_cmdb_ins.get_cloud_service_info(
                 region_id, "eReplication")]
        sdr_infos = list()
        for region_id in installed_dr_region_ids:
            service_lst = list()
            installed_services = get_install_components(
                self.project_id, self.pod_id, region_id=region_id)
            sdr_ctrl_float_ip = self.get_sdr_ctrl_float_ip_from_cmdb(
                region_id)
            if not sdr_ctrl_float_ip:
                continue
            for service in installed_services:
                if service.get("key", None).upper() == \
                        PathValue.CSDR_SERVICE_TYPE.upper() \
                        and service.get("value", None).lower() == "true":
                    service_lst.append("csdr")
                elif service.get("key", None).upper() == \
                        PathValue.CSHA_SERVICE_TYPE.upper() and \
                        service.get("value", None).lower() == "true":
                    service_lst.append("csha")
                elif service.get("key", None).upper() == \
                        PathValue.VHA_SERVICE_TYPE.upper() \
                        and service.get("value", None).lower() == "true":
                    service_lst.append("vha")
            if not service_lst:
                continue
            sdr_info = \
                f"{region_id}@{sdr_ctrl_float_ip}#{'&'.join(service_lst)}"
            sdr_infos.append(sdr_info)
        return ",".join(sdr_infos)

    def get_sdr_ctrl_float_ip_from_cmdb(self, region_id):
        sdr_machine_name_lst = ["PUB-SRV-03", "PUB-SRV-04", "PUB-SRV-DC1-02",
                                "PUB-SRV-DC2-02"]
        sdr_deploy_node_info = self.mo_cmdb_ins.get_deploy_node_info(
            region_id, "SDR")
        for node in sdr_deploy_node_info:
            if node.get("name") not in sdr_machine_name_lst:
                continue
            sdr_float_ip_info = node.get("floatIpAddresses")
            for float_ip_info in sdr_float_ip_info:
                if float_ip_info.get("cloudServiceIndex") != "SDR":
                    continue
                sdr_ctrl_float_ip = float_ip_info.get("floatIp")
                logger.info(
                    f"Get sdr ctrl float ip for {region_id} "
                    f"return {sdr_ctrl_float_ip}")
                return sdr_ctrl_float_ip
