#!/usr/bin/python
# -*- coding: utf-8 -*-
import time
import uuid
import traceback

from getDBConnection import get_zenith_session
import common_tasks.const_sql as const_sql
from common_tasks.base_task import BaseTask

ONE_PACKAGE = 500


class UpgradeStelnetPara(BaseTask):
    def __init__(self, product_name="NCE"):
        super(UpgradeStelnetPara, self).__init__()
        self.set_product_name(product_name)
        self.info("UpgradeStelnetPara init product_name is %s" % product_name)
        self.src_db_session = get_zenith_session('MCDB', 'MCDB', product_name)
        if self.src_db_session is None:
            self.error("MCDB is None")
            return
        self.src_db_session.autocommit(True)
        self.src_db_cursor = self.src_db_session.cursor()

        self.dst_db_session = get_zenith_session('neresdb', 'neresdb', product_name)
        if self.dst_db_session is None:
            self.error("neresdb_sess is None")
            return
        self.dst_db_session.autocommit(True)
        self.dst_db_cursor = self.dst_db_session.cursor()

        self.src_table = "tbl_TelnetPara"
        self.src_table_cols = (
            "DevID", "AuthMode", "LoginUser", "LoginPwd", "ProtocolVersion","IsEnable",
            "PrivilegeLevel", "PrivilegePwd", "LoginTimeout", "ResponseTimeout",
            "Port", "ProtocolType", "UserPrivateKey", "UserPrivateKeyPwd")
        self.src_table_cols_index = {y:x for x, y in enumerate(self.src_table_cols)}
        self.dst_table = "StelnetPara"
        self.dst_table_cols = (
            "modelId", "fModelId", "tenantId", "neResId", "neId", "purpose", "connectStatus", "authMode",
            "userName", "password", "privilegeState", "privilegeLevel", "privilegePwd", "loginTimeout",
            "responseTimeout", "port", "protocolType", "userPrivateKey", "userPrivateKeyPwd",
            "createAt", "updateAt")
        self.idmapping_dic = {}

    def get_resid_from_idmapping(self, paras):
        idmappingdb_sess = get_zenith_session('idmappingdb', 'idmappingdb', self.product_name)
        if idmappingdb_sess is None:
            self.error("get idmappingdb_sess session fail")
            return
        idmappingdb_sess.autocommit(True)
        id_mapping_cursor = idmappingdb_sess.cursor()

        tmp_table_name = "tmp_neid_%s" % self.dst_table
        id_mapping_cursor.execute(const_sql.DROP_TEMP_TABLE % tmp_table_name)
        id_mapping_cursor.execute(const_sql.CREATE_TEMP_TABLE % tmp_table_name)
        insert_stmt = "insert into tmp_%s values(:1)" % tmp_table_name

        nativeIds = []
        for para in paras:
            data = []
            nativeId = "NE=%s" % para[self.src_table_cols_index.get("DevID")]
            data.append(nativeId)
            tuple_data = tuple(data)
            nativeIds.append(tuple_data)
            if len(nativeIds) == ONE_PACKAGE:
                id_mapping_cursor.executemany(insert_stmt, nativeIds)
                self.debug("one package:%s" % nativeIds)
                nativeIds = []

        if len(nativeIds) != 0:
            id_mapping_cursor.executemany(insert_stmt, nativeIds)
            self.debug("last package datas: %s" % nativeIds)

        query_stmt = const_sql.INNER_JOIN_TEMP_TABLE % tmp_table_name
        self.debug("query sql: %s" % query_stmt)
        id_mapping_cursor.execute(query_stmt)
        result = id_mapping_cursor.fetchall()

        for r in result:
            l = list(r)
            self.idmapping_dic[l[1].lstrip("NE=")] = l[0]
            self.debug("result: %s: %s" % (l[1], l[0]))

        id_mapping_cursor.execute(const_sql.DROP_TEMP_TABLE % tmp_table_name)
        id_mapping_cursor.close()
        idmappingdb_sess.close()

    def convert_data(self, paras):
        modelId = str(uuid.uuid1())
        fModelId = None
        tenantId = "default-organization-id"
        neResId = self.idmapping_dic.get(str(paras[self.src_table_cols_index.get("DevID")]))
        neId = paras[self.src_table_cols_index.get("DevID")]
        purpose = "default"
        connectStatus = -1
        authModeMap = {
            "1": "NOAuth",
            "2": "Password",
            "3": "User",
            "4": "PrivateKey",
            "5": "PrivateKeyPassword",
        }
        authMode = authModeMap.get(str(paras[self.src_table_cols_index.get("AuthMode")]))
        userName = paras[self.src_table_cols_index.get("LoginUser")]
        password = paras[self.src_table_cols_index.get("LoginPwd")]
        privilegeStateMap = {
            "0": False,
            "1": True
        }
        privilegeState = privilegeStateMap.get(str(paras[self.src_table_cols_index.get("IsEnable")]))
        privilegeLevel = paras[self.src_table_cols_index.get("PrivilegeLevel")]
        privilegePwd = paras[self.src_table_cols_index.get("PrivilegePwd")]
        # 特权使能开关关闭的时候且特权密码全是空格，则置空
        if not privilegeState and privilegePwd and privilegePwd.isspace():
            privilegePwd = None
        loginTimeout = paras[self.src_table_cols_index.get("LoginTimeout")]
        responseTimeout = paras[self.src_table_cols_index.get("ResponseTimeout")]
        port = paras[self.src_table_cols_index.get("Port")]
        protocolTypeMap = {
            "0": "Telnet",
            "1": "STelnet",
        }
        protocolType = protocolTypeMap.get(str(paras[self.src_table_cols_index.get("ProtocolType")]))
        userPrivateKey = paras[self.src_table_cols_index.get("UserPrivateKey")]
        if userPrivateKey and isinstance(userPrivateKey, bytes):
            self.info("[%s] userPrivateKey is bytes, trans to str" % str(neId))
            userPrivateKey = str(userPrivateKey.decode('utf-8'))
        userPrivateKeyPwd = paras[self.src_table_cols_index.get("UserPrivateKeyPwd")]
        createAt = int(time.time() * 1000)
        updateAt = int(time.time() * 1000)


        return tuple(None if x is None else str(x) for x in (
            modelId, fModelId, tenantId, neResId, neId, purpose, connectStatus, authMode,
            userName, password, privilegeState, privilegeLevel, privilegePwd, loginTimeout,
            responseTimeout, port, protocolType, userPrivateKey, userPrivateKeyPwd,
            createAt, updateAt))

    def to_UpgradePara(self, stelnet_templates):
        self.get_resid_from_idmapping(stelnet_templates)
        col_names = "`" + ("`, `".join(self.dst_table_cols)) + "`"
        val_ids = ":" + (",:".join((str(x+1) for x in range(len(self.dst_table_cols)))))
        insert_stmt = "INSERT INTO `%s` (%s) VALUES(%s)" % (self.dst_table, col_names, val_ids)
        self.debug("%s" % insert_stmt)
        list_datas = []
        dst_cols_len = len(self.dst_table_cols)
        for tempalte in stelnet_templates:
            data = self.convert_data(tempalte)
            if len(data) == dst_cols_len:
                tuple_data = tuple(data)
                list_datas.append(tuple_data)
            else:
                self.warning("coverted data length not equals dst table cols, to be ignored")
            if len(list_datas) == ONE_PACKAGE:
                self.exec_sql(insert_stmt, list_datas)
                list_datas = []

        if len(list_datas) != 0:
            self.exec_sql(insert_stmt, list_datas)

    def close_session(self):
        self.dst_db_cursor.close()
        self.dst_db_session.close()
        self.src_db_cursor.close()
        self.src_db_session.close()

    def do(self):
        try:
            self.info ('UpgradeStelnetPara::do start!')
            select_stmt = 'select %s from %s' % (",".join(self.src_table_cols), self.src_table)
            templates = self.exec_query_sql(self.src_db_cursor, select_stmt)
            self.info('get stelnet para data success count:%d'%len(templates))

            self.to_UpgradePara(templates)
            self.info('to_UpgradeStelnetPara done')
            self.close_session()
        except Exception as e:
            self.close_session()
            self.error('UpgradeStelnetPara got exception')
            self.error(traceback.format_exc())
            return -1
        self.info('UpgradeStelnetPara::do done')
        return 0


if __name__ == '__main__':
    tool = UpgradeStelnetPara()
    print('[INFO] UpgradeStelnetPara start>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
    tool.do()
    print('[INFO] UpgradeStelnetPara finished<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
