test_quadrature.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. from __future__ import division, print_function, absolute_import
  2. import numpy as np
  3. from numpy import cos, sin, pi
  4. from numpy.testing import assert_equal, \
  5. assert_almost_equal, assert_allclose, assert_
  6. from scipy._lib._numpy_compat import suppress_warnings
  7. from scipy.integrate import (quadrature, romberg, romb, newton_cotes,
  8. cumtrapz, quad, simps, fixed_quad)
  9. from scipy.integrate.quadrature import AccuracyWarning
  10. class TestFixedQuad(object):
  11. def test_scalar(self):
  12. n = 4
  13. func = lambda x: x**(2*n - 1)
  14. expected = 1/(2*n)
  15. got, _ = fixed_quad(func, 0, 1, n=n)
  16. # quadrature exact for this input
  17. assert_allclose(got, expected, rtol=1e-12)
  18. def test_vector(self):
  19. n = 4
  20. p = np.arange(1, 2*n)
  21. func = lambda x: x**p[:,None]
  22. expected = 1/(p + 1)
  23. got, _ = fixed_quad(func, 0, 1, n=n)
  24. assert_allclose(got, expected, rtol=1e-12)
  25. class TestQuadrature(object):
  26. def quad(self, x, a, b, args):
  27. raise NotImplementedError
  28. def test_quadrature(self):
  29. # Typical function with two extra arguments:
  30. def myfunc(x, n, z): # Bessel function integrand
  31. return cos(n*x-z*sin(x))/pi
  32. val, err = quadrature(myfunc, 0, pi, (2, 1.8))
  33. table_val = 0.30614353532540296487
  34. assert_almost_equal(val, table_val, decimal=7)
  35. def test_quadrature_rtol(self):
  36. def myfunc(x, n, z): # Bessel function integrand
  37. return 1e90 * cos(n*x-z*sin(x))/pi
  38. val, err = quadrature(myfunc, 0, pi, (2, 1.8), rtol=1e-10)
  39. table_val = 1e90 * 0.30614353532540296487
  40. assert_allclose(val, table_val, rtol=1e-10)
  41. def test_quadrature_miniter(self):
  42. # Typical function with two extra arguments:
  43. def myfunc(x, n, z): # Bessel function integrand
  44. return cos(n*x-z*sin(x))/pi
  45. table_val = 0.30614353532540296487
  46. for miniter in [5, 52]:
  47. val, err = quadrature(myfunc, 0, pi, (2, 1.8), miniter=miniter)
  48. assert_almost_equal(val, table_val, decimal=7)
  49. assert_(err < 1.0)
  50. def test_quadrature_single_args(self):
  51. def myfunc(x, n):
  52. return 1e90 * cos(n*x-1.8*sin(x))/pi
  53. val, err = quadrature(myfunc, 0, pi, args=2, rtol=1e-10)
  54. table_val = 1e90 * 0.30614353532540296487
  55. assert_allclose(val, table_val, rtol=1e-10)
  56. def test_romberg(self):
  57. # Typical function with two extra arguments:
  58. def myfunc(x, n, z): # Bessel function integrand
  59. return cos(n*x-z*sin(x))/pi
  60. val = romberg(myfunc, 0, pi, args=(2, 1.8))
  61. table_val = 0.30614353532540296487
  62. assert_almost_equal(val, table_val, decimal=7)
  63. def test_romberg_rtol(self):
  64. # Typical function with two extra arguments:
  65. def myfunc(x, n, z): # Bessel function integrand
  66. return 1e19*cos(n*x-z*sin(x))/pi
  67. val = romberg(myfunc, 0, pi, args=(2, 1.8), rtol=1e-10)
  68. table_val = 1e19*0.30614353532540296487
  69. assert_allclose(val, table_val, rtol=1e-10)
  70. def test_romb(self):
  71. assert_equal(romb(np.arange(17)), 128)
  72. def test_romb_gh_3731(self):
  73. # Check that romb makes maximal use of data points
  74. x = np.arange(2**4+1)
  75. y = np.cos(0.2*x)
  76. val = romb(y)
  77. val2, err = quad(lambda x: np.cos(0.2*x), x.min(), x.max())
  78. assert_allclose(val, val2, rtol=1e-8, atol=0)
  79. # should be equal to romb with 2**k+1 samples
  80. with suppress_warnings() as sup:
  81. sup.filter(AccuracyWarning, "divmax .4. exceeded")
  82. val3 = romberg(lambda x: np.cos(0.2*x), x.min(), x.max(), divmax=4)
  83. assert_allclose(val, val3, rtol=1e-12, atol=0)
  84. def test_non_dtype(self):
  85. # Check that we work fine with functions returning float
  86. import math
  87. valmath = romberg(math.sin, 0, 1)
  88. expected_val = 0.45969769413185085
  89. assert_almost_equal(valmath, expected_val, decimal=7)
  90. def test_newton_cotes(self):
  91. """Test the first few degrees, for evenly spaced points."""
  92. n = 1
  93. wts, errcoff = newton_cotes(n, 1)
  94. assert_equal(wts, n*np.array([0.5, 0.5]))
  95. assert_almost_equal(errcoff, -n**3/12.0)
  96. n = 2
  97. wts, errcoff = newton_cotes(n, 1)
  98. assert_almost_equal(wts, n*np.array([1.0, 4.0, 1.0])/6.0)
  99. assert_almost_equal(errcoff, -n**5/2880.0)
  100. n = 3
  101. wts, errcoff = newton_cotes(n, 1)
  102. assert_almost_equal(wts, n*np.array([1.0, 3.0, 3.0, 1.0])/8.0)
  103. assert_almost_equal(errcoff, -n**5/6480.0)
  104. n = 4
  105. wts, errcoff = newton_cotes(n, 1)
  106. assert_almost_equal(wts, n*np.array([7.0, 32.0, 12.0, 32.0, 7.0])/90.0)
  107. assert_almost_equal(errcoff, -n**7/1935360.0)
  108. def test_newton_cotes2(self):
  109. """Test newton_cotes with points that are not evenly spaced."""
  110. x = np.array([0.0, 1.5, 2.0])
  111. y = x**2
  112. wts, errcoff = newton_cotes(x)
  113. exact_integral = 8.0/3
  114. numeric_integral = np.dot(wts, y)
  115. assert_almost_equal(numeric_integral, exact_integral)
  116. x = np.array([0.0, 1.4, 2.1, 3.0])
  117. y = x**2
  118. wts, errcoff = newton_cotes(x)
  119. exact_integral = 9.0
  120. numeric_integral = np.dot(wts, y)
  121. assert_almost_equal(numeric_integral, exact_integral)
  122. def test_simps(self):
  123. y = np.arange(17)
  124. assert_equal(simps(y), 128)
  125. assert_equal(simps(y, dx=0.5), 64)
  126. assert_equal(simps(y, x=np.linspace(0, 4, 17)), 32)
  127. y = np.arange(4)
  128. x = 2**y
  129. assert_equal(simps(y, x=x, even='avg'), 13.875)
  130. assert_equal(simps(y, x=x, even='first'), 13.75)
  131. assert_equal(simps(y, x=x, even='last'), 14)
  132. class TestCumtrapz(object):
  133. def test_1d(self):
  134. x = np.linspace(-2, 2, num=5)
  135. y = x
  136. y_int = cumtrapz(y, x, initial=0)
  137. y_expected = [0., -1.5, -2., -1.5, 0.]
  138. assert_allclose(y_int, y_expected)
  139. y_int = cumtrapz(y, x, initial=None)
  140. assert_allclose(y_int, y_expected[1:])
  141. def test_y_nd_x_nd(self):
  142. x = np.arange(3 * 2 * 4).reshape(3, 2, 4)
  143. y = x
  144. y_int = cumtrapz(y, x, initial=0)
  145. y_expected = np.array([[[0., 0.5, 2., 4.5],
  146. [0., 4.5, 10., 16.5]],
  147. [[0., 8.5, 18., 28.5],
  148. [0., 12.5, 26., 40.5]],
  149. [[0., 16.5, 34., 52.5],
  150. [0., 20.5, 42., 64.5]]])
  151. assert_allclose(y_int, y_expected)
  152. # Try with all axes
  153. shapes = [(2, 2, 4), (3, 1, 4), (3, 2, 3)]
  154. for axis, shape in zip([0, 1, 2], shapes):
  155. y_int = cumtrapz(y, x, initial=3.45, axis=axis)
  156. assert_equal(y_int.shape, (3, 2, 4))
  157. y_int = cumtrapz(y, x, initial=None, axis=axis)
  158. assert_equal(y_int.shape, shape)
  159. def test_y_nd_x_1d(self):
  160. y = np.arange(3 * 2 * 4).reshape(3, 2, 4)
  161. x = np.arange(4)**2
  162. # Try with all axes
  163. ys_expected = (
  164. np.array([[[4., 5., 6., 7.],
  165. [8., 9., 10., 11.]],
  166. [[40., 44., 48., 52.],
  167. [56., 60., 64., 68.]]]),
  168. np.array([[[2., 3., 4., 5.]],
  169. [[10., 11., 12., 13.]],
  170. [[18., 19., 20., 21.]]]),
  171. np.array([[[0.5, 5., 17.5],
  172. [4.5, 21., 53.5]],
  173. [[8.5, 37., 89.5],
  174. [12.5, 53., 125.5]],
  175. [[16.5, 69., 161.5],
  176. [20.5, 85., 197.5]]]))
  177. for axis, y_expected in zip([0, 1, 2], ys_expected):
  178. y_int = cumtrapz(y, x=x[:y.shape[axis]], axis=axis, initial=None)
  179. assert_allclose(y_int, y_expected)
  180. def test_x_none(self):
  181. y = np.linspace(-2, 2, num=5)
  182. y_int = cumtrapz(y)
  183. y_expected = [-1.5, -2., -1.5, 0.]
  184. assert_allclose(y_int, y_expected)
  185. y_int = cumtrapz(y, initial=1.23)
  186. y_expected = [1.23, -1.5, -2., -1.5, 0.]
  187. assert_allclose(y_int, y_expected)
  188. y_int = cumtrapz(y, dx=3)
  189. y_expected = [-4.5, -6., -4.5, 0.]
  190. assert_allclose(y_int, y_expected)
  191. y_int = cumtrapz(y, dx=3, initial=1.23)
  192. y_expected = [1.23, -4.5, -6., -4.5, 0.]
  193. assert_allclose(y_int, y_expected)