'''
common_method
'''

import os
import time
import subprocess
import json
import datetime
import shutil
from copy import deepcopy
import sys
sys.path.append(os.path.split(os.path.abspath(__file__))[0])

from set_log import set_log
from util.httpclient import CommonHttpClient
from util import common
from util import ossext
import re

CUR_PATH = os.path.split(__file__)[0]


def get_temporary_path():
    '''
    获取临时目录
    :return:
    '''
    command = "sh %s/../../get_temporary_dir.sh" % CUR_PATH
    code, temporary_path = subprocess.getstatusoutput(command)
    if code != 0 or not temporary_path:
        raise ValueError("get temporary_path failed")
    return temporary_path

def get_biggest_version(dir_path, contain="", end="", log=None):
    '''
    获取最大的版本号
    :param dir_path: 查找文件或者目录所在的路径
    :param contain: 包含的内容
    :param end: 结尾的内容
    :return:
    '''
    if log is None:
        log = set_log(stream=True)
    log.info("start get_biggest_version, dir_path: %s", dir_path)
    result_list = []
    compile_re = re.compile(r"\d+\..*\.\d+")
    for name in os.listdir(dir_path):
        if contain in name and name.endswith(end):
            version_group = re.search(compile_re, name)
            if version_group:
                version_str = version_group.group()
                version_num = int("".join(version_str.split(".")))
                result_list.append([version_num, os.path.join(dir_path, name)])
                log.info("name: %s, version_str: %s, version_num: %s" % (name, version_str, version_num))
    if result_list:
        log.info(result_list)
        result_list.sort(key=lambda x: x[0], reverse=True)
        return result_list[0][1]
    else:
        return ""


def get_all_db_name(instance_name, db_type, product_type="fabric", tenant_name="",
                    read_container_list_file=None, container_list_info=None, log=None):
    '''
    :param instance_name:
    :param db_type:
    :param product_type:
    :param tenant_name:
    :param log:
    :param read_container_list_file:
    :param container_list_info:
    :return:
    '''
    if log is None:
        log = set_log(stream=True)
    log.info("start get_all_db_name, product_type: %s", product_type)
    if container_list_info is None:
        if read_container_list_file is None:
            container_list_info = json.loads(CommonMethod().get_container_list(product_type)[1])
        else:
            with open(read_container_list_file, "rb") as _f:
                container_list_info = json.load(_f)
    db_list = []
    keys = list(container_list_info.keys())
    for key in keys:
        if instance_name in key and container_list_info[key]["containerType"] == db_type \
                and "dbList" in container_list_info[key]:
            db_list = list(container_list_info[key]["dbList"].keys())
    return ",".join(db_list)


class CommonMethod:
    '''常用函数'''
    def __init__(self):
        self.log = set_log(stream=True)
        self.success = 0
        self.failed = 1
        self.temporary_path = get_temporary_path()

    def get_tenant_name(self):
        '''
        获取租户名称
        :return:
        '''
        self.log.info("start get_tenant_name")
        if os.path.exists("/opt/oss/NCECOMMONE"):
            tenant_name = "NCECOMMONE"
        else:
            tenant_name = "NCE"
        self.log.info("tenant_name: %s", tenant_name)
        return tenant_name

    def query_product(self, tenant_name="", out_dir="", retry_times=0, sleep_time=6):
        '''
        查询版本
        :param tenant_name:
        :param out_dir:
        :param retry_times:
        :param sleep_time:
        :return:
        '''
        self.log.info("start query_product: %s", locals())
        if not tenant_name:
            tenant_name = self.get_tenant_name()
        if not out_dir:
            now_time = self._create_timestamp_string()
            out_dir = os.path.join(self.temporary_path, "product_%s" % now_time)
            if not os.path.exists(out_dir):
                os.makedirs(out_dir)
        self.log.info("retry_times: %s, sleep_time: %s", retry_times, sleep_time)
        command = "/opt/oss/manager/tools/resmgr/queryproduct.sh -pn %s -output %s" % (tenant_name, out_dir)
        for retry_time in range(retry_times + 1):
            self.log.info("retry_time: %s", retry_time)
            result_code, result_str = subprocess.getstatusoutput(command)
            self.log.info("command: %s, result_code: %s, result_str: %s", command, result_code, result_str)
            if result_code == 0:
                self.log.info("query_product success")
                break
            time.sleep(sleep_time)
        return out_dir

    def compare_version(self):
        '''
        比较版本号
        :return:
        '''
        tenant_name = self.get_tenant_name()
        out_dir = self.query_product(tenant_name)
        file_path = "/opt/oss/%s/apps/DCNService/pub/dflc/DCN_soft_policy.json" % tenant_name
        dcn_product_version = ""
        with open(file_path, "rb") as _f:
            content_dcn = json.load(_f)
            for info in content_dcn:
                if info["nesName"] == "iMaster NCE-Fabric":
                    dcn_product_version = info["serialNumber"]

        with open(os.path.join(out_dir, "product_%s.json" % tenant_name), "rb") as _f2:
            content_query = json.load(_f2)
            query_product_version = content_query["productext"].get("product_version", "")

        self.log.info("dcn_product_version: %s, query_product_version: %s" % (
            dcn_product_version, query_product_version))
        if not dcn_product_version or not query_product_version:
            return "false"
        if dcn_product_version != query_product_version:
            return "false"

        return "true"

    def get_local_ip(self):
        '''
        获取本地ip地址
        :return: ip
        '''
        if hasattr(common, "getLocalIP"):
            return common.getLocalIP()
        else:
            return common.get_local_ip()

    def get_ir_port(self):
        '''
        获取ir的端口号
        :return:
        '''
        if hasattr(common, "getIRAccessPort"):
            return common.getIRAccessPort()
        else:
            return common.get_ir_access_port()

    def get_gmt_port(self):
        '''
        获取gmt ir的端口号
        :return:
        '''
        if hasattr(common, "getMgmtIRPort"):
            return common.getMgmtIRPort()
        else:
            return common.get_mgmt_ir_port()

    def send_rest(self, _ip, port, url, mode, **params):
        '''
        调用平台res
        :param ip: ip地址
        :param port: 端口
        :param url: url
        :param mode: GET POST DELETE
        :params log, body, sleep, retry_times, content
        :return: 成功 0  失败 1
        '''
        http = CommonHttpClient(_ip, port, True, False)
        times = 0
        body = params.get("body", None)
        while times <= params["retry_times"]:
            if mode == "GET":
                status, ret = http.get(url)
            elif mode == "POST":
                status, ret = http.post(url, body)
            elif mode == "DELETE":
                status, ret = http.delete(url, body)
            elif mode == "PUT":
                status, ret = http.put(url)
            if status == 200 or status == 204:
                result = self.success
                self.log.info("%s [SUCCESS] %s %s response: %s, retry_times:%s",
                                   params["content"], mode, url, status, times)
                break

            self.log.error("%s [FAILED] %s %s response: %s, retry_times:%s",
                                params["content"], mode, url, status, times)
            result = self.failed
            if times < params["retry_times"]:
                times += 1
                time.sleep(params["sleep"])
        return result, ret

    def get_db_ip_and_port(self, instance_name, db_type):
        '''
        获取连接数据库的ip和port，以及instance_id
        :param instance_name: 实例名字
        :param db_type: 数据库类型
        :return: 
        '''
        self.log.info("start get_db_ip_and_port: %s", locals())
        self._check_db_type(db_type)

        _ip = ""
        port = ""
        instance_id = ""
        command = "/opt/oss/manager/apps/DBAgent/bin/dbsvc_adm -cmd query-db-instance |" \
                  "grep -E %s | grep %s | wc -l" % (instance_name, db_type)
        result, result_str = subprocess.getstatusoutput(command)
        self.log.info(result_str)
        if result == 0 and result_str:
            try:
                result_num = int(result_str)
            except ValueError:
                return _ip, port, instance_id
            if int(result_num) == 0:
                return _ip, port, instance_id
            elif int(result_num) == 1:
                self.log.info("single db instance")
                command = "/opt/oss/manager/apps/DBAgent/bin/dbsvc_adm -cmd query-db-instance |" \
                          "grep -E %s | grep %s" % (instance_name, db_type)
                self.log.info(command)
                result, result_str = subprocess.getstatusoutput(command)
                self.log.info(result_str)
                if result == 0 and result_str:
                    _ip = result_str.split()[5]
                    port = result_str.split()[6]
                    instance_id = result_str.split()[0]
            else:
                self.log.info("multi db instance")
                command = "/opt/oss/manager/apps/DBAgent/bin/dbsvc_adm -cmd query-db-instance |" \
                          "grep -E %s | grep %s | grep Master" % (instance_name, db_type)
                self.log.info(command)
                result, result_str = subprocess.getstatusoutput(command)
                self.log.info(result_str)
                if result == 0 and result_str:
                    _ip = result_str.split()[5]
                    port = result_str.split()[6]
                    instance_id = result_str.split()[0]
        return _ip, port, instance_id

    def get_db_pass(self, instance_name, db_type, pass_type="dbUserPasswd", db_name=""):
        '''
        :param instance_name: 数据库实例
        :param db_type: gauss or zenith
        :param pass_type: dbUserPasswd 连接数据库的密码， adminPassword: sys密码
        :return:
        '''
        self.log.info("start get_db_pass: %s", locals())
        self._check_db_type(db_type)
        result_code, ret = self.get_container_list()
        if result_code != self.success:
            self.log.info("get_container_list failed,try get db password by file")
            db_un_passed = self._get_db_pass_by_file(instance_name, db_type, pass_type)
            db_pass = ossext.Cipher.decrypt(db_un_passed)
            return db_pass
        else:
            db_info = json.loads(ret)
            keys = list(db_info.keys())
            for key in keys:
                if instance_name in key and db_info[key]["containerType"] == db_type and "dbList" in db_info[key]:
                    db_un_passed = db_info[key].get(pass_type)
                    # 19.1版本调用接口查出来的字段为空
                    if not db_un_passed and db_name:
                        db_un_passed = db_info[key][db_name.lower()].get(pass_type)
                    if not db_un_passed:
                        self.log.info("get_container_list result has no password, "
                                      "try get db password by file")
                        db_un_passed = self._get_db_pass_by_file(instance_name, db_type, pass_type)
                    db_pass = ossext.Cipher.decrypt(db_un_passed)
                    self.log.info("get_db_pass success")
                    return db_pass
        return ""

    def get_container_list(self, product_type="fabric", save_container_list_file=""):
        '''
        :param product_type: 产品 或者 管理面
        :param save_container_list_file: 不为空，说明需要保存container_list
        :return:
        '''
        self.log.info("start get_container_list, product_type: %s", product_type)
        if product_type == "fabric":
            tenant_name = self.get_tenant_name()
        else:
            tenant_name = "manager"
        input_param = {'tenant': tenant_name}
        _ip = self.get_local_ip()
        port = self.get_gmt_port()
        url = "/rest/plat/dbmgr/v1/main/instances/action?action-id=export-containerlist"
        result_code, ret = self.send_rest(_ip, port, url, mode="POST",
                                          body=input_param, sleep=3, retry_times=3,
                                          content="get_db_json")
        if result_code != self.success:
            self.log.error("get_container_list failed")
            raise ValueError("get_container_list failed")
        else:
            result = json.loads(ret)
            if save_container_list_file:
                self.log.info("start save container_list_info to %s" % save_container_list_file)
                with open(save_container_list_file, "w") as _f:
                    _f.write(json.dumps(result, indent=4, separators=(',', ': ')))
            return result_code, ret

    def exec_sql(self, db_type, instance_name, db_name, sql_file="", sql="", db_pass="", return_content=0,
                 retry_times=0, sleep_time=5, ignore_error_list=None):
        '''
        执行sql
        :param db_type:
        :param instance_name:
        :param db_name:
        :param sql_file:
        :param sql:
        :param db_pass:
        :param return_content: 0：返回code  1：返回内容  其余：都返回
        :param retry_times: 重试次数
        :param sleep_time: 重试间隔
        :param ignore_error_list: 忽略错误
        :return:
        '''
        for retry_time in range(retry_times + 1):
            try:
                if retry_time == 0:
                    params = deepcopy(locals())
                    params.pop("db_pass")
                    self.log.info("start exec_sql: %s", params)
                exec_sql_param, valid_sql = self._check_sql_type(sql_file, sql)
                if valid_sql == sql_file:
                    valid_sql = self._copy_sql_file(sql_file)
                _ip, port, _ = self.get_db_ip_and_port(instance_name=instance_name, db_type=db_type)
                db_pass = self.get_db_pass(instance_name=instance_name, db_type=db_type)
                if not _ip or not port or not db_pass:
                    if retry_time == retry_times:
                        # 重试次数到了
                        self.log.info("It's already the maximum number of retries")
                        raise ValueError("db _ip or db port or db pass is empty")
                    else:
                        self.log.info("db _ip or db port or db pass is empty, retry_times: %s", retry_time)
                        time.sleep(sleep_time)
                        continue

                param_dict = {"db_name": db_name, "db_pass": db_pass, "_ip": _ip, "port": port,
                              "exec_sql_param": exec_sql_param, "valid_sql": valid_sql}

                if db_type == "zenith":
                    command = '''sudo -s -u dbuser<<EOF
source /home/dbuser/.bashrc
/opt/zenith/app/bin/zsql %(db_name)s/%(db_pass)s@%(_ip)s:%(port)s %(exec_sql_param)s "%(valid_sql)s"
EOF''' % param_dict
                else:
                    command = '''sudo -s -u dbuser<<EOF
source /home/dbuser/appgsdb.bashrc
gsql -d %(db_name)s -p %(port)s -U ossdbuser -h %(_ip)s -W %(db_pass)s %(exec_sql_param)s "%(valid_sql)s"
EOF''' % param_dict
                self.log.info(command.replace(db_pass, "****"))
                result_code, result_str = subprocess.getstatusoutput(command)
                self.log.info("result_code: %s\nresult_str: %s", result_code, result_str)
                if result_code != 0:
                    if retry_time == retry_times:
                        # 重试次数到了，继续往下执行，返回该有的值
                        self.log.info("It's already the maximum number of retries")
                    else:
                        is_success = False
                        for ignore_error in ignore_error_list:
                            if ignore_error in result_str:
                                self.log.info("ignore_error '%s' in result_str", ignore_error)
                                result_code = 0
                                is_success = True
                                break
                        if not is_success:
                            self.log.info("exec sql failed, retry_times: %s", retry_time)
                            time.sleep(sleep_time)
                            continue

                if return_content == 0:
                    return result_code
                elif return_content == 1:
                    return result_str
                else:
                    return result_code, result_str
            except:
                if retry_time == retry_times:
                    # 重试次数到了，继续往下执行，返回该有的值
                    raise ValueError("It's already the maximum number of retries")
                self.log.exception("Exception log:")
                time.sleep(sleep_time)

    def _copy_sql_file(self, sql_file):
        '''
        拷贝文件
        :param sql_file:
        :return:
        '''
        new_sql_file = os.path.join(self.temporary_path, "%s_%s.sql" % (
            os.path.splitext(os.path.basename(sql_file))[0], self._create_timestamp_string()))
        self.log.info(new_sql_file)
        shutil.copy(sql_file, new_sql_file)
        command = "chmod 666 %s" % new_sql_file
        self.log.info(command)
        result, result_str = subprocess.getstatusoutput(command)
        self.log.info(result_str)
        if result != 0:
            raise ValueError("call %s failed" % command)
        return new_sql_file

    def _get_db_pass_by_file(self, instance_name, db_type, pass_type):
        '''
        通过文件获取 db pass
        :param instance_name:
        :param db_type:
        :param pass_type:
        :return:
        '''
        self.log.info("start get_db_pass_by_file: %s", locals())
        tenant_name = self.get_tenant_name()
        file_path = "/opt/oss/manager/var/tenants/%s/containerlist.json" % tenant_name
        with open(file_path, "rb") as _f:
            db_info = json.load(_f)
            keys = list(db_info.keys())
            for key in keys:
                if instance_name in key and db_info[key]["containerType"] == db_type and "dbList" in db_info[key]:
                    db_un_passed = db_info[key][pass_type]
                    self.log.info("get_db_pass_by_file success")
                    return db_un_passed
        return ""

    def _check_sql_type(self, sql_file, sql):
        '''
        检查sql类型
        :param sql_file:
        :param sql:
        :return:
        '''
        if sql_file:
            if not os.path.exists(sql_file):
                raise FileNotFoundError("sql_file %s is not exists" % sql_file)
            exec_sql_param = "-f"
            vaild_sql = sql_file
        elif sql:
            exec_sql_param = "-c"
            vaild_sql = sql
        else:
            raise TypeError("sql_file or sql all is None")
        return exec_sql_param, vaild_sql

    def _create_timestamp_string(self):
        '''
        创建时间戳字符串
        :return:
        '''
        return datetime.datetime.now().strftime("%Y%m%d%H%M%S_%f")

    def _check_db_type(self, db_type):
        '''
        检查数据库类型
        :param db_type:
        :return:
        '''
        if not (db_type == "gauss" or db_type == "zenith"):
            raise ValueError("db_type must be gauss or zenith, dbtype: %s" % db_type)





