base.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. from __future__ import division, print_function, absolute_import
  2. import numpy as np
  3. def check_arguments(fun, y0, support_complex):
  4. """Helper function for checking arguments common to all solvers."""
  5. y0 = np.asarray(y0)
  6. if np.issubdtype(y0.dtype, np.complexfloating):
  7. if not support_complex:
  8. raise ValueError("`y0` is complex, but the chosen solver does "
  9. "not support integration in a complex domain.")
  10. dtype = complex
  11. else:
  12. dtype = float
  13. y0 = y0.astype(dtype, copy=False)
  14. if y0.ndim != 1:
  15. raise ValueError("`y0` must be 1-dimensional.")
  16. def fun_wrapped(t, y):
  17. return np.asarray(fun(t, y), dtype=dtype)
  18. return fun_wrapped, y0
  19. class OdeSolver(object):
  20. """Base class for ODE solvers.
  21. In order to implement a new solver you need to follow the guidelines:
  22. 1. A constructor must accept parameters presented in the base class
  23. (listed below) along with any other parameters specific to a solver.
  24. 2. A constructor must accept arbitrary extraneous arguments
  25. ``**extraneous``, but warn that these arguments are irrelevant
  26. using `common.warn_extraneous` function. Do not pass these
  27. arguments to the base class.
  28. 3. A solver must implement a private method `_step_impl(self)` which
  29. propagates a solver one step further. It must return tuple
  30. ``(success, message)``, where ``success`` is a boolean indicating
  31. whether a step was successful, and ``message`` is a string
  32. containing description of a failure if a step failed or None
  33. otherwise.
  34. 4. A solver must implement a private method `_dense_output_impl(self)`
  35. which returns a `DenseOutput` object covering the last successful
  36. step.
  37. 5. A solver must have attributes listed below in Attributes section.
  38. Note that `t_old` and `step_size` are updated automatically.
  39. 6. Use `fun(self, t, y)` method for the system rhs evaluation, this
  40. way the number of function evaluations (`nfev`) will be tracked
  41. automatically.
  42. 7. For convenience a base class provides `fun_single(self, t, y)` and
  43. `fun_vectorized(self, t, y)` for evaluating the rhs in
  44. non-vectorized and vectorized fashions respectively (regardless of
  45. how `fun` from the constructor is implemented). These calls don't
  46. increment `nfev`.
  47. 8. If a solver uses a Jacobian matrix and LU decompositions, it should
  48. track the number of Jacobian evaluations (`njev`) and the number of
  49. LU decompositions (`nlu`).
  50. 9. By convention the function evaluations used to compute a finite
  51. difference approximation of the Jacobian should not be counted in
  52. `nfev`, thus use `fun_single(self, t, y)` or
  53. `fun_vectorized(self, t, y)` when computing a finite difference
  54. approximation of the Jacobian.
  55. Parameters
  56. ----------
  57. fun : callable
  58. Right-hand side of the system. The calling signature is ``fun(t, y)``.
  59. Here ``t`` is a scalar and there are two options for ndarray ``y``.
  60. It can either have shape (n,), then ``fun`` must return array_like with
  61. shape (n,). Or alternatively it can have shape (n, n_points), then
  62. ``fun`` must return array_like with shape (n, n_points) (each column
  63. corresponds to a single column in ``y``). The choice between the two
  64. options is determined by `vectorized` argument (see below).
  65. t0 : float
  66. Initial time.
  67. y0 : array_like, shape (n,)
  68. Initial state.
  69. t_bound : float
  70. Boundary time --- the integration won't continue beyond it. It also
  71. determines the direction of the integration.
  72. vectorized : bool
  73. Whether `fun` is implemented in a vectorized fashion.
  74. support_complex : bool, optional
  75. Whether integration in a complex domain should be supported.
  76. Generally determined by a derived solver class capabilities.
  77. Default is False.
  78. Attributes
  79. ----------
  80. n : int
  81. Number of equations.
  82. status : string
  83. Current status of the solver: 'running', 'finished' or 'failed'.
  84. t_bound : float
  85. Boundary time.
  86. direction : float
  87. Integration direction: +1 or -1.
  88. t : float
  89. Current time.
  90. y : ndarray
  91. Current state.
  92. t_old : float
  93. Previous time. None if no steps were made yet.
  94. step_size : float
  95. Size of the last successful step. None if no steps were made yet.
  96. nfev : int
  97. Number of the system's rhs evaluations.
  98. njev : int
  99. Number of the Jacobian evaluations.
  100. nlu : int
  101. Number of LU decompositions.
  102. """
  103. TOO_SMALL_STEP = "Required step size is less than spacing between numbers."
  104. def __init__(self, fun, t0, y0, t_bound, vectorized,
  105. support_complex=False):
  106. self.t_old = None
  107. self.t = t0
  108. self._fun, self.y = check_arguments(fun, y0, support_complex)
  109. self.t_bound = t_bound
  110. self.vectorized = vectorized
  111. if vectorized:
  112. def fun_single(t, y):
  113. return self._fun(t, y[:, None]).ravel()
  114. fun_vectorized = self._fun
  115. else:
  116. fun_single = self._fun
  117. def fun_vectorized(t, y):
  118. f = np.empty_like(y)
  119. for i, yi in enumerate(y.T):
  120. f[:, i] = self._fun(t, yi)
  121. return f
  122. def fun(t, y):
  123. self.nfev += 1
  124. return self.fun_single(t, y)
  125. self.fun = fun
  126. self.fun_single = fun_single
  127. self.fun_vectorized = fun_vectorized
  128. self.direction = np.sign(t_bound - t0) if t_bound != t0 else 1
  129. self.n = self.y.size
  130. self.status = 'running'
  131. self.nfev = 0
  132. self.njev = 0
  133. self.nlu = 0
  134. @property
  135. def step_size(self):
  136. if self.t_old is None:
  137. return None
  138. else:
  139. return np.abs(self.t - self.t_old)
  140. def step(self):
  141. """Perform one integration step.
  142. Returns
  143. -------
  144. message : string or None
  145. Report from the solver. Typically a reason for a failure if
  146. `self.status` is 'failed' after the step was taken or None
  147. otherwise.
  148. """
  149. if self.status != 'running':
  150. raise RuntimeError("Attempt to step on a failed or finished "
  151. "solver.")
  152. if self.n == 0 or self.t == self.t_bound:
  153. # Handle corner cases of empty solver or no integration.
  154. self.t_old = self.t
  155. self.t = self.t_bound
  156. message = None
  157. self.status = 'finished'
  158. else:
  159. t = self.t
  160. success, message = self._step_impl()
  161. if not success:
  162. self.status = 'failed'
  163. else:
  164. self.t_old = t
  165. if self.direction * (self.t - self.t_bound) >= 0:
  166. self.status = 'finished'
  167. return message
  168. def dense_output(self):
  169. """Compute a local interpolant over the last successful step.
  170. Returns
  171. -------
  172. sol : `DenseOutput`
  173. Local interpolant over the last successful step.
  174. """
  175. if self.t_old is None:
  176. raise RuntimeError("Dense output is available after a successful "
  177. "step was made.")
  178. if self.n == 0 or self.t == self.t_old:
  179. # Handle corner cases of empty solver and no integration.
  180. return ConstantDenseOutput(self.t_old, self.t, self.y)
  181. else:
  182. return self._dense_output_impl()
  183. def _step_impl(self):
  184. raise NotImplementedError
  185. def _dense_output_impl(self):
  186. raise NotImplementedError
  187. class DenseOutput(object):
  188. """Base class for local interpolant over step made by an ODE solver.
  189. It interpolates between `t_min` and `t_max` (see Attributes below).
  190. Evaluation outside this interval is not forbidden, but the accuracy is not
  191. guaranteed.
  192. Attributes
  193. ----------
  194. t_min, t_max : float
  195. Time range of the interpolation.
  196. """
  197. def __init__(self, t_old, t):
  198. self.t_old = t_old
  199. self.t = t
  200. self.t_min = min(t, t_old)
  201. self.t_max = max(t, t_old)
  202. def __call__(self, t):
  203. """Evaluate the interpolant.
  204. Parameters
  205. ----------
  206. t : float or array_like with shape (n_points,)
  207. Points to evaluate the solution at.
  208. Returns
  209. -------
  210. y : ndarray, shape (n,) or (n, n_points)
  211. Computed values. Shape depends on whether `t` was a scalar or a
  212. 1-d array.
  213. """
  214. t = np.asarray(t)
  215. if t.ndim > 1:
  216. raise ValueError("`t` must be float or 1-d array.")
  217. return self._call_impl(t)
  218. def _call_impl(self, t):
  219. raise NotImplementedError
  220. class ConstantDenseOutput(DenseOutput):
  221. """Constant value interpolator.
  222. This class used for degenerate integration cases: equal integration limits
  223. or a system with 0 equations.
  224. """
  225. def __init__(self, t_old, t, value):
  226. super(ConstantDenseOutput, self).__init__(t_old, t)
  227. self.value = value
  228. def _call_impl(self, t):
  229. if t.ndim == 0:
  230. return self.value
  231. else:
  232. ret = np.empty((self.value.shape[0], t.shape[0]))
  233. ret[:] = self.value[:, None]
  234. return ret