# -*- coding: utf-8 -*-
import json
import os
import platform
import string
import random
import zipfile
import shutil
import time
import traceback
import utils.common.log as logger
from utils.common.exception import FCUException
from plugins.DistributedStorage.common.upgrade_operate import UpgradeOperate
from plugins.DistributedStorage.common.base import TestCase
from plugins.DistributedStorage.common.constants import Constant
from plugins.DistributedStorage.common.public_handle_new import SystemHandle


class UpgradeInspection(TestCase):
    def __init__(self, project_id, pod_id, fs_args, **kwargs):
        super(UpgradeInspection, self).__init__(project_id, pod_id)
        self.more_args = kwargs
        self.opr = UpgradeOperate(fs_args)
        self.user_name = fs_args.get("user_name")
        self.password = fs_args.get("password")
        self.tmp_path_suffix = ''.join(random.choices(string.digits + string.ascii_letters, k=10))
        current_platform = platform.platform()
        if current_platform.startswith("Windows"):
            self.res_file_dir = Constant.WIN_PMI_RET_DIR.format(self.tmp_path_suffix)
            self.detail_file_dir = Constant.WIN_PMI_DETAILS_DIR.format(self.tmp_path_suffix)
        else:
            self.res_file_dir = Constant.UNIX_PMI_RET_DIR.format(self.tmp_path_suffix)
            self.detail_file_dir = Constant.UNIX_PMI_DETAILS_DIR.format(self.tmp_path_suffix)
        self.res_file_name = 'resource.zip'
        self.check_list_file = fs_args.get("check_list_file")
        self.pmi_item_black = ["deployVersion"]
        self.down_file_name = None

    @staticmethod
    def unzip_file(file_path, dst_dir):
        normal = None
        limit_size = 4 * 1024 * 1024 * 1024
        file_size = os.path.getsize(file_path)
        if file_size > limit_size:
            return False
        if os.path.exists(dst_dir):
            shutil.rmtree(dst_dir, onerror=SystemHandle.error_remove_read_only)
        os.makedirs(dst_dir)
        logger.info("start unzip file")
        with zipfile.ZipFile(file_path, 'r') as zip_file:
            for file in zip_file.namelist():
                zip_file.extract(file, dst_dir)
        return normal

    @staticmethod
    def _parse_error_detail(result):
        res = []
        for key, value in result.items():
            if len(value) > 2:
                item = json.dumps(value[:2])
                item += '...'
            else:
                item = json.dumps(value)
            res.append(f"{key}:{item}")
            if len(res) == 2:
                res = ";".join(res)
                res += "..."
                return res
        ";".join(res)
        return res

    @staticmethod
    def _get_pmi_result(inspection_data):
        fail_results = []
        fail_item = []
        for item in inspection_data:
            pmi_item_id = item["item_id"]
            pmi_item_process = item["process"]
            if pmi_item_process != 100:
                err_msg = "inspection item(%s) process(%s) timeout" \
                          % (pmi_item_id, str(pmi_item_process))
                fail_results.append(err_msg)
                continue
            pmi_item_status = item.get("node_status")
            for node_ip, status in pmi_item_status.items():
                if status != 0:
                    err_msg = "inspection item(%s) in node(%s) " \
                              "status(%s) is abnormal" \
                              % (pmi_item_id, node_ip, str(status))
                    fail_results.append(err_msg)
                    fail_item.append((pmi_item_id, node_ip))
        return fail_results, fail_item

    def procedure(self):
        logger.info('Start pre upgrade inspection check task.')
        try:
            status_code, error_code, error_des = self.opr.try_login(
                self.user_name, self.password)
            if status_code != 200 or error_code != 0:
                err_msg = "login failed, Detail:[status:%s,code:%s]%s" \
                          % (status_code, error_code, error_des)
                logger.error(err_msg)
                raise Exception(err_msg)

            logger.info('inspection check.')
            nodes_list = self._get_server_list()

            inspections_list = self._get_inspections_list()

            ret_result, ret_data = self.opr.start_inspection(
                nodes_list, inspections_list)
            if ret_result["code"] != 0:
                err_msg = "upgrade inspection failed, " \
                          "Detail:[result:%s, data:%s]" \
                          % (ret_result, ret_data)
                logger.error(err_msg)
                raise Exception(err_msg)
            task_id = ret_data["task_id"]

            check_timeout = 300 * (len(nodes_list)/128 + 1)

            inspection_data = self._get_inspect_data(
                task_id, inspections_list, check_timeout)

            fail_results, fail_item = self._get_pmi_result(inspection_data)
            if fail_results:
                self._download_inspect_file(task_id, fail_item)
                self._unzip_inspect_file()
                result = self._parase_result(fail_item)
                raise FCUException(621005, result)
        finally:
            logger.info('rm -rf %s.' % self.res_file_dir)
            if os.path.exists(self.res_file_dir):
                shutil.rmtree(self.res_file_dir,
                              onerror=SystemHandle.error_remove_read_only)
            if os.path.exists(self.detail_file_dir):
                shutil.rmtree(self.detail_file_dir,
                              onerror=SystemHandle.error_remove_read_only)
            logger.info('procedure end.')

    def _get_server_list(self):
        nodes_list = []
        ret_result, ret_data = self.opr.get_servers()
        if ret_result["code"] != 0:
            err_msg = "get servers failed, " \
                      "Detail:[result:%s, data:%s]" \
                      % (ret_result, ret_data)
            logger.error(err_msg)
            raise Exception(err_msg)
        for item in ret_data:
            nodes_list.append(item["management_ip"])
        return nodes_list

    def _get_inspections_list(self):
        os.makedirs(self.res_file_dir)
        res_file_path = os.path.join(self.res_file_dir, self.res_file_name)
        ret_result, ret_data = self.opr.get_pmi_resource(res_file_path)
        if ret_result != 0:
            err_msg = "get pmi resource file failed."
            logger.error(err_msg)
            raise Exception(err_msg)
        dst_dir = os.path.join(self.res_file_dir, "resource")
        self.unzip_file(res_file_path, dst_dir)
        dst_res_file = os.path.join(dst_dir, self.check_list_file)

        inspections_list = []
        import xml.etree.ElementTree as ET
        tree = ET.parse(dst_res_file)
        root = tree.getroot()
        cmd_type_items = root.findall("cmdtype")
        for cmd_type_item in cmd_type_items:
            cmd_items = cmd_type_item.findall("cmditem")
            for cmd_item in cmd_items:
                if cmd_item.attrib.get('id') not in self.pmi_item_black:
                    inspections_list.append(cmd_item.attrib.get('id'))
        return inspections_list

    def _download_inspect_file(self, task_id, fail_item_ip):
        os.makedirs(self.detail_file_dir)
        for item_id, node_ip in fail_item_ip:
            self.down_file_name = item_id + node_ip + ".zip"
            res_file_path = os.path.join(self.detail_file_dir, self.down_file_name)
            ret_result, ret_data = self.opr.download_inspection_result(task_id, res_file_path, item_id, node_ip)
            if ret_result != 0:
                err_msg = "get inspection file failed, " \
                          "Detail:[result:%s, data:%s]" \
                          % (ret_result, ret_data)
                logger.error(err_msg)
                raise Exception(err_msg)

    def _get_inspect_data(self, task_id, inspections_list, timeout):
        check_timeout = timeout
        inspection_data = {}
        while check_timeout > 0:
            node_timeout = False
            time.sleep(10)
            check_timeout -= 10
            ret_result, ret_data = self.opr.get_inspection_status(
                task_id, inspections_list)
            if ret_result["code"] != 0:
                err_msg = "get inspection status failed, " \
                          "Detail:[result:%s, data:%s]" \
                          % (ret_result, ret_data)
                logger.error(err_msg)
                raise Exception(err_msg)
            for item in ret_data:
                if item["process"] != 100:
                    node_timeout = True
                    break
            if node_timeout:
                continue
            inspection_data = ret_data
            break
        if check_timeout <= 0:
            logger.error("inspect timeout. data %s" % inspection_data)
        return inspection_data

    def _unzip_inspect_file(self):
        dst_dir = os.path.join(self.detail_file_dir, "result")
        file_list = os.listdir(self.detail_file_dir)
        for file_name in file_list:
            file_path = os.path.join(self.detail_file_dir, file_name)
            with zipfile.ZipFile(file_path, 'r') as zip_file:
                for file in zip_file.namelist():
                    zip_file.extract(file, dst_dir)

    def _parase_result(self, fail_item_ip):
        result = {}
        for item, node in fail_item_ip:
            file_path = os.path.join(self.detail_file_dir, 'result', item, node, item + '.sh.rst')
            with open(file_path, 'r') as f:
                content = f.read()
            if result.get(item):
                result.get(item).append((node, content))
            else:
                result[item] = [(node, content)]
            logger.error("upgrade inspection failed!"
                         "Item: %s, Node: %s, Detail: %s "
                         % (item, node, content))
            res = self._parse_error_detail(result)
            return res
