# coding: utf-8
# Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.

import time
import locale
import datetime
import traceback
import re

from memory_inspect.utils.device_info_table import MACHINE_CONFIG_MAP
from memory_inspect.utils.constants import BANK_NUM_TUPLE, TIME_FORMATTER
from memory_inspect.address_decodes import forwardSadOperation

# UCE error 年匹配规则
UCE_YEAR_REGX = re.compile("mcelog:\s+TIME\s+\d+\s+\w+\s+\w+\s+\d+\s+\d+:\d+:\d+\s+(\d+)")

# 误报规则：出现 CPU 0 BANK 9
BANK_REGEX = re.compile("CPU\s+\d+\s+BANK\s+(\d+)")
# 误报规则：ADDR a0300
ADDR_REGEX = re.compile("\s+ADDR\s+a0300")
# 误报规则：bank取值为9~11
ERR_BANK_LIST = ["9", "10", "11"]
# 误报规则：条件在一个文件中出现，且在连续的10行内命中
LINE_NUMBER_LIMIT = 10


class MceLogParser(object):
    """
    mcelog日志解析，涉及2个文件
    """
    def __init__(self, logger, start_time):
        self.logger = logger
        self.start_time = start_time

    @staticmethod
    def get_machine_code(sn_number):
        return sn_number[2:10]

    @staticmethod
    def get_bank_num(line_str):
        sp_lst = line_str.split()
        return sp_lst[sp_lst.index("BANK") + 1]

    @staticmethod
    def get_addr_year(line_str):
        return line_str.strip().split()[-1]

    @staticmethod
    def has_uncorrected_error_record(line_str):
        key_words = ["Uncorrected patrol scrub error", "Uncorrected error"]
        for key in key_words:
            if key.lower() in line_str.lower():
                return True
        return False

    @staticmethod
    def init_uce_record_struct(uce_record_structure):
        """
        初始化uce结构体, 长度固定是3个
        用于记录 bank: 行号， addr: 行号， Uncorrected error 行号，主要用于排除第一次上电导致的误报
        :param uce_record_structure:
        :return:
        """
        for i in range(3):
            uce_record_structure[i] = -1

    @staticmethod
    def is_unexpected_error(uce_record_struct):
        """
        是否是10行之内的指定bank报错
        1. 出现 CPU 0 BANK 9
        bank取值为9~11
        2. 出现ADDR a0300
        3. 出现Uncorrected error
        :param uce_record_struct: 记录各个行号
        :return: True 误报， False 非误报
        """
        return all(
            (uce_record_struct[0] > 0,
             uce_record_struct[1] > uce_record_struct[0],
             uce_record_struct[2] > uce_record_struct[1],
             0 < uce_record_struct[2] - uce_record_struct[0] + 1 < LINE_NUMBER_LIMIT)
        )

    def parse(self, file_path, sn):
        """
        解析一个mcelog日志文件
        :param file_path: 一个mcelog文件全路径
        :param sn: 设备SN，主要用于通过BOM获取对应的型号、内存等硬件信息
        :return: all_results: [{node:0, channel:1, dimm:0, rank:2, bgroup:4, bank:4, row:0xaa, column:0x21，
                                err_type:CE, log_time='2021-02-06 22:00'}]
        """
        all_results = []
        error_line = ["", "", ""]
        self._set_locale_env()
        self.latest_uce_error_year = ''
        with open(file_path, 'r') as log_file:
            self.logger.info("[MceLogParser] start check file:{}".format(file_path))
            # 用于记录 bank: 行号， addr: 行号， Uncorrected error 行号，主要用于排除第一次上电导致的误报
            uce_record_structure = [-1, -1, -1]
            for number, line_str in enumerate(log_file):
                self.get_latest_year(line_str)
                self.pick_uce_key_line_info(line_str, number, uce_record_structure)
                self.check_uce_error(line_str, number, all_results, uce_record_structure)
                error_line[0] = error_line[1]
                error_line[1] = error_line[2]
                error_line[2] = line_str
                self.check_ce_error(sn, error_line, number, all_results)


        return all_results

    def get_latest_year(self, line):
        res = UCE_YEAR_REGX.findall(line)
        self.latest_uce_error_year = res[0] if res else self.latest_uce_error_year

    def pick_uce_key_line_info(self, line_str, line_num, uce_record_structure):
        """
        记录uce bank, addr, key 的行号信息。
        :param line_str: 行字符串
        :param line_num: 行号
        :param uce_record_structure: 信息结构体
        :return:
        """
        bank_record = BANK_REGEX.findall(line_str)
        if bank_record and bank_record[0] in ERR_BANK_LIST:
            self.init_uce_record_struct(uce_record_structure)
            uce_record_structure[0] = line_num
        addr_record = ADDR_REGEX.findall(line_str)
        if addr_record:
            uce_record_structure[1] = line_num
        if "Uncorrected error" in line_str:
            uce_record_structure[2] = line_num

    def check_ce_error(self, sn, error_line, number, all_results):
        first_line, second_line, line_str = error_line
        error_info = self.pick_mce_err(first_line, second_line, line_str)
        if not error_info:
            return
        time_to_fmt_str, err_addr = error_info
        self.logger.info("check_ce_error:time_to_fmt_str:{} err_addr:{},  self.start_time:{}".format(
            time_to_fmt_str, err_addr, self.start_time))
        if time_to_fmt_str <= self.start_time:
            return

        try:
            machine_code = self.get_machine_code(sn)
            dev_config_data = MACHINE_CONFIG_MAP.get(machine_code)
            phy_addr = self.get_real_addr(err_addr, dev_config_data)
            all_results.append(
                dict(time=time_to_fmt_str, cpu=phy_addr[0], card=phy_addr[1], slot=phy_addr[2],
                     err_type='HA CE', rank=phy_addr[3], device=phy_addr[4], bank=phy_addr[5], row=phy_addr[6],
                     col=phy_addr[7], sys_addr=err_addr, ordinal="%s[%s]" % (time_to_fmt_str, str(number).zfill(6))
                     ))
        except Exception:
            self.logger.error(
                "Address decode error:{}, err:\n{}".format(err_addr, str(traceback.format_exc()))
            )

    def check_uce_error(self, line_str, number, all_results, uce_record_struct):
        if self.has_uncorrected_error_record(line_str):
            # 如果判断是误报则不保存记录
            if self.is_unexpected_error(uce_record_struct):
                self.logger.info("Find uncorrected error, but it's unexpected record. Do not record it!")
                return
            log_time_str = self.get_full_time_str("{} {} {}".format(*line_str.split()[:3]),
                                                  "%Y %b %d %H:%M:%S", "")
            self.logger.info("uce log time={}".format(log_time_str))
            if log_time_str > self.start_time:
                self.logger.error("has uce error: {}".format(line_str))
                all_results.append(dict(err_type='HA UCE', time=log_time_str,
                                        cpu='0', card='0', slot='0',
                                        ordinal="%s[%s]" % (log_time_str, str(number).zfill(6))))

    def pick_mce_err(self, before_line, second_line, line_str):
        if " ADDR " not in second_line:
            return
        try:
            bank_num = self.get_bank_num(before_line)
        except Exception as e:
            self.logger.error("ERROR: get_bank_num err:{}, before_line:\n{}".format(e, before_line))
            return
        if bank_num not in BANK_NUM_TUPLE:
            return

        addr_year = self.get_addr_year(line_str)
        sp_list = second_line.split()

        time_to_fmt_str = self.get_full_time_str("{} {} {}".format(*sp_list[:3]), "%Y %b %d %H:%M:%S",
                                                 addr_year)
        err_code = sp_list[sp_list.index("ADDR") + 1]
        self.logger.info("pick_mce_err:{} {}".format(time_to_fmt_str, err_code))
        return time_to_fmt_str, err_code

    def get_full_time_str(self, log_time_str, log_format, addr_year):
        if addr_year and addr_year.isdigit():
            cur_year = addr_year
        elif self.latest_uce_error_year and self.latest_uce_error_year.isdigit():
            self.logger.info(
                "addr year is invalid. use latest error time:{}.".format(self.latest_uce_error_year)
            )
            cur_year = self.latest_uce_error_year
        else:
            self.logger.info("addr year is invalid. use current time.")
            cur_year = datetime.datetime.now().year
        res_dt_str = self._get_time_str(
            "{} {}".format(cur_year, log_time_str),
            log_format
        )
        if not res_dt_str or datetime.datetime.now().strftime(TIME_FORMATTER) < res_dt_str:
            res_dt_str = self._get_time_str(
                "{} {}".format(int(cur_year) - 1, log_time_str),
                log_format
            )
        return res_dt_str

    def _get_time_str(self, time_str, log_format):
        try:
            time_obj = time.strptime(time_str, log_format)
            return time.strftime(TIME_FORMATTER, time_obj)
        except Exception:
            self.logger.error("strip time error.time_str={}, log_format={}".format(time_str, log_format))
            return ""

    def _set_locale_env(self):
        try:
            # 此操作保证中文环境日志日期正常解析
            locale.setlocale(locale.LC_ALL, "en.GBK")
        except locale.Error:
            self.logger.error("not support set local")

    def get_real_addr(self, logic_addr, config_data):
        #  ["18500 V5", "Disable", "2", "8", "32G"],
        func_name = "SAD_Base_{1}_{2}_{0}P".format(*config_data[-3:])
        func_args = (int(logic_addr, base=16), "Open", config_data[-4])

        func_obj = getattr(forwardSadOperation, func_name)
        real_addr_info = func_obj(*func_args)
        real_addr_split_list = real_addr_info.split(",")
        if len(real_addr_split_list) != 12:
            self.logger.info("get real addr info invalid: {} -> {}".format(logic_addr, real_addr_info))
            raise ValueError
        return real_addr_split_list[1:9]
