#!/usr/bin/python
# -*- coding: utf-8 -*-
import re
import time
import traceback

from getDBConnection import get_zenith_session
from common_tasks.base_task import BaseTask

ONE_PACKAGE = 500


class UpgradeNeBasicInfo(BaseTask):
    def __init__(self, product_name="NCE"):
        super(UpgradeNeBasicInfo, self).__init__()
        self.set_product_name(product_name)
        self.info("UpgradeNeBasicInfo init product_name is %s" % product_name)
        self.src_db_session = get_zenith_session('eamdb', 'eamdb', product_name)
        if self.src_db_session is None:
            self.error("eamdb is None")
            return
        self.src_db_session.autocommit(True)
        self.src_db_cursor = self.src_db_session.cursor()

        self.dst_db_session = get_zenith_session('neresdb', 'neresdb', product_name)
        if self.dst_db_session is None:
            self.error("neresdb_sess is None")
            return
        self.dst_db_session.autocommit(True)
        self.dst_db_cursor = self.dst_db_session.cursor()

        self.src_table = "tbl_ne"
        self.src_table_cols = ("resId", "dn", "label", "productName", "typeId", "ipaddr", "memo", "parentResId")
        self.src_table_cols_index = {y:x for x, y in enumerate(self.src_table_cols)}
        self.dst_table = "NeBasicInfo"
        self.dst_table_cols = ("resId", "tenantId", "neId", "name", "typeName", "typeId", "uniqueAddress",
                               "remark", "alias", "refParentSubnet", "manageStatus", "createAt", "updateAt")
        self.idmapping_dic = {}
        self.neid_uniqueid_dic = {}
        self.ne_one_dic = {}
        self.one_subnet = {}
        self.create_time_dic = {}
        self.localnm_create_time_dic = {}

        self.phy_cols = ("nativeId", "name", "productName", "typeId", "remark", "alias", "createTime")
        self.phy_cols_index = {y: x for x, y in enumerate(self.phy_cols)}

    def get_resid_from_idmapping(self, paras):
        topodb_sess = get_zenith_session('topodb', 'topodb', self.product_name)
        if topodb_sess is None:
            self.error("get topodb session fail")
            return
        topodb_sess.autocommit(True)
        topodb_sess_cursor = topodb_sess.cursor()
        # 只获取物理视图（topoid=101）中别名alias字段非空的数据，这样可以大大减少无效数据
        query_alias_stmt = "select resId, alias from tbl_node where topoid=101 and alias is not null"
        self.debug("exec sql: %s" % query_alias_stmt)
        topodb_sess_cursor.execute(query_alias_stmt)
        self.idmapping_dic = dict(topodb_sess_cursor.fetchall())
        self.debug("idmapping_dic: %s" % self.idmapping_dic)

        topodb_sess_cursor.close()
        topodb_sess.close()

    def get_id_uniqueid_mapping(self):
        neresdb_sess = get_zenith_session('neresdb', 'neresdb', self.product_name)
        if neresdb_sess is None:
            self.error("get neresdb session fail")
            return
        neresdb_sess.autocommit(True)
        neresdb_cursor = neresdb_sess.cursor()
        select_stmt = "select `neId`, `uniqueId` from `NeIdAlloc`"
        self.debug("exec sql: %s" % select_stmt)
        neresdb_cursor.execute(select_stmt)
        result = list(neresdb_cursor.fetchall())
        # 这里强制转换成字符串，避免数据库字段类型的影响。
        self.neid_uniqueid_dic = dict([(str(x), str(y)) for x, y in result])
        self.debug("neid_uniqueid_dic: %s" % self.neid_uniqueid_dic)
        neresdb_cursor.close()
        neresdb_sess.close()

    def get_ne_one_mapping(self):
        trans_common_db_sess = get_zenith_session('TransCommonDB', 'TransCommonDB', self.product_name)
        if trans_common_db_sess is None:
            self.error("get TransCommonDB session fail")
            return
        trans_common_db_sess.autocommit(True)
        trans_common_db_cursor = trans_common_db_sess.cursor()
        select_stmt = "select f.cValue, t.cDWDMNEID from tTEDWDMNE t inner join tTEUnsignedLong f on t.cID = f.PID"
        result = self.exec_query_sql(trans_common_db_cursor, select_stmt)
        # 这里强制转换成字符串，避免数据库字段类型的影响。
        self.ne_one_dic = dict([(str(x), str(y)) for x, y in result])
        self.debug("ne_one_dic: %s" % self.ne_one_dic)
        trans_common_db_cursor.close()
        trans_common_db_sess.close()

        # 这里使用self.src_db_cursor进行查询即可
        sql = "select DN, parentResId from tbl_ne where typeId in " \
              "(123863040,123928576,123994112,124059648,124387328)"
        ret = self.exec_query_sql(self.src_db_cursor, sql)
        for dn, parent_resid in ret:
            self.one_subnet[dn.lstrip("ONE=")] = parent_resid
        self.debug("one_subnet: %s" % self.one_subnet)

    def get_create_time_mapping(self):
        db_sess = get_zenith_session('cmdbcoresvrdb', 'cmdbcoresvrdb', self.product_name)
        if db_sess is None:
            self.error("get cmdbcoresvrdb session fail")
            return
        db_sess.autocommit(True)
        db_cursor = db_sess.cursor()
        select_stmt = "select `nativeId`, `createTime` from i_fixednetworkelement"
        # SQL> select `nativeId`, `createTime` from i_fixednetworkelement;
        #
        # nativeId                                                         createTime
        # ---------------------------------------------------------------- --------------------
        # NE=167772161                                                     1644801354000
        #
        # 1 rows fetched.
        self.debug("exec sql: %s" % select_stmt)
        db_cursor.execute(select_stmt)
        result = list(db_cursor.fetchall())
        # 这里强制转换成字符串，避免数据库字段类型的影响。
        self.create_time_dic = dict([(str(x), str(y)) for x, y in result])
        self.debug("create_time_dic: %s" % self.create_time_dic)
        db_cursor.close()
        db_sess.close()

    def get_localnm_create_time_mapping(self):
        db_sess = get_zenith_session('neresdb', 'neresdb', self.product_name)
        if db_sess is None:
            self.error("get cmdbcoresvrdb session fail")
            return
        db_sess.autocommit(True)
        db_cursor = db_sess.cursor()
        select_stmt = "select `neId`, `createAt` from `LocalNM`"
        # SQL> select `neId`, `createAt` from `LocalNM`;
        #
        # neId                 createAt
        # -------------------- --------------------
        # 4161536              1647925719000
        # 167794948            1648104840429
        # 4161537              1647925721000
        # 4161538              1647925721000
        # 1                    -1
        # 167794973            1648105034013
        #
        # 6 rows fetched.

        self.debug("exec sql: %s" % select_stmt)
        db_cursor.execute(select_stmt)
        result = list(db_cursor.fetchall())
        # 这里强制转换成字符串，避免数据库字段类型的影响。
        self.localnm_create_time_dic = dict([(str(x), str(y)) for x, y in result])
        self.debug("localnm_create_time_dic: %s" % self.localnm_create_time_dic)
        db_cursor.close()
        db_sess.close()


    def convert_data(self, paras):
        def covert_ip_addr(c_uniqueId):
            if re.compile(r"([\d]{1,3}\.[\d]{1,3}\.[\d]{1,3}\.[\d]{1,3})").match(c_uniqueId):
                # 能匹配上，就是IPv4
                return ("ipv4://" + c_uniqueId).replace("_", "/")
            else:
                # 不是IPv4就是IPv6,因为IPv6的格式比较灵活，所以只要不是IPv4的就当做IPv6处理
                return ("ipv6://"+c_uniqueId).replace("_", "/")

        resId = paras[self.src_table_cols_index.get("resId")]
        tenantId = "default-organization-id"
        neId = paras[self.src_table_cols_index.get("dn")].replace("NE=","").replace("OS=","")
        if neId and not str(neId).isdigit():
            self.debug("neId is not number, data is ignored: %s" % neId)
            return []
        name = paras[self.src_table_cols_index.get("label")]
        typeName = paras[self.src_table_cols_index.get("productName")]
        typeId = paras[self.src_table_cols_index.get("typeId")]
        uniqueAddress = self.neid_uniqueid_dic.get(str(neId))
        self.debug("uniqueAddress: %s" % uniqueAddress)
        if not uniqueAddress:
            ipaddr = paras[self.src_table_cols_index.get("ipaddr")]
            if ipaddr is None:
                # 特殊情况下，ipaddr可能是None值
                uniqueAddress = None
            else:
                # 这里确保ipaddr一定是字符串类型
                ipaddr = str(ipaddr)
                if re.compile(r"([\d]+\-[\d]+)").match(ipaddr):
                    # 如果是xx-yy的格式，则是网元ID
                    # 物理ID：9-9528
                    # 转化成：phyid://xx-yy
                    # Done: 方案刷新，直接拼接xx-yy格式的网元ID
                    uniqueAddress = "phyid://%s" % ipaddr
                else:
                    # IP的地址格式跟FAN的处理过程一样
                    uniqueAddress = covert_ip_addr(ipaddr)
                    # 对于dummy设备，在uniqueAddress后加_dummy_resId
                    if typeId == "51707905":
                        uniqueAddress = "%s_dummy_%s" % (uniqueAddress, resId)

        remark = paras[self.src_table_cols_index.get("memo")]
        alias = self.idmapping_dic.get(str(resId))
        if self.ne_one_dic.get(neId):
            # 从eamdb.tbl_ne表查询所有带升级的网元信息，拿网元neId到缓存在内存中的ne_one.cValue匹配
            # 若匹配成功，代表网元挂在光网元内
            # 拿ne_one.cDWDMNEID拼上ONE=，到one_subnet中匹配，找到对应的所属子网IDparentResId，赋值
            refParentSubnet = self.one_subnet.get(self.ne_one_dic.get(neId))
        else:
            # 若匹配不成功，则代表网元未挂接在光网元内，直接用网元的parentResId赋值
            refParentSubnet = paras[self.src_table_cols_index.get("parentResId")]
        manageStatus = "0"
        createTime = str(self.create_time_dic.get(paras[self.src_table_cols_index.get("dn")],
                                                  self.localnm_create_time_dic.get(neId, 0)))
        if createTime.isdigit():
            # 这里还涉及到createTime与createAt之间的单位转换，createAt的单位是ms，createTime的单位待确定
            createAt = createTime
        else:
            createAt = str(0)
        updateAt = str(int(time.time() * 1000))

        return tuple(None if x is None else str(x) for x in (
            resId, tenantId, neId, name, typeName, typeId, uniqueAddress, remark, alias,refParentSubnet, manageStatus,
            createAt, updateAt))

    def get_uniqueAddress_res_id(self, neId):
        """
        通过neId从NeIdAlloc表中查询uniqueId和resid
        :param neId:
        :return:
        """
        neresdb_sess = get_zenith_session('neresdb', 'neresdb', self.product_name)
        if neresdb_sess is None:
            self.error("get neresdb session fail")
            return
        neresdb_sess.autocommit(True)
        neresdb_cursor = neresdb_sess.cursor()
        select_stmt = "select `neId`, `uniqueId`, `resId` from `NeIdAlloc` where `neId` = '%s'" % str(neId)
        self.debug("exec sql: %s" % select_stmt)
        neresdb_cursor.execute(select_stmt)
        result = list(neresdb_cursor.fetchall())
        # 这里强制转换成字符串，避免数据库字段类型的影响。
        if not result:
            return None, None
        uniqueAddress = result[0][1]
        resId = result[0][2]
        self.debug("one neid get NeIdAlloc, uniqueAddress is %s, resId is %s" % (uniqueAddress, resId))
        neresdb_cursor.close()
        neresdb_sess.close()
        return uniqueAddress, resId

    def convert_special_data(self, paras):
        """
        转换i_fixednetworkelement的数据
        :param paras:
        :return:
        """
        neId = paras[self.phy_cols_index.get("nativeId")].replace("NE=", "").replace("OS=", "")
        if neId and not str(neId).isdigit():
            self.debug("neId is not number, data is ignored: %s" % neId)
            return []
        uniqueAddress, resId = self.get_uniqueAddress_res_id(str(neId))
        resId = resId
        tenantId = "default-organization-id"
        name = paras[self.phy_cols_index.get("name")]
        typeName = paras[self.phy_cols_index.get("productName")]
        typeId = paras[self.phy_cols_index.get("typeId")]
        # 通过neId从NeIdAlloc表中查询uniqueId
        uniqueAddress = uniqueAddress
        remark = paras[self.phy_cols_index.get("remark")]
        alias = paras[self.phy_cols_index.get("alias")]
        if self.ne_one_dic.get(str(neId)):
            # 从eamdb.tbl_ne表查询所有带升级的网元信息，拿网元neId到缓存在内存中的ne_one.cValue匹配
            # 若匹配成功，代表网元挂在光网元内
            # 拿ne_one.cDWDMNEID拼上ONE=，到one_subnet中匹配，找到对应的所属子网IDparentResId，赋值
            refParentSubnet = self.one_subnet.get(self.ne_one_dic.get(str(neId)))
        else:
            self.error("convert_special_data error: %s not exist in ne_one_dic" % name)
            return []
        manageStatus = "0"
        createTime = paras[self.phy_cols_index.get("createTime")]
        if str(createTime).isdigit():
            # 这里还涉及到createTime与createAt之间的单位转换，createAt的单位是ms，createTime的单位待确定
            createAt = createTime
        else:
            createAt = str(0)
        updateAt = str(int(time.time() * 1000))
        return tuple(None if x is None else str(x) for x in (
            resId, tenantId, neId, name, typeName, typeId, uniqueAddress, remark, alias, refParentSubnet, manageStatus,
            createAt, updateAt))

    def to_UpgradePara(self, datas):
        self.get_resid_from_idmapping(datas)
        self.get_id_uniqueid_mapping()
        # 这里当升级环境中不存在TransCommonDB时，也就是说没有T的数据，则忽略掉
        try:
            self.get_ne_one_mapping()
        except TypeError as te:
            self.warning("get_ne_one_mapping exception: %s" % te)
        except BaseException as be:
            self.warning("get_ne_one_mapping exception: %s" % be)

        try:
            self.get_create_time_mapping()
            self.get_localnm_create_time_mapping()
        except TypeError as te:
            self.warning("get_create_time_mapping exception: %s" % te)
        except BaseException as be:
            self.warning("get_create_time_mapping exception: %s" % be)

        col_names = "`" + ("`, `".join(self.dst_table_cols)) + "`"
        val_ids = ":" + (",:".join((str(x+1) for x in range(len(self.dst_table_cols)))))
        insert_stmt = "INSERT INTO `%s` (%s) VALUES(%s)" % (self.dst_table, col_names, val_ids)
        self.debug("insert sql stmt: %s" % insert_stmt)
        list_datas = []
        for data in datas:
            self.debug("original data is: %s, length is:%s" % (data, len(data)))
            data = self.convert_data(data)
            self.debug("coverted date is: %s, length is:%s"%(data, len(data)))
            if len(data) == len(self.dst_table_cols):
                tuple_data = tuple(data)
                list_datas.append(tuple_data)
            if len(list_datas) == ONE_PACKAGE:
                self.exec_sql(insert_stmt, list_datas)
                list_datas = []

        if len(list_datas) != 0:
            self.debug("data is:%s"%list_datas)
            self.exec_sql(insert_stmt, list_datas)

    def to_UpgradeSpecialPara(self, datas):
        """
        升级光网元下不在tbl_ne表中的网元
        :param datas:
        :return:
        """
        if not datas:
            self.info("to_UpgradeSpecialPara no data")
            return
        col_names = "`" + ("`, `".join(self.dst_table_cols)) + "`"
        val_ids = ":" + (",:".join((str(x + 1) for x in range(len(self.dst_table_cols)))))
        insert_stmt = "INSERT INTO `%s` (%s) VALUES(%s)" % (self.dst_table, col_names, val_ids)
        self.debug("insert sql stmt: %s" % insert_stmt)
        list_datas = []
        for data in datas:
            self.debug("original special data is: %s, length is:%s" % (data, len(data)))
            data = self.convert_special_data(data)
            self.debug("coverted special date is: %s, length is:%s" % (data, len(data)))
            if len(data) == len(self.dst_table_cols):
                tuple_data = tuple(data)
                list_datas.append(tuple_data)
            if len(list_datas) == ONE_PACKAGE:
                self.exec_sql(insert_stmt, list_datas)
                list_datas = []

        if len(list_datas) != 0:
            self.debug("data is:%s" % list_datas)
            self.exec_sql(insert_stmt, list_datas)

    def split_list(self, data, batch_no=500):
        result = list()
        if len(data) <= batch_no:
            result.append(data)
            return result
        count = int(len(data) / batch_no) + 1
        for x in range(count):
            if x == count - 1:
                result.append(data[x * batch_no:])
                continue
            result.append(data[x * batch_no:(x + 1) * batch_no])
        return result

    def get_ne_by_neid(self, one_not_existed):
        """
        根据网元id  NE='' 从i_fixednetworkelement表中查询网元信息
        :param one_not_existed:
        :return:
        """
        res = list()
        if not one_not_existed:
            self.info("one_not_existed is null")
            return res
        db_sess = get_zenith_session('cmdbcoresvrdb', 'cmdbcoresvrdb', self.product_name)
        if db_sess is None:
            self.error("get cmdbcoresvrdb session fail")
            return
        db_sess.autocommit(True)
        db_cursor = db_sess.cursor()
        for x in self.split_list(one_not_existed):
            select_stmt = "select %s from i_fixednetworkelement where `nativeId` in (%s)" % (
                "`" + "`,`".join(self.phy_cols) + "`", "'" + "','".join(x) + "'")
            self.debug("exec sql: %s" % select_stmt)
            db_cursor.execute(select_stmt)
            result = list(db_cursor.fetchall())
            res.extend(result)
        return res

    def get_special_ne_from_phy(self):
        """
        从 i_fixednetworkelement 表中获取光网元下的网元数据
        :return:
        """
        if not self.ne_one_dic:
            return list()
        one_ne_ids = ["NE=" + str(x) for x in self.ne_one_dic.keys()]
        one_ne_id_batch = self.split_list(one_ne_ids)
        one_existed = list()
        for i in one_ne_id_batch:
            select_stmt = "select dn from %s where dn in (%s)" % (self.src_table, "'" + "','".join(i) + "'")
            datas = self.exec_query_sql(self.src_db_cursor, select_stmt)
            datas = [str(i[0]) for i in datas]
            one_existed.extend(datas)
        one_not_existed = [i for i in one_ne_ids if i not in one_existed]
        return self.get_ne_by_neid(one_not_existed)

    def close_session(self):
        self.dst_db_cursor.close()
        self.dst_db_session.close()
        self.src_db_cursor.close()
        self.src_db_session.close()

    def do(self):
        try:
            self.info ('UpgradeNeBasicInfo::do start!')
            # 查询neresdb数据库DevTypeRegInfo表，匹配typeId字段，能够匹配成功的执行升级；
            # 匹配不上的代表非纳管范围内设备类型，不参与升级
            # 问题修改：不做typeId在DevTypeRegInfo表中的限制，因为业务微服务在升级时没有启动，没有注册到DevTypeRegInfo。
            # 若升级前Dummy设备与华为设备存在IP地址冲突：
            # （1）若升级后华为设备丢失，需要手工删除Dummy设备，使用指定IP手工添加华为设备；
            # （2）若升级后Dummy设备丢失，需要使用其他IP地址创建Dummy设备
            # 优先升级非Dummy类型的设备，再升级Dummy设备
            dummy_type_id = "51707905"
            # 光网元不参与升级
            #附：光网元类型列表
            # 类型ID	类型名称	策略
            # 123863040	WDM_IdleONE	不参与升级
            # 123928576	WDM_OTM	不参与升级
            # 123994112	WDM_OADM	不参与升级
            # 124059648	WDM_OLA	不参与升级
            # 124387328	WDM_OEQ	不参与升级
            one_ids = ["123863040", "123928576", "123994112", "124059648", "124387328"]
            select_stmt_not_dummy = "select %s from %s where typeId not in (%s)" % (
                ",".join(self.src_table_cols), self.src_table, ",".join(one_ids + [dummy_type_id]))
            select_stmt_dummy = "select %s from %s where typeId='%s'" % (",".join(self.src_table_cols), self.src_table,
                                                                    dummy_type_id)
            # 确保Dummy设备在后面，这样入库时如果有冲突，则优先丢弃dummy设备；
            filtered_datas = []
            for select_stmt in (select_stmt_not_dummy, select_stmt_dummy):
                datas = self.exec_query_sql(self.src_db_cursor, select_stmt)
                # 为了提升效率，在查询得到结果之后，再用代码过滤掉不在范围内的数据：
                # eamdb.tbl_ne表中DN为OSS和OS=1的设备不升级，这两个设备为平台默认设备（本机网管设备），不参与升级
                # eamdb.tbl_ne表中DN为ONE开头的设备为光网元，由于纳管不管理光网元，故不参与升级
                for data in datas:
                    dn = data[self.src_table_cols_index.get("dn")]
                    if dn in ("OSS", "OS=1") or dn.startswith("ONE"):
                        continue
                    else:
                        filtered_datas.append(data)

            datas = list(filtered_datas)
            self.info('get para data success count: %d' % len(datas))

            # 升级tbl_ne中的网元
            self.to_UpgradePara(datas)
            # 【背景】传送历史版本缺陷，直接在光网元内创建网元，不会入平台拓扑，在R20C10补丁版本后续才解决。
            # 升级光网元下的网元，且不入tbl_ne表的网元数据
            self.info('upgrade one ne::do start!')
            special_nes = self.get_special_ne_from_phy()
            self.info('get ne not in tbl_ne data success count: %d' % len(special_nes))
            self.to_UpgradeSpecialPara(special_nes)

            self.info('to_UpgradePara done')
            self.close_session()
        except Exception as e:
            self.close_session()
            self.error('UpgradeNeBasicInfo got exception')
            self.error(traceback.format_exc())
            return -1
        self.info('UpgradeNeBasicInfo::do done')
        return 0


if __name__ == '__main__':

    tool = UpgradeNeBasicInfo()
    print('[INFO] UpgradeNeBasicInfo start>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
    tool.do()
    print('[INFO] UpgradeNeBasicInfo finished<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
