/* **********************************************************
 * Copyright (c) 2012-2015, 2017, 2019, 2022-2023 VMware, Inc. All rights reserved. -- VMware Confidential
 * ********************************************************* */

package com.vmware.vapi.protocol.server.msg.json;

import static com.vmware.vapi.internal.tracing.otel.TracingAttributeKey.WIRE_PROTOCOL;
import static com.vmware.vapi.protocol.RequestProcessor.SECURITY_PROC_METADATA_KEY;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.UnsupportedEncodingException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.vmware.vapi.core.ApiProvider;
import com.vmware.vapi.core.AsyncHandle;
import com.vmware.vapi.core.ExecutionContext;
import com.vmware.vapi.core.ExecutionContext.ApplicationData;
import com.vmware.vapi.core.ExecutionContext.SecurityContext;
import com.vmware.vapi.core.MethodResult;
import com.vmware.vapi.data.DataValue;
import com.vmware.vapi.internal.protocol.common.json.JsonApiRequest;
import com.vmware.vapi.internal.protocol.common.json.JsonApiResponse;
import com.vmware.vapi.internal.protocol.common.json.JsonBaseResponse;
import com.vmware.vapi.internal.protocol.common.json.JsonContextBuilderRequestParams;
import com.vmware.vapi.internal.protocol.common.json.JsonContextBuilderRequestParams.ExecutionContextBuilder;
import com.vmware.vapi.internal.protocol.common.json.JsonDeserializer;
import com.vmware.vapi.internal.protocol.common.json.JsonError;
import com.vmware.vapi.internal.protocol.common.json.JsonErrorResponse;
import com.vmware.vapi.internal.protocol.common.json.JsonInvalidContext;
import com.vmware.vapi.internal.protocol.common.json.JsonInvalidDataValueException;
import com.vmware.vapi.internal.protocol.common.json.JsonInvalidMethodException;
import com.vmware.vapi.internal.protocol.common.json.JsonInvalidMethodParamsException;
import com.vmware.vapi.internal.protocol.common.json.JsonInvalidRequest;
import com.vmware.vapi.internal.protocol.common.json.JsonInvokeParams;
import com.vmware.vapi.internal.protocol.common.json.JsonMsgDeserializer2;
import com.vmware.vapi.internal.protocol.common.json.JsonMsgSerializer2;
import com.vmware.vapi.internal.protocol.common.json.JsonProgressResponse;
import com.vmware.vapi.internal.protocol.common.json.JsonSerializer;
import com.vmware.vapi.internal.tracing.TracingSpan;
import com.vmware.vapi.internal.util.Validate;
import com.vmware.vapi.internal.util.io.IoUtil;
import com.vmware.vapi.protocol.Constants;
import com.vmware.vapi.protocol.RequestProcessor;
import com.vmware.vapi.protocol.RequestProcessor.Request;
import com.vmware.vapi.protocol.server.rpc.RequestReceiver;
import com.vmware.vapi.security.SessionSecurityContext;
import com.vmware.vapi.security.StdSecuritySchemes;

public final class JsonServerConnection implements RequestReceiver {
    private static final String APP_DATA_HEADER_PREFIX = "vapi-ctx-";
    static final String[] APP_DATA_SPECIAL_KEYS = { "opId", "actId",
         "$showUnreleasedAPIs", "$userAgent", "$doNotRoute",
         "vmwareSessionId", "ActivationId" };
    static final String[] APP_DATA_SPECIAL_KEYS_LOWER;

    static {
       APP_DATA_SPECIAL_KEYS_LOWER = new String[APP_DATA_SPECIAL_KEYS.length];
       for (int i = 0; i < APP_DATA_SPECIAL_KEYS.length; i++) {
          APP_DATA_SPECIAL_KEYS_LOWER[i] = APP_DATA_SPECIAL_KEYS[i].toLowerCase(Locale.ENGLISH);
       }
    }

    private static Logger logger =
        LoggerFactory.getLogger(JsonServerConnection.class);

    private final ApiProvider provider;

    // (de)serializer for the new JSON protocol
    private final JsonSerializer jsonSerializer2;
    private final JsonDeserializer jsonDeserializer2;

    private final List<RequestProcessor> processorChain;

    // JSON-RPC 2.0 Error constants
    private static final int INVALID_PARAMS_CODE = -32602;
    private static final int INVALID_REQUEST_CODE = -32600;
    private static final int INVALID_CONTEXT_CODE = -31001;
    private static final int METHOD_NOT_FOUND_CODE = -32601;
    private static final int INTERNAL_JSONRPC_ERROR_CODE = -32603;
    private static final String INVALID_PARAMS_MSG = "Invalid params";
    private static final String INVALID_REQUEST_MSG = "Invalid Request";
    private static final String INVALID_CONTEXT_MSG = "Invalid context";
    private static final String METHOD_NOT_FOUND_MSG = "Method not found";
    private static final String INTERNAL_JSONRPC_ERROR_MSG = "Internal error";
    private static final String UTF8_CHARSET = "UTF-8";

    /**
     * Creates a handler which translates JSON-RPC 2.0 requests to
     * {@link ApiProvider} method calls.
     *
     * @param provider all {@link ApiProvider} method calls will be delegated
     *                     to this provider; must not be <code>null</code>
     * @param processorChain chain of processors which will be invoked on each
     *                       JSON request; the input of each processor is the
     *                       output of the previous processor in the chain; the
     *                       input to the first processor is the original JSON
     *                       request; processors are executed before the
     *                       request gets translated into an {@link ApiProvider}
     *                       call; must not be <code>null</code>
     */
    public JsonServerConnection(ApiProvider provider,
                                List<RequestProcessor> processorChain) {
        Validate.notNull(provider);
        Validate.notNull(processorChain);

        this.provider = provider;

        this.jsonDeserializer2 = new JsonMsgDeserializer2();
        this.jsonSerializer2 = new JsonMsgSerializer2();

        this.processorChain = Collections.unmodifiableList(processorChain);
    }

    @Override
    public void requestReceived(InputStream request, TransportContext transport)
            throws IOException {
        Validate.notNull(request);
        Validate.notNull(transport);
        processRequest(request, transport);
    }

    private void processRequest(InputStream request, TransportContext transport)
    throws IOException, UnsupportedEncodingException {
        // TODO: don't copy/buffer, just feed the stream into the
        //       JSON deserializer (once we get rid of the older one)
        byte[] buffer = IoUtil.readAll(request);
        if (buffer.length == 0) {
            // the request is empty, stop processing and return an error
            logger.error("Empty request");
            JsonBaseResponse response = createErrorResponse(null,
                                                            INVALID_REQUEST_CODE,
                                                            INVALID_REQUEST_MSG,
                                                            "Empty request");
            sendResponse(response, transport, true);
            return;
        }
        logger.debug("Received request of size: {}", buffer.length);

        String jsonRpcRequest = new String(buffer, UTF8_CHARSET);
        if (logger.isDebugEnabled() && Constants.shouldLogRawRequestResponse()) {
            logger.debug("JSON request: {}", jsonRpcRequest);
        }
        JsonApiRequest jsonRpcRequestObj = null;
        JsonBaseResponse errorResponse = null;
        String id = null;
        Map<String, Object> procMetaData = new HashMap<>();
        try {
            logger.debug("Deserializing JSON request");
            jsonRpcRequestObj = jsonDeserializer2.requestDeserialize(
                    jsonRpcRequest);
            JsonContextBuilderRequestParams params =
                  (JsonContextBuilderRequestParams) (jsonRpcRequestObj.getParams());
            RequestContext context = transport.getRequestContext();
            verifyContext(context, params);
            processRequestContext(context, params);
            processContextApplicationData(context.getAllProperties(),
                                          params.getCtxBuilder().applicationData);
            fixApplicationDataKeyCasing(params.getCtxBuilder().applicationData);
            processContextSession(context.getSession(),
                                  params.getCtxBuilder().security);

            // execute request pre-processors
            // TODO error handling - request processor errors should not be
            // converted to JSON RPC errors

            for (RequestProcessor proc : processorChain) {
                // TODO: need better toString for RequestProcessor impls
                logger.debug("Execution request pre-processor {}", proc.toString());
                buffer = proc.process(buffer, procMetaData, params);
            }

        } catch (JsonInvalidMethodParamsException e) {
            logger.debug("Invalid parameters: {}", e.getMessage(), e);
            errorResponse = createErrorResponse(e.getId(),
                                                INVALID_PARAMS_CODE,
                                                INVALID_PARAMS_MSG,
                                                e.getMessage());
        } catch (JsonInvalidContext e) {
           logger.debug("Invalid JSON context: {}", e.getMessage(), e);
           errorResponse = createErrorResponse(null,
                                               INVALID_CONTEXT_CODE,
                                               INVALID_CONTEXT_MSG,
                                               e.getMessage());
        } catch (JsonInvalidRequest e) {
            logger.debug("Invalid JSON request: {}", e.getMsg(), e);
            errorResponse = createErrorResponse(null,
                                                INVALID_REQUEST_CODE,
                                                INVALID_REQUEST_MSG,
                                                e.getMsg());
        } catch (JsonInvalidDataValueException e) {
            logger.debug("Invalid DataValue JSON detected in request: {}", e.getMessage(), e);
            errorResponse = createErrorResponse(null,
                                                INVALID_REQUEST_CODE,
                                                INVALID_REQUEST_MSG,
                                                e.getMessage());
        } catch (JsonInvalidMethodException e) {
            logger.debug("Method not found: {}", e.getMsg(), e);
            errorResponse = createErrorResponse(e.getId(),
                                                METHOD_NOT_FOUND_CODE,
                                                METHOD_NOT_FOUND_MSG,
                                                e.getMsg());
        } catch (Exception e) {
            // This block catches IOException and JsonParseException that
            // JsonParser might throw as well as any kind of RuntimeException.
            logger.debug("JSON-RPC error", e);
            errorResponse = createErrorResponse(null,
                                                INTERNAL_JSONRPC_ERROR_CODE,
                                                INTERNAL_JSONRPC_ERROR_MSG,
                                                e.getMessage());
        } catch (Throwable t) {
            logger.error("Severe internal error", t);
            errorResponse = createErrorResponse(null,
                                                INTERNAL_JSONRPC_ERROR_CODE,
                                                INTERNAL_JSONRPC_ERROR_MSG,
                                                t.getMessage());
        }

        if (errorResponse != null) {
            logger.error("Stop processing invalid JSON request: {}",
                    jsonRpcRequest);
            // the request is invalid, stop processing and return an error
            sendResponse(errorResponse, transport, true);
        } else {
            id = jsonRpcRequestObj.getId();
            logger.debug("Processing JSON request with id {} for method {}",
                    jsonRpcRequestObj.getId(), jsonRpcRequestObj.getMethod());
            processApiRequest(jsonRpcRequestObj,
                              id,
                              procMetaData,
                              transport);
        }
    }

    /**
     * Asserts that important contextual data (serviceId, operationId) matches the one found inside
     * the JSON RPC request itself in order to provide early warnings in case of routing mishaps.
     *
     * <p>
     * Does nothing if both contextual {@code serviceId} and {@code operationId} are missing in
     * order to allow requests with JSON RPC 1.0.
     * </p>
     *
     * @param requestContext the contextual data to check
     * @param request the source of truth
     * @throws JsonInvalidContext upon encountering invalid context
     */
    static void verifyContext(RequestContext requestContext, Request request)
            throws JsonInvalidContext {
        String serviceId = requestContext.getServiceId();
        String operationId = requestContext.getOperationId();
        if (serviceId == null) {
            if (operationId == null) {
                return;
            }
            throw new JsonInvalidContext("Missing vapi-service HTTP header");
        }
        if (operationId == null) {
            throw new JsonInvalidContext("Missing vapi-operation HTTP header");
        }
        if (!serviceId.equals(request.getServiceId())) {
            throw new JsonInvalidContext("Mismatching service identifier in "
                    + "HTTP header and payload");
        }
        if (!operationId.equals(request.getOperationId())) {
            throw new JsonInvalidContext("Mismatching operation identifier in "
                    + "HTTP header and payload");
        }
    }

    static void processRequestContext(RequestContext context,
                                      JsonContextBuilderRequestParams params) {
        updateTracingSpan(context, params);
    }

    static void updateTracingSpan(RequestContext context, JsonContextBuilderRequestParams params) {
        String spanName = params.getServiceId() + "." + params.getOperationId();
        context.getTracingSpan().updateName(spanName);
        context.getTracingSpan().setAttribute(WIRE_PROTOCOL, context.getJsonRpcVersion());
    }

    /**
     * Overrides values in the specified {@code applicationData} with values
     * specified in {@code properties}. Only properties with the {@code vapi-ctx-}
     * prefix are considered.
     *
     * @param properties a map of all context properties
     * @param applicationData the map to write values into
     */
    static void processContextApplicationData(Map<String, String> properties,
                                              Map<String, String> applicationData) {
        for (Map.Entry<String, String> entry : properties.entrySet()) {
            String headerName = entry.getKey();
            if (headerName.regionMatches(true,
                                         0,
                                         APP_DATA_HEADER_PREFIX,
                                         0,
                                         APP_DATA_HEADER_PREFIX.length())) {
                String appDataKey = headerName
                        .substring(APP_DATA_HEADER_PREFIX.length())
                        .toLowerCase(Locale.ENGLISH);
                applicationData.put(appDataKey, entry.getValue());
            }
        }
    }

    static void processContextSession(String session,
                                      Map<String, Object> security) {
        if (session == null) {
            return;
        }
        security.clear();
        security.put(SecurityContext.AUTHENTICATION_SCHEME_ID,
                     StdSecuritySchemes.SESSION);
        security.put(SessionSecurityContext.SESSION_ID_KEY, session);
    }

    /**
     * The specified {@code applicationData} is supposed to have all of its keys
     * in lower-case. This method goes over all the keys and replaces some well
     * known ones with their camel-case equivalent, i.e. opid -> opId.
     * @param applicationData the map to process
     */
    static void fixApplicationDataKeyCasing(Map<String, String> applicationData) {
        for (int i = 0; i < APP_DATA_SPECIAL_KEYS.length; i++) {
            String lowerCaseKey = APP_DATA_SPECIAL_KEYS_LOWER[i];
            if (applicationData.containsKey(lowerCaseKey)) {
                applicationData.put(APP_DATA_SPECIAL_KEYS[i],
                                    applicationData.remove(lowerCaseKey));
            }
        }
    }

   /**
     * Serializes the specified JSON response and sends it via the transport
     * handle. Catches and logs all serialization and transport errors.
     *
     * @param response JSON response object
     * @param transport transport handle
     * @param isFinal whether this is the final response for the request
     */
    private void sendResponse(JsonBaseResponse response,
                              TransportContext transport,
                              boolean isFinal) {
        if (response.isError()) {
            TracingSpan span = transport.getRequestContext().getTracingSpan();
            span.setStatusError("json.rpc.error", null);
            span.end();
        }
        byte[] jsonResponse = null;
        try {
            jsonResponse = jsonSerializer2.serialize(response);
        } catch (RuntimeException t) {
            logger.warn("Failed to serialize JSON response", t);
            return;
        }
        dumpResponse(jsonResponse);
        sendResponse(jsonResponse, transport, isFinal);
    }

    /**
     * Logs the specified JSON response message, if debug logging is enabled.
     * Logs an error if the UTF-8 encoding is invalid.
     *
     * @param jsonResponse JSON response in UTF-8 encoding
     */
    private static void dumpResponse(byte[] jsonResponse) {
        if (logger.isDebugEnabled() && Constants.shouldLogRawRequestResponse()) {
            try {
                String strJsonResponse = new String(jsonResponse, UTF8_CHARSET);
                logger.debug("JSON response: {}", strJsonResponse);
            } catch (UnsupportedEncodingException ex) {
                logger.error("JSON response is not valid UTF-8", ex);
            }
        }
    }

    /**
     * Helper method to forward a invoke_method request to the provider and
     * return the response
     *
     * @param jsonApiRequest invoke_method request
     * @param id ID of the JSON-RPC 2.0 request/response
     * @param procMetaData
     * @param transport transport handle
     */
    private void processApiRequest(JsonApiRequest jsonApiRequest,
                                   final String id,
                                   Map<String, Object> procMetaData,
                                   final TransportContext transport) {
        JsonInvokeParams invokeParams = jsonApiRequest.getParams();
        JsonContextBuilderRequestParams requestParamsBuilder =
              (JsonContextBuilderRequestParams) invokeParams;
        ExecutionContextBuilder ctxBuilder = requestParamsBuilder.getCtxBuilder();
        updateSecurityContext(procMetaData, ctxBuilder.security);
        addUserAgent(transport.getRequestContext(), ctxBuilder.applicationData);
        addLocalizationSettings(transport.getRequestContext(), ctxBuilder.applicationData);
        ExecutionContext ctx = ctxBuilder.build();

        AsyncHandle<MethodResult> cb;
        cb = new AsyncHandleImpl<MethodResult>(transport, id, true) {
            @Override
            protected JsonBaseResponse toJson(MethodResult result) {
                TracingSpan span = transport.getRequestContext().getTracingSpan();
                if (!result.success()) {
                    this.transport.setHeader("vapi-error",
                                             result.getError().getName());
                    span.setStatusError(result.getError());
                } else {
                    span.setStatusOk();
                }
                span.end();
                return new JsonApiResponse(id, result);
            }
        };
        provider.invoke(requestParamsBuilder.getServiceId(),
                        requestParamsBuilder.getOperationId(),
                        requestParamsBuilder.getInput(),
                        ctx,
                        cb);
    }

    private static void addUserAgent(RequestContext reqCtx,
                                     Map<String, String> applicationData) {
        if (reqCtx.getUserAgent() != null) {
           applicationData.put(ApplicationData.USER_AGENT_KEY, reqCtx.getUserAgent());
        }
    }

    private static void addLocalizationSettings(RequestContext reqCtx,
                                     Map<String, String> applicationData) {
        if (reqCtx.getAcceptLanguage() != null) {
            applicationData.put(ApplicationData.ACCEPT_LANGUAGE_KEY, reqCtx.getAcceptLanguage());
         }
    }

    /**
     * Updates the map holding the security context with the additional
     * properties provided by the request processors.
     *
     * @param procMetaData must not be null
     * @param security can be null
     */
    private static void updateSecurityContext(Map<String, Object> procMetaData,
                                              Map<String, Object> security) {
        assert procMetaData != null;

        Object procSecCtx = procMetaData.get(SECURITY_PROC_METADATA_KEY);
        if (procSecCtx != null && procSecCtx instanceof Map) {
           security.put(SECURITY_PROC_METADATA_KEY, procSecCtx);
        } else if (security.get(SECURITY_PROC_METADATA_KEY) != null) {
            // overwrite any data under RequestProcessor.SECURITY_PROC_METADATA_KEY
            // key as this is the request processors communication data field
            // i.e. this field is reserved for internal usage
            if (logger.isDebugEnabled()) {
                logger.debug("Removed {" + SECURITY_PROC_METADATA_KEY + "," +
                        security.get(SECURITY_PROC_METADATA_KEY) +
                        "} key/value pair from the security context");
            }
            security.remove(SECURITY_PROC_METADATA_KEY);
        }
    }

    /**
     * Helper method to send response to the HTTP/AMQP RPC layer.
     *
     * @param resp response to be sent over the wire
     * @param isFinal whether this is the final response for the request
     */
    private void sendResponse(byte[] resp,
                              TransportContext transport,
                              boolean isFinal) {
        int respSize = resp.length;
        try {
            logger.debug("Sending JSON response of size {}", respSize);
            transport.send(new ByteArrayInputStream(resp),
                           resp.length,
                           isFinal);
        } catch (IOException e) {
            logger.error("Unable to send JSON response", e);
        }
    }

    /**
     * Helper method to create JSON-RPC 2.0 error responses
     *
     * @param id  ID to be used in the JSON-RPC 2.0 error response
     * @param code  JSON-RPC 2.0 Error Code
     * @param message JSON-RPC 2.0 Error Message
     * @return object encapsulation JSON-RPC 2.0 error response
     */
    private JsonErrorResponse createErrorResponse(String id, int code,
                                                  String message,
                                                  String data) {
        JsonError jsonError = new JsonError(code, message, data);
        return new JsonErrorResponse(id, jsonError);
    }

    /**
     * Base class for converters which translate {@link ApiProvider} method
     * result to JSON and send it as a response via a transport handle.
     *
     * @param <T> type of the result of an {@link ApiProvider} method
     */
    private abstract class AsyncHandleImpl<T> extends AsyncHandle<T> {

        protected final TransportContext transport;
        private final String id;
        private final boolean sendProgressUpdates;

        /**
         * Wraps the transport handle in an async handle.
         *
         * @param transport transport handle; must not be <code>null</code>
         * @param id JSON-RPC request id; must not be <code>null</code>
         * @param sendProgressUpdates whether to send progress updates over the
         *                            transport
         */
        AsyncHandleImpl(TransportContext transport,
                        String id,
                        boolean sendProgressUpdates) {
            Validate.notNull(transport);
            Validate.notNull(id);
            this.transport = transport;
            this.id = id;
            this.sendProgressUpdates = sendProgressUpdates;
        }

        @Override
        public void updateProgress(DataValue progress) {
            if (sendProgressUpdates) {
                sendResponse(new JsonProgressResponse(id, progress),
                             transport,
                             false);
            } else {
                logger.debug("Skipping progress update for request with id {}", id);
            }
        }

        @Override
        public void setResult(T result) {
            JsonBaseResponse jsonObj;
            try {
                jsonObj = toJson(result);
            } catch (RuntimeException ex) {
                logger.error(
                        "Result for request with id {} could not be converted to JSON",
                        id, ex);
                jsonObj = createErrorResponse(null,
                                              INTERNAL_JSONRPC_ERROR_CODE,
                                              INTERNAL_JSONRPC_ERROR_MSG,
                                              ex.getMessage());
            }
            sendResponse(jsonObj, transport, true);
        }

        @Override
        public void setError(RuntimeException error) {
            logger.error("Provider returned error for request with id {}", id, error);
            JsonErrorResponse jsonObj = createErrorResponse(null,
                                                            INTERNAL_JSONRPC_ERROR_CODE,
                                                            INTERNAL_JSONRPC_ERROR_MSG,
                                                            error.getMessage());
            TracingSpan span = transport.getRequestContext().getTracingSpan();
            span.setStatusError(error);
            span.end();
            sendResponse(jsonObj, transport, true);
        }

        /**
         * Converts from {@link ApiProvider} method result to JSON. Derived
         * classes must override this method.
         *
         * @param result {@link ApiProvider} method result; must not be
         *               <code>null</code>
         * @return JSON representation; must not be <code>null</code>
         */
        protected abstract JsonBaseResponse toJson(T result);
    }
}
