test_interpolate.py 102 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769
  1. from __future__ import division, print_function, absolute_import
  2. import itertools
  3. from numpy.testing import (assert_, assert_equal, assert_almost_equal,
  4. assert_array_almost_equal, assert_array_equal,
  5. assert_allclose)
  6. from pytest import raises as assert_raises
  7. import pytest
  8. from numpy import mgrid, pi, sin, ogrid, poly1d, linspace
  9. import numpy as np
  10. from scipy._lib.six import xrange
  11. from scipy._lib._numpy_compat import _assert_warns, suppress_warnings
  12. from scipy.interpolate import (interp1d, interp2d, lagrange, PPoly, BPoly,
  13. splrep, splev, splantider, splint, sproot, Akima1DInterpolator,
  14. RegularGridInterpolator, LinearNDInterpolator, NearestNDInterpolator,
  15. RectBivariateSpline, interpn, NdPPoly, BSpline)
  16. from scipy.special import poch, gamma
  17. from scipy.interpolate import _ppoly
  18. from scipy._lib._gcutils import assert_deallocated, IS_PYPY
  19. from scipy.integrate import nquad
  20. from scipy.special import binom
  21. class TestInterp2D(object):
  22. def test_interp2d(self):
  23. y, x = mgrid[0:2:20j, 0:pi:21j]
  24. z = sin(x+0.5*y)
  25. I = interp2d(x, y, z)
  26. assert_almost_equal(I(1.0, 2.0), sin(2.0), decimal=2)
  27. v,u = ogrid[0:2:24j, 0:pi:25j]
  28. assert_almost_equal(I(u.ravel(), v.ravel()), sin(u+0.5*v), decimal=2)
  29. def test_interp2d_meshgrid_input(self):
  30. # Ticket #703
  31. x = linspace(0, 2, 16)
  32. y = linspace(0, pi, 21)
  33. z = sin(x[None,:] + y[:,None]/2.)
  34. I = interp2d(x, y, z)
  35. assert_almost_equal(I(1.0, 2.0), sin(2.0), decimal=2)
  36. def test_interp2d_meshgrid_input_unsorted(self):
  37. np.random.seed(1234)
  38. x = linspace(0, 2, 16)
  39. y = linspace(0, pi, 21)
  40. z = sin(x[None,:] + y[:,None]/2.)
  41. ip1 = interp2d(x.copy(), y.copy(), z, kind='cubic')
  42. np.random.shuffle(x)
  43. z = sin(x[None,:] + y[:,None]/2.)
  44. ip2 = interp2d(x.copy(), y.copy(), z, kind='cubic')
  45. np.random.shuffle(x)
  46. np.random.shuffle(y)
  47. z = sin(x[None,:] + y[:,None]/2.)
  48. ip3 = interp2d(x, y, z, kind='cubic')
  49. x = linspace(0, 2, 31)
  50. y = linspace(0, pi, 30)
  51. assert_equal(ip1(x, y), ip2(x, y))
  52. assert_equal(ip1(x, y), ip3(x, y))
  53. def test_interp2d_eval_unsorted(self):
  54. y, x = mgrid[0:2:20j, 0:pi:21j]
  55. z = sin(x + 0.5*y)
  56. func = interp2d(x, y, z)
  57. xe = np.array([3, 4, 5])
  58. ye = np.array([5.3, 7.1])
  59. assert_allclose(func(xe, ye), func(xe, ye[::-1]))
  60. assert_raises(ValueError, func, xe, ye[::-1], 0, 0, True)
  61. def test_interp2d_linear(self):
  62. # Ticket #898
  63. a = np.zeros([5, 5])
  64. a[2, 2] = 1.0
  65. x = y = np.arange(5)
  66. b = interp2d(x, y, a, 'linear')
  67. assert_almost_equal(b(2.0, 1.5), np.array([0.5]), decimal=2)
  68. assert_almost_equal(b(2.0, 2.5), np.array([0.5]), decimal=2)
  69. def test_interp2d_bounds(self):
  70. x = np.linspace(0, 1, 5)
  71. y = np.linspace(0, 2, 7)
  72. z = x[None, :]**2 + y[:, None]
  73. ix = np.linspace(-1, 3, 31)
  74. iy = np.linspace(-1, 3, 33)
  75. b = interp2d(x, y, z, bounds_error=True)
  76. assert_raises(ValueError, b, ix, iy)
  77. b = interp2d(x, y, z, fill_value=np.nan)
  78. iz = b(ix, iy)
  79. mx = (ix < 0) | (ix > 1)
  80. my = (iy < 0) | (iy > 2)
  81. assert_(np.isnan(iz[my,:]).all())
  82. assert_(np.isnan(iz[:,mx]).all())
  83. assert_(np.isfinite(iz[~my,:][:,~mx]).all())
  84. class TestInterp1D(object):
  85. def setup_method(self):
  86. self.x5 = np.arange(5.)
  87. self.x10 = np.arange(10.)
  88. self.y10 = np.arange(10.)
  89. self.x25 = self.x10.reshape((2,5))
  90. self.x2 = np.arange(2.)
  91. self.y2 = np.arange(2.)
  92. self.x1 = np.array([0.])
  93. self.y1 = np.array([0.])
  94. self.y210 = np.arange(20.).reshape((2, 10))
  95. self.y102 = np.arange(20.).reshape((10, 2))
  96. self.y225 = np.arange(20.).reshape((2, 2, 5))
  97. self.y25 = np.arange(10.).reshape((2, 5))
  98. self.y235 = np.arange(30.).reshape((2, 3, 5))
  99. self.y325 = np.arange(30.).reshape((3, 2, 5))
  100. self.fill_value = -100.0
  101. def test_validation(self):
  102. # Make sure that appropriate exceptions are raised when invalid values
  103. # are given to the constructor.
  104. # These should all work.
  105. for kind in ('nearest', 'zero', 'linear', 'slinear', 'quadratic',
  106. 'cubic', 'previous', 'next'):
  107. interp1d(self.x10, self.y10, kind=kind)
  108. interp1d(self.x10, self.y10, kind=kind, fill_value="extrapolate")
  109. interp1d(self.x10, self.y10, kind='linear', fill_value=(-1, 1))
  110. interp1d(self.x10, self.y10, kind='linear',
  111. fill_value=np.array([-1]))
  112. interp1d(self.x10, self.y10, kind='linear',
  113. fill_value=(-1,))
  114. interp1d(self.x10, self.y10, kind='linear',
  115. fill_value=-1)
  116. interp1d(self.x10, self.y10, kind='linear',
  117. fill_value=(-1, -1))
  118. interp1d(self.x10, self.y10, kind=0)
  119. interp1d(self.x10, self.y10, kind=1)
  120. interp1d(self.x10, self.y10, kind=2)
  121. interp1d(self.x10, self.y10, kind=3)
  122. interp1d(self.x10, self.y210, kind='linear', axis=-1,
  123. fill_value=(-1, -1))
  124. interp1d(self.x2, self.y210, kind='linear', axis=0,
  125. fill_value=np.ones(10))
  126. interp1d(self.x2, self.y210, kind='linear', axis=0,
  127. fill_value=(np.ones(10), np.ones(10)))
  128. interp1d(self.x2, self.y210, kind='linear', axis=0,
  129. fill_value=(np.ones(10), -1))
  130. # x array must be 1D.
  131. assert_raises(ValueError, interp1d, self.x25, self.y10)
  132. # y array cannot be a scalar.
  133. assert_raises(ValueError, interp1d, self.x10, np.array(0))
  134. # Check for x and y arrays having the same length.
  135. assert_raises(ValueError, interp1d, self.x10, self.y2)
  136. assert_raises(ValueError, interp1d, self.x2, self.y10)
  137. assert_raises(ValueError, interp1d, self.x10, self.y102)
  138. interp1d(self.x10, self.y210)
  139. interp1d(self.x10, self.y102, axis=0)
  140. # Check for x and y having at least 1 element.
  141. assert_raises(ValueError, interp1d, self.x1, self.y10)
  142. assert_raises(ValueError, interp1d, self.x10, self.y1)
  143. assert_raises(ValueError, interp1d, self.x1, self.y1)
  144. # Bad fill values
  145. assert_raises(ValueError, interp1d, self.x10, self.y10, kind='linear',
  146. fill_value=(-1, -1, -1)) # doesn't broadcast
  147. assert_raises(ValueError, interp1d, self.x10, self.y10, kind='linear',
  148. fill_value=[-1, -1, -1]) # doesn't broadcast
  149. assert_raises(ValueError, interp1d, self.x10, self.y10, kind='linear',
  150. fill_value=np.array((-1, -1, -1))) # doesn't broadcast
  151. assert_raises(ValueError, interp1d, self.x10, self.y10, kind='linear',
  152. fill_value=[[-1]]) # doesn't broadcast
  153. assert_raises(ValueError, interp1d, self.x10, self.y10, kind='linear',
  154. fill_value=[-1, -1]) # doesn't broadcast
  155. assert_raises(ValueError, interp1d, self.x10, self.y10, kind='linear',
  156. fill_value=np.array([])) # doesn't broadcast
  157. assert_raises(ValueError, interp1d, self.x10, self.y10, kind='linear',
  158. fill_value=()) # doesn't broadcast
  159. assert_raises(ValueError, interp1d, self.x2, self.y210, kind='linear',
  160. axis=0, fill_value=[-1, -1]) # doesn't broadcast
  161. assert_raises(ValueError, interp1d, self.x2, self.y210, kind='linear',
  162. axis=0, fill_value=(0., [-1, -1])) # above doesn't bc
  163. def test_init(self):
  164. # Check that the attributes are initialized appropriately by the
  165. # constructor.
  166. assert_(interp1d(self.x10, self.y10).copy)
  167. assert_(not interp1d(self.x10, self.y10, copy=False).copy)
  168. assert_(interp1d(self.x10, self.y10).bounds_error)
  169. assert_(not interp1d(self.x10, self.y10, bounds_error=False).bounds_error)
  170. assert_(np.isnan(interp1d(self.x10, self.y10).fill_value))
  171. assert_equal(interp1d(self.x10, self.y10, fill_value=3.0).fill_value,
  172. 3.0)
  173. assert_equal(interp1d(self.x10, self.y10, fill_value=(1.0, 2.0)).fill_value,
  174. (1.0, 2.0))
  175. assert_equal(interp1d(self.x10, self.y10).axis, 0)
  176. assert_equal(interp1d(self.x10, self.y210).axis, 1)
  177. assert_equal(interp1d(self.x10, self.y102, axis=0).axis, 0)
  178. assert_array_equal(interp1d(self.x10, self.y10).x, self.x10)
  179. assert_array_equal(interp1d(self.x10, self.y10).y, self.y10)
  180. assert_array_equal(interp1d(self.x10, self.y210).y, self.y210)
  181. def test_assume_sorted(self):
  182. # Check for unsorted arrays
  183. interp10 = interp1d(self.x10, self.y10)
  184. interp10_unsorted = interp1d(self.x10[::-1], self.y10[::-1])
  185. assert_array_almost_equal(interp10_unsorted(self.x10), self.y10)
  186. assert_array_almost_equal(interp10_unsorted(1.2), np.array([1.2]))
  187. assert_array_almost_equal(interp10_unsorted([2.4, 5.6, 6.0]),
  188. interp10([2.4, 5.6, 6.0]))
  189. # Check assume_sorted keyword (defaults to False)
  190. interp10_assume_kw = interp1d(self.x10[::-1], self.y10[::-1],
  191. assume_sorted=False)
  192. assert_array_almost_equal(interp10_assume_kw(self.x10), self.y10)
  193. interp10_assume_kw2 = interp1d(self.x10[::-1], self.y10[::-1],
  194. assume_sorted=True)
  195. # Should raise an error for unsorted input if assume_sorted=True
  196. assert_raises(ValueError, interp10_assume_kw2, self.x10)
  197. # Check that if y is a 2-D array, things are still consistent
  198. interp10_y_2d = interp1d(self.x10, self.y210)
  199. interp10_y_2d_unsorted = interp1d(self.x10[::-1], self.y210[:, ::-1])
  200. assert_array_almost_equal(interp10_y_2d(self.x10),
  201. interp10_y_2d_unsorted(self.x10))
  202. def test_linear(self):
  203. for kind in ['linear', 'slinear']:
  204. self._check_linear(kind)
  205. def _check_linear(self, kind):
  206. # Check the actual implementation of linear interpolation.
  207. interp10 = interp1d(self.x10, self.y10, kind=kind)
  208. assert_array_almost_equal(interp10(self.x10), self.y10)
  209. assert_array_almost_equal(interp10(1.2), np.array([1.2]))
  210. assert_array_almost_equal(interp10([2.4, 5.6, 6.0]),
  211. np.array([2.4, 5.6, 6.0]))
  212. # test fill_value="extrapolate"
  213. extrapolator = interp1d(self.x10, self.y10, kind=kind,
  214. fill_value='extrapolate')
  215. assert_allclose(extrapolator([-1., 0, 9, 11]),
  216. [-1, 0, 9, 11], rtol=1e-14)
  217. opts = dict(kind=kind,
  218. fill_value='extrapolate',
  219. bounds_error=True)
  220. assert_raises(ValueError, interp1d, self.x10, self.y10, **opts)
  221. def test_linear_dtypes(self):
  222. # regression test for gh-5898, where 1D linear interpolation has been
  223. # delegated to numpy.interp for all float dtypes, and the latter was
  224. # not handling e.g. np.float128.
  225. for dtyp in np.sctypes["float"]:
  226. x = np.arange(8, dtype=dtyp)
  227. y = x
  228. yp = interp1d(x, y, kind='linear')(x)
  229. assert_equal(yp.dtype, dtyp)
  230. assert_allclose(yp, y, atol=1e-15)
  231. def test_slinear_dtypes(self):
  232. # regression test for gh-7273: 1D slinear interpolation fails with
  233. # float32 inputs
  234. dt_r = [np.float16, np.float32, np.float64]
  235. dt_rc = dt_r + [np.complex64, np.complex128]
  236. spline_kinds = ['slinear', 'zero', 'quadratic', 'cubic']
  237. for dtx in dt_r:
  238. x = np.arange(0, 10, dtype=dtx)
  239. for dty in dt_rc:
  240. y = np.exp(-x/3.0).astype(dty)
  241. for dtn in dt_r:
  242. xnew = x.astype(dtn)
  243. for kind in spline_kinds:
  244. f = interp1d(x, y, kind=kind, bounds_error=False)
  245. assert_allclose(f(xnew), y, atol=1e-7,
  246. err_msg="%s, %s %s" % (dtx, dty, dtn))
  247. def test_cubic(self):
  248. # Check the actual implementation of spline interpolation.
  249. interp10 = interp1d(self.x10, self.y10, kind='cubic')
  250. assert_array_almost_equal(interp10(self.x10), self.y10)
  251. assert_array_almost_equal(interp10(1.2), np.array([1.2]))
  252. assert_array_almost_equal(interp10([2.4, 5.6, 6.0]),
  253. np.array([2.4, 5.6, 6.0]),)
  254. def test_nearest(self):
  255. # Check the actual implementation of nearest-neighbour interpolation.
  256. interp10 = interp1d(self.x10, self.y10, kind='nearest')
  257. assert_array_almost_equal(interp10(self.x10), self.y10)
  258. assert_array_almost_equal(interp10(1.2), np.array(1.))
  259. assert_array_almost_equal(interp10([2.4, 5.6, 6.0]),
  260. np.array([2., 6., 6.]),)
  261. # test fill_value="extrapolate"
  262. extrapolator = interp1d(self.x10, self.y10, kind='nearest',
  263. fill_value='extrapolate')
  264. assert_allclose(extrapolator([-1., 0, 9, 11]),
  265. [0, 0, 9, 9], rtol=1e-14)
  266. opts = dict(kind='nearest',
  267. fill_value='extrapolate',
  268. bounds_error=True)
  269. assert_raises(ValueError, interp1d, self.x10, self.y10, **opts)
  270. def test_previous(self):
  271. # Check the actual implementation of previous interpolation.
  272. interp10 = interp1d(self.x10, self.y10, kind='previous')
  273. assert_array_almost_equal(interp10(self.x10), self.y10)
  274. assert_array_almost_equal(interp10(1.2), np.array(1.))
  275. assert_array_almost_equal(interp10([2.4, 5.6, 6.0]),
  276. np.array([2., 5., 6.]),)
  277. # test fill_value="extrapolate"
  278. extrapolator = interp1d(self.x10, self.y10, kind='previous',
  279. fill_value='extrapolate')
  280. assert_allclose(extrapolator([-1., 0, 9, 11]),
  281. [0, 0, 9, 9], rtol=1e-14)
  282. opts = dict(kind='previous',
  283. fill_value='extrapolate',
  284. bounds_error=True)
  285. assert_raises(ValueError, interp1d, self.x10, self.y10, **opts)
  286. def test_next(self):
  287. # Check the actual implementation of next interpolation.
  288. interp10 = interp1d(self.x10, self.y10, kind='next')
  289. assert_array_almost_equal(interp10(self.x10), self.y10)
  290. assert_array_almost_equal(interp10(1.2), np.array(2.))
  291. assert_array_almost_equal(interp10([2.4, 5.6, 6.0]),
  292. np.array([3., 6., 6.]),)
  293. # test fill_value="extrapolate"
  294. extrapolator = interp1d(self.x10, self.y10, kind='next',
  295. fill_value='extrapolate')
  296. assert_allclose(extrapolator([-1., 0, 9, 11]),
  297. [0, 0, 9, 9], rtol=1e-14)
  298. opts = dict(kind='next',
  299. fill_value='extrapolate',
  300. bounds_error=True)
  301. assert_raises(ValueError, interp1d, self.x10, self.y10, **opts)
  302. def test_zero(self):
  303. # Check the actual implementation of zero-order spline interpolation.
  304. interp10 = interp1d(self.x10, self.y10, kind='zero')
  305. assert_array_almost_equal(interp10(self.x10), self.y10)
  306. assert_array_almost_equal(interp10(1.2), np.array(1.))
  307. assert_array_almost_equal(interp10([2.4, 5.6, 6.0]),
  308. np.array([2., 5., 6.]))
  309. def _bounds_check(self, kind='linear'):
  310. # Test that our handling of out-of-bounds input is correct.
  311. extrap10 = interp1d(self.x10, self.y10, fill_value=self.fill_value,
  312. bounds_error=False, kind=kind)
  313. assert_array_equal(extrap10(11.2), np.array(self.fill_value))
  314. assert_array_equal(extrap10(-3.4), np.array(self.fill_value))
  315. assert_array_equal(extrap10([[[11.2], [-3.4], [12.6], [19.3]]]),
  316. np.array(self.fill_value),)
  317. assert_array_equal(extrap10._check_bounds(
  318. np.array([-1.0, 0.0, 5.0, 9.0, 11.0])),
  319. np.array([[True, False, False, False, False],
  320. [False, False, False, False, True]]))
  321. raises_bounds_error = interp1d(self.x10, self.y10, bounds_error=True,
  322. kind=kind)
  323. assert_raises(ValueError, raises_bounds_error, -1.0)
  324. assert_raises(ValueError, raises_bounds_error, 11.0)
  325. raises_bounds_error([0.0, 5.0, 9.0])
  326. def _bounds_check_int_nan_fill(self, kind='linear'):
  327. x = np.arange(10).astype(np.int_)
  328. y = np.arange(10).astype(np.int_)
  329. c = interp1d(x, y, kind=kind, fill_value=np.nan, bounds_error=False)
  330. yi = c(x - 1)
  331. assert_(np.isnan(yi[0]))
  332. assert_array_almost_equal(yi, np.r_[np.nan, y[:-1]])
  333. def test_bounds(self):
  334. for kind in ('linear', 'cubic', 'nearest', 'previous', 'next',
  335. 'slinear', 'zero', 'quadratic'):
  336. self._bounds_check(kind)
  337. self._bounds_check_int_nan_fill(kind)
  338. def _check_fill_value(self, kind):
  339. interp = interp1d(self.x10, self.y10, kind=kind,
  340. fill_value=(-100, 100), bounds_error=False)
  341. assert_array_almost_equal(interp(10), 100)
  342. assert_array_almost_equal(interp(-10), -100)
  343. assert_array_almost_equal(interp([-10, 10]), [-100, 100])
  344. # Proper broadcasting:
  345. # interp along axis of length 5
  346. # other dim=(2, 3), (3, 2), (2, 2), or (2,)
  347. # one singleton fill_value (works for all)
  348. for y in (self.y235, self.y325, self.y225, self.y25):
  349. interp = interp1d(self.x5, y, kind=kind, axis=-1,
  350. fill_value=100, bounds_error=False)
  351. assert_array_almost_equal(interp(10), 100)
  352. assert_array_almost_equal(interp(-10), 100)
  353. assert_array_almost_equal(interp([-10, 10]), 100)
  354. # singleton lower, singleton upper
  355. interp = interp1d(self.x5, y, kind=kind, axis=-1,
  356. fill_value=(-100, 100), bounds_error=False)
  357. assert_array_almost_equal(interp(10), 100)
  358. assert_array_almost_equal(interp(-10), -100)
  359. if y.ndim == 3:
  360. result = [[[-100, 100]] * y.shape[1]] * y.shape[0]
  361. else:
  362. result = [[-100, 100]] * y.shape[0]
  363. assert_array_almost_equal(interp([-10, 10]), result)
  364. # one broadcastable (3,) fill_value
  365. fill_value = [100, 200, 300]
  366. for y in (self.y325, self.y225):
  367. assert_raises(ValueError, interp1d, self.x5, y, kind=kind,
  368. axis=-1, fill_value=fill_value, bounds_error=False)
  369. interp = interp1d(self.x5, self.y235, kind=kind, axis=-1,
  370. fill_value=fill_value, bounds_error=False)
  371. assert_array_almost_equal(interp(10), [[100, 200, 300]] * 2)
  372. assert_array_almost_equal(interp(-10), [[100, 200, 300]] * 2)
  373. assert_array_almost_equal(interp([-10, 10]), [[[100, 100],
  374. [200, 200],
  375. [300, 300]]] * 2)
  376. # one broadcastable (2,) fill_value
  377. fill_value = [100, 200]
  378. assert_raises(ValueError, interp1d, self.x5, self.y235, kind=kind,
  379. axis=-1, fill_value=fill_value, bounds_error=False)
  380. for y in (self.y225, self.y325, self.y25):
  381. interp = interp1d(self.x5, y, kind=kind, axis=-1,
  382. fill_value=fill_value, bounds_error=False)
  383. result = [100, 200]
  384. if y.ndim == 3:
  385. result = [result] * y.shape[0]
  386. assert_array_almost_equal(interp(10), result)
  387. assert_array_almost_equal(interp(-10), result)
  388. result = [[100, 100], [200, 200]]
  389. if y.ndim == 3:
  390. result = [result] * y.shape[0]
  391. assert_array_almost_equal(interp([-10, 10]), result)
  392. # broadcastable (3,) lower, singleton upper
  393. fill_value = (np.array([-100, -200, -300]), 100)
  394. for y in (self.y325, self.y225):
  395. assert_raises(ValueError, interp1d, self.x5, y, kind=kind,
  396. axis=-1, fill_value=fill_value, bounds_error=False)
  397. interp = interp1d(self.x5, self.y235, kind=kind, axis=-1,
  398. fill_value=fill_value, bounds_error=False)
  399. assert_array_almost_equal(interp(10), 100)
  400. assert_array_almost_equal(interp(-10), [[-100, -200, -300]] * 2)
  401. assert_array_almost_equal(interp([-10, 10]), [[[-100, 100],
  402. [-200, 100],
  403. [-300, 100]]] * 2)
  404. # broadcastable (2,) lower, singleton upper
  405. fill_value = (np.array([-100, -200]), 100)
  406. assert_raises(ValueError, interp1d, self.x5, self.y235, kind=kind,
  407. axis=-1, fill_value=fill_value, bounds_error=False)
  408. for y in (self.y225, self.y325, self.y25):
  409. interp = interp1d(self.x5, y, kind=kind, axis=-1,
  410. fill_value=fill_value, bounds_error=False)
  411. assert_array_almost_equal(interp(10), 100)
  412. result = [-100, -200]
  413. if y.ndim == 3:
  414. result = [result] * y.shape[0]
  415. assert_array_almost_equal(interp(-10), result)
  416. result = [[-100, 100], [-200, 100]]
  417. if y.ndim == 3:
  418. result = [result] * y.shape[0]
  419. assert_array_almost_equal(interp([-10, 10]), result)
  420. # broadcastable (3,) lower, broadcastable (3,) upper
  421. fill_value = ([-100, -200, -300], [100, 200, 300])
  422. for y in (self.y325, self.y225):
  423. assert_raises(ValueError, interp1d, self.x5, y, kind=kind,
  424. axis=-1, fill_value=fill_value, bounds_error=False)
  425. for ii in range(2): # check ndarray as well as list here
  426. if ii == 1:
  427. fill_value = tuple(np.array(f) for f in fill_value)
  428. interp = interp1d(self.x5, self.y235, kind=kind, axis=-1,
  429. fill_value=fill_value, bounds_error=False)
  430. assert_array_almost_equal(interp(10), [[100, 200, 300]] * 2)
  431. assert_array_almost_equal(interp(-10), [[-100, -200, -300]] * 2)
  432. assert_array_almost_equal(interp([-10, 10]), [[[-100, 100],
  433. [-200, 200],
  434. [-300, 300]]] * 2)
  435. # broadcastable (2,) lower, broadcastable (2,) upper
  436. fill_value = ([-100, -200], [100, 200])
  437. assert_raises(ValueError, interp1d, self.x5, self.y235, kind=kind,
  438. axis=-1, fill_value=fill_value, bounds_error=False)
  439. for y in (self.y325, self.y225, self.y25):
  440. interp = interp1d(self.x5, y, kind=kind, axis=-1,
  441. fill_value=fill_value, bounds_error=False)
  442. result = [100, 200]
  443. if y.ndim == 3:
  444. result = [result] * y.shape[0]
  445. assert_array_almost_equal(interp(10), result)
  446. result = [-100, -200]
  447. if y.ndim == 3:
  448. result = [result] * y.shape[0]
  449. assert_array_almost_equal(interp(-10), result)
  450. result = [[-100, 100], [-200, 200]]
  451. if y.ndim == 3:
  452. result = [result] * y.shape[0]
  453. assert_array_almost_equal(interp([-10, 10]), result)
  454. # one broadcastable (2, 2) array-like
  455. fill_value = [[100, 200], [1000, 2000]]
  456. for y in (self.y235, self.y325, self.y25):
  457. assert_raises(ValueError, interp1d, self.x5, y, kind=kind,
  458. axis=-1, fill_value=fill_value, bounds_error=False)
  459. for ii in range(2):
  460. if ii == 1:
  461. fill_value = np.array(fill_value)
  462. interp = interp1d(self.x5, self.y225, kind=kind, axis=-1,
  463. fill_value=fill_value, bounds_error=False)
  464. assert_array_almost_equal(interp(10), [[100, 200], [1000, 2000]])
  465. assert_array_almost_equal(interp(-10), [[100, 200], [1000, 2000]])
  466. assert_array_almost_equal(interp([-10, 10]), [[[100, 100],
  467. [200, 200]],
  468. [[1000, 1000],
  469. [2000, 2000]]])
  470. # broadcastable (2, 2) lower, broadcastable (2, 2) upper
  471. fill_value = ([[-100, -200], [-1000, -2000]],
  472. [[100, 200], [1000, 2000]])
  473. for y in (self.y235, self.y325, self.y25):
  474. assert_raises(ValueError, interp1d, self.x5, y, kind=kind,
  475. axis=-1, fill_value=fill_value, bounds_error=False)
  476. for ii in range(2):
  477. if ii == 1:
  478. fill_value = (np.array(fill_value[0]), np.array(fill_value[1]))
  479. interp = interp1d(self.x5, self.y225, kind=kind, axis=-1,
  480. fill_value=fill_value, bounds_error=False)
  481. assert_array_almost_equal(interp(10), [[100, 200], [1000, 2000]])
  482. assert_array_almost_equal(interp(-10), [[-100, -200],
  483. [-1000, -2000]])
  484. assert_array_almost_equal(interp([-10, 10]), [[[-100, 100],
  485. [-200, 200]],
  486. [[-1000, 1000],
  487. [-2000, 2000]]])
  488. def test_fill_value(self):
  489. # test that two-element fill value works
  490. for kind in ('linear', 'nearest', 'cubic', 'slinear', 'quadratic',
  491. 'zero', 'previous', 'next'):
  492. self._check_fill_value(kind)
  493. def test_fill_value_writeable(self):
  494. # backwards compat: fill_value is a public writeable attribute
  495. interp = interp1d(self.x10, self.y10, fill_value=123.0)
  496. assert_equal(interp.fill_value, 123.0)
  497. interp.fill_value = 321.0
  498. assert_equal(interp.fill_value, 321.0)
  499. def _nd_check_interp(self, kind='linear'):
  500. # Check the behavior when the inputs and outputs are multidimensional.
  501. # Multidimensional input.
  502. interp10 = interp1d(self.x10, self.y10, kind=kind)
  503. assert_array_almost_equal(interp10(np.array([[3., 5.], [2., 7.]])),
  504. np.array([[3., 5.], [2., 7.]]))
  505. # Scalar input -> 0-dim scalar array output
  506. assert_(isinstance(interp10(1.2), np.ndarray))
  507. assert_equal(interp10(1.2).shape, ())
  508. # Multidimensional outputs.
  509. interp210 = interp1d(self.x10, self.y210, kind=kind)
  510. assert_array_almost_equal(interp210(1.), np.array([1., 11.]))
  511. assert_array_almost_equal(interp210(np.array([1., 2.])),
  512. np.array([[1., 2.], [11., 12.]]))
  513. interp102 = interp1d(self.x10, self.y102, axis=0, kind=kind)
  514. assert_array_almost_equal(interp102(1.), np.array([2.0, 3.0]))
  515. assert_array_almost_equal(interp102(np.array([1., 3.])),
  516. np.array([[2., 3.], [6., 7.]]))
  517. # Both at the same time!
  518. x_new = np.array([[3., 5.], [2., 7.]])
  519. assert_array_almost_equal(interp210(x_new),
  520. np.array([[[3., 5.], [2., 7.]],
  521. [[13., 15.], [12., 17.]]]))
  522. assert_array_almost_equal(interp102(x_new),
  523. np.array([[[6., 7.], [10., 11.]],
  524. [[4., 5.], [14., 15.]]]))
  525. def _nd_check_shape(self, kind='linear'):
  526. # Check large ndim output shape
  527. a = [4, 5, 6, 7]
  528. y = np.arange(np.prod(a)).reshape(*a)
  529. for n, s in enumerate(a):
  530. x = np.arange(s)
  531. z = interp1d(x, y, axis=n, kind=kind)
  532. assert_array_almost_equal(z(x), y, err_msg=kind)
  533. x2 = np.arange(2*3*1).reshape((2,3,1)) / 12.
  534. b = list(a)
  535. b[n:n+1] = [2,3,1]
  536. assert_array_almost_equal(z(x2).shape, b, err_msg=kind)
  537. def test_nd(self):
  538. for kind in ('linear', 'cubic', 'slinear', 'quadratic', 'nearest',
  539. 'zero', 'previous', 'next'):
  540. self._nd_check_interp(kind)
  541. self._nd_check_shape(kind)
  542. def _check_complex(self, dtype=np.complex_, kind='linear'):
  543. x = np.array([1, 2.5, 3, 3.1, 4, 6.4, 7.9, 8.0, 9.5, 10])
  544. y = x * x ** (1 + 2j)
  545. y = y.astype(dtype)
  546. # simple test
  547. c = interp1d(x, y, kind=kind)
  548. assert_array_almost_equal(y[:-1], c(x)[:-1])
  549. # check against interpolating real+imag separately
  550. xi = np.linspace(1, 10, 31)
  551. cr = interp1d(x, y.real, kind=kind)
  552. ci = interp1d(x, y.imag, kind=kind)
  553. assert_array_almost_equal(c(xi).real, cr(xi))
  554. assert_array_almost_equal(c(xi).imag, ci(xi))
  555. def test_complex(self):
  556. for kind in ('linear', 'nearest', 'cubic', 'slinear', 'quadratic',
  557. 'zero', 'previous', 'next'):
  558. self._check_complex(np.complex64, kind)
  559. self._check_complex(np.complex128, kind)
  560. @pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
  561. def test_circular_refs(self):
  562. # Test interp1d can be automatically garbage collected
  563. x = np.linspace(0, 1)
  564. y = np.linspace(0, 1)
  565. # Confirm interp can be released from memory after use
  566. with assert_deallocated(interp1d, x, y) as interp:
  567. new_y = interp([0.1, 0.2])
  568. del interp
  569. def test_overflow_nearest(self):
  570. # Test that the x range doesn't overflow when given integers as input
  571. for kind in ('nearest', 'previous', 'next'):
  572. x = np.array([0, 50, 127], dtype=np.int8)
  573. ii = interp1d(x, x, kind=kind)
  574. assert_array_almost_equal(ii(x), x)
  575. def test_local_nans(self):
  576. # check that for local interpolation kinds (slinear, zero) a single nan
  577. # only affects its local neighborhood
  578. x = np.arange(10).astype(float)
  579. y = x.copy()
  580. y[6] = np.nan
  581. for kind in ('zero', 'slinear'):
  582. ir = interp1d(x, y, kind=kind)
  583. vals = ir([4.9, 7.0])
  584. assert_(np.isfinite(vals).all())
  585. def test_spline_nans(self):
  586. # Backwards compat: a single nan makes the whole spline interpolation
  587. # return nans in an array of the correct shape. And it doesn't raise,
  588. # just quiet nans because of backcompat.
  589. x = np.arange(8).astype(float)
  590. y = x.copy()
  591. yn = y.copy()
  592. yn[3] = np.nan
  593. for kind in ['quadratic', 'cubic']:
  594. ir = interp1d(x, y, kind=kind)
  595. irn = interp1d(x, yn, kind=kind)
  596. for xnew in (6, [1, 6], [[1, 6], [3, 5]]):
  597. xnew = np.asarray(xnew)
  598. out, outn = ir(x), irn(x)
  599. assert_(np.isnan(outn).all())
  600. assert_equal(out.shape, outn.shape)
  601. def test_read_only(self):
  602. x = np.arange(0, 10)
  603. y = np.exp(-x / 3.0)
  604. xnew = np.arange(0, 9, 0.1)
  605. # Check both read-only and not read-only:
  606. for writeable in (True, False):
  607. xnew.flags.writeable = writeable
  608. for kind in ('linear', 'nearest', 'zero', 'slinear', 'quadratic',
  609. 'cubic'):
  610. f = interp1d(x, y, kind=kind)
  611. vals = f(xnew)
  612. assert_(np.isfinite(vals).all())
  613. class TestLagrange(object):
  614. def test_lagrange(self):
  615. p = poly1d([5,2,1,4,3])
  616. xs = np.arange(len(p.coeffs))
  617. ys = p(xs)
  618. pl = lagrange(xs,ys)
  619. assert_array_almost_equal(p.coeffs,pl.coeffs)
  620. class TestAkima1DInterpolator(object):
  621. def test_eval(self):
  622. x = np.arange(0., 11.)
  623. y = np.array([0., 2., 1., 3., 2., 6., 5.5, 5.5, 2.7, 5.1, 3.])
  624. ak = Akima1DInterpolator(x, y)
  625. xi = np.array([0., 0.5, 1., 1.5, 2.5, 3.5, 4.5, 5.1, 6.5, 7.2,
  626. 8.6, 9.9, 10.])
  627. yi = np.array([0., 1.375, 2., 1.5, 1.953125, 2.484375,
  628. 4.1363636363636366866103344, 5.9803623910336236590978842,
  629. 5.5067291516462386624652936, 5.2031367459745245795943447,
  630. 4.1796554159017080820603951, 3.4110386597938129327189927,
  631. 3.])
  632. assert_allclose(ak(xi), yi)
  633. def test_eval_2d(self):
  634. x = np.arange(0., 11.)
  635. y = np.array([0., 2., 1., 3., 2., 6., 5.5, 5.5, 2.7, 5.1, 3.])
  636. y = np.column_stack((y, 2. * y))
  637. ak = Akima1DInterpolator(x, y)
  638. xi = np.array([0., 0.5, 1., 1.5, 2.5, 3.5, 4.5, 5.1, 6.5, 7.2,
  639. 8.6, 9.9, 10.])
  640. yi = np.array([0., 1.375, 2., 1.5, 1.953125, 2.484375,
  641. 4.1363636363636366866103344,
  642. 5.9803623910336236590978842,
  643. 5.5067291516462386624652936,
  644. 5.2031367459745245795943447,
  645. 4.1796554159017080820603951,
  646. 3.4110386597938129327189927, 3.])
  647. yi = np.column_stack((yi, 2. * yi))
  648. assert_allclose(ak(xi), yi)
  649. def test_eval_3d(self):
  650. x = np.arange(0., 11.)
  651. y_ = np.array([0., 2., 1., 3., 2., 6., 5.5, 5.5, 2.7, 5.1, 3.])
  652. y = np.empty((11, 2, 2))
  653. y[:, 0, 0] = y_
  654. y[:, 1, 0] = 2. * y_
  655. y[:, 0, 1] = 3. * y_
  656. y[:, 1, 1] = 4. * y_
  657. ak = Akima1DInterpolator(x, y)
  658. xi = np.array([0., 0.5, 1., 1.5, 2.5, 3.5, 4.5, 5.1, 6.5, 7.2,
  659. 8.6, 9.9, 10.])
  660. yi = np.empty((13, 2, 2))
  661. yi_ = np.array([0., 1.375, 2., 1.5, 1.953125, 2.484375,
  662. 4.1363636363636366866103344,
  663. 5.9803623910336236590978842,
  664. 5.5067291516462386624652936,
  665. 5.2031367459745245795943447,
  666. 4.1796554159017080820603951,
  667. 3.4110386597938129327189927, 3.])
  668. yi[:, 0, 0] = yi_
  669. yi[:, 1, 0] = 2. * yi_
  670. yi[:, 0, 1] = 3. * yi_
  671. yi[:, 1, 1] = 4. * yi_
  672. assert_allclose(ak(xi), yi)
  673. def test_degenerate_case_multidimensional(self):
  674. # This test is for issue #5683.
  675. x = np.array([0, 1, 2])
  676. y = np.vstack((x, x**2)).T
  677. ak = Akima1DInterpolator(x, y)
  678. x_eval = np.array([0.5, 1.5])
  679. y_eval = ak(x_eval)
  680. assert_allclose(y_eval, np.vstack((x_eval, x_eval**2)).T)
  681. def test_extend(self):
  682. x = np.arange(0., 11.)
  683. y = np.array([0., 2., 1., 3., 2., 6., 5.5, 5.5, 2.7, 5.1, 3.])
  684. ak = Akima1DInterpolator(x, y)
  685. match = "Extending a 1D Akima interpolator is not yet implemented"
  686. with pytest.raises(NotImplementedError, match=match):
  687. ak.extend(None, None)
  688. class TestPPolyCommon(object):
  689. # test basic functionality for PPoly and BPoly
  690. def test_sort_check(self):
  691. c = np.array([[1, 4], [2, 5], [3, 6]])
  692. x = np.array([0, 1, 0.5])
  693. assert_raises(ValueError, PPoly, c, x)
  694. assert_raises(ValueError, BPoly, c, x)
  695. def test_ctor_c(self):
  696. # wrong shape: `c` must be at least 2-dimensional
  697. with assert_raises(ValueError):
  698. PPoly([1, 2], [0, 1])
  699. def test_extend(self):
  700. # Test adding new points to the piecewise polynomial
  701. np.random.seed(1234)
  702. order = 3
  703. x = np.unique(np.r_[0, 10 * np.random.rand(30), 10])
  704. c = 2*np.random.rand(order+1, len(x)-1, 2, 3) - 1
  705. for cls in (PPoly, BPoly):
  706. pp = cls(c[:,:9], x[:10])
  707. pp.extend(c[:,9:], x[10:])
  708. pp2 = cls(c[:, 10:], x[10:])
  709. pp2.extend(c[:, :10], x[:10])
  710. pp3 = cls(c, x)
  711. assert_array_equal(pp.c, pp3.c)
  712. assert_array_equal(pp.x, pp3.x)
  713. assert_array_equal(pp2.c, pp3.c)
  714. assert_array_equal(pp2.x, pp3.x)
  715. def test_extend_diff_orders(self):
  716. # Test extending polynomial with different order one
  717. np.random.seed(1234)
  718. x = np.linspace(0, 1, 6)
  719. c = np.random.rand(2, 5)
  720. x2 = np.linspace(1, 2, 6)
  721. c2 = np.random.rand(4, 5)
  722. for cls in (PPoly, BPoly):
  723. pp1 = cls(c, x)
  724. pp2 = cls(c2, x2)
  725. pp_comb = cls(c, x)
  726. pp_comb.extend(c2, x2[1:])
  727. # NB. doesn't match to pp1 at the endpoint, because pp1 is not
  728. # continuous with pp2 as we took random coefs.
  729. xi1 = np.linspace(0, 1, 300, endpoint=False)
  730. xi2 = np.linspace(1, 2, 300)
  731. assert_allclose(pp1(xi1), pp_comb(xi1))
  732. assert_allclose(pp2(xi2), pp_comb(xi2))
  733. def test_extend_descending(self):
  734. np.random.seed(0)
  735. order = 3
  736. x = np.sort(np.random.uniform(0, 10, 20))
  737. c = np.random.rand(order + 1, x.shape[0] - 1, 2, 3)
  738. for cls in (PPoly, BPoly):
  739. p = cls(c, x)
  740. p1 = cls(c[:, :9], x[:10])
  741. p1.extend(c[:, 9:], x[10:])
  742. p2 = cls(c[:, 10:], x[10:])
  743. p2.extend(c[:, :10], x[:10])
  744. assert_array_equal(p1.c, p.c)
  745. assert_array_equal(p1.x, p.x)
  746. assert_array_equal(p2.c, p.c)
  747. assert_array_equal(p2.x, p.x)
  748. def test_shape(self):
  749. np.random.seed(1234)
  750. c = np.random.rand(8, 12, 5, 6, 7)
  751. x = np.sort(np.random.rand(13))
  752. xp = np.random.rand(3, 4)
  753. for cls in (PPoly, BPoly):
  754. p = cls(c, x)
  755. assert_equal(p(xp).shape, (3, 4, 5, 6, 7))
  756. # 'scalars'
  757. for cls in (PPoly, BPoly):
  758. p = cls(c[..., 0, 0, 0], x)
  759. assert_equal(np.shape(p(0.5)), ())
  760. assert_equal(np.shape(p(np.array(0.5))), ())
  761. # can't use dtype=object (with any numpy; what fails is
  762. # constructing the object array here for old numpy)
  763. assert_raises(ValueError, p, np.array([[0.1, 0.2], [0.4]]))
  764. def test_complex_coef(self):
  765. np.random.seed(12345)
  766. x = np.sort(np.random.random(13))
  767. c = np.random.random((8, 12)) * (1. + 0.3j)
  768. c_re, c_im = c.real, c.imag
  769. xp = np.random.random(5)
  770. for cls in (PPoly, BPoly):
  771. p, p_re, p_im = cls(c, x), cls(c_re, x), cls(c_im, x)
  772. for nu in [0, 1, 2]:
  773. assert_allclose(p(xp, nu).real, p_re(xp, nu))
  774. assert_allclose(p(xp, nu).imag, p_im(xp, nu))
  775. def test_axis(self):
  776. np.random.seed(12345)
  777. c = np.random.rand(3, 4, 5, 6, 7, 8)
  778. c_s = c.shape
  779. xp = np.random.random((1, 2))
  780. for axis in (0, 1, 2, 3):
  781. k, m = c.shape[axis], c.shape[axis+1]
  782. x = np.sort(np.random.rand(m+1))
  783. for cls in (PPoly, BPoly):
  784. p = cls(c, x, axis=axis)
  785. assert_equal(p.c.shape,
  786. c_s[axis:axis+2] + c_s[:axis] + c_s[axis+2:])
  787. res = p(xp)
  788. targ_shape = c_s[:axis] + xp.shape + c_s[2+axis:]
  789. assert_equal(res.shape, targ_shape)
  790. # deriv/antideriv does not drop the axis
  791. for p1 in [cls(c, x, axis=axis).derivative(),
  792. cls(c, x, axis=axis).derivative(2),
  793. cls(c, x, axis=axis).antiderivative(),
  794. cls(c, x, axis=axis).antiderivative(2)]:
  795. assert_equal(p1.axis, p.axis)
  796. # c array needs two axes for the coefficients and intervals, so
  797. # 0 <= axis < c.ndim-1; raise otherwise
  798. for axis in (-1, 4, 5, 6):
  799. for cls in (BPoly, PPoly):
  800. assert_raises(ValueError, cls, **dict(c=c, x=x, axis=axis))
  801. class TestPolySubclassing(object):
  802. class P(PPoly):
  803. pass
  804. class B(BPoly):
  805. pass
  806. def _make_polynomials(self):
  807. np.random.seed(1234)
  808. x = np.sort(np.random.random(3))
  809. c = np.random.random((4, 2))
  810. return self.P(c, x), self.B(c, x)
  811. def test_derivative(self):
  812. pp, bp = self._make_polynomials()
  813. for p in (pp, bp):
  814. pd = p.derivative()
  815. assert_equal(p.__class__, pd.__class__)
  816. ppa = pp.antiderivative()
  817. assert_equal(pp.__class__, ppa.__class__)
  818. def test_from_spline(self):
  819. np.random.seed(1234)
  820. x = np.sort(np.r_[0, np.random.rand(11), 1])
  821. y = np.random.rand(len(x))
  822. spl = splrep(x, y, s=0)
  823. pp = self.P.from_spline(spl)
  824. assert_equal(pp.__class__, self.P)
  825. def test_conversions(self):
  826. pp, bp = self._make_polynomials()
  827. pp1 = self.P.from_bernstein_basis(bp)
  828. assert_equal(pp1.__class__, self.P)
  829. bp1 = self.B.from_power_basis(pp)
  830. assert_equal(bp1.__class__, self.B)
  831. def test_from_derivatives(self):
  832. x = [0, 1, 2]
  833. y = [[1], [2], [3]]
  834. bp = self.B.from_derivatives(x, y)
  835. assert_equal(bp.__class__, self.B)
  836. class TestPPoly(object):
  837. def test_simple(self):
  838. c = np.array([[1, 4], [2, 5], [3, 6]])
  839. x = np.array([0, 0.5, 1])
  840. p = PPoly(c, x)
  841. assert_allclose(p(0.3), 1*0.3**2 + 2*0.3 + 3)
  842. assert_allclose(p(0.7), 4*(0.7-0.5)**2 + 5*(0.7-0.5) + 6)
  843. def test_periodic(self):
  844. c = np.array([[1, 4], [2, 5], [3, 6]])
  845. x = np.array([0, 0.5, 1])
  846. p = PPoly(c, x, extrapolate='periodic')
  847. assert_allclose(p(1.3), 1 * 0.3 ** 2 + 2 * 0.3 + 3)
  848. assert_allclose(p(-0.3), 4 * (0.7 - 0.5) ** 2 + 5 * (0.7 - 0.5) + 6)
  849. assert_allclose(p(1.3, 1), 2 * 0.3 + 2)
  850. assert_allclose(p(-0.3, 1), 8 * (0.7 - 0.5) + 5)
  851. def test_descending(self):
  852. def binom_matrix(power):
  853. n = np.arange(power + 1).reshape(-1, 1)
  854. k = np.arange(power + 1)
  855. B = binom(n, k)
  856. return B[::-1, ::-1]
  857. np.random.seed(0)
  858. power = 3
  859. for m in [10, 20, 30]:
  860. x = np.sort(np.random.uniform(0, 10, m + 1))
  861. ca = np.random.uniform(-2, 2, size=(power + 1, m))
  862. h = np.diff(x)
  863. h_powers = h[None, :] ** np.arange(power + 1)[::-1, None]
  864. B = binom_matrix(power)
  865. cap = ca * h_powers
  866. cdp = np.dot(B.T, cap)
  867. cd = cdp / h_powers
  868. pa = PPoly(ca, x, extrapolate=True)
  869. pd = PPoly(cd[:, ::-1], x[::-1], extrapolate=True)
  870. x_test = np.random.uniform(-10, 20, 100)
  871. assert_allclose(pa(x_test), pd(x_test), rtol=1e-13)
  872. assert_allclose(pa(x_test, 1), pd(x_test, 1), rtol=1e-13)
  873. pa_d = pa.derivative()
  874. pd_d = pd.derivative()
  875. assert_allclose(pa_d(x_test), pd_d(x_test), rtol=1e-13)
  876. # Antiderivatives won't be equal because fixing continuity is
  877. # done in the reverse order, but surely the differences should be
  878. # equal.
  879. pa_i = pa.antiderivative()
  880. pd_i = pd.antiderivative()
  881. for a, b in np.random.uniform(-10, 20, (5, 2)):
  882. int_a = pa.integrate(a, b)
  883. int_d = pd.integrate(a, b)
  884. assert_allclose(int_a, int_d, rtol=1e-13)
  885. assert_allclose(pa_i(b) - pa_i(a), pd_i(b) - pd_i(a),
  886. rtol=1e-13)
  887. roots_d = pd.roots()
  888. roots_a = pa.roots()
  889. assert_allclose(roots_a, np.sort(roots_d), rtol=1e-12)
  890. def test_multi_shape(self):
  891. c = np.random.rand(6, 2, 1, 2, 3)
  892. x = np.array([0, 0.5, 1])
  893. p = PPoly(c, x)
  894. assert_equal(p.x.shape, x.shape)
  895. assert_equal(p.c.shape, c.shape)
  896. assert_equal(p(0.3).shape, c.shape[2:])
  897. assert_equal(p(np.random.rand(5, 6)).shape, (5, 6) + c.shape[2:])
  898. dp = p.derivative()
  899. assert_equal(dp.c.shape, (5, 2, 1, 2, 3))
  900. ip = p.antiderivative()
  901. assert_equal(ip.c.shape, (7, 2, 1, 2, 3))
  902. def test_construct_fast(self):
  903. np.random.seed(1234)
  904. c = np.array([[1, 4], [2, 5], [3, 6]], dtype=float)
  905. x = np.array([0, 0.5, 1])
  906. p = PPoly.construct_fast(c, x)
  907. assert_allclose(p(0.3), 1*0.3**2 + 2*0.3 + 3)
  908. assert_allclose(p(0.7), 4*(0.7-0.5)**2 + 5*(0.7-0.5) + 6)
  909. def test_vs_alternative_implementations(self):
  910. np.random.seed(1234)
  911. c = np.random.rand(3, 12, 22)
  912. x = np.sort(np.r_[0, np.random.rand(11), 1])
  913. p = PPoly(c, x)
  914. xp = np.r_[0.3, 0.5, 0.33, 0.6]
  915. expected = _ppoly_eval_1(c, x, xp)
  916. assert_allclose(p(xp), expected)
  917. expected = _ppoly_eval_2(c[:,:,0], x, xp)
  918. assert_allclose(p(xp)[:,0], expected)
  919. def test_from_spline(self):
  920. np.random.seed(1234)
  921. x = np.sort(np.r_[0, np.random.rand(11), 1])
  922. y = np.random.rand(len(x))
  923. spl = splrep(x, y, s=0)
  924. pp = PPoly.from_spline(spl)
  925. xi = np.linspace(0, 1, 200)
  926. assert_allclose(pp(xi), splev(xi, spl))
  927. # make sure .from_spline accepts BSpline objects
  928. b = BSpline(*spl)
  929. ppp = PPoly.from_spline(b)
  930. assert_allclose(ppp(xi), b(xi))
  931. # BSpline's extrapolate attribute propagates unless overridden
  932. t, c, k = spl
  933. for extrap in (None, True, False):
  934. b = BSpline(t, c, k, extrapolate=extrap)
  935. p = PPoly.from_spline(b)
  936. assert_equal(p.extrapolate, b.extrapolate)
  937. def test_derivative_simple(self):
  938. np.random.seed(1234)
  939. c = np.array([[4, 3, 2, 1]]).T
  940. dc = np.array([[3*4, 2*3, 2]]).T
  941. ddc = np.array([[2*3*4, 1*2*3]]).T
  942. x = np.array([0, 1])
  943. pp = PPoly(c, x)
  944. dpp = PPoly(dc, x)
  945. ddpp = PPoly(ddc, x)
  946. assert_allclose(pp.derivative().c, dpp.c)
  947. assert_allclose(pp.derivative(2).c, ddpp.c)
  948. def test_derivative_eval(self):
  949. np.random.seed(1234)
  950. x = np.sort(np.r_[0, np.random.rand(11), 1])
  951. y = np.random.rand(len(x))
  952. spl = splrep(x, y, s=0)
  953. pp = PPoly.from_spline(spl)
  954. xi = np.linspace(0, 1, 200)
  955. for dx in range(0, 3):
  956. assert_allclose(pp(xi, dx), splev(xi, spl, dx))
  957. def test_derivative(self):
  958. np.random.seed(1234)
  959. x = np.sort(np.r_[0, np.random.rand(11), 1])
  960. y = np.random.rand(len(x))
  961. spl = splrep(x, y, s=0, k=5)
  962. pp = PPoly.from_spline(spl)
  963. xi = np.linspace(0, 1, 200)
  964. for dx in range(0, 10):
  965. assert_allclose(pp(xi, dx), pp.derivative(dx)(xi),
  966. err_msg="dx=%d" % (dx,))
  967. def test_antiderivative_of_constant(self):
  968. # https://github.com/scipy/scipy/issues/4216
  969. p = PPoly([[1.]], [0, 1])
  970. assert_equal(p.antiderivative().c, PPoly([[1], [0]], [0, 1]).c)
  971. assert_equal(p.antiderivative().x, PPoly([[1], [0]], [0, 1]).x)
  972. def test_antiderivative_regression_4355(self):
  973. # https://github.com/scipy/scipy/issues/4355
  974. p = PPoly([[1., 0.5]], [0, 1, 2])
  975. q = p.antiderivative()
  976. assert_equal(q.c, [[1, 0.5], [0, 1]])
  977. assert_equal(q.x, [0, 1, 2])
  978. assert_allclose(p.integrate(0, 2), 1.5)
  979. assert_allclose(q(2) - q(0), 1.5)
  980. def test_antiderivative_simple(self):
  981. np.random.seed(1234)
  982. # [ p1(x) = 3*x**2 + 2*x + 1,
  983. # p2(x) = 1.6875]
  984. c = np.array([[3, 2, 1], [0, 0, 1.6875]]).T
  985. # [ pp1(x) = x**3 + x**2 + x,
  986. # pp2(x) = 1.6875*(x - 0.25) + pp1(0.25)]
  987. ic = np.array([[1, 1, 1, 0], [0, 0, 1.6875, 0.328125]]).T
  988. # [ ppp1(x) = (1/4)*x**4 + (1/3)*x**3 + (1/2)*x**2,
  989. # ppp2(x) = (1.6875/2)*(x - 0.25)**2 + pp1(0.25)*x + ppp1(0.25)]
  990. iic = np.array([[1/4, 1/3, 1/2, 0, 0],
  991. [0, 0, 1.6875/2, 0.328125, 0.037434895833333336]]).T
  992. x = np.array([0, 0.25, 1])
  993. pp = PPoly(c, x)
  994. ipp = pp.antiderivative()
  995. iipp = pp.antiderivative(2)
  996. iipp2 = ipp.antiderivative()
  997. assert_allclose(ipp.x, x)
  998. assert_allclose(ipp.c.T, ic.T)
  999. assert_allclose(iipp.c.T, iic.T)
  1000. assert_allclose(iipp2.c.T, iic.T)
  1001. def test_antiderivative_vs_derivative(self):
  1002. np.random.seed(1234)
  1003. x = np.linspace(0, 1, 30)**2
  1004. y = np.random.rand(len(x))
  1005. spl = splrep(x, y, s=0, k=5)
  1006. pp = PPoly.from_spline(spl)
  1007. for dx in range(0, 10):
  1008. ipp = pp.antiderivative(dx)
  1009. # check that derivative is inverse op
  1010. pp2 = ipp.derivative(dx)
  1011. assert_allclose(pp.c, pp2.c)
  1012. # check continuity
  1013. for k in range(dx):
  1014. pp2 = ipp.derivative(k)
  1015. r = 1e-13
  1016. endpoint = r*pp2.x[:-1] + (1 - r)*pp2.x[1:]
  1017. assert_allclose(pp2(pp2.x[1:]), pp2(endpoint),
  1018. rtol=1e-7, err_msg="dx=%d k=%d" % (dx, k))
  1019. def test_antiderivative_vs_spline(self):
  1020. np.random.seed(1234)
  1021. x = np.sort(np.r_[0, np.random.rand(11), 1])
  1022. y = np.random.rand(len(x))
  1023. spl = splrep(x, y, s=0, k=5)
  1024. pp = PPoly.from_spline(spl)
  1025. for dx in range(0, 10):
  1026. pp2 = pp.antiderivative(dx)
  1027. spl2 = splantider(spl, dx)
  1028. xi = np.linspace(0, 1, 200)
  1029. assert_allclose(pp2(xi), splev(xi, spl2),
  1030. rtol=1e-7)
  1031. def test_antiderivative_continuity(self):
  1032. c = np.array([[2, 1, 2, 2], [2, 1, 3, 3]]).T
  1033. x = np.array([0, 0.5, 1])
  1034. p = PPoly(c, x)
  1035. ip = p.antiderivative()
  1036. # check continuity
  1037. assert_allclose(ip(0.5 - 1e-9), ip(0.5 + 1e-9), rtol=1e-8)
  1038. # check that only lowest order coefficients were changed
  1039. p2 = ip.derivative()
  1040. assert_allclose(p2.c, p.c)
  1041. def test_integrate(self):
  1042. np.random.seed(1234)
  1043. x = np.sort(np.r_[0, np.random.rand(11), 1])
  1044. y = np.random.rand(len(x))
  1045. spl = splrep(x, y, s=0, k=5)
  1046. pp = PPoly.from_spline(spl)
  1047. a, b = 0.3, 0.9
  1048. ig = pp.integrate(a, b)
  1049. ipp = pp.antiderivative()
  1050. assert_allclose(ig, ipp(b) - ipp(a))
  1051. assert_allclose(ig, splint(a, b, spl))
  1052. a, b = -0.3, 0.9
  1053. ig = pp.integrate(a, b, extrapolate=True)
  1054. assert_allclose(ig, ipp(b) - ipp(a))
  1055. assert_(np.isnan(pp.integrate(a, b, extrapolate=False)).all())
  1056. def test_integrate_periodic(self):
  1057. x = np.array([1, 2, 4])
  1058. c = np.array([[0., 0.], [-1., -1.], [2., -0.], [1., 2.]])
  1059. P = PPoly(c, x, extrapolate='periodic')
  1060. I = P.antiderivative()
  1061. period_int = I(4) - I(1)
  1062. assert_allclose(P.integrate(1, 4), period_int)
  1063. assert_allclose(P.integrate(-10, -7), period_int)
  1064. assert_allclose(P.integrate(-10, -4), 2 * period_int)
  1065. assert_allclose(P.integrate(1.5, 2.5), I(2.5) - I(1.5))
  1066. assert_allclose(P.integrate(3.5, 5), I(2) - I(1) + I(4) - I(3.5))
  1067. assert_allclose(P.integrate(3.5 + 12, 5 + 12),
  1068. I(2) - I(1) + I(4) - I(3.5))
  1069. assert_allclose(P.integrate(3.5, 5 + 12),
  1070. I(2) - I(1) + I(4) - I(3.5) + 4 * period_int)
  1071. assert_allclose(P.integrate(0, -1), I(2) - I(3))
  1072. assert_allclose(P.integrate(-9, -10), I(2) - I(3))
  1073. assert_allclose(P.integrate(0, -10), I(2) - I(3) - 3 * period_int)
  1074. def test_roots(self):
  1075. x = np.linspace(0, 1, 31)**2
  1076. y = np.sin(30*x)
  1077. spl = splrep(x, y, s=0, k=3)
  1078. pp = PPoly.from_spline(spl)
  1079. r = pp.roots()
  1080. r = r[(r >= 0 - 1e-15) & (r <= 1 + 1e-15)]
  1081. assert_allclose(r, sproot(spl), atol=1e-15)
  1082. def test_roots_idzero(self):
  1083. # Roots for piecewise polynomials with identically zero
  1084. # sections.
  1085. c = np.array([[-1, 0.25], [0, 0], [-1, 0.25]]).T
  1086. x = np.array([0, 0.4, 0.6, 1.0])
  1087. pp = PPoly(c, x)
  1088. assert_array_equal(pp.roots(),
  1089. [0.25, 0.4, np.nan, 0.6 + 0.25])
  1090. # ditto for p.solve(const) with sections identically equal const
  1091. const = 2.
  1092. c1 = c.copy()
  1093. c1[1, :] += const
  1094. pp1 = PPoly(c1, x)
  1095. assert_array_equal(pp1.solve(const),
  1096. [0.25, 0.4, np.nan, 0.6 + 0.25])
  1097. def test_roots_all_zero(self):
  1098. # test the code path for the polynomial being identically zero everywhere
  1099. c = [[0], [0]]
  1100. x = [0, 1]
  1101. p = PPoly(c, x)
  1102. assert_array_equal(p.roots(), [0, np.nan])
  1103. assert_array_equal(p.solve(0), [0, np.nan])
  1104. assert_array_equal(p.solve(1), [])
  1105. c = [[0, 0], [0, 0]]
  1106. x = [0, 1, 2]
  1107. p = PPoly(c, x)
  1108. assert_array_equal(p.roots(), [0, np.nan, 1, np.nan])
  1109. assert_array_equal(p.solve(0), [0, np.nan, 1, np.nan])
  1110. assert_array_equal(p.solve(1), [])
  1111. def test_roots_repeated(self):
  1112. # Check roots repeated in multiple sections are reported only
  1113. # once.
  1114. # [(x + 1)**2 - 1, -x**2] ; x == 0 is a repeated root
  1115. c = np.array([[1, 0, -1], [-1, 0, 0]]).T
  1116. x = np.array([-1, 0, 1])
  1117. pp = PPoly(c, x)
  1118. assert_array_equal(pp.roots(), [-2, 0])
  1119. assert_array_equal(pp.roots(extrapolate=False), [0])
  1120. def test_roots_discont(self):
  1121. # Check that a discontinuity across zero is reported as root
  1122. c = np.array([[1], [-1]]).T
  1123. x = np.array([0, 0.5, 1])
  1124. pp = PPoly(c, x)
  1125. assert_array_equal(pp.roots(), [0.5])
  1126. assert_array_equal(pp.roots(discontinuity=False), [])
  1127. # ditto for a discontinuity across y:
  1128. assert_array_equal(pp.solve(0.5), [0.5])
  1129. assert_array_equal(pp.solve(0.5, discontinuity=False), [])
  1130. assert_array_equal(pp.solve(1.5), [])
  1131. assert_array_equal(pp.solve(1.5, discontinuity=False), [])
  1132. def test_roots_random(self):
  1133. # Check high-order polynomials with random coefficients
  1134. np.random.seed(1234)
  1135. num = 0
  1136. for extrapolate in (True, False):
  1137. for order in range(0, 20):
  1138. x = np.unique(np.r_[0, 10 * np.random.rand(30), 10])
  1139. c = 2*np.random.rand(order+1, len(x)-1, 2, 3) - 1
  1140. pp = PPoly(c, x)
  1141. for y in [0, np.random.random()]:
  1142. r = pp.solve(y, discontinuity=False, extrapolate=extrapolate)
  1143. for i in range(2):
  1144. for j in range(3):
  1145. rr = r[i,j]
  1146. if rr.size > 0:
  1147. # Check that the reported roots indeed are roots
  1148. num += rr.size
  1149. val = pp(rr, extrapolate=extrapolate)[:,i,j]
  1150. cmpval = pp(rr, nu=1,
  1151. extrapolate=extrapolate)[:,i,j]
  1152. msg = "(%r) r = %s" % (extrapolate, repr(rr),)
  1153. assert_allclose((val-y) / cmpval, 0, atol=1e-7,
  1154. err_msg=msg)
  1155. # Check that we checked a number of roots
  1156. assert_(num > 100, repr(num))
  1157. def test_roots_croots(self):
  1158. # Test the complex root finding algorithm
  1159. np.random.seed(1234)
  1160. for k in range(1, 15):
  1161. c = np.random.rand(k, 1, 130)
  1162. if k == 3:
  1163. # add a case with zero discriminant
  1164. c[:,0,0] = 1, 2, 1
  1165. for y in [0, np.random.random()]:
  1166. w = np.empty(c.shape, dtype=complex)
  1167. _ppoly._croots_poly1(c, w)
  1168. if k == 1:
  1169. assert_(np.isnan(w).all())
  1170. continue
  1171. res = 0
  1172. cres = 0
  1173. for i in range(k):
  1174. res += c[i,None] * w**(k-1-i)
  1175. cres += abs(c[i,None] * w**(k-1-i))
  1176. with np.errstate(invalid='ignore'):
  1177. res /= cres
  1178. res = res.ravel()
  1179. res = res[~np.isnan(res)]
  1180. assert_allclose(res, 0, atol=1e-10)
  1181. def test_extrapolate_attr(self):
  1182. # [ 1 - x**2 ]
  1183. c = np.array([[-1, 0, 1]]).T
  1184. x = np.array([0, 1])
  1185. for extrapolate in [True, False, None]:
  1186. pp = PPoly(c, x, extrapolate=extrapolate)
  1187. pp_d = pp.derivative()
  1188. pp_i = pp.antiderivative()
  1189. if extrapolate is False:
  1190. assert_(np.isnan(pp([-0.1, 1.1])).all())
  1191. assert_(np.isnan(pp_i([-0.1, 1.1])).all())
  1192. assert_(np.isnan(pp_d([-0.1, 1.1])).all())
  1193. assert_equal(pp.roots(), [1])
  1194. else:
  1195. assert_allclose(pp([-0.1, 1.1]), [1-0.1**2, 1-1.1**2])
  1196. assert_(not np.isnan(pp_i([-0.1, 1.1])).any())
  1197. assert_(not np.isnan(pp_d([-0.1, 1.1])).any())
  1198. assert_allclose(pp.roots(), [1, -1])
  1199. class TestBPoly(object):
  1200. def test_simple(self):
  1201. x = [0, 1]
  1202. c = [[3]]
  1203. bp = BPoly(c, x)
  1204. assert_allclose(bp(0.1), 3.)
  1205. def test_simple2(self):
  1206. x = [0, 1]
  1207. c = [[3], [1]]
  1208. bp = BPoly(c, x) # 3*(1-x) + 1*x
  1209. assert_allclose(bp(0.1), 3*0.9 + 1.*0.1)
  1210. def test_simple3(self):
  1211. x = [0, 1]
  1212. c = [[3], [1], [4]]
  1213. bp = BPoly(c, x) # 3 * (1-x)**2 + 2 * x (1-x) + 4 * x**2
  1214. assert_allclose(bp(0.2),
  1215. 3 * 0.8*0.8 + 1 * 2*0.2*0.8 + 4 * 0.2*0.2)
  1216. def test_simple4(self):
  1217. x = [0, 1]
  1218. c = [[1], [1], [1], [2]]
  1219. bp = BPoly(c, x)
  1220. assert_allclose(bp(0.3), 0.7**3 +
  1221. 3 * 0.7**2 * 0.3 +
  1222. 3 * 0.7 * 0.3**2 +
  1223. 2 * 0.3**3)
  1224. def test_simple5(self):
  1225. x = [0, 1]
  1226. c = [[1], [1], [8], [2], [1]]
  1227. bp = BPoly(c, x)
  1228. assert_allclose(bp(0.3), 0.7**4 +
  1229. 4 * 0.7**3 * 0.3 +
  1230. 8 * 6 * 0.7**2 * 0.3**2 +
  1231. 2 * 4 * 0.7 * 0.3**3 +
  1232. 0.3**4)
  1233. def test_periodic(self):
  1234. x = [0, 1, 3]
  1235. c = [[3, 0], [0, 0], [0, 2]]
  1236. # [3*(1-x)**2, 2*((x-1)/2)**2]
  1237. bp = BPoly(c, x, extrapolate='periodic')
  1238. assert_allclose(bp(3.4), 3 * 0.6**2)
  1239. assert_allclose(bp(-1.3), 2 * (0.7/2)**2)
  1240. assert_allclose(bp(3.4, 1), -6 * 0.6)
  1241. assert_allclose(bp(-1.3, 1), 2 * (0.7/2))
  1242. def test_descending(self):
  1243. np.random.seed(0)
  1244. power = 3
  1245. for m in [10, 20, 30]:
  1246. x = np.sort(np.random.uniform(0, 10, m + 1))
  1247. ca = np.random.uniform(-0.1, 0.1, size=(power + 1, m))
  1248. # We need only to flip coefficients to get it right!
  1249. cd = ca[::-1].copy()
  1250. pa = BPoly(ca, x, extrapolate=True)
  1251. pd = BPoly(cd[:, ::-1], x[::-1], extrapolate=True)
  1252. x_test = np.random.uniform(-10, 20, 100)
  1253. assert_allclose(pa(x_test), pd(x_test), rtol=1e-13)
  1254. assert_allclose(pa(x_test, 1), pd(x_test, 1), rtol=1e-13)
  1255. pa_d = pa.derivative()
  1256. pd_d = pd.derivative()
  1257. assert_allclose(pa_d(x_test), pd_d(x_test), rtol=1e-13)
  1258. # Antiderivatives won't be equal because fixing continuity is
  1259. # done in the reverse order, but surely the differences should be
  1260. # equal.
  1261. pa_i = pa.antiderivative()
  1262. pd_i = pd.antiderivative()
  1263. for a, b in np.random.uniform(-10, 20, (5, 2)):
  1264. int_a = pa.integrate(a, b)
  1265. int_d = pd.integrate(a, b)
  1266. assert_allclose(int_a, int_d, rtol=1e-12)
  1267. assert_allclose(pa_i(b) - pa_i(a), pd_i(b) - pd_i(a),
  1268. rtol=1e-12)
  1269. def test_multi_shape(self):
  1270. c = np.random.rand(6, 2, 1, 2, 3)
  1271. x = np.array([0, 0.5, 1])
  1272. p = BPoly(c, x)
  1273. assert_equal(p.x.shape, x.shape)
  1274. assert_equal(p.c.shape, c.shape)
  1275. assert_equal(p(0.3).shape, c.shape[2:])
  1276. assert_equal(p(np.random.rand(5,6)).shape,
  1277. (5,6)+c.shape[2:])
  1278. dp = p.derivative()
  1279. assert_equal(dp.c.shape, (5, 2, 1, 2, 3))
  1280. def test_interval_length(self):
  1281. x = [0, 2]
  1282. c = [[3], [1], [4]]
  1283. bp = BPoly(c, x)
  1284. xval = 0.1
  1285. s = xval / 2 # s = (x - xa) / (xb - xa)
  1286. assert_allclose(bp(xval), 3 * (1-s)*(1-s) + 1 * 2*s*(1-s) + 4 * s*s)
  1287. def test_two_intervals(self):
  1288. x = [0, 1, 3]
  1289. c = [[3, 0], [0, 0], [0, 2]]
  1290. bp = BPoly(c, x) # [3*(1-x)**2, 2*((x-1)/2)**2]
  1291. assert_allclose(bp(0.4), 3 * 0.6*0.6)
  1292. assert_allclose(bp(1.7), 2 * (0.7/2)**2)
  1293. def test_extrapolate_attr(self):
  1294. x = [0, 2]
  1295. c = [[3], [1], [4]]
  1296. bp = BPoly(c, x)
  1297. for extrapolate in (True, False, None):
  1298. bp = BPoly(c, x, extrapolate=extrapolate)
  1299. bp_d = bp.derivative()
  1300. if extrapolate is False:
  1301. assert_(np.isnan(bp([-0.1, 2.1])).all())
  1302. assert_(np.isnan(bp_d([-0.1, 2.1])).all())
  1303. else:
  1304. assert_(not np.isnan(bp([-0.1, 2.1])).any())
  1305. assert_(not np.isnan(bp_d([-0.1, 2.1])).any())
  1306. class TestBPolyCalculus(object):
  1307. def test_derivative(self):
  1308. x = [0, 1, 3]
  1309. c = [[3, 0], [0, 0], [0, 2]]
  1310. bp = BPoly(c, x) # [3*(1-x)**2, 2*((x-1)/2)**2]
  1311. bp_der = bp.derivative()
  1312. assert_allclose(bp_der(0.4), -6*(0.6))
  1313. assert_allclose(bp_der(1.7), 0.7)
  1314. # derivatives in-place
  1315. assert_allclose([bp(0.4, nu=1), bp(0.4, nu=2), bp(0.4, nu=3)],
  1316. [-6*(1-0.4), 6., 0.])
  1317. assert_allclose([bp(1.7, nu=1), bp(1.7, nu=2), bp(1.7, nu=3)],
  1318. [0.7, 1., 0])
  1319. def test_derivative_ppoly(self):
  1320. # make sure it's consistent w/ power basis
  1321. np.random.seed(1234)
  1322. m, k = 5, 8 # number of intervals, order
  1323. x = np.sort(np.random.random(m))
  1324. c = np.random.random((k, m-1))
  1325. bp = BPoly(c, x)
  1326. pp = PPoly.from_bernstein_basis(bp)
  1327. for d in range(k):
  1328. bp = bp.derivative()
  1329. pp = pp.derivative()
  1330. xp = np.linspace(x[0], x[-1], 21)
  1331. assert_allclose(bp(xp), pp(xp))
  1332. def test_deriv_inplace(self):
  1333. np.random.seed(1234)
  1334. m, k = 5, 8 # number of intervals, order
  1335. x = np.sort(np.random.random(m))
  1336. c = np.random.random((k, m-1))
  1337. # test both real and complex coefficients
  1338. for cc in [c.copy(), c*(1. + 2.j)]:
  1339. bp = BPoly(cc, x)
  1340. xp = np.linspace(x[0], x[-1], 21)
  1341. for i in range(k):
  1342. assert_allclose(bp(xp, i), bp.derivative(i)(xp))
  1343. def test_antiderivative_simple(self):
  1344. # f(x) = x for x \in [0, 1),
  1345. # (x-1)/2 for x \in [1, 3]
  1346. #
  1347. # antiderivative is then
  1348. # F(x) = x**2 / 2 for x \in [0, 1),
  1349. # 0.5*x*(x/2 - 1) + A for x \in [1, 3]
  1350. # where A = 3/4 for continuity at x = 1.
  1351. x = [0, 1, 3]
  1352. c = [[0, 0], [1, 1]]
  1353. bp = BPoly(c, x)
  1354. bi = bp.antiderivative()
  1355. xx = np.linspace(0, 3, 11)
  1356. assert_allclose(bi(xx),
  1357. np.where(xx < 1, xx**2 / 2.,
  1358. 0.5 * xx * (xx/2. - 1) + 3./4),
  1359. atol=1e-12, rtol=1e-12)
  1360. def test_der_antider(self):
  1361. np.random.seed(1234)
  1362. x = np.sort(np.random.random(11))
  1363. c = np.random.random((4, 10, 2, 3))
  1364. bp = BPoly(c, x)
  1365. xx = np.linspace(x[0], x[-1], 100)
  1366. assert_allclose(bp.antiderivative().derivative()(xx),
  1367. bp(xx), atol=1e-12, rtol=1e-12)
  1368. def test_antider_ppoly(self):
  1369. np.random.seed(1234)
  1370. x = np.sort(np.random.random(11))
  1371. c = np.random.random((4, 10, 2, 3))
  1372. bp = BPoly(c, x)
  1373. pp = PPoly.from_bernstein_basis(bp)
  1374. xx = np.linspace(x[0], x[-1], 10)
  1375. assert_allclose(bp.antiderivative(2)(xx),
  1376. pp.antiderivative(2)(xx), atol=1e-12, rtol=1e-12)
  1377. def test_antider_continuous(self):
  1378. np.random.seed(1234)
  1379. x = np.sort(np.random.random(11))
  1380. c = np.random.random((4, 10))
  1381. bp = BPoly(c, x).antiderivative()
  1382. xx = bp.x[1:-1]
  1383. assert_allclose(bp(xx - 1e-14),
  1384. bp(xx + 1e-14), atol=1e-12, rtol=1e-12)
  1385. def test_integrate(self):
  1386. np.random.seed(1234)
  1387. x = np.sort(np.random.random(11))
  1388. c = np.random.random((4, 10))
  1389. bp = BPoly(c, x)
  1390. pp = PPoly.from_bernstein_basis(bp)
  1391. assert_allclose(bp.integrate(0, 1),
  1392. pp.integrate(0, 1), atol=1e-12, rtol=1e-12)
  1393. def test_integrate_extrap(self):
  1394. c = [[1]]
  1395. x = [0, 1]
  1396. b = BPoly(c, x)
  1397. # default is extrapolate=True
  1398. assert_allclose(b.integrate(0, 2), 2., atol=1e-14)
  1399. # .integrate argument overrides self.extrapolate
  1400. b1 = BPoly(c, x, extrapolate=False)
  1401. assert_(np.isnan(b1.integrate(0, 2)))
  1402. assert_allclose(b1.integrate(0, 2, extrapolate=True), 2., atol=1e-14)
  1403. def test_integrate_periodic(self):
  1404. x = np.array([1, 2, 4])
  1405. c = np.array([[0., 0.], [-1., -1.], [2., -0.], [1., 2.]])
  1406. P = BPoly.from_power_basis(PPoly(c, x), extrapolate='periodic')
  1407. I = P.antiderivative()
  1408. period_int = I(4) - I(1)
  1409. assert_allclose(P.integrate(1, 4), period_int)
  1410. assert_allclose(P.integrate(-10, -7), period_int)
  1411. assert_allclose(P.integrate(-10, -4), 2 * period_int)
  1412. assert_allclose(P.integrate(1.5, 2.5), I(2.5) - I(1.5))
  1413. assert_allclose(P.integrate(3.5, 5), I(2) - I(1) + I(4) - I(3.5))
  1414. assert_allclose(P.integrate(3.5 + 12, 5 + 12),
  1415. I(2) - I(1) + I(4) - I(3.5))
  1416. assert_allclose(P.integrate(3.5, 5 + 12),
  1417. I(2) - I(1) + I(4) - I(3.5) + 4 * period_int)
  1418. assert_allclose(P.integrate(0, -1), I(2) - I(3))
  1419. assert_allclose(P.integrate(-9, -10), I(2) - I(3))
  1420. assert_allclose(P.integrate(0, -10), I(2) - I(3) - 3 * period_int)
  1421. def test_antider_neg(self):
  1422. # .derivative(-nu) ==> .andiderivative(nu) and vice versa
  1423. c = [[1]]
  1424. x = [0, 1]
  1425. b = BPoly(c, x)
  1426. xx = np.linspace(0, 1, 21)
  1427. assert_allclose(b.derivative(-1)(xx), b.antiderivative()(xx),
  1428. atol=1e-12, rtol=1e-12)
  1429. assert_allclose(b.derivative(1)(xx), b.antiderivative(-1)(xx),
  1430. atol=1e-12, rtol=1e-12)
  1431. class TestPolyConversions(object):
  1432. def test_bp_from_pp(self):
  1433. x = [0, 1, 3]
  1434. c = [[3, 2], [1, 8], [4, 3]]
  1435. pp = PPoly(c, x)
  1436. bp = BPoly.from_power_basis(pp)
  1437. pp1 = PPoly.from_bernstein_basis(bp)
  1438. xp = [0.1, 1.4]
  1439. assert_allclose(pp(xp), bp(xp))
  1440. assert_allclose(pp(xp), pp1(xp))
  1441. def test_bp_from_pp_random(self):
  1442. np.random.seed(1234)
  1443. m, k = 5, 8 # number of intervals, order
  1444. x = np.sort(np.random.random(m))
  1445. c = np.random.random((k, m-1))
  1446. pp = PPoly(c, x)
  1447. bp = BPoly.from_power_basis(pp)
  1448. pp1 = PPoly.from_bernstein_basis(bp)
  1449. xp = np.linspace(x[0], x[-1], 21)
  1450. assert_allclose(pp(xp), bp(xp))
  1451. assert_allclose(pp(xp), pp1(xp))
  1452. def test_pp_from_bp(self):
  1453. x = [0, 1, 3]
  1454. c = [[3, 3], [1, 1], [4, 2]]
  1455. bp = BPoly(c, x)
  1456. pp = PPoly.from_bernstein_basis(bp)
  1457. bp1 = BPoly.from_power_basis(pp)
  1458. xp = [0.1, 1.4]
  1459. assert_allclose(bp(xp), pp(xp))
  1460. assert_allclose(bp(xp), bp1(xp))
  1461. class TestBPolyFromDerivatives(object):
  1462. def test_make_poly_1(self):
  1463. c1 = BPoly._construct_from_derivatives(0, 1, [2], [3])
  1464. assert_allclose(c1, [2., 3.])
  1465. def test_make_poly_2(self):
  1466. c1 = BPoly._construct_from_derivatives(0, 1, [1, 0], [1])
  1467. assert_allclose(c1, [1., 1., 1.])
  1468. # f'(0) = 3
  1469. c2 = BPoly._construct_from_derivatives(0, 1, [2, 3], [1])
  1470. assert_allclose(c2, [2., 7./2, 1.])
  1471. # f'(1) = 3
  1472. c3 = BPoly._construct_from_derivatives(0, 1, [2], [1, 3])
  1473. assert_allclose(c3, [2., -0.5, 1.])
  1474. def test_make_poly_3(self):
  1475. # f'(0)=2, f''(0)=3
  1476. c1 = BPoly._construct_from_derivatives(0, 1, [1, 2, 3], [4])
  1477. assert_allclose(c1, [1., 5./3, 17./6, 4.])
  1478. # f'(1)=2, f''(1)=3
  1479. c2 = BPoly._construct_from_derivatives(0, 1, [1], [4, 2, 3])
  1480. assert_allclose(c2, [1., 19./6, 10./3, 4.])
  1481. # f'(0)=2, f'(1)=3
  1482. c3 = BPoly._construct_from_derivatives(0, 1, [1, 2], [4, 3])
  1483. assert_allclose(c3, [1., 5./3, 3., 4.])
  1484. def test_make_poly_12(self):
  1485. np.random.seed(12345)
  1486. ya = np.r_[0, np.random.random(5)]
  1487. yb = np.r_[0, np.random.random(5)]
  1488. c = BPoly._construct_from_derivatives(0, 1, ya, yb)
  1489. pp = BPoly(c[:, None], [0, 1])
  1490. for j in range(6):
  1491. assert_allclose([pp(0.), pp(1.)], [ya[j], yb[j]])
  1492. pp = pp.derivative()
  1493. def test_raise_degree(self):
  1494. np.random.seed(12345)
  1495. x = [0, 1]
  1496. k, d = 8, 5
  1497. c = np.random.random((k, 1, 2, 3, 4))
  1498. bp = BPoly(c, x)
  1499. c1 = BPoly._raise_degree(c, d)
  1500. bp1 = BPoly(c1, x)
  1501. xp = np.linspace(0, 1, 11)
  1502. assert_allclose(bp(xp), bp1(xp))
  1503. def test_xi_yi(self):
  1504. assert_raises(ValueError, BPoly.from_derivatives, [0, 1], [0])
  1505. def test_coords_order(self):
  1506. xi = [0, 0, 1]
  1507. yi = [[0], [0], [0]]
  1508. assert_raises(ValueError, BPoly.from_derivatives, xi, yi)
  1509. def test_zeros(self):
  1510. xi = [0, 1, 2, 3]
  1511. yi = [[0, 0], [0], [0, 0], [0, 0]] # NB: will have to raise the degree
  1512. pp = BPoly.from_derivatives(xi, yi)
  1513. assert_(pp.c.shape == (4, 3))
  1514. ppd = pp.derivative()
  1515. for xp in [0., 0.1, 1., 1.1, 1.9, 2., 2.5]:
  1516. assert_allclose([pp(xp), ppd(xp)], [0., 0.])
  1517. def _make_random_mk(self, m, k):
  1518. # k derivatives at each breakpoint
  1519. np.random.seed(1234)
  1520. xi = np.asarray([1. * j**2 for j in range(m+1)])
  1521. yi = [np.random.random(k) for j in range(m+1)]
  1522. return xi, yi
  1523. def test_random_12(self):
  1524. m, k = 5, 12
  1525. xi, yi = self._make_random_mk(m, k)
  1526. pp = BPoly.from_derivatives(xi, yi)
  1527. for order in range(k//2):
  1528. assert_allclose(pp(xi), [yy[order] for yy in yi])
  1529. pp = pp.derivative()
  1530. def test_order_zero(self):
  1531. m, k = 5, 12
  1532. xi, yi = self._make_random_mk(m, k)
  1533. assert_raises(ValueError, BPoly.from_derivatives,
  1534. **dict(xi=xi, yi=yi, orders=0))
  1535. def test_orders_too_high(self):
  1536. m, k = 5, 12
  1537. xi, yi = self._make_random_mk(m, k)
  1538. pp = BPoly.from_derivatives(xi, yi, orders=2*k-1) # this is still ok
  1539. assert_raises(ValueError, BPoly.from_derivatives, # but this is not
  1540. **dict(xi=xi, yi=yi, orders=2*k))
  1541. def test_orders_global(self):
  1542. m, k = 5, 12
  1543. xi, yi = self._make_random_mk(m, k)
  1544. # ok, this is confusing. Local polynomials will be of the order 5
  1545. # which means that up to the 2nd derivatives will be used at each point
  1546. order = 5
  1547. pp = BPoly.from_derivatives(xi, yi, orders=order)
  1548. for j in range(order//2+1):
  1549. assert_allclose(pp(xi[1:-1] - 1e-12), pp(xi[1:-1] + 1e-12))
  1550. pp = pp.derivative()
  1551. assert_(not np.allclose(pp(xi[1:-1] - 1e-12), pp(xi[1:-1] + 1e-12)))
  1552. # now repeat with `order` being even: on each interval, it uses
  1553. # order//2 'derivatives' @ the right-hand endpoint and
  1554. # order//2+1 @ 'derivatives' the left-hand endpoint
  1555. order = 6
  1556. pp = BPoly.from_derivatives(xi, yi, orders=order)
  1557. for j in range(order//2):
  1558. assert_allclose(pp(xi[1:-1] - 1e-12), pp(xi[1:-1] + 1e-12))
  1559. pp = pp.derivative()
  1560. assert_(not np.allclose(pp(xi[1:-1] - 1e-12), pp(xi[1:-1] + 1e-12)))
  1561. def test_orders_local(self):
  1562. m, k = 7, 12
  1563. xi, yi = self._make_random_mk(m, k)
  1564. orders = [o + 1 for o in range(m)]
  1565. for i, x in enumerate(xi[1:-1]):
  1566. pp = BPoly.from_derivatives(xi, yi, orders=orders)
  1567. for j in range(orders[i] // 2 + 1):
  1568. assert_allclose(pp(x - 1e-12), pp(x + 1e-12))
  1569. pp = pp.derivative()
  1570. assert_(not np.allclose(pp(x - 1e-12), pp(x + 1e-12)))
  1571. def test_yi_trailing_dims(self):
  1572. m, k = 7, 5
  1573. xi = np.sort(np.random.random(m+1))
  1574. yi = np.random.random((m+1, k, 6, 7, 8))
  1575. pp = BPoly.from_derivatives(xi, yi)
  1576. assert_equal(pp.c.shape, (2*k, m, 6, 7, 8))
  1577. def test_gh_5430(self):
  1578. # At least one of these raises an error unless gh-5430 is
  1579. # fixed. In py2k an int is implemented using a C long, so
  1580. # which one fails depends on your system. In py3k there is only
  1581. # one arbitrary precision integer type, so both should fail.
  1582. orders = np.int32(1)
  1583. p = BPoly.from_derivatives([0, 1], [[0], [0]], orders=orders)
  1584. assert_almost_equal(p(0), 0)
  1585. orders = np.int64(1)
  1586. p = BPoly.from_derivatives([0, 1], [[0], [0]], orders=orders)
  1587. assert_almost_equal(p(0), 0)
  1588. orders = 1
  1589. # This worked before; make sure it still works
  1590. p = BPoly.from_derivatives([0, 1], [[0], [0]], orders=orders)
  1591. assert_almost_equal(p(0), 0)
  1592. orders = 1
  1593. class TestNdPPoly(object):
  1594. def test_simple_1d(self):
  1595. np.random.seed(1234)
  1596. c = np.random.rand(4, 5)
  1597. x = np.linspace(0, 1, 5+1)
  1598. xi = np.random.rand(200)
  1599. p = NdPPoly(c, (x,))
  1600. v1 = p((xi,))
  1601. v2 = _ppoly_eval_1(c[:,:,None], x, xi).ravel()
  1602. assert_allclose(v1, v2)
  1603. def test_simple_2d(self):
  1604. np.random.seed(1234)
  1605. c = np.random.rand(4, 5, 6, 7)
  1606. x = np.linspace(0, 1, 6+1)
  1607. y = np.linspace(0, 1, 7+1)**2
  1608. xi = np.random.rand(200)
  1609. yi = np.random.rand(200)
  1610. v1 = np.empty([len(xi), 1], dtype=c.dtype)
  1611. v1.fill(np.nan)
  1612. _ppoly.evaluate_nd(c.reshape(4*5, 6*7, 1),
  1613. (x, y),
  1614. np.array([4, 5], dtype=np.intc),
  1615. np.c_[xi, yi],
  1616. np.array([0, 0], dtype=np.intc),
  1617. 1,
  1618. v1)
  1619. v1 = v1.ravel()
  1620. v2 = _ppoly2d_eval(c, (x, y), xi, yi)
  1621. assert_allclose(v1, v2)
  1622. p = NdPPoly(c, (x, y))
  1623. for nu in (None, (0, 0), (0, 1), (1, 0), (2, 3), (9, 2)):
  1624. v1 = p(np.c_[xi, yi], nu=nu)
  1625. v2 = _ppoly2d_eval(c, (x, y), xi, yi, nu=nu)
  1626. assert_allclose(v1, v2, err_msg=repr(nu))
  1627. def test_simple_3d(self):
  1628. np.random.seed(1234)
  1629. c = np.random.rand(4, 5, 6, 7, 8, 9)
  1630. x = np.linspace(0, 1, 7+1)
  1631. y = np.linspace(0, 1, 8+1)**2
  1632. z = np.linspace(0, 1, 9+1)**3
  1633. xi = np.random.rand(40)
  1634. yi = np.random.rand(40)
  1635. zi = np.random.rand(40)
  1636. p = NdPPoly(c, (x, y, z))
  1637. for nu in (None, (0, 0, 0), (0, 1, 0), (1, 0, 0), (2, 3, 0),
  1638. (6, 0, 2)):
  1639. v1 = p((xi, yi, zi), nu=nu)
  1640. v2 = _ppoly3d_eval(c, (x, y, z), xi, yi, zi, nu=nu)
  1641. assert_allclose(v1, v2, err_msg=repr(nu))
  1642. def test_simple_4d(self):
  1643. np.random.seed(1234)
  1644. c = np.random.rand(4, 5, 6, 7, 8, 9, 10, 11)
  1645. x = np.linspace(0, 1, 8+1)
  1646. y = np.linspace(0, 1, 9+1)**2
  1647. z = np.linspace(0, 1, 10+1)**3
  1648. u = np.linspace(0, 1, 11+1)**4
  1649. xi = np.random.rand(20)
  1650. yi = np.random.rand(20)
  1651. zi = np.random.rand(20)
  1652. ui = np.random.rand(20)
  1653. p = NdPPoly(c, (x, y, z, u))
  1654. v1 = p((xi, yi, zi, ui))
  1655. v2 = _ppoly4d_eval(c, (x, y, z, u), xi, yi, zi, ui)
  1656. assert_allclose(v1, v2)
  1657. def test_deriv_1d(self):
  1658. np.random.seed(1234)
  1659. c = np.random.rand(4, 5)
  1660. x = np.linspace(0, 1, 5+1)
  1661. p = NdPPoly(c, (x,))
  1662. # derivative
  1663. dp = p.derivative(nu=[1])
  1664. p1 = PPoly(c, x)
  1665. dp1 = p1.derivative()
  1666. assert_allclose(dp.c, dp1.c)
  1667. # antiderivative
  1668. dp = p.antiderivative(nu=[2])
  1669. p1 = PPoly(c, x)
  1670. dp1 = p1.antiderivative(2)
  1671. assert_allclose(dp.c, dp1.c)
  1672. def test_deriv_3d(self):
  1673. np.random.seed(1234)
  1674. c = np.random.rand(4, 5, 6, 7, 8, 9)
  1675. x = np.linspace(0, 1, 7+1)
  1676. y = np.linspace(0, 1, 8+1)**2
  1677. z = np.linspace(0, 1, 9+1)**3
  1678. p = NdPPoly(c, (x, y, z))
  1679. # differentiate vs x
  1680. p1 = PPoly(c.transpose(0, 3, 1, 2, 4, 5), x)
  1681. dp = p.derivative(nu=[2])
  1682. dp1 = p1.derivative(2)
  1683. assert_allclose(dp.c,
  1684. dp1.c.transpose(0, 2, 3, 1, 4, 5))
  1685. # antidifferentiate vs y
  1686. p1 = PPoly(c.transpose(1, 4, 0, 2, 3, 5), y)
  1687. dp = p.antiderivative(nu=[0, 1, 0])
  1688. dp1 = p1.antiderivative(1)
  1689. assert_allclose(dp.c,
  1690. dp1.c.transpose(2, 0, 3, 4, 1, 5))
  1691. # differentiate vs z
  1692. p1 = PPoly(c.transpose(2, 5, 0, 1, 3, 4), z)
  1693. dp = p.derivative(nu=[0, 0, 3])
  1694. dp1 = p1.derivative(3)
  1695. assert_allclose(dp.c,
  1696. dp1.c.transpose(2, 3, 0, 4, 5, 1))
  1697. def test_deriv_3d_simple(self):
  1698. # Integrate to obtain function x y**2 z**4 / (2! 4!)
  1699. c = np.ones((1, 1, 1, 3, 4, 5))
  1700. x = np.linspace(0, 1, 3+1)**1
  1701. y = np.linspace(0, 1, 4+1)**2
  1702. z = np.linspace(0, 1, 5+1)**3
  1703. p = NdPPoly(c, (x, y, z))
  1704. ip = p.antiderivative((1, 0, 4))
  1705. ip = ip.antiderivative((0, 2, 0))
  1706. xi = np.random.rand(20)
  1707. yi = np.random.rand(20)
  1708. zi = np.random.rand(20)
  1709. assert_allclose(ip((xi, yi, zi)),
  1710. xi * yi**2 * zi**4 / (gamma(3)*gamma(5)))
  1711. def test_integrate_2d(self):
  1712. np.random.seed(1234)
  1713. c = np.random.rand(4, 5, 16, 17)
  1714. x = np.linspace(0, 1, 16+1)**1
  1715. y = np.linspace(0, 1, 17+1)**2
  1716. # make continuously differentiable so that nquad() has an
  1717. # easier time
  1718. c = c.transpose(0, 2, 1, 3)
  1719. cx = c.reshape(c.shape[0], c.shape[1], -1).copy()
  1720. _ppoly.fix_continuity(cx, x, 2)
  1721. c = cx.reshape(c.shape)
  1722. c = c.transpose(0, 2, 1, 3)
  1723. c = c.transpose(1, 3, 0, 2)
  1724. cx = c.reshape(c.shape[0], c.shape[1], -1).copy()
  1725. _ppoly.fix_continuity(cx, y, 2)
  1726. c = cx.reshape(c.shape)
  1727. c = c.transpose(2, 0, 3, 1).copy()
  1728. # Check integration
  1729. p = NdPPoly(c, (x, y))
  1730. for ranges in [[(0, 1), (0, 1)],
  1731. [(0, 0.5), (0, 1)],
  1732. [(0, 1), (0, 0.5)],
  1733. [(0.3, 0.7), (0.6, 0.2)]]:
  1734. ig = p.integrate(ranges)
  1735. ig2, err2 = nquad(lambda x, y: p((x, y)), ranges,
  1736. opts=[dict(epsrel=1e-5, epsabs=1e-5)]*2)
  1737. assert_allclose(ig, ig2, rtol=1e-5, atol=1e-5,
  1738. err_msg=repr(ranges))
  1739. def test_integrate_1d(self):
  1740. np.random.seed(1234)
  1741. c = np.random.rand(4, 5, 6, 16, 17, 18)
  1742. x = np.linspace(0, 1, 16+1)**1
  1743. y = np.linspace(0, 1, 17+1)**2
  1744. z = np.linspace(0, 1, 18+1)**3
  1745. # Check 1D integration
  1746. p = NdPPoly(c, (x, y, z))
  1747. u = np.random.rand(200)
  1748. v = np.random.rand(200)
  1749. a, b = 0.2, 0.7
  1750. px = p.integrate_1d(a, b, axis=0)
  1751. pax = p.antiderivative((1, 0, 0))
  1752. assert_allclose(px((u, v)), pax((b, u, v)) - pax((a, u, v)))
  1753. py = p.integrate_1d(a, b, axis=1)
  1754. pay = p.antiderivative((0, 1, 0))
  1755. assert_allclose(py((u, v)), pay((u, b, v)) - pay((u, a, v)))
  1756. pz = p.integrate_1d(a, b, axis=2)
  1757. paz = p.antiderivative((0, 0, 1))
  1758. assert_allclose(pz((u, v)), paz((u, v, b)) - paz((u, v, a)))
  1759. def _ppoly_eval_1(c, x, xps):
  1760. """Evaluate piecewise polynomial manually"""
  1761. out = np.zeros((len(xps), c.shape[2]))
  1762. for i, xp in enumerate(xps):
  1763. if xp < 0 or xp > 1:
  1764. out[i,:] = np.nan
  1765. continue
  1766. j = np.searchsorted(x, xp) - 1
  1767. d = xp - x[j]
  1768. assert_(x[j] <= xp < x[j+1])
  1769. r = sum(c[k,j] * d**(c.shape[0]-k-1)
  1770. for k in range(c.shape[0]))
  1771. out[i,:] = r
  1772. return out
  1773. def _ppoly_eval_2(coeffs, breaks, xnew, fill=np.nan):
  1774. """Evaluate piecewise polynomial manually (another way)"""
  1775. a = breaks[0]
  1776. b = breaks[-1]
  1777. K = coeffs.shape[0]
  1778. saveshape = np.shape(xnew)
  1779. xnew = np.ravel(xnew)
  1780. res = np.empty_like(xnew)
  1781. mask = (xnew >= a) & (xnew <= b)
  1782. res[~mask] = fill
  1783. xx = xnew.compress(mask)
  1784. indxs = np.searchsorted(breaks, xx)-1
  1785. indxs = indxs.clip(0, len(breaks))
  1786. pp = coeffs
  1787. diff = xx - breaks.take(indxs)
  1788. V = np.vander(diff, N=K)
  1789. values = np.array([np.dot(V[k, :], pp[:, indxs[k]]) for k in xrange(len(xx))])
  1790. res[mask] = values
  1791. res.shape = saveshape
  1792. return res
  1793. def _dpow(x, y, n):
  1794. """
  1795. d^n (x**y) / dx^n
  1796. """
  1797. if n < 0:
  1798. raise ValueError("invalid derivative order")
  1799. elif n > y:
  1800. return 0
  1801. else:
  1802. return poch(y - n + 1, n) * x**(y - n)
  1803. def _ppoly2d_eval(c, xs, xnew, ynew, nu=None):
  1804. """
  1805. Straightforward evaluation of 2D piecewise polynomial
  1806. """
  1807. if nu is None:
  1808. nu = (0, 0)
  1809. out = np.empty((len(xnew),), dtype=c.dtype)
  1810. nx, ny = c.shape[:2]
  1811. for jout, (x, y) in enumerate(zip(xnew, ynew)):
  1812. if not ((xs[0][0] <= x <= xs[0][-1]) and
  1813. (xs[1][0] <= y <= xs[1][-1])):
  1814. out[jout] = np.nan
  1815. continue
  1816. j1 = np.searchsorted(xs[0], x) - 1
  1817. j2 = np.searchsorted(xs[1], y) - 1
  1818. s1 = x - xs[0][j1]
  1819. s2 = y - xs[1][j2]
  1820. val = 0
  1821. for k1 in range(c.shape[0]):
  1822. for k2 in range(c.shape[1]):
  1823. val += (c[nx-k1-1,ny-k2-1,j1,j2]
  1824. * _dpow(s1, k1, nu[0])
  1825. * _dpow(s2, k2, nu[1]))
  1826. out[jout] = val
  1827. return out
  1828. def _ppoly3d_eval(c, xs, xnew, ynew, znew, nu=None):
  1829. """
  1830. Straightforward evaluation of 3D piecewise polynomial
  1831. """
  1832. if nu is None:
  1833. nu = (0, 0, 0)
  1834. out = np.empty((len(xnew),), dtype=c.dtype)
  1835. nx, ny, nz = c.shape[:3]
  1836. for jout, (x, y, z) in enumerate(zip(xnew, ynew, znew)):
  1837. if not ((xs[0][0] <= x <= xs[0][-1]) and
  1838. (xs[1][0] <= y <= xs[1][-1]) and
  1839. (xs[2][0] <= z <= xs[2][-1])):
  1840. out[jout] = np.nan
  1841. continue
  1842. j1 = np.searchsorted(xs[0], x) - 1
  1843. j2 = np.searchsorted(xs[1], y) - 1
  1844. j3 = np.searchsorted(xs[2], z) - 1
  1845. s1 = x - xs[0][j1]
  1846. s2 = y - xs[1][j2]
  1847. s3 = z - xs[2][j3]
  1848. val = 0
  1849. for k1 in range(c.shape[0]):
  1850. for k2 in range(c.shape[1]):
  1851. for k3 in range(c.shape[2]):
  1852. val += (c[nx-k1-1,ny-k2-1,nz-k3-1,j1,j2,j3]
  1853. * _dpow(s1, k1, nu[0])
  1854. * _dpow(s2, k2, nu[1])
  1855. * _dpow(s3, k3, nu[2]))
  1856. out[jout] = val
  1857. return out
  1858. def _ppoly4d_eval(c, xs, xnew, ynew, znew, unew, nu=None):
  1859. """
  1860. Straightforward evaluation of 4D piecewise polynomial
  1861. """
  1862. if nu is None:
  1863. nu = (0, 0, 0, 0)
  1864. out = np.empty((len(xnew),), dtype=c.dtype)
  1865. mx, my, mz, mu = c.shape[:4]
  1866. for jout, (x, y, z, u) in enumerate(zip(xnew, ynew, znew, unew)):
  1867. if not ((xs[0][0] <= x <= xs[0][-1]) and
  1868. (xs[1][0] <= y <= xs[1][-1]) and
  1869. (xs[2][0] <= z <= xs[2][-1]) and
  1870. (xs[3][0] <= u <= xs[3][-1])):
  1871. out[jout] = np.nan
  1872. continue
  1873. j1 = np.searchsorted(xs[0], x) - 1
  1874. j2 = np.searchsorted(xs[1], y) - 1
  1875. j3 = np.searchsorted(xs[2], z) - 1
  1876. j4 = np.searchsorted(xs[3], u) - 1
  1877. s1 = x - xs[0][j1]
  1878. s2 = y - xs[1][j2]
  1879. s3 = z - xs[2][j3]
  1880. s4 = u - xs[3][j4]
  1881. val = 0
  1882. for k1 in range(c.shape[0]):
  1883. for k2 in range(c.shape[1]):
  1884. for k3 in range(c.shape[2]):
  1885. for k4 in range(c.shape[3]):
  1886. val += (c[mx-k1-1,my-k2-1,mz-k3-1,mu-k4-1,j1,j2,j3,j4]
  1887. * _dpow(s1, k1, nu[0])
  1888. * _dpow(s2, k2, nu[1])
  1889. * _dpow(s3, k3, nu[2])
  1890. * _dpow(s4, k4, nu[3]))
  1891. out[jout] = val
  1892. return out
  1893. class TestRegularGridInterpolator(object):
  1894. def _get_sample_4d(self):
  1895. # create a 4d grid of 3 points in each dimension
  1896. points = [(0., .5, 1.)] * 4
  1897. values = np.asarray([0., .5, 1.])
  1898. values0 = values[:, np.newaxis, np.newaxis, np.newaxis]
  1899. values1 = values[np.newaxis, :, np.newaxis, np.newaxis]
  1900. values2 = values[np.newaxis, np.newaxis, :, np.newaxis]
  1901. values3 = values[np.newaxis, np.newaxis, np.newaxis, :]
  1902. values = (values0 + values1 * 10 + values2 * 100 + values3 * 1000)
  1903. return points, values
  1904. def _get_sample_4d_2(self):
  1905. # create another 4d grid of 3 points in each dimension
  1906. points = [(0., .5, 1.)] * 2 + [(0., 5., 10.)] * 2
  1907. values = np.asarray([0., .5, 1.])
  1908. values0 = values[:, np.newaxis, np.newaxis, np.newaxis]
  1909. values1 = values[np.newaxis, :, np.newaxis, np.newaxis]
  1910. values2 = values[np.newaxis, np.newaxis, :, np.newaxis]
  1911. values3 = values[np.newaxis, np.newaxis, np.newaxis, :]
  1912. values = (values0 + values1 * 10 + values2 * 100 + values3 * 1000)
  1913. return points, values
  1914. def test_list_input(self):
  1915. points, values = self._get_sample_4d()
  1916. sample = np.asarray([[0.1, 0.1, 1., .9], [0.2, 0.1, .45, .8],
  1917. [0.5, 0.5, .5, .5]])
  1918. for method in ['linear', 'nearest']:
  1919. interp = RegularGridInterpolator(points,
  1920. values.tolist(),
  1921. method=method)
  1922. v1 = interp(sample.tolist())
  1923. interp = RegularGridInterpolator(points,
  1924. values,
  1925. method=method)
  1926. v2 = interp(sample)
  1927. assert_allclose(v1, v2)
  1928. def test_complex(self):
  1929. points, values = self._get_sample_4d()
  1930. values = values - 2j*values
  1931. sample = np.asarray([[0.1, 0.1, 1., .9], [0.2, 0.1, .45, .8],
  1932. [0.5, 0.5, .5, .5]])
  1933. for method in ['linear', 'nearest']:
  1934. interp = RegularGridInterpolator(points, values,
  1935. method=method)
  1936. rinterp = RegularGridInterpolator(points, values.real,
  1937. method=method)
  1938. iinterp = RegularGridInterpolator(points, values.imag,
  1939. method=method)
  1940. v1 = interp(sample)
  1941. v2 = rinterp(sample) + 1j*iinterp(sample)
  1942. assert_allclose(v1, v2)
  1943. def test_linear_xi1d(self):
  1944. points, values = self._get_sample_4d_2()
  1945. interp = RegularGridInterpolator(points, values)
  1946. sample = np.asarray([0.1, 0.1, 10., 9.])
  1947. wanted = 1001.1
  1948. assert_array_almost_equal(interp(sample), wanted)
  1949. def test_linear_xi3d(self):
  1950. points, values = self._get_sample_4d()
  1951. interp = RegularGridInterpolator(points, values)
  1952. sample = np.asarray([[0.1, 0.1, 1., .9], [0.2, 0.1, .45, .8],
  1953. [0.5, 0.5, .5, .5]])
  1954. wanted = np.asarray([1001.1, 846.2, 555.5])
  1955. assert_array_almost_equal(interp(sample), wanted)
  1956. def test_nearest(self):
  1957. points, values = self._get_sample_4d()
  1958. interp = RegularGridInterpolator(points, values, method="nearest")
  1959. sample = np.asarray([0.1, 0.1, .9, .9])
  1960. wanted = 1100.
  1961. assert_array_almost_equal(interp(sample), wanted)
  1962. sample = np.asarray([0.1, 0.1, 0.1, 0.1])
  1963. wanted = 0.
  1964. assert_array_almost_equal(interp(sample), wanted)
  1965. sample = np.asarray([0., 0., 0., 0.])
  1966. wanted = 0.
  1967. assert_array_almost_equal(interp(sample), wanted)
  1968. sample = np.asarray([1., 1., 1., 1.])
  1969. wanted = 1111.
  1970. assert_array_almost_equal(interp(sample), wanted)
  1971. sample = np.asarray([0.1, 0.4, 0.6, 0.9])
  1972. wanted = 1055.
  1973. assert_array_almost_equal(interp(sample), wanted)
  1974. def test_linear_edges(self):
  1975. points, values = self._get_sample_4d()
  1976. interp = RegularGridInterpolator(points, values)
  1977. sample = np.asarray([[0., 0., 0., 0.], [1., 1., 1., 1.]])
  1978. wanted = np.asarray([0., 1111.])
  1979. assert_array_almost_equal(interp(sample), wanted)
  1980. def test_valid_create(self):
  1981. # create a 2d grid of 3 points in each dimension
  1982. points = [(0., .5, 1.), (0., 1., .5)]
  1983. values = np.asarray([0., .5, 1.])
  1984. values0 = values[:, np.newaxis]
  1985. values1 = values[np.newaxis, :]
  1986. values = (values0 + values1 * 10)
  1987. assert_raises(ValueError, RegularGridInterpolator, points, values)
  1988. points = [((0., .5, 1.), ), (0., .5, 1.)]
  1989. assert_raises(ValueError, RegularGridInterpolator, points, values)
  1990. points = [(0., .5, .75, 1.), (0., .5, 1.)]
  1991. assert_raises(ValueError, RegularGridInterpolator, points, values)
  1992. points = [(0., .5, 1.), (0., .5, 1.), (0., .5, 1.)]
  1993. assert_raises(ValueError, RegularGridInterpolator, points, values)
  1994. points = [(0., .5, 1.), (0., .5, 1.)]
  1995. assert_raises(ValueError, RegularGridInterpolator, points, values,
  1996. method="undefmethod")
  1997. def test_valid_call(self):
  1998. points, values = self._get_sample_4d()
  1999. interp = RegularGridInterpolator(points, values)
  2000. sample = np.asarray([[0., 0., 0., 0.], [1., 1., 1., 1.]])
  2001. assert_raises(ValueError, interp, sample, "undefmethod")
  2002. sample = np.asarray([[0., 0., 0.], [1., 1., 1.]])
  2003. assert_raises(ValueError, interp, sample)
  2004. sample = np.asarray([[0., 0., 0., 0.], [1., 1., 1., 1.1]])
  2005. assert_raises(ValueError, interp, sample)
  2006. def test_out_of_bounds_extrap(self):
  2007. points, values = self._get_sample_4d()
  2008. interp = RegularGridInterpolator(points, values, bounds_error=False,
  2009. fill_value=None)
  2010. sample = np.asarray([[-.1, -.1, -.1, -.1], [1.1, 1.1, 1.1, 1.1],
  2011. [21, 2.1, -1.1, -11], [2.1, 2.1, -1.1, -1.1]])
  2012. wanted = np.asarray([0., 1111., 11., 11.])
  2013. assert_array_almost_equal(interp(sample, method="nearest"), wanted)
  2014. wanted = np.asarray([-111.1, 1222.1, -11068., -1186.9])
  2015. assert_array_almost_equal(interp(sample, method="linear"), wanted)
  2016. def test_out_of_bounds_extrap2(self):
  2017. points, values = self._get_sample_4d_2()
  2018. interp = RegularGridInterpolator(points, values, bounds_error=False,
  2019. fill_value=None)
  2020. sample = np.asarray([[-.1, -.1, -.1, -.1], [1.1, 1.1, 1.1, 1.1],
  2021. [21, 2.1, -1.1, -11], [2.1, 2.1, -1.1, -1.1]])
  2022. wanted = np.asarray([0., 11., 11., 11.])
  2023. assert_array_almost_equal(interp(sample, method="nearest"), wanted)
  2024. wanted = np.asarray([-12.1, 133.1, -1069., -97.9])
  2025. assert_array_almost_equal(interp(sample, method="linear"), wanted)
  2026. def test_out_of_bounds_fill(self):
  2027. points, values = self._get_sample_4d()
  2028. interp = RegularGridInterpolator(points, values, bounds_error=False,
  2029. fill_value=np.nan)
  2030. sample = np.asarray([[-.1, -.1, -.1, -.1], [1.1, 1.1, 1.1, 1.1],
  2031. [2.1, 2.1, -1.1, -1.1]])
  2032. wanted = np.asarray([np.nan, np.nan, np.nan])
  2033. assert_array_almost_equal(interp(sample, method="nearest"), wanted)
  2034. assert_array_almost_equal(interp(sample, method="linear"), wanted)
  2035. sample = np.asarray([[0.1, 0.1, 1., .9], [0.2, 0.1, .45, .8],
  2036. [0.5, 0.5, .5, .5]])
  2037. wanted = np.asarray([1001.1, 846.2, 555.5])
  2038. assert_array_almost_equal(interp(sample), wanted)
  2039. def test_nearest_compare_qhull(self):
  2040. points, values = self._get_sample_4d()
  2041. interp = RegularGridInterpolator(points, values, method="nearest")
  2042. points_qhull = itertools.product(*points)
  2043. points_qhull = [p for p in points_qhull]
  2044. points_qhull = np.asarray(points_qhull)
  2045. values_qhull = values.reshape(-1)
  2046. interp_qhull = NearestNDInterpolator(points_qhull, values_qhull)
  2047. sample = np.asarray([[0.1, 0.1, 1., .9], [0.2, 0.1, .45, .8],
  2048. [0.5, 0.5, .5, .5]])
  2049. assert_array_almost_equal(interp(sample), interp_qhull(sample))
  2050. def test_linear_compare_qhull(self):
  2051. points, values = self._get_sample_4d()
  2052. interp = RegularGridInterpolator(points, values)
  2053. points_qhull = itertools.product(*points)
  2054. points_qhull = [p for p in points_qhull]
  2055. points_qhull = np.asarray(points_qhull)
  2056. values_qhull = values.reshape(-1)
  2057. interp_qhull = LinearNDInterpolator(points_qhull, values_qhull)
  2058. sample = np.asarray([[0.1, 0.1, 1., .9], [0.2, 0.1, .45, .8],
  2059. [0.5, 0.5, .5, .5]])
  2060. assert_array_almost_equal(interp(sample), interp_qhull(sample))
  2061. def test_duck_typed_values(self):
  2062. x = np.linspace(0, 2, 5)
  2063. y = np.linspace(0, 1, 7)
  2064. values = MyValue((5, 7))
  2065. for method in ('nearest', 'linear'):
  2066. interp = RegularGridInterpolator((x, y), values,
  2067. method=method)
  2068. v1 = interp([0.4, 0.7])
  2069. interp = RegularGridInterpolator((x, y), values._v,
  2070. method=method)
  2071. v2 = interp([0.4, 0.7])
  2072. assert_allclose(v1, v2)
  2073. def test_invalid_fill_value(self):
  2074. np.random.seed(1234)
  2075. x = np.linspace(0, 2, 5)
  2076. y = np.linspace(0, 1, 7)
  2077. values = np.random.rand(5, 7)
  2078. # integers can be cast to floats
  2079. RegularGridInterpolator((x, y), values, fill_value=1)
  2080. # complex values cannot
  2081. assert_raises(ValueError, RegularGridInterpolator,
  2082. (x, y), values, fill_value=1+2j)
  2083. def test_fillvalue_type(self):
  2084. # from #3703; test that interpolator object construction succeeds
  2085. values = np.ones((10, 20, 30), dtype='>f4')
  2086. points = [np.arange(n) for n in values.shape]
  2087. xi = [(1, 1, 1)]
  2088. interpolator = RegularGridInterpolator(points, values)
  2089. interpolator = RegularGridInterpolator(points, values, fill_value=0.)
  2090. class MyValue(object):
  2091. """
  2092. Minimal indexable object
  2093. """
  2094. def __init__(self, shape):
  2095. self.ndim = 2
  2096. self.shape = shape
  2097. self._v = np.arange(np.prod(shape)).reshape(shape)
  2098. def __getitem__(self, idx):
  2099. return self._v[idx]
  2100. def __array_interface__(self):
  2101. return None
  2102. def __array__(self):
  2103. raise RuntimeError("No array representation")
  2104. class TestInterpN(object):
  2105. def _sample_2d_data(self):
  2106. x = np.arange(1, 6)
  2107. x = np.array([.5, 2., 3., 4., 5.5])
  2108. y = np.arange(1, 6)
  2109. y = np.array([.5, 2., 3., 4., 5.5])
  2110. z = np.array([[1, 2, 1, 2, 1], [1, 2, 1, 2, 1], [1, 2, 3, 2, 1],
  2111. [1, 2, 2, 2, 1], [1, 2, 1, 2, 1]])
  2112. return x, y, z
  2113. def test_spline_2d(self):
  2114. x, y, z = self._sample_2d_data()
  2115. lut = RectBivariateSpline(x, y, z)
  2116. xi = np.array([[1, 2.3, 5.3, 0.5, 3.3, 1.2, 3],
  2117. [1, 3.3, 1.2, 4.0, 5.0, 1.0, 3]]).T
  2118. assert_array_almost_equal(interpn((x, y), z, xi, method="splinef2d"),
  2119. lut.ev(xi[:, 0], xi[:, 1]))
  2120. def test_list_input(self):
  2121. x, y, z = self._sample_2d_data()
  2122. xi = np.array([[1, 2.3, 5.3, 0.5, 3.3, 1.2, 3],
  2123. [1, 3.3, 1.2, 4.0, 5.0, 1.0, 3]]).T
  2124. for method in ['nearest', 'linear', 'splinef2d']:
  2125. v1 = interpn((x, y), z, xi, method=method)
  2126. v2 = interpn((x.tolist(), y.tolist()), z.tolist(),
  2127. xi.tolist(), method=method)
  2128. assert_allclose(v1, v2, err_msg=method)
  2129. def test_spline_2d_outofbounds(self):
  2130. x = np.array([.5, 2., 3., 4., 5.5])
  2131. y = np.array([.5, 2., 3., 4., 5.5])
  2132. z = np.array([[1, 2, 1, 2, 1], [1, 2, 1, 2, 1], [1, 2, 3, 2, 1],
  2133. [1, 2, 2, 2, 1], [1, 2, 1, 2, 1]])
  2134. lut = RectBivariateSpline(x, y, z)
  2135. xi = np.array([[1, 2.3, 6.3, 0.5, 3.3, 1.2, 3],
  2136. [1, 3.3, 1.2, -4.0, 5.0, 1.0, 3]]).T
  2137. actual = interpn((x, y), z, xi, method="splinef2d",
  2138. bounds_error=False, fill_value=999.99)
  2139. expected = lut.ev(xi[:, 0], xi[:, 1])
  2140. expected[2:4] = 999.99
  2141. assert_array_almost_equal(actual, expected)
  2142. # no extrapolation for splinef2d
  2143. assert_raises(ValueError, interpn, (x, y), z, xi, method="splinef2d",
  2144. bounds_error=False, fill_value=None)
  2145. def _sample_4d_data(self):
  2146. points = [(0., .5, 1.)] * 2 + [(0., 5., 10.)] * 2
  2147. values = np.asarray([0., .5, 1.])
  2148. values0 = values[:, np.newaxis, np.newaxis, np.newaxis]
  2149. values1 = values[np.newaxis, :, np.newaxis, np.newaxis]
  2150. values2 = values[np.newaxis, np.newaxis, :, np.newaxis]
  2151. values3 = values[np.newaxis, np.newaxis, np.newaxis, :]
  2152. values = (values0 + values1 * 10 + values2 * 100 + values3 * 1000)
  2153. return points, values
  2154. def test_linear_4d(self):
  2155. # create a 4d grid of 3 points in each dimension
  2156. points, values = self._sample_4d_data()
  2157. interp_rg = RegularGridInterpolator(points, values)
  2158. sample = np.asarray([[0.1, 0.1, 10., 9.]])
  2159. wanted = interpn(points, values, sample, method="linear")
  2160. assert_array_almost_equal(interp_rg(sample), wanted)
  2161. def test_4d_linear_outofbounds(self):
  2162. # create a 4d grid of 3 points in each dimension
  2163. points, values = self._sample_4d_data()
  2164. sample = np.asarray([[0.1, -0.1, 10.1, 9.]])
  2165. wanted = 999.99
  2166. actual = interpn(points, values, sample, method="linear",
  2167. bounds_error=False, fill_value=999.99)
  2168. assert_array_almost_equal(actual, wanted)
  2169. def test_nearest_4d(self):
  2170. # create a 4d grid of 3 points in each dimension
  2171. points, values = self._sample_4d_data()
  2172. interp_rg = RegularGridInterpolator(points, values, method="nearest")
  2173. sample = np.asarray([[0.1, 0.1, 10., 9.]])
  2174. wanted = interpn(points, values, sample, method="nearest")
  2175. assert_array_almost_equal(interp_rg(sample), wanted)
  2176. def test_4d_nearest_outofbounds(self):
  2177. # create a 4d grid of 3 points in each dimension
  2178. points, values = self._sample_4d_data()
  2179. sample = np.asarray([[0.1, -0.1, 10.1, 9.]])
  2180. wanted = 999.99
  2181. actual = interpn(points, values, sample, method="nearest",
  2182. bounds_error=False, fill_value=999.99)
  2183. assert_array_almost_equal(actual, wanted)
  2184. def test_xi_1d(self):
  2185. # verify that 1D xi works as expected
  2186. points, values = self._sample_4d_data()
  2187. sample = np.asarray([0.1, 0.1, 10., 9.])
  2188. v1 = interpn(points, values, sample, bounds_error=False)
  2189. v2 = interpn(points, values, sample[None,:], bounds_error=False)
  2190. assert_allclose(v1, v2)
  2191. def test_xi_nd(self):
  2192. # verify that higher-d xi works as expected
  2193. points, values = self._sample_4d_data()
  2194. np.random.seed(1234)
  2195. sample = np.random.rand(2, 3, 4)
  2196. v1 = interpn(points, values, sample, method='nearest',
  2197. bounds_error=False)
  2198. assert_equal(v1.shape, (2, 3))
  2199. v2 = interpn(points, values, sample.reshape(-1, 4),
  2200. method='nearest', bounds_error=False)
  2201. assert_allclose(v1, v2.reshape(v1.shape))
  2202. def test_xi_broadcast(self):
  2203. # verify that the interpolators broadcast xi
  2204. x, y, values = self._sample_2d_data()
  2205. points = (x, y)
  2206. xi = np.linspace(0, 1, 2)
  2207. yi = np.linspace(0, 3, 3)
  2208. for method in ['nearest', 'linear', 'splinef2d']:
  2209. sample = (xi[:,None], yi[None,:])
  2210. v1 = interpn(points, values, sample, method=method,
  2211. bounds_error=False)
  2212. assert_equal(v1.shape, (2, 3))
  2213. xx, yy = np.meshgrid(xi, yi)
  2214. sample = np.c_[xx.T.ravel(), yy.T.ravel()]
  2215. v2 = interpn(points, values, sample,
  2216. method=method, bounds_error=False)
  2217. assert_allclose(v1, v2.reshape(v1.shape))
  2218. def test_nonscalar_values(self):
  2219. # Verify that non-scalar valued values also works
  2220. points, values = self._sample_4d_data()
  2221. np.random.seed(1234)
  2222. values = np.random.rand(3, 3, 3, 3, 6)
  2223. sample = np.random.rand(7, 11, 4)
  2224. for method in ['nearest', 'linear']:
  2225. v = interpn(points, values, sample, method=method,
  2226. bounds_error=False)
  2227. assert_equal(v.shape, (7, 11, 6), err_msg=method)
  2228. vs = [interpn(points, values[...,j], sample, method=method,
  2229. bounds_error=False)
  2230. for j in range(6)]
  2231. v2 = np.array(vs).transpose(1, 2, 0)
  2232. assert_allclose(v, v2, err_msg=method)
  2233. # Vector-valued splines supported with fitpack
  2234. assert_raises(ValueError, interpn, points, values, sample,
  2235. method='splinef2d')
  2236. def test_complex(self):
  2237. x, y, values = self._sample_2d_data()
  2238. points = (x, y)
  2239. values = values - 2j*values
  2240. sample = np.array([[1, 2.3, 5.3, 0.5, 3.3, 1.2, 3],
  2241. [1, 3.3, 1.2, 4.0, 5.0, 1.0, 3]]).T
  2242. for method in ['linear', 'nearest']:
  2243. v1 = interpn(points, values, sample, method=method)
  2244. v2r = interpn(points, values.real, sample, method=method)
  2245. v2i = interpn(points, values.imag, sample, method=method)
  2246. v2 = v2r + 1j*v2i
  2247. assert_allclose(v1, v2)
  2248. # Complex-valued data not supported by spline2fd
  2249. _assert_warns(np.ComplexWarning, interpn, points, values,
  2250. sample, method='splinef2d')
  2251. def test_duck_typed_values(self):
  2252. x = np.linspace(0, 2, 5)
  2253. y = np.linspace(0, 1, 7)
  2254. values = MyValue((5, 7))
  2255. for method in ('nearest', 'linear'):
  2256. v1 = interpn((x, y), values, [0.4, 0.7], method=method)
  2257. v2 = interpn((x, y), values._v, [0.4, 0.7], method=method)
  2258. assert_allclose(v1, v2)
  2259. def test_matrix_input(self):
  2260. x = np.linspace(0, 2, 5)
  2261. y = np.linspace(0, 1, 7)
  2262. values = np.matrix(np.random.rand(5, 7))
  2263. sample = np.random.rand(3, 7, 2)
  2264. for method in ('nearest', 'linear', 'splinef2d'):
  2265. v1 = interpn((x, y), values, sample, method=method)
  2266. v2 = interpn((x, y), np.asarray(values), sample, method=method)
  2267. assert_allclose(v1, np.asmatrix(v2))