test_classes.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642
  1. """Test inter-conversion of different polynomial classes.
  2. This tests the convert and cast methods of all the polynomial classes.
  3. """
  4. from __future__ import division, absolute_import, print_function
  5. import operator as op
  6. from numbers import Number
  7. import pytest
  8. import numpy as np
  9. from numpy.polynomial import (
  10. Polynomial, Legendre, Chebyshev, Laguerre, Hermite, HermiteE)
  11. from numpy.testing import (
  12. assert_almost_equal, assert_raises, assert_equal, assert_,
  13. )
  14. from numpy.compat import long
  15. #
  16. # fixtures
  17. #
  18. classes = (
  19. Polynomial, Legendre, Chebyshev, Laguerre,
  20. Hermite, HermiteE
  21. )
  22. classids = tuple(cls.__name__ for cls in classes)
  23. @pytest.fixture(params=classes, ids=classids)
  24. def Poly(request):
  25. return request.param
  26. #
  27. # helper functions
  28. #
  29. random = np.random.random
  30. def assert_poly_almost_equal(p1, p2, msg=""):
  31. try:
  32. assert_(np.all(p1.domain == p2.domain))
  33. assert_(np.all(p1.window == p2.window))
  34. assert_almost_equal(p1.coef, p2.coef)
  35. except AssertionError:
  36. msg = "Result: %s\nTarget: %s", (p1, p2)
  37. raise AssertionError(msg)
  38. #
  39. # Test conversion methods that depend on combinations of two classes.
  40. #
  41. Poly1 = Poly
  42. Poly2 = Poly
  43. def test_conversion(Poly1, Poly2):
  44. x = np.linspace(0, 1, 10)
  45. coef = random((3,))
  46. d1 = Poly1.domain + random((2,))*.25
  47. w1 = Poly1.window + random((2,))*.25
  48. p1 = Poly1(coef, domain=d1, window=w1)
  49. d2 = Poly2.domain + random((2,))*.25
  50. w2 = Poly2.window + random((2,))*.25
  51. p2 = p1.convert(kind=Poly2, domain=d2, window=w2)
  52. assert_almost_equal(p2.domain, d2)
  53. assert_almost_equal(p2.window, w2)
  54. assert_almost_equal(p2(x), p1(x))
  55. def test_cast(Poly1, Poly2):
  56. x = np.linspace(0, 1, 10)
  57. coef = random((3,))
  58. d1 = Poly1.domain + random((2,))*.25
  59. w1 = Poly1.window + random((2,))*.25
  60. p1 = Poly1(coef, domain=d1, window=w1)
  61. d2 = Poly2.domain + random((2,))*.25
  62. w2 = Poly2.window + random((2,))*.25
  63. p2 = Poly2.cast(p1, domain=d2, window=w2)
  64. assert_almost_equal(p2.domain, d2)
  65. assert_almost_equal(p2.window, w2)
  66. assert_almost_equal(p2(x), p1(x))
  67. #
  68. # test methods that depend on one class
  69. #
  70. def test_identity(Poly):
  71. d = Poly.domain + random((2,))*.25
  72. w = Poly.window + random((2,))*.25
  73. x = np.linspace(d[0], d[1], 11)
  74. p = Poly.identity(domain=d, window=w)
  75. assert_equal(p.domain, d)
  76. assert_equal(p.window, w)
  77. assert_almost_equal(p(x), x)
  78. def test_basis(Poly):
  79. d = Poly.domain + random((2,))*.25
  80. w = Poly.window + random((2,))*.25
  81. p = Poly.basis(5, domain=d, window=w)
  82. assert_equal(p.domain, d)
  83. assert_equal(p.window, w)
  84. assert_equal(p.coef, [0]*5 + [1])
  85. def test_fromroots(Poly):
  86. # check that requested roots are zeros of a polynomial
  87. # of correct degree, domain, and window.
  88. d = Poly.domain + random((2,))*.25
  89. w = Poly.window + random((2,))*.25
  90. r = random((5,))
  91. p1 = Poly.fromroots(r, domain=d, window=w)
  92. assert_equal(p1.degree(), len(r))
  93. assert_equal(p1.domain, d)
  94. assert_equal(p1.window, w)
  95. assert_almost_equal(p1(r), 0)
  96. # check that polynomial is monic
  97. pdom = Polynomial.domain
  98. pwin = Polynomial.window
  99. p2 = Polynomial.cast(p1, domain=pdom, window=pwin)
  100. assert_almost_equal(p2.coef[-1], 1)
  101. def test_fit(Poly):
  102. def f(x):
  103. return x*(x - 1)*(x - 2)
  104. x = np.linspace(0, 3)
  105. y = f(x)
  106. # check default value of domain and window
  107. p = Poly.fit(x, y, 3)
  108. assert_almost_equal(p.domain, [0, 3])
  109. assert_almost_equal(p(x), y)
  110. assert_equal(p.degree(), 3)
  111. # check with given domains and window
  112. d = Poly.domain + random((2,))*.25
  113. w = Poly.window + random((2,))*.25
  114. p = Poly.fit(x, y, 3, domain=d, window=w)
  115. assert_almost_equal(p(x), y)
  116. assert_almost_equal(p.domain, d)
  117. assert_almost_equal(p.window, w)
  118. p = Poly.fit(x, y, [0, 1, 2, 3], domain=d, window=w)
  119. assert_almost_equal(p(x), y)
  120. assert_almost_equal(p.domain, d)
  121. assert_almost_equal(p.window, w)
  122. # check with class domain default
  123. p = Poly.fit(x, y, 3, [])
  124. assert_equal(p.domain, Poly.domain)
  125. assert_equal(p.window, Poly.window)
  126. p = Poly.fit(x, y, [0, 1, 2, 3], [])
  127. assert_equal(p.domain, Poly.domain)
  128. assert_equal(p.window, Poly.window)
  129. # check that fit accepts weights.
  130. w = np.zeros_like(x)
  131. z = y + random(y.shape)*.25
  132. w[::2] = 1
  133. p1 = Poly.fit(x[::2], z[::2], 3)
  134. p2 = Poly.fit(x, z, 3, w=w)
  135. p3 = Poly.fit(x, z, [0, 1, 2, 3], w=w)
  136. assert_almost_equal(p1(x), p2(x))
  137. assert_almost_equal(p2(x), p3(x))
  138. def test_equal(Poly):
  139. p1 = Poly([1, 2, 3], domain=[0, 1], window=[2, 3])
  140. p2 = Poly([1, 1, 1], domain=[0, 1], window=[2, 3])
  141. p3 = Poly([1, 2, 3], domain=[1, 2], window=[2, 3])
  142. p4 = Poly([1, 2, 3], domain=[0, 1], window=[1, 2])
  143. assert_(p1 == p1)
  144. assert_(not p1 == p2)
  145. assert_(not p1 == p3)
  146. assert_(not p1 == p4)
  147. def test_not_equal(Poly):
  148. p1 = Poly([1, 2, 3], domain=[0, 1], window=[2, 3])
  149. p2 = Poly([1, 1, 1], domain=[0, 1], window=[2, 3])
  150. p3 = Poly([1, 2, 3], domain=[1, 2], window=[2, 3])
  151. p4 = Poly([1, 2, 3], domain=[0, 1], window=[1, 2])
  152. assert_(not p1 != p1)
  153. assert_(p1 != p2)
  154. assert_(p1 != p3)
  155. assert_(p1 != p4)
  156. def test_add(Poly):
  157. # This checks commutation, not numerical correctness
  158. c1 = list(random((4,)) + .5)
  159. c2 = list(random((3,)) + .5)
  160. p1 = Poly(c1)
  161. p2 = Poly(c2)
  162. p3 = p1 + p2
  163. assert_poly_almost_equal(p2 + p1, p3)
  164. assert_poly_almost_equal(p1 + c2, p3)
  165. assert_poly_almost_equal(c2 + p1, p3)
  166. assert_poly_almost_equal(p1 + tuple(c2), p3)
  167. assert_poly_almost_equal(tuple(c2) + p1, p3)
  168. assert_poly_almost_equal(p1 + np.array(c2), p3)
  169. assert_poly_almost_equal(np.array(c2) + p1, p3)
  170. assert_raises(TypeError, op.add, p1, Poly([0], domain=Poly.domain + 1))
  171. assert_raises(TypeError, op.add, p1, Poly([0], window=Poly.window + 1))
  172. if Poly is Polynomial:
  173. assert_raises(TypeError, op.add, p1, Chebyshev([0]))
  174. else:
  175. assert_raises(TypeError, op.add, p1, Polynomial([0]))
  176. def test_sub(Poly):
  177. # This checks commutation, not numerical correctness
  178. c1 = list(random((4,)) + .5)
  179. c2 = list(random((3,)) + .5)
  180. p1 = Poly(c1)
  181. p2 = Poly(c2)
  182. p3 = p1 - p2
  183. assert_poly_almost_equal(p2 - p1, -p3)
  184. assert_poly_almost_equal(p1 - c2, p3)
  185. assert_poly_almost_equal(c2 - p1, -p3)
  186. assert_poly_almost_equal(p1 - tuple(c2), p3)
  187. assert_poly_almost_equal(tuple(c2) - p1, -p3)
  188. assert_poly_almost_equal(p1 - np.array(c2), p3)
  189. assert_poly_almost_equal(np.array(c2) - p1, -p3)
  190. assert_raises(TypeError, op.sub, p1, Poly([0], domain=Poly.domain + 1))
  191. assert_raises(TypeError, op.sub, p1, Poly([0], window=Poly.window + 1))
  192. if Poly is Polynomial:
  193. assert_raises(TypeError, op.sub, p1, Chebyshev([0]))
  194. else:
  195. assert_raises(TypeError, op.sub, p1, Polynomial([0]))
  196. def test_mul(Poly):
  197. c1 = list(random((4,)) + .5)
  198. c2 = list(random((3,)) + .5)
  199. p1 = Poly(c1)
  200. p2 = Poly(c2)
  201. p3 = p1 * p2
  202. assert_poly_almost_equal(p2 * p1, p3)
  203. assert_poly_almost_equal(p1 * c2, p3)
  204. assert_poly_almost_equal(c2 * p1, p3)
  205. assert_poly_almost_equal(p1 * tuple(c2), p3)
  206. assert_poly_almost_equal(tuple(c2) * p1, p3)
  207. assert_poly_almost_equal(p1 * np.array(c2), p3)
  208. assert_poly_almost_equal(np.array(c2) * p1, p3)
  209. assert_poly_almost_equal(p1 * 2, p1 * Poly([2]))
  210. assert_poly_almost_equal(2 * p1, p1 * Poly([2]))
  211. assert_raises(TypeError, op.mul, p1, Poly([0], domain=Poly.domain + 1))
  212. assert_raises(TypeError, op.mul, p1, Poly([0], window=Poly.window + 1))
  213. if Poly is Polynomial:
  214. assert_raises(TypeError, op.mul, p1, Chebyshev([0]))
  215. else:
  216. assert_raises(TypeError, op.mul, p1, Polynomial([0]))
  217. def test_floordiv(Poly):
  218. c1 = list(random((4,)) + .5)
  219. c2 = list(random((3,)) + .5)
  220. c3 = list(random((2,)) + .5)
  221. p1 = Poly(c1)
  222. p2 = Poly(c2)
  223. p3 = Poly(c3)
  224. p4 = p1 * p2 + p3
  225. c4 = list(p4.coef)
  226. assert_poly_almost_equal(p4 // p2, p1)
  227. assert_poly_almost_equal(p4 // c2, p1)
  228. assert_poly_almost_equal(c4 // p2, p1)
  229. assert_poly_almost_equal(p4 // tuple(c2), p1)
  230. assert_poly_almost_equal(tuple(c4) // p2, p1)
  231. assert_poly_almost_equal(p4 // np.array(c2), p1)
  232. assert_poly_almost_equal(np.array(c4) // p2, p1)
  233. assert_poly_almost_equal(2 // p2, Poly([0]))
  234. assert_poly_almost_equal(p2 // 2, 0.5*p2)
  235. assert_raises(
  236. TypeError, op.floordiv, p1, Poly([0], domain=Poly.domain + 1))
  237. assert_raises(
  238. TypeError, op.floordiv, p1, Poly([0], window=Poly.window + 1))
  239. if Poly is Polynomial:
  240. assert_raises(TypeError, op.floordiv, p1, Chebyshev([0]))
  241. else:
  242. assert_raises(TypeError, op.floordiv, p1, Polynomial([0]))
  243. def test_truediv(Poly):
  244. # true division is valid only if the denominator is a Number and
  245. # not a python bool.
  246. p1 = Poly([1,2,3])
  247. p2 = p1 * 5
  248. for stype in np.ScalarType:
  249. if not issubclass(stype, Number) or issubclass(stype, bool):
  250. continue
  251. s = stype(5)
  252. assert_poly_almost_equal(op.truediv(p2, s), p1)
  253. assert_raises(TypeError, op.truediv, s, p2)
  254. for stype in (int, long, float):
  255. s = stype(5)
  256. assert_poly_almost_equal(op.truediv(p2, s), p1)
  257. assert_raises(TypeError, op.truediv, s, p2)
  258. for stype in [complex]:
  259. s = stype(5, 0)
  260. assert_poly_almost_equal(op.truediv(p2, s), p1)
  261. assert_raises(TypeError, op.truediv, s, p2)
  262. for s in [tuple(), list(), dict(), bool(), np.array([1])]:
  263. assert_raises(TypeError, op.truediv, p2, s)
  264. assert_raises(TypeError, op.truediv, s, p2)
  265. for ptype in classes:
  266. assert_raises(TypeError, op.truediv, p2, ptype(1))
  267. def test_mod(Poly):
  268. # This checks commutation, not numerical correctness
  269. c1 = list(random((4,)) + .5)
  270. c2 = list(random((3,)) + .5)
  271. c3 = list(random((2,)) + .5)
  272. p1 = Poly(c1)
  273. p2 = Poly(c2)
  274. p3 = Poly(c3)
  275. p4 = p1 * p2 + p3
  276. c4 = list(p4.coef)
  277. assert_poly_almost_equal(p4 % p2, p3)
  278. assert_poly_almost_equal(p4 % c2, p3)
  279. assert_poly_almost_equal(c4 % p2, p3)
  280. assert_poly_almost_equal(p4 % tuple(c2), p3)
  281. assert_poly_almost_equal(tuple(c4) % p2, p3)
  282. assert_poly_almost_equal(p4 % np.array(c2), p3)
  283. assert_poly_almost_equal(np.array(c4) % p2, p3)
  284. assert_poly_almost_equal(2 % p2, Poly([2]))
  285. assert_poly_almost_equal(p2 % 2, Poly([0]))
  286. assert_raises(TypeError, op.mod, p1, Poly([0], domain=Poly.domain + 1))
  287. assert_raises(TypeError, op.mod, p1, Poly([0], window=Poly.window + 1))
  288. if Poly is Polynomial:
  289. assert_raises(TypeError, op.mod, p1, Chebyshev([0]))
  290. else:
  291. assert_raises(TypeError, op.mod, p1, Polynomial([0]))
  292. def test_divmod(Poly):
  293. # This checks commutation, not numerical correctness
  294. c1 = list(random((4,)) + .5)
  295. c2 = list(random((3,)) + .5)
  296. c3 = list(random((2,)) + .5)
  297. p1 = Poly(c1)
  298. p2 = Poly(c2)
  299. p3 = Poly(c3)
  300. p4 = p1 * p2 + p3
  301. c4 = list(p4.coef)
  302. quo, rem = divmod(p4, p2)
  303. assert_poly_almost_equal(quo, p1)
  304. assert_poly_almost_equal(rem, p3)
  305. quo, rem = divmod(p4, c2)
  306. assert_poly_almost_equal(quo, p1)
  307. assert_poly_almost_equal(rem, p3)
  308. quo, rem = divmod(c4, p2)
  309. assert_poly_almost_equal(quo, p1)
  310. assert_poly_almost_equal(rem, p3)
  311. quo, rem = divmod(p4, tuple(c2))
  312. assert_poly_almost_equal(quo, p1)
  313. assert_poly_almost_equal(rem, p3)
  314. quo, rem = divmod(tuple(c4), p2)
  315. assert_poly_almost_equal(quo, p1)
  316. assert_poly_almost_equal(rem, p3)
  317. quo, rem = divmod(p4, np.array(c2))
  318. assert_poly_almost_equal(quo, p1)
  319. assert_poly_almost_equal(rem, p3)
  320. quo, rem = divmod(np.array(c4), p2)
  321. assert_poly_almost_equal(quo, p1)
  322. assert_poly_almost_equal(rem, p3)
  323. quo, rem = divmod(p2, 2)
  324. assert_poly_almost_equal(quo, 0.5*p2)
  325. assert_poly_almost_equal(rem, Poly([0]))
  326. quo, rem = divmod(2, p2)
  327. assert_poly_almost_equal(quo, Poly([0]))
  328. assert_poly_almost_equal(rem, Poly([2]))
  329. assert_raises(TypeError, divmod, p1, Poly([0], domain=Poly.domain + 1))
  330. assert_raises(TypeError, divmod, p1, Poly([0], window=Poly.window + 1))
  331. if Poly is Polynomial:
  332. assert_raises(TypeError, divmod, p1, Chebyshev([0]))
  333. else:
  334. assert_raises(TypeError, divmod, p1, Polynomial([0]))
  335. def test_roots(Poly):
  336. d = Poly.domain * 1.25 + .25
  337. w = Poly.window
  338. tgt = np.linspace(d[0], d[1], 5)
  339. res = np.sort(Poly.fromroots(tgt, domain=d, window=w).roots())
  340. assert_almost_equal(res, tgt)
  341. # default domain and window
  342. res = np.sort(Poly.fromroots(tgt).roots())
  343. assert_almost_equal(res, tgt)
  344. def test_degree(Poly):
  345. p = Poly.basis(5)
  346. assert_equal(p.degree(), 5)
  347. def test_copy(Poly):
  348. p1 = Poly.basis(5)
  349. p2 = p1.copy()
  350. assert_(p1 == p2)
  351. assert_(p1 is not p2)
  352. assert_(p1.coef is not p2.coef)
  353. assert_(p1.domain is not p2.domain)
  354. assert_(p1.window is not p2.window)
  355. def test_integ(Poly):
  356. P = Polynomial
  357. # Check defaults
  358. p0 = Poly.cast(P([1*2, 2*3, 3*4]))
  359. p1 = P.cast(p0.integ())
  360. p2 = P.cast(p0.integ(2))
  361. assert_poly_almost_equal(p1, P([0, 2, 3, 4]))
  362. assert_poly_almost_equal(p2, P([0, 0, 1, 1, 1]))
  363. # Check with k
  364. p0 = Poly.cast(P([1*2, 2*3, 3*4]))
  365. p1 = P.cast(p0.integ(k=1))
  366. p2 = P.cast(p0.integ(2, k=[1, 1]))
  367. assert_poly_almost_equal(p1, P([1, 2, 3, 4]))
  368. assert_poly_almost_equal(p2, P([1, 1, 1, 1, 1]))
  369. # Check with lbnd
  370. p0 = Poly.cast(P([1*2, 2*3, 3*4]))
  371. p1 = P.cast(p0.integ(lbnd=1))
  372. p2 = P.cast(p0.integ(2, lbnd=1))
  373. assert_poly_almost_equal(p1, P([-9, 2, 3, 4]))
  374. assert_poly_almost_equal(p2, P([6, -9, 1, 1, 1]))
  375. # Check scaling
  376. d = 2*Poly.domain
  377. p0 = Poly.cast(P([1*2, 2*3, 3*4]), domain=d)
  378. p1 = P.cast(p0.integ())
  379. p2 = P.cast(p0.integ(2))
  380. assert_poly_almost_equal(p1, P([0, 2, 3, 4]))
  381. assert_poly_almost_equal(p2, P([0, 0, 1, 1, 1]))
  382. def test_deriv(Poly):
  383. # Check that the derivative is the inverse of integration. It is
  384. # assumes that the integration has been checked elsewhere.
  385. d = Poly.domain + random((2,))*.25
  386. w = Poly.window + random((2,))*.25
  387. p1 = Poly([1, 2, 3], domain=d, window=w)
  388. p2 = p1.integ(2, k=[1, 2])
  389. p3 = p1.integ(1, k=[1])
  390. assert_almost_equal(p2.deriv(1).coef, p3.coef)
  391. assert_almost_equal(p2.deriv(2).coef, p1.coef)
  392. # default domain and window
  393. p1 = Poly([1, 2, 3])
  394. p2 = p1.integ(2, k=[1, 2])
  395. p3 = p1.integ(1, k=[1])
  396. assert_almost_equal(p2.deriv(1).coef, p3.coef)
  397. assert_almost_equal(p2.deriv(2).coef, p1.coef)
  398. def test_linspace(Poly):
  399. d = Poly.domain + random((2,))*.25
  400. w = Poly.window + random((2,))*.25
  401. p = Poly([1, 2, 3], domain=d, window=w)
  402. # check default domain
  403. xtgt = np.linspace(d[0], d[1], 20)
  404. ytgt = p(xtgt)
  405. xres, yres = p.linspace(20)
  406. assert_almost_equal(xres, xtgt)
  407. assert_almost_equal(yres, ytgt)
  408. # check specified domain
  409. xtgt = np.linspace(0, 2, 20)
  410. ytgt = p(xtgt)
  411. xres, yres = p.linspace(20, domain=[0, 2])
  412. assert_almost_equal(xres, xtgt)
  413. assert_almost_equal(yres, ytgt)
  414. def test_pow(Poly):
  415. d = Poly.domain + random((2,))*.25
  416. w = Poly.window + random((2,))*.25
  417. tgt = Poly([1], domain=d, window=w)
  418. tst = Poly([1, 2, 3], domain=d, window=w)
  419. for i in range(5):
  420. assert_poly_almost_equal(tst**i, tgt)
  421. tgt = tgt * tst
  422. # default domain and window
  423. tgt = Poly([1])
  424. tst = Poly([1, 2, 3])
  425. for i in range(5):
  426. assert_poly_almost_equal(tst**i, tgt)
  427. tgt = tgt * tst
  428. # check error for invalid powers
  429. assert_raises(ValueError, op.pow, tgt, 1.5)
  430. assert_raises(ValueError, op.pow, tgt, -1)
  431. def test_call(Poly):
  432. P = Polynomial
  433. d = Poly.domain
  434. x = np.linspace(d[0], d[1], 11)
  435. # Check defaults
  436. p = Poly.cast(P([1, 2, 3]))
  437. tgt = 1 + x*(2 + 3*x)
  438. res = p(x)
  439. assert_almost_equal(res, tgt)
  440. def test_cutdeg(Poly):
  441. p = Poly([1, 2, 3])
  442. assert_raises(ValueError, p.cutdeg, .5)
  443. assert_raises(ValueError, p.cutdeg, -1)
  444. assert_equal(len(p.cutdeg(3)), 3)
  445. assert_equal(len(p.cutdeg(2)), 3)
  446. assert_equal(len(p.cutdeg(1)), 2)
  447. assert_equal(len(p.cutdeg(0)), 1)
  448. def test_truncate(Poly):
  449. p = Poly([1, 2, 3])
  450. assert_raises(ValueError, p.truncate, .5)
  451. assert_raises(ValueError, p.truncate, 0)
  452. assert_equal(len(p.truncate(4)), 3)
  453. assert_equal(len(p.truncate(3)), 3)
  454. assert_equal(len(p.truncate(2)), 2)
  455. assert_equal(len(p.truncate(1)), 1)
  456. def test_trim(Poly):
  457. c = [1, 1e-6, 1e-12, 0]
  458. p = Poly(c)
  459. assert_equal(p.trim().coef, c[:3])
  460. assert_equal(p.trim(1e-10).coef, c[:2])
  461. assert_equal(p.trim(1e-5).coef, c[:1])
  462. def test_mapparms(Poly):
  463. # check with defaults. Should be identity.
  464. d = Poly.domain
  465. w = Poly.window
  466. p = Poly([1], domain=d, window=w)
  467. assert_almost_equal([0, 1], p.mapparms())
  468. #
  469. w = 2*d + 1
  470. p = Poly([1], domain=d, window=w)
  471. assert_almost_equal([1, 2], p.mapparms())
  472. def test_ufunc_override(Poly):
  473. p = Poly([1, 2, 3])
  474. x = np.ones(3)
  475. assert_raises(TypeError, np.add, p, x)
  476. assert_raises(TypeError, np.add, x, p)
  477. class TestLatexRepr(object):
  478. """Test the latex repr used by ipython """
  479. def as_latex(self, obj):
  480. # right now we ignore the formatting of scalars in our tests, since
  481. # it makes them too verbose. Ideally, the formatting of scalars will
  482. # be fixed such that tests below continue to pass
  483. obj._repr_latex_scalar = lambda x: str(x)
  484. try:
  485. return obj._repr_latex_()
  486. finally:
  487. del obj._repr_latex_scalar
  488. def test_simple_polynomial(self):
  489. # default input
  490. p = Polynomial([1, 2, 3])
  491. assert_equal(self.as_latex(p),
  492. r'$x \mapsto 1.0 + 2.0\,x + 3.0\,x^{2}$')
  493. # translated input
  494. p = Polynomial([1, 2, 3], domain=[-2, 0])
  495. assert_equal(self.as_latex(p),
  496. r'$x \mapsto 1.0 + 2.0\,\left(1.0 + x\right) + 3.0\,\left(1.0 + x\right)^{2}$')
  497. # scaled input
  498. p = Polynomial([1, 2, 3], domain=[-0.5, 0.5])
  499. assert_equal(self.as_latex(p),
  500. r'$x \mapsto 1.0 + 2.0\,\left(2.0x\right) + 3.0\,\left(2.0x\right)^{2}$')
  501. # affine input
  502. p = Polynomial([1, 2, 3], domain=[-1, 0])
  503. assert_equal(self.as_latex(p),
  504. r'$x \mapsto 1.0 + 2.0\,\left(1.0 + 2.0x\right) + 3.0\,\left(1.0 + 2.0x\right)^{2}$')
  505. def test_basis_func(self):
  506. p = Chebyshev([1, 2, 3])
  507. assert_equal(self.as_latex(p),
  508. r'$x \mapsto 1.0\,{T}_{0}(x) + 2.0\,{T}_{1}(x) + 3.0\,{T}_{2}(x)$')
  509. # affine input - check no surplus parens are added
  510. p = Chebyshev([1, 2, 3], domain=[-1, 0])
  511. assert_equal(self.as_latex(p),
  512. r'$x \mapsto 1.0\,{T}_{0}(1.0 + 2.0x) + 2.0\,{T}_{1}(1.0 + 2.0x) + 3.0\,{T}_{2}(1.0 + 2.0x)$')
  513. def test_multichar_basis_func(self):
  514. p = HermiteE([1, 2, 3])
  515. assert_equal(self.as_latex(p),
  516. r'$x \mapsto 1.0\,{He}_{0}(x) + 2.0\,{He}_{1}(x) + 3.0\,{He}_{2}(x)$')
  517. #
  518. # Test class method that only exists for some classes
  519. #
  520. class TestInterpolate(object):
  521. def f(self, x):
  522. return x * (x - 1) * (x - 2)
  523. def test_raises(self):
  524. assert_raises(ValueError, Chebyshev.interpolate, self.f, -1)
  525. assert_raises(TypeError, Chebyshev.interpolate, self.f, 10.)
  526. def test_dimensions(self):
  527. for deg in range(1, 5):
  528. assert_(Chebyshev.interpolate(self.f, deg).degree() == deg)
  529. def test_approximation(self):
  530. def powx(x, p):
  531. return x**p
  532. x = np.linspace(0, 2, 10)
  533. for deg in range(0, 10):
  534. for t in range(0, deg + 1):
  535. p = Chebyshev.interpolate(powx, deg, domain=[0, 2], args=(t,))
  536. assert_almost_equal(p(x), powx(x, t), decimal=12)