test_index_tricks.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. from __future__ import division, absolute_import, print_function
  2. import pytest
  3. import numpy as np
  4. from numpy.testing import (
  5. assert_, assert_equal, assert_array_equal, assert_almost_equal,
  6. assert_array_almost_equal, assert_raises, assert_raises_regex,
  7. assert_warns
  8. )
  9. from numpy.lib.index_tricks import (
  10. mgrid, ogrid, ndenumerate, fill_diagonal, diag_indices, diag_indices_from,
  11. index_exp, ndindex, r_, s_, ix_
  12. )
  13. class TestRavelUnravelIndex(object):
  14. def test_basic(self):
  15. assert_equal(np.unravel_index(2, (2, 2)), (1, 0))
  16. # test backwards compatibility with older dims
  17. # keyword argument; see Issue #10586
  18. with assert_warns(DeprecationWarning):
  19. # we should achieve the correct result
  20. # AND raise the appropriate warning
  21. # when using older "dims" kw argument
  22. assert_equal(np.unravel_index(indices=2,
  23. dims=(2, 2)),
  24. (1, 0))
  25. # test that new shape argument works properly
  26. assert_equal(np.unravel_index(indices=2,
  27. shape=(2, 2)),
  28. (1, 0))
  29. # test that an invalid second keyword argument
  30. # is properly handled
  31. with assert_raises(TypeError):
  32. np.unravel_index(indices=2, hape=(2, 2))
  33. with assert_raises(TypeError):
  34. np.unravel_index(2, hape=(2, 2))
  35. with assert_raises(TypeError):
  36. np.unravel_index(254, ims=(17, 94))
  37. assert_equal(np.ravel_multi_index((1, 0), (2, 2)), 2)
  38. assert_equal(np.unravel_index(254, (17, 94)), (2, 66))
  39. assert_equal(np.ravel_multi_index((2, 66), (17, 94)), 254)
  40. assert_raises(ValueError, np.unravel_index, -1, (2, 2))
  41. assert_raises(TypeError, np.unravel_index, 0.5, (2, 2))
  42. assert_raises(ValueError, np.unravel_index, 4, (2, 2))
  43. assert_raises(ValueError, np.ravel_multi_index, (-3, 1), (2, 2))
  44. assert_raises(ValueError, np.ravel_multi_index, (2, 1), (2, 2))
  45. assert_raises(ValueError, np.ravel_multi_index, (0, -3), (2, 2))
  46. assert_raises(ValueError, np.ravel_multi_index, (0, 2), (2, 2))
  47. assert_raises(TypeError, np.ravel_multi_index, (0.1, 0.), (2, 2))
  48. assert_equal(np.unravel_index((2*3 + 1)*6 + 4, (4, 3, 6)), [2, 1, 4])
  49. assert_equal(
  50. np.ravel_multi_index([2, 1, 4], (4, 3, 6)), (2*3 + 1)*6 + 4)
  51. arr = np.array([[3, 6, 6], [4, 5, 1]])
  52. assert_equal(np.ravel_multi_index(arr, (7, 6)), [22, 41, 37])
  53. assert_equal(
  54. np.ravel_multi_index(arr, (7, 6), order='F'), [31, 41, 13])
  55. assert_equal(
  56. np.ravel_multi_index(arr, (4, 6), mode='clip'), [22, 23, 19])
  57. assert_equal(np.ravel_multi_index(arr, (4, 4), mode=('clip', 'wrap')),
  58. [12, 13, 13])
  59. assert_equal(np.ravel_multi_index((3, 1, 4, 1), (6, 7, 8, 9)), 1621)
  60. assert_equal(np.unravel_index(np.array([22, 41, 37]), (7, 6)),
  61. [[3, 6, 6], [4, 5, 1]])
  62. assert_equal(
  63. np.unravel_index(np.array([31, 41, 13]), (7, 6), order='F'),
  64. [[3, 6, 6], [4, 5, 1]])
  65. assert_equal(np.unravel_index(1621, (6, 7, 8, 9)), [3, 1, 4, 1])
  66. def test_big_indices(self):
  67. # ravel_multi_index for big indices (issue #7546)
  68. if np.intp == np.int64:
  69. arr = ([1, 29], [3, 5], [3, 117], [19, 2],
  70. [2379, 1284], [2, 2], [0, 1])
  71. assert_equal(
  72. np.ravel_multi_index(arr, (41, 7, 120, 36, 2706, 8, 6)),
  73. [5627771580, 117259570957])
  74. # test overflow checking for too big array (issue #7546)
  75. dummy_arr = ([0],[0])
  76. half_max = np.iinfo(np.intp).max // 2
  77. assert_equal(
  78. np.ravel_multi_index(dummy_arr, (half_max, 2)), [0])
  79. assert_raises(ValueError,
  80. np.ravel_multi_index, dummy_arr, (half_max+1, 2))
  81. assert_equal(
  82. np.ravel_multi_index(dummy_arr, (half_max, 2), order='F'), [0])
  83. assert_raises(ValueError,
  84. np.ravel_multi_index, dummy_arr, (half_max+1, 2), order='F')
  85. def test_dtypes(self):
  86. # Test with different data types
  87. for dtype in [np.int16, np.uint16, np.int32,
  88. np.uint32, np.int64, np.uint64]:
  89. coords = np.array(
  90. [[1, 0, 1, 2, 3, 4], [1, 6, 1, 3, 2, 0]], dtype=dtype)
  91. shape = (5, 8)
  92. uncoords = 8*coords[0]+coords[1]
  93. assert_equal(np.ravel_multi_index(coords, shape), uncoords)
  94. assert_equal(coords, np.unravel_index(uncoords, shape))
  95. uncoords = coords[0]+5*coords[1]
  96. assert_equal(
  97. np.ravel_multi_index(coords, shape, order='F'), uncoords)
  98. assert_equal(coords, np.unravel_index(uncoords, shape, order='F'))
  99. coords = np.array(
  100. [[1, 0, 1, 2, 3, 4], [1, 6, 1, 3, 2, 0], [1, 3, 1, 0, 9, 5]],
  101. dtype=dtype)
  102. shape = (5, 8, 10)
  103. uncoords = 10*(8*coords[0]+coords[1])+coords[2]
  104. assert_equal(np.ravel_multi_index(coords, shape), uncoords)
  105. assert_equal(coords, np.unravel_index(uncoords, shape))
  106. uncoords = coords[0]+5*(coords[1]+8*coords[2])
  107. assert_equal(
  108. np.ravel_multi_index(coords, shape, order='F'), uncoords)
  109. assert_equal(coords, np.unravel_index(uncoords, shape, order='F'))
  110. def test_clipmodes(self):
  111. # Test clipmodes
  112. assert_equal(
  113. np.ravel_multi_index([5, 1, -1, 2], (4, 3, 7, 12), mode='wrap'),
  114. np.ravel_multi_index([1, 1, 6, 2], (4, 3, 7, 12)))
  115. assert_equal(np.ravel_multi_index([5, 1, -1, 2], (4, 3, 7, 12),
  116. mode=(
  117. 'wrap', 'raise', 'clip', 'raise')),
  118. np.ravel_multi_index([1, 1, 0, 2], (4, 3, 7, 12)))
  119. assert_raises(
  120. ValueError, np.ravel_multi_index, [5, 1, -1, 2], (4, 3, 7, 12))
  121. def test_writeability(self):
  122. # See gh-7269
  123. x, y = np.unravel_index([1, 2, 3], (4, 5))
  124. assert_(x.flags.writeable)
  125. assert_(y.flags.writeable)
  126. def test_0d(self):
  127. # gh-580
  128. x = np.unravel_index(0, ())
  129. assert_equal(x, ())
  130. assert_raises_regex(ValueError, "0d array", np.unravel_index, [0], ())
  131. assert_raises_regex(
  132. ValueError, "out of bounds", np.unravel_index, [1], ())
  133. class TestGrid(object):
  134. def test_basic(self):
  135. a = mgrid[-1:1:10j]
  136. b = mgrid[-1:1:0.1]
  137. assert_(a.shape == (10,))
  138. assert_(b.shape == (20,))
  139. assert_(a[0] == -1)
  140. assert_almost_equal(a[-1], 1)
  141. assert_(b[0] == -1)
  142. assert_almost_equal(b[1]-b[0], 0.1, 11)
  143. assert_almost_equal(b[-1], b[0]+19*0.1, 11)
  144. assert_almost_equal(a[1]-a[0], 2.0/9.0, 11)
  145. def test_linspace_equivalence(self):
  146. y, st = np.linspace(2, 10, retstep=1)
  147. assert_almost_equal(st, 8/49.0)
  148. assert_array_almost_equal(y, mgrid[2:10:50j], 13)
  149. def test_nd(self):
  150. c = mgrid[-1:1:10j, -2:2:10j]
  151. d = mgrid[-1:1:0.1, -2:2:0.2]
  152. assert_(c.shape == (2, 10, 10))
  153. assert_(d.shape == (2, 20, 20))
  154. assert_array_equal(c[0][0, :], -np.ones(10, 'd'))
  155. assert_array_equal(c[1][:, 0], -2*np.ones(10, 'd'))
  156. assert_array_almost_equal(c[0][-1, :], np.ones(10, 'd'), 11)
  157. assert_array_almost_equal(c[1][:, -1], 2*np.ones(10, 'd'), 11)
  158. assert_array_almost_equal(d[0, 1, :] - d[0, 0, :],
  159. 0.1*np.ones(20, 'd'), 11)
  160. assert_array_almost_equal(d[1, :, 1] - d[1, :, 0],
  161. 0.2*np.ones(20, 'd'), 11)
  162. def test_sparse(self):
  163. grid_full = mgrid[-1:1:10j, -2:2:10j]
  164. grid_sparse = ogrid[-1:1:10j, -2:2:10j]
  165. # sparse grids can be made dense by broadcasting
  166. grid_broadcast = np.broadcast_arrays(*grid_sparse)
  167. for f, b in zip(grid_full, grid_broadcast):
  168. assert_equal(f, b)
  169. @pytest.mark.parametrize("start, stop, step, expected", [
  170. (None, 10, 10j, (200, 10)),
  171. (-10, 20, None, (1800, 30)),
  172. ])
  173. def test_mgrid_size_none_handling(self, start, stop, step, expected):
  174. # regression test None value handling for
  175. # start and step values used by mgrid;
  176. # internally, this aims to cover previously
  177. # unexplored code paths in nd_grid()
  178. grid = mgrid[start:stop:step, start:stop:step]
  179. # need a smaller grid to explore one of the
  180. # untested code paths
  181. grid_small = mgrid[start:stop:step]
  182. assert_equal(grid.size, expected[0])
  183. assert_equal(grid_small.size, expected[1])
  184. class TestConcatenator(object):
  185. def test_1d(self):
  186. assert_array_equal(r_[1, 2, 3, 4, 5, 6], np.array([1, 2, 3, 4, 5, 6]))
  187. b = np.ones(5)
  188. c = r_[b, 0, 0, b]
  189. assert_array_equal(c, [1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1])
  190. def test_mixed_type(self):
  191. g = r_[10.1, 1:10]
  192. assert_(g.dtype == 'f8')
  193. def test_more_mixed_type(self):
  194. g = r_[-10.1, np.array([1]), np.array([2, 3, 4]), 10.0]
  195. assert_(g.dtype == 'f8')
  196. def test_complex_step(self):
  197. # Regression test for #12262
  198. g = r_[0:36:100j]
  199. assert_(g.shape == (100,))
  200. def test_2d(self):
  201. b = np.random.rand(5, 5)
  202. c = np.random.rand(5, 5)
  203. d = r_['1', b, c] # append columns
  204. assert_(d.shape == (5, 10))
  205. assert_array_equal(d[:, :5], b)
  206. assert_array_equal(d[:, 5:], c)
  207. d = r_[b, c]
  208. assert_(d.shape == (10, 5))
  209. assert_array_equal(d[:5, :], b)
  210. assert_array_equal(d[5:, :], c)
  211. def test_0d(self):
  212. assert_equal(r_[0, np.array(1), 2], [0, 1, 2])
  213. assert_equal(r_[[0, 1, 2], np.array(3)], [0, 1, 2, 3])
  214. assert_equal(r_[np.array(0), [1, 2, 3]], [0, 1, 2, 3])
  215. class TestNdenumerate(object):
  216. def test_basic(self):
  217. a = np.array([[1, 2], [3, 4]])
  218. assert_equal(list(ndenumerate(a)),
  219. [((0, 0), 1), ((0, 1), 2), ((1, 0), 3), ((1, 1), 4)])
  220. class TestIndexExpression(object):
  221. def test_regression_1(self):
  222. # ticket #1196
  223. a = np.arange(2)
  224. assert_equal(a[:-1], a[s_[:-1]])
  225. assert_equal(a[:-1], a[index_exp[:-1]])
  226. def test_simple_1(self):
  227. a = np.random.rand(4, 5, 6)
  228. assert_equal(a[:, :3, [1, 2]], a[index_exp[:, :3, [1, 2]]])
  229. assert_equal(a[:, :3, [1, 2]], a[s_[:, :3, [1, 2]]])
  230. class TestIx_(object):
  231. def test_regression_1(self):
  232. # Test empty inputs create outputs of indexing type, gh-5804
  233. # Test both lists and arrays
  234. for func in (range, np.arange):
  235. a, = np.ix_(func(0))
  236. assert_equal(a.dtype, np.intp)
  237. def test_shape_and_dtype(self):
  238. sizes = (4, 5, 3, 2)
  239. # Test both lists and arrays
  240. for func in (range, np.arange):
  241. arrays = np.ix_(*[func(sz) for sz in sizes])
  242. for k, (a, sz) in enumerate(zip(arrays, sizes)):
  243. assert_equal(a.shape[k], sz)
  244. assert_(all(sh == 1 for j, sh in enumerate(a.shape) if j != k))
  245. assert_(np.issubdtype(a.dtype, np.integer))
  246. def test_bool(self):
  247. bool_a = [True, False, True, True]
  248. int_a, = np.nonzero(bool_a)
  249. assert_equal(np.ix_(bool_a)[0], int_a)
  250. def test_1d_only(self):
  251. idx2d = [[1, 2, 3], [4, 5, 6]]
  252. assert_raises(ValueError, np.ix_, idx2d)
  253. def test_repeated_input(self):
  254. length_of_vector = 5
  255. x = np.arange(length_of_vector)
  256. out = ix_(x, x)
  257. assert_equal(out[0].shape, (length_of_vector, 1))
  258. assert_equal(out[1].shape, (1, length_of_vector))
  259. # check that input shape is not modified
  260. assert_equal(x.shape, (length_of_vector,))
  261. def test_c_():
  262. a = np.c_[np.array([[1, 2, 3]]), 0, 0, np.array([[4, 5, 6]])]
  263. assert_equal(a, [[1, 2, 3, 0, 0, 4, 5, 6]])
  264. class TestFillDiagonal(object):
  265. def test_basic(self):
  266. a = np.zeros((3, 3), int)
  267. fill_diagonal(a, 5)
  268. assert_array_equal(
  269. a, np.array([[5, 0, 0],
  270. [0, 5, 0],
  271. [0, 0, 5]])
  272. )
  273. def test_tall_matrix(self):
  274. a = np.zeros((10, 3), int)
  275. fill_diagonal(a, 5)
  276. assert_array_equal(
  277. a, np.array([[5, 0, 0],
  278. [0, 5, 0],
  279. [0, 0, 5],
  280. [0, 0, 0],
  281. [0, 0, 0],
  282. [0, 0, 0],
  283. [0, 0, 0],
  284. [0, 0, 0],
  285. [0, 0, 0],
  286. [0, 0, 0]])
  287. )
  288. def test_tall_matrix_wrap(self):
  289. a = np.zeros((10, 3), int)
  290. fill_diagonal(a, 5, True)
  291. assert_array_equal(
  292. a, np.array([[5, 0, 0],
  293. [0, 5, 0],
  294. [0, 0, 5],
  295. [0, 0, 0],
  296. [5, 0, 0],
  297. [0, 5, 0],
  298. [0, 0, 5],
  299. [0, 0, 0],
  300. [5, 0, 0],
  301. [0, 5, 0]])
  302. )
  303. def test_wide_matrix(self):
  304. a = np.zeros((3, 10), int)
  305. fill_diagonal(a, 5)
  306. assert_array_equal(
  307. a, np.array([[5, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  308. [0, 5, 0, 0, 0, 0, 0, 0, 0, 0],
  309. [0, 0, 5, 0, 0, 0, 0, 0, 0, 0]])
  310. )
  311. def test_operate_4d_array(self):
  312. a = np.zeros((3, 3, 3, 3), int)
  313. fill_diagonal(a, 4)
  314. i = np.array([0, 1, 2])
  315. assert_equal(np.where(a != 0), (i, i, i, i))
  316. def test_low_dim_handling(self):
  317. # raise error with low dimensionality
  318. a = np.zeros(3, int)
  319. with assert_raises_regex(ValueError, "at least 2-d"):
  320. fill_diagonal(a, 5)
  321. def test_hetero_shape_handling(self):
  322. # raise error with high dimensionality and
  323. # shape mismatch
  324. a = np.zeros((3,3,7,3), int)
  325. with assert_raises_regex(ValueError, "equal length"):
  326. fill_diagonal(a, 2)
  327. def test_diag_indices():
  328. di = diag_indices(4)
  329. a = np.array([[1, 2, 3, 4],
  330. [5, 6, 7, 8],
  331. [9, 10, 11, 12],
  332. [13, 14, 15, 16]])
  333. a[di] = 100
  334. assert_array_equal(
  335. a, np.array([[100, 2, 3, 4],
  336. [5, 100, 7, 8],
  337. [9, 10, 100, 12],
  338. [13, 14, 15, 100]])
  339. )
  340. # Now, we create indices to manipulate a 3-d array:
  341. d3 = diag_indices(2, 3)
  342. # And use it to set the diagonal of a zeros array to 1:
  343. a = np.zeros((2, 2, 2), int)
  344. a[d3] = 1
  345. assert_array_equal(
  346. a, np.array([[[1, 0],
  347. [0, 0]],
  348. [[0, 0],
  349. [0, 1]]])
  350. )
  351. class TestDiagIndicesFrom(object):
  352. def test_diag_indices_from(self):
  353. x = np.random.random((4, 4))
  354. r, c = diag_indices_from(x)
  355. assert_array_equal(r, np.arange(4))
  356. assert_array_equal(c, np.arange(4))
  357. def test_error_small_input(self):
  358. x = np.ones(7)
  359. with assert_raises_regex(ValueError, "at least 2-d"):
  360. diag_indices_from(x)
  361. def test_error_shape_mismatch(self):
  362. x = np.zeros((3, 3, 2, 3), int)
  363. with assert_raises_regex(ValueError, "equal length"):
  364. diag_indices_from(x)
  365. def test_ndindex():
  366. x = list(ndindex(1, 2, 3))
  367. expected = [ix for ix, e in ndenumerate(np.zeros((1, 2, 3)))]
  368. assert_array_equal(x, expected)
  369. x = list(ndindex((1, 2, 3)))
  370. assert_array_equal(x, expected)
  371. # Test use of scalars and tuples
  372. x = list(ndindex((3,)))
  373. assert_array_equal(x, list(ndindex(3)))
  374. # Make sure size argument is optional
  375. x = list(ndindex())
  376. assert_equal(x, [()])
  377. x = list(ndindex(()))
  378. assert_equal(x, [()])
  379. # Make sure 0-sized ndindex works correctly
  380. x = list(ndindex(*[0]))
  381. assert_equal(x, [])