# -*- coding: UTF-8 -*-
#  Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved.

from psdk.dsl.common import DslException, ReturnException
from psdk.dsl.common import DslContext
from psdk.dsl.operator_factory import (
    get_data_source_operators,
    get_parser_operators
)
from expparser import *
import traceback


class Dsl:
    def __init__(self, dsl_context):
        self.dsl_context = dsl_context
        self.origin_info = []
        self.logger = dsl_context.context.get_logger()

    def run(self, express, *args, **kwargs):
        """
        执行一条dsl命令
        :param express: dsl 表达式
        :return:
        """
        if kwargs.get("need_log", True):
            self.logger.info("run dsl {}".format(express))
        self.dsl_context.args = args
        self.dsl_context.kwargs = kwargs
        exec_section = get_data_source_operators(self.dsl_context)
        parse_section = get_parser_operators(self.dsl_context)
        pattern = exec_section + ZeroOrMore(parse_section)
        try:
            pattern.parseString(express)
        except ParseException:
            self.logger.info("dsl parse error:" + express)
            raise DslException("parse.expected")
        except ReturnException as ex:
            self.logger.info("return exception")
            if isinstance(ex.get_ret(), Exception):
                raise ex.get_ret()
            return ex.get_ret()

        except Exception:
            self.logger.info("dsl general error:" + traceback.format_exc())
            raise DslException("not.expected")
        finally:
            self.origin_info = self.dsl_context.origin_info or []

        return self.dsl_context.last_data

    def get_matched_data(self):
        return self.dsl_context.matched_data
