﻿# -*- coding: UTF-8 -*-
#  Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.


import os
from common import context_util
import defusedxml.ElementTree as ET
from cbb.frame.base import baseUtil

# 受限检查项配置配置文档
DEFAULT_XML_FILE = "checklist_innerversion_base.xml"


def execute(context):
    """
    @summary      : 根据版本号确定检查项的配置文件，配置字典
    @param context: context
    @return       : flag, cliRet, errMsg
    """
    flag = True
    err_msg = ''
    logger = baseUtil.getLogger(context_util.get_logger(context), __file__)

    curent_version = context_util.get_cur_version(context)
    target_version = context_util.get_target_version(context)
    device_type = context_util.get_dev_type(context)
    logger.logInfo(
        "current Version=[%s],target version=[%s],product model=[%s]" % (curent_version, target_version, device_type))

    # 获取升级模式
    upgrade_mode = context_util.get_upgrade_model(context)
    logger.logInfo('upgradeMode:' + upgrade_mode)

    file_path = get_xml_file_path()
    logger.logInfo('filePath:' + file_path)
    filter_list = ET.parse(file_path)

    # 根据升级模式过滤
    filter_upgrade_manner(filter_list, upgrade_mode)

    # 删除无关属性
    filter_list = filter_item_name(filter_list)

    # 删除没有孩子节点“item”的“module”节点
    filter_no_child_module(filter_list)

    # node类型转化为str
    cli_ret = get_check_list(filter_list)
    logger.logInfo('check list:' + str(cli_ret))
    return flag, cli_ret, err_msg


def get_xml_file_path():
    """
    @summary: 获取当前版本调用的检查项xml文件
    @return: 检查项配置文件的绝对路径
    """
    cur_dir = os.path.dirname(__file__)
    return os.path.join(cur_dir, DEFAULT_XML_FILE)


def get_check_list(node):
    """
    @Function name      : getCheckList
    @Function describe  : 将节点转化为字符串型checklist
    @Input              : node
    @Return             : checklist
    """
    root = node.getroot()
    checklist = ET.tostring(root, encoding="utf-8")
    checklist = '''<?xml version="1.0" encoding="UTF-8"?>\n''' + checklist
    return checklist


def filter_upgrade_manner(checklist, method):
    """
    @Function name      : filterUpgradeManner(checklist)
    @Function describe  : 根据升级方式过滤checklist
    @Input              : checklist, method
    @Return             : checklist
    """
    for node in checklist.getiterator("module"):
        check_one_model(method, node)


def check_one_model(method, node):
    items = node.getchildren()
    i_flag = 0
    len_item = len(node.getchildren())
    while i_flag < len_item:
        if items[i_flag].attrib.has_key("upgrademode") and items[i_flag].attrib["upgrademode"].lower() != 'all':
            if items[i_flag].attrib["upgrademode"].lower() != method.lower():
                node.remove(items[i_flag])
                i_flag -= 1
                len_item -= 1
            i_flag += 1
            # 未配置upgrademode和upgrademode=ALL的情况，默认都为upgrademode=ALL
        else:
            i_flag += 1


def filter_no_child_module(checklist):
    """
    @summary: filter the node named 'module' which has no child node
    @param checklist: an object which stand for the XML configuration file is parsed as a tree
    """
    root_node = checklist.getroot()
    for node_module in root_node.getchildren():
        cur_node_item_list = node_module.getchildren()
        # 删除空节点
        if not cur_node_item_list:
            root_node.remove(node_module)


def filter_item_name(checklist):
    """
    @Function name      : filterItemName(checklist)
    @Function describe  : 输出无关属性
    @Input              : checklist
    @Return             : checklist
    """
    for node in checklist.getiterator("module"):
        if not node:
            continue
        del node.attrib["name"]
        for item in node.getchildren():
            del item.attrib["name"]
            if item.attrib.has_key("aftersecure"):
                del item.attrib["aftersecure"]
            if item.attrib.has_key("upgralimit"):
                del item.attrib["upgralimit"]
            # 兼容upgrademode未配置的情况
            if item.attrib.has_key("upgrademode"):
                del item.attrib["upgrademode"]
    return checklist


def add_check_items(product, checklist):
    """
    @summary: 找到指定的检查分类，新增配置的检查项
    """
    root_node = checklist.getroot()
    for module in product.getchildren():
        index = int(module.attrib["index"])
        for item in module.getchildren():
            root_node.getchildren()[index].append(item)
    return checklist


def delete_check_items(product, checklist):
    """
    @summary: 如果当前配置存在删除项id，删除该检查项
    """
    delete_ids = [item.attrib["id"] for item in product.getchildren()]
    for module in checklist.getiterator("module"):
        temp_module = list(module.getchildren())
        for item in temp_module:
            if item.attrib["id"] in delete_ids:
                module.remove(item)
    return checklist
