test_lsqr.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. from __future__ import division, print_function, absolute_import
  2. import numpy as np
  3. from numpy.testing import (assert_, assert_equal, assert_almost_equal,
  4. assert_array_almost_equal)
  5. from scipy._lib.six import xrange
  6. import scipy.sparse
  7. import scipy.sparse.linalg
  8. from scipy.sparse.linalg import lsqr
  9. from time import time
  10. # Set up a test problem
  11. n = 35
  12. G = np.eye(n)
  13. normal = np.random.normal
  14. norm = np.linalg.norm
  15. for jj in xrange(5):
  16. gg = normal(size=n)
  17. hh = gg * gg.T
  18. G += (hh + hh.T) * 0.5
  19. G += normal(size=n) * normal(size=n)
  20. b = normal(size=n)
  21. tol = 1e-10
  22. show = False
  23. maxit = None
  24. def test_basic():
  25. b_copy = b.copy()
  26. X = lsqr(G, b, show=show, atol=tol, btol=tol, iter_lim=maxit)
  27. assert_(np.all(b_copy == b))
  28. svx = np.linalg.solve(G, b)
  29. xo = X[0]
  30. assert_(norm(svx - xo) < 1e-5)
  31. def test_gh_2466():
  32. row = np.array([0, 0])
  33. col = np.array([0, 1])
  34. val = np.array([1, -1])
  35. A = scipy.sparse.coo_matrix((val, (row, col)), shape=(1, 2))
  36. b = np.asarray([4])
  37. lsqr(A, b)
  38. def test_well_conditioned_problems():
  39. # Test that sparse the lsqr solver returns the right solution
  40. # on various problems with different random seeds.
  41. # This is a non-regression test for a potential ZeroDivisionError
  42. # raised when computing the `test2` & `test3` convergence conditions.
  43. n = 10
  44. A_sparse = scipy.sparse.eye(n, n)
  45. A_dense = A_sparse.toarray()
  46. with np.errstate(invalid='raise'):
  47. for seed in range(30):
  48. rng = np.random.RandomState(seed + 10)
  49. beta = rng.rand(n)
  50. beta[beta == 0] = 0.00001 # ensure that all the betas are not null
  51. b = A_sparse * beta[:, np.newaxis]
  52. output = lsqr(A_sparse, b, show=show)
  53. # Check that the termination condition corresponds to an approximate
  54. # solution to Ax = b
  55. assert_equal(output[1], 1)
  56. solution = output[0]
  57. # Check that we recover the ground truth solution
  58. assert_array_almost_equal(solution, beta)
  59. # Sanity check: compare to the dense array solver
  60. reference_solution = np.linalg.solve(A_dense, b).ravel()
  61. assert_array_almost_equal(solution, reference_solution)
  62. def test_b_shapes():
  63. # Test b being a scalar.
  64. A = np.array([[1.0, 2.0]])
  65. b = 3.0
  66. x = lsqr(A, b)[0]
  67. assert_almost_equal(norm(A.dot(x) - b), 0)
  68. # Test b being a column vector.
  69. A = np.eye(10)
  70. b = np.ones((10, 1))
  71. x = lsqr(A, b)[0]
  72. assert_almost_equal(norm(A.dot(x) - b.ravel()), 0)
  73. def test_initialization():
  74. # Test the default setting is the same as zeros
  75. b_copy = b.copy()
  76. x_ref = lsqr(G, b, show=show, atol=tol, btol=tol, iter_lim=maxit)
  77. x0 = np.zeros(x_ref[0].shape)
  78. x = lsqr(G, b, show=show, atol=tol, btol=tol, iter_lim=maxit, x0=x0)
  79. assert_(np.all(b_copy == b))
  80. assert_array_almost_equal(x_ref[0], x[0])
  81. # Test warm-start with single iteration
  82. x0 = lsqr(G, b, show=show, atol=tol, btol=tol, iter_lim=1)[0]
  83. x = lsqr(G, b, show=show, atol=tol, btol=tol, iter_lim=maxit, x0=x0)
  84. assert_array_almost_equal(x_ref[0], x[0])
  85. assert_(np.all(b_copy == b))
  86. if __name__ == "__main__":
  87. svx = np.linalg.solve(G, b)
  88. tic = time()
  89. X = lsqr(G, b, show=show, atol=tol, btol=tol, iter_lim=maxit)
  90. xo = X[0]
  91. phio = X[3]
  92. psio = X[7]
  93. k = X[2]
  94. chio = X[8]
  95. mg = np.amax(G - G.T)
  96. if mg > 1e-14:
  97. sym = 'No'
  98. else:
  99. sym = 'Yes'
  100. print('LSQR')
  101. print("Is linear operator symmetric? " + sym)
  102. print("n: %3g iterations: %3g" % (n, k))
  103. print("Norms computed in %.2fs by LSQR" % (time() - tic))
  104. print(" ||x|| %9.4e ||r|| %9.4e ||Ar|| %9.4e " % (chio, phio, psio))
  105. print("Residual norms computed directly:")
  106. print(" ||x|| %9.4e ||r|| %9.4e ||Ar|| %9.4e" % (norm(xo),
  107. norm(G*xo - b),
  108. norm(G.T*(G*xo-b))))
  109. print("Direct solution norms:")
  110. print(" ||x|| %9.4e ||r|| %9.4e " % (norm(svx), norm(G*svx - b)))
  111. print("")
  112. print(" || x_{direct} - x_{LSQR}|| %9.4e " % norm(svx-xo))
  113. print("")