# -*- coding: UTF-8 -*-
import common
import cliUtil
from common_utils import get_err_msg
from common import UnCheckException


LANG = common.getLang(py_java_env)
LOGGER = common.getLogger(PY_LOGGER, __file__)
PY_JAVA_ENV = py_java_env


RISK_VERSION = "V500R007C60SPC300"
RISK_PATCH_VERSION = "V500R007C60SPH305"
NO_PATCH = "--"


def execute(cli):
    """
    NFS服务检查
    步骤1 以admin用户登录设备；
    步骤2 执行命令：show upgrade package，获取系统软件版本和热补丁版本信息；
    步骤3 执行命令：change user_mode current_mode=developer，进入developer模式；
    步骤4 执行命令：show nfs server_info，获取nfs服务器信息。
    :param cli:
    :return:nfs_service_check
    """
    nfs_service_check = NfsServiceCheck(cli, LANG, PY_JAVA_ENV, LOGGER)
    flag, msg = nfs_service_check.execute_check()
    return flag, "\n".join(nfs_service_check.all_cli_ret), msg


class NfsServiceCheck:
    def __init__(self, cli, lang, env, logger):
        self.cli = cli
        self.lang = lang
        self.env = env
        self.logger = logger
        self.pro_version = ""
        self.pat_version = ""
        self.product_version = ""
        self.all_cli_ret = []

    def execute_check(self):

        try:
            self.product_version = str(
                self.env.get("devInfo").getProductVersion()
            )

            # 检查版本
            if not self.check_version():
                return True, ""

            # 检查nfs服务器信息
            return self.check_nfs_message()

        except UnCheckException as e:
            self.logger.logError(str(e))
            return cliUtil.RESULT_NOCHECK, e.errorMsg
        except Exception:
            err_msg = "query.result.abnormal"
            return cliUtil.RESULT_NOCHECK, common.getMsg(self.lang, err_msg)

    def check_version(self):
        """
        检查软件版本
        :return:True: 检查通过
        """

        (
            flag,
            self.pro_version,
            self.pat_version,
            cli_ret,
            err_msg,
        ) = common.getProductVersionAndHotPatchVersion(
            self.cli, self.logger, self.lang
        )
        self.all_cli_ret.append(cli_ret)
        if flag is not True:
            raise UnCheckException(err_msg, cli_ret)

        self.logger.logInfo(
            "product_version is: {}; patch_version is: {}".format(
                self.pro_version, self.pat_version
            )
        )
        return all(
            [
                self.pro_version == RISK_VERSION,
                any(
                    [
                        self.pat_version < RISK_PATCH_VERSION,
                        self.pat_version == NO_PATCH,
                    ]
                ),
            ]
        )

    def check_nfs_message(self):
        """
        查询nfs服务器信息，'Server Name'字段
        :return:True: 检查通过
        """

        flag, cli_ret, err_msg = cliUtil.enterDeveloperMode(
            self.cli, self.lang
        )
        self.all_cli_ret.append(cli_ret)
        if flag is not True:
            raise UnCheckException(err_msg, cli_ret)

        cmd = "show nfs server_info"
        flag, cli_ret, err_msg = cliUtil.excuteCmdInDeveloperMode(
            self.cli, cmd, True, self.lang
        )

        self.all_cli_ret.append(cli_ret)
        if flag is not True:
            raise UnCheckException(err_msg, cli_ret)

        if cliUtil.queryResultWithNoRecord(cli_ret):
            return True, ""

        result = cliUtil.getVerticalCliRet(cli_ret)
        server_name = result[0].get("Server Name")
        if server_name.lower() == "nfsd":
            err_msg = "software.check.nfs.service.not.pass"
            return False, get_err_msg(
                self.lang,
                err_msg,
                (
                    self.product_version,
                    self.pat_version,
                ),
            )
        return True, ""
