# encoding:utf-8

import com.huawei.ism.tool.framework.platform.runtime.SystemProperties as sysProperties
import java.util.MissingResourceException as MissingResourceException

from common_cache import get_ctrl_ip_list_cache
from common_cache import get_manager_ip_list_cache
import cliUtil
from cbb.frame.cli.cli_with_cache import execute_cmd_in_cli_mode_with_cache
from frameone.util import common as frame_common
import common

import os
import threading
import traceback
import time

MUTEX = threading.Lock()

# 命令总数超过此数量则开启多线程查询
NEED_BUILD_CONN_CMD_COUNT = 200


class ExecuteCmdCacheTask:
    def __init__(self, cmd_list, cli, env, logger, start_progress, total_p, sn, developer_cmd_list):
        self.cmd_list = cmd_list
        self.developer_cmd_list = developer_cmd_list
        self.env = env
        self.sn = sn
        self.logger = logger
        self.cli = cli
        self.start_p = start_progress
        self.total_p = total_p
        self.current_ctrl_ip = env.get("devInfo").getIp()
        self.conn_list = list()
        self.lang = common.getLang(env)
        if self._is_need_build_conn():
            self.logger.logInfo("is need build conn!")
            self.get_conn_list()

    def _is_need_build_conn(self):
        return len(self.cmd_list) + len(self.developer_cmd_list) > NEED_BUILD_CONN_CMD_COUNT

    def get_conn_list(self):
        ip_list = []
        # 18000 使用show upgrade package命令查询
        if common.is18000(self.env, self.cli):
            ip_list = self.get_ip_list_by_upgrade_package()
        else:
            # 中低端使用show port general查管理网口
            ip_list, cli_ret = get_manager_ip_list_cache(
                self.cli, self.env, self.logger
            )
            if not ip_list:
                ip_list = self.get_ip_list_by_upgrade_package()

        # 建立连接可能也非常耗时，所以启用线程同步创建。
        build_conn_thread_list = []
        ctrl_ssh_conn_number = self.get_ctrl_ssh_conn_number()
        # 设置的每控多少个ssh连接，默认2个
        ip_list = ctrl_ssh_conn_number * ip_list
        self.logger.logInfo("ip list is :{}".format(ip_list))
        for ip_address in ip_list:
            # 当前设备已有连接，不再重建
            if self.current_ctrl_ip == ip_address:
                continue
            build_thread = threading.Thread(
                target=self._build_conn_task, args=(ip_address,)
            )
            build_conn_thread_list.append(build_thread)

        # 启动所有线程
        self._start_thread(build_conn_thread_list)

    def get_ip_list_by_upgrade_package(self):
        ip_list = []
        flag, cli_ret, msg, ip_list_dict = get_ctrl_ip_list_cache(
            self.cli, self.env, self.logger
        )
        for engine in ip_list_dict:
            ip_list.extend(ip_list_dict.get(engine, []))
        return ip_list

    def _build_conn_task(self, ip_address):
        """
        建立连接
        :return:
        """
        try:
            self.logger.logInfo("preinspector build conn start.")
            conn = common.getCilConnectionByIp(ip_address, self.env, self.logger)
            if conn is not None:
                self.conn_list.append(conn)
                self.logger.logInfo(
                    "pre inspector build conn succ:{}".format(ip_address)
                )
            else:
                self.logger.logError(
                    "pre inspector build conn failed:{}".format(ip_address)
                )
        except Exception:
            self.logger.logError(
                "pre inspector build conn failed:{}".format(
                    traceback.format_exc())
            )

    def _build_task(self, thread_list, cmd_list, total_p, r_pro):
        """
        为每个连接分配任务
        :param thread_list:
        :param r_pro:
        :return:
        """
        start_index = 0
        conn_len = len(self.conn_list) + 1
        total_cmd_len = len(cmd_list)
        if total_cmd_len == 0:
            total_cmd_len = 1
        part_len = total_cmd_len / conn_len
        part_p = total_p / total_cmd_len
        self.logger.logInfo("part_p is {}".format(part_p))
        for conn in self.conn_list:
            last_index = start_index + part_len
            cmd_execute_list = cmd_list[start_index:last_index]
            self.logger.logInfo(
                "conn {} pre inspector task_cmd：{}.".format(
                    conn, cmd_execute_list)
            )
            start_index = last_index
            check_thread = threading.Thread(
                target=self.execute_cmd_cache,
                args=(
                    conn, cmd_execute_list, r_pro,
                    part_p,
                ),
            )
            thread_list.append(check_thread)

        return start_index, part_p

    def execute_task(self):
        """
        根据命令列表和连接列表，分配每个连接的任务，并执行。刷新进度条。
        :return:
        """
        thread_list = []
        try:
            cmd_list_length = len(self.cmd_list)
            cmd_list_developer_length = len(self.developer_cmd_list)
            cmd_ratio = 0.5
            total_cmd_length = float(cmd_list_length + cmd_list_developer_length)
            if total_cmd_length != 0:
                cmd_ratio = float(cmd_list_length) / total_cmd_length
            # 切片每个连接分一部分CMD
            developer_total_process = self.total_p * (1 - cmd_ratio)
            cli_cmd_total_process = self.total_p * cmd_ratio
            r_pro = RefreshProgress(self.start_p, self.logger, self.env)
            start_index, part_p = self._build_task(thread_list, self.cmd_list, cli_cmd_total_process, r_pro)
            # 最后当前连接执行检查
            last_task = self.cmd_list[start_index:]
            self.logger.logInfo("pre inspector last lun:{}".format(last_task))
            check_thread = threading.Thread(
                target=self.execute_cmd_cache,
                args=(self.cli, last_task, r_pro, part_p),
            )
            thread_list.append(check_thread)
            self._start_thread(thread_list)
            self.logger.logInfo("admin cli has been operated")
            # 开始执行的developer命令
            self.exec_developer_task(developer_total_process, r_pro)
            return r_pro.cur_progress
        except Exception:
            self.logger.logError(
                "except is {}".format(traceback.format_exc())
            )
        finally:
            for conn in self.conn_list:
                common.closeConnection(conn, self.env, self.logger)
                self.logger.logInfo("pre inspect closed {}".format(conn))

    def exec_developer_task(self, developer_total_process, r_pro):
        """
        根据developer命令列表和连接列表，分配每个连接的任务，并执行。刷新进度条。
        :return:
        """
        # enter developer
        thread_list = []
        # 将新增的链接切换至developer
        for conn in self.conn_list:
            self.logger.logInfo("change user mode")
            cliUtil.enterDeveloperMode(conn, self.lang)
        # 将当前的链接切换至developer
        cliUtil.enterDeveloperMode(self.cli, self.lang)

        start_index, part_p = self._build_task(thread_list, self.developer_cmd_list, developer_total_process, r_pro)
        # 最后当前连接执行检查
        last_task = self.developer_cmd_list[start_index:]
        self.logger.logInfo("pre inspector last fs:{}".format(last_task))
        check_thread = threading.Thread(
            target=self.execute_cmd_cache,
            args=(self.cli, last_task, r_pro, part_p,),
        )
        thread_list.append(check_thread)
        self._start_thread(thread_list)
        for conn in self.conn_list:
            self.logger.logInfo("change user mode")
            cliUtil.enterCliModeFromSomeModel(conn, self.lang)
        cliUtil.enterCliModeFromSomeModel(self.cli, self.lang)

    @staticmethod
    def _start_thread(thread_list):
        # 启动所有线程
        for check_thread in thread_list:
            time.sleep(0.01)
            check_thread.start()

        # 主线程中等待所有子线程退出
        for check_thread in thread_list:
            check_thread.join()

    def execute_cmd_cache(self, cli, cmd_list, r_pro, part_p):
        """
        执行命令并缓存数据库
        :param cli:连接
        :param cmd_list: 命令列表
        :param r_pro:刷新进度对象
        :param part_p:每个cmd的进度
        :return:返回执行命令结果
        """
        for cmd in cmd_list:
            execute_cmd_in_cli_mode_with_cache(
                self.env, cli, cmd, self.logger, self.sn)
            r_pro.refresh_progress(part_p)

    @frame_common.wrapAllExceptionLogged(logger=None)
    def get_ctrl_ssh_conn_number(self):
        """
        获取是否设置了逃生通道，访问子工具configuration目录下的system.properties
        查看app.inspect.pre.collect.ssh.conn.number
        :return: 每个控制器的ssh连接数量，默认是2，没有配置也是2
        """
        sys_properties_path = os.path.abspath(common.DIR_RELATIVE_CMD)
        escape_key = "app.inspect.pre.collect.ssh.conn.number"
        app_conf = sysProperties(sys_properties_path)
        try:
            escape_value = app_conf.getValue(escape_key)
            return int(escape_value) if str(escape_value).isdigit() else 2
        except MissingResourceException as e:
            self.logger.logError(
                "no key found. use default 1. err msg:{}".format(str(e))
            )
        return 2


class RefreshProgress:
    """
    刷新整体进度
    """

    def __init__(self, cur_progress, logger, env):
        self.cur_progress = cur_progress
        self.tmp_increment = 0
        self.logger = logger
        self.env = env

    def refresh_progress(self, increment):
        with MUTEX:
            self.tmp_increment = self.tmp_increment + increment
            self.logger.logInfo(
                "tmp_incre:%s, cur_pro:%s, incre:%s"
                % (self.tmp_increment, self.cur_progress, increment)
            )
            if self.tmp_increment >= 1:
                self.cur_progress += self.tmp_increment
                common.refreshProcess(self.env, self.cur_progress, self.logger)
                self.tmp_increment = self.tmp_increment - int(self.tmp_increment)
