#!/usr/bin/python
# -*- coding: utf-8 -*-
# Copyright (c) Huawei Technologies Co., Ltd. 2018-2022. All rights reserved.
import argparse
import contextlib
import json
import os
import re
import shutil
import sys
import time

import yaml

from check_item.check_util import util
from check_item.comm_check_func import IMAGE_DISK, get_conf_product
from config import env, path, pkg as pkg_cfg, db as db_cfg
from config.upd import UpgStatus
from infra.db.swm_db import DbService as DbSrv
from infra.debug.log import swm_logger as log
from infra.util import pkg
from infra.util.shell import remount_image
from plat.fw.adapter.dmi_adapter import DmiAdapter
from plat.fw.adapter.dmi_base import DmiFwVerCheckResult
from plat.fw.fw_mgmt_facade import FwObjMgrFacade, FwObjName

ROLL_APOLLO_UPG = 8
OFFLINE_APOLLO_UPG = 9


class UpgradeModeType(object):
    UPGRADE_MODE_INVALID = 0  # 升级模式无效值
    UPGRADE_HOT_MODE = 1  # 热升级模式,无需复位器件生效的模式
    UPGRADE_COLD_MODE = 2  # 冷升级模式,需要复位器件生效的模式


def init_thrift_interface():
    dmi_adapter = DmiAdapter()
    if hasattr(dmi_adapter, 'open_relate_so'):
        dmi_adapter.open_relate_so()
    ret_code = dmi_adapter.start()
    if not ret_code:
        log.error("FW_CHECK: interface init failed.")
        return False
    return True


def apollo_copy_pkg(src_plat, dst_plat, dst_default_plat):
    dst_cmpt_list = [x.split('-')[0] for x in os.listdir(dst_plat)]
    for cmpt_pkg in os.listdir(src_plat):
        cmpt_name = cmpt_pkg.split('-')[0]
        if cmpt_name not in dst_cmpt_list:
            if cmpt_name.startswith("his_kernel"):  # his_kernel 不拷贝 降级场景会残留
                continue
            log.info("WEAK_MATCH: The dst plat has no cmpt(%s), copy to dst.", cmpt_name)
            if cmpt_name == 'firmware' and dst_default_plat != os.path.realpath(dst_plat):
                cmd = 'cp -arf {0}/{1} {2}/'.format(dst_default_plat, cmpt_pkg, dst_plat)
            else:
                cmd = 'cp -arf {0}/{1} {2}/'.format(src_plat, cmpt_pkg, dst_plat)
            log.info("cmd:%s", cmd)
            ret, _ = util.call_system_cmd(cmd.split(), timeout=5)
            if ret:
                log.error("WEAK_MATCH: Copy weak match cmpt(%s) to dst failed.", cmpt_name)
                return False
            log.info("WEAK_MATCH: Copy weak match cmpt(%s) to dst successfully.", cmpt_name)
    return True


def copy_his_kernel(src_plat, dst_plat):
    """
    拷贝源包中的 kernel 组件到目的包中
    从版本开始支持his_kernel，不支持的版本不拷贝

    :param:src_plat:源apollo目录
    :param:dst_plat:目的apollo目录
    """
    src_his_kern_path = get_cmpt_path(src_plat, "his_kernel")
    log.info("CO_KERNEL: src_his_kern_path %s.", src_his_kern_path)
    # 源目录已经存在 his_kernel 直接拷贝到目的目录
    if src_his_kern_path:
        dst_cmpt_path = os.path.join(dst_plat, os.path.basename(src_his_kern_path))
        shutil.copyfile(src_his_kern_path, dst_cmpt_path)
        return
    # 否则拷贝源目录下的 kernel 组件到目的目录作为 his_kernel
    src_cmpt_path = get_cmpt_path(src_plat, "kernel")
    if not src_cmpt_path:
        log.error("CO_KERNEL: Src kernel path is not exist.")
        return
    dst_cmpt_path = os.path.join(dst_plat, "his_" + os.path.basename(src_cmpt_path))
    shutil.copyfile(src_cmpt_path, dst_cmpt_path)
    return


def get_apollo_version(apollo_link):
    if not os.path.islink(apollo_link):
        return ""
    apollo_path = os.path.realpath(apollo_link)
    return os.path.basename(apollo_path).split('-')[1]


def get_cmpt_path(pkg_dir, cmpt_name):
    """
    查找目录下的以指定名称开头的组件路径
    :param:pkg_dir:组件所在的目录
    :param:cmpt_name:组件名称
    """
    if not os.path.exists(pkg_dir):
        return ""
    for pkg_name in os.listdir(pkg_dir):
        if pkg_name.startswith(cmpt_name + "-"):
            return os.path.join(pkg_dir, pkg_name)
    return ""


def complete_weak_match_plat(src_plat, dst_plat):
    if not os.path.islink(src_plat):
        log.warning("WEAK_MATCH: The src plat(%s) is not link.", src_plat)
        return True
    if not os.path.islink(dst_plat):
        log.warning("WEAK_MATCH: The dst plat(%s) is not link.", dst_plat)
        return True
    dst_default_plat = os.path.realpath(dst_plat)
    src_apollo_version = get_apollo_version(src_plat)
    dst_default_apollo_version = get_apollo_version(dst_plat)
    if os.path.realpath(src_plat) != os.path.realpath(dst_plat):
        log.info(
            "WEAK_MATCH: The src plat(%s) diff with dst plat(%s), "
            "change dst plat link.", src_plat, dst_plat)
        os.unlink(dst_plat)
        os.symlink(os.readlink(src_plat), dst_plat)
    if src_apollo_version != dst_default_apollo_version:  # 原版本和目的版本的默认apollo不一致拷贝 his_kernel
        log.info("WEAK_MATCH: The src_apollo_version(%s) diff with dst_default_apollo_version(%s), copy his kernel.",
                 src_apollo_version, dst_default_apollo_version)
        copy_his_kernel(src_plat, dst_plat)
    return apollo_copy_pkg(src_plat, dst_plat, dst_default_plat)


def complete_weak_pkg_with_retry(dst_pkg_root, retry_time=3, interval=2):
    for i in range(retry_time):
        if complete_weak_pkg(dst_pkg_root):
            return True
        if (i+1) == retry_time:
            break
        time.sleep(interval)
        log.warning("WEAK_MATCH: Complete weak pkg failed retry.")
    return False


@remount_image
def complete_weak_pkg(dst_pkg_root):
    """
    Complete weak match package
    :param dst_pkg_root:
    :return:
    """
    src_pkg_root = os.path.realpath(path.PKG_CUR_DIR)
    for plat in pkg_cfg.WEAK_MATCH_PLAT_LIST:
        src_plat = os.path.join(src_pkg_root, plat)
        dst_plat = os.path.join(dst_pkg_root, plat)
        ret = complete_weak_match_plat(src_plat, dst_plat)
        if not ret:
            log.error("WEAK_MATCH: Complete weak match plat(%s) failed.", plat)
            return False
        log.info("WEAK_MATCH: Complete weak match plat(%s) successfully.", plat)
    log.info("WEAK_MATCH: Complete weak match pkg successfully.")
    return True


def check_fast_func():
    init_res = init_thrift_interface()
    if not init_res:
        return False

    if not os.path.exists(path.TARGET_CONFIG_PATH):
        log.info("FW_CHECK: Can not get the new version,pkg %s does not exist." % path.UPLOAD_SOFTWARE_PATH)
        return False
    pkg_ver = pkg.get_pkg_version(path.TARGET_CONFIG_PATH)[0]
    if not pkg_ver:
        return False
    if not change_initrd_permit_back():
        return False
    sys_path = os.path.join(path.IMAGE_DISK, pkg_ver, 'system')
    if os.readlink(os.path.join(path.PKG_CUR_DIR, 'apollo')) not in \
            os.listdir(os.path.join(path.IMAGE_DISK, pkg_ver, 'system')):
        log.error("CHECK_ITEM: There is no src_apollo in target package.")
        return False
    complete_res = complete_weak_pkg_with_retry(sys_path)
    if not complete_res:
        return False

    fw_dir = os.path.join(sys_path, 'firmware', env.BOARD_TYPE)
    if not check_fw_support_fast_upg(fw_dir):
        return False
    if not check_src_product_support_fast_upg():
        return False
    return True


def get_bay_type():
    with os.fdopen(os.open('/OSM/script/proc_osp_bsp.info', os.O_RDONLY, 0o644), 'rb') as file_handler:
        bsp_info = file_handler.read()
    bay_type = re.search(r'Model of products is:(\w+)', bsp_info) if env.PYTHON_2 \
        else re.search(b'Model of products is:(\w+)', bsp_info)
    if not bay_type:
        return ""
    return str(bay_type.groups()[0].decode(errors='ignore'))


def check_src_product_support_fast_upg():
    """
    check 当前环境原版本的产品型号是否支持快速升级
    return: True 支持 False 不支持
    """
    black_list = {
        "Lite_Converged": {  # 616 TR5 不支持快速升级到后续版本
            "7600512178": ("2200_16G", "2220_16G", "2620_32G", "2600_32G", "5120_16G", "5120_32G")
        }
    }
    version, _, _ = get_pkg_version_from_dir(path.PKG_CUR_DIR)
    product = get_conf_product(conf_file=os.path.join(path.PKG_CUR_DIR, "manifest.yml"), conf_type='sys')
    if not version or not product:
        return True
    black_bay_type_list = black_list.get(product, {}).get(version)
    if not black_bay_type_list:
        return True
    bay_type = get_bay_type()
    if bay_type not in black_bay_type_list:
        log.info("FAST_UPG_CHECK: bay type (%s) not in black list.", bay_type)
        return True
    log.error("FAST_UPG_CHECK: product(name: %s, src version:%s, bay type: %s) not support fast upgrade.",
              product, version, bay_type)
    return False


def check_fw_support_fast_upg(fw_dir):
    fw_upg_facade = FwObjMgrFacade(obj_type=FwObjName.BIOS, fw_path=fw_dir,
                                   upgrade_mode=UpgradeModeType.UPGRADE_HOT_MODE)
    is_support_hot_upg = fw_upg_facade.is_support_hot_upgrade()
    if not is_support_hot_upg:
        log.info("FW_CHECK: bios not support fast upgrade.")
        return False

    fw_upg_facade = FwObjMgrFacade(obj_type=FwObjName.ENCLOSURE, fw_path=fw_dir)
    is_support_hot_upg = fw_upg_facade.is_support_hot_upgrade()
    if not is_support_hot_upg:
        log.info("FW_CHECK: enclosure not support fast upgrade.")
        return False

    # 615TR5版本中低端带1823卡不支持快速升级到TR6版本，需要在升级检查项中进行拦截，修改之前DmiFwVerCheckResult：2不拦截修改后拦截，避免后续版本测试误操作
    # 考虑升级检查兼容性场景，此处填写固件名
    check_fw_list = ['hi1822_fe_card', 'hi1822_fe_eth_card', 'hi1822_fe_vxlan_card']
    cold_upg_list = []
    for fw_obj in check_fw_list:
        try:
            fw_upg_facade = FwObjMgrFacade(obj_type=fw_obj, fw_path=fw_dir)
        except Exception as exception:
            log.exception("Can't find obj for %s, exception %s", fw_obj, exception)
            continue
        check_result = fw_upg_facade.check()
        if not check_result.result:
            # DMI接口查询失败。
            log.error("FW_CHECK: Some %s check failed.", fw_obj)
            return False
        for obj in check_result.suc_obj_operate_res_list:
            if obj.operate_value not in [DmiFwVerCheckResult.DMI_FW_VER_CORRECT,
                                         DmiFwVerCheckResult.DMI_FW_VER_CO_COMPATIBLE]:
                cold_upg_list.append(obj.fw_obj)
    if cold_upg_list:
        cold_upg_list_name = [fw.name for fw in cold_upg_list]
        log.error("FW_CHECK: Some objs not support hot upgrade, cold_upg_list(%s) .", cold_upg_list_name)
        return False
    return True


def change_initrd_permit_back():
    cur_apollo_path = os.path.realpath(os.path.join(path.SYS_CUR_PATH, 'apollo'))
    apollo_version = os.path.basename(cur_apollo_path).split('-')[1]
    if not apollo_version:
        log.error('FAST_UPG_CHECK: Get pkg current apollo version or number version failed.')
        return False
    apollo_version_615 = '1.1.5.0'
    apollo_version_before_a5 = any(map(lambda ver: int(ver[0]) < int(ver[1]),
                                       zip(apollo_version.split('.'), apollo_version_615.split('.'))))
    if not apollo_version_before_a5:
        log.info('FAST_UPG_CHECK: No need to change initrd permit back.')
        return True
    log.info('FAST_UPG_CHECK: Change initrd permit back.')
    return _change_initrd_osp_permit_back()


def _change_initrd_osp_permit_back():
    """
    争对615以下的版本快速升级到615并复位节点后，再快速升级到616的场景，给initrd的other组加上读权限
    :return:
    """
    initrd_osp_path_link = "/startup_disk/image/boot/initrd"
    initrd_osp_path = os.path.realpath(initrd_osp_path_link)

    # other组没有读权限时，加上读权限
    stat_info = os.stat(initrd_osp_path)
    if stat_info.st_mode & 0o004 == 0o004:
        log.info("FAST_UPG_CHECK: No need to give other group read permit for initrd.")
        return True
    log.info("FAST_UPG_CHECK: Give other group read permit for initrd.")
    with remount_rw_ro():
        os.chmod(initrd_osp_path, stat_info.st_mode | 0o004)
    log.info("FAST_UPG_CHECK: Give other group read permit for initrd success.")
    return True


@contextlib.contextmanager
def remount_rw_ro(disk=IMAGE_DISK):
    try:
        util.call_system_cmd(['mount', '-o', 'remount,rw,async', disk])
        yield
    finally:
        util.call_system_cmd(['mount', '-o', 'remount,ro,async', disk])


def check_is_support_apollo_switch():
    # 0: 可以切换
    # 1：不需要切换，已经是最新
    # 2：内部错误
    # 获取pkg_cur指向的包
    # 存在两个apollo包，1.当前指向的不是最大apollo version链接，   返回0
    #                 2.指向是最大apollo version链接，固件不匹配，返回0
    #                 3.指向是最大apollo version链接，固件匹配，  返回1
    # 存在一个apollo包，1.固件匹配  返回1
    #                 2.固件不匹配 返回0
    cur_pkg_root = os.path.realpath(path.PKG_CUR_DIR)
    dst_list = [x for x in os.listdir(cur_pkg_root) if
                x.startswith("apollo-")]
    if len(dst_list) == 0:
        log.error("APOLLO_PRE_CHECK: system dir has error.")
        return 2
    cur_apollo_ver_dir = os.path.basename(
        os.path.realpath(os.path.join(cur_pkg_root, 'apollo')))
    max_apollo_ver_dir = pkg.get_max_apollo_ver_dir(dst_list)
    if max_apollo_ver_dir != "" and max_apollo_ver_dir != cur_apollo_ver_dir:
        log.info("APOLLO_PRE_CHECK: cur not the max apollo version.")
        return 0

    # check firmware
    fw_path = os.path.join(path.FW_DIR, env.BOARD_TYPE)
    fw_upg_facade = FwObjMgrFacade(obj_type=FwObjName.BIOS, fw_path=fw_path)
    bios_list = fw_upg_facade.get_obj_list()
    not_latest_obj_list = [obj for obj in bios_list if not obj.is_ver_latest]
    if not_latest_obj_list:
        log.info("APOLLO_PRE_CHECK: bios ver not same, can switch apollo.")
        return 0

    fw_upg_facade = FwObjMgrFacade(obj_type=FwObjName.ENCLOSURE, fw_path=fw_path)
    bios_list = fw_upg_facade.get_obj_list()
    not_latest_obj_list = [obj for obj in bios_list if not obj.is_ver_latest]
    if not_latest_obj_list:
        log.info("APOLLO_PRE_CHECK: ctrl ver not same, can switch apollo.")
        return 0

    log.info("APOLLO_PRE_CHECK: latest, no need switch apollo link.")
    return 1


def compare_version(ver_1, ver_2):
    """

    :param ver_1:1.0.0.9
    :param ver_2:
    :return:
    """
    try:
        ver_list_1 = [int(i) for i in ver_1.split('.')]
        ver_list_2 = [int(i) for i in ver_2.split('.')]
        for i, ver, in enumerate(ver_list_1):
            if ver == ver_list_2[i]:
                continue
            elif ver < ver_list_2[i]:
                return -1
            else:
                return 1
        return 0
    except Exception:
        log.exception("Cmpt compare failed.")
        return None


def get_pkg_version_from_dir(pkd_dir):
    manifest_file = os.path.join(pkd_dir, "manifest.yml")
    if not os.path.exists(manifest_file):
        log.warning("CHECK_ITEM: Can not get the new version, config %s not "
                    "existing.", manifest_file)
        return "", "", ""

    with open(manifest_file) as fd:
        cfg_yml = yaml.safe_load(fd)
        try:
            target_version = str(cfg_yml.get("SYS")["Version"])
            target_spc_version = str(cfg_yml.get("SYS")["SpcVersion"])
            target_apollo_version = os.readlink(os.path.join(pkd_dir, 'apollo'))
            log.info('CHECK_ITEM: Get Version(%s) SPCVersion(%s).', target_version, target_spc_version)
            return target_version, target_spc_version, target_apollo_version
        except Exception:
            log.exception("CHECK_ITEM: Failed to get version.")
            return "", "", ""


def get_max_target_apollo(path_dir):
    ver_list = [x.split("-")[1].strip().split(".") for x in os.listdir(path_dir) if x.startswith("apollo-")]
    ver_tuple = [tuple(map(int, i)) for i in ver_list]
    ret_list = list(map(str, max(ver_tuple)))
    return "apollo-{0}".format(".".join(ret_list))


def apollo_copy_firmware_pkg(src_plat, dst_plat):
    if 'firmware' in os.listdir(dst_plat):
        return True
    cmd = 'cp -arf {0}/{1} {2}/'.format(src_plat, 'firmware', dst_plat)
    ret = os.system(cmd)
    if ret:
        log.error("WEAK_MATCH: Copy weak match firmware to dst failed with cmd(%s).", cmd)
        return False
    log.info("WEAK_MATCH: Copy weak match firmware to dst successfully with cmd(%s).", cmd)
    return True


@remount_image
def switch_apollo_in_target(path_dir, max_apollo):
    apollo_components = ['kernel', 'apollo_patch', 'euler', 'firmware']
    src_plat = os.path.join(path_dir, "apollo")
    dst_plat = os.path.join(path_dir, max_apollo)
    if not apollo_copy_firmware_pkg(src_plat, dst_plat):
        return False
    os.unlink(src_plat)
    os.symlink(max_apollo, src_plat)
    if len(os.listdir(dst_plat)) < len(apollo_components):
        log.error("CHECK_ITEM: There are not enough components.")
        return False
    return True


def load_record_info_from_db():
    sql = "select upg_type, pre_ver, post_ver, record from {0} where sn=" \
          "(select max(sn) from {0} where state={1} and upg_type not in ({2}, {3}))". \
        format(db_cfg.UPGRADE_RECORD_TBL, int(UpgStatus.UPD_SUCCESS), ROLL_APOLLO_UPG, OFFLINE_APOLLO_UPG)
    ret = DbSrv.exec_sql(sql)
    if ret is None:
        log.error("CHECK_ITEM: Load progress failed.")
        return None
    if not ret:
        log.info("CHECK_ITEM: No progress info in db.")
        return None
    return [ret[0][0], ret[0][1], ret[0][2], ret[0][3]]


def get_his_apollo_ver():
    ver_info = load_record_info_from_db()
    if ver_info is None:
        return None
    try:
        record = json.loads(ver_info[-1])
        if not record:
            log.info("CHECK_ITEM: Record info in upgrade record table is empty.")
            return None
        pre_apollo_ver = record.get('pre_apollo_ver', '')
        post_apollo_ver = record.get('post_apollo_ver', '')
        return pre_apollo_ver, post_apollo_ver
    except Exception:
        log.error("CHECK_ITEM: Get history apollo ver from database failed.")
        return None


def handle_apollo_before_upgrade():
    if not os.path.exists(path.TARGET_CONFIG_PATH):
        log.info("CHECK_ITEM: Can not get the new version,pkg %s does not exist." % path.UPLOAD_SOFTWARE_PATH)
        return False
    pkg_ver = pkg.get_pkg_version(path.TARGET_CONFIG_PATH)[0]
    if not pkg_ver:
        log.error("CHECK_ITEM: Failed to get dst pkg version.")
        return False
    cur_ver, _, cur_apollo = get_pkg_version_from_dir(path.PKG_CUR_DIR)
    _, _, upd_apollo = get_pkg_version_from_dir(os.path.join(path.IMAGE_DISK, pkg_ver, 'system'))
    if not cur_ver or not cur_apollo or not upd_apollo:
        log.error("CHECK_ITEM: Failed to get version info, cur_ver(%s) cur_apollo(%s) upd_apollo(%s).",
                  cur_ver, cur_apollo, upd_apollo)
        return False
    # 升级场景
    if int(cur_ver) < int(pkg_ver):
        # 当前apollo版本小于/等于目标apollo版本，无需变更
        if compare_version(cur_apollo.split('-')[1], upd_apollo.split('-')[1]) <= 0:
            return True
        # 当前apollo版本大于目标apollo版本，需要修改目标包apollo版本的链接，指向最大的apollo包(要保证最大的Apollo不小于当前)
        max_apollo = get_max_target_apollo(os.path.join(path.IMAGE_DISK, pkg_ver, 'system'))
        if compare_version(cur_apollo.split('-')[1], max_apollo.split('-')[1]) > 0:
            log.error("CHECK_ITEM: Current apollo version(%s) is bigger than max target apollo version(%s).",
                      cur_apollo.split('-')[1], max_apollo.split('-')[1])
            return False
        return switch_apollo_in_target(os.path.join(path.IMAGE_DISK, pkg_ver, 'system'), max_apollo)
    else:
        # 降级场景，先获取升级记录的apollo版本
        his_apollo = get_his_apollo_ver()
        if his_apollo is None:
            return False
        pre_apollo_ver, post_apollo_ver = his_apollo
        ver_list = [x.split("-")[1].strip()
                    for x in os.listdir(os.path.join(path.IMAGE_DISK, pkg_ver, 'system')) if x.startswith("apollo-")]
        if pre_apollo_ver not in ver_list:
            log.error("CHECK_ITEM: There is no pre_apollo(%s) in target package(%s).", pre_apollo_ver, ver_list)
            return False
        if compare_version(pre_apollo_ver, upd_apollo.split('-')[1]) <= 0:  # 历史apollo版本和目标包apollo版本匹配
            log.info("CHECK_ITEM: The pre_apollo(%s) is match with target default apollo.", pre_apollo_ver)
            return True
        return switch_apollo_in_target(os.path.join(path.IMAGE_DISK, pkg_ver, 'system'),
                                       "apollo-{0}".format(pre_apollo_ver))


def main():
    parser = argparse.ArgumentParser(
        description="This is a fast upgrade check script.")
    parser.add_argument('-c', '--check', dest='check_fast',
                        action="store_true",
                        help='check is support fast upgrade.')
    parser.add_argument('-a', '--apollo', dest='check_apollo',
                        action="store_true",
                        help='check is support apollo switch.')
    parser.add_argument('-H', '--handle_apollo', dest='handle_apollo',
                        action="store_true",
                        help='handle apollo before upgrade.')
    args = parser.parse_args()
    check_fast = getattr(args, "check_fast", False)
    if check_fast:
        check_result = check_fast_func()
        if not check_result:
            return 1

    check_apollo = getattr(args, "check_apollo", False)
    if check_apollo:
        return check_is_support_apollo_switch()

    handle_apollo = getattr(args, "handle_apollo", False)
    if handle_apollo:
        return 0 if handle_apollo_before_upgrade() else 1
    return 0


if __name__ == '__main__':
    sys.exit(main())
