123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550 |
- import itertools
- from itertools import starmap
- from cytoolz.utils import raises
- from functools import partial
- from random import Random
- from pickle import dumps, loads
- from cytoolz.itertoolz import (remove, groupby, merge_sorted,
- concat, concatv, interleave, unique,
- isiterable, getter,
- mapcat, isdistinct, first, second,
- nth, take, tail, drop, interpose, get,
- rest, last, cons, frequencies,
- reduceby, iterate, accumulate,
- sliding_window, count, partition,
- partition_all, take_nth, pluck, join,
- diff, topk, peek, peekn, random_sample)
- from cytoolz.compatibility import range, filter
- from operator import add, mul
- # is comparison will fail between this and no_default
- no_default2 = loads(dumps('__no__default__'))
- def identity(x):
- return x
- def iseven(x):
- return x % 2 == 0
- def isodd(x):
- return x % 2 == 1
- def inc(x):
- return x + 1
- def double(x):
- return 2 * x
- def test_remove():
- r = remove(iseven, range(5))
- assert type(r) is not list
- assert list(r) == list(filter(isodd, range(5)))
- def test_groupby():
- assert groupby(iseven, [1, 2, 3, 4]) == {True: [2, 4], False: [1, 3]}
- def test_groupby_non_callable():
- assert groupby(0, [(1, 2), (1, 3), (2, 2), (2, 4)]) == \
- {1: [(1, 2), (1, 3)],
- 2: [(2, 2), (2, 4)]}
- assert groupby([0], [(1, 2), (1, 3), (2, 2), (2, 4)]) == \
- {(1,): [(1, 2), (1, 3)],
- (2,): [(2, 2), (2, 4)]}
- assert groupby([0, 0], [(1, 2), (1, 3), (2, 2), (2, 4)]) == \
- {(1, 1): [(1, 2), (1, 3)],
- (2, 2): [(2, 2), (2, 4)]}
- def test_merge_sorted():
- assert list(merge_sorted([1, 2, 3], [1, 2, 3])) == [1, 1, 2, 2, 3, 3]
- assert list(merge_sorted([1, 3, 5], [2, 4, 6])) == [1, 2, 3, 4, 5, 6]
- assert list(merge_sorted([1], [2, 4], [3], [])) == [1, 2, 3, 4]
- assert list(merge_sorted([5, 3, 1], [6, 4, 3], [],
- key=lambda x: -x)) == [6, 5, 4, 3, 3, 1]
- assert list(merge_sorted([2, 1, 3], [1, 2, 3],
- key=lambda x: x // 3)) == [2, 1, 1, 2, 3, 3]
- assert list(merge_sorted([2, 3], [1, 3],
- key=lambda x: x // 3)) == [2, 1, 3, 3]
- assert ''.join(merge_sorted('abc', 'abc', 'abc')) == 'aaabbbccc'
- assert ''.join(merge_sorted('abc', 'abc', 'abc', key=ord)) == 'aaabbbccc'
- assert ''.join(merge_sorted('cba', 'cba', 'cba',
- key=lambda x: -ord(x))) == 'cccbbbaaa'
- assert list(merge_sorted([1], [2, 3, 4], key=identity)) == [1, 2, 3, 4]
- data = [[(1, 2), (0, 4), (3, 6)], [(5, 3), (6, 5), (8, 8)],
- [(9, 1), (9, 8), (9, 9)]]
- assert list(merge_sorted(*data, key=lambda x: x[1])) == [
- (9, 1), (1, 2), (5, 3), (0, 4), (6, 5), (3, 6), (8, 8), (9, 8), (9, 9)]
- assert list(merge_sorted()) == []
- assert list(merge_sorted([1, 2, 3])) == [1, 2, 3]
- assert list(merge_sorted([1, 4, 5], [2, 3])) == [1, 2, 3, 4, 5]
- assert list(merge_sorted([1, 4, 5], [2, 3], key=identity)) == [
- 1, 2, 3, 4, 5]
- assert list(merge_sorted([1, 5], [2], [4, 7], [3, 6], key=identity)) == [
- 1, 2, 3, 4, 5, 6, 7]
- def test_interleave():
- assert ''.join(interleave(('ABC', '123'))) == 'A1B2C3'
- assert ''.join(interleave(('ABC', '1'))) == 'A1BC'
- def test_unique():
- assert tuple(unique((1, 2, 3))) == (1, 2, 3)
- assert tuple(unique((1, 2, 1, 3))) == (1, 2, 3)
- assert tuple(unique((1, 2, 3), key=iseven)) == (1, 2)
- def test_isiterable():
- assert isiterable([1, 2, 3]) is True
- assert isiterable('abc') is True
- assert isiterable(5) is False
- def test_isdistinct():
- assert isdistinct([1, 2, 3]) is True
- assert isdistinct([1, 2, 1]) is False
- assert isdistinct("Hello") is False
- assert isdistinct("World") is True
- assert isdistinct(iter([1, 2, 3])) is True
- assert isdistinct(iter([1, 2, 1])) is False
- def test_nth():
- assert nth(2, 'ABCDE') == 'C'
- assert nth(2, iter('ABCDE')) == 'C'
- assert nth(1, (3, 2, 1)) == 2
- assert nth(0, {'foo': 'bar'}) == 'foo'
- assert raises(StopIteration, lambda: nth(10, {10: 'foo'}))
- assert nth(-2, 'ABCDE') == 'D'
- assert raises(ValueError, lambda: nth(-2, iter('ABCDE')))
- def test_first():
- assert first('ABCDE') == 'A'
- assert first((3, 2, 1)) == 3
- assert isinstance(first({0: 'zero', 1: 'one'}), int)
- def test_second():
- assert second('ABCDE') == 'B'
- assert second((3, 2, 1)) == 2
- assert isinstance(second({0: 'zero', 1: 'one'}), int)
- def test_last():
- assert last('ABCDE') == 'E'
- assert last((3, 2, 1)) == 1
- assert isinstance(last({0: 'zero', 1: 'one'}), int)
- def test_rest():
- assert list(rest('ABCDE')) == list('BCDE')
- assert list(rest((3, 2, 1))) == list((2, 1))
- def test_take():
- assert list(take(3, 'ABCDE')) == list('ABC')
- assert list(take(2, (3, 2, 1))) == list((3, 2))
- def test_tail():
- assert list(tail(3, 'ABCDE')) == list('CDE')
- assert list(tail(3, iter('ABCDE'))) == list('CDE')
- assert list(tail(2, (3, 2, 1))) == list((2, 1))
- def test_drop():
- assert list(drop(3, 'ABCDE')) == list('DE')
- assert list(drop(1, (3, 2, 1))) == list((2, 1))
- def test_take_nth():
- assert list(take_nth(2, 'ABCDE')) == list('ACE')
- def test_get():
- assert get(1, 'ABCDE') == 'B'
- assert list(get([1, 3], 'ABCDE')) == list('BD')
- assert get('a', {'a': 1, 'b': 2, 'c': 3}) == 1
- assert get(['a', 'b'], {'a': 1, 'b': 2, 'c': 3}) == (1, 2)
- assert get('foo', {}, default='bar') == 'bar'
- assert get({}, [1, 2, 3], default='bar') == 'bar'
- assert get([0, 2], 'AB', 'C') == ('A', 'C')
- assert get([0], 'AB') == ('A',)
- assert get([], 'AB') == ()
- assert raises(IndexError, lambda: get(10, 'ABC'))
- assert raises(KeyError, lambda: get(10, {'a': 1}))
- assert raises(TypeError, lambda: get({}, [1, 2, 3]))
- assert raises(TypeError, lambda: get([1, 2, 3], 1, None))
- assert raises(KeyError, lambda: get('foo', {}, default=no_default2))
- def test_mapcat():
- assert (list(mapcat(identity, [[1, 2, 3], [4, 5, 6]])) ==
- [1, 2, 3, 4, 5, 6])
- assert (list(mapcat(reversed, [[3, 2, 1, 0], [6, 5, 4], [9, 8, 7]])) ==
- list(range(10)))
- inc = lambda i: i + 1
- assert ([4, 5, 6, 7, 8, 9] ==
- list(mapcat(partial(map, inc), [[3, 4, 5], [6, 7, 8]])))
- def test_cons():
- assert list(cons(1, [2, 3])) == [1, 2, 3]
- def test_concat():
- assert list(concat([[], [], []])) == []
- assert (list(take(5, concat([['a', 'b'], range(1000000000)]))) ==
- ['a', 'b', 0, 1, 2])
- def test_concatv():
- assert list(concatv([], [], [])) == []
- assert (list(take(5, concatv(['a', 'b'], range(1000000000)))) ==
- ['a', 'b', 0, 1, 2])
- def test_interpose():
- assert "a" == first(rest(interpose("a", range(1000000000))))
- assert "tXaXrXzXaXn" == "".join(interpose("X", "tarzan"))
- assert list(interpose(0, itertools.repeat(1, 4))) == [1, 0, 1, 0, 1, 0, 1]
- assert list(interpose('.', ['a', 'b', 'c'])) == ['a', '.', 'b', '.', 'c']
- def test_frequencies():
- assert (frequencies(["cat", "pig", "cat", "eel",
- "pig", "dog", "dog", "dog"]) ==
- {"cat": 2, "eel": 1, "pig": 2, "dog": 3})
- assert frequencies([]) == {}
- assert frequencies("onomatopoeia") == {"a": 2, "e": 1, "i": 1, "m": 1,
- "o": 4, "n": 1, "p": 1, "t": 1}
- def test_reduceby():
- data = [1, 2, 3, 4, 5]
- iseven = lambda x: x % 2 == 0
- assert reduceby(iseven, add, data, 0) == {False: 9, True: 6}
- assert reduceby(iseven, mul, data, 1) == {False: 15, True: 8}
- projects = [{'name': 'build roads', 'state': 'CA', 'cost': 1000000},
- {'name': 'fight crime', 'state': 'IL', 'cost': 100000},
- {'name': 'help farmers', 'state': 'IL', 'cost': 2000000},
- {'name': 'help farmers', 'state': 'CA', 'cost': 200000}]
- assert reduceby(lambda x: x['state'],
- lambda acc, x: acc + x['cost'],
- projects, 0) == {'CA': 1200000, 'IL': 2100000}
- assert reduceby('state',
- lambda acc, x: acc + x['cost'],
- projects, 0) == {'CA': 1200000, 'IL': 2100000}
- def test_reduce_by_init():
- assert reduceby(iseven, add, [1, 2, 3, 4]) == {True: 2 + 4, False: 1 + 3}
- assert reduceby(iseven, add, [1, 2, 3, 4], no_default2) == {True: 2 + 4,
- False: 1 + 3}
- def test_reduce_by_callable_default():
- def set_add(s, i):
- s.add(i)
- return s
- assert reduceby(iseven, set_add, [1, 2, 3, 4, 1, 2], set) == \
- {True: {2, 4}, False: {1, 3}}
- def test_iterate():
- assert list(itertools.islice(iterate(inc, 0), 0, 5)) == [0, 1, 2, 3, 4]
- assert list(take(4, iterate(double, 1))) == [1, 2, 4, 8]
- def test_accumulate():
- assert list(accumulate(add, [1, 2, 3, 4, 5])) == [1, 3, 6, 10, 15]
- assert list(accumulate(mul, [1, 2, 3, 4, 5])) == [1, 2, 6, 24, 120]
- assert list(accumulate(add, [1, 2, 3, 4, 5], -1)) == [-1, 0, 2, 5, 9, 14]
- def binop(a, b):
- raise AssertionError('binop should not be called')
- start = object()
- assert list(accumulate(binop, [], start)) == [start]
- assert list(accumulate(binop, [])) == []
- assert list(accumulate(add, [1, 2, 3], no_default2)) == [1, 3, 6]
- def test_accumulate_works_on_consumable_iterables():
- assert list(accumulate(add, iter((1, 2, 3)))) == [1, 3, 6]
- def test_sliding_window():
- assert list(sliding_window(2, [1, 2, 3, 4])) == [(1, 2), (2, 3), (3, 4)]
- assert list(sliding_window(3, [1, 2, 3, 4])) == [(1, 2, 3), (2, 3, 4)]
- def test_sliding_window_of_short_iterator():
- assert list(sliding_window(3, [1, 2])) == []
- assert list(sliding_window(7, [1, 2])) == []
- def test_partition():
- assert list(partition(2, [1, 2, 3, 4])) == [(1, 2), (3, 4)]
- assert list(partition(3, range(7))) == [(0, 1, 2), (3, 4, 5)]
- assert list(partition(3, range(4), pad=-1)) == [(0, 1, 2),
- (3, -1, -1)]
- assert list(partition(2, [])) == []
- def test_partition_all():
- assert list(partition_all(2, [1, 2, 3, 4])) == [(1, 2), (3, 4)]
- assert list(partition_all(3, range(5))) == [(0, 1, 2), (3, 4)]
- assert list(partition_all(2, [])) == []
- # Regression test: https://github.com/pycytoolz/cytoolz/issues/387
- class NoCompare(object):
- def __eq__(self, other):
- if self.__class__ == other.__class__:
- return True
- raise ValueError()
- obj = NoCompare()
- result = [(obj, obj, obj, obj), (obj, obj, obj)]
- assert list(partition_all(4, [obj]*7)) == result
- assert list(partition_all(4, iter([obj]*7))) == result
- def test_count():
- assert count((1, 2, 3)) == 3
- assert count([]) == 0
- assert count(iter((1, 2, 3, 4))) == 4
- assert count('hello') == 5
- assert count(iter('hello')) == 5
- def test_pluck():
- assert list(pluck(0, [[0, 1], [2, 3], [4, 5]])) == [0, 2, 4]
- assert list(pluck([0, 1], [[0, 1, 2], [3, 4, 5]])) == [(0, 1), (3, 4)]
- assert list(pluck(1, [[0], [0, 1]], None)) == [None, 1]
- data = [{'id': 1, 'name': 'cheese'}, {'id': 2, 'name': 'pies', 'price': 1}]
- assert list(pluck('id', data)) == [1, 2]
- assert list(pluck('price', data, 0)) == [0, 1]
- assert list(pluck(['id', 'name'], data)) == [(1, 'cheese'), (2, 'pies')]
- assert list(pluck(['name'], data)) == [('cheese',), ('pies',)]
- assert list(pluck(['price', 'other'], data, 0)) == [(0, 0), (1, 0)]
- assert raises(IndexError, lambda: list(pluck(1, [[0]])))
- assert raises(KeyError, lambda: list(pluck('name', [{'id': 1}])))
- assert list(pluck(0, [[0, 1], [2, 3], [4, 5]], no_default2)) == [0, 2, 4]
- assert raises(IndexError, lambda: list(pluck(1, [[0]], no_default2)))
- def test_join():
- names = [(1, 'one'), (2, 'two'), (3, 'three')]
- fruit = [('apple', 1), ('orange', 1), ('banana', 2), ('coconut', 2)]
- def addpair(pair):
- return pair[0] + pair[1]
- result = set(starmap(add, join(first, names, second, fruit)))
- expected = {(1, 'one', 'apple', 1),
- (1, 'one', 'orange', 1),
- (2, 'two', 'banana', 2),
- (2, 'two', 'coconut', 2)}
- assert result == expected
- result = set(starmap(add, join(first, names, second, fruit,
- left_default=no_default2,
- right_default=no_default2)))
- assert result == expected
- def test_getter():
- assert getter(0)('Alice') == 'A'
- assert getter([0])('Alice') == ('A',)
- assert getter([])('Alice') == ()
- def test_key_as_getter():
- squares = [(i, i**2) for i in range(5)]
- pows = [(i, i**2, i**3) for i in range(5)]
- assert set(join(0, squares, 0, pows)) == set(join(lambda x: x[0], squares,
- lambda x: x[0], pows))
- get = lambda x: (x[0], x[1])
- assert set(join([0, 1], squares, [0, 1], pows)) == set(join(get, squares,
- get, pows))
- get = lambda x: (x[0],)
- assert set(join([0], squares, [0], pows)) == set(join(get, squares,
- get, pows))
- def test_join_double_repeats():
- names = [(1, 'one'), (2, 'two'), (3, 'three'), (1, 'uno'), (2, 'dos')]
- fruit = [('apple', 1), ('orange', 1), ('banana', 2), ('coconut', 2)]
- result = set(starmap(add, join(first, names, second, fruit)))
- expected = {(1, 'one', 'apple', 1),
- (1, 'one', 'orange', 1),
- (2, 'two', 'banana', 2),
- (2, 'two', 'coconut', 2),
- (1, 'uno', 'apple', 1),
- (1, 'uno', 'orange', 1),
- (2, 'dos', 'banana', 2),
- (2, 'dos', 'coconut', 2)}
- assert result == expected
- def test_join_missing_element():
- names = [(1, 'one'), (2, 'two'), (3, 'three')]
- fruit = [('apple', 5), ('orange', 1)]
- result = set(starmap(add, join(first, names, second, fruit)))
- expected = {(1, 'one', 'orange', 1)}
- assert result == expected
- def test_left_outer_join():
- result = set(join(identity, [1, 2], identity, [2, 3], left_default=None))
- expected = {(2, 2), (None, 3)}
- assert result == expected
- def test_right_outer_join():
- result = set(join(identity, [1, 2], identity, [2, 3], right_default=None))
- expected = {(2, 2), (1, None)}
- assert result == expected
- def test_outer_join():
- result = set(join(identity, [1, 2], identity, [2, 3],
- left_default=None, right_default=None))
- expected = {(2, 2), (1, None), (None, 3)}
- assert result == expected
- def test_diff():
- assert raises(TypeError, lambda: list(diff()))
- assert raises(TypeError, lambda: list(diff([1, 2])))
- assert raises(TypeError, lambda: list(diff([1, 2], 3)))
- assert list(diff([1, 2], (1, 2), iter([1, 2]))) == []
- assert list(diff([1, 2, 3], (1, 10, 3), iter([1, 2, 10]))) == [
- (2, 10, 2), (3, 3, 10)]
- assert list(diff([1, 2], [10])) == [(1, 10)]
- assert list(diff([1, 2], [10], default=None)) == [(1, 10), (2, None)]
- # non-variadic usage
- assert raises(TypeError, lambda: list(diff([])))
- assert raises(TypeError, lambda: list(diff([[]])))
- assert raises(TypeError, lambda: list(diff([[1, 2]])))
- assert raises(TypeError, lambda: list(diff([[1, 2], 3])))
- assert list(diff([(1, 2), (1, 3)])) == [(2, 3)]
- data1 = [{'cost': 1, 'currency': 'dollar'},
- {'cost': 2, 'currency': 'dollar'}]
- data2 = [{'cost': 100, 'currency': 'yen'},
- {'cost': 300, 'currency': 'yen'}]
- conversions = {'dollar': 1, 'yen': 0.01}
- def indollars(item):
- return conversions[item['currency']] * item['cost']
- list(diff(data1, data2, key=indollars)) == [
- ({'cost': 2, 'currency': 'dollar'}, {'cost': 300, 'currency': 'yen'})]
- def test_topk():
- assert topk(2, [4, 1, 5, 2]) == (5, 4)
- assert topk(2, [4, 1, 5, 2], key=lambda x: -x) == (1, 2)
- assert topk(2, iter([5, 1, 4, 2]), key=lambda x: -x) == (1, 2)
- assert topk(2, [{'a': 1, 'b': 10}, {'a': 2, 'b': 9},
- {'a': 10, 'b': 1}, {'a': 9, 'b': 2}], key='a') == \
- ({'a': 10, 'b': 1}, {'a': 9, 'b': 2})
- assert topk(2, [{'a': 1, 'b': 10}, {'a': 2, 'b': 9},
- {'a': 10, 'b': 1}, {'a': 9, 'b': 2}], key='b') == \
- ({'a': 1, 'b': 10}, {'a': 2, 'b': 9})
- assert topk(2, [(0, 4), (1, 3), (2, 2), (3, 1), (4, 0)], 0) == \
- ((4, 0), (3, 1))
- def test_topk_is_stable():
- assert topk(4, [5, 9, 2, 1, 5, 3], key=lambda x: 1) == (5, 9, 2, 1)
- def test_peek():
- alist = ["Alice", "Bob", "Carol"]
- element, blist = peek(alist)
- assert element == alist[0]
- assert list(blist) == alist
- assert raises(StopIteration, lambda: peek([]))
- def test_peekn():
- alist = ("Alice", "Bob", "Carol")
- elements, blist = peekn(2, alist)
- assert elements == alist[:2]
- assert tuple(blist) == alist
- elements, blist = peekn(len(alist) * 4, alist)
- assert elements == alist
- assert tuple(blist) == alist
- def test_random_sample():
- alist = list(range(100))
- assert list(random_sample(prob=1, seq=alist, random_state=2016)) == alist
- mk_rsample = lambda rs=1: list(random_sample(prob=0.1,
- seq=alist,
- random_state=rs))
- rsample1 = mk_rsample()
- assert rsample1 == mk_rsample()
- rsample2 = mk_rsample(1984)
- randobj = Random(1984)
- assert rsample2 == mk_rsample(randobj)
- assert rsample1 != rsample2
- assert mk_rsample(object) == mk_rsample(object)
- assert mk_rsample(object) != mk_rsample(object())
- assert mk_rsample(b"a") == mk_rsample(u"a")
- assert raises(TypeError, lambda: mk_rsample([]))
|