test_shape_base.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708
  1. from __future__ import division, absolute_import, print_function
  2. import numpy as np
  3. import warnings
  4. import functools
  5. import sys
  6. import pytest
  7. from numpy.lib.shape_base import (
  8. apply_along_axis, apply_over_axes, array_split, split, hsplit, dsplit,
  9. vsplit, dstack, column_stack, kron, tile, expand_dims, take_along_axis,
  10. put_along_axis
  11. )
  12. from numpy.testing import (
  13. assert_, assert_equal, assert_array_equal, assert_raises, assert_warns
  14. )
  15. IS_64BIT = sys.maxsize > 2**32
  16. def _add_keepdims(func):
  17. """ hack in keepdims behavior into a function taking an axis """
  18. @functools.wraps(func)
  19. def wrapped(a, axis, **kwargs):
  20. res = func(a, axis=axis, **kwargs)
  21. if axis is None:
  22. axis = 0 # res is now a scalar, so we can insert this anywhere
  23. return np.expand_dims(res, axis=axis)
  24. return wrapped
  25. class TestTakeAlongAxis(object):
  26. def test_argequivalent(self):
  27. """ Test it translates from arg<func> to <func> """
  28. from numpy.random import rand
  29. a = rand(3, 4, 5)
  30. funcs = [
  31. (np.sort, np.argsort, dict()),
  32. (_add_keepdims(np.min), _add_keepdims(np.argmin), dict()),
  33. (_add_keepdims(np.max), _add_keepdims(np.argmax), dict()),
  34. (np.partition, np.argpartition, dict(kth=2)),
  35. ]
  36. for func, argfunc, kwargs in funcs:
  37. for axis in list(range(a.ndim)) + [None]:
  38. a_func = func(a, axis=axis, **kwargs)
  39. ai_func = argfunc(a, axis=axis, **kwargs)
  40. assert_equal(a_func, take_along_axis(a, ai_func, axis=axis))
  41. def test_invalid(self):
  42. """ Test it errors when indices has too few dimensions """
  43. a = np.ones((10, 10))
  44. ai = np.ones((10, 2), dtype=np.intp)
  45. # sanity check
  46. take_along_axis(a, ai, axis=1)
  47. # not enough indices
  48. assert_raises(ValueError, take_along_axis, a, np.array(1), axis=1)
  49. # bool arrays not allowed
  50. assert_raises(IndexError, take_along_axis, a, ai.astype(bool), axis=1)
  51. # float arrays not allowed
  52. assert_raises(IndexError, take_along_axis, a, ai.astype(float), axis=1)
  53. # invalid axis
  54. assert_raises(np.AxisError, take_along_axis, a, ai, axis=10)
  55. def test_empty(self):
  56. """ Test everything is ok with empty results, even with inserted dims """
  57. a = np.ones((3, 4, 5))
  58. ai = np.ones((3, 0, 5), dtype=np.intp)
  59. actual = take_along_axis(a, ai, axis=1)
  60. assert_equal(actual.shape, ai.shape)
  61. def test_broadcast(self):
  62. """ Test that non-indexing dimensions are broadcast in both directions """
  63. a = np.ones((3, 4, 1))
  64. ai = np.ones((1, 2, 5), dtype=np.intp)
  65. actual = take_along_axis(a, ai, axis=1)
  66. assert_equal(actual.shape, (3, 2, 5))
  67. class TestPutAlongAxis(object):
  68. def test_replace_max(self):
  69. a_base = np.array([[10, 30, 20], [60, 40, 50]])
  70. for axis in list(range(a_base.ndim)) + [None]:
  71. # we mutate this in the loop
  72. a = a_base.copy()
  73. # replace the max with a small value
  74. i_max = _add_keepdims(np.argmax)(a, axis=axis)
  75. put_along_axis(a, i_max, -99, axis=axis)
  76. # find the new minimum, which should max
  77. i_min = _add_keepdims(np.argmin)(a, axis=axis)
  78. assert_equal(i_min, i_max)
  79. def test_broadcast(self):
  80. """ Test that non-indexing dimensions are broadcast in both directions """
  81. a = np.ones((3, 4, 1))
  82. ai = np.arange(10, dtype=np.intp).reshape((1, 2, 5)) % 4
  83. put_along_axis(a, ai, 20, axis=1)
  84. assert_equal(take_along_axis(a, ai, axis=1), 20)
  85. class TestApplyAlongAxis(object):
  86. def test_simple(self):
  87. a = np.ones((20, 10), 'd')
  88. assert_array_equal(
  89. apply_along_axis(len, 0, a), len(a)*np.ones(a.shape[1]))
  90. def test_simple101(self):
  91. a = np.ones((10, 101), 'd')
  92. assert_array_equal(
  93. apply_along_axis(len, 0, a), len(a)*np.ones(a.shape[1]))
  94. def test_3d(self):
  95. a = np.arange(27).reshape((3, 3, 3))
  96. assert_array_equal(apply_along_axis(np.sum, 0, a),
  97. [[27, 30, 33], [36, 39, 42], [45, 48, 51]])
  98. def test_preserve_subclass(self):
  99. def double(row):
  100. return row * 2
  101. class MyNDArray(np.ndarray):
  102. pass
  103. m = np.array([[0, 1], [2, 3]]).view(MyNDArray)
  104. expected = np.array([[0, 2], [4, 6]]).view(MyNDArray)
  105. result = apply_along_axis(double, 0, m)
  106. assert_(isinstance(result, MyNDArray))
  107. assert_array_equal(result, expected)
  108. result = apply_along_axis(double, 1, m)
  109. assert_(isinstance(result, MyNDArray))
  110. assert_array_equal(result, expected)
  111. def test_subclass(self):
  112. class MinimalSubclass(np.ndarray):
  113. data = 1
  114. def minimal_function(array):
  115. return array.data
  116. a = np.zeros((6, 3)).view(MinimalSubclass)
  117. assert_array_equal(
  118. apply_along_axis(minimal_function, 0, a), np.array([1, 1, 1])
  119. )
  120. def test_scalar_array(self, cls=np.ndarray):
  121. a = np.ones((6, 3)).view(cls)
  122. res = apply_along_axis(np.sum, 0, a)
  123. assert_(isinstance(res, cls))
  124. assert_array_equal(res, np.array([6, 6, 6]).view(cls))
  125. def test_0d_array(self, cls=np.ndarray):
  126. def sum_to_0d(x):
  127. """ Sum x, returning a 0d array of the same class """
  128. assert_equal(x.ndim, 1)
  129. return np.squeeze(np.sum(x, keepdims=True))
  130. a = np.ones((6, 3)).view(cls)
  131. res = apply_along_axis(sum_to_0d, 0, a)
  132. assert_(isinstance(res, cls))
  133. assert_array_equal(res, np.array([6, 6, 6]).view(cls))
  134. res = apply_along_axis(sum_to_0d, 1, a)
  135. assert_(isinstance(res, cls))
  136. assert_array_equal(res, np.array([3, 3, 3, 3, 3, 3]).view(cls))
  137. def test_axis_insertion(self, cls=np.ndarray):
  138. def f1to2(x):
  139. """produces an asymmetric non-square matrix from x"""
  140. assert_equal(x.ndim, 1)
  141. return (x[::-1] * x[1:,None]).view(cls)
  142. a2d = np.arange(6*3).reshape((6, 3))
  143. # 2d insertion along first axis
  144. actual = apply_along_axis(f1to2, 0, a2d)
  145. expected = np.stack([
  146. f1to2(a2d[:,i]) for i in range(a2d.shape[1])
  147. ], axis=-1).view(cls)
  148. assert_equal(type(actual), type(expected))
  149. assert_equal(actual, expected)
  150. # 2d insertion along last axis
  151. actual = apply_along_axis(f1to2, 1, a2d)
  152. expected = np.stack([
  153. f1to2(a2d[i,:]) for i in range(a2d.shape[0])
  154. ], axis=0).view(cls)
  155. assert_equal(type(actual), type(expected))
  156. assert_equal(actual, expected)
  157. # 3d insertion along middle axis
  158. a3d = np.arange(6*5*3).reshape((6, 5, 3))
  159. actual = apply_along_axis(f1to2, 1, a3d)
  160. expected = np.stack([
  161. np.stack([
  162. f1to2(a3d[i,:,j]) for i in range(a3d.shape[0])
  163. ], axis=0)
  164. for j in range(a3d.shape[2])
  165. ], axis=-1).view(cls)
  166. assert_equal(type(actual), type(expected))
  167. assert_equal(actual, expected)
  168. def test_subclass_preservation(self):
  169. class MinimalSubclass(np.ndarray):
  170. pass
  171. self.test_scalar_array(MinimalSubclass)
  172. self.test_0d_array(MinimalSubclass)
  173. self.test_axis_insertion(MinimalSubclass)
  174. def test_axis_insertion_ma(self):
  175. def f1to2(x):
  176. """produces an asymmetric non-square matrix from x"""
  177. assert_equal(x.ndim, 1)
  178. res = x[::-1] * x[1:,None]
  179. return np.ma.masked_where(res%5==0, res)
  180. a = np.arange(6*3).reshape((6, 3))
  181. res = apply_along_axis(f1to2, 0, a)
  182. assert_(isinstance(res, np.ma.masked_array))
  183. assert_equal(res.ndim, 3)
  184. assert_array_equal(res[:,:,0].mask, f1to2(a[:,0]).mask)
  185. assert_array_equal(res[:,:,1].mask, f1to2(a[:,1]).mask)
  186. assert_array_equal(res[:,:,2].mask, f1to2(a[:,2]).mask)
  187. def test_tuple_func1d(self):
  188. def sample_1d(x):
  189. return x[1], x[0]
  190. res = np.apply_along_axis(sample_1d, 1, np.array([[1, 2], [3, 4]]))
  191. assert_array_equal(res, np.array([[2, 1], [4, 3]]))
  192. def test_empty(self):
  193. # can't apply_along_axis when there's no chance to call the function
  194. def never_call(x):
  195. assert_(False) # should never be reached
  196. a = np.empty((0, 0))
  197. assert_raises(ValueError, np.apply_along_axis, never_call, 0, a)
  198. assert_raises(ValueError, np.apply_along_axis, never_call, 1, a)
  199. # but it's sometimes ok with some non-zero dimensions
  200. def empty_to_1(x):
  201. assert_(len(x) == 0)
  202. return 1
  203. a = np.empty((10, 0))
  204. actual = np.apply_along_axis(empty_to_1, 1, a)
  205. assert_equal(actual, np.ones(10))
  206. assert_raises(ValueError, np.apply_along_axis, empty_to_1, 0, a)
  207. def test_with_iterable_object(self):
  208. # from issue 5248
  209. d = np.array([
  210. [{1, 11}, {2, 22}, {3, 33}],
  211. [{4, 44}, {5, 55}, {6, 66}]
  212. ])
  213. actual = np.apply_along_axis(lambda a: set.union(*a), 0, d)
  214. expected = np.array([{1, 11, 4, 44}, {2, 22, 5, 55}, {3, 33, 6, 66}])
  215. assert_equal(actual, expected)
  216. # issue 8642 - assert_equal doesn't detect this!
  217. for i in np.ndindex(actual.shape):
  218. assert_equal(type(actual[i]), type(expected[i]))
  219. class TestApplyOverAxes(object):
  220. def test_simple(self):
  221. a = np.arange(24).reshape(2, 3, 4)
  222. aoa_a = apply_over_axes(np.sum, a, [0, 2])
  223. assert_array_equal(aoa_a, np.array([[[60], [92], [124]]]))
  224. class TestExpandDims(object):
  225. def test_functionality(self):
  226. s = (2, 3, 4, 5)
  227. a = np.empty(s)
  228. for axis in range(-5, 4):
  229. b = expand_dims(a, axis)
  230. assert_(b.shape[axis] == 1)
  231. assert_(np.squeeze(b).shape == s)
  232. def test_deprecations(self):
  233. # 2017-05-17, 1.13.0
  234. s = (2, 3, 4, 5)
  235. a = np.empty(s)
  236. with warnings.catch_warnings():
  237. warnings.simplefilter("always")
  238. assert_warns(DeprecationWarning, expand_dims, a, -6)
  239. assert_warns(DeprecationWarning, expand_dims, a, 5)
  240. def test_subclasses(self):
  241. a = np.arange(10).reshape((2, 5))
  242. a = np.ma.array(a, mask=a%3 == 0)
  243. expanded = np.expand_dims(a, axis=1)
  244. assert_(isinstance(expanded, np.ma.MaskedArray))
  245. assert_equal(expanded.shape, (2, 1, 5))
  246. assert_equal(expanded.mask.shape, (2, 1, 5))
  247. class TestArraySplit(object):
  248. def test_integer_0_split(self):
  249. a = np.arange(10)
  250. assert_raises(ValueError, array_split, a, 0)
  251. def test_integer_split(self):
  252. a = np.arange(10)
  253. res = array_split(a, 1)
  254. desired = [np.arange(10)]
  255. compare_results(res, desired)
  256. res = array_split(a, 2)
  257. desired = [np.arange(5), np.arange(5, 10)]
  258. compare_results(res, desired)
  259. res = array_split(a, 3)
  260. desired = [np.arange(4), np.arange(4, 7), np.arange(7, 10)]
  261. compare_results(res, desired)
  262. res = array_split(a, 4)
  263. desired = [np.arange(3), np.arange(3, 6), np.arange(6, 8),
  264. np.arange(8, 10)]
  265. compare_results(res, desired)
  266. res = array_split(a, 5)
  267. desired = [np.arange(2), np.arange(2, 4), np.arange(4, 6),
  268. np.arange(6, 8), np.arange(8, 10)]
  269. compare_results(res, desired)
  270. res = array_split(a, 6)
  271. desired = [np.arange(2), np.arange(2, 4), np.arange(4, 6),
  272. np.arange(6, 8), np.arange(8, 9), np.arange(9, 10)]
  273. compare_results(res, desired)
  274. res = array_split(a, 7)
  275. desired = [np.arange(2), np.arange(2, 4), np.arange(4, 6),
  276. np.arange(6, 7), np.arange(7, 8), np.arange(8, 9),
  277. np.arange(9, 10)]
  278. compare_results(res, desired)
  279. res = array_split(a, 8)
  280. desired = [np.arange(2), np.arange(2, 4), np.arange(4, 5),
  281. np.arange(5, 6), np.arange(6, 7), np.arange(7, 8),
  282. np.arange(8, 9), np.arange(9, 10)]
  283. compare_results(res, desired)
  284. res = array_split(a, 9)
  285. desired = [np.arange(2), np.arange(2, 3), np.arange(3, 4),
  286. np.arange(4, 5), np.arange(5, 6), np.arange(6, 7),
  287. np.arange(7, 8), np.arange(8, 9), np.arange(9, 10)]
  288. compare_results(res, desired)
  289. res = array_split(a, 10)
  290. desired = [np.arange(1), np.arange(1, 2), np.arange(2, 3),
  291. np.arange(3, 4), np.arange(4, 5), np.arange(5, 6),
  292. np.arange(6, 7), np.arange(7, 8), np.arange(8, 9),
  293. np.arange(9, 10)]
  294. compare_results(res, desired)
  295. res = array_split(a, 11)
  296. desired = [np.arange(1), np.arange(1, 2), np.arange(2, 3),
  297. np.arange(3, 4), np.arange(4, 5), np.arange(5, 6),
  298. np.arange(6, 7), np.arange(7, 8), np.arange(8, 9),
  299. np.arange(9, 10), np.array([])]
  300. compare_results(res, desired)
  301. def test_integer_split_2D_rows(self):
  302. a = np.array([np.arange(10), np.arange(10)])
  303. res = array_split(a, 3, axis=0)
  304. tgt = [np.array([np.arange(10)]), np.array([np.arange(10)]),
  305. np.zeros((0, 10))]
  306. compare_results(res, tgt)
  307. assert_(a.dtype.type is res[-1].dtype.type)
  308. # Same thing for manual splits:
  309. res = array_split(a, [0, 1, 2], axis=0)
  310. tgt = [np.zeros((0, 10)), np.array([np.arange(10)]),
  311. np.array([np.arange(10)])]
  312. compare_results(res, tgt)
  313. assert_(a.dtype.type is res[-1].dtype.type)
  314. def test_integer_split_2D_cols(self):
  315. a = np.array([np.arange(10), np.arange(10)])
  316. res = array_split(a, 3, axis=-1)
  317. desired = [np.array([np.arange(4), np.arange(4)]),
  318. np.array([np.arange(4, 7), np.arange(4, 7)]),
  319. np.array([np.arange(7, 10), np.arange(7, 10)])]
  320. compare_results(res, desired)
  321. def test_integer_split_2D_default(self):
  322. """ This will fail if we change default axis
  323. """
  324. a = np.array([np.arange(10), np.arange(10)])
  325. res = array_split(a, 3)
  326. tgt = [np.array([np.arange(10)]), np.array([np.arange(10)]),
  327. np.zeros((0, 10))]
  328. compare_results(res, tgt)
  329. assert_(a.dtype.type is res[-1].dtype.type)
  330. # perhaps should check higher dimensions
  331. @pytest.mark.skipif(not IS_64BIT, reason="Needs 64bit platform")
  332. def test_integer_split_2D_rows_greater_max_int32(self):
  333. a = np.broadcast_to([0], (1 << 32, 2))
  334. res = array_split(a, 4)
  335. chunk = np.broadcast_to([0], (1 << 30, 2))
  336. tgt = [chunk] * 4
  337. for i in range(len(tgt)):
  338. assert_equal(res[i].shape, tgt[i].shape)
  339. def test_index_split_simple(self):
  340. a = np.arange(10)
  341. indices = [1, 5, 7]
  342. res = array_split(a, indices, axis=-1)
  343. desired = [np.arange(0, 1), np.arange(1, 5), np.arange(5, 7),
  344. np.arange(7, 10)]
  345. compare_results(res, desired)
  346. def test_index_split_low_bound(self):
  347. a = np.arange(10)
  348. indices = [0, 5, 7]
  349. res = array_split(a, indices, axis=-1)
  350. desired = [np.array([]), np.arange(0, 5), np.arange(5, 7),
  351. np.arange(7, 10)]
  352. compare_results(res, desired)
  353. def test_index_split_high_bound(self):
  354. a = np.arange(10)
  355. indices = [0, 5, 7, 10, 12]
  356. res = array_split(a, indices, axis=-1)
  357. desired = [np.array([]), np.arange(0, 5), np.arange(5, 7),
  358. np.arange(7, 10), np.array([]), np.array([])]
  359. compare_results(res, desired)
  360. class TestSplit(object):
  361. # The split function is essentially the same as array_split,
  362. # except that it test if splitting will result in an
  363. # equal split. Only test for this case.
  364. def test_equal_split(self):
  365. a = np.arange(10)
  366. res = split(a, 2)
  367. desired = [np.arange(5), np.arange(5, 10)]
  368. compare_results(res, desired)
  369. def test_unequal_split(self):
  370. a = np.arange(10)
  371. assert_raises(ValueError, split, a, 3)
  372. class TestColumnStack(object):
  373. def test_non_iterable(self):
  374. assert_raises(TypeError, column_stack, 1)
  375. def test_1D_arrays(self):
  376. # example from docstring
  377. a = np.array((1, 2, 3))
  378. b = np.array((2, 3, 4))
  379. expected = np.array([[1, 2],
  380. [2, 3],
  381. [3, 4]])
  382. actual = np.column_stack((a, b))
  383. assert_equal(actual, expected)
  384. def test_2D_arrays(self):
  385. # same as hstack 2D docstring example
  386. a = np.array([[1], [2], [3]])
  387. b = np.array([[2], [3], [4]])
  388. expected = np.array([[1, 2],
  389. [2, 3],
  390. [3, 4]])
  391. actual = np.column_stack((a, b))
  392. assert_equal(actual, expected)
  393. def test_generator(self):
  394. with assert_warns(FutureWarning):
  395. column_stack((np.arange(3) for _ in range(2)))
  396. class TestDstack(object):
  397. def test_non_iterable(self):
  398. assert_raises(TypeError, dstack, 1)
  399. def test_0D_array(self):
  400. a = np.array(1)
  401. b = np.array(2)
  402. res = dstack([a, b])
  403. desired = np.array([[[1, 2]]])
  404. assert_array_equal(res, desired)
  405. def test_1D_array(self):
  406. a = np.array([1])
  407. b = np.array([2])
  408. res = dstack([a, b])
  409. desired = np.array([[[1, 2]]])
  410. assert_array_equal(res, desired)
  411. def test_2D_array(self):
  412. a = np.array([[1], [2]])
  413. b = np.array([[1], [2]])
  414. res = dstack([a, b])
  415. desired = np.array([[[1, 1]], [[2, 2, ]]])
  416. assert_array_equal(res, desired)
  417. def test_2D_array2(self):
  418. a = np.array([1, 2])
  419. b = np.array([1, 2])
  420. res = dstack([a, b])
  421. desired = np.array([[[1, 1], [2, 2]]])
  422. assert_array_equal(res, desired)
  423. def test_generator(self):
  424. with assert_warns(FutureWarning):
  425. dstack((np.arange(3) for _ in range(2)))
  426. # array_split has more comprehensive test of splitting.
  427. # only do simple test on hsplit, vsplit, and dsplit
  428. class TestHsplit(object):
  429. """Only testing for integer splits.
  430. """
  431. def test_non_iterable(self):
  432. assert_raises(ValueError, hsplit, 1, 1)
  433. def test_0D_array(self):
  434. a = np.array(1)
  435. try:
  436. hsplit(a, 2)
  437. assert_(0)
  438. except ValueError:
  439. pass
  440. def test_1D_array(self):
  441. a = np.array([1, 2, 3, 4])
  442. res = hsplit(a, 2)
  443. desired = [np.array([1, 2]), np.array([3, 4])]
  444. compare_results(res, desired)
  445. def test_2D_array(self):
  446. a = np.array([[1, 2, 3, 4],
  447. [1, 2, 3, 4]])
  448. res = hsplit(a, 2)
  449. desired = [np.array([[1, 2], [1, 2]]), np.array([[3, 4], [3, 4]])]
  450. compare_results(res, desired)
  451. class TestVsplit(object):
  452. """Only testing for integer splits.
  453. """
  454. def test_non_iterable(self):
  455. assert_raises(ValueError, vsplit, 1, 1)
  456. def test_0D_array(self):
  457. a = np.array(1)
  458. assert_raises(ValueError, vsplit, a, 2)
  459. def test_1D_array(self):
  460. a = np.array([1, 2, 3, 4])
  461. try:
  462. vsplit(a, 2)
  463. assert_(0)
  464. except ValueError:
  465. pass
  466. def test_2D_array(self):
  467. a = np.array([[1, 2, 3, 4],
  468. [1, 2, 3, 4]])
  469. res = vsplit(a, 2)
  470. desired = [np.array([[1, 2, 3, 4]]), np.array([[1, 2, 3, 4]])]
  471. compare_results(res, desired)
  472. class TestDsplit(object):
  473. # Only testing for integer splits.
  474. def test_non_iterable(self):
  475. assert_raises(ValueError, dsplit, 1, 1)
  476. def test_0D_array(self):
  477. a = np.array(1)
  478. assert_raises(ValueError, dsplit, a, 2)
  479. def test_1D_array(self):
  480. a = np.array([1, 2, 3, 4])
  481. assert_raises(ValueError, dsplit, a, 2)
  482. def test_2D_array(self):
  483. a = np.array([[1, 2, 3, 4],
  484. [1, 2, 3, 4]])
  485. try:
  486. dsplit(a, 2)
  487. assert_(0)
  488. except ValueError:
  489. pass
  490. def test_3D_array(self):
  491. a = np.array([[[1, 2, 3, 4],
  492. [1, 2, 3, 4]],
  493. [[1, 2, 3, 4],
  494. [1, 2, 3, 4]]])
  495. res = dsplit(a, 2)
  496. desired = [np.array([[[1, 2], [1, 2]], [[1, 2], [1, 2]]]),
  497. np.array([[[3, 4], [3, 4]], [[3, 4], [3, 4]]])]
  498. compare_results(res, desired)
  499. class TestSqueeze(object):
  500. def test_basic(self):
  501. from numpy.random import rand
  502. a = rand(20, 10, 10, 1, 1)
  503. b = rand(20, 1, 10, 1, 20)
  504. c = rand(1, 1, 20, 10)
  505. assert_array_equal(np.squeeze(a), np.reshape(a, (20, 10, 10)))
  506. assert_array_equal(np.squeeze(b), np.reshape(b, (20, 10, 20)))
  507. assert_array_equal(np.squeeze(c), np.reshape(c, (20, 10)))
  508. # Squeezing to 0-dim should still give an ndarray
  509. a = [[[1.5]]]
  510. res = np.squeeze(a)
  511. assert_equal(res, 1.5)
  512. assert_equal(res.ndim, 0)
  513. assert_equal(type(res), np.ndarray)
  514. class TestKron(object):
  515. def test_return_type(self):
  516. class myarray(np.ndarray):
  517. __array_priority__ = 0.0
  518. a = np.ones([2, 2])
  519. ma = myarray(a.shape, a.dtype, a.data)
  520. assert_equal(type(kron(a, a)), np.ndarray)
  521. assert_equal(type(kron(ma, ma)), myarray)
  522. assert_equal(type(kron(a, ma)), np.ndarray)
  523. assert_equal(type(kron(ma, a)), myarray)
  524. class TestTile(object):
  525. def test_basic(self):
  526. a = np.array([0, 1, 2])
  527. b = [[1, 2], [3, 4]]
  528. assert_equal(tile(a, 2), [0, 1, 2, 0, 1, 2])
  529. assert_equal(tile(a, (2, 2)), [[0, 1, 2, 0, 1, 2], [0, 1, 2, 0, 1, 2]])
  530. assert_equal(tile(a, (1, 2)), [[0, 1, 2, 0, 1, 2]])
  531. assert_equal(tile(b, 2), [[1, 2, 1, 2], [3, 4, 3, 4]])
  532. assert_equal(tile(b, (2, 1)), [[1, 2], [3, 4], [1, 2], [3, 4]])
  533. assert_equal(tile(b, (2, 2)), [[1, 2, 1, 2], [3, 4, 3, 4],
  534. [1, 2, 1, 2], [3, 4, 3, 4]])
  535. def test_tile_one_repetition_on_array_gh4679(self):
  536. a = np.arange(5)
  537. b = tile(a, 1)
  538. b += 2
  539. assert_equal(a, np.arange(5))
  540. def test_empty(self):
  541. a = np.array([[[]]])
  542. b = np.array([[], []])
  543. c = tile(b, 2).shape
  544. d = tile(a, (3, 2, 5)).shape
  545. assert_equal(c, (2, 0))
  546. assert_equal(d, (3, 2, 0))
  547. def test_kroncompare(self):
  548. from numpy.random import randint
  549. reps = [(2,), (1, 2), (2, 1), (2, 2), (2, 3, 2), (3, 2)]
  550. shape = [(3,), (2, 3), (3, 4, 3), (3, 2, 3), (4, 3, 2, 4), (2, 2)]
  551. for s in shape:
  552. b = randint(0, 10, size=s)
  553. for r in reps:
  554. a = np.ones(r, b.dtype)
  555. large = tile(b, r)
  556. klarge = kron(a, b)
  557. assert_equal(large, klarge)
  558. class TestMayShareMemory(object):
  559. def test_basic(self):
  560. d = np.ones((50, 60))
  561. d2 = np.ones((30, 60, 6))
  562. assert_(np.may_share_memory(d, d))
  563. assert_(np.may_share_memory(d, d[::-1]))
  564. assert_(np.may_share_memory(d, d[::2]))
  565. assert_(np.may_share_memory(d, d[1:, ::-1]))
  566. assert_(not np.may_share_memory(d[::-1], d2))
  567. assert_(not np.may_share_memory(d[::2], d2))
  568. assert_(not np.may_share_memory(d[1:, ::-1], d2))
  569. assert_(np.may_share_memory(d2[1:, ::-1], d2))
  570. # Utility
  571. def compare_results(res, desired):
  572. for i in range(len(desired)):
  573. assert_array_equal(res[i], desired[i])