#!/usr/bin/python
# -*- coding: utf-8 -*-
import time
import traceback

from getDBConnection import get_zenith_session
import common_tasks.const_sql as const_sql
from common_tasks.base_task import BaseTask

ONE_PACKAGE = 500


class UpgradeQxDcn(BaseTask):
    def __init__(self, product_name="NCE"):
        super(UpgradeQxDcn, self).__init__()
        self.set_product_name(product_name)
        self.info("UpgradeQxDcn init product_name is %s" % product_name)
        self.src_db_session = get_zenith_session('MCDB', 'MCDB', product_name)
        if self.src_db_session is None:
            self.error("MCDB is None")
            return
        self.src_db_session.autocommit(True)
        self.src_db_cursor = self.src_db_session.cursor()

        self.dst_db_session = get_zenith_session('neresdb', 'neresdb', product_name)
        if self.dst_db_session is None:
            self.error("neresdb_sess is None")
            return
        self.dst_db_session.autocommit(True)
        self.dst_db_cursor = self.dst_db_session.cursor()

        self.src_table = "qxDcn_Cpp"
        self.src_table_cols = ("cLogicNeId", "cNeId", "cGneType", "cMainGneId", "cBackup1GneId", "cBackup2GneId",
                               "cBackup3GneId", "cIp", "cIPV6Address", "cPort", "cSSLDir", "cNEUser",
                               "cEncryptPassword", "iCommuProType")

        self.src_table_cols_index = {y:x for x, y in enumerate(self.src_table_cols)}
        self.dst_table = "QxDcn"
        self.dst_table_cols = ("tenantId", "neResId", "neId", "gatewayType", "gatewayProtocol", "masterGateway",
                               "ipAddr", "nsapAddr", "connectionType", "port", "sslCertificateId", "phyNeId",
                               "commuProType", "backup1Gateway", "backup2Gateway", "backup3Gateway",
                               "createAt", "updateAt")
        self.idmapping_dic = {}
        self.phyid_to_logicid= {}

    def get_resid_from_idmapping(self, paras):
        self.debug("get_resid_from_idmapping::start")
        idmappingdb_sess = get_zenith_session('idmappingdb', 'idmappingdb', self.product_name)
        if idmappingdb_sess is None:
            self.error("get idmappingdb_sess session fail")
            return
        idmappingdb_sess.autocommit(True)
        id_mapping_cursor = idmappingdb_sess.cursor()

        tmp_table_name = "tmp_neid_%s" % self.dst_table
        id_mapping_cursor.execute(const_sql.DROP_TEMP_TABLE % tmp_table_name)
        id_mapping_cursor.execute(const_sql.CREATE_TEMP_TABLE % tmp_table_name)
        insert_stmt = "insert into tmp_%s values(:1)" % tmp_table_name

        nativeIds = []
        for para in paras:
            data = []
            nativeId = "NE=%s" % para[self.src_table_cols_index.get("cLogicNeId")]
            data.append(nativeId)
            tuple_data = tuple(data)
            nativeIds.append(tuple_data)
            if len(nativeIds) == ONE_PACKAGE:
                id_mapping_cursor.executemany(insert_stmt, nativeIds)
                self.debug("one package:%s" % nativeIds)
                nativeIds = []

        if len(nativeIds) != 0:
            id_mapping_cursor.executemany(insert_stmt, nativeIds)
            self.debug("last package datas: %s" % nativeIds)

        query_stmt = const_sql.INNER_JOIN_TEMP_TABLE % tmp_table_name
        self.debug("query sql: %s" % query_stmt)
        id_mapping_cursor.execute(query_stmt)
        result = id_mapping_cursor.fetchall()

        for r in result:
            l = list(r)
            self.idmapping_dic[l[1].lstrip("NE=")] = l[0]
            self.debug("result: %s: %s" % (l[1], l[0]))

        id_mapping_cursor.execute(const_sql.DROP_TEMP_TABLE % tmp_table_name)
        id_mapping_cursor.close()
        idmappingdb_sess.close()

    def convert_data(self, paras):
        tenantId = "default-organization-id"
        neResId = self.idmapping_dic.get(str(paras[self.src_table_cols_index.get("cLogicNeId")]))
        neId = paras[self.src_table_cols_index.get("cLogicNeId")]
        cNeId = paras[self.src_table_cols_index.get("cNeId")]
        cMainGneId = paras[self.src_table_cols_index.get("cMainGneId")]
        gatewayType = "Gateway" if str(cNeId) == str(cMainGneId) else "NonGateway"
        if gatewayType == "Gateway":
            gatewayProtocol = "IP"
        else:
            gatewayProtocol = None
        # Done: qxDcn_C++的cMainGneId是数字格式的物理ID，QxDcn的masterGateway是设备逻辑ID，
        # Done: 需要根据物理ID查询出逻辑ID，然后赋值
        masterGateway = self.phyid_to_logicid.get(str(paras[self.src_table_cols_index.get("cMainGneId")]))

        if gatewayType == "Gateway":
            cIp = paras[self.src_table_cols_index.get("cIp")]
            cIPV6Address = paras[self.src_table_cols_index.get("cIPV6Address")]
            if not(cIp or cIPV6Address):
                # 非网关网元IP地址可以是空
                self.warning("cIp and cIPV6Address are None: neId is %s" % str(neId))
                ipAddr = None
            else:
                # 对于ptnv8的情况，cIp和cIPV6Address的值是相等的并且是字符串的格式，不是数字，因此直接走else分支即可正常获取IP；
                if cIp and str(cIp).isdigit():
                    iIp = int(cIp)
                    ipAddr = "%s.%s.%s.%s" % tuple((iIp >> x & 0xFF) for x in range(24, -1, -8))
                else:
                    ipAddr = cIPV6Address
        else:
            ipAddr = None

        nsapAddr = None

        connectionTypeMap = {
            "1400": "Common",
            "5432": "SSL"
        }
        # 非网关赋值None；如果是网关，如果不是5432，就给默认值Common
        if gatewayType == "Gateway":
            connectionType = connectionTypeMap.get(str(paras[self.src_table_cols_index.get("cPort")]), "Common")
        else:
            connectionType = None
        port = paras[self.src_table_cols_index.get("cPort")]
        sslCertificateId = paras[self.src_table_cols_index.get("cSSLDir")]
        phyNeId = paras[self.src_table_cols_index.get("cNeId")]
        commuProType = paras[self.src_table_cols_index.get("iCommuProType")]
        # Done: qxDcn_C++的backup1Gateway是数字格式的物理ID，QxDcn的cBackup1GneId是设备逻辑ID，
        # Done: 需要根据物理ID查询出逻辑ID，然后赋值
        backup1Gateway = self.phyid_to_logicid.get(str(paras[self.src_table_cols_index.get("cBackup1GneId")]))
        backup2Gateway = self.phyid_to_logicid.get(str(paras[self.src_table_cols_index.get("cBackup2GneId")]))
        backup3Gateway = self.phyid_to_logicid.get(str(paras[self.src_table_cols_index.get("cBackup3GneId")]))

        createAt = int(time.time() * 1000)
        updateAt = int(time.time() * 1000)

        return tuple(None if x is None else str(x) for x in (
            tenantId, neResId, neId, gatewayType, gatewayProtocol, masterGateway, ipAddr, nsapAddr, connectionType,
            port, sslCertificateId, phyNeId,commuProType, backup1Gateway, backup2Gateway, backup3Gateway,
            createAt, updateAt))

    def to_UpgradePara(self, stelnet_templates):
        self.get_resid_from_idmapping(stelnet_templates)
        col_names = "`" + ("`, `".join(self.dst_table_cols)) + "`"
        val_ids = ":" + (",:".join((str(x+1) for x in range(len(self.dst_table_cols)))))
        insert_stmt = "INSERT INTO `%s` (%s) VALUES(%s)" % (self.dst_table, col_names, val_ids)
        self.debug("%s" % insert_stmt)
        list_datas = []
        dst_cols_len = len(self.dst_table_cols)
        for tempalte in stelnet_templates:
            data = self.convert_data(tempalte)
            if len(data) == dst_cols_len:
                tuple_data = tuple(data)
                list_datas.append(tuple_data)
            else:
                self.warning("coverted data length not equals dst table cols, to be ignored.")
            if len(list_datas) == ONE_PACKAGE:
                self.exec_sql(insert_stmt, list_datas)
                list_datas = []

        if len(list_datas) != 0:
            self.exec_sql(insert_stmt, list_datas)

    def convert_data_java(self, paras):
        tenantId = "default-organization-id"
        neResId = self.idmapping_dic.get(str(paras[self.src_table_cols_index.get("cLogicNeId")]))
        neId = paras[self.src_table_cols_index.get("cLogicNeId")]
        # JAVA实例中都是网关
        gatewayType = "Gateway"
        gatewayProtocol = "IP"

        masterGateway = paras[self.src_table_cols_index.get("cMainGneId")]
        ipAddr = paras[self.src_table_cols_index.get("cIp")]

        nsapAddr = None

        connectionTypeMap = {
            "1400": "Common",
            "5432": "SSL"
        }
        # 非网关赋值None；如果是网关，如果不是5432，就给默认值Common
        connectionType = connectionTypeMap.get(str(paras[self.src_table_cols_index.get("cPort")]), "Common")
        port = paras[self.src_table_cols_index.get("cPort")]
        sslCertificateId = paras[self.src_table_cols_index.get("cSSLDir")]
        phyNeId = paras[self.src_table_cols_index.get("cNeId")]
        commuProType = paras[self.src_table_cols_index.get("iCommuProType")]
        # Done: qxDcn_C++的backup1Gateway是数字格式的物理ID，QxDcn的cBackup1GneId是设备逻辑ID，
        # Done: 需要根据物理ID查询出逻辑ID，然后赋值
        backup1Gateway = paras[self.src_table_cols_index.get("cBackup1GneId")]
        backup2Gateway = paras[self.src_table_cols_index.get("cBackup2GneId")]
        backup3Gateway = paras[self.src_table_cols_index.get("cBackup3GneId")]

        createAt = int(time.time() * 1000)
        updateAt = int(time.time() * 1000)

        return tuple(None if x is None else str(x) for x in (
            tenantId, neResId, neId, gatewayType, gatewayProtocol, masterGateway, ipAddr, nsapAddr, connectionType,
            port, sslCertificateId, phyNeId,commuProType, backup1Gateway, backup2Gateway, backup3Gateway,
            createAt, updateAt))

    def to_UpgradePara_JAVA(self, templates):
        self.info("to_UpgradePara_JAVA start.")
        self.get_resid_from_idmapping(templates)
        col_names = "`" + ("`, `".join(self.dst_table_cols)) + "`"
        val_ids = ":" + (",:".join((str(x+1) for x in range(len(self.dst_table_cols)))))
        insert_stmt = "INSERT INTO `%s` (%s) VALUES(%s)" % (self.dst_table, col_names, val_ids)
        self.debug("%s" % insert_stmt)
        list_datas = []
        dst_cols_len = len(self.dst_table_cols)

        for tempalte in templates:
            data = self.convert_data_java(tempalte)
            if len(data) == dst_cols_len:
                tuple_data = tuple(data)
                list_datas.append(tuple_data)
            else:
                self.warning("coverted data length not equals dst table cols, to be ignored.")
            if len(list_datas) == ONE_PACKAGE:
                self.exec_sql(insert_stmt, list_datas)
                list_datas = []

        if len(list_datas) != 0:
            self.exec_sql(insert_stmt, list_datas)
        self.info("to_UpgradePara_JAVA finished.")

    def close_session(self):
        self.dst_db_cursor.close()
        self.dst_db_session.close()
        self.src_db_cursor.close()
        self.src_db_session.close()

    def get_src_data_by_db_inst(self, db_inst):
        try:
            db_sess = get_zenith_session(db_inst, db_inst, self.product_name)
            if db_sess is None:
                # 找不到数据库，可能是当前场景不需要升级
                self.warning("%s db session is None" % db_inst)
                return []
        except TypeError as te:
            self.warning("can not get db session for db: %s, err is :%s"% (db_inst, str(te)))
            return []
        except BaseException as be:
            self.warning("can not get db session for db: %s, err is :%s"% (db_inst, str(be)))
            return []

        db_sess.autocommit(True)
        db_cur = db_sess.cursor()
        # 这个SQL语句的字段顺序与self.src_table_cols必须保持一致
        select_stmt = "select " \
              "t.cLogicNeId," \
              "t.cNeId," \
              "g.cGneType," \
              "t.cMainGneId," \
              "t.cBackup1GneId," \
              "t.cBackup2GneId," \
              "t.cBackup3GneId," \
              "g.cIp," \
              "g.cIPV6Address, " \
              "g.cPort, " \
              "g.cSSLDir," \
              "t.cNEUser," \
              "t.cEncryptPassword," \
              "-1 as iCommuProType from tTECommuNe t left join tTEEthGne g on t.cLogicNeId = g.cLogicNeId"
        self.debug("execute sql: %s" % select_stmt)
        try:
            db_cur.execute(select_stmt)
            templates = db_cur.fetchall()
        except BaseException as be:
            self.warning("execute sql failed, sql: %s, error: %s" % (select_stmt, be))
            return []
        if templates:
            templates = list(templates)
        self.info('get qx dcn data success count:%d' % len(templates))

        # 根据网关的物理ID获取逻辑ID
        select_stmt_phyid_to_logicid = "select cNeId, cLogicNeId from tTEEthGne"
        self.debug("execute sql: %s" % select_stmt_phyid_to_logicid)
        try:
            db_cur.execute(select_stmt_phyid_to_logicid)
            self.phyid_to_logicid = dict(
                [(str(c_neid), str(c_logic_neid)) for c_neid, c_logic_neid in db_cur.fetchall()])
        except BaseException as be:
            self.warning("execute sql failed, err is: %s"% str(be))
        db_cur.close()
        db_sess.close()

        return templates

    def get_src_data_by_db_inst_for_v8ptn(self, db_inst):
        try:
            db_sess = get_zenith_session(db_inst, db_inst, self.product_name)
            if db_sess is None:
                # 找不到数据库，可能是当前场景不需要升级
                self.warning("%s db session is None" % db_inst)
                return []
        except TypeError as te:
            self.warning("can not get db session for db: %s, err is :%s"% (db_inst, str(te)))
            return []
        except BaseException as be:
            self.warning("can not get db session for db: %s, err is :%s"% (db_inst, str(be)))
            return []

        db_sess.autocommit(True)
        db_cur = db_sess.cursor()
        # 这个SQL语句的字段顺序与self.src_table_cols必须保持一致
        select_stmt = "select " \
                      "t.devID," \
                      "t.devPhyID," \
                      "t.devGneType," \
                      "g.logicDevMainGneId," \
                      "g.logicDevBackup1GneId," \
                      "g.logicDevBackup2GneId," \
                      "g.logicDevBackup3GneId," \
                      "t.devIP," \
                      "t.devIP as cIPV6Address," \
                      "t.devPort," \
                      "t.devSSLDir," \
                      "g.devNEUser," \
                      "g.devEncryptPassword," \
                      "t.iCommuProType from t_dev_ethgne t left join t_dev_commune g on t.devID = g.devID"
        self.debug("execute sql: %s" % select_stmt)
        try:
            db_cur.execute(select_stmt)
        except BaseException as be:
            self.warning("execute sql failed, err is: %s"% str(be))
            return []
        templates = db_cur.fetchall()
        if templates:
            templates = list(templates)
        self.info('get qx dcn data success count:%d' % len(templates))

        # dcn_java不需要根据网关的物理ID获取逻辑ID
        db_cur.close()
        db_sess.close()

        return templates

    def merge_cpp_java_inst_data(self, dcn_cpp, dcn_java):
        dcn_java_add = []
        # 若qxDcn_JAVA存在devID在qxDcn_C++中不存在，则需要在QxDcn和QxPara中新增这条数据
        # 若qxDcn_JAVA存在devId在qxDcn_C++中存在，则刷新QxDcn中iCommuProType字段数据。（不要遗漏这一步操作）
        java_inst_devid_index = dict((paras[self.src_table_cols_index.get("cLogicNeId")], i) for i, paras in enumerate(dcn_java))
        java_inst_dev_id_set = set(java_inst_devid_index.keys())
        cpp_inst_devid_index = dict((paras[self.src_table_cols_index.get("cLogicNeId")], i) for i, paras in enumerate(dcn_cpp))
        cpp_inst_dev_id_set = set(cpp_inst_devid_index.keys())
        # 先更新，再增加，避免索引出现问题

        # 若qxDcn_JAVA存在devId在qxDcn_C++中存在，则刷新QxDcn中iCommuProType字段数据。（不要遗漏这一步操作）
        # 取交集，就是java的devId在cpp中存在
        to_update_devids = cpp_inst_dev_id_set.intersection(java_inst_dev_id_set)
        iCommuProType_index = self.src_table_cols_index.get("iCommuProType")
        cNeId_index = self.src_table_cols_index.get("cNeId")
        self.debug("to_update_devids: %s" % to_update_devids)
        for devid in to_update_devids:
            cpp_idx = cpp_inst_devid_index.get(devid)
            java_idx = java_inst_devid_index.get(devid)
            temp = list(dcn_cpp[cpp_idx])
            self.debug("before replaced: neId %s, iCommuProType %s" % (str(temp[cNeId_index]), str(temp[iCommuProType_index])))
            temp[iCommuProType_index] = dcn_java[java_idx][iCommuProType_index]
            dcn_cpp[cpp_idx] = tuple(temp)
            self.debug("after replaced: neId %s, iCommuProType %s" % (str(temp[cNeId_index]), str(temp[iCommuProType_index])))

        # 若qxDcn_JAVA存在devID在qxDcn_C++中不存在，则作为增量的网关数据返回，单独处理
        to_add_devids = java_inst_dev_id_set - cpp_inst_dev_id_set
        for devid in to_add_devids:
            java_idx = java_inst_devid_index.get(devid)
            self.debug("add qxDcn_JAVA db data, neId: %s" % str(devid))
            dcn_java_add.append(dcn_java[java_idx])

        # qxDcn_C++存在，qxDcn_JAVA不存在，判定为非网关数据，默认qx协议，设置iCommuProType为1
        non_gateway_dev_ids = cpp_inst_dev_id_set - java_inst_dev_id_set
        for dev_id in non_gateway_dev_ids:
            cpp_idx = cpp_inst_devid_index.get(dev_id)
            temp = list(dcn_cpp[cpp_idx])
            temp[iCommuProType_index] = 1
            dcn_cpp[cpp_idx] = tuple(temp)
            self.debug("replace qxDcn iCommuProType 1, data: %s" % str(dev_id))

        return dcn_cpp, dcn_java_add

    def do_db_inst(self, process_type, inst_id):
        db_inst_map = {
            "nemgr_trans": "nemgr_transDB_%s",
            "nemgr_v8trans": "nesvc_v8transDB_%s",
            "nemgr_marine": "nemgr_marineDB_%s",
            "nemgr_v8ptn": "nesvc_v8ptnDB_%s", # v8ptn还有一个java进程的数据库实例需要额外处理
        }
        if process_type in db_inst_map:
            db_inst = db_inst_map.get(process_type) % inst_id
        else:
            self.warning("not valid process_type: %s" % process_type)
            return
        self.info('UpgradeQxDcn::do_db_inst(%s) start!' % db_inst)
        if process_type == "nemgr_v8ptn":
            dcn_cpp = self.get_src_data_by_db_inst(db_inst)
            dcn_java = self.get_src_data_by_db_inst_for_v8ptn("ptn_v8_db_%s" % inst_id)
            # 合并数据, 按照规则，将dcn_java中的数据合并到dcn_cpp，并返回合并后的dcn_cpp
            templates, dcn_java_add = self.merge_cpp_java_inst_data(dcn_cpp, dcn_java)
            self.to_UpgradePara_JAVA(dcn_java_add)
        else:
            templates = self.get_src_data_by_db_inst(db_inst)
        self.to_UpgradePara(templates)

        self.info('UpgradeQxDcn::do_db_inst(%s) Done!' % db_inst)

    def do(self):
        try:
            self.info('UpgradeQxDcn::do start!')
            # 先从MCDB查出多个数据库实例，再用数据库实例的连接进行操作：
            sql_get_db_inst = "select DISTINCT process_type, handle " \
                              "from currenthandle " \
                              "where process_type = '%s' order by handle asc"
            for process_type_name in ("nemgr_trans", "nemgr_v8trans", "nemgr_marine", "nemgr_v8ptn"):
                select_stmt = sql_get_db_inst % process_type_name
                ret  = self.exec_query_sql(self.src_db_cursor, select_stmt)
                if ret:
                    self.debug("process_type, handle: %s" % str(ret))
                    for process_type, handle in ret:
                        self.do_db_inst(process_type_name, handle)
                else:
                    self.debug("get no record, use default")
                    self.do_db_inst(process_type_name, 1)
            self.info('to_UpgradeQxDcn done')
            self.close_session()
        except Exception as e:
            self.close_session()
            self.error('UpgradeQxDcn got exception')
            self.error(traceback.format_exc())
            return -1
        self.info('UpgradeQxDcn::do done')
        return 0


if __name__ == '__main__':
    tool = UpgradeQxDcn()
    print('[INFO] UpgradeQxDcn start>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
    tool.do()
    print('[INFO] UpgradeQxDcn finished<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')