#######################################################################
# Copyright (C) 2005 VMWare, Inc.
# All Rights Reserved
########################################################################
#
# Helper functions for InstallState classes.
#
# Testing:
#   test/testinstall.py

import os, os.path
import stat
import logging
import re
import rpm
import sys
import select
import time
import glob
import shutil
from statvfs import *
from stat import *

from vmware import system
from vmware.descriptor import Rpm, Rpmlist
from vmware.descriptor import Descriptor
from vmware.descriptor import utils
import errors
from Depot import LOCAL_CACHE_ROOT
from vmware.Lock import Lock

#
# Constants/globals
#
log = logging.getLogger('install')

YUMLOG         		= '/var/log/yum.log'
YUM_HDR_CACHE  		= '/var/cache/yum/update/headers'
RPM_TIMEOUT    		= None
VMCPUINFO      		= '/proc/vmware/sched/cpu'
OPTIONAL_BACKUP_FILES	= '/boot/grub/grub.conf /boot/grub/device.map /boot/initrd*vmnix*.img'
MANDATORY_BACKUP_FILES 	= '/etc/vmware/*' 
YUMCONF_FILE 	= '/etc/vmware/yum.conf'

########################################################################
#
# Host operations
#

#
# Return the version of the yum package.  This may not work
# in the rare case that multiple yum versions exist.
#
def GetYumVer():
    """ Return the yum version # as (version, release) tuple. """
    ts = rpm.TransactionSet()
    for h in ts.dbMatch('name', 'yum'):
        return h['version'], h['release']
    return None, None


#
# Return the RPM database headers keyed by name.arch:
#  installed[name.arch] = [ all headers with matching name.arch ]
#
def GetRpmdbHeaders(rpmroot='/'):
    """ Return headers in RPM database as dict keyed by name.arch
    The RPM database is located at rpmroot+'/var/lib/rpm'
    """
    installed = {}
    log.debug('Reading from RPM database at root %s...' % (rpmroot))
    ts = rpm.TransactionSet(rpmroot)
    mi = ts.dbMatch()
    for h in mi:
        na = '%s.%s' % (h['name'], h['arch'])
        installed.setdefault(na, []).append(h)
    log.debug('  found %s unique packages' % (len(installed)))
    return installed


#
# Map the RPM database to an Rpmlist instance and return it.
# In cases when multiple versions of a package are in the RPMDB,
# we default to the same algo used in yum 2.0.x (rpmdbNevralLoad):
# always keep the package with the highest version.
# NB: This does not cover the case where the truly installed pkg
# version is not the latest, eg if somebody rpm -ihv --force an
# older version.  The better way is to sort by the INSTALLTID
# in the RPM header and use the header with the highest TID.
# OTOH, RPM bugs mess up the TID :(.
# For now yum compatibility is important, but in the future,
# when yum is removed, we can try sorting by TID.
#
def GetRpmlistFromRpmdb(rpmroot='/'):
    dblist = Rpmlist(None, tag='rpmdb')
    installed = GetRpmdbHeaders(rpmroot)
    for na in installed:
        rpm = Rpm(hdr=installed[na][0])
        for rpmhdr in installed[na][1:]:
            rpm1 = Rpm(hdr=rpmhdr)
            if rpm1 > rpm:
                rpm = rpm1
        dblist.AddRpm(rpm)
    return dblist


#
# Return an Rpmlist of what we should expect to be installed given the
# patches and ISOs recorded in our Patch DB.
#
def ExpectedRpmsFromDb(db):
    """ Returns an Rpmlist of what should be installed given the db.
    db            - a PatchDB instance
    """
    patches = db.GetInstalledPatches(getObsolete=True)
    timeindex = db.LasttimeIndex(patches)
    union = Rpmlist(None, tag='unionlist')
    #
    # Form a list of the rpms that should be installed.
    # Go through each install hash record, merging the rpms that were installed,
    # and deleting the rpms that were obsoleted.
    #
    for key in timeindex:
        d = patches[key]['desc']
        if d and d.GetFullRpmlist():
            installedNAlist = patches[key]['installed'].keys()
            instlist = d.GetFullRpmlist().FilterList(installedNAlist)
            union.Merge(instlist)

        union.Prune(patches[key].get('removed', []))

    return union


#
# Diff RPM database headers against an Rpmlist
#  headers    - a dict keyed by name.arch with list of headers or rpm.mi instance
#  rpmlist    - the Rpmlist to compare against, must have no duplicates
#
# Returns:
#  new                = Rpmlist of installed packages not in rpmlist (may have dups)
#  missing            = Rpmlist of pkgs on rpmlist not installed
#  extras[name.arch]  = [ headers of pkgs with same 'name.arch' but wrong version ]
#  matches[name.arch] = [ header of pkg with matching name, arch, and version     ]
#
def DiffHeaders(headers, rpmlist):
    """ Diff the rpm headers against an Rpmlist, returns (new, missing, extras, matches)
    """
    missing = Rpmlist(None, tag='missing')
    new = Rpmlist(None, tag='newlist')
    extras = {}
    matches = {}
    for pkg in rpmlist.GetRpms():
        na = pkg.GetNameArch()

        # Process right and wrong versions of matching packages
        if headers.has_key(na):
            for h in headers[na]:
                if Rpm(hdr=h) == pkg:
                    matches[na] = [ h ]
                else:
                    extras.setdefault(na, []).append(h)
        # not in rpmdb => deleted
        else:
            missing.AddRpm(pkg)
            
    # What's not matched nor in extras is new
    for na in headers:
        if na not in matches and na not in extras:
            for h in headers[na]:
                new.AddRpm(Rpm(hdr=h))

    log.debug('DiffHeaders: %d new pkgs, %d missing, %d matching with %d having extras' % \
              (len(new), len(missing), len(matches), len(extras)))
    return new, missing, extras, matches


#
# Sort the descriptor list in ascending
#   esx version (e.g. 301, 302, 350 ), UTC Timestamp, release name.
# This is needed for handling <removes> and <nodeps> properly.
#
def SortedDescriptorList(desclist):
    def get_esx_version(desc):
        ver_re = re.compile(r'^(\d).(\d).(\d)-')
        # check ga release name convention, e.g. 3.5.0-, return '350'
        mo = ver_re.match(desc.GetRelease())
        if mo:
            return mo.group(1) + mo.group(2) + mo.group(3)
        # check <upgradepaths><path release="3.0.2-" > attribute
        for reqId in desc.GetUpgradePaths().GetRequiredIDs():
            mo = ver_re.match(reqId)
            if mo:
                return mo.group(1) + mo.group(2) + mo.group(3)

    def _sort(a, b):
        # esx version
        a_ver = get_esx_version(a)
        b_ver = get_esx_version(b)
        if a_ver != b_ver:
            return cmp(a_ver, b_ver)
        # utc timestamp
        a_utc = a.GetUTCTimestamp()
        b_utc = b.GetUTCTimestamp()
        if a_utc != b_utc:
            return cmp(a_utc, b_utc)
        # release name
        return cmp(a.GetRelease(), b.GetRelease())

    descriptors = desclist[:]
    descriptors.sort(_sort)
    return descriptors


#
# Builds a list of Rpm's to install and remove, given a bunch of
# descriptors.  Algorithm details:
#  * In general, the newest rpms win
#  * A <remove> of an rpm takes precedence over any versions of that rpm
#    in earlier descriptors
#  * <nodeps> rpms take precedence over the same rpms in earlier desc's
#  * descriptors are sorted by the UTC timestamp
#
# Descriptor requirements:
#  <rpmlist> must have arch set (any descriptor generated by makerepo.py)
#  UTCtimestamp attribute of <descriptor> must be set
#
# Returned:
#    latest    - Rpmlist() of rpms to install
#    removes   - dict[name] of rpms to remove from system
#    nodeps    - dict[name.arch] of rpms to install nodeps
#
def GetRpmBaseline(desclist):
    descriptors = SortedDescriptorList(desclist)
    #
    # The latest list should contain the latest versions of all
    # the rpms in all descriptors, except ones listed in <nodeps>
    # and <removes>
    #
    latest = Rpmlist(None)
    removes = {}
    nodeps = {}
    for desc in descriptors:
        log.log(1, "Processing " + desc.GetRelease())
        rpmlist = desc.GetFullRpmlist()
        if rpmlist:
            updates, newlist, older = rpmlist.GetUpdates(latest)
            latest.Merge(updates)
            latest.Merge(newlist)
            log.log(1, "  + updates = " + str(updates.rpms))
            log.log(1, "  + newlist = " + str(newlist.rpms))

            #
            # Eliminate pkgs in removes/nodeps lists if
            # they have been superseded
            #
            for newrpm in newlist.GetRpms():
                if newrpm.GetName() in removes:
                    del removes[newrpm.GetName()]
            for rpm in updates.GetRpms():
                if rpm.GetNameArch() in nodeps:
                    del nodeps[rpm.GetNameArch()]
        
        #
        # <removes> takes precedence over previous pkgs
        #
        removelist = desc.GetRemoves()
        if removelist:
            removeNames = removelist.GetPkgNames()
            latest.Prune(removeNames)
            log.log(1, "  - removed = " + str(removeNames))
            for name in removeNames:
                removes[name] = 1

        #
        # <nodeps> get installed regardless of version
        # so must be considered the "latest".  Also, <nodeps>
        # list has no rel-ver info, so must get that info
        # from the full rpmlist.
        #
        nodeplist = desc.GetNodeps()
        if nodeplist:
            nodepsRpms = rpmlist.FilterList(nodeplist.GetPkgNames())
            latest.Merge(nodepsRpms)
            log.log(1, "  + nodeps = " + str(nodepsRpms.rpms))
            for na in nodepsRpms.GetPkgNAList():
                nodeps[na] = 1
    
    log.log(1, "Latest rpms = " + str(latest.rpms))
    log.log(1, "Rpms removed = " + str(removes.keys()) )
    log.log(1, "nodeps list = " + str(nodeps.keys()) )

    return latest, removes, nodeps


#
# Finds the newest bundles associated with each package in the
# latest, removes lists produced by GetRpmBaseline() above.
# Returns:
#   bundleIDs  : dict of bundle IDs keyed by name.arch for installable
#                packages, or name for packages to be removed
#
def GetLatestBundles(desclist, latest, removes):
    bundleIDs = {}
    descriptors = SortedDescriptorList(desclist)
    for desc in descriptors:
        bundleID = desc.GetRelease()
        rpmlist = desc.GetFullRpmlist()
        if rpmlist:
            myLatest = rpmlist.Union(latest)
            for na in myLatest.GetPkgNAList():
                bundleIDs[na] = bundleID
        removelist = desc.GetRemoves()
        if removelist:
            removeNames = removelist.GetPkgNames()
            for name in removeNames:
                if name in removes:
                    bundleIDs[name] = bundleID
            
    return bundleIDs


#
# Determines if a bundle is "fresh", and if not, returns what's newer.
# Returns: (newerBundles, newerPkg)
# newerBundles:  list of the bundles that obsolete this one
# newerPkg    :  dict keyed by name.arch or name for removable pkgs.
#                The bundle ID that obsoleted that particular pkg.
#
# NB: newerPkg cannot compute when an <rpmlist> obsoletes a <removes>,
#     due to lack of arch info in <removes>.  Re-introducing an RPM
#     we have previously removed should never occur though.
#
def SupersedingBundleInfo(desc, latest, removes, bundleIDs, allBundles):
    newerBundles = []
    newerPkg = {}
    bundleID = desc.GetRelease()
    #
    # Figure out superseding RPMs ... compute newerPkg
    #
    rpmlist = desc.GetFullRpmlist()
    if rpmlist:
        bundleIDs = FilterOlderBundles(desc, bundleIDs, allBundles, rpmlist)
        for na in rpmlist.GetPkgNAList():
            name = utils.NameFromNA(na)
            if na in bundleIDs:
                if bundleIDs[na] != bundleID:
                    newerPkg[na] = bundleIDs[na]
            if name in bundleIDs:
                if bundleIDs[name] != bundleID:
                    newerPkg[name] = bundleIDs[name]                
    removelist = desc.GetRemoves()
    if removelist:
        for name in removelist.GetPkgNames():
            if name in bundleIDs:
                if bundleIDs[name] != bundleID:
                    newerPkg[name] = bundleIDs[name]

    #
    # if none of descriptor's RPMs are the latest, feed list
    # of obsoleting bundles
    #
    if rpmlist and len(rpmlist.Union(latest)) == 0:
        for bundle in newerPkg.values():
            if bundle not in newerBundles:
                newerBundles.append(bundle)

    return newerBundles, newerPkg

def FilterOlderBundles(currentBundle, bundleIDs, allBundles, rpmlist):
    """ Process a list of bundles. Look for bundles that are 
        applicable to an older x.y.z version than the 
        bundle we're currently looking at. Exclude these bundles from
        the return value.
        
        currentBundle:  a bundle which can possibly be obsoleted by one 
                        or more bundle in bundleIDs.values
        bundleIDs:   latest rpm to bundle mapping. 
        allBundles:  used to lookup descriptor objects 
        rpmlist:     used to filter bundleIDs

        returns: a new na:bundle dictionary, excluding older bundles
    """

    allowedBundles = {}
    # if the current bundle doesnt have any RPMs it cannot be 
    # implicitly obsoleted.  
    if rpmlist:
        bundleIDs = dict([(na, bundleIDs[na])
                          for na in rpmlist.GetPkgNAList()
                          if na in bundleIDs])
    else:
        return bundleIDs

    versionre = re.compile(r'^\d\.\d.\d+')
    currentVersion = (0,0,0)
    # determine what x.y.z version the bundle is
    # applicable to
    release = currentBundle.GetRelease()
    
    # pick out the greatest req ID and use that as version
    for reqId in currentBundle.GetUpgradePaths().GetRequiredIDs():
        result = versionre.match(reqId)
        if result:
            guess = tuple(map(int, result.group().split('.')))
            if guess > currentVersion:
                currentVersion = guess
    
    if currentVersion == (0,0,0):
        log.debug("Can't determine version of bundle: %s" % release)
        return bundleIDs

    # go through the list of bundles, and add those that should be allowed to 
    # obsolete our current bundle to allowedBundles
    for na, bundleID in bundleIDs.items():
        bundleDesc = allBundles[bundleID]['desc']
        bundleVersion = (0,0,0)

        # determine this bundle's version by looking at its req IDs
        for reqId in bundleDesc.GetUpgradePaths().GetRequiredIDs():
            result = versionre.match(reqId)
            if result:
                reqVersion = tuple(map(int, result.group().split('.')))
                if reqVersion >= currentVersion:
                    allowedBundles[na] = bundleID
                    break
                else:
                    log.debug("Bundle %s cannot obsolete %s" % \
                                                   (bundleID, release))

    return allowedBundles

#
# Delete files from a directory for which time1 <= modification time < time2
# time1, time2:  seconds since Epoch (as returned by time())
#
def RmdirByTime(dirname, time1, time2):
    """ Delete files from dirname for which time1 <= modification time < time2.
    """
    log.debug('Removing files %s <= mod time < %s' % (time1, time2))
    log.debug('  iow...       %s <= mod time < %s' % (FormatTime(time1), FormatTime(time2)))
    for f in os.listdir(dirname):
        path = os.path.join(dirname, f)
        mtime = os.stat(path)[stat.ST_MTIME]
        if os.path.isfile(path) and time1 <= mtime < time2:
            try:
                os.unlink(path)
            except EnvironmentError, (enum, msg):
                log.warn('Could not delete %s: %s' % (path, msg))
    

#
# Verify that there is sufficient free space on given file system.
# 
# Returns: nothing on success
# Throws: errors.DiskSpaceError on insufficient space or error.
#
# Requires: path on partition mountpoint, minimum free space in bytes.
# Note that 'partition' can be either a mountpoint, or any sub-directory of the
# partition's mount point.
#
def CheckPartitionFreeSpace(partition, bytes):
    try:
        fsstats = os.statvfs(partition)
    except OSError, e:
        raise errors.DiskSpaceError('Failed getting %s filesystem stats: %s'
                                    % (partition, e))
    spacefree = fsstats[F_BSIZE] * fsstats[F_BAVAIL]
    if spacefree < bytes:
        raise errors.DiskSpaceError('Insufficient free space on %s '
                                    'partition. Try removing unneeded '
                                    'files or re-partitioning the disk to '
                                    'create more free space in %s '
                                    'directory.' % (partition, partition))


#
# Make sure that /sbin and /usr/sbin are in PATH variable
#
defPaths = ('/bin', '/sbin', '/usr/sbin')

def CheckEnvPath(paths=defPaths):
    syspath = os.environ.get('PATH', '').split(':')
    for path in paths:
        if path not in syspath:
            syspath.append(path)
    os.environ['PATH'] = ':'.join(syspath)
    log.debug('PATH env variable = %s' % (os.environ['PATH']))


#
# See if any VMs are running.
# Returns FALSE if none are running or not booted into vmkernel.
# Throws LogCommandError if grep return code is 2 or above
#
def HaveRunningVMs():
    """ Check /proc/vmware/sched/cpu for running VMs """
    if os.access(VMCPUINFO, os.F_OK):
        try:
            out = system.LogCommand('grep vmware-vmx %s' % (VMCPUINFO), returnOutput=1)
        except system.LogCommandError, e:
            #
            # Return code from grep is 1 if no matches found
            #
            if e.res>>8 == 1:
                return False
            else:
                raise
        return out != ''
    #
    # vmkernel is not loaded if we can't access cpu proc node
    #
    return False


#
# See if the ESX host is in Maintenance Mode
# Returns:
#   False     - inMaintenanceMode is false or setting cannot be grepped
# Throws:
#   LogCommandError - if vimsh can't be executed
#
def InMaintenanceMode():
    """ See if the ESX Host is in Maintenance Mode """
    try:
        out = system.LogCommand("vimsh -ne 'hostsvc/runtimeinfo' | grep Maintenance",
                                returnOutput=1)
    except system.LogCommandError, e:
        if e.res>>8 == 1:
            return False
        else:
            raise

    setting, value = out.split('=')
    if value.strip() == 'true':
        return True
    else:
        return False
    

########################################################################
#
# Text formatters for InstallSession hashes saved to the PatchDB
# The hash structure is defined by InstallSession.ToHash()
#
# Other supporting functions for the info and query commands
#

def FormatHashVerbose(instHash, printRpms=0):
    """ Return a multiline string formatted from an install hash.
    printRpms    - if true, print list of RPMs installed
    """
    s = ''
    d = instHash['desc']
    if d:
        s += 'Product       : %s\n' % (d.GetProduct())
        s += 'Vendor        : %s (%s)\n' % (d.GetVendor(), d.GetContact())
        s += 'Bundle ID:    : %s\n' % (d.GetRelease())
        s += 'Release Date  : %s\n' % (d.GetReleasedate())
        s += 'Summary       : %s\n' % (d.GetSummary())
        s += 'Description   :\n%s\n'   % (d.GetDescription())
        #
        # Obtain dependency info and spit out install flags too
        #
        paths = d.GetUpgradePaths()
        s += 'Requires      : %s\n' % ', '.join(paths.GetRequiredIDs())
        s += 'Conflicts with: %s\n' % ', '.join(paths.GetConflictIDs())
        s += 'Obsoletes     : %s\n' % ', '.join(paths.GetObsoletedIDs())
        s += 'Will reboot after install : %s\n' % FormatBool(paths.IsRebootRequired())
        s += 'Hostd restart required    : %s\n' % FormatBool(paths.RestartHostd())
        s += 'Maintenance Mode required : %s\n' % FormatBool(paths.IsMModeRequired())

    s += 'Bundle URL    : %s\n' % (instHash['options'].get('url', ''))
    if instHash.has_key('firsttime'):
        s += 'Install start : %s\n' % (FormatTime(instHash['firsttime']))
        s += 'Install finish: %s\n' % (FormatTime(instHash['lasttime']))
    #
    # Now, print out list of installed, not installed, and removed
    # RPMs. Reference the version-release in the full rpmlist.
    #
    if printRpms:
        rpmlist = instHash['desc'].GetFullRpmlist()
        instlist = None
        if instHash['installed']:
            instlist = rpmlist.FilterList(instHash['installed'])
            rpmlist.Prune(instHash['installed'])

        if instlist:
            s += 'RPMs installed:\n'
            s += FormatRpmlist(instlist)

        if rpmlist:
            s += '\nRPMs skipped or not yet installed:\n'
            s += FormatRpmlist(rpmlist)

        removelist = instHash.get('removed')
        if removelist:
            #
            # There is no rpmlist to reference for removed
            # rpms, so just print out name.arch.
            #
            s += '\nRPMs removed:\n'
            s += '\n'.join(['  %s' % (na) for na in removelist])
        s += '\n'

    return s


def FormatHashShort(instHash):
    """ Return a one-line string formatted from an install hash """
    relstr = 'none'
    datestr = '-'
    sumstr = ''
    d = instHash['desc']
    if d:
        relstr = d.GetRelease()
        sumstr = d.GetSummary()
        
    if instHash.has_key('lasttime'):
        datestr = FormatTime(instHash['lasttime'])

    return '%20s %20s %.40s' % (relstr, datestr, sumstr)


def FormatShortHeaders():
    """ Returns header string for FormatHashShort() """
    s = '%20s %20s %.40s' % ('------ Name ------', '--- Install Date ---',
                            '--- Summary ---')
    return s


def FormatTime(sec):
    """ Converts time in seconds since Epoch to a string """
    try:
        return time.strftime('%X %x', time.localtime(sec))
    except TypeError:
        return sec


def FormatBool(val):
    """ Returns 'True' or 'False' based on value of val """
    if val:
        return 'True'
    else:
        return 'False'


def FormatRpmlist(rpmlist):
    """ Returns multiline string with NVR for each rpm """
    textlist = ['  %s' % (rpm.GetNVR()) for rpm in rpmlist.GetRpms()]
    textlist.sort()
    return '\n'.join(textlist)


def FormatExtras(extras, matches, expected):
    """ Figure out which packages in extras are duplicates (same package with
    correct version is installed).  If the correct version is not installed,
    display it.  Returns multiline string.
    """
    s = ''
    for na in extras:
        correctRpm = expected.GetPkgByNA(na)
        if correctRpm:
            correctVer = '%s-%s' % (correctRpm.GetVer(), correctRpm.GetRelease())
        else:
            correctVer = ''

        if matches.has_key(na):
            aid = '(duplicate of %s)' % (correctVer)
        else:
            aid = '(should be %s)' % (correctVer)
            
        for h in extras[na]:
            nvr = '%s-%s-%s' % (h['name'], h['version'], h['release'])
            s += '  %-40s %s\n' % (nvr, aid)
    return s


########################################################################
#
# Helper functions for installation steps.
#
    

#
# A very basic callback for running an RPM transaction.
# Use it by creating an instance and passing it to rpm.ts.run().  E.g.
#     cb = RpmTransactionCb()
#     ts.run(cb, 1)
#
class RpmTransactionCb:
    def __init__(self):
        self.fd = None
    def __call__(self, reason, amount, total, key, client_data):
        if reason == rpm.RPMCALLBACK_INST_OPEN_FILE:
            if self.fd != None:
                os.close(self.fd)
            self.fd = os.open(key, os.O_RDONLY)
            return self.fd
        if reason in (rpm.RPMCALLBACK_INST_CLOSE_FILE,
                      rpm.RPMCALLBACK_INST_START):
            if self.fd != None:
                os.close(self.fd)
                self.fd = None
    def __del__(self):
        if self.fd:
            os.close(self.fd)


#
# Do transaction check for list of available RPMs.
#
# Returns: Nothing on success.
# Throws: errors.DiskSpaceError on failure
#
# Requires: rpmlist, a list of RPM filenames to add to the transaction.
#
def TestTransaction(rpmlist, removelist, rollback=False,
                    rpmroot='/'):
    ts = rpm.TransactionSet(rpmroot)
    #
    # No verification
    ts.setVSFlags(rpm._RPMVSF_NODIGESTS | rpm._RPMVSF_NOSIGNATURES)
    #
    # Filter everything except diskspace problems.
    probflags = rpm.RPMPROB_FILTER_DISKNODES | rpm.RPMPROB_FILTER_DISKSPACE
    ts.setProbFilter(sys.maxint ^ probflags)
    tsflags = rpm.RPMTRANS_FLAG_TEST
    #
    # Does the disk space check actually take repackaging into consideration?
    # I doubt it.
    if rollback:
        tsflags |= rpm.RPMTRANS_FLAG_REPACKAGE
    ts.setFlags(tsflags)
    for pkg in removelist:
        ts.addErase(pkg)
    for fn in rpmlist:
        try:
            fd = os.open(fn, os.O_RDONLY)
            hdr = ts.hdrFromFdno(fd)
            ts.addInstall(hdr, fn, 'u')
        except IOError, e:
            raise errors.DiskSpaceError('Unable to open RPM file %s' % fn)
        except rpm.error, e:
            raise errors.DiskSpaceError('Error reading RPM header %s: %s' %
                                        (fn, e))
        os.close(fd)
    ts.check()
    ts.order()
    problems = ts.run(RpmTransactionCb(), 1)
    if problems:
        descs = list()
        for problem in problems:
            descs.append(problem[0])
        raise errors.DiskSpaceError(', '.join(descs))


""" Invoke the yum program with command and arguments in a pipe.
Log output to stdout, filter output to log file, and keep relevant
output also: errors (yum doesn't write to stderr), and problems such
as yum taking no action and returning ok error code.
> yum -y options yumcmd pkglist
 yumcmd     : 'upgrade', 'install'
 pkglist    : name.arch list: specific packages to be installed, if any
 xlist      : name.arch list: add '--exclude' option for each of these
 progress   : ProgressBar instance, can be None
 test       : If true, log yum command but do not execute it.
 options    : text options to pass on
 repackage  : bool, if true, add '--repackage' to options
 timeout    : pipe read timeout in seconds
"""
def InvokeYum(yumcmd, pkglist, xlist, progress, test=False, options='',
              repackage=0, timeout=RPM_TIMEOUT):
    log = logging.getLogger('yum')
    
    #
    # Form the exclusion list string using package names -
    #  if name.arch is passed in, select only the name, because
    #  yum 2.0.7 does not like name.arch in the excludes.
    #
    excludestr = ''
    for xpkg in xlist:
        excludestr += '--exclude=%s ' % (utils.NameFromNA(xpkg))
        
    pkgnames = [utils.NameFromNA(na) for na in pkglist]
    pkgnames = ' '.join(pkgnames)
    if repackage:
        options += ' --repackage'
    #
    # Note: -y => assume yes.  -C => run entirely from cache
    # Drop -C flag, which causes error for preinstall, see PR 176929 for detail.
    #
    line = 'yum -c %s -y %s %s %s %s' % (YUMCONF_FILE, excludestr, options, yumcmd, pkgnames)
    log.debug('About to execute [%s]...' % (line))
    if test: return 1

    #
    # TODO: Precompile regex's here for filtering output
    #
    # Here is a list of yum output messages (as of 2.0.7) we may want to filter.
    # They all represent broken deps; except the first comes with kernel package
    # installs.  The goal is to come up with a list of packages we must force install.
    # Also, we should only force install these packages if we have --force defined.
    # 
    ## Errors reported during trial run
    ## file XXXX.so from install of N-V-R conflicts with file from package N-V-R
    #
    needRE = re.compile(r'^Package (\S+) needs (\S+)')
    confRE = re.compile(r'^\s*conflict between (\S+) and ([\w\-]+)')
    errout = ''
    forcelist = []

    #
    # Compiled regexp's (REs) for package statuses
    #
    statusREs = [('Repackaging', re.compile(r'^Repackaging (\S+)') ),
                 ('Updating',    re.compile(r'^(\S+) \d+ % done') ),
                 ('Finishing',   re.compile(r'^Completing update for (\S+)') ),
                 ('Downloading', re.compile(r'Getting (\S+)\.rpm') ),
                 ]

    #
    # Passing a list instead of a string enables us to bypass
    # the shell.  One less layer to worry about, and much more secure.
    # popen4 combines stderr and stdout into one file handle.
    #
    # Use splitlines() string method to break up the \r's (Ctrl-M's).
    # readline() output from yum may have lots of \r's inlined.
    # Act like the console and only keep the part after the final \r.
    #
    yumout = system.Ptypipe(line.split())

    while 1:
        #
        # a SIGCHLD could interrupt the select syscall.  Time to end.
        #
        try:
            inlist, outlist, exlist = select.select([yumout], [], [], timeout)
        except select.error:
            break
        
        if not inlist and not outlist and not exlist:
            msg = '%d sec timeout exceeded running [%s]' % (timeout, line[:40])
            log.error(msg)
            raise errors.TimeoutError(msg, timeout)
        
        try:
            out = yumout.readline()
        # IOError signals EOF
        except IOError:
            break
        else:
            outlast = out.splitlines()[-1]
            
        log.info('| ' + outlast)

        # Filter out package status messages
        if progress:
            for verb, compiledRE in statusREs:
                match = compiledRE.search(outlast)
                if match:
                    progress.DoneStep('%s %s' % (verb, match.group(1)) )

        # Filter out dependency and conflict errors
        confmatch = confRE.search(outlast)
        needmatch = needRE.search(outlast)
        if confmatch:
            errout += outlast + '\n'
            forcelist += confmatch.groups()
        if needmatch:
            errout += outlast + '\n'
            forcelist += needmatch.groups()
            
    #
    # If the child process already died, IOError is returned (PR 92188)
    #
    try:
        res = yumout.close()
    except IOError:
        res = 0
        pass

    system.ClosePtypipe()
    if res is None: res = 0
    if (res >> 8) != 0:
        raise system.LogCommandError(line, res >> 8, 0, errout)
    elif errout:
        raise errors.YumGenericError(line, errout, forcelist)

    return 1


#
# Use the 'rpm' command to directly install a package.
#
# Returns:  None
# Throws:   exception on failure
#
def InvokeRpm(localpath, rpmobj, options):
    """ localpath - path to rpm in local cache """
    rpmlog = logging.getLogger('rpm')
    #
    # When providing a path with spaces in it, instead of simply
    # escaping the spaces or enclosing the entire path in quotes,
    # RPM apparently needs us to do both. (This has been fixed in
    # RPM v4.4.8-0.4.)
    #
    fullPath = '"%s"' % localpath.replace(" ", "\\ ")

    #
    # Call rpm command.  
    #
    cmd = 'rpm -Uv %s %s' % (options, fullPath)
    try:
        out = system.LogCommand(cmd, returnOutput=1, timeout=RPM_TIMEOUT)
    except system.LogCommandError, e:
        #
        # rpm exits with code 256 (1) when installing kernel-vmnix, even
        # though the install succeeds.  Verify the name-rel-ver that's
        # printed alone on one of the lines
        #
        escapedstr = re.escape(rpmobj.GetNVR())
        if re.search(r'^\s*%s\s*$' % (escapedstr), e.output, re.MULTILINE):
            pass
        else:
            rpmlog.error('Did not find %s in rpm -Uv output.' % (rpmobj.GetNVR()))
            raise

    rpmlog.info('Installed %s' % (rpmobj.GetNVR()))

#
# Backup the system files before new install using the tar utility
#
# Returns:  None
# Throws:   exception on failure
#
def BackupSystemFiles(spooldirpath, spooldir, tarfile, 
                      optional=OPTIONAL_BACKUP_FILES, 
                      mandatory=MANDATORY_BACKUP_FILES):
    #
    # Use the -P option with tar to maintain the absolute path information
    #
    # Check if the directory is present or not. If not, create it.
    if not spooldir in os.listdir(spooldirpath):
        os.mkdir(spooldirpath + '/' + spooldir)
        log.debug("Directory %s not present at %s, so creating it..." % \
                  (spooldir, spooldirpath))

    # Launch the command
    tarfilewithpath = spooldirpath + '/' + spooldir + '/' + tarfile 
    #
    # List the files which need not be present for backup, and then 
    # feed those present to the tar command. This is to avoid generation 
    # of an error for files not present. Include the list of compulsory 
    # files in the second command itself, so that an error is generated 
    # if any of these is not present. 
    #
    cmd = 'ls %s | xargs tar -P -czf %s %s' % \
          (optional, tarfilewithpath, mandatory)
    out = system.LogCommand(cmd, returnOutput=1)
    log.info('All system files backed up...')


#
# Restore the system files to while reverting to old install, using the 
# tar utility
#
# Returns:  None
# Throws:   exception on failure
#
def RestoreSystemFiles(tarfile):
    #
    # Use the -P option with tar to maintain the absolute path information
    #
    # Launch the command
    cmd = 'tar -P -xzf %s' % (tarfile)
    out = system.LogCommand(cmd, returnOutput=1)
    log.info('All system files restored...')


#
# Wrapper to force install a list of packages
#
# Side effects:  Packages are installed --force --nodeps.  'Nuff said.
#
def ForceInstallRpms(pkgEntry, rpmlist, NAlist, test=0, repackage=0, 
                     progressbar=None):
    """ rpm --nodeps on list of packages in NAlist.
    pkgEntry  - dict of DepotEntry instances indexed by name.arch
    rpmlist   - Rpmlist instance containing at least the rpms in NAlist
    test, repackage, force - add these flags to the rpm command
    """
    rpm_opts = '--force --nodeps'
    if test:
        rpm_opts += ' --test'
    if repackage:
        rpm_opts += ' --repackage'
    for na in NAlist:
        if progressbar:
            progressbar.StartStep(na)
        log.debug('Force installing package %s' % (na))
        pkg = rpmlist.GetPkgByNA(na)
        assert pkg, 'Package %s not in rpmlist!' % (na)
        localpath = pkgEntry[na].localLocation
        InvokeRpm(localpath, pkg, rpm_opts)
        if progressbar:
            progressbar.DoneStep()


#
# Erase a list of packages via rpm -e.
#
# What should happen if other RPMs depend on an RPM in the removes list?
#
# 1) RPM should be removed even if deps broken.
# 2) RPM should not be removed if deps broken.    Print a warning.
# 3) RPM and all RPMs that depend on it should be removed.  (yum erase behavior)
#
# 2) seems like the best and least-intrusive option, but we go with 1) here
# for the sake of compatibility with anaconda and upgrade.pl.  Uniformity
# in upgrade behavior and less code change is better at this point.
#
def RemoveRpms(NAlist, test=0, progressbar=None):
    """ rpm -e --nodeps on a name.arch list of packages.
    If test is true, then add '--test' option.
    """
    rpm_opts = '-e --nodeps'
    if test:
        rpm_opts += ' --test'
    for na in NAlist:
        rpmname = utils.NameFromNA(na)
        if progressbar:
            progressbar.StartStep(rpmname)
        try:
            system.LogCommand('rpm %s %s' % (rpm_opts, rpmname))
        except system.LogCommandError, e:
            #
            # Mimic upgrade.pl behavior - ignore rpm errors
            #
            pass

        if progressbar:
            progressbar.DoneStep()


#XXX: Refactor using DiffRpmdbHeaders above?
#
# Obtain the package headers corresponding to rpmlist in
# the RPM database on the host.
#
# Returns two dicts (installed, matches):
#  installed[name.arch] = [ all headers with matching name.arch ]
#  matches[name.arch]   = the header with NAVR matching the rpm in rpmlist
#
# len(installed[name.arch]) > 1 means multiple versions of same pkg are installed
#
# TODO: Also match Epoch
#
def GetDbInstalledHeaders(rpmlist):
    """ Return package headers in the host DB that match our package list """
    installed = {}
    matches = {}
    if not rpmlist:
        return installed, matches
    
    ts = rpm.TransactionSet()
    #
    # Look up each package in RpmDB
    #
    for pkg in rpmlist.GetRpms():
        mi = ts.dbMatch('name', pkg.GetName())
        #
        # There may be multiple versions of the same package.  We match the one
        # with the same arch-version-release.
        #
        for h in mi:
            na = pkg.GetNameArch()
            installed.setdefault(na, []).append(h)
            if pkg == Rpm(hdr=h):
                matches[na] = h
            
    log.debug("Matching package name %d and NAVR %d" % (len(installed), len(matches)) )
    #
    # Free up an rpmdb lock as soon as we can.
    #
    del ts
    return installed, matches


#
# Initiate a rollback.
#
# XXX:
# The only failsafe method is reinstalling the old RPMs and
# restoring backed-up config files.
# We may also experiment with rollbacks.
#

#
# Set the install state after a reboot to newState.
# Actually it just calls SetLinuxBoot() to have esxupdate
# called with the newState as the arg.
#
# Use abspath and expanduser to find full path to this script
# TODO: XXX: This won't work when esxupdate is installed as an RPM
#     and users can simply type 'esxupdate' from anywhere. Use
#     a constant /usr/sbin/esxupdate and assume rpm install???
#
def SetRebootState(newState):
    """ Have esxupdate called after reboot with arg newState """
    fullpath = os.path.abspath(os.path.expanduser(sys.argv[0]))
    system.SetLinuxBoot(fullpath, newState)


#
# Re-start esxupdate via exec().
# The new process replaces the current, and inherits its environment.
# The loglevel, and argv[0] is preserved (same esxupdate is exec'ed)
# Esxupdate will start with 'restore' command, which seeks to restore
# a saved session and continue the installation.
#
# Returns:
#        Does not return unless exec() fails
# Side effects:
#        See above.
#
def RestartProcess():
    """ Restarts esxupdate """    
    log.info('Restarting %s...' % (sys.argv[0]))
    args = '-v %s restore' % (system.LogGetLevel())
    os.execvp(sys.argv[0], [ sys.argv[0] ] + args.split(' '))


#
# Build yum headers subdirectory
# Note: -l option is not needed but is kept to ensure 100%
# compatibility with makerepo.py
# Side effects:
#        Old headers subdirs are removed first
#
def BuildYumHeaders(bundleDir):
    system.LogCommand('rm -rf %s/headers' % (bundleDir))
    system.LogCommand('rm -rf %s/.newheaders' % (bundleDir))
    system.LogCommand('rm -rf %s/.oldheaders' % (bundleDir))
    system.LogCommand('yum-arch -l %s' % (bundleDir))


#
# Clean up a yum or rpm hang.
# Returns:
#        None
# Side effects:
#        The yum child process will be killed.  If rpm itself hangs,
#        we cannot find the child pid.
#
def DumpChildren():
    """ Kill the yum child process. """
    if system.GetPtyPid() > 0:
        #
        # Don't use LogCommand here as it could hang from piping/forking
        # weirdness after a hang
        #
        log.debug('Killing yum child process %s...' % (system.GetPtyPid()))
        os.system('kill -9 %s' % (system.GetPtyPid()))


#
# Common install cleanup.
# Side effects:
#        /etc/vmware/esx.conf.WRITELOCK, /var/run/yum.pid, and /var/lib/rpm/__db*
#        are deleted.
#
def Cleanup(removeRpms=False):
    """ Common install cleanup """
    log.debug('Install cleanup...')
    #
    # Remove any reboot commands we put in rc.local
    #
    system.SetLinuxBoot(None)

    #
    # Clean up after any messy installations
    # Also remove RPMs if specified
    #
    try:
        system.LogCommand('rm -f /etc/vmware/esx.conf.WRITELOCK')
        system.LogCommand('rm -f /var/lib/rpm/__db*')
        if removeRpms:
            system.LogCommand('rm -f %s/*.rpm' % (LOCAL_CACHE_ROOT))
            system.LogCommand('rm -f %s/*/*.rpm' % (LOCAL_CACHE_ROOT))
    except system.LogCommandError, e:
        log.debug("Ignored - error (%d) executing [%s]\n%s" % (e.res>>8, e.cmd, e.output))
        pass

    
########################################################################


def FlushYumCache():
    """Remove all the .hdr files from the headers dir"""
    hdrFiles = glob.glob( YUM_HDR_CACHE +'/*.hdr' )

    if not hdrFiles:
        log.warn('Yum cache already empty')

    for fname in hdrFiles:
        try:
            os.remove( fname )
        except OSError, e:
            raise errors.FileError( fname, str(e) )

#
# Check whether the process who generated the lock file is still running. 
# If it is, we will check the status again after sec seconds until it is 
# not running or we have tried retry times.
#
# Returns:
#  True     : Lock file is owned by active process
#  False    : Lock file does not exist, invalid lock file, 
#               or the owner of the lock file is not running 
#
def LockfileTimeout(file, sec=0, retry=1):
    """Check whether lock file is owned by active process""" 
    lock = Lock(file)

    while(retry):
        pid = lock.ReadPID()
        if pid and len(os.popen(
                        'ps -p ' + str(pid) + ' --no-heading').readline()):
             
            log.info('Lock file %s is owned by active process %d.' 
                        % (file, pid))
            log.info('Check the status after %d seconds...' % (sec))
            time.sleep(sec)
            retry = retry - 1
        else:
            return False

    return True
    
