rk.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. from __future__ import division, print_function, absolute_import
  2. import numpy as np
  3. from .base import OdeSolver, DenseOutput
  4. from .common import (validate_max_step, validate_tol, select_initial_step,
  5. norm, warn_extraneous, validate_first_step)
  6. # Multiply steps computed from asymptotic behaviour of errors by this.
  7. SAFETY = 0.9
  8. MIN_FACTOR = 0.2 # Minimum allowed decrease in a step size.
  9. MAX_FACTOR = 10 # Maximum allowed increase in a step size.
  10. def rk_step(fun, t, y, f, h, A, B, C, E, K):
  11. """Perform a single Runge-Kutta step.
  12. This function computes a prediction of an explicit Runge-Kutta method and
  13. also estimates the error of a less accurate method.
  14. Notation for Butcher tableau is as in [1]_.
  15. Parameters
  16. ----------
  17. fun : callable
  18. Right-hand side of the system.
  19. t : float
  20. Current time.
  21. y : ndarray, shape (n,)
  22. Current state.
  23. f : ndarray, shape (n,)
  24. Current value of the derivative, i.e. ``fun(x, y)``.
  25. h : float
  26. Step to use.
  27. A : list of ndarray, length n_stages - 1
  28. Coefficients for combining previous RK stages to compute the next
  29. stage. For explicit methods the coefficients above the main diagonal
  30. are zeros, so `A` is stored as a list of arrays of increasing lengths.
  31. The first stage is always just `f`, thus no coefficients for it
  32. are required.
  33. B : ndarray, shape (n_stages,)
  34. Coefficients for combining RK stages for computing the final
  35. prediction.
  36. C : ndarray, shape (n_stages - 1,)
  37. Coefficients for incrementing time for consecutive RK stages.
  38. The value for the first stage is always zero, thus it is not stored.
  39. E : ndarray, shape (n_stages + 1,)
  40. Coefficients for estimating the error of a less accurate method. They
  41. are computed as the difference between b's in an extended tableau.
  42. K : ndarray, shape (n_stages + 1, n)
  43. Storage array for putting RK stages here. Stages are stored in rows.
  44. Returns
  45. -------
  46. y_new : ndarray, shape (n,)
  47. Solution at t + h computed with a higher accuracy.
  48. f_new : ndarray, shape (n,)
  49. Derivative ``fun(t + h, y_new)``.
  50. error : ndarray, shape (n,)
  51. Error estimate of a less accurate method.
  52. References
  53. ----------
  54. .. [1] E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential
  55. Equations I: Nonstiff Problems", Sec. II.4.
  56. """
  57. K[0] = f
  58. for s, (a, c) in enumerate(zip(A, C)):
  59. dy = np.dot(K[:s + 1].T, a) * h
  60. K[s + 1] = fun(t + c * h, y + dy)
  61. y_new = y + h * np.dot(K[:-1].T, B)
  62. f_new = fun(t + h, y_new)
  63. K[-1] = f_new
  64. error = np.dot(K.T, E) * h
  65. return y_new, f_new, error
  66. class RungeKutta(OdeSolver):
  67. """Base class for explicit Runge-Kutta methods."""
  68. C = NotImplemented
  69. A = NotImplemented
  70. B = NotImplemented
  71. E = NotImplemented
  72. P = NotImplemented
  73. order = NotImplemented
  74. n_stages = NotImplemented
  75. def __init__(self, fun, t0, y0, t_bound, max_step=np.inf,
  76. rtol=1e-3, atol=1e-6, vectorized=False,
  77. first_step=None, **extraneous):
  78. warn_extraneous(extraneous)
  79. super(RungeKutta, self).__init__(fun, t0, y0, t_bound, vectorized,
  80. support_complex=True)
  81. self.y_old = None
  82. self.max_step = validate_max_step(max_step)
  83. self.rtol, self.atol = validate_tol(rtol, atol, self.n)
  84. self.f = self.fun(self.t, self.y)
  85. if first_step is None:
  86. self.h_abs = select_initial_step(
  87. self.fun, self.t, self.y, self.f, self.direction,
  88. self.order, self.rtol, self.atol)
  89. else:
  90. self.h_abs = validate_first_step(first_step, t0, t_bound)
  91. self.K = np.empty((self.n_stages + 1, self.n), dtype=self.y.dtype)
  92. def _step_impl(self):
  93. t = self.t
  94. y = self.y
  95. max_step = self.max_step
  96. rtol = self.rtol
  97. atol = self.atol
  98. min_step = 10 * np.abs(np.nextafter(t, self.direction * np.inf) - t)
  99. if self.h_abs > max_step:
  100. h_abs = max_step
  101. elif self.h_abs < min_step:
  102. h_abs = min_step
  103. else:
  104. h_abs = self.h_abs
  105. order = self.order
  106. step_accepted = False
  107. while not step_accepted:
  108. if h_abs < min_step:
  109. return False, self.TOO_SMALL_STEP
  110. h = h_abs * self.direction
  111. t_new = t + h
  112. if self.direction * (t_new - self.t_bound) > 0:
  113. t_new = self.t_bound
  114. h = t_new - t
  115. h_abs = np.abs(h)
  116. y_new, f_new, error = rk_step(self.fun, t, y, self.f, h, self.A,
  117. self.B, self.C, self.E, self.K)
  118. scale = atol + np.maximum(np.abs(y), np.abs(y_new)) * rtol
  119. error_norm = norm(error / scale)
  120. if error_norm == 0.0:
  121. h_abs *= MAX_FACTOR
  122. step_accepted = True
  123. elif error_norm < 1:
  124. h_abs *= min(MAX_FACTOR,
  125. max(1, SAFETY * error_norm ** (-1 / (order + 1))))
  126. step_accepted = True
  127. else:
  128. h_abs *= max(MIN_FACTOR,
  129. SAFETY * error_norm ** (-1 / (order + 1)))
  130. self.y_old = y
  131. self.t = t_new
  132. self.y = y_new
  133. self.h_abs = h_abs
  134. self.f = f_new
  135. return True, None
  136. def _dense_output_impl(self):
  137. Q = self.K.T.dot(self.P)
  138. return RkDenseOutput(self.t_old, self.t, self.y_old, Q)
  139. class RK23(RungeKutta):
  140. """Explicit Runge-Kutta method of order 3(2).
  141. This uses the Bogacki-Shampine pair of formulas [1]_. The error is controlled
  142. assuming accuracy of the second-order method, but steps are taken using the
  143. third-order accurate formula (local extrapolation is done). A cubic Hermite
  144. polynomial is used for the dense output.
  145. Can be applied in the complex domain.
  146. Parameters
  147. ----------
  148. fun : callable
  149. Right-hand side of the system. The calling signature is ``fun(t, y)``.
  150. Here ``t`` is a scalar and there are two options for ndarray ``y``.
  151. It can either have shape (n,), then ``fun`` must return array_like with
  152. shape (n,). Or alternatively it can have shape (n, k), then ``fun``
  153. must return array_like with shape (n, k), i.e. each column
  154. corresponds to a single column in ``y``. The choice between the two
  155. options is determined by `vectorized` argument (see below).
  156. t0 : float
  157. Initial time.
  158. y0 : array_like, shape (n,)
  159. Initial state.
  160. t_bound : float
  161. Boundary time - the integration won't continue beyond it. It also
  162. determines the direction of the integration.
  163. first_step : float or None, optional
  164. Initial step size. Default is ``None`` which means that the algorithm
  165. should choose.
  166. max_step : float, optional
  167. Maximum allowed step size. Default is np.inf, i.e. the step size is not
  168. bounded and determined solely by the solver.
  169. rtol, atol : float and array_like, optional
  170. Relative and absolute tolerances. The solver keeps the local error
  171. estimates less than ``atol + rtol * abs(y)``. Here `rtol` controls a
  172. relative accuracy (number of correct digits). But if a component of `y`
  173. is approximately below `atol`, the error only needs to fall within
  174. the same `atol` threshold, and the number of correct digits is not
  175. guaranteed. If components of y have different scales, it might be
  176. beneficial to set different `atol` values for different components by
  177. passing array_like with shape (n,) for `atol`. Default values are
  178. 1e-3 for `rtol` and 1e-6 for `atol`.
  179. vectorized : bool, optional
  180. Whether `fun` is implemented in a vectorized fashion. Default is False.
  181. Attributes
  182. ----------
  183. n : int
  184. Number of equations.
  185. status : string
  186. Current status of the solver: 'running', 'finished' or 'failed'.
  187. t_bound : float
  188. Boundary time.
  189. direction : float
  190. Integration direction: +1 or -1.
  191. t : float
  192. Current time.
  193. y : ndarray
  194. Current state.
  195. t_old : float
  196. Previous time. None if no steps were made yet.
  197. step_size : float
  198. Size of the last successful step. None if no steps were made yet.
  199. nfev : int
  200. Number evaluations of the system's right-hand side.
  201. njev : int
  202. Number of evaluations of the Jacobian. Is always 0 for this solver as it does not use the Jacobian.
  203. nlu : int
  204. Number of LU decompositions. Is always 0 for this solver.
  205. References
  206. ----------
  207. .. [1] P. Bogacki, L.F. Shampine, "A 3(2) Pair of Runge-Kutta Formulas",
  208. Appl. Math. Lett. Vol. 2, No. 4. pp. 321-325, 1989.
  209. """
  210. order = 2
  211. n_stages = 3
  212. C = np.array([1/2, 3/4])
  213. A = [np.array([1/2]),
  214. np.array([0, 3/4])]
  215. B = np.array([2/9, 1/3, 4/9])
  216. E = np.array([5/72, -1/12, -1/9, 1/8])
  217. P = np.array([[1, -4 / 3, 5 / 9],
  218. [0, 1, -2/3],
  219. [0, 4/3, -8/9],
  220. [0, -1, 1]])
  221. class RK45(RungeKutta):
  222. """Explicit Runge-Kutta method of order 5(4).
  223. This uses the Dormand-Prince pair of formulas [1]_. The error is controlled
  224. assuming accuracy of the fourth-order method accuracy, but steps are taken
  225. using the fifth-order accurate formula (local extrapolation is done).
  226. A quartic interpolation polynomial is used for the dense output [2]_.
  227. Can be applied in the complex domain.
  228. Parameters
  229. ----------
  230. fun : callable
  231. Right-hand side of the system. The calling signature is ``fun(t, y)``.
  232. Here ``t`` is a scalar, and there are two options for the ndarray ``y``:
  233. It can either have shape (n,); then ``fun`` must return array_like with
  234. shape (n,). Alternatively it can have shape (n, k); then ``fun``
  235. must return an array_like with shape (n, k), i.e. each column
  236. corresponds to a single column in ``y``. The choice between the two
  237. options is determined by `vectorized` argument (see below).
  238. t0 : float
  239. Initial time.
  240. y0 : array_like, shape (n,)
  241. Initial state.
  242. t_bound : float
  243. Boundary time - the integration won't continue beyond it. It also
  244. determines the direction of the integration.
  245. first_step : float or None, optional
  246. Initial step size. Default is ``None`` which means that the algorithm
  247. should choose.
  248. max_step : float, optional
  249. Maximum allowed step size. Default is np.inf, i.e. the step size is not
  250. bounded and determined solely by the solver.
  251. rtol, atol : float and array_like, optional
  252. Relative and absolute tolerances. The solver keeps the local error
  253. estimates less than ``atol + rtol * abs(y)``. Here `rtol` controls a
  254. relative accuracy (number of correct digits). But if a component of `y`
  255. is approximately below `atol`, the error only needs to fall within
  256. the same `atol` threshold, and the number of correct digits is not
  257. guaranteed. If components of y have different scales, it might be
  258. beneficial to set different `atol` values for different components by
  259. passing array_like with shape (n,) for `atol`. Default values are
  260. 1e-3 for `rtol` and 1e-6 for `atol`.
  261. vectorized : bool, optional
  262. Whether `fun` is implemented in a vectorized fashion. Default is False.
  263. Attributes
  264. ----------
  265. n : int
  266. Number of equations.
  267. status : string
  268. Current status of the solver: 'running', 'finished' or 'failed'.
  269. t_bound : float
  270. Boundary time.
  271. direction : float
  272. Integration direction: +1 or -1.
  273. t : float
  274. Current time.
  275. y : ndarray
  276. Current state.
  277. t_old : float
  278. Previous time. None if no steps were made yet.
  279. step_size : float
  280. Size of the last successful step. None if no steps were made yet.
  281. nfev : int
  282. Number evaluations of the system's right-hand side.
  283. njev : int
  284. Number of evaluations of the Jacobian. Is always 0 for this solver as it does not use the Jacobian.
  285. nlu : int
  286. Number of LU decompositions. Is always 0 for this solver.
  287. References
  288. ----------
  289. .. [1] J. R. Dormand, P. J. Prince, "A family of embedded Runge-Kutta
  290. formulae", Journal of Computational and Applied Mathematics, Vol. 6,
  291. No. 1, pp. 19-26, 1980.
  292. .. [2] L. W. Shampine, "Some Practical Runge-Kutta Formulas", Mathematics
  293. of Computation,, Vol. 46, No. 173, pp. 135-150, 1986.
  294. """
  295. order = 4
  296. n_stages = 6
  297. C = np.array([1/5, 3/10, 4/5, 8/9, 1])
  298. A = [np.array([1/5]),
  299. np.array([3/40, 9/40]),
  300. np.array([44/45, -56/15, 32/9]),
  301. np.array([19372/6561, -25360/2187, 64448/6561, -212/729]),
  302. np.array([9017/3168, -355/33, 46732/5247, 49/176, -5103/18656])]
  303. B = np.array([35/384, 0, 500/1113, 125/192, -2187/6784, 11/84])
  304. E = np.array([-71/57600, 0, 71/16695, -71/1920, 17253/339200, -22/525,
  305. 1/40])
  306. # Corresponds to the optimum value of c_6 from [2]_.
  307. P = np.array([
  308. [1, -8048581381/2820520608, 8663915743/2820520608,
  309. -12715105075/11282082432],
  310. [0, 0, 0, 0],
  311. [0, 131558114200/32700410799, -68118460800/10900136933,
  312. 87487479700/32700410799],
  313. [0, -1754552775/470086768, 14199869525/1410260304,
  314. -10690763975/1880347072],
  315. [0, 127303824393/49829197408, -318862633887/49829197408,
  316. 701980252875 / 199316789632],
  317. [0, -282668133/205662961, 2019193451/616988883, -1453857185/822651844],
  318. [0, 40617522/29380423, -110615467/29380423, 69997945/29380423]])
  319. class RkDenseOutput(DenseOutput):
  320. def __init__(self, t_old, t, y_old, Q):
  321. super(RkDenseOutput, self).__init__(t_old, t)
  322. self.h = t - t_old
  323. self.Q = Q
  324. self.order = Q.shape[1] - 1
  325. self.y_old = y_old
  326. def _call_impl(self, t):
  327. x = (t - self.t_old) / self.h
  328. if t.ndim == 0:
  329. p = np.tile(x, self.order + 1)
  330. p = np.cumprod(p)
  331. else:
  332. p = np.tile(x, (self.order + 1, 1))
  333. p = np.cumprod(p, axis=0)
  334. y = self.h * np.dot(self.Q, p)
  335. if y.ndim == 2:
  336. y += self.y_old[:, None]
  337. else:
  338. y += self.y_old
  339. return y