#######################################################################
# Copyright (C) 2006 VMWare, Inc.
# All Rights Reserved
########################################################################
#
# PatchDB Class definition
#
# Testing:
#   testinstall.py

import bsddb
import cPickle
import time
import os, os.path
import fnmatch
import logging
from vmware.descriptor import Descriptor
from InstallSession import HASH_INTF_VERSION
import insthelper
import errors

#
# Constants/globals
#
log = logging.getLogger('db')

DEF_SESSIONDB    = 'patch.db'
SESSIONKEY       = '/current'
DB_MODE_FLAGS    = 0644

DEF_DESC_DIR     = os.path.join(os.path.dirname(Descriptor.__file__),
                                'defaults')

########################################################################
#
# PatchDB Class
# Manage a database of current and previous install state.
#
# Multiple instances of the same database can exist as long as locking
# is used.
#
# Database keys:
#   /current                  - reserved for the current install session
#   /sess/<release>
#   /sess/none                - if no release string or descriptor
#
# Use a BTree Berkeley DB.  Assuming our release strings will be
# based on a monotonically increasing numbering scheme (ex build #)
# and the database is not huge, BTrees will help locate ranges
# of records.
#
# Python 2.2's built in bsddb module only supports the BDB 1.85 API;
# however, the database itself is current and can be accessed by newer
# APIs as well as the Berkeley DB utils.
#
class PatchDB:
    """ A class for accessing the patch state/history database.
    Includes locking, saving and restoring of complete Sessions,
    management of patch history/state data.
    """
    def __init__(self, dbdir, dbfile=DEF_SESSIONDB):
        """ Opens the patch database """
        if not os.path.isdir(dbdir):
            os.mkdir(dbdir)
        self.dbdir = dbdir
        self.dbfile = os.path.join(dbdir, dbfile)
        #
        # TODO: Use "cl" for locking support when we use a Berkeley DB
        # 3 driver:  bsddb3 or bsddb in Python >= 2.3
        #
        self.bt = bsddb.btopen(self.dbfile, "c", DB_MODE_FLAGS)
        log.debug('dbfile = [%s], %d keys' % (self.dbfile, len(self.bt)))
        #
        # Import default descriptors if database is empty
        dbkeys = self.bt.keys()
        if not len(dbkeys) or (len(dbkeys) == 1 and dbkeys[0] == '/current'):
            self.__importdefaultdescriptors()

    def __importdefaultdescriptors(self):
        '''Import default descriptors into the patch database from
           DEF_DESC_DIR.  Returns nothing, and only logs warnings on errors.'''
        if not os.path.isdir(DEF_DESC_DIR):
            return
        try:
            descfns = [os.path.join(DEF_DESC_DIR, fn)
                       for fn in os.listdir(DEF_DESC_DIR)
                       if fn.lower().endswith('.xml')]
            descfns = [fn for fn in descfns if os.path.isfile(fn)]
        except Exception, e:
            log.warn('Error reading directory %s: %s' % (DEF_DESC_DIR, e))
            return
        for fn in descfns:
            try:
                self.AddIsoEntry(Descriptor.Descriptor(fn))
            except Exception, e:
                log.warn('Error importing descriptor %s: %s' % (fn, e))

    def __releasefromkey(self, key):
        """ Returns the patch release string from a database key, or None
        if the key is special (eg saved sessions)
        """
        if key.startswith('/sess/'):
            return key[6:]
        else:
            return None

    def __keyfromrelease(self, release):
        """ Returns a key string from patch release string """
        return '/sess/%s' % (release)

    ########################################################################
    # Lower-level public database methods
    ########################################################################

    def AddDescriptor(self, desc):
        """ Saves the descriptor object to the patch DB """
        fname = os.path.join(self.dbdir, '%s.xml' % (desc.GetRelease()))
        log.debug('Saving descriptor to file %s' % (fname))
        desc.Write(fname)

    def GetDescriptor(self, release):
        """ Returns a descriptor object of the corresponding release, or
        KeyError if such release is not found
        """
        fname = os.path.join(self.dbdir, '%s.xml' % (release))
        log.debug('Trying to load descriptor from %s' % (fname))
        if not os.path.isfile(fname):
            raise KeyError, "Descriptor of release '%s' not found" % (release)
        desc = Descriptor.Descriptor(fname)
        return desc

    def AddHashEntry(self, insthash, release='none', key=None):
        """ Add an install hash to the patch DB """
        if not key:
            key = self.__keyfromrelease(release)
        log.debug('Adding entry with key [%s]' % (key))
        self.bt[key] = cPickle.dumps(insthash)

    def HasEntry(self, release, key=None):
        """ Checks on the presence of a key in the database matching
        key if supplied, or /sess/<release> if key is not supplied.
        Returns True or False.
        """
        if not key:
            key = self.__keyfromrelease(release)
        return self.bt.has_key(key)
    
    def GetEntry(self, release, key=None):
        """ Returns the hash corresponding to key if supplied, or to
        /sess/<release> if key is not supplied.
        None is returned if the key is not found.
        """
        if not key:
            key = self.__keyfromrelease(release)
        try:
            return cPickle.loads(self.bt[key])
        except KeyError:
            return None
            
    def Sync(self):
        """ Write current changes in database to disk """
        self.bt.sync()

    def Close(self):
        """ Close database.  After this the db object is useless. """
        self.bt.sync()
        self.bt.close()


    ########################################################################
    # Higher-level public database methods
    ########################################################################

    def SaveSession(self, sessionHash):
        """ Saves the session hash to the /current key in DB """
        self.AddHashEntry(sessionHash, key=SESSIONKEY)

    def RestoreSession(self):
        """ Returns the last session hash saved, or None """
        sesshash = self.GetEntry('', key=SESSIONKEY)
        return sesshash

    def _addBundleHash(self, bhash, desc, release):
        """ Common method for adding the session hash with given
        descriptor and release to the DB
        """
        #
        # Initialize common fields
        bhash['systemStates'] = []
        
        #
        # Obsolete other entries if needed
        if desc:
            paths = desc.GetUpgradePaths()
            if paths:
                self.MarkPatchesObsolete(paths.GetObsoletedIDs())

        self.AddHashEntry(bhash, release)

    def AddBundleEntry(self, bundleHash):
        """ Adds a bundleHash to the PatchDB.  Important members
        of a bundleHash:
           desc:    Descriptor() instance
        """
        if 'desc' in bundleHash and bundleHash['desc']:
            desc = bundleHash['desc']
            bundleID = desc.GetRelease()
            bundleHash['desclist'] = [ bundleID ]
            self.AddDescriptor(desc)
            del bundleHash['desc']
        else:
            bundleID = 'none'

        # imported bundles should have a unique timestamp. 
        bundleHash['lasttime'] = time.time()
        self._addBundleHash(bundleHash, desc, bundleID)

    def AddIsoEntry(self, desc):
        """ Add an install record based on a descriptor.xml shipped on the
        ISO.  Fill in time, last state, and rpms installed.
        Side effects:  This will open the RPMDB and verify the rpmlist.
        """
        h = {}
        h['__version__'] = HASH_INTF_VERSION
        #
        # Make sure this record cannot be serialized into InstallSession object
        h['restorable']  = 0
        h['obsolete']    = 0
        h['lastState']   = 'SuccessState'
        h['options'] = {'url': '<ISO install>'}
        #
        # Save descriptor too
        self.AddDescriptor(desc)
        h['desclist'] = [ desc.GetRelease() ]
        #
        # The install time is just the time this record was added
        h['firsttime'] = h['lasttime'] = time.time()
        #
        # Which of the rpms are actually installed?
        h['installed'] = {}
        if desc.GetFullRpmlist():
            installed, matches = insthelper.GetDbInstalledHeaders(desc.GetFullRpmlist())
            for na in matches.keys():
                h['installed'][na] = 1

        #
        # Append things to the 'removed' hash.  Note that here the 'removed'
        # are names only.  When we process removes during an actual bundle
        # installation, we know what the arch was, but because here we are
        # trying to guess after the fact, we don't have an easy way to know.
        # This works as far as query -l output is concerned, because the
        # Rpmlist.Prune() method also prunes by name as well as name.arch, and
        # -l info doesn't show the arch in any case.
        h['removed'] = dict()
        if desc.GetRemoves():
            installed = insthelper.GetRpmlistFromRpmdb()
            for r in desc.GetRemoves().GetPkgNames():
                if not installed.GetPkgByName(r):
                    h['removed'][r] = 1

        self._addBundleHash(h, desc, desc.GetRelease())

    def GetPatchInfo(self, release):
        """ Returns the hash for patch release with 'desc' key pointing
        to the Descriptor object for that patch.
        Returns None if no such release exists in PatchDB.
        """
        h = self.GetEntry(release)
        if h:
            h['desc'] = None
            if h.has_key('desclist'):
                h['desc'] = self.GetDescriptor(h['desclist'][0])
        return h

    def IsObsolete(self, release):
        """ Returns True if the given release has been deemed to be
        obsolete by the other installed releases.
        Returns False if the release is not found.
        """
        h = self.GetEntry(release)
        if h:
            return h.get('obsolete', False)
        else:
            return False

    def GetBundleList(self):
        """ Returns a list of all bundles in the PatchDB regardless
        of obsolescence. """
        bundles = []
        for key in self.bt.keys():
            release = self.__releasefromkey(key)
            if release:
                bundles.append(release)
        return bundles

    def GetInstalledPatches(self, getObsolete=False, onlyObsolete=False):
        """ Returns a hash of the currently installed patches, with the
        keys being the release strings, and the items being install hashes
        as returned by GetPatchInfo().
        Returns an empty hash if there are no installed patches.
        If getObsolete is True, then obsolete bundle hashes are returned.
        """
        installed = {}
        #
        # Get all non-obsolete session entries
        #
        for rel in self.GetBundleList():
            patchHash = self.GetPatchInfo(rel)
            if getObsolete or not patchHash.get('obsolete', False):
                if not onlyObsolete or patchHash.get('obsolete', False):
                    installed[rel] = patchHash
        return installed

    def LasttimeIndex(self, installed):
        """ Returns a list of keys in installed, ordered by
        ascending install finish date/time.
        installed    - a hash as returned by GetInstalledPathes()
        """
        def sortByLasttime(a, b):
            return cmp(installed[a]['lasttime'], installed[b]['lasttime'])
        
        index = installed.keys()
        try:
            index.sort(sortByLasttime)
        except Exception, e:
            log.warn('Difficulty sorting: %s' % (e,))
        return index

    def MarkPatchesObsolete(self, IDList):
        """ Given a list of patch IDs, marks their entries as
        obsolete and add on to the obsolete patches list.
        fnmatch() wildcards are allowed in IDList.
        """
        installed = self.GetBundleList()
        for ID in IDList:
            matches = fnmatch.filter(installed, ID)
            if not matches:
                log.debug('MarkPatchesObsolete: [%s] not found in PatchDB' % (ID))
                continue
            for match in matches:
                h = self.GetEntry(match)
                h['obsolete'] = 1
                self.AddHashEntry(h, match)

    def AddSystemState(self, state):
        """ Adds one system state to all the bundles in the DB,
        including obsoleted ones.
        """
        #
        # Loop through each hash and update the systemStates.
        # Even obsoleted bundle records are updated.
        # Only write back to DB if we changed the record.
        #
        for rel in self.GetBundleList():
            h = self.GetPatchInfo(rel)
            if h is None:
                log.warning('Error accessing %s from patch DB' % (rel))
                continue
            if state not in h.setdefault('systemStates', []):
                h['systemStates'].append(state)
                self.AddHashEntry(h, rel)
