#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright 2016 Huawei Technologies Co. Ltd. All rights reserved.
"""DB lock wrapper."""

from inspect import stack
import random
import datetime
import six

from sqlalchemy.orm import exc as orm_exc
from oslo_db import api as oslo_db_api
from oslo_db import exception as db_exc
from oslo_log import log
from oslo_utils import excutils
from oslo_utils import timeutils

from networking_huawei.drivers.ac.db.lock import exceptions as df_exc
from networking_huawei.drivers.ac.db.lock import lock_model as models
from networking_huawei.drivers.ac.common import constants as ac_constants
from networking_huawei.drivers.ac.db.dbif import ACdbInterface
from networking_huawei.drivers.ac.common import neutron_compatible_util as ncu

# Used to identify each API session
LOCK_SEED = 9876543210

# Used to wait and retry for distributed lock
LOCK_MAX_RETRIES = 1000
LOCK_INIT_RETRY_INTERVAL = 0.1
LOCK_MAX_RETRY_INTERVAL = 1

# The resource need to be protected by lock
RESOURCE_FW_POLICY = 1
RESOURCE_L3_ROUTER = 2
RESOURCE_TOKEN_CONFIG = 3
RESOURCE_STATUS_REPORT = 4
GET_VM_PORT_NAME = 5
SUBNET_OF_PUBLIC_SERVICE = 6
RESOURCE_DHCP_NETWORK = 7
DELETE_PORT_LOCK = 8
DELETE_NETWORK_LOCK = 9
UPDATE_PARENT_PORT = 10
DELETE_SUBNET_LOCK = 11

LOG = log.getLogger(__name__)


class HuaweiWrapDbRetry(oslo_db_api.wrap_db_retry):
    """DB retry wrapper."""

    def __init__(self, max_retries=0, retry_interval=0, inc_retry_interval=0,
                 max_retry_interval=0, retry_on_deadlock=False,
                 exception_checker=lambda exc: False):
        if ncu.get_ops_version() in [ac_constants.OPS_K]:
            super(HuaweiWrapDbRetry, self).__init__(
                max_retries=max_retries, retry_interval=retry_interval,
                inc_retry_interval=inc_retry_interval,
                max_retry_interval=max_retry_interval,
                retry_on_deadlock=retry_on_deadlock,
                retry_on_request=True
            )
        else:
            super(HuaweiWrapDbRetry, self).__init__(
                max_retries=max_retries, retry_interval=retry_interval,
                inc_retry_interval=inc_retry_interval,
                max_retry_interval=max_retry_interval,
                retry_on_deadlock=retry_on_deadlock,
                exception_checker=exception_checker
            )


class wrap_db_lock(object):
    """DB lock wrapper."""

    def __init__(self, resource_type):
        super(wrap_db_lock, self).__init__()
        self.type = resource_type

    def __call__(self, func):
        @six.wraps(func)
        def wrapper(*args, **kwargs):
            """DB lock wrapper."""
            lock_session_id = 0
            result = None

            lock_id = _get_lock_id_by_resource_type(self.type, args, kwargs)
            LOG.debug("[AC] Get lock id %s by resource", lock_id)

            # magic to prevent from nested lock
            within_wrapper = False
            for frame in stack()[1:]:
                if frame[3] == 'wrap_db_lock':
                    within_wrapper = True
                    break

            if within_wrapper is False:
                # test and create the lock if necessary
                _test_and_create_object(lock_id)
                lock_session_id = _acquire_lock(lock_id)

            try:
                result = func(*args, **kwargs)
            except Exception:
                with excutils.save_and_reraise_exception() as ctxt:
                    ctxt.reraise = True
            finally:
                if within_wrapper is False:
                    try:
                        _release_lock(lock_id, lock_session_id)
                    except Exception as ex:
                        LOG.exception(ex)

            return result

        return wrapper


def _get_lock_id_by_resource_type(resource_type, *args):
    if RESOURCE_FW_POLICY == resource_type or \
            RESOURCE_L3_ROUTER == resource_type:
        lock_id = args[0][2]
    elif RESOURCE_TOKEN_CONFIG == resource_type:
        lock_info = args[1]
        lock_id = lock_info.get('token_id')
    elif RESOURCE_STATUS_REPORT == resource_type:
        lock_id = args[0][1]
    elif GET_VM_PORT_NAME == resource_type:
        lock_id = args[0][0]
    elif SUBNET_OF_PUBLIC_SERVICE == resource_type:
        lock_id = args[0][4]
    elif RESOURCE_DHCP_NETWORK == resource_type:
        lock_id = args[0][2].network_id if hasattr(args[0][2], 'network_id') else args[0][2].get("network_id")
    elif DELETE_PORT_LOCK == resource_type:
        lock_info = args[0][1]
        lock_id = lock_info.current['id'] + '_delete_port'
    elif DELETE_NETWORK_LOCK == resource_type:
        lock_info = args[0][1]
        lock_id = lock_info.current['id'] + '_delete_network'
    elif UPDATE_PARENT_PORT == resource_type:
        lock_info = args[0][1]
        lock_id = lock_info['id'] + '_update_parent_port'
    elif DELETE_SUBNET_LOCK == resource_type:
        lock_info = args[0][1]
        lock_id = lock_info.current['id'] + '_delete_subnet'
    else:
        raise df_exc.UnknownResourceException(resource_type=resource_type)

    return lock_id


@HuaweiWrapDbRetry(max_retries=LOCK_MAX_RETRIES,
                   retry_interval=LOCK_INIT_RETRY_INTERVAL,
                   inc_retry_interval=False,
                   max_retry_interval=LOCK_MAX_RETRY_INTERVAL,
                   retry_on_deadlock=True,
                   exception_checker=lambda exc: True)
def _acquire_lock(oid):
    # generate temporary session id for this API context
    LOG.debug("[AC] acquire lock")
    sid = _generate_lock_session_id()

    # NOTE(nick-ma-z): we disallow subtransactions because the
    # retry logic will bust any parent transactions
    session = ACdbInterface().get_session('write')
    with session.begin():
        LOG.debug("[AC] Try to get lock for object %(oid)s in "
                  "session %(sid)s.", {'oid': oid, 'sid': sid})
        _lock_free_update(session, oid, lock_state=False, lock_session_id=sid)
        LOG.debug("[AC] Lock is acquired for object %(oid)s in "
                  "session %(sid)s.", {'oid': oid, 'sid': sid})
        return sid


@HuaweiWrapDbRetry(max_retries=LOCK_MAX_RETRIES,
                   retry_interval=LOCK_INIT_RETRY_INTERVAL,
                   inc_retry_interval=False,
                   max_retry_interval=LOCK_MAX_RETRY_INTERVAL,
                   retry_on_deadlock=True,
                   exception_checker=lambda exc: True)
def _release_lock(oid, sid):
    # NOTE(nick-ma-z): we disallow subtransactions because the
    # retry logic will bust any parent transactions
    session = ACdbInterface().get_session('write')
    with session.begin():
        LOG.debug("[AC] Try to release lock for object %(oid)s in "
                  "session %(sid)s.", {'oid': oid, 'sid': sid})
        _lock_free_update(session, oid, lock_state=True, lock_session_id=sid)
        LOG.debug("[AC] Lock is released for object %(oid)s in "
                  "session %(sid)s.", {'oid': oid, 'sid': sid})


def _generate_lock_session_id():
    return random.randint(0, LOCK_SEED)


@HuaweiWrapDbRetry(max_retries=LOCK_MAX_RETRIES,
                   retry_interval=LOCK_INIT_RETRY_INTERVAL,
                   inc_retry_interval=False,
                   max_retry_interval=LOCK_MAX_RETRY_INTERVAL,
                   retry_on_deadlock=True,
                   exception_checker=lambda exc: True)
def _test_and_create_object(uuid):
    LOG.debug("[AC] test and create object")
    try:
        session = ACdbInterface().get_session('write')
        with session.begin():
            row = session.query(models.DBLockedObjects).filter_by(
                object_uuid=uuid).one()
            # test ttl
            if row.lock and timeutils.is_older_than(
                    row.created_at, 120):
                # reset the lock if it is timeout
                LOG.debug('[AC] The lock for object %(id)s is reset '
                          'due to timeout.', {'id': uuid})
                _lock_free_update(session, uuid, lock_state=True,
                                  lock_session_id=row.lock_session_id)
    except orm_exc.NoResultFound:
        try:
            session = ACdbInterface().get_session('write')
            with session.begin():
                _create_db_row(session, oid=uuid)
        except db_exc.DBDuplicateEntry:
            # the lock is concurrently created.
            pass


def _lock_free_update(session, uuid, lock_state=False, lock_session_id=0):
    """Implement lock-free atomic update for the distributed lock

    :param session:    the db session
    :type session:     DB Session object
    :param uuid:         the lock uuid
    :type uuid:          string
    :param lock_state: the lock state to update
    :type lock_state:  boolean
    :param lock_session_id: the API session ID to update
    :type lock_session_id:  string
    :raises:           RetryRequest() when the lock failed to update
    """
    LOG.debug("[AC] lock free update")
    if not lock_state:
        # acquire lock
        search_params = {'object_uuid': uuid, 'lock': lock_state}
        update_params = {'lock': not lock_state,
                         'lock_session_id': lock_session_id,
                         'created_at': datetime.datetime.utcnow()}
    else:
        # release or reset lock
        search_params = {'object_uuid': uuid, 'lock': lock_state,
                         'lock_session_id': lock_session_id}
        update_params = {'lock': not lock_state, 'lock_session_id': 0}

    rows_update = session.query(models.DBLockedObjects). \
        filter_by(**search_params). \
        update(update_params, synchronize_session='fetch')

    if not rows_update:
        LOG.debug('[AC] The lock for object %(id)s in session '
                  '%(sid)s cannot be updated.',
                  {'id': uuid, 'sid': lock_session_id})
        raise db_exc.RetryRequest(
            df_exc.DBLockFailed(oid=uuid, sid=lock_session_id))


def _create_db_row(session, oid):
    LOG.debug("[AC] create db row")
    row = models.DBLockedObjects(object_uuid=oid,
                                 lock=False, lock_session_id=0,
                                 created_at=datetime.datetime.utcnow())
    session.add(row)
    session.flush()
