# encoding=utf-8
import json
import os
import sys
import time

from datetime import datetime
from commonlog import Logger
from taskmgr_util import Taskmgrutil
from uniep_taskmgr import Unieptaskmgr
from upgrade_custom import UpgradeCustom


LOGGER = Logger().getinstance(sys.argv[0])
# 日志级别枚举
INFO = "INFO"
ERROR = "ERROR"
FINISH= "FINISH"
# 检查备份包接口
BACKUP_CHECK_URL = "/rest/plat/brmgr/v1/main/baseinfo/backupdata/exist"
# 备份时间戳记录文件
TIMESTAMP_FILE = "/opt/upgrade/backup/backup_timestamp.properties"
# 错误码资源文件
local_path = os.path.dirname(os.path.realpath(__file__))
error_code_file = os.path.join(os.path.dirname(local_path), "config/operate_check_error_code.json")


def check_backup_product_data(product_name, time_stamp):
    """
    功能说明:检查产品数据备份数据是否存在
    :param product_name
    :param time_stamp
    :return:
    """
    # 软件包上传耗时 2小时
    time_out = int(time.monotonic()) + 7200
    from upload_backup_package import build_upload_param, UPLOAD_BACKUP_PACKAGE_URL, \
        FAILURE_STATE, RUNNING_STATE
    uniep_taskmgr = Unieptaskmgr()
    request_params = build_upload_param(product_name, time_stamp)
    result, response = uniep_taskmgr.send_post_request(UPLOAD_BACKUP_PACKAGE_URL, request_params)
    LOGGER.info(f"{UPLOAD_BACKUP_PACKAGE_URL},result:{result},response{response}")
    # result 不为True,上传请求失败
    if not result:
        return False
    task_url = json.loads(response).get("url", "")
    if not task_url:
        return False
    # 超时退出
    while int(time.monotonic()) <= time_out:
        result, response = uniep_taskmgr.send_get_request(task_url)
        LOGGER.info(f"{task_url},result:{result},response{response}")
        if not result or not json.loads(response):
            return False
        res_data = json.loads(response)[0]
        if res_data.get("currentState") in FAILURE_STATE:
            return False
        # 非运行态也非失败状态即为成功
        if res_data.get("currentState") not in RUNNING_STATE:
            return True
        time.sleep(10)
    return False


class RollBackCheck:
    """
    功能描述:回滚前置检查
    """
    def __init__(self, params):
        """
        初始化
        :param params:
        """
        self.script_id = params.get("scriptid", "")
        self.product_name = params.get("productname", "")
        self.src_version = params.get("src_version", "")
        self.des_version = params.get("des_version", "")
        self.is_backup_db = params.get("is_backup_db", "")
        self.backup_product_with_diff = params.get("backup_product_with_diff", "")
        self.action = params.get("action", "")

        self.msg_list = []
        self.task_mgr = Taskmgrutil()
        self.uniep_task_mgr = Unieptaskmgr()
        self.task_path = os.path.join("/opt/upgrade/easysuite_upgrade/taskmgr", self.script_id)
        self.task_mgr.init_e_taskmgr(self.task_path)

        self.plandata_path = "/opt/upgrade/easysuite_upgrade/workpath/%s-%s/workpath-%s/plandata" % \
                             (self.src_version, self.des_version, self.product_name)

    def record_log(self, msg, level=INFO):
        """
        记录日志
        """
        now_time = datetime.strftime(datetime.now(), '%Y-%m-%d %H:%M:%S')
        str_msg = "[%s] [%s] %s | %s" % (now_time, os.getpid(), level, msg)
        self.msg_list.append(str_msg)
        self.task_mgr.set_e_taskmsg(self.task_path, "\n".join(self.msg_list))
        if level == ERROR:
            self.generate_error_code()
        if level in [ERROR, FINISH]:
            self.task_mgr.set_e_taskstatus(self.task_path, level.lower())
            self.task_mgr.set_e_taskprogress(self.task_path, "100")
        return True

    def generate_error_code(self):
        """
        生成错误码文件
        """
        with open(error_code_file, "r") as f:
            error_code_details = json.loads(f.read())
        error_code_detail = error_code_details.get(self.action, {})
        if error_code_detail:
            Taskmgrutil.set_e_task_errorcode(self.task_path, error_code_detail)

    def check_backup_packages(self):
        """
        检查备份包是否在备份服务器上
        """
        self.record_log("start to check whether backup packages is on backup server.")

        # 备份记录key和备份模式对应关系
        timestamp_key_dict = self.get_timestamp_key()

        # 解析备份模式和(最新的一条)备份时间戳对应关系
        back_mode_dict = RollBackCheck.get_backup_model_dict(timestamp_key_dict)

        # 增量备份场景中，更新全量时间戳为增量时间戳匹配的时间戳，防止时间戳对应关系错误，导致回滚失败
        self.update_diff_mode_timestamp(back_mode_dict)

        # 检查对应时间戳的备份包是否在备份服务器上
        for mode, timestamp in back_mode_dict.items():
            check_params = self.build_check_params(mode, timestamp)
            status, response = self.uniep_task_mgr.send_post_request(BACKUP_CHECK_URL, check_params)
            LOGGER.info(f"{BACKUP_CHECK_URL}, status: {status}; response: {response}.")
            if not status:
                self.record_log("The interface is not available, skip check.")
                continue
            check_result = json.loads(response.decode())
            result = check_result.get("result", "")
            if result == "OK":
                self.record_log(f"mode: {mode}, timestamp: {timestamp} check success.")
            elif mode in ["upgrade", "upgrade_full", "upgrade_diff"] and \
                    check_backup_product_data(self.product_name, timestamp):
                self.record_log(f"mode: {mode}, timestamp: {timestamp} check success.")
            else:
                self.record_log(f"mode: {mode}, timestamp: {timestamp} is lost.", ERROR)
                return False
        return True

    def update_diff_mode_timestamp(self, back_mode_dict):
        """
        更新全量时间戳为增量时间戳匹配的时间戳，防止时间戳对应关系错误，导致回滚失败
        :param back_mode_dict:
        """
        if any(x not in back_mode_dict for x in ["upgrade_full", "upgrade_diff", "full_diff"]):
            return True
        full_diff_timestamp = back_mode_dict.get("full_diff")
        LOGGER.info(f"full_diff timestamp: {full_diff_timestamp}")
        full_timestamp = full_diff_timestamp.split("-")[0]
        back_mode_dict["upgrade_full"] = full_timestamp
        back_mode_dict.pop("full_diff")
        return True

    @staticmethod
    def get_backup_model_dict(timestamp_key_dict):
        """
        解析备份模式和(最新的一条)备份时间戳对应关系
        """
        back_mode_dict = {}
        with open(TIMESTAMP_FILE, "r", encoding="utf-8") as timestamp_file_content:
            for line in timestamp_file_content:
                timestamp_key = line.split(":")[0].strip()
                timestamp_value = line.split(":")[1].strip()
                if timestamp_key in timestamp_key_dict.keys():
                    back_mode_dict.update({timestamp_key_dict[timestamp_key]: timestamp_value})
        return back_mode_dict

    def get_timestamp_key(self):
        """
        备份记录key和备份模式对应关系
        备份模式upgrade_full/upgrade_diff用于区分增量备份数据，实际传参使用upgrade替换
        """
        timestamp_key_dict = {
            f"backup_single_uniep_{self.product_name}_{self.src_version}-{self.des_version}": "management"
        }
        if self.is_backup_db == "true":
            timestamp_key_dict.update({
                f"backup_db_{self.product_name}_{self.src_version}-{self.des_version}": "db"
            })
        # 区分增量备份时间戳
        if self.backup_product_with_diff == "false":
            timestamp_key_dict.update({
                f"backup_product_{self.product_name}_{self.src_version}-{self.des_version}": "upgrade"
            })
        else:
            timestamp_key_dict.update({
                f"backup_product_full_{self.product_name}_{self.src_version}-{self.des_version}": "upgrade_full",
                f"backup_product_diff_{self.product_name}_{self.src_version}-{self.des_version}": "upgrade_diff",
                f"backup_product_full_diff_{self.product_name}_{self.src_version}-{self.des_version}": "full_diff"
            })
        return timestamp_key_dict

    def build_check_params(self, mode, timestamp):
        """
        构造备份接口参数
        mode: {
            db：数据库应用程序
            management：管理面
            upgrade：产品数据
        }
        """
        product_name = "manager" if mode == "management" else self.product_name
        if mode in ["upgrade_full", "upgrade_diff"]:
            mode = "upgrade"
        check_params = dict(
            productName=product_name,
            mode=mode,
            timeStamp=timestamp
        )
        return check_params

    def check_rollback_nce(self):
        """
        1.检查升级是否超过30天 2.检查是否修改过ip
        rollback_check.json
        {
            "productName": "NCE",
            "DeployId": "XXX"
        }
        返回
        0：检查完成
        1：其他
        2：回滚超过30天
        3：节点IP被修改
        """
        self.record_log(f"start to check whether {self.product_name} can rollback.")

        # 是否生成升级部署任务记录
        rollback_check_input_json = os.path.join(self.plandata_path, "rollback_check.json")
        if not os.path.isfile(rollback_check_input_json):
            self.record_log(f"{self.product_name} is not upgraded, skip check.")
            return True

        # 命令行检查
        cmd = "ipmc_tool -cmd productmgr -o rollback_check -input %s" % rollback_check_input_json
        cmd_code, _ = Taskmgrutil.execute_cmd(cmd)
        if cmd_code == 0:
            self.record_log(f"{self.product_name} can rollback.")
        elif cmd_code == 2:
            self.record_log(f"{self.product_name} rollback is over 30 days.", ERROR)
            return False
        elif cmd_code == 3:
            self.record_log(f"{self.product_name} node ip is changed.", ERROR)
            return False
        else:
            self.record_log("The interface is not available, skip check.")
        return True


def main(argv):
    """
    功能说明:主入口
    :param argv:
    :return:
    """
    params = {
        "scriptid": argv[1],
        "productname": argv[2],
        "src_version": argv[3],
        "des_version": argv[4],
        "is_backup_db": argv[5],
        "backup_product_with_diff": argv[6],
        "os_patch": argv[7],
        "action": "rollback_check"
    }
    task_path = os.path.join("/opt/upgrade/easysuite_upgrade/taskmgr", argv[1])
    try:
        check_function = RollBackCheck(params)
        # 检查备份包
        if not check_function.check_backup_packages():
            return False
        # 检查产品回滚
        if not check_function.check_rollback_nce():
            return False
        # 执行定制检查脚本
        custom_check = UpgradeCustom(params)
        if not custom_check.main_entry():
            return False
    except Exception as e_msg:
        Taskmgrutil.set_task(task_path, status='error', progress="100", msg=f'Exception:{e_msg}')
    return True


if __name__ == '__main__':
    main(sys.argv)
