# -*- coding: UTF-8 -*-

import re
#  Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
import time
import traceback

from cbb.frame.dsl.dsl import DslException
from cbb.frame.faulttree.common import NodeStatus, NodeResult
from java.lang import Exception as JException

# 原始信息最多展示3000条，太多没啥意义，还可能导致卡死
MAX_ORIGIN = 3000


class CheckNode:
    def __init__(self, context, node_id, data):
        self.context = context
        self.node_id = node_id
        self.node_result = None
        self.origin_info = []
        self.resource = context.get("resource")
        self.data = data
        self.tables = {}
        self.logger = context.get("logger")
        self.error_code = ""

    def init_result(self, all_result):
        if self.node_id in all_result:
            self.node_result = all_result[self.node_id]
            return
        self.node_result = NodeResult()
        self.node_result.id = self.node_id
        self.node_result.name = self.get_message(self.node_id + ".name")
        all_result[self.node_id] = self.node_result
        return

    def set_status(self, status):
        self.node_result.status = status

    def exec_fun(self, fun_obj, required=False):
        """
        反射执行一个指定的方法
        :param fun_obj: 方法对象
        :param required: 是否必须存在
        :return: 方法执行的结果
        """
        """
        :param fun_name: 方法名
        :param required: 是否必选，如果为True，找不到方法时抛出异常
        :param fun_obj: 执行的函数
        :return: 函数执行结果
        """
        if fun_obj:
            start = time.time()
            self.logger.info("begin to execute fun {}".format(fun_obj))
            result = self.exec_fun_inner(fun_obj)
            self.logger.info("cost second {} by execute fun {}".format(int(time.time() - start), fun_obj))
            return result
        if required:
            raise DslException("common.not.exist.function")
        return None

    def exec_fun_inner(self, fun_obj):
        node_result = self.node_result
        try:
            ret = fun_obj()
            self.logger.info("end to execute fun {}".format(fun_obj))
            return ret
        except DslException as e:
            self.logger.error(
                "dsl exception. {}".format(traceback.format_exc()))
            err_key = e.get_code()
            self.logger.error("err_key=" + err_key)
            node_result.detail = self.get_message(err_key)
            node_result.status = NodeStatus.NO_CHECK
            node_result.suggestion = \
                self.get_message("common.contact.engineers.suggestion")
            raise

        except (Exception, JException):
            self.logger.error(
                "exec exception. {}".format(traceback.format_exc()))
            node_result.detail = self.get_message("common.check.failed")
            node_result.status = NodeStatus.NO_CHECK
            node_result.suggestion = \
                self.get_message("common.contact.engineers.suggestion")
            raise
        finally:
            if len(self.origin_info) > MAX_ORIGIN:
                node_result.origin_info = self.get_message("common.origin.info.too.more") + "\n"
            node_result.origin_info += "\n".join(self.origin_info[:MAX_ORIGIN])
            node_result.tables = self.tables

    def set_node_hung(self):
        self.node_result.status = NodeStatus.HUNG
        self.node_result.detail = self.get_message("common.check.hung")

    def is_no_hit(self):
        return self.node_result.status == NodeStatus.NO_HIT

    def is_hit(self):
        return self.node_result.status == NodeStatus.HIT

    def is_not_finish(self):
        return self.node_result.status \
               in (NodeStatus.NO_START, NodeStatus.HUNG)

    def exec_check(self, run_fun_name):
        status = self.exec_fun(run_fun_name, True)
        self.update_node_result(self.node_result, status)

    def update_node_result(self, node_result, status):
        node_id = node_result.id
        # 为Ture表示该节点有问题
        if status is True:
            self._update_result_when_hit(node_id, node_result)
        # 为False表示该节点正常
        elif status is False:
            if self.error_code:
                key = node_id + ".normal." + self.error_code
            else:
                key = node_id + ".normal"
            node_result.detail = self.get_message(key)
            node_result.status = NodeStatus.NO_HIT
            node_result.suggestion = self.get_message("common.no.suggestion")
        else:
            node_result.detail = self.get_message(node_id + ".info")
            node_result.status = NodeStatus.SUCCESS
            node_result.suggestion = self.get_message("common.no.suggestion")

    def _update_result_when_hit(self, node_id, node_result):
        """
        命中的时候，更新检查结果
        """
        node_result.status = NodeStatus.HIT
        if not isinstance(self.error_code, list):
            self.error_code = [self.error_code]
        all_codes = set()
        all_cases = set()
        for code_id in self.error_code:
            err_code, case = self.get_result(node_id, code_id, node_result)
            if err_code:
                all_codes.add(err_code)
            if case:
                all_cases.add(case)
        node_result.err_code = ",".join(all_codes)
        if all_cases:
            node_result.case = list(all_cases)[0]

    def get_result(self, node_id, code, node_result):
        code_key = node_id + ".code"
        error_key = node_id + ".abnormal"
        suggestion_key = node_id + '.suggestion'
        help_key = node_id + ".help"
        if code:
            code_key += '.' + code
            error_key += "." + code
            suggestion_key += "." + code
            help_key += "." + code
        node_result.detail = CheckNode.concat_message_ignore_empty(
            node_result.detail, self.get_message(error_key), '\\n')
        node_result.suggestion = CheckNode.concat_message_ignore_empty(
            node_result.suggestion, self.get_message(suggestion_key), '\\n')
        err_code = self.get_message(code_key)
        case = self.get_message(help_key)
        return err_code, case

    @staticmethod
    def concat_message_ignore_empty(message_a, message_b, delimiter):
        if not bool(message_a):
            return message_b
        if not bool(message_b):
            return message_a
        return message_a + delimiter + message_b

    def get_message(self, key):
        if not key:
            return ''
        message = self.resource.getString(key)
        if not message:
            return ""
        message = self.format_message(message)
        return message

    def format_message(self, message):
        try:
            reg = re.compile(r"\[#(.*)#\]", 0)
            matched = re.search(reg, message)
            if matched:
                mark = matched.group(0)
                data_key = matched.group(1)
                message = self.format_by_special_data(message, mark, data_key)
            else:
                message = message.format(**self.trans_data(self.data))
        except Exception:
            self.logger.error(
                "parse exception: {}".format(traceback.format_exc()))
        return message

    def format_by_special_data(self, message, mark, data_key):
        messages = []
        new_message = message.replace(mark, "")
        self.logger.info("special data={}".format(self.data[data_key]))
        self.logger.info("new message={}".format(new_message))
        for item in self.data[data_key]:
            messages.append(new_message.format(**self.trans_data(item)))
        message = "\\n".join(messages)
        return message

    def trans_data(self, source_data):
        if isinstance(source_data, dict):
            result = {}
            for key in source_data:
                real_key = key.encode("utf-8")
                result[real_key] = self.trans_data(source_data[key])
        elif isinstance(source_data, list) or isinstance(source_data, set):
            if CheckNode.all_match(
                    source_data,
                    lambda x: isinstance(x, str) or 'unicode' in str(
                        type(x))):
                result = ','.join(source_data)
            else:
                result = []
                for val in source_data:
                    result.append(self.trans_data(val))
        else:
            result = source_data
        return result

    @staticmethod
    def all_match(source_list, func):
        for item in source_list:
            if func(item) is not True:
                return False
        return True
