
/* **********************************************************
 * Copyright 2021 VMware, Inc.  All rights reserved. -- VMware Confidential
 * **********************************************************/
package com.vmware.vapi.client.util;

import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.Supplier;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.vmware.vapi.core.ApiProvider;
import com.vmware.vapi.core.AsyncHandle;
import com.vmware.vapi.core.DecoratorApiProvider;
import com.vmware.vapi.core.ExecutionContext;
import com.vmware.vapi.core.ExecutionContext.SecurityContext;
import com.vmware.vapi.core.MethodResult;
import com.vmware.vapi.data.DataValue;
import com.vmware.vapi.data.ErrorValue;
import com.vmware.vapi.std.errors.Unauthenticated;



/**
 * Filter that can be used to acquire on-demand security context for requests.
 * This filter is to be used as decorator of the client {@link ApiProvider} upon
 * which the {@link com.vmware.vapi.internal.bindings.Stub} classes are
 * Instantiated.
 * <p>
 * The filter relies on externally provided {@link Supplier} of
 * {@link com.vmware.vapi.core.ExecutionContext.SecurityContext}s. To allow for
 * efficient operation a {@link CompletionStage} is expected from the supplier.
 * <p>
 * Supplier will be invoked once on the first call to obtain security context
 * and will be invoked subsequently when {@link Unauthenticated} exception is
 * encountered from the remote service.
 * <p>
 * {@link SecurityContextAcquisitionError} will be returned when the filter
 * cannot obtain authentication token with the provided {@link Supplier}.
 * <p>
 * Requests failing with {@link Unauthenticated} errors will be retried only
 * once with new credential. If {@link Unauthenticated} is encountered with the
 * new credential the error will be returned in the bindings tier.
 * <p>
 * Calls waiting on a failing future before the timeout for refreshing failed
 * future lapse will fail with no retry. If multiple retries are to be made to
 * acquire credential this is to be implemented in the future logic.
 */
public class DynamicAuthnFilter extends DecoratorApiProvider {
    private static final int MAX_GENERATION = 1_000_000;
    private static String UNAUTHENTICATED = Unauthenticated
            ._getCanonicalTypeName();
    private static Logger LOGGER = LoggerFactory.getLogger(DynamicAuthnFilter.class);

    private Supplier<CompletionStage<SecurityContext>> secCtxFactory;
    private volatile SecurityHolder cache;
    private volatile int generationCounter = 1;
    private final Object lock = new Object();
    private final long timeoutMs;

    /**
     * Creates a new authentication filter that can be used to create stub that
     * will have all their calls authenticated
     *
     * @param decoratedProvider {@link ApiProvider} that invokes the real API.
     *        Typically obtained from a connection.
     * @param supplier whenever new security context is needed this will be
     *        called to obtain a new {@link CompletionStage} that will deliver
     *        the context. It is expected that this operation is relatively
     *        fast. While the completion stage may take time to acquire the
     *        actual {@link SecurityContext}.
     *        <p>
     *        Supplier will be called on the first request and after
     *        {@link Unauthenticated} errors are reported from the API.
     * @param timeoutMs is the interval after acquisition of token fails until
     *        refresh is attempted. Positive number is required.
     */
    public DynamicAuthnFilter(ApiProvider decoratedProvider,
                              Supplier<CompletionStage<SecurityContext>> supplier,
                              long timeoutMs) {
        super(decoratedProvider);
        Objects.requireNonNull(supplier, "Supplier cannot be null");
        if (timeoutMs <= 0) {
            throw new IllegalArgumentException("timeoutMs must be positive");
        }
        this.secCtxFactory = supplier;
        this.timeoutMs = timeoutMs;
    }

    /**
     * Invokes the decorated provider and in case of {@link Unauthenticated}
     * error retries the request with refreshed {@link SecurityContext}. See
     * {@link ApiProvider#invoke(String, String, DataValue, ExecutionContext, AsyncHandle)}
     * for details on the parameters and operation.
     */
    @Override
    public void invoke(String service,
                       String operation,
                       DataValue input,
                       final ExecutionContext ctx,
                       AsyncHandle<MethodResult> asyncHandle) {

        SecurityHolder holder = getOrRefreshSecCtx(0);
        int generaton = holder.getGeneration();
        holder.getSecCtxFuture().thenAccept(security -> {
            UnauthenticatedHandler handler;
            handler = new UnauthenticatedHandler(service,
                                                 operation,
                                                 input,
                                                 ctx,
                                                 asyncHandle,
                                                 generaton);
            ExecutionContext newCtx = ctx.withSecurityContext(security);
            invokeNext(service, operation, input, newCtx, handler);
        }).exceptionally(ex -> {
            holder.signalFailure();
            LOGGER.debug("Failed to obtain security context", ex);
            asyncHandle.setError(new SecurityContextAcquisitionError("Cannot obtain security context",
                                                                     ex));
            return null;
        });

    }

    protected void invokeNext(String service,
                              String operation,
                              DataValue input,
                              ExecutionContext ctx,
                              AsyncHandle<MethodResult> handle) {
        try {
            decoratedProvider.invoke(service, operation, input, ctx, handle);
        } catch (RuntimeException ex) {
            handle.setError(ex);
        }
    }

    /**
     * Async handler for {@link Unauthenticated} errors. The first such erro is
     * followed up by refresh of the security context. Subsequent errors are
     * reported to the caller.
     */
    private class UnauthenticatedHandler extends AsyncHandle<MethodResult> {
        private final String service, operation;
        private final DataValue input;
        private final ExecutionContext ctx;
        private final AsyncHandle<MethodResult> next;
        private final int generation;

        public UnauthenticatedHandler(String service,
                                      String operation,
                                      DataValue input,
                                      ExecutionContext ctx,
                                      AsyncHandle<MethodResult> asyncHandle,
                                      int generation) {
            this.service = service;
            this.operation = operation;
            this.input = input;
            this.ctx = ctx;
            this.next = asyncHandle;
            this.generation = generation;
        }

        @Override
        public void setResult(MethodResult result) {
            ErrorValue err = result.getError();
            if (err != null && UNAUTHENTICATED.equals(err.getName())) {
                LOGGER.debug("Unauthenticated. Recreating security context");
                if (result.getNext() != null) {
                    // We are done with the request. Close it.
                    result.getNext().accept(null);
                }

                SecurityHolder holder = getOrRefreshSecCtx(generation);

                holder.getSecCtxFuture().thenAccept(security -> {
                    ExecutionContext c = ctx.withSecurityContext(security);
                    invokeNext(service, operation, input, c, next);
                }).exceptionally(ex -> {
                    holder.signalFailure();
                    LOGGER.debug("Failed to refresh security context", ex);
                    next.setError(
                            new SecurityContextAcquisitionError("Cannot refresh expired authentication context",
                                                                ex));
                    return null;
                });
                return;
            }
            next.setResult(result);
        }

        @Override
        public void setError(RuntimeException error) {
            next.setError(error);
        }

        @Override
        public void updateProgress(DataValue progress) {
            next.updateProgress(progress);
        }
    }

    /**
     * Gets the cached security context or initialize new one. This method
     * assures only one client acquires new context in case of concurent use.
     * <p>
     * There are two mechanisms employed.
     * <p>
     * Locking is used to assure no racing occurs in the method.
     * <p>
     * To avoid multiple clients from acquiring tokens on
     * {@link Unauthenticated} error optimistic locking is employed. Clients
     * obtain a cached token only if the generation number they supplied is
     * different from the cached value.
     *
     * When new request is made generation value of zero is to be used. Zero
     * will not be used as valid generation number.
     *
     * @param generation generation number known to not work for the caller. It
     *        is either zero or value obtained from previous call
     * @return handle to the security context to be obtained.
     */
    private SecurityHolder getOrRefreshSecCtx(int generation) {
        SecurityHolder holder = cache;
        // Valid cache unknown to the caller
        if (holder != null && holder.usable(generation)) {
            LOGGER.trace("Using cached context");
            return holder;
        }
        synchronized (lock) { // Current state is not good try to fix it.
            holder = cache; // Check if it is already fixed
            if (holder != null && holder.usable(generation)) {
                LOGGER.trace("Using newly provisioned cached context");
                return holder;
            }
            cache = null; // Clear the cache to prevent future use
            LOGGER.debug("Acquiring new security context future");
            CompletionStage<SecurityContext> future;
            try {
                future = secCtxFactory.get();
            } catch (RuntimeException ex) {
                LOGGER.debug("Error acquiring future for security context", ex);
                CompletableFuture<SecurityContext> r = new CompletableFuture<>();
                r.completeExceptionally(ex);
                return new SecurityHolder(0, r);
            }
            holder = new SecurityHolder(nextGeneration(), future);
            cache = holder;
            return holder;
        }
    }

    /**
     * Generates new generation number. This method assumes to run in only one
     * thread.
     *
     * @return next generation number
     */
    private int nextGeneration() {
        generationCounter = (generationCounter % MAX_GENERATION) + 1;
        return generationCounter;
    }

    /**
     * Unmodifiable holder for security context and generation number. It is
     * safe to share between threads
     */
    private class SecurityHolder {
        private int generation;
        private CompletionStage<SecurityContext> secCtxFuture;
        private Object lock = new Object();
        private volatile long failedAt = 0;

        public SecurityHolder(int generation,
                              CompletionStage<SecurityContext> secCtxFuture) {
            this.generation = generation;
            this.secCtxFuture = secCtxFuture;
        }

        public void signalFailure() {
            if (failedAt == 0) {
                synchronized (lock) {
                    if (failedAt == 0) {
                        failedAt = System.currentTimeMillis();
                    }
                }
            }
        }

        public boolean usable(int generation) {
            long failedAt = this.failedAt;
            boolean attempted = generation == this.generation;
            long now = System.currentTimeMillis();
            boolean needsRefresh = failedAt > 0 && (failedAt + timeoutMs < now);
            return !(attempted || needsRefresh);
        }

        public CompletionStage<SecurityContext> getSecCtxFuture() {
            return secCtxFuture;
        }

        public int getGeneration() {
            return generation;
        }
    }
}
