/* **********************************************************
 * Copyright (c) 2012-2013, 2017, 2019, 2022-2023 VMware, Inc.  All rights reserved. -- VMware Confidential
 * **********************************************************/
package com.vmware.vapi.protocol.server.rpc.http.impl;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;

import javax.servlet.AsyncContext;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.vmware.vapi.internal.protocol.common.http.FrameSerializer;
import com.vmware.vapi.internal.protocol.common.http.impl.ChunkedTransferEncodingFrameSerializer;
import com.vmware.vapi.internal.protocol.server.TraceDecoratedServlet;
import com.vmware.vapi.internal.tracing.TracingSpan;
import com.vmware.vapi.internal.util.Validate;
import com.vmware.vapi.internal.util.io.IoUtil;
import com.vmware.vapi.protocol.common.http.HttpConstants;
import com.vmware.vapi.protocol.server.rpc.RequestReceiver;
import com.vmware.vapi.protocol.server.rpc.RequestReceiver.RequestContext;
import com.vmware.vapi.protocol.server.rpc.RequestReceiver.TransportContext;
import com.vmware.vapi.protocol.server.rpc.http.MediaTypeResolver;
import com.vmware.vapi.tracing.Tracer;

/**
 * <p>
 * Servlet which enables its {@link RequestReceiver} handlers to send multiple
 * response frames for a single request. It relies on the async features
 * introduced in Servlet 3.0.
 * </p>
 * <p>
 * If there is a single response frame for a given request, the HTTP response
 * will contain a normal Content-Type derived from the payload, e.g.
 * {@link HttpConstants#CONTENT_TYPE_JSON}. If there are multiple response
 * frames, the Content-Type would be prefixed with
 * {@link HttpConstants#CONTENT_TYPE_FRAMED}.
 * </p>
 */
public class HttpStreamingServlet extends TraceDecoratedServlet {

    private static final long serialVersionUID = 0L;
    private static final Logger logger =
            LoggerFactory.getLogger(HttpStreamingServlet.class);

    private final MediaTypeResolver mediaTypeResolver;
    private final long asyncTimeout;

    /**
     * Calls {@link #HttpStreamingServlet(MediaTypeResolver, long)} with timeout
     * -1 (means no timeout).
     *
     * @param mediaTypeResolver resolves a media type to a content handler
     */
    public HttpStreamingServlet(MediaTypeResolver mediaTypeResolver) {
        this(mediaTypeResolver, -1);
    }

    public HttpStreamingServlet(MediaTypeResolver mediaTypeResolver, Tracer tracer) {
        this(mediaTypeResolver, -1, tracer);
    }

    /**
     * Calls {@link #HttpStreamingServlet(MediaTypeResolver, long, Tracer)} with
     * tracer {@link Tracer#NO_OP} meaning there is no tracing.
     *
     * @param mediaTypeResolver resolves a media type to a content handler
     * @param asyncTimeout time limit in milliseconds for handling a request;
     *        zero or less indicates no timeout; see
     *        {@link javax.servlet.AsyncContext#setTimeout(long)}
     */
    public HttpStreamingServlet(MediaTypeResolver mediaTypeResolver,
                                long asyncTimeout) {
        this(mediaTypeResolver, asyncTimeout, Tracer.NO_OP);
    }

    /**
     * Creates a servlet.
     *
     * @param mediaTypeResolver resolves a media type to a content handler
     * @param asyncTimeout time limit in milliseconds for handling a request;
     *        zero or less indicates no timeout; see
     *        {@link javax.servlet.AsyncContext#setTimeout(long)}
     * @param tracer the tracer used by the vapi stack for distributed tracing
     */
    public HttpStreamingServlet(MediaTypeResolver mediaTypeResolver,
                                long asyncTimeout,
                                Tracer tracer) {
        super(tracer);
        Validate.notNull(mediaTypeResolver);
        this.mediaTypeResolver = mediaTypeResolver;
        this.asyncTimeout = asyncTimeout;
    }

    /**
     * Method to handle HTTP POST request.
     */
    @Override
    public void doPost(HttpServletRequest request, HttpServletResponse response)
            throws ServletException, IOException {
        try {
            doPostImpl(request, response);
        } catch (ServletException ex) {
            logger.error("Servlet error while processing POST request", ex);
            throw ex;
        } catch (IOException ex) {
            logger.error("I/O error while processing POST request", ex);
            throw ex;
        } catch (RuntimeException ex) {
            logger.error("Unexpected error while processing POST request", ex);
            throw ex;
        }
    }

    private void doPostImpl(HttpServletRequest request,
                            HttpServletResponse response)
    throws ServletException, IOException {
        logRequest(request);
        String contentType = request.getContentType();
        if (contentType == null) {
            throw new ServletException("Missing Content-Type header");
        }
        String mediaType = getMediaType(contentType);
        RequestReceiver msgProtocol = mediaTypeResolver.getHandler(mediaType);
        if (msgProtocol == null) {
            throw new ServletException("Unexpected Content-Type: " +
                    request.getContentType());
        }

        AsyncContext servletAsyncContext =
                request.startAsync(request, response);
        servletAsyncContext.setTimeout(asyncTimeout);

        TransportContext transport = createTransport(request, servletAsyncContext);
        msgProtocol.requestReceived(request.getInputStream(), transport);
    }

    /**
     * Extracts the media type (type/subtype) portion of the content-type header
     * without the parameters (e.g. charset). For example for
     * <code>text/html; charset=ISO-8859-4</code> would turn into
     * <code>text/html</code>.
     *
     * @param contentType content-type HTTP header; must not be
     *                    <code>null</code>
     * @return the media type
     */
    static String getMediaType(String contentType) {
        String mediaType = contentType;
        int semicolon = mediaType.indexOf(';');
        if (semicolon >= 0) {
            mediaType = mediaType.substring(0, semicolon).trim();
        }
        return mediaType;
    }

    /**
     * Helper Method to log requests.
     */
    private void logRequest(HttpServletRequest request) {
        if (!logger.isDebugEnabled()) {
            return;
        }
        logger.debug(String.format(
                "Received request from agent '%s' with content-length %s, content-type '%s' and accept header '%s'",
                        request.getHeader(HttpConstants.HEADER_USER_AGENT),
                        request.getContentLength(), request.getContentType(),
                        request.getHeader(HttpConstants.HEADER_ACCEPT)));
    }

    /**
     * Create transport for the response.
     *
     * @param request servlet request; must not be <code>null</code>
     * @param servletAsyncContext servlet async context for the response; must
     *                            not be <code>null</code>
     * @return transport; never <code>null</code>
     */
    private static TransportContext createTransport(HttpServletRequest request,
                                                    AsyncContext servletAsyncContext) {
        RequestContext requestContext = buildRequestContext(request);
        /*
         * Framed content requires special parsing, so we use it only if the
         * client has explicitly stated that it supports it. Otherwise we would
         * surprise clients that don't support it.
         */
        String acceptHeader = request.getHeader(HttpConstants.HEADER_ACCEPT);
        if (acceptHeader == null) {
            acceptHeader = HttpConstants.CONTENT_TYPE_JSON;
        }
        boolean clientAcceptsSingleFrame = acceptHeader
                .contains(HttpConstants.CONTENT_TYPE_JSON);
        boolean clientAcceptsMultiFrame = acceptHeader
                .contains(HttpConstants.CONTENT_TYPE_FRAMED);
        if (clientAcceptsSingleFrame && clientAcceptsMultiFrame) {
            return new AdaptiveTransport(servletAsyncContext, requestContext);
        }
        if (clientAcceptsMultiFrame) {
            return new MultiFrameTransport(servletAsyncContext, requestContext);
        }
        return new SingleFrameTransport(servletAsyncContext, requestContext);
    }

    private static RequestContext buildRequestContext(final HttpServletRequest servletRequest) {
        return new RequestContext() {
            TracingSpan span = HttpStreamingServlet.getTracingSpan(servletRequest);
            @Override
            public String getUserAgent() {
                return servletRequest.getHeader(HttpConstants.HEADER_USER_AGENT);
            }

            @Override
            public String getSession() {
               return servletRequest.getHeader(HttpConstants.HEADER_SESSION_ID);
            }

            @Override
            public String getServiceId() {
               return servletRequest.getHeader(HttpConstants.HEADER_SERVICE_ID);
            }

            @Override
            public String getOperationId() {
               return servletRequest.getHeader(HttpConstants.HEADER_OPERATION_ID);
            }

            @Override
            public String getAcceptLanguage() {
                return servletRequest.getHeader(HttpConstants.HEADER_ACCEPT_LANGUAGE);
            }

            @Override
            public Map<String, String> getAllProperties() {
               return getAllHeaders(servletRequest);
            }

            @Override
            public String getJsonRpcVersion() {
                if (getServiceId() != null) {
                    return "jsonrpc1.1";
                }
                return "jsonrpc";
            }

            @Override
            public TracingSpan getTracingSpan() {
                return span;
            }

        };
    }

    /**
    * @return an immutable map of all request headers with a case-insensitive key
    */
   static Map<String, String> getAllHeaders(HttpServletRequest servletRequest) {
       Map<String, String> headers = new HashMap<>();

       Enumeration<String> headerNames = servletRequest.getHeaderNames();
       if (headerNames != null) {
           while (headerNames.hasMoreElements()) {
               String name = headerNames.nextElement();
               String values = combineHeaderValues(servletRequest.getHeaders(name));
               if (values != null) {
                  headers.put(name, values);
               }
           }
       }

       return Collections.unmodifiableMap(headers);
   }

   private static String combineHeaderValues(Enumeration<String> values) {
       if (values == null || !values.hasMoreElements()) {
           return null;
       }

       String firstValue = values.nextElement();
       if (!values.hasMoreElements()) {
          return firstValue;
       }

       StringBuilder result = new StringBuilder(firstValue);
       do {
          result.append(',');
          result.append(values.nextElement());
       } while (values.hasMoreElements());
       return result.toString();
   }

    /**
     * HTTP transport which wraps an {@link javax.servlet.AsyncContext}.
     */
    private static abstract class BaseTransport implements TransportContext {

        protected final AsyncContext servletAsyncContext;
        protected final RequestContext requestContext;

        protected BaseTransport(AsyncContext servletAsyncContext,
                                RequestContext requestContext) {
            this.servletAsyncContext = servletAsyncContext;
            this.requestContext = requestContext;
        }

        @Override
        public RequestContext getRequestContext() {
            return requestContext;
        }

        @Override
        public void setHeader(String name, String value) {
            ((HttpServletResponse) this.servletAsyncContext.getResponse())
                .setHeader(name, value);
        }
    }

    /**
     * HTTP transport for a response that contains only one frame.
     */
    private static class SingleFrameTransport extends BaseTransport {

        public SingleFrameTransport(AsyncContext servletAsyncContext,
                                    RequestContext requestContext) {
            super(servletAsyncContext, requestContext);
        }

        @Override
        public synchronized void send(InputStream response,
                                      int responseLength,
                                      boolean isFinal) throws IOException {
            if (!isFinal) {
                return;
            }

            servletAsyncContext.getResponse().setContentType(
                    HttpConstants.CONTENT_TYPE_JSON);

            OutputStream out =
                    servletAsyncContext.getResponse().getOutputStream();

            IoUtil.copy(response, out);
            out.flush();
            servletAsyncContext.complete();
        }
    }

    /**
     * HTTP transport for a response that may contain multiple frames.
     */
    private static class MultiFrameTransport extends BaseTransport {

        private static final FrameSerializer frameSerializer =
                new ChunkedTransferEncodingFrameSerializer();

        public MultiFrameTransport(AsyncContext servletAsyncContext,
                                   RequestContext requestContext) {
            super(servletAsyncContext, requestContext);
        }

        @Override
        public synchronized void send(InputStream response,
                                      int responseLength,
                                      boolean isFinal) throws IOException {
            servletAsyncContext.getResponse().setContentType(
                    HttpConstants.CONTENT_TYPE_FRAMED);

            OutputStream out =
                    servletAsyncContext.getResponse().getOutputStream();
            byte[] responseFrame = IoUtil.readAll(response);
            frameSerializer.writeFrame(out, responseFrame);
            out.flush();

            // no more responses for this request
            if (isFinal) {
                servletAsyncContext.complete();
            }
        }
    }

    /**
     * HTTP transport which works as {@link SingleFrameTransport} if the
     * response contains only one frame or as {@link MultiFrameTransport} if the
     * response contains more than one frame.
     */
    private static class AdaptiveTransport extends BaseTransport {

        private TransportContext decorated;

        public AdaptiveTransport(AsyncContext servletAsyncContext,
                                 RequestContext requestContext) {
           super(servletAsyncContext, requestContext);
        }

        @Override
        public synchronized void send(InputStream response,
                                      int responseLength,
                                      boolean isFinal) throws IOException {
            if (decorated == null) {
                if (isFinal) {
                    // first and final means single
                    decorated = new SingleFrameTransport(servletAsyncContext,
                                                         getRequestContext());
                } else {
                    decorated = new MultiFrameTransport(servletAsyncContext,
                                                        getRequestContext());
                }
            }
            decorated.send(response, responseLength, isFinal);
        }
    }
}
