# -*- coding: utf-8 -*-
# Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved.

import os
import re
import stat
import errno
import shutil
import tarfile
import time
import threading

import utils.common.log as logger
from utils.common.exception import HCCIException
from plugins.DistributedStorage.common.upgrade_operate import UpgradeOperate

FILE_LOCK = threading.Lock()


class SystemHandle:
    def __init__(self):
        """
        Distributed block storage打包、系统操作公共类，非对外接口
        """
        pass

    @staticmethod
    def error_remove_read_only(func, path, exc):
        excvalue = exc[1]
        if func in (os.rmdir, os.remove) and excvalue.errno == errno.EACCES:
            os.chmod(path, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
            # retry
            func(path)
        elif func in (os.rmdir, os.remove) and excvalue.errno == errno.ENOTEMPTY:
            func(path)
        else:
            raise HCCIException(621003, "remove func is wrong")

    @staticmethod
    def check_pkg_version(fs_args):
        package_name = fs_args.get("package_name")
        if package_name.startswith("OceanStor-Pacific_8") or package_name.startswith("Distributed-Storage_8"):
            return 81
        else:
            err_msg = "pkg {} version is wrong".format(package_name)
            logger.error(err_msg)
            raise HCCIException(621003, err_msg)

    @staticmethod
    def un_tar_pkg(ssh_client, remote_pkg_path, dst_dir):
        un_tar_cmd = "tar xfz {} -C {}; echo last_result=$?".format(remote_pkg_path, dst_dir)
        cmd_ret = ssh_client.send_cmd(un_tar_cmd, '#')
        cmd_ret = ''.join(cmd_ret)
        if "last_result=0" not in cmd_ret:
            logger.error("un tar %s failed, ret: %s" % (remote_pkg_path, cmd_ret))
            err_msg = "un tar %s failed" % remote_pkg_path
            raise HCCIException(621007, err_msg)

    @staticmethod
    def _deploymanage_pkg(members):
        for tarinfo in members:
            logger.info(tarinfo.name)
            if os.path.split(tarinfo.name)[1].startswith("OceanStor-Pacific_deploymanager_"):
                yield tarinfo

    def unzip_file(self, file_path, local_dst_dir):
        limit_size = 8 * 1024 * 1024 * 1024
        file_size = os.path.getsize(file_path)
        if file_size > limit_size:
            err_msg = "file size %s is larger than 8G" % file_size
            logger.error(err_msg)
            raise HCCIException(621003, err_msg)
        if os.path.exists(local_dst_dir):
            shutil.rmtree(local_dst_dir, onerror=self.error_remove_read_only)
        os.makedirs(local_dst_dir)
        logger.info("start unzip file")
        with tarfile.open(file_path) as tar_file:
            tar_file.extractall(members=self._deploymanage_pkg(tar_file), path=local_dst_dir)


class RestPublicMethod(object):
    def __init__(self, project_id=None, pod_id=None, fs_args=None, **kwargs):
        self.project_id = project_id
        self.pod_id = pod_id
        self.fs_args = fs_args
        self.opr = UpgradeOperate(fs_args)

    @staticmethod
    def get_failed_component(data_result):
        failed_component_map = dict()

        def _get_failed_component():
            component_info = sequence.get("componentInfoList")
            for component in component_info:
                if component.get("componentStatus") == "failed":
                    failed_component_list.append(component.get("name"))

        for node in data_result:
            if node["status"] != "failed":
                continue
            node_ip = node["hostIp"]
            sequence_infos = node["sequenceInfos"]
            failed_component_list = []
            for _ in sequence_infos:
                _get_failed_component()
            failed_component_map[node_ip] = failed_component_list
        return failed_component_map

    def check_storage_version(self):
        ret_result, ret_data = self.opr.get_version()
        if ret_result["code"] != 0:
            err_msg = "get version failed, Detail:[result:%s, data:%s]" % (ret_result, ret_data)
            logger.error(err_msg)
            raise HCCIException(621003, err_msg)
        version = ret_data.get("version")
        logger.info("current version is %s" % version)
        if version >= "8.1":
            return 81
        else:
            return 80

    def storage_version_after_813(self):
        version_regular1 = "^8.1.(RC[5-9]|[3-9]|([3-9].(SPH\d+|SPC\d+|HP\d+)))$"
        version_regular2 = "^8.[2-9].(RC\d+|[0-9]|([0-9].(SPH\d+|SPC\d+|HP\d+)))$"
        ret_result, ret_data = self.opr.get_version()
        if ret_result["code"] != 0:
            err_msg = "get version failed, Detail:[result:%s, data:%s]" % (ret_result, ret_data)
            logger.error(err_msg)
            raise HCCIException(621003, err_msg)
        version = ret_data.get("version")
        logger.info("current version is %s" % version)
        if re.match(version_regular1, version) or re.match(version_regular2, version):
            return True
        else:
            return False

    def disable_sandbox(self, ip, root_pwd, admin_pwd):
        disable_nodes = {'node_ips': ip,
                         'root_password': root_pwd,
                         'admin_password': admin_pwd}
        ret_data = self.opr.disable_os_sandbox(disable_nodes)
        failed_nodes_list = []
        for node in ret_data:
            if node.get('code') != 0:
                failed_nodes_list.append(node)
        if failed_nodes_list:
            err_msg = 'Failed disable sandbox nodes list: %s' % failed_nodes_list
            logger.error(err_msg)
            raise Exception(err_msg)

    def enable_sandbox(self, ip):
        ret_data = self.opr.enable_os_sandbox(ip)
        failed_nodes_list = []
        for node in ret_data:
            if node.get('code') != 0:
                failed_nodes_list.append(node)
        if failed_nodes_list:
            err_msg = 'Failed enable sandbox nodes list: %s' % failed_nodes_list
            logger.error(err_msg)
            raise Exception(err_msg)

    def get_server_list(self):
        osd_nodes_list = []
        vbs_nodes_list = []
        ret_result, ret_data = self.opr.get_servers()
        if ret_result["code"] != 0:
            err_msg = "get servers failed, Detail:[result:%s, data:%s]" % (ret_result, ret_data)
            logger.error(err_msg)
            raise Exception(err_msg)
        for item in ret_data:
            if "storage" in item["role"]:
                osd_nodes_list.append(item["management_ip"])
            if "compute" in item["role"]:
                vbs_nodes_list.append(item["management_ip"])
        return osd_nodes_list, vbs_nodes_list


class SshPublicMethod:
    def __init__(self, project_id=None, pod_id=None, fs_args=None, *args, **kwargs):
        """
        ssh public interface
        """
        pass

    @staticmethod
    def check_status_opr(ssh_client):
        get_status_cmd = 'curl -sgkX GET -H "Content-type: application/json" --noproxy -m 5 --connect-timeout 5 ' \
                         'https://127.0.0.1:6098/api/v2/curl/deploy/tomcat_status'
        check_timeout = 600
        while check_timeout > 0:
            cmd_ret = ssh_client.send_cmd(get_status_cmd, "#", 10)
            cmd_ret = ''.join(cmd_ret)
            if "success" in cmd_ret:
                break
            check_timeout -= 10
            time.sleep(10)
        if check_timeout <= 0:
            raise HCCIException(621003, "deploymanager status failed")

    @staticmethod
    def ssh_cmd_get_arch(ssh_client):
        arch_cmd = "uname -r"
        cmd_ret = ssh_client.send_cmd(arch_cmd, '#')
        return cmd_ret

    def get_deploymanager_pkg_name(self, ssh_client, local_pkg_dir):
        cmd_ret = self.ssh_cmd_get_arch(ssh_client)
        cmd_ret = ''.join(cmd_ret)
        if "x86_64" in cmd_ret:
            platform = "x86_64"
        elif "aarch" in cmd_ret:
            platform = "aarch64"
        else:
            raise HCCIException(621003, "node arch %s is wrong." % cmd_ret)
        pkg_name_prefix = "OceanStor-Pacific_deploymanager_"
        pkg_name_like = pkg_name_prefix + platform
        file_list = os.listdir(local_pkg_dir)
        deploymanager_pkg_name = None
        for file_name in file_list:
            if len(re.findall(pkg_name_like, file_name)) > 0:
                deploymanager_pkg_name = file_name
                break
        if not deploymanager_pkg_name:
            raise HCCIException(621003, "the deploymanager pkg {} is not exist".format(pkg_name_like))
        return deploymanager_pkg_name
