# -*-coding:utf-8 -*-
import json
import os
import requests
import utils.common.log as logger
from utils.common.ssh_util import Ssh
from plugins.DistributedStorage.scripts.utils.common.DeployConstant import DeployConstant
from requests_toolbelt import MultipartEncoder


class RestClient(object):
    def __init__(self, ip=DeployConstant.IP, port=DeployConstant.PORT):
        self.ip = "[" + ip + "]" if ":" in ip else ip
        self.port = port
        self.token = None
        self.csrf_token = None
        self.jsession_id = None
        self.session = None
        self.res_login = None

    def make_header(self, content_type='application/json', **kwargs):
        """
        构造header
        :param content_type:
        :param token:
        :param kwargs:
        :return:
        """
        header = {}
        if content_type is not None:
            header = {'Content-type': content_type,
                      'User-Agent': 'Mozilla/5.0'}
        if self.token is not None:
            header['X-Auth-Token'] = self.token
        if self.csrf_token is not None:
            header['X-CSRF-Token'] = self.csrf_token
        for i in kwargs:
            header[i] = kwargs[i]
        return header

    def set_cookie(self, token):
        """
        设置token
        :param token:
        :return:
        """
        self.token = token

    def login(self, username, password, keepsession=False):
        """
        登录模块
        :param username:用户名
        :param password:密码
        :return:
        """
        url = DeployConstant.HTTPS + self.ip + ":" + self.port + DeployConstant.LOGIN
        login_data = {"username": username,
                      "password": password,
                      "scope": 0}
        json_data = json.dumps(login_data)
        login_header = {
            'Content-type': 'application/json',
            'Cookie': '__LANGUAGE_KEY__=zh-CN; __IBASE_LANGUAGE_KEY__=zh-CN'}
        # 消除规避ssl认证导致的警告
        requests.packages.urllib3.disable_warnings()
        try:
            with requests.session() as session:
                res = session.post(url, data=json_data, headers=login_header,
                                   verify=False)
                if keepsession:
                    self.session = session
                    self.res_login = res.json()
                else:
                    res.close()
        except Exception as e:
            err_msg = "Failed to login [https://%s:%s], Detail:%s" % (
                self.ip, self.port, str(e))
            raise Exception(err_msg)
        return res

    def get_session(self):
        '''
        查询当前会话
        '''
        url = DeployConstant.HTTPS + self.ip + ':' + self.port + DeployConstant.CURRENT_SESSION
        try:
            res = self.normal_request(url, 'get')
            resp = res.json()
            result, data = resp.get('result'), resp.get('data')
        except Exception as e:
            err_msg = 'Failed to get session info, detail:%s' % str(e)
            raise Exception(err_msg)
        return result, data

    def normal_request(self, url, method, data=None, timeout=None, keepsession=False):
        """
        一般的请求，除登录和上传文件外的请求统一走这里
        :param url:
        :param method:
        :return:
        """
        # 消除规避ssl认证导致的警告
        requests.packages.urllib3.disable_warnings()
        if data is not None:
            json_data = json.dumps(data)
        else:
            json_data = data
        if keepsession:
            req = self.session
            self.csrf_token = self.res_login.get('data').get('x_csrf_token')
            self.token = self.res_login.get('data').get('x_auth_token')
        else:
            req = requests.session()
        headers = self.make_header()
        with req as session:
            if method == 'put':
                res = session.put(url, data=json_data, headers=headers,
                                  verify=False, timeout=timeout)
            elif method == 'post':
                res = session.post(url, data=json_data, headers=headers,
                                   verify=False, timeout=timeout)
            elif method == 'get':
                res = session.get(url, headers=headers, verify=False, timeout=timeout)
            elif method == 'delete':
                res = session.delete(url, headers=headers, verify=False, timeout=timeout)
            res.close()
        return res

    def upload_tar_request_stream(self, url, method, file_path):
        """
        上传tar包
        :param method:
        :param url:
        :param file_path:
        :return:
        """
        logger.info("upload request: %s %s " % (method, url))
        requests.packages.urllib3.disable_warnings()

        from requests_toolbelt import MultipartEncoder
        m = MultipartEncoder(fields={'file': (os.path.basename(file_path),
                                              open(file_path, "rb"))})
        headers = self.make_header(content_type=m.content_type)
        res = requests.post(
            url=url,
            data=m,
            headers=headers,
            verify=False)
        res.close()
        return res

    def upload_cert_request(self, url, method, file_path, username, password):
        """
        上传证书
        :param method:
        :param url:
        :param file_path:
        :param username:用户名
        :param password:密码
        :return:
        """
        logger.info("upload request: %s %s " % (method, url))
        requests.packages.urllib3.disable_warnings()
        retry = 0
        task_data = dict()
        while retry < 3:
            csrf_token = self.res_login.get('data').get('x_csrf_token')
            m = MultipartEncoder(fields={'filename': (os.path.basename(file_path), open(file_path, 'rb'))})
            headers = {'Content-Type': m.content_type}
            with self.session as session:
                session.cookies['X_CSRF_TOKEN'] = csrf_token
                res = session.post(url=url, data=m, headers=headers, verify=False)
            task_data['status_code'] = res.status_code
            if res.content:
                task_data['content'] = json.loads(res.content)
            if task_data['status_code'] == 200:
                res.close()
                break
            # 若会话过期则重新登录
            result, data = self.get_session()
            logger.info('Current session(%s, %s)' % (result, data))
            if result['code'] != 0:
                logger.info('Try login again.')
                self.login(username, password, keepsession=True)
            for resp in data:
                if resp['status'] != '1':
                    logger.info('Session timeout. Try login again.')
                    self.login(username, password, keepsession=True)
            retry += 1
            logger.info('Try to upload certificate in %s time.' % retry)
        return task_data


class StorageSSHClient:
    def __init__(self, host_ip, username, password, root_pwd=None):
        self.host_ip = host_ip
        self.username = username
        self.password = password
        self.ssh_client = Ssh.ssh_create_client(host_ip, username, password)
        if username == "root":
            self.is_root = True
        else:
            self.is_root = False
        self.root_pwd = root_pwd

    def __del__(self):
        if self.ssh_client:
            Ssh.ssh_close(self.ssh_client)

    def switch_root(self, root_password=None):
        cmd_res = None
        try:
            if self.is_root is False:
                if root_password is None:
                    root_password = self.root_pwd
                cmd_res = Ssh.ssh_send_command(
                    self.ssh_client, 'su -', 'assword:', 20)
                cmd_res = Ssh.ssh_send_command(
                    self.ssh_client, root_password, '#', 20)
                cmd_res = Ssh.ssh_send_command(
                    self.ssh_client, 'TMOUT=0', '#', 20)
                self.is_root = True
        except Exception as e:
            logger.error("Failed to su root on host cmd_res:%s, using user "
                         "%s, manager_ip:%s, err:%s"
                         % (str(cmd_res), self.username, self.host_ip, str(e)))
            raise

    def send_cmd(self, cmd, expect, timeout=60, retry_times=0, sensitive=False):
        if not sensitive:
            logger.info("exec cmd: %s, expect: %s" % (cmd, expect))
        else:
            logger.info("exec cmd: %s, expect: %s" % ("****", expect))
        cmd_ret = Ssh.ssh_send_command(
            self.ssh_client, cmd, expect, timeout, retry_times)
        logger.debug("cmd ret: %s" % cmd_ret)
        return cmd_ret

    def upload(self, local_file, remote_file):
        logger.info("upload file: %s -> %s" % (local_file, remote_file))
        return Ssh.put_file(self.host_ip, self.username, self.password,
                            local_file, remote_file)
