test_polyint.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663
  1. from __future__ import division, print_function, absolute_import
  2. import warnings
  3. import numpy as np
  4. from numpy.testing import (
  5. assert_almost_equal, assert_array_equal, assert_array_almost_equal,
  6. assert_allclose, assert_equal, assert_)
  7. from pytest import raises as assert_raises
  8. from scipy.interpolate import (
  9. KroghInterpolator, krogh_interpolate,
  10. BarycentricInterpolator, barycentric_interpolate,
  11. approximate_taylor_polynomial, pchip, PchipInterpolator,
  12. pchip_interpolate, Akima1DInterpolator, CubicSpline, make_interp_spline)
  13. from scipy._lib.six import xrange
  14. def check_shape(interpolator_cls, x_shape, y_shape, deriv_shape=None, axis=0,
  15. extra_args={}):
  16. np.random.seed(1234)
  17. x = [-1, 0, 1, 2, 3, 4]
  18. s = list(range(1, len(y_shape)+1))
  19. s.insert(axis % (len(y_shape)+1), 0)
  20. y = np.random.rand(*((6,) + y_shape)).transpose(s)
  21. # Cython code chokes on y.shape = (0, 3) etc, skip them
  22. if y.size == 0:
  23. return
  24. xi = np.zeros(x_shape)
  25. yi = interpolator_cls(x, y, axis=axis, **extra_args)(xi)
  26. target_shape = ((deriv_shape or ()) + y.shape[:axis]
  27. + x_shape + y.shape[axis:][1:])
  28. assert_equal(yi.shape, target_shape)
  29. # check it works also with lists
  30. if x_shape and y.size > 0:
  31. interpolator_cls(list(x), list(y), axis=axis, **extra_args)(list(xi))
  32. # check also values
  33. if xi.size > 0 and deriv_shape is None:
  34. bs_shape = y.shape[:axis] + (1,)*len(x_shape) + y.shape[axis:][1:]
  35. yv = y[((slice(None,),)*(axis % y.ndim)) + (1,)]
  36. yv = yv.reshape(bs_shape)
  37. yi, y = np.broadcast_arrays(yi, yv)
  38. assert_allclose(yi, y)
  39. SHAPES = [(), (0,), (1,), (6, 2, 5)]
  40. def test_shapes():
  41. def spl_interp(x, y, axis):
  42. return make_interp_spline(x, y, axis=axis)
  43. for ip in [KroghInterpolator, BarycentricInterpolator, pchip,
  44. Akima1DInterpolator, CubicSpline, spl_interp]:
  45. for s1 in SHAPES:
  46. for s2 in SHAPES:
  47. for axis in range(-len(s2), len(s2)):
  48. if ip != CubicSpline:
  49. check_shape(ip, s1, s2, None, axis)
  50. else:
  51. for bc in ['natural', 'clamped']:
  52. extra = {'bc_type': bc}
  53. check_shape(ip, s1, s2, None, axis, extra)
  54. def test_derivs_shapes():
  55. def krogh_derivs(x, y, axis=0):
  56. return KroghInterpolator(x, y, axis).derivatives
  57. for s1 in SHAPES:
  58. for s2 in SHAPES:
  59. for axis in range(-len(s2), len(s2)):
  60. check_shape(krogh_derivs, s1, s2, (6,), axis)
  61. def test_deriv_shapes():
  62. def krogh_deriv(x, y, axis=0):
  63. return KroghInterpolator(x, y, axis).derivative
  64. def pchip_deriv(x, y, axis=0):
  65. return pchip(x, y, axis).derivative()
  66. def pchip_deriv2(x, y, axis=0):
  67. return pchip(x, y, axis).derivative(2)
  68. def pchip_antideriv(x, y, axis=0):
  69. return pchip(x, y, axis).derivative()
  70. def pchip_antideriv2(x, y, axis=0):
  71. return pchip(x, y, axis).derivative(2)
  72. def pchip_deriv_inplace(x, y, axis=0):
  73. class P(PchipInterpolator):
  74. def __call__(self, x):
  75. return PchipInterpolator.__call__(self, x, 1)
  76. pass
  77. return P(x, y, axis)
  78. def akima_deriv(x, y, axis=0):
  79. return Akima1DInterpolator(x, y, axis).derivative()
  80. def akima_antideriv(x, y, axis=0):
  81. return Akima1DInterpolator(x, y, axis).antiderivative()
  82. def cspline_deriv(x, y, axis=0):
  83. return CubicSpline(x, y, axis).derivative()
  84. def cspline_antideriv(x, y, axis=0):
  85. return CubicSpline(x, y, axis).antiderivative()
  86. def bspl_deriv(x, y, axis=0):
  87. return make_interp_spline(x, y, axis=axis).derivative()
  88. def bspl_antideriv(x, y, axis=0):
  89. return make_interp_spline(x, y, axis=axis).antiderivative()
  90. for ip in [krogh_deriv, pchip_deriv, pchip_deriv2, pchip_deriv_inplace,
  91. pchip_antideriv, pchip_antideriv2, akima_deriv, akima_antideriv,
  92. cspline_deriv, cspline_antideriv, bspl_deriv, bspl_antideriv]:
  93. for s1 in SHAPES:
  94. for s2 in SHAPES:
  95. for axis in range(-len(s2), len(s2)):
  96. check_shape(ip, s1, s2, (), axis)
  97. def _check_complex(ip):
  98. x = [1, 2, 3, 4]
  99. y = [1, 2, 1j, 3]
  100. p = ip(x, y)
  101. assert_allclose(y, p(x))
  102. def test_complex():
  103. for ip in [KroghInterpolator, BarycentricInterpolator, pchip, CubicSpline]:
  104. _check_complex(ip)
  105. class TestKrogh(object):
  106. def setup_method(self):
  107. self.true_poly = np.poly1d([-2,3,1,5,-4])
  108. self.test_xs = np.linspace(-1,1,100)
  109. self.xs = np.linspace(-1,1,5)
  110. self.ys = self.true_poly(self.xs)
  111. def test_lagrange(self):
  112. P = KroghInterpolator(self.xs,self.ys)
  113. assert_almost_equal(self.true_poly(self.test_xs),P(self.test_xs))
  114. def test_scalar(self):
  115. P = KroghInterpolator(self.xs,self.ys)
  116. assert_almost_equal(self.true_poly(7),P(7))
  117. assert_almost_equal(self.true_poly(np.array(7)), P(np.array(7)))
  118. def test_derivatives(self):
  119. P = KroghInterpolator(self.xs,self.ys)
  120. D = P.derivatives(self.test_xs)
  121. for i in xrange(D.shape[0]):
  122. assert_almost_equal(self.true_poly.deriv(i)(self.test_xs),
  123. D[i])
  124. def test_low_derivatives(self):
  125. P = KroghInterpolator(self.xs,self.ys)
  126. D = P.derivatives(self.test_xs,len(self.xs)+2)
  127. for i in xrange(D.shape[0]):
  128. assert_almost_equal(self.true_poly.deriv(i)(self.test_xs),
  129. D[i])
  130. def test_derivative(self):
  131. P = KroghInterpolator(self.xs,self.ys)
  132. m = 10
  133. r = P.derivatives(self.test_xs,m)
  134. for i in xrange(m):
  135. assert_almost_equal(P.derivative(self.test_xs,i),r[i])
  136. def test_high_derivative(self):
  137. P = KroghInterpolator(self.xs,self.ys)
  138. for i in xrange(len(self.xs),2*len(self.xs)):
  139. assert_almost_equal(P.derivative(self.test_xs,i),
  140. np.zeros(len(self.test_xs)))
  141. def test_hermite(self):
  142. xs = [0,0,0,1,1,1,2]
  143. ys = [self.true_poly(0),
  144. self.true_poly.deriv(1)(0),
  145. self.true_poly.deriv(2)(0),
  146. self.true_poly(1),
  147. self.true_poly.deriv(1)(1),
  148. self.true_poly.deriv(2)(1),
  149. self.true_poly(2)]
  150. P = KroghInterpolator(self.xs,self.ys)
  151. assert_almost_equal(self.true_poly(self.test_xs),P(self.test_xs))
  152. def test_vector(self):
  153. xs = [0, 1, 2]
  154. ys = np.array([[0,1],[1,0],[2,1]])
  155. P = KroghInterpolator(xs,ys)
  156. Pi = [KroghInterpolator(xs,ys[:,i]) for i in xrange(ys.shape[1])]
  157. test_xs = np.linspace(-1,3,100)
  158. assert_almost_equal(P(test_xs),
  159. np.rollaxis(np.asarray([p(test_xs) for p in Pi]),-1))
  160. assert_almost_equal(P.derivatives(test_xs),
  161. np.transpose(np.asarray([p.derivatives(test_xs) for p in Pi]),
  162. (1,2,0)))
  163. def test_empty(self):
  164. P = KroghInterpolator(self.xs,self.ys)
  165. assert_array_equal(P([]), [])
  166. def test_shapes_scalarvalue(self):
  167. P = KroghInterpolator(self.xs,self.ys)
  168. assert_array_equal(np.shape(P(0)), ())
  169. assert_array_equal(np.shape(P(np.array(0))), ())
  170. assert_array_equal(np.shape(P([0])), (1,))
  171. assert_array_equal(np.shape(P([0,1])), (2,))
  172. def test_shapes_scalarvalue_derivative(self):
  173. P = KroghInterpolator(self.xs,self.ys)
  174. n = P.n
  175. assert_array_equal(np.shape(P.derivatives(0)), (n,))
  176. assert_array_equal(np.shape(P.derivatives(np.array(0))), (n,))
  177. assert_array_equal(np.shape(P.derivatives([0])), (n,1))
  178. assert_array_equal(np.shape(P.derivatives([0,1])), (n,2))
  179. def test_shapes_vectorvalue(self):
  180. P = KroghInterpolator(self.xs,np.outer(self.ys,np.arange(3)))
  181. assert_array_equal(np.shape(P(0)), (3,))
  182. assert_array_equal(np.shape(P([0])), (1,3))
  183. assert_array_equal(np.shape(P([0,1])), (2,3))
  184. def test_shapes_1d_vectorvalue(self):
  185. P = KroghInterpolator(self.xs,np.outer(self.ys,[1]))
  186. assert_array_equal(np.shape(P(0)), (1,))
  187. assert_array_equal(np.shape(P([0])), (1,1))
  188. assert_array_equal(np.shape(P([0,1])), (2,1))
  189. def test_shapes_vectorvalue_derivative(self):
  190. P = KroghInterpolator(self.xs,np.outer(self.ys,np.arange(3)))
  191. n = P.n
  192. assert_array_equal(np.shape(P.derivatives(0)), (n,3))
  193. assert_array_equal(np.shape(P.derivatives([0])), (n,1,3))
  194. assert_array_equal(np.shape(P.derivatives([0,1])), (n,2,3))
  195. def test_wrapper(self):
  196. P = KroghInterpolator(self.xs, self.ys)
  197. ki = krogh_interpolate
  198. assert_almost_equal(P(self.test_xs), ki(self.xs, self.ys, self.test_xs))
  199. assert_almost_equal(P.derivative(self.test_xs, 2),
  200. ki(self.xs, self.ys, self.test_xs, der=2))
  201. assert_almost_equal(P.derivatives(self.test_xs, 2),
  202. ki(self.xs, self.ys, self.test_xs, der=[0, 1]))
  203. def test_int_inputs(self):
  204. # Check input args are cast correctly to floats, gh-3669
  205. x = [0, 234, 468, 702, 936, 1170, 1404, 2340, 3744, 6084, 8424,
  206. 13104, 60000]
  207. offset_cdf = np.array([-0.95, -0.86114777, -0.8147762, -0.64072425,
  208. -0.48002351, -0.34925329, -0.26503107,
  209. -0.13148093, -0.12988833, -0.12979296,
  210. -0.12973574, -0.08582937, 0.05])
  211. f = KroghInterpolator(x, offset_cdf)
  212. assert_allclose(abs((f(x) - offset_cdf) / f.derivative(x, 1)),
  213. 0, atol=1e-10)
  214. def test_derivatives_complex(self):
  215. # regression test for gh-7381: krogh.derivatives(0) fails complex y
  216. x, y = np.array([-1, -1, 0, 1, 1]), np.array([1, 1.0j, 0, -1, 1.0j])
  217. func = KroghInterpolator(x, y)
  218. cmplx = func.derivatives(0)
  219. cmplx2 = (KroghInterpolator(x, y.real).derivatives(0) +
  220. 1j*KroghInterpolator(x, y.imag).derivatives(0))
  221. assert_allclose(cmplx, cmplx2, atol=1e-15)
  222. class TestTaylor(object):
  223. def test_exponential(self):
  224. degree = 5
  225. p = approximate_taylor_polynomial(np.exp, 0, degree, 1, 15)
  226. for i in xrange(degree+1):
  227. assert_almost_equal(p(0),1)
  228. p = p.deriv()
  229. assert_almost_equal(p(0),0)
  230. class TestBarycentric(object):
  231. def setup_method(self):
  232. self.true_poly = np.poly1d([-2, 3, 1, 5, -4])
  233. self.test_xs = np.linspace(-1, 1, 100)
  234. self.xs = np.linspace(-1, 1, 5)
  235. self.ys = self.true_poly(self.xs)
  236. def test_lagrange(self):
  237. P = BarycentricInterpolator(self.xs, self.ys)
  238. assert_almost_equal(self.true_poly(self.test_xs), P(self.test_xs))
  239. def test_scalar(self):
  240. P = BarycentricInterpolator(self.xs, self.ys)
  241. assert_almost_equal(self.true_poly(7), P(7))
  242. assert_almost_equal(self.true_poly(np.array(7)), P(np.array(7)))
  243. def test_delayed(self):
  244. P = BarycentricInterpolator(self.xs)
  245. P.set_yi(self.ys)
  246. assert_almost_equal(self.true_poly(self.test_xs), P(self.test_xs))
  247. def test_append(self):
  248. P = BarycentricInterpolator(self.xs[:3], self.ys[:3])
  249. P.add_xi(self.xs[3:], self.ys[3:])
  250. assert_almost_equal(self.true_poly(self.test_xs), P(self.test_xs))
  251. def test_vector(self):
  252. xs = [0, 1, 2]
  253. ys = np.array([[0, 1], [1, 0], [2, 1]])
  254. BI = BarycentricInterpolator
  255. P = BI(xs, ys)
  256. Pi = [BI(xs, ys[:, i]) for i in xrange(ys.shape[1])]
  257. test_xs = np.linspace(-1, 3, 100)
  258. assert_almost_equal(P(test_xs),
  259. np.rollaxis(np.asarray([p(test_xs) for p in Pi]), -1))
  260. def test_shapes_scalarvalue(self):
  261. P = BarycentricInterpolator(self.xs, self.ys)
  262. assert_array_equal(np.shape(P(0)), ())
  263. assert_array_equal(np.shape(P(np.array(0))), ())
  264. assert_array_equal(np.shape(P([0])), (1,))
  265. assert_array_equal(np.shape(P([0, 1])), (2,))
  266. def test_shapes_vectorvalue(self):
  267. P = BarycentricInterpolator(self.xs, np.outer(self.ys, np.arange(3)))
  268. assert_array_equal(np.shape(P(0)), (3,))
  269. assert_array_equal(np.shape(P([0])), (1, 3))
  270. assert_array_equal(np.shape(P([0, 1])), (2, 3))
  271. def test_shapes_1d_vectorvalue(self):
  272. P = BarycentricInterpolator(self.xs, np.outer(self.ys, [1]))
  273. assert_array_equal(np.shape(P(0)), (1,))
  274. assert_array_equal(np.shape(P([0])), (1, 1))
  275. assert_array_equal(np.shape(P([0,1])), (2, 1))
  276. def test_wrapper(self):
  277. P = BarycentricInterpolator(self.xs, self.ys)
  278. values = barycentric_interpolate(self.xs, self.ys, self.test_xs)
  279. assert_almost_equal(P(self.test_xs), values)
  280. class TestPCHIP(object):
  281. def _make_random(self, npts=20):
  282. np.random.seed(1234)
  283. xi = np.sort(np.random.random(npts))
  284. yi = np.random.random(npts)
  285. return pchip(xi, yi), xi, yi
  286. def test_overshoot(self):
  287. # PCHIP should not overshoot
  288. p, xi, yi = self._make_random()
  289. for i in range(len(xi)-1):
  290. x1, x2 = xi[i], xi[i+1]
  291. y1, y2 = yi[i], yi[i+1]
  292. if y1 > y2:
  293. y1, y2 = y2, y1
  294. xp = np.linspace(x1, x2, 10)
  295. yp = p(xp)
  296. assert_(((y1 <= yp) & (yp <= y2)).all())
  297. def test_monotone(self):
  298. # PCHIP should preserve monotonicty
  299. p, xi, yi = self._make_random()
  300. for i in range(len(xi)-1):
  301. x1, x2 = xi[i], xi[i+1]
  302. y1, y2 = yi[i], yi[i+1]
  303. xp = np.linspace(x1, x2, 10)
  304. yp = p(xp)
  305. assert_(((y2-y1) * (yp[1:] - yp[:1]) > 0).all())
  306. def test_cast(self):
  307. # regression test for integer input data, see gh-3453
  308. data = np.array([[0, 4, 12, 27, 47, 60, 79, 87, 99, 100],
  309. [-33, -33, -19, -2, 12, 26, 38, 45, 53, 55]])
  310. xx = np.arange(100)
  311. curve = pchip(data[0], data[1])(xx)
  312. data1 = data * 1.0
  313. curve1 = pchip(data1[0], data1[1])(xx)
  314. assert_allclose(curve, curve1, atol=1e-14, rtol=1e-14)
  315. def test_nag(self):
  316. # Example from NAG C implementation,
  317. # http://nag.com/numeric/cl/nagdoc_cl25/html/e01/e01bec.html
  318. # suggested in gh-5326 as a smoke test for the way the derivatives
  319. # are computed (see also gh-3453)
  320. from scipy._lib.six import StringIO
  321. dataStr = '''
  322. 7.99 0.00000E+0
  323. 8.09 0.27643E-4
  324. 8.19 0.43750E-1
  325. 8.70 0.16918E+0
  326. 9.20 0.46943E+0
  327. 10.00 0.94374E+0
  328. 12.00 0.99864E+0
  329. 15.00 0.99992E+0
  330. 20.00 0.99999E+0
  331. '''
  332. data = np.loadtxt(StringIO(dataStr))
  333. pch = pchip(data[:,0], data[:,1])
  334. resultStr = '''
  335. 7.9900 0.0000
  336. 9.1910 0.4640
  337. 10.3920 0.9645
  338. 11.5930 0.9965
  339. 12.7940 0.9992
  340. 13.9950 0.9998
  341. 15.1960 0.9999
  342. 16.3970 1.0000
  343. 17.5980 1.0000
  344. 18.7990 1.0000
  345. 20.0000 1.0000
  346. '''
  347. result = np.loadtxt(StringIO(resultStr))
  348. assert_allclose(result[:,1], pch(result[:,0]), rtol=0., atol=5e-5)
  349. def test_endslopes(self):
  350. # this is a smoke test for gh-3453: PCHIP interpolator should not
  351. # set edge slopes to zero if the data do not suggest zero edge derivatives
  352. x = np.array([0.0, 0.1, 0.25, 0.35])
  353. y1 = np.array([279.35, 0.5e3, 1.0e3, 2.5e3])
  354. y2 = np.array([279.35, 2.5e3, 1.50e3, 1.0e3])
  355. for pp in (pchip(x, y1), pchip(x, y2)):
  356. for t in (x[0], x[-1]):
  357. assert_(pp(t, 1) != 0)
  358. def test_all_zeros(self):
  359. x = np.arange(10)
  360. y = np.zeros_like(x)
  361. # this should work and not generate any warnings
  362. with warnings.catch_warnings():
  363. warnings.filterwarnings('error')
  364. pch = pchip(x, y)
  365. xx = np.linspace(0, 9, 101)
  366. assert_equal(pch(xx), 0.)
  367. def test_two_points(self):
  368. # regression test for gh-6222: pchip([0, 1], [0, 1]) fails because
  369. # it tries to use a three-point scheme to estimate edge derivatives,
  370. # while there are only two points available.
  371. # Instead, it should construct a linear interpolator.
  372. x = np.linspace(0, 1, 11)
  373. p = pchip([0, 1], [0, 2])
  374. assert_allclose(p(x), 2*x, atol=1e-15)
  375. def test_pchip_interpolate(self):
  376. assert_array_almost_equal(
  377. pchip_interpolate([1,2,3], [4,5,6], [0.5], der=1),
  378. [1.])
  379. assert_array_almost_equal(
  380. pchip_interpolate([1,2,3], [4,5,6], [0.5], der=0),
  381. [3.5])
  382. assert_array_almost_equal(
  383. pchip_interpolate([1,2,3], [4,5,6], [0.5], der=[0, 1]),
  384. [[3.5], [1]])
  385. def test_roots(self):
  386. # regression test for gh-6357: .roots method should work
  387. p = pchip([0, 1], [-1, 1])
  388. r = p.roots()
  389. assert_allclose(r, 0.5)
  390. class TestCubicSpline(object):
  391. @staticmethod
  392. def check_correctness(S, bc_start='not-a-knot', bc_end='not-a-knot',
  393. tol=1e-14):
  394. """Check that spline coefficients satisfy the continuity and boundary
  395. conditions."""
  396. x = S.x
  397. c = S.c
  398. dx = np.diff(x)
  399. dx = dx.reshape([dx.shape[0]] + [1] * (c.ndim - 2))
  400. dxi = dx[:-1]
  401. # Check C2 continuity.
  402. assert_allclose(c[3, 1:], c[0, :-1] * dxi**3 + c[1, :-1] * dxi**2 +
  403. c[2, :-1] * dxi + c[3, :-1], rtol=tol, atol=tol)
  404. assert_allclose(c[2, 1:], 3 * c[0, :-1] * dxi**2 +
  405. 2 * c[1, :-1] * dxi + c[2, :-1], rtol=tol, atol=tol)
  406. assert_allclose(c[1, 1:], 3 * c[0, :-1] * dxi + c[1, :-1],
  407. rtol=tol, atol=tol)
  408. # Check that we found a parabola, the third derivative is 0.
  409. if x.size == 3 and bc_start == 'not-a-knot' and bc_end == 'not-a-knot':
  410. assert_allclose(c[0], 0, rtol=tol, atol=tol)
  411. return
  412. # Check periodic boundary conditions.
  413. if bc_start == 'periodic':
  414. assert_allclose(S(x[0], 0), S(x[-1], 0), rtol=tol, atol=tol)
  415. assert_allclose(S(x[0], 1), S(x[-1], 1), rtol=tol, atol=tol)
  416. assert_allclose(S(x[0], 2), S(x[-1], 2), rtol=tol, atol=tol)
  417. return
  418. # Check other boundary conditions.
  419. if bc_start == 'not-a-knot':
  420. if x.size == 2:
  421. slope = (S(x[1]) - S(x[0])) / dx[0]
  422. assert_allclose(S(x[0], 1), slope, rtol=tol, atol=tol)
  423. else:
  424. assert_allclose(c[0, 0], c[0, 1], rtol=tol, atol=tol)
  425. elif bc_start == 'clamped':
  426. assert_allclose(S(x[0], 1), 0, rtol=tol, atol=tol)
  427. elif bc_start == 'natural':
  428. assert_allclose(S(x[0], 2), 0, rtol=tol, atol=tol)
  429. else:
  430. order, value = bc_start
  431. assert_allclose(S(x[0], order), value, rtol=tol, atol=tol)
  432. if bc_end == 'not-a-knot':
  433. if x.size == 2:
  434. slope = (S(x[1]) - S(x[0])) / dx[0]
  435. assert_allclose(S(x[1], 1), slope, rtol=tol, atol=tol)
  436. else:
  437. assert_allclose(c[0, -1], c[0, -2], rtol=tol, atol=tol)
  438. elif bc_end == 'clamped':
  439. assert_allclose(S(x[-1], 1), 0, rtol=tol, atol=tol)
  440. elif bc_end == 'natural':
  441. assert_allclose(S(x[-1], 2), 0, rtol=2*tol, atol=2*tol)
  442. else:
  443. order, value = bc_end
  444. assert_allclose(S(x[-1], order), value, rtol=tol, atol=tol)
  445. def check_all_bc(self, x, y, axis):
  446. deriv_shape = list(y.shape)
  447. del deriv_shape[axis]
  448. first_deriv = np.empty(deriv_shape)
  449. first_deriv.fill(2)
  450. second_deriv = np.empty(deriv_shape)
  451. second_deriv.fill(-1)
  452. bc_all = [
  453. 'not-a-knot',
  454. 'natural',
  455. 'clamped',
  456. (1, first_deriv),
  457. (2, second_deriv)
  458. ]
  459. for bc in bc_all[:3]:
  460. S = CubicSpline(x, y, axis=axis, bc_type=bc)
  461. self.check_correctness(S, bc, bc)
  462. for bc_start in bc_all:
  463. for bc_end in bc_all:
  464. S = CubicSpline(x, y, axis=axis, bc_type=(bc_start, bc_end))
  465. self.check_correctness(S, bc_start, bc_end, tol=2e-14)
  466. def test_general(self):
  467. x = np.array([-1, 0, 0.5, 2, 4, 4.5, 5.5, 9])
  468. y = np.array([0, -0.5, 2, 3, 2.5, 1, 1, 0.5])
  469. for n in [2, 3, x.size]:
  470. self.check_all_bc(x[:n], y[:n], 0)
  471. Y = np.empty((2, n, 2))
  472. Y[0, :, 0] = y[:n]
  473. Y[0, :, 1] = y[:n] - 1
  474. Y[1, :, 0] = y[:n] + 2
  475. Y[1, :, 1] = y[:n] + 3
  476. self.check_all_bc(x[:n], Y, 1)
  477. def test_periodic(self):
  478. for n in [2, 3, 5]:
  479. x = np.linspace(0, 2 * np.pi, n)
  480. y = np.cos(x)
  481. S = CubicSpline(x, y, bc_type='periodic')
  482. self.check_correctness(S, 'periodic', 'periodic')
  483. Y = np.empty((2, n, 2))
  484. Y[0, :, 0] = y
  485. Y[0, :, 1] = y + 2
  486. Y[1, :, 0] = y - 1
  487. Y[1, :, 1] = y + 5
  488. S = CubicSpline(x, Y, axis=1, bc_type='periodic')
  489. self.check_correctness(S, 'periodic', 'periodic')
  490. def test_periodic_eval(self):
  491. x = np.linspace(0, 2 * np.pi, 10)
  492. y = np.cos(x)
  493. S = CubicSpline(x, y, bc_type='periodic')
  494. assert_almost_equal(S(1), S(1 + 2 * np.pi), decimal=15)
  495. def test_dtypes(self):
  496. x = np.array([0, 1, 2, 3], dtype=int)
  497. y = np.array([-5, 2, 3, 1], dtype=int)
  498. S = CubicSpline(x, y)
  499. self.check_correctness(S)
  500. y = np.array([-1+1j, 0.0, 1-1j, 0.5-1.5j])
  501. S = CubicSpline(x, y)
  502. self.check_correctness(S)
  503. S = CubicSpline(x, x ** 3, bc_type=("natural", (1, 2j)))
  504. self.check_correctness(S, "natural", (1, 2j))
  505. y = np.array([-5, 2, 3, 1])
  506. S = CubicSpline(x, y, bc_type=[(1, 2 + 0.5j), (2, 0.5 - 1j)])
  507. self.check_correctness(S, (1, 2 + 0.5j), (2, 0.5 - 1j))
  508. def test_small_dx(self):
  509. rng = np.random.RandomState(0)
  510. x = np.sort(rng.uniform(size=100))
  511. y = 1e4 + rng.uniform(size=100)
  512. S = CubicSpline(x, y)
  513. self.check_correctness(S, tol=1e-13)
  514. def test_incorrect_inputs(self):
  515. x = np.array([1, 2, 3, 4])
  516. y = np.array([1, 2, 3, 4])
  517. xc = np.array([1 + 1j, 2, 3, 4])
  518. xn = np.array([np.nan, 2, 3, 4])
  519. xo = np.array([2, 1, 3, 4])
  520. yn = np.array([np.nan, 2, 3, 4])
  521. y3 = [1, 2, 3]
  522. x1 = [1]
  523. y1 = [1]
  524. assert_raises(ValueError, CubicSpline, xc, y)
  525. assert_raises(ValueError, CubicSpline, xn, y)
  526. assert_raises(ValueError, CubicSpline, x, yn)
  527. assert_raises(ValueError, CubicSpline, xo, y)
  528. assert_raises(ValueError, CubicSpline, x, y3)
  529. assert_raises(ValueError, CubicSpline, x[:, np.newaxis], y)
  530. assert_raises(ValueError, CubicSpline, x1, y1)
  531. wrong_bc = [('periodic', 'clamped'),
  532. ((2, 0), (3, 10)),
  533. ((1, 0), ),
  534. (0., 0.),
  535. 'not-a-typo']
  536. for bc_type in wrong_bc:
  537. assert_raises(ValueError, CubicSpline, x, y, 0, bc_type, True)
  538. # Shapes mismatch when giving arbitrary derivative values:
  539. Y = np.c_[y, y]
  540. bc1 = ('clamped', (1, 0))
  541. bc2 = ('clamped', (1, [0, 0, 0]))
  542. bc3 = ('clamped', (1, [[0, 0]]))
  543. assert_raises(ValueError, CubicSpline, x, Y, 0, bc1, True)
  544. assert_raises(ValueError, CubicSpline, x, Y, 0, bc2, True)
  545. assert_raises(ValueError, CubicSpline, x, Y, 0, bc3, True)
  546. # periodic condition, y[-1] must be equal to y[0]:
  547. assert_raises(ValueError, CubicSpline, x, y, 0, 'periodic', True)