decomp_qr.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. """QR decomposition functions."""
  2. from __future__ import division, print_function, absolute_import
  3. import numpy
  4. # Local imports
  5. from .lapack import get_lapack_funcs
  6. from .misc import _datacopied
  7. __all__ = ['qr', 'qr_multiply', 'rq']
  8. def safecall(f, name, *args, **kwargs):
  9. """Call a LAPACK routine, determining lwork automatically and handling
  10. error return values"""
  11. lwork = kwargs.get("lwork", None)
  12. if lwork in (None, -1):
  13. kwargs['lwork'] = -1
  14. ret = f(*args, **kwargs)
  15. kwargs['lwork'] = ret[-2][0].real.astype(numpy.int)
  16. ret = f(*args, **kwargs)
  17. if ret[-1] < 0:
  18. raise ValueError("illegal value in %d-th argument of internal %s"
  19. % (-ret[-1], name))
  20. return ret[:-2]
  21. def qr(a, overwrite_a=False, lwork=None, mode='full', pivoting=False,
  22. check_finite=True):
  23. """
  24. Compute QR decomposition of a matrix.
  25. Calculate the decomposition ``A = Q R`` where Q is unitary/orthogonal
  26. and R upper triangular.
  27. Parameters
  28. ----------
  29. a : (M, N) array_like
  30. Matrix to be decomposed
  31. overwrite_a : bool, optional
  32. Whether data in a is overwritten (may improve performance)
  33. lwork : int, optional
  34. Work array size, lwork >= a.shape[1]. If None or -1, an optimal size
  35. is computed.
  36. mode : {'full', 'r', 'economic', 'raw'}, optional
  37. Determines what information is to be returned: either both Q and R
  38. ('full', default), only R ('r') or both Q and R but computed in
  39. economy-size ('economic', see Notes). The final option 'raw'
  40. (added in Scipy 0.11) makes the function return two matrices
  41. (Q, TAU) in the internal format used by LAPACK.
  42. pivoting : bool, optional
  43. Whether or not factorization should include pivoting for rank-revealing
  44. qr decomposition. If pivoting, compute the decomposition
  45. ``A P = Q R`` as above, but where P is chosen such that the diagonal
  46. of R is non-increasing.
  47. check_finite : bool, optional
  48. Whether to check that the input matrix contains only finite numbers.
  49. Disabling may give a performance gain, but may result in problems
  50. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  51. Returns
  52. -------
  53. Q : float or complex ndarray
  54. Of shape (M, M), or (M, K) for ``mode='economic'``. Not returned
  55. if ``mode='r'``.
  56. R : float or complex ndarray
  57. Of shape (M, N), or (K, N) for ``mode='economic'``. ``K = min(M, N)``.
  58. P : int ndarray
  59. Of shape (N,) for ``pivoting=True``. Not returned if
  60. ``pivoting=False``.
  61. Raises
  62. ------
  63. LinAlgError
  64. Raised if decomposition fails
  65. Notes
  66. -----
  67. This is an interface to the LAPACK routines dgeqrf, zgeqrf,
  68. dorgqr, zungqr, dgeqp3, and zgeqp3.
  69. If ``mode=economic``, the shapes of Q and R are (M, K) and (K, N) instead
  70. of (M,M) and (M,N), with ``K=min(M,N)``.
  71. Examples
  72. --------
  73. >>> from scipy import random, linalg, dot, diag, all, allclose
  74. >>> a = random.randn(9, 6)
  75. >>> q, r = linalg.qr(a)
  76. >>> allclose(a, np.dot(q, r))
  77. True
  78. >>> q.shape, r.shape
  79. ((9, 9), (9, 6))
  80. >>> r2 = linalg.qr(a, mode='r')
  81. >>> allclose(r, r2)
  82. True
  83. >>> q3, r3 = linalg.qr(a, mode='economic')
  84. >>> q3.shape, r3.shape
  85. ((9, 6), (6, 6))
  86. >>> q4, r4, p4 = linalg.qr(a, pivoting=True)
  87. >>> d = abs(diag(r4))
  88. >>> all(d[1:] <= d[:-1])
  89. True
  90. >>> allclose(a[:, p4], dot(q4, r4))
  91. True
  92. >>> q4.shape, r4.shape, p4.shape
  93. ((9, 9), (9, 6), (6,))
  94. >>> q5, r5, p5 = linalg.qr(a, mode='economic', pivoting=True)
  95. >>> q5.shape, r5.shape, p5.shape
  96. ((9, 6), (6, 6), (6,))
  97. """
  98. # 'qr' was the old default, equivalent to 'full'. Neither 'full' nor
  99. # 'qr' are used below.
  100. # 'raw' is used internally by qr_multiply
  101. if mode not in ['full', 'qr', 'r', 'economic', 'raw']:
  102. raise ValueError("Mode argument should be one of ['full', 'r',"
  103. "'economic', 'raw']")
  104. if check_finite:
  105. a1 = numpy.asarray_chkfinite(a)
  106. else:
  107. a1 = numpy.asarray(a)
  108. if len(a1.shape) != 2:
  109. raise ValueError("expected 2D array")
  110. M, N = a1.shape
  111. overwrite_a = overwrite_a or (_datacopied(a1, a))
  112. if pivoting:
  113. geqp3, = get_lapack_funcs(('geqp3',), (a1,))
  114. qr, jpvt, tau = safecall(geqp3, "geqp3", a1, overwrite_a=overwrite_a)
  115. jpvt -= 1 # geqp3 returns a 1-based index array, so subtract 1
  116. else:
  117. geqrf, = get_lapack_funcs(('geqrf',), (a1,))
  118. qr, tau = safecall(geqrf, "geqrf", a1, lwork=lwork,
  119. overwrite_a=overwrite_a)
  120. if mode not in ['economic', 'raw'] or M < N:
  121. R = numpy.triu(qr)
  122. else:
  123. R = numpy.triu(qr[:N, :])
  124. if pivoting:
  125. Rj = R, jpvt
  126. else:
  127. Rj = R,
  128. if mode == 'r':
  129. return Rj
  130. elif mode == 'raw':
  131. return ((qr, tau),) + Rj
  132. gor_un_gqr, = get_lapack_funcs(('orgqr',), (qr,))
  133. if M < N:
  134. Q, = safecall(gor_un_gqr, "gorgqr/gungqr", qr[:, :M], tau,
  135. lwork=lwork, overwrite_a=1)
  136. elif mode == 'economic':
  137. Q, = safecall(gor_un_gqr, "gorgqr/gungqr", qr, tau, lwork=lwork,
  138. overwrite_a=1)
  139. else:
  140. t = qr.dtype.char
  141. qqr = numpy.empty((M, M), dtype=t)
  142. qqr[:, :N] = qr
  143. Q, = safecall(gor_un_gqr, "gorgqr/gungqr", qqr, tau, lwork=lwork,
  144. overwrite_a=1)
  145. return (Q,) + Rj
  146. def qr_multiply(a, c, mode='right', pivoting=False, conjugate=False,
  147. overwrite_a=False, overwrite_c=False):
  148. """
  149. Calculate the QR decomposition and multiply Q with a matrix.
  150. Calculate the decomposition ``A = Q R`` where Q is unitary/orthogonal
  151. and R upper triangular. Multiply Q with a vector or a matrix c.
  152. Parameters
  153. ----------
  154. a : (M, N), array_like
  155. Input array
  156. c : array_like
  157. Input array to be multiplied by ``q``.
  158. mode : {'left', 'right'}, optional
  159. ``Q @ c`` is returned if mode is 'left', ``c @ Q`` is returned if
  160. mode is 'right'.
  161. The shape of c must be appropriate for the matrix multiplications,
  162. if mode is 'left', ``min(a.shape) == c.shape[0]``,
  163. if mode is 'right', ``a.shape[0] == c.shape[1]``.
  164. pivoting : bool, optional
  165. Whether or not factorization should include pivoting for rank-revealing
  166. qr decomposition, see the documentation of qr.
  167. conjugate : bool, optional
  168. Whether Q should be complex-conjugated. This might be faster
  169. than explicit conjugation.
  170. overwrite_a : bool, optional
  171. Whether data in a is overwritten (may improve performance)
  172. overwrite_c : bool, optional
  173. Whether data in c is overwritten (may improve performance).
  174. If this is used, c must be big enough to keep the result,
  175. i.e. ``c.shape[0]`` = ``a.shape[0]`` if mode is 'left'.
  176. Returns
  177. -------
  178. CQ : ndarray
  179. The product of ``Q`` and ``c``.
  180. R : (K, N), ndarray
  181. R array of the resulting QR factorization where ``K = min(M, N)``.
  182. P : (N,) ndarray
  183. Integer pivot array. Only returned when ``pivoting=True``.
  184. Raises
  185. ------
  186. LinAlgError
  187. Raised if QR decomposition fails.
  188. Notes
  189. -----
  190. This is an interface to the LAPACK routines ``?GEQRF``, ``?ORMQR``,
  191. ``?UNMQR``, and ``?GEQP3``.
  192. .. versionadded:: 0.11.0
  193. Examples
  194. --------
  195. >>> from scipy.linalg import qr_multiply, qr
  196. >>> A = np.array([[1, 3, 3], [2, 3, 2], [2, 3, 3], [1, 3, 2]])
  197. >>> qc, r1, piv1 = qr_multiply(A, 2*np.eye(4), pivoting=1)
  198. >>> qc
  199. array([[-1., 1., -1.],
  200. [-1., -1., 1.],
  201. [-1., -1., -1.],
  202. [-1., 1., 1.]])
  203. >>> r1
  204. array([[-6., -3., -5. ],
  205. [ 0., -1., -1.11022302e-16],
  206. [ 0., 0., -1. ]])
  207. >>> piv1
  208. array([1, 0, 2], dtype=int32)
  209. >>> q2, r2, piv2 = qr(A, mode='economic', pivoting=1)
  210. >>> np.allclose(2*q2 - qc, np.zeros((4, 3)))
  211. True
  212. """
  213. if mode not in ['left', 'right']:
  214. raise ValueError("Mode argument can only be 'left' or 'right' but "
  215. "not '{}'".format(mode))
  216. c = numpy.asarray_chkfinite(c)
  217. if c.ndim < 2:
  218. onedim = True
  219. c = numpy.atleast_2d(c)
  220. if mode == "left":
  221. c = c.T
  222. else:
  223. onedim = False
  224. a = numpy.atleast_2d(numpy.asarray(a)) # chkfinite done in qr
  225. M, N = a.shape
  226. if mode == 'left':
  227. if c.shape[0] != min(M, N + overwrite_c*(M-N)):
  228. raise ValueError('Array shapes are not compatible for Q @ c'
  229. ' operation: {} vs {}'.format(a.shape, c.shape))
  230. else:
  231. if M != c.shape[1]:
  232. raise ValueError('Array shapes are not compatible for c @ Q'
  233. ' operation: {} vs {}'.format(c.shape, a.shape))
  234. raw = qr(a, overwrite_a, None, "raw", pivoting)
  235. Q, tau = raw[0]
  236. gor_un_mqr, = get_lapack_funcs(('ormqr',), (Q,))
  237. if gor_un_mqr.typecode in ('s', 'd'):
  238. trans = "T"
  239. else:
  240. trans = "C"
  241. Q = Q[:, :min(M, N)]
  242. if M > N and mode == "left" and not overwrite_c:
  243. if conjugate:
  244. cc = numpy.zeros((c.shape[1], M), dtype=c.dtype, order="F")
  245. cc[:, :N] = c.T
  246. else:
  247. cc = numpy.zeros((M, c.shape[1]), dtype=c.dtype, order="F")
  248. cc[:N, :] = c
  249. trans = "N"
  250. if conjugate:
  251. lr = "R"
  252. else:
  253. lr = "L"
  254. overwrite_c = True
  255. elif c.flags["C_CONTIGUOUS"] and trans == "T" or conjugate:
  256. cc = c.T
  257. if mode == "left":
  258. lr = "R"
  259. else:
  260. lr = "L"
  261. else:
  262. trans = "N"
  263. cc = c
  264. if mode == "left":
  265. lr = "L"
  266. else:
  267. lr = "R"
  268. cQ, = safecall(gor_un_mqr, "gormqr/gunmqr", lr, trans, Q, tau, cc,
  269. overwrite_c=overwrite_c)
  270. if trans != "N":
  271. cQ = cQ.T
  272. if mode == "right":
  273. cQ = cQ[:, :min(M, N)]
  274. if onedim:
  275. cQ = cQ.ravel()
  276. return (cQ,) + raw[1:]
  277. def rq(a, overwrite_a=False, lwork=None, mode='full', check_finite=True):
  278. """
  279. Compute RQ decomposition of a matrix.
  280. Calculate the decomposition ``A = R Q`` where Q is unitary/orthogonal
  281. and R upper triangular.
  282. Parameters
  283. ----------
  284. a : (M, N) array_like
  285. Matrix to be decomposed
  286. overwrite_a : bool, optional
  287. Whether data in a is overwritten (may improve performance)
  288. lwork : int, optional
  289. Work array size, lwork >= a.shape[1]. If None or -1, an optimal size
  290. is computed.
  291. mode : {'full', 'r', 'economic'}, optional
  292. Determines what information is to be returned: either both Q and R
  293. ('full', default), only R ('r') or both Q and R but computed in
  294. economy-size ('economic', see Notes).
  295. check_finite : bool, optional
  296. Whether to check that the input matrix contains only finite numbers.
  297. Disabling may give a performance gain, but may result in problems
  298. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  299. Returns
  300. -------
  301. R : float or complex ndarray
  302. Of shape (M, N) or (M, K) for ``mode='economic'``. ``K = min(M, N)``.
  303. Q : float or complex ndarray
  304. Of shape (N, N) or (K, N) for ``mode='economic'``. Not returned
  305. if ``mode='r'``.
  306. Raises
  307. ------
  308. LinAlgError
  309. If decomposition fails.
  310. Notes
  311. -----
  312. This is an interface to the LAPACK routines sgerqf, dgerqf, cgerqf, zgerqf,
  313. sorgrq, dorgrq, cungrq and zungrq.
  314. If ``mode=economic``, the shapes of Q and R are (K, N) and (M, K) instead
  315. of (N,N) and (M,N), with ``K=min(M,N)``.
  316. Examples
  317. --------
  318. >>> from scipy import linalg
  319. >>> a = np.random.randn(6, 9)
  320. >>> r, q = linalg.rq(a)
  321. >>> np.allclose(a, r @ q)
  322. True
  323. >>> r.shape, q.shape
  324. ((6, 9), (9, 9))
  325. >>> r2 = linalg.rq(a, mode='r')
  326. >>> np.allclose(r, r2)
  327. True
  328. >>> r3, q3 = linalg.rq(a, mode='economic')
  329. >>> r3.shape, q3.shape
  330. ((6, 6), (6, 9))
  331. """
  332. if mode not in ['full', 'r', 'economic']:
  333. raise ValueError(
  334. "Mode argument should be one of ['full', 'r', 'economic']")
  335. if check_finite:
  336. a1 = numpy.asarray_chkfinite(a)
  337. else:
  338. a1 = numpy.asarray(a)
  339. if len(a1.shape) != 2:
  340. raise ValueError('expected matrix')
  341. M, N = a1.shape
  342. overwrite_a = overwrite_a or (_datacopied(a1, a))
  343. gerqf, = get_lapack_funcs(('gerqf',), (a1,))
  344. rq, tau = safecall(gerqf, 'gerqf', a1, lwork=lwork,
  345. overwrite_a=overwrite_a)
  346. if not mode == 'economic' or N < M:
  347. R = numpy.triu(rq, N-M)
  348. else:
  349. R = numpy.triu(rq[-M:, -M:])
  350. if mode == 'r':
  351. return R
  352. gor_un_grq, = get_lapack_funcs(('orgrq',), (rq,))
  353. if N < M:
  354. Q, = safecall(gor_un_grq, "gorgrq/gungrq", rq[-N:], tau, lwork=lwork,
  355. overwrite_a=1)
  356. elif mode == 'economic':
  357. Q, = safecall(gor_un_grq, "gorgrq/gungrq", rq, tau, lwork=lwork,
  358. overwrite_a=1)
  359. else:
  360. rq1 = numpy.empty((N, N), dtype=rq.dtype)
  361. rq1[-M:] = rq
  362. Q, = safecall(gor_un_grq, "gorgrq/gungrq", rq1, tau, lwork=lwork,
  363. overwrite_a=1)
  364. return R, Q