# -*- coding: UTF-8 -*-
# Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.

import time
import traceback

from java.lang import Exception as JException
from com.huawei.ism.tool.protocol.tlv.exception import TLVException

from common import resourceParse
from common import constant
from common.patch_service import RunHelper, RestService
from common.baseFactory import log, finishProcess, set_zone_progress
from cbb.frame.context import contextUtil
from cbb.frame.rest.restUtil import Tlv2Rest
from cbb.frame.rest import restData
from cbb.frame.base import baseUtil
from cbb.business.checkitems.nginx_service_status_check import sph7_nginx_service_check

UPGRADE_PROGRESS_ITEM = "upgrade"


def execute(data_dict):
    return RetryUpgradePatch(data_dict).execute()


class RetryUpgradePatch:
    def __init__(self, data_dict):
        self.data_dict = data_dict
        self.zone_ids = data_dict.get("zone_ids")
        log.info(data_dict, "zone_ids:{}".format(self.zone_ids))
        self.rest_service = RestService(self.data_dict)
        self.run_helper = RunHelper(self.data_dict)
        self.lang = self.data_dict['lang']
        self.resource = resourceParse.execute(self.lang)
        self.rest = contextUtil.getRest(self.data_dict)
        self.dev_type = contextUtil.getDevObj(self.data_dict).get("type", "")
        self.upgrade_res = baseUtil.UpgradeRes(self.dev_type)
        self.check_result = []

    @staticmethod
    def is_all_success(result):
        return all(zone_status == constant.OM_MSG_OP_UPD_LST_SYS_PROGRESS.UPD_UPD_SUCCESS for zone_status in result)

    def execute(self):
        """
        功能说明：执行zone节点重试。
        """
        try:
            flag, err_msg = self.check_upgrade_state()
            if not flag:
                return False, err_msg
            # 上传补丁包
            hot_patch_tgz_pkg = self.data_dict.get('packagePath')
            self.run_helper.upload_service(hot_patch_tgz_pkg)
            # 检查补丁包是否分发完成。
            log.info(self.data_dict, "begin to distribute patch package.")
            distribute_ret = RestService(self.data_dict).check_distribute_pkg_finish()
            if not distribute_ret[0]:
                log.error(self.data_dict, "Distribute package is failed!")
                return False, distribute_ret[1]
            log.info(self.data_dict, "execute upgrade")
            # 执行升级
            return self.upgrade()
        except Exception:
            log.error(self.data_dict, "execute upgrade retry." + traceback.format_exc())
            return False, self.resource.get("upgrade.faild")
        finally:
            finishProcess(self.data_dict, UPGRADE_PROGRESS_ITEM)

    def check_upgrade_state(self):
        """
            功能说明：升级前检查状态是否允许升级
            输入：工具框架上下文
            输出：bool检查结果False/True，str错误信息
        """

        # 重试之前需要检查设备状态是否在升级中
        recs = self.rest_service.query_upgrade_patch_process()
        zone_status = []
        process_params = []
        err_msgs = []
        for zone_info in recs:
            log.info(self.data_dict, "upgrade patch zone process:" + str(zone_info))
            status = Tlv2Rest.getRecordValue(zone_info, restData.Upgrade.UpdLstSysProgress.CMO_UPD_SYS_STEP_TASK_STATUS)
            remain_time = Tlv2Rest.getRecordValue(zone_info,
                                                  restData.Upgrade.UpdLstSysProgress.CMO_UPD_SYS_STEP_REMAINTIME)
            progress = Tlv2Rest.getRecordValue(zone_info,
                                               restData.Upgrade.UpdLstSysProgress.CMO_UPD_SYS_STEP_STEP_PERCENT)
            zone_id = Tlv2Rest.getRecordValue(zone_info, restData.Upgrade.UpdLstSysProgress.CMO_UPD_SYS_ZONE_ID)
            log.info(self.data_dict, "set zone process start.zone_id:{},progress:{}".format(zone_id, progress))
            progress_dict = {"zone_id": zone_id, "status": status, "progress": progress, "remain_time": remain_time}
            # 不能执行升级的状态有（升级中，回滚中，同步中）
            if status == constant.OM_MSG_OP_UPD_LST_SYS_PROGRESS.UPD_UPDING:  # 升级中判断
                err_msg = self.resource.get("upgrade.upding")
                err_msgs.append(err_msg)
                zone_status.append(status)
                process_params.append(progress_dict)
            elif status == constant.OM_MSG_OP_UPD_LST_SYS_PROGRESS.UPD_ROLLBACKING:  # 回滚中判断
                err_msg = self.resource.get("upgrade.rollbacking")
                err_msgs.append(err_msg)
                zone_status.append(status)
                process_params.append(progress_dict)
        # 不支持重试，刷新进度信息
        if err_msgs:
            set_zone_progress(self.data_dict, process_params, UPGRADE_PROGRESS_ITEM)
            return False, "\n".join(err_msgs)
        return True, ""

    def upgrade(self):
        """
        功能说明：执行升级并轮询升级结果
        输入：工具框架上下文
        输出：bool执行结果False/True，str错误信息
        """
        try:
            # 执行升级命令。
            self.rest_service.execute_upgrade_patch()
            time.sleep(5)
            flag, err_msg = self.query_upgrade_result()
            if not flag:
                return False, err_msg
        except Exception:
            log.error(self.data_dict, "upgrade error!" + str(traceback.format_exc()))
            return False, self.resource.get("upgrade.faild")

        # 补丁安装成功的场景检查nginx服务
        flag, err_msg = sph7_nginx_service_check(self.data_dict, True)
        if not flag:
            log.error(self.data_dict, "check nginx service error..")
            return False, err_msg
        return True, ''

    def query_upgrade_result(self):
        poll_count = 0  # 轮询计数器
        status = None
        # 产品BUG会导致补丁安装状态从10跳变到0，工具每10s查询一次，查到0了也要继续
        upgrade_status_time = 0
        err_msg = ""
        while constant.OM_MSG_OP_UPD_LST_SYS_PROGRESS.UPD_UPD_SUCCESS != status:
            try:
                poll_count = poll_count + 1
                if poll_count > constant.UPGRADE_PACKAGE_POLL_TIMES:
                    return False, self.resource.get("upgrade.faild")
                # 查询升级进度
                recs = self.rest_service.query_upgrade_patch_process()
                zone_status = []
                self.check_upgrade_result(recs, zone_status, upgrade_status_time)
                log.info(self.data_dict, "upgrade result is {}".format(zone_status))
                # 所有重试的zone升级结束，且有升级补丁异常，需查询升级详情。
                if len(zone_status) == len(recs):
                    return self.is_all_success(zone_status), ""
                baseUtil.safeSleep(constant.UPGRADE_PACKAGE_POLL_INTERVAL)
            except (JException, Exception) as e:
                log.error(self.data_dict, "device upgrade query progress error" + str(e))
                # 第一个原子执行时间较长，会抛出该错误码，继续执行查询
                if isinstance(e, TLVException) and constant.OM_OPERATE_FAIL_CODE == e.getErrorId():
                    continue
                flag, err_msg = self.run_helper.check_dev_reboot()
                if flag:
                    continue
                if not err_msg:
                    err_msg = self.resource.get("upgrade.rebooterr")
                return False, err_msg
            finally:
                self.finish_zone_process(self.data_dict, err_msg)
        return True, ""

    def finish_zone_process(self, data_dict, err_msg):
        if not self.check_result:
            log.info(data_dict, "query zone failed CHECK_RESULT:" + str(self.check_result))
            return
        if not err_msg:
            return
        for process in self.check_result:
            process["status"] = constant.OM_MSG_OP_UPD_LST_SYS_PROGRESS.UPD_UPD_FAIL
            process["progress"] = 99
            process["err_msg"] = err_msg
        set_zone_progress(data_dict, self.check_result, UPGRADE_PROGRESS_ITEM)

    def check_upgrade_result(self, recs, zone_status, upgrade_status_time):
        """
        查询A800设备补丁安装的进度及详细请
        """

        err_msgs = []
        for zone_info in recs:
            log.info(self.data_dict, "upgrade patch zone process:" + str(zone_info))
            status = Tlv2Rest.getRecordValue(zone_info, restData.Upgrade.UpdLstSysProgress.CMO_UPD_SYS_STEP_TASK_STATUS)
            remain_time = Tlv2Rest.getRecordValue(zone_info,
                                                  restData.Upgrade.UpdLstSysProgress.CMO_UPD_SYS_STEP_REMAINTIME)
            progress = Tlv2Rest.getRecordValue(zone_info,
                                               restData.Upgrade.UpdLstSysProgress.CMO_UPD_SYS_STEP_STEP_PERCENT)
            zone_id = Tlv2Rest.getRecordValue(zone_info, restData.Upgrade.UpdLstSysProgress.CMO_UPD_SYS_ZONE_ID)
            log.info(self.data_dict, "set zone process start.zone_id:{},progress:{}".format(zone_id, progress))
            progress_dict = {"zone_id": zone_id, "status": status, "progress": progress, "remain_time": remain_time}
            if not RunHelper.check_upgrade_status_is_upgrade(status, upgrade_status_time):
                log.info(self.data_dict, "upgrade patch zone status:" + str(status))
                zone_status.append(status)
                err_msg = self.query_upgrade_detail(status)
                progress_dict["err_msg"] = err_msg
                err_msgs.append(err_msg)

            if status == constant.OM_MSG_OP_UPD_LST_SYS_PROGRESS.UPD_UPD_SUCCESS:
                zone_status.append(status)
            self.update_zone_process(progress_dict)
        # 刷新进度信息
        set_zone_progress(self.data_dict, self.check_result, UPGRADE_PROGRESS_ITEM)

    def update_zone_process(self, progress_dict):
        for i, process in enumerate(self.check_result):
            if progress_dict.get("zone_id") == process.get("zone_id"):
                self.check_result[i] = progress_dict
                return
        self.check_result.append(progress_dict)

    def query_upgrade_detail(self, status):
        err_msg = ""
        # 查询升级详情
        log.info(self.data_dict, "system status is not in upgrading, status is:" + str(status))
        if not RunHelper.check_upgrade_status_is_normal(status):
            log.info(self.data_dict, "system status is not normal.")
        # 查询升级详细信息
        recs = self.rest_service.query_upgrade_detail()
        for rec in recs:
            node_id = Tlv2Rest.getRecordValue(rec, restData.Upgrade.UpdListDetaiInfo.NODE_ID)
            name_key = Tlv2Rest.getRecordValue(rec, restData.Upgrade.UpdListDetaiInfo.NAME_KEY)
            item_state = Tlv2Rest.getRecordValue(rec, restData.Upgrade.UpdListDetaiInfo.ITEM_STATE)
            error_key = Tlv2Rest.getRecordValue(rec, restData.Upgrade.UpdListDetaiInfo.ERROY_KEY)
            if item_state != '3':
                continue

            if not error_key:
                err_msg += "%s--%s--%s\n" % (
                    node_id, self.upgrade_res.get_res(name_key), self.upgrade_res.get_res('failed'))
                continue

            error_key_list = error_key.split(":")
            error_key_list_len = len(error_key_list)
            error_key = error_key_list[0]
            try:
                if error_key_list_len >= 2:
                    first_comma_loc = error_key.find(":")  # 第一个冒号前为key值，后面为错误详情参数，此方法避免参数中存在冒号
                    msg_key = error_key[:first_comma_loc]
                    err_param = error_key[first_comma_loc + 1:]  # 冒号之后为错误参数信息
                    parse_error_info_ret = RunHelper.parse_atom_info(err_param)
                    log.info(self.data_dict, "msgKey: %s parseErrorInfoRet:%s" % (msg_key, parse_error_info_ret))
                    err_msg += self.upgrade_res.get_res(msg_key) % parse_error_info_ret + '\n'
                else:
                    err_msg += self.upgrade_res.get_res(error_key) + '\n'
            except Exception:
                log.error(self.data_dict, "parse atom info failed:%s" % err_msg)
                err_msg = self.resource.get("upgrade.faild")
        return err_msg
