decomp_schur.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. """Schur decomposition functions."""
  2. from __future__ import division, print_function, absolute_import
  3. import numpy
  4. from numpy import asarray_chkfinite, single, asarray, array
  5. from numpy.linalg import norm
  6. from scipy._lib.six import callable
  7. # Local imports.
  8. from .misc import LinAlgError, _datacopied
  9. from .lapack import get_lapack_funcs
  10. from .decomp import eigvals
  11. __all__ = ['schur', 'rsf2csf']
  12. _double_precision = ['i', 'l', 'd']
  13. def schur(a, output='real', lwork=None, overwrite_a=False, sort=None,
  14. check_finite=True):
  15. """
  16. Compute Schur decomposition of a matrix.
  17. The Schur decomposition is::
  18. A = Z T Z^H
  19. where Z is unitary and T is either upper-triangular, or for real
  20. Schur decomposition (output='real'), quasi-upper triangular. In
  21. the quasi-triangular form, 2x2 blocks describing complex-valued
  22. eigenvalue pairs may extrude from the diagonal.
  23. Parameters
  24. ----------
  25. a : (M, M) array_like
  26. Matrix to decompose
  27. output : {'real', 'complex'}, optional
  28. Construct the real or complex Schur decomposition (for real matrices).
  29. lwork : int, optional
  30. Work array size. If None or -1, it is automatically computed.
  31. overwrite_a : bool, optional
  32. Whether to overwrite data in a (may improve performance).
  33. sort : {None, callable, 'lhp', 'rhp', 'iuc', 'ouc'}, optional
  34. Specifies whether the upper eigenvalues should be sorted. A callable
  35. may be passed that, given a eigenvalue, returns a boolean denoting
  36. whether the eigenvalue should be sorted to the top-left (True).
  37. Alternatively, string parameters may be used::
  38. 'lhp' Left-hand plane (x.real < 0.0)
  39. 'rhp' Right-hand plane (x.real > 0.0)
  40. 'iuc' Inside the unit circle (x*x.conjugate() <= 1.0)
  41. 'ouc' Outside the unit circle (x*x.conjugate() > 1.0)
  42. Defaults to None (no sorting).
  43. check_finite : bool, optional
  44. Whether to check that the input matrix contains only finite numbers.
  45. Disabling may give a performance gain, but may result in problems
  46. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  47. Returns
  48. -------
  49. T : (M, M) ndarray
  50. Schur form of A. It is real-valued for the real Schur decomposition.
  51. Z : (M, M) ndarray
  52. An unitary Schur transformation matrix for A.
  53. It is real-valued for the real Schur decomposition.
  54. sdim : int
  55. If and only if sorting was requested, a third return value will
  56. contain the number of eigenvalues satisfying the sort condition.
  57. Raises
  58. ------
  59. LinAlgError
  60. Error raised under three conditions:
  61. 1. The algorithm failed due to a failure of the QR algorithm to
  62. compute all eigenvalues
  63. 2. If eigenvalue sorting was requested, the eigenvalues could not be
  64. reordered due to a failure to separate eigenvalues, usually because
  65. of poor conditioning
  66. 3. If eigenvalue sorting was requested, roundoff errors caused the
  67. leading eigenvalues to no longer satisfy the sorting condition
  68. See also
  69. --------
  70. rsf2csf : Convert real Schur form to complex Schur form
  71. Examples
  72. --------
  73. >>> from scipy.linalg import schur, eigvals
  74. >>> A = np.array([[0, 2, 2], [0, 1, 2], [1, 0, 1]])
  75. >>> T, Z = schur(A)
  76. >>> T
  77. array([[ 2.65896708, 1.42440458, -1.92933439],
  78. [ 0. , -0.32948354, -0.49063704],
  79. [ 0. , 1.31178921, -0.32948354]])
  80. >>> Z
  81. array([[0.72711591, -0.60156188, 0.33079564],
  82. [0.52839428, 0.79801892, 0.28976765],
  83. [0.43829436, 0.03590414, -0.89811411]])
  84. >>> T2, Z2 = schur(A, output='complex')
  85. >>> T2
  86. array([[ 2.65896708, -1.22839825+1.32378589j, 0.42590089+1.51937378j],
  87. [ 0. , -0.32948354+0.80225456j, -0.59877807+0.56192146j],
  88. [ 0. , 0. , -0.32948354-0.80225456j]])
  89. >>> eigvals(T2)
  90. array([2.65896708, -0.32948354+0.80225456j, -0.32948354-0.80225456j])
  91. An arbitrary custom eig-sorting condition, having positive imaginary part,
  92. which is satisfied by only one eigenvalue
  93. >>> T3, Z3, sdim = schur(A, output='complex', sort=lambda x: x.imag > 0)
  94. >>> sdim
  95. 1
  96. """
  97. if output not in ['real', 'complex', 'r', 'c']:
  98. raise ValueError("argument must be 'real', or 'complex'")
  99. if check_finite:
  100. a1 = asarray_chkfinite(a)
  101. else:
  102. a1 = asarray(a)
  103. if len(a1.shape) != 2 or (a1.shape[0] != a1.shape[1]):
  104. raise ValueError('expected square matrix')
  105. typ = a1.dtype.char
  106. if output in ['complex', 'c'] and typ not in ['F', 'D']:
  107. if typ in _double_precision:
  108. a1 = a1.astype('D')
  109. typ = 'D'
  110. else:
  111. a1 = a1.astype('F')
  112. typ = 'F'
  113. overwrite_a = overwrite_a or (_datacopied(a1, a))
  114. gees, = get_lapack_funcs(('gees',), (a1,))
  115. if lwork is None or lwork == -1:
  116. # get optimal work array
  117. result = gees(lambda x: None, a1, lwork=-1)
  118. lwork = result[-2][0].real.astype(numpy.int)
  119. if sort is None:
  120. sort_t = 0
  121. sfunction = lambda x: None
  122. else:
  123. sort_t = 1
  124. if callable(sort):
  125. sfunction = sort
  126. elif sort == 'lhp':
  127. sfunction = lambda x: (x.real < 0.0)
  128. elif sort == 'rhp':
  129. sfunction = lambda x: (x.real >= 0.0)
  130. elif sort == 'iuc':
  131. sfunction = lambda x: (abs(x) <= 1.0)
  132. elif sort == 'ouc':
  133. sfunction = lambda x: (abs(x) > 1.0)
  134. else:
  135. raise ValueError("'sort' parameter must either be 'None', or a "
  136. "callable, or one of ('lhp','rhp','iuc','ouc')")
  137. result = gees(sfunction, a1, lwork=lwork, overwrite_a=overwrite_a,
  138. sort_t=sort_t)
  139. info = result[-1]
  140. if info < 0:
  141. raise ValueError('illegal value in {}-th argument of internal gees'
  142. ''.format(-info))
  143. elif info == a1.shape[0] + 1:
  144. raise LinAlgError('Eigenvalues could not be separated for reordering.')
  145. elif info == a1.shape[0] + 2:
  146. raise LinAlgError('Leading eigenvalues do not satisfy sort condition.')
  147. elif info > 0:
  148. raise LinAlgError("Schur form not found. Possibly ill-conditioned.")
  149. if sort_t == 0:
  150. return result[0], result[-3]
  151. else:
  152. return result[0], result[-3], result[1]
  153. eps = numpy.finfo(float).eps
  154. feps = numpy.finfo(single).eps
  155. _array_kind = {'b': 0, 'h': 0, 'B': 0, 'i': 0, 'l': 0,
  156. 'f': 0, 'd': 0, 'F': 1, 'D': 1}
  157. _array_precision = {'i': 1, 'l': 1, 'f': 0, 'd': 1, 'F': 0, 'D': 1}
  158. _array_type = [['f', 'd'], ['F', 'D']]
  159. def _commonType(*arrays):
  160. kind = 0
  161. precision = 0
  162. for a in arrays:
  163. t = a.dtype.char
  164. kind = max(kind, _array_kind[t])
  165. precision = max(precision, _array_precision[t])
  166. return _array_type[kind][precision]
  167. def _castCopy(type, *arrays):
  168. cast_arrays = ()
  169. for a in arrays:
  170. if a.dtype.char == type:
  171. cast_arrays = cast_arrays + (a.copy(),)
  172. else:
  173. cast_arrays = cast_arrays + (a.astype(type),)
  174. if len(cast_arrays) == 1:
  175. return cast_arrays[0]
  176. else:
  177. return cast_arrays
  178. def rsf2csf(T, Z, check_finite=True):
  179. """
  180. Convert real Schur form to complex Schur form.
  181. Convert a quasi-diagonal real-valued Schur form to the upper triangular
  182. complex-valued Schur form.
  183. Parameters
  184. ----------
  185. T : (M, M) array_like
  186. Real Schur form of the original array
  187. Z : (M, M) array_like
  188. Schur transformation matrix
  189. check_finite : bool, optional
  190. Whether to check that the input arrays contain only finite numbers.
  191. Disabling may give a performance gain, but may result in problems
  192. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  193. Returns
  194. -------
  195. T : (M, M) ndarray
  196. Complex Schur form of the original array
  197. Z : (M, M) ndarray
  198. Schur transformation matrix corresponding to the complex form
  199. See Also
  200. --------
  201. schur : Schur decomposition of an array
  202. Examples
  203. --------
  204. >>> from scipy.linalg import schur, rsf2csf
  205. >>> A = np.array([[0, 2, 2], [0, 1, 2], [1, 0, 1]])
  206. >>> T, Z = schur(A)
  207. >>> T
  208. array([[ 2.65896708, 1.42440458, -1.92933439],
  209. [ 0. , -0.32948354, -0.49063704],
  210. [ 0. , 1.31178921, -0.32948354]])
  211. >>> Z
  212. array([[0.72711591, -0.60156188, 0.33079564],
  213. [0.52839428, 0.79801892, 0.28976765],
  214. [0.43829436, 0.03590414, -0.89811411]])
  215. >>> T2 , Z2 = rsf2csf(T, Z)
  216. >>> T2
  217. array([[2.65896708+0.j, -1.64592781+0.743164187j, -1.21516887+1.00660462j],
  218. [0.+0.j , -0.32948354+8.02254558e-01j, -0.82115218-2.77555756e-17j],
  219. [0.+0.j , 0.+0.j, -0.32948354-0.802254558j]])
  220. >>> Z2
  221. array([[0.72711591+0.j, 0.28220393-0.31385693j, 0.51319638-0.17258824j],
  222. [0.52839428+0.j, 0.24720268+0.41635578j, -0.68079517-0.15118243j],
  223. [0.43829436+0.j, -0.76618703+0.01873251j, -0.03063006+0.46857912j]])
  224. """
  225. if check_finite:
  226. Z, T = map(asarray_chkfinite, (Z, T))
  227. else:
  228. Z, T = map(asarray, (Z, T))
  229. for ind, X in enumerate([Z, T]):
  230. if X.ndim != 2 or X.shape[0] != X.shape[1]:
  231. raise ValueError("Input '{}' must be square.".format('ZT'[ind]))
  232. if T.shape[0] != Z.shape[0]:
  233. raise ValueError("Input array shapes must match: Z: {} vs. T: {}"
  234. "".format(Z.shape, T.shape))
  235. N = T.shape[0]
  236. t = _commonType(Z, T, array([3.0], 'F'))
  237. Z, T = _castCopy(t, Z, T)
  238. for m in range(N-1, 0, -1):
  239. if abs(T[m, m-1]) > eps*(abs(T[m-1, m-1]) + abs(T[m, m])):
  240. mu = eigvals(T[m-1:m+1, m-1:m+1]) - T[m, m]
  241. r = norm([mu[0], T[m, m-1]])
  242. c = mu[0] / r
  243. s = T[m, m-1] / r
  244. G = array([[c.conj(), s], [-s, c]], dtype=t)
  245. T[m-1:m+1, m-1:] = G.dot(T[m-1:m+1, m-1:])
  246. T[:m+1, m-1:m+1] = T[:m+1, m-1:m+1].dot(G.conj().T)
  247. Z[:, m-1:m+1] = Z[:, m-1:m+1].dot(G.conj().T)
  248. T[m, m-1] = 0.0
  249. return T, Z