# -*- coding: UTF-8 -*-
import cliUtil
import common
import traceback
from common import UnCheckException
import common_utils

PY_JAVA_ENV = py_java_env
LOGGER = common.getLogger(PY_LOGGER, __file__)
LANG = common.getLang(PY_JAVA_ENV)
ALL_CLI_RET = ''


def execute(cli):
    """
        双活仲裁网络配置检查：
        1、若步骤2查询结果为“No matching records”，则检查通过。
        2、若双活两端都不存在仲裁链路，则检查结果为建议优化。
        3、若只存在一端配置仲裁链路，则检查结果为不通过。
        4、若步骤3或者步骤6中的Link Status字段值存在不为Link Up，则检查结果为建议优化。
        5、若步骤3或步骤6中的所有有效IP网段相同，则检查结果为建议优化。
    """
    global ALL_CLI_RET
    flag = True
    errMsg = ""
    noCheckMsg = ""
    try:
        # 本端SN
        locaDevSn = PY_JAVA_ENV.get("devInfo").getDeviceSerialNumber()

        # 获取domain 信息
        domainInfoDict = getDomainInfo(locaDevSn)
        LOGGER.logInfo("domainInfoDict: %s" % domainInfoDict)
        if not domainInfoDict:
            return True, ALL_CLI_RET, ''

        # 检查是否添加了远端设备
        __, addedSnList, errRemoteMsg = common.checkAddedRemoteDevSn(PY_JAVA_ENV, LANG)
        LOGGER.logInfo("addedSnList: %s" % addedSnList)

        # 获取远端设备信息
        remoteDeviceInfo = getRemoteDeviceInfo(locaDevSn, domainInfoDict)
        LOGGER.logInfo("remoteDeviceInfo: %s" % remoteDeviceInfo)
        for domainId in domainInfoDict:
            domain_pair_list, pair_info_list = common_utils.get_hyper_metro_pair_id_list(
                locaDevSn, domainId, PY_JAVA_ENV, LOGGER, LANG)
            if not domain_pair_list:
                LOGGER.logInfo("domain {} do not have pair. sn:{}".format(domainId, locaDevSn))
                continue

            # 检查本端是否正确
            tmp_msg = checkDomainQuoServerLinkIp(locaDevSn, domainId,
                                                 domainInfoDict)
            if tmp_msg not in errMsg:
                errMsg += tmp_msg
            LOGGER.logInfo("checkDomainQuoServerLinkIp errMsg: %s" % errMsg)
            remoteSn = remoteDeviceInfo.get(domainId, "")
            if addedSnList is None:
                noCheckMsg+= errRemoteMsg
                continue
            
            if remoteSn not in addedSnList:
                noCheckMsg += common.getMsg(LANG, "not.add.remote.device.again", remoteSn)
                continue

            ALL_CLI_RET = common.joinLines(ALL_CLI_RET, "ON REMOTE DEVICE(SN:%s):" % remoteSn)
            remote_domain_dict = getDomainInfo(remoteSn)
            # 检查远端是否正确
            tmp_msg = checkDomainQuoServerLinkIp(remoteSn, domainId,
                                                 remote_domain_dict)
            if tmp_msg not in errMsg:
                errMsg += tmp_msg

        if errMsg:
            return cliUtil.RESULT_WARNING, ALL_CLI_RET, errMsg + noCheckMsg

        if noCheckMsg:
            return common_utils.get_result_bureau(PY_JAVA_ENV, ALL_CLI_RET, noCheckMsg)

        return (True, ALL_CLI_RET, '')
    except UnCheckException as e:
        LOGGER.logError(str(traceback.format_exc()))
        return cliUtil.RESULT_NOCHECK, e.cliRet, e.errorMsg
    except Exception:
        LOGGER.logError(str(traceback.format_exc()))
        return (cliUtil.RESULT_NOCHECK, ALL_CLI_RET, common.getMsg(LANG, "query.result.abnormal"))


def getRemoteDeviceInfo(devSn, domainDict):
    """
    获取远端设备
    :param devSn:
    :param domainDict:
    :return:
    """
    global ALL_CLI_RET
    remoteDeviceDict = {}
    cmd = "show remote_device general"
    flag, cliRet, errMsg = common.getObjFromFile(py_java_env, LOGGER, devSn, cmd, LANG)
    ALL_CLI_RET = common.joinLines(ALL_CLI_RET, cliRet)
    if flag is not True:
        LOGGER.logInfo("Failed to get information about remote device. errMsg:%s" % errMsg)
        raise UnCheckException(common.getMsg(LANG, "cannot.get.info", {"zh": u"远端设备", "en": "remote device"}.get(LANG)),
                               ALL_CLI_RET)

    remoteDeviceList = cliUtil.getHorizontalCliRet(cliRet)
    for remoteDev in remoteDeviceList:
        remoteDeviceId = remoteDev.get("ID", '')
        remoteDeviceSn = remoteDev.get("SN", '')
        for domainId in domainDict:
            if remoteDeviceId == domainDict[domainId]["remoteDeviceId"]:
                remoteDeviceDict[domainId] = remoteDeviceSn

    return remoteDeviceDict


def checkDomainQuoServerLinkIp(devSn, domainId, domainInfo):
    """
    查询双活域的是否使用第三方仲裁,如果使用了则判断是否IP正确
    :param devSn:
    :param domainId:
    :param domainInfo:
    :return: errMsg
    """
    errMsg = ''
    # 判断是否有双活pair
    hyperPairList = getHyperMetroPairIdList(devSn, domainId)
    if not hyperPairList:
        return errMsg

    # 判断是否有主仲裁
    quoServerId = domainInfo.get(domainId).get("quoServerId")
    LOGGER.logInfo("quoServerId : %s" % quoServerId)
    if quoServerId == '' or quoServerId == '--':
        return errMsg

    domainDetailInfo = getDomainDetailInfo(devSn, domainId)
    LOGGER.logInfo("domainDetailInfo : %s" % domainDetailInfo)
    errMsg += checkQuoServerIp(devSn, quoServerId)

    standbyQuoServerId = domainDetailInfo.get(domainId).get("quoStandbyServerId", '')
    LOGGER.logInfo("standbyQuoServerId : %s" % standbyQuoServerId)
    # 判断是否有子仲裁
    if standbyQuoServerId == '' or standbyQuoServerId == '--':
        return errMsg

    errMsg += checkQuoServerIp(devSn, standbyQuoServerId)
    return errMsg


def checkQuoServerIp(devSn, quoServId):
    """
    检查仲裁服务器下链路信息
    :param devSn:
    :param quoServId:
    :return:
    """
    err_msg = ''
    # 获取serer下serverlink
    localPortInfoDict = getQuorumServerLinkInfo(devSn, quoServId)
    #不存在仲裁链路检查结果为不通过
    if not localPortInfoDict:
        return common.getMsg(LANG, "hypermetro.arbnetwork.config.not.exit.quorum.server.link",(devSn, quoServId))
    
    LOGGER.logInfo("localPortInfoDict : %s" % localPortInfoDict)
    localPortIdList = []
    for portId in localPortInfoDict:
        if localPortInfoDict[portId].get("linkStatus") != 'Link Up':
            err_msg += common.getMsg(LANG, "hypermetro.arbnetwork.config.link.not.up", (quoServId, localPortInfoDict[portId].get("linkId"), devSn))
            err_msg += common.getMsg(
                LANG, "hypermetro.arbnetwork.config.link.not.up.sugg",
                localPortInfoDict.get(portId, {}).get("linkId")
            )
        else:
            localPortIdList.append(portId)

    if not localPortIdList or len(localPortIdList) == 1:
        return err_msg

    locaPortIpSmInfoDict = getLocalPortIpInfo(localPortIdList, devSn)
    LOGGER.logInfo("locaPortIpSmInfoDict : %s" % locaPortIpSmInfoDict)
    err_msg += checkEffectiveIpIsSame(devSn, locaPortIpSmInfoDict, quoServId)
    return err_msg


def getDomainInfo(devSn):
    """
    获取domain信息
    :param devSn:
    :return:
    """
    global ALL_CLI_RET
    domainDict = {}
    cmd = "show hyper_metro_domain general"
    flag, cliRet, errMsg = common.getObjFromFile(py_java_env, LOGGER, devSn,
                                                 cmd, LANG)
    ALL_CLI_RET = common.joinLines(ALL_CLI_RET, cliRet)
    if flag != True:
        LOGGER.logInfo("Failed to get information about HyperMetro domain. errMsg:%s" % errMsg)
        raise UnCheckException(errMsg, ALL_CLI_RET)

    hyperMetroDomainList = cliUtil.getHorizontalCliRet(cliRet)
    for domainInfo in hyperMetroDomainList:
        domainId = domainInfo.get("ID", '')
        quoServerId = domainInfo.get("Quorum Server ID", "")
        remoteDeviceId = domainInfo.get("Remote Device ID", "")
        domainDict[domainId] = {
            "remoteDeviceId": remoteDeviceId,
            "quoServerId": quoServerId
        }

    return domainDict


def getDomainDetailInfo(devSn, domainId):
    """
    获取domain详细信息
    :param devSn:
    :param domainId:
    :return:
    """
    global ALL_CLI_RET
    domainInfoDict = {}
    cmd = "show hyper_metro_domain general domain_id=%s" % domainId
    flag, cliRet, errMsg = common.getObjFromFile(py_java_env, LOGGER, devSn, cmd, LANG)
    ALL_CLI_RET = common.joinLines(ALL_CLI_RET, cliRet)
    if flag != True:
        LOGGER.logInfo("Failed to get information about HyperMetro domain. errMsg:%s" % errMsg)
        raise UnCheckException(errMsg, ALL_CLI_RET)

    hyperMetroDomainInfoList = cliUtil.getVerticalCliRet(cliRet)
    for domainInfo in hyperMetroDomainInfoList:
        quoServerName = domainInfo.get("Quorum Server Name", '')
        quoServerId = domainInfo.get("Quorum Server ID", '')
        quoModel = domainInfo.get("Quorum Mode", "")
        quoStandbyServerId = domainInfo.get("Standby Quorum ID", '')
        quoStandbyServerName = domainInfo.get("Standby Quorum Server Name", '')

        domainInfoDict[domainId] = {"quoServerName": quoServerName,
                                    "quoServerId": quoServerId,
                                    "quoModel": quoModel,
                                    "quoStandbyServerName": quoStandbyServerName,
                                    "quoStandbyServerId": quoStandbyServerId
                                    }
    return domainInfoDict


def getHyperMetroPairIdList(devSn, domainId):
    """
    @summary: 获取双活pair
    @return: hyperMetroPairIdList 双活pair ID列表
    """
    global ALL_CLI_RET
    hyperMetroPairIdList = []

    cmd = "show hyper_metro_pair general |filterRow column=Domain\sID " \
          "predict=equal_to value=%s|filterColumn include columnList=" \
          "Local\sID,Remote\sID,ID,Domain\sID" % domainId
    flag, cliRet, errMsg = common.getObjFromFile(py_java_env, LOGGER, devSn, cmd, LANG)
    ALL_CLI_RET = common.joinLines(ALL_CLI_RET, cliRet)
    if flag is not True:
        LOGGER.logInfo("Failed to get information about HyperMetro Pair List. errMsg:%s" % errMsg)
        raise UnCheckException(
            common.getMsg(LANG, "cannot.get.info", {"zh": u"双活pair", "en": "HyperMetro Pair"}.get(LANG)), ALL_CLI_RET)

    hyperMetroPairList = cliUtil.getHorizontalCliRet(cliRet)

    for pairInfo in hyperMetroPairList:
        pairId = pairInfo.get("ID", '')

        if not pairId:
            continue

        hyperMetroPairIdList.append(pairId)

    return hyperMetroPairIdList


def getQuorumServerLinkInfo(devSn, serverId):
    """
    @summary: 获取双活仲裁链路的端口子网掩码和IP网段
    """
    global ALL_CLI_RET
    localPortInfoDict = {}
    cmd = "show quorum_server_link general server_id=%s" % serverId
    flag, cliRet, errMsg = common.getObjFromFile(py_java_env, LOGGER, devSn, cmd, LANG)
    ALL_CLI_RET = common.joinLines(ALL_CLI_RET, cliRet)
    if flag != True:
        raise UnCheckException(errMsg, ALL_CLI_RET)

    # 获取双活仲裁链路端口ID
    quorumInfoDictList = cliUtil.getHorizontalCliRet(cliRet)
    for quorumInfoDict in quorumInfoDictList:
        localPortId = quorumInfoDict.get("Local Port", "")
        linkStatus = quorumInfoDict.get("Link Status", "")
        linkId = quorumInfoDict.get("Link ID", "")
        
        localPortInfoDict[localPortId] = {"linkStatus":linkStatus,
                                          "linkId":linkId
                                          }

    return localPortInfoDict


def getLocalPortIpInfo(localPortIdList, devSn):
    """
    获取端口IP信息
    :param localPortIdList:
    :param devSn:
    :return:
    """
    global ALL_CLI_RET
    portIpSmInfoDict = {}
    ipV4 = ""
    subnetMask = ""
    IPV6 = ""
    iPv6PrefixLength = ""
    for localPortId in localPortIdList:
        cmd = "show port general port_id=%s" % localPortId
        flag, cliRet, errMsg = common.getObjFromFile(py_java_env, LOGGER, devSn, cmd, LANG)
        ALL_CLI_RET = common.joinLines(ALL_CLI_RET, cliRet)

        if flag != True:
            raise UnCheckException(errMsg, ALL_CLI_RET)

        portInfoDictList = cliUtil.getVerticalCliRet(cliRet)
        for portInfoDict in portInfoDictList:
            ipV4 = portInfoDict.get("IPv4 Address", "")
            subnetMask = portInfoDict.get("Subnet Mask", "")
            IPV6 = portInfoDict.get("IPv6 Address", "")
            iPv6PrefixLength = portInfoDict.get("IPv6 Prefix Length", "")
        portIpSmInfoDict[localPortId] = {"ipV4": ipV4, "subnetMask": subnetMask,
                                         "ipV6": IPV6, "iPv6PrefixLength": iPv6PrefixLength
                                         }
    return portIpSmInfoDict


def checkEffectiveIpIsSame(locaDevSn, locaPortIpInfoDict, quoServId):
    """
    @summary: 根据子网掩码和IP判断有效IP网段是否相同
    @param locaPortIpInfoDict: 本端双活仲裁链路信息
    @return: errMsg ：错误消息
    """
    err_msg = ""
    # 双活本端仲裁链路端口有效IP相同，检查结果为建议优化
    ipv6DiffFlag = False
    ipv4DiffFlag = False
    for localPortId in locaPortIpInfoDict:
        localIpV4 = locaPortIpInfoDict[localPortId].get("ipV4", "")
        localSubnetMask = locaPortIpInfoDict[localPortId].get("subnetMask", "")
        localIpV6 = locaPortIpInfoDict[localPortId].get("ipV6", "")
        iPv6PrefixLength = locaPortIpInfoDict[localPortId].get("iPv6PrefixLength", "")

        if ipv4DiffFlag or ipv6DiffFlag:
            break

        for localPortIdOther in locaPortIpInfoDict:
            if localPortId == localPortIdOther:
                continue

            if not ipv4DiffFlag:
                if localIpV4 and localIpV4 != "--":
                    localIpOther = locaPortIpInfoDict[localPortIdOther].get("ipV4", "")
                    localSubnetMaskOther = locaPortIpInfoDict[localPortIdOther].get("subnetMask", "")
                    if localIpOther and localIpOther != "--":
                        subnetMaskOther = bitWiseAndCalculation(localSubnetMask, localSubnetMaskOther)
                        localIpReal = bitWiseAndCalculation(subnetMaskOther, localIpV4)
                        localIpOtherReal = bitWiseAndCalculation(subnetMaskOther, localIpOther)
                        LOGGER.logInfo("localIpReal : %s, localIpOtherReal:%s" % (localIpReal, localIpOtherReal))
                        if localIpReal != localIpOtherReal:
                            ipv4DiffFlag = True
                            break
                        
            if not ipv6DiffFlag:
                if localIpV6 and localIpV6 != "--":
                    localIpV6Other = locaPortIpInfoDict[localPortIdOther].get("ipV6", "")
                    iPv6PrefixLengthOther = locaPortIpInfoDict[localPortIdOther].get("iPv6PrefixLength", "")
                    if localIpV6Other and localIpV6Other != '--':
                        binIpV6Net = getBinIpStr(localIpV6, iPv6PrefixLength)
                        binIpV6NetOther = getBinIpStr(localIpV6Other, iPv6PrefixLengthOther)
                        LOGGER.logInfo("localIpV6 : %s, localIpV6Other:%s" % (localIpV6, localIpV6Other))
                        LOGGER.logInfo("binIpV6Net : %s, binIpV6NetOther:%s" % (binIpV6Net, binIpV6NetOther))
                        if binIpV6Net != binIpV6NetOther:
                            ipv6DiffFlag = True
                            break
                        
    LOGGER.logInfo("ipv4DiffFlag : %s, ipv6DiffFlag:%s" % (ipv4DiffFlag, ipv6DiffFlag))
                        
    if not ipv4DiffFlag and not ipv6DiffFlag:
        err_msg += common.getMsg(LANG, "hypermetro.arbnetwork.config.IP.network.error",
                                (quoServId, ", ".join(locaPortIpInfoDict.keys()), locaDevSn))
        err_msg += common.getMsg(LANG, "hypermetro.arbnetwork.config.IP.network.error.sugg")
    return err_msg


def getBinIpStr(ipStr, preLen):
    """
    @summary: 获取二进制IPV6网段
    """
    ip = convertIpV6(ipStr)
    LOGGER.logInfo("convertIpV6[%s]to[%s]" % (ipStr, ip))
    resList = []
    hexList = ip.split(":")
    binStr = ""
    for hexStr in hexList:
        binStr += hex2bin(hexStr)

    preBinStrA = "".join(["1" for i in range(int(preLen))]) + "".join(["0" for i in range(128 - int(preLen))])

    for i in range(len(binStr)):
        resList.append(str(int(preBinStrA[i]) and int(binStr[i])))

    return "".join(resList)

def convertIpV6(ipStr):
    """
    @summary: ipV6存在缩写的情况： "1000:F::D"
    需要考虑，将双冒号转换为0
    """
    if "::" in ipStr:
        tmpIpStr = ipStr.replace("::",":")
        tmpLen = len(tmpIpStr.split(":"))
        if tmpLen == 1:
            replaceStr = ":0:"
        else:
            replaceStr = ":" + ":".join(["0" for i in range(8 - tmpLen)]) + ":"
        return ipStr.replace("::", replaceStr)
    
    return ipStr

def bitWiseAndCalculation(strA, strB):
    """
    @summary: 将子网掩码或者IP进行按位与运算，子网掩码中255==111
            例如：
            步骤7，subMaskA为255.255.0.0，subMaskB为255.255.255.0，则subMask为255.255.0.0；
            步骤8，ipAddressA为200.46.xx.xx，ipAddressB为200.46.xx.xx，subMask为255.255.0.0，
            则对比的有效IP网段为200.46和200.46相同，检查不通过。
    """
    try:
        aList = strA.split(".")
        bList = strB.split(".")

        # 子网掩码和IPv4网段一定是4位
        resList = []
        for i in range(len(aList) if len(aList) < len(bList) else len(bList)):
            resList.append(int(aList[i]) & int(bList[i]))

        strBAndA = ".".join(str(i) for i in resList)

        return strBAndA

    except Exception:
        LOGGER.logError(str(traceback.format_exc()))
        return ""


def hex2dec(string_num):
    """
    hex2dec 十六进制 to 十进制
    :param string_num:
    :return:
    """
    try:
        return str(int(string_num.upper(), 16))
    except:
        return ''


def dec2bin(string_num):
    """
    dec2bin 十进制 to 二进制: bin()
    :param string_num:
    :return:
    """
    try:
        base = [str(x) for x in range(10)] + [chr(x) for x in range(ord('A'), ord('A') + 6)]
        num = int(string_num)
        mid = []
        while True:
            if num == 0:
                break
            num, rem = divmod(num, 2)
            mid.append(base[rem])
    
        return ''.join([str(x) for x in mid[::-1]])
    except:
        return ''


def hex2bin(string_num):
    """
    hex2tobin 十六进制 to 二进制: bin(int(str,16))
    :param string_num:
    :return:
    """
    if not string_num:
        binStr = ''
    else:
        binStr = dec2bin(hex2dec(string_num.upper()))
        
    if len(binStr) != 16:
        return "".join(["0" for i in range(16 - len(binStr))]) + binStr

    return binStr
