/* **********************************************************
 * Copyright 2022 VMware, Inc.  All rights reserved. -- VMware Confidential
 * **********************************************************/
package com.vmware.vapi.internal.protocol.server;

import java.io.IOException;
import java.util.Objects;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import com.vmware.vapi.internal.tracing.TracingScope;
import com.vmware.vapi.internal.tracing.TracingSpan;
import com.vmware.vapi.tracing.Tracer;

/**
 * This servlet has a {@link TracingSpan} configured. If there is a tracing context in the request
 * and tracing is enabled, it is extracted and made active.
 */
public abstract class TraceDecoratedServlet extends HttpServlet {
    private static final long serialVersionUID = 1L;

    public static final String ATTR_TRACING_SPAN = TraceDecoratedServlet.class.getName()
                                                   + ".TRACING_SPAN";

    private final Tracer tracer;

    public TraceDecoratedServlet() {
        this(Tracer.NO_OP);
    }

    public TraceDecoratedServlet(Tracer tracer) {
        this.tracer = Objects.requireNonNull(tracer);
    }

    /**
     * Do not override this method. Override
     * {@link #doService(HttpServletRequest, HttpServletResponse)} instead
     */
    @Override
    protected void service(HttpServletRequest req, HttpServletResponse resp)
    throws ServletException, IOException {
        TracingSpan span = tracer.attachServerSpan(req);
        req.setAttribute(ATTR_TRACING_SPAN, span);
        try (TracingScope scope = span.makeCurrent()) {
            doService(req, resp);
        }
    }

    protected void doService(HttpServletRequest req, HttpServletResponse resp)
    throws ServletException, IOException {
        super.service(req, resp);
    }

    /**
     * Gets the attached {@link TracingSpan} associated with the request if available
     *
     * @param req The request for which to get the span
     * @return the attached {@link TracingSpan} if any; <code>null</code> otherwise
     */
    public static TracingSpan getTracingSpan(HttpServletRequest req) {
        return (TracingSpan) req.getAttribute(ATTR_TRACING_SPAN);
    }
}
