test_orthogonal_eval.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. from __future__ import division, print_function, absolute_import
  2. import numpy as np
  3. from numpy.testing import assert_, assert_allclose
  4. import scipy.special.orthogonal as orth
  5. from scipy.special._testutils import FuncData
  6. def test_eval_chebyt():
  7. n = np.arange(0, 10000, 7)
  8. x = 2*np.random.rand() - 1
  9. v1 = np.cos(n*np.arccos(x))
  10. v2 = orth.eval_chebyt(n, x)
  11. assert_(np.allclose(v1, v2, rtol=1e-15))
  12. def test_eval_genlaguerre_restriction():
  13. # check it returns nan for alpha <= -1
  14. assert_(np.isnan(orth.eval_genlaguerre(0, -1, 0)))
  15. assert_(np.isnan(orth.eval_genlaguerre(0.1, -1, 0)))
  16. def test_warnings():
  17. # ticket 1334
  18. olderr = np.seterr(all='raise')
  19. try:
  20. # these should raise no fp warnings
  21. orth.eval_legendre(1, 0)
  22. orth.eval_laguerre(1, 1)
  23. orth.eval_gegenbauer(1, 1, 0)
  24. finally:
  25. np.seterr(**olderr)
  26. class TestPolys(object):
  27. """
  28. Check that the eval_* functions agree with the constructed polynomials
  29. """
  30. def check_poly(self, func, cls, param_ranges=[], x_range=[], nn=10,
  31. nparam=10, nx=10, rtol=1e-8):
  32. np.random.seed(1234)
  33. dataset = []
  34. for n in np.arange(nn):
  35. params = [a + (b-a)*np.random.rand(nparam) for a,b in param_ranges]
  36. params = np.asarray(params).T
  37. if not param_ranges:
  38. params = [0]
  39. for p in params:
  40. if param_ranges:
  41. p = (n,) + tuple(p)
  42. else:
  43. p = (n,)
  44. x = x_range[0] + (x_range[1] - x_range[0])*np.random.rand(nx)
  45. x[0] = x_range[0] # always include domain start point
  46. x[1] = x_range[1] # always include domain end point
  47. poly = np.poly1d(cls(*p).coef)
  48. z = np.c_[np.tile(p, (nx,1)), x, poly(x)]
  49. dataset.append(z)
  50. dataset = np.concatenate(dataset, axis=0)
  51. def polyfunc(*p):
  52. p = (p[0].astype(int),) + p[1:]
  53. return func(*p)
  54. olderr = np.seterr(all='raise')
  55. try:
  56. ds = FuncData(polyfunc, dataset, list(range(len(param_ranges)+2)), -1,
  57. rtol=rtol)
  58. ds.check()
  59. finally:
  60. np.seterr(**olderr)
  61. def test_jacobi(self):
  62. self.check_poly(orth.eval_jacobi, orth.jacobi,
  63. param_ranges=[(-0.99, 10), (-0.99, 10)], x_range=[-1, 1],
  64. rtol=1e-5)
  65. def test_sh_jacobi(self):
  66. self.check_poly(orth.eval_sh_jacobi, orth.sh_jacobi,
  67. param_ranges=[(1, 10), (0, 1)], x_range=[0, 1],
  68. rtol=1e-5)
  69. def test_gegenbauer(self):
  70. self.check_poly(orth.eval_gegenbauer, orth.gegenbauer,
  71. param_ranges=[(-0.499, 10)], x_range=[-1, 1],
  72. rtol=1e-7)
  73. def test_chebyt(self):
  74. self.check_poly(orth.eval_chebyt, orth.chebyt,
  75. param_ranges=[], x_range=[-1, 1])
  76. def test_chebyu(self):
  77. self.check_poly(orth.eval_chebyu, orth.chebyu,
  78. param_ranges=[], x_range=[-1, 1])
  79. def test_chebys(self):
  80. self.check_poly(orth.eval_chebys, orth.chebys,
  81. param_ranges=[], x_range=[-2, 2])
  82. def test_chebyc(self):
  83. self.check_poly(orth.eval_chebyc, orth.chebyc,
  84. param_ranges=[], x_range=[-2, 2])
  85. def test_sh_chebyt(self):
  86. olderr = np.seterr(all='ignore')
  87. try:
  88. self.check_poly(orth.eval_sh_chebyt, orth.sh_chebyt,
  89. param_ranges=[], x_range=[0, 1])
  90. finally:
  91. np.seterr(**olderr)
  92. def test_sh_chebyu(self):
  93. self.check_poly(orth.eval_sh_chebyu, orth.sh_chebyu,
  94. param_ranges=[], x_range=[0, 1])
  95. def test_legendre(self):
  96. self.check_poly(orth.eval_legendre, orth.legendre,
  97. param_ranges=[], x_range=[-1, 1])
  98. def test_sh_legendre(self):
  99. olderr = np.seterr(all='ignore')
  100. try:
  101. self.check_poly(orth.eval_sh_legendre, orth.sh_legendre,
  102. param_ranges=[], x_range=[0, 1])
  103. finally:
  104. np.seterr(**olderr)
  105. def test_genlaguerre(self):
  106. self.check_poly(orth.eval_genlaguerre, orth.genlaguerre,
  107. param_ranges=[(-0.99, 10)], x_range=[0, 100])
  108. def test_laguerre(self):
  109. self.check_poly(orth.eval_laguerre, orth.laguerre,
  110. param_ranges=[], x_range=[0, 100])
  111. def test_hermite(self):
  112. self.check_poly(orth.eval_hermite, orth.hermite,
  113. param_ranges=[], x_range=[-100, 100])
  114. def test_hermitenorm(self):
  115. self.check_poly(orth.eval_hermitenorm, orth.hermitenorm,
  116. param_ranges=[], x_range=[-100, 100])
  117. class TestRecurrence(object):
  118. """
  119. Check that the eval_* functions sig='ld->d' and 'dd->d' agree.
  120. """
  121. def check_poly(self, func, param_ranges=[], x_range=[], nn=10,
  122. nparam=10, nx=10, rtol=1e-8):
  123. np.random.seed(1234)
  124. dataset = []
  125. for n in np.arange(nn):
  126. params = [a + (b-a)*np.random.rand(nparam) for a,b in param_ranges]
  127. params = np.asarray(params).T
  128. if not param_ranges:
  129. params = [0]
  130. for p in params:
  131. if param_ranges:
  132. p = (n,) + tuple(p)
  133. else:
  134. p = (n,)
  135. x = x_range[0] + (x_range[1] - x_range[0])*np.random.rand(nx)
  136. x[0] = x_range[0] # always include domain start point
  137. x[1] = x_range[1] # always include domain end point
  138. kw = dict(sig=(len(p)+1)*'d'+'->d')
  139. z = np.c_[np.tile(p, (nx,1)), x, func(*(p + (x,)), **kw)]
  140. dataset.append(z)
  141. dataset = np.concatenate(dataset, axis=0)
  142. def polyfunc(*p):
  143. p = (p[0].astype(int),) + p[1:]
  144. kw = dict(sig='l'+(len(p)-1)*'d'+'->d')
  145. return func(*p, **kw)
  146. olderr = np.seterr(all='raise')
  147. try:
  148. ds = FuncData(polyfunc, dataset, list(range(len(param_ranges)+2)), -1,
  149. rtol=rtol)
  150. ds.check()
  151. finally:
  152. np.seterr(**olderr)
  153. def test_jacobi(self):
  154. self.check_poly(orth.eval_jacobi,
  155. param_ranges=[(-0.99, 10), (-0.99, 10)], x_range=[-1, 1])
  156. def test_sh_jacobi(self):
  157. self.check_poly(orth.eval_sh_jacobi,
  158. param_ranges=[(1, 10), (0, 1)], x_range=[0, 1])
  159. def test_gegenbauer(self):
  160. self.check_poly(orth.eval_gegenbauer,
  161. param_ranges=[(-0.499, 10)], x_range=[-1, 1])
  162. def test_chebyt(self):
  163. self.check_poly(orth.eval_chebyt,
  164. param_ranges=[], x_range=[-1, 1])
  165. def test_chebyu(self):
  166. self.check_poly(orth.eval_chebyu,
  167. param_ranges=[], x_range=[-1, 1])
  168. def test_chebys(self):
  169. self.check_poly(orth.eval_chebys,
  170. param_ranges=[], x_range=[-2, 2])
  171. def test_chebyc(self):
  172. self.check_poly(orth.eval_chebyc,
  173. param_ranges=[], x_range=[-2, 2])
  174. def test_sh_chebyt(self):
  175. self.check_poly(orth.eval_sh_chebyt,
  176. param_ranges=[], x_range=[0, 1])
  177. def test_sh_chebyu(self):
  178. self.check_poly(orth.eval_sh_chebyu,
  179. param_ranges=[], x_range=[0, 1])
  180. def test_legendre(self):
  181. self.check_poly(orth.eval_legendre,
  182. param_ranges=[], x_range=[-1, 1])
  183. def test_sh_legendre(self):
  184. self.check_poly(orth.eval_sh_legendre,
  185. param_ranges=[], x_range=[0, 1])
  186. def test_genlaguerre(self):
  187. self.check_poly(orth.eval_genlaguerre,
  188. param_ranges=[(-0.99, 10)], x_range=[0, 100])
  189. def test_laguerre(self):
  190. self.check_poly(orth.eval_laguerre,
  191. param_ranges=[], x_range=[0, 100])
  192. def test_hermite(self):
  193. v = orth.eval_hermite(70, 1.0)
  194. a = -1.457076485701412e60
  195. assert_allclose(v,a)