test_real_transforms.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829
  1. from __future__ import division, print_function, absolute_import
  2. from os.path import join, dirname
  3. import numpy as np
  4. from numpy.testing import assert_array_almost_equal, assert_equal
  5. import pytest
  6. from pytest import raises as assert_raises
  7. from scipy.fftpack.realtransforms import (
  8. dct, idct, dst, idst, dctn, idctn, dstn, idstn)
  9. # Matlab reference data
  10. MDATA = np.load(join(dirname(__file__), 'test.npz'))
  11. X = [MDATA['x%d' % i] for i in range(8)]
  12. Y = [MDATA['y%d' % i] for i in range(8)]
  13. # FFTW reference data: the data are organized as follows:
  14. # * SIZES is an array containing all available sizes
  15. # * for every type (1, 2, 3, 4) and every size, the array dct_type_size
  16. # contains the output of the DCT applied to the input np.linspace(0, size-1,
  17. # size)
  18. FFTWDATA_DOUBLE = np.load(join(dirname(__file__), 'fftw_double_ref.npz'))
  19. FFTWDATA_SINGLE = np.load(join(dirname(__file__), 'fftw_single_ref.npz'))
  20. FFTWDATA_SIZES = FFTWDATA_DOUBLE['sizes']
  21. def fftw_dct_ref(type, size, dt):
  22. x = np.linspace(0, size-1, size).astype(dt)
  23. dt = np.result_type(np.float32, dt)
  24. if dt == np.double:
  25. data = FFTWDATA_DOUBLE
  26. elif dt == np.float32:
  27. data = FFTWDATA_SINGLE
  28. else:
  29. raise ValueError()
  30. y = (data['dct_%d_%d' % (type, size)]).astype(dt)
  31. return x, y, dt
  32. def fftw_dst_ref(type, size, dt):
  33. x = np.linspace(0, size-1, size).astype(dt)
  34. dt = np.result_type(np.float32, dt)
  35. if dt == np.double:
  36. data = FFTWDATA_DOUBLE
  37. elif dt == np.float32:
  38. data = FFTWDATA_SINGLE
  39. else:
  40. raise ValueError()
  41. y = (data['dst_%d_%d' % (type, size)]).astype(dt)
  42. return x, y, dt
  43. def dct_2d_ref(x, **kwargs):
  44. """Calculate reference values for testing dct2."""
  45. x = np.array(x, copy=True)
  46. for row in range(x.shape[0]):
  47. x[row, :] = dct(x[row, :], **kwargs)
  48. for col in range(x.shape[1]):
  49. x[:, col] = dct(x[:, col], **kwargs)
  50. return x
  51. def idct_2d_ref(x, **kwargs):
  52. """Calculate reference values for testing idct2."""
  53. x = np.array(x, copy=True)
  54. for row in range(x.shape[0]):
  55. x[row, :] = idct(x[row, :], **kwargs)
  56. for col in range(x.shape[1]):
  57. x[:, col] = idct(x[:, col], **kwargs)
  58. return x
  59. def dst_2d_ref(x, **kwargs):
  60. """Calculate reference values for testing dst2."""
  61. x = np.array(x, copy=True)
  62. for row in range(x.shape[0]):
  63. x[row, :] = dst(x[row, :], **kwargs)
  64. for col in range(x.shape[1]):
  65. x[:, col] = dst(x[:, col], **kwargs)
  66. return x
  67. def idst_2d_ref(x, **kwargs):
  68. """Calculate reference values for testing idst2."""
  69. x = np.array(x, copy=True)
  70. for row in range(x.shape[0]):
  71. x[row, :] = idst(x[row, :], **kwargs)
  72. for col in range(x.shape[1]):
  73. x[:, col] = idst(x[:, col], **kwargs)
  74. return x
  75. def naive_dct1(x, norm=None):
  76. """Calculate textbook definition version of DCT-I."""
  77. x = np.array(x, copy=True)
  78. N = len(x)
  79. M = N-1
  80. y = np.zeros(N)
  81. m0, m = 1, 2
  82. if norm == 'ortho':
  83. m0 = np.sqrt(1.0/M)
  84. m = np.sqrt(2.0/M)
  85. for k in range(N):
  86. for n in range(1, N-1):
  87. y[k] += m*x[n]*np.cos(np.pi*n*k/M)
  88. y[k] += m0 * x[0]
  89. y[k] += m0 * x[N-1] * (1 if k % 2 == 0 else -1)
  90. if norm == 'ortho':
  91. y[0] *= 1/np.sqrt(2)
  92. y[N-1] *= 1/np.sqrt(2)
  93. return y
  94. def naive_dst1(x, norm=None):
  95. """Calculate textbook definition version of DST-I."""
  96. x = np.array(x, copy=True)
  97. N = len(x)
  98. M = N+1
  99. y = np.zeros(N)
  100. for k in range(N):
  101. for n in range(N):
  102. y[k] += 2*x[n]*np.sin(np.pi*(n+1.0)*(k+1.0)/M)
  103. if norm == 'ortho':
  104. y *= np.sqrt(0.5/M)
  105. return y
  106. def naive_dct4(x, norm=None):
  107. """Calculate textbook definition version of DCT-IV."""
  108. x = np.array(x, copy=True)
  109. N = len(x)
  110. y = np.zeros(N)
  111. for k in range(N):
  112. for n in range(N):
  113. y[k] += x[n]*np.cos(np.pi*(n+0.5)*(k+0.5)/(N))
  114. if norm == 'ortho':
  115. y *= np.sqrt(2.0/N)
  116. else:
  117. y *= 2
  118. return y
  119. def naive_dst4(x, norm=None):
  120. """Calculate textbook definition version of DST-IV."""
  121. x = np.array(x, copy=True)
  122. N = len(x)
  123. y = np.zeros(N)
  124. for k in range(N):
  125. for n in range(N):
  126. y[k] += x[n]*np.sin(np.pi*(n+0.5)*(k+0.5)/(N))
  127. if norm == 'ortho':
  128. y *= np.sqrt(2.0/N)
  129. else:
  130. y *= 2
  131. return y
  132. class TestComplex(object):
  133. def test_dct_complex64(self):
  134. y = dct(1j*np.arange(5, dtype=np.complex64))
  135. x = 1j*dct(np.arange(5))
  136. assert_array_almost_equal(x, y)
  137. def test_dct_complex(self):
  138. y = dct(np.arange(5)*1j)
  139. x = 1j*dct(np.arange(5))
  140. assert_array_almost_equal(x, y)
  141. def test_idct_complex(self):
  142. y = idct(np.arange(5)*1j)
  143. x = 1j*idct(np.arange(5))
  144. assert_array_almost_equal(x, y)
  145. def test_dst_complex64(self):
  146. y = dst(np.arange(5, dtype=np.complex64)*1j)
  147. x = 1j*dst(np.arange(5))
  148. assert_array_almost_equal(x, y)
  149. def test_dst_complex(self):
  150. y = dst(np.arange(5)*1j)
  151. x = 1j*dst(np.arange(5))
  152. assert_array_almost_equal(x, y)
  153. def test_idst_complex(self):
  154. y = idst(np.arange(5)*1j)
  155. x = 1j*idst(np.arange(5))
  156. assert_array_almost_equal(x, y)
  157. class _TestDCTBase(object):
  158. def setup_method(self):
  159. self.rdt = None
  160. self.dec = 14
  161. self.type = None
  162. def test_definition(self):
  163. for i in FFTWDATA_SIZES:
  164. x, yr, dt = fftw_dct_ref(self.type, i, self.rdt)
  165. y = dct(x, type=self.type)
  166. assert_equal(y.dtype, dt)
  167. # XXX: we divide by np.max(y) because the tests fail otherwise. We
  168. # should really use something like assert_array_approx_equal. The
  169. # difference is due to fftw using a better algorithm w.r.t error
  170. # propagation compared to the ones from fftpack.
  171. assert_array_almost_equal(y / np.max(y), yr / np.max(y), decimal=self.dec,
  172. err_msg="Size %d failed" % i)
  173. def test_axis(self):
  174. nt = 2
  175. for i in [7, 8, 9, 16, 32, 64]:
  176. x = np.random.randn(nt, i)
  177. y = dct(x, type=self.type)
  178. for j in range(nt):
  179. assert_array_almost_equal(y[j], dct(x[j], type=self.type),
  180. decimal=self.dec)
  181. x = x.T
  182. y = dct(x, axis=0, type=self.type)
  183. for j in range(nt):
  184. assert_array_almost_equal(y[:,j], dct(x[:,j], type=self.type),
  185. decimal=self.dec)
  186. class _TestDCTIBase(_TestDCTBase):
  187. def test_definition_ortho(self):
  188. # Test orthornomal mode.
  189. for i in range(len(X)):
  190. x = np.array(X[i], dtype=self.rdt)
  191. dt = np.result_type(np.float32, self.rdt)
  192. y = dct(x, norm='ortho', type=1)
  193. y2 = naive_dct1(x, norm='ortho')
  194. assert_equal(y.dtype, dt)
  195. assert_array_almost_equal(y / np.max(y), y2 / np.max(y), decimal=self.dec)
  196. class _TestDCTIIBase(_TestDCTBase):
  197. def test_definition_matlab(self):
  198. # Test correspondence with matlab (orthornomal mode).
  199. for i in range(len(X)):
  200. dt = np.result_type(np.float32, self.rdt)
  201. x = np.array(X[i], dtype=dt)
  202. yr = Y[i]
  203. y = dct(x, norm="ortho", type=2)
  204. assert_equal(y.dtype, dt)
  205. assert_array_almost_equal(y, yr, decimal=self.dec)
  206. class _TestDCTIIIBase(_TestDCTBase):
  207. def test_definition_ortho(self):
  208. # Test orthornomal mode.
  209. for i in range(len(X)):
  210. x = np.array(X[i], dtype=self.rdt)
  211. dt = np.result_type(np.float32, self.rdt)
  212. y = dct(x, norm='ortho', type=2)
  213. xi = dct(y, norm="ortho", type=3)
  214. assert_equal(xi.dtype, dt)
  215. assert_array_almost_equal(xi, x, decimal=self.dec)
  216. class _TestDCTIVBase(_TestDCTBase):
  217. def test_definition_ortho(self):
  218. # Test orthornomal mode.
  219. for i in range(len(X)):
  220. x = np.array(X[i], dtype=self.rdt)
  221. dt = np.result_type(np.float32, self.rdt)
  222. y = dct(x, norm='ortho', type=4)
  223. y2 = naive_dct4(x, norm='ortho')
  224. assert_equal(y.dtype, dt)
  225. assert_array_almost_equal(y / np.max(y), y2 / np.max(y), decimal=self.dec)
  226. class TestDCTIDouble(_TestDCTIBase):
  227. def setup_method(self):
  228. self.rdt = np.double
  229. self.dec = 10
  230. self.type = 1
  231. class TestDCTIFloat(_TestDCTIBase):
  232. def setup_method(self):
  233. self.rdt = np.float32
  234. self.dec = 4
  235. self.type = 1
  236. class TestDCTIInt(_TestDCTIBase):
  237. def setup_method(self):
  238. self.rdt = int
  239. self.dec = 5
  240. self.type = 1
  241. class TestDCTIIDouble(_TestDCTIIBase):
  242. def setup_method(self):
  243. self.rdt = np.double
  244. self.dec = 10
  245. self.type = 2
  246. class TestDCTIIFloat(_TestDCTIIBase):
  247. def setup_method(self):
  248. self.rdt = np.float32
  249. self.dec = 5
  250. self.type = 2
  251. class TestDCTIIInt(_TestDCTIIBase):
  252. def setup_method(self):
  253. self.rdt = int
  254. self.dec = 5
  255. self.type = 2
  256. class TestDCTIIIDouble(_TestDCTIIIBase):
  257. def setup_method(self):
  258. self.rdt = np.double
  259. self.dec = 14
  260. self.type = 3
  261. class TestDCTIIIFloat(_TestDCTIIIBase):
  262. def setup_method(self):
  263. self.rdt = np.float32
  264. self.dec = 5
  265. self.type = 3
  266. class TestDCTIIIInt(_TestDCTIIIBase):
  267. def setup_method(self):
  268. self.rdt = int
  269. self.dec = 5
  270. self.type = 3
  271. class TestDCTIVDouble(_TestDCTIVBase):
  272. def setup_method(self):
  273. self.rdt = np.double
  274. self.dec = 12
  275. self.type = 3
  276. class TestDCTIVFloat(_TestDCTIVBase):
  277. def setup_method(self):
  278. self.rdt = np.float32
  279. self.dec = 5
  280. self.type = 3
  281. class TestDCTIVInt(_TestDCTIVBase):
  282. def setup_method(self):
  283. self.rdt = int
  284. self.dec = 5
  285. self.type = 3
  286. class _TestIDCTBase(object):
  287. def setup_method(self):
  288. self.rdt = None
  289. self.dec = 14
  290. self.type = None
  291. def test_definition(self):
  292. for i in FFTWDATA_SIZES:
  293. xr, yr, dt = fftw_dct_ref(self.type, i, self.rdt)
  294. x = idct(yr, type=self.type)
  295. if self.type == 1:
  296. x /= 2 * (i-1)
  297. else:
  298. x /= 2 * i
  299. assert_equal(x.dtype, dt)
  300. # XXX: we divide by np.max(y) because the tests fail otherwise. We
  301. # should really use something like assert_array_approx_equal. The
  302. # difference is due to fftw using a better algorithm w.r.t error
  303. # propagation compared to the ones from fftpack.
  304. assert_array_almost_equal(x / np.max(x), xr / np.max(x), decimal=self.dec,
  305. err_msg="Size %d failed" % i)
  306. class TestIDCTIDouble(_TestIDCTBase):
  307. def setup_method(self):
  308. self.rdt = np.double
  309. self.dec = 10
  310. self.type = 1
  311. class TestIDCTIFloat(_TestIDCTBase):
  312. def setup_method(self):
  313. self.rdt = np.float32
  314. self.dec = 4
  315. self.type = 1
  316. class TestIDCTIInt(_TestIDCTBase):
  317. def setup_method(self):
  318. self.rdt = int
  319. self.dec = 4
  320. self.type = 1
  321. class TestIDCTIIDouble(_TestIDCTBase):
  322. def setup_method(self):
  323. self.rdt = np.double
  324. self.dec = 10
  325. self.type = 2
  326. class TestIDCTIIFloat(_TestIDCTBase):
  327. def setup_method(self):
  328. self.rdt = np.float32
  329. self.dec = 5
  330. self.type = 2
  331. class TestIDCTIIInt(_TestIDCTBase):
  332. def setup_method(self):
  333. self.rdt = int
  334. self.dec = 5
  335. self.type = 2
  336. class TestIDCTIIIDouble(_TestIDCTBase):
  337. def setup_method(self):
  338. self.rdt = np.double
  339. self.dec = 14
  340. self.type = 3
  341. class TestIDCTIIIFloat(_TestIDCTBase):
  342. def setup_method(self):
  343. self.rdt = np.float32
  344. self.dec = 5
  345. self.type = 3
  346. class TestIDCTIIIInt(_TestIDCTBase):
  347. def setup_method(self):
  348. self.rdt = int
  349. self.dec = 5
  350. self.type = 3
  351. class TestIDCTIVDouble(_TestIDCTBase):
  352. def setup_method(self):
  353. self.rdt = np.double
  354. self.dec = 12
  355. self.type = 4
  356. class TestIDCTIVFloat(_TestIDCTBase):
  357. def setup_method(self):
  358. self.rdt = np.float32
  359. self.dec = 5
  360. self.type = 4
  361. class TestIDCTIVInt(_TestIDCTBase):
  362. def setup_method(self):
  363. self.rdt = int
  364. self.dec = 5
  365. self.type = 4
  366. class _TestDSTBase(object):
  367. def setup_method(self):
  368. self.rdt = None # dtype
  369. self.dec = None # number of decimals to match
  370. self.type = None # dst type
  371. def test_definition(self):
  372. for i in FFTWDATA_SIZES:
  373. xr, yr, dt = fftw_dst_ref(self.type, i, self.rdt)
  374. y = dst(xr, type=self.type)
  375. assert_equal(y.dtype, dt)
  376. # XXX: we divide by np.max(y) because the tests fail otherwise. We
  377. # should really use something like assert_array_approx_equal. The
  378. # difference is due to fftw using a better algorithm w.r.t error
  379. # propagation compared to the ones from fftpack.
  380. assert_array_almost_equal(y / np.max(y), yr / np.max(y), decimal=self.dec,
  381. err_msg="Size %d failed" % i)
  382. class _TestDSTIBase(_TestDSTBase):
  383. def test_definition_ortho(self):
  384. # Test orthornomal mode.
  385. for i in range(len(X)):
  386. x = np.array(X[i], dtype=self.rdt)
  387. dt = np.result_type(np.float32, self.rdt)
  388. y = dst(x, norm='ortho', type=1)
  389. y2 = naive_dst1(x, norm='ortho')
  390. assert_equal(y.dtype, dt)
  391. assert_array_almost_equal(y / np.max(y), y2 / np.max(y), decimal=self.dec)
  392. class _TestDSTIVBase(_TestDSTBase):
  393. def test_definition_ortho(self):
  394. # Test orthornomal mode.
  395. for i in range(len(X)):
  396. x = np.array(X[i], dtype=self.rdt)
  397. dt = np.result_type(np.float32, self.rdt)
  398. y = dst(x, norm='ortho', type=4)
  399. y2 = naive_dst4(x, norm='ortho')
  400. assert_equal(y.dtype, dt)
  401. assert_array_almost_equal(y, y2, decimal=self.dec)
  402. class TestDSTIDouble(_TestDSTIBase):
  403. def setup_method(self):
  404. self.rdt = np.double
  405. self.dec = 12
  406. self.type = 1
  407. class TestDSTIFloat(_TestDSTIBase):
  408. def setup_method(self):
  409. self.rdt = np.float32
  410. self.dec = 4
  411. self.type = 1
  412. class TestDSTIInt(_TestDSTIBase):
  413. def setup_method(self):
  414. self.rdt = int
  415. self.dec = 5
  416. self.type = 1
  417. class TestDSTIIDouble(_TestDSTBase):
  418. def setup_method(self):
  419. self.rdt = np.double
  420. self.dec = 14
  421. self.type = 2
  422. class TestDSTIIFloat(_TestDSTBase):
  423. def setup_method(self):
  424. self.rdt = np.float32
  425. self.dec = 6
  426. self.type = 2
  427. class TestDSTIIInt(_TestDSTBase):
  428. def setup_method(self):
  429. self.rdt = int
  430. self.dec = 6
  431. self.type = 2
  432. class TestDSTIIIDouble(_TestDSTBase):
  433. def setup_method(self):
  434. self.rdt = np.double
  435. self.dec = 14
  436. self.type = 3
  437. class TestDSTIIIFloat(_TestDSTBase):
  438. def setup_method(self):
  439. self.rdt = np.float32
  440. self.dec = 7
  441. self.type = 3
  442. class TestDSTIIIInt(_TestDSTBase):
  443. def setup_method(self):
  444. self.rdt = int
  445. self.dec = 7
  446. self.type = 3
  447. class TestDSTIVDouble(_TestDSTIVBase):
  448. def setup_method(self):
  449. self.rdt = np.double
  450. self.dec = 12
  451. self.type = 4
  452. class TestDSTIVFloat(_TestDSTIVBase):
  453. def setup_method(self):
  454. self.rdt = np.float32
  455. self.dec = 4
  456. self.type = 4
  457. class TestDSTIVInt(_TestDSTIVBase):
  458. def setup_method(self):
  459. self.rdt = int
  460. self.dec = 5
  461. self.type = 4
  462. class _TestIDSTBase(object):
  463. def setup_method(self):
  464. self.rdt = None
  465. self.dec = None
  466. self.type = None
  467. def test_definition(self):
  468. for i in FFTWDATA_SIZES:
  469. xr, yr, dt = fftw_dst_ref(self.type, i, self.rdt)
  470. x = idst(yr, type=self.type)
  471. if self.type == 1:
  472. x /= 2 * (i+1)
  473. else:
  474. x /= 2 * i
  475. assert_equal(x.dtype, dt)
  476. # XXX: we divide by np.max(x) because the tests fail otherwise. We
  477. # should really use something like assert_array_approx_equal. The
  478. # difference is due to fftw using a better algorithm w.r.t error
  479. # propagation compared to the ones from fftpack.
  480. assert_array_almost_equal(x / np.max(x), xr / np.max(x), decimal=self.dec,
  481. err_msg="Size %d failed" % i)
  482. class TestIDSTIDouble(_TestIDSTBase):
  483. def setup_method(self):
  484. self.rdt = np.double
  485. self.dec = 12
  486. self.type = 1
  487. class TestIDSTIFloat(_TestIDSTBase):
  488. def setup_method(self):
  489. self.rdt = np.float32
  490. self.dec = 4
  491. self.type = 1
  492. class TestIDSTIInt(_TestIDSTBase):
  493. def setup_method(self):
  494. self.rdt = int
  495. self.dec = 4
  496. self.type = 1
  497. class TestIDSTIIDouble(_TestIDSTBase):
  498. def setup_method(self):
  499. self.rdt = np.double
  500. self.dec = 14
  501. self.type = 2
  502. class TestIDSTIIFloat(_TestIDSTBase):
  503. def setup_method(self):
  504. self.rdt = np.float32
  505. self.dec = 6
  506. self.type = 2
  507. class TestIDSTIIInt(_TestIDSTBase):
  508. def setup_method(self):
  509. self.rdt = int
  510. self.dec = 6
  511. self.type = 2
  512. class TestIDSTIIIDouble(_TestIDSTBase):
  513. def setup_method(self):
  514. self.rdt = np.double
  515. self.dec = 14
  516. self.type = 3
  517. class TestIDSTIIIFloat(_TestIDSTBase):
  518. def setup_method(self):
  519. self.rdt = np.float32
  520. self.dec = 6
  521. self.type = 3
  522. class TestIDSTIIIInt(_TestIDSTBase):
  523. def setup_method(self):
  524. self.rdt = int
  525. self.dec = 6
  526. self.type = 3
  527. class TestIDSTIVDouble(_TestIDSTBase):
  528. def setup_method(self):
  529. self.rdt = np.double
  530. self.dec = 12
  531. self.type = 4
  532. class TestIDSTIVFloat(_TestIDSTBase):
  533. def setup_method(self):
  534. self.rdt = np.float32
  535. self.dec = 6
  536. self.type = 4
  537. class TestIDSTIVnt(_TestIDSTBase):
  538. def setup_method(self):
  539. self.rdt = int
  540. self.dec = 6
  541. self.type = 4
  542. class TestOverwrite(object):
  543. """Check input overwrite behavior."""
  544. real_dtypes = [np.float32, np.float64]
  545. def _check(self, x, routine, type, fftsize, axis, norm, overwrite_x,
  546. should_overwrite, **kw):
  547. x2 = x.copy()
  548. routine(x2, type, fftsize, axis, norm, overwrite_x=overwrite_x)
  549. sig = "%s(%s%r, %r, axis=%r, overwrite_x=%r)" % (
  550. routine.__name__, x.dtype, x.shape, fftsize, axis, overwrite_x)
  551. if not should_overwrite:
  552. assert_equal(x2, x, err_msg="spurious overwrite in %s" % sig)
  553. def _check_1d(self, routine, dtype, shape, axis, overwritable_dtypes):
  554. np.random.seed(1234)
  555. if np.issubdtype(dtype, np.complexfloating):
  556. data = np.random.randn(*shape) + 1j*np.random.randn(*shape)
  557. else:
  558. data = np.random.randn(*shape)
  559. data = data.astype(dtype)
  560. for type in [1, 2, 3, 4]:
  561. for overwrite_x in [True, False]:
  562. for norm in [None, 'ortho']:
  563. should_overwrite = (overwrite_x
  564. and dtype in overwritable_dtypes
  565. and (len(shape) == 1 or
  566. (axis % len(shape) == len(shape)-1
  567. )))
  568. self._check(data, routine, type, None, axis, norm,
  569. overwrite_x, should_overwrite)
  570. def test_dct(self):
  571. overwritable = self.real_dtypes
  572. for dtype in self.real_dtypes:
  573. self._check_1d(dct, dtype, (16,), -1, overwritable)
  574. self._check_1d(dct, dtype, (16, 2), 0, overwritable)
  575. self._check_1d(dct, dtype, (2, 16), 1, overwritable)
  576. def test_idct(self):
  577. overwritable = self.real_dtypes
  578. for dtype in self.real_dtypes:
  579. self._check_1d(idct, dtype, (16,), -1, overwritable)
  580. self._check_1d(idct, dtype, (16, 2), 0, overwritable)
  581. self._check_1d(idct, dtype, (2, 16), 1, overwritable)
  582. def test_dst(self):
  583. overwritable = self.real_dtypes
  584. for dtype in self.real_dtypes:
  585. self._check_1d(dst, dtype, (16,), -1, overwritable)
  586. self._check_1d(dst, dtype, (16, 2), 0, overwritable)
  587. self._check_1d(dst, dtype, (2, 16), 1, overwritable)
  588. def test_idst(self):
  589. overwritable = self.real_dtypes
  590. for dtype in self.real_dtypes:
  591. self._check_1d(idst, dtype, (16,), -1, overwritable)
  592. self._check_1d(idst, dtype, (16, 2), 0, overwritable)
  593. self._check_1d(idst, dtype, (2, 16), 1, overwritable)
  594. class Test_DCTN_IDCTN(object):
  595. dec = 14
  596. dct_type = [1, 2, 3, 4]
  597. norms = [None, 'ortho']
  598. rstate = np.random.RandomState(1234)
  599. shape = (32, 16)
  600. data = rstate.randn(*shape)
  601. @pytest.mark.parametrize('fforward,finverse', [(dctn, idctn),
  602. (dstn, idstn)])
  603. @pytest.mark.parametrize('axes', [None,
  604. 1, (1,), [1],
  605. 0, (0,), [0],
  606. (0, 1), [0, 1],
  607. (-2, -1), [-2, -1]])
  608. @pytest.mark.parametrize('dct_type', dct_type)
  609. @pytest.mark.parametrize('norm', ['ortho'])
  610. def test_axes_round_trip(self, fforward, finverse, axes, dct_type, norm):
  611. tmp = fforward(self.data, type=dct_type, axes=axes, norm=norm)
  612. tmp = finverse(tmp, type=dct_type, axes=axes, norm=norm)
  613. assert_array_almost_equal(self.data, tmp, decimal=12)
  614. @pytest.mark.parametrize('fforward,fforward_ref', [(dctn, dct_2d_ref),
  615. (dstn, dst_2d_ref)])
  616. @pytest.mark.parametrize('dct_type', dct_type)
  617. @pytest.mark.parametrize('norm', norms)
  618. def test_dctn_vs_2d_reference(self, fforward, fforward_ref,
  619. dct_type, norm):
  620. y1 = fforward(self.data, type=dct_type, axes=None, norm=norm)
  621. y2 = fforward_ref(self.data, type=dct_type, norm=norm)
  622. assert_array_almost_equal(y1, y2, decimal=11)
  623. @pytest.mark.parametrize('finverse,finverse_ref', [(idctn, idct_2d_ref),
  624. (idstn, idst_2d_ref)])
  625. @pytest.mark.parametrize('dct_type', dct_type)
  626. @pytest.mark.parametrize('norm', [None, 'ortho'])
  627. def test_idctn_vs_2d_reference(self, finverse, finverse_ref,
  628. dct_type, norm):
  629. fdata = dctn(self.data, type=dct_type, norm=norm)
  630. y1 = finverse(fdata, type=dct_type, norm=norm)
  631. y2 = finverse_ref(fdata, type=dct_type, norm=norm)
  632. assert_array_almost_equal(y1, y2, decimal=11)
  633. @pytest.mark.parametrize('fforward,finverse', [(dctn, idctn),
  634. (dstn, idstn)])
  635. def test_axes_and_shape(self, fforward, finverse):
  636. with assert_raises(ValueError,
  637. match="when given, axes and shape arguments"
  638. " have to be of the same length"):
  639. fforward(self.data, shape=self.data.shape[0], axes=(0, 1))
  640. with assert_raises(ValueError,
  641. match="when given, axes and shape arguments"
  642. " have to be of the same length"):
  643. fforward(self.data, shape=self.data.shape[0], axes=None)
  644. with assert_raises(ValueError,
  645. match="when given, axes and shape arguments"
  646. " have to be of the same length"):
  647. fforward(self.data, shape=self.data.shape, axes=0)
  648. @pytest.mark.parametrize('fforward', [dctn, dstn])
  649. def test_shape(self, fforward):
  650. tmp = fforward(self.data, shape=(128, 128), axes=None)
  651. assert_equal(tmp.shape, (128, 128))
  652. @pytest.mark.parametrize('fforward,finverse', [(dctn, idctn),
  653. (dstn, idstn)])
  654. @pytest.mark.parametrize('axes', [1, (1,), [1],
  655. 0, (0,), [0]])
  656. def test_shape_is_none_with_axes(self, fforward, finverse, axes):
  657. tmp = fforward(self.data, shape=None, axes=axes, norm='ortho')
  658. tmp = finverse(tmp, shape=None, axes=axes, norm='ortho')
  659. assert_array_almost_equal(self.data, tmp, decimal=self.dec)