index_tricks.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965
  1. from __future__ import division, absolute_import, print_function
  2. import functools
  3. import sys
  4. import math
  5. import numpy.core.numeric as _nx
  6. from numpy.core.numeric import (
  7. asarray, ScalarType, array, alltrue, cumprod, arange, ndim
  8. )
  9. from numpy.core.numerictypes import find_common_type, issubdtype
  10. import numpy.matrixlib as matrixlib
  11. from .function_base import diff
  12. from numpy.core.multiarray import ravel_multi_index, unravel_index
  13. from numpy.core.overrides import set_module
  14. from numpy.core import overrides, linspace
  15. from numpy.lib.stride_tricks import as_strided
  16. array_function_dispatch = functools.partial(
  17. overrides.array_function_dispatch, module='numpy')
  18. __all__ = [
  19. 'ravel_multi_index', 'unravel_index', 'mgrid', 'ogrid', 'r_', 'c_',
  20. 's_', 'index_exp', 'ix_', 'ndenumerate', 'ndindex', 'fill_diagonal',
  21. 'diag_indices', 'diag_indices_from'
  22. ]
  23. def _ix__dispatcher(*args):
  24. return args
  25. @array_function_dispatch(_ix__dispatcher)
  26. def ix_(*args):
  27. """
  28. Construct an open mesh from multiple sequences.
  29. This function takes N 1-D sequences and returns N outputs with N
  30. dimensions each, such that the shape is 1 in all but one dimension
  31. and the dimension with the non-unit shape value cycles through all
  32. N dimensions.
  33. Using `ix_` one can quickly construct index arrays that will index
  34. the cross product. ``a[np.ix_([1,3],[2,5])]`` returns the array
  35. ``[[a[1,2] a[1,5]], [a[3,2] a[3,5]]]``.
  36. Parameters
  37. ----------
  38. args : 1-D sequences
  39. Each sequence should be of integer or boolean type.
  40. Boolean sequences will be interpreted as boolean masks for the
  41. corresponding dimension (equivalent to passing in
  42. ``np.nonzero(boolean_sequence)``).
  43. Returns
  44. -------
  45. out : tuple of ndarrays
  46. N arrays with N dimensions each, with N the number of input
  47. sequences. Together these arrays form an open mesh.
  48. See Also
  49. --------
  50. ogrid, mgrid, meshgrid
  51. Examples
  52. --------
  53. >>> a = np.arange(10).reshape(2, 5)
  54. >>> a
  55. array([[0, 1, 2, 3, 4],
  56. [5, 6, 7, 8, 9]])
  57. >>> ixgrid = np.ix_([0, 1], [2, 4])
  58. >>> ixgrid
  59. (array([[0],
  60. [1]]), array([[2, 4]]))
  61. >>> ixgrid[0].shape, ixgrid[1].shape
  62. ((2, 1), (1, 2))
  63. >>> a[ixgrid]
  64. array([[2, 4],
  65. [7, 9]])
  66. >>> ixgrid = np.ix_([True, True], [2, 4])
  67. >>> a[ixgrid]
  68. array([[2, 4],
  69. [7, 9]])
  70. >>> ixgrid = np.ix_([True, True], [False, False, True, False, True])
  71. >>> a[ixgrid]
  72. array([[2, 4],
  73. [7, 9]])
  74. """
  75. out = []
  76. nd = len(args)
  77. for k, new in enumerate(args):
  78. new = asarray(new)
  79. if new.ndim != 1:
  80. raise ValueError("Cross index must be 1 dimensional")
  81. if new.size == 0:
  82. # Explicitly type empty arrays to avoid float default
  83. new = new.astype(_nx.intp)
  84. if issubdtype(new.dtype, _nx.bool_):
  85. new, = new.nonzero()
  86. new = new.reshape((1,)*k + (new.size,) + (1,)*(nd-k-1))
  87. out.append(new)
  88. return tuple(out)
  89. class nd_grid(object):
  90. """
  91. Construct a multi-dimensional "meshgrid".
  92. ``grid = nd_grid()`` creates an instance which will return a mesh-grid
  93. when indexed. The dimension and number of the output arrays are equal
  94. to the number of indexing dimensions. If the step length is not a
  95. complex number, then the stop is not inclusive.
  96. However, if the step length is a **complex number** (e.g. 5j), then the
  97. integer part of its magnitude is interpreted as specifying the
  98. number of points to create between the start and stop values, where
  99. the stop value **is inclusive**.
  100. If instantiated with an argument of ``sparse=True``, the mesh-grid is
  101. open (or not fleshed out) so that only one-dimension of each returned
  102. argument is greater than 1.
  103. Parameters
  104. ----------
  105. sparse : bool, optional
  106. Whether the grid is sparse or not. Default is False.
  107. Notes
  108. -----
  109. Two instances of `nd_grid` are made available in the NumPy namespace,
  110. `mgrid` and `ogrid`, approximately defined as::
  111. mgrid = nd_grid(sparse=False)
  112. ogrid = nd_grid(sparse=True)
  113. Users should use these pre-defined instances instead of using `nd_grid`
  114. directly.
  115. """
  116. def __init__(self, sparse=False):
  117. self.sparse = sparse
  118. def __getitem__(self, key):
  119. try:
  120. size = []
  121. typ = int
  122. for k in range(len(key)):
  123. step = key[k].step
  124. start = key[k].start
  125. if start is None:
  126. start = 0
  127. if step is None:
  128. step = 1
  129. if isinstance(step, complex):
  130. size.append(int(abs(step)))
  131. typ = float
  132. else:
  133. size.append(
  134. int(math.ceil((key[k].stop - start)/(step*1.0))))
  135. if (isinstance(step, float) or
  136. isinstance(start, float) or
  137. isinstance(key[k].stop, float)):
  138. typ = float
  139. if self.sparse:
  140. nn = [_nx.arange(_x, dtype=_t)
  141. for _x, _t in zip(size, (typ,)*len(size))]
  142. else:
  143. nn = _nx.indices(size, typ)
  144. for k in range(len(size)):
  145. step = key[k].step
  146. start = key[k].start
  147. if start is None:
  148. start = 0
  149. if step is None:
  150. step = 1
  151. if isinstance(step, complex):
  152. step = int(abs(step))
  153. if step != 1:
  154. step = (key[k].stop - start)/float(step-1)
  155. nn[k] = (nn[k]*step+start)
  156. if self.sparse:
  157. slobj = [_nx.newaxis]*len(size)
  158. for k in range(len(size)):
  159. slobj[k] = slice(None, None)
  160. nn[k] = nn[k][tuple(slobj)]
  161. slobj[k] = _nx.newaxis
  162. return nn
  163. except (IndexError, TypeError):
  164. step = key.step
  165. stop = key.stop
  166. start = key.start
  167. if start is None:
  168. start = 0
  169. if isinstance(step, complex):
  170. step = abs(step)
  171. length = int(step)
  172. if step != 1:
  173. step = (key.stop-start)/float(step-1)
  174. stop = key.stop + step
  175. return _nx.arange(0, length, 1, float)*step + start
  176. else:
  177. return _nx.arange(start, stop, step)
  178. class MGridClass(nd_grid):
  179. """
  180. `nd_grid` instance which returns a dense multi-dimensional "meshgrid".
  181. An instance of `numpy.lib.index_tricks.nd_grid` which returns an dense
  182. (or fleshed out) mesh-grid when indexed, so that each returned argument
  183. has the same shape. The dimensions and number of the output arrays are
  184. equal to the number of indexing dimensions. If the step length is not a
  185. complex number, then the stop is not inclusive.
  186. However, if the step length is a **complex number** (e.g. 5j), then
  187. the integer part of its magnitude is interpreted as specifying the
  188. number of points to create between the start and stop values, where
  189. the stop value **is inclusive**.
  190. Returns
  191. ----------
  192. mesh-grid `ndarrays` all of the same dimensions
  193. See Also
  194. --------
  195. numpy.lib.index_tricks.nd_grid : class of `ogrid` and `mgrid` objects
  196. ogrid : like mgrid but returns open (not fleshed out) mesh grids
  197. r_ : array concatenator
  198. Examples
  199. --------
  200. >>> np.mgrid[0:5,0:5]
  201. array([[[0, 0, 0, 0, 0],
  202. [1, 1, 1, 1, 1],
  203. [2, 2, 2, 2, 2],
  204. [3, 3, 3, 3, 3],
  205. [4, 4, 4, 4, 4]],
  206. [[0, 1, 2, 3, 4],
  207. [0, 1, 2, 3, 4],
  208. [0, 1, 2, 3, 4],
  209. [0, 1, 2, 3, 4],
  210. [0, 1, 2, 3, 4]]])
  211. >>> np.mgrid[-1:1:5j]
  212. array([-1. , -0.5, 0. , 0.5, 1. ])
  213. """
  214. def __init__(self):
  215. super(MGridClass, self).__init__(sparse=False)
  216. mgrid = MGridClass()
  217. class OGridClass(nd_grid):
  218. """
  219. `nd_grid` instance which returns an open multi-dimensional "meshgrid".
  220. An instance of `numpy.lib.index_tricks.nd_grid` which returns an open
  221. (i.e. not fleshed out) mesh-grid when indexed, so that only one dimension
  222. of each returned array is greater than 1. The dimension and number of the
  223. output arrays are equal to the number of indexing dimensions. If the step
  224. length is not a complex number, then the stop is not inclusive.
  225. However, if the step length is a **complex number** (e.g. 5j), then
  226. the integer part of its magnitude is interpreted as specifying the
  227. number of points to create between the start and stop values, where
  228. the stop value **is inclusive**.
  229. Returns
  230. ----------
  231. mesh-grid `ndarrays` with only one dimension :math:`\\neq 1`
  232. See Also
  233. --------
  234. np.lib.index_tricks.nd_grid : class of `ogrid` and `mgrid` objects
  235. mgrid : like `ogrid` but returns dense (or fleshed out) mesh grids
  236. r_ : array concatenator
  237. Examples
  238. --------
  239. >>> from numpy import ogrid
  240. >>> ogrid[-1:1:5j]
  241. array([-1. , -0.5, 0. , 0.5, 1. ])
  242. >>> ogrid[0:5,0:5]
  243. [array([[0],
  244. [1],
  245. [2],
  246. [3],
  247. [4]]), array([[0, 1, 2, 3, 4]])]
  248. """
  249. def __init__(self):
  250. super(OGridClass, self).__init__(sparse=True)
  251. ogrid = OGridClass()
  252. class AxisConcatenator(object):
  253. """
  254. Translates slice objects to concatenation along an axis.
  255. For detailed documentation on usage, see `r_`.
  256. """
  257. # allow ma.mr_ to override this
  258. concatenate = staticmethod(_nx.concatenate)
  259. makemat = staticmethod(matrixlib.matrix)
  260. def __init__(self, axis=0, matrix=False, ndmin=1, trans1d=-1):
  261. self.axis = axis
  262. self.matrix = matrix
  263. self.trans1d = trans1d
  264. self.ndmin = ndmin
  265. def __getitem__(self, key):
  266. # handle matrix builder syntax
  267. if isinstance(key, str):
  268. frame = sys._getframe().f_back
  269. mymat = matrixlib.bmat(key, frame.f_globals, frame.f_locals)
  270. return mymat
  271. if not isinstance(key, tuple):
  272. key = (key,)
  273. # copy attributes, since they can be overridden in the first argument
  274. trans1d = self.trans1d
  275. ndmin = self.ndmin
  276. matrix = self.matrix
  277. axis = self.axis
  278. objs = []
  279. scalars = []
  280. arraytypes = []
  281. scalartypes = []
  282. for k, item in enumerate(key):
  283. scalar = False
  284. if isinstance(item, slice):
  285. step = item.step
  286. start = item.start
  287. stop = item.stop
  288. if start is None:
  289. start = 0
  290. if step is None:
  291. step = 1
  292. if isinstance(step, complex):
  293. size = int(abs(step))
  294. newobj = linspace(start, stop, num=size)
  295. else:
  296. newobj = _nx.arange(start, stop, step)
  297. if ndmin > 1:
  298. newobj = array(newobj, copy=False, ndmin=ndmin)
  299. if trans1d != -1:
  300. newobj = newobj.swapaxes(-1, trans1d)
  301. elif isinstance(item, str):
  302. if k != 0:
  303. raise ValueError("special directives must be the "
  304. "first entry.")
  305. if item in ('r', 'c'):
  306. matrix = True
  307. col = (item == 'c')
  308. continue
  309. if ',' in item:
  310. vec = item.split(',')
  311. try:
  312. axis, ndmin = [int(x) for x in vec[:2]]
  313. if len(vec) == 3:
  314. trans1d = int(vec[2])
  315. continue
  316. except Exception:
  317. raise ValueError("unknown special directive")
  318. try:
  319. axis = int(item)
  320. continue
  321. except (ValueError, TypeError):
  322. raise ValueError("unknown special directive")
  323. elif type(item) in ScalarType:
  324. newobj = array(item, ndmin=ndmin)
  325. scalars.append(len(objs))
  326. scalar = True
  327. scalartypes.append(newobj.dtype)
  328. else:
  329. item_ndim = ndim(item)
  330. newobj = array(item, copy=False, subok=True, ndmin=ndmin)
  331. if trans1d != -1 and item_ndim < ndmin:
  332. k2 = ndmin - item_ndim
  333. k1 = trans1d
  334. if k1 < 0:
  335. k1 += k2 + 1
  336. defaxes = list(range(ndmin))
  337. axes = defaxes[:k1] + defaxes[k2:] + defaxes[k1:k2]
  338. newobj = newobj.transpose(axes)
  339. objs.append(newobj)
  340. if not scalar and isinstance(newobj, _nx.ndarray):
  341. arraytypes.append(newobj.dtype)
  342. # Ensure that scalars won't up-cast unless warranted
  343. final_dtype = find_common_type(arraytypes, scalartypes)
  344. if final_dtype is not None:
  345. for k in scalars:
  346. objs[k] = objs[k].astype(final_dtype)
  347. res = self.concatenate(tuple(objs), axis=axis)
  348. if matrix:
  349. oldndim = res.ndim
  350. res = self.makemat(res)
  351. if oldndim == 1 and col:
  352. res = res.T
  353. return res
  354. def __len__(self):
  355. return 0
  356. # separate classes are used here instead of just making r_ = concatentor(0),
  357. # etc. because otherwise we couldn't get the doc string to come out right
  358. # in help(r_)
  359. class RClass(AxisConcatenator):
  360. """
  361. Translates slice objects to concatenation along the first axis.
  362. This is a simple way to build up arrays quickly. There are two use cases.
  363. 1. If the index expression contains comma separated arrays, then stack
  364. them along their first axis.
  365. 2. If the index expression contains slice notation or scalars then create
  366. a 1-D array with a range indicated by the slice notation.
  367. If slice notation is used, the syntax ``start:stop:step`` is equivalent
  368. to ``np.arange(start, stop, step)`` inside of the brackets. However, if
  369. ``step`` is an imaginary number (i.e. 100j) then its integer portion is
  370. interpreted as a number-of-points desired and the start and stop are
  371. inclusive. In other words ``start:stop:stepj`` is interpreted as
  372. ``np.linspace(start, stop, step, endpoint=1)`` inside of the brackets.
  373. After expansion of slice notation, all comma separated sequences are
  374. concatenated together.
  375. Optional character strings placed as the first element of the index
  376. expression can be used to change the output. The strings 'r' or 'c' result
  377. in matrix output. If the result is 1-D and 'r' is specified a 1 x N (row)
  378. matrix is produced. If the result is 1-D and 'c' is specified, then a N x 1
  379. (column) matrix is produced. If the result is 2-D then both provide the
  380. same matrix result.
  381. A string integer specifies which axis to stack multiple comma separated
  382. arrays along. A string of two comma-separated integers allows indication
  383. of the minimum number of dimensions to force each entry into as the
  384. second integer (the axis to concatenate along is still the first integer).
  385. A string with three comma-separated integers allows specification of the
  386. axis to concatenate along, the minimum number of dimensions to force the
  387. entries to, and which axis should contain the start of the arrays which
  388. are less than the specified number of dimensions. In other words the third
  389. integer allows you to specify where the 1's should be placed in the shape
  390. of the arrays that have their shapes upgraded. By default, they are placed
  391. in the front of the shape tuple. The third argument allows you to specify
  392. where the start of the array should be instead. Thus, a third argument of
  393. '0' would place the 1's at the end of the array shape. Negative integers
  394. specify where in the new shape tuple the last dimension of upgraded arrays
  395. should be placed, so the default is '-1'.
  396. Parameters
  397. ----------
  398. Not a function, so takes no parameters
  399. Returns
  400. -------
  401. A concatenated ndarray or matrix.
  402. See Also
  403. --------
  404. concatenate : Join a sequence of arrays along an existing axis.
  405. c_ : Translates slice objects to concatenation along the second axis.
  406. Examples
  407. --------
  408. >>> np.r_[np.array([1,2,3]), 0, 0, np.array([4,5,6])]
  409. array([1, 2, 3, 0, 0, 4, 5, 6])
  410. >>> np.r_[-1:1:6j, [0]*3, 5, 6]
  411. array([-1. , -0.6, -0.2, 0.2, 0.6, 1. , 0. , 0. , 0. , 5. , 6. ])
  412. String integers specify the axis to concatenate along or the minimum
  413. number of dimensions to force entries into.
  414. >>> a = np.array([[0, 1, 2], [3, 4, 5]])
  415. >>> np.r_['-1', a, a] # concatenate along last axis
  416. array([[0, 1, 2, 0, 1, 2],
  417. [3, 4, 5, 3, 4, 5]])
  418. >>> np.r_['0,2', [1,2,3], [4,5,6]] # concatenate along first axis, dim>=2
  419. array([[1, 2, 3],
  420. [4, 5, 6]])
  421. >>> np.r_['0,2,0', [1,2,3], [4,5,6]]
  422. array([[1],
  423. [2],
  424. [3],
  425. [4],
  426. [5],
  427. [6]])
  428. >>> np.r_['1,2,0', [1,2,3], [4,5,6]]
  429. array([[1, 4],
  430. [2, 5],
  431. [3, 6]])
  432. Using 'r' or 'c' as a first string argument creates a matrix.
  433. >>> np.r_['r',[1,2,3], [4,5,6]]
  434. matrix([[1, 2, 3, 4, 5, 6]])
  435. """
  436. def __init__(self):
  437. AxisConcatenator.__init__(self, 0)
  438. r_ = RClass()
  439. class CClass(AxisConcatenator):
  440. """
  441. Translates slice objects to concatenation along the second axis.
  442. This is short-hand for ``np.r_['-1,2,0', index expression]``, which is
  443. useful because of its common occurrence. In particular, arrays will be
  444. stacked along their last axis after being upgraded to at least 2-D with
  445. 1's post-pended to the shape (column vectors made out of 1-D arrays).
  446. See Also
  447. --------
  448. column_stack : Stack 1-D arrays as columns into a 2-D array.
  449. r_ : For more detailed documentation.
  450. Examples
  451. --------
  452. >>> np.c_[np.array([1,2,3]), np.array([4,5,6])]
  453. array([[1, 4],
  454. [2, 5],
  455. [3, 6]])
  456. >>> np.c_[np.array([[1,2,3]]), 0, 0, np.array([[4,5,6]])]
  457. array([[1, 2, 3, 0, 0, 4, 5, 6]])
  458. """
  459. def __init__(self):
  460. AxisConcatenator.__init__(self, -1, ndmin=2, trans1d=0)
  461. c_ = CClass()
  462. @set_module('numpy')
  463. class ndenumerate(object):
  464. """
  465. Multidimensional index iterator.
  466. Return an iterator yielding pairs of array coordinates and values.
  467. Parameters
  468. ----------
  469. arr : ndarray
  470. Input array.
  471. See Also
  472. --------
  473. ndindex, flatiter
  474. Examples
  475. --------
  476. >>> a = np.array([[1, 2], [3, 4]])
  477. >>> for index, x in np.ndenumerate(a):
  478. ... print(index, x)
  479. (0, 0) 1
  480. (0, 1) 2
  481. (1, 0) 3
  482. (1, 1) 4
  483. """
  484. def __init__(self, arr):
  485. self.iter = asarray(arr).flat
  486. def __next__(self):
  487. """
  488. Standard iterator method, returns the index tuple and array value.
  489. Returns
  490. -------
  491. coords : tuple of ints
  492. The indices of the current iteration.
  493. val : scalar
  494. The array element of the current iteration.
  495. """
  496. return self.iter.coords, next(self.iter)
  497. def __iter__(self):
  498. return self
  499. next = __next__
  500. @set_module('numpy')
  501. class ndindex(object):
  502. """
  503. An N-dimensional iterator object to index arrays.
  504. Given the shape of an array, an `ndindex` instance iterates over
  505. the N-dimensional index of the array. At each iteration a tuple
  506. of indices is returned, the last dimension is iterated over first.
  507. Parameters
  508. ----------
  509. `*args` : ints
  510. The size of each dimension of the array.
  511. See Also
  512. --------
  513. ndenumerate, flatiter
  514. Examples
  515. --------
  516. >>> for index in np.ndindex(3, 2, 1):
  517. ... print(index)
  518. (0, 0, 0)
  519. (0, 1, 0)
  520. (1, 0, 0)
  521. (1, 1, 0)
  522. (2, 0, 0)
  523. (2, 1, 0)
  524. """
  525. def __init__(self, *shape):
  526. if len(shape) == 1 and isinstance(shape[0], tuple):
  527. shape = shape[0]
  528. x = as_strided(_nx.zeros(1), shape=shape,
  529. strides=_nx.zeros_like(shape))
  530. self._it = _nx.nditer(x, flags=['multi_index', 'zerosize_ok'],
  531. order='C')
  532. def __iter__(self):
  533. return self
  534. def ndincr(self):
  535. """
  536. Increment the multi-dimensional index by one.
  537. This method is for backward compatibility only: do not use.
  538. """
  539. next(self)
  540. def __next__(self):
  541. """
  542. Standard iterator method, updates the index and returns the index
  543. tuple.
  544. Returns
  545. -------
  546. val : tuple of ints
  547. Returns a tuple containing the indices of the current
  548. iteration.
  549. """
  550. next(self._it)
  551. return self._it.multi_index
  552. next = __next__
  553. # You can do all this with slice() plus a few special objects,
  554. # but there's a lot to remember. This version is simpler because
  555. # it uses the standard array indexing syntax.
  556. #
  557. # Written by Konrad Hinsen <hinsen@cnrs-orleans.fr>
  558. # last revision: 1999-7-23
  559. #
  560. # Cosmetic changes by T. Oliphant 2001
  561. #
  562. #
  563. class IndexExpression(object):
  564. """
  565. A nicer way to build up index tuples for arrays.
  566. .. note::
  567. Use one of the two predefined instances `index_exp` or `s_`
  568. rather than directly using `IndexExpression`.
  569. For any index combination, including slicing and axis insertion,
  570. ``a[indices]`` is the same as ``a[np.index_exp[indices]]`` for any
  571. array `a`. However, ``np.index_exp[indices]`` can be used anywhere
  572. in Python code and returns a tuple of slice objects that can be
  573. used in the construction of complex index expressions.
  574. Parameters
  575. ----------
  576. maketuple : bool
  577. If True, always returns a tuple.
  578. See Also
  579. --------
  580. index_exp : Predefined instance that always returns a tuple:
  581. `index_exp = IndexExpression(maketuple=True)`.
  582. s_ : Predefined instance without tuple conversion:
  583. `s_ = IndexExpression(maketuple=False)`.
  584. Notes
  585. -----
  586. You can do all this with `slice()` plus a few special objects,
  587. but there's a lot to remember and this version is simpler because
  588. it uses the standard array indexing syntax.
  589. Examples
  590. --------
  591. >>> np.s_[2::2]
  592. slice(2, None, 2)
  593. >>> np.index_exp[2::2]
  594. (slice(2, None, 2),)
  595. >>> np.array([0, 1, 2, 3, 4])[np.s_[2::2]]
  596. array([2, 4])
  597. """
  598. def __init__(self, maketuple):
  599. self.maketuple = maketuple
  600. def __getitem__(self, item):
  601. if self.maketuple and not isinstance(item, tuple):
  602. return (item,)
  603. else:
  604. return item
  605. index_exp = IndexExpression(maketuple=True)
  606. s_ = IndexExpression(maketuple=False)
  607. # End contribution from Konrad.
  608. # The following functions complement those in twodim_base, but are
  609. # applicable to N-dimensions.
  610. def _fill_diagonal_dispatcher(a, val, wrap=None):
  611. return (a,)
  612. @array_function_dispatch(_fill_diagonal_dispatcher)
  613. def fill_diagonal(a, val, wrap=False):
  614. """Fill the main diagonal of the given array of any dimensionality.
  615. For an array `a` with ``a.ndim >= 2``, the diagonal is the list of
  616. locations with indices ``a[i, ..., i]`` all identical. This function
  617. modifies the input array in-place, it does not return a value.
  618. Parameters
  619. ----------
  620. a : array, at least 2-D.
  621. Array whose diagonal is to be filled, it gets modified in-place.
  622. val : scalar
  623. Value to be written on the diagonal, its type must be compatible with
  624. that of the array a.
  625. wrap : bool
  626. For tall matrices in NumPy version up to 1.6.2, the
  627. diagonal "wrapped" after N columns. You can have this behavior
  628. with this option. This affects only tall matrices.
  629. See also
  630. --------
  631. diag_indices, diag_indices_from
  632. Notes
  633. -----
  634. .. versionadded:: 1.4.0
  635. This functionality can be obtained via `diag_indices`, but internally
  636. this version uses a much faster implementation that never constructs the
  637. indices and uses simple slicing.
  638. Examples
  639. --------
  640. >>> a = np.zeros((3, 3), int)
  641. >>> np.fill_diagonal(a, 5)
  642. >>> a
  643. array([[5, 0, 0],
  644. [0, 5, 0],
  645. [0, 0, 5]])
  646. The same function can operate on a 4-D array:
  647. >>> a = np.zeros((3, 3, 3, 3), int)
  648. >>> np.fill_diagonal(a, 4)
  649. We only show a few blocks for clarity:
  650. >>> a[0, 0]
  651. array([[4, 0, 0],
  652. [0, 0, 0],
  653. [0, 0, 0]])
  654. >>> a[1, 1]
  655. array([[0, 0, 0],
  656. [0, 4, 0],
  657. [0, 0, 0]])
  658. >>> a[2, 2]
  659. array([[0, 0, 0],
  660. [0, 0, 0],
  661. [0, 0, 4]])
  662. The wrap option affects only tall matrices:
  663. >>> # tall matrices no wrap
  664. >>> a = np.zeros((5, 3),int)
  665. >>> fill_diagonal(a, 4)
  666. >>> a
  667. array([[4, 0, 0],
  668. [0, 4, 0],
  669. [0, 0, 4],
  670. [0, 0, 0],
  671. [0, 0, 0]])
  672. >>> # tall matrices wrap
  673. >>> a = np.zeros((5, 3),int)
  674. >>> fill_diagonal(a, 4, wrap=True)
  675. >>> a
  676. array([[4, 0, 0],
  677. [0, 4, 0],
  678. [0, 0, 4],
  679. [0, 0, 0],
  680. [4, 0, 0]])
  681. >>> # wide matrices
  682. >>> a = np.zeros((3, 5),int)
  683. >>> fill_diagonal(a, 4, wrap=True)
  684. >>> a
  685. array([[4, 0, 0, 0, 0],
  686. [0, 4, 0, 0, 0],
  687. [0, 0, 4, 0, 0]])
  688. """
  689. if a.ndim < 2:
  690. raise ValueError("array must be at least 2-d")
  691. end = None
  692. if a.ndim == 2:
  693. # Explicit, fast formula for the common case. For 2-d arrays, we
  694. # accept rectangular ones.
  695. step = a.shape[1] + 1
  696. #This is needed to don't have tall matrix have the diagonal wrap.
  697. if not wrap:
  698. end = a.shape[1] * a.shape[1]
  699. else:
  700. # For more than d=2, the strided formula is only valid for arrays with
  701. # all dimensions equal, so we check first.
  702. if not alltrue(diff(a.shape) == 0):
  703. raise ValueError("All dimensions of input must be of equal length")
  704. step = 1 + (cumprod(a.shape[:-1])).sum()
  705. # Write the value out into the diagonal.
  706. a.flat[:end:step] = val
  707. @set_module('numpy')
  708. def diag_indices(n, ndim=2):
  709. """
  710. Return the indices to access the main diagonal of an array.
  711. This returns a tuple of indices that can be used to access the main
  712. diagonal of an array `a` with ``a.ndim >= 2`` dimensions and shape
  713. (n, n, ..., n). For ``a.ndim = 2`` this is the usual diagonal, for
  714. ``a.ndim > 2`` this is the set of indices to access ``a[i, i, ..., i]``
  715. for ``i = [0..n-1]``.
  716. Parameters
  717. ----------
  718. n : int
  719. The size, along each dimension, of the arrays for which the returned
  720. indices can be used.
  721. ndim : int, optional
  722. The number of dimensions.
  723. See also
  724. --------
  725. diag_indices_from
  726. Notes
  727. -----
  728. .. versionadded:: 1.4.0
  729. Examples
  730. --------
  731. Create a set of indices to access the diagonal of a (4, 4) array:
  732. >>> di = np.diag_indices(4)
  733. >>> di
  734. (array([0, 1, 2, 3]), array([0, 1, 2, 3]))
  735. >>> a = np.arange(16).reshape(4, 4)
  736. >>> a
  737. array([[ 0, 1, 2, 3],
  738. [ 4, 5, 6, 7],
  739. [ 8, 9, 10, 11],
  740. [12, 13, 14, 15]])
  741. >>> a[di] = 100
  742. >>> a
  743. array([[100, 1, 2, 3],
  744. [ 4, 100, 6, 7],
  745. [ 8, 9, 100, 11],
  746. [ 12, 13, 14, 100]])
  747. Now, we create indices to manipulate a 3-D array:
  748. >>> d3 = np.diag_indices(2, 3)
  749. >>> d3
  750. (array([0, 1]), array([0, 1]), array([0, 1]))
  751. And use it to set the diagonal of an array of zeros to 1:
  752. >>> a = np.zeros((2, 2, 2), dtype=int)
  753. >>> a[d3] = 1
  754. >>> a
  755. array([[[1, 0],
  756. [0, 0]],
  757. [[0, 0],
  758. [0, 1]]])
  759. """
  760. idx = arange(n)
  761. return (idx,) * ndim
  762. def _diag_indices_from(arr):
  763. return (arr,)
  764. @array_function_dispatch(_diag_indices_from)
  765. def diag_indices_from(arr):
  766. """
  767. Return the indices to access the main diagonal of an n-dimensional array.
  768. See `diag_indices` for full details.
  769. Parameters
  770. ----------
  771. arr : array, at least 2-D
  772. See Also
  773. --------
  774. diag_indices
  775. Notes
  776. -----
  777. .. versionadded:: 1.4.0
  778. """
  779. if not arr.ndim >= 2:
  780. raise ValueError("input array must be at least 2-d")
  781. # For more than d=2, the strided formula is only valid for arrays with
  782. # all dimensions equal, so we check first.
  783. if not alltrue(diff(arr.shape) == 0):
  784. raise ValueError("All dimensions of input must be of equal length")
  785. return diag_indices(arr.shape[0], arr.ndim)