# coding: UTF-8
#  Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
import os
from defusedxml import ElementTree as ET

import java.lang.System as java_system
import com.huawei.ism.tool.protocol.utils.RestUtil as rest_util
from com.huawei.ism.tool.service.common import DevInfoUtil
from com.huawei.ism.exception import IsmException

import common
from ds_rest_util import CommonRestService

LANG = common.getLang(py_java_env)
ITEM_ID = "hotpatch_version"
LOGGER = common.getLogger(PY_LOGGER, __file__)


def execute(rest_conn):
    """
    检查热补丁版本
    :param rest_conn: 集群rest连接
    :return:
    """
    try:
        dev_node = py_java_env.get("devInfo")
        product_hotpatch_ver = query_product_hotpatch_ver(dev_node, rest_conn)
        product_complete_ver = DevInfoUtil.getCompleteDevVersion(dev_node)
        product_model = dev_node.getProductModel()

        toolbox_hotpatch_file = get_toolbox_hotpatch_file()
        toolbox_patch_flag, toolbox_patch_info = \
            get_product_patchinfo(product_model, product_complete_ver, product_hotpatch_ver, toolbox_hotpatch_file)
        inspector_hotpatch_file = py_java_env.get('hot_patch_file')
        inspector_patch_flag, inspector_patch_info = \
            get_product_patchinfo(product_model, product_complete_ver, product_hotpatch_ver, inspector_hotpatch_file)
        not_pass, final_patch_info = \
            get_final_patch_info(toolbox_patch_flag, toolbox_patch_info, inspector_patch_flag, inspector_patch_info)

        if not_pass:
            recommend_hotpatch_ver = final_patch_info.get("patchVersion")
            enterprise_url = final_patch_info.get("downloadUrl-user")[0].get("enterpriseUrl-{}".format(LANG))
            carrier_url = final_patch_info.get("downloadUrl-user")[0].get("carrierUrl-{}".format(LANG))
            res = "Not Pass!\nproduct_hotpatch_ver:{}, recommend_hotpatch_ver:{}".format(product_hotpatch_ver,
                                                                                        recommend_hotpatch_ver)
            return common.INSPECT_WARNING, res, common.get_err_msg(LANG, "hotpatch.check.failed"). \
                format(product_hotpatch_ver, recommend_hotpatch_ver, enterprise_url, carrier_url)
        res = "Pass!\nNo patch is required.\nProduct_hotpatch_ver:{}".format(product_hotpatch_ver)
        return common.INSPECT_PASS, res, ""
    except (IsmException, Exception) as exception:
        LOGGER.logException(exception)
        return common.INSPECT_UNNORMAL, "Exception", common.get_err_msg(LANG, "query.result.abnormal")


def query_product_hotpatch_ver(dev_node, rest_conn):
    base_uri = rest_util.getDstorageUrlHead(dev_node)
    cmd_str = "{}/api/v2/cluster/product".format(base_uri)
    product_result = CommonRestService.exec_get_gor_big_by_ds(rest_conn, cmd_str)
    return product_result.get("data", {}).get("hotpatch_version", "")


def get_toolbox_hotpatch_file():
    toolbox_patch_path = java_system.getenv("patchSavePath")
    return toolbox_patch_path + os.sep + 'productHotPatch.xml' if toolbox_patch_path else ""


def get_product_patchinfo(product_model, product_complete_ver, product_hotpatch_ver, hotpatch_file):
    if hotpatch_file:
        hotpatch_version_dict_list = parse_xml_file(hotpatch_file)
        for patch_info in hotpatch_version_dict_list:
            if patch_info.get("productMode") == product_model and \
                    patch_info.get("productVersion") == product_complete_ver and \
                    patch_info.get("patchVersion") > product_hotpatch_ver:
                return True, patch_info
    return False, {}


def get_final_patch_info(toolbox_patch_flag, toolbox_patch_info, inspector_patch_flag, inspector_patch_info):
    if toolbox_patch_flag and inspector_patch_flag:
        if toolbox_patch_info.get("patchVersion") > inspector_patch_info.get("patchVersion"):
            return True, toolbox_patch_info
        return True, inspector_patch_info
    if inspector_patch_flag:
        return True, inspector_patch_info
    if toolbox_patch_flag:
        return True, toolbox_patch_info
    return False, {}


def parse_xml_file(xml_file):
    '''
    @summary: 解析xml的指定节点信息
    @param element: xml的一个节点
    @return: 解析后的xml节点信息（一个嵌套的字典列表)
    '''
    try:
        hotpatch_version_dict_list = []
        xml_element_tree = ET.parse(xml_file)
        root_element = xml_element_tree.getroot()
        element = root_element.getiterator(tag="patch")

        for child in element:
            product_patch_version_dict = get_product_patch_version_dict(child)
            support_patch_path_list = get_support_patch_path_list(child)
            support_patch_path_dict_list = {"downloadUrl-user": support_patch_path_list}
            hotpatch_version_dict = dict(product_patch_version_dict, **support_patch_path_dict_list)
            hotpatch_version_dict_list.append(hotpatch_version_dict)

        return hotpatch_version_dict_list
    except Exception as exception:
        LOGGER.logException(exception)
        return []


def get_support_patch_path_list(child):
    # 解析Support补丁下载路径和对应产品型号、补丁名称
    download_urluser = child.getiterator(tag="downloadUrl-user")
    download_url = download_urluser[0].getiterator(tag="downloadUrl")

    support_patch_path_list = []
    for child in download_url:
        model_patch_name_dict = child.attrib
        carrier_url_zh = child.getiterator(tag="carrierUrl-zh")[0].text
        carrier_url_en = child.getiterator(tag="carrierUrl-en")[0].text
        enterprise_url_zh = child.getiterator(tag="enterpriseUrl-zh")[0].text
        enterprise_url_en = child.getiterator(tag="enterpriseUrl-en")[0].text

        support_patch_path_dict = {"carrierUrl-zh": carrier_url_zh,
                                   "carrierUrl-en": carrier_url_en,
                                   "enterpriseUrl-zh": enterprise_url_zh,
                                   "enterpriseUrl-en": enterprise_url_en,
                                   }
        # 字典归一
        hotpatch_version_dict = dict(model_patch_name_dict, **support_patch_path_dict)
        support_patch_path_list.append(hotpatch_version_dict)
    return support_patch_path_list


def get_product_patch_version_dict(child):
    # 解析产品版本
    product_version = child.getiterator(tag="productVersion")[0].text.strip()
    # 解析补丁版本
    patch_version = child.getiterator(tag="patchVersion")[0].text.strip()
    # 解析产品型号
    product_mode = child.getiterator(tag="productModel")[0].text.strip()
    return {"productVersion": product_version, "patchVersion": patch_version, "productMode": product_mode}
