_cubic.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770
  1. """Interpolation algorithms using piecewise cubic polynomials."""
  2. from __future__ import division, print_function, absolute_import
  3. import numpy as np
  4. from scipy._lib.six import string_types
  5. from . import BPoly, PPoly
  6. from .polyint import _isscalar
  7. from scipy._lib._util import _asarray_validated
  8. from scipy.linalg import solve_banded, solve
  9. __all__ = ["PchipInterpolator", "pchip_interpolate", "pchip",
  10. "Akima1DInterpolator", "CubicSpline"]
  11. class PchipInterpolator(BPoly):
  12. r"""PCHIP 1-d monotonic cubic interpolation.
  13. `x` and `y` are arrays of values used to approximate some function f,
  14. with ``y = f(x)``. The interpolant uses monotonic cubic splines
  15. to find the value of new points. (PCHIP stands for Piecewise Cubic
  16. Hermite Interpolating Polynomial).
  17. Parameters
  18. ----------
  19. x : ndarray
  20. A 1-D array of monotonically increasing real values. `x` cannot
  21. include duplicate values (otherwise f is overspecified)
  22. y : ndarray
  23. A 1-D array of real values. `y`'s length along the interpolation
  24. axis must be equal to the length of `x`. If N-D array, use `axis`
  25. parameter to select correct axis.
  26. axis : int, optional
  27. Axis in the y array corresponding to the x-coordinate values.
  28. extrapolate : bool, optional
  29. Whether to extrapolate to out-of-bounds points based on first
  30. and last intervals, or to return NaNs.
  31. Methods
  32. -------
  33. __call__
  34. derivative
  35. antiderivative
  36. roots
  37. See Also
  38. --------
  39. Akima1DInterpolator
  40. CubicSpline
  41. BPoly
  42. Notes
  43. -----
  44. The interpolator preserves monotonicity in the interpolation data and does
  45. not overshoot if the data is not smooth.
  46. The first derivatives are guaranteed to be continuous, but the second
  47. derivatives may jump at :math:`x_k`.
  48. Determines the derivatives at the points :math:`x_k`, :math:`f'_k`,
  49. by using PCHIP algorithm [1]_.
  50. Let :math:`h_k = x_{k+1} - x_k`, and :math:`d_k = (y_{k+1} - y_k) / h_k`
  51. are the slopes at internal points :math:`x_k`.
  52. If the signs of :math:`d_k` and :math:`d_{k-1}` are different or either of
  53. them equals zero, then :math:`f'_k = 0`. Otherwise, it is given by the
  54. weighted harmonic mean
  55. .. math::
  56. \frac{w_1 + w_2}{f'_k} = \frac{w_1}{d_{k-1}} + \frac{w_2}{d_k}
  57. where :math:`w_1 = 2 h_k + h_{k-1}` and :math:`w_2 = h_k + 2 h_{k-1}`.
  58. The end slopes are set using a one-sided scheme [2]_.
  59. References
  60. ----------
  61. .. [1] F. N. Fritsch and R. E. Carlson, Monotone Piecewise Cubic Interpolation,
  62. SIAM J. Numer. Anal., 17(2), 238 (1980).
  63. :doi:`10.1137/0717021`.
  64. .. [2] see, e.g., C. Moler, Numerical Computing with Matlab, 2004.
  65. :doi:`10.1137/1.9780898717952`
  66. """
  67. def __init__(self, x, y, axis=0, extrapolate=None):
  68. x = _asarray_validated(x, check_finite=False, as_inexact=True)
  69. y = _asarray_validated(y, check_finite=False, as_inexact=True)
  70. axis = axis % y.ndim
  71. xp = x.reshape((x.shape[0],) + (1,)*(y.ndim-1))
  72. yp = np.rollaxis(y, axis)
  73. dk = self._find_derivatives(xp, yp)
  74. data = np.hstack((yp[:, None, ...], dk[:, None, ...]))
  75. _b = BPoly.from_derivatives(x, data, orders=None)
  76. super(PchipInterpolator, self).__init__(_b.c, _b.x,
  77. extrapolate=extrapolate)
  78. self.axis = axis
  79. def roots(self):
  80. """
  81. Return the roots of the interpolated function.
  82. """
  83. return (PPoly.from_bernstein_basis(self)).roots()
  84. @staticmethod
  85. def _edge_case(h0, h1, m0, m1):
  86. # one-sided three-point estimate for the derivative
  87. d = ((2*h0 + h1)*m0 - h0*m1) / (h0 + h1)
  88. # try to preserve shape
  89. mask = np.sign(d) != np.sign(m0)
  90. mask2 = (np.sign(m0) != np.sign(m1)) & (np.abs(d) > 3.*np.abs(m0))
  91. mmm = (~mask) & mask2
  92. d[mask] = 0.
  93. d[mmm] = 3.*m0[mmm]
  94. return d
  95. @staticmethod
  96. def _find_derivatives(x, y):
  97. # Determine the derivatives at the points y_k, d_k, by using
  98. # PCHIP algorithm is:
  99. # We choose the derivatives at the point x_k by
  100. # Let m_k be the slope of the kth segment (between k and k+1)
  101. # If m_k=0 or m_{k-1}=0 or sgn(m_k) != sgn(m_{k-1}) then d_k == 0
  102. # else use weighted harmonic mean:
  103. # w_1 = 2h_k + h_{k-1}, w_2 = h_k + 2h_{k-1}
  104. # 1/d_k = 1/(w_1 + w_2)*(w_1 / m_k + w_2 / m_{k-1})
  105. # where h_k is the spacing between x_k and x_{k+1}
  106. y_shape = y.shape
  107. if y.ndim == 1:
  108. # So that _edge_case doesn't end up assigning to scalars
  109. x = x[:, None]
  110. y = y[:, None]
  111. hk = x[1:] - x[:-1]
  112. mk = (y[1:] - y[:-1]) / hk
  113. if y.shape[0] == 2:
  114. # edge case: only have two points, use linear interpolation
  115. dk = np.zeros_like(y)
  116. dk[0] = mk
  117. dk[1] = mk
  118. return dk.reshape(y_shape)
  119. smk = np.sign(mk)
  120. condition = (smk[1:] != smk[:-1]) | (mk[1:] == 0) | (mk[:-1] == 0)
  121. w1 = 2*hk[1:] + hk[:-1]
  122. w2 = hk[1:] + 2*hk[:-1]
  123. # values where division by zero occurs will be excluded
  124. # by 'condition' afterwards
  125. with np.errstate(divide='ignore'):
  126. whmean = (w1/mk[:-1] + w2/mk[1:]) / (w1 + w2)
  127. dk = np.zeros_like(y)
  128. dk[1:-1][condition] = 0.0
  129. dk[1:-1][~condition] = 1.0 / whmean[~condition]
  130. # special case endpoints, as suggested in
  131. # Cleve Moler, Numerical Computing with MATLAB, Chap 3.4
  132. dk[0] = PchipInterpolator._edge_case(hk[0], hk[1], mk[0], mk[1])
  133. dk[-1] = PchipInterpolator._edge_case(hk[-1], hk[-2], mk[-1], mk[-2])
  134. return dk.reshape(y_shape)
  135. def pchip_interpolate(xi, yi, x, der=0, axis=0):
  136. """
  137. Convenience function for pchip interpolation.
  138. xi and yi are arrays of values used to approximate some function f,
  139. with ``yi = f(xi)``. The interpolant uses monotonic cubic splines
  140. to find the value of new points x and the derivatives there.
  141. See `PchipInterpolator` for details.
  142. Parameters
  143. ----------
  144. xi : array_like
  145. A sorted list of x-coordinates, of length N.
  146. yi : array_like
  147. A 1-D array of real values. `yi`'s length along the interpolation
  148. axis must be equal to the length of `xi`. If N-D array, use axis
  149. parameter to select correct axis.
  150. x : scalar or array_like
  151. Of length M.
  152. der : int or list, optional
  153. Derivatives to extract. The 0-th derivative can be included to
  154. return the function value.
  155. axis : int, optional
  156. Axis in the yi array corresponding to the x-coordinate values.
  157. See Also
  158. --------
  159. PchipInterpolator
  160. Returns
  161. -------
  162. y : scalar or array_like
  163. The result, of length R or length M or M by R,
  164. """
  165. P = PchipInterpolator(xi, yi, axis=axis)
  166. if der == 0:
  167. return P(x)
  168. elif _isscalar(der):
  169. return P.derivative(der)(x)
  170. else:
  171. return [P.derivative(nu)(x) for nu in der]
  172. # Backwards compatibility
  173. pchip = PchipInterpolator
  174. class Akima1DInterpolator(PPoly):
  175. """
  176. Akima interpolator
  177. Fit piecewise cubic polynomials, given vectors x and y. The interpolation
  178. method by Akima uses a continuously differentiable sub-spline built from
  179. piecewise cubic polynomials. The resultant curve passes through the given
  180. data points and will appear smooth and natural.
  181. Parameters
  182. ----------
  183. x : ndarray, shape (m, )
  184. 1-D array of monotonically increasing real values.
  185. y : ndarray, shape (m, ...)
  186. N-D array of real values. The length of `y` along the first axis must
  187. be equal to the length of `x`.
  188. axis : int, optional
  189. Specifies the axis of `y` along which to interpolate. Interpolation
  190. defaults to the first axis of `y`.
  191. Methods
  192. -------
  193. __call__
  194. derivative
  195. antiderivative
  196. roots
  197. See Also
  198. --------
  199. PchipInterpolator
  200. CubicSpline
  201. PPoly
  202. Notes
  203. -----
  204. .. versionadded:: 0.14
  205. Use only for precise data, as the fitted curve passes through the given
  206. points exactly. This routine is useful for plotting a pleasingly smooth
  207. curve through a few given points for purposes of plotting.
  208. References
  209. ----------
  210. [1] A new method of interpolation and smooth curve fitting based
  211. on local procedures. Hiroshi Akima, J. ACM, October 1970, 17(4),
  212. 589-602.
  213. """
  214. def __init__(self, x, y, axis=0):
  215. # Original implementation in MATLAB by N. Shamsundar (BSD licensed), see
  216. # https://www.mathworks.com/matlabcentral/fileexchange/1814-akima-interpolation
  217. x, y = map(np.asarray, (x, y))
  218. axis = axis % y.ndim
  219. if np.any(np.diff(x) < 0.):
  220. raise ValueError("x must be strictly ascending")
  221. if x.ndim != 1:
  222. raise ValueError("x must be 1-dimensional")
  223. if x.size < 2:
  224. raise ValueError("at least 2 breakpoints are needed")
  225. if x.size != y.shape[axis]:
  226. raise ValueError("x.shape must equal y.shape[%s]" % axis)
  227. # move interpolation axis to front
  228. y = np.rollaxis(y, axis)
  229. # determine slopes between breakpoints
  230. m = np.empty((x.size + 3, ) + y.shape[1:])
  231. dx = np.diff(x)
  232. dx = dx[(slice(None), ) + (None, ) * (y.ndim - 1)]
  233. m[2:-2] = np.diff(y, axis=0) / dx
  234. # add two additional points on the left ...
  235. m[1] = 2. * m[2] - m[3]
  236. m[0] = 2. * m[1] - m[2]
  237. # ... and on the right
  238. m[-2] = 2. * m[-3] - m[-4]
  239. m[-1] = 2. * m[-2] - m[-3]
  240. # if m1 == m2 != m3 == m4, the slope at the breakpoint is not defined.
  241. # This is the fill value:
  242. t = .5 * (m[3:] + m[:-3])
  243. # get the denominator of the slope t
  244. dm = np.abs(np.diff(m, axis=0))
  245. f1 = dm[2:]
  246. f2 = dm[:-2]
  247. f12 = f1 + f2
  248. # These are the mask of where the the slope at breakpoint is defined:
  249. ind = np.nonzero(f12 > 1e-9 * np.max(f12))
  250. x_ind, y_ind = ind[0], ind[1:]
  251. # Set the slope at breakpoint
  252. t[ind] = (f1[ind] * m[(x_ind + 1,) + y_ind] +
  253. f2[ind] * m[(x_ind + 2,) + y_ind]) / f12[ind]
  254. # calculate the higher order coefficients
  255. c = (3. * m[2:-2] - 2. * t[:-1] - t[1:]) / dx
  256. d = (t[:-1] + t[1:] - 2. * m[2:-2]) / dx ** 2
  257. coeff = np.zeros((4, x.size - 1) + y.shape[1:])
  258. coeff[3] = y[:-1]
  259. coeff[2] = t[:-1]
  260. coeff[1] = c
  261. coeff[0] = d
  262. super(Akima1DInterpolator, self).__init__(coeff, x, extrapolate=False)
  263. self.axis = axis
  264. def extend(self, c, x, right=True):
  265. raise NotImplementedError("Extending a 1D Akima interpolator is not "
  266. "yet implemented")
  267. # These are inherited from PPoly, but they do not produce an Akima
  268. # interpolator. Hence stub them out.
  269. @classmethod
  270. def from_spline(cls, tck, extrapolate=None):
  271. raise NotImplementedError("This method does not make sense for "
  272. "an Akima interpolator.")
  273. @classmethod
  274. def from_bernstein_basis(cls, bp, extrapolate=None):
  275. raise NotImplementedError("This method does not make sense for "
  276. "an Akima interpolator.")
  277. class CubicSpline(PPoly):
  278. """Cubic spline data interpolator.
  279. Interpolate data with a piecewise cubic polynomial which is twice
  280. continuously differentiable [1]_. The result is represented as a `PPoly`
  281. instance with breakpoints matching the given data.
  282. Parameters
  283. ----------
  284. x : array_like, shape (n,)
  285. 1-d array containing values of the independent variable.
  286. Values must be real, finite and in strictly increasing order.
  287. y : array_like
  288. Array containing values of the dependent variable. It can have
  289. arbitrary number of dimensions, but the length along `axis` (see below)
  290. must match the length of `x`. Values must be finite.
  291. axis : int, optional
  292. Axis along which `y` is assumed to be varying. Meaning that for
  293. ``x[i]`` the corresponding values are ``np.take(y, i, axis=axis)``.
  294. Default is 0.
  295. bc_type : string or 2-tuple, optional
  296. Boundary condition type. Two additional equations, given by the
  297. boundary conditions, are required to determine all coefficients of
  298. polynomials on each segment [2]_.
  299. If `bc_type` is a string, then the specified condition will be applied
  300. at both ends of a spline. Available conditions are:
  301. * 'not-a-knot' (default): The first and second segment at a curve end
  302. are the same polynomial. It is a good default when there is no
  303. information on boundary conditions.
  304. * 'periodic': The interpolated functions is assumed to be periodic
  305. of period ``x[-1] - x[0]``. The first and last value of `y` must be
  306. identical: ``y[0] == y[-1]``. This boundary condition will result in
  307. ``y'[0] == y'[-1]`` and ``y''[0] == y''[-1]``.
  308. * 'clamped': The first derivative at curves ends are zero. Assuming
  309. a 1D `y`, ``bc_type=((1, 0.0), (1, 0.0))`` is the same condition.
  310. * 'natural': The second derivative at curve ends are zero. Assuming
  311. a 1D `y`, ``bc_type=((2, 0.0), (2, 0.0))`` is the same condition.
  312. If `bc_type` is a 2-tuple, the first and the second value will be
  313. applied at the curve start and end respectively. The tuple values can
  314. be one of the previously mentioned strings (except 'periodic') or a
  315. tuple `(order, deriv_values)` allowing to specify arbitrary
  316. derivatives at curve ends:
  317. * `order`: the derivative order, 1 or 2.
  318. * `deriv_value`: array_like containing derivative values, shape must
  319. be the same as `y`, excluding `axis` dimension. For example, if `y`
  320. is 1D, then `deriv_value` must be a scalar. If `y` is 3D with the
  321. shape (n0, n1, n2) and axis=2, then `deriv_value` must be 2D
  322. and have the shape (n0, n1).
  323. extrapolate : {bool, 'periodic', None}, optional
  324. If bool, determines whether to extrapolate to out-of-bounds points
  325. based on first and last intervals, or to return NaNs. If 'periodic',
  326. periodic extrapolation is used. If None (default), `extrapolate` is
  327. set to 'periodic' for ``bc_type='periodic'`` and to True otherwise.
  328. Attributes
  329. ----------
  330. x : ndarray, shape (n,)
  331. Breakpoints. The same `x` which was passed to the constructor.
  332. c : ndarray, shape (4, n-1, ...)
  333. Coefficients of the polynomials on each segment. The trailing
  334. dimensions match the dimensions of `y`, excluding `axis`. For example,
  335. if `y` is 1-d, then ``c[k, i]`` is a coefficient for
  336. ``(x-x[i])**(3-k)`` on the segment between ``x[i]`` and ``x[i+1]``.
  337. axis : int
  338. Interpolation axis. The same `axis` which was passed to the
  339. constructor.
  340. Methods
  341. -------
  342. __call__
  343. derivative
  344. antiderivative
  345. integrate
  346. roots
  347. See Also
  348. --------
  349. Akima1DInterpolator
  350. PchipInterpolator
  351. PPoly
  352. Notes
  353. -----
  354. Parameters `bc_type` and `interpolate` work independently, i.e. the former
  355. controls only construction of a spline, and the latter only evaluation.
  356. When a boundary condition is 'not-a-knot' and n = 2, it is replaced by
  357. a condition that the first derivative is equal to the linear interpolant
  358. slope. When both boundary conditions are 'not-a-knot' and n = 3, the
  359. solution is sought as a parabola passing through given points.
  360. When 'not-a-knot' boundary conditions is applied to both ends, the
  361. resulting spline will be the same as returned by `splrep` (with ``s=0``)
  362. and `InterpolatedUnivariateSpline`, but these two methods use a
  363. representation in B-spline basis.
  364. .. versionadded:: 0.18.0
  365. Examples
  366. --------
  367. In this example the cubic spline is used to interpolate a sampled sinusoid.
  368. You can see that the spline continuity property holds for the first and
  369. second derivatives and violates only for the third derivative.
  370. >>> from scipy.interpolate import CubicSpline
  371. >>> import matplotlib.pyplot as plt
  372. >>> x = np.arange(10)
  373. >>> y = np.sin(x)
  374. >>> cs = CubicSpline(x, y)
  375. >>> xs = np.arange(-0.5, 9.6, 0.1)
  376. >>> fig, ax = plt.subplots(figsize=(6.5, 4))
  377. >>> ax.plot(x, y, 'o', label='data')
  378. >>> ax.plot(xs, np.sin(xs), label='true')
  379. >>> ax.plot(xs, cs(xs), label="S")
  380. >>> ax.plot(xs, cs(xs, 1), label="S'")
  381. >>> ax.plot(xs, cs(xs, 2), label="S''")
  382. >>> ax.plot(xs, cs(xs, 3), label="S'''")
  383. >>> ax.set_xlim(-0.5, 9.5)
  384. >>> ax.legend(loc='lower left', ncol=2)
  385. >>> plt.show()
  386. In the second example, the unit circle is interpolated with a spline. A
  387. periodic boundary condition is used. You can see that the first derivative
  388. values, ds/dx=0, ds/dy=1 at the periodic point (1, 0) are correctly
  389. computed. Note that a circle cannot be exactly represented by a cubic
  390. spline. To increase precision, more breakpoints would be required.
  391. >>> theta = 2 * np.pi * np.linspace(0, 1, 5)
  392. >>> y = np.c_[np.cos(theta), np.sin(theta)]
  393. >>> cs = CubicSpline(theta, y, bc_type='periodic')
  394. >>> print("ds/dx={:.1f} ds/dy={:.1f}".format(cs(0, 1)[0], cs(0, 1)[1]))
  395. ds/dx=0.0 ds/dy=1.0
  396. >>> xs = 2 * np.pi * np.linspace(0, 1, 100)
  397. >>> fig, ax = plt.subplots(figsize=(6.5, 4))
  398. >>> ax.plot(y[:, 0], y[:, 1], 'o', label='data')
  399. >>> ax.plot(np.cos(xs), np.sin(xs), label='true')
  400. >>> ax.plot(cs(xs)[:, 0], cs(xs)[:, 1], label='spline')
  401. >>> ax.axes.set_aspect('equal')
  402. >>> ax.legend(loc='center')
  403. >>> plt.show()
  404. The third example is the interpolation of a polynomial y = x**3 on the
  405. interval 0 <= x<= 1. A cubic spline can represent this function exactly.
  406. To achieve that we need to specify values and first derivatives at
  407. endpoints of the interval. Note that y' = 3 * x**2 and thus y'(0) = 0 and
  408. y'(1) = 3.
  409. >>> cs = CubicSpline([0, 1], [0, 1], bc_type=((1, 0), (1, 3)))
  410. >>> x = np.linspace(0, 1)
  411. >>> np.allclose(x**3, cs(x))
  412. True
  413. References
  414. ----------
  415. .. [1] `Cubic Spline Interpolation
  416. <https://en.wikiversity.org/wiki/Cubic_Spline_Interpolation>`_
  417. on Wikiversity.
  418. .. [2] Carl de Boor, "A Practical Guide to Splines", Springer-Verlag, 1978.
  419. """
  420. def __init__(self, x, y, axis=0, bc_type='not-a-knot', extrapolate=None):
  421. x, y = map(np.asarray, (x, y))
  422. if np.issubdtype(x.dtype, np.complexfloating):
  423. raise ValueError("`x` must contain real values.")
  424. if np.issubdtype(y.dtype, np.complexfloating):
  425. dtype = complex
  426. else:
  427. dtype = float
  428. y = y.astype(dtype, copy=False)
  429. axis = axis % y.ndim
  430. if x.ndim != 1:
  431. raise ValueError("`x` must be 1-dimensional.")
  432. if x.shape[0] < 2:
  433. raise ValueError("`x` must contain at least 2 elements.")
  434. if x.shape[0] != y.shape[axis]:
  435. raise ValueError("The length of `y` along `axis`={0} doesn't "
  436. "match the length of `x`".format(axis))
  437. if not np.all(np.isfinite(x)):
  438. raise ValueError("`x` must contain only finite values.")
  439. if not np.all(np.isfinite(y)):
  440. raise ValueError("`y` must contain only finite values.")
  441. dx = np.diff(x)
  442. if np.any(dx <= 0):
  443. raise ValueError("`x` must be strictly increasing sequence.")
  444. n = x.shape[0]
  445. y = np.rollaxis(y, axis)
  446. bc, y = self._validate_bc(bc_type, y, y.shape[1:], axis)
  447. if extrapolate is None:
  448. if bc[0] == 'periodic':
  449. extrapolate = 'periodic'
  450. else:
  451. extrapolate = True
  452. dxr = dx.reshape([dx.shape[0]] + [1] * (y.ndim - 1))
  453. slope = np.diff(y, axis=0) / dxr
  454. # If bc is 'not-a-knot' this change is just a convention.
  455. # If bc is 'periodic' then we already checked that y[0] == y[-1],
  456. # and the spline is just a constant, we handle this case in the same
  457. # way by setting the first derivatives to slope, which is 0.
  458. if n == 2:
  459. if bc[0] in ['not-a-knot', 'periodic']:
  460. bc[0] = (1, slope[0])
  461. if bc[1] in ['not-a-knot', 'periodic']:
  462. bc[1] = (1, slope[0])
  463. # This is a very special case, when both conditions are 'not-a-knot'
  464. # and n == 3. In this case 'not-a-knot' can't be handled regularly
  465. # as the both conditions are identical. We handle this case by
  466. # constructing a parabola passing through given points.
  467. if n == 3 and bc[0] == 'not-a-knot' and bc[1] == 'not-a-knot':
  468. A = np.zeros((3, 3)) # This is a standard matrix.
  469. b = np.empty((3,) + y.shape[1:], dtype=y.dtype)
  470. A[0, 0] = 1
  471. A[0, 1] = 1
  472. A[1, 0] = dx[1]
  473. A[1, 1] = 2 * (dx[0] + dx[1])
  474. A[1, 2] = dx[0]
  475. A[2, 1] = 1
  476. A[2, 2] = 1
  477. b[0] = 2 * slope[0]
  478. b[1] = 3 * (dxr[0] * slope[1] + dxr[1] * slope[0])
  479. b[2] = 2 * slope[1]
  480. s = solve(A, b, overwrite_a=True, overwrite_b=True,
  481. check_finite=False)
  482. else:
  483. # Find derivative values at each x[i] by solving a tridiagonal
  484. # system.
  485. A = np.zeros((3, n)) # This is a banded matrix representation.
  486. b = np.empty((n,) + y.shape[1:], dtype=y.dtype)
  487. # Filling the system for i=1..n-2
  488. # (x[i-1] - x[i]) * s[i-1] +\
  489. # 2 * ((x[i] - x[i-1]) + (x[i+1] - x[i])) * s[i] +\
  490. # (x[i] - x[i-1]) * s[i+1] =\
  491. # 3 * ((x[i+1] - x[i])*(y[i] - y[i-1])/(x[i] - x[i-1]) +\
  492. # (x[i] - x[i-1])*(y[i+1] - y[i])/(x[i+1] - x[i]))
  493. A[1, 1:-1] = 2 * (dx[:-1] + dx[1:]) # The diagonal
  494. A[0, 2:] = dx[:-1] # The upper diagonal
  495. A[-1, :-2] = dx[1:] # The lower diagonal
  496. b[1:-1] = 3 * (dxr[1:] * slope[:-1] + dxr[:-1] * slope[1:])
  497. bc_start, bc_end = bc
  498. if bc_start == 'periodic':
  499. # Due to the periodicity, and because y[-1] = y[0], the linear
  500. # system has (n-1) unknowns/equations instead of n:
  501. A = A[:, 0:-1]
  502. A[1, 0] = 2 * (dx[-1] + dx[0])
  503. A[0, 1] = dx[-1]
  504. b = b[:-1]
  505. # Also, due to the periodicity, the system is not tri-diagonal.
  506. # We need to compute a "condensed" matrix of shape (n-2, n-2).
  507. # See https://web.archive.org/web/20151220180652/http://www.cfm.brown.edu/people/gk/chap6/node14.html
  508. # for more explanations.
  509. # The condensed matrix is obtained by removing the last column
  510. # and last row of the (n-1, n-1) system matrix. The removed
  511. # values are saved in scalar variables with the (n-1, n-1)
  512. # system matrix indices forming their names:
  513. a_m1_0 = dx[-2] # lower left corner value: A[-1, 0]
  514. a_m1_m2 = dx[-1]
  515. a_m1_m1 = 2 * (dx[-1] + dx[-2])
  516. a_m2_m1 = dx[-2]
  517. a_0_m1 = dx[0]
  518. b[0] = 3 * (dxr[0] * slope[-1] + dxr[-1] * slope[0])
  519. b[-1] = 3 * (dxr[-1] * slope[-2] + dxr[-2] * slope[-1])
  520. Ac = A[:, :-1]
  521. b1 = b[:-1]
  522. b2 = np.zeros_like(b1)
  523. b2[0] = -a_0_m1
  524. b2[-1] = -a_m2_m1
  525. # s1 and s2 are the solutions of (n-2, n-2) system
  526. s1 = solve_banded((1, 1), Ac, b1, overwrite_ab=False,
  527. overwrite_b=False, check_finite=False)
  528. s2 = solve_banded((1, 1), Ac, b2, overwrite_ab=False,
  529. overwrite_b=False, check_finite=False)
  530. # computing the s[n-2] solution:
  531. s_m1 = ((b[-1] - a_m1_0 * s1[0] - a_m1_m2 * s1[-1]) /
  532. (a_m1_m1 + a_m1_0 * s2[0] + a_m1_m2 * s2[-1]))
  533. # s is the solution of the (n, n) system:
  534. s = np.empty((n,) + y.shape[1:], dtype=y.dtype)
  535. s[:-2] = s1 + s_m1 * s2
  536. s[-2] = s_m1
  537. s[-1] = s[0]
  538. else:
  539. if bc_start == 'not-a-knot':
  540. A[1, 0] = dx[1]
  541. A[0, 1] = x[2] - x[0]
  542. d = x[2] - x[0]
  543. b[0] = ((dxr[0] + 2*d) * dxr[1] * slope[0] +
  544. dxr[0]**2 * slope[1]) / d
  545. elif bc_start[0] == 1:
  546. A[1, 0] = 1
  547. A[0, 1] = 0
  548. b[0] = bc_start[1]
  549. elif bc_start[0] == 2:
  550. A[1, 0] = 2 * dx[0]
  551. A[0, 1] = dx[0]
  552. b[0] = -0.5 * bc_start[1] * dx[0]**2 + 3 * (y[1] - y[0])
  553. if bc_end == 'not-a-knot':
  554. A[1, -1] = dx[-2]
  555. A[-1, -2] = x[-1] - x[-3]
  556. d = x[-1] - x[-3]
  557. b[-1] = ((dxr[-1]**2*slope[-2] +
  558. (2*d + dxr[-1])*dxr[-2]*slope[-1]) / d)
  559. elif bc_end[0] == 1:
  560. A[1, -1] = 1
  561. A[-1, -2] = 0
  562. b[-1] = bc_end[1]
  563. elif bc_end[0] == 2:
  564. A[1, -1] = 2 * dx[-1]
  565. A[-1, -2] = dx[-1]
  566. b[-1] = 0.5 * bc_end[1] * dx[-1]**2 + 3 * (y[-1] - y[-2])
  567. s = solve_banded((1, 1), A, b, overwrite_ab=True,
  568. overwrite_b=True, check_finite=False)
  569. # Compute coefficients in PPoly form.
  570. t = (s[:-1] + s[1:] - 2 * slope) / dxr
  571. c = np.empty((4, n - 1) + y.shape[1:], dtype=t.dtype)
  572. c[0] = t / dxr
  573. c[1] = (slope - s[:-1]) / dxr - t
  574. c[2] = s[:-1]
  575. c[3] = y[:-1]
  576. super(CubicSpline, self).__init__(c, x, extrapolate=extrapolate)
  577. self.axis = axis
  578. @staticmethod
  579. def _validate_bc(bc_type, y, expected_deriv_shape, axis):
  580. """Validate and prepare boundary conditions.
  581. Returns
  582. -------
  583. validated_bc : 2-tuple
  584. Boundary conditions for a curve start and end.
  585. y : ndarray
  586. y casted to complex dtype if one of the boundary conditions has
  587. complex dtype.
  588. """
  589. if isinstance(bc_type, string_types):
  590. if bc_type == 'periodic':
  591. if not np.allclose(y[0], y[-1], rtol=1e-15, atol=1e-15):
  592. raise ValueError(
  593. "The first and last `y` point along axis {} must "
  594. "be identical (within machine precision) when "
  595. "bc_type='periodic'.".format(axis))
  596. bc_type = (bc_type, bc_type)
  597. else:
  598. if len(bc_type) != 2:
  599. raise ValueError("`bc_type` must contain 2 elements to "
  600. "specify start and end conditions.")
  601. if 'periodic' in bc_type:
  602. raise ValueError("'periodic' `bc_type` is defined for both "
  603. "curve ends and cannot be used with other "
  604. "boundary conditions.")
  605. validated_bc = []
  606. for bc in bc_type:
  607. if isinstance(bc, string_types):
  608. if bc == 'clamped':
  609. validated_bc.append((1, np.zeros(expected_deriv_shape)))
  610. elif bc == 'natural':
  611. validated_bc.append((2, np.zeros(expected_deriv_shape)))
  612. elif bc in ['not-a-knot', 'periodic']:
  613. validated_bc.append(bc)
  614. else:
  615. raise ValueError("bc_type={} is not allowed.".format(bc))
  616. else:
  617. try:
  618. deriv_order, deriv_value = bc
  619. except Exception:
  620. raise ValueError("A specified derivative value must be "
  621. "given in the form (order, value).")
  622. if deriv_order not in [1, 2]:
  623. raise ValueError("The specified derivative order must "
  624. "be 1 or 2.")
  625. deriv_value = np.asarray(deriv_value)
  626. if deriv_value.shape != expected_deriv_shape:
  627. raise ValueError(
  628. "`deriv_value` shape {} is not the expected one {}."
  629. .format(deriv_value.shape, expected_deriv_shape))
  630. if np.issubdtype(deriv_value.dtype, np.complexfloating):
  631. y = y.astype(complex, copy=False)
  632. validated_bc.append((deriv_order, deriv_value))
  633. return validated_bc, y