#!/usr/bin/python
# -*- coding: utf-8 -*-
import json
import re
import sys
import getopt
import os
import yaml

from baseline import rpc_opcode
from baseline.return_code import INVALID_NID
from comm_check_func import check_sys_and_hotpatch_version
from config import db as cfg_db
from check_item.check_util.extend_param_mgr import ExtendParamMgr
from config import path as cfg_path
from config.rpc import RpcRet
from infra.debug.log import swm_logger as log
from infra.rpc.rpc_client import RpcClient
from infra.util import shell
from plat.fw.fw_base import fw_cfg
from plat.host.host import IpInfo
from plat.host.host_mgr import HostMgr
from plat.origin_diagnose.diagnose import SwmDiagnose, Diagnose
from service.version.ver_mgr import VerMgr


IMAGE_DISK = '/startup_disk/image'


class DiskUsageInfo(object):
    """
    Disk usage information class
    """
    def __init__(self, disk_user_id, logic_id, usage, disk_type, *args):
        self.disk_user_id = disk_user_id
        self.logic_id = logic_id
        self.usage = int(usage)
        self.disk_type = disk_type

    def __str__(self):
        print_info = "disk id({}), logic id({}), usage({}), disk_type({}).".\
            format(self.disk_user_id, self.logic_id, self.usage, self.disk_type)
        return print_info


class CheckResult(object):
    """
    Disk usage check result class
    """
    def __init__(self, max_usage, max_usage_threshold, average_usage,
                 average_usage_threshold, over_threshold_list):
        self.max_usage = max_usage
        self.max_usage_threshold = max_usage_threshold
        self.average_usage = average_usage
        self.average_usage_threshold = average_usage_threshold
        self.over_threshold_list = over_threshold_list


class SASDiskUsageChecker(object):
    """

    """
    _DISK_AVERAGE_USAGE_THRESHOLD = 60
    _DISK_MAX_USAGE_THRESHOLD = 80
    _IGNORE_CHECK_THRESHOLD = 10
    _CHECK_DISK_TYPE = ['SAS', 'SATA', 'NL_SAS']
    _DISK_LIST_FILE_PATH = cfg_path.DISK_LOCATION_TMP_PATH
    _DISK_USAGE_CHECK_DIAGNOSE_CMD = "ld getalldiskuserate"
    _DISK_USAGE_CHECK_CMD = "diagnose --auto << EOF > {0}\n{1}\nquit\nEOF".\
        format(cfg_path.DISK_LOCATION_USAGE_TMP_PATH, _DISK_USAGE_CHECK_DIAGNOSE_CMD)
    _CMD_EXEC_TIMEOUT = 180

    def __init__(self):
        self.all_disk_usage_list = []
        self.need_check_disk_id_list = []
        self.average_usage = 0
        self.max_usage = 0
        self.over_threshold_disk_list = []

    def get_check_result(self):
        """
        Get the max disk usage info, and average disk usage info ,
        and threshold info, and over threshold disk list info.
        :return: CheckResult instance
        """
        self._get_disk_for_check()

        disk_count = len(self.need_check_disk_id_list)
        if disk_count == 0:
            log.info("There have no any disk, no need to check usage.")
            return True, CheckResult(0, self._DISK_MAX_USAGE_THRESHOLD,
                                     0, self._DISK_AVERAGE_USAGE_THRESHOLD, "")

        disk_usage_arr = [int(disk.usage) for disk in self.need_check_disk_id_list]
        self.max_usage = max(disk_usage_arr)
        self.average_usage = sum(disk_usage_arr) // len(self.need_check_disk_id_list)
        self.over_threshold_disk_list = [disk.disk_user_id for disk in
                                         self.need_check_disk_id_list if disk.usage >
                                         self._DISK_MAX_USAGE_THRESHOLD]

        if self.max_usage > self._DISK_MAX_USAGE_THRESHOLD \
                or self.average_usage > self._DISK_AVERAGE_USAGE_THRESHOLD:
            return False, CheckResult(self.max_usage, self._DISK_MAX_USAGE_THRESHOLD,
                                      self.average_usage, self._DISK_AVERAGE_USAGE_THRESHOLD,
                                      self.over_threshold_disk_list)
        return True, CheckResult(0, self._DISK_MAX_USAGE_THRESHOLD,
                                 0, self._DISK_AVERAGE_USAGE_THRESHOLD, "")

    def _get_all_disk_usage_info(self):
        """

        :return:
        """
        if os.path.exists(cfg_path.DISK_LOCATION_USAGE_TMP_PATH):
            os.remove(cfg_path.DISK_LOCATION_USAGE_TMP_PATH)
        ret = shell.call_shell_cmd(self._DISK_USAGE_CHECK_CMD)
        if ret:
            return False

        try:
            with open(cfg_path.DISK_LOCATION_USAGE_TMP_PATH, 'r') as file_to_read:
                for line in file_to_read.readlines():
                    # filter the cmd and title lines
                    if line.find('DiskId') >= 0 or line.find('diagnose') >= 0:
                        continue

                    line = re.sub(' +', ' ', line)
                    line = line.replace("\t\n", "").replace("\n", "").replace("\t", ",").replace(" ", ",")
                    disk_usage_info = line.split(',')
                    log.info("disk usage info(%s).", disk_usage_info)

                    disk = DiskUsageInfo(*disk_usage_info)
                    # 过滤制定类型的盘，且如果利用率未达到检查阈值，则不忽略
                    if disk.disk_type in self._CHECK_DISK_TYPE \
                            and disk.usage > self._IGNORE_CHECK_THRESHOLD:
                        self.all_disk_usage_list.append(disk)
        except Exception:
            log.exception("Exception")
            return False

        log.info("Get disk(%s) usage infos successfully, total disk(%s).",
                 self._CHECK_DISK_TYPE, len(self.all_disk_usage_list))
        return True

    def _get_disk_for_check(self):
        """

        :return:
        """
        self._get_all_disk_usage_info()
        if not os.path.exists(cfg_path.DISK_LOCATION_TMP_PATH):
            ExtendParamMgr.get_file_from_other_node(
                cfg_path.DISK_LOCATION_TMP_PATH)
            if not os.path.exists(cfg_path.DISK_LOCATION_TMP_PATH):
                log.info("There have no (%s) file, no need to filter disks.", cfg_path.DISK_LOCATION_TMP_PATH)
                self.need_check_disk_id_list = self.all_disk_usage_list
                return

        disk_id_list = []
        for line in open(cfg_path.DISK_LOCATION_TMP_PATH):
            disk_id_list.extend(re.split("[;,]", line.replace('\n', '')))
            log.info("Line read from disk info temp file(%s).", disk_id_list)

        self.need_check_disk_id_list = [disk for disk in self.all_disk_usage_list if
                                        disk.disk_user_id in disk_id_list]
        log.info("The disks need to check are(%s).", [disk.disk_user_id for disk in self.need_check_disk_id_list])
        return


class SSDDiskUsageChecker(SASDiskUsageChecker):
    """
    SSD disk usage check class
    """
    _DISK_AVERAGE_USAGE_THRESHOLD = 60
    _DISK_MAX_USAGE_THRESHOLD = 80
    _IGNORE_CHECK_THRESHOLD = 0
    _CHECK_DISK_TYPE = ['SSD', 'SLC_SSD', 'MLC_SSD', 'NVME_SSD']


class HotUpgChecker(object):
    """

    """
    _GET_ISCSI_HOST_LINK_CMD = "hostlink show iscsi"

    @staticmethod
    def _is_path_support_hot_upg(cur_version, dst_version):
        if not os.path.exists(cfg_path.HOT_UPG_PATH_CONFIG_FILE):
            log.info("HOT_UPG_CHECK: Hot upg config file(%s) not existing.", cfg_path.HOT_UPG_PATH_CONFIG_FILE)
            return False
        with open(cfg_path.HOT_UPG_PATH_CONFIG_FILE, 'r') as cfg_fd:
            try:
                cfg_yml = yaml.safe_load(cfg_fd)
                log.info("HOT_UPG_CHECK: Yml content(%s).", cfg_yml)
                paths = cfg_yml.get("FrontHotUpgradePath")
                log.info("HOT_UPG_CHECK: paths(%s).", paths)
                if not paths:
                    return False
                for path in paths:
                    if path.get('src_version', '') == cur_version and \
                            path.get('dst_version', '') == dst_version:
                        log.info("src version(%s), dst version(%s) support hot upgrade.",
                                 path.get('src_version', ''), path.get('dst_version', ''))
                        return True
            except Exception:
                log.exception("HOT_UPG_CHECK: Failed to read path from yml, exception.")
                return False
        return False

    def _have_iscsi_host_link(self):
        """
        Check whether iSCSI host links exist.
        :return:
        """
        return_code, iscsi_info = SwmDiagnose().exec_diagnose(self._GET_ISCSI_HOST_LINK_CMD)
        log.info("Diagnose cmd exec ret(%s), output info(%s).", return_code, iscsi_info)
        if return_code or not iscsi_info or \
                str(iscsi_info[0]).replace('\n', '') == 'No iscsi links':
            log.info("HOT_UPG_CHECK: No iscsi links.")
            return False
        return True

    def is_front_support_hot_upg(self):
        """
        check whether the front-end interface card supports hot upgrade
        The conditions for hot upgrade are as follows:
        1) upgrade path supports hot upgrade
        2) no iSCSI host link is available
        3) high device type
        :return:
        """
        if not fw_cfg.is_high_device():
            log.info("It is not high device type, not support hot upg.")
            return False

        if not os.path.exists(cfg_path.CHECK_EXTEND_PARAM_TMP_FILE):
            ExtendParamMgr.get_file_from_other_node(
                cfg_path.CHECK_EXTEND_PARAM_TMP_FILE)
            if not os.path.exists(cfg_path.CHECK_EXTEND_PARAM_TMP_FILE):
                log.info("No extend param tmp file(%s), no need to check hot upg.",
                         cfg_path.CHECK_EXTEND_PARAM_TMP_FILE)
                return False
        res, dst_version = ExtendParamMgr.get_extend_param("targetVersion")
        if not res:
            log.info("No extend param in tmp file(%s), "
                     "no need to check hot upg.",
                     cfg_path.CHECK_EXTEND_PARAM_TMP_FILE)
            return False

        cur_version = VerMgr.get_src_version()[1]

        if not dst_version or not cur_version:
            log.error("Dst version(%s), cur version(%s) is not correct.",
                      dst_version, cur_version)
            return False

        # If the upgrade path supports hot upgrade and no iSCSI host link is available,
        # the hot upgrade conditions are met.
        if self._is_path_support_hot_upg(cur_version, dst_version) and \
                not self._have_iscsi_host_link():
            return True
        log.info("Check not pass, may be the path not support hot upgrade or have iSCSI initiators.")
        return False


class ExtLunLinkCheck(object):
    """
    仅控制器查询，外部lun单链路检查
    """
    _name = 'A_ExternalLunLinkCheck'

    CHECK_CMD = "ld checkextlunsinglelink {nodeList:s}"
    GET_BATCH_CMD = "upgcheck get_batch"

    @classmethod
    def check_one_batch(cls, nid_list):
        exec_command = None
        val = None
        try:
            # check external lun links
            nid_str = ','.join([str(nid) for nid in nid_list])
            exec_command = cls.CHECK_CMD.format(nodeList=nid_str)
            diagnose_obj = Diagnose()
            val, _, _ = diagnose_obj.exec_cmd_and_get_res(exec_command)
            if val != "False":
                log.error("ExtLun: Check ext lun link failed, "
                          "cmd(%s), val(%s).", exec_command, val)
                return False
            log.info("ExtLun: Check ext lun link succeed, "
                     "cmd(%s), val(%s).", exec_command, val)
            return True
        except Exception:
            log.exception("Inner error cmd(%s), val(%s).", exec_command, val)
            return False

    @classmethod
    def check(cls):
        """
        False表示异构无单链路，True表示异构有单链路
        :return:
        """
        try:
            ret, batch_info = SwmDiagnose().exec_diagnose(cls.GET_BATCH_CMD)
            batch_str = [b for b in batch_info if "batch_list" in b]
            if ret != 0 or not batch_str:
                log.error("ExtLun: Get batch info failed, ret(%s), "
                          "batch_info(%s).", ret, batch_info)
                print("True")
                return 0

            batch_list = json.loads(batch_str[0].strip())["batch_list"]
            for each_batch in batch_list:
                nid_list = [nid for nid in each_batch
                            if HostMgr.is_ctrl_node(nid)]
                is_redundant = cls.check_one_batch(nid_list)
                if not is_redundant:
                    log.error("ExtLun: Check ext lun failed, all_batch(%s), "
                              "batch(%s).", batch_list, nid_list)
                    print("True")
                    return 0

            log.info("ExtLun: Check ext lun succeed for all batch, "
                     "all_batch(%s).", batch_list)
            print("False")
            return 0
        except Exception:
            log.error("Unkown Error.")
            print("True")
            return 0


def format_print(usage_print_info, threshold_print_info, check_result):
    usage_print_info = "{0};{1}%;{2}%;{3}%;{4}%".format(usage_print_info, check_result.max_usage,
                                                        check_result.max_usage_threshold,
                                                        check_result.average_usage,
                                                        check_result.average_usage_threshold)
    tmp_str = ','.join(check_result.over_threshold_list)
    threshold_print_info = "{0},{1}".format(threshold_print_info, tmp_str)
    return usage_print_info, threshold_print_info


def check_disk_usage():
    """
    check disk usage
    :return:
    """
    ssd_result, ssd_check_result = SSDDiskUsageChecker().get_check_result()
    sas_result, sas_check_result = SASDiskUsageChecker().get_check_result()
    # Remove the tmp file after using
    if os.path.exists(cfg_path.DISK_LOCATION_TMP_PATH):
        os.remove(cfg_path.DISK_LOCATION_TMP_PATH)
    usage_print_info = ""
    threshold_print_info = ""
    if ssd_result and sas_result:
        print('True')
        print('')
        print('')
        return 0

    usage_print_info, threshold_print_info = \
        format_print(usage_print_info, threshold_print_info, sas_check_result)

    usage_print_info, threshold_print_info = \
        format_print(usage_print_info, threshold_print_info, ssd_check_result)

    print('False')
    print('10004')
    if threshold_print_info.lstrip(',').rstrip(',').strip() == '':
        print("{0};{1}".format(usage_print_info.lstrip(';').rstrip(';'), 'NA'))
    else:
        print("{0};{1}".format(usage_print_info.lstrip(';').rstrip(';'),
                               threshold_print_info.lstrip(',').rstrip(',')))
    return 0


def check_front_hot_upg():
    """
    check whether the front-end interface card supports hot upgrade.
    In the fourth line, it must be JSON format and
    the hostCompaEvaluCheck field is mandatory.
    :return:
    """
    hot_upg_checker = HotUpgChecker()
    if hot_upg_checker.is_front_support_hot_upg():
        print('True')
        # 'hostCompaEvaluCheck: True' indicates that the check item is used for host compatibility evaluation.
        print("{'hostCompaEvaluCheck': 'True'}")
    else:
        print('False')
        print("{'hostCompaEvaluCheck': 'True'}")
    return 0


def get_master_id():
    ret, output = SwmDiagnose().exec_diagnose('cls show')
    if ret != 0:
        return INVALID_NID
    re_ret = re.search(br"master_id(\s*):(\s*)(\d+)", output[1])
    if re_ret is not None:
        return int(re_ret.group(3))
    return INVALID_NID


def check_cluster_sn():
    """
    查询集群max_sn是否与主控sn相等
    :param host_list:
    :return:
    """
    resolve_patch = {
         '6200803153': 'SPH311',
    }
    if check_sys_and_hotpatch_version(resolve_patch):
        print("True")
        return 0
    master_id = get_master_id()
    if master_id == INVALID_NID:
        log.error("Failed to get master id")
        print("False")
        return 1
    master_sn = -1
    max_sn = 0
    max_sn_nid = None
    sql = "select max(sn) from %s" % cfg_db.UPGRADE_RECORD_TBL
    log.info("Master id(%s)", master_id)
    for host in HostMgr.get_ctl_host_list():
        host_ips = [ip.addr for ip in host.host_ip if ip.state == IpInfo.STATE_VALID]
        rpc_result, user_result = RpcClient.start_rpc_call(host_ips, rpc_opcode.RPC_OPCODE_EXEC_SQL, sql, timeout=40)
        if rpc_result != RpcRet.OK or not user_result:
            log.warning("Failed to get host(%s) sn, (rpc_result: %s, user_result: %s)", host, rpc_result, user_result)
            print("False")
            return 1
        if user_result[0][0] and int(user_result[0][0]) > max_sn:
            max_sn = int(user_result[0][0])
            max_sn_nid = host.host_id
        if host.host_id == master_id:
            if user_result[0][0]:
                master_sn = int(user_result[0][0])
                log.info("Master sn is %s", master_sn)
            else:
                master_sn = 0
                log.info("Master sn is default 0.")
    if master_sn >= max_sn:
        print("True")
        return 0
    log.error("Cluster max sn belong to nid(%s), nid is(%s).", max_sn_nid, max_sn)
    print("False")
    return 0


def check_enclosure_num():
    """
        查询IP框数量，大于25则报错，判断逻辑在check_xml中
        :param
        :return:
        """
    data_list = HostMgr.get_data_host_list()
    if not data_list:
        log.info("not exist data node.")
        print(0)
        return 0
    print(int(len(data_list)/2))
    return 0


def fool_proof_mem_cut():
    # 内存裁剪包中都有mem_cut_flag标记，其他包中都没有这个标记
    pkg_dirs = [os.path.join(IMAGE_DISK, d, 'system', 'mem_cut_flag') for d in os.listdir(IMAGE_DISK) if d.isdigit()]
    if len(set(map(os.path.exists, pkg_dirs))) == 1:
        print('True')
    else:
        print('False')
    return 0


def show_help():
    usage = (
        "This is used for upgrade pre-check.\n"
        "  -h, --help             show the cmd help info.\n"
        "  -d, --diskusage        check disk usage.\n"
        "  -f, --fronthotupg      check whether the front-end interface card supports hot upgrade.\n"
        "  -n, --check_sn         check whether the master sn is biggest.\n"
        "  -e, --external_lun     check external_lun single link.\n"
        "  -a, --check_enclosure_num         check enclosure num.\n"
        "  -c, --check_package_compatible    check package compatible.\n"
    ).format(os.path.basename(__file__))
    print(usage)


def main(argv=None):
    if argv is None:
        argv = sys.argv
    try:
        opts, args = getopt.getopt(
            argv[1:], "dhfenac", ["help", "diskusage", "fronthotupg", "external_lun", "check_sn", "check_enclosure_num",
                                  "check_package_compatible"]
        )
        for o, a in opts:
            if o in ("-h", "--help"):
                show_help()
                return 0

            if o in ['-d', '--diskusage']:
                return check_disk_usage()

            if o in ['-f', '--fronthotupg']:
                return check_front_hot_upg()

            if o in ['-e', '--external_lun']:
                return ExtLunLinkCheck.check()

            if o in ['-n', '--check_sn']:
                return check_cluster_sn()

            if o in ['-a', '--check_enclosure_num']:
                return check_enclosure_num()

            if o in ['-c', '--check_package_compatible']:
                return fool_proof_mem_cut()

    except Exception as e:
        print("exception")
        print("exception")
        print(e)
        log.exception(e)
        return 1
    return 0


if __name__ == '__main__':
    sys.exit(main())
