#  coding=UTF-8
#  Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.

"""
@time: 2022/08/01
@file: check_ip_conflict.py
@function:
"""
import re
from Common.base import context_util
from Common.base import entity
from Common.base.entity import DeployException
from Common.base.entity import ResultFactory
from Common.protocol import ssh_util
from Common.service.physical_card_to_logical import PhysicalCardToLogical


def execute(task):
    return CheckIpConflict(task).check()


class CheckIpConflict(object):
    def __init__(self, task):
        self._task = task
        self._java_py_env = task.getJythonContext()
        self._strategies = None
        self._physical_to_logical = None
        self._ssh_rets = list()
        self._ips_v4 = dict()
        self._ips_v6 = dict()
        self._conflict_ip = dict()
        self._slot_ports_relation_dict = None
        self._err_msgs = list()
        self._logger = entity.create_logger(__file__)

    def check(self):
        try:
            self._strategies = context_util \
                .get_os_network_strategy(self._java_py_env)
            self._physical_to_logical = \
                PhysicalCardToLogical(self._java_py_env)
            self._slot_ports_relation_dict = \
                self.get_slot_ports_relation_dict()
            # 先将ipv4与ipv6分类，然后分别执行命令校验
            self.categorize_ips()
            self.check_ip_v4_conflict()
            self.check_ip_v6_conflict()
            if self._conflict_ip:
                for ip in self._conflict_ip:
                    self._err_msgs.append(entity.create_msg("ip.conflict.exist")
                                          .format(ip, self._conflict_ip.get(ip)))
                return ResultFactory.create_not_pass(self._ssh_rets, self._err_msgs)
            return ResultFactory.create_pass(self._ssh_rets, self._err_msgs)
        except DeployException as e:
            self._logger.error(e.message)
            self._err_msgs.append(e.err_msg)
            if e.may_info_miss():
                self._task.openAutoRetry()
            return ResultFactory.create_not_pass(e.origin_info, self._err_msgs)

    def categorize_ips(self):
        """
        将ip分类，并获取当前ip的端口名称
        """
        from com.huawei.ism.tool.distributeddeploy.entity import IpAddress
        self._logger.info("strategys:{}".format(str(self._strategies)))
        for strategy in self._strategies:
            port_name = self.get_port_name(strategy)
            self._logger.info("Port name:{}".format(str(port_name)))
            if strategy.containsIpv4():
                self.append_v4_port_name(strategy.getIpv4Address().getIp(),
                                         port_name)
            if strategy.containsIpv6():
                self.append_v6_port_name(strategy.getIpv6Address().getIp(),
                                         port_name)
            if strategy.getCtrlIpAddress():
                ctrl_ip = strategy.getCtrlIpAddress()
                if IpAddress.IpPattern.IPV6.equals(ctrl_ip.getIpPattern()):
                    self.append_v6_port_name(ctrl_ip.getIp(), port_name)
                else:
                    self.append_v4_port_name(ctrl_ip.getIp(), port_name)

    def append_v4_port_name(self, ip, port_names):
        """
        将新旧端口汇总去重
        :param ip: 配置的ip
        :param port_names: ip对应的端口列表
        """
        cache_port = self._ips_v4.get(ip, list())
        cache_port.extend(port_names)
        new_port_list = set(cache_port)
        self._ips_v4[ip] = new_port_list

    def append_v6_port_name(self, ip, port_names):
        """
        将新旧端口汇总去重
        :param ip: 配置的ip
        :param port_name: ip对应的端口列表
        """
        cache_port = self._ips_v6.get(ip, list())
        cache_port.extend(port_names)
        new_port_list = set(cache_port)
        self._ips_v6[ip] = new_port_list

    def get_port_name(self, strategy):
        """
        根据组网模式区分，获取端口名称
        :param strategy: 配置策略
        :return: 端口名称
        """
        from com.huawei.ism.tool.distributeddeploy.logic.importfile.entity.osnetwork \
            import PhysicalConfigStrategy, BondConfigStrategy, \
            BondToVlanConfigStrategy, PhysicalToVlanConfigStrategy
        if isinstance(strategy, PhysicalConfigStrategy):
            return self.get_single_prot_name(strategy)
        elif isinstance(strategy, BondConfigStrategy):
            return self.get_multi_port_name(strategy)
        elif isinstance(strategy, PhysicalToVlanConfigStrategy):
            return self.get_single_prot_name(strategy)
        elif isinstance(strategy, BondToVlanConfigStrategy):
            return self.get_multi_port_name(strategy)
        raise DeployException(
            "parsed port name failed",
            err_code=DeployException.ErrCode.MAY_INFO_MISS)

    def get_single_prot_name(self, strategy):
        port_names = list()
        slot = strategy.getPhysicalPort().getSlot()
        port = strategy.getPhysicalPort().getPort()
        port_name, ssh_ret = self._physical_to_logical \
            .get_logical_name(slot, port, self._slot_ports_relation_dict)
        self._ssh_rets.append(ssh_ret)
        port_names.append(port_name)
        return port_names

    def get_multi_port_name(self, strategy):
        """
        获取配置的 绑定端口名。兼容处理：将绑定端口下的物理端口也拿去检测
        :param strategy:配置策略
        :return:端口名称
        """
        bond_port = strategy.getBondPort()
        port_names = self.mapping_logical_port(bond_port.getPhysicalPorts())
        port_names.append(bond_port.getBondName())
        return port_names

    def mapping_logical_port(self, physical_ports):
        port_names = list()
        for physical_port in physical_ports:
            slot = physical_port.getSlot()
            port = physical_port.getPort()
            port_name, ssh_ret = self._physical_to_logical.get_logical_name(
                slot, port, self._slot_ports_relation_dict)
            self._ssh_rets.append(ssh_ret)
            port_names.append(port_name)
        return port_names

    def get_slot_ports_relation_dict(self):
        """
        获取槽位端口对应关系表
        :return: 槽位端口对应关系表
        """
        slot_to_ports_dict = dict()
        for strategy in self._strategies:
            slot_to_ports = strategy.getSlotToPortsRelation()
            slots = slot_to_ports.keys()
            for slot in slots:
                slot_to_ports_dict.setdefault(slot.lower(), slot_to_ports[slot]) \
                    .update(slot_to_ports[slot])
        return slot_to_ports_dict

    def check_ip_v4_conflict(self):
        for ip in self._ips_v4:
            for port_name in self._ips_v4.get(ip):
                mac = self.get_mac_address_by_ip_v4(ip, port_name)
                if mac:
                    self._conflict_ip[ip] = mac

    def check_ip_v6_conflict(self):
        for ip in self._ips_v6:
            for port_name in self._ips_v6.get(ip):
                mac = self.get_mac_address_by_ip_v6(ip, port_name)
                if mac:
                    self._conflict_ip[ip] = mac

    def get_mac_address_by_ip_v4(self, ip, port_name):
        """
        先用arping测试是否有其他主机使用该ip
        如果有，将mac地址截取出来
        :param ip:配置的ip
        :param port_name:使用的端口名
        :return:ip冲突的mac地址
        """
        try:
            cmd = "arping -D -I {} {} -c 1".format(port_name, ip)
            ssh_ret = ssh_util.exec_ssh_cmd(self._java_py_env, cmd)
            self._ssh_rets.append(ssh_ret)
            for line in ssh_ret.splitlines():
                if "Unicast" in line:
                    return self.parse_v4_mac_address(line)
        except DeployException as e:
            self._ssh_rets.append(e.origin_info)
        return ""

    def parse_v4_mac_address(self, ssh_ret):
        """
        arping的mac在[]里，使用正则截取
        :param ssh_ret: 原始信息
        :return: mac地址
        """
        match = re.findall(r"\[(.*)\]", ssh_ret)
        if match:
            return match[0]
        raise DeployException(
            "parsed mac address failed",
            err_code=DeployException.ErrCode.MAY_INFO_MISS)

    def get_mac_address_by_ip_v6(self, ip, port_name):
        """
        用ndisc6测试是否有其他主机使用该ip, 810及之后的环境才内置ndisc6
        :param ip:配置ip
        :param port_name:使用的端口名
        :return:ip冲突的mac地址
        """
        cmd = "ndisc6 {} {}".format(ip, port_name)
        try:
            ssh_ret = ssh_util.exec_ssh_cmd(self._java_py_env, cmd)
            self._ssh_rets.append(ssh_ret)
            for line in ssh_ret.splitlines():
                if "Target" in line:
                    return self.parse_v6_mac_address(line)
        except DeployException as e:
            self._ssh_rets.append(e.origin_info)
        return ""

    def parse_v6_mac_address(self, ssh_ret):
        """
        ndisc6的mac地址在address:后面，直接使用分隔符截取
        :param ssh_ret: 原始信息
        :return: mac地址
        """
        splits = ssh_ret.split("address: ")
        if len(splits) > 1:
            return splits[1].strip()
        raise DeployException(
            "parsed mac address failed",
            err_code=DeployException.ErrCode.MAY_INFO_MISS)
