test_fitpack.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. from __future__ import division, print_function, absolute_import
  2. import os
  3. import numpy as np
  4. from numpy.testing import (assert_equal, assert_allclose, assert_,
  5. assert_almost_equal, assert_array_almost_equal)
  6. from pytest import raises as assert_raises
  7. from numpy import array, asarray, pi, sin, cos, arange, dot, ravel, sqrt, round
  8. from scipy import interpolate
  9. from scipy.interpolate.fitpack import (splrep, splev, bisplrep, bisplev,
  10. sproot, splprep, splint, spalde, splder, splantider, insert, dblint)
  11. from scipy.interpolate.dfitpack import regrid_smth
  12. def data_file(basename):
  13. return os.path.join(os.path.abspath(os.path.dirname(__file__)),
  14. 'data', basename)
  15. def norm2(x):
  16. return sqrt(dot(x.T,x))
  17. def f1(x,d=0):
  18. if d is None:
  19. return "sin"
  20. if x is None:
  21. return "sin(x)"
  22. if d % 4 == 0:
  23. return sin(x)
  24. if d % 4 == 1:
  25. return cos(x)
  26. if d % 4 == 2:
  27. return -sin(x)
  28. if d % 4 == 3:
  29. return -cos(x)
  30. def f2(x,y=0,dx=0,dy=0):
  31. if x is None:
  32. return "sin(x+y)"
  33. d = dx+dy
  34. if d % 4 == 0:
  35. return sin(x+y)
  36. if d % 4 == 1:
  37. return cos(x+y)
  38. if d % 4 == 2:
  39. return -sin(x+y)
  40. if d % 4 == 3:
  41. return -cos(x+y)
  42. def makepairs(x, y):
  43. """Helper function to create an array of pairs of x and y."""
  44. # Or itertools.product (>= python 2.6)
  45. xy = array([[a, b] for a in asarray(x) for b in asarray(y)])
  46. return xy.T
  47. def put(*a):
  48. """Produce some output if file run directly"""
  49. import sys
  50. if hasattr(sys.modules['__main__'], '__put_prints'):
  51. sys.stderr.write("".join(map(str, a)) + "\n")
  52. class TestSmokeTests(object):
  53. """
  54. Smoke tests (with a few asserts) for fitpack routines -- mostly
  55. check that they are runnable
  56. """
  57. def check_1(self,f=f1,per=0,s=0,a=0,b=2*pi,N=20,at=0,xb=None,xe=None):
  58. if xb is None:
  59. xb = a
  60. if xe is None:
  61. xe = b
  62. x = a+(b-a)*arange(N+1,dtype=float)/float(N) # nodes
  63. x1 = a+(b-a)*arange(1,N,dtype=float)/float(N-1) # middle points of the nodes
  64. v,v1 = f(x),f(x1)
  65. nk = []
  66. def err_est(k, d):
  67. # Assume f has all derivatives < 1
  68. h = 1.0/float(N)
  69. tol = 5 * h**(.75*(k-d))
  70. if s > 0:
  71. tol += 1e5*s
  72. return tol
  73. for k in range(1,6):
  74. tck = splrep(x,v,s=s,per=per,k=k,xe=xe)
  75. if at:
  76. t = tck[0][k:-k]
  77. else:
  78. t = x1
  79. nd = []
  80. for d in range(k+1):
  81. tol = err_est(k, d)
  82. err = norm2(f(t,d)-splev(t,tck,d)) / norm2(f(t,d))
  83. assert_(err < tol, (k, d, err, tol))
  84. nd.append((err, tol))
  85. nk.append(nd)
  86. put("\nf = %s s=S_k(x;t,c) x in [%s, %s] > [%s, %s]" % (f(None),
  87. repr(round(xb,3)),repr(round(xe,3)),
  88. repr(round(a,3)),repr(round(b,3))))
  89. if at:
  90. str = "at knots"
  91. else:
  92. str = "at the middle of nodes"
  93. put(" per=%d s=%s Evaluation %s" % (per,repr(s),str))
  94. put(" k : |f-s|^2 |f'-s'| |f''-.. |f'''-. |f''''- |f'''''")
  95. k = 1
  96. for l in nk:
  97. put(' %d : ' % k)
  98. for r in l:
  99. put(' %.1e %.1e' % r)
  100. put('\n')
  101. k = k+1
  102. def check_2(self,f=f1,per=0,s=0,a=0,b=2*pi,N=20,xb=None,xe=None,
  103. ia=0,ib=2*pi,dx=0.2*pi):
  104. if xb is None:
  105. xb = a
  106. if xe is None:
  107. xe = b
  108. x = a+(b-a)*arange(N+1,dtype=float)/float(N) # nodes
  109. v = f(x)
  110. def err_est(k, d):
  111. # Assume f has all derivatives < 1
  112. h = 1.0/float(N)
  113. tol = 5 * h**(.75*(k-d))
  114. if s > 0:
  115. tol += 1e5*s
  116. return tol
  117. nk = []
  118. for k in range(1,6):
  119. tck = splrep(x,v,s=s,per=per,k=k,xe=xe)
  120. nk.append([splint(ia,ib,tck),spalde(dx,tck)])
  121. put("\nf = %s s=S_k(x;t,c) x in [%s, %s] > [%s, %s]" % (f(None),
  122. repr(round(xb,3)),repr(round(xe,3)),
  123. repr(round(a,3)),repr(round(b,3))))
  124. put(" per=%d s=%s N=%d [a, b] = [%s, %s] dx=%s" % (per,repr(s),N,repr(round(ia,3)),repr(round(ib,3)),repr(round(dx,3))))
  125. put(" k : int(s,[a,b]) Int.Error Rel. error of s^(d)(dx) d = 0, .., k")
  126. k = 1
  127. for r in nk:
  128. if r[0] < 0:
  129. sr = '-'
  130. else:
  131. sr = ' '
  132. put(" %d %s%.8f %.1e " % (k,sr,abs(r[0]),
  133. abs(r[0]-(f(ib,-1)-f(ia,-1)))))
  134. d = 0
  135. for dr in r[1]:
  136. err = abs(1-dr/f(dx,d))
  137. tol = err_est(k, d)
  138. assert_(err < tol, (k, d))
  139. put(" %.1e %.1e" % (err, tol))
  140. d = d+1
  141. put("\n")
  142. k = k+1
  143. def check_3(self,f=f1,per=0,s=0,a=0,b=2*pi,N=20,xb=None,xe=None,
  144. ia=0,ib=2*pi,dx=0.2*pi):
  145. if xb is None:
  146. xb = a
  147. if xe is None:
  148. xe = b
  149. x = a+(b-a)*arange(N+1,dtype=float)/float(N) # nodes
  150. v = f(x)
  151. put(" k : Roots of s(x) approx %s x in [%s,%s]:" %
  152. (f(None),repr(round(a,3)),repr(round(b,3))))
  153. for k in range(1,6):
  154. tck = splrep(x, v, s=s, per=per, k=k, xe=xe)
  155. if k == 3:
  156. roots = sproot(tck)
  157. assert_allclose(splev(roots, tck), 0, atol=1e-10, rtol=1e-10)
  158. assert_allclose(roots, pi*array([1, 2, 3, 4]), rtol=1e-3)
  159. put(' %d : %s' % (k, repr(roots.tolist())))
  160. else:
  161. assert_raises(ValueError, sproot, tck)
  162. def check_4(self,f=f1,per=0,s=0,a=0,b=2*pi,N=20,xb=None,xe=None,
  163. ia=0,ib=2*pi,dx=0.2*pi):
  164. if xb is None:
  165. xb = a
  166. if xe is None:
  167. xe = b
  168. x = a+(b-a)*arange(N+1,dtype=float)/float(N) # nodes
  169. x1 = a + (b-a)*arange(1,N,dtype=float)/float(N-1) # middle points of the nodes
  170. v,v1 = f(x),f(x1)
  171. put(" u = %s N = %d" % (repr(round(dx,3)),N))
  172. put(" k : [x(u), %s(x(u))] Error of splprep Error of splrep " % (f(0,None)))
  173. for k in range(1,6):
  174. tckp,u = splprep([x,v],s=s,per=per,k=k,nest=-1)
  175. tck = splrep(x,v,s=s,per=per,k=k)
  176. uv = splev(dx,tckp)
  177. err1 = abs(uv[1]-f(uv[0]))
  178. err2 = abs(splev(uv[0],tck)-f(uv[0]))
  179. assert_(err1 < 1e-2)
  180. assert_(err2 < 1e-2)
  181. put(" %d : %s %.1e %.1e" %
  182. (k,repr([round(z,3) for z in uv]),
  183. err1,
  184. err2))
  185. put("Derivatives of parametric cubic spline at u (first function):")
  186. k = 3
  187. tckp,u = splprep([x,v],s=s,per=per,k=k,nest=-1)
  188. for d in range(1,k+1):
  189. uv = splev(dx,tckp,d)
  190. put(" %s " % (repr(uv[0])))
  191. def check_5(self,f=f2,kx=3,ky=3,xb=0,xe=2*pi,yb=0,ye=2*pi,Nx=20,Ny=20,s=0):
  192. x = xb+(xe-xb)*arange(Nx+1,dtype=float)/float(Nx)
  193. y = yb+(ye-yb)*arange(Ny+1,dtype=float)/float(Ny)
  194. xy = makepairs(x,y)
  195. tck = bisplrep(xy[0],xy[1],f(xy[0],xy[1]),s=s,kx=kx,ky=ky)
  196. tt = [tck[0][kx:-kx],tck[1][ky:-ky]]
  197. t2 = makepairs(tt[0],tt[1])
  198. v1 = bisplev(tt[0],tt[1],tck)
  199. v2 = f2(t2[0],t2[1])
  200. v2.shape = len(tt[0]),len(tt[1])
  201. err = norm2(ravel(v1-v2))
  202. assert_(err < 1e-2, err)
  203. put(err)
  204. def test_smoke_splrep_splev(self):
  205. put("***************** splrep/splev")
  206. self.check_1(s=1e-6)
  207. self.check_1()
  208. self.check_1(at=1)
  209. self.check_1(per=1)
  210. self.check_1(per=1,at=1)
  211. self.check_1(b=1.5*pi)
  212. self.check_1(b=1.5*pi,xe=2*pi,per=1,s=1e-1)
  213. def test_smoke_splint_spalde(self):
  214. put("***************** splint/spalde")
  215. self.check_2()
  216. self.check_2(per=1)
  217. self.check_2(ia=0.2*pi,ib=pi)
  218. self.check_2(ia=0.2*pi,ib=pi,N=50)
  219. def test_smoke_sproot(self):
  220. put("***************** sproot")
  221. self.check_3(a=0.1,b=15)
  222. def test_smoke_splprep_splrep_splev(self):
  223. put("***************** splprep/splrep/splev")
  224. self.check_4()
  225. self.check_4(N=50)
  226. def test_smoke_bisplrep_bisplev(self):
  227. put("***************** bisplev")
  228. self.check_5()
  229. class TestSplev(object):
  230. def test_1d_shape(self):
  231. x = [1,2,3,4,5]
  232. y = [4,5,6,7,8]
  233. tck = splrep(x, y)
  234. z = splev([1], tck)
  235. assert_equal(z.shape, (1,))
  236. z = splev(1, tck)
  237. assert_equal(z.shape, ())
  238. def test_2d_shape(self):
  239. x = [1, 2, 3, 4, 5]
  240. y = [4, 5, 6, 7, 8]
  241. tck = splrep(x, y)
  242. t = np.array([[1.0, 1.5, 2.0, 2.5],
  243. [3.0, 3.5, 4.0, 4.5]])
  244. z = splev(t, tck)
  245. z0 = splev(t[0], tck)
  246. z1 = splev(t[1], tck)
  247. assert_equal(z, np.row_stack((z0, z1)))
  248. def test_extrapolation_modes(self):
  249. # test extrapolation modes
  250. # * if ext=0, return the extrapolated value.
  251. # * if ext=1, return 0
  252. # * if ext=2, raise a ValueError
  253. # * if ext=3, return the boundary value.
  254. x = [1,2,3]
  255. y = [0,2,4]
  256. tck = splrep(x, y, k=1)
  257. rstl = [[-2, 6], [0, 0], None, [0, 4]]
  258. for ext in (0, 1, 3):
  259. assert_array_almost_equal(splev([0, 4], tck, ext=ext), rstl[ext])
  260. assert_raises(ValueError, splev, [0, 4], tck, ext=2)
  261. class TestSplder(object):
  262. def setup_method(self):
  263. # non-uniform grid, just to make it sure
  264. x = np.linspace(0, 1, 100)**3
  265. y = np.sin(20 * x)
  266. self.spl = splrep(x, y)
  267. # double check that knots are non-uniform
  268. assert_(np.diff(self.spl[0]).ptp() > 0)
  269. def test_inverse(self):
  270. # Check that antiderivative + derivative is identity.
  271. for n in range(5):
  272. spl2 = splantider(self.spl, n)
  273. spl3 = splder(spl2, n)
  274. assert_allclose(self.spl[0], spl3[0])
  275. assert_allclose(self.spl[1], spl3[1])
  276. assert_equal(self.spl[2], spl3[2])
  277. def test_splder_vs_splev(self):
  278. # Check derivative vs. FITPACK
  279. for n in range(3+1):
  280. # Also extrapolation!
  281. xx = np.linspace(-1, 2, 2000)
  282. if n == 3:
  283. # ... except that FITPACK extrapolates strangely for
  284. # order 0, so let's not check that.
  285. xx = xx[(xx >= 0) & (xx <= 1)]
  286. dy = splev(xx, self.spl, n)
  287. spl2 = splder(self.spl, n)
  288. dy2 = splev(xx, spl2)
  289. if n == 1:
  290. assert_allclose(dy, dy2, rtol=2e-6)
  291. else:
  292. assert_allclose(dy, dy2)
  293. def test_splantider_vs_splint(self):
  294. # Check antiderivative vs. FITPACK
  295. spl2 = splantider(self.spl)
  296. # no extrapolation, splint assumes function is zero outside
  297. # range
  298. xx = np.linspace(0, 1, 20)
  299. for x1 in xx:
  300. for x2 in xx:
  301. y1 = splint(x1, x2, self.spl)
  302. y2 = splev(x2, spl2) - splev(x1, spl2)
  303. assert_allclose(y1, y2)
  304. def test_order0_diff(self):
  305. assert_raises(ValueError, splder, self.spl, 4)
  306. def test_kink(self):
  307. # Should refuse to differentiate splines with kinks
  308. spl2 = insert(0.5, self.spl, m=2)
  309. splder(spl2, 2) # Should work
  310. assert_raises(ValueError, splder, spl2, 3)
  311. spl2 = insert(0.5, self.spl, m=3)
  312. splder(spl2, 1) # Should work
  313. assert_raises(ValueError, splder, spl2, 2)
  314. spl2 = insert(0.5, self.spl, m=4)
  315. assert_raises(ValueError, splder, spl2, 1)
  316. def test_multidim(self):
  317. # c can have trailing dims
  318. for n in range(3):
  319. t, c, k = self.spl
  320. c2 = np.c_[c, c, c]
  321. c2 = np.dstack((c2, c2))
  322. spl2 = splantider((t, c2, k), n)
  323. spl3 = splder(spl2, n)
  324. assert_allclose(t, spl3[0])
  325. assert_allclose(c2, spl3[1])
  326. assert_equal(k, spl3[2])
  327. class TestBisplrep(object):
  328. def test_overflow(self):
  329. a = np.linspace(0, 1, 620)
  330. b = np.linspace(0, 1, 620)
  331. x, y = np.meshgrid(a, b)
  332. z = np.random.rand(*x.shape)
  333. assert_raises(OverflowError, bisplrep, x.ravel(), y.ravel(), z.ravel(), s=0)
  334. def test_regression_1310(self):
  335. # Regression test for gh-1310
  336. data = np.load(data_file('bug-1310.npz'))['data']
  337. # Shouldn't crash -- the input data triggers work array sizes
  338. # that caused previously some data to not be aligned on
  339. # sizeof(double) boundaries in memory, which made the Fortran
  340. # code to crash when compiled with -O3
  341. bisplrep(data[:,0], data[:,1], data[:,2], kx=3, ky=3, s=0,
  342. full_output=True)
  343. def test_dblint():
  344. # Basic test to see it runs and gives the correct result on a trivial
  345. # problem. Note that `dblint` is not exposed in the interpolate namespace.
  346. x = np.linspace(0, 1)
  347. y = np.linspace(0, 1)
  348. xx, yy = np.meshgrid(x, y)
  349. rect = interpolate.RectBivariateSpline(x, y, 4 * xx * yy)
  350. tck = list(rect.tck)
  351. tck.extend(rect.degrees)
  352. assert_almost_equal(dblint(0, 1, 0, 1, tck), 1)
  353. assert_almost_equal(dblint(0, 0.5, 0, 1, tck), 0.25)
  354. assert_almost_equal(dblint(0.5, 1, 0, 1, tck), 0.75)
  355. assert_almost_equal(dblint(-100, 100, -100, 100, tck), 1)
  356. def test_splev_der_k():
  357. # regression test for gh-2188: splev(x, tck, der=k) gives garbage or crashes
  358. # for x outside of knot range
  359. # test case from gh-2188
  360. tck = (np.array([0., 0., 2.5, 2.5]),
  361. np.array([-1.56679978, 2.43995873, 0., 0.]),
  362. 1)
  363. t, c, k = tck
  364. x = np.array([-3, 0, 2.5, 3])
  365. # an explicit form of the linear spline
  366. assert_allclose(splev(x, tck), c[0] + (c[1] - c[0]) * x/t[2])
  367. assert_allclose(splev(x, tck, 1), (c[1]-c[0]) / t[2])
  368. # now check a random spline vs splder
  369. np.random.seed(1234)
  370. x = np.sort(np.random.random(30))
  371. y = np.random.random(30)
  372. t, c, k = splrep(x, y)
  373. x = [t[0] - 1., t[-1] + 1.]
  374. tck2 = splder((t, c, k), k)
  375. assert_allclose(splev(x, (t, c, k), k), splev(x, tck2))
  376. def test_bisplev_integer_overflow():
  377. np.random.seed(1)
  378. x = np.linspace(0, 1, 11)
  379. y = x
  380. z = np.random.randn(11, 11).ravel()
  381. kx = 1
  382. ky = 1
  383. nx, tx, ny, ty, c, fp, ier = regrid_smth(
  384. x, y, z, None, None, None, None, kx=kx, ky=ky, s=0.0)
  385. tck = (tx[:nx], ty[:ny], c[:(nx - kx - 1) * (ny - ky - 1)], kx, ky)
  386. xp = np.zeros([2621440])
  387. yp = np.zeros([2621440])
  388. assert_raises((RuntimeError, MemoryError), bisplev, xp, yp, tck)