#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright 2018 Huawei Technologies Co. Ltd. All rights reserved.
"""
    Bulid wesocket connection with AC
"""

import base64
import os
import random
import socket
import ssl
import string
import struct
import threading
import time
from concurrent.futures import ThreadPoolExecutor

import eventlet
import netaddr
import six
from neutron.db.db_base_plugin_v2 import NeutronDbPluginV2
from neutron.plugins.ml2.models import PortBinding
from oslo_config import cfg
from oslo_serialization import jsonutils
from sqlalchemy import not_

from networking_huawei._i18n import _LI, _LE
from networking_huawei.common.exceptions import CrtFileNotFoundException, WrongSANException, CertExpiredException, \
    AuthenticationException, InternalServerException, OtherException, CertRevokedException
from networking_huawei.drivers.ac.ac_agent.rpc.websocket.frame import Frame
from networking_huawei.drivers.ac.client.restclient import ACReSTClient
from networking_huawei.drivers.ac.common import constants
from networking_huawei.drivers.ac.common import neutron_compatible_util as ncu
from networking_huawei.drivers.ac.common import security_util
from networking_huawei.drivers.ac.common.constants import NW_HW_PORTS, OPER_UPDATE
from networking_huawei.drivers.ac.common.fusion_sphere_alarm import ACPluginAlarm
from networking_huawei.drivers.ac.common.neutron_compatible_util import ac_log as logging
from networking_huawei.drivers.ac.common.util import DataFilterUtil
from networking_huawei.drivers.ac.db.dbif import ACdbInterface
from networking_huawei.drivers.ac.db.schema import ACFailedResources
from networking_huawei.drivers.ac.encode_convert import convert_to_bytes, convert_to_str

eventlet.monkey_patch()
LOG = logging.getLogger(__name__)


def _read_ports():
    """获取所有port，过滤掉黑名单以及反亲和类型port"""
    db_context = ncu.neutron_context.get_admin_context()
    anti_port_ids = set()
    anti_ports = db_context.session.query(PortBinding).filter(not_(PortBinding.profile is None))
    for elem in anti_ports:
        port_binding = jsonutils.loads(elem.profile) if elem.profile else {}
        if port_binding.get('anti_affinity_port'):
            anti_port_ids.add(elem.port_id)
    LOG.debug("Websocket reconnect, query anti_affinity_port ids:%s", anti_port_ids)

    data_filter = DataFilterUtil()
    result = []
    err_filters = {'status': ["ERROR"]}
    ports = NeutronDbPluginV2().get_ports(db_context, filters=err_filters)
    for elem in ports:
        port_id = elem.get('id')
        if not (port_id in anti_port_ids or data_filter.not_in_white_or_in_black_list(
                db_context, elem, NW_HW_PORTS, port_id)):
            result.append(elem)
    return result


def _retry_error_port_when_reconnect():
    """websocket重连后重新发送给控制器一次error状态的port"""
    LOG.info("after reconnect success,start deal error port.")
    ports = _read_ports()
    LOG.debug("port list:%s", ports)
    db_helper = ACdbInterface()
    for elem in ports:

        # 重试次数为1
        failed_res = ACFailedResources(id=elem.get('id'), res_type=NW_HW_PORTS, retry_count=1,
                                       operation='%s_%s' % (OPER_UPDATE, NW_HW_PORTS))
        LOG.debug("send error port to ac:%s,port=%s", failed_res, elem)
        try:
            db_helper.create_or_update_failed_resource(failed_res)
        except Exception as e:
            LOG.error('record failed resource in db occur error:%s', e)

    LOG.info("deal error port finish.")


class WebSocket(object):
    """
    AC plugin websocket object
    """

    def __init__(self, remote_address, local_address, ssl_crt_file_path):
        """ init WebSocket """
        self.remote_address = remote_address
        self.local_address = local_address
        self.ssl_crt_file_path = ssl_crt_file_path
        self.input_executor = ThreadPoolExecutor(max_workers=constants.RPC_HANDLE_MAX_WORKERS)
        self.lock = threading.Lock()
        self.send_lock = threading.Lock()
        self.recv_lock = threading.Lock()
        self.ws_url = None
        self.http_url = None
        self.frame = Frame()
        self.connection = None
        self.recv_buffer = []
        self.head = None
        self.payload_length = None
        self.mask = None
        self.ac_rest_client = ACReSTClient()

    def _clear_head(self):
        """ set head, payload_length and mask to None"""
        self.head = None
        self.payload_length = None
        self.mask = None

    def _get_stream_name(self, token_id, remote_ip, port):
        """create subscription and get stream-name from return"""
        json_req = jsonutils.dumps({"input": {"stream": "netconf", "access": "json"}})
        headers = {"Content-type": "application/json", "x-auth-token": token_id}
        url = "https://%s:%s/restconf/operations/huawei-ac-stream-websocket:create-stream-subscription" % (
            remote_ip, port)
        stream_name = self.ac_rest_client.get_stream_name_with_cert(headers, "POST", url, json_req)
        LOG.info(_LI("[AC]get_stream_name result."))
        return stream_name

    def get_filter(self):
        """ get filter info """
        remote_ip = self.remote_address[0]
        ac_port = constants.rest_server_port
        try:
            token_id = self.ac_rest_client._get_token_id(remote_ip, ac_port, is_websocket=True)
        except (AuthenticationException, InternalServerException, OtherException) as ex:
            if ncu.IS_FSP:
                alarm_info = ACPluginAlarm.get_websocket_connection_fail_alarm(str(ex))
                LOG.info("[AC]Websocket connection failed for token issue.Send alarm message.")
                ACPluginAlarm.send_alarm(alarm_info)
            return -1
        if netaddr.valid_ipv6(remote_ip):
            remote_ip = '[' + remote_ip + ']'
        stream_name = self._get_stream_name(token_id, remote_ip, ac_port)
        headers = {"Content-type": "application/json", "x-auth-token": token_id}
        cloud_name = cfg.CONF.huawei_ac_config.cloud_name
        restconf_dir = os.path.join(os.path.realpath("/restconf/"), stream_name)
        url = "https://%s:%s%s?filter=/huawei-ac-dcn-neutron:statusreport/openstack-info=%s-%s" % (
            remote_ip, ac_port, restconf_dir, cloud_name, self.local_address[0].replace(':', '~'))
        filter_info = self.ac_rest_client.get_filter_info_with_cert(headers, "GET", url)
        LOG.info(_LI("[AC]get_filter result."))
        return filter_info

    @classmethod
    def _get_certs(cls):
        """get certs"""
        dir_path = os.path.dirname(os.path.realpath(__file__))
        key_file_path = os.path.join(dir_path, "client_key.pem")
        cert_file_path = os.path.join(dir_path, "client.cer")
        ca_cert_path = os.path.join(dir_path, "trust.cer")
        if os.path.isfile(key_file_path) and os.path.isfile(cert_file_path) and os.path.isfile(ca_cert_path):
            return {'ca_file': ca_cert_path, 'cert_file': cert_file_path, 'key_file': key_file_path}
        LOG.error(_LE("[AC] %s, %s, %s are not found"), key_file_path, cert_file_path, ca_cert_path)
        raise CrtFileNotFoundException()

    def split_remote_addr_and_get_base_url(self):
        """ split remote address """
        remote_ip = self.remote_address[0]
        remote_port = self.remote_address[1]
        base_url = '%s:%s' % ('[' + remote_ip + ']' if netaddr.valid_ipv6(remote_ip) else remote_ip, remote_port)
        return remote_ip, base_url

    @classmethod
    def get_websocket_key_pwd(cls):
        """ get websocket key password """
        if ncu.after_fsp_6_3_0():
            return cfg.CONF.huawei_ac_config.websocket_key_password
        return security_util.decrypt_data(cfg.CONF.huawei_ac_config.websocket_key_password,
                                          data_type=constants.WEBSOCKET_SECURE_KEY)

    @classmethod
    def _get_ssl_context(cls, certs):
        """ get SSLContext """
        websocket_key_pwd = WebSocket.get_websocket_key_pwd()
        ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
        ssl_context.load_verify_locations(certs['ca_file'])
        ssl_context.load_cert_chain(certs['cert_file'], certs['key_file'], websocket_key_pwd)
        ssl_context.set_ciphers(constants.SUPPORT_CIPHERS)
        ssl_context.verify_mode = ssl.CERT_REQUIRED
        return ssl_context

    @classmethod
    def verify_server_cert(cls, new_client):
        """ verify server cert """
        server_cert = new_client.getpeercert(True)
        pem = ssl.DER_cert_to_PEM_cert(server_cert)
        if ncu.is_revoked(ncu.WEBSOCKET_CRL, pem):
            LOG.error(_LE("AC websocket cert is revoked"))
            raise CertRevokedException(cert="AC websocket cert")
        if ncu.SUPPORT_VERIFY_CERT:
            # get AC server ca cert
            if ncu.is_cert_expired(pem):
                raise CertExpiredException()
            if not ncu.verify_san(server_cert):
                raise WrongSANException()

    def connect(self):
        """ connect to AC """
        with self.lock:
            remote_ip, base_url = self.split_remote_addr_and_get_base_url()
            self.http_url = 'wss://' + base_url + "/restconf/stream-websocket/netconf"
            if netaddr.valid_ipv4(remote_ip):
                new_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            elif netaddr.valid_ipv6(remote_ip):
                new_client = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
            else:
                raise Exception('[AC]remote_ip is not format: %s' % remote_ip)

            certs = self._get_certs()

            ncu.verify_plugin_cert(certs, connection='websocket')

            ssl_context = WebSocket._get_ssl_context(certs)

            new_client = ssl.SSLSocket(sock=new_client, _context=ssl_context)
            self.ws_url = 'wss://' + base_url
            if self.local_address:
                LOG.debug("[AC]websocket rpc client bind local address %s", self.local_address)
                new_client.bind(self.local_address)
            new_client.settimeout(constants.SOCKET_TIMEOUT)
            new_client.connect(self.remote_address)

            self.verify_server_cert(new_client)

            if self.handshake(new_client) == -1:
                return -1
            LOG.info("[AC]connect to (%s) success.", self.ws_url)
            new_client.settimeout(30)
            self.connection = new_client

            # 异步任务后台处理错误的port
            port_thread = threading.Thread(target=_retry_error_port_when_reconnect, name='error_port_deal_thread')
            port_thread.start()
            return 0

    @classmethod
    def generate_websocket_key(cls):
        """ generate websocket key """
        random_string = ''.join(random.sample(string.ascii_letters + string.digits, 16))
        encoded_date = base64.b64encode(convert_to_bytes(random_string.encode('utf-8')))
        return convert_to_str(encoded_date)

    def handshake(self, sock):
        """ websocket handshake """
        filter_info = self.get_filter()
        if filter_info == -1:
            return -1

        filter_info = convert_to_str(filter_info) + "&isHeartBeat=true"
        filter_info = filter_info.split('18010')[1]
        websocket_key = self.generate_websocket_key()
        hand_msg = "GET %(filter)s HTTP/1.1\r\n" \
                   "Host: %(host)s\r\n" \
                   "Origin: %(http_url)s\r\n" \
                   "Connection: Upgrade\r\n" \
                   "Upgrade: websocket\r\n" \
                   "Sec-WebSocket-Version: 13\r\n" \
                   "Sec-WebSocket-Key: %(websocket_key)s\r\n\r\n" % {'filter': filter_info,
                                                                     'host': self.local_address[0],
                                                                     'http_url': self.http_url,
                                                                     'websocket_key': websocket_key}
        sock.send(convert_to_bytes(hand_msg))

        shake = sock.recv(1024)
        if "sec-websocket-accept" not in convert_to_str(shake).lower():
            LOG.error(_LE("[AC]websocket shake hands fail."))
            if ncu.IS_FSP:
                alarm_info = ACPluginAlarm.get_websocket_connection_fail_alarm('shake hand failed')
                LOG.info("[AC]Websocket connection failed. Send alarm message.")
                ACPluginAlarm.send_alarm(alarm_info)
            return -1
        return 0

    def send(self, message):
        """ send message to server """
        # if message is too long, then split it into many frames to send
        if not self.connection:
            LOG.error(_LE("[AC]websocket connection does not exist"))
            return
        first = True
        while len(message) > constants.max_payload_length:
            frame = Frame().get_message_frame(message[0:constants.max_payload_length], first=first, end=False,
                                              mask=True)
            # if first fragment,opcode is TEXT;
            # from the second to the last fragment,opcode is CONTINUATION
            if first:
                first = False
            self.__send_frame(frame)
            message = message[constants.max_payload_length:]
        # after the loop maybe there is left message to send
        if message:
            frame = Frame().get_message_frame(message[0:constants.max_payload_length],
                                              first=first, end=True, mask=True)
            self.__send_frame(frame)

    def send_ping(self):
        """ send a ping to server """
        frame = Frame().get_ping_frame()
        self.__send_frame(frame)
        LOG.debug(_LI("[AC]send a ping!"))

    def send_pong(self):
        """ return a pong frame to server after receiving a ping """
        frame = Frame().get_pong_frame()
        self.__send_frame(frame)

    def send_close(self):
        """ send close message """
        frame = Frame().get_close_frame(status_code=Frame.STATUS_NORMAL, reason="")
        self.__send_frame(frame)

    def __send_frame(self, frame):
        """ send the frame object to server """
        with self.send_lock:
            self.connection.send(frame.to_data_frame())

    # recv data from socket.
    # websocket protocol :https://tools.ietf.org/html/rfc6455#section-6.1
    # chapter 5.2 Base Framing Protocol
    def recv(self):
        """ recv data from socket. """
        LOG.debug("[AC]begin recv")
        if not self.head and self.recv_head() == -1:
            return None
        (fin, rsv1, rsv2, rsv3, opcode, mask, _) = self.head
        LOG.debug("[AC]recv head.fin:%d,opcode:%d,mask:%d",
                  fin, opcode, mask)

        if not self.payload_length and self.recv_length() == -1:
            return None
        length = self.payload_length

        if mask:
            mask_key = self.recv_mask()
            if mask_key == -1:
                return None
            payload = self.parse_payload(self.recv_data(length), mask_key)
        else:
            payload = self.recv_data(length)

        self._clear_head()
        frame = Frame(fin, rsv1, rsv2, rsv3, opcode, mask, length, payload)
        return frame

    def recv_head(self):
        """ recv first 2 bytes, then parse it to get fin/opcode/mask... """
        data = self.recv_data(2)
        if not data:
            return -1
        byte1 = data[0]
        if six.PY2:
            byte1 = ord(byte1)

        fin = byte1 >> 7 & 1
        rsv1 = byte1 >> 6 & 1
        rsv2 = byte1 >> 5 & 1
        rsv3 = byte1 >> 4 & 1
        opcode = byte1 & 0xf

        byte2 = data[1]
        if six.PY2:
            byte2 = ord(byte2)
        mask = byte2 >> 7 & 1
        length = byte2 & 0x7f
        self.head = (fin, rsv1, rsv2, rsv3, opcode, mask, length)
        return 0

    def recv_length(self):
        """ recv payload length """
        if self.head[6] == 0x7e:
            data = self.recv_data(2)
            if not data:
                return -1
            self.payload_length = struct.unpack("!H", data)[0]
        elif self.head[6] == 0x7f:
            data = self.recv_data(8)
            if not data:
                return -1
            self.payload_length = struct.unpack("!Q", data)[0]
        else:
            self.payload_length = self.head[6]
        return 0

    def recv_mask(self):
        """ recv mask """
        data = self.recv_data(4)
        if not data:
            return -1
        return data

    def recv_data(self, bufsize):
        """ recv data from socket """
        if not self.connection:
            return None
        if bufsize == 0:
            return None
        data = ''
        try:
            data = self.connection.recv(bufsize)
            if not data:
                time.sleep(1)
        except (ValueError, socket.error) as ex:
            time.sleep(1)
            LOG.error(_LE("[AC]recv data error. %s"), ex)
        return data

    def parse_payload(self, payload, masking_key):
        """ if mask exist,then parse payload according to make-key """
        if not payload or len(payload) != self.payload_length:
            LOG.error(_LE("parse payload error! data_len:%d,payload_length:%d"), len(payload), self.payload_length)
        i = 0
        data = ""
        for charactor in payload:
            data += chr(ord(charactor) ^ ord(masking_key[i % 4]))
            i += 1
        return data

    def close(self):
        """ close websocket """
        with self.lock:
            if self.connection:
                try:
                    self.connection.close()
                    self.connection = None
                    self.ws_url = None
                    self.http_url = None
                except socket.error as ex:
                    LOG.error(_LE("[AC]Close socket with %s, exception: %s"), self.remote_address, str(ex))
            else:
                LOG.error("[AC]connection does not exist,remote addr:%s", self.remote_address)
