# encoding=utf-8
"""
功 能：封装对ospatchdb数据库的操作
版权信息：华为技术有限公司，版本所有(C) 2020-2029
修改记录：2020-06-01 12:00 创建
"""
import json
import sys
from datetime import datetime
import easyhttputil
import pyzenith
from commonlog import Logger
from util import ossext

LOG = Logger().getinstance(sys.argv[0])
UPDATE_INSTALL_SQL = "UPDATE TAL_PATCH_INFO SET PATCHINSTALLSTATUS='2',INSTALLTIME='%s' " \
                     "WHERE NODEID='%s' and PATCHPACKAGENAME='%s'"
INSERT_INSTALL_SQL = "INSERT INTO TAL_PATCH_INFO (NODEID, PATCHPACKAGENAME, " \
                     "PATCHINSTALLSTATUS, INSTALLTIME) VALUES ('%s','%s','2','%s')"
UPDATE_UNINSTALL_SQL = "UPDATE TAL_PATCH_INFO SET PATCHINSTALLSTATUS='1',INSTALLTIME='%s' " \
                       "WHERE NODEID='%s' and PATCHPACKAGENAME='%s'"


class ModifyFlag:
    """
    功能说明：操作软件包安装状态
    """

    def __init__(self, params):
        self.node_id = params.get("node_id")
        self.pkg_name = params.get("pkg_name")
        self.db_ip = params.get("db_ip")
        self.db_name = params.get("db_name")
        self.db_port = str(params.get("db_port"))
        self.db_connect = ''
        self.connect = ''


    def get_db_connect(self):
        """
        功能描述：获取数据库密文密码
        返回： 数据库密文密码
        修改记录：新增方法
        """
        db_pwd = ""
        nce_param_file = "/opt/upgrade/easysuite_upgrade/scripts/common/NCE-Common/nce_params_secure.json"
        with open(nce_param_file, "r", encoding='utf-8') as nce_param:
            nce_param_data = json.load(nce_param)
        exportParam = nce_param_data.get("exportParam", "")
        if not exportParam:
            return False
        if hasattr(easyhttputil, 'httppostWithRetry'):
            http_post_func = easyhttputil.httppostWithRetry
        else:
            http_post_func = easyhttputil.http_post_with_retry
        response = http_post_func(
            '/rest/plat/deploy-proxy/v1/containerlist/action?action-id=export-containerlist', exportParam, retry=3,
            sleep=5)
        file_dict = json.loads(response.decode('utf-8'))
        for one_key in file_dict:
            if not one_key.startswith('managedbsvr-'):
                continue
            if file_dict.get(one_key).get("adminPassword") == "":
                continue
            db_list = file_dict.get(one_key).get('dbList')
            db_pwd = db_list.get(self.db_name).get('dbUserPasswd')
            break
        if not db_pwd:
            return False
        de_db_pwd = ossext.Cipher.decrypt(db_pwd)
        try:
            self.connect = pyzenith.connect(self.db_ip, self.db_name, de_db_pwd, self.db_port)
            self.connect.autocommit(True)
            self.db_connect = self.connect.cursor()
        except Exception as e:
            LOG.warning("[get_db_connect]Exception:get_db_connect")
            return False
        return True

    def send_db_sql(self, sql, action='query'):
        """
        功能说明：执行数据库命令，返回结果
        :param sql:
        :param action
        :return:
        """
        default_result = True
        __sql = sql
        LOG.info("[send_db_sql]sql:%s" % __sql)
        if not self.get_db_connect():
            LOG.error("[send_db_sql]sql:%s" % __sql)
            return False
        try:
            self.db_connect.execute(__sql)
            if action == 'query':
                default_result = self.db_connect.fetchall()
                if isinstance(default_result, tuple):
                    LOG.info("[send_db_sql]result:%s" % list(default_result))
            else:
                self.connect.commit()
            self.db_connect.close()
            self.connect.close()
        except Exception as e:
            LOG.warning("[send_db_sql]Exception:send_db_sql error")
            LOG.error("[send_db_sql]sql:%s" % __sql)
            return False
        return default_result

    def query_pkg_status(self, node_id=None):
        """
        功能说明:查询软件包安装状态
        :return:
        """
        #   101：未安装：有历史安装记录
        #   102：未安装：无历史安装记录
        #   103：已安装
        if not node_id:
            __node_id = self.node_id
        else:
            __node_id = [node_id]
        __pkg_name = self.pkg_name
        query_sql = "select NODEID,PATCHPACKAGENAME,PATCHINSTALLSTATUS,INSTALLTIME" \
                    " from TAL_PATCH_INFO"
        default_result = self.send_db_sql(query_sql)
        if not default_result:
            LOG.error("[query_pkg_status]:sql:%s" % query_sql)
            return False
        all_node_pkg_status = {}
        for one_node_id in __node_id:
            for node_pkg in default_result:
                if one_node_id != node_pkg[0] or __pkg_name != node_pkg[1]:
                    continue
                if node_pkg[2] == "1":
                    all_node_pkg_status.update({one_node_id: "1"})
                elif node_pkg[2] == "2":
                    all_node_pkg_status.update({one_node_id: "2"})
        no_history = True
        for one_node_id in __node_id:
            if one_node_id in all_node_pkg_status.keys():
                no_history = False
        if no_history:
            return 102
        if "1" in all_node_pkg_status.values():
            return 101
        else:
            return default_result

    def modify_pkg_install(self):
        """
        功能说明：刷新补丁包安装状态为install
        :return:
        """
        #   111: 刷新失败
        #   112: 刷新成功
        now_time = datetime.strftime(datetime.now(), '%Y%m%d%H%M%S')
        pkg_result = self.query_pkg_status()
        if pkg_result == 101:
            for one_node_id in self.node_id:
                pkg_status = self.query_pkg_status(node_id=one_node_id)
                if pkg_status == 101:
                    sql_result = self.send_db_sql(UPDATE_INSTALL_SQL % (now_time, one_node_id,
                                                                        self.pkg_name),
                                                  action='commit')
                elif pkg_status == 101:
                    continue
                else:
                    sql_result = self.send_db_sql(INSERT_INSTALL_SQL % (one_node_id,
                                                                        self.pkg_name, now_time),
                                                  action='commit')
                if not sql_result:
                    LOG.error(
                        "[modify_pkg_install]node_id:%s result:%s" % (one_node_id, sql_result))
                    return 111
        elif pkg_result == 102:
            for one_node_id in self.node_id:
                sql_result = self.send_db_sql(INSERT_INSTALL_SQL)
                if not sql_result:
                    LOG.error(
                        "[modify_pkg_install]node_id:%s result:%s" % (one_node_id, sql_result))
                    return 111
        return 112

    def modify_pkg_uninstall(self):
        """
        功能说明：刷新补丁包安装状态为uninstall
        :return:
        """
        #   121: 刷新失败
        #   122：刷新成功
        now_time = datetime.strftime(datetime.now(), '%Y%m%d%H%M%S')
        pkg_result = self.query_pkg_status()
        if pkg_result not in [100, 101, 102]:
            for one_node_id in self.node_id:
                sql_result = self.send_db_sql(UPDATE_UNINSTALL_SQL % (now_time, one_node_id,
                                                                      self.pkg_name),
                                              action='commit')
                if not sql_result:
                    LOG.error(
                        "[modify_pkg_uninstall]node_id:%s result:%s" % (one_node_id, sql_result))
                    return 121
        return 122


def main(argv):
    action = argv[1]
    in_nodes = argv[2]
    pkg = argv[3]
    master_manager_db_ip = argv[4]
    node_id = in_nodes.split(",")
    params = {"node_id": node_id,
              "pkg_name": pkg,
              "db_ip": master_manager_db_ip,
              "db_name": "ospatchdb",
              "db_port": "32080"}
    modify_function = ModifyFlag(params)
    if action == "query":
        modify_result = modify_function.query_pkg_status()
    elif action == "install":
        modify_result = modify_function.modify_pkg_install()
    else:
        modify_result = modify_function.modify_pkg_uninstall()
    return modify_result


if __name__ == '__main__':
    result = main(sys.argv)
    if isinstance(result, int):
        sys.exit(result)
    sys.exit(0)
