test_lgmres.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. """Tests for the linalg.isolve.lgmres module
  2. """
  3. from __future__ import division, print_function, absolute_import
  4. from numpy.testing import assert_, assert_allclose, assert_equal
  5. import pytest
  6. from platform import python_implementation
  7. import numpy as np
  8. from numpy import zeros, array, allclose
  9. from scipy.linalg import norm
  10. from scipy.sparse import csr_matrix, eye, rand
  11. from scipy.sparse.linalg.interface import LinearOperator
  12. from scipy.sparse.linalg import splu
  13. from scipy.sparse.linalg.isolve import lgmres, gmres
  14. from scipy._lib._numpy_compat import suppress_warnings
  15. Am = csr_matrix(array([[-2, 1, 0, 0, 0, 9],
  16. [1, -2, 1, 0, 5, 0],
  17. [0, 1, -2, 1, 0, 0],
  18. [0, 0, 1, -2, 1, 0],
  19. [0, 3, 0, 1, -2, 1],
  20. [1, 0, 0, 0, 1, -2]]))
  21. b = array([1, 2, 3, 4, 5, 6])
  22. count = [0]
  23. def matvec(v):
  24. count[0] += 1
  25. return Am*v
  26. A = LinearOperator(matvec=matvec, shape=Am.shape, dtype=Am.dtype)
  27. def do_solve(**kw):
  28. count[0] = 0
  29. with suppress_warnings() as sup:
  30. sup.filter(DeprecationWarning, ".*called without specifying.*")
  31. x0, flag = lgmres(A, b, x0=zeros(A.shape[0]),
  32. inner_m=6, tol=1e-14, **kw)
  33. count_0 = count[0]
  34. assert_(allclose(A*x0, b, rtol=1e-12, atol=1e-12), norm(A*x0-b))
  35. return x0, count_0
  36. class TestLGMRES(object):
  37. def test_preconditioner(self):
  38. # Check that preconditioning works
  39. pc = splu(Am.tocsc())
  40. M = LinearOperator(matvec=pc.solve, shape=A.shape, dtype=A.dtype)
  41. x0, count_0 = do_solve()
  42. x1, count_1 = do_solve(M=M)
  43. assert_(count_1 == 3)
  44. assert_(count_1 < count_0/2)
  45. assert_(allclose(x1, x0, rtol=1e-14))
  46. def test_outer_v(self):
  47. # Check that the augmentation vectors behave as expected
  48. outer_v = []
  49. x0, count_0 = do_solve(outer_k=6, outer_v=outer_v)
  50. assert_(len(outer_v) > 0)
  51. assert_(len(outer_v) <= 6)
  52. x1, count_1 = do_solve(outer_k=6, outer_v=outer_v,
  53. prepend_outer_v=True)
  54. assert_(count_1 == 2, count_1)
  55. assert_(count_1 < count_0/2)
  56. assert_(allclose(x1, x0, rtol=1e-14))
  57. # ---
  58. outer_v = []
  59. x0, count_0 = do_solve(outer_k=6, outer_v=outer_v,
  60. store_outer_Av=False)
  61. assert_(array([v[1] is None for v in outer_v]).all())
  62. assert_(len(outer_v) > 0)
  63. assert_(len(outer_v) <= 6)
  64. x1, count_1 = do_solve(outer_k=6, outer_v=outer_v,
  65. prepend_outer_v=True)
  66. assert_(count_1 == 3, count_1)
  67. assert_(count_1 < count_0/2)
  68. assert_(allclose(x1, x0, rtol=1e-14))
  69. @pytest.mark.skipif(python_implementation() == 'PyPy',
  70. reason="Fails on PyPy CI runs. See #9507")
  71. def test_arnoldi(self):
  72. np.random.rand(1234)
  73. A = eye(2000) + rand(2000, 2000, density=5e-4)
  74. b = np.random.rand(2000)
  75. # The inner arnoldi should be equivalent to gmres
  76. with suppress_warnings() as sup:
  77. sup.filter(DeprecationWarning, ".*called without specifying.*")
  78. x0, flag0 = lgmres(A, b, x0=zeros(A.shape[0]),
  79. inner_m=15, maxiter=1)
  80. x1, flag1 = gmres(A, b, x0=zeros(A.shape[0]),
  81. restart=15, maxiter=1)
  82. assert_equal(flag0, 1)
  83. assert_equal(flag1, 1)
  84. assert_(np.linalg.norm(A.dot(x0) - b) > 4e-4)
  85. assert_allclose(x0, x1)
  86. def test_cornercase(self):
  87. np.random.seed(1234)
  88. # Rounding error may prevent convergence with tol=0 --- ensure
  89. # that the return values in this case are correct, and no
  90. # exceptions are raised
  91. for n in [3, 5, 10, 100]:
  92. A = 2*eye(n)
  93. with suppress_warnings() as sup:
  94. sup.filter(DeprecationWarning, ".*called without specifying.*")
  95. b = np.ones(n)
  96. x, info = lgmres(A, b, maxiter=10)
  97. assert_equal(info, 0)
  98. assert_allclose(A.dot(x) - b, 0, atol=1e-14)
  99. x, info = lgmres(A, b, tol=0, maxiter=10)
  100. if info == 0:
  101. assert_allclose(A.dot(x) - b, 0, atol=1e-14)
  102. b = np.random.rand(n)
  103. x, info = lgmres(A, b, maxiter=10)
  104. assert_equal(info, 0)
  105. assert_allclose(A.dot(x) - b, 0, atol=1e-14)
  106. x, info = lgmres(A, b, tol=0, maxiter=10)
  107. if info == 0:
  108. assert_allclose(A.dot(x) - b, 0, atol=1e-14)
  109. def test_nans(self):
  110. A = eye(3, format='lil')
  111. A[1, 1] = np.nan
  112. b = np.ones(3)
  113. with suppress_warnings() as sup:
  114. sup.filter(DeprecationWarning, ".*called without specifying.*")
  115. x, info = lgmres(A, b, tol=0, maxiter=10)
  116. assert_equal(info, 1)
  117. def test_breakdown_with_outer_v(self):
  118. A = np.array([[1, 2], [3, 4]], dtype=float)
  119. b = np.array([1, 2])
  120. x = np.linalg.solve(A, b)
  121. v0 = np.array([1, 0])
  122. # The inner iteration should converge to the correct solution,
  123. # since it's in the outer vector list
  124. with suppress_warnings() as sup:
  125. sup.filter(DeprecationWarning, ".*called without specifying.*")
  126. xp, info = lgmres(A, b, outer_v=[(v0, None), (x, None)], maxiter=1)
  127. assert_allclose(xp, x, atol=1e-12)
  128. def test_breakdown_underdetermined(self):
  129. # Should find LSQ solution in the Krylov span in one inner
  130. # iteration, despite solver breakdown from nilpotent A.
  131. A = np.array([[0, 1, 1, 1],
  132. [0, 0, 1, 1],
  133. [0, 0, 0, 1],
  134. [0, 0, 0, 0]], dtype=float)
  135. bs = [
  136. np.array([1, 1, 1, 1]),
  137. np.array([1, 1, 1, 0]),
  138. np.array([1, 1, 0, 0]),
  139. np.array([1, 0, 0, 0]),
  140. ]
  141. for b in bs:
  142. with suppress_warnings() as sup:
  143. sup.filter(DeprecationWarning, ".*called without specifying.*")
  144. xp, info = lgmres(A, b, maxiter=1)
  145. resp = np.linalg.norm(A.dot(xp) - b)
  146. K = np.c_[b, A.dot(b), A.dot(A.dot(b)), A.dot(A.dot(A.dot(b)))]
  147. y, _, _, _ = np.linalg.lstsq(A.dot(K), b, rcond=-1)
  148. x = K.dot(y)
  149. res = np.linalg.norm(A.dot(x) - b)
  150. assert_allclose(resp, res, err_msg=repr(b))
  151. def test_denormals(self):
  152. # Check that no warnings are emitted if the matrix contains
  153. # numbers for which 1/x has no float representation, and that
  154. # the solver behaves properly.
  155. A = np.array([[1, 2], [3, 4]], dtype=float)
  156. A *= 100 * np.nextafter(0, 1)
  157. b = np.array([1, 1])
  158. with suppress_warnings() as sup:
  159. sup.filter(DeprecationWarning, ".*called without specifying.*")
  160. xp, info = lgmres(A, b)
  161. if info == 0:
  162. assert_allclose(A.dot(xp), b)