# -*- coding:utf-8 -*-
import configparser
import ipaddress
import os
import re
import shlex
import time
import zipfile
import tarfile
from functools import wraps

import utils.common.log as logger
from utils.constant.path_constant import SOURCES_ROOT_DIR

from plugins.CSBS.common.constant import DPA_PLUGIN_CONFIG_PATH
from plugins.CSBS.common.ssh_client import SshClient
from plugins.CSBS.common.upgrade.constant import ENV_CONFIG_PATH

MSG_200 = 200
MSG_500 = 500
# 最大解压文件限定为5GB
MB_SIZE = 1024 * 1024 * 1024 * 5
# 最大文件数量限定10万个
MAX_FILE_COUNT = 1 * 1000 * 100


def auto_retry(max_retry_times=5, delay_time=60, step_time=0,
               exceptions=BaseException, sleep=time.sleep,
               callback=None, validate=None):
    """函数执行出现异常时自动重试的简单装饰器。

    :param max_retry_times:  最多重试次数。
    :param delay_time:  每次重试的延迟，单位秒。
    :param step_time:  每次重试后延迟递增，单位秒。
    :param exceptions:  触发重试的异常类型，单个异常直接传入异常类型，多个异常
    以tuple或list传入。
    :param sleep:  实现延迟的方法，默认为time.sleep。
    在一些异步框架，如tornado中，使用time.sleep会导致阻塞，可以传入自定义的方法
    来实现延迟。
    自定义方法函数签名应与time.sleep相同，接收一个参数，为延迟执行的时间。
    :param callback: 回调函数，函数签名应接收一个参数，每次出现异常时，会将异常
    对象传入。
    可用于记录异常日志，中断重试等。
    如回调函数正常执行，并返回True，则表示告知重试装饰器异常已经处理，重试装饰器
    终止重试，并且不会抛出任何异常。
    如回调函数正常执行，没有返回值或返回除True以外的结果，则继续重试。
    如回调函数抛出异常，则终止重试，并将回调函数的异常抛出。
    :param validate: 验证函数，用于验证执行结果，并确认是否继续重试。
    函数签名应接收一个参数，每次被装饰的函数完成且未抛出任何异常时，调用验证函数，
    将执行的结果传入。
    如验证函数正常执行，且返回False，则继续重试，即使被装饰的函数完成且未抛出任何
    异常。如回调函数正常执行，没有返回值或返回除False以外的结果，则终止重试，并将
    函数执行结果返回。
    如验证函数抛出异常，且异常属于被重试装饰器捕获的类型，则继续重试。
    如验证函数抛出异常，且异常不属于被重试装饰器捕获的类型，则将验证函数的异常抛出。
    :return: 被装饰函数的执行结果。
    """

    def wrapper(func):
        @wraps(func)
        def _wrapper(*args, **kwargs):
            delay, step, max_retries = delay_time, step_time, max_retry_times
            result = None
            return_flag = False
            retry_flag = True
            retry_times = 0
            while retry_flag:
                try:
                    result = func(*args, **kwargs)
                    # 验证函数返回False时，表示告知装饰器验证不通过，继续重试
                    if callable(validate) and not validate(result):
                        logger.error("Validate result return false.")
                        continue
                    logger.info("Validate result return true.")
                    return_flag = True
                    return result
                except exceptions as ex:
                    # 回调函数返回True时，表示告知装饰器异常已经处理，终止重试
                    if callable(callback) and callback(ex) is True:
                        return result
                    logger.error(f"Failed to execute {func.__name__} method, "
                                 f"err_msg: {ex}.")
                finally:
                    retry_times += 1
                    if retry_times > max_retries:
                        logger.info("Retry all complete.")
                        retry_flag = False
                    else:
                        if return_flag:
                            logger.info("Retry return True, all complete.")
                        else:
                            if step > 0:
                                delay += step
                            logger.info(f"Retry after {delay} seconds, "
                                        f"retry times: {retry_times}.")
                            sleep(delay)
            return result

        return _wrapper

    return wrapper


def open_file(file_name, mode='r', encoding=None, **kwargs):
    if mode in ['r', 'rt', 'tr'] and encoding is None:
        with open(file_name, 'rb') as file:
            context = file.read()
            for encoding_item in ['UTF-8', 'GBK', 'ISO-8859-1']:
                try:
                    context.decode(encoding=encoding_item)
                    encoding = encoding_item
                    break
                except UnicodeDecodeError as err:
                    raise err
    return open(file_name, mode=mode, encoding=encoding, **kwargs)


def get_config_content(file_path):
    file_object = open_file(file_path)
    content = ""
    try:
        for line in file_object:
            content = f"{content}{line}"
    except Exception as err:
        logger.error(f'Get file {file_path} content failed, err_msg:{err}.')
    finally:
        file_object.close()
    return content


def to_str(args):
    """
    字符串转换
    :param args: 待转换的对象
    :return: 转换后的对象
    """
    if isinstance(args, bytes):
        for encoding_item in ['UTF-8', 'GBK', 'ISO-8859-1']:
            try:
                return args.decode(encoding=encoding_item)
            except UnicodeDecodeError:
                pass
        logger.error("transfer bytes to str error: %s" % args)
        return args
    elif isinstance(args, dict):
        return {to_str(key): to_str(value)
                for key, value in list(args.items())}
    elif isinstance(args, list):
        return [to_str(element) for element in args]
    else:
        return args


class CSBSConfigParser(configparser.ConfigParser):
    def optionxform(self, optionstr):
        return optionstr


class ConfigIniFile(object):
    def __init__(self, conf_file):
        self.conf_file = conf_file
        if not (conf_file and os.path.isfile(conf_file)):
            raise Exception(f"The file does not exist, file path:{conf_file}.")
        self.config = CSBSConfigParser()
        self.config.read(conf_file)

    def get_params_dict_by_key_name(self, key_name, sub_keys=None):
        """Obtains the value of key in the .ini file.

        :param key_name: Key of the .ini file, parameter type: str
        :param sub_keys:sub_key of key_name, parameter type: list or tuple
        :return:dict()
        """
        if sub_keys and not isinstance(sub_keys, (list, tuple)):
            raise Exception(
                f"The parameter type of sub_keys must be list or tuple."
                f"sub_keys is {str(sub_keys)}, type is {type(sub_keys)}")
        param_dict = {}
        if sub_keys:
            for key, value in self.config.items(key_name):
                if key in sub_keys:
                    param_dict[key] = value
        else:
            for key, value in self.config.items(key_name):
                param_dict[key] = value
        return param_dict

    def get_value_by_key_and_sub_key(self, key, sub_key):
        try:
            return self.config.get(key, sub_key)
        except Exception as err:
            logger.warn(err)
            raise

    def set_value_by_key_and_sub_key(self, key, sub_key, value):
        if not self.config.has_section(key):
            self.config.add_section(key)
        self.config.set(key, sub_key, str(value))
        if not self.config.has_option(key, sub_key):
            return False
        if str(value) != self.config.get(key, sub_key):
            return False
        with open_file(self.conf_file, 'w') as file:
            self.config.write(file)
        return True


def unzip_file(file_path, target_path, unzip_size_limit_mb=None,
               unzip_file_count_limit_kilo=None):
    limit_size = unzip_size_limit_mb * 1024 * 1024 if unzip_size_limit_mb \
        else MB_SIZE
    if limit_size > MB_SIZE:
        limit_size = MB_SIZE
    limit_count = unzip_file_count_limit_kilo * 1000 if \
        unzip_file_count_limit_kilo else MAX_FILE_COUNT
    if limit_count > MAX_FILE_COUNT:
        limit_count = MAX_FILE_COUNT

    current_size = 0
    zip_f = zipfile.ZipFile(file_path)
    try:
        if not os.path.isfile(file_path):
            raise Exception(f'The package file {file_path} not existed.')
        if not os.path.exists(target_path):
            os.mkdir(target_path)
        file_count = len(zip_f.infolist())
        if file_count > limit_count:
            raise Exception(f"The package {file_path} contains {file_path} "
                            "files, maximum file limit exceeded, "
                            "Check whether the original file is correct.")
        for info in zip_f.infolist():
            current_size += info.file_size
            if current_size >= limit_size:
                raise Exception("The size of file to unzip exceeds max "
                                f"size {limit_size} byte allowed")
            zip_f.extract(info.filename, path=target_path)
        logger.info(f"Unzip file {file_path} to {target_path} success.")
    except Exception as err:
        raise Exception(f'Unzip {file_path} failed:{str(err)}.') from err
    finally:
        zip_f.close()


def check_tar_file(tar_files):
    file_size = 0
    file_count = 0
    for tar_file in tar_files.getmembers():
        file_count += 1
        file_size += tar_file.size
    if file_count > MAX_FILE_COUNT:
        raise Exception(f"zipfile: {tar_files} is too many.")
    if file_size > MB_SIZE:
        raise Exception(f"tarfile: {tar_files} is too large.")


def check_zipfile(file_list):
    """
    Check the size of the zip and tar.gz files.
    :param file_list:
    :return:
    """
    for file in file_list:
        if file.endswith(".zip"):
            zip_file = zipfile.ZipFile(file, "r")
            zip_file_name_list = zip_file.namelist()
            if len(zip_file_name_list) > MAX_FILE_COUNT:
                raise Exception(f"zipfile: {file} is too many.")
            file_size = 0
            for zip_file_name in zip_file_name_list:
                file_info = zip_file.getinfo(zip_file_name)
                file_size += file_info.file_size
            if file_size > MB_SIZE:
                raise Exception("zipfile: %s is too large." % file)
        elif file.endswith(".tar.gz"):
            tar_files = tarfile.open(file, "r")
            try:
                check_tar_file(tar_files)
            except Exception as err:
                raise Exception(f'tarfile {file} check failed: {str(err)}.') from err
            finally:
                tar_files.close()


def check_string_param(name, max_len=255, min_len=1,
                       expr=r"^[a-zA-Z0-9.\-_]+$", allow_null=False):
    if not name:
        return allow_null
    if len(name) < min_len or len(name) > max_len:
        return False
    pattern = re.compile(expr)
    if not re.match(pattern, name):
        return False
    return True


def check_url_param(endpoint, max_len=255, min_len=1, allow_null=False):
    if not check_string_param(endpoint, max_len, min_len,
                              expr=r"^[a-zA-Z0-9./\-_:=?#%]+$",
                              allow_null=allow_null):
        raise Exception(f"Invalid param, the parma is {endpoint}.")


def check_ip(tar_ip):
    if not isinstance(tar_ip, (str, int)):
        return False
    try:
        ipaddress.ip_address(tar_ip)
        return True
    except Exception as err:
        logger.warn(f"The [{tar_ip}] is not an IPv4 or IPv6 address, err_msg:{err}.")
        return False


def check_az_param(azs_str):
    azs_str = re.sub(r"\s+", "", azs_str)
    if azs_str.count(",") == 1 and azs_str.endswith(","):
        azs_str = azs_str.strip(",")
    az_id_list = azs_str.split(",")
    for az_id in az_id_list:
        if not check_string_param(az_id):
            raise Exception(f'The az_id = {az_id} is not legal. Its length range is 1~255, '
                            f'and the characters can only be [a-zA-Z0-9_-.].')


def check_openstack_region(param_name):
    return check_string_param(param_name)


def compare_version(version1, version2):
    if not check_version(version1) or not check_version(version2):
        raise Exception("Illegal version params.")

    # "8.1.2.SPC300" -> "8.1.2.300"
    ver1 = remove_alpha(version1)
    ver2 = remove_alpha(version2)

    v1_list = list(map(int, ver1.split('.')))
    v2_list = list(map(int, ver2.split('.')))
    for v1_num, v2_num in zip(v1_list, v2_list):
        if v1_num > v2_num:
            return True
        elif v1_num < v2_num:
            return False

    return True if len(v1_list) > len(v2_list) else False


def check_version(version_string, separator="."):
    if not isinstance(version_string, str) or separator not in version_string:
        return False
    version_list = version_string.split(separator)
    for string in version_list:
        if not string.isalnum():
            return False
    return True


def remove_alpha(string):
    if not isinstance(string, str):
        raise Exception(f"{type(string)} is not str.")
    result = re.sub(r'[a-zA-Z]*', '', string)
    return result


def check_shell_inject_args(cmd_str):
    cmd_list = []
    cmd = shlex.split(cmd_str)
    if isinstance(cmd, (tuple, list)):
        cmd = [c.decode("utf8") if isinstance(c, bytes) else c for c in cmd]
        cmd_list0 = " ".join(cmd)
        cmd_list = shlex.split(cmd_list0)

    symbol_black_list = ["-c", "--c", "="]
    symbol_black_set = {"~", "`", "<", ">", "[", "]", "|", "&", "!", "@",
                        "#", "%", "^", "(", ")", ";"}
    target_str = set()
    for tmp_cmd in cmd_list:
        if tmp_cmd in symbol_black_list:
            raise Exception(f"Command {cmd_list} "
                            f"may has shell inject risk2, please fix it")
        target_str = target_str | set(tmp_cmd)
    if len(symbol_black_set & target_str) > 0:
        raise Exception(f"Command {cmd_list} "
                        f"may has shell inject risk1, please fix it")
    return True


def set_ebackup_server(karbor_node, ebackup_management_ip, ebackup_az, pwd):
    ssh_util = SshClient()
    get_cmd = f"source /opt/huawei/dj/inst/utils.sh;get_ebackup_server | grep -w {ebackup_management_ip}"
    ssh_client = None
    try:
        ssh_client = ssh_util.get_ssh_client_user_su_root(karbor_node)
        ret = ssh_util.ssh_exec_command_return(ssh_client, get_cmd)
        # 当前eBackup已经对接
        cmd = "source /opt/huawei/dj/inst/utils.sh;" \
              f"set_ebackup_server --op add --ebackup_url {ebackup_management_ip} --az {ebackup_az}"
        if ssh_util.is_ssh_cmd_executed(ret):
            logger.info("Check eBackup server is exists.")
            ret = ssh_util.ssh_send_command_expect(ssh_client, cmd, expect_str="successfully", timeout=30)
            logger.info(f"Result of set_ebackup_server: {ret}.")
        else:
            ssh_util.ssh_send_command_expect(ssh_client, cmd, expect_str="password", timeout=30)
            cert_except = "Are you sure to download the ca cert"
            ssh_util.ssh_send_command_expect(ssh_client, pwd, cert_except, 30)
            ret = ssh_util.ssh_send_command_expect(ssh_client, "y", "last cmd: 0", 30)
            logger.info(f"Result of download ca: {ret}.")
    except Exception as ex:
        logger.error(f"Set ebackup az failed! {str(ex)}")
        return False
    finally:
        if ssh_client:
            SshClient.ssh_close(ssh_client)
    return True


def get_env_config():
    env_file_path = os.path.join(SOURCES_ROOT_DIR, ENV_CONFIG_PATH)
    return ConfigIniFile(env_file_path)


def get_dpa_plugin_config():
    return ConfigIniFile(os.path.join(SOURCES_ROOT_DIR, DPA_PLUGIN_CONFIG_PATH))
