test_iterative.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635
  1. """ Test functions for the sparse.linalg.isolve module
  2. """
  3. from __future__ import division, print_function, absolute_import
  4. import itertools
  5. import numpy as np
  6. from numpy.testing import (assert_equal, assert_array_equal,
  7. assert_, assert_allclose)
  8. import pytest
  9. from pytest import raises as assert_raises
  10. from scipy._lib._numpy_compat import suppress_warnings
  11. from numpy import zeros, arange, array, ones, eye, iscomplexobj
  12. from scipy.linalg import norm
  13. from scipy.sparse import spdiags, csr_matrix, SparseEfficiencyWarning
  14. from scipy.sparse.linalg import LinearOperator, aslinearoperator
  15. from scipy.sparse.linalg.isolve import cg, cgs, bicg, bicgstab, gmres, qmr, minres, lgmres, gcrotmk
  16. # TODO check that method preserve shape and type
  17. # TODO test both preconditioner methods
  18. class Case(object):
  19. def __init__(self, name, A, b=None, skip=None, nonconvergence=None):
  20. self.name = name
  21. self.A = A
  22. if b is None:
  23. self.b = arange(A.shape[0], dtype=float)
  24. else:
  25. self.b = b
  26. if skip is None:
  27. self.skip = []
  28. else:
  29. self.skip = skip
  30. if nonconvergence is None:
  31. self.nonconvergence = []
  32. else:
  33. self.nonconvergence = nonconvergence
  34. def __repr__(self):
  35. return "<%s>" % self.name
  36. class IterativeParams(object):
  37. def __init__(self):
  38. # list of tuples (solver, symmetric, positive_definite )
  39. solvers = [cg, cgs, bicg, bicgstab, gmres, qmr, minres, lgmres, gcrotmk]
  40. sym_solvers = [minres, cg]
  41. posdef_solvers = [cg]
  42. real_solvers = [minres]
  43. self.solvers = solvers
  44. # list of tuples (A, symmetric, positive_definite )
  45. self.cases = []
  46. # Symmetric and Positive Definite
  47. N = 40
  48. data = ones((3,N))
  49. data[0,:] = 2
  50. data[1,:] = -1
  51. data[2,:] = -1
  52. Poisson1D = spdiags(data, [0,-1,1], N, N, format='csr')
  53. self.Poisson1D = Case("poisson1d", Poisson1D)
  54. self.cases.append(Case("poisson1d", Poisson1D))
  55. # note: minres fails for single precision
  56. self.cases.append(Case("poisson1d", Poisson1D.astype('f'),
  57. skip=[minres]))
  58. # Symmetric and Negative Definite
  59. self.cases.append(Case("neg-poisson1d", -Poisson1D,
  60. skip=posdef_solvers))
  61. # note: minres fails for single precision
  62. self.cases.append(Case("neg-poisson1d", (-Poisson1D).astype('f'),
  63. skip=posdef_solvers + [minres]))
  64. # Symmetric and Indefinite
  65. data = array([[6, -5, 2, 7, -1, 10, 4, -3, -8, 9]],dtype='d')
  66. RandDiag = spdiags(data, [0], 10, 10, format='csr')
  67. self.cases.append(Case("rand-diag", RandDiag, skip=posdef_solvers))
  68. self.cases.append(Case("rand-diag", RandDiag.astype('f'),
  69. skip=posdef_solvers))
  70. # Random real-valued
  71. np.random.seed(1234)
  72. data = np.random.rand(4, 4)
  73. self.cases.append(Case("rand", data, skip=posdef_solvers+sym_solvers))
  74. self.cases.append(Case("rand", data.astype('f'),
  75. skip=posdef_solvers+sym_solvers))
  76. # Random symmetric real-valued
  77. np.random.seed(1234)
  78. data = np.random.rand(4, 4)
  79. data = data + data.T
  80. self.cases.append(Case("rand-sym", data, skip=posdef_solvers))
  81. self.cases.append(Case("rand-sym", data.astype('f'),
  82. skip=posdef_solvers))
  83. # Random pos-def symmetric real
  84. np.random.seed(1234)
  85. data = np.random.rand(9, 9)
  86. data = np.dot(data.conj(), data.T)
  87. self.cases.append(Case("rand-sym-pd", data))
  88. # note: minres fails for single precision
  89. self.cases.append(Case("rand-sym-pd", data.astype('f'),
  90. skip=[minres]))
  91. # Random complex-valued
  92. np.random.seed(1234)
  93. data = np.random.rand(4, 4) + 1j*np.random.rand(4, 4)
  94. self.cases.append(Case("rand-cmplx", data,
  95. skip=posdef_solvers+sym_solvers+real_solvers))
  96. self.cases.append(Case("rand-cmplx", data.astype('F'),
  97. skip=posdef_solvers+sym_solvers+real_solvers))
  98. # Random hermitian complex-valued
  99. np.random.seed(1234)
  100. data = np.random.rand(4, 4) + 1j*np.random.rand(4, 4)
  101. data = data + data.T.conj()
  102. self.cases.append(Case("rand-cmplx-herm", data,
  103. skip=posdef_solvers+real_solvers))
  104. self.cases.append(Case("rand-cmplx-herm", data.astype('F'),
  105. skip=posdef_solvers+real_solvers))
  106. # Random pos-def hermitian complex-valued
  107. np.random.seed(1234)
  108. data = np.random.rand(9, 9) + 1j*np.random.rand(9, 9)
  109. data = np.dot(data.conj(), data.T)
  110. self.cases.append(Case("rand-cmplx-sym-pd", data, skip=real_solvers))
  111. self.cases.append(Case("rand-cmplx-sym-pd", data.astype('F'),
  112. skip=real_solvers))
  113. # Non-symmetric and Positive Definite
  114. #
  115. # cgs, qmr, and bicg fail to converge on this one
  116. # -- algorithmic limitation apparently
  117. data = ones((2,10))
  118. data[0,:] = 2
  119. data[1,:] = -1
  120. A = spdiags(data, [0,-1], 10, 10, format='csr')
  121. self.cases.append(Case("nonsymposdef", A,
  122. skip=sym_solvers+[cgs, qmr, bicg]))
  123. self.cases.append(Case("nonsymposdef", A.astype('F'),
  124. skip=sym_solvers+[cgs, qmr, bicg]))
  125. # Symmetric, non-pd, hitting cgs/bicg/bicgstab/qmr breakdown
  126. A = np.array([[0, 0, 0, 0, 0, 1, -1, -0, -0, -0, -0],
  127. [0, 0, 0, 0, 0, 2, -0, -1, -0, -0, -0],
  128. [0, 0, 0, 0, 0, 2, -0, -0, -1, -0, -0],
  129. [0, 0, 0, 0, 0, 2, -0, -0, -0, -1, -0],
  130. [0, 0, 0, 0, 0, 1, -0, -0, -0, -0, -1],
  131. [1, 2, 2, 2, 1, 0, -0, -0, -0, -0, -0],
  132. [-1, 0, 0, 0, 0, 0, -1, -0, -0, -0, -0],
  133. [0, -1, 0, 0, 0, 0, -0, -1, -0, -0, -0],
  134. [0, 0, -1, 0, 0, 0, -0, -0, -1, -0, -0],
  135. [0, 0, 0, -1, 0, 0, -0, -0, -0, -1, -0],
  136. [0, 0, 0, 0, -1, 0, -0, -0, -0, -0, -1]], dtype=float)
  137. b = np.array([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], dtype=float)
  138. assert (A == A.T).all()
  139. self.cases.append(Case("sym-nonpd", A, b,
  140. skip=posdef_solvers,
  141. nonconvergence=[cgs,bicg,bicgstab,qmr]))
  142. params = IterativeParams()
  143. def check_maxiter(solver, case):
  144. A = case.A
  145. tol = 1e-12
  146. b = case.b
  147. x0 = 0*b
  148. residuals = []
  149. def callback(x):
  150. residuals.append(norm(b - case.A*x))
  151. x, info = solver(A, b, x0=x0, tol=tol, maxiter=1, callback=callback)
  152. assert_equal(len(residuals), 1)
  153. assert_equal(info, 1)
  154. def test_maxiter():
  155. case = params.Poisson1D
  156. for solver in params.solvers:
  157. if solver in case.skip:
  158. continue
  159. with suppress_warnings() as sup:
  160. sup.filter(DeprecationWarning, ".*called without specifying.*")
  161. check_maxiter(solver, case)
  162. def assert_normclose(a, b, tol=1e-8):
  163. residual = norm(a - b)
  164. tolerance = tol*norm(b)
  165. msg = "residual (%g) not smaller than tolerance %g" % (residual, tolerance)
  166. assert_(residual < tolerance, msg=msg)
  167. def check_convergence(solver, case):
  168. A = case.A
  169. if A.dtype.char in "dD":
  170. tol = 1e-8
  171. else:
  172. tol = 1e-2
  173. b = case.b
  174. x0 = 0*b
  175. x, info = solver(A, b, x0=x0, tol=tol)
  176. assert_array_equal(x0, 0*b) # ensure that x0 is not overwritten
  177. if solver not in case.nonconvergence:
  178. assert_equal(info,0)
  179. assert_normclose(A.dot(x), b, tol=tol)
  180. else:
  181. assert_(info != 0)
  182. assert_(np.linalg.norm(A.dot(x) - b) <= np.linalg.norm(b))
  183. def test_convergence():
  184. for solver in params.solvers:
  185. for case in params.cases:
  186. if solver in case.skip:
  187. continue
  188. with suppress_warnings() as sup:
  189. sup.filter(DeprecationWarning, ".*called without specifying.*")
  190. check_convergence(solver, case)
  191. def check_precond_dummy(solver, case):
  192. tol = 1e-8
  193. def identity(b,which=None):
  194. """trivial preconditioner"""
  195. return b
  196. A = case.A
  197. M,N = A.shape
  198. D = spdiags([1.0/A.diagonal()], [0], M, N)
  199. b = case.b
  200. x0 = 0*b
  201. precond = LinearOperator(A.shape, identity, rmatvec=identity)
  202. if solver is qmr:
  203. x, info = solver(A, b, M1=precond, M2=precond, x0=x0, tol=tol)
  204. else:
  205. x, info = solver(A, b, M=precond, x0=x0, tol=tol)
  206. assert_equal(info,0)
  207. assert_normclose(A.dot(x), b, tol)
  208. A = aslinearoperator(A)
  209. A.psolve = identity
  210. A.rpsolve = identity
  211. x, info = solver(A, b, x0=x0, tol=tol)
  212. assert_equal(info,0)
  213. assert_normclose(A*x, b, tol=tol)
  214. def test_precond_dummy():
  215. case = params.Poisson1D
  216. for solver in params.solvers:
  217. if solver in case.skip:
  218. continue
  219. with suppress_warnings() as sup:
  220. sup.filter(DeprecationWarning, ".*called without specifying.*")
  221. check_precond_dummy(solver, case)
  222. def check_precond_inverse(solver, case):
  223. tol = 1e-8
  224. def inverse(b,which=None):
  225. """inverse preconditioner"""
  226. A = case.A
  227. if not isinstance(A, np.ndarray):
  228. A = A.todense()
  229. return np.linalg.solve(A, b)
  230. def rinverse(b,which=None):
  231. """inverse preconditioner"""
  232. A = case.A
  233. if not isinstance(A, np.ndarray):
  234. A = A.todense()
  235. return np.linalg.solve(A.T, b)
  236. matvec_count = [0]
  237. def matvec(b):
  238. matvec_count[0] += 1
  239. return case.A.dot(b)
  240. def rmatvec(b):
  241. matvec_count[0] += 1
  242. return case.A.T.dot(b)
  243. b = case.b
  244. x0 = 0*b
  245. A = LinearOperator(case.A.shape, matvec, rmatvec=rmatvec)
  246. precond = LinearOperator(case.A.shape, inverse, rmatvec=rinverse)
  247. # Solve with preconditioner
  248. matvec_count = [0]
  249. x, info = solver(A, b, M=precond, x0=x0, tol=tol)
  250. assert_equal(info, 0)
  251. assert_normclose(case.A.dot(x), b, tol)
  252. # Solution should be nearly instant
  253. assert_(matvec_count[0] <= 3, repr(matvec_count))
  254. def test_precond_inverse():
  255. case = params.Poisson1D
  256. for solver in params.solvers:
  257. if solver in case.skip:
  258. continue
  259. if solver is qmr:
  260. continue
  261. with suppress_warnings() as sup:
  262. sup.filter(DeprecationWarning, ".*called without specifying.*")
  263. check_precond_inverse(solver, case)
  264. def test_gmres_basic():
  265. A = np.vander(np.arange(10) + 1)[:, ::-1]
  266. b = np.zeros(10)
  267. b[0] = 1
  268. x = np.linalg.solve(A, b)
  269. with suppress_warnings() as sup:
  270. sup.filter(DeprecationWarning, ".*called without specifying.*")
  271. x_gm, err = gmres(A, b, restart=5, maxiter=1)
  272. assert_allclose(x_gm[0], 0.359, rtol=1e-2)
  273. def test_reentrancy():
  274. non_reentrant = [cg, cgs, bicg, bicgstab, gmres, qmr]
  275. reentrant = [lgmres, minres, gcrotmk]
  276. for solver in reentrant + non_reentrant:
  277. with suppress_warnings() as sup:
  278. sup.filter(DeprecationWarning, ".*called without specifying.*")
  279. _check_reentrancy(solver, solver in reentrant)
  280. def _check_reentrancy(solver, is_reentrant):
  281. def matvec(x):
  282. A = np.array([[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]])
  283. y, info = solver(A, x)
  284. assert_equal(info, 0)
  285. return y
  286. b = np.array([1, 1./2, 1./3])
  287. op = LinearOperator((3, 3), matvec=matvec, rmatvec=matvec,
  288. dtype=b.dtype)
  289. if not is_reentrant:
  290. assert_raises(RuntimeError, solver, op, b)
  291. else:
  292. y, info = solver(op, b)
  293. assert_equal(info, 0)
  294. assert_allclose(y, [1, 1, 1])
  295. @pytest.mark.parametrize("solver", [cg, cgs, bicg, bicgstab, gmres, qmr, lgmres, gcrotmk])
  296. def test_atol(solver):
  297. # TODO: minres. It didn't historically use absolute tolerances, so
  298. # fixing it is less urgent.
  299. np.random.seed(1234)
  300. A = np.random.rand(10, 10)
  301. A = A.dot(A.T) + 10 * np.eye(10)
  302. b = 1e3 * np.random.rand(10)
  303. b_norm = np.linalg.norm(b)
  304. tols = np.r_[0, np.logspace(np.log10(1e-10), np.log10(1e2), 7), np.inf]
  305. # Check effect of badly scaled preconditioners
  306. M0 = np.random.randn(10, 10)
  307. M0 = M0.dot(M0.T)
  308. Ms = [None, 1e-6 * M0, 1e6 * M0]
  309. for M, tol, atol in itertools.product(Ms, tols, tols):
  310. if tol == 0 and atol == 0:
  311. continue
  312. if solver is qmr:
  313. if M is not None:
  314. M = aslinearoperator(M)
  315. M2 = aslinearoperator(np.eye(10))
  316. else:
  317. M2 = None
  318. x, info = solver(A, b, M1=M, M2=M2, tol=tol, atol=atol)
  319. else:
  320. x, info = solver(A, b, M=M, tol=tol, atol=atol)
  321. assert_equal(info, 0)
  322. residual = A.dot(x) - b
  323. err = np.linalg.norm(residual)
  324. atol2 = tol * b_norm
  325. assert_(err <= max(atol, atol2))
  326. @pytest.mark.parametrize("solver", [cg, cgs, bicg, bicgstab, gmres, qmr, minres, lgmres, gcrotmk])
  327. def test_zero_rhs(solver):
  328. np.random.seed(1234)
  329. A = np.random.rand(10, 10)
  330. A = A.dot(A.T) + 10 * np.eye(10)
  331. b = np.zeros(10)
  332. tols = np.r_[np.logspace(np.log10(1e-10), np.log10(1e2), 7)]
  333. for tol in tols:
  334. with suppress_warnings() as sup:
  335. sup.filter(DeprecationWarning, ".*called without specifying.*")
  336. x, info = solver(A, b, tol=tol)
  337. assert_equal(info, 0)
  338. assert_allclose(x, 0, atol=1e-15)
  339. x, info = solver(A, b, tol=tol, x0=ones(10))
  340. assert_equal(info, 0)
  341. assert_allclose(x, 0, atol=tol)
  342. if solver is not minres:
  343. x, info = solver(A, b, tol=tol, atol=0, x0=ones(10))
  344. if info == 0:
  345. assert_allclose(x, 0)
  346. x, info = solver(A, b, tol=tol, atol=tol)
  347. assert_equal(info, 0)
  348. assert_allclose(x, 0, atol=1e-300)
  349. x, info = solver(A, b, tol=tol, atol=0)
  350. assert_equal(info, 0)
  351. assert_allclose(x, 0, atol=1e-300)
  352. @pytest.mark.parametrize("solver", [
  353. gmres, qmr, lgmres,
  354. pytest.param(cgs, marks=pytest.mark.xfail),
  355. pytest.param(bicg, marks=pytest.mark.xfail),
  356. pytest.param(bicgstab, marks=pytest.mark.xfail),
  357. pytest.param(gcrotmk, marks=pytest.mark.xfail)])
  358. def test_maxiter_worsening(solver):
  359. # Check error does not grow (boundlessly) with increasing maxiter.
  360. # This can occur due to the solvers hitting close to breakdown,
  361. # which they should detect and halt as necessary.
  362. # cf. gh-9100
  363. # Singular matrix, rhs numerically not in range
  364. A = np.array([[-0.1112795288033378, 0, 0, 0.16127952880333685],
  365. [0, -0.13627952880333782+6.283185307179586j, 0, 0],
  366. [0, 0, -0.13627952880333782-6.283185307179586j, 0],
  367. [0.1112795288033368, 0j, 0j, -0.16127952880333785]])
  368. v = np.ones(4)
  369. best_error = np.inf
  370. for maxiter in range(1, 20):
  371. x, info = solver(A, v, maxiter=maxiter, tol=1e-8, atol=0)
  372. if info == 0:
  373. assert_(np.linalg.norm(A.dot(x) - v) <= 1e-8*np.linalg.norm(v))
  374. error = np.linalg.norm(A.dot(x) - v)
  375. best_error = min(best_error, error)
  376. # Check with slack
  377. assert_(error <= 5*best_error)
  378. #------------------------------------------------------------------------------
  379. class TestQMR(object):
  380. def test_leftright_precond(self):
  381. """Check that QMR works with left and right preconditioners"""
  382. from scipy.sparse.linalg.dsolve import splu
  383. from scipy.sparse.linalg.interface import LinearOperator
  384. n = 100
  385. dat = ones(n)
  386. A = spdiags([-2*dat, 4*dat, -dat], [-1,0,1],n,n)
  387. b = arange(n,dtype='d')
  388. L = spdiags([-dat/2, dat], [-1,0], n, n)
  389. U = spdiags([4*dat, -dat], [0,1], n, n)
  390. with suppress_warnings() as sup:
  391. sup.filter(SparseEfficiencyWarning, "splu requires CSC matrix format")
  392. L_solver = splu(L)
  393. U_solver = splu(U)
  394. def L_solve(b):
  395. return L_solver.solve(b)
  396. def U_solve(b):
  397. return U_solver.solve(b)
  398. def LT_solve(b):
  399. return L_solver.solve(b,'T')
  400. def UT_solve(b):
  401. return U_solver.solve(b,'T')
  402. M1 = LinearOperator((n,n), matvec=L_solve, rmatvec=LT_solve)
  403. M2 = LinearOperator((n,n), matvec=U_solve, rmatvec=UT_solve)
  404. with suppress_warnings() as sup:
  405. sup.filter(DeprecationWarning, ".*called without specifying.*")
  406. x,info = qmr(A, b, tol=1e-8, maxiter=15, M1=M1, M2=M2)
  407. assert_equal(info,0)
  408. assert_normclose(A*x, b, tol=1e-8)
  409. class TestGMRES(object):
  410. def test_callback(self):
  411. def store_residual(r, rvec):
  412. rvec[rvec.nonzero()[0].max()+1] = r
  413. # Define, A,b
  414. A = csr_matrix(array([[-2,1,0,0,0,0],[1,-2,1,0,0,0],[0,1,-2,1,0,0],[0,0,1,-2,1,0],[0,0,0,1,-2,1],[0,0,0,0,1,-2]]))
  415. b = ones((A.shape[0],))
  416. maxiter = 1
  417. rvec = zeros(maxiter+1)
  418. rvec[0] = 1.0
  419. callback = lambda r:store_residual(r, rvec)
  420. with suppress_warnings() as sup:
  421. sup.filter(DeprecationWarning, ".*called without specifying.*")
  422. x,flag = gmres(A, b, x0=zeros(A.shape[0]), tol=1e-16, maxiter=maxiter, callback=callback)
  423. # Expected output from Scipy 1.0.0
  424. assert_allclose(rvec, array([1.0, 0.81649658092772603]), rtol=1e-10)
  425. # Test preconditioned callback
  426. M = 1e-3 * np.eye(A.shape[0])
  427. rvec = zeros(maxiter+1)
  428. rvec[0] = 1.0
  429. with suppress_warnings() as sup:
  430. sup.filter(DeprecationWarning, ".*called without specifying.*")
  431. x, flag = gmres(A, b, M=M, tol=1e-16, maxiter=maxiter, callback=callback)
  432. # Expected output from Scipy 1.0.0 (callback has preconditioned residual!)
  433. assert_allclose(rvec, array([1.0, 1e-3 * 0.81649658092772603]), rtol=1e-10)
  434. def test_abi(self):
  435. # Check we don't segfault on gmres with complex argument
  436. A = eye(2)
  437. b = ones(2)
  438. with suppress_warnings() as sup:
  439. sup.filter(DeprecationWarning, ".*called without specifying.*")
  440. r_x, r_info = gmres(A, b)
  441. r_x = r_x.astype(complex)
  442. x, info = gmres(A.astype(complex), b.astype(complex))
  443. assert_(iscomplexobj(x))
  444. assert_allclose(r_x, x)
  445. assert_(r_info == info)
  446. def test_atol_legacy(self):
  447. with suppress_warnings() as sup:
  448. sup.filter(DeprecationWarning, ".*called without specifying.*")
  449. # Check the strange legacy behavior: the tolerance is interpreted
  450. # as atol, but only for the initial residual
  451. A = eye(2)
  452. b = 1e-6 * ones(2)
  453. x, info = gmres(A, b, tol=1e-5)
  454. assert_array_equal(x, np.zeros(2))
  455. A = eye(2)
  456. b = ones(2)
  457. x, info = gmres(A, b, tol=1e-5)
  458. assert_(np.linalg.norm(A.dot(x) - b) <= 1e-5*np.linalg.norm(b))
  459. assert_allclose(x, b, atol=0, rtol=1e-8)
  460. rndm = np.random.RandomState(12345)
  461. A = rndm.rand(30, 30)
  462. b = 1e-6 * ones(30)
  463. x, info = gmres(A, b, tol=1e-7, restart=20)
  464. assert_(np.linalg.norm(A.dot(x) - b) > 1e-7)
  465. A = eye(2)
  466. b = 1e-10 * ones(2)
  467. x, info = gmres(A, b, tol=1e-8, atol=0)
  468. assert_(np.linalg.norm(A.dot(x) - b) <= 1e-8*np.linalg.norm(b))
  469. def test_defective_precond_breakdown(self):
  470. # Breakdown due to defective preconditioner
  471. M = np.eye(3)
  472. M[2,2] = 0
  473. b = np.array([0, 1, 1])
  474. x = np.array([1, 0, 0])
  475. A = np.diag([2, 3, 4])
  476. x, info = gmres(A, b, x0=x, M=M, tol=1e-15, atol=0)
  477. # Should not return nans, nor terminate with false success
  478. assert_(not np.isnan(x).any())
  479. if info == 0:
  480. assert_(np.linalg.norm(A.dot(x) - b) <= 1e-15*np.linalg.norm(b))
  481. # The solution should be OK outside null space of M
  482. assert_allclose(M.dot(A.dot(x)), M.dot(b))
  483. def test_defective_matrix_breakdown(self):
  484. # Breakdown due to defective matrix
  485. A = np.array([[0, 1, 0], [1, 0, 0], [0, 0, 0]])
  486. b = np.array([1, 0, 1])
  487. x, info = gmres(A, b, tol=1e-8, atol=0)
  488. # Should not return nans, nor terminate with false success
  489. assert_(not np.isnan(x).any())
  490. if info == 0:
  491. assert_(np.linalg.norm(A.dot(x) - b) <= 1e-8*np.linalg.norm(b))
  492. # The solution should be OK outside null space of A
  493. assert_allclose(A.dot(A.dot(x)), A.dot(b))