test__util.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from __future__ import division, print_function, absolute_import
  2. from multiprocessing import Pool
  3. from multiprocessing.pool import Pool as PWL
  4. import numpy as np
  5. from numpy.testing import assert_equal, assert_
  6. from pytest import raises as assert_raises
  7. from scipy._lib._util import _aligned_zeros, check_random_state, MapWrapper
  8. def test__aligned_zeros():
  9. niter = 10
  10. def check(shape, dtype, order, align):
  11. err_msg = repr((shape, dtype, order, align))
  12. x = _aligned_zeros(shape, dtype, order, align=align)
  13. if align is None:
  14. align = np.dtype(dtype).alignment
  15. assert_equal(x.__array_interface__['data'][0] % align, 0)
  16. if hasattr(shape, '__len__'):
  17. assert_equal(x.shape, shape, err_msg)
  18. else:
  19. assert_equal(x.shape, (shape,), err_msg)
  20. assert_equal(x.dtype, dtype)
  21. if order == "C":
  22. assert_(x.flags.c_contiguous, err_msg)
  23. elif order == "F":
  24. if x.size > 0:
  25. # Size-0 arrays get invalid flags on Numpy 1.5
  26. assert_(x.flags.f_contiguous, err_msg)
  27. elif order is None:
  28. assert_(x.flags.c_contiguous, err_msg)
  29. else:
  30. raise ValueError()
  31. # try various alignments
  32. for align in [1, 2, 3, 4, 8, 16, 32, 64, None]:
  33. for n in [0, 1, 3, 11]:
  34. for order in ["C", "F", None]:
  35. for dtype in [np.uint8, np.float64]:
  36. for shape in [n, (1, 2, 3, n)]:
  37. for j in range(niter):
  38. check(shape, dtype, order, align)
  39. def test_check_random_state():
  40. # If seed is None, return the RandomState singleton used by np.random.
  41. # If seed is an int, return a new RandomState instance seeded with seed.
  42. # If seed is already a RandomState instance, return it.
  43. # Otherwise raise ValueError.
  44. rsi = check_random_state(1)
  45. assert_equal(type(rsi), np.random.RandomState)
  46. rsi = check_random_state(rsi)
  47. assert_equal(type(rsi), np.random.RandomState)
  48. rsi = check_random_state(None)
  49. assert_equal(type(rsi), np.random.RandomState)
  50. assert_raises(ValueError, check_random_state, 'a')
  51. class TestMapWrapper(object):
  52. def setup_method(self):
  53. self.input = np.arange(10.)
  54. self.output = np.sin(self.input)
  55. def test_serial(self):
  56. p = MapWrapper(1)
  57. assert_(p._mapfunc is map)
  58. assert_(p.pool is None)
  59. assert_(p._own_pool is False)
  60. out = list(p(np.sin, self.input))
  61. assert_equal(out, self.output)
  62. with assert_raises(RuntimeError):
  63. p = MapWrapper(0)
  64. def test_parallel(self):
  65. with MapWrapper(2) as p:
  66. out = p(np.sin, self.input)
  67. assert_equal(list(out), self.output)
  68. assert_(p._own_pool is True)
  69. assert_(isinstance(p.pool, PWL))
  70. assert_(p._mapfunc is not None)
  71. # the context manager should've closed the internal pool
  72. # check that it has by asking it to calculate again.
  73. with assert_raises(Exception) as excinfo:
  74. p(np.sin, self.input)
  75. # on py27 an AssertionError is raised, on >py27 it's a ValueError
  76. err_type = excinfo.type
  77. assert_((err_type is ValueError) or (err_type is AssertionError))
  78. # can also set a PoolWrapper up with a map-like callable instance
  79. try:
  80. p = Pool(2)
  81. q = MapWrapper(p.map)
  82. assert_(q._own_pool is False)
  83. q.close()
  84. # closing the PoolWrapper shouldn't close the internal pool
  85. # because it didn't create it
  86. out = p.map(np.sin, self.input)
  87. assert_equal(list(out), self.output)
  88. finally:
  89. p.close()