#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright 2016 Huawei Technologies Co. Ltd. All rights reserved.
"""firewall rule db"""

import copy
import netaddr

try:
    from neutron import context as neutron_context
except ImportError:
    from neutron_lib import context as neutron_context

try:
    from neutron_lib import exceptions as nexception
except ImportError:
    from neutron.common import exceptions as nexception

try:
    from neutron_lib.api.validators import validate_subnet as subnet_validate
except ImportError:
    from neutron.api.v2.attributes import _validate_subnet as \
        subnet_validate

from oslo_log import log as logging

from networking_huawei.drivers.ac.common import validate as validator
from networking_huawei.drivers.ac.db.firewall_rule import schema
from networking_huawei.drivers.ac.common import neutron_compatible_util as ncu
from networking_huawei._i18n import _LI, _LE
from networking_huawei._i18n import _

LOG = logging.getLogger(__name__)


class FirewallRuleIpAddressDuplicate(nexception.InvalidInput):
    """Firewall RuleIp Address Duplicate"""
    message = _("The same ip address exists in %(param)s.")


class FirewallRulePropertyConflict(nexception.InvalidInput):
    """Firewall Rule Property Conflict"""
    message = _("%(param)s cannot exist at the same time.")


class FirewallRulePortsLengthMax(nexception.InvalidInput):
    """Firewall Rule Ports Length Max"""
    message = _("The number of %(param)s-ports exceeds the maximum size 15.")


class FirewallRuleIpAddressLengthMax(nexception.InvalidInput):
    """Firewall Rule Ip Address Length Max"""
    message = _("The number of %(param)s ip addresses exceeds the "
                "maximum size 100.")


class FirewallIpAddressConflict(nexception.InvalidInput):
    """Firewall Ip Address Conflict"""
    message = _("Invalid input - IP addresses do not agree with IP Version")


class FirewallRuleInvalidPortValue(nexception.InvalidInput):
    """Firewall Rule Invalid Port Value"""
    message = _("Invalid value for port %(port)s.")


class FirewallRuleInvalidValue(nexception.InvalidInput):
    """Firewall Rule Invalid Value"""
    message = _("%(param)s.")


class ACFirewallruleAddrsDbMixin(ncu.base_db.CommonDbMixin):
    """AC firewall rule addresses DB
    """

    def _get_ac_firewall_rule_resource(self, context, model, v_id):
        try:
            return self._get_by_id(context, model, v_id)
        except Exception as ex:
            LOG.error(_LE('[ac] Get AC firewall rule addresses resource '
                          'failed : %s'), ex)
            return None

    def _make_ac_firewall_rule_dict(self, ac_firewall_rule_db, fields=None):
        LOG.info(_LI('[ac] _make_ac_firewall_rule_dict in Neutron '
                     'DB for device %s.'), ac_firewall_rule_db)
        if ac_firewall_rule_db:
            res = {'id': ac_firewall_rule_db.get('id', ''),
                   'source_ports': ac_firewall_rule_db.get('source_ports', ''),
                   'destination_ports': ac_firewall_rule_db.get(
                       'destination_ports', ''),
                   'source_ip_addresses': ac_firewall_rule_db.get(
                       'source_ip_addresses', ''),
                   'destination_ip_addresses': ac_firewall_rule_db.get(
                       'destination_ip_addresses', '')}
            return self._fields(res, fields)
        return None

    @classmethod
    def get_string_fwr_dto(cls, fwr):
        """get string fwr dto"""
        source_ip_addresses = fwr.get('source_ip_addresses', None)
        if source_ip_addresses and isinstance(source_ip_addresses, list):
            fwr['source_ip_addresses'] = ','.join(source_ip_addresses)

        destination_ip_addresses = fwr.get('destination_ip_addresses', None)
        if destination_ip_addresses and isinstance(destination_ip_addresses,
                                                   list):
            fwr['destination_ip_addresses'] = ','.join(destination_ip_addresses)

        source_ports = fwr.get('source_ports', None)
        if source_ports and isinstance(source_ports, list):
            fwr['source_ports'] = ','.join(source_ports)

        destination_ports = fwr.get('destination_ports', None)
        if destination_ports and isinstance(destination_ports, list):
            fwr['destination_ports'] = ','.join(destination_ports)

        return fwr

    def create_db_ac_firewall_rule(self, firewall_rule, fid):
        """the function to create ac firewall rule addresses"""
        LOG.info(_LI('[ac] Create ac firewall rule  %s.'), firewall_rule)

        fwr = firewall_rule['firewall_rule']

        fwr = self.get_string_fwr_dto(fwr)

        LOG.info(_LI('[achgs] fwr is  %s.'), fwr)

        self._validate_ips_or_subnets_or_none(fwr.get('source_ip_addresses'))

        self._validate_ips_or_subnets_or_none(
            fwr.get('destination_ip_addresses'))

        self._validate_fwr_src_dst_ips_duplicate(fwr)

        ip_version = self._validate_fwr_src_dst_ip_version(fwr)

        source_ports = self._get_ports(fwr, "source")

        destination_ports = self._get_ports(fwr, "destination")

        source_ip_addresses = self._get_ip_addresses(
            fwr.get('source_ip_address', None),
            fwr.get('source_ip_addresses', None),
            "Source-ip-address and source-ip-addresses")
        destination_ip_addresses = self._get_ip_addresses(
            fwr.get('destination_ip_address', None),
            fwr.get('destination_ip_addresses', None),
            "Destination-ip-address and destination-ip-addresses")

        context = neutron_context.get_admin_context()

        with context.session.begin(subtransactions=True):
            firewall_rule_addrs = schema.ACFirewallruleSchema(
                id=fid,
                source_ports=source_ports,
                destination_ports=destination_ports,
                source_ip_addresses=source_ip_addresses,
                destination_ip_addresses=destination_ip_addresses)

            context.session.add(firewall_rule_addrs)
            context.session.flush()
            return [self._make_ac_firewall_rule_dict(firewall_rule_addrs),
                    ip_version]

    def rollback_db_ac_firewall_rule(self, original_fw_rule):
        """rollback db ac firewall rule"""
        LOG.info(_LI('[ac] rollback_db_ac_firewall_rule %s.'), original_fw_rule)

        rule_id = original_fw_rule['id']

        context = neutron_context.get_admin_context()

        try:
            with context.session.begin(subtransactions=True):
                rollback_dto = self._get_ac_firewall_rule_resource(
                    context, schema.ACFirewallruleSchema, rule_id)
                if rollback_dto:
                    rollback_dto.update(original_fw_rule)

        except Exception as ex:
            LOG.error(_LE("[ac] rollback_db_ac_firewall_rule "
                          "failed %s"), str(ex))

    def update_db_ac_firewall_rule(self, firewall_rule, original_fw_rule):
        """update db ac firewall rule"""
        fwr = self.get_string_fwr_dto(firewall_rule['firewall_rule'])

        result = copy.deepcopy(original_fw_rule)
        for key1, key2 in [('source_port', 'source_ports'),
                           ('destination_port', 'destination_ports'),
                           ('source_ip_address', 'source_ip_addresses'),
                           ('destination_ip_address',
                            'destination_ip_addresses')]:
            if key1 in fwr:
                result[key1] = fwr.get(key1, '')
                result[key2] = ''
            if key2 in fwr:
                result[key1] = ''
                result[key2] = fwr.get(key2, '')

        if 'action' in fwr:
            result['action'] = fwr['action']

        self._validate_ips_or_subnets_or_none(result.get('source_ip_addresses'))
        self._validate_ips_or_subnets_or_none(
            result.get('destination_ip_addresses'))
        self._validate_fwr_src_dst_ips_duplicate(result)
        result['ip_version'] = self._validate_fwr_src_dst_ip_version(result)
        result['source_ports'] = self._get_ports(result, "source")
        result['destination_ports'] = self._get_ports(result, "destination")
        result['source_ip_addresses'] = self._get_ip_addresses(
            result.get('source_ip_address', None),
            result.get('source_ip_addresses', None),
            "Source-ip-address and source-ip-addresses", )
        result['destination_ip_addresses'] = self._get_ip_addresses(
            result.get('destination_ip_address', None),
            result.get('destination_ip_addresses', None),
            "Destination-ip-address and destination-ip-addresses", )

        LOG.info(_LI('[ac] update ac firewall rule new_firewall_rule %s.'),
                 result)
        rule_id = original_fw_rule['id']
        context = neutron_context.get_admin_context()
        try:
            with context.session.begin(subtransactions=True):
                update_firewall_rule = self._get_ac_firewall_rule_resource(
                    context, schema.ACFirewallruleSchema, rule_id)
                if update_firewall_rule:
                    update_firewall_rule.update(result)
                else:
                    self.create_db_ac_firewall_rule(
                        {'firewall_rule': result}, rule_id)

        except Exception as ex:
            LOG.error(_LE("[ac]update ac firewall rule addresses in Neutron DB "
                          "failed : %s"), str(ex))

        return result

    def delete_db_ac_firewall_rule(self, rule_id):
        """the function to delete ac firewall rule addresses"""
        LOG.info(_LI('[ac]Delete ac firewall rule in neutron id is:%s.'),
                 rule_id)
        context = neutron_context.get_admin_context()
        with context.session.begin(subtransactions=True):
            firewall_rule_db = self._get_ac_firewall_rule_resource(
                context, schema.ACFirewallruleSchema, rule_id)
            if firewall_rule_db:
                context.session.delete(firewall_rule_db)

    def get_db_ac_firewall_rule_list(self, filters=None, fields=None):
        """the function to get ac firewall rule addresses list"""
        LOG.debug(_LI("[ac] Get ac firewall rule addresses from Neutron DB."))
        context = neutron_context.get_admin_context()
        return self._get_collection(context, schema.ACFirewallruleSchema,
                                    self._make_ac_firewall_rule_dict,
                                    filters=filters, fields=fields)

    def get_db_ac_firewall_rule(self, fid):
        """the function to get ac firewall rule addresses"""
        LOG.debug(_LI('[ac] Get ac firewall rule addresses from Neutron '
                      'DB by id: %s.'), fid)
        context = neutron_context.get_admin_context()
        firewall_rule_addresses = self._get_ac_firewall_rule_resource(
            context, schema.ACFirewallruleSchema, fid)
        if firewall_rule_addresses:
            return self._make_ac_firewall_rule_dict(firewall_rule_addresses)
        return None

    def _validate_ips_or_subnets_or_none(self, data):
        if data is None or data == '':
            return None
        ips = data.split(',')
        msg = None
        for ip_addr in ips:
            msg = self._validate_ip_or_subnet_or_none(ip_addr)

        if msg is not None:
            raise FirewallRuleInvalidValue(param=msg)
        return None

    @classmethod
    def _validate_ip_or_subnet_or_none(cls, data, valid_values=None):
        if data is None or data == '':
            return None
        msg_ip = validator.validate_ip_address(data, valid_values)
        if not msg_ip:
            return None
        msg_subnet = subnet_validate(data, valid_values)
        if not msg_subnet:
            return None
        return _("%(msg_ip)s and %(msg_subnet)s") % {'msg_ip': msg_ip,
                                                     'msg_subnet': msg_subnet}

    def _validate_fwr_src_dst_ip_version(self, fwr):
        src_version = dst_version = None
        ip_version = 4
        source_ip_address = fwr.get('source_ip_address', None)
        if source_ip_address is not None and source_ip_address != '':
            src_version = netaddr.IPNetwork(source_ip_address).version

        destination_ip_address = fwr.get('destination_ip_address', None)
        if destination_ip_address is not None and destination_ip_address != '':
            dst_version = netaddr.IPNetwork(destination_ip_address).version

        source_version_list = []
        source_ip_addresses = fwr.get('source_ip_addresses', None)
        if source_ip_addresses is not None and source_ip_addresses != '':
            source_version_list = self._validate_fwr_ip_addresses_version(
                source_ip_addresses)

        dest_version_list = []
        destination_ip_addresses = fwr.get('destination_ip_addresses', None)
        if destination_ip_addresses is not None and destination_ip_addresses \
                != '':
            dest_version_list = self._validate_fwr_ip_addresses_version(
                destination_ip_addresses)

        if source_version_list:
            src_version = source_version_list[0]

        if dest_version_list:
            dst_version = dest_version_list[0]

        if src_version and dst_version and src_version != dst_version:
            raise FirewallIpAddressConflict()

        if src_version:
            ip_version = src_version

        if dst_version:
            ip_version = dst_version

        return ip_version

    @classmethod
    def _validate_fwr_ip_addresses_version(cls, ipaddrs):
        version_list = []
        if not ipaddrs:
            return version_list

        ipaddr_list = ipaddrs.split(',')
        for ipaddr in ipaddr_list:
            if ipaddr:
                version_list.append(netaddr.IPNetwork(ipaddr).version)
        if len(set(version_list)) > 1:
            raise FirewallIpAddressConflict()
        return version_list

    @classmethod
    def _validate_fwr_src_dst_ips_duplicate(cls, fwr):
        source_ip_addresses = fwr.get('source_ip_addresses', None)
        if source_ip_addresses and source_ip_addresses != '':
            source_ip_address_list = source_ip_addresses.split(',')
            if len(source_ip_address_list) != len(set(source_ip_address_list)):
                raise FirewallRuleIpAddressDuplicate(
                    param='source-ip-addresses')

        destination_ip_addresses = fwr.get('destination_ip_addresses', None)
        if destination_ip_addresses and destination_ip_addresses != '':
            dest_ip_address_list = destination_ip_addresses.split(',')
            if len(dest_ip_address_list) != len(set(dest_ip_address_list)):
                raise FirewallRuleIpAddressDuplicate(
                    param='destination-ip-addresses')

    def _get_ports(self, fwr, res_type):
        ports = None
        port = None
        msg = ''

        if res_type == "source":
            msg = "Source-port and source-ports"

            source_port = fwr.get('source_port', None)
            if source_port is not None and source_port != '':
                port = source_port

            source_ports = fwr.get('source_ports', None)
            if source_ports is not None and source_ports != '':
                ports = source_ports

        if res_type == "destination":
            msg = "Destination-port and destination-ports"

            destination_port = fwr.get('destination_port', None)
            if destination_port is not None and destination_port != '':
                port = destination_port

            destination_ports = fwr.get('destination_ports', None)
            if destination_ports is not None and destination_ports != '':
                ports = destination_ports

        if port is not None and ports is not None:
            raise FirewallRulePropertyConflict(
                param=msg)

        if ports is not None and ports != '':
            port_list = ports.split(',')
            if len(port_list) > 15:
                raise FirewallRulePortsLengthMax(
                    param=res_type)
            result_list = []
            for single_port in port_list:
                msg = self._validate_port_range(single_port)
                if msg is not None:
                    raise FirewallRuleInvalidValue(param=msg)
                min_port, _, max_port = single_port.partition(":")
                if not max_port:
                    result_list.append(min_port)
                else:
                    self._validate_fwr_port_range(min_port, max_port)
                    result_list.append(min_port + ":" + max_port)
            return ','.join(result_list)
        return None

    @classmethod
    def _validate_fwr_port_range(cls, min_port, max_port):
        if int(min_port) > int(max_port):
            port_range = '%s:%s' % (min_port, max_port)
            raise FirewallRuleInvalidPortValue(port=port_range)

    @classmethod
    def _validate_port_range(cls, data):
        if not data:
            return None
        ports = data.split(':')
        msg = None
        for port in ports:
            port = str(port)
            if port.isdigit():
                port = int(port)
                if port <= 0 or port > 65535:
                    msg = _("Invalid port '%s'") % port
            else:
                msg = _("Port '%s' is not a valid number") % port
            if msg:
                LOG.debug(msg)
                return msg
        return None

    @classmethod
    def _get_ip_addresses(cls, old_ip_addr, old_ip_addr_list, msg):
        """get source or destination ip_addresses"""
        ip_address = None
        ip_addresses = None
        if old_ip_addr is not None and old_ip_addr != '':
            ip_address = old_ip_addr

        if old_ip_addr_list is not None and old_ip_addr_list != '':
            ip_addresses = old_ip_addr_list

        if ip_address is not None and ip_addresses is not None:
            raise FirewallRulePropertyConflict(param=msg)

        if ip_addresses is not None and ip_addresses != '':
            addresses = ip_addresses.split(',')
            if len(addresses) > 100:
                raise FirewallRuleIpAddressLengthMax(param=type)
            else:
                return ip_addresses
        else:
            return None
