# -*- coding: UTF-8 -*-
# Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.

import traceback

from com.huawei.ism.tlv.lang import UnsignedInt32
from com.huawei.ism.tool.protocol.tlv.cmds.CommandConstans import UpgradeStatus
from com.huawei.ism.exception import IsmException
from java.lang import Exception as JException

from cbb.frame.context import contextUtil
from common import constant
from cbb.frame.base import baseUtil
from cbb.frame.rest import restData
from cbb.frame.rest.restUtil import Tlv2Rest, REST_CAN_NOT_EXECUTE
from cbb.business.checkitems.nginx_service_status_check import sph7_nginx_service_check

# 分包超时时间
DISTRIBUTE_PKG_POLL_TIME = 450

# 补丁安装超时时间
UPGRADE_REBOOT_AFTER_QUERY_STATUS_TIME = 300


class RunHelper:
    def __init__(self, data_dict):
        self.data_dict = data_dict
        self.logger = data_dict.get('logger')
        self.rest_service = RestService(data_dict)

    @staticmethod
    def check_upgrade_status_is_normal(status):
        """
        功能说明：检查升级中的状态是不是正常的
        输入：升级状态
        输出：bool检查结果False/True
        """

        # 检查升级状态是不是处于回退阶段
        return status not in (constant.OM_MSG_OP_UPD_LST_SYS_PROGRESS.UPD_ROLLBACKING,
                              constant.OM_MSG_OP_UPD_LST_SYS_PROGRESS.UPD_ROLLBACK_SUCCESS,
                              constant.OM_MSG_OP_UPD_LST_SYS_PROGRESS.UPD_ROLLBACK_FAIL,
                              constant.OM_MSG_OP_UPD_LST_SYS_PROGRESS.UPD_UPD_FAIL)

    @staticmethod
    def check_upgrade_status_is_upgrade(status, upgrade_start_status_time=0):
        """
        功能说明：检查升级中的状态是不是升级中和升级成功例外场景：
        产品BUG会导致补丁安装状态从10跳变到0，工具每10s查询一次，查到初始化状态（0）了也要继续
        输入：升级状态
        输入：upgrade_status_time 升级初始化中轮询的次数，如果
        输出：bool检查结果False/True
        """
        # 轮询初始化状态（0）轮询1次或者2次都是初始化状态，需要继续查询。
        if upgrade_start_status_time == 1 or upgrade_start_status_time == 2:
            return True

        # 检查升级状态是否处于升级中和升级成功
        return status in (
            constant.OM_MSG_OP_UPD_LST_SYS_PROGRESS.UPD_UPDING,
            constant.OM_MSG_OP_UPD_LST_SYS_PROGRESS.UPD_UPD_SUCCESS)

    @staticmethod
    def parse_atom_info(err_param):
        """
                描述：转换升级原子信息
                原子错误信息格式：参数1,参数2
                参数之间用逗号隔开
                参数：errParam：错误信息字符串
                返回：解析后的错误信息参数元组
        """
        return tuple(err_param.split(","))

    def check_dev_reboot(self):
        """
        功能说明：等待设备重启完成后，服务正常启动，重新建立连接
        输入：工具上下文
        输出：bool连接结果False/True
        """
        # 使用cli命令打开TLV连接通道。
        cli = self.data_dict['ssh']
        # 进行UPGRADE_REBOOT_AFTER_CONNECT_TIMES次重试连接，如果连接成功则返回，如果全部没有连接成功在返回失败
        reboot_after_connect_suc = False
        for times in range(0, constant.UPGRADE_REBOOT_AFTER_CONNECT_TIMES):
            try:
                cli.reConnect()
                contextUtil.getRest(self.data_dict)
                reboot_after_connect_suc = True
                break
            except (Exception, JException) as e:
                # 在rest无法获取进度时，检查nginx服务是否正常
                flag, err_msg = sph7_nginx_service_check(self.data_dict)
                if not flag:
                    # 补丁已经安装完成，但nginx服务存在问题，报错提示
                    self.logger.error("check nginx service error")
                    return False, err_msg
                baseUtil.safeSleep(constant.UPGRADE_REBOOT_AFTER_CONNECT_INTERVAL)
                self.logger.error("dev reboot after connect error " + str(times + 1) + " times" + str(e))

        self.logger.info("device reboot moniter finished . isOK:" + str(reboot_after_connect_suc))
        return reboot_after_connect_suc, ""

    def upload_service(self, pkg_path):
        """
        上传补丁包
        """
        dev = self.data_dict.get("dev")
        sftp = self.data_dict.get("sftp")
        logger = self.data_dict.get("logger")
        try:
            upload_path = self.rest_service.get_upload_path()
            # sftp上传补丁包
            logger.info("Hot patch package upload ip = " + str(dev.getIp()))
            sftp.putFile(dev.getIp(),
                         pkg_path,
                         upload_path,
                         dev.getLoginUser(),
                         None)

            dev_pkg_path = dev.getIp() + ":" + upload_path
            self.rest_service.notify_distribute_pkg(dev_pkg_path)
        except (Exception, JException) as e:
            logger.error("Upload hot patch package exception:" + traceback.format_exc())
            return False, ''
        return True, upload_path


class RestService:
    def __init__(self, data_dict):
        self.data_dict = data_dict
        self.zone_ids = data_dict.get("zone_ids")
        self.logger = data_dict.get('logger')

    def get_upload_path(self):
        """
        获取补丁上传路径，命令字：8599109754
        """
        rest = contextUtil.getRest(self.data_dict)
        param0 = (restData.Upgrade.GetPackageUploadPath.CMO_PACKAGE_TYPE, UnsignedInt32(3))
        param2 = (restData.Upgrade.GetPackageUploadPath.CMO_PACKAGE_SIZE, UnsignedInt32(0))
        params = [param0, param2]
        if self.zone_ids:
            self.logger.info('zone_ids %s' % self.zone_ids)
            zone_param = (restData.Upgrade.GetPackageUploadPath.CMO_PACKAGE_ZONE_IDS, self.zone_ids)
            params.append(zone_param)
        self.logger.info('Execute cmd [code:%d] [%s]' % (restData.TlvCmd.OM_MSG_OP_GET_PACKAGE_UPLOADPATH.get('cmd'),
                                                         params))
        recs = Tlv2Rest.execCmdJlist(rest, restData.TlvCmd.OM_MSG_OP_GET_PACKAGE_UPLOADPATH, params)
        rec = recs[0]
        return Tlv2Rest.getRecordValue(rec, restData.Upgrade.GetPackageUploadPath.CMO_PACKAGE_DIR_FILE)

    def notify_distribute_pkg(self, dev_pkg_path):
        """
        上传完成，通知分发包给指定的zone 命令字：8599117947
        """
        self.logger.info("package upload full path:%s" % dev_pkg_path)
        rest = contextUtil.getRest(self.data_dict)
        params = []
        pkg_path_param = (restData.Upgrade.NotifyPackagePath.CMO_NOTIFY_PACKAGE_PATH, str(dev_pkg_path))
        params.append(pkg_path_param)
        if self.zone_ids:
            self.logger.info('zone_ids %s' % self.zone_ids)
            zone_param = (restData.Upgrade.NotifyPackagePath.CMO_NOTIFY_ZONE_IDS, self.zone_ids)
            params.append(zone_param)
        self.logger.info('Execute cmd [code:%d]' % (restData.TlvCmd.OM_MSG_OP_NOTIFY_PACKAGE_PATH.get('cmd')))
        Tlv2Rest.execCmdJlist(rest, restData.TlvCmd.OM_MSG_OP_NOTIFY_PACKAGE_PATH, params)

    def check_distribute_pkg_finish(self):
        """
        判断是否分包完成 命令字：8599109744
        """
        dev_dld_query_time = 0
        status = None
        rest = contextUtil.getRest(self.data_dict)
        while UpgradeStatus.UPD_DLD_SUCCESS != status:
            # 10分钟还没有分包成功，则认为分包失败
            dev_dld_query_time = dev_dld_query_time + 1
            if dev_dld_query_time > DISTRIBUTE_PKG_POLL_TIME:
                self.logger.error("Distribute package timeout!")
                return False, "upload.distribute.pkg.timeout"
            try:
                self.logger.info(
                    'Execute cmd [code:%d] []' % (restData.TlvCmd.OM_MSG_OP_UPD_LST_SYS_PROGRESS.get('cmd')))
                params = []
                msg_param1 = (
                    restData.Upgrade.UpdLstSysProgress.CMO_UPD_FLOW_ID,
                    restData.Enum.UpgQueryFlowEnum.DISTRIBUTE_PACKAGE
                )
                params.append(msg_param1)
                if self.zone_ids:
                    zone_param = (restData.Upgrade.UpdLstSysProgress.CMO_UPD_SYS_ZONE_IDS, self.zone_ids)
                    self.logger.info('zone_ids %s' % self.zone_ids)
                    params.append(zone_param)
                recs = Tlv2Rest.execCmdJlist(rest, restData.TlvCmd.OM_MSG_OP_UPD_LST_SYS_PROGRESS, params)
                rec = recs[0]
                upgrade_param = restData.Upgrade.LstSysProgress
                status_index = Tlv2Rest.getRecordValue(rec, upgrade_param.CMO_UPD_SYS_STEP_TASK_STATUS)
                status = UpgradeStatus.values()[status_index]

                if status == UpgradeStatus.UPD_DLD_FAIL:
                    self.logger.error("Distribute package failed!")
                    return False, "preprocess.distribute.pkg.failed"
                baseUtil.safeSleep(10)
            except (Exception, JException) as exp:
                if isinstance(exp, IsmException) and str(exp.getErrorId()) == str(REST_CAN_NOT_EXECUTE):
                    self.logger.error("query upgrade process exception.[%s]" % str(exp))
                    rest = self.get_rest_conect(rest)
                    continue
                self.logger.error("Distribute package exception:%s" % (traceback.format_exc()))
                return False, "preprocess.distribute.pkg.exception"
        return True, ""

    def get_rest_conect(self, rest):
        try:
            baseUtil.safeSleep(10)
            rest = contextUtil.getRest(self.data_dict)
        except (Exception, JException):
            self.logger.error("creat rest connection exception.")
        return rest

    def execute_upgrade_patch(self):
        """
        执行补丁安装命令，命令字：8599117931
        """
        # 升级类型都如果，model_type有值，使用已有值，否则为在线升级5
        model_num = self.data_dict.get("model_type") if self.data_dict.get("model_type") else 5
        params = []
        param2 = (restData.Upgrade.NotifyExcUpgrade.CMO_EXE_UPD_ACTIVETYPE, UnsignedInt32(model_num))
        params.append(param2)
        flow_choice_num = self.data_dict.get('flowchoice')
        if flow_choice_num:
            msg_param11 = (restData.Upgrade.NotifyExcUpgrade.CMO_EXC_ERROR_INFO, flow_choice_num)
            params.append(msg_param11)
        if self.zone_ids:
            self.logger.info('zone_ids %s' % self.zone_ids)
            zone_param = (restData.Upgrade.NotifyExcUpgrade.CMO_EXC_ZONE_IDS, self.zone_ids)
            params.append(zone_param)
        # 执行升级命令
        rest = contextUtil.getRest(self.data_dict)
        Tlv2Rest.execCmdJlist(rest, restData.TlvCmd.EXC_UPD, params)

    def query_upgrade_patch_process(self):
        """
        查询补丁安装状态和进度命令，命令字：8599109744
        """
        rest = contextUtil.getRest(self.data_dict)
        params = []
        msg_param5 = (restData.Upgrade.UpdLstSysProgress.CMO_UPD_FLOW_ID, restData.Enum.UpgQueryFlowEnum.PATCH_INSTALL)
        params.append(msg_param5)
        if self.zone_ids:
            self.logger.info('zone_ids %s' % self.zone_ids)
            zone_param = (restData.Upgrade.UpdLstSysProgress.CMO_UPD_SYS_ZONE_IDS, self.zone_ids)
            params.append(zone_param)
        recs = Tlv2Rest.execCmdJlist(rest, restData.TlvCmd.OM_MSG_OP_UPD_LST_SYS_PROGRESS, params)
        return recs

    def query_upgrade_detail(self):
        """
        查询A800指定zone补丁安装的详情信息，命令字：8599150995
        """
        rest = contextUtil.getRest(self.data_dict)
        params = []
        msg_param5 = (restData.Upgrade.UpdLstSysProgress.CMO_UPD_FLOW_ID, restData.Enum.UpgQueryFlowEnum.PATCH_INSTALL)
        params.append(msg_param5)
        if self.zone_ids:
            zone_param = (restData.Upgrade.UpdListDetaiInfo.CMO_DETAIL_ZONE_ID, self.zone_ids)
            params.append(zone_param)
        recs = Tlv2Rest.execCmdJlist(rest, restData.TlvCmd.OM_MSG_OP_UPD_LIST_DETAILINFO, params)
        return recs
