#!/usr/bin/python
# -*- coding: utf-8 -*-
import re
import traceback
import time

from getDBConnection import get_zenith_session

import common_tasks.const_sql as const_sql
from common_tasks.base_task import BaseTask

ONE_PACKAGE = 500


class UpgradeNeIdAlloc(BaseTask):
    def __init__(self, product_name="NCE"):
        super(UpgradeNeIdAlloc, self).__init__()
        self.set_product_name(product_name)
        self.info("UpgradeNeIdAlloc init product_name is %s" % product_name)
        self.src_db_session = get_zenith_session('ISDB', 'ISDB', product_name)
        if self.src_db_session is None:
            self.error("ISDB is None")
            return
        self.src_db_session.autocommit(True)
        self.src_db_cursor = self.src_db_session.cursor()

        self.dst_db_session = get_zenith_session('neresdb', 'neresdb', product_name)
        if self.dst_db_session is None:
            self.error("neresdb_sess is None")
            return
        self.dst_db_session.autocommit(True)
        self.dst_db_cursor = self.dst_db_session.cursor()

        self.src_table = "t_neam_neid"
        self.src_table_cols = ("c_uniqueid", "c_neid")
        self.src_table_cols_index = {y: x for x, y in enumerate(self.src_table_cols)}
        self.dst_table = "NeIdAlloc"
        self.dst_table_cols = ("uniqueId", "resId", "neId", "createAt", "updateAt")
        self.idmapping_dic = {}
        # 记录转换后重复的unique地址和转换前后uniqueid的映射关系
        self.unique_id_repeat = set()
        self.unique_id_dict = {}
        self.old_unique_id_resid_dict = {}
        self.unique_id_records = set()

    def get_resid_from_idmapping(self, paras):
        idmappingdb_sess = get_zenith_session('idmappingdb', 'idmappingdb', self.product_name)
        if idmappingdb_sess is None:
            self.error("get idmappingdb_sess session fail")
            return
        idmappingdb_sess.autocommit(True)
        id_mapping_cursor = idmappingdb_sess.cursor()

        tmp_table_name = "tmp_neid_%s" % self.dst_table
        id_mapping_cursor.execute(const_sql.DROP_TEMP_TABLE % tmp_table_name)
        id_mapping_cursor.execute(const_sql.CREATE_TEMP_TABLE % tmp_table_name)
        insert_stmt = "insert into tmp_%s values(:1)" % tmp_table_name

        nativeIds = []
        for para in paras:
            data = []
            if str(para[self.src_table_cols_index.get("c_uniqueid")]).endswith("_transone"):
                # 光网元，nativeId为ONE=xxxx
                nativeId = "ONE=%s" % para[self.src_table_cols_index.get("c_neid")]
            else:
                nativeId = "NE=%s" % para[self.src_table_cols_index.get("c_neid")]
            data.append(nativeId)
            tuple_data = tuple(data)
            nativeIds.append(tuple_data)
            if len(nativeIds) == ONE_PACKAGE:
                id_mapping_cursor.executemany(insert_stmt, nativeIds)
                self.debug("one package: %s" % nativeIds)
                nativeIds = []

        if len(nativeIds) != 0:
            id_mapping_cursor.executemany(insert_stmt, nativeIds)
            self.debug("last package datas: %s" % nativeIds)

        query_stmt = const_sql.INNER_JOIN_TEMP_TABLE % tmp_table_name
        self.debug('[get_resid_from_idmapping]query sql: %s' % query_stmt)
        id_mapping_cursor.execute(query_stmt)
        result = id_mapping_cursor.fetchall()

        for r in result:
            l = list(r)
            self.idmapping_dic[l[1].lstrip("NE=").lstrip("ONE=")] = l[0]
            self.debug("result: %s: %s" % (l[1], l[0]))

        id_mapping_cursor.execute(const_sql.DROP_TEMP_TABLE % tmp_table_name)
        id_mapping_cursor.close()
        idmappingdb_sess.close()

    def convert_data(self, paras):
        def covert_ip_addr(prefix, c_uniqueId):
            if re.compile(r"%s([\d]{1,3}\.[\d]{1,3}\.[\d]{1,3}\.[\d]{1,3})" % prefix).match(c_uniqueId):
                # 能匹配上，就是IPv4
                return c_uniqueId.replace(prefix, "ipv4://").replace("_", "/")
            else:
                # 不是IPv4就是IPv6,因为IPv6的格式比较灵活，所以只要不是IPv4的就当做IPv6处理
                return c_uniqueId.replace(prefix, "ipv6://").replace("_", "/")

        c_uniqueId = paras[self.src_table_cols_index.get("c_uniqueid")]
        # c_uniqueId的值可能是None，需要做None值判断
        if c_uniqueId is None:
            uniqueId = c_uniqueId
        else:
            # 这里确保c_uniqueId一定是字符串类型
            c_uniqueId = str(c_uniqueId)
            if c_uniqueId.startswith("access_"):
                #  FAN的地址格式：
                # 1）access_10.1.1.1
                # 2）access_10.1.1.1_0.2.1.14.12（ONU无IP地址场景）
                # 需要转化成
                # 1）ipv4://10.1.1.1; ipv4://10.1.1.1/0.2.1.14.12（ONU无IP地址场景）
                # 2）ipv6://xxxx.xxxx.xxxx.xxxx.xxxx.xxxx.xxxx.xxxx
                uniqueId = covert_ip_addr("access_", c_uniqueId)
            elif c_uniqueId.startswith("bits_"):
                # BITS的地址格式的处理过程一样
                uniqueId = covert_ip_addr("bits_", c_uniqueId)
            elif c_uniqueId.startswith("ip_"):
                # IP的地址格式跟FAN的处理过程一样
                uniqueId = covert_ip_addr("ip_", c_uniqueId)
            elif c_uniqueId.startswith("3rd_"):
                # 3rd的地址格式跟FAN的处理过程一样
                uniqueId = covert_ip_addr("3rd_", c_uniqueId)
            elif c_uniqueId.startswith("em3rd_ip_"):
                # 3rd的地址格式跟FAN的处理过程一样
                uniqueId = covert_ip_addr("em3rd_ip_", c_uniqueId)
            elif c_uniqueId.endswith("_transqx") or c_uniqueId.endswith("_ptnv8"):
                # 物理ID：8978553_transqx，12345678_ptnv8
                # 转化成：phyid://8978553
                # Done： 物理ID：phyid://24-166（计算方法：1573030/65536-1573030%65536）
                # 不管后缀是_transqx还是_ptnv8，全都去除掉
                s_uniqueId = c_uniqueId.replace("_transqx", "").replace("_ptnv8", "")
                if s_uniqueId.isdigit():
                    i_uniqueId = int(s_uniqueId)
                    uniqueId = "phyid://%s-%s" % (int(i_uniqueId / 65536), i_uniqueId % 65536)
                else:
                    self.warning("c_uniqueid invalid: %s" % c_uniqueId)
                    uniqueId = c_uniqueId
            elif c_uniqueId.endswith("_transtl1"):
                # 物理ID：8978553_transtl1
                # 转化成：tl1://8978553
                # 不需要计算
                # 后缀是_transtl1，去掉后缀，改为tl1://xxxx的格式
                uniqueId = "tl1://%s" % c_uniqueId.replace("_transtl1", "")
            elif c_uniqueId.endswith("_transone"):
                # 物理ID：global_8978553_transone
                # 转化成：one://global_8978553_transone
                # 不需要计算
                # _transone，改为one://xxxx的格式
                uniqueId = "one://%s" % c_uniqueId
            else:
                uniqueId = c_uniqueId

        self.unique_id_repeat.add(uniqueId) if uniqueId in self.unique_id_records else self.unique_id_records.add(
            uniqueId)
        if uniqueId not in self.unique_id_dict.keys():
            self.unique_id_dict[uniqueId] = list()
        self.unique_id_dict[uniqueId].append(c_uniqueId)
        resId = self.idmapping_dic.get(str(paras[self.src_table_cols_index.get("c_neid")]))
        self.old_unique_id_resid_dict[c_uniqueId] = resId

        neId = paras[self.src_table_cols_index.get("c_neid")]

        createAt = str(int(time.time() * 1000))
        updateAt = str(int(time.time() * 1000))
        return tuple(None if x is None else str(x) for x in (uniqueId, resId, neId, createAt, updateAt))

    def batch_insert_data(self, data, batch_num=ONE_PACKAGE):
        """
        分批插入数据库
        :param data:
        :return:
        """
        list_datas = []
        col_names = "`" + ("`, `".join(self.dst_table_cols)) + "`"
        val_ids = ":" + (",:".join((str(x + 1) for x in range(len(self.dst_table_cols)))))
        insert_stmt = "INSERT INTO `%s` (%s) VALUES(%s)" % (self.dst_table, col_names, val_ids)
        self.debug("%s" % insert_stmt)
        self.debug("UpgradeNeIdAlloc batch insert data %s" % len(data))
        for para in data:
            list_datas.append(para)
            if len(list_datas) == batch_num:
                self.exec_sql(insert_stmt, list_datas)
                list_datas = []
        if len(list_datas) != 0:
            self.exec_sql(insert_stmt, list_datas)
        list_datas.clear()

    def get_neinfo_by_repeat_id(self, c_uniqueids=None):
        """
        重复的ip需要从t_neam_neinfo表中判断保留哪条
        原则：升级过程中，可能存在uniqueAddress重复情况，例如3rd_11.11.11.11和ip_11.11.11.11，
        如果发现重复，查询neinfo表，以存在的网元为主，
        如果存在多个，优先丢弃3rd的情况。
        :param c_uniqueids:
        :return:
        """
        neinfo_cols = ("c_uniqueid", "c_neid")
        neinfo_table = "t_neam_neinfo"
        select_stmt = "select %s from %s where c_uniqueid in (%s)" % (",".join(neinfo_cols), neinfo_table,
                                                                      "'" + "','".join(c_uniqueids) + "'")
        datas = self.exec_query_sql(self.src_db_cursor, select_stmt)
        self.info('get %s data success count: %d' % (neinfo_table, len(datas)))
        neinfo_data_dict = {i[0]: i[1] for i in datas}
        self.info("neam_neinfo repeat data: %s" % str(neinfo_data_dict))
        return neinfo_data_dict
        pass

    def handle_repeat_unique_ip(self, list_datas):
        """
        对重复的数据处理，按照如下原则做处理
        原则：升级过程中，可能存在uniqueAddress重复情况，例如3rd_11.11.11.11和ip_11.11.11.11，
        如果发现重复，查询neinfo表，以存在的网元为主，
        如果存在多个，优先丢弃3rd的情况。
        :param list_datas:
        :return:
        """
        repeat_c_unique_id = list()
        [repeat_c_unique_id.extend(self.unique_id_dict.get(i)) for i in self.unique_id_repeat]
        neinfo_data = self.get_neinfo_by_repeat_id(repeat_c_unique_id)
        insert_data = list()
        repeat_data = dict()
        col_names = "`" + ("`, `".join(self.dst_table_cols)) + "`"
        val_ids = ":" + (",:".join((str(x + 1) for x in range(len(self.dst_table_cols)))))
        insert_stmt = "INSERT INTO `%s` (%s) VALUES(%s)" % (self.dst_table, col_names, val_ids)

        for data in list_datas:
            # data字段顺序uniqueId, resId, neId, createAt, updateAt
            if data[0] not in self.unique_id_repeat:
                insert_data.append(data)
            else:
                if data[0] not in repeat_data.keys():
                    repeat_data[data[0]] = list()
                repeat_data[data[0]].append(data)
            if len(insert_data) == ONE_PACKAGE:
                self.exec_sql(insert_stmt, insert_data)
                insert_data = []

        if len(insert_data) != 0:
            self.exec_sql(insert_stmt, insert_data)
            insert_data = []

        for key, val in repeat_data.items():
            self.info("unique_id %s has repeat data %s" % (key, str(val)))
            old_unique_ids = self.unique_id_dict.get(key)
            # 如果发现重复，查询neinfo表，以存在的网元为主，如果存在多个，优先丢弃3rd的情况
            first_final_unique_id = None
            second_final_unique_id = None
            for old_unique_id in old_unique_ids:
                if neinfo_data.get(old_unique_id):
                    second_final_unique_id = old_unique_id
                if neinfo_data.get(old_unique_id) and not str(old_unique_id).startswith("3rd_"):
                    first_final_unique_id = old_unique_id
            final_unique_id = first_final_unique_id or second_final_unique_id or old_unique_ids[0]
            final_res_id = self.old_unique_id_resid_dict.get(final_unique_id)
            final_data = list(filter(lambda x: x[1] == final_res_id, val))
            self.info("unique_id %s final insert data %s" % (key, str(final_data)))
            if not final_data:
                continue
            tuple_data = tuple(final_data[0])
            insert_data.append(tuple_data)
            if len(insert_data) == ONE_PACKAGE:
                self.exec_sql(insert_stmt, insert_data)
                insert_data = []
        if len(insert_data) != 0:
            self.exec_sql(insert_stmt, insert_data)
        insert_data.clear()

    def to_UpgradePara(self, datas):
        self.get_resid_from_idmapping(datas)

        list_datas = []
        dst_cols_len = len(self.dst_table_cols)

        for data in datas:
            self.debug("original data is: %s, length is:%s" % (data, len(data)))
            data = self.convert_data(data)
            self.debug("coverted data is: %s, length is:%s" % (data, len(data)))
            if len(data) == dst_cols_len:
                tuple_data = tuple(data)
                list_datas.append(tuple_data)
            else:
                self.warning("coverted data length not equals dst table cols, to be ignored.")

        datas.clear()

        self.info("unique_id_repeat data:%s" % str(self.unique_id_repeat))
        if not self.unique_id_repeat:
            self.batch_insert_data(list_datas)
        else:
            self.handle_repeat_unique_ip(list_datas)

    def close_session(self):
        self.dst_db_cursor.close()
        self.dst_db_session.close()
        self.src_db_cursor.close()
        self.src_db_session.close()

    def do(self):
        try:
            self.info('UpgradeNeidAlloc::do start!')
            select_stmt = 'select %s from %s' % (",".join(self.src_table_cols), self.src_table)
            datas = self.exec_query_sql(self.src_db_cursor, select_stmt)
            self.info('get para data success count: %d' % len(datas))

            self.to_UpgradePara(datas)
            self.info('to_UpgradePara done')
            self.close_session()
        except Exception as e:
            self.close_session()
            self.error('UpgradeNeIdAlloc got exception')
            self.error(traceback.format_exc())
            return -1
        self.info('UpgradeNeIdAlloc::do done')
        return 0



if __name__ == '__main__':
    tool = UpgradeNeIdAlloc()
    print('[INFO] UpgradeNeIdAlloc start>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
    tool.do()
    print('[INFO] UpgradeNeIdAlloc finished<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
