# -*- coding: utf-8 -*-
import utils.common.log as logger
from utils.common.fic_base import TestCase
from utils.common.ssh_util import Ssh as ssh
from utils.business.hardware_driver_util import DriverApi
from plugins.DistributedStorage.logic.install_operate import InstallOperate
from plugins.DistributedStorage.utils.common.deploy_constant import DeployConstant


class InstallDriver(TestCase):
    def __init__(self, project_id, pod_id, fs_args, **kwargs):
        super(InstallDriver, self).__init__(project_id, pod_id)
        self.fs_args = fs_args
        self.more_args = kwargs
        self.operate = InstallOperate(self.project_id, self.pod_id, self.fs_args)

    @staticmethod
    def _upload_install_scripts_to_node(ssh_client, host_om_ip, user, passwd):
        remote_path = "/tmp/rpm_install_temp"
        cmd_tmp = "[ -d %s ] && rm -rf %s" % (remote_path, remote_path)
        ssh.ssh_send_command(ssh_client, cmd_tmp, "#", 60, 3)
        cmd_tmp = 'su - %s -s /bin/bash -c "mkdir -p %s; chmod 750 %s"' % (user, remote_path, remote_path)
        cmd_res = ssh.ssh_send_command(ssh_client, cmd_tmp, "#", 60, 3)
        logger.info("Make an directory on host[%s], result: %s" % (host_om_ip, str(cmd_res)))

        logger.info("Start to upload shell install script to host[%s]" % host_om_ip)
        driver_script = DeployConstant.DRIVER_SCRIPT[1]
        res = ssh.put_file(host_om_ip, user, passwd, driver_script, remote_path)
        logger.info("Upload package[%s] complete[%s]" % (driver_script, res))
        cmd_tmp = 'chown -h -R root:root %s' % remote_path
        ssh.ssh_send_command(ssh_client, cmd_tmp, "#", 60, 3)
        logger.info("Modify %s right success" % remote_path)

    def procedure(self):
        osd_list = self.fs_args.get('osd_list')
        osd_ip_list = [osd.get('om_ip') for osd in osd_list]
        chesk_ret, fail_list = self.operate.check_host_list_ping(osd_ip_list)
        if not chesk_ret:
            err_msg = "There are some nodes%s that are not up to reach" % fail_list
            logger.error(err_msg)
            raise Exception(err_msg)

        logger.info("Get all driver packages")
        client_list = list()
        all_drivers = list()
        driver_infos = dict()
        for host in osd_list:
            host_om_ip = host.get('om_ip')
            user = host.get('user')
            passwd = host.get('passwd')
            root_pwd = host.get('root_pwd')
            ssh_client = self.operate.create_ssh_root_client(host_om_ip, user, passwd, root_pwd)
            driver_list = self._get_driver_list(ssh_client, host)
            client_list.append(ssh_client)
            if not driver_list:
                logger.info("There is no driver to need to install in current node[%s]." % host_om_ip)
                continue
            all_drivers.extend(driver_list)
            driver_infos[host_om_ip] = driver_list
            self._upload_install_scripts_to_node(ssh_client, host_om_ip, user, passwd)

        self.close_client_list(client_list)

        logger.info("Install all package on storage nodes")
        all_drivers = list(set(all_drivers))
        driver_and_ip_list = self.sorted_drivers(driver_infos, all_drivers)
        logger.info("Installing All Drivers and Nodes is : %s" % driver_and_ip_list)
        for driver_and_ip in driver_and_ip_list:
            driver, ip_list = driver_and_ip
            driver_obj = DriverApi(self.project_id,
                                   os_ip_list=ip_list,
                                   os_username=user,
                                   os_password=passwd,
                                   root_password=root_pwd)
            driver_obj.install_driver(driver, timeout=600)

    def close_client_list(self, client_list):
        for client in client_list:
            if client:
                ssh.ssh_close(client)

    def sorted_drivers(self, driver_infos, all_drivers):
        """
        以驱动为依据，对驱动列表进行分批次
        """
        driver_list = list()
        for driver in all_drivers:
            ip_list = []
            for node in driver_infos:
                if driver in driver_infos.get(node):
                    ip_list.append(node)
            if ip_list:
                data = ([driver], ip_list)
                driver_list.append(data)
        return driver_list

    def _get_driver_list(self, ssh_client, host):
        host_om_ip = host.get('om_ip')
        user = host.get('user')
        passwd = host.get('passwd')
        root_pwd = host.get('root_pwd')
        cpu_arch = self.operate.get_cpu_arch(ssh_client)
        driver_obj = DriverApi(self.project_id,
                               os_ip_list=[host_om_ip],
                               os_username=user, os_password=passwd,
                               root_password=root_pwd)
        driver_info = driver_obj.get_nodes_driver_list_by_os(component="fusionstorage",
                                                             node_type="fusionsphere",
                                                             arch=cpu_arch).get(host_om_ip)
        driver_list = list(driver_info.keys())
        logger.info("get node %s driver info: %s" % (host_om_ip, driver_list))
        return driver_list
