import os
import shlex
import socket
import ssl
import stat
import threading
import time

from common.common_define import CommonDefine
from common.configreader import g_cfg_common, g_cfg_agentassist
from common.utils import Utils
from message.tcp import common
from message.tcp.common import msg_queue, logger, FileIOUtil, TimerTask, CheckArgs, MSG_BODY_FILEHEAD_SIZE, \
    MSG_BODY_FILEBODY_SIZE, MSG_HEAD_SIZE, update_kmc_time
from message.tcp.messagehandler import CipherHandler, RequestHandler, Request, CERT_FILE_PATH
from message.tcp.messageparser import MessageParser, CMD_TYPE_CONST

cipher_type = common.get_cipher_type()
haproxy_ip = g_cfg_common.get_option('input_info', 'ha_address')
common_port = g_cfg_agentassist.get_int_option('proxy', 'port')

install_path = Utils.get_install_path()
kmc_keygen_file = os.path.join(install_path, 'AgentAssist/conf/kmc_kegen')

request_haproxy_count = 0
request_proxy_count = 0
timer_register_host = None
timer_report_host = None
timer_report_cert = None
timer_update_kmc = None

CA_NAME = 'ca.pem'
SERVER_HOSTNAME = "KarborProxy"
SOCK_TIMEOUT = 200
COUNT_INIT = 0
COUNT_STEP = 1
REQUEST_PROXY_FLAG = 1
REQUEST_HAPROXY_FLAG = 0
RECONNECT_TIMES = 2
REGISTER_HOST_TIME_SEC = 60
REPORT_HOST_TIME_SEC = 180
SLEEP_TIME_SEC = 10
DATA_RECV_BUFFER = 8 * 1024 * 1024
MAX_COUNT = 10
MAX_DOWN_COUNT = 5
MAX_MSG_BODY_SIZE = MSG_BODY_FILEHEAD_SIZE + MSG_BODY_FILEBODY_SIZE
SLEEP_TIME = 2
TIME_OUT = 20
FLAGS = os.O_WRONLY | os.O_CREAT
MODES = stat.S_IWUSR | stat.S_IRUSR
SLEEP_TIME_LIST = [30, 60, 120, 240, 480, 960]


class Client(object):
    def __init__(self):
        CipherHandler.update_kmc(cipher_type, kmc_keygen_file)
        self._context = None
        self.sock = None
        self.ssl_sock = None
        self._address = self._make_address(haproxy_ip, common_port)
        g_cfg_agentassist.set_option("proxy", "haproxy_ip", self._address[0])
        self.lock = threading.Lock()
        self.connect_sleep_time = 0
        self.download_count = 0
        self.recv_sleep_time = 0
        self.recv_count = 0

    @staticmethod
    def create_timer_task(second, task, *args, **kwargs):
        timer_task = TimerTask(second, task, *args, **kwargs)
        timer_task.daemon = True
        return timer_task

    @staticmethod
    def stop_timer_task(timer_tasks):
        if not isinstance(timer_tasks, (list, tuple)):
            return False
        for timer_task in timer_tasks:
            if timer_task and timer_task.is_alive():
                timer_task.cancel()
                time.sleep(SLEEP_TIME_SEC)
                logger.info(f"Cancel timer task succ:{timer_task},is_alive:{timer_task.is_alive()},"
                            f"parent:{threading.current_thread()}.")
                del timer_task
                logger.info("Del timer task succ.")
        return True

    @staticmethod
    def is_sock_alive(sock):
        return not getattr(sock, '_closed', True)

    @staticmethod
    def _make_address(ip, port):
        if CheckArgs.check_ip(ip) and CheckArgs.check_port(port):
            return ip, port
        raise ValueError(f"Illegal address:{(ip, port)}.")

    @staticmethod
    def _get_cert(cert_path):
        cert = FileIOUtil.get_path(cert_path, CA_NAME)
        if FileIOUtil.is_file(cert):
            return cert
        raise ssl.SSLError

    @staticmethod
    def make_ssl_context():
        return ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)

    @staticmethod
    def set_ssl_context(ssl_context, ca):
        ssl_context.verify_mode = ssl.CERT_REQUIRED
        ssl_context.load_verify_locations(ca)
        return ssl_context

    @staticmethod
    def make_sock():
        return socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    @staticmethod
    def make_ssl_sock(ssl_context, sock):
        return ssl_context.wrap_socket(sock, server_hostname=SERVER_HOSTNAME)

    def prepare_to_connect(self):
        if not cipher_type:
            logger.error(f"The cipher type is None. Support type in {common.SUPPORT_CIPHER_TYPES}.")
            return

        ca_cert = self._get_cert(CERT_FILE_PATH)

        self._context = self.set_ssl_context(self.make_ssl_context(), ca_cert)
        logger.info("Make ssl context succ.")

        socket.setdefaulttimeout(SOCK_TIMEOUT)
        logger.info("Set sock timeout succ.")
        self.sock = self.make_sock()
        logger.info("Make sock succ.")
        self.ssl_sock = self.make_ssl_sock(self._context, self.sock)
        logger.info(f"SSL context wrap sock succ.")

    def start_timer_task_after_connect(self):
        global timer_report_host, timer_report_cert, timer_update_kmc, timer_register_host

        # 注册主机信息
        timer_register_host = self.create_timer_task(REGISTER_HOST_TIME_SEC, Request.register_host)
        timer_register_host.start()

        # 定时上报主机信息
        timer_report_host = self.create_timer_task(REPORT_HOST_TIME_SEC, Request.report_host_info)
        timer_report_host.start()

        # 定时更新kmc口令密文
        if not cipher_type:
            logger.error(f"The cipher type is None. Support type in {common.SUPPORT_CIPHER_TYPES}.")
            return

        timer_update_kmc = self.create_timer_task(
            update_kmc_time * 24 * 60 * 60, CipherHandler.update_kmc, (cipher_type, kmc_keygen_file))
        timer_update_kmc.start()

    def stop_timer_task_when_sock_except(self):
        global timer_report_host, timer_report_cert, timer_update_kmc, timer_register_host

        if self.stop_timer_task((timer_register_host,)):
            timer_register_host = None
            Request.clear_register_host_info()
        if self.stop_timer_task((timer_report_host, timer_report_cert, timer_update_kmc)):
            timer_report_host = None
            timer_report_cert = None
            timer_update_kmc = None

    def connect(self):
        global request_haproxy_count, request_proxy_count
        try:
            # 断线重连时，确保收发线程只建立一个socket
            if self.is_sock_alive(self.ssl_sock):
                return
            # 初次连接haproxy或重连proxy达到最大重连次数时，置为haproxy的地址
            if request_haproxy_count == REQUEST_HAPROXY_FLAG:
                self._address = self._make_address(haproxy_ip, common_port)

            self.prepare_to_connect()
            self.ssl_sock.connect(self._address)

            # 连接haproxy需要进行AK，SK登录IAM认证
            Request.iam_certification()
            # 检查proxy节点
            self.check_haproxy_connect()
        except ssl.SSLError as err:
            # 证书过期或者证书不正确异常连接
            self.cert_check_fail(err)
        except Exception as err:
            if request_haproxy_count <= RECONNECT_TIMES:
                with self.lock:
                    request_haproxy_count += COUNT_STEP
            else:
                with self.lock:
                    request_haproxy_count = COUNT_INIT
                    request_proxy_count = COUNT_INIT
            # 连接断开，停止定时上报任务
            self.stop_timer_task_when_sock_except()
            # ssl_sock创建成功，异常断开时，触发关闭当前ssl_sock套接字
            self.disconnect(self.ssl_sock)
            logger.error(f"Connect err:{err}, sock:{self.ssl_sock}.")

    def check_haproxy_connect(self):
        # 判断proxy节点是否ok，否则获取另一个proxy节点
        global request_proxy_count
        if self._address[0] == haproxy_ip:
            with self.lock:
                request_proxy_count += COUNT_STEP
            if request_proxy_count == REQUEST_PROXY_FLAG:
                Request.request_proxy_ip()
            with self.lock:
                request_proxy_count = COUNT_INIT
        else:
            self.start_timer_task_after_connect()
            self.download_count = 0
            self.connect_sleep_time = 0
            self.recv_sleep_time = 0
            self.recv_count = 0

    def cert_check_fail(self, err):
        # 证书过期或者证书不正确异常连接
        # 连接断开，停止定时上报任务
        self.stop_timer_task_when_sock_except()
        # ssl_sock创建成功，异常断开时，触发关闭当前ssl_sock套接字
        self.disconnect(self.ssl_sock)
        self.download_count += 1
        if self.download_count < MAX_COUNT:
            count = len(SLEEP_TIME_LIST)
            count = self.connect_sleep_time if self.connect_sleep_time < count else count - 1
            time.sleep(SLEEP_TIME_LIST[count])
            if self.download_cert():
                self.connect()
        else:
            logger.exception(f"Connect err:{err},sock:{self.ssl_sock}.")
            self.download_count = 0
            self.connect_sleep_time += 1

    def disconnect(self, sock):
        if self.is_sock_alive(sock):
            sock.close()
        return True

    def download_cert(self):
        ca_pem = os.path.realpath(FileIOUtil.get_path(CERT_FILE_PATH, CA_NAME))
        if FileIOUtil.is_file(ca_pem):
            os.remove(ca_pem)
        if CommonDefine.IS_WINDOWS:
            download_cert_path = os.path.realpath(f"{install_path}/AgentAssist/download_cert.bat")
            ret_code, res = Utils.execute_cmd(
                shlex.split(f"{download_cert_path} {install_path} {haproxy_ip}", posix=False))
        else:
            download_cert_path = os.path.realpath(f"{install_path}/AgentAssist/download_cert.sh")
            ret_code, res = Utils.execute_cmd(shlex.split(f"bash {download_cert_path} {install_path} {haproxy_ip}"))
        if ret_code != 0:
            logger.error(f"Download cert failed, error:{res}.")
            return False
        return True

    def send_msg(self, msg):
        try:
            self.ssl_sock.send(msg)
            logger.info(f"Send msg succ.")
        except Exception as err:
            logger.error(f"Send msg fail:{err},sock:{self.ssl_sock}.")
            time.sleep(SLEEP_TIME_SEC)
            self.connect()

    def _connect_to_proxy(self, msg_head, msg_body):
        proxy_ip_from_server = RequestHandler.get_proxy_ip(msg_body)
        g_cfg_agentassist.set_option('proxy', 'proxy_ip', proxy_ip_from_server)
        self._address = self._make_address(proxy_ip_from_server, common_port)
        self.disconnect(self.ssl_sock)
        global request_haproxy_count
        with self.lock:
            request_haproxy_count += COUNT_STEP
        self.connect()
        logger.info(f"AgentAssist successfully connected to AgentProxy.")

    def _handle_recv_msg(self, msg):
        while True:
            if msg and not msg.startswith(b'HWAB'):
                logger.error("Msg does not start with b'HWAB'.")
                raise ValueError("Msg does not start with b'HWAB',could be an attack.")
            if len(msg) < MSG_HEAD_SIZE:
                return msg
            msg_head = MessageParser.unpack_msg_head(msg[:MSG_HEAD_SIZE])
            body_size = msg_head.get("body_len")
            if body_size > MAX_MSG_BODY_SIZE:
                logger.error(f"Body size:{body_size} exceeds the limits.")
                return b''
            cmd_type = msg_head.get("cmd_type")
            if cmd_type not in MessageParser.cmd_type_all:
                logger.error(f"Cmd type:{cmd_type} is not in cmd_type set.")
                msg = msg[MSG_HEAD_SIZE + body_size:]
                continue
            magic = msg_head.get("magic")
            cmd_version = msg_head.get("cmd_version")
            reserved = msg_head.get("reserved")
            flags = msg_head.get("flags")
            sequence_num = msg_head.get("sequence_num")
            args_check = (magic == b'HWAB' and cmd_version == 1 and reserved == 0)
            if not args_check or not CheckArgs.check_int_range(flags, [0, 1]) \
                    or not CheckArgs.check_arg_type(sequence_num, int):
                logger.error("Illegal args.")
                msg = msg[MSG_HEAD_SIZE + body_size:]
                continue

            if len(msg) < MSG_HEAD_SIZE + body_size:
                return msg
            body_bytes = msg[MSG_HEAD_SIZE:MSG_HEAD_SIZE + body_size]
            msg_body = body_bytes if MessageParser.is_file_msg_body(cmd_type) else body_bytes.decode()
            if cmd_type == CMD_TYPE_CONST["iam_certification_ret"]:
                self.handle_recv_iam_certification_ret(msg_body)
            if cmd_type == CMD_TYPE_CONST["response_proxy_ip"]:
                self._connect_to_proxy(msg_head, msg_body)
                return msg[MSG_HEAD_SIZE + body_size:]
            msg_queue.push_req((msg_head, msg_body))
            msg = msg[MSG_HEAD_SIZE + body_size:]

    def handle_recv_iam_certification_ret(self, msg_body):
        if 'iam login succ' not in msg_body:
            self.recv_count += 1
            if self.recv_count == 10:
                self.recv_count = 0
                self.recv_sleep_time += 10
            count = len(SLEEP_TIME_LIST)
            count = self.recv_sleep_time if self.recv_sleep_time < count else count - 1
            time.sleep(SLEEP_TIME_LIST[count])
            raise Exception(f"IAM authentication failed.")

    def recv_msg(self):
        data_buffer = bytes()
        while True:
            try:
                data = self.ssl_sock.recv(DATA_RECV_BUFFER)
                if not data:
                    logger.error(f"The {self.ssl_sock} recv empty bytes, maybe it was suspended, please check.")
                    time.sleep(SLEEP_TIME_SEC)
                    raise Exception(f"SSL sock recv empty bytes.")
                data_buffer += data
                data_buffer = self._handle_recv_msg(data_buffer)
            except Exception as err:
                logger.error(f"Recv msg err:{err}.")
                data_buffer = bytes()
                # 接收到b''引发异常时，先关闭当前定时任务，避免同一任务同时存在多个线程
                self.stop_timer_task_when_sock_except()
                self.disconnect(self.ssl_sock)
                time.sleep(SLEEP_TIME_SEC)
                self.connect()

    def start_recv_msg(self):
        self.recv_msg()
