/* **********************************************************
 * Copyright 2012-2014, 2019, 2021 VMware, Inc.  All rights reserved.
 *      -- VMware Confidential
 * **********************************************************/
package com.vmware.vapi.cis.authn.json;

import static com.vmware.vapi.internal.security.SecurityContextConstants.SIGNATURE_ALGORITHM_KEY;

import java.io.UnsupportedEncodingException;
import java.util.Calendar;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.vmware.vapi.CoreException;
import com.vmware.vapi.Message;
import com.vmware.vapi.MessageFactory;
import com.vmware.vapi.cis.authn.SamlTokenAuthnHandler;
import com.vmware.vapi.core.ExecutionContext;
import com.vmware.vapi.core.ExecutionContext.SecurityContext;
import com.vmware.vapi.data.ConstraintValidationException;
import com.vmware.vapi.dsig.json.SignatureException;
import com.vmware.vapi.dsig.json.StsTrustChain;
import com.vmware.vapi.internal.cis.authn.json.JsonSignerImpl;
import com.vmware.vapi.internal.dsig.json.JsonCanonicalizer;
import com.vmware.vapi.internal.dsig.json.Verifier;
import com.vmware.vapi.internal.security.SecurityContextConstants;
import com.vmware.vapi.internal.security.SecurityUtil;
import com.vmware.vapi.internal.util.DateTimeConverter;
import com.vmware.vapi.protocol.RequestProcessor;
import com.vmware.vapi.security.StdSecuritySchemes;
import com.vmware.vapi.saml.ConfirmationType;
import com.vmware.vapi.saml.DefaultTokenFactory;
import com.vmware.vapi.saml.SamlToken;
import com.vmware.vapi.saml.exception.InvalidTokenException;

/**
 * This processor handles SAML token authentication. In case of HoK token the
 * request signature is validated.
 */
public final class JsonSignatureVerificationProcessor implements RequestProcessor {

    private static final int MILLIS_PER_SECOND = 1000;
    private static final Message VERIFY_ERROR = MessageFactory
            .getMessage("vapi.signature.verify");
    private static final Message DECODE_ERROR =
            MessageFactory.getMessage("vapi.sso.signproc.decoderequest");
    private static final Logger logger =
            LoggerFactory.getLogger(JsonSignatureVerificationProcessor.class);
    private final DateTimeConverter dateConverter = new DateTimeConverter();

    private final Verifier jsonVerifier;
    private final DefaultTokenFactory tokenFactory = new DefaultTokenFactory();
    final StsTrustChain stsTrustChain;
    private final long clockToleranceSec;

    /**
     * @param stsTrustChain used to retrieve the STS signing certificates for
     *                      validating SAML tokens. Must not be null.
     */
    public JsonSignatureVerificationProcessor(StsTrustChain stsTrustChain) {
        this(stsTrustChain,
             Verifier.DEFAULT_CLOCK_TOLERANCE_SEC);
    }

    /**
     * @param stsTrustChain used to retrieve the STS signing certificates for
     *                      validating SAML tokens. Must not be null.
     * @param clockToleranceSec the allowed time discrepancy between the client
     *                          and the server. Must not be negative.
     */
    public JsonSignatureVerificationProcessor(StsTrustChain stsTrustChain,
                                              long clockToleranceSec) {
        this(new JsonSignerImpl(new JsonCanonicalizer(), stsTrustChain),
                stsTrustChain,
                clockToleranceSec);
    }

    /**
     * @param verifier used to verify the signature. must not be null.
     * @param stsTrustChain used to retrieve the STS trust chain for validating
     *                      SAML tokens. Must not be null.
     * @param clockToleranceSec the allowed time discrepancy between the client
     *                          and the server. Must not be negative.
     */
    JsonSignatureVerificationProcessor(Verifier verifier,
                                       StsTrustChain stsTrustChain,
                                       long clockToleranceSec) {
        Objects.requireNonNull(verifier);
        Objects.requireNonNull(stsTrustChain);
        if (clockToleranceSec < 0) {
            throw new IllegalArgumentException("Clock tolerance must not be negative: " + clockToleranceSec);
        }
        this.stsTrustChain = stsTrustChain;
        this.clockToleranceSec = clockToleranceSec;
        jsonVerifier = verifier;
    }

    @Override
    public byte[] process(byte[] request,
                          Map<String, Object> metadata,
                          Request vapiRequest) {

        Objects.requireNonNull(request);
        Objects.requireNonNull(metadata);
        Objects.requireNonNull(vapiRequest);

        // TODO include request ID in every log statement produced by this method

        SecurityContext ctx = getSecurityContext(vapiRequest);
        if (ctx == null) {
            // nothing to process
            return request;
        }

        String schemeId = SecurityUtil.narrowType(
                ctx.getProperty(SecurityContext.AUTHENTICATION_SCHEME_ID), String.class);
        if (!validateSchemeId(schemeId)) {
            // nothing to process
            return request;
        }

        // TODO get the request in some other way because requestToString actually
        // copies the request??
        String requestString = requestToString(request);
        SamlToken token = null;
        Exception error = null;

        try {
            if (schemeId.equalsIgnoreCase(StdSecuritySchemes.SAML_TOKEN)) {
                token = validateSignature(ctx, requestString);
            } else {
                token = parseBearerToken(ctx);
            }
        } catch (Exception e) {
            error = e;
        }

        Map<String, Object> addSecProcData = getSecurityProcData(metadata);
        addSecProcData.put(SamlTokenAuthnHandler.SAML_TOKEN_KEY, token);
        addSecProcData.put(SamlTokenAuthnHandler.ERROR_KEY, error);
        metadata.put(RequestProcessor.SECURITY_PROC_METADATA_KEY, addSecProcData);

        return request;
    }

    /**
     * Parses a holder of key token authentication scheme and validates
     * the request signature
     *
     * @param ctx must not be null
     * @param request must not be null
     * @return the parsed hok token
     * @throws InvalidTokenException
     */
    private SamlToken validateSignature(SecurityContext ctx, String request)
            throws InvalidTokenException {
        assert ctx != null;

        @SuppressWarnings("unchecked")
        Map<String, Object> signature = SecurityUtil.narrowType(ctx.getProperty(
                SecurityContextConstants.SIGNATURE_KEY), Map.class);

        if (signature == null) {
            logger.debug("Signature not found.");
            throw new SignatureException(VERIFY_ERROR);
        }

        validateSignatureTimestamp(ctx);

        signature.put(SIGNATURE_ALGORITHM_KEY,
                      SecurityUtil.narrowType(ctx.getProperty(SIGNATURE_ALGORITHM_KEY),
                                              String.class));

        // TODO use ByteArrayInputStream instead of String here (request)?
        if (!jsonVerifier.verifySignature(request,
                                          signature,
                                          clockToleranceSec)) {
            // TODO add refetch
            throw new SignatureException(VERIFY_ERROR);
        }
        logger.debug("Signature validated");

        // TODO we have just parsed the SamlToken in verifySignature() - find a
        // way to reuse that
        return parseToken(signature.get(SecurityContextConstants.SAML_TOKEN_KEY));
    }

    /**
     * Validates that current time is within the timestamp created/expires dates
     *
     * @param ctx must not be <code>null</code>
     * @throws SignatureException
     */
    private void validateSignatureTimestamp(SecurityContext ctx) {
        @SuppressWarnings("unchecked")
        Map<String, String> timestamp = SecurityUtil.narrowType(ctx.getProperty(
                SecurityContextConstants.TIMESTAMP_KEY), Map.class);
        if (timestamp == null) {
            logger.debug("Timestamp is missing");
            throw new SignatureException(VERIFY_ERROR);
        }
        String createdStr = timestamp.get(SecurityContextConstants.TS_CREATED_KEY);
        String expiresStr = timestamp.get(SecurityContextConstants.TS_EXPIRES_KEY);
        if (createdStr == null || expiresStr == null) {
            logger.debug("Invalid timestamp: " + createdStr + " " + expiresStr);
            throw new SignatureException(VERIFY_ERROR);
        }

        Calendar created = null;
        Calendar expires = null;
        try {
            created = dateConverter.fromStringValue(createdStr);
            expires = dateConverter.fromStringValue(expiresStr);
        } catch (ConstraintValidationException e) {
            logger.debug("Cannot convert timestamp date", e);
            throw new SignatureException(VERIFY_ERROR, e);
        }

        long createdMillis = created.getTimeInMillis();
        long expiresMillis = expires.getTimeInMillis();
        long toleranceMillis = clockToleranceSec * MILLIS_PER_SECOND;

        if (createdMillis > expiresMillis) {
            logger.debug("Invalid timestamp: " + createdStr + " " + expiresStr);
            throw new SignatureException(VERIFY_ERROR);
        }

        long currentTime = System.currentTimeMillis();
        if (createdMillis > currentTime + toleranceMillis) {
            if (logger.isDebugEnabled()) {
                logger.debug("Invalid timestamp. Created: " + new Date(createdMillis) +
                             " Current time: " + new Date(currentTime));
            }
            throw new SignatureException(VERIFY_ERROR);
        }
        if (expiresMillis < currentTime - toleranceMillis) {
            if (logger.isDebugEnabled()) {
                logger.debug("Invalid timestamp. Expires: " + new Date(expiresMillis) +
                             " Current time: " + new Date(currentTime));
            }
            throw new SignatureException(VERIFY_ERROR);
        }
        logger.debug("Signature timestamp validated");
    }

    /**
     * @param tokenString can be null.
     * @return the parsed SamlToken object or <code>null</code> if the tokenString
     *         is null or is not a String.
     * @throws InvalidTokenException
     */
    private SamlToken parseToken(Object tokenString) throws InvalidTokenException {
        SamlToken token = null;

        if (tokenString instanceof String) {
            // TODO use factory per thread instead of synchronizing
            synchronized (tokenFactory) {
                token = tokenFactory.parseToken((String) tokenString,
                                                stsTrustChain.getStsTrustChain(),
                                                clockToleranceSec);
            }
        }

        return token;
    }

    /**
     * Parses a bearer token authentication scheme
     *
     * @param secCtx not null
     * @return the parsed bearer token
     * @throws InvalidTokenException
     * @throws RuntimeException
     */
    private SamlToken parseBearerToken(SecurityContext ctx) throws InvalidTokenException {
        assert ctx != null;

        SamlToken token = parseToken(ctx.getProperty(SecurityContextConstants.SAML_TOKEN_KEY));
        if (token == null || token.getConfirmationType() != ConfirmationType.BEARER) {
            throw new RuntimeException("Cannot parse bearer token: " + token);
        }

        return token;
    }

    /**
     * Validates if this processor works with the schemeId from the request
     *
     * @param schemeId can be <code>null</code>
     * @return true if the processor can handle the schemeId
     */
    private boolean validateSchemeId(String schemeId) {
        return schemeId != null &&
                (schemeId.equalsIgnoreCase(StdSecuritySchemes.SAML_TOKEN) ||
                 schemeId.equalsIgnoreCase(StdSecuritySchemes.SAML_BEARER_TOKEN));
    }

    private SecurityContext getSecurityContext(Request request) {
        ExecutionContext ctx = request.getCtx();
        if (ctx == null) {
            return null;
        }
        return ctx.retrieveSecurityContext();
    }

    private String requestToString(byte[] request) {
        try {
            return new String(request, RequestProcessor.UTF8_CHARSET);
        } catch (UnsupportedEncodingException e) {
            logger.error(e.getMessage(), e);
            throw new CoreException(DECODE_ERROR);
        }

    }

    /**
     * @param metadata must not be <code>null</code>
     * @return the security context data structure that contains the additional
     *         data that should be appended to the request's security context
     */
    private Map<String, Object> getSecurityProcData(
            Map<String, Object> metadata) {
        assert metadata != null;

        @SuppressWarnings("unchecked")
        Map<String, Object> procDataStruct = SecurityUtil.narrowType(
                metadata.get(RequestProcessor.SECURITY_PROC_METADATA_KEY), Map.class);
        if (procDataStruct == null) {
            procDataStruct = new HashMap<String, Object>();
        }

        return procDataStruct;
    }
}