/* **********************************************************
 * Copyright (c) 2012-2014, 2020-2022 VMware, Inc.  All rights reserved. -- VMware Confidential
 * **********************************************************/

package com.vmware.vapi.internal.protocol.client.msg.json;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.UnsupportedEncodingException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.vmware.vapi.client.exception.MessageProtocolException;
import com.vmware.vapi.core.ApiProvider;
import com.vmware.vapi.core.AsyncHandle;
import com.vmware.vapi.core.Consumer;
import com.vmware.vapi.core.ExecutionContext;
import com.vmware.vapi.core.ExecutionContext.ApplicationData;
import com.vmware.vapi.core.MethodResult;
import com.vmware.vapi.data.DataValue;
import com.vmware.vapi.data.ErrorValue;
import com.vmware.vapi.data.StringValue;
import com.vmware.vapi.diagnostics.LogDiagnosticUtil;
import com.vmware.vapi.internal.core.abort.AbortHandle;
import com.vmware.vapi.internal.core.abort.AbortHandleProvider;
import com.vmware.vapi.internal.protocol.client.rpc.CorrelatingClient;
import com.vmware.vapi.internal.protocol.client.rpc.CorrelatingClient.ResponseCallback;
import com.vmware.vapi.internal.protocol.client.rpc.CorrelatingClient.ResponseCallbackFactory;
import com.vmware.vapi.internal.protocol.client.rpc.CorrelatingClient.ResponseCallbackParams;
import com.vmware.vapi.internal.protocol.client.rpc.CorrelatingClient.SendParams;
import com.vmware.vapi.internal.protocol.client.rpc.ExecutorAsyncHandle;
import com.vmware.vapi.internal.protocol.common.Util;
import com.vmware.vapi.internal.protocol.common.json.JsonApiRequest;
import com.vmware.vapi.internal.protocol.common.json.JsonApiResponse;
import com.vmware.vapi.internal.protocol.common.json.JsonBaseResponse;
import com.vmware.vapi.internal.protocol.common.json.JsonConstants.RequestType;
import com.vmware.vapi.internal.protocol.common.json.JsonDeserializer;
import com.vmware.vapi.internal.protocol.common.json.JsonInvokeParams;
import com.vmware.vapi.internal.protocol.common.json.JsonInvokeRequestParams2;
import com.vmware.vapi.internal.protocol.common.json.JsonMsgDeserializer2;
import com.vmware.vapi.internal.protocol.common.json.JsonMsgSerializer2;
import com.vmware.vapi.internal.protocol.common.json.JsonProgressResponse;
import com.vmware.vapi.internal.protocol.common.msg.JsonMessageProtocolExceptionTranslator;
import com.vmware.vapi.internal.tracing.TracingScope;
import com.vmware.vapi.internal.tracing.TracingSpan;
import com.vmware.vapi.internal.tracing.otel.TracingAttributeKey;
import com.vmware.vapi.internal.util.StringUtils;
import com.vmware.vapi.internal.util.io.IoUtil;
import com.vmware.vapi.internal.util.io.ReleasableInputStream;
import com.vmware.vapi.protocol.ClientConfiguration;
import com.vmware.vapi.protocol.Constants;
import com.vmware.vapi.protocol.RequestProcessor;
import com.vmware.vapi.protocol.common.http.HttpConstants;
import com.vmware.vapi.tracing.Tracer;


/**
 * JSON Client side ApiProvider implementation.
 */
public final class JsonApiProvider implements ApiProvider {
    private static final String UTF8_CHARSET = "UTF-8";

    private static Logger logger =
            LoggerFactory.getLogger(JsonApiProvider.class);

    // serializer for the JSON message format
    private final JsonMsgSerializer2 jsonSerializer2;

    // Map<MIME type, JsonDeserializer>
    private Map<String, JsonDeserializer> deserializers;

    private final List<RequestProcessor> processorChain;

    private final CorrelatingClient client;

    private final Executor executor;

    private final Tracer tracer;


    /**
     * Constructor.
     *
     * @param client Instance of RpcClient to send/receive requests/responses
     * @param config The messaging client configuration object. can be null.
     */
    public JsonApiProvider(CorrelatingClient client,
                           ClientConfiguration config) {
        this(client,config, prepareDeserializers(false));
    }
    /**
     * Constructor.
     *
     * @param client Instance of RpcClient to send/receive requests/responses
     * @param config The messaging client configuration object. can be null.
     * @param deserializers association of accepted mime types to
     *        {@link JsonDeserializer} instances
     */
    public JsonApiProvider(CorrelatingClient client,
                           ClientConfiguration config,
                           Map<String, JsonDeserializer> deserializers) {
        this.client = client;
        if (config == null) {
            config = new ClientConfiguration.Builder().getConfig();
        }
        this.processorChain = Collections.unmodifiableList(config.getRequestProcessors());

        this.jsonSerializer2 = new JsonMsgSerializer2();

        this.deserializers = deserializers;

        this.executor = config.getExecutor();
        // TODO add unit tests for the tracer functionality
        this.tracer = config.getTracer();
    }

    public static Map<String, JsonDeserializer> prepareDeserializers(ClientConfiguration clientConfig) {
        boolean streaming = clientConfig == null ? true
                                            : clientConfig.isSteamingEnabled();
        return prepareDeserializers(streaming);
    }

    private static Map<String, JsonDeserializer> prepareDeserializers(boolean enableStreaming) {
        JsonDeserializer defaultJson = new JsonMsgDeserializer2();
        // JsonDeserializer cleanJson = new JsonMsgDeserializer2(new JsonDirectDeserializer());

        Map<String, JsonDeserializer> deserializers = new HashMap<>();

        deserializers.put(HttpConstants.CONTENT_TYPE_JSON, defaultJson);
        deserializers.put(HttpConstants.CONTENT_TYPE_FRAMED, defaultJson);
        if (enableStreaming) {
            deserializers.put(HttpConstants.CONTENT_TYPE_STREAM_JSON,
                              defaultJson);
        }
        // TODO The code below enables clean JSON
        // if (enableCleanJson) {
        //  deserializers.put(HttpConstants.CONTENT_TYPE_CLEAN_JSON,
        //                    cleanJson);
        // }
        // if (enableCleanJson && enableStreaming) {
        //  deserializers.put(HttpConstants.CONTENT_TYPE_CLEAN_STREAM_JSON,
        //                    cleanJson);
        // }
        return deserializers;
    }

    /**
     * Helper method to send request to the underlying HTTP/AMQP RPC layer.
     *
     * @param jsonRequestObj  JSON-RPC 2.0 request to be sent
     */
    private void sendRequest(Object jsonRequestObj,
                             Map<String, Object> processorMetaData,
                             ExecutionContext ctx,
                             AbortHandle abortHandle,
                             ResponseCallbackFactory cbFactory,
                             String serviceId,
                             String operationId,
                             TracingSpan tracingSpan) {
        if (Util.checkRequestAborted(abortHandle, cbFactory)) {
            // Validate that the request is not aborted and if it is - return
            // as we don't need to serialize it.
            return;
        }

        byte[] jsonRequest = jsonSerializer2.serialize(jsonRequestObj);

        // execute post processors
        for (RequestProcessor proc : processorChain) {
            if (Util.checkRequestAborted(abortHandle, cbFactory)) {
                // If the request is aborted - stop all processing.
                return;
            }

            // TODO the last parameter should not be null
            jsonRequest = proc.process(jsonRequest, processorMetaData, null);
        }

        if (logger.isDebugEnabled()
            && Constants.shouldLogRawRequestResponse()) {
            try {
                String jsonRequestString =
                        new String(jsonRequest, UTF8_CHARSET);
                logger.debug("JSON request: {}", jsonRequestString);
            } catch (UnsupportedEncodingException ex) {
                logger.debug("Cound not decode JSON request", ex);
            }
        }
        logger.debug("Sending request of size: {}", jsonRequest.length);

        InputStream requestAsByteIS = new ByteArrayInputStream(jsonRequest);
        InputStream releasableRequest = new ReleasableInputStream(requestAsByteIS);

        client.send(new SendParams(serviceId,
                       operationId,
                       releasableRequest,
                       jsonRequest.length,
                       ctx,
                       cbFactory,
                       abortHandle,
                       HttpConstants.CONTENT_TYPE_JSON,
                       deserializers.keySet(),
                       tracingSpan));
    }

    // TODO: !!! Make sure all deserializers use UTF-8 while reading the InputStream !!!

    /**
     * Method to generate random UUID to be used in JSON-RPC 2.0 requests
     *
     * @return Random UUID
     */
    private static String generateUUID() {
        return UUID.randomUUID().toString();
    }

    /**
     * Helper method to verify if the JSON-RPC 2.0 response message has the same
     * ID as the corresponding JSON-RPC 2.0 request
     *
     * @param response
     *            JSON response
     * @param UUID
     *            UUID sent in the JSON request
     * @throws MessageProtocolException
     */
    private void checkResponseId(JsonBaseResponse response, String UUID) {
        String responseId = response.getId();
        if (!responseId.equals(UUID)) {
            throw new MessageProtocolException("UUID mismatch in getProviderIdentifier");
        }
    }


    /**
     * A ResponseCallbackFactory implementation responsible for creating
     * response {link CorrelatingClient.ResponseCallback} with different
     * deserializers depending on the {@link ResponseCallbackParams}.
     *
     */
    private class ResponseCallbackFactoryImpl
            implements CorrelatingClient.ResponseCallbackFactory {

        private final AsyncHandle<MethodResult> asyncHandle;
        private final AbortHandle abortHandle;
        private final String UUID;
        private final TracingSpan tracingSpan;

        public ResponseCallbackFactoryImpl(AsyncHandle<MethodResult> asyncHandle,
                                           AbortHandle abortHandle,
                                           String UUID,
                                           TracingSpan tracingSpan) {
            this.asyncHandle = asyncHandle;
            this.abortHandle = abortHandle;
            this.UUID = UUID;
            this.tracingSpan = tracingSpan;
        }

        @Override
        public ResponseCallback createResponseCallback(ResponseCallbackParams params) {
            // The default deserializer is "dirty"
            JsonDeserializer deserializer = new JsonMsgDeserializer2();
            String contentType = params.getContentType();
            for (String key : deserializers.keySet()) {
                if(contentType != null && contentType.contains(key)) {
                    deserializer = deserializers.get(key);
                }
            }

            return decorateCallbackWithExceptionTranslation(new ResponseCallbackImpl(asyncHandle,
                                                                                     abortHandle,
                                                                                     UUID,
                                                                                     deserializer,
                                                                                     tracingSpan));
        }

        @Override
        public void failed(RuntimeException error) {
            logger.debug("Abort prior callback creation.");
            asyncHandle.setError(JsonMessageProtocolExceptionTranslator.translate(error));
        }
    }

    /**
     * Class for a callbacks that parse JSON responses. Responses come as JSON
     * text and have to be converted to the object model used by
     * {@link ApiProvider}.
     */
    private class ResponseCallbackImpl implements
            CorrelatingClient.ResponseCallback {

        private AsyncHandle<MethodResult> asyncHandle;
        private final AbortHandle abortHandle;
        private final String UUID;
        private final JsonDeserializer deserializer;

        private volatile boolean terminalFrameReceived = false;

        ResponseCallbackImpl(AsyncHandle<MethodResult> asyncHandle,
                             AbortHandle abortHandle,
                             String UUID,
                             JsonDeserializer deserializer,
                             TracingSpan tracingSpan) {
            this.asyncHandle = asyncHandle;
            this.abortHandle = abortHandle;
            this.UUID = UUID;
            this.deserializer = deserializer;
        }

        @Override
        public void failed(RuntimeException error) {
            logger.debug(String.format("Callback - %h, failed.", this));
            setError(JsonMessageProtocolExceptionTranslator.translate(error));
        }

        @Override
        public void completed() {
            logger.debug(String
                    .format("Streaming has been completed in callback - %h.",
                            this));
            if (!terminalFrameReceived && asyncHandle != null) {
                logger.debug(String
                        .format("Terminal frame was not received in callback - %h.",
                                this));
                setResult(MethodResult.EMPTY);
            }
        }

        @Override
        public void received(InputStream response,
                             final CorrelatingClient.TransportControl control) {
            if (logger.isDebugEnabled()
                && Constants.shouldLogRawRequestResponse()) {
                try {
                    byte[] responseBytes = IoUtil.readAll(response);
                    String responseText = new String(responseBytes, UTF8_CHARSET);
                    logger.debug("JSON response: {}", responseText);
                    response = new ByteArrayInputStream(responseBytes);
                } catch (IOException ex) {
                    logger.debug("Could not log JSON response", ex);
                }
            }

            if (Util.checkRequestAborted(abortHandle, this)) {
                // Error is set, no need to process further (no need to
                // deserialize the response).
                return;
            }

            if (terminalFrameReceived) {
                logger.error(String
                        .format("Terminal frame received in callback - %h, but reading is not terminated.",
                                this));
                setError(JsonMessageProtocolExceptionTranslator
                        .translate(new IllegalStateException("The client found a terminal frame, but server keeps transmitting.")));
                return;
            }

            MethodResult result = null;
            DataValue progress = null;
            JsonBaseResponse jsonObj = deserializer.responseDeserialize(response, RequestType.invoke);
            if (jsonObj instanceof JsonProgressResponse) {
                progress = ((JsonProgressResponse) jsonObj).retrieveProgress();
            } else if (jsonObj instanceof JsonApiResponse){
                result = decodeResponse((JsonApiResponse) jsonObj);
                terminalFrameReceived = ((JsonApiResponse) jsonObj).isLast();
            } else {
                throw new RuntimeException("Unknown JsonBaseResponse sub-type");
            }

            if (Util.checkRequestAborted(abortHandle, this)) {
                // Error is set, no need to set the result.
                return;
            }

            // Simulate back pressure and delegate demand to
            // {@link Subscription} request
            if (!terminalFrameReceived) {
                logger.trace(String
                        .format("Suspending I/O for TransportControl - %h.",
                                control));
                control.suspendRead();
            }

            if (result != null && !terminalFrameReceived) {
                logger.trace("Next handle consumer added.");
                result.setNext(new Consumer<AsyncHandle<MethodResult>>() {
                    private AtomicBoolean accepted = new AtomicBoolean();

                    @Override
                    public void accept(AsyncHandle<MethodResult> handle) {
                        if (!accepted.compareAndSet(false, true)) {
                            logger.debug(String
                                    .format("Consumer and handle already utilized."));
                            return;
                        }
                        logger.trace(String.format("New handle provided - %h.",
                                                   handle));
                        handle = tryDecorateHandleWithExecutor(handle, executor);
                        updateAsyncHandle(handle);
                        control.resumeRead();
                    }
                });
            }

            if (progress != null) {
                updateProgress(progress);
            } else {
                setResult(result);
            }
        }

        private synchronized void updateAsyncHandle(AsyncHandle<MethodResult> newHandle) {
            if (asyncHandle == null) {
                logger.trace(String
                        .format("Updating async handle with new handle - %h.",
                                newHandle));
                asyncHandle = newHandle;
            } else {
                logger.error(String
                        .format("Two handles exist at the same time for a single server response;"
                                + "old handle - %h, new handle - %h."),
                             asyncHandle,
                             newHandle);
                asyncHandle
                        .setError(new IllegalStateException("Two handles exist at the same time for a single server response."));
            }
        }

        private synchronized void updateProgress(DataValue progress) {
            if (asyncHandle != null) {
                logger.trace(String
                        .format("Update operation progress via handle - %h.",
                                asyncHandle));
                asyncHandle.updateProgress(progress);
            } else {
                logger.error(String
                        .format("Async handle missing in callback - %h, progress could not be updated.",
                                this));
            }
        }

        private synchronized void setResult(MethodResult result) {
            if (asyncHandle != null) {
                logger.trace(String
                        .format("Setting result. The current handle - %h, is cleared.",
                                asyncHandle));
                // We want to be certain that asyncHandle is null,
                // prior the next Consumer#accept
                AsyncHandle<MethodResult> temp = asyncHandle;
                asyncHandle = null;
                temp.setResult(result);
            } else {
                logger.error(String
                        .format("Async handle missing in callback - %h, result will not be received.",
                                this));
            }
        }

        private synchronized void setError(RuntimeException error) {
            if (asyncHandle != null) {
                logger.trace(String.format("Setting error via handle - %h.",
                                           asyncHandle));
                asyncHandle.setError(error);
                asyncHandle = null;
            } else {
                logger.error(String
                        .format("Async handle missing in callback - %h, error will not be received.",
                                this));
            }
        }

        /**
         * Converts the response from an {@link ApiProvider} method invocation.
         *
         * @param jsonResponse the response as an object from the JSON
         *        deserializer object model
         * @return the response as an object from the {@link ApiProvider} object
         *         model
         * @throws MessageProtocolException problem with the response
         */
        protected MethodResult decodeResponse(JsonApiResponse jsonResponse) {
            checkResponseId(jsonResponse, UUID);
            return jsonResponse.getResult();
        };
    }

    private CorrelatingClient.ResponseCallback
            decorateCallbackWithExceptionTranslation(
                    final CorrelatingClient.ResponseCallback cb) {
        return new CorrelatingClient.ResponseCallback() {
            @Override
            public void received(InputStream response,
                                 CorrelatingClient.TransportControl control) {
                try {
                    cb.received(response, control);

                } catch (Exception e) {
                    throw JsonMessageProtocolExceptionTranslator.translate(e);
                }
            }

            @Override
            public void failed(RuntimeException error) {
                try {
                    cb.failed(error);
                } catch (Exception e) {
                    throw JsonMessageProtocolExceptionTranslator.translate(e);
                }
            }

            @Override
            public void completed() {
                try {
                    cb.completed();
                } catch (Exception e) {
                    throw JsonMessageProtocolExceptionTranslator.translate(e);
                }
            }
        };
    }

    @Override
    public void invoke(String serviceId,
                       String operationId,
                       DataValue input,
                       ExecutionContext ctx,
                       final AsyncHandle<MethodResult> asyncHandle) {
        AbortHandle requestAbortHandle =
                (asyncHandle instanceof AbortHandleProvider) ?
                        ((AbortHandleProvider)asyncHandle).getAbortHandle():
                        null;

        final String UUID = generateUUID();

        AsyncHandle<MethodResult> resultHandle = null;
        boolean isTaskOperation = operationId.endsWith("$task");
        TracingSpan tracingSpan = startTracingSpan(
               ctx, serviceId, operationId, isTaskOperation);
        try (TracingScope tracingScope = tracingSpan.makeCurrent()) {
            resultHandle = tryDecorateHandleWithExecutor(
                                decorateHandleWithTracing(asyncHandle,
                                                          isTaskOperation,
                                                          tracingSpan),
                                executor);

            JsonInvokeParams jsonInvokeParams =
                    new JsonInvokeRequestParams2(serviceId,
                                                 operationId,
                                                 ctx,
                                                 input);
            JsonApiRequest jsonApiRequest = new JsonApiRequest(UUID,
                                                               jsonInvokeParams);

            ResponseCallbackFactory cb = new ResponseCallbackFactoryImpl(resultHandle,
                                                                         requestAbortHandle,
                                                                         UUID,
                                                                         tracingSpan);

            Map<String, Object> metadata = new HashMap<>();
            metadata.put(RequestProcessor.SECURITY_CONTEXT_KEY,
                         jsonInvokeParams.getCtx().retrieveSecurityContext());
            sendRequest(jsonApiRequest,
                        metadata,
                        ctx,
                        requestAbortHandle,
                        cb,
                        serviceId,
                        operationId,
                        tracingSpan);

            // we won't end the span here because the span's execution may proceed
            // asynchronously past this point; the span will be ended by the
            // response callback created by the 'cb' factory (instantiated above)

        } catch (RuntimeException e) {
            if (resultHandle != null) {
                // the resultHandle will end the tracingSpan because the
                // resultHandle is decorated with tracing (see the invocation
                // of decorateHandleWithTracing(...) above)
                resultHandle.setError(JsonMessageProtocolExceptionTranslator
                        .translate(e));
            } else {
                try {
                    tracingSpan.setStatusError(e);
                } finally {
                    tracingSpan.end();
                }
            }
        }
    }

    private TracingSpan startTracingSpan(ExecutionContext execCtx,
                                         String serviceId,
                                         String operationId,
                                         boolean isTaskOperation) {

        TracingSpan tracingSpan = tracer.createClientSpan(serviceId + "." + operationId);

        ApplicationData appData = (execCtx != null ? execCtx.retrieveApplicationData() : null);
        if (appData != null) {
            String opId = appData.getProperty(LogDiagnosticUtil.OPERATION_ID);
            if (!StringUtils.isBlank(opId)) {
                tracingSpan.setAttribute(TracingAttributeKey.OP_ID, opId);
            }
        }

        if (isTaskOperation) {
            tracingSpan.setAttribute(TracingAttributeKey.TASK, true);
        }

        return tracingSpan;
    }

    /**
     * Decorates AsyncHandle with executor if both executor and specified handle
     * are present (not null). Otherwise, the non decorated handle is returned.
     *
     * @param ah the async handle to be decorated; may be {@code null}
     * @param executor; may be {@code null}
     * @return the decorated handle.
     */
    private AsyncHandle<MethodResult> tryDecorateHandleWithExecutor(AsyncHandle<MethodResult> ah,
                                                                 Executor executor) {
        if (executor != null && ah != null) {
            logger.trace(String
                    .format("Decorating async handle - %h, with executor.",
                            ah));
            return new ExecutorAsyncHandle<MethodResult>(ah, executor);
        }
        return ah;
    }

    private static AsyncHandle<MethodResult> decorateHandleWithTracing(
            final AsyncHandle<MethodResult> asyncHandle,
            final boolean isTaskOperation,
            final TracingSpan span) {
        return new AsyncHandle<MethodResult>() {

            @Override
            public void updateProgress(DataValue progress) {
                asyncHandle.updateProgress(progress);
            }

            @Override
            public void setResult(MethodResult result) {
                try {
                    DataValue resultOutput = result.getOutput();
                    if (result.success()) {
                        if (isTaskOperation && resultOutput instanceof StringValue) {
                            span.setAttribute(TracingAttributeKey.TASK_ID,
                                              ((StringValue)resultOutput).getValue());
                        }
                        span.setStatusOk();
                    } else {
                        ErrorValue errorValue = result.getError();
                        span.setStatusError(errorValue);
                    }
                } finally {
                    span.end();
                }
                asyncHandle.setResult(result);
            }

            @Override
            public void setError(RuntimeException error) {
                try {
                    span.setStatusError(error);
                } finally {
                    span.end();
                }
                asyncHandle.setError(error);
            }
        };
    }
}
