#!/usr/bin/python
# vim:ts=4:sw=4 expandtab
#############################################################################
#
# Copyright Avaya Inc., All Rights Reserved.
#
# THIS IS UNPUBLISHED PROPRIETARY SOURCE CODE OF Avaya Inc.
#
# The copyright notice above does not evidence any actual or intended
# publication of such source code.
#
# Some third-party source code components may have  been modified from their
# original versions by Avaya Inc.
#
# The modifications are Copyright Avaya Inc., All Rights Reserved.
#
# Avaya - Confidential & Restricted. May not be distributed further without
# written permission of the Avaya owner.
#
#############################################################################
#
# If the certificates being signed are issued by SMGR, then upgrade by
# re-creating the certificates. This allows for longer certificate life and
# any new settings.
#
# The backup will contain the old keystores in PEM and JKS/BCFKS format and
# it will also contain an X509 text output of each of the keystores.  This
# text format is used to pull CN and Subject Alternate Name.  This values
# are placed in the CRD input file so that the new SMGR certs will match the
# values of the old one.
#
# This routine only applies to SMGR signed certificates.  Third part certs
# are later copied over the SMGR generated ones.
#
#############################################################################

import grp
import logging
import logging.handlers
from optparse import OptionParser
import os
import pwd
import re
import shutil
import stat
from subprocess import *
import sys
from xml.etree.ElementTree import *
from _elementtree import register_namespace

UPGRADE_PATH = '/opt/Avaya/upgrade'
TMPATH = '/opt/avaya/tm'
TM_UPG_PATH = UPGRADE_PATH + TMPATH + '/keystore'
SMGR_CA_FILE = UPGRADE_PATH + TMPATH + '/truststore/smgrca'
UPG_CRD_INPUT_FILE = TMPATH + '/CRDUpgrade_Input.xml'
UPG_CRD_TM_FILE = TMPATH + '/CRDUpgradeTM.xml'

###############################################################
# Create CRD files
###############################################################
"""
Creates a CRDUpgrade_Input.xml file
"""
def createUpgradeCRDFiles():
    ORIG_CRD_INPUT_FILE = TMPATH + '/CRDJEE_Input.xml'
    shutil.copy(ORIG_CRD_INPUT_FILE, UPG_CRD_INPUT_FILE)
    
    ORIG_CRD_TM_FILE = TMPATH + '/CRDJEETM.xml'
    shutil.copy(ORIG_CRD_TM_FILE, UPG_CRD_TM_FILE)
    
    # set to owner read
    os.chmod(UPG_CRD_INPUT_FILE, stat.S_IREAD)
    os.chmod(UPG_CRD_TM_FILE, stat.S_IREAD)
 
    # change the owner and group
    uid = pwd.getpwnam("root").pw_uid
    gid = grp.getgrnam("root").gr_gid
    os.chown(UPG_CRD_INPUT_FILE, uid, gid)
    os.chown(UPG_CRD_TM_FILE, uid, gid)
    
###############################################################
# Updates Input CRD file
###############################################################
"""
Updates the key size defined in CRDUpgradeTM.xml
"""
def updateInputCRDFile():
    tree = parse(UPG_CRD_INPUT_FILE)
    root = tree.getroot()

    global keySize
 
    # Get the SMGR CA info to see if third party certificates are
    # being used.
    smgrCA = ""
    if os.path.exists(SMGR_CA_FILE):
        cert = Popen(['openssl', 'x509', '-inform', 'der', '-in', SMGR_CA_FILE],
            stdout=PIPE)
        output = check_output(['openssl', 'x509', '-text', '-noout'], 
            stdin=cert.stdout)
        cert.wait()
            
        for line in output.split('\n'):
            match = re.search('.*Subject: (.*)', line)
            if match: 
                smgrCA = match.group(1)
                log.info("SMGR CA is '%s'", smgrCA)
                break
    else:
        log.warn("SMGR CA file not found in upgrade")
    
    for file in os.listdir(TM_UPG_PATH):
    
        # look for the text files with the certificate information
        if 'txt' not in file:
            continue
        log.debug("Found text file: %s", file)
    
        commonName = None
        subjectAltName = ""
    
        with open(TM_UPG_PATH + '/' + file, mode='r') as textFile:
            lines = textFile.readlines()
            for index, line in enumerate(lines):
            
                # determine the Issuer
                match = re.search('.*Issuer: (.*)', line)
                if match: 
                    issuer = match.group(1)
                    log.debug("Found Issuer text: '%s'", issuer)
                    
                    if issuer != smgrCA:
                        log.info("%s is a third party certificate, skipping",
                            file)
                        break
                
                # Pull out the CN
                match = re.search('.*Subject: CN=(.*), .*', line)
                if match and commonName == None:
                    commonName = match.group(1)
                    log.debug("Found CN text: %s", commonName)
                    end = commonName.find(',')
                    if end > 0:
                        # strip of everything left over after the CN
                        commonName = commonName[0:end]
                    continue
            
                # Handle the case where the CN is only a FQDN
                match = re.search('.*Subject: CN=(.*)', line)
                if match and commonName == None:
                    commonName = match.group(1)
                    log.debug("Found CN text: %s", commonName)
                    continue
                
                # Subject Alt Name fields
                if 'Subject Alternative Name' in line:
                    # the next line contains the values
                    subjectAltName = lines[index+1].strip()
                
                    log.debug("SubjectAltName text: %s", subjectAltName)
                    # parse the format so that is matches the expected CRD format
                    subjectAltName = subjectAltName.replace('DNS:', 'dnsName=')
                    subjectAltName = subjectAltName.replace('IP Address:', 'ipAddress=')
                    subjectAltName = subjectAltName.replace('URI:', 'uniformResourceIdentifier=')
                    subjectAltName = subjectAltName.replace(' ', '')
                    continue
            
                # figure out key size
                match = re.search('.*Public-Key: \((.*) bit\)', line)
                if match:
                    try:
                        newSize = int(match.group(1))
                        log.debug("Found key size: %d", newSize)
                    
                        if newSize > keySize:
                            log.info('Found large key than the default (%d bit)',
                                     newSize)
                            keySize = newSize
                        
                    except ValueError:
                        log.warn('Unable to figure out key size')
                    
        for idCert in root.findall(ns + 'INPUT_IDCert'):
            serviceId = idCert.find(ns + 'ServiceID').text
     
            if getServiceID(file) == serviceId:
                if commonName != None:   
                    idCert.find(ns + 'SS_Cert_CN').text = commonName
            
                subjectAltElement = idCert.find(ns + 'SS_Cert_Subject_Alt_Name')
                if len(subjectAltName) != 0:
                    log.info("Setting Sub Alt name to %s", subjectAltName)
                    subjectAltElement.find(ns + 'Value').text = subjectAltName
                else:
                    log.info("Subject Alt Name is empty for %s", serviceId)
                    idCert.remove(subjectAltElement)

    # write the modified XML out to the CRD Input file
    log.info('Updating CRDUpgrade_Input.xml')
    tree.write(open(UPG_CRD_INPUT_FILE, 'w'))

###############################################################
# Updates the key size
###############################################################
"""
Updates the key size defined in CRDUpgradeTM.xml
"""
def updateKeySize(keySize):
    
    log.debug("Updating key size to %d", keySize)
    
    # need to update the key size in CRDUpgradeTM.xml
    tree = parse(UPG_CRD_TM_FILE)
    root = tree.getroot()
    
    for profile in root.findall('{0}CRDTM/{0}SERVICE_PROFILE'.format(ns)):
        length = profile.find('{0}ServiceCERT_PublicKeyAlg/{0}Length'.format(ns))
        length.text = str(keySize)
        
    log.info('Updating CRDUpgradeTM.xml')
    tree.write(open(UPG_CRD_TM_FILE, 'w'))
    
###############################################################
# Get Matching Service ID
###############################################################
"""
Returns the service ID that matches the given keystore file.
"""
def getServiceID(filename):
    if 'asset_http_keystore' in filename:
        return 'securitymodule_http'
    elif 'asset_keystore' in filename:
        return 'securitymodule_sip'
    elif 'spirit' in filename:
        return 'spiritalias'
    elif 'container_keystore' in filename:
        return 'mgmt'
    elif 'postgres' in filename:
        return 'postgres'
    else:
        return None

###############################################################
# Main
###############################################################

usage = "usage: %prog [options]"
parser = OptionParser(usage)
parser.add_option("-d", "--debug", dest="debug", action="store_true",
        help="turn on debug level logging",
        default=False)
(options, args) = parser.parse_args()
    
# default to 2048 key size
DEFAULT_KEYSIZE=2048
keySize = DEFAULT_KEYSIZE

ns = '{http://www.avaya.com/mgmt/trust/data/crd/CRDXMLSchema}'

# remove the brackets when registering the namespace
register_namespace('', ns[1:-1])

# set up logging to syslog and console
log = logging.getLogger(__name__)
chandler = logging.StreamHandler()
if options.debug:
    log.setLevel(logging.DEBUG)
    chandler.setLevel(logging.DEBUG)
else:
    log.setLevel(logging.INFO)
    chandler.setLevel(logging.INFO)
handler = logging.handlers.SysLogHandler('/dev/log')
formatter = logging.Formatter('tmupgrade: %(levelname)s %(message)s')
cformatter = logging.Formatter('%(message)s')
handler.setFormatter(formatter)
handler.setLevel(logging.DEBUG)
chandler.setFormatter(cformatter)
log.addHandler(handler)
log.addHandler(chandler)

if os.path.exists(TM_UPG_PATH) == False:
    log.error("Upgrade files do not exist")
    sys.exit(1)
 
createUpgradeCRDFiles()
updateInputCRDFile()

if keySize > DEFAULT_KEYSIZE:
    updateKeySize(keySize)

