decomp_lu.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. """LU decomposition functions."""
  2. from __future__ import division, print_function, absolute_import
  3. from warnings import warn
  4. from numpy import asarray, asarray_chkfinite
  5. # Local imports
  6. from .misc import _datacopied, LinAlgWarning
  7. from .lapack import get_lapack_funcs
  8. from .flinalg import get_flinalg_funcs
  9. __all__ = ['lu', 'lu_solve', 'lu_factor']
  10. def lu_factor(a, overwrite_a=False, check_finite=True):
  11. """
  12. Compute pivoted LU decomposition of a matrix.
  13. The decomposition is::
  14. A = P L U
  15. where P is a permutation matrix, L lower triangular with unit
  16. diagonal elements, and U upper triangular.
  17. Parameters
  18. ----------
  19. a : (M, M) array_like
  20. Matrix to decompose
  21. overwrite_a : bool, optional
  22. Whether to overwrite data in A (may increase performance)
  23. check_finite : bool, optional
  24. Whether to check that the input matrix contains only finite numbers.
  25. Disabling may give a performance gain, but may result in problems
  26. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  27. Returns
  28. -------
  29. lu : (N, N) ndarray
  30. Matrix containing U in its upper triangle, and L in its lower triangle.
  31. The unit diagonal elements of L are not stored.
  32. piv : (N,) ndarray
  33. Pivot indices representing the permutation matrix P:
  34. row i of matrix was interchanged with row piv[i].
  35. See also
  36. --------
  37. lu_solve : solve an equation system using the LU factorization of a matrix
  38. Notes
  39. -----
  40. This is a wrapper to the ``*GETRF`` routines from LAPACK.
  41. Examples
  42. --------
  43. >>> from scipy.linalg import lu_factor
  44. >>> from numpy import tril, triu, allclose, zeros, eye
  45. >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
  46. >>> lu, piv = lu_factor(A)
  47. >>> piv
  48. array([2, 2, 3, 3], dtype=int32)
  49. Convert LAPACK's ``piv`` array to NumPy index and test the permutation
  50. >>> piv_py = [2, 0, 3, 1]
  51. >>> L, U = np.tril(lu, k=-1) + np.eye(4), np.triu(lu)
  52. >>> np.allclose(A[piv_py] - L @ U, np.zeros((4, 4)))
  53. True
  54. """
  55. if check_finite:
  56. a1 = asarray_chkfinite(a)
  57. else:
  58. a1 = asarray(a)
  59. if len(a1.shape) != 2 or (a1.shape[0] != a1.shape[1]):
  60. raise ValueError('expected square matrix')
  61. overwrite_a = overwrite_a or (_datacopied(a1, a))
  62. getrf, = get_lapack_funcs(('getrf',), (a1,))
  63. lu, piv, info = getrf(a1, overwrite_a=overwrite_a)
  64. if info < 0:
  65. raise ValueError('illegal value in %d-th argument of '
  66. 'internal getrf (lu_factor)' % -info)
  67. if info > 0:
  68. warn("Diagonal number %d is exactly zero. Singular matrix." % info,
  69. LinAlgWarning, stacklevel=2)
  70. return lu, piv
  71. def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
  72. """Solve an equation system, a x = b, given the LU factorization of a
  73. Parameters
  74. ----------
  75. (lu, piv)
  76. Factorization of the coefficient matrix a, as given by lu_factor
  77. b : array
  78. Right-hand side
  79. trans : {0, 1, 2}, optional
  80. Type of system to solve:
  81. ===== =========
  82. trans system
  83. ===== =========
  84. 0 a x = b
  85. 1 a^T x = b
  86. 2 a^H x = b
  87. ===== =========
  88. overwrite_b : bool, optional
  89. Whether to overwrite data in b (may increase performance)
  90. check_finite : bool, optional
  91. Whether to check that the input matrices contain only finite numbers.
  92. Disabling may give a performance gain, but may result in problems
  93. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  94. Returns
  95. -------
  96. x : array
  97. Solution to the system
  98. See also
  99. --------
  100. lu_factor : LU factorize a matrix
  101. Examples
  102. --------
  103. >>> from scipy.linalg import lu_factor, lu_solve
  104. >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
  105. >>> b = np.array([1, 1, 1, 1])
  106. >>> lu, piv = lu_factor(A)
  107. >>> x = lu_solve((lu, piv), b)
  108. >>> np.allclose(A @ x - b, np.zeros((4,)))
  109. True
  110. """
  111. (lu, piv) = lu_and_piv
  112. if check_finite:
  113. b1 = asarray_chkfinite(b)
  114. else:
  115. b1 = asarray(b)
  116. overwrite_b = overwrite_b or _datacopied(b1, b)
  117. if lu.shape[0] != b1.shape[0]:
  118. raise ValueError("incompatible dimensions.")
  119. getrs, = get_lapack_funcs(('getrs',), (lu, b1))
  120. x, info = getrs(lu, piv, b1, trans=trans, overwrite_b=overwrite_b)
  121. if info == 0:
  122. return x
  123. raise ValueError('illegal value in %d-th argument of internal gesv|posv'
  124. % -info)
  125. def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
  126. """
  127. Compute pivoted LU decomposition of a matrix.
  128. The decomposition is::
  129. A = P L U
  130. where P is a permutation matrix, L lower triangular with unit
  131. diagonal elements, and U upper triangular.
  132. Parameters
  133. ----------
  134. a : (M, N) array_like
  135. Array to decompose
  136. permute_l : bool, optional
  137. Perform the multiplication P*L (Default: do not permute)
  138. overwrite_a : bool, optional
  139. Whether to overwrite data in a (may improve performance)
  140. check_finite : bool, optional
  141. Whether to check that the input matrix contains only finite numbers.
  142. Disabling may give a performance gain, but may result in problems
  143. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  144. Returns
  145. -------
  146. **(If permute_l == False)**
  147. p : (M, M) ndarray
  148. Permutation matrix
  149. l : (M, K) ndarray
  150. Lower triangular or trapezoidal matrix with unit diagonal.
  151. K = min(M, N)
  152. u : (K, N) ndarray
  153. Upper triangular or trapezoidal matrix
  154. **(If permute_l == True)**
  155. pl : (M, K) ndarray
  156. Permuted L matrix.
  157. K = min(M, N)
  158. u : (K, N) ndarray
  159. Upper triangular or trapezoidal matrix
  160. Notes
  161. -----
  162. This is a LU factorization routine written for Scipy.
  163. Examples
  164. --------
  165. >>> from scipy.linalg import lu
  166. >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
  167. >>> p, l, u = lu(A)
  168. >>> np.allclose(A - p @ l @ u, np.zeros((4, 4)))
  169. True
  170. """
  171. if check_finite:
  172. a1 = asarray_chkfinite(a)
  173. else:
  174. a1 = asarray(a)
  175. if len(a1.shape) != 2:
  176. raise ValueError('expected matrix')
  177. overwrite_a = overwrite_a or (_datacopied(a1, a))
  178. flu, = get_flinalg_funcs(('lu',), (a1,))
  179. p, l, u, info = flu(a1, permute_l=permute_l, overwrite_a=overwrite_a)
  180. if info < 0:
  181. raise ValueError('illegal value in %d-th argument of '
  182. 'internal lu.getrf' % -info)
  183. if permute_l:
  184. return l, u
  185. return p, l, u