/* **********************************************************
 * Copyright (c) 2020-2021 VMware, Inc.  All rights reserved. -- VMware Confidential
 * **********************************************************/
package com.vmware.vapi.internal.bindings;

import java.util.Collection;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;

import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.vmware.vapi.bindings.client.InvocationConfig;
import com.vmware.vapi.bindings.type.StructType;
import com.vmware.vapi.bindings.type.Type;
import com.vmware.vapi.core.ApiProvider;
import com.vmware.vapi.core.AsyncHandle;
import com.vmware.vapi.core.Consumer;
import com.vmware.vapi.core.ExecutionContext;
import com.vmware.vapi.core.MethodIdentifier;
import com.vmware.vapi.core.MethodResult;
import com.vmware.vapi.data.DataValue;
import com.vmware.vapi.data.StructValue;
import com.vmware.vapi.diagnostics.LogDiagnosticUtil;
import com.vmware.vapi.diagnostics.LogDiagnosticsConfigurator;
import com.vmware.vapi.diagnostics.Slf4jMDCLogConfigurator;
import com.vmware.vapi.internal.core.abort.AbortHandle;
import com.vmware.vapi.internal.core.abort.AbortHandleImpl;
import com.vmware.vapi.internal.core.abort.AbortableAsyncHandle;

/**
 * Implementation of the Publisher interface returned by VAPI stubs.
 * <p>
 * {@link ApiProvider#invoke(String, String, DataValue, ExecutionContext, AsyncHandle)}
 * method is being called.
 * <p>
 * An instance of this class is meant to handle a single {@code @Stream}
 * invocation. Only one {@link Subscriber} is allowed to subscribe.
 *
 * @param <T> the return type of the invoked operation
 */
public class StreamPublisher<T> implements Publisher<T> {
    private static final Logger LOGGER = LoggerFactory
            .getLogger(StreamPublisher.class);

    private final AbortHandle abortHandle = new AbortHandleImpl();
    private final AtomicBoolean subscribedTo = new AtomicBoolean();
    private volatile Subscriber<? super T> subscriber;
    private volatile Consumer<AsyncHandle<MethodResult>> nextChunkConsumer;

    private final Stub stub;
    private final String serviceId;
    private final String operationId;
    private final ExecutionContext execCtx;
    private final Type outputType;
    private final Collection<Type> errorTypes;
    private StructValue inputValue;

    final AtomicBoolean finished = new AtomicBoolean();
    final AtomicLong demandCount = new AtomicLong();

    public StreamPublisher(Stub stub,
                           MethodIdentifier methodId,
                           StructValue inputValue,
                           StructType inputType,
                           Type outputType,
                           Collection<Type> errorTypes,
                           InvocationConfig invocationConfig) {
        this.stub = stub;
        this.serviceId = methodId.getInterfaceIdentifier().getName();
        this.operationId = methodId.getName();
        this.inputValue = inputValue;
        execCtx = stub.getExecutionContext(invocationConfig);
        this.outputType = outputType;
        this.errorTypes = errorTypes;
    }

    @Override
    public void subscribe(Subscriber<? super T> s) {
        Objects.requireNonNull(s);
        if (!subscribedTo.compareAndSet(false, true)) {
            rejectSubscription(s);
            return;
        }
        subscriber = s;
        s.onSubscribe(new AsyncSubscription());
    }

    private void rejectSubscription(Subscriber<? super T> s) {
        s.onSubscribe(new Subscription() {
            @Override
            public void request(long n) {
            }

            @Override
            public void cancel() {
            }
        });
        s.onError(new IllegalStateException("This instance has already been subscribed to"));
    }

    public ResultTranslatingHandle<T> createHandle(final Subscriber s) {
        ResultTranslatingHandle<T> ah = new ResultTranslatingHandle<T>(stub,
                                                                       outputType,
                                                                       errorTypes) {
            @Override
            public void updateProgress(DataValue progress) {
            }

            @Override
            protected void postProcessResponse(MethodResult result) {}

            @Override
            void onSuccess(T result, Consumer<AsyncHandle<MethodResult>> next) {
                if (finished.get()) {
                    LOGGER.debug("Finished safeguard triggered.");
                    return;
                }

                // Update prior decreasing demand, because another thread can
                // immediately #initiateStreaming after demand is decreased and
                // reuse old handle.
                nextChunkConsumer = next;

                long demand;
                try {
                    if (result != null) {
                        LOGGER.trace("Executing subscriber#onNext with result - {}",
                                     result);
                        s.onNext(result);
                    }
                } finally {
                    demand = demandCount.decrementAndGet();
                }

                if (next == null) {
                    finished.set(true);
                    LOGGER.trace("Streaming complete - next handle not received.");
                    s.onComplete();
                    return;
                }

                LOGGER.trace("Demand is {} .", demand);
                if (demand > 0L) {
                    LOGGER.trace("Resume requesting.");
                    next.accept(createHandle(s));
                }
            }

            @Override
            void onFailure(RuntimeException error) {
                abort(error);
            }
        };

        return ah;
    }

    private void invoke(final ResultTranslatingHandle<T> handle,
                        final ExecutionContext executionContext,
                        final int invocationAttempt) {
        AsyncHandle<MethodResult> ah = handle;
        if (stub.isRetryingConfigured()) {
            LOGGER.trace("Creating retrying handle.");
            ah = new RetryingHandle<T>(stub,
                                       outputType,
                                       errorTypes,
                                       handle,
                                       serviceId,
                                       operationId,
                                       executionContext,
                                       inputValue,
                                       invocationAttempt) {

                @Override
                void onFailure(RuntimeException error) {
                    LOGGER.debug("Error during streaming occurred.", error);
                    if (finished.get()) {
                        LOGGER.debug("Streaming has finished prior receiving the error.",
                                     error);
                        return;
                    }
                    super.onFailure(error);
                }

                @Override
                void onRetry(ExecutionContext retryContext) {
                    LOGGER.debug("Retrying invocation of the streaming request {} {}.",
                                 serviceId,
                                 operationId);
                    invoke(handle, retryContext, invocationAttempt + 1);
                }
            };
        }

        ah = new AbortableAsyncHandle<MethodResult>(ah, abortHandle);

        LogDiagnosticsConfigurator logConfig = new Slf4jMDCLogConfigurator();
        try {
            logConfig.configureContext(LogDiagnosticUtil
                    .getDiagnosticContext(executionContext));
            LOGGER.trace("Starting streaming invocation request {} {}.",
                         serviceId,
                         operationId);
            stub.apiProvider.invoke(serviceId,
                                    operationId,
                                    inputValue,
                                    executionContext,
                                    ah);
            // no need to keep this in memory
            inputValue = null;
        } finally {
            logConfig.cleanUpContext(LogDiagnosticUtil.getDiagnosticKeys());
        }
    }

    private void abort(Throwable t) {
        LOGGER.debug("Stream processing failed.", t);
        if (!finished.compareAndSet(false, true)) {
            return;
        }
        abortHandle.abort();
        if (t != null) {
            try {
                subscriber.onError(t);
            } catch (RuntimeException e) {
                    String message = String
                            .format("Exception while invoking %s.onError for %s.%s",
                                    subscriber,
                                    serviceId,
                                    operationId);
                    LOGGER.warn(message, e);
            }
        }
        subscriber = null;
    }

    /**
     * This thread-safe {@link Subscription} implementation guards against
     * multiple threads invoking the {@link SynchronousStreamPublisher#invoke}
     * method.
     */
    private class AsyncSubscription implements Subscription {

        private AtomicBoolean isStreamingInitiated = new AtomicBoolean();

        @Override
        public void request(long n) {
            LOGGER.debug("Requested demand - {} .", n);
            if (finished.get()) {
                LOGGER.debug("Streaming has finished, no more requests are processed.");
                return;
            }

            if (n <= 0) {
                LOGGER.debug("Invalid request count received - {}.", n);
                abort(new IllegalArgumentException("non-positive subscription request signals are illegal"));
                return;
            }

            long newDemand = guardedAddAndGet(n);

            // Check whether demand was 0 - meaning execution has stopped,
            // hence a single accept will not be executed more than once.
            if (newDemand == n) {
                initiateStreaming();
            }
        }

        private void initiateStreaming() {
            LOGGER.debug("Stream initiation requested.");
            if (isStreamingInitiated.compareAndSet(false, true)) {
                LOGGER.trace("Publisher is initiating communication.");
                invoke(createHandle(subscriber), execCtx, 0);
            } else {
                Consumer<AsyncHandle<MethodResult>> next = nextChunkConsumer;
                if (next != null) {
                    LOGGER.trace(String
                            .format("Publisher continues suspended communication via consumer - %h.",
                                    next));
                    next.accept(createHandle(subscriber));
                }
            }
        }

        /**
         * Adds the specified value to the {@link StreamPublisher#demandCount}
         * field while guarding for flowing over and race conditions
         * {@link Long#MAX_VALUE}.
         *
         * @param n negative values are allowed but there is no guarding against
         *        {@link Long#MIN_VALUE}
         * @return the updated value
         */
        private long guardedAddAndGet(long n) {
            long prev, next;
            do {
                prev = demandCount.get();
                next = prev + n;
                if (next < 0) {
                    next = Long.MAX_VALUE;
                }
            } while (!demandCount.compareAndSet(prev, next));
            return next;
        }

        @Override
        public void cancel() {
            abort(null);
        }
    }
}