# -*- coding: utf-8 -*-
import time
import traceback

import utils.common.log as logger
from utils.common.exception import HCCIException
from plugins.DistributedStorage.common.UpgradeOperate import UpgradeOperate
from plugins.DistributedStorage.common.base import TestCase
from plugins.DistributedStorage.common.PublicHandleNew import RestPublicMethod


class UpgradePkg(TestCase):
    def __init__(self, project_id, pod_id, fs_args, condition=None, metadata=None, **kwargs):
        super(UpgradePkg, self).__init__(project_id, pod_id)
        self.more_args = kwargs
        self.condition = condition
        self.metadata = metadata
        self.opr = UpgradeOperate(fs_args)
        self.user_name = fs_args["user_name"]
        self.password = fs_args["password"]
        self.rest_opr = RestPublicMethod(project_id, pod_id, fs_args)

    def procedure(self):
        logger.info('Start rollback task.')
        try:
            status_code, error_code, error_des = self.opr.try_login(self.user_name, self.password)
            if status_code != 200 or error_code != 0:
                err_msg = "Failed to login, Detail:[status:%s,code:%s], error:%s" % (status_code, error_code, error_des)
                logger.error(err_msg)
                raise Exception(err_msg)

            self._check_upgrade_status()

            osd_nodes_list, vbs_nodes_list = self.rest_opr.get_server_list()

            self._do_upgrade_task()

            logger.info('get rollback upgrade task info.')
            osd_timeout = 300*len(osd_nodes_list)
            vbs_timeout = 300 * (len(vbs_nodes_list)/20 + 1)
            check_timeout = 90*60 + osd_timeout + vbs_timeout
            self._wait_upgrade_result(check_timeout)

        except HCCIException as e:
            logger.error('rollback upgrade pkg failed, details:{}'.format(e))
            logger.error(traceback.format_exc())
            raise e
        except Exception as e:
            logger.error('rollback upgrade pkg failed, details:{}'.format(e))
            logger.error(traceback.format_exc())
            raise HCCIException(620011, str(e))
        finally:
            logger.info('procedure end...')

    def retry(self):
        logger.info('Start retry rollback task.')
        try:
            status_code, error_code, error_des = self.opr.try_login(self.user_name, self.password)
            if status_code != 200 or error_code != 0:
                err_msg = "Failed to login, Details:[status:%s,code:%s]%s" % (status_code, error_code, error_des)
                logger.error(err_msg)
                raise Exception(err_msg)

            self._check_upgrade_status()
            osd_nodes_list, vbs_nodes_list = self.rest_opr.get_server_list()

            logger.info("get rollback task result")
            ret_result, ret_data = self.opr.get_upgrade_task()
            if ret_result["code"] != '0':
                err_msg = "get rollback task failed, " \
                          "Detail:[result:%s, data:%s]" \
                          % (ret_result, ret_data)
                logger.error(err_msg)
                raise Exception(err_msg)
            elif ret_data["status"] == "failed" \
                    and ret_data["taskType"] == "rollback":
                logger.info('retry rollback task.')
                ret_result, ret_data = self.opr.retry_upgrade_task()
                if ret_result["code"] != '0':
                    err_msg = "retry rollback task failed, " \
                              "Detail:[result:%s, data:%s]" \
                              % (ret_result, ret_data)
                    logger.error(err_msg)
                    raise Exception(err_msg)
            elif ret_data["taskType"] == "upgrade":
                self._do_upgrade_task()

            logger.info('Get retry rollback task info.')
            check_timeout = 90*60 + 300*len(osd_nodes_list) + 300 * (len(vbs_nodes_list)/20 + 1)
            self._wait_upgrade_result(check_timeout)
        except HCCIException as e:
            logger.error('rollback pkg failed, details:{}'.format(e))
            logger.error(traceback.format_exc())
            raise e
        except Exception as e:
            logger.error('rollback pkg failed, details:{}'.format(e))
            logger.error(traceback.format_exc())
            raise HCCIException(620011, str(e))
        finally:
            logger.info('procedure end...')

    def _wait_upgrade_result(self, check_timeout):
        max_retry_times = 30
        failed_nums = 0
        ret_result = None
        ret_data = None
        while check_timeout > 0:
            try:
                ret_result, ret_data = self.opr.get_upgrade_task()
            except Exception as ex:
                err_msg = "get rollback task failed, Exception: %s" \
                          % (str(ex))
                logger.error(err_msg)
                failed_nums += 1
                if failed_nums > max_retry_times:
                    logger.error("failed number ({}) too more"
                                 .format(failed_nums))
                    raise
            else:
                failed_nums = 0
                if 0 == self._check_upgrade_result(ret_result, ret_data):
                    break
            time.sleep(10)
            check_timeout -= 10
        if check_timeout <= 0:
            err_msg = "upgrade rollback timeout, status:{status}, " \
                      "progress:{progress}. result:{ret_result}, " \
                      "data: {ret_data}" \
                .format(status=ret_data["status"],
                        progress=ret_data["percent"],
                        ret_result=ret_result, ret_data=ret_data)
            logger.error(err_msg)
            raise HCCIException(621000, err_msg)

    def _check_upgrade_status(self):
            logger.info("check upgrade status.")
            ret_result, ret_data = self.opr.get_upgrade_status()
            if ret_result["code"] != '0':
                err_msg = "get upgrade status failed, " \
                          "Detail:[result:%s, data:%s]"\
                          % (ret_result, ret_data)
                logger.error(err_msg)
                raise Exception(err_msg)
            elif ret_data["currentPhase"] == "upgrading":
                logger.info("current is upgrading")
            elif ret_data["currentPhase"] == "rollbacking":
                logger.info("current is rollbacking")

    def _do_upgrade_task(self):
        logger.info('start upgrade task.')
        ret_result, ret_data = self.opr.rollback_upgrade_task()
        if ret_result["code"] != '0':
            err_msg = "rollback upgrade task failed, " \
                      "Detail:[result:%s, data:%s]" \
                      % (ret_result, ret_data)
            logger.error(err_msg)
            raise Exception(err_msg)

    def _check_upgrade_result(self, ret_result, ret_data):
        if ret_result["code"] != '0':
            err_msg = "get rollback task failed, " \
                      "Detail:[result:%s, data:%s]" \
                      % (ret_result, ret_data)
            logger.error(err_msg)
            raise Exception(err_msg)
        if ret_data["status"] == "failed":
            failed_map = RestPublicMethod.get_failed_component(ret_data["result"])
            err_msg = "rollback {} failed".format(failed_map)
            logger.error(err_msg)
            raise Exception(err_msg)
        elif ret_data["status"] == "success":
            logger.info("upgrade success!")
            return 0
        elif ret_data["status"] == "running":
            logger.info("rollback is running. Progress is {}%"
                        .format(ret_data["percent"]))
        else:
            logger.info("get upgrade task status {status}. "
                        "result: {ret_result}, data: {ret_data}"
                        .format(status=ret_data["status"],
                                ret_result=ret_result,
                                ret_data=ret_data))
        return 1
