test_lapack.py 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245
  1. #
  2. # Created by: Pearu Peterson, September 2002
  3. #
  4. from __future__ import division, print_function, absolute_import
  5. import sys
  6. import subprocess
  7. import time
  8. from functools import reduce
  9. from numpy.testing import (assert_equal, assert_array_almost_equal, assert_,
  10. assert_allclose, assert_almost_equal,
  11. assert_array_equal)
  12. import pytest
  13. from pytest import raises as assert_raises
  14. import numpy as np
  15. from numpy import (eye, ones, zeros, zeros_like, triu, tril, tril_indices,
  16. triu_indices)
  17. from numpy.random import rand, seed
  18. from scipy.linalg import _flapack as flapack
  19. from scipy.linalg import inv, svd, cholesky, solve
  20. from scipy.linalg.lapack import _compute_lwork
  21. try:
  22. from scipy.linalg import _clapack as clapack
  23. except ImportError:
  24. clapack = None
  25. from scipy.linalg.lapack import get_lapack_funcs
  26. from scipy.linalg.blas import get_blas_funcs
  27. REAL_DTYPES = [np.float32, np.float64]
  28. COMPLEX_DTYPES = [np.complex64, np.complex128]
  29. DTYPES = REAL_DTYPES + COMPLEX_DTYPES
  30. class TestFlapackSimple(object):
  31. def test_gebal(self):
  32. a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
  33. a1 = [[1, 0, 0, 3e-4],
  34. [4, 0, 0, 2e-3],
  35. [7, 1, 0, 0],
  36. [0, 1, 0, 0]]
  37. for p in 'sdzc':
  38. f = getattr(flapack, p+'gebal', None)
  39. if f is None:
  40. continue
  41. ba, lo, hi, pivscale, info = f(a)
  42. assert_(not info, repr(info))
  43. assert_array_almost_equal(ba, a)
  44. assert_equal((lo, hi), (0, len(a[0])-1))
  45. assert_array_almost_equal(pivscale, np.ones(len(a)))
  46. ba, lo, hi, pivscale, info = f(a1, permute=1, scale=1)
  47. assert_(not info, repr(info))
  48. # print(a1)
  49. # print(ba, lo, hi, pivscale)
  50. def test_gehrd(self):
  51. a = [[-149, -50, -154],
  52. [537, 180, 546],
  53. [-27, -9, -25]]
  54. for p in 'd':
  55. f = getattr(flapack, p+'gehrd', None)
  56. if f is None:
  57. continue
  58. ht, tau, info = f(a)
  59. assert_(not info, repr(info))
  60. def test_trsyl(self):
  61. a = np.array([[1, 2], [0, 4]])
  62. b = np.array([[5, 6], [0, 8]])
  63. c = np.array([[9, 10], [11, 12]])
  64. trans = 'T'
  65. # Test single and double implementations, including most
  66. # of the options
  67. for dtype in 'fdFD':
  68. a1, b1, c1 = a.astype(dtype), b.astype(dtype), c.astype(dtype)
  69. trsyl, = get_lapack_funcs(('trsyl',), (a1,))
  70. if dtype.isupper(): # is complex dtype
  71. a1[0] += 1j
  72. trans = 'C'
  73. x, scale, info = trsyl(a1, b1, c1)
  74. assert_array_almost_equal(np.dot(a1, x) + np.dot(x, b1),
  75. scale * c1)
  76. x, scale, info = trsyl(a1, b1, c1, trana=trans, tranb=trans)
  77. assert_array_almost_equal(
  78. np.dot(a1.conjugate().T, x) + np.dot(x, b1.conjugate().T),
  79. scale * c1, decimal=4)
  80. x, scale, info = trsyl(a1, b1, c1, isgn=-1)
  81. assert_array_almost_equal(np.dot(a1, x) - np.dot(x, b1),
  82. scale * c1, decimal=4)
  83. def test_lange(self):
  84. a = np.array([
  85. [-149, -50, -154],
  86. [537, 180, 546],
  87. [-27, -9, -25]])
  88. for dtype in 'fdFD':
  89. for norm in 'Mm1OoIiFfEe':
  90. a1 = a.astype(dtype)
  91. if dtype.isupper():
  92. # is complex dtype
  93. a1[0, 0] += 1j
  94. lange, = get_lapack_funcs(('lange',), (a1,))
  95. value = lange(norm, a1)
  96. if norm in 'FfEe':
  97. if dtype in 'Ff':
  98. decimal = 3
  99. else:
  100. decimal = 7
  101. ref = np.sqrt(np.sum(np.square(np.abs(a1))))
  102. assert_almost_equal(value, ref, decimal)
  103. else:
  104. if norm in 'Mm':
  105. ref = np.max(np.abs(a1))
  106. elif norm in '1Oo':
  107. ref = np.max(np.sum(np.abs(a1), axis=0))
  108. elif norm in 'Ii':
  109. ref = np.max(np.sum(np.abs(a1), axis=1))
  110. assert_equal(value, ref)
  111. class TestLapack(object):
  112. def test_flapack(self):
  113. if hasattr(flapack, 'empty_module'):
  114. # flapack module is empty
  115. pass
  116. def test_clapack(self):
  117. if hasattr(clapack, 'empty_module'):
  118. # clapack module is empty
  119. pass
  120. class TestLeastSquaresSolvers(object):
  121. def test_gels(self):
  122. seed(1234)
  123. # Test fat/tall matrix argument handling - gh-issue #8329
  124. for ind, dtype in enumerate(DTYPES):
  125. m = 10
  126. n = 20
  127. nrhs = 1
  128. a1 = rand(m, n).astype(dtype)
  129. b1 = rand(n).astype(dtype)
  130. gls, glslw = get_lapack_funcs(('gels', 'gels_lwork'), dtype=dtype)
  131. # Request of sizes
  132. lwork = _compute_lwork(glslw, m, n, nrhs)
  133. _, _, info = gls(a1, b1, lwork=lwork)
  134. assert_(info >= 0)
  135. _, _, info = gls(a1, b1, trans='TTCC'[ind], lwork=lwork)
  136. assert_(info >= 0)
  137. for dtype in REAL_DTYPES:
  138. a1 = np.array([[1.0, 2.0],
  139. [4.0, 5.0],
  140. [7.0, 8.0]], dtype=dtype)
  141. b1 = np.array([16.0, 17.0, 20.0], dtype=dtype)
  142. gels, gels_lwork, geqrf = get_lapack_funcs(
  143. ('gels', 'gels_lwork', 'geqrf'), (a1, b1))
  144. m, n = a1.shape
  145. if len(b1.shape) == 2:
  146. nrhs = b1.shape[1]
  147. else:
  148. nrhs = 1
  149. # Request of sizes
  150. lwork = _compute_lwork(gels_lwork, m, n, nrhs)
  151. lqr, x, info = gels(a1, b1, lwork=lwork)
  152. assert_allclose(x[:-1], np.array([-14.333333333333323,
  153. 14.999999999999991],
  154. dtype=dtype),
  155. rtol=25*np.finfo(dtype).eps)
  156. lqr_truth, _, _, _ = geqrf(a1)
  157. assert_array_equal(lqr, lqr_truth)
  158. for dtype in COMPLEX_DTYPES:
  159. a1 = np.array([[1.0+4.0j, 2.0],
  160. [4.0+0.5j, 5.0-3.0j],
  161. [7.0-2.0j, 8.0+0.7j]], dtype=dtype)
  162. b1 = np.array([16.0, 17.0+2.0j, 20.0-4.0j], dtype=dtype)
  163. gels, gels_lwork, geqrf = get_lapack_funcs(
  164. ('gels', 'gels_lwork', 'geqrf'), (a1, b1))
  165. m, n = a1.shape
  166. if len(b1.shape) == 2:
  167. nrhs = b1.shape[1]
  168. else:
  169. nrhs = 1
  170. # Request of sizes
  171. lwork = _compute_lwork(gels_lwork, m, n, nrhs)
  172. lqr, x, info = gels(a1, b1, lwork=lwork)
  173. assert_allclose(x[:-1],
  174. np.array([1.161753632288328-1.901075709391912j,
  175. 1.735882340522193+1.521240901196909j],
  176. dtype=dtype), rtol=25*np.finfo(dtype).eps)
  177. lqr_truth, _, _, _ = geqrf(a1)
  178. assert_array_equal(lqr, lqr_truth)
  179. def test_gelsd(self):
  180. for dtype in REAL_DTYPES:
  181. a1 = np.array([[1.0, 2.0],
  182. [4.0, 5.0],
  183. [7.0, 8.0]], dtype=dtype)
  184. b1 = np.array([16.0, 17.0, 20.0], dtype=dtype)
  185. gelsd, gelsd_lwork = get_lapack_funcs(('gelsd', 'gelsd_lwork'),
  186. (a1, b1))
  187. m, n = a1.shape
  188. if len(b1.shape) == 2:
  189. nrhs = b1.shape[1]
  190. else:
  191. nrhs = 1
  192. # Request of sizes
  193. work, iwork, info = gelsd_lwork(m, n, nrhs, -1)
  194. lwork = int(np.real(work))
  195. iwork_size = iwork
  196. x, s, rank, info = gelsd(a1, b1, lwork, iwork_size,
  197. -1, False, False)
  198. assert_allclose(x[:-1], np.array([-14.333333333333323,
  199. 14.999999999999991], dtype=dtype),
  200. rtol=25*np.finfo(dtype).eps)
  201. assert_allclose(s, np.array([12.596017180511966,
  202. 0.583396253199685], dtype=dtype),
  203. rtol=25*np.finfo(dtype).eps)
  204. for dtype in COMPLEX_DTYPES:
  205. a1 = np.array([[1.0+4.0j, 2.0],
  206. [4.0+0.5j, 5.0-3.0j],
  207. [7.0-2.0j, 8.0+0.7j]], dtype=dtype)
  208. b1 = np.array([16.0, 17.0+2.0j, 20.0-4.0j], dtype=dtype)
  209. gelsd, gelsd_lwork = get_lapack_funcs(('gelsd', 'gelsd_lwork'),
  210. (a1, b1))
  211. m, n = a1.shape
  212. if len(b1.shape) == 2:
  213. nrhs = b1.shape[1]
  214. else:
  215. nrhs = 1
  216. # Request of sizes
  217. work, rwork, iwork, info = gelsd_lwork(m, n, nrhs, -1)
  218. lwork = int(np.real(work))
  219. rwork_size = int(rwork)
  220. iwork_size = iwork
  221. x, s, rank, info = gelsd(a1, b1, lwork, rwork_size, iwork_size,
  222. -1, False, False)
  223. assert_allclose(x[:-1],
  224. np.array([1.161753632288328-1.901075709391912j,
  225. 1.735882340522193+1.521240901196909j],
  226. dtype=dtype), rtol=25*np.finfo(dtype).eps)
  227. assert_allclose(s,
  228. np.array([13.035514762572043, 4.337666985231382],
  229. dtype=dtype), rtol=25*np.finfo(dtype).eps)
  230. def test_gelss(self):
  231. for dtype in REAL_DTYPES:
  232. a1 = np.array([[1.0, 2.0],
  233. [4.0, 5.0],
  234. [7.0, 8.0]], dtype=dtype)
  235. b1 = np.array([16.0, 17.0, 20.0], dtype=dtype)
  236. gelss, gelss_lwork = get_lapack_funcs(('gelss', 'gelss_lwork'),
  237. (a1, b1))
  238. m, n = a1.shape
  239. if len(b1.shape) == 2:
  240. nrhs = b1.shape[1]
  241. else:
  242. nrhs = 1
  243. # Request of sizes
  244. work, info = gelss_lwork(m, n, nrhs, -1)
  245. lwork = int(np.real(work))
  246. v, x, s, rank, work, info = gelss(a1, b1, -1, lwork, False, False)
  247. assert_allclose(x[:-1], np.array([-14.333333333333323,
  248. 14.999999999999991], dtype=dtype),
  249. rtol=25*np.finfo(dtype).eps)
  250. assert_allclose(s, np.array([12.596017180511966,
  251. 0.583396253199685], dtype=dtype),
  252. rtol=25*np.finfo(dtype).eps)
  253. for dtype in COMPLEX_DTYPES:
  254. a1 = np.array([[1.0+4.0j, 2.0],
  255. [4.0+0.5j, 5.0-3.0j],
  256. [7.0-2.0j, 8.0+0.7j]], dtype=dtype)
  257. b1 = np.array([16.0, 17.0+2.0j, 20.0-4.0j], dtype=dtype)
  258. gelss, gelss_lwork = get_lapack_funcs(('gelss', 'gelss_lwork'),
  259. (a1, b1))
  260. m, n = a1.shape
  261. if len(b1.shape) == 2:
  262. nrhs = b1.shape[1]
  263. else:
  264. nrhs = 1
  265. # Request of sizes
  266. work, info = gelss_lwork(m, n, nrhs, -1)
  267. lwork = int(np.real(work))
  268. v, x, s, rank, work, info = gelss(a1, b1, -1, lwork, False, False)
  269. assert_allclose(x[:-1],
  270. np.array([1.161753632288328-1.901075709391912j,
  271. 1.735882340522193+1.521240901196909j],
  272. dtype=dtype),
  273. rtol=25*np.finfo(dtype).eps)
  274. assert_allclose(s, np.array([13.035514762572043,
  275. 4.337666985231382], dtype=dtype),
  276. rtol=25*np.finfo(dtype).eps)
  277. def test_gelsy(self):
  278. for dtype in REAL_DTYPES:
  279. a1 = np.array([[1.0, 2.0],
  280. [4.0, 5.0],
  281. [7.0, 8.0]], dtype=dtype)
  282. b1 = np.array([16.0, 17.0, 20.0], dtype=dtype)
  283. gelsy, gelsy_lwork = get_lapack_funcs(('gelsy', 'gelss_lwork'),
  284. (a1, b1))
  285. m, n = a1.shape
  286. if len(b1.shape) == 2:
  287. nrhs = b1.shape[1]
  288. else:
  289. nrhs = 1
  290. # Request of sizes
  291. work, info = gelsy_lwork(m, n, nrhs, 10*np.finfo(dtype).eps)
  292. lwork = int(np.real(work))
  293. jptv = np.zeros((a1.shape[1], 1), dtype=np.int32)
  294. v, x, j, rank, info = gelsy(a1, b1, jptv, np.finfo(dtype).eps,
  295. lwork, False, False)
  296. assert_allclose(x[:-1], np.array([-14.333333333333323,
  297. 14.999999999999991], dtype=dtype),
  298. rtol=25*np.finfo(dtype).eps)
  299. for dtype in COMPLEX_DTYPES:
  300. a1 = np.array([[1.0+4.0j, 2.0],
  301. [4.0+0.5j, 5.0-3.0j],
  302. [7.0-2.0j, 8.0+0.7j]], dtype=dtype)
  303. b1 = np.array([16.0, 17.0+2.0j, 20.0-4.0j], dtype=dtype)
  304. gelsy, gelsy_lwork = get_lapack_funcs(('gelsy', 'gelss_lwork'),
  305. (a1, b1))
  306. m, n = a1.shape
  307. if len(b1.shape) == 2:
  308. nrhs = b1.shape[1]
  309. else:
  310. nrhs = 1
  311. # Request of sizes
  312. work, info = gelsy_lwork(m, n, nrhs, 10*np.finfo(dtype).eps)
  313. lwork = int(np.real(work))
  314. jptv = np.zeros((a1.shape[1], 1), dtype=np.int32)
  315. v, x, j, rank, info = gelsy(a1, b1, jptv, np.finfo(dtype).eps,
  316. lwork, False, False)
  317. assert_allclose(x[:-1],
  318. np.array([1.161753632288328-1.901075709391912j,
  319. 1.735882340522193+1.521240901196909j],
  320. dtype=dtype),
  321. rtol=25*np.finfo(dtype).eps)
  322. class TestRegression(object):
  323. def test_ticket_1645(self):
  324. # Check that RQ routines have correct lwork
  325. for dtype in DTYPES:
  326. a = np.zeros((300, 2), dtype=dtype)
  327. gerqf, = get_lapack_funcs(['gerqf'], [a])
  328. assert_raises(Exception, gerqf, a, lwork=2)
  329. rq, tau, work, info = gerqf(a)
  330. if dtype in REAL_DTYPES:
  331. orgrq, = get_lapack_funcs(['orgrq'], [a])
  332. assert_raises(Exception, orgrq, rq[-2:], tau, lwork=1)
  333. orgrq(rq[-2:], tau, lwork=2)
  334. elif dtype in COMPLEX_DTYPES:
  335. ungrq, = get_lapack_funcs(['ungrq'], [a])
  336. assert_raises(Exception, ungrq, rq[-2:], tau, lwork=1)
  337. ungrq(rq[-2:], tau, lwork=2)
  338. class TestDpotr(object):
  339. def test_gh_2691(self):
  340. # 'lower' argument of dportf/dpotri
  341. for lower in [True, False]:
  342. for clean in [True, False]:
  343. np.random.seed(42)
  344. x = np.random.normal(size=(3, 3))
  345. a = x.dot(x.T)
  346. dpotrf, dpotri = get_lapack_funcs(("potrf", "potri"), (a, ))
  347. c, info = dpotrf(a, lower, clean=clean)
  348. dpt = dpotri(c, lower)[0]
  349. if lower:
  350. assert_allclose(np.tril(dpt), np.tril(inv(a)))
  351. else:
  352. assert_allclose(np.triu(dpt), np.triu(inv(a)))
  353. class TestDlasd4(object):
  354. def test_sing_val_update(self):
  355. sigmas = np.array([4., 3., 2., 0])
  356. m_vec = np.array([3.12, 5.7, -4.8, -2.2])
  357. M = np.hstack((np.vstack((np.diag(sigmas[0:-1]),
  358. np.zeros((1, len(m_vec) - 1)))), m_vec[:, np.newaxis]))
  359. SM = svd(M, full_matrices=False, compute_uv=False, overwrite_a=False,
  360. check_finite=False)
  361. it_len = len(sigmas)
  362. sgm = np.concatenate((sigmas[::-1], (sigmas[0] +
  363. it_len*np.sqrt(np.sum(np.power(m_vec, 2))),)))
  364. mvc = np.concatenate((m_vec[::-1], (0,)))
  365. lasd4 = get_lapack_funcs('lasd4', (sigmas,))
  366. roots = []
  367. for i in range(0, it_len):
  368. res = lasd4(i, sgm, mvc)
  369. roots.append(res[1])
  370. assert_((res[3] <= 0), "LAPACK root finding dlasd4 failed to find \
  371. the singular value %i" % i)
  372. roots = np.array(roots)[::-1]
  373. assert_((not np.any(np.isnan(roots)), "There are NaN roots"))
  374. assert_allclose(SM, roots, atol=100*np.finfo(np.float64).eps,
  375. rtol=100*np.finfo(np.float64).eps)
  376. def test_lartg():
  377. for dtype in 'fdFD':
  378. lartg = get_lapack_funcs('lartg', dtype=dtype)
  379. f = np.array(3, dtype)
  380. g = np.array(4, dtype)
  381. if np.iscomplexobj(g):
  382. g *= 1j
  383. cs, sn, r = lartg(f, g)
  384. assert_allclose(cs, 3.0/5.0)
  385. assert_allclose(r, 5.0)
  386. if np.iscomplexobj(g):
  387. assert_allclose(sn, -4.0j/5.0)
  388. assert_(type(r) == complex)
  389. assert_(type(cs) == float)
  390. else:
  391. assert_allclose(sn, 4.0/5.0)
  392. def test_rot():
  393. # srot, drot from blas and crot and zrot from lapack.
  394. for dtype in 'fdFD':
  395. c = 0.6
  396. s = 0.8
  397. u = np.ones(4, dtype) * 3
  398. v = np.ones(4, dtype) * 4
  399. atol = 10**-(np.finfo(dtype).precision-1)
  400. if dtype in 'fd':
  401. rot = get_blas_funcs('rot', dtype=dtype)
  402. f = 4
  403. else:
  404. rot = get_lapack_funcs('rot', dtype=dtype)
  405. s *= -1j
  406. v *= 1j
  407. f = 4j
  408. assert_allclose(rot(u, v, c, s), [[5, 5, 5, 5],
  409. [0, 0, 0, 0]], atol=atol)
  410. assert_allclose(rot(u, v, c, s, n=2), [[5, 5, 3, 3],
  411. [0, 0, f, f]], atol=atol)
  412. assert_allclose(rot(u, v, c, s, offx=2, offy=2),
  413. [[3, 3, 5, 5], [f, f, 0, 0]], atol=atol)
  414. assert_allclose(rot(u, v, c, s, incx=2, offy=2, n=2),
  415. [[5, 3, 5, 3], [f, f, 0, 0]], atol=atol)
  416. assert_allclose(rot(u, v, c, s, offx=2, incy=2, n=2),
  417. [[3, 3, 5, 5], [0, f, 0, f]], atol=atol)
  418. assert_allclose(rot(u, v, c, s, offx=2, incx=2, offy=2, incy=2, n=1),
  419. [[3, 3, 5, 3], [f, f, 0, f]], atol=atol)
  420. assert_allclose(rot(u, v, c, s, incx=-2, incy=-2, n=2),
  421. [[5, 3, 5, 3], [0, f, 0, f]], atol=atol)
  422. a, b = rot(u, v, c, s, overwrite_x=1, overwrite_y=1)
  423. assert_(a is u)
  424. assert_(b is v)
  425. assert_allclose(a, [5, 5, 5, 5], atol=atol)
  426. assert_allclose(b, [0, 0, 0, 0], atol=atol)
  427. def test_larfg_larf():
  428. np.random.seed(1234)
  429. a0 = np.random.random((4, 4))
  430. a0 = a0.T.dot(a0)
  431. a0j = np.random.random((4, 4)) + 1j*np.random.random((4, 4))
  432. a0j = a0j.T.conj().dot(a0j)
  433. # our test here will be to do one step of reducing a hermetian matrix to
  434. # tridiagonal form using householder transforms.
  435. for dtype in 'fdFD':
  436. larfg, larf = get_lapack_funcs(['larfg', 'larf'], dtype=dtype)
  437. if dtype in 'FD':
  438. a = a0j.copy()
  439. else:
  440. a = a0.copy()
  441. # generate a householder transform to clear a[2:,0]
  442. alpha, x, tau = larfg(a.shape[0]-1, a[1, 0], a[2:, 0])
  443. # create expected output
  444. expected = np.zeros_like(a[:, 0])
  445. expected[0] = a[0, 0]
  446. expected[1] = alpha
  447. # assemble householder vector
  448. v = np.zeros_like(a[1:, 0])
  449. v[0] = 1.0
  450. v[1:] = x
  451. # apply transform from the left
  452. a[1:, :] = larf(v, tau.conjugate(), a[1:, :], np.zeros(a.shape[1]))
  453. # apply transform from the right
  454. a[:, 1:] = larf(v, tau, a[:, 1:], np.zeros(a.shape[0]), side='R')
  455. assert_allclose(a[:, 0], expected, atol=1e-5)
  456. assert_allclose(a[0, :], expected, atol=1e-5)
  457. @pytest.mark.xslow
  458. def test_sgesdd_lwork_bug_workaround():
  459. # Test that SGESDD lwork is sufficiently large for LAPACK.
  460. #
  461. # This checks that workaround around an apparent LAPACK bug
  462. # actually works. cf. gh-5401
  463. #
  464. # xslow: requires 1GB+ of memory
  465. p = subprocess.Popen([sys.executable, '-c',
  466. 'import numpy as np; '
  467. 'from scipy.linalg import svd; '
  468. 'a = np.zeros([9537, 9537], dtype=np.float32); '
  469. 'svd(a)'],
  470. stdout=subprocess.PIPE,
  471. stderr=subprocess.STDOUT)
  472. # Check if it an error occurred within 5 sec; the computation can
  473. # take substantially longer, and we will not wait for it to finish
  474. for j in range(50):
  475. time.sleep(0.1)
  476. if p.poll() is not None:
  477. returncode = p.returncode
  478. break
  479. else:
  480. # Didn't exit in time -- probably entered computation. The
  481. # error is raised before entering computation, so things are
  482. # probably OK.
  483. returncode = 0
  484. p.terminate()
  485. assert_equal(returncode, 0,
  486. "Code apparently failed: " + p.stdout.read())
  487. class TestSytrd(object):
  488. def test_sytrd(self):
  489. for dtype in REAL_DTYPES:
  490. # Assert that a 0x0 matrix raises an error
  491. A = np.zeros((0, 0), dtype=dtype)
  492. sytrd, sytrd_lwork = \
  493. get_lapack_funcs(('sytrd', 'sytrd_lwork'), (A,))
  494. assert_raises(ValueError, sytrd, A)
  495. # Tests for n = 1 currently fail with
  496. # ```
  497. # ValueError: failed to create intent(cache|hide)|optional array--
  498. # must have defined dimensions but got (0,)
  499. # ```
  500. # This is a NumPy issue
  501. # <https://github.com/numpy/numpy/issues/9617>.
  502. # TODO once the issue has been resolved, test for n=1
  503. # some upper triangular array
  504. n = 3
  505. A = np.zeros((n, n), dtype=dtype)
  506. A[np.triu_indices_from(A)] = \
  507. np.arange(1, n*(n+1)//2+1, dtype=dtype)
  508. # query lwork
  509. lwork, info = sytrd_lwork(n)
  510. assert_equal(info, 0)
  511. # check lower=1 behavior (shouldn't do much since the matrix is
  512. # upper triangular)
  513. data, d, e, tau, info = sytrd(A, lower=1, lwork=lwork)
  514. assert_equal(info, 0)
  515. assert_allclose(data, A, atol=5*np.finfo(dtype).eps, rtol=1.0)
  516. assert_allclose(d, np.diag(A))
  517. assert_allclose(e, 0.0)
  518. assert_allclose(tau, 0.0)
  519. # and now for the proper test (lower=0 is the default)
  520. data, d, e, tau, info = sytrd(A, lwork=lwork)
  521. assert_equal(info, 0)
  522. # assert Q^T*A*Q = tridiag(e, d, e)
  523. # build tridiagonal matrix
  524. T = np.zeros_like(A, dtype=dtype)
  525. k = np.arange(A.shape[0])
  526. T[k, k] = d
  527. k2 = np.arange(A.shape[0]-1)
  528. T[k2+1, k2] = e
  529. T[k2, k2+1] = e
  530. # build Q
  531. Q = np.eye(n, n, dtype=dtype)
  532. for i in range(n-1):
  533. v = np.zeros(n, dtype=dtype)
  534. v[:i] = data[:i, i+1]
  535. v[i] = 1.0
  536. H = np.eye(n, n, dtype=dtype) - tau[i] * np.outer(v, v)
  537. Q = np.dot(H, Q)
  538. # Make matrix fully symmetric
  539. i_lower = np.tril_indices(n, -1)
  540. A[i_lower] = A.T[i_lower]
  541. QTAQ = np.dot(Q.T, np.dot(A, Q))
  542. # disable rtol here since some values in QTAQ and T are very close
  543. # to 0.
  544. assert_allclose(QTAQ, T, atol=5*np.finfo(dtype).eps, rtol=1.0)
  545. class TestHetrd(object):
  546. def test_hetrd(self):
  547. for real_dtype, complex_dtype in zip(REAL_DTYPES, COMPLEX_DTYPES):
  548. # Assert that a 0x0 matrix raises an error
  549. A = np.zeros((0, 0), dtype=complex_dtype)
  550. hetrd, hetrd_lwork = \
  551. get_lapack_funcs(('hetrd', 'hetrd_lwork'), (A,))
  552. assert_raises(ValueError, hetrd, A)
  553. # Tests for n = 1 currently fail with
  554. # ```
  555. # ValueError: failed to create intent(cache|hide)|optional array--
  556. # must have defined dimensions but got (0,)
  557. # ```
  558. # This is a NumPy issue
  559. # <https://github.com/numpy/numpy/issues/9617>.
  560. # TODO once the issue has been resolved, test for n=1
  561. # some upper triangular array
  562. n = 3
  563. A = np.zeros((n, n), dtype=complex_dtype)
  564. A[np.triu_indices_from(A)] = (
  565. np.arange(1, n*(n+1)//2+1, dtype=real_dtype)
  566. + 1j * np.arange(1, n*(n+1)//2+1, dtype=real_dtype)
  567. )
  568. np.fill_diagonal(A, np.real(np.diag(A)))
  569. # query lwork
  570. lwork, info = hetrd_lwork(n)
  571. assert_equal(info, 0)
  572. # check lower=1 behavior (shouldn't do much since the matrix is
  573. # upper triangular)
  574. data, d, e, tau, info = hetrd(A, lower=1, lwork=lwork)
  575. assert_equal(info, 0)
  576. assert_allclose(data, A, atol=5*np.finfo(real_dtype).eps, rtol=1.0)
  577. assert_allclose(d, np.real(np.diag(A)))
  578. assert_allclose(e, 0.0)
  579. assert_allclose(tau, 0.0)
  580. # and now for the proper test (lower=0 is the default)
  581. data, d, e, tau, info = hetrd(A, lwork=lwork)
  582. assert_equal(info, 0)
  583. # assert Q^T*A*Q = tridiag(e, d, e)
  584. # build tridiagonal matrix
  585. T = np.zeros_like(A, dtype=real_dtype)
  586. k = np.arange(A.shape[0], dtype=int)
  587. T[k, k] = d
  588. k2 = np.arange(A.shape[0]-1, dtype=int)
  589. T[k2+1, k2] = e
  590. T[k2, k2+1] = e
  591. # build Q
  592. Q = np.eye(n, n, dtype=complex_dtype)
  593. for i in range(n-1):
  594. v = np.zeros(n, dtype=complex_dtype)
  595. v[:i] = data[:i, i+1]
  596. v[i] = 1.0
  597. H = np.eye(n, n, dtype=complex_dtype) \
  598. - tau[i] * np.outer(v, np.conj(v))
  599. Q = np.dot(H, Q)
  600. # Make matrix fully Hermetian
  601. i_lower = np.tril_indices(n, -1)
  602. A[i_lower] = np.conj(A.T[i_lower])
  603. QHAQ = np.dot(np.conj(Q.T), np.dot(A, Q))
  604. # disable rtol here since some values in QTAQ and T are very close
  605. # to 0.
  606. assert_allclose(
  607. QHAQ, T, atol=10*np.finfo(real_dtype).eps, rtol=1.0
  608. )
  609. def test_gglse():
  610. # Example data taken from NAG manual
  611. for ind, dtype in enumerate(DTYPES):
  612. # DTYPES = <s,d,c,z> gglse
  613. func, func_lwork = get_lapack_funcs(('gglse', 'gglse_lwork'),
  614. dtype=dtype)
  615. lwork = _compute_lwork(func_lwork, m=6, n=4, p=2)
  616. # For <s,d>gglse
  617. if ind < 2:
  618. a = np.array([[-0.57, -1.28, -0.39, 0.25],
  619. [-1.93, 1.08, -0.31, -2.14],
  620. [2.30, 0.24, 0.40, -0.35],
  621. [-1.93, 0.64, -0.66, 0.08],
  622. [0.15, 0.30, 0.15, -2.13],
  623. [-0.02, 1.03, -1.43, 0.50]], dtype=dtype)
  624. c = np.array([-1.50, -2.14, 1.23, -0.54, -1.68, 0.82], dtype=dtype)
  625. d = np.array([0., 0.], dtype=dtype)
  626. # For <s,d>gglse
  627. else:
  628. a = np.array([[0.96-0.81j, -0.03+0.96j, -0.91+2.06j, -0.05+0.41j],
  629. [-0.98+1.98j, -1.20+0.19j, -0.66+0.42j, -0.81+0.56j],
  630. [0.62-0.46j, 1.01+0.02j, 0.63-0.17j, -1.11+0.60j],
  631. [0.37+0.38j, 0.19-0.54j, -0.98-0.36j, 0.22-0.20j],
  632. [0.83+0.51j, 0.20+0.01j, -0.17-0.46j, 1.47+1.59j],
  633. [1.08-0.28j, 0.20-0.12j, -0.07+1.23j, 0.26+0.26j]])
  634. c = np.array([[-2.54+0.09j],
  635. [1.65-2.26j],
  636. [-2.11-3.96j],
  637. [1.82+3.30j],
  638. [-6.41+3.77j],
  639. [2.07+0.66j]])
  640. d = np.zeros(2, dtype=dtype)
  641. b = np.array([[1., 0., -1., 0.], [0., 1., 0., -1.]], dtype=dtype)
  642. _, _, _, result, _ = func(a, b, c, d, lwork=lwork)
  643. if ind < 2:
  644. expected = np.array([0.48904455,
  645. 0.99754786,
  646. 0.48904455,
  647. 0.99754786])
  648. else:
  649. expected = np.array([1.08742917-1.96205783j,
  650. -0.74093902+3.72973919j,
  651. 1.08742917-1.96205759j,
  652. -0.74093896+3.72973895j])
  653. assert_array_almost_equal(result, expected, decimal=4)
  654. def test_sycon_hecon():
  655. seed(1234)
  656. for ind, dtype in enumerate(DTYPES+COMPLEX_DTYPES):
  657. # DTYPES + COMPLEX DTYPES = <s,d,c,z> sycon + <c,z>hecon
  658. n = 10
  659. # For <s,d,c,z>sycon
  660. if ind < 4:
  661. func_lwork = get_lapack_funcs('sytrf_lwork', dtype=dtype)
  662. funcon, functrf = get_lapack_funcs(('sycon', 'sytrf'), dtype=dtype)
  663. A = (rand(n, n)).astype(dtype)
  664. # For <c,z>hecon
  665. else:
  666. func_lwork = get_lapack_funcs('hetrf_lwork', dtype=dtype)
  667. funcon, functrf = get_lapack_funcs(('hecon', 'hetrf'), dtype=dtype)
  668. A = (rand(n, n) + rand(n, n)*1j).astype(dtype)
  669. # Since sycon only refers to upper/lower part, conj() is safe here.
  670. A = (A + A.conj().T)/2 + 2*np.eye(n, dtype=dtype)
  671. anorm = np.linalg.norm(A, 1)
  672. lwork = _compute_lwork(func_lwork, n)
  673. ldu, ipiv, _ = functrf(A, lwork=lwork, lower=1)
  674. rcond, _ = funcon(a=ldu, ipiv=ipiv, anorm=anorm, lower=1)
  675. # The error is at most 1-fold
  676. assert_(abs(1/rcond - np.linalg.cond(A, p=1))*rcond < 1)
  677. def test_sygst():
  678. seed(1234)
  679. for ind, dtype in enumerate(REAL_DTYPES):
  680. # DTYPES = <s,d> sygst
  681. n = 10
  682. potrf, sygst, syevd, sygvd = get_lapack_funcs(('potrf', 'sygst',
  683. 'syevd', 'sygvd'),
  684. dtype=dtype)
  685. A = rand(n, n).astype(dtype)
  686. A = (A + A.T)/2
  687. # B must be positive definite
  688. B = rand(n, n).astype(dtype)
  689. B = (B + B.T)/2 + 2 * np.eye(n, dtype=dtype)
  690. # Perform eig (sygvd)
  691. _, eig_gvd, info = sygvd(A, B)
  692. assert_(info == 0)
  693. # Convert to std problem potrf
  694. b, info = potrf(B)
  695. assert_(info == 0)
  696. a, info = sygst(A, b)
  697. assert_(info == 0)
  698. eig, _, info = syevd(a)
  699. assert_(info == 0)
  700. assert_allclose(eig, eig_gvd, rtol=1e-4)
  701. def test_hegst():
  702. seed(1234)
  703. for ind, dtype in enumerate(COMPLEX_DTYPES):
  704. # DTYPES = <c,z> hegst
  705. n = 10
  706. potrf, hegst, heevd, hegvd = get_lapack_funcs(('potrf', 'hegst',
  707. 'heevd', 'hegvd'),
  708. dtype=dtype)
  709. A = rand(n, n).astype(dtype) + 1j * rand(n, n).astype(dtype)
  710. A = (A + A.conj().T)/2
  711. # B must be positive definite
  712. B = rand(n, n).astype(dtype) + 1j * rand(n, n).astype(dtype)
  713. B = (B + B.conj().T)/2 + 2 * np.eye(n, dtype=dtype)
  714. # Perform eig (hegvd)
  715. _, eig_gvd, info = hegvd(A, B)
  716. assert_(info == 0)
  717. # Convert to std problem potrf
  718. b, info = potrf(B)
  719. assert_(info == 0)
  720. a, info = hegst(A, b)
  721. assert_(info == 0)
  722. eig, _, info = heevd(a)
  723. assert_(info == 0)
  724. assert_allclose(eig, eig_gvd, rtol=1e-4)
  725. def test_tzrzf():
  726. """
  727. This test performs an RZ decomposition in which an m x n upper trapezoidal
  728. array M (m <= n) is factorized as M = [R 0] * Z where R is upper triangular
  729. and Z is unitary.
  730. """
  731. seed(1234)
  732. m, n = 10, 15
  733. for ind, dtype in enumerate(DTYPES):
  734. tzrzf, tzrzf_lw = get_lapack_funcs(('tzrzf', 'tzrzf_lwork'),
  735. dtype=dtype)
  736. lwork = _compute_lwork(tzrzf_lw, m, n)
  737. if ind < 2:
  738. A = triu(rand(m, n).astype(dtype))
  739. else:
  740. A = triu((rand(m, n) + rand(m, n)*1j).astype(dtype))
  741. # assert wrong shape arg, f2py returns generic error
  742. assert_raises(Exception, tzrzf, A.T)
  743. rz, tau, info = tzrzf(A, lwork=lwork)
  744. # Check success
  745. assert_(info == 0)
  746. # Get Z manually for comparison
  747. R = np.hstack((rz[:, :m], np.zeros((m, n-m), dtype=dtype)))
  748. V = np.hstack((np.eye(m, dtype=dtype), rz[:, m:]))
  749. Id = np.eye(n, dtype=dtype)
  750. ref = [Id-tau[x]*V[[x], :].T.dot(V[[x], :].conj()) for x in range(m)]
  751. Z = reduce(np.dot, ref)
  752. assert_allclose(R.dot(Z) - A, zeros_like(A, dtype=dtype),
  753. atol=10*np.spacing(dtype(1.0).real), rtol=0.)
  754. def test_tfsm():
  755. """
  756. Test for solving a linear system with the coefficient matrix is a
  757. triangular array stored in Full Packed (RFP) format.
  758. """
  759. seed(1234)
  760. for ind, dtype in enumerate(DTYPES):
  761. n = 20
  762. if ind > 1:
  763. A = triu(rand(n, n) + rand(n, n)*1j + eye(n)).astype(dtype)
  764. trans = 'C'
  765. else:
  766. A = triu(rand(n, n) + eye(n)).astype(dtype)
  767. trans = 'T'
  768. trttf, tfttr, tfsm = get_lapack_funcs(('trttf', 'tfttr', 'tfsm'),
  769. dtype=dtype)
  770. Afp, _ = trttf(A)
  771. B = rand(n, 2).astype(dtype)
  772. soln = tfsm(-1, Afp, B)
  773. assert_array_almost_equal(soln, solve(-A, B),
  774. decimal=4 if ind % 2 == 0 else 6)
  775. soln = tfsm(-1, Afp, B, trans=trans)
  776. assert_array_almost_equal(soln, solve(-A.conj().T, B),
  777. decimal=4 if ind % 2 == 0 else 6)
  778. # Make A, unit diagonal
  779. A[np.arange(n), np.arange(n)] = dtype(1.)
  780. soln = tfsm(-1, Afp, B, trans=trans, diag='U')
  781. assert_array_almost_equal(soln, solve(-A.conj().T, B),
  782. decimal=4 if ind % 2 == 0 else 6)
  783. # Change side
  784. B2 = rand(3, n).astype(dtype)
  785. soln = tfsm(-1, Afp, B2, trans=trans, diag='U', side='R')
  786. assert_array_almost_equal(soln, solve(-A, B2.T).conj().T,
  787. decimal=4 if ind % 2 == 0 else 6)
  788. def test_ormrz_unmrz():
  789. """
  790. This test performs a matrix multiplication with an arbitrary m x n matric C
  791. and a unitary matrix Q without explicitly forming the array. The array data
  792. is encoded in the rectangular part of A which is obtained from ?TZRZF. Q
  793. size is inferred by m, n, side keywords.
  794. """
  795. seed(1234)
  796. qm, qn, cn = 10, 15, 15
  797. for ind, dtype in enumerate(DTYPES):
  798. tzrzf, tzrzf_lw = get_lapack_funcs(('tzrzf', 'tzrzf_lwork'),
  799. dtype=dtype)
  800. lwork_rz = _compute_lwork(tzrzf_lw, qm, qn)
  801. if ind < 2:
  802. A = triu(rand(qm, qn).astype(dtype))
  803. C = rand(cn, cn).astype(dtype)
  804. orun_mrz, orun_mrz_lw = get_lapack_funcs(('ormrz', 'ormrz_lwork'),
  805. dtype=dtype)
  806. else:
  807. A = triu((rand(qm, qn) + rand(qm, qn)*1j).astype(dtype))
  808. C = (rand(cn, cn) + rand(cn, cn)*1j).astype(dtype)
  809. orun_mrz, orun_mrz_lw = get_lapack_funcs(('unmrz', 'unmrz_lwork'),
  810. dtype=dtype)
  811. lwork_mrz = _compute_lwork(orun_mrz_lw, cn, cn)
  812. rz, tau, info = tzrzf(A, lwork=lwork_rz)
  813. # Get Q manually for comparison
  814. V = np.hstack((np.eye(qm, dtype=dtype), rz[:, qm:]))
  815. Id = np.eye(qn, dtype=dtype)
  816. ref = [Id-tau[x]*V[[x], :].T.dot(V[[x], :].conj()) for x in range(qm)]
  817. Q = reduce(np.dot, ref)
  818. # Now that we have Q, we can test whether lapack results agree with
  819. # each case of CQ, CQ^H, QC, and QC^H
  820. trans = 'T' if ind < 2 else 'C'
  821. tol = 10*np.spacing(dtype(1.0).real)
  822. cq, info = orun_mrz(rz, tau, C, lwork=lwork_mrz)
  823. assert_(info == 0)
  824. assert_allclose(cq - Q.dot(C), zeros_like(C), atol=tol, rtol=0.)
  825. cq, info = orun_mrz(rz, tau, C, trans=trans, lwork=lwork_mrz)
  826. assert_(info == 0)
  827. assert_allclose(cq - Q.conj().T.dot(C), zeros_like(C), atol=tol,
  828. rtol=0.)
  829. cq, info = orun_mrz(rz, tau, C, side='R', lwork=lwork_mrz)
  830. assert_(info == 0)
  831. assert_allclose(cq - C.dot(Q), zeros_like(C), atol=tol, rtol=0.)
  832. cq, info = orun_mrz(rz, tau, C, side='R', trans=trans, lwork=lwork_mrz)
  833. assert_(info == 0)
  834. assert_allclose(cq - C.dot(Q.conj().T), zeros_like(C), atol=tol,
  835. rtol=0.)
  836. def test_tfttr_trttf():
  837. """
  838. Test conversion routines between the Rectengular Full Packed (RFP) format
  839. and Standard Triangular Array (TR)
  840. """
  841. seed(1234)
  842. for ind, dtype in enumerate(DTYPES):
  843. n = 20
  844. if ind > 1:
  845. A_full = (rand(n, n) + rand(n, n)*1j).astype(dtype)
  846. transr = 'C'
  847. else:
  848. A_full = (rand(n, n)).astype(dtype)
  849. transr = 'T'
  850. trttf, tfttr = get_lapack_funcs(('trttf', 'tfttr'), dtype=dtype)
  851. A_tf_U, info = trttf(A_full)
  852. assert_(info == 0)
  853. A_tf_L, info = trttf(A_full, uplo='L')
  854. assert_(info == 0)
  855. A_tf_U_T, info = trttf(A_full, transr=transr, uplo='U')
  856. assert_(info == 0)
  857. A_tf_L_T, info = trttf(A_full, transr=transr, uplo='L')
  858. assert_(info == 0)
  859. # Create the RFP array manually (n is even!)
  860. A_tf_U_m = zeros((n+1, n//2), dtype=dtype)
  861. A_tf_U_m[:-1, :] = triu(A_full)[:, n//2:]
  862. A_tf_U_m[n//2+1:, :] += triu(A_full)[:n//2, :n//2].conj().T
  863. A_tf_L_m = zeros((n+1, n//2), dtype=dtype)
  864. A_tf_L_m[1:, :] = tril(A_full)[:, :n//2]
  865. A_tf_L_m[:n//2, :] += tril(A_full)[n//2:, n//2:].conj().T
  866. assert_array_almost_equal(A_tf_U, A_tf_U_m.reshape(-1, order='F'))
  867. assert_array_almost_equal(A_tf_U_T,
  868. A_tf_U_m.conj().T.reshape(-1, order='F'))
  869. assert_array_almost_equal(A_tf_L, A_tf_L_m.reshape(-1, order='F'))
  870. assert_array_almost_equal(A_tf_L_T,
  871. A_tf_L_m.conj().T.reshape(-1, order='F'))
  872. # Get the original array from RFP
  873. A_tr_U, info = tfttr(n, A_tf_U)
  874. assert_(info == 0)
  875. A_tr_L, info = tfttr(n, A_tf_L, uplo='L')
  876. assert_(info == 0)
  877. A_tr_U_T, info = tfttr(n, A_tf_U_T, transr=transr, uplo='U')
  878. assert_(info == 0)
  879. A_tr_L_T, info = tfttr(n, A_tf_L_T, transr=transr, uplo='L')
  880. assert_(info == 0)
  881. assert_array_almost_equal(A_tr_U, triu(A_full))
  882. assert_array_almost_equal(A_tr_U_T, triu(A_full))
  883. assert_array_almost_equal(A_tr_L, tril(A_full))
  884. assert_array_almost_equal(A_tr_L_T, tril(A_full))
  885. def test_tpttr_trttp():
  886. """
  887. Test conversion routines between the Rectengular Full Packed (RFP) format
  888. and Standard Triangular Array (TR)
  889. """
  890. seed(1234)
  891. for ind, dtype in enumerate(DTYPES):
  892. n = 20
  893. if ind > 1:
  894. A_full = (rand(n, n) + rand(n, n)*1j).astype(dtype)
  895. else:
  896. A_full = (rand(n, n)).astype(dtype)
  897. trttp, tpttr = get_lapack_funcs(('trttp', 'tpttr'), dtype=dtype)
  898. A_tp_U, info = trttp(A_full)
  899. assert_(info == 0)
  900. A_tp_L, info = trttp(A_full, uplo='L')
  901. assert_(info == 0)
  902. # Create the TP array manually
  903. inds = tril_indices(n)
  904. A_tp_U_m = zeros(n*(n+1)//2, dtype=dtype)
  905. A_tp_U_m[:] = (triu(A_full).T)[inds]
  906. inds = triu_indices(n)
  907. A_tp_L_m = zeros(n*(n+1)//2, dtype=dtype)
  908. A_tp_L_m[:] = (tril(A_full).T)[inds]
  909. assert_array_almost_equal(A_tp_U, A_tp_U_m)
  910. assert_array_almost_equal(A_tp_L, A_tp_L_m)
  911. # Get the original array from TP
  912. A_tr_U, info = tpttr(n, A_tp_U)
  913. assert_(info == 0)
  914. A_tr_L, info = tpttr(n, A_tp_L, uplo='L')
  915. assert_(info == 0)
  916. assert_array_almost_equal(A_tr_U, triu(A_full))
  917. assert_array_almost_equal(A_tr_L, tril(A_full))
  918. def test_pftrf():
  919. """
  920. Test Cholesky factorization of a positive definite Rectengular Full
  921. Packed (RFP) format array
  922. """
  923. seed(1234)
  924. for ind, dtype in enumerate(DTYPES):
  925. n = 20
  926. if ind > 1:
  927. A = (rand(n, n) + rand(n, n)*1j).astype(dtype)
  928. A = A + A.conj().T + n*eye(n)
  929. else:
  930. A = (rand(n, n)).astype(dtype)
  931. A = A + A.T + n*eye(n)
  932. pftrf, trttf, tfttr = get_lapack_funcs(('pftrf', 'trttf', 'tfttr'),
  933. dtype=dtype)
  934. # Get the original array from TP
  935. Afp, info = trttf(A)
  936. Achol_rfp, info = pftrf(n, Afp)
  937. assert_(info == 0)
  938. A_chol_r, _ = tfttr(n, Achol_rfp)
  939. Achol = cholesky(A)
  940. assert_array_almost_equal(A_chol_r, Achol)
  941. def test_pftri():
  942. """
  943. Test Cholesky factorization of a positive definite Rectengular Full
  944. Packed (RFP) format array to find its inverse
  945. """
  946. seed(1234)
  947. for ind, dtype in enumerate(DTYPES):
  948. n = 20
  949. if ind > 1:
  950. A = (rand(n, n) + rand(n, n)*1j).astype(dtype)
  951. A = A + A.conj().T + n*eye(n)
  952. else:
  953. A = (rand(n, n)).astype(dtype)
  954. A = A + A.T + n*eye(n)
  955. pftri, pftrf, trttf, tfttr = get_lapack_funcs(('pftri',
  956. 'pftrf',
  957. 'trttf',
  958. 'tfttr'),
  959. dtype=dtype)
  960. # Get the original array from TP
  961. Afp, info = trttf(A)
  962. A_chol_rfp, info = pftrf(n, Afp)
  963. A_inv_rfp, info = pftri(n, A_chol_rfp)
  964. assert_(info == 0)
  965. A_inv_r, _ = tfttr(n, A_inv_rfp)
  966. Ainv = inv(A)
  967. assert_array_almost_equal(A_inv_r, triu(Ainv),
  968. decimal=4 if ind % 2 == 0 else 6)
  969. def test_pftrs():
  970. """
  971. Test Cholesky factorization of a positive definite Rectengular Full
  972. Packed (RFP) format array and solve a linear system
  973. """
  974. seed(1234)
  975. for ind, dtype in enumerate(DTYPES):
  976. n = 20
  977. if ind > 1:
  978. A = (rand(n, n) + rand(n, n)*1j).astype(dtype)
  979. A = A + A.conj().T + n*eye(n)
  980. else:
  981. A = (rand(n, n)).astype(dtype)
  982. A = A + A.T + n*eye(n)
  983. B = ones((n, 3), dtype=dtype)
  984. Bf1 = ones((n+2, 3), dtype=dtype)
  985. Bf2 = ones((n-2, 3), dtype=dtype)
  986. pftrs, pftrf, trttf, tfttr = get_lapack_funcs(('pftrs',
  987. 'pftrf',
  988. 'trttf',
  989. 'tfttr'),
  990. dtype=dtype)
  991. # Get the original array from TP
  992. Afp, info = trttf(A)
  993. A_chol_rfp, info = pftrf(n, Afp)
  994. # larger B arrays shouldn't segfault
  995. soln, info = pftrs(n, A_chol_rfp, Bf1)
  996. assert_(info == 0)
  997. assert_raises(Exception, pftrs, n, A_chol_rfp, Bf2)
  998. soln, info = pftrs(n, A_chol_rfp, B)
  999. assert_(info == 0)
  1000. assert_array_almost_equal(solve(A, B), soln,
  1001. decimal=4 if ind % 2 == 0 else 6)
  1002. def test_sfrk_hfrk():
  1003. """
  1004. Test for performing a symmetric rank-k operation for matrix in RFP format.
  1005. """
  1006. seed(1234)
  1007. for ind, dtype in enumerate(DTYPES):
  1008. n = 20
  1009. if ind > 1:
  1010. A = (rand(n, n) + rand(n, n)*1j).astype(dtype)
  1011. A = A + A.conj().T + n*eye(n)
  1012. else:
  1013. A = (rand(n, n)).astype(dtype)
  1014. A = A + A.T + n*eye(n)
  1015. prefix = 's'if ind < 2 else 'h'
  1016. trttf, tfttr, shfrk = get_lapack_funcs(('trttf', 'tfttr', '{}frk'
  1017. ''.format(prefix)),
  1018. dtype=dtype)
  1019. Afp, _ = trttf(A)
  1020. C = np.random.rand(n, 2).astype(dtype)
  1021. Afp_out = shfrk(n, 2, -1, C, 2, Afp)
  1022. A_out, _ = tfttr(n, Afp_out)
  1023. assert_array_almost_equal(A_out, triu(-C.dot(C.conj().T) + 2*A),
  1024. decimal=4 if ind % 2 == 0 else 6)