# -*- coding: utf-8 -*-
import traceback
from utils.common.fic_base import TestCase
import utils.common.log as logger
from utils.common.message import Message
from utils.common.exception import FCDException
from utils.common.ssh_util import Ssh as ssh
from utils.business.hardware_driver_util import DriverApi
from plugins.DistributedStorage.scripts.utils.common.DeployConstant import DeployConstant
from plugins.DistributedStorageReplication.scripts.common_utils.config_params import Params
from plugins.DistributedStorage.scripts.logic.InstallOperate import InstallOperate, InstallDriverOperate


class InstallDriver(TestCase):
    def __init__(self, project_id, pod_id, fs_args, condition=None, metadata=None, process="norep", **kwargs):
        super(InstallDriver, self).__init__(project_id, pod_id)
        self.condition = condition
        self.metadata = metadata
        self.more_args = kwargs
        self.process = process
        self.operate = InstallOperate(self.project_id, self.pod_id, fs_args)
        self.service_name = "FusionStorageBlockReplication"
        self.osd_list = fs_args.get('osd_list')
        self.need_install_driver_hosts = dict()

    def procedure(self):
        try:
            host_list, host_ip_list = self._get_host_list()
            # 计算节点直接跳过驱动安装
            if not host_list:
                return Message()
            # 确认所有的节点ip是通的
            chesk_ret, fail_list = self.operate.check_host_list_ping(host_ip_list)
            if not chesk_ret:
                err_msg = "There are some nodes%s that are not up to reach" % fail_list
                logger.error(err_msg)
                raise Exception(err_msg)

            logger.info("Get all driver packages")
            client_list = list()
            all_drivers = list()
            reboot_host_list = list()
            for host in host_list:
                host_om_ip = host.get('om_ip')
                user = host.get('user')
                passwd = host.get('passwd')
                root_pwd = host.get('root_pwd')
                ssh_client = self.operate.create_ssh_root_client(host_om_ip, user, passwd, root_pwd)
                driver_list = self._get_driver_list(ssh_client, host)
                client_list.append(ssh_client)
                if not driver_list:
                    logger.info("There is no driver to need to install in current node[%s]." % host_om_ip)
                    continue
                logger.info("The driver list of node[%s] is %s" % (host_om_ip, driver_list))
                all_drivers.extend(driver_list)
                reboot_host_list.append(host)
                InstallDriverOperate.group_by_account_info(host, self.need_install_driver_hosts, driver_list)
                self._upload_install_scripts_to_node(ssh_client, host_om_ip, user, passwd)

            self.close_client_list(client_list)

            self._install_drivers_for_all_node(all_drivers)

            if len(reboot_host_list) > 0:
                logger.info("The drivers take effect after the storage nodes is restarted.")
                self.operate.reboot_host_list(reboot_host_list)

        except FCDException as e:
            logger.error("Failed install driver, details: %s" % str(e))
            logger.error(traceback.format_exc())
            return Message(500, e)
        except Exception as e:
            logger.error("Failed install driver, details: %s" % str(e))
            logger.error(traceback.format_exc())
            return Message(500, FCDException(626109, str(e)))
        return Message()

    def close_client_list(self, client_list):
        for client in client_list:
            if client:
                ssh.ssh_close(client)

    def _get_host_list(self):
        host_list = list()
        host_ip_list = list()
        if self.process == "rep":
            all_params = Params(self.project_id, self.pod_id, self.service_name).get_params_dict()
            rep_node_list = all_params.get("replication_cluster_nodes_info")
            for rep_node in rep_node_list:
                rep_host = dict()
                rep_host['om_ip'] = rep_node.get('replication_manage_ip')
                rep_host['user'] = rep_node.get('replication_manage_ip_common_username')
                rep_host['passwd'] = rep_node.get('replication_manage_ip_common_password')
                rep_host['root_pwd'] = rep_node.get('replication_manage_ip_root_password')
                host_list.append(rep_host)
                host_ip_list.append(rep_host['om_ip'])
        else:
            host_list = self.osd_list
            host_ip_list = [osd.get('om_ip') for osd in host_list]
        return host_list, host_ip_list

    @staticmethod
    def _upload_install_scripts_to_node(ssh_client, host_om_ip, user, passwd):
        """
        上传安装脚本至存储节点
        :param ssh_client:
        :param host_om_ip:
        :param user:
        :param passwd:
        :return:
        """
        remote_path = "/tmp/rpm_install_temp"
        cmd_tmp = "[ -d %s ] && rm -rf %s" % (remote_path, remote_path)
        ssh.ssh_send_command(ssh_client, cmd_tmp, "#", 60, 3)
        cmd_tmp = 'su - %s -s /bin/bash -c "mkdir -p %s; chmod 750 %s"' % (
            user, remote_path, remote_path)
        cmd_res = ssh.ssh_send_command(ssh_client, cmd_tmp, "#", 60, 3)
        logger.info("Make an directory on host[%s], result: %s" % (host_om_ip, str(cmd_res)))

        logger.info("Start to upload shell install script to host[%s]" % host_om_ip)
        driver_script = DeployConstant.DRIVER_SCRIPT
        for script in driver_script:
            res = ssh.put_file(host_om_ip, user, passwd, script, remote_path)
            logger.info("Upload package[%s] complete[%s]" % (script, res))
        cmd_tmp = 'chown -h -R root:root %s' % remote_path
        ssh.ssh_send_command(ssh_client, cmd_tmp, "#", 60, 3)
        logger.info("Modify %s right success" % remote_path)

    def _install_drivers_for_all_node(self, all_drivers):
        """
        安装驱动
        :param all_drivers: 待安装驱动列表
        :return:
        """
        logger.info("Install all package on storage nodes")
        for same_account_nodes in self.need_install_driver_hosts.values():
            user, passwd, root_pwd, nodes_ip_and_driver_info = same_account_nodes
            driver_and_ip_list = InstallDriverOperate.sort_drivers(nodes_ip_and_driver_info, all_drivers)
            logger.info("Installing All Drivers and Nodes is : %s" % driver_and_ip_list)
            for driver_and_ip in driver_and_ip_list:
                driver, ip_list = driver_and_ip
                driver_obj = DriverApi(self.project_id, ip_list, user, passwd, root_pwd)
                driver_obj.install_driver(driver, timeout=600)

    def _get_driver_list(self, ssh_client, host):
        host_om_ip = host.get('om_ip')
        user = host.get('user')
        passwd = host.get('passwd')
        root_pwd = host.get('root_pwd')
        cpu_arch = self.operate.get_cpu_arch(ssh_client)
        driver_obj = DriverApi(self.project_id,
                               os_ip_list=[host_om_ip],
                               os_username=user, os_password=passwd,
                               root_password=root_pwd)
        driver_info = driver_obj.get_nodes_driver_list_by_os(
            component="fusionstorage", node_type="fusionstorage", arch=cpu_arch).get(host_om_ip)
        logger.info("get node %s driver info: %s" % (host_om_ip, driver_info))
        driver_list = list(driver_info.keys())
        driver_list = self.operate.check_es3000_driver(ssh_client, cpu_arch, driver_list)
        return driver_list
