# encoding=utf-8
"""
功 能：记录日志
版权信息：华为技术有限公司，版本所有(C) 2019-2029
修改记录：2019-12-11 12:00 创建
"""
import datetime
import inspect
import json
import os
import re
import sys
import time
from threading import Thread

from commonlog import Logger
from taskmgr_util import Taskmgrutil

logger = Logger().getinstance(sys.argv[0])

INFO = "INFO"
WARN = "WARNING"
ERROR = "ERROR"

# 默认重试参数
DEFAULT_RETRY_TIMES = 0
DEFAULT_INTERVAL = 5
DEFAULT_OVERTIME = 3600


class UpgradeCustom:
    def __init__(self, args):
        self.productname = args.get("productname", "")
        self.src_version = args.get("src_version", "")
        self.des_version = args.get("des_version", "")
        self.middle_version = args.get("middle_version", "")
        self.is_chain_upgrade = args.get("is_chain_upgrade", "")
        self.site = args.get("site", "")
        self.action = args.get("action", "")
        self.scriptid = args.get("scriptid", "")
        self.cmd_reminder_time = int(args.get("cmd_reminder_time", 0)) * 60
        self.is_protection_hot = args.get("is_protection_hot", "")
        self.single_mgr_domain = args.get("single_mgr_domain", "")
        self.nce_microservices_patch = args.get("nce_microservices_patch", "")
        self.os_patch = args.get("os_patch", "")
        self.scriptpath = "/opt/upgrade/easysuite_upgrade/scripts"
        self.workpath = "/opt/upgrade/easysuite_upgrade/workpath/%s-%s/workpath-%s" % (
            self.src_version, self.des_version, self.productname)
        self.taskpath = os.path.join("/opt/upgrade/easysuite_upgrade/taskmgr", self.scriptid)
        self.taskmsg = os.path.join(self.taskpath, "task.log")
        self.taskstatus = os.path.join(self.taskpath, "task.status")
        self.taskprogress = os.path.join(self.taskpath, "task.process")
        self.taskmgr_function = Taskmgrutil()
        self.taskmgr_function.init_e_taskmgr(self.taskpath)
        self.msg = ""
        # 判断线程是否有失败
        self.thread_flag = True
        # 初始化
        if os.path.isdir(self.workpath):
            path_str = ""
            for file_path in os.listdir(self.workpath):
                if file_path == "plandata":
                    continue
                path_str = " %s %s" % (path_str, file_path)
            Taskmgrutil.execute_cmd("rm -rf %s" % path_str)
        Taskmgrutil.execute_cmd("mkdir -p %s" % self.workpath)
        self.get_product_info()
        Taskmgrutil.initial_error_path(self.workpath)

    @staticmethod
    def get_function_name():
        """
        获取正在运行函数(或方法)名称
        :return:
        """
        return inspect.stack()[1][3]

    def get_product_info(self):
        """
        调用平台查询产品信息接口获取产品信息
        """
        os.chmod(self.workpath, 0o700)
        retcode = 0
        for _ in range(0, 2):
            retcode, _ = Taskmgrutil.execute_cmd("bash /opt/oss/manager/tools/resmgr/queryproduct.sh "
                                                 "-pn %s -output %s" % (self.productname, self.workpath))
            if retcode == 0:
                break
            else:
                time.sleep(1)
        if retcode != 0:
            logger.info("[%s] Result:%s" % (self.get_function_name(), retcode))
            return False, "Failed to run queryproduct.sh"
        return True, ""

    def gen_param(self):
        """
        生成升级参数文件
        """
        param = {}
        param.update({"productname": self.productname})
        param.update({"src_version": self.src_version})
        param.update({"des_version": self.des_version})
        param.update({"middle_version": self.middle_version})
        param.update({"is_chain_upgrade": self.is_chain_upgrade})
        param.update({"is_protection_hot": self.is_protection_hot})
        param.update({"single_mgr_domain": self.single_mgr_domain})
        param.update({"site": self.site})
        param.update({"nce_microservices_patch": self.nce_microservices_patch})
        param.update({"os_patch": self.os_patch})
        with os.fdopen(os.open(os.path.join(self.workpath, "plandata.json"),
                               os.O_CREAT | os.O_WRONLY | os.O_TRUNC,
                               mode=0o660), "w", encoding="utf-8") as file_obj:
            json.dump(param, file_obj, indent=4)
        return True

    def one_thread(self, script, script_info, script_retry_info, timeout):
        logger.info("[%s] Start to run script: %s" % (self.get_function_name(), script))
        error_path = Taskmgrutil.error_path(self.workpath)
        error_code_file = script.split("/")[-1] + ".json"
        error_code_file_path = os.path.join(error_path, error_code_file)
        cmd = self.format_script_cmd(script, error_path)
        ret_code, ret_msg = Taskmgrutil.execute_cmd(cmd, timeout=timeout)
        if ret_msg.strip():
            self.print_msg(INFO, ret_msg.strip(), form="script_output")
        if ret_code == 0:
            logger.info("[%s] script: %s Result:%s Msg:%s" % (self.get_function_name(), script,
                                                              ret_code, ret_msg))
        # 脚本返回非0，判断是否重试
        if ret_code != 0:
            logger.error("[%s] script: %s Result:%s Msg:%s" % (self.get_function_name(), script,
                                                               ret_code, ret_msg))
            # 解析错误码(无错误码直接报错)
            error_code = self.query_script_error_code(error_code_file_path)
            if not error_code:
                return False

            # 根据错误码匹配重试配置（未配置的错误码不处理）
            if error_code not in script_retry_info.keys():
                return False

            # 执行重试
            retry_times = script_retry_info.get(error_code).get("retry_times", DEFAULT_RETRY_TIMES)
            interval = script_retry_info.get(error_code).get("interval", DEFAULT_INTERVAL)
            overtime = script_retry_info.get(error_code).get("overtime", DEFAULT_OVERTIME)
            if not retry_times:
                return False
            for i in range(retry_times):
                time.sleep(interval)
                self.print_msg(INFO, f"Start to retry script{script_info} for time {i + 1}, Max retry time: "
                                     f"{retry_times}, Code: {error_code}, Interval: {interval}s, Overtime: {overtime}s")

                # 清理历史错误码文件
                if os.path.isfile(error_code_file_path):
                    os.remove(error_code_file_path)

                ret_code, ret_msg = Taskmgrutil.execute_cmd(cmd, timeout=overtime)
                if ret_code == 0:
                    return True

                self.print_msg(INFO, "Script: %s Result:%s\nMsg:%s" % (script, ret_code, ret_msg))
                logger.error("[%s] script: %s Result:%s Msg:%s" % (self.get_function_name(), script,
                                                                   ret_code, ret_msg))
                continue
            return False
        return True

    def format_script_cmd(self, script, error_path):
        cmd = f"bash {script} -input {self.workpath} -error_path {error_path}"
        return cmd

    def query_script_error_code(self, error_code_file_path):
        """
        查询脚本返回错误码
        """
        if not os.path.isfile(error_code_file_path):
            return ""
        with open(error_code_file_path, "r") as f:
            error_info = json.loads(f.read())
        error_code = error_info.get("error_code", "")
        return error_code

    def print_msg(self, level, msg, form=""):
        nowtime = datetime.datetime.strftime(datetime.datetime.now(), '%Y-%m-%d %H:%M:%S')
        if level in ["info", "warn", "error"]:
            level = level.upper()
        str_msg = "[%s] [%s] %s" % (nowtime, os.getpid(), msg)
        if form == "script_output":
            str_msg = msg
        elif level == "ERROR":
            str_msg = "[%s] [%s] [%s] %s" % (nowtime, os.getpid(), level, msg)
        if self.msg == "":
            self.msg = str_msg
        else:
            self.msg = "%s\n%s" % (self.msg, str_msg)
        self.taskmgr_function.set_e_taskmsg(self.taskpath, self.msg)
        return True

    def con_exc(self, script_info, script, script_retry_info, timeout):
        self.print_msg(INFO, "Start to run script%s." % script_info)
        if not self.one_thread(script, script_info, script_retry_info, timeout):
            self.print_msg(INFO, "Failed to run script%s." % script_info)
            self.thread_flag = False
        self.print_msg(INFO, "Finished to run script%s." % script_info)

    def con_domain_cmd_process(self):
        """
        并行调用产品脚本
        """
        # 列出所有需要调用的子域脚本
        self.gen_param()
        scripts_list = []
        scripts_info = []
        scripts_timeout = []
        script_info_list = []
        scripts_retry_info = []
        for domain_path in os.listdir(self.scriptpath):
            domain_script_path = os.path.join(self.scriptpath, domain_path)
            if domain_path == "common":
                domain_script_path = os.path.join(self.scriptpath, domain_path, "NCE-Common")
            if not os.path.isdir(domain_script_path):
                continue
            # 单域定制的脚本超时
            pre_post_timeout = self.get_domain_timeout_config(domain_script_path)
            # 解析子域脚本重试配置
            scripts_retry_config = self.parse_domain_retry_config(domain_script_path)

            for script_name in os.listdir(domain_script_path):
                script_path = os.path.join(domain_script_path, script_name)
                if script_name.startswith(self.action) and script_name.endswith(".sh"):
                    scripts_list.append(script_path)
                    script_info = "(%s)[solution due to %s]" % (script_name, domain_path)
                    scripts_info.append(script_info)
                    scripts_timeout.append(pre_post_timeout)
                    scripts_retry_info.append(scripts_retry_config.get(script_name, {}).get("error_code", {}))
                    script_info_list.append({"script_name": script_name, "domain_path": domain_path})

        child_thread_tasks_data = dict()
        cmd_start_time = time.time()
        for i in range(len(scripts_list)):
            child_thread_name = 'child-thread-%s' % scripts_info[i]
            child_thread = Thread(target=self.con_exc,
                                  args=(scripts_info[i], scripts_list[i], scripts_retry_info[i], scripts_timeout[i]),
                                  name=child_thread_name)
            child_thread.start()
            # 任务监控信息
            script_name = script_info_list[i]["script_name"]
            script_domain = script_info_list[i]["domain_path"]
            child_thread_tasks_data.update(
                {script_name: {"thread": child_thread, "domain": script_domain, "countdown": 0}})

        self.monitor_cmd_tasks(child_thread_tasks_data, cmd_start_time)
        if not self.thread_flag:
            self.thread_flag = True
            return False
        return True

    @staticmethod
    def get_domain_timeout_config(domain_script_path):
        """
        获取子域脚本路径配置的前后置超时
        """
        pre_post_timeout = DEFAULT_OVERTIME
        timeout_config_file = os.path.join(domain_script_path, "timeout_config.json")
        if os.path.isfile(timeout_config_file):
            with open(timeout_config_file, "r") as f:
                timeout_config = json.load(f)
            pre_post_timeout = timeout_config.get("pre_post_timeout", DEFAULT_OVERTIME)
        return pre_post_timeout

    def parse_domain_retry_config(self, domain_script_path):
        """
        接卸子域脚本重试配置
        """
        # 解析对应领域配置的自动重试信息
        domain_retry_config_file = os.path.join(domain_script_path, "auto_retry.json")
        if not os.path.isfile(domain_retry_config_file):
            return {}
        with open(domain_retry_config_file, "r") as f:
            retry_config = json.loads(f.read())
        if not isinstance(retry_config, dict):
            return {}
        scripts_retry_config = retry_config.get("scripts", {})
        return scripts_retry_config

    def invoke_cmd(self, path, scripts_retry_config, domain=""):
        if os.path.isdir(path):
            temp_list = []
            # 仅以“两位数字_”开头的sh脚本能入库
            pattern = re.compile(r'\d{2}_\w*\.sh$')
            for script in os.listdir(path):
                if pattern.match(script):
                    temp_list.append(script)
            script_timeout = self.get_domain_timeout_config(os.path.abspath(os.path.join(path, os.pardir)))
            for script_name in sorted(temp_list, key=lambda i: int(i.split('_')[0])):
                script_path = os.path.join(path, script_name)
                script_info = "(%s)[solution due to %s]" % (script_name, domain)
                cmd_start_time = time.time()
                child_thread = Thread(target=self.con_exc,
                                      args=(script_info, script_path,
                                            scripts_retry_config.get(script_name, {}).get("error_code", {}),
                                            script_timeout),
                                      name="child-thread-0")
                child_thread.start()
                child_thread_tasks_data = {
                    script_name: {"thread": child_thread, "domain": domain, "countdown": 0}}
                self.monitor_cmd_tasks(child_thread_tasks_data, cmd_start_time)
                if not self.thread_flag:
                    self.thread_flag = True
                    return False
        return True

    def serial_cmd(self):
        # 此函数将会按照编号顺序调用脚本
        allfinish = True
        self.gen_param()
        for domain_path in os.listdir(self.scriptpath):
            domain_script_path = os.path.join(self.scriptpath, domain_path, self.action)
            if domain_path == "common":
                continue
            if not os.path.isdir(domain_script_path):
                continue
            scripts_retry_config = self.parse_domain_retry_config(domain_script_path)
            result = self.invoke_cmd(domain_script_path, scripts_retry_config, domain=domain_path)
            if not result:
                allfinish = False
                break
        if allfinish:
            # 最后才执行common域的脚本
            domain_script_path = os.path.join(self.scriptpath, "common", "NCE-Common", self.action)
            scripts_retry_config = self.parse_domain_retry_config(domain_script_path)
            result = self.invoke_cmd(domain_script_path, scripts_retry_config, domain="common")
            if not result:
                return False
            return True
        return False

    def main_entry(self):
        """
        功能说明:升级公共接口调度总入口
        :return:
        """
        steps = {"concurrency": self.con_domain_cmd_process, "serial": self.serial_cmd}
        for action, func_obj in steps.items():
            self.print_msg(INFO, f"Start to run scripts of domain({action}).")
            if not func_obj():
                self.print_msg(ERROR, f"Failed to run scripts of domain({action}).")
                Taskmgrutil.set_task(self.taskpath, status="error", progress='100',
                                     task_error=[Taskmgrutil.error_path(self.workpath)])
                return False
            self.print_msg(INFO, f"Finished to run scripts of domain({action}).")
        Taskmgrutil.set_task(self.taskpath, status="success", progress='100',
                             task_error=[Taskmgrutil.error_path(self.workpath)])
        return True

    def monitor_cmd_tasks(self, child_thread_tasks_data, cmd_start_time):
        # 格式化提示开始时间
        format_cmd_start_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(cmd_start_time))
        # 数字“57600”是“1970-01-02 00:00:00”, 只取:时分秒, 格式化时间
        format_cmd_reminder_time = time.strftime("%H:%M:%S",
                                                 time.localtime(57600 + self.cmd_reminder_time))
        # 循环开关
        is_on = True
        while is_on:
            time.sleep(1)
            is_on = False
            cost_time = time.time() - cmd_start_time
            format_cost_time = int(cost_time / 60)
            for script_name, task_data in child_thread_tasks_data.items():
                task_thread = task_data.get("thread")
                task_domain = task_data.get("domain")
                task_countdown = task_data.get("countdown")
                # 线程结束,直接跳过
                if not task_thread.is_alive():
                    continue
                # 有未结束线程,打开下次循环
                is_on = True
                # 执行未超时,直接跳过,未设置超时时间跳过
                if cost_time < self.cmd_reminder_time or self.cmd_reminder_time == 0:
                    continue
                # 初始化:0,每次归0打印1次警告
                if task_countdown > 0:
                    task_data["countdown"] -= 1
                    continue
                # 打印一次超时提示,倒计时重置60
                task_data["countdown"] = 60
                self.print_msg(WARN,
                               "WARN | Script is %s (solution to %s). Start Time: (%s). Cost Time: (%s)m"
                               % (
                               script_name, task_domain, format_cmd_start_time, format_cost_time))


def main(argv):
    args = {}
    args.update({"src_version": argv[2], "des_version": argv[4], "middle_version": argv[6],
                 "is_chain_upgrade": argv[8],
                 "productname": argv[10], "site": argv[12], "is_protection_hot": argv[14],
                 "single_mgr_domain": argv[16], "action": argv[18], "scriptid": argv[20],
                 "cmd_reminder_time": argv[22], "nce_microservices_patch": argv[24], "os_patch": argv[26]})
    custom_function = UpgradeCustom(args)
    return custom_function.main_entry()


if __name__ == '__main__':
    main(sys.argv)
