# coding:utf-8
# Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.

import json
from memory_inspect.adapter.java_adapter import get_db_adapter
from memory_inspect.utils.constants import SPLIT_SYMBOL, OPERATOR_MAP


class DbService(object):
    """
    所需要的表在Java侧已创建，python侧只管使用
    """

    def __init__(self, context):
        self.dq_tbl = ErrorDqTbl(context)
        self.err_addr_tbl = ErrorAddrTbl(context)
        self.log_err_tbl = LogErrorTbl(context)
        self.task_process_tbl = TaskProcessDataTbl(context)
        self.task_result_tbl = TaskResultTbl(context)
        self.task_check_result = TaskCheckResultTbl(context)


class BaseDb(object):
    tbl_name = ""
    _conn = None

    def __init__(self, context):
        self.context = context
        self.logger = context.get("logger")
        self.db_full_name = context.get("db_full_name")

    def get_conn(self):
        if self.db_full_name in self.context:
            return self.context.get(self.db_full_name)
        self.logger.info("db_full_name is {}".format(self.db_full_name))
        conn = get_db_adapter().get_connection(self.db_full_name)
        self.context[self.db_full_name] = conn
        return conn

    @classmethod
    def _value_format_func(cls, value):
        if isinstance(value, list):
            value = json.dumps(value)
        if isinstance(value, basestring):
            return "'{}'".format(value)
        elif value is None:
            return "NULL"
        return str(value)

    def save(self, data_dict, ignore_exists=True):
        key_list = list(data_dict.keys())
        value_list = []
        for key in key_list:
            value_list.append(data_dict[key])

        keys_str = ",".join(map(lambda x: "'{}'".format(x), key_list))
        value_str = ", ".join(map(self._value_format_func, value_list))
        sql_str = "INSERT OR {} INTO {}({}) VALUES ({});".format(
            "IGNORE" if ignore_exists else "REPLACE",
            self.tbl_name,
            keys_str,
            value_str
        )
        return get_db_adapter().insert_or_update_query(
            conn=self.get_conn(),
            sql_str=sql_str
        )

    def _get_filter_str(self, filter_dict):
        filter_list = []
        for k, v in filter_dict.items():
            if SPLIT_SYMBOL in k:
                key, operator = k.split(SPLIT_SYMBOL)
            else:
                key, operator = k, "eq"
            filter_list.append(
                "{} {} {}".format(
                    key,
                    OPERATOR_MAP[operator],
                    self._value_format_func(v)))
        return " AND ".join(filter_list)

    def update(self, obj_id, data_dict):
        value_set_list = []
        for k, v in data_dict.items():
            value_set_list.append(
                "{}={}".format(
                    k, self._value_format_func(v)))
        value_set_str = ", ".join(value_set_list)
        sql_str = "UPDATE {} SET {} WHERE id={};".format(
            self.tbl_name, value_set_str, obj_id)
        return get_db_adapter().insert_or_update_query(
            conn=self.get_conn(),
            sql_str=sql_str
        )

    def update_by_condition(self, filter_dict, data_dict):
        value_set_list = []
        for k, v in data_dict.items():
            value_set_list.append(
                "{}={}".format(
                    k, self._value_format_func(v)))
        value_set_str = ", ".join(value_set_list)
        filter_str = self._get_filter_str(filter_dict)
        sql_str = "UPDATE {} SET {} WHERE {};".format(
            self.tbl_name, value_set_str, filter_str)
        return get_db_adapter().insert_or_update_query(
            conn=self.get_conn(),
            sql_str=sql_str
        )

    def delete_rows(self, filter_dict):
        filter_str = self._get_filter_str(filter_dict)
        sql_str = "DELETE FROM {} WHERE {};".format(self.tbl_name, filter_str)
        res = get_db_adapter().insert_or_update_query(
            conn=self.get_conn(),
            sql_str=sql_str
        )
        self.logger.info("Delete rows:{} \nres:\n{}".format(sql_str, res))
        return res


class ErrorAddrTbl(BaseDb):
    tbl_name = "err_addr_tbl"

    def query(self, sql_str):
        return get_db_adapter().query(self.get_conn(), sql_str)

    def fetch_latest_one(self, filter_dict):
        filter_str = self._get_filter_str(filter_dict)
        sql_str = "SELECT * FROM {} WHERE {} ORDER BY log_time DESC LIMIT 1;"\
            .format(self.tbl_name, filter_str)
        return get_db_adapter().fetch_one(
            conn=self.get_conn(),
            sql_str=sql_str
        )

    def get_latest_err_time_from_db(self, sn, controller):
        sql_str = "SELECT log_time FROM {} WHERE sn='{}' AND ctrl='{}' " \
                  "ORDER BY log_time DESC LIMIT 1;".format(
                    self.tbl_name, sn, controller)
        result = get_db_adapter().fetch_one(
            conn=self.get_conn(),
            sql_str=sql_str
        )
        return result[0] if result else ""

    def update_batch(self, sectors, params):
        q_mark = ["?"] * len(sectors)
        get_db_adapter().update_batch(
            self.get_conn(),
            "replace into {} ({}) values ({});"
                .format(self.tbl_name, ",".join(sectors), ",".join(q_mark)),
            params)


class ErrorDqTbl(BaseDb):
    tbl_name = "error_dq"

    def fetch_latest_one(self, filter_dict):
        filter_str = self._get_filter_str(filter_dict)
        sql_str = "SELECT * FROM {} WHERE {} ORDER BY log_time DESC LIMIT 1;"\
            .format(self.tbl_name, filter_str)
        return get_db_adapter().fetch_one(
            conn=self.get_conn(),
            sql_str=sql_str
        )

    def get_latest_err_time_from_db(self, sn, controller):
        sql_str = "SELECT log_time FROM {} WHERE sn='{}' AND ctrl='{}' " \
                  "ORDER BY log_time DESC LIMIT 1;".format(
                    self.tbl_name, sn, controller)
        result = get_db_adapter().fetch_one(
            conn=self.get_conn(),
            sql_str=sql_str
        )
        return result[0] if result else ""

    def get_latest_data(self, filter_dict, limit=1001):
        filter_str = self._get_filter_str(filter_dict)
        sql_str = "SELECT * FROM {} WHERE {} ORDER BY id DESC LIMIT ?;".format(
            self.tbl_name, filter_str)
        result = get_db_adapter().query_with_params(
            self.get_conn(),
            sql_str,
            limit
        )
        return result


class TaskResultTbl(BaseDb):
    tbl_name = "analyze_result"

    def has_warning(self, sn_number):
        sql_str = "SELECT COUNT(*) FROM {} WHERE sn='{}';".format(self.tbl_name, sn_number)
        self.logger.info("has_warning sql:{}".format(sql_str))
        result = get_db_adapter().query_with_params(
            self.get_conn(),
            sql_str,
        )
        self.logger.info("has_warning result:{}".format(result))
        return bool(result[0].get("COUNT(*)"))

    def get_err_dimm(self, sn_number):
        filter_str = self._get_filter_str({"sn": sn_number})
        sql_str = "SELECT DISTINCT ctrl, node, channel, dimm, level FROM {} WHERE {}".format(self.tbl_name, filter_str)
        self.logger.info("get_err_dimm sql:{}".format(sql_str))
        result = get_db_adapter().query_with_params(
            self.get_conn(),
            sql_str,
        )
        self.logger.info("get_err_dimm result:{}".format(result))
        return result

    def get_warning_list(self, filter_dict):
        filter_str = self._get_filter_str(filter_dict)
        sql_str = "SELECT * FROM {} WHERE {} ORDER BY id".format(self.tbl_name, filter_str)
        result = get_db_adapter().query_with_params(
            self.get_conn(),
            sql_str,
        )
        return result


class TaskCheckResultTbl(BaseDb):
    tbl_name = "tbl_task_check_result"

    def get_check_result(self, sn):
        sql_str = "select * from {} WHERE sn = ?".format(self.tbl_name)
        return get_db_adapter().query_with_params(self.get_conn(), sql_str, sn)

    def save_check_result(self, sn, status, detail):
        get_db_adapter().update_with_params(
            self.get_conn(),
            "replace into {} (sn, status, detail) values (?, ?, ?)"
                .format(self.tbl_name),
            sn, status, detail)


class TaskProcessDataTbl(BaseDb):
    tbl_name = "tbl_task_process_data"
    keys = ("sn", "ctrl", "collected_log_files")

    def query_collected_log_files(self, sn):
        self.logger.info("query collected log files, sn={}".format(sn))
        result = get_db_adapter().query_with_params(
            self.get_conn(),
            "SELECT * FROM {} WHERE sn = ?".format(
                self.tbl_name), sn)
        self.logger.info("the result = {}".format(result))
        return group_by(result, "ctrl")

    def save_collected_log_files(self, data):
        get_db_adapter().update_with_params(
            self.get_conn(),
            "replace into {} (sn, ctrl, collected_log_files) values (?, ?, ?)"
                .format(self.tbl_name),
            data.get("sn"),
            data.get("ctrl"),
            data.get("collected_log_files"))


class LogErrorTbl(BaseDb):
    tbl_name = "tbl_log_err_info"

    def query(self, sql_str):
        return get_db_adapter().query(self.get_conn(), sql_str)

    def fetch_latest_one(self, filter_dict):
        filter_str = self._get_filter_str(filter_dict)
        sql_str = "SELECT * FROM {} WHERE {} ORDER BY log_time DESC LIMIT 1;" \
            .format(self.tbl_name, filter_str)
        return get_db_adapter().fetch_one(
            conn=self.get_conn(),
            sql_str=sql_str
        )

    def update_batch(self, sectors, params):
        q_mark = ["?"] * len(sectors)
        get_db_adapter().update_batch(
            self.get_conn(),
            "replace into {} ({}) values ({});"
                .format(self.tbl_name, ",".join(sectors), ",".join(q_mark)),
            params)

    def save_log_error_data(self, data_list):
        table_heads = ["sn", "ctrl", "time", "err_type", "obj_id", "matched_data", "origin_info", "item_id"]
        params = []
        for data in data_list:
            param = []
            for key in table_heads:
                param.append(data.get(key, ""))
            params.append(param)
        self.update_batch(table_heads, params)

    def get_latest_err_time_from_db(self, sn, ctrl, item_id):
        sql_str = "SELECT time FROM {} WHERE sn='{}' AND ctrl='{}' AND item_id='{}' " \
                  "ORDER BY time DESC LIMIT 1;".format(
                    self.tbl_name, sn, ctrl, item_id)
        result = get_db_adapter().fetch_one(
            conn=self.get_conn(),
            sql_str=sql_str
        )
        return result[0] if result else ""


def group_by(data_list, key):
    """
    按指定的key将data_list进行分组
    :param data_list: 数据源
    :param key: 分组的key
    :return: {key1: [xx,yy], key2:[zz, dd]}
    """
    group_data = {}
    if not data_list:
        return group_data
    for data in data_list:
        item_key = data.get(key)
        if item_key not in group_data:
            group_data[item_key] = []
        group_data[item_key].append(data)
    return group_data
