test_integrate.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835
  1. # Authors: Nils Wagner, Ed Schofield, Pauli Virtanen, John Travers
  2. """
  3. Tests for numerical integration.
  4. """
  5. from __future__ import division, print_function, absolute_import
  6. import numpy as np
  7. from numpy import (arange, zeros, array, dot, sqrt, cos, sin, eye, pi, exp,
  8. allclose)
  9. from scipy._lib._numpy_compat import _assert_warns
  10. from scipy._lib.six import xrange
  11. from numpy.testing import (
  12. assert_, assert_array_almost_equal,
  13. assert_allclose, assert_array_equal, assert_equal)
  14. from pytest import raises as assert_raises
  15. from scipy.integrate import odeint, ode, complex_ode
  16. #------------------------------------------------------------------------------
  17. # Test ODE integrators
  18. #------------------------------------------------------------------------------
  19. class TestOdeint(object):
  20. # Check integrate.odeint
  21. def _do_problem(self, problem):
  22. t = arange(0.0, problem.stop_t, 0.05)
  23. # Basic case
  24. z, infodict = odeint(problem.f, problem.z0, t, full_output=True)
  25. assert_(problem.verify(z, t))
  26. # Use tfirst=True
  27. z, infodict = odeint(lambda t, y: problem.f(y, t), problem.z0, t,
  28. full_output=True, tfirst=True)
  29. assert_(problem.verify(z, t))
  30. if hasattr(problem, 'jac'):
  31. # Use Dfun
  32. z, infodict = odeint(problem.f, problem.z0, t, Dfun=problem.jac,
  33. full_output=True)
  34. assert_(problem.verify(z, t))
  35. # Use Dfun and tfirst=True
  36. z, infodict = odeint(lambda t, y: problem.f(y, t), problem.z0, t,
  37. Dfun=lambda t, y: problem.jac(y, t),
  38. full_output=True, tfirst=True)
  39. assert_(problem.verify(z, t))
  40. def test_odeint(self):
  41. for problem_cls in PROBLEMS:
  42. problem = problem_cls()
  43. if problem.cmplx:
  44. continue
  45. self._do_problem(problem)
  46. class TestODEClass(object):
  47. ode_class = None # Set in subclass.
  48. def _do_problem(self, problem, integrator, method='adams'):
  49. # ode has callback arguments in different order than odeint
  50. f = lambda t, z: problem.f(z, t)
  51. jac = None
  52. if hasattr(problem, 'jac'):
  53. jac = lambda t, z: problem.jac(z, t)
  54. integrator_params = {}
  55. if problem.lband is not None or problem.uband is not None:
  56. integrator_params['uband'] = problem.uband
  57. integrator_params['lband'] = problem.lband
  58. ig = self.ode_class(f, jac)
  59. ig.set_integrator(integrator,
  60. atol=problem.atol/10,
  61. rtol=problem.rtol/10,
  62. method=method,
  63. **integrator_params)
  64. ig.set_initial_value(problem.z0, t=0.0)
  65. z = ig.integrate(problem.stop_t)
  66. assert_array_equal(z, ig.y)
  67. assert_(ig.successful(), (problem, method))
  68. assert_(ig.get_return_code() > 0, (problem, method))
  69. assert_(problem.verify(array([z]), problem.stop_t), (problem, method))
  70. class TestOde(TestODEClass):
  71. ode_class = ode
  72. def test_vode(self):
  73. # Check the vode solver
  74. for problem_cls in PROBLEMS:
  75. problem = problem_cls()
  76. if problem.cmplx:
  77. continue
  78. if not problem.stiff:
  79. self._do_problem(problem, 'vode', 'adams')
  80. self._do_problem(problem, 'vode', 'bdf')
  81. def test_zvode(self):
  82. # Check the zvode solver
  83. for problem_cls in PROBLEMS:
  84. problem = problem_cls()
  85. if not problem.stiff:
  86. self._do_problem(problem, 'zvode', 'adams')
  87. self._do_problem(problem, 'zvode', 'bdf')
  88. def test_lsoda(self):
  89. # Check the lsoda solver
  90. for problem_cls in PROBLEMS:
  91. problem = problem_cls()
  92. if problem.cmplx:
  93. continue
  94. self._do_problem(problem, 'lsoda')
  95. def test_dopri5(self):
  96. # Check the dopri5 solver
  97. for problem_cls in PROBLEMS:
  98. problem = problem_cls()
  99. if problem.cmplx:
  100. continue
  101. if problem.stiff:
  102. continue
  103. if hasattr(problem, 'jac'):
  104. continue
  105. self._do_problem(problem, 'dopri5')
  106. def test_dop853(self):
  107. # Check the dop853 solver
  108. for problem_cls in PROBLEMS:
  109. problem = problem_cls()
  110. if problem.cmplx:
  111. continue
  112. if problem.stiff:
  113. continue
  114. if hasattr(problem, 'jac'):
  115. continue
  116. self._do_problem(problem, 'dop853')
  117. def test_concurrent_fail(self):
  118. for sol in ('vode', 'zvode', 'lsoda'):
  119. f = lambda t, y: 1.0
  120. r = ode(f).set_integrator(sol)
  121. r.set_initial_value(0, 0)
  122. r2 = ode(f).set_integrator(sol)
  123. r2.set_initial_value(0, 0)
  124. r.integrate(r.t + 0.1)
  125. r2.integrate(r2.t + 0.1)
  126. assert_raises(RuntimeError, r.integrate, r.t + 0.1)
  127. def test_concurrent_ok(self):
  128. f = lambda t, y: 1.0
  129. for k in xrange(3):
  130. for sol in ('vode', 'zvode', 'lsoda', 'dopri5', 'dop853'):
  131. r = ode(f).set_integrator(sol)
  132. r.set_initial_value(0, 0)
  133. r2 = ode(f).set_integrator(sol)
  134. r2.set_initial_value(0, 0)
  135. r.integrate(r.t + 0.1)
  136. r2.integrate(r2.t + 0.1)
  137. r2.integrate(r2.t + 0.1)
  138. assert_allclose(r.y, 0.1)
  139. assert_allclose(r2.y, 0.2)
  140. for sol in ('dopri5', 'dop853'):
  141. r = ode(f).set_integrator(sol)
  142. r.set_initial_value(0, 0)
  143. r2 = ode(f).set_integrator(sol)
  144. r2.set_initial_value(0, 0)
  145. r.integrate(r.t + 0.1)
  146. r.integrate(r.t + 0.1)
  147. r2.integrate(r2.t + 0.1)
  148. r.integrate(r.t + 0.1)
  149. r2.integrate(r2.t + 0.1)
  150. assert_allclose(r.y, 0.3)
  151. assert_allclose(r2.y, 0.2)
  152. class TestComplexOde(TestODEClass):
  153. ode_class = complex_ode
  154. def test_vode(self):
  155. # Check the vode solver
  156. for problem_cls in PROBLEMS:
  157. problem = problem_cls()
  158. if not problem.stiff:
  159. self._do_problem(problem, 'vode', 'adams')
  160. else:
  161. self._do_problem(problem, 'vode', 'bdf')
  162. def test_lsoda(self):
  163. # Check the lsoda solver
  164. for problem_cls in PROBLEMS:
  165. problem = problem_cls()
  166. self._do_problem(problem, 'lsoda')
  167. def test_dopri5(self):
  168. # Check the dopri5 solver
  169. for problem_cls in PROBLEMS:
  170. problem = problem_cls()
  171. if problem.stiff:
  172. continue
  173. if hasattr(problem, 'jac'):
  174. continue
  175. self._do_problem(problem, 'dopri5')
  176. def test_dop853(self):
  177. # Check the dop853 solver
  178. for problem_cls in PROBLEMS:
  179. problem = problem_cls()
  180. if problem.stiff:
  181. continue
  182. if hasattr(problem, 'jac'):
  183. continue
  184. self._do_problem(problem, 'dop853')
  185. class TestSolout(object):
  186. # Check integrate.ode correctly handles solout for dopri5 and dop853
  187. def _run_solout_test(self, integrator):
  188. # Check correct usage of solout
  189. ts = []
  190. ys = []
  191. t0 = 0.0
  192. tend = 10.0
  193. y0 = [1.0, 2.0]
  194. def solout(t, y):
  195. ts.append(t)
  196. ys.append(y.copy())
  197. def rhs(t, y):
  198. return [y[0] + y[1], -y[1]**2]
  199. ig = ode(rhs).set_integrator(integrator)
  200. ig.set_solout(solout)
  201. ig.set_initial_value(y0, t0)
  202. ret = ig.integrate(tend)
  203. assert_array_equal(ys[0], y0)
  204. assert_array_equal(ys[-1], ret)
  205. assert_equal(ts[0], t0)
  206. assert_equal(ts[-1], tend)
  207. def test_solout(self):
  208. for integrator in ('dopri5', 'dop853'):
  209. self._run_solout_test(integrator)
  210. def _run_solout_after_initial_test(self, integrator):
  211. # Check if solout works even if it is set after the initial value.
  212. ts = []
  213. ys = []
  214. t0 = 0.0
  215. tend = 10.0
  216. y0 = [1.0, 2.0]
  217. def solout(t, y):
  218. ts.append(t)
  219. ys.append(y.copy())
  220. def rhs(t, y):
  221. return [y[0] + y[1], -y[1]**2]
  222. ig = ode(rhs).set_integrator(integrator)
  223. ig.set_initial_value(y0, t0)
  224. ig.set_solout(solout)
  225. ret = ig.integrate(tend)
  226. assert_array_equal(ys[0], y0)
  227. assert_array_equal(ys[-1], ret)
  228. assert_equal(ts[0], t0)
  229. assert_equal(ts[-1], tend)
  230. def test_solout_after_initial(self):
  231. for integrator in ('dopri5', 'dop853'):
  232. self._run_solout_after_initial_test(integrator)
  233. def _run_solout_break_test(self, integrator):
  234. # Check correct usage of stopping via solout
  235. ts = []
  236. ys = []
  237. t0 = 0.0
  238. tend = 10.0
  239. y0 = [1.0, 2.0]
  240. def solout(t, y):
  241. ts.append(t)
  242. ys.append(y.copy())
  243. if t > tend/2.0:
  244. return -1
  245. def rhs(t, y):
  246. return [y[0] + y[1], -y[1]**2]
  247. ig = ode(rhs).set_integrator(integrator)
  248. ig.set_solout(solout)
  249. ig.set_initial_value(y0, t0)
  250. ret = ig.integrate(tend)
  251. assert_array_equal(ys[0], y0)
  252. assert_array_equal(ys[-1], ret)
  253. assert_equal(ts[0], t0)
  254. assert_(ts[-1] > tend/2.0)
  255. assert_(ts[-1] < tend)
  256. def test_solout_break(self):
  257. for integrator in ('dopri5', 'dop853'):
  258. self._run_solout_break_test(integrator)
  259. class TestComplexSolout(object):
  260. # Check integrate.ode correctly handles solout for dopri5 and dop853
  261. def _run_solout_test(self, integrator):
  262. # Check correct usage of solout
  263. ts = []
  264. ys = []
  265. t0 = 0.0
  266. tend = 20.0
  267. y0 = [0.0]
  268. def solout(t, y):
  269. ts.append(t)
  270. ys.append(y.copy())
  271. def rhs(t, y):
  272. return [1.0/(t - 10.0 - 1j)]
  273. ig = complex_ode(rhs).set_integrator(integrator)
  274. ig.set_solout(solout)
  275. ig.set_initial_value(y0, t0)
  276. ret = ig.integrate(tend)
  277. assert_array_equal(ys[0], y0)
  278. assert_array_equal(ys[-1], ret)
  279. assert_equal(ts[0], t0)
  280. assert_equal(ts[-1], tend)
  281. def test_solout(self):
  282. for integrator in ('dopri5', 'dop853'):
  283. self._run_solout_test(integrator)
  284. def _run_solout_break_test(self, integrator):
  285. # Check correct usage of stopping via solout
  286. ts = []
  287. ys = []
  288. t0 = 0.0
  289. tend = 20.0
  290. y0 = [0.0]
  291. def solout(t, y):
  292. ts.append(t)
  293. ys.append(y.copy())
  294. if t > tend/2.0:
  295. return -1
  296. def rhs(t, y):
  297. return [1.0/(t - 10.0 - 1j)]
  298. ig = complex_ode(rhs).set_integrator(integrator)
  299. ig.set_solout(solout)
  300. ig.set_initial_value(y0, t0)
  301. ret = ig.integrate(tend)
  302. assert_array_equal(ys[0], y0)
  303. assert_array_equal(ys[-1], ret)
  304. assert_equal(ts[0], t0)
  305. assert_(ts[-1] > tend/2.0)
  306. assert_(ts[-1] < tend)
  307. def test_solout_break(self):
  308. for integrator in ('dopri5', 'dop853'):
  309. self._run_solout_break_test(integrator)
  310. #------------------------------------------------------------------------------
  311. # Test problems
  312. #------------------------------------------------------------------------------
  313. class ODE:
  314. """
  315. ODE problem
  316. """
  317. stiff = False
  318. cmplx = False
  319. stop_t = 1
  320. z0 = []
  321. lband = None
  322. uband = None
  323. atol = 1e-6
  324. rtol = 1e-5
  325. class SimpleOscillator(ODE):
  326. r"""
  327. Free vibration of a simple oscillator::
  328. m \ddot{u} + k u = 0, u(0) = u_0 \dot{u}(0) \dot{u}_0
  329. Solution::
  330. u(t) = u_0*cos(sqrt(k/m)*t)+\dot{u}_0*sin(sqrt(k/m)*t)/sqrt(k/m)
  331. """
  332. stop_t = 1 + 0.09
  333. z0 = array([1.0, 0.1], float)
  334. k = 4.0
  335. m = 1.0
  336. def f(self, z, t):
  337. tmp = zeros((2, 2), float)
  338. tmp[0, 1] = 1.0
  339. tmp[1, 0] = -self.k / self.m
  340. return dot(tmp, z)
  341. def verify(self, zs, t):
  342. omega = sqrt(self.k / self.m)
  343. u = self.z0[0]*cos(omega*t) + self.z0[1]*sin(omega*t)/omega
  344. return allclose(u, zs[:, 0], atol=self.atol, rtol=self.rtol)
  345. class ComplexExp(ODE):
  346. r"""The equation :lm:`\dot u = i u`"""
  347. stop_t = 1.23*pi
  348. z0 = exp([1j, 2j, 3j, 4j, 5j])
  349. cmplx = True
  350. def f(self, z, t):
  351. return 1j*z
  352. def jac(self, z, t):
  353. return 1j*eye(5)
  354. def verify(self, zs, t):
  355. u = self.z0 * exp(1j*t)
  356. return allclose(u, zs, atol=self.atol, rtol=self.rtol)
  357. class Pi(ODE):
  358. r"""Integrate 1/(t + 1j) from t=-10 to t=10"""
  359. stop_t = 20
  360. z0 = [0]
  361. cmplx = True
  362. def f(self, z, t):
  363. return array([1./(t - 10 + 1j)])
  364. def verify(self, zs, t):
  365. u = -2j * np.arctan(10)
  366. return allclose(u, zs[-1, :], atol=self.atol, rtol=self.rtol)
  367. class CoupledDecay(ODE):
  368. r"""
  369. 3 coupled decays suited for banded treatment
  370. (banded mode makes it necessary when N>>3)
  371. """
  372. stiff = True
  373. stop_t = 0.5
  374. z0 = [5.0, 7.0, 13.0]
  375. lband = 1
  376. uband = 0
  377. lmbd = [0.17, 0.23, 0.29] # fictitious decay constants
  378. def f(self, z, t):
  379. lmbd = self.lmbd
  380. return np.array([-lmbd[0]*z[0],
  381. -lmbd[1]*z[1] + lmbd[0]*z[0],
  382. -lmbd[2]*z[2] + lmbd[1]*z[1]])
  383. def jac(self, z, t):
  384. # The full Jacobian is
  385. #
  386. # [-lmbd[0] 0 0 ]
  387. # [ lmbd[0] -lmbd[1] 0 ]
  388. # [ 0 lmbd[1] -lmbd[2]]
  389. #
  390. # The lower and upper bandwidths are lband=1 and uband=0, resp.
  391. # The representation of this array in packed format is
  392. #
  393. # [-lmbd[0] -lmbd[1] -lmbd[2]]
  394. # [ lmbd[0] lmbd[1] 0 ]
  395. lmbd = self.lmbd
  396. j = np.zeros((self.lband + self.uband + 1, 3), order='F')
  397. def set_j(ri, ci, val):
  398. j[self.uband + ri - ci, ci] = val
  399. set_j(0, 0, -lmbd[0])
  400. set_j(1, 0, lmbd[0])
  401. set_j(1, 1, -lmbd[1])
  402. set_j(2, 1, lmbd[1])
  403. set_j(2, 2, -lmbd[2])
  404. return j
  405. def verify(self, zs, t):
  406. # Formulae derived by hand
  407. lmbd = np.array(self.lmbd)
  408. d10 = lmbd[1] - lmbd[0]
  409. d21 = lmbd[2] - lmbd[1]
  410. d20 = lmbd[2] - lmbd[0]
  411. e0 = np.exp(-lmbd[0] * t)
  412. e1 = np.exp(-lmbd[1] * t)
  413. e2 = np.exp(-lmbd[2] * t)
  414. u = np.vstack((
  415. self.z0[0] * e0,
  416. self.z0[1] * e1 + self.z0[0] * lmbd[0] / d10 * (e0 - e1),
  417. self.z0[2] * e2 + self.z0[1] * lmbd[1] / d21 * (e1 - e2) +
  418. lmbd[1] * lmbd[0] * self.z0[0] / d10 *
  419. (1 / d20 * (e0 - e2) - 1 / d21 * (e1 - e2)))).transpose()
  420. return allclose(u, zs, atol=self.atol, rtol=self.rtol)
  421. PROBLEMS = [SimpleOscillator, ComplexExp, Pi, CoupledDecay]
  422. #------------------------------------------------------------------------------
  423. def f(t, x):
  424. dxdt = [x[1], -x[0]]
  425. return dxdt
  426. def jac(t, x):
  427. j = array([[0.0, 1.0],
  428. [-1.0, 0.0]])
  429. return j
  430. def f1(t, x, omega):
  431. dxdt = [omega*x[1], -omega*x[0]]
  432. return dxdt
  433. def jac1(t, x, omega):
  434. j = array([[0.0, omega],
  435. [-omega, 0.0]])
  436. return j
  437. def f2(t, x, omega1, omega2):
  438. dxdt = [omega1*x[1], -omega2*x[0]]
  439. return dxdt
  440. def jac2(t, x, omega1, omega2):
  441. j = array([[0.0, omega1],
  442. [-omega2, 0.0]])
  443. return j
  444. def fv(t, x, omega):
  445. dxdt = [omega[0]*x[1], -omega[1]*x[0]]
  446. return dxdt
  447. def jacv(t, x, omega):
  448. j = array([[0.0, omega[0]],
  449. [-omega[1], 0.0]])
  450. return j
  451. class ODECheckParameterUse(object):
  452. """Call an ode-class solver with several cases of parameter use."""
  453. # solver_name must be set before tests can be run with this class.
  454. # Set these in subclasses.
  455. solver_name = ''
  456. solver_uses_jac = False
  457. def _get_solver(self, f, jac):
  458. solver = ode(f, jac)
  459. if self.solver_uses_jac:
  460. solver.set_integrator(self.solver_name, atol=1e-9, rtol=1e-7,
  461. with_jacobian=self.solver_uses_jac)
  462. else:
  463. # XXX Shouldn't set_integrator *always* accept the keyword arg
  464. # 'with_jacobian', and perhaps raise an exception if it is set
  465. # to True if the solver can't actually use it?
  466. solver.set_integrator(self.solver_name, atol=1e-9, rtol=1e-7)
  467. return solver
  468. def _check_solver(self, solver):
  469. ic = [1.0, 0.0]
  470. solver.set_initial_value(ic, 0.0)
  471. solver.integrate(pi)
  472. assert_array_almost_equal(solver.y, [-1.0, 0.0])
  473. def test_no_params(self):
  474. solver = self._get_solver(f, jac)
  475. self._check_solver(solver)
  476. def test_one_scalar_param(self):
  477. solver = self._get_solver(f1, jac1)
  478. omega = 1.0
  479. solver.set_f_params(omega)
  480. if self.solver_uses_jac:
  481. solver.set_jac_params(omega)
  482. self._check_solver(solver)
  483. def test_two_scalar_params(self):
  484. solver = self._get_solver(f2, jac2)
  485. omega1 = 1.0
  486. omega2 = 1.0
  487. solver.set_f_params(omega1, omega2)
  488. if self.solver_uses_jac:
  489. solver.set_jac_params(omega1, omega2)
  490. self._check_solver(solver)
  491. def test_vector_param(self):
  492. solver = self._get_solver(fv, jacv)
  493. omega = [1.0, 1.0]
  494. solver.set_f_params(omega)
  495. if self.solver_uses_jac:
  496. solver.set_jac_params(omega)
  497. self._check_solver(solver)
  498. def test_warns_on_failure(self):
  499. # Set nsteps small to ensure failure
  500. solver = self._get_solver(f, jac)
  501. solver.set_integrator(self.solver_name, nsteps=1)
  502. ic = [1.0, 0.0]
  503. solver.set_initial_value(ic, 0.0)
  504. _assert_warns(UserWarning, solver.integrate, pi)
  505. class TestDOPRI5CheckParameterUse(ODECheckParameterUse):
  506. solver_name = 'dopri5'
  507. solver_uses_jac = False
  508. class TestDOP853CheckParameterUse(ODECheckParameterUse):
  509. solver_name = 'dop853'
  510. solver_uses_jac = False
  511. class TestVODECheckParameterUse(ODECheckParameterUse):
  512. solver_name = 'vode'
  513. solver_uses_jac = True
  514. class TestZVODECheckParameterUse(ODECheckParameterUse):
  515. solver_name = 'zvode'
  516. solver_uses_jac = True
  517. class TestLSODACheckParameterUse(ODECheckParameterUse):
  518. solver_name = 'lsoda'
  519. solver_uses_jac = True
  520. def test_odeint_trivial_time():
  521. # Test that odeint succeeds when given a single time point
  522. # and full_output=True. This is a regression test for gh-4282.
  523. y0 = 1
  524. t = [0]
  525. y, info = odeint(lambda y, t: -y, y0, t, full_output=True)
  526. assert_array_equal(y, np.array([[y0]]))
  527. def test_odeint_banded_jacobian():
  528. # Test the use of the `Dfun`, `ml` and `mu` options of odeint.
  529. def func(y, t, c):
  530. return c.dot(y)
  531. def jac(y, t, c):
  532. return c
  533. def jac_transpose(y, t, c):
  534. return c.T.copy(order='C')
  535. def bjac_rows(y, t, c):
  536. jac = np.row_stack((np.r_[0, np.diag(c, 1)],
  537. np.diag(c),
  538. np.r_[np.diag(c, -1), 0],
  539. np.r_[np.diag(c, -2), 0, 0]))
  540. return jac
  541. def bjac_cols(y, t, c):
  542. return bjac_rows(y, t, c).T.copy(order='C')
  543. c = array([[-205, 0.01, 0.00, 0.0],
  544. [0.1, -2.50, 0.02, 0.0],
  545. [1e-3, 0.01, -2.0, 0.01],
  546. [0.00, 0.00, 0.1, -1.0]])
  547. y0 = np.ones(4)
  548. t = np.array([0, 5, 10, 100])
  549. # Use the full Jacobian.
  550. sol1, info1 = odeint(func, y0, t, args=(c,), full_output=True,
  551. atol=1e-13, rtol=1e-11, mxstep=10000,
  552. Dfun=jac)
  553. # Use the transposed full Jacobian, with col_deriv=True.
  554. sol2, info2 = odeint(func, y0, t, args=(c,), full_output=True,
  555. atol=1e-13, rtol=1e-11, mxstep=10000,
  556. Dfun=jac_transpose, col_deriv=True)
  557. # Use the banded Jacobian.
  558. sol3, info3 = odeint(func, y0, t, args=(c,), full_output=True,
  559. atol=1e-13, rtol=1e-11, mxstep=10000,
  560. Dfun=bjac_rows, ml=2, mu=1)
  561. # Use the transposed banded Jacobian, with col_deriv=True.
  562. sol4, info4 = odeint(func, y0, t, args=(c,), full_output=True,
  563. atol=1e-13, rtol=1e-11, mxstep=10000,
  564. Dfun=bjac_cols, ml=2, mu=1, col_deriv=True)
  565. assert_allclose(sol1, sol2, err_msg="sol1 != sol2")
  566. assert_allclose(sol1, sol3, atol=1e-12, err_msg="sol1 != sol3")
  567. assert_allclose(sol3, sol4, err_msg="sol3 != sol4")
  568. # Verify that the number of jacobian evaluations was the same for the
  569. # calls of odeint with a full jacobian and with a banded jacobian. This is
  570. # a regression test--there was a bug in the handling of banded jacobians
  571. # that resulted in an incorrect jacobian matrix being passed to the LSODA
  572. # code. That would cause errors or excessive jacobian evaluations.
  573. assert_array_equal(info1['nje'], info2['nje'])
  574. assert_array_equal(info3['nje'], info4['nje'])
  575. # Test the use of tfirst
  576. sol1ty, info1ty = odeint(lambda t, y, c: func(y, t, c), y0, t, args=(c,),
  577. full_output=True, atol=1e-13, rtol=1e-11,
  578. mxstep=10000,
  579. Dfun=lambda t, y, c: jac(y, t, c), tfirst=True)
  580. # The code should execute the exact same sequence of floating point
  581. # calculations, so these should be exactly equal. We'll be safe and use
  582. # a small tolerance.
  583. assert_allclose(sol1, sol1ty, rtol=1e-12, err_msg="sol1 != sol1ty")
  584. def test_odeint_errors():
  585. def sys1d(x, t):
  586. return -100*x
  587. def bad1(x, t):
  588. return 1.0/0
  589. def bad2(x, t):
  590. return "foo"
  591. def bad_jac1(x, t):
  592. return 1.0/0
  593. def bad_jac2(x, t):
  594. return [["foo"]]
  595. def sys2d(x, t):
  596. return [-100*x[0], -0.1*x[1]]
  597. def sys2d_bad_jac(x, t):
  598. return [[1.0/0, 0], [0, -0.1]]
  599. assert_raises(ZeroDivisionError, odeint, bad1, 1.0, [0, 1])
  600. assert_raises(ValueError, odeint, bad2, 1.0, [0, 1])
  601. assert_raises(ZeroDivisionError, odeint, sys1d, 1.0, [0, 1], Dfun=bad_jac1)
  602. assert_raises(ValueError, odeint, sys1d, 1.0, [0, 1], Dfun=bad_jac2)
  603. assert_raises(ZeroDivisionError, odeint, sys2d, [1.0, 1.0], [0, 1],
  604. Dfun=sys2d_bad_jac)
  605. def test_odeint_bad_shapes():
  606. # Tests of some errors that can occur with odeint.
  607. def badrhs(x, t):
  608. return [1, -1]
  609. def sys1(x, t):
  610. return -100*x
  611. def badjac(x, t):
  612. return [[0, 0, 0]]
  613. # y0 must be at most 1-d.
  614. bad_y0 = [[0, 0], [0, 0]]
  615. assert_raises(ValueError, odeint, sys1, bad_y0, [0, 1])
  616. # t must be at most 1-d.
  617. bad_t = [[0, 1], [2, 3]]
  618. assert_raises(ValueError, odeint, sys1, [10.0], bad_t)
  619. # y0 is 10, but badrhs(x, t) returns [1, -1].
  620. assert_raises(RuntimeError, odeint, badrhs, 10, [0, 1])
  621. # shape of array returned by badjac(x, t) is not correct.
  622. assert_raises(RuntimeError, odeint, sys1, [10, 10], [0, 1], Dfun=badjac)
  623. def test_repeated_t_values():
  624. """Regression test for gh-8217."""
  625. def func(x, t):
  626. return -0.25*x
  627. t = np.zeros(10)
  628. sol = odeint(func, [1.], t)
  629. assert_array_equal(sol, np.ones((len(t), 1)))
  630. tau = 4*np.log(2)
  631. t = [0]*9 + [tau, 2*tau, 2*tau, 3*tau]
  632. sol = odeint(func, [1, 2], t, rtol=1e-12, atol=1e-12)
  633. expected_sol = np.array([[1.0, 2.0]]*9 +
  634. [[0.5, 1.0],
  635. [0.25, 0.5],
  636. [0.25, 0.5],
  637. [0.125, 0.25]])
  638. assert_allclose(sol, expected_sol)
  639. # Edge case: empty t sequence.
  640. sol = odeint(func, [1.], [])
  641. assert_array_equal(sol, np.array([], dtype=np.float64).reshape((0, 1)))
  642. # t values are not monotonic.
  643. assert_raises(ValueError, odeint, func, [1.], [0, 1, 0.5, 0])
  644. assert_raises(ValueError, odeint, func, [1, 2, 3], [0, -1, -2, 3])