#!/usr/bin/python
# -*- coding: utf-8 -*-
import re
import time
import uuid
import traceback

from getDBConnection import get_zenith_session
import common_tasks.const_sql as const_sql
from common_tasks.base_task import BaseTask

ONE_PACKAGE = 500


class UpgradeLocalNM(BaseTask):
    def __init__(self, product_name="NCE"):
        super(UpgradeLocalNM, self).__init__()
        self.set_product_name(product_name)
        self.info("UpgradeLocalNM init product_name is %s" % product_name)
        self.src_db_session = get_zenith_session('MCDB', 'MCDB', product_name)
        if self.src_db_session is None:
            self.error("MCDB is None")
            return
        self.src_db_session.autocommit(True)
        self.src_db_cursor = self.src_db_session.cursor()

        self.dst_db_session = get_zenith_session('neresdb', 'neresdb', product_name)
        if self.dst_db_session is None:
            self.error("neresdb_sess is None")
            return
        self.dst_db_session.autocommit(True)
        self.dst_db_cursor = self.dst_db_session.cursor()

        self.src_table = "tbl_LocalNM"
        self.src_table_cols = (
            "iPAddress", "iIsLocalNM", "iNMID", "iDevType", "strNMName", "strOwner",
            "strMemo", "strUserLabel", "strUserName", "strCreateTime")
        self.src_table_cols_index = {y: x for x, y in enumerate(self.src_table_cols)}
        self.dst_table = "LocalNM"
        self.dst_table_cols = (
            "tenantId", "resId", "neId", "name", "ipAddr", "isLocalNM",
            "typeName", "typeId", "remark", "owner", "userLabel", "userName",
            "createAt", "updateAt")
        self.idmapping_dic = {}

    def get_resid_from_idmapping(self, paras):
        idmappingdb_sess = get_zenith_session('idmappingdb', 'idmappingdb', self.product_name)
        if idmappingdb_sess is None:
            self.error("get idmappingdb_sess session fail")
            return
        idmappingdb_sess.autocommit(True)
        id_mapping_cursor = idmappingdb_sess.cursor()

        tmp_table_name = "tmp_neid_%s" % self.dst_table
        id_mapping_cursor.execute(const_sql.DROP_TEMP_TABLE % tmp_table_name)
        id_mapping_cursor.execute(const_sql.CREATE_TEMP_TABLE % tmp_table_name)
        insert_stmt = "insert into tmp_%s values(:1)" % tmp_table_name

        nativeIds = []
        for para in paras:
            data = []
            nativeId = "NE=%s" % para[self.src_table_cols_index.get("iNMID")]
            data.append(nativeId)
            tuple_data = tuple(data)
            nativeIds.append(tuple_data)
            if len(nativeIds) == ONE_PACKAGE:
                id_mapping_cursor.executemany(insert_stmt, nativeIds)
                self.debug("one package:%s" % nativeIds)
                nativeIds = []

        if len(nativeIds) != 0:
            id_mapping_cursor.executemany(insert_stmt, nativeIds)
            self.debug("last package datas: %s" % nativeIds)

        query_stmt = const_sql.INNER_JOIN_TEMP_TABLE % tmp_table_name
        self.debug("query sql: %s" % query_stmt)
        id_mapping_cursor.execute(query_stmt)
        result = id_mapping_cursor.fetchall()

        for r in result:
            l = list(r)
            self.idmapping_dic[l[1].lstrip("NE=")] = l[0]
            self.debug("result: %s: %s" % (l[1], l[0]))

        id_mapping_cursor.execute(const_sql.DROP_TEMP_TABLE % tmp_table_name)
        id_mapping_cursor.close()
        idmappingdb_sess.close()

    def get_res_id_from_ne(self, ne_id):
        """
        从tbl_ne获取resId，保证和naBasicInfo表中resId一致
        :param ne_id: 逻辑id
        :return: resId
        """
        eam_db_session = get_zenith_session('eamdb', 'eamdb', self.product_name)
        if eam_db_session is None:
            res_id = str(uuid.uuid1())
            self.error(
                "UpgradeLocalNM get_res_id_from tbl_ne get eamdb error, ne_id is: %s, resId is %s" % (
                    str(ne_id), res_id))
            return res_id
        eam_db_session.autocommit(True)
        eam_db_cursor = eam_db_session.cursor()
        select_stmt = "select RESID from tbl_ne where DN = 'OS=%s'" % str(ne_id)
        result = self.exec_query_sql(eam_db_cursor, select_stmt)
        if len(result) > 0:
            res_id = result[0][0]
        else:
            res_id = str(uuid.uuid1())
            self.error(
                "UpgradeLocalNM get_res_id_from tbl_ne is null, ne_id is: %s, resId is %s" % (str(ne_id), res_id))
        return res_id

    def convert_data(self, paras):
        tenantId = "default-organization-id"
        neId = paras[self.src_table_cols_index.get("iNMID")]
        # 从tbl_ne表中取resId,保证和neBasicinfo表中resId一致
        # self.idmapping_dic.get(str(paras[self.src_table_cols_index.get("iNMID")]))
        resId = self.get_res_id_from_ne(neId)
        name = paras[self.src_table_cols_index.get("strNMName")]
        ipAddr = paras[self.src_table_cols_index.get("iPAddress")]
        if ipAddr and str(ipAddr).isdigit():
            iIp = int(ipAddr)
            ipAddr = "%s.%s.%s.%s" % tuple((iIp >> x & 0xFF) for x in range(24, -1, -8))
        else:
            self.warning("iPAddress is not valid, %s" % str(paras))
        is_nm_map = {
            "0": False,
            "1": True
        }
        isLocalNM = is_nm_map.get(str(paras[self.src_table_cols_index.get("iIsLocalNM")]))
        typeName = "NM"
        typeId = paras[self.src_table_cols_index.get("iDevType")]
        if typeId and str(typeId).isdigit():
            typeId = int(typeId) * 65536
        else:
            self.warning("typeId is not valid, %s" % str(paras))
        remark = paras[self.src_table_cols_index.get("strMemo")]
        owner = paras[self.src_table_cols_index.get("strOwner")]
        userLabel = paras[self.src_table_cols_index.get("strUserLabel")]
        userName = paras[self.src_table_cols_index.get("strUserName")]
        createAt = paras[self.src_table_cols_index.get("strCreateTime")]
        if createAt and str(createAt).isdigit():
            createAt = int(createAt) * 1000
        else:
            self.warning("strCreateTime is not valid, %s" % str(paras))
        updateAt = int(time.time() * 1000)

        return tuple(None if x is None else str(x) for x in (
            tenantId, resId, neId, name, ipAddr, isLocalNM, typeName, typeId, remark, owner, userLabel, userName,
            createAt, updateAt))

    def covert_ip_addr(self, c_uniqueId):
        if re.compile(r"([\d]{1,3}\.[\d]{1,3}\.[\d]{1,3}\.[\d]{1,3})").match(c_uniqueId):
            # 能匹配上，就是IPv4
            return ("ipv4://" + c_uniqueId).replace("_", "/")
        else:
            # 不是IPv4就是IPv6,因为IPv6的格式比较灵活，所以只要不是IPv4的就当做IPv6处理
            return ("ipv6://" + c_uniqueId).replace("_", "/")

    def trans_id_alloc_data(self, local_nm_data):
        # dst_table_cols = ("uniqueId", "resId", "neId", "tenantId", "createAt", "updateAt")
        id_alloc_data = list()
        id_alloc_data.append(self.covert_ip_addr(local_nm_data[self.dst_table_cols.index("ipAddr")]))
        id_alloc_data.append(local_nm_data[self.dst_table_cols.index("resId")])
        id_alloc_data.append(local_nm_data[self.dst_table_cols.index("neId")])
        id_alloc_data.append("default-organization-id")
        id_alloc_data.append(int(time.time() * 1000))
        id_alloc_data.append(int(time.time() * 1000))
        return tuple(None if x is None else str(x) for x in id_alloc_data)

    def generate_insert_stmt(self, table_name, col_names, need_symbol=False):
        """
        生成插入语句
        :param table_name: 待插入表名
        :param col_names: 对应列
        :param need_symbol: 是否需要``反引号
        :return: insert语句
        """
        col_stmt = ",".join(col_names) if not need_symbol else "`" + ("`, `".join(col_names)) + "`"
        val_ids = ":" + (",:".join((str(x + 1) for x in range(len(col_names)))))
        table_name = "`%s`" % table_name if need_symbol else table_name
        insert_stmt = "INSERT INTO %s (%s) VALUES(%s)" % (table_name, col_stmt, val_ids)
        return insert_stmt

    def insert_idmapping(self, insert_stmt, datas):
        """
        获取idmappingdb，插入t_ids表
        :param insert_stmt: 插入语句
        :param datas: ids数据
        :return: None
        """
        idmappingdb_sess = None
        id_mapping_cursor = None
        try:
            self.debug("sql: %s" % insert_stmt)
            self.debug("list_datas: %s" % datas)
            idmappingdb_sess = get_zenith_session('idmappingdb', 'idmappingdb', self.product_name)
            if idmappingdb_sess is None:
                self.error("get idmappingdb session fail")
                return
            idmappingdb_sess.autocommit(True)
            id_mapping_cursor = idmappingdb_sess.cursor()
            id_mapping_cursor.executemany(insert_stmt, datas)
        except BaseException as be:
            self.warning(be)
            for data in datas:
                try:
                    id_mapping_cursor.executemany(insert_stmt, [data])
                except IndexError as ie:
                    self.warning("err is: %s" % ie)
                    self.warning("the sql is: %s" % insert_stmt)
                    self.warning("the data is: %s" % str(data))
                except BaseException as be2:
                    self.warning("err is: %s" % be2)
                    self.warning("the sql is: %s" % insert_stmt)
                    self.warning("the data is: %s" % str(data))
        finally:
            if id_mapping_cursor:
                id_mapping_cursor.close()
            if idmappingdb_sess:
                idmappingdb_sess.close()

    def to_UpgradePara(self, stelnet_templates):
        self.get_resid_from_idmapping(stelnet_templates)
        # LocalNM表插入语句
        insert_stmt = self.generate_insert_stmt(self.dst_table, self.dst_table_cols, True)
        # t_ids表插入语句
        t_ids_cols = ("ID", "NATIVEID", "SOURCETAG", "CREATETIME")
        insert_t_ids_stmt = self.generate_insert_stmt("T_IDS", t_ids_cols)
        # NeIdAlloc 插入语句
        id_alloc_cols = ("uniqueId", "resId", "neId", "tenantId", "createAt", "updateAt")
        insert_id_alloc_stmt = self.generate_insert_stmt("NeIdAlloc", id_alloc_cols, True)
        self.debug("%s" % insert_stmt)
        self.debug("%s" % insert_t_ids_stmt)
        self.debug("%s" % insert_id_alloc_stmt)
        list_datas = []
        # 添加t_ids 和 NeIdAlloc表
        t_ids_data = []
        t_ne_id_alloc_data = []
        default_local_nm = []
        dst_cols_len = len(self.dst_table_cols)
        for tempalte in stelnet_templates:
            self.debug("original data is: %s, length is:%s" % (tempalte, len(tempalte)))
            data = self.convert_data(tempalte)
            self.debug("coverted data is: %s, length is:%s" % (data, len(data)))
            if len(data) == dst_cols_len:
                tuple_data = tuple(data)
                # 默认本机网管会改的字段只有三个：
                # name、remarks、ip address
                # 升级的时候需要执行update语句
                # 其他非默认网管网元，则执行insert语句
                if tuple_data[self.dst_table_cols.index("isLocalNM")] == str(False):
                    list_datas.append(tuple_data)
                    t_ids_data.append(
                        (data[self.dst_table_cols.index("resId")],
                         "NE=%s" % str(data[self.dst_table_cols.index("neId")]),
                         "NeResEntryService", str(int(time.time() * 1000))))
                    t_ne_id_alloc_data.append(self.trans_id_alloc_data(data))
                else:
                    default_local_nm.append(tuple_data)
            else:
                self.warning("coverted data length not equals dst table cols, "
                             "to be ignored. src data: %s, dst data: %s" % (tempalte, data))
            if len(list_datas) == ONE_PACKAGE:
                self.exec_sql(insert_stmt, list_datas)
                list_datas = []
            if len(t_ne_id_alloc_data) == ONE_PACKAGE:
                self.exec_sql(insert_id_alloc_stmt, t_ne_id_alloc_data)
                t_ne_id_alloc_data = []
            if len(t_ids_data) == ONE_PACKAGE:
                self.insert_idmapping(insert_t_ids_stmt, t_ids_data)
                t_ids_data = []

        if len(list_datas) != 0:
            self.exec_sql(insert_stmt, list_datas)
        if len(t_ne_id_alloc_data) != 0:
            self.exec_sql(insert_id_alloc_stmt, t_ne_id_alloc_data)
        if len(t_ids_data) != 0:
            self.insert_idmapping(insert_t_ids_stmt, t_ids_data)

        if default_local_nm:
            cols = ("name", "remark", "ipAddr")
            set_col_val = "`%s` = ?" % ("` = ?, `".join(cols))
            update_stmt = "update `%s` set %s where `isLocalNM` = 'True'" % (self.dst_table, set_col_val)
            datas = list([tuple([x[self.dst_table_cols.index(col)] for col in cols]) for x in default_local_nm])
            self.debug(datas)
            self.exec_sql(update_stmt, datas)

    def close_session(self):
        self.dst_db_cursor.close()
        self.dst_db_session.close()
        self.src_db_cursor.close()
        self.src_db_session.close()

    def do(self):
        try:
            self.info('UpgradeLocalNM::do start!')
            select_stmt = 'select %s from %s' % (",".join(self.src_table_cols), self.src_table)
            templates = self.exec_query_sql(self.src_db_cursor, select_stmt)
            self.info('get para data success count: %d' % len(templates))

            self.to_UpgradePara(templates)
            self.info('to_UpgradeLocalNM done')
            self.close_session()
        except Exception as e:
            self.close_session()
            self.error('UpgradeLocalNM got exception')
            self.error(traceback.format_exc())
            return -1
        self.info('UpgradeLocalNM::do done')
        return 0


if __name__ == '__main__':
    tool = UpgradeLocalNM()
    print('[INFO] UpgradeLocalNM start>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
    tool.do()
    print('[INFO] UpgradeLocalNM finished<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
