compressed.py 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206
  1. """Base class for sparse matrix formats using compressed storage."""
  2. from __future__ import division, print_function, absolute_import
  3. __all__ = []
  4. from warnings import warn
  5. import operator
  6. import numpy as np
  7. from scipy._lib.six import zip as izip
  8. from scipy._lib._util import _prune_array
  9. from .base import spmatrix, isspmatrix, SparseEfficiencyWarning
  10. from .data import _data_matrix, _minmax_mixin
  11. from .dia import dia_matrix
  12. from . import _sparsetools
  13. from .sputils import (upcast, upcast_char, to_native, isdense, isshape,
  14. getdtype, isscalarlike, IndexMixin, get_index_dtype,
  15. downcast_intp_index, get_sum_dtype, check_shape)
  16. class _cs_matrix(_data_matrix, _minmax_mixin, IndexMixin):
  17. """base matrix class for compressed row and column oriented matrices"""
  18. def __init__(self, arg1, shape=None, dtype=None, copy=False):
  19. _data_matrix.__init__(self)
  20. if isspmatrix(arg1):
  21. if arg1.format == self.format and copy:
  22. arg1 = arg1.copy()
  23. else:
  24. arg1 = arg1.asformat(self.format)
  25. self._set_self(arg1)
  26. elif isinstance(arg1, tuple):
  27. if isshape(arg1):
  28. # It's a tuple of matrix dimensions (M, N)
  29. # create empty matrix
  30. self._shape = check_shape(arg1)
  31. M, N = self.shape
  32. # Select index dtype large enough to pass array and
  33. # scalar parameters to sparsetools
  34. idx_dtype = get_index_dtype(maxval=max(M, N))
  35. self.data = np.zeros(0, getdtype(dtype, default=float))
  36. self.indices = np.zeros(0, idx_dtype)
  37. self.indptr = np.zeros(self._swap((M, N))[0] + 1,
  38. dtype=idx_dtype)
  39. else:
  40. if len(arg1) == 2:
  41. # (data, ij) format
  42. from .coo import coo_matrix
  43. other = self.__class__(coo_matrix(arg1, shape=shape))
  44. self._set_self(other)
  45. elif len(arg1) == 3:
  46. # (data, indices, indptr) format
  47. (data, indices, indptr) = arg1
  48. # Select index dtype large enough to pass array and
  49. # scalar parameters to sparsetools
  50. maxval = None
  51. if shape is not None:
  52. maxval = max(shape)
  53. idx_dtype = get_index_dtype((indices, indptr),
  54. maxval=maxval,
  55. check_contents=True)
  56. self.indices = np.array(indices, copy=copy,
  57. dtype=idx_dtype)
  58. self.indptr = np.array(indptr, copy=copy, dtype=idx_dtype)
  59. self.data = np.array(data, copy=copy, dtype=dtype)
  60. else:
  61. raise ValueError("unrecognized {}_matrix "
  62. "constructor usage".format(self.format))
  63. else:
  64. # must be dense
  65. try:
  66. arg1 = np.asarray(arg1)
  67. except Exception:
  68. raise ValueError("unrecognized {}_matrix constructor usage"
  69. "".format(self.format))
  70. from .coo import coo_matrix
  71. self._set_self(self.__class__(coo_matrix(arg1, dtype=dtype)))
  72. # Read matrix dimensions given, if any
  73. if shape is not None:
  74. self._shape = check_shape(shape)
  75. else:
  76. if self.shape is None:
  77. # shape not already set, try to infer dimensions
  78. try:
  79. major_dim = len(self.indptr) - 1
  80. minor_dim = self.indices.max() + 1
  81. except Exception:
  82. raise ValueError('unable to infer matrix dimensions')
  83. else:
  84. self._shape = check_shape(self._swap((major_dim,
  85. minor_dim)))
  86. if dtype is not None:
  87. self.data = np.asarray(self.data, dtype=dtype)
  88. self.check_format(full_check=False)
  89. def getnnz(self, axis=None):
  90. if axis is None:
  91. return int(self.indptr[-1])
  92. else:
  93. if axis < 0:
  94. axis += 2
  95. axis, _ = self._swap((axis, 1 - axis))
  96. _, N = self._swap(self.shape)
  97. if axis == 0:
  98. return np.bincount(downcast_intp_index(self.indices),
  99. minlength=N)
  100. elif axis == 1:
  101. return np.diff(self.indptr)
  102. raise ValueError('axis out of bounds')
  103. getnnz.__doc__ = spmatrix.getnnz.__doc__
  104. def _set_self(self, other, copy=False):
  105. """take the member variables of other and assign them to self"""
  106. if copy:
  107. other = other.copy()
  108. self.data = other.data
  109. self.indices = other.indices
  110. self.indptr = other.indptr
  111. self._shape = check_shape(other.shape)
  112. def check_format(self, full_check=True):
  113. """check whether the matrix format is valid
  114. Parameters
  115. ----------
  116. full_check : bool, optional
  117. If `True`, rigorous check, O(N) operations. Otherwise
  118. basic check, O(1) operations (default True).
  119. """
  120. # use _swap to determine proper bounds
  121. major_name, minor_name = self._swap(('row', 'column'))
  122. major_dim, minor_dim = self._swap(self.shape)
  123. # index arrays should have integer data types
  124. if self.indptr.dtype.kind != 'i':
  125. warn("indptr array has non-integer dtype ({})"
  126. "".format(self.indptr.dtype.name), stacklevel=3)
  127. if self.indices.dtype.kind != 'i':
  128. warn("indices array has non-integer dtype ({})"
  129. "".format(self.indices.dtype.name), stacklevel=3)
  130. idx_dtype = get_index_dtype((self.indptr, self.indices))
  131. self.indptr = np.asarray(self.indptr, dtype=idx_dtype)
  132. self.indices = np.asarray(self.indices, dtype=idx_dtype)
  133. self.data = to_native(self.data)
  134. # check array shapes
  135. for x in [self.data.ndim, self.indices.ndim, self.indptr.ndim]:
  136. if x != 1:
  137. raise ValueError('data, indices, and indptr should be 1-D')
  138. # check index pointer
  139. if (len(self.indptr) != major_dim + 1):
  140. raise ValueError("index pointer size ({}) should be ({})"
  141. "".format(len(self.indptr), major_dim + 1))
  142. if (self.indptr[0] != 0):
  143. raise ValueError("index pointer should start with 0")
  144. # check index and data arrays
  145. if (len(self.indices) != len(self.data)):
  146. raise ValueError("indices and data should have the same size")
  147. if (self.indptr[-1] > len(self.indices)):
  148. raise ValueError("Last value of index pointer should be less than "
  149. "the size of index and data arrays")
  150. self.prune()
  151. if full_check:
  152. # check format validity (more expensive)
  153. if self.nnz > 0:
  154. if self.indices.max() >= minor_dim:
  155. raise ValueError("{} index values must be < {}"
  156. "".format(minor_name, minor_dim))
  157. if self.indices.min() < 0:
  158. raise ValueError("{} index values must be >= 0"
  159. "".format(minor_name))
  160. if np.diff(self.indptr).min() < 0:
  161. raise ValueError("index pointer values must form a "
  162. "non-decreasing sequence")
  163. # if not self.has_sorted_indices():
  164. # warn('Indices were not in sorted order. Sorting indices.')
  165. # self.sort_indices()
  166. # assert(self.has_sorted_indices())
  167. # TODO check for duplicates?
  168. #######################
  169. # Boolean comparisons #
  170. #######################
  171. def _scalar_binopt(self, other, op):
  172. """Scalar version of self._binopt, for cases in which no new nonzeros
  173. are added. Produces a new spmatrix in canonical form.
  174. """
  175. self.sum_duplicates()
  176. res = self._with_data(op(self.data, other), copy=True)
  177. res.eliminate_zeros()
  178. return res
  179. def __eq__(self, other):
  180. # Scalar other.
  181. if isscalarlike(other):
  182. if np.isnan(other):
  183. return self.__class__(self.shape, dtype=np.bool_)
  184. if other == 0:
  185. warn("Comparing a sparse matrix with 0 using == is inefficient"
  186. ", try using != instead.", SparseEfficiencyWarning,
  187. stacklevel=3)
  188. all_true = self.__class__(np.ones(self.shape, dtype=np.bool_))
  189. inv = self._scalar_binopt(other, operator.ne)
  190. return all_true - inv
  191. else:
  192. return self._scalar_binopt(other, operator.eq)
  193. # Dense other.
  194. elif isdense(other):
  195. return self.todense() == other
  196. # Sparse other.
  197. elif isspmatrix(other):
  198. warn("Comparing sparse matrices using == is inefficient, try using"
  199. " != instead.", SparseEfficiencyWarning, stacklevel=3)
  200. # TODO sparse broadcasting
  201. if self.shape != other.shape:
  202. return False
  203. elif self.format != other.format:
  204. other = other.asformat(self.format)
  205. res = self._binopt(other, '_ne_')
  206. all_true = self.__class__(np.ones(self.shape, dtype=np.bool_))
  207. return all_true - res
  208. else:
  209. return False
  210. def __ne__(self, other):
  211. # Scalar other.
  212. if isscalarlike(other):
  213. if np.isnan(other):
  214. warn("Comparing a sparse matrix with nan using != is"
  215. " inefficient", SparseEfficiencyWarning, stacklevel=3)
  216. all_true = self.__class__(np.ones(self.shape, dtype=np.bool_))
  217. return all_true
  218. elif other != 0:
  219. warn("Comparing a sparse matrix with a nonzero scalar using !="
  220. " is inefficient, try using == instead.",
  221. SparseEfficiencyWarning, stacklevel=3)
  222. all_true = self.__class__(np.ones(self.shape), dtype=np.bool_)
  223. inv = self._scalar_binopt(other, operator.eq)
  224. return all_true - inv
  225. else:
  226. return self._scalar_binopt(other, operator.ne)
  227. # Dense other.
  228. elif isdense(other):
  229. return self.todense() != other
  230. # Sparse other.
  231. elif isspmatrix(other):
  232. # TODO sparse broadcasting
  233. if self.shape != other.shape:
  234. return True
  235. elif self.format != other.format:
  236. other = other.asformat(self.format)
  237. return self._binopt(other, '_ne_')
  238. else:
  239. return True
  240. def _inequality(self, other, op, op_name, bad_scalar_msg):
  241. # Scalar other.
  242. if isscalarlike(other):
  243. if 0 == other and op_name in ('_le_', '_ge_'):
  244. raise NotImplementedError(" >= and <= don't work with 0.")
  245. elif op(0, other):
  246. warn(bad_scalar_msg, SparseEfficiencyWarning)
  247. other_arr = np.empty(self.shape, dtype=np.result_type(other))
  248. other_arr.fill(other)
  249. other_arr = self.__class__(other_arr)
  250. return self._binopt(other_arr, op_name)
  251. else:
  252. return self._scalar_binopt(other, op)
  253. # Dense other.
  254. elif isdense(other):
  255. return op(self.todense(), other)
  256. # Sparse other.
  257. elif isspmatrix(other):
  258. # TODO sparse broadcasting
  259. if self.shape != other.shape:
  260. raise ValueError("inconsistent shapes")
  261. elif self.format != other.format:
  262. other = other.asformat(self.format)
  263. if op_name not in ('_ge_', '_le_'):
  264. return self._binopt(other, op_name)
  265. warn("Comparing sparse matrices using >= and <= is inefficient, "
  266. "using <, >, or !=, instead.", SparseEfficiencyWarning)
  267. all_true = self.__class__(np.ones(self.shape, dtype=np.bool_))
  268. res = self._binopt(other, '_gt_' if op_name == '_le_' else '_lt_')
  269. return all_true - res
  270. else:
  271. raise ValueError("Operands could not be compared.")
  272. def __lt__(self, other):
  273. return self._inequality(other, operator.lt, '_lt_',
  274. "Comparing a sparse matrix with a scalar "
  275. "greater than zero using < is inefficient, "
  276. "try using >= instead.")
  277. def __gt__(self, other):
  278. return self._inequality(other, operator.gt, '_gt_',
  279. "Comparing a sparse matrix with a scalar "
  280. "less than zero using > is inefficient, "
  281. "try using <= instead.")
  282. def __le__(self, other):
  283. return self._inequality(other, operator.le, '_le_',
  284. "Comparing a sparse matrix with a scalar "
  285. "greater than zero using <= is inefficient, "
  286. "try using > instead.")
  287. def __ge__(self, other):
  288. return self._inequality(other, operator.ge, '_ge_',
  289. "Comparing a sparse matrix with a scalar "
  290. "less than zero using >= is inefficient, "
  291. "try using < instead.")
  292. #################################
  293. # Arithmetic operator overrides #
  294. #################################
  295. def _add_dense(self, other):
  296. if other.shape != self.shape:
  297. raise ValueError('Incompatible shapes.')
  298. dtype = upcast_char(self.dtype.char, other.dtype.char)
  299. order = self._swap('CF')[0]
  300. result = np.array(other, dtype=dtype, order=order, copy=True)
  301. M, N = self._swap(self.shape)
  302. y = result if result.flags.c_contiguous else result.T
  303. _sparsetools.csr_todense(M, N, self.indptr, self.indices, self.data, y)
  304. return np.matrix(result, copy=False)
  305. def _add_sparse(self, other):
  306. return self._binopt(other, '_plus_')
  307. def _sub_sparse(self, other):
  308. return self._binopt(other, '_minus_')
  309. def multiply(self, other):
  310. """Point-wise multiplication by another matrix, vector, or
  311. scalar.
  312. """
  313. # Scalar multiplication.
  314. if isscalarlike(other):
  315. return self._mul_scalar(other)
  316. # Sparse matrix or vector.
  317. if isspmatrix(other):
  318. if self.shape == other.shape:
  319. other = self.__class__(other)
  320. return self._binopt(other, '_elmul_')
  321. # Single element.
  322. elif other.shape == (1, 1):
  323. return self._mul_scalar(other.toarray()[0, 0])
  324. elif self.shape == (1, 1):
  325. return other._mul_scalar(self.toarray()[0, 0])
  326. # A row times a column.
  327. elif self.shape[1] == 1 and other.shape[0] == 1:
  328. return self._mul_sparse_matrix(other.tocsc())
  329. elif self.shape[0] == 1 and other.shape[1] == 1:
  330. return other._mul_sparse_matrix(self.tocsc())
  331. # Row vector times matrix. other is a row.
  332. elif other.shape[0] == 1 and self.shape[1] == other.shape[1]:
  333. other = dia_matrix((other.toarray().ravel(), [0]),
  334. shape=(other.shape[1], other.shape[1]))
  335. return self._mul_sparse_matrix(other)
  336. # self is a row.
  337. elif self.shape[0] == 1 and self.shape[1] == other.shape[1]:
  338. copy = dia_matrix((self.toarray().ravel(), [0]),
  339. shape=(self.shape[1], self.shape[1]))
  340. return other._mul_sparse_matrix(copy)
  341. # Column vector times matrix. other is a column.
  342. elif other.shape[1] == 1 and self.shape[0] == other.shape[0]:
  343. other = dia_matrix((other.toarray().ravel(), [0]),
  344. shape=(other.shape[0], other.shape[0]))
  345. return other._mul_sparse_matrix(self)
  346. # self is a column.
  347. elif self.shape[1] == 1 and self.shape[0] == other.shape[0]:
  348. copy = dia_matrix((self.toarray().ravel(), [0]),
  349. shape=(self.shape[0], self.shape[0]))
  350. return copy._mul_sparse_matrix(other)
  351. else:
  352. raise ValueError("inconsistent shapes")
  353. # Assume other is a dense matrix/array, which produces a single-item
  354. # object array if other isn't convertible to ndarray.
  355. other = np.atleast_2d(other)
  356. if other.ndim != 2:
  357. return np.multiply(self.toarray(), other)
  358. # Single element / wrapped object.
  359. if other.size == 1:
  360. return self._mul_scalar(other.flat[0])
  361. # Fast case for trivial sparse matrix.
  362. elif self.shape == (1, 1):
  363. return np.multiply(self.toarray()[0, 0], other)
  364. from .coo import coo_matrix
  365. ret = self.tocoo()
  366. # Matching shapes.
  367. if self.shape == other.shape:
  368. data = np.multiply(ret.data, other[ret.row, ret.col])
  369. # Sparse row vector times...
  370. elif self.shape[0] == 1:
  371. if other.shape[1] == 1: # Dense column vector.
  372. data = np.multiply(ret.data, other)
  373. elif other.shape[1] == self.shape[1]: # Dense matrix.
  374. data = np.multiply(ret.data, other[:, ret.col])
  375. else:
  376. raise ValueError("inconsistent shapes")
  377. row = np.repeat(np.arange(other.shape[0]), len(ret.row))
  378. col = np.tile(ret.col, other.shape[0])
  379. return coo_matrix((data.view(np.ndarray).ravel(), (row, col)),
  380. shape=(other.shape[0], self.shape[1]),
  381. copy=False)
  382. # Sparse column vector times...
  383. elif self.shape[1] == 1:
  384. if other.shape[0] == 1: # Dense row vector.
  385. data = np.multiply(ret.data[:, None], other)
  386. elif other.shape[0] == self.shape[0]: # Dense matrix.
  387. data = np.multiply(ret.data[:, None], other[ret.row])
  388. else:
  389. raise ValueError("inconsistent shapes")
  390. row = np.repeat(ret.row, other.shape[1])
  391. col = np.tile(np.arange(other.shape[1]), len(ret.col))
  392. return coo_matrix((data.view(np.ndarray).ravel(), (row, col)),
  393. shape=(self.shape[0], other.shape[1]),
  394. copy=False)
  395. # Sparse matrix times dense row vector.
  396. elif other.shape[0] == 1 and self.shape[1] == other.shape[1]:
  397. data = np.multiply(ret.data, other[:, ret.col].ravel())
  398. # Sparse matrix times dense column vector.
  399. elif other.shape[1] == 1 and self.shape[0] == other.shape[0]:
  400. data = np.multiply(ret.data, other[ret.row].ravel())
  401. else:
  402. raise ValueError("inconsistent shapes")
  403. ret.data = data.view(np.ndarray).ravel()
  404. return ret
  405. ###########################
  406. # Multiplication handlers #
  407. ###########################
  408. def _mul_vector(self, other):
  409. M, N = self.shape
  410. # output array
  411. result = np.zeros(M, dtype=upcast_char(self.dtype.char,
  412. other.dtype.char))
  413. # csr_matvec or csc_matvec
  414. fn = getattr(_sparsetools, self.format + '_matvec')
  415. fn(M, N, self.indptr, self.indices, self.data, other, result)
  416. return result
  417. def _mul_multivector(self, other):
  418. M, N = self.shape
  419. n_vecs = other.shape[1] # number of column vectors
  420. result = np.zeros((M, n_vecs),
  421. dtype=upcast_char(self.dtype.char, other.dtype.char))
  422. # csr_matvecs or csc_matvecs
  423. fn = getattr(_sparsetools, self.format + '_matvecs')
  424. fn(M, N, n_vecs, self.indptr, self.indices, self.data,
  425. other.ravel(), result.ravel())
  426. return result
  427. def _mul_sparse_matrix(self, other):
  428. M, K1 = self.shape
  429. K2, N = other.shape
  430. major_axis = self._swap((M, N))[0]
  431. other = self.__class__(other) # convert to this format
  432. idx_dtype = get_index_dtype((self.indptr, self.indices,
  433. other.indptr, other.indices),
  434. maxval=M*N)
  435. indptr = np.empty(major_axis + 1, dtype=idx_dtype)
  436. fn = getattr(_sparsetools, self.format + '_matmat_pass1')
  437. fn(M, N,
  438. np.asarray(self.indptr, dtype=idx_dtype),
  439. np.asarray(self.indices, dtype=idx_dtype),
  440. np.asarray(other.indptr, dtype=idx_dtype),
  441. np.asarray(other.indices, dtype=idx_dtype),
  442. indptr)
  443. nnz = indptr[-1]
  444. idx_dtype = get_index_dtype((self.indptr, self.indices,
  445. other.indptr, other.indices),
  446. maxval=nnz)
  447. indptr = np.asarray(indptr, dtype=idx_dtype)
  448. indices = np.empty(nnz, dtype=idx_dtype)
  449. data = np.empty(nnz, dtype=upcast(self.dtype, other.dtype))
  450. fn = getattr(_sparsetools, self.format + '_matmat_pass2')
  451. fn(M, N, np.asarray(self.indptr, dtype=idx_dtype),
  452. np.asarray(self.indices, dtype=idx_dtype),
  453. self.data,
  454. np.asarray(other.indptr, dtype=idx_dtype),
  455. np.asarray(other.indices, dtype=idx_dtype),
  456. other.data,
  457. indptr, indices, data)
  458. return self.__class__((data, indices, indptr), shape=(M, N))
  459. def diagonal(self, k=0):
  460. rows, cols = self.shape
  461. if k <= -rows or k >= cols:
  462. raise ValueError("k exceeds matrix dimensions")
  463. fn = getattr(_sparsetools, self.format + "_diagonal")
  464. y = np.empty(min(rows + min(k, 0), cols - max(k, 0)),
  465. dtype=upcast(self.dtype))
  466. fn(k, self.shape[0], self.shape[1], self.indptr, self.indices,
  467. self.data, y)
  468. return y
  469. diagonal.__doc__ = spmatrix.diagonal.__doc__
  470. #####################
  471. # Other binary ops #
  472. #####################
  473. def _maximum_minimum(self, other, npop, op_name, dense_check):
  474. if isscalarlike(other):
  475. if dense_check(other):
  476. warn("Taking maximum (minimum) with > 0 (< 0) number results"
  477. " to a dense matrix.", SparseEfficiencyWarning,
  478. stacklevel=3)
  479. other_arr = np.empty(self.shape, dtype=np.asarray(other).dtype)
  480. other_arr.fill(other)
  481. other_arr = self.__class__(other_arr)
  482. return self._binopt(other_arr, op_name)
  483. else:
  484. self.sum_duplicates()
  485. new_data = npop(self.data, np.asarray(other))
  486. mat = self.__class__((new_data, self.indices, self.indptr),
  487. dtype=new_data.dtype, shape=self.shape)
  488. return mat
  489. elif isdense(other):
  490. return npop(self.todense(), other)
  491. elif isspmatrix(other):
  492. return self._binopt(other, op_name)
  493. else:
  494. raise ValueError("Operands not compatible.")
  495. def maximum(self, other):
  496. return self._maximum_minimum(other, np.maximum,
  497. '_maximum_', lambda x: np.asarray(x) > 0)
  498. maximum.__doc__ = spmatrix.maximum.__doc__
  499. def minimum(self, other):
  500. return self._maximum_minimum(other, np.minimum,
  501. '_minimum_', lambda x: np.asarray(x) < 0)
  502. minimum.__doc__ = spmatrix.minimum.__doc__
  503. #####################
  504. # Reduce operations #
  505. #####################
  506. def sum(self, axis=None, dtype=None, out=None):
  507. """Sum the matrix over the given axis. If the axis is None, sum
  508. over both rows and columns, returning a scalar.
  509. """
  510. # The spmatrix base class already does axis=0 and axis=1 efficiently
  511. # so we only do the case axis=None here
  512. if (not hasattr(self, 'blocksize') and
  513. axis in self._swap(((1, -1), (0, 2)))[0]):
  514. # faster than multiplication for large minor axis in CSC/CSR
  515. res_dtype = get_sum_dtype(self.dtype)
  516. ret = np.zeros(len(self.indptr) - 1, dtype=res_dtype)
  517. major_index, value = self._minor_reduce(np.add)
  518. ret[major_index] = value
  519. ret = np.asmatrix(ret)
  520. if axis % 2 == 1:
  521. ret = ret.T
  522. if out is not None and out.shape != ret.shape:
  523. raise ValueError('dimensions do not match')
  524. return ret.sum(axis=(), dtype=dtype, out=out)
  525. # spmatrix will handle the remaining situations when axis
  526. # is in {None, -1, 0, 1}
  527. else:
  528. return spmatrix.sum(self, axis=axis, dtype=dtype, out=out)
  529. sum.__doc__ = spmatrix.sum.__doc__
  530. def _minor_reduce(self, ufunc, data=None):
  531. """Reduce nonzeros with a ufunc over the minor axis when non-empty
  532. Can be applied to a function of self.data by supplying data parameter.
  533. Warning: this does not call sum_duplicates()
  534. Returns
  535. -------
  536. major_index : array of ints
  537. Major indices where nonzero
  538. value : array of self.dtype
  539. Reduce result for nonzeros in each major_index
  540. """
  541. if data is None:
  542. data = self.data
  543. major_index = np.flatnonzero(np.diff(self.indptr))
  544. value = ufunc.reduceat(data,
  545. downcast_intp_index(self.indptr[major_index]))
  546. return major_index, value
  547. #######################
  548. # Getting and Setting #
  549. #######################
  550. def __setitem__(self, index, x):
  551. # Process arrays from IndexMixin
  552. i, j = self._unpack_index(index)
  553. i, j = self._index_to_arrays(i, j)
  554. if isspmatrix(x):
  555. broadcast_row = x.shape[0] == 1 and i.shape[0] != 1
  556. broadcast_col = x.shape[1] == 1 and i.shape[1] != 1
  557. if not ((broadcast_row or x.shape[0] == i.shape[0]) and
  558. (broadcast_col or x.shape[1] == i.shape[1])):
  559. raise ValueError("shape mismatch in assignment")
  560. # clear entries that will be overwritten
  561. ci, cj = self._swap((i.ravel(), j.ravel()))
  562. self._zero_many(ci, cj)
  563. x = x.tocoo(copy=True)
  564. x.sum_duplicates()
  565. r, c = x.row, x.col
  566. x = np.asarray(x.data, dtype=self.dtype)
  567. if broadcast_row:
  568. r = np.repeat(np.arange(i.shape[0]), len(r))
  569. c = np.tile(c, i.shape[0])
  570. x = np.tile(x, i.shape[0])
  571. if broadcast_col:
  572. r = np.repeat(r, i.shape[1])
  573. c = np.tile(np.arange(i.shape[1]), len(c))
  574. x = np.repeat(x, i.shape[1])
  575. # only assign entries in the new sparsity structure
  576. i = i[r, c]
  577. j = j[r, c]
  578. else:
  579. # Make x and i into the same shape
  580. x = np.asarray(x, dtype=self.dtype)
  581. x, _ = np.broadcast_arrays(x, i)
  582. if x.shape != i.shape:
  583. raise ValueError("shape mismatch in assignment")
  584. if np.size(x) == 0:
  585. return
  586. i, j = self._swap((i.ravel(), j.ravel()))
  587. self._set_many(i, j, x.ravel())
  588. def _setdiag(self, values, k):
  589. if 0 in self.shape:
  590. return
  591. M, N = self.shape
  592. broadcast = (values.ndim == 0)
  593. if k < 0:
  594. if broadcast:
  595. max_index = min(M + k, N)
  596. else:
  597. max_index = min(M + k, N, len(values))
  598. i = np.arange(max_index, dtype=self.indices.dtype)
  599. j = np.arange(max_index, dtype=self.indices.dtype)
  600. i -= k
  601. else:
  602. if broadcast:
  603. max_index = min(M, N - k)
  604. else:
  605. max_index = min(M, N - k, len(values))
  606. i = np.arange(max_index, dtype=self.indices.dtype)
  607. j = np.arange(max_index, dtype=self.indices.dtype)
  608. j += k
  609. if not broadcast:
  610. values = values[:len(i)]
  611. self[i, j] = values
  612. def _prepare_indices(self, i, j):
  613. M, N = self._swap(self.shape)
  614. def check_bounds(indices, bound):
  615. idx = indices.max()
  616. if idx >= bound:
  617. raise IndexError('index (%d) out of range (>= %d)' %
  618. (idx, bound))
  619. idx = indices.min()
  620. if idx < -bound:
  621. raise IndexError('index (%d) out of range (< -%d)' %
  622. (idx, bound))
  623. check_bounds(i, M)
  624. check_bounds(j, N)
  625. i = np.asarray(i, dtype=self.indices.dtype)
  626. j = np.asarray(j, dtype=self.indices.dtype)
  627. return i, j, M, N
  628. def _set_many(self, i, j, x):
  629. """Sets value at each (i, j) to x
  630. Here (i,j) index major and minor respectively, and must not contain
  631. duplicate entries.
  632. """
  633. i, j, M, N = self._prepare_indices(i, j)
  634. n_samples = len(x)
  635. offsets = np.empty(n_samples, dtype=self.indices.dtype)
  636. ret = _sparsetools.csr_sample_offsets(M, N, self.indptr, self.indices,
  637. n_samples, i, j, offsets)
  638. if ret == 1:
  639. # rinse and repeat
  640. self.sum_duplicates()
  641. _sparsetools.csr_sample_offsets(M, N, self.indptr,
  642. self.indices, n_samples, i, j,
  643. offsets)
  644. if -1 not in offsets:
  645. # only affects existing non-zero cells
  646. self.data[offsets] = x
  647. return
  648. else:
  649. warn("Changing the sparsity structure of a {}_matrix is expensive."
  650. " lil_matrix is more efficient.".format(self.format),
  651. SparseEfficiencyWarning, stacklevel=3)
  652. # replace where possible
  653. mask = offsets > -1
  654. self.data[offsets[mask]] = x[mask]
  655. # only insertions remain
  656. mask = ~mask
  657. i = i[mask]
  658. i[i < 0] += M
  659. j = j[mask]
  660. j[j < 0] += N
  661. self._insert_many(i, j, x[mask])
  662. def _zero_many(self, i, j):
  663. """Sets value at each (i, j) to zero, preserving sparsity structure.
  664. Here (i,j) index major and minor respectively.
  665. """
  666. i, j, M, N = self._prepare_indices(i, j)
  667. n_samples = len(i)
  668. offsets = np.empty(n_samples, dtype=self.indices.dtype)
  669. ret = _sparsetools.csr_sample_offsets(M, N, self.indptr, self.indices,
  670. n_samples, i, j, offsets)
  671. if ret == 1:
  672. # rinse and repeat
  673. self.sum_duplicates()
  674. _sparsetools.csr_sample_offsets(M, N, self.indptr,
  675. self.indices, n_samples, i, j,
  676. offsets)
  677. # only assign zeros to the existing sparsity structure
  678. self.data[offsets[offsets > -1]] = 0
  679. def _insert_many(self, i, j, x):
  680. """Inserts new nonzero at each (i, j) with value x
  681. Here (i,j) index major and minor respectively.
  682. i, j and x must be non-empty, 1d arrays.
  683. Inserts each major group (e.g. all entries per row) at a time.
  684. Maintains has_sorted_indices property.
  685. Modifies i, j, x in place.
  686. """
  687. order = np.argsort(i, kind='mergesort') # stable for duplicates
  688. i = i.take(order, mode='clip')
  689. j = j.take(order, mode='clip')
  690. x = x.take(order, mode='clip')
  691. do_sort = self.has_sorted_indices
  692. # Update index data type
  693. idx_dtype = get_index_dtype((self.indices, self.indptr),
  694. maxval=(self.indptr[-1] + x.size))
  695. self.indptr = np.asarray(self.indptr, dtype=idx_dtype)
  696. self.indices = np.asarray(self.indices, dtype=idx_dtype)
  697. i = np.asarray(i, dtype=idx_dtype)
  698. j = np.asarray(j, dtype=idx_dtype)
  699. # Collate old and new in chunks by major index
  700. indices_parts = []
  701. data_parts = []
  702. ui, ui_indptr = np.unique(i, return_index=True)
  703. ui_indptr = np.append(ui_indptr, len(j))
  704. new_nnzs = np.diff(ui_indptr)
  705. prev = 0
  706. for c, (ii, js, je) in enumerate(izip(ui, ui_indptr, ui_indptr[1:])):
  707. # old entries
  708. start = self.indptr[prev]
  709. stop = self.indptr[ii]
  710. indices_parts.append(self.indices[start:stop])
  711. data_parts.append(self.data[start:stop])
  712. # handle duplicate j: keep last setting
  713. uj, uj_indptr = np.unique(j[js:je][::-1], return_index=True)
  714. if len(uj) == je - js:
  715. indices_parts.append(j[js:je])
  716. data_parts.append(x[js:je])
  717. else:
  718. indices_parts.append(j[js:je][::-1][uj_indptr])
  719. data_parts.append(x[js:je][::-1][uj_indptr])
  720. new_nnzs[c] = len(uj)
  721. prev = ii
  722. # remaining old entries
  723. start = self.indptr[ii]
  724. indices_parts.append(self.indices[start:])
  725. data_parts.append(self.data[start:])
  726. # update attributes
  727. self.indices = np.concatenate(indices_parts)
  728. self.data = np.concatenate(data_parts)
  729. nnzs = np.empty(self.indptr.shape, dtype=idx_dtype)
  730. nnzs[0] = idx_dtype(0)
  731. indptr_diff = np.diff(self.indptr)
  732. indptr_diff[ui] += new_nnzs
  733. nnzs[1:] = indptr_diff
  734. self.indptr = np.cumsum(nnzs, out=nnzs)
  735. if do_sort:
  736. # TODO: only sort where necessary
  737. self.has_sorted_indices = False
  738. self.sort_indices()
  739. self.check_format(full_check=False)
  740. def _get_single_element(self, row, col):
  741. M, N = self.shape
  742. if (row < 0):
  743. row += M
  744. if (col < 0):
  745. col += N
  746. if not (0 <= row < M) or not (0 <= col < N):
  747. raise IndexError("index out of bounds: 0<=%d<%d, 0<=%d<%d" %
  748. (row, M, col, N))
  749. major_index, minor_index = self._swap((row, col))
  750. start = self.indptr[major_index]
  751. end = self.indptr[major_index + 1]
  752. if self.has_sorted_indices:
  753. # Copies may be made, if dtypes of indices are not identical
  754. minor_index = self.indices.dtype.type(minor_index)
  755. minor_indices = self.indices[start:end]
  756. insert_pos_left = np.searchsorted(
  757. minor_indices, minor_index, side='left')
  758. insert_pos_right = insert_pos_left + np.searchsorted(
  759. minor_indices[insert_pos_left:], minor_index, side='right')
  760. return self.data[start + insert_pos_left:
  761. start + insert_pos_right].sum(dtype=self.dtype)
  762. else:
  763. return np.compress(minor_index == self.indices[start:end],
  764. self.data[start:end]).sum(dtype=self.dtype)
  765. def _get_submatrix(self, slice0, slice1):
  766. """Return a submatrix of this matrix (new matrix is created)."""
  767. slice0, slice1 = self._swap((slice0, slice1))
  768. shape0, shape1 = self._swap(self.shape)
  769. def _process_slice(sl, num):
  770. if isinstance(sl, slice):
  771. i0, i1 = sl.start, sl.stop
  772. if i0 is None:
  773. i0 = 0
  774. elif i0 < 0:
  775. i0 = num + i0
  776. if i1 is None:
  777. i1 = num
  778. elif i1 < 0:
  779. i1 = num + i1
  780. return i0, i1
  781. elif np.isscalar(sl):
  782. if sl < 0:
  783. sl += num
  784. return sl, sl + 1
  785. else:
  786. return sl[0], sl[1]
  787. def _in_bounds(i0, i1, num):
  788. if not (0 <= i0 < num) or not (0 < i1 <= num) or not (i0 < i1):
  789. raise IndexError("index out of bounds:"
  790. " 0<={i0}<{num}, 0<={i1}<{num}, {i0}<{i1}"
  791. "".format(i0=i0, num=num, i1=i1))
  792. i0, i1 = _process_slice(slice0, shape0)
  793. j0, j1 = _process_slice(slice1, shape1)
  794. _in_bounds(i0, i1, shape0)
  795. _in_bounds(j0, j1, shape1)
  796. aux = _sparsetools.get_csr_submatrix(shape0, shape1,
  797. self.indptr, self.indices,
  798. self.data,
  799. i0, i1, j0, j1)
  800. data, indices, indptr = aux[2], aux[1], aux[0]
  801. shape = self._swap((i1 - i0, j1 - j0))
  802. return self.__class__((data, indices, indptr), shape=shape)
  803. ######################
  804. # Conversion methods #
  805. ######################
  806. def tocoo(self, copy=True):
  807. major_dim, minor_dim = self._swap(self.shape)
  808. minor_indices = self.indices
  809. major_indices = np.empty(len(minor_indices), dtype=self.indices.dtype)
  810. _sparsetools.expandptr(major_dim, self.indptr, major_indices)
  811. row, col = self._swap((major_indices, minor_indices))
  812. from .coo import coo_matrix
  813. return coo_matrix((self.data, (row, col)), self.shape, copy=copy,
  814. dtype=self.dtype)
  815. tocoo.__doc__ = spmatrix.tocoo.__doc__
  816. def toarray(self, order=None, out=None):
  817. if out is None and order is None:
  818. order = self._swap('cf')[0]
  819. out = self._process_toarray_args(order, out)
  820. if not (out.flags.c_contiguous or out.flags.f_contiguous):
  821. raise ValueError('Output array must be C or F contiguous')
  822. # align ideal order with output array order
  823. if out.flags.c_contiguous:
  824. x = self.tocsr()
  825. y = out
  826. else:
  827. x = self.tocsc()
  828. y = out.T
  829. M, N = x._swap(x.shape)
  830. _sparsetools.csr_todense(M, N, x.indptr, x.indices, x.data, y)
  831. return out
  832. toarray.__doc__ = spmatrix.toarray.__doc__
  833. ##############################################################
  834. # methods that examine or modify the internal data structure #
  835. ##############################################################
  836. def eliminate_zeros(self):
  837. """Remove zero entries from the matrix
  838. This is an *in place* operation
  839. """
  840. M, N = self._swap(self.shape)
  841. _sparsetools.csr_eliminate_zeros(M, N, self.indptr, self.indices,
  842. self.data)
  843. self.prune() # nnz may have changed
  844. def __get_has_canonical_format(self):
  845. """Determine whether the matrix has sorted indices and no duplicates
  846. Returns
  847. - True: if the above applies
  848. - False: otherwise
  849. has_canonical_format implies has_sorted_indices, so if the latter flag
  850. is False, so will the former be; if the former is found True, the
  851. latter flag is also set.
  852. """
  853. # first check to see if result was cached
  854. if not getattr(self, '_has_sorted_indices', True):
  855. # not sorted => not canonical
  856. self._has_canonical_format = False
  857. elif not hasattr(self, '_has_canonical_format'):
  858. self.has_canonical_format = _sparsetools.csr_has_canonical_format(
  859. len(self.indptr) - 1, self.indptr, self.indices)
  860. return self._has_canonical_format
  861. def __set_has_canonical_format(self, val):
  862. self._has_canonical_format = bool(val)
  863. if val:
  864. self.has_sorted_indices = True
  865. has_canonical_format = property(fget=__get_has_canonical_format,
  866. fset=__set_has_canonical_format)
  867. def sum_duplicates(self):
  868. """Eliminate duplicate matrix entries by adding them together
  869. The is an *in place* operation
  870. """
  871. if self.has_canonical_format:
  872. return
  873. self.sort_indices()
  874. M, N = self._swap(self.shape)
  875. _sparsetools.csr_sum_duplicates(M, N, self.indptr, self.indices,
  876. self.data)
  877. self.prune() # nnz may have changed
  878. self.has_canonical_format = True
  879. def __get_sorted(self):
  880. """Determine whether the matrix has sorted indices
  881. Returns
  882. - True: if the indices of the matrix are in sorted order
  883. - False: otherwise
  884. """
  885. # first check to see if result was cached
  886. if not hasattr(self, '_has_sorted_indices'):
  887. self._has_sorted_indices = _sparsetools.csr_has_sorted_indices(
  888. len(self.indptr) - 1, self.indptr, self.indices)
  889. return self._has_sorted_indices
  890. def __set_sorted(self, val):
  891. self._has_sorted_indices = bool(val)
  892. has_sorted_indices = property(fget=__get_sorted, fset=__set_sorted)
  893. def sorted_indices(self):
  894. """Return a copy of this matrix with sorted indices
  895. """
  896. A = self.copy()
  897. A.sort_indices()
  898. return A
  899. # an alternative that has linear complexity is the following
  900. # although the previous option is typically faster
  901. # return self.toother().toother()
  902. def sort_indices(self):
  903. """Sort the indices of this matrix *in place*
  904. """
  905. if not self.has_sorted_indices:
  906. _sparsetools.csr_sort_indices(len(self.indptr) - 1, self.indptr,
  907. self.indices, self.data)
  908. self.has_sorted_indices = True
  909. def prune(self):
  910. """Remove empty space after all non-zero elements.
  911. """
  912. major_dim = self._swap(self.shape)[0]
  913. if len(self.indptr) != major_dim + 1:
  914. raise ValueError('index pointer has invalid length')
  915. if len(self.indices) < self.nnz:
  916. raise ValueError('indices array has fewer than nnz elements')
  917. if len(self.data) < self.nnz:
  918. raise ValueError('data array has fewer than nnz elements')
  919. self.indices = _prune_array(self.indices[:self.nnz])
  920. self.data = _prune_array(self.data[:self.nnz])
  921. def resize(self, *shape):
  922. shape = check_shape(shape)
  923. if hasattr(self, 'blocksize'):
  924. bm, bn = self.blocksize
  925. new_M, rm = divmod(shape[0], bm)
  926. new_N, rn = divmod(shape[1], bn)
  927. if rm or rn:
  928. raise ValueError("shape must be divisible into %s blocks. "
  929. "Got %s" % (self.blocksize, shape))
  930. M, N = self.shape[0] // bm, self.shape[1] // bn
  931. else:
  932. new_M, new_N = self._swap(shape)
  933. M, N = self._swap(self.shape)
  934. if new_M < M:
  935. self.indices = self.indices[:self.indptr[new_M]]
  936. self.data = self.data[:self.indptr[new_M]]
  937. self.indptr = self.indptr[:new_M + 1]
  938. elif new_M > M:
  939. self.indptr = np.resize(self.indptr, new_M + 1)
  940. self.indptr[M + 1:].fill(self.indptr[M])
  941. if new_N < N:
  942. mask = self.indices < new_N
  943. if not np.all(mask):
  944. self.indices = self.indices[mask]
  945. self.data = self.data[mask]
  946. major_index, val = self._minor_reduce(np.add, mask)
  947. self.indptr.fill(0)
  948. self.indptr[1:][major_index] = val
  949. np.cumsum(self.indptr, out=self.indptr)
  950. self._shape = shape
  951. resize.__doc__ = spmatrix.resize.__doc__
  952. ###################
  953. # utility methods #
  954. ###################
  955. # needed by _data_matrix
  956. def _with_data(self, data, copy=True):
  957. """Returns a matrix with the same sparsity structure as self,
  958. but with different data. By default the structure arrays
  959. (i.e. .indptr and .indices) are copied.
  960. """
  961. if copy:
  962. return self.__class__((data, self.indices.copy(),
  963. self.indptr.copy()),
  964. shape=self.shape,
  965. dtype=data.dtype)
  966. else:
  967. return self.__class__((data, self.indices, self.indptr),
  968. shape=self.shape, dtype=data.dtype)
  969. def _binopt(self, other, op):
  970. """apply the binary operation fn to two sparse matrices."""
  971. other = self.__class__(other)
  972. # e.g. csr_plus_csr, csr_minus_csr, etc.
  973. fn = getattr(_sparsetools, self.format + op + self.format)
  974. maxnnz = self.nnz + other.nnz
  975. idx_dtype = get_index_dtype((self.indptr, self.indices,
  976. other.indptr, other.indices),
  977. maxval=maxnnz)
  978. indptr = np.empty(self.indptr.shape, dtype=idx_dtype)
  979. indices = np.empty(maxnnz, dtype=idx_dtype)
  980. bool_ops = ['_ne_', '_lt_', '_gt_', '_le_', '_ge_']
  981. if op in bool_ops:
  982. data = np.empty(maxnnz, dtype=np.bool_)
  983. else:
  984. data = np.empty(maxnnz, dtype=upcast(self.dtype, other.dtype))
  985. fn(self.shape[0], self.shape[1],
  986. np.asarray(self.indptr, dtype=idx_dtype),
  987. np.asarray(self.indices, dtype=idx_dtype),
  988. self.data,
  989. np.asarray(other.indptr, dtype=idx_dtype),
  990. np.asarray(other.indices, dtype=idx_dtype),
  991. other.data,
  992. indptr, indices, data)
  993. A = self.__class__((data, indices, indptr), shape=self.shape)
  994. A.prune()
  995. return A
  996. def _divide_sparse(self, other):
  997. """
  998. Divide this matrix by a second sparse matrix.
  999. """
  1000. if other.shape != self.shape:
  1001. raise ValueError('inconsistent shapes')
  1002. r = self._binopt(other, '_eldiv_')
  1003. if np.issubdtype(r.dtype, np.inexact):
  1004. # Eldiv leaves entries outside the combined sparsity
  1005. # pattern empty, so they must be filled manually.
  1006. # Everything outside of other's sparsity is NaN, and everything
  1007. # inside it is either zero or defined by eldiv.
  1008. out = np.empty(self.shape, dtype=self.dtype)
  1009. out.fill(np.nan)
  1010. row, col = other.nonzero()
  1011. out[row, col] = 0
  1012. r = r.tocoo()
  1013. out[r.row, r.col] = r.data
  1014. out = np.matrix(out)
  1015. else:
  1016. # integers types go with nan <-> 0
  1017. out = r
  1018. return out