test_upfirdn.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # Code adapted from "upfirdn" python library with permission:
  2. #
  3. # Copyright (c) 2009, Motorola, Inc
  4. #
  5. # All Rights Reserved.
  6. #
  7. # Redistribution and use in source and binary forms, with or without
  8. # modification, are permitted provided that the following conditions are
  9. # met:
  10. #
  11. # * Redistributions of source code must retain the above copyright notice,
  12. # this list of conditions and the following disclaimer.
  13. #
  14. # * Redistributions in binary form must reproduce the above copyright
  15. # notice, this list of conditions and the following disclaimer in the
  16. # documentation and/or other materials provided with the distribution.
  17. #
  18. # * Neither the name of Motorola nor the names of its contributors may be
  19. # used to endorse or promote products derived from this software without
  20. # specific prior written permission.
  21. #
  22. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
  23. # IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
  24. # THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
  25. # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
  26. # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
  27. # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
  28. # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
  29. # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
  30. # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
  31. # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  32. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  33. import numpy as np
  34. from itertools import product
  35. from numpy.testing import assert_equal, assert_allclose
  36. from pytest import raises as assert_raises
  37. from scipy.signal import upfirdn, firwin, lfilter
  38. from scipy.signal._upfirdn import _output_len
  39. def upfirdn_naive(x, h, up=1, down=1):
  40. """Naive upfirdn processing in Python
  41. Note: arg order (x, h) differs to facilitate apply_along_axis use.
  42. """
  43. h = np.asarray(h)
  44. out = np.zeros(len(x) * up, x.dtype)
  45. out[::up] = x
  46. out = np.convolve(h, out)[::down][:_output_len(len(h), len(x), up, down)]
  47. return out
  48. class UpFIRDnCase(object):
  49. """Test _UpFIRDn object"""
  50. def __init__(self, up, down, h, x_dtype):
  51. self.up = up
  52. self.down = down
  53. self.h = np.atleast_1d(h)
  54. self.x_dtype = x_dtype
  55. self.rng = np.random.RandomState(17)
  56. def __call__(self):
  57. # tiny signal
  58. self.scrub(np.ones(1, self.x_dtype))
  59. # ones
  60. self.scrub(np.ones(10, self.x_dtype)) # ones
  61. # randn
  62. x = self.rng.randn(10).astype(self.x_dtype)
  63. if self.x_dtype in (np.complex64, np.complex128):
  64. x += 1j * self.rng.randn(10)
  65. self.scrub(x)
  66. # ramp
  67. self.scrub(np.arange(10).astype(self.x_dtype))
  68. # 3D, random
  69. size = (2, 3, 5)
  70. x = self.rng.randn(*size).astype(self.x_dtype)
  71. if self.x_dtype in (np.complex64, np.complex128):
  72. x += 1j * self.rng.randn(*size)
  73. for axis in range(len(size)):
  74. self.scrub(x, axis=axis)
  75. x = x[:, ::2, 1::3].T
  76. for axis in range(len(size)):
  77. self.scrub(x, axis=axis)
  78. def scrub(self, x, axis=-1):
  79. yr = np.apply_along_axis(upfirdn_naive, axis, x,
  80. self.h, self.up, self.down)
  81. y = upfirdn(self.h, x, self.up, self.down, axis=axis)
  82. dtypes = (self.h.dtype, x.dtype)
  83. if all(d == np.complex64 for d in dtypes):
  84. assert_equal(y.dtype, np.complex64)
  85. elif np.complex64 in dtypes and np.float32 in dtypes:
  86. assert_equal(y.dtype, np.complex64)
  87. elif all(d == np.float32 for d in dtypes):
  88. assert_equal(y.dtype, np.float32)
  89. elif np.complex128 in dtypes or np.complex64 in dtypes:
  90. assert_equal(y.dtype, np.complex128)
  91. else:
  92. assert_equal(y.dtype, np.float64)
  93. assert_allclose(yr, y)
  94. class TestUpfirdn(object):
  95. def test_valid_input(self):
  96. assert_raises(ValueError, upfirdn, [1], [1], 1, 0) # up or down < 1
  97. assert_raises(ValueError, upfirdn, [], [1], 1, 1) # h.ndim != 1
  98. assert_raises(ValueError, upfirdn, [[1]], [1], 1, 1)
  99. def test_vs_lfilter(self):
  100. # Check that up=1.0 gives same answer as lfilter + slicing
  101. random_state = np.random.RandomState(17)
  102. try_types = (int, np.float32, np.complex64, float, complex)
  103. size = 10000
  104. down_factors = [2, 11, 79]
  105. for dtype in try_types:
  106. x = random_state.randn(size).astype(dtype)
  107. if dtype in (np.complex64, np.complex128):
  108. x += 1j * random_state.randn(size)
  109. for down in down_factors:
  110. h = firwin(31, 1. / down, window='hamming')
  111. yl = lfilter(h, 1.0, x)[::down]
  112. y = upfirdn(h, x, up=1, down=down)
  113. assert_allclose(yl, y[:yl.size], atol=1e-7, rtol=1e-7)
  114. def test_vs_naive(self):
  115. tests = []
  116. try_types = (int, np.float32, np.complex64, float, complex)
  117. # Simple combinations of factors
  118. for x_dtype, h in product(try_types, (1., 1j)):
  119. tests.append(UpFIRDnCase(1, 1, h, x_dtype))
  120. tests.append(UpFIRDnCase(2, 2, h, x_dtype))
  121. tests.append(UpFIRDnCase(3, 2, h, x_dtype))
  122. tests.append(UpFIRDnCase(2, 3, h, x_dtype))
  123. # mixture of big, small, and both directions (net up and net down)
  124. # use all combinations of data and filter dtypes
  125. factors = (100, 10) # up/down factors
  126. cases = product(factors, factors, try_types, try_types)
  127. for case in cases:
  128. tests += self._random_factors(*case)
  129. for test in tests:
  130. test()
  131. def _random_factors(self, p_max, q_max, h_dtype, x_dtype):
  132. n_rep = 3
  133. longest_h = 25
  134. random_state = np.random.RandomState(17)
  135. tests = []
  136. for _ in range(n_rep):
  137. # Randomize the up/down factors somewhat
  138. p_add = q_max if p_max > q_max else 1
  139. q_add = p_max if q_max > p_max else 1
  140. p = random_state.randint(p_max) + p_add
  141. q = random_state.randint(q_max) + q_add
  142. # Generate random FIR coefficients
  143. len_h = random_state.randint(longest_h) + 1
  144. h = np.atleast_1d(random_state.randint(len_h))
  145. h = h.astype(h_dtype)
  146. if h_dtype == complex:
  147. h += 1j * random_state.randint(len_h)
  148. tests.append(UpFIRDnCase(p, q, h, x_dtype))
  149. return tests