# coding:utf-8
"""
@version: Toolkit V200R005C00
@time: 2018/09/22
@file: funcUtils.py
@function:
@modify:
"""
import functools
import threading
import time
import traceback
import re

import java
from java.lang import Exception as JException


def get_error_log_func(logger):
    error_log_fun = None
    if logger and hasattr(logger, 'error'):
        error_log_fun = getattr(logger, 'error', None)
    elif logger and hasattr(logger, 'logError'):
        error_log_fun = getattr(logger, 'logError', None)

    return error_log_fun


def wrap_all_exception_logged(logger=None):
    """封装所有异常信息，并记录日志， 保证执行函数不会抛出任何异常.

    :param logger:
    :return:
    """

    log_func = get_error_log_func(logger)

    def wrapper(func):
        @functools.wraps(func)
        def inner(*args, **kwargs):
            try:
                return func(*args, **kwargs)
            except (Exception, JException):
                if log_func:
                    log_func('Exception trace:{}'.format(
                        str(traceback.format_exc())))

        return inner

    return wrapper


def setProgress(context, totalSecs, intervalSec, event):
    """设置执行进度和剩余时间.

    :param context:
    :param totalSecs:
    :param intervalSec:
    :param event:
    :return:
    """
    if not context:
        return

    logger = context.get("logger")
    if logger:
        logger.info('set progress loop start..')

    context["curProgressPer"] = 0
    context["remainTime"] = totalSecs
    startSec = time.time()

    # 执行时间超过200s时，至少保证前100s占10%进度，防止时间过长时前3min左右进度无更新。
    isNeedAdjustProgress = totalSecs >= 200
    percent = 0
    elapsedSecs = 0
    while (not event.isSet()) and elapsedSecs < totalSecs:
        time.sleep(intervalSec)
        elapsedSecs = time.time() - startSec
        calcPercent = int(elapsedSecs * 100.0 / totalSecs)
        if isNeedAdjustProgress and elapsedSecs < 200:
            adjustedPercent = int(elapsedSecs * 1.0 / 100 * 0.1)
            calcPercent = max(calcPercent, adjustedPercent)

        percent = max(calcPercent, percent)

        remainTime = totalSecs - elapsedSecs

        percent = int(percent) if percent < 100 else 99
        remainTime = remainTime if remainTime > 0 else 1

        context.put("curProgressPer", int(percent))
        context.put("remainTime", int(remainTime))
    else:
        if logger:
            logger.info('setting progress loop end for timeout or event set.')

    context.put("curProgressPer", 100)
    context.put("remainTime", 0)
    if logger:
        logger.info('set progress loop end..')
    return


def set_multi_progress(context, total_secs, interval_secs, event,
                       display_list=[]):
    """设置多个阶段的进度展示.

    :param context:
    :param total_secs:
    :param interval_secs:
    :param event:
    :param display_list: [{"time":10, "percent":90}]
    :return:
    """
    if not context:
        return

    total_percent = 100.0
    context["curProgressPer"] = 0
    start_time = time.time()

    if not display_list:
        display_list = [{"time": total_secs, "percent": total_percent}]

    context["remainTime"] = display_list[0].get("time")
    last_time = 0
    last_percent = 0
    for dis in display_list:
        max_time = min(dis.get("time"), total_secs)
        max_percent = min(dis.get("percent"), total_percent)
        used_time = 0
        while(not event.isSet()) and used_time < max_time:
            time.sleep(interval_secs)
            used_time = time.time() - start_time
            percent = \
                int((used_time * 1.0 - last_time) / (max_time - last_time) *
                    (max_percent - last_percent) + last_percent)

            remainTime = max_time - used_time
            percent = int(percent) if percent < total_percent else 99
            remainTime = remainTime if remainTime > 0 else 1
            context.put("curProgressPer", int(percent))
            context.put("remainTime", int(remainTime))

        last_time = max_time
        last_percent = max_percent

    context.put("curProgressPer", total_percent)
    context.put("remainTime", 0)
    return

def set_info_collect_progress(context, total_secs, interval_secs, event):
    """设置信息收集工具的执行进度和剩余时间.

    :param context:
    :param total_secs:
    :param interval_secs:
    :param event:
    :return:
    """
    if not context:
        return

    logger = context.get("logger")
    if logger:
        logger.info("set progress loop start..")

    observer = context.get("progressObserver")
    if not observer or not hasattr(observer, "updateProgress"):
        return

    observer.updateProgress(1, None)
    start_sec = time.time()

    # 执行时间超过200s时，至少保证前100s占10%进度，防止时间过长时
    # 前3min左右进度无更新。
    is_need_adjust_progress = total_secs >= 200
    percent = 1
    elapsed_secs = 0
    while (not event.isSet()) and elapsed_secs < total_secs:
        time.sleep(interval_secs)
        elapsed_secs = time.time() - start_sec
        calc_percent = int(elapsed_secs * 100.0 / total_secs)
        if is_need_adjust_progress and elapsed_secs < 200:
            adjusted_percent = int(elapsed_secs * 1.0 / 100 * 0.1)
            calc_percent = max(calc_percent, adjusted_percent)

        percent = max(calc_percent, percent)
        percent = int(percent) if percent < 100 else 99
        observer.updateProgress(percent, None)
    else:
        if logger:
            logger.info("setting progress loop end for timeout or event set.")

    observer.updateProgress(100, None)
    if logger:
        logger.info("set progress loop end..")
    return


def fakeProgress(totalSecs, intervalSec=5, tool_name=None, display_list=[]):
    """刷新进度和执行时间，支持指定时间阶段和百分比展示

    :param totalSecs:总时间（s）
    :param intervalSec:间隔时间（s)
    :param tool_name: 工具名称
    :param display_list: 阶段显示时间信息
    :return:
    """
    targetMap = {"infoCollect": set_info_collect_progress,
                 "offlineCtrl": set_multi_progress,
                 "poweronCtrl": set_multi_progress}

    def wrapper(func):
        @functools.wraps(func)
        def inner(*args, **kwargs):
            context = None
            event = None
            refreshProgressThread = None
            for arg in args:
                if isinstance(arg, java.util.Map) or isinstance(arg, dict):
                    context = arg
                    break

            logger = context.get("logger") if context else None
            target = targetMap.get(tool_name, setProgress)

            try:
                event = threading.Event()
                params = (context, totalSecs, intervalSec, event)
                if tool_name in ("offlineCtrl", "poweronCtrl"):
                    params = \
                        (context, totalSecs, intervalSec, event, display_list)
                refreshProgressThread = threading.Thread(
                    target=target,
                    args=params,
                    name="refreshProgressThread",
                )
                refreshProgressThread.start()
                if logger:
                    logger.info(
                        "refreshProgressThread start..%s" % type(context)
                    )
            except (Exception, JException):
                if logger:
                    logger.error("Trace back:%s" % str(traceback.format_exc()))
            try:
                return func(*args, **kwargs)
            finally:
                try:
                    if event:
                        event.set()
                    if refreshProgressThread:
                        refreshProgressThread.join(timeout=intervalSec + 5)
                except (Exception, JException):
                    if logger:
                        logger.error(
                            "Trace back:%s" % str(traceback.format_exc())
                        )

        return inner

    return wrapper


def change_unit_to_GB(str_val):
    """将传进来的字符串值转换为以GB为单位的浮点数
        移自巡检
    :param str_val: 字符串形式的值
    :return: flag:转换是否成功
             floatValue：转换后的值
    """
    flag = False
    float_val = 0
    # noinspection PyBroadException
    try:
        if not str_val:
            return flag, float_val

        if re.search("TB", str_val):
            float_val = float(str_val.split('T')[0].strip()) * 1024
        elif re.search("GB", str_val):
            float_val = float(str_val.split('G')[0].strip())
        elif re.search("MB", str_val):
            float_val = float(str_val.split('M')[0].strip()) / 1024
        elif re.search("KB", str_val):
            float_val = float(str_val.split('K')[0].strip()) / (1024 * 1024)
        elif re.search("B", str_val):
            float_val = float(str_val.split('B')[0].strip()) / (
                        1024 * 1024 * 1024)
        else:
            return flag, float_val

        flag = True
        return flag, float_val

    except Exception:
        return False, float_val


class ThreadWithReturn(threading.Thread):
    """
    可以返回函数值的线程类
    """
    def __init__(self, func, args=(), kwargs=None):
        super(ThreadWithReturn, self).__init__()
        self.func = func
        self.args = args
        self.kwargs = kwargs
        self.result = None

    def run(self):
        self.result = self.func(*self.args, **self.kwargs)

    def get_result(self):
        return self.result


def time_out_decorator(timeout_exception=None, timeout_callback_func=None,
                       time_out=1800, *timeout_return_result):
    """
    函数超时时间 装饰器函数
    :param timeout_exception: 超时异常
    :param timeout_callback_func: 超时回调函数
    :param time_out: 超时时间,单位:秒，默认30分钟
    :param timeout_return_result: 自定义函数返回值
    :return: 返回值，超时返回或正常返回
    """
    def decorator(func):
        def wrapper(*args, **kwargs):
            th = ThreadWithReturn(func, args, kwargs)
            th.setDaemon(True)  # 设置主线程结束子线程立刻结束
            th.start()
            th.join(time_out)  # 主线程阻塞等待time_out秒
            if th.is_alive():
                if timeout_return_result:
                    return timeout_return_result
                if timeout_exception:
                    raise timeout_exception
                if timeout_callback_func:
                    return timeout_callback_func(*args, **kwargs)
                raise AttributeError("Timeout params error!")
            else:
                return th.get_result()

        return wrapper

    return decorator
