/* **********************************************************
 * Copyright (c) 2011-2015, 2017-2019, 2022 VMware, Inc.  All rights reserved. -- VMware Confidential
 * **********************************************************/
package com.vmware.vapi.internal.bindings;

import static com.vmware.vapi.MessageFactory.getMessage;
import static com.vmware.vapi.bindings.exception.Constants.INTERNAL_SERVER_ERROR;
import static com.vmware.vapi.bindings.exception.Constants.INVALID_ARGUMENT;
import static com.vmware.vapi.bindings.exception.Constants.OPERATION_NOT_FOUND;
import static com.vmware.vapi.bindings.exception.Constants.UNEXPECTED_INPUT;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.vmware.vapi.CoreException;
import com.vmware.vapi.Message;
import com.vmware.vapi.MessageFactory;
import com.vmware.vapi.bindings.ApiError;
import com.vmware.vapi.bindings.StaticStructure;
import com.vmware.vapi.bindings.exception.Constants;
import com.vmware.vapi.bindings.server.InvocationContext;
import com.vmware.vapi.bindings.type.ErrorType;
import com.vmware.vapi.bindings.type.StructType;
import com.vmware.vapi.bindings.type.Type;
import com.vmware.vapi.bindings.type.TypeReference;
import com.vmware.vapi.core.AsyncHandle;
import com.vmware.vapi.core.ExecutionContext;
import com.vmware.vapi.core.InterfaceIdentifier;
import com.vmware.vapi.core.MethodDefinition;
import com.vmware.vapi.core.MethodDefinition.TaskSupport;
import com.vmware.vapi.core.MethodIdentifier;
import com.vmware.vapi.core.MethodResult;
import com.vmware.vapi.data.ConstraintValidationException;
import com.vmware.vapi.data.DataDefinition;
import com.vmware.vapi.data.DataType;
import com.vmware.vapi.data.DataValue;
import com.vmware.vapi.data.ErrorDefinition;
import com.vmware.vapi.data.ErrorValue;
import com.vmware.vapi.data.ListDefinition;
import com.vmware.vapi.data.ListValue;
import com.vmware.vapi.data.OptionalDefinition;
import com.vmware.vapi.data.OptionalValue;
import com.vmware.vapi.data.StructDefinition;
import com.vmware.vapi.data.StructValue;
import com.vmware.vapi.internal.bindings.TypeConverter.ConversionContext;
import com.vmware.vapi.internal.bindings.type.TypeUtil;
import com.vmware.vapi.internal.data.DeepDefinitionVisitor;
import com.vmware.vapi.internal.util.Validate;
import com.vmware.vapi.provider.ApiMethod;
import com.vmware.vapi.std.StandardDataFactory;

/**
 * Base <code>ApiMethod</code> implementation, intended to be subclassed by
 * API methods generated by vAPI Java generator.
 */
public abstract class ApiMethodSkeleton implements ApiMethod {
    private static final Logger logger =
            LoggerFactory.getLogger(ApiMethodSkeleton.class);

    /**
     * Set of errors which are reported by this class.
     */
    static final Set<ErrorDefinition> API_METHOD_SKELETON_ERRORS =
            Collections.unmodifiableSet(new HashSet<ErrorDefinition>(
                Arrays.asList(
                    StandardDataFactory.createStandardErrorDefinition(INVALID_ARGUMENT),
                    StandardDataFactory.createStandardErrorDefinition(INTERNAL_SERVER_ERROR),
                    StandardDataFactory.createStandardErrorDefinition(UNEXPECTED_INPUT))));

    /**
     * Set of errors that are allowed be reported by provider implementations
     * even though they are not included in the VMODL2 throws clause.
     *
     * TODO: This constant should be removed when once the runtime/bindings
     * check for invocations using API features that are disabled by feature
     * switch and report these errors (so provider implementations don't need
     * to report them).
     */
    static final Set<ErrorDefinition> UNCHECKED_ERRORS =
            Collections.unmodifiableSet(new HashSet<ErrorDefinition>(
                Arrays.asList(
                    StandardDataFactory.createStandardErrorDefinition(
                        OPERATION_NOT_FOUND),
                    StandardDataFactory.createStandardErrorDefinition(
                        UNEXPECTED_INPUT))));

    private final MethodIdentifier id;
    private final MethodDefinition definition;
    private final TypeConverter converter;
    private final Map<Class<?>, ErrorType> errorClass2Type;
    private final StructType inputType;
    private final Type outputType;

    /**
     * Constructor.
     *
     * @param ifaceId      identifier of the <code>ApiInterface</code> for
     *                     this <code>ApiMethod</code>. must not be
     *                     <code>null</code>
     * @param name         <code>ApiMethod</code> name. must not be
     *                     <code>null</code>
     * @param inputType    type of the <code>ApiMethod</code> input. must not be
     *                     <code>null</code>
     * @param outputType   type of the <code>ApiMethod</code> output/return type.
     *                     must not be <code>null</code>
     * @param converter    converter between API runtime types and bindings types;
     *                     must not be <code>null</code>
     * @param errorTypes   collection of error types declared for this method;
     *                     can be <code>null</code>
     */
    public ApiMethodSkeleton(InterfaceIdentifier ifaceId,
                             String name,
                             StructType inputType,
                             Type outputType,
                             TypeConverter converter,
                             Collection<Type> errorTypes) {
        this(ifaceId, name, inputType, outputType, converter, errorTypes,
                TaskSupport.NONE);
    }
    /**
     * Constructor.
     *
     * @param ifaceId      identifier of the <code>ApiInterface</code> for
     *                     this <code>ApiMethod</code>. must not be
     *                     <code>null</code>
     * @param name         <code>ApiMethod</code> name. must not be
     *                     <code>null</code>
     * @param inputType    type of the <code>ApiMethod</code> input. must not be
     *                     <code>null</code>
     * @param outputType   type of the <code>ApiMethod</code> output/return type.
     *                     must not be <code>null</code>
     * @param converter    converter between API runtime types and bindings types;
     *                     must not be <code>null</code>
     * @param errorTypes   collection of error types declared for this method;
     *                     can be <code>null</code>
     * @param isTask       if the <code>ApiMethod</code> is a task-enabled method
     * @param isTaskOnly   if the <code>ApiMethod</code> is a task-only method
     */
    public ApiMethodSkeleton(InterfaceIdentifier ifaceId,
                             String name,
                             StructType inputType,
                             Type outputType,
                             TypeConverter converter,
                             Collection<Type> errorTypes,
                             TaskSupport taskSupport) {
        Validate.notNull(ifaceId);
        Validate.notNull(name);
        Validate.notNull(inputType);
        Validate.notNull(outputType);
        Validate.notNull(converter);
        id = new MethodIdentifier(ifaceId, name);
        errorClass2Type = new HashMap<Class<?>, ErrorType>();
        Set<ErrorDefinition> augmentedErrors =
                new HashSet<ErrorDefinition>(API_METHOD_SKELETON_ERRORS);
        /*
         * Allow methods to report operation_not_found or unexpected_input.
         * This special case is currently allow so that vSphere providers
         * can report these errors when the feature state is disabled.
         *
         * TODO: This statement should be removed when once the runtime/bindings
         * check for invocations using API features that are disabled by feature
         * switch and report these errors (so provider implementations don't
         * need to report them).
         */
        augmentedErrors.addAll(UNCHECKED_ERRORS);

        if (errorTypes != null) {
            for (Type type : errorTypes) {
                // TODO: optimize to remove these casts
                //       cast to (TypeReference<?>) is just warning-killer
                augmentedErrors.add((ErrorDefinition) TypeUtil.toDataDefinition(type));
                ErrorType errType = (ErrorType) ((TypeReference<?>) type).resolve();
                errorClass2Type.put(errType.getBindingClass(), errType);
            }
        }
        this.inputType = inputType;
        this.outputType = outputType;
        definition = new MethodDefinition(id,
                                          TypeUtil.toDataDefinition(inputType),
                                          TypeUtil.toDataDefinition(outputType),
                                          augmentedErrors,
                                          taskSupport);
        this.converter = converter;
    }

    @Override
    public MethodIdentifier getIdentifier() {
        return id;
    }

    @Override
    public MethodDefinition getDefinition() {
        return definition;
    }

    @Override
    public DataDefinition getInputDefinition() {
        return definition.getInputDefinition();
    }

    @Override
    public DataDefinition getOutputDefinition() {
        return definition.getOutputDefinition();
    }

    @Override
    public void invoke(InvocationContext invocationContext,
                       DataValue input,
                       AsyncHandle<MethodResult> asyncHandle) {
        try {
            ErrorValue valError = validateInput(input);
            if (valError != null) {
                logger.debug(
                        "Input validation failed. Completing invocation with error\n {}",
                        valError);
                asyncHandle.setResult(MethodResult.newErrorResult(valError));
                return;
            }

            ValidatingAsyncHandle validatingHandle = new ValidatingAsyncHandle(
                    asyncHandle, outputType, id);
            doInvoke(invocationContext, (StructValue) input, validatingHandle);
        } catch(RuntimeException ex) {
            if (logger.isDebugEnabled()) {
                logger.debug(String.format("Method %s threw an exception", id), ex);
            }
            ExecutionContext ctx = invocationContext
                     .getExecutionContext();
            MethodResult errorResult = MethodResult
                    .newErrorResult(toErrorValue(ex, ctx));
            try {
                asyncHandle.setResult(errorResult);
            } catch (IllegalStateException e) {
                /*
                 * Looks like the invocation is already complete and the async
                 * handle complains that we try to complete it again. Just log
                 * the error and forget about it.
                 */
                logger.error(
                        "Method {} threw an exception after completing the invocation",
                        id, ex);
            }
        }
    }

    private ErrorValue validateInput(DataValue input) {
        ErrorValue validationError = validateInputTypes(input);
        if (validationError != null) {
            return validationError;
        }

        StructValue inStruct = (StructValue) input;
        validationError = validateInput(inStruct);
        if (validationError != null) {
            return validationError;
        }

        StructDefinition inDef = (StructDefinition) (definition.getInputDefinition());
        validationError = checkForUnknownFields(inStruct, inDef);
        if (validationError != null) {
            return validationError;
        }

        // everything is fine
        return null;
    }

    /**
     * Use {@link #toErrorValue(RuntimeException, ExecutionContext)} instead.
     */
    @Deprecated
    protected ErrorValue toErrorValue(RuntimeException ex) {
        return toErrorValue(ex, ExecutionContext.EMPTY);
    }

    protected ErrorValue toErrorValue(RuntimeException ex,
                                      ExecutionContext ctx) {
        return toErrorValue(converter,
                            ex,
                            true, // exception while processing input
                            ctx);
    }

    /**
     * Use
     * {@link #toErrorValue(TypeConveter, RuntimeException, ExecutionContext)}
     * instead.
     */
    public ErrorValue toErrorValue(TypeConverter converter,
                                   RuntimeException ex) {
        return toErrorValue(converter, ex, ExecutionContext.EMPTY);
    }


    public ErrorValue toErrorValue(TypeConverter converter,
                                   RuntimeException ex,
                                   ExecutionContext ctx) {
        return toErrorValue(converter,
                            ex,
                            false,      // exception while processing output
                            ctx);
    }

    /**
     * Converts an exception to an {@link ErrorValue}. Exceptions which do not
     * represent bindings of standard errors are mapped to internal server
     * error.
     * <p>
     * Use
     * {@link #toErrorValue(TypeConverter, RuntimeException, boolean, ExecutionContext)}
     * instead
     *
     * @param converter bindings type converter
     * @param ex exception
     * @param onInput whether the exception was detected during processing of
     *        client input (request) or server output (response);
     * @return error value representation of the exception
     */
    public ErrorValue toErrorValue(TypeConverter converter,
                                   RuntimeException ex,
                                   boolean onInput) {
        return toErrorValue(converter, ex, onInput, ExecutionContext.EMPTY);
    }

    /**
     * Converts an exception to an {@link ErrorValue}. Exceptions which do not
     * represent bindings of standard errors are mapped to internal server
     * error.
     *
     * @param converter bindings type converter
     * @param ex exception
     * @param onInput whether the exception was detected during processing of
     *                client input (request) or server output (response);
     * @param ctx contextual data used for formatting output. This is needed for
     *           localization and other use cases where the request context may
     *           be needed during conversion.
     * @return error value representation of the exception
     */
    public ErrorValue toErrorValue(TypeConverter converter,
                                   RuntimeException ex,
                                   boolean onInput,
                                   ExecutionContext ctx) {

        if (ex instanceof ConstraintValidationException) {
            ConstraintValidationException cve = (ConstraintValidationException) ex;
            logger.error("Validation error", cve);
            return StandardDataFactory.createErrorValueForMessages(
                    onInput ? INVALID_ARGUMENT : INTERNAL_SERVER_ERROR,
                    cve.getExceptionMessages());
        } else if (ex instanceof CoreException) {
            CoreException coreExc = (CoreException) ex;
            // some error in the runtime infrastructure,
            // log it and report INTERNAL_SERVER_ERROR
            logger.error("Invocation error", coreExc);
            return buildInternalServerErrorValue(
                    coreExc.getExceptionMessages());
        } else {
            return convertToErrorValue(converter, ex, ctx);
        }
    }

    private ErrorValue validateInputTypes(DataValue input) {
        if ((input != null && !(input instanceof StructValue)) ||
                !(definition.getInputDefinition() instanceof StructDefinition)) {

            return StandardDataFactory.createErrorValueForMessages(
                    Constants.INVALID_REQUEST,
                    Arrays.asList(MessageFactory.getMessage(
                        "vapi.bindings.method.invalid.input.type")));
        }

        // successful validation
        return null;
    }

    /**
     * Validates the input of the method.
     */
    private ErrorValue validateInput(StructValue input) {
        try {
            ValidatorUtil.validate(inputType, input, id);
        } catch (CoreException ex) {
            return toInvalidArgument(ex);
        }

        // successful validation
        return null;
    }

    /**
     * Check recursively for unknown set fields in a structure.
     *
     * @param inStruct cannot be null
     * @param inDef cannot be null
     * @return <code>ErrorValue</code> if problem is found; or
     *         <code>null</code> otherwise
     */
    private ErrorValue checkForUnknownFields(StructValue inStruct,
                                             StructDefinition inDef) {
        assert inStruct != null;

        try {
            CheckForUnknownFieldsVisitor checkVisitor =
                    new CheckForUnknownFieldsVisitor(inStruct);
            inDef.accept(checkVisitor);
            return checkVisitor.getErrorValue();
        } catch (CoreException ex) {
            return toInvalidArgument(ex);
        }
    }

    private ErrorValue toInvalidArgument(CoreException ex) {
        if (logger.isDebugEnabled()) {
            logger.debug(String.format("Validation failed in method %s",
                                       id),
                         ex);
        }
        return StandardDataFactory.createErrorValueForMessages(
                Constants.INVALID_ARGUMENT,
                ex.getExceptionMessages());
    }

    private ErrorValue convertToErrorValue(TypeConverter converter,
                                           RuntimeException ex,
                                           ExecutionContext ec) {
        ErrorType type = errorClass2Type.get(ex.getClass());
        if (type == null) {
            /*
             * Allow methods to report operation_not_found or unexpected_input.
             * This special case is currently allow so that vSphere providers
             * can report these errors when the feature state is disabled.
             *
             * TODO: This if-statement should be removed when once the
             * runtime/bindings check for invocations using API features that
             * are disabled by feature switch and report these errors (so
             * provider implementations don't need to report them).
             */
            if (ex instanceof ApiError) {
                ApiError apiError = (ApiError)ex;
                StructValue structValue = apiError._getDataValue();
                String structName = structValue.getName();
                if (Constants.OPERATION_NOT_FOUND.equals(structName) ||
                        Constants.UNEXPECTED_INPUT.equals(structName)) {
                    logger.debug("Method implementation threw a VMODL2 error",
                        ex);
                    return (ErrorValue)structValue;
                }
            }

            // unexpected exception from implementation
            // log it and report INTERNAL_SERVER_ERROR
            logger.warn(
                    "Implementation method reported unexpected exception: {}",
                    ex.getClass().getCanonicalName(), ex);

            String msgArg = null;
            if (ex instanceof ApiError) {
                msgArg = ex.getClass().getName();
            } else {
                msgArg = String.valueOf(ex.getMessage()); // can handle null
            }

            List<Message> msgs = new ArrayList<Message>();
            msgs.add(getMessage("vapi.bindings.method.impl.unexpected",
                     msgArg));

            // check if the unexpected exception contains localizable messages
            // and include them in the response
            if (ex instanceof StaticStructure) {
                msgs.addAll(StandardDataFactory.getMessagesFromErrorValue(
                                   ((StaticStructure) ex)._getDataValue()));
            }

            return buildInternalServerErrorValue(msgs);
        }

        logger.debug("Method implementation threw a VMODL2 error", ex);
        ConversionContext cc = new TypeConverter.ConversionContext(ec);
        return (ErrorValue) converter.convertToVapi(ex, type, cc);
    }

    private static ErrorValue buildInternalServerErrorValue(
            List<Message> messages) {
        return StandardDataFactory.createErrorValueForMessages(
                    INTERNAL_SERVER_ERROR,
                    messages);
    }

    /**
     * Invokes the method represented by this <code>ApiMethod</code> on the
     * implementation of the Java bindings interface.
     *
     * @param invocationContext context for the invocation
     * @param inStruct <code>StructValue</code> representing method input
     * @param asyncHandle used to return result or error
     */
    public abstract void doInvoke(InvocationContext invocationContext,
                                  StructValue inStruct,
                                  AsyncHandle<MethodResult> asyncHandle);

    /**
     * This async handle enforces validation of the method results
     */
    private static final class ValidatingAsyncHandle extends AsyncHandle<MethodResult> {

        private final AsyncHandle<MethodResult> decoratedHandle;
        private final Type resultType;
        private final MethodIdentifier methodId;

        /**
         * @param decoratedHandle the handle to be decorated with validation.
         *                        must not be <code>null</code>
         * @param outputType the type of the dataValue that will be validated.
         *                   must not be <code>null</code>
         * @param methodId the method invocation in which context the dataValue
         *                 will be returned. must not be <code>null</code>
         */
        public ValidatingAsyncHandle(AsyncHandle<MethodResult> decoratedHandle,
                                     Type outputType,
                                     MethodIdentifier methodId) {
            Validate.notNull(decoratedHandle);
            Validate.notNull(outputType);
            Validate.notNull(methodId);
            this.decoratedHandle = decoratedHandle;
            this.resultType = outputType;
            this.methodId = methodId;
        }

        @Override
        public void updateProgress(DataValue progress) {
            decoratedHandle.updateProgress(progress);
        }

        @Override
        public void setResult(MethodResult result) {
            ErrorValue error = null;
            try {
                if (result.success()) {
                    DataValue output = result.getOutput();
                    ValidatorUtil.validate(resultType, output, methodId);
                }
            } catch (ConstraintValidationException e) {
                logger.debug("Method {} returned value which does not comply " +
                             "with the declared result type", methodId, e);
                error = StandardDataFactory.createErrorValueForMessages(
                        Constants.INTERNAL_SERVER_ERROR,
                        e.getExceptionMessages());
            }
            if (error == null) {
                decoratedHandle.setResult(result);
            } else {
                decoratedHandle.setResult(MethodResult.newErrorResult(error));
            }
        }

        @Override
        public void setError(RuntimeException error) {
            decoratedHandle.setError(error);
        }
    }

    /**
     * Visitor for traversing a data value and checking for unkown fields in structures
     * (including errors), lists and optional fields.
     */
    private static final class CheckForUnknownFieldsVisitor extends DeepDefinitionVisitor {
        private final DataValue dataValue;
        private ErrorValue error;

        CheckForUnknownFieldsVisitor(DataValue dataValue) {
            this.dataValue = Objects.requireNonNull(dataValue);
        }

        @Override
        public void visit(StructDefinition def) {
            assertType(dataValue.getType(), DataType.STRUCTURE, DataType.ERROR);

            // server bindings should reject method calls containing set
            // (Optional) unknown fields (but still allow unset optional fields)
            StructValue value = (StructValue) dataValue;
            Set<String> unexpected = findUnknownSetOptionalFields(def, value);
            // if there are unexpected fields, set error and stop processing
            if (!unexpected.isEmpty()) {
                error = StandardDataFactory.createErrorValueForMessages(
                        Constants.UNEXPECTED_INPUT,
                        Arrays.asList(MessageFactory.getMessage(
                                "vapi.data.structure.field.unexpected",
                                unexpected.toString(),
                                value.getName())));
                return;
            }
            // process fields recursively
            for (String fieldName : def.getFieldNames()) {
                DataDefinition fieldDef = def.getField(fieldName);

                CheckForUnknownFieldsVisitor checkVisitor =
                        new CheckForUnknownFieldsVisitor(value.getField(fieldName));
                fieldDef.accept(checkVisitor);
                error = checkVisitor.getErrorValue();
                if (error != null) {
                    return;
                }
            }
        }

        @Override
        public void visit(OptionalDefinition def) {
            assertType(dataValue.getType(), DataType.OPTIONAL);

            OptionalValue value = (OptionalValue) dataValue;
            if (value.isSet()) {
                CheckForUnknownFieldsVisitor checkVisitor =
                        new CheckForUnknownFieldsVisitor(value.getValue());
                def.getElementType().accept(checkVisitor);
                error = checkVisitor.getErrorValue();
            }
        }

        @Override
        public void visit(ListDefinition fieldDef) {
            assertType(dataValue.getType(), DataType.LIST);

            ListValue value = (ListValue) dataValue;
            for (DataValue elemValue : value) {
                CheckForUnknownFieldsVisitor checkVisitor =
                        new CheckForUnknownFieldsVisitor(elemValue);
                fieldDef.getElementType().accept(checkVisitor);
                error = checkVisitor.getErrorValue();
                if (error != null) {
                    return;
                }
            }
        }

        private static void assertType(DataType type, DataType... allowed) {
            if (!Arrays.asList(allowed).contains(type)) {
                throw new CoreException("vapi.bindings.typeconverter.unexpected.type",
                                        Arrays.toString(allowed),
                                        type.toString());
            }
        }

        private ErrorValue getErrorValue() {
            return error;
        }

        /**
         * Return a set with structure field names which are set and
         * unknown: fields which are not optional plus fields which are
         * optional and are set.
         *
         * @param def
         * @param value
         * @return
         */
        private static Set<String> findUnknownSetOptionalFields(StructDefinition def,
                                                                StructValue value) {
            Set<String> unexpected = new HashSet<String>();
            for (String fieldName : value.getFieldNames()) {
                if (!def.hasField(fieldName)) {
                    DataValue fieldValue = value.getField(fieldName);
                    if (!(fieldValue instanceof OptionalValue)
                        || ((OptionalValue) fieldValue).isSet()) {
                        unexpected.add(fieldName);
                    }
                }
            }

            return unexpected;
        }
    }
}
