recipes.py 14 KB


  1. """Imported from the recipes section of the itertools documentation.
  2. All functions taken from the recipes section of the itertools library docs
  3. [1]_.
  4. Some backward-compatible usability improvements have been made.
  5. .. [1] http://docs.python.org/library/itertools.html#recipes
  6. """
  7. from collections import deque
  8. from itertools import (
  9. chain, combinations, count, cycle, groupby, islice, repeat, starmap, tee
  10. )
  11. import operator
  12. from random import randrange, sample, choice
  13. from six import PY2
  14. from six.moves import filter, filterfalse, map, range, zip, zip_longest
  15. __all__ = [
  16. 'accumulate',
  17. 'all_equal',
  18. 'consume',
  19. 'dotproduct',
  20. 'first_true',
  21. 'flatten',
  22. 'grouper',
  23. 'iter_except',
  24. 'ncycles',
  25. 'nth',
  26. 'nth_combination',
  27. 'padnone',
  28. 'pairwise',
  29. 'partition',
  30. 'powerset',
  31. 'quantify',
  32. 'random_combination_with_replacement',
  33. 'random_combination',
  34. 'random_permutation',
  35. 'random_product',
  36. 'repeatfunc',
  37. 'roundrobin',
  38. 'tabulate',
  39. 'tail',
  40. 'take',
  41. 'unique_everseen',
  42. 'unique_justseen',
  43. ]
  44. def accumulate(iterable, func=operator.add):
  45. """
  46. Return an iterator whose items are the accumulated results of a function
  47. (specified by the optional *func* argument) that takes two arguments.
  48. By default, returns accumulated sums with :func:`operator.add`.
  49. >>> list(accumulate([1, 2, 3, 4, 5])) # Running sum
  50. [1, 3, 6, 10, 15]
  51. >>> list(accumulate([1, 2, 3], func=operator.mul)) # Running product
  52. [1, 2, 6]
  53. >>> list(accumulate([0, 1, -1, 2, 3, 2], func=max)) # Running maximum
  54. [0, 1, 1, 2, 3, 3]
  55. This function is available in the ``itertools`` module for Python 3.2 and
  56. greater.
  57. """
  58. it = iter(iterable)
  59. try:
  60. total = next(it)
  61. except StopIteration:
  62. return
  63. else:
  64. yield total
  65. for element in it:
  66. total = func(total, element)
  67. yield total
  68. def take(n, iterable):
  69. """Return first *n* items of the iterable as a list.
  70. >>> take(3, range(10))
  71. [0, 1, 2]
  72. >>> take(5, range(3))
  73. [0, 1, 2]
  74. Effectively a short replacement for ``next`` based iterator consumption
  75. when you want more than one item, but less than the whole iterator.
  76. """
  77. return list(islice(iterable, n))
  78. def tabulate(function, start=0):
  79. """Return an iterator over the results of ``func(start)``,
  80. ``func(start + 1)``, ``func(start + 2)``...
  81. *func* should be a function that accepts one integer argument.
  82. If *start* is not specified it defaults to 0. It will be incremented each
  83. time the iterator is advanced.
  84. >>> square = lambda x: x ** 2
  85. >>> iterator = tabulate(square, -3)
  86. >>> take(4, iterator)
  87. [9, 4, 1, 0]
  88. """
  89. return map(function, count(start))
  90. def tail(n, iterable):
  91. """Return an iterator over the last *n* items of *iterable*.
  92. >>> t = tail(3, 'ABCDEFG')
  93. >>> list(t)
  94. ['E', 'F', 'G']
  95. """
  96. return iter(deque(iterable, maxlen=n))
  97. def consume(iterator, n=None):
  98. """Advance *iterable* by *n* steps. If *n* is ``None``, consume it
  99. entirely.
  100. Efficiently exhausts an iterator without returning values. Defaults to
  101. consuming the whole iterator, but an optional second argument may be
  102. provided to limit consumption.
  103. >>> i = (x for x in range(10))
  104. >>> next(i)
  105. 0
  106. >>> consume(i, 3)
  107. >>> next(i)
  108. 4
  109. >>> consume(i)
  110. >>> next(i)
  111. Traceback (most recent call last):
  112. File "<stdin>", line 1, in <module>
  113. StopIteration
  114. If the iterator has fewer items remaining than the provided limit, the
  115. whole iterator will be consumed.
  116. >>> i = (x for x in range(3))
  117. >>> consume(i, 5)
  118. >>> next(i)
  119. Traceback (most recent call last):
  120. File "<stdin>", line 1, in <module>
  121. StopIteration
  122. """
  123. # Use functions that consume iterators at C speed.
  124. if n is None:
  125. # feed the entire iterator into a zero-length deque
  126. deque(iterator, maxlen=0)
  127. else:
  128. # advance to the empty slice starting at position n
  129. next(islice(iterator, n, n), None)
  130. def nth(iterable, n, default=None):
  131. """Returns the nth item or a default value.
  132. >>> l = range(10)
  133. >>> nth(l, 3)
  134. 3
  135. >>> nth(l, 20, "zebra")
  136. 'zebra'
  137. """
  138. return next(islice(iterable, n, None), default)
  139. def all_equal(iterable):
  140. """
  141. Returns ``True`` if all the elements are equal to each other.
  142. >>> all_equal('aaaa')
  143. True
  144. >>> all_equal('aaab')
  145. False
  146. """
  147. g = groupby(iterable)
  148. return next(g, True) and not next(g, False)
  149. def quantify(iterable, pred=bool):
  150. """Return the how many times the predicate is true.
  151. >>> quantify([True, False, True])
  152. 2
  153. """
  154. return sum(map(pred, iterable))
  155. def padnone(iterable):
  156. """Returns the sequence of elements and then returns ``None`` indefinitely.
  157. >>> take(5, padnone(range(3)))
  158. [0, 1, 2, None, None]
  159. Useful for emulating the behavior of the built-in :func:`map` function.
  160. See also :func:`padded`.
  161. """
  162. return chain(iterable, repeat(None))
  163. def ncycles(iterable, n):
  164. """Returns the sequence elements *n* times
  165. >>> list(ncycles(["a", "b"], 3))
  166. ['a', 'b', 'a', 'b', 'a', 'b']
  167. """
  168. return chain.from_iterable(repeat(tuple(iterable), n))
  169. def dotproduct(vec1, vec2):
  170. """Returns the dot product of the two iterables.
  171. >>> dotproduct([10, 10], [20, 20])
  172. 400
  173. """
  174. return sum(map(operator.mul, vec1, vec2))
  175. def flatten(listOfLists):
  176. """Return an iterator flattening one level of nesting in a list of lists.
  177. >>> list(flatten([[0, 1], [2, 3]]))
  178. [0, 1, 2, 3]
  179. See also :func:`collapse`, which can flatten multiple levels of nesting.
  180. """
  181. return chain.from_iterable(listOfLists)
  182. def repeatfunc(func, times=None, *args):
  183. """Call *func* with *args* repeatedly, returning an iterable over the
  184. results.
  185. If *times* is specified, the iterable will terminate after that many
  186. repetitions:
  187. >>> from operator import add
  188. >>> times = 4
  189. >>> args = 3, 5
  190. >>> list(repeatfunc(add, times, *args))
  191. [8, 8, 8, 8]
  192. If *times* is ``None`` the iterable will not terminate:
  193. >>> from random import randrange
  194. >>> times = None
  195. >>> args = 1, 11
  196. >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP
  197. [2, 4, 8, 1, 8, 4]
  198. """
  199. if times is None:
  200. return starmap(func, repeat(args))
  201. return starmap(func, repeat(args, times))
  202. def pairwise(iterable):
  203. """Returns an iterator of paired items, overlapping, from the original
  204. >>> take(4, pairwise(count()))
  205. [(0, 1), (1, 2), (2, 3), (3, 4)]
  206. """
  207. a, b = tee(iterable)
  208. next(b, None)
  209. return zip(a, b)
  210. def grouper(n, iterable, fillvalue=None):
  211. """Collect data into fixed-length chunks or blocks.
  212. >>> list(grouper(3, 'ABCDEFG', 'x'))
  213. [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
  214. """
  215. args = [iter(iterable)] * n
  216. return zip_longest(fillvalue=fillvalue, *args)
  217. def roundrobin(*iterables):
  218. """Yields an item from each iterable, alternating between them.
  219. >>> list(roundrobin('ABC', 'D', 'EF'))
  220. ['A', 'D', 'E', 'B', 'F', 'C']
  221. This function produces the same output as :func:`interleave_longest`, but
  222. may perform better for some inputs (in particular when the number of
  223. iterables is small).
  224. """
  225. # Recipe credited to George Sakkis
  226. pending = len(iterables)
  227. if PY2:
  228. nexts = cycle(iter(it).next for it in iterables)
  229. else:
  230. nexts = cycle(iter(it).__next__ for it in iterables)
  231. while pending:
  232. try:
  233. for next in nexts:
  234. yield next()
  235. except StopIteration:
  236. pending -= 1
  237. nexts = cycle(islice(nexts, pending))
  238. def partition(pred, iterable):
  239. """
  240. Returns a 2-tuple of iterables derived from the input iterable.
  241. The first yields the items that have ``pred(item) == False``.
  242. The second yields the items that have ``pred(item) == True``.
  243. >>> is_odd = lambda x: x % 2 != 0
  244. >>> iterable = range(10)
  245. >>> even_items, odd_items = partition(is_odd, iterable)
  246. >>> list(even_items), list(odd_items)
  247. ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
  248. """
  249. # partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9
  250. t1, t2 = tee(iterable)
  251. return filterfalse(pred, t1), filter(pred, t2)
  252. def powerset(iterable):
  253. """Yields all possible subsets of the iterable.
  254. >>> list(powerset([1,2,3]))
  255. [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
  256. """
  257. s = list(iterable)
  258. return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
  259. def unique_everseen(iterable, key=None):
  260. """
  261. Yield unique elements, preserving order.
  262. >>> list(unique_everseen('AAAABBBCCDAABBB'))
  263. ['A', 'B', 'C', 'D']
  264. >>> list(unique_everseen('ABBCcAD', str.lower))
  265. ['A', 'B', 'C', 'D']
  266. Sequences with a mix of hashable and unhashable items can be used.
  267. The function will be slower (i.e., `O(n^2)`) for unhashable items.
  268. """
  269. seenset = set()
  270. seenset_add = seenset.add
  271. seenlist = []
  272. seenlist_add = seenlist.append
  273. if key is None:
  274. for element in iterable:
  275. try:
  276. if element not in seenset:
  277. seenset_add(element)
  278. yield element
  279. except TypeError:
  280. if element not in seenlist:
  281. seenlist_add(element)
  282. yield element
  283. else:
  284. for element in iterable:
  285. k = key(element)
  286. try:
  287. if k not in seenset:
  288. seenset_add(k)
  289. yield element
  290. except TypeError:
  291. if k not in seenlist:
  292. seenlist_add(k)
  293. yield element
  294. def unique_justseen(iterable, key=None):
  295. """Yields elements in order, ignoring serial duplicates
  296. >>> list(unique_justseen('AAAABBBCCDAABBB'))
  297. ['A', 'B', 'C', 'D', 'A', 'B']
  298. >>> list(unique_justseen('ABBCcAD', str.lower))
  299. ['A', 'B', 'C', 'A', 'D']
  300. """
  301. return map(next, map(operator.itemgetter(1), groupby(iterable, key)))
  302. def iter_except(func, exception, first=None):
  303. """Yields results from a function repeatedly until an exception is raised.
  304. Converts a call-until-exception interface to an iterator interface.
  305. Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
  306. to end the loop.
  307. >>> l = [0, 1, 2]
  308. >>> list(iter_except(l.pop, IndexError))
  309. [2, 1, 0]
  310. """
  311. try:
  312. if first is not None:
  313. yield first()
  314. while 1:
  315. yield func()
  316. except exception:
  317. pass
  318. def first_true(iterable, default=False, pred=None):
  319. """
  320. Returns the first true value in the iterable.
  321. If no true value is found, returns *default*
  322. If *pred* is not None, returns the first item for which
  323. ``pred(item) == True`` .
  324. >>> first_true(range(10))
  325. 1
  326. >>> first_true(range(10), pred=lambda x: x > 5)
  327. 6
  328. >>> first_true(range(10), default='missing', pred=lambda x: x > 9)
  329. 'missing'
  330. """
  331. return next(filter(pred, iterable), default)
  332. def random_product(*args, **kwds):
  333. """Draw an item at random from each of the input iterables.
  334. >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP
  335. ('c', 3, 'Z')
  336. If *repeat* is provided as a keyword argument, that many items will be
  337. drawn from each iterable.
  338. >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP
  339. ('a', 2, 'd', 3)
  340. This equivalent to taking a random selection from
  341. ``itertools.product(*args, **kwarg)``.
  342. """
  343. pools = [tuple(pool) for pool in args] * kwds.get('repeat', 1)
  344. return tuple(choice(pool) for pool in pools)
  345. def random_permutation(iterable, r=None):
  346. """Return a random *r* length permutation of the elements in *iterable*.
  347. If *r* is not specified or is ``None``, then *r* defaults to the length of
  348. *iterable*.
  349. >>> random_permutation(range(5)) # doctest:+SKIP
  350. (3, 4, 0, 1, 2)
  351. This equivalent to taking a random selection from
  352. ``itertools.permutations(iterable, r)``.
  353. """
  354. pool = tuple(iterable)
  355. r = len(pool) if r is None else r
  356. return tuple(sample(pool, r))
  357. def random_combination(iterable, r):
  358. """Return a random *r* length subsequence of the elements in *iterable*.
  359. >>> random_combination(range(5), 3) # doctest:+SKIP
  360. (2, 3, 4)
  361. This equivalent to taking a random selection from
  362. ``itertools.combinations(iterable, r)``.
  363. """
  364. pool = tuple(iterable)
  365. n = len(pool)
  366. indices = sorted(sample(range(n), r))
  367. return tuple(pool[i] for i in indices)
  368. def random_combination_with_replacement(iterable, r):
  369. """Return a random *r* length subsequence of elements in *iterable*,
  370. allowing individual elements to be repeated.
  371. >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
  372. (0, 0, 1, 2, 2)
  373. This equivalent to taking a random selection from
  374. ``itertools.combinations_with_replacement(iterable, r)``.
  375. """
  376. pool = tuple(iterable)
  377. n = len(pool)
  378. indices = sorted(randrange(n) for i in range(r))
  379. return tuple(pool[i] for i in indices)
  380. def nth_combination(iterable, r, index):
  381. """Equivalent to ``list(combinations(iterable, r))[index]``.
  382. The subsequences of *iterable* that are of length *r* can be ordered
  383. lexicographically. :func:`nth_combination` computes the subsequence at
  384. sort position *index* directly, without computing the previous
  385. subsequences.
  386. """
  387. pool = tuple(iterable)
  388. n = len(pool)
  389. if (r < 0) or (r > n):
  390. raise ValueError
  391. c = 1
  392. k = min(r, n - r)
  393. for i in range(1, k + 1):
  394. c = c * (n - k + i) // i
  395. if index < 0:
  396. index += c
  397. if (index < 0) or (index >= c):
  398. raise IndexError
  399. result = []
  400. while r:
  401. c, n, r = c * r // n, n - 1, r - 1
  402. while index >= c:
  403. index -= c
  404. c, n = c * (n - r) // n, n - 1
  405. result.append(pool[-1 - n])
  406. return tuple(result)