linsolve.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570
  1. from __future__ import division, print_function, absolute_import
  2. from warnings import warn
  3. import numpy as np
  4. from numpy import asarray
  5. from scipy.sparse import (isspmatrix_csc, isspmatrix_csr, isspmatrix,
  6. SparseEfficiencyWarning, csc_matrix, csr_matrix)
  7. from scipy.linalg import LinAlgError
  8. from . import _superlu
  9. noScikit = False
  10. try:
  11. import scikits.umfpack as umfpack
  12. except ImportError:
  13. noScikit = True
  14. useUmfpack = not noScikit
  15. __all__ = ['use_solver', 'spsolve', 'splu', 'spilu', 'factorized',
  16. 'MatrixRankWarning', 'spsolve_triangular']
  17. class MatrixRankWarning(UserWarning):
  18. pass
  19. def use_solver(**kwargs):
  20. """
  21. Select default sparse direct solver to be used.
  22. Parameters
  23. ----------
  24. useUmfpack : bool, optional
  25. Use UMFPACK over SuperLU. Has effect only if scikits.umfpack is
  26. installed. Default: True
  27. assumeSortedIndices : bool, optional
  28. Allow UMFPACK to skip the step of sorting indices for a CSR/CSC matrix.
  29. Has effect only if useUmfpack is True and scikits.umfpack is installed.
  30. Default: False
  31. Notes
  32. -----
  33. The default sparse solver is umfpack when available
  34. (scikits.umfpack is installed). This can be changed by passing
  35. useUmfpack = False, which then causes the always present SuperLU
  36. based solver to be used.
  37. Umfpack requires a CSR/CSC matrix to have sorted column/row indices. If
  38. sure that the matrix fulfills this, pass ``assumeSortedIndices=True``
  39. to gain some speed.
  40. """
  41. if 'useUmfpack' in kwargs:
  42. globals()['useUmfpack'] = kwargs['useUmfpack']
  43. if useUmfpack and 'assumeSortedIndices' in kwargs:
  44. umfpack.configure(assumeSortedIndices=kwargs['assumeSortedIndices'])
  45. def _get_umf_family(A):
  46. """Get umfpack family string given the sparse matrix dtype."""
  47. _families = {
  48. (np.float64, np.int32): 'di',
  49. (np.complex128, np.int32): 'zi',
  50. (np.float64, np.int64): 'dl',
  51. (np.complex128, np.int64): 'zl'
  52. }
  53. f_type = np.sctypeDict[A.dtype.name]
  54. i_type = np.sctypeDict[A.indices.dtype.name]
  55. try:
  56. family = _families[(f_type, i_type)]
  57. except KeyError:
  58. msg = 'only float64 or complex128 matrices with int32 or int64' \
  59. ' indices are supported! (got: matrix: %s, indices: %s)' \
  60. % (f_type, i_type)
  61. raise ValueError(msg)
  62. return family
  63. def spsolve(A, b, permc_spec=None, use_umfpack=True):
  64. """Solve the sparse linear system Ax=b, where b may be a vector or a matrix.
  65. Parameters
  66. ----------
  67. A : ndarray or sparse matrix
  68. The square matrix A will be converted into CSC or CSR form
  69. b : ndarray or sparse matrix
  70. The matrix or vector representing the right hand side of the equation.
  71. If a vector, b.shape must be (n,) or (n, 1).
  72. permc_spec : str, optional
  73. How to permute the columns of the matrix for sparsity preservation.
  74. (default: 'COLAMD')
  75. - ``NATURAL``: natural ordering.
  76. - ``MMD_ATA``: minimum degree ordering on the structure of A^T A.
  77. - ``MMD_AT_PLUS_A``: minimum degree ordering on the structure of A^T+A.
  78. - ``COLAMD``: approximate minimum degree column ordering
  79. use_umfpack : bool, optional
  80. if True (default) then use umfpack for the solution. This is
  81. only referenced if b is a vector and ``scikit-umfpack`` is installed.
  82. Returns
  83. -------
  84. x : ndarray or sparse matrix
  85. the solution of the sparse linear equation.
  86. If b is a vector, then x is a vector of size A.shape[1]
  87. If b is a matrix, then x is a matrix of size (A.shape[1], b.shape[1])
  88. Notes
  89. -----
  90. For solving the matrix expression AX = B, this solver assumes the resulting
  91. matrix X is sparse, as is often the case for very sparse inputs. If the
  92. resulting X is dense, the construction of this sparse result will be
  93. relatively expensive. In that case, consider converting A to a dense
  94. matrix and using scipy.linalg.solve or its variants.
  95. Examples
  96. --------
  97. >>> from scipy.sparse import csc_matrix
  98. >>> from scipy.sparse.linalg import spsolve
  99. >>> A = csc_matrix([[3, 2, 0], [1, -1, 0], [0, 5, 1]], dtype=float)
  100. >>> B = csc_matrix([[2, 0], [-1, 0], [2, 0]], dtype=float)
  101. >>> x = spsolve(A, B)
  102. >>> np.allclose(A.dot(x).todense(), B.todense())
  103. True
  104. """
  105. if not (isspmatrix_csc(A) or isspmatrix_csr(A)):
  106. A = csc_matrix(A)
  107. warn('spsolve requires A be CSC or CSR matrix format',
  108. SparseEfficiencyWarning)
  109. # b is a vector only if b have shape (n,) or (n, 1)
  110. b_is_sparse = isspmatrix(b)
  111. if not b_is_sparse:
  112. b = asarray(b)
  113. b_is_vector = ((b.ndim == 1) or (b.ndim == 2 and b.shape[1] == 1))
  114. # sum duplicates for non-canonical format
  115. A.sum_duplicates()
  116. A = A.asfptype() # upcast to a floating point format
  117. result_dtype = np.promote_types(A.dtype, b.dtype)
  118. if A.dtype != result_dtype:
  119. A = A.astype(result_dtype)
  120. if b.dtype != result_dtype:
  121. b = b.astype(result_dtype)
  122. # validate input shapes
  123. M, N = A.shape
  124. if (M != N):
  125. raise ValueError("matrix must be square (has shape %s)" % ((M, N),))
  126. if M != b.shape[0]:
  127. raise ValueError("matrix - rhs dimension mismatch (%s - %s)"
  128. % (A.shape, b.shape[0]))
  129. use_umfpack = use_umfpack and useUmfpack
  130. if b_is_vector and use_umfpack:
  131. if b_is_sparse:
  132. b_vec = b.toarray()
  133. else:
  134. b_vec = b
  135. b_vec = asarray(b_vec, dtype=A.dtype).ravel()
  136. if noScikit:
  137. raise RuntimeError('Scikits.umfpack not installed.')
  138. if A.dtype.char not in 'dD':
  139. raise ValueError("convert matrix data to double, please, using"
  140. " .astype(), or set linsolve.useUmfpack = False")
  141. umf = umfpack.UmfpackContext(_get_umf_family(A))
  142. x = umf.linsolve(umfpack.UMFPACK_A, A, b_vec,
  143. autoTranspose=True)
  144. else:
  145. if b_is_vector and b_is_sparse:
  146. b = b.toarray()
  147. b_is_sparse = False
  148. if not b_is_sparse:
  149. if isspmatrix_csc(A):
  150. flag = 1 # CSC format
  151. else:
  152. flag = 0 # CSR format
  153. options = dict(ColPerm=permc_spec)
  154. x, info = _superlu.gssv(N, A.nnz, A.data, A.indices, A.indptr,
  155. b, flag, options=options)
  156. if info != 0:
  157. warn("Matrix is exactly singular", MatrixRankWarning)
  158. x.fill(np.nan)
  159. if b_is_vector:
  160. x = x.ravel()
  161. else:
  162. # b is sparse
  163. Afactsolve = factorized(A)
  164. if not isspmatrix_csc(b):
  165. warn('spsolve is more efficient when sparse b '
  166. 'is in the CSC matrix format', SparseEfficiencyWarning)
  167. b = csc_matrix(b)
  168. # Create a sparse output matrix by repeatedly applying
  169. # the sparse factorization to solve columns of b.
  170. data_segs = []
  171. row_segs = []
  172. col_segs = []
  173. for j in range(b.shape[1]):
  174. bj = b[:, j].A.ravel()
  175. xj = Afactsolve(bj)
  176. w = np.flatnonzero(xj)
  177. segment_length = w.shape[0]
  178. row_segs.append(w)
  179. col_segs.append(np.full(segment_length, j, dtype=int))
  180. data_segs.append(np.asarray(xj[w], dtype=A.dtype))
  181. sparse_data = np.concatenate(data_segs)
  182. sparse_row = np.concatenate(row_segs)
  183. sparse_col = np.concatenate(col_segs)
  184. x = A.__class__((sparse_data, (sparse_row, sparse_col)),
  185. shape=b.shape, dtype=A.dtype)
  186. return x
  187. def splu(A, permc_spec=None, diag_pivot_thresh=None,
  188. relax=None, panel_size=None, options=dict()):
  189. """
  190. Compute the LU decomposition of a sparse, square matrix.
  191. Parameters
  192. ----------
  193. A : sparse matrix
  194. Sparse matrix to factorize. Should be in CSR or CSC format.
  195. permc_spec : str, optional
  196. How to permute the columns of the matrix for sparsity preservation.
  197. (default: 'COLAMD')
  198. - ``NATURAL``: natural ordering.
  199. - ``MMD_ATA``: minimum degree ordering on the structure of A^T A.
  200. - ``MMD_AT_PLUS_A``: minimum degree ordering on the structure of A^T+A.
  201. - ``COLAMD``: approximate minimum degree column ordering
  202. diag_pivot_thresh : float, optional
  203. Threshold used for a diagonal entry to be an acceptable pivot.
  204. See SuperLU user's guide for details [1]_
  205. relax : int, optional
  206. Expert option for customizing the degree of relaxing supernodes.
  207. See SuperLU user's guide for details [1]_
  208. panel_size : int, optional
  209. Expert option for customizing the panel size.
  210. See SuperLU user's guide for details [1]_
  211. options : dict, optional
  212. Dictionary containing additional expert options to SuperLU.
  213. See SuperLU user guide [1]_ (section 2.4 on the 'Options' argument)
  214. for more details. For example, you can specify
  215. ``options=dict(Equil=False, IterRefine='SINGLE'))``
  216. to turn equilibration off and perform a single iterative refinement.
  217. Returns
  218. -------
  219. invA : scipy.sparse.linalg.SuperLU
  220. Object, which has a ``solve`` method.
  221. See also
  222. --------
  223. spilu : incomplete LU decomposition
  224. Notes
  225. -----
  226. This function uses the SuperLU library.
  227. References
  228. ----------
  229. .. [1] SuperLU http://crd.lbl.gov/~xiaoye/SuperLU/
  230. Examples
  231. --------
  232. >>> from scipy.sparse import csc_matrix
  233. >>> from scipy.sparse.linalg import splu
  234. >>> A = csc_matrix([[1., 0., 0.], [5., 0., 2.], [0., -1., 0.]], dtype=float)
  235. >>> B = splu(A)
  236. >>> x = np.array([1., 2., 3.], dtype=float)
  237. >>> B.solve(x)
  238. array([ 1. , -3. , -1.5])
  239. >>> A.dot(B.solve(x))
  240. array([ 1., 2., 3.])
  241. >>> B.solve(A.dot(x))
  242. array([ 1., 2., 3.])
  243. """
  244. if not isspmatrix_csc(A):
  245. A = csc_matrix(A)
  246. warn('splu requires CSC matrix format', SparseEfficiencyWarning)
  247. # sum duplicates for non-canonical format
  248. A.sum_duplicates()
  249. A = A.asfptype() # upcast to a floating point format
  250. M, N = A.shape
  251. if (M != N):
  252. raise ValueError("can only factor square matrices") # is this true?
  253. _options = dict(DiagPivotThresh=diag_pivot_thresh, ColPerm=permc_spec,
  254. PanelSize=panel_size, Relax=relax)
  255. if options is not None:
  256. _options.update(options)
  257. return _superlu.gstrf(N, A.nnz, A.data, A.indices, A.indptr,
  258. ilu=False, options=_options)
  259. def spilu(A, drop_tol=None, fill_factor=None, drop_rule=None, permc_spec=None,
  260. diag_pivot_thresh=None, relax=None, panel_size=None, options=None):
  261. """
  262. Compute an incomplete LU decomposition for a sparse, square matrix.
  263. The resulting object is an approximation to the inverse of `A`.
  264. Parameters
  265. ----------
  266. A : (N, N) array_like
  267. Sparse matrix to factorize
  268. drop_tol : float, optional
  269. Drop tolerance (0 <= tol <= 1) for an incomplete LU decomposition.
  270. (default: 1e-4)
  271. fill_factor : float, optional
  272. Specifies the fill ratio upper bound (>= 1.0) for ILU. (default: 10)
  273. drop_rule : str, optional
  274. Comma-separated string of drop rules to use.
  275. Available rules: ``basic``, ``prows``, ``column``, ``area``,
  276. ``secondary``, ``dynamic``, ``interp``. (Default: ``basic,area``)
  277. See SuperLU documentation for details.
  278. Remaining other options
  279. Same as for `splu`
  280. Returns
  281. -------
  282. invA_approx : scipy.sparse.linalg.SuperLU
  283. Object, which has a ``solve`` method.
  284. See also
  285. --------
  286. splu : complete LU decomposition
  287. Notes
  288. -----
  289. To improve the better approximation to the inverse, you may need to
  290. increase `fill_factor` AND decrease `drop_tol`.
  291. This function uses the SuperLU library.
  292. Examples
  293. --------
  294. >>> from scipy.sparse import csc_matrix
  295. >>> from scipy.sparse.linalg import spilu
  296. >>> A = csc_matrix([[1., 0., 0.], [5., 0., 2.], [0., -1., 0.]], dtype=float)
  297. >>> B = spilu(A)
  298. >>> x = np.array([1., 2., 3.], dtype=float)
  299. >>> B.solve(x)
  300. array([ 1. , -3. , -1.5])
  301. >>> A.dot(B.solve(x))
  302. array([ 1., 2., 3.])
  303. >>> B.solve(A.dot(x))
  304. array([ 1., 2., 3.])
  305. """
  306. if not isspmatrix_csc(A):
  307. A = csc_matrix(A)
  308. warn('splu requires CSC matrix format', SparseEfficiencyWarning)
  309. # sum duplicates for non-canonical format
  310. A.sum_duplicates()
  311. A = A.asfptype() # upcast to a floating point format
  312. M, N = A.shape
  313. if (M != N):
  314. raise ValueError("can only factor square matrices") # is this true?
  315. _options = dict(ILU_DropRule=drop_rule, ILU_DropTol=drop_tol,
  316. ILU_FillFactor=fill_factor,
  317. DiagPivotThresh=diag_pivot_thresh, ColPerm=permc_spec,
  318. PanelSize=panel_size, Relax=relax)
  319. if options is not None:
  320. _options.update(options)
  321. return _superlu.gstrf(N, A.nnz, A.data, A.indices, A.indptr,
  322. ilu=True, options=_options)
  323. def factorized(A):
  324. """
  325. Return a function for solving a sparse linear system, with A pre-factorized.
  326. Parameters
  327. ----------
  328. A : (N, N) array_like
  329. Input.
  330. Returns
  331. -------
  332. solve : callable
  333. To solve the linear system of equations given in `A`, the `solve`
  334. callable should be passed an ndarray of shape (N,).
  335. Examples
  336. --------
  337. >>> from scipy.sparse.linalg import factorized
  338. >>> A = np.array([[ 3. , 2. , -1. ],
  339. ... [ 2. , -2. , 4. ],
  340. ... [-1. , 0.5, -1. ]])
  341. >>> solve = factorized(A) # Makes LU decomposition.
  342. >>> rhs1 = np.array([1, -2, 0])
  343. >>> solve(rhs1) # Uses the LU factors.
  344. array([ 1., -2., -2.])
  345. """
  346. if useUmfpack:
  347. if noScikit:
  348. raise RuntimeError('Scikits.umfpack not installed.')
  349. if not isspmatrix_csc(A):
  350. A = csc_matrix(A)
  351. warn('splu requires CSC matrix format', SparseEfficiencyWarning)
  352. A = A.asfptype() # upcast to a floating point format
  353. if A.dtype.char not in 'dD':
  354. raise ValueError("convert matrix data to double, please, using"
  355. " .astype(), or set linsolve.useUmfpack = False")
  356. umf = umfpack.UmfpackContext(_get_umf_family(A))
  357. # Make LU decomposition.
  358. umf.numeric(A)
  359. def solve(b):
  360. return umf.solve(umfpack.UMFPACK_A, A, b, autoTranspose=True)
  361. return solve
  362. else:
  363. return splu(A).solve
  364. def spsolve_triangular(A, b, lower=True, overwrite_A=False, overwrite_b=False):
  365. """
  366. Solve the equation `A x = b` for `x`, assuming A is a triangular matrix.
  367. Parameters
  368. ----------
  369. A : (M, M) sparse matrix
  370. A sparse square triangular matrix. Should be in CSR format.
  371. b : (M,) or (M, N) array_like
  372. Right-hand side matrix in `A x = b`
  373. lower : bool, optional
  374. Whether `A` is a lower or upper triangular matrix.
  375. Default is lower triangular matrix.
  376. overwrite_A : bool, optional
  377. Allow changing `A`. The indices of `A` are going to be sorted and zero
  378. entries are going to be removed.
  379. Enabling gives a performance gain. Default is False.
  380. overwrite_b : bool, optional
  381. Allow overwriting data in `b`.
  382. Enabling gives a performance gain. Default is False.
  383. If `overwrite_b` is True, it should be ensured that
  384. `b` has an appropriate dtype to be able to store the result.
  385. Returns
  386. -------
  387. x : (M,) or (M, N) ndarray
  388. Solution to the system `A x = b`. Shape of return matches shape of `b`.
  389. Raises
  390. ------
  391. LinAlgError
  392. If `A` is singular or not triangular.
  393. ValueError
  394. If shape of `A` or shape of `b` do not match the requirements.
  395. Notes
  396. -----
  397. .. versionadded:: 0.19.0
  398. Examples
  399. --------
  400. >>> from scipy.sparse import csr_matrix
  401. >>> from scipy.sparse.linalg import spsolve_triangular
  402. >>> A = csr_matrix([[3, 0, 0], [1, -1, 0], [2, 0, 1]], dtype=float)
  403. >>> B = np.array([[2, 0], [-1, 0], [2, 0]], dtype=float)
  404. >>> x = spsolve_triangular(A, B)
  405. >>> np.allclose(A.dot(x), B)
  406. True
  407. """
  408. # Check the input for correct type and format.
  409. if not isspmatrix_csr(A):
  410. warn('CSR matrix format is required. Converting to CSR matrix.',
  411. SparseEfficiencyWarning)
  412. A = csr_matrix(A)
  413. elif not overwrite_A:
  414. A = A.copy()
  415. if A.shape[0] != A.shape[1]:
  416. raise ValueError(
  417. 'A must be a square matrix but its shape is {}.'.format(A.shape))
  418. # sum duplicates for non-canonical format
  419. A.sum_duplicates()
  420. b = np.asanyarray(b)
  421. if b.ndim not in [1, 2]:
  422. raise ValueError(
  423. 'b must have 1 or 2 dims but its shape is {}.'.format(b.shape))
  424. if A.shape[0] != b.shape[0]:
  425. raise ValueError(
  426. 'The size of the dimensions of A must be equal to '
  427. 'the size of the first dimension of b but the shape of A is '
  428. '{} and the shape of b is {}.'.format(A.shape, b.shape))
  429. # Init x as (a copy of) b.
  430. x_dtype = np.result_type(A.data, b, np.float)
  431. if overwrite_b:
  432. if np.can_cast(b.dtype, x_dtype, casting='same_kind'):
  433. x = b
  434. else:
  435. raise ValueError(
  436. 'Cannot overwrite b (dtype {}) with result '
  437. 'of type {}.'.format(b.dtype, x_dtype))
  438. else:
  439. x = b.astype(x_dtype, copy=True)
  440. # Choose forward or backward order.
  441. if lower:
  442. row_indices = range(len(b))
  443. else:
  444. row_indices = range(len(b) - 1, -1, -1)
  445. # Fill x iteratively.
  446. for i in row_indices:
  447. # Get indices for i-th row.
  448. indptr_start = A.indptr[i]
  449. indptr_stop = A.indptr[i + 1]
  450. if lower:
  451. A_diagonal_index_row_i = indptr_stop - 1
  452. A_off_diagonal_indices_row_i = slice(indptr_start, indptr_stop - 1)
  453. else:
  454. A_diagonal_index_row_i = indptr_start
  455. A_off_diagonal_indices_row_i = slice(indptr_start + 1, indptr_stop)
  456. # Check regularity and triangularity of A.
  457. if indptr_stop <= indptr_start or A.indices[A_diagonal_index_row_i] < i:
  458. raise LinAlgError(
  459. 'A is singular: diagonal {} is zero.'.format(i))
  460. if A.indices[A_diagonal_index_row_i] > i:
  461. raise LinAlgError(
  462. 'A is not triangular: A[{}, {}] is nonzero.'
  463. ''.format(i, A.indices[A_diagonal_index_row_i]))
  464. # Incorporate off-diagonal entries.
  465. A_column_indices_in_row_i = A.indices[A_off_diagonal_indices_row_i]
  466. A_values_in_row_i = A.data[A_off_diagonal_indices_row_i]
  467. x[i] -= np.dot(x[A_column_indices_in_row_i].T, A_values_in_row_i)
  468. # Compute i-th entry of x.
  469. x[i] /= A.data[A_diagonal_index_row_i]
  470. return x