/* **********************************************************
 * Copyright 2011-2014, 2018-2022 VMware, Inc. All rights reserved. -- VMware Confidential
 * **********************************************************/
package com.vmware.vapi.security;


import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.vmware.vapi.ErrorValueFactory;
import com.vmware.vapi.Message;
import com.vmware.vapi.MessageFactory;
import com.vmware.vapi.bindings.exception.Constants;
import com.vmware.vapi.core.ApiProvider;
import com.vmware.vapi.core.AsyncHandle;
import com.vmware.vapi.core.DecoratorApiProvider;
import com.vmware.vapi.core.ExecutionContext;
import com.vmware.vapi.core.ExecutionContext.SecurityContext;
import com.vmware.vapi.core.MethodIdentifier;
import com.vmware.vapi.core.MethodResult;
import com.vmware.vapi.data.DataValue;
import com.vmware.vapi.data.ErrorDefinition;
import com.vmware.vapi.data.ErrorValue;
import com.vmware.vapi.internal.security.SecurityUtil;
import com.vmware.vapi.internal.util.Validate;
import com.vmware.vapi.provider.introspection.ErrorAugmentingFilter;
import com.vmware.vapi.security.AuthenticationConfig.AuthnScheme;
import com.vmware.vapi.security.AuthenticationHandler.AuthenticationResult;
import com.vmware.vapi.std.StandardDataFactory;

/**
 * This class acts as a decorator to the actual provider. It enforces the
 * authentication rules defined in the authentication config file.
 */
public final class AuthenticationFilter extends DecoratorApiProvider {

    private static final char PACKAGE_DELIMITER = '.';
    private static final Logger logger = LoggerFactory
            .getLogger(AuthenticationFilter.class);
    private static final AuthnScheme NO_AUTHN_SCHEME = AuthnScheme
            .getNoAuthenticationScheme();
    private static final ErrorValue UNAUTHENTICATED;
    private static final ErrorValue OPERATION_NOT_FOUND;
    private static Set<String> VAPI_ANON_OPERATIONS;

    static {
        Message message = MessageFactory
                .getMessage("vapi.method.authentication.required");
        UNAUTHENTICATED = ErrorValueFactory
                .buildErrorValue(StandardDataFactory.UNAUTHENTICATED, message);

        message = MessageFactory
                .getMessage("vapi.authentication.metadata.required");
        OPERATION_NOT_FOUND = ErrorValueFactory
                .buildErrorValue(StandardDataFactory.OPERATION_NOT_FOUND,
                                 message);

        VAPI_ANON_OPERATIONS = new HashSet<>(Arrays
                .asList("com.vmware.vapi.metadata.routing.component.list",
                        "com.vmware.vapi.metadata.routing.component.get",
                        "com.vmware.vapi.metadata.routing.component.fingerprint",
                        "com.vmware.vapi.metadata.routing.service.operation.list",
                        "com.vmware.vapi.metadata.routing.service.operation.get",
                        "com.vmware.vapi.metadata.routing.package.list",
                        "com.vmware.vapi.metadata.routing.package.get",
                        "com.vmware.vapi.metadata.routing.service.list",
                        "com.vmware.vapi.metadata.routing.service.get",
                        "com.vmware.vapi.metadata.cli.command.list",
                        "com.vmware.vapi.metadata.cli.command.get",
                        "com.vmware.vapi.metadata.cli.command.fingerprint",
                        "com.vmware.vapi.metadata.cli.namespace.list",
                        "com.vmware.vapi.metadata.cli.namespace.get",
                        "com.vmware.vapi.metadata.cli.namespace.fingerprint",
                        "com.vmware.vapi.metadata.privilege.component.list",
                        "com.vmware.vapi.metadata.privilege.component.get",
                        "com.vmware.vapi.metadata.privilege.component.fingerprint",
                        "com.vmware.vapi.metadata.privilege.service.operation.list",
                        "com.vmware.vapi.metadata.privilege.service.operation.get",
                        "com.vmware.vapi.metadata.privilege.package.list",
                        "com.vmware.vapi.metadata.privilege.package.get",
                        "com.vmware.vapi.metadata.privilege.service.list",
                        "com.vmware.vapi.metadata.privilege.service.get",
                        "com.vmware.vapi.metadata.authentication.component.list",
                        "com.vmware.vapi.metadata.authentication.component.get",
                        "com.vmware.vapi.metadata.authentication.component.fingerprint",
                        "com.vmware.vapi.metadata.authentication.service.operation.list",
                        "com.vmware.vapi.metadata.authentication.service.operation.get",
                        "com.vmware.vapi.metadata.authentication.package.list",
                        "com.vmware.vapi.metadata.authentication.package.get",
                        "com.vmware.vapi.metadata.authentication.service.list",
                        "com.vmware.vapi.metadata.authentication.service.get",
                        "com.vmware.vapi.metadata.metamodel.component.list",
                        "com.vmware.vapi.metadata.metamodel.component.get",
                        "com.vmware.vapi.metadata.metamodel.component.fingerprint",
                        "com.vmware.vapi.metadata.metamodel.enumeration.list",
                        "com.vmware.vapi.metadata.metamodel.enumeration.get",
                        "com.vmware.vapi.metadata.metamodel.resource.model.list",
                        "com.vmware.vapi.metadata.metamodel.service.operation.list",
                        "com.vmware.vapi.metadata.metamodel.service.operation.get",
                        "com.vmware.vapi.metadata.metamodel.service.hidden.list",
                        "com.vmware.vapi.metadata.metamodel.package.list",
                        "com.vmware.vapi.metadata.metamodel.package.get",
                        "com.vmware.vapi.metadata.metamodel.resource.list",
                        "com.vmware.vapi.metadata.metamodel.service.list",
                        "com.vmware.vapi.metadata.metamodel.service.get",
                        "com.vmware.vapi.metadata.metamodel.structure.list",
                        "com.vmware.vapi.metadata.metamodel.structure.get",
                        "com.vmware.vapi.rest.navigation.component.list",
                        "com.vmware.vapi.rest.navigation.options.get",
                        "com.vmware.vapi.rest.navigation.resource.get",
                        "com.vmware.vapi.rest.navigation.resource.list",
                        "com.vmware.vapi.rest.navigation.root.get",
                        "com.vmware.vapi.rest.navigation.service.list",
                        "com.vmware.vapi.std.introspection.operation.list",
                        "com.vmware.vapi.std.introspection.operation.get",
                        "com.vmware.vapi.std.introspection.provider.get",
                        "com.vmware.vapi.std.introspection.service.list",
                        "com.vmware.vapi.std.introspection.service.get"));
    }

    static final Set<ErrorDefinition> AUTHN_FILTER_ERROR_DEFS =
            Collections.singleton(
                    StandardDataFactory.createStandardErrorDefinition(
                            Constants.UNAUTHENTICATED));

    private final Map<String, List<AuthnScheme>> ifaceRulesTable;
    private final Map<String, List<AuthnScheme>> packageRulesTable;
    private final Map<String, List<AuthnScheme>> operationRulesTable;
    private final List<AuthenticationHandler> authnHandlers;

    /**
     * @param decoratedProvider the provider that is decorated with this
     *                          authentication filter. cannot be null.
     * @param authnConfig the authentication configuration. cannot be null.
     * @param authnHandlers a list of the authentication handlers that will
     *                      be executed during authentication events. cannot be
     *                      null.
     */
    public AuthenticationFilter(ApiProvider decoratedProvider,
                                AuthenticationConfig authnConfig,
                                List<AuthenticationHandler> authnHandlers) {
        super(new ErrorAugmentingFilter(decoratedProvider,
                                        AUTHN_FILTER_ERROR_DEFS));

        Validate.notNull(authnConfig);
        Validate.notNull(authnHandlers);
        ifaceRulesTable = copyRules(authnConfig.getIFaceAuthenticationRules(),
                                    true);
        packageRulesTable = copyRules(authnConfig.getPackageAuthenticationRules(),
                                      true);
        operationRulesTable = copyRules(authnConfig.getOperationAuthenticationRules(),
                                        false);
        this.authnHandlers = authnHandlers;
    }

    @Override
    public void invoke(final String serviceId,
                       final String operationId,
                       final DataValue input,
                       final ExecutionContext ctx,
                       final AsyncHandle<MethodResult> asyncHandle) {

        List<AuthnScheme> requiredAuthnSchemes = getMethodAuthnScheme(serviceId,
                                                                      operationId);
        if (requiredAuthnSchemes == null || requiredAuthnSchemes.isEmpty()) {
            asyncHandle.setResult(MethodResult.newErrorResult(OPERATION_NOT_FOUND));
            return;
        }

        final SecurityContext secCtx = ctx.retrieveSecurityContext();
        String wireSchemeId = extractWireScheme(secCtx);
        AuthnScheme wireScheme = createAuthnScheme(wireSchemeId);

        if (isSchemeAllowed(requiredAuthnSchemes, wireScheme)) {
            /*
             * shortcut for unauthenticated invocations of operations which
             * allow unauthenticated access
             */
            if (NO_AUTHN_SCHEME == wireScheme) {
                decoratedProvider.invoke(serviceId,
                                         operationId,
                                         input,
                                         ctx,
                                         asyncHandle);
                return;
            }
        } else if (isSchemeAllowed(requiredAuthnSchemes, NO_AUTHN_SCHEME)) {
            /*
             * unsupported scheme was received on the wire, but unauthenticated
             * access is allowed so drop the authentication data which was
             * not checked by the handlers and let the invocation happen
             */
            logger.debug("Unexpected scheme '{}' found in the invocation of "
                    + "method '{}.{}' which allows 'NoAuthentication'",
                         wireScheme,
                         serviceId,
                         operationId);
            decoratedProvider.invoke(serviceId,
                                     operationId,
                                     input,
                                     ctx.withSecurityContext(null),
                                     asyncHandle);
            return;
        } else {
            // authentication rules not fulfilled
            logger.debug("Invalid authentication scheme '{}' for method "
                    + "{}.{} which allows {}",
                         wireScheme,
                         serviceId,
                         operationId,
                         requiredAuthnSchemes);
            asyncHandle.setResult(MethodResult.newErrorResult(UNAUTHENTICATED));
            return;
        }

        // check that the authentication is valid
        AuthenticationHandler authnHandler = findHandler(wireSchemeId);
        if (authnHandler == null) {
            // no handler is found for the requested authentication
            asyncHandle.setResult(MethodResult.newErrorResult(UNAUTHENTICATED));
            return;
        }
        authnHandler.authenticate(secCtx,
                new AsyncHandle<AuthenticationHandler.AuthenticationResult>() {

            @Override
            public void updateProgress(DataValue progress) {
                // noop
            }

            @Override
            public void setResult(AuthenticationResult result) {
                SecurityContext authnSecCtx = secCtx;
                if (result != null && result.getSecurityContext() != null) {
                    // overwrite the security context as set from the authn handler
                    authnSecCtx = result.getSecurityContext();
                }
                SecurityContextImpl s = new SecurityContextImpl(authnSecCtx,
                                                                result);
                decoratedProvider.invoke(
                        serviceId,
                        operationId,
                        input,
                        ctx.withSecurityContext(s),
                        asyncHandle);
            }

            @Override
            public void setError(RuntimeException error) {
                logger.info("Authentication failed", error);
                asyncHandle.setResult(MethodResult.newErrorResult(UNAUTHENTICATED));
            }
        });
    }

    private AuthenticationHandler findHandler(String schemeId) {
        for (AuthenticationHandler handler : authnHandlers) {
            if (handler.supportedAuthenticationSchemes().contains(schemeId)) {
                logger.debug("Selected authentication handler is {}", handler);
                return handler;
            }
        }

        logger.debug("No suitable authentication handler found for scheme '{}'", schemeId);
        return null;
    }

    /**
     * @param serviceId must not be <code>null</code>
     * @param operationId must not be <code>null</code>
     * @return the list authentication schemes one of which should be enforced
     *         for the given method. can be <code>null</code> if no authn scheme
     *         is found at package level.
     */
    private List<AuthnScheme> getMethodAuthnScheme(String serviceId,
                                                   String operationId) {
        String fqn = MethodIdentifier.getFullyQualifiedName(serviceId,
                                                            operationId);
        List<AuthnScheme> requiredSchemesList = operationRulesTable.get(fqn);
        if (requiredSchemesList != null) {
            return requiredSchemesList;

        } else if (VAPI_ANON_OPERATIONS.contains(fqn)) {
            return Collections.singletonList(NO_AUTHN_SCHEME);
        }

        requiredSchemesList = ifaceRulesTable.get(serviceId);
        if (requiredSchemesList != null) {
            return requiredSchemesList;
        }

        String closestPackage = findClosestPackage(serviceId,
                                                   packageRulesTable.keySet());
        return packageRulesTable.get(closestPackage);
    }

    /**
     * Checks if the authentication scheme received on the wire is allowed
     *
     * @param allowedList the allowed authentication schemes. not null.
     * @param wireScheme the scheme received on the wire. not null.
     * @return true if the scheme received on the wire can be used for
     *         authentication, false otherwise
     */
    private static boolean isSchemeAllowed(List<AuthnScheme> allowedList,
                                           AuthnScheme wireScheme) {
        for (AuthnScheme scheme : allowedList) {
            if (scheme.isAllowed(wireScheme)) {
                return true;
            }
        }

        return false;
    }

    /**
     * Extracts the scheme that is used in the request
     *
     * @param ctx the SecurityContext. can be null.
     * @return {@link AuthnScheme} representation of the scheme used in the
     *         request. Cannot be null.
     */
    private static String extractWireScheme(SecurityContext ctx) {
        if (ctx == null) {
            return null;
        }

        Object schemeIdObject = ctx.getProperty(SecurityContext.AUTHENTICATION_SCHEME_ID);
        return SecurityUtil.narrowType(schemeIdObject, String.class);
    }

    /**
     * Extracts the scheme that is used in the request
     *
     * @param schemeId the scheme id that came on the wire. can be null.
     * @return {@link AuthnScheme} representation of the scheme used in the
     *         request. Cannot be null.
     */
    private static AuthnScheme createAuthnScheme(String schemeId) {
        if (schemeId == null) {
            return AuthnScheme.getNoAuthenticationScheme();
        }
        return new AuthnScheme(Collections.singletonList(schemeId));
    }

    /**
     * Finds the closest package that contains the invocation package
     *
     * @param invocationService cannot be null
     * @param registeredPackages the packages that are listed in the authn config
     * @return the closest package that contains the invocation package.
     *         cannot be null.
     */
    static String findClosestPackage(String invocationService,
                                     Iterable<String> registeredPackages) {
        String result = "";
        for (String packageName : registeredPackages) {
            int packageNameLength = packageName.length();
            if (packageNameLength > result.length()
                    && invocationService.startsWith(packageName)
                    && invocationService.length() > packageNameLength
                    && invocationService
                            .charAt(packageNameLength) == PACKAGE_DELIMITER) {
                result = packageName;
            }
        }

        return result;
    }

    /**
     * @param rejectNoAuth specifies whether to fail upon discovering a
     *        NoAuthentication scheme
     * @return a copy of the specified map
     */
    private static Map<String, List<AuthnScheme>> copyRules(
                             Map<String, List<AuthnScheme>> authenticationRules,
                             boolean rejectNoAuth) {

        Map<String, List<AuthnScheme>> result = new HashMap<>(authenticationRules);
        for (Entry<String, List<AuthnScheme>> e : result.entrySet()) {
            List<AuthnScheme> schemes = e.getValue();
            if (schemes == null) {
                continue;
            }

            if (rejectNoAuth && schemes
                    .contains(AuthnScheme.getNoAuthenticationScheme())) {
                throw new RuntimeException(String
                        .format("%s cannot be marked as anonymous. NoAuthentication"
                                + " is allowed on operations level only.",
                                e.getKey()));
            }
            e.setValue(new ArrayList<>(schemes));
        }
        return result;
    }

    /**
     * @return the fully-qualified names of operations which the vapi-runtime
     *         deems anonymous without the secure-by-default requirement of
     *         having them listed explicitly as anonymous.
     */
    public static Set<String> getImplicitlyAnonymousOperations() {
        return Collections.unmodifiableSet(VAPI_ANON_OPERATIONS);
    }

    private static class SecurityContextImpl implements SecurityContext {

        private final Map<String, Object> ctxData;

        private SecurityContextImpl(SecurityContext ctx,
                                    AuthenticationResult result) {
            ctxData = new HashMap<>(ctx.getAllProperties());
            ctxData.put(SecurityContext.AUTHENTICATION_DATA_ID, result);
        }

        @Override
        public Object getProperty(String key) {
            return ctxData.get(key);
        }

        @Override
        public Map<String, Object> getAllProperties() {
            return Collections.unmodifiableMap(ctxData);
        }
    }
}
