# -*- coding: UTF-8 -*-
import cliUtil
import common
import traceback
from common import UnCheckException
from common_cache import not_support_nas_domain
import common_utils
import common_cache

PY_JAVA_ENV = py_java_env
LOGGER = common.getLogger(PY_LOGGER, __file__)
LANG = common.getLang(PY_JAVA_ENV)
ALL_CLI_RET = ''

SAN_DOMAIN_CMD = "show hyper_metro_domain general"
NAS_DOMAIN_CMD = "show fs_hyper_metro_domain general"
SAN_DOMAIN_DETAIL_CMD = "show hyper_metro_domain general domain_id=%s"
NAS_DOMAIN_DETAIL_CMD = "show fs_hyper_metro_domain general domain_id=%s"


def execute(cli):
    """
        双活仲裁网络配置检查：
        1、若步骤2查询结果为“No matching records”，则检查通过。
        2、若双活两端都不存在仲裁链路，则检查结果为建议优化。
        3、若只存在一端配置仲裁链路，则检查结果为不通过。
        4、若步骤3或者步骤6中的Link Status字段值存在不为Link Up，则检查结果为建议优化。
        5、若步骤3或步骤6中的所有有效IP网段相同，则检查结果为建议优化。
    """
    global ALL_CLI_RET
    error_msg = []
    no_check_msg = []
    try:
        # 本端SN
        locaDevSn = PY_JAVA_ENV.get("devInfo").getDeviceSerialNumber()
        # 获取domain 信息
        ALL_CLI_RET += "ON LOCAL DEVICE(SN:%s)" % locaDevSn
        domain_dict = getDomainInfo(locaDevSn, SAN_DOMAIN_CMD)
        nas_domain_dict = getDomainInfo(locaDevSn, NAS_DOMAIN_CMD)
        total_domain_dict = domain_dict.copy()
        total_domain_dict.update(nas_domain_dict)
        LOGGER.logInfo("total_domain_dict: {}".format(total_domain_dict))
        if not total_domain_dict:
            return True, ALL_CLI_RET, ''
        # 检查是否添加了远端设备
        flag, addedSnList, errRemoteMsg = common.checkAddedRemoteDevSn(
            PY_JAVA_ENV, LANG)
        LOGGER.logInfo("addedSnList: %s" % addedSnList)
        if flag is not True:
            return common_utils.get_result_bureau(PY_JAVA_ENV, ALL_CLI_RET, errRemoteMsg)
        # 获取远端设备信息
        san_remote_dev = getRemoteDeviceInfo(locaDevSn, domain_dict)
        nas_remote_dev = getRemoteDeviceInfo(locaDevSn, nas_domain_dict)
        checked_device = []
        # 检查san
        check_domain_quorum_config(
            san_remote_dev, locaDevSn, addedSnList, error_msg, no_check_msg,
            SAN_DOMAIN_DETAIL_CMD, checked_device)
        ALL_CLI_RET += "\n\n----Check SAN END----\n\nON LOCAL DEVICE(SN:{})".\
            format(locaDevSn)
        # 检查nas
        check_domain_quorum_config(
            nas_remote_dev, locaDevSn, addedSnList, error_msg, no_check_msg,
            NAS_DOMAIN_DETAIL_CMD, checked_device)
        ALL_CLI_RET += "\n\n----Check NAS END----"
        return common_utils.merge_result(
            list(), error_msg, no_check_msg, [ALL_CLI_RET], PY_JAVA_ENV)
    except UnCheckException as e:
        LOGGER.logError(str(traceback.format_exc()))
        return cliUtil.RESULT_NOCHECK, e.cliRet, e.errorMsg
    except Exception:
        LOGGER.logError(str(traceback.format_exc()))
        return (cliUtil.RESULT_NOCHECK, ALL_CLI_RET, common.getMsg(LANG, "query.result.abnormal"))


def check_domain_quorum_config(
        domain_dict, local_sn, added_sn_list,
        error_msg, no_check_msg, domain_cmd, checked_device):
    global ALL_CLI_RET
    for domain_id in domain_dict:
        # 检查本端是否正确
        temp_error_msg = checkDomainQuoServerLinkIp(
            local_sn, domain_id, domain_dict, domain_cmd)
        if temp_error_msg:
            error_msg.append(temp_error_msg)

        remote_sn = domain_dict.get(domain_id, {}).get("remote_sn")
        if remote_sn not in added_sn_list:
            no_check = common.getMsg(
                LANG, "not.add.remote.device.again", remote_sn)
            if no_check:
                no_check_msg.append(no_check)
            continue

        ALL_CLI_RET = common.joinLines(
            ALL_CLI_RET, "\nON REMOTE DEVICE(SN:%s):" % remote_sn)
        # 检查远端是否正确
        temp_error_msg = checkDomainQuoServerLinkIp(
            remote_sn, domain_id, domain_dict, domain_cmd)
        if temp_error_msg:
            error_msg.append(temp_error_msg)


def getRemoteDeviceInfo(devSn, domainDict):
    """
    获取远端设备
    :param devSn:
    :param domainDict:
    :return:
    """
    global ALL_CLI_RET
    cmd = "show remote_device general"
    flag, cliRet, errMsg = common.getObjFromFile(py_java_env, LOGGER, devSn, cmd, LANG)
    if cliRet not in ALL_CLI_RET:
        ALL_CLI_RET = common.joinLines(ALL_CLI_RET, cliRet)
    if flag is not True:
        LOGGER.logInfo("Failed to get information about remote device. errMsg:%s" % errMsg)
        raise UnCheckException(common.getMsg(LANG, "cannot.get.info", {"zh": u"远端设备", "en": "remote device"}.get(LANG)),
                               ALL_CLI_RET)

    remoteDeviceList = cliUtil.getHorizontalCliRet(cliRet)
    for remoteDev in remoteDeviceList:
        remoteDeviceId = remoteDev.get("ID", '')
        remoteDeviceSn = remoteDev.get("SN", '')
        for domainId in domainDict:
            if remoteDeviceId == domainDict[domainId]["remoteDeviceId"]:
                domainDict[domainId]["remote_sn"] = remoteDeviceSn

    return domainDict


def checkDomainQuoServerLinkIp(devSn, domainId, domainInfo, detail_cmd):
    """
    查询双活域的是否使用第三方仲裁,如果使用了则判断是否IP正确
    :param devSn:
    :param domainId:
    :param domainInfo:
    :param detail_cmd:
    :return: errMsg
    """
    errMsg = ''
    # 判断是否有双活pair
    hyperPairList = getHyperMetroPairIdList(devSn, domainId)
    if not hyperPairList:
        return errMsg

    domainDetailInfo = getDomainDetailInfo(devSn, domainId, detail_cmd)
    # 判断是否有主仲裁
    quoServerId = domainDetailInfo.get(domainId).get("quoServerId")
    LOGGER.logInfo("quoServerId : %s" % quoServerId)
    if quoServerId == '' or quoServerId == '--':
        return errMsg

    LOGGER.logInfo("domainDetailInfo : %s" % domainDetailInfo)
    errMsg += checkQuoServerIp(devSn, quoServerId)

    standbyQuoServerId = domainDetailInfo.get(domainId).get("quoStandbyServerId", '')
    LOGGER.logInfo("standbyQuoServerId : %s" % standbyQuoServerId)
    # 判断是否有子仲裁
    if standbyQuoServerId == '' or standbyQuoServerId == '--':
        return errMsg

    errMsg += checkQuoServerIp(devSn, standbyQuoServerId)
    return errMsg


def checkQuoServerIp(devSn, quoServId):
    """
    检查仲裁服务器下链路信息
    :param devSn:
    :param quoServId:
    :return:
    """
    err_msg = ''
    # 获取serer下serverlink
    localPortInfoDict = getQuorumServerLinkInfo(devSn, quoServId)
    #不存在仲裁链路检查结果为不通过
    if not localPortInfoDict:
        return common.getMsg(LANG, "hypermetro.arbnetwork.config.not.exit.quorum.server.link",(devSn, quoServId))
    
    LOGGER.logInfo("localPortInfoDict : %s" % localPortInfoDict)
    localPortIdList = []
    for portId in localPortInfoDict:
        if localPortInfoDict[portId].get("linkStatus") != 'Link Up':
            link_id = localPortInfoDict[portId].get("linkId")
            err_msg += common.getMsg(
                LANG, "hypermetro.arbnetwork.config.link.not.up",
                (quoServId, link_id, devSn)
            )
            err_msg += common.getMsg(
                LANG, "hypermetro.arbnetwork.config.link.not.up.sugg",
                link_id
            )
        else:
            localPortIdList.append(portId)

    if not localPortIdList or len(localPortIdList) == 1:
        return err_msg

    locaPortIpSmInfoDict = getLocalPortIpInfo(localPortIdList, devSn)
    LOGGER.logInfo("locaPortIpSmInfoDict : %s" % locaPortIpSmInfoDict)
    err_msg += checkEffectiveIpIsSame(devSn, locaPortIpSmInfoDict, quoServId)
    return err_msg


def getDomainInfo(devSn, cmd):
    """
    获取domain信息
    :param devSn:
    :param cmd:
    :return:
    """
    global ALL_CLI_RET
    domainDict = {}
    flag, cliRet, errMsg = common.getObjFromFile(py_java_env, LOGGER, devSn,
                                                 cmd, LANG)
    ALL_CLI_RET = common.joinLines(ALL_CLI_RET, cliRet)

    if not_support_nas_domain(cliRet):
        return domainDict

    if flag is not True:
        LOGGER.logInfo("Failed to get information about HyperMetro domain. errMsg:%s" % errMsg)
        raise UnCheckException(errMsg, ALL_CLI_RET)

    hyperMetroDomainList = cliUtil.getHorizontalCliRet(cliRet)
    for domainInfo in hyperMetroDomainList:
        is_nas = cmd == NAS_DOMAIN_CMD
        domain_id = domainInfo.get("ID", '')
        if is_nas and not common_utils.is_hypermetro_work_mode(
                domainInfo):
            continue

        if not filter_no_pair_domain(is_nas, domain_id, devSn):
            LOGGER.logInfo("domain {} has no pair. sn:{}".format(domain_id, devSn))
            continue

        quoServerId = domainInfo.get("Quorum Server ID", "")
        if not quoServerId:
            quoServerId = domainInfo.get("Quorum ID", "")
        quoServerId = "" if quoServerId == "--" else quoServerId
        remoteDeviceId = domainInfo.get("Remote Device ID", "")
        domainDict[domain_id] = {
            "remoteDeviceId": remoteDeviceId,
            "quoServerId": quoServerId
        }

    return domainDict


def filter_no_pair_domain(is_nas, domain_id, sn):
    """
    过滤掉没有pair的domain
    :param is_nas:
    :param domain_id:
    :param sn:
    :return:
    """
    if is_nas:
        domain_pair_list, pair_info_list = get_nas_hyper_metro_pair_id_list(
            sn, domain_id, PY_JAVA_ENV, LOGGER, LANG)
    else:
        domain_pair_list, pair_info_list = common_utils.get_hyper_metro_pair_id_list(
            sn, domain_id, PY_JAVA_ENV, LOGGER, LANG)

    return bool(domain_pair_list)


def get_nas_hyper_metro_pair_id_list(dev_sn, domain_id, env, logger, lang):
    """
    获取nas pair
    :param dev_sn:
    :param domain_id:
    :param env:
    :param logger:
    :param lang:
    :return:
    """
    flag, ret, msg, pair_list = common_cache.get_nas_pair_from_cache(
        env, logger, dev_sn, lang)
    if flag is not True:
        logger.logInfo("Failed to get information about HyperMetro Pair")
        raise UnCheckException(msg)

    pair_info_list = []
    pair_id_list = []
    for pair_info in pair_list:
        if pair_info.get("Domain ID") == domain_id:
            pair_info_list.append(pair_info)
            pair_id_list.append(pair_info.get("ID"))

    return pair_id_list, pair_info_list


def getDomainDetailInfo(devSn, domainId, detail_cmd):
    """
    获取domain详细信息
    :param devSn:
    :param domainId:
    :param detail_cmd:
    :return:
    """
    global ALL_CLI_RET
    domainInfoDict = {}
    cmd = detail_cmd % domainId
    flag, cliRet, errMsg = common.getObjFromFile(py_java_env, LOGGER, devSn, cmd, LANG)
    ALL_CLI_RET = common.joinLines(ALL_CLI_RET, cliRet)
    if not_support_nas_domain(cliRet):
        return domainInfoDict

    if flag is not True:
        LOGGER.logInfo("Failed to get information about HyperMetro domain. errMsg:%s" % errMsg)
        raise UnCheckException(errMsg, ALL_CLI_RET)

    hyperMetroDomainInfoList = cliUtil.getVerticalCliRet(cliRet)
    for domainInfo in hyperMetroDomainInfoList:
        quoServerName = domainInfo.get(
            "Quorum Server Name",
            domainInfo.get("Quorum Name", '')
        )
        quoServerName = "" if quoServerName == "--" else quoServerName
        quoServerId = domainInfo.get(
            "Quorum Server ID",
            domainInfo.get("Quorum ID", '')
        )
        quoServerId = "" if quoServerId == "--" else quoServerId
        quoModel = domainInfo.get("Quorum Mode", "")
        quoStandbyServerId = domainInfo.get("Standby Quorum ID", '')
        quoStandbyServerName = domainInfo.get(
            "Standby Quorum Server Name",
            domainInfo.get("Standby Quorum Name", '')
        )
        quoStandbyServerName = "" if quoStandbyServerName == "--" \
            else quoStandbyServerName
        domainInfoDict[domainId] = {"quoServerName": quoServerName,
                                    "quoServerId": quoServerId,
                                    "quoModel": quoModel,
                                    "quoStandbyServerName": quoStandbyServerName,
                                    "quoStandbyServerId": quoStandbyServerId
                                    }
    return domainInfoDict


def getHyperMetroPairIdList(devSn, domainId):
    """
    @summary: 获取双活pair
    @return: hyperMetroPairIdList 双活pair ID列表
    """
    global ALL_CLI_RET
    hyperMetroPairIdList = []

    cmd = "show hyper_metro_pair general |filterRow column=Domain\sID " \
          "predict=equal_to value=%s|filterColumn include " \
          "columnList=Local\sID,Remote\sID,ID,Domain\sID" % domainId
    flag, cliRet, errMsg = common.getObjFromFile(py_java_env, LOGGER, devSn, cmd, LANG)
    ALL_CLI_RET = common.joinLines(ALL_CLI_RET, cliRet)
    if flag is not True:
        LOGGER.logInfo("Failed to get information about HyperMetro Pair List. errMsg:%s" % errMsg)
        raise UnCheckException(
            common.getMsg(LANG, "cannot.get.info", {"zh": u"双活pair", "en": "HyperMetro Pair"}.get(LANG)), ALL_CLI_RET)

    hyperMetroPairList = cliUtil.getHorizontalCliRet(cliRet)

    for pairInfo in hyperMetroPairList:
        pairId = pairInfo.get("ID", '')

        if not pairId:
            continue

        hyperMetroPairIdList.append(pairId)

    return hyperMetroPairIdList


def getQuorumServerLinkInfo(devSn, serverId):
    """
    @summary: 获取双活仲裁链路的端口子网掩码和IP网段
    """
    global ALL_CLI_RET
    localPortInfoDict = {}
    cmd = "show quorum_server_link general server_id=%s" % serverId
    flag, cliRet, errMsg = common.getObjFromFile(py_java_env, LOGGER, devSn, cmd, LANG)
    ALL_CLI_RET = common.joinLines(ALL_CLI_RET, cliRet)
    if flag != True:
        raise UnCheckException(errMsg, ALL_CLI_RET)

    # 获取双活仲裁链路端口ID
    quorumInfoDictList = cliUtil.getHorizontalCliRet(cliRet)
    for quorumInfoDict in quorumInfoDictList:
        localPortId = quorumInfoDict.get("Local Port", "")
        linkStatus = quorumInfoDict.get("Link Status", "")
        linkId = quorumInfoDict.get("Link ID", "")
        
        localPortInfoDict[localPortId] = {"linkStatus":linkStatus,
                                          "linkId":linkId
                                          }

    return localPortInfoDict


def getLocalPortIpInfo(local_port_id_list, dev_sn):
    """
    获取端口IP信息
    :param local_port_id_list:
    :param dev_sn:
    :return:
    """
    global ALL_CLI_RET
    port_ip_info_dict = {}
    cmd = "show logical_port general|filterColumn include columnList=" \
          "Logical\sPort\sName,IPv4\sAddress,IPv4\sMask,IPv6\s" \
          "Address,IPv6\sMask"
    flag, cliRet, msg = common.getObjFromFile(
        py_java_env, LOGGER, dev_sn, cmd, LANG)
    ALL_CLI_RET = common.joinLines(ALL_CLI_RET, cliRet)
    ret_list = cliUtil.getHorizontalNostandardCliRet(cliRet)
    for ret_info in ret_list:
        logical_port = ret_info.get("Logical Port Name")
        if logical_port not in local_port_id_list:
            continue
        ipv4 = ret_info.get("IPv4 Address", "")
        subnet_mask = ret_info.get("IPv4 Mask", "")
        ipv6 = ret_info.get("IPv6 Address", "")
        ipv6_prefix_length = ret_info.get("IPv6 Mask", "")
        port_ip_info_dict[logical_port] = {
            "ipV4": ipv4,
            "subnetMask": subnet_mask,
            "ipV6": ipv6,
            "iPv6PrefixLength": ipv6_prefix_length
        }
    return port_ip_info_dict


def checkEffectiveIpIsSame(locaDevSn, locaPortIpInfoDict, quoServId):
    """
    @summary: 根据子网掩码和IP判断有效IP网段是否相同
    @param locaPortIpInfoDict: 本端双活仲裁链路信息
    @return: errMsg ：错误消息
    """
    errMsg = ""
    # 双活本端仲裁链路端口有效IP相同，检查结果为建议优化
    ipv6DiffFlag = False
    ipv4DiffFlag = False
    for localPortId in locaPortIpInfoDict:
        localIpV4 = locaPortIpInfoDict[localPortId].get("ipV4", "")
        localSubnetMask = locaPortIpInfoDict[localPortId].get("subnetMask", "")
        localIpV6 = locaPortIpInfoDict[localPortId].get("ipV6", "")
        iPv6PrefixLength = locaPortIpInfoDict[localPortId].get("iPv6PrefixLength", "")

        if ipv4DiffFlag or ipv6DiffFlag:
            break

        for localPortIdOther in locaPortIpInfoDict:
            if localPortId == localPortIdOther:
                continue

            if not ipv4DiffFlag:
                if localIpV4 and localIpV4 != "--":
                    localIpOther = locaPortIpInfoDict[localPortIdOther].get("ipV4", "")
                    localSubnetMaskOther = locaPortIpInfoDict[localPortIdOther].get("subnetMask", "")
                    if localIpOther and localIpOther != "--":
                        subnetMaskOther = bitWiseAndCalculation(localSubnetMask, localSubnetMaskOther)
                        localIpReal = bitWiseAndCalculation(subnetMaskOther, localIpV4)
                        localIpOtherReal = bitWiseAndCalculation(subnetMaskOther, localIpOther)
                        LOGGER.logInfo("localIpReal : %s, localIpOtherReal:%s" % (localIpReal, localIpOtherReal))
                        if localIpReal != localIpOtherReal:
                            ipv4DiffFlag = True
                            break
                        
            if not ipv6DiffFlag:
                if localIpV6 and localIpV6 != "--":
                    localIpV6Other = locaPortIpInfoDict[localPortIdOther].get("ipV6", "")
                    iPv6PrefixLengthOther = locaPortIpInfoDict[localPortIdOther].get("iPv6PrefixLength", "")
                    if localIpV6Other and localIpV6Other != '--':
                        binIpV6Net = getBinIpStr(localIpV6, iPv6PrefixLength)
                        binIpV6NetOther = getBinIpStr(localIpV6Other, iPv6PrefixLengthOther)
                        LOGGER.logInfo("localIpV6 : %s, localIpV6Other:%s" % (localIpV6, localIpV6Other))
                        LOGGER.logInfo("binIpV6Net : %s, binIpV6NetOther:%s" % (binIpV6Net, binIpV6NetOther))
                        if binIpV6Net != binIpV6NetOther:
                            ipv6DiffFlag = True
                            break
                        
    LOGGER.logInfo("ipv4DiffFlag : %s, ipv6DiffFlag:%s" % (ipv4DiffFlag, ipv6DiffFlag))
                        
    if not ipv4DiffFlag and not ipv6DiffFlag:
        errMsg += common.getMsg(LANG, "hypermetro.arbnetwork.config.IP.network.error",
                                (quoServId, ", ".join(locaPortIpInfoDict.keys()), locaDevSn))
        errMsg += common.getMsg(LANG, "hypermetro.arbnetwork.config.IP.network.error.sugg")

    return errMsg


def getBinIpStr(ipStr, preLen):
    """
    @summary: 获取二进制IPV6网段
    """
    ip = convertIpV6(ipStr)
    LOGGER.logInfo("convertIpV6[%s]to[%s]" % (ipStr, ip))
    resList = []
    hexList = ip.split(":")
    binStr = ""
    for hexStr in hexList:
        binStr += hex2bin(hexStr)

    preBinStrA = "".join(["1" for i in range(int(preLen))]) + "".join(["0" for i in range(128 - int(preLen))])

    for i in range(len(binStr)):
        resList.append(str(int(preBinStrA[i]) and int(binStr[i])))

    return "".join(resList)

def convertIpV6(ipStr):
    """
    @summary: ipV6存在缩写的情况： "1000:F::D"
    需要考虑，将双冒号转换为0
    """
    if "::" in ipStr:
        tmpIpStr = ipStr.replace("::",":")
        tmpLen = len(tmpIpStr.split(":"))
        if tmpLen == 1:
            replaceStr = ":0:"
        else:
            replaceStr = ":" + ":".join(["0" for i in range(8 - tmpLen)]) + ":"
        return ipStr.replace("::", replaceStr)
    
    return ipStr

def bitWiseAndCalculation(strA, strB):
    """
    @summary: 将子网掩码或者IP进行按位与运算，子网掩码中255==111
            例如：
            步骤7，subMaskA为255.255.0.0，subMaskB为255.255.255.0，则subMask为255.255.0.0；
            步骤8，ipAddressA为200.46.xx.xx，ipAddressB为200.46.xx.xx，subMask为255.255.0.0，
            则对比的有效IP网段为200.46和200.46相同，检查不通过。
    """
    try:
        aList = strA.split(".")
        bList = strB.split(".")

        # 子网掩码和IPv4网段一定是4位
        resList = []
        for i in range(len(aList) if len(aList) < len(bList) else len(bList)):
            resList.append(int(aList[i]) & int(bList[i]))

        strBAndA = ".".join(str(i) for i in resList)

        return strBAndA

    except Exception:
        LOGGER.logError(str(traceback.format_exc()))
        return ""


def hex2dec(string_num):
    """
    hex2dec 十六进制 to 十进制
    :param string_num:
    :return:
    """
    try:
        return str(int(string_num.upper(), 16))
    except:
        return ''


def dec2bin(string_num):
    """
    dec2bin 十进制 to 二进制: bin()
    :param string_num:
    :return:
    """
    try:
        base = [str(x) for x in range(10)] + [chr(x) for x in range(ord('A'), ord('A') + 6)]
        num = int(string_num)
        mid = []
        while True:
            if num == 0:
                break
            num, rem = divmod(num, 2)
            mid.append(base[rem])
    
        return ''.join([str(x) for x in mid[::-1]])
    except:
        return ''


def hex2bin(string_num):
    """
    hex2tobin 十六进制 to 二进制: bin(int(str,16))
    :param string_num:
    :return:
    """
    if not string_num:
        binStr = ''
    else:
        binStr = dec2bin(hex2dec(string_num.upper()))
        
    if len(binStr) != 16:
        return "".join(["0" for i in range(16 - len(binStr))]) + binStr

    return binStr
