collects.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # -*- coding: utf-8 -*-
  2. '''stuf collections.'''
  3. import sys
  4. from .deep import getcls
  5. from .base import second, first
  6. from .six import OrderedDict, items
  7. try:
  8. from reprlib import recursive_repr # @UnusedImport
  9. except ImportError:
  10. from .six import get_ident, getdoc, getmod, docit
  11. def recursive_repr(fillvalue='...'):
  12. def decorating_function(user_function):
  13. repr_running = set()
  14. def wrapper(self): # @IgnorePep8
  15. key = id(self), get_ident()
  16. if key in repr_running:
  17. return fillvalue
  18. repr_running.add(key)
  19. try:
  20. result = user_function(self)
  21. finally:
  22. repr_running.discard(key)
  23. return result
  24. wrapper.__module__ = getmod(user_function)
  25. docit(wrapper, getdoc(user_function))
  26. return wrapper
  27. return decorating_function
  28. version = sys.version_info
  29. if version[0] == 3 and version[1] > 1:
  30. from collections import Counter
  31. else:
  32. from heapq import nlargest
  33. from itertools import chain, starmap, repeat
  34. from .deep import clsname
  35. from .base import ismapping
  36. class Counter(dict):
  37. '''dict subclass for counting hashable items'''
  38. def __init__(self, iterable=None, **kw):
  39. '''
  40. If given, count elements from an input iterable. Or, initialize
  41. count from another mapping of elements to their counts.
  42. '''
  43. super(Counter, self).__init__()
  44. self.update(iterable, **kw)
  45. def __missing__(self, key):
  46. '''The count of elements not in the Counter is zero.'''
  47. return 0
  48. def __reduce__(self):
  49. return getcls(self), (dict(self),)
  50. def __delitem__(self, elem):
  51. '''
  52. Like dict.__delitem__() but does not raise KeyError for missing'
  53. values.
  54. '''
  55. if elem in self:
  56. super(Counter, self).__delitem__(elem)
  57. def __repr__(self): # pragma: no coverage
  58. if not self:
  59. return '%s()' % clsname(self)
  60. try:
  61. items = ', '.join(map('%r: %r'.__mod__, self.most_common()))
  62. return '%s({%s})' % (clsname(self), items)
  63. except TypeError:
  64. # handle case where values are not orderable
  65. return '{0}({1!r})'.format(clsname(self), dict(self))
  66. def __add__(self, other):
  67. '''Add counts from two counters.'''
  68. if not isinstance(other, getcls(self)):
  69. return NotImplemented()
  70. result = getcls(self)()
  71. for elem, count in items(self):
  72. newcount = count + other[elem]
  73. if newcount > 0:
  74. result[elem] = newcount
  75. for elem, count in items(other):
  76. if elem not in self and count > 0:
  77. result[elem] = count
  78. return result
  79. def __sub__(self, other):
  80. '''Subtract count, but keep only results with positive counts.'''
  81. if not isinstance(other, getcls(self)):
  82. return NotImplemented()
  83. result = getcls(self)()
  84. for elem, count in items(self):
  85. newcount = count - other[elem]
  86. if newcount > 0:
  87. result[elem] = newcount
  88. for elem, count in items(other):
  89. if elem not in self and count < 0:
  90. result[elem] = 0 - count
  91. return result
  92. def __or__(self, other):
  93. '''Union is the maximum of value in either of the input counters.'''
  94. if not isinstance(other, getcls(self)):
  95. return NotImplemented()
  96. result = getcls(self)()
  97. for elem, count in items(self):
  98. other_count = other[elem]
  99. newcount = other_count if count < other_count else count
  100. if newcount > 0:
  101. result[elem] = newcount
  102. for elem, count in items(other):
  103. if elem not in self and count > 0:
  104. result[elem] = count
  105. return result
  106. def __and__(self, other):
  107. '''Intersection is the minimum of corresponding counts.'''
  108. if not isinstance(other, getcls(self)):
  109. return NotImplemented()
  110. result = getcls(self)()
  111. for elem, count in items(self):
  112. other_count = other[elem]
  113. newcount = count if count < other_count else other_count
  114. if newcount > 0:
  115. result[elem] = newcount
  116. return result
  117. def __pos__(self):
  118. '''
  119. Adds an empty counter, effectively stripping negative and zero
  120. counts.
  121. '''
  122. return self + getcls(self)()
  123. def __neg__(self):
  124. '''
  125. Subtracts from an empty counter. Strips positive and zero counts,
  126. and flips the sign on negative counts.
  127. '''
  128. return getcls(self)() - self
  129. def most_common(self, n=None, nl=nlargest, i=items, g=second):
  130. '''
  131. List the n most common elements and their counts from the most
  132. common to the least. If n is None, then list all element counts.
  133. '''
  134. if n is None:
  135. return sorted(i(self), key=g, reverse=True)
  136. return nl(n, i(self), key=g)
  137. def elements(self):
  138. '''
  139. Iterator over elements repeating each as many times as its count.
  140. '''
  141. return chain.from_iterable(starmap(repeat, items(self)))
  142. @classmethod
  143. def fromkeys(cls, iterable, v=None):
  144. raise NotImplementedError(
  145. 'Counter.fromkeys() undefined. Use Counter(iterable) instead.'
  146. )
  147. def update(self, iterable=None, **kwds):
  148. '''Like dict.update() but add counts instead of replacing them.'''
  149. if iterable is not None:
  150. if ismapping(iterable):
  151. if self:
  152. self_get = self.get
  153. for elem, count in items(iterable):
  154. self[elem] = count + self_get(elem, 0)
  155. else:
  156. super(Counter, self).update(iterable)
  157. else:
  158. mapping_get = self.get
  159. for elem in iterable:
  160. self[elem] = mapping_get(elem, 0) + 1
  161. if kwds:
  162. self.update(kwds)
  163. def subtract(self, iterable=None, **kwds):
  164. '''
  165. Like dict.update() but subtracts counts instead of replacing them.
  166. Counts can be reduced below zero. Both the inputs and outputs are
  167. allowed to contain zero and negative counts.
  168. Source can be an iterable, a dictionary, or another Counter
  169. instance.
  170. '''
  171. if iterable is not None:
  172. self_get = self.get
  173. if ismapping(iterable):
  174. for elem, count in items(iterable):
  175. self[elem] = self_get(elem, 0) - count
  176. else:
  177. for elem in iterable:
  178. self[elem] = self_get(elem, 0) - 1
  179. if kwds:
  180. self.subtract(kwds)
  181. def copy(self):
  182. 'Return a shallow copy.'
  183. return getcls(self)(self)
  184. try:
  185. from collections import ChainMap # @UnusedImport
  186. except ImportError:
  187. # not until Python >= 3.3
  188. from collections import MutableMapping
  189. class ChainMap(MutableMapping):
  190. '''
  191. `ChainMap` groups multiple dicts (or other mappings) together to create
  192. a single, updateable view.
  193. '''
  194. def __init__(self, *maps):
  195. '''
  196. Initialize `ChainMap` by setting *maps* to the given mappings. If no
  197. mappings are provided, a single empty dictionary is used.
  198. '''
  199. # always at least one map
  200. self.maps = list(maps) or [OrderedDict()]
  201. def __missing__(self, key):
  202. raise KeyError(key)
  203. def __getitem__(self, key):
  204. for mapping in self.maps:
  205. try:
  206. # can't use 'key in mapping' with defaultdict
  207. return mapping[key]
  208. except KeyError:
  209. pass
  210. # support subclasses that define __missing__
  211. return self.__missing__(key)
  212. def get(self, key, default=None):
  213. return self[key] if key in self else default
  214. def __len__(self):
  215. # reuses stored hash values if possible
  216. return len(set().union(*self.maps))
  217. def __iter__(self, set=set):
  218. return set().union(*self.maps).__iter__()
  219. def __contains__(self, key, any=any):
  220. return any(key in m for m in self.maps)
  221. def __bool__(self, any=any):
  222. return any(self.maps)
  223. @classmethod
  224. def fromkeys(cls, iterable, *args):
  225. '''
  226. Create a ChainMap with a single dict created from the iterable.
  227. '''
  228. return cls(dict.fromkeys(iterable, *args))
  229. def copy(self):
  230. '''
  231. New ChainMap or subclass with a new copy of maps[0] and refs to
  232. maps[1:]
  233. '''
  234. return getcls(self)(first(self.maps).copy(), *self.maps[1:])
  235. __copy__ = copy
  236. def new_child(self):
  237. '''New ChainMap with a new dict followed by all previous maps.'''
  238. # like Django's Context.push()
  239. return getcls(self)({}, *self.maps)
  240. @property
  241. def parents(self):
  242. '''New ChainMap from maps[1:].'''
  243. # like Django's Context.pop()
  244. return getcls(self)(*self.maps[1:])
  245. def __setitem__(self, key, value):
  246. first(self.maps)[key] = value
  247. def __delitem__(self, key):
  248. try:
  249. del first(self.maps)[key]
  250. except KeyError:
  251. raise KeyError(
  252. 'Key not found in the first mapping: {r}'.format(key)
  253. )
  254. def popitem(self):
  255. '''
  256. Remove and return an item pair from maps[0]. Raise `KeyError` is
  257. maps[0] is empty.
  258. '''
  259. try:
  260. return first(self.maps).popitem()
  261. except KeyError:
  262. raise KeyError('No keys found in the first mapping.')
  263. def pop(self, key, *args):
  264. '''
  265. Remove *key* from maps[0] and return its value. Raise KeyError if
  266. *key* not in maps[0].
  267. '''
  268. try:
  269. return first(self.maps).pop(key, *args)
  270. except KeyError:
  271. raise KeyError(
  272. 'Key not found in the first mapping: {r}'.format(key)
  273. )
  274. def clear(self):
  275. '''Clear maps[0], leaving maps[1:] intact.'''
  276. first(self.maps).clear()