test_pseudo_diffs.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. # Created by Pearu Peterson, September 2002
  2. from __future__ import division, print_function, absolute_import
  3. __usage__ = """
  4. Build fftpack:
  5. python setup_fftpack.py build
  6. Run tests if scipy is installed:
  7. python -c 'import scipy;scipy.fftpack.test(<level>)'
  8. Run tests if fftpack is not installed:
  9. python tests/test_pseudo_diffs.py [<level>]
  10. """
  11. from numpy.testing import (assert_equal, assert_almost_equal,
  12. assert_array_almost_equal)
  13. from scipy.fftpack import (diff, fft, ifft, tilbert, itilbert, hilbert,
  14. ihilbert, shift, fftfreq, cs_diff, sc_diff,
  15. ss_diff, cc_diff)
  16. import numpy as np
  17. from numpy import arange, sin, cos, pi, exp, tanh, sum, sign
  18. from numpy.random import random
  19. def direct_diff(x,k=1,period=None):
  20. fx = fft(x)
  21. n = len(fx)
  22. if period is None:
  23. period = 2*pi
  24. w = fftfreq(n)*2j*pi/period*n
  25. if k < 0:
  26. w = 1 / w**k
  27. w[0] = 0.0
  28. else:
  29. w = w**k
  30. if n > 2000:
  31. w[250:n-250] = 0.0
  32. return ifft(w*fx).real
  33. def direct_tilbert(x,h=1,period=None):
  34. fx = fft(x)
  35. n = len(fx)
  36. if period is None:
  37. period = 2*pi
  38. w = fftfreq(n)*h*2*pi/period*n
  39. w[0] = 1
  40. w = 1j/tanh(w)
  41. w[0] = 0j
  42. return ifft(w*fx)
  43. def direct_itilbert(x,h=1,period=None):
  44. fx = fft(x)
  45. n = len(fx)
  46. if period is None:
  47. period = 2*pi
  48. w = fftfreq(n)*h*2*pi/period*n
  49. w = -1j*tanh(w)
  50. return ifft(w*fx)
  51. def direct_hilbert(x):
  52. fx = fft(x)
  53. n = len(fx)
  54. w = fftfreq(n)*n
  55. w = 1j*sign(w)
  56. return ifft(w*fx)
  57. def direct_ihilbert(x):
  58. return -direct_hilbert(x)
  59. def direct_shift(x,a,period=None):
  60. n = len(x)
  61. if period is None:
  62. k = fftfreq(n)*1j*n
  63. else:
  64. k = fftfreq(n)*2j*pi/period*n
  65. return ifft(fft(x)*exp(k*a)).real
  66. class TestDiff(object):
  67. def test_definition(self):
  68. for n in [16,17,64,127,32]:
  69. x = arange(n)*2*pi/n
  70. assert_array_almost_equal(diff(sin(x)),direct_diff(sin(x)))
  71. assert_array_almost_equal(diff(sin(x),2),direct_diff(sin(x),2))
  72. assert_array_almost_equal(diff(sin(x),3),direct_diff(sin(x),3))
  73. assert_array_almost_equal(diff(sin(x),4),direct_diff(sin(x),4))
  74. assert_array_almost_equal(diff(sin(x),5),direct_diff(sin(x),5))
  75. assert_array_almost_equal(diff(sin(2*x),3),direct_diff(sin(2*x),3))
  76. assert_array_almost_equal(diff(sin(2*x),4),direct_diff(sin(2*x),4))
  77. assert_array_almost_equal(diff(cos(x)),direct_diff(cos(x)))
  78. assert_array_almost_equal(diff(cos(x),2),direct_diff(cos(x),2))
  79. assert_array_almost_equal(diff(cos(x),3),direct_diff(cos(x),3))
  80. assert_array_almost_equal(diff(cos(x),4),direct_diff(cos(x),4))
  81. assert_array_almost_equal(diff(cos(2*x)),direct_diff(cos(2*x)))
  82. assert_array_almost_equal(diff(sin(x*n/8)),direct_diff(sin(x*n/8)))
  83. assert_array_almost_equal(diff(cos(x*n/8)),direct_diff(cos(x*n/8)))
  84. for k in range(5):
  85. assert_array_almost_equal(diff(sin(4*x),k),direct_diff(sin(4*x),k))
  86. assert_array_almost_equal(diff(cos(4*x),k),direct_diff(cos(4*x),k))
  87. def test_period(self):
  88. for n in [17,64]:
  89. x = arange(n)/float(n)
  90. assert_array_almost_equal(diff(sin(2*pi*x),period=1),
  91. 2*pi*cos(2*pi*x))
  92. assert_array_almost_equal(diff(sin(2*pi*x),3,period=1),
  93. -(2*pi)**3*cos(2*pi*x))
  94. def test_sin(self):
  95. for n in [32,64,77]:
  96. x = arange(n)*2*pi/n
  97. assert_array_almost_equal(diff(sin(x)),cos(x))
  98. assert_array_almost_equal(diff(cos(x)),-sin(x))
  99. assert_array_almost_equal(diff(sin(x),2),-sin(x))
  100. assert_array_almost_equal(diff(sin(x),4),sin(x))
  101. assert_array_almost_equal(diff(sin(4*x)),4*cos(4*x))
  102. assert_array_almost_equal(diff(sin(sin(x))),cos(x)*cos(sin(x)))
  103. def test_expr(self):
  104. for n in [64,77,100,128,256,512,1024,2048,4096,8192][:5]:
  105. x = arange(n)*2*pi/n
  106. f = sin(x)*cos(4*x)+exp(sin(3*x))
  107. df = cos(x)*cos(4*x)-4*sin(x)*sin(4*x)+3*cos(3*x)*exp(sin(3*x))
  108. ddf = -17*sin(x)*cos(4*x)-8*cos(x)*sin(4*x)\
  109. - 9*sin(3*x)*exp(sin(3*x))+9*cos(3*x)**2*exp(sin(3*x))
  110. d1 = diff(f)
  111. assert_array_almost_equal(d1,df)
  112. assert_array_almost_equal(diff(df),ddf)
  113. assert_array_almost_equal(diff(f,2),ddf)
  114. assert_array_almost_equal(diff(ddf,-1),df)
  115. def test_expr_large(self):
  116. for n in [2048,4096]:
  117. x = arange(n)*2*pi/n
  118. f = sin(x)*cos(4*x)+exp(sin(3*x))
  119. df = cos(x)*cos(4*x)-4*sin(x)*sin(4*x)+3*cos(3*x)*exp(sin(3*x))
  120. ddf = -17*sin(x)*cos(4*x)-8*cos(x)*sin(4*x)\
  121. - 9*sin(3*x)*exp(sin(3*x))+9*cos(3*x)**2*exp(sin(3*x))
  122. assert_array_almost_equal(diff(f),df)
  123. assert_array_almost_equal(diff(df),ddf)
  124. assert_array_almost_equal(diff(ddf,-1),df)
  125. assert_array_almost_equal(diff(f,2),ddf)
  126. def test_int(self):
  127. n = 64
  128. x = arange(n)*2*pi/n
  129. assert_array_almost_equal(diff(sin(x),-1),-cos(x))
  130. assert_array_almost_equal(diff(sin(x),-2),-sin(x))
  131. assert_array_almost_equal(diff(sin(x),-4),sin(x))
  132. assert_array_almost_equal(diff(2*cos(2*x),-1),sin(2*x))
  133. def test_random_even(self):
  134. for k in [0,2,4,6]:
  135. for n in [60,32,64,56,55]:
  136. f = random((n,))
  137. af = sum(f,axis=0)/n
  138. f = f-af
  139. # zeroing Nyquist mode:
  140. f = diff(diff(f,1),-1)
  141. assert_almost_equal(sum(f,axis=0),0.0)
  142. assert_array_almost_equal(diff(diff(f,k),-k),f)
  143. assert_array_almost_equal(diff(diff(f,-k),k),f)
  144. def test_random_odd(self):
  145. for k in [0,1,2,3,4,5,6]:
  146. for n in [33,65,55]:
  147. f = random((n,))
  148. af = sum(f,axis=0)/n
  149. f = f-af
  150. assert_almost_equal(sum(f,axis=0),0.0)
  151. assert_array_almost_equal(diff(diff(f,k),-k),f)
  152. assert_array_almost_equal(diff(diff(f,-k),k),f)
  153. def test_zero_nyquist(self):
  154. for k in [0,1,2,3,4,5,6]:
  155. for n in [32,33,64,56,55]:
  156. f = random((n,))
  157. af = sum(f,axis=0)/n
  158. f = f-af
  159. # zeroing Nyquist mode:
  160. f = diff(diff(f,1),-1)
  161. assert_almost_equal(sum(f,axis=0),0.0)
  162. assert_array_almost_equal(diff(diff(f,k),-k),f)
  163. assert_array_almost_equal(diff(diff(f,-k),k),f)
  164. class TestTilbert(object):
  165. def test_definition(self):
  166. for h in [0.1,0.5,1,5.5,10]:
  167. for n in [16,17,64,127]:
  168. x = arange(n)*2*pi/n
  169. y = tilbert(sin(x),h)
  170. y1 = direct_tilbert(sin(x),h)
  171. assert_array_almost_equal(y,y1)
  172. assert_array_almost_equal(tilbert(sin(x),h),
  173. direct_tilbert(sin(x),h))
  174. assert_array_almost_equal(tilbert(sin(2*x),h),
  175. direct_tilbert(sin(2*x),h))
  176. def test_random_even(self):
  177. for h in [0.1,0.5,1,5.5,10]:
  178. for n in [32,64,56]:
  179. f = random((n,))
  180. af = sum(f,axis=0)/n
  181. f = f-af
  182. assert_almost_equal(sum(f,axis=0),0.0)
  183. assert_array_almost_equal(direct_tilbert(direct_itilbert(f,h),h),f)
  184. def test_random_odd(self):
  185. for h in [0.1,0.5,1,5.5,10]:
  186. for n in [33,65,55]:
  187. f = random((n,))
  188. af = sum(f,axis=0)/n
  189. f = f-af
  190. assert_almost_equal(sum(f,axis=0),0.0)
  191. assert_array_almost_equal(itilbert(tilbert(f,h),h),f)
  192. assert_array_almost_equal(tilbert(itilbert(f,h),h),f)
  193. class TestITilbert(object):
  194. def test_definition(self):
  195. for h in [0.1,0.5,1,5.5,10]:
  196. for n in [16,17,64,127]:
  197. x = arange(n)*2*pi/n
  198. y = itilbert(sin(x),h)
  199. y1 = direct_itilbert(sin(x),h)
  200. assert_array_almost_equal(y,y1)
  201. assert_array_almost_equal(itilbert(sin(x),h),
  202. direct_itilbert(sin(x),h))
  203. assert_array_almost_equal(itilbert(sin(2*x),h),
  204. direct_itilbert(sin(2*x),h))
  205. class TestHilbert(object):
  206. def test_definition(self):
  207. for n in [16,17,64,127]:
  208. x = arange(n)*2*pi/n
  209. y = hilbert(sin(x))
  210. y1 = direct_hilbert(sin(x))
  211. assert_array_almost_equal(y,y1)
  212. assert_array_almost_equal(hilbert(sin(2*x)),
  213. direct_hilbert(sin(2*x)))
  214. def test_tilbert_relation(self):
  215. for n in [16,17,64,127]:
  216. x = arange(n)*2*pi/n
  217. f = sin(x)+cos(2*x)*sin(x)
  218. y = hilbert(f)
  219. y1 = direct_hilbert(f)
  220. assert_array_almost_equal(y,y1)
  221. y2 = tilbert(f,h=10)
  222. assert_array_almost_equal(y,y2)
  223. def test_random_odd(self):
  224. for n in [33,65,55]:
  225. f = random((n,))
  226. af = sum(f,axis=0)/n
  227. f = f-af
  228. assert_almost_equal(sum(f,axis=0),0.0)
  229. assert_array_almost_equal(ihilbert(hilbert(f)),f)
  230. assert_array_almost_equal(hilbert(ihilbert(f)),f)
  231. def test_random_even(self):
  232. for n in [32,64,56]:
  233. f = random((n,))
  234. af = sum(f,axis=0)/n
  235. f = f-af
  236. # zeroing Nyquist mode:
  237. f = diff(diff(f,1),-1)
  238. assert_almost_equal(sum(f,axis=0),0.0)
  239. assert_array_almost_equal(direct_hilbert(direct_ihilbert(f)),f)
  240. assert_array_almost_equal(hilbert(ihilbert(f)),f)
  241. class TestIHilbert(object):
  242. def test_definition(self):
  243. for n in [16,17,64,127]:
  244. x = arange(n)*2*pi/n
  245. y = ihilbert(sin(x))
  246. y1 = direct_ihilbert(sin(x))
  247. assert_array_almost_equal(y,y1)
  248. assert_array_almost_equal(ihilbert(sin(2*x)),
  249. direct_ihilbert(sin(2*x)))
  250. def test_itilbert_relation(self):
  251. for n in [16,17,64,127]:
  252. x = arange(n)*2*pi/n
  253. f = sin(x)+cos(2*x)*sin(x)
  254. y = ihilbert(f)
  255. y1 = direct_ihilbert(f)
  256. assert_array_almost_equal(y,y1)
  257. y2 = itilbert(f,h=10)
  258. assert_array_almost_equal(y,y2)
  259. class TestShift(object):
  260. def test_definition(self):
  261. for n in [18,17,64,127,32,2048,256]:
  262. x = arange(n)*2*pi/n
  263. for a in [0.1,3]:
  264. assert_array_almost_equal(shift(sin(x),a),direct_shift(sin(x),a))
  265. assert_array_almost_equal(shift(sin(x),a),sin(x+a))
  266. assert_array_almost_equal(shift(cos(x),a),cos(x+a))
  267. assert_array_almost_equal(shift(cos(2*x)+sin(x),a),
  268. cos(2*(x+a))+sin(x+a))
  269. assert_array_almost_equal(shift(exp(sin(x)),a),exp(sin(x+a)))
  270. assert_array_almost_equal(shift(sin(x),2*pi),sin(x))
  271. assert_array_almost_equal(shift(sin(x),pi),-sin(x))
  272. assert_array_almost_equal(shift(sin(x),pi/2),cos(x))
  273. class TestOverwrite(object):
  274. """Check input overwrite behavior """
  275. real_dtypes = [np.float32, np.float64]
  276. dtypes = real_dtypes + [np.complex64, np.complex128]
  277. def _check(self, x, routine, *args, **kwargs):
  278. x2 = x.copy()
  279. routine(x2, *args, **kwargs)
  280. sig = routine.__name__
  281. if args:
  282. sig += repr(args)
  283. if kwargs:
  284. sig += repr(kwargs)
  285. assert_equal(x2, x, err_msg="spurious overwrite in %s" % sig)
  286. def _check_1d(self, routine, dtype, shape, *args, **kwargs):
  287. np.random.seed(1234)
  288. if np.issubdtype(dtype, np.complexfloating):
  289. data = np.random.randn(*shape) + 1j*np.random.randn(*shape)
  290. else:
  291. data = np.random.randn(*shape)
  292. data = data.astype(dtype)
  293. self._check(data, routine, *args, **kwargs)
  294. def test_diff(self):
  295. for dtype in self.dtypes:
  296. self._check_1d(diff, dtype, (16,))
  297. def test_tilbert(self):
  298. for dtype in self.dtypes:
  299. self._check_1d(tilbert, dtype, (16,), 1.6)
  300. def test_itilbert(self):
  301. for dtype in self.dtypes:
  302. self._check_1d(itilbert, dtype, (16,), 1.6)
  303. def test_hilbert(self):
  304. for dtype in self.dtypes:
  305. self._check_1d(hilbert, dtype, (16,))
  306. def test_cs_diff(self):
  307. for dtype in self.dtypes:
  308. self._check_1d(cs_diff, dtype, (16,), 1.0, 4.0)
  309. def test_sc_diff(self):
  310. for dtype in self.dtypes:
  311. self._check_1d(sc_diff, dtype, (16,), 1.0, 4.0)
  312. def test_ss_diff(self):
  313. for dtype in self.dtypes:
  314. self._check_1d(ss_diff, dtype, (16,), 1.0, 4.0)
  315. def test_cc_diff(self):
  316. for dtype in self.dtypes:
  317. self._check_1d(cc_diff, dtype, (16,), 1.0, 4.0)
  318. def test_shift(self):
  319. for dtype in self.dtypes:
  320. self._check_1d(shift, dtype, (16,), 1.0)