#!/usr/bin/python
# vim:ts=4:sw=4 expandtab
#############################################################################
#
# Copyright Avaya Inc., All Rights Reserved.
#
# THIS IS UNPUBLISHED PROPRIETARY SOURCE CODE OF Avaya Inc.
#
# The copyright notice above does not evidence any actual or intended
# publication of such source code.
#
# Some third-party source code components may have  been modified from their
# original versions by Avaya Inc.
#
# The modifications are Copyright Avaya Inc., All Rights Reserved.
#
# Avaya - Confidential & Restricted. May not be distributed further without
# written permission of the Avaya owner.
#
#############################################################################
"""

Performs certificate related operations:
    - import p12 ID certificates into trust management
    - import pem trusted certificates into trust management
 
Return Codes:
    - 0: Success
    - 1: Usage
    - 2: File does not exist
    - 3: Import ID failure
    - 4: Incorrect passphrase
    - 5: User entered no for imported trusted certificate
    - 6: User aborted
    - 7: Import trusted certificate failure
"""

import argparse
import codecs
from xml.dom.minidom import parse
import getpass
import grp
import os
import pwd
import shutil
import stat
from subprocess import *
import sys
from tempfile import NamedTemporaryFile

TM_HOME = '/opt/avaya/tm'
CRD_INPUT_FILE = TM_HOME + '/CRDThirdParty_Input.xml' 
CRD_TM_FILE = TM_HOME + '/CRDThirdPartyTM.xml'
TMCLIENT='/opt/util/bin/tm/TMClient.sh'


# map some short names for service IDs
serviceIdMapping = { 
    'spirit':'spiritalias',
    'LDAP':'LDAP',
    'Syslog':'syslog',
    'all':'all'
     }

# Each service has one or more truststores associated with it.  
trustNameMapping = { 
    'spiritalias':['SPIRIT_TRUSTSTORE'],
    'LDAP':['LDAP_TRUSTSTORE'],
    'syslog':['SYSLOG_TRUSTSTORE']
    }

trustIdMapping = { 
    'spiritalias':'SAL_AGENT',
    'LDAP':'LDAP',
    'syslog':'SYSLOG'
    }

temp = '/tmp/backup_init_certs'

###############################################################
# Copy p12 file to TM home
###############################################################
"""
The p12 file needs to be in the TM home directory in order
to be picked up during initialization.

Arguments:
    filename -    full path to p12 file
"""
def copyP12File(filename):
    shutil.copy(filename, TM_HOME)
    p12file = TM_HOME + "/" + os.path.basename(filename)
    setPermissions(p12file)

###############################################################
# Set file permissions
###############################################################
"""
Sets the owner, group, and permissions of a file.

Arguments:
    filename - full path of the file to change
"""
def setPermissions(filename):
    
    # set to owner read
    os.chmod(filename, stat.S_IREAD)
 
    # change the owner and group
    uid = pwd.getpwnam("root").pw_uid
    gid = grp.getgrnam("root").gr_gid
    os.chown(filename, uid, gid)

###############################################################
# Extracts CA Certs
###############################################################
"""
Extracts the CA certs so that they can be later imported
into the different truststores.

Arguments:
    filename -    full path to p12 file
"""
def extractCACerts(p12file, passphrase):
    print "Extracting CA certs from %s" % p12file
    
    fullpath = '%s/%s' % (TM_HOME, p12file)
    
    try:
        output = check_output(['/usr/bin/openssl', 'pkcs12', '-in', fullpath,
                '-nokeys', '-cacerts', '-passin', 'pass:%s' % passphrase])
        
    except CalledProcessError:
        print "Failed to extract certs from %s, check passphrase." % p12file
        sys.exit(4)
        
    trustedCertDir = TM_HOME + '/trustedcerts'
    
    if not os.path.exists(trustedCertDir):
        os.mkdir(trustedCertDir)
    
    certBlock = False
    for line in output.split('\n'):
        if 'BEGIN CERTIFICATE' in line:
            tempfile = NamedTemporaryFile(dir=trustedCertDir, prefix='certs', 
                suffix='.pem', delete=False)
            certBlock = True
            tempfile.write('%s\n' % line)
        elif 'END CERTIFICATE' in line:
            certBlock = False
            tempfile.write(line + '\n')
        elif certBlock:
            tempfile.write('%s\n' % line)
            
###############################################################
# Update trusted CRD file
###############################################################
"""
Removes the default CA initialization of the truststores
from the CRD TM file.
"""
def updateCRDTMFile(serviceID):
    
    if serviceID in trustNameMapping.keys():
        name = trustNameMapping[serviceID]
    else:
        print "Failed to update %s" % CRD_TM_FILE
        sys.exit(3) 
           
    doc = parse(CRD_TM_FILE)
 
    trustInitNodes = doc.getElementsByTagName('TRUST_STORE_INIT')
    for trustInitNode in trustInitNodes:
        trustStoreName = trustInitNode.getElementsByTagName('TrustStoreName')[0]
        
        # remove the entire element
        if trustStoreName.firstChild.data in name:
            parent = trustInitNode.parentNode
            parent.removeChild(trustInitNode)
        
    # save the changes
    with codecs.open(CRD_TM_FILE, "w","utf-8") as f:
        doc.writexml(f)
    
###############################################################
# Update CRD Input file
###############################################################
"""
Sets up the CRD Input file for importing ID certificates.
"""
def updateCRDInputFile(serviceID, p12file, passphrase):
    
    doc = parse(CRD_INPUT_FILE)
    root = doc.documentElement
    
    idCertNodes = doc.getElementsByTagName('INPUT_IDCert')
    for idCertNode in idCertNodes:
        serviceIDNode = idCertNode.getElementsByTagName('ServiceID')[0]
        if serviceIDNode.firstChild.data == serviceID:
            
            # Remove any SCEP entries
            for tag in ['SCEP_CA_InternalName', 'SCEP_PASSWORD']:
                scepNodes = idCertNode.getElementsByTagName(tag)
                if len(scepNodes) == 1:
                    idCertNode.removeChild(scepNodes[0])
             
            # now add in the p12 import fields
            pkcs12Node = idCertNode.getElementsByTagName('PKCS12FileName')
            if len(pkcs12Node) != 0:
                # already exists so just update it
                pkcs12Node[0].replaceChild(doc.createTextNode(p12file), 
                    pkcs12Node[0].firstChild)
            else:
                # doesn't exist so add it
                pkcs12Node = doc.createElement('PKCS12FileName')
                pkcs12Node.appendChild(doc.createTextNode(p12file))
                idCertNode.appendChild(pkcs12Node)
            
            passwordNode = idCertNode.getElementsByTagName('Password')
            if len(passwordNode) != 0:
                # already exists
                passwordNode[0].replaceChild(doc.createTextNode(passphrase),
                    passwordNode[0].firstChild)
            else:
                # add it
                passwordNode = doc.createElement('Password');
                passwordNode.appendChild(doc.createTextNode(passphrase));
                idCertNode.appendChild(passwordNode)
     
    # Remove the INPUT_TRUSTSTORE entry because this will cause TMClient to
    # try and reach out to SMGR. If SMGR's CA is disabled then this will
    # cause initTM to fail.
    inputTrustList = doc.getElementsByTagName('INPUT_TRUSTSTORE')
    if inputTrustList != None and inputTrustList.length != 0:
        root.removeChild(inputTrustList.item(0))
        
    # save the changes
    with codecs.open(CRD_INPUT_FILE, "w","utf-8") as f:
        doc.writexml(f)
    
###############################################################
# Create CRD files
###############################################################
"""
Create third party CRD files.
"""
def createCRDFiles():
    ORIG_CRD_INPUT_FILE = TM_HOME + '/CRDJEE_Input.xml'
    shutil.copy(ORIG_CRD_INPUT_FILE, CRD_INPUT_FILE)
    setPermissions(CRD_INPUT_FILE)
    
    ORIG_CRD_TM_FILE = TM_HOME + '/CRDJEETM.xml'
    shutil.copy(ORIG_CRD_TM_FILE, CRD_TM_FILE)
    setPermissions(CRD_TM_FILE)

###############################################################
# Import ID certificate
###############################################################
def importIDCert(serviceIDs, filename, passphrase):
    
    p12file = os.path.basename(filename)
    copyP12File(filename)
    
    for serviceID in serviceIDs:
        print "Importing %s for %s" % (p12file, serviceID)
    
        if not os.path.exists(CRD_INPUT_FILE) or not os.path.exists(CRD_TM_FILE):
            createCRDFiles()
     
        if passphrase == None:
            while True:
                try:
                    passphrase = getpass.getpass("Passphrase: ")
                    if len(passphrase) == 0:
                        print "A passphrase is required!"
                    else:
                        break
                except KeyboardInterrupt:
                    print
                    sys.exit(6)
        
        updateCRDTMFile(serviceID)
        updateCRDInputFile(serviceID, p12file, passphrase)
        extractCACerts(p12file, passphrase)
        
###############################################################
# Import a trusted CA certificate
###############################################################
def tmClientImport(serviceID, filename):

    inventoryFile = os.path.join(TM_HOME, "TMClientInv.xml")
    if os.path.exists(inventoryFile) == False:
        # Trust has not been established yet, so add it to trustedcerts
        # directory so that it will be automatically loaded when initTM
        # is run.
        shutil.copy(filename, os.path.join(TM_HOME, "trustedcerts"))
        return
    
    if serviceID == "all":
        # add the trusted cert to all files
        for id in trustIdMapping.values():
            try:
                output = check_output([TMCLIENT, 'addCertificate', id, 
                    filename])
                print "Certificate added to %s truststore" % id
            except CalledProcessError:
                # continue onto to adding it to other truststores
                pass
    else:
        trustIdName = trustIdMapping[serviceID]
        try:
            output = check_output([TMCLIENT, 'addCertificate', trustIdName, 
                filename])
            print "Certificate added to %s truststore" % trustIdName
        except CalledProcessError:
            print "Failed to add certificate to %s truststore" % trustIdName
            sys.exit(7)

###############################################################
# Import a trusted CA certificate
###############################################################
def importTrustedCA(serviceIDs, filename, assumeYes=False, backup=False):
    '''
    Imports trusted certificates to the specified services.
    
    Parameters
    ----------
    serviceID: list
        list of services to import the trusted cert to (can be set to 'all')
    filename: str
        path to trusted cert in PEM format
    assumeYes: boolean
        when true the user is not prompted to accept the certificate
    '''
    with open(filename, 'r') as pemFile:
        line = pemFile.readline()
        if '-----BEGIN CERTIFICATE-----' not in line:
            print "Certificate must be in PEM format"
            sys.exit(4)
            
    if not assumeYes:
        # display the certificate so the user can confirm the import
        call(['/bin/openssl', 'x509', '-text', '-certopt', 'no_sigdump',
            '-noout', '-in', filename]) 
    
        try:
            response = raw_input('\nTrust this certificate? [no]: ')
            if response.lower() == 'yes' or response.lower() == 'y': 
                print
                for serviceID in serviceIDs:
                    tmClientImport(serviceID, filename)
            else:
                print "Certificate not added"
                if not backup:
                   sys.exit(5)

        except KeyboardInterrupt:
            print
            sys.exit(6)
    else:
        # add the trusted cert to all files (silent mode) 
        for serviceID in serviceIDs:
            tmClientImport(serviceID, filename)


###############################################################
# Import syslog certificates from syslog
###############################################################
def importSyslogCerts(service_ids, StorePasswd = 'None'):
   '''
    Imports certificates from previous syslog store.

    Parameters
   ----------
   '''

   sender_key = '/etc/pki/tls/private/sender-key.pem'
   sender_cert = '/etc/pki/tls/private/sender-cert.pem'
   ca_cert = '/etc/pki/tls/private/ca-cert.pem'
   initTM_id = '/opt/avaya/tm/keystore/syslog_keystore.pem'

   if os.path.exists(sender_key) and os.path.exists(sender_cert) and os.path.exists(ca_cert):
      if not os.path.exists(temp):
         os.mkdir(temp)
      else:
         shutil.rmtree(temp)
         os.mkdir(temp)

      destPath = temp + '/systore_oldkey.pem'
      print 'Restoring syslog third party certificate'
      filenames = [sender_key, sender_cert]
      with open(destPath, 'w') as outfile:
         for fname in filenames:
            with open(fname) as infile:
               for line in infile:
                  outfile.write(line)
      if os.path.exists(initTM_id):
         shutil.copy(destPath, initTM_id)
         os.remove(destPath)

      importTrustedCA(service_ids, ca_cert, True, True)

      #assume customer has restored certs
      os.rename(sender_key, sender_key + '.old')
      os.rename(sender_cert, sender_cert + '.old')
      os.rename(ca_cert, ca_cert + '.old')
   elif os.path.exists(ca_cert):
      # one-way TLS
      importTrustedCA(service_ids, ca_cert, True, True)
      os.rename(ca_cert, ca_cert + '.old')
      print 'Syslog Trust certificate copied'
   else:
      print "Nothing to restore for syslog"


###############################################################
# Import trusted certs from ldap trusted cert
###############################################################
def importLdapTrust(service_ids):
   '''
    Imports certificates from trust certs from temp.

    Parameters
   ----------

   '''
 
   ldap_cert = "/etc/pki/tls/ca_ldap.pem"
   if os.path.exists(ldap_cert):
      print 'Restoring ldap third party certificate'
      importTrustedCA(service_ids, ldap_cert, True, True)
      os.rename(ldap_cert, ldap_cert + '.old')

###############################################################
# Actaul import all trusted certs from temp
###############################################################
def importTempSpiritTrust(service_ids, StorePasswd = 'None'):
   '''
    Imports certificates from trust certs from temp.

    Parameters
   ----------
   '''

   if os.path.exists(temp):
      for filename in os.listdir(temp):
         if filename.endswith(".pem"): 
            file = temp + '/' + filename
            importTrustedCA(service_ids, file, True, True)


###############################################################
# Import previous spirit Id certificate
###############################################################
def importPreSpiritId(service_ids, fips_mode, StorePasswd = 'None'):
   '''
    Imports certificates from previous store.

    Parameters
   ----------
   '''
   storeType = 'JKS'
   attempt_count = 1
   if fips_mode == 0:
      fullpath = '/opt/spirit/security/identity.jks'
   else:
      fullpath = '/opt/spirit/security/identity.bcfks'
      storeType = 'BCFKS'

   if os.path.exists(fullpath):
      print 'Restoring spirit keystore third party certificate'
      if not os.path.exists(temp):
         os.mkdir(temp)
      else:
         shutil.rmtree(temp)
         os.mkdir(temp)
      while attempt_count <= 3:
         if StorePasswd == 'None':
            while True:
               try:
                  StorePasswd = getpass.getpass("Existing spirit identity store Password: ")
                  if len(StorePasswd) == 0:
                     print "A password is required!"
                  else:
                     break
               except KeyboardInterrupt:
                  print "Unable to read password, keyboard interaction failed. Exiting script.."
                  sys.exit(6)
         destAlias = 'spiritalias'
         destStore = '/opt/avaya/tm/keystore/spirit-identity.bcfks'
         pwdFile = '/etc/opt/passwd'
         destPwd = 'avaya123'

         try: 
            with open(pwdFile, 'r') as pfile:
               for line in pfile:
                  line.strip()
                  destPwd = line
         except CalledProcessError:
            print "Failed to get initTM store password, script is exiting .."
            sys.exit(4)
         try:
            if fips_mode == 0:
                output = check_output(['/usr/bin/keytool', '-list', '-v', '-keystore', fullpath, '-storepass', StorePasswd])
            else:
                output = check_output(['/usr/bin/keytool', '-list', '-v', '-keystore', fullpath, '-storepass', StorePasswd,
                                '-storetype', 'BCFKS', '-provider',
                                'org.bouncycastle.jcajce.provider.BouncyCastleFipsProvider',
                                '-providerpath', '/opt/util/lib/bc-fips-1.0.1.3.jar'])

            out_filename = temp + '/' + 'identity_certs.txt'
            with open(out_filename, 'w') as out_file:
               out_file.writelines(output)
            with open(out_filename, 'r') as origin:
               for line in origin:
                  if 'Alias' in line:
                     alias = line.split(":")
                     alias_name = alias[1].strip()
                     output = check_output(['/usr/bin/keytool', '-v', '-importkeystore', '-srckeystore',
                                fullpath, '-srcstorepass', StorePasswd, '-srcalias', alias_name, '-destkeystore',
                                destStore, '-destalias', destAlias, '-srcstoretype', storeType, '-deststorepass', 
                                destPwd, '-deststoretype', 'BCFKS', '-destkeypass', destPwd, '-provider', 
                                'org.bouncycastle.jcajce.provider.BouncyCastleFipsProvider',
                                '-providerpath', '/opt/util/lib/bc-fips-1.0.1.3.jar', '-noprompt'])
                     print output
            os.rename(fullpath, fullpath + '.old')
            break
         except CalledProcessError:
            print "Failed to extract certs attempt %s of 3, check passphrase." % attempt_count
            attempt_count = attempt_count + 1
            StorePasswd = 'None'
            if attempt_count > 3:
               print "Failed to extract certs from %s, all attempts failed." % fullpath
               print "Exiting script, please perform certificate restore manually using SMGR."
               sys.exit(4)
   else:
      print "Nothing to restore for spirit identity store"


 
###############################################################
# Import previous third party spirit agent trust certificate
###############################################################
def importPreCert(service_ids, fips_mode, StorePasswd = 'None'):
   '''
    Imports certificates from previous store.

    Parameters
   ----------
   '''
   if fips_mode == 0:
      fullpath = '/opt/spirit/security/trust.jks'
   else:
      fullpath = '/opt/spirit/security/trust.bcfks'

   if os.path.exists(fullpath):
      print 'Restoring spirit truststore third party certificate'
      attempt_count = 1
      if not os.path.exists(temp):
         os.mkdir(temp)
      else:
         shutil.rmtree(temp)
         os.mkdir(temp)

      while attempt_count <= 3:
         if StorePasswd == 'None':
            while True:
               try:
                  StorePasswd = getpass.getpass("Existing spirit trust store Password: ")
                  if len(StorePasswd) == 0:
                     print "A password is required!"
                  else:
                     break
               except KeyboardInterrupt:
                  print "Unable to read password"
                  sys.exit(6)

         try:
            if fips_mode == 0:
               output = check_output(['/usr/bin/keytool', '-list', '-v', '-keystore', fullpath, '-storepass', StorePasswd])
            else:
               output = check_output(['/usr/bin/keytool', '-list', '-v', '-keystore', fullpath, '-storepass', StorePasswd,
                                '-storetype', 'BCFKS', '-provider',
                                'org.bouncycastle.jcajce.provider.BouncyCastleFipsProvider',
                                '-providerpath', '/opt/util/lib/bc-fips-1.0.1.3.jar'])
            out_filename = temp + '/' + 'certs.txt'
            with open(out_filename, 'w') as out_file:
               out_file.writelines(output)
            with open(out_filename, 'r') as origin:
               for line in origin:
                  if 'Alias' in line:
                     alias = line.split(":")
                     alias_name = alias[1].strip()
                     filename = temp + '/' + alias_name + ".pem"
                     if fips_mode == 0:
                        output = check_output(['/usr/bin/keytool', '-export', '-alias', alias_name, '-file', filename,
                                    '-keystore', fullpath, '-storepass', StorePasswd, '-rfc'])
                     else:
                        output = check_output(['/usr/bin/keytool', '-export', '-alias', alias_name, '-file', filename,
                                    '-keystore', fullpath, '-storepass', StorePasswd, '-rfc', 
                                '-storetype', 'BCFKS', '-provider',
                                'org.bouncycastle.jcajce.provider.BouncyCastleFipsProvider',
                                '-providerpath', '/opt/util/lib/bc-fips-1.0.1.3.jar'])
            print "certificate succesfully copied"
            os.remove(out_filename)

            # import third party certs from tmp
            importTempSpiritTrust(service_ids)
            os.rename(fullpath, fullpath + '.old')
            break
         except CalledProcessError:
            print "Failed to extract certs attempt %s of 3, check passphrase." % attempt_count
            attempt_count = attempt_count + 1
            StorePasswd = 'None'
            if attempt_count > 3:
               print "Failed to extract certs from %s, all attempts failed." % fullpath
               print "Exiting script, please perform certificate restore manually using SMGR." 
               sys.exit(4)
   else:
      print "Nothing to restore for spirit trust store"

   
 
###############################################################
# Main
###############################################################

parser = argparse.ArgumentParser()
parser.add_argument('-i', '--import', dest='filename', action="store",
     type=str, help='pkcs12 file to import', required=False)
parser.add_argument('-r', '--restore', dest='pre', action="store_true",
     help='cert from previous store to import', default=False)
parser.add_argument('-t', '--trustedca', dest='trustedca', action="store_true",
     help='import a trusted certificate', default=False)
parser.add_argument('-y', dest='assumeYes', action="store_true",
     help='Assume yes when importing trusted certificates', default=False)
parser.add_argument('-s', '--service', dest='id', nargs="+", action="store",
     type=str, 
     help='service ID of certificate (spirit, syslog, ldap all)', 
     required=True)
parser.add_argument('-p', '--passphrase', dest='passphrase', action="store",
     type=str, help="PKCS #12 passphrase" , required=False)
parser.add_argument('-f', '--fips', dest='fips_mode', action="store",
     type=int, help="FIPS mode" , required=False)
args = parser.parse_args()

service_ids = []

# check if certificate is present in previous store.
if args.pre:
   fips_mode = args.fips_mode
   filename = "/etc/opt/passwd"
   if not os.path.exists(filename):
      sys.exit(1)
   # check if syslog path for previous third party cert exist
   id = 'syslog'
   service_ids.append(id)
   importSyslogCerts(service_ids)
   del service_ids[:]

   # check if spirit path for previous third party cert exist
   id = 'spirit'
   service_ids.append(serviceIdMapping[id])
   importPreCert(service_ids, fips_mode)
   importPreSpiritId(service_ids, fips_mode)
   del service_ids[:]

   id = 'LDAP'
   service_ids.append(id)
   importLdapTrust(service_ids)

   # clean up
   filename = "/etc/opt/passwd"
   if os.path.exists(filename):
      os.remove(filename)

   if os.path.exists(temp) and os.path.isdir(temp):
      shutil.rmtree(temp)
   RESTORE_PATH='/etc/opt/tm_initialize.sh'
   with open(RESTORE_PATH, 'w') as out_file:
            out_file.writelines('initialized')

   sys.exit(0)


# more user friendly.  For example, instead of using the real service ID
# "securitymodule_sip", the user can enter just "sip".
for id in args.id:
    id = id.lower()
    if id in serviceIdMapping.values():
        # exact service ID passed in
        service_ids.append(id)
    elif id in serviceIdMapping.keys():
        # convert it to the service ID
        service_ids.append(serviceIdMapping[id])
    else:
        print "Invalid service : %s" % id
        sys.exit(1)

 
if not os.path.exists(args.filename):
    print "%s does not exist!" % args.filename
    sys.exit(2)
 
if args.trustedca == False:
    
    if id == "all":
        print "The 'all' option cannot be used when importing ID certificates"
        sys.exit(1)
          
    importIDCert(service_ids, args.filename, args.passphrase)
else:
    if args.passphrase != None:
        print "The passphrase option doesn't apply to importing trusted certs"
        sys.exit(1)

    importTrustedCA(service_ids, args.filename, args.assumeYes)
