#!/usr/bin/python
# -*- coding: utf-8 -*-
import time
import uuid
import traceback

from getDBConnection import get_zenith_session
from common_tasks.base_task import BaseTask

ONE_PACKAGE = 500

class UpgradeSnmpPara(BaseTask):
    def __init__(self, product_name="NCE"):
        super(UpgradeSnmpPara, self).__init__()
        self.set_product_name(product_name)
        self.info("UpgradeSnmpPara init product_name is %s" % product_name)
        self.src_db_session = get_zenith_session('eamdb', 'eamdb', product_name)
        if self.src_db_session is None:
            self.error("eamdb_sess is None")
            return
        self.src_db_session.autocommit(True)
        self.src_db_cursor = self.src_db_session.cursor()

        self.dst_db_session = get_zenith_session('neresdb', 'neresdb', product_name)
        if self.dst_db_session is None:
            self.error("neresdb_sess is None")
            return
        self.dst_db_session.autocommit(True)
        self.dst_db_cursor = self.dst_db_session.cursor()

        self.src_table = "tbl_devsnmppara"
        self.src_table_cols = (
            "dn", "iVersion", "szGetCommunity", "szSetCommunity", "iPort","iStatusPollInterval",
            "iConfPollInterval", "iTimeout", "iRetries", "iUseGetBulk",
            "iMinRepeat", "iMaxRepeat", "szDevstrUserName", "szContextName","szContextID",
            "iAuthProto", "szAuthPwd", "iPrivProto", "szPrivPwd", "szLocation","szEngineID",
            "ownerId", "resId")
        self.src_table_cols_index = {y:x for x, y in enumerate(self.src_table_cols)}

        self.dst_table = "SnmpPara"
        self.dst_table_cols = (
            "modelId", "fModelId", "tenantId", "neResId", "neId", "purpose", "connectStatus", "version","port",
            "userName","readCommunity", "writeCommunity", "timeout", "retries","pollInterval",
            "encryptProtocol", "encryptProtocolPwd", "authProtocol", "authProtocolPwd", "contextId",
            "contextName","engineId","usGetBulkFlag", "usMinRepeatCnt", "usMaxRepeatCnt",
            "createAt", "updateAt")
        
        self.version_map = {
            "0": "SNMPv1",
            "1": "SNMPv2c",
            "3": "SNMPv3"
        }

        self.encryptProtocol_map = {
            "1": "NO-PRIV",
            "2": "DES",
            "3": "3DES",
            "4": "AES-128",
            "20": "AES-192",
            "21": "AES-256"
        }

        self.authProtocol_map = {
            "1": "NO-AUTH",
            "2": "HMAC-MD5",
            "3": "HMAC-SHA",
            "4": "HMAC-SHA2-224",
            "5": "HMAC-SHA2-256",
            "6": "HMAC-SHA2-384",
            "7": "HMAC-SHA2-512"
        }

    def convert_data(self, paras):

        def pollInterval_convert(iStatusPollInterval, iConfPollInterval):
            if iStatusPollInterval != 0:
                return iStatusPollInterval
            elif iConfPollInterval != 0:
                return iConfPollInterval
            else:
                return 0

        modelId = str(uuid.uuid1())
        fModelId = None
        tenantId = "default-organization-id"
        neResId = paras[self.src_table_cols_index.get("resId")]
        dn = paras[self.src_table_cols_index.get("dn")]
        try:
            neId = dn.lstrip("NE=")
            if neId and not str(neId).isdigit():
                self.debug("neId is not number, data is ignored: %s" % neId)
                return []
        except AttributeError as ae:
            self.warning(ae)
            self.warning("the dn is: %s" % dn)
            return []
        purpose = "default"
        # iVersion转成str获取对应的映射值，map的key也改成str，避免int转换的类型错误
        iVersion = paras[self.src_table_cols_index.get("iVersion")]
        version = self.version_map.get(str(iVersion))

        connectStatus = -1
        port = paras[self.src_table_cols_index.get("iPort")]
        userName = paras[self.src_table_cols_index.get("szDevstrUserName")]
        if version == "SNMPv3":
            readCommunity = None
            writeCommunity = None
        else:
            readCommunity = paras[self.src_table_cols_index.get("szGetCommunity")]
            writeCommunity = paras[self.src_table_cols_index.get("szSetCommunity")]
        timeout = paras[self.src_table_cols_index.get("iTimeout")]
        retries = paras[self.src_table_cols_index.get("iRetries")]
        pollInterval = pollInterval_convert(paras[self.src_table_cols_index.get("iStatusPollInterval")], paras[self.src_table_cols_index.get("iConfPollInterval")])
        encryptProtocol = self.encryptProtocol_map.get(str(paras[self.src_table_cols_index.get("iPrivProto")]))
        encryptProtocolPwd = paras[self.src_table_cols_index.get("szPrivPwd")]
        authProtocol = self.authProtocol_map.get(str(paras[self.src_table_cols_index.get("iAuthProto")]))
        authProtocolPwd = paras[self.src_table_cols_index.get("szAuthPwd")]
        contextId = paras[self.src_table_cols_index.get("szContextID")]
        contextName	 = paras[self.src_table_cols_index.get("szContextName")]
        engineId = paras[self.src_table_cols_index.get("szEngineID")]
        usGetBulkFlag = paras[self.src_table_cols_index.get("iUseGetBulk")]
        usMinRepeatCnt	 = paras[self.src_table_cols_index.get("iMinRepeat")]
        usMaxRepeatCnt = paras[self.src_table_cols_index.get("iMaxRepeat")]
        t = int(round(time.time() * 1000))
        createAt = t
        updateAt = t
        return tuple(None if x is None else str(x) for x in (
            modelId, fModelId, tenantId, neResId, neId, purpose, connectStatus, version, port, userName,
            readCommunity, writeCommunity, timeout, retries, pollInterval, encryptProtocol, encryptProtocolPwd,
            authProtocol, authProtocolPwd, contextId, contextName, engineId, usGetBulkFlag, usMinRepeatCnt, usMaxRepeatCnt,
            createAt, updateAt))

    def to_UpgradePara(self, snmp_paras):
        col_names = "`" + ("`, `".join(self.dst_table_cols)) + "`"
        val_ids = ":" + (",:".join((str(x+1) for x in range(len(self.dst_table_cols)))))
        insert_stmt = "INSERT INTO `SnmpPara` (%s) VALUES(%s)" % (col_names, val_ids)
        self.debug("%s" % insert_stmt)
        list_datas = []
        dst_cols_len = len(self.dst_table_cols)

        for snmp_para in snmp_paras:
            data = self.convert_data(snmp_para)
            if len(data) == dst_cols_len:
                tuple_data = tuple(data)
                list_datas.append(tuple_data)
            else:
                self.warning("coverted data length not equals dst table cols, to be ignored.")
            if len(list_datas) == ONE_PACKAGE:
                self.exec_sql(insert_stmt, list_datas)
                list_datas = []

        if len(list_datas) != 0:
            self.exec_sql(insert_stmt, list_datas)

    def close_session(self):
        self.dst_db_cursor.close()
        self.dst_db_session.close()
        self.src_db_cursor.close()
        self.src_db_session.close()

    def do(self):
        try:
            self.info('UpgradeSnmpPara::do start!')
            select_stmt = "select %s from %s where dn !='OS=1' and dn != 'OSS' " % (",".join(self.src_table_cols), self.src_table)
            datas = self.exec_query_sql(self.src_db_cursor, select_stmt)
            self.info('get snmp_paras data success count:%d'%len(datas))

            self.to_UpgradePara(datas)
            self.info('to_UpgradeSnmpPara done')
            self.close_session()
        except Exception as e:
            self.close_session()
            self.error('UpgradeSnmpPara got exception')
            self.error(traceback.format_exc())
            return -1
        self.info('UpgradeSnmpPara::do done')
        return 0


if __name__ == '__main__':
    tool = UpgradeSnmpPara()
    print('[INFO] UpgradeSnmpPara start>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
    tool.do()
    print('[INFO] UpgradeSnmpPara finished<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
