# -*- coding: UTF-8 -*-
import os
import re

import cliUtil
import common
from statisticDisk import Statistic
from frameone.util import common as frame_common
from cbb.frame.util import sqlite_util
from cbb.frame.base import product

from query_hyper_metro_info import QueryHyperMetroInfo

from com.huawei.ism.tool.obase.exception import ToolException

PY_JAVA_ENV = py_java_env
LANG = common.getLang(PY_JAVA_ENV)
LOGGER = common.getLogger(PY_LOGGER, __file__)
CLI_HOR_DIRECTION = 'Horizontal',
CLI_HOR_DIRECTION_NO_STAND = 'HorizontalNoStand',
CLI_VER_DIRECTION = 'Vertical'
SSH_LOCK_STATU_OCCUPIED = 1
SSH_LOCK_STATU_FREE = 0
HYPERMETRO_NODE_NAME = {'zh': u"双活一致性", 'en':"HyperMetro information consistency"}
NEED_OPEN_TLV_CHANNEL_VERSION = ["V300R003", "V300R005"]
USER_NAME_BLACK_LIST = ["developer", "diagnose", "error", "password",
                        "upgrade", "minisystem", "storage"]

DIGIT_SPH = "SPH"
KUNPENG_FLAG = "Kunpeng"


def execute(devObj):
    '''
        健康检查前置脚本
    '''
    default_pass_list = ["CDM", "HUAWEI OceanCyber 300", "HUAWEI CyberEngine 300"]
    if str(PY_JAVA_ENV.get("devInfo").getDeviceType()) in default_pass_list:
        return ('', dict(flag=True, des=common.getMsg(LANG, "check.pass")))
    common.refreshProcess(PY_JAVA_ENV, 1, LOGGER)
    __ini_sqlite_db_safely()
    cli = devObj.get("ssh")
    cli_ret, switch = cliUtil.openDeveloperSwitch(cli, LANG)
    PY_JAVA_ENV['switch'] = switch
    check_mem_version_consistence(cli)
    # 删除备份DB
    _delete_backup_db()

    # 检查设备状态
    flag, ret = check_device_status(cli)
    if not flag:
        return ret
    try:
        flag, ret = check_login_user(devObj)
        if not flag:
            return ret
        cur_progress = 2.0
        hyper_total_progress = 80.0
        query_obj = QueryHyperMetroInfo(
            PY_JAVA_ENV, LOGGER, cli, cur_progress, hyper_total_progress)
        cur_progress = query_obj.collect_hyper_metro_info()
        cur_progress = hyper_total_progress if not cur_progress else cur_progress
        LOGGER.logInfo("collect_hyper_metro_info: progress {}".format(
            cur_progress))
        statistic_progress = 100 - cur_progress
        if statistic_progress < 0:
            statistic_progress = 1

        # 扩容评估收集现网硬盘配置
        s = Statistic(PY_JAVA_ENV, devObj, LOGGER, cur_progress,
                      cur_progress + statistic_progress)
        s.execute()
        common.refreshProcess(PY_JAVA_ENV, 100, LOGGER)
        return ('', dict(flag=True, des=common.getMsg(LANG, "check.pass")))

    except Exception as e:
        LOGGER.logException(e)
        common.refreshProcess(PY_JAVA_ENV, 100, LOGGER)
        return ('', dict(flag=False, des=common.getMsg(
            LANG, "query.result.abnormal")))
    except ToolException as e:
        LOGGER.logException(e)
        common.refreshProcess(PY_JAVA_ENV, 100, LOGGER)
        return ('', dict(flag=False, des=common.getMsg(
            LANG, "ssh.connect.fail")))
    finally:
        cliUtil.enterCliModeFromSomeModel(cli, LANG)


def _delete_backup_db():
    # 删除扩容评估中备份的DB数据
    sceneResPath = PY_JAVA_ENV.get("sceneResPath")
    if sceneResPath is not None:
        dbFilePath = sceneResPath + os.path.sep + "db.dat"
        if os.path.isfile(dbFilePath):
            try:
                os.remove(dbFilePath)
                LOGGER.logInfo("Backup DB data deleted successfully.")
            except Exception:
                LOGGER.logInfo("Fail to Delete Backup DB data.")


def check_device_status(cli):
    cmd = "show system general"
    checkRet = cliUtil.excuteCmdInCliMode(cli, cmd, True, LANG)
    if checkRet[0] is not True:
        errMsg = checkRet[2]
        LOGGER.logSysAbnormal()
        return False, ('', dict(flag=False, des=errMsg))

    cliRet = checkRet[1]
    status, errMsg = common.checkSystemStatus(cliRet, LANG)
    if not status:
        LOGGER.logNoPass("The status of system is abnormal")
        common.refreshProcess(PY_JAVA_ENV, 100, LOGGER)
        return False, ('', dict(flag=False, des=errMsg))
    return True, ('', dict())


def check_login_user(dev_obj):
    ensureTlvChannelOpened(dev_obj)
    cli = dev_obj.get("ssh")
    # 检查用户是否为备份管理员（备份管理员进行限制收集）
    loginUserName = PY_JAVA_ENV.get("devInfoMap").get("userName")
    LOGGER.logInfo("loginUserName is %s" % str(loginUserName))
    isPass, errMsg = checkLoginAdministrator(cli, loginUserName, LANG)
    if not isPass:
        return False, ('', dict(flag=False, des=errMsg))

    # 检查用户名称是否非法（部分关键字作为用户名将影响CLI执行结果判断）
    isPass = checkLoginUserName(loginUserName)
    if not isPass:
        return False, ('', dict(flag=False, des=common.getMsg(
            LANG, "loginUser.name.check.failure")))
    common.refreshProcess(PY_JAVA_ENV, 2, LOGGER)
    return True, ('', dict())


def checkLoginUserName(loginUserName):
    '''
    @summary: 检查工具登录用户名称是否为特殊用户名
    '''
    
    loginUserName = loginUserName.lower()
    
    for keyWords in USER_NAME_BLACK_LIST:
        if loginUserName.endswith(keyWords):
            return False
    
    return True


def checkLoginAdministrator(cli, loginUserName, LANG):
    '''
    @summary: 检查用户是否为备份管理员（备份管理员进行限制收集）
    '''
    cmd = "show user user_name=%s" % (loginUserName)
    isSuccess, cliRet, errMsg = cliUtil.excuteCmdInCliMode(cli, cmd, True, LANG)
    if isSuccess != True: 
            LOGGER.logSysAbnormal()
            return False, errMsg
    DictList = cliUtil.getHorizontalCliRet(cliRet)
    RoleId = ""
    for item in DictList:
        #兼容老版本不存在“Role ID”场景
        if "Role ID" in item:
            RoleId = item.get("Role ID")
            LOGGER.logInfo("RoleId is %s" % str(RoleId))
        
    if RoleId == "9":
        return False, common.getMsg(LANG, "loginUser.name.level.check.failure")

    # 如果设备登录用户为域用户，且权限不为超级管理员，限制收集
    if py_java_env.get("devInfo").getLoginUser().getUserName().startswith("domain/") and RoleId != "1":
        return False, common.getMsg(LANG, "loginUser.name.domain.level.check.failure")

    return True, ""


@frame_common.wrapAllExceptionLogged(logger=PY_LOGGER)
def ensureTlvChannelOpened(devObj):
    curProductVersion = py_java_env.get("devInfo").getProductVersion()
    needOpenTlvChannel = False
    for ver in NEED_OPEN_TLV_CHANNEL_VERSION:
        if ver in curProductVersion:
            needOpenTlvChannel = True
    if not needOpenTlvChannel:
        return

    cli = devObj.get("ssh")
    exeResult, qryRet, _ = cliUtil.excuteCmdInDeveloperMode(cli, 'show system external_tlv_channel',
                                                       True, LANG)
    if exeResult != True:
        LOGGER.logError('Query tlv channel status failed.')
        return

    if 'yes' in qryRet.lower():
        LOGGER.logInfo('Tlv channel is already opened, need not to open again')
        return
    else:
        LOGGER.logInfo('Tlv channel is not opened, need to open')

    exeResult, cliRet, _ = cliUtil.excuteCmdInDeveloperMode(cli, 'change system external_tlv_channel enabled=yes',
                                                       True, LANG)
    if exeResult != True:
        LOGGER.logError("Failed to open tlv channel")
        return
    else:
        while '(y/n)' in cliRet:
            exeResult, cliRet, _ = cliUtil.excuteCmdInCliMode(cli, 'y', True, 'en')

        LOGGER.logInfo("Open tlv channel success.")
        py_java_env.put('needCloseTlv', True)


@frame_common.wrapAllExceptionLogged(logger=PY_LOGGER)
def __ini_sqlite_db_safely():
    context = py_java_env.get("objectForPy")
    sqlite_util.ini_sqlite_db_for_dev(context,
                                      common.get_sn_from_env(py_java_env),
                                      LOGGER)
    sqlite_dict = context.get("SQLITE_CONN_DICT")
    LOGGER.logInfo("db context is %s." % str(sqlite_dict))


@frame_common.wrapAllExceptionLogged(logger=PY_LOGGER)
def check_mem_version_consistence(cli):
    """判断spc和sph是否和内存中一致"""
    dev_info = py_java_env.get("devInfo")
    p_version_mem = str(dev_info.getProductVersion())
    p_patch_mem = str(dev_info.getHotPatchVersion())

    p_version, p_patch = get_actual_version(cli)
    # 查询失败时不做拦截
    if not p_version and not p_patch:
        LOGGER.logError("query actual version fail!do not "
                        "check version consistency.")
        return True, p_version_mem, p_version
    LOGGER.logInfo(
        "ver actual is:{}, version in mem:{}, "
        "patch actual is {}, patch in mem is {}".format(
            p_version, p_version_mem, p_patch, p_patch_mem
        )
    )

    # 实际查出来无补丁信息，则直接设置为空
    if p_patch == "--" or not p_patch:
        # 数字版本需重新刷新版本信息为不带SPH字样的版本号
        set_no_patch(dev_info, p_version, p_version_mem)
        return True, p_version_mem, p_version

    # 实际有补丁时，检查是否一致，不一致时更新补丁信息
    if not check_sph_version_consistence(p_patch_mem, p_patch):
        set_patch_version(p_patch, p_version_mem, p_version, dev_info)
    else:
        set_complete_patch_version(p_patch, p_version_mem, p_version, dev_info)

    return True, p_version_mem, p_version


def set_complete_patch_version(p_patch, p_version_mem, p_version, dev_info):
    """
    设置补丁版本和更新版本信息
    :param p_patch: 实际的补丁版本
    :param p_version_mem: 内存中的软件版本
    :param p_version: 实际的软件版本
    :param dev_info: dev_node
    :return:
    """
    if product.isDigitalVer(p_version) and p_patch.startswith(DIGIT_SPH):
        p_patch = p_version + "." + p_patch
        # 数字版本只需要更新版本信息。
    elif KUNPENG_FLAG in p_version_mem and KUNPENG_FLAG not in p_patch:
        patch_str = re.compile("SPH\d+").findall(p_patch)[0]
        p_patch = "{}{} {}".format(str(p_version.split("SPC")[0]), patch_str, KUNPENG_FLAG)

    dev_info.setHotPatchVersion(p_patch)
    LOGGER.logInfo("set_complete_patch_version update patch version to {}.".format(p_patch))
    return True, p_version_mem, p_version


def set_patch_version(p_patch, p_version_mem, p_version, dev_info):
    """
    设置补丁版本和更新版本信息
    :param p_patch:
    :param p_version_mem:
    :param p_version:
    :param dev_info:
    :return:
    """
    # Kunpeng的补丁版本展示时追加了Kunpeng字样。
    if KUNPENG_FLAG in p_version_mem and KUNPENG_FLAG not in p_patch:
        patch_str = re.compile("SPH\d+").findall(p_patch)[0]
        p_patch = "{}{} {}".format(str(p_version.split("SPC")[0]), patch_str, KUNPENG_FLAG)
    # dorado 6.x 带补丁版本时添加版本信息
    elif product.isDigitalVer(p_version) and p_patch.startswith(DIGIT_SPH):
        # 当show package中version带了sph字段时，不能重复添加sph版本。
        if DIGIT_SPH in p_version:
            dev_info.setProductVersion(p_version)
            return True, p_version_mem, p_version

        p_patch = p_version + "." + p_patch
        # 数字版本只需要更新版本信息。
        dev_info.setProductVersion(p_patch)
    dev_info.setHotPatchVersion(p_patch)
    LOGGER.logInfo("set_patch_version update patch version to {}.".format(p_patch))
    return True, p_version_mem, p_version


def set_no_patch(dev_info, p_version, p_version_mem):
    """
    数字版本需重新刷新版本信息为不带SPH字样的版本号
    :param dev_info: dev 对象
    :param p_version: 实际版本
    :param p_version_mem: 内存版本
    :return:
    """
    if product.isDigitalVer(p_version):
        dev_info.setProductVersion(p_version)
    dev_info.setHotPatchVersion('')
    LOGGER.logInfo("update patch version to null.")
    return True, p_version_mem, p_version


def check_version_consistence(version_a, version_b):
    """
    检查版本是否一致
    :param version_a: 待检查内存版本
    :param version_b: 待检查实际版本
    :return: False 不一致，True 一致
    """
    return filter_special_version(version_a) == filter_special_version(
        version_b)


def check_sph_version_consistence(sph_a, sph_b):
    """
    检查补丁版本是否一致
    :param sph_a: 待检查内存中补丁
    :param sph_b: 待检查实际补丁
    :return: False 不一致，True 一致
    """
    return get_sph_str(sph_a) == get_sph_str(sph_b)


def get_sph_str(sph_version):
    """
     当不存在补丁时，内存中保存的是null。而实际show命令查出来的是--。
    :param sph_version: 补丁
    :return: SPH信息
    """
    if not sph_version or sph_version == '--':
        return ''

    sph_regx = re.compile(r"{}\w+".format(DIGIT_SPH))
    search_ret = sph_regx.search(sph_version)
    if search_ret:
        return search_ret.group()

    return sph_version


def filter_special_version(tmp_version):
    """
    1. 版本号中带 Kunpeng的需要去掉Kunpeng比较
    2. 将版本中的补丁号去掉，6.0.1.SPH1 转换为 6.0.1
    :param tmp_version: 待处理版本
    :return: 处理后版本
    """
    if KUNPENG_FLAG in tmp_version:
        return tmp_version.rstrip(KUNPENG_FLAG).strip()

    elif product.isDigitalVer(tmp_version) and DIGIT_SPH in tmp_version:
        return filter_digit_version(tmp_version)

    return tmp_version


def filter_digit_version(tmp_version):
    """
    将版本中的补丁号去掉，6.0.1.SPH1 转换为 6.0.1
    :param tmp_version: 待处理版本
    :return: 处理后版本
    """
    sph_regx = re.compile(r"{}.*".format(DIGIT_SPH))
    search_ret = sph_regx.search(tmp_version)
    if search_ret:
        sph_part = search_ret.group()
        tmp_version = tmp_version.rstrip(sph_part)

    return tmp_version.rstrip(".")


def get_actual_version(cli):
    """
    通过命令show upgrade package 查询实际版本
    :param cli: ssh连接
    :return: 版本和补丁信息
    """
    (
        ret,
        p_version,
        p_patch,
    ) = common.getHotPatchVersionAndCurrentVersion(cli, LANG)

    return p_version, p_patch


