#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright 2016 Huawei Technologies Co. Ltd. All rights reserved.
"""
WebSocketTerminal class
"""

import sched
import threading
import uuid
import time

try:
    from neutron import context
except ImportError:
    from neutron_lib import context

from networking_huawei.drivers.ac.common.neutron_compatible_util import \
    ac_log as logging
from oslo_serialization import jsonutils

from networking_huawei._i18n import _LI, _LW, _LE
from networking_huawei.drivers.ac.common import constants
from networking_huawei.drivers.ac.ac_agent.rpc.websocket. \
    websocket_future import Future
from networking_huawei.drivers.ac.db.compare_result.compare_results_db \
    import CompareDbMixin
from networking_huawei.drivers.ac.db.sync_result.sync_results_db \
    import SyncDbMixin
from networking_huawei.drivers.ac.db.dbif import ACdbInterface
from networking_huawei.drivers.ac.sync import worker

LOG = logging.getLogger(__name__)
UTF8 = "UTF-8"


class WebSocketTerminal(object):
    """WebSocketTerminal class"""

    def __init__(self, websocket, address, input_executor):
        self.address = address
        self.input_executor = input_executor
        self.requests = {}
        self.handlers = []
        self.schedule = sched.scheduler(time.time, time.sleep)
        self.schedule_sync = sched.scheduler(time.time, time.sleep)
        self.req_lock = threading.Lock()
        self.is_req_timeout_running = False
        self.echo_future = Future()
        self.websock = websocket
        self.compare_db = CompareDbMixin()
        self.sync_db = SyncDbMixin()
        self.db_if = ACdbInterface()
        self._thread_event = threading.Event()
        self._thread_pool = worker.ACSyncThreadPool(
            constants.SYNC_THREAD_POOL_DEFAULT_COUNT)
        self.sync_spawn_state = constants.SYNC_THREAD_POOL_NO_SPAWN

    def _add_request(self, request):
        """Add request"""
        self.requests[request.id] = (request, constants.RPC_DEFAULT_TIMEOUT)

    def _handle_request(self, msg):
        """Handle request"""
        LOG.info(_LI('Get message from AC :%s'), msg)
        request = jsonutils.loads(msg.encode(UTF8))
        if 'sync_plugin_version' in request.get('params', {}):
            LOG.info(_LI('[AC] no need to handle the request: %s'), request)
            return

        if "method" in request:
            if request.get("method") in constants.WEBSOCKET_REQUEST_LIST:
                self.input_executor.submit(self._get_service, request)

        if request["id"] in self.requests.keys():
            self.input_executor.submit(self._get_return, request)

    def _get_return(self, req):
        """Get return"""
        try:
            curr_thread = threading.currentThread()
            curr_thread.setName("rpc_getReturn")
            LOG.info(_LI("[AC]Get rpc return from AC, return: %s, error: %s"),
                     req.get("result", "not provided"),
                     req.get("error", "not provided"))
            with self.req_lock:
                request = self.requests.get(req["id"])[0] if \
                    (self.requests.get(req["id"]) is not None) else None
            if request and not request.future.callback:
                return
            if request and request.future.callback:
                callback = request.future.callback
                if isinstance(callback, str) and callback == "sync":
                    request.future.notify(req["id"], req["result"],
                                          req["error"])
                else:
                    self._callback(request, req)
        except Exception as ex:
            LOG.error(_LE("[AC]Rpc return raise exception: %s"), str(ex))
        finally:
            with self.req_lock:
                if self.requests.get(req["id"]) is not None:
                    del self.requests[req["id"]]

    def _callback(self, request, req):
        """Callback"""
        callback_method = request.future.callback.returnFail if req.get(
            "error") \
            else request.future.callback.returnSuccess
        result = req.get("error") if req.get("error") else req.get("result")
        callback_method(self.address, result)

    def _get_service(self, req):
        """Get service"""
        try:
            self.set_curr_thread_name(req)
            LOG.debug("[AC]Get rpc service from AC")
            if not self.handlers:
                LOG.warn(_LW("[AC]Rpc service: no handler registered but "
                             "request received"))
                return
            if req.get("method") is None:
                LOG.warn(_LW("[AC]Rpc service %s: this msg should "
                             "not be handled here, there must be "
                             "some error in param[id] or this request "
                             "has been time out"), req)
                return
            for handler in self.handlers:
                if getattr(handler, req["method"], "default") != "default":
                    self._get_service_deal_default(handler, req)
                    break
            else:
                LOG.error(_LE("[AC]Rpc service: there is no handler "
                              "for your request"))
                error_info = "no handler for request"
                json_ret = {"id": req["id"], "result": False,
                            "error": error_info}
                self.websock.send(jsonutils.dumps(json_ret))
        except Exception as ex:
            LOG.error(_LE("[AC]Rpc service raise exception: %s"), str(ex))

    def _get_service_deal_default(self, handler, req):
        """Get service:deal not default case"""
        try:
            self._get_result_send_to_ac(handler, req)
        except Exception as ex:
            LOG.exception(str(ex))
            if req["method"] in constants.AC_PLUGIN_OPERATE_LIST:
                json_ret = {"id": req["id"], "method": req["method"],
                            "result_data": "",
                            "error": str(ex)}
            else:
                json_ret = {"id": req["id"], "result": False,
                            "error": str(ex)}
            LOG.error(_LE("[AC]Rpc service raise exception: %s"),
                      json_ret)
            self.websock.send(jsonutils.dumps(json_ret))

    def _get_result_send_to_ac(self, handler, req):
        params = req.get('params', [])
        if req["method"] in constants.AC_PLUGIN_OPERATE_LIST:
            obj = getattr(handler, req["method"])(req)
            json_ret = {"id": req["id"], "method": req["method"],
                        "result_data": obj['result'],
                        "error": obj['error_msg']}
        else:
            obj = getattr(handler, req["method"])(*params)
            json_ret = {"id": req["id"], "result": obj['result'],
                        "error": obj['error_msg'],
                        "operation": constants.SYNC_MSG}
        LOG.debug("[AC]Rpc service process result: %s",
                  json_ret)
        self.websock.send(jsonutils.dumps(json_ret))
        if req["method"] in constants.ASYNCHRONOUS_RETURN_LIST:
            if not json_ret["error"]:
                self.add_thread_to_check_sync(req)

    def _register(self, handler):
        """Register"""
        self.handlers.append(handler)

    def set_requests_keys(self):
        """Start requests keys"""
        for k in self.requests:
            with self.req_lock:
                value = self.requests.get(k)
                if value is None:
                    continue
                if value[1] == 1:
                    LOG.info(_LI("[AC]Rpc terminal:remove timeout request: %s"),
                             self.requests[k])
                    self.requests.pop(k)
                else:
                    self.requests[k] = (value[0], value[1] - 1)

    def _handler_timeout_from_dict(self, key):
        """Handler timeout from dict"""
        with self.req_lock:
            if self.requests.get(key) is not None and \
                    self.requests.get(key)[1] == 1:
                LOG.info(_LI("[AC]Rpc terminal: remove timeout "
                             "request: %s"), self.requests[key])
                self.requests.pop(key)
            if self.requests.get(key) is not None and \
                    self.requests.get(key)[1] != 1:
                self.requests[key] = (self.requests.get(key)[0],
                                      self.requests.get(key)[1] - 1)

    def start_req_timeout(self):
        """Start request timeout"""
        if not self.is_req_timeout_running:
            self.is_req_timeout_running = True
            threading.Thread(target=self._start_timeout, args=()).start()

    @classmethod
    def set_curr_thread_name(cls, req=None):
        """Start request timeout"""
        curr_thread = threading.currentThread()
        if req:
            curr_thread.setName("rpc_getService_%s" %
                                req.get("method", "not provided"))
        else:
            curr_thread.setName("rpc_terminal_req_timeout")

    def _start_timeout(self):
        """Start timeout"""
        self.set_curr_thread_name()
        self.schedule.enter(0, 0, self._start_req_timeout, ())
        self.schedule.run()
        self.is_req_timeout_running = False

    def _start_req_timeout(self):
        """start request timeout"""
        self.set_requests_keys()
        self.schedule.enter(1, 0, self._start_req_timeout, ())

    def __getattr__(self, item):
        return Request(item, self)

    def add_thread_to_check_sync(self, req):
        """Add thread to check sync whether over"""
        LOG.info("add thread to check sync for :%s", req)
        if not self.sync_spawn_state:
            self._thread_pool. \
                spawn_threads(constants.SYNC_THREAD_POOL_DEFAULT_COUNT)
            self.sync_spawn_state = 1
        while True:
            if not self._thread_pool.task_queue.full():
                self._thread_pool.add_task(self.check_sync_result, req)
                break
            else:
                self._thread_event.wait(.1)

    def check_sync_result(self, thread_name, request, sec=10):
        """Schedule to check sync result"""
        LOG.info(_LI("Check sync result :%s, and thread name is :%s"),
                 request, thread_name)
        uid = request['id']
        method = request['method']
        self.schedule_sync.enter(sec, 0, self._check_sync_existence,
                                 (uid, sec, method, constants.CHECK_SYNC_TIMES))
        self.schedule_sync.run()

    def _check_sync_existence(self, uid, sec, method, check_times):
        """Check sync existence and translate sync results to ac"""
        LOG.info(_LI("Check sync is existence, the method is :%s"), method)
        try:
            if check_times == 0:
                LOG.error(_LE("Can not get sync results by 24 hours, return"))
                return
            admin_context = context.get_admin_context()
            session = admin_context.session
            (status, _) = self.db_if.check_is_neutron_sync_in_progress(session)
            LOG.info(_LI("Check sync is processing :%s"), status)

            sync_results = []
            if method in [constants.INCONSISTENCY_CHECK,
                          constants.SINGLE_FEATURE_CHECK]:
                sync_results = self.compare_db. \
                    get_db_compare_results_query(admin_context)
            elif method in [constants.INCONSISTENCY_RECOVER,
                            constants.SINGLE_FEATURE_RECOVER,
                            constants.MULTI_INSTANCE_RECOVER]:
                filters = {'sync_type': [str(method)]}
                sync_results = self.sync_db. \
                    get_db_sync_results_query(admin_context, filters=filters)
            LOG.info(_LI("Get sync results :%s"), sync_results)

            if status or not sync_results:
                self.schedule_sync.enter(sec, 0, self._check_sync_existence,
                                         (uid, sec, method, check_times - 1))
            else:
                self.process_sync_results(uid, method, sync_results)

        except Exception as ex:
            dict_ret = {"id": uid, "method": method,
                        "result_data": "", "error": str(ex)}

            json_ret = jsonutils.dumps(dict_ret)
            LOG.error(_LE("[AC]Query sync result raise exception:%s"), json_ret)
            self.websock.send(json_ret)

    def process_sync_results(self, uid, method, sync_results):
        """process sync results"""
        error_msg = ""
        if len(sync_results) == 1:
            sync_id = ''
            if str(sync_results[0].get('id')).isdigit():
                sync_id = int(str(sync_results[0].get('id')))
            # 1. set sync result is [] when no difference data
            # 2. record error message when sync raise exception
            if sync_id == constants.SYNC_NO_DIFFERENCE:
                sync_results = []
            elif sync_id == constants.SYNC_OP_COMPARE_DATA:
                error_msg = sync_results[0].get('status')
            elif sync_id in [constants.SYNC_OP_SYNC_DATA,
                             constants.SYNC_OP_SYNC_AND_COMPARE_DATA,
                             constants.SYNC_OP_SYNC_SINGLE_INSTANCE]:
                error_msg = sync_results[0].get('error_message')
        LOG.info(_LI("[AC]Process sync result to ac: %s, error msg is: %s"),
                 sync_results, error_msg)
        self.segment_transmission(uid, method, sync_results, error_msg)

    def segment_transmission(self, uid, method, results, error_msg):
        """Send message to ac by segment"""
        count = len(results)
        index = 1
        json_ret = {"id": uid,
                    "method": method,
                    "result_data": "",
                    "status": "",
                    "index": index,
                    "error": error_msg,
                    "count": count}
        while len(results) > constants.MAX_SYNC_TRANS_LEN:
            if index == 1:
                json_ret["status"] = "start"
            else:
                json_ret["status"] = "continue"
            json_ret["index"] = index
            json_ret["result_data"] = results[0:constants.MAX_SYNC_TRANS_LEN]
            results = results[constants.MAX_SYNC_TRANS_LEN:]
            index = index + 1
            request = jsonutils.dumps(json_ret)
            self.websock.send(request)
            time.sleep(1)
        json_ret["result_data"] = results
        json_ret["status"] = "end"
        json_ret["index"] = index
        request = jsonutils.dumps(json_ret)
        self.websock.send(request)


class Request(object):
    """Request class"""

    def __init__(self, method, terminal):
        self.method = method
        self.terminal = terminal
        self.req_id = uuid.uuid4().__str__()
        self.future = Future()

    def __call__(self, *args, **kwargs):
        try:
            if "callbackFunc" in kwargs:
                self.future.register(kwargs["callbackFunc"])
            else:
                self.future.register("sync")
                self.future.sync(self.req_id, self.method).clear()

            json_request = {"id": self.req_id, "method": self.method}
            params = list(args)
            json_request["params"] = params
            LOG.info(_LI("[AC]Send json rpc request"))
            self.terminal._write(jsonutils.dumps(json_request))
            self.terminal._add_request(self)
        except Exception as ex:
            LOG.error(_LE(str(ex)))
            self.future.set_error(str(ex))
        return self.future
