import ipaddress
import os
import socket
import ssl
import sys

from message.tcp.agentframework import MAX_MSG_BODY_SIZE
from message.tcp.common import MSG_HEAD_SIZE, CheckArgs
from message.tcp.messagehandler import Request, RequestHandler
from message.tcp.messageparser import MessageParser, CMD_TYPE_CONST

SERVER_HOSTNAME = "KarborProxy"
DATA_RECV_BUFFER = 8 * 1024 * 1024


def check_ip(ip):
    if not isinstance(ip, (str, int)):
        return False
    try:
        ipaddress.ip_address(ip)
        return True
    except Exception:
        return False


def check_port(port):
    if not isinstance(port, int):
        return False
    return 0 <= port <= 65535


def read_file(file_path):
    with open(file_path, 'r') as f:
        return f.read().strip()


def make_ssl_context():
    return ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)


def set_ssl_context(ssl_context):
    ssl_context.verify_mode = ssl.CERT_REQUIRED
    ssl_context.load_verify_locations(ca)
    return ssl_context


def make_sock():
    return socket.socket(socket.AF_INET, socket.SOCK_STREAM)


def make_ssl_sock(ssl_context, sock):
    return ssl_context.wrap_socket(sock, server_hostname=SERVER_HOSTNAME)


def test_connectivity():
    try:
        context = set_ssl_context(make_ssl_context())
        socket.setdefaulttimeout(10)
        sock = make_sock()
        ssl_sock = make_ssl_sock(context, sock)
        ssl_sock.connect((str(haproxy_ip), int(haproxy_port)))
    except Exception as err:
        if not getattr(ssl_sock, '_closed', True):
            ssl_sock.close()
        print(f"Test connection failed:\n  {err}.")
        return False
    try:
        check_iam_auth(ssl_sock)
        return True
    except Exception as err:
        print(f"Iam login failed:\n  {err}.")
        return False


def check_iam_auth(ssl_sock):
    msg = pack_msg(CMD_TYPE_CONST["iam_certification"], Request.generate_msg_body_action)
    ssl_sock.send(msg)
    data_buffer = bytes()
    data = ssl_sock.recv(DATA_RECV_BUFFER)
    if not data:
        raise Exception(f"please check ak and sk is right.")
    data_buffer += data
    _handle_recv_msg(data_buffer)


def _handle_recv_msg(msg):
    while True:
        if msg and not msg.startswith(b'HWAB'):
            raise Exception("Msg does not start with b'HWAB',could be an attack.")
        msg_head = MessageParser.unpack_msg_head(msg[:MSG_HEAD_SIZE])
        body_size = msg_head.get("body_len")
        if body_size > MAX_MSG_BODY_SIZE:
            raise Exception(f"Body size:{body_size} exceeds the limits.")
        cmd_type = msg_head.get("cmd_type")
        if cmd_type not in MessageParser.cmd_type_all:
            raise Exception(f"Cmd type:{cmd_type} is not in cmd_type set.")
        magic = msg_head.get("magic")
        cmd_version = msg_head.get("cmd_version")
        reserved = msg_head.get("reserved")
        flags = msg_head.get("flags")
        sequence_num = msg_head.get("sequence_num")
        args_check = (magic == b'HWAB' and cmd_version == 1 and reserved == 0)
        if not all((args_check, CheckArgs.check_int_range(flags, [0, 1]), CheckArgs.check_arg_type(sequence_num, int))):
            raise Exception("Illegal args.")
        body_bytes = msg[MSG_HEAD_SIZE:MSG_HEAD_SIZE + body_size]
        msg_body = body_bytes if MessageParser.is_file_msg_body(cmd_type) else body_bytes.decode()
        if cmd_type == CMD_TYPE_CONST["iam_certification_ret"]:
            if 'iam login succ' not in msg_body:
                raise Exception(f"{msg_body}.")
            else:
                return True


def pack_msg(cmd_type, msg_body_dict):
    """构造一个完整消息：消息头+消息体"""
    if cmd_type in msg_body_dict:
        generate_msg_body_action = msg_body_dict.get(cmd_type)
        if not generate_msg_body_action:
            msg_body_bytes = b''
        else:
            _generate_msg_body_action_ret = MessageParser.to_bytes(generate_msg_body_action())
            if not _generate_msg_body_action_ret:
                raise Exception(f"package msg error.")
            msg_body_bytes = _generate_msg_body_action_ret
        msg_head = MessageParser.make_msg_head(cmd_type, body=msg_body_bytes)
        msg_head_pack = MessageParser.pack_msg_head(*msg_head)
        messages = msg_head_pack + msg_body_bytes
        return messages
    else:
        raise Exception(f"Illegal cmd_type:{cmd_type}.")

if __name__ == '__main__':
    install_path = sys.argv[1]
    haproxy_ip = sys.argv[2]
    haproxy_port = int(sys.argv[3])
    if all([len(sys.argv) == 4, os.path.exists(install_path), check_ip(haproxy_ip), check_port(haproxy_port)]):
        ca = os.path.join(install_path, 'AgentAssist/conf/cert/ca.pem')
        if not test_connectivity():
            sys.exit(1)
        sys.exit(0)
