#!/usr/bin/env python
# coding=UTF-8
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.
import copy
import re
from Common.base import entity, context_util
from Common.base.context_util import get_mapping_attribute, get_mapping_attribute_url
from Common.base.entity import CheckCommon
from Common.protocol import ssh_util

PY_JAVA_ENV = py_java_env

MAPPING_VERSION_KEYS = ("pcieinf", "bsp", "cma", "memf")
BASE_KEY = "{}_base_driver_version"


def execute(task):
    return CheckDiskDriverVersion(task).check()


class CheckDiskDriverVersion(CheckCommon):
    def __init__(self, task):
        self.deploy_node = context_util.get_deploy_node(PY_JAVA_ENV)
        self._mapping_versions = []
        self._mapping_urls = []
        super(CheckDiskDriverVersion, self).__init__(task)

    @staticmethod
    def cur_version_lt_map_version(cur_version, map_version):
        """
        用于比较基础驱动版本
        :param cur_version:当前版本
        :param map_version:配套表中版本
        :return bool: 当前版本小于配套表版本时返回True，否则返回False
        """

        def hex_to_int(s):
            try:
                return int(s, 16)
            except ValueError:
                return s

        match = re.compile(r'[.-]')
        cur_sections = match.split(cur_version)
        map_sections = match.split(map_version)
        for index in range(min(len(cur_sections), len(map_sections))):  # 分割后转换为16进制后比较大小，转换失败时用原字符串比较
            if hex_to_int(cur_sections[index]) < hex_to_int(map_sections[index]):
                return True
        return False

    def obtain_match_versions(self):
        for key in MAPPING_VERSION_KEYS:
            version = get_mapping_attribute(PY_JAVA_ENV, BASE_KEY.format(key))
            self._mapping_versions.append(version)
            self.set_not_support_msg(key, version)
            self._mapping_urls.append(get_mapping_attribute_url(PY_JAVA_ENV, BASE_KEY.format(key)))

    def set_not_support_msg(self, key, version):
        if not version:
            self.deploy_node.putVersion(BASE_KEY.format(key), self.not_support_msg)

    def get_match_msg(self):
        mapping_versions = copy.copy(self._mapping_versions)
        for index, val in enumerate(self._mapping_versions):
            if not val:
                mapping_versions[index] = self.not_support_msg
        match_msg = entity.build_url_error_msg(self._mapping_urls[0] if self._mapping_urls else "",
                                               entity.create_msg("base.driver.version.mapping.match")
                                               .format(*mapping_versions))
        return entity.create_source_file_msg(PY_JAVA_ENV, match_msg)

    def check_version(self):
        version_match = True
        msg_list = []
        for index, key in enumerate(MAPPING_VERSION_KEYS):
            if not context_util.contain_need_check_key(PY_JAVA_ENV, [BASE_KEY.format(key)]):
                continue
            ssh_ret = ssh_util.exec_ssh_cmd_nocheck(PY_JAVA_ENV, "rpm -qa |grep {}".format(key))
            self._ssh_rets.append(ssh_ret)
            match = re.search(r"euler-\d+.\d+.\d+-(.*?).aarch64", ssh_ret)
            version = match.group(1) if match else ""
            self.deploy_node.putVersion(BASE_KEY.format(key), version)
            if not self._mapping_versions[index]:
                continue
            if not version:
                no_exist_msg = entity.create_msg("no.associated.base.driver.exist").format(key)
                msg_list.append(no_exist_msg)
                version_match = False
                self.deploy_node.putResult(BASE_KEY.format(key), context_util.get_not_pass_key())
                continue
            msg_list.append(entity.create_msg("{}.base.driver.current.version".format(key)).format(version))
            if self.cur_version_lt_map_version(version, self._mapping_versions[index]):
                version_match = False
                self.deploy_node.putResult(BASE_KEY.format(key), context_util.get_not_pass_key())

        self._err_msgs.append('\n'.join(msg_list))
        return version_match
