# -*- coding: UTF-8 -*-
import re
import traceback

import cliUtil
import common
from common import UnCheckException
from common_utils import get_err_msg
from common_utils import check_conn_and_mode
LANG = common.getLang(py_java_env)
LOGGER = common.getLogger(PY_LOGGER, __file__)
PY_JAVA_ENV = py_java_env

LOCAL_RET = "ON LOCAL DEVICE(SN:{})"
REMOTE_RET = "\n\nON REMOTE DEVICE(SN:{})"


def execute(cli):
    """
    双活开关状态检查
    :param cli:
    :return:
    """
    switch_check = SyncLunSnSwitchCheck(cli, LANG, PY_JAVA_ENV, LOGGER)
    flag, msg = switch_check.execute_check()
    return flag, "\n".join(switch_check.all_ret_list), msg


class SyncLunSnSwitchCheck:
    def __init__(self, cli, lang, env, logger):
        self.cli = cli
        self.lang = lang
        self.env = env
        self.logger = logger
        self.all_ret_list = []
        self.err_msg_list = []
        self.no_check_msg_list = []
        self.p_version = str(env.get("devInfo").getProductVersion())

    def execute_check(self):
        try:
            local_dev_sn = PY_JAVA_ENV.get("devInfo").getDeviceSerialNumber()
            self.all_ret_list.append(LOCAL_RET.format(local_dev_sn))
            domain_info = self.get_domain_info(local_dev_sn)
            if not domain_info:
                return True, ""
            remote_dev_dict = self.get_remote_dev_info(local_dev_sn)
            self.change_to_developer()
            local_switch = self.get_hyper_switch(local_dev_sn)
            added_sns = self.get_added_remote_dev_sn()
            checked_remote_dev = []
            for domain_id, remote_dev_id in domain_info.items():
                if remote_dev_id in checked_remote_dev:
                    continue
                checked_remote_dev.append(remote_dev_id)
                remote_dev_sn = remote_dev_dict.get(remote_dev_id)
                if remote_dev_sn not in added_sns:
                    err_msg = common.getMsg(LANG, "not.add.remote.device.again", remote_dev_sn)
                    self.no_check_msg_list.append(err_msg)
                    continue
                self.all_ret_list.append(REMOTE_RET.format(remote_dev_sn))
                remote_switch = self.get_hyper_switch(remote_dev_sn)
                self.logger.logInfo(
                    "local:{},remote:{},local_sn:{},remote_sn:{}".format(
                        local_switch, remote_switch, local_dev_sn,
                        remote_dev_sn,
                    )
                )
                self.check_switch(
                    local_switch, remote_switch, local_dev_sn, remote_dev_sn
                )
            total_error_list = self.err_msg_list + self.no_check_msg_list
            total_error_msg = "\n".join(total_error_list)
            if self.err_msg_list:
                return cliUtil.RESULT_WARNING, total_error_msg + common.getMsg(self.lang,
                                                                               "hyper.metro.sync.lock.diff.sugg")
            if self.no_check_msg_list:
                return cliUtil.RESULT_WARNING if common.is_opening_delivery_inspect(
                    PY_JAVA_ENV) else cliUtil.RESULT_NOCHECK, total_error_msg
            return True, ""
        except UnCheckException as e:
            return cliUtil.RESULT_NOCHECK, e.errorMsg
        except Exception:
            LOGGER.logError(str(traceback.format_exc()))
            return (
                cliUtil.RESULT_NOCHECK,
                common.getMsg(self.lang, "query.result.abnormal"),
            )
        finally:
            check_conn_and_mode(self.cli, self.lang, self.logger)

    def check_switch(self, local_switch, remote_switch, sn_local, sn_remote):
        """
        检查开关
        :param local_switch:
        :param remote_switch:
        :param sn_local:
        :param sn_remote:
        :return:
        """
        # 两端开关都关闭时，给提示信息：本端和远端的双活同步LUN SN开关已关闭，请确认是是否需要打开。
        if local_switch == remote_switch == "off":
            self.err_msg_list.append(
                get_err_msg(
                    self.lang,
                    "software.hyper.metro.switch.not.support",
                    (sn_local, sn_remote),
                )
            )
            return False
        if local_switch == remote_switch:
            return True
        if any(
            [
                not local_switch and remote_switch == "off",
                local_switch == "off" and not remote_switch,
            ]
        ):
            return True
        if not local_switch and remote_switch == "on":
            self.err_msg_list.append(
                get_err_msg(
                    self.lang,
                    "software.hyper.metro.switch.local.not.support",
                    (sn_remote, sn_local),
                )
            )
            return False
        if local_switch == "on" and not remote_switch:
            self.err_msg_list.append(
                get_err_msg(
                    self.lang,
                    "software.hyper.metro.switch.remote.not.support",
                    (sn_local, sn_remote),
                )
            )
            return False
        if local_switch != remote_switch:
            self.err_msg_list.append(
                get_err_msg(
                    self.lang,
                    "software.hyper.metro.switch.not.pass",
                    (sn_local, local_switch, sn_remote, remote_switch),
                )
            )
            return False
        return True

    def change_to_developer(self):
        """
        切换到developer模式，获取回显
        :return:
        """
        flag, cli_ret, err_msg = cliUtil.enterDeveloperMode(
            self.cli, self.lang
        )
        self.all_ret_list.append(cli_ret)
        if flag is not True:
            raise UnCheckException(err_msg, cli_ret)

    def get_hyper_switch(self, sn):
        """
        获取双活开关
        :param sn:
        :return:
        """
        cmd = "show hyper_metro general"
        flag, ret, err_msg = common.getObjFromFile(
            self.env, self.logger, sn, cmd, self.lang
        )
        self.all_ret_list.append(ret)
        if flag is not True:
            if not cliUtil.hasCliExecPrivilege(ret):
                return ""
            raise UnCheckException(err_msg, ret)
        regx = "Syncronized LUN SN Switch\s*:\s*(on|off)"
        res_list = re.compile(regx).findall(ret, re.IGNORECASE)
        if not res_list:
            return ""

        return res_list[0]

    def get_domain_info(self, dev_sn):
        """
        获取domain信息
        :param dev_sn:
        :return:
        """

        domain_dict = {}
        cmd = "show hyper_metro_domain general"
        flag, ret, msg = common.getObjFromFile(
            self.env, self.logger, dev_sn, cmd, self.lang
        )
        self.all_ret_list.append(ret)
        if flag is not True:
            raise UnCheckException(msg, ret)

        domain_list = cliUtil.getHorizontalCliRet(ret)
        for domain_info in domain_list:
            domain_id = domain_info.get("ID", "")
            remote_device_id = domain_info.get("Remote Device ID", "")
            domain_dict[domain_id] = remote_device_id

        return domain_dict

    def get_remote_dev_info(self, sn):
        """
        获取远端设备
        :param sn:
        :return:
        """
        remote_dev_dict = dict()
        cmd = "show remote_device general"
        flag, ret, msg = common.getObjFromFile(
            self.env, self.logger, sn, cmd, self.lang
        )
        self.all_ret_list.append(ret)
        if flag is not True:
            raise UnCheckException(msg, ret)

        device_list = cliUtil.getHorizontalCliRet(ret)
        for remote_dev in device_list:
            device_id = remote_dev.get("ID", "")
            device_sn = remote_dev.get("SN", "")
            remote_dev_dict[device_id] = device_sn

        return remote_dev_dict

    def get_added_remote_dev_sn(self):
        """
        @summary: 获取双活远端设备
        """
        sn_list = self.env.get("devInfo").getRemoteSNs()
        select_devs = self.env.get("selectDevs")

        # 如果为空列表，表示没有远端双活设备
        if not sn_list:
            return []

        select_sn = [node.getDeviceSerialNumber() for node in select_devs]
        added_sn = [dev_sn for dev_sn in sn_list if dev_sn in select_sn]
        return added_sn
