#######################################################################
# Copyright (C) 2006,2007 VMWare, Inc.
# All Rights Reserved
########################################################################
#
from GPG import GPG, Verify
import md5, sha
import logging

log = logging.getLogger( 'Signature' )

#----------------------------------------------------------------------------
# Module-specific Exceptions
#----------------------------------------------------------------------------
class SigException(Exception):
	pass
class UnsupportedAlgorithm(SigException):
	pass


#----------------------------------------------------------------------------
# Private helper functions
#----------------------------------------------------------------------------
def _MakeMD5( fullText ):
	return md5.new( fullText ).hexdigest()
def _MakeSHA1( fullText ):
	return sha.new( fullText ).hexdigest()
def _MakeMD5SHA1( fullText ):
	return md5.new( fullText ).hexdigest() + sha.new( fullText ).hexdigest()

def _MakePlaintextString( digAlg, sigAlg, digest ):
	# there shouldn't be a space in any of the input strings
	if ' ' in digAlg + sigAlg + digest:
		raise SigException('space was in alg name or digest data')

	plaintext = ' '.join( [digAlg, sigAlg, digest] )
	return plaintext
	

#----------------------------------------------------------------------------
# Module-level variables
#----------------------------------------------------------------------------
digestAlgs = {
              'md5': _MakeMD5,
              'sha1': _MakeSHA1,
              'md5sha1': _MakeMD5SHA1,
             }

#-----------------------------------------------------------------------------
class KeyState:
	""" If a key is both expired and revoked, it will be REVOKED. """
	UNKNOWN = 0
	MISSING = 1     # No public key found
	EXPIRED = 2     # Signature or key is expired
	REVOKED = 3
	OK      = 99
	
#-----------------------------------------------------------------------------
class Trust:
	""" GPG one-char Trust codes as listed in doc/DETAILS """
	UNKNOWN = '-'
	UNDEFINED = 'q'
	NEVER = 'n'
	MARGINAL = 'm'
	FULLY = 'f'
	ULTIMATE = 'u'
	
#-----------------------------------------------------------------------------
class VerifyResult:
	""" Container for holding signature verification results.
	errmsg     - one-liner error summary
	stderr     - GPG stderr output. Multi-line.
	valid      - Bool, True if GPG sig and digest are good
	gpgSigValid - GPG sig valid; independent of key state & trust
	              If the above is true, then these are defined:

	keystate   - one of the KeyState values above
	trust_code - one of the Trust values above
	shortKeyID, longKeyID, fingerprint
	Also KeyID is filled if KeyState is MISSING.
	"""
	def __init__(self):
		self.errmsg = "Unknown error"
		self.stderr = ""
		self.valid  = False
		self.gpgSigValid = False
		self.keystate = KeyState.UNKNOWN
		self.trust_code = Trust.UNKNOWN
		self.shortKeyID = ""
		self.longKeyID = ""
		self.fingerprint = None

	def ParseGPGVerify(self, verification):
		""" Parse the results from gpg.Verify instance """
		assert isinstance(verification, Verify)
		
		self.stderr = verification.stderr
		if verification.key_id:
			self.longKeyID = verification.key_id
			self.shortKeyID = self.longKeyID[-8:]
		self.fingerprint = verification.fingerprint
		self.trust_code = verification.trust_code
		self.gpgSigValid = verification.valid
		#
		# Now, determine the key state
		#
		if verification.nokey:
			self.keystate = KeyState.MISSING
			self.errmsg = "keyMissing: " + self.shortKeyID
		elif verification.revoked:
			self.keystate = KeyState.REVOKED
			self.errmsg = "keyRevoked: " + self.shortKeyID
		elif verification.expired:
			self.keystate = KeyState.EXPIRED
			self.errmsg = "keyExpired: " + self.shortKeyID
		elif verification.valid:
			self.keystate = KeyState.OK
			self.errmsg = "keyValid: " + self.shortKeyID
		elif verification.corrupt:
			self.errmsg = "corruptSignature:"
		else:
			self.errmsg = "validationError:" + self.stderr
			
		
#-----------------------------------------------------------------------------
class Signature:
	def __init__(self, sig, sigAlg, digest, digestAlg ):
		# the preceding _ means "please don't modify"
		self._sig = sig
		self._digest = digest
		self._sigAlgName = sigAlg.lower()
		self._digestAlgName = digestAlg.lower()
	
	def GetSig(self):
		return self._sig
	def GetDigest(self):
		return self._digest
	def GetSigAlgName(self):
		return self._sigAlgName
	def GetDigestAlgName(self):
		return self._digestAlgName
	
	def Check(self, fileLocation=None, fileHandle=None,
		  keyringDir=None):
		""" Checks the digest, GPG signature.
		Returns True if the GPG signature is valid and the digest
		of the file matches that stored in the signature.
		Will return true even if the key is expired or revoked;
		check results.keystate to verify the key status.
		"""
		if not fileLocation and not fileHandle:
			raise TypeError('Check() needs a file loc or handle')
		self.result = VerifyResult()

		#first, try to calculate the digest of the local file
		try:
			digFunc = digestAlgs[ self._digestAlgName ]
		except KeyError, e:
			msg = "unsupportedDigestAlgo: " + \
			      self._digestAlgName
			self.result.errmsg = msg
			raise UnsupportedAlgorithm(self._digestAlgName)
		
		if fileLocation:
			fp = file( fileLocation )
			fileDigest = digFunc( fp.read() )
			fp.close()
		elif fileHandle:
			fileDigest = digFunc( fileHandle.read() )

		if fileDigest != self._digest:
			msg = 'digestMismatch: ' + self._digestAlgName
			self.result.errmsg = msg
			return False
			
		#check the signature
		gpg = GPG(gnupghome=keyringDir)
		verification = gpg.verify( self._sig )
		log.log(1, 'Verification stderr : '+ verification.stderr )
		
		self.result.ParseGPGVerify(verification)
		if not self.result.gpgSigValid:
			return False
			
		#check the expected sig data matches what was in the gpg sig
		plaintext = _MakePlaintextString( self._digestAlgName, 
		                                  self._sigAlgName, 
		                                  self._digest )

		if self._sigAlgName == 'gpg':
			# WORKAROUND
			#GPG appends a newline for some reason.  Either that
			#or the GPG.py library doesn't import the data 
			#properly
			plaintext += '\n'

		if plaintext != verification.data:
			msg = ('dataMismatch: expected [%s] != actual [%s]' %
			       (plaintext, verification.data) )
			self.result.errmsg = msg
			return False

		self.result.valid = True
		return True


#-----------------------------------------------------------------------------
class MultiSignature:
	""" Class for handling the signing and verification of multiple
	detached signatures associated with a payload.  This is a base
	abstract class;  to make it useful, inherit and override the
	_pre and _postCheckPolicy methods.
	"""

	def __init__(self, keyringDir=None):
		self.sigs = {}
		self.errmsg = ""
		self.keyringDir = keyringDir

	def AddSignature(self, sigID, *argv, **kwargs):
		if sigID in self.sigs:
			raise KeyError('Signature ID=%s already exists' %
				       (sigID))
		self.sigs[sigID] = Signature(*argv, **kwargs)

	def _preCheckPolicy(self):
		""" Any pre-verification checks go here, such as for the
		number of signatures. """
		pass

	def _postCheckPolicy(self):
		""" Check the results of each signature verification,
		evaluate the error message, key state, key ID, etc.,
		and return a bool. """
		pass
	
	def Check(self, fileLocation=None, fileHandle=None):
		if not self._preCheckPolicy():
			return False

		self.errmsg = ""
		for (sigID, sig) in self.sigs.items():
			try:
				sig.Check( fileLocation, fileHandle,
					   self.keyringDir )
			except SigException, e:
				if not sig.result.errmsg:
					msg = e.__class__.__name__ +': '+ str(e)
					sig.result.errmsg = msg

			self.errmsg += "Signature %s: %s\n" % (
				sigID, sig.result.errmsg)

		self.errmsg = self.errmsg.rstrip()
		return self._postCheckPolicy()
	

#-----------------------------------------------------------------------------
class OneGoodSignature(MultiSignature):
	""" This is an example MultiSignature implementation.  It looks
	for one good, nonexpired, nonrevoked signature, and does not check
	for trust status. """

	def _preCheckPolicy(self):
		if len(self.sigs) < 1:
			self.errmsg = "Signature -: notEnoughSignatures:"
			return False
		return True
	
	def _postCheckPolicy(self):
		for sig in self.sigs.values():
			res = sig.result
			if res.valid and res.keystate == KeyState.OK:
				return True
		return False

#-----------------------------------------------------------------------------
class TwoGoodSigsOneExpiredNoRevoked(MultiSignature):
	"""This looks for at least two good, nonrevoked signatures, one of
	which should not be expired also,
	and does not check for trust status. """

	def _preCheckPolicy(self):
		if len(self.sigs) < 2:
			self.errmsg = "Signature: notEnoughSignatures. need 2."
			return False
		return True
	
	def _postCheckPolicy(self):
		goodOnes = 0
		expired = 0
		for sig in self.sigs.values():
			res = sig.result
			if res.keystate == KeyState.REVOKED:
				return False
			if not res.valid:
				return False
			if res.keystate == KeyState.OK:
				goodOnes += 1
			elif res.keystate == KeyState.EXPIRED:
				expired += 1

		if (goodOnes + expired) >= 2 and expired <= 1:
			return True
		else:
			return False
