test_waveforms.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. from __future__ import division, print_function, absolute_import
  2. import numpy as np
  3. from numpy.testing import (assert_almost_equal, assert_equal,
  4. assert_, assert_allclose, assert_array_equal)
  5. from pytest import raises as assert_raises
  6. import scipy.signal.waveforms as waveforms
  7. # These chirp_* functions are the instantaneous frequencies of the signals
  8. # returned by chirp().
  9. def chirp_linear(t, f0, f1, t1):
  10. f = f0 + (f1 - f0) * t / t1
  11. return f
  12. def chirp_quadratic(t, f0, f1, t1, vertex_zero=True):
  13. if vertex_zero:
  14. f = f0 + (f1 - f0) * t**2 / t1**2
  15. else:
  16. f = f1 - (f1 - f0) * (t1 - t)**2 / t1**2
  17. return f
  18. def chirp_geometric(t, f0, f1, t1):
  19. f = f0 * (f1/f0)**(t/t1)
  20. return f
  21. def chirp_hyperbolic(t, f0, f1, t1):
  22. f = f0*f1*t1 / ((f0 - f1)*t + f1*t1)
  23. return f
  24. def compute_frequency(t, theta):
  25. """
  26. Compute theta'(t)/(2*pi), where theta'(t) is the derivative of theta(t).
  27. """
  28. # Assume theta and t are 1D numpy arrays.
  29. # Assume that t is uniformly spaced.
  30. dt = t[1] - t[0]
  31. f = np.diff(theta)/(2*np.pi) / dt
  32. tf = 0.5*(t[1:] + t[:-1])
  33. return tf, f
  34. class TestChirp(object):
  35. def test_linear_at_zero(self):
  36. w = waveforms.chirp(t=0, f0=1.0, f1=2.0, t1=1.0, method='linear')
  37. assert_almost_equal(w, 1.0)
  38. def test_linear_freq_01(self):
  39. method = 'linear'
  40. f0 = 1.0
  41. f1 = 2.0
  42. t1 = 1.0
  43. t = np.linspace(0, t1, 100)
  44. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  45. tf, f = compute_frequency(t, phase)
  46. abserr = np.max(np.abs(f - chirp_linear(tf, f0, f1, t1)))
  47. assert_(abserr < 1e-6)
  48. def test_linear_freq_02(self):
  49. method = 'linear'
  50. f0 = 200.0
  51. f1 = 100.0
  52. t1 = 10.0
  53. t = np.linspace(0, t1, 100)
  54. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  55. tf, f = compute_frequency(t, phase)
  56. abserr = np.max(np.abs(f - chirp_linear(tf, f0, f1, t1)))
  57. assert_(abserr < 1e-6)
  58. def test_quadratic_at_zero(self):
  59. w = waveforms.chirp(t=0, f0=1.0, f1=2.0, t1=1.0, method='quadratic')
  60. assert_almost_equal(w, 1.0)
  61. def test_quadratic_at_zero2(self):
  62. w = waveforms.chirp(t=0, f0=1.0, f1=2.0, t1=1.0, method='quadratic',
  63. vertex_zero=False)
  64. assert_almost_equal(w, 1.0)
  65. def test_quadratic_freq_01(self):
  66. method = 'quadratic'
  67. f0 = 1.0
  68. f1 = 2.0
  69. t1 = 1.0
  70. t = np.linspace(0, t1, 2000)
  71. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  72. tf, f = compute_frequency(t, phase)
  73. abserr = np.max(np.abs(f - chirp_quadratic(tf, f0, f1, t1)))
  74. assert_(abserr < 1e-6)
  75. def test_quadratic_freq_02(self):
  76. method = 'quadratic'
  77. f0 = 20.0
  78. f1 = 10.0
  79. t1 = 10.0
  80. t = np.linspace(0, t1, 2000)
  81. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  82. tf, f = compute_frequency(t, phase)
  83. abserr = np.max(np.abs(f - chirp_quadratic(tf, f0, f1, t1)))
  84. assert_(abserr < 1e-6)
  85. def test_logarithmic_at_zero(self):
  86. w = waveforms.chirp(t=0, f0=1.0, f1=2.0, t1=1.0, method='logarithmic')
  87. assert_almost_equal(w, 1.0)
  88. def test_logarithmic_freq_01(self):
  89. method = 'logarithmic'
  90. f0 = 1.0
  91. f1 = 2.0
  92. t1 = 1.0
  93. t = np.linspace(0, t1, 10000)
  94. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  95. tf, f = compute_frequency(t, phase)
  96. abserr = np.max(np.abs(f - chirp_geometric(tf, f0, f1, t1)))
  97. assert_(abserr < 1e-6)
  98. def test_logarithmic_freq_02(self):
  99. method = 'logarithmic'
  100. f0 = 200.0
  101. f1 = 100.0
  102. t1 = 10.0
  103. t = np.linspace(0, t1, 10000)
  104. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  105. tf, f = compute_frequency(t, phase)
  106. abserr = np.max(np.abs(f - chirp_geometric(tf, f0, f1, t1)))
  107. assert_(abserr < 1e-6)
  108. def test_logarithmic_freq_03(self):
  109. method = 'logarithmic'
  110. f0 = 100.0
  111. f1 = 100.0
  112. t1 = 10.0
  113. t = np.linspace(0, t1, 10000)
  114. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  115. tf, f = compute_frequency(t, phase)
  116. abserr = np.max(np.abs(f - chirp_geometric(tf, f0, f1, t1)))
  117. assert_(abserr < 1e-6)
  118. def test_hyperbolic_at_zero(self):
  119. w = waveforms.chirp(t=0, f0=10.0, f1=1.0, t1=1.0, method='hyperbolic')
  120. assert_almost_equal(w, 1.0)
  121. def test_hyperbolic_freq_01(self):
  122. method = 'hyperbolic'
  123. t1 = 1.0
  124. t = np.linspace(0, t1, 10000)
  125. # f0 f1
  126. cases = [[10.0, 1.0],
  127. [1.0, 10.0],
  128. [-10.0, -1.0],
  129. [-1.0, -10.0]]
  130. for f0, f1 in cases:
  131. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  132. tf, f = compute_frequency(t, phase)
  133. expected = chirp_hyperbolic(tf, f0, f1, t1)
  134. assert_allclose(f, expected)
  135. def test_hyperbolic_zero_freq(self):
  136. # f0=0 or f1=0 must raise a ValueError.
  137. method = 'hyperbolic'
  138. t1 = 1.0
  139. t = np.linspace(0, t1, 5)
  140. assert_raises(ValueError, waveforms.chirp, t, 0, t1, 1, method)
  141. assert_raises(ValueError, waveforms.chirp, t, 1, t1, 0, method)
  142. def test_unknown_method(self):
  143. method = "foo"
  144. f0 = 10.0
  145. f1 = 20.0
  146. t1 = 1.0
  147. t = np.linspace(0, t1, 10)
  148. assert_raises(ValueError, waveforms.chirp, t, f0, t1, f1, method)
  149. def test_integer_t1(self):
  150. f0 = 10.0
  151. f1 = 20.0
  152. t = np.linspace(-1, 1, 11)
  153. t1 = 3.0
  154. float_result = waveforms.chirp(t, f0, t1, f1)
  155. t1 = 3
  156. int_result = waveforms.chirp(t, f0, t1, f1)
  157. err_msg = "Integer input 't1=3' gives wrong result"
  158. assert_equal(int_result, float_result, err_msg=err_msg)
  159. def test_integer_f0(self):
  160. f1 = 20.0
  161. t1 = 3.0
  162. t = np.linspace(-1, 1, 11)
  163. f0 = 10.0
  164. float_result = waveforms.chirp(t, f0, t1, f1)
  165. f0 = 10
  166. int_result = waveforms.chirp(t, f0, t1, f1)
  167. err_msg = "Integer input 'f0=10' gives wrong result"
  168. assert_equal(int_result, float_result, err_msg=err_msg)
  169. def test_integer_f1(self):
  170. f0 = 10.0
  171. t1 = 3.0
  172. t = np.linspace(-1, 1, 11)
  173. f1 = 20.0
  174. float_result = waveforms.chirp(t, f0, t1, f1)
  175. f1 = 20
  176. int_result = waveforms.chirp(t, f0, t1, f1)
  177. err_msg = "Integer input 'f1=20' gives wrong result"
  178. assert_equal(int_result, float_result, err_msg=err_msg)
  179. def test_integer_all(self):
  180. f0 = 10
  181. t1 = 3
  182. f1 = 20
  183. t = np.linspace(-1, 1, 11)
  184. float_result = waveforms.chirp(t, float(f0), float(t1), float(f1))
  185. int_result = waveforms.chirp(t, f0, t1, f1)
  186. err_msg = "Integer input 'f0=10, t1=3, f1=20' gives wrong result"
  187. assert_equal(int_result, float_result, err_msg=err_msg)
  188. class TestSweepPoly(object):
  189. def test_sweep_poly_quad1(self):
  190. p = np.poly1d([1.0, 0.0, 1.0])
  191. t = np.linspace(0, 3.0, 10000)
  192. phase = waveforms._sweep_poly_phase(t, p)
  193. tf, f = compute_frequency(t, phase)
  194. expected = p(tf)
  195. abserr = np.max(np.abs(f - expected))
  196. assert_(abserr < 1e-6)
  197. def test_sweep_poly_const(self):
  198. p = np.poly1d(2.0)
  199. t = np.linspace(0, 3.0, 10000)
  200. phase = waveforms._sweep_poly_phase(t, p)
  201. tf, f = compute_frequency(t, phase)
  202. expected = p(tf)
  203. abserr = np.max(np.abs(f - expected))
  204. assert_(abserr < 1e-6)
  205. def test_sweep_poly_linear(self):
  206. p = np.poly1d([-1.0, 10.0])
  207. t = np.linspace(0, 3.0, 10000)
  208. phase = waveforms._sweep_poly_phase(t, p)
  209. tf, f = compute_frequency(t, phase)
  210. expected = p(tf)
  211. abserr = np.max(np.abs(f - expected))
  212. assert_(abserr < 1e-6)
  213. def test_sweep_poly_quad2(self):
  214. p = np.poly1d([1.0, 0.0, -2.0])
  215. t = np.linspace(0, 3.0, 10000)
  216. phase = waveforms._sweep_poly_phase(t, p)
  217. tf, f = compute_frequency(t, phase)
  218. expected = p(tf)
  219. abserr = np.max(np.abs(f - expected))
  220. assert_(abserr < 1e-6)
  221. def test_sweep_poly_cubic(self):
  222. p = np.poly1d([2.0, 1.0, 0.0, -2.0])
  223. t = np.linspace(0, 2.0, 10000)
  224. phase = waveforms._sweep_poly_phase(t, p)
  225. tf, f = compute_frequency(t, phase)
  226. expected = p(tf)
  227. abserr = np.max(np.abs(f - expected))
  228. assert_(abserr < 1e-6)
  229. def test_sweep_poly_cubic2(self):
  230. """Use an array of coefficients instead of a poly1d."""
  231. p = np.array([2.0, 1.0, 0.0, -2.0])
  232. t = np.linspace(0, 2.0, 10000)
  233. phase = waveforms._sweep_poly_phase(t, p)
  234. tf, f = compute_frequency(t, phase)
  235. expected = np.poly1d(p)(tf)
  236. abserr = np.max(np.abs(f - expected))
  237. assert_(abserr < 1e-6)
  238. def test_sweep_poly_cubic3(self):
  239. """Use a list of coefficients instead of a poly1d."""
  240. p = [2.0, 1.0, 0.0, -2.0]
  241. t = np.linspace(0, 2.0, 10000)
  242. phase = waveforms._sweep_poly_phase(t, p)
  243. tf, f = compute_frequency(t, phase)
  244. expected = np.poly1d(p)(tf)
  245. abserr = np.max(np.abs(f - expected))
  246. assert_(abserr < 1e-6)
  247. class TestGaussPulse(object):
  248. def test_integer_fc(self):
  249. float_result = waveforms.gausspulse('cutoff', fc=1000.0)
  250. int_result = waveforms.gausspulse('cutoff', fc=1000)
  251. err_msg = "Integer input 'fc=1000' gives wrong result"
  252. assert_equal(int_result, float_result, err_msg=err_msg)
  253. def test_integer_bw(self):
  254. float_result = waveforms.gausspulse('cutoff', bw=1.0)
  255. int_result = waveforms.gausspulse('cutoff', bw=1)
  256. err_msg = "Integer input 'bw=1' gives wrong result"
  257. assert_equal(int_result, float_result, err_msg=err_msg)
  258. def test_integer_bwr(self):
  259. float_result = waveforms.gausspulse('cutoff', bwr=-6.0)
  260. int_result = waveforms.gausspulse('cutoff', bwr=-6)
  261. err_msg = "Integer input 'bwr=-6' gives wrong result"
  262. assert_equal(int_result, float_result, err_msg=err_msg)
  263. def test_integer_tpr(self):
  264. float_result = waveforms.gausspulse('cutoff', tpr=-60.0)
  265. int_result = waveforms.gausspulse('cutoff', tpr=-60)
  266. err_msg = "Integer input 'tpr=-60' gives wrong result"
  267. assert_equal(int_result, float_result, err_msg=err_msg)
  268. class TestUnitImpulse(object):
  269. def test_no_index(self):
  270. assert_array_equal(waveforms.unit_impulse(7), [1, 0, 0, 0, 0, 0, 0])
  271. assert_array_equal(waveforms.unit_impulse((3, 3)),
  272. [[1, 0, 0], [0, 0, 0], [0, 0, 0]])
  273. def test_index(self):
  274. assert_array_equal(waveforms.unit_impulse(10, 3),
  275. [0, 0, 0, 1, 0, 0, 0, 0, 0, 0])
  276. assert_array_equal(waveforms.unit_impulse((3, 3), (1, 1)),
  277. [[0, 0, 0], [0, 1, 0], [0, 0, 0]])
  278. # Broadcasting
  279. imp = waveforms.unit_impulse((4, 4), 2)
  280. assert_array_equal(imp, np.array([[0, 0, 0, 0],
  281. [0, 0, 0, 0],
  282. [0, 0, 1, 0],
  283. [0, 0, 0, 0]]))
  284. def test_mid(self):
  285. assert_array_equal(waveforms.unit_impulse((3, 3), 'mid'),
  286. [[0, 0, 0], [0, 1, 0], [0, 0, 0]])
  287. assert_array_equal(waveforms.unit_impulse(9, 'mid'),
  288. [0, 0, 0, 0, 1, 0, 0, 0, 0])
  289. def test_dtype(self):
  290. imp = waveforms.unit_impulse(7)
  291. assert_(np.issubdtype(imp.dtype, np.floating))
  292. imp = waveforms.unit_impulse(5, 3, dtype=int)
  293. assert_(np.issubdtype(imp.dtype, np.integer))
  294. imp = waveforms.unit_impulse((5, 2), (3, 1), dtype=complex)
  295. assert_(np.issubdtype(imp.dtype, np.complexfloating))