sputils.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. """ Utility functions for sparse matrix module
  2. """
  3. from __future__ import division, print_function, absolute_import
  4. import operator
  5. import warnings
  6. import numpy as np
  7. from scipy._lib._version import NumpyVersion
  8. if NumpyVersion(np.__version__) >= '1.17.0':
  9. from scipy._lib._util import _broadcast_arrays
  10. else:
  11. from numpy import broadcast_arrays as _broadcast_arrays
  12. __all__ = ['upcast', 'getdtype', 'isscalarlike', 'isintlike',
  13. 'isshape', 'issequence', 'isdense', 'ismatrix', 'get_sum_dtype']
  14. supported_dtypes = ['bool', 'int8', 'uint8', 'short', 'ushort', 'intc',
  15. 'uintc', 'l', 'L', 'longlong', 'ulonglong', 'single', 'double',
  16. 'longdouble', 'csingle', 'cdouble', 'clongdouble']
  17. supported_dtypes = [np.typeDict[x] for x in supported_dtypes]
  18. _upcast_memo = {}
  19. def upcast(*args):
  20. """Returns the nearest supported sparse dtype for the
  21. combination of one or more types.
  22. upcast(t0, t1, ..., tn) -> T where T is a supported dtype
  23. Examples
  24. --------
  25. >>> upcast('int32')
  26. <type 'numpy.int32'>
  27. >>> upcast('bool')
  28. <type 'numpy.bool_'>
  29. >>> upcast('int32','float32')
  30. <type 'numpy.float64'>
  31. >>> upcast('bool',complex,float)
  32. <type 'numpy.complex128'>
  33. """
  34. t = _upcast_memo.get(hash(args))
  35. if t is not None:
  36. return t
  37. upcast = np.find_common_type(args, [])
  38. for t in supported_dtypes:
  39. if np.can_cast(upcast, t):
  40. _upcast_memo[hash(args)] = t
  41. return t
  42. raise TypeError('no supported conversion for types: %r' % (args,))
  43. def upcast_char(*args):
  44. """Same as `upcast` but taking dtype.char as input (faster)."""
  45. t = _upcast_memo.get(args)
  46. if t is not None:
  47. return t
  48. t = upcast(*map(np.dtype, args))
  49. _upcast_memo[args] = t
  50. return t
  51. def upcast_scalar(dtype, scalar):
  52. """Determine data type for binary operation between an array of
  53. type `dtype` and a scalar.
  54. """
  55. return (np.array([0], dtype=dtype) * scalar).dtype
  56. def downcast_intp_index(arr):
  57. """
  58. Down-cast index array to np.intp dtype if it is of a larger dtype.
  59. Raise an error if the array contains a value that is too large for
  60. intp.
  61. """
  62. if arr.dtype.itemsize > np.dtype(np.intp).itemsize:
  63. if arr.size == 0:
  64. return arr.astype(np.intp)
  65. maxval = arr.max()
  66. minval = arr.min()
  67. if maxval > np.iinfo(np.intp).max or minval < np.iinfo(np.intp).min:
  68. raise ValueError("Cannot deal with arrays with indices larger "
  69. "than the machine maximum address size "
  70. "(e.g. 64-bit indices on 32-bit machine).")
  71. return arr.astype(np.intp)
  72. return arr
  73. def to_native(A):
  74. return np.asarray(A, dtype=A.dtype.newbyteorder('native'))
  75. def getdtype(dtype, a=None, default=None):
  76. """Function used to simplify argument processing. If 'dtype' is not
  77. specified (is None), returns a.dtype; otherwise returns a np.dtype
  78. object created from the specified dtype argument. If 'dtype' and 'a'
  79. are both None, construct a data type out of the 'default' parameter.
  80. Furthermore, 'dtype' must be in 'allowed' set.
  81. """
  82. # TODO is this really what we want?
  83. if dtype is None:
  84. try:
  85. newdtype = a.dtype
  86. except AttributeError:
  87. if default is not None:
  88. newdtype = np.dtype(default)
  89. else:
  90. raise TypeError("could not interpret data type")
  91. else:
  92. newdtype = np.dtype(dtype)
  93. if newdtype == np.object_:
  94. warnings.warn("object dtype is not supported by sparse matrices")
  95. return newdtype
  96. def get_index_dtype(arrays=(), maxval=None, check_contents=False):
  97. """
  98. Based on input (integer) arrays `a`, determine a suitable index data
  99. type that can hold the data in the arrays.
  100. Parameters
  101. ----------
  102. arrays : tuple of array_like
  103. Input arrays whose types/contents to check
  104. maxval : float, optional
  105. Maximum value needed
  106. check_contents : bool, optional
  107. Whether to check the values in the arrays and not just their types.
  108. Default: False (check only the types)
  109. Returns
  110. -------
  111. dtype : dtype
  112. Suitable index data type (int32 or int64)
  113. """
  114. int32min = np.iinfo(np.int32).min
  115. int32max = np.iinfo(np.int32).max
  116. dtype = np.intc
  117. if maxval is not None:
  118. if maxval > int32max:
  119. dtype = np.int64
  120. if isinstance(arrays, np.ndarray):
  121. arrays = (arrays,)
  122. for arr in arrays:
  123. arr = np.asarray(arr)
  124. if not np.can_cast(arr.dtype, np.int32):
  125. if check_contents:
  126. if arr.size == 0:
  127. # a bigger type not needed
  128. continue
  129. elif np.issubdtype(arr.dtype, np.integer):
  130. maxval = arr.max()
  131. minval = arr.min()
  132. if minval >= int32min and maxval <= int32max:
  133. # a bigger type not needed
  134. continue
  135. dtype = np.int64
  136. break
  137. return dtype
  138. def get_sum_dtype(dtype):
  139. """Mimic numpy's casting for np.sum"""
  140. if dtype.kind == 'u' and np.can_cast(dtype, np.uint):
  141. return np.uint
  142. if np.can_cast(dtype, np.int_):
  143. return np.int_
  144. return dtype
  145. def isscalarlike(x):
  146. """Is x either a scalar, an array scalar, or a 0-dim array?"""
  147. return np.isscalar(x) or (isdense(x) and x.ndim == 0)
  148. def isintlike(x):
  149. """Is x appropriate as an index into a sparse matrix? Returns True
  150. if it can be cast safely to a machine int.
  151. """
  152. # Fast-path check to eliminate non-scalar values. operator.index would
  153. # catch this case too, but the exception catching is slow.
  154. if np.ndim(x) != 0:
  155. return False
  156. try:
  157. operator.index(x)
  158. except (TypeError, ValueError):
  159. try:
  160. loose_int = bool(int(x) == x)
  161. except (TypeError, ValueError):
  162. return False
  163. if loose_int:
  164. warnings.warn("Inexact indices into sparse matrices are deprecated",
  165. DeprecationWarning)
  166. return loose_int
  167. return True
  168. def isshape(x, nonneg=False):
  169. """Is x a valid 2-tuple of dimensions?
  170. If nonneg, also checks that the dimensions are non-negative.
  171. """
  172. try:
  173. # Assume it's a tuple of matrix dimensions (M, N)
  174. (M, N) = x
  175. except Exception:
  176. return False
  177. else:
  178. if isintlike(M) and isintlike(N):
  179. if np.ndim(M) == 0 and np.ndim(N) == 0:
  180. if not nonneg or (M >= 0 and N >= 0):
  181. return True
  182. return False
  183. def issequence(t):
  184. return ((isinstance(t, (list, tuple)) and
  185. (len(t) == 0 or np.isscalar(t[0]))) or
  186. (isinstance(t, np.ndarray) and (t.ndim == 1)))
  187. def ismatrix(t):
  188. return ((isinstance(t, (list, tuple)) and
  189. len(t) > 0 and issequence(t[0])) or
  190. (isinstance(t, np.ndarray) and t.ndim == 2))
  191. def isdense(x):
  192. return isinstance(x, np.ndarray)
  193. def validateaxis(axis):
  194. if axis is not None:
  195. axis_type = type(axis)
  196. # In NumPy, you can pass in tuples for 'axis', but they are
  197. # not very useful for sparse matrices given their limited
  198. # dimensions, so let's make it explicit that they are not
  199. # allowed to be passed in
  200. if axis_type == tuple:
  201. raise TypeError(("Tuples are not accepted for the 'axis' "
  202. "parameter. Please pass in one of the "
  203. "following: {-2, -1, 0, 1, None}."))
  204. # If not a tuple, check that the provided axis is actually
  205. # an integer and raise a TypeError similar to NumPy's
  206. if not np.issubdtype(np.dtype(axis_type), np.integer):
  207. raise TypeError("axis must be an integer, not {name}"
  208. .format(name=axis_type.__name__))
  209. if not (-2 <= axis <= 1):
  210. raise ValueError("axis out of range")
  211. def check_shape(args, current_shape=None):
  212. """Imitate numpy.matrix handling of shape arguments"""
  213. if len(args) == 0:
  214. raise TypeError("function missing 1 required positional argument: "
  215. "'shape'")
  216. elif len(args) == 1:
  217. try:
  218. shape_iter = iter(args[0])
  219. except TypeError:
  220. new_shape = (operator.index(args[0]), )
  221. else:
  222. new_shape = tuple(operator.index(arg) for arg in shape_iter)
  223. else:
  224. new_shape = tuple(operator.index(arg) for arg in args)
  225. if current_shape is None:
  226. if len(new_shape) != 2:
  227. raise ValueError('shape must be a 2-tuple of positive integers')
  228. elif new_shape[0] < 0 or new_shape[1] < 0:
  229. raise ValueError("'shape' elements cannot be negative")
  230. else:
  231. # Check the current size only if needed
  232. current_size = np.prod(current_shape, dtype=int)
  233. # Check for negatives
  234. negative_indexes = [i for i, x in enumerate(new_shape) if x < 0]
  235. if len(negative_indexes) == 0:
  236. new_size = np.prod(new_shape, dtype=int)
  237. if new_size != current_size:
  238. raise ValueError('cannot reshape array of size {} into shape {}'
  239. .format(new_size, new_shape))
  240. elif len(negative_indexes) == 1:
  241. skip = negative_indexes[0]
  242. specified = np.prod(new_shape[0:skip] + new_shape[skip+1:])
  243. unspecified, remainder = divmod(current_size, specified)
  244. if remainder != 0:
  245. err_shape = tuple('newshape' if x < 0 else x for x in new_shape)
  246. raise ValueError('cannot reshape array of size {} into shape {}'
  247. ''.format(current_size, err_shape))
  248. new_shape = new_shape[0:skip] + (unspecified,) + new_shape[skip+1:]
  249. else:
  250. raise ValueError('can only specify one unknown dimension')
  251. # Add and remove ones like numpy.matrix.reshape
  252. if len(new_shape) != 2:
  253. new_shape = tuple(arg for arg in new_shape if arg != 1)
  254. if len(new_shape) == 0:
  255. new_shape = (1, 1)
  256. elif len(new_shape) == 1:
  257. new_shape = (1, new_shape[0])
  258. if len(new_shape) > 2:
  259. raise ValueError('shape too large to be a matrix')
  260. return new_shape
  261. def check_reshape_kwargs(kwargs):
  262. """Unpack keyword arguments for reshape function.
  263. This is useful because keyword arguments after star arguments are not
  264. allowed in Python 2, but star keyword arguments are. This function unpacks
  265. 'order' and 'copy' from the star keyword arguments (with defaults) and
  266. throws an error for any remaining.
  267. """
  268. order = kwargs.pop('order', 'C')
  269. copy = kwargs.pop('copy', False)
  270. if kwargs: # Some unused kwargs remain
  271. raise TypeError('reshape() got unexpected keywords arguments: {}'
  272. .format(', '.join(kwargs.keys())))
  273. return order, copy
  274. class IndexMixin(object):
  275. """
  276. This class simply exists to hold the methods necessary for fancy indexing.
  277. """
  278. def _slicetoarange(self, j, shape):
  279. """ Given a slice object, use numpy arange to change it to a 1D
  280. array.
  281. """
  282. start, stop, step = j.indices(shape)
  283. return np.arange(start, stop, step)
  284. def _unpack_index(self, index):
  285. """ Parse index. Always return a tuple of the form (row, col).
  286. Where row/col is a integer, slice, or array of integers.
  287. """
  288. # First, check if indexing with single boolean matrix.
  289. from .base import spmatrix # This feels dirty but...
  290. if (isinstance(index, (spmatrix, np.ndarray)) and
  291. (index.ndim == 2) and index.dtype.kind == 'b'):
  292. return index.nonzero()
  293. # Parse any ellipses.
  294. index = self._check_ellipsis(index)
  295. # Next, parse the tuple or object
  296. if isinstance(index, tuple):
  297. if len(index) == 2:
  298. row, col = index
  299. elif len(index) == 1:
  300. row, col = index[0], slice(None)
  301. else:
  302. raise IndexError('invalid number of indices')
  303. else:
  304. row, col = index, slice(None)
  305. # Next, check for validity, or transform the index as needed.
  306. row, col = self._check_boolean(row, col)
  307. return row, col
  308. def _check_ellipsis(self, index):
  309. """Process indices with Ellipsis. Returns modified index."""
  310. if index is Ellipsis:
  311. return (slice(None), slice(None))
  312. elif isinstance(index, tuple):
  313. # Find first ellipsis
  314. for j, v in enumerate(index):
  315. if v is Ellipsis:
  316. first_ellipsis = j
  317. break
  318. else:
  319. first_ellipsis = None
  320. # Expand the first one
  321. if first_ellipsis is not None:
  322. # Shortcuts
  323. if len(index) == 1:
  324. return (slice(None), slice(None))
  325. elif len(index) == 2:
  326. if first_ellipsis == 0:
  327. if index[1] is Ellipsis:
  328. return (slice(None), slice(None))
  329. else:
  330. return (slice(None), index[1])
  331. else:
  332. return (index[0], slice(None))
  333. # General case
  334. tail = ()
  335. for v in index[first_ellipsis+1:]:
  336. if v is not Ellipsis:
  337. tail = tail + (v,)
  338. nd = first_ellipsis + len(tail)
  339. nslice = max(0, 2 - nd)
  340. return index[:first_ellipsis] + (slice(None),)*nslice + tail
  341. return index
  342. def _check_boolean(self, row, col):
  343. from .base import isspmatrix # ew...
  344. # Supporting sparse boolean indexing with both row and col does
  345. # not work because spmatrix.ndim is always 2.
  346. if isspmatrix(row) or isspmatrix(col):
  347. raise IndexError(
  348. "Indexing with sparse matrices is not supported "
  349. "except boolean indexing where matrix and index "
  350. "are equal shapes.")
  351. if isinstance(row, np.ndarray) and row.dtype.kind == 'b':
  352. row = self._boolean_index_to_array(row)
  353. if isinstance(col, np.ndarray) and col.dtype.kind == 'b':
  354. col = self._boolean_index_to_array(col)
  355. return row, col
  356. def _boolean_index_to_array(self, i):
  357. if i.ndim > 1:
  358. raise IndexError('invalid index shape')
  359. return i.nonzero()[0]
  360. def _index_to_arrays(self, i, j):
  361. i, j = self._check_boolean(i, j)
  362. i_slice = isinstance(i, slice)
  363. if i_slice:
  364. i = self._slicetoarange(i, self.shape[0])[:, None]
  365. else:
  366. i = np.atleast_1d(i)
  367. if isinstance(j, slice):
  368. j = self._slicetoarange(j, self.shape[1])[None, :]
  369. if i.ndim == 1:
  370. i = i[:, None]
  371. elif not i_slice:
  372. raise IndexError('index returns 3-dim structure')
  373. elif isscalarlike(j):
  374. # row vector special case
  375. j = np.atleast_1d(j)
  376. if i.ndim == 1:
  377. i, j = _broadcast_arrays(i, j)
  378. i = i[:, None]
  379. j = j[:, None]
  380. return i, j
  381. else:
  382. j = np.atleast_1d(j)
  383. if i_slice and j.ndim > 1:
  384. raise IndexError('index returns 3-dim structure')
  385. i, j = _broadcast_arrays(i, j)
  386. if i.ndim == 1:
  387. # return column vectors for 1-D indexing
  388. i = i[None, :]
  389. j = j[None, :]
  390. elif i.ndim > 2:
  391. raise IndexError("Index dimension must be <= 2")
  392. return i, j