test_arrayterator.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. from __future__ import division, absolute_import, print_function
  2. from operator import mul
  3. from functools import reduce
  4. import numpy as np
  5. from numpy.random import randint
  6. from numpy.lib import Arrayterator
  7. from numpy.testing import assert_
  8. def test():
  9. np.random.seed(np.arange(10))
  10. # Create a random array
  11. ndims = randint(5)+1
  12. shape = tuple(randint(10)+1 for dim in range(ndims))
  13. els = reduce(mul, shape)
  14. a = np.arange(els)
  15. a.shape = shape
  16. buf_size = randint(2*els)
  17. b = Arrayterator(a, buf_size)
  18. # Check that each block has at most ``buf_size`` elements
  19. for block in b:
  20. assert_(len(block.flat) <= (buf_size or els))
  21. # Check that all elements are iterated correctly
  22. assert_(list(b.flat) == list(a.flat))
  23. # Slice arrayterator
  24. start = [randint(dim) for dim in shape]
  25. stop = [randint(dim)+1 for dim in shape]
  26. step = [randint(dim)+1 for dim in shape]
  27. slice_ = tuple(slice(*t) for t in zip(start, stop, step))
  28. c = b[slice_]
  29. d = a[slice_]
  30. # Check that each block has at most ``buf_size`` elements
  31. for block in c:
  32. assert_(len(block.flat) <= (buf_size or els))
  33. # Check that the arrayterator is sliced correctly
  34. assert_(np.all(c.__array__() == d))
  35. # Check that all elements are iterated correctly
  36. assert_(list(c.flat) == list(d.flat))