lsmr.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. """
  2. Copyright (C) 2010 David Fong and Michael Saunders
  3. LSMR uses an iterative method.
  4. 07 Jun 2010: Documentation updated
  5. 03 Jun 2010: First release version in Python
  6. David Chin-lung Fong clfong@stanford.edu
  7. Institute for Computational and Mathematical Engineering
  8. Stanford University
  9. Michael Saunders saunders@stanford.edu
  10. Systems Optimization Laboratory
  11. Dept of MS&E, Stanford University.
  12. """
  13. from __future__ import division, print_function, absolute_import
  14. __all__ = ['lsmr']
  15. from numpy import zeros, infty, atleast_1d
  16. from numpy.linalg import norm
  17. from math import sqrt
  18. from scipy.sparse.linalg.interface import aslinearoperator
  19. from .lsqr import _sym_ortho
  20. def lsmr(A, b, damp=0.0, atol=1e-6, btol=1e-6, conlim=1e8,
  21. maxiter=None, show=False, x0=None):
  22. """Iterative solver for least-squares problems.
  23. lsmr solves the system of linear equations ``Ax = b``. If the system
  24. is inconsistent, it solves the least-squares problem ``min ||b - Ax||_2``.
  25. A is a rectangular matrix of dimension m-by-n, where all cases are
  26. allowed: m = n, m > n, or m < n. B is a vector of length m.
  27. The matrix A may be dense or sparse (usually sparse).
  28. Parameters
  29. ----------
  30. A : {matrix, sparse matrix, ndarray, LinearOperator}
  31. Matrix A in the linear system.
  32. b : array_like, shape (m,)
  33. Vector b in the linear system.
  34. damp : float
  35. Damping factor for regularized least-squares. `lsmr` solves
  36. the regularized least-squares problem::
  37. min ||(b) - ( A )x||
  38. ||(0) (damp*I) ||_2
  39. where damp is a scalar. If damp is None or 0, the system
  40. is solved without regularization.
  41. atol, btol : float, optional
  42. Stopping tolerances. `lsmr` continues iterations until a
  43. certain backward error estimate is smaller than some quantity
  44. depending on atol and btol. Let ``r = b - Ax`` be the
  45. residual vector for the current approximate solution ``x``.
  46. If ``Ax = b`` seems to be consistent, ``lsmr`` terminates
  47. when ``norm(r) <= atol * norm(A) * norm(x) + btol * norm(b)``.
  48. Otherwise, lsmr terminates when ``norm(A^{T} r) <=
  49. atol * norm(A) * norm(r)``. If both tolerances are 1.0e-6 (say),
  50. the final ``norm(r)`` should be accurate to about 6
  51. digits. (The final x will usually have fewer correct digits,
  52. depending on ``cond(A)`` and the size of LAMBDA.) If `atol`
  53. or `btol` is None, a default value of 1.0e-6 will be used.
  54. Ideally, they should be estimates of the relative error in the
  55. entries of A and B respectively. For example, if the entries
  56. of `A` have 7 correct digits, set atol = 1e-7. This prevents
  57. the algorithm from doing unnecessary work beyond the
  58. uncertainty of the input data.
  59. conlim : float, optional
  60. `lsmr` terminates if an estimate of ``cond(A)`` exceeds
  61. `conlim`. For compatible systems ``Ax = b``, conlim could be
  62. as large as 1.0e+12 (say). For least-squares problems,
  63. `conlim` should be less than 1.0e+8. If `conlim` is None, the
  64. default value is 1e+8. Maximum precision can be obtained by
  65. setting ``atol = btol = conlim = 0``, but the number of
  66. iterations may then be excessive.
  67. maxiter : int, optional
  68. `lsmr` terminates if the number of iterations reaches
  69. `maxiter`. The default is ``maxiter = min(m, n)``. For
  70. ill-conditioned systems, a larger value of `maxiter` may be
  71. needed.
  72. show : bool, optional
  73. Print iterations logs if ``show=True``.
  74. x0 : array_like, shape (n,), optional
  75. Initial guess of x, if None zeros are used.
  76. .. versionadded:: 1.0.0
  77. Returns
  78. -------
  79. x : ndarray of float
  80. Least-square solution returned.
  81. istop : int
  82. istop gives the reason for stopping::
  83. istop = 0 means x=0 is a solution. If x0 was given, then x=x0 is a
  84. solution.
  85. = 1 means x is an approximate solution to A*x = B,
  86. according to atol and btol.
  87. = 2 means x approximately solves the least-squares problem
  88. according to atol.
  89. = 3 means COND(A) seems to be greater than CONLIM.
  90. = 4 is the same as 1 with atol = btol = eps (machine
  91. precision)
  92. = 5 is the same as 2 with atol = eps.
  93. = 6 is the same as 3 with CONLIM = 1/eps.
  94. = 7 means ITN reached maxiter before the other stopping
  95. conditions were satisfied.
  96. itn : int
  97. Number of iterations used.
  98. normr : float
  99. ``norm(b-Ax)``
  100. normar : float
  101. ``norm(A^T (b - Ax))``
  102. norma : float
  103. ``norm(A)``
  104. conda : float
  105. Condition number of A.
  106. normx : float
  107. ``norm(x)``
  108. Notes
  109. -----
  110. .. versionadded:: 0.11.0
  111. References
  112. ----------
  113. .. [1] D. C.-L. Fong and M. A. Saunders,
  114. "LSMR: An iterative algorithm for sparse least-squares problems",
  115. SIAM J. Sci. Comput., vol. 33, pp. 2950-2971, 2011.
  116. https://arxiv.org/abs/1006.0758
  117. .. [2] LSMR Software, https://web.stanford.edu/group/SOL/software/lsmr/
  118. Examples
  119. --------
  120. >>> from scipy.sparse import csc_matrix
  121. >>> from scipy.sparse.linalg import lsmr
  122. >>> A = csc_matrix([[1., 0.], [1., 1.], [0., 1.]], dtype=float)
  123. The first example has the trivial solution `[0, 0]`
  124. >>> b = np.array([0., 0., 0.], dtype=float)
  125. >>> x, istop, itn, normr = lsmr(A, b)[:4]
  126. >>> istop
  127. 0
  128. >>> x
  129. array([ 0., 0.])
  130. The stopping code `istop=0` returned indicates that a vector of zeros was
  131. found as a solution. The returned solution `x` indeed contains `[0., 0.]`.
  132. The next example has a non-trivial solution:
  133. >>> b = np.array([1., 0., -1.], dtype=float)
  134. >>> x, istop, itn, normr = lsmr(A, b)[:4]
  135. >>> istop
  136. 1
  137. >>> x
  138. array([ 1., -1.])
  139. >>> itn
  140. 1
  141. >>> normr
  142. 4.440892098500627e-16
  143. As indicated by `istop=1`, `lsmr` found a solution obeying the tolerance
  144. limits. The given solution `[1., -1.]` obviously solves the equation. The
  145. remaining return values include information about the number of iterations
  146. (`itn=1`) and the remaining difference of left and right side of the solved
  147. equation.
  148. The final example demonstrates the behavior in the case where there is no
  149. solution for the equation:
  150. >>> b = np.array([1., 0.01, -1.], dtype=float)
  151. >>> x, istop, itn, normr = lsmr(A, b)[:4]
  152. >>> istop
  153. 2
  154. >>> x
  155. array([ 1.00333333, -0.99666667])
  156. >>> A.dot(x)-b
  157. array([ 0.00333333, -0.00333333, 0.00333333])
  158. >>> normr
  159. 0.005773502691896255
  160. `istop` indicates that the system is inconsistent and thus `x` is rather an
  161. approximate solution to the corresponding least-squares problem. `normr`
  162. contains the minimal distance that was found.
  163. """
  164. A = aslinearoperator(A)
  165. b = atleast_1d(b)
  166. if b.ndim > 1:
  167. b = b.squeeze()
  168. msg = ('The exact solution is x = 0, or x = x0, if x0 was given ',
  169. 'Ax - b is small enough, given atol, btol ',
  170. 'The least-squares solution is good enough, given atol ',
  171. 'The estimate of cond(Abar) has exceeded conlim ',
  172. 'Ax - b is small enough for this machine ',
  173. 'The least-squares solution is good enough for this machine',
  174. 'Cond(Abar) seems to be too large for this machine ',
  175. 'The iteration limit has been reached ')
  176. hdg1 = ' itn x(1) norm r norm A''r'
  177. hdg2 = ' compatible LS norm A cond A'
  178. pfreq = 20 # print frequency (for repeating the heading)
  179. pcount = 0 # print counter
  180. m, n = A.shape
  181. # stores the num of singular values
  182. minDim = min([m, n])
  183. if maxiter is None:
  184. maxiter = minDim
  185. if show:
  186. print(' ')
  187. print('LSMR Least-squares solution of Ax = b\n')
  188. print('The matrix A has %8g rows and %8g cols' % (m, n))
  189. print('damp = %20.14e\n' % (damp))
  190. print('atol = %8.2e conlim = %8.2e\n' % (atol, conlim))
  191. print('btol = %8.2e maxiter = %8g\n' % (btol, maxiter))
  192. u = b
  193. normb = norm(b)
  194. if x0 is None:
  195. x = zeros(n)
  196. beta = normb.copy()
  197. else:
  198. x = atleast_1d(x0)
  199. u = u - A.matvec(x)
  200. beta = norm(u)
  201. if beta > 0:
  202. u = (1 / beta) * u
  203. v = A.rmatvec(u)
  204. alpha = norm(v)
  205. else:
  206. v = zeros(n)
  207. alpha = 0
  208. if alpha > 0:
  209. v = (1 / alpha) * v
  210. # Initialize variables for 1st iteration.
  211. itn = 0
  212. zetabar = alpha * beta
  213. alphabar = alpha
  214. rho = 1
  215. rhobar = 1
  216. cbar = 1
  217. sbar = 0
  218. h = v.copy()
  219. hbar = zeros(n)
  220. # Initialize variables for estimation of ||r||.
  221. betadd = beta
  222. betad = 0
  223. rhodold = 1
  224. tautildeold = 0
  225. thetatilde = 0
  226. zeta = 0
  227. d = 0
  228. # Initialize variables for estimation of ||A|| and cond(A)
  229. normA2 = alpha * alpha
  230. maxrbar = 0
  231. minrbar = 1e+100
  232. normA = sqrt(normA2)
  233. condA = 1
  234. normx = 0
  235. # Items for use in stopping rules, normb set earlier
  236. istop = 0
  237. ctol = 0
  238. if conlim > 0:
  239. ctol = 1 / conlim
  240. normr = beta
  241. # Reverse the order here from the original matlab code because
  242. # there was an error on return when arnorm==0
  243. normar = alpha * beta
  244. if normar == 0:
  245. if show:
  246. print(msg[0])
  247. return x, istop, itn, normr, normar, normA, condA, normx
  248. if show:
  249. print(' ')
  250. print(hdg1, hdg2)
  251. test1 = 1
  252. test2 = alpha / beta
  253. str1 = '%6g %12.5e' % (itn, x[0])
  254. str2 = ' %10.3e %10.3e' % (normr, normar)
  255. str3 = ' %8.1e %8.1e' % (test1, test2)
  256. print(''.join([str1, str2, str3]))
  257. # Main iteration loop.
  258. while itn < maxiter:
  259. itn = itn + 1
  260. # Perform the next step of the bidiagonalization to obtain the
  261. # next beta, u, alpha, v. These satisfy the relations
  262. # beta*u = a*v - alpha*u,
  263. # alpha*v = A'*u - beta*v.
  264. u = A.matvec(v) - alpha * u
  265. beta = norm(u)
  266. if beta > 0:
  267. u = (1 / beta) * u
  268. v = A.rmatvec(u) - beta * v
  269. alpha = norm(v)
  270. if alpha > 0:
  271. v = (1 / alpha) * v
  272. # At this point, beta = beta_{k+1}, alpha = alpha_{k+1}.
  273. # Construct rotation Qhat_{k,2k+1}.
  274. chat, shat, alphahat = _sym_ortho(alphabar, damp)
  275. # Use a plane rotation (Q_i) to turn B_i to R_i
  276. rhoold = rho
  277. c, s, rho = _sym_ortho(alphahat, beta)
  278. thetanew = s*alpha
  279. alphabar = c*alpha
  280. # Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar
  281. rhobarold = rhobar
  282. zetaold = zeta
  283. thetabar = sbar * rho
  284. rhotemp = cbar * rho
  285. cbar, sbar, rhobar = _sym_ortho(cbar * rho, thetanew)
  286. zeta = cbar * zetabar
  287. zetabar = - sbar * zetabar
  288. # Update h, h_hat, x.
  289. hbar = h - (thetabar * rho / (rhoold * rhobarold)) * hbar
  290. x = x + (zeta / (rho * rhobar)) * hbar
  291. h = v - (thetanew / rho) * h
  292. # Estimate of ||r||.
  293. # Apply rotation Qhat_{k,2k+1}.
  294. betaacute = chat * betadd
  295. betacheck = -shat * betadd
  296. # Apply rotation Q_{k,k+1}.
  297. betahat = c * betaacute
  298. betadd = -s * betaacute
  299. # Apply rotation Qtilde_{k-1}.
  300. # betad = betad_{k-1} here.
  301. thetatildeold = thetatilde
  302. ctildeold, stildeold, rhotildeold = _sym_ortho(rhodold, thetabar)
  303. thetatilde = stildeold * rhobar
  304. rhodold = ctildeold * rhobar
  305. betad = - stildeold * betad + ctildeold * betahat
  306. # betad = betad_k here.
  307. # rhodold = rhod_k here.
  308. tautildeold = (zetaold - thetatildeold * tautildeold) / rhotildeold
  309. taud = (zeta - thetatilde * tautildeold) / rhodold
  310. d = d + betacheck * betacheck
  311. normr = sqrt(d + (betad - taud)**2 + betadd * betadd)
  312. # Estimate ||A||.
  313. normA2 = normA2 + beta * beta
  314. normA = sqrt(normA2)
  315. normA2 = normA2 + alpha * alpha
  316. # Estimate cond(A).
  317. maxrbar = max(maxrbar, rhobarold)
  318. if itn > 1:
  319. minrbar = min(minrbar, rhobarold)
  320. condA = max(maxrbar, rhotemp) / min(minrbar, rhotemp)
  321. # Test for convergence.
  322. # Compute norms for convergence testing.
  323. normar = abs(zetabar)
  324. normx = norm(x)
  325. # Now use these norms to estimate certain other quantities,
  326. # some of which will be small near a solution.
  327. test1 = normr / normb
  328. if (normA * normr) != 0:
  329. test2 = normar / (normA * normr)
  330. else:
  331. test2 = infty
  332. test3 = 1 / condA
  333. t1 = test1 / (1 + normA * normx / normb)
  334. rtol = btol + atol * normA * normx / normb
  335. # The following tests guard against extremely small values of
  336. # atol, btol or ctol. (The user may have set any or all of
  337. # the parameters atol, btol, conlim to 0.)
  338. # The effect is equivalent to the normAl tests using
  339. # atol = eps, btol = eps, conlim = 1/eps.
  340. if itn >= maxiter:
  341. istop = 7
  342. if 1 + test3 <= 1:
  343. istop = 6
  344. if 1 + test2 <= 1:
  345. istop = 5
  346. if 1 + t1 <= 1:
  347. istop = 4
  348. # Allow for tolerances set by the user.
  349. if test3 <= ctol:
  350. istop = 3
  351. if test2 <= atol:
  352. istop = 2
  353. if test1 <= rtol:
  354. istop = 1
  355. # See if it is time to print something.
  356. if show:
  357. if (n <= 40) or (itn <= 10) or (itn >= maxiter - 10) or \
  358. (itn % 10 == 0) or (test3 <= 1.1 * ctol) or \
  359. (test2 <= 1.1 * atol) or (test1 <= 1.1 * rtol) or \
  360. (istop != 0):
  361. if pcount >= pfreq:
  362. pcount = 0
  363. print(' ')
  364. print(hdg1, hdg2)
  365. pcount = pcount + 1
  366. str1 = '%6g %12.5e' % (itn, x[0])
  367. str2 = ' %10.3e %10.3e' % (normr, normar)
  368. str3 = ' %8.1e %8.1e' % (test1, test2)
  369. str4 = ' %8.1e %8.1e' % (normA, condA)
  370. print(''.join([str1, str2, str3, str4]))
  371. if istop > 0:
  372. break
  373. # Print the stopping condition.
  374. if show:
  375. print(' ')
  376. print('LSMR finished')
  377. print(msg[istop])
  378. print('istop =%8g normr =%8.1e' % (istop, normr))
  379. print(' normA =%8.1e normAr =%8.1e' % (normA, normar))
  380. print('itn =%8g condA =%8.1e' % (itn, condA))
  381. print(' normx =%8.1e' % (normx))
  382. print(str1, str2)
  383. print(str3, str4)
  384. return x, istop, itn, normr, normar, normA, condA, normx