wavelets.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. from __future__ import division, print_function, absolute_import
  2. import numpy as np
  3. from numpy.dual import eig
  4. from scipy.special import comb
  5. from scipy import linspace, pi, exp
  6. from scipy.signal import convolve
  7. __all__ = ['daub', 'qmf', 'cascade', 'morlet', 'ricker', 'cwt']
  8. def daub(p):
  9. """
  10. The coefficients for the FIR low-pass filter producing Daubechies wavelets.
  11. p>=1 gives the order of the zero at f=1/2.
  12. There are 2p filter coefficients.
  13. Parameters
  14. ----------
  15. p : int
  16. Order of the zero at f=1/2, can have values from 1 to 34.
  17. Returns
  18. -------
  19. daub : ndarray
  20. Return
  21. """
  22. sqrt = np.sqrt
  23. if p < 1:
  24. raise ValueError("p must be at least 1.")
  25. if p == 1:
  26. c = 1 / sqrt(2)
  27. return np.array([c, c])
  28. elif p == 2:
  29. f = sqrt(2) / 8
  30. c = sqrt(3)
  31. return f * np.array([1 + c, 3 + c, 3 - c, 1 - c])
  32. elif p == 3:
  33. tmp = 12 * sqrt(10)
  34. z1 = 1.5 + sqrt(15 + tmp) / 6 - 1j * (sqrt(15) + sqrt(tmp - 15)) / 6
  35. z1c = np.conj(z1)
  36. f = sqrt(2) / 8
  37. d0 = np.real((1 - z1) * (1 - z1c))
  38. a0 = np.real(z1 * z1c)
  39. a1 = 2 * np.real(z1)
  40. return f / d0 * np.array([a0, 3 * a0 - a1, 3 * a0 - 3 * a1 + 1,
  41. a0 - 3 * a1 + 3, 3 - a1, 1])
  42. elif p < 35:
  43. # construct polynomial and factor it
  44. if p < 35:
  45. P = [comb(p - 1 + k, k, exact=1) for k in range(p)][::-1]
  46. yj = np.roots(P)
  47. else: # try different polynomial --- needs work
  48. P = [comb(p - 1 + k, k, exact=1) / 4.0**k
  49. for k in range(p)][::-1]
  50. yj = np.roots(P) / 4
  51. # for each root, compute two z roots, select the one with |z|>1
  52. # Build up final polynomial
  53. c = np.poly1d([1, 1])**p
  54. q = np.poly1d([1])
  55. for k in range(p - 1):
  56. yval = yj[k]
  57. part = 2 * sqrt(yval * (yval - 1))
  58. const = 1 - 2 * yval
  59. z1 = const + part
  60. if (abs(z1)) < 1:
  61. z1 = const - part
  62. q = q * [1, -z1]
  63. q = c * np.real(q)
  64. # Normalize result
  65. q = q / np.sum(q) * sqrt(2)
  66. return q.c[::-1]
  67. else:
  68. raise ValueError("Polynomial factorization does not work "
  69. "well for p too large.")
  70. def qmf(hk):
  71. """
  72. Return high-pass qmf filter from low-pass
  73. Parameters
  74. ----------
  75. hk : array_like
  76. Coefficients of high-pass filter.
  77. """
  78. N = len(hk) - 1
  79. asgn = [{0: 1, 1: -1}[k % 2] for k in range(N + 1)]
  80. return hk[::-1] * np.array(asgn)
  81. def cascade(hk, J=7):
  82. """
  83. Return (x, phi, psi) at dyadic points ``K/2**J`` from filter coefficients.
  84. Parameters
  85. ----------
  86. hk : array_like
  87. Coefficients of low-pass filter.
  88. J : int, optional
  89. Values will be computed at grid points ``K/2**J``. Default is 7.
  90. Returns
  91. -------
  92. x : ndarray
  93. The dyadic points ``K/2**J`` for ``K=0...N * (2**J)-1`` where
  94. ``len(hk) = len(gk) = N+1``.
  95. phi : ndarray
  96. The scaling function ``phi(x)`` at `x`:
  97. ``phi(x) = sum(hk * phi(2x-k))``, where k is from 0 to N.
  98. psi : ndarray, optional
  99. The wavelet function ``psi(x)`` at `x`:
  100. ``phi(x) = sum(gk * phi(2x-k))``, where k is from 0 to N.
  101. `psi` is only returned if `gk` is not None.
  102. Notes
  103. -----
  104. The algorithm uses the vector cascade algorithm described by Strang and
  105. Nguyen in "Wavelets and Filter Banks". It builds a dictionary of values
  106. and slices for quick reuse. Then inserts vectors into final vector at the
  107. end.
  108. """
  109. N = len(hk) - 1
  110. if (J > 30 - np.log2(N + 1)):
  111. raise ValueError("Too many levels.")
  112. if (J < 1):
  113. raise ValueError("Too few levels.")
  114. # construct matrices needed
  115. nn, kk = np.ogrid[:N, :N]
  116. s2 = np.sqrt(2)
  117. # append a zero so that take works
  118. thk = np.r_[hk, 0]
  119. gk = qmf(hk)
  120. tgk = np.r_[gk, 0]
  121. indx1 = np.clip(2 * nn - kk, -1, N + 1)
  122. indx2 = np.clip(2 * nn - kk + 1, -1, N + 1)
  123. m = np.zeros((2, 2, N, N), 'd')
  124. m[0, 0] = np.take(thk, indx1, 0)
  125. m[0, 1] = np.take(thk, indx2, 0)
  126. m[1, 0] = np.take(tgk, indx1, 0)
  127. m[1, 1] = np.take(tgk, indx2, 0)
  128. m *= s2
  129. # construct the grid of points
  130. x = np.arange(0, N * (1 << J), dtype=float) / (1 << J)
  131. phi = 0 * x
  132. psi = 0 * x
  133. # find phi0, and phi1
  134. lam, v = eig(m[0, 0])
  135. ind = np.argmin(np.absolute(lam - 1))
  136. # a dictionary with a binary representation of the
  137. # evaluation points x < 1 -- i.e. position is 0.xxxx
  138. v = np.real(v[:, ind])
  139. # need scaling function to integrate to 1 so find
  140. # eigenvector normalized to sum(v,axis=0)=1
  141. sm = np.sum(v)
  142. if sm < 0: # need scaling function to integrate to 1
  143. v = -v
  144. sm = -sm
  145. bitdic = {'0': v / sm}
  146. bitdic['1'] = np.dot(m[0, 1], bitdic['0'])
  147. step = 1 << J
  148. phi[::step] = bitdic['0']
  149. phi[(1 << (J - 1))::step] = bitdic['1']
  150. psi[::step] = np.dot(m[1, 0], bitdic['0'])
  151. psi[(1 << (J - 1))::step] = np.dot(m[1, 1], bitdic['0'])
  152. # descend down the levels inserting more and more values
  153. # into bitdic -- store the values in the correct location once we
  154. # have computed them -- stored in the dictionary
  155. # for quicker use later.
  156. prevkeys = ['1']
  157. for level in range(2, J + 1):
  158. newkeys = ['%d%s' % (xx, yy) for xx in [0, 1] for yy in prevkeys]
  159. fac = 1 << (J - level)
  160. for key in newkeys:
  161. # convert key to number
  162. num = 0
  163. for pos in range(level):
  164. if key[pos] == '1':
  165. num += (1 << (level - 1 - pos))
  166. pastphi = bitdic[key[1:]]
  167. ii = int(key[0])
  168. temp = np.dot(m[0, ii], pastphi)
  169. bitdic[key] = temp
  170. phi[num * fac::step] = temp
  171. psi[num * fac::step] = np.dot(m[1, ii], pastphi)
  172. prevkeys = newkeys
  173. return x, phi, psi
  174. def morlet(M, w=5.0, s=1.0, complete=True):
  175. """
  176. Complex Morlet wavelet.
  177. Parameters
  178. ----------
  179. M : int
  180. Length of the wavelet.
  181. w : float, optional
  182. Omega0. Default is 5
  183. s : float, optional
  184. Scaling factor, windowed from ``-s*2*pi`` to ``+s*2*pi``. Default is 1.
  185. complete : bool, optional
  186. Whether to use the complete or the standard version.
  187. Returns
  188. -------
  189. morlet : (M,) ndarray
  190. See Also
  191. --------
  192. scipy.signal.gausspulse
  193. Notes
  194. -----
  195. The standard version::
  196. pi**-0.25 * exp(1j*w*x) * exp(-0.5*(x**2))
  197. This commonly used wavelet is often referred to simply as the
  198. Morlet wavelet. Note that this simplified version can cause
  199. admissibility problems at low values of `w`.
  200. The complete version::
  201. pi**-0.25 * (exp(1j*w*x) - exp(-0.5*(w**2))) * exp(-0.5*(x**2))
  202. This version has a correction
  203. term to improve admissibility. For `w` greater than 5, the
  204. correction term is negligible.
  205. Note that the energy of the return wavelet is not normalised
  206. according to `s`.
  207. The fundamental frequency of this wavelet in Hz is given
  208. by ``f = 2*s*w*r / M`` where `r` is the sampling rate.
  209. Note: This function was created before `cwt` and is not compatible
  210. with it.
  211. """
  212. x = linspace(-s * 2 * pi, s * 2 * pi, M)
  213. output = exp(1j * w * x)
  214. if complete:
  215. output -= exp(-0.5 * (w**2))
  216. output *= exp(-0.5 * (x**2)) * pi**(-0.25)
  217. return output
  218. def ricker(points, a):
  219. """
  220. Return a Ricker wavelet, also known as the "Mexican hat wavelet".
  221. It models the function:
  222. ``A (1 - x^2/a^2) exp(-x^2/2 a^2)``,
  223. where ``A = 2/sqrt(3a)pi^1/4``.
  224. Parameters
  225. ----------
  226. points : int
  227. Number of points in `vector`.
  228. Will be centered around 0.
  229. a : scalar
  230. Width parameter of the wavelet.
  231. Returns
  232. -------
  233. vector : (N,) ndarray
  234. Array of length `points` in shape of ricker curve.
  235. Examples
  236. --------
  237. >>> from scipy import signal
  238. >>> import matplotlib.pyplot as plt
  239. >>> points = 100
  240. >>> a = 4.0
  241. >>> vec2 = signal.ricker(points, a)
  242. >>> print(len(vec2))
  243. 100
  244. >>> plt.plot(vec2)
  245. >>> plt.show()
  246. """
  247. A = 2 / (np.sqrt(3 * a) * (np.pi**0.25))
  248. wsq = a**2
  249. vec = np.arange(0, points) - (points - 1.0) / 2
  250. xsq = vec**2
  251. mod = (1 - xsq / wsq)
  252. gauss = np.exp(-xsq / (2 * wsq))
  253. total = A * mod * gauss
  254. return total
  255. def cwt(data, wavelet, widths):
  256. """
  257. Continuous wavelet transform.
  258. Performs a continuous wavelet transform on `data`,
  259. using the `wavelet` function. A CWT performs a convolution
  260. with `data` using the `wavelet` function, which is characterized
  261. by a width parameter and length parameter.
  262. Parameters
  263. ----------
  264. data : (N,) ndarray
  265. data on which to perform the transform.
  266. wavelet : function
  267. Wavelet function, which should take 2 arguments.
  268. The first argument is the number of points that the returned vector
  269. will have (len(wavelet(length,width)) == length).
  270. The second is a width parameter, defining the size of the wavelet
  271. (e.g. standard deviation of a gaussian). See `ricker`, which
  272. satisfies these requirements.
  273. widths : (M,) sequence
  274. Widths to use for transform.
  275. Returns
  276. -------
  277. cwt: (M, N) ndarray
  278. Will have shape of (len(widths), len(data)).
  279. Notes
  280. -----
  281. ::
  282. length = min(10 * width[ii], len(data))
  283. cwt[ii,:] = signal.convolve(data, wavelet(length,
  284. width[ii]), mode='same')
  285. Examples
  286. --------
  287. >>> from scipy import signal
  288. >>> import matplotlib.pyplot as plt
  289. >>> t = np.linspace(-1, 1, 200, endpoint=False)
  290. >>> sig = np.cos(2 * np.pi * 7 * t) + signal.gausspulse(t - 0.4, fc=2)
  291. >>> widths = np.arange(1, 31)
  292. >>> cwtmatr = signal.cwt(sig, signal.ricker, widths)
  293. >>> plt.imshow(cwtmatr, extent=[-1, 1, 31, 1], cmap='PRGn', aspect='auto',
  294. ... vmax=abs(cwtmatr).max(), vmin=-abs(cwtmatr).max())
  295. >>> plt.show()
  296. """
  297. output = np.zeros([len(widths), len(data)])
  298. for ind, width in enumerate(widths):
  299. wavelet_data = wavelet(min(10 * width, len(data)), width)
  300. output[ind, :] = convolve(data, wavelet_data,
  301. mode='same')
  302. return output