123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267 |
- from ._compat import Container, Iterable, Sized, Hashable
- from functools import reduce
- from pyrsistent._pmap import pmap
- def _add_to_counters(counters, element):
- return counters.set(element, counters.get(element, 0) + 1)
- class PBag(object):
- """
- A persistent bag/multiset type.
- Requires elements to be hashable, and allows duplicates, but has no
- ordering. Bags are hashable.
- Do not instantiate directly, instead use the factory functions :py:func:`b`
- or :py:func:`pbag` to create an instance.
- Some examples:
- >>> s = pbag([1, 2, 3, 1])
- >>> s2 = s.add(4)
- >>> s3 = s2.remove(1)
- >>> s
- pbag([1, 1, 2, 3])
- >>> s2
- pbag([1, 1, 2, 3, 4])
- >>> s3
- pbag([1, 2, 3, 4])
- """
- __slots__ = ('_counts', '__weakref__')
- def __init__(self, counts):
- self._counts = counts
- def add(self, element):
- """
- Add an element to the bag.
- >>> s = pbag([1])
- >>> s2 = s.add(1)
- >>> s3 = s.add(2)
- >>> s2
- pbag([1, 1])
- >>> s3
- pbag([1, 2])
- """
- return PBag(_add_to_counters(self._counts, element))
- def update(self, iterable):
- """
- Update bag with all elements in iterable.
- >>> s = pbag([1])
- >>> s.update([1, 2])
- pbag([1, 1, 2])
- """
- if iterable:
- return PBag(reduce(_add_to_counters, iterable, self._counts))
- return self
- def remove(self, element):
- """
- Remove an element from the bag.
- >>> s = pbag([1, 1, 2])
- >>> s2 = s.remove(1)
- >>> s3 = s.remove(2)
- >>> s2
- pbag([1, 2])
- >>> s3
- pbag([1, 1])
- """
- if element not in self._counts:
- raise KeyError(element)
- elif self._counts[element] == 1:
- newc = self._counts.remove(element)
- else:
- newc = self._counts.set(element, self._counts[element] - 1)
- return PBag(newc)
- def count(self, element):
- """
- Return the number of times an element appears.
- >>> pbag([]).count('non-existent')
- 0
- >>> pbag([1, 1, 2]).count(1)
- 2
- """
- return self._counts.get(element, 0)
- def __len__(self):
- """
- Return the length including duplicates.
- >>> len(pbag([1, 1, 2]))
- 3
- """
- return sum(self._counts.itervalues())
- def __iter__(self):
- """
- Return an iterator of all elements, including duplicates.
- >>> list(pbag([1, 1, 2]))
- [1, 1, 2]
- >>> list(pbag([1, 2]))
- [1, 2]
- """
- for elt, count in self._counts.iteritems():
- for i in range(count):
- yield elt
- def __contains__(self, elt):
- """
- Check if an element is in the bag.
- >>> 1 in pbag([1, 1, 2])
- True
- >>> 0 in pbag([1, 2])
- False
- """
- return elt in self._counts
- def __repr__(self):
- return "pbag({0})".format(list(self))
- def __eq__(self, other):
- """
- Check if two bags are equivalent, honoring the number of duplicates,
- and ignoring insertion order.
- >>> pbag([1, 1, 2]) == pbag([1, 2])
- False
- >>> pbag([2, 1, 0]) == pbag([0, 1, 2])
- True
- """
- if type(other) is not PBag:
- raise TypeError("Can only compare PBag with PBags")
- return self._counts == other._counts
- def __lt__(self, other):
- raise TypeError('PBags are not orderable')
- __le__ = __lt__
- __gt__ = __lt__
- __ge__ = __lt__
- # Multiset-style operations similar to collections.Counter
- def __add__(self, other):
- """
- Combine elements from two PBags.
- >>> pbag([1, 2, 2]) + pbag([2, 3, 3])
- pbag([1, 2, 2, 2, 3, 3])
- """
- if not isinstance(other, PBag):
- return NotImplemented
- result = self._counts.evolver()
- for elem, other_count in other._counts.iteritems():
- result[elem] = self.count(elem) + other_count
- return PBag(result.persistent())
- def __sub__(self, other):
- """
- Remove elements from one PBag that are present in another.
- >>> pbag([1, 2, 2, 2, 3]) - pbag([2, 3, 3, 4])
- pbag([1, 2, 2])
- """
- if not isinstance(other, PBag):
- return NotImplemented
- result = self._counts.evolver()
- for elem, other_count in other._counts.iteritems():
- newcount = self.count(elem) - other_count
- if newcount > 0:
- result[elem] = newcount
- elif elem in self:
- result.remove(elem)
- return PBag(result.persistent())
-
- def __or__(self, other):
- """
- Union: Keep elements that are present in either of two PBags.
- >>> pbag([1, 2, 2, 2]) | pbag([2, 3, 3])
- pbag([1, 2, 2, 2, 3, 3])
- """
- if not isinstance(other, PBag):
- return NotImplemented
- result = self._counts.evolver()
- for elem, other_count in other._counts.iteritems():
- count = self.count(elem)
- newcount = max(count, other_count)
- result[elem] = newcount
- return PBag(result.persistent())
-
- def __and__(self, other):
- """
- Intersection: Only keep elements that are present in both PBags.
-
- >>> pbag([1, 2, 2, 2]) & pbag([2, 3, 3])
- pbag([2])
- """
- if not isinstance(other, PBag):
- return NotImplemented
- result = pmap().evolver()
- for elem, count in self._counts.iteritems():
- newcount = min(count, other.count(elem))
- if newcount > 0:
- result[elem] = newcount
- return PBag(result.persistent())
-
- def __hash__(self):
- """
- Hash based on value of elements.
- >>> m = pmap({pbag([1, 2]): "it's here!"})
- >>> m[pbag([2, 1])]
- "it's here!"
- >>> pbag([1, 1, 2]) in m
- False
- """
- return hash(self._counts)
- Container.register(PBag)
- Iterable.register(PBag)
- Sized.register(PBag)
- Hashable.register(PBag)
- def b(*elements):
- """
- Construct a persistent bag.
- Takes an arbitrary number of arguments to insert into the new persistent
- bag.
- >>> b(1, 2, 3, 2)
- pbag([1, 2, 2, 3])
- """
- return pbag(elements)
- def pbag(elements):
- """
- Convert an iterable to a persistent bag.
- Takes an iterable with elements to insert.
- >>> pbag([1, 2, 3, 2])
- pbag([1, 2, 2, 3])
- """
- if not elements:
- return _EMPTY_PBAG
- return PBag(reduce(_add_to_counters, elements, pmap()))
- _EMPTY_PBAG = PBag(pmap())
|