import os
import threading
import time

import paramiko

import utils.common.log as logger

SSH_TIMEOUT = 60 * 60 * 2
KEEP_ALIVE = 30
CMD_TIMEOUT = 60 * 60 * 2


class SshClient(object):

    def __init__(self):
        self.r_lock_outer = self.r_lock_inner = threading.RLock()

    @staticmethod
    def get_prompt(ssh_client):
        ssh_client['channel'].send('\n')
        recv_str = ssh_client['channel'].recv(65535).decode('utf-8')
        while not recv_str.__contains__('\r\n'):
            recv_str = \
                recv_str + ssh_client['channel'].recv(65535).decode('utf-8')
            time.sleep(1)
        list_out = recv_str.split('\r\n')
        ssh_client['prompt'] = list_out[-1]

    def create_ssh_client(self, ip, username, passwd, port=22,
                          timeout=SSH_TIMEOUT, pub_key=None,
                          status_prompt=True):
        ssh_client = dict()
        trans = paramiko.Transport((ip, port))
        if not pub_key:
            trans.connect(username=username, password=passwd)
        else:
            with open(pub_key, encoding='utf-8') as _fp:
                key = paramiko.RSAKey.from_private_key(_fp)
                trans.connect(username=username, password=passwd, pkey=key)
        trans.set_keepalive(KEEP_ALIVE)
        channel = trans.open_session()
        channel.settimeout(timeout)
        channel.get_pty()
        channel.invoke_shell()
        stdout = channel.makefile('r', -1)
        ssh_client['client'] = trans
        ssh_client['channel'] = channel
        ssh_client['stdout'] = stdout
        time.sleep(2)
        if status_prompt:
            self.get_prompt(ssh_client)
        return ssh_client

    @staticmethod
    def ssh_exec_command(ssh_client, cmds):
        return ssh_client['channel'].send(cmds + "\n")

    @staticmethod
    def is_expect_recv(expect_list, recvstr):
        for expect in expect_list:
            if recvstr.__contains__(expect):
                return True
        return False

    @staticmethod
    def recv_str(recv_str, cmds):
        recv_str = recv_str.replace(' \r', '\r')
        recv_str = recv_str.replace(' \n', '\n')
        recv_str = recv_str.replace('\r', '')
        recv_str = recv_str.replace(cmds, '')
        npos = recv_str.find("$?")
        if npos != -1:
            recv_str = recv_str[npos + 2:]
        return recv_str

    def ssh_recv_output(self, ssh_client, cmds,
                        expect_list, timeout=CMD_TIMEOUT):
        recv_str = ""
        time_count = 0
        while not self.is_expect_recv(expect_list, recv_str):
            if timeout and time_count >= timeout:
                break
            time.sleep(1)
            time_count += 1
            if ssh_client['channel'].closed:
                recv_str = recv_str + ssh_client['channel'].recv(
                    65535).decode('utf-8')
                recv_str = self.recv_str(recv_str, cmds)
                break
            if not ssh_client['channel'].recv_ready():
                continue
            recv_str = recv_str + ssh_client['channel'].recv(
                65535).decode('utf-8')
            recv_str = self.recv_str(recv_str, cmds)
        if not self.is_expect_recv(expect_list, recv_str):
            raise Exception(
                'Ssh received output cannot find expect in {} seconds. '
                'The received string is:{}'.format(str(timeout), recv_str))

        recv_str = recv_str.replace(ssh_client['prompt'], '')
        list_out = recv_str.split('\n')
        list_out = [x for x in list_out if x]
        return list_out

    def ssh_exec_command_return(self, ssh_client, cmds, timeout=CMD_TIMEOUT):
        new_cmds = "{} ;echo last cmd result: $?".format(cmds)
        self.ssh_exec_command(ssh_client, new_cmds)
        return self.ssh_recv_output(ssh_client, new_cmds,
                                    ['last cmd result:'], timeout)

    def ssh_send_command_expect(self, ssh_client, cmds,
                                expect_str="", timeout=CMD_TIMEOUT):
        if not expect_str:
            expect_list = ["# ", "$ "]
        else:
            expect_list = expect_str.split(';')
        self.ssh_exec_command(ssh_client, cmds)
        return self.ssh_recv_output(ssh_client, cmds, expect_list, timeout)

    @staticmethod
    def ssh_close(ssh_client):
        ssh_client['client'].close()

    @staticmethod
    def is_ssh_cmd_executed(listout):
        for line in listout:
            if line.__contains__("last cmd result:") and \
                    not line.__contains__("$?"):
                lr = line.split(':')
                code = lr[1].strip()
                if 0 == int(code):
                    return True
                else:
                    return False
        return False

    @staticmethod
    def is_file_exist(sftp_client, path):
        try:
            sftp_client.stat(path)
            return True
        except IOError:
            return False
        pass

    def put_file(self, sftp_client, source, destdir):
        basename = os.path.basename(source)
        if os.path.isfile(source):
            with self.r_lock_inner:
                sftp_client.put(source, "{}/{}".format(destdir, basename))
        elif os.path.isdir(source):
            if not self.is_file_exist(sftp_client,
                                      "{}/{}".format(destdir, basename)):
                sftp_client.mkdir("{}/{}".format(destdir, basename))
            children = os.listdir(source)
            for child in children:
                with self.r_lock_outer:
                    self.put_file(sftp_client,
                                  "{}/{}".format(source, child),
                                  "{}/{}".format(destdir, basename))

    def put(self, host, user, passwd, source, destdir, port=22):
        _transport = paramiko.Transport(sock=(host, port))
        _transport.connect(username=user, password=passwd)
        _sftpclient = paramiko.SFTPClient.from_transport(_transport)
        with self.r_lock_outer:
            self.put_file(_sftpclient, source, destdir)
        _sftpclient.close()
        _transport.close()

    def scp_file(self, ssh_client, cpfrom, cpto, password):
        scp_cmd = "scp '{}' '{}'".format(cpfrom, cpto)
        listout = self.ssh_send_command_expect(
            ssh_client,
            scp_cmd,
            expect_str="(yes/no;password:"
        )
        for list_out in listout:
            if self.is_expect_recv(['(yes/no'], list_out):
                listout = self.ssh_send_command_expect(ssh_client,
                                                       "yes",
                                                       expect_str="password:")
                break
        for list_out in listout:
            if self.is_expect_recv(['password:'], list_out):
                expect_str = "Permission denied;# ;$ "
                listout = self.ssh_send_command_expect(ssh_client, password,
                                                       expect_str)
                break
        for list_out in listout:
            if self.is_expect_recv([' 100%'], list_out):
                return True
        return False

    def get_ssh_client_user_sudo_root(self, node):
        """从user登录后，使用sudo su root切换到root

        """
        ssh_client = self.create_ssh_client(node.ip, node.user, node.user_pwd)
        self.ssh_exec_command_return(ssh_client, "TMOUT=0")
        self.ssh_send_command_expect(ssh_client, "sudo su root",
                                     expect_str="password for root:")
        expect_str = "Sorry, try again;Permission denied;# ;$ "
        result = self.ssh_send_command_expect(ssh_client, node.root_pwd,
                                              expect_str=expect_str)
        if self.failed_to_return(result, "# ", ssh_client):
            raise Exception("Sudo su root failed.")
        self.ssh_exec_command_return(ssh_client, "TMOUT=0")
        return ssh_client

    def get_ssh_client_user_su_root(self, node):
        """从user登录后，使用su root切换到root

        """
        ssh_client = self.create_ssh_client(node.ip, node.user, node.user_pwd)
        self.ssh_exec_command_return(ssh_client, "TMOUT=0")
        self.ssh_send_command_expect(ssh_client, "su root",
                                     expect_str="Password: ")
        expect_str = "Permission denied;# ;$ "
        result = self.ssh_send_command_expect(ssh_client, node.root_pwd,
                                              expect_str=expect_str)
        if self.failed_to_return(result, "# ", ssh_client):
            raise Exception("Su root failed.")
        self.ssh_exec_command_return(ssh_client, "TMOUT=0")
        return ssh_client

    def get_ssh_client(self, node, sudo_type="sudo"):
        """

        :param node:
        :param sudo_type: 切换账号的方式 su/sudo
        :return:
        """
        try:
            if sudo_type.lower() == "sudo":
                return self.get_ssh_client_user_sudo_root(node)
            else:
                return self.get_ssh_client_user_su_root(node)
        except Exception as e:
            logger.warn(f'Normal user login node:{node.ip} failed, '
                        f'return message:{str(e)}, '
                        'use root relogin directly.')
            ssh_client = self.create_ssh_client(node.ip, "root", node.root_pwd)
            self.ssh_exec_command_return(ssh_client, "TMOUT=0")
            return ssh_client

    def failed_to_return(self, result_list, expect_str, ssh_client):
        for result in result_list:
            if self.is_expect_recv([expect_str], result):
                return False
        self.ssh_close(ssh_client)
        return True

    def success_to_return(self, result_list, expect_str):
        for result in result_list:
            if self.is_expect_recv([expect_str], result):
                return True
        return False

    @staticmethod
    def bytes_to_str(input_content):
        str_content = input_content
        if isinstance(str_content, bytes):
            str_content = input_content.decode()
        return str_content
