test__root.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. """
  2. Unit tests for optimization routines from _root.py.
  3. """
  4. from __future__ import division, print_function, absolute_import
  5. from numpy.testing import assert_
  6. from pytest import raises as assert_raises
  7. import numpy as np
  8. from scipy.optimize import root
  9. class TestRoot(object):
  10. def test_tol_parameter(self):
  11. # Check that the minimize() tol= argument does something
  12. def func(z):
  13. x, y = z
  14. return np.array([x**3 - 1, y**3 - 1])
  15. def dfunc(z):
  16. x, y = z
  17. return np.array([[3*x**2, 0], [0, 3*y**2]])
  18. for method in ['hybr', 'lm', 'broyden1', 'broyden2', 'anderson',
  19. 'diagbroyden', 'krylov']:
  20. if method in ('linearmixing', 'excitingmixing'):
  21. # doesn't converge
  22. continue
  23. if method in ('hybr', 'lm'):
  24. jac = dfunc
  25. else:
  26. jac = None
  27. sol1 = root(func, [1.1,1.1], jac=jac, tol=1e-4, method=method)
  28. sol2 = root(func, [1.1,1.1], jac=jac, tol=0.5, method=method)
  29. msg = "%s: %s vs. %s" % (method, func(sol1.x), func(sol2.x))
  30. assert_(sol1.success, msg)
  31. assert_(sol2.success, msg)
  32. assert_(abs(func(sol1.x)).max() < abs(func(sol2.x)).max(),
  33. msg)
  34. def test_minimize_scalar_coerce_args_param(self):
  35. # github issue #3503
  36. def func(z, f=1):
  37. x, y = z
  38. return np.array([x**3 - 1, y**3 - f])
  39. root(func, [1.1, 1.1], args=1.5)
  40. def test_f_size(self):
  41. # gh8320
  42. # check that decreasing the size of the returned array raises an error
  43. # and doesn't segfault
  44. class fun(object):
  45. def __init__(self):
  46. self.count = 0
  47. def __call__(self, x):
  48. self.count += 1
  49. if not (self.count % 5):
  50. ret = x[0] + 0.5 * (x[0] - x[1]) ** 3 - 1.0
  51. else:
  52. ret = ([x[0] + 0.5 * (x[0] - x[1]) ** 3 - 1.0,
  53. 0.5 * (x[1] - x[0]) ** 3 + x[1]])
  54. return ret
  55. F = fun()
  56. with assert_raises(ValueError):
  57. sol = root(F, [0.1, 0.0], method='lm')