#!/usr/bin/python
# -*- coding: utf-8 -*-
import time
import traceback
import uuid
import json

from util import ossext

from getDBConnection import get_zenith_session
from common_tasks.base_task import BaseTask

ONE_PACKAGE = 500

class UpgradeSnmpTemplate(BaseTask):
    def __init__(self, product_name="NCE"):
        super(UpgradeSnmpTemplate, self).__init__()
        self.set_product_name(product_name)
        self.info("UpgradeSnmpTemplate init product_name is %s" % product_name)
        self.src_db_session = get_zenith_session('rcaccessconfigdb', 'rcaccessconfigdb', product_name)
        if self.src_db_session is None:
            self.error("rcaccessconfigdb is None")
            return
        self.src_db_session.autocommit(True)
        self.src_db_cursor = self.src_db_session.cursor()

        self.dst_db_session = get_zenith_session('neresdb', 'neresdb', product_name)
        if self.dst_db_session is None:
            self.error("neresdb_sess is None")
            return
        self.dst_db_session.autocommit(True)
        self.dst_db_cursor = self.dst_db_session.cursor()

        self.src_table = "tbl_protocol_template"
        self.src_table_cols = ("id", "protocol", "name", "version", "userId", "accessMode", "content",
                               "tenantId", "lastUpdate")
        self.src_table_cols_index = {y:x for x, y in enumerate(self.src_table_cols)}
        self.dst_table = "SnmpTemplate"
        self.dst_table_cols = ("modelId", "tenantId", "templateName", "templateDesc", "accessMode", "version",
                               "port", "userName", "readCommunity", "writeCommunity", "timeout", "retries",
                               "pollInterval", "encryptProtocol", "encryptProtocolPwd", "authProtocol",
                               "authProtocolPwd", "contextId", "contextName", "createAt", "updateAt")

        self.version_map = {
            "SNMPv1": "SNMPv1",
            "SNMPv2C": "SNMPv2c",
            "SNMPv3": "SNMPv3"
        }

        self.authProtocol_map = {
            "SHA": "HMAC-SHA",
            "MD5": "HMAC-MD5",
            "SHA2-224": "HMAC-SHA2-224",
            "SHA2-256": "HMAC-SHA2-256",
            "SHA2-384": "HMAC-SHA2-384",
            "SHA2-512": "HMAC-SHA2-512",
        }

    def convert_data(self, paras):

        def decrypt(encrypted):
            """
            解密函数。
            :param encrypted: 已加密的数据。
            :return: 解密后的数据。
            """
            if encrypted is None:
                return None
            # 这里加解密都可能抛异常
            try:
                return ossext.Cipher.decrypt(encrypted)
            except TypeError as te:
                self.warning(te)
                self.warning("decrypt data failed.")
            except ValueError as ve:
                self.warning(ve)
                self.warning("decrypt data failed.")
            except BaseException as be:
                self.warning(be)
                self.warning("decrypt data failed.")
            return None

        def encrypt(text):
            """
            加密函数。
            :param text: 待加密的数据。
            :return: 加密后的数据。
            """
            if text is None:
                return None
            # 这里加解密都可能抛异常
            try:
                return ossext.Cipher.encrypt0(text)
            except TypeError as te:
                self.warning(te)
                self.warning("encrypt data failed.")
            except ValueError as ve:
                self.warning(ve)
                self.warning("encrypt data failed.")
            except BaseException as be:
                self.warning(be)
                self.warning("encrypt data failed.")
            return None

        def convert_encryptProtocol(level, privProtocol):
            if level == "3":
                return privProtocol
            else:
                return "NO-PRIV"

        def convert_encryptProtocolPwd(level, privProtocolPwd):
            if level == "3":
                return privProtocolPwd
            else:
                return None

        def convert_authProtocol(level, authProtocol):
            if level == "1":
                return "NO-AUTH"
            else:
                return self.authProtocol_map.get(authProtocol)

        def convert_authProtocolPwd(level, authProtocolPwd):
            if level == "1":
                return None
            else:
                return authProtocolPwd

        modelId = str(uuid.uuid1())
        tenantId = "default-organization-id"
        templateName = paras[self.src_table_cols_index.get("name")]
        templateDesc = None
        accessMode = paras[self.src_table_cols_index.get("userId")] if str(paras[self.src_table_cols_index.get("accessMode")]) == "1" else 0
        version = self.version_map.get(paras[self.src_table_cols_index.get("version")])
        if not version:
            self.warning("version is None, original templateName is: %s" % str(templateName))
            return []
        t = int(round(time.time() * 1000))
        createAt = t
        updateAt = t

        content = paras[self.src_table_cols_index.get("content")]
        content_text = decrypt(content)
        if content_text is None:
            return []

        try:
            oldContent = json.loads(content_text)
        except json.decoder.JSONDecodeError as jde:
            self.warning("SnmpTemplate convert_data json loads JSONDecodeError, templateName is %s" % str(templateName))
            return []
        except AttributeError as ae:
            self.warning("SnmpTemplate convert_data json loads AttributeError, templateName is %s" % str(templateName))
            return []
        except TypeError as te:
            self.warning("SnmpTemplate convert_data json loads TypeError, templateName is %s" % str(templateName))
            return []
        except ValueError as ve:
            self.warning("SnmpTemplate convert_data json loads ValueError, templateName is %s" % str(templateName))
            return []

        if version == "SNMPv3":
            port = oldContent.get("nePort")
            userName = oldContent.get("v3User")
            readCommunity = None
            writeCommunity = None
            timeout = oldContent.get("timeout")
            retries = oldContent.get("retries")
            pollInterval = oldContent.get("statusPollInterval")
            encryptProtocol = convert_encryptProtocol(oldContent.get("level"), oldContent.get("privProtocol"))
            encryptProtocolPwd = encrypt(convert_encryptProtocolPwd(oldContent.get("level"), oldContent.get("privPassword")))
            authProtocol = convert_authProtocol(oldContent.get("level"), oldContent.get("authProtocol"))
            authProtocolPwd = encrypt(convert_authProtocolPwd(oldContent.get("level"), oldContent.get("authPassword")))
            contextId = oldContent.get("engineId")
            contextName = oldContent.get("context")
        else:
            port = oldContent.get("nePort")
            userName = None
            readCommunity = encrypt(oldContent.get("read"))
            writeCommunity = encrypt(oldContent.get("write"))
            timeout = oldContent.get("timeout")
            retries = oldContent.get("retries")
            pollInterval = oldContent.get("statusPollInterval")
            encryptProtocol = None
            encryptProtocolPwd = None
            authProtocol = None
            authProtocolPwd = None
            contextId = None
            contextName = None

        return tuple(None if x is None else str(x) for x in (
            modelId, tenantId, templateName, templateDesc, accessMode, version, port, userName,
            readCommunity, writeCommunity, timeout, retries, pollInterval, encryptProtocol, encryptProtocolPwd,
            authProtocol, authProtocolPwd, contextId, contextName, createAt, updateAt))

    def to_UpgradePara(self, datas):
        col_names = "`" + ("`, `".join(self.dst_table_cols)) + "`"
        val_ids = ":" + (",:".join((str(x+1) for x in range(len(self.dst_table_cols)))))
        insert_stmt = "INSERT INTO `%s` (%s) VALUES(%s)" % (self.dst_table, col_names, val_ids)
        self.debug("insert sql stmt: %s" % insert_stmt)
        list_datas = []
        for data in datas:
            data = self.convert_data(data)
            if len(data) == len(self.dst_table_cols):
                tuple_data = tuple(data)
                list_datas.append(tuple_data)
            if len(list_datas) == ONE_PACKAGE:
                self.exec_sql(insert_stmt, list_datas)
                list_datas = []

        if len(list_datas) != 0:
            self.exec_sql(insert_stmt, list_datas)

    def close_session(self):
        self.dst_db_cursor.close()
        self.dst_db_session.close()
        self.src_db_cursor.close()
        self.src_db_session.close()

    def do(self):
        try:
            self.info('UpgradeSnmpTemplate::do start!')
            select_stmt = "select %s from %s" % (",".join(self.src_table_cols), self.src_table)
            datas = self.exec_query_sql(self.src_db_cursor, select_stmt)
            self.info('get para data success count:%d'%len(datas))

            self.to_UpgradePara(datas)
            self.info('to_UpgradePara done')
            self.close_session()
        except Exception as e:
            self.close_session()
            self.error('UpgradeSnmpTemplate got exception')
            self.error(traceback.format_exc())
            return -1
        self.info('UpgradeSnmpTemplate::do done')
        return 0


if __name__ == '__main__':

    tool = UpgradeSnmpTemplate()
    print('[INFO] UpgradeSnmpTemplate start>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
    tool.do()
    print('[INFO] UpgradeSnmpTemplate finished<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
