# -*- coding: UTF-8 -*-
import os
import re
import time
import traceback

import java.lang.Exception as JException
import common
from config import V3_MODEL_VER_PATCH_LIST
import cliUtil
import java.lang.System as javaSystem
from defusedxml import ElementTree as ET
from common import UnCheckException
from cbb.frame.base import product


cliRet = ''
LANG = common.getLang(py_java_env)
SPECIAL_VER = {'2200 V3': ['V300R006C20SPC100', 'V300R006C20'],
               '2200 V3 Enhanced': ['V300R006C20SPC100']}
LOGGER = common.getLogger(PY_LOGGER, __file__)


def execute(cli):
    '''
            热补丁推荐版本检查
    1、如果产品型号或者系统软件版本不在检查范围内，或者系统已安装推荐版本的热补丁，则检查通过。
    2、如果系统未安装热补丁或热补丁版本低于推荐版本则检查结果为建议优化。
    '''
    global cliRet
    hotPatchNotExists = ["", "--"]
    try:
        patchWarningDevs = common.getPatchWarningDevs(py_java_env)
        # 获取设备型号
        isQryOk, deviceType, cliRet, errMsg = cliUtil.getProductModelWithCliRet(cli, LANG)
        # 修改备注：getProductModelWithCliRet返回False是未检查
        if isQryOk != True:
            if not isQryOk:
                return cliUtil.RESULT_NOCHECK, cliRet, errMsg
            return isQryOk, cliRet, errMsg
        # Dorado NAS型号的设备直接通过
        if 'Dorado NAS' in deviceType:
            return True, cliRet, ''

        # 获取设备当前产品版本和补丁信息
        checkRet, currentVersionDictList, hotPatchVersionDictList = common.parse_upgradePackage(cli, LANG)
        cliRet += checkRet[1]
        if checkRet[0] != True:
            LOGGER.logSysAbnormal()
            return cliUtil.RESULT_NOCHECK, cliRet, checkRet[2]

        result, currentVersion, errMsg = common.getCurrentVersion(currentVersionDictList, LANG)
        if not result:
            return cliUtil.RESULT_NOCHECK, cliRet, errMsg  # 修改备注：getCurrentVersion返回False是未检查
        product_version = str(py_java_env.get("devInfo").getProductVersion())
        if 'Kunpeng' in product_version and 'Kunpeng' not in currentVersion:
            currentVersion += ' Kunpeng'
        # Dorado C01版本不推荐任何热补丁
        if 'dorado' in deviceType.lower() and currentVersion.startswith(
                'V300R001C01'):
            return True, cliRet, ''

        if deviceType in SPECIAL_VER and currentVersion in SPECIAL_VER.get(
                deviceType):
            deviceType = deviceType + get_cache_capacity(cli, LANG, LOGGER)
            LOGGER.logInfo("2200 V3 device type:{}".format(deviceType))

        result, curHotPatchVer, errMsg = common.getCurrentVersion(hotPatchVersionDictList, LANG)
        if not result:
            return cliUtil.RESULT_NOCHECK, cliRet, errMsg  # 修改备注：getCurrentVersion返回False是未检查
        # 判断当前设备是否为风险版本
        isQryOk, needHotPatchVer, patchName, carrierpatchPach, EnterprisepatchPach, isLinkageAvaliable = checkRiskVersion(
            deviceType, currentVersion, curHotPatchVer, LANG, LOGGER)
        if isQryOk == cliUtil.RESULT_NOCHECK:
            errMsg = common.getMsg(LANG, "parse.product.hotpatch.file.filed")
            return isQryOk, cliRet, errMsg

        # 如果没有推荐链接，就不推荐
        if not EnterprisepatchPach and not carrierpatchPach:
            return True, cliRet, ''

        if not isQryOk:
            # 判断是否存在本地补丁
            patchSavePath = javaSystem.getenv("patchSavePath")
            schedluedTaskPathFlag, schedluedTaskPath = getHotPatchXmlPathForSchedluedTask()
            patchPath = ''.join([schedluedTaskPath.split('productHotPatch.xml')[0], patchName])
            if (patchSavePath != None and os.path.exists(patchSavePath + os.sep + patchName)) or (
                        schedluedTaskPathFlag and os.path.exists(patchPath)):
                if patchSavePath != None:
                    errMsg = common.getMsg(LANG, "exist.hotpatch.error.mseeages", (
                        deviceType, currentVersion, needHotPatchVer, (patchSavePath + os.sep + patchName)))
                else:
                    errMsg = common.getMsg(LANG, "exist.hotpatch.error.mseeages", (
                        deviceType, currentVersion, needHotPatchVer, patchPath))
                # interact interface :this diction's key is not changeable
                if not "FALSE" == isLinkageAvaliable.strip().upper():
                    if patchSavePath != None:
                        patchBasisDict = {"hotPatchPath": str(os.path.join(patchSavePath, patchName)),
                                          "curPatchVersion": curHotPatchVer,
                                          "suggestPatchVersion": needHotPatchVer}
                        patchWarningDevs.put(common.getCurDeviceInfo(py_java_env), patchBasisDict)

            else:
                # 补丁路径不存在或者补丁不存在
                supportPath = ""
                if carrierpatchPach:
                    supportPath += common.getMsg(LANG, "carrierpatchPach", (carrierpatchPach))
                if EnterprisepatchPach:
                    supportPath += common.getMsg(LANG, "EnterprisepatchPach", (EnterprisepatchPach))
                errMsg = common.getMsg(LANG, "not.exist.hotpatch.error.mseeages",
                                       (deviceType, currentVersion, needHotPatchVer, supportPath))

            # 2020-03-29新增，dorado 6.0.0 推荐打SPH6及以上补丁
            # 6.0.0提升补丁等级为不通过

            if check_not_pass_version(currentVersion, curHotPatchVer,
                                      hotPatchNotExists, LOGGER):
                return False, cliRet, errMsg
            # 系统未安装热补丁或热补丁版本低于推荐版本，开局巡检为检查不通过，日常巡检为建议优化。
            result_status = False if common.is_opening_delivery_inspect(py_java_env) else cliUtil.RESULT_WARNING
            return result_status, cliRet, errMsg

        return isQryOk, cliRet, errMsg
    except UnCheckException as unCheckException:
        LOGGER.logError(str(traceback.format_exc()))
        LOGGER.logInfo("UnCheckException, errMsg: %s" % unCheckException.errorMsg)
        if not unCheckException.flag:
            return cliUtil.RESULT_NOCHECK, unCheckException.cliRet, unCheckException.errorMsg
        return unCheckException.flag, unCheckException.cliRet, unCheckException.errorMsg
    except Exception as exception:
        LOGGER.logException(exception)
        LOGGER.logError(str(traceback.format_exc()))
        return cliUtil.RESULT_NOCHECK, cliRet, common.getMsg(LANG, "query.result.abnormal")


def is_device_type_oceanstor():
    """
    检查是否是新融合产品
    :return:
    """
    device_list = [
        "OceanStor 5210", "OceanStor 18500K", "OceanStor 18510", "OceanStor 18810", "OceanStor 2200",
        "OceanStor 2220", "OceanStor 2600", "OceanStor 2620", "OceanStor 5120", "OceanStor 5210",
        "OceanStor 5220", "OceanStor 5300K", "OceanStor 5310", "OceanStor 5500K", "OceanStor 5510",
        "OceanStor 5310 Capacity Flash", "OceanStor 5510 Capacity Flash", "OceanStor A300",
        "OceanStor 5510S", "OceanStor 5610", "OceanStor 6810"
    ]
    device_type = str(common.getProductModeFromContext(py_java_env))
    return device_type in device_list


def check_risk_patch_version(current_version, cur_hot_patch, hot_patch_not_exists):
    """
    对一些特殊处理场景建议优化改为不通过（等于问题版本）
    :param current_version:
    :param cur_hot_patch:
    :param hot_patch_not_exists:
    :return:
    """

    risk_version_dict = {}
    pattern_hot_patch = re.compile(r"SPH(\d+)", flags=re.IGNORECASE)
    match_hot_path = pattern_hot_patch.search(cur_hot_patch)

    cur_hot_patch = (
        0
        if cur_hot_patch in hot_patch_not_exists or not match_hot_path
        else int(match_hot_path.group(1))
    )

    if is_device_type_oceanstor():
        risk_version_dict["6.1.3"] = 16

    if len(risk_version_dict) == 0:
        return False

    for risk_version, risk_path in risk_version_dict.items():
        if risk_version in current_version and cur_hot_patch == risk_path:
            return True
    return False


def check_need_update_version(current_version, cur_hot_patch, hot_patch_not_exists):
    """
    对一些特殊处理场景建议优化改为不通过（小于问题版本）
    :param current_version:
    :param cur_hot_patch:
    :param hot_patch_not_exists:
    :return:
    """

    risk_version_dict = {"6.0.0": 11,
                         "6.0.1": 10}
    pattern_hot_patch = re.compile(r"SPH(\d+)", flags=re.IGNORECASE)
    match_hot_path = pattern_hot_patch.search(cur_hot_patch)

    cur_hot_patch = (
        0
        if cur_hot_patch in hot_patch_not_exists or not match_hot_path
        else int(match_hot_path.group(1))
    )

    if is_device_type_oceanstor():
        risk_version_dict["6.1.6"] = 2

    if common.is_opening_delivery_inspect(
            py_java_env) or common.isExpansionCapacityScenePreInspect(
            py_java_env):
        risk_version_dict["V500R007C60SPC100 Kunpeng"] = 105
        risk_version_dict["V500R007C60SPC300 Kunpeng"] = 309
        risk_version_dict["6.1.0"] = 11
        risk_version_dict["6.0.1"] = 20
    for risk_version, risk_path in risk_version_dict.items():
        if risk_version in current_version and cur_hot_patch < risk_path:
            return True
    return False


def check_not_pass_version(current_version, cur_hot_patch,
                           hot_patch_not_exists, LOGGER):
    """
    对一些特殊处理场景建议优化改为不通过
    :param current_version:
    :param cur_hot_patch:
    :param hot_patch_not_exists:
    :return:
    """

    LOGGER.logInfo("isBureau value is{}".format(py_java_env.get("isBureau")))
    if check_need_update_version(current_version, cur_hot_patch, hot_patch_not_exists) or \
       check_risk_patch_version(current_version, cur_hot_patch, hot_patch_not_exists):
        return True
    return False


def getHotPatchXmlPathForSchedluedTask():
    """
    判断定时任务的热补丁xml文件是否存在
    :return:
    """
    toolPath = javaSystem.getProperty("user.dir").split('\\tools\inspector')[0]
    hotPatchXmlPath = '\\'.join([toolPath, 'data', 'Patch', 'inspector', 'productHotPatch.xml'])
    if os.path.exists(hotPatchXmlPath) and os.path.isfile(hotPatchXmlPath):
        return True, hotPatchXmlPath
    else:
        inspector_patch_xml = os.path.join(
            toolPath, "packages",
            "inspector", "products", "productHotPatch.xml"
        )
        LOGGER.logInfo(
            "not exist productHotPatch.xml in root smartkit dir, "
            "use inspect's:{}, toolPath:{}".format(inspector_patch_xml, toolPath)
        )
        if os.path.exists(inspector_patch_xml) and os.path.isfile(inspector_patch_xml):
            return True, inspector_patch_xml
        return False, ''


def checkRiskVersion(deviceType, sysSpcVer, curHotPatchVer, lang, LOGGER):
    '''
    @summary: 检查当前设备是否为风险版本
    @param deviceType: 设备型号
    @param sysVerRet: 产品版本
    @param curHotPatchVer: 当前热补丁版本
    @param needHotPatchVer: 目标热补丁版本
    @param logger: 日志对象
    @param lang: 语言
    @return:
        flag:
            True: 无风险版本
            False: 风险版本
    '''
    global cliRet
    flag = True
    patchName = ""
    needHotPatchVer = ""
    carrierpatchPach = ""
    EnterprisepatchPach = ""
    linkageAvaliable = ""
    hotPatchNotExists = ["", "--"]
    schedluedTaskPathFlag, schedluedTaskPath = getHotPatchXmlPathForSchedluedTask()
    if (javaSystem.getenv("patchSavePath") != None and os.path.exists(
                    javaSystem.getenv("patchSavePath") + os.sep + "productHotPatch.xml")) or schedluedTaskPathFlag:
        LOGGER.logInfo("exist productHotPatch.xml")
        if javaSystem.getenv("patchSavePath") != None:
            hotPatchVersionFilePath = javaSystem.getenv("patchSavePath") + os.sep + "productHotPatch.xml"
        elif schedluedTaskPathFlag:
            hotPatchVersionFilePath = schedluedTaskPath
        else:
            raise UnCheckException(common.getMsg(lang, "cannot.get.hotpatch.xml"), cliRet)
        hotPatchVersionDictList = parse_patch_xml_with_retry(hotPatchVersionFilePath)
        if not hotPatchVersionDictList:
            # 解析补丁配套关系表失败，返回“未检查”
            flag = cliUtil.RESULT_NOCHECK
            return flag, needHotPatchVer, patchName, carrierpatchPach, EnterprisepatchPach, linkageAvaliable

        hot_patch_dict = get_newest_patch(hotPatchVersionDictList, deviceType, sysSpcVer)
        for infoDict in hot_patch_dict:
            versionList = infoDict.get("productVersion", "").split(",")
            needHotPatchVer = infoDict.get("patchVersion", "")
            downloadUrlUser = infoDict.get("downloadUrl-user", "")
            flag, patchName, carrierpatchPach, EnterprisepatchPach, linkageAvaliable = \
                getRiskVersionSupportPath(lang, downloadUrlUser, deviceType, sysSpcVer, versionList, curHotPatchVer,
                                          hotPatchNotExists, needHotPatchVer)
            if not flag:
                break

    return flag, needHotPatchVer, patchName, carrierpatchPach, EnterprisepatchPach, linkageAvaliable


def get_newest_patch(tool_box_patch_list, device_type, sys_spc_ver):
    """
    比较子工具和工具箱获取最新的补丁
    :param tool_box_patch_list: 工具箱获取的补丁库
    :param device_type: 设备型号
    :param sys_spc_ver: 版本
    :return: 最新的补丁库
    """
    inspect_patch_ver_list = parse_patch_xml_with_retry()
    if not inspect_patch_ver_list:
        return tool_box_patch_list

    inspect_need_patch = get_need_patch_from_xml(device_type, sys_spc_ver, inspect_patch_ver_list)
    tool_box_need_patch = get_need_patch_from_xml(device_type, sys_spc_ver, tool_box_patch_list)
    # 如果存在巡检补丁且 巡检推荐的比工具箱的新，则使用巡检工具自带的补丁库
    if inspect_need_patch and check_hot_patch_version(
            device_type, sys_spc_ver, inspect_need_patch, tool_box_need_patch):
        LOGGER.logInfo("Using the patch xml in inspect dir!!!")
        return inspect_patch_ver_list

    return tool_box_patch_list


def get_need_patch_from_xml(device_type, sys_spc_ver, patch_list):
    """
    获取对应的推荐补丁
    :param device_type: 设备型号
    :param sys_spc_ver: 版本
    :param patch_list: 补丁库
    :return: 推荐补丁
    """
    for info_dict in patch_list:
        version_list = info_dict.get("productVersion", "").split(",")
        need_patch = info_dict.get("patchVersion", "")
        download_url_dict = info_dict.get("downloadUrl-user", {})
        for product_mode_path_dict in download_url_dict:
            device_type_list = product_mode_path_dict.get("productMode", "").split(",")
            if device_type in device_type_list and sys_spc_ver in version_list:
                return need_patch
    return ""


def parse_patch_xml_in_inspect_dir():
    dir_relative_cmd = "."
    base_path = os.path.abspath(dir_relative_cmd)
    file_path = os.path.join(base_path, "packages", "inspector", "products", "productHotPatch.xml")
    file_read_lock = py_java_env.get("fileReadLock")
    obj_py = py_java_env.get("objectForPy")
    file_read_lock.lock()
    try:
        if obj_py.get("hot_patch_version_xml_parser_res_in_inspect_dir"):
            LOGGER.logInfo("local inspect xml parse res from memory!!!")
            return obj_py.get("hot_patch_version_xml_parser_res_in_inspect_dir")
        patch_ver_list = parse_patch_xml_file(file_path)
        LOGGER.logInfo("save inspect dir xml parse res to memory!")
        obj_py.put("hot_patch_version_xml_parser_res_in_inspect_dir", patch_ver_list)
        return patch_ver_list
    except (JException, Exception) as e:
        LOGGER.logError(str(e))
        raise common.UnCheckException("parse xml in inspect dir error", "")
    finally:
        file_read_lock.unlock()


def parse_patch_xml_with_retry(xml_path=None):
    """
    多线程出现读写文件空指针异常，使用重试+文件锁的方式暂时规避此问题
    继续观察到下个版本是否还有出现此问题。
    规避开始时间：2022.9.8
    :param xml_path: xml 路径
    :return: xml parse 结果
    """
    retry_time = 10
    sleep_time = 5
    while retry_time > 0:
        try:
            if not xml_path:
                return parse_patch_xml_in_inspect_dir()
            return parseXMLFile(xml_path)
        except common.UnCheckException as e:
            LOGGER.logError(str(e))
            retry_time -= 1
            time.sleep(sleep_time)
    return []


def check_hot_patch_version(device_type, current_version, needHotPatchVer,
                            cur_hot_patch_ver):
    """
    兼容dorado6.0.0出现的SPH6的补丁
    :param device_type:
    :param current_version:
    :param needHotPatchVer:
    :param cur_hot_patch_ver:
    :return:
    """
    LOGGER.logInfo("device_type:{}, need path:{}, cur hot patch:{}".format(
        device_type, needHotPatchVer, cur_hot_patch_ver))
    if product.isDigitalVer(current_version):
        pattern_hot_patch = re.compile(r'SPH(\d+)',
                                       flags=re.IGNORECASE)
        match_hot_path = pattern_hot_patch.search(cur_hot_patch_ver)
        match_need_path = pattern_hot_patch.search(needHotPatchVer)
        if match_hot_path and match_need_path:
            return int(match_hot_path.group(1)) < int(
                match_need_path.group(1))
    return cur_hot_patch_ver < needHotPatchVer



def getRiskVersionSupportPath(lang, downloadUrlUser, deviceType, sysSpcVer, versionList, curHotPatchVer,
                              hotPatchNotExists, needHotPatchVer):
    '''
    @summary: 获取风险版本的Support补丁下载路径
    @param downloadUrlUser: 补丁路径字典列表
    @return:
        flag:
            True: 非风险版本
            False: 风险版本
            patchName：补丁名称
            carrierpatchPach：运营商补丁获取网址
            EnterprisepatchPach：企业网补丁获取网址
    '''
    flag = True
    patchName = ""
    carrierpatchPach = ""
    EnterprisepatchPach = ""
    parseLinkageAvaliable = ""
    for productModePathDict in downloadUrlUser:
        deviceTypeList = productModePathDict.get("productMode", "").split(",")
        # 去掉空格
        deviceTypeList = [devType.strip() for devType in deviceTypeList]
        if ((deviceType in deviceTypeList) and (sysSpcVer in versionList) and
                (curHotPatchVer in hotPatchNotExists or check_hot_patch_version(
                         deviceType, sysSpcVer, needHotPatchVer,
                         curHotPatchVer))):
            flag = False
            patchName = productModePathDict.get("patchName", "")
            parseLinkageAvaliable = productModePathDict.get("toPatchTool", "")
            if lang == "zh":
                # 运营商补丁路径
                carrierpatchPach = productModePathDict.get("carrierUrl-zh", "")
                # 企业网补丁路径
                EnterprisepatchPach = productModePathDict.get("enterpriseUrl-zh", "")
            else:
                # 运营商补丁路径
                carrierpatchPach = productModePathDict.get("carrierUrl-en", "")
                # 企业网补丁路径
                EnterprisepatchPach = productModePathDict.get("enterpriseUrl-en", "")

            break

    return flag, patchName, carrierpatchPach, EnterprisepatchPach, parseLinkageAvaliable


def parseXMLFile(filePath):
    '''
    @summary: 解析xml的指定节点信息
    @param element: xml的一个节点
    @return: 解析后的xml节点信息（一个嵌套的字典列表)
    '''
    fileReadLock = py_java_env.get("fileReadLock")
    obj_py = py_java_env.get("objectForPy")
    fileReadLock.lock()
    try:
        if obj_py.get("hot_patch_version_xml_parser_res"):
            LOGGER.logInfo("xml parse res from memory!")
            return obj_py.get("hot_patch_version_xml_parser_res")
        hotPatchVersionDictList = parse_patch_xml_file(filePath)
        LOGGER.logInfo("save xml parse res to memory!")
        obj_py.put("hot_patch_version_xml_parser_res", hotPatchVersionDictList)
        return hotPatchVersionDictList
    except (JException, Exception):
        LOGGER.logError("parse xml error:{}".format(traceback.format_exc()))
        raise common.UnCheckException("parse xml error", "")
    finally:
        fileReadLock.unlock()


def parse_patch_xml_file(filePath):
    hot_patch_version_dict_list = []
    xmlElementTree = ET.parse(filePath)
    rootElement = xmlElementTree.getroot()
    element = rootElement.getiterator(tag="patch")
    for children in element:
        # 解析产品版本
        product_version = children.getiterator(tag="productVersion")[0].text.strip()
        # 解析补丁版本
        patch_version = children.getiterator(tag="patchVersion")[0].text.strip()
        product_patch_version_dict = {"productVersion": product_version, "patchVersion": patch_version}

        # 解析Support补丁下载路径和对应产品型号、补丁名称
        download_url_user = children.getiterator(tag="downloadUrl-user")
        download_url = download_url_user[0].getiterator(tag="downloadUrl")

        supportPatchPathList = []
        for child in download_url:
            modelPatchNameDict = child.attrib
            carrierUrlZh = child.getiterator(tag="carrierUrl-zh")[0].text
            carrierUrlEn = child.getiterator(tag="carrierUrl-en")[0].text
            enterpriseUrlZh = child.getiterator(tag="enterpriseUrl-zh")[0].text
            enterpriseUrlEn = child.getiterator(tag="enterpriseUrl-en")[0].text

            supportPatchPathDict = {"carrierUrl-zh": carrierUrlZh,
                                    "carrierUrl-en": carrierUrlEn,
                                    "enterpriseUrl-zh": enterpriseUrlZh,
                                    "enterpriseUrl-en": enterpriseUrlEn,
                                    }
            # 字典归一
            hot_patch_version_dict = dict(modelPatchNameDict, **supportPatchPathDict)
            supportPatchPathList.append(hot_patch_version_dict)

        supportPatchPathDictList = {"downloadUrl-user": supportPatchPathList}
        hot_patch_version_dict = dict(product_patch_version_dict, **supportPatchPathDictList)
        hot_patch_version_dict_list.append(hot_patch_version_dict)
    return hot_patch_version_dict_list


def get_cache_capacity(cli, LANG, LOGGER):
    """
    @summary: get current device cache capacity :8G  16G etc.
    :param cli:链接
    :param LANG:中英文
    :param LOGGER:日志
    :return:
    """
    cmd = r"show controller general |filterColumn include " \
          r"columnList=Cache\sCapacity"
    echo_status, cli_ret, err_msg = cliUtil.excuteCmdInCliMode(cli, cmd, True,
                                                               LANG)
    if not echo_status or echo_status == cliUtil.RESULT_NOSUPPORT or \
            echo_status == cliUtil.RESULT_NOCHECK:
        LOGGER.logInfo(
            "show controller general command executed with exception:")
        return ""
    for cliRetLine in cli_ret.splitlines():
        if "cache capacity" in cliRetLine.lower():
            fields = cliRetLine.split(":")
            if len(fields) < 2:
                continue
            field_value = fields[1].strip().replace('.000', '')
            LOGGER.logInfo(
                "current device's cache capacity is{}".format(field_value))
            return " " + field_value

    return ""
