#!/usr/bin/python
# -*- coding: utf-8 -*-
import time
import traceback
import uuid

from util import ossext

from getDBConnection import get_zenith_session
from common_tasks.base_task import BaseTask

ONE_PACKAGE = 500


class UpgradeNetconfTemplate(BaseTask):
    def __init__(self, product_name="NCE"):
        super(UpgradeNetconfTemplate, self).__init__()
        self.set_product_name(product_name)
        self.info("UpgradeNetconfTemplate init product_name is %s" % product_name)
        self.src_db_session = get_zenith_session('netconftemplateservicedb', 'netconftemplateservicedb', product_name)
        if self.src_db_session is None:
            self.error("netconftemplateservicedb is None")
            return
        self.src_db_session.autocommit(True)
        self.src_db_cursor = self.src_db_session.cursor()

        self.dst_db_session = get_zenith_session('neresdb', 'neresdb', product_name)
        if self.dst_db_session is None:
            self.error("neresdb_sess is None")
            return
        self.dst_db_session.autocommit(True)
        self.dst_db_cursor = self.dst_db_session.cursor()

        self.src_table = "t_netconf_template"
        self.src_table_cols = ("templateID", "templateName", "userName", "password", "port", "privateKey",
                               "passwordPhrase", "authMode", "loginTimeout", "responseTimeout", "templateType")
        self.src_table_cols_index = {y:x for x, y in enumerate(self.src_table_cols)}
        self.dst_table = "NetConfTemplate"
        self.dst_table_cols = ("modelId", "fModelId", "tenantId", "templateName", "templateDesc", "accessMode",
                               "userName", "password", "port", "userPrivateKey", "userPrivateKeyPwd", "authMode",
                               "loginTimeout", "responseTimeout", "createAt", "updateAt")

    def convert_data(self, paras):
        def encrypt(text):
            """
            加密函数。
            :param text: 待加密的数据。
            :return: 加密后的数据。
            """
            if text is None:
                return None
            # 这里加解密都可能抛异常
            try:
                return ossext.Cipher.encrypt0(text)
            except TypeError as te:
                self.warning(te)
                self.warning("encrypt data failed.")
            except ValueError as ve:
                self.warning(ve)
                self.warning("encrypt data failed.")
            except BaseException as be:
                self.warning(be)
                self.warning("encrypt data failed.")
            return None

        modelId = str(uuid.uuid1())
        fModelId = None
        tenantId = "default-organization-id"
        templateName = paras[self.src_table_cols_index.get("templateName")]
        templateDesc = None
        accessMode = paras[self.src_table_cols_index.get("templateType")]
        userName = paras[self.src_table_cols_index.get("userName")]
        password = paras[self.src_table_cols_index.get("password")]
        port = paras[self.src_table_cols_index.get("port")]
        userPrivateKey = encrypt(paras[self.src_table_cols_index.get("privateKey")])
        userPrivateKeyPwd = paras[self.src_table_cols_index.get("passwordPhrase")]

        authModeMap = {
            "1": "User",
            "2": "PrivateKey",
            "3": "PrivateKeyPassword",
        }
        # authMode字段的值需要做异常保护，避免由于不合法值而抛出异常
        authMode_text = paras[self.src_table_cols_index.get("authMode")]
        authMode = authModeMap.get(str(authMode_text))
        loginTimeout = paras[self.src_table_cols_index.get("loginTimeout")]
        responseTimeout = paras[self.src_table_cols_index.get("responseTimeout")]
        createAt = str(int(time.time() * 1000))
        updateAt = str(int(time.time() * 1000))

        return tuple(None if x is None else str(x) for x in (modelId, fModelId, tenantId, templateName, templateDesc,
                                                     accessMode,userName, password, port, userPrivateKey,
                                                     userPrivateKeyPwd, authMode, loginTimeout, responseTimeout,
                                                     createAt, updateAt))

    def to_UpgradePara(self, datas):
        col_names = "`" + ("`, `".join(self.dst_table_cols)) + "`"
        val_ids = ":" + (",:".join((str(x+1) for x in range(len(self.dst_table_cols)))))
        insert_stmt = "INSERT INTO `%s` (%s) VALUES(%s)" % (self.dst_table, col_names, val_ids)
        self.debug("insert sql stmt: %s" % insert_stmt)
        list_datas = []
        for data in datas:
            data = self.convert_data(data)
            if len(data) == len(self.dst_table_cols):
                tuple_data = tuple(data)
                list_datas.append(tuple_data)
            if len(list_datas) == ONE_PACKAGE:
                self.exec_sql(insert_stmt, list_datas)
                list_datas = []

        if len(list_datas) != 0:
            self.exec_sql(insert_stmt, list_datas)

    def close_session(self):
        self.dst_db_cursor.close()
        self.dst_db_session.close()
        self.src_db_cursor.close()
        self.src_db_session.close()

    def do(self):
        try:
            self.info('UpgradeNetconfTemplate::do start!')
            select_stmt = "select %s from %s" % (",".join(self.src_table_cols), self.src_table)
            self.debug("execute sql: %s" % select_stmt)
            try:
                self.src_db_cursor.execute(select_stmt)
                datas = self.src_db_cursor.fetchall()
            except BaseException as be:
                self.warning("execute sql failed, err is: %s" % str(be))
                datas = []

            final_data = []
            if datas:
                final_data = list(datas)
            self.info('get para data success count: %d' % len(final_data))

            self.to_UpgradePara(final_data)
            self.to_clear_old_table(datas)
            self.info('to_UpgradePara done')
            self.close_session()
        except Exception as e:
            self.close_session()
            self.error('UpgradeNetconfTemplate got exception')
            self.error(traceback.format_exc())
            return -1
        self.info('UpgradeNetconfTemplate::do done')
        return 0

    def to_clear_old_table(self, datas):
        if datas:
            self.info('clear t_netconf_template start')
            clear_sql = 'delete from ' + self.src_table
            self.debug("execute sql: %s" % clear_sql)
            try:
                self.src_db_cursor.execute(clear_sql)
                self.info('clear t_netconf_template end')
            except Exception:
                self.error('clear t_netconf_template got exception')
                self.error(traceback.format_exc())
        else:
            self.warning('no template need to delete')


if __name__ == '__main__':

    tool = UpgradeNetconfTemplate()
    print('UpgradeNetconfTemplate start>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
    tool.do()
    print('UpgradeNetconfTemplate finished<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
