# -*-coding:utf-8 -*-
# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved.
import json
import os
import pathlib

import requests
from requests_toolbelt import MultipartEncoder
from requests.packages import urllib3
import utils.common.log as logger
from plugins.DistributedStorage.scripts.utils.common.DeployConstant import DeployConstant


class RestClient(object):
    def __init__(self, ip=DeployConstant.IP, port=DeployConstant.PORT):
        self.ip = "[" + ip + "]" if ":" in ip else ip
        self.port = port
        self.token = None
        self.csrf_token = None
        self.jsession_id = None
        self.session = None
        self.res_login = None

    def make_header(self, content_type='application/json', **kwargs):
        """
        构造header
        :param content_type:
        :param token:
        :param kwargs:
        :return:
        """
        header = {}
        if content_type is not None:
            header = {'Content-type': content_type,
                      'User-Agent': 'Mozilla/5.0'}
        if self.token is not None:
            header['X-Auth-Token'] = self.token
        if self.csrf_token is not None:
            header['X-CSRF-Token'] = self.csrf_token
        for i in kwargs:
            header[i] = kwargs.get(i)
        return header

    def set_cookie(self, token: tuple):
        """
        设置token
        :param token:
        :return:
        """
        self.token, self.csrf_token = token

    def login(self, username, password, keepsession=False):
        """
        登录模块
        :param username:用户名
        :param password:密码
        :return:
        """
        url = DeployConstant.HTTPS + self.ip + ":" + self.port + DeployConstant.LOGIN
        login_data = {"username": username,
                      "password": password,
                      "scope": 0}
        json_data = json.dumps(login_data)
        login_header = {
            'Content-type': 'application/json',
            'Cookie': '__LANGUAGE_KEY__=zh-CN; __IBASE_LANGUAGE_KEY__=zh-CN'}
        # 消除规避ssl认证导致的警告
        requests.packages.urllib3.disable_warnings()
        try:
            with requests.session() as session:
                res = session.post(url, data=json_data, headers=login_header,
                                   verify=False)
                if keepsession:
                    self.session = session
                    self.res_login = res.json()
                else:
                    res.close()
        except Exception as e:
            err_msg = "Failed to login [https://%s:%s], Detail:%s" % (
                self.ip, self.port, str(e))
            raise Exception(err_msg) from e
        return res

    def get_session(self):
        '''
        查询当前会话
        '''
        url = DeployConstant.HTTPS + self.ip + ':' + self.port + DeployConstant.CURRENT_SESSION
        try:
            res = self.normal_request(url, 'get')
        except Exception as e:
            err_msg = 'Failed to get session info, detail:%s' % str(e)
            raise Exception(err_msg) from e
        resp = res.json()
        result, data = resp.get('result'), resp.get('data')
        return result, data

    def normal_request(self, url, method, data=None, **kwargs):
        """
        一般的请求，除登录和上传文件外的请求统一走这里
        :param url:
        :param method:
        :param data:
        :param kwargs:
        :return:
        """
        # 消除规避ssl认证导致的警告
        requests.packages.urllib3.disable_warnings()
        timeout, keep_session = kwargs.get("timeout"), kwargs.get("keepsession")
        if data is not None:
            json_data = json.dumps(data)
        else:
            json_data = data
        if keep_session:
            req = self.session
            self.csrf_token = self.res_login.get('data').get('x_csrf_token')
            self.token = self.res_login.get('data').get('x_auth_token')
        else:
            req = requests.session()
        headers = self.make_header()
        with req as session:
            if method == 'put':
                res = session.put(url, data=json_data, headers=headers,
                                  verify=False, timeout=timeout)
            elif method == 'post':
                res = session.post(url, data=json_data, headers=headers,
                                   verify=False, timeout=timeout)
            elif method == 'get':
                res = session.get(url, headers=headers, verify=False, timeout=timeout)
            elif method == 'delete':
                res = session.delete(url, headers=headers, verify=False, timeout=timeout)
            res.close()
        return res

    def upload_tar_request_stream(self, url, method, file_path):
        """
        上传tar包
        :param method:
        :param url:
        :param file_path:
        :return:
        """
        logger.info("upload request: %s %s " % (method, url))
        requests.packages.urllib3.disable_warnings()

        m = MultipartEncoder(fields={'file': (os.path.basename(file_path),
                                              open(file_path, "rb"))})
        headers = self.make_header(content_type=m.content_type)
        res = requests.post(
            url=url,
            data=m,
            headers=headers,
            verify=False)
        res.close()
        return res

    def upload_cert_request(self, data_info):
        """
        上传证书
        :param method:
        :param url:
        :param file_path:
        :param username:用户名
        :param password:密码
        :return:
        """
        url, method, file_path, username, password = \
            data_info[0], data_info[1], data_info[2], data_info[3], data_info[4]
        logger.info("upload request: %s %s " % (method, url))
        requests.packages.urllib3.disable_warnings()
        retry = 0
        task_data = dict()
        while retry < 3:
            csrf_token = self.res_login.get('data').get('x_csrf_token')
            m = MultipartEncoder(fields={'filename': (os.path.basename(file_path), open(file_path, 'rb'))})
            headers = {'Content-Type': m.content_type}
            with self.session as session:
                session.cookies['X_CSRF_TOKEN'] = csrf_token
                res = session.post(url=url, data=m, headers=headers, verify=False)
            task_data['status_code'] = res.status_code
            if res.content:
                task_data['content'] = json.loads(res.content)
            if task_data.get('status_code') == 200:
                res.close()
                break
            # 若会话过期则重新登录
            result, data = self.get_session()
            logger.info('Current session(%s, %s)' % (result, data))
            if result['code'] != 0:
                logger.info('Try login again.')
                self.login(username, password, keepsession=True)
            for resp in data:
                if resp['status'] != '1':
                    logger.info('Session timeout. Try login again.')
                    self.login(username, password, keepsession=True)
            retry += 1
            logger.info('Try to upload certificate in %s time.' % retry)
        return task_data

    def download_file_request(self, url, method, verify=False, timeout=None):
        requests.packages.urllib3.disable_warnings()
        headers = self.make_header()
        session = self.session if self.session else requests.sessions.Session()
        with session as _session:
            ret = _session.request(method, url, headers=headers, verify=verify, timeout=timeout)
        if ret.status_code != 200:
            err_msg = 'Request url:{}, Response code:{}'.format(url, ret.status_code)
            logger.error(err_msg)
            raise requests.RequestException(err_msg)
        return ret

    def upload_file_stream_request(self, url, file, verify=False, timeout=None):
        logger.info("upload request: %s %s " % url)
        requests.packages.urllib3.disable_warnings()
        m = MultipartEncoder(fields={'file': (pathlib.Path(file).name, open(file, 'rb'))})
        headers = self.make_header(content_type=m.content_type)
        session = self.session if self.session else requests.sessions.Session()
        with session as _session:
            ret = _session.post(url=url, data=m, headers=headers, verify=verify, timeout=timeout)
        if ret.status_code != 200:
            err_msg = 'Request url:{}, Response code:{}'.format(url, ret.status_code)
            logger.error(err_msg)
            raise requests.RequestException(err_msg)
        return ret
