/* **********************************************************
 * Copyright 2012-2013, 2020, 2021 VMware, Inc.  All rights reserved.
 *      -- VMware Confidential
 * **********************************************************/

/*
 * ProviderAggregation.java --
 *
 *      Aggregation of vAPI providers.
 */

package com.vmware.vapi.provider.aggregator;

import static com.vmware.vapi.MessageFactory.getMessage;
import static com.vmware.vapi.bindings.exception.Constants.INTERNAL_SERVER_ERROR;
import static com.vmware.vapi.bindings.exception.Constants.OPERATION_NOT_FOUND;

import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.zip.CRC32;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.vmware.vapi.ErrorValueFactory;
import com.vmware.vapi.Message;
import com.vmware.vapi.core.ApiProvider;
import com.vmware.vapi.core.ApiProviderStubImpl;
import com.vmware.vapi.core.AsyncHandle;
import com.vmware.vapi.core.ExecutionContext;
import com.vmware.vapi.core.InterfaceDefinition;
import com.vmware.vapi.core.InterfaceIdentifier;
import com.vmware.vapi.core.MethodDefinition;
import com.vmware.vapi.core.MethodIdentifier;
import com.vmware.vapi.core.MethodResult;
import com.vmware.vapi.core.ProviderDefinition;
import com.vmware.vapi.data.DataValue;
import com.vmware.vapi.data.ErrorDefinition;
import com.vmware.vapi.data.ErrorValue;
import com.vmware.vapi.internal.provider.introspection.SyncToAsyncApiIntrospectionAdapter;
import com.vmware.vapi.internal.util.StringUtils;
import com.vmware.vapi.internal.util.async.SetAccumulator;
import com.vmware.vapi.provider.introspection.ApiIntrospection;
import com.vmware.vapi.provider.introspection.SyncApiIntrospection;
import com.vmware.vapi.std.StandardDataFactory;

/**
 * Aggregation of providers (local or remote). Implements aggregated invoke and
 * aggregated {@link ApiIntrospection}. Does not have remote introspection (i.e.
 * introspection services).
 */
class ProviderAggregation implements ApiIntrospection, ApiProvider {

    private static final Logger logger =
            LoggerFactory.getLogger(ProviderAggregation.class);

    /**
     * Set of errors which are reported by the aggregator implementation.
     */
    static final Set<ErrorDefinition> AGGREGATOR_ERROR_DEFS =
        Collections.unmodifiableSet(new HashSet<ErrorDefinition>(
            Arrays.asList(
                StandardDataFactory.createStandardErrorDefinition(INTERNAL_SERVER_ERROR),
                StandardDataFactory.createStandardErrorDefinition(OPERATION_NOT_FOUND))));

    private final String name;
    private final Map<String, ApiProvider> ifaceMap;
    private final Set<ApiProvider> providers;

    public ProviderAggregation(String name,
            Map<String, ApiProvider> ifaceMap) {
        this.name = name;
        this.ifaceMap = ifaceMap;
        providers = new HashSet<ApiProvider>(this.ifaceMap.values());
    }

    @Override
    public void getDefinition(ExecutionContext ctx,
            final AsyncHandle<ProviderDefinition> asyncHandle) {
        // reduce a set of definitions to a single definition
        AsyncHandle<Set<ProviderDefinition>> reducer;
        reducer = new AsyncHandle<Set<ProviderDefinition>>() {
            @Override
            public void updateProgress(DataValue progress) {
                asyncHandle.updateProgress(progress);
            }

            @Override
            public void setResult(Set<ProviderDefinition> result) {
                String checkSum = computeCheckSum(result);
                asyncHandle.setResult(new ProviderDefinition(name, checkSum));
            }

            @Override
            public void setError(RuntimeException error) {
                asyncHandle.setError(error);
            }
        };

        // accumulate the definitions of all providers
        SetAccumulator<ProviderDefinition> accumulator =
                new SetAccumulator<ProviderDefinition>(reducer,
                        providers.size());
        for (ApiProvider p : providers) {
            introspect(p).getDefinition(ctx,
                    accumulator.createSlaveForOneElement());
        }
    }

    /**
     * Computes the checksum of the Aggregator based on the definitions of the
     * aggregated providers. The Aggregator checksum must change if the
     * checksum of an aggregated provider changes, if a new provider is
     * registered, or if an existing provider is unregistered removed from the
     * Aggregator.
     *
     * @param defs definitions of aggregated providers
     * @return checksum of the Aggregator
     */
    private static String computeCheckSum(Set<ProviderDefinition> defs) {
        try {
            CRC32 checksum = new CRC32();
            for (ProviderDefinition def : defs) {
                checksum.update(def.getCheckSum().getBytes("UTF-8"));
            }
            return StringUtils.crc32ToHexString(checksum);
        } catch (UnsupportedEncodingException ex) {
            // this should never happen, since "UTF-8" is supported in Java
            logger.error("Unable to get UTF-8 bytes for data for checksum", ex);
            return "";
        }
    }

    @Override
    public void getInterfaceIdentifiers(ExecutionContext ctx,
            AsyncHandle<Set<InterfaceIdentifier>> asyncHandle) {
        Set<InterfaceIdentifier> result = new HashSet<InterfaceIdentifier>();
        for (String serviceId : ifaceMap.keySet()) {
            result.add(new InterfaceIdentifier(serviceId));
        }
        asyncHandle.setResult(result);
    }

    @Override
    public void getInterface(ExecutionContext ctx,
                              InterfaceIdentifier iface,
                              AsyncHandle<InterfaceDefinition> asyncHandle) {
        if (iface == null) {
            asyncHandle.setResult(null);
            return;
        }
        ApiProvider p = ifaceMap.get(iface.getName());
        if (p != null) {
            introspect(p).getInterface(ctx, iface, asyncHandle);
        } else {
            asyncHandle.setResult(null);
        }
    }

    @Override
    public void getMethod(ExecutionContext ctx,
                           MethodIdentifier method,
                           AsyncHandle<MethodDefinition> asyncHandle) {
        if (method == null) {
            asyncHandle.setResult(null);
            return;
        }
        InterfaceIdentifier iface = method.getInterfaceIdentifier();
        ApiProvider prov = ifaceMap.get(iface.getName());
        if (prov != null) {
            introspect(prov).getMethod(ctx, method, asyncHandle);
        } else {
            asyncHandle.setResult(null);
        }
    }

    private ApiIntrospection introspect(ApiProvider p) {
        if (p instanceof ApiIntrospection) {
            return (ApiIntrospection) p;
        }
        if (p instanceof SyncApiIntrospection) {
            return new SyncToAsyncApiIntrospectionAdapter(
                    (SyncApiIntrospection) p);
        };
        return new ApiProviderStubImpl(p);
    }

    private void invokeMethodImpl(final ExecutionContext ctx,
                                  final String serviceId,
                                  final String operationId,
                                  final DataValue input,
                                  final AsyncHandle<MethodResult> asyncHandle) {
        ApiProvider prov = ifaceMap.get(serviceId);
        if (prov == null) {
            logger.warn("Could not find provider for service '{}'", serviceId);

            ErrorValue errValue = ErrorValueFactory.buildErrorValue(
                    StandardDataFactory.OPERATION_NOT_FOUND,
                    "vapi.method.input.invalid.interface",
                    serviceId);

            asyncHandle.setResult(MethodResult.newErrorResult(errValue));
            return;
        }

        logger.debug("Invoking operation '{}.{}'", serviceId, operationId);

        AsyncHandle<MethodResult> cb = new AsyncHandle<MethodResult>() {
            @Override
            public void updateProgress(DataValue progress) {
                asyncHandle.updateProgress(progress);
            }

            @Override
            public void setResult(MethodResult result) {
                 asyncHandle.setResult(result);
            }

            @Override
            public void setError(RuntimeException ex) {
                // some internal error; or network error trying
                // to invoke the provider - log and report
                // InternalServerError to the client
                logger.error(String.format(
                        "Error while invoking operation '%s.%s'",
                        serviceId,
                        operationId), ex);
                asyncHandle.setResult(invokeMethodError(serviceId, operationId));
            }
        };

        prov.invoke(serviceId,
                    operationId,
                    input,
                    ctx,
                    cb);
    }

    private static MethodResult invokeMethodError(String serviceId,
            String operationId) {
        String methodId = serviceId + "." + operationId;
        Message message = getMessage("vapi.provider.aggregator.invokemethod.exception",
                                     methodId);
        return MethodResult.newErrorResult(
                StandardDataFactory.createErrorValueForMessages(
                        INTERNAL_SERVER_ERROR,
                        Arrays.asList(message)));
    }

    @Override
    public void invoke(String serviceId,
                       String operationId,
                       DataValue input,
                       ExecutionContext ctx,
                       AsyncHandle<MethodResult> asyncHandle) {
        try {
            invokeMethodImpl(ctx, serviceId, operationId, input, asyncHandle);
        } catch (RuntimeException ex) {
            logger.error("Exception thrown in invokeMethod", ex);
            asyncHandle.setError(ex);
        }
    }
}
