test_signaltools.py 102 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732
  1. # -*- coding: utf-8 -*-
  2. from __future__ import division, print_function, absolute_import
  3. import sys
  4. from decimal import Decimal
  5. from itertools import product
  6. import warnings
  7. import pytest
  8. from pytest import raises as assert_raises
  9. from numpy.testing import (
  10. assert_equal,
  11. assert_almost_equal, assert_array_equal, assert_array_almost_equal,
  12. assert_allclose, assert_, assert_warns, assert_array_less)
  13. from scipy._lib._numpy_compat import suppress_warnings
  14. from numpy import array, arange
  15. import numpy as np
  16. from scipy.ndimage.filters import correlate1d
  17. from scipy.optimize import fmin
  18. from scipy import signal
  19. from scipy.signal import (
  20. correlate, convolve, convolve2d, fftconvolve, choose_conv_method,
  21. hilbert, hilbert2, lfilter, lfilter_zi, filtfilt, butter, zpk2tf, zpk2sos,
  22. invres, invresz, vectorstrength, lfiltic, tf2sos, sosfilt, sosfiltfilt,
  23. sosfilt_zi, tf2zpk, BadCoefficients, unique_roots)
  24. from scipy.signal.windows import hann
  25. from scipy.signal.signaltools import _filtfilt_gust
  26. if sys.version_info.major >= 3 and sys.version_info.minor >= 5:
  27. from math import gcd
  28. else:
  29. from fractions import gcd
  30. class _TestConvolve(object):
  31. def test_basic(self):
  32. a = [3, 4, 5, 6, 5, 4]
  33. b = [1, 2, 3]
  34. c = convolve(a, b)
  35. assert_array_equal(c, array([3, 10, 22, 28, 32, 32, 23, 12]))
  36. def test_same(self):
  37. a = [3, 4, 5]
  38. b = [1, 2, 3, 4]
  39. c = convolve(a, b, mode="same")
  40. assert_array_equal(c, array([10, 22, 34]))
  41. def test_same_eq(self):
  42. a = [3, 4, 5]
  43. b = [1, 2, 3]
  44. c = convolve(a, b, mode="same")
  45. assert_array_equal(c, array([10, 22, 22]))
  46. def test_complex(self):
  47. x = array([1 + 1j, 2 + 1j, 3 + 1j])
  48. y = array([1 + 1j, 2 + 1j])
  49. z = convolve(x, y)
  50. assert_array_equal(z, array([2j, 2 + 6j, 5 + 8j, 5 + 5j]))
  51. def test_zero_rank(self):
  52. a = 1289
  53. b = 4567
  54. c = convolve(a, b)
  55. assert_equal(c, a * b)
  56. def test_single_element(self):
  57. a = array([4967])
  58. b = array([3920])
  59. c = convolve(a, b)
  60. assert_equal(c, a * b)
  61. def test_2d_arrays(self):
  62. a = [[1, 2, 3], [3, 4, 5]]
  63. b = [[2, 3, 4], [4, 5, 6]]
  64. c = convolve(a, b)
  65. d = array([[2, 7, 16, 17, 12],
  66. [10, 30, 62, 58, 38],
  67. [12, 31, 58, 49, 30]])
  68. assert_array_equal(c, d)
  69. def test_input_swapping(self):
  70. small = arange(8).reshape(2, 2, 2)
  71. big = 1j * arange(27).reshape(3, 3, 3)
  72. big += arange(27)[::-1].reshape(3, 3, 3)
  73. out_array = array(
  74. [[[0 + 0j, 26 + 0j, 25 + 1j, 24 + 2j],
  75. [52 + 0j, 151 + 5j, 145 + 11j, 93 + 11j],
  76. [46 + 6j, 133 + 23j, 127 + 29j, 81 + 23j],
  77. [40 + 12j, 98 + 32j, 93 + 37j, 54 + 24j]],
  78. [[104 + 0j, 247 + 13j, 237 + 23j, 135 + 21j],
  79. [282 + 30j, 632 + 96j, 604 + 124j, 330 + 86j],
  80. [246 + 66j, 548 + 180j, 520 + 208j, 282 + 134j],
  81. [142 + 66j, 307 + 161j, 289 + 179j, 153 + 107j]],
  82. [[68 + 36j, 157 + 103j, 147 + 113j, 81 + 75j],
  83. [174 + 138j, 380 + 348j, 352 + 376j, 186 + 230j],
  84. [138 + 174j, 296 + 432j, 268 + 460j, 138 + 278j],
  85. [70 + 138j, 145 + 323j, 127 + 341j, 63 + 197j]],
  86. [[32 + 72j, 68 + 166j, 59 + 175j, 30 + 100j],
  87. [68 + 192j, 139 + 433j, 117 + 455j, 57 + 255j],
  88. [38 + 222j, 73 + 499j, 51 + 521j, 21 + 291j],
  89. [12 + 144j, 20 + 318j, 7 + 331j, 0 + 182j]]])
  90. assert_array_equal(convolve(small, big, 'full'), out_array)
  91. assert_array_equal(convolve(big, small, 'full'), out_array)
  92. assert_array_equal(convolve(small, big, 'same'),
  93. out_array[1:3, 1:3, 1:3])
  94. assert_array_equal(convolve(big, small, 'same'),
  95. out_array[0:3, 0:3, 0:3])
  96. assert_array_equal(convolve(small, big, 'valid'),
  97. out_array[1:3, 1:3, 1:3])
  98. assert_array_equal(convolve(big, small, 'valid'),
  99. out_array[1:3, 1:3, 1:3])
  100. def test_invalid_params(self):
  101. a = [3, 4, 5]
  102. b = [1, 2, 3]
  103. assert_raises(ValueError, convolve, a, b, mode='spam')
  104. assert_raises(ValueError, convolve, a, b, mode='eggs', method='fft')
  105. assert_raises(ValueError, convolve, a, b, mode='ham', method='direct')
  106. assert_raises(ValueError, convolve, a, b, mode='full', method='bacon')
  107. assert_raises(ValueError, convolve, a, b, mode='same', method='bacon')
  108. class TestConvolve(_TestConvolve):
  109. def test_valid_mode2(self):
  110. # See gh-5897
  111. a = [1, 2, 3, 6, 5, 3]
  112. b = [2, 3, 4, 5, 3, 4, 2, 2, 1]
  113. expected = [70, 78, 73, 65]
  114. out = convolve(a, b, 'valid')
  115. assert_array_equal(out, expected)
  116. out = convolve(b, a, 'valid')
  117. assert_array_equal(out, expected)
  118. a = [1 + 5j, 2 - 1j, 3 + 0j]
  119. b = [2 - 3j, 1 + 0j]
  120. expected = [2 - 3j, 8 - 10j]
  121. out = convolve(a, b, 'valid')
  122. assert_array_equal(out, expected)
  123. out = convolve(b, a, 'valid')
  124. assert_array_equal(out, expected)
  125. def test_same_mode(self):
  126. a = [1, 2, 3, 3, 1, 2]
  127. b = [1, 4, 3, 4, 5, 6, 7, 4, 3, 2, 1, 1, 3]
  128. c = convolve(a, b, 'same')
  129. d = array([57, 61, 63, 57, 45, 36])
  130. assert_array_equal(c, d)
  131. def test_invalid_shapes(self):
  132. # By "invalid," we mean that no one
  133. # array has dimensions that are all at
  134. # least as large as the corresponding
  135. # dimensions of the other array. This
  136. # setup should throw a ValueError.
  137. a = np.arange(1, 7).reshape((2, 3))
  138. b = np.arange(-6, 0).reshape((3, 2))
  139. assert_raises(ValueError, convolve, *(a, b), **{'mode': 'valid'})
  140. assert_raises(ValueError, convolve, *(b, a), **{'mode': 'valid'})
  141. def test_convolve_method(self, n=100):
  142. types = sum([t for _, t in np.sctypes.items()], [])
  143. types = {np.dtype(t).name for t in types}
  144. # These types include 'bool' and all precisions (int8, float32, etc)
  145. # The removed types throw errors in correlate or fftconvolve
  146. for dtype in ['complex256', 'complex192', 'float128', 'float96',
  147. 'str', 'void', 'bytes', 'object', 'unicode', 'string']:
  148. if dtype in types:
  149. types.remove(dtype)
  150. args = [(t1, t2, mode) for t1 in types for t2 in types
  151. for mode in ['valid', 'full', 'same']]
  152. # These are random arrays, which means test is much stronger than
  153. # convolving testing by convolving two np.ones arrays
  154. np.random.seed(42)
  155. array_types = {'i': np.random.choice([0, 1], size=n),
  156. 'f': np.random.randn(n)}
  157. array_types['b'] = array_types['u'] = array_types['i']
  158. array_types['c'] = array_types['f'] + 0.5j*array_types['f']
  159. for t1, t2, mode in args:
  160. x1 = array_types[np.dtype(t1).kind].astype(t1)
  161. x2 = array_types[np.dtype(t2).kind].astype(t2)
  162. results = {key: convolve(x1, x2, method=key, mode=mode)
  163. for key in ['fft', 'direct']}
  164. assert_equal(results['fft'].dtype, results['direct'].dtype)
  165. if 'bool' in t1 and 'bool' in t2:
  166. assert_equal(choose_conv_method(x1, x2), 'direct')
  167. continue
  168. # Found by experiment. Found approx smallest value for (rtol, atol)
  169. # threshold to have tests pass.
  170. if any([t in {'complex64', 'float32'} for t in [t1, t2]]):
  171. kwargs = {'rtol': 1.0e-4, 'atol': 1e-6}
  172. elif 'float16' in [t1, t2]:
  173. # atol is default for np.allclose
  174. kwargs = {'rtol': 1e-3, 'atol': 1e-8}
  175. else:
  176. # defaults for np.allclose (different from assert_allclose)
  177. kwargs = {'rtol': 1e-5, 'atol': 1e-8}
  178. assert_allclose(results['fft'], results['direct'], **kwargs)
  179. def test_convolve_method_large_input(self):
  180. # This is really a test that convolving two large integers goes to the
  181. # direct method even if they're in the fft method.
  182. for n in [10, 20, 50, 51, 52, 53, 54, 60, 62]:
  183. z = np.array([2**n], dtype=np.int64)
  184. fft = convolve(z, z, method='fft')
  185. direct = convolve(z, z, method='direct')
  186. # this is the case when integer precision gets to us
  187. # issue #6076 has more detail, hopefully more tests after resolved
  188. if n < 50:
  189. assert_equal(fft, direct)
  190. assert_equal(fft, 2**(2*n))
  191. assert_equal(direct, 2**(2*n))
  192. def test_mismatched_dims(self):
  193. # Input arrays should have the same number of dimensions
  194. assert_raises(ValueError, convolve, [1], 2, method='direct')
  195. assert_raises(ValueError, convolve, 1, [2], method='direct')
  196. assert_raises(ValueError, convolve, [1], 2, method='fft')
  197. assert_raises(ValueError, convolve, 1, [2], method='fft')
  198. assert_raises(ValueError, convolve, [1], [[2]])
  199. assert_raises(ValueError, convolve, [3], 2)
  200. class _TestConvolve2d(object):
  201. def test_2d_arrays(self):
  202. a = [[1, 2, 3], [3, 4, 5]]
  203. b = [[2, 3, 4], [4, 5, 6]]
  204. d = array([[2, 7, 16, 17, 12],
  205. [10, 30, 62, 58, 38],
  206. [12, 31, 58, 49, 30]])
  207. e = convolve2d(a, b)
  208. assert_array_equal(e, d)
  209. def test_valid_mode(self):
  210. e = [[2, 3, 4, 5, 6, 7, 8], [4, 5, 6, 7, 8, 9, 10]]
  211. f = [[1, 2, 3], [3, 4, 5]]
  212. h = array([[62, 80, 98, 116, 134]])
  213. g = convolve2d(e, f, 'valid')
  214. assert_array_equal(g, h)
  215. # See gh-5897
  216. g = convolve2d(f, e, 'valid')
  217. assert_array_equal(g, h)
  218. def test_valid_mode_complx(self):
  219. e = [[2, 3, 4, 5, 6, 7, 8], [4, 5, 6, 7, 8, 9, 10]]
  220. f = np.array([[1, 2, 3], [3, 4, 5]], dtype=complex) + 1j
  221. h = array([[62.+24.j, 80.+30.j, 98.+36.j, 116.+42.j, 134.+48.j]])
  222. g = convolve2d(e, f, 'valid')
  223. assert_array_almost_equal(g, h)
  224. # See gh-5897
  225. g = convolve2d(f, e, 'valid')
  226. assert_array_equal(g, h)
  227. def test_fillvalue(self):
  228. a = [[1, 2, 3], [3, 4, 5]]
  229. b = [[2, 3, 4], [4, 5, 6]]
  230. fillval = 1
  231. c = convolve2d(a, b, 'full', 'fill', fillval)
  232. d = array([[24, 26, 31, 34, 32],
  233. [28, 40, 62, 64, 52],
  234. [32, 46, 67, 62, 48]])
  235. assert_array_equal(c, d)
  236. def test_fillvalue_deprecations(self):
  237. # Deprecated 2017-07, scipy version 1.0.0
  238. with suppress_warnings() as sup:
  239. sup.filter(np.ComplexWarning, "Casting complex values to real")
  240. r = sup.record(DeprecationWarning, "could not cast `fillvalue`")
  241. convolve2d([[1]], [[1, 2]], fillvalue=1j)
  242. assert_(len(r) == 1)
  243. warnings.filterwarnings(
  244. "error", message="could not cast `fillvalue`",
  245. category=DeprecationWarning)
  246. assert_raises(DeprecationWarning, convolve2d, [[1]], [[1, 2]],
  247. fillvalue=1j)
  248. with suppress_warnings():
  249. warnings.filterwarnings(
  250. "always", message="`fillvalue` must be scalar or an array ",
  251. category=DeprecationWarning)
  252. assert_warns(DeprecationWarning, convolve2d, [[1]], [[1, 2]],
  253. fillvalue=[1, 2])
  254. warnings.filterwarnings(
  255. "error", message="`fillvalue` must be scalar or an array ",
  256. category=DeprecationWarning)
  257. assert_raises(DeprecationWarning, convolve2d, [[1]], [[1, 2]],
  258. fillvalue=[1, 2])
  259. def test_fillvalue_empty(self):
  260. # Check that fillvalue being empty raises an error:
  261. assert_raises(ValueError, convolve2d, [[1]], [[1, 2]],
  262. fillvalue=[])
  263. def test_wrap_boundary(self):
  264. a = [[1, 2, 3], [3, 4, 5]]
  265. b = [[2, 3, 4], [4, 5, 6]]
  266. c = convolve2d(a, b, 'full', 'wrap')
  267. d = array([[80, 80, 74, 80, 80],
  268. [68, 68, 62, 68, 68],
  269. [80, 80, 74, 80, 80]])
  270. assert_array_equal(c, d)
  271. def test_sym_boundary(self):
  272. a = [[1, 2, 3], [3, 4, 5]]
  273. b = [[2, 3, 4], [4, 5, 6]]
  274. c = convolve2d(a, b, 'full', 'symm')
  275. d = array([[34, 30, 44, 62, 66],
  276. [52, 48, 62, 80, 84],
  277. [82, 78, 92, 110, 114]])
  278. assert_array_equal(c, d)
  279. def test_invalid_shapes(self):
  280. # By "invalid," we mean that no one
  281. # array has dimensions that are all at
  282. # least as large as the corresponding
  283. # dimensions of the other array. This
  284. # setup should throw a ValueError.
  285. a = np.arange(1, 7).reshape((2, 3))
  286. b = np.arange(-6, 0).reshape((3, 2))
  287. assert_raises(ValueError, convolve2d, *(a, b), **{'mode': 'valid'})
  288. assert_raises(ValueError, convolve2d, *(b, a), **{'mode': 'valid'})
  289. class TestConvolve2d(_TestConvolve2d):
  290. def test_same_mode(self):
  291. e = [[1, 2, 3], [3, 4, 5]]
  292. f = [[2, 3, 4, 5, 6, 7, 8], [4, 5, 6, 7, 8, 9, 10]]
  293. g = convolve2d(e, f, 'same')
  294. h = array([[22, 28, 34],
  295. [80, 98, 116]])
  296. assert_array_equal(g, h)
  297. def test_valid_mode2(self):
  298. # See gh-5897
  299. e = [[1, 2, 3], [3, 4, 5]]
  300. f = [[2, 3, 4, 5, 6, 7, 8], [4, 5, 6, 7, 8, 9, 10]]
  301. expected = [[62, 80, 98, 116, 134]]
  302. out = convolve2d(e, f, 'valid')
  303. assert_array_equal(out, expected)
  304. out = convolve2d(f, e, 'valid')
  305. assert_array_equal(out, expected)
  306. e = [[1 + 1j, 2 - 3j], [3 + 1j, 4 + 0j]]
  307. f = [[2 - 1j, 3 + 2j, 4 + 0j], [4 - 0j, 5 + 1j, 6 - 3j]]
  308. expected = [[27 - 1j, 46. + 2j]]
  309. out = convolve2d(e, f, 'valid')
  310. assert_array_equal(out, expected)
  311. # See gh-5897
  312. out = convolve2d(f, e, 'valid')
  313. assert_array_equal(out, expected)
  314. def test_consistency_convolve_funcs(self):
  315. # Compare np.convolve, signal.convolve, signal.convolve2d
  316. a = np.arange(5)
  317. b = np.array([3.2, 1.4, 3])
  318. for mode in ['full', 'valid', 'same']:
  319. assert_almost_equal(np.convolve(a, b, mode=mode),
  320. signal.convolve(a, b, mode=mode))
  321. assert_almost_equal(np.squeeze(
  322. signal.convolve2d([a], [b], mode=mode)),
  323. signal.convolve(a, b, mode=mode))
  324. def test_invalid_dims(self):
  325. assert_raises(ValueError, convolve2d, 3, 4)
  326. assert_raises(ValueError, convolve2d, [3], [4])
  327. assert_raises(ValueError, convolve2d, [[[3]]], [[[4]]])
  328. class TestFFTConvolve(object):
  329. @pytest.mark.parametrize('axes', ['', None, 0, [0], -1, [-1]])
  330. def test_real(self, axes):
  331. a = array([1, 2, 3])
  332. expected = array([1, 4, 10, 12, 9.])
  333. if axes == '':
  334. out = fftconvolve(a, a)
  335. else:
  336. out = fftconvolve(a, a, axes=axes)
  337. assert_array_almost_equal(out, expected)
  338. @pytest.mark.parametrize('axes', [1, [1], -1, [-1]])
  339. def test_real_axes(self, axes):
  340. a = array([1, 2, 3])
  341. expected = array([1, 4, 10, 12, 9.])
  342. a = np.tile(a, [2, 1])
  343. expected = np.tile(expected, [2, 1])
  344. out = fftconvolve(a, a, axes=axes)
  345. assert_array_almost_equal(out, expected)
  346. @pytest.mark.parametrize('axes', ['', None, 0, [0], -1, [-1]])
  347. def test_complex(self, axes):
  348. a = array([1 + 1j, 2 + 2j, 3 + 3j])
  349. expected = array([0 + 2j, 0 + 8j, 0 + 20j, 0 + 24j, 0 + 18j])
  350. if axes == '':
  351. out = fftconvolve(a, a)
  352. else:
  353. out = fftconvolve(a, a, axes=axes)
  354. assert_array_almost_equal(out, expected)
  355. @pytest.mark.parametrize('axes', [1, [1], -1, [-1]])
  356. def test_complex_axes(self, axes):
  357. a = array([1 + 1j, 2 + 2j, 3 + 3j])
  358. expected = array([0 + 2j, 0 + 8j, 0 + 20j, 0 + 24j, 0 + 18j])
  359. a = np.tile(a, [2, 1])
  360. expected = np.tile(expected, [2, 1])
  361. out = fftconvolve(a, a, axes=axes)
  362. assert_array_almost_equal(out, expected)
  363. @pytest.mark.parametrize('axes', ['',
  364. None,
  365. [0, 1],
  366. [1, 0],
  367. [0, -1],
  368. [-1, 0],
  369. [-2, 1],
  370. [1, -2],
  371. [-2, -1],
  372. [-1, -2]])
  373. def test_2d_real_same(self, axes):
  374. a = array([[1, 2, 3],
  375. [4, 5, 6]])
  376. expected = array([[1, 4, 10, 12, 9],
  377. [8, 26, 56, 54, 36],
  378. [16, 40, 73, 60, 36]])
  379. if axes == '':
  380. out = fftconvolve(a, a)
  381. else:
  382. out = fftconvolve(a, a, axes=axes)
  383. assert_array_almost_equal(out, expected)
  384. @pytest.mark.parametrize('axes', [[1, 2],
  385. [2, 1],
  386. [1, -1],
  387. [-1, 1],
  388. [-2, 2],
  389. [2, -2],
  390. [-2, -1],
  391. [-1, -2]])
  392. def test_2d_real_same_axes(self, axes):
  393. a = array([[1, 2, 3],
  394. [4, 5, 6]])
  395. expected = array([[1, 4, 10, 12, 9],
  396. [8, 26, 56, 54, 36],
  397. [16, 40, 73, 60, 36]])
  398. a = np.tile(a, [2, 1, 1])
  399. expected = np.tile(expected, [2, 1, 1])
  400. out = fftconvolve(a, a, axes=axes)
  401. assert_array_almost_equal(out, expected)
  402. @pytest.mark.parametrize('axes', ['',
  403. None,
  404. [0, 1],
  405. [1, 0],
  406. [0, -1],
  407. [-1, 0],
  408. [-2, 1],
  409. [1, -2],
  410. [-2, -1],
  411. [-1, -2]])
  412. def test_2d_complex_same(self, axes):
  413. a = array([[1 + 2j, 3 + 4j, 5 + 6j],
  414. [2 + 1j, 4 + 3j, 6 + 5j]])
  415. expected = array([
  416. [-3 + 4j, -10 + 20j, -21 + 56j, -18 + 76j, -11 + 60j],
  417. [10j, 44j, 118j, 156j, 122j],
  418. [3 + 4j, 10 + 20j, 21 + 56j, 18 + 76j, 11 + 60j]
  419. ])
  420. if axes == '':
  421. out = fftconvolve(a, a)
  422. else:
  423. out = fftconvolve(a, a, axes=axes)
  424. assert_array_almost_equal(out, expected)
  425. @pytest.mark.parametrize('axes', [[1, 2],
  426. [2, 1],
  427. [1, -1],
  428. [-1, 1],
  429. [-2, 2],
  430. [2, -2],
  431. [-2, -1],
  432. [-1, -2]])
  433. def test_2d_complex_same_axes(self, axes):
  434. a = array([[1 + 2j, 3 + 4j, 5 + 6j],
  435. [2 + 1j, 4 + 3j, 6 + 5j]])
  436. expected = array([
  437. [-3 + 4j, -10 + 20j, -21 + 56j, -18 + 76j, -11 + 60j],
  438. [10j, 44j, 118j, 156j, 122j],
  439. [3 + 4j, 10 + 20j, 21 + 56j, 18 + 76j, 11 + 60j]
  440. ])
  441. a = np.tile(a, [2, 1, 1])
  442. expected = np.tile(expected, [2, 1, 1])
  443. out = fftconvolve(a, a, axes=axes)
  444. assert_array_almost_equal(out, expected)
  445. @pytest.mark.parametrize('axes', ['', None, 0, [0], -1, [-1]])
  446. def test_real_same_mode(self, axes):
  447. a = array([1, 2, 3])
  448. b = array([3, 3, 5, 6, 8, 7, 9, 0, 1])
  449. expected_1 = array([35., 41., 47.])
  450. expected_2 = array([9., 20., 25., 35., 41., 47., 39., 28., 2.])
  451. if axes == '':
  452. out = fftconvolve(a, b, 'same')
  453. else:
  454. out = fftconvolve(a, b, 'same', axes=axes)
  455. assert_array_almost_equal(out, expected_1)
  456. if axes == '':
  457. out = fftconvolve(b, a, 'same')
  458. else:
  459. out = fftconvolve(b, a, 'same', axes=axes)
  460. assert_array_almost_equal(out, expected_2)
  461. @pytest.mark.parametrize('axes', [1, -1, [1], [-1]])
  462. def test_real_same_mode_axes(self, axes):
  463. a = array([1, 2, 3])
  464. b = array([3, 3, 5, 6, 8, 7, 9, 0, 1])
  465. expected_1 = array([35., 41., 47.])
  466. expected_2 = array([9., 20., 25., 35., 41., 47., 39., 28., 2.])
  467. a = np.tile(a, [2, 1])
  468. b = np.tile(b, [2, 1])
  469. expected_1 = np.tile(expected_1, [2, 1])
  470. expected_2 = np.tile(expected_2, [2, 1])
  471. out = fftconvolve(a, b, 'same', axes=axes)
  472. assert_array_almost_equal(out, expected_1)
  473. out = fftconvolve(b, a, 'same', axes=axes)
  474. assert_array_almost_equal(out, expected_2)
  475. @pytest.mark.parametrize('axes', ['', None, 0, [0], -1, [-1]])
  476. def test_valid_mode_real(self, axes):
  477. # See gh-5897
  478. a = array([3, 2, 1])
  479. b = array([3, 3, 5, 6, 8, 7, 9, 0, 1])
  480. expected = array([24., 31., 41., 43., 49., 25., 12.])
  481. if axes == '':
  482. out = fftconvolve(a, b, 'valid')
  483. else:
  484. out = fftconvolve(a, b, 'valid', axes=axes)
  485. assert_array_almost_equal(out, expected)
  486. if axes == '':
  487. out = fftconvolve(b, a, 'valid')
  488. else:
  489. out = fftconvolve(b, a, 'valid', axes=axes)
  490. assert_array_almost_equal(out, expected)
  491. @pytest.mark.parametrize('axes', [1, [1]])
  492. def test_valid_mode_real_axes(self, axes):
  493. # See gh-5897
  494. a = array([3, 2, 1])
  495. b = array([3, 3, 5, 6, 8, 7, 9, 0, 1])
  496. expected = array([24., 31., 41., 43., 49., 25., 12.])
  497. a = np.tile(a, [2, 1])
  498. b = np.tile(b, [2, 1])
  499. expected = np.tile(expected, [2, 1])
  500. out = fftconvolve(a, b, 'valid', axes=axes)
  501. assert_array_almost_equal(out, expected)
  502. @pytest.mark.parametrize('axes', ['', None, 0, [0], -1, [-1]])
  503. def test_valid_mode_complex(self, axes):
  504. a = array([3 - 1j, 2 + 7j, 1 + 0j])
  505. b = array([3 + 2j, 3 - 3j, 5 + 0j, 6 - 1j, 8 + 0j])
  506. expected = array([45. + 12.j, 30. + 23.j, 48 + 32.j])
  507. if axes == '':
  508. out = fftconvolve(a, b, 'valid')
  509. else:
  510. out = fftconvolve(a, b, 'valid', axes=axes)
  511. assert_array_almost_equal(out, expected)
  512. if axes == '':
  513. out = fftconvolve(b, a, 'valid')
  514. else:
  515. out = fftconvolve(b, a, 'valid', axes=axes)
  516. assert_array_almost_equal(out, expected)
  517. @pytest.mark.parametrize('axes', [1, [1], -1, [-1]])
  518. def test_valid_mode_complex_axes(self, axes):
  519. a = array([3 - 1j, 2 + 7j, 1 + 0j])
  520. b = array([3 + 2j, 3 - 3j, 5 + 0j, 6 - 1j, 8 + 0j])
  521. expected = array([45. + 12.j, 30. + 23.j, 48 + 32.j])
  522. a = np.tile(a, [2, 1])
  523. b = np.tile(b, [2, 1])
  524. expected = np.tile(expected, [2, 1])
  525. out = fftconvolve(a, b, 'valid', axes=axes)
  526. assert_array_almost_equal(out, expected)
  527. out = fftconvolve(b, a, 'valid', axes=axes)
  528. assert_array_almost_equal(out, expected)
  529. def test_empty(self):
  530. # Regression test for #1745: crashes with 0-length input.
  531. assert_(fftconvolve([], []).size == 0)
  532. assert_(fftconvolve([5, 6], []).size == 0)
  533. assert_(fftconvolve([], [7]).size == 0)
  534. def test_zero_rank(self):
  535. a = array(4967)
  536. b = array(3920)
  537. out = fftconvolve(a, b)
  538. assert_equal(out, a * b)
  539. def test_single_element(self):
  540. a = array([4967])
  541. b = array([3920])
  542. out = fftconvolve(a, b)
  543. assert_equal(out, a * b)
  544. @pytest.mark.parametrize('axes', ['', None, 0, [0], -1, [-1]])
  545. def test_random_data(self, axes):
  546. np.random.seed(1234)
  547. a = np.random.rand(1233) + 1j * np.random.rand(1233)
  548. b = np.random.rand(1321) + 1j * np.random.rand(1321)
  549. expected = np.convolve(a, b, 'full')
  550. if axes == '':
  551. out = fftconvolve(a, b, 'full')
  552. else:
  553. out = fftconvolve(a, b, 'full', axes=axes)
  554. assert_(np.allclose(out, expected, rtol=1e-10))
  555. @pytest.mark.parametrize('axes', [1, [1], -1, [-1]])
  556. def test_random_data_axes(self, axes):
  557. np.random.seed(1234)
  558. a = np.random.rand(1233) + 1j * np.random.rand(1233)
  559. b = np.random.rand(1321) + 1j * np.random.rand(1321)
  560. expected = np.convolve(a, b, 'full')
  561. a = np.tile(a, [2, 1])
  562. b = np.tile(b, [2, 1])
  563. expected = np.tile(expected, [2, 1])
  564. out = fftconvolve(a, b, 'full', axes=axes)
  565. assert_(np.allclose(out, expected, rtol=1e-10))
  566. @pytest.mark.parametrize('axes', [[1, 4],
  567. [4, 1],
  568. [1, -1],
  569. [-1, 1],
  570. [-4, 4],
  571. [4, -4],
  572. [-4, -1],
  573. [-1, -4]])
  574. def test_random_data_multidim_axes(self, axes):
  575. np.random.seed(1234)
  576. a = np.random.rand(123, 222) + 1j * np.random.rand(123, 222)
  577. b = np.random.rand(132, 111) + 1j * np.random.rand(132, 111)
  578. expected = convolve2d(a, b, 'full')
  579. a = a[:, :, None, None, None]
  580. b = b[:, :, None, None, None]
  581. expected = expected[:, :, None, None, None]
  582. a = np.rollaxis(a.swapaxes(0, 2), 1, 5)
  583. b = np.rollaxis(b.swapaxes(0, 2), 1, 5)
  584. expected = np.rollaxis(expected.swapaxes(0, 2), 1, 5)
  585. # use 1 for dimension 2 in a and 3 in b to test broadcasting
  586. a = np.tile(a, [2, 1, 3, 1, 1])
  587. b = np.tile(b, [2, 1, 1, 4, 1])
  588. expected = np.tile(expected, [2, 1, 3, 4, 1])
  589. out = fftconvolve(a, b, 'full', axes=axes)
  590. assert_(np.allclose(out, expected, rtol=1e-10))
  591. @pytest.mark.slow
  592. @pytest.mark.parametrize(
  593. 'n',
  594. list(range(1, 100)) +
  595. list(range(1000, 1500)) +
  596. np.random.RandomState(1234).randint(1001, 10000, 5).tolist())
  597. def test_many_sizes(self, n):
  598. a = np.random.rand(n) + 1j * np.random.rand(n)
  599. b = np.random.rand(n) + 1j * np.random.rand(n)
  600. expected = np.convolve(a, b, 'full')
  601. out = fftconvolve(a, b, 'full')
  602. assert_allclose(out, expected, atol=1e-10)
  603. out = fftconvolve(a, b, 'full', axes=[0])
  604. assert_allclose(out, expected, atol=1e-10)
  605. def test_invalid_shapes(self):
  606. a = np.arange(1, 7).reshape((2, 3))
  607. b = np.arange(-6, 0).reshape((3, 2))
  608. with assert_raises(ValueError,
  609. match="For 'valid' mode, one must be at least "
  610. "as large as the other in every dimension"):
  611. fftconvolve(a, b, mode='valid')
  612. def test_invalid_shapes_axes(self):
  613. a = np.zeros([5, 6, 2, 1])
  614. b = np.zeros([5, 6, 3, 1])
  615. with assert_raises(ValueError,
  616. match=r"incompatible shapes for in1 and in2:"
  617. r" \(5L?, 6L?, 2L?, 1L?\) and"
  618. r" \(5L?, 6L?, 3L?, 1L?\)"):
  619. fftconvolve(a, b, axes=[0, 1])
  620. @pytest.mark.parametrize('a,b',
  621. [([1], 2),
  622. (1, [2]),
  623. ([3], [[2]])])
  624. def test_mismatched_dims(self, a, b):
  625. with assert_raises(ValueError,
  626. match="in1 and in2 should have the same"
  627. " dimensionality"):
  628. fftconvolve(a, b)
  629. def test_invalid_flags(self):
  630. with assert_raises(ValueError,
  631. match="acceptable mode flags are 'valid',"
  632. " 'same', or 'full'"):
  633. fftconvolve([1], [2], mode='chips')
  634. with assert_raises(ValueError,
  635. match="when provided, axes cannot be empty"):
  636. fftconvolve([1], [2], axes=[])
  637. with assert_raises(ValueError,
  638. match="when given, axes values must be a scalar"
  639. " or vector"):
  640. fftconvolve([1], [2], axes=[[1, 2], [3, 4]])
  641. with assert_raises(ValueError,
  642. match="when given, axes values must be integers"):
  643. fftconvolve([1], [2], axes=[1., 2., 3., 4.])
  644. with assert_raises(ValueError,
  645. match="axes exceeds dimensionality of input"):
  646. fftconvolve([1], [2], axes=[1])
  647. with assert_raises(ValueError,
  648. match="axes exceeds dimensionality of input"):
  649. fftconvolve([1], [2], axes=[-2])
  650. with assert_raises(ValueError,
  651. match="all axes must be unique"):
  652. fftconvolve([1], [2], axes=[0, 0])
  653. class TestMedFilt(object):
  654. def test_basic(self):
  655. f = [[50, 50, 50, 50, 50, 92, 18, 27, 65, 46],
  656. [50, 50, 50, 50, 50, 0, 72, 77, 68, 66],
  657. [50, 50, 50, 50, 50, 46, 47, 19, 64, 77],
  658. [50, 50, 50, 50, 50, 42, 15, 29, 95, 35],
  659. [50, 50, 50, 50, 50, 46, 34, 9, 21, 66],
  660. [70, 97, 28, 68, 78, 77, 61, 58, 71, 42],
  661. [64, 53, 44, 29, 68, 32, 19, 68, 24, 84],
  662. [3, 33, 53, 67, 1, 78, 74, 55, 12, 83],
  663. [7, 11, 46, 70, 60, 47, 24, 43, 61, 26],
  664. [32, 61, 88, 7, 39, 4, 92, 64, 45, 61]]
  665. d = signal.medfilt(f, [7, 3])
  666. e = signal.medfilt2d(np.array(f, float), [7, 3])
  667. assert_array_equal(d, [[0, 50, 50, 50, 42, 15, 15, 18, 27, 0],
  668. [0, 50, 50, 50, 50, 42, 19, 21, 29, 0],
  669. [50, 50, 50, 50, 50, 47, 34, 34, 46, 35],
  670. [50, 50, 50, 50, 50, 50, 42, 47, 64, 42],
  671. [50, 50, 50, 50, 50, 50, 46, 55, 64, 35],
  672. [33, 50, 50, 50, 50, 47, 46, 43, 55, 26],
  673. [32, 50, 50, 50, 50, 47, 46, 45, 55, 26],
  674. [7, 46, 50, 50, 47, 46, 46, 43, 45, 21],
  675. [0, 32, 33, 39, 32, 32, 43, 43, 43, 0],
  676. [0, 7, 11, 7, 4, 4, 19, 19, 24, 0]])
  677. assert_array_equal(d, e)
  678. def test_none(self):
  679. # Ticket #1124. Ensure this does not segfault.
  680. signal.medfilt(None)
  681. # Expand on this test to avoid a regression with possible contiguous
  682. # numpy arrays that have odd strides. The stride value below gets
  683. # us into wrong memory if used (but it does not need to be used)
  684. dummy = np.arange(10, dtype=np.float64)
  685. a = dummy[5:6]
  686. a.strides = 16
  687. assert_(signal.medfilt(a, 1) == 5.)
  688. def test_refcounting(self):
  689. # Check a refcounting-related crash
  690. a = Decimal(123)
  691. x = np.array([a, a], dtype=object)
  692. if hasattr(sys, 'getrefcount'):
  693. n = 2 * sys.getrefcount(a)
  694. else:
  695. n = 10
  696. # Shouldn't segfault:
  697. for j in range(n):
  698. signal.medfilt(x)
  699. if hasattr(sys, 'getrefcount'):
  700. assert_(sys.getrefcount(a) < n)
  701. assert_equal(x, [a, a])
  702. class TestWiener(object):
  703. def test_basic(self):
  704. g = array([[5, 6, 4, 3],
  705. [3, 5, 6, 2],
  706. [2, 3, 5, 6],
  707. [1, 6, 9, 7]], 'd')
  708. h = array([[2.16374269, 3.2222222222, 2.8888888889, 1.6666666667],
  709. [2.666666667, 4.33333333333, 4.44444444444, 2.8888888888],
  710. [2.222222222, 4.4444444444, 5.4444444444, 4.801066874837],
  711. [1.33333333333, 3.92735042735, 6.0712560386, 5.0404040404]])
  712. assert_array_almost_equal(signal.wiener(g), h, decimal=6)
  713. assert_array_almost_equal(signal.wiener(g, mysize=3), h, decimal=6)
  714. class TestResample(object):
  715. def test_basic(self):
  716. # Some basic tests
  717. # Regression test for issue #3603.
  718. # window.shape must equal to sig.shape[0]
  719. sig = np.arange(128)
  720. num = 256
  721. win = signal.get_window(('kaiser', 8.0), 160)
  722. assert_raises(ValueError, signal.resample, sig, num, window=win)
  723. # Other degenerate conditions
  724. assert_raises(ValueError, signal.resample_poly, sig, 'yo', 1)
  725. assert_raises(ValueError, signal.resample_poly, sig, 1, 0)
  726. # test for issue #6505 - should not modify window.shape when axis ≠ 0
  727. sig2 = np.tile(np.arange(160), (2,1))
  728. signal.resample(sig2, num, axis=-1, window=win)
  729. assert_(win.shape == (160,))
  730. def test_fft(self):
  731. # Test FFT-based resampling
  732. self._test_data(method='fft')
  733. def test_polyphase(self):
  734. # Test polyphase resampling
  735. self._test_data(method='polyphase')
  736. def test_polyphase_extfilter(self):
  737. # Test external specification of downsampling filter
  738. self._test_data(method='polyphase', ext=True)
  739. def test_mutable_window(self):
  740. # Test that a mutable window is not modified
  741. impulse = np.zeros(3)
  742. window = np.random.RandomState(0).randn(2)
  743. window_orig = window.copy()
  744. signal.resample_poly(impulse, 5, 1, window=window)
  745. assert_array_equal(window, window_orig)
  746. def test_output_float32(self):
  747. # Test that float32 inputs yield a float32 output
  748. x = np.arange(10, dtype=np.float32)
  749. h = np.array([1,1,1], dtype=np.float32)
  750. y = signal.resample_poly(x, 1, 2, window=h)
  751. assert_(y.dtype == np.float32)
  752. def _test_data(self, method, ext=False):
  753. # Test resampling of sinusoids and random noise (1-sec)
  754. rate = 100
  755. rates_to = [49, 50, 51, 99, 100, 101, 199, 200, 201]
  756. # Sinusoids, windowed to avoid edge artifacts
  757. t = np.arange(rate) / float(rate)
  758. freqs = np.array((1., 10., 40.))[:, np.newaxis]
  759. x = np.sin(2 * np.pi * freqs * t) * hann(rate)
  760. for rate_to in rates_to:
  761. t_to = np.arange(rate_to) / float(rate_to)
  762. y_tos = np.sin(2 * np.pi * freqs * t_to) * hann(rate_to)
  763. if method == 'fft':
  764. y_resamps = signal.resample(x, rate_to, axis=-1)
  765. else:
  766. if ext and rate_to != rate:
  767. # Match default window design
  768. g = gcd(rate_to, rate)
  769. up = rate_to // g
  770. down = rate // g
  771. max_rate = max(up, down)
  772. f_c = 1. / max_rate
  773. half_len = 10 * max_rate
  774. window = signal.firwin(2 * half_len + 1, f_c,
  775. window=('kaiser', 5.0))
  776. polyargs = {'window': window}
  777. else:
  778. polyargs = {}
  779. y_resamps = signal.resample_poly(x, rate_to, rate, axis=-1,
  780. **polyargs)
  781. for y_to, y_resamp, freq in zip(y_tos, y_resamps, freqs):
  782. if freq >= 0.5 * rate_to:
  783. y_to.fill(0.) # mostly low-passed away
  784. assert_allclose(y_resamp, y_to, atol=1e-3)
  785. else:
  786. assert_array_equal(y_to.shape, y_resamp.shape)
  787. corr = np.corrcoef(y_to, y_resamp)[0, 1]
  788. assert_(corr > 0.99, msg=(corr, rate, rate_to))
  789. # Random data
  790. rng = np.random.RandomState(0)
  791. x = hann(rate) * np.cumsum(rng.randn(rate)) # low-pass, wind
  792. for rate_to in rates_to:
  793. # random data
  794. t_to = np.arange(rate_to) / float(rate_to)
  795. y_to = np.interp(t_to, t, x)
  796. if method == 'fft':
  797. y_resamp = signal.resample(x, rate_to)
  798. else:
  799. y_resamp = signal.resample_poly(x, rate_to, rate)
  800. assert_array_equal(y_to.shape, y_resamp.shape)
  801. corr = np.corrcoef(y_to, y_resamp)[0, 1]
  802. assert_(corr > 0.99, msg=corr)
  803. # More tests of fft method (Master 0.18.1 fails these)
  804. if method == 'fft':
  805. x1 = np.array([1.+0.j,0.+0.j])
  806. y1_test = signal.resample(x1,4)
  807. y1_true = np.array([1.+0.j,0.5+0.j,0.+0.j,0.5+0.j]) # upsampling a complex array
  808. assert_allclose(y1_test, y1_true, atol=1e-12)
  809. x2 = np.array([1.,0.5,0.,0.5])
  810. y2_test = signal.resample(x2,2) # downsampling a real array
  811. y2_true = np.array([1.,0.])
  812. assert_allclose(y2_test, y2_true, atol=1e-12)
  813. def test_poly_vs_filtfilt(self):
  814. # Check that up=1.0 gives same answer as filtfilt + slicing
  815. random_state = np.random.RandomState(17)
  816. try_types = (int, np.float32, np.complex64, float, complex)
  817. size = 10000
  818. down_factors = [2, 11, 79]
  819. for dtype in try_types:
  820. x = random_state.randn(size).astype(dtype)
  821. if dtype in (np.complex64, np.complex128):
  822. x += 1j * random_state.randn(size)
  823. # resample_poly assumes zeros outside of signl, whereas filtfilt
  824. # can only constant-pad. Make them equivalent:
  825. x[0] = 0
  826. x[-1] = 0
  827. for down in down_factors:
  828. h = signal.firwin(31, 1. / down, window='hamming')
  829. yf = filtfilt(h, 1.0, x, padtype='constant')[::down]
  830. # Need to pass convolved version of filter to resample_poly,
  831. # since filtfilt does forward and backward, but resample_poly
  832. # only goes forward
  833. hc = convolve(h, h[::-1])
  834. y = signal.resample_poly(x, 1, down, window=hc)
  835. assert_allclose(yf, y, atol=1e-7, rtol=1e-7)
  836. def test_correlate1d(self):
  837. for down in [2, 4]:
  838. for nx in range(1, 40, down):
  839. for nweights in (32, 33):
  840. x = np.random.random((nx,))
  841. weights = np.random.random((nweights,))
  842. y_g = correlate1d(x, weights[::-1], mode='constant')
  843. y_s = signal.resample_poly(x, up=1, down=down, window=weights)
  844. assert_allclose(y_g[::down], y_s)
  845. class TestCSpline1DEval(object):
  846. def test_basic(self):
  847. y = array([1, 2, 3, 4, 3, 2, 1, 2, 3.0])
  848. x = arange(len(y))
  849. dx = x[1] - x[0]
  850. cj = signal.cspline1d(y)
  851. x2 = arange(len(y) * 10.0) / 10.0
  852. y2 = signal.cspline1d_eval(cj, x2, dx=dx, x0=x[0])
  853. # make sure interpolated values are on knot points
  854. assert_array_almost_equal(y2[::10], y, decimal=5)
  855. def test_complex(self):
  856. # create some smoothly varying complex signal to interpolate
  857. x = np.arange(2)
  858. y = np.zeros(x.shape, dtype=np.complex64)
  859. T = 10.0
  860. f = 1.0 / T
  861. y = np.exp(2.0J * np.pi * f * x)
  862. # get the cspline transform
  863. cy = signal.cspline1d(y)
  864. # determine new test x value and interpolate
  865. xnew = np.array([0.5])
  866. ynew = signal.cspline1d_eval(cy, xnew)
  867. assert_equal(ynew.dtype, y.dtype)
  868. class TestOrderFilt(object):
  869. def test_basic(self):
  870. assert_array_equal(signal.order_filter([1, 2, 3], [1, 0, 1], 1),
  871. [2, 3, 2])
  872. class _TestLinearFilter(object):
  873. def generate(self, shape):
  874. x = np.linspace(0, np.prod(shape) - 1, np.prod(shape)).reshape(shape)
  875. return self.convert_dtype(x)
  876. def convert_dtype(self, arr):
  877. if self.dtype == np.dtype('O'):
  878. arr = np.asarray(arr)
  879. out = np.empty(arr.shape, self.dtype)
  880. iter = np.nditer([arr, out], ['refs_ok','zerosize_ok'],
  881. [['readonly'],['writeonly']])
  882. for x, y in iter:
  883. y[...] = self.type(x[()])
  884. return out
  885. else:
  886. return np.array(arr, self.dtype, copy=False)
  887. def test_rank_1_IIR(self):
  888. x = self.generate((6,))
  889. b = self.convert_dtype([1, -1])
  890. a = self.convert_dtype([0.5, -0.5])
  891. y_r = self.convert_dtype([0, 2, 4, 6, 8, 10.])
  892. assert_array_almost_equal(lfilter(b, a, x), y_r)
  893. def test_rank_1_FIR(self):
  894. x = self.generate((6,))
  895. b = self.convert_dtype([1, 1])
  896. a = self.convert_dtype([1])
  897. y_r = self.convert_dtype([0, 1, 3, 5, 7, 9.])
  898. assert_array_almost_equal(lfilter(b, a, x), y_r)
  899. def test_rank_1_IIR_init_cond(self):
  900. x = self.generate((6,))
  901. b = self.convert_dtype([1, 0, -1])
  902. a = self.convert_dtype([0.5, -0.5])
  903. zi = self.convert_dtype([1, 2])
  904. y_r = self.convert_dtype([1, 5, 9, 13, 17, 21])
  905. zf_r = self.convert_dtype([13, -10])
  906. y, zf = lfilter(b, a, x, zi=zi)
  907. assert_array_almost_equal(y, y_r)
  908. assert_array_almost_equal(zf, zf_r)
  909. def test_rank_1_FIR_init_cond(self):
  910. x = self.generate((6,))
  911. b = self.convert_dtype([1, 1, 1])
  912. a = self.convert_dtype([1])
  913. zi = self.convert_dtype([1, 1])
  914. y_r = self.convert_dtype([1, 2, 3, 6, 9, 12.])
  915. zf_r = self.convert_dtype([9, 5])
  916. y, zf = lfilter(b, a, x, zi=zi)
  917. assert_array_almost_equal(y, y_r)
  918. assert_array_almost_equal(zf, zf_r)
  919. def test_rank_2_IIR_axis_0(self):
  920. x = self.generate((4, 3))
  921. b = self.convert_dtype([1, -1])
  922. a = self.convert_dtype([0.5, 0.5])
  923. y_r2_a0 = self.convert_dtype([[0, 2, 4], [6, 4, 2], [0, 2, 4],
  924. [6, 4, 2]])
  925. y = lfilter(b, a, x, axis=0)
  926. assert_array_almost_equal(y_r2_a0, y)
  927. def test_rank_2_IIR_axis_1(self):
  928. x = self.generate((4, 3))
  929. b = self.convert_dtype([1, -1])
  930. a = self.convert_dtype([0.5, 0.5])
  931. y_r2_a1 = self.convert_dtype([[0, 2, 0], [6, -4, 6], [12, -10, 12],
  932. [18, -16, 18]])
  933. y = lfilter(b, a, x, axis=1)
  934. assert_array_almost_equal(y_r2_a1, y)
  935. def test_rank_2_IIR_axis_0_init_cond(self):
  936. x = self.generate((4, 3))
  937. b = self.convert_dtype([1, -1])
  938. a = self.convert_dtype([0.5, 0.5])
  939. zi = self.convert_dtype(np.ones((4,1)))
  940. y_r2_a0_1 = self.convert_dtype([[1, 1, 1], [7, -5, 7], [13, -11, 13],
  941. [19, -17, 19]])
  942. zf_r = self.convert_dtype([-5, -17, -29, -41])[:, np.newaxis]
  943. y, zf = lfilter(b, a, x, axis=1, zi=zi)
  944. assert_array_almost_equal(y_r2_a0_1, y)
  945. assert_array_almost_equal(zf, zf_r)
  946. def test_rank_2_IIR_axis_1_init_cond(self):
  947. x = self.generate((4,3))
  948. b = self.convert_dtype([1, -1])
  949. a = self.convert_dtype([0.5, 0.5])
  950. zi = self.convert_dtype(np.ones((1,3)))
  951. y_r2_a0_0 = self.convert_dtype([[1, 3, 5], [5, 3, 1],
  952. [1, 3, 5], [5, 3, 1]])
  953. zf_r = self.convert_dtype([[-23, -23, -23]])
  954. y, zf = lfilter(b, a, x, axis=0, zi=zi)
  955. assert_array_almost_equal(y_r2_a0_0, y)
  956. assert_array_almost_equal(zf, zf_r)
  957. def test_rank_3_IIR(self):
  958. x = self.generate((4, 3, 2))
  959. b = self.convert_dtype([1, -1])
  960. a = self.convert_dtype([0.5, 0.5])
  961. for axis in range(x.ndim):
  962. y = lfilter(b, a, x, axis)
  963. y_r = np.apply_along_axis(lambda w: lfilter(b, a, w), axis, x)
  964. assert_array_almost_equal(y, y_r)
  965. def test_rank_3_IIR_init_cond(self):
  966. x = self.generate((4, 3, 2))
  967. b = self.convert_dtype([1, -1])
  968. a = self.convert_dtype([0.5, 0.5])
  969. for axis in range(x.ndim):
  970. zi_shape = list(x.shape)
  971. zi_shape[axis] = 1
  972. zi = self.convert_dtype(np.ones(zi_shape))
  973. zi1 = self.convert_dtype([1])
  974. y, zf = lfilter(b, a, x, axis, zi)
  975. lf0 = lambda w: lfilter(b, a, w, zi=zi1)[0]
  976. lf1 = lambda w: lfilter(b, a, w, zi=zi1)[1]
  977. y_r = np.apply_along_axis(lf0, axis, x)
  978. zf_r = np.apply_along_axis(lf1, axis, x)
  979. assert_array_almost_equal(y, y_r)
  980. assert_array_almost_equal(zf, zf_r)
  981. def test_rank_3_FIR(self):
  982. x = self.generate((4, 3, 2))
  983. b = self.convert_dtype([1, 0, -1])
  984. a = self.convert_dtype([1])
  985. for axis in range(x.ndim):
  986. y = lfilter(b, a, x, axis)
  987. y_r = np.apply_along_axis(lambda w: lfilter(b, a, w), axis, x)
  988. assert_array_almost_equal(y, y_r)
  989. def test_rank_3_FIR_init_cond(self):
  990. x = self.generate((4, 3, 2))
  991. b = self.convert_dtype([1, 0, -1])
  992. a = self.convert_dtype([1])
  993. for axis in range(x.ndim):
  994. zi_shape = list(x.shape)
  995. zi_shape[axis] = 2
  996. zi = self.convert_dtype(np.ones(zi_shape))
  997. zi1 = self.convert_dtype([1, 1])
  998. y, zf = lfilter(b, a, x, axis, zi)
  999. lf0 = lambda w: lfilter(b, a, w, zi=zi1)[0]
  1000. lf1 = lambda w: lfilter(b, a, w, zi=zi1)[1]
  1001. y_r = np.apply_along_axis(lf0, axis, x)
  1002. zf_r = np.apply_along_axis(lf1, axis, x)
  1003. assert_array_almost_equal(y, y_r)
  1004. assert_array_almost_equal(zf, zf_r)
  1005. def test_zi_pseudobroadcast(self):
  1006. x = self.generate((4, 5, 20))
  1007. b,a = signal.butter(8, 0.2, output='ba')
  1008. b = self.convert_dtype(b)
  1009. a = self.convert_dtype(a)
  1010. zi_size = b.shape[0] - 1
  1011. # lfilter requires x.ndim == zi.ndim exactly. However, zi can have
  1012. # length 1 dimensions.
  1013. zi_full = self.convert_dtype(np.ones((4, 5, zi_size)))
  1014. zi_sing = self.convert_dtype(np.ones((1, 1, zi_size)))
  1015. y_full, zf_full = lfilter(b, a, x, zi=zi_full)
  1016. y_sing, zf_sing = lfilter(b, a, x, zi=zi_sing)
  1017. assert_array_almost_equal(y_sing, y_full)
  1018. assert_array_almost_equal(zf_full, zf_sing)
  1019. # lfilter does not prepend ones
  1020. assert_raises(ValueError, lfilter, b, a, x, -1, np.ones(zi_size))
  1021. def test_scalar_a(self):
  1022. # a can be a scalar.
  1023. x = self.generate(6)
  1024. b = self.convert_dtype([1, 0, -1])
  1025. a = self.convert_dtype([1])
  1026. y_r = self.convert_dtype([0, 1, 2, 2, 2, 2])
  1027. y = lfilter(b, a[0], x)
  1028. assert_array_almost_equal(y, y_r)
  1029. def test_zi_some_singleton_dims(self):
  1030. # lfilter doesn't really broadcast (no prepending of 1's). But does
  1031. # do singleton expansion if x and zi have the same ndim. This was
  1032. # broken only if a subset of the axes were singletons (gh-4681).
  1033. x = self.convert_dtype(np.zeros((3,2,5), 'l'))
  1034. b = self.convert_dtype(np.ones(5, 'l'))
  1035. a = self.convert_dtype(np.array([1,0,0]))
  1036. zi = np.ones((3,1,4), 'l')
  1037. zi[1,:,:] *= 2
  1038. zi[2,:,:] *= 3
  1039. zi = self.convert_dtype(zi)
  1040. zf_expected = self.convert_dtype(np.zeros((3,2,4), 'l'))
  1041. y_expected = np.zeros((3,2,5), 'l')
  1042. y_expected[:,:,:4] = [[[1]], [[2]], [[3]]]
  1043. y_expected = self.convert_dtype(y_expected)
  1044. # IIR
  1045. y_iir, zf_iir = lfilter(b, a, x, -1, zi)
  1046. assert_array_almost_equal(y_iir, y_expected)
  1047. assert_array_almost_equal(zf_iir, zf_expected)
  1048. # FIR
  1049. y_fir, zf_fir = lfilter(b, a[0], x, -1, zi)
  1050. assert_array_almost_equal(y_fir, y_expected)
  1051. assert_array_almost_equal(zf_fir, zf_expected)
  1052. def base_bad_size_zi(self, b, a, x, axis, zi):
  1053. b = self.convert_dtype(b)
  1054. a = self.convert_dtype(a)
  1055. x = self.convert_dtype(x)
  1056. zi = self.convert_dtype(zi)
  1057. assert_raises(ValueError, lfilter, b, a, x, axis, zi)
  1058. def test_bad_size_zi(self):
  1059. # rank 1
  1060. x1 = np.arange(6)
  1061. self.base_bad_size_zi([1], [1], x1, -1, [1])
  1062. self.base_bad_size_zi([1, 1], [1], x1, -1, [0, 1])
  1063. self.base_bad_size_zi([1, 1], [1], x1, -1, [[0]])
  1064. self.base_bad_size_zi([1, 1], [1], x1, -1, [0, 1, 2])
  1065. self.base_bad_size_zi([1, 1, 1], [1], x1, -1, [[0]])
  1066. self.base_bad_size_zi([1, 1, 1], [1], x1, -1, [0, 1, 2])
  1067. self.base_bad_size_zi([1], [1, 1], x1, -1, [0, 1])
  1068. self.base_bad_size_zi([1], [1, 1], x1, -1, [[0]])
  1069. self.base_bad_size_zi([1], [1, 1], x1, -1, [0, 1, 2])
  1070. self.base_bad_size_zi([1, 1, 1], [1, 1], x1, -1, [0])
  1071. self.base_bad_size_zi([1, 1, 1], [1, 1], x1, -1, [[0], [1]])
  1072. self.base_bad_size_zi([1, 1, 1], [1, 1], x1, -1, [0, 1, 2])
  1073. self.base_bad_size_zi([1, 1, 1], [1, 1], x1, -1, [0, 1, 2, 3])
  1074. self.base_bad_size_zi([1, 1], [1, 1, 1], x1, -1, [0])
  1075. self.base_bad_size_zi([1, 1], [1, 1, 1], x1, -1, [[0], [1]])
  1076. self.base_bad_size_zi([1, 1], [1, 1, 1], x1, -1, [0, 1, 2])
  1077. self.base_bad_size_zi([1, 1], [1, 1, 1], x1, -1, [0, 1, 2, 3])
  1078. # rank 2
  1079. x2 = np.arange(12).reshape((4,3))
  1080. # for axis=0 zi.shape should == (max(len(a),len(b))-1, 3)
  1081. self.base_bad_size_zi([1], [1], x2, 0, [0])
  1082. # for each of these there are 5 cases tested (in this order):
  1083. # 1. not deep enough, right # elements
  1084. # 2. too deep, right # elements
  1085. # 3. right depth, right # elements, transposed
  1086. # 4. right depth, too few elements
  1087. # 5. right depth, too many elements
  1088. self.base_bad_size_zi([1, 1], [1], x2, 0, [0,1,2])
  1089. self.base_bad_size_zi([1, 1], [1], x2, 0, [[[0,1,2]]])
  1090. self.base_bad_size_zi([1, 1], [1], x2, 0, [[0], [1], [2]])
  1091. self.base_bad_size_zi([1, 1], [1], x2, 0, [[0,1]])
  1092. self.base_bad_size_zi([1, 1], [1], x2, 0, [[0,1,2,3]])
  1093. self.base_bad_size_zi([1, 1, 1], [1], x2, 0, [0,1,2,3,4,5])
  1094. self.base_bad_size_zi([1, 1, 1], [1], x2, 0, [[[0,1,2],[3,4,5]]])
  1095. self.base_bad_size_zi([1, 1, 1], [1], x2, 0, [[0,1],[2,3],[4,5]])
  1096. self.base_bad_size_zi([1, 1, 1], [1], x2, 0, [[0,1],[2,3]])
  1097. self.base_bad_size_zi([1, 1, 1], [1], x2, 0, [[0,1,2,3],[4,5,6,7]])
  1098. self.base_bad_size_zi([1], [1, 1], x2, 0, [0,1,2])
  1099. self.base_bad_size_zi([1], [1, 1], x2, 0, [[[0,1,2]]])
  1100. self.base_bad_size_zi([1], [1, 1], x2, 0, [[0], [1], [2]])
  1101. self.base_bad_size_zi([1], [1, 1], x2, 0, [[0,1]])
  1102. self.base_bad_size_zi([1], [1, 1], x2, 0, [[0,1,2,3]])
  1103. self.base_bad_size_zi([1], [1, 1, 1], x2, 0, [0,1,2,3,4,5])
  1104. self.base_bad_size_zi([1], [1, 1, 1], x2, 0, [[[0,1,2],[3,4,5]]])
  1105. self.base_bad_size_zi([1], [1, 1, 1], x2, 0, [[0,1],[2,3],[4,5]])
  1106. self.base_bad_size_zi([1], [1, 1, 1], x2, 0, [[0,1],[2,3]])
  1107. self.base_bad_size_zi([1], [1, 1, 1], x2, 0, [[0,1,2,3],[4,5,6,7]])
  1108. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 0, [0,1,2,3,4,5])
  1109. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 0, [[[0,1,2],[3,4,5]]])
  1110. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 0, [[0,1],[2,3],[4,5]])
  1111. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 0, [[0,1],[2,3]])
  1112. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 0, [[0,1,2,3],[4,5,6,7]])
  1113. # for axis=1 zi.shape should == (4, max(len(a),len(b))-1)
  1114. self.base_bad_size_zi([1], [1], x2, 1, [0])
  1115. self.base_bad_size_zi([1, 1], [1], x2, 1, [0,1,2,3])
  1116. self.base_bad_size_zi([1, 1], [1], x2, 1, [[[0],[1],[2],[3]]])
  1117. self.base_bad_size_zi([1, 1], [1], x2, 1, [[0, 1, 2, 3]])
  1118. self.base_bad_size_zi([1, 1], [1], x2, 1, [[0],[1],[2]])
  1119. self.base_bad_size_zi([1, 1], [1], x2, 1, [[0],[1],[2],[3],[4]])
  1120. self.base_bad_size_zi([1, 1, 1], [1], x2, 1, [0,1,2,3,4,5,6,7])
  1121. self.base_bad_size_zi([1, 1, 1], [1], x2, 1, [[[0,1],[2,3],[4,5],[6,7]]])
  1122. self.base_bad_size_zi([1, 1, 1], [1], x2, 1, [[0,1,2,3],[4,5,6,7]])
  1123. self.base_bad_size_zi([1, 1, 1], [1], x2, 1, [[0,1],[2,3],[4,5]])
  1124. self.base_bad_size_zi([1, 1, 1], [1], x2, 1, [[0,1],[2,3],[4,5],[6,7],[8,9]])
  1125. self.base_bad_size_zi([1], [1, 1], x2, 1, [0,1,2,3])
  1126. self.base_bad_size_zi([1], [1, 1], x2, 1, [[[0],[1],[2],[3]]])
  1127. self.base_bad_size_zi([1], [1, 1], x2, 1, [[0, 1, 2, 3]])
  1128. self.base_bad_size_zi([1], [1, 1], x2, 1, [[0],[1],[2]])
  1129. self.base_bad_size_zi([1], [1, 1], x2, 1, [[0],[1],[2],[3],[4]])
  1130. self.base_bad_size_zi([1], [1, 1, 1], x2, 1, [0,1,2,3,4,5,6,7])
  1131. self.base_bad_size_zi([1], [1, 1, 1], x2, 1, [[[0,1],[2,3],[4,5],[6,7]]])
  1132. self.base_bad_size_zi([1], [1, 1, 1], x2, 1, [[0,1,2,3],[4,5,6,7]])
  1133. self.base_bad_size_zi([1], [1, 1, 1], x2, 1, [[0,1],[2,3],[4,5]])
  1134. self.base_bad_size_zi([1], [1, 1, 1], x2, 1, [[0,1],[2,3],[4,5],[6,7],[8,9]])
  1135. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 1, [0,1,2,3,4,5,6,7])
  1136. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 1, [[[0,1],[2,3],[4,5],[6,7]]])
  1137. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 1, [[0,1,2,3],[4,5,6,7]])
  1138. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 1, [[0,1],[2,3],[4,5]])
  1139. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 1, [[0,1],[2,3],[4,5],[6,7],[8,9]])
  1140. def test_empty_zi(self):
  1141. # Regression test for #880: empty array for zi crashes.
  1142. x = self.generate((5,))
  1143. a = self.convert_dtype([1])
  1144. b = self.convert_dtype([1])
  1145. zi = self.convert_dtype([])
  1146. y, zf = lfilter(b, a, x, zi=zi)
  1147. assert_array_almost_equal(y, x)
  1148. assert_equal(zf.dtype, self.dtype)
  1149. assert_equal(zf.size, 0)
  1150. def test_lfiltic_bad_zi(self):
  1151. # Regression test for #3699: bad initial conditions
  1152. a = self.convert_dtype([1])
  1153. b = self.convert_dtype([1])
  1154. # "y" sets the datatype of zi, so it truncates if int
  1155. zi = lfiltic(b, a, [1., 0])
  1156. zi_1 = lfiltic(b, a, [1, 0])
  1157. zi_2 = lfiltic(b, a, [True, False])
  1158. assert_array_equal(zi, zi_1)
  1159. assert_array_equal(zi, zi_2)
  1160. def test_short_x_FIR(self):
  1161. # regression test for #5116
  1162. # x shorter than b, with non None zi fails
  1163. a = self.convert_dtype([1])
  1164. b = self.convert_dtype([1, 0, -1])
  1165. zi = self.convert_dtype([2, 7])
  1166. x = self.convert_dtype([72])
  1167. ye = self.convert_dtype([74])
  1168. zfe = self.convert_dtype([7, -72])
  1169. y, zf = lfilter(b, a, x, zi=zi)
  1170. assert_array_almost_equal(y, ye)
  1171. assert_array_almost_equal(zf, zfe)
  1172. def test_short_x_IIR(self):
  1173. # regression test for #5116
  1174. # x shorter than b, with non None zi fails
  1175. a = self.convert_dtype([1, 1])
  1176. b = self.convert_dtype([1, 0, -1])
  1177. zi = self.convert_dtype([2, 7])
  1178. x = self.convert_dtype([72])
  1179. ye = self.convert_dtype([74])
  1180. zfe = self.convert_dtype([-67, -72])
  1181. y, zf = lfilter(b, a, x, zi=zi)
  1182. assert_array_almost_equal(y, ye)
  1183. assert_array_almost_equal(zf, zfe)
  1184. def test_do_not_modify_a_b_IIR(self):
  1185. x = self.generate((6,))
  1186. b = self.convert_dtype([1, -1])
  1187. b0 = b.copy()
  1188. a = self.convert_dtype([0.5, -0.5])
  1189. a0 = a.copy()
  1190. y_r = self.convert_dtype([0, 2, 4, 6, 8, 10.])
  1191. y_f = lfilter(b, a, x)
  1192. assert_array_almost_equal(y_f, y_r)
  1193. assert_equal(b, b0)
  1194. assert_equal(a, a0)
  1195. def test_do_not_modify_a_b_FIR(self):
  1196. x = self.generate((6,))
  1197. b = self.convert_dtype([1, 0, 1])
  1198. b0 = b.copy()
  1199. a = self.convert_dtype([2])
  1200. a0 = a.copy()
  1201. y_r = self.convert_dtype([0, 0.5, 1, 2, 3, 4.])
  1202. y_f = lfilter(b, a, x)
  1203. assert_array_almost_equal(y_f, y_r)
  1204. assert_equal(b, b0)
  1205. assert_equal(a, a0)
  1206. class TestLinearFilterFloat32(_TestLinearFilter):
  1207. dtype = np.dtype('f')
  1208. class TestLinearFilterFloat64(_TestLinearFilter):
  1209. dtype = np.dtype('d')
  1210. class TestLinearFilterFloatExtended(_TestLinearFilter):
  1211. dtype = np.dtype('g')
  1212. class TestLinearFilterComplex64(_TestLinearFilter):
  1213. dtype = np.dtype('F')
  1214. class TestLinearFilterComplex128(_TestLinearFilter):
  1215. dtype = np.dtype('D')
  1216. class TestLinearFilterComplexExtended(_TestLinearFilter):
  1217. dtype = np.dtype('G')
  1218. class TestLinearFilterDecimal(_TestLinearFilter):
  1219. dtype = np.dtype('O')
  1220. def type(self, x):
  1221. return Decimal(str(x))
  1222. class TestLinearFilterObject(_TestLinearFilter):
  1223. dtype = np.dtype('O')
  1224. type = float
  1225. def test_lfilter_bad_object():
  1226. # lfilter: object arrays with non-numeric objects raise TypeError.
  1227. # Regression test for ticket #1452.
  1228. assert_raises(TypeError, lfilter, [1.0], [1.0], [1.0, None, 2.0])
  1229. assert_raises(TypeError, lfilter, [1.0], [None], [1.0, 2.0, 3.0])
  1230. assert_raises(TypeError, lfilter, [None], [1.0], [1.0, 2.0, 3.0])
  1231. def test_lfilter_notimplemented_input():
  1232. # Should not crash, gh-7991
  1233. assert_raises(NotImplementedError, lfilter, [2,3], [4,5], [1,2,3,4,5])
  1234. @pytest.mark.parametrize('dt', [np.ubyte, np.byte, np.ushort, np.short,
  1235. np.uint, int, np.ulonglong, np.ulonglong,
  1236. np.float32, np.float64, np.longdouble,
  1237. Decimal])
  1238. class TestCorrelateReal(object):
  1239. def _setup_rank1(self, dt):
  1240. a = np.linspace(0, 3, 4).astype(dt)
  1241. b = np.linspace(1, 2, 2).astype(dt)
  1242. y_r = np.array([0, 2, 5, 8, 3]).astype(dt)
  1243. return a, b, y_r
  1244. def equal_tolerance(self, res_dt):
  1245. # default value of keyword
  1246. decimal = 6
  1247. try:
  1248. dt_info = np.finfo(res_dt)
  1249. if hasattr(dt_info, 'resolution'):
  1250. decimal = int(-0.5*np.log10(dt_info.resolution))
  1251. except Exception:
  1252. pass
  1253. return decimal
  1254. def equal_tolerance_fft(self, res_dt):
  1255. # FFT implementations convert longdouble arguments down to
  1256. # double so don't expect better precision, see gh-9520
  1257. if res_dt == np.longdouble:
  1258. return self.equal_tolerance(np.double)
  1259. else:
  1260. return self.equal_tolerance(res_dt)
  1261. def test_method(self, dt):
  1262. if dt == Decimal:
  1263. method = choose_conv_method([Decimal(4)], [Decimal(3)])
  1264. assert_equal(method, 'direct')
  1265. else:
  1266. a, b, y_r = self._setup_rank3(dt)
  1267. y_fft = correlate(a, b, method='fft')
  1268. y_direct = correlate(a, b, method='direct')
  1269. assert_array_almost_equal(y_r, y_fft, decimal=self.equal_tolerance_fft(y_fft.dtype))
  1270. assert_array_almost_equal(y_r, y_direct, decimal=self.equal_tolerance(y_direct.dtype))
  1271. assert_equal(y_fft.dtype, dt)
  1272. assert_equal(y_direct.dtype, dt)
  1273. def test_rank1_valid(self, dt):
  1274. a, b, y_r = self._setup_rank1(dt)
  1275. y = correlate(a, b, 'valid')
  1276. assert_array_almost_equal(y, y_r[1:4])
  1277. assert_equal(y.dtype, dt)
  1278. # See gh-5897
  1279. y = correlate(b, a, 'valid')
  1280. assert_array_almost_equal(y, y_r[1:4][::-1])
  1281. assert_equal(y.dtype, dt)
  1282. def test_rank1_same(self, dt):
  1283. a, b, y_r = self._setup_rank1(dt)
  1284. y = correlate(a, b, 'same')
  1285. assert_array_almost_equal(y, y_r[:-1])
  1286. assert_equal(y.dtype, dt)
  1287. def test_rank1_full(self, dt):
  1288. a, b, y_r = self._setup_rank1(dt)
  1289. y = correlate(a, b, 'full')
  1290. assert_array_almost_equal(y, y_r)
  1291. assert_equal(y.dtype, dt)
  1292. def _setup_rank3(self, dt):
  1293. a = np.linspace(0, 39, 40).reshape((2, 4, 5), order='F').astype(
  1294. dt)
  1295. b = np.linspace(0, 23, 24).reshape((2, 3, 4), order='F').astype(
  1296. dt)
  1297. y_r = array([[[0., 184., 504., 912., 1360., 888., 472., 160.],
  1298. [46., 432., 1062., 1840., 2672., 1698., 864., 266.],
  1299. [134., 736., 1662., 2768., 3920., 2418., 1168., 314.],
  1300. [260., 952., 1932., 3056., 4208., 2580., 1240., 332.],
  1301. [202., 664., 1290., 1984., 2688., 1590., 712., 150.],
  1302. [114., 344., 642., 960., 1280., 726., 296., 38.]],
  1303. [[23., 400., 1035., 1832., 2696., 1737., 904., 293.],
  1304. [134., 920., 2166., 3680., 5280., 3306., 1640., 474.],
  1305. [325., 1544., 3369., 5512., 7720., 4683., 2192., 535.],
  1306. [571., 1964., 3891., 6064., 8272., 4989., 2324., 565.],
  1307. [434., 1360., 2586., 3920., 5264., 3054., 1312., 230.],
  1308. [241., 700., 1281., 1888., 2496., 1383., 532., 39.]],
  1309. [[22., 214., 528., 916., 1332., 846., 430., 132.],
  1310. [86., 484., 1098., 1832., 2600., 1602., 772., 206.],
  1311. [188., 802., 1698., 2732., 3788., 2256., 1018., 218.],
  1312. [308., 1006., 1950., 2996., 4052., 2400., 1078., 230.],
  1313. [230., 692., 1290., 1928., 2568., 1458., 596., 78.],
  1314. [126., 354., 636., 924., 1212., 654., 234., 0.]]],
  1315. dtype=dt)
  1316. return a, b, y_r
  1317. def test_rank3_valid(self, dt):
  1318. a, b, y_r = self._setup_rank3(dt)
  1319. y = correlate(a, b, "valid")
  1320. assert_array_almost_equal(y, y_r[1:2, 2:4, 3:5])
  1321. assert_equal(y.dtype, dt)
  1322. # See gh-5897
  1323. y = correlate(b, a, "valid")
  1324. assert_array_almost_equal(y, y_r[1:2, 2:4, 3:5][::-1, ::-1, ::-1])
  1325. assert_equal(y.dtype, dt)
  1326. def test_rank3_same(self, dt):
  1327. a, b, y_r = self._setup_rank3(dt)
  1328. y = correlate(a, b, "same")
  1329. assert_array_almost_equal(y, y_r[0:-1, 1:-1, 1:-2])
  1330. assert_equal(y.dtype, dt)
  1331. def test_rank3_all(self, dt):
  1332. a, b, y_r = self._setup_rank3(dt)
  1333. y = correlate(a, b)
  1334. assert_array_almost_equal(y, y_r)
  1335. assert_equal(y.dtype, dt)
  1336. class TestCorrelate(object):
  1337. # Tests that don't depend on dtype
  1338. def test_invalid_shapes(self):
  1339. # By "invalid," we mean that no one
  1340. # array has dimensions that are all at
  1341. # least as large as the corresponding
  1342. # dimensions of the other array. This
  1343. # setup should throw a ValueError.
  1344. a = np.arange(1, 7).reshape((2, 3))
  1345. b = np.arange(-6, 0).reshape((3, 2))
  1346. assert_raises(ValueError, correlate, *(a, b), **{'mode': 'valid'})
  1347. assert_raises(ValueError, correlate, *(b, a), **{'mode': 'valid'})
  1348. def test_invalid_params(self):
  1349. a = [3, 4, 5]
  1350. b = [1, 2, 3]
  1351. assert_raises(ValueError, correlate, a, b, mode='spam')
  1352. assert_raises(ValueError, correlate, a, b, mode='eggs', method='fft')
  1353. assert_raises(ValueError, correlate, a, b, mode='ham', method='direct')
  1354. assert_raises(ValueError, correlate, a, b, mode='full', method='bacon')
  1355. assert_raises(ValueError, correlate, a, b, mode='same', method='bacon')
  1356. def test_mismatched_dims(self):
  1357. # Input arrays should have the same number of dimensions
  1358. assert_raises(ValueError, correlate, [1], 2, method='direct')
  1359. assert_raises(ValueError, correlate, 1, [2], method='direct')
  1360. assert_raises(ValueError, correlate, [1], 2, method='fft')
  1361. assert_raises(ValueError, correlate, 1, [2], method='fft')
  1362. assert_raises(ValueError, correlate, [1], [[2]])
  1363. assert_raises(ValueError, correlate, [3], 2)
  1364. def test_numpy_fastpath(self):
  1365. a = [1, 2, 3]
  1366. b = [4, 5]
  1367. assert_allclose(correlate(a, b, mode='same'), [5, 14, 23])
  1368. a = [1, 2, 3]
  1369. b = [4, 5, 6]
  1370. assert_allclose(correlate(a, b, mode='same'), [17, 32, 23])
  1371. assert_allclose(correlate(a, b, mode='full'), [6, 17, 32, 23, 12])
  1372. assert_allclose(correlate(a, b, mode='valid'), [32])
  1373. @pytest.mark.parametrize('dt', [np.csingle, np.cdouble, np.clongdouble])
  1374. class TestCorrelateComplex(object):
  1375. # The decimal precision to be used for comparing results.
  1376. # This value will be passed as the 'decimal' keyword argument of
  1377. # assert_array_almost_equal().
  1378. # Since correlate may chose to use FFT method which converts
  1379. # longdoubles to doubles internally don't expect better precision
  1380. # for longdouble than for double (see gh-9520).
  1381. def decimal(self, dt):
  1382. if dt == np.clongdouble:
  1383. dt = np.cdouble
  1384. return int(2 * np.finfo(dt).precision / 3)
  1385. def _setup_rank1(self, dt, mode):
  1386. np.random.seed(9)
  1387. a = np.random.randn(10).astype(dt)
  1388. a += 1j * np.random.randn(10).astype(dt)
  1389. b = np.random.randn(8).astype(dt)
  1390. b += 1j * np.random.randn(8).astype(dt)
  1391. y_r = (correlate(a.real, b.real, mode=mode) +
  1392. correlate(a.imag, b.imag, mode=mode)).astype(dt)
  1393. y_r += 1j * (-correlate(a.real, b.imag, mode=mode) +
  1394. correlate(a.imag, b.real, mode=mode))
  1395. return a, b, y_r
  1396. def test_rank1_valid(self, dt):
  1397. a, b, y_r = self._setup_rank1(dt, 'valid')
  1398. y = correlate(a, b, 'valid')
  1399. assert_array_almost_equal(y, y_r, decimal=self.decimal(dt))
  1400. assert_equal(y.dtype, dt)
  1401. # See gh-5897
  1402. y = correlate(b, a, 'valid')
  1403. assert_array_almost_equal(y, y_r[::-1].conj(), decimal=self.decimal(dt))
  1404. assert_equal(y.dtype, dt)
  1405. def test_rank1_same(self, dt):
  1406. a, b, y_r = self._setup_rank1(dt, 'same')
  1407. y = correlate(a, b, 'same')
  1408. assert_array_almost_equal(y, y_r, decimal=self.decimal(dt))
  1409. assert_equal(y.dtype, dt)
  1410. def test_rank1_full(self, dt):
  1411. a, b, y_r = self._setup_rank1(dt, 'full')
  1412. y = correlate(a, b, 'full')
  1413. assert_array_almost_equal(y, y_r, decimal=self.decimal(dt))
  1414. assert_equal(y.dtype, dt)
  1415. def test_swap_full(self, dt):
  1416. d = np.array([0.+0.j, 1.+1.j, 2.+2.j], dtype=dt)
  1417. k = np.array([1.+3.j, 2.+4.j, 3.+5.j, 4.+6.j], dtype=dt)
  1418. y = correlate(d, k)
  1419. assert_equal(y, [0.+0.j, 10.-2.j, 28.-6.j, 22.-6.j, 16.-6.j, 8.-4.j])
  1420. def test_swap_same(self, dt):
  1421. d = [0.+0.j, 1.+1.j, 2.+2.j]
  1422. k = [1.+3.j, 2.+4.j, 3.+5.j, 4.+6.j]
  1423. y = correlate(d, k, mode="same")
  1424. assert_equal(y, [10.-2.j, 28.-6.j, 22.-6.j])
  1425. def test_rank3(self, dt):
  1426. a = np.random.randn(10, 8, 6).astype(dt)
  1427. a += 1j * np.random.randn(10, 8, 6).astype(dt)
  1428. b = np.random.randn(8, 6, 4).astype(dt)
  1429. b += 1j * np.random.randn(8, 6, 4).astype(dt)
  1430. y_r = (correlate(a.real, b.real)
  1431. + correlate(a.imag, b.imag)).astype(dt)
  1432. y_r += 1j * (-correlate(a.real, b.imag) + correlate(a.imag, b.real))
  1433. y = correlate(a, b, 'full')
  1434. assert_array_almost_equal(y, y_r, decimal=self.decimal(dt) - 1)
  1435. assert_equal(y.dtype, dt)
  1436. def test_rank0(self, dt):
  1437. a = np.array(np.random.randn()).astype(dt)
  1438. a += 1j * np.array(np.random.randn()).astype(dt)
  1439. b = np.array(np.random.randn()).astype(dt)
  1440. b += 1j * np.array(np.random.randn()).astype(dt)
  1441. y_r = (correlate(a.real, b.real)
  1442. + correlate(a.imag, b.imag)).astype(dt)
  1443. y_r += 1j * (-correlate(a.real, b.imag) + correlate(a.imag, b.real))
  1444. y = correlate(a, b, 'full')
  1445. assert_array_almost_equal(y, y_r, decimal=self.decimal(dt) - 1)
  1446. assert_equal(y.dtype, dt)
  1447. assert_equal(correlate([1], [2j]), correlate(1, 2j))
  1448. assert_equal(correlate([2j], [3j]), correlate(2j, 3j))
  1449. assert_equal(correlate([3j], [4]), correlate(3j, 4))
  1450. class TestCorrelate2d(object):
  1451. def test_consistency_correlate_funcs(self):
  1452. # Compare np.correlate, signal.correlate, signal.correlate2d
  1453. a = np.arange(5)
  1454. b = np.array([3.2, 1.4, 3])
  1455. for mode in ['full', 'valid', 'same']:
  1456. assert_almost_equal(np.correlate(a, b, mode=mode),
  1457. signal.correlate(a, b, mode=mode))
  1458. assert_almost_equal(np.squeeze(signal.correlate2d([a], [b],
  1459. mode=mode)),
  1460. signal.correlate(a, b, mode=mode))
  1461. # See gh-5897
  1462. if mode == 'valid':
  1463. assert_almost_equal(np.correlate(b, a, mode=mode),
  1464. signal.correlate(b, a, mode=mode))
  1465. assert_almost_equal(np.squeeze(signal.correlate2d([b], [a],
  1466. mode=mode)),
  1467. signal.correlate(b, a, mode=mode))
  1468. def test_invalid_shapes(self):
  1469. # By "invalid," we mean that no one
  1470. # array has dimensions that are all at
  1471. # least as large as the corresponding
  1472. # dimensions of the other array. This
  1473. # setup should throw a ValueError.
  1474. a = np.arange(1, 7).reshape((2, 3))
  1475. b = np.arange(-6, 0).reshape((3, 2))
  1476. assert_raises(ValueError, signal.correlate2d, *(a, b), **{'mode': 'valid'})
  1477. assert_raises(ValueError, signal.correlate2d, *(b, a), **{'mode': 'valid'})
  1478. def test_complex_input(self):
  1479. assert_equal(signal.correlate2d([[1]], [[2j]]), -2j)
  1480. assert_equal(signal.correlate2d([[2j]], [[3j]]), 6)
  1481. assert_equal(signal.correlate2d([[3j]], [[4]]), 12j)
  1482. class TestLFilterZI(object):
  1483. def test_basic(self):
  1484. a = np.array([1.0, -1.0, 0.5])
  1485. b = np.array([1.0, 0.0, 2.0])
  1486. zi_expected = np.array([5.0, -1.0])
  1487. zi = lfilter_zi(b, a)
  1488. assert_array_almost_equal(zi, zi_expected)
  1489. def test_scale_invariance(self):
  1490. # Regression test. There was a bug in which b was not correctly
  1491. # rescaled when a[0] was nonzero.
  1492. b = np.array([2, 8, 5])
  1493. a = np.array([1, 1, 8])
  1494. zi1 = lfilter_zi(b, a)
  1495. zi2 = lfilter_zi(2*b, 2*a)
  1496. assert_allclose(zi2, zi1, rtol=1e-12)
  1497. class TestFiltFilt(object):
  1498. filtfilt_kind = 'tf'
  1499. def filtfilt(self, zpk, x, axis=-1, padtype='odd', padlen=None,
  1500. method='pad', irlen=None):
  1501. if self.filtfilt_kind == 'tf':
  1502. b, a = zpk2tf(*zpk)
  1503. return filtfilt(b, a, x, axis, padtype, padlen, method, irlen)
  1504. elif self.filtfilt_kind == 'sos':
  1505. sos = zpk2sos(*zpk)
  1506. return sosfiltfilt(sos, x, axis, padtype, padlen)
  1507. def test_basic(self):
  1508. zpk = tf2zpk([1, 2, 3], [1, 2, 3])
  1509. out = self.filtfilt(zpk, np.arange(12))
  1510. assert_allclose(out, arange(12), atol=1e-11)
  1511. def test_sine(self):
  1512. rate = 2000
  1513. t = np.linspace(0, 1.0, rate + 1)
  1514. # A signal with low frequency and a high frequency.
  1515. xlow = np.sin(5 * 2 * np.pi * t)
  1516. xhigh = np.sin(250 * 2 * np.pi * t)
  1517. x = xlow + xhigh
  1518. zpk = butter(8, 0.125, output='zpk')
  1519. # r is the magnitude of the largest pole.
  1520. r = np.abs(zpk[1]).max()
  1521. eps = 1e-5
  1522. # n estimates the number of steps for the
  1523. # transient to decay by a factor of eps.
  1524. n = int(np.ceil(np.log(eps) / np.log(r)))
  1525. # High order lowpass filter...
  1526. y = self.filtfilt(zpk, x, padlen=n)
  1527. # Result should be just xlow.
  1528. err = np.abs(y - xlow).max()
  1529. assert_(err < 1e-4)
  1530. # A 2D case.
  1531. x2d = np.vstack([xlow, xlow + xhigh])
  1532. y2d = self.filtfilt(zpk, x2d, padlen=n, axis=1)
  1533. assert_equal(y2d.shape, x2d.shape)
  1534. err = np.abs(y2d - xlow).max()
  1535. assert_(err < 1e-4)
  1536. # Use the previous result to check the use of the axis keyword.
  1537. # (Regression test for ticket #1620)
  1538. y2dt = self.filtfilt(zpk, x2d.T, padlen=n, axis=0)
  1539. assert_equal(y2d, y2dt.T)
  1540. def test_axis(self):
  1541. # Test the 'axis' keyword on a 3D array.
  1542. x = np.arange(10.0 * 11.0 * 12.0).reshape(10, 11, 12)
  1543. zpk = butter(3, 0.125, output='zpk')
  1544. y0 = self.filtfilt(zpk, x, padlen=0, axis=0)
  1545. y1 = self.filtfilt(zpk, np.swapaxes(x, 0, 1), padlen=0, axis=1)
  1546. assert_array_equal(y0, np.swapaxes(y1, 0, 1))
  1547. y2 = self.filtfilt(zpk, np.swapaxes(x, 0, 2), padlen=0, axis=2)
  1548. assert_array_equal(y0, np.swapaxes(y2, 0, 2))
  1549. def test_acoeff(self):
  1550. if self.filtfilt_kind != 'tf':
  1551. return # only necessary for TF
  1552. # test for 'a' coefficient as single number
  1553. out = signal.filtfilt([.5, .5], 1, np.arange(10))
  1554. assert_allclose(out, np.arange(10), rtol=1e-14, atol=1e-14)
  1555. def test_gust_simple(self):
  1556. if self.filtfilt_kind != 'tf':
  1557. pytest.skip('gust only implemented for TF systems')
  1558. # The input array has length 2. The exact solution for this case
  1559. # was computed "by hand".
  1560. x = np.array([1.0, 2.0])
  1561. b = np.array([0.5])
  1562. a = np.array([1.0, -0.5])
  1563. y, z1, z2 = _filtfilt_gust(b, a, x)
  1564. assert_allclose([z1[0], z2[0]],
  1565. [0.3*x[0] + 0.2*x[1], 0.2*x[0] + 0.3*x[1]])
  1566. assert_allclose(y, [z1[0] + 0.25*z2[0] + 0.25*x[0] + 0.125*x[1],
  1567. 0.25*z1[0] + z2[0] + 0.125*x[0] + 0.25*x[1]])
  1568. def test_gust_scalars(self):
  1569. if self.filtfilt_kind != 'tf':
  1570. pytest.skip('gust only implemented for TF systems')
  1571. # The filter coefficients are both scalars, so the filter simply
  1572. # multiplies its input by b/a. When it is used in filtfilt, the
  1573. # factor is (b/a)**2.
  1574. x = np.arange(12)
  1575. b = 3.0
  1576. a = 2.0
  1577. y = filtfilt(b, a, x, method="gust")
  1578. expected = (b/a)**2 * x
  1579. assert_allclose(y, expected)
  1580. class TestSOSFiltFilt(TestFiltFilt):
  1581. filtfilt_kind = 'sos'
  1582. def test_equivalence(self):
  1583. """Test equivalence between sosfiltfilt and filtfilt"""
  1584. x = np.random.RandomState(0).randn(1000)
  1585. for order in range(1, 6):
  1586. zpk = signal.butter(order, 0.35, output='zpk')
  1587. b, a = zpk2tf(*zpk)
  1588. sos = zpk2sos(*zpk)
  1589. y = filtfilt(b, a, x)
  1590. y_sos = sosfiltfilt(sos, x)
  1591. assert_allclose(y, y_sos, atol=1e-12, err_msg='order=%s' % order)
  1592. def filtfilt_gust_opt(b, a, x):
  1593. """
  1594. An alternative implementation of filtfilt with Gustafsson edges.
  1595. This function computes the same result as
  1596. `scipy.signal.signaltools._filtfilt_gust`, but only 1-d arrays
  1597. are accepted. The problem is solved using `fmin` from `scipy.optimize`.
  1598. `_filtfilt_gust` is significanly faster than this implementation.
  1599. """
  1600. def filtfilt_gust_opt_func(ics, b, a, x):
  1601. """Objective function used in filtfilt_gust_opt."""
  1602. m = max(len(a), len(b)) - 1
  1603. z0f = ics[:m]
  1604. z0b = ics[m:]
  1605. y_f = lfilter(b, a, x, zi=z0f)[0]
  1606. y_fb = lfilter(b, a, y_f[::-1], zi=z0b)[0][::-1]
  1607. y_b = lfilter(b, a, x[::-1], zi=z0b)[0][::-1]
  1608. y_bf = lfilter(b, a, y_b, zi=z0f)[0]
  1609. value = np.sum((y_fb - y_bf)**2)
  1610. return value
  1611. m = max(len(a), len(b)) - 1
  1612. zi = lfilter_zi(b, a)
  1613. ics = np.concatenate((x[:m].mean()*zi, x[-m:].mean()*zi))
  1614. result = fmin(filtfilt_gust_opt_func, ics, args=(b, a, x),
  1615. xtol=1e-10, ftol=1e-12,
  1616. maxfun=10000, maxiter=10000,
  1617. full_output=True, disp=False)
  1618. opt, fopt, niter, funcalls, warnflag = result
  1619. if warnflag > 0:
  1620. raise RuntimeError("minimization failed in filtfilt_gust_opt: "
  1621. "warnflag=%d" % warnflag)
  1622. z0f = opt[:m]
  1623. z0b = opt[m:]
  1624. # Apply the forward-backward filter using the computed initial
  1625. # conditions.
  1626. y_b = lfilter(b, a, x[::-1], zi=z0b)[0][::-1]
  1627. y = lfilter(b, a, y_b, zi=z0f)[0]
  1628. return y, z0f, z0b
  1629. def check_filtfilt_gust(b, a, shape, axis, irlen=None):
  1630. # Generate x, the data to be filtered.
  1631. np.random.seed(123)
  1632. x = np.random.randn(*shape)
  1633. # Apply filtfilt to x. This is the main calculation to be checked.
  1634. y = filtfilt(b, a, x, axis=axis, method="gust", irlen=irlen)
  1635. # Also call the private function so we can test the ICs.
  1636. yg, zg1, zg2 = _filtfilt_gust(b, a, x, axis=axis, irlen=irlen)
  1637. # filtfilt_gust_opt is an independent implementation that gives the
  1638. # expected result, but it only handles 1-d arrays, so use some looping
  1639. # and reshaping shenanigans to create the expected output arrays.
  1640. xx = np.swapaxes(x, axis, -1)
  1641. out_shape = xx.shape[:-1]
  1642. yo = np.empty_like(xx)
  1643. m = max(len(a), len(b)) - 1
  1644. zo1 = np.empty(out_shape + (m,))
  1645. zo2 = np.empty(out_shape + (m,))
  1646. for indx in product(*[range(d) for d in out_shape]):
  1647. yo[indx], zo1[indx], zo2[indx] = filtfilt_gust_opt(b, a, xx[indx])
  1648. yo = np.swapaxes(yo, -1, axis)
  1649. zo1 = np.swapaxes(zo1, -1, axis)
  1650. zo2 = np.swapaxes(zo2, -1, axis)
  1651. assert_allclose(y, yo, rtol=1e-9, atol=1e-10)
  1652. assert_allclose(yg, yo, rtol=1e-9, atol=1e-10)
  1653. assert_allclose(zg1, zo1, rtol=1e-9, atol=1e-10)
  1654. assert_allclose(zg2, zo2, rtol=1e-9, atol=1e-10)
  1655. def test_choose_conv_method():
  1656. for mode in ['valid', 'same', 'full']:
  1657. for ndims in [1, 2]:
  1658. n, k, true_method = 8, 6, 'direct'
  1659. x = np.random.randn(*((n,) * ndims))
  1660. h = np.random.randn(*((k,) * ndims))
  1661. method = choose_conv_method(x, h, mode=mode)
  1662. assert_equal(method, true_method)
  1663. method_try, times = choose_conv_method(x, h, mode=mode, measure=True)
  1664. assert_(method_try in {'fft', 'direct'})
  1665. assert_(type(times) is dict)
  1666. assert_('fft' in times.keys() and 'direct' in times.keys())
  1667. n = 10
  1668. for not_fft_conv_supp in ["complex256", "complex192"]:
  1669. if hasattr(np, not_fft_conv_supp):
  1670. x = np.ones(n, dtype=not_fft_conv_supp)
  1671. h = x.copy()
  1672. assert_equal(choose_conv_method(x, h, mode=mode), 'direct')
  1673. x = np.array([2**51], dtype=np.int64)
  1674. h = x.copy()
  1675. assert_equal(choose_conv_method(x, h, mode=mode), 'direct')
  1676. x = [Decimal(3), Decimal(2)]
  1677. h = [Decimal(1), Decimal(4)]
  1678. assert_equal(choose_conv_method(x, h, mode=mode), 'direct')
  1679. def test_filtfilt_gust():
  1680. # Design a filter.
  1681. z, p, k = signal.ellip(3, 0.01, 120, 0.0875, output='zpk')
  1682. # Find the approximate impulse response length of the filter.
  1683. eps = 1e-10
  1684. r = np.max(np.abs(p))
  1685. approx_impulse_len = int(np.ceil(np.log(eps) / np.log(r)))
  1686. np.random.seed(123)
  1687. b, a = zpk2tf(z, p, k)
  1688. for irlen in [None, approx_impulse_len]:
  1689. signal_len = 5 * approx_impulse_len
  1690. # 1-d test case
  1691. check_filtfilt_gust(b, a, (signal_len,), 0, irlen)
  1692. # 3-d test case; test each axis.
  1693. for axis in range(3):
  1694. shape = [2, 2, 2]
  1695. shape[axis] = signal_len
  1696. check_filtfilt_gust(b, a, shape, axis, irlen)
  1697. # Test case with length less than 2*approx_impulse_len.
  1698. # In this case, `filtfilt_gust` should behave the same as if
  1699. # `irlen=None` was given.
  1700. length = 2*approx_impulse_len - 50
  1701. check_filtfilt_gust(b, a, (length,), 0, approx_impulse_len)
  1702. class TestDecimate(object):
  1703. def test_bad_args(self):
  1704. x = np.arange(12)
  1705. assert_raises(TypeError, signal.decimate, x, q=0.5, n=1)
  1706. assert_raises(TypeError, signal.decimate, x, q=2, n=0.5)
  1707. def test_basic_IIR(self):
  1708. x = np.arange(12)
  1709. y = signal.decimate(x, 2, n=1, ftype='iir', zero_phase=False).round()
  1710. assert_array_equal(y, x[::2])
  1711. def test_basic_FIR(self):
  1712. x = np.arange(12)
  1713. y = signal.decimate(x, 2, n=1, ftype='fir', zero_phase=False).round()
  1714. assert_array_equal(y, x[::2])
  1715. def test_shape(self):
  1716. # Regression test for ticket #1480.
  1717. z = np.zeros((30, 30))
  1718. d0 = signal.decimate(z, 2, axis=0, zero_phase=False)
  1719. assert_equal(d0.shape, (15, 30))
  1720. d1 = signal.decimate(z, 2, axis=1, zero_phase=False)
  1721. assert_equal(d1.shape, (30, 15))
  1722. def test_phaseshift_FIR(self):
  1723. with suppress_warnings() as sup:
  1724. sup.filter(BadCoefficients, "Badly conditioned filter")
  1725. self._test_phaseshift(method='fir', zero_phase=False)
  1726. def test_zero_phase_FIR(self):
  1727. with suppress_warnings() as sup:
  1728. sup.filter(BadCoefficients, "Badly conditioned filter")
  1729. self._test_phaseshift(method='fir', zero_phase=True)
  1730. def test_phaseshift_IIR(self):
  1731. self._test_phaseshift(method='iir', zero_phase=False)
  1732. def test_zero_phase_IIR(self):
  1733. self._test_phaseshift(method='iir', zero_phase=True)
  1734. def _test_phaseshift(self, method, zero_phase):
  1735. rate = 120
  1736. rates_to = [15, 20, 30, 40] # q = 8, 6, 4, 3
  1737. t_tot = int(100) # Need to let antialiasing filters settle
  1738. t = np.arange(rate*t_tot+1) / float(rate)
  1739. # Sinusoids at 0.8*nyquist, windowed to avoid edge artifacts
  1740. freqs = np.array(rates_to) * 0.8 / 2
  1741. d = (np.exp(1j * 2 * np.pi * freqs[:, np.newaxis] * t)
  1742. * signal.windows.tukey(t.size, 0.1))
  1743. for rate_to in rates_to:
  1744. q = rate // rate_to
  1745. t_to = np.arange(rate_to*t_tot+1) / float(rate_to)
  1746. d_tos = (np.exp(1j * 2 * np.pi * freqs[:, np.newaxis] * t_to)
  1747. * signal.windows.tukey(t_to.size, 0.1))
  1748. # Set up downsampling filters, match v0.17 defaults
  1749. if method == 'fir':
  1750. n = 30
  1751. system = signal.dlti(signal.firwin(n + 1, 1. / q,
  1752. window='hamming'), 1.)
  1753. elif method == 'iir':
  1754. n = 8
  1755. wc = 0.8*np.pi/q
  1756. system = signal.dlti(*signal.cheby1(n, 0.05, wc/np.pi))
  1757. # Calculate expected phase response, as unit complex vector
  1758. if zero_phase is False:
  1759. _, h_resps = signal.freqz(system.num, system.den,
  1760. freqs/rate*2*np.pi)
  1761. h_resps /= np.abs(h_resps)
  1762. else:
  1763. h_resps = np.ones_like(freqs)
  1764. y_resamps = signal.decimate(d.real, q, n, ftype=system,
  1765. zero_phase=zero_phase)
  1766. # Get phase from complex inner product, like CSD
  1767. h_resamps = np.sum(d_tos.conj() * y_resamps, axis=-1)
  1768. h_resamps /= np.abs(h_resamps)
  1769. subnyq = freqs < 0.5*rate_to
  1770. # Complex vectors should be aligned, only compare below nyquist
  1771. assert_allclose(np.angle(h_resps.conj()*h_resamps)[subnyq], 0,
  1772. atol=1e-3, rtol=1e-3)
  1773. def test_auto_n(self):
  1774. # Test that our value of n is a reasonable choice (depends on
  1775. # the downsampling factor)
  1776. sfreq = 100.
  1777. n = 1000
  1778. t = np.arange(n) / sfreq
  1779. # will alias for decimations (>= 15)
  1780. x = np.sqrt(2. / n) * np.sin(2 * np.pi * (sfreq / 30.) * t)
  1781. assert_allclose(np.linalg.norm(x), 1., rtol=1e-3)
  1782. x_out = signal.decimate(x, 30, ftype='fir')
  1783. assert_array_less(np.linalg.norm(x_out), 0.01)
  1784. class TestHilbert(object):
  1785. def test_bad_args(self):
  1786. x = np.array([1.0 + 0.0j])
  1787. assert_raises(ValueError, hilbert, x)
  1788. x = np.arange(8.0)
  1789. assert_raises(ValueError, hilbert, x, N=0)
  1790. def test_hilbert_theoretical(self):
  1791. # test cases by Ariel Rokem
  1792. decimal = 14
  1793. pi = np.pi
  1794. t = np.arange(0, 2 * pi, pi / 256)
  1795. a0 = np.sin(t)
  1796. a1 = np.cos(t)
  1797. a2 = np.sin(2 * t)
  1798. a3 = np.cos(2 * t)
  1799. a = np.vstack([a0, a1, a2, a3])
  1800. h = hilbert(a)
  1801. h_abs = np.abs(h)
  1802. h_angle = np.angle(h)
  1803. h_real = np.real(h)
  1804. # The real part should be equal to the original signals:
  1805. assert_almost_equal(h_real, a, decimal)
  1806. # The absolute value should be one everywhere, for this input:
  1807. assert_almost_equal(h_abs, np.ones(a.shape), decimal)
  1808. # For the 'slow' sine - the phase should go from -pi/2 to pi/2 in
  1809. # the first 256 bins:
  1810. assert_almost_equal(h_angle[0, :256],
  1811. np.arange(-pi / 2, pi / 2, pi / 256),
  1812. decimal)
  1813. # For the 'slow' cosine - the phase should go from 0 to pi in the
  1814. # same interval:
  1815. assert_almost_equal(
  1816. h_angle[1, :256], np.arange(0, pi, pi / 256), decimal)
  1817. # The 'fast' sine should make this phase transition in half the time:
  1818. assert_almost_equal(h_angle[2, :128],
  1819. np.arange(-pi / 2, pi / 2, pi / 128),
  1820. decimal)
  1821. # Ditto for the 'fast' cosine:
  1822. assert_almost_equal(
  1823. h_angle[3, :128], np.arange(0, pi, pi / 128), decimal)
  1824. # The imaginary part of hilbert(cos(t)) = sin(t) Wikipedia
  1825. assert_almost_equal(h[1].imag, a0, decimal)
  1826. def test_hilbert_axisN(self):
  1827. # tests for axis and N arguments
  1828. a = np.arange(18).reshape(3, 6)
  1829. # test axis
  1830. aa = hilbert(a, axis=-1)
  1831. assert_equal(hilbert(a.T, axis=0), aa.T)
  1832. # test 1d
  1833. assert_almost_equal(hilbert(a[0]), aa[0], 14)
  1834. # test N
  1835. aan = hilbert(a, N=20, axis=-1)
  1836. assert_equal(aan.shape, [3, 20])
  1837. assert_equal(hilbert(a.T, N=20, axis=0).shape, [20, 3])
  1838. # the next test is just a regression test,
  1839. # no idea whether numbers make sense
  1840. a0hilb = np.array([0.000000000000000e+00 - 1.72015830311905j,
  1841. 1.000000000000000e+00 - 2.047794505137069j,
  1842. 1.999999999999999e+00 - 2.244055555687583j,
  1843. 3.000000000000000e+00 - 1.262750302935009j,
  1844. 4.000000000000000e+00 - 1.066489252384493j,
  1845. 5.000000000000000e+00 + 2.918022706971047j,
  1846. 8.881784197001253e-17 + 3.845658908989067j,
  1847. -9.444121133484362e-17 + 0.985044202202061j,
  1848. -1.776356839400251e-16 + 1.332257797702019j,
  1849. -3.996802888650564e-16 + 0.501905089898885j,
  1850. 1.332267629550188e-16 + 0.668696078880782j,
  1851. -1.192678053963799e-16 + 0.235487067862679j,
  1852. -1.776356839400251e-16 + 0.286439612812121j,
  1853. 3.108624468950438e-16 + 0.031676888064907j,
  1854. 1.332267629550188e-16 - 0.019275656884536j,
  1855. -2.360035624836702e-16 - 0.1652588660287j,
  1856. 0.000000000000000e+00 - 0.332049855010597j,
  1857. 3.552713678800501e-16 - 0.403810179797771j,
  1858. 8.881784197001253e-17 - 0.751023775297729j,
  1859. 9.444121133484362e-17 - 0.79252210110103j])
  1860. assert_almost_equal(aan[0], a0hilb, 14, 'N regression')
  1861. class TestHilbert2(object):
  1862. def test_bad_args(self):
  1863. # x must be real.
  1864. x = np.array([[1.0 + 0.0j]])
  1865. assert_raises(ValueError, hilbert2, x)
  1866. # x must be rank 2.
  1867. x = np.arange(24).reshape(2, 3, 4)
  1868. assert_raises(ValueError, hilbert2, x)
  1869. # Bad value for N.
  1870. x = np.arange(16).reshape(4, 4)
  1871. assert_raises(ValueError, hilbert2, x, N=0)
  1872. assert_raises(ValueError, hilbert2, x, N=(2, 0))
  1873. assert_raises(ValueError, hilbert2, x, N=(2,))
  1874. class TestPartialFractionExpansion(object):
  1875. def test_invresz_one_coefficient_bug(self):
  1876. # Regression test for issue in gh-4646.
  1877. r = [1]
  1878. p = [2]
  1879. k = [0]
  1880. a_expected = [1.0, 0.0]
  1881. b_expected = [1.0, -2.0]
  1882. a_observed, b_observed = invresz(r, p, k)
  1883. assert_allclose(a_observed, a_expected)
  1884. assert_allclose(b_observed, b_expected)
  1885. def test_invres_distinct_roots(self):
  1886. # This test was inspired by github issue 2496.
  1887. r = [3 / 10, -1 / 6, -2 / 15]
  1888. p = [0, -2, -5]
  1889. k = []
  1890. a_expected = [1, 3]
  1891. b_expected = [1, 7, 10, 0]
  1892. a_observed, b_observed = invres(r, p, k)
  1893. assert_allclose(a_observed, a_expected)
  1894. assert_allclose(b_observed, b_expected)
  1895. rtypes = ('avg', 'mean', 'min', 'minimum', 'max', 'maximum')
  1896. # With the default tolerance, the rtype does not matter
  1897. # for this example.
  1898. for rtype in rtypes:
  1899. a_observed, b_observed = invres(r, p, k, rtype=rtype)
  1900. assert_allclose(a_observed, a_expected)
  1901. assert_allclose(b_observed, b_expected)
  1902. # With unrealistically large tolerances, repeated roots may be inferred
  1903. # and the rtype comes into play.
  1904. ridiculous_tolerance = 1e10
  1905. for rtype in rtypes:
  1906. a, b = invres(r, p, k, tol=ridiculous_tolerance, rtype=rtype)
  1907. def test_invres_repeated_roots(self):
  1908. r = [3 / 20, -7 / 36, -1 / 6, 2 / 45]
  1909. p = [0, -2, -2, -5]
  1910. k = []
  1911. a_expected = [1, 3]
  1912. b_expected = [1, 9, 24, 20, 0]
  1913. rtypes = ('avg', 'mean', 'min', 'minimum', 'max', 'maximum')
  1914. for rtype in rtypes:
  1915. a_observed, b_observed = invres(r, p, k, rtype=rtype)
  1916. assert_allclose(a_observed, a_expected)
  1917. assert_allclose(b_observed, b_expected)
  1918. def test_invres_bad_rtype(self):
  1919. r = [3 / 20, -7 / 36, -1 / 6, 2 / 45]
  1920. p = [0, -2, -2, -5]
  1921. k = []
  1922. assert_raises(ValueError, invres, r, p, k, rtype='median')
  1923. class TestVectorstrength(object):
  1924. def test_single_1dperiod(self):
  1925. events = np.array([.5])
  1926. period = 5.
  1927. targ_strength = 1.
  1928. targ_phase = .1
  1929. strength, phase = vectorstrength(events, period)
  1930. assert_equal(strength.ndim, 0)
  1931. assert_equal(phase.ndim, 0)
  1932. assert_almost_equal(strength, targ_strength)
  1933. assert_almost_equal(phase, 2 * np.pi * targ_phase)
  1934. def test_single_2dperiod(self):
  1935. events = np.array([.5])
  1936. period = [1, 2, 5.]
  1937. targ_strength = [1.] * 3
  1938. targ_phase = np.array([.5, .25, .1])
  1939. strength, phase = vectorstrength(events, period)
  1940. assert_equal(strength.ndim, 1)
  1941. assert_equal(phase.ndim, 1)
  1942. assert_array_almost_equal(strength, targ_strength)
  1943. assert_almost_equal(phase, 2 * np.pi * targ_phase)
  1944. def test_equal_1dperiod(self):
  1945. events = np.array([.25, .25, .25, .25, .25, .25])
  1946. period = 2
  1947. targ_strength = 1.
  1948. targ_phase = .125
  1949. strength, phase = vectorstrength(events, period)
  1950. assert_equal(strength.ndim, 0)
  1951. assert_equal(phase.ndim, 0)
  1952. assert_almost_equal(strength, targ_strength)
  1953. assert_almost_equal(phase, 2 * np.pi * targ_phase)
  1954. def test_equal_2dperiod(self):
  1955. events = np.array([.25, .25, .25, .25, .25, .25])
  1956. period = [1, 2, ]
  1957. targ_strength = [1.] * 2
  1958. targ_phase = np.array([.25, .125])
  1959. strength, phase = vectorstrength(events, period)
  1960. assert_equal(strength.ndim, 1)
  1961. assert_equal(phase.ndim, 1)
  1962. assert_almost_equal(strength, targ_strength)
  1963. assert_almost_equal(phase, 2 * np.pi * targ_phase)
  1964. def test_spaced_1dperiod(self):
  1965. events = np.array([.1, 1.1, 2.1, 4.1, 10.1])
  1966. period = 1
  1967. targ_strength = 1.
  1968. targ_phase = .1
  1969. strength, phase = vectorstrength(events, period)
  1970. assert_equal(strength.ndim, 0)
  1971. assert_equal(phase.ndim, 0)
  1972. assert_almost_equal(strength, targ_strength)
  1973. assert_almost_equal(phase, 2 * np.pi * targ_phase)
  1974. def test_spaced_2dperiod(self):
  1975. events = np.array([.1, 1.1, 2.1, 4.1, 10.1])
  1976. period = [1, .5]
  1977. targ_strength = [1.] * 2
  1978. targ_phase = np.array([.1, .2])
  1979. strength, phase = vectorstrength(events, period)
  1980. assert_equal(strength.ndim, 1)
  1981. assert_equal(phase.ndim, 1)
  1982. assert_almost_equal(strength, targ_strength)
  1983. assert_almost_equal(phase, 2 * np.pi * targ_phase)
  1984. def test_partial_1dperiod(self):
  1985. events = np.array([.25, .5, .75])
  1986. period = 1
  1987. targ_strength = 1. / 3.
  1988. targ_phase = .5
  1989. strength, phase = vectorstrength(events, period)
  1990. assert_equal(strength.ndim, 0)
  1991. assert_equal(phase.ndim, 0)
  1992. assert_almost_equal(strength, targ_strength)
  1993. assert_almost_equal(phase, 2 * np.pi * targ_phase)
  1994. def test_partial_2dperiod(self):
  1995. events = np.array([.25, .5, .75])
  1996. period = [1., 1., 1., 1.]
  1997. targ_strength = [1. / 3.] * 4
  1998. targ_phase = np.array([.5, .5, .5, .5])
  1999. strength, phase = vectorstrength(events, period)
  2000. assert_equal(strength.ndim, 1)
  2001. assert_equal(phase.ndim, 1)
  2002. assert_almost_equal(strength, targ_strength)
  2003. assert_almost_equal(phase, 2 * np.pi * targ_phase)
  2004. def test_opposite_1dperiod(self):
  2005. events = np.array([0, .25, .5, .75])
  2006. period = 1.
  2007. targ_strength = 0
  2008. strength, phase = vectorstrength(events, period)
  2009. assert_equal(strength.ndim, 0)
  2010. assert_equal(phase.ndim, 0)
  2011. assert_almost_equal(strength, targ_strength)
  2012. def test_opposite_2dperiod(self):
  2013. events = np.array([0, .25, .5, .75])
  2014. period = [1.] * 10
  2015. targ_strength = [0.] * 10
  2016. strength, phase = vectorstrength(events, period)
  2017. assert_equal(strength.ndim, 1)
  2018. assert_equal(phase.ndim, 1)
  2019. assert_almost_equal(strength, targ_strength)
  2020. def test_2d_events_ValueError(self):
  2021. events = np.array([[1, 2]])
  2022. period = 1.
  2023. assert_raises(ValueError, vectorstrength, events, period)
  2024. def test_2d_period_ValueError(self):
  2025. events = 1.
  2026. period = np.array([[1]])
  2027. assert_raises(ValueError, vectorstrength, events, period)
  2028. def test_zero_period_ValueError(self):
  2029. events = 1.
  2030. period = 0
  2031. assert_raises(ValueError, vectorstrength, events, period)
  2032. def test_negative_period_ValueError(self):
  2033. events = 1.
  2034. period = -1
  2035. assert_raises(ValueError, vectorstrength, events, period)
  2036. class TestSOSFilt(object):
  2037. # For sosfilt we only test a single datatype. Since sosfilt wraps
  2038. # to lfilter under the hood, it's hopefully good enough to ensure
  2039. # lfilter is extensively tested.
  2040. dt = np.float64
  2041. # The test_rank* tests are pulled from _TestLinearFilter
  2042. def test_rank1(self):
  2043. x = np.linspace(0, 5, 6).astype(self.dt)
  2044. b = np.array([1, -1]).astype(self.dt)
  2045. a = np.array([0.5, -0.5]).astype(self.dt)
  2046. # Test simple IIR
  2047. y_r = np.array([0, 2, 4, 6, 8, 10.]).astype(self.dt)
  2048. assert_array_almost_equal(sosfilt(tf2sos(b, a), x), y_r)
  2049. # Test simple FIR
  2050. b = np.array([1, 1]).astype(self.dt)
  2051. # NOTE: This was changed (rel. to TestLinear...) to add a pole @zero:
  2052. a = np.array([1, 0]).astype(self.dt)
  2053. y_r = np.array([0, 1, 3, 5, 7, 9.]).astype(self.dt)
  2054. assert_array_almost_equal(sosfilt(tf2sos(b, a), x), y_r)
  2055. b = [1, 1, 0]
  2056. a = [1, 0, 0]
  2057. x = np.ones(8)
  2058. sos = np.concatenate((b, a))
  2059. sos.shape = (1, 6)
  2060. y = sosfilt(sos, x)
  2061. assert_allclose(y, [1, 2, 2, 2, 2, 2, 2, 2])
  2062. def test_rank2(self):
  2063. shape = (4, 3)
  2064. x = np.linspace(0, np.prod(shape) - 1, np.prod(shape)).reshape(shape)
  2065. x = x.astype(self.dt)
  2066. b = np.array([1, -1]).astype(self.dt)
  2067. a = np.array([0.5, 0.5]).astype(self.dt)
  2068. y_r2_a0 = np.array([[0, 2, 4], [6, 4, 2], [0, 2, 4], [6, 4, 2]],
  2069. dtype=self.dt)
  2070. y_r2_a1 = np.array([[0, 2, 0], [6, -4, 6], [12, -10, 12],
  2071. [18, -16, 18]], dtype=self.dt)
  2072. y = sosfilt(tf2sos(b, a), x, axis=0)
  2073. assert_array_almost_equal(y_r2_a0, y)
  2074. y = sosfilt(tf2sos(b, a), x, axis=1)
  2075. assert_array_almost_equal(y_r2_a1, y)
  2076. def test_rank3(self):
  2077. shape = (4, 3, 2)
  2078. x = np.linspace(0, np.prod(shape) - 1, np.prod(shape)).reshape(shape)
  2079. b = np.array([1, -1]).astype(self.dt)
  2080. a = np.array([0.5, 0.5]).astype(self.dt)
  2081. # Test last axis
  2082. y = sosfilt(tf2sos(b, a), x)
  2083. for i in range(x.shape[0]):
  2084. for j in range(x.shape[1]):
  2085. assert_array_almost_equal(y[i, j], lfilter(b, a, x[i, j]))
  2086. def test_initial_conditions(self):
  2087. b1, a1 = signal.butter(2, 0.25, 'low')
  2088. b2, a2 = signal.butter(2, 0.75, 'low')
  2089. b3, a3 = signal.butter(2, 0.75, 'low')
  2090. b = np.convolve(np.convolve(b1, b2), b3)
  2091. a = np.convolve(np.convolve(a1, a2), a3)
  2092. sos = np.array((np.r_[b1, a1], np.r_[b2, a2], np.r_[b3, a3]))
  2093. x = np.random.rand(50)
  2094. # Stopping filtering and continuing
  2095. y_true, zi = lfilter(b, a, x[:20], zi=np.zeros(6))
  2096. y_true = np.r_[y_true, lfilter(b, a, x[20:], zi=zi)[0]]
  2097. assert_allclose(y_true, lfilter(b, a, x))
  2098. y_sos, zi = sosfilt(sos, x[:20], zi=np.zeros((3, 2)))
  2099. y_sos = np.r_[y_sos, sosfilt(sos, x[20:], zi=zi)[0]]
  2100. assert_allclose(y_true, y_sos)
  2101. # Use a step function
  2102. zi = sosfilt_zi(sos)
  2103. x = np.ones(8)
  2104. y, zf = sosfilt(sos, x, zi=zi)
  2105. assert_allclose(y, np.ones(8))
  2106. assert_allclose(zf, zi)
  2107. # Initial condition shape matching
  2108. x.shape = (1, 1) + x.shape # 3D
  2109. assert_raises(ValueError, sosfilt, sos, x, zi=zi)
  2110. zi_nd = zi.copy()
  2111. zi_nd.shape = (zi.shape[0], 1, 1, zi.shape[-1])
  2112. assert_raises(ValueError, sosfilt, sos, x,
  2113. zi=zi_nd[:, :, :, [0, 1, 1]])
  2114. y, zf = sosfilt(sos, x, zi=zi_nd)
  2115. assert_allclose(y[0, 0], np.ones(8))
  2116. assert_allclose(zf[:, 0, 0, :], zi)
  2117. def test_initial_conditions_3d_axis1(self):
  2118. # Test the use of zi when sosfilt is applied to axis 1 of a 3-d input.
  2119. # Input array is x.
  2120. x = np.random.RandomState(159).randint(0, 5, size=(2, 15, 3))
  2121. # Design a filter in ZPK format and convert to SOS
  2122. zpk = signal.butter(6, 0.35, output='zpk')
  2123. sos = zpk2sos(*zpk)
  2124. nsections = sos.shape[0]
  2125. # Filter along this axis.
  2126. axis = 1
  2127. # Initial conditions, all zeros.
  2128. shp = list(x.shape)
  2129. shp[axis] = 2
  2130. shp = [nsections] + shp
  2131. z0 = np.zeros(shp)
  2132. # Apply the filter to x.
  2133. yf, zf = sosfilt(sos, x, axis=axis, zi=z0)
  2134. # Apply the filter to x in two stages.
  2135. y1, z1 = sosfilt(sos, x[:, :5, :], axis=axis, zi=z0)
  2136. y2, z2 = sosfilt(sos, x[:, 5:, :], axis=axis, zi=z1)
  2137. # y should equal yf, and z2 should equal zf.
  2138. y = np.concatenate((y1, y2), axis=axis)
  2139. assert_allclose(y, yf, rtol=1e-10, atol=1e-13)
  2140. assert_allclose(z2, zf, rtol=1e-10, atol=1e-13)
  2141. # let's try the "step" initial condition
  2142. zi = sosfilt_zi(sos)
  2143. zi.shape = [nsections, 1, 2, 1]
  2144. zi = zi * x[:, 0:1, :]
  2145. y = sosfilt(sos, x, axis=axis, zi=zi)[0]
  2146. # check it against the TF form
  2147. b, a = zpk2tf(*zpk)
  2148. zi = lfilter_zi(b, a)
  2149. zi.shape = [1, zi.size, 1]
  2150. zi = zi * x[:, 0:1, :]
  2151. y_tf = lfilter(b, a, x, axis=axis, zi=zi)[0]
  2152. assert_allclose(y, y_tf, rtol=1e-10, atol=1e-13)
  2153. def test_bad_zi_shape(self):
  2154. # The shape of zi is checked before using any values in the
  2155. # arguments, so np.empty is fine for creating the arguments.
  2156. x = np.empty((3, 15, 3))
  2157. sos = np.empty((4, 6))
  2158. zi = np.empty((4, 3, 3, 2)) # Correct shape is (4, 3, 2, 3)
  2159. assert_raises(ValueError, sosfilt, sos, x, zi=zi, axis=1)
  2160. def test_sosfilt_zi(self):
  2161. sos = signal.butter(6, 0.2, output='sos')
  2162. zi = sosfilt_zi(sos)
  2163. y, zf = sosfilt(sos, np.ones(40), zi=zi)
  2164. assert_allclose(zf, zi, rtol=1e-13)
  2165. # Expected steady state value of the step response of this filter:
  2166. ss = np.prod(sos[:, :3].sum(axis=-1) / sos[:, 3:].sum(axis=-1))
  2167. assert_allclose(y, ss, rtol=1e-13)
  2168. class TestDeconvolve(object):
  2169. def test_basic(self):
  2170. # From docstring example
  2171. original = [0, 1, 0, 0, 1, 1, 0, 0]
  2172. impulse_response = [2, 1]
  2173. recorded = [0, 2, 1, 0, 2, 3, 1, 0, 0]
  2174. recovered, remainder = signal.deconvolve(recorded, impulse_response)
  2175. assert_allclose(recovered, original)
  2176. class TestUniqueRoots(object):
  2177. def test_real_no_repeat(self):
  2178. p = [-1.0, -0.5, 0.3, 1.2, 10.0]
  2179. unique, multiplicity = unique_roots(p)
  2180. assert_almost_equal(unique, p, decimal=15)
  2181. assert_equal(multiplicity, np.ones(len(p)))
  2182. def test_real_repeat(self):
  2183. p = [-1.0, -0.95, -0.89, -0.8, 0.5, 1.0, 1.05]
  2184. unique, multiplicity = unique_roots(p, tol=1e-1, rtype='min')
  2185. assert_almost_equal(unique, [-1.0, -0.89, 0.5, 1.0], decimal=15)
  2186. assert_equal(multiplicity, [2, 2, 1, 2])
  2187. unique, multiplicity = unique_roots(p, tol=1e-1, rtype='max')
  2188. assert_almost_equal(unique, [-0.95, -0.8, 0.5, 1.05], decimal=15)
  2189. assert_equal(multiplicity, [2, 2, 1, 2])
  2190. unique, multiplicity = unique_roots(p, tol=1e-1, rtype='avg')
  2191. assert_almost_equal(unique, [-0.975, -0.845, 0.5, 1.025], decimal=15)
  2192. assert_equal(multiplicity, [2, 2, 1, 2])
  2193. def test_complex_no_repeat(self):
  2194. p = [-1.0, 1.0j, 0.5 + 0.5j, -1.0 - 1.0j, 3.0 + 2.0j]
  2195. unique, multiplicity = unique_roots(p)
  2196. assert_almost_equal(unique, p, decimal=15)
  2197. assert_equal(multiplicity, np.ones(len(p)))
  2198. def test_complex_repeat(self):
  2199. p = [-1.0, -1.0 + 0.05j, -0.95 + 0.15j, -0.90 + 0.15j, 0.0,
  2200. 0.5 + 0.5j, 0.45 + 0.55j]
  2201. unique, multiplicity = unique_roots(p, tol=1e-1, rtype='min')
  2202. assert_almost_equal(unique, [-1.0, -0.95 + 0.15j, 0.0, 0.45 + 0.55j],
  2203. decimal=15)
  2204. assert_equal(multiplicity, [2, 2, 1, 2])
  2205. unique, multiplicity = unique_roots(p, tol=1e-1, rtype='max')
  2206. assert_almost_equal(unique,
  2207. [-1.0 + 0.05j, -0.90 + 0.15j, 0.0, 0.5 + 0.5j],
  2208. decimal=15)
  2209. assert_equal(multiplicity, [2, 2, 1, 2])
  2210. unique, multiplicity = unique_roots(p, tol=1e-1, rtype='avg')
  2211. assert_almost_equal(
  2212. unique, [-1.0 + 0.025j, -0.925 + 0.15j, 0.0, 0.475 + 0.525j],
  2213. decimal=15)
  2214. assert_equal(multiplicity, [2, 2, 1, 2])
  2215. def test_gh_4915(self):
  2216. p = np.roots(np.convolve(np.ones(5), np.ones(5)))
  2217. true_roots = [-(-1 + 0j)**(1/5),
  2218. (-1 + 0j)**(4/5),
  2219. -(-1 + 0j)**(3/5),
  2220. (-1 + 0j)**(2/5)]
  2221. unique, multiplicity = unique_roots(p)
  2222. unique = np.sort(unique)
  2223. assert_almost_equal(np.sort(unique), true_roots, decimal=7)
  2224. assert_equal(multiplicity, [2, 2, 2, 2])
  2225. def test_complex_roots_extra(self):
  2226. unique, multiplicity = unique_roots([1.0, 1.0j, 1.0])
  2227. assert_almost_equal(unique, [1.0, 1.0j], decimal=15)
  2228. assert_equal(multiplicity, [2, 1])
  2229. unique, multiplicity = unique_roots([1, 1 + 2e-9, 1e-9 + 1j], tol=0.1)
  2230. assert_almost_equal(unique, [1.0, 1e-9 + 1.0j], decimal=15)
  2231. assert_equal(multiplicity, [2, 1])
  2232. def test_single_unique_root(self):
  2233. p = np.random.rand(100) + 1j * np.random.rand(100)
  2234. unique, multiplicity = unique_roots(p, 2)
  2235. assert_almost_equal(unique, [np.min(p)], decimal=15)
  2236. assert_equal(multiplicity, [100])