#!/usr/bin/python3
import ipaddress
import os
import platform
import re
import shlex
import socket
import subprocess
import tarfile
import zipfile
import time
from functools import wraps

import psutil

try:
    import win32service
    import win32serviceutil
except ImportError:
    pass

from common.common_define import Permission
from common.common_define import ArchiveType
from common.common_define import IPType
from common.common_define import CommonDefine
from common import configreader, exceptions

AGENTASSIST_INSTALL_PATH = 'AGENTASSIST_INSTALL_PATH'
file_name_pattern = re.compile(r'^[\w.-]+$')

VAR_CDM_LIB = "/var/lib/CDM"
VAR_CDM_LOG = "/var/log/CDM"
VAR_CDM_LIB_BAK = "/var/lib/CDM_BAK"
VAR_CDM_LOG_BAK = "/var/log/CDM_BAK"


def log_with_exception(log):
    def wrapper(func):
        @wraps(func)
        def inner(*args):
            try:
                return func(*args)
            except Exception as error:
                log.error(f"The {func} failed, param:[{args}], {error}.")
                if not isinstance(error, exceptions.AgentAssistantException):
                    error.error_code = "CSBS.9999"
                return error.error_code

        return inner

    return wrapper


class AgentConf(object):
    DEFAULT_LOG_FILE_SIZE_MB = 50
    MIN_LOG_FILE_SIZE_MB = 10
    MAX_LOG_FILE_SIZE_MB = 100

    DEFAULT_LOG_FILE_NUM = 10
    MIN_LOG_FILE_NUM = 5
    MAX_LOG_FILE_NUM = 20

    DEFAULT_MEMORY_LIMIT = 300
    MIN_MEMORY_LIMIT = 200
    MAX_MEMORY_LIMIT = 2048

    DEFAULT_CPU_LIMIT = 90
    MIN_CPU_LIMIT = 50
    MAX_CPU_LIMIT = 100

    DEFAULT_SLEEP_TIME_LIMIT = 10
    MIN_SLEEP_TIMEOUT_LIMIT = 10
    MAX_SLEEP_TIMEOUT_LIMIT = 60

    @staticmethod
    def check_param(min_limit, value, max_limit):
        if not (min_limit <= value <= max_limit):
            return False
        return True


class Utils(object):
    rdagent_file = 'rdagent.py'
    monitor_file = 'monitor.py'
    log_package_file = 'log_package.py'
    monitor_service_name = 'AgentAssist_monitor'

    @staticmethod
    def mod_chmod(path, mode):
        if not CommonDefine.IS_WINDOWS:
            os.chmod(path, mode)

    @staticmethod
    def set_one_file_permission(file_path, mod: Permission):
        if os.path.isfile(file_path):
            os.chmod(file_path, mod)

    @staticmethod
    def set_path_permissions(path, mod: Permission):
        for file in os.listdir(path):
            file_path = os.path.join(path, file)
            Utils.set_one_file_permission(file_path, mod)

    @staticmethod
    def find_user(user):
        file = '/etc/passwd'
        with open(file, 'r', encoding='utf-8') as fr:
            return user in fr.read()

    @staticmethod
    def get_python_path():
        python_path = configreader.g_cfg_common.get_option('input_info',
                                                           'python_path')
        return python_path

    @staticmethod
    def check_process_status(process_name):
        pid_list = psutil.pids()
        process_running = False
        for pid in pid_list:
            try:
                p = psutil.Process(pid)
                if process_name in p.name():
                    process_running = True
                    break
            except Exception:
                continue
        return process_running

    @staticmethod
    def check_backup_job():
        if CommonDefine.IS_WINDOWS:
            return not Utils.check_process_status("eefproc.exe")
        else:
            if Utils.check_process_status("eefproc") or Utils.check_process_status("datasourceproc"):
                return False
        return True

    @staticmethod
    def _check_file_size(arc_file, pkg_type, des_path):
        # 设置解压内容最大值(一般平均最大的压缩率20，再高就很可能是异常文件了！)
        max_file_size = 1024 * 1024 * 100 * 20
        if pkg_type == ArchiveType.ZIP:
            total_size = sum(file.file_size for file in arc_file.filelist)
            if total_size > max_file_size:
                raise IOError(f"The compressed package has an exception, total_size:{total_size}.")
            if total_size >= psutil.disk_usage(des_path).free:
                raise IOError(f'The zipfile size({total_size}) exceed remain '
                              f'target disk space({psutil.disk_usage(des_path).free}).')
        elif pkg_type == ArchiveType.TAR:
            total_size = sum(tarinfo.size for tarinfo in arc_file)
            if total_size > max_file_size:
                raise IOError(f"The compressed package has an exception, total_size:{total_size}.")
            if total_size >= psutil.disk_usage(des_path).free:
                raise IOError(f'The tarfile size({total_size}) exceed remain '
                              f'target disk space({psutil.disk_usage(des_path).free}).')

    @staticmethod
    def _walk_dir(directory):
        result = []
        for root, dirs, files in os.walk(directory):
            for file in files:
                result.append(os.path.join(root, file))
        return result

    @staticmethod
    def _make_dir(dest_file):
        directory, filename = os.path.split(dest_file)
        if not os.path.exists(directory):
            m = os.umask(0)
            try:
                os.makedirs(directory, mode=0o700)
                return directory, filename
            finally:
                os.umask(m)
        else:
            return directory, filename

    @staticmethod
    def _compress(src, dest_file, mode, compression=zipfile.ZIP_DEFLATED,
                  allowZip64=True, compresslevel=6):
        """
        :param src:源文件
        :param dest_file:目标文件
        :param mode:读'r'或者写'w'模式
        :param compression:设置压缩格式
        :param allowZip64:由于压缩容量过大，发生异常，此时需要允许zip64
        :param compresslevel:压缩等级,1~9,1最快，9最佳
        :return:True
        """
        directory, filename = Utils._make_dir(dest_file)

        if filename.endswith(".tar.gz"):
            with tarfile.open(dest_file, mode) as tar_file:
                tar_file.add(src, arcname=os.path.basename(src))
            return True
        elif filename.endswith(".zip"):
            zip_file = zipfile.ZipFile(dest_file, mode,
                                       compression=compression,
                                       allowZip64=allowZip64,
                                       compresslevel=compresslevel)
            if os.path.isfile(src):
                zip_file.write(src, arcname=os.path.basename(src))
            if os.path.isdir(src):
                filenames = Utils._walk_dir(src)
                for file_name in filenames:
                    zip_file.write(file_name, arcname=file_name)
            zip_file.close()
            return True
        else:
            raise Exception("Unsupported compress type.")

    @staticmethod
    def pack(src, dest_file, mode, compression=None, allowZip64=None,
             compresslevel=None):
        """压缩文件：支持zip和tar.gz压缩
        src：要压缩的文件或者目录
        dest_file：指定目标压缩完整路径，含文件名及压缩文件类型后缀，如c:/dest.zip
        mode：
            tar.gz-'w:gz'模式支持归档并压缩；'w'模式仅归档，不压缩
            zip-'a'模式或'w'模式皆可
        compression：仅支持zip压缩。
            ZIP_STOREED：只存储，不压缩
            ZIP_DEFLATED：gzip压缩算法
            ZIP_BZIP2：bzip2压缩算法
            ZIP_LZMA：lzma压缩算法
        allowZip64：仅支持zip压缩。操作的zip文件大小超过2G，设置为True
        compresslevel：仅支持zip压缩。压缩级别，默认为6，可根据情况调整
        """
        mode_range = ['a', 'w', 'w:gz']
        if not os.path.isfile(src) and not os.path.isdir(src):
            raise IOError("The src type error.")

        if mode not in mode_range:
            raise TypeError("Unsupported compress mode.")

        return Utils._compress(src, dest_file, mode, compression=compression,
                               allowZip64=allowZip64,
                               compresslevel=compresslevel)

    @staticmethod
    def unpack(pkg_name, des_path, pkg_type=ArchiveType.ZIP):
        if not os.path.isfile(pkg_name) or not os.path.isdir(des_path):
            raise IOError(f"The pkg or path not exists, pkg_name:{pkg_name},path:{des_path}.")
        # 限制最大文件数5000
        max_file_num = 5000
        arc_file = None
        try:
            if pkg_type == ArchiveType.ZIP:
                arc_file = zipfile.ZipFile(pkg_name)
                Utils._check_file_size(arc_file, ArchiveType.ZIP, des_path)
                arc_list = arc_file.namelist()
            elif pkg_type == ArchiveType.TAR:
                arc_file = tarfile.open(pkg_name)
                Utils._check_file_size(arc_file, ArchiveType.TAR, des_path)
                arc_list = arc_file.getnames()
            else:
                raise TypeError("Unsupported archive type.")
            if len(arc_list) > max_file_num:
                raise ValueError(f"The compressed file nums exceed maximum, file_nums:{len(arc_list)}.")
            invalid_path_parts = ('', os.path.pardir)
            for f in arc_list:
                suffix = '/' if f.endswith('/') else ''
                f = '/'.join(
                    x for x in f.split('/') if x not in invalid_path_parts)
                f += suffix
                arc_file.extract(f, des_path)
            return True
        finally:
            if arc_file:
                arc_file.close()

    @staticmethod
    def check_shell_inject_args(cmd):
        cmd_list = []
        if isinstance(cmd, (tuple, list)):
            cmd = [c.decode("utf8") if isinstance(c, bytes) else c for c in cmd]
            cmd = " ".join(cmd)
            cmd_list = shlex.split(cmd)

        symbol_black_list = ["-c", "--c", "="]
        symbol_black_set = {"~", "`", "<", ">", "[", "]", "|", "&", "!", "@",
                            "#", "%", "^", "(", ")", "$", "\\\\", "\\n", ";"}
        target_str = set()
        for tmp_cmd in cmd_list:
            if tmp_cmd in symbol_black_list:
                raise Exception(f"Command may has shell inject risk2, please fix it.")
            target_str = target_str | set(tmp_cmd)
        if len(symbol_black_set & target_str) > 0:
            raise Exception(f"Command may has shell inject risk1, please fix it.")
        return True

    @staticmethod
    def execute_cmds(cmd_list):
        cmd_stdin = None
        final_process = None
        for cmd in cmd_list:
            Utils.check_shell_inject_args(cmd)
            final_process = subprocess.Popen(cmd,
                                             stdin=cmd_stdin,
                                             stdout=subprocess.PIPE,
                                             stderr=subprocess.PIPE)
            cmd_stdin = final_process.stdout
        out, err = final_process.communicate()
        if final_process.returncode != 0:
            Exception(f"The execute_cmds failed, error:{err.decode()}.")
        return final_process.returncode, out.decode().strip('\n').splitlines()

    @staticmethod
    def change_cmds_format(args: str):
        cmds = []
        result = []
        if '|' in args:
            cmds = args.split('|')
        else:
            cmds.append(args)
        for cmd in cmds:
            cmd = shlex.split(cmd)
            result.append(cmd)
        return result

    @staticmethod
    def execute_cmd(str_command):
        Utils.check_shell_inject_args(str_command)
        if CommonDefine.IS_WINDOWS:
            process = subprocess.Popen(str_command, shell=False,
                                       stdout=subprocess.PIPE,
                                       stderr=subprocess.PIPE,
                                       encoding="gbk")
        else:
            process = subprocess.Popen(str_command, shell=False,
                                       stdout=subprocess.PIPE,
                                       stderr=subprocess.PIPE,
                                       encoding="utf-8")
        process.wait()
        ret_code = process.returncode
        res = process.stdout.read()
        process.stdout.close()
        return ret_code, res

    @staticmethod
    def get_ip_type(str_ip):
        try:
            ipaddress.IPv4Address(str_ip)
            ip_type = IPType.IPV4
        except Exception as error_ipv4:
            try:
                ipaddress.IPv6Address(str_ip)
                ip_type = IPType.IPV6
            except Exception as error_ipv6:
                raise TypeError(f"Convert str to ip failed,{error_ipv4}，{error_ipv6}.")
        return ip_type

    @staticmethod
    def get_install_path():
        return configreader.INSTALL_PATH

    @staticmethod
    def get_pkg_root_path():
        return os.path.join(configreader.INSTALL_PATH, 'PKG')

    @staticmethod
    def get_sub_agent_root_path():
        return os.path.join(configreader.INSTALL_PATH, 'SubAgent')

    @staticmethod
    def get_sub_agent_backup_path():
        return os.path.join(configreader.INSTALL_PATH, 'SubAgentBak')

    @staticmethod
    def get_agent_assist_root_path():
        return os.path.join(configreader.INSTALL_PATH, 'AgentAssist')

    @staticmethod
    def get_log_path():
        return os.path.join(configreader.INSTALL_PATH, 'AgentAssist/log')

    @staticmethod
    def get_conf_path():
        return os.path.join(configreader.INSTALL_PATH, 'AgentAssist/conf')

    @staticmethod
    def get_free_space_mb(folder):
        free = psutil.disk_usage(folder).free / 1000 / 1000 / 1000
        return free

    @staticmethod
    def get_host_ip():
        """获取主机ip"""
        if CommonDefine.IS_WINDOWS:
            hostname = socket.gethostname()
            ip = socket.gethostbyname(hostname)
            return ip
        else:
            hwc_ip = configreader.g_cfg_agentassist.get_option('proxy',
                                                               'hwc_ip')
            sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sk.settimeout(1 / 1000)
            try:
                sk.connect((hwc_ip, 80))
            except Exception:
                ip = sk.getsockname()[0]
            else:
                ip = sk.getsockname()[0]
            finally:
                sk.close()
            return ip

    @staticmethod
    def get_host_active_code():
        """获取主机激活码"""
        active_code = configreader.g_cfg_common.get_option('input_info',
                                                           'activation_code')
        return active_code

    @staticmethod
    def get_host_name():
        """获取主机名"""
        hostname = socket.gethostname()
        return hostname

    @staticmethod
    def get_linux_kernel_ver():
        sh_file = os.path.join(Utils.get_agent_assist_root_path(),
                               'bin/assist/hostinfo/check_os_ver.sh')
        ret_code, kernel_ver = Utils.execute_cmd(f"{sh_file}")
        if ret_code != 0:
            raise IOError("Get kernel version failed.")
        return kernel_ver.strip()

    @staticmethod
    def get_host_os_version():
        """"获取主机os版本"""
        if CommonDefine.IS_WINDOWS:
            os_version = platform.platform()
            return os_version
        else:
            os_version = Utils.get_linux_kernel_ver().lower()
            if isinstance(os_version, bytes):
                os_version = os_version.decode().strip()
            if os_version.startswith('el'):
                os_version = "Linux_" + os_version
            return os_version

    @staticmethod
    def get_host_os_architecture():
        """获取主机os架构"""
        os_architecture = platform.machine()
        return os_architecture

    @staticmethod
    def check_service(service_name):
        """
        check process by name
        :return: none
        """
        if CommonDefine.IS_WINDOWS:
            try:
                status_info = win32serviceutil.QueryServiceStatus(service_name)
                status = status_info[1]
                if status == win32service.SERVICE_RUNNING:
                    return True
                return False
            except Exception:
                return False

    @staticmethod
    def check_process(process_name):
        for process in psutil.process_iter():
            try:
                if len(process.cmdline()) > 1 \
                        and re.search(
                    process_name.replace('.', '\\.') + '$',
                    process.cmdline()[-1]) \
                        and not re.search('stop\\.py',
                                          process.cmdline()[-2]):
                    return True
            except Exception:
                continue
        return False

    @staticmethod
    def get_process_name_and_pid(process_file_name):
        flag = False
        p_name = []
        pid = []
        for process in psutil.process_iter():
            try:
                if len(process.cmdline()) > 1 \
                        and re.search(
                    process_file_name.replace('.', '\\.') + '$',
                    process.cmdline()[-1]) \
                        and not re.search('stop\\.py', process.cmdline()[-2]) \
                        and not re.match('s[h|u)]', process.cmdline()[0]):
                    p_name.append(process.name())
                    pid.append(process.pid)
                    flag = True
            except Exception:
                continue
        if flag:
            return True, p_name, pid
        else:
            return False, '', ''

    @staticmethod
    def check_file_name(file_name):
        if isinstance(file_name, str) and len(
                file_name) <= 256 and file_name_pattern.findall(file_name):
            return True
        else:
            return False


    @staticmethod
    def compare_version(version1, version2):
        if not Utils.check_version(version1) or not Utils.check_version(version2):
            raise Exception("Illegal version params.")

        # "8.1.2.SPC300" -> "8.1.2.300"
        ver1 = Utils.remove_alpha(version1)
        ver2 = Utils.remove_alpha(version2)

        v1_list = list(map(int, ver1.split('.')))
        v2_list = list(map(int, ver2.split('.')))
        for v1_num, v2_num in zip(v1_list, v2_list):
            if v1_num > v2_num:
                return True
            elif v1_num < v2_num:
                return False

        return True if len(v1_list) > len(v2_list) else False

    @staticmethod
    def check_version(version_string, separator="."):
        if (not isinstance(version_string, str)) or (not version_string) or (separator not in version_string):
            return False
        version_list = version_string.split(separator)
        for string in version_list:
            if not string.isalnum():
                return False
        return True

    @staticmethod
    def remove_alpha(string):
        if not isinstance(string, str):
            raise Exception(f"{type(string)} is not str.")
        result = re.sub(r'[a-zA-Z]*', '', string)
        return result
