# encoding=utf-8
"""
功 能：记录日志
版权信息：华为技术有限公司，版本所有(C) 2019-2029
修改记录：2019-12-11 12:00 创建
"""
import datetime
import inspect
import json
import os
import re
import sys
import time
from threading import Thread

from commonlog import Logger
from taskmgr_util import Taskmgrutil

logger = Logger().getinstance(sys.argv[0])

INFO = "INFO"
WARN = "WARNING"
ERROR = "ERROR"


class UpgradeCustom:
    def __init__(self, args):
        self.productname = args.get("productname", "")
        self.src_version = args.get("src_version", "")
        self.des_version = args.get("des_version", "")
        self.middle_version = args.get("middle_version", "")
        self.is_chain_upgrade = args.get("is_chain_upgrade", "")
        self.site = args.get("site", "")
        self.action = args.get("action", "")
        self.scriptid = args.get("scriptid", "")
        self.cmd_reminder_time = int(args.get("cmd_reminder_time", 0)) * 60
        self.is_protection_hot = args.get("is_protection_hot", "")
        self.single_mgr_domain = args.get("single_mgr_domain", "")
        self.scriptpath = "/opt/upgrade/easysuite_upgrade/scripts"
        self.workpath = "/opt/upgrade/easysuite_upgrade/workpath/%s-%s/workpath-%s" % (
            self.src_version, self.des_version, self.productname)
        self.taskpath = os.path.join("/opt/upgrade/easysuite_upgrade/taskmgr", self.scriptid)
        self.taskmsg = os.path.join(self.taskpath, "task.log")
        self.taskstatus = os.path.join(self.taskpath, "task.status")
        self.taskprogress = os.path.join(self.taskpath, "task.process")
        self.taskmgr_function = Taskmgrutil()
        self.taskmgr_function.init_e_taskmgr(self.taskpath)
        self.msg = ""
        # 判断线程是否有失败
        self.thread_flag = True
        # 初始化
        if os.path.isdir(self.workpath):
            path_str = ""
            for file_path in os.listdir(self.workpath):
                if file_path == "plandata":
                    continue
                path_str = " %s %s" % (path_str, file_path)
            os.system("rm -rf %s" % path_str)
        os.system("mkdir -p %s" % self.workpath)
        self.get_product_info()

    @staticmethod
    def get_function_name():
        """
        获取正在运行函数(或方法)名称
        :return:
        """
        return inspect.stack()[1][3]

    def get_product_info(self):
        """
        调用平台查询产品信息接口获取产品信息
        """
        retcode = 0
        for _ in range(0, 2):
            retcode = os.system("bash /opt/oss/manager/tools/resmgr/queryproduct.sh -pn %s -output "
                                "%s" % (self.productname, self.workpath))
            if retcode == 0:
                break
            else:
                time.sleep(1)
        if retcode != 0:
            logger.info("[%s] Result:%s" % (self.get_function_name(), retcode))
            return False, "Failed to run queryproduct.sh"
        return True, ""

    def gen_param(self):
        """
        生成升级参数文件
        """
        param = {}
        param.update({"productname": self.productname})
        param.update({"src_version": self.src_version})
        param.update({"des_version": self.des_version})
        param.update({"middle_version": self.middle_version})
        param.update({"is_chain_upgrade": self.is_chain_upgrade})
        param.update({"is_protection_hot": self.is_protection_hot})
        param.update({"single_mgr_domain": self.single_mgr_domain})
        param.update({"site": self.site})
        with os.fdopen(os.open(os.path.join(self.workpath, "plandata.json"),
                               os.O_CREAT | os.O_WRONLY | os.O_TRUNC,
                               mode=0o660), "w") as file_obj:
            json.dump(param, file_obj, indent=4)
        return True

    def one_thread(self, script):
        logger.info("[%s] Start to run script: %s" % (self.get_function_name(), script))
        ret_code, ret_msg = Taskmgrutil.execute_cmd("bash %s -input %s" % (script, self.workpath))
        if ret_msg.strip():
            self.print_msg(INFO, ret_msg.strip(), form="script_output")
        if ret_code == 0:
            logger.info("[%s] script: %s Result:%s Msg:%s" % (self.get_function_name(), script,
                                                              ret_code, ret_msg))
        if ret_code != 0:
            self.print_msg(INFO, "Script: %s Result:%s\nMsg:%s" % (script, ret_code, ret_msg))
            logger.error("[%s] script: %s Result:%s Msg:%s" % (self.get_function_name(), script,
                                                               ret_code, ret_msg))
            return False
        return True

    def print_msg(self, level, msg, form=""):
        nowtime = datetime.datetime.strftime(datetime.datetime.now(), '%Y-%m-%d %H:%M:%S')
        if level in ["info", "warn", "error"]:
            level = level.upper()
        str_msg = "[%s] [%s] %s" % (nowtime, os.getpid(), msg)
        if form == "script_output":
            str_msg = msg
        elif level == "ERROR":
            str_msg = "[%s] [%s] [%s] %s" % (nowtime, os.getpid(), level, msg)
        if self.msg == "":
            self.msg = str_msg
        else:
            self.msg = "%s\n%s" % (self.msg, str_msg)
        self.taskmgr_function.set_e_taskmsg(self.taskpath, self.msg)
        return True

    def con_exc(self, script_info, script):
        self.print_msg(INFO, "Start to run script%s." % script_info)
        if not self.one_thread(script):
            self.print_msg(INFO, "Failed to run script%s." % script_info)
            self.thread_flag = False
        self.print_msg(INFO, "Finished to run script%s." % script_info)

    def con_domain_cmd_process(self):
        """
        并行调用产品脚本
        """
        # 列出所有需要调用的子域脚本
        self.gen_param()
        scripts_list = []
        scripts_info = []
        script_info_list = list()
        for domain_path in os.listdir(self.scriptpath):
            domain_script_path = os.path.join(self.scriptpath, domain_path)
            if domain_path == "common":
                domain_script_path = os.path.join(self.scriptpath, domain_path, "NCE-Common")
            if not os.path.isdir(domain_script_path):
                continue
            for script_name in os.listdir(domain_script_path):
                script_path = os.path.join(domain_script_path, script_name)
                if script_name.startswith(self.action) and script_name.endswith(".sh"):
                    scripts_list.append(script_path)
                    script_info = "(%s)[solution due to %s]" % (script_name, domain_path)
                    scripts_info.append(script_info)
                    script_info_list.append({"script_name": script_name, "domain_path": domain_path})
        child_thread_tasks_data = dict()
        cmd_start_time = time.time()
        for i in range(len(scripts_list)):
            child_thread_name = 'child-thread-%s' % scripts_info[i]
            child_thread = Thread(target=self.con_exc, args=(scripts_info[i], scripts_list[i]), name=child_thread_name)
            child_thread.start()
            # 任务监控信息
            script_name = script_info_list[i]["script_name"]
            script_domain = script_info_list[i]["domain_path"]
            child_thread_tasks_data.update(
                {script_name: {"thread": child_thread, "domain": script_domain, "countdown": 0}})

        self.monitor_cmd_tasks(child_thread_tasks_data, cmd_start_time)
        if not self.thread_flag:
            self.thread_flag = True
            return False
        return True

    def invoke_cmd(self, path, domain=""):
        if os.path.isdir(path):
            temp_list = []
            # 仅以“两位数字_”开头的sh脚本能入库
            pattern = re.compile(r'\d{2}_\w*\.sh$')
            for script in os.listdir(path):
                if pattern.match(script):
                    temp_list.append(script)
            for script_name in sorted(temp_list, key=lambda i: int(i.split('_')[0])):
                script_path = os.path.join(path, script_name)
                script_info = "(%s)[solution due to %s]" % (script_name, domain)
                cmd_start_time = time.time()
                child_thread = Thread(target=self.con_exc, args=(script_info, script_path), name="child-thread-0")
                child_thread.start()
                child_thread_tasks_data = {script_name: {"thread": child_thread, "domain": domain, "countdown": 0}}
                self.monitor_cmd_tasks(child_thread_tasks_data, cmd_start_time)
                if not self.thread_flag:
                    self.thread_flag = True
                    return False
        return True

    def serial_cmd(self):
        # 此函数将会按照编号顺序调用脚本
        allfinish = True
        self.gen_param()
        for domain_path in os.listdir(self.scriptpath):
            domain_script_path = os.path.join(self.scriptpath, domain_path, self.action)
            if domain_path == "common":
                continue
            if not os.path.isdir(domain_script_path):
                continue
            result = self.invoke_cmd(domain_script_path, domain=domain_path)
            if not result:
                allfinish = False
                break
        if allfinish:
            # 最后才执行common域的脚本
            domain_script_path = os.path.join(self.scriptpath, "common", "NCE-Common", self.action)
            result = self.invoke_cmd(domain_script_path, domain="common")
            if not result:
                return False
            return True
        return False

    def domain_cmd_process(self):
        """
        串行调用产品脚本(后续可以改为并行，只要子域支持）
        """
        allfinish = True
        self.gen_param()
        for domain_path in os.listdir(self.scriptpath):
            domain_script_path = os.path.join(self.scriptpath, domain_path)
            if domain_path == "common":
                domain_script_path = os.path.join(self.scriptpath, domain_path, "NCE-Common")
            if not os.path.isdir(domain_script_path):
                continue
            allfinish = self.process_single_domain_scripts(allfinish, domain_script_path)
        if allfinish:
            return True
        return False

    def process_single_domain_scripts(self, allfinish, domain_script_path):
        """
        功能说明:执行领域目录下脚本脚本
        :param allfinish:
        :param domain_script_path:
        :return:
        """
        for script_name in os.listdir(domain_script_path):
            script_path = os.path.join(domain_script_path, script_name)
            if script_name.startswith(self.action) and script_name.endswith(".sh"):
                self.print_msg(INFO, "Start to run script(%s)." % script_name)
                if not self.one_thread(script_path):
                    self.print_msg(INFO, "Failed to run script(%s)." % script_name)
                    allfinish = False
                    break
                self.print_msg(INFO, "Finished to run script(%s)." % script_name)
        return allfinish

    def main_entry(self, scriptid_path):
        """
        功能说明:升级公共接口调度总入口
        :return:
        """
        self.print_msg(INFO, "Start to run scripts of domain(concurrency).")
        result = self.con_domain_cmd_process()
        if result:
            self.print_msg(INFO, "Finished to run scripts of domain(concurrency).")
            self.print_msg(INFO, "Start to run scripts of domain(serial).")
            result = self.serial_cmd()
            if not result:
                self.print_msg(ERROR, "Failed to run scripts of domain(serial).")
                Taskmgrutil.set_e_taskstatus(scriptid_path, "error")
                Taskmgrutil.set_e_taskprogress(scriptid_path, "100")
                return False
            self.print_msg(INFO, "Finished to run scripts of domain(serial).")
            Taskmgrutil.set_e_taskstatus(scriptid_path, "success")
            Taskmgrutil.set_e_taskprogress(scriptid_path, "100")
            return True
        self.print_msg(ERROR, "Failed to run scripts of domain(concurrency).")
        Taskmgrutil.set_e_taskstatus(scriptid_path, "error")
        Taskmgrutil.set_e_taskprogress(scriptid_path, "100")
        return False

    def monitor_cmd_tasks(self, child_thread_tasks_data, cmd_start_time):
        # 格式化提示开始时间
        format_cmd_start_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(cmd_start_time))
        # 数字“57600”是“1970-01-02 00:00:00”, 只取:时分秒, 格式化时间
        format_cmd_reminder_time = time.strftime("%H:%M:%S", time.localtime(57600 + self.cmd_reminder_time))
        # 循环开关
        is_on = True
        while is_on:
            time.sleep(1)
            is_on = False
            cost_time = time.time() - cmd_start_time
            format_cost_time = int(cost_time / 60)
            for script_name, task_data in child_thread_tasks_data.items():
                task_thread = task_data.get("thread")
                task_domain = task_data.get("domain")
                task_countdown = task_data.get("countdown")
                # 线程结束,直接跳过
                if not task_thread.is_alive():
                    continue
                # 有未结束线程,打开下次循环
                is_on = True
                # 执行未超时,直接跳过,未设置超时时间跳过
                if cost_time < self.cmd_reminder_time or self.cmd_reminder_time == 0:
                    continue
                # 初始化:0,每次归0打印1次警告
                if task_countdown > 0:
                    task_data["countdown"] -= 1
                    continue
                # 打印一次超时提示,倒计时重置60
                task_data["countdown"] = 60
                self.print_msg(WARN, "WARN | Script is %s (solution to %s). Start Time: (%s). Cost Time: (%s)m"
                               % (script_name, task_domain, format_cmd_start_time, format_cost_time))


def main(argv):
    args = {}
    args.update({"src_version": argv[2], "des_version": argv[4], "middle_version": argv[6], "is_chain_upgrade": argv[8],
                 "productname": argv[10], "site": argv[12], "is_protection_hot": argv[14],
                 "single_mgr_domain": argv[16], "action": argv[18], "scriptid": argv[20],
                 "cmd_reminder_time": argv[22]})
    scriptid_path = os.path.join("/opt/upgrade/easysuite_upgrade/taskmgr", args.get("scriptid"))
    custom_function = UpgradeCustom(args)
    return custom_function.main_entry(scriptid_path)


if __name__ == '__main__':
    main(sys.argv)
