# -*- coding: UTF-8 -*-
# Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.

import datetime
import time
import os
import re
import traceback

import common
from cbb.frame.base import baseUtil
from cbb.frame.cli import cliUtil
from cbb.frame.cli import cli_con_mgr
from cbb.frame.cli.exec_on_all_even_mini_sys import (
    ExecuteOnAllControllers, FuncResult, ResultType,
    ExeOnAllCtrlContext
)
from cbb.frame.dsl.adapter import get_sn
from cbb.frame.context import contextUtil
from cbb.frame.adapter.replace_adapter import is_digital_ver, compare_version
from memory_inspect.adapter.java_adapter import UnCheckException
from memory_inspect.db.task_db_service import DbService
from java.io import File
from java.lang import Exception as JException
from com.huawei.ism.tool.obase.connection import SftpTransporter


EXPORT_CMD = 'show file export_path file_type={}'
DELETE_EXPORT_CMD = 'delete file filetype={}'
TRANSFER_CMD = 'getremotefile {} {} {}'
ENGINE_IP_CMD = 'show port general logic_type=Management_Port'
EXPORT_PATH_FOR_DORADO_610 = "/home/permitdir"


def is_digital_ver_or_v7(ver):
    return is_digital_ver(ver) or ver.startswith("V700")


class GetRemoteFileAdapter:
    """
    针对Dorado 612及之后版本的getremotefile命令适配器
    """
    TRANSFER_CMD = 'getremotefile {} {}'
    CLEAN_CMD = 'cleartranstmpfile {}'
    EXPORT_PATH = '/OSM/log/trans_tmp/'

    def __init__(self, dev, logger, lang):
        self.dev = dev
        self.logger = logger
        self.lang = lang

    def is_match(self):
        """
        哪些版本需要使用这个适配器
        """
        if 'dorado' in self.dev.get('type').lower() and is_digital_ver_or_v7(
                self.dev.get('version')) and compare_version(self.dev.get('version'), '6.1.2RC1') >= 0:
            self.logger.info("is dorado 612 product.")
            return True
        if baseUtil.is_ocean_protect(self.dev.get('type')):
            self.logger.info("is ocean protect product.")
            return True
        if baseUtil.is_new_oceanstor(self.dev.get('type')):
            self.logger.info("is new oceanstor product.")
            return True
        return False

    @classmethod
    def get_export_path(cls):
        return cls.EXPORT_PATH

    @classmethod
    def get_export_cmd(cls, ip, file_path):
        return cls.TRANSFER_CMD.format(ip, file_path)

    def clear_tmp_file(self, cli, remote_name):
        flag, cli_ret, err_msg = cliUtil.excuteCmdInMinisystemModel(
            cli, self.CLEAN_CMD.format(remote_name), self.lang, self.dev.get('pawd'))
        if not flag:
            self.logger.error("delete file error:{}", err_msg)
        return


DEFAULT_EXPORT_DIR = '/OSM/coffer_data/omm/export_import'

GLOBAL_CONTEXT = {}


def execute(context, collected_files):

    logger = contextUtil.getLogger(context)
    lang = contextUtil.getLang(context)
    # 设备初始化
    export_dev = None
    try:
        export_dev = ExportDevice(context)
        flag = export_dev.init_device()
        if export_dev.login_mini_sys:
            raise UnCheckException(common.getMsg(lang, "log.collect.in.mini.mode"), "")

        if flag is not True:
            raise UnCheckException(
                common.getMsg(lang, "log.collect.device.abnormal"), "")
        log_collect = LogCollect(context, export_dev, collected_files)
        log_collect.init_exported_files_list(export_dev.current_cli)
        export_dev.reconnect_con()

        # 导出日志
        exporter = FileExporter(context, export_dev)
        failed_engines = log_collect.export_files_on_all_engines(exporter)
        # 连接模式还原
        cliUtil.enterCliModeFromSomeModel(
            exporter.device.current_cli, exporter.lang)
        logger.info("collect log success.")
        return failed_engines
    except UnCheckException as ex:
        raise ex
    except Exception:
        logger.error(
            "run log collect exception. {}".format(
                traceback.format_exc()))
        raise UnCheckException(common.getMsg(lang, "log.collect.failed"), str(traceback.format_exc()))
    finally:
        if export_dev:
            export_dev.clear_env()


class LogCollect:
    def __init__(self, context, export_dev, need_collect_files):
        self.context = context
        self.logger = contextUtil.getLogger(context)
        self.sn = get_sn(context)
        self.lang = contextUtil.getLang(context)
        self.start_time = context.get("start_time")
        self.chip_type = context.get("chip_type")
        self.db_service = DbService(context)
        self.collected_log_files = \
            self.db_service.task_process_tbl.query_collected_log_files(self.sn)
        self.need_collect_files = need_collect_files

        # 每个控制器上存在的需要导出的文件
        self.expected_export_files = {}

        # 每个控制器上实际导出的文件
        self.export_files = {}
        self.export_dev = export_dev

    @ExecuteOnAllControllers
    def scan_available_files(self, exec_context):
        """
        扫描需要导出的文件，要遍历一次控制器
        :param exec_context:
        :return:
        """

        def __scan_dir(file_key):
            cmd = 'ls {}'.format(self.need_collect_files.get("dir").get(file_key).get("path"))
            flag, cli_res, error_msg = cliUtil.excuteCmdInMinisystemModel(
                cli, cmd, lang)
            if flag is True:
                return self.__parse_valid_files(
                    cli_res, self.need_collect_files.get("dir").get(file_key), ctrl_id)
            return []

        cli = exec_context.dev_info.cli
        lang = exec_context.lang
        is_in_mini_sys = cliUtil.testIsInMinisystemMode(cli)
        ctrl_id = baseUtil.get_ctrl_id_by_node_id(
            int(exec_context.cur_ctrl_id),
            self.export_dev.count_ctrl_one_engine)
        self.logger.info("node_id={}, ctrl_id={}, ctrl_num={}".format(
            exec_context.cur_ctrl_id, ctrl_id,
            self.export_dev.count_ctrl_one_engine))
        result = {}
        for file_type in self.need_collect_files.get("dir"):
            result[file_type] = __scan_dir(file_type)

        if is_in_mini_sys is not True:
            cliUtil.enterCliModeFromSomeModel(cli, lang)
        return FuncResult(
            ResultType.SUCCESS, '', '', other_result=result)

    def __parse_valid_files(self, cli_res, parsed_info, ctrl_id):
        """
        :param cli_res: 如0A_messages_0000000001_20201116_154455.tgz
        :return:
        """
        start_time_str = self._format_start_time()
        file_path = parsed_info.get("path")
        file_pattern = re.compile(parsed_info.get("file_pattern"))
        find_pattern = re.compile(parsed_info.get("find_pattern"))

        valid_files = []
        has_parsed_files = self.collected_log_files.get(ctrl_id)
        self.logger.info("ctrl={}, has_parsed_files={}".format(
            ctrl_id, has_parsed_files))
        for file_name in find_pattern.findall(cli_res):
            self.logger.info("file_name={}".format(file_name))
            match_obj = file_pattern.match(file_name)
            if not match_obj:
                self.logger.info("not valid file name")
                continue
            if file_name in str(has_parsed_files):
                self.logger.info("has collected file")
                continue
            if not start_time_str or match_obj.group(
                    1) >= start_time_str:
                self.logger.info("is valid file")
                valid_files.append(file_path + file_name)
        self.logger.info("valid_files={}".format(valid_files))
        return valid_files

    def _format_start_time(self):
        start_time_str = None
        if self.start_time:
            start_time_obj = datetime.datetime.fromtimestamp(
                int(self.start_time) / 1000)
            start_time_str = start_time_obj.strftime("%Y%m%d")
            self.logger.info("start time: {}".format(start_time_str))
        return start_time_str

    def init_exported_files_list(self, cli):
        """
        初始化导出文件列表
        :param cli: cli
        :return:
        """
        exec_all_context = ExeOnAllCtrlContext(self.context, cli)
        func_res = self.scan_available_files(exec_all_context)
        self.expected_export_files = func_res.other_result

    def export_files_on_all_engines(self, exporter):
        """
        导出所有引擎上的所需日志文件
        每个引擎是通过管理IP连接
        通过内部IP进行引擎内将控制器文件拖到连接的控制器
        :param exporter:
        :return:
        """
        count_engine = exporter.device.count_engine
        count_ctrl_one_engine = exporter.device.count_ctrl_one_engine
        current_engine = int(exporter.device.current_engine)
        # 先导出当前引擎
        self.export_on_one_engine(
            exporter, current_engine, count_ctrl_one_engine)
        failed_engines = []
        for engine_index in range(count_engine):
            if engine_index == current_engine:
                continue
            engine_ips = exporter.device.get_engine_ip(engine_index)
            # 将连接信息切换到对应引擎的一个控制器上
            if not self.switch_engine_by_ip(exporter, engine_ips):
                failed_engines.append(engine_index)
                self.logger.info("switch to engine {} failed".format(engine_index))
                continue

            self.export_on_one_engine(
                exporter, engine_index, count_ctrl_one_engine)

        self.save_exported_files_to_db(self.export_files)
        return failed_engines

    def switch_engine_by_ip(self, exporter, engine_ips):
        if not engine_ips:
            self.logger.info("engine ip is none.")
            return False
        try:
            exporter.device.switch_engine(engine_ips)
        except Exception:
            self.logger.error("can not reach to engine")
            return False
        return True

    def save_exported_files_to_db(self, export_files):
        last_export_files = \
            self.db_service.task_process_tbl.query_collected_log_files(self.sn)
        for ctrl, export_file in export_files.items():
            last_export_file = last_export_files.get(ctrl, [{}])[0]\
                .get("collected_log_files", "")
            export_file.append(last_export_file)
            save_data = {
                "sn": self.sn,
                "ctrl": ctrl,
                "collected_log_files": ",".join(export_file)
            }
            self.logger.info("collected log to save data:{}".format(save_data))
            self.db_service.task_process_tbl.save_collected_log_files(
                save_data)

    def export_on_one_engine(self, exporter, engine_index, count_ctrl):
        engine_his = self.expected_export_files[str(engine_index)]
        for ctrl_index in range(count_ctrl):
            logic_ctrl_id = engine_index * count_ctrl + ctrl_index
            ctrl_id = baseUtil.get_ctrl_id_by_node_id(
                logic_ctrl_id,
                count_ctrl)
            # 获取内网ip
            ip = self.build_ip(ctrl_index)
            ctrl_id_str = str(logic_ctrl_id)
            extra_files = []
            if ctrl_id_str in engine_his \
                    and engine_his[ctrl_id_str].other_result:
                for key in self.need_collect_files.get("dir").keys():
                    extra_files += \
                        engine_his[ctrl_id_str].other_result.get(key, [])
            self.export_file_on_one_ctrl(exporter, ip, ctrl_id, extra_files)

    def export_file_on_one_ctrl(
            self,
            exporter,
            inner_ip,
            ctrl_id,
            extra_file=None):
        self.logger.info("export ctrl={}, inner_ip={}, files={}".format(
            ctrl_id, inner_ip, extra_file))
        # 先导出当前日志
        for file_name in self.need_collect_files.get("file"):
            file_path = self.need_collect_files.get("file")[file_name]
            # 将文件移动至导出目录
            flag, file_path = self.find_actual_path_by_attempt_exporter(file_path, exporter, inner_ip)
            if flag is not True:
                continue
            # sftp导出文件
            flag = exporter.download_file(
                file_name, ctrl_id)

            if flag is not True:
                self.logger.info("export file failed")

            # 导出后删除目标目标文件
            exporter.delete_file(file_path)

        # 再导出历史日志
        if extra_file is not None:
            for file_path in extra_file:
                # 将文件移动至导出目录
                flag, ret, err_msg = exporter.do_transfer_file(
                    file_path, inner_ip)
                if flag is not True:
                    continue
                # sftp导出文件
                export_file_name = FileExporter.parse_file_name(file_path)
                flag = exporter.download_file(export_file_name, ctrl_id)
                if flag is not True:
                    self.logger.info("export file failed")
                if ctrl_id not in self.export_files:
                    self.export_files[ctrl_id] = []
                self.export_files[ctrl_id].append(export_file_name)

                # 导出后删除目标目标文件
                exporter.delete_file(file_path)

    @staticmethod
    def build_ip(ctrl_index):
        return '127.127.127.' + str(10 + ctrl_index)

    @staticmethod
    def find_actual_path_by_attempt_exporter(file_path, exporter, inner_ip):
        """
        通过尝试导出的方式，获取实际的路径
        :param file_path: 文件位置（可能是字符串或者列表）
        :param exporter: 导出执行器
        :param inner_ip: 内部IP
        :return: 导出执行结果， 实际的文件位置
        """
        file_paths = file_path
        if not isinstance(file_path, list):
            file_paths = [file_path]
        for item in file_paths:
            flag, ret, err_msg = exporter.do_transfer_file(item, inner_ip)
            if flag is True:
                return True, item
        return False, ""


class FileExporter:
    def __init__(self, context, export_device):
        self.context = context
        self.device = export_device
        self.lang = contextUtil.getLang(self.context)
        self.logger = contextUtil.getLogger(self.context)
        self.collect_adapter = self._get_collect_adapter()

        # 通过getremotefile将同引擎控制器日志下载到当前控制器时，文件存放路径
        self.export_path = self.test_export_path()

    def _get_collect_adapter(self):
        collect_adapter_612 = GetRemoteFileAdapter(self.device.dev, self.logger, self.lang)
        if collect_adapter_612.is_match():
            return collect_adapter_612
        return None

    def test_export_path(self):
        """
        测试可以访问的导出路径
        :return:
        """
        if self.collect_adapter:
            return self.collect_adapter.get_export_path()

        # 610 版本做了安全加固，只能导出到/home/permitdir
        if self._is_version_dorado_610(self.device.dev.get("type"), self.device.dev.get("version")):
            return EXPORT_PATH_FOR_DORADO_610

        possible_path = ('/OSM/coffer_data/omm/export_import',
                         '/OSM/export_import')

        for path in possible_path:
            flag, cli_ret, _ = cliUtil.excuteCmdInMinisystemModel(
                self.device.current_cli,
                "ls {}".format(path),
                self.lang, self.device.dev.get('pawd'))
            if "permission denied" not in cli_ret.lower():
                return path
        raise Exception("no find access path")

    def _is_version_dorado_610(self, dev_type, dev_version):
        if 'dorado' in dev_type.lower() and is_digital_ver_or_v7(
                dev_version) and compare_version(dev_version, '6.1.RC1') >= 0:
            self.logger.info("is dorado 610 version.")
            return True
        return False

    def do_transfer_file(self, remote_file_path, ip):
        """
        将文件传至当前登录控制器
        :param remote_file_path: 远端目录
        :param ip: 远端IP
        :return:
        """
        if self.collect_adapter:
            export_cmd = self.collect_adapter.get_export_cmd(ip, remote_file_path)
        else:
            export_cmd = TRANSFER_CMD.format(ip, remote_file_path, self.export_path)
        try:
            flag, cli_ret, err_msg = cliUtil.excuteCmdInMinisystemModel(
                self.device.current_cli, export_cmd,
                self.lang, self.device.dev.get('pawd'))
            self.logger.info('try to get remote file:{}, result:{}',
                             remote_file_path, cli_ret)
            if flag is not True:
                return flag, cli_ret, err_msg
            return 'fail' not in cli_ret and "does not exist" not in cli_ret, cli_ret, err_msg
        except Exception as e:
            self.logger.error(
                'error occurs when transfer log:%s from:%s' %
                (remote_file_path, ip), e)
            return False, '', 'fail to transfer log'

    def download_file(self, remote_file, controller_id=None,
                      is_full_path=False):
        """
        下载目标文件
        :param remote_file:
        :param controller_id:
        :param is_full_path: 完整路径
        :return:
        """
        local_dir = self.context.get('log_path')
        if not is_full_path:
            remote_file = self.export_path + "/" + remote_file
        # 如果本地临时目录不存在，则创建
        if not os.path.exists(local_dir):
            os.makedirs(local_dir)
        target_file_name = FileExporter.parse_file_name(remote_file)
        if controller_id:
            target_file_name = controller_id + '_' + target_file_name
        target = File(local_dir + '/' + target_file_name)
        return self.__sftp_transfer(remote_file, target)

    def __sftp_transfer(self, remote_name, local_name):
        sftp = self.device.sftp
        retry_times = 0
        while True:
            try:
                sftp.getFile(remote_name, local_name, None)
                return True
            except JException as e:
                if retry_times < 5:
                    self.logger.warn('try download file fail')
                    retry_times += 1
                    time.sleep(2)
                else:
                    self.logger.error(
                        'error occurs when try to download file:{}, '
                        'message:{}'.format(remote_name, str(e)))
                    return False

    def delete_file(self, file_name):
        remote_name = self.export_path + '/' + \
            FileExporter.parse_file_name(file_name)
        if self.collect_adapter:
            self.collect_adapter.clear_tmp_file(self.device.current_cli, remote_name)
            return
        try:
            self.device.sftp.deleteFile(remote_name)
        except JException as e:
            self.logger.warn('delete file:{} fail'.format(remote_name), e)

    @staticmethod
    def parse_file_name(remote_path):
        return remote_path[remote_path.rindex('/') + 1:]

    @property
    def sftp(self):
        return self.device.sftp


class ExportDevice:
    def __init__(self, context):
        self.context = context
        self.lang = contextUtil.getLang(self.context)
        self.logger = contextUtil.getLogger(self.context)
        self.count_engine = 1
        self.count_ctrl_one_engine = 1
        self.dev = self.context.get('dev')
        self.src_cli = context.get("cli")
        self.current_cli = cli_con_mgr.get_ctrl_cli(self.context, self.src_cli)
        self.current_ctrl = '0'
        self.current_engine = '0'
        self.login_mini_sys = cliUtil.testIsInMinisystemMode(self.current_cli)
        self.login_ip = self.context.get('dev').get('ip')
        self.use_ipv4 = True if re.match(
            '\\d+\\.\\d+\\.\\d+\\.\\d+', self.login_ip,
        ) else False
        self.engine2ip = {}
        self.sftp = None

    def init_device(self):
        flag = self.get_system_time() and self.init_engine_ctrl_count() and self.init_engine2ip()
        return flag

    def get_system_time(self):
        sys_time, _ = cliUtil.getSystemDate(self.current_cli, self.lang)
        self.context["sys_time"] = sys_time
        return bool(sys_time)

    def reconnect_con(self):
        self.logger.info("reconnect conn")
        self.current_cli = cli_con_mgr.get_ctrl_cli(self.context, self.src_cli)
        self.creat_sftp_conn()

    def creat_sftp_conn(self):
        self.sftp = SftpTransporter(self.current_cli)

    def init_engine_ctrl_count(self):
        cmd = 'showsysstatus'
        flag, cli_ret, err_msg = cliUtil.excuteCmdInMinisystemModel(
            self.current_cli, cmd, self.lang)
        # 如果登录时非小系统，则退出至cli
        if not self.login_mini_sys:
            cliUtil.enterCliModeFromSomeModel(self.current_cli, self.lang)
        if flag is not True:
            self.logger.warn('try to query engine ctrl count fail:{}'.
                             format(err_msg))
            return False
        self.parse_engine_ctrl_count_mini_sys(cli_ret)
        return True

    def parse_engine_ctrl_count_mini_sys(self, cli_ret):
        vertical_res = cliUtil.getVerticalCliRet(cli_ret)[0]
        self.current_ctrl = vertical_res.get('local node id', '0')
        # 尝试计算单引擎控制器数量
        ctrl_rows = cliUtil.getHorizontalCliRet(cli_ret)
        grouped_rows = baseUtil.group_by(ctrl_rows, 'engine')
        engine_max_ctrl = 2
        for item in grouped_rows.values():
            engine_max_ctrl = max(engine_max_ctrl, len(item))
        self.count_ctrl_one_engine = engine_max_ctrl if engine_max_ctrl % 2 == 0 else engine_max_ctrl + 1
        self.count_engine = max(len(grouped_rows.values()), 1)
        for row in ctrl_rows:
            if self.current_ctrl == row.get('id'):
                self.current_engine = row.get('engine', '0')
                break

    def init_engine2ip(self):
        """
        获取每个引擎下的一个有效ip
        :return:
        """
        # 登录即为小系统无法查询，故直接返回
        if self.login_mini_sys:
            return True
        cmd = 'show upgrade package'
        flag, cli_ret, err_msg = cliUtil.excuteCmdInCliMode(
            self.current_cli, cmd, True, self.lang)
        if flag is not True:
            self.logger.warn('try to query engine2ip fail:{}'.format(err_msg))
            return flag
        self.engine2ip = self.parse_engine_ip(cli_ret)
        self.init_engine2ip_by_show_port()
        return True

    def init_engine2ip_by_show_port(self):
        if self._is_svp_disabled():
            self.engine2ip = {}
        cmd = 'show port general logic_type=Management_Port physical_type=ETH'
        flag, cli_ret, err_msg = cliUtil.excuteCmdInCliMode(
            self.current_cli, cmd, True, self.lang)
        if flag is not True:
            self.logger.warn('try to query engine2ip fail:{}'.format(err_msg))
        manage_ip_dict = cliUtil.getHorizontalCliRet(cli_ret)
        self.logger.info("engine to ip:{}".format(self.engine2ip))
        for item in manage_ip_dict:
            location = item.get("ID", '')
            match_obj = re.match("CTE(\d+)\.", location)
            if not match_obj:
                continue
            engine_id = match_obj.group(1)
            if engine_id not in self.engine2ip:
                self.engine2ip[engine_id] = []
            ipv4 = item.get("IPv4 Address", "")
            if re.match("\d+.\d+.\d+.\d+", ipv4):
                self.engine2ip[engine_id].append(ipv4)

        self.logger.info("engine to ip:{}".format(self.engine2ip))
        for engine_id, ips in self.engine2ip.items():
            self.engine2ip[engine_id] = list(set(ips))

    def _is_svp_disabled(self):
        try:
            dev_type = contextUtil.getProductModel(self.context)
            version = contextUtil.getCurVersion(self.context)
            if not dev_type.startswith("18") or version < "V500R007C60":
                return False
            dev_node = contextUtil.get_base_dev(self.context)
            svp_module_info = dev_node.getHighDevSVPModuleInfo()
            if svp_module_info == "new_toolbox_nosvp":
                return True
        except (JException, Exception):
            self.logger.error("is svp disable error.")
        return False

    def get_engine_ip(self, engine_index):
        """
        获取引擎上的一个ip
        :return:
        """
        return self.engine2ip.get(str(engine_index))

    def parse_engine_ip(self, cli_ret):
        lines = cli_ret.splitlines()
        for index in range(len(lines)):
            if 'HotPatch Version' in lines[index]:
                lines = lines[index + 1:]
                break
        res_map = cliUtil.getHorizontalCliRet('\n'.join(lines))
        engines = set()
        result = {}
        for item in res_map:
            engine_id = item.get('Name')[0]
            engines.add(engine_id)
            if '-' in item.get('IP'):
                continue
            if engine_id not in result:
                result[engine_id] = []
            result[engine_id].append(item.get('IP'))
        # 重写为实际接入的引擎数
        self.logger.info('init engine to ip result:{}', str(result))
        self.count_engine = len(engines)
        return result

    @staticmethod
    def get_engine_res(cli_res):
        """
        获取引擎ip相关的返回
        :param cli_res:
        :return:
        """
        parts = cli_res.splitlines()
        valid_lines = []
        flag = False
        for part in parts:
            if len(part) == 0:
                continue
            if part.startswith(' '):
                if flag:
                    valid_lines.append(part)
            elif part.startswith('-'):
                continue
            elif part.startswith('ETH port'):
                flag = True
            else:
                flag = False
        return cliUtil.getHorizontalCliRet('\n'.join(valid_lines))

    def switch_engine(self, ip_list):
        self.logger.info('start to switch engine')
        self.current_cli.close()
        self.sftp.close()
        for ip in ip_list:
            self.dev['ip'] = ip
            self.current_cli = contextUtil.createCliConnection(
                self.context, ip)
            if self.current_cli is not None:
                self.logger.info('switch engine to:{} success', ip)
                break
        if self.current_cli is None:
            raise UnCheckException(common.getMsg(
                self.lang, "log.collect.engine.manage.ip.can.not.reach"), "")

        self.logger.info('finish switch engine')
        self.creat_sftp_conn()

    def clear_env(self):
        self.logger.info('try to clear sftp collect env')
        if self.src_cli:
            self.src_cli.close()
        if self.current_cli:
            self.current_cli.close()
        if self.sftp is not None:
            self.sftp.close()
