test_zeros.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663
  1. from __future__ import division, print_function, absolute_import
  2. import pytest
  3. from math import sqrt, exp, sin, cos
  4. from numpy.testing import (assert_warns, assert_,
  5. assert_allclose,
  6. assert_equal,
  7. assert_array_equal)
  8. import numpy as np
  9. from numpy import finfo, power, nan, isclose
  10. from scipy.optimize import zeros, newton, root_scalar
  11. from scipy._lib._util import getargspec_no_self as _getargspec
  12. # Import testing parameters
  13. from scipy.optimize._tstutils import get_tests, functions as tstutils_functions, fstrings as tstutils_fstrings
  14. from scipy._lib._numpy_compat import suppress_warnings
  15. TOL = 4*np.finfo(float).eps # tolerance
  16. _FLOAT_EPS = finfo(float).eps
  17. # A few test functions used frequently:
  18. # # A simple quadratic, (x-1)^2 - 1
  19. def f1(x):
  20. return x ** 2 - 2 * x - 1
  21. def f1_1(x):
  22. return 2 * x - 2
  23. def f1_2(x):
  24. return 2.0 + 0 * x
  25. def f1_and_p_and_pp(x):
  26. return f1(x), f1_1(x), f1_2(x)
  27. # Simple transcendental function
  28. def f2(x):
  29. return exp(x) - cos(x)
  30. def f2_1(x):
  31. return exp(x) + sin(x)
  32. def f2_2(x):
  33. return exp(x) + cos(x)
  34. class TestBasic(object):
  35. def run_check_by_name(self, name, smoothness=0, **kwargs):
  36. a = .5
  37. b = sqrt(3)
  38. xtol = 4*np.finfo(float).eps
  39. rtol = 4*np.finfo(float).eps
  40. for function, fname in zip(tstutils_functions, tstutils_fstrings):
  41. if smoothness > 0 and fname in ['f4', 'f5', 'f6']:
  42. continue
  43. r = root_scalar(function, method=name, bracket=[a, b], x0=a,
  44. xtol=xtol, rtol=rtol, **kwargs)
  45. zero = r.root
  46. assert_(r.converged)
  47. assert_allclose(zero, 1.0, atol=xtol, rtol=rtol,
  48. err_msg='method %s, function %s' % (name, fname))
  49. def run_check(self, method, name):
  50. a = .5
  51. b = sqrt(3)
  52. xtol = 4 * _FLOAT_EPS
  53. rtol = 4 * _FLOAT_EPS
  54. for function, fname in zip(tstutils_functions, tstutils_fstrings):
  55. zero, r = method(function, a, b, xtol=xtol, rtol=rtol,
  56. full_output=True)
  57. assert_(r.converged)
  58. assert_allclose(zero, 1.0, atol=xtol, rtol=rtol,
  59. err_msg='method %s, function %s' % (name, fname))
  60. def _run_one_test(self, tc, method, sig_args_keys=None,
  61. sig_kwargs_keys=None, **kwargs):
  62. method_args = []
  63. for k in sig_args_keys or []:
  64. if k not in tc:
  65. # If a,b not present use x0, x1. Similarly for f and func
  66. k = {'a': 'x0', 'b': 'x1', 'func': 'f'}.get(k, k)
  67. method_args.append(tc[k])
  68. method_kwargs = dict(**kwargs)
  69. method_kwargs.update({'full_output': True, 'disp': False})
  70. for k in sig_kwargs_keys or []:
  71. method_kwargs[k] = tc[k]
  72. root = tc.get('root')
  73. func_args = tc.get('args', ())
  74. try:
  75. r, rr = method(*method_args, args=func_args, **method_kwargs)
  76. return root, rr, tc
  77. except Exception:
  78. return root, zeros.RootResults(nan, -1, -1, zeros._EVALUEERR), tc
  79. def run_tests(self, tests, method, name,
  80. xtol=4 * _FLOAT_EPS, rtol=4 * _FLOAT_EPS,
  81. known_fail=None, **kwargs):
  82. r"""Run test-cases using the specified method and the supplied signature.
  83. Extract the arguments for the method call from the test case
  84. dictionary using the supplied keys for the method's signature."""
  85. # The methods have one of two base signatures:
  86. # (f, a, b, **kwargs) # newton
  87. # (func, x0, **kwargs) # bisect/brentq/...
  88. sig = _getargspec(method) # ArgSpec with args, varargs, varkw, defaults
  89. nDefaults = len(sig[3])
  90. nRequired = len(sig[0]) - nDefaults
  91. sig_args_keys = sig[0][:nRequired]
  92. sig_kwargs_keys = []
  93. if name in ['secant', 'newton', 'halley']:
  94. if name in ['newton', 'halley']:
  95. sig_kwargs_keys.append('fprime')
  96. if name in ['halley']:
  97. sig_kwargs_keys.append('fprime2')
  98. kwargs['tol'] = xtol
  99. else:
  100. kwargs['xtol'] = xtol
  101. kwargs['rtol'] = rtol
  102. results = [list(self._run_one_test(
  103. tc, method, sig_args_keys=sig_args_keys,
  104. sig_kwargs_keys=sig_kwargs_keys, **kwargs)) for tc in tests]
  105. # results= [[true root, full output, tc], ...]
  106. known_fail = known_fail or []
  107. notcvgd = [elt for elt in results if not elt[1].converged]
  108. notcvgd = [elt for elt in notcvgd if elt[-1]['ID'] not in known_fail]
  109. notcvged_IDS = [elt[-1]['ID'] for elt in notcvgd]
  110. assert_equal([len(notcvged_IDS), notcvged_IDS], [0, []])
  111. # The usable xtol and rtol depend on the test
  112. tols = {'xtol': 4 * _FLOAT_EPS, 'rtol': 4 * _FLOAT_EPS}
  113. tols.update(**kwargs)
  114. rtol = tols['rtol']
  115. atol = tols.get('tol', tols['xtol'])
  116. cvgd = [elt for elt in results if elt[1].converged]
  117. approx = [elt[1].root for elt in cvgd]
  118. correct = [elt[0] for elt in cvgd]
  119. notclose = [[a] + elt for a, c, elt in zip(approx, correct, cvgd) if
  120. not isclose(a, c, rtol=rtol, atol=atol)
  121. and elt[-1]['ID'] not in known_fail]
  122. # Evaluate the function and see if is 0 at the purported root
  123. fvs = [tc['f'](aroot, *(tc['args'])) for aroot, c, fullout, tc in notclose]
  124. notclose = [[fv] + elt for fv, elt in zip(fvs, notclose) if fv != 0]
  125. assert_equal([notclose, len(notclose)], [[], 0])
  126. def run_collection(self, collection, method, name, smoothness=None,
  127. known_fail=None,
  128. xtol=4 * _FLOAT_EPS, rtol=4 * _FLOAT_EPS,
  129. **kwargs):
  130. r"""Run a collection of tests using the specified method.
  131. The name is used to determine some optional arguments."""
  132. tests = get_tests(collection, smoothness=smoothness)
  133. self.run_tests(tests, method, name, xtol=xtol, rtol=rtol,
  134. known_fail=known_fail, **kwargs)
  135. def test_bisect(self):
  136. self.run_check(zeros.bisect, 'bisect')
  137. self.run_check_by_name('bisect')
  138. self.run_collection('aps', zeros.bisect, 'bisect', smoothness=1)
  139. def test_ridder(self):
  140. self.run_check(zeros.ridder, 'ridder')
  141. self.run_check_by_name('ridder')
  142. self.run_collection('aps', zeros.ridder, 'ridder', smoothness=1)
  143. def test_brentq(self):
  144. self.run_check(zeros.brentq, 'brentq')
  145. self.run_check_by_name('brentq')
  146. # Brentq/h needs a lower tolerance to be specified
  147. self.run_collection('aps', zeros.brentq, 'brentq', smoothness=1,
  148. xtol=1e-14, rtol=1e-14)
  149. def test_brenth(self):
  150. self.run_check(zeros.brenth, 'brenth')
  151. self.run_check_by_name('brenth')
  152. self.run_collection('aps', zeros.brenth, 'brenth', smoothness=1,
  153. xtol=1e-14, rtol=1e-14)
  154. def test_toms748(self):
  155. self.run_check(zeros.toms748, 'toms748')
  156. self.run_check_by_name('toms748')
  157. self.run_collection('aps', zeros.toms748, 'toms748', smoothness=1)
  158. def test_newton_collections(self):
  159. known_fail = ['aps.13.00']
  160. known_fail += ['aps.12.05', 'aps.12.17'] # fails under Windows Py27
  161. for collection in ['aps', 'complex']:
  162. self.run_collection(collection, zeros.newton, 'newton',
  163. smoothness=2, known_fail=known_fail)
  164. def test_halley_collections(self):
  165. known_fail = ['aps.12.06', 'aps.12.07', 'aps.12.08', 'aps.12.09',
  166. 'aps.12.10', 'aps.12.11', 'aps.12.12', 'aps.12.13',
  167. 'aps.12.14', 'aps.12.15', 'aps.12.16', 'aps.12.17',
  168. 'aps.12.18', 'aps.13.00']
  169. for collection in ['aps', 'complex']:
  170. self.run_collection(collection, zeros.newton, 'halley',
  171. smoothness=2, known_fail=known_fail)
  172. @staticmethod
  173. def f1(x):
  174. return x**2 - 2*x - 1 # == (x-1)**2 - 2
  175. @staticmethod
  176. def f1_1(x):
  177. return 2*x - 2
  178. @staticmethod
  179. def f1_2(x):
  180. return 2.0 + 0*x
  181. @staticmethod
  182. def f2(x):
  183. return exp(x) - cos(x)
  184. @staticmethod
  185. def f2_1(x):
  186. return exp(x) + sin(x)
  187. @staticmethod
  188. def f2_2(x):
  189. return exp(x) + cos(x)
  190. def test_newton(self):
  191. for f, f_1, f_2 in [(self.f1, self.f1_1, self.f1_2),
  192. (self.f2, self.f2_1, self.f2_2)]:
  193. x = zeros.newton(f, 3, tol=1e-6)
  194. assert_allclose(f(x), 0, atol=1e-6)
  195. x = zeros.newton(f, 3, x1=5, tol=1e-6) # secant, x0 and x1
  196. assert_allclose(f(x), 0, atol=1e-6)
  197. x = zeros.newton(f, 3, fprime=f_1, tol=1e-6) # newton
  198. assert_allclose(f(x), 0, atol=1e-6)
  199. x = zeros.newton(f, 3, fprime=f_1, fprime2=f_2, tol=1e-6) # halley
  200. assert_allclose(f(x), 0, atol=1e-6)
  201. def test_newton_by_name(self):
  202. r"""Invoke newton through root_scalar()"""
  203. for f, f_1, f_2 in [(f1, f1_1, f1_2), (f2, f2_1, f2_2)]:
  204. r = root_scalar(f, method='newton', x0=3, fprime=f_1, xtol=1e-6)
  205. assert_allclose(f(r.root), 0, atol=1e-6)
  206. def test_secant_by_name(self):
  207. r"""Invoke secant through root_scalar()"""
  208. for f, f_1, f_2 in [(f1, f1_1, f1_2), (f2, f2_1, f2_2)]:
  209. r = root_scalar(f, method='secant', x0=3, x1=2, xtol=1e-6)
  210. assert_allclose(f(r.root), 0, atol=1e-6)
  211. r = root_scalar(f, method='secant', x0=3, x1=5, xtol=1e-6)
  212. assert_allclose(f(r.root), 0, atol=1e-6)
  213. def test_halley_by_name(self):
  214. r"""Invoke halley through root_scalar()"""
  215. for f, f_1, f_2 in [(f1, f1_1, f1_2), (f2, f2_1, f2_2)]:
  216. r = root_scalar(f, method='halley', x0=3,
  217. fprime=f_1, fprime2=f_2, xtol=1e-6)
  218. assert_allclose(f(r.root), 0, atol=1e-6)
  219. def test_root_scalar_fail(self):
  220. with pytest.raises(ValueError):
  221. root_scalar(f1, method='secant', x0=3, xtol=1e-6) # no x1
  222. with pytest.raises(ValueError):
  223. root_scalar(f1, method='newton', x0=3, xtol=1e-6) # no fprime
  224. with pytest.raises(ValueError):
  225. root_scalar(f1, method='halley', fprime=f1_1, x0=3, xtol=1e-6) # no fprime2
  226. with pytest.raises(ValueError):
  227. root_scalar(f1, method='halley', fprime2=f1_2, x0=3, xtol=1e-6) # no fprime
  228. def test_array_newton(self):
  229. """test newton with array"""
  230. def f1(x, *a):
  231. b = a[0] + x * a[3]
  232. return a[1] - a[2] * (np.exp(b / a[5]) - 1.0) - b / a[4] - x
  233. def f1_1(x, *a):
  234. b = a[3] / a[5]
  235. return -a[2] * np.exp(a[0] / a[5] + x * b) * b - a[3] / a[4] - 1
  236. def f1_2(x, *a):
  237. b = a[3] / a[5]
  238. return -a[2] * np.exp(a[0] / a[5] + x * b) * b**2
  239. a0 = np.array([
  240. 5.32725221, 5.48673747, 5.49539973,
  241. 5.36387202, 4.80237316, 1.43764452,
  242. 5.23063958, 5.46094772, 5.50512718,
  243. 5.42046290
  244. ])
  245. a1 = (np.sin(range(10)) + 1.0) * 7.0
  246. args = (a0, a1, 1e-09, 0.004, 10, 0.27456)
  247. x0 = [7.0] * 10
  248. x = zeros.newton(f1, x0, f1_1, args)
  249. x_expected = (
  250. 6.17264965, 11.7702805, 12.2219954,
  251. 7.11017681, 1.18151293, 0.143707955,
  252. 4.31928228, 10.5419107, 12.7552490,
  253. 8.91225749
  254. )
  255. assert_allclose(x, x_expected)
  256. # test halley's
  257. x = zeros.newton(f1, x0, f1_1, args, fprime2=f1_2)
  258. assert_allclose(x, x_expected)
  259. # test secant
  260. x = zeros.newton(f1, x0, args=args)
  261. assert_allclose(x, x_expected)
  262. def test_array_secant_active_zero_der(self):
  263. """test secant doesn't continue to iterate zero derivatives"""
  264. x = zeros.newton(lambda x, *a: x*x - a[0], x0=[4.123, 5],
  265. args=[np.array([17, 25])])
  266. assert_allclose(x, (4.123105625617661, 5.0))
  267. def test_array_newton_integers(self):
  268. # test secant with float
  269. x = zeros.newton(lambda y, z: z - y ** 2, [4.0] * 2,
  270. args=([15.0, 17.0],))
  271. assert_allclose(x, (3.872983346207417, 4.123105625617661))
  272. # test integer becomes float
  273. x = zeros.newton(lambda y, z: z - y ** 2, [4] * 2, args=([15, 17],))
  274. assert_allclose(x, (3.872983346207417, 4.123105625617661))
  275. def test_array_newton_zero_der_failures(self):
  276. # test derivative zero warning
  277. assert_warns(RuntimeWarning, zeros.newton,
  278. lambda y: y**2 - 2, [0., 0.], lambda y: 2 * y)
  279. # test failures and zero_der
  280. with pytest.warns(RuntimeWarning):
  281. results = zeros.newton(lambda y: y**2 - 2, [0., 0.],
  282. lambda y: 2*y, full_output=True)
  283. assert_allclose(results.root, 0)
  284. assert results.zero_der.all()
  285. assert not results.converged.any()
  286. def test_newton_combined(self):
  287. f1 = lambda x: x**2 - 2*x - 1
  288. f1_1 = lambda x: 2*x - 2
  289. f1_2 = lambda x: 2.0 + 0*x
  290. def f1_and_p_and_pp(x):
  291. return x**2 - 2*x-1, 2*x-2, 2.0
  292. sol0 = root_scalar(f1, method='newton', x0=3, fprime=f1_1)
  293. sol = root_scalar(f1_and_p_and_pp, method='newton', x0=3, fprime=True)
  294. assert_allclose(sol0.root, sol.root, atol=1e-8)
  295. assert_equal(2*sol.function_calls, sol0.function_calls)
  296. sol0 = root_scalar(f1, method='halley', x0=3, fprime=f1_1, fprime2=f1_2)
  297. sol = root_scalar(f1_and_p_and_pp, method='halley', x0=3, fprime2=True)
  298. assert_allclose(sol0.root, sol.root, atol=1e-8)
  299. assert_equal(3*sol.function_calls, sol0.function_calls)
  300. def test_newton_full_output(self):
  301. # Test the full_output capability, both when converging and not.
  302. # Use simple polynomials, to avoid hitting platform dependencies
  303. # (e.g. exp & trig) in number of iterations
  304. x0 = 3
  305. expected_counts = [(6, 7), (5, 10), (3, 9)]
  306. for derivs in range(3):
  307. kwargs = {'tol': 1e-6, 'full_output': True, }
  308. for k, v in [['fprime', self.f1_1], ['fprime2', self.f1_2]][:derivs]:
  309. kwargs[k] = v
  310. x, r = zeros.newton(self.f1, x0, disp=False, **kwargs)
  311. assert_(r.converged)
  312. assert_equal(x, r.root)
  313. assert_equal((r.iterations, r.function_calls), expected_counts[derivs])
  314. if derivs == 0:
  315. assert(r.function_calls <= r.iterations + 1)
  316. else:
  317. assert_equal(r.function_calls, (derivs + 1) * r.iterations)
  318. # Now repeat, allowing one fewer iteration to force convergence failure
  319. iters = r.iterations - 1
  320. x, r = zeros.newton(self.f1, x0, maxiter=iters, disp=False, **kwargs)
  321. assert_(not r.converged)
  322. assert_equal(x, r.root)
  323. assert_equal(r.iterations, iters)
  324. if derivs == 1:
  325. # Check that the correct Exception is raised and
  326. # validate the start of the message.
  327. with pytest.raises(
  328. RuntimeError,
  329. match='Failed to converge after %d iterations, value is .*' % (iters)):
  330. x, r = zeros.newton(self.f1, x0, maxiter=iters, disp=True, **kwargs)
  331. def test_deriv_zero_warning(self):
  332. func = lambda x: x**2 - 2.0
  333. dfunc = lambda x: 2*x
  334. assert_warns(RuntimeWarning, zeros.newton, func, 0.0, dfunc)
  335. def test_newton_does_not_modify_x0(self):
  336. # https://github.com/scipy/scipy/issues/9964
  337. x0 = np.array([0.1, 3])
  338. x0_copy = x0.copy() # Copy to test for equality.
  339. newton(np.sin, x0, np.cos)
  340. assert_array_equal(x0, x0_copy)
  341. def test_gh_5555():
  342. root = 0.1
  343. def f(x):
  344. return x - root
  345. methods = [zeros.bisect, zeros.ridder]
  346. xtol = rtol = TOL
  347. for method in methods:
  348. res = method(f, -1e8, 1e7, xtol=xtol, rtol=rtol)
  349. assert_allclose(root, res, atol=xtol, rtol=rtol,
  350. err_msg='method %s' % method.__name__)
  351. def test_gh_5557():
  352. # Show that without the changes in 5557 brentq and brenth might
  353. # only achieve a tolerance of 2*(xtol + rtol*|res|).
  354. # f linearly interpolates (0, -0.1), (0.5, -0.1), and (1,
  355. # 0.4). The important parts are that |f(0)| < |f(1)| (so that
  356. # brent takes 0 as the initial guess), |f(0)| < atol (so that
  357. # brent accepts 0 as the root), and that the exact root of f lies
  358. # more than atol away from 0 (so that brent doesn't achieve the
  359. # desired tolerance).
  360. def f(x):
  361. if x < 0.5:
  362. return -0.1
  363. else:
  364. return x - 0.6
  365. atol = 0.51
  366. rtol = 4 * _FLOAT_EPS
  367. methods = [zeros.brentq, zeros.brenth]
  368. for method in methods:
  369. res = method(f, 0, 1, xtol=atol, rtol=rtol)
  370. assert_allclose(0.6, res, atol=atol, rtol=rtol)
  371. class TestRootResults:
  372. def test_repr(self):
  373. r = zeros.RootResults(root=1.0,
  374. iterations=44,
  375. function_calls=46,
  376. flag=0)
  377. expected_repr = (" converged: True\n flag: 'converged'"
  378. "\n function_calls: 46\n iterations: 44\n"
  379. " root: 1.0")
  380. assert_equal(repr(r), expected_repr)
  381. def test_complex_halley():
  382. """Test Halley's works with complex roots"""
  383. def f(x, *a):
  384. return a[0] * x**2 + a[1] * x + a[2]
  385. def f_1(x, *a):
  386. return 2 * a[0] * x + a[1]
  387. def f_2(x, *a):
  388. retval = 2 * a[0]
  389. try:
  390. size = len(x)
  391. except TypeError:
  392. return retval
  393. else:
  394. return [retval] * size
  395. z = complex(1.0, 2.0)
  396. coeffs = (2.0, 3.0, 4.0)
  397. y = zeros.newton(f, z, args=coeffs, fprime=f_1, fprime2=f_2, tol=1e-6)
  398. # (-0.75000000000000078+1.1989578808281789j)
  399. assert_allclose(f(y, *coeffs), 0, atol=1e-6)
  400. z = [z] * 10
  401. coeffs = (2.0, 3.0, 4.0)
  402. y = zeros.newton(f, z, args=coeffs, fprime=f_1, fprime2=f_2, tol=1e-6)
  403. assert_allclose(f(y, *coeffs), 0, atol=1e-6)
  404. def test_zero_der_nz_dp():
  405. """Test secant method with a non-zero dp, but an infinite newton step"""
  406. # pick a symmetrical functions and choose a point on the side that with dx
  407. # makes a secant that is a flat line with zero slope, EG: f = (x - 100)**2,
  408. # which has a root at x = 100 and is symmetrical around the line x = 100
  409. # we have to pick a really big number so that it is consistently true
  410. # now find a point on each side so that the secant has a zero slope
  411. dx = np.finfo(float).eps ** 0.33
  412. # 100 - p0 = p1 - 100 = p0 * (1 + dx) + dx - 100
  413. # -> 200 = p0 * (2 + dx) + dx
  414. p0 = (200.0 - dx) / (2.0 + dx)
  415. with suppress_warnings() as sup:
  416. sup.filter(RuntimeWarning, "RMS of")
  417. x = zeros.newton(lambda y: (y - 100.0)**2, x0=[p0] * 10)
  418. assert_allclose(x, [100] * 10)
  419. # test scalar cases too
  420. p0 = (2.0 - 1e-4) / (2.0 + 1e-4)
  421. with suppress_warnings() as sup:
  422. sup.filter(RuntimeWarning, "Tolerance of")
  423. x = zeros.newton(lambda y: (y - 1.0) ** 2, x0=p0)
  424. assert_allclose(x, 1)
  425. p0 = (-2.0 + 1e-4) / (2.0 + 1e-4)
  426. with suppress_warnings() as sup:
  427. sup.filter(RuntimeWarning, "Tolerance of")
  428. x = zeros.newton(lambda y: (y + 1.0) ** 2, x0=p0)
  429. assert_allclose(x, -1)
  430. def test_array_newton_failures():
  431. """Test that array newton fails as expected"""
  432. # p = 0.68 # [MPa]
  433. # dp = -0.068 * 1e6 # [Pa]
  434. # T = 323 # [K]
  435. diameter = 0.10 # [m]
  436. # L = 100 # [m]
  437. roughness = 0.00015 # [m]
  438. rho = 988.1 # [kg/m**3]
  439. mu = 5.4790e-04 # [Pa*s]
  440. u = 2.488 # [m/s]
  441. reynolds_number = rho * u * diameter / mu # Reynolds number
  442. def colebrook_eqn(darcy_friction, re, dia):
  443. return (1 / np.sqrt(darcy_friction) +
  444. 2 * np.log10(roughness / 3.7 / dia +
  445. 2.51 / re / np.sqrt(darcy_friction)))
  446. # only some failures
  447. with pytest.warns(RuntimeWarning):
  448. result = zeros.newton(
  449. colebrook_eqn, x0=[0.01, 0.2, 0.02223, 0.3], maxiter=2,
  450. args=[reynolds_number, diameter], full_output=True
  451. )
  452. assert not result.converged.all()
  453. # they all fail
  454. with pytest.raises(RuntimeError):
  455. result = zeros.newton(
  456. colebrook_eqn, x0=[0.01] * 2, maxiter=2,
  457. args=[reynolds_number, diameter], full_output=True
  458. )
  459. # this test should **not** raise a RuntimeWarning
  460. def test_gh8904_zeroder_at_root_fails():
  461. """Test that Newton or Halley don't warn if zero derivative at root"""
  462. # a function that has a zero derivative at it's root
  463. def f_zeroder_root(x):
  464. return x**3 - x**2
  465. # should work with secant
  466. r = zeros.newton(f_zeroder_root, x0=0)
  467. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  468. # test again with array
  469. r = zeros.newton(f_zeroder_root, x0=[0]*10)
  470. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  471. # 1st derivative
  472. def fder(x):
  473. return 3 * x**2 - 2 * x
  474. # 2nd derivative
  475. def fder2(x):
  476. return 6*x - 2
  477. # should work with newton and halley
  478. r = zeros.newton(f_zeroder_root, x0=0, fprime=fder)
  479. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  480. r = zeros.newton(f_zeroder_root, x0=0, fprime=fder,
  481. fprime2=fder2)
  482. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  483. # test again with array
  484. r = zeros.newton(f_zeroder_root, x0=[0]*10, fprime=fder)
  485. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  486. r = zeros.newton(f_zeroder_root, x0=[0]*10, fprime=fder,
  487. fprime2=fder2)
  488. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  489. # also test that if a root is found we do not raise RuntimeWarning even if
  490. # the derivative is zero, EG: at x = 0.5, then fval = -0.125 and
  491. # fder = -0.25 so the next guess is 0.5 - (-0.125/-0.5) = 0 which is the
  492. # root, but if the solver continued with that guess, then it will calculate
  493. # a zero derivative, so it should return the root w/o RuntimeWarning
  494. r = zeros.newton(f_zeroder_root, x0=0.5, fprime=fder)
  495. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  496. # test again with array
  497. r = zeros.newton(f_zeroder_root, x0=[0.5]*10, fprime=fder)
  498. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  499. # doesn't apply to halley
  500. def test_gh_8881():
  501. r"""Test that Halley's method realizes that the 2nd order adjustment
  502. is too big and drops off to the 1st order adjustment."""
  503. n = 9
  504. def f(x):
  505. return power(x, 1.0/n) - power(n, 1.0/n)
  506. def fp(x):
  507. return power(x, (1.0-n)/n)/n
  508. def fpp(x):
  509. return power(x, (1.0-2*n)/n) * (1.0/n) * (1.0-n)/n
  510. x0 = 0.1
  511. # The root is at x=9.
  512. # The function has positive slope, x0 < root.
  513. # Newton succeeds in 8 iterations
  514. rt, r = newton(f, x0, fprime=fp, full_output=True)
  515. assert(r.converged)
  516. # Before the Issue 8881/PR 8882, halley would send x in the wrong direction.
  517. # Check that it now succeeds.
  518. rt, r = newton(f, x0, fprime=fp, fprime2=fpp, full_output=True)
  519. assert(r.converged)
  520. def test_gh_9608_preserve_array_shape():
  521. """
  522. Test that shape is preserved for array inputs even if fprime or fprime2 is
  523. scalar
  524. """
  525. def f(x):
  526. return x**2
  527. def fp(x):
  528. return 2 * x
  529. def fpp(x):
  530. return 2
  531. x0 = np.array([-2], dtype=np.float32)
  532. rt, r = newton(f, x0, fprime=fp, fprime2=fpp, full_output=True)
  533. assert(r.converged)
  534. x0_array = np.array([-2, -3], dtype=np.float32)
  535. # This next invocation should fail
  536. with pytest.raises(IndexError):
  537. result = zeros.newton(
  538. f, x0_array, fprime=fp, fprime2=fpp, full_output=True
  539. )
  540. def fpp_array(x):
  541. return 2*np.ones(np.shape(x), dtype=np.float32)
  542. result = zeros.newton(
  543. f, x0_array, fprime=fp, fprime2=fpp_array, full_output=True
  544. )
  545. assert result.converged.all()