test_recipes.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. from doctest import DocTestSuite
  2. from unittest import TestCase
  3. from itertools import combinations
  4. from six.moves import range
  5. import more_itertools as mi
  6. def load_tests(loader, tests, ignore):
  7. # Add the doctests
  8. tests.addTests(DocTestSuite('more_itertools.recipes'))
  9. return tests
  10. class AccumulateTests(TestCase):
  11. """Tests for ``accumulate()``"""
  12. def test_empty(self):
  13. """Test that an empty input returns an empty output"""
  14. self.assertEqual(list(mi.accumulate([])), [])
  15. def test_default(self):
  16. """Test accumulate with the default function (addition)"""
  17. self.assertEqual(list(mi.accumulate([1, 2, 3])), [1, 3, 6])
  18. def test_bogus_function(self):
  19. """Test accumulate with an invalid function"""
  20. with self.assertRaises(TypeError):
  21. list(mi.accumulate([1, 2, 3], func=lambda x: x))
  22. def test_custom_function(self):
  23. """Test accumulate with a custom function"""
  24. self.assertEqual(
  25. list(mi.accumulate((1, 2, 3, 2, 1), func=max)), [1, 2, 3, 3, 3]
  26. )
  27. class TakeTests(TestCase):
  28. """Tests for ``take()``"""
  29. def test_simple_take(self):
  30. """Test basic usage"""
  31. t = mi.take(5, range(10))
  32. self.assertEqual(t, [0, 1, 2, 3, 4])
  33. def test_null_take(self):
  34. """Check the null case"""
  35. t = mi.take(0, range(10))
  36. self.assertEqual(t, [])
  37. def test_negative_take(self):
  38. """Make sure taking negative items results in a ValueError"""
  39. self.assertRaises(ValueError, lambda: mi.take(-3, range(10)))
  40. def test_take_too_much(self):
  41. """Taking more than an iterator has remaining should return what the
  42. iterator has remaining.
  43. """
  44. t = mi.take(10, range(5))
  45. self.assertEqual(t, [0, 1, 2, 3, 4])
  46. class TabulateTests(TestCase):
  47. """Tests for ``tabulate()``"""
  48. def test_simple_tabulate(self):
  49. """Test the happy path"""
  50. t = mi.tabulate(lambda x: x)
  51. f = tuple([next(t) for _ in range(3)])
  52. self.assertEqual(f, (0, 1, 2))
  53. def test_count(self):
  54. """Ensure tabulate accepts specific count"""
  55. t = mi.tabulate(lambda x: 2 * x, -1)
  56. f = (next(t), next(t), next(t))
  57. self.assertEqual(f, (-2, 0, 2))
  58. class TailTests(TestCase):
  59. """Tests for ``tail()``"""
  60. def test_greater(self):
  61. """Length of iterable is greather than requested tail"""
  62. self.assertEqual(list(mi.tail(3, 'ABCDEFG')), ['E', 'F', 'G'])
  63. def test_equal(self):
  64. """Length of iterable is equal to the requested tail"""
  65. self.assertEqual(
  66. list(mi.tail(7, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G']
  67. )
  68. def test_less(self):
  69. """Length of iterable is less than requested tail"""
  70. self.assertEqual(
  71. list(mi.tail(8, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G']
  72. )
  73. class ConsumeTests(TestCase):
  74. """Tests for ``consume()``"""
  75. def test_sanity(self):
  76. """Test basic functionality"""
  77. r = (x for x in range(10))
  78. mi.consume(r, 3)
  79. self.assertEqual(3, next(r))
  80. def test_null_consume(self):
  81. """Check the null case"""
  82. r = (x for x in range(10))
  83. mi.consume(r, 0)
  84. self.assertEqual(0, next(r))
  85. def test_negative_consume(self):
  86. """Check that negative consumsion throws an error"""
  87. r = (x for x in range(10))
  88. self.assertRaises(ValueError, lambda: mi.consume(r, -1))
  89. def test_total_consume(self):
  90. """Check that iterator is totally consumed by default"""
  91. r = (x for x in range(10))
  92. mi.consume(r)
  93. self.assertRaises(StopIteration, lambda: next(r))
  94. class NthTests(TestCase):
  95. """Tests for ``nth()``"""
  96. def test_basic(self):
  97. """Make sure the nth item is returned"""
  98. l = range(10)
  99. for i, v in enumerate(l):
  100. self.assertEqual(mi.nth(l, i), v)
  101. def test_default(self):
  102. """Ensure a default value is returned when nth item not found"""
  103. l = range(3)
  104. self.assertEqual(mi.nth(l, 100, "zebra"), "zebra")
  105. def test_negative_item_raises(self):
  106. """Ensure asking for a negative item raises an exception"""
  107. self.assertRaises(ValueError, lambda: mi.nth(range(10), -3))
  108. class AllEqualTests(TestCase):
  109. """Tests for ``all_equal()``"""
  110. def test_true(self):
  111. """Everything is equal"""
  112. self.assertTrue(mi.all_equal('aaaaaa'))
  113. self.assertTrue(mi.all_equal([0, 0, 0, 0]))
  114. def test_false(self):
  115. """Not everything is equal"""
  116. self.assertFalse(mi.all_equal('aaaaab'))
  117. self.assertFalse(mi.all_equal([0, 0, 0, 1]))
  118. def test_tricky(self):
  119. """Not everything is identical, but everything is equal"""
  120. items = [1, complex(1, 0), 1.0]
  121. self.assertTrue(mi.all_equal(items))
  122. def test_empty(self):
  123. """Return True if the iterable is empty"""
  124. self.assertTrue(mi.all_equal(''))
  125. self.assertTrue(mi.all_equal([]))
  126. def test_one(self):
  127. """Return True if the iterable is singular"""
  128. self.assertTrue(mi.all_equal('0'))
  129. self.assertTrue(mi.all_equal([0]))
  130. class QuantifyTests(TestCase):
  131. """Tests for ``quantify()``"""
  132. def test_happy_path(self):
  133. """Make sure True count is returned"""
  134. q = [True, False, True]
  135. self.assertEqual(mi.quantify(q), 2)
  136. def test_custom_predicate(self):
  137. """Ensure non-default predicates return as expected"""
  138. q = range(10)
  139. self.assertEqual(mi.quantify(q, lambda x: x % 2 == 0), 5)
  140. class PadnoneTests(TestCase):
  141. """Tests for ``padnone()``"""
  142. def test_happy_path(self):
  143. """wrapper iterator should return None indefinitely"""
  144. r = range(2)
  145. p = mi.padnone(r)
  146. self.assertEqual([0, 1, None, None], [next(p) for _ in range(4)])
  147. class NcyclesTests(TestCase):
  148. """Tests for ``nyclces()``"""
  149. def test_happy_path(self):
  150. """cycle a sequence three times"""
  151. r = ["a", "b", "c"]
  152. n = mi.ncycles(r, 3)
  153. self.assertEqual(
  154. ["a", "b", "c", "a", "b", "c", "a", "b", "c"],
  155. list(n)
  156. )
  157. def test_null_case(self):
  158. """asking for 0 cycles should return an empty iterator"""
  159. n = mi.ncycles(range(100), 0)
  160. self.assertRaises(StopIteration, lambda: next(n))
  161. def test_pathalogical_case(self):
  162. """asking for negative cycles should return an empty iterator"""
  163. n = mi.ncycles(range(100), -10)
  164. self.assertRaises(StopIteration, lambda: next(n))
  165. class DotproductTests(TestCase):
  166. """Tests for ``dotproduct()``'"""
  167. def test_happy_path(self):
  168. """simple dotproduct example"""
  169. self.assertEqual(400, mi.dotproduct([10, 10], [20, 20]))
  170. class FlattenTests(TestCase):
  171. """Tests for ``flatten()``"""
  172. def test_basic_usage(self):
  173. """ensure list of lists is flattened one level"""
  174. f = [[0, 1, 2], [3, 4, 5]]
  175. self.assertEqual(list(range(6)), list(mi.flatten(f)))
  176. def test_single_level(self):
  177. """ensure list of lists is flattened only one level"""
  178. f = [[0, [1, 2]], [[3, 4], 5]]
  179. self.assertEqual([0, [1, 2], [3, 4], 5], list(mi.flatten(f)))
  180. class RepeatfuncTests(TestCase):
  181. """Tests for ``repeatfunc()``"""
  182. def test_simple_repeat(self):
  183. """test simple repeated functions"""
  184. r = mi.repeatfunc(lambda: 5)
  185. self.assertEqual([5, 5, 5, 5, 5], [next(r) for _ in range(5)])
  186. def test_finite_repeat(self):
  187. """ensure limited repeat when times is provided"""
  188. r = mi.repeatfunc(lambda: 5, times=5)
  189. self.assertEqual([5, 5, 5, 5, 5], list(r))
  190. def test_added_arguments(self):
  191. """ensure arguments are applied to the function"""
  192. r = mi.repeatfunc(lambda x: x, 2, 3)
  193. self.assertEqual([3, 3], list(r))
  194. def test_null_times(self):
  195. """repeat 0 should return an empty iterator"""
  196. r = mi.repeatfunc(range, 0, 3)
  197. self.assertRaises(StopIteration, lambda: next(r))
  198. class PairwiseTests(TestCase):
  199. """Tests for ``pairwise()``"""
  200. def test_base_case(self):
  201. """ensure an iterable will return pairwise"""
  202. p = mi.pairwise([1, 2, 3])
  203. self.assertEqual([(1, 2), (2, 3)], list(p))
  204. def test_short_case(self):
  205. """ensure an empty iterator if there's not enough values to pair"""
  206. p = mi.pairwise("a")
  207. self.assertRaises(StopIteration, lambda: next(p))
  208. class GrouperTests(TestCase):
  209. """Tests for ``grouper()``"""
  210. def test_even(self):
  211. """Test when group size divides evenly into the length of
  212. the iterable.
  213. """
  214. self.assertEqual(
  215. list(mi.grouper(3, 'ABCDEF')), [('A', 'B', 'C'), ('D', 'E', 'F')]
  216. )
  217. def test_odd(self):
  218. """Test when group size does not divide evenly into the length of the
  219. iterable.
  220. """
  221. self.assertEqual(
  222. list(mi.grouper(3, 'ABCDE')), [('A', 'B', 'C'), ('D', 'E', None)]
  223. )
  224. def test_fill_value(self):
  225. """Test that the fill value is used to pad the final group"""
  226. self.assertEqual(
  227. list(mi.grouper(3, 'ABCDE', 'x')),
  228. [('A', 'B', 'C'), ('D', 'E', 'x')]
  229. )
  230. class RoundrobinTests(TestCase):
  231. """Tests for ``roundrobin()``"""
  232. def test_even_groups(self):
  233. """Ensure ordered output from evenly populated iterables"""
  234. self.assertEqual(
  235. list(mi.roundrobin('ABC', [1, 2, 3], range(3))),
  236. ['A', 1, 0, 'B', 2, 1, 'C', 3, 2]
  237. )
  238. def test_uneven_groups(self):
  239. """Ensure ordered output from unevenly populated iterables"""
  240. self.assertEqual(
  241. list(mi.roundrobin('ABCD', [1, 2], range(0))),
  242. ['A', 1, 'B', 2, 'C', 'D']
  243. )
  244. class PartitionTests(TestCase):
  245. """Tests for ``partition()``"""
  246. def test_bool(self):
  247. """Test when pred() returns a boolean"""
  248. lesser, greater = mi.partition(lambda x: x > 5, range(10))
  249. self.assertEqual(list(lesser), [0, 1, 2, 3, 4, 5])
  250. self.assertEqual(list(greater), [6, 7, 8, 9])
  251. def test_arbitrary(self):
  252. """Test when pred() returns an integer"""
  253. divisibles, remainders = mi.partition(lambda x: x % 3, range(10))
  254. self.assertEqual(list(divisibles), [0, 3, 6, 9])
  255. self.assertEqual(list(remainders), [1, 2, 4, 5, 7, 8])
  256. class PowersetTests(TestCase):
  257. """Tests for ``powerset()``"""
  258. def test_combinatorics(self):
  259. """Ensure a proper enumeration"""
  260. p = mi.powerset([1, 2, 3])
  261. self.assertEqual(
  262. list(p),
  263. [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
  264. )
  265. class UniqueEverseenTests(TestCase):
  266. """Tests for ``unique_everseen()``"""
  267. def test_everseen(self):
  268. """ensure duplicate elements are ignored"""
  269. u = mi.unique_everseen('AAAABBBBCCDAABBB')
  270. self.assertEqual(
  271. ['A', 'B', 'C', 'D'],
  272. list(u)
  273. )
  274. def test_custom_key(self):
  275. """ensure the custom key comparison works"""
  276. u = mi.unique_everseen('aAbACCc', key=str.lower)
  277. self.assertEqual(list('abC'), list(u))
  278. def test_unhashable(self):
  279. """ensure things work for unhashable items"""
  280. iterable = ['a', [1, 2, 3], [1, 2, 3], 'a']
  281. u = mi.unique_everseen(iterable)
  282. self.assertEqual(list(u), ['a', [1, 2, 3]])
  283. def test_unhashable_key(self):
  284. """ensure things work for unhashable items with a custom key"""
  285. iterable = ['a', [1, 2, 3], [1, 2, 3], 'a']
  286. u = mi.unique_everseen(iterable, key=lambda x: x)
  287. self.assertEqual(list(u), ['a', [1, 2, 3]])
  288. class UniqueJustseenTests(TestCase):
  289. """Tests for ``unique_justseen()``"""
  290. def test_justseen(self):
  291. """ensure only last item is remembered"""
  292. u = mi.unique_justseen('AAAABBBCCDABB')
  293. self.assertEqual(list('ABCDAB'), list(u))
  294. def test_custom_key(self):
  295. """ensure the custom key comparison works"""
  296. u = mi.unique_justseen('AABCcAD', str.lower)
  297. self.assertEqual(list('ABCAD'), list(u))
  298. class IterExceptTests(TestCase):
  299. """Tests for ``iter_except()``"""
  300. def test_exact_exception(self):
  301. """ensure the exact specified exception is caught"""
  302. l = [1, 2, 3]
  303. i = mi.iter_except(l.pop, IndexError)
  304. self.assertEqual(list(i), [3, 2, 1])
  305. def test_generic_exception(self):
  306. """ensure the generic exception can be caught"""
  307. l = [1, 2]
  308. i = mi.iter_except(l.pop, Exception)
  309. self.assertEqual(list(i), [2, 1])
  310. def test_uncaught_exception_is_raised(self):
  311. """ensure a non-specified exception is raised"""
  312. l = [1, 2, 3]
  313. i = mi.iter_except(l.pop, KeyError)
  314. self.assertRaises(IndexError, lambda: list(i))
  315. def test_first(self):
  316. """ensure first is run before the function"""
  317. l = [1, 2, 3]
  318. f = lambda: 25
  319. i = mi.iter_except(l.pop, IndexError, f)
  320. self.assertEqual(list(i), [25, 3, 2, 1])
  321. class FirstTrueTests(TestCase):
  322. """Tests for ``first_true()``"""
  323. def test_something_true(self):
  324. """Test with no keywords"""
  325. self.assertEqual(mi.first_true(range(10)), 1)
  326. def test_nothing_true(self):
  327. """Test default return value."""
  328. self.assertEqual(mi.first_true([0, 0, 0]), False)
  329. def test_default(self):
  330. """Test with a default keyword"""
  331. self.assertEqual(mi.first_true([0, 0, 0], default='!'), '!')
  332. def test_pred(self):
  333. """Test with a custom predicate"""
  334. self.assertEqual(
  335. mi.first_true([2, 4, 6], pred=lambda x: x % 3 == 0), 6
  336. )
  337. class RandomProductTests(TestCase):
  338. """Tests for ``random_product()``
  339. Since random.choice() has different results with the same seed across
  340. python versions 2.x and 3.x, these tests use highly probably events to
  341. create predictable outcomes across platforms.
  342. """
  343. def test_simple_lists(self):
  344. """Ensure that one item is chosen from each list in each pair.
  345. Also ensure that each item from each list eventually appears in
  346. the chosen combinations.
  347. Odds are roughly 1 in 7.1 * 10e16 that one item from either list will
  348. not be chosen after 100 samplings of one item from each list. Just to
  349. be safe, better use a known random seed, too.
  350. """
  351. nums = [1, 2, 3]
  352. lets = ['a', 'b', 'c']
  353. n, m = zip(*[mi.random_product(nums, lets) for _ in range(100)])
  354. n, m = set(n), set(m)
  355. self.assertEqual(n, set(nums))
  356. self.assertEqual(m, set(lets))
  357. self.assertEqual(len(n), len(nums))
  358. self.assertEqual(len(m), len(lets))
  359. def test_list_with_repeat(self):
  360. """ensure multiple items are chosen, and that they appear to be chosen
  361. from one list then the next, in proper order.
  362. """
  363. nums = [1, 2, 3]
  364. lets = ['a', 'b', 'c']
  365. r = list(mi.random_product(nums, lets, repeat=100))
  366. self.assertEqual(2 * 100, len(r))
  367. n, m = set(r[::2]), set(r[1::2])
  368. self.assertEqual(n, set(nums))
  369. self.assertEqual(m, set(lets))
  370. self.assertEqual(len(n), len(nums))
  371. self.assertEqual(len(m), len(lets))
  372. class RandomPermutationTests(TestCase):
  373. """Tests for ``random_permutation()``"""
  374. def test_full_permutation(self):
  375. """ensure every item from the iterable is returned in a new ordering
  376. 15 elements have a 1 in 1.3 * 10e12 of appearing in sorted order, so
  377. we fix a seed value just to be sure.
  378. """
  379. i = range(15)
  380. r = mi.random_permutation(i)
  381. self.assertEqual(set(i), set(r))
  382. if i == r:
  383. raise AssertionError("Values were not permuted")
  384. def test_partial_permutation(self):
  385. """ensure all returned items are from the iterable, that the returned
  386. permutation is of the desired length, and that all items eventually
  387. get returned.
  388. Sampling 100 permutations of length 5 from a set of 15 leaves a
  389. (2/3)^100 chance that an item will not be chosen. Multiplied by 15
  390. items, there is a 1 in 2.6e16 chance that at least 1 item will not
  391. show up in the resulting output. Using a random seed will fix that.
  392. """
  393. items = range(15)
  394. item_set = set(items)
  395. all_items = set()
  396. for _ in range(100):
  397. permutation = mi.random_permutation(items, 5)
  398. self.assertEqual(len(permutation), 5)
  399. permutation_set = set(permutation)
  400. self.assertLessEqual(permutation_set, item_set)
  401. all_items |= permutation_set
  402. self.assertEqual(all_items, item_set)
  403. class RandomCombinationTests(TestCase):
  404. """Tests for ``random_combination()``"""
  405. def test_psuedorandomness(self):
  406. """ensure different subsets of the iterable get returned over many
  407. samplings of random combinations"""
  408. items = range(15)
  409. all_items = set()
  410. for _ in range(50):
  411. combination = mi.random_combination(items, 5)
  412. all_items |= set(combination)
  413. self.assertEqual(all_items, set(items))
  414. def test_no_replacement(self):
  415. """ensure that elements are sampled without replacement"""
  416. items = range(15)
  417. for _ in range(50):
  418. combination = mi.random_combination(items, len(items))
  419. self.assertEqual(len(combination), len(set(combination)))
  420. self.assertRaises(
  421. ValueError, lambda: mi.random_combination(items, len(items) + 1)
  422. )
  423. class RandomCombinationWithReplacementTests(TestCase):
  424. """Tests for ``random_combination_with_replacement()``"""
  425. def test_replacement(self):
  426. """ensure that elements are sampled with replacement"""
  427. items = range(5)
  428. combo = mi.random_combination_with_replacement(items, len(items) * 2)
  429. self.assertEqual(2 * len(items), len(combo))
  430. if len(set(combo)) == len(combo):
  431. raise AssertionError("Combination contained no duplicates")
  432. def test_pseudorandomness(self):
  433. """ensure different subsets of the iterable get returned over many
  434. samplings of random combinations"""
  435. items = range(15)
  436. all_items = set()
  437. for _ in range(50):
  438. combination = mi.random_combination_with_replacement(items, 5)
  439. all_items |= set(combination)
  440. self.assertEqual(all_items, set(items))
  441. class NthCombinationTests(TestCase):
  442. def test_basic(self):
  443. iterable = 'abcdefg'
  444. r = 4
  445. for index, expected in enumerate(combinations(iterable, r)):
  446. actual = mi.nth_combination(iterable, r, index)
  447. self.assertEqual(actual, expected)
  448. def test_long(self):
  449. actual = mi.nth_combination(range(180), 4, 2000000)
  450. expected = (2, 12, 35, 126)
  451. self.assertEqual(actual, expected)