test_banded_ode_solvers.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. from __future__ import division, print_function, absolute_import
  2. import itertools
  3. import numpy as np
  4. from numpy.testing import assert_allclose
  5. from scipy.integrate import ode
  6. def _band_count(a):
  7. """Returns ml and mu, the lower and upper band sizes of a."""
  8. nrows, ncols = a.shape
  9. ml = 0
  10. for k in range(-nrows+1, 0):
  11. if np.diag(a, k).any():
  12. ml = -k
  13. break
  14. mu = 0
  15. for k in range(nrows-1, 0, -1):
  16. if np.diag(a, k).any():
  17. mu = k
  18. break
  19. return ml, mu
  20. def _linear_func(t, y, a):
  21. """Linear system dy/dt = a * y"""
  22. return a.dot(y)
  23. def _linear_jac(t, y, a):
  24. """Jacobian of a * y is a."""
  25. return a
  26. def _linear_banded_jac(t, y, a):
  27. """Banded Jacobian."""
  28. ml, mu = _band_count(a)
  29. bjac = []
  30. for k in range(mu, 0, -1):
  31. bjac.append(np.r_[[0] * k, np.diag(a, k)])
  32. bjac.append(np.diag(a))
  33. for k in range(-1, -ml-1, -1):
  34. bjac.append(np.r_[np.diag(a, k), [0] * (-k)])
  35. return bjac
  36. def _solve_linear_sys(a, y0, tend=1, dt=0.1,
  37. solver=None, method='bdf', use_jac=True,
  38. with_jacobian=False, banded=False):
  39. """Use scipy.integrate.ode to solve a linear system of ODEs.
  40. a : square ndarray
  41. Matrix of the linear system to be solved.
  42. y0 : ndarray
  43. Initial condition
  44. tend : float
  45. Stop time.
  46. dt : float
  47. Step size of the output.
  48. solver : str
  49. If not None, this must be "vode", "lsoda" or "zvode".
  50. method : str
  51. Either "bdf" or "adams".
  52. use_jac : bool
  53. Determines if the jacobian function is passed to ode().
  54. with_jacobian : bool
  55. Passed to ode.set_integrator().
  56. banded : bool
  57. Determines whether a banded or full jacobian is used.
  58. If `banded` is True, `lband` and `uband` are determined by the
  59. values in `a`.
  60. """
  61. if banded:
  62. lband, uband = _band_count(a)
  63. else:
  64. lband = None
  65. uband = None
  66. if use_jac:
  67. if banded:
  68. r = ode(_linear_func, _linear_banded_jac)
  69. else:
  70. r = ode(_linear_func, _linear_jac)
  71. else:
  72. r = ode(_linear_func)
  73. if solver is None:
  74. if np.iscomplexobj(a):
  75. solver = "zvode"
  76. else:
  77. solver = "vode"
  78. r.set_integrator(solver,
  79. with_jacobian=with_jacobian,
  80. method=method,
  81. lband=lband, uband=uband,
  82. rtol=1e-9, atol=1e-10,
  83. )
  84. t0 = 0
  85. r.set_initial_value(y0, t0)
  86. r.set_f_params(a)
  87. r.set_jac_params(a)
  88. t = [t0]
  89. y = [y0]
  90. while r.successful() and r.t < tend:
  91. r.integrate(r.t + dt)
  92. t.append(r.t)
  93. y.append(r.y)
  94. t = np.array(t)
  95. y = np.array(y)
  96. return t, y
  97. def _analytical_solution(a, y0, t):
  98. """
  99. Analytical solution to the linear differential equations dy/dt = a*y.
  100. The solution is only valid if `a` is diagonalizable.
  101. Returns a 2-d array with shape (len(t), len(y0)).
  102. """
  103. lam, v = np.linalg.eig(a)
  104. c = np.linalg.solve(v, y0)
  105. e = c * np.exp(lam * t.reshape(-1, 1))
  106. sol = e.dot(v.T)
  107. return sol
  108. def test_banded_ode_solvers():
  109. # Test the "lsoda", "vode" and "zvode" solvers of the `ode` class
  110. # with a system that has a banded Jacobian matrix.
  111. t_exact = np.linspace(0, 1.0, 5)
  112. # --- Real arrays for testing the "lsoda" and "vode" solvers ---
  113. # lband = 2, uband = 1:
  114. a_real = np.array([[-0.6, 0.1, 0.0, 0.0, 0.0],
  115. [0.2, -0.5, 0.9, 0.0, 0.0],
  116. [0.1, 0.1, -0.4, 0.1, 0.0],
  117. [0.0, 0.3, -0.1, -0.9, -0.3],
  118. [0.0, 0.0, 0.1, 0.1, -0.7]])
  119. # lband = 0, uband = 1:
  120. a_real_upper = np.triu(a_real)
  121. # lband = 2, uband = 0:
  122. a_real_lower = np.tril(a_real)
  123. # lband = 0, uband = 0:
  124. a_real_diag = np.triu(a_real_lower)
  125. real_matrices = [a_real, a_real_upper, a_real_lower, a_real_diag]
  126. real_solutions = []
  127. for a in real_matrices:
  128. y0 = np.arange(1, a.shape[0] + 1)
  129. y_exact = _analytical_solution(a, y0, t_exact)
  130. real_solutions.append((y0, t_exact, y_exact))
  131. def check_real(idx, solver, meth, use_jac, with_jac, banded):
  132. a = real_matrices[idx]
  133. y0, t_exact, y_exact = real_solutions[idx]
  134. t, y = _solve_linear_sys(a, y0,
  135. tend=t_exact[-1],
  136. dt=t_exact[1] - t_exact[0],
  137. solver=solver,
  138. method=meth,
  139. use_jac=use_jac,
  140. with_jacobian=with_jac,
  141. banded=banded)
  142. assert_allclose(t, t_exact)
  143. assert_allclose(y, y_exact)
  144. for idx in range(len(real_matrices)):
  145. p = [['vode', 'lsoda'], # solver
  146. ['bdf', 'adams'], # method
  147. [False, True], # use_jac
  148. [False, True], # with_jacobian
  149. [False, True]] # banded
  150. for solver, meth, use_jac, with_jac, banded in itertools.product(*p):
  151. check_real(idx, solver, meth, use_jac, with_jac, banded)
  152. # --- Complex arrays for testing the "zvode" solver ---
  153. # complex, lband = 2, uband = 1:
  154. a_complex = a_real - 0.5j * a_real
  155. # complex, lband = 0, uband = 0:
  156. a_complex_diag = np.diag(np.diag(a_complex))
  157. complex_matrices = [a_complex, a_complex_diag]
  158. complex_solutions = []
  159. for a in complex_matrices:
  160. y0 = np.arange(1, a.shape[0] + 1) + 1j
  161. y_exact = _analytical_solution(a, y0, t_exact)
  162. complex_solutions.append((y0, t_exact, y_exact))
  163. def check_complex(idx, solver, meth, use_jac, with_jac, banded):
  164. a = complex_matrices[idx]
  165. y0, t_exact, y_exact = complex_solutions[idx]
  166. t, y = _solve_linear_sys(a, y0,
  167. tend=t_exact[-1],
  168. dt=t_exact[1] - t_exact[0],
  169. solver=solver,
  170. method=meth,
  171. use_jac=use_jac,
  172. with_jacobian=with_jac,
  173. banded=banded)
  174. assert_allclose(t, t_exact)
  175. assert_allclose(y, y_exact)
  176. for idx in range(len(complex_matrices)):
  177. p = [['bdf', 'adams'], # method
  178. [False, True], # use_jac
  179. [False, True], # with_jacobian
  180. [False, True]] # banded
  181. for meth, use_jac, with_jac, banded in itertools.product(*p):
  182. check_complex(idx, "zvode", meth, use_jac, with_jac, banded)