csr.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. """Compressed Sparse Row matrix format"""
  2. from __future__ import division, print_function, absolute_import
  3. __docformat__ = "restructuredtext en"
  4. __all__ = ['csr_matrix', 'isspmatrix_csr']
  5. import numpy as np
  6. from scipy._lib.six import xrange
  7. from .base import spmatrix
  8. from ._sparsetools import csr_tocsc, csr_tobsr, csr_count_blocks, \
  9. get_csr_submatrix, csr_sample_values
  10. from .sputils import (upcast, isintlike, IndexMixin, issequence,
  11. get_index_dtype, ismatrix)
  12. from .compressed import _cs_matrix
  13. class csr_matrix(_cs_matrix, IndexMixin):
  14. """
  15. Compressed Sparse Row matrix
  16. This can be instantiated in several ways:
  17. csr_matrix(D)
  18. with a dense matrix or rank-2 ndarray D
  19. csr_matrix(S)
  20. with another sparse matrix S (equivalent to S.tocsr())
  21. csr_matrix((M, N), [dtype])
  22. to construct an empty matrix with shape (M, N)
  23. dtype is optional, defaulting to dtype='d'.
  24. csr_matrix((data, (row_ind, col_ind)), [shape=(M, N)])
  25. where ``data``, ``row_ind`` and ``col_ind`` satisfy the
  26. relationship ``a[row_ind[k], col_ind[k]] = data[k]``.
  27. csr_matrix((data, indices, indptr), [shape=(M, N)])
  28. is the standard CSR representation where the column indices for
  29. row i are stored in ``indices[indptr[i]:indptr[i+1]]`` and their
  30. corresponding values are stored in ``data[indptr[i]:indptr[i+1]]``.
  31. If the shape parameter is not supplied, the matrix dimensions
  32. are inferred from the index arrays.
  33. Attributes
  34. ----------
  35. dtype : dtype
  36. Data type of the matrix
  37. shape : 2-tuple
  38. Shape of the matrix
  39. ndim : int
  40. Number of dimensions (this is always 2)
  41. nnz
  42. Number of nonzero elements
  43. data
  44. CSR format data array of the matrix
  45. indices
  46. CSR format index array of the matrix
  47. indptr
  48. CSR format index pointer array of the matrix
  49. has_sorted_indices
  50. Whether indices are sorted
  51. Notes
  52. -----
  53. Sparse matrices can be used in arithmetic operations: they support
  54. addition, subtraction, multiplication, division, and matrix power.
  55. Advantages of the CSR format
  56. - efficient arithmetic operations CSR + CSR, CSR * CSR, etc.
  57. - efficient row slicing
  58. - fast matrix vector products
  59. Disadvantages of the CSR format
  60. - slow column slicing operations (consider CSC)
  61. - changes to the sparsity structure are expensive (consider LIL or DOK)
  62. Examples
  63. --------
  64. >>> import numpy as np
  65. >>> from scipy.sparse import csr_matrix
  66. >>> csr_matrix((3, 4), dtype=np.int8).toarray()
  67. array([[0, 0, 0, 0],
  68. [0, 0, 0, 0],
  69. [0, 0, 0, 0]], dtype=int8)
  70. >>> row = np.array([0, 0, 1, 2, 2, 2])
  71. >>> col = np.array([0, 2, 2, 0, 1, 2])
  72. >>> data = np.array([1, 2, 3, 4, 5, 6])
  73. >>> csr_matrix((data, (row, col)), shape=(3, 3)).toarray()
  74. array([[1, 0, 2],
  75. [0, 0, 3],
  76. [4, 5, 6]])
  77. >>> indptr = np.array([0, 2, 3, 6])
  78. >>> indices = np.array([0, 2, 2, 0, 1, 2])
  79. >>> data = np.array([1, 2, 3, 4, 5, 6])
  80. >>> csr_matrix((data, indices, indptr), shape=(3, 3)).toarray()
  81. array([[1, 0, 2],
  82. [0, 0, 3],
  83. [4, 5, 6]])
  84. As an example of how to construct a CSR matrix incrementally,
  85. the following snippet builds a term-document matrix from texts:
  86. >>> docs = [["hello", "world", "hello"], ["goodbye", "cruel", "world"]]
  87. >>> indptr = [0]
  88. >>> indices = []
  89. >>> data = []
  90. >>> vocabulary = {}
  91. >>> for d in docs:
  92. ... for term in d:
  93. ... index = vocabulary.setdefault(term, len(vocabulary))
  94. ... indices.append(index)
  95. ... data.append(1)
  96. ... indptr.append(len(indices))
  97. ...
  98. >>> csr_matrix((data, indices, indptr), dtype=int).toarray()
  99. array([[2, 1, 0, 0],
  100. [0, 1, 1, 1]])
  101. """
  102. format = 'csr'
  103. def transpose(self, axes=None, copy=False):
  104. if axes is not None:
  105. raise ValueError(("Sparse matrices do not support "
  106. "an 'axes' parameter because swapping "
  107. "dimensions is the only logical permutation."))
  108. M, N = self.shape
  109. from .csc import csc_matrix
  110. return csc_matrix((self.data, self.indices,
  111. self.indptr), shape=(N, M), copy=copy)
  112. transpose.__doc__ = spmatrix.transpose.__doc__
  113. def tolil(self, copy=False):
  114. from .lil import lil_matrix
  115. lil = lil_matrix(self.shape,dtype=self.dtype)
  116. self.sum_duplicates()
  117. ptr,ind,dat = self.indptr,self.indices,self.data
  118. rows, data = lil.rows, lil.data
  119. for n in xrange(self.shape[0]):
  120. start = ptr[n]
  121. end = ptr[n+1]
  122. rows[n] = ind[start:end].tolist()
  123. data[n] = dat[start:end].tolist()
  124. return lil
  125. tolil.__doc__ = spmatrix.tolil.__doc__
  126. def tocsr(self, copy=False):
  127. if copy:
  128. return self.copy()
  129. else:
  130. return self
  131. tocsr.__doc__ = spmatrix.tocsr.__doc__
  132. def tocsc(self, copy=False):
  133. idx_dtype = get_index_dtype((self.indptr, self.indices),
  134. maxval=max(self.nnz, self.shape[0]))
  135. indptr = np.empty(self.shape[1] + 1, dtype=idx_dtype)
  136. indices = np.empty(self.nnz, dtype=idx_dtype)
  137. data = np.empty(self.nnz, dtype=upcast(self.dtype))
  138. csr_tocsc(self.shape[0], self.shape[1],
  139. self.indptr.astype(idx_dtype),
  140. self.indices.astype(idx_dtype),
  141. self.data,
  142. indptr,
  143. indices,
  144. data)
  145. from .csc import csc_matrix
  146. A = csc_matrix((data, indices, indptr), shape=self.shape)
  147. A.has_sorted_indices = True
  148. return A
  149. tocsc.__doc__ = spmatrix.tocsc.__doc__
  150. def tobsr(self, blocksize=None, copy=True):
  151. from .bsr import bsr_matrix
  152. if blocksize is None:
  153. from .spfuncs import estimate_blocksize
  154. return self.tobsr(blocksize=estimate_blocksize(self))
  155. elif blocksize == (1,1):
  156. arg1 = (self.data.reshape(-1,1,1),self.indices,self.indptr)
  157. return bsr_matrix(arg1, shape=self.shape, copy=copy)
  158. else:
  159. R,C = blocksize
  160. M,N = self.shape
  161. if R < 1 or C < 1 or M % R != 0 or N % C != 0:
  162. raise ValueError('invalid blocksize %s' % blocksize)
  163. blks = csr_count_blocks(M,N,R,C,self.indptr,self.indices)
  164. idx_dtype = get_index_dtype((self.indptr, self.indices),
  165. maxval=max(N//C, blks))
  166. indptr = np.empty(M//R+1, dtype=idx_dtype)
  167. indices = np.empty(blks, dtype=idx_dtype)
  168. data = np.zeros((blks,R,C), dtype=self.dtype)
  169. csr_tobsr(M, N, R, C,
  170. self.indptr.astype(idx_dtype),
  171. self.indices.astype(idx_dtype),
  172. self.data,
  173. indptr, indices, data.ravel())
  174. return bsr_matrix((data,indices,indptr), shape=self.shape)
  175. tobsr.__doc__ = spmatrix.tobsr.__doc__
  176. # these functions are used by the parent class (_cs_matrix)
  177. # to remove redudancy between csc_matrix and csr_matrix
  178. def _swap(self, x):
  179. """swap the members of x if this is a column-oriented matrix
  180. """
  181. return x
  182. def __getitem__(self, key):
  183. def asindices(x):
  184. try:
  185. x = np.asarray(x)
  186. # Check index contents to avoid creating 64bit arrays needlessly
  187. idx_dtype = get_index_dtype((x,), check_contents=True)
  188. if idx_dtype != x.dtype:
  189. x = x.astype(idx_dtype)
  190. except Exception:
  191. raise IndexError('invalid index')
  192. else:
  193. return x
  194. def check_bounds(indices, N):
  195. if indices.size == 0:
  196. return (0, 0)
  197. max_indx = indices.max()
  198. if max_indx >= N:
  199. raise IndexError('index (%d) out of range' % max_indx)
  200. min_indx = indices.min()
  201. if min_indx < -N:
  202. raise IndexError('index (%d) out of range' % (N + min_indx))
  203. return min_indx, max_indx
  204. def extractor(indices,N):
  205. """Return a sparse matrix P so that P*self implements
  206. slicing of the form self[[1,2,3],:]
  207. """
  208. indices = asindices(indices).copy()
  209. min_indx, max_indx = check_bounds(indices, N)
  210. if min_indx < 0:
  211. indices[indices < 0] += N
  212. indptr = np.arange(len(indices)+1, dtype=indices.dtype)
  213. data = np.ones(len(indices), dtype=self.dtype)
  214. shape = (len(indices),N)
  215. return csr_matrix((data,indices,indptr), shape=shape,
  216. dtype=self.dtype, copy=False)
  217. row, col = self._unpack_index(key)
  218. # First attempt to use original row optimized methods
  219. # [1, ?]
  220. if isintlike(row):
  221. # [i, j]
  222. if isintlike(col):
  223. return self._get_single_element(row, col)
  224. # [i, 1:2]
  225. elif isinstance(col, slice):
  226. return self._get_row_slice(row, col)
  227. # [i, [1, 2]]
  228. elif issequence(col):
  229. P = extractor(col,self.shape[1]).T
  230. return self[row, :] * P
  231. elif isinstance(row, slice):
  232. # [1:2,??]
  233. if ((isintlike(col) and row.step in (1, None)) or
  234. (isinstance(col, slice) and
  235. col.step in (1, None) and
  236. row.step in (1, None))):
  237. # col is int or slice with step 1, row is slice with step 1.
  238. return self._get_submatrix(row, col)
  239. elif issequence(col):
  240. # row is slice, col is sequence.
  241. P = extractor(col,self.shape[1]).T # [1:2,[1,2]]
  242. sliced = self
  243. if row != slice(None, None, None):
  244. sliced = sliced[row,:]
  245. return sliced * P
  246. elif issequence(row):
  247. # [[1,2],??]
  248. if isintlike(col) or isinstance(col,slice):
  249. P = extractor(row, self.shape[0]) # [[1,2],j] or [[1,2],1:2]
  250. extracted = P * self
  251. if col == slice(None, None, None):
  252. return extracted
  253. else:
  254. return extracted[:,col]
  255. elif ismatrix(row) and issequence(col):
  256. if len(row[0]) == 1 and isintlike(row[0][0]):
  257. # [[[1],[2]], [1,2]], outer indexing
  258. row = asindices(row)
  259. P_row = extractor(row[:,0], self.shape[0])
  260. P_col = extractor(col, self.shape[1]).T
  261. return P_row * self * P_col
  262. if not (issequence(col) and issequence(row)):
  263. # Sample elementwise
  264. row, col = self._index_to_arrays(row, col)
  265. row = asindices(row)
  266. col = asindices(col)
  267. if row.shape != col.shape:
  268. raise IndexError('number of row and column indices differ')
  269. assert row.ndim <= 2
  270. num_samples = np.size(row)
  271. if num_samples == 0:
  272. return csr_matrix(np.atleast_2d(row).shape, dtype=self.dtype)
  273. check_bounds(row, self.shape[0])
  274. check_bounds(col, self.shape[1])
  275. val = np.empty(num_samples, dtype=self.dtype)
  276. csr_sample_values(self.shape[0], self.shape[1],
  277. self.indptr, self.indices, self.data,
  278. num_samples, row.ravel(), col.ravel(), val)
  279. if row.ndim == 1:
  280. # row and col are 1d
  281. return np.asmatrix(val)
  282. return self.__class__(val.reshape(row.shape))
  283. def __iter__(self):
  284. indptr = np.zeros(2, dtype=self.indptr.dtype)
  285. shape = (1, self.shape[1])
  286. i0 = 0
  287. for i1 in self.indptr[1:]:
  288. indptr[1] = i1 - i0
  289. indices = self.indices[i0:i1]
  290. data = self.data[i0:i1]
  291. yield csr_matrix((data, indices, indptr), shape=shape, copy=True)
  292. i0 = i1
  293. def getrow(self, i):
  294. """Returns a copy of row i of the matrix, as a (1 x n)
  295. CSR matrix (row vector).
  296. """
  297. M, N = self.shape
  298. i = int(i)
  299. if i < 0:
  300. i += M
  301. if i < 0 or i >= M:
  302. raise IndexError('index (%d) out of range' % i)
  303. idx = slice(*self.indptr[i:i+2])
  304. data = self.data[idx].copy()
  305. indices = self.indices[idx].copy()
  306. indptr = np.array([0, len(indices)], dtype=self.indptr.dtype)
  307. return csr_matrix((data, indices, indptr), shape=(1, N),
  308. dtype=self.dtype, copy=False)
  309. def getcol(self, i):
  310. """Returns a copy of column i of the matrix, as a (m x 1)
  311. CSR matrix (column vector).
  312. """
  313. return self._get_submatrix(slice(None), i)
  314. def _get_row_slice(self, i, cslice):
  315. """Returns a copy of row self[i, cslice]
  316. """
  317. M, N = self.shape
  318. if i < 0:
  319. i += M
  320. if i < 0 or i >= M:
  321. raise IndexError('index (%d) out of range' % i)
  322. start, stop, stride = cslice.indices(N)
  323. if stride == 1:
  324. # for stride == 1, get_csr_submatrix is faster
  325. row_indptr, row_indices, row_data = get_csr_submatrix(
  326. M, N, self.indptr, self.indices, self.data, i, i + 1,
  327. start, stop)
  328. else:
  329. # other strides need new code
  330. row_indices = self.indices[self.indptr[i]:self.indptr[i + 1]]
  331. row_data = self.data[self.indptr[i]:self.indptr[i + 1]]
  332. if stride > 0:
  333. ind = (row_indices >= start) & (row_indices < stop)
  334. else:
  335. ind = (row_indices <= start) & (row_indices > stop)
  336. if abs(stride) > 1:
  337. ind &= (row_indices - start) % stride == 0
  338. row_indices = (row_indices[ind] - start) // stride
  339. row_data = row_data[ind]
  340. row_indptr = np.array([0, len(row_indices)])
  341. if stride < 0:
  342. row_data = row_data[::-1]
  343. row_indices = abs(row_indices[::-1])
  344. shape = (1, int(np.ceil(float(stop - start) / stride)))
  345. return csr_matrix((row_data, row_indices, row_indptr), shape=shape,
  346. dtype=self.dtype, copy=False)
  347. def _get_submatrix(self, row_slice, col_slice):
  348. """Return a submatrix of this matrix (new matrix is created)."""
  349. def process_slice(sl, num):
  350. if isinstance(sl, slice):
  351. i0, i1, stride = sl.indices(num)
  352. if stride != 1:
  353. raise ValueError('slicing with step != 1 not supported')
  354. elif isintlike(sl):
  355. if sl < 0:
  356. sl += num
  357. i0, i1 = sl, sl + 1
  358. else:
  359. raise TypeError('expected slice or scalar')
  360. if not (0 <= i0 <= num) or not (0 <= i1 <= num) or not (i0 <= i1):
  361. raise IndexError(
  362. "index out of bounds: 0 <= %d <= %d, 0 <= %d <= %d,"
  363. " %d <= %d" % (i0, num, i1, num, i0, i1))
  364. return i0, i1
  365. M,N = self.shape
  366. i0, i1 = process_slice(row_slice, M)
  367. j0, j1 = process_slice(col_slice, N)
  368. indptr, indices, data = get_csr_submatrix(
  369. M, N, self.indptr, self.indices, self.data, i0, i1, j0, j1)
  370. shape = (i1 - i0, j1 - j0)
  371. return self.__class__((data, indices, indptr), shape=shape,
  372. dtype=self.dtype, copy=False)
  373. def isspmatrix_csr(x):
  374. """Is x of csr_matrix type?
  375. Parameters
  376. ----------
  377. x
  378. object to check for being a csr matrix
  379. Returns
  380. -------
  381. bool
  382. True if x is a csr matrix, False otherwise
  383. Examples
  384. --------
  385. >>> from scipy.sparse import csr_matrix, isspmatrix_csr
  386. >>> isspmatrix_csr(csr_matrix([[5]]))
  387. True
  388. >>> from scipy.sparse import csc_matrix, csr_matrix, isspmatrix_csc
  389. >>> isspmatrix_csr(csc_matrix([[5]]))
  390. False
  391. """
  392. return isinstance(x, csr_matrix)