_pbag.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. from ._compat import Container, Iterable, Sized, Hashable
  2. from functools import reduce
  3. from pyrsistent._pmap import pmap
  4. def _add_to_counters(counters, element):
  5. return counters.set(element, counters.get(element, 0) + 1)
  6. class PBag(object):
  7. """
  8. A persistent bag/multiset type.
  9. Requires elements to be hashable, and allows duplicates, but has no
  10. ordering. Bags are hashable.
  11. Do not instantiate directly, instead use the factory functions :py:func:`b`
  12. or :py:func:`pbag` to create an instance.
  13. Some examples:
  14. >>> s = pbag([1, 2, 3, 1])
  15. >>> s2 = s.add(4)
  16. >>> s3 = s2.remove(1)
  17. >>> s
  18. pbag([1, 1, 2, 3])
  19. >>> s2
  20. pbag([1, 1, 2, 3, 4])
  21. >>> s3
  22. pbag([1, 2, 3, 4])
  23. """
  24. __slots__ = ('_counts', '__weakref__')
  25. def __init__(self, counts):
  26. self._counts = counts
  27. def add(self, element):
  28. """
  29. Add an element to the bag.
  30. >>> s = pbag([1])
  31. >>> s2 = s.add(1)
  32. >>> s3 = s.add(2)
  33. >>> s2
  34. pbag([1, 1])
  35. >>> s3
  36. pbag([1, 2])
  37. """
  38. return PBag(_add_to_counters(self._counts, element))
  39. def update(self, iterable):
  40. """
  41. Update bag with all elements in iterable.
  42. >>> s = pbag([1])
  43. >>> s.update([1, 2])
  44. pbag([1, 1, 2])
  45. """
  46. if iterable:
  47. return PBag(reduce(_add_to_counters, iterable, self._counts))
  48. return self
  49. def remove(self, element):
  50. """
  51. Remove an element from the bag.
  52. >>> s = pbag([1, 1, 2])
  53. >>> s2 = s.remove(1)
  54. >>> s3 = s.remove(2)
  55. >>> s2
  56. pbag([1, 2])
  57. >>> s3
  58. pbag([1, 1])
  59. """
  60. if element not in self._counts:
  61. raise KeyError(element)
  62. elif self._counts[element] == 1:
  63. newc = self._counts.remove(element)
  64. else:
  65. newc = self._counts.set(element, self._counts[element] - 1)
  66. return PBag(newc)
  67. def count(self, element):
  68. """
  69. Return the number of times an element appears.
  70. >>> pbag([]).count('non-existent')
  71. 0
  72. >>> pbag([1, 1, 2]).count(1)
  73. 2
  74. """
  75. return self._counts.get(element, 0)
  76. def __len__(self):
  77. """
  78. Return the length including duplicates.
  79. >>> len(pbag([1, 1, 2]))
  80. 3
  81. """
  82. return sum(self._counts.itervalues())
  83. def __iter__(self):
  84. """
  85. Return an iterator of all elements, including duplicates.
  86. >>> list(pbag([1, 1, 2]))
  87. [1, 1, 2]
  88. >>> list(pbag([1, 2]))
  89. [1, 2]
  90. """
  91. for elt, count in self._counts.iteritems():
  92. for i in range(count):
  93. yield elt
  94. def __contains__(self, elt):
  95. """
  96. Check if an element is in the bag.
  97. >>> 1 in pbag([1, 1, 2])
  98. True
  99. >>> 0 in pbag([1, 2])
  100. False
  101. """
  102. return elt in self._counts
  103. def __repr__(self):
  104. return "pbag({0})".format(list(self))
  105. def __eq__(self, other):
  106. """
  107. Check if two bags are equivalent, honoring the number of duplicates,
  108. and ignoring insertion order.
  109. >>> pbag([1, 1, 2]) == pbag([1, 2])
  110. False
  111. >>> pbag([2, 1, 0]) == pbag([0, 1, 2])
  112. True
  113. """
  114. if type(other) is not PBag:
  115. raise TypeError("Can only compare PBag with PBags")
  116. return self._counts == other._counts
  117. def __lt__(self, other):
  118. raise TypeError('PBags are not orderable')
  119. __le__ = __lt__
  120. __gt__ = __lt__
  121. __ge__ = __lt__
  122. # Multiset-style operations similar to collections.Counter
  123. def __add__(self, other):
  124. """
  125. Combine elements from two PBags.
  126. >>> pbag([1, 2, 2]) + pbag([2, 3, 3])
  127. pbag([1, 2, 2, 2, 3, 3])
  128. """
  129. if not isinstance(other, PBag):
  130. return NotImplemented
  131. result = self._counts.evolver()
  132. for elem, other_count in other._counts.iteritems():
  133. result[elem] = self.count(elem) + other_count
  134. return PBag(result.persistent())
  135. def __sub__(self, other):
  136. """
  137. Remove elements from one PBag that are present in another.
  138. >>> pbag([1, 2, 2, 2, 3]) - pbag([2, 3, 3, 4])
  139. pbag([1, 2, 2])
  140. """
  141. if not isinstance(other, PBag):
  142. return NotImplemented
  143. result = self._counts.evolver()
  144. for elem, other_count in other._counts.iteritems():
  145. newcount = self.count(elem) - other_count
  146. if newcount > 0:
  147. result[elem] = newcount
  148. elif elem in self:
  149. result.remove(elem)
  150. return PBag(result.persistent())
  151. def __or__(self, other):
  152. """
  153. Union: Keep elements that are present in either of two PBags.
  154. >>> pbag([1, 2, 2, 2]) | pbag([2, 3, 3])
  155. pbag([1, 2, 2, 2, 3, 3])
  156. """
  157. if not isinstance(other, PBag):
  158. return NotImplemented
  159. result = self._counts.evolver()
  160. for elem, other_count in other._counts.iteritems():
  161. count = self.count(elem)
  162. newcount = max(count, other_count)
  163. result[elem] = newcount
  164. return PBag(result.persistent())
  165. def __and__(self, other):
  166. """
  167. Intersection: Only keep elements that are present in both PBags.
  168. >>> pbag([1, 2, 2, 2]) & pbag([2, 3, 3])
  169. pbag([2])
  170. """
  171. if not isinstance(other, PBag):
  172. return NotImplemented
  173. result = pmap().evolver()
  174. for elem, count in self._counts.iteritems():
  175. newcount = min(count, other.count(elem))
  176. if newcount > 0:
  177. result[elem] = newcount
  178. return PBag(result.persistent())
  179. def __hash__(self):
  180. """
  181. Hash based on value of elements.
  182. >>> m = pmap({pbag([1, 2]): "it's here!"})
  183. >>> m[pbag([2, 1])]
  184. "it's here!"
  185. >>> pbag([1, 1, 2]) in m
  186. False
  187. """
  188. return hash(self._counts)
  189. Container.register(PBag)
  190. Iterable.register(PBag)
  191. Sized.register(PBag)
  192. Hashable.register(PBag)
  193. def b(*elements):
  194. """
  195. Construct a persistent bag.
  196. Takes an arbitrary number of arguments to insert into the new persistent
  197. bag.
  198. >>> b(1, 2, 3, 2)
  199. pbag([1, 2, 2, 3])
  200. """
  201. return pbag(elements)
  202. def pbag(elements):
  203. """
  204. Convert an iterable to a persistent bag.
  205. Takes an iterable with elements to insert.
  206. >>> pbag([1, 2, 3, 2])
  207. pbag([1, 2, 2, 3])
  208. """
  209. if not elements:
  210. return _EMPTY_PBAG
  211. return PBag(reduce(_add_to_counters, elements, pmap()))
  212. _EMPTY_PBAG = PBag(pmap())