# -*- coding: utf-8 -*-
# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved.
import traceback

import utils.common.log as logger
from utils.common.message import Message
from utils.common.fic_base import TestCase
from utils.DBAdapter.DBConnector import BaseOps
from utils.common.exception import HCCIException
from utils.common.fic_base import StepBaseInterface

from utils.business.project_condition_utils import get_project_condition_boolean
from plugins.DistributedStorage.utils.common.deploy_constant import DeployConstant
from plugins.DistributedStorage.utils.interface.DistributedStorage import DistributedStorageTool
from plugins.DistributedStorage.Deploy.scripts.PreCheck.common.device_operate import PreCheckPublicOperate


class RepOSVersionCheckInterface(StepBaseInterface):
    def __init__(self, project_id, pod_id):
        super(RepOSVersionCheckInterface, self).__init__(project_id, pod_id)
        self.project_id = project_id
        self.pod_id = pod_id

    def pre_check(self, project_id, pod_id):
        """
        插件内部接口：
        :param project_id:
        :param pod_id:
        :return:
        """
        return Message()

    def execute(self, project_id, pod_id):
        """
        标准调用接口：
        :param project_id:
        :param pod_id:
        :return:Message类对象
        """
        try:
            RepVersionCheckImpl(project_id, pod_id).procedure()
        except Exception as err:
            logger.error(traceback.format_exc())
            return Message(500, err)
        return Message(200)

    def rollback(self, project_id, pod_id):
        """
        标准调用接口：执行回滚
        :param project_id:
        :param pod_id:
        :return:Message类对象
        """
        return Message()

    def retry(self, project_id, pod_id):
        """
        标准调用接口：重试
        :return: Message类对象
        """
        return self.execute(project_id, pod_id)

    def check(self, project_id, pod_id):
        """
        标准调用接口：重试
        :param project_id:
        :param pod_id:
        :return:
        """
        return Message()


class RepVersionCheckImpl(TestCase):
    def __init__(self, project_id, pod_id):
        super(RepVersionCheckImpl, self).__init__(project_id, pod_id)
        self.db_obj = BaseOps()
        self.storage_tool = DistributedStorageTool(self.project_id, self.pod_id, self.db_obj)
        self.is_exp_rep = get_project_condition_boolean(
            self.project_id,
            "!DRStorage_TFB_PD&DRStorage_TFB_Sep&(CSHAStorage_TFB|CSDRStorage_TFB)"
            "&(ExpansionAdCloudService|ExpansionServiceStorage&!TenantStorFB80)")

    def procedure(self):
        """
        追加复制集群场景：1、生产存储与灾备存储版本要求相同
                       2、在1通过的前提下，生产存储与插件版本要相同

        全新建存储集群-部署灾备端场景：生产存储与插件版本要相同
        :return:
        """
        logger.info('Start to check the product version')

        product_site_product_version = self.get_product_version(dc_name="produce")
        logger.info('Production Site Product Version:{}'.format(product_site_product_version))
        if self.is_exp_rep:
            disaster_site_product_version = self.get_product_version(dc_name="disaster")
            logger.info('Disaster Site Product Version:{}'.format(disaster_site_product_version))

            if product_site_product_version != disaster_site_product_version:
                err_msg = "Disaster Site Product Version:{}, Production Site Product Version:{}".format(
                    disaster_site_product_version, product_site_product_version)
                logger.error(err_msg)
                raise HCCIException(626394, err_msg)

        plugins_version = PreCheckPublicOperate.get_plugins_version()
        logger.info('Plugins Version:{}'.format(product_site_product_version))

        if product_site_product_version != plugins_version:
            if self.is_exp_rep:
                err_msg = "plugins_version:{}, Local DistributedStorage Version:{}".format(
                    plugins_version, product_site_product_version)
                logger.error(err_msg)
                raise HCCIException(626396, err_msg, product_site_product_version)
            else:
                err_msg = "plugins_version:{}, Production Site DistributedStorage Version:{}".format(
                    plugins_version, product_site_product_version)
                logger.error(err_msg)
                raise HCCIException(626395, err_msg)

        logger.info("Pass")

    def get_product_version(self, dc_name="produce"):
        """查询存储版本"""
        logger.info("Query Pacific version")
        product_version = ""
        storage_infos = self.storage_tool.get_storage_float_ip(az_mode="rep")
        for storage_info in storage_infos:
            float_ip = storage_info.get("float_ip")
            if storage_info.get("mode") != dc_name:
                logger.info('Not {} float IP:{}, pass'.format(dc_name, float_ip))
                continue
            logger.info('local float IP:{}, pass'.format(float_ip))
            product_version = PreCheckPublicOperate.get_product_version(float_ip, storage_info.get("portal_pwd"))
            break
        if not product_version:
            err_msg = 'Failed to query product version'
            logger.error(err_msg)
            raise HCCIException(626278, DeployConstant.GET_PRODUCT, err_msg)
        return product_version
