#!/usr/bin/python
# -*- coding: utf-8 -*-
import time
import traceback
import uuid

from getDBConnection import get_zenith_session
from common_tasks.base_task import BaseTask

ONE_PACKAGE = 500


class UpgradeDevTypeRegInfo(BaseTask):
    def __init__(self, product_name="NCE"):
        super(UpgradeDevTypeRegInfo, self).__init__()
        self.set_product_name(product_name)
        self.info("UpgradeDevTypeRegInfo init product_name is %s" % product_name)
        self.src_db_session = get_zenith_session('invmetadatadb', 'invmetadatadb', product_name)
        if self.src_db_session is None:
            self.error("invmetadatadb is None")
            return
        self.src_db_session.autocommit(True)
        self.src_db_cursor = self.src_db_session.cursor()

        self.dst_db_session = get_zenith_session('neresdb', 'neresdb', product_name)
        if self.dst_db_session is None:
            self.error("neresdb_sess is None")
            return
        self.dst_db_session.autocommit(True)
        self.dst_db_cursor = self.dst_db_session.cursor()

        self.src_table = "resourcetype"
        # 除了typeId，其他的字段都是内联子表获取的，这里的字段并不是真正的源表字段；
        # 在查询源数据的方法中，确保返回的字段与这里定义的一致。
        self.src_table_cols = ("typeId", "DevSysOID", "DevTypeTreeName", "label")
        self.src_table_cols_index = {y: x for x, y in enumerate(self.src_table_cols)}
        self.dst_table = "DevTypeRegInfo"
        self.dst_table_cols = ("modelId", "fModelId", "domain", "protocolType", "oid", "serialName", "typeName",
                               "typeId", "supportManualSpecify", "devNamingRule", "devModifyUrl", "priority",
                               "isTakingOverDelete", "createAt", "updateAt")
        self.idmapping_dic = {}

    def convert_data(self, paras):
        fModelId = None
        # 计算typeId/65536取整,
        # NCE-FAN区间：1-85; 2304-2500; 2502-2559
        # NCE-IP区间：167-173；1237-1397；2672-2785
        # NCE-IP 三方网元管理器区间：1520-1529
        # 在查询数据时已经执行了过滤，因此这里肯定会落在区间中，不属于NCE-FAN，就是NCE-IP
        type_num = int(int(paras[self.src_table_cols_index.get("typeId")]) / 65536)
        if 1 <= type_num <= 85 or 2304 <= type_num <= 2500 or 2502 <= type_num <= 2559:
            domain = "NCE-FAN"
            supportManualSpecify = False
        else:
            domain = "NCE-IP"
            supportManualSpecify = True
        oid = paras[self.src_table_cols_index.get("DevSysOID")]
        serialName = paras[self.src_table_cols_index.get("DevTypeTreeName")]
        typeName = paras[self.src_table_cols_index.get("label")]
        typeId = paras[self.src_table_cols_index.get("typeId")]

        devNamingRule = None
        devModifyUrl = "/rest/dam/v1/nebasicinfos/action/batch-update"
        priority = 0
        isTakingOverDelete = False

        createAt = str(int(time.time() * 1000))
        updateAt = str(int(time.time() * 1000))

        # 升级设备时，需要根据设备类型判断设备支持的协议
        # 默认为每种设备添加 SNMP 和 STELNET 类型，
        # IP三方网元设备（1520 <= type <= 1529）需要 SNMP, STELNET 和 NETCONF
        needed_protocols = self.check_if_need_netconf(type_num)
        return [tuple(x if x is None else str(x) for x in (
            str(uuid.uuid1()), fModelId, domain, protocolType, oid, serialName, typeName,
            typeId, supportManualSpecify, devNamingRule, devModifyUrl, priority,
            isTakingOverDelete, createAt, updateAt)) for protocolType in needed_protocols]

    def check_if_need_netconf(self, type_num):
        if 1520 <= type_num <= 1529:
            return ["SNMP", "STELNET", "NETCONF"]
        else:
            return ["SNMP", "STELNET"]

    def to_UpgradePara(self, datas):
        col_names = "`" + ("`, `".join(self.dst_table_cols)) + "`"
        val_ids = ":" + (",:".join((str(x + 1) for x in range(len(self.dst_table_cols)))))
        insert_stmt = "INSERT INTO `%s` (%s) VALUES(%s)" % (self.dst_table, col_names, val_ids)
        self.debug("insert sql stmt: %s" % insert_stmt)
        list_datas = []
        for data in datas:
            self.debug("original data is: %s, length is:%s" % (data, len(data)))
            data = self.convert_data(data)
            self.debug("coverted date is: %s, length is:%s" % (data, len(data)))
            if len(data[0]) == len(self.dst_table_cols):
                list_datas.extend(data)
            if len(list_datas) == ONE_PACKAGE:
                self.exec_sql(insert_stmt, list_datas)
                list_datas = []

        if len(list_datas) != 0:
            self.debug("data is:%s" % list_datas)
            self.exec_sql(insert_stmt, list_datas)

    def close_session(self):
        self.dst_db_cursor.close()
        self.dst_db_session.close()
        self.src_db_cursor.close()
        self.src_db_session.close()

    def do(self):
        try:
            self.info('UpgradeDevTypeRegInfo::do start!')
            # 这个升级脚本比较特殊，这里只能写死查询语句，先联表查询，然后再把子表中的值转换成字段.
            # --第一步 查询升级前NCE支持的所有网元类型ID；其中100000（本机网管）和301（虚拟网元）是平台内置的设备类型，不需要升级
            # select typeId from resourcetype where category = 'NetworkElement' and typeId not in ('100000','301');
            #
            # --第二步 查询升级切纳管后支持的所有类型ID
            # select typeId from DevTypeRegInfo;
            #
            # --第三步 对比升级切纳管后，比升级前少的类型ID，梳理出typeId列表
            #
            # --第四步 开始升级，按照如下表单构造数据，升级DevTypeRegInfo数据
            select_stmt_src = "select typeId from %s where " \
                              "category = 'NetworkElement' and typeId not in ('100000','301')" % self.src_table
            src_db_type_ids = set(x[0] for x in self.exec_query_sql(self.src_db_cursor, select_stmt_src))
            self.debug("src_db_type_ids: %s" % src_db_type_ids)
            select_stmt_dst = "select `typeId` from `%s`" % self.dst_table
            self.dst_db_cursor.execute(select_stmt_dst)
            dst_db_type_ids = set(x[0] for x in self.dst_db_cursor.fetchall())
            self.debug("dst_db_type_ids: %s" % dst_db_type_ids)
            sql = "select featureName, featureValue from resourcetype t " \
                  "inner join resourcetype_feature f on t.id = f.exid where t.typeId = '%s'"
            datas = []
            # 计算typeId/65536取整
            # NCE-FAN区间：1-85; 2304-2500; 2502-2559
            # NCE-IP区间：167-173;1237-1397;2672-2785;三方网元范围：1520-1529
            # 不在区间内，丢弃，记录丢弃的类型ID
            self.debug("src_db_type_ids - dst_db_type_ids = %s" % (src_db_type_ids - dst_db_type_ids))
            for typeId in src_db_type_ids - dst_db_type_ids:
                select_stmt_inner_join = sql % typeId
                key_values = dict(self.exec_query_sql(self.src_db_cursor, select_stmt_inner_join))
                self.debug("key_values: %s" % key_values)
                if typeId is None:
                    self.warning("typeId is None.")
                    continue
                # 这里确保typeId一定是字符串类型
                typeId = str(typeId)
                if not typeId.isdigit():
                    self.warning("typeId is not digit: %s" % typeId)
                    continue
                typeId_num = int(int(typeId) / 65536)
                if 1 <= typeId_num <= 85 \
                        or 2304 <= typeId_num <= 2500 \
                        or 2502 <= typeId_num <= 2559 \
                        or 167 <= typeId_num <= 173 \
                        or 1237 <= typeId_num <= 1397 \
                        or 1520 <= typeId_num <= 1529 \
                        or 2672 <= typeId_num <= 2785:
                    datas.append((typeId,
                                  key_values.get("DevSysOID"),
                                  key_values.get("DevTypeTreeName"),
                                  key_values.get("label")))
                else:
                    self.warning("typeId %s,typeId_num %s, not in valid range, ignored. "
                                 "the reg info is: %s" % (typeId, typeId_num, key_values))

            if datas:
                datas = list(datas)
            self.info('get para data success count: %d' % len(datas))

            self.to_UpgradePara(datas)
            self.info('to_UpgradePara done')
            self.close_session()
        except Exception as e:
            self.close_session()
            self.error('UpgradeDevTypeRegInfo got exception')
            self.error(traceback.format_exc())
            return -1
        self.info('UpgradeDevTypeRegInfo::do done')
        return 0


if __name__ == '__main__':
    tool = UpgradeDevTypeRegInfo()
    print('[INFO] UpgradeDevTypeRegInfo start>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
    tool.do()
    print('[INFO] UpgradeDevTypeRegInfo finished<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
