#!/usr/bin/python
# -*- coding: utf-8 -*-
"""
任务调度器。
用于对升级子任务进行调度，并监控子任务的状态，当出现子任务执行失败时，终止所有子任务。
"""
import os
import sys
import time
import threading
import traceback

sys.path.append(str(sys.argv[1]))

from common_tasks import COMMON_TASKS
from common_tasks import task_logger as logging

SIGNAL_KILL = 9  # 用于在失败时直接杀掉进程退出。
THREAD_POLL_INSTERVAL = 5  # 线程轮询时间间隔，用于调度任务时等待。
MAX_TIMEOUT = 3600 # 任务等待的最长时间，单位：秒

def flush():
    sys.stdout.flush()
    sys.stderr.flush()

def flush_and_exit():
    """
    强制输出缓冲区内容，然后杀死自身进程，退出。
    :return: None
    """
    flush()
    os.kill(os.getpid(), SIGNAL_KILL)


def thread_task(task_info, dispatcher, product_name="NCE"):
    """
    线程task，用于调用真正的任务，显示任务的启动和完成信息；
    在任务返回非0非None或者抛异常时，打印错误信息并退出。
    :param upgrade_tool: 升级任务。
    :return: None
    """
    task_name = task_info.get("name")
    pre_tasks = set(task_info.get("pre_tasks", []))
    start_time = time.time()
    try:
        while True:
            finish_tasks = set(dispatcher.get_finish_tasks())
            if finish_tasks.issuperset(pre_tasks):
                logging.info("pre tasks finished.pre_tasks %s, finished tasks: %s" % (pre_tasks, finish_tasks))
                break
            elif time.time() - start_time < MAX_TIMEOUT:
                logging.info("pre tasks NOT finished. pre_tasks %s, finished tasks: %s"% (pre_tasks, finish_tasks))
                flush()
                time.sleep(THREAD_POLL_INSTERVAL)
            else:
                logging.error("waiting for pre tasks timeout! task %s ignored. "
                              "pre_tasks %s, finished tasks: %s"% (task_name, pre_tasks, finish_tasks))
                dispatcher.add_finish_task(task_name)
                return False

        task = task_info.get("task")
        logging.info(("%s start" % task_name).center(78, ">"))
        flush()
        try:
            # 适配多场景，在不同场景下，可能存在某些源表不存在，此时记录日志并跳过该源表的升级操作。
            # 初始化失败，则标记为完成并退出
            task_obj = task(product_name)
            logging.info("==================%s, product_name is==== %s" % (task_name, task_obj.product_name))
        except TypeError as te:
            logging.error(traceback.format_exc())
            logging.error("init task %s got an exception: %s, ignored." % (str(task_name), str(te)))
            dispatcher.add_finish_task(task_name)
            return False
        except AttributeError as ae:
            logging.error(traceback.format_exc())
            logging.error("init task %s got an exception: %s, ignored." % (str(task_name), str(ae)))
            dispatcher.add_finish_task(task_name)
            return False
        except ValueError as ve:
            logging.error(traceback.format_exc())
            logging.error("init task %s got an exception: %s, ignored." % (str(task_name), str(ve)))
            dispatcher.add_finish_task(task_name)
            return False
        except BaseException as be:
            logging.error(traceback.format_exc())
            logging.error("init task %s got an exception: %s, ignored." % (str(task_name), str(be)))
            dispatcher.add_finish_task(task_name)
            return False

        # task初始化没问题，就可以正常执行
        ret = task_obj.do()
        if ret not in (0, None):
            logging.error("%s return %s." % (str(task_name), ret))
            flush_and_exit()
        dispatcher.add_finish_task(task_name)
        logging.info(("%s finished" % task_name).center(78, "<"))
        flush()
    except KeyError as e:
        logging.error(traceback.format_exc())
        logging.error("%s got an KeyError: %s." % (str(task_name), str(e)))
        flush_and_exit()
    except IOError as e:
        logging.error(traceback.format_exc())
        logging.error("%s got an IOError: %s." % (str(task_name), str(e)))
        flush_and_exit()
    except IndexError as e:
        logging.error(traceback.format_exc())
        logging.error("%s got an IndexError: %s." % (str(task_name), str(e)))
        flush_and_exit()
    except KeyboardInterrupt as e:
        logging.error(traceback.format_exc())
        logging.error("%s got an KeyboardInterrupt: %s." % (str(task_name), str(e)))
        flush_and_exit()
    except ValueError as e:
        logging.error(traceback.format_exc())
        logging.error("%s got an ValueError: %s." % (str(task_name), str(e)))
        flush_and_exit()
    except SystemExit as e:
        logging.warning("get an SystemExit, go to exit and ignore the next tasks.")
        return True
    except BaseException as e:
        logging.error(traceback.format_exc())
        logging.error("%s got an exception: %s." % (str(task_name), str(e)))
        flush_and_exit()
    return False


class CommonTaskDispatcher:
    """
    升级任务调度框架，用于并发调度升级任务，并监控任务状态。
    """
    def __init__(self):
        # 引用升级任务全局变量，根据注册的升级任务进行调度。
        self.common_tasks = COMMON_TASKS

        self.task_lock = threading.Lock()
        self.finish_tasks = []

    def add_finish_task(self, task):
        if self.task_lock.acquire():
            self.finish_tasks.append(task)
            self.task_lock.release()

    def get_finish_tasks(self):
        temp = None
        if self.task_lock.acquire():
            temp = tuple(self.finish_tasks)
            self.task_lock.release()
        return temp

    def do(self, product_name="NCE"):
        """
        升级任务调度函数。使用多线程的方式并发阿调度升级任务。
        :return: None
        """
        for task in self.common_tasks:
            if thread_task(task, self, product_name):
                break


if __name__ == '__main__':
    logging.info("Common task dispatcher start".center(78, ">"))
    logging.info("Common task dispatcher params length is %s：" % len(sys.argv))
    product_name = sys.argv[2]
    logging.info("Common task dispatcher PRODUCT_NAME is %s：" % product_name)
    common_task_dispatcher = CommonTaskDispatcher()
    common_task_dispatcher.do(product_name)
    logging.info("Common task dispatcher finished".center(78, "<"))
