#!/usr/bin/python
# -*- coding: UTF-8 -*-
#  Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.
import logging
import sys
import subprocess
import shlex
import re
import os
import time
import yaml

USER_EXIST = r'id -u admin > /dev/null 2>&1'
USER_CMD = r"cat /etc/passwd | grep '/ISM/cli/ismcli' | grep ':0:root:'"
# 定义返回值
RETURN_OK = 0
RETURN_ERROR = 1
# 命令间隔时间
SLEEP_TIME = 10
# 版本信息
IMAGE_DISK = '/startup_disk/image'
PATCH_MANIFEST = 'patch.yml'
# 当前版本信息
CUR_MANIFEST_PATH = '/OSM/conf/manifest.yml'
# 升级目标版本信息
UPG_MANIFEST_PATH = '/startup_disk/image/pkg_upd/manifest.yml'
# 当前内核版本
CUR_KERNEL_VERSION = '/OSM/script/upgrade.sh'
# 问题解决版本
VERSION_V616_NO = "7600513292"
VERSION_V610_NO = "7600503265"
# 初始化系统日志
logging.basicConfig(level=logging.INFO,
                    filename="/OSM/log/cur_debug/messages",
                    format='[%(asctime)s][%(levelname)s][%(message)s][%(filename)s, %(lineno)d]',
                    datefmt='%Y-%m-%d %H:%M:%S')


# 直接执行diagnose命令
def diagnose(command):
    cmd = 'timeout -s 9 10 diagsh --attach=*_12 --cmd="%s"' % command
    process = subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE)
    # euler r8版本不可以使用timeout参数
    out = process.communicate()
    return out[0].decode("utf-8")


# 直接执行命令
def execute_cmd(cmd):
    cmd_fd = os.popen(cmd)
    data = cmd_fd.read().strip()
    cmd_fd.close()
    return data


# 获得配置文件中的系统版本
def get_sys_version(manifest_path):
    if not os.path.exists(manifest_path):
        logging.warning("sys file(%s) didnt exist.", manifest_path)
        return "", ""
    with open(manifest_path) as file_handle:
        cfg_yml = yaml.safe_load(file_handle)
        try:
            target_version = str(cfg_yml.get("SYS")["Version"])
            target_spc_version = str(cfg_yml.get("SYS")["SpcVersion"])
            return target_version, target_spc_version
        except Exception:
            logging.warning("CHECK_FUNC: Failed to get version.")
            return "", ""


# 从补丁配置文件中获得补丁版本
def get_patch_version(conf_file):
    if not os.path.exists(conf_file):
        logging.warning("patch conf file(%s) didnt exist.", conf_file)
        return ""
    file_handle = open(conf_file)
    patch_conf = yaml.safe_load(file_handle)
    file_handle.close()
    ver_conf = patch_conf.get("PatchConf", None)
    if not ver_conf:
        logging.warning("Patch Conf file didnt exist.")
        return ""
    return ver_conf.get('patch_version', "")


# 检查内核版本，仅A0会出现连接残留
def check_kernel_version():
    cmd = "sh {} kernel showversion"
    kernel_ver = execute_cmd(cmd.format(CUR_KERNEL_VERSION))
    logging.info("%s", kernel_ver)
    digit_kernel_ver = re.findall(r"kernel version:\s+\d+.\d+.(\d+).\d+", kernel_ver, re.S)
    if len(digit_kernel_ver) > 0:
        if digit_kernel_ver[0] > "0":
            return True
    return False


# 检查软件和补丁版本
def check_software_and_patch_version(manifest_path):
    upgrade_patch_cfg = {
        '7600506206': 'SPH60',  # 6.1.2版本
        '7600509200': 'SPH30',  # 6.1.3版本
        '7600511219': 'SPH18'   # 6.1.5版本
    }

    digit_ver, sys_ver = get_sys_version(manifest_path)
    if not digit_ver:
        logging.warning("Get system version failed.")
        return False
    # 此版本问题已解决或无问题，检查通过
    if digit_ver >= VERSION_V616_NO or digit_ver <= VERSION_V610_NO:
        return True

    # 获得该系统版本的补丁版本目录
    patch_path = os.path.join(IMAGE_DISK, str(digit_ver), 'hotpatch', 'patch_cur', PATCH_MANIFEST)
    # 获得该系统版本的目标安装补丁版本目录
    temp_patch_path = os.path.join(IMAGE_DISK, str(digit_ver), 'hotpatch', 'patch_temp', PATCH_MANIFEST)
    if os.path.exists(temp_patch_path):
        patch_path = temp_patch_path
        temp_patch_ver = get_patch_version(temp_patch_path)
        # 6.1.5.SPH12及之后安装会杀进程
        if digit_ver == "7600511219" and temp_patch_ver < "SPH12":
            logging.info("patch_ver(%s) didnt kill thread.", temp_patch_ver)
            return True

    # 获得补丁版本
    patch_ver = get_patch_version(patch_path)

    if not patch_ver:
        patch_digit_ver = 0
    else:
        patch_digit_ver = int(re.split('(\\d+)', patch_ver)[1])
    logging.info("CHECK_VERSION: sys_ver(%s), patch_ver(%s)", sys_ver, patch_ver)

    # 检查补丁版本问题是否解决
    if digit_ver in upgrade_patch_cfg:
        require_patch_digit_ver = int(re.split('(\\d+)', upgrade_patch_cfg[digit_ver])[1])
        if patch_digit_ver >= require_patch_digit_ver:
            return True
        else:
            return False
    else:
        logging.error("Digital version(%s) is invalid.", digit_ver)
        return False


def check_current_and_target_version():
    logging.info("Start check current version.")
    ret = check_kernel_version()
    if ret:
        return True
    ret = check_software_and_patch_version(CUR_MANIFEST_PATH)
    if ret:
        return True

    logging.info("Start check target upgrade version.")
    ret = check_software_and_patch_version(UPG_MANIFEST_PATH)
    if ret:
        return True
    return False


def check_remain_conn_card(slot_list, card_list, card_len, dev_list, dev_len):
    for i in range(0, card_len):
        slot = card_list[i][1]
        card = card_list[i][0]
        j = 0
        # 根据1822card的slot来筛选对应的dev
        for j in range(0, dev_len):
            if dev_list[j][1] == slot:
                break
        if j < dev_len:
            dev = dev_list[j][0]
        else:
            logging.error("The card(slot:%d) cannot be found in dev.", slot)
            continue

        # 查看内核态连接信息
        cmd = "hinicadm toeconn -i hinic{} -l"
        kernel_conn = execute_cmd(cmd.format(card))

        kernel_conn_xid = re.findall(r"DTOE\s+(\d+)", kernel_conn, re.S)
        kernel_conn_xid = set(kernel_conn_xid)
        kernel_conn_num = len(kernel_conn_xid)
        # 未查询到DTOE连接，该卡检查通过
        if kernel_conn_num == 0:
            logging.info("Number of connection in hinic%d is 0.", i)
            continue

        # 查看用户态连接信息
        cmd = "dtoe showdevxidinuse {}"
        user_conn = str(diagnose(cmd.format(dev)))
        user_conn = user_conn.split('\n')
        user_conn = '\n'.join(user_conn[1:])  # 去除命令本身的第一行数字
        user_conn_xid = re.findall("(\d+)\s+", user_conn, re.S)

        user_conn_xid = set(user_conn_xid)
        user_conn_num = len(user_conn_xid)
        logging.info("Hinic%d: kernel_conn_num=%d, user_conn_num=%d", i, kernel_conn_num, user_conn_num)

        # 该卡用户态的连接编号不包含内核态的所有连接编号，认为可能有连接残留或正在建断连，重复查询
        if not kernel_conn_xid.issubset(user_conn_xid):
            slot_list.append(slot)
            return False
    return True


# 检查连接残留
def check_remain_conn(slot_list):
    all_dev = str(diagnose("dtoe showalldev"))
    all_card = str(diagnose("dtoe showallcard"))

    if "drv_type" in all_card:
        logging.info("Current kernel version is later than or equal to A5, using DTOE2.0.")
        return True
    # 筛选1822卡
    card_list = re.findall(r"hinic(\d+).*?slot:(\d+)", all_card)
    card_list.sort(key=lambda x: x[1])
    card_len = len(card_list)
    if card_len == 0:
        logging.info("The 1822 DTOE card does not exist.")
        return True

    # sn包含十进制和十六进制格式
    dev_list = re.findall(r"Dev name.*?sn.*?(?:0x)?([0-9a-fA-F]+).*?Dev slot id.*?(\d+)", all_dev, re.S)
    dev_list.sort(key=lambda x: x[1])
    dev_len = len(dev_list)

    for _ in range(0, 2):
        # 循环检查所有卡上的连接
        ret = check_remain_conn_card(slot_list, card_list, card_len, dev_list, dev_len)
        # 如果用户态的连接编号包含内核态的所有连接编号，则不存在连接残留
        if ret:
            return True
        # 等待10s再检查
        time.sleep(SLEEP_TIME)

    return False


def main():
    try:
        logging.info("Start check remain dtoe connection.")
        # 检查版本
        check_result = check_current_and_target_version()
        if check_result:
            logging.info("Check version successfully.")
            print("True")
            return RETURN_OK
        
        # 检查连接残留
        slot_list = []
        check_result = check_remain_conn(slot_list)
        if check_result:
            logging.info("No remain DTOE connection exists.")
            print("True")
            return RETURN_OK

        # 存在dtoe连接残留的卡槽位去重
        slot_list_new = list(set(slot_list))
        logging.error("NIC slots with remaining DTOE connections: %s.", slot_list_new)
        print("False")
        return RETURN_OK

    except Exception as err:
        logging.exception("check_dtoe_conn_remain: %s", err)
        print("False")
        return RETURN_ERROR


if __name__ == '__main__':
    sys.exit(main())
