test_minres.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. from __future__ import division, print_function, absolute_import
  2. import numpy as np
  3. from numpy.testing import assert_equal, assert_allclose, assert_
  4. from scipy.sparse.linalg.isolve import minres
  5. import pytest
  6. from pytest import raises as assert_raises
  7. from .test_iterative import assert_normclose
  8. def get_sample_problem():
  9. # A random 10 x 10 symmetric matrix
  10. np.random.seed(1234)
  11. matrix = np.random.rand(10, 10)
  12. matrix = matrix + matrix.T
  13. # A random vector of length 10
  14. vector = np.random.rand(10)
  15. return matrix, vector
  16. def test_singular():
  17. A, b = get_sample_problem()
  18. A[0, ] = 0
  19. b[0] = 0
  20. xp, info = minres(A, b)
  21. assert_equal(info, 0)
  22. assert_normclose(A.dot(xp), b, tol=1e-5)
  23. @pytest.mark.skip(reason="Skip Until gh #6843 is fixed")
  24. def test_gh_6843():
  25. """check if x0 is being used by tracing iterates"""
  26. A, b = get_sample_problem()
  27. # Random x0 to feed minres
  28. np.random.seed(12345)
  29. x0 = np.random.rand(10)
  30. trace = []
  31. def trace_iterates(xk):
  32. trace.append(xk)
  33. minres(A, b, x0=x0, callback=trace_iterates)
  34. trace_with_x0 = trace
  35. trace = []
  36. minres(A, b, callback=trace_iterates)
  37. assert_(not np.array_equal(trace_with_x0[0], trace[0]))
  38. def test_shift():
  39. A, b = get_sample_problem()
  40. shift = 0.5
  41. shifted_A = A - shift * np.eye(10)
  42. x1, info1 = minres(A, b, shift=shift)
  43. x2, info2 = minres(shifted_A, b)
  44. assert_equal(info1, 0)
  45. assert_allclose(x1, x2, rtol=1e-5)
  46. def test_asymmetric_fail():
  47. """Asymmetric matrix should raise `ValueError` when check=True"""
  48. A, b = get_sample_problem()
  49. A[1, 2] = 1
  50. A[2, 1] = 2
  51. with assert_raises(ValueError):
  52. xp, info = minres(A, b, check=True)