/*
 * Copyright (C) 2009 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.conscrypt;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.security.cert.Certificate;
import java.security.cert.CertificateEncodingException;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSessionContext;

/**
 * Supports SSL session caches.
 */
abstract class AbstractSessionContext implements SSLSessionContext {

    /**
     * Maximum lifetime of a session (in seconds) after which it's considered invalid and should not
     * be used to for new connections.
     */
    private static final int DEFAULT_SESSION_TIMEOUT_SECONDS = 8 * 60 * 60;

    volatile int maximumSize;
    volatile int timeout = DEFAULT_SESSION_TIMEOUT_SECONDS;

    final long sslCtxNativePointer = NativeCrypto.SSL_CTX_new();

    /** Identifies OpenSSL sessions. */
    static final int OPEN_SSL = 1;

    /** Identifies OpenSSL sessions with OCSP stapled data. */
    static final int OPEN_SSL_WITH_OCSP = 2;

    /** Identifies OpenSSL sessions with TLS SCT data. */
    static final int OPEN_SSL_WITH_TLS_SCT = 3;

    //@SuppressWarnings("serial")
    private final Map sessions = new LinkedHashMap() {

		private static final long serialVersionUID = -7489291500279637668L;

		//@Override
        protected boolean removeEldestEntry(
                Map.Entry eldest) {
            boolean remove = maximumSize > 0 && size() > maximumSize;
            if (remove) {
                remove(eldest.getKey());
                sessionRemoved((SSLSession)eldest.getValue());
            }
            return false;
        }
    };

    /**
     * Constructs a new session context.
     *
     * @param maximumSize of cache
     */
    AbstractSessionContext(int maximumSize) {
        this.maximumSize = maximumSize;
    }

    /**
     * Returns the collection of sessions ordered from oldest to newest
     */
    private Iterator sessionIterator() {
        synchronized (sessions) {
            SSLSession[] array = (SSLSession[])sessions.values().toArray(
                    new SSLSession[sessions.size()]);
            return Arrays.asList(array).iterator();
        }
    }

    //@Override
    public final Enumeration getIds() {
        final Iterator i = sessionIterator();
        return new Enumeration() {
            private SSLSession next;

            //@Override
            public boolean hasMoreElements() {
                if (next != null) {
                    return true;
                }
                while (i.hasNext()) {
                    SSLSession session = (SSLSession)i.next();
                    if (Platform.SSLSession_isValid(session)) {
                        next = session;
                        return true;
                    }
                }
                next = null;
                return false;
            }

            //@Override
            public Object nextElement() {
                if (hasMoreElements()) {
                    byte[] id = next.getId();
                    next = null;
                    return id;
                }
                throw new NoSuchElementException();
            }
        };
    }

    //@Override
    public final int getSessionCacheSize() {
        return maximumSize;
    }

    //@Override
    public final int getSessionTimeout() {
        return timeout;
    }

    /**
     * Makes sure cache size is < maximumSize.
     */
    protected void trimToSize() {
        synchronized (sessions) {
            int size = sessions.size();
            if (size > maximumSize) {
                int removals = size - maximumSize;
                Iterator i = sessions.values().iterator();
                do {
                    SSLSession session = (SSLSession)i.next();
                    i.remove();
                    sessionRemoved(session);
                } while (--removals > 0);
            }
        }
    }

    //@Override
    public void setSessionTimeout(int seconds)
            throws IllegalArgumentException {
        if (seconds < 0) {
            throw new IllegalArgumentException("seconds < 0");
        }
        timeout = seconds;

        synchronized (sessions) {
            Iterator i = sessions.values().iterator();
            while (i.hasNext()) {
                SSLSession session = (SSLSession)i.next();
                // SSLSession's know their context and consult the
                // timeout as part of their validity condition.
                if (!Platform.SSLSession_isValid(session)) {
                    i.remove();
                    sessionRemoved(session);
                }
            }
        }
    }

    /**
     * Called when a session is removed. Used by ClientSessionContext
     * to update its host-and-port based cache.
     */
    protected abstract void sessionRemoved(SSLSession session);

    //@Override
    public void setSessionCacheSize(int size)
            throws IllegalArgumentException {
        if (size < 0) {
            throw new IllegalArgumentException("size < 0");
        }

        int oldMaximum = maximumSize;
        maximumSize = size;

        // Trim cache to size if necessary.
        if (size < oldMaximum) {
            trimToSize();
        }
    }

    /**
     * Converts the given session to bytes.
     *
     * @return session data as bytes or null if the session can't be converted
     */
    public byte[] toBytes(SSLSession session) {
        // TODO: Support SSLSessionImpl, too.
        if (!(session instanceof OpenSSLSessionImpl)) {
            return null;
        }

        OpenSSLSessionImpl sslSession = (OpenSSLSessionImpl) session;
        try {
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            DataOutputStream daos = new DataOutputStream(baos);

            daos.writeInt(OPEN_SSL_WITH_TLS_SCT); // session type ID

            // Session data.
            byte[] data = sslSession.getEncoded();
            daos.writeInt(data.length);
            daos.write(data);

            // Certificates.
            Certificate[] certs = session.getPeerCertificates();
            daos.writeInt(certs.length);

            for(int i = 0; i < certs.length; i++){
            	Certificate cert = certs[i];
                data = cert.getEncoded();
                daos.writeInt(data.length);
                daos.write(data);
            }

            List ocspResponses = sslSession.getStatusResponses();
            daos.writeInt(ocspResponses.size());
            
            Iterator iter = ocspResponses.iterator();
            while(iter.hasNext()){
            	byte[] ocspResponse = (byte[]) iter.next();
                daos.writeInt(ocspResponse.length);
                daos.write(ocspResponse);
            }

            byte[] tlsSctData = sslSession.getTlsSctData();
            if (tlsSctData != null) {
                daos.writeInt(tlsSctData.length);
                daos.write(tlsSctData);
            } else {
                daos.writeInt(0);
            }

            // TODO: local certificates?

            return baos.toByteArray();
        } catch (IOException e) {
            System.err.println("Failed to convert saved SSL Session: " + e.getMessage());
            return null;
        } catch (CertificateEncodingException e) {
            log(e);
            return null;
        }
    }

    /**
     * Creates a session from the given bytes.
     *
     * @return a session or null if the session can't be converted
     */
    OpenSSLSessionImpl toSession(byte[] data, String host, int port) {
        ByteArrayInputStream bais = new ByteArrayInputStream(data);
        DataInputStream dais = new DataInputStream(bais);
        try {
            int type = dais.readInt();
            if (type != OPEN_SSL && type != OPEN_SSL_WITH_OCSP && type != OPEN_SSL_WITH_TLS_SCT) {
                throw new IOException("Unexpected type ID: " + type);
            }

            int length = dais.readInt();
            byte[] sessionData = new byte[length];
            dais.readFully(sessionData);

            int count = dais.readInt();
            X509Certificate[] certs = new X509Certificate[count];
            for (int i = 0; i < count; i++) {
                length = dais.readInt();
                byte[] certData = new byte[length];
                dais.readFully(certData);
                try {
                    certs[i] = OpenSSLX509Certificate.fromX509Der(certData);
                } catch (Exception e) {
                    throw new IOException("Can not read certificate " + i + "/" + count);
                }
            }

            byte[] ocspData = null;
            if (type >= OPEN_SSL_WITH_OCSP) {
                // We only support one OCSP response now, but in the future
                // we may support RFC 6961 which has multiple.
                int countOcspResponses = dais.readInt();
                dais.skipBytes(countOcspResponses);
                
                if (countOcspResponses >= 1) {
                    int ocspLength = dais.readInt();
                    ocspData = new byte[ocspLength];
                    dais.readFully(ocspData);
                   
                    // Skip the rest of the responses.
                    for (int i = 1; i < countOcspResponses; i++) {
                        ocspLength = dais.readInt();
                        dais.skipBytes(ocspLength);
                    }
                }
            }

            byte[] tlsSctData = null;
            if (type == OPEN_SSL_WITH_TLS_SCT) {
                int tlsSctDataLength = dais.readInt();
                if (tlsSctDataLength > 0) {
                    tlsSctData = new byte[tlsSctDataLength];
                    dais.readFully(tlsSctData);
                }
            }

            if (dais.available() != 0) {
                log(new AssertionError("Read entire session, but data still remains; rejecting"));
                return null;
            }

            return new OpenSSLSessionImpl(sessionData, host, port, certs, ocspData, tlsSctData,
                    this);
        } catch (IOException e) {
            log(e);
            return null;
        } catch (Exception e) {
            log(e);
            return null;
        }
    }

    protected SSLSession wrapSSLSessionIfNeeded(SSLSession session) {
        if (session instanceof AbstractOpenSSLSession) {
            return Platform.wrapSSLSession((AbstractOpenSSLSession) session);
        } else {
            return session;
        }
    }

    //@Override
    public SSLSession getSession(byte[] sessionId) {
        if (sessionId == null) {
            throw new NullPointerException("sessionId == null");
        }
        ByteArray key = new ByteArray(sessionId);
        SSLSession session;
        synchronized (sessions) {
            session = (SSLSession)sessions.get(key);
        }
        if (session != null && Platform.SSLSession_isValid(session)) {
            return wrapSSLSessionIfNeeded(session);
        }
        return null;
    }

    void putSession(SSLSession session) {
        byte[] id = session.getId();
        if (id.length == 0) {
            return;
        }
        ByteArray key = new ByteArray(id);
        synchronized (sessions) {
            sessions.put(key, session);
        }
    }

    static void log(Throwable t) {
        System.out.println("Error inflating SSL session: "
                + (t.getMessage() != null ? t.getMessage() : t.getClass().getName()));
    }

    //@Override
    protected void finalize() throws Throwable {
        try {
            NativeCrypto.SSL_CTX_free(sslCtxNativePointer);
        } finally {
            super.finalize();
        }
    }
}
