#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright 2016 Huawei Technologies Co. Ltd. All rights reserved.
"""dnat db class"""

import sqlalchemy as sa

from sqlalchemy.orm import exc
from oslo_utils import uuidutils

from neutron.db import models_v2
try:
    from neutron.db.l3_db import Router
except ImportError:
    from neutron.db.models.l3 import Router
try:
    from neutron.db import common_db_mixin
except ImportError:
    from networking_huawei.drivers.ac.common import common_db_mixin
try:
    from neutron.db import model_base
except ImportError:
    from neutron_lib.db import model_base
from networking_huawei._i18n import _LI
from networking_huawei.drivers.ac.common import constants
from networking_huawei.drivers.ac.extensions.dnat import dnat as extension
from networking_huawei.drivers.ac.common import neutron_compatible_util as ncu

LOG = ncu.ac_log.getLogger(__name__)


class DNAT(model_base.BASEV2):
    """dnat db table class"""
    __tablename__ = 'huawei_ac_dnat'
    id = sa.Column(sa.String(36), primary_key=True,
                   default=uuidutils.generate_uuid)
    floating_ip_id = sa.Column(sa.String(36), sa.ForeignKey('floatingips.id'))
    router_id = sa.Column(sa.String(36), sa.ForeignKey('routers.id'))
    fixed_port_id = sa.Column(sa.String(36), sa.ForeignKey('ports.id'))
    fixed_ip_address = sa.Column(sa.String(64))
    protocol = sa.Column(sa.String(16))
    floating_ip_port = sa.Column(sa.Integer)
    fixed_ip_port = sa.Column(sa.Integer)
    status = sa.Column(sa.String(16))


class DNATDBMixin(common_db_mixin.CommonDbMixin):
    """dnat db class"""
    @property
    def _l3_plugin(self):
        return ncu.get_service_plugin()['L3_ROUTER_NAT']

    def _make_dnat_dict(self, dnat_db, fields=None):
        res = {
            'id': dnat_db['id'],
            'floating_ip_id': dnat_db['floating_ip_id'],
            'router_id': dnat_db['router_id'],
            'port_id': dnat_db['fixed_port_id'],
            'fixed_ip_address': dnat_db['fixed_ip_address'],
            'protocol': dnat_db['protocol'],
            'floating_ip_port': int(dnat_db['floating_ip_port']),
            'fixed_ip_port': int(dnat_db['fixed_ip_port']),
            'status': dnat_db['status'],
        }
        return self._fields(res, fields)

    def get_floatingip(self, context, dnat_id, fields=None):
        """get floatingip"""
        return self._l3_plugin.get_floatingip(context, dnat_id, fields)

    @classmethod
    def _get_router_id(cls, context, port_id, fixed_ip_address):
        ip_allocation = context.session.query(models_v2.IPAllocation).filter(
            models_v2.IPAllocation.port_id == port_id,
            models_v2.IPAllocation.ip_address == fixed_ip_address,
        ).first()
        if not ip_allocation or not ip_allocation.subnet_id:
            raise extension.SubnetForPortNotFound(id=port_id)
        interface_query = context.session.query(models_v2.Port)
        interface_query = interface_query.join(models_v2.IPAllocation)
        router_interface = interface_query.filter(
            models_v2.IPAllocation.subnet_id == ip_allocation.subnet_id,
            models_v2.Port.device_owner == ncu.DEVICE_OWNER_ROUTER_INTF,
        ).first()
        if not router_interface or not router_interface.device_id:
            raise extension.RouterForPortNotFound(id=port_id)
        router = context.session.query(Router).filter(
            Router.id == router_interface.device_id,
        ).first()
        if not router:
            raise extension.RouterForPortNotFound(id=port_id)
        return router_interface.device_id

    def _validate_floating_ip(self, context, floating_ip_id):
        floatingip = self.get_floatingip(context, floating_ip_id)
        if floatingip['port_id'] or floatingip['fixed_ip_address']:
            raise extension.FloatingIPInUse(id=floating_ip_id)

    @classmethod
    def _validate_fixed_ip_address(cls, context, port_id, fixed_ip_address):
        ip_allocations = context.session.query(models_v2.IPAllocation).filter(
            models_v2.IPAllocation.port_id == port_id,
        ).all()
        for ip_allocation in ip_allocations:
            if ip_allocation.ip_address == fixed_ip_address:
                return
        raise extension.FixedIPForDNATConflict(ip=fixed_ip_address, id=port_id)

    @classmethod
    def _validate_router_id(cls, context, floating_ip_id, router_id):
        dnats = context.session.query(DNAT).filter(
            DNAT.floating_ip_id == floating_ip_id
        ).all()
        for dnat in dnats:
            if dnat.router_id != router_id:
                raise extension.RouterForDNATConflict(
                    router=router_id, dnat=dnat.id)

    @classmethod
    def _validate_dnat_duplication(cls, context, dnat_info):
        dnats = context.session.query(DNAT).filter(
            DNAT.floating_ip_id == dnat_info['floating_ip_id'],
            DNAT.protocol == dnat_info['protocol'],
            DNAT.floating_ip_port == dnat_info['floating_ip_port'],
        ).all()
        if dnats:
            raise extension.DNATConflict(id=dnats[0].id)

    def create_dnat_db(self, context, dnat):
        """create dnat db"""
        LOG.info(_LI('[AC] Begin to create DNAT in DB: %s'), dnat)
        dnat_info = dnat['dnat']
        self._validate_dnat_duplication(context, dnat_info)
        self._validate_floating_ip(context, dnat_info['floating_ip_id'])
        self._validate_fixed_ip_address(
            context, dnat_info['port_id'], dnat_info['fixed_ip_address'])
        router_id = self._get_router_id(
            context, dnat_info['port_id'], dnat_info['fixed_ip_address'])
        self._validate_router_id(
            context, dnat_info['floating_ip_id'], router_id)
        with context.session.begin(subtransactions=True):
            dnat_db = DNAT(
                id=uuidutils.generate_uuid(),
                floating_ip_id=dnat_info['floating_ip_id'],
                router_id=router_id,
                fixed_port_id=dnat_info['port_id'],
                fixed_ip_address=dnat_info['fixed_ip_address'],
                protocol=dnat_info['protocol'],
                floating_ip_port=dnat_info['floating_ip_port'],
                fixed_ip_port=dnat_info['fixed_ip_port'],
                status=constants.NEUTRON_STATUS_ACTIVE,
            )
            context.session.add(dnat_db)
        return self._make_dnat_dict(dnat_db)

    def delete_dnat_db(self, context, dnat_id):
        """delete dnat db"""
        LOG.info(_LI('[AC] Begin to delete DNAT in DB: %s'), dnat_id)
        dnat_db = self._get_dnat_db(context, dnat_id)
        with context.session.begin(subtransactions=True):
            context.session.delete(dnat_db)

    def _get_dnat_db(self, context, dnat_id):
        try:
            dnat = self._get_by_id(context, DNAT, dnat_id)
        except exc.NoResultFound:
            raise extension.DNATNotFound(id=dnat_id)
        return dnat

    def get_dnat_db(self, context, dnat_id, fields=None):
        """get dnat db"""
        dnat = self._get_dnat_db(context, dnat_id)
        return self._make_dnat_dict(dnat, fields)

    def get_dnats_db(self, context, filters=None, fields=None):
        """get dnats db"""
        return self._get_collection(
            context, DNAT, self._make_dnat_dict,
            filters=filters, fields=fields)
