#######################################################################
# Copyright (C) 2005 VMWare, Inc.
# All Rights Reserved
########################################################################
#
# Rpmlist.py
#    class module for Rpmlist, a collection of Rpm objects
#
import logging
import os
import errno
import glob
try:
    from elementtree.ElementTree import ElementTree, Element, iselement
except ImportError:
    from xml.etree.ElementTree import ElementTree, Element, iselement

from Rpm import Rpm
from utils import *

# set up logging system
log = logging.getLogger("rpmlist")


########################################################################
#
# Class representing a list of RPM packages in the descriptor.
#
# To use:
#   list = Rpmlist(root.find('rpmlist'))   # from rpmlist element
# OR:
#   list = Rpmlist(dir='/RPMS/', tag='rpmlist')  # from directory
#   root.append(list.ToXml())
# OR:
#   list = Rpmlist(my_list_of_rpm_names)   # from list of package names
#
# A key functionality of this class is the ability to do a diff against another
#  rpmlist.
#
# TODO:  Build a dict indexed by package name to speed up lookup/diff functions
# TODO:  A derived class that builds an Rpmlist off of the RPM database?
#        That would allow for a very easy comparsion between the rpmlist
#        in a patch and what's on the host.  Ex:
#          hostlist = repository.HostRpmlist()
#          add, chg, delete = patchlist.DiffList(hostlist)
#
class Rpmlist:
    """Class modelling a list of RPM packages."""

    def __init__(self, dir, tag='rpmlist', ts=None):
        """Constructor for Rpmlist instance.

        list = Rpmlist(None)   Create empty rpmlist object
        list = Rpmlist(elem)   Create from Rpmlist element
        list = Rpmlist(dir)    Read package info from directory dir
        list = Rpmlist(pkglist)  Create from list of pkg names

        Extra args:
         tag='rpmlist'     Tag to use for XML element
         ts                Set verification flags in transactionset

        Throws:
         AssertionError    - if element passed and tag != tag
         IOError(2)        - dir or dir/*.rpm not found
        """
        self.rpms = []
        self.tag = tag

        # It's an Element, create list of Rpms from that
        if iselement(dir):
            assert dir.tag == tag, "Mismatch: dir.tag = [%s], tag = [%s]" % (dir.tag, tag)

            self.rpms = [Rpm(child) for child in dir]
            dir.clear()

        # It's a pathname to dir of RPMs
        elif IsStringLike(dir):
            # add trailing slash if not present
            if not dir[-1] == '/':
                dir = dir + '/'
            log.log(1, "Rpmlist(dir=%s)" % (dir) )

            rpmlist = glob.glob(dir + '*.rpm')
            if not rpmlist:
                raise IOError, (errno.ENOENT, "Directory or RPMs not found", dir)

            for file in rpmlist:
                try:
                    rpmobj = Rpm(file, ts=ts)
                except EnvironmentError, (enum, msg):
                    log.error("%s (%d) while reading header from [%s]" %
                              (msg, enum, file))
                    log.info("Skipped [%s] due to errors" % (file))
                    continue

                self.rpms.append(rpmobj)
                log.log(1, "+ added [%s]" % (rpmobj.GetBasename()) )

        # It's a list of pkg names
        elif IsListLike(dir):
            self.rpms = [Rpm(name=n) for n in dir]

        log.log(1, "rpms created = %4d for tag='%s'" % (len(self.rpms),
                                                        self.tag))

    def __len__(self):
        """ Returns the # of packages in this list.  Use like len(instance) """
        return len(self.rpms)

    def AddRpm(self, rpmobj):
        """ Adds an Rpm object to the end of the rpmlist """
        self.rpms.append(rpmobj)

    def GetRpms(self):
        return self.rpms

    def GetPkgNames(self):
        """ Returns a list of package names in this rpmlist, or [] """
        names = [pkg.GetName() for pkg in self.rpms]
        return names

    def GetPkgNAList(self, reflist=None):
        """ Returns this rpmlist as a list of name.arch strings, or [].
        If reflist is specified, first look up each rpm in the reflist and
        use the arch information from that list.  Useful when the rpmlist
        only contains name info.  If a pkg in our list is not in reflist,
        then that pkg will be dropped, so this is a Union function.
        """
        if reflist:
            rpms = [reflist.GetPkgByName(pkg.GetName()) for pkg in self.rpms]
            names = [pkg.GetNameArch() for pkg in rpms if pkg]
        else:
            names = [pkg.GetNameArch() for pkg in self.rpms]
        return names

    def GetPkgByNA(self, NAStr):
        """ Returns the first rpm object matching the 'name.arch' string, as
        returned by Rpm.GetNameArch().  If NAStr is of the form 'name.', ie
        there is no arch, then only an rpm without arch defined will match.
        Returns None if no match was found.
        """
        for rpm in self.rpms:
            if rpm.GetNameArch() == NAStr:
                return rpm
        return None

    def GetPkgByName(self, name):
        """ Returns the first rpm object matching the package name.
        Returns None if no match was found.
        """
        for rpm in self.rpms:
            if rpm.GetName() == name:
                return rpm
        return None

    def SetRpms(self, listOfRpms):
        assert IsListLike(listOfRpms), "Non-list arg passed to SetRpms"
        self.rpms = listOfRpms

    #
    # Exclusions: This function does not compare well those packages in our list that
    #   do not have arch, with packages in olderList that have the arch.
    #
    def DiffList(self, olderList):
        """ Returns (added, changed, deleted) pkg lists when diff'ed against another rpmlist.

        Compares this rpmlist against another rpmlist passed in as the argument.
        added, deleted lists are created by finding the unique package names to our and their
        list.  changed are the common packages by name that compare differently using
        the Rpm == operator method.

        What is returned is a tuple, each a list of Rpm objects.
        """
        # Added list = ourNames not in olderNames
        # Changed list = ourNames in olderNames that are not ==
        #
        added = []
        changed = []
        #
        # Use olderDict to track what packages in the olderList are in this Rpmlist.
        # Packages with a value <= 0 in this dict are common between both lists.
        # < 0 means this Rpmlist has more than one copy.
        # The old packages that retain a value of 1 are not in this Rpmlist.
        #
        # Index by NameArch for uniqueness (Rpms without arch will index as 'Name.')
        #
        olderDict = {}
        for rpm in olderList.GetRpms():
            olderDict[rpm.GetNameArch()] = 1

        #
        # Go through our list of rpms and, for each one, try to match a package
        # in the olderList.  First try to match by NameArch, which works if
        # both new and old packages have arch defined, or if neither defines it.
        # What matches and has changed versions goes on the changed list.
        #
        for rpm in self.rpms:
            na = rpm.GetNameArch()
            name = rpm.GetName() + '.'
            if olderDict.has_key(na):
                # Use subtraction to handle duplicate packages in our rpmlist
                olderDict[na] -= 1
                if rpm != olderList.GetPkgByNA(na):
                    changed.append(rpm)
            #
            # NameArch doesn't match, try [Name.] to match the case where
            # the older Rpm obj doesn't have Arch defined.
            #
            elif olderDict.has_key(name):
                olderDict[name] -= 1
                if rpm != olderList.GetPkgByNA(name):
                    changed.append(rpm)
            #
            # What doesn't match at all in olderList must be a new rpm.
            #
            else:
                added.append(rpm)

        #
        # deleted are the old list packages missing from this Rpmlist.
        # This means olderDict[pkg] will still remain 1.
        # Return a list of Rpm objects.
        #
        deleted = [olderList.GetPkgByNA(na) for na in olderDict.keys() if olderDict[na] > 0]
        return added, changed, deleted

    def GetUpdates(self, baselist):
        """ Compares this Rpmlist to baselist, and returns a tuple of Rpmlists,
        (updates, new, downgrades) of the changes(updates) from baselist.
         updates: the Rpms of > version than those on baselist
         new    : these Rpms are not in baselist
         downgrades: the Rpms of < version than those on baselist

        baselist -- an Rpmlist with the baseline of Rpms for comparison.  This
                    could be all the Rpms on the system or some other baseline.
        """
        updates = Rpmlist(None)
        newlist = Rpmlist(None)
        older = Rpmlist(None)
        added, changed, deleted = self.DiffList(baselist)
        log.log(1, "GetUpdates: added=%s\nchanged=%s\ndeleted=%s" % (
            added, changed, deleted))
        newlist.SetRpms(added)
        #
        # Go through all the changed Rpms and determine which ones are newer
        #
        for pkg in changed:
            basepkg = baselist.GetPkgByName(pkg.GetName())
            if basepkg:
                if pkg > basepkg:
                    updates.AddRpm(pkg)
                else:         # must be < cuz these are changed
                    older.AddRpm(pkg)

        return updates, newlist, older

    def Union(self, otherlist):
        """ Returns a NEW Rpmlist consisting of the exactly matching Rpms
        between self and otherlist.  The instance from own list is used.
        """
        newlist = Rpmlist(None, tag='union')
        otherdict = {}
        for rpm in otherlist.rpms:
            otherdict[rpm.GetName()] = rpm
        for rpm in self.rpms:
            rpmname = rpm.GetName()
            if rpmname in otherdict and rpm==otherdict[rpmname]:
                newlist.AddRpm(rpm)
        return newlist

    def FilterList(self, NAlist):
        """ Returns a NEW Rpmlist with only the Rpm instances whose name.arch is in
        NAlist.  Returns an empty Rpmlist (len == 0) if no instances match NAlist.
        NAlist   - list of name.arch strings or name strings
                   or dict keyed by name.arch or name
        Original list is unchanged.
        """
        newlist = Rpmlist(None)
        for rpm in self.rpms:
            if rpm.GetNameArch() in NAlist or rpm.GetName() in NAlist:
                newlist.AddRpm(rpm)
        log.log(1, 'Filtered list of %d down to %d, with NAlist of %d' % (
            len(self), len(newlist), len(NAlist)))
        return newlist

    def Prune(self, NAlist):
        """ Removes the Rpm instances whose name.arch are in NAlist.
        NAlist can also be a list of names alone.
        Returns the number of instances removed.
        """
        numpruned = 0
        #
        # Go backwards as list shrinks to avoid stepping out of bounds
        #
        for i in range(len(self.rpms)-1, -1, -1):
            if self.rpms[i].GetNameArch() in NAlist or \
                  self.rpms[i].GetName() in NAlist:
                self.rpms.pop(i)
                numpruned = numpruned + 1
        log.log(1, 'Pruned %d rpms from list %s' % (numpruned, self.tag))
        return numpruned

    def Merge(self, otherlist):
        """ Merge with otherlist, also an Rpmlist.  Packages strictly in otherlist are
        added;  packages with same name.arch in otherlist clobber ones on this list.
        """
        #
        # Build a hash of otherlist by NA
        #
        otherNA = {}
        for rpm in otherlist.GetRpms():
            otherNA[rpm.GetNameArch()] = rpm

        #
        # Go through and replace our rpms with same NameArch
        #
        for i in range(len(self.rpms)):
            na = self.rpms[i].GetNameArch()
            if otherNA.has_key(na):
                self.rpms[i] = otherNA[na]
                del otherNA[na]

        #
        # Add extras in otherlist to ourselves
        #
        extras = len(otherNA)
        self.rpms.extend(otherNA.values())
        log.log(1, "Grew list '%s' to %d, added %d, changed %d" % (
            self.tag, len(self), extras, len(otherlist) - extras))

    def Dump(self):
        """ Prints out the basenames of the first and last 2 rpm elements to stdout """
        if not self.rpms:
            print "self.rpms is empty"
        else:
            for rpm in self.rpms:
                print "%3d:  %s" % (self.rpms.index(rpm), rpm.GetBasename())


    def ToXml(self):
        """ Serializes the list of Rpms into XML, returning an Element instance.
        """
        return ElementList(self.tag, self.rpms)
