_pset.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. from ._compat import Set, Hashable
  2. import sys
  3. from pyrsistent._pmap import pmap
  4. PY2 = sys.version_info[0] < 3
  5. class PSet(object):
  6. """
  7. Persistent set implementation. Built on top of the persistent map. The set supports all operations
  8. in the Set protocol and is Hashable.
  9. Do not instantiate directly, instead use the factory functions :py:func:`s` or :py:func:`pset`
  10. to create an instance.
  11. Random access and insert is log32(n) where n is the size of the set.
  12. Some examples:
  13. >>> s = pset([1, 2, 3, 1])
  14. >>> s2 = s.add(4)
  15. >>> s3 = s2.remove(2)
  16. >>> s
  17. pset([1, 2, 3])
  18. >>> s2
  19. pset([1, 2, 3, 4])
  20. >>> s3
  21. pset([1, 3, 4])
  22. """
  23. __slots__ = ('_map', '__weakref__')
  24. def __new__(cls, m):
  25. self = super(PSet, cls).__new__(cls)
  26. self._map = m
  27. return self
  28. def __contains__(self, element):
  29. return element in self._map
  30. def __iter__(self):
  31. return iter(self._map)
  32. def __len__(self):
  33. return len(self._map)
  34. def __repr__(self):
  35. if PY2 or not self:
  36. return 'p' + str(set(self))
  37. return 'pset([{0}])'.format(str(set(self))[1:-1])
  38. def __str__(self):
  39. return self.__repr__()
  40. def __hash__(self):
  41. return hash(self._map)
  42. def __reduce__(self):
  43. # Pickling support
  44. return pset, (list(self),)
  45. @classmethod
  46. def _from_iterable(cls, it, pre_size=8):
  47. return PSet(pmap(dict((k, True) for k in it), pre_size=pre_size))
  48. def add(self, element):
  49. """
  50. Return a new PSet with element added
  51. >>> s1 = s(1, 2)
  52. >>> s1.add(3)
  53. pset([1, 2, 3])
  54. """
  55. return self.evolver().add(element).persistent()
  56. def update(self, iterable):
  57. """
  58. Return a new PSet with elements in iterable added
  59. >>> s1 = s(1, 2)
  60. >>> s1.update([3, 4, 4])
  61. pset([1, 2, 3, 4])
  62. """
  63. e = self.evolver()
  64. for element in iterable:
  65. e.add(element)
  66. return e.persistent()
  67. def remove(self, element):
  68. """
  69. Return a new PSet with element removed. Raises KeyError if element is not present.
  70. >>> s1 = s(1, 2)
  71. >>> s1.remove(2)
  72. pset([1])
  73. """
  74. if element in self._map:
  75. return self.evolver().remove(element).persistent()
  76. raise KeyError("Element '%s' not present in PSet" % element)
  77. def discard(self, element):
  78. """
  79. Return a new PSet with element removed. Returns itself if element is not present.
  80. """
  81. if element in self._map:
  82. return self.evolver().remove(element).persistent()
  83. return self
  84. class _Evolver(object):
  85. __slots__ = ('_original_pset', '_pmap_evolver')
  86. def __init__(self, original_pset):
  87. self._original_pset = original_pset
  88. self._pmap_evolver = original_pset._map.evolver()
  89. def add(self, element):
  90. self._pmap_evolver[element] = True
  91. return self
  92. def remove(self, element):
  93. del self._pmap_evolver[element]
  94. return self
  95. def is_dirty(self):
  96. return self._pmap_evolver.is_dirty()
  97. def persistent(self):
  98. if not self.is_dirty():
  99. return self._original_pset
  100. return PSet(self._pmap_evolver.persistent())
  101. def __len__(self):
  102. return len(self._pmap_evolver)
  103. def copy(self):
  104. return self
  105. def evolver(self):
  106. """
  107. Create a new evolver for this pset. For a discussion on evolvers in general see the
  108. documentation for the pvector evolver.
  109. Create the evolver and perform various mutating updates to it:
  110. >>> s1 = s(1, 2, 3)
  111. >>> e = s1.evolver()
  112. >>> _ = e.add(4)
  113. >>> len(e)
  114. 4
  115. >>> _ = e.remove(1)
  116. The underlying pset remains the same:
  117. >>> s1
  118. pset([1, 2, 3])
  119. The changes are kept in the evolver. An updated pmap can be created using the
  120. persistent() function on the evolver.
  121. >>> s2 = e.persistent()
  122. >>> s2
  123. pset([2, 3, 4])
  124. The new pset will share data with the original pset in the same way that would have
  125. been done if only using operations on the pset.
  126. """
  127. return PSet._Evolver(self)
  128. # All the operations and comparisons you would expect on a set.
  129. #
  130. # This is not very beautiful. If we avoid inheriting from PSet we can use the
  131. # __slots__ concepts (which requires a new style class) and hopefully save some memory.
  132. __le__ = Set.__le__
  133. __lt__ = Set.__lt__
  134. __gt__ = Set.__gt__
  135. __ge__ = Set.__ge__
  136. __eq__ = Set.__eq__
  137. __ne__ = Set.__ne__
  138. __and__ = Set.__and__
  139. __or__ = Set.__or__
  140. __sub__ = Set.__sub__
  141. __xor__ = Set.__xor__
  142. issubset = __le__
  143. issuperset = __ge__
  144. union = __or__
  145. intersection = __and__
  146. difference = __sub__
  147. symmetric_difference = __xor__
  148. isdisjoint = Set.isdisjoint
  149. Set.register(PSet)
  150. Hashable.register(PSet)
  151. _EMPTY_PSET = PSet(pmap())
  152. def pset(iterable=(), pre_size=8):
  153. """
  154. Creates a persistent set from iterable. Optionally takes a sizing parameter equivalent to that
  155. used for :py:func:`pmap`.
  156. >>> s1 = pset([1, 2, 3, 2])
  157. >>> s1
  158. pset([1, 2, 3])
  159. """
  160. if not iterable:
  161. return _EMPTY_PSET
  162. return PSet._from_iterable(iterable, pre_size=pre_size)
  163. def s(*elements):
  164. """
  165. Create a persistent set.
  166. Takes an arbitrary number of arguments to insert into the new set.
  167. >>> s1 = s(1, 2, 3, 2)
  168. >>> s1
  169. pset([1, 2, 3])
  170. """
  171. return pset(elements)