# -*- coding: UTF-8 -*-
#  Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.
from psdk.checkitem.common.base_dsl_check import BaseCheckItem
from psdk.platform.entity.check_status import CheckStatus
from psdk.dsl.dsl_common import get_version_info
from psdk.platform.util.product_util import compare_patch_version
from psdk.platform.util.product_util import get_ctrl_id_by_node_id
from psdk.dsl.dsl_common import get_engine_height

PASS_VERSION_DIC = {
    "6.1.2": "--",
    "6.1.3": "--",
    "6.1.5": "SPH12",
}
REP_LOWOJ_ID_LEN = 9
MAX_PAIR_NUM = 200


class CheckItem(BaseCheckItem):
    def execute(self):
        # 检查方法步骤1：检查补丁版本号是否满足要求
        software_version = self.check_patch_version()
        if not software_version:
            return CheckStatus.PASS, ""
        engine_height = self.get_engine_height_rep()
        # 检查方法步骤2：检查是否有双活pair的平滑标记
        pair_flag = self.dsl("exec_on_all {}", self.get_pair_flag)
        all_msgs = []
        for node_id, ab_nodes in pair_flag.items():
            if ab_nodes and ab_nodes != [MAX_PAIR_NUM]:
                ctrl_id = get_ctrl_id_by_node_id(node_id, engine_height)
                all_msgs.append(self.get_msg("check.not.pass", software_version, ctrl_id))
            if ab_nodes == [MAX_PAIR_NUM]:
                return CheckStatus.WARNING, self.get_msg("item.suggestion")
        if all_msgs:
            return CheckStatus.NOT_PASS, "\n".join(all_msgs)
        return CheckStatus.PASS, ""

    def check_patch_version(self):
        version_info = get_version_info(self.dsl)
        patch_version = version_info.get("patch_version").get("Current Version")
        software_version = version_info.get("base_version").get("Current Version")
        self.logger.info("base_version {},patch_version {}. pmsLun check pass".format(software_version, patch_version))
        version_615 = list(PASS_VERSION_DIC.keys())
        version_615.sort()
        if software_version == version_615[-1] and \
                compare_patch_version(patch_version, PASS_VERSION_DIC.get(version_615[-1])) >= 0:
            return ""
        if software_version not in PASS_VERSION_DIC:
            return ""

        return software_version

    def get_pair_flag(self):
        pair_index = []
        res = []
        cls_info = ""
        # 获取节点nodeId
        cls_info_s = self.dsl("exec_diagnose 'rsf showcls' | vertical_parser")
        if not cls_info_s:
            return []
        cls_info = cls_info_s[0].get("local node id")
        if cls_info:
            pair_index = self.get_hyper_pair_index(cls_info)

        for index in pair_index[0: 201]:
            pair_flag = self.dsl("exec_diagnose 'rephc querypair {}' | vertical_parser".format(index))
            for flag in pair_flag:
                # healPairState字段初始化是255， local cg ID初始化是4294967295， 避免获取不到报错填默认值
                if int(flag.get("healPairState", "255")) == 0 and\
                        int(flag.get("local cg ID", "4294967295")) == 4294967295:
                    res.append(index)
                    return res
        if len(pair_index) > MAX_PAIR_NUM:
            return [MAX_PAIR_NUM]
        return res

    def get_hyper_metro_num(self, rms_info):
        pair_num = 0
        for rms_line in rms_info:
            if rms_line.split(":")[0].strip() == "type(0)totalObjNum":
                pair_num = int(rms_line.split(":")[1].strip())
        return pair_num

    def get_hyper_pair_index(self, cls_info):
        loop_flag = 1
        start_num = 0
        end_num = MAX_PAIR_NUM
        pair_index = []
        while loop_flag < 200:
            loop_flag += 1
            rms_info = self.dsl("exec_diagnose 'rsf showrms {} {} {}' |"
                                " splitlines".format(cls_info, str(start_num), str(end_num)))
            # # 获取双活数量
            pair_num = self.get_hyper_metro_num(rms_info)
            # 获取双活localId
            self.logger.info("smooth have get pair num {}, pair num{}.".format(pair_index, pair_num))
            for rms_line in rms_info[::-1]:
                if rms_line.split("|")[0].strip().endswith(":/diagnose>"):
                    continue
                # 当解析出来第一个字段不为空，则表示当前节点没有rms对象，停止检查
                if rms_line.split("|")[0].strip():
                    loop_flag = 200
                    break
                modify_flag = rms_line.split("|")[1].strip()
                lowoj_id = rms_line.split("|")[6].strip()
                obj_index = rms_line.split("|")[8].strip()
                # 获取objId，当modifyFlag字段不是数值时停止遍历
                if modify_flag == "modifyFlag":
                    break
                if modify_flag.isdigit() and len(lowoj_id) < REP_LOWOJ_ID_LEN and obj_index not in pair_index:
                    pair_index.append(obj_index)
                if len(pair_index) > MAX_PAIR_NUM or len(pair_index) >= pair_num:
                    break
            if len(pair_index) > MAX_PAIR_NUM or len(pair_index) >= pair_num:
                break
            start_num += 200
            end_num += 200
        self.logger.info("smooth have get pair num {}, pair num{}.".format(pair_index, pair_num))
        return pair_index

    def get_engine_height_rep(self):
        dev_infos = self.dsl("exec_diagnose 'sys showcls' | vertical_parser")
        if dev_infos:
            dev_info = dev_infos[0]
            return int(int(dev_info.get("node max")) / int(dev_info.get("group max")))
        return 0