/* **********************************************************
 * Copyright (c) 2011-2022 VMware, Inc. All rights reserved. -- VMware Confidential
 * **********************************************************/

package com.vmware.vapi.provider.local;

import static com.vmware.vapi.MessageFactory.getMessage;
import static com.vmware.vapi.internal.util.TaskUtil.isTaskInvocation;
import static com.vmware.vapi.std.StandardDataFactory.INTERNAL_SERVER_ERROR;
import static com.vmware.vapi.std.StandardDataFactory.INVALID_ARGUMENT;
import static com.vmware.vapi.std.StandardDataFactory.OPERATION_NOT_FOUND;
import static com.vmware.vapi.std.StandardDataFactory.getMessagesFromErrorValue;

import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.zip.CRC32;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.vmware.vapi.CoreException;
import com.vmware.vapi.Message;
import com.vmware.vapi.MessageFactory;
import com.vmware.vapi.core.ApiProvider;
import com.vmware.vapi.core.AsyncHandle;
import com.vmware.vapi.core.ExecutionContext;
import com.vmware.vapi.core.InterfaceDefinition;
import com.vmware.vapi.core.InterfaceIdentifier;
import com.vmware.vapi.core.MethodDefinition;
import com.vmware.vapi.core.MethodIdentifier;
import com.vmware.vapi.core.MethodResult;
import com.vmware.vapi.core.ProviderDefinition;
import com.vmware.vapi.data.ConstraintValidationException;
import com.vmware.vapi.data.DataDefinition;
import com.vmware.vapi.data.DataValue;
import com.vmware.vapi.data.ErrorDefinition;
import com.vmware.vapi.data.ErrorValue;
import com.vmware.vapi.data.StringDefinition;
import com.vmware.vapi.diagnostics.LogDiagnosticUtil;
import com.vmware.vapi.diagnostics.LogDiagnosticsConfigurator;
import com.vmware.vapi.diagnostics.Slf4jMDCLogConfigurator;
import com.vmware.vapi.internal.provider.introspection.OperationIntrospectionService;
import com.vmware.vapi.internal.provider.introspection.ProviderIntrospectionService;
import com.vmware.vapi.internal.provider.introspection.ServiceIntrospectionService;
import com.vmware.vapi.internal.util.StringUtils;
import com.vmware.vapi.internal.util.Validate;
import com.vmware.vapi.internal.util.async.StrictAsyncHandle;
import com.vmware.vapi.internal.util.time.Chronometer;
import com.vmware.vapi.provider.ApiInterface;
import com.vmware.vapi.provider.introspection.SyncApiIntrospection;
import com.vmware.vapi.std.StandardDataFactory;

/**
 * The <code>LocalProvider</code> class is a local (in-process) implementation
 * of the <code>ApiProvider</code> interface.
 *
 * <p>
 * <i>Thread-safety:</i> This class is thread-safe.
 */
public class LocalProvider implements ApiProvider, SyncApiIntrospection {

    private static final String INVOKEMETHOD_MISSING_INPUT_DEF = "vapi.provider.local.invokemethod.missing.input.def";

    private static final StringDefinition TASK_OUTPUT_DEFINITION = StringDefinition.getInstance();

    public static final String LOCAL_PROVIDER_DEFAULT_NAME = "LocalProvider";

    private static final Logger logger =
            LoggerFactory.getLogger(LocalProvider.class);

    private final LogDiagnosticsConfigurator logDiag =
        new Slf4jMDCLogConfigurator();

    /**
     * Set of errors which are reported by the local provider implementation.
     */
    static final Set<ErrorDefinition> LOCAL_PROVIDER_ERROR_DEFS =
            Collections.unmodifiableSet(new HashSet<>(
                Arrays.asList(
                    StandardDataFactory.createStandardErrorDefinition(INTERNAL_SERVER_ERROR),
                    StandardDataFactory.createStandardErrorDefinition(INVALID_ARGUMENT),
                    StandardDataFactory.createStandardErrorDefinition(OPERATION_NOT_FOUND))));

    private final ConcurrentMap<String, ApiInterface> services;
    private final String name;

    /**
     * Calls {@link #LocalProvider(String, List)} with an empty list of
     * <code>ApiInterface</code>s, automatically deploying introspection
     * services.
     */
    public LocalProvider(String name) {
        this(name, Collections.<ApiInterface>emptyList());
    }

    /**
     * Calls {@link LocalProvider#LocalProvider(String, List, boolean)}
     * automatically deploying introspection services.
     */
    public LocalProvider(String name, List<ApiInterface> ifaces) {
        this(name, ifaces, true);
    }

    /**
     * Constructor. Optionally can deploy a set of introspection services
     * which expose the introspection API of the provider through
     * {@link #invoke(String, String, DataValue, ExecutionContext, AsyncHandle)}.
     *
     * @param name name of the provider
     * @param ifaces list of <code>ApiInterface</code>s to register
     *                with this provider
     * @param deployIntrospectionServices whether to automatically deploy
     *                                     introspection services
     */
    public LocalProvider(String name,
                         List<ApiInterface> ifaces,
                         boolean deployIntrospectionServices) {
        services = new ConcurrentHashMap<>();

        if (name != null && !name.trim().isEmpty()) {
            this.name = name.trim();
        } else {
            this.name = LOCAL_PROVIDER_DEFAULT_NAME;
        }

        if (deployIntrospectionServices) {
            addInterface(new ProviderIntrospectionService(this));
            addInterface(new ServiceIntrospectionService(this));
            addInterface(new OperationIntrospectionService(this));
        }

        if (ifaces != null) {
            for (ApiInterface i : ifaces) {
                addInterface(i);
            }
        }
    }

    /**
     * Adds the specified API interface to the provider.
     *
     * @param iface    API interface to add; must not be <code>null</code>
     *
     * @throws IllegalArgumentException if API interface with the same ID as
     *         <code>iface</code> is already registered
     */
    public void addInterface(ApiInterface iface) {
        Validate.notNull(iface);
        String serviceId = iface.getIdentifier().getName();
        ApiInterface duplicate = services.putIfAbsent(serviceId, iface);
        if (duplicate != null) {
            throw new IllegalArgumentException(
                    "Duplicate service name: " + serviceId);
        }
        logger.info("Registered the service {}", serviceId);
    }

    /**
     * Adds a list of API interface to the provider.
     *
     * @param ifaces list of <code>ApiInterface</code>s to register
     *                with this provider; must not be <code>null</code>
     * @throws IllegalArgumentException if an API interface with the same ID as some
     *         <code>iface</code> is already registered
     */
    public void addInterfaces(List<ApiInterface> ifaces) {
        Validate.notNull(ifaces);
        for (ApiInterface i : ifaces) {
            addInterface(i);
        }
    }

    @Override
    public ProviderDefinition getDefinition(ExecutionContext ctx) {
        return new ProviderDefinition(name, computeCheckSum());
    }

    private String computeCheckSum() {
        try {
            CRC32 checksum = new CRC32();
            for (ApiInterface iface : services.values()) {
                // add interface ids
                checksum.update(iface.getIdentifier().getName().getBytes("UTF-8"));

                for (MethodIdentifier methodId : iface.getDefinition().getMethodIdentifiers()) {
                    MethodDefinition methodDef = iface.getMethodDefinition(methodId);
                    // add methods input/output/error definitions
                    checksum.update(methodDef.toString().getBytes("UTF-8"));
                }
            }

            return StringUtils.crc32ToHexString(checksum);
        } catch (UnsupportedEncodingException ex) {
            // this should never happen, since "UTF-8" is supported in Java
            logger.error("Unable to get UTF-8 bytes for data for checksum",
                         ex);
            return "";
        }
    }

    @Override
    public Set<InterfaceIdentifier> getInterfaceIdentifiers(
            ExecutionContext ctx) {
        LinkedHashSet<InterfaceIdentifier> ids =
                new LinkedHashSet<>();
        for (ApiInterface i : services.values()) {
            ids.add(i.getIdentifier());
        }
        return ids;
    }

    @Override
    public InterfaceDefinition getInterface(ExecutionContext ctx,
                                            InterfaceIdentifier iface) {
        if (iface == null) {
            return null;
        }
        ApiInterface service = services.get(iface.getName());
        if (service != null) {
            return service.getDefinition();
        }
        return null;
    }

    @Override
    public MethodDefinition getMethod(ExecutionContext ctx,
                                      MethodIdentifier methodId) {
        if (methodId == null) {
            return null;
        }

        ApiInterface iface = services.get(methodId.getInterfaceIdentifier().getName());
        if (iface == null) {
            return null;
        }

        MethodDefinition methodDef = iface.getMethodDefinition(methodId);
        if (methodDef == null) {
            return null;
        }

        // add errors reported by the local provider itself
        Set<ErrorDefinition> augmentedErrors = new HashSet<>(
                methodDef.getErrorDefinitions());
        augmentedErrors.addAll(LOCAL_PROVIDER_ERROR_DEFS);
        return new MethodDefinition(methodDef.getIdentifier(),
                                    methodDef.getInputDefinition(),
                                    methodDef.getOutputDefinition(),
                                    augmentedErrors);
    }

    @Override
    public void invoke(String serviceId,
                       String operationId,
                       DataValue input,
                       ExecutionContext ctx,
                       AsyncHandle<MethodResult> asyncHandle) {
        try {
            logDiag.configureContext(LogDiagnosticUtil
                    .getDiagnosticContext(ctx));
            invokeMethodInt(serviceId, operationId, input, ctx, asyncHandle);
        } catch (InvocationError ex) {
            MethodResult errorResult = ex.getMethodResult();
            setError(asyncHandle, errorResult, serviceId, operationId, ex);
        } catch (Exception ex) {
            MethodResult errorResult = invokeMethodError(serviceId,
                                                         operationId,
                                                         ex);
            setError(asyncHandle, errorResult, serviceId, operationId, ex);
        } finally {
            logDiag.cleanUpContext(LogDiagnosticUtil.getDiagnosticKeys());
        }
    }

    private static void setError(AsyncHandle<MethodResult> asyncHandle,
                                 MethodResult errorResult,
                                 String serviceId,
                                 String operationId,
                                 Exception ex) {
        try {
            asyncHandle.setResult(errorResult);
        } catch (IllegalStateException e) {
            /*
             * Looks like the invocation is already complete and the async
             * handle complains that we try to complete it again. Just log
             * the error and forget about it.
             */
            logger.error(String.format(
                    "Operation '{}' from Service '{}'  threw an " +
                    "exception after completing the invocation",
                        operationId, serviceId),
                    ex);
        }
    }

    /**
     * Convert java {@link Exception} to {@link MethodResult} suitable for async
     * operation
     *
     * @param serviceId name of the invoked service when the error occurred
     * @param operationId name of the invoked operation when the error occurred
     * @param e the error to be converted
     * @return VAPI style result
     */
    private static MethodResult invokeMethodError(String serviceId,
                                                  String operationId,
                                                  Exception e) {
        if (e instanceof ConstraintValidationException) {
            ConstraintValidationException cve;
            cve = (ConstraintValidationException) e;
            logger.error("Validation error", cve);
            return createErrorResult(INVALID_ARGUMENT,
                                     cve.getExceptionMessages());
        } else if (e instanceof CoreException) {
            CoreException ex = (CoreException) e;
            // some error in the runtime infrastructure,
            // log it and report INTERNAL_SERVER_ERROR
            logger.error("invokeMethod error:", ex);
            return createErrorResult(INTERNAL_SERVER_ERROR,
                                     ex.getExceptionMessages());
        } else {
            logger.warn("invokeMethod error:", e);
            Message message = getMessage("vapi.provider.local.invoke.exception",
                                         operationId,
                                         serviceId);
            return createErrorResult(INTERNAL_SERVER_ERROR,
                                     Arrays.asList(message));
        }
    }

    /**
     * Invokes the specified method using the provided input and
     * execution context.
     *
     * @param serviceId   service identifier
     * @param operationId operation identifier for <code>servideId</code>
     * @param ctx         execution context for the method invocation
     * @param input       input arguments for the method
     * @param asyncHandle handle used to return result or error; the result may
     *                    be the actual result from the invocation or error
     *                    result if there is some problem with the invocation
     *                    (e.g. failed validation)
     *
     * @see #invoke
     */
    private void invokeMethodInt(final String serviceId,
                                 final String operationId,
                                 final DataValue input,
                                 final ExecutionContext ctx,
                                 final AsyncHandle<MethodResult> asyncHandle) {
        Objects.requireNonNull(asyncHandle, "AsyncHandle must not be null");
        final Chronometer invokeMethodTimer = new Chronometer();
        invokeMethodTimer.start();


        try {
            validateInvokeInputs(serviceId, operationId, input, ctx);

            ApiInterface iface = services.get(serviceId);
            if (iface == null) {
                throw new InvocationError(OPERATION_NOT_FOUND,
                                          "vapi.method.input.invalid.interface",
                                          serviceId);
            }

            // TODO: need more refactoring to get rid of this (MethodIdentifier)
            MethodIdentifier methodId = new MethodIdentifier(iface
                    .getIdentifier(), operationId);

            MethodDefinition methodDef = iface.getMethodDefinition(methodId);
            if (methodDef == null) {
                throw new InvocationError(OPERATION_NOT_FOUND,
                                          "vapi.method.input.invalid.method",
                                          operationId,
                                          serviceId);
            }

            completeMethodInput(methodDef, input);

            // invoke the requested method
            logger.debug("call to invoke() for service '{}', operation '{}'",
                         serviceId, operationId);

            boolean taskInvocation = isTaskInvocation(operationId);

            AsyncHandle<MethodResult> cb =
                    new AsyncHandleAdapter(methodDef, taskInvocation,
                                           asyncHandle, invokeMethodTimer);

            iface.invoke(ctx, methodId, input, new StrictAsyncHandle<>(cb));

        } catch(InvocationError e) {
            invokeMethodTimer.stop();
            throw e;
        }
    }

    private static void validateInvokeInputs(String serviceId,
                                             String operationId,
                                             DataValue input,
                                             ExecutionContext ctx) {
        if (ctx == null) {
            // TODO: but shouldn't this be just the cause for the more general
            // message: vapi.provider.local.invoke.exception
            throw new InvocationError(INTERNAL_SERVER_ERROR,
                      "vapi.provider.local.invoke.missing.execution.context");
        }

        if (serviceId == null) {
            throw new InvocationError(INTERNAL_SERVER_ERROR,
                      "vapi.provider.local.invoke.missing.serviceId");
        }

        if (operationId == null) {
            throw new InvocationError(INTERNAL_SERVER_ERROR,
                      "vapi.provider.local.invoke.missing.operationId");
        }

        if (input == null) {
            throw new InvocationError(INTERNAL_SERVER_ERROR,
                              "vapi.provider.local.invoke.missing.input.value",
                              serviceId,
                              operationId);
        }
    }

    /**
     * Adapter around {@link AsyncHandle} to add addition validation and error
     * handling logic.
     */
    private class AsyncHandleAdapter extends AsyncHandle<MethodResult> {
        final protected MethodDefinition methodDef;
        final protected AsyncHandle<MethodResult> asyncHandle;
        final protected Chronometer invokeMethodTimer;
        final protected MethodIdentifier methodId;
        final protected String serviceId;
        final protected String operationId;
        final protected DataDefinition outputType;

        public AsyncHandleAdapter(final MethodDefinition methodDef,
                                  final boolean isTaskInvocation,
                                  final AsyncHandle<MethodResult> asyncHandle,
                                  final Chronometer invokeMethodTimer) {
            this.methodDef = methodDef;
            this.asyncHandle = asyncHandle;
            this.outputType = getOutputType(methodDef, isTaskInvocation);
            this.invokeMethodTimer = invokeMethodTimer;
            methodId = methodDef.getIdentifier();
            serviceId = methodId.getInterfaceIdentifier().getName();
            operationId = methodId.getName();
        }

        @Override
        public void updateProgress(DataValue progress) {
            asyncHandle.updateProgress(progress);
        }

        @Override
        public void setResult(MethodResult result) {
            ErrorValue valError;
            try {
                valError = validateMethodResult(methodDef, result, outputType);
            } catch (RuntimeException ex) {
                asyncHandle.setResult(
                        invokeMethodError(serviceId, operationId, ex));
                invokeMethodTimer.stop();
                return;
            }
            if (valError != null) {
                asyncHandle.setResult(MethodResult.newErrorResult(valError));
            } else {
                asyncHandle.setResult(result);
            }
            invokeMethodTimer.stop();
        }

        @Override
        public void setError(RuntimeException error) {
            asyncHandle.setResult(
                    invokeMethodError(serviceId, operationId, error));
        }
    }

    /**
     * Fills in missing data in the value presented as extraneous fields.
     *
     * @param method    method definition
     * @param input     input value for the method
     */
    private static void completeMethodInput(final MethodDefinition method,
                                            final DataValue input) {
        DataDefinition inputType = method.getInputDefinition();
        if (inputType == null) {
            MethodIdentifier methodId = method.getIdentifier();
            throw new CoreException(INVOKEMETHOD_MISSING_INPUT_DEF,
                                    methodId.toString());
        }

        inputType.completeValue(input);
    }

    /**
     * Validate a result returned by the specified method.
     *
     * @param method method definition
     * @param result result returned by the method
     * @param outputType definition of the operation output
     * @return {@code ErrorValue} describing validation problem if such is
     *         found; or {@code null} if validation is successful
     */
    private static ErrorValue validateMethodResult(final MethodDefinition method,
                                                   final MethodResult result,
                                                   final DataDefinition outputType) {
        if (result == null) {
            throw new CoreException(
                        "vapi.provider.local.invokemethod.missing.result",
                        method.getIdentifier().getFullyQualifiedName());
        }
        if (result.success()) {
            validateMethodOutput(method, result.getOutput(), outputType);
            return null;
        } else {
            return validateMethodError(method, result);
        }
    }

    /**
     * Determines the definition of an invocation return type. If the operation
     * is invoked as task the return type is changed to {@link StringDefinition}
     *
     * @param method definition of the operation
     * @param ctx invocation context
     * @return definition of the output type to be used in validation
     */
    private static DataDefinition getOutputType(final MethodDefinition method,
                                                final boolean isTaskInvocation) {
        /*
         * TODO: [kkaraatanassov] We may want to emit task *ApiMethod entry in
         * the *ApiInterface class to avoid this logic in the runtime.
         */
        DataDefinition outputType;
        if (isTaskInvocation && isTaskEnabledOperation(method)) {
            outputType = TASK_OUTPUT_DEFINITION;
        } else {
            outputType = method.getOutputDefinition();
        }
        return outputType;
    }

    /**
     * Determines if an operation supports Task invocations.
     *
     * @param method definition of the operation
     * @return true if the operation supports tasks
     */
    private static boolean isTaskEnabledOperation(final MethodDefinition method) {
        return MethodDefinition.TaskSupport.NONE != method.getTaskSupport();
    }

    /**
     * Validate an output value returned by the specified method.
     *
     * @param method method used for debugging purposes
     * @param output     output value returned by the method
     * @param outputType output type the output value should conform to
     * @return {@code ErrorValue} describing validation problem if such
     *         is found; or {@code null} if validation is successful
     */
    private static void validateMethodOutput(MethodDefinition method,
                                             DataValue output,
                                             DataDefinition outputType) {
        if (outputType == null) {
            throw new CoreException(
                    "vapi.provider.local.invokemethod.missing.output.def",
                    method.getIdentifier().getFullyQualifiedName());
        }
    }

    /**
     * Validate error reported by the specific method.
     *
     * @param method definition for the method that was just invoked
     * @param result result from method invocation
     */
    private static ErrorValue validateMethodError(MethodDefinition method,
                                                  MethodResult result) {
        ErrorValue error = result.getError();
        ErrorDefinition errorDef = method.getErrorDefinition(error.getName());

        // method definition does not include the error
        if (errorDef == null) {
            logger.error("Undeclared error %s reported from method {}",
                         error.getName(),
                         method.getIdentifier());
            return createErrorForMsgs(INTERNAL_SERVER_ERROR,
                          getMessagesFromErrorValue(error),
                          "vapi.provider.local.invokemethod.errors.undeclared",
                          error.getName(),
                          method.getIdentifier().toString());
        }

        // the error structure does not match the method definition
        List<Message> validationMsgs = errorDef.validate(error);
        if (validationMsgs.size() > 0) {
            logger.error("Invalid error {} reported from method {}: {}",
                         error,
                         method.getIdentifier(),
                         validationMsgs);
            List<Message> msgs = getMessagesFromErrorValue(error);
            return createErrorForMsgs(INTERNAL_SERVER_ERROR,
                              joinMsgs(validationMsgs, msgs),
                              "vapi.provider.local.invokemethod.errors.invalid",
                              error.getName(),
                              method.getIdentifier().toString());
        }

        // the error is perfectly valid for the method
        return null;
    }

    /**
     * Joins two lists of messages
     *
     * @param first first list to include
     * @param second second list to include
     * @return list with all message
     */
    private static List<Message> joinMsgs(List<Message> first,
                                          List<Message> second) {
        List<Message> msgs = new LinkedList<>(first);
        msgs.addAll(second);
        return msgs;
    }

    /**
     * Create {@link ErrorValue} from name, set of initial message and a top of
     * the stack message
     *
     * @param errorName error name
     * @param previousMsgs previous messages such as those from validation or
     *        unexpected error
     * @param msgId identifier of the top of stack message to be added
     * @param msgArgs arguments for the top of stack message
     * @return
     */
    private static ErrorValue createErrorForMsgs(String errorName,
                                                 List<Message> previousMsgs,
                                                 String msgId,
                                                 String... msgArgs) {
        List<Message> allMsgs = appendMsg(previousMsgs, msgId, msgArgs);
        return StandardDataFactory.createErrorValueForMessages(errorName,
                                                               allMsgs);
    }

    private static MethodResult createErrorResult(String errorName,
                                                  List<Message> msgs) {
        return MethodResult.newErrorResult(
           StandardDataFactory.createErrorValueForMessages(errorName, msgs));
    }

    /**
     * Create {@link ErrorValue} from name and message
     *
     * @param errorName error name
     * @param msgId identifier of the top of stack message to be added
     * @param msgArgs arguments for the top of stack message
     * @return
     */
    private static ErrorValue createErrorForMsgs(String errorName,
                                                 String msgId,
                                                 String... msgArgs) {
        Message msg = MessageFactory.getMessage(msgId, msgArgs);
        List<Message> allMsgs = Collections.singletonList(msg);
        return StandardDataFactory.createErrorValueForMessages(errorName,
                                                               allMsgs);
    }

    /**
     * Appends a message to list of messages.
     *
     * @param previousMsgs previous messages that will be extended
     * @param msgId message identifier of the new message
     * @param msgArgs new message string arguments
     * @return extended error message list
     */
    private static List<Message> appendMsg(List<Message> previousMsgs,
                                        String msgId,
                                        String... msgArgs) {
        Message message = MessageFactory.getMessage(msgId, msgArgs);
        List<Message> allMsgs = new LinkedList<>();
        allMsgs.add(message);
        allMsgs.addAll(previousMsgs);
        return allMsgs;
    }

    /**
     * Utility class allowing use of Java exceptions in the synchronous part of
     * the request.
     * <p/>
     * If and when we embrace Java 8 this class can be re-factored to use
     * functions and we can use the functional approach to error handling too.
     */
    private static class InvocationError extends RuntimeException {
        private static final long serialVersionUID = 1L;
        private ErrorValue errorValue;


        public InvocationError(String errorName,
                               String msgId,
                               String... msgArgs) {
            this(createErrorForMsgs(errorName,msgId, msgArgs));
        }

        public InvocationError(ErrorValue errorValue) {
            super();
            this.errorValue = Objects.requireNonNull(errorValue);
        }

        public MethodResult getMethodResult() {
            return MethodResult.newErrorResult(errorValue);
        }
    }

}
