# -*- coding:utf-8 -*-


import re
from collections import Counter
from collections import defaultdict
from time import sleep

import paramiko
import utils.common.log as logger
from utils.common.error.hcci_error_code import get_code_msg

from plugins.eReplication.common.constant import Example
from plugins.eReplication.common.lib.conditions import Condition
from plugins.eReplication.common.lib.utils import check_host_connection
from plugins.eReplication.common.lib.utils import check_param_az_id
from plugins.eReplication.common.lib.utils import check_param_integer
from plugins.eReplication.common.lib.utils import check_param_ip
from plugins.eReplication.common.lib.utils import check_value_null
from plugins.eReplication.common.request_api import RequestApi
from plugins.eReplication.common.request_api import SERVICE_INVALID
from plugins.eReplication.common.request_api import SERVICE_IP_NOT_CONNECTED
from plugins.eReplication.common.request_api import SERVICE_PWD_ERROR


class ParamCheck(object):
    def __init__(self, kwargs, err_msg):
        self.kwargs = kwargs
        self.err_msg = err_msg
        self.params = self.kwargs.get("params")
        self.project_id = self.params.get('project_id')
        self.condition = Condition(self.project_id)
        self.check_ok = 5
        self.ip_or_port_error = 6
        self.password_error = 7

    def check_lld_ips(self):
        kwargs_key_lld_ip_list = 'lld_ip_list'
        lld_ip_lst = ["eReplication_Server01", "eReplication_Server02",
                      "eReplication_FloatIP", "VHA_Quorum_Server"]
        if kwargs_key_lld_ip_list not in list(self.kwargs["params"].keys()):
            return
        ip_address = None
        for key in lld_ip_lst:
            found = False
            for item in self.kwargs['params'][kwargs_key_lld_ip_list]:
                if item["param_name"] == key:
                    found = True
                    ip_address = item["ip_address"]
                    break
            if not found:
                continue
            if check_value_null(ip_address) or not check_param_ip(ip_address):
                self.err_msg.update({key: get_code_msg('663501') % ip_address})
        return

    def check_primary_region_ips(self):
        primary_key = "eReplication_Primary_IP"
        standby_key = "eReplication_Second_IP"
        primary_ip = self._get_and_check_primary_ip(primary_key)
        standby_ip = self._get_and_check_primary_ip(standby_key)
        if primary_ip and standby_ip and (primary_ip == standby_ip):
            self.err_msg.update(
                {f"{primary_key},{standby_key}": get_code_msg('663512')})

    def _get_and_check_primary_ip(self, primary_key):
        find_flag, primary_ip = self.get_params_value_by_key(primary_key)
        if find_flag and not self.condition.is_current_dr_installed:
            if check_value_null(primary_ip) or not check_param_ip(primary_ip):
                self.err_msg.update(
                    {primary_key: get_code_msg('663501') % primary_ip})
            else:
                if not check_host_connection(primary_ip):
                    self.err_msg.update(
                        {primary_key: get_code_msg('663507') % primary_ip})
        return primary_ip

    def get_params_value_by_key(self, key):
        find_flag = False
        value = None
        for tmp_dict in self.kwargs["params"]["param_list"]:
            if tmp_dict["key"] == key:
                find_flag = True
                value = tmp_dict["value"]
                return find_flag, value
        return find_flag, value

    def check_dns_ntp(self):
        dns_key1 = "dns_ntp_om_ip_01"
        dns_key2 = "dns_ntp_om_ip_02"
        find_key1, dns_ip1 = self.get_params_value_by_key(dns_key1)
        find_key2, dns_ip2 = self.get_params_value_by_key(dns_key2)
        if find_key1 and (
                check_value_null(dns_ip1) or not check_param_ip(dns_ip1)):
            self.err_msg.update({dns_key1: get_code_msg('663501') % dns_ip1})
        if find_key2 and (
                check_value_null(dns_ip2) or not check_param_ip(dns_ip2)):
            self.err_msg.update({dns_key2: get_code_msg('663501') % dns_ip2})

    def check_quorum(self):
        """
        全新安装csha校验仲裁
        :return:
        """
        if self.condition.is_upgrade_to_hcs8:
            quorum_key_lst = [
                "Third_site_IP_of_Arbitration_Service",
                "Arbitration_DC1_01_API", "Arbitration_DC1_02_API",
                "Arbitration_DC2_01_API", "Arbitration_DC2_02_API",
                "Arbitration_DC1_01_OM", "Arbitration_DC1_02_OM",
                "Arbitration_DC2_01_OM", "Arbitration_DC2_02_OM"]
        else:
            quorum_key_lst = [
                "Third_site_IP_of_Arbitration_Service", "Arbitration_DC1_01",
                "Arbitration_DC1_02", "Arbitration_DC2_01",
                "Arbitration_DC2_02"]
        for quorum_key in quorum_key_lst:
            find_key, quorum_ip = self.get_params_value_by_key(quorum_key)
            if not find_key:
                self.err_msg.update(
                    {quorum_key: get_code_msg('663500') % quorum_key})
                continue
            if check_value_null(quorum_ip) or not check_param_ip(quorum_ip):
                self.err_msg.update(
                    {quorum_key: get_code_msg('663501') % quorum_ip})
            else:
                if not check_host_connection(quorum_ip):
                    self.err_msg.update(
                        {quorum_key: get_code_msg('663507') % quorum_ip})

    def _check_csha_quorum_vm_params(self):
        arb_ip_key = "csha_quorum_ip"
        arb_user_pwd_key = "csha_quorum_vm_sudo_password"
        find_ip, arb_ip = self.get_params_value_by_key(arb_ip_key)
        find_user_pwd, arb_user_pwd = self.get_params_value_by_key(
            arb_user_pwd_key)
        is_ip_ok = False
        if not find_ip:
            self.err_msg.update(
                {arb_ip_key: get_code_msg('663500') % arb_ip_key})
        elif check_value_null(arb_ip):
            self.err_msg.update({arb_ip_key: get_code_msg('663502')})
        elif not check_param_ip(arb_ip):
            self.err_msg.update({arb_ip_key: get_code_msg('663501') % arb_ip})
        elif not check_host_connection(arb_ip):
            self.err_msg.update({arb_ip_key: get_code_msg('663507') % arb_ip})
        else:
            is_ip_ok = True
        if check_value_null(arb_user_pwd):
            self.err_msg.update({arb_user_pwd_key: get_code_msg('663502')})
            return
        if not is_ip_ok:
            return
        check_code, check_result = self._check_csha_quorum_params(
            arb_ip, "22", "root", arb_user_pwd)
        if check_code == self.ip_or_port_error:
            self.err_msg.update(
                {arb_ip: get_code_msg('663508') % check_result})
        if check_code == self.password_error:
            self.err_msg.update({arb_user_pwd_key: get_code_msg(
                '663509') % ("root", check_result)})

    def _check_csha_quorum_params(self, ip, port, user, pwd):
        para_ins = None
        err_msg = ""
        for _ in range(0, 3):
            try:
                para_ins = paramiko.Transport((ip, int(port)))
            except Exception as e:
                logger.error(f"Check csha quorum params failed: {e}")
                err_msg = str(e)
                sleep(2)
            break
        if not para_ins:
            return self.ip_or_port_error, err_msg
        try:
            para_ins.connect(username=user, password=pwd, pkey=None)
        except Exception as e:
            logger.error(f"Check csha quorum params failed: {e}")
            return self.password_error, str(e)
        return self.check_ok, ""

    def _check_csha_region_az_map(self):
        region_map_key = "csha_region_map_info"
        find_flag, region_info = self.get_params_value_by_key(region_map_key)
        if not find_flag:
            self.err_msg.update(
                {region_map_key: get_code_msg('663500') % region_map_key})
        if find_flag and check_value_null(region_info):
            self.err_msg.update({region_map_key: get_code_msg('663502')})
            return
        for az_volume in region_info.replace(' ', '').split('#'):
            value_info = az_volume.split('|')
            if len(value_info) != 2:
                self.err_msg.update({
                    region_map_key: get_code_msg('663503') % (
                        region_info, Example.CSHA_MAP)})
            elif len(value_info[0].split(',')) != 2 or len(
                    value_info[1].split(',')) != 2:
                self.err_msg.update(
                    {region_map_key: get_code_msg('663503') % (
                        region_info, Example.CSHA_MAP)})
            elif not check_param_az_id(
                    value_info[0].split(',')[0], raise_ex=False) or not \
                    check_param_az_id(
                        value_info[1].split(',')[0], raise_ex=False):
                self.err_msg.update(
                    {region_map_key: get_code_msg('663503') % (
                        region_info, Example.CSHA_MAP)})
            else:
                if value_info[0].split(',')[0] == value_info[1].split(',')[
                    0] or value_info[0].split(',')[1] == \
                        value_info[1].split(',')[1]:
                    self.err_msg.update(
                        {region_map_key: get_code_msg('663503') % (
                            region_info, Example.CSHA_MAP)})

    def csha_param_check(self):
        if self.condition.csha_install_quorum:
            self._check_csha_quorum_vm_params()
        self._check_csha_region_az_map()

    def csdr_param_check(self):
        region_map_key = "csdr_region_map_info"
        find_flag, region_info = self.get_params_value_by_key(region_map_key)
        if not find_flag:
            self.err_msg.update(
                {region_map_key: get_code_msg('663500') % region_map_key})
        if find_flag and check_value_null(region_info):
            self.err_msg.update({region_map_key: get_code_msg('663502')})
            return
        priority_set = set()
        for region in region_info.replace(' ', '').split('#'):
            value_info = region.split('&')
            volume_info = value_info[0].split('|')
            if len(region_info.replace(' ', '').split('#')) == 1 and len(
                    region.split('|')) != 2:
                self.err_msg.update(
                    {region_map_key: get_code_msg('663503') % (
                        region_map_key, Example.CSDR_MAP)})
            elif len(region_info.replace(' ', '').split('#')) > 1 and len(
                    value_info) != 2:
                self.err_msg.update({
                    region_map_key: get_code_msg('663503') % (
                        region_map_key, Example.CSDR_MAP_CYCLE)})
            if len(region.split('|')) == 2:
                volume_info_list = [
                    _volume_info.split(",")[0] for _volume_info in volume_info]
                regions = set(volume_info_list)
                if len(regions) <= 1:
                    self.err_msg.update(
                        {region_map_key: get_code_msg('663519') % region_info})
                    break
            if len(region_info.replace(' ', '').split('#')) > 1:
                self._check_volume_info(
                    volume_info, region_map_key, is_cycle=True)
                priority = str(value_info[1])
                if (not priority.isdigit()) or (
                        int(priority) < 1 or int(priority) > 1000):
                    self.err_msg.update(
                        {region_map_key: get_code_msg('663503') % (
                            region_map_key, Example.CSDR_MAP_CYCLE)})
                priority_set.add(priority)
            else:
                self._check_volume_info(volume_info, region_map_key)
        if not self.err_msg and (priority_set and len(priority_set) != 2):
            self.err_msg.update(
                {region_map_key: get_code_msg('663503') % (
                    region_map_key, Example.CSDR_MAP_CYCLE)})

    def csdr_param_check_for_nas(self):
        region_map_key = "csdr_region_map_info_for_nas"
        find_flag, region_info = self.get_params_value_by_key(region_map_key)
        if not find_flag:
            self.err_msg.update(
                {region_map_key: get_code_msg('663500') % region_map_key})
            return
        if find_flag and check_value_null(region_info):
            self.err_msg.update({region_map_key: get_code_msg('663502')})
            return
        region_info_arr = region_info.replace(' ', '').split('#')
        for region in region_info_arr:
            if len(region_info_arr) == 1 and len(region.split('|')) != 2:
                self.err_msg.update(
                    {region_map_key: get_code_msg('663503') % (
                        region_map_key, Example.CSDR_MAP_FOR_NAS)})
            if len(region_info_arr) > 1:
                self._check_share_type_info(
                    region, region_map_key, is_cycle=True)
            else:
                self._check_share_type_info(region, region_map_key)
            if len(region.split('|')) == 2:
                share_type_info = region.replace(' ', '').split('|')
                region_info_list = [_share_type_info.split(",")[0]
                                    for _share_type_info in share_type_info]
                regions = set(region_info_list)
                if len(regions) <= 1:
                    self.err_msg.update(
                        {region_map_key: get_code_msg('663519') % region_info})
                    break

    def _check_share_type_info(self, region, region_map_key, is_cycle=False):
        message = Example.CSDR_MAP_CYCLE_FOR_NAS if is_cycle else Example. \
            CSDR_MAP_FOR_NAS
        share_type_info = region.split('|')
        if len(share_type_info) != 2:
            self.err_msg.update(
                {region_map_key: get_code_msg('663503') % (
                    region_map_key, message)})
        elif len(share_type_info[0].split(',')) != 3 or len(
                share_type_info[1].split(',')) != 3:
            self.err_msg.update(
                {region_map_key: get_code_msg('663503') % (
                    region_map_key, message)})

    def _check_volume_info(self, volume_info, region_map_key, is_cycle=False):
        message = Example.CSDR_MAP_CYCLE if is_cycle else Example.CSDR_MAP
        if len(volume_info) != 2:
            self.err_msg.update(
                {region_map_key: get_code_msg('663503') % (
                    region_map_key, message)})
        elif len(volume_info[0].split(',')) != 3 or len(
                volume_info[1].split(',')) != 3:
            self.err_msg.update(
                {region_map_key: get_code_msg('663503') % (
                    region_map_key, message)})

    def check_service(self):
        """
        服务已安装时，对服务信息进行的校验
        :return:
        """
        server_addr = "eReplication_ip"
        server_port = "eReplication_port"
        server_admin_pwd = "eReplication_SyncAdmin_password"
        find_host, host = self.get_params_value_by_key(server_addr)
        find_port, port = self.get_params_value_by_key(server_port)
        find_pwd, pwd = self.get_params_value_by_key(server_admin_pwd)
        # 如果已存在容灾服务, 则必须配置eReplication_ip等信息, 因为只能使用同一套eReplication
        is_ok = True
        if check_value_null(host) or not check_param_ip(host):
            is_ok = False
            self.err_msg.update({server_addr: get_code_msg('663501') % host})
        else:
            if not check_host_connection(host):
                self.err_msg.update(
                    {server_addr: get_code_msg('663507') % host})
        if check_value_null(port) or not check_param_integer(port):
            is_ok = False
            self.err_msg.update({server_port: get_code_msg('663502')})
        if check_value_null(pwd):
            is_ok = False
            self.err_msg.update({server_admin_pwd: get_code_msg('663502')})
        if not is_ok:
            return
            # 参数正确性校验
        request_api = RequestApi(host, "SyncAdmin", pwd, port, raise_ex=False)
        res_code, res_msg = request_api.check_dr_service()
        if res_code == SERVICE_PWD_ERROR:
            self.err_msg.update({server_admin_pwd: get_code_msg('663505')})
        if res_code == SERVICE_INVALID:
            self.err_msg.update(
                {f"{server_addr},{server_port}": get_code_msg('663506') % (
                    host, port, str(res_msg))})
        if res_code == SERVICE_IP_NOT_CONNECTED:
            self.err_msg.update({server_addr: get_code_msg('663507') % host})

    def _check_param_region_az_comply_with_the_naming_rule(
            self, region_az_lst, param_key):
        region_naming_rule = re.compile(r"^[a-z]+-[a-z]+-[1-9]\d*$")
        az_naming_rule = re.compile(
            r"^[a-z\d][a-z\d-]{0,48}[a-z\d]\."
            r"[a-z\d][a-z\d-]{0,48}[a-z\d]$")
        for region_az in region_az_lst:
            region = region_az.split(",")[0]
            az = region_az.split(",")[1]
            if not re.match(region_naming_rule, region) or len(region) > 32:
                self.err_msg.update(
                    {param_key: get_code_msg('663541') % region_az})
            if not re.match(az_naming_rule, az):
                self.err_msg.update(
                    {param_key: get_code_msg('663542') % region_az})

    @staticmethod
    def _check_element_valid(value_info):
        element_valid = True
        for ele in value_info[0: -1]:
            if len(ele.split(",")) != 2:
                element_valid = False
                break
        return element_valid
