ivp.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. from __future__ import division, print_function, absolute_import
  2. import inspect
  3. import numpy as np
  4. from .bdf import BDF
  5. from .radau import Radau
  6. from .rk import RK23, RK45
  7. from .lsoda import LSODA
  8. from scipy.optimize import OptimizeResult
  9. from .common import EPS, OdeSolution
  10. from .base import OdeSolver
  11. METHODS = {'RK23': RK23,
  12. 'RK45': RK45,
  13. 'Radau': Radau,
  14. 'BDF': BDF,
  15. 'LSODA': LSODA}
  16. MESSAGES = {0: "The solver successfully reached the end of the integration interval.",
  17. 1: "A termination event occurred."}
  18. class OdeResult(OptimizeResult):
  19. pass
  20. def prepare_events(events):
  21. """Standardize event functions and extract is_terminal and direction."""
  22. if callable(events):
  23. events = (events,)
  24. if events is not None:
  25. is_terminal = np.empty(len(events), dtype=bool)
  26. direction = np.empty(len(events))
  27. for i, event in enumerate(events):
  28. try:
  29. is_terminal[i] = event.terminal
  30. except AttributeError:
  31. is_terminal[i] = False
  32. try:
  33. direction[i] = event.direction
  34. except AttributeError:
  35. direction[i] = 0
  36. else:
  37. is_terminal = None
  38. direction = None
  39. return events, is_terminal, direction
  40. def solve_event_equation(event, sol, t_old, t):
  41. """Solve an equation corresponding to an ODE event.
  42. The equation is ``event(t, y(t)) = 0``, here ``y(t)`` is known from an
  43. ODE solver using some sort of interpolation. It is solved by
  44. `scipy.optimize.brentq` with xtol=atol=4*EPS.
  45. Parameters
  46. ----------
  47. event : callable
  48. Function ``event(t, y)``.
  49. sol : callable
  50. Function ``sol(t)`` which evaluates an ODE solution between `t_old`
  51. and `t`.
  52. t_old, t : float
  53. Previous and new values of time. They will be used as a bracketing
  54. interval.
  55. Returns
  56. -------
  57. root : float
  58. Found solution.
  59. """
  60. from scipy.optimize import brentq
  61. return brentq(lambda t: event(t, sol(t)), t_old, t,
  62. xtol=4 * EPS, rtol=4 * EPS)
  63. def handle_events(sol, events, active_events, is_terminal, t_old, t):
  64. """Helper function to handle events.
  65. Parameters
  66. ----------
  67. sol : DenseOutput
  68. Function ``sol(t)`` which evaluates an ODE solution between `t_old`
  69. and `t`.
  70. events : list of callables, length n_events
  71. Event functions with signatures ``event(t, y)``.
  72. active_events : ndarray
  73. Indices of events which occurred.
  74. is_terminal : ndarray, shape (n_events,)
  75. Which events are terminal.
  76. t_old, t : float
  77. Previous and new values of time.
  78. Returns
  79. -------
  80. root_indices : ndarray
  81. Indices of events which take zero between `t_old` and `t` and before
  82. a possible termination.
  83. roots : ndarray
  84. Values of t at which events occurred.
  85. terminate : bool
  86. Whether a terminal event occurred.
  87. """
  88. roots = []
  89. for event_index in active_events:
  90. roots.append(solve_event_equation(events[event_index], sol, t_old, t))
  91. roots = np.asarray(roots)
  92. if np.any(is_terminal[active_events]):
  93. if t > t_old:
  94. order = np.argsort(roots)
  95. else:
  96. order = np.argsort(-roots)
  97. active_events = active_events[order]
  98. roots = roots[order]
  99. t = np.nonzero(is_terminal[active_events])[0][0]
  100. active_events = active_events[:t + 1]
  101. roots = roots[:t + 1]
  102. terminate = True
  103. else:
  104. terminate = False
  105. return active_events, roots, terminate
  106. def find_active_events(g, g_new, direction):
  107. """Find which event occurred during an integration step.
  108. Parameters
  109. ----------
  110. g, g_new : array_like, shape (n_events,)
  111. Values of event functions at a current and next points.
  112. direction : ndarray, shape (n_events,)
  113. Event "direction" according to the definition in `solve_ivp`.
  114. Returns
  115. -------
  116. active_events : ndarray
  117. Indices of events which occurred during the step.
  118. """
  119. g, g_new = np.asarray(g), np.asarray(g_new)
  120. up = (g <= 0) & (g_new >= 0)
  121. down = (g >= 0) & (g_new <= 0)
  122. either = up | down
  123. mask = (up & (direction > 0) |
  124. down & (direction < 0) |
  125. either & (direction == 0))
  126. return np.nonzero(mask)[0]
  127. def solve_ivp(fun, t_span, y0, method='RK45', t_eval=None, dense_output=False,
  128. events=None, vectorized=False, **options):
  129. """Solve an initial value problem for a system of ODEs.
  130. This function numerically integrates a system of ordinary differential
  131. equations given an initial value::
  132. dy / dt = f(t, y)
  133. y(t0) = y0
  134. Here t is a one-dimensional independent variable (time), y(t) is an
  135. n-dimensional vector-valued function (state), and an n-dimensional
  136. vector-valued function f(t, y) determines the differential equations.
  137. The goal is to find y(t) approximately satisfying the differential
  138. equations, given an initial value y(t0)=y0.
  139. Some of the solvers support integration in the complex domain, but note that
  140. for stiff ODE solvers, the right-hand side must be complex-differentiable
  141. (satisfy Cauchy-Riemann equations [11]_). To solve a problem in the complex
  142. domain, pass y0 with a complex data type. Another option is always to
  143. rewrite your problem for real and imaginary parts separately.
  144. Parameters
  145. ----------
  146. fun : callable
  147. Right-hand side of the system. The calling signature is ``fun(t, y)``.
  148. Here ``t`` is a scalar, and there are two options for the ndarray ``y``:
  149. It can either have shape (n,); then ``fun`` must return array_like with
  150. shape (n,). Alternatively it can have shape (n, k); then ``fun``
  151. must return an array_like with shape (n, k), i.e. each column
  152. corresponds to a single column in ``y``. The choice between the two
  153. options is determined by `vectorized` argument (see below). The
  154. vectorized implementation allows a faster approximation of the Jacobian
  155. by finite differences (required for stiff solvers).
  156. t_span : 2-tuple of floats
  157. Interval of integration (t0, tf). The solver starts with t=t0 and
  158. integrates until it reaches t=tf.
  159. y0 : array_like, shape (n,)
  160. Initial state. For problems in the complex domain, pass `y0` with a
  161. complex data type (even if the initial guess is purely real).
  162. method : string or `OdeSolver`, optional
  163. Integration method to use:
  164. * 'RK45' (default): Explicit Runge-Kutta method of order 5(4) [1]_.
  165. The error is controlled assuming accuracy of the fourth-order
  166. method, but steps are taken using the fifth-order accurate formula
  167. (local extrapolation is done). A quartic interpolation polynomial
  168. is used for the dense output [2]_. Can be applied in the complex domain.
  169. * 'RK23': Explicit Runge-Kutta method of order 3(2) [3]_. The error
  170. is controlled assuming accuracy of the second-order method, but
  171. steps are taken using the third-order accurate formula (local
  172. extrapolation is done). A cubic Hermite polynomial is used for the
  173. dense output. Can be applied in the complex domain.
  174. * 'Radau': Implicit Runge-Kutta method of the Radau IIA family of
  175. order 5 [4]_. The error is controlled with a third-order accurate
  176. embedded formula. A cubic polynomial which satisfies the
  177. collocation conditions is used for the dense output.
  178. * 'BDF': Implicit multi-step variable-order (1 to 5) method based
  179. on a backward differentiation formula for the derivative
  180. approximation [5]_. The implementation follows the one described
  181. in [6]_. A quasi-constant step scheme is used and accuracy is
  182. enhanced using the NDF modification. Can be applied in the complex
  183. domain.
  184. * 'LSODA': Adams/BDF method with automatic stiffness detection and
  185. switching [7]_, [8]_. This is a wrapper of the Fortran solver
  186. from ODEPACK.
  187. You should use the 'RK45' or 'RK23' method for non-stiff problems and
  188. 'Radau' or 'BDF' for stiff problems [9]_. If not sure, first try to run
  189. 'RK45'. If needs unusually many iterations, diverges, or fails, your
  190. problem is likely to be stiff and you should use 'Radau' or 'BDF'.
  191. 'LSODA' can also be a good universal choice, but it might be somewhat
  192. less convenient to work with as it wraps old Fortran code.
  193. You can also pass an arbitrary class derived from `OdeSolver` which
  194. implements the solver.
  195. dense_output : bool, optional
  196. Whether to compute a continuous solution. Default is False.
  197. t_eval : array_like or None, optional
  198. Times at which to store the computed solution, must be sorted and lie
  199. within `t_span`. If None (default), use points selected by the solver.
  200. events : callable, list of callables or None, optional
  201. Types of events to track. Each is defined by a continuous function of
  202. time and state that becomes zero value in case of an event. Each function
  203. must have the signature ``event(t, y)`` and return a float. The solver will
  204. find an accurate value of ``t`` at which ``event(t, y(t)) = 0`` using a
  205. root-finding algorithm. Additionally each ``event`` function might have
  206. the following attributes:
  207. * terminal: bool, whether to terminate integration if this
  208. event occurs. Implicitly False if not assigned.
  209. * direction: float, direction of a zero crossing. If `direction`
  210. is positive, `event` must go from negative to positive, and
  211. vice versa if `direction` is negative. If 0, then either direction
  212. will count. Implicitly 0 if not assigned.
  213. You can assign attributes like ``event.terminal = True`` to any
  214. function in Python. If None (default), events won't be tracked.
  215. vectorized : bool, optional
  216. Whether `fun` is implemented in a vectorized fashion. Default is False.
  217. options
  218. Options passed to a chosen solver. All options available for already
  219. implemented solvers are listed below.
  220. first_step : float or None, optional
  221. Initial step size. Default is ``None`` which means that the algorithm
  222. should choose.
  223. max_step : float, optional
  224. Maximum allowed step size. Default is np.inf, i.e. the step size is not
  225. bounded and determined solely by the solver.
  226. rtol, atol : float and array_like, optional
  227. Relative and absolute tolerances. The solver keeps the local error
  228. estimates less than ``atol + rtol * abs(y)``. Here `rtol` controls a
  229. relative accuracy (number of correct digits). But if a component of `y`
  230. is approximately below `atol`, the error only needs to fall within
  231. the same `atol` threshold, and the number of correct digits is not
  232. guaranteed. If components of y have different scales, it might be
  233. beneficial to set different `atol` values for different components by
  234. passing array_like with shape (n,) for `atol`. Default values are
  235. 1e-3 for `rtol` and 1e-6 for `atol`.
  236. jac : {None, array_like, sparse_matrix, callable}, optional
  237. Jacobian matrix of the right-hand side of the system with respect to
  238. y, required by the 'Radau', 'BDF' and 'LSODA' method. The Jacobian matrix
  239. has shape (n, n) and its element (i, j) is equal to ``d f_i / d y_j``.
  240. There are three ways to define the Jacobian:
  241. * If array_like or sparse_matrix, the Jacobian is assumed to
  242. be constant. Not supported by 'LSODA'.
  243. * If callable, the Jacobian is assumed to depend on both
  244. t and y; it will be called as ``jac(t, y)`` as necessary.
  245. For the 'Radau' and 'BDF' methods, the return value might be a
  246. sparse matrix.
  247. * If None (default), the Jacobian will be approximated by
  248. finite differences.
  249. It is generally recommended to provide the Jacobian rather than
  250. relying on a finite-difference approximation.
  251. jac_sparsity : {None, array_like, sparse matrix}, optional
  252. Defines a sparsity structure of the Jacobian matrix for a
  253. finite-difference approximation. Its shape must be (n, n). This argument
  254. is ignored if `jac` is not `None`. If the Jacobian has only few non-zero
  255. elements in *each* row, providing the sparsity structure will greatly
  256. speed up the computations [10]_. A zero entry means that a corresponding
  257. element in the Jacobian is always zero. If None (default), the Jacobian
  258. is assumed to be dense.
  259. Not supported by 'LSODA', see `lband` and `uband` instead.
  260. lband, uband : int or None
  261. Parameters defining the bandwidth of the Jacobian for the 'LSODA' method,
  262. i.e., ``jac[i, j] != 0 only for i - lband <= j <= i + uband``. Setting
  263. these requires your jac routine to return the Jacobian in the packed format:
  264. the returned array must have ``n`` columns and ``uband + lband + 1``
  265. rows in which Jacobian diagonals are written. Specifically
  266. ``jac_packed[uband + i - j , j] = jac[i, j]``. The same format is used
  267. in `scipy.linalg.solve_banded` (check for an illustration).
  268. These parameters can be also used with ``jac=None`` to reduce the
  269. number of Jacobian elements estimated by finite differences.
  270. min_step : float, optional
  271. The minimum allowed step size for 'LSODA' method.
  272. By default `min_step` is zero.
  273. Returns
  274. -------
  275. Bunch object with the following fields defined:
  276. t : ndarray, shape (n_points,)
  277. Time points.
  278. y : ndarray, shape (n, n_points)
  279. Values of the solution at `t`.
  280. sol : `OdeSolution` or None
  281. Found solution as `OdeSolution` instance; None if `dense_output` was
  282. set to False.
  283. t_events : list of ndarray or None
  284. Contains for each event type a list of arrays at which an event of
  285. that type event was detected. None if `events` was None.
  286. nfev : int
  287. Number of evaluations of the right-hand side.
  288. njev : int
  289. Number of evaluations of the Jacobian.
  290. nlu : int
  291. Number of LU decompositions.
  292. status : int
  293. Reason for algorithm termination:
  294. * -1: Integration step failed.
  295. * 0: The solver successfully reached the end of `tspan`.
  296. * 1: A termination event occurred.
  297. message : string
  298. Human-readable description of the termination reason.
  299. success : bool
  300. True if the solver reached the interval end or a termination event
  301. occurred (``status >= 0``).
  302. References
  303. ----------
  304. .. [1] J. R. Dormand, P. J. Prince, "A family of embedded Runge-Kutta
  305. formulae", Journal of Computational and Applied Mathematics, Vol. 6,
  306. No. 1, pp. 19-26, 1980.
  307. .. [2] L. W. Shampine, "Some Practical Runge-Kutta Formulas", Mathematics
  308. of Computation,, Vol. 46, No. 173, pp. 135-150, 1986.
  309. .. [3] P. Bogacki, L.F. Shampine, "A 3(2) Pair of Runge-Kutta Formulas",
  310. Appl. Math. Lett. Vol. 2, No. 4. pp. 321-325, 1989.
  311. .. [4] E. Hairer, G. Wanner, "Solving Ordinary Differential Equations II:
  312. Stiff and Differential-Algebraic Problems", Sec. IV.8.
  313. .. [5] `Backward Differentiation Formula
  314. <https://en.wikipedia.org/wiki/Backward_differentiation_formula>`_
  315. on Wikipedia.
  316. .. [6] L. F. Shampine, M. W. Reichelt, "THE MATLAB ODE SUITE", SIAM J. SCI.
  317. COMPUTE., Vol. 18, No. 1, pp. 1-22, January 1997.
  318. .. [7] A. C. Hindmarsh, "ODEPACK, A Systematized Collection of ODE
  319. Solvers," IMACS Transactions on Scientific Computation, Vol 1.,
  320. pp. 55-64, 1983.
  321. .. [8] L. Petzold, "Automatic selection of methods for solving stiff and
  322. nonstiff systems of ordinary differential equations", SIAM Journal
  323. on Scientific and Statistical Computing, Vol. 4, No. 1, pp. 136-148,
  324. 1983.
  325. .. [9] `Stiff equation <https://en.wikipedia.org/wiki/Stiff_equation>`_ on
  326. Wikipedia.
  327. .. [10] A. Curtis, M. J. D. Powell, and J. Reid, "On the estimation of
  328. sparse Jacobian matrices", Journal of the Institute of Mathematics
  329. and its Applications, 13, pp. 117-120, 1974.
  330. .. [11] `Cauchy-Riemann equations
  331. <https://en.wikipedia.org/wiki/Cauchy-Riemann_equations>`_ on
  332. Wikipedia.
  333. Examples
  334. --------
  335. Basic exponential decay showing automatically chosen time points.
  336. >>> from scipy.integrate import solve_ivp
  337. >>> def exponential_decay(t, y): return -0.5 * y
  338. >>> sol = solve_ivp(exponential_decay, [0, 10], [2, 4, 8])
  339. >>> print(sol.t)
  340. [ 0. 0.11487653 1.26364188 3.06061781 4.85759374
  341. 6.65456967 8.4515456 10. ]
  342. >>> print(sol.y)
  343. [[2. 1.88836035 1.06327177 0.43319312 0.17648948 0.0719045
  344. 0.02929499 0.01350938]
  345. [4. 3.7767207 2.12654355 0.86638624 0.35297895 0.143809
  346. 0.05858998 0.02701876]
  347. [8. 7.5534414 4.25308709 1.73277247 0.7059579 0.287618
  348. 0.11717996 0.05403753]]
  349. Specifying points where the solution is desired.
  350. >>> sol = solve_ivp(exponential_decay, [0, 10], [2, 4, 8],
  351. ... t_eval=[0, 1, 2, 4, 10])
  352. >>> print(sol.t)
  353. [ 0 1 2 4 10]
  354. >>> print(sol.y)
  355. [[2. 1.21305369 0.73534021 0.27066736 0.01350938]
  356. [4. 2.42610739 1.47068043 0.54133472 0.02701876]
  357. [8. 4.85221478 2.94136085 1.08266944 0.05403753]]
  358. Cannon fired upward with terminal event upon impact. The ``terminal`` and
  359. ``direction`` fields of an event are applied by monkey patching a function.
  360. Here ``y[0]`` is position and ``y[1]`` is velocity. The projectile starts at
  361. position 0 with velocity +10. Note that the integration never reaches t=100
  362. because the event is terminal.
  363. >>> def upward_cannon(t, y): return [y[1], -0.5]
  364. >>> def hit_ground(t, y): return y[1]
  365. >>> hit_ground.terminal = True
  366. >>> hit_ground.direction = -1
  367. >>> sol = solve_ivp(upward_cannon, [0, 100], [0, 10], events=hit_ground)
  368. >>> print(sol.t_events)
  369. [array([ 20.])]
  370. >>> print(sol.t)
  371. [0.00000000e+00 9.99900010e-05 1.09989001e-03 1.10988901e-02
  372. 1.11088891e-01 1.11098890e+00 1.11099890e+01 2.00000000e+01]
  373. """
  374. if method not in METHODS and not (
  375. inspect.isclass(method) and issubclass(method, OdeSolver)):
  376. raise ValueError("`method` must be one of {} or OdeSolver class."
  377. .format(METHODS))
  378. t0, tf = float(t_span[0]), float(t_span[1])
  379. if t_eval is not None:
  380. t_eval = np.asarray(t_eval)
  381. if t_eval.ndim != 1:
  382. raise ValueError("`t_eval` must be 1-dimensional.")
  383. if np.any(t_eval < min(t0, tf)) or np.any(t_eval > max(t0, tf)):
  384. raise ValueError("Values in `t_eval` are not within `t_span`.")
  385. d = np.diff(t_eval)
  386. if tf > t0 and np.any(d <= 0) or tf < t0 and np.any(d >= 0):
  387. raise ValueError("Values in `t_eval` are not properly sorted.")
  388. if tf > t0:
  389. t_eval_i = 0
  390. else:
  391. # Make order of t_eval decreasing to use np.searchsorted.
  392. t_eval = t_eval[::-1]
  393. # This will be an upper bound for slices.
  394. t_eval_i = t_eval.shape[0]
  395. if method in METHODS:
  396. method = METHODS[method]
  397. solver = method(fun, t0, y0, tf, vectorized=vectorized, **options)
  398. if t_eval is None:
  399. ts = [t0]
  400. ys = [y0]
  401. elif t_eval is not None and dense_output:
  402. ts = []
  403. ti = [t0]
  404. ys = []
  405. else:
  406. ts = []
  407. ys = []
  408. interpolants = []
  409. events, is_terminal, event_dir = prepare_events(events)
  410. if events is not None:
  411. g = [event(t0, y0) for event in events]
  412. t_events = [[] for _ in range(len(events))]
  413. else:
  414. t_events = None
  415. status = None
  416. while status is None:
  417. message = solver.step()
  418. if solver.status == 'finished':
  419. status = 0
  420. elif solver.status == 'failed':
  421. status = -1
  422. break
  423. t_old = solver.t_old
  424. t = solver.t
  425. y = solver.y
  426. if dense_output:
  427. sol = solver.dense_output()
  428. interpolants.append(sol)
  429. else:
  430. sol = None
  431. if events is not None:
  432. g_new = [event(t, y) for event in events]
  433. active_events = find_active_events(g, g_new, event_dir)
  434. if active_events.size > 0:
  435. if sol is None:
  436. sol = solver.dense_output()
  437. root_indices, roots, terminate = handle_events(
  438. sol, events, active_events, is_terminal, t_old, t)
  439. for e, te in zip(root_indices, roots):
  440. t_events[e].append(te)
  441. if terminate:
  442. status = 1
  443. t = roots[-1]
  444. y = sol(t)
  445. g = g_new
  446. if t_eval is None:
  447. ts.append(t)
  448. ys.append(y)
  449. else:
  450. # The value in t_eval equal to t will be included.
  451. if solver.direction > 0:
  452. t_eval_i_new = np.searchsorted(t_eval, t, side='right')
  453. t_eval_step = t_eval[t_eval_i:t_eval_i_new]
  454. else:
  455. t_eval_i_new = np.searchsorted(t_eval, t, side='left')
  456. # It has to be done with two slice operations, because
  457. # you can't slice to 0-th element inclusive using backward
  458. # slicing.
  459. t_eval_step = t_eval[t_eval_i_new:t_eval_i][::-1]
  460. if t_eval_step.size > 0:
  461. if sol is None:
  462. sol = solver.dense_output()
  463. ts.append(t_eval_step)
  464. ys.append(sol(t_eval_step))
  465. t_eval_i = t_eval_i_new
  466. if t_eval is not None and dense_output:
  467. ti.append(t)
  468. message = MESSAGES.get(status, message)
  469. if t_events is not None:
  470. t_events = [np.asarray(te) for te in t_events]
  471. if t_eval is None:
  472. ts = np.array(ts)
  473. ys = np.vstack(ys).T
  474. else:
  475. ts = np.hstack(ts)
  476. ys = np.hstack(ys)
  477. if dense_output:
  478. if t_eval is None:
  479. sol = OdeSolution(ts, interpolants)
  480. else:
  481. sol = OdeSolution(ti, interpolants)
  482. else:
  483. sol = None
  484. return OdeResult(t=ts, y=ys, sol=sol, t_events=t_events, nfev=solver.nfev,
  485. njev=solver.njev, nlu=solver.nlu, status=status,
  486. message=message, success=status >= 0)