# -*- coding: utf-8 -*-
import time
import traceback
from tenacity import stop_after_attempt, wait_fixed
from tenacity import retry as retry_adorn

import utils.common.log as logger
from utils.common.exception import HCCIException
from plugins.DistributedStorage.common.RestClient import StorageSSHClient
from plugins.DistributedStorage.common.UpgradeOperate import UpgradeOperate
from plugins.DistributedStorage.common.UpgradeHotPatchOperate import UpgradeHotPatchOperate
from plugins.DistributedStorage.common.base import TestCase
from plugins.DistributedStorage.Upgrade.scripts.impl.TC_Post_Upgrade_Check import PostUpgradeCheck
from plugins.DistributedStorage.Upgrade.scripts.impl.TC_Upgrade_Confirm import UpgradeConfirm


class HotpatchConfirm(TestCase):
    def __init__(self, project_id, pod_id, fs_args, condition=None,
                 metadata=None, **kwargs):
        super(HotpatchConfirm, self).__init__(project_id, pod_id)
        self.condition = condition
        self.metadata = metadata
        self.more_args = kwargs
        self.opr = UpgradeOperate(fs_args)
        self.fs_args = fs_args
        self.master_node = fs_args.get("master_node")
        self.master_client, self.slaver_client = None, None
        self.slaver_node = fs_args.get("slaver_node")
        self.remote_path = "/tmp/upgrade_tmp_hotpatch"
        self.user_name = fs_args["user_name"]
        self.password = fs_args["password"]
        self.upgrade_type = fs_args["upgrade_type"]
        self.upgrade_commit_key = fs_args.get("upgrade_commit_key")
        self.region_id = self.fs_args.get("region_id")

    def procedure(self):
        logger.info('Start %s commit.' % self.upgrade_type)
        try:
            status_code, error_code, error_des = self.opr.try_login(
                self.user_name, self.password)
            if status_code != 200 or error_code != 0:
                err_msg = "Failed to login, Detail:[status:%s,code:%s]%s" % \
                          (status_code, error_code, error_des)
                logger.error(err_msg)
                raise Exception(err_msg)
            if self.fs_args.get("hot_patch_tag"):
                # 获取主备节点
                self.create_fsm_client()
                # 上传脚本到主备fsm节点
                self.upload_script_to_fsm()
                self.hotpatch_upgrade_commit_procedure()
            else:
                self.upgrade_confirm()
            if self.fs_args.get("hot_patch_tag"):
                self.clear_fsm_backup_data()
        except HCCIException as e:
            logger.error('failed commit upgrade:{}'.format(e))
            logger.error(traceback.format_exc())
            raise e
        except Exception as e:
            logger.error('failed commit upgrade:{}'.format(e))
            logger.error(traceback.format_exc())
            raise HCCIException(620017, str(e))
        finally:
            if self.fs_args.get("hot_patch_tag"):
                self.del_script_om_fsm()

    def hotpatch_upgrade_commit_procedure(self):
        current_task_id = self.get_upgrade_commit_state()
        self.run_hotpatch_commit_tasks(current_task_id)

    def run_hotpatch_commit_tasks(self, start_task):
        commit_tasks = [
            "restore_product_repo_data",
            "restore_product_db_data",
            "post_upgrade_check",
            "upgrade_confirm",
            "restore_hotpatch_repo_data",
            "restore_hotpatch_db_data",
            "post_upgrade_check",
            "upgrade_confirm"
        ]
        for index, task in enumerate(commit_tasks[start_task:]):
            self.update_upgrade_commit_state(start_task + index)
            getattr(self, task)()

    def upgrade_confirm(self):
        UpgradeConfirm(self.project_id, self.pod_id, self.fs_args).procedure()

    def restore_product_repo_data(self):
        UpgradeHotPatchOperate.restore_repo_data(self.remote_path, self.master_client)

    def restore_product_db_data(self):
        UpgradeHotPatchOperate.restore_db_data(self.remote_path, self.master_client, step="product")

    def restore_hotpatch_repo_data(self):
        UpgradeHotPatchOperate.restore_repo_data(self.remote_path, self.master_client)

    def restore_hotpatch_db_data(self):
        UpgradeHotPatchOperate.restore_db_data(self.remote_path, self.master_client)

    def clear_fsm_backup_data(self):
        UpgradeHotPatchOperate.clear_backup_data(self.remote_path, self.master_client)
        UpgradeHotPatchOperate.clear_backup_data(self.remote_path, self.slaver_client)

    @retry_adorn(stop=stop_after_attempt(5), wait=wait_fixed(120), reraise=True)
    def post_upgrade_check(self):
        """
        升级提交前检查
        """
        PostUpgradeCheck(self.project_id, self.pod_id, self.fs_args).procedure()

    def upload_script_to_fsm(self):
        """
        上传脚本至fsm主备节点
        """
        node_list = [(self.master_client, self.master_node),
                     (self.slaver_client, self.slaver_node)]
        for ssh_client, node in node_list:
            UpgradeHotPatchOperate.upload_script_to_fsm(ssh_client, node, self.remote_path)

    def del_script_om_fsm(self):
        """
        删除fsm节点上的脚本文件
        """
        client_list = [client for client in
                       [self.master_client, self.slaver_client] if client]
        for ssh_client in client_list:
            UpgradeHotPatchOperate.del_script_om_fsm(ssh_client, self.remote_path)
            del ssh_client

    def create_fsm_client(self):
        """
        获取当前主节点
        """
        node_infos = [self.master_node, self.slaver_node]
        self.master_client, node = UpgradeHotPatchOperate.get_master_node_client(node_infos)
        slaver_node = self.slaver_node if node == self.master_node else self.master_node
        self.slaver_client = StorageSSHClient(*slaver_node[:3])
        self.slaver_client.switch_root(slaver_node.root_pwd)

    def update_upgrade_commit_state(self, state=0):
        """
        更新记录当前升级步骤
        0：升级确认初级状态
        1：大包恢复repo数据完成
        2：大包恢复DB数据完成
        3：大包升级确认完成
        4：热补丁恢复repo数据完成
        5：热补丁恢复DB数据完成
        6：热补丁升级确认完成
        """
        UpgradeHotPatchOperate.update_operate_state(self.project_id, self.region_id, self.upgrade_commit_key, state)

    def get_upgrade_commit_state(self):
        return UpgradeHotPatchOperate.get_operate_state(self.project_id, self.region_id, self.upgrade_commit_key)
