#!/usr/bin/python
# -*- coding: utf-8 -*-
import time
import traceback

from getDBConnection import get_zenith_session
from common_tasks.base_task import BaseTask

ONE_PACKAGE = 500


class UpgradeProtocolSecurityConfig(BaseTask):
    def __init__(self, product_name="NCE"):
        super(UpgradeProtocolSecurityConfig, self).__init__()
        self.set_product_name(product_name)
        self.info("UpgradeProtocolSecurityConfig init product_name is %s" % product_name)
        self.src_db_session = get_zenith_session('neresdb', 'neresdb', product_name)
        if self.src_db_session is None:
            self.error("neresdb is None")
            return
        self.src_db_session.autocommit(True)
        self.src_db_cursor = self.src_db_session.cursor()

        self.dst_db_session = get_zenith_session('neresdb', 'neresdb', product_name)
        if self.dst_db_session is None:
            self.error("neresdb_sess is None")
            return
        self.dst_db_session.autocommit(True)
        self.dst_db_cursor = self.dst_db_session.cursor()

        self.src_table = "ProtocolSecurityConfig"
        # 在查询源数据的方法中，确保返回的字段与这里定义的一致。
        self.src_table_cols = ("tenantId", "configKey", "configValue")
        self.src_table_cols_index = {y:x for x, y in enumerate(self.src_table_cols)}
        self.dst_table = "ProtocolSecurityConfig"
        self.dst_table_cols = ("tenantId", "configKey", "configValue", "createAt", "updateAt")
        self.idmapping_dic = {}
        self.logicid_to_addr = {}

    def convert_data(self, paras):
        tenantId = "default-organization-id"
        configKey = "telnet_config"
        configValue = "true"

        createAt = int(time.time() * 1000)
        updateAt = int(time.time() * 1000)

        return tuple(None if x is None else str(x) for x in (tenantId, configKey, configValue, createAt, updateAt))

    def to_UpgradePara(self, datas):
        col_names = "`" + ("`, `".join(self.dst_table_cols)) + "`"
        val_ids = ":" + (",:".join((str(x+1) for x in range(len(self.dst_table_cols)))))
        insert_stmt = "INSERT INTO `%s` (%s) VALUES(%s)" % (self.dst_table, col_names, val_ids)
        self.debug("insert sql stmt: %s" % insert_stmt)
        list_datas = []
        for data in datas:
            self.debug("original data is: %s, length is:%s" % (data, len(data)))
            data = self.convert_data(data)
            self.debug("coverted date is: %s, length is:%s" % (data, len(data)))
            if len(data) == len(self.dst_table_cols):
                list_datas.append(data)
            if len(list_datas) == ONE_PACKAGE:
                self.exec_sql(insert_stmt, list_datas)
                list_datas = []

        if len(list_datas) != 0:
            self.debug("data is:%s" % list_datas)
            self.exec_sql(insert_stmt, list_datas)

    def close_session(self):
        self.dst_db_cursor.close()
        self.dst_db_session.close()
        self.src_db_cursor.close()
        self.src_db_session.close()

    def do(self):
        try:
            self.info ('UpgradeProtocolSecurityConfig::do start!')
            select_stmt = "select `%s` from `%s` where `configKey`='telnet_config'" \
                          % ("`,`".join(self.src_table_cols), self.src_table)
            templates = self.exec_query_sql(self.src_db_cursor, select_stmt)
            self.info('get stelnet para data success count:%d'%len(templates))

            if not templates:
                # 如果不存在telnet_config，则插入一条新的记录。
                # 构造一条数据，用于转换，之后新增即可
                self.to_UpgradePara([("tenantId", "configKey", "configValue")])
            else:
                # 如果已经存在telnet_config，则执行update语句，更新对应的值。
                update_stmt = "update `%s` set `configValue`=? where `configKey`='telnet_config'" % self.dst_table
                self.exec_sql(update_stmt, [("true",)])

            self.info('to_UpgradeProtocolSecurityConfig done')
            self.close_session()
        except Exception as e:
            self.close_session()
            self.error('UpgradeProtocolSecurityConfig got exception')
            self.error(traceback.format_exc())
            return -1
        self.info('UpgradeProtocolSecurityConfig::do done')
        return 0

        
if __name__ == '__main__':
    tool = UpgradeProtocolSecurityConfig()
    print('[INFO] UpgradeProtocolSecurityConfig start>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
    tool.do()
    print('[INFO] UpgradeProtocolSecurityConfig finished<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
