test_decomp.py 102 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838
  1. """ Test functions for linalg.decomp module
  2. """
  3. from __future__ import division, print_function, absolute_import
  4. __usage__ = """
  5. Build linalg:
  6. python setup_linalg.py build
  7. Run tests if scipy is installed:
  8. python -c 'import scipy;scipy.linalg.test()'
  9. """
  10. import itertools
  11. import numpy as np
  12. from numpy.testing import (assert_equal, assert_almost_equal,
  13. assert_array_almost_equal, assert_array_equal,
  14. assert_, assert_allclose)
  15. import pytest
  16. from pytest import raises as assert_raises
  17. from scipy._lib.six import xrange
  18. from scipy.linalg import (eig, eigvals, lu, svd, svdvals, cholesky, qr,
  19. schur, rsf2csf, lu_solve, lu_factor, solve, diagsvd, hessenberg, rq,
  20. eig_banded, eigvals_banded, eigh, eigvalsh, qr_multiply, qz, orth, ordqz,
  21. subspace_angles, hadamard, eigvalsh_tridiagonal, eigh_tridiagonal,
  22. null_space, cdf2rdf)
  23. from scipy.linalg.lapack import dgbtrf, dgbtrs, zgbtrf, zgbtrs, \
  24. dsbev, dsbevd, dsbevx, zhbevd, zhbevx
  25. from scipy.linalg.misc import norm
  26. from scipy.linalg._decomp_qz import _select_function
  27. from numpy import array, transpose, sometrue, diag, ones, linalg, \
  28. argsort, zeros, arange, float32, complex64, dot, conj, identity, \
  29. ravel, sqrt, iscomplex, shape, sort, conjugate, bmat, sign, \
  30. asarray, matrix, isfinite, all, ndarray, outer, eye, dtype, empty,\
  31. triu, tril
  32. from numpy.random import normal, seed, random
  33. from scipy.linalg._testutils import assert_no_overwrite
  34. # digit precision to use in asserts for different types
  35. DIGITS = {'d':11, 'D':11, 'f':4, 'F':4}
  36. def clear_fuss(ar, fuss_binary_bits=7):
  37. """Clears trailing `fuss_binary_bits` of mantissa of a floating number"""
  38. x = np.asanyarray(ar)
  39. if np.iscomplexobj(x):
  40. return clear_fuss(x.real) + 1j * clear_fuss(x.imag)
  41. significant_binary_bits = np.finfo(x.dtype).nmant
  42. x_mant, x_exp = np.frexp(x)
  43. f = 2.0**(significant_binary_bits - fuss_binary_bits)
  44. x_mant *= f
  45. np.rint(x_mant, out=x_mant)
  46. x_mant /= f
  47. return np.ldexp(x_mant, x_exp)
  48. # XXX: This function should be available through numpy.testing
  49. def assert_dtype_equal(act, des):
  50. if isinstance(act, ndarray):
  51. act = act.dtype
  52. else:
  53. act = dtype(act)
  54. if isinstance(des, ndarray):
  55. des = des.dtype
  56. else:
  57. des = dtype(des)
  58. assert_(act == des, 'dtype mismatch: "%s" (should be "%s") ' % (act, des))
  59. # XXX: This function should not be defined here, but somewhere in
  60. # scipy.linalg namespace
  61. def symrand(dim_or_eigv):
  62. """Return a random symmetric (Hermitian) matrix.
  63. If 'dim_or_eigv' is an integer N, return a NxN matrix, with eigenvalues
  64. uniformly distributed on (-1,1).
  65. If 'dim_or_eigv' is 1-D real array 'a', return a matrix whose
  66. eigenvalues are 'a'.
  67. """
  68. if isinstance(dim_or_eigv, int):
  69. dim = dim_or_eigv
  70. d = random(dim)*2 - 1
  71. elif (isinstance(dim_or_eigv, ndarray) and
  72. len(dim_or_eigv.shape) == 1):
  73. dim = dim_or_eigv.shape[0]
  74. d = dim_or_eigv
  75. else:
  76. raise TypeError("input type not supported.")
  77. v = random_rot(dim)
  78. h = dot(dot(v.T.conj(), diag(d)), v)
  79. # to avoid roundoff errors, symmetrize the matrix (again)
  80. h = 0.5*(h.T+h)
  81. return h
  82. # XXX: This function should not be defined here, but somewhere in
  83. # scipy.linalg namespace
  84. def random_rot(dim):
  85. """Return a random rotation matrix, drawn from the Haar distribution
  86. (the only uniform distribution on SO(n)).
  87. The algorithm is described in the paper
  88. Stewart, G.W., 'The efficient generation of random orthogonal
  89. matrices with an application to condition estimators', SIAM Journal
  90. on Numerical Analysis, 17(3), pp. 403-409, 1980.
  91. For more information see
  92. https://en.wikipedia.org/wiki/Orthogonal_matrix#Randomization"""
  93. H = eye(dim)
  94. D = ones((dim,))
  95. for n in range(1, dim):
  96. x = normal(size=(dim-n+1,))
  97. D[n-1] = sign(x[0])
  98. x[0] -= D[n-1]*sqrt((x*x).sum())
  99. # Householder transformation
  100. Hx = eye(dim-n+1) - 2.*outer(x, x)/(x*x).sum()
  101. mat = eye(dim)
  102. mat[n-1:,n-1:] = Hx
  103. H = dot(H, mat)
  104. # Fix the last sign such that the determinant is 1
  105. D[-1] = -D.prod()
  106. H = (D*H.T).T
  107. return H
  108. class TestEigVals(object):
  109. def test_simple(self):
  110. a = [[1,2,3],[1,2,3],[2,5,6]]
  111. w = eigvals(a)
  112. exact_w = [(9+sqrt(93))/2,0,(9-sqrt(93))/2]
  113. assert_array_almost_equal(w,exact_w)
  114. def test_simple_tr(self):
  115. a = array([[1,2,3],[1,2,3],[2,5,6]],'d')
  116. a = transpose(a).copy()
  117. a = transpose(a)
  118. w = eigvals(a)
  119. exact_w = [(9+sqrt(93))/2,0,(9-sqrt(93))/2]
  120. assert_array_almost_equal(w,exact_w)
  121. def test_simple_complex(self):
  122. a = [[1,2,3],[1,2,3],[2,5,6+1j]]
  123. w = eigvals(a)
  124. exact_w = [(9+1j+sqrt(92+6j))/2,
  125. 0,
  126. (9+1j-sqrt(92+6j))/2]
  127. assert_array_almost_equal(w,exact_w)
  128. def test_finite(self):
  129. a = [[1,2,3],[1,2,3],[2,5,6]]
  130. w = eigvals(a, check_finite=False)
  131. exact_w = [(9+sqrt(93))/2,0,(9-sqrt(93))/2]
  132. assert_array_almost_equal(w,exact_w)
  133. class TestEig(object):
  134. def test_simple(self):
  135. a = [[1,2,3],[1,2,3],[2,5,6]]
  136. w,v = eig(a)
  137. exact_w = [(9+sqrt(93))/2,0,(9-sqrt(93))/2]
  138. v0 = array([1,1,(1+sqrt(93)/3)/2])
  139. v1 = array([3.,0,-1])
  140. v2 = array([1,1,(1-sqrt(93)/3)/2])
  141. v0 = v0 / sqrt(dot(v0,transpose(v0)))
  142. v1 = v1 / sqrt(dot(v1,transpose(v1)))
  143. v2 = v2 / sqrt(dot(v2,transpose(v2)))
  144. assert_array_almost_equal(w,exact_w)
  145. assert_array_almost_equal(v0,v[:,0]*sign(v[0,0]))
  146. assert_array_almost_equal(v1,v[:,1]*sign(v[0,1]))
  147. assert_array_almost_equal(v2,v[:,2]*sign(v[0,2]))
  148. for i in range(3):
  149. assert_array_almost_equal(dot(a,v[:,i]),w[i]*v[:,i])
  150. w,v = eig(a,left=1,right=0)
  151. for i in range(3):
  152. assert_array_almost_equal(dot(transpose(a),v[:,i]),w[i]*v[:,i])
  153. def test_simple_complex_eig(self):
  154. a = [[1,2],[-2,1]]
  155. w,vl,vr = eig(a,left=1,right=1)
  156. assert_array_almost_equal(w, array([1+2j, 1-2j]))
  157. for i in range(2):
  158. assert_array_almost_equal(dot(a,vr[:,i]),w[i]*vr[:,i])
  159. for i in range(2):
  160. assert_array_almost_equal(dot(conjugate(transpose(a)),vl[:,i]),
  161. conjugate(w[i])*vl[:,i])
  162. def test_simple_complex(self):
  163. a = [[1,2,3],[1,2,3],[2,5,6+1j]]
  164. w,vl,vr = eig(a,left=1,right=1)
  165. for i in range(3):
  166. assert_array_almost_equal(dot(a,vr[:,i]),w[i]*vr[:,i])
  167. for i in range(3):
  168. assert_array_almost_equal(dot(conjugate(transpose(a)),vl[:,i]),
  169. conjugate(w[i])*vl[:,i])
  170. def test_gh_3054(self):
  171. a = [[1]]
  172. b = [[0]]
  173. w, vr = eig(a, b, homogeneous_eigvals=True)
  174. assert_allclose(w[1,0], 0)
  175. assert_(w[0,0] != 0)
  176. assert_allclose(vr, 1)
  177. w, vr = eig(a, b)
  178. assert_equal(w, np.inf)
  179. assert_allclose(vr, 1)
  180. def _check_gen_eig(self, A, B):
  181. if B is not None:
  182. A, B = asarray(A), asarray(B)
  183. B0 = B
  184. else:
  185. A = asarray(A)
  186. B0 = B
  187. B = np.eye(*A.shape)
  188. msg = "\n%r\n%r" % (A, B)
  189. # Eigenvalues in homogeneous coordinates
  190. w, vr = eig(A, B0, homogeneous_eigvals=True)
  191. wt = eigvals(A, B0, homogeneous_eigvals=True)
  192. val1 = dot(A, vr) * w[1,:]
  193. val2 = dot(B, vr) * w[0,:]
  194. for i in range(val1.shape[1]):
  195. assert_allclose(val1[:,i], val2[:,i], rtol=1e-13, atol=1e-13, err_msg=msg)
  196. if B0 is None:
  197. assert_allclose(w[1,:], 1)
  198. assert_allclose(wt[1,:], 1)
  199. perm = np.lexsort(w)
  200. permt = np.lexsort(wt)
  201. assert_allclose(w[:,perm], wt[:,permt], atol=1e-7, rtol=1e-7,
  202. err_msg=msg)
  203. length = np.empty(len(vr))
  204. for i in xrange(len(vr)):
  205. length[i] = norm(vr[:,i])
  206. assert_allclose(length, np.ones(length.size), err_msg=msg,
  207. atol=1e-7, rtol=1e-7)
  208. # Convert homogeneous coordinates
  209. beta_nonzero = (w[1,:] != 0)
  210. wh = w[0,beta_nonzero] / w[1,beta_nonzero]
  211. # Eigenvalues in standard coordinates
  212. w, vr = eig(A, B0)
  213. wt = eigvals(A, B0)
  214. val1 = dot(A, vr)
  215. val2 = dot(B, vr) * w
  216. res = val1 - val2
  217. for i in range(res.shape[1]):
  218. if all(isfinite(res[:,i])):
  219. assert_allclose(res[:,i], 0, rtol=1e-13, atol=1e-13, err_msg=msg)
  220. w_fin = w[isfinite(w)]
  221. wt_fin = wt[isfinite(wt)]
  222. perm = argsort(clear_fuss(w_fin))
  223. permt = argsort(clear_fuss(wt_fin))
  224. assert_allclose(w[perm], wt[permt],
  225. atol=1e-7, rtol=1e-7, err_msg=msg)
  226. length = np.empty(len(vr))
  227. for i in xrange(len(vr)):
  228. length[i] = norm(vr[:,i])
  229. assert_allclose(length, np.ones(length.size), err_msg=msg)
  230. # Compare homogeneous and nonhomogeneous versions
  231. assert_allclose(sort(wh), sort(w[np.isfinite(w)]))
  232. @pytest.mark.xfail(reason="See gh-2254.")
  233. def test_singular(self):
  234. # Example taken from
  235. # https://web.archive.org/web/20040903121217/http://www.cs.umu.se/research/nla/singular_pairs/guptri/matlab.html
  236. A = array(([22,34,31,31,17], [45,45,42,19,29], [39,47,49,26,34],
  237. [27,31,26,21,15], [38,44,44,24,30]))
  238. B = array(([13,26,25,17,24], [31,46,40,26,37], [26,40,19,25,25],
  239. [16,25,27,14,23], [24,35,18,21,22]))
  240. olderr = np.seterr(all='ignore')
  241. try:
  242. self._check_gen_eig(A, B)
  243. finally:
  244. np.seterr(**olderr)
  245. def test_falker(self):
  246. # Test matrices giving some Nan generalized eigenvalues.
  247. M = diag(array(([1,0,3])))
  248. K = array(([2,-1,-1],[-1,2,-1],[-1,-1,2]))
  249. D = array(([1,-1,0],[-1,1,0],[0,0,0]))
  250. Z = zeros((3,3))
  251. I3 = identity(3)
  252. A = bmat([[I3, Z], [Z, -K]])
  253. B = bmat([[Z, I3], [M, D]])
  254. olderr = np.seterr(all='ignore')
  255. try:
  256. self._check_gen_eig(A, B)
  257. finally:
  258. np.seterr(**olderr)
  259. def test_bad_geneig(self):
  260. # Ticket #709 (strange return values from DGGEV)
  261. def matrices(omega):
  262. c1 = -9 + omega**2
  263. c2 = 2*omega
  264. A = [[1, 0, 0, 0],
  265. [0, 1, 0, 0],
  266. [0, 0, c1, 0],
  267. [0, 0, 0, c1]]
  268. B = [[0, 0, 1, 0],
  269. [0, 0, 0, 1],
  270. [1, 0, 0, -c2],
  271. [0, 1, c2, 0]]
  272. return A, B
  273. # With a buggy LAPACK, this can fail for different omega on different
  274. # machines -- so we need to test several values
  275. olderr = np.seterr(all='ignore')
  276. try:
  277. for k in xrange(100):
  278. A, B = matrices(omega=k*5./100)
  279. self._check_gen_eig(A, B)
  280. finally:
  281. np.seterr(**olderr)
  282. def test_make_eigvals(self):
  283. # Step through all paths in _make_eigvals
  284. seed(1234)
  285. # Real eigenvalues
  286. A = symrand(3)
  287. self._check_gen_eig(A, None)
  288. B = symrand(3)
  289. self._check_gen_eig(A, B)
  290. # Complex eigenvalues
  291. A = random((3, 3)) + 1j*random((3, 3))
  292. self._check_gen_eig(A, None)
  293. B = random((3, 3)) + 1j*random((3, 3))
  294. self._check_gen_eig(A, B)
  295. def test_check_finite(self):
  296. a = [[1,2,3],[1,2,3],[2,5,6]]
  297. w,v = eig(a, check_finite=False)
  298. exact_w = [(9+sqrt(93))/2,0,(9-sqrt(93))/2]
  299. v0 = array([1,1,(1+sqrt(93)/3)/2])
  300. v1 = array([3.,0,-1])
  301. v2 = array([1,1,(1-sqrt(93)/3)/2])
  302. v0 = v0 / sqrt(dot(v0,transpose(v0)))
  303. v1 = v1 / sqrt(dot(v1,transpose(v1)))
  304. v2 = v2 / sqrt(dot(v2,transpose(v2)))
  305. assert_array_almost_equal(w,exact_w)
  306. assert_array_almost_equal(v0,v[:,0]*sign(v[0,0]))
  307. assert_array_almost_equal(v1,v[:,1]*sign(v[0,1]))
  308. assert_array_almost_equal(v2,v[:,2]*sign(v[0,2]))
  309. for i in range(3):
  310. assert_array_almost_equal(dot(a,v[:,i]),w[i]*v[:,i])
  311. def test_not_square_error(self):
  312. """Check that passing a non-square array raises a ValueError."""
  313. A = np.arange(6).reshape(3,2)
  314. assert_raises(ValueError, eig, A)
  315. def test_shape_mismatch(self):
  316. """Check that passing arrays of with different shapes raises a ValueError."""
  317. A = identity(2)
  318. B = np.arange(9.0).reshape(3,3)
  319. assert_raises(ValueError, eig, A, B)
  320. assert_raises(ValueError, eig, B, A)
  321. class TestEigBanded(object):
  322. def setup_method(self):
  323. self.create_bandmat()
  324. def create_bandmat(self):
  325. """Create the full matrix `self.fullmat` and
  326. the corresponding band matrix `self.bandmat`."""
  327. N = 10
  328. self.KL = 2 # number of subdiagonals (below the diagonal)
  329. self.KU = 2 # number of superdiagonals (above the diagonal)
  330. # symmetric band matrix
  331. self.sym_mat = (diag(1.0*ones(N))
  332. + diag(-1.0*ones(N-1), -1) + diag(-1.0*ones(N-1), 1)
  333. + diag(-2.0*ones(N-2), -2) + diag(-2.0*ones(N-2), 2))
  334. # hermitian band matrix
  335. self.herm_mat = (diag(-1.0*ones(N))
  336. + 1j*diag(1.0*ones(N-1), -1) - 1j*diag(1.0*ones(N-1), 1)
  337. + diag(-2.0*ones(N-2), -2) + diag(-2.0*ones(N-2), 2))
  338. # general real band matrix
  339. self.real_mat = (diag(1.0*ones(N))
  340. + diag(-1.0*ones(N-1), -1) + diag(-3.0*ones(N-1), 1)
  341. + diag(2.0*ones(N-2), -2) + diag(-2.0*ones(N-2), 2))
  342. # general complex band matrix
  343. self.comp_mat = (1j*diag(1.0*ones(N))
  344. + diag(-1.0*ones(N-1), -1) + 1j*diag(-3.0*ones(N-1), 1)
  345. + diag(2.0*ones(N-2), -2) + diag(-2.0*ones(N-2), 2))
  346. # Eigenvalues and -vectors from linalg.eig
  347. ew, ev = linalg.eig(self.sym_mat)
  348. ew = ew.real
  349. args = argsort(ew)
  350. self.w_sym_lin = ew[args]
  351. self.evec_sym_lin = ev[:,args]
  352. ew, ev = linalg.eig(self.herm_mat)
  353. ew = ew.real
  354. args = argsort(ew)
  355. self.w_herm_lin = ew[args]
  356. self.evec_herm_lin = ev[:,args]
  357. # Extract upper bands from symmetric and hermitian band matrices
  358. # (for use in dsbevd, dsbevx, zhbevd, zhbevx
  359. # and their single precision versions)
  360. LDAB = self.KU + 1
  361. self.bandmat_sym = zeros((LDAB, N), dtype=float)
  362. self.bandmat_herm = zeros((LDAB, N), dtype=complex)
  363. for i in xrange(LDAB):
  364. self.bandmat_sym[LDAB-i-1,i:N] = diag(self.sym_mat, i)
  365. self.bandmat_herm[LDAB-i-1,i:N] = diag(self.herm_mat, i)
  366. # Extract bands from general real and complex band matrix
  367. # (for use in dgbtrf, dgbtrs and their single precision versions)
  368. LDAB = 2*self.KL + self.KU + 1
  369. self.bandmat_real = zeros((LDAB, N), dtype=float)
  370. self.bandmat_real[2*self.KL,:] = diag(self.real_mat) # diagonal
  371. for i in xrange(self.KL):
  372. # superdiagonals
  373. self.bandmat_real[2*self.KL-1-i,i+1:N] = diag(self.real_mat, i+1)
  374. # subdiagonals
  375. self.bandmat_real[2*self.KL+1+i,0:N-1-i] = diag(self.real_mat,-i-1)
  376. self.bandmat_comp = zeros((LDAB, N), dtype=complex)
  377. self.bandmat_comp[2*self.KL,:] = diag(self.comp_mat) # diagonal
  378. for i in xrange(self.KL):
  379. # superdiagonals
  380. self.bandmat_comp[2*self.KL-1-i,i+1:N] = diag(self.comp_mat, i+1)
  381. # subdiagonals
  382. self.bandmat_comp[2*self.KL+1+i,0:N-1-i] = diag(self.comp_mat,-i-1)
  383. # absolute value for linear equation system A*x = b
  384. self.b = 1.0*arange(N)
  385. self.bc = self.b * (1 + 1j)
  386. #####################################################################
  387. def test_dsbev(self):
  388. """Compare dsbev eigenvalues and eigenvectors with
  389. the result of linalg.eig."""
  390. w, evec, info = dsbev(self.bandmat_sym, compute_v=1)
  391. evec_ = evec[:,argsort(w)]
  392. assert_array_almost_equal(sort(w), self.w_sym_lin)
  393. assert_array_almost_equal(abs(evec_), abs(self.evec_sym_lin))
  394. def test_dsbevd(self):
  395. """Compare dsbevd eigenvalues and eigenvectors with
  396. the result of linalg.eig."""
  397. w, evec, info = dsbevd(self.bandmat_sym, compute_v=1)
  398. evec_ = evec[:,argsort(w)]
  399. assert_array_almost_equal(sort(w), self.w_sym_lin)
  400. assert_array_almost_equal(abs(evec_), abs(self.evec_sym_lin))
  401. def test_dsbevx(self):
  402. """Compare dsbevx eigenvalues and eigenvectors
  403. with the result of linalg.eig."""
  404. N,N = shape(self.sym_mat)
  405. ## Achtung: Argumente 0.0,0.0,range?
  406. w, evec, num, ifail, info = dsbevx(self.bandmat_sym, 0.0, 0.0, 1, N,
  407. compute_v=1, range=2)
  408. evec_ = evec[:,argsort(w)]
  409. assert_array_almost_equal(sort(w), self.w_sym_lin)
  410. assert_array_almost_equal(abs(evec_), abs(self.evec_sym_lin))
  411. def test_zhbevd(self):
  412. """Compare zhbevd eigenvalues and eigenvectors
  413. with the result of linalg.eig."""
  414. w, evec, info = zhbevd(self.bandmat_herm, compute_v=1)
  415. evec_ = evec[:,argsort(w)]
  416. assert_array_almost_equal(sort(w), self.w_herm_lin)
  417. assert_array_almost_equal(abs(evec_), abs(self.evec_herm_lin))
  418. def test_zhbevx(self):
  419. """Compare zhbevx eigenvalues and eigenvectors
  420. with the result of linalg.eig."""
  421. N,N = shape(self.herm_mat)
  422. ## Achtung: Argumente 0.0,0.0,range?
  423. w, evec, num, ifail, info = zhbevx(self.bandmat_herm, 0.0, 0.0, 1, N,
  424. compute_v=1, range=2)
  425. evec_ = evec[:,argsort(w)]
  426. assert_array_almost_equal(sort(w), self.w_herm_lin)
  427. assert_array_almost_equal(abs(evec_), abs(self.evec_herm_lin))
  428. def test_eigvals_banded(self):
  429. """Compare eigenvalues of eigvals_banded with those of linalg.eig."""
  430. w_sym = eigvals_banded(self.bandmat_sym)
  431. w_sym = w_sym.real
  432. assert_array_almost_equal(sort(w_sym), self.w_sym_lin)
  433. w_herm = eigvals_banded(self.bandmat_herm)
  434. w_herm = w_herm.real
  435. assert_array_almost_equal(sort(w_herm), self.w_herm_lin)
  436. # extracting eigenvalues with respect to an index range
  437. ind1 = 2
  438. ind2 = np.longlong(6)
  439. w_sym_ind = eigvals_banded(self.bandmat_sym,
  440. select='i', select_range=(ind1, ind2))
  441. assert_array_almost_equal(sort(w_sym_ind),
  442. self.w_sym_lin[ind1:ind2+1])
  443. w_herm_ind = eigvals_banded(self.bandmat_herm,
  444. select='i', select_range=(ind1, ind2))
  445. assert_array_almost_equal(sort(w_herm_ind),
  446. self.w_herm_lin[ind1:ind2+1])
  447. # extracting eigenvalues with respect to a value range
  448. v_lower = self.w_sym_lin[ind1] - 1.0e-5
  449. v_upper = self.w_sym_lin[ind2] + 1.0e-5
  450. w_sym_val = eigvals_banded(self.bandmat_sym,
  451. select='v', select_range=(v_lower, v_upper))
  452. assert_array_almost_equal(sort(w_sym_val),
  453. self.w_sym_lin[ind1:ind2+1])
  454. v_lower = self.w_herm_lin[ind1] - 1.0e-5
  455. v_upper = self.w_herm_lin[ind2] + 1.0e-5
  456. w_herm_val = eigvals_banded(self.bandmat_herm,
  457. select='v', select_range=(v_lower, v_upper))
  458. assert_array_almost_equal(sort(w_herm_val),
  459. self.w_herm_lin[ind1:ind2+1])
  460. w_sym = eigvals_banded(self.bandmat_sym, check_finite=False)
  461. w_sym = w_sym.real
  462. assert_array_almost_equal(sort(w_sym), self.w_sym_lin)
  463. def test_eig_banded(self):
  464. """Compare eigenvalues and eigenvectors of eig_banded
  465. with those of linalg.eig. """
  466. w_sym, evec_sym = eig_banded(self.bandmat_sym)
  467. evec_sym_ = evec_sym[:,argsort(w_sym.real)]
  468. assert_array_almost_equal(sort(w_sym), self.w_sym_lin)
  469. assert_array_almost_equal(abs(evec_sym_), abs(self.evec_sym_lin))
  470. w_herm, evec_herm = eig_banded(self.bandmat_herm)
  471. evec_herm_ = evec_herm[:,argsort(w_herm.real)]
  472. assert_array_almost_equal(sort(w_herm), self.w_herm_lin)
  473. assert_array_almost_equal(abs(evec_herm_), abs(self.evec_herm_lin))
  474. # extracting eigenvalues with respect to an index range
  475. ind1 = 2
  476. ind2 = 6
  477. w_sym_ind, evec_sym_ind = eig_banded(self.bandmat_sym,
  478. select='i', select_range=(ind1, ind2))
  479. assert_array_almost_equal(sort(w_sym_ind),
  480. self.w_sym_lin[ind1:ind2+1])
  481. assert_array_almost_equal(abs(evec_sym_ind),
  482. abs(self.evec_sym_lin[:,ind1:ind2+1]))
  483. w_herm_ind, evec_herm_ind = eig_banded(self.bandmat_herm,
  484. select='i', select_range=(ind1, ind2))
  485. assert_array_almost_equal(sort(w_herm_ind),
  486. self.w_herm_lin[ind1:ind2+1])
  487. assert_array_almost_equal(abs(evec_herm_ind),
  488. abs(self.evec_herm_lin[:,ind1:ind2+1]))
  489. # extracting eigenvalues with respect to a value range
  490. v_lower = self.w_sym_lin[ind1] - 1.0e-5
  491. v_upper = self.w_sym_lin[ind2] + 1.0e-5
  492. w_sym_val, evec_sym_val = eig_banded(self.bandmat_sym,
  493. select='v', select_range=(v_lower, v_upper))
  494. assert_array_almost_equal(sort(w_sym_val),
  495. self.w_sym_lin[ind1:ind2+1])
  496. assert_array_almost_equal(abs(evec_sym_val),
  497. abs(self.evec_sym_lin[:,ind1:ind2+1]))
  498. v_lower = self.w_herm_lin[ind1] - 1.0e-5
  499. v_upper = self.w_herm_lin[ind2] + 1.0e-5
  500. w_herm_val, evec_herm_val = eig_banded(self.bandmat_herm,
  501. select='v', select_range=(v_lower, v_upper))
  502. assert_array_almost_equal(sort(w_herm_val),
  503. self.w_herm_lin[ind1:ind2+1])
  504. assert_array_almost_equal(abs(evec_herm_val),
  505. abs(self.evec_herm_lin[:,ind1:ind2+1]))
  506. w_sym, evec_sym = eig_banded(self.bandmat_sym, check_finite=False)
  507. evec_sym_ = evec_sym[:,argsort(w_sym.real)]
  508. assert_array_almost_equal(sort(w_sym), self.w_sym_lin)
  509. assert_array_almost_equal(abs(evec_sym_), abs(self.evec_sym_lin))
  510. def test_dgbtrf(self):
  511. """Compare dgbtrf LU factorisation with the LU factorisation result
  512. of linalg.lu."""
  513. M,N = shape(self.real_mat)
  514. lu_symm_band, ipiv, info = dgbtrf(self.bandmat_real, self.KL, self.KU)
  515. # extract matrix u from lu_symm_band
  516. u = diag(lu_symm_band[2*self.KL,:])
  517. for i in xrange(self.KL + self.KU):
  518. u += diag(lu_symm_band[2*self.KL-1-i,i+1:N], i+1)
  519. p_lin, l_lin, u_lin = lu(self.real_mat, permute_l=0)
  520. assert_array_almost_equal(u, u_lin)
  521. def test_zgbtrf(self):
  522. """Compare zgbtrf LU factorisation with the LU factorisation result
  523. of linalg.lu."""
  524. M,N = shape(self.comp_mat)
  525. lu_symm_band, ipiv, info = zgbtrf(self.bandmat_comp, self.KL, self.KU)
  526. # extract matrix u from lu_symm_band
  527. u = diag(lu_symm_band[2*self.KL,:])
  528. for i in xrange(self.KL + self.KU):
  529. u += diag(lu_symm_band[2*self.KL-1-i,i+1:N], i+1)
  530. p_lin, l_lin, u_lin = lu(self.comp_mat, permute_l=0)
  531. assert_array_almost_equal(u, u_lin)
  532. def test_dgbtrs(self):
  533. """Compare dgbtrs solutions for linear equation system A*x = b
  534. with solutions of linalg.solve."""
  535. lu_symm_band, ipiv, info = dgbtrf(self.bandmat_real, self.KL, self.KU)
  536. y, info = dgbtrs(lu_symm_band, self.KL, self.KU, self.b, ipiv)
  537. y_lin = linalg.solve(self.real_mat, self.b)
  538. assert_array_almost_equal(y, y_lin)
  539. def test_zgbtrs(self):
  540. """Compare zgbtrs solutions for linear equation system A*x = b
  541. with solutions of linalg.solve."""
  542. lu_symm_band, ipiv, info = zgbtrf(self.bandmat_comp, self.KL, self.KU)
  543. y, info = zgbtrs(lu_symm_band, self.KL, self.KU, self.bc, ipiv)
  544. y_lin = linalg.solve(self.comp_mat, self.bc)
  545. assert_array_almost_equal(y, y_lin)
  546. class TestEigTridiagonal(object):
  547. def setup_method(self):
  548. self.create_trimat()
  549. def create_trimat(self):
  550. """Create the full matrix `self.fullmat`, `self.d`, and `self.e`."""
  551. N = 10
  552. # symmetric band matrix
  553. self.d = 1.0*ones(N)
  554. self.e = -1.0*ones(N-1)
  555. self.full_mat = (diag(self.d) + diag(self.e, -1) + diag(self.e, 1))
  556. ew, ev = linalg.eig(self.full_mat)
  557. ew = ew.real
  558. args = argsort(ew)
  559. self.w = ew[args]
  560. self.evec = ev[:, args]
  561. def test_degenerate(self):
  562. """Test error conditions."""
  563. # Wrong sizes
  564. assert_raises(ValueError, eigvalsh_tridiagonal, self.d, self.e[:-1])
  565. # Must be real
  566. assert_raises(TypeError, eigvalsh_tridiagonal, self.d, self.e * 1j)
  567. # Bad driver
  568. assert_raises(TypeError, eigvalsh_tridiagonal, self.d, self.e,
  569. lapack_driver=1.)
  570. assert_raises(ValueError, eigvalsh_tridiagonal, self.d, self.e,
  571. lapack_driver='foo')
  572. # Bad bounds
  573. assert_raises(ValueError, eigvalsh_tridiagonal, self.d, self.e,
  574. select='i', select_range=(0, -1))
  575. def test_eigvalsh_tridiagonal(self):
  576. """Compare eigenvalues of eigvalsh_tridiagonal with those of eig."""
  577. # can't use ?STERF with subselection
  578. for driver in ('sterf', 'stev', 'stebz', 'stemr', 'auto'):
  579. w = eigvalsh_tridiagonal(self.d, self.e, lapack_driver=driver)
  580. assert_array_almost_equal(sort(w), self.w)
  581. for driver in ('sterf', 'stev'):
  582. assert_raises(ValueError, eigvalsh_tridiagonal, self.d, self.e,
  583. lapack_driver='stev', select='i',
  584. select_range=(0, 1))
  585. for driver in ('stebz', 'stemr', 'auto'):
  586. # extracting eigenvalues with respect to the full index range
  587. w_ind = eigvalsh_tridiagonal(
  588. self.d, self.e, select='i', select_range=(0, len(self.d)-1),
  589. lapack_driver=driver)
  590. assert_array_almost_equal(sort(w_ind), self.w)
  591. # extracting eigenvalues with respect to an index range
  592. ind1 = 2
  593. ind2 = 6
  594. w_ind = eigvalsh_tridiagonal(
  595. self.d, self.e, select='i', select_range=(ind1, ind2),
  596. lapack_driver=driver)
  597. assert_array_almost_equal(sort(w_ind), self.w[ind1:ind2+1])
  598. # extracting eigenvalues with respect to a value range
  599. v_lower = self.w[ind1] - 1.0e-5
  600. v_upper = self.w[ind2] + 1.0e-5
  601. w_val = eigvalsh_tridiagonal(
  602. self.d, self.e, select='v', select_range=(v_lower, v_upper),
  603. lapack_driver=driver)
  604. assert_array_almost_equal(sort(w_val), self.w[ind1:ind2+1])
  605. def test_eigh_tridiagonal(self):
  606. """Compare eigenvalues and eigenvectors of eigh_tridiagonal
  607. with those of eig. """
  608. # can't use ?STERF when eigenvectors are requested
  609. assert_raises(ValueError, eigh_tridiagonal, self.d, self.e,
  610. lapack_driver='sterf')
  611. for driver in ('stebz', 'stev', 'stemr', 'auto'):
  612. w, evec = eigh_tridiagonal(self.d, self.e, lapack_driver=driver)
  613. evec_ = evec[:, argsort(w)]
  614. assert_array_almost_equal(sort(w), self.w)
  615. assert_array_almost_equal(abs(evec_), abs(self.evec))
  616. assert_raises(ValueError, eigh_tridiagonal, self.d, self.e,
  617. lapack_driver='stev', select='i', select_range=(0, 1))
  618. for driver in ('stebz', 'stemr', 'auto'):
  619. # extracting eigenvalues with respect to an index range
  620. ind1 = 0
  621. ind2 = len(self.d)-1
  622. w, evec = eigh_tridiagonal(
  623. self.d, self.e, select='i', select_range=(ind1, ind2),
  624. lapack_driver=driver)
  625. assert_array_almost_equal(sort(w), self.w)
  626. assert_array_almost_equal(abs(evec), abs(self.evec))
  627. ind1 = 2
  628. ind2 = 6
  629. w, evec = eigh_tridiagonal(
  630. self.d, self.e, select='i', select_range=(ind1, ind2),
  631. lapack_driver=driver)
  632. assert_array_almost_equal(sort(w), self.w[ind1:ind2+1])
  633. assert_array_almost_equal(abs(evec),
  634. abs(self.evec[:, ind1:ind2+1]))
  635. # extracting eigenvalues with respect to a value range
  636. v_lower = self.w[ind1] - 1.0e-5
  637. v_upper = self.w[ind2] + 1.0e-5
  638. w, evec = eigh_tridiagonal(
  639. self.d, self.e, select='v', select_range=(v_lower, v_upper),
  640. lapack_driver=driver)
  641. assert_array_almost_equal(sort(w), self.w[ind1:ind2+1])
  642. assert_array_almost_equal(abs(evec),
  643. abs(self.evec[:, ind1:ind2+1]))
  644. def test_eigh():
  645. DIM = 6
  646. v = {'dim': (DIM,),
  647. 'dtype': ('f','d','F','D'),
  648. 'overwrite': (True, False),
  649. 'lower': (True, False),
  650. 'turbo': (True, False),
  651. 'eigvals': (None, (2, DIM-2))}
  652. for dim in v['dim']:
  653. for typ in v['dtype']:
  654. for overwrite in v['overwrite']:
  655. for turbo in v['turbo']:
  656. for eigenvalues in v['eigvals']:
  657. for lower in v['lower']:
  658. eigenhproblem_standard(
  659. 'ordinary',
  660. dim, typ, overwrite, lower,
  661. turbo, eigenvalues)
  662. eigenhproblem_general(
  663. 'general ',
  664. dim, typ, overwrite, lower,
  665. turbo, eigenvalues)
  666. def test_eigh_of_sparse():
  667. # This tests the rejection of inputs that eigh cannot currently handle.
  668. import scipy.sparse
  669. a = scipy.sparse.identity(2).tocsc()
  670. b = np.atleast_2d(a)
  671. assert_raises(ValueError, eigh, a)
  672. assert_raises(ValueError, eigh, b)
  673. def _complex_symrand(dim, dtype):
  674. a1, a2 = symrand(dim), symrand(dim)
  675. # add antisymmetric matrix as imag part
  676. a = a1 + 1j*(triu(a2)-tril(a2))
  677. return a.astype(dtype)
  678. def eigenhproblem_standard(desc, dim, dtype,
  679. overwrite, lower, turbo,
  680. eigenvalues):
  681. """Solve a standard eigenvalue problem."""
  682. if iscomplex(empty(1, dtype=dtype)):
  683. a = _complex_symrand(dim, dtype)
  684. else:
  685. a = symrand(dim).astype(dtype)
  686. if overwrite:
  687. a_c = a.copy()
  688. else:
  689. a_c = a
  690. w, z = eigh(a, overwrite_a=overwrite, lower=lower, eigvals=eigenvalues)
  691. assert_dtype_equal(z.dtype, dtype)
  692. w = w.astype(dtype)
  693. diag_ = diag(dot(z.T.conj(), dot(a_c, z))).real
  694. assert_array_almost_equal(diag_, w, DIGITS[dtype])
  695. def eigenhproblem_general(desc, dim, dtype,
  696. overwrite, lower, turbo,
  697. eigenvalues):
  698. """Solve a generalized eigenvalue problem."""
  699. if iscomplex(empty(1, dtype=dtype)):
  700. a = _complex_symrand(dim, dtype)
  701. b = _complex_symrand(dim, dtype)+diag([2.1]*dim).astype(dtype)
  702. else:
  703. a = symrand(dim).astype(dtype)
  704. b = symrand(dim).astype(dtype)+diag([2.1]*dim).astype(dtype)
  705. if overwrite:
  706. a_c, b_c = a.copy(), b.copy()
  707. else:
  708. a_c, b_c = a, b
  709. w, z = eigh(a, b, overwrite_a=overwrite, lower=lower,
  710. overwrite_b=overwrite, turbo=turbo, eigvals=eigenvalues)
  711. assert_dtype_equal(z.dtype, dtype)
  712. w = w.astype(dtype)
  713. diag1_ = diag(dot(z.T.conj(), dot(a_c, z))).real
  714. assert_array_almost_equal(diag1_, w, DIGITS[dtype])
  715. diag2_ = diag(dot(z.T.conj(), dot(b_c, z))).real
  716. assert_array_almost_equal(diag2_, ones(diag2_.shape[0]), DIGITS[dtype])
  717. def test_eigh_integer():
  718. a = array([[1,2],[2,7]])
  719. b = array([[3,1],[1,5]])
  720. w,z = eigh(a)
  721. w,z = eigh(a,b)
  722. class TestLU(object):
  723. def setup_method(self):
  724. self.a = array([[1,2,3],[1,2,3],[2,5,6]])
  725. self.ca = array([[1,2,3],[1,2,3],[2,5j,6]])
  726. # Those matrices are more robust to detect problems in permutation
  727. # matrices than the ones above
  728. self.b = array([[1,2,3],[4,5,6],[7,8,9]])
  729. self.cb = array([[1j,2j,3j],[4j,5j,6j],[7j,8j,9j]])
  730. # Reectangular matrices
  731. self.hrect = array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 12, 12]])
  732. self.chrect = 1.j * array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 12, 12]])
  733. self.vrect = array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 12, 12]])
  734. self.cvrect = 1.j * array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 12, 12]])
  735. # Medium sizes matrices
  736. self.med = random((30, 40))
  737. self.cmed = random((30, 40)) + 1.j * random((30, 40))
  738. def _test_common(self, data):
  739. p,l,u = lu(data)
  740. assert_array_almost_equal(dot(dot(p,l),u),data)
  741. pl,u = lu(data,permute_l=1)
  742. assert_array_almost_equal(dot(pl,u),data)
  743. # Simple tests
  744. def test_simple(self):
  745. self._test_common(self.a)
  746. def test_simple_complex(self):
  747. self._test_common(self.ca)
  748. def test_simple2(self):
  749. self._test_common(self.b)
  750. def test_simple2_complex(self):
  751. self._test_common(self.cb)
  752. # rectangular matrices tests
  753. def test_hrectangular(self):
  754. self._test_common(self.hrect)
  755. def test_vrectangular(self):
  756. self._test_common(self.vrect)
  757. def test_hrectangular_complex(self):
  758. self._test_common(self.chrect)
  759. def test_vrectangular_complex(self):
  760. self._test_common(self.cvrect)
  761. # Bigger matrices
  762. def test_medium1(self):
  763. """Check lu decomposition on medium size, rectangular matrix."""
  764. self._test_common(self.med)
  765. def test_medium1_complex(self):
  766. """Check lu decomposition on medium size, rectangular matrix."""
  767. self._test_common(self.cmed)
  768. def test_check_finite(self):
  769. p, l, u = lu(self.a, check_finite=False)
  770. assert_array_almost_equal(dot(dot(p,l),u), self.a)
  771. def test_simple_known(self):
  772. # Ticket #1458
  773. for order in ['C', 'F']:
  774. A = np.array([[2, 1],[0, 1.]], order=order)
  775. LU, P = lu_factor(A)
  776. assert_array_almost_equal(LU, np.array([[2, 1], [0, 1]]))
  777. assert_array_equal(P, np.array([0, 1]))
  778. class TestLUSingle(TestLU):
  779. """LU testers for single precision, real and double"""
  780. def setup_method(self):
  781. TestLU.setup_method(self)
  782. self.a = self.a.astype(float32)
  783. self.ca = self.ca.astype(complex64)
  784. self.b = self.b.astype(float32)
  785. self.cb = self.cb.astype(complex64)
  786. self.hrect = self.hrect.astype(float32)
  787. self.chrect = self.hrect.astype(complex64)
  788. self.vrect = self.vrect.astype(float32)
  789. self.cvrect = self.vrect.astype(complex64)
  790. self.med = self.vrect.astype(float32)
  791. self.cmed = self.vrect.astype(complex64)
  792. class TestLUSolve(object):
  793. def setup_method(self):
  794. seed(1234)
  795. def test_lu(self):
  796. a0 = random((10,10))
  797. b = random((10,))
  798. for order in ['C', 'F']:
  799. a = np.array(a0, order=order)
  800. x1 = solve(a,b)
  801. lu_a = lu_factor(a)
  802. x2 = lu_solve(lu_a,b)
  803. assert_array_almost_equal(x1,x2)
  804. def test_check_finite(self):
  805. a = random((10,10))
  806. b = random((10,))
  807. x1 = solve(a,b)
  808. lu_a = lu_factor(a, check_finite=False)
  809. x2 = lu_solve(lu_a,b, check_finite=False)
  810. assert_array_almost_equal(x1,x2)
  811. class TestSVD_GESDD(object):
  812. def setup_method(self):
  813. self.lapack_driver = 'gesdd'
  814. seed(1234)
  815. def test_degenerate(self):
  816. assert_raises(TypeError, svd, [[1.]], lapack_driver=1.)
  817. assert_raises(ValueError, svd, [[1.]], lapack_driver='foo')
  818. def test_simple(self):
  819. a = [[1,2,3],[1,20,3],[2,5,6]]
  820. for full_matrices in (True, False):
  821. u,s,vh = svd(a, full_matrices=full_matrices,
  822. lapack_driver=self.lapack_driver)
  823. assert_array_almost_equal(dot(transpose(u),u),identity(3))
  824. assert_array_almost_equal(dot(transpose(vh),vh),identity(3))
  825. sigma = zeros((u.shape[0],vh.shape[0]),s.dtype.char)
  826. for i in range(len(s)):
  827. sigma[i,i] = s[i]
  828. assert_array_almost_equal(dot(dot(u,sigma),vh),a)
  829. def test_simple_singular(self):
  830. a = [[1,2,3],[1,2,3],[2,5,6]]
  831. for full_matrices in (True, False):
  832. u,s,vh = svd(a, full_matrices=full_matrices,
  833. lapack_driver=self.lapack_driver)
  834. assert_array_almost_equal(dot(transpose(u),u),identity(3))
  835. assert_array_almost_equal(dot(transpose(vh),vh),identity(3))
  836. sigma = zeros((u.shape[0],vh.shape[0]),s.dtype.char)
  837. for i in range(len(s)):
  838. sigma[i,i] = s[i]
  839. assert_array_almost_equal(dot(dot(u,sigma),vh),a)
  840. def test_simple_underdet(self):
  841. a = [[1,2,3],[4,5,6]]
  842. for full_matrices in (True, False):
  843. u,s,vh = svd(a, full_matrices=full_matrices,
  844. lapack_driver=self.lapack_driver)
  845. assert_array_almost_equal(dot(transpose(u),u),identity(u.shape[0]))
  846. sigma = zeros((u.shape[0],vh.shape[0]),s.dtype.char)
  847. for i in range(len(s)):
  848. sigma[i,i] = s[i]
  849. assert_array_almost_equal(dot(dot(u,sigma),vh),a)
  850. def test_simple_overdet(self):
  851. a = [[1,2],[4,5],[3,4]]
  852. for full_matrices in (True, False):
  853. u,s,vh = svd(a, full_matrices=full_matrices,
  854. lapack_driver=self.lapack_driver)
  855. assert_array_almost_equal(dot(transpose(u),u), identity(u.shape[1]))
  856. assert_array_almost_equal(dot(transpose(vh),vh),identity(2))
  857. sigma = zeros((u.shape[1],vh.shape[0]),s.dtype.char)
  858. for i in range(len(s)):
  859. sigma[i,i] = s[i]
  860. assert_array_almost_equal(dot(dot(u,sigma),vh),a)
  861. def test_random(self):
  862. n = 20
  863. m = 15
  864. for i in range(3):
  865. for a in [random([n,m]),random([m,n])]:
  866. for full_matrices in (True, False):
  867. u,s,vh = svd(a, full_matrices=full_matrices,
  868. lapack_driver=self.lapack_driver)
  869. assert_array_almost_equal(dot(transpose(u),u),identity(u.shape[1]))
  870. assert_array_almost_equal(dot(vh, transpose(vh)),identity(vh.shape[0]))
  871. sigma = zeros((u.shape[1],vh.shape[0]),s.dtype.char)
  872. for i in range(len(s)):
  873. sigma[i,i] = s[i]
  874. assert_array_almost_equal(dot(dot(u,sigma),vh),a)
  875. def test_simple_complex(self):
  876. a = [[1,2,3],[1,2j,3],[2,5,6]]
  877. for full_matrices in (True, False):
  878. u,s,vh = svd(a, full_matrices=full_matrices,
  879. lapack_driver=self.lapack_driver)
  880. assert_array_almost_equal(dot(conj(transpose(u)),u),identity(u.shape[1]))
  881. assert_array_almost_equal(dot(conj(transpose(vh)),vh),identity(vh.shape[0]))
  882. sigma = zeros((u.shape[0],vh.shape[0]),s.dtype.char)
  883. for i in range(len(s)):
  884. sigma[i,i] = s[i]
  885. assert_array_almost_equal(dot(dot(u,sigma),vh),a)
  886. def test_random_complex(self):
  887. n = 20
  888. m = 15
  889. for i in range(3):
  890. for full_matrices in (True, False):
  891. for a in [random([n,m]),random([m,n])]:
  892. a = a + 1j*random(list(a.shape))
  893. u,s,vh = svd(a, full_matrices=full_matrices,
  894. lapack_driver=self.lapack_driver)
  895. assert_array_almost_equal(dot(conj(transpose(u)),u),identity(u.shape[1]))
  896. # This fails when [m,n]
  897. # assert_array_almost_equal(dot(conj(transpose(vh)),vh),identity(len(vh),dtype=vh.dtype.char))
  898. sigma = zeros((u.shape[1],vh.shape[0]),s.dtype.char)
  899. for i in range(len(s)):
  900. sigma[i,i] = s[i]
  901. assert_array_almost_equal(dot(dot(u,sigma),vh),a)
  902. def test_crash_1580(self):
  903. sizes = [(13, 23), (30, 50), (60, 100)]
  904. np.random.seed(1234)
  905. for sz in sizes:
  906. for dt in [np.float32, np.float64, np.complex64, np.complex128]:
  907. a = np.random.rand(*sz).astype(dt)
  908. # should not crash
  909. svd(a, lapack_driver=self.lapack_driver)
  910. def test_check_finite(self):
  911. a = [[1,2,3],[1,20,3],[2,5,6]]
  912. u,s,vh = svd(a, check_finite=False, lapack_driver=self.lapack_driver)
  913. assert_array_almost_equal(dot(transpose(u),u),identity(3))
  914. assert_array_almost_equal(dot(transpose(vh),vh),identity(3))
  915. sigma = zeros((u.shape[0],vh.shape[0]),s.dtype.char)
  916. for i in range(len(s)):
  917. sigma[i,i] = s[i]
  918. assert_array_almost_equal(dot(dot(u,sigma),vh),a)
  919. def test_gh_5039(self):
  920. # This is a smoke test for https://github.com/scipy/scipy/issues/5039
  921. #
  922. # The following is reported to raise "ValueError: On entry to DGESDD
  923. # parameter number 12 had an illegal value".
  924. # `interp1d([1,2,3,4], [1,2,3,4], kind='cubic')`
  925. # This is reported to only show up on LAPACK 3.0.3.
  926. #
  927. # The matrix below is taken from the call to
  928. # `B = _fitpack._bsplmat(order, xk)` in interpolate._find_smoothest
  929. b = np.array(
  930. [[0.16666667, 0.66666667, 0.16666667, 0., 0., 0.],
  931. [0., 0.16666667, 0.66666667, 0.16666667, 0., 0.],
  932. [0., 0., 0.16666667, 0.66666667, 0.16666667, 0.],
  933. [0., 0., 0., 0.16666667, 0.66666667, 0.16666667]])
  934. svd(b, lapack_driver=self.lapack_driver)
  935. class TestSVD_GESVD(TestSVD_GESDD):
  936. def setup_method(self):
  937. self.lapack_driver = 'gesvd'
  938. seed(1234)
  939. class TestSVDVals(object):
  940. def test_empty(self):
  941. for a in [[]], np.empty((2, 0)), np.ones((0, 3)):
  942. s = svdvals(a)
  943. assert_equal(s, np.empty(0))
  944. def test_simple(self):
  945. a = [[1,2,3],[1,2,3],[2,5,6]]
  946. s = svdvals(a)
  947. assert_(len(s) == 3)
  948. assert_(s[0] >= s[1] >= s[2])
  949. def test_simple_underdet(self):
  950. a = [[1,2,3],[4,5,6]]
  951. s = svdvals(a)
  952. assert_(len(s) == 2)
  953. assert_(s[0] >= s[1])
  954. def test_simple_overdet(self):
  955. a = [[1,2],[4,5],[3,4]]
  956. s = svdvals(a)
  957. assert_(len(s) == 2)
  958. assert_(s[0] >= s[1])
  959. def test_simple_complex(self):
  960. a = [[1,2,3],[1,20,3j],[2,5,6]]
  961. s = svdvals(a)
  962. assert_(len(s) == 3)
  963. assert_(s[0] >= s[1] >= s[2])
  964. def test_simple_underdet_complex(self):
  965. a = [[1,2,3],[4,5j,6]]
  966. s = svdvals(a)
  967. assert_(len(s) == 2)
  968. assert_(s[0] >= s[1])
  969. def test_simple_overdet_complex(self):
  970. a = [[1,2],[4,5],[3j,4]]
  971. s = svdvals(a)
  972. assert_(len(s) == 2)
  973. assert_(s[0] >= s[1])
  974. def test_check_finite(self):
  975. a = [[1,2,3],[1,2,3],[2,5,6]]
  976. s = svdvals(a, check_finite=False)
  977. assert_(len(s) == 3)
  978. assert_(s[0] >= s[1] >= s[2])
  979. @pytest.mark.slow
  980. def test_crash_2609(self):
  981. np.random.seed(1234)
  982. a = np.random.rand(1500, 2800)
  983. # Shouldn't crash:
  984. svdvals(a)
  985. class TestDiagSVD(object):
  986. def test_simple(self):
  987. assert_array_almost_equal(diagsvd([1,0,0],3,3),[[1,0,0],[0,0,0],[0,0,0]])
  988. class TestQR(object):
  989. def setup_method(self):
  990. seed(1234)
  991. def test_simple(self):
  992. a = [[8,2,3],[2,9,3],[5,3,6]]
  993. q,r = qr(a)
  994. assert_array_almost_equal(dot(transpose(q),q),identity(3))
  995. assert_array_almost_equal(dot(q,r),a)
  996. def test_simple_left(self):
  997. a = [[8,2,3],[2,9,3],[5,3,6]]
  998. q,r = qr(a)
  999. c = [1, 2, 3]
  1000. qc,r2 = qr_multiply(a, c, "left")
  1001. assert_array_almost_equal(dot(q, c), qc)
  1002. assert_array_almost_equal(r, r2)
  1003. qc,r2 = qr_multiply(a, identity(3), "left")
  1004. assert_array_almost_equal(q, qc)
  1005. def test_simple_right(self):
  1006. a = [[8,2,3],[2,9,3],[5,3,6]]
  1007. q,r = qr(a)
  1008. c = [1, 2, 3]
  1009. qc,r2 = qr_multiply(a, c)
  1010. assert_array_almost_equal(dot(c, q), qc)
  1011. assert_array_almost_equal(r, r2)
  1012. qc,r = qr_multiply(a, identity(3))
  1013. assert_array_almost_equal(q, qc)
  1014. def test_simple_pivoting(self):
  1015. a = np.asarray([[8,2,3],[2,9,3],[5,3,6]])
  1016. q,r,p = qr(a, pivoting=True)
  1017. d = abs(diag(r))
  1018. assert_(all(d[1:] <= d[:-1]))
  1019. assert_array_almost_equal(dot(transpose(q),q),identity(3))
  1020. assert_array_almost_equal(dot(q,r),a[:,p])
  1021. q2,r2 = qr(a[:,p])
  1022. assert_array_almost_equal(q,q2)
  1023. assert_array_almost_equal(r,r2)
  1024. def test_simple_left_pivoting(self):
  1025. a = [[8,2,3],[2,9,3],[5,3,6]]
  1026. q,r,jpvt = qr(a, pivoting=True)
  1027. c = [1, 2, 3]
  1028. qc,r,jpvt = qr_multiply(a, c, "left", True)
  1029. assert_array_almost_equal(dot(q, c), qc)
  1030. def test_simple_right_pivoting(self):
  1031. a = [[8,2,3],[2,9,3],[5,3,6]]
  1032. q,r,jpvt = qr(a, pivoting=True)
  1033. c = [1, 2, 3]
  1034. qc,r,jpvt = qr_multiply(a, c, pivoting=True)
  1035. assert_array_almost_equal(dot(c, q), qc)
  1036. def test_simple_trap(self):
  1037. a = [[8,2,3],[2,9,3]]
  1038. q,r = qr(a)
  1039. assert_array_almost_equal(dot(transpose(q),q),identity(2))
  1040. assert_array_almost_equal(dot(q,r),a)
  1041. def test_simple_trap_pivoting(self):
  1042. a = np.asarray([[8,2,3],[2,9,3]])
  1043. q,r,p = qr(a, pivoting=True)
  1044. d = abs(diag(r))
  1045. assert_(all(d[1:] <= d[:-1]))
  1046. assert_array_almost_equal(dot(transpose(q),q),identity(2))
  1047. assert_array_almost_equal(dot(q,r),a[:,p])
  1048. q2,r2 = qr(a[:,p])
  1049. assert_array_almost_equal(q,q2)
  1050. assert_array_almost_equal(r,r2)
  1051. def test_simple_tall(self):
  1052. # full version
  1053. a = [[8,2],[2,9],[5,3]]
  1054. q,r = qr(a)
  1055. assert_array_almost_equal(dot(transpose(q),q),identity(3))
  1056. assert_array_almost_equal(dot(q,r),a)
  1057. def test_simple_tall_pivoting(self):
  1058. # full version pivoting
  1059. a = np.asarray([[8,2],[2,9],[5,3]])
  1060. q,r,p = qr(a, pivoting=True)
  1061. d = abs(diag(r))
  1062. assert_(all(d[1:] <= d[:-1]))
  1063. assert_array_almost_equal(dot(transpose(q),q),identity(3))
  1064. assert_array_almost_equal(dot(q,r),a[:,p])
  1065. q2,r2 = qr(a[:,p])
  1066. assert_array_almost_equal(q,q2)
  1067. assert_array_almost_equal(r,r2)
  1068. def test_simple_tall_e(self):
  1069. # economy version
  1070. a = [[8,2],[2,9],[5,3]]
  1071. q,r = qr(a, mode='economic')
  1072. assert_array_almost_equal(dot(transpose(q),q),identity(2))
  1073. assert_array_almost_equal(dot(q,r),a)
  1074. assert_equal(q.shape, (3,2))
  1075. assert_equal(r.shape, (2,2))
  1076. def test_simple_tall_e_pivoting(self):
  1077. # economy version pivoting
  1078. a = np.asarray([[8,2],[2,9],[5,3]])
  1079. q,r,p = qr(a, pivoting=True, mode='economic')
  1080. d = abs(diag(r))
  1081. assert_(all(d[1:] <= d[:-1]))
  1082. assert_array_almost_equal(dot(transpose(q),q),identity(2))
  1083. assert_array_almost_equal(dot(q,r),a[:,p])
  1084. q2,r2 = qr(a[:,p], mode='economic')
  1085. assert_array_almost_equal(q,q2)
  1086. assert_array_almost_equal(r,r2)
  1087. def test_simple_tall_left(self):
  1088. a = [[8,2],[2,9],[5,3]]
  1089. q,r = qr(a, mode="economic")
  1090. c = [1, 2]
  1091. qc,r2 = qr_multiply(a, c, "left")
  1092. assert_array_almost_equal(dot(q, c), qc)
  1093. assert_array_almost_equal(r, r2)
  1094. c = array([1,2,0])
  1095. qc,r2 = qr_multiply(a, c, "left", overwrite_c=True)
  1096. assert_array_almost_equal(dot(q, c[:2]), qc)
  1097. qc,r = qr_multiply(a, identity(2), "left")
  1098. assert_array_almost_equal(qc, q)
  1099. def test_simple_tall_left_pivoting(self):
  1100. a = [[8,2],[2,9],[5,3]]
  1101. q,r,jpvt = qr(a, mode="economic", pivoting=True)
  1102. c = [1, 2]
  1103. qc,r,kpvt = qr_multiply(a, c, "left", True)
  1104. assert_array_equal(jpvt, kpvt)
  1105. assert_array_almost_equal(dot(q, c), qc)
  1106. qc,r,jpvt = qr_multiply(a, identity(2), "left", True)
  1107. assert_array_almost_equal(qc, q)
  1108. def test_simple_tall_right(self):
  1109. a = [[8,2],[2,9],[5,3]]
  1110. q,r = qr(a, mode="economic")
  1111. c = [1, 2, 3]
  1112. cq,r2 = qr_multiply(a, c)
  1113. assert_array_almost_equal(dot(c, q), cq)
  1114. assert_array_almost_equal(r, r2)
  1115. cq,r = qr_multiply(a, identity(3))
  1116. assert_array_almost_equal(cq, q)
  1117. def test_simple_tall_right_pivoting(self):
  1118. a = [[8,2],[2,9],[5,3]]
  1119. q,r,jpvt = qr(a, pivoting=True, mode="economic")
  1120. c = [1, 2, 3]
  1121. cq,r,jpvt = qr_multiply(a, c, pivoting=True)
  1122. assert_array_almost_equal(dot(c, q), cq)
  1123. cq,r,jpvt = qr_multiply(a, identity(3), pivoting=True)
  1124. assert_array_almost_equal(cq, q)
  1125. def test_simple_fat(self):
  1126. # full version
  1127. a = [[8,2,5],[2,9,3]]
  1128. q,r = qr(a)
  1129. assert_array_almost_equal(dot(transpose(q),q),identity(2))
  1130. assert_array_almost_equal(dot(q,r),a)
  1131. assert_equal(q.shape, (2,2))
  1132. assert_equal(r.shape, (2,3))
  1133. def test_simple_fat_pivoting(self):
  1134. # full version pivoting
  1135. a = np.asarray([[8,2,5],[2,9,3]])
  1136. q,r,p = qr(a, pivoting=True)
  1137. d = abs(diag(r))
  1138. assert_(all(d[1:] <= d[:-1]))
  1139. assert_array_almost_equal(dot(transpose(q),q),identity(2))
  1140. assert_array_almost_equal(dot(q,r),a[:,p])
  1141. assert_equal(q.shape, (2,2))
  1142. assert_equal(r.shape, (2,3))
  1143. q2,r2 = qr(a[:,p])
  1144. assert_array_almost_equal(q,q2)
  1145. assert_array_almost_equal(r,r2)
  1146. def test_simple_fat_e(self):
  1147. # economy version
  1148. a = [[8,2,3],[2,9,5]]
  1149. q,r = qr(a, mode='economic')
  1150. assert_array_almost_equal(dot(transpose(q),q),identity(2))
  1151. assert_array_almost_equal(dot(q,r),a)
  1152. assert_equal(q.shape, (2,2))
  1153. assert_equal(r.shape, (2,3))
  1154. def test_simple_fat_e_pivoting(self):
  1155. # economy version pivoting
  1156. a = np.asarray([[8,2,3],[2,9,5]])
  1157. q,r,p = qr(a, pivoting=True, mode='economic')
  1158. d = abs(diag(r))
  1159. assert_(all(d[1:] <= d[:-1]))
  1160. assert_array_almost_equal(dot(transpose(q),q),identity(2))
  1161. assert_array_almost_equal(dot(q,r),a[:,p])
  1162. assert_equal(q.shape, (2,2))
  1163. assert_equal(r.shape, (2,3))
  1164. q2,r2 = qr(a[:,p], mode='economic')
  1165. assert_array_almost_equal(q,q2)
  1166. assert_array_almost_equal(r,r2)
  1167. def test_simple_fat_left(self):
  1168. a = [[8,2,3],[2,9,5]]
  1169. q,r = qr(a, mode="economic")
  1170. c = [1, 2]
  1171. qc,r2 = qr_multiply(a, c, "left")
  1172. assert_array_almost_equal(dot(q, c), qc)
  1173. assert_array_almost_equal(r, r2)
  1174. qc,r = qr_multiply(a, identity(2), "left")
  1175. assert_array_almost_equal(qc, q)
  1176. def test_simple_fat_left_pivoting(self):
  1177. a = [[8,2,3],[2,9,5]]
  1178. q,r,jpvt = qr(a, mode="economic", pivoting=True)
  1179. c = [1, 2]
  1180. qc,r,jpvt = qr_multiply(a, c, "left", True)
  1181. assert_array_almost_equal(dot(q, c), qc)
  1182. qc,r,jpvt = qr_multiply(a, identity(2), "left", True)
  1183. assert_array_almost_equal(qc, q)
  1184. def test_simple_fat_right(self):
  1185. a = [[8,2,3],[2,9,5]]
  1186. q,r = qr(a, mode="economic")
  1187. c = [1, 2]
  1188. cq,r2 = qr_multiply(a, c)
  1189. assert_array_almost_equal(dot(c, q), cq)
  1190. assert_array_almost_equal(r, r2)
  1191. cq,r = qr_multiply(a, identity(2))
  1192. assert_array_almost_equal(cq, q)
  1193. def test_simple_fat_right_pivoting(self):
  1194. a = [[8,2,3],[2,9,5]]
  1195. q,r,jpvt = qr(a, pivoting=True, mode="economic")
  1196. c = [1, 2]
  1197. cq,r,jpvt = qr_multiply(a, c, pivoting=True)
  1198. assert_array_almost_equal(dot(c, q), cq)
  1199. cq,r,jpvt = qr_multiply(a, identity(2), pivoting=True)
  1200. assert_array_almost_equal(cq, q)
  1201. def test_simple_complex(self):
  1202. a = [[3,3+4j,5],[5,2,2+7j],[3,2,7]]
  1203. q,r = qr(a)
  1204. assert_array_almost_equal(dot(conj(transpose(q)),q),identity(3))
  1205. assert_array_almost_equal(dot(q,r),a)
  1206. def test_simple_complex_left(self):
  1207. a = [[3,3+4j,5],[5,2,2+7j],[3,2,7]]
  1208. q,r = qr(a)
  1209. c = [1, 2, 3+4j]
  1210. qc,r = qr_multiply(a, c, "left")
  1211. assert_array_almost_equal(dot(q, c), qc)
  1212. qc,r = qr_multiply(a, identity(3), "left")
  1213. assert_array_almost_equal(q, qc)
  1214. def test_simple_complex_right(self):
  1215. a = [[3,3+4j,5],[5,2,2+7j],[3,2,7]]
  1216. q,r = qr(a)
  1217. c = [1, 2, 3+4j]
  1218. qc,r = qr_multiply(a, c)
  1219. assert_array_almost_equal(dot(c, q), qc)
  1220. qc,r = qr_multiply(a, identity(3))
  1221. assert_array_almost_equal(q, qc)
  1222. def test_simple_tall_complex_left(self):
  1223. a = [[8,2+3j],[2,9],[5+7j,3]]
  1224. q,r = qr(a, mode="economic")
  1225. c = [1, 2+2j]
  1226. qc,r2 = qr_multiply(a, c, "left")
  1227. assert_array_almost_equal(dot(q, c), qc)
  1228. assert_array_almost_equal(r, r2)
  1229. c = array([1,2,0])
  1230. qc,r2 = qr_multiply(a, c, "left", overwrite_c=True)
  1231. assert_array_almost_equal(dot(q, c[:2]), qc)
  1232. qc,r = qr_multiply(a, identity(2), "left")
  1233. assert_array_almost_equal(qc, q)
  1234. def test_simple_complex_left_conjugate(self):
  1235. a = [[3,3+4j,5],[5,2,2+7j],[3,2,7]]
  1236. q,r = qr(a)
  1237. c = [1, 2, 3+4j]
  1238. qc,r = qr_multiply(a, c, "left", conjugate=True)
  1239. assert_array_almost_equal(dot(q.conjugate(), c), qc)
  1240. def test_simple_complex_tall_left_conjugate(self):
  1241. a = [[3,3+4j],[5,2+2j],[3,2]]
  1242. q,r = qr(a, mode='economic')
  1243. c = [1, 3+4j]
  1244. qc,r = qr_multiply(a, c, "left", conjugate=True)
  1245. assert_array_almost_equal(dot(q.conjugate(), c), qc)
  1246. def test_simple_complex_right_conjugate(self):
  1247. a = [[3,3+4j,5],[5,2,2+7j],[3,2,7]]
  1248. q,r = qr(a)
  1249. c = [1, 2, 3+4j]
  1250. qc,r = qr_multiply(a, c, conjugate=True)
  1251. assert_array_almost_equal(dot(c, q.conjugate()), qc)
  1252. def test_simple_complex_pivoting(self):
  1253. a = np.asarray([[3,3+4j,5],[5,2,2+7j],[3,2,7]])
  1254. q,r,p = qr(a, pivoting=True)
  1255. d = abs(diag(r))
  1256. assert_(all(d[1:] <= d[:-1]))
  1257. assert_array_almost_equal(dot(conj(transpose(q)),q),identity(3))
  1258. assert_array_almost_equal(dot(q,r),a[:,p])
  1259. q2,r2 = qr(a[:,p])
  1260. assert_array_almost_equal(q,q2)
  1261. assert_array_almost_equal(r,r2)
  1262. def test_simple_complex_left_pivoting(self):
  1263. a = np.asarray([[3,3+4j,5],[5,2,2+7j],[3,2,7]])
  1264. q,r,jpvt = qr(a, pivoting=True)
  1265. c = [1, 2, 3+4j]
  1266. qc,r,jpvt = qr_multiply(a, c, "left", True)
  1267. assert_array_almost_equal(dot(q, c), qc)
  1268. def test_simple_complex_right_pivoting(self):
  1269. a = np.asarray([[3,3+4j,5],[5,2,2+7j],[3,2,7]])
  1270. q,r,jpvt = qr(a, pivoting=True)
  1271. c = [1, 2, 3+4j]
  1272. qc,r,jpvt = qr_multiply(a, c, pivoting=True)
  1273. assert_array_almost_equal(dot(c, q), qc)
  1274. def test_random(self):
  1275. n = 20
  1276. for k in range(2):
  1277. a = random([n,n])
  1278. q,r = qr(a)
  1279. assert_array_almost_equal(dot(transpose(q),q),identity(n))
  1280. assert_array_almost_equal(dot(q,r),a)
  1281. def test_random_left(self):
  1282. n = 20
  1283. for k in range(2):
  1284. a = random([n,n])
  1285. q,r = qr(a)
  1286. c = random([n])
  1287. qc,r = qr_multiply(a, c, "left")
  1288. assert_array_almost_equal(dot(q, c), qc)
  1289. qc,r = qr_multiply(a, identity(n), "left")
  1290. assert_array_almost_equal(q, qc)
  1291. def test_random_right(self):
  1292. n = 20
  1293. for k in range(2):
  1294. a = random([n,n])
  1295. q,r = qr(a)
  1296. c = random([n])
  1297. cq,r = qr_multiply(a, c)
  1298. assert_array_almost_equal(dot(c, q), cq)
  1299. cq,r = qr_multiply(a, identity(n))
  1300. assert_array_almost_equal(q, cq)
  1301. def test_random_pivoting(self):
  1302. n = 20
  1303. for k in range(2):
  1304. a = random([n,n])
  1305. q,r,p = qr(a, pivoting=True)
  1306. d = abs(diag(r))
  1307. assert_(all(d[1:] <= d[:-1]))
  1308. assert_array_almost_equal(dot(transpose(q),q),identity(n))
  1309. assert_array_almost_equal(dot(q,r),a[:,p])
  1310. q2,r2 = qr(a[:,p])
  1311. assert_array_almost_equal(q,q2)
  1312. assert_array_almost_equal(r,r2)
  1313. def test_random_tall(self):
  1314. # full version
  1315. m = 200
  1316. n = 100
  1317. for k in range(2):
  1318. a = random([m,n])
  1319. q,r = qr(a)
  1320. assert_array_almost_equal(dot(transpose(q),q),identity(m))
  1321. assert_array_almost_equal(dot(q,r),a)
  1322. def test_random_tall_left(self):
  1323. # full version
  1324. m = 200
  1325. n = 100
  1326. for k in range(2):
  1327. a = random([m,n])
  1328. q,r = qr(a, mode="economic")
  1329. c = random([n])
  1330. qc,r = qr_multiply(a, c, "left")
  1331. assert_array_almost_equal(dot(q, c), qc)
  1332. qc,r = qr_multiply(a, identity(n), "left")
  1333. assert_array_almost_equal(qc, q)
  1334. def test_random_tall_right(self):
  1335. # full version
  1336. m = 200
  1337. n = 100
  1338. for k in range(2):
  1339. a = random([m,n])
  1340. q,r = qr(a, mode="economic")
  1341. c = random([m])
  1342. cq,r = qr_multiply(a, c)
  1343. assert_array_almost_equal(dot(c, q), cq)
  1344. cq,r = qr_multiply(a, identity(m))
  1345. assert_array_almost_equal(cq, q)
  1346. def test_random_tall_pivoting(self):
  1347. # full version pivoting
  1348. m = 200
  1349. n = 100
  1350. for k in range(2):
  1351. a = random([m,n])
  1352. q,r,p = qr(a, pivoting=True)
  1353. d = abs(diag(r))
  1354. assert_(all(d[1:] <= d[:-1]))
  1355. assert_array_almost_equal(dot(transpose(q),q),identity(m))
  1356. assert_array_almost_equal(dot(q,r),a[:,p])
  1357. q2,r2 = qr(a[:,p])
  1358. assert_array_almost_equal(q,q2)
  1359. assert_array_almost_equal(r,r2)
  1360. def test_random_tall_e(self):
  1361. # economy version
  1362. m = 200
  1363. n = 100
  1364. for k in range(2):
  1365. a = random([m,n])
  1366. q,r = qr(a, mode='economic')
  1367. assert_array_almost_equal(dot(transpose(q),q),identity(n))
  1368. assert_array_almost_equal(dot(q,r),a)
  1369. assert_equal(q.shape, (m,n))
  1370. assert_equal(r.shape, (n,n))
  1371. def test_random_tall_e_pivoting(self):
  1372. # economy version pivoting
  1373. m = 200
  1374. n = 100
  1375. for k in range(2):
  1376. a = random([m,n])
  1377. q,r,p = qr(a, pivoting=True, mode='economic')
  1378. d = abs(diag(r))
  1379. assert_(all(d[1:] <= d[:-1]))
  1380. assert_array_almost_equal(dot(transpose(q),q),identity(n))
  1381. assert_array_almost_equal(dot(q,r),a[:,p])
  1382. assert_equal(q.shape, (m,n))
  1383. assert_equal(r.shape, (n,n))
  1384. q2,r2 = qr(a[:,p], mode='economic')
  1385. assert_array_almost_equal(q,q2)
  1386. assert_array_almost_equal(r,r2)
  1387. def test_random_trap(self):
  1388. m = 100
  1389. n = 200
  1390. for k in range(2):
  1391. a = random([m,n])
  1392. q,r = qr(a)
  1393. assert_array_almost_equal(dot(transpose(q),q),identity(m))
  1394. assert_array_almost_equal(dot(q,r),a)
  1395. def test_random_trap_pivoting(self):
  1396. m = 100
  1397. n = 200
  1398. for k in range(2):
  1399. a = random([m,n])
  1400. q,r,p = qr(a, pivoting=True)
  1401. d = abs(diag(r))
  1402. assert_(all(d[1:] <= d[:-1]))
  1403. assert_array_almost_equal(dot(transpose(q),q),identity(m))
  1404. assert_array_almost_equal(dot(q,r),a[:,p])
  1405. q2,r2 = qr(a[:,p])
  1406. assert_array_almost_equal(q,q2)
  1407. assert_array_almost_equal(r,r2)
  1408. def test_random_complex(self):
  1409. n = 20
  1410. for k in range(2):
  1411. a = random([n,n])+1j*random([n,n])
  1412. q,r = qr(a)
  1413. assert_array_almost_equal(dot(conj(transpose(q)),q),identity(n))
  1414. assert_array_almost_equal(dot(q,r),a)
  1415. def test_random_complex_left(self):
  1416. n = 20
  1417. for k in range(2):
  1418. a = random([n,n])+1j*random([n,n])
  1419. q,r = qr(a)
  1420. c = random([n])+1j*random([n])
  1421. qc,r = qr_multiply(a, c, "left")
  1422. assert_array_almost_equal(dot(q, c), qc)
  1423. qc,r = qr_multiply(a, identity(n), "left")
  1424. assert_array_almost_equal(q, qc)
  1425. def test_random_complex_right(self):
  1426. n = 20
  1427. for k in range(2):
  1428. a = random([n,n])+1j*random([n,n])
  1429. q,r = qr(a)
  1430. c = random([n])+1j*random([n])
  1431. cq,r = qr_multiply(a, c)
  1432. assert_array_almost_equal(dot(c, q), cq)
  1433. cq,r = qr_multiply(a, identity(n))
  1434. assert_array_almost_equal(q, cq)
  1435. def test_random_complex_pivoting(self):
  1436. n = 20
  1437. for k in range(2):
  1438. a = random([n,n])+1j*random([n,n])
  1439. q,r,p = qr(a, pivoting=True)
  1440. d = abs(diag(r))
  1441. assert_(all(d[1:] <= d[:-1]))
  1442. assert_array_almost_equal(dot(conj(transpose(q)),q),identity(n))
  1443. assert_array_almost_equal(dot(q,r),a[:,p])
  1444. q2,r2 = qr(a[:,p])
  1445. assert_array_almost_equal(q,q2)
  1446. assert_array_almost_equal(r,r2)
  1447. def test_check_finite(self):
  1448. a = [[8,2,3],[2,9,3],[5,3,6]]
  1449. q,r = qr(a, check_finite=False)
  1450. assert_array_almost_equal(dot(transpose(q),q),identity(3))
  1451. assert_array_almost_equal(dot(q,r),a)
  1452. def test_lwork(self):
  1453. a = [[8,2,3],[2,9,3],[5,3,6]]
  1454. # Get comparison values
  1455. q,r = qr(a, lwork=None)
  1456. # Test against minimum valid lwork
  1457. q2,r2 = qr(a, lwork=3)
  1458. assert_array_almost_equal(q2,q)
  1459. assert_array_almost_equal(r2,r)
  1460. # Test against larger lwork
  1461. q3,r3 = qr(a, lwork=10)
  1462. assert_array_almost_equal(q3,q)
  1463. assert_array_almost_equal(r3,r)
  1464. # Test against explicit lwork=-1
  1465. q4,r4 = qr(a, lwork=-1)
  1466. assert_array_almost_equal(q4,q)
  1467. assert_array_almost_equal(r4,r)
  1468. # Test against invalid lwork
  1469. assert_raises(Exception, qr, (a,), {'lwork':0})
  1470. assert_raises(Exception, qr, (a,), {'lwork':2})
  1471. class TestRQ(object):
  1472. def setup_method(self):
  1473. seed(1234)
  1474. def test_simple(self):
  1475. a = [[8,2,3],[2,9,3],[5,3,6]]
  1476. r,q = rq(a)
  1477. assert_array_almost_equal(dot(q, transpose(q)),identity(3))
  1478. assert_array_almost_equal(dot(r,q),a)
  1479. def test_r(self):
  1480. a = [[8,2,3],[2,9,3],[5,3,6]]
  1481. r,q = rq(a)
  1482. r2 = rq(a, mode='r')
  1483. assert_array_almost_equal(r, r2)
  1484. def test_random(self):
  1485. n = 20
  1486. for k in range(2):
  1487. a = random([n,n])
  1488. r,q = rq(a)
  1489. assert_array_almost_equal(dot(q, transpose(q)),identity(n))
  1490. assert_array_almost_equal(dot(r,q),a)
  1491. def test_simple_trap(self):
  1492. a = [[8,2,3],[2,9,3]]
  1493. r,q = rq(a)
  1494. assert_array_almost_equal(dot(transpose(q),q),identity(3))
  1495. assert_array_almost_equal(dot(r,q),a)
  1496. def test_simple_tall(self):
  1497. a = [[8,2],[2,9],[5,3]]
  1498. r,q = rq(a)
  1499. assert_array_almost_equal(dot(transpose(q),q),identity(2))
  1500. assert_array_almost_equal(dot(r,q),a)
  1501. def test_simple_fat(self):
  1502. a = [[8,2,5],[2,9,3]]
  1503. r,q = rq(a)
  1504. assert_array_almost_equal(dot(transpose(q),q),identity(3))
  1505. assert_array_almost_equal(dot(r,q),a)
  1506. def test_simple_complex(self):
  1507. a = [[3,3+4j,5],[5,2,2+7j],[3,2,7]]
  1508. r,q = rq(a)
  1509. assert_array_almost_equal(dot(q, conj(transpose(q))),identity(3))
  1510. assert_array_almost_equal(dot(r,q),a)
  1511. def test_random_tall(self):
  1512. m = 200
  1513. n = 100
  1514. for k in range(2):
  1515. a = random([m,n])
  1516. r,q = rq(a)
  1517. assert_array_almost_equal(dot(q, transpose(q)),identity(n))
  1518. assert_array_almost_equal(dot(r,q),a)
  1519. def test_random_trap(self):
  1520. m = 100
  1521. n = 200
  1522. for k in range(2):
  1523. a = random([m,n])
  1524. r,q = rq(a)
  1525. assert_array_almost_equal(dot(q, transpose(q)),identity(n))
  1526. assert_array_almost_equal(dot(r,q),a)
  1527. def test_random_trap_economic(self):
  1528. m = 100
  1529. n = 200
  1530. for k in range(2):
  1531. a = random([m,n])
  1532. r,q = rq(a, mode='economic')
  1533. assert_array_almost_equal(dot(q,transpose(q)),identity(m))
  1534. assert_array_almost_equal(dot(r,q),a)
  1535. assert_equal(q.shape, (m, n))
  1536. assert_equal(r.shape, (m, m))
  1537. def test_random_complex(self):
  1538. n = 20
  1539. for k in range(2):
  1540. a = random([n,n])+1j*random([n,n])
  1541. r,q = rq(a)
  1542. assert_array_almost_equal(dot(q, conj(transpose(q))),identity(n))
  1543. assert_array_almost_equal(dot(r,q),a)
  1544. def test_random_complex_economic(self):
  1545. m = 100
  1546. n = 200
  1547. for k in range(2):
  1548. a = random([m,n])+1j*random([m,n])
  1549. r,q = rq(a, mode='economic')
  1550. assert_array_almost_equal(dot(q,conj(transpose(q))),identity(m))
  1551. assert_array_almost_equal(dot(r,q),a)
  1552. assert_equal(q.shape, (m, n))
  1553. assert_equal(r.shape, (m, m))
  1554. def test_check_finite(self):
  1555. a = [[8,2,3],[2,9,3],[5,3,6]]
  1556. r,q = rq(a, check_finite=False)
  1557. assert_array_almost_equal(dot(q, transpose(q)),identity(3))
  1558. assert_array_almost_equal(dot(r,q),a)
  1559. transp = transpose
  1560. any = sometrue
  1561. class TestSchur(object):
  1562. def test_simple(self):
  1563. a = [[8,12,3],[2,9,3],[10,3,6]]
  1564. t,z = schur(a)
  1565. assert_array_almost_equal(dot(dot(z,t),transp(conj(z))),a)
  1566. tc,zc = schur(a,'complex')
  1567. assert_(any(ravel(iscomplex(zc))) and any(ravel(iscomplex(tc))))
  1568. assert_array_almost_equal(dot(dot(zc,tc),transp(conj(zc))),a)
  1569. tc2,zc2 = rsf2csf(tc,zc)
  1570. assert_array_almost_equal(dot(dot(zc2,tc2),transp(conj(zc2))),a)
  1571. def test_sort(self):
  1572. a = [[4.,3.,1.,-1.],[-4.5,-3.5,-1.,1.],[9.,6.,-4.,4.5],[6.,4.,-3.,3.5]]
  1573. s,u,sdim = schur(a,sort='lhp')
  1574. assert_array_almost_equal([[0.1134,0.5436,0.8316,0.],
  1575. [-0.1134,-0.8245,0.5544,0.],
  1576. [-0.8213,0.1308,0.0265,-0.5547],
  1577. [-0.5475,0.0872,0.0177,0.8321]],
  1578. u,3)
  1579. assert_array_almost_equal([[-1.4142,0.1456,-11.5816,-7.7174],
  1580. [0.,-0.5000,9.4472,-0.7184],
  1581. [0.,0.,1.4142,-0.1456],
  1582. [0.,0.,0.,0.5]],
  1583. s,3)
  1584. assert_equal(2,sdim)
  1585. s,u,sdim = schur(a,sort='rhp')
  1586. assert_array_almost_equal([[0.4862,-0.4930,0.1434,-0.7071],
  1587. [-0.4862,0.4930,-0.1434,-0.7071],
  1588. [0.6042,0.3944,-0.6924,0.],
  1589. [0.4028,0.5986,0.6924,0.]],
  1590. u,3)
  1591. assert_array_almost_equal([[1.4142,-0.9270,4.5368,-14.4130],
  1592. [0.,0.5,6.5809,-3.1870],
  1593. [0.,0.,-1.4142,0.9270],
  1594. [0.,0.,0.,-0.5]],
  1595. s,3)
  1596. assert_equal(2,sdim)
  1597. s,u,sdim = schur(a,sort='iuc')
  1598. assert_array_almost_equal([[0.5547,0.,-0.5721,-0.6042],
  1599. [-0.8321,0.,-0.3814,-0.4028],
  1600. [0.,0.7071,-0.5134,0.4862],
  1601. [0.,0.7071,0.5134,-0.4862]],
  1602. u,3)
  1603. assert_array_almost_equal([[-0.5000,0.0000,-6.5809,-4.0974],
  1604. [0.,0.5000,-3.3191,-14.4130],
  1605. [0.,0.,1.4142,2.1573],
  1606. [0.,0.,0.,-1.4142]],
  1607. s,3)
  1608. assert_equal(2,sdim)
  1609. s,u,sdim = schur(a,sort='ouc')
  1610. assert_array_almost_equal([[0.4862,-0.5134,0.7071,0.],
  1611. [-0.4862,0.5134,0.7071,0.],
  1612. [0.6042,0.5721,0.,-0.5547],
  1613. [0.4028,0.3814,0.,0.8321]],
  1614. u,3)
  1615. assert_array_almost_equal([[1.4142,-2.1573,14.4130,4.0974],
  1616. [0.,-1.4142,3.3191,6.5809],
  1617. [0.,0.,-0.5000,0.],
  1618. [0.,0.,0.,0.5000]],
  1619. s,3)
  1620. assert_equal(2,sdim)
  1621. rhp_function = lambda x: x >= 0.0
  1622. s,u,sdim = schur(a,sort=rhp_function)
  1623. assert_array_almost_equal([[0.4862,-0.4930,0.1434,-0.7071],
  1624. [-0.4862,0.4930,-0.1434,-0.7071],
  1625. [0.6042,0.3944,-0.6924,0.],
  1626. [0.4028,0.5986,0.6924,0.]],
  1627. u,3)
  1628. assert_array_almost_equal([[1.4142,-0.9270,4.5368,-14.4130],
  1629. [0.,0.5,6.5809,-3.1870],
  1630. [0.,0.,-1.4142,0.9270],
  1631. [0.,0.,0.,-0.5]],
  1632. s,3)
  1633. assert_equal(2,sdim)
  1634. def test_sort_errors(self):
  1635. a = [[4.,3.,1.,-1.],[-4.5,-3.5,-1.,1.],[9.,6.,-4.,4.5],[6.,4.,-3.,3.5]]
  1636. assert_raises(ValueError, schur, a, sort='unsupported')
  1637. assert_raises(ValueError, schur, a, sort=1)
  1638. def test_check_finite(self):
  1639. a = [[8,12,3],[2,9,3],[10,3,6]]
  1640. t,z = schur(a, check_finite=False)
  1641. assert_array_almost_equal(dot(dot(z,t),transp(conj(z))),a)
  1642. class TestHessenberg(object):
  1643. def test_simple(self):
  1644. a = [[-149, -50,-154],
  1645. [537, 180, 546],
  1646. [-27, -9, -25]]
  1647. h1 = [[-149.0000,42.2037,-156.3165],
  1648. [-537.6783,152.5511,-554.9272],
  1649. [0,0.0728, 2.4489]]
  1650. h,q = hessenberg(a,calc_q=1)
  1651. assert_array_almost_equal(dot(transp(q),dot(a,q)),h)
  1652. assert_array_almost_equal(h,h1,decimal=4)
  1653. def test_simple_complex(self):
  1654. a = [[-149, -50,-154],
  1655. [537, 180j, 546],
  1656. [-27j, -9, -25]]
  1657. h,q = hessenberg(a,calc_q=1)
  1658. h1 = dot(transp(conj(q)),dot(a,q))
  1659. assert_array_almost_equal(h1,h)
  1660. def test_simple2(self):
  1661. a = [[1,2,3,4,5,6,7],
  1662. [0,2,3,4,6,7,2],
  1663. [0,2,2,3,0,3,2],
  1664. [0,0,2,8,0,0,2],
  1665. [0,3,1,2,0,1,2],
  1666. [0,1,2,3,0,1,0],
  1667. [0,0,0,0,0,1,2]]
  1668. h,q = hessenberg(a,calc_q=1)
  1669. assert_array_almost_equal(dot(transp(q),dot(a,q)),h)
  1670. def test_simple3(self):
  1671. a = np.eye(3)
  1672. a[-1, 0] = 2
  1673. h, q = hessenberg(a, calc_q=1)
  1674. assert_array_almost_equal(dot(transp(q), dot(a, q)), h)
  1675. def test_random(self):
  1676. n = 20
  1677. for k in range(2):
  1678. a = random([n,n])
  1679. h,q = hessenberg(a,calc_q=1)
  1680. assert_array_almost_equal(dot(transp(q),dot(a,q)),h)
  1681. def test_random_complex(self):
  1682. n = 20
  1683. for k in range(2):
  1684. a = random([n,n])+1j*random([n,n])
  1685. h,q = hessenberg(a,calc_q=1)
  1686. h1 = dot(transp(conj(q)),dot(a,q))
  1687. assert_array_almost_equal(h1,h)
  1688. def test_check_finite(self):
  1689. a = [[-149, -50,-154],
  1690. [537, 180, 546],
  1691. [-27, -9, -25]]
  1692. h1 = [[-149.0000,42.2037,-156.3165],
  1693. [-537.6783,152.5511,-554.9272],
  1694. [0,0.0728, 2.4489]]
  1695. h,q = hessenberg(a,calc_q=1, check_finite=False)
  1696. assert_array_almost_equal(dot(transp(q),dot(a,q)),h)
  1697. assert_array_almost_equal(h,h1,decimal=4)
  1698. def test_2x2(self):
  1699. a = [[2, 1], [7, 12]]
  1700. h, q = hessenberg(a, calc_q=1)
  1701. assert_array_almost_equal(q, np.eye(2))
  1702. assert_array_almost_equal(h, a)
  1703. b = [[2-7j, 1+2j], [7+3j, 12-2j]]
  1704. h2, q2 = hessenberg(b, calc_q=1)
  1705. assert_array_almost_equal(q2, np.eye(2))
  1706. assert_array_almost_equal(h2, b)
  1707. class TestQZ(object):
  1708. def setup_method(self):
  1709. seed(12345)
  1710. def test_qz_single(self):
  1711. n = 5
  1712. A = random([n,n]).astype(float32)
  1713. B = random([n,n]).astype(float32)
  1714. AA,BB,Q,Z = qz(A,B)
  1715. assert_array_almost_equal(dot(dot(Q,AA),Z.T), A, decimal=5)
  1716. assert_array_almost_equal(dot(dot(Q,BB),Z.T), B, decimal=5)
  1717. assert_array_almost_equal(dot(Q,Q.T), eye(n), decimal=5)
  1718. assert_array_almost_equal(dot(Z,Z.T), eye(n), decimal=5)
  1719. assert_(all(diag(BB) >= 0))
  1720. def test_qz_double(self):
  1721. n = 5
  1722. A = random([n,n])
  1723. B = random([n,n])
  1724. AA,BB,Q,Z = qz(A,B)
  1725. assert_array_almost_equal(dot(dot(Q,AA),Z.T), A)
  1726. assert_array_almost_equal(dot(dot(Q,BB),Z.T), B)
  1727. assert_array_almost_equal(dot(Q,Q.T), eye(n))
  1728. assert_array_almost_equal(dot(Z,Z.T), eye(n))
  1729. assert_(all(diag(BB) >= 0))
  1730. def test_qz_complex(self):
  1731. n = 5
  1732. A = random([n,n]) + 1j*random([n,n])
  1733. B = random([n,n]) + 1j*random([n,n])
  1734. AA,BB,Q,Z = qz(A,B)
  1735. assert_array_almost_equal(dot(dot(Q,AA),Z.conjugate().T), A)
  1736. assert_array_almost_equal(dot(dot(Q,BB),Z.conjugate().T), B)
  1737. assert_array_almost_equal(dot(Q,Q.conjugate().T), eye(n))
  1738. assert_array_almost_equal(dot(Z,Z.conjugate().T), eye(n))
  1739. assert_(all(diag(BB) >= 0))
  1740. assert_(all(diag(BB).imag == 0))
  1741. def test_qz_complex64(self):
  1742. n = 5
  1743. A = (random([n,n]) + 1j*random([n,n])).astype(complex64)
  1744. B = (random([n,n]) + 1j*random([n,n])).astype(complex64)
  1745. AA,BB,Q,Z = qz(A,B)
  1746. assert_array_almost_equal(dot(dot(Q,AA),Z.conjugate().T), A, decimal=5)
  1747. assert_array_almost_equal(dot(dot(Q,BB),Z.conjugate().T), B, decimal=5)
  1748. assert_array_almost_equal(dot(Q,Q.conjugate().T), eye(n), decimal=5)
  1749. assert_array_almost_equal(dot(Z,Z.conjugate().T), eye(n), decimal=5)
  1750. assert_(all(diag(BB) >= 0))
  1751. assert_(all(diag(BB).imag == 0))
  1752. def test_qz_double_complex(self):
  1753. n = 5
  1754. A = random([n,n])
  1755. B = random([n,n])
  1756. AA,BB,Q,Z = qz(A,B, output='complex')
  1757. aa = dot(dot(Q,AA),Z.conjugate().T)
  1758. assert_array_almost_equal(aa.real, A)
  1759. assert_array_almost_equal(aa.imag, 0)
  1760. bb = dot(dot(Q,BB),Z.conjugate().T)
  1761. assert_array_almost_equal(bb.real, B)
  1762. assert_array_almost_equal(bb.imag, 0)
  1763. assert_array_almost_equal(dot(Q,Q.conjugate().T), eye(n))
  1764. assert_array_almost_equal(dot(Z,Z.conjugate().T), eye(n))
  1765. assert_(all(diag(BB) >= 0))
  1766. def test_qz_double_sort(self):
  1767. # from https://www.nag.com/lapack-ex/node119.html
  1768. # NOTE: These matrices may be ill-conditioned and lead to a
  1769. # seg fault on certain python versions when compiled with
  1770. # sse2 or sse3 older ATLAS/LAPACK binaries for windows
  1771. # A = np.array([[3.9, 12.5, -34.5, -0.5],
  1772. # [ 4.3, 21.5, -47.5, 7.5],
  1773. # [ 4.3, 21.5, -43.5, 3.5],
  1774. # [ 4.4, 26.0, -46.0, 6.0 ]])
  1775. # B = np.array([[ 1.0, 2.0, -3.0, 1.0],
  1776. # [1.0, 3.0, -5.0, 4.0],
  1777. # [1.0, 3.0, -4.0, 3.0],
  1778. # [1.0, 3.0, -4.0, 4.0]])
  1779. A = np.array([[3.9, 12.5, -34.5, 2.5],
  1780. [4.3, 21.5, -47.5, 7.5],
  1781. [4.3, 1.5, -43.5, 3.5],
  1782. [4.4, 6.0, -46.0, 6.0]])
  1783. B = np.array([[1.0, 1.0, -3.0, 1.0],
  1784. [1.0, 3.0, -5.0, 4.4],
  1785. [1.0, 2.0, -4.0, 1.0],
  1786. [1.2, 3.0, -4.0, 4.0]])
  1787. sort = lambda ar,ai,beta: ai == 0
  1788. assert_raises(ValueError, qz, A, B, sort=sort)
  1789. if False:
  1790. AA,BB,Q,Z,sdim = qz(A,B,sort=sort)
  1791. # assert_(sdim == 2)
  1792. assert_(sdim == 4)
  1793. assert_array_almost_equal(dot(dot(Q,AA),Z.T), A)
  1794. assert_array_almost_equal(dot(dot(Q,BB),Z.T), B)
  1795. # test absolute values bc the sign is ambiguous and might be platform
  1796. # dependent
  1797. assert_array_almost_equal(np.abs(AA), np.abs(np.array(
  1798. [[35.7864, -80.9061, -12.0629, -9.498],
  1799. [0., 2.7638, -2.3505, 7.3256],
  1800. [0., 0., 0.6258, -0.0398],
  1801. [0., 0., 0., -12.8217]])), 4)
  1802. assert_array_almost_equal(np.abs(BB), np.abs(np.array(
  1803. [[4.5324, -8.7878, 3.2357, -3.5526],
  1804. [0., 1.4314, -2.1894, 0.9709],
  1805. [0., 0., 1.3126, -0.3468],
  1806. [0., 0., 0., 0.559]])), 4)
  1807. assert_array_almost_equal(np.abs(Q), np.abs(np.array(
  1808. [[-0.4193, -0.605, -0.1894, -0.6498],
  1809. [-0.5495, 0.6987, 0.2654, -0.3734],
  1810. [-0.4973, -0.3682, 0.6194, 0.4832],
  1811. [-0.5243, 0.1008, -0.7142, 0.4526]])), 4)
  1812. assert_array_almost_equal(np.abs(Z), np.abs(np.array(
  1813. [[-0.9471, -0.2971, -0.1217, 0.0055],
  1814. [-0.0367, 0.1209, 0.0358, 0.9913],
  1815. [0.3171, -0.9041, -0.2547, 0.1312],
  1816. [0.0346, 0.2824, -0.9587, 0.0014]])), 4)
  1817. # test absolute values bc the sign is ambiguous and might be platform
  1818. # dependent
  1819. # assert_array_almost_equal(abs(AA), abs(np.array([
  1820. # [3.8009, -69.4505, 50.3135, -43.2884],
  1821. # [0.0000, 9.2033, -0.2001, 5.9881],
  1822. # [0.0000, 0.0000, 1.4279, 4.4453],
  1823. # [0.0000, 0.0000, 0.9019, -1.1962]])), 4)
  1824. # assert_array_almost_equal(abs(BB), abs(np.array([
  1825. # [1.9005, -10.2285, 0.8658, -5.2134],
  1826. # [0.0000, 2.3008, 0.7915, 0.4262],
  1827. # [0.0000, 0.0000, 0.8101, 0.0000],
  1828. # [0.0000, 0.0000, 0.0000, -0.2823]])), 4)
  1829. # assert_array_almost_equal(abs(Q), abs(np.array([
  1830. # [0.4642, 0.7886, 0.2915, -0.2786],
  1831. # [0.5002, -0.5986, 0.5638, -0.2713],
  1832. # [0.5002, 0.0154, -0.0107, 0.8657],
  1833. # [0.5331, -0.1395, -0.7727, -0.3151]])), 4)
  1834. # assert_array_almost_equal(dot(Q,Q.T), eye(4))
  1835. # assert_array_almost_equal(abs(Z), abs(np.array([
  1836. # [0.9961, -0.0014, 0.0887, -0.0026],
  1837. # [0.0057, -0.0404, -0.0938, -0.9948],
  1838. # [0.0626, 0.7194, -0.6908, 0.0363],
  1839. # [0.0626, -0.6934, -0.7114, 0.0956]])), 4)
  1840. # assert_array_almost_equal(dot(Z,Z.T), eye(4))
  1841. # def test_qz_complex_sort(self):
  1842. # cA = np.array([
  1843. # [-21.10+22.50*1j, 53.50+-50.50*1j, -34.50+127.50*1j, 7.50+ 0.50*1j],
  1844. # [-0.46+ -7.78*1j, -3.50+-37.50*1j, -15.50+ 58.50*1j,-10.50+ -1.50*1j],
  1845. # [ 4.30+ -5.50*1j, 39.70+-17.10*1j, -68.50+ 12.50*1j, -7.50+ -3.50*1j],
  1846. # [ 5.50+ 4.40*1j, 14.40+ 43.30*1j, -32.50+-46.00*1j,-19.00+-32.50*1j]])
  1847. # cB = np.array([
  1848. # [1.00+ -5.00*1j, 1.60+ 1.20*1j,-3.00+ 0.00*1j, 0.00+ -1.00*1j],
  1849. # [0.80+ -0.60*1j, 3.00+ -5.00*1j,-4.00+ 3.00*1j,-2.40+ -3.20*1j],
  1850. # [1.00+ 0.00*1j, 2.40+ 1.80*1j,-4.00+ -5.00*1j, 0.00+ -3.00*1j],
  1851. # [0.00+ 1.00*1j,-1.80+ 2.40*1j, 0.00+ -4.00*1j, 4.00+ -5.00*1j]])
  1852. # AAS,BBS,QS,ZS,sdim = qz(cA,cB,sort='lhp')
  1853. # eigenvalues = diag(AAS)/diag(BBS)
  1854. # assert_(all(np.real(eigenvalues[:sdim] < 0)))
  1855. # assert_(all(np.real(eigenvalues[sdim:] > 0)))
  1856. def test_check_finite(self):
  1857. n = 5
  1858. A = random([n,n])
  1859. B = random([n,n])
  1860. AA,BB,Q,Z = qz(A,B,check_finite=False)
  1861. assert_array_almost_equal(dot(dot(Q,AA),Z.T), A)
  1862. assert_array_almost_equal(dot(dot(Q,BB),Z.T), B)
  1863. assert_array_almost_equal(dot(Q,Q.T), eye(n))
  1864. assert_array_almost_equal(dot(Z,Z.T), eye(n))
  1865. assert_(all(diag(BB) >= 0))
  1866. def _make_pos(X):
  1867. # the decompositions can have different signs than verified results
  1868. return np.sign(X)*X
  1869. class TestOrdQZ(object):
  1870. @classmethod
  1871. def setup_class(cls):
  1872. # https://www.nag.com/lapack-ex/node119.html
  1873. A1 = np.array([[-21.10 - 22.50j, 53.5 - 50.5j, -34.5 + 127.5j,
  1874. 7.5 + 0.5j],
  1875. [-0.46 - 7.78j, -3.5 - 37.5j, -15.5 + 58.5j,
  1876. -10.5 - 1.5j],
  1877. [4.30 - 5.50j, 39.7 - 17.1j, -68.5 + 12.5j,
  1878. -7.5 - 3.5j],
  1879. [5.50 + 4.40j, 14.4 + 43.3j, -32.5 - 46.0j,
  1880. -19.0 - 32.5j]])
  1881. B1 = np.array([[1.0 - 5.0j, 1.6 + 1.2j, -3 + 0j, 0.0 - 1.0j],
  1882. [0.8 - 0.6j, .0 - 5.0j, -4 + 3j, -2.4 - 3.2j],
  1883. [1.0 + 0.0j, 2.4 + 1.8j, -4 - 5j, 0.0 - 3.0j],
  1884. [0.0 + 1.0j, -1.8 + 2.4j, 0 - 4j, 4.0 - 5.0j]])
  1885. # https://www.nag.com/numeric/fl/nagdoc_fl23/xhtml/F08/f08yuf.xml
  1886. A2 = np.array([[3.9, 12.5, -34.5, -0.5],
  1887. [4.3, 21.5, -47.5, 7.5],
  1888. [4.3, 21.5, -43.5, 3.5],
  1889. [4.4, 26.0, -46.0, 6.0]])
  1890. B2 = np.array([[1, 2, -3, 1],
  1891. [1, 3, -5, 4],
  1892. [1, 3, -4, 3],
  1893. [1, 3, -4, 4]])
  1894. # example with the eigenvalues
  1895. # -0.33891648, 1.61217396+0.74013521j, 1.61217396-0.74013521j,
  1896. # 0.61244091
  1897. # thus featuring:
  1898. # * one complex conjugate eigenvalue pair,
  1899. # * one eigenvalue in the lhp
  1900. # * 2 eigenvalues in the unit circle
  1901. # * 2 non-real eigenvalues
  1902. A3 = np.array([[5., 1., 3., 3.],
  1903. [4., 4., 2., 7.],
  1904. [7., 4., 1., 3.],
  1905. [0., 4., 8., 7.]])
  1906. B3 = np.array([[8., 10., 6., 10.],
  1907. [7., 7., 2., 9.],
  1908. [9., 1., 6., 6.],
  1909. [5., 1., 4., 7.]])
  1910. # example with infinite eigenvalues
  1911. A4 = np.eye(2)
  1912. B4 = np.diag([0, 1])
  1913. # example with (alpha, beta) = (0, 0)
  1914. A5 = np.diag([1, 0])
  1915. B5 = np.diag([1, 0])
  1916. cls.A = [A1, A2, A3, A4, A5]
  1917. cls.B = [B1, B2, B3, B4, A5]
  1918. def qz_decomp(self, sort):
  1919. try:
  1920. olderr = np.seterr('raise')
  1921. ret = [ordqz(Ai, Bi, sort=sort) for Ai, Bi in zip(self.A, self.B)]
  1922. finally:
  1923. np.seterr(**olderr)
  1924. return tuple(ret)
  1925. def check(self, A, B, sort, AA, BB, alpha, beta, Q, Z):
  1926. Id = np.eye(*A.shape)
  1927. # make sure Q and Z are orthogonal
  1928. assert_array_almost_equal(Q.dot(Q.T.conj()), Id)
  1929. assert_array_almost_equal(Z.dot(Z.T.conj()), Id)
  1930. # check factorization
  1931. assert_array_almost_equal(Q.dot(AA), A.dot(Z))
  1932. assert_array_almost_equal(Q.dot(BB), B.dot(Z))
  1933. # check shape of AA and BB
  1934. assert_array_equal(np.tril(AA, -2), np.zeros(AA.shape))
  1935. assert_array_equal(np.tril(BB, -1), np.zeros(BB.shape))
  1936. # check eigenvalues
  1937. for i in range(A.shape[0]):
  1938. # does the current diagonal element belong to a 2-by-2 block
  1939. # that was already checked?
  1940. if i > 0 and A[i, i - 1] != 0:
  1941. continue
  1942. # take care of 2-by-2 blocks
  1943. if i < AA.shape[0] - 1 and AA[i + 1, i] != 0:
  1944. evals, _ = eig(AA[i:i + 2, i:i + 2], BB[i:i + 2, i:i + 2])
  1945. # make sure the pair of complex conjugate eigenvalues
  1946. # is ordered consistently (positive imaginary part first)
  1947. if evals[0].imag < 0:
  1948. evals = evals[[1, 0]]
  1949. tmp = alpha[i:i + 2]/beta[i:i + 2]
  1950. if tmp[0].imag < 0:
  1951. tmp = tmp[[1, 0]]
  1952. assert_array_almost_equal(evals, tmp)
  1953. else:
  1954. if alpha[i] == 0 and beta[i] == 0:
  1955. assert_equal(AA[i, i], 0)
  1956. assert_equal(BB[i, i], 0)
  1957. elif beta[i] == 0:
  1958. assert_equal(BB[i, i], 0)
  1959. else:
  1960. assert_almost_equal(AA[i, i]/BB[i, i], alpha[i]/beta[i])
  1961. sortfun = _select_function(sort)
  1962. lastsort = True
  1963. for i in range(A.shape[0]):
  1964. cursort = sortfun(np.array([alpha[i]]), np.array([beta[i]]))
  1965. # once the sorting criterion was not matched all subsequent
  1966. # eigenvalues also shouldn't match
  1967. if not lastsort:
  1968. assert(not cursort)
  1969. lastsort = cursort
  1970. def check_all(self, sort):
  1971. ret = self.qz_decomp(sort)
  1972. for reti, Ai, Bi in zip(ret, self.A, self.B):
  1973. self.check(Ai, Bi, sort, *reti)
  1974. def test_lhp(self):
  1975. self.check_all('lhp')
  1976. def test_rhp(self):
  1977. self.check_all('rhp')
  1978. def test_iuc(self):
  1979. self.check_all('iuc')
  1980. def test_ouc(self):
  1981. self.check_all('ouc')
  1982. def test_ref(self):
  1983. # real eigenvalues first (top-left corner)
  1984. def sort(x, y):
  1985. out = np.empty_like(x, dtype=bool)
  1986. nonzero = (y != 0)
  1987. out[~nonzero] = False
  1988. out[nonzero] = (x[nonzero]/y[nonzero]).imag == 0
  1989. return out
  1990. self.check_all(sort)
  1991. def test_cef(self):
  1992. # complex eigenvalues first (top-left corner)
  1993. def sort(x, y):
  1994. out = np.empty_like(x, dtype=bool)
  1995. nonzero = (y != 0)
  1996. out[~nonzero] = False
  1997. out[nonzero] = (x[nonzero]/y[nonzero]).imag != 0
  1998. return out
  1999. self.check_all(sort)
  2000. def test_diff_input_types(self):
  2001. ret = ordqz(self.A[1], self.B[2], sort='lhp')
  2002. self.check(self.A[1], self.B[2], 'lhp', *ret)
  2003. ret = ordqz(self.B[2], self.A[1], sort='lhp')
  2004. self.check(self.B[2], self.A[1], 'lhp', *ret)
  2005. def test_sort_explicit(self):
  2006. # Test order of the eigenvalues in the 2 x 2 case where we can
  2007. # explicitly compute the solution
  2008. A1 = np.eye(2)
  2009. B1 = np.diag([-2, 0.5])
  2010. expected1 = [('lhp', [-0.5, 2]),
  2011. ('rhp', [2, -0.5]),
  2012. ('iuc', [-0.5, 2]),
  2013. ('ouc', [2, -0.5])]
  2014. A2 = np.eye(2)
  2015. B2 = np.diag([-2 + 1j, 0.5 + 0.5j])
  2016. expected2 = [('lhp', [1/(-2 + 1j), 1/(0.5 + 0.5j)]),
  2017. ('rhp', [1/(0.5 + 0.5j), 1/(-2 + 1j)]),
  2018. ('iuc', [1/(-2 + 1j), 1/(0.5 + 0.5j)]),
  2019. ('ouc', [1/(0.5 + 0.5j), 1/(-2 + 1j)])]
  2020. # 'lhp' is ambiguous so don't test it
  2021. A3 = np.eye(2)
  2022. B3 = np.diag([2, 0])
  2023. expected3 = [('rhp', [0.5, np.inf]),
  2024. ('iuc', [0.5, np.inf]),
  2025. ('ouc', [np.inf, 0.5])]
  2026. # 'rhp' is ambiguous so don't test it
  2027. A4 = np.eye(2)
  2028. B4 = np.diag([-2, 0])
  2029. expected4 = [('lhp', [-0.5, np.inf]),
  2030. ('iuc', [-0.5, np.inf]),
  2031. ('ouc', [np.inf, -0.5])]
  2032. A5 = np.diag([0, 1])
  2033. B5 = np.diag([0, 0.5])
  2034. # 'lhp' and 'iuc' are ambiguous so don't test them
  2035. expected5 = [('rhp', [2, np.nan]),
  2036. ('ouc', [2, np.nan])]
  2037. A = [A1, A2, A3, A4, A5]
  2038. B = [B1, B2, B3, B4, B5]
  2039. expected = [expected1, expected2, expected3, expected4, expected5]
  2040. for Ai, Bi, expectedi in zip(A, B, expected):
  2041. for sortstr, expected_eigvals in expectedi:
  2042. _, _, alpha, beta, _, _ = ordqz(Ai, Bi, sort=sortstr)
  2043. azero = (alpha == 0)
  2044. bzero = (beta == 0)
  2045. x = np.empty_like(alpha)
  2046. x[azero & bzero] = np.nan
  2047. x[~azero & bzero] = np.inf
  2048. x[~bzero] = alpha[~bzero]/beta[~bzero]
  2049. assert_allclose(expected_eigvals, x)
  2050. class TestOrdQZWorkspaceSize(object):
  2051. def setup_method(self):
  2052. seed(12345)
  2053. def test_decompose(self):
  2054. N = 202
  2055. # raises error if lwork parameter to dtrsen is too small
  2056. for ddtype in [np.float32, np.float64]:
  2057. A = random((N,N)).astype(ddtype)
  2058. B = random((N,N)).astype(ddtype)
  2059. # sort = lambda alphar, alphai, beta: alphar**2 + alphai**2< beta**2
  2060. sort = lambda alpha, beta: alpha < beta
  2061. [S,T,alpha,beta,U,V] = ordqz(A,B,sort=sort, output='real')
  2062. for ddtype in [np.complex, np.complex64]:
  2063. A = random((N,N)).astype(ddtype)
  2064. B = random((N,N)).astype(ddtype)
  2065. sort = lambda alpha, beta: alpha < beta
  2066. [S,T,alpha,beta,U,V] = ordqz(A,B,sort=sort, output='complex')
  2067. @pytest.mark.slow
  2068. def test_decompose_ouc(self):
  2069. N = 202
  2070. # segfaults if lwork parameter to dtrsen is too small
  2071. for ddtype in [np.float32, np.float64, np.complex, np.complex64]:
  2072. A = random((N,N)).astype(ddtype)
  2073. B = random((N,N)).astype(ddtype)
  2074. [S,T,alpha,beta,U,V] = ordqz(A,B,sort='ouc')
  2075. class TestDatacopied(object):
  2076. def test_datacopied(self):
  2077. from scipy.linalg.decomp import _datacopied
  2078. M = matrix([[0,1],[2,3]])
  2079. A = asarray(M)
  2080. L = M.tolist()
  2081. M2 = M.copy()
  2082. class Fake1:
  2083. def __array__(self):
  2084. return A
  2085. class Fake2:
  2086. __array_interface__ = A.__array_interface__
  2087. F1 = Fake1()
  2088. F2 = Fake2()
  2089. for item, status in [(M, False), (A, False), (L, True),
  2090. (M2, False), (F1, False), (F2, False)]:
  2091. arr = asarray(item)
  2092. assert_equal(_datacopied(arr, item), status,
  2093. err_msg=repr(item))
  2094. def test_aligned_mem_float():
  2095. """Check linalg works with non-aligned memory"""
  2096. # Allocate 402 bytes of memory (allocated on boundary)
  2097. a = arange(402, dtype=np.uint8)
  2098. # Create an array with boundary offset 4
  2099. z = np.frombuffer(a.data, offset=2, count=100, dtype=float32)
  2100. z.shape = 10, 10
  2101. eig(z, overwrite_a=True)
  2102. eig(z.T, overwrite_a=True)
  2103. def test_aligned_mem():
  2104. """Check linalg works with non-aligned memory"""
  2105. # Allocate 804 bytes of memory (allocated on boundary)
  2106. a = arange(804, dtype=np.uint8)
  2107. # Create an array with boundary offset 4
  2108. z = np.frombuffer(a.data, offset=4, count=100, dtype=float)
  2109. z.shape = 10, 10
  2110. eig(z, overwrite_a=True)
  2111. eig(z.T, overwrite_a=True)
  2112. def test_aligned_mem_complex():
  2113. """Check that complex objects don't need to be completely aligned"""
  2114. # Allocate 1608 bytes of memory (allocated on boundary)
  2115. a = zeros(1608, dtype=np.uint8)
  2116. # Create an array with boundary offset 8
  2117. z = np.frombuffer(a.data, offset=8, count=100, dtype=complex)
  2118. z.shape = 10, 10
  2119. eig(z, overwrite_a=True)
  2120. # This does not need special handling
  2121. eig(z.T, overwrite_a=True)
  2122. def check_lapack_misaligned(func, args, kwargs):
  2123. args = list(args)
  2124. for i in range(len(args)):
  2125. a = args[:]
  2126. if isinstance(a[i],np.ndarray):
  2127. # Try misaligning a[i]
  2128. aa = np.zeros(a[i].size*a[i].dtype.itemsize+8, dtype=np.uint8)
  2129. aa = np.frombuffer(aa.data, offset=4, count=a[i].size, dtype=a[i].dtype)
  2130. aa.shape = a[i].shape
  2131. aa[...] = a[i]
  2132. a[i] = aa
  2133. func(*a,**kwargs)
  2134. if len(a[i].shape) > 1:
  2135. a[i] = a[i].T
  2136. func(*a,**kwargs)
  2137. @pytest.mark.xfail(run=False, reason="Ticket #1152, triggers a segfault in rare cases.")
  2138. def test_lapack_misaligned():
  2139. M = np.eye(10,dtype=float)
  2140. R = np.arange(100)
  2141. R.shape = 10,10
  2142. S = np.arange(20000,dtype=np.uint8)
  2143. S = np.frombuffer(S.data, offset=4, count=100, dtype=float)
  2144. S.shape = 10, 10
  2145. b = np.ones(10)
  2146. LU, piv = lu_factor(S)
  2147. for (func, args, kwargs) in [
  2148. (eig,(S,),dict(overwrite_a=True)), # crash
  2149. (eigvals,(S,),dict(overwrite_a=True)), # no crash
  2150. (lu,(S,),dict(overwrite_a=True)), # no crash
  2151. (lu_factor,(S,),dict(overwrite_a=True)), # no crash
  2152. (lu_solve,((LU,piv),b),dict(overwrite_b=True)),
  2153. (solve,(S,b),dict(overwrite_a=True,overwrite_b=True)),
  2154. (svd,(M,),dict(overwrite_a=True)), # no crash
  2155. (svd,(R,),dict(overwrite_a=True)), # no crash
  2156. (svd,(S,),dict(overwrite_a=True)), # crash
  2157. (svdvals,(S,),dict()), # no crash
  2158. (svdvals,(S,),dict(overwrite_a=True)), # crash
  2159. (cholesky,(M,),dict(overwrite_a=True)), # no crash
  2160. (qr,(S,),dict(overwrite_a=True)), # crash
  2161. (rq,(S,),dict(overwrite_a=True)), # crash
  2162. (hessenberg,(S,),dict(overwrite_a=True)), # crash
  2163. (schur,(S,),dict(overwrite_a=True)), # crash
  2164. ]:
  2165. check_lapack_misaligned(func, args, kwargs)
  2166. # not properly tested
  2167. # cholesky, rsf2csf, lu_solve, solve, eig_banded, eigvals_banded, eigh, diagsvd
  2168. class TestOverwrite(object):
  2169. def test_eig(self):
  2170. assert_no_overwrite(eig, [(3,3)])
  2171. assert_no_overwrite(eig, [(3,3), (3,3)])
  2172. def test_eigh(self):
  2173. assert_no_overwrite(eigh, [(3,3)])
  2174. assert_no_overwrite(eigh, [(3,3), (3,3)])
  2175. def test_eig_banded(self):
  2176. assert_no_overwrite(eig_banded, [(3,2)])
  2177. def test_eigvals(self):
  2178. assert_no_overwrite(eigvals, [(3,3)])
  2179. def test_eigvalsh(self):
  2180. assert_no_overwrite(eigvalsh, [(3,3)])
  2181. def test_eigvals_banded(self):
  2182. assert_no_overwrite(eigvals_banded, [(3,2)])
  2183. def test_hessenberg(self):
  2184. assert_no_overwrite(hessenberg, [(3,3)])
  2185. def test_lu_factor(self):
  2186. assert_no_overwrite(lu_factor, [(3,3)])
  2187. def test_lu_solve(self):
  2188. x = np.array([[1,2,3], [4,5,6], [7,8,8]])
  2189. xlu = lu_factor(x)
  2190. assert_no_overwrite(lambda b: lu_solve(xlu, b), [(3,)])
  2191. def test_lu(self):
  2192. assert_no_overwrite(lu, [(3,3)])
  2193. def test_qr(self):
  2194. assert_no_overwrite(qr, [(3,3)])
  2195. def test_rq(self):
  2196. assert_no_overwrite(rq, [(3,3)])
  2197. def test_schur(self):
  2198. assert_no_overwrite(schur, [(3,3)])
  2199. def test_schur_complex(self):
  2200. assert_no_overwrite(lambda a: schur(a, 'complex'), [(3,3)],
  2201. dtypes=[np.float32, np.float64])
  2202. def test_svd(self):
  2203. assert_no_overwrite(svd, [(3,3)])
  2204. assert_no_overwrite(lambda a: svd(a, lapack_driver='gesvd'), [(3,3)])
  2205. def test_svdvals(self):
  2206. assert_no_overwrite(svdvals, [(3,3)])
  2207. def _check_orth(n, dtype, skip_big=False):
  2208. X = np.ones((n, 2), dtype=float).astype(dtype)
  2209. eps = np.finfo(dtype).eps
  2210. tol = 1000 * eps
  2211. Y = orth(X)
  2212. assert_equal(Y.shape, (n, 1))
  2213. assert_allclose(Y, Y.mean(), atol=tol)
  2214. Y = orth(X.T)
  2215. assert_equal(Y.shape, (2, 1))
  2216. assert_allclose(Y, Y.mean(), atol=tol)
  2217. if n > 5 and not skip_big:
  2218. np.random.seed(1)
  2219. X = np.random.rand(n, 5).dot(np.random.rand(5, n))
  2220. X = X + 1e-4 * np.random.rand(n, 1).dot(np.random.rand(1, n))
  2221. X = X.astype(dtype)
  2222. Y = orth(X, rcond=1e-3)
  2223. assert_equal(Y.shape, (n, 5))
  2224. Y = orth(X, rcond=1e-6)
  2225. assert_equal(Y.shape, (n, 5 + 1))
  2226. @pytest.mark.slow
  2227. @pytest.mark.skipif(np.dtype(np.intp).itemsize < 8, reason="test only on 64-bit, else too slow")
  2228. def test_orth_memory_efficiency():
  2229. # Pick n so that 16*n bytes is reasonable but 8*n*n bytes is unreasonable.
  2230. # Keep in mind that @pytest.mark.slow tests are likely to be running
  2231. # under configurations that support 4Gb+ memory for tests related to
  2232. # 32 bit overflow.
  2233. n = 10*1000*1000
  2234. try:
  2235. _check_orth(n, np.float64, skip_big=True)
  2236. except MemoryError:
  2237. raise AssertionError('memory error perhaps caused by orth regression')
  2238. def test_orth():
  2239. dtypes = [np.float32, np.float64, np.complex64, np.complex128]
  2240. sizes = [1, 2, 3, 10, 100]
  2241. for dt, n in itertools.product(dtypes, sizes):
  2242. _check_orth(n, dt)
  2243. def test_null_space():
  2244. np.random.seed(1)
  2245. dtypes = [np.float32, np.float64, np.complex64, np.complex128]
  2246. sizes = [1, 2, 3, 10, 100]
  2247. for dt, n in itertools.product(dtypes, sizes):
  2248. X = np.ones((2, n), dtype=dt)
  2249. eps = np.finfo(dt).eps
  2250. tol = 1000 * eps
  2251. Y = null_space(X)
  2252. assert_equal(Y.shape, (n, n-1))
  2253. assert_allclose(X.dot(Y), 0, atol=tol)
  2254. Y = null_space(X.T)
  2255. assert_equal(Y.shape, (2, 1))
  2256. assert_allclose(X.T.dot(Y), 0, atol=tol)
  2257. X = np.random.randn(1 + n//2, n)
  2258. Y = null_space(X)
  2259. assert_equal(Y.shape, (n, n - 1 - n//2))
  2260. assert_allclose(X.dot(Y), 0, atol=tol)
  2261. if n > 5:
  2262. np.random.seed(1)
  2263. X = np.random.rand(n, 5).dot(np.random.rand(5, n))
  2264. X = X + 1e-4 * np.random.rand(n, 1).dot(np.random.rand(1, n))
  2265. X = X.astype(dt)
  2266. Y = null_space(X, rcond=1e-3)
  2267. assert_equal(Y.shape, (n, n - 5))
  2268. Y = null_space(X, rcond=1e-6)
  2269. assert_equal(Y.shape, (n, n - 6))
  2270. def test_subspace_angles():
  2271. H = hadamard(8, float)
  2272. A = H[:, :3]
  2273. B = H[:, 3:]
  2274. assert_allclose(subspace_angles(A, B), [np.pi / 2.] * 3, atol=1e-14)
  2275. assert_allclose(subspace_angles(B, A), [np.pi / 2.] * 3, atol=1e-14)
  2276. for x in (A, B):
  2277. assert_allclose(subspace_angles(x, x), np.zeros(x.shape[1]),
  2278. atol=1e-14)
  2279. # From MATLAB function "subspace", which effectively only returns the
  2280. # last value that we calculate
  2281. x = np.array(
  2282. [[0.537667139546100, 0.318765239858981, 3.578396939725760, 0.725404224946106], # noqa: E501
  2283. [1.833885014595086, -1.307688296305273, 2.769437029884877, -0.063054873189656], # noqa: E501
  2284. [-2.258846861003648, -0.433592022305684, -1.349886940156521, 0.714742903826096], # noqa: E501
  2285. [0.862173320368121, 0.342624466538650, 3.034923466331855, -0.204966058299775]]) # noqa: E501
  2286. expected = 1.481454682101605
  2287. assert_allclose(subspace_angles(x[:, :2], x[:, 2:])[0], expected,
  2288. rtol=1e-12)
  2289. assert_allclose(subspace_angles(x[:, 2:], x[:, :2])[0], expected,
  2290. rtol=1e-12)
  2291. expected = 0.746361174247302
  2292. assert_allclose(subspace_angles(x[:, :2], x[:, [2]]), expected, rtol=1e-12)
  2293. assert_allclose(subspace_angles(x[:, [2]], x[:, :2]), expected, rtol=1e-12)
  2294. expected = 0.487163718534313
  2295. assert_allclose(subspace_angles(x[:, :3], x[:, [3]]), expected, rtol=1e-12)
  2296. assert_allclose(subspace_angles(x[:, [3]], x[:, :3]), expected, rtol=1e-12)
  2297. expected = 0.328950515907756
  2298. assert_allclose(subspace_angles(x[:, :2], x[:, 1:]), [expected, 0],
  2299. atol=1e-12)
  2300. # Degenerate conditions
  2301. assert_raises(ValueError, subspace_angles, x[0], x)
  2302. assert_raises(ValueError, subspace_angles, x, x[0])
  2303. assert_raises(ValueError, subspace_angles, x[:-1], x)
  2304. # Test branch if mask.any is True:
  2305. A = np.array([[1, 0, 0],
  2306. [0, 1, 0],
  2307. [0, 0, 1],
  2308. [0, 0, 0],
  2309. [0, 0, 0]])
  2310. B = np.array([[1, 0, 0],
  2311. [0, 1, 0],
  2312. [0, 0, 0],
  2313. [0, 0, 0],
  2314. [0, 0, 1]])
  2315. expected = np.array([np.pi/2, 0, 0])
  2316. assert_allclose(subspace_angles(A, B), expected, rtol=1e-12)
  2317. # Complex
  2318. # second column in "b" does not affect result, just there so that
  2319. # b can have more cols than a, and vice-versa (both conditional code paths)
  2320. a = [[1 + 1j], [0]]
  2321. b = [[1 - 1j, 0], [0, 1]]
  2322. assert_allclose(subspace_angles(a, b), 0., atol=1e-14)
  2323. assert_allclose(subspace_angles(b, a), 0., atol=1e-14)
  2324. class TestCDF2RDF(object):
  2325. def matmul(self, a, b):
  2326. return np.einsum('...ij,...jk->...ik', a, b)
  2327. def assert_eig_valid(self, w, v, x):
  2328. assert_array_almost_equal(
  2329. self.matmul(v, w),
  2330. self.matmul(x, v)
  2331. )
  2332. def test_single_array0x0real(self):
  2333. # eig doesn't support 0x0 in old versions of numpy
  2334. X = np.empty((0, 0))
  2335. w, v = np.empty(0), np.empty((0, 0))
  2336. wr, vr = cdf2rdf(w, v)
  2337. self.assert_eig_valid(wr, vr, X)
  2338. def test_single_array2x2_real(self):
  2339. X = np.array([[1, 2], [3, -1]])
  2340. w, v = np.linalg.eig(X)
  2341. wr, vr = cdf2rdf(w, v)
  2342. self.assert_eig_valid(wr, vr, X)
  2343. def test_single_array2x2_complex(self):
  2344. X = np.array([[1, 2], [-2, 1]])
  2345. w, v = np.linalg.eig(X)
  2346. wr, vr = cdf2rdf(w, v)
  2347. self.assert_eig_valid(wr, vr, X)
  2348. def test_single_array3x3_real(self):
  2349. X = np.array([[1, 2, 3], [1, 2, 3], [2, 5, 6]])
  2350. w, v = np.linalg.eig(X)
  2351. wr, vr = cdf2rdf(w, v)
  2352. self.assert_eig_valid(wr, vr, X)
  2353. def test_single_array3x3_complex(self):
  2354. X = np.array([[1, 2, 3], [0, 4, 5], [0, -5, 4]])
  2355. w, v = np.linalg.eig(X)
  2356. wr, vr = cdf2rdf(w, v)
  2357. self.assert_eig_valid(wr, vr, X)
  2358. def test_random_1d_stacked_arrays(self):
  2359. # cannot test M == 0 due to bug in old numpy
  2360. for M in range(1, 7):
  2361. X = np.random.rand(100, M, M)
  2362. w, v = np.linalg.eig(X)
  2363. wr, vr = cdf2rdf(w, v)
  2364. self.assert_eig_valid(wr, vr, X)
  2365. def test_random_2d_stacked_arrays(self):
  2366. # cannot test M == 0 due to bug in old numpy
  2367. for M in range(1, 7):
  2368. X = np.random.rand(10, 10, M, M)
  2369. w, v = np.linalg.eig(X)
  2370. wr, vr = cdf2rdf(w, v)
  2371. self.assert_eig_valid(wr, vr, X)
  2372. def test_low_dimensionality_error(self):
  2373. w, v = np.empty(()), np.array((2,))
  2374. assert_raises(ValueError, cdf2rdf, w, v)
  2375. def test_not_square_error(self):
  2376. # Check that passing a non-square array raises a ValueError.
  2377. w, v = np.arange(3), np.arange(6).reshape(3,2)
  2378. assert_raises(ValueError, cdf2rdf, w, v)
  2379. def test_swapped_v_w_error(self):
  2380. # Check that exchanging places of w and v raises ValueError.
  2381. X = np.array([[1, 2, 3], [0, 4, 5], [0, -5, 4]])
  2382. w, v = np.linalg.eig(X)
  2383. assert_raises(ValueError, cdf2rdf, v, w)
  2384. def test_non_associated_error(self):
  2385. # Check that passing non-associated eigenvectors raises a ValueError.
  2386. w, v = np.arange(3), np.arange(16).reshape(4,4)
  2387. assert_raises(ValueError, cdf2rdf, w, v)
  2388. def test_not_conjugate_pairs(self):
  2389. # Check that passing non-conjugate pairs raises a ValueError.
  2390. X = np.array([[1, 2, 3], [1, 2, 3], [2, 5, 6+1j]])
  2391. w, v = np.linalg.eig(X)
  2392. assert_raises(ValueError, cdf2rdf, w, v)
  2393. # different arrays in the stack, so not conjugate
  2394. X = np.array([
  2395. [[1, 2, 3], [1, 2, 3], [2, 5, 6+1j]],
  2396. [[1, 2, 3], [1, 2, 3], [2, 5, 6-1j]],
  2397. ])
  2398. w, v = np.linalg.eig(X)
  2399. assert_raises(ValueError, cdf2rdf, w, v)