#!/usr/bin/env python
# -!- coding:utf-8 -!-

import argparse
import json
import logging
import os
import re
import shlex
import sys
import telnetlib
import time
from subprocess import PIPE, Popen
from threading import Timer
import select
import stat

CONNECTION_TIMEOUT = 30
END_SYMBOL = [">", "[Y/N]:", "]"]

AI_FABRIC_COMMAND = ["sys", "ai-service", "ai-ecn", "dis this"]
DIS_VERSION_COMMAND = ["sys", "dis version"]
DIS_INTERFACE_COMMAND = ["sys", "dis int bri"]
DIS_PFC_COMMAND = ["sys", "dcb pfc", "dis this"]
DIS_ECN_COMMAND = ["sys", "dis drop-profile"]

NETWORK_CFG = "/opt/fusionstorage/agent/conf/network.cfg"
RDMA_CHECK_LOG = "/var/log/rdma_check/check_log.log"
RDMA_CHECK_SWITCH_JSON = "/var/log/rdma_check/rdma_check_switch_"
RDMA_SWITCH_PORTS_JSON = "/var/log/rdma_check/switch_ports_"

# dscp全局pfc配置
DSCP_GLOBAL_PFC_CONF = '''
    priority 4 5
    priority 4 deadlock-detect time 10
    priority 4 deadlock-recovery time 10
    priority 4 turn-off threshold 10
'''
# 全局ecn模糊配置，不关注具体数值
DSCP_GLOBAL_ECN_CONF = '''
    ecn buffer-size
'''

# dscp端口ecn配置
DSCP_INTERFACE_ECN_CONF = '''
    qos drr
    qos queue 4
    qos buffer queue 4
    qos queue 4 ecn
'''

# dscp端口pfc配置
DSCP_INTERFACE_PFC_CONF = '''
    trust dscp
    dcb pfc enable mode manual
    dcb pfc buffer 4
'''

# pcp全局pfc配置
PCP_GLOBAL_PFC_CONF = '''
    priority 3
    priority 3 deadlock-detect time 10
    priority 3 turn-off threshold 10
'''
# pcp全局ecn模糊配置，不关注具体数值
PCP_GLOBAL_ECN_CONF = '''
    ecn buffer-size
'''

PCP_INTERFACE_PFC_CONF = '''
    dcb pfc enable mode manual
    dcb pfc buffer 3
'''

# pcp端口配置
PCP_INTERFACE_ECN_CONF = '''
    qos drr
    qos queue 3
    qos buffer queue 3
    qos queue 3 ecn
'''

SW_DSCP_CONF = {
    "std_global_pfc": DSCP_GLOBAL_PFC_CONF,
    "std_global_ecn": DSCP_GLOBAL_ECN_CONF,
    "std_intf_pfc_conf": DSCP_INTERFACE_PFC_CONF,
    "std_intf_ecn_conf": DSCP_INTERFACE_ECN_CONF
}

SW_PCP_CONF = {
    "std_global_pfc": PCP_GLOBAL_PFC_CONF,
    "std_global_ecn": PCP_GLOBAL_ECN_CONF,
    "std_intf_pfc_conf": PCP_INTERFACE_PFC_CONF,
    "std_intf_ecn_conf": PCP_INTERFACE_ECN_CONF
}


class LoginClient(object):
    '''
    支持telnet/ssh方式登录
    '''

    def __init__(self, ip, username, password):
        self.is_init_log = False
        self.ip = ip
        self.tn = None
        self.username = username
        self.password = password
        self.is_login_timeout = False
        self.login()

    @property
    def log(self):
        return self._init_log(RDMA_CHECK_LOG)

    def login(self):
        '''登录时,首选telnet方式,如果未开启telnet,则继续尝试ssh连接交换机.如果都失败,则直接返错退出'''
        try:
            self.tn = self._login_telnet()
        except BaseException as e:
            self.log.error("Telnet login failed, err: %s", str(e))
            self.tn = None
        if self.tn:
            self.login_type = "telnet"
            self._exec_command_telnet(self.tn, ["system"])
            return

        res = self._login_ssh()

        if self.is_login_timeout:
            self.log.error("ssh login timeout.")
            raise TimeoutError

        if "Info:" in res:
            self.login_type = "ssh"
        else:
            self.log.error(
                "telent and ssh are not supported to connect switch.")
            raise ValueError

    def exec_command(self, cmd_list):
        '''通过telnet或ssh方式连接交换机,'''
        if self.login_type == "telnet":
            return self._exec_command_telnet(self.tn, cmd_list)
        else:
            return self._exec_command_ssh(cmd_list)

    def popen(self, cmdline):
        res = PIPE
        commands = cmdline.split('|')
        for cmd in commands:
            encrypt_log = self._encrypt_sensitive_info(cmd, self.password)
            try:
                result = Popen(shlex.split(cmd), shell=False, stdin=res, stdout=PIPE, stderr=PIPE,
                               universal_newlines=True, encoding='unicode-escape')
            except BaseException as e:
                self.log.error("exec %s cmd err, err: %s", encrypt_log, str(e))
                return "err"
            res = result.stdout
        timer = Timer(CONNECTION_TIMEOUT, lambda process: process.kill(), [result])
        try:
            timer.start()
            stdout, stderr = result.communicate()
            encrypt_log = self._encrypt_sensitive_info(cmdline, self.password)

            if stderr and "User Authentication" not in stderr:
                self.log.error("cmd({}) communicate failed, err:{}, out:{}".format(encrypt_log, stderr, stdout))
                return "err"

            if not timer.is_alive():
                self.log.error('Command timed out: {}, out: {}'.format(encrypt_log, stdout))
                self.is_login_timeout = True
            else:
                self.log.info('Command completed successfully: {}'.format(encrypt_log))

            return stdout
        finally:
            timer.cancel()

    def _exec_command_ssh(self, cmd_list=None):
        file_name = 'expectLogin.sh'
        # 获取当前脚本的绝对路径，适配smarkit巡检项中调用脚本时, 会将脚本推送到/tmp/目录，但是当前目录会被识别为/root/的问题
        file_path = os.path.realpath(os.path.dirname(__file__))

        login_cmd = (file_path + '/' + file_name + ' ' + self.ip + ' ' + self.username + ' ' + self.password + \
                    ' \"' + 'screen-length 0 tem')
        all_cmd = '' + login_cmd

        if cmd_list:
            all_cmd += '; '
            
            for item in cmd_list:
                all_cmd += item + '; '

        all_cmd += '\"'
        res = self.popen(all_cmd)

        encrypt_log = self._encrypt_sensitive_info(all_cmd, self.password)
        # 只有在交换机不支持相应命令时才会有此字符串, 不能打印密码
        if "found at '^' position" in res:
            self.log.warning("cmd(%s) are not supported.", encrypt_log)

        return res

    def _login_ssh(self):
        cmd = ["sys", "aaa", "undo local-user policy security-enhance", "commit"]
        self._exec_command_ssh(cmd)

        return self._exec_command_ssh()

    def _exec_command_telnet(self, tn, cmd_list):
        res = ''
        for cmd in cmd_list:
            if cmd == "sys":
                continue

            tn.write(cmd.encode('ascii') + b"\n")
            length = -1 * len(END_SYMBOL[0])
            res = ''
            while 1:
                res += tn.read_very_eager().decode('utf-8')
                if res[-1:] in END_SYMBOL or res[-6:] in END_SYMBOL or res[length:] in END_SYMBOL:
                    break

            encrypt_log = self._encrypt_sensitive_info(cmd, self.password)
            # 只有在交换机不支持相应命令时才会有此字符串
            if "found at '^' position" in res:
                self.log.warning("cmd(%s) are not supported.", encrypt_log)

        return res

    def _login_telnet(self):
        tn = telnetlib.Telnet(self.ip, timeout=CONNECTION_TIMEOUT)
        login_res = tn.read_until(b'Username:', 1)
        if b'Username' in login_res:
            tn.write(self.username.encode('ascii') + b"\n")
            login_res = tn.read_until(b'Password:')
        tn.write(self.password.encode('ascii') + b'\n')
        res = ''
        length = -1 * len(END_SYMBOL[0])
        start_time = int(time.time())
        while 1:
            res += tn.read_very_eager().decode('utf-8')
            interval = int(time.time()) - start_time
            if interval > CONNECTION_TIMEOUT:
                self.log.error(res)
                passwd_error = Exception(res)
                raise passwd_error
            if '[Y/N]:' in res:
                res = ''
                self._exec_command_telnet(tn, ['N'])
                break
            if res[-1:] in END_SYMBOL:
                break
            if res[length:] in END_SYMBOL:
                break

        res = self._exec_command_telnet(tn, ['screen-length 0 tem'])
        END_SYMBOL[0] = res.splitlines().pop()
        return tn

    def _init_log(self, log_file):
        '''初始化日志'''
        if not os.path.exists(os.path.dirname(log_file)):
            os.makedirs(os.path.dirname(log_file))

        if not os.path.exists(log_file): # 适配日志文件误删除场景
            self.is_init_log = False

        logger = logging.getLogger()
        if self.is_init_log:
            return logger
        fh = logging.FileHandler(log_file)
        logger.setLevel(logging.INFO)
        formatter = logging.Formatter(fmt='[%(asctime)s] [%(filename)s:%(lineno)d] [%(levelname)s]: %(message)s',
                                      datefmt="%Y-%m-%d %H:%M:%S")
        fh.setFormatter(formatter)
        logger.addHandler(fh)
        self.is_init_log = True
        return logger

    def _encrypt_sensitive_info(self, log, password):
        '''处理日志中的敏感信息，目前只有密码'''
        if password in log:
            log = log.replace(password, '***')
        
        return log


class Interface():
    '''
    交换机端口类
    '''

    def __init__(self, intf_name, intf_status):
        self.intf_name = intf_name
        self.intf_status = intf_status


class GetSwConf(LoginClient):
    '''
    获取交换机配置
    '''

    def __init__(self, ip, username, password, check_intfs_list):
        super(GetSwConf, self).__init__(ip, username, password)
        self.check_intfs_list = check_intfs_list  # 交换机要检测的端口，由用户指定
        # 交换机端口的ecn配置：{'端口1': [ecn_conf1, ecn_conf2], '端口2': ...}
        self.intf_ecn_conf = {}
        # 交换机端口的pfc配置：{'端口1': [pfc_conf1, pfc_conf2], '端口2': ...}
        self.intf_pfc_conf = {}
        self.intf_name_list = []  # 交换机所有端口
        self.intf_list = []  # 交换机检测端口的信息列表，便于写入json文件中
        self.global_pfc_conf = ""  # 交换机全局pfc配置
        self.global_ecn_conf = {}  # 交换机全局ecn配置,可能多个ecn配置共存: {ecn_name, ecn_conf}
        self.sw_version = "None Sw Version"
        self.board_type = "None Board Type"
        self.get_switch_type()
        self.ai_fabric_enable()
        self.get_all_interface()
        self.get_global_pfc_conf()
        self.get_global_ecn_conf()
        self.get_interface_conf()

    def ai_fabric_enable(self):
        '''获取交换机是否开启ai-fabric'''
        res = self.exec_command(AI_FABRIC_COMMAND)
        if "ai-ecn enable" in res:
            self.ai_fabric = True
            return
        self.ai_fabric = False

    def get_switch_type(self):
        '''获取交换机版本以及型号'''
        res = self.exec_command(DIS_VERSION_COMMAND)
        for line in res.splitlines():
            if "Board  Type" in line:
                self.board_type = line.split()[-1]

            if "(CE" in line:
                self.sw_version = line.split()[-1].strip(')')

            if "(FM" in line:
                self.sw_version = line.split()[-1].strip(')')

    def intf_to_dict(self, intf):
        return {
            "portName": intf.intf_name,
            "status": intf.intf_status
        }

    def get_all_interface(self):
        """
        获取交换机所有端口:
        系统回显为:
        100GE1/4/32                xxx
        400GE1/1/1:1(200GE)        xxx
        进入交换机端口视图需要获取: 100GE1/4/32、400GE1/1/1:1
        """
        res = self.exec_command(DIS_INTERFACE_COMMAND)
        for line in res.splitlines():
            if re.match(r'^\d', line.strip()):
                intf_name = line.split()[0]
                intf_name = intf_name.split('(')[0]
                if (self.check_intfs_list is None) or (intf_name in self.check_intfs_list):
                    intf_status = line.split()[1]
                    intf = Interface(intf_name, intf_status)
                    '''将intf转成字典，便于打印在json文件中'''
                    intf_dict = self.intf_to_dict(intf)
                    self.intf_name_list.append(intf_name)
                    self.intf_list.append(intf_dict)

    def get_global_pfc_conf(self):
        '''获取全局pfc配置,将结果存入global_pfc_conf字符串中'''
        res = self.exec_command(DIS_PFC_COMMAND)
        for line in res.splitlines():
            if 'priority' in line.strip():
                self.global_pfc_conf += (line + "\n")

    def get_global_ecn_conf(self):
        '''获取全局ecn配置,将结果存入global_ecn_conf字典中'''
        res = self.exec_command(DIS_ECN_COMMAND)
        ecn_name = ""
        for line in res.splitlines():
            if "Drop-profile" in line:
                ecn_name = line.split()[-1]

            if "ECN" in line and ecn_name != "":
                data = line.split()
                ecn_data = "ecn buffer-size low-limit %s high-limit %s dicard-percentage %s" % \
                           (data[2], data[3], data[-1])
                self.global_ecn_conf.update({ecn_name: ecn_data})
                ecn_name = ""

        self.log.info(self.global_ecn_conf)

    def get_interface_conf(self):
        '''获取每个端口的ecn配置和pfc配置'''
        for intf in self.intf_name_list:
            intf_ecn_data = []
            intf_pfc_data = []
            res = self.exec_command(["sys", 'int %s' % intf, 'dis this'])

            for line in res.splitlines():
                if line.strip().startswith('qos'):
                    intf_ecn_data.append(line)
                if 'dcb' in line or 'trust' in line:
                    intf_pfc_data.append(line)

            self.intf_ecn_conf.update({intf: intf_ecn_data})
            self.intf_pfc_conf.update({intf: intf_pfc_data})

    def get_network_type(self):
        '''获取组网类型:pcp/dscp'''
        network_type = ""

        if not os.path.exists(NETWORK_CFG):
            self.log.error("network config %s doesn't exist" % NETWORK_CFG)
            raise FileNotFoundError

        stdout = self.popen("cat %s|grep g_roce_flow_type" % NETWORK_CFG)
        if "err" in stdout:
            self.log.error('get network.cfg fauled.')
            raise LookupError

        if 'dscp' in stdout:
            network_type = "dscp"
        elif 'pcp' in stdout:
            network_type = "pcp"
        else:
            self.log.error("cur environment doesn't support RoCE.")
            raise ValueError

        return network_type

    def dump_switch_info(self):
        '''将获取的交换机的型号、版本、端口以json格式存放到json文件中'''
        code = 0 if self.login_type != "" else 1
        info = {
            "checkItemId": "sw_type",
            "resultCode": code,
            "data": {
                "type": self.board_type,
                "version": self.sw_version,
                "ports":
                    self.intf_list
            }
        }

        flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
        modes = stat.S_IWUSR | stat.S_IRUSR | stat.S_IWGRP | stat.S_IRGRP | stat.S_IWOTH | stat.S_IROTH
        with os.fdopen(os.open(RDMA_SWITCH_PORTS_JSON + self.ip + ".json", flags, modes), "w") as json_file:
            json.dump(info, json_file, indent=4, separators=(',', ':'))


class QosCheck(GetSwConf):
    '''
    检查从交换机获取到的流控配置, 主要涉及pfc和ecn
    '''

    def __init__(self, arguments):
        super(QosCheck, self).__init__(arguments.ip, arguments.username, arguments.password, arguments.interfaces)
        if arguments.trust_type:
            self.trust_type = arguments.trust_type
        else:
            self.trust_type = self.get_network_type()
        if (self.trust_type != "dscp") and (self.trust_type != "pcp"):
            self.log.error("invalid network type.")
            raise ValueError
        self.select_sw_std_conf()
        self.pfc_error_ports = []
        self.ecn_error_ports = []

    def select_sw_std_conf(self):
        '''根据trust_type选择dscp配置还是pcp配置'''
        if self.trust_type == "dscp":
            self.std_global_pfc = SW_DSCP_CONF["std_global_pfc"]
            self.std_global_ecn = SW_DSCP_CONF["std_global_ecn"]
            self.std_intf_pfc_conf = SW_DSCP_CONF["std_intf_pfc_conf"]
            self.std_intf_ecn_conf = SW_DSCP_CONF["std_intf_ecn_conf"]
        else:
            self.std_global_pfc = SW_PCP_CONF["std_global_pfc"]
            self.std_global_ecn = SW_PCP_CONF["std_global_ecn"]
            self.std_intf_pfc_conf = SW_PCP_CONF["std_intf_pfc_conf"]
            self.std_intf_ecn_conf = SW_PCP_CONF["std_intf_ecn_conf"]

    def check_global_pfc(self):
        '''检查全局pfc配置:主要包含优先级和死锁检测'''
        return self._check_pfc_priority() and self._check_deadlock_detect()

    def check_interface_pfc(self):
        '''检查每个端口pfc配置:根据输入的标准端口配置去检测'''
        self._check_interface_conf(
            self.intf_pfc_conf, self.std_intf_pfc_conf, "pfc")

    def check_global_ecn(self):
        '''检查全局ecn配置'''
        if self.ai_fabric:  # 如果开启ai-fabric,则直接全局ecn检查返回true;如果未开启,则继续检测全局ecn配置
            return True

        for global_ecn_data in self.global_ecn_conf.values():
            if self.std_global_ecn.strip() in global_ecn_data:
                return True

        return False

    def check_interface_ecn(self):
        '''检查每个端口ecn配置:根据输入的标准端口配置去检测'''
        if self.ai_fabric:  # 开启ai-ecn的交换机,则ecn配置检查直接返成功
            return

        self._check_interface_conf(
            self.intf_ecn_conf, self.std_intf_ecn_conf, "ecn")

    def dump_check_info(self):
        global_pfc_code = 0 if self.check_global_pfc() else 1
        global_ecn_code = 0 if self.check_global_ecn() else 1

        self.check_interface_ecn()
        self.check_interface_pfc()
        intf_pfc_code = 0 if not self.pfc_error_ports else 1
        intf_ecn_code = 0 if not self.ecn_error_ports else 1

        info = [
            {
                "checkItemId": "global_pfc_check",
                "resultCode": global_pfc_code,
                "errorPort": [],
            },
            {
                "checkItemId": "global_ecn_check",
                "resultCode": global_ecn_code,
                "errorPort": [],
            },
            {
                "checkItemId": "intf_pfc_check",
                "resultCode": intf_pfc_code,
                "errorPort": self.pfc_error_ports,
            },
            {
                "checkItemId": "intf_ecn_check",
                "resultCode": intf_ecn_code,
                "errorPort": self.ecn_error_ports,
            }
        ]

        flags = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
        modes = stat.S_IWUSR | stat.S_IRUSR | stat.S_IWGRP | stat.S_IRGRP | stat.S_IWOTH | stat.S_IROTH
        with os.fdopen(os.open(RDMA_CHECK_SWITCH_JSON + self.ip + ".json", flags, modes), "w") as json_file:
            json.dump(info, json_file, indent=4, separators=(',', ':'))

    def _check_pfc_priority(self):
        '''检查pfc全局优先级'''
        if self.trust_type == "dscp":
            return True if '4' in self.global_pfc_conf else False
        elif self.trust_type == "pcp":
            return True if '3' in self.global_pfc_conf else False
        return False

    def _check_deadlock_detect(self):
        '''检查pfc死锁检测'''
        return True if "deadlock-detect" in self.std_global_pfc else False

    def _check_interface_conf(self, input_intf_conf, input_std_conf, check_type):
        '''检测端口的配置:对比标准配置和端口的配置'''
        for intf in self.intf_name_list:
            count = 0
            num_std_conf = 0
            intf_conf = input_intf_conf[intf]  # 根据端口号,获取对应的端口配置

            for std_data in input_std_conf.strip().splitlines():
                num_std_conf += 1
                for intf_data in intf_conf:
                    if std_data.strip() in intf_data.strip():
                        count += 1
                        break

            if count != num_std_conf:
                if check_type == "pfc":
                    self.pfc_error_ports.append(intf)
                else:
                    self.ecn_error_ports.append(intf)


def main(arguments):
    if arguments.login:
        try:
            conf = GetSwConf(arguments.ip, arguments.username, arguments.password, arguments.interfaces)
        except BaseException:
            sys.exit(1)
        conf.log.info("GetSwConf now, ip(%s).", arguments.ip)
        conf.dump_switch_info()
        return

    if arguments.check:
        try:
            sw_qos = QosCheck(arguments)
        except BaseException:
            sys.exit(1)
        sw_qos.log.info("QosCheck now, ip(%s).", arguments.ip)
        sw_qos.dump_check_info()


def read_password(timeout=3):
    ready, _, _ = select.select([sys.stdin], [], [], timeout)

    if ready:
        return sys.stdin.readline().strip()
    else:
        return ''

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--ip', help='switch ip')
    parser.add_argument('-u', '--username', help='switch username')
    # -f参数用于指定检查的端口，默认是全部端口
    parser.add_argument('-f', '--interfaces', nargs='+', help='check interfaces, -f [intf1] [intf2]...')

    # -t参数是为了平台测试,需要手动输入组网类型
    parser.add_argument('-t', '--trust_type',
                        help='switch trust type: dscp or pcp')
    parser.add_argument(
        '-l', '--login', help='login switch', action='store_true')
    parser.add_argument(
        '-c', '--check', help='check switch qos', action='store_true')

    args = parser.parse_args()
    args.password = read_password()

    if args.password:
        main(args)

    sys.exit(1)
