# -*- coding:utf-8 -*-
import os
import re

import time
import utils.common.log as logger
from utils.common.exception import HCCIException
from utils.common.ssh_util import Ssh

from plugins.eBackup.common.util import Utils
from plugins.eBackup.common.model import SshInfo, DmkTaskInfo
from plugins.eBackup.scripts.upgrade.dmk_task import DMKTask
from plugins.eBackup.scripts.upgrade.util import EbackupConfig

CHANGE_HA_CMD = 'export LD_LIBRARY_PATH=' \
                '/opt/huawei-data-protection/ebackup/libs/ && ' \
                '/opt/huawei-data-protection/ebackup/ha/module' \
                '/hacom/tools/ha_client_tool --switchover ' \
                '--name=product'
QUERY_HA_ROLE = 'sh /opt/huawei-data-protection/ebackup' \
                '/bin/config_omm_ha.sh query|' \
                r'grep "HaLocalName"|grep -E -o "\(.*\)";' \
                'cat /tmp/upgrade/ha_current_role'


class EbackupUpgrader(object):
    def __init__(self, host_ips, db_param_dict, ebackup_type=None):
        self.__db_param_dict = db_param_dict
        self.__version = self.__db_param_dict['eBackup_Version']
        self.__ebackup_type = ebackup_type
        self.__host_ips = host_ips
        self.dmk_task = DMKTask(db_param_dict['dmk_floatIp'],
                                db_param_dict['eBackup_dmk_user'],
                                db_param_dict['eBackup_dmk_password'])

    def check_tmp_upgrade(self):
        try:
            for server_ip in self.__host_ips:
                if not self._is_exist_tmp_upgrade(server_ip):
                    return True
            return False
        except Exception as e:
            raise e

    def precheck(self):
        is_true = self._is_all_upgraded()
        if is_true:
            logger.info("All node is upgraded.")
            return True
        return self._deploy_with_dmk("precheck")

    def backup(self):
        need_precheck = self.check_tmp_upgrade()
        if need_precheck:
            self.precheck()
        is_true = self._is_all_upgraded()
        if is_true:
            logger.info("All node is upgraded.")
            return True
        return self._deploy_with_dmk("backup")

    def upgrade(self):
        is_true = self._is_all_upgraded()
        if is_true:
            logger.info("All nodes are upgraded.")
            return True
        self.adjust_ha_relation()
        return self._deploy_with_dmk("upgrade")

    def rollback(self):
        is_true = self._is_all_rollback()
        if is_true:
            logger.info("All nodes are old version,no need to rollback.")
            return True
        return self._deploy_with_dmk("rollback")

    def rollback_node(self, rollback_ip):
        ssh_client = None

        def do_rollback_node():
            nonlocal ssh_client
            logger.info("Begin to excute rollback force cmd on node %s" % rollback_ip)
            ssh = Ssh()
            rollback_force_cmd = 'cd /home/hcp/;dos2unix rollback_force.sh;' \
                                 'chmod +x rollback_force.sh;nohup sh ' \
                                 'rollback_force.sh %s > nohup.out &' \
                                 % self.__db_param_dict['eBackup_Version']
            file_path = os.path.realpath(__file__ + '/../../../shell_tool/rollback_force.sh')
            ssh_client = ssh.ssh_create_client(rollback_ip, 'hcp', self.__db_param_dict['eBackup_hcp_pwd'])
            ssh.ssh_send_command(ssh_client, 'su - root', 'Password:', 100)
            ssh.ssh_send_command(ssh_client, self.__db_param_dict['eBackup_root_pwd'], '#', 100)
            ssh.ssh_send_command(ssh_client, 'rm -rf /home/hcp/rollback_force.sh', '#', 100)
            ret = ssh.put_file(rollback_ip, "hcp", self.__db_param_dict["eBackup_hcp_pwd"], file_path, "/home/hcp/")
            if not ret:
                logger.error("Upload %s to host(%s) failed." % (file_path, rollback_ip))
                return False
            logger.info("Upload %s to host %s successfully." % (file_path, rollback_ip))
            ssh.ssh_send_command(ssh_client, rollback_force_cmd, 'nohup.out', 100)

            logger.info("Excute rollback cmd on %s successfully." % rollback_ip)
            return True

        try:
            res = do_rollback_node()
        except Exception as e:
            logger.error("rollback force failed,The reason is:" + str(e))
            return False
        finally:
            if ssh_client:
                Utils.close_ssh_clinet(ssh_client)
        return res

    @staticmethod
    def get_rollback_result(ssh, ssh_client, ebk_ip):
        check_rollback_status_cmd = 'ps -ef |grep "rollback_force" | ' \
                                    'grep -v grep >/dev/null 2>&1 || ' \
                                    'echo "successfully"'
        count = 1
        flag = False
        while count < 50:  # 50 * 15 = 750s
            ret = ssh.ssh_exec_command_return_list(
                ssh_client, check_rollback_status_cmd)
            if -1 != str(ret).find("successfully"):
                flag = True
                break
            logger.info(
                "rollback is still going at the %sst time on %s." %
                (count, ebk_ip))
            count += 1
            time.sleep(15)
        if not flag:
            logger.error("maybe rollback failed at %s." % ebk_ip)
            return False
        check_rollback_cmd = "cat /home/hcp/nohup.out | grep " \
                             "'Rollback and start force " \
                             "successfully.' && echo " \
                             "'successfully'"
        ret = ssh.ssh_exec_command_return_list(ssh_client, check_rollback_cmd)
        if -1 == str(ret).find("successfully"):
            logger.error("rollback failed, please check the %s node "
                         "/home/hcp/nohup.out." % ebk_ip)
            return False
        return True

    def start_service_after_rollback(self, node_ip):
        ssh_client = None
        start_service_cmd = 'nohup service hcp start force &'
        check_status_cmd = 'service hcp status >/dev/null 2>&1 && ' \
                           'echo "successfully"'

        def do_start_service_after_rollback():
            nonlocal ssh_client
            logger.info(f"Begin to check service status at {node_ip}.")
            ssh = Ssh()
            ssh_client = ssh.ssh_create_client(node_ip, "hcp", self.__db_param_dict['eBackup_hcp_pwd'])
            ssh.ssh_send_command(ssh_client, 'su - root', 'Password:', 100)
            ssh.ssh_send_command(ssh_client, self.__db_param_dict['eBackup_root_pwd'], '#', 100)
            rollback_result = self.get_rollback_result(ssh, ssh_client, node_ip)
            if not rollback_result:
                logger.error(f"rollback failed on node {node_ip}.")
                return False
            ssh.ssh_send_command(ssh_client, start_service_cmd, '#', 100)
            for count in range(1, 51):  # 50 * 15 = 750s
                ret = ssh.ssh_exec_command_return_list(ssh_client, check_status_cmd)
                if -1 != str(ret).find("successfully"):
                    logger.info(f"Service status is normal on {node_ip}.")
                    return True
                logger.info(f"Service is still starting at the {count}st time on {node_ip}.")
                time.sleep(15)
            else:
                logger.error("Start Service failed at %s." % node_ip)
                return False

        try:
            res = do_start_service_after_rollback()
        except Exception as e:
            logger.error("Check service status failed.The reason is:%s", str(e))
            return False
        finally:
            if ssh_client:
                Utils.close_ssh_clinet(ssh_client)
        return res

    def is_need_rollback(self):
        logger.info("Begin to check whether need to rollback force.")
        check_version_cmd = '''grep "System Version" /opt/huawei-data-''' \
                            '''protection/ebackup/conf/versions.conf | ''' \
                            '''awk -F "=" '{print $2}' '''
        check_status_cmd = 'service hcp status >/dev/null 2>&1 && echo "successfully"'
        ssh_client = None
        ssh = Ssh()
        need_rollback = False
        all_version = []

        def do_check_service_status_and_version(rollback_ip):
            nonlocal ssh_client, need_rollback
            ssh_client = ssh.ssh_create_client(rollback_ip, "hcp", self.__db_param_dict['eBackup_hcp_pwd'])
            ssh.ssh_send_command(ssh_client, 'su - root', 'Password:', 100)
            ssh.ssh_send_command(ssh_client, self.__db_param_dict['eBackup_root_pwd'], '#', 100)
            ret = ssh.ssh_exec_command_return_list(ssh_client, check_version_cmd)
            version = ret[0].strip().replace("\n", "")
            logger.info("The version is %s of node[%s]." % (version, rollback_ip))
            if version not in all_version:
                all_version.append(version)
            if len(all_version) >= 2:
                need_rollback = True
                logger.info("There are different versions %s in cluster %s,need to rollback." %
                            (str(all_version), str(self.__host_ips)))
                return True

            ret = ssh.ssh_exec_command_return_list(ssh_client, check_status_cmd)
            if -1 == str(ret).find("successfully"):
                need_rollback = True
                logger.info("The service status is abnormal on node[%s], need to rollback force." % rollback_ip)
                return True
            logger.info("The service is running on %s" % rollback_ip)
            return False

        for host_ip in self.__host_ips:
            try:
                if do_check_service_status_and_version(host_ip):
                    break
            except Exception as err:
                logger.error(f"Check service status and version failed.The reason is:{err}.")
                raise err
            finally:
                if ssh_client:
                    Utils.close_ssh_clinet(ssh_client)

        logger.info("Check whether need to rollback force, the result: " + str(need_rollback))
        return need_rollback

    def stop_all_nodes(self, stop_nodes_sequence):
        ssh_client = None
        stop_service_cmd = 'service hcp stop force >/dev/null 2>&1'
        remove_script_cmd = 'rm -f /home/hcp/rollback_force.sh'
        ssh = Ssh()
        logger.info("Begin to stop all node.")

        def do_stop_node(stop_ip):
            client = ssh.ssh_create_client(stop_ip, "hcp", self.__db_param_dict['eBackup_hcp_pwd'])
            ssh.ssh_send_command(client, 'su - root', 'Password:', 100)
            ssh.ssh_send_command(client, self.__db_param_dict['eBackup_root_pwd'], '#', 100)
            ssh.ssh_send_command(client, remove_script_cmd, '#', 180)
            ssh.ssh_send_command(client, stop_service_cmd, '#', 180)
            return client

        for node_ip in stop_nodes_sequence:
            try:
                ssh_client = do_stop_node(node_ip)
            except Exception as err:
                logger.debug(f"Exception occurs when stop service:{err}.")
            finally:
                if ssh_client:
                    Utils.close_ssh_clinet(ssh_client)
        logger.info("All node stoped.")
        return True

    def rollback_force(self):
        logger.info("Begin to rollback:" + str(self.__host_ips))
        # step1:check whether to rollback force
        is_true = self.is_need_rollback()
        if not is_true:
            logger.info("All node[%s] have been rollbacked and "
                        "status is normal,No need to rollback force." %
                        str(self.__host_ips))
            return True
        # step2:stop service
        primary_ip = self.find_primary_node()
        if 0 == len(primary_ip):
            logger.error("Find primary node failed.")
            return False
        stop_service_sequence = self.__host_ips
        stop_service_sequence.remove(primary_ip)
        stop_service_sequence.append(primary_ip)
        self.stop_all_nodes(stop_service_sequence)
        # step2:rollback and start primary node
        is_true = self.rollback_node(primary_ip)
        if not is_true:
            logger.error("Excute rollback operation "
                         "failed on " + str(primary_ip))
            return False
        logger.info("Begin to sleep 200s to wait the primary starting.")
        time.sleep(200)
        is_true = self.start_service_after_rollback(primary_ip)
        if not is_true:
            logger.error("Start service failed on " + str(primary_ip))
            return False

        # step3:rollback and start standby and proxy node
        standby_and_proxy = self.__host_ips
        standby_and_proxy.remove(primary_ip)
        for node_ip in standby_and_proxy:
            is_true = self.rollback_node(node_ip)
            if not is_true:
                logger.error("Excute rollback operation failed on " + str(node_ip))
                return False
        logger.info("Begin to sleep 30s to wait the standby and "
                    "proxy starting.")
        time.sleep(30)
        for node_ip in standby_and_proxy:
            is_true = self.start_service_after_rollback(node_ip)
            if not is_true:
                logger.error("Start service failed on " + str(node_ip))
                return False
        logger.info("Rollback %s force successfully" % str(self.__host_ips))
        return True

    def find_primary_node(self):
        def do_find_primary_node():
            nonlocal ssh_client
            ssh_client = ssh.ssh_create_client(node_ip, 'hcp', self.__db_param_dict['eBackup_hcp_pwd'])
            ssh.ssh_send_command(ssh_client, 'su - root', 'Password:', 100)
            ssh.ssh_send_command(ssh_client, self.__db_param_dict['eBackup_root_pwd'], '#', 100)
            result = ssh.ssh_exec_command_return_list(ssh_client, find_primary_node)
            ha_role = result[0].strip().replace('\n', '')
            if "0" == ha_role:
                logger.info("The primary ip is:" + node_ip)
                return node_ip
            return ''

        logger.info("Begin to find the primary node.")
        ssh = Ssh()
        find_primary_node = 'cat /var/ebackup_bak/ebackup/ha_current_role'
        ssh_client = None
        for node_ip in self.__host_ips:
            try:
                primary_node_ip = do_find_primary_node()
            except Exception as err:
                logger.error(f"Find primary node failed.The reason is:{err}.")
                return ''
            finally:
                if ssh_client:
                    Utils.close_ssh_clinet(ssh_client)
            if primary_node_ip:
                logger.info("Find primary node successfully,the primary ip is:" + primary_node_ip)
                return primary_node_ip
        else:
            logger.error("Find primary node failed.")
            return ''

    def _is_exist_tmp_upgrade(self, server_ip):
        def do_is_exist_tmp_upgrade():
            nonlocal ssh_client
            account_hcp = "hcp"
            account_hcp_passwd = self.__db_param_dict['eBackup_hcp_pwd']
            account_root_passwd = self.__db_param_dict['eBackup_root_pwd']
            tmp_check_cmds = '''[ -d /tmp/upgrade ] && echo "tmpdir exist"'''
            ssh = Ssh()
            ssh_client = ssh.ssh_create_client(server_ip, account_hcp,
                                               account_hcp_passwd)
            ssh.ssh_send_command(ssh_client, 'su - root', 'Password:', 100)
            ssh.ssh_send_command(ssh_client, account_root_passwd, '#', 100)
            result = ssh.ssh_exec_command_return_list(ssh_client, tmp_check_cmds)
            if str(result).find("tmpdir exist") == -1:
                return False
            logger.info("eBackup node " + server_ip + " tmpdir exist.")
            return True
        ssh_client = None
        try:
            rc = do_is_exist_tmp_upgrade()
        except Exception as e:
            raise e
        finally:
            if ssh_client:
                Utils.close_ssh_clinet(ssh_client)
        return rc

    def _is_current_version(self, server_ip):
        def do_is_current_version():
            nonlocal ssh_client
            account_hcp = "hcp"
            account_hcp_passwd = self.__db_param_dict['eBackup_hcp_pwd']
            account_root_passwd = self.__db_param_dict['eBackup_root_pwd']
            version_check_cmds = '''grep "System Version" /opt/huawei-data''' \
                                 '''-protection/ebackup/conf/versions.conf ''' \
                                 '''| awk -F "=" '{print $2}' '''
            ssh = Ssh()
            ssh_client = ssh.ssh_create_client(server_ip, account_hcp,
                                               account_hcp_passwd)
            ssh.ssh_send_command(ssh_client, 'su - root', 'Password:', 100)
            ssh.ssh_send_command(ssh_client, account_root_passwd, '#', 100)
            result = ssh.ssh_exec_command_return_list(ssh_client, version_check_cmds)

            logger.info("eBackup node " + server_ip + " version is " + result[0])
            return result[0].strip().replace('\n', '')

        ssh_client = None
        try:
            version = do_is_current_version()
        except Exception as e:
            raise e
        finally:
            if ssh_client:
                Utils.close_ssh_clinet(ssh_client)
        if version == self.__db_param_dict['eBackup_Version']:
            return True
        else:
            return False

    def _is_all_upgraded(self):
        try:
            for server_ip in self.__host_ips:
                is_true = self._is_current_version(server_ip)
                if not is_true:
                    return False
            return True
        except Exception as e:
            raise e

    def _is_all_rollback(self):
        try:
            for server_ip in self.__host_ips:
                is_true = self._is_current_version(server_ip)
                if is_true:
                    return False
            return True
        except Exception as e:
            raise e

    def _deploy_with_dmk(self, action):

        ebackup_config, ebackup_host = self.get_all_params()

        # install eBackup
        ebackup_action = ""
        if action == "precheck":
            ebackup_action = "[upgrade]1.Precheck eBackup"
            ebackup_config = "no_need_backup: 0"
        elif action == "backup":
            ebackup_action = "[backup] Backup eBackup Data"
        elif action == "upgrade":
            ebackup_action = "[upgrade]2.Upgrade eBackup"
        elif action == "rollback":
            ebackup_action = "[rollback]1.Rollback eBackup"

        dmk_task_info = DmkTaskInfo(self.__version, ebackup_action, [ebackup_host, ebackup_config])
        result = self.dmk_task.do_task(dmk_task_info, "hcp")
        if not result:
            logger.error("Do DMK task %s failed." % ebackup_action)

        return result

    def get_all_params(self):
        config = EbackupConfig(self.__host_ips, self.__db_param_dict,
                               self.__ebackup_type)
        return config.get_config()

    def judge_ha_changed(self, hcp_pwd, root_pwd):
        logger.info("Begin to check if ha has been changed after prechecking.")
        primary_ip = ""
        standby_ip = ""
        is_ha_changed = None
        for node_ip in self.__host_ips:
            ha_ssh_info = SshInfo(node_ip, 'hcp', hcp_pwd, root_pwd)
            ssh_client = Utils.get_ssh_client(ha_ssh_info)
            result = Ssh.ssh_exec_command_return_list(ssh_client, QUERY_HA_ROLE)
            rule = re.compile(r'[(](.*?)[)]', re.S)
            ha_role = re.findall(rule, result[0])[0]
            ha_current_role = result[1].strip().replace('\n', '')
            if ha_role == 'active':
                primary_ip = node_ip
                if ha_current_role == '0':
                    logger.info("Ha has not been changed after prechecking.")
                    Utils.close_ssh_clinet(ssh_client)
                    return False, "", ""
                else:
                    logger.info("Ha has been changed after prechecking.")
                    is_ha_changed = True

            elif ha_role == 'standby':
                standby_ip = node_ip
                if ha_current_role == '1':
                    logger.info("Ha has not been changed after prechecking.")
                    Utils.close_ssh_clinet(ssh_client)
                    return False, "", ""
                else:
                    logger.info("Ha has been changed after prechecking.")
                    is_ha_changed = True
            Utils.close_ssh_clinet(ssh_client)
        return is_ha_changed, primary_ip, standby_ip

    @staticmethod
    def switch_ha(primary_ip, standby_ip, hcp_pwd, root_pwd):
        primary_ssh_info = SshInfo(primary_ip, 'hcp', hcp_pwd, root_pwd)
        standby_ssh_info = SshInfo(standby_ip, 'hcp', hcp_pwd, root_pwd)
        ssh_client = Utils.get_ssh_client(primary_ssh_info)
        Ssh.ssh_exec_command_return_list(ssh_client, CHANGE_HA_CMD)
        Utils.close_ssh_clinet(ssh_client)
        logger.info("Excute the command to change ha successfully."
                    "Now sleep 2min to wait for it completed.")
        time.sleep(120)
        for _ in range(0, 12):
            ssh_client = Utils.get_ssh_client(standby_ssh_info)
            result = Ssh.ssh_exec_command_return_list(ssh_client,
                                                      QUERY_HA_ROLE)
            Utils.close_ssh_clinet(ssh_client)
            rule = re.compile(r'[(](.*?)[)]', re.S)
            ha_role = re.findall(rule, result[0])[0]
            if ha_role == 'active':
                logger.info("Ha has been changed completed.")
                return True
            time.sleep(10)
        logger.error("Changed the HA timeout.")
        return False

    def adjust_ha_relation(self):
        def do_adjust_ha_relation():
            hcp_pwd = self.__db_param_dict['eBackup_hcp_pwd']
            root_pwd = self.__db_param_dict['eBackup_root_pwd']
            is_ha_changed, primary_ip, standby_ip = self.judge_ha_changed(
                hcp_pwd, root_pwd)
            if not is_ha_changed:
                return True

            logger.info("Ha has beend changed,Now we need "
                        "to restore the ha relationship.")
            if is_ha_changed and primary_ip != '' and standby_ip != '':
                if not self.switch_ha(
                        primary_ip, standby_ip, hcp_pwd, root_pwd):
                    raise HCCIException(650039, primary_ip + ',' + standby_ip)
                return True
            return ''

        try:
            return do_adjust_ha_relation()
        except HCCIException as err:
            logger.error("Restore Ha failed.")
            raise err
        except Exception as err:
            logger.error(f"Adjust Ha failed, the reason is:{err}.")
            raise err
