# Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.
# -*- coding: utf-8 -*-
"""
检查项节点类
"""
import os
import json
import logging
from abc import ABC, abstractmethod
from utils.ssh_util import SSHCmd
from func.upgrade.common.upgrade_local_ssh import UpgradeLocalSsh
from func.upgrade.upgrade_operation_mgr.common.upgrade_util import UpgradeUtil
from .check_exception import CmdExecutionError

_local_logger = logging.getLogger(__name__)
_max_stdout_len = 200


def _filter_passwd_expire_warning(message):
    """
    过滤密码临期导致"Warning:"开头的警告信息
    :param message: 待过滤字符串
    """
    char_eol = os.linesep
    keyword = "Warning: your password will expire in"
    while message and message.startswith(keyword):
        idx = message.find(char_eol)
        message = message[idx + 1:]
    return message


class Node(ABC):
    """节点信息"""

    def __init__(self, kvs):
        """
        节点初始化
        :param kvs: 全局字典
        """
        self.kvs = kvs
        self.node_name = self._get_node_name()
        self.login_ip = kvs.get(f"{self.node_name}_ip")
        self.login_user = kvs.get("standby_login_user")
        self.login_pwd = kvs.get(f"{self.node_name}_{self.login_user}_pwd")
        self.su_user = "ossadm"
        self.su_pwd = kvs.get(f"{self.node_name}_{self.su_user}_pwd")
        self.ssh_port = kvs.get(f"{self.node_name}_ssh_port")
        self.sftp_port = kvs.get(f"{self.node_name}_sftp_port")
        self.ip_list = self._get_ip_list()

    @abstractmethod
    def _get_node_name(self):
        """
        获取节点名字
        """
        pass

    def _get_ip_list(self):
        """
        获取节点中的ip列表信息
        """
        cmd = "cat /opt/oss/manager/etc/sysconf/nodelists.json"
        stdout = self.execute_su_cmd(cmd)
        node_ips = []
        data = json.loads(stdout)
        for k, v in data.get('nodeList').items():
            # 跳过本节点ip
            if k == '0':
                continue
            node_ips.extend([item.get('IP') for item in v.get('IPAddresses')
                             if 'maintenance' in item.get('usage')])
        return node_ips

    def execute_cmd(self, cmd):
        """
        在节点上执行指令
        :param cmd: 指令内容
        :return: str 指令执行后的标准输出
        """
        stdcode, stdout, stderr = SSHCmd.run_cmd(cmd, self.login_ip, self.login_user, self.login_pwd,
                                                 self.ssh_port, path=None, watchers=None, timeout=180,
                                                 disown=False)
        if stdcode != 0:
            _local_logger.error("Execute cmd failed. Cmd: %s Stdcode: %s Stdout: %s Stderr: %s IP: %s",
                                cmd, stdcode, stdout, stderr, self.login_ip)
            raise CmdExecutionError(f"Execute cmd failed. Cmd: {cmd} Stdcode: {stdcode} Stdout: {stdout} "
                                    f"Stderr: {stderr} IP: {self.login_ip}")
        stdout = _filter_passwd_expire_warning(stdout)
        _local_logger.info("Execute cmd success. Cmd: %s Stdcode: %s Stdout: %s IP: %s",
                           cmd, stdcode, stdout[:_max_stdout_len], self.login_ip)
        return stdout

    def execute_su_cmd(self, cmd):
        """
        在节点上切换用户su_user后执行指令
        :param cmd: 指令内容
        :return: str 指令执行后的标准输出
        """
        stdcode, stdout, stderr = SSHCmd.run_su_cmd(cmd, self.login_ip, self.login_user, self.login_pwd,
                                                    self.ssh_port, self.su_user, self.su_pwd,
                                                    path=None, watchers=None, timeout=180, disown=False)
        if stdcode != 0:
            _local_logger.error("Execute su cmd failed. Cmd: %s Stdcode: %s Stdout: %s Stderr: %s IP: %s",
                                cmd, stdcode, stdout, stderr, self.login_ip)
            raise CmdExecutionError(f"Execute su cmd failed. Cmd: {cmd} Stdcode: {stdcode} Stdout: {stdout} "
                                    f"Stderr: {stderr} IP: {self.login_ip}")
        stdout = _filter_passwd_expire_warning(stdout)
        _local_logger.info("Execute cmd success. Cmd: %s Stdcode: %s Stdout: %s IP: %s",
                           cmd, stdcode, stdout[:_max_stdout_len], self.login_ip)
        return stdout


class MasterNode(Node):
    """主站点节点信息"""

    def __init__(self, kvs):
        super().__init__(kvs)
        self.node_type = 'master'

    def _get_node_name(self):
        return "node_omp01" if self.kvs.get("single_mgr_domain") == "no" else "node_nmsserver"

    def execute_cmd(self, cmd):
        success, stdout = UpgradeLocalSsh.send_custom_local_cmd(cmd, time_out=180, node_ip=self.login_ip)
        if not success:
            _local_logger.error("Execute cmd failed. Cmd: %s Stdcode: %s Stdout: %s IP: %s",
                                cmd, success, stdout, self.login_ip)
            raise CmdExecutionError(f"Execute cmd failed. Cmd: {cmd} Stdcode: {success} Stdout: {stdout} "
                                    f"IP: {self.login_ip}")
        stdout = _filter_passwd_expire_warning(stdout)
        _local_logger.info("Execute cmd success. Cmd: %s Stdcode: %s Stdout: %s IP: %s",
                           cmd, success, stdout[:_max_stdout_len], self.login_ip)
        return stdout

    def execute_su_cmd(self, cmd):
        # 主站点管理节点本为ossadm
        return self.execute_cmd(cmd)


class StandbyNode(Node):
    """备用站点节点信息"""

    def __init__(self, kvs):
        super().__init__(kvs)
        self.node_type = 'standby'

    def _get_node_name(self):
        return "node_standby_omp01" if self.kvs.get("single_mgr_domain") == "no" else "node_standby_nmsserver"


class ArbitrateNode(Node):
    """三方站点节点信息"""

    def __init__(self, kvs):
        super().__init__(kvs)
        self.node_type = 'arbitrate'
        self.node_name = self._get_node_name()
        self.login_ip = UpgradeUtil.get_value_from_main(kvs.get("easysuite.task_id"), "arbitrate_ip")
        self.login_user = kvs.get(f"{self.node_name}_login_user")
        self.login_pwd = kvs.get(f"{self.node_name}_login_pwd")
        self.ip_list = self._get_ip_list()
        self.su_user = "root"
        self.su_pwd = kvs.get(f"{self.node_name}_root_pwd")

    def _get_node_name(self):
        return "node_arbitrate"

    def _get_ip_list(self):
        return []
