# -*- coding: UTF-8 -*-
#  Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.

import re

from psdk.checkitem.common.base_dsl_check import BaseCheckItem
from psdk.platform.entity.check_status import CheckStatus
from psdk.dsl.dsl_common import get_version_info
from psdk.platform.util.product_util import compare_patch_version

DORADO_RISK_PATCH = {"6.0.1": "SPH25", "6.1.0": "SPH15", "6.1.2": "SPH7"}

OCEAN_PROTECT_MODEL_AND_PATCH = {
    "OceanProtect X8000": "SPH1",
    "OceanProtect X9000": "SPH1",
}


class CheckItem(BaseCheckItem):
    def execute(self):
        model = self.context.dev_node.model
        version_info = get_version_info(self.dsl)
        patch_version = version_info.get("patch_version").get("Current Version")
        base_version = version_info.get("base_version").get("Current Version")
        self.logger.info("version is : {}, patch version is:{}".format(base_version, patch_version))
        if (
            base_version in DORADO_RISK_PATCH
            and compare_patch_version(patch_version, DORADO_RISK_PATCH.get(base_version)) < 0
        ):
            self.logger.info(
                "dorado model:{} version is : {}, patch version is:{}".format(model, base_version, patch_version)
            )
            return CheckStatus.NOT_PASS, self.get_msg("check.not.pass", base_version, patch_version)

        if (
            model in OCEAN_PROTECT_MODEL_AND_PATCH
            and compare_patch_version(patch_version, OCEAN_PROTECT_MODEL_AND_PATCH.get(model)) < 0
        ):
            self.logger.info(
                "protect model:{} version is : {}, patch version is:{}".format(model, base_version, patch_version)
            )
            return CheckStatus.NOT_PASS, self.get_msg("check.not.pass", base_version, patch_version)

        if model == "OceanProtect A8000":
            spc_version, complete_version = self.get_a8000_spc_version()
            self.logger.info("a8000 version is : {}, spc_version:{}".format(complete_version, spc_version))
            if int(spc_version) >= 2:
                return CheckStatus.PASS, ""
            return CheckStatus.NOT_PASS, self.get_msg("check.not.pass.a8000", complete_version)

        return CheckStatus.PASS, ""

    def get_a8000_spc_version(self):
        ret = self.dsl("exec_cli 'show container_application general name=dataprotect'", need_log=False)
        res = re.compile("global\.version\s+(1\.0\.0\.SPC\d+)", re.IGNORECASE).findall(str(ret))
        if not res:
            return "0", "unknown"
        return res[0].split("SPC")[-1], res[0]
