#######################################################################
# Copyright (C) 2005 VMWare, Inc.
# All Rights Reserved
########################################################################
#
# InstallSession.py  - class module
#
# Unit test: testinstall.py
#

import ConfigParser
import logging
import errors
from vmware.descriptor import Rpmlist
from vmware.descriptor.UpgradePaths import UpgradePaths
import insthelper
import ha
from ProgressBar import ProgressBar
import Depot

#
# Constants/globals
#
log = logging.getLogger('session')



HASH_INTF_VERSION = '2.0'
HASH_MIN_VERSION  = '1.0'

########################################################################
#
# Installation session management
#

class InstallSession:
    """ Executes an install session and accounts for all state --
    bundles, package state, yum config.  This state is maintained as the
    installation proceeds through different InstallState's and can also
    be serialized to and from the PatchDB.

    One of 'depoturl' or 'bundleurl' must be there:
    depoturl         - create a regular depot at depoturl
    bundleIDs        - list of bundles to install
    bundleurl        - install only one bundle at bundleurl

    options:
    excludes         - rpm package names to exclude from action
    force            - install older packages over existing ones
    enableRollback   - bool, enable this install to be rolled back
    nopathcheck      - don't check for dependencies
    baseurls         - secondary (base) repo. Internal/QA only.
    flushcache       - flush the depot cache
    localCacheRoot   - specify an alternative local caching dir
    """

    def __init__(self, db, depot=None, options=None, dbhash=None):
        """ Create a new install session and initialize state.
        depot:   specify a depot to use.  If None, create one based
                 on the 'depoturl'/'bundleurl' keys in options.
        dbhash:  create the session from a PatchDB hash.
        options: key arguments and options
        """

        if options: self.options = options
        else:       self.options = {}
        self.steps = []
        self.firsttime = 0
        self.lastState = 'Initialized'
        # UpgradePath instance - combined installation flags
        self.flags = None
        # Descriptor() instances, indexed by bundle ID
        self.descs = {}
        # Ordered list of bundle lists
        self.bundleGroups = []
        # Rpmlist() with "latest" Rpms across all bundles
        self.latestlist = None
        # Bundle ID for each package, indexed by name.arch
        self.pkgbundle = {}
        # DepotDirEntry instances indexed by bundle ID
        self.bundleEntry = {}
        # DepotEntry instance for each pkg, indexed by name.arch
        self.pkgEntry = {}
        self.pending = {}
        self.installed = {}
        self.removed = {}
        self.db = db
        self.repackaged = self.options.get('enableRollback', 0)
        self.obsolete = 0
        self.preinstUpdatesDb = 1
        self.results = {}
        # List of the downgrading rpm tuple (newversionrpm, oldversionrpm)
        self.downgrades = []
        if depot:
            self.depot = depot
        else:
            if dbhash:
                self.FromHash(dbhash)
            self.depot = self._createDepot()

    def _createDepot(self):
        """ Returns the right depot instance based on self.options """
        depotArgs = {}
        if 'localCacheRoot' in self.options:
            depotArgs['localCacheRoot'] = self.options['localCacheRoot']
        if 'flushcache' in self.options:
            depotArgs['flushCache'] = self.options['flushcache']
        if 'bundleIDs' in self.options:
            depotArgs['scanlist'] = self.options['bundleIDs']
        if 'keyringDir' in self.options:
            depotArgs['keyringDir'] = self.options['keyringDir']
        
        if 'depoturl' in self.options:
            return Depot.Depot(self.options['depoturl'], **depotArgs)
        elif 'bundleurl' in self.options:
            return Depot.DepotInBundle(self.options['bundleurl'],
                                       **depotArgs)
        else:
            return None

    def ToHash(self):
        """ Serializes multi-bundle session state into a hash.
        May be used for saving/restoring session state.
        """
        myhash = self._commonHash()
        myhash['restorable'] = 1
        myhash['pending'] = self.pending
        myhash['installed'] = self.installed
        myhash['removed'] = self.removed
        myhash['bundleGroups'] = self.bundleGroups
        myhash['desclist'] = self.OrderedBundleList()
        myhash['pkgbundle'] = self.pkgbundle
        #
        # This gives us a way to tell if we're upgrading from a legacy version
        # of esxupdate (which doesn't call UpdatePendingFromHost() in
        # PreInstallState).  See PR 175224 for a detailed explanation.
        #
        myhash['preinstUpdatesDb'] = self.preinstUpdatesDb
        return myhash

    def ToBundleHash(self, bundleID):
        """ Serializes package state for one bundle in a format
        compatible with query/info commands.  May not be
        restored to an InstallSession instance.
        """
        myhash = self._commonHash()
        myhash['restorable'] = 0
        if bundleID in self.bundleEntry:
            myhash['options']['url'] = self.bundleEntry[bundleID].remoteLocation
        #
        # Extract from installed/removed what is relevant to
        # this bundle
        myhash['pending'] = {}
        myhash['removed'] = {}
        myhash['installed'] = {}
        for na in self.pending:
            if self.pkgbundle.get(na, '') == bundleID:
                myhash['pending'][na] = self.pending[na]
        for na in self.removed:
            if self.pkgbundle.get(na, '') == bundleID:
                myhash['removed'][na] = self.removed[na]
        for na in self.installed:
            if self.pkgbundle.get(na, '') == bundleID:
                myhash['installed'][na] = self.installed[na]

        if bundleID in self.descs:
            myhash['desc'] = self.descs[bundleID]

        return myhash
    
    def _commonHash(self):
        """ Common session hash data """
        myhash = {}
        myhash['__version__'] = HASH_INTF_VERSION
        myhash['options'] = self.options
        myhash['repackaged'] = self.repackaged
        myhash['obsolete'] = self.obsolete
        #
        # Store vital info only from list of install states
        #
        myhash['lastState'] = self.GetState()
        if self.steps:
            myhash['firsttime'] = self.GetTXStartTime()
            myhash['lasttime']  = self.steps[-1].GetTime()
        
        #
        # We don't need to serialize yumconf, as it is not used
        # after ConfigState, and the state would not need to be
        # saved before ConfigState finishes.
        #
        return myhash

    def FromHash(self, myhash):
        """ Deserializes from a hash and fills this instance.
        Throws ValueError if the hash version < HASH_MIN_VERSION, or
        the hash 'restorable' attribute is false.
        """
        #
        # TODO: use a real version comparator, not string
        if myhash['__version__'] < HASH_MIN_VERSION:
            raise ValueError, "Cannot import hash of version %s (Minimum is %s)" % (
                myhash['__version__'], HASH_MIN_VERSION)
        if myhash['__version__'] < "2.0":
            log.debug('Converting from older session hash')
            self.FromOldHash(myhash)
            
        if myhash['restorable'] == 0:
            raise ValueError, 'Cannot import hash: not restorable'
        #
        # Copy the hashes - use update() method for deep copy
        #
        self.SetOptions(myhash['options'])
        self.pending.update(myhash['pending'])
        self.installed.update(myhash['installed'])
        self.removed.update(myhash.get('removed', {}))
        
        #
        # It's not necessary to restore the descriptor,
        # as we must restart at ConfigState anyways.
        #

        if myhash.has_key('preinstUpdatesDb'):
            self.preinstUpdatesDb = myhash['preinstUpdatesDb']
        else:
            self.preinstUpdatesDb = 0
        
        # Now deal with the string and bool vars
        self.obsolete  = myhash['obsolete']
        self.lastState = myhash['lastState']
        self.firsttime = myhash.get('firsttime', None)
        self.repackaged = myhash.get('repackaged', 0)

    def FromOldHash(self, myhash):
        """ Converts an old pre-2.0 hash to the new format """
        if myhash['restorable'] == 0:
            raise ValueError, 'Cannot import hash: not restorable'
        #
        # Copy the hashes - use update() method for deep copy
        #
        self.SetOptions(myhash['options'])
        self.pending.update(myhash['pending'])
        self.installed.update(myhash['installed'])
        self.removed.update(myhash.get('removed', {}))

        #
        # Convert old options to new options
        #
        if 'bundleID' in self.options:
            self.options['bundleIDs'] = [ self.options['bundleID'] ]
            self.options['depoturl'] = self.options['url']
        else:
            self.options['bundleurl'] = self.options['url']

        #
        # In older esxupdate's, the preinstalls were not recorded
        # in the pending list, so they would not end up in the
        # installed list either.  Help by finding out which of the
        # preinstall RPMs were installed, and put that in pending.
        #

        # Now deal with the string and bool vars
        self.obsolete  = myhash['obsolete']
        self.lastState = myhash['lastState']
        self.firsttime = myhash.get('firsttime', None)
        self.repackaged = myhash.get('repackaged', 0)

    def GetState(self):
        """ Return last install state as string, ie 'ConfigState' """
        if len(self.steps):
            return self.steps[-1].GetState()
        else:
            return self.lastState

    def GetLastErr(self):
        """ Returns last saved exception object or None """
        for i in range(len(self.steps)-1, -1, -1):
            if self.steps[i].err:
                return self.steps[i].err
        return None

    def GetTXStartTime(self):
        """ Returns time of first step of this install transaction, or 0 """
        if self.firsttime:
            return self.firsttime
        elif len(self.steps) > 0:
            return self.steps[0].GetTime()
        else:
            return 0

    def GetOption(self, key):
        if key in self.options:
            return self.options[key]
        else:
            return None

    def SetOptions(self, options):
        """ Merges dict of user options with existing dict """
        self.options.update(options)
        
    def SetRepo(self):
        """ Writes out URLs to /etc/vmware/yum.conf """
        assert len(self.bundleEntry) > 0
        log.debug('Writing out new /etc/vmware/yum.conf')
        from urllib import quote

        conf = ConfigParser.ConfigParser()
        conf.add_section('main')
        #
        # Don't treat kernel-source and other such packages specially.
        # Always upgrade them instead of installing.
        #
        conf.set('main', 'installonlypkgs', '')
        
        conf.set('main', 'cachedir', '/var/cache/yum')
        conf.set('main', 'debuglevel', '2')
        conf.set('main', 'logfile', '/var/log/yum.log')
        conf.set('main', 'pkgpolicy', 'newest')
        conf.set('main', 'distroverpkg', 'vmware-release')
        conf.set('main', 'tolerant', '1')
        conf.set('main', 'exactarch', '1')

        # Include relevant bundles only
        for bundle in self.OrderedBundleList():
            bEntry = self.bundleEntry[bundle]
            # skip if no rpms 
            if not bEntry.DeepFindFileMatching('*.rpm'):
                continue
            conf.add_section(bundle)
            conf.set(bundle, 'name', 'Bundle ' + bundle)
            conf.set(bundle, 'baseurl',
                     'file://' + quote(bEntry.localLocation))
            #
            # Build yum headers for locally cached downloads
            # to ensure yum metadata is in sync
            if not Depot.LocalUrl(bEntry.remoteLocation):
                insthelper.BuildYumHeaders(bEntry.localLocation)

            #
            # We use detached GPG signatures with an external keyring,
            # not RPM's built in keyring and sig checking.
            #
            conf.set(bundle, 'gpgcheck', '0')
        
        bases = self.GetOption('baseurls')
        if bases:
            for i in range(len(bases)):
                sectname = 'base%d' % (i + 1)
                conf.remove_section(sectname)
                conf.add_section(sectname)
                conf.set(sectname, 'name', 'Base repository %d' % (i + 1))
                conf.set(sectname, 'baseurl', bases[i] )

        fd = open(insthelper.YUMCONF_FILE, 'w')
        conf.write(fd)
        fd.close()

    def CreateProgressBar(self, steps):
        """ Creates a progress bar instance with the given number of steps.  Points
        to either the host agent-specific callback or the normal callback.
        """
        if self.GetOption('hostagent'):
            self.progress = ProgressBar(ha.HAProgressCB, steps)
        else:
            self.progress = ProgressBar(total_steps = steps)

    def RunStateMachine(self, firstState, lastState=type(None)):
        """ Keep looping through states until done, or until lastState.
        For each state, carry out Action() and call NextState() to return the
        next state, unless IsDone() returns true.
        lastState    - optionally, specify the last state to stop at. Useful
                       for debugging/testing.  Pass in a class ref, not a strng.
        Modifies steps[].
        """
        state = firstState
        while 1:
            self.steps.append(state)
            state.Action()
            if state.IsDone() or isinstance(state, lastState):
                break
            state = state.NextState()

    def SaveForRestore(self):
        """ Save current installation state for later restoration """
        self.db.SaveSession(self.ToHash())
        self.db.Sync()

    def SaveInstallRecord(self):
        """ Save a record of the current installation for future
        patch info/state/history queries.
        """
        for bundle in self.OrderedBundleList():
            xlist = self.GetOption('excludes')
            if xlist:
                rpms = self.descs[bundle].rpmlist.GetPkgNames()
                diff = [name for name in rpms if name not in xlist]
                if len(diff) == 0:
                    # Do not save bundle in patchdb if all rpms are excluded
                    continue

            self.db.AddBundleEntry(self.ToBundleHash(bundle))
        self.db.Sync()

        
    #
    # Install package state
    # ------------------------------------------------
    # The package state is kept in two dicts, pending
    # and installed, that keep track of packages that
    # still need to be worked on (and what type of work
    # is needed), and packages that have been installed.
    # They are indexed by package 'name.arch' and contain
    # a string to identify the operation needed:
    # 'downgrade', 'nodeps', 'install', etc.
    # ------------------------------------------------
    # This scheme really relies on the descriptor rpmlist
    # having the arch info.  Without this, the name.arch from
    # yum info and from the host rpmdb lookup cannot be
    # matched.
    #

    __validMarks = ('upgrade', 'install', 'downgrade', 'nodeps',
                    'exclude', 'remove', 'preinstall')

    def InitPending(self):
        """ Populate pending list from the rpmlist, preinstall,
        removes, and nodeps of all relevant descriptors.
        Previous pending list will be wiped out.
        This function must populate all the keys in pending;  no other
        function is allowed to expand the pending list.
        Also, packages in the nodeps and excludes lists that are not part of
        the descriptor FullRpmlist are ignored.
        """
        assert self.lastDesclist
        assert self.bundleGroups

        #
        # Populate self.latestlist, self.pending, self.pkgbundle
        # with preinstall, install, nodeps, remove(name) filled in
        #
        latest, removes, nodeps = insthelper.GetRpmBaseline(
            self.lastDesclist)
        self._mergeRpmPendingLists(latest, removes, nodeps)

        # Also exclude user requested pkgs
        # We must make an Rpmlist out of the list of package names.
        if self.GetOption('excludes'):
            xlist = Rpmlist(self.GetOption('excludes'))
            self.MarkPending(xlist.GetPkgNAList(self.latestlist), 'exclude')
            
    def MarkPending(self, NAlist, mark, grow=0):
        """ Mark the packages given by NAlist in the pending list with mark.
        NAlist   - list of packages in name.arch format
        mark     - one of the valid marks
        grow     - 1 if new entries should be added to pending, default 0
        """
        assert mark in self.__validMarks, "Invalid mark %s passed to MarkPending" % (mark)
        for nameArch in NAlist:
            if grow or nameArch in self.pending:
                self.pending[nameArch] = mark
            else:
                log.debug('Rejected na=%s, not in pending' % (nameArch))

    def MarkPendingNames(self, nameslist, mark, grow=0):
        """ A version of MarkPending that takes a list of package names.
        The arch info is looked up from the latestlist.
        """
        assert self.latestlist is not None
        reqlist = Rpmlist(nameslist)
        self.MarkPending(reqlist.GetPkgNAList(self.latestlist), mark, grow)

    def GetPending(self, mark):
        """ Return a name.arch list of pending packages with the given mark """
        assert mark in self.__validMarks, "Invalid mark %s passed to GetPending" % (mark)
        NAlist = [na for na in self.pending if self.pending[na] == mark]
        return NAlist

    def MatchPending(self, marks):
        """ Return a name.arch list of pending packages with any of the given marks """
        NAlist = [na for na in self.pending if self.pending[na] in marks]
        return NAlist

    def OnlyExcludesLeft(self):
        """ Returns true if the only packages left on the pending list are
        excluded packages.
        """
        return len(self.pending) == len(self.GetPending('exclude'))

    def DriverPending(self):
        """ Returns true if a driver package is about to be installed """
        for na in self.pending:
            if self.pending[na] == 'exclude':
                continue
            #
            # TODO: verify if the rpm depends on the vmkernel driver API;
            # the test below is good for the foreseeable future.
            #
            if na.startswith('VMware-esx-drivers'):
                return True

        return False

    def VerifiedPending(self, NAlist):
        """ Move the packages in NAlist from pending to installed list.  Should be
        used only after the NAlist packages have been verified to be installed.
        NAlist   - list in name.arch, or dict with keys in name.arch
        """
        for na in NAlist:
            if na in self.pending:
                self.installed[na] = self.pending[na]
                del self.pending[na]

    def RemovePending(self, NAlist):
        """ Remove the list of packages in NAlist from the pending list.
        NAlist   - list in name.arch, or dict with keys in name.arch
        """
        for na in NAlist:
            if na in self.pending:
                del self.pending[na]

    def DumpPkglists(self):
        """ Dump debugging info on pending and installed lists """
        mykeys = self.pending.keys() + self.installed.keys()
        mykeys.sort()
        for key in mykeys:
            str1 = '    '
            str2 = '    '
            if key in self.pending:
                mark = self.pending[key]
                str1 = 'pend'
            if key in self.installed:
                mark = self.installed[key]
                str2 = 'done'
            log.debug("  %9s %4s%4s: %s" % (mark, str1, str2, key))

        log.debug('-- Total %d  pending %d  installed %d' % (len(mykeys),
                                                            len(self.pending),
                                                            len(self.installed)
                                                            ))

    def UpdatePendingFromHost(self):
        """ Checks the packages installed on the host and transfers
        them to the installed list via VerifiedPending()
        """
        assert self.latestlist is not None
        #
        # What got installed from our patch?
        # Move from pending to installed list
        #
        if len(self.latestlist):
            log.debug('Checking status of %d rpms...' % (len(self.latestlist)) )
            installed, matches = insthelper.GetDbInstalledHeaders(self.latestlist)
            self.VerifiedPending(matches.keys())

        #
        # Move to removed list the pkgs no longer in RPMDB
        # Match by package name only
        #
        self.UpdateRemovedFromHost()

        # So we can debug at each stage
        self.DumpPkglists()

    def UpdateRemovedFromHost(self):
        """ If a 'remove' package is no longer installed, then
        delete from pending list and put it on removed list.
        """
        removes = self.GetPending('remove')
        if removes:
            dblist = insthelper.GetRpmlistFromRpmdb(self.options['rpmroot'])
            for name in removes:
                if not dblist.GetPkgByName(name):
                    self.removed[name] = self.pending[name]
                    del self.pending[name]

    #
    # ------------------------------------------------
    # Manage session state for working with multiple bundles
    # ------------------------------------------------
    #
    def DownloadDescriptors(self):
        """ Download descriptors from depot and populate
        self.descs[].
        Throws DepotAccessError, DepotDownloadError.
        """
        self.depot.SetBlacklist(['*.rpm'])
        self.depot.SyncFromRemote()
        self.depot.SanityCheck()
        if len(self.depot.bundles) == 0:
            raise errors.DepotAccessError('No matching bundleIDs found at '
                                          '%s' % (self.depot.url))
        self.flags = UpgradePaths()
        for bundleID, bundleHash in self.depot.bundles.items():
            self.descs[bundleID] = bundleHash['desc']
            self.bundleEntry[bundleID] = self.depot.bundleDirEntry[bundleID]

    def OrderedBundleList(self):
        """ Returns an ordered list of bundles to install """
        ordered = []
        for item in self.bundleGroups:
            item = list(item)
            item.sort()
            ordered += item
        return ordered

    def MergeInstallFlags(self):
        """ Merges the install flags of all the installable bundles """
        bundles = self.OrderedBundleList()
        if not bundles:
            return
        flags = self.descs[bundles[0]].GetUpgradePaths()
        for bundle in bundles[1:]:
            other = self.descs[bundle].GetUpgradePaths()
            flags.MergeFlags(other)
        self.flags = flags

    def SetEntryInstances(self):        
        """ After InitPending() is called, this method fills the
        pkgEntry[] hashes, connecting each rpm to its DepotEntry
        instance.
        DepotAccessError - an entry cannot be found for an rpm
        """
        for pkg in self.latestlist.GetRpms():
            na = pkg.GetNameArch()
            bEntry = self.bundleEntry[self.pkgbundle[na]]
            pkgEntry = bEntry.GetEntry(pkg.GetBasename())
            if pkgEntry:
                self.pkgEntry[na] = pkgEntry
            else:
                raise errors.DepotAccessError(
                    'No depot entry for file %s, package %s in bundle %s' %
                    (pkg.GetBasename(), na, self.pkgbundle[na]))

    def _mergeRpmPendingLists(self, baseline, removes, nodeps):
        """ Narrows the latest RPM list to just the bundles that will
        be installed. Also, compute the corresponding pending list,
        with the preinstall, install, nodeps, and remove marks filled
        in. The inputs are the outputs from insthelper.GetRpmBaseline().
        
        Inputs: baseline - the latest list of all bundles in the
                            PatchDB and Depot
                removes  - dict of pkg names to remove
                nodeps   - dict of pkg name.arch to install nodeps
        Requires: descs, bundleGroups
        Modifies: latestlist - subset of baseline for all
                               bundles to be installed
                  pending - preinstall, install, remove, nodeps
                  pkgbundle - bundle ID for each name.arch (remove: name only)
        """
        latestlist = Rpmlist(None)
        pending = {}
        pkgbundle = {}
        #
        # Aggregate all of the preinstalls and removes
        # from installable bundles
        #
        all_preinstalls = {}
        all_removes = {}
        removes_ids = {}
        for bundle in self.OrderedBundleList():
            full = self.descs[bundle].GetFullRpmlist()
            if full:
                union = baseline.Union(full)
                latestlist.Merge(union)
                for na in union.GetPkgNAList():
                    if na not in pending:
                        pending[na] = 'install'
                        pkgbundle[na] = bundle
                        
            preinst = self.descs[bundle].GetPreinstall()
            if preinst:
                for na in preinst.GetPkgNAList(full):
                    all_preinstalls[na] = 1
                    
            removelist = self.descs[bundle].GetRemoves()
            if removelist:
                for name in removelist.GetPkgNames():
                    all_removes[name] = 1
                    if name not in removes_ids:
                        removes_ids[name] = bundle

        log.log(1, "latestlist = " + str(latestlist.rpms))
        log.log(1, "pkgbundle = " + str(pkgbundle))
        log.log(1, "all_preinstalls = " + str(all_preinstalls))
        log.log(1, "all_removes = " + str(all_removes))
        log.log(1, "removes_ids = " + str(removes_ids))
        log.log(1, "pending so far = " + str(pending))
        #
        # We have latestlist and pending as the set of union rpms.
        # preinstalls = union of latestlist and all_preinstalls
        # nodeps = union of latestlist and nodeps
        #
        for na in pending:
            if na in all_preinstalls:
                pending[na] = 'preinstall'
            elif na in nodeps:
                pending[na] = 'nodeps'

        #
        # Finally, what gets removed is the union of the removes
        # list from all descriptors & all_removes, which come
        # from just the bundles we are installing.
        #
        for name in removes:
            if name in all_removes:
                pending[name] = 'remove'
                pkgbundle[name] = removes_ids[name]
                
        self.latestlist = latestlist
        log.log(1, "final pending = " + str(pending))
        self.pending = pending
        self.pkgbundle = pkgbundle
    
    #
    # ------------------------------------------------
    # Write out reports or summaries to the log at the end
    # ------------------------------------------------
    #
    def LogSummary(self):
        """ Log the pending and installed totals """
        excludelist = self.GetPending('exclude')
        inststr = '%d packages installed' % (len(self.installed))
        pendstr = '%d pending or failed' % (len(self.pending) - len(excludelist))
        rmvstr = '%d removed' % (len(self.removed))
        xstr = '%d excluded' % (len(excludelist))

        sumlog = logging.getLogger('summary')
        sumlog.info('--- TOTALS: %s, %s, %s, %s ---' % (inststr, pendstr,
                                                        rmvstr, xstr))
        if len(self.downgrades):
            sumlog.info('--- WARNING: RPMs downgraded:')
            for new, old in self.downgrades:
                sumlog.info('%s is downgraded to %s' % (new.GetBasename(), old.GetBasename()))

