# -*- coding: UTF-8 -*-
import time
from cbb.business.operate.expansion import common
from cbb.frame.context import contextUtil
from cbb.frame.rest import restUtil
from cbb.frame.tlv import tlvUtil
from utils import Products

CHECK_TIME = 2400
INTERVAL = 5
SLEEP_TIME = 10
NO_RESPONSE_RETRY_TIMES = 3

# 二级存储A系列
OCEAN_PROTECT_A = ["OceanProtect A8000"]


def execute(context):
    ExpandPod(context).execute_pod()


class ExpandPod:
    """
    扩容POD
    """

    def __init__(self, context):
        self.logger = common.getLogger(context.get("logger"), __file__)
        self.result_dict = {"flag": True, "errMsg": "", "suggestion": ""}
        self.lang = contextUtil.getLang(context)
        self.context = context
        self.tlv = contextUtil.getTlv(context)
        self.product_model = contextUtil.getItem(context, "productModel")
        self.product_version = contextUtil.getItem(context, "productVersion")
        self.exist_backup_plane = contextUtil.getItem(context, "exist_backup_plane")
        self.exist_archive_plane = contextUtil.getItem(context, "exist_archive_plane")
        self.exist_copy_plane = contextUtil.getItem(context, "exist_copy_plane")
        self.backup_start_ip = contextUtil.getItem(context, "backup_start_ip")
        self.backup_end_ip = contextUtil.getItem(context, "backup_end_ip")
        self.backup_start_ip_new = contextUtil.getItem(context, "backup_start_ip_new")
        self.backup_end_ip_new = contextUtil.getItem(context, "backup_end_ip_new")
        self.backup_subnet_mask = contextUtil.getItem(context, "backup_subnet_mask")
        self.archive_start_ip = contextUtil.getItem(context, "archive_start_ip")
        self.archive_end_ip = contextUtil.getItem(context, "archive_end_ip")
        self.archive_start_ip_new = contextUtil.getItem(context, "archive_start_ip_new")
        self.archive_end_ip_new = contextUtil.getItem(context, "archive_end_ip_new")
        self.archive_subnet_mask = contextUtil.getItem(context, "archive_subnet_mask")
        self.copy_start_ip = contextUtil.getItem(context, "copy_start_ip")
        self.copy_end_ip = contextUtil.getItem(context, "copy_end_ip")
        self.copy_start_ip_new = contextUtil.getItem(context, "copy_start_ip_new")
        self.copy_end_ip_new = contextUtil.getItem(context, "copy_end_ip_new")
        self.copy_subnet_mask = contextUtil.getItem(context, "copy_subnet_mask")

    def execute_pod(self):
        try:
            common.threadUpProcess(self.context, CHECK_TIME, INTERVAL)
            # 2引擎起扩，不需要扩IP
            if contextUtil.getItem(self.context, "ctrl_enclosure_num") >= 2:
                contextUtil.handleSuccess(self.context)
                return
            if not self.is_risk_model_and_version():
                contextUtil.handleSuccess(self.context)
                return
            # 下发扩容网段命令
            if not self.expand_network_service():
                self.set_err_info("expand.command.execute.fail")
                return

            # 查询扩容是否完成
            self.check_expand_network_status()

        except Exception as e:
            self.logger.logException(e)
            contextUtil.handleException(self.context, e)
        finally:
            common.finishProcess(self.context)

    def is_risk_model_and_version(self):
        """
        判断是否为风险型号和版本
        :return:
        """
        # A8000如果版本大于等于1.0.0，返回True
        if self.product_model in OCEAN_PROTECT_A and \
                Products.compareVersion(self.product_version, "1.0.0") >= 0:
            return True
        # 返回是否拥有容器特性
        return contextUtil.getItem(self.context, "has_container_feature")

    def expand_network_service(self):
        """
        扩容服务网段
        :return:
        """
        param_dict = self.get_expand_network_param_dict()
        params_str = restUtil.CommonRest.getParamsJsonStr(param_dict, False)
        uri = "system/expandBackupService"
        record = common.execute_container_rest_cmd(
            self.context, uri, "POST", param=params_str)
        return record.get("error", {}).get("code") == "0"

    def get_expand_network_param_dict(self):
        """
        获取扩容网段的参数字典
        :return:
        """
        dev = contextUtil.getDevObj(self.context)
        user_info = {"username": dev.get("user"),
                     "password": dev.get("pawd")}
        # 连续性网段
        if not contextUtil.getItem(self.context, "have_subnet"):
            return self.get_param_dict_without_subnet(user_info)

        # 非连续性网段
        required_ctrl_num = common.get_required_ctrl_num(self.context, self.product_model)
        param_dict = {
            "storageAuth": user_info,
            "controller": required_ctrl_num
        }
        # 没有备份IP
        if not self.exist_backup_plane:
            return param_dict

        backup_plane = []
        for plane in self.exist_backup_plane:
            backup_plane.append({"startIp": plane.get("startIp"),
                                 "endIp": plane.get("endIp")})
        if self.backup_end_ip_new and self.backup_end_ip_new:
            backup_plane.append({
                "startIp": self.backup_start_ip_new,
                "endIp": self.backup_end_ip_new
            })
        param_dict.update({"backupPlane": backup_plane})
        self.set_archive_param(param_dict)
        self.set_copy_param(param_dict)
        return param_dict

    def set_archive_param(self, param_dict):
        if not self.exist_archive_plane:
            return param_dict

        archive_plane = []
        for plane in self.exist_archive_plane:
            archive_plane.append({"startIp": plane.get("startIp"),
                                  "endIp": plane.get("endIp")})
        if self.archive_end_ip_new and self.archive_end_ip_new:
            archive_plane.append({
                "startIp": self.archive_start_ip_new,
                "endIp": self.archive_end_ip_new
            })
        param_dict.update({
            "archivePlane": archive_plane,
        })
        return param_dict

    def set_copy_param(self, param_dict):
        if not self.exist_copy_plane:
            return param_dict

        copy_plane = []
        for plane in self.exist_copy_plane:
            copy_plane.append({"startIp": plane.get("startIp"),
                                  "endIp": plane.get("endIp")})
        if self.copy_end_ip_new and self.copy_end_ip_new:
            copy_plane.append({
                "startIp": self.copy_start_ip_new,
                "endIp": self.copy_end_ip_new
            })
        param_dict.update({
            "copyPlane": copy_plane,
        })
        return param_dict

    def get_param_dict_without_subnet(self, user_info):
        """
        获取连续网段的参数字典
        :param user_info: 用户信息
        :return:
        """
        param_dict = {
                "storageAuth": user_info,
                "controller": 4
            }
        if not self.backup_start_ip and not self.backup_end_ip:
            return param_dict
        backup_plane = [{"startIp": self.backup_start_ip,
                         "endIp": self.backup_end_ip}]
        param_dict.update({"backupPlane": backup_plane})

        if self.archive_start_ip and self.archive_end_ip:
            archive_plane = [{"startIp": self.archive_start_ip,
                              "endIp": self.archive_end_ip}]
            param_dict.update({"archivePlane": archive_plane})

        if self.copy_start_ip and self.copy_end_ip:
            copy_plane = [{"startIp": self.copy_start_ip,
                              "endIp": self.copy_end_ip}]
            param_dict.update({"copyPlane": copy_plane, })
        return param_dict

    def check_expand_network_status(self):
        """
        检查扩网段是否成功
        :return:
        """
        # 10s之后再查询
        time.sleep(10)
        no_response_err_times = 0
        start_time = time.time()
        while time.time() - start_time <= CHECK_TIME:
            status, error_code, error_param = self.get_expand_network_status()
            if str(status) == "0":
                contextUtil.handleSuccess(self.context)
                return
            if str(status) == "-1":
                self.set_err_info("", error_code, error_param)
                return
            # 当status为2时，也计算循环失败的次数
            if status is None or str(status) == "2":
                no_response_err_times += 1
            else:
                no_response_err_times = 0
            # 连续3次都不可访问。
            if no_response_err_times >= NO_RESPONSE_RETRY_TIMES:
                self.logger.logInfo("up to 3 times no response error.")
                self.set_err_info("expand.network.service.fail")
                return
            # 10s查询一次
            time.sleep(SLEEP_TIME)

        self.set_err_info("expand.network.service.fail")
        return

    def get_expand_network_status(self):
        """
        获取扩容网段状态
        :return:
        """
        try:
            uri = "system/expandControllerStatus"
            record = common.execute_container_rest_cmd(self.context, uri, "GET")
            return (record.get("data", {}).get("status"),
                    record.get("error", {}).get("code"),
                    record.get("error", {}).get("errorParam"))
        except Exception as e:
            self.logger.logException(e)
            return None, None, None

    def set_err_info(self, msg, error_code="", error_param=""):
        """
        设置失败信息
        :return:
        """
        self.result_dict["flag"] = False
        if not error_code:
            self.result_dict["errMsg"], self.result_dict["suggestion"] = \
                common.getMsg(self.lang, msg)
        else:
            self.result_dict["errMsg"], self.result_dict["suggestion"] = \
                contextUtil.getItem(
                    self.context, 'errCodeMgrObj', {}).get(int(error_code))
            self.format_err_msg(error_param)
        contextUtil.handleFailure(self.context, self.result_dict)

    def format_err_msg(self, error_param):
        """
        格式化错误码信息
        :return:
        """
        try:
            self.result_dict["errMsg"] = \
                self.result_dict["errMsg"].format(*error_param)
            self.result_dict["suggestion"] = \
                self.result_dict["suggestion"].format(*error_param)
        except Exception as e:
            self.logger.logException(e)
            self.result_dict["errMsg"], self.result_dict["suggestion"] = \
                common.getMsg(self.lang, "expand.network.service.fail")
