# -*- coding: UTF-8 -*-

from java.io import File

import os
import re
import traceback
import common
import modelManager


def getHeartbeatIp(ssh, devObj):
    """
    # *****************************************************************************************#
    # 函数名称: getHeartbeatIp(ssh)
    # 功能说明: 获取对端心跳IP地址（需要确定使用之前为debug模式）
    # 输入参数: ssh, devObj
    # 返 回 值: 对端心跳IP地址，获取无效时返回为空
    # *****************************************************************************************#
    """
    logger = devObj.get("logger")
    
    #执行命令，获取IP信息
    cliRet = common.execCmd(ssh, "ifconfig")
    cliRetList = cliRet.splitlines()
    canFind = False
    
   #获取设备型号和版本号需要切换到cli模式
    curModel = modelManager.getCurSysMode(ssh)
    flag = modelManager.changeMode2CLI(devObj)
    if not flag:
        logger.error("[getHeartbeatIp] Change model to cli failed!")
        return ""
    
    #获取设备型号和版本号
    heartBeatEthName = ""
    deviceType = common.getDeviceType(ssh)
    deviceVersion = common.getProductVersion(ssh)
    typeAndVer = deviceType + deviceVersion
    logger.error("[getHeartbeatIp] Current type and version is %s." % typeAndVer)

	#切换回原来的模式
    flag = modelManager.changeMode(devObj, curModel)
    if not flag:
        logger.error("[getHeartbeatIp] Change model to %s failed!" % curModel)
        return ""
    
    #产品不同获取心跳IP关键字不同
    if "S2600V100R001" in typeAndVer:  
        heartBeatEthName = "eth0 "
    else:
        heartBeatEthName = "bond0 "
        
    ip = ""
    for lineInfo in cliRetList:
        lineInfo = lineInfo.strip()
        if lineInfo.startswith(heartBeatEthName):
            canFind = True
            continue
        try:
            #可以查找心跳IP了
            if canFind and lineInfo.startswith("inet addr:"):
                ip = lineInfo.split("Bcast")[0].replace("inet addr:", "").strip()
                if ip.endswith(".10"):
                    ip = ip[:ip.rfind(".")] + ".11"
                elif ip.endswith(".11"):
                    ip = ip[:ip.rfind(".")] + ".10"    
                break    
        except:
            return ""
    return ip

def scpPeerFile(ssh, pwd, heartBeat, devicePeerDir, deviceTmpDir):
    """
    # *****************************************************************************************#
    # 函数名称: scpPeerFile(ssh, devObj, heartBeat, devicePeerDir, deviceTmpDir)
    # 功能说明: 通过心跳IP拷贝对端文件到本端指定目录（debug模式下）
    # 输入参数: 
    #     ssh:
    #     pwd：登录密码
    #     heartBeat:对端心跳IP
    #     devicePeerDir:对端源文件（目录）
    #     deviceTmpDir：本端目标文件（目录：需要存在）
    # 返 回 值: True/False
    # *****************************************************************************************#
    """
    try:
        #到对端执行命令拷贝文件
        itemCliRet = common.execCmd(ssh, "scp -rp " + "admin@" + str(heartBeat) + ":" + str(devicePeerDir) + " " + deviceTmpDir)
        cliRet = itemCliRet
        if re.search("Connection timed out", itemCliRet, re.IGNORECASE):
            return False
        
        if re.search("Are you sure you want to continue connecting (yes/no)?", itemCliRet, re.IGNORECASE):
            itemCliRet = common.execCmd(ssh, "yes")
            cliRet += itemCliRet
            
        if re.search("Password:", cliRet, re.IGNORECASE):
            cliRet = common.execCmdNoLog(ssh, pwd)
        
        if re.search("No such file or directory", cliRet, re.IGNORECASE):
            return False

        return True
    except:
        return False


def getPeerFile(ssh, devObj, heartBeat, devicePeerList, deviceTmpDir):
    '''
    # *****************************************************************************************#
    # 函数名称: getPeerFile(ssh, devObj, devicePeerDirDictList, deviceTmpDir)
    # 功能说明: 通过心跳IP拷贝对端文件到本端指定目录（需debug模式下）
    # 输入参数: 
    #     ssh:
    #     devObj：
    #     devicePeerList:对端源文件列表
    #     deviceTmpDir：本端目标目录（需要存在）
    # 返 回 值: True/False
    # *****************************************************************************************#
    '''
    try:
        logger = devObj.get("logger")
        #获取密码，用于登录对端
        pwd = devObj.get("password")
        
        #判断心跳IP地址
        if not heartBeat:
            logger.error("[getPeerFile] The heartBeat is empty")
            return False
        
        if devicePeerList == []:
           logger.error("[getPeerFile] The fileList is empty")
           return False 
       
        #将对端源文件依次拷贝到本端临时目录
        for sourceFile in devicePeerList:    
            result = scpPeerFile(ssh, pwd, heartBeat, sourceFile, deviceTmpDir)
            if not result:
                logger.error("[getPeerFile] Get the file fialed: " + str(sourceFile))
        return True
    
    except:
        logger.error("[getPeerFile] escept:" + str(traceback.format_exc()))
        return False

def mkTempDir(ssh, tempDir):
    '''
    # *****************************************************************************************#
    # 函数名称: mkTempDir(ssh, tempDir)
    # 功能说明: 在本端创建临时目录（需要确定使用之前为debug模式）
    # 输入参数: 
    #     ssh:
    #     tempDir:临时目录名称
    # 返 回 值: True/False
    # *****************************************************************************************#
    '''
    cmd = "mkdir -p " + tempDir
    cliRet = common.execCmd(ssh, cmd)
    if re.search("cannot create directory", cliRet, re.IGNORECASE):
        return False
    
    return True

def rmTempDir(ssh, tempDir):
    '''
    # *****************************************************************************************#
    # 函数名称: mkTempDir(ssh, tempDir)
    # 功能说明: 删除临时目录（需要确定使用之前为debug模式）
    # 输入参数: 
    #     ssh:
    #     tempDir:临时目录名称
    # 返 回 值: True/False
    # *****************************************************************************************#
    '''
    cmd = "rm -rf " + tempDir
    cliRet = common.execCmd(ssh, cmd)
    return True

def copyCurNodeFile(ssh, devObj, fileList, targetDir):
    '''
    # *****************************************************************************************#
    # 函数名称: copyFile(ssh, sourceDir, targetDir)
    # 功能说明: 拷贝本端指定文件到临时目录下（debug模式下）
    # 输入参数: 
    #     ssh:
    #     fileList:源文件（目录）列表
    #     targetDir：目标目录（需要存在）
    # 返 回 值: True/False
    # *****************************************************************************************#
    '''
    logger = devObj.get("logger")
    
    if fileList == []:
        logger.error("[copyCurNodeFile] The fileList is empty")
        return False

    for sourceFile in fileList:
        #保留源文件目录结构
        remoteDirName = os.path.dirname(sourceFile)
        #拷贝文件到临时目录
        cmd = "cp -rf " + sourceFile + " " + targetDir
        cliRet = common.execCmd(ssh, cmd)
        if re.search("directory does not exist", cliRet, re.IGNORECASE):
            logger.error("[copyCurNodeFile] directory does not exist:" + str(sourceFile))
            return False
        
    return True

def getAllNodeFile(ssh, devObj, fileDictList, tempDir):
    '''
    # *****************************************************************************************#
    # 函数名称: getAllNodeFile(ssh, devObj, fileDictList, tempDir)
    # 功能说明: 将所有控制器需要收集的文件收集到本端指定的临时文件夹下
    # 输入参数: 
    #     ssh:
    #     devObj:
    #     fileDictList:需要收集的文件集合
    #     tempDir：临时目录
    # 返 回 值: True/False
    # *****************************************************************************************#
    '''
    try:
        logger = devObj.get("logger")
        
        #获取所有控制器的IP，用于创建目录
        ipList = getCtrlIp(ssh)
        if ipList == []:
            logger.error("[getAllNodeFile] Ip list is empty")
            return False
        
        #获取当前的IP
        curIp = devObj.get("devIp")
        
        #切换模式到debug，后续操作均需要在debug模式下执行
        curModel = modelManager.getCurSysMode(ssh)
        flag = modelManager.changeMode2Debug(devObj, True)
        if not flag:
            logger.error("[getAllNodeFile] Change model to debug failed!")
            return False
        
        #先清理再创建临时目录
        rmTempDir(ssh, tempDir)
        mkRet = mkTempDir(ssh, tempDir)
        if not mkRet:
            logger.error("[getAllNodeFile] Make remote temp directory failed! Dir=" + str(tempDir))
            return False
        
        #循环收集每个IP的文档到临时目录
        for ip in ipList:
            tempDir4Ip = tempDir + "/" + str(ip)
            mkRet = mkTempDir(ssh, tempDir4Ip)
            if not mkRet:
                logger.error("[getAllNodeFile] Make remote temp directory failed! Dir=" + str(tempDir4Ip))
                return False
            
            handleRet = handleFileWithIp(ssh, devObj, ip, curIp, fileDictList, tempDir4Ip)
            if not handleRet:
                logger.info("[getAllNodeFile] Handle file failed! IP=" + str(ip))
                continue
            
        return True

    except:
        logger.error("[getAllNodeFile] except trace back:" + str(traceback.format_exc()))
        return False
    finally:
        #先删除临时文件夹，再恢复之前的状态;
        tempHigDir = os.path.dirname(tempDir) + '/'  #删除 /tmp/collect/ 目录
        rmTempDir(ssh, tempHigDir)
        modelManager.changeMode(devObj, curModel) 
    
def handleFileWithIp(ssh, devObj, ip, curIp, fileDictList, tempDir4Ip):
    '''
    # *****************************************************************************************#
    # 函数名称: handleFileWithIp(ssh, devObj, ip, curIp, fileDictList, tempDir4Ip):
    # 功能说明: 将指定IP上需要收集的文件收集到本端指定的临时文件夹下
    # 输入参数: 
    #     ssh:
    #     devObj:
    #         ip:当前处理的IP
    #      curIP：已连接的IP
    #fileDictList:需要收集的文件集合
    #     tempDir4Ip：临时目录
    # 返 回 值: True/False
    # *****************************************************************************************#
    '''
    logger = devObj.get("logger")
    #获取心跳IP地址
    heartBeat = getHeartbeatIp(ssh, devObj)
    if heartBeat == "":
        logger.error("[getAllNodeFile] Get heartbeat failed!")
        return False
    
    #远端已创建的临时文件夹集合
    makedDir = []  
            
    #处理xml中获取的文件列表
    for fileDict in fileDictList:
        
        #获取文件名和是否需要收集最新
        remoteFileName = fileDict["name"]
        needNew = fileDict["isLatest"]
        
        #获取需要下载的文件的目录名称，用于保留路径结构
        remoteFileHigDirName = os.path.dirname(remoteFileName)
        remoteDirName = tempDir4Ip + remoteFileHigDirName
        sourceFileList = []
        
        #创建目录
        if not remoteDirName in makedDir:
            result = mkTempDir(ssh, remoteDirName)
            if result:
                makedDir.append(remoteDirName)
            else:
                logger.error("[getAllNodeFile] Create directory failed! Dir=" + remoteDirName)
                continue
            
        #根据IP选择不同的处理方式
        if ip == curIp:
            #处理文件名带*的文档
            if "*" in remoteFileName and needNew == "true":
                sourceFileList = getFileList(ssh, devObj, remoteFileName, ip)
            elif "*" in remoteFileName:
                sourceFileList = getFileList(ssh, devObj, remoteFileName, ip, True)
            else:
                sourceFileList.append(remoteFileName)
            
            #拷贝文件到指定临时目录
            collectRet = copyCurNodeFile(ssh, devObj, sourceFileList, remoteDirName)
                   
        else:
            #处理文件名带*的文档
            if "*" in remoteFileName and needNew == "true":
                sourceFileList = getFileList(ssh, devObj, remoteFileName, heartBeat, False, True)
            elif "*" in remoteFileName:
                sourceFileList = getFileList(ssh, devObj, remoteFileName, heartBeat, True, True)
            else:
                sourceFileList.append(remoteFileName)
            
            #拷贝文件到指定临时目录    
            collectRet = getPeerFile(ssh, devObj, heartBeat, sourceFileList, remoteDirName)
    
        #下载文件
        localSaveDir = devObj["LocalSaveDir"]
        for fileName in sourceFileList:
            #构造文件结构
            remoteFileName = os.path.basename(fileName)
            localFileName = remoteFileName
            localHigDirName = remoteFileHigDirName.replace('/', os.path.sep)
            
            #保存download临时参数
            devObj["DownRemoteFileDir"] = remoteDirName
            devObj["DownRemoteFileName"] = remoteFileName
            devObj["DownLocalFileDir"] = localSaveDir + os.path.sep + str(ip) + localHigDirName 
            devObj["DownLocalFileName"] = localFileName
            #存在文件下载失败
            if not downloadFileAndDelete(devObj):
                logger.error("[downloadFileAndDelete] Download file failed! fileNmae: " + str(fileName))
                continue
    return True      
                    
def getCtrlIp(ssh):
    '''
    # *****************************************************************************************#
    # 函数名称: gerCtrlIp(ssh)
    # 功能说明: 获取所有控制器Ip
    # 输入参数: 
    #     ssh:
    # 返 回 值: ipList: ip 列表
    # *****************************************************************************************#
    '''
    ipList = []
    
    cmd = "showctrlip"
    cliRet = common.execCmd(ssh, cmd)
    cliRetLines = cliRet.splitlines()
    if len(cliRetLines) < 3:
        return ipList
    
    for line in cliRetLines:
        lineStrip = line.strip()
        if lineStrip.startswith("A") or lineStrip.startswith("B"):
            ipList.append(lineStrip.split()[1])
        
    return ipList

def getFileList(ssh, devObj, sourceFile, ip, getAll=False, getPeer=False):
    '''
    # *****************************************************************************************#
    # 函数名称:  getLatestFile(sourceFile, getPress=False)
    # 功能说明: 获取指定文件的最新（需要扩展的目录下不能含子目录，否则文件列表将出错）
    # 输入参数: 
    # sourceFile：文件名称
    #   getAll:是否需要获取所有扩展出的文件
    #   getPeer：是否为对端
    # 返 回 值: 文件名称列表
    # *****************************************************************************************#
    '''
    fileNameList = []
    if not getPeer:
        if not getAll:
            cmd = "ls -lt %s| head -n 1" % sourceFile
        else:
            cmd = "ls %s" % sourceFile
    else:
        if not getAll:
            cmd = 'ssh admin@%s "ls -lt %s | head -n 1"' % (ip, sourceFile)
        else:
            cmd = 'ssh admin@%s "ls %s"' % (ip,sourceFile)
            
    itemCliRet = common.execCmd(ssh, cmd)
    cliRet = itemCliRet
    if re.search("Are you sure you want to continue connecting (yes/no)?", itemCliRet, re.IGNORECASE):
        itemCliRet = common.execCmd(ssh, "yes")
        cliRet += itemCliRet
        
    if re.search("Password:", cliRet, re.IGNORECASE):
        #获取密码，用于登录对端
        pwd = devObj.get("password")
        cliRet = common.execCmdNoLog(ssh, pwd)
    
    if re.search("No such file or directory", cliRet, re.IGNORECASE):
        return []
    if re.search("Could not resolve hostname",cliRet, re.IGNORECASE):
        return []
    #解析回文，获取文件名称
    cliRetLines = cliRet.splitlines()
    if len(cliRetLines) < 3:
        fileNameList = []
    else:
        if getAll:
            #获取*所有扩展出来的文件
            cliRetLines = cliRetLines[1:-1]
            for line in cliRetLines:
                lineSplit = line.split()
                if '.' in lineSplit:
                    lineSplit.remove(".")
                if '..' in lineSplit:
                    lineSplit.remove("..")
                fileNameList += lineSplit
            
        else:
            fileName = cliRetLines[1].split()[-1] 
            if fileName in ['.', ".."]:
                fileNameList = []
            else:
                fileNameList.append(fileName)
    return fileNameList


#************************************************************************** #
# 函数名称：downloadFileAndDelete
# 功能说明：下载文件后，删除远端阵列临时文件
# 传入参数：downloadDict对象
# 返  回  值：True:执行成功；False:执行失败
# **************************************************************************** #
def downloadFileAndDelete(downloadDict):
    
    try:
        #获取方法
        sftp = downloadDict.get("SFTP")
        logger = downloadDict.get("logger")
        lang = downloadDict.get("lang")
        
        #获取文件信息
        sourceFileDir = downloadDict.get("DownRemoteFileDir")
        sourceFileName = downloadDict.get("DownRemoteFileName")
        destFileDir = downloadDict.get("DownLocalFileDir")
        destFileName = downloadDict.get("DownLocalFileName")
        
        #判断远端阵列文件是否存在
        if not sftp.isFileExist(sourceFileDir, sourceFileName):
            logger.info("[downloadFileAndDelete] remote file does not exist:" + str(sourceFileDir) + "/" + str(sourceFileName))
            return False
        
        #判断本地保存文件夹是否存在，不存在则创建
        if os.path.exists(destFileDir) != True:
            os.makedirs(destFileDir)
        
        #通过SFTP下载文件
        remoteFile = sourceFileDir + '/' + sourceFileName
        localFile = destFileDir + os.path.sep + destFileName
        file = File(localFile)
        sftp.getFile(remoteFile, file, None)  

        #不判断返回值：失败了会打印错误信息，但不影响最终收集结果
        deleteRemoteFile(downloadDict, remoteFile)
        return True

    except:
        logger.error("[downloadFileAndDelete] except trace back:" + str(traceback.format_exc()))
        if lang == "zh":
            downloadDict["py_detail"] = u"通过SFTP下载文件失败"
        else:
            downloadDict["py_detail"] = "Downloading remote file by SFTP failed"
        return False
        
# **************************************************************************** #
# 函数名称：deleteRemoteFile
# 功能说明：删除远端阵列上的文件
# 传入参数：devObj对象；remoteFile远端阵列文件
# 返  回  值：True:执行成功；False:执行失败
# **************************************************************************** #
def deleteRemoteFile(devObj, remoteFile):
    try:
        sftp = devObj.get("SFTP")
        logger = devObj.get("logger")
        #使用sftp自带接口删除远端临时文件
        sftp.deleteFile(remoteFile)
        return True
    except:
        logger.error("[deleteRemoteFile] except trace back:" + str(traceback.format_exc()))
        return False    
    
    
    