_arraytools.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. """
  2. Functions for acting on a axis of an array.
  3. """
  4. from __future__ import division, print_function, absolute_import
  5. import numpy as np
  6. def axis_slice(a, start=None, stop=None, step=None, axis=-1):
  7. """Take a slice along axis 'axis' from 'a'.
  8. Parameters
  9. ----------
  10. a : numpy.ndarray
  11. The array to be sliced.
  12. start, stop, step : int or None
  13. The slice parameters.
  14. axis : int, optional
  15. The axis of `a` to be sliced.
  16. Examples
  17. --------
  18. >>> a = array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  19. >>> axis_slice(a, start=0, stop=1, axis=1)
  20. array([[1],
  21. [4],
  22. [7]])
  23. >>> axis_slice(a, start=1, axis=0)
  24. array([[4, 5, 6],
  25. [7, 8, 9]])
  26. Notes
  27. -----
  28. The keyword arguments start, stop and step are used by calling
  29. slice(start, stop, step). This implies axis_slice() does not
  30. handle its arguments the exactly the same as indexing. To select
  31. a single index k, for example, use
  32. axis_slice(a, start=k, stop=k+1)
  33. In this case, the length of the axis 'axis' in the result will
  34. be 1; the trivial dimension is not removed. (Use numpy.squeeze()
  35. to remove trivial axes.)
  36. """
  37. a_slice = [slice(None)] * a.ndim
  38. a_slice[axis] = slice(start, stop, step)
  39. b = a[tuple(a_slice)]
  40. return b
  41. def axis_reverse(a, axis=-1):
  42. """Reverse the 1-d slices of `a` along axis `axis`.
  43. Returns axis_slice(a, step=-1, axis=axis).
  44. """
  45. return axis_slice(a, step=-1, axis=axis)
  46. def odd_ext(x, n, axis=-1):
  47. """
  48. Odd extension at the boundaries of an array
  49. Generate a new ndarray by making an odd extension of `x` along an axis.
  50. Parameters
  51. ----------
  52. x : ndarray
  53. The array to be extended.
  54. n : int
  55. The number of elements by which to extend `x` at each end of the axis.
  56. axis : int, optional
  57. The axis along which to extend `x`. Default is -1.
  58. Examples
  59. --------
  60. >>> from scipy.signal._arraytools import odd_ext
  61. >>> a = np.array([[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]])
  62. >>> odd_ext(a, 2)
  63. array([[-1, 0, 1, 2, 3, 4, 5, 6, 7],
  64. [-4, -1, 0, 1, 4, 9, 16, 23, 28]])
  65. Odd extension is a "180 degree rotation" at the endpoints of the original
  66. array:
  67. >>> t = np.linspace(0, 1.5, 100)
  68. >>> a = 0.9 * np.sin(2 * np.pi * t**2)
  69. >>> b = odd_ext(a, 40)
  70. >>> import matplotlib.pyplot as plt
  71. >>> plt.plot(arange(-40, 140), b, 'b', lw=1, label='odd extension')
  72. >>> plt.plot(arange(100), a, 'r', lw=2, label='original')
  73. >>> plt.legend(loc='best')
  74. >>> plt.show()
  75. """
  76. if n < 1:
  77. return x
  78. if n > x.shape[axis] - 1:
  79. raise ValueError(("The extension length n (%d) is too big. " +
  80. "It must not exceed x.shape[axis]-1, which is %d.")
  81. % (n, x.shape[axis] - 1))
  82. left_end = axis_slice(x, start=0, stop=1, axis=axis)
  83. left_ext = axis_slice(x, start=n, stop=0, step=-1, axis=axis)
  84. right_end = axis_slice(x, start=-1, axis=axis)
  85. right_ext = axis_slice(x, start=-2, stop=-(n + 2), step=-1, axis=axis)
  86. ext = np.concatenate((2 * left_end - left_ext,
  87. x,
  88. 2 * right_end - right_ext),
  89. axis=axis)
  90. return ext
  91. def even_ext(x, n, axis=-1):
  92. """
  93. Even extension at the boundaries of an array
  94. Generate a new ndarray by making an even extension of `x` along an axis.
  95. Parameters
  96. ----------
  97. x : ndarray
  98. The array to be extended.
  99. n : int
  100. The number of elements by which to extend `x` at each end of the axis.
  101. axis : int, optional
  102. The axis along which to extend `x`. Default is -1.
  103. Examples
  104. --------
  105. >>> from scipy.signal._arraytools import even_ext
  106. >>> a = np.array([[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]])
  107. >>> even_ext(a, 2)
  108. array([[ 3, 2, 1, 2, 3, 4, 5, 4, 3],
  109. [ 4, 1, 0, 1, 4, 9, 16, 9, 4]])
  110. Even extension is a "mirror image" at the boundaries of the original array:
  111. >>> t = np.linspace(0, 1.5, 100)
  112. >>> a = 0.9 * np.sin(2 * np.pi * t**2)
  113. >>> b = even_ext(a, 40)
  114. >>> import matplotlib.pyplot as plt
  115. >>> plt.plot(arange(-40, 140), b, 'b', lw=1, label='even extension')
  116. >>> plt.plot(arange(100), a, 'r', lw=2, label='original')
  117. >>> plt.legend(loc='best')
  118. >>> plt.show()
  119. """
  120. if n < 1:
  121. return x
  122. if n > x.shape[axis] - 1:
  123. raise ValueError(("The extension length n (%d) is too big. " +
  124. "It must not exceed x.shape[axis]-1, which is %d.")
  125. % (n, x.shape[axis] - 1))
  126. left_ext = axis_slice(x, start=n, stop=0, step=-1, axis=axis)
  127. right_ext = axis_slice(x, start=-2, stop=-(n + 2), step=-1, axis=axis)
  128. ext = np.concatenate((left_ext,
  129. x,
  130. right_ext),
  131. axis=axis)
  132. return ext
  133. def const_ext(x, n, axis=-1):
  134. """
  135. Constant extension at the boundaries of an array
  136. Generate a new ndarray that is a constant extension of `x` along an axis.
  137. The extension repeats the values at the first and last element of
  138. the axis.
  139. Parameters
  140. ----------
  141. x : ndarray
  142. The array to be extended.
  143. n : int
  144. The number of elements by which to extend `x` at each end of the axis.
  145. axis : int, optional
  146. The axis along which to extend `x`. Default is -1.
  147. Examples
  148. --------
  149. >>> from scipy.signal._arraytools import const_ext
  150. >>> a = np.array([[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]])
  151. >>> const_ext(a, 2)
  152. array([[ 1, 1, 1, 2, 3, 4, 5, 5, 5],
  153. [ 0, 0, 0, 1, 4, 9, 16, 16, 16]])
  154. Constant extension continues with the same values as the endpoints of the
  155. array:
  156. >>> t = np.linspace(0, 1.5, 100)
  157. >>> a = 0.9 * np.sin(2 * np.pi * t**2)
  158. >>> b = const_ext(a, 40)
  159. >>> import matplotlib.pyplot as plt
  160. >>> plt.plot(arange(-40, 140), b, 'b', lw=1, label='constant extension')
  161. >>> plt.plot(arange(100), a, 'r', lw=2, label='original')
  162. >>> plt.legend(loc='best')
  163. >>> plt.show()
  164. """
  165. if n < 1:
  166. return x
  167. left_end = axis_slice(x, start=0, stop=1, axis=axis)
  168. ones_shape = [1] * x.ndim
  169. ones_shape[axis] = n
  170. ones = np.ones(ones_shape, dtype=x.dtype)
  171. left_ext = ones * left_end
  172. right_end = axis_slice(x, start=-1, axis=axis)
  173. right_ext = ones * right_end
  174. ext = np.concatenate((left_ext,
  175. x,
  176. right_ext),
  177. axis=axis)
  178. return ext
  179. def zero_ext(x, n, axis=-1):
  180. """
  181. Zero padding at the boundaries of an array
  182. Generate a new ndarray that is a zero padded extension of `x` along
  183. an axis.
  184. Parameters
  185. ----------
  186. x : ndarray
  187. The array to be extended.
  188. n : int
  189. The number of elements by which to extend `x` at each end of the
  190. axis.
  191. axis : int, optional
  192. The axis along which to extend `x`. Default is -1.
  193. Examples
  194. --------
  195. >>> from scipy.signal._arraytools import zero_ext
  196. >>> a = np.array([[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]])
  197. >>> zero_ext(a, 2)
  198. array([[ 0, 0, 1, 2, 3, 4, 5, 0, 0],
  199. [ 0, 0, 0, 1, 4, 9, 16, 0, 0]])
  200. """
  201. if n < 1:
  202. return x
  203. zeros_shape = list(x.shape)
  204. zeros_shape[axis] = n
  205. zeros = np.zeros(zeros_shape, dtype=x.dtype)
  206. ext = np.concatenate((zeros, x, zeros), axis=axis)
  207. return ext