/* **********************************************************
 * Copyright 2021 VMware, Inc.  All rights reserved. -- VMware Confidential
 * **********************************************************/
package com.vmware.vapi.cis.authn.json;

import static com.vmware.vapi.internal.dsig.json.Verifier.DEFAULT_CLOCK_TOLERANCE_SEC;

import java.security.cert.X509Certificate;
import java.util.Map;

import com.vmware.vapi.cis.authn.SamlTokenAuthnHandler;
import com.vmware.vapi.cis.util.RefreshableCache;
import com.vmware.vapi.dsig.json.SignatureException;
import com.vmware.vapi.dsig.json.StsTrustChain;
import com.vmware.vapi.protocol.RequestProcessor;
import com.vmware.vapi.saml.exception.InvalidTokenException;

/**
 * This is a decorator of {@link JsonSignatureVerificationProcessor} that will retry the processing
 * of the request in case an {@link InvalidTokenException} is received during the SAML token
 * validation<br>
 * New certificates are obtained either by calling the {@link StsTrustChain#getStsTrustChain()}
 * method when an {@link StsTrustChain} is provided in the constructor or by calling the
 * {@link RefreshableCache#refresh()} when a {@link RefreshableCache} is provided directly
 * @see {@link JsonSignatureVerificationProcessor}
 */
public class RetryJsonSignatureVerificationProcessor implements RequestProcessor {
    private final RefreshableCache<X509Certificate[]> certsCache;
    private RequestProcessor decoratedProcessor;


    /**
     * @param trustChain The STS trusted certificates. The method
     *        {@link StsTrustChain#getStsTrustChain()} will be called once during initialization and
     *        subsequently whenever the certificates need to be refreshed
     * @param retryDelayMs Minimum period of time in milliseconds between refreshes of the STS trust
     *        chain
     */
    public RetryJsonSignatureVerificationProcessor(final StsTrustChain trustChain,
                                                   long retryDelayMs) {
        this(trustChain, retryDelayMs, DEFAULT_CLOCK_TOLERANCE_SEC);
    }

    /**
     * @param trustChain The STS trusted certificates. The method
     *        {@link StsTrustChain#getStsTrustChain()} will be called once during initialization and
     *        subsequently whenever the certificates need to be refreshed
     * @param retryDelayMs Minimum period of time in milliseconds between refreshes of the STS trust
     *        chain
     * @param clockToleranceSec The allowed time discrepancy between the client and the server. Must
     *        not be negative
     */
    public RetryJsonSignatureVerificationProcessor(StsTrustChain trustChain,
                                                   long retryDelayMs,
                                                   long clockToleranceSec) {
        this.certsCache = new RefreshableCache<>(() -> {
            return trustChain.getStsTrustChain();
        }, retryDelayMs);
        this.decoratedProcessor = new JsonSignatureVerificationProcessor(new CacheStsTrustChain(),
                                                                         clockToleranceSec);
    }

    /**
     * @param trustChainCache A refreshable cache that provides the trusted root certificates.
     *        Whenever a new trust chain is needed, the refresh method will be called
     */
    public RetryJsonSignatureVerificationProcessor(RefreshableCache<X509Certificate[]> trustChainCache) {
        this.certsCache = trustChainCache;
        this.decoratedProcessor = new JsonSignatureVerificationProcessor(new CacheStsTrustChain());
    }

    @Override
    public byte[] process(byte[] requestBytes, Map<String, Object> metadata, Request request) {
        byte[] processed = decoratedProcessor.process(requestBytes, metadata, request);
        if (isInvalidTokenExceptionPresent(metadata)) {
            processed = retry(requestBytes, metadata, request);
        }
        return processed;
    }

    /**
     * Retries the call possibly with a refreshed trust chain
     */
    private byte[] retry(byte[] requestBytes, Map<String, Object> metadata, Request request) {
        certsCache.refresh();
        clearError(metadata);
        byte[] processed = decoratedProcessor.process(requestBytes, metadata, request);
        return processed;
    }

    private void clearError(Map<String, Object> metadata) {
        Map<String, Object> securityMetadata = getSecurityMetadata(metadata);
        if (securityMetadata != null) {
            securityMetadata.remove(SamlTokenAuthnHandler.ERROR_KEY);
        }
    }

    static boolean isInvalidTokenExceptionPresent(Map<String, Object> metadata) {
        Map<String, Object> securityProcMetadata = getSecurityMetadata(metadata);
        if (securityProcMetadata == null) {
            return false;
        }
        Exception error = (Exception) securityProcMetadata.get(SamlTokenAuthnHandler.ERROR_KEY);
        return (error instanceof SignatureException)
               && error.getCause() instanceof InvalidTokenException;
    }

    private static Map<String, Object> getSecurityMetadata(Map<String, Object> metadata) {
        if (metadata == null) {
            return null;
        }
        @SuppressWarnings("unchecked")
        Map<String, Object> securityProcMetadata = (Map<String, Object>) metadata
        .get(RequestProcessor.SECURITY_PROC_METADATA_KEY);
        return securityProcMetadata;
    }


    void setDecoratedProcessor(RequestProcessor decoratedProcessor) {
        this.decoratedProcessor = decoratedProcessor;
    }

    RequestProcessor getDecoratedProcessor() {
        return this.decoratedProcessor;
    }


    /**
     * {@link StsTrustChain} backed by the {@link RetryJsonSignatureVerificationProcessor#cache}
     * trust chain cache
     */
    class CacheStsTrustChain implements StsTrustChain {
        @Override
        public X509Certificate[] getStsTrustChain() {
            return certsCache.get();
        }
    }
}