# -*- coding: UTF-8 -*-
import traceback
import os
import threading
import time
from java.io import File
from common.util import log, util
from common.constants import COMSYSTEMLOGPATH
from com.huawei.ism.tool.obase.exception import ToolException


class download_file_info:
    def __init__(self):
        # 全路径
        self.localFileName = ''
        self.remoteFileName = ''


def execute(devObj):
    try:
        sftp = devObj.get("SFTP")
        ssh = devObj.get("SSH")
        ssh.execCmdWithTimout("TMOUT=0", 60 * 30)
        localPath = os.path.join(devObj.get("collectRetDir"),
                                 devObj.get("collectRetFileName"))
        file = File(localPath)
        sftp.setRetryTimes(2)
        sftp.getFile(devObj.get("collecRemotePath"), file, None)
        return (True, "")
    except Exception as e:
        return (False, "%s" % str(e).encode("utf-8"))


def downComlog(devObj):
    try:
        sftp = devObj.get("SFTP")
        ssh = devObj.get("SSH")
        ssh.execCmdWithTimout("TMOUT=0", 60 * 30)
        localPath = os.path.join(devObj.get("collectRetDir"),
                                 devObj.get("collectRemoteFileName"))
        file = File(localPath)
        sftp.setRetryTimes(2)
        sftp.getFile(devObj.get("collecRemotePath"), file, None)
        log.info(devObj,
                 "download over,localdir name:%s,localfile:%s,remote path:%s"
                 % (devObj.get("collectRetDir"),
                    devObj.get("collectRetFileName"),
                    devObj.get("collecRemotePath")))
        return (True, "")
    except Exception as e:
        log.error(devObj, "Exception occurred while down package,err info:%s"
                  % str(e).encode("utf-8"))
        log.error(devObj, "Exception occurred while down package,err info:%s"
                  % str(traceback.format_exc()).encode("utf-8"))
        return (False, "Exception occurred while down package.")


def DownLog(devObj, file_info):
    fileHandle = open(file_info.localFileName + ".txt", 'w')

    try:
        sftp = devObj.get("SFTP")
        util.execCmdWithTimout(devObj, "TMOUT=0", 20)
        file = File(file_info.localFileName)
        log.info(devObj, "download start,local file name:%s,remote file:%s"
                 % (file_info.localFileName, file_info.remoteFileName))
        try:
            sftp.getFile(file_info.remoteFileName, file, None)
        except ToolException:
            sftp.reConnect()
            sftp.getFile(file_info.remoteFileName, file, None)
        log.info(devObj, "download over, remote file:%s"
                 % (file_info.remoteFileName))
        if fileHandle:
            fileHandle.write("SUCCESS")
            fileHandle.close()
            # 下载成功后删除下载的文件
            cmd = "[ -f %s ] && rm %s" % (file_info.remoteFileName,
                                          file_info.remoteFileName)
            util.execCmdWithTimout(devObj, cmd, 60 * 30)
        return (True, "")
    except Exception as e:
        log.error(devObj, "Exception occurred while down package,err info:%s"
                  % str(e).encode("utf-8"))
        log.error(devObj, "Exception occurred while down package,err info:%s"
                  % str(traceback.format_exc()).encode("utf-8"))
        if fileHandle:
            fileHandle.write("FAIL")
            fileHandle.close()
        return (False, "Exception occurred while down package.")


def DownByThread(devObj, startProgress, endProgress, file_info):
    log.info(devObj, "start download by thread, %s" % (file_info.remoteFileName))
    fileResult = file_info.localFileName + ".txt"

    # 获取收集包总大小
    fileTotalSize = util.GetFileTotalSize(devObj, file_info.remoteFileName)
    if fileTotalSize <= 0:
        log.error(devObj,
                  "faile to get file(%s) size" % file_info.remoteFileName)
        return (False, "faile to get file size")
    log.info(devObj, "total size:%d" % fileTotalSize)

    # 启动狭隘线程
    threadDownload = threading.Thread(target=DownLog,
                                      args=(devObj, file_info))
    threadDownload.start()

    # 更新下载进度
    lastSize = 0
    speed = 0
    maxRate = endProgress - startProgress
    while threadDownload.isAlive():
        if not os.path.exists(file_info.localFileName):
            time.sleep(5)
            continue
        curFileSize = os.path.getsize(file_info.localFileName)  # byte
        speed = (curFileSize - lastSize) / 10240
        lastSize = curFileSize
        tmpPercent = int(curFileSize * 1.0 / fileTotalSize * maxRate)
        percent = startProgress + tmpPercent
        util.RefreshProcess(devObj, percent, "Downloading(%dKB/s)" % (speed))
        time.sleep(10)
    threadDownload.join()
    speed = (fileTotalSize - lastSize) / 10240
    util.RefreshProcess(devObj, endProgress, "Downloading(%dKB/s)" % (speed))

    # 解析下载结果
    strRet = ""
    fileHandle = open(fileResult, "r")
    if fileHandle:
        strRet = fileHandle.read()
        fileHandle.close()
    os.remove(fileResult)

    for line in strRet.splitlines():
        if line.find("SUCCESS") >= 0:
            return (True, "")

    return (False, "download failed, please check.")


def downTgz(devObj, strRet, nodeProgress, rateETmp,
            localSyslogPath, remoteSyslogPath):
    startProgress = rateETmp
    endProgress = 100

    # 获取节点数
    if nodeProgress == 0:
        log.info(devObj, "start download Syslog data by thread.")
        # 创建本地System_log目录
        isSuccess = util.createDir(devObj, localSyslogPath)
        if not isSuccess:
            log.error(devObj,
                      "create Syslog dir Failed, dataTmpDir=%s"
                      % localSyslogPath)
            return (False, "create Syslog directory Failed.")
        nodenum = util.QueryNodeNum(devObj)
        # 设置单节点总进度
        tmpProgress = endProgress - startProgress
        nodeProgress = tmpProgress / (nodenum + 1.0)
        log.info(devObj, "Node num is %d" % nodenum)

    for line in strRet.splitlines():
        if line.find(".tgz") >= 0:
            # 设置进度
            rateSTmp = rateETmp
            rateETmp = rateSTmp + nodeProgress
            if rateETmp > endProgress:
                rateETmp = endProgress
            log.info(devObj,
                     "fielname=%s, NodeProgress=%d, start=%d, end=%d"
                     % (line, nodeProgress, rateSTmp, rateETmp))

            # 设置最后进度
            limitRate = rateETmp
            # 启动狭隘下载线程
            file_info = download_file_info()
            file_info.localFileName = os.path.join(localSyslogPath,
                                                   line)
            file_info.remoteFileName = remoteSyslogPath + "/" + line
            isSuccess, retInfo = DownByThread(devObj,
                                              rateSTmp,
                                              rateETmp,
                                              file_info)
            if not isSuccess:
                # 如果下载失败,重新下载.
                rateETmp = rateSTmp
                log.error(devObj,
                          "download normal log failed, ret=%s"
                          % retInfo)
    return (limitRate, nodeProgress)


def GetCollectlimitRate(countTime, limitRate):
    # 老版本,硬件,最多分1小时,60%进度
    if countTime > 240:
        countTime = 240

    # 新版本如果有进度,使用进度.最多25%
    if limitRate > 100:
        limitRate = 100

    if limitRate == 0:
        limitRate = countTime
    limitRate = limitRate * 0.25

    return limitRate


def WaitAndDownByThread(devObj):
    log.info(devObj, "start wait download node data by thread.")
    nodeProgress = 0.0  # 单个节点进度
    countTime = 0
    limitRate = 0
    rateETmp = 0

    # 设置起始进度
    util.RefreshProcess(devObj, 1, "Collecting...")
    time.sleep(15)

    localSyslogPath = os.path.join(devObj.get("collectRetDir"), "System_log")
    remoteSyslogPath = COMSYSTEMLOGPATH
    cmd_rm_Pkg = "rm -rf %s; ls -l %s | awk '{print $9}'" \
                 % (remoteSyslogPath, remoteSyslogPath)
    cmd_ls_Pkg = "ls -l %s | awk '{print $9}'" % (remoteSyslogPath)

    util.execCmdWithTimout(devObj, "TMOUT=0", 20)
    strRet = util.execCmdWithTimout(devObj, cmd_rm_Pkg, 20)
    isExistTGZ = strRet.find(".tgz")
    proc = util.CheckCollectProcessExists(devObj)

    while not ((proc <= 0) and (isExistTGZ < 0)):
        if isExistTGZ >= 0:
            limitRate, nodeProgress = downTgz(devObj, strRet, nodeProgress, rateETmp,
                                              localSyslogPath, remoteSyslogPath)
            rateETmp = limitRate
        else:
            # 防止进入下载阶段后,进入此流程
            if nodeProgress == 0:
                countTime = countTime + 1
                # 如果有进度,使用进度.
                limitRate = util.QueryCollectProcess(devObj)
                log.info(devObj, "limitRate=%d" % limitRate)

                limitRate = GetCollectlimitRate(countTime, limitRate)
                util.RefreshProcess(devObj, limitRate, "Collecting...")
                rateETmp = limitRate
            else:
                util.RefreshProcess(devObj, limitRate, "Waiting...")
            time.sleep(15)

        proc = util.CheckCollectProcessExists(devObj)
        strRet = util.execCmdWithTimout(devObj, cmd_ls_Pkg, 20)
        isExistTGZ = strRet.find(".tgz")
    # 清理远端空目录
    if os.path.exists(localSyslogPath):
        cmd = "rm -rf %s" % (remoteSyslogPath)
        util.execCmdWithTimout(devObj, cmd, 20)
    log.info(devObj, "end download Syslog data by thread.")
    return limitRate
