test_curried.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import toolz
  2. import toolz.curried
  3. from toolz.curried import (take, first, second, sorted, merge_with, reduce,
  4. merge, operator as cop)
  5. from toolz.compatibility import import_module
  6. from collections import defaultdict
  7. from operator import add
  8. def test_take():
  9. assert list(take(2)([1, 2, 3])) == [1, 2]
  10. def test_first():
  11. assert first is toolz.itertoolz.first
  12. def test_merge():
  13. assert merge(factory=lambda: defaultdict(int))({1: 1}) == {1: 1}
  14. assert merge({1: 1}) == {1: 1}
  15. assert merge({1: 1}, factory=lambda: defaultdict(int)) == {1: 1}
  16. def test_merge_with():
  17. assert merge_with(sum)({1: 1}, {1: 2}) == {1: 3}
  18. def test_merge_with_list():
  19. assert merge_with(sum, [{'a': 1}, {'a': 2}]) == {'a': 3}
  20. def test_sorted():
  21. assert sorted(key=second)([(1, 2), (2, 1)]) == [(2, 1), (1, 2)]
  22. def test_reduce():
  23. assert reduce(add)((1, 2, 3)) == 6
  24. def test_module_name():
  25. assert toolz.curried.__name__ == 'toolz.curried'
  26. def test_curried_operator():
  27. for k, v in vars(cop).items():
  28. if not callable(v):
  29. continue
  30. if not isinstance(v, toolz.curry):
  31. try:
  32. # Make sure it is unary
  33. v(1)
  34. except TypeError:
  35. try:
  36. v('x')
  37. except TypeError:
  38. pass
  39. else:
  40. continue
  41. raise AssertionError(
  42. 'toolz.curried.operator.%s is not curried!' % k,
  43. )
  44. # Make sure this isn't totally empty.
  45. assert len(set(vars(cop)) & set(['add', 'sub', 'mul'])) == 3
  46. def test_curried_namespace():
  47. exceptions = import_module('toolz.curried.exceptions')
  48. namespace = {}
  49. def should_curry(func):
  50. if not callable(func) or isinstance(func, toolz.curry):
  51. return False
  52. nargs = toolz.functoolz.num_required_args(func)
  53. if nargs is None or nargs > 1:
  54. return True
  55. return nargs == 1 and toolz.functoolz.has_keywords(func)
  56. def curry_namespace(ns):
  57. return dict(
  58. (name, toolz.curry(f) if should_curry(f) else f)
  59. for name, f in ns.items() if '__' not in name
  60. )
  61. from_toolz = curry_namespace(vars(toolz))
  62. from_exceptions = curry_namespace(vars(exceptions))
  63. namespace.update(toolz.merge(from_toolz, from_exceptions))
  64. namespace = toolz.valfilter(callable, namespace)
  65. curried_namespace = toolz.valfilter(callable, toolz.curried.__dict__)
  66. if namespace != curried_namespace:
  67. missing = set(namespace) - set(curried_namespace)
  68. if missing:
  69. raise AssertionError('There are missing functions in toolz.curried:\n %s'
  70. % ' \n'.join(sorted(missing)))
  71. extra = set(curried_namespace) - set(namespace)
  72. if extra:
  73. raise AssertionError('There are extra functions in toolz.curried:\n %s'
  74. % ' \n'.join(sorted(extra)))
  75. unequal = toolz.merge_with(list, namespace, curried_namespace)
  76. unequal = toolz.valfilter(lambda x: x[0] != x[1], unequal)
  77. messages = []
  78. for name, (orig_func, auto_func) in sorted(unequal.items()):
  79. if name in from_exceptions:
  80. messages.append('%s should come from toolz.curried.exceptions' % name)
  81. elif should_curry(getattr(toolz, name)):
  82. messages.append('%s should be curried from toolz' % name)
  83. else:
  84. messages.append('%s should come from toolz and NOT be curried' % name)
  85. raise AssertionError('\n'.join(messages))