test_basic.py 62 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653
  1. #
  2. # Created by: Pearu Peterson, March 2002
  3. #
  4. """ Test functions for linalg.basic module
  5. """
  6. from __future__ import division, print_function, absolute_import
  7. import warnings
  8. import itertools
  9. import numpy as np
  10. from numpy import (arange, array, dot, zeros, identity, conjugate, transpose,
  11. float32)
  12. import numpy.linalg as linalg
  13. from numpy.random import random
  14. from numpy.testing import (assert_equal, assert_almost_equal, assert_,
  15. assert_array_almost_equal, assert_allclose,
  16. assert_array_equal)
  17. import pytest
  18. from pytest import raises as assert_raises
  19. from scipy._lib._numpy_compat import suppress_warnings
  20. from scipy.linalg import (solve, inv, det, lstsq, pinv, pinv2, pinvh, norm,
  21. solve_banded, solveh_banded, solve_triangular,
  22. solve_circulant, circulant, LinAlgError, block_diag,
  23. matrix_balance, LinAlgWarning)
  24. from scipy.linalg.basic import LstsqLapackError
  25. from scipy.linalg._testutils import assert_no_overwrite
  26. from scipy._lib._version import NumpyVersion
  27. """
  28. Bugs:
  29. 1) solve.check_random_sym_complex fails if a is complex
  30. and transpose(a) = conjugate(a) (a is Hermitian).
  31. """
  32. __usage__ = """
  33. Build linalg:
  34. python setup_linalg.py build
  35. Run tests if scipy is installed:
  36. python -c 'import scipy;scipy.linalg.test()'
  37. Run tests if linalg is not installed:
  38. python tests/test_basic.py
  39. """
  40. REAL_DTYPES = [np.float32, np.float64, np.longdouble]
  41. COMPLEX_DTYPES = [np.complex64, np.complex128, np.clongdouble]
  42. DTYPES = REAL_DTYPES + COMPLEX_DTYPES
  43. def _eps_cast(dtyp):
  44. """Get the epsilon for dtype, possibly downcast to BLAS types."""
  45. dt = dtyp
  46. if dt == np.longdouble:
  47. dt = np.float64
  48. elif dt == np.clongdouble:
  49. dt = np.complex128
  50. return np.finfo(dt).eps
  51. class TestSolveBanded(object):
  52. def test_real(self):
  53. a = array([[1.0, 20, 0, 0],
  54. [-30, 4, 6, 0],
  55. [2, 1, 20, 2],
  56. [0, -1, 7, 14]])
  57. ab = array([[0.0, 20, 6, 2],
  58. [1, 4, 20, 14],
  59. [-30, 1, 7, 0],
  60. [2, -1, 0, 0]])
  61. l, u = 2, 1
  62. b4 = array([10.0, 0.0, 2.0, 14.0])
  63. b4by1 = b4.reshape(-1, 1)
  64. b4by2 = array([[2, 1],
  65. [-30, 4],
  66. [2, 3],
  67. [1, 3]])
  68. b4by4 = array([[1, 0, 0, 0],
  69. [0, 0, 0, 1],
  70. [0, 1, 0, 0],
  71. [0, 1, 0, 0]])
  72. for b in [b4, b4by1, b4by2, b4by4]:
  73. x = solve_banded((l, u), ab, b)
  74. assert_array_almost_equal(dot(a, x), b)
  75. def test_complex(self):
  76. a = array([[1.0, 20, 0, 0],
  77. [-30, 4, 6, 0],
  78. [2j, 1, 20, 2j],
  79. [0, -1, 7, 14]])
  80. ab = array([[0.0, 20, 6, 2j],
  81. [1, 4, 20, 14],
  82. [-30, 1, 7, 0],
  83. [2j, -1, 0, 0]])
  84. l, u = 2, 1
  85. b4 = array([10.0, 0.0, 2.0, 14.0j])
  86. b4by1 = b4.reshape(-1, 1)
  87. b4by2 = array([[2, 1],
  88. [-30, 4],
  89. [2, 3],
  90. [1, 3]])
  91. b4by4 = array([[1, 0, 0, 0],
  92. [0, 0, 0, 1j],
  93. [0, 1, 0, 0],
  94. [0, 1, 0, 0]])
  95. for b in [b4, b4by1, b4by2, b4by4]:
  96. x = solve_banded((l, u), ab, b)
  97. assert_array_almost_equal(dot(a, x), b)
  98. def test_tridiag_real(self):
  99. ab = array([[0.0, 20, 6, 2],
  100. [1, 4, 20, 14],
  101. [-30, 1, 7, 0]])
  102. a = np.diag(ab[0, 1:], 1) + np.diag(ab[1, :], 0) + np.diag(
  103. ab[2, :-1], -1)
  104. b4 = array([10.0, 0.0, 2.0, 14.0])
  105. b4by1 = b4.reshape(-1, 1)
  106. b4by2 = array([[2, 1],
  107. [-30, 4],
  108. [2, 3],
  109. [1, 3]])
  110. b4by4 = array([[1, 0, 0, 0],
  111. [0, 0, 0, 1],
  112. [0, 1, 0, 0],
  113. [0, 1, 0, 0]])
  114. for b in [b4, b4by1, b4by2, b4by4]:
  115. x = solve_banded((1, 1), ab, b)
  116. assert_array_almost_equal(dot(a, x), b)
  117. def test_tridiag_complex(self):
  118. ab = array([[0.0, 20, 6, 2j],
  119. [1, 4, 20, 14],
  120. [-30, 1, 7, 0]])
  121. a = np.diag(ab[0, 1:], 1) + np.diag(ab[1, :], 0) + np.diag(
  122. ab[2, :-1], -1)
  123. b4 = array([10.0, 0.0, 2.0, 14.0j])
  124. b4by1 = b4.reshape(-1, 1)
  125. b4by2 = array([[2, 1],
  126. [-30, 4],
  127. [2, 3],
  128. [1, 3]])
  129. b4by4 = array([[1, 0, 0, 0],
  130. [0, 0, 0, 1],
  131. [0, 1, 0, 0],
  132. [0, 1, 0, 0]])
  133. for b in [b4, b4by1, b4by2, b4by4]:
  134. x = solve_banded((1, 1), ab, b)
  135. assert_array_almost_equal(dot(a, x), b)
  136. def test_check_finite(self):
  137. a = array([[1.0, 20, 0, 0],
  138. [-30, 4, 6, 0],
  139. [2, 1, 20, 2],
  140. [0, -1, 7, 14]])
  141. ab = array([[0.0, 20, 6, 2],
  142. [1, 4, 20, 14],
  143. [-30, 1, 7, 0],
  144. [2, -1, 0, 0]])
  145. l, u = 2, 1
  146. b4 = array([10.0, 0.0, 2.0, 14.0])
  147. x = solve_banded((l, u), ab, b4, check_finite=False)
  148. assert_array_almost_equal(dot(a, x), b4)
  149. def test_bad_shape(self):
  150. ab = array([[0.0, 20, 6, 2],
  151. [1, 4, 20, 14],
  152. [-30, 1, 7, 0],
  153. [2, -1, 0, 0]])
  154. l, u = 2, 1
  155. bad = array([1.0, 2.0, 3.0, 4.0]).reshape(-1, 4)
  156. assert_raises(ValueError, solve_banded, (l, u), ab, bad)
  157. assert_raises(ValueError, solve_banded, (l, u), ab, [1.0, 2.0])
  158. # Values of (l,u) are not compatible with ab.
  159. assert_raises(ValueError, solve_banded, (1, 1), ab, [1.0, 2.0])
  160. def test_1x1(self):
  161. b = array([[1., 2., 3.]])
  162. x = solve_banded((1, 1), [[0], [2], [0]], b)
  163. assert_array_equal(x, [[0.5, 1.0, 1.5]])
  164. assert_equal(x.dtype, np.dtype('f8'))
  165. assert_array_equal(b, [[1.0, 2.0, 3.0]])
  166. def test_native_list_arguments(self):
  167. a = [[1.0, 20, 0, 0],
  168. [-30, 4, 6, 0],
  169. [2, 1, 20, 2],
  170. [0, -1, 7, 14]]
  171. ab = [[0.0, 20, 6, 2],
  172. [1, 4, 20, 14],
  173. [-30, 1, 7, 0],
  174. [2, -1, 0, 0]]
  175. l, u = 2, 1
  176. b = [10.0, 0.0, 2.0, 14.0]
  177. x = solve_banded((l, u), ab, b)
  178. assert_array_almost_equal(dot(a, x), b)
  179. class TestSolveHBanded(object):
  180. def test_01_upper(self):
  181. # Solve
  182. # [ 4 1 2 0] [1]
  183. # [ 1 4 1 2] X = [4]
  184. # [ 2 1 4 1] [1]
  185. # [ 0 2 1 4] [2]
  186. # with the RHS as a 1D array.
  187. ab = array([[0.0, 0.0, 2.0, 2.0],
  188. [-99, 1.0, 1.0, 1.0],
  189. [4.0, 4.0, 4.0, 4.0]])
  190. b = array([1.0, 4.0, 1.0, 2.0])
  191. x = solveh_banded(ab, b)
  192. assert_array_almost_equal(x, [0.0, 1.0, 0.0, 0.0])
  193. def test_02_upper(self):
  194. # Solve
  195. # [ 4 1 2 0] [1 6]
  196. # [ 1 4 1 2] X = [4 2]
  197. # [ 2 1 4 1] [1 6]
  198. # [ 0 2 1 4] [2 1]
  199. #
  200. ab = array([[0.0, 0.0, 2.0, 2.0],
  201. [-99, 1.0, 1.0, 1.0],
  202. [4.0, 4.0, 4.0, 4.0]])
  203. b = array([[1.0, 6.0],
  204. [4.0, 2.0],
  205. [1.0, 6.0],
  206. [2.0, 1.0]])
  207. x = solveh_banded(ab, b)
  208. expected = array([[0.0, 1.0],
  209. [1.0, 0.0],
  210. [0.0, 1.0],
  211. [0.0, 0.0]])
  212. assert_array_almost_equal(x, expected)
  213. def test_03_upper(self):
  214. # Solve
  215. # [ 4 1 2 0] [1]
  216. # [ 1 4 1 2] X = [4]
  217. # [ 2 1 4 1] [1]
  218. # [ 0 2 1 4] [2]
  219. # with the RHS as a 2D array with shape (3,1).
  220. ab = array([[0.0, 0.0, 2.0, 2.0],
  221. [-99, 1.0, 1.0, 1.0],
  222. [4.0, 4.0, 4.0, 4.0]])
  223. b = array([1.0, 4.0, 1.0, 2.0]).reshape(-1, 1)
  224. x = solveh_banded(ab, b)
  225. assert_array_almost_equal(x, array([0., 1., 0., 0.]).reshape(-1, 1))
  226. def test_01_lower(self):
  227. # Solve
  228. # [ 4 1 2 0] [1]
  229. # [ 1 4 1 2] X = [4]
  230. # [ 2 1 4 1] [1]
  231. # [ 0 2 1 4] [2]
  232. #
  233. ab = array([[4.0, 4.0, 4.0, 4.0],
  234. [1.0, 1.0, 1.0, -99],
  235. [2.0, 2.0, 0.0, 0.0]])
  236. b = array([1.0, 4.0, 1.0, 2.0])
  237. x = solveh_banded(ab, b, lower=True)
  238. assert_array_almost_equal(x, [0.0, 1.0, 0.0, 0.0])
  239. def test_02_lower(self):
  240. # Solve
  241. # [ 4 1 2 0] [1 6]
  242. # [ 1 4 1 2] X = [4 2]
  243. # [ 2 1 4 1] [1 6]
  244. # [ 0 2 1 4] [2 1]
  245. #
  246. ab = array([[4.0, 4.0, 4.0, 4.0],
  247. [1.0, 1.0, 1.0, -99],
  248. [2.0, 2.0, 0.0, 0.0]])
  249. b = array([[1.0, 6.0],
  250. [4.0, 2.0],
  251. [1.0, 6.0],
  252. [2.0, 1.0]])
  253. x = solveh_banded(ab, b, lower=True)
  254. expected = array([[0.0, 1.0],
  255. [1.0, 0.0],
  256. [0.0, 1.0],
  257. [0.0, 0.0]])
  258. assert_array_almost_equal(x, expected)
  259. def test_01_float32(self):
  260. # Solve
  261. # [ 4 1 2 0] [1]
  262. # [ 1 4 1 2] X = [4]
  263. # [ 2 1 4 1] [1]
  264. # [ 0 2 1 4] [2]
  265. #
  266. ab = array([[0.0, 0.0, 2.0, 2.0],
  267. [-99, 1.0, 1.0, 1.0],
  268. [4.0, 4.0, 4.0, 4.0]], dtype=float32)
  269. b = array([1.0, 4.0, 1.0, 2.0], dtype=float32)
  270. x = solveh_banded(ab, b)
  271. assert_array_almost_equal(x, [0.0, 1.0, 0.0, 0.0])
  272. def test_02_float32(self):
  273. # Solve
  274. # [ 4 1 2 0] [1 6]
  275. # [ 1 4 1 2] X = [4 2]
  276. # [ 2 1 4 1] [1 6]
  277. # [ 0 2 1 4] [2 1]
  278. #
  279. ab = array([[0.0, 0.0, 2.0, 2.0],
  280. [-99, 1.0, 1.0, 1.0],
  281. [4.0, 4.0, 4.0, 4.0]], dtype=float32)
  282. b = array([[1.0, 6.0],
  283. [4.0, 2.0],
  284. [1.0, 6.0],
  285. [2.0, 1.0]], dtype=float32)
  286. x = solveh_banded(ab, b)
  287. expected = array([[0.0, 1.0],
  288. [1.0, 0.0],
  289. [0.0, 1.0],
  290. [0.0, 0.0]])
  291. assert_array_almost_equal(x, expected)
  292. def test_01_complex(self):
  293. # Solve
  294. # [ 4 -j 2 0] [2-j]
  295. # [ j 4 -j 2] X = [4-j]
  296. # [ 2 j 4 -j] [4+j]
  297. # [ 0 2 j 4] [2+j]
  298. #
  299. ab = array([[0.0, 0.0, 2.0, 2.0],
  300. [-99, -1.0j, -1.0j, -1.0j],
  301. [4.0, 4.0, 4.0, 4.0]])
  302. b = array([2-1.0j, 4.0-1j, 4+1j, 2+1j])
  303. x = solveh_banded(ab, b)
  304. assert_array_almost_equal(x, [0.0, 1.0, 1.0, 0.0])
  305. def test_02_complex(self):
  306. # Solve
  307. # [ 4 -j 2 0] [2-j 2+4j]
  308. # [ j 4 -j 2] X = [4-j -1-j]
  309. # [ 2 j 4 -j] [4+j 4+2j]
  310. # [ 0 2 j 4] [2+j j]
  311. #
  312. ab = array([[0.0, 0.0, 2.0, 2.0],
  313. [-99, -1.0j, -1.0j, -1.0j],
  314. [4.0, 4.0, 4.0, 4.0]])
  315. b = array([[2-1j, 2+4j],
  316. [4.0-1j, -1-1j],
  317. [4.0+1j, 4+2j],
  318. [2+1j, 1j]])
  319. x = solveh_banded(ab, b)
  320. expected = array([[0.0, 1.0j],
  321. [1.0, 0.0],
  322. [1.0, 1.0],
  323. [0.0, 0.0]])
  324. assert_array_almost_equal(x, expected)
  325. def test_tridiag_01_upper(self):
  326. # Solve
  327. # [ 4 1 0] [1]
  328. # [ 1 4 1] X = [4]
  329. # [ 0 1 4] [1]
  330. # with the RHS as a 1D array.
  331. ab = array([[-99, 1.0, 1.0], [4.0, 4.0, 4.0]])
  332. b = array([1.0, 4.0, 1.0])
  333. x = solveh_banded(ab, b)
  334. assert_array_almost_equal(x, [0.0, 1.0, 0.0])
  335. def test_tridiag_02_upper(self):
  336. # Solve
  337. # [ 4 1 0] [1 4]
  338. # [ 1 4 1] X = [4 2]
  339. # [ 0 1 4] [1 4]
  340. #
  341. ab = array([[-99, 1.0, 1.0],
  342. [4.0, 4.0, 4.0]])
  343. b = array([[1.0, 4.0],
  344. [4.0, 2.0],
  345. [1.0, 4.0]])
  346. x = solveh_banded(ab, b)
  347. expected = array([[0.0, 1.0],
  348. [1.0, 0.0],
  349. [0.0, 1.0]])
  350. assert_array_almost_equal(x, expected)
  351. def test_tridiag_03_upper(self):
  352. # Solve
  353. # [ 4 1 0] [1]
  354. # [ 1 4 1] X = [4]
  355. # [ 0 1 4] [1]
  356. # with the RHS as a 2D array with shape (3,1).
  357. ab = array([[-99, 1.0, 1.0], [4.0, 4.0, 4.0]])
  358. b = array([1.0, 4.0, 1.0]).reshape(-1, 1)
  359. x = solveh_banded(ab, b)
  360. assert_array_almost_equal(x, array([0.0, 1.0, 0.0]).reshape(-1, 1))
  361. def test_tridiag_01_lower(self):
  362. # Solve
  363. # [ 4 1 0] [1]
  364. # [ 1 4 1] X = [4]
  365. # [ 0 1 4] [1]
  366. #
  367. ab = array([[4.0, 4.0, 4.0],
  368. [1.0, 1.0, -99]])
  369. b = array([1.0, 4.0, 1.0])
  370. x = solveh_banded(ab, b, lower=True)
  371. assert_array_almost_equal(x, [0.0, 1.0, 0.0])
  372. def test_tridiag_02_lower(self):
  373. # Solve
  374. # [ 4 1 0] [1 4]
  375. # [ 1 4 1] X = [4 2]
  376. # [ 0 1 4] [1 4]
  377. #
  378. ab = array([[4.0, 4.0, 4.0],
  379. [1.0, 1.0, -99]])
  380. b = array([[1.0, 4.0],
  381. [4.0, 2.0],
  382. [1.0, 4.0]])
  383. x = solveh_banded(ab, b, lower=True)
  384. expected = array([[0.0, 1.0],
  385. [1.0, 0.0],
  386. [0.0, 1.0]])
  387. assert_array_almost_equal(x, expected)
  388. def test_tridiag_01_float32(self):
  389. # Solve
  390. # [ 4 1 0] [1]
  391. # [ 1 4 1] X = [4]
  392. # [ 0 1 4] [1]
  393. #
  394. ab = array([[-99, 1.0, 1.0], [4.0, 4.0, 4.0]], dtype=float32)
  395. b = array([1.0, 4.0, 1.0], dtype=float32)
  396. x = solveh_banded(ab, b)
  397. assert_array_almost_equal(x, [0.0, 1.0, 0.0])
  398. def test_tridiag_02_float32(self):
  399. # Solve
  400. # [ 4 1 0] [1 4]
  401. # [ 1 4 1] X = [4 2]
  402. # [ 0 1 4] [1 4]
  403. #
  404. ab = array([[-99, 1.0, 1.0],
  405. [4.0, 4.0, 4.0]], dtype=float32)
  406. b = array([[1.0, 4.0],
  407. [4.0, 2.0],
  408. [1.0, 4.0]], dtype=float32)
  409. x = solveh_banded(ab, b)
  410. expected = array([[0.0, 1.0],
  411. [1.0, 0.0],
  412. [0.0, 1.0]])
  413. assert_array_almost_equal(x, expected)
  414. def test_tridiag_01_complex(self):
  415. # Solve
  416. # [ 4 -j 0] [ -j]
  417. # [ j 4 -j] X = [4-j]
  418. # [ 0 j 4] [4+j]
  419. #
  420. ab = array([[-99, -1.0j, -1.0j], [4.0, 4.0, 4.0]])
  421. b = array([-1.0j, 4.0-1j, 4+1j])
  422. x = solveh_banded(ab, b)
  423. assert_array_almost_equal(x, [0.0, 1.0, 1.0])
  424. def test_tridiag_02_complex(self):
  425. # Solve
  426. # [ 4 -j 0] [ -j 4j]
  427. # [ j 4 -j] X = [4-j -1-j]
  428. # [ 0 j 4] [4+j 4 ]
  429. #
  430. ab = array([[-99, -1.0j, -1.0j],
  431. [4.0, 4.0, 4.0]])
  432. b = array([[-1j, 4.0j],
  433. [4.0-1j, -1.0-1j],
  434. [4.0+1j, 4.0]])
  435. x = solveh_banded(ab, b)
  436. expected = array([[0.0, 1.0j],
  437. [1.0, 0.0],
  438. [1.0, 1.0]])
  439. assert_array_almost_equal(x, expected)
  440. def test_check_finite(self):
  441. # Solve
  442. # [ 4 1 0] [1]
  443. # [ 1 4 1] X = [4]
  444. # [ 0 1 4] [1]
  445. # with the RHS as a 1D array.
  446. ab = array([[-99, 1.0, 1.0], [4.0, 4.0, 4.0]])
  447. b = array([1.0, 4.0, 1.0])
  448. x = solveh_banded(ab, b, check_finite=False)
  449. assert_array_almost_equal(x, [0.0, 1.0, 0.0])
  450. def test_bad_shapes(self):
  451. ab = array([[-99, 1.0, 1.0],
  452. [4.0, 4.0, 4.0]])
  453. b = array([[1.0, 4.0],
  454. [4.0, 2.0]])
  455. assert_raises(ValueError, solveh_banded, ab, b)
  456. assert_raises(ValueError, solveh_banded, ab, [1.0, 2.0])
  457. assert_raises(ValueError, solveh_banded, ab, [1.0])
  458. def test_1x1(self):
  459. x = solveh_banded([[1]], [[1, 2, 3]])
  460. assert_array_equal(x, [[1.0, 2.0, 3.0]])
  461. assert_equal(x.dtype, np.dtype('f8'))
  462. def test_native_list_arguments(self):
  463. # Same as test_01_upper, using python's native list.
  464. ab = [[0.0, 0.0, 2.0, 2.0],
  465. [-99, 1.0, 1.0, 1.0],
  466. [4.0, 4.0, 4.0, 4.0]]
  467. b = [1.0, 4.0, 1.0, 2.0]
  468. x = solveh_banded(ab, b)
  469. assert_array_almost_equal(x, [0.0, 1.0, 0.0, 0.0])
  470. class TestSolve(object):
  471. def setup_method(self):
  472. np.random.seed(1234)
  473. def test_20Feb04_bug(self):
  474. a = [[1, 1], [1.0, 0]] # ok
  475. x0 = solve(a, [1, 0j])
  476. assert_array_almost_equal(dot(a, x0), [1, 0])
  477. # gives failure with clapack.zgesv(..,rowmajor=0)
  478. a = [[1, 1], [1.2, 0]]
  479. b = [1, 0j]
  480. x0 = solve(a, b)
  481. assert_array_almost_equal(dot(a, x0), [1, 0])
  482. def test_simple(self):
  483. a = [[1, 20], [-30, 4]]
  484. for b in ([[1, 0], [0, 1]], [1, 0],
  485. [[2, 1], [-30, 4]]):
  486. x = solve(a, b)
  487. assert_array_almost_equal(dot(a, x), b)
  488. def test_simple_sym(self):
  489. a = [[2, 3], [3, 5]]
  490. for lower in [0, 1]:
  491. for b in ([[1, 0], [0, 1]], [1, 0]):
  492. x = solve(a, b, sym_pos=1, lower=lower)
  493. assert_array_almost_equal(dot(a, x), b)
  494. def test_simple_sym_complex(self):
  495. a = [[5, 2], [2, 4]]
  496. for b in [[1j, 0],
  497. [[1j, 1j],
  498. [0, 2]],
  499. ]:
  500. x = solve(a, b, sym_pos=1)
  501. assert_array_almost_equal(dot(a, x), b)
  502. def test_simple_complex(self):
  503. a = array([[5, 2], [2j, 4]], 'D')
  504. for b in [[1j, 0],
  505. [[1j, 1j],
  506. [0, 2]],
  507. [1, 0j],
  508. array([1, 0], 'D'),
  509. ]:
  510. x = solve(a, b)
  511. assert_array_almost_equal(dot(a, x), b)
  512. def test_nils_20Feb04(self):
  513. n = 2
  514. A = random([n, n])+random([n, n])*1j
  515. X = zeros((n, n), 'D')
  516. Ainv = inv(A)
  517. R = identity(n)+identity(n)*0j
  518. for i in arange(0, n):
  519. r = R[:, i]
  520. X[:, i] = solve(A, r)
  521. assert_array_almost_equal(X, Ainv)
  522. def test_random(self):
  523. n = 20
  524. a = random([n, n])
  525. for i in range(n):
  526. a[i, i] = 20*(.1+a[i, i])
  527. for i in range(4):
  528. b = random([n, 3])
  529. x = solve(a, b)
  530. assert_array_almost_equal(dot(a, x), b)
  531. def test_random_complex(self):
  532. n = 20
  533. a = random([n, n]) + 1j * random([n, n])
  534. for i in range(n):
  535. a[i, i] = 20*(.1+a[i, i])
  536. for i in range(2):
  537. b = random([n, 3])
  538. x = solve(a, b)
  539. assert_array_almost_equal(dot(a, x), b)
  540. def test_random_sym(self):
  541. n = 20
  542. a = random([n, n])
  543. for i in range(n):
  544. a[i, i] = abs(20*(.1+a[i, i]))
  545. for j in range(i):
  546. a[i, j] = a[j, i]
  547. for i in range(4):
  548. b = random([n])
  549. x = solve(a, b, sym_pos=1)
  550. assert_array_almost_equal(dot(a, x), b)
  551. def test_random_sym_complex(self):
  552. n = 20
  553. a = random([n, n])
  554. # XXX: with the following addition the accuracy will be very low
  555. a = a + 1j*random([n, n])
  556. for i in range(n):
  557. a[i, i] = abs(20*(.1+a[i, i]))
  558. for j in range(i):
  559. a[i, j] = conjugate(a[j, i])
  560. b = random([n])+2j*random([n])
  561. for i in range(2):
  562. x = solve(a, b, sym_pos=1)
  563. assert_array_almost_equal(dot(a, x), b)
  564. def test_check_finite(self):
  565. a = [[1, 20], [-30, 4]]
  566. for b in ([[1, 0], [0, 1]], [1, 0],
  567. [[2, 1], [-30, 4]]):
  568. x = solve(a, b, check_finite=False)
  569. assert_array_almost_equal(dot(a, x), b)
  570. def test_scalar_a_and_1D_b(self):
  571. a = 1
  572. b = [1, 2, 3]
  573. x = solve(a, b)
  574. assert_array_almost_equal(x.ravel(), b)
  575. assert_(x.shape == (3,), 'Scalar_a_1D_b test returned wrong shape')
  576. def test_simple2(self):
  577. a = np.array([[1.80, 2.88, 2.05, -0.89],
  578. [525.00, -295.00, -95.00, -380.00],
  579. [1.58, -2.69, -2.90, -1.04],
  580. [-1.11, -0.66, -0.59, 0.80]])
  581. b = np.array([[9.52, 18.47],
  582. [2435.00, 225.00],
  583. [0.77, -13.28],
  584. [-6.22, -6.21]])
  585. x = solve(a, b)
  586. assert_array_almost_equal(x, np.array([[1., -1, 3, -5],
  587. [3, 2, 4, 1]]).T)
  588. def test_simple_complex2(self):
  589. a = np.array([[-1.34+2.55j, 0.28+3.17j, -6.39-2.20j, 0.72-0.92j],
  590. [-1.70-14.10j, 33.10-1.50j, -1.50+13.40j, 12.90+13.80j],
  591. [-3.29-2.39j, -1.91+4.42j, -0.14-1.35j, 1.72+1.35j],
  592. [2.41+0.39j, -0.56+1.47j, -0.83-0.69j, -1.96+0.67j]])
  593. b = np.array([[26.26+51.78j, 31.32-6.70j],
  594. [64.30-86.80j, 158.60-14.20j],
  595. [-5.75+25.31j, -2.15+30.19j],
  596. [1.16+2.57j, -2.56+7.55j]])
  597. x = solve(a, b)
  598. assert_array_almost_equal(x, np. array([[1+1.j, -1-2.j],
  599. [2-3.j, 5+1.j],
  600. [-4-5.j, -3+4.j],
  601. [6.j, 2-3.j]]))
  602. def test_hermitian(self):
  603. # An upper triangular matrix will be used for hermitian matrix a
  604. a = np.array([[-1.84, 0.11-0.11j, -1.78-1.18j, 3.91-1.50j],
  605. [0, -4.63, -1.84+0.03j, 2.21+0.21j],
  606. [0, 0, -8.87, 1.58-0.90j],
  607. [0, 0, 0, -1.36]])
  608. b = np.array([[2.98-10.18j, 28.68-39.89j],
  609. [-9.58+3.88j, -24.79-8.40j],
  610. [-0.77-16.05j, 4.23-70.02j],
  611. [7.79+5.48j, -35.39+18.01j]])
  612. res = np.array([[2.+1j, -8+6j],
  613. [3.-2j, 7-2j],
  614. [-1+2j, -1+5j],
  615. [1.-1j, 3-4j]])
  616. x = solve(a, b, assume_a='her')
  617. assert_array_almost_equal(x, res)
  618. # Also conjugate a and test for lower triangular data
  619. x = solve(a.conj().T, b, assume_a='her', lower=True)
  620. assert_array_almost_equal(x, res)
  621. def test_pos_and_sym(self):
  622. A = np.arange(1, 10).reshape(3, 3)
  623. x = solve(np.tril(A)/9, np.ones(3), assume_a='pos')
  624. assert_array_almost_equal(x, [9., 1.8, 1.])
  625. x = solve(np.tril(A)/9, np.ones(3), assume_a='sym')
  626. assert_array_almost_equal(x, [9., 1.8, 1.])
  627. def test_singularity(self):
  628. a = np.array([[1, 0, 0, 0, 0, 0, 1, 0, 1],
  629. [1, 1, 1, 0, 0, 0, 1, 0, 1],
  630. [0, 1, 1, 0, 0, 0, 1, 0, 1],
  631. [1, 0, 1, 1, 1, 1, 0, 0, 0],
  632. [1, 0, 1, 1, 1, 1, 0, 0, 0],
  633. [1, 0, 1, 1, 1, 1, 0, 0, 0],
  634. [1, 0, 1, 1, 1, 1, 0, 0, 0],
  635. [1, 1, 1, 1, 1, 1, 1, 1, 1],
  636. [1, 1, 1, 1, 1, 1, 1, 1, 1]])
  637. b = np.arange(9)[:, None]
  638. assert_raises(LinAlgError, solve, a, b)
  639. def test_ill_condition_warning(self):
  640. a = np.array([[1, 1], [1+1e-16, 1-1e-16]])
  641. b = np.ones(2)
  642. with warnings.catch_warnings():
  643. warnings.simplefilter('error')
  644. assert_raises(LinAlgWarning, solve, a, b)
  645. def test_empty_rhs(self):
  646. a = np.eye(2)
  647. b = [[], []]
  648. x = solve(a, b)
  649. assert_(x.size == 0, 'Returned array is not empty')
  650. assert_(x.shape == (2, 0), 'Returned empty array shape is wrong')
  651. def test_multiple_rhs(self):
  652. a = np.eye(2)
  653. b = np.random.rand(2, 3, 4)
  654. x = solve(a, b)
  655. assert_array_almost_equal(x, b)
  656. def test_transposed_keyword(self):
  657. A = np.arange(9).reshape(3, 3) + 1
  658. x = solve(np.tril(A)/9, np.ones(3), transposed=True)
  659. assert_array_almost_equal(x, [1.2, 0.2, 1])
  660. x = solve(np.tril(A)/9, np.ones(3), transposed=False)
  661. assert_array_almost_equal(x, [9, -5.4, -1.2])
  662. def test_transposed_notimplemented(self):
  663. a = np.eye(3).astype(complex)
  664. with assert_raises(NotImplementedError):
  665. solve(a, a, transposed=True)
  666. def test_nonsquare_a(self):
  667. assert_raises(ValueError, solve, [1, 2], 1)
  668. def test_size_mismatch_with_1D_b(self):
  669. assert_array_almost_equal(solve(np.eye(3), np.ones(3)), np.ones(3))
  670. assert_raises(ValueError, solve, np.eye(3), np.ones(4))
  671. def test_assume_a_keyword(self):
  672. assert_raises(ValueError, solve, 1, 1, assume_a='zxcv')
  673. @pytest.mark.skip(reason="Failure on OS X (gh-7500), "
  674. "crash on Windows (gh-8064)")
  675. def test_all_type_size_routine_combinations(self):
  676. sizes = [10, 100]
  677. assume_as = ['gen', 'sym', 'pos', 'her']
  678. dtypes = [np.float32, np.float64, np.complex64, np.complex128]
  679. for size, assume_a, dtype in itertools.product(sizes, assume_as,
  680. dtypes):
  681. is_complex = dtype in (np.complex64, np.complex128)
  682. if assume_a == 'her' and not is_complex:
  683. continue
  684. err_msg = ("Failed for size: {}, assume_a: {},"
  685. "dtype: {}".format(size, assume_a, dtype))
  686. a = np.random.randn(size, size).astype(dtype)
  687. b = np.random.randn(size).astype(dtype)
  688. if is_complex:
  689. a = a + (1j*np.random.randn(size, size)).astype(dtype)
  690. if assume_a == 'sym': # Can still be complex but only symmetric
  691. a = a + a.T
  692. elif assume_a == 'her': # Handle hermitian matrices here instead
  693. a = a + a.T.conj()
  694. elif assume_a == 'pos':
  695. a = a.conj().T.dot(a) + 0.1*np.eye(size)
  696. tol = 1e-12 if dtype in (np.float64, np.complex128) else 1e-6
  697. if assume_a in ['gen', 'sym', 'her']:
  698. # We revert the tolerance from before
  699. # 4b4a6e7c34fa4060533db38f9a819b98fa81476c
  700. if dtype in (np.float32, np.complex64):
  701. tol *= 10
  702. x = solve(a, b, assume_a=assume_a)
  703. assert_allclose(a.dot(x), b,
  704. atol=tol * size,
  705. rtol=tol * size,
  706. err_msg=err_msg)
  707. if assume_a == 'sym' and dtype not in (np.complex64,
  708. np.complex128):
  709. x = solve(a, b, assume_a=assume_a, transposed=True)
  710. assert_allclose(a.dot(x), b,
  711. atol=tol * size,
  712. rtol=tol * size,
  713. err_msg=err_msg)
  714. class TestSolveTriangular(object):
  715. def test_simple(self):
  716. """
  717. solve_triangular on a simple 2x2 matrix.
  718. """
  719. A = array([[1, 0], [1, 2]])
  720. b = [1, 1]
  721. sol = solve_triangular(A, b, lower=True)
  722. assert_array_almost_equal(sol, [1, 0])
  723. # check that it works also for non-contiguous matrices
  724. sol = solve_triangular(A.T, b, lower=False)
  725. assert_array_almost_equal(sol, [.5, .5])
  726. # and that it gives the same result as trans=1
  727. sol = solve_triangular(A, b, lower=True, trans=1)
  728. assert_array_almost_equal(sol, [.5, .5])
  729. b = identity(2)
  730. sol = solve_triangular(A, b, lower=True, trans=1)
  731. assert_array_almost_equal(sol, [[1., -.5], [0, 0.5]])
  732. def test_simple_complex(self):
  733. """
  734. solve_triangular on a simple 2x2 complex matrix
  735. """
  736. A = array([[1+1j, 0], [1j, 2]])
  737. b = identity(2)
  738. sol = solve_triangular(A, b, lower=True, trans=1)
  739. assert_array_almost_equal(sol, [[.5-.5j, -.25-.25j], [0, 0.5]])
  740. def test_check_finite(self):
  741. """
  742. solve_triangular on a simple 2x2 matrix.
  743. """
  744. A = array([[1, 0], [1, 2]])
  745. b = [1, 1]
  746. sol = solve_triangular(A, b, lower=True, check_finite=False)
  747. assert_array_almost_equal(sol, [1, 0])
  748. class TestInv(object):
  749. def setup_method(self):
  750. np.random.seed(1234)
  751. def test_simple(self):
  752. a = [[1, 2], [3, 4]]
  753. a_inv = inv(a)
  754. assert_array_almost_equal(dot(a, a_inv), np.eye(2))
  755. a = [[1, 2, 3], [4, 5, 6], [7, 8, 10]]
  756. a_inv = inv(a)
  757. assert_array_almost_equal(dot(a, a_inv), np.eye(3))
  758. def test_random(self):
  759. n = 20
  760. for i in range(4):
  761. a = random([n, n])
  762. for i in range(n):
  763. a[i, i] = 20*(.1+a[i, i])
  764. a_inv = inv(a)
  765. assert_array_almost_equal(dot(a, a_inv),
  766. identity(n))
  767. def test_simple_complex(self):
  768. a = [[1, 2], [3, 4j]]
  769. a_inv = inv(a)
  770. assert_array_almost_equal(dot(a, a_inv), [[1, 0], [0, 1]])
  771. def test_random_complex(self):
  772. n = 20
  773. for i in range(4):
  774. a = random([n, n])+2j*random([n, n])
  775. for i in range(n):
  776. a[i, i] = 20*(.1+a[i, i])
  777. a_inv = inv(a)
  778. assert_array_almost_equal(dot(a, a_inv),
  779. identity(n))
  780. def test_check_finite(self):
  781. a = [[1, 2], [3, 4]]
  782. a_inv = inv(a, check_finite=False)
  783. assert_array_almost_equal(dot(a, a_inv), [[1, 0], [0, 1]])
  784. class TestDet(object):
  785. def setup_method(self):
  786. np.random.seed(1234)
  787. def test_simple(self):
  788. a = [[1, 2], [3, 4]]
  789. a_det = det(a)
  790. assert_almost_equal(a_det, -2.0)
  791. def test_simple_complex(self):
  792. a = [[1, 2], [3, 4j]]
  793. a_det = det(a)
  794. assert_almost_equal(a_det, -6+4j)
  795. def test_random(self):
  796. basic_det = linalg.det
  797. n = 20
  798. for i in range(4):
  799. a = random([n, n])
  800. d1 = det(a)
  801. d2 = basic_det(a)
  802. assert_almost_equal(d1, d2)
  803. def test_random_complex(self):
  804. basic_det = linalg.det
  805. n = 20
  806. for i in range(4):
  807. a = random([n, n]) + 2j*random([n, n])
  808. d1 = det(a)
  809. d2 = basic_det(a)
  810. assert_allclose(d1, d2, rtol=1e-13)
  811. def test_check_finite(self):
  812. a = [[1, 2], [3, 4]]
  813. a_det = det(a, check_finite=False)
  814. assert_almost_equal(a_det, -2.0)
  815. def direct_lstsq(a, b, cmplx=0):
  816. at = transpose(a)
  817. if cmplx:
  818. at = conjugate(at)
  819. a1 = dot(at, a)
  820. b1 = dot(at, b)
  821. return solve(a1, b1)
  822. class TestLstsq(object):
  823. lapack_drivers = ('gelsd', 'gelss', 'gelsy', None)
  824. def setup_method(self):
  825. np.random.seed(1234)
  826. def test_simple_exact(self):
  827. for dtype in REAL_DTYPES:
  828. a = np.array([[1, 20], [-30, 4]], dtype=dtype)
  829. for lapack_driver in TestLstsq.lapack_drivers:
  830. for overwrite in (True, False):
  831. for bt in (((1, 0), (0, 1)), (1, 0),
  832. ((2, 1), (-30, 4))):
  833. # Store values in case they are overwritten
  834. # later
  835. a1 = a.copy()
  836. b = np.array(bt, dtype=dtype)
  837. b1 = b.copy()
  838. try:
  839. out = lstsq(a1, b1,
  840. lapack_driver=lapack_driver,
  841. overwrite_a=overwrite,
  842. overwrite_b=overwrite)
  843. except LstsqLapackError:
  844. if lapack_driver is None:
  845. mesg = ('LstsqLapackError raised with '
  846. 'lapack_driver being None.')
  847. raise AssertionError(mesg)
  848. else:
  849. # can't proceed, skip to the next iteration
  850. continue
  851. x = out[0]
  852. r = out[2]
  853. assert_(r == 2,
  854. 'expected efficient rank 2, got %s' % r)
  855. assert_allclose(
  856. dot(a, x), b,
  857. atol=25 * _eps_cast(a1.dtype),
  858. rtol=25 * _eps_cast(a1.dtype),
  859. err_msg="driver: %s" % lapack_driver)
  860. def test_simple_overdet(self):
  861. for dtype in REAL_DTYPES:
  862. a = np.array([[1, 2], [4, 5], [3, 4]], dtype=dtype)
  863. b = np.array([1, 2, 3], dtype=dtype)
  864. for lapack_driver in TestLstsq.lapack_drivers:
  865. for overwrite in (True, False):
  866. # Store values in case they are overwritten later
  867. a1 = a.copy()
  868. b1 = b.copy()
  869. try:
  870. out = lstsq(a1, b1, lapack_driver=lapack_driver,
  871. overwrite_a=overwrite,
  872. overwrite_b=overwrite)
  873. except LstsqLapackError:
  874. if lapack_driver is None:
  875. mesg = ('LstsqLapackError raised with '
  876. 'lapack_driver being None.')
  877. raise AssertionError(mesg)
  878. else:
  879. # can't proceed, skip to the next iteration
  880. continue
  881. x = out[0]
  882. if lapack_driver == 'gelsy':
  883. residuals = np.sum((b - a.dot(x))**2)
  884. else:
  885. residuals = out[1]
  886. r = out[2]
  887. assert_(r == 2, 'expected efficient rank 2, got %s' % r)
  888. assert_allclose(abs((dot(a, x) - b)**2).sum(axis=0),
  889. residuals,
  890. rtol=25 * _eps_cast(a1.dtype),
  891. atol=25 * _eps_cast(a1.dtype),
  892. err_msg="driver: %s" % lapack_driver)
  893. assert_allclose(x, (-0.428571428571429, 0.85714285714285),
  894. rtol=25 * _eps_cast(a1.dtype),
  895. atol=25 * _eps_cast(a1.dtype),
  896. err_msg="driver: %s" % lapack_driver)
  897. def test_simple_overdet_complex(self):
  898. for dtype in COMPLEX_DTYPES:
  899. a = np.array([[1+2j, 2], [4, 5], [3, 4]], dtype=dtype)
  900. b = np.array([1, 2+4j, 3], dtype=dtype)
  901. for lapack_driver in TestLstsq.lapack_drivers:
  902. for overwrite in (True, False):
  903. # Store values in case they are overwritten later
  904. a1 = a.copy()
  905. b1 = b.copy()
  906. try:
  907. out = lstsq(a1, b1, lapack_driver=lapack_driver,
  908. overwrite_a=overwrite,
  909. overwrite_b=overwrite)
  910. except LstsqLapackError:
  911. if lapack_driver is None:
  912. mesg = ('LstsqLapackError raised with '
  913. 'lapack_driver being None.')
  914. raise AssertionError(mesg)
  915. else:
  916. # can't proceed, skip to the next iteration
  917. continue
  918. x = out[0]
  919. if lapack_driver == 'gelsy':
  920. res = b - a.dot(x)
  921. residuals = np.sum(res * res.conj())
  922. else:
  923. residuals = out[1]
  924. r = out[2]
  925. assert_(r == 2, 'expected efficient rank 2, got %s' % r)
  926. assert_allclose(abs((dot(a, x) - b)**2).sum(axis=0),
  927. residuals,
  928. rtol=25 * _eps_cast(a1.dtype),
  929. atol=25 * _eps_cast(a1.dtype),
  930. err_msg="driver: %s" % lapack_driver)
  931. assert_allclose(
  932. x, (-0.4831460674157303 + 0.258426966292135j,
  933. 0.921348314606741 + 0.292134831460674j),
  934. rtol=25 * _eps_cast(a1.dtype),
  935. atol=25 * _eps_cast(a1.dtype),
  936. err_msg="driver: %s" % lapack_driver)
  937. def test_simple_underdet(self):
  938. for dtype in REAL_DTYPES:
  939. a = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)
  940. b = np.array([1, 2], dtype=dtype)
  941. for lapack_driver in TestLstsq.lapack_drivers:
  942. for overwrite in (True, False):
  943. # Store values in case they are overwritten later
  944. a1 = a.copy()
  945. b1 = b.copy()
  946. try:
  947. out = lstsq(a1, b1, lapack_driver=lapack_driver,
  948. overwrite_a=overwrite,
  949. overwrite_b=overwrite)
  950. except LstsqLapackError:
  951. if lapack_driver is None:
  952. mesg = ('LstsqLapackError raised with '
  953. 'lapack_driver being None.')
  954. raise AssertionError(mesg)
  955. else:
  956. # can't proceed, skip to the next iteration
  957. continue
  958. x = out[0]
  959. r = out[2]
  960. assert_(r == 2, 'expected efficient rank 2, got %s' % r)
  961. assert_allclose(x, (-0.055555555555555, 0.111111111111111,
  962. 0.277777777777777),
  963. rtol=25 * _eps_cast(a1.dtype),
  964. atol=25 * _eps_cast(a1.dtype),
  965. err_msg="driver: %s" % lapack_driver)
  966. def test_random_exact(self):
  967. for dtype in REAL_DTYPES:
  968. for n in (20, 200):
  969. for lapack_driver in TestLstsq.lapack_drivers:
  970. for overwrite in (True, False):
  971. a = np.asarray(random([n, n]), dtype=dtype)
  972. for i in range(n):
  973. a[i, i] = 20 * (0.1 + a[i, i])
  974. for i in range(4):
  975. b = np.asarray(random([n, 3]), dtype=dtype)
  976. # Store values in case they are overwritten later
  977. a1 = a.copy()
  978. b1 = b.copy()
  979. try:
  980. out = lstsq(a1, b1,
  981. lapack_driver=lapack_driver,
  982. overwrite_a=overwrite,
  983. overwrite_b=overwrite)
  984. except LstsqLapackError:
  985. if lapack_driver is None:
  986. mesg = ('LstsqLapackError raised with '
  987. 'lapack_driver being None.')
  988. raise AssertionError(mesg)
  989. else:
  990. # can't proceed, skip to the next iteration
  991. continue
  992. x = out[0]
  993. r = out[2]
  994. assert_(r == n, 'expected efficient rank %s, '
  995. 'got %s' % (n, r))
  996. if dtype is np.float32:
  997. assert_allclose(
  998. dot(a, x), b,
  999. rtol=500 * _eps_cast(a1.dtype),
  1000. atol=500 * _eps_cast(a1.dtype),
  1001. err_msg="driver: %s" % lapack_driver)
  1002. else:
  1003. assert_allclose(
  1004. dot(a, x), b,
  1005. rtol=1000 * _eps_cast(a1.dtype),
  1006. atol=1000 * _eps_cast(a1.dtype),
  1007. err_msg="driver: %s" % lapack_driver)
  1008. def test_random_complex_exact(self):
  1009. for dtype in COMPLEX_DTYPES:
  1010. for n in (20, 200):
  1011. for lapack_driver in TestLstsq.lapack_drivers:
  1012. for overwrite in (True, False):
  1013. a = np.asarray(random([n, n]) + 1j*random([n, n]),
  1014. dtype=dtype)
  1015. for i in range(n):
  1016. a[i, i] = 20 * (0.1 + a[i, i])
  1017. for i in range(2):
  1018. b = np.asarray(random([n, 3]), dtype=dtype)
  1019. # Store values in case they are overwritten later
  1020. a1 = a.copy()
  1021. b1 = b.copy()
  1022. out = lstsq(a1, b1, lapack_driver=lapack_driver,
  1023. overwrite_a=overwrite,
  1024. overwrite_b=overwrite)
  1025. x = out[0]
  1026. r = out[2]
  1027. assert_(r == n, 'expected efficient rank %s, '
  1028. 'got %s' % (n, r))
  1029. if dtype is np.complex64:
  1030. assert_allclose(
  1031. dot(a, x), b,
  1032. rtol=400 * _eps_cast(a1.dtype),
  1033. atol=400 * _eps_cast(a1.dtype),
  1034. err_msg="driver: %s" % lapack_driver)
  1035. else:
  1036. assert_allclose(
  1037. dot(a, x), b,
  1038. rtol=1000 * _eps_cast(a1.dtype),
  1039. atol=1000 * _eps_cast(a1.dtype),
  1040. err_msg="driver: %s" % lapack_driver)
  1041. def test_random_overdet(self):
  1042. for dtype in REAL_DTYPES:
  1043. for (n, m) in ((20, 15), (200, 2)):
  1044. for lapack_driver in TestLstsq.lapack_drivers:
  1045. for overwrite in (True, False):
  1046. a = np.asarray(random([n, m]), dtype=dtype)
  1047. for i in range(m):
  1048. a[i, i] = 20 * (0.1 + a[i, i])
  1049. for i in range(4):
  1050. b = np.asarray(random([n, 3]), dtype=dtype)
  1051. # Store values in case they are overwritten later
  1052. a1 = a.copy()
  1053. b1 = b.copy()
  1054. try:
  1055. out = lstsq(a1, b1,
  1056. lapack_driver=lapack_driver,
  1057. overwrite_a=overwrite,
  1058. overwrite_b=overwrite)
  1059. except LstsqLapackError:
  1060. if lapack_driver is None:
  1061. mesg = ('LstsqLapackError raised with '
  1062. 'lapack_driver being None.')
  1063. raise AssertionError(mesg)
  1064. else:
  1065. # can't proceed, skip to the next iteration
  1066. continue
  1067. x = out[0]
  1068. r = out[2]
  1069. assert_(r == m, 'expected efficient rank %s, '
  1070. 'got %s' % (m, r))
  1071. assert_allclose(
  1072. x, direct_lstsq(a, b, cmplx=0),
  1073. rtol=25 * _eps_cast(a1.dtype),
  1074. atol=25 * _eps_cast(a1.dtype),
  1075. err_msg="driver: %s" % lapack_driver)
  1076. def test_random_complex_overdet(self):
  1077. for dtype in COMPLEX_DTYPES:
  1078. for (n, m) in ((20, 15), (200, 2)):
  1079. for lapack_driver in TestLstsq.lapack_drivers:
  1080. for overwrite in (True, False):
  1081. a = np.asarray(random([n, m]) + 1j*random([n, m]),
  1082. dtype=dtype)
  1083. for i in range(m):
  1084. a[i, i] = 20 * (0.1 + a[i, i])
  1085. for i in range(2):
  1086. b = np.asarray(random([n, 3]), dtype=dtype)
  1087. # Store values in case they are overwritten
  1088. # later
  1089. a1 = a.copy()
  1090. b1 = b.copy()
  1091. out = lstsq(a1, b1,
  1092. lapack_driver=lapack_driver,
  1093. overwrite_a=overwrite,
  1094. overwrite_b=overwrite)
  1095. x = out[0]
  1096. r = out[2]
  1097. assert_(r == m, 'expected efficient rank %s, '
  1098. 'got %s' % (m, r))
  1099. assert_allclose(
  1100. x, direct_lstsq(a, b, cmplx=1),
  1101. rtol=25 * _eps_cast(a1.dtype),
  1102. atol=25 * _eps_cast(a1.dtype),
  1103. err_msg="driver: %s" % lapack_driver)
  1104. def test_check_finite(self):
  1105. with suppress_warnings() as sup:
  1106. # On (some) OSX this tests triggers a warning (gh-7538)
  1107. sup.filter(RuntimeWarning,
  1108. "internal gelsd driver lwork query error,.*"
  1109. "Falling back to 'gelss' driver.")
  1110. at = np.array(((1, 20), (-30, 4)))
  1111. for dtype, bt, lapack_driver, overwrite, check_finite in \
  1112. itertools.product(REAL_DTYPES,
  1113. (((1, 0), (0, 1)), (1, 0), ((2, 1), (-30, 4))),
  1114. TestLstsq.lapack_drivers,
  1115. (True, False),
  1116. (True, False)):
  1117. a = at.astype(dtype)
  1118. b = np.array(bt, dtype=dtype)
  1119. # Store values in case they are overwritten
  1120. # later
  1121. a1 = a.copy()
  1122. b1 = b.copy()
  1123. try:
  1124. out = lstsq(a1, b1, lapack_driver=lapack_driver,
  1125. check_finite=check_finite, overwrite_a=overwrite,
  1126. overwrite_b=overwrite)
  1127. except LstsqLapackError:
  1128. if lapack_driver is None:
  1129. raise AssertionError('LstsqLapackError raised with '
  1130. '"lapack_driver" being "None".')
  1131. else:
  1132. # can't proceed,
  1133. # skip to the next iteration
  1134. continue
  1135. x = out[0]
  1136. r = out[2]
  1137. assert_(r == 2, 'expected efficient rank 2, got %s' % r)
  1138. assert_allclose(dot(a, x), b,
  1139. rtol=25 * _eps_cast(a.dtype),
  1140. atol=25 * _eps_cast(a.dtype),
  1141. err_msg="driver: %s" % lapack_driver)
  1142. def test_zero_size(self):
  1143. for a_shape, b_shape in (((0, 2), (0,)),
  1144. ((0, 4), (0, 2)),
  1145. ((4, 0), (4,)),
  1146. ((4, 0), (4, 2))):
  1147. b = np.ones(b_shape)
  1148. x, residues, rank, s = lstsq(np.zeros(a_shape), b)
  1149. assert_equal(x, np.zeros((a_shape[1],) + b_shape[1:]))
  1150. residues_should_be = (np.empty((0,)) if a_shape[1]
  1151. else np.linalg.norm(b, axis=0)**2)
  1152. assert_equal(residues, residues_should_be)
  1153. assert_(rank == 0, 'expected rank 0')
  1154. assert_equal(s, np.empty((0,)))
  1155. class TestPinv(object):
  1156. def test_simple_real(self):
  1157. a = array([[1, 2, 3], [4, 5, 6], [7, 8, 10]], dtype=float)
  1158. a_pinv = pinv(a)
  1159. assert_array_almost_equal(dot(a, a_pinv), np.eye(3))
  1160. a_pinv = pinv2(a)
  1161. assert_array_almost_equal(dot(a, a_pinv), np.eye(3))
  1162. def test_simple_complex(self):
  1163. a = (array([[1, 2, 3], [4, 5, 6], [7, 8, 10]],
  1164. dtype=float) + 1j * array([[10, 8, 7], [6, 5, 4], [3, 2, 1]],
  1165. dtype=float))
  1166. a_pinv = pinv(a)
  1167. assert_array_almost_equal(dot(a, a_pinv), np.eye(3))
  1168. a_pinv = pinv2(a)
  1169. assert_array_almost_equal(dot(a, a_pinv), np.eye(3))
  1170. def test_simple_singular(self):
  1171. a = array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=float)
  1172. a_pinv = pinv(a)
  1173. a_pinv2 = pinv2(a)
  1174. assert_array_almost_equal(a_pinv, a_pinv2)
  1175. def test_simple_cols(self):
  1176. a = array([[1, 2, 3], [4, 5, 6]], dtype=float)
  1177. a_pinv = pinv(a)
  1178. a_pinv2 = pinv2(a)
  1179. assert_array_almost_equal(a_pinv, a_pinv2)
  1180. def test_simple_rows(self):
  1181. a = array([[1, 2], [3, 4], [5, 6]], dtype=float)
  1182. a_pinv = pinv(a)
  1183. a_pinv2 = pinv2(a)
  1184. assert_array_almost_equal(a_pinv, a_pinv2)
  1185. def test_check_finite(self):
  1186. a = array([[1, 2, 3], [4, 5, 6.], [7, 8, 10]])
  1187. a_pinv = pinv(a, check_finite=False)
  1188. assert_array_almost_equal(dot(a, a_pinv), np.eye(3))
  1189. a_pinv = pinv2(a, check_finite=False)
  1190. assert_array_almost_equal(dot(a, a_pinv), np.eye(3))
  1191. def test_native_list_argument(self):
  1192. a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
  1193. a_pinv = pinv(a)
  1194. a_pinv2 = pinv2(a)
  1195. assert_array_almost_equal(a_pinv, a_pinv2)
  1196. class TestPinvSymmetric(object):
  1197. def test_simple_real(self):
  1198. a = array([[1, 2, 3], [4, 5, 6], [7, 8, 10]], dtype=float)
  1199. a = np.dot(a, a.T)
  1200. a_pinv = pinvh(a)
  1201. assert_array_almost_equal(np.dot(a, a_pinv), np.eye(3))
  1202. def test_nonpositive(self):
  1203. a = array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=float)
  1204. a = np.dot(a, a.T)
  1205. u, s, vt = np.linalg.svd(a)
  1206. s[0] *= -1
  1207. a = np.dot(u * s, vt) # a is now symmetric non-positive and singular
  1208. a_pinv = pinv2(a)
  1209. a_pinvh = pinvh(a)
  1210. assert_array_almost_equal(a_pinv, a_pinvh)
  1211. def test_simple_complex(self):
  1212. a = (array([[1, 2, 3], [4, 5, 6], [7, 8, 10]],
  1213. dtype=float) + 1j * array([[10, 8, 7], [6, 5, 4], [3, 2, 1]],
  1214. dtype=float))
  1215. a = np.dot(a, a.conj().T)
  1216. a_pinv = pinvh(a)
  1217. assert_array_almost_equal(np.dot(a, a_pinv), np.eye(3))
  1218. def test_native_list_argument(self):
  1219. a = array([[1, 2, 3], [4, 5, 6], [7, 8, 10]], dtype=float)
  1220. a = np.dot(a, a.T)
  1221. a_pinv = pinvh(a.tolist())
  1222. assert_array_almost_equal(np.dot(a, a_pinv), np.eye(3))
  1223. class TestVectorNorms(object):
  1224. def test_types(self):
  1225. for dtype in np.typecodes['AllFloat']:
  1226. x = np.array([1, 2, 3], dtype=dtype)
  1227. tol = max(1e-15, np.finfo(dtype).eps.real * 20)
  1228. assert_allclose(norm(x), np.sqrt(14), rtol=tol)
  1229. assert_allclose(norm(x, 2), np.sqrt(14), rtol=tol)
  1230. for dtype in np.typecodes['Complex']:
  1231. x = np.array([1j, 2j, 3j], dtype=dtype)
  1232. tol = max(1e-15, np.finfo(dtype).eps.real * 20)
  1233. assert_allclose(norm(x), np.sqrt(14), rtol=tol)
  1234. assert_allclose(norm(x, 2), np.sqrt(14), rtol=tol)
  1235. def test_overflow(self):
  1236. # unlike numpy's norm, this one is
  1237. # safer on overflow
  1238. a = array([1e20], dtype=float32)
  1239. assert_almost_equal(norm(a), a)
  1240. def test_stable(self):
  1241. # more stable than numpy's norm
  1242. a = array([1e4] + [1]*10000, dtype=float32)
  1243. try:
  1244. # snrm in double precision; we obtain the same as for float64
  1245. # -- large atol needed due to varying blas implementations
  1246. assert_allclose(norm(a) - 1e4, 0.5, atol=1e-2)
  1247. except AssertionError:
  1248. # snrm implemented in single precision, == np.linalg.norm result
  1249. msg = ": Result should equal either 0.0 or 0.5 (depending on " \
  1250. "implementation of snrm2)."
  1251. assert_almost_equal(norm(a) - 1e4, 0.0, err_msg=msg)
  1252. def test_zero_norm(self):
  1253. assert_equal(norm([1, 0, 3], 0), 2)
  1254. assert_equal(norm([1, 2, 3], 0), 3)
  1255. def test_axis_kwd(self):
  1256. a = np.array([[[2, 1], [3, 4]]] * 2, 'd')
  1257. assert_allclose(norm(a, axis=1), [[3.60555128, 4.12310563]] * 2)
  1258. assert_allclose(norm(a, 1, axis=1), [[5.] * 2] * 2)
  1259. @pytest.mark.skipif(NumpyVersion(np.__version__) < '1.10.0', reason="")
  1260. def test_keepdims_kwd(self):
  1261. a = np.array([[[2, 1], [3, 4]]] * 2, 'd')
  1262. b = norm(a, axis=1, keepdims=True)
  1263. assert_allclose(b, [[[3.60555128, 4.12310563]]] * 2)
  1264. assert_(b.shape == (2, 1, 2))
  1265. assert_allclose(norm(a, 1, axis=2, keepdims=True), [[[3.], [7.]]] * 2)
  1266. class TestMatrixNorms(object):
  1267. def test_matrix_norms(self):
  1268. # Not all of these are matrix norms in the most technical sense.
  1269. np.random.seed(1234)
  1270. for n, m in (1, 1), (1, 3), (3, 1), (4, 4), (4, 5), (5, 4):
  1271. for t in np.single, np.double, np.csingle, np.cdouble, np.int64:
  1272. A = 10 * np.random.randn(n, m).astype(t)
  1273. if np.issubdtype(A.dtype, np.complexfloating):
  1274. A = (A + 10j * np.random.randn(n, m)).astype(t)
  1275. t_high = np.cdouble
  1276. else:
  1277. t_high = np.double
  1278. for order in (None, 'fro', 1, -1, 2, -2, np.inf, -np.inf):
  1279. actual = norm(A, ord=order)
  1280. desired = np.linalg.norm(A, ord=order)
  1281. # SciPy may return higher precision matrix norms.
  1282. # This is a consequence of using LAPACK.
  1283. if not np.allclose(actual, desired):
  1284. desired = np.linalg.norm(A.astype(t_high), ord=order)
  1285. assert_allclose(actual, desired)
  1286. def test_axis_kwd(self):
  1287. a = np.array([[[2, 1], [3, 4]]] * 2, 'd')
  1288. b = norm(a, ord=np.inf, axis=(1, 0))
  1289. c = norm(np.swapaxes(a, 0, 1), ord=np.inf, axis=(0, 1))
  1290. d = norm(a, ord=1, axis=(0, 1))
  1291. assert_allclose(b, c)
  1292. assert_allclose(c, d)
  1293. assert_allclose(b, d)
  1294. assert_(b.shape == c.shape == d.shape)
  1295. b = norm(a, ord=1, axis=(1, 0))
  1296. c = norm(np.swapaxes(a, 0, 1), ord=1, axis=(0, 1))
  1297. d = norm(a, ord=np.inf, axis=(0, 1))
  1298. assert_allclose(b, c)
  1299. assert_allclose(c, d)
  1300. assert_allclose(b, d)
  1301. assert_(b.shape == c.shape == d.shape)
  1302. @pytest.mark.skipif(NumpyVersion(np.__version__) < '1.10.0', reason="")
  1303. def test_keepdims_kwd(self):
  1304. a = np.arange(120, dtype='d').reshape(2, 3, 4, 5)
  1305. b = norm(a, ord=np.inf, axis=(1, 0), keepdims=True)
  1306. c = norm(a, ord=1, axis=(0, 1), keepdims=True)
  1307. assert_allclose(b, c)
  1308. assert_(b.shape == c.shape)
  1309. class TestOverwrite(object):
  1310. def test_solve(self):
  1311. assert_no_overwrite(solve, [(3, 3), (3,)])
  1312. def test_solve_triangular(self):
  1313. assert_no_overwrite(solve_triangular, [(3, 3), (3,)])
  1314. def test_solve_banded(self):
  1315. assert_no_overwrite(lambda ab, b: solve_banded((2, 1), ab, b),
  1316. [(4, 6), (6,)])
  1317. def test_solveh_banded(self):
  1318. assert_no_overwrite(solveh_banded, [(2, 6), (6,)])
  1319. def test_inv(self):
  1320. assert_no_overwrite(inv, [(3, 3)])
  1321. def test_det(self):
  1322. assert_no_overwrite(det, [(3, 3)])
  1323. def test_lstsq(self):
  1324. assert_no_overwrite(lstsq, [(3, 2), (3,)])
  1325. def test_pinv(self):
  1326. assert_no_overwrite(pinv, [(3, 3)])
  1327. def test_pinv2(self):
  1328. assert_no_overwrite(pinv2, [(3, 3)])
  1329. def test_pinvh(self):
  1330. assert_no_overwrite(pinvh, [(3, 3)])
  1331. class TestSolveCirculant(object):
  1332. def test_basic1(self):
  1333. c = np.array([1, 2, 3, 5])
  1334. b = np.array([1, -1, 1, 0])
  1335. x = solve_circulant(c, b)
  1336. y = solve(circulant(c), b)
  1337. assert_allclose(x, y)
  1338. def test_basic2(self):
  1339. # b is a 2-d matrix.
  1340. c = np.array([1, 2, -3, -5])
  1341. b = np.arange(12).reshape(4, 3)
  1342. x = solve_circulant(c, b)
  1343. y = solve(circulant(c), b)
  1344. assert_allclose(x, y)
  1345. def test_basic3(self):
  1346. # b is a 3-d matrix.
  1347. c = np.array([1, 2, -3, -5])
  1348. b = np.arange(24).reshape(4, 3, 2)
  1349. x = solve_circulant(c, b)
  1350. y = solve(circulant(c), b)
  1351. assert_allclose(x, y)
  1352. def test_complex(self):
  1353. # Complex b and c
  1354. c = np.array([1+2j, -3, 4j, 5])
  1355. b = np.arange(8).reshape(4, 2) + 0.5j
  1356. x = solve_circulant(c, b)
  1357. y = solve(circulant(c), b)
  1358. assert_allclose(x, y)
  1359. def test_random_b_and_c(self):
  1360. # Random b and c
  1361. np.random.seed(54321)
  1362. c = np.random.randn(50)
  1363. b = np.random.randn(50)
  1364. x = solve_circulant(c, b)
  1365. y = solve(circulant(c), b)
  1366. assert_allclose(x, y)
  1367. def test_singular(self):
  1368. # c gives a singular circulant matrix.
  1369. c = np.array([1, 1, 0, 0])
  1370. b = np.array([1, 2, 3, 4])
  1371. x = solve_circulant(c, b, singular='lstsq')
  1372. y, res, rnk, s = lstsq(circulant(c), b)
  1373. assert_allclose(x, y)
  1374. assert_raises(LinAlgError, solve_circulant, x, y)
  1375. def test_axis_args(self):
  1376. # Test use of caxis, baxis and outaxis.
  1377. # c has shape (2, 1, 4)
  1378. c = np.array([[[-1, 2.5, 3, 3.5]], [[1, 6, 6, 6.5]]])
  1379. # b has shape (3, 4)
  1380. b = np.array([[0, 0, 1, 1], [1, 1, 0, 0], [1, -1, 0, 0]])
  1381. x = solve_circulant(c, b, baxis=1)
  1382. assert_equal(x.shape, (4, 2, 3))
  1383. expected = np.empty_like(x)
  1384. expected[:, 0, :] = solve(circulant(c[0]), b.T)
  1385. expected[:, 1, :] = solve(circulant(c[1]), b.T)
  1386. assert_allclose(x, expected)
  1387. x = solve_circulant(c, b, baxis=1, outaxis=-1)
  1388. assert_equal(x.shape, (2, 3, 4))
  1389. assert_allclose(np.rollaxis(x, -1), expected)
  1390. # np.swapaxes(c, 1, 2) has shape (2, 4, 1); b.T has shape (4, 3).
  1391. x = solve_circulant(np.swapaxes(c, 1, 2), b.T, caxis=1)
  1392. assert_equal(x.shape, (4, 2, 3))
  1393. assert_allclose(x, expected)
  1394. def test_native_list_arguments(self):
  1395. # Same as test_basic1 using python's native list.
  1396. c = [1, 2, 3, 5]
  1397. b = [1, -1, 1, 0]
  1398. x = solve_circulant(c, b)
  1399. y = solve(circulant(c), b)
  1400. assert_allclose(x, y)
  1401. class TestMatrix_Balance(object):
  1402. def test_string_arg(self):
  1403. assert_raises(ValueError, matrix_balance, 'Some string for fail')
  1404. def test_infnan_arg(self):
  1405. assert_raises(ValueError, matrix_balance,
  1406. np.array([[1, 2], [3, np.inf]]))
  1407. assert_raises(ValueError, matrix_balance,
  1408. np.array([[1, 2], [3, np.nan]]))
  1409. def test_scaling(self):
  1410. _, y = matrix_balance(np.array([[1000, 1], [1000, 0]]))
  1411. # Pre/post LAPACK 3.5.0 gives the same result up to an offset
  1412. # since in each case col norm is x1000 greater and
  1413. # 1000 / 32 ~= 1 * 32 hence balanced with 2 ** 5.
  1414. assert_allclose(int(np.diff(np.log2(np.diag(y)))), 5)
  1415. def test_scaling_order(self):
  1416. A = np.array([[1, 0, 1e-4], [1, 1, 1e-2], [1e4, 1e2, 1]])
  1417. x, y = matrix_balance(A)
  1418. assert_allclose(solve(y, A).dot(y), x)
  1419. def test_separate(self):
  1420. _, (y, z) = matrix_balance(np.array([[1000, 1], [1000, 0]]),
  1421. separate=1)
  1422. assert_equal(int(np.diff(np.log2(y))), 5)
  1423. assert_allclose(z, np.arange(2))
  1424. def test_permutation(self):
  1425. A = block_diag(np.ones((2, 2)), np.tril(np.ones((2, 2))),
  1426. np.ones((3, 3)))
  1427. x, (y, z) = matrix_balance(A, separate=1)
  1428. assert_allclose(y, np.ones_like(y))
  1429. assert_allclose(z, np.array([0, 1, 6, 5, 4, 3, 2]))
  1430. def test_perm_and_scaling(self):
  1431. # Matrix with its diagonal removed
  1432. cases = ( # Case 0
  1433. np.array([[0., 0., 0., 0., 0.000002],
  1434. [0., 0., 0., 0., 0.],
  1435. [2., 2., 0., 0., 0.],
  1436. [2., 2., 0., 0., 0.],
  1437. [0., 0., 0.000002, 0., 0.]]),
  1438. # Case 1 user reported GH-7258
  1439. np.array([[-0.5, 0., 0., 0.],
  1440. [0., -1., 0., 0.],
  1441. [1., 0., -0.5, 0.],
  1442. [0., 1., 0., -1.]]),
  1443. # Case 2 user reported GH-7258
  1444. np.array([[-3., 0., 1., 0.],
  1445. [-1., -1., -0., 1.],
  1446. [-3., -0., -0., 0.],
  1447. [-1., -0., 1., -1.]])
  1448. )
  1449. for A in cases:
  1450. x, y = matrix_balance(A)
  1451. x, (s, p) = matrix_balance(A, separate=1)
  1452. ip = np.empty_like(p)
  1453. ip[p] = np.arange(A.shape[0])
  1454. assert_allclose(y, np.diag(s)[ip, :])
  1455. assert_allclose(solve(y, A).dot(y), x)