# -*- coding: UTF-8 -*-
# Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.
"""
内存检查任务整体流程实现：日志解压->日志解析入库->规则分析
Arm和Purely不同硬件大体流程是相同的，但有少量差异。
实现上用MemoryCheckTask定义通用流程，差异部分放在ArmMemoryCheckTask和PurelyMemoryCheckTask中
"""

import os
import datetime
import collections
from memory_inspect.rule import memory_rule
from memory_inspect.db.task_db_service import DbService
from memory_inspect.adapter.java_adapter import get_db_adapter, get_uncheck_exception, get_msg
from memory_inspect.parser.rasdaemon_log_parser import RasdaemonParser
from memory_inspect.parser.mce_log_parser import MceLogParser
from memory_inspect.parser.imu_log_parser import ImuLogParser
from memory_inspect.parser.log_reset_parser import LogResetParser
from memory_inspect.parser.message_dq_parser import MessageParser
from memory_inspect.parser.bios_uart_print_parser import BiosUartPrintParser
from memory_inspect.parser.memory_info_parser import MemoryInfoParser
from memory_inspect.utils import six_about
from memory_inspect.rule.memory_rule_engine import RuleEngine
from memory_inspect.utils.config_handler import ConfigParserHandler
from memory_inspect.utils.constants import (
    TIME_SEC_FORMATTER, CheckResult
)


class MemoryCheckTask(object):
    def __init__(self, context):
        self.context = context
        self.logger = context.get("logger")
        self.lang = context.get("lang")
        self.sn_number = context.get("dev").get("sn")
        self.rule_configs = ConfigParserHandler(context["config_path"], self.logger).get_all_items()
        self.db_service = DbService(context)
        self.start_time = self.format_start_time(self.context.get("start_time"))
        self.sys_time = self.format_system_time(context.get("sys_time"))
        self.last_err_time = self.start_time
        self.origin_info = []
        self.has_new_err_controllers = set()
        self.sorted_logs = context.get("collected_files", {})

    def execute(self):
        if not self._check_db_connection():
            self.logger.info("db not exist")
            raise get_uncheck_exception(
                get_msg(self.lang, "db.file.not.exist"), "")

        self._clear_db_expired_data()
        status, err_msg = self._do_check()
        return status, os.linesep.join(self.origin_info), err_msg

    def _do_check(self):
        # 日志解析、入库
        self._parse_log_to_db(self.sorted_logs)
        if not self.has_new_err_controllers:
            self.logger.info("no new error generate. return pass.")
            return CheckResult.PASS, ""

        # 规则分析
        status, err_msg = self._analyse_all_rule()

        self._add_extra_info(self.sorted_logs, status)

        return status, err_msg

    def _add_extra_info(self, sorted_logs, status):
        raise NotImplementedError

    def _parse_log_to_db(self, sorted_logs):
        raise NotImplementedError

    def _clear_db_expired_data(self):
        # 清理掉60天之前的过期数据
        begin_time = six_about.get_time_by_delta_process(self.sys_time, days=-60)
        self.db_service.err_addr_tbl.delete_rows(dict(sn=self.sn_number, log_time__lt=begin_time))

    def _analyse_all_rule(self):
        status, err_msg, all_rule_origin_infos = CheckResult.PASS, [], {}

        # 每个控制器分析一遍规则
        for ctrl in self.has_new_err_controllers:
            rule_context = self._get_rule_context(self.sys_time)
            rule_context.analyse_ctrls = [ctrl]
            rule_context.vendor_data = self._get_vendor_data(ctrl)
            self.logger.info("vendor data={}".format(rule_context.vendor_data))
            ctrl_status, ctrl_err_msg, origin_info = RuleEngine(self.rule_configs, self.logger).execute(rule_context)
            all_rule_origin_infos.update(origin_info)
            if ctrl_status:
                status = CheckResult.NOT_PASS
                err_msg.append(ctrl_err_msg)
        self.origin_info.extend(all_rule_origin_infos.values())
        return status, os.linesep.join(err_msg)

    def _get_vendor_data(self, ctrl):
        file_names = self.sorted_logs.get(ctrl).get("memory_info", [])
        if not file_names:
            return []
        for file_name in file_names:
            if "bios" in file_name:
                vendor_data = BiosUartPrintParser(file_name).execute()
                if vendor_data:
                    return vendor_data
            if "patch" in file_name:
                vendor_data = MemoryInfoParser(file_name).execute()
                if vendor_data:
                    return vendor_data
        return []

    def _filter_old_err_data(self, datas):
        return [data for data in datas if data["time"] > self.last_err_time]

    def _get_rule_context(self, current_time):
        rule_context = memory_rule.RuleContext()
        rule_context.db_service = self.db_service
        rule_context.current_time = current_time
        rule_context.sn = self.sn_number
        rule_context.logger = self.logger
        rule_context.lang = self.lang
        rule_context.last_err_time = self.last_err_time
        return rule_context

    def _check_db_connection(self):
        conn = get_db_adapter().get_connection(self.context["db_full_name"])
        if not conn:
            self.logger.error("db conn:{}, abort".format(conn))
            return False
        table_db_err_conn = self.db_service.err_addr_tbl.get_conn()
        if not table_db_err_conn:
            self.logger.error(
                "table_db_err_conn:{}, abort".format(table_db_err_conn))
            return False
        return True

    def _update_last_err_time(self, ctrl):
        last_err_time = self.db_service.err_addr_tbl.get_latest_err_time_from_db(self.sn_number, ctrl)
        if last_err_time:
            self.last_err_time = last_err_time

    def _get_full_addr(self, data):
        """
        获取全地址：逻辑地址+物理地址
        """
        raise NotImplementedError

    def _save_data_to_db(self, ctrl, datas, is_multi_dq=None):
        trans_map = collections.OrderedDict(
            sn="sn", ctrl="ctrl", log_time="time", node="cpu", channel="card", dimm="slot",
            rank="rank", bgroup="device", bank="bank", row="row", col="col",
            err_type="err_type", slot="_slot", addr="_addr", ordinal="_ordinal", dq="_dq")
        params = []
        for data in datas:
            param = []
            data.update(dict(sn=self.sn_number, ctrl=ctrl,
                             _dq=is_multi_dq,
                             _slot=data.get("cpu", "") + data.get("card", "") + data.get("slot", ""),
                             _ordinal=ctrl + data.get("ordinal"),
                             _addr=self._get_full_addr(data)
                             ))
            for key in trans_map:
                param.append(data.get(trans_map.get(key), ""))
            params.append(param)
        self.db_service.err_addr_tbl.update_batch(trans_map.keys(), params)

    @staticmethod
    def format_start_time(timestamp):
        start_time_obj = datetime.datetime.fromtimestamp(int(timestamp) / 1000)
        start_time_str = start_time_obj.strftime(TIME_SEC_FORMATTER)
        return start_time_str

    @staticmethod
    def format_system_time(time_str):
        time_str = time_str.split()[0]
        time_obj = datetime.datetime.strptime(time_str, "%Y-%m-%d/%H:%M:%S")
        return time_obj.strftime("%Y-%m-%d %H:%M:%S")


class ArmMemoryCheckTask(MemoryCheckTask):
    def __init__(self, context):
        super(ArmMemoryCheckTask, self).__init__(context)
        self.origin_dq_info = ""

    def _parse_log_to_db(self, sorted_logs):
        # 解析DQ信息，是否存在多DQ
        imu_parser = ImuLogParser(self.logger)
        dq_info, self.origin_dq_info = imu_parser.parse(sorted_logs)

        for ctrl_id in sorted(sorted_logs.keys()):
            log_record = sorted_logs[ctrl_id]
            rasdaemon_files = log_record.get("rasdaemon", [])
            self._parse_rasdaemon_log_to_db(ctrl_id, dq_info, rasdaemon_files)

    def _parse_rasdaemon_log_to_db(self, ctrl_id, dq_info, rasdaemon_files):
        rasdaemon_parser = RasdaemonParser(self.logger)
        self._update_last_err_time(ctrl_id)
        for file_name in rasdaemon_files:
            parse_results = rasdaemon_parser.parse(file_name)
            new_err_datas = self._filter_old_err_data(parse_results)
            if new_err_datas:
                self.has_new_err_controllers.add(ctrl_id)
                self._save_data_to_db(ctrl_id, new_err_datas,
                                      dq_info.get(ctrl_id, int(dq_info.get(ctrl_id, False))))

    def _add_extra_info(self, sorted_logs, status):
        if status == CheckResult.NOT_PASS:
            self.origin_info.append(self.origin_dq_info)
            self.origin_info.append(LogResetParser(self.logger).parse(sorted_logs))

    def _get_full_addr(self, data):
        _slot = data.get("cpu", "") + data.get("card", "") + data.get("slot", "")
        _rank_addr = _slot + data.get("rank", "") + data.get("bank", "") + data.get("device", "")
        _full_addr = "{}_{}_{}_{}".format(data.get("sys_addr"), _rank_addr, data.get("row"), data.get("col"))
        return _full_addr


class PurelyMemoryCheckTask(MemoryCheckTask):
    def _parse_log_to_db(self, sorted_logs):
        # 解析mcelog，并将数据入库
        for ctrl_id in sorted(sorted_logs.keys()):
            log_record = sorted_logs[ctrl_id]
            mce_log_files = log_record.get("mcelog", [])
            self._parse_one_ctrl_mce_log_to_db(ctrl_id, mce_log_files)

        # 关联DQ信息(按时间关联，对应时间的槽位都更新)
        for ctrl_id in sorted(sorted_logs.keys()):
            log_record = sorted_logs[ctrl_id]
            message_files = log_record.get("messages", [])
            parser = MessageParser(self.logger)
            dq_info = parser.parse(message_files)
            self._update_dq_info(ctrl_id, dq_info)
            self._update_memory_ecc_info(ctrl_id, parser.memory_info)

    def _update_memory_ecc_info(self, ctrl_id, memory_info):
        new_err_datas = self._filter_old_err_data(memory_info)
        if new_err_datas:
            self.has_new_err_controllers.add(ctrl_id)
            self._save_data_to_db(ctrl_id, new_err_datas)

    def _parse_one_ctrl_mce_log_to_db(self, ctrl_id, mce_log_files):
        parser = MceLogParser(self.logger, self.last_err_time)
        self._update_last_err_time(ctrl_id)
        for file_name in mce_log_files:
            parse_results = parser.parse(file_name, self.sn_number)
            new_err_datas = self._filter_old_err_data(parse_results)
            if new_err_datas:
                self.has_new_err_controllers.add(ctrl_id)
                self._save_data_to_db(ctrl_id, new_err_datas)

    def _update_dq_info(self, ctrl, dq_info):
        for dq in dq_info:
            filter_dict = dict(sn=self.sn_number, ctrl=ctrl, log_time=dq.get("time"))
            data_dict = dict(dq=int(dq.get("multi_dq")))
            self.db_service.err_addr_tbl.update_by_condition(filter_dict, data_dict)

    def _add_extra_info(self, sorted_logs, status):
        self.logger.info("not need add extra info")

    def _get_full_addr(self, data):
        return data.get("sys_addr")
