# -*- coding: UTF-8 -*-
import re
import time
import datetime
from psdk.checkitem.common.base_dsl_check import BaseCheckItem
from psdk.checkitem.scripts.cert_constants import DEFAULT_CERT_INFO
from psdk.dsl.dsl_common import get_version_info
from psdk.platform.entity.check_status import CheckStatus
from psdk.platform.util.base_util import get_common_msg
from psdk.dsl import fault_mode as ft
from psdk.platform.base.constants import PRODUCT_DORADO_V3

CERT_TYPE = "Type"
EXPIRE_TIME = "Expire Time"
CA_EXPIRE_TIME = "CA Expire Time"
CA_FINGERPRINT = "CA Fingerprint"
CERT_FINGERPRINT = "Fingerprint"
CERT_INFO = "Cert Info"
NOT_SUPPORT_CERTIFICATE = "enter a correct parameter"
DEFAULT_VALUE = "--"
ONE_DAY_SECONDS = 60*60*24
DEFAULT_CERT_TYPE = {
    'device_management', 'devicemanager_authentication',
    'vm_vnc_server_authentication', 'vm_vnc_client_authentication', 'hypermetro_arbitration',
    'https_protocol', 'ftps_protocol', 'call_home_authentication', 'remote_om_link'
}
# CLI回显跟type匹配的映射数组
DEFAULT_CERT_TYPE_CLI_TYPE = {
    "cli_type": 'Device Management', CERT_TYPE: 'device_management'
}, {
    "cli_type": 'HTTPS Protocol', CERT_TYPE: 'https_protocol'
}, {
    "cli_type": 'FTPS Protocol', CERT_TYPE: 'ftps_protocol'
}, {
    "cli_type": 'Call Home Authentication', CERT_TYPE: 'call_home_authentication'
}
# 屏蔽告警列表
DEFAULT_MASK_ALARM_LIST = ["0xF50000001", "0xF50000002", "0xF50000003", "0xF50000004", "0xF03370000", "0xF03370001"]


def check_is_need_deal_ret(check_ret):
    if NOT_SUPPORT_CERTIFICATE in check_ret or "^" in check_ret:
        return False
    return True


def cal_time(expire_time):
    if expire_time == "--":
        return False
    try:
        start_time = datetime.datetime(2022, 1, 1)
        # 需要优化巡检过期证书时间的检查范围，双活是2029年过期，虚拟机证书是2030年过期.
        end_time = datetime.datetime(2030, 12, 31)
        expire_date = datetime.datetime.strptime(expire_time, "%Y-%m-%d")
        return start_time <= expire_date <= end_time
    except ValueError:
        return False


class CheckItem(BaseCheckItem):

    def execute(self):
        if self.check_device_alarm_is_mask_for_cert():
            return CheckStatus.PASS, ""
        if self.check_device_version_is_not_surport_type():
            flag, cert_dictlist = self.get_cert_info_is_not_surport_type()
        else:
            flag, cert_dictlist = self.get_cert_info()
        if flag is not True:
            return CheckStatus.NOT_CHECK, get_common_msg(self.lang, "query.result.abnormal")
        self.logger.info("Get certificate expire time certificateInfoDictList:{}".format(cert_dictlist))
        # 校验证书过期时间
        return self.check_cert_expire(cert_dictlist)

    def check_device_alarm_is_mask_for_cert(self):
        # 判断当前设备是否屏蔽过告警
        return any(self.dsl("exec_cli 'show alarm_mask|filterRow column=Alarm ID predict=equal_to value=%s' |"
                            "horizontal_parser" % alarm_id) for alarm_id in DEFAULT_MASK_ALARM_LIST)

    def check_device_sshalgorithm_is_change(self):
        # 判断当前设备是否修改过ssh登录算法
        return self.dsl("exec_cli 'show event object_type=49|filterRow column=ID "
                        "predict=equal_to value=0x200F003100D3' | horizontal_parser") or \
               self.dsl("exec_cli 'show event object_type=818|filterRow column=ID "
                        "predict=equal_to value=0x200F0332006F' | horizontal_parser")

    def check_devicetime_is_change(self):
        # 判断当前设备时间是否修改过.
        cur_sys_time = self.dsl("exec_cli 'show system general' | regex 'Time.*(\d{4}-\d{2}-\d{2})' | get_index(0)")
        # 证书的初始时间最晚的为2024.4.7日,所以系统时间只要是正常的，一定会大于等于2024.4.7
        return (datetime.datetime.strptime(cur_sys_time, '%Y-%m-%d') <
            datetime.datetime.strptime('2024-04-08', '%Y-%m-%d'))

    def check_device_ip_rule_switch_is_on(self):
        # 判断设备的白名单是否打开
        cur_sys_rule_switch = self.dsl("exec_cli 'show security_rule' | horizontal_parser")[0].get("Enabled")
        return cur_sys_rule_switch == "Yes"

    def get_expire_time_is_not_surport_type(self, cert_type_info):
        self.logger.info(" get {} expire time".format(cert_type_info))
        cert_type = cert_type_info.get("cli_type")
        cert_info_dict = {CERT_TYPE: cert_type_info.get(CERT_TYPE), CERT_INFO: []}

        all_cert_info = self.dsl("exec_cli 'show certificate general' | horizontal_parser")
        for cli_ret_info in all_cert_info:
            if cert_type != cli_ret_info.get(CERT_TYPE):
                continue
            # 这个地方需要做特殊处理,因为V5R7C00和V5R7C10不支持detail版本,替换证书后,如果只匹配指纹,仍然会报错.所以不检查F3指纹。(这个流程只影响V5R7C00和C10,其他版本不走这个分支)
            if (cli_ret_info.get(CA_FINGERPRINT, DEFAULT_VALUE) ==
                    'F3:73:B3:87:06:5A:28:84:8A:F2:F3:4A:CE:19:2B:DD:C7:8E:9C:AC'):
                continue
            cert_info = {
                CA_EXPIRE_TIME: DEFAULT_VALUE,
                CA_FINGERPRINT: cli_ret_info.get(CA_FINGERPRINT, DEFAULT_VALUE),
                CERT_FINGERPRINT: DEFAULT_VALUE,
                EXPIRE_TIME: cli_ret_info.get(EXPIRE_TIME, "")
            }
            cert_info_dict[CERT_INFO].append(cert_info)
        return True, cert_info_dict

    def check_device_version_is_not_surport_type(self):
        # 获取基本信息
        version_info = get_version_info(self.dsl)
        base_version = version_info.get("base_version").get("Current Version")
        product_model = self.context.dev_node.model
        return (product_model == "2800 V5" and base_version == "V500R007C10") \
            or (product_model == "2800 V3" and base_version == "V300R006C20")

    def get_cert_info_is_not_surport_type(self):
        cert_info_dict_list = []
        # 生成证书字典列表
        for cert_type in DEFAULT_CERT_TYPE_CLI_TYPE:
            # 只检查存在预置证书的场景
            flag, expire_time = self.get_expire_time_is_not_surport_type(cert_type)
            if not flag:
                continue
            if expire_time:
                cert_info_dict_list.append(expire_time)
        return True, cert_info_dict_list

    def get_cert_info(self):
        # 判断系统版本，决定是否要获取"CA Expire Time"字段
        cert_info_dict_list = []
        check_result = self.check_is_get_ca_expire_time()
        cli_ret = self.dsl("exec_cli 'show certificate general type='", return_if={ft.FindStr("Error:"): "not_check"})
        if not cli_ret or (cli_ret == 'not_check'):
            return False, cert_info_dict_list
        regx = re.compile(r"type=([^?\s]\S+)")
        need_check_cert_list = regx.findall(cli_ret)
        self.logger.info("need check certificates list:{}".format(need_check_cert_list))
        # 生成证书字典列表
        for cert_type in need_check_cert_list:
            # 只检查存在预置证书的场景
            if cert_type not in DEFAULT_CERT_TYPE:
                continue
            if check_result:
                self.logger.info("Start get Expire Time,CA Expire Time")
                flag, ca_time = self.get_ca_expire_time_and_expire_time(cert_type)
                if not flag:
                    continue
                cert_info_dict_list.append(ca_time)
            else:
                self.logger.info("Start get lower system product:Expire Time")
                flag, expire_time = self.get_expire_time(cert_type)
                if not flag:
                    continue
                if expire_time:
                    cert_info_dict_list.append(expire_time)
        return True, cert_info_dict_list

    def check_is_get_ca_expire_time(self):
        # 获取基本信息
        version_info = get_version_info(self.dsl)
        base_version = version_info.get("base_version").get("Current Version")
        product_model = self.context.dev_node.model
        if product_model in PRODUCT_DORADO_V3 or product_model == "Dorado NAS":
            if base_version >= "V300R002C00":
                return True
        else:
            if base_version.startswith("V3") and base_version >= "V300R006C30":
                return True
            if base_version.startswith("V5") and base_version >= "V500R007C20":
                return True
        return False

    def get_ca_expire_time_and_expire_time(self, cert_type):
        cert_info_dict = {CERT_TYPE: cert_type, CERT_INFO: []}
        cli_ret_dict_list = self.dsl(
            "exec_cli 'show certificate general type={} detail=yes' | vertical_parser".format(
                cert_type))
        if not cli_ret_dict_list or len(cli_ret_dict_list) == 0:
            return False, ""
        for cli_ret_dict in cli_ret_dict_list:
            cert_info = {
                CA_EXPIRE_TIME: cli_ret_dict.get(CA_EXPIRE_TIME, ""),
                CA_FINGERPRINT: cli_ret_dict.get(CA_FINGERPRINT, ""),
                CERT_FINGERPRINT: cli_ret_dict.get(CERT_FINGERPRINT, ""),
                EXPIRE_TIME: cli_ret_dict.get(EXPIRE_TIME, "")
            }
            cert_info_dict[CERT_INFO].append(cert_info)
        return True, cert_info_dict

    # 检查过期时间是否在一天内,包含一天,因为时区调整会导致CLI命令回显产生一天差异.
    def check_time_is_in_onedays(self, cert_detail_time, cert_time):
        try:
            if (cert_detail_time == cert_time):
                return True
            else:
                cert_detailtime_format = datetime.datetime.strptime(cert_detail_time, "%Y-%m-%d")
                certtime_format = datetime.datetime.strptime(cert_time, "%Y-%m-%d")
                time_diff = abs(certtime_format - cert_detailtime_format)
                return time_diff.total_seconds() <= ONE_DAY_SECONDS
        except (ValueError):
            return False

    def check_cert_is_in_default(self, cert_detail, cert):
        # CA指纹相等，CA时间相等
        if (cert_detail[CA_FINGERPRINT] == cert[CA_FINGERPRINT]
                and self.check_time_is_in_onedays(cert_detail[CA_EXPIRE_TIME], cert[CA_EXPIRE_TIME])):
            return True
        #CA 指纹相等，无法获取CA时间场景(V500R007C00~V500R007C10)
        if (cert_detail[CA_FINGERPRINT] == cert[CA_FINGERPRINT]
                and cert_detail[CA_EXPIRE_TIME] == DEFAULT_VALUE):
            # V5R7C00和V5R7C10不支持detail版本,替换证书后,如果只匹配指纹,仍然会报错.所以不检查F3指纹。
            if (cert_detail[CA_FINGERPRINT] == 'F3:73:B3:87:06:5A:28:84:8A:F2:F3:4A:CE:19:2B:DD:C7:8E:9C:AC'):
                return False
            else:
                # 不支持时间显示的版本，需要将默认时间赋值给detail详情，方便巡检提示出来过期时间.
                cert_detail[CA_EXPIRE_TIME] = cert[CA_EXPIRE_TIME]
            return True
        # CA不存在，仅比较证书过期时间
        if (cert_detail[CA_FINGERPRINT] == DEFAULT_VALUE
                and self.check_time_is_in_onedays(cert_detail[EXPIRE_TIME], cert[EXPIRE_TIME])):
            return True
        return False

    def do_check(self, cert_type, cert_detail, default_cert, check_result_set):
        cert_expire_flag = False
        cert_ca_expire_flag = False
        err_msg = ""
        # 指纹为--时,仅比较时间,CA证书需要比较指纹+时间.
        if cert_detail[CERT_FINGERPRINT] == DEFAULT_VALUE:
            for cert in default_cert:
                if (self.check_cert_is_in_default(cert_detail, cert)):
                    cert_expire_flag = cal_time(cert_detail[EXPIRE_TIME])
                    cert_ca_expire_flag = cal_time(cert_detail[CA_EXPIRE_TIME])
                    break
        # 指纹有效,比较指纹.
        else:
            for cert in default_cert:
                if cert_detail[CERT_FINGERPRINT] == cert[CERT_FINGERPRINT]:
                    cert_expire_flag = cal_time(cert_detail[EXPIRE_TIME])
                    cert_ca_expire_flag = cal_time(cert_detail[CA_EXPIRE_TIME])
                    break
        if cert_ca_expire_flag is True and cert_expire_flag is True:
            err_msg += str(self.get_msg("check.cert.expire.no.pass",
                                        cert_type, cert_detail[EXPIRE_TIME], cert_detail[CA_EXPIRE_TIME]))
            check_result_set.add(cert_expire_flag)
            check_result_set.add(cert_ca_expire_flag)
        elif cert_ca_expire_flag is True:
            err_msg += str(self.get_msg("check.cert.expire.no.pass.only.ca.cert.expire",
                                        cert_type, cert_detail[CA_EXPIRE_TIME]))
            check_result_set.add(cert_ca_expire_flag)
        elif cert_expire_flag is True:
            err_msg += str(self.get_msg("check.cert.expire.no.pass.only.cert.expire",
                                        cert_type, cert_detail[EXPIRE_TIME]))
            check_result_set.add(cert_expire_flag)

        return err_msg

    def get_expire_cert_time(self, cert_dict, check_result_set):
        err_msg = ""
        for default_cert_set in DEFAULT_CERT_INFO:
            if cert_dict[CERT_TYPE] == default_cert_set.get(CERT_TYPE):
                for cert_detail in cert_dict[CERT_INFO]:
                    err_msg += self.do_check(cert_dict[CERT_TYPE], cert_detail, default_cert_set.get(CERT_INFO),
                                             check_result_set)
        return err_msg

    def get_expire_time(self, cert_type):
        self.logger.info("get {} expire time".format(cert_type))
        cert_info_dict = {CERT_TYPE: cert_type, CERT_INFO: []}

        check_ret = self.dsl(
            "exec_cli 'show certificate general type={}' | horizontal_parser".format(
                cert_type))

        if not check_is_need_deal_ret(check_ret):
            return True, ""
        if not check_ret or len(check_ret) == 0:
            return False, ""

        for cli_ret_info in check_ret:
            self.logger.info("get CA_FINGERPRINT{}".format(cli_ret_info.get(CA_FINGERPRINT, DEFAULT_VALUE)))
            cert_info = {
                CA_EXPIRE_TIME: DEFAULT_VALUE,
                CA_FINGERPRINT: cli_ret_info.get(CA_FINGERPRINT, DEFAULT_VALUE),
                CERT_FINGERPRINT: DEFAULT_VALUE,
                EXPIRE_TIME: cli_ret_info.get(EXPIRE_TIME, "")
            }
            cert_info_dict[CERT_INFO].append(cert_info)

        return True, cert_info_dict

    def check_cert_expire(self, cert_info_dicts):
        err_msg = ""
        check_result_set = set()
        for cert_dict in cert_info_dicts:
            err_msg += self.get_expire_cert_time(cert_dict, check_result_set)
        if True in check_result_set:
            if (self.check_devicetime_is_change()):
                err_msg += str(self.get_msg("check.cert.expire.no.pass.time.is.change"))
            if (self.check_device_ip_rule_switch_is_on()):
                err_msg += str(self.get_msg("check.cert.expire.no.pass.ip.rule.switch.is.open"))
            if (self.check_device_sshalgorithm_is_change()):
                err_msg += str(self.get_msg("check.cert.expire.no.pass.ssh.algorithms.is.change"))
            return CheckStatus.NOT_PASS, err_msg
        return CheckStatus.PASS, ""