#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright 2016 Huawei Technologies Co. Ltd. All rights reserved.
"""ac qos plugin"""

from oslo_log import log as logging

try:
    from neutron_lib import exceptions as n_exc
except ImportError:
    from neutron.common import exceptions as n_exc
try:
    from neutron.db import api as db_api
except ImportError:
    from neutron_lib.db import api as db_api

    if not hasattr(db_api, 'autonested_transaction'):
        from networking_huawei.drivers.ac.plugins.qos import api as db_api

from neutron.db import db_base_plugin_common
from neutron.extensions import qos
from neutron.objects import base as base_obj
from neutron.objects.qos import rule as rule_object
from neutron.objects.qos import rule_type as rule_type_object
from neutron.objects.qos import policy as policy_object

try:
    from neutron.services.qos.notification_drivers import manager as driver_mgr
except ImportError:
    from neutron.services.qos.drivers import manager as driver_mgr
try:
    from neutron.services.qos import qos_consts
except ImportError:
    from neutron_lib.services.qos import constants as qos_consts
from networking_huawei.drivers.ac.common import constants as ac_const
from networking_huawei.drivers.ac.plugins.qos import ac_qos_driver
from networking_huawei.drivers.ac.common import neutron_compatible_util as ncu
from networking_huawei._i18n import _LI, _LE

LOG = logging.getLogger(__name__)


class HuaweiAbstractQoSPlugin(object):
    """Huawei AC QoS Service Plugin abstract class"""
    def update_rule(self, context, rule_cls, rule_id, rule_data):
        """Abstract method update_rule registered in qos extension in
        openstack Train, It is need to be realized in qos plugin.
        Don't delete it.
        """
        pass

    def delete_rule(self, context, rule_cls, rule_id):
        """Abstract method delete_rule registered in qos extension in
        openstack Train, It is need to be realized in qos plugin.
        Don't delete it.
        """
        pass

    def get_rule(self, context, rule_cls, rule_id, fields=None):
        """Abstract method get_rule registered in qos extension in
        openstack Train, It is need to be realized in qos plugin.
        Don't delete it.
        """
        pass


class HuaweiQoSPlugin(HuaweiAbstractQoSPlugin, qos.QoSPluginBase):
    """Implementation of the Huawei AC QoS Service Plugin.

    This class implements a Quality of Service plugin that
    provides quality of service parameters over ports and
    networks.

    """
    if ncu.get_ops_version() in ac_const.OPS_VERSION_PQRTW_6_21:
        supported_extension_aliases = ['qos',
                                       'qos-bw-limit-direction',
                                       'qos-default',
                                       'qos-rule-type-details']
    else:
        supported_extension_aliases = ['qos']

    def __init__(self):
        LOG.info(_LI("[AC] Init Huawei QoS plugin."))
        super(HuaweiQoSPlugin, self).__init__()
        ac_qos_driver.register()
        self.ops_version = ncu.get_ops_version()
        if self.ops_version in ac_const.OPS_VERSION_PQRTW_6_21:
            self.driver_manager = driver_mgr.QosServiceDriverManager()
        else:
            self.notification_driver_manager = (
                driver_mgr.QosServiceNotificationDriverManager())
        LOG.info(_LI("[AC] Initialization finished successfully "
                     "for Huawei QoS plugin."))

    def get_plugin_description(self):
        """get plugin description"""
        return "Huawei AC QoS Service Plugin for ports and networks"

    @classmethod
    def _update_policy_db(cls, context, policy_id, policy):
        if hasattr(policy_object.QosPolicy, "get_by_id"):
            # OPS L version
            obj = policy_object.QosPolicy(context, **policy['policy'])
            obj.id = policy_id
        elif hasattr(policy_object.QosPolicy, "get_object"):
            # OPS M version
            obj = policy_object.QosPolicy(context, id=policy_id)
            obj.obj_reset_changes()
            for item_k, item_v in policy['policy'].items():
                if item_k != 'id':
                    setattr(obj, item_k, item_v)
        else:
            raise Exception

        obj.update()
        return obj

    def _create_bandwidth_limit_rule(self, context,
                                     policy_id, bandwidth_limit_rule):
        LOG.debug("Make sure we will have a policy object"
                  " to push resource update")
        with db_api.autonested_transaction(context.session):
            policy = ac_qos_driver.get_policy_obj(context, policy_id)
            rule = rule_object.QosBandwidthLimitRule(
                context, qos_policy_id=policy_id,
                **bandwidth_limit_rule['bandwidth_limit_rule'])
            rule.create()
            self._deal_policy_rules(policy)
        return rule, policy

    def _update_bandwidth_limit_rule(self, context, rule_id, policy_id,
                                     bandwidth_limit_rule):
        LOG.debug("Make sure we will have a policy object "
                  "to push resource update")
        with db_api.autonested_transaction(context.session):
            policy = ac_qos_driver.get_policy_obj(context, policy_id)
            policy.get_rule_by_id(rule_id)
            if hasattr(policy_object.QosPolicy, "get_by_id"):
                # OPS L version
                rule = rule_object.QosBandwidthLimitRule(
                    context, **bandwidth_limit_rule['bandwidth_limit_rule'])
                rule.id = rule_id
            elif hasattr(policy_object.QosPolicy, "get_object"):
                # OPS M version
                rule = rule_object.QosBandwidthLimitRule(context, id=rule_id)
                rule.obj_reset_changes()
                attr_list = [attr for attr in bandwidth_limit_rule['bandwidth_limit_rule'].items()
                             if attr[0] != 'id']
                for key, value in attr_list:
                    setattr(rule, key, value)
            else:
                raise Exception
            rule.update()
            self._deal_policy_rules(policy)
        return rule, policy

    def _delete_bandwidth_limit_rule(self, context, rule_id, policy_id):
        LOG.debug("Make sure we will have a policy object"
                  " to push resource update")
        with db_api.autonested_transaction(context.session):
            policy = ac_qos_driver.get_policy_obj(context, policy_id)
            rule = policy.get_rule_by_id(rule_id)
            rule.delete()
            self._deal_policy_rules(policy)
        return policy

    def _rollback_create_bandwidth_limit_rule(self, context, rule_id,
                                              policy_id, original_rule):
        rule_dict = {
            "max_kbps": original_rule['max_kbps'],
            "max_burst_kbps": original_rule['max_burst_kbps'],
            "tenant_id": context.tenant_id
        }
        with db_api.autonested_transaction(context.session):
            # first, validate that we have access to the policy
            policy = ac_qos_driver.get_policy_obj(context, policy_id)
            rule = rule_object.QosBandwidthLimitRule(
                context, qos_policy_id=policy_id, **rule_dict)
            rule.id = rule_id
            rule.create()
            self._deal_policy_rules(policy)

    @db_base_plugin_common.convert_result_to_dict
    def create_policy(self, context, policy):
        """create policy"""
        if self.ops_version in ac_const.OPS_VERSION_O_PQRTW_6_21:
            policy['policy'].pop('tenant_id', None)
        policy = policy_object.QosPolicy(context, **policy['policy'])
        policy.create()
        LOG.info(_LI("[AC] Success to create QoS policy in neutron DB: %s "),
                 policy)
        try:
            if self.ops_version in ac_const.OPS_VERSION_PQRTW_6_21:
                self.driver_manager.call(
                    qos_consts.CREATE_POLICY, context, policy)
            else:
                self.notification_driver_manager.create_policy(context, policy)
        except Exception:
            policy.delete()
            raise
        return policy

    @db_base_plugin_common.convert_result_to_dict
    def update_policy(self, context, policy_id, policy):
        """update policy"""
        original_policy = ac_qos_driver.get_policy_obj(context, policy_id)
        policy = self._update_policy_db(context, policy_id, policy)
        try:
            self._deal_update_policy(context, policy)
        except Exception:
            original_policy_dict = {
                'policy':
                    {
                        'name': original_policy.name,
                        'description': original_policy.description,
                        'shared': original_policy.shared
                    }
            }
            self._update_policy_db(context, policy_id, original_policy_dict)
            self._deal_update_policy(context, original_policy)
            raise
        return policy

    def delete_policy(self, context, policy_id):
        """delete policy"""
        policy = policy_object.QosPolicy(context)
        policy.id = policy_id
        if self.ops_version in ac_const.OPS_VERSION_PQRTW_6_21:
            self.driver_manager.call(
                qos_consts.DELETE_POLICY, context, policy)
        else:
            self.notification_driver_manager.delete_policy(context, policy)
        policy.delete()

    @db_base_plugin_common.filter_fields
    @db_base_plugin_common.convert_result_to_dict
    def get_policy(self, context, policy_id, fields=None):
        """get policy"""
        LOG.debug("Get policy witch ID = %s.", policy_id)
        return ac_qos_driver.get_policy_obj(context, policy_id)

    @db_base_plugin_common.filter_fields
    @db_base_plugin_common.convert_result_to_dict
    def get_policies(self, context, fields=None, sorts=None, filters=None,
                     marker=None, limit=None, page_reverse=False):
        """get policies"""
        return policy_object.QosPolicy.get_objects(context, **filters)

    @db_base_plugin_common.convert_result_to_dict
    def create_policy_bandwidth_limit_rule(self, context, policy_id,
                                           bandwidth_limit_rule):
        """create policy bandwidth limit rule"""
        LOG.info(_LI("[AC] Begin to create QoS policy "
                     "bandwidth limit rule: %s."), bandwidth_limit_rule)

        rule, policy = self._create_bandwidth_limit_rule(context,
                                                         policy_id,
                                                         bandwidth_limit_rule)
        LOG.info(_LI("[AC] Success to create QoS bandwidth limit rule "
                     "in neutron DB: %s "), rule)

        try:
            self._deal_update_policy(context, policy)
        except Exception as except_msg:
            LOG.error(_LE("[AC] Failed to create QoS policy bandwidth limit "
                          "rule in huawei driver: %s."), except_msg)
            self._delete_bandwidth_limit_rule(context, rule['id'], policy_id)
            self._deal_update_policy(context, policy)
            raise

        LOG.info(_LI("[AC] Huawei AC create QoS policy "
                     "bandwidth limit rule successfully."))
        return rule

    @db_base_plugin_common.convert_result_to_dict
    def update_policy_bandwidth_limit_rule(self, context, rule_id, policy_id,
                                           bandwidth_limit_rule):
        """update policy bandwidth limit rule"""
        LOG.info(_LI("[AC] Begin to update QoS policy "
                     "bandwidth limit rule: %s."), bandwidth_limit_rule)

        original_rule = self.get_policy_bandwidth_limit_rule(context,
                                                             rule_id,
                                                             policy_id)
        original_rule.pop('id')
        original_rule.pop('qos_policy_id')

        new_rule, policy = \
            self._update_bandwidth_limit_rule(context, rule_id, policy_id,
                                              bandwidth_limit_rule)
        try:
            self._deal_update_policy(context, policy)
        except Exception as except_msg:
            LOG.error(_LE("[AC] Failed to update QoS policy bandwidth limit "
                          "rule in huawei driver: %s."), except_msg)
            self._update_bandwidth_limit_rule(
                context, rule_id, policy_id,
                {'bandwidth_limit_rule': original_rule})
            self._deal_update_policy(context, policy)
            raise

        LOG.info(_LI("[AC] Huawei AC update QoS policy "
                     "bandwidth limit rule successfully."))
        return new_rule

    def delete_policy_bandwidth_limit_rule(self, context, rule_id, policy_id):
        """delete policy bandwidth limit rule"""
        LOG.info(_LI("[AC] Begin to delete QoS policy "
                     "bandwidth limit rule: %s."), rule_id)

        original_rule = self.get_policy_bandwidth_limit_rule(context,
                                                             rule_id,
                                                             policy_id)
        policy = self._delete_bandwidth_limit_rule(context, rule_id, policy_id)
        try:
            self._deal_update_policy(context, policy)
        except Exception as except_msg:
            LOG.error(_LE("[AC] Failed to delete QoS policy bandwidth limit "
                          "rule in huawei driver: %s."), except_msg)
            self._rollback_create_bandwidth_limit_rule(context, rule_id,
                                                       policy_id,
                                                       original_rule)
            if self.ops_version in [ac_const.OPS_P]:
                self.driver_manager.call(
                    qos_consts.UPDATE_POLICY, context, policy)
            else:
                self.notification_driver_manager.update_policy(context, policy)
            raise

        LOG.info(_LI("[AC] Huawei AC delete QoS policy "
                     "bandwidth limit rule successfully."))

    @db_base_plugin_common.filter_fields
    @db_base_plugin_common.convert_result_to_dict
    def get_policy_bandwidth_limit_rule(self, context, rule_id, policy_id,
                                        fields=None):
        """get policy bandwidth limit rule"""
        LOG.debug("Make sure we have access to the policy "
                  "when fetching the rule")
        with db_api.autonested_transaction(context.session):
            ac_qos_driver.get_policy_obj(context, policy_id)
            if hasattr(rule_object.QosBandwidthLimitRule, "get_by_id"):
                # OPS L version
                rule = rule_object.QosBandwidthLimitRule. \
                    get_by_id(context, rule_id)
            elif hasattr(rule_object.QosBandwidthLimitRule, "get_object"):
                # OPS M version
                rule = rule_object.QosBandwidthLimitRule. \
                    get_object(context, id=rule_id)
            else:
                rule = None
        if not rule:
            raise n_exc.QosRuleNotFound(policy_id=policy_id, rule_id=rule_id)
        return rule

    @db_base_plugin_common.filter_fields
    @db_base_plugin_common.convert_result_to_dict
    def get_policy_bandwidth_limit_rules(self, context, policy_id, filters=None,
                                         fields=None, sorts=None, limit=None,
                                         marker=None, page_reverse=False):
        """get policy bandwidth limit rules"""
        with db_api.autonested_transaction(context.session):
            LOG.debug("make sure we have access to the policy"
                      " when fetching the rule")
            # first, validate that we have access to the policy
            ac_qos_driver.get_policy_obj(context, policy_id)
            if filters is None:
                filters = dict()
            filters.update({qos_consts.QOS_POLICY_ID: policy_id})
            return rule_object.QosBandwidthLimitRule.get_objects(context,
                                                                 **filters)

    def _deal_update_policy(self, context, policy):
        """deal update_policy"""
        if self.ops_version in ac_const.OPS_VERSION_PQRTW_6_21:
            self.driver_manager.call(
                qos_consts.UPDATE_POLICY, context, policy)
        else:
            self.notification_driver_manager.update_policy(context, policy)

    def _deal_policy_rules(self, policy):
        """deal policy rules"""
        if self.ops_version in ac_const.OPS_VERSION_PQRTW_6_21:
            policy.obj_load_attr('rules')
        else:
            policy.reload_rules()

    @db_base_plugin_common.convert_result_to_dict
    def create_policy_rule(self, context, rule_cls, policy_id, rule_data):
        """create policy rule"""
        LOG.info(_LI("[AC] Begin to create policy rule, type: %(type)s, "
                     "policy id: %(policy_id)s, data: %(rule_data)s"),
                 {"type": rule_cls.rule_type, "policy_id": policy_id,
                  "rule_data": rule_data})
        rule_type = rule_cls.rule_type
        rule_data = rule_data[rule_type + '_rule']

        with db_api.autonested_transaction(context.session):
            LOG.debug("Ensure that we have access to the policy.")
            policy = ac_qos_driver.get_policy_obj(context, policy_id)
            rule = rule_cls(context, qos_policy_id=policy_id, **rule_data)
            rule.create()
            self._deal_policy_rules(policy)
        try:
            self._deal_update_policy(context, policy)
        except Exception as ex:
            LOG.error(_LE("[AC] Failed to create policy rule: %(ex)s, "
                          "type: %(type)s, policy id: %(policy_id)s, "
                          "data: %(rule_data)s"),
                      {"type": rule_type, "policy_id": policy_id,
                       "rule_data": rule_data, "ex": ex})
            with db_api.autonested_transaction(context.session):
                rule.delete()
                self._deal_policy_rules(policy)
            raise
        return rule

    @db_base_plugin_common.convert_result_to_dict
    def update_policy_rule(self, context, rule_cls, rule_id, policy_id,
                           rule_data):
        """update policy rule"""
        LOG.info(_LI("[AC] Begin to update policy rule, type: %(type)s, "
                     "rule id: %(rule_id)s, policy id: %(policy_id)s, "
                     "data: %(rule_data)s"),
                 {"type": rule_cls.rule_type, "policy_id": policy_id,
                  "rule_data": rule_data, "rule_id": rule_id})
        rule_type = rule_cls.rule_type
        rule_data = rule_data[rule_type + '_rule']

        with db_api.autonested_transaction(context.session):
            LOG.debug("Ensure we have access to the policy.")
            policy = ac_qos_driver.get_policy_obj(context, policy_id)
            rule = rule_cls(context, id=rule_id)
            rule.update_fields(rule_data, reset_changes=True)
            rule.update()
            self._deal_policy_rules(policy)
        try:
            self._deal_update_policy(context, policy)
        except Exception as ex:
            LOG.error(_LE("[AC] Failed to update policy rule: %(ex)s, "
                          "type: %(type)s, rule id: %(rule_id)s, "
                          "policy id: %(policy_id)s, data: %(rule_data)s"),
                      {"type": rule_type, "policy_id": policy_id, "ex": ex,
                       "rule_data": rule_data, "rule_id": rule_id})
            with db_api.autonested_transaction(context.session):
                original_rule = policy.get_rule_by_id(rule_id)
                LOG.info(_LI('[AC] Original rule: %s'), original_rule)
                rule.update_fields(original_rule.db_obj, reset_changes=True)
                rule.update()
                self._deal_policy_rules(policy)
            raise
        return rule

    def delete_policy_rule(self, context, rule_cls, rule_id, policy_id):
        """delete policy rule"""
        LOG.info(_LI("[AC] Begin to delete policy rule, rule id: %(rule_id)s, "
                     "policy id: %(policy_id)s"),
                 {"rule_id": rule_id, "policy_id": policy_id})
        with db_api.autonested_transaction(context.session):
            # Ensure we have access to the policy.
            policy = ac_qos_driver.get_policy_obj(context, policy_id)
            import copy
            original_rule = policy.get_rule_by_id(rule_id)
            original_db_obj = copy.deepcopy(original_rule.db_obj)
            rule = policy.get_rule_by_id(rule_id)
            rule.delete()
            self._deal_policy_rules(policy)
        try:
            self._deal_update_policy(context, policy)
        except Exception as ex:
            LOG.error(_LE("[AC] Failed to delete policy rule: %(ex)s, "
                          "rule id: %(rule_id)s, policy id: %(policy_id)s"),
                      {"rule_id": rule_id, "policy_id": policy_id, "ex": ex})
            with db_api.autonested_transaction(context.session):
                rule = rule_cls(context, **original_db_obj)
                rule.create()
                self._deal_policy_rules(policy)
            raise

    @db_base_plugin_common.filter_fields
    @db_base_plugin_common.convert_result_to_dict
    def get_policy_rule(self, context, rule_cls, rule_id, policy_id,
                        fields=None):
        """get policy rule"""
        with db_api.autonested_transaction(context.session):
            # Ensure we have access to the policy.
            ac_qos_driver.get_policy_obj(context, policy_id)
            rule = rule_cls.get_object(context, id=rule_id)
        if rule:
            return rule
        else:
            raise n_exc.QosRuleNotFound(policy_id=policy_id, rule_id=rule_id)

    @db_base_plugin_common.filter_fields
    @db_base_plugin_common.convert_result_to_dict
    def get_policy_rules(self, context, rule_cls, policy_id, filters=None,
                         fields=None, sorts=None, limit=None, marker=None,
                         page_reverse=False):
        """get policy rules"""
        with db_api.autonested_transaction(context.session):
            # Ensure we have access to the policy.
            ac_qos_driver.get_policy_obj(context, policy_id)
            if filters is None:
                filters = dict()
            filters.update({qos_consts.QOS_POLICY_ID: policy_id})
            pager = base_obj.Pager(sorts, limit, page_reverse, marker)
            return rule_cls.get_objects(context, _pager=pager, **filters)

    @db_base_plugin_common.filter_fields
    @db_base_plugin_common.convert_result_to_dict
    def get_rule_type(self, context, rule_type_name, fields=None):
        """get rule type"""
        if not context.is_admin:
            from neutron_lib import exceptions as lib_exc
            raise lib_exc.NotAuthorized()
        return rule_type_object.QosRuleType.get_object(rule_type_name)

    @db_base_plugin_common.filter_fields
    @db_base_plugin_common.convert_result_to_dict
    def get_rule_types(self, context, filters=None, fields=None,
                       sorts=None, limit=None,
                       marker=None, page_reverse=False):
        """get rule types"""
        if not filters:
            filters = {}
        return rule_type_object.QosRuleType.get_objects(**filters)

    @property
    def supported_rule_types(self):
        return self.driver_manager.supported_rule_types

    def supported_rule_type_details(self, rule_type_name):
        return self.driver_manager.supported_rule_type_details(rule_type_name)