basic.py 55 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619
  1. #
  2. # Author: Pearu Peterson, March 2002
  3. #
  4. # w/ additions by Travis Oliphant, March 2002
  5. # and Jake Vanderplas, August 2012
  6. from __future__ import division, print_function, absolute_import
  7. from warnings import warn
  8. import numpy as np
  9. from numpy import atleast_1d, atleast_2d
  10. from .flinalg import get_flinalg_funcs
  11. from .lapack import get_lapack_funcs, _compute_lwork
  12. from .misc import LinAlgError, _datacopied, LinAlgWarning
  13. from .decomp import _asarray_validated
  14. from . import decomp, decomp_svd
  15. from ._solve_toeplitz import levinson
  16. __all__ = ['solve', 'solve_triangular', 'solveh_banded', 'solve_banded',
  17. 'solve_toeplitz', 'solve_circulant', 'inv', 'det', 'lstsq',
  18. 'pinv', 'pinv2', 'pinvh', 'matrix_balance']
  19. # Linear equations
  20. def _solve_check(n, info, lamch=None, rcond=None):
  21. """ Check arguments during the different steps of the solution phase """
  22. if info < 0:
  23. raise ValueError('LAPACK reported an illegal value in {}-th argument'
  24. '.'.format(-info))
  25. elif 0 < info:
  26. raise LinAlgError('Matrix is singular.')
  27. if lamch is None:
  28. return
  29. E = lamch('E')
  30. if rcond < E:
  31. warn('Ill-conditioned matrix (rcond={:.6g}): '
  32. 'result may not be accurate.'.format(rcond),
  33. LinAlgWarning, stacklevel=3)
  34. def solve(a, b, sym_pos=False, lower=False, overwrite_a=False,
  35. overwrite_b=False, debug=None, check_finite=True, assume_a='gen',
  36. transposed=False):
  37. """
  38. Solves the linear equation set ``a * x = b`` for the unknown ``x``
  39. for square ``a`` matrix.
  40. If the data matrix is known to be a particular type then supplying the
  41. corresponding string to ``assume_a`` key chooses the dedicated solver.
  42. The available options are
  43. =================== ========
  44. generic matrix 'gen'
  45. symmetric 'sym'
  46. hermitian 'her'
  47. positive definite 'pos'
  48. =================== ========
  49. If omitted, ``'gen'`` is the default structure.
  50. The datatype of the arrays define which solver is called regardless
  51. of the values. In other words, even when the complex array entries have
  52. precisely zero imaginary parts, the complex solver will be called based
  53. on the data type of the array.
  54. Parameters
  55. ----------
  56. a : (N, N) array_like
  57. Square input data
  58. b : (N, NRHS) array_like
  59. Input data for the right hand side.
  60. sym_pos : bool, optional
  61. Assume `a` is symmetric and positive definite. This key is deprecated
  62. and assume_a = 'pos' keyword is recommended instead. The functionality
  63. is the same. It will be removed in the future.
  64. lower : bool, optional
  65. If True, only the data contained in the lower triangle of `a`. Default
  66. is to use upper triangle. (ignored for ``'gen'``)
  67. overwrite_a : bool, optional
  68. Allow overwriting data in `a` (may enhance performance).
  69. Default is False.
  70. overwrite_b : bool, optional
  71. Allow overwriting data in `b` (may enhance performance).
  72. Default is False.
  73. check_finite : bool, optional
  74. Whether to check that the input matrices contain only finite numbers.
  75. Disabling may give a performance gain, but may result in problems
  76. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  77. assume_a : str, optional
  78. Valid entries are explained above.
  79. transposed: bool, optional
  80. If True, ``a^T x = b`` for real matrices, raises `NotImplementedError`
  81. for complex matrices (only for True).
  82. Returns
  83. -------
  84. x : (N, NRHS) ndarray
  85. The solution array.
  86. Raises
  87. ------
  88. ValueError
  89. If size mismatches detected or input a is not square.
  90. LinAlgError
  91. If the matrix is singular.
  92. LinAlgWarning
  93. If an ill-conditioned input a is detected.
  94. NotImplementedError
  95. If transposed is True and input a is a complex matrix.
  96. Examples
  97. --------
  98. Given `a` and `b`, solve for `x`:
  99. >>> a = np.array([[3, 2, 0], [1, -1, 0], [0, 5, 1]])
  100. >>> b = np.array([2, 4, -1])
  101. >>> from scipy import linalg
  102. >>> x = linalg.solve(a, b)
  103. >>> x
  104. array([ 2., -2., 9.])
  105. >>> np.dot(a, x) == b
  106. array([ True, True, True], dtype=bool)
  107. Notes
  108. -----
  109. If the input b matrix is a 1D array with N elements, when supplied
  110. together with an NxN input a, it is assumed as a valid column vector
  111. despite the apparent size mismatch. This is compatible with the
  112. numpy.dot() behavior and the returned result is still 1D array.
  113. The generic, symmetric, hermitian and positive definite solutions are
  114. obtained via calling ?GESV, ?SYSV, ?HESV, and ?POSV routines of
  115. LAPACK respectively.
  116. """
  117. # Flags for 1D or nD right hand side
  118. b_is_1D = False
  119. a1 = atleast_2d(_asarray_validated(a, check_finite=check_finite))
  120. b1 = atleast_1d(_asarray_validated(b, check_finite=check_finite))
  121. n = a1.shape[0]
  122. overwrite_a = overwrite_a or _datacopied(a1, a)
  123. overwrite_b = overwrite_b or _datacopied(b1, b)
  124. if a1.shape[0] != a1.shape[1]:
  125. raise ValueError('Input a needs to be a square matrix.')
  126. if n != b1.shape[0]:
  127. # Last chance to catch 1x1 scalar a and 1D b arrays
  128. if not (n == 1 and b1.size != 0):
  129. raise ValueError('Input b has to have same number of rows as '
  130. 'input a')
  131. # accommodate empty arrays
  132. if b1.size == 0:
  133. return np.asfortranarray(b1.copy())
  134. # regularize 1D b arrays to 2D
  135. if b1.ndim == 1:
  136. if n == 1:
  137. b1 = b1[None, :]
  138. else:
  139. b1 = b1[:, None]
  140. b_is_1D = True
  141. # Backwards compatibility - old keyword.
  142. if sym_pos:
  143. assume_a = 'pos'
  144. if assume_a not in ('gen', 'sym', 'her', 'pos'):
  145. raise ValueError('{} is not a recognized matrix structure'
  146. ''.format(assume_a))
  147. # Deprecate keyword "debug"
  148. if debug is not None:
  149. warn('Use of the "debug" keyword is deprecated '
  150. 'and this keyword will be removed in future '
  151. 'versions of SciPy.', DeprecationWarning, stacklevel=2)
  152. # Get the correct lamch function.
  153. # The LAMCH functions only exists for S and D
  154. # So for complex values we have to convert to real/double.
  155. if a1.dtype.char in 'fF': # single precision
  156. lamch = get_lapack_funcs('lamch', dtype='f')
  157. else:
  158. lamch = get_lapack_funcs('lamch', dtype='d')
  159. # Currently we do not have the other forms of the norm calculators
  160. # lansy, lanpo, lanhe.
  161. # However, in any case they only reduce computations slightly...
  162. lange = get_lapack_funcs('lange', (a1,))
  163. # Since the I-norm and 1-norm are the same for symmetric matrices
  164. # we can collect them all in this one call
  165. # Note however, that when issuing 'gen' and form!='none', then
  166. # the I-norm should be used
  167. if transposed:
  168. trans = 1
  169. norm = 'I'
  170. if np.iscomplexobj(a1):
  171. raise NotImplementedError('scipy.linalg.solve can currently '
  172. 'not solve a^T x = b or a^H x = b '
  173. 'for complex matrices.')
  174. else:
  175. trans = 0
  176. norm = '1'
  177. anorm = lange(norm, a1)
  178. # Generalized case 'gesv'
  179. if assume_a == 'gen':
  180. gecon, getrf, getrs = get_lapack_funcs(('gecon', 'getrf', 'getrs'),
  181. (a1, b1))
  182. lu, ipvt, info = getrf(a1, overwrite_a=overwrite_a)
  183. _solve_check(n, info)
  184. x, info = getrs(lu, ipvt, b1,
  185. trans=trans, overwrite_b=overwrite_b)
  186. _solve_check(n, info)
  187. rcond, info = gecon(lu, anorm, norm=norm)
  188. # Hermitian case 'hesv'
  189. elif assume_a == 'her':
  190. hecon, hesv, hesv_lw = get_lapack_funcs(('hecon', 'hesv',
  191. 'hesv_lwork'), (a1, b1))
  192. lwork = _compute_lwork(hesv_lw, n, lower)
  193. lu, ipvt, x, info = hesv(a1, b1, lwork=lwork,
  194. lower=lower,
  195. overwrite_a=overwrite_a,
  196. overwrite_b=overwrite_b)
  197. _solve_check(n, info)
  198. rcond, info = hecon(lu, ipvt, anorm)
  199. # Symmetric case 'sysv'
  200. elif assume_a == 'sym':
  201. sycon, sysv, sysv_lw = get_lapack_funcs(('sycon', 'sysv',
  202. 'sysv_lwork'), (a1, b1))
  203. lwork = _compute_lwork(sysv_lw, n, lower)
  204. lu, ipvt, x, info = sysv(a1, b1, lwork=lwork,
  205. lower=lower,
  206. overwrite_a=overwrite_a,
  207. overwrite_b=overwrite_b)
  208. _solve_check(n, info)
  209. rcond, info = sycon(lu, ipvt, anorm)
  210. # Positive definite case 'posv'
  211. else:
  212. pocon, posv = get_lapack_funcs(('pocon', 'posv'),
  213. (a1, b1))
  214. lu, x, info = posv(a1, b1, lower=lower,
  215. overwrite_a=overwrite_a,
  216. overwrite_b=overwrite_b)
  217. _solve_check(n, info)
  218. rcond, info = pocon(lu, anorm)
  219. _solve_check(n, info, lamch, rcond)
  220. if b_is_1D:
  221. x = x.ravel()
  222. return x
  223. def solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False,
  224. overwrite_b=False, debug=None, check_finite=True):
  225. """
  226. Solve the equation `a x = b` for `x`, assuming a is a triangular matrix.
  227. Parameters
  228. ----------
  229. a : (M, M) array_like
  230. A triangular matrix
  231. b : (M,) or (M, N) array_like
  232. Right-hand side matrix in `a x = b`
  233. lower : bool, optional
  234. Use only data contained in the lower triangle of `a`.
  235. Default is to use upper triangle.
  236. trans : {0, 1, 2, 'N', 'T', 'C'}, optional
  237. Type of system to solve:
  238. ======== =========
  239. trans system
  240. ======== =========
  241. 0 or 'N' a x = b
  242. 1 or 'T' a^T x = b
  243. 2 or 'C' a^H x = b
  244. ======== =========
  245. unit_diagonal : bool, optional
  246. If True, diagonal elements of `a` are assumed to be 1 and
  247. will not be referenced.
  248. overwrite_b : bool, optional
  249. Allow overwriting data in `b` (may enhance performance)
  250. check_finite : bool, optional
  251. Whether to check that the input matrices contain only finite numbers.
  252. Disabling may give a performance gain, but may result in problems
  253. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  254. Returns
  255. -------
  256. x : (M,) or (M, N) ndarray
  257. Solution to the system `a x = b`. Shape of return matches `b`.
  258. Raises
  259. ------
  260. LinAlgError
  261. If `a` is singular
  262. Notes
  263. -----
  264. .. versionadded:: 0.9.0
  265. Examples
  266. --------
  267. Solve the lower triangular system a x = b, where::
  268. [3 0 0 0] [4]
  269. a = [2 1 0 0] b = [2]
  270. [1 0 1 0] [4]
  271. [1 1 1 1] [2]
  272. >>> from scipy.linalg import solve_triangular
  273. >>> a = np.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]])
  274. >>> b = np.array([4, 2, 4, 2])
  275. >>> x = solve_triangular(a, b, lower=True)
  276. >>> x
  277. array([ 1.33333333, -0.66666667, 2.66666667, -1.33333333])
  278. >>> a.dot(x) # Check the result
  279. array([ 4., 2., 4., 2.])
  280. """
  281. # Deprecate keyword "debug"
  282. if debug is not None:
  283. warn('Use of the "debug" keyword is deprecated '
  284. 'and this keyword will be removed in the future '
  285. 'versions of SciPy.', DeprecationWarning, stacklevel=2)
  286. a1 = _asarray_validated(a, check_finite=check_finite)
  287. b1 = _asarray_validated(b, check_finite=check_finite)
  288. if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
  289. raise ValueError('expected square matrix')
  290. if a1.shape[0] != b1.shape[0]:
  291. raise ValueError('incompatible dimensions')
  292. overwrite_b = overwrite_b or _datacopied(b1, b)
  293. if debug:
  294. print('solve:overwrite_b=', overwrite_b)
  295. trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans)
  296. trtrs, = get_lapack_funcs(('trtrs',), (a1, b1))
  297. x, info = trtrs(a1, b1, overwrite_b=overwrite_b, lower=lower,
  298. trans=trans, unitdiag=unit_diagonal)
  299. if info == 0:
  300. return x
  301. if info > 0:
  302. raise LinAlgError("singular matrix: resolution failed at diagonal %d" %
  303. (info-1))
  304. raise ValueError('illegal value in %d-th argument of internal trtrs' %
  305. (-info))
  306. def solve_banded(l_and_u, ab, b, overwrite_ab=False, overwrite_b=False,
  307. debug=None, check_finite=True):
  308. """
  309. Solve the equation a x = b for x, assuming a is banded matrix.
  310. The matrix a is stored in `ab` using the matrix diagonal ordered form::
  311. ab[u + i - j, j] == a[i,j]
  312. Example of `ab` (shape of a is (6,6), `u` =1, `l` =2)::
  313. * a01 a12 a23 a34 a45
  314. a00 a11 a22 a33 a44 a55
  315. a10 a21 a32 a43 a54 *
  316. a20 a31 a42 a53 * *
  317. Parameters
  318. ----------
  319. (l, u) : (integer, integer)
  320. Number of non-zero lower and upper diagonals
  321. ab : (`l` + `u` + 1, M) array_like
  322. Banded matrix
  323. b : (M,) or (M, K) array_like
  324. Right-hand side
  325. overwrite_ab : bool, optional
  326. Discard data in `ab` (may enhance performance)
  327. overwrite_b : bool, optional
  328. Discard data in `b` (may enhance performance)
  329. check_finite : bool, optional
  330. Whether to check that the input matrices contain only finite numbers.
  331. Disabling may give a performance gain, but may result in problems
  332. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  333. Returns
  334. -------
  335. x : (M,) or (M, K) ndarray
  336. The solution to the system a x = b. Returned shape depends on the
  337. shape of `b`.
  338. Examples
  339. --------
  340. Solve the banded system a x = b, where::
  341. [5 2 -1 0 0] [0]
  342. [1 4 2 -1 0] [1]
  343. a = [0 1 3 2 -1] b = [2]
  344. [0 0 1 2 2] [2]
  345. [0 0 0 1 1] [3]
  346. There is one nonzero diagonal below the main diagonal (l = 1), and
  347. two above (u = 2). The diagonal banded form of the matrix is::
  348. [* * -1 -1 -1]
  349. ab = [* 2 2 2 2]
  350. [5 4 3 2 1]
  351. [1 1 1 1 *]
  352. >>> from scipy.linalg import solve_banded
  353. >>> ab = np.array([[0, 0, -1, -1, -1],
  354. ... [0, 2, 2, 2, 2],
  355. ... [5, 4, 3, 2, 1],
  356. ... [1, 1, 1, 1, 0]])
  357. >>> b = np.array([0, 1, 2, 2, 3])
  358. >>> x = solve_banded((1, 2), ab, b)
  359. >>> x
  360. array([-2.37288136, 3.93220339, -4. , 4.3559322 , -1.3559322 ])
  361. """
  362. # Deprecate keyword "debug"
  363. if debug is not None:
  364. warn('Use of the "debug" keyword is deprecated '
  365. 'and this keyword will be removed in the future '
  366. 'versions of SciPy.', DeprecationWarning, stacklevel=2)
  367. a1 = _asarray_validated(ab, check_finite=check_finite, as_inexact=True)
  368. b1 = _asarray_validated(b, check_finite=check_finite, as_inexact=True)
  369. # Validate shapes.
  370. if a1.shape[-1] != b1.shape[0]:
  371. raise ValueError("shapes of ab and b are not compatible.")
  372. (nlower, nupper) = l_and_u
  373. if nlower + nupper + 1 != a1.shape[0]:
  374. raise ValueError("invalid values for the number of lower and upper "
  375. "diagonals: l+u+1 (%d) does not equal ab.shape[0] "
  376. "(%d)" % (nlower + nupper + 1, ab.shape[0]))
  377. overwrite_b = overwrite_b or _datacopied(b1, b)
  378. if a1.shape[-1] == 1:
  379. b2 = np.array(b1, copy=(not overwrite_b))
  380. b2 /= a1[1, 0]
  381. return b2
  382. if nlower == nupper == 1:
  383. overwrite_ab = overwrite_ab or _datacopied(a1, ab)
  384. gtsv, = get_lapack_funcs(('gtsv',), (a1, b1))
  385. du = a1[0, 1:]
  386. d = a1[1, :]
  387. dl = a1[2, :-1]
  388. du2, d, du, x, info = gtsv(dl, d, du, b1, overwrite_ab, overwrite_ab,
  389. overwrite_ab, overwrite_b)
  390. else:
  391. gbsv, = get_lapack_funcs(('gbsv',), (a1, b1))
  392. a2 = np.zeros((2*nlower + nupper + 1, a1.shape[1]), dtype=gbsv.dtype)
  393. a2[nlower:, :] = a1
  394. lu, piv, x, info = gbsv(nlower, nupper, a2, b1, overwrite_ab=True,
  395. overwrite_b=overwrite_b)
  396. if info == 0:
  397. return x
  398. if info > 0:
  399. raise LinAlgError("singular matrix")
  400. raise ValueError('illegal value in %d-th argument of internal '
  401. 'gbsv/gtsv' % -info)
  402. def solveh_banded(ab, b, overwrite_ab=False, overwrite_b=False, lower=False,
  403. check_finite=True):
  404. """
  405. Solve equation a x = b. a is Hermitian positive-definite banded matrix.
  406. The matrix a is stored in `ab` either in lower diagonal or upper
  407. diagonal ordered form:
  408. ab[u + i - j, j] == a[i,j] (if upper form; i <= j)
  409. ab[ i - j, j] == a[i,j] (if lower form; i >= j)
  410. Example of `ab` (shape of a is (6, 6), `u` =2)::
  411. upper form:
  412. * * a02 a13 a24 a35
  413. * a01 a12 a23 a34 a45
  414. a00 a11 a22 a33 a44 a55
  415. lower form:
  416. a00 a11 a22 a33 a44 a55
  417. a10 a21 a32 a43 a54 *
  418. a20 a31 a42 a53 * *
  419. Cells marked with * are not used.
  420. Parameters
  421. ----------
  422. ab : (`u` + 1, M) array_like
  423. Banded matrix
  424. b : (M,) or (M, K) array_like
  425. Right-hand side
  426. overwrite_ab : bool, optional
  427. Discard data in `ab` (may enhance performance)
  428. overwrite_b : bool, optional
  429. Discard data in `b` (may enhance performance)
  430. lower : bool, optional
  431. Is the matrix in the lower form. (Default is upper form)
  432. check_finite : bool, optional
  433. Whether to check that the input matrices contain only finite numbers.
  434. Disabling may give a performance gain, but may result in problems
  435. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  436. Returns
  437. -------
  438. x : (M,) or (M, K) ndarray
  439. The solution to the system a x = b. Shape of return matches shape
  440. of `b`.
  441. Examples
  442. --------
  443. Solve the banded system A x = b, where::
  444. [ 4 2 -1 0 0 0] [1]
  445. [ 2 5 2 -1 0 0] [2]
  446. A = [-1 2 6 2 -1 0] b = [2]
  447. [ 0 -1 2 7 2 -1] [3]
  448. [ 0 0 -1 2 8 2] [3]
  449. [ 0 0 0 -1 2 9] [3]
  450. >>> from scipy.linalg import solveh_banded
  451. `ab` contains the main diagonal and the nonzero diagonals below the
  452. main diagonal. That is, we use the lower form:
  453. >>> ab = np.array([[ 4, 5, 6, 7, 8, 9],
  454. ... [ 2, 2, 2, 2, 2, 0],
  455. ... [-1, -1, -1, -1, 0, 0]])
  456. >>> b = np.array([1, 2, 2, 3, 3, 3])
  457. >>> x = solveh_banded(ab, b, lower=True)
  458. >>> x
  459. array([ 0.03431373, 0.45938375, 0.05602241, 0.47759104, 0.17577031,
  460. 0.34733894])
  461. Solve the Hermitian banded system H x = b, where::
  462. [ 8 2-1j 0 0 ] [ 1 ]
  463. H = [2+1j 5 1j 0 ] b = [1+1j]
  464. [ 0 -1j 9 -2-1j] [1-2j]
  465. [ 0 0 -2+1j 6 ] [ 0 ]
  466. In this example, we put the upper diagonals in the array `hb`:
  467. >>> hb = np.array([[0, 2-1j, 1j, -2-1j],
  468. ... [8, 5, 9, 6 ]])
  469. >>> b = np.array([1, 1+1j, 1-2j, 0])
  470. >>> x = solveh_banded(hb, b)
  471. >>> x
  472. array([ 0.07318536-0.02939412j, 0.11877624+0.17696461j,
  473. 0.10077984-0.23035393j, -0.00479904-0.09358128j])
  474. """
  475. a1 = _asarray_validated(ab, check_finite=check_finite)
  476. b1 = _asarray_validated(b, check_finite=check_finite)
  477. # Validate shapes.
  478. if a1.shape[-1] != b1.shape[0]:
  479. raise ValueError("shapes of ab and b are not compatible.")
  480. overwrite_b = overwrite_b or _datacopied(b1, b)
  481. overwrite_ab = overwrite_ab or _datacopied(a1, ab)
  482. if a1.shape[0] == 2:
  483. ptsv, = get_lapack_funcs(('ptsv',), (a1, b1))
  484. if lower:
  485. d = a1[0, :].real
  486. e = a1[1, :-1]
  487. else:
  488. d = a1[1, :].real
  489. e = a1[0, 1:].conj()
  490. d, du, x, info = ptsv(d, e, b1, overwrite_ab, overwrite_ab,
  491. overwrite_b)
  492. else:
  493. pbsv, = get_lapack_funcs(('pbsv',), (a1, b1))
  494. c, x, info = pbsv(a1, b1, lower=lower, overwrite_ab=overwrite_ab,
  495. overwrite_b=overwrite_b)
  496. if info > 0:
  497. raise LinAlgError("%d-th leading minor not positive definite" % info)
  498. if info < 0:
  499. raise ValueError('illegal value in %d-th argument of internal '
  500. 'pbsv' % -info)
  501. return x
  502. def solve_toeplitz(c_or_cr, b, check_finite=True):
  503. """Solve a Toeplitz system using Levinson Recursion
  504. The Toeplitz matrix has constant diagonals, with c as its first column
  505. and r as its first row. If r is not given, ``r == conjugate(c)`` is
  506. assumed.
  507. Parameters
  508. ----------
  509. c_or_cr : array_like or tuple of (array_like, array_like)
  510. The vector ``c``, or a tuple of arrays (``c``, ``r``). Whatever the
  511. actual shape of ``c``, it will be converted to a 1-D array. If not
  512. supplied, ``r = conjugate(c)`` is assumed; in this case, if c[0] is
  513. real, the Toeplitz matrix is Hermitian. r[0] is ignored; the first row
  514. of the Toeplitz matrix is ``[c[0], r[1:]]``. Whatever the actual shape
  515. of ``r``, it will be converted to a 1-D array.
  516. b : (M,) or (M, K) array_like
  517. Right-hand side in ``T x = b``.
  518. check_finite : bool, optional
  519. Whether to check that the input matrices contain only finite numbers.
  520. Disabling may give a performance gain, but may result in problems
  521. (result entirely NaNs) if the inputs do contain infinities or NaNs.
  522. Returns
  523. -------
  524. x : (M,) or (M, K) ndarray
  525. The solution to the system ``T x = b``. Shape of return matches shape
  526. of `b`.
  527. See Also
  528. --------
  529. toeplitz : Toeplitz matrix
  530. Notes
  531. -----
  532. The solution is computed using Levinson-Durbin recursion, which is faster
  533. than generic least-squares methods, but can be less numerically stable.
  534. Examples
  535. --------
  536. Solve the Toeplitz system T x = b, where::
  537. [ 1 -1 -2 -3] [1]
  538. T = [ 3 1 -1 -2] b = [2]
  539. [ 6 3 1 -1] [2]
  540. [10 6 3 1] [5]
  541. To specify the Toeplitz matrix, only the first column and the first
  542. row are needed.
  543. >>> c = np.array([1, 3, 6, 10]) # First column of T
  544. >>> r = np.array([1, -1, -2, -3]) # First row of T
  545. >>> b = np.array([1, 2, 2, 5])
  546. >>> from scipy.linalg import solve_toeplitz, toeplitz
  547. >>> x = solve_toeplitz((c, r), b)
  548. >>> x
  549. array([ 1.66666667, -1. , -2.66666667, 2.33333333])
  550. Check the result by creating the full Toeplitz matrix and
  551. multiplying it by `x`. We should get `b`.
  552. >>> T = toeplitz(c, r)
  553. >>> T.dot(x)
  554. array([ 1., 2., 2., 5.])
  555. """
  556. # If numerical stability of this algorithm is a problem, a future
  557. # developer might consider implementing other O(N^2) Toeplitz solvers,
  558. # such as GKO (https://www.jstor.org/stable/2153371) or Bareiss.
  559. if isinstance(c_or_cr, tuple):
  560. c, r = c_or_cr
  561. c = _asarray_validated(c, check_finite=check_finite).ravel()
  562. r = _asarray_validated(r, check_finite=check_finite).ravel()
  563. else:
  564. c = _asarray_validated(c_or_cr, check_finite=check_finite).ravel()
  565. r = c.conjugate()
  566. # Form a 1D array of values to be used in the matrix, containing a reversed
  567. # copy of r[1:], followed by c.
  568. vals = np.concatenate((r[-1:0:-1], c))
  569. if b is None:
  570. raise ValueError('illegal value, `b` is a required argument')
  571. b = _asarray_validated(b)
  572. if vals.shape[0] != (2*b.shape[0] - 1):
  573. raise ValueError('incompatible dimensions')
  574. if np.iscomplexobj(vals) or np.iscomplexobj(b):
  575. vals = np.asarray(vals, dtype=np.complex128, order='c')
  576. b = np.asarray(b, dtype=np.complex128)
  577. else:
  578. vals = np.asarray(vals, dtype=np.double, order='c')
  579. b = np.asarray(b, dtype=np.double)
  580. if b.ndim == 1:
  581. x, _ = levinson(vals, np.ascontiguousarray(b))
  582. else:
  583. b_shape = b.shape
  584. b = b.reshape(b.shape[0], -1)
  585. x = np.column_stack([levinson(vals, np.ascontiguousarray(b[:, i]))[0]
  586. for i in range(b.shape[1])])
  587. x = x.reshape(*b_shape)
  588. return x
  589. def _get_axis_len(aname, a, axis):
  590. ax = axis
  591. if ax < 0:
  592. ax += a.ndim
  593. if 0 <= ax < a.ndim:
  594. return a.shape[ax]
  595. raise ValueError("'%saxis' entry is out of bounds" % (aname,))
  596. def solve_circulant(c, b, singular='raise', tol=None,
  597. caxis=-1, baxis=0, outaxis=0):
  598. """Solve C x = b for x, where C is a circulant matrix.
  599. `C` is the circulant matrix associated with the vector `c`.
  600. The system is solved by doing division in Fourier space. The
  601. calculation is::
  602. x = ifft(fft(b) / fft(c))
  603. where `fft` and `ifft` are the fast Fourier transform and its inverse,
  604. respectively. For a large vector `c`, this is *much* faster than
  605. solving the system with the full circulant matrix.
  606. Parameters
  607. ----------
  608. c : array_like
  609. The coefficients of the circulant matrix.
  610. b : array_like
  611. Right-hand side matrix in ``a x = b``.
  612. singular : str, optional
  613. This argument controls how a near singular circulant matrix is
  614. handled. If `singular` is "raise" and the circulant matrix is
  615. near singular, a `LinAlgError` is raised. If `singular` is
  616. "lstsq", the least squares solution is returned. Default is "raise".
  617. tol : float, optional
  618. If any eigenvalue of the circulant matrix has an absolute value
  619. that is less than or equal to `tol`, the matrix is considered to be
  620. near singular. If not given, `tol` is set to::
  621. tol = abs_eigs.max() * abs_eigs.size * np.finfo(np.float64).eps
  622. where `abs_eigs` is the array of absolute values of the eigenvalues
  623. of the circulant matrix.
  624. caxis : int
  625. When `c` has dimension greater than 1, it is viewed as a collection
  626. of circulant vectors. In this case, `caxis` is the axis of `c` that
  627. holds the vectors of circulant coefficients.
  628. baxis : int
  629. When `b` has dimension greater than 1, it is viewed as a collection
  630. of vectors. In this case, `baxis` is the axis of `b` that holds the
  631. right-hand side vectors.
  632. outaxis : int
  633. When `c` or `b` are multidimensional, the value returned by
  634. `solve_circulant` is multidimensional. In this case, `outaxis` is
  635. the axis of the result that holds the solution vectors.
  636. Returns
  637. -------
  638. x : ndarray
  639. Solution to the system ``C x = b``.
  640. Raises
  641. ------
  642. LinAlgError
  643. If the circulant matrix associated with `c` is near singular.
  644. See Also
  645. --------
  646. circulant : circulant matrix
  647. Notes
  648. -----
  649. For a one-dimensional vector `c` with length `m`, and an array `b`
  650. with shape ``(m, ...)``,
  651. solve_circulant(c, b)
  652. returns the same result as
  653. solve(circulant(c), b)
  654. where `solve` and `circulant` are from `scipy.linalg`.
  655. .. versionadded:: 0.16.0
  656. Examples
  657. --------
  658. >>> from scipy.linalg import solve_circulant, solve, circulant, lstsq
  659. >>> c = np.array([2, 2, 4])
  660. >>> b = np.array([1, 2, 3])
  661. >>> solve_circulant(c, b)
  662. array([ 0.75, -0.25, 0.25])
  663. Compare that result to solving the system with `scipy.linalg.solve`:
  664. >>> solve(circulant(c), b)
  665. array([ 0.75, -0.25, 0.25])
  666. A singular example:
  667. >>> c = np.array([1, 1, 0, 0])
  668. >>> b = np.array([1, 2, 3, 4])
  669. Calling ``solve_circulant(c, b)`` will raise a `LinAlgError`. For the
  670. least square solution, use the option ``singular='lstsq'``:
  671. >>> solve_circulant(c, b, singular='lstsq')
  672. array([ 0.25, 1.25, 2.25, 1.25])
  673. Compare to `scipy.linalg.lstsq`:
  674. >>> x, resid, rnk, s = lstsq(circulant(c), b)
  675. >>> x
  676. array([ 0.25, 1.25, 2.25, 1.25])
  677. A broadcasting example:
  678. Suppose we have the vectors of two circulant matrices stored in an array
  679. with shape (2, 5), and three `b` vectors stored in an array with shape
  680. (3, 5). For example,
  681. >>> c = np.array([[1.5, 2, 3, 0, 0], [1, 1, 4, 3, 2]])
  682. >>> b = np.arange(15).reshape(-1, 5)
  683. We want to solve all combinations of circulant matrices and `b` vectors,
  684. with the result stored in an array with shape (2, 3, 5). When we
  685. disregard the axes of `c` and `b` that hold the vectors of coefficients,
  686. the shapes of the collections are (2,) and (3,), respectively, which are
  687. not compatible for broadcasting. To have a broadcast result with shape
  688. (2, 3), we add a trivial dimension to `c`: ``c[:, np.newaxis, :]`` has
  689. shape (2, 1, 5). The last dimension holds the coefficients of the
  690. circulant matrices, so when we call `solve_circulant`, we can use the
  691. default ``caxis=-1``. The coefficients of the `b` vectors are in the last
  692. dimension of the array `b`, so we use ``baxis=-1``. If we use the
  693. default `outaxis`, the result will have shape (5, 2, 3), so we'll use
  694. ``outaxis=-1`` to put the solution vectors in the last dimension.
  695. >>> x = solve_circulant(c[:, np.newaxis, :], b, baxis=-1, outaxis=-1)
  696. >>> x.shape
  697. (2, 3, 5)
  698. >>> np.set_printoptions(precision=3) # For compact output of numbers.
  699. >>> x
  700. array([[[-0.118, 0.22 , 1.277, -0.142, 0.302],
  701. [ 0.651, 0.989, 2.046, 0.627, 1.072],
  702. [ 1.42 , 1.758, 2.816, 1.396, 1.841]],
  703. [[ 0.401, 0.304, 0.694, -0.867, 0.377],
  704. [ 0.856, 0.758, 1.149, -0.412, 0.831],
  705. [ 1.31 , 1.213, 1.603, 0.042, 1.286]]])
  706. Check by solving one pair of `c` and `b` vectors (cf. ``x[1, 1, :]``):
  707. >>> solve_circulant(c[1], b[1, :])
  708. array([ 0.856, 0.758, 1.149, -0.412, 0.831])
  709. """
  710. c = np.atleast_1d(c)
  711. nc = _get_axis_len("c", c, caxis)
  712. b = np.atleast_1d(b)
  713. nb = _get_axis_len("b", b, baxis)
  714. if nc != nb:
  715. raise ValueError('Incompatible c and b axis lengths')
  716. fc = np.fft.fft(np.rollaxis(c, caxis, c.ndim), axis=-1)
  717. abs_fc = np.abs(fc)
  718. if tol is None:
  719. # This is the same tolerance as used in np.linalg.matrix_rank.
  720. tol = abs_fc.max(axis=-1) * nc * np.finfo(np.float64).eps
  721. if tol.shape != ():
  722. tol.shape = tol.shape + (1,)
  723. else:
  724. tol = np.atleast_1d(tol)
  725. near_zeros = abs_fc <= tol
  726. is_near_singular = np.any(near_zeros)
  727. if is_near_singular:
  728. if singular == 'raise':
  729. raise LinAlgError("near singular circulant matrix.")
  730. else:
  731. # Replace the small values with 1 to avoid errors in the
  732. # division fb/fc below.
  733. fc[near_zeros] = 1
  734. fb = np.fft.fft(np.rollaxis(b, baxis, b.ndim), axis=-1)
  735. q = fb / fc
  736. if is_near_singular:
  737. # `near_zeros` is a boolean array, same shape as `c`, that is
  738. # True where `fc` is (near) zero. `q` is the broadcasted result
  739. # of fb / fc, so to set the values of `q` to 0 where `fc` is near
  740. # zero, we use a mask that is the broadcast result of an array
  741. # of True values shaped like `b` with `near_zeros`.
  742. mask = np.ones_like(b, dtype=bool) & near_zeros
  743. q[mask] = 0
  744. x = np.fft.ifft(q, axis=-1)
  745. if not (np.iscomplexobj(c) or np.iscomplexobj(b)):
  746. x = x.real
  747. if outaxis != -1:
  748. x = np.rollaxis(x, -1, outaxis)
  749. return x
  750. # matrix inversion
  751. def inv(a, overwrite_a=False, check_finite=True):
  752. """
  753. Compute the inverse of a matrix.
  754. Parameters
  755. ----------
  756. a : array_like
  757. Square matrix to be inverted.
  758. overwrite_a : bool, optional
  759. Discard data in `a` (may improve performance). Default is False.
  760. check_finite : bool, optional
  761. Whether to check that the input matrix contains only finite numbers.
  762. Disabling may give a performance gain, but may result in problems
  763. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  764. Returns
  765. -------
  766. ainv : ndarray
  767. Inverse of the matrix `a`.
  768. Raises
  769. ------
  770. LinAlgError
  771. If `a` is singular.
  772. ValueError
  773. If `a` is not square, or not 2-dimensional.
  774. Examples
  775. --------
  776. >>> from scipy import linalg
  777. >>> a = np.array([[1., 2.], [3., 4.]])
  778. >>> linalg.inv(a)
  779. array([[-2. , 1. ],
  780. [ 1.5, -0.5]])
  781. >>> np.dot(a, linalg.inv(a))
  782. array([[ 1., 0.],
  783. [ 0., 1.]])
  784. """
  785. a1 = _asarray_validated(a, check_finite=check_finite)
  786. if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
  787. raise ValueError('expected square matrix')
  788. overwrite_a = overwrite_a or _datacopied(a1, a)
  789. # XXX: I found no advantage or disadvantage of using finv.
  790. # finv, = get_flinalg_funcs(('inv',),(a1,))
  791. # if finv is not None:
  792. # a_inv,info = finv(a1,overwrite_a=overwrite_a)
  793. # if info==0:
  794. # return a_inv
  795. # if info>0: raise LinAlgError, "singular matrix"
  796. # if info<0: raise ValueError('illegal value in %d-th argument of '
  797. # 'internal inv.getrf|getri'%(-info))
  798. getrf, getri, getri_lwork = get_lapack_funcs(('getrf', 'getri',
  799. 'getri_lwork'),
  800. (a1,))
  801. lu, piv, info = getrf(a1, overwrite_a=overwrite_a)
  802. if info == 0:
  803. lwork = _compute_lwork(getri_lwork, a1.shape[0])
  804. # XXX: the following line fixes curious SEGFAULT when
  805. # benchmarking 500x500 matrix inverse. This seems to
  806. # be a bug in LAPACK ?getri routine because if lwork is
  807. # minimal (when using lwork[0] instead of lwork[1]) then
  808. # all tests pass. Further investigation is required if
  809. # more such SEGFAULTs occur.
  810. lwork = int(1.01 * lwork)
  811. inv_a, info = getri(lu, piv, lwork=lwork, overwrite_lu=1)
  812. if info > 0:
  813. raise LinAlgError("singular matrix")
  814. if info < 0:
  815. raise ValueError('illegal value in %d-th argument of internal '
  816. 'getrf|getri' % -info)
  817. return inv_a
  818. # Determinant
  819. def det(a, overwrite_a=False, check_finite=True):
  820. """
  821. Compute the determinant of a matrix
  822. The determinant of a square matrix is a value derived arithmetically
  823. from the coefficients of the matrix.
  824. The determinant for a 3x3 matrix, for example, is computed as follows::
  825. a b c
  826. d e f = A
  827. g h i
  828. det(A) = a*e*i + b*f*g + c*d*h - c*e*g - b*d*i - a*f*h
  829. Parameters
  830. ----------
  831. a : (M, M) array_like
  832. A square matrix.
  833. overwrite_a : bool, optional
  834. Allow overwriting data in a (may enhance performance).
  835. check_finite : bool, optional
  836. Whether to check that the input matrix contains only finite numbers.
  837. Disabling may give a performance gain, but may result in problems
  838. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  839. Returns
  840. -------
  841. det : float or complex
  842. Determinant of `a`.
  843. Notes
  844. -----
  845. The determinant is computed via LU factorization, LAPACK routine z/dgetrf.
  846. Examples
  847. --------
  848. >>> from scipy import linalg
  849. >>> a = np.array([[1,2,3], [4,5,6], [7,8,9]])
  850. >>> linalg.det(a)
  851. 0.0
  852. >>> a = np.array([[0,2,3], [4,5,6], [7,8,9]])
  853. >>> linalg.det(a)
  854. 3.0
  855. """
  856. a1 = _asarray_validated(a, check_finite=check_finite)
  857. if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
  858. raise ValueError('expected square matrix')
  859. overwrite_a = overwrite_a or _datacopied(a1, a)
  860. fdet, = get_flinalg_funcs(('det',), (a1,))
  861. a_det, info = fdet(a1, overwrite_a=overwrite_a)
  862. if info < 0:
  863. raise ValueError('illegal value in %d-th argument of internal '
  864. 'det.getrf' % -info)
  865. return a_det
  866. # Linear Least Squares
  867. class LstsqLapackError(LinAlgError):
  868. pass
  869. def lstsq(a, b, cond=None, overwrite_a=False, overwrite_b=False,
  870. check_finite=True, lapack_driver=None):
  871. """
  872. Compute least-squares solution to equation Ax = b.
  873. Compute a vector x such that the 2-norm ``|b - A x|`` is minimized.
  874. Parameters
  875. ----------
  876. a : (M, N) array_like
  877. Left hand side matrix (2-D array).
  878. b : (M,) or (M, K) array_like
  879. Right hand side matrix or vector (1-D or 2-D array).
  880. cond : float, optional
  881. Cutoff for 'small' singular values; used to determine effective
  882. rank of a. Singular values smaller than
  883. ``rcond * largest_singular_value`` are considered zero.
  884. overwrite_a : bool, optional
  885. Discard data in `a` (may enhance performance). Default is False.
  886. overwrite_b : bool, optional
  887. Discard data in `b` (may enhance performance). Default is False.
  888. check_finite : bool, optional
  889. Whether to check that the input matrices contain only finite numbers.
  890. Disabling may give a performance gain, but may result in problems
  891. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  892. lapack_driver : str, optional
  893. Which LAPACK driver is used to solve the least-squares problem.
  894. Options are ``'gelsd'``, ``'gelsy'``, ``'gelss'``. Default
  895. (``'gelsd'``) is a good choice. However, ``'gelsy'`` can be slightly
  896. faster on many problems. ``'gelss'`` was used historically. It is
  897. generally slow but uses less memory.
  898. .. versionadded:: 0.17.0
  899. Returns
  900. -------
  901. x : (N,) or (N, K) ndarray
  902. Least-squares solution. Return shape matches shape of `b`.
  903. residues : (0,) or () or (K,) ndarray
  904. Sums of residues, squared 2-norm for each column in ``b - a x``.
  905. If rank of matrix a is ``< N`` or ``N > M``, or ``'gelsy'`` is used,
  906. this is a length zero array. If b was 1-D, this is a () shape array
  907. (numpy scalar), otherwise the shape is (K,).
  908. rank : int
  909. Effective rank of matrix `a`.
  910. s : (min(M,N),) ndarray or None
  911. Singular values of `a`. The condition number of a is
  912. ``abs(s[0] / s[-1])``. None is returned when ``'gelsy'`` is used.
  913. Raises
  914. ------
  915. LinAlgError
  916. If computation does not converge.
  917. ValueError
  918. When parameters are wrong.
  919. See Also
  920. --------
  921. optimize.nnls : linear least squares with non-negativity constraint
  922. Examples
  923. --------
  924. >>> from scipy.linalg import lstsq
  925. >>> import matplotlib.pyplot as plt
  926. Suppose we have the following data:
  927. >>> x = np.array([1, 2.5, 3.5, 4, 5, 7, 8.5])
  928. >>> y = np.array([0.3, 1.1, 1.5, 2.0, 3.2, 6.6, 8.6])
  929. We want to fit a quadratic polynomial of the form ``y = a + b*x**2``
  930. to this data. We first form the "design matrix" M, with a constant
  931. column of 1s and a column containing ``x**2``:
  932. >>> M = x[:, np.newaxis]**[0, 2]
  933. >>> M
  934. array([[ 1. , 1. ],
  935. [ 1. , 6.25],
  936. [ 1. , 12.25],
  937. [ 1. , 16. ],
  938. [ 1. , 25. ],
  939. [ 1. , 49. ],
  940. [ 1. , 72.25]])
  941. We want to find the least-squares solution to ``M.dot(p) = y``,
  942. where ``p`` is a vector with length 2 that holds the parameters
  943. ``a`` and ``b``.
  944. >>> p, res, rnk, s = lstsq(M, y)
  945. >>> p
  946. array([ 0.20925829, 0.12013861])
  947. Plot the data and the fitted curve.
  948. >>> plt.plot(x, y, 'o', label='data')
  949. >>> xx = np.linspace(0, 9, 101)
  950. >>> yy = p[0] + p[1]*xx**2
  951. >>> plt.plot(xx, yy, label='least squares fit, $y = a + bx^2$')
  952. >>> plt.xlabel('x')
  953. >>> plt.ylabel('y')
  954. >>> plt.legend(framealpha=1, shadow=True)
  955. >>> plt.grid(alpha=0.25)
  956. >>> plt.show()
  957. """
  958. a1 = _asarray_validated(a, check_finite=check_finite)
  959. b1 = _asarray_validated(b, check_finite=check_finite)
  960. if len(a1.shape) != 2:
  961. raise ValueError('expected matrix')
  962. m, n = a1.shape
  963. if len(b1.shape) == 2:
  964. nrhs = b1.shape[1]
  965. else:
  966. nrhs = 1
  967. if m != b1.shape[0]:
  968. raise ValueError('incompatible dimensions')
  969. if m == 0 or n == 0: # Zero-sized problem, confuses LAPACK
  970. x = np.zeros((n,) + b1.shape[1:], dtype=np.common_type(a1, b1))
  971. if n == 0:
  972. residues = np.linalg.norm(b1, axis=0)**2
  973. else:
  974. residues = np.empty((0,))
  975. return x, residues, 0, np.empty((0,))
  976. driver = lapack_driver
  977. if driver is None:
  978. driver = lstsq.default_lapack_driver
  979. if driver not in ('gelsd', 'gelsy', 'gelss'):
  980. raise ValueError('LAPACK driver "%s" is not found' % driver)
  981. lapack_func, lapack_lwork = get_lapack_funcs((driver,
  982. '%s_lwork' % driver),
  983. (a1, b1))
  984. real_data = True if (lapack_func.dtype.kind == 'f') else False
  985. if m < n:
  986. # need to extend b matrix as it will be filled with
  987. # a larger solution matrix
  988. if len(b1.shape) == 2:
  989. b2 = np.zeros((n, nrhs), dtype=lapack_func.dtype)
  990. b2[:m, :] = b1
  991. else:
  992. b2 = np.zeros(n, dtype=lapack_func.dtype)
  993. b2[:m] = b1
  994. b1 = b2
  995. overwrite_a = overwrite_a or _datacopied(a1, a)
  996. overwrite_b = overwrite_b or _datacopied(b1, b)
  997. if cond is None:
  998. cond = np.finfo(lapack_func.dtype).eps
  999. if driver in ('gelss', 'gelsd'):
  1000. if driver == 'gelss':
  1001. lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
  1002. v, x, s, rank, work, info = lapack_func(a1, b1, cond, lwork,
  1003. overwrite_a=overwrite_a,
  1004. overwrite_b=overwrite_b)
  1005. elif driver == 'gelsd':
  1006. if real_data:
  1007. lwork, iwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
  1008. if iwork == 0:
  1009. # this is LAPACK bug 0038: dgelsd does not provide the
  1010. # size of the iwork array in query mode. This bug was
  1011. # fixed in LAPACK 3.2.2, released July 21, 2010.
  1012. mesg = ("internal gelsd driver lwork query error, "
  1013. "required iwork dimension not returned. "
  1014. "This is likely the result of LAPACK bug "
  1015. "0038, fixed in LAPACK 3.2.2 (released "
  1016. "July 21, 2010). ")
  1017. if lapack_driver is None:
  1018. # restart with gelss
  1019. lstsq.default_lapack_driver = 'gelss'
  1020. mesg += "Falling back to 'gelss' driver."
  1021. warn(mesg, RuntimeWarning, stacklevel=2)
  1022. return lstsq(a, b, cond, overwrite_a, overwrite_b,
  1023. check_finite, lapack_driver='gelss')
  1024. # can't proceed, bail out
  1025. mesg += ("Use a different lapack_driver when calling lstsq"
  1026. " or upgrade LAPACK.")
  1027. raise LstsqLapackError(mesg)
  1028. x, s, rank, info = lapack_func(a1, b1, lwork,
  1029. iwork, cond, False, False)
  1030. else: # complex data
  1031. lwork, rwork, iwork = _compute_lwork(lapack_lwork, m, n,
  1032. nrhs, cond)
  1033. x, s, rank, info = lapack_func(a1, b1, lwork, rwork, iwork,
  1034. cond, False, False)
  1035. if info > 0:
  1036. raise LinAlgError("SVD did not converge in Linear Least Squares")
  1037. if info < 0:
  1038. raise ValueError('illegal value in %d-th argument of internal %s'
  1039. % (-info, lapack_driver))
  1040. resids = np.asarray([], dtype=x.dtype)
  1041. if m > n:
  1042. x1 = x[:n]
  1043. if rank == n:
  1044. resids = np.sum(np.abs(x[n:])**2, axis=0)
  1045. x = x1
  1046. return x, resids, rank, s
  1047. elif driver == 'gelsy':
  1048. lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
  1049. jptv = np.zeros((a1.shape[1], 1), dtype=np.int32)
  1050. v, x, j, rank, info = lapack_func(a1, b1, jptv, cond,
  1051. lwork, False, False)
  1052. if info < 0:
  1053. raise ValueError("illegal value in %d-th argument of internal "
  1054. "gelsy" % -info)
  1055. if m > n:
  1056. x1 = x[:n]
  1057. x = x1
  1058. return x, np.array([], x.dtype), rank, None
  1059. lstsq.default_lapack_driver = 'gelsd'
  1060. def pinv(a, cond=None, rcond=None, return_rank=False, check_finite=True):
  1061. """
  1062. Compute the (Moore-Penrose) pseudo-inverse of a matrix.
  1063. Calculate a generalized inverse of a matrix using a least-squares
  1064. solver.
  1065. Parameters
  1066. ----------
  1067. a : (M, N) array_like
  1068. Matrix to be pseudo-inverted.
  1069. cond, rcond : float, optional
  1070. Cutoff for 'small' singular values in the least-squares solver.
  1071. Singular values smaller than ``rcond * largest_singular_value``
  1072. are considered zero.
  1073. return_rank : bool, optional
  1074. if True, return the effective rank of the matrix
  1075. check_finite : bool, optional
  1076. Whether to check that the input matrix contains only finite numbers.
  1077. Disabling may give a performance gain, but may result in problems
  1078. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  1079. Returns
  1080. -------
  1081. B : (N, M) ndarray
  1082. The pseudo-inverse of matrix `a`.
  1083. rank : int
  1084. The effective rank of the matrix. Returned if return_rank == True
  1085. Raises
  1086. ------
  1087. LinAlgError
  1088. If computation does not converge.
  1089. Examples
  1090. --------
  1091. >>> from scipy import linalg
  1092. >>> a = np.random.randn(9, 6)
  1093. >>> B = linalg.pinv(a)
  1094. >>> np.allclose(a, np.dot(a, np.dot(B, a)))
  1095. True
  1096. >>> np.allclose(B, np.dot(B, np.dot(a, B)))
  1097. True
  1098. """
  1099. a = _asarray_validated(a, check_finite=check_finite)
  1100. b = np.identity(a.shape[0], dtype=a.dtype)
  1101. if rcond is not None:
  1102. cond = rcond
  1103. x, resids, rank, s = lstsq(a, b, cond=cond, check_finite=False)
  1104. if return_rank:
  1105. return x, rank
  1106. else:
  1107. return x
  1108. def pinv2(a, cond=None, rcond=None, return_rank=False, check_finite=True):
  1109. """
  1110. Compute the (Moore-Penrose) pseudo-inverse of a matrix.
  1111. Calculate a generalized inverse of a matrix using its
  1112. singular-value decomposition and including all 'large' singular
  1113. values.
  1114. Parameters
  1115. ----------
  1116. a : (M, N) array_like
  1117. Matrix to be pseudo-inverted.
  1118. cond, rcond : float or None
  1119. Cutoff for 'small' singular values.
  1120. Singular values smaller than ``rcond*largest_singular_value``
  1121. are considered zero.
  1122. If None or -1, suitable machine precision is used.
  1123. return_rank : bool, optional
  1124. if True, return the effective rank of the matrix
  1125. check_finite : bool, optional
  1126. Whether to check that the input matrix contains only finite numbers.
  1127. Disabling may give a performance gain, but may result in problems
  1128. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  1129. Returns
  1130. -------
  1131. B : (N, M) ndarray
  1132. The pseudo-inverse of matrix `a`.
  1133. rank : int
  1134. The effective rank of the matrix. Returned if return_rank == True
  1135. Raises
  1136. ------
  1137. LinAlgError
  1138. If SVD computation does not converge.
  1139. Examples
  1140. --------
  1141. >>> from scipy import linalg
  1142. >>> a = np.random.randn(9, 6)
  1143. >>> B = linalg.pinv2(a)
  1144. >>> np.allclose(a, np.dot(a, np.dot(B, a)))
  1145. True
  1146. >>> np.allclose(B, np.dot(B, np.dot(a, B)))
  1147. True
  1148. """
  1149. a = _asarray_validated(a, check_finite=check_finite)
  1150. u, s, vh = decomp_svd.svd(a, full_matrices=False, check_finite=False)
  1151. if rcond is not None:
  1152. cond = rcond
  1153. if cond in [None, -1]:
  1154. t = u.dtype.char.lower()
  1155. factor = {'f': 1E3, 'd': 1E6}
  1156. cond = factor[t] * np.finfo(t).eps
  1157. rank = np.sum(s > cond * np.max(s))
  1158. u = u[:, :rank]
  1159. u /= s[:rank]
  1160. B = np.transpose(np.conjugate(np.dot(u, vh[:rank])))
  1161. if return_rank:
  1162. return B, rank
  1163. else:
  1164. return B
  1165. def pinvh(a, cond=None, rcond=None, lower=True, return_rank=False,
  1166. check_finite=True):
  1167. """
  1168. Compute the (Moore-Penrose) pseudo-inverse of a Hermitian matrix.
  1169. Calculate a generalized inverse of a Hermitian or real symmetric matrix
  1170. using its eigenvalue decomposition and including all eigenvalues with
  1171. 'large' absolute value.
  1172. Parameters
  1173. ----------
  1174. a : (N, N) array_like
  1175. Real symmetric or complex hermetian matrix to be pseudo-inverted
  1176. cond, rcond : float or None
  1177. Cutoff for 'small' eigenvalues.
  1178. Singular values smaller than rcond * largest_eigenvalue are considered
  1179. zero.
  1180. If None or -1, suitable machine precision is used.
  1181. lower : bool, optional
  1182. Whether the pertinent array data is taken from the lower or upper
  1183. triangle of a. (Default: lower)
  1184. return_rank : bool, optional
  1185. if True, return the effective rank of the matrix
  1186. check_finite : bool, optional
  1187. Whether to check that the input matrix contains only finite numbers.
  1188. Disabling may give a performance gain, but may result in problems
  1189. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  1190. Returns
  1191. -------
  1192. B : (N, N) ndarray
  1193. The pseudo-inverse of matrix `a`.
  1194. rank : int
  1195. The effective rank of the matrix. Returned if return_rank == True
  1196. Raises
  1197. ------
  1198. LinAlgError
  1199. If eigenvalue does not converge
  1200. Examples
  1201. --------
  1202. >>> from scipy.linalg import pinvh
  1203. >>> a = np.random.randn(9, 6)
  1204. >>> a = np.dot(a, a.T)
  1205. >>> B = pinvh(a)
  1206. >>> np.allclose(a, np.dot(a, np.dot(B, a)))
  1207. True
  1208. >>> np.allclose(B, np.dot(B, np.dot(a, B)))
  1209. True
  1210. """
  1211. a = _asarray_validated(a, check_finite=check_finite)
  1212. s, u = decomp.eigh(a, lower=lower, check_finite=False)
  1213. if rcond is not None:
  1214. cond = rcond
  1215. if cond in [None, -1]:
  1216. t = u.dtype.char.lower()
  1217. factor = {'f': 1E3, 'd': 1E6}
  1218. cond = factor[t] * np.finfo(t).eps
  1219. # For Hermitian matrices, singular values equal abs(eigenvalues)
  1220. above_cutoff = (abs(s) > cond * np.max(abs(s)))
  1221. psigma_diag = 1.0 / s[above_cutoff]
  1222. u = u[:, above_cutoff]
  1223. B = np.dot(u * psigma_diag, np.conjugate(u).T)
  1224. if return_rank:
  1225. return B, len(psigma_diag)
  1226. else:
  1227. return B
  1228. def matrix_balance(A, permute=True, scale=True, separate=False,
  1229. overwrite_a=False):
  1230. """
  1231. Compute a diagonal similarity transformation for row/column balancing.
  1232. The balancing tries to equalize the row and column 1-norms by applying
  1233. a similarity transformation such that the magnitude variation of the
  1234. matrix entries is reflected to the scaling matrices.
  1235. Moreover, if enabled, the matrix is first permuted to isolate the upper
  1236. triangular parts of the matrix and, again if scaling is also enabled,
  1237. only the remaining subblocks are subjected to scaling.
  1238. The balanced matrix satisfies the following equality
  1239. .. math::
  1240. B = T^{-1} A T
  1241. The scaling coefficients are approximated to the nearest power of 2
  1242. to avoid round-off errors.
  1243. Parameters
  1244. ----------
  1245. A : (n, n) array_like
  1246. Square data matrix for the balancing.
  1247. permute : bool, optional
  1248. The selector to define whether permutation of A is also performed
  1249. prior to scaling.
  1250. scale : bool, optional
  1251. The selector to turn on and off the scaling. If False, the matrix
  1252. will not be scaled.
  1253. separate : bool, optional
  1254. This switches from returning a full matrix of the transformation
  1255. to a tuple of two separate 1D permutation and scaling arrays.
  1256. overwrite_a : bool, optional
  1257. This is passed to xGEBAL directly. Essentially, overwrites the result
  1258. to the data. It might increase the space efficiency. See LAPACK manual
  1259. for details. This is False by default.
  1260. Returns
  1261. -------
  1262. B : (n, n) ndarray
  1263. Balanced matrix
  1264. T : (n, n) ndarray
  1265. A possibly permuted diagonal matrix whose nonzero entries are
  1266. integer powers of 2 to avoid numerical truncation errors.
  1267. scale, perm : (n,) ndarray
  1268. If ``separate`` keyword is set to True then instead of the array
  1269. ``T`` above, the scaling and the permutation vectors are given
  1270. separately as a tuple without allocating the full array ``T``.
  1271. Notes
  1272. -----
  1273. This algorithm is particularly useful for eigenvalue and matrix
  1274. decompositions and in many cases it is already called by various
  1275. LAPACK routines.
  1276. The algorithm is based on the well-known technique of [1]_ and has
  1277. been modified to account for special cases. See [2]_ for details
  1278. which have been implemented since LAPACK v3.5.0. Before this version
  1279. there are corner cases where balancing can actually worsen the
  1280. conditioning. See [3]_ for such examples.
  1281. The code is a wrapper around LAPACK's xGEBAL routine family for matrix
  1282. balancing.
  1283. .. versionadded:: 0.19.0
  1284. Examples
  1285. --------
  1286. >>> from scipy import linalg
  1287. >>> x = np.array([[1,2,0], [9,1,0.01], [1,2,10*np.pi]])
  1288. >>> y, permscale = linalg.matrix_balance(x)
  1289. >>> np.abs(x).sum(axis=0) / np.abs(x).sum(axis=1)
  1290. array([ 3.66666667, 0.4995005 , 0.91312162])
  1291. >>> np.abs(y).sum(axis=0) / np.abs(y).sum(axis=1)
  1292. array([ 1.2 , 1.27041742, 0.92658316]) # may vary
  1293. >>> permscale # only powers of 2 (0.5 == 2^(-1))
  1294. array([[ 0.5, 0. , 0. ], # may vary
  1295. [ 0. , 1. , 0. ],
  1296. [ 0. , 0. , 1. ]])
  1297. References
  1298. ----------
  1299. .. [1] : B.N. Parlett and C. Reinsch, "Balancing a Matrix for
  1300. Calculation of Eigenvalues and Eigenvectors", Numerische Mathematik,
  1301. Vol.13(4), 1969, DOI:10.1007/BF02165404
  1302. .. [2] : R. James, J. Langou, B.R. Lowery, "On matrix balancing and
  1303. eigenvector computation", 2014, Available online:
  1304. https://arxiv.org/abs/1401.5766
  1305. .. [3] : D.S. Watkins. A case where balancing is harmful.
  1306. Electron. Trans. Numer. Anal, Vol.23, 2006.
  1307. """
  1308. A = np.atleast_2d(_asarray_validated(A, check_finite=True))
  1309. if not np.equal(*A.shape):
  1310. raise ValueError('The data matrix for balancing should be square.')
  1311. gebal = get_lapack_funcs(('gebal'), (A,))
  1312. B, lo, hi, ps, info = gebal(A, scale=scale, permute=permute,
  1313. overwrite_a=overwrite_a)
  1314. if info < 0:
  1315. raise ValueError('xGEBAL exited with the internal error '
  1316. '"illegal value in argument number {}.". See '
  1317. 'LAPACK documentation for the xGEBAL error codes.'
  1318. ''.format(-info))
  1319. # Separate the permutations from the scalings and then convert to int
  1320. scaling = np.ones_like(ps, dtype=float)
  1321. scaling[lo:hi+1] = ps[lo:hi+1]
  1322. # gebal uses 1-indexing
  1323. ps = ps.astype(int, copy=False) - 1
  1324. n = A.shape[0]
  1325. perm = np.arange(n)
  1326. # LAPACK permutes with the ordering n --> hi, then 0--> lo
  1327. if hi < n:
  1328. for ind, x in enumerate(ps[hi+1:][::-1], 1):
  1329. if n-ind == x:
  1330. continue
  1331. perm[[x, n-ind]] = perm[[n-ind, x]]
  1332. if lo > 0:
  1333. for ind, x in enumerate(ps[:lo]):
  1334. if ind == x:
  1335. continue
  1336. perm[[x, ind]] = perm[[ind, x]]
  1337. if separate:
  1338. return B, (scaling, perm)
  1339. # get the inverse permutation
  1340. iperm = np.empty_like(perm)
  1341. iperm[perm] = np.arange(n)
  1342. return B, np.diag(scaling)[iperm, :]