decomp_cholesky.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. """Cholesky decomposition functions."""
  2. from __future__ import division, print_function, absolute_import
  3. from numpy import asarray_chkfinite, asarray, atleast_2d
  4. # Local imports
  5. from .misc import LinAlgError, _datacopied
  6. from .lapack import get_lapack_funcs
  7. __all__ = ['cholesky', 'cho_factor', 'cho_solve', 'cholesky_banded',
  8. 'cho_solve_banded']
  9. def _cholesky(a, lower=False, overwrite_a=False, clean=True,
  10. check_finite=True):
  11. """Common code for cholesky() and cho_factor()."""
  12. a1 = asarray_chkfinite(a) if check_finite else asarray(a)
  13. a1 = atleast_2d(a1)
  14. # Dimension check
  15. if a1.ndim != 2:
  16. raise ValueError('Input array needs to be 2 dimensional but received '
  17. 'a {}d-array.'.format(a1.ndim))
  18. # Squareness check
  19. if a1.shape[0] != a1.shape[1]:
  20. raise ValueError('Input array is expected to be square but has '
  21. 'the shape: {}.'.format(a1.shape))
  22. # Quick return for square empty array
  23. if a1.size == 0:
  24. return a1.copy(), lower
  25. overwrite_a = overwrite_a or _datacopied(a1, a)
  26. potrf, = get_lapack_funcs(('potrf',), (a1,))
  27. c, info = potrf(a1, lower=lower, overwrite_a=overwrite_a, clean=clean)
  28. if info > 0:
  29. raise LinAlgError("%d-th leading minor of the array is not positive "
  30. "definite" % info)
  31. if info < 0:
  32. raise ValueError('LAPACK reported an illegal value in {}-th argument'
  33. 'on entry to "POTRF".'.format(-info))
  34. return c, lower
  35. def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
  36. """
  37. Compute the Cholesky decomposition of a matrix.
  38. Returns the Cholesky decomposition, :math:`A = L L^*` or
  39. :math:`A = U^* U` of a Hermitian positive-definite matrix A.
  40. Parameters
  41. ----------
  42. a : (M, M) array_like
  43. Matrix to be decomposed
  44. lower : bool, optional
  45. Whether to compute the upper or lower triangular Cholesky
  46. factorization. Default is upper-triangular.
  47. overwrite_a : bool, optional
  48. Whether to overwrite data in `a` (may improve performance).
  49. check_finite : bool, optional
  50. Whether to check that the input matrix contains only finite numbers.
  51. Disabling may give a performance gain, but may result in problems
  52. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  53. Returns
  54. -------
  55. c : (M, M) ndarray
  56. Upper- or lower-triangular Cholesky factor of `a`.
  57. Raises
  58. ------
  59. LinAlgError : if decomposition fails.
  60. Examples
  61. --------
  62. >>> from scipy.linalg import cholesky
  63. >>> a = np.array([[1,-2j],[2j,5]])
  64. >>> L = cholesky(a, lower=True)
  65. >>> L
  66. array([[ 1.+0.j, 0.+0.j],
  67. [ 0.+2.j, 1.+0.j]])
  68. >>> L @ L.T.conj()
  69. array([[ 1.+0.j, 0.-2.j],
  70. [ 0.+2.j, 5.+0.j]])
  71. """
  72. c, lower = _cholesky(a, lower=lower, overwrite_a=overwrite_a, clean=True,
  73. check_finite=check_finite)
  74. return c
  75. def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
  76. """
  77. Compute the Cholesky decomposition of a matrix, to use in cho_solve
  78. Returns a matrix containing the Cholesky decomposition,
  79. ``A = L L*`` or ``A = U* U`` of a Hermitian positive-definite matrix `a`.
  80. The return value can be directly used as the first parameter to cho_solve.
  81. .. warning::
  82. The returned matrix also contains random data in the entries not
  83. used by the Cholesky decomposition. If you need to zero these
  84. entries, use the function `cholesky` instead.
  85. Parameters
  86. ----------
  87. a : (M, M) array_like
  88. Matrix to be decomposed
  89. lower : bool, optional
  90. Whether to compute the upper or lower triangular Cholesky factorization
  91. (Default: upper-triangular)
  92. overwrite_a : bool, optional
  93. Whether to overwrite data in a (may improve performance)
  94. check_finite : bool, optional
  95. Whether to check that the input matrix contains only finite numbers.
  96. Disabling may give a performance gain, but may result in problems
  97. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  98. Returns
  99. -------
  100. c : (M, M) ndarray
  101. Matrix whose upper or lower triangle contains the Cholesky factor
  102. of `a`. Other parts of the matrix contain random data.
  103. lower : bool
  104. Flag indicating whether the factor is in the lower or upper triangle
  105. Raises
  106. ------
  107. LinAlgError
  108. Raised if decomposition fails.
  109. See also
  110. --------
  111. cho_solve : Solve a linear set equations using the Cholesky factorization
  112. of a matrix.
  113. Examples
  114. --------
  115. >>> from scipy.linalg import cho_factor
  116. >>> A = np.array([[9, 3, 1, 5], [3, 7, 5, 1], [1, 5, 9, 2], [5, 1, 2, 6]])
  117. >>> c, low = cho_factor(A)
  118. >>> c
  119. array([[3. , 1. , 0.33333333, 1.66666667],
  120. [3. , 2.44948974, 1.90515869, -0.27216553],
  121. [1. , 5. , 2.29330749, 0.8559528 ],
  122. [5. , 1. , 2. , 1.55418563]])
  123. >>> np.allclose(np.triu(c).T @ np. triu(c) - A, np.zeros((4, 4)))
  124. True
  125. """
  126. c, lower = _cholesky(a, lower=lower, overwrite_a=overwrite_a, clean=False,
  127. check_finite=check_finite)
  128. return c, lower
  129. def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
  130. """Solve the linear equations A x = b, given the Cholesky factorization of A.
  131. Parameters
  132. ----------
  133. (c, lower) : tuple, (array, bool)
  134. Cholesky factorization of a, as given by cho_factor
  135. b : array
  136. Right-hand side
  137. overwrite_b : bool, optional
  138. Whether to overwrite data in b (may improve performance)
  139. check_finite : bool, optional
  140. Whether to check that the input matrices contain only finite numbers.
  141. Disabling may give a performance gain, but may result in problems
  142. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  143. Returns
  144. -------
  145. x : array
  146. The solution to the system A x = b
  147. See also
  148. --------
  149. cho_factor : Cholesky factorization of a matrix
  150. Examples
  151. --------
  152. >>> from scipy.linalg import cho_factor, cho_solve
  153. >>> A = np.array([[9, 3, 1, 5], [3, 7, 5, 1], [1, 5, 9, 2], [5, 1, 2, 6]])
  154. >>> c, low = cho_factor(A)
  155. >>> x = cho_solve((c, low), [1, 1, 1, 1])
  156. >>> np.allclose(A @ x - [1, 1, 1, 1], np.zeros(4))
  157. True
  158. """
  159. (c, lower) = c_and_lower
  160. if check_finite:
  161. b1 = asarray_chkfinite(b)
  162. c = asarray_chkfinite(c)
  163. else:
  164. b1 = asarray(b)
  165. c = asarray(c)
  166. if c.ndim != 2 or c.shape[0] != c.shape[1]:
  167. raise ValueError("The factored matrix c is not square.")
  168. if c.shape[1] != b1.shape[0]:
  169. raise ValueError("incompatible dimensions.")
  170. overwrite_b = overwrite_b or _datacopied(b1, b)
  171. potrs, = get_lapack_funcs(('potrs',), (c, b1))
  172. x, info = potrs(c, b1, lower=lower, overwrite_b=overwrite_b)
  173. if info != 0:
  174. raise ValueError('illegal value in %d-th argument of internal potrs'
  175. % -info)
  176. return x
  177. def cholesky_banded(ab, overwrite_ab=False, lower=False, check_finite=True):
  178. """
  179. Cholesky decompose a banded Hermitian positive-definite matrix
  180. The matrix a is stored in ab either in lower diagonal or upper
  181. diagonal ordered form::
  182. ab[u + i - j, j] == a[i,j] (if upper form; i <= j)
  183. ab[ i - j, j] == a[i,j] (if lower form; i >= j)
  184. Example of ab (shape of a is (6,6), u=2)::
  185. upper form:
  186. * * a02 a13 a24 a35
  187. * a01 a12 a23 a34 a45
  188. a00 a11 a22 a33 a44 a55
  189. lower form:
  190. a00 a11 a22 a33 a44 a55
  191. a10 a21 a32 a43 a54 *
  192. a20 a31 a42 a53 * *
  193. Parameters
  194. ----------
  195. ab : (u + 1, M) array_like
  196. Banded matrix
  197. overwrite_ab : bool, optional
  198. Discard data in ab (may enhance performance)
  199. lower : bool, optional
  200. Is the matrix in the lower form. (Default is upper form)
  201. check_finite : bool, optional
  202. Whether to check that the input matrix contains only finite numbers.
  203. Disabling may give a performance gain, but may result in problems
  204. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  205. Returns
  206. -------
  207. c : (u + 1, M) ndarray
  208. Cholesky factorization of a, in the same banded format as ab
  209. See also
  210. --------
  211. cho_solve_banded : Solve a linear set equations, given the Cholesky factorization
  212. of a banded hermitian.
  213. Examples
  214. --------
  215. >>> from scipy.linalg import cholesky_banded
  216. >>> from numpy import allclose, zeros, diag
  217. >>> Ab = np.array([[0, 0, 1j, 2, 3j], [0, -1, -2, 3, 4], [9, 8, 7, 6, 9]])
  218. >>> A = np.diag(Ab[0,2:], k=2) + np.diag(Ab[1,1:], k=1)
  219. >>> A = A + A.conj().T + np.diag(Ab[2, :])
  220. >>> c = cholesky_banded(Ab)
  221. >>> C = np.diag(c[0, 2:], k=2) + np.diag(c[1, 1:], k=1) + np.diag(c[2, :])
  222. >>> np.allclose(C.conj().T @ C - A, np.zeros((5, 5)))
  223. True
  224. """
  225. if check_finite:
  226. ab = asarray_chkfinite(ab)
  227. else:
  228. ab = asarray(ab)
  229. pbtrf, = get_lapack_funcs(('pbtrf',), (ab,))
  230. c, info = pbtrf(ab, lower=lower, overwrite_ab=overwrite_ab)
  231. if info > 0:
  232. raise LinAlgError("%d-th leading minor not positive definite" % info)
  233. if info < 0:
  234. raise ValueError('illegal value in %d-th argument of internal pbtrf'
  235. % -info)
  236. return c
  237. def cho_solve_banded(cb_and_lower, b, overwrite_b=False, check_finite=True):
  238. """
  239. Solve the linear equations ``A x = b``, given the Cholesky factorization of
  240. the banded hermitian ``A``.
  241. Parameters
  242. ----------
  243. (cb, lower) : tuple, (ndarray, bool)
  244. `cb` is the Cholesky factorization of A, as given by cholesky_banded.
  245. `lower` must be the same value that was given to cholesky_banded.
  246. b : array_like
  247. Right-hand side
  248. overwrite_b : bool, optional
  249. If True, the function will overwrite the values in `b`.
  250. check_finite : bool, optional
  251. Whether to check that the input matrices contain only finite numbers.
  252. Disabling may give a performance gain, but may result in problems
  253. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  254. Returns
  255. -------
  256. x : array
  257. The solution to the system A x = b
  258. See also
  259. --------
  260. cholesky_banded : Cholesky factorization of a banded matrix
  261. Notes
  262. -----
  263. .. versionadded:: 0.8.0
  264. Examples
  265. --------
  266. >>> from scipy.linalg import cholesky_banded, cho_solve_banded
  267. >>> Ab = np.array([[0, 0, 1j, 2, 3j], [0, -1, -2, 3, 4], [9, 8, 7, 6, 9]])
  268. >>> A = np.diag(Ab[0,2:], k=2) + np.diag(Ab[1,1:], k=1)
  269. >>> A = A + A.conj().T + np.diag(Ab[2, :])
  270. >>> c = cholesky_banded(Ab)
  271. >>> x = cho_solve_banded((c, False), np.ones(5))
  272. >>> np.allclose(A @ x - np.ones(5), np.zeros(5))
  273. True
  274. """
  275. (cb, lower) = cb_and_lower
  276. if check_finite:
  277. cb = asarray_chkfinite(cb)
  278. b = asarray_chkfinite(b)
  279. else:
  280. cb = asarray(cb)
  281. b = asarray(b)
  282. # Validate shapes.
  283. if cb.shape[-1] != b.shape[0]:
  284. raise ValueError("shapes of cb and b are not compatible.")
  285. pbtrs, = get_lapack_funcs(('pbtrs',), (cb, b))
  286. x, info = pbtrs(cb, b, lower=lower, overwrite_b=overwrite_b)
  287. if info > 0:
  288. raise LinAlgError("%d-th leading minor not positive definite" % info)
  289. if info < 0:
  290. raise ValueError('illegal value in %d-th argument of internal pbtrs'
  291. % -info)
  292. return x