test_bsplines.py 42 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247
  1. from __future__ import division, absolute_import, print_function
  2. import numpy as np
  3. from numpy.testing import assert_equal, assert_allclose, assert_
  4. from scipy._lib._numpy_compat import suppress_warnings
  5. from pytest import raises as assert_raises
  6. import pytest
  7. from scipy.interpolate import (BSpline, BPoly, PPoly, make_interp_spline,
  8. make_lsq_spline, _bspl, splev, splrep, splprep, splder, splantider,
  9. sproot, splint, insert)
  10. import scipy.linalg as sl
  11. from scipy._lib._version import NumpyVersion
  12. from scipy.interpolate._bsplines import _not_a_knot, _augknt
  13. import scipy.interpolate._fitpack_impl as _impl
  14. from scipy.interpolate._fitpack import _splint
  15. class TestBSpline(object):
  16. def test_ctor(self):
  17. # knots should be an ordered 1D array of finite real numbers
  18. assert_raises((TypeError, ValueError), BSpline,
  19. **dict(t=[1, 1.j], c=[1.], k=0))
  20. with np.errstate(invalid='ignore'):
  21. assert_raises(ValueError, BSpline, **dict(t=[1, np.nan], c=[1.], k=0))
  22. assert_raises(ValueError, BSpline, **dict(t=[1, np.inf], c=[1.], k=0))
  23. assert_raises(ValueError, BSpline, **dict(t=[1, -1], c=[1.], k=0))
  24. assert_raises(ValueError, BSpline, **dict(t=[[1], [1]], c=[1.], k=0))
  25. # for n+k+1 knots and degree k need at least n coefficients
  26. assert_raises(ValueError, BSpline, **dict(t=[0, 1, 2], c=[1], k=0))
  27. assert_raises(ValueError, BSpline,
  28. **dict(t=[0, 1, 2, 3, 4], c=[1., 1.], k=2))
  29. # non-integer orders
  30. assert_raises(TypeError, BSpline,
  31. **dict(t=[0., 0., 1., 2., 3., 4.], c=[1., 1., 1.], k="cubic"))
  32. assert_raises(TypeError, BSpline,
  33. **dict(t=[0., 0., 1., 2., 3., 4.], c=[1., 1., 1.], k=2.5))
  34. # basic interval cannot have measure zero (here: [1..1])
  35. assert_raises(ValueError, BSpline,
  36. **dict(t=[0., 0, 1, 1, 2, 3], c=[1., 1, 1], k=2))
  37. # tck vs self.tck
  38. n, k = 11, 3
  39. t = np.arange(n+k+1)
  40. c = np.random.random(n)
  41. b = BSpline(t, c, k)
  42. assert_allclose(t, b.t)
  43. assert_allclose(c, b.c)
  44. assert_equal(k, b.k)
  45. def test_tck(self):
  46. b = _make_random_spline()
  47. tck = b.tck
  48. assert_allclose(b.t, tck[0], atol=1e-15, rtol=1e-15)
  49. assert_allclose(b.c, tck[1], atol=1e-15, rtol=1e-15)
  50. assert_equal(b.k, tck[2])
  51. # b.tck is read-only
  52. with pytest.raises(AttributeError):
  53. b.tck = 'foo'
  54. def test_degree_0(self):
  55. xx = np.linspace(0, 1, 10)
  56. b = BSpline(t=[0, 1], c=[3.], k=0)
  57. assert_allclose(b(xx), 3)
  58. b = BSpline(t=[0, 0.35, 1], c=[3, 4], k=0)
  59. assert_allclose(b(xx), np.where(xx < 0.35, 3, 4))
  60. def test_degree_1(self):
  61. t = [0, 1, 2, 3, 4]
  62. c = [1, 2, 3]
  63. k = 1
  64. b = BSpline(t, c, k)
  65. x = np.linspace(1, 3, 50)
  66. assert_allclose(c[0]*B_012(x) + c[1]*B_012(x-1) + c[2]*B_012(x-2),
  67. b(x), atol=1e-14)
  68. assert_allclose(splev(x, (t, c, k)), b(x), atol=1e-14)
  69. def test_bernstein(self):
  70. # a special knot vector: Bernstein polynomials
  71. k = 3
  72. t = np.asarray([0]*(k+1) + [1]*(k+1))
  73. c = np.asarray([1., 2., 3., 4.])
  74. bp = BPoly(c.reshape(-1, 1), [0, 1])
  75. bspl = BSpline(t, c, k)
  76. xx = np.linspace(-1., 2., 10)
  77. assert_allclose(bp(xx, extrapolate=True),
  78. bspl(xx, extrapolate=True), atol=1e-14)
  79. assert_allclose(splev(xx, (t, c, k)),
  80. bspl(xx), atol=1e-14)
  81. def test_rndm_naive_eval(self):
  82. # test random coefficient spline *on the base interval*,
  83. # t[k] <= x < t[-k-1]
  84. b = _make_random_spline()
  85. t, c, k = b.tck
  86. xx = np.linspace(t[k], t[-k-1], 50)
  87. y_b = b(xx)
  88. y_n = [_naive_eval(x, t, c, k) for x in xx]
  89. assert_allclose(y_b, y_n, atol=1e-14)
  90. y_n2 = [_naive_eval_2(x, t, c, k) for x in xx]
  91. assert_allclose(y_b, y_n2, atol=1e-14)
  92. def test_rndm_splev(self):
  93. b = _make_random_spline()
  94. t, c, k = b.tck
  95. xx = np.linspace(t[k], t[-k-1], 50)
  96. assert_allclose(b(xx), splev(xx, (t, c, k)), atol=1e-14)
  97. def test_rndm_splrep(self):
  98. np.random.seed(1234)
  99. x = np.sort(np.random.random(20))
  100. y = np.random.random(20)
  101. tck = splrep(x, y)
  102. b = BSpline(*tck)
  103. t, k = b.t, b.k
  104. xx = np.linspace(t[k], t[-k-1], 80)
  105. assert_allclose(b(xx), splev(xx, tck), atol=1e-14)
  106. def test_rndm_unity(self):
  107. b = _make_random_spline()
  108. b.c = np.ones_like(b.c)
  109. xx = np.linspace(b.t[b.k], b.t[-b.k-1], 100)
  110. assert_allclose(b(xx), 1.)
  111. def test_vectorization(self):
  112. n, k = 22, 3
  113. t = np.sort(np.random.random(n))
  114. c = np.random.random(size=(n, 6, 7))
  115. b = BSpline(t, c, k)
  116. tm, tp = t[k], t[-k-1]
  117. xx = tm + (tp - tm) * np.random.random((3, 4, 5))
  118. assert_equal(b(xx).shape, (3, 4, 5, 6, 7))
  119. def test_len_c(self):
  120. # for n+k+1 knots, only first n coefs are used.
  121. # and BTW this is consistent with FITPACK
  122. n, k = 33, 3
  123. t = np.sort(np.random.random(n+k+1))
  124. c = np.random.random(n)
  125. # pad coefficients with random garbage
  126. c_pad = np.r_[c, np.random.random(k+1)]
  127. b, b_pad = BSpline(t, c, k), BSpline(t, c_pad, k)
  128. dt = t[-1] - t[0]
  129. xx = np.linspace(t[0] - dt, t[-1] + dt, 50)
  130. assert_allclose(b(xx), b_pad(xx), atol=1e-14)
  131. assert_allclose(b(xx), splev(xx, (t, c, k)), atol=1e-14)
  132. assert_allclose(b(xx), splev(xx, (t, c_pad, k)), atol=1e-14)
  133. def test_endpoints(self):
  134. # base interval is closed
  135. b = _make_random_spline()
  136. t, _, k = b.tck
  137. tm, tp = t[k], t[-k-1]
  138. for extrap in (True, False):
  139. assert_allclose(b([tm, tp], extrap),
  140. b([tm + 1e-10, tp - 1e-10], extrap), atol=1e-9)
  141. def test_continuity(self):
  142. # assert continuity at internal knots
  143. b = _make_random_spline()
  144. t, _, k = b.tck
  145. assert_allclose(b(t[k+1:-k-1] - 1e-10), b(t[k+1:-k-1] + 1e-10),
  146. atol=1e-9)
  147. def test_extrap(self):
  148. b = _make_random_spline()
  149. t, c, k = b.tck
  150. dt = t[-1] - t[0]
  151. xx = np.linspace(t[k] - dt, t[-k-1] + dt, 50)
  152. mask = (t[k] < xx) & (xx < t[-k-1])
  153. # extrap has no effect within the base interval
  154. assert_allclose(b(xx[mask], extrapolate=True),
  155. b(xx[mask], extrapolate=False))
  156. # extrapolated values agree with FITPACK
  157. assert_allclose(b(xx, extrapolate=True),
  158. splev(xx, (t, c, k), ext=0))
  159. def test_default_extrap(self):
  160. # BSpline defaults to extrapolate=True
  161. b = _make_random_spline()
  162. t, _, k = b.tck
  163. xx = [t[0] - 1, t[-1] + 1]
  164. yy = b(xx)
  165. assert_(not np.all(np.isnan(yy)))
  166. def test_periodic_extrap(self):
  167. np.random.seed(1234)
  168. t = np.sort(np.random.random(8))
  169. c = np.random.random(4)
  170. k = 3
  171. b = BSpline(t, c, k, extrapolate='periodic')
  172. n = t.size - (k + 1)
  173. dt = t[-1] - t[0]
  174. xx = np.linspace(t[k] - dt, t[n] + dt, 50)
  175. xy = t[k] + (xx - t[k]) % (t[n] - t[k])
  176. assert_allclose(b(xx), splev(xy, (t, c, k)))
  177. # Direct check
  178. xx = [-1, 0, 0.5, 1]
  179. xy = t[k] + (xx - t[k]) % (t[n] - t[k])
  180. assert_equal(b(xx, extrapolate='periodic'), b(xy, extrapolate=True))
  181. def test_ppoly(self):
  182. b = _make_random_spline()
  183. t, c, k = b.tck
  184. pp = PPoly.from_spline((t, c, k))
  185. xx = np.linspace(t[k], t[-k], 100)
  186. assert_allclose(b(xx), pp(xx), atol=1e-14, rtol=1e-14)
  187. def test_derivative_rndm(self):
  188. b = _make_random_spline()
  189. t, c, k = b.tck
  190. xx = np.linspace(t[0], t[-1], 50)
  191. xx = np.r_[xx, t]
  192. for der in range(1, k+1):
  193. yd = splev(xx, (t, c, k), der=der)
  194. assert_allclose(yd, b(xx, nu=der), atol=1e-14)
  195. # higher derivatives all vanish
  196. assert_allclose(b(xx, nu=k+1), 0, atol=1e-14)
  197. def test_derivative_jumps(self):
  198. # example from de Boor, Chap IX, example (24)
  199. # NB: knots augmented & corresp coefs are zeroed out
  200. # in agreement with the convention (29)
  201. k = 2
  202. t = [-1, -1, 0, 1, 1, 3, 4, 6, 6, 6, 7, 7]
  203. np.random.seed(1234)
  204. c = np.r_[0, 0, np.random.random(5), 0, 0]
  205. b = BSpline(t, c, k)
  206. # b is continuous at x != 6 (triple knot)
  207. x = np.asarray([1, 3, 4, 6])
  208. assert_allclose(b(x[x != 6] - 1e-10),
  209. b(x[x != 6] + 1e-10))
  210. assert_(not np.allclose(b(6.-1e-10), b(6+1e-10)))
  211. # 1st derivative jumps at double knots, 1 & 6:
  212. x0 = np.asarray([3, 4])
  213. assert_allclose(b(x0 - 1e-10, nu=1),
  214. b(x0 + 1e-10, nu=1))
  215. x1 = np.asarray([1, 6])
  216. assert_(not np.all(np.allclose(b(x1 - 1e-10, nu=1),
  217. b(x1 + 1e-10, nu=1))))
  218. # 2nd derivative is not guaranteed to be continuous either
  219. assert_(not np.all(np.allclose(b(x - 1e-10, nu=2),
  220. b(x + 1e-10, nu=2))))
  221. def test_basis_element_quadratic(self):
  222. xx = np.linspace(-1, 4, 20)
  223. b = BSpline.basis_element(t=[0, 1, 2, 3])
  224. assert_allclose(b(xx),
  225. splev(xx, (b.t, b.c, b.k)), atol=1e-14)
  226. assert_allclose(b(xx),
  227. B_0123(xx), atol=1e-14)
  228. b = BSpline.basis_element(t=[0, 1, 1, 2])
  229. xx = np.linspace(0, 2, 10)
  230. assert_allclose(b(xx),
  231. np.where(xx < 1, xx*xx, (2.-xx)**2), atol=1e-14)
  232. def test_basis_element_rndm(self):
  233. b = _make_random_spline()
  234. t, c, k = b.tck
  235. xx = np.linspace(t[k], t[-k-1], 20)
  236. assert_allclose(b(xx), _sum_basis_elements(xx, t, c, k), atol=1e-14)
  237. def test_cmplx(self):
  238. b = _make_random_spline()
  239. t, c, k = b.tck
  240. cc = c * (1. + 3.j)
  241. b = BSpline(t, cc, k)
  242. b_re = BSpline(t, b.c.real, k)
  243. b_im = BSpline(t, b.c.imag, k)
  244. xx = np.linspace(t[k], t[-k-1], 20)
  245. assert_allclose(b(xx).real, b_re(xx), atol=1e-14)
  246. assert_allclose(b(xx).imag, b_im(xx), atol=1e-14)
  247. def test_nan(self):
  248. # nan in, nan out.
  249. b = BSpline.basis_element([0, 1, 1, 2])
  250. assert_(np.isnan(b(np.nan)))
  251. def test_derivative_method(self):
  252. b = _make_random_spline(k=5)
  253. t, c, k = b.tck
  254. b0 = BSpline(t, c, k)
  255. xx = np.linspace(t[k], t[-k-1], 20)
  256. for j in range(1, k):
  257. b = b.derivative()
  258. assert_allclose(b0(xx, j), b(xx), atol=1e-12, rtol=1e-12)
  259. def test_antiderivative_method(self):
  260. b = _make_random_spline()
  261. t, c, k = b.tck
  262. xx = np.linspace(t[k], t[-k-1], 20)
  263. assert_allclose(b.antiderivative().derivative()(xx),
  264. b(xx), atol=1e-14, rtol=1e-14)
  265. # repeat with n-D array for c
  266. c = np.c_[c, c, c]
  267. c = np.dstack((c, c))
  268. b = BSpline(t, c, k)
  269. assert_allclose(b.antiderivative().derivative()(xx),
  270. b(xx), atol=1e-14, rtol=1e-14)
  271. def test_integral(self):
  272. b = BSpline.basis_element([0, 1, 2]) # x for x < 1 else 2 - x
  273. assert_allclose(b.integrate(0, 1), 0.5)
  274. assert_allclose(b.integrate(1, 0), -1 * 0.5)
  275. assert_allclose(b.integrate(1, 0), -0.5)
  276. # extrapolate or zeros outside of [0, 2]; default is yes
  277. assert_allclose(b.integrate(-1, 1), 0)
  278. assert_allclose(b.integrate(-1, 1, extrapolate=True), 0)
  279. assert_allclose(b.integrate(-1, 1, extrapolate=False), 0.5)
  280. assert_allclose(b.integrate(1, -1, extrapolate=False), -1 * 0.5)
  281. # Test ``_fitpack._splint()``
  282. t, c, k = b.tck
  283. assert_allclose(b.integrate(1, -1, extrapolate=False),
  284. _splint(t, c, k, 1, -1)[0])
  285. # Test ``extrapolate='periodic'``.
  286. b.extrapolate = 'periodic'
  287. i = b.antiderivative()
  288. period_int = i(2) - i(0)
  289. assert_allclose(b.integrate(0, 2), period_int)
  290. assert_allclose(b.integrate(2, 0), -1 * period_int)
  291. assert_allclose(b.integrate(-9, -7), period_int)
  292. assert_allclose(b.integrate(-8, -4), 2 * period_int)
  293. assert_allclose(b.integrate(0.5, 1.5), i(1.5) - i(0.5))
  294. assert_allclose(b.integrate(1.5, 3), i(1) - i(0) + i(2) - i(1.5))
  295. assert_allclose(b.integrate(1.5 + 12, 3 + 12),
  296. i(1) - i(0) + i(2) - i(1.5))
  297. assert_allclose(b.integrate(1.5, 3 + 12),
  298. i(1) - i(0) + i(2) - i(1.5) + 6 * period_int)
  299. assert_allclose(b.integrate(0, -1), i(0) - i(1))
  300. assert_allclose(b.integrate(-9, -10), i(0) - i(1))
  301. assert_allclose(b.integrate(0, -9), i(1) - i(2) - 4 * period_int)
  302. def test_integrate_ppoly(self):
  303. # test .integrate method to be consistent with PPoly.integrate
  304. x = [0, 1, 2, 3, 4]
  305. b = make_interp_spline(x, x)
  306. b.extrapolate = 'periodic'
  307. p = PPoly.from_spline(b)
  308. for x0, x1 in [(-5, 0.5), (0.5, 5), (-4, 13)]:
  309. assert_allclose(b.integrate(x0, x1),
  310. p.integrate(x0, x1))
  311. def test_subclassing(self):
  312. # classmethods should not decay to the base class
  313. class B(BSpline):
  314. pass
  315. b = B.basis_element([0, 1, 2, 2])
  316. assert_equal(b.__class__, B)
  317. assert_equal(b.derivative().__class__, B)
  318. assert_equal(b.antiderivative().__class__, B)
  319. def test_axis(self):
  320. n, k = 22, 3
  321. t = np.linspace(0, 1, n + k + 1)
  322. sh0 = [6, 7, 8]
  323. for axis in range(4):
  324. sh = sh0[:]
  325. sh.insert(axis, n) # [22, 6, 7, 8] etc
  326. c = np.random.random(size=sh)
  327. b = BSpline(t, c, k, axis=axis)
  328. assert_equal(b.c.shape,
  329. [sh[axis],] + sh[:axis] + sh[axis+1:])
  330. xp = np.random.random((3, 4, 5))
  331. assert_equal(b(xp).shape,
  332. sh[:axis] + list(xp.shape) + sh[axis+1:])
  333. #0 <= axis < c.ndim
  334. for ax in [-1, c.ndim]:
  335. assert_raises(ValueError, BSpline, **dict(t=t, c=c, k=k, axis=ax))
  336. # derivative, antiderivative keeps the axis
  337. for b1 in [BSpline(t, c, k, axis=axis).derivative(),
  338. BSpline(t, c, k, axis=axis).derivative(2),
  339. BSpline(t, c, k, axis=axis).antiderivative(),
  340. BSpline(t, c, k, axis=axis).antiderivative(2)]:
  341. assert_equal(b1.axis, b.axis)
  342. def test_knots_multiplicity():
  343. # Take a spline w/ random coefficients, throw in knots of varying
  344. # multiplicity.
  345. def check_splev(b, j, der=0, atol=1e-14, rtol=1e-14):
  346. # check evaluations against FITPACK, incl extrapolations
  347. t, c, k = b.tck
  348. x = np.unique(t)
  349. x = np.r_[t[0]-0.1, 0.5*(x[1:] + x[:1]), t[-1]+0.1]
  350. assert_allclose(splev(x, (t, c, k), der), b(x, der),
  351. atol=atol, rtol=rtol, err_msg='der = %s k = %s' % (der, b.k))
  352. # test loop itself
  353. # [the index `j` is for interpreting the traceback in case of a failure]
  354. for k in [1, 2, 3, 4, 5]:
  355. b = _make_random_spline(k=k)
  356. for j, b1 in enumerate(_make_multiples(b)):
  357. check_splev(b1, j)
  358. for der in range(1, k+1):
  359. check_splev(b1, j, der, 1e-12, 1e-12)
  360. ### stolen from @pv, verbatim
  361. def _naive_B(x, k, i, t):
  362. """
  363. Naive way to compute B-spline basis functions. Useful only for testing!
  364. computes B(x; t[i],..., t[i+k+1])
  365. """
  366. if k == 0:
  367. return 1.0 if t[i] <= x < t[i+1] else 0.0
  368. if t[i+k] == t[i]:
  369. c1 = 0.0
  370. else:
  371. c1 = (x - t[i])/(t[i+k] - t[i]) * _naive_B(x, k-1, i, t)
  372. if t[i+k+1] == t[i+1]:
  373. c2 = 0.0
  374. else:
  375. c2 = (t[i+k+1] - x)/(t[i+k+1] - t[i+1]) * _naive_B(x, k-1, i+1, t)
  376. return (c1 + c2)
  377. ### stolen from @pv, verbatim
  378. def _naive_eval(x, t, c, k):
  379. """
  380. Naive B-spline evaluation. Useful only for testing!
  381. """
  382. if x == t[k]:
  383. i = k
  384. else:
  385. i = np.searchsorted(t, x) - 1
  386. assert t[i] <= x <= t[i+1]
  387. assert i >= k and i < len(t) - k
  388. return sum(c[i-j] * _naive_B(x, k, i-j, t) for j in range(0, k+1))
  389. def _naive_eval_2(x, t, c, k):
  390. """Naive B-spline evaluation, another way."""
  391. n = len(t) - (k+1)
  392. assert n >= k+1
  393. assert len(c) >= n
  394. assert t[k] <= x <= t[n]
  395. return sum(c[i] * _naive_B(x, k, i, t) for i in range(n))
  396. def _sum_basis_elements(x, t, c, k):
  397. n = len(t) - (k+1)
  398. assert n >= k+1
  399. assert len(c) >= n
  400. s = 0.
  401. for i in range(n):
  402. b = BSpline.basis_element(t[i:i+k+2], extrapolate=False)(x)
  403. s += c[i] * np.nan_to_num(b) # zero out out-of-bounds elements
  404. return s
  405. def B_012(x):
  406. """ A linear B-spline function B(x | 0, 1, 2)."""
  407. x = np.atleast_1d(x)
  408. return np.piecewise(x, [(x < 0) | (x > 2),
  409. (x >= 0) & (x < 1),
  410. (x >= 1) & (x <= 2)],
  411. [lambda x: 0., lambda x: x, lambda x: 2.-x])
  412. def B_0123(x, der=0):
  413. """A quadratic B-spline function B(x | 0, 1, 2, 3)."""
  414. x = np.atleast_1d(x)
  415. conds = [x < 1, (x > 1) & (x < 2), x > 2]
  416. if der == 0:
  417. funcs = [lambda x: x*x/2.,
  418. lambda x: 3./4 - (x-3./2)**2,
  419. lambda x: (3.-x)**2 / 2]
  420. elif der == 2:
  421. funcs = [lambda x: 1.,
  422. lambda x: -2.,
  423. lambda x: 1.]
  424. else:
  425. raise ValueError('never be here: der=%s' % der)
  426. pieces = np.piecewise(x, conds, funcs)
  427. return pieces
  428. def _make_random_spline(n=35, k=3):
  429. np.random.seed(123)
  430. t = np.sort(np.random.random(n+k+1))
  431. c = np.random.random(n)
  432. return BSpline.construct_fast(t, c, k)
  433. def _make_multiples(b):
  434. """Increase knot multiplicity."""
  435. c, k = b.c, b.k
  436. t1 = b.t.copy()
  437. t1[17:19] = t1[17]
  438. t1[22] = t1[21]
  439. yield BSpline(t1, c, k)
  440. t1 = b.t.copy()
  441. t1[:k+1] = t1[0]
  442. yield BSpline(t1, c, k)
  443. t1 = b.t.copy()
  444. t1[-k-1:] = t1[-1]
  445. yield BSpline(t1, c, k)
  446. class TestInterop(object):
  447. #
  448. # Test that FITPACK-based spl* functions can deal with BSpline objects
  449. #
  450. def setup_method(self):
  451. xx = np.linspace(0, 4.*np.pi, 41)
  452. yy = np.cos(xx)
  453. b = make_interp_spline(xx, yy)
  454. self.tck = (b.t, b.c, b.k)
  455. self.xx, self.yy, self.b = xx, yy, b
  456. self.xnew = np.linspace(0, 4.*np.pi, 21)
  457. c2 = np.c_[b.c, b.c, b.c]
  458. self.c2 = np.dstack((c2, c2))
  459. self.b2 = BSpline(b.t, self.c2, b.k)
  460. def test_splev(self):
  461. xnew, b, b2 = self.xnew, self.b, self.b2
  462. # check that splev works with 1D array of coefficients
  463. # for array and scalar `x`
  464. assert_allclose(splev(xnew, b),
  465. b(xnew), atol=1e-15, rtol=1e-15)
  466. assert_allclose(splev(xnew, b.tck),
  467. b(xnew), atol=1e-15, rtol=1e-15)
  468. assert_allclose([splev(x, b) for x in xnew],
  469. b(xnew), atol=1e-15, rtol=1e-15)
  470. # With n-D coefficients, there's a quirck:
  471. # splev(x, BSpline) is equivalent to BSpline(x)
  472. with suppress_warnings() as sup:
  473. sup.filter(DeprecationWarning,
  474. "Calling splev.. with BSpline objects with c.ndim > 1 is not recommended.")
  475. assert_allclose(splev(xnew, b2), b2(xnew), atol=1e-15, rtol=1e-15)
  476. # However, splev(x, BSpline.tck) needs some transposes. This is because
  477. # BSpline interpolates along the first axis, while the legacy FITPACK
  478. # wrapper does list(map(...)) which effectively interpolates along the
  479. # last axis. Like so:
  480. sh = tuple(range(1, b2.c.ndim)) + (0,) # sh = (1, 2, 0)
  481. cc = b2.c.transpose(sh)
  482. tck = (b2.t, cc, b2.k)
  483. assert_allclose(splev(xnew, tck),
  484. b2(xnew).transpose(sh), atol=1e-15, rtol=1e-15)
  485. def test_splrep(self):
  486. x, y = self.xx, self.yy
  487. # test that "new" splrep is equivalent to _impl.splrep
  488. tck = splrep(x, y)
  489. t, c, k = _impl.splrep(x, y)
  490. assert_allclose(tck[0], t, atol=1e-15)
  491. assert_allclose(tck[1], c, atol=1e-15)
  492. assert_equal(tck[2], k)
  493. # also cover the `full_output=True` branch
  494. tck_f, _, _, _ = splrep(x, y, full_output=True)
  495. assert_allclose(tck_f[0], t, atol=1e-15)
  496. assert_allclose(tck_f[1], c, atol=1e-15)
  497. assert_equal(tck_f[2], k)
  498. # test that the result of splrep roundtrips with splev:
  499. # evaluate the spline on the original `x` points
  500. yy = splev(x, tck)
  501. assert_allclose(y, yy, atol=1e-15)
  502. # ... and also it roundtrips if wrapped in a BSpline
  503. b = BSpline(*tck)
  504. assert_allclose(y, b(x), atol=1e-15)
  505. @pytest.mark.xfail(NumpyVersion(np.__version__) < '1.14.0',
  506. reason='requires NumPy >= 1.14.0')
  507. def test_splrep_errors(self):
  508. # test that both "old" and "new" splrep raise for an n-D ``y`` array
  509. # with n > 1
  510. x, y = self.xx, self.yy
  511. y2 = np.c_[y, y]
  512. with assert_raises(ValueError):
  513. splrep(x, y2)
  514. with assert_raises(ValueError):
  515. _impl.splrep(x, y2)
  516. # input below minimum size
  517. with assert_raises(TypeError, match="m > k must hold"):
  518. splrep(x[:3], y[:3])
  519. with assert_raises(TypeError, match="m > k must hold"):
  520. _impl.splrep(x[:3], y[:3])
  521. def test_splprep(self):
  522. x = np.arange(15).reshape((3, 5))
  523. b, u = splprep(x)
  524. tck, u1 = _impl.splprep(x)
  525. # test the roundtrip with splev for both "old" and "new" output
  526. assert_allclose(u, u1, atol=1e-15)
  527. assert_allclose(splev(u, b), x, atol=1e-15)
  528. assert_allclose(splev(u, tck), x, atol=1e-15)
  529. # cover the ``full_output=True`` branch
  530. (b_f, u_f), _, _, _ = splprep(x, s=0, full_output=True)
  531. assert_allclose(u, u_f, atol=1e-15)
  532. assert_allclose(splev(u_f, b_f), x, atol=1e-15)
  533. def test_splprep_errors(self):
  534. # test that both "old" and "new" code paths raise for x.ndim > 2
  535. x = np.arange(3*4*5).reshape((3, 4, 5))
  536. with assert_raises(ValueError, match="too many values to unpack"):
  537. splprep(x)
  538. with assert_raises(ValueError, match="too many values to unpack"):
  539. _impl.splprep(x)
  540. # input below minimum size
  541. x = np.linspace(0, 40, num=3)
  542. with assert_raises(TypeError, match="m > k must hold"):
  543. splprep([x])
  544. with assert_raises(TypeError, match="m > k must hold"):
  545. _impl.splprep([x])
  546. # automatically calculated parameters are non-increasing
  547. # see gh-7589
  548. x = [-50.49072266, -50.49072266, -54.49072266, -54.49072266]
  549. with assert_raises(ValueError, match="Invalid inputs"):
  550. splprep([x])
  551. with assert_raises(ValueError, match="Invalid inputs"):
  552. _impl.splprep([x])
  553. # given non-increasing parameter values u
  554. x = [1, 3, 2, 4]
  555. u = [0, 0.3, 0.2, 1]
  556. with assert_raises(ValueError, match="Invalid inputs"):
  557. splprep(*[[x], None, u])
  558. def test_sproot(self):
  559. b, b2 = self.b, self.b2
  560. roots = np.array([0.5, 1.5, 2.5, 3.5])*np.pi
  561. # sproot accepts a BSpline obj w/ 1D coef array
  562. assert_allclose(sproot(b), roots, atol=1e-7, rtol=1e-7)
  563. assert_allclose(sproot((b.t, b.c, b.k)), roots, atol=1e-7, rtol=1e-7)
  564. # ... and deals with trailing dimensions if coef array is n-D
  565. with suppress_warnings() as sup:
  566. sup.filter(DeprecationWarning,
  567. "Calling sproot.. with BSpline objects with c.ndim > 1 is not recommended.")
  568. r = sproot(b2, mest=50)
  569. r = np.asarray(r)
  570. assert_equal(r.shape, (3, 2, 4))
  571. assert_allclose(r - roots, 0, atol=1e-12)
  572. # and legacy behavior is preserved for a tck tuple w/ n-D coef
  573. c2r = b2.c.transpose(1, 2, 0)
  574. rr = np.asarray(sproot((b2.t, c2r, b2.k), mest=50))
  575. assert_equal(rr.shape, (3, 2, 4))
  576. assert_allclose(rr - roots, 0, atol=1e-12)
  577. def test_splint(self):
  578. # test that splint accepts BSpline objects
  579. b, b2 = self.b, self.b2
  580. assert_allclose(splint(0, 1, b),
  581. splint(0, 1, b.tck), atol=1e-14)
  582. assert_allclose(splint(0, 1, b),
  583. b.integrate(0, 1), atol=1e-14)
  584. # ... and deals with n-D arrays of coefficients
  585. with suppress_warnings() as sup:
  586. sup.filter(DeprecationWarning,
  587. "Calling splint.. with BSpline objects with c.ndim > 1 is not recommended.")
  588. assert_allclose(splint(0, 1, b2), b2.integrate(0, 1), atol=1e-14)
  589. # and the legacy behavior is preserved for a tck tuple w/ n-D coef
  590. c2r = b2.c.transpose(1, 2, 0)
  591. integr = np.asarray(splint(0, 1, (b2.t, c2r, b2.k)))
  592. assert_equal(integr.shape, (3, 2))
  593. assert_allclose(integr,
  594. splint(0, 1, b), atol=1e-14)
  595. def test_splder(self):
  596. for b in [self.b, self.b2]:
  597. # pad the c array (FITPACK convention)
  598. ct = len(b.t) - len(b.c)
  599. if ct > 0:
  600. b.c = np.r_[b.c, np.zeros((ct,) + b.c.shape[1:])]
  601. for n in [1, 2, 3]:
  602. bd = splder(b)
  603. tck_d = _impl.splder((b.t, b.c, b.k))
  604. assert_allclose(bd.t, tck_d[0], atol=1e-15)
  605. assert_allclose(bd.c, tck_d[1], atol=1e-15)
  606. assert_equal(bd.k, tck_d[2])
  607. assert_(isinstance(bd, BSpline))
  608. assert_(isinstance(tck_d, tuple)) # back-compat: tck in and out
  609. def test_splantider(self):
  610. for b in [self.b, self.b2]:
  611. # pad the c array (FITPACK convention)
  612. ct = len(b.t) - len(b.c)
  613. if ct > 0:
  614. b.c = np.r_[b.c, np.zeros((ct,) + b.c.shape[1:])]
  615. for n in [1, 2, 3]:
  616. bd = splantider(b)
  617. tck_d = _impl.splantider((b.t, b.c, b.k))
  618. assert_allclose(bd.t, tck_d[0], atol=1e-15)
  619. assert_allclose(bd.c, tck_d[1], atol=1e-15)
  620. assert_equal(bd.k, tck_d[2])
  621. assert_(isinstance(bd, BSpline))
  622. assert_(isinstance(tck_d, tuple)) # back-compat: tck in and out
  623. def test_insert(self):
  624. b, b2, xx = self.b, self.b2, self.xx
  625. j = b.t.size // 2
  626. tn = 0.5*(b.t[j] + b.t[j+1])
  627. bn, tck_n = insert(tn, b), insert(tn, (b.t, b.c, b.k))
  628. assert_allclose(splev(xx, bn),
  629. splev(xx, tck_n), atol=1e-15)
  630. assert_(isinstance(bn, BSpline))
  631. assert_(isinstance(tck_n, tuple)) # back-compat: tck in, tck out
  632. # for n-D array of coefficients, BSpline.c needs to be transposed
  633. # after that, the results are equivalent.
  634. sh = tuple(range(b2.c.ndim))
  635. c_ = b2.c.transpose(sh[1:] + (0,))
  636. tck_n2 = insert(tn, (b2.t, c_, b2.k))
  637. bn2 = insert(tn, b2)
  638. # need a transpose for comparing the results, cf test_splev
  639. assert_allclose(np.asarray(splev(xx, tck_n2)).transpose(2, 0, 1),
  640. bn2(xx), atol=1e-15)
  641. assert_(isinstance(bn2, BSpline))
  642. assert_(isinstance(tck_n2, tuple)) # back-compat: tck in, tck out
  643. class TestInterp(object):
  644. #
  645. # Test basic ways of constructing interpolating splines.
  646. #
  647. xx = np.linspace(0., 2.*np.pi)
  648. yy = np.sin(xx)
  649. def test_non_int_order(self):
  650. with assert_raises(TypeError):
  651. make_interp_spline(self.xx, self.yy, k=2.5)
  652. def test_order_0(self):
  653. b = make_interp_spline(self.xx, self.yy, k=0)
  654. assert_allclose(b(self.xx), self.yy, atol=1e-14, rtol=1e-14)
  655. b = make_interp_spline(self.xx, self.yy, k=0, axis=-1)
  656. assert_allclose(b(self.xx), self.yy, atol=1e-14, rtol=1e-14)
  657. def test_linear(self):
  658. b = make_interp_spline(self.xx, self.yy, k=1)
  659. assert_allclose(b(self.xx), self.yy, atol=1e-14, rtol=1e-14)
  660. b = make_interp_spline(self.xx, self.yy, k=1, axis=-1)
  661. assert_allclose(b(self.xx), self.yy, atol=1e-14, rtol=1e-14)
  662. def test_not_a_knot(self):
  663. for k in [3, 5]:
  664. b = make_interp_spline(self.xx, self.yy, k)
  665. assert_allclose(b(self.xx), self.yy, atol=1e-14, rtol=1e-14)
  666. def test_quadratic_deriv(self):
  667. der = [(1, 8.)] # order, value: f'(x) = 8.
  668. # derivative at right-hand edge
  669. b = make_interp_spline(self.xx, self.yy, k=2, bc_type=(None, der))
  670. assert_allclose(b(self.xx), self.yy, atol=1e-14, rtol=1e-14)
  671. assert_allclose(b(self.xx[-1], 1), der[0][1], atol=1e-14, rtol=1e-14)
  672. # derivative at left-hand edge
  673. b = make_interp_spline(self.xx, self.yy, k=2, bc_type=(der, None))
  674. assert_allclose(b(self.xx), self.yy, atol=1e-14, rtol=1e-14)
  675. assert_allclose(b(self.xx[0], 1), der[0][1], atol=1e-14, rtol=1e-14)
  676. def test_cubic_deriv(self):
  677. k = 3
  678. # first derivatives at left & right edges:
  679. der_l, der_r = [(1, 3.)], [(1, 4.)]
  680. b = make_interp_spline(self.xx, self.yy, k, bc_type=(der_l, der_r))
  681. assert_allclose(b(self.xx), self.yy, atol=1e-14, rtol=1e-14)
  682. assert_allclose([b(self.xx[0], 1), b(self.xx[-1], 1)],
  683. [der_l[0][1], der_r[0][1]], atol=1e-14, rtol=1e-14)
  684. # 'natural' cubic spline, zero out 2nd derivatives at the boundaries
  685. der_l, der_r = [(2, 0)], [(2, 0)]
  686. b = make_interp_spline(self.xx, self.yy, k, bc_type=(der_l, der_r))
  687. assert_allclose(b(self.xx), self.yy, atol=1e-14, rtol=1e-14)
  688. def test_quintic_derivs(self):
  689. k, n = 5, 7
  690. x = np.arange(n).astype(np.float_)
  691. y = np.sin(x)
  692. der_l = [(1, -12.), (2, 1)]
  693. der_r = [(1, 8.), (2, 3.)]
  694. b = make_interp_spline(x, y, k=k, bc_type=(der_l, der_r))
  695. assert_allclose(b(x), y, atol=1e-14, rtol=1e-14)
  696. assert_allclose([b(x[0], 1), b(x[0], 2)],
  697. [val for (nu, val) in der_l])
  698. assert_allclose([b(x[-1], 1), b(x[-1], 2)],
  699. [val for (nu, val) in der_r])
  700. @pytest.mark.xfail(reason='unstable')
  701. def test_cubic_deriv_unstable(self):
  702. # 1st and 2nd derivative at x[0], no derivative information at x[-1]
  703. # The problem is not that it fails [who would use this anyway],
  704. # the problem is that it fails *silently*, and I've no idea
  705. # how to detect this sort of instability.
  706. # In this particular case: it's OK for len(t) < 20, goes haywire
  707. # at larger `len(t)`.
  708. k = 3
  709. t = _augknt(self.xx, k)
  710. der_l = [(1, 3.), (2, 4.)]
  711. b = make_interp_spline(self.xx, self.yy, k, t, bc_type=(der_l, None))
  712. assert_allclose(b(self.xx), self.yy, atol=1e-14, rtol=1e-14)
  713. def test_knots_not_data_sites(self):
  714. # Knots need not coincide with the data sites.
  715. # use a quadratic spline, knots are at data averages,
  716. # two additional constraints are zero 2nd derivs at edges
  717. k = 2
  718. t = np.r_[(self.xx[0],)*(k+1),
  719. (self.xx[1:] + self.xx[:-1]) / 2.,
  720. (self.xx[-1],)*(k+1)]
  721. b = make_interp_spline(self.xx, self.yy, k, t,
  722. bc_type=([(2, 0)], [(2, 0)]))
  723. assert_allclose(b(self.xx), self.yy, atol=1e-14, rtol=1e-14)
  724. assert_allclose([b(self.xx[0], 2), b(self.xx[-1], 2)], [0., 0.],
  725. atol=1e-14)
  726. def test_minimum_points_and_deriv(self):
  727. # interpolation of f(x) = x**3 between 0 and 1. f'(x) = 3 * xx**2 and
  728. # f'(0) = 0, f'(1) = 3.
  729. k = 3
  730. x = [0., 1.]
  731. y = [0., 1.]
  732. b = make_interp_spline(x, y, k, bc_type=([(1, 0.)], [(1, 3.)]))
  733. xx = np.linspace(0., 1.)
  734. yy = xx**3
  735. assert_allclose(b(xx), yy, atol=1e-14, rtol=1e-14)
  736. def test_deriv_spec(self):
  737. # If one of the derivatives is omitted, the spline definition is
  738. # incomplete.
  739. x = y = [1.0, 2, 3, 4, 5, 6]
  740. with assert_raises(ValueError):
  741. make_interp_spline(x, y, bc_type=([(1, 0.)], None))
  742. with assert_raises(ValueError):
  743. make_interp_spline(x, y, bc_type=(1, 0.))
  744. with assert_raises(ValueError):
  745. make_interp_spline(x, y, bc_type=[(1, 0.)])
  746. with assert_raises(ValueError):
  747. make_interp_spline(x, y, bc_type=42)
  748. # CubicSpline expects`bc_type=(left_pair, right_pair)`, while
  749. # here we expect `bc_type=(iterable, iterable)`.
  750. l, r = (1, 0.0), (1, 0.0)
  751. with assert_raises(ValueError):
  752. make_interp_spline(x, y, bc_type=(l, r))
  753. def test_complex(self):
  754. k = 3
  755. xx = self.xx
  756. yy = self.yy + 1.j*self.yy
  757. # first derivatives at left & right edges:
  758. der_l, der_r = [(1, 3.j)], [(1, 4.+2.j)]
  759. b = make_interp_spline(xx, yy, k, bc_type=(der_l, der_r))
  760. assert_allclose(b(xx), yy, atol=1e-14, rtol=1e-14)
  761. assert_allclose([b(xx[0], 1), b(xx[-1], 1)],
  762. [der_l[0][1], der_r[0][1]], atol=1e-14, rtol=1e-14)
  763. # also test zero and first order
  764. for k in (0, 1):
  765. b = make_interp_spline(xx, yy, k=k)
  766. assert_allclose(b(xx), yy, atol=1e-14, rtol=1e-14)
  767. def test_int_xy(self):
  768. x = np.arange(10).astype(np.int_)
  769. y = np.arange(10).astype(np.int_)
  770. # cython chokes on "buffer type mismatch" (construction) or
  771. # "no matching signature found" (evaluation)
  772. for k in (0, 1, 2, 3):
  773. b = make_interp_spline(x, y, k=k)
  774. b(x)
  775. def test_sliced_input(self):
  776. # cython code chokes on non C contiguous arrays
  777. xx = np.linspace(-1, 1, 100)
  778. x = xx[::5]
  779. y = xx[::5]
  780. for k in (0, 1, 2, 3):
  781. make_interp_spline(x, y, k=k)
  782. def test_check_finite(self):
  783. # check_finite defaults to True; nans and such trigger a ValueError
  784. x = np.arange(10).astype(float)
  785. y = x**2
  786. for z in [np.nan, np.inf, -np.inf]:
  787. y[-1] = z
  788. assert_raises(ValueError, make_interp_spline, x, y)
  789. @pytest.mark.parametrize('k', [1, 2, 3, 5])
  790. def test_list_input(self, k):
  791. # regression test for gh-8714: TypeError for x, y being lists and k=2
  792. x = list(range(10))
  793. y = [a**2 for a in x]
  794. make_interp_spline(x, y, k=k)
  795. def test_multiple_rhs(self):
  796. yy = np.c_[np.sin(self.xx), np.cos(self.xx)]
  797. der_l = [(1, [1., 2.])]
  798. der_r = [(1, [3., 4.])]
  799. b = make_interp_spline(self.xx, yy, k=3, bc_type=(der_l, der_r))
  800. assert_allclose(b(self.xx), yy, atol=1e-14, rtol=1e-14)
  801. assert_allclose(b(self.xx[0], 1), der_l[0][1], atol=1e-14, rtol=1e-14)
  802. assert_allclose(b(self.xx[-1], 1), der_r[0][1], atol=1e-14, rtol=1e-14)
  803. def test_shapes(self):
  804. np.random.seed(1234)
  805. k, n = 3, 22
  806. x = np.sort(np.random.random(size=n))
  807. y = np.random.random(size=(n, 5, 6, 7))
  808. b = make_interp_spline(x, y, k)
  809. assert_equal(b.c.shape, (n, 5, 6, 7))
  810. # now throw in some derivatives
  811. d_l = [(1, np.random.random((5, 6, 7)))]
  812. d_r = [(1, np.random.random((5, 6, 7)))]
  813. b = make_interp_spline(x, y, k, bc_type=(d_l, d_r))
  814. assert_equal(b.c.shape, (n + k - 1, 5, 6, 7))
  815. def test_string_aliases(self):
  816. yy = np.sin(self.xx)
  817. # a single string is duplicated
  818. b1 = make_interp_spline(self.xx, yy, k=3, bc_type='natural')
  819. b2 = make_interp_spline(self.xx, yy, k=3, bc_type=([(2, 0)], [(2, 0)]))
  820. assert_allclose(b1.c, b2.c, atol=1e-15)
  821. # two strings are handled
  822. b1 = make_interp_spline(self.xx, yy, k=3,
  823. bc_type=('natural', 'clamped'))
  824. b2 = make_interp_spline(self.xx, yy, k=3,
  825. bc_type=([(2, 0)], [(1, 0)]))
  826. assert_allclose(b1.c, b2.c, atol=1e-15)
  827. # one-sided BCs are OK
  828. b1 = make_interp_spline(self.xx, yy, k=2, bc_type=(None, 'clamped'))
  829. b2 = make_interp_spline(self.xx, yy, k=2, bc_type=(None, [(1, 0.0)]))
  830. assert_allclose(b1.c, b2.c, atol=1e-15)
  831. # 'not-a-knot' is equivalent to None
  832. b1 = make_interp_spline(self.xx, yy, k=3, bc_type='not-a-knot')
  833. b2 = make_interp_spline(self.xx, yy, k=3, bc_type=None)
  834. assert_allclose(b1.c, b2.c, atol=1e-15)
  835. # unknown strings do not pass
  836. with assert_raises(ValueError):
  837. make_interp_spline(self.xx, yy, k=3, bc_type='typo')
  838. # string aliases are handled for 2D values
  839. yy = np.c_[np.sin(self.xx), np.cos(self.xx)]
  840. der_l = [(1, [0., 0.])]
  841. der_r = [(2, [0., 0.])]
  842. b2 = make_interp_spline(self.xx, yy, k=3, bc_type=(der_l, der_r))
  843. b1 = make_interp_spline(self.xx, yy, k=3,
  844. bc_type=('clamped', 'natural'))
  845. assert_allclose(b1.c, b2.c, atol=1e-15)
  846. # ... and for n-D values:
  847. np.random.seed(1234)
  848. k, n = 3, 22
  849. x = np.sort(np.random.random(size=n))
  850. y = np.random.random(size=(n, 5, 6, 7))
  851. # now throw in some derivatives
  852. d_l = [(1, np.zeros((5, 6, 7)))]
  853. d_r = [(1, np.zeros((5, 6, 7)))]
  854. b1 = make_interp_spline(x, y, k, bc_type=(d_l, d_r))
  855. b2 = make_interp_spline(x, y, k, bc_type='clamped')
  856. assert_allclose(b1.c, b2.c, atol=1e-15)
  857. def test_full_matrix(self):
  858. np.random.seed(1234)
  859. k, n = 3, 7
  860. x = np.sort(np.random.random(size=n))
  861. y = np.random.random(size=n)
  862. t = _not_a_knot(x, k)
  863. b = make_interp_spline(x, y, k, t)
  864. cf = make_interp_full_matr(x, y, t, k)
  865. assert_allclose(b.c, cf, atol=1e-14, rtol=1e-14)
  866. def make_interp_full_matr(x, y, t, k):
  867. """Assemble an spline order k with knots t to interpolate
  868. y(x) using full matrices.
  869. Not-a-knot BC only.
  870. This routine is here for testing only (even though it's functional).
  871. """
  872. assert x.size == y.size
  873. assert t.size == x.size + k + 1
  874. n = x.size
  875. A = np.zeros((n, n), dtype=np.float_)
  876. for j in range(n):
  877. xval = x[j]
  878. if xval == t[k]:
  879. left = k
  880. else:
  881. left = np.searchsorted(t, xval) - 1
  882. # fill a row
  883. bb = _bspl.evaluate_all_bspl(t, k, xval, left)
  884. A[j, left-k:left+1] = bb
  885. c = sl.solve(A, y)
  886. return c
  887. ### XXX: 'periodic' interp spline using full matrices
  888. def make_interp_per_full_matr(x, y, t, k):
  889. x, y, t = map(np.asarray, (x, y, t))
  890. n = x.size
  891. nt = t.size - k - 1
  892. # have `n` conditions for `nt` coefficients; need nt-n derivatives
  893. assert nt - n == k - 1
  894. # LHS: the collocation matrix + derivatives at edges
  895. A = np.zeros((nt, nt), dtype=np.float_)
  896. # derivatives at x[0]:
  897. offset = 0
  898. if x[0] == t[k]:
  899. left = k
  900. else:
  901. left = np.searchsorted(t, x[0]) - 1
  902. if x[-1] == t[k]:
  903. left2 = k
  904. else:
  905. left2 = np.searchsorted(t, x[-1]) - 1
  906. for i in range(k-1):
  907. bb = _bspl.evaluate_all_bspl(t, k, x[0], left, nu=i+1)
  908. A[i, left-k:left+1] = bb
  909. bb = _bspl.evaluate_all_bspl(t, k, x[-1], left2, nu=i+1)
  910. A[i, left2-k:left2+1] = -bb
  911. offset += 1
  912. # RHS
  913. y = np.r_[[0]*(k-1), y]
  914. # collocation matrix
  915. for j in range(n):
  916. xval = x[j]
  917. # find interval
  918. if xval == t[k]:
  919. left = k
  920. else:
  921. left = np.searchsorted(t, xval) - 1
  922. # fill a row
  923. bb = _bspl.evaluate_all_bspl(t, k, xval, left)
  924. A[j + offset, left-k:left+1] = bb
  925. c = sl.solve(A, y)
  926. return c
  927. def make_lsq_full_matrix(x, y, t, k=3):
  928. """Make the least-square spline, full matrices."""
  929. x, y, t = map(np.asarray, (x, y, t))
  930. m = x.size
  931. n = t.size - k - 1
  932. A = np.zeros((m, n), dtype=np.float_)
  933. for j in range(m):
  934. xval = x[j]
  935. # find interval
  936. if xval == t[k]:
  937. left = k
  938. else:
  939. left = np.searchsorted(t, xval) - 1
  940. # fill a row
  941. bb = _bspl.evaluate_all_bspl(t, k, xval, left)
  942. A[j, left-k:left+1] = bb
  943. # have observation matrix, can solve the LSQ problem
  944. B = np.dot(A.T, A)
  945. Y = np.dot(A.T, y)
  946. c = sl.solve(B, Y)
  947. return c, (A, Y)
  948. class TestLSQ(object):
  949. #
  950. # Test make_lsq_spline
  951. #
  952. np.random.seed(1234)
  953. n, k = 13, 3
  954. x = np.sort(np.random.random(n))
  955. y = np.random.random(n)
  956. t = _augknt(np.linspace(x[0], x[-1], 7), k)
  957. def test_lstsq(self):
  958. # check LSQ construction vs a full matrix version
  959. x, y, t, k = self.x, self.y, self.t, self.k
  960. c0, AY = make_lsq_full_matrix(x, y, t, k)
  961. b = make_lsq_spline(x, y, t, k)
  962. assert_allclose(b.c, c0)
  963. assert_equal(b.c.shape, (t.size - k - 1,))
  964. # also check against numpy.lstsq
  965. aa, yy = AY
  966. c1, _, _, _ = np.linalg.lstsq(aa, y, rcond=-1)
  967. assert_allclose(b.c, c1)
  968. def test_weights(self):
  969. # weights = 1 is same as None
  970. x, y, t, k = self.x, self.y, self.t, self.k
  971. w = np.ones_like(x)
  972. b = make_lsq_spline(x, y, t, k)
  973. b_w = make_lsq_spline(x, y, t, k, w=w)
  974. assert_allclose(b.t, b_w.t, atol=1e-14)
  975. assert_allclose(b.c, b_w.c, atol=1e-14)
  976. assert_equal(b.k, b_w.k)
  977. def test_multiple_rhs(self):
  978. x, t, k, n = self.x, self.t, self.k, self.n
  979. y = np.random.random(size=(n, 5, 6, 7))
  980. b = make_lsq_spline(x, y, t, k)
  981. assert_equal(b.c.shape, (t.size-k-1, 5, 6, 7))
  982. def test_complex(self):
  983. # cmplx-valued `y`
  984. x, t, k = self.x, self.t, self.k
  985. yc = self.y * (1. + 2.j)
  986. b = make_lsq_spline(x, yc, t, k)
  987. b_re = make_lsq_spline(x, yc.real, t, k)
  988. b_im = make_lsq_spline(x, yc.imag, t, k)
  989. assert_allclose(b(x), b_re(x) + 1.j*b_im(x), atol=1e-15, rtol=1e-15)
  990. def test_int_xy(self):
  991. x = np.arange(10).astype(np.int_)
  992. y = np.arange(10).astype(np.int_)
  993. t = _augknt(x, k=1)
  994. # cython chokes on "buffer type mismatch"
  995. make_lsq_spline(x, y, t, k=1)
  996. def test_sliced_input(self):
  997. # cython code chokes on non C contiguous arrays
  998. xx = np.linspace(-1, 1, 100)
  999. x = xx[::3]
  1000. y = xx[::3]
  1001. t = _augknt(x, 1)
  1002. make_lsq_spline(x, y, t, k=1)
  1003. def test_checkfinite(self):
  1004. # check_finite defaults to True; nans and such trigger a ValueError
  1005. x = np.arange(12).astype(float)
  1006. y = x**2
  1007. t = _augknt(x, 3)
  1008. for z in [np.nan, np.inf, -np.inf]:
  1009. y[-1] = z
  1010. assert_raises(ValueError, make_lsq_spline, x, y, t)