#  coding=UTF-8
#  Copyright (c) Huawei Technologies Co., Ltd. 2020~ All rights reserved.
import os
import zipfile
import codecs
import traceback
from collections import defaultdict

# noinspection PyUnresolvedReferences
import com.alibaba.fastjson.JSON as JSON
import shutil
import time
# noinspection PyUnresolvedReferences
from com.huawei.ism.tool.base.utils import SceneUtils
# noinspection PyUnresolvedReferences
from com.huawei.ism.tool.obase.exception import ToolException
# noinspection PyUnresolvedReferences
from java.io import File
# noinspection PyUnresolvedReferences
from java.lang import System
# noinspection PyUnresolvedReferences
from java.lang import Exception as JException

from cbb.business.collect.hostinfo import const
from cbb.business.collect.hostinfo.collect_host_info_util import (
    HostInfoCollectException,
)
from cbb.business.collect.hostinfo.collect_host_info_util import (
    notify_storage_device_delete_host_info_pkg,
)
from cbb.business.collect.hostinfo.collect_host_info_util import (
    notify_storage_start_collect_host_info,
)
from cbb.business.collect.hostinfo.collect_host_info_util import (
    query_host_info_collect_download_path,
)
from cbb.business.collect.hostinfo.collect_host_info_util import (
    query_host_info_collect_script_upload_path,
)
from cbb.business.collect.hostinfo.collect_host_info_util import (
    query_host_info_collect_status,
)
from cbb.business.collect.hostinfo.collect_host_info_util import (
    check_conn_is_alive,
)
from cbb.business.collect.hostinfo.collect_host_info_util import (
    reconnect_conn,
)
from cbb.common.query.software.host import (
    get_hosts_can_be_collected_via_storage,
)
from cbb.common.query.software.host import get_hyper_hosts, is_version_6_1
from cbb.frame.base.baseUtil import getString
from cbb.frame.context import contextUtil
from cbb.frame.cli import cliUtil
from cbb.frame.adapter.replace_adapter import compare_version
from cbb.frame.adapter.replace_adapter import is_digital_ver
from cbb.frame.base.baseUtil import is_ocean_protect
from cbb.frame.base.baseUtil import is_new_oceanstor


def collect(params_dict):
    """信息收集入口参数

    :param params_dict:
    :return:
    """
    try:
        return collect_core(params_dict)
    finally:
        # release rest connection is very important.
        java_map = contextUtil.getContext(params_dict)
        contextUtil.releaseRest(java_map)


def collect_core(params_dict):
    """主机信息收集入口函数。
    拋出异常：
        HostInfoCollectException

    :param params_dict: 参数字典
    :return: 信息收集结果包。
    """
    pkg_download_path = _collect_4_download_path(params_dict)
    # 下载之前判断是否ssh超时
    logger = params_dict.get("logger")
    check_conn_is_alive(params_dict, logger)
    result_pkg_path = transfer_file_with_retry(
        params_dict, pkg_download_path, upload=False
    )
    delete_pkg_and_ignore_exception(params_dict)

    logger.info("Result package path:{}".format(result_pkg_path))

    # 下载之后还需要执行cli命令，判断是否ssh超时
    check_conn_is_alive(params_dict, logger)
    result_dict = parse_collect_result(params_dict, result_pkg_path)
    result_dict["host_pkg_path"] = result_pkg_path
    logger.info("Parsed result dict:{}".format(result_dict))
    return result_dict


def _collect_4_download_path(params_dict):
    verify_params_dict(params_dict, ("devNode", "devIp"))
    upload_package_by_node_id(params_dict)
    try:
        loop_query_host_info_collect_status(params_dict)
    except Exception as e:
        logger = params_dict.get("logger")
        logger.error("_collect_4_download_path error!: {}\n{}".format(
            e, str(traceback.format_exc())))
    pkg_download_path = query_host_info_collect_download_path(params_dict)
    return pkg_download_path


def collect_4_download_path(params_dict):
    """
    主机信息收集
    :param params_dict: 参数字典
    :return: 远端下载路径
    """
    try:
        return _collect_4_download_path(params_dict)
    except (Exception, JException) as e:
        logger = params_dict.get("logger")
        logger.error("collect_4_download_path error!: {}".format(e))
        return ""


def after_collect_host_info(params_dict):
    try:
        logger = params_dict.get("logger")
        logger.info("delete host collect package.")
        check_conn_is_alive(params_dict, logger)
        delete_pkg_and_ignore_exception(params_dict)
    except (Exception, JException) as e:
        logger = params_dict.get("logger")
        logger.error("after_collect_host_info error!: {} \n{}".format(
            e, str(traceback.format_exc())))
    finally:
        # release rest connection is very important.
        java_map = contextUtil.getContext(params_dict)
        contextUtil.releaseRest(java_map)


def upload_package_by_node_id(params_dict):
    """
    指定node_id，上传信息采集包到阵列
    :param params_dict:
    :return:
    """
    local_node_id = query_local_node_id(params_dict)
    # noinspection PyUnresolvedReferences
    dev_obj = params_dict.get("devNode")
    if need_upload_host_script_package(dev_obj):
        script_upload_path = query_host_info_collect_script_upload_path(
            params_dict, nodeId=local_node_id
        )
        transfer_file_with_retry(params_dict, script_upload_path, upload=True)
    if is_version_6_1(dev_obj):
        notify_storage_start_collect_host_info(
            params_dict,
            request_method="post",
            scriptName=const.HOST_INFO_COLLECT_UPLOAD_ZIP_FILE_NAME,
        )
    else:
        notify_storage_start_collect_host_info(
            params_dict,
            scriptName=const.HOST_INFO_COLLECT_UPLOAD_ZIP_FILE_NAME,
            nodeId=local_node_id,
            workPortId=local_node_id
        )


def need_upload_host_script_package(dev_obj):
    sys_ver = str(dev_obj.getProductVersion())
    product_model = str(dev_obj.getDeviceModel())
    if is_ocean_protect(product_model):
        return False
    if is_new_oceanstor(product_model):
        return False
    if is_digital_ver(sys_ver) and compare_version(sys_ver, "6.1.2RC1") >= 0:
        return False
    if "Kunpeng" in sys_ver and compare_version(sys_ver, "V500R007C72 Kunpeng") >= 0:
        return False
    if compare_version(sys_ver, "V700R001C00") >= 0:
        return False
    return True


def upload_package_by_port_id(params_dict):
    """
    指定port_id，上传信息采集包到阵列
    :param params_dict:
    :return:
    """
    tool_connected_eth_port_id = get_tool_connected_eth_port_id(params_dict)
    script_upload_path = query_host_info_collect_script_upload_path(
        params_dict, workPortId=tool_connected_eth_port_id
    )
    transfer_file_with_retry(params_dict, script_upload_path, upload=True)
    # noinspection PyUnresolvedReferences
    notify_storage_start_collect_host_info(
        params_dict,
        scriptName=const.HOST_INFO_COLLECT_UPLOAD_ZIP_FILE_NAME,
        workPortId=tool_connected_eth_port_id,
    )


def transfer_file_with_retry(
    params_dict,
    remote_path,
    upload=True,
    retry_times=3,
    retry_interval_secs=15,
):
    """通过sftp传输文件，支持重试。

    :param params_dict:
    :param remote_path: 远端文件（目录）路径
    :param upload: 是否上传
    :param retry_times:重试次数
    :param retry_interval_secs:重试间隔
    :return:
    """

    retry_cnt = 0
    error_id = None
    error_des = None
    while retry_cnt < retry_times:
        try:
            return transfer_file_once(params_dict, remote_path, upload)
        except(ToolException, Exception) as te:
            logger = params_dict.get("logger")
            logger.error("download error, reconnect!: {}".format(te))
            reconnect_conn(params_dict, logger)
            error_id = te.getErrorId()
            error_des = te.getDes()
            retry_cnt += 1
            time.sleep(retry_interval_secs)

    # noinspection PyUnresolvedReferences
    err_msg_key = (
        const.ERR_KEY_UPLOAD_SCRIPT_FAILED
        if upload
        else const.ERR_KEY_DOWNLOAD_COLLECTED_FILE_FAILED
    )
    raise HostInfoCollectException(error_id, error_des, getString(err_msg_key))


def verify_params_dict(params_dict, need_keys=("SFTP", "devNode", "toolDir")):
    """检验参数字典。

    :param params_dict:
    :param need_keys:
    :return:
    """
    missing_keys = []
    for need_key in need_keys:
        if need_key not in params_dict:
            missing_keys.append(need_key)
    if missing_keys:
        # noinspection PyUnresolvedReferences
        raise HostInfoCollectException(
            -1,
            "param error, missing keys:{}".format(",".join(missing_keys)),
            getString(const.ERR_KEY_COLLECT_HOST_INFO_PARAMS_ERROR),
        )


def zip_hostinfo_collect_pkg(params_dict):
    """ 将场景化文件写入scene_params.txt，并将其与 zip/zip.cms/zip.crl一同打包
    为 hostinfo_collect_all.zip 文件，然后上传。

    :param params_dict:
    :return:
    """
    tool_dir = params_dict.get("toolDir")
    host_info_collect_path = os.path.join(
        tool_dir, const.HOST_INFO_COLLECT_SCRIPT_PATH_NAME
    )
    logger = params_dict.get("logger")
    host_info_collect_pkg_dir = os.path.split(host_info_collect_path)[0]
    smartkit_param = SceneUtils.getCurrentSubScene()
    hostinfo_collect_para = const.HOST_INFO_COLLECT_SCENE_PARAM_MAP.get(
        smartkit_param, ""
    )
    # 多个参数要拼接成一行，参考设计文档。
    logger.info("Host info collect params:{}".format(hostinfo_collect_para))
    zip_pkg_path_name = os.path.join(
        host_info_collect_pkg_dir, const.HOST_INFO_COLLECT_UPLOAD_ZIP_FILE_NAME
    )
    scene_para_file = os.path.join(
        host_info_collect_pkg_dir,
        const.HOST_INFO_COLLECT_SCENE_PARAMS_FILE_NAME,
    )
    try:
        with codecs.open(scene_para_file, "w", encoding='utf-8') as f:
            f.write(hostinfo_collect_para)

        zip_file_obj = zipfile.ZipFile(
            zip_pkg_path_name, "w", zipfile.ZIP_DEFLATED
        )
        for file_name in const.HOST_INFO_UPLOAD_ZIP_FILE_NAMES:
            file_path_name = os.path.join(host_info_collect_pkg_dir, file_name)
            zip_file_obj.write(file_path_name, file_name)
        zip_file_obj.close()
        return zip_pkg_path_name
    except Exception as e:
        logger.error(
            "create hostinfo collect upload pkg exception:{}".format(e)
        )
        return
    finally:
        clear_temp_file(scene_para_file, logger)


def clear_temp_file(tmp_file, logger):
    """
    安全删除临时文件。
    :param tmp_file:
    :param logger:
    :return:
    """
    try:
        os.remove(tmp_file)
    except Exception as ex:
        logger.error("clear temp file error: {}".format(ex))


def transfer_file_once(params_dict, remote_path, upload=True):
    """执行1次传输操作。

    :param params_dict: 主机信息参数字典
    :param remote_path:远端文件（目录）路径。
    :param upload:是否上传
    :return:
    """
    verify_params_dict(params_dict, need_keys=("SFTP", "devNode", "toolDir"))

    sftp = params_dict.get("SFTP")
    dev_node = params_dict.get("devNode")

    # noinspection PyUnresolvedReferences
    if upload:
        upload_pkg_pathname = zip_hostinfo_collect_pkg(params_dict)
        sftp.putFile(File(upload_pkg_pathname), remote_path, None)
        logger = params_dict.get("logger")
        clear_temp_file(upload_pkg_pathname, logger)
        return None
    else:
        java_tmp_dir = System.getProperty("java.io.tmpdir")
        # noinspection PyUnresolvedReferences
        host_info_local_save_dir = os.path.join(
            java_tmp_dir, const.HOST_INFO_COLLECT_SAVE_DIR
        )
        if not os.path.exists(host_info_local_save_dir):
            os.makedirs(host_info_local_save_dir)

        # noinspection PyUnresolvedReferences
        final_result_pkg_full_name = os.path.join(
            host_info_local_save_dir,
            dev_node.getDeviceSerialNumber()
            + "_"
            + const.HOST_INFO_COLLECT_PKG_NAME,
        )

        # noinspection PyUnresolvedReferences
        sftp.getFile(remote_path, File(final_result_pkg_full_name), None)
        return final_result_pkg_full_name


def query_all_eth_ports(params_dict):
    """查询所有 eth 端口。

    :param params_dict: 参数字典
    :return: 端口信息字典的列表
    """
    dev_obj = params_dict.get("devNode")
    logger = params_dict.get("logger")
    java_map = contextUtil.getContext(params_dict)
    rest_conn_wrapper = contextUtil.getRest(java_map)
    rest_connection = rest_conn_wrapper.getRest()

    dev_ip = params_dict.get("devIp")
    is_ipv6 = ":" in dev_ip

    url = r"https://{ip}:{port}/deviceManager/rest/{dev_sn}/eth_port".format(
        ip="[{}]".format(dev_ip) if is_ipv6 else dev_ip,
        port=str(8088),
        dev_sn=str(dev_obj.getDeviceSerialNumber()),
    )
    response_info = rest_connection.execGet(url)
    content = None
    try:
        content = response_info.getContent()
        response_dict = JSON.parse(content)
    except Exception as e:
        logger.error("Exception:{}".format(e))
        logger.error("Resp content:{}".format(content))
        raise HostInfoCollectException(
            -1,
            "Rest content parse exception",
            getString(const.ERR_KEY_QUERY_ALL_ETH_PORTS_FAILED),
        )

    err_code = response_dict.get("error", {}).get("code", 0)
    err_desc = response_dict.get("error", {}).get("description", 0)
    ports_dict_list = response_dict.get("data", [])

    if err_code != 0:
        raise HostInfoCollectException(
            err_code,
            "Rest error desc:{}".format(err_desc),
            getString(const.ERR_KEY_QUERY_ALL_ETH_PORTS_FAILED),
        )

    return ports_dict_list


def get_tool_connected_eth_port_id(params_dict):
    """获取工具连接 ETH 端口 ID。

    :param params_dict:
    :return:
    """
    ports_dict_list = query_all_eth_ports(params_dict)
    tool_connected_ip = params_dict.get("devIp")
    is_ipv6 = ":" in tool_connected_ip

    for port_dict in ports_dict_list:
        dev_ip = (
            port_dict.get("IPV4ADDR", "")
            if not is_ipv6
            else port_dict.get("IPV6ADDR", "")
        )
        port_id = port_dict.get("ID", "")
        if tool_connected_ip == dev_ip:
            return port_id
    else:
        raise HostInfoCollectException(
            -1,
            "No eth port matched connected IP.",
            getString(const.ERR_KEY_CREATE_CTRL_CONN_FAILED),
        )


def query_local_node_id(params_dict):
    """
    调用公共接口获取工具连接的节点id
    :param params_dict:
    :return: local_node_id
    """
    cli = params_dict.get("SSH")
    logger = params_dict.get("logger")
    lang = params_dict.get("lang")
    try:
        flag, cli_ret, err_msg, ctrl_topography_tuple \
            = cliUtil.getControllerEngineTopography(cli, lang)
        if flag and ctrl_topography_tuple[1]:
            local_node_id = ctrl_topography_tuple[1]
            return local_node_id
    except (Exception, JException) as e:
        logger.error("query local id exception.{}".format(e))
    raise HostInfoCollectException(
        -1,
        "Query local node id error.",
        getString(const.ERR_KEY_QUERY_LOCAL_NODE_ID_FAILED),
    )


def loop_query_host_info_collect_status(
    dev_obj, time_out_secs=70 * 60, qry_interval_secs=15, allow_failed_times=5
):
    """轮训主机信息收集进度。

    :param dev_obj: 设备信息字典
    :param time_out_secs: 超时时间
    :param qry_interval_secs: 查询间隔
    :param allow_failed_times: 运行连续查询失败最大次数
    :return:
    """
    start_time_secs = time.time()
    except_times = 0
    while True:
        end_time_secs = time.time()
        if end_time_secs - start_time_secs >= time_out_secs:
            # noinspection PyUnresolvedReferences
            raise HostInfoCollectException(
                -1,
                "collect host info timed out, over {} minutes".format(
                    time_out_secs / 60.0
                ),
                getString(const.ERR_KEY_COLLECT_HOST_INFO_TIMED_OUT),
            )
        try:
            progress, status = query_host_info_collect_status(dev_obj)
        except HostInfoCollectException as e:
            except_times += 1
            if except_times > allow_failed_times:
                raise e
        else:
            except_times = 0
            # noinspection PyUnresolvedReferences
            if status == const.COLLECT_STATUS_DONE:
                return

        time.sleep(qry_interval_secs)


def delete_pkg_and_ignore_exception(params_dict):
    """通知阵列删除主机信息收集结果，忽略异常信息。

    :param params_dict:
    :return:
    """
    # noinspection PyBroadException
    dev_obj = params_dict.get("devNode")
    if is_version_6_1(dev_obj):
        # 6.1.*版本使用delete方法
        notify_storage_device_delete_host_info_pkg(
            params_dict, request_method="delete")
    else:
        notify_storage_device_delete_host_info_pkg(params_dict)


def parse_collect_result(params_dict, result_pkg_path):
    """解析收集结果：解压压缩包，并解析每台主机的结果文件。
    返回值说明：
    {'IPv4_or_IPv6': ['xxx\\Linux_XX_202002110159.data',
                     'xxx\\CollectLog_XX_202002110159.txt'
                     ],
     'success': ['IPv4_or_IPv6']
     'part_success':[],
     'failed':[]
     }

    :param params_dict:
    :param result_pkg_path:
    :return:
    """
    tgz_path, tgz_name = os.path.split(result_pkg_path)
    storage_dev_sn = tgz_name.split("_")[0]
    result_dict = defaultdict(list)

    zip_dir = os.path.join(tgz_path, storage_dev_sn)
    if os.path.exists(zip_dir):
        shutil.rmtree(zip_dir, ignore_errors=True)

    tar_class = params_dict.get("PYENGINE.PY_ZIP")
    zip_type = get_pkg_zip_type(result_pkg_path)

    zip_method_dict = {
        ".zip": tar_class.decompressZipFileAndGetFileNames,
        ".tgz": tar_class.decompressTarGzFileAndGetFileNames,
        ".tar.gz": tar_class.decompressTarGzFileAndGetFileNames,
    }
    zip_method = zip_method_dict.get(zip_type)
    zip_method(result_pkg_path, zip_dir)

    host_zip_file_path = os.path.join(zip_dir, const.ZIP_FILE_REL_PATH)
    host_zip_files = os.listdir(host_zip_file_path)
    logger = params_dict.get("logger")
    logger.info("host_zip_files:{}".format(host_zip_files))
    for zip_file_name in host_zip_files:
        if zip_file_name.endswith(("txt", "data")):
            continue

        logger.info("host_zip_file_path:{}".format(host_zip_file_path))
        zip_file_path_name = os.path.join(host_zip_file_path, zip_file_name)
        zip_file_type = get_pkg_zip_type(zip_file_name)
        logger.info("zip_file_path_name:{}".format(zip_file_path_name))
        host_pkg_zip_method = zip_method_dict.get(zip_file_type)
        host_data_files = host_pkg_zip_method(
            zip_file_path_name, host_zip_file_path
        )
        # 老的多路径上面的解压可能失败，可以使用下面py解压
        if not host_data_files:
            # 还是无法解压，尝试使用py自带压缩方式解压
            host_data_files = uncompress_file_other(
                zip_file_path_name, logger, host_zip_file_path
            )
        if not host_data_files:
            continue

        host_data_files_path_names = list(
            map(
                lambda path: os.path.join(host_zip_file_path, path),
                host_data_files,
            )
        )
        host_ip_f = get_ip_from_file_path(host_data_files)
        host_ip = os.path.split(zip_file_path_name)[1].strip(zip_file_type)
        if (
            host_ip_f
            and host_ip != host_ip_f
            and ("." in host_ip_f or ":" in host_ip_f)
        ):
            host_ip = host_ip_f

        result_dict[host_ip] = host_data_files_path_names

    statistic_results(params_dict, result_dict)
    return result_dict


def uncompress_file_other(zip_file_path_name, logger, host_zip_file_path):
    """
    使用Py自带压缩包解压
    :param zip_file_path_name:
    :param logger:
    :param host_zip_file_path:
    :return:
    """
    try:
        logger.info("decompress except use zip:{}".format(
            zip_file_path_name))
        f = zipfile.ZipFile(zip_file_path_name, "r")
        host_data_files = f.namelist()
        for file in f.namelist():
            f.extract(file, host_zip_file_path)
        logger.info("decompress use py zip res:{}".format(
            str(host_data_files)))
        return host_data_files
    except BaseException:  # noinspection PyBroadException
        logger.error("decompress use py zip except!")
        return []


def get_ip_from_file_path(host_data_files):
    """
    存在路径中的IP和压缩中的IP不一致场景，使用.data文件中的IP，
    如果没有则继续使用路径中的IP
    """
    for file_path in host_data_files:
        if file_path.endswith(".data"):
            tmp_list = file_path.split("_")
            if len(tmp_list) >= 2:
                return tmp_list[1]
    return ''


def get_pkg_zip_type(result_pkg_path):
    """根据压缩包名获取压缩包类型。

    :param result_pkg_path:
    :return:
    """
    if result_pkg_path.endswith(".zip"):
        zip_type = ".zip"
    elif result_pkg_path.endswith(".tgz"):
        zip_type = ".tgz"
    elif result_pkg_path.endswith(".tar.gz"):
        zip_type = ".tar.gz"
    else:
        zip_type = ".zip"
    return zip_type


def statistic_results(params_dict, result_dict):
    """统计收集结果。

    :param params_dict:
    :param result_dict:
    :return:
    """

    logger = params_dict.get("logger")
    logger.info("Unzipped result_dict:{}".format(result_dict))
    for hostkey in result_dict.keys():
        host_files = result_dict.get(hostkey)
        host_data_files = list(
            filter(lambda file_name: file_name.endswith(".data"), host_files)
        )
        if not host_data_files:
            result_dict["failed"].append(hostkey)
            continue

        host_data_file = host_data_files[0]
        if not judge_collect_success(host_data_file):
            result_dict["part_success"].append(hostkey)
        else:
            result_dict["success"].append(hostkey)

    host_sn_iqn_list = get_hosts_can_be_collected_via_storage(params_dict)
    should_collect_host_num = len(host_sn_iqn_list)
    result_dict["should_collect_host_num"] = should_collect_host_num
    result_dict["host_sn_to_wwn_or_iqn_list"] = host_sn_iqn_list
    handle_hyper_hosts_for_inspector(params_dict, result_dict)

    return result_dict


def handle_hyper_hosts_for_inspector(params_dict, result_dict):
    """处理双活主机（为巡检工具免主机添加做准备）

    :param params_dict:
    :param result_dict:
    :return:
    """
    logger = params_dict.get("logger")
    hyper_host_ids = params_dict.get("hyper_host_ids")

    # 如果未查询过，则执行查询。
    if hyper_host_ids is None:
        hyper_hosts = get_hyper_hosts(params_dict)
        hyper_host_ids = [
            hyper_host_info.get("ID") for hyper_host_info in hyper_hosts
        ]

    # 非双活设备
    if not hyper_host_ids:
        result_dict["hyper_host"] = []
        logger.info("No hyper host...")
        return

    all_huawei_initiators = params_dict.get(
        "all_huawei_initiators", defaultdict(list)
    )
    hyper_host_sns = []
    for host_sn in all_huawei_initiators:
        ini_objs = all_huawei_initiators.get(host_sn)
        hyper_initiators = list(
            filter(lambda ini_obj: ini_obj.host_id in hyper_host_ids, ini_objs)
        )
        if hyper_initiators:
            hyper_host_sns.append(host_sn)

    result_dict["hyper_host"] = hyper_host_sns


def judge_collect_success(host_data_file):
    """根据文件结尾是否写入：self_define_cmd_checkinfo 标记字符串，判断

    是否收集完全成功。

    :param host_data_file:
    :return:
    """
    if not os.path.exists(host_data_file):
        return False

    with open(host_data_file) as hdf:
        for line in hdf:
            if const.HOST_INFO_COLLECT_SUCCESS_MARK in line:
                return True
    return False
