#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright 2016 Huawei Technologies Co. Ltd. All rights reserved.
"""Snat db"""

import netaddr
import sqlalchemy as sa
from sqlalchemy.orm import exc
from oslo_utils import uuidutils
from neutron.db import models_v2
try:
    from neutron.db import common_db_mixin
except ImportError:
    from networking_huawei.drivers.ac.common import common_db_mixin
try:
    from neutron.db.l3_db import Router
except ImportError:
    from neutron.db.models.l3 import Router
from networking_huawei.drivers.ac.common.neutron_compatible_util import \
    ac_log as logging
try:
    from neutron.db import model_base
except ImportError:
    from neutron_lib.db import model_base

from networking_huawei._i18n import _LI
from networking_huawei.drivers.ac.extensions.snat import snat as extension

LOG = logging.getLogger(__name__)

NAT44 = 0
NAT64 = 1


class Snat(model_base.BASEV2):
    """Snat db table"""
    __tablename__ = 'huawei_ac_snat'
    id = sa.Column(sa.String(36), primary_key=True,
                   default=uuidutils.generate_uuid)
    name = sa.Column(sa.String(255), nullable=True)
    tenant_id = sa.Column(sa.String(255), nullable=False)
    router_id = sa.Column(sa.String(36), sa.ForeignKey('routers.id'))
    snat_network_id = sa.Column(sa.String(36), sa.ForeignKey('networks.id'))
    snat_ip_address = sa.Column(sa.String(64), nullable=True)
    snat_ip_pool = sa.Column(sa.PickleType(protocol=2), nullable=True)
    original_cidrs = sa.Column(sa.PickleType(protocol=2), nullable=True)
    type = sa.Column(sa.Integer, nullable=False)


class SnatDbMixin(common_db_mixin.CommonDbMixin):
    """Snat db"""
    def _make_snat_dict(self, snat_db, fields=None):
        res = {
            'id': snat_db['id'],
            'name': snat_db['name'],
            'tenant_id': snat_db['tenant_id'],
            'router_id': snat_db['router_id'],
            'snat_network_id': snat_db['snat_network_id'],
            'snat_ip_address': snat_db['snat_ip_address'],
            'snat_ip_pool': snat_db['snat_ip_pool'],
            'original_cidrs': snat_db['original_cidrs'],
            'type': int(snat_db['type']),
        }
        return self._fields(res, fields)

    @classmethod
    def _validate_tenant_id(cls, context, snat):
        router = context.session.query(Router).filter(
            Router.id == snat['router_id']
        ).first()
        if not router:
            raise extension.ResourceNotFound(rs='Router', id=snat['router_id'])
        if snat.get('tenant_id') and snat['tenant_id'] != router.tenant_id:
            raise extension.TenantIdConflict(
                tid=snat['tenant_id'], rid=snat['router_id'])
        return router.tenant_id

    @classmethod
    def _validate_network_id(cls, context, network_id):
        if not network_id:
            return

        network = context.session.query(models_v2.Network).filter(
            models_v2.Network.id == network_id
        ).first()
        if not network:
            raise extension.ResourceNotFound(rs='Network', id=network_id)

    @classmethod
    def _validate_snat_ip_address(cls, nat_type, ip_address):
        if nat_type == NAT44 and not ip_address:
            raise extension.SnatIpAddressNotSpecified()

        if nat_type == NAT64 and ip_address:
            raise extension.SnatIpAddressSpecified()

    @classmethod
    def _validate_snat_ip_pool(cls, nat_type, ip_pools):
        if nat_type == NAT64 and not ip_pools:
            raise extension.SnatIpPoolNotSpecified()

        if nat_type == NAT44 and ip_pools:
            raise extension.SnatIpAddressSpecified()

        if len(ip_pools) > 8:
            raise extension.SnatIpPoolOverLength()

        ip_set = None
        for ip_pool in ip_pools:
            for param in ['begin_ip', 'end_ip']:
                if param not in ip_pool:
                    raise extension.SnatIpPoolParamNotSpecified(param=param)

            try:
                if ip_set is None:
                    ip_set = netaddr.IPSet(netaddr.IPRange(
                        ip_pool['begin_ip'], ip_pool['end_ip']))
                else:
                    ip_set = ip_set & netaddr.IPSet(netaddr.IPRange(
                        ip_pool['begin_ip'], ip_pool['end_ip']))
            except netaddr.AddrFormatError:
                raise extension.SnatIpPoolInvalid(ip_pool=ip_pool)

        if isinstance(ip_set, netaddr.IPSet) and ip_set.size != 0 and \
                len(ip_pools) != 1:
            raise extension.SnatIpPoolConflict(ip_pool=ip_pools)

    @classmethod
    def _validate_original_cidrs(cls, original_cidrs):
        if not original_cidrs:
            return

        ip_version = None
        for cidr in original_cidrs:
            try:
                ip_addr = netaddr.IPNetwork(cidr)
            except netaddr.AddrFormatError:
                raise extension.SnatOriginalCidrsInvalid(cidr=cidr)

            if not ip_version:
                ip_version = ip_addr.version

            if ip_addr.version != ip_version:
                raise extension.SnatOriginalCidrsConflict(cidr=cidr)

    def create_snat_db(self, context, snat):
        """create snat db"""
        LOG.info(_LI('[AC] Begin to create SNAT in DB: %s'), snat)

        snat_id = snat.get('id')
        if not snat_id:
            snat_id = uuidutils.generate_uuid()

        tenant_id = self._validate_tenant_id(context, snat)

        snat_network_id = snat.get('snat_network_id')
        self._validate_network_id(context, snat_network_id)

        snat_ip_address = snat.get('snat_ip_address')
        self._validate_snat_ip_address(snat['type'], snat_ip_address)

        snat_ip_pool = snat.get('snat_ip_pool')
        self._validate_snat_ip_pool(snat['type'], snat_ip_pool)

        original_cidrs = snat.get('original_cidrs')
        self._validate_original_cidrs(original_cidrs)

        with context.session.begin(subtransactions=True):
            snat_db = Snat(
                id=snat_id,
                name=snat['name'],
                tenant_id=tenant_id,
                router_id=snat['router_id'],
                snat_network_id=snat_network_id,
                snat_ip_address=snat_ip_address,
                snat_ip_pool=snat['snat_ip_pool'],
                original_cidrs=original_cidrs,
                type=snat['type'],
            )
            context.session.add(snat_db)
            ret = self._make_snat_dict(snat_db)
        LOG.info(_LI('[AC] End of creating SNAT in DB: %s'), ret)
        return ret

    def update_snat_db(self, context, snat_id, snat):
        """update snat db"""
        LOG.info(_LI('[AC] Begin to update SNAT in DB: %s'), snat)
        with context.session.begin(subtransactions=True):
            snat_db = self._get_snat_db(context, snat_id)
            snat_db.update(snat)
            ret = self._make_snat_dict(snat_db)
        LOG.info(_LI('[AC] End of updating SNAT in DB: %s'), ret)
        return ret

    def delete_snat_db(self, context, snat_id):
        """delete snat db"""
        LOG.info(_LI('[AC] Begin to delete SNAT in DB: %s'), snat_id)
        snat_db = self._get_snat_db(context, snat_id)
        with context.session.begin(subtransactions=True):
            context.session.delete(snat_db)
        LOG.info(_LI('[AC] End of deleting SNAT in DB: %s'), snat_id)

    def _get_snat_db(self, context, snat_id):
        try:
            snat = self._get_by_id(context, Snat, snat_id)
        except exc.NoResultFound:
            raise extension.ResourceNotFound(rs='SNAT', id=snat_id)
        return snat

    def get_snat_db(self, context, snat_id, fields=None):
        """get snat db"""
        snat = self._get_snat_db(context, snat_id)
        return self._make_snat_dict(snat, fields)

    def get_snats_db(self, context, filters=None, fields=None):
        """get snats db"""
        return self._get_collection(
            context, Snat, self._make_snat_dict,
            filters=filters, fields=fields)
