# -*- coding: UTF-8 -*-
#  Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.

import os
import inspect
from psdk.dsl.dsl import DslContext
from psdk.dsl.operator.base_operator import BaseOperator


def get_all_operators():
    operator_files = os.listdir(os.path.join(os.path.dirname(__file__), "operator"))
    all_sub_classes = []
    for operator_file in operator_files:
        if not operator_file.endswith(".py"):
            continue
        operator_module = __import__("psdk.dsl.operator." + operator_file[:-3], fromlist=["None"])
        members = inspect.getmembers(operator_module, predicate=is_subclass)
        member_classes = [member[1] for member in members]
        all_sub_classes.extend(member_classes)
    return list(set(all_sub_classes))


def get_data_source_operators(context):
    """
    dsl支持的用作数据源操作符
    :param context: 上下文
    :return: dsl表达式
    """
    data_source_operators = None
    all_sub_classes = get_all_operators()
    for bp_cls in all_sub_classes:
        operator_obj = bp_cls(context)
        if operator_obj.operator_type == "data_source":
            if data_source_operators is None:
                data_source_operators = operator_obj.get()
            else:
                data_source_operators = data_source_operators | operator_obj.get()
    return data_source_operators


def is_subclass(o):
    return inspect.isclass(o) and issubclass(o, BaseOperator)


def get_parser_operators(context):
    """
    dsl支持的用作解析的操作符
    :param context: 上下文
    :return: dsl表达式
    """
    parser_operators = None
    all_sub_classes = get_all_operators()
    for bp_cls in all_sub_classes:
        operator_obj = bp_cls(context)
        if operator_obj.operator_type == "parser":
            if parser_operators is None:
                parser_operators = operator_obj.get()
            else:
                parser_operators = parser_operators | operator_obj.get()
    return parser_operators
