minres.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. from __future__ import division, print_function, absolute_import
  2. from numpy import sqrt, inner, zeros, inf, finfo
  3. from numpy.linalg import norm
  4. from .utils import make_system
  5. __all__ = ['minres']
  6. def minres(A, b, x0=None, shift=0.0, tol=1e-5, maxiter=None,
  7. M=None, callback=None, show=False, check=False):
  8. """
  9. Use MINimum RESidual iteration to solve Ax=b
  10. MINRES minimizes norm(A*x - b) for a real symmetric matrix A. Unlike
  11. the Conjugate Gradient method, A can be indefinite or singular.
  12. If shift != 0 then the method solves (A - shift*I)x = b
  13. Parameters
  14. ----------
  15. A : {sparse matrix, dense matrix, LinearOperator}
  16. The real symmetric N-by-N matrix of the linear system
  17. b : {array, matrix}
  18. Right hand side of the linear system. Has shape (N,) or (N,1).
  19. Returns
  20. -------
  21. x : {array, matrix}
  22. The converged solution.
  23. info : integer
  24. Provides convergence information:
  25. 0 : successful exit
  26. >0 : convergence to tolerance not achieved, number of iterations
  27. <0 : illegal input or breakdown
  28. Other Parameters
  29. ----------------
  30. x0 : {array, matrix}
  31. Starting guess for the solution.
  32. tol : float
  33. Tolerance to achieve. The algorithm terminates when the relative
  34. residual is below `tol`.
  35. maxiter : integer
  36. Maximum number of iterations. Iteration will stop after maxiter
  37. steps even if the specified tolerance has not been achieved.
  38. M : {sparse matrix, dense matrix, LinearOperator}
  39. Preconditioner for A. The preconditioner should approximate the
  40. inverse of A. Effective preconditioning dramatically improves the
  41. rate of convergence, which implies that fewer iterations are needed
  42. to reach a given error tolerance.
  43. callback : function
  44. User-supplied function to call after each iteration. It is called
  45. as callback(xk), where xk is the current solution vector.
  46. References
  47. ----------
  48. Solution of sparse indefinite systems of linear equations,
  49. C. C. Paige and M. A. Saunders (1975),
  50. SIAM J. Numer. Anal. 12(4), pp. 617-629.
  51. https://web.stanford.edu/group/SOL/software/minres/
  52. This file is a translation of the following MATLAB implementation:
  53. https://web.stanford.edu/group/SOL/software/minres/minres-matlab.zip
  54. """
  55. A, M, x, b, postprocess = make_system(A, M, x0, b)
  56. matvec = A.matvec
  57. psolve = M.matvec
  58. first = 'Enter minres. '
  59. last = 'Exit minres. '
  60. n = A.shape[0]
  61. if maxiter is None:
  62. maxiter = 5 * n
  63. msg = [' beta2 = 0. If M = I, b and x are eigenvectors ', # -1
  64. ' beta1 = 0. The exact solution is x = 0 ', # 0
  65. ' A solution to Ax = b was found, given rtol ', # 1
  66. ' A least-squares solution was found, given rtol ', # 2
  67. ' Reasonable accuracy achieved, given eps ', # 3
  68. ' x has converged to an eigenvector ', # 4
  69. ' acond has exceeded 0.1/eps ', # 5
  70. ' The iteration limit was reached ', # 6
  71. ' A does not define a symmetric matrix ', # 7
  72. ' M does not define a symmetric matrix ', # 8
  73. ' M does not define a pos-def preconditioner '] # 9
  74. if show:
  75. print(first + 'Solution of symmetric Ax = b')
  76. print(first + 'n = %3g shift = %23.14e' % (n,shift))
  77. print(first + 'itnlim = %3g rtol = %11.2e' % (maxiter,tol))
  78. print()
  79. istop = 0
  80. itn = 0
  81. Anorm = 0
  82. Acond = 0
  83. rnorm = 0
  84. ynorm = 0
  85. xtype = x.dtype
  86. eps = finfo(xtype).eps
  87. x = zeros(n, dtype=xtype)
  88. # Set up y and v for the first Lanczos vector v1.
  89. # y = beta1 P' v1, where P = C**(-1).
  90. # v is really P' v1.
  91. y = b
  92. r1 = b
  93. y = psolve(b)
  94. beta1 = inner(b,y)
  95. if beta1 < 0:
  96. raise ValueError('indefinite preconditioner')
  97. elif beta1 == 0:
  98. return (postprocess(x), 0)
  99. beta1 = sqrt(beta1)
  100. if check:
  101. # are these too strict?
  102. # see if A is symmetric
  103. w = matvec(y)
  104. r2 = matvec(w)
  105. s = inner(w,w)
  106. t = inner(y,r2)
  107. z = abs(s - t)
  108. epsa = (s + eps) * eps**(1.0/3.0)
  109. if z > epsa:
  110. raise ValueError('non-symmetric matrix')
  111. # see if M is symmetric
  112. r2 = psolve(y)
  113. s = inner(y,y)
  114. t = inner(r1,r2)
  115. z = abs(s - t)
  116. epsa = (s + eps) * eps**(1.0/3.0)
  117. if z > epsa:
  118. raise ValueError('non-symmetric preconditioner')
  119. # Initialize other quantities
  120. oldb = 0
  121. beta = beta1
  122. dbar = 0
  123. epsln = 0
  124. qrnorm = beta1
  125. phibar = beta1
  126. rhs1 = beta1
  127. rhs2 = 0
  128. tnorm2 = 0
  129. gmax = 0
  130. gmin = finfo(xtype).max
  131. cs = -1
  132. sn = 0
  133. w = zeros(n, dtype=xtype)
  134. w2 = zeros(n, dtype=xtype)
  135. r2 = r1
  136. if show:
  137. print()
  138. print()
  139. print(' Itn x(1) Compatible LS norm(A) cond(A) gbar/|A|')
  140. while itn < maxiter:
  141. itn += 1
  142. s = 1.0/beta
  143. v = s*y
  144. y = matvec(v)
  145. y = y - shift * v
  146. if itn >= 2:
  147. y = y - (beta/oldb)*r1
  148. alfa = inner(v,y)
  149. y = y - (alfa/beta)*r2
  150. r1 = r2
  151. r2 = y
  152. y = psolve(r2)
  153. oldb = beta
  154. beta = inner(r2,y)
  155. if beta < 0:
  156. raise ValueError('non-symmetric matrix')
  157. beta = sqrt(beta)
  158. tnorm2 += alfa**2 + oldb**2 + beta**2
  159. if itn == 1:
  160. if beta/beta1 <= 10*eps:
  161. istop = -1 # Terminate later
  162. # Apply previous rotation Qk-1 to get
  163. # [deltak epslnk+1] = [cs sn][dbark 0 ]
  164. # [gbar k dbar k+1] [sn -cs][alfak betak+1].
  165. oldeps = epsln
  166. delta = cs * dbar + sn * alfa # delta1 = 0 deltak
  167. gbar = sn * dbar - cs * alfa # gbar 1 = alfa1 gbar k
  168. epsln = sn * beta # epsln2 = 0 epslnk+1
  169. dbar = - cs * beta # dbar 2 = beta2 dbar k+1
  170. root = norm([gbar, dbar])
  171. Arnorm = phibar * root
  172. # Compute the next plane rotation Qk
  173. gamma = norm([gbar, beta]) # gammak
  174. gamma = max(gamma, eps)
  175. cs = gbar / gamma # ck
  176. sn = beta / gamma # sk
  177. phi = cs * phibar # phik
  178. phibar = sn * phibar # phibark+1
  179. # Update x.
  180. denom = 1.0/gamma
  181. w1 = w2
  182. w2 = w
  183. w = (v - oldeps*w1 - delta*w2) * denom
  184. x = x + phi*w
  185. # Go round again.
  186. gmax = max(gmax, gamma)
  187. gmin = min(gmin, gamma)
  188. z = rhs1 / gamma
  189. rhs1 = rhs2 - delta*z
  190. rhs2 = - epsln*z
  191. # Estimate various norms and test for convergence.
  192. Anorm = sqrt(tnorm2)
  193. ynorm = norm(x)
  194. epsa = Anorm * eps
  195. epsx = Anorm * ynorm * eps
  196. epsr = Anorm * ynorm * tol
  197. diag = gbar
  198. if diag == 0:
  199. diag = epsa
  200. qrnorm = phibar
  201. rnorm = qrnorm
  202. if ynorm == 0 or Anorm == 0:
  203. test1 = inf
  204. else:
  205. test1 = rnorm / (Anorm*ynorm) # ||r|| / (||A|| ||x||)
  206. if Anorm == 0:
  207. test2 = inf
  208. else:
  209. test2 = root / Anorm # ||Ar|| / (||A|| ||r||)
  210. # Estimate cond(A).
  211. # In this version we look at the diagonals of R in the
  212. # factorization of the lower Hessenberg matrix, Q * H = R,
  213. # where H is the tridiagonal matrix from Lanczos with one
  214. # extra row, beta(k+1) e_k^T.
  215. Acond = gmax/gmin
  216. # See if any of the stopping criteria are satisfied.
  217. # In rare cases, istop is already -1 from above (Abar = const*I).
  218. if istop == 0:
  219. t1 = 1 + test1 # These tests work if tol < eps
  220. t2 = 1 + test2
  221. if t2 <= 1:
  222. istop = 2
  223. if t1 <= 1:
  224. istop = 1
  225. if itn >= maxiter:
  226. istop = 6
  227. if Acond >= 0.1/eps:
  228. istop = 4
  229. if epsx >= beta1:
  230. istop = 3
  231. # if rnorm <= epsx : istop = 2
  232. # if rnorm <= epsr : istop = 1
  233. if test2 <= tol:
  234. istop = 2
  235. if test1 <= tol:
  236. istop = 1
  237. # See if it is time to print something.
  238. prnt = False
  239. if n <= 40:
  240. prnt = True
  241. if itn <= 10:
  242. prnt = True
  243. if itn >= maxiter-10:
  244. prnt = True
  245. if itn % 10 == 0:
  246. prnt = True
  247. if qrnorm <= 10*epsx:
  248. prnt = True
  249. if qrnorm <= 10*epsr:
  250. prnt = True
  251. if Acond <= 1e-2/eps:
  252. prnt = True
  253. if istop != 0:
  254. prnt = True
  255. if show and prnt:
  256. str1 = '%6g %12.5e %10.3e' % (itn, x[0], test1)
  257. str2 = ' %10.3e' % (test2,)
  258. str3 = ' %8.1e %8.1e %8.1e' % (Anorm, Acond, gbar/Anorm)
  259. print(str1 + str2 + str3)
  260. if itn % 10 == 0:
  261. print()
  262. if callback is not None:
  263. callback(x)
  264. if istop != 0:
  265. break # TODO check this
  266. if show:
  267. print()
  268. print(last + ' istop = %3g itn =%5g' % (istop,itn))
  269. print(last + ' Anorm = %12.4e Acond = %12.4e' % (Anorm,Acond))
  270. print(last + ' rnorm = %12.4e ynorm = %12.4e' % (rnorm,ynorm))
  271. print(last + ' Arnorm = %12.4e' % (Arnorm,))
  272. print(last + msg[istop+1])
  273. if istop == 6:
  274. info = maxiter
  275. else:
  276. info = 0
  277. return (postprocess(x),info)
  278. if __name__ == '__main__':
  279. from scipy import ones, arange
  280. from scipy.linalg import norm
  281. from scipy.sparse import spdiags
  282. n = 10
  283. residuals = []
  284. def cb(x):
  285. residuals.append(norm(b - A*x))
  286. # A = poisson((10,),format='csr')
  287. A = spdiags([arange(1,n+1,dtype=float)], [0], n, n, format='csr')
  288. M = spdiags([1.0/arange(1,n+1,dtype=float)], [0], n, n, format='csr')
  289. A.psolve = M.matvec
  290. b = 0*ones(A.shape[0])
  291. x = minres(A,b,tol=1e-12,maxiter=None,callback=cb)
  292. # x = cg(A,b,x0=b,tol=1e-12,maxiter=None,callback=cb)[0]