# -*- coding: UTF-8 -*-
#  Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved.

from psdk.checkitem.common.base_dsl_check import BaseCheckItem
from psdk.dsl.dsl_common import get_version_info
from psdk.platform.entity.check_status import CheckStatus
from psdk.platform.util.product_util import compare_patch_version

UPGRADE_VERSION = ["6.0.0", "6.1.0"]
PATCH_VERSION = ["6.0.1"]


class CheckItem(BaseCheckItem):

    @staticmethod
    def set_pair_and_cg(pairs_info, pair_cg_dict):
        for pair_info in pairs_info:  # 检查pair和cg的数量
            lun_name = pair_info.get("LOCALOBJNAME") or pair_info.get("LOCALRESNAME")  # 双活和复制的名称字段不一样
            if lun_name not in pair_cg_dict:
                pair_cg_dict[lun_name] = [0, 0]  # Index0表示pair数量， index1表示cg数量
            pair_cg_dict.get(lun_name)[0] += 1  # pair数量+1

            cg_id = pair_info.get("CGID")
            if cg_id and cg_id != "--":
                pair_cg_dict.get(lun_name)[1] += 1  # cg数量+1
                
    def execute(self):
        self.logger.info("check omtask expand lun start")
        try:
            version_info = get_version_info(self.dsl)
            cur_version = version_info.get("base_version").get("Current Version")
            patch_version = version_info.get("patch_version").get("Current Version")

            err_msg_key = self.get_msg("item.suggestion.{}".format(cur_version))

            if cur_version in UPGRADE_VERSION and not self.rep_metro_check(cur_version):
                return CheckStatus.WARNING, err_msg_key

            if cur_version in PATCH_VERSION:
                result = compare_patch_version(patch_version, "SPH21")
                if result < 0 and not self.rep_metro_check(cur_version):
                    return CheckStatus.WARNING, err_msg_key
        except Exception as e:
            self.logger.error("error is {}".format(e))

        return CheckStatus.PASS, ""

    def rep_metro_check(self, cur_version):
        """
        检查同一个lun的pair数量和cg数量是否一致
        """
        url_metro_count = "exec_rest '/hypermetropair/count?HCRESOURCETYPE=1'"
        url_metro = "exec_rest '/hypermetropair?HCRESOURCETYPE=1&range=[{}-{}]'"
        url_rep_count = "exec_rest '/REPLICATIONPAIR/count'"
        if cur_version == "6.1.0":
            url_rep = "exec_rest '/REPLICATIONPAIR?filter" \
                  "=LOCALRESTYPE%3A%3A11%20and%20resourceSubtype%3A%3A0&range=[{}-{}]'"
        else:
            url_rep = "exec_rest '/REPLICATIONPAIR?range=[{}-{}]'"

        pair_cg_dict = dict()
        self.get_pair_info(url_metro_count, url_metro, pair_cg_dict)
        self.get_pair_info(url_rep_count, url_rep, pair_cg_dict)

        for lun_name in pair_cg_dict:  # 检查每个lun的pair数量和cg数量是否一致
            pair_number = pair_cg_dict.get(lun_name)[0]
            cg_number = pair_cg_dict.get(lun_name)[1]
            if pair_number != cg_number and cg_number != 0:
                return False

        return True

    def get_pair_info(self, url_count, url_pair, pair_cg_dict):
        pair_count = int(self.dsl(url_count).get("COUNT", 0))
        i = 0
        while pair_count > 0:
            url = url_pair.format(100*i, 100*(i + 1))
            tmp_info = self.dsl(url)
            if not isinstance(tmp_info, list):
                continue
            self.set_pair_and_cg(tmp_info, pair_cg_dict)

            pair_count -= 100
            i += 1

