test_decomp_update.py 67 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697
  1. from __future__ import division, print_function, absolute_import
  2. import itertools
  3. import numpy as np
  4. from numpy.testing import assert_, assert_allclose, assert_equal
  5. from pytest import raises as assert_raises
  6. from scipy import linalg
  7. import scipy.linalg._decomp_update as _decomp_update
  8. from scipy.linalg._decomp_update import *
  9. def assert_unitary(a, rtol=None, atol=None, assert_sqr=True):
  10. if rtol is None:
  11. rtol = 10.0 ** -(np.finfo(a.dtype).precision-2)
  12. if atol is None:
  13. atol = 2*np.finfo(a.dtype).eps
  14. if assert_sqr:
  15. assert_(a.shape[0] == a.shape[1], 'unitary matrices must be square')
  16. aTa = np.dot(a.T.conj(), a)
  17. assert_allclose(aTa, np.eye(a.shape[1]), rtol=rtol, atol=atol)
  18. def assert_upper_tri(a, rtol=None, atol=None):
  19. if rtol is None:
  20. rtol = 10.0 ** -(np.finfo(a.dtype).precision-2)
  21. if atol is None:
  22. atol = 2*np.finfo(a.dtype).eps
  23. mask = np.tri(a.shape[0], a.shape[1], -1, np.bool_)
  24. assert_allclose(a[mask], 0.0, rtol=rtol, atol=atol)
  25. def check_qr(q, r, a, rtol, atol, assert_sqr=True):
  26. assert_unitary(q, rtol, atol, assert_sqr)
  27. assert_upper_tri(r, rtol, atol)
  28. assert_allclose(q.dot(r), a, rtol=rtol, atol=atol)
  29. def make_strided(arrs):
  30. strides = [(3, 7), (2, 2), (3, 4), (4, 2), (5, 4), (2, 3), (2, 1), (4, 5)]
  31. kmax = len(strides)
  32. k = 0
  33. ret = []
  34. for a in arrs:
  35. if a.ndim == 1:
  36. s = strides[k % kmax]
  37. k += 1
  38. base = np.zeros(s[0]*a.shape[0]+s[1], a.dtype)
  39. view = base[s[1]::s[0]]
  40. view[...] = a
  41. elif a.ndim == 2:
  42. s = strides[k % kmax]
  43. t = strides[(k+1) % kmax]
  44. k += 2
  45. base = np.zeros((s[0]*a.shape[0]+s[1], t[0]*a.shape[1]+t[1]), a.dtype)
  46. view = base[s[1]::s[0], t[1]::t[0]]
  47. view[...] = a
  48. else:
  49. raise ValueError('make_strided only works for ndim = 1 or 2 arrays')
  50. ret.append(view)
  51. return ret
  52. def negate_strides(arrs):
  53. ret = []
  54. for a in arrs:
  55. b = np.zeros_like(a)
  56. if b.ndim == 2:
  57. b = b[::-1, ::-1]
  58. elif b.ndim == 1:
  59. b = b[::-1]
  60. else:
  61. raise ValueError('negate_strides only works for ndim = 1 or 2 arrays')
  62. b[...] = a
  63. ret.append(b)
  64. return ret
  65. def nonitemsize_strides(arrs):
  66. out = []
  67. for a in arrs:
  68. a_dtype = a.dtype
  69. b = np.zeros(a.shape, [('a', a_dtype), ('junk', 'S1')])
  70. c = b.getfield(a_dtype)
  71. c[...] = a
  72. out.append(c)
  73. return out
  74. def make_nonnative(arrs):
  75. out = []
  76. for a in arrs:
  77. out.append(a.astype(a.dtype.newbyteorder()))
  78. return out
  79. class BaseQRdeltas(object):
  80. def setup_method(self):
  81. self.rtol = 10.0 ** -(np.finfo(self.dtype).precision-2)
  82. self.atol = 10 * np.finfo(self.dtype).eps
  83. def generate(self, type, mode='full'):
  84. np.random.seed(29382)
  85. shape = {'sqr': (8, 8), 'tall': (12, 7), 'fat': (7, 12),
  86. 'Mx1': (8, 1), '1xN': (1, 8), '1x1': (1, 1)}[type]
  87. a = np.random.random(shape)
  88. if np.iscomplexobj(self.dtype.type(1)):
  89. b = np.random.random(shape)
  90. a = a + 1j * b
  91. a = a.astype(self.dtype)
  92. q, r = linalg.qr(a, mode=mode)
  93. return a, q, r
  94. class BaseQRdelete(BaseQRdeltas):
  95. def test_sqr_1_row(self):
  96. a, q, r = self.generate('sqr')
  97. for row in range(r.shape[0]):
  98. q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
  99. a1 = np.delete(a, row, 0)
  100. check_qr(q1, r1, a1, self.rtol, self.atol)
  101. def test_sqr_p_row(self):
  102. a, q, r = self.generate('sqr')
  103. for ndel in range(2, 6):
  104. for row in range(a.shape[0]-ndel):
  105. q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
  106. a1 = np.delete(a, slice(row, row+ndel), 0)
  107. check_qr(q1, r1, a1, self.rtol, self.atol)
  108. def test_sqr_1_col(self):
  109. a, q, r = self.generate('sqr')
  110. for col in range(r.shape[1]):
  111. q1, r1 = qr_delete(q, r, col, which='col', overwrite_qr=False)
  112. a1 = np.delete(a, col, 1)
  113. check_qr(q1, r1, a1, self.rtol, self.atol)
  114. def test_sqr_p_col(self):
  115. a, q, r = self.generate('sqr')
  116. for ndel in range(2, 6):
  117. for col in range(r.shape[1]-ndel):
  118. q1, r1 = qr_delete(q, r, col, ndel, which='col',
  119. overwrite_qr=False)
  120. a1 = np.delete(a, slice(col, col+ndel), 1)
  121. check_qr(q1, r1, a1, self.rtol, self.atol)
  122. def test_tall_1_row(self):
  123. a, q, r = self.generate('tall')
  124. for row in range(r.shape[0]):
  125. q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
  126. a1 = np.delete(a, row, 0)
  127. check_qr(q1, r1, a1, self.rtol, self.atol)
  128. def test_tall_p_row(self):
  129. a, q, r = self.generate('tall')
  130. for ndel in range(2, 6):
  131. for row in range(a.shape[0]-ndel):
  132. q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
  133. a1 = np.delete(a, slice(row, row+ndel), 0)
  134. check_qr(q1, r1, a1, self.rtol, self.atol)
  135. def test_tall_1_col(self):
  136. a, q, r = self.generate('tall')
  137. for col in range(r.shape[1]):
  138. q1, r1 = qr_delete(q, r, col, which='col', overwrite_qr=False)
  139. a1 = np.delete(a, col, 1)
  140. check_qr(q1, r1, a1, self.rtol, self.atol)
  141. def test_tall_p_col(self):
  142. a, q, r = self.generate('tall')
  143. for ndel in range(2, 6):
  144. for col in range(r.shape[1]-ndel):
  145. q1, r1 = qr_delete(q, r, col, ndel, which='col',
  146. overwrite_qr=False)
  147. a1 = np.delete(a, slice(col, col+ndel), 1)
  148. check_qr(q1, r1, a1, self.rtol, self.atol)
  149. def test_fat_1_row(self):
  150. a, q, r = self.generate('fat')
  151. for row in range(r.shape[0]):
  152. q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
  153. a1 = np.delete(a, row, 0)
  154. check_qr(q1, r1, a1, self.rtol, self.atol)
  155. def test_fat_p_row(self):
  156. a, q, r = self.generate('fat')
  157. for ndel in range(2, 6):
  158. for row in range(a.shape[0]-ndel):
  159. q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
  160. a1 = np.delete(a, slice(row, row+ndel), 0)
  161. check_qr(q1, r1, a1, self.rtol, self.atol)
  162. def test_fat_1_col(self):
  163. a, q, r = self.generate('fat')
  164. for col in range(r.shape[1]):
  165. q1, r1 = qr_delete(q, r, col, which='col', overwrite_qr=False)
  166. a1 = np.delete(a, col, 1)
  167. check_qr(q1, r1, a1, self.rtol, self.atol)
  168. def test_fat_p_col(self):
  169. a, q, r = self.generate('fat')
  170. for ndel in range(2, 6):
  171. for col in range(r.shape[1]-ndel):
  172. q1, r1 = qr_delete(q, r, col, ndel, which='col',
  173. overwrite_qr=False)
  174. a1 = np.delete(a, slice(col, col+ndel), 1)
  175. check_qr(q1, r1, a1, self.rtol, self.atol)
  176. def test_economic_1_row(self):
  177. # this test always starts and ends with an economic decomp.
  178. a, q, r = self.generate('tall', 'economic')
  179. for row in range(r.shape[0]):
  180. q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
  181. a1 = np.delete(a, row, 0)
  182. check_qr(q1, r1, a1, self.rtol, self.atol, False)
  183. # for economic row deletes
  184. # eco - prow = eco
  185. # eco - prow = sqr
  186. # eco - prow = fat
  187. def base_economic_p_row_xxx(self, ndel):
  188. a, q, r = self.generate('tall', 'economic')
  189. for row in range(a.shape[0]-ndel):
  190. q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
  191. a1 = np.delete(a, slice(row, row+ndel), 0)
  192. check_qr(q1, r1, a1, self.rtol, self.atol, False)
  193. def test_economic_p_row_economic(self):
  194. # (12, 7) - (3, 7) = (9,7) --> stays economic
  195. self.base_economic_p_row_xxx(3)
  196. def test_economic_p_row_sqr(self):
  197. # (12, 7) - (5, 7) = (7, 7) --> becomes square
  198. self.base_economic_p_row_xxx(5)
  199. def test_economic_p_row_fat(self):
  200. # (12, 7) - (7,7) = (5, 7) --> becomes fat
  201. self.base_economic_p_row_xxx(7)
  202. def test_economic_1_col(self):
  203. a, q, r = self.generate('tall', 'economic')
  204. for col in range(r.shape[1]):
  205. q1, r1 = qr_delete(q, r, col, which='col', overwrite_qr=False)
  206. a1 = np.delete(a, col, 1)
  207. check_qr(q1, r1, a1, self.rtol, self.atol, False)
  208. def test_economic_p_col(self):
  209. a, q, r = self.generate('tall', 'economic')
  210. for ndel in range(2, 6):
  211. for col in range(r.shape[1]-ndel):
  212. q1, r1 = qr_delete(q, r, col, ndel, which='col',
  213. overwrite_qr=False)
  214. a1 = np.delete(a, slice(col, col+ndel), 1)
  215. check_qr(q1, r1, a1, self.rtol, self.atol, False)
  216. def test_Mx1_1_row(self):
  217. a, q, r = self.generate('Mx1')
  218. for row in range(r.shape[0]):
  219. q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
  220. a1 = np.delete(a, row, 0)
  221. check_qr(q1, r1, a1, self.rtol, self.atol)
  222. def test_Mx1_p_row(self):
  223. a, q, r = self.generate('Mx1')
  224. for ndel in range(2, 6):
  225. for row in range(a.shape[0]-ndel):
  226. q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
  227. a1 = np.delete(a, slice(row, row+ndel), 0)
  228. check_qr(q1, r1, a1, self.rtol, self.atol)
  229. def test_1xN_1_col(self):
  230. a, q, r = self.generate('1xN')
  231. for col in range(r.shape[1]):
  232. q1, r1 = qr_delete(q, r, col, which='col', overwrite_qr=False)
  233. a1 = np.delete(a, col, 1)
  234. check_qr(q1, r1, a1, self.rtol, self.atol)
  235. def test_1xN_p_col(self):
  236. a, q, r = self.generate('1xN')
  237. for ndel in range(2, 6):
  238. for col in range(r.shape[1]-ndel):
  239. q1, r1 = qr_delete(q, r, col, ndel, which='col',
  240. overwrite_qr=False)
  241. a1 = np.delete(a, slice(col, col+ndel), 1)
  242. check_qr(q1, r1, a1, self.rtol, self.atol)
  243. def test_Mx1_economic_1_row(self):
  244. a, q, r = self.generate('Mx1', 'economic')
  245. for row in range(r.shape[0]):
  246. q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
  247. a1 = np.delete(a, row, 0)
  248. check_qr(q1, r1, a1, self.rtol, self.atol, False)
  249. def test_Mx1_economic_p_row(self):
  250. a, q, r = self.generate('Mx1', 'economic')
  251. for ndel in range(2, 6):
  252. for row in range(a.shape[0]-ndel):
  253. q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
  254. a1 = np.delete(a, slice(row, row+ndel), 0)
  255. check_qr(q1, r1, a1, self.rtol, self.atol, False)
  256. def test_delete_last_1_row(self):
  257. # full and eco are the same for 1xN
  258. a, q, r = self.generate('1xN')
  259. q1, r1 = qr_delete(q, r, 0, 1, 'row')
  260. assert_equal(q1, np.ndarray(shape=(0, 0), dtype=q.dtype))
  261. assert_equal(r1, np.ndarray(shape=(0, r.shape[1]), dtype=r.dtype))
  262. def test_delete_last_p_row(self):
  263. a, q, r = self.generate('tall', 'full')
  264. q1, r1 = qr_delete(q, r, 0, a.shape[0], 'row')
  265. assert_equal(q1, np.ndarray(shape=(0, 0), dtype=q.dtype))
  266. assert_equal(r1, np.ndarray(shape=(0, r.shape[1]), dtype=r.dtype))
  267. a, q, r = self.generate('tall', 'economic')
  268. q1, r1 = qr_delete(q, r, 0, a.shape[0], 'row')
  269. assert_equal(q1, np.ndarray(shape=(0, 0), dtype=q.dtype))
  270. assert_equal(r1, np.ndarray(shape=(0, r.shape[1]), dtype=r.dtype))
  271. def test_delete_last_1_col(self):
  272. a, q, r = self.generate('Mx1', 'economic')
  273. q1, r1 = qr_delete(q, r, 0, 1, 'col')
  274. assert_equal(q1, np.ndarray(shape=(q.shape[0], 0), dtype=q.dtype))
  275. assert_equal(r1, np.ndarray(shape=(0, 0), dtype=r.dtype))
  276. a, q, r = self.generate('Mx1', 'full')
  277. q1, r1 = qr_delete(q, r, 0, 1, 'col')
  278. assert_unitary(q1)
  279. assert_(q1.dtype == q.dtype)
  280. assert_(q1.shape == q.shape)
  281. assert_equal(r1, np.ndarray(shape=(r.shape[0], 0), dtype=r.dtype))
  282. def test_delete_last_p_col(self):
  283. a, q, r = self.generate('tall', 'full')
  284. q1, r1 = qr_delete(q, r, 0, a.shape[1], 'col')
  285. assert_unitary(q1)
  286. assert_(q1.dtype == q.dtype)
  287. assert_(q1.shape == q.shape)
  288. assert_equal(r1, np.ndarray(shape=(r.shape[0], 0), dtype=r.dtype))
  289. a, q, r = self.generate('tall', 'economic')
  290. q1, r1 = qr_delete(q, r, 0, a.shape[1], 'col')
  291. assert_equal(q1, np.ndarray(shape=(q.shape[0], 0), dtype=q.dtype))
  292. assert_equal(r1, np.ndarray(shape=(0, 0), dtype=r.dtype))
  293. def test_delete_1x1_row_col(self):
  294. a, q, r = self.generate('1x1')
  295. q1, r1 = qr_delete(q, r, 0, 1, 'row')
  296. assert_equal(q1, np.ndarray(shape=(0, 0), dtype=q.dtype))
  297. assert_equal(r1, np.ndarray(shape=(0, r.shape[1]), dtype=r.dtype))
  298. a, q, r = self.generate('1x1')
  299. q1, r1 = qr_delete(q, r, 0, 1, 'col')
  300. assert_unitary(q1)
  301. assert_(q1.dtype == q.dtype)
  302. assert_(q1.shape == q.shape)
  303. assert_equal(r1, np.ndarray(shape=(r.shape[0], 0), dtype=r.dtype))
  304. # all full qr, row deletes and single column deletes should be able to
  305. # handle any non negative strides. (only row and column vector
  306. # operations are used.) p column delete require fortran ordered
  307. # Q and R and will make a copy as necessary. Economic qr row deletes
  308. # requre a contigous q.
  309. def base_non_simple_strides(self, adjust_strides, ks, p, which, overwriteable):
  310. if which == 'row':
  311. qind = (slice(p,None), slice(p,None))
  312. rind = (slice(p,None), slice(None))
  313. else:
  314. qind = (slice(None), slice(None))
  315. rind = (slice(None), slice(None,-p))
  316. for type, k in itertools.product(['sqr', 'tall', 'fat'], ks):
  317. a, q0, r0, = self.generate(type)
  318. qs, rs = adjust_strides((q0, r0))
  319. if p == 1:
  320. a1 = np.delete(a, k, 0 if which == 'row' else 1)
  321. else:
  322. s = slice(k,k+p)
  323. if k < 0:
  324. s = slice(k, k + p + (a.shape[0] if which == 'row' else a.shape[1]))
  325. a1 = np.delete(a, s, 0 if which == 'row' else 1)
  326. # for each variable, q, r we try with it strided and
  327. # overwrite=False. Then we try with overwrite=True, and make
  328. # sure that q and r are still overwritten.
  329. q = q0.copy('F')
  330. r = r0.copy('F')
  331. q1, r1 = qr_delete(qs, r, k, p, which, False)
  332. check_qr(q1, r1, a1, self.rtol, self.atol)
  333. q1o, r1o = qr_delete(qs, r, k, p, which, True)
  334. check_qr(q1o, r1o, a1, self.rtol, self.atol)
  335. if overwriteable:
  336. assert_allclose(q1o, qs[qind], rtol=self.rtol, atol=self.atol)
  337. assert_allclose(r1o, r[rind], rtol=self.rtol, atol=self.atol)
  338. q = q0.copy('F')
  339. r = r0.copy('F')
  340. q2, r2 = qr_delete(q, rs, k, p, which, False)
  341. check_qr(q2, r2, a1, self.rtol, self.atol)
  342. q2o, r2o = qr_delete(q, rs, k, p, which, True)
  343. check_qr(q2o, r2o, a1, self.rtol, self.atol)
  344. if overwriteable:
  345. assert_allclose(q2o, q[qind], rtol=self.rtol, atol=self.atol)
  346. assert_allclose(r2o, rs[rind], rtol=self.rtol, atol=self.atol)
  347. q = q0.copy('F')
  348. r = r0.copy('F')
  349. # since some of these were consumed above
  350. qs, rs = adjust_strides((q, r))
  351. q3, r3 = qr_delete(qs, rs, k, p, which, False)
  352. check_qr(q3, r3, a1, self.rtol, self.atol)
  353. q3o, r3o = qr_delete(qs, rs, k, p, which, True)
  354. check_qr(q3o, r3o, a1, self.rtol, self.atol)
  355. if overwriteable:
  356. assert_allclose(q2o, qs[qind], rtol=self.rtol, atol=self.atol)
  357. assert_allclose(r3o, rs[rind], rtol=self.rtol, atol=self.atol)
  358. def test_non_unit_strides_1_row(self):
  359. self.base_non_simple_strides(make_strided, [0], 1, 'row', True)
  360. def test_non_unit_strides_p_row(self):
  361. self.base_non_simple_strides(make_strided, [0], 3, 'row', True)
  362. def test_non_unit_strides_1_col(self):
  363. self.base_non_simple_strides(make_strided, [0], 1, 'col', True)
  364. def test_non_unit_strides_p_col(self):
  365. self.base_non_simple_strides(make_strided, [0], 3, 'col', False)
  366. def test_neg_strides_1_row(self):
  367. self.base_non_simple_strides(negate_strides, [0], 1, 'row', False)
  368. def test_neg_strides_p_row(self):
  369. self.base_non_simple_strides(negate_strides, [0], 3, 'row', False)
  370. def test_neg_strides_1_col(self):
  371. self.base_non_simple_strides(negate_strides, [0], 1, 'col', False)
  372. def test_neg_strides_p_col(self):
  373. self.base_non_simple_strides(negate_strides, [0], 3, 'col', False)
  374. def test_non_itemize_strides_1_row(self):
  375. self.base_non_simple_strides(nonitemsize_strides, [0], 1, 'row', False)
  376. def test_non_itemize_strides_p_row(self):
  377. self.base_non_simple_strides(nonitemsize_strides, [0], 3, 'row', False)
  378. def test_non_itemize_strides_1_col(self):
  379. self.base_non_simple_strides(nonitemsize_strides, [0], 1, 'col', False)
  380. def test_non_itemize_strides_p_col(self):
  381. self.base_non_simple_strides(nonitemsize_strides, [0], 3, 'col', False)
  382. def test_non_native_byte_order_1_row(self):
  383. self.base_non_simple_strides(make_nonnative, [0], 1, 'row', False)
  384. def test_non_native_byte_order_p_row(self):
  385. self.base_non_simple_strides(make_nonnative, [0], 3, 'row', False)
  386. def test_non_native_byte_order_1_col(self):
  387. self.base_non_simple_strides(make_nonnative, [0], 1, 'col', False)
  388. def test_non_native_byte_order_p_col(self):
  389. self.base_non_simple_strides(make_nonnative, [0], 3, 'col', False)
  390. def test_neg_k(self):
  391. a, q, r = self.generate('sqr')
  392. for k, p, w in itertools.product([-3, -7], [1, 3], ['row', 'col']):
  393. q1, r1 = qr_delete(q, r, k, p, w, overwrite_qr=False)
  394. if w == 'row':
  395. a1 = np.delete(a, slice(k+a.shape[0], k+p+a.shape[0]), 0)
  396. else:
  397. a1 = np.delete(a, slice(k+a.shape[0], k+p+a.shape[1]), 1)
  398. check_qr(q1, r1, a1, self.rtol, self.atol)
  399. def base_overwrite_qr(self, which, p, test_C, test_F, mode='full'):
  400. assert_sqr = True if mode == 'full' else False
  401. if which == 'row':
  402. qind = (slice(p,None), slice(p,None))
  403. rind = (slice(p,None), slice(None))
  404. else:
  405. qind = (slice(None), slice(None))
  406. rind = (slice(None), slice(None,-p))
  407. a, q0, r0 = self.generate('sqr', mode)
  408. if p == 1:
  409. a1 = np.delete(a, 3, 0 if which == 'row' else 1)
  410. else:
  411. a1 = np.delete(a, slice(3, 3+p), 0 if which == 'row' else 1)
  412. # don't overwrite
  413. q = q0.copy('F')
  414. r = r0.copy('F')
  415. q1, r1 = qr_delete(q, r, 3, p, which, False)
  416. check_qr(q1, r1, a1, self.rtol, self.atol, assert_sqr)
  417. check_qr(q, r, a, self.rtol, self.atol, assert_sqr)
  418. if test_F:
  419. q = q0.copy('F')
  420. r = r0.copy('F')
  421. q2, r2 = qr_delete(q, r, 3, p, which, True)
  422. check_qr(q2, r2, a1, self.rtol, self.atol, assert_sqr)
  423. # verify the overwriting
  424. assert_allclose(q2, q[qind], rtol=self.rtol, atol=self.atol)
  425. assert_allclose(r2, r[rind], rtol=self.rtol, atol=self.atol)
  426. if test_C:
  427. q = q0.copy('C')
  428. r = r0.copy('C')
  429. q3, r3 = qr_delete(q, r, 3, p, which, True)
  430. check_qr(q3, r3, a1, self.rtol, self.atol, assert_sqr)
  431. assert_allclose(q3, q[qind], rtol=self.rtol, atol=self.atol)
  432. assert_allclose(r3, r[rind], rtol=self.rtol, atol=self.atol)
  433. def test_overwrite_qr_1_row(self):
  434. # any positively strided q and r.
  435. self.base_overwrite_qr('row', 1, True, True)
  436. def test_overwrite_economic_qr_1_row(self):
  437. # Any contiguous q and positively strided r.
  438. self.base_overwrite_qr('row', 1, True, True, 'economic')
  439. def test_overwrite_qr_1_col(self):
  440. # any positively strided q and r.
  441. # full and eco share code paths
  442. self.base_overwrite_qr('col', 1, True, True)
  443. def test_overwrite_qr_p_row(self):
  444. # any positively strided q and r.
  445. self.base_overwrite_qr('row', 3, True, True)
  446. def test_overwrite_economic_qr_p_row(self):
  447. # any contiguous q and positively strided r
  448. self.base_overwrite_qr('row', 3, True, True, 'economic')
  449. def test_overwrite_qr_p_col(self):
  450. # only F orderd q and r can be overwritten for cols
  451. # full and eco share code paths
  452. self.base_overwrite_qr('col', 3, False, True)
  453. def test_bad_which(self):
  454. a, q, r = self.generate('sqr')
  455. assert_raises(ValueError, qr_delete, q, r, 0, which='foo')
  456. def test_bad_k(self):
  457. a, q, r = self.generate('tall')
  458. assert_raises(ValueError, qr_delete, q, r, q.shape[0], 1)
  459. assert_raises(ValueError, qr_delete, q, r, -q.shape[0]-1, 1)
  460. assert_raises(ValueError, qr_delete, q, r, r.shape[0], 1, 'col')
  461. assert_raises(ValueError, qr_delete, q, r, -r.shape[0]-1, 1, 'col')
  462. def test_bad_p(self):
  463. a, q, r = self.generate('tall')
  464. # p must be positive
  465. assert_raises(ValueError, qr_delete, q, r, 0, -1)
  466. assert_raises(ValueError, qr_delete, q, r, 0, -1, 'col')
  467. # and nonzero
  468. assert_raises(ValueError, qr_delete, q, r, 0, 0)
  469. assert_raises(ValueError, qr_delete, q, r, 0, 0, 'col')
  470. # must have at least k+p rows or cols, depending.
  471. assert_raises(ValueError, qr_delete, q, r, 3, q.shape[0]-2)
  472. assert_raises(ValueError, qr_delete, q, r, 3, r.shape[1]-2, 'col')
  473. def test_empty_q(self):
  474. a, q, r = self.generate('tall')
  475. # same code path for 'row' and 'col'
  476. assert_raises(ValueError, qr_delete, np.array([]), r, 0, 1)
  477. def test_empty_r(self):
  478. a, q, r = self.generate('tall')
  479. # same code path for 'row' and 'col'
  480. assert_raises(ValueError, qr_delete, q, np.array([]), 0, 1)
  481. def test_mismatched_q_and_r(self):
  482. a, q, r = self.generate('tall')
  483. r = r[1:]
  484. assert_raises(ValueError, qr_delete, q, r, 0, 1)
  485. def test_unsupported_dtypes(self):
  486. dts = ['int8', 'int16', 'int32', 'int64',
  487. 'uint8', 'uint16', 'uint32', 'uint64',
  488. 'float16', 'longdouble', 'longcomplex',
  489. 'bool']
  490. a, q0, r0 = self.generate('tall')
  491. for dtype in dts:
  492. q = q0.real.astype(dtype)
  493. r = r0.real.astype(dtype)
  494. assert_raises(ValueError, qr_delete, q, r0, 0, 1, 'row')
  495. assert_raises(ValueError, qr_delete, q, r0, 0, 2, 'row')
  496. assert_raises(ValueError, qr_delete, q, r0, 0, 1, 'col')
  497. assert_raises(ValueError, qr_delete, q, r0, 0, 2, 'col')
  498. assert_raises(ValueError, qr_delete, q0, r, 0, 1, 'row')
  499. assert_raises(ValueError, qr_delete, q0, r, 0, 2, 'row')
  500. assert_raises(ValueError, qr_delete, q0, r, 0, 1, 'col')
  501. assert_raises(ValueError, qr_delete, q0, r, 0, 2, 'col')
  502. def test_check_finite(self):
  503. a0, q0, r0 = self.generate('tall')
  504. q = q0.copy('F')
  505. q[1,1] = np.nan
  506. assert_raises(ValueError, qr_delete, q, r0, 0, 1, 'row')
  507. assert_raises(ValueError, qr_delete, q, r0, 0, 3, 'row')
  508. assert_raises(ValueError, qr_delete, q, r0, 0, 1, 'col')
  509. assert_raises(ValueError, qr_delete, q, r0, 0, 3, 'col')
  510. r = r0.copy('F')
  511. r[1,1] = np.nan
  512. assert_raises(ValueError, qr_delete, q0, r, 0, 1, 'row')
  513. assert_raises(ValueError, qr_delete, q0, r, 0, 3, 'row')
  514. assert_raises(ValueError, qr_delete, q0, r, 0, 1, 'col')
  515. assert_raises(ValueError, qr_delete, q0, r, 0, 3, 'col')
  516. def test_qr_scalar(self):
  517. a, q, r = self.generate('1x1')
  518. assert_raises(ValueError, qr_delete, q[0, 0], r, 0, 1, 'row')
  519. assert_raises(ValueError, qr_delete, q, r[0, 0], 0, 1, 'row')
  520. assert_raises(ValueError, qr_delete, q[0, 0], r, 0, 1, 'col')
  521. assert_raises(ValueError, qr_delete, q, r[0, 0], 0, 1, 'col')
  522. class TestQRdelete_f(BaseQRdelete):
  523. dtype = np.dtype('f')
  524. class TestQRdelete_F(BaseQRdelete):
  525. dtype = np.dtype('F')
  526. class TestQRdelete_d(BaseQRdelete):
  527. dtype = np.dtype('d')
  528. class TestQRdelete_D(BaseQRdelete):
  529. dtype = np.dtype('D')
  530. class BaseQRinsert(BaseQRdeltas):
  531. def generate(self, type, mode='full', which='row', p=1):
  532. a, q, r = super(BaseQRinsert, self).generate(type, mode)
  533. assert_(p > 0)
  534. # super call set the seed...
  535. if which == 'row':
  536. if p == 1:
  537. u = np.random.random(a.shape[1])
  538. else:
  539. u = np.random.random((p, a.shape[1]))
  540. elif which == 'col':
  541. if p == 1:
  542. u = np.random.random(a.shape[0])
  543. else:
  544. u = np.random.random((a.shape[0], p))
  545. else:
  546. ValueError('which should be either "row" or "col"')
  547. if np.iscomplexobj(self.dtype.type(1)):
  548. b = np.random.random(u.shape)
  549. u = u + 1j * b
  550. u = u.astype(self.dtype)
  551. return a, q, r, u
  552. def test_sqr_1_row(self):
  553. a, q, r, u = self.generate('sqr', which='row')
  554. for row in range(r.shape[0] + 1):
  555. q1, r1 = qr_insert(q, r, u, row)
  556. a1 = np.insert(a, row, u, 0)
  557. check_qr(q1, r1, a1, self.rtol, self.atol)
  558. def test_sqr_p_row(self):
  559. # sqr + rows --> fat always
  560. a, q, r, u = self.generate('sqr', which='row', p=3)
  561. for row in range(r.shape[0] + 1):
  562. q1, r1 = qr_insert(q, r, u, row)
  563. a1 = np.insert(a, row*np.ones(3, np.intp), u, 0)
  564. check_qr(q1, r1, a1, self.rtol, self.atol)
  565. def test_sqr_1_col(self):
  566. a, q, r, u = self.generate('sqr', which='col')
  567. for col in range(r.shape[1] + 1):
  568. q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
  569. a1 = np.insert(a, col, u, 1)
  570. check_qr(q1, r1, a1, self.rtol, self.atol)
  571. def test_sqr_p_col(self):
  572. # sqr + cols --> fat always
  573. a, q, r, u = self.generate('sqr', which='col', p=3)
  574. for col in range(r.shape[1] + 1):
  575. q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
  576. a1 = np.insert(a, col*np.ones(3, np.intp), u, 1)
  577. check_qr(q1, r1, a1, self.rtol, self.atol)
  578. def test_tall_1_row(self):
  579. a, q, r, u = self.generate('tall', which='row')
  580. for row in range(r.shape[0] + 1):
  581. q1, r1 = qr_insert(q, r, u, row)
  582. a1 = np.insert(a, row, u, 0)
  583. check_qr(q1, r1, a1, self.rtol, self.atol)
  584. def test_tall_p_row(self):
  585. # tall + rows --> tall always
  586. a, q, r, u = self.generate('tall', which='row', p=3)
  587. for row in range(r.shape[0] + 1):
  588. q1, r1 = qr_insert(q, r, u, row)
  589. a1 = np.insert(a, row*np.ones(3, np.intp), u, 0)
  590. check_qr(q1, r1, a1, self.rtol, self.atol)
  591. def test_tall_1_col(self):
  592. a, q, r, u = self.generate('tall', which='col')
  593. for col in range(r.shape[1] + 1):
  594. q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
  595. a1 = np.insert(a, col, u, 1)
  596. check_qr(q1, r1, a1, self.rtol, self.atol)
  597. # for column adds to tall matrices there are three cases to test
  598. # tall + pcol --> tall
  599. # tall + pcol --> sqr
  600. # tall + pcol --> fat
  601. def base_tall_p_col_xxx(self, p):
  602. a, q, r, u = self.generate('tall', which='col', p=p)
  603. for col in range(r.shape[1] + 1):
  604. q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
  605. a1 = np.insert(a, col*np.ones(p, np.intp), u, 1)
  606. check_qr(q1, r1, a1, self.rtol, self.atol)
  607. def test_tall_p_col_tall(self):
  608. # 12x7 + 12x3 = 12x10 --> stays tall
  609. self.base_tall_p_col_xxx(3)
  610. def test_tall_p_col_sqr(self):
  611. # 12x7 + 12x5 = 12x12 --> becomes sqr
  612. self.base_tall_p_col_xxx(5)
  613. def test_tall_p_col_fat(self):
  614. # 12x7 + 12x7 = 12x14 --> becomes fat
  615. self.base_tall_p_col_xxx(7)
  616. def test_fat_1_row(self):
  617. a, q, r, u = self.generate('fat', which='row')
  618. for row in range(r.shape[0] + 1):
  619. q1, r1 = qr_insert(q, r, u, row)
  620. a1 = np.insert(a, row, u, 0)
  621. check_qr(q1, r1, a1, self.rtol, self.atol)
  622. # for row adds to fat matrices there are three cases to test
  623. # fat + prow --> fat
  624. # fat + prow --> sqr
  625. # fat + prow --> tall
  626. def base_fat_p_row_xxx(self, p):
  627. a, q, r, u = self.generate('fat', which='row', p=p)
  628. for row in range(r.shape[0] + 1):
  629. q1, r1 = qr_insert(q, r, u, row)
  630. a1 = np.insert(a, row*np.ones(p, np.intp), u, 0)
  631. check_qr(q1, r1, a1, self.rtol, self.atol)
  632. def test_fat_p_row_fat(self):
  633. # 7x12 + 3x12 = 10x12 --> stays fat
  634. self.base_fat_p_row_xxx(3)
  635. def test_fat_p_row_sqr(self):
  636. # 7x12 + 5x12 = 12x12 --> becomes sqr
  637. self.base_fat_p_row_xxx(5)
  638. def test_fat_p_row_tall(self):
  639. # 7x12 + 7x12 = 14x12 --> becomes tall
  640. self.base_fat_p_row_xxx(7)
  641. def test_fat_1_col(self):
  642. a, q, r, u = self.generate('fat', which='col')
  643. for col in range(r.shape[1] + 1):
  644. q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
  645. a1 = np.insert(a, col, u, 1)
  646. check_qr(q1, r1, a1, self.rtol, self.atol)
  647. def test_fat_p_col(self):
  648. # fat + cols --> fat always
  649. a, q, r, u = self.generate('fat', which='col', p=3)
  650. for col in range(r.shape[1] + 1):
  651. q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
  652. a1 = np.insert(a, col*np.ones(3, np.intp), u, 1)
  653. check_qr(q1, r1, a1, self.rtol, self.atol)
  654. def test_economic_1_row(self):
  655. a, q, r, u = self.generate('tall', 'economic', 'row')
  656. for row in range(r.shape[0] + 1):
  657. q1, r1 = qr_insert(q, r, u, row, overwrite_qru=False)
  658. a1 = np.insert(a, row, u, 0)
  659. check_qr(q1, r1, a1, self.rtol, self.atol, False)
  660. def test_economic_p_row(self):
  661. # tall + rows --> tall always
  662. a, q, r, u = self.generate('tall', 'economic', 'row', 3)
  663. for row in range(r.shape[0] + 1):
  664. q1, r1 = qr_insert(q, r, u, row, overwrite_qru=False)
  665. a1 = np.insert(a, row*np.ones(3, np.intp), u, 0)
  666. check_qr(q1, r1, a1, self.rtol, self.atol, False)
  667. def test_economic_1_col(self):
  668. a, q, r, u = self.generate('tall', 'economic', which='col')
  669. for col in range(r.shape[1] + 1):
  670. q1, r1 = qr_insert(q, r, u.copy(), col, 'col', overwrite_qru=False)
  671. a1 = np.insert(a, col, u, 1)
  672. check_qr(q1, r1, a1, self.rtol, self.atol, False)
  673. def test_economic_1_col_bad_update(self):
  674. # When the column to be added lies in the span of Q, the update is
  675. # not meaningful. This is detected, and a LinAlgError is issued.
  676. q = np.eye(5, 3, dtype=self.dtype)
  677. r = np.eye(3, dtype=self.dtype)
  678. u = np.array([1, 0, 0, 0, 0], self.dtype)
  679. assert_raises(linalg.LinAlgError, qr_insert, q, r, u, 0, 'col')
  680. # for column adds to economic matrices there are three cases to test
  681. # eco + pcol --> eco
  682. # eco + pcol --> sqr
  683. # eco + pcol --> fat
  684. def base_economic_p_col_xxx(self, p):
  685. a, q, r, u = self.generate('tall', 'economic', which='col', p=p)
  686. for col in range(r.shape[1] + 1):
  687. q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
  688. a1 = np.insert(a, col*np.ones(p, np.intp), u, 1)
  689. check_qr(q1, r1, a1, self.rtol, self.atol, False)
  690. def test_economic_p_col_eco(self):
  691. # 12x7 + 12x3 = 12x10 --> stays eco
  692. self.base_economic_p_col_xxx(3)
  693. def test_economic_p_col_sqr(self):
  694. # 12x7 + 12x5 = 12x12 --> becomes sqr
  695. self.base_economic_p_col_xxx(5)
  696. def test_economic_p_col_fat(self):
  697. # 12x7 + 12x7 = 12x14 --> becomes fat
  698. self.base_economic_p_col_xxx(7)
  699. def test_Mx1_1_row(self):
  700. a, q, r, u = self.generate('Mx1', which='row')
  701. for row in range(r.shape[0] + 1):
  702. q1, r1 = qr_insert(q, r, u, row)
  703. a1 = np.insert(a, row, u, 0)
  704. check_qr(q1, r1, a1, self.rtol, self.atol)
  705. def test_Mx1_p_row(self):
  706. a, q, r, u = self.generate('Mx1', which='row', p=3)
  707. for row in range(r.shape[0] + 1):
  708. q1, r1 = qr_insert(q, r, u, row)
  709. a1 = np.insert(a, row*np.ones(3, np.intp), u, 0)
  710. check_qr(q1, r1, a1, self.rtol, self.atol)
  711. def test_Mx1_1_col(self):
  712. a, q, r, u = self.generate('Mx1', which='col')
  713. for col in range(r.shape[1] + 1):
  714. q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
  715. a1 = np.insert(a, col, u, 1)
  716. check_qr(q1, r1, a1, self.rtol, self.atol)
  717. def test_Mx1_p_col(self):
  718. a, q, r, u = self.generate('Mx1', which='col', p=3)
  719. for col in range(r.shape[1] + 1):
  720. q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
  721. a1 = np.insert(a, col*np.ones(3, np.intp), u, 1)
  722. check_qr(q1, r1, a1, self.rtol, self.atol)
  723. def test_Mx1_economic_1_row(self):
  724. a, q, r, u = self.generate('Mx1', 'economic', 'row')
  725. for row in range(r.shape[0] + 1):
  726. q1, r1 = qr_insert(q, r, u, row)
  727. a1 = np.insert(a, row, u, 0)
  728. check_qr(q1, r1, a1, self.rtol, self.atol, False)
  729. def test_Mx1_economic_p_row(self):
  730. a, q, r, u = self.generate('Mx1', 'economic', 'row', 3)
  731. for row in range(r.shape[0] + 1):
  732. q1, r1 = qr_insert(q, r, u, row)
  733. a1 = np.insert(a, row*np.ones(3, np.intp), u, 0)
  734. check_qr(q1, r1, a1, self.rtol, self.atol, False)
  735. def test_Mx1_economic_1_col(self):
  736. a, q, r, u = self.generate('Mx1', 'economic', 'col')
  737. for col in range(r.shape[1] + 1):
  738. q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
  739. a1 = np.insert(a, col, u, 1)
  740. check_qr(q1, r1, a1, self.rtol, self.atol, False)
  741. def test_Mx1_economic_p_col(self):
  742. a, q, r, u = self.generate('Mx1', 'economic', 'col', 3)
  743. for col in range(r.shape[1] + 1):
  744. q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
  745. a1 = np.insert(a, col*np.ones(3, np.intp), u, 1)
  746. check_qr(q1, r1, a1, self.rtol, self.atol, False)
  747. def test_1xN_1_row(self):
  748. a, q, r, u = self.generate('1xN', which='row')
  749. for row in range(r.shape[0] + 1):
  750. q1, r1 = qr_insert(q, r, u, row)
  751. a1 = np.insert(a, row, u, 0)
  752. check_qr(q1, r1, a1, self.rtol, self.atol)
  753. def test_1xN_p_row(self):
  754. a, q, r, u = self.generate('1xN', which='row', p=3)
  755. for row in range(r.shape[0] + 1):
  756. q1, r1 = qr_insert(q, r, u, row)
  757. a1 = np.insert(a, row*np.ones(3, np.intp), u, 0)
  758. check_qr(q1, r1, a1, self.rtol, self.atol)
  759. def test_1xN_1_col(self):
  760. a, q, r, u = self.generate('1xN', which='col')
  761. for col in range(r.shape[1] + 1):
  762. q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
  763. a1 = np.insert(a, col, u, 1)
  764. check_qr(q1, r1, a1, self.rtol, self.atol)
  765. def test_1xN_p_col(self):
  766. a, q, r, u = self.generate('1xN', which='col', p=3)
  767. for col in range(r.shape[1] + 1):
  768. q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
  769. a1 = np.insert(a, col*np.ones(3, np.intp), u, 1)
  770. check_qr(q1, r1, a1, self.rtol, self.atol)
  771. def test_1x1_1_row(self):
  772. a, q, r, u = self.generate('1x1', which='row')
  773. for row in range(r.shape[0] + 1):
  774. q1, r1 = qr_insert(q, r, u, row)
  775. a1 = np.insert(a, row, u, 0)
  776. check_qr(q1, r1, a1, self.rtol, self.atol)
  777. def test_1x1_p_row(self):
  778. a, q, r, u = self.generate('1x1', which='row', p=3)
  779. for row in range(r.shape[0] + 1):
  780. q1, r1 = qr_insert(q, r, u, row)
  781. a1 = np.insert(a, row*np.ones(3, np.intp), u, 0)
  782. check_qr(q1, r1, a1, self.rtol, self.atol)
  783. def test_1x1_1_col(self):
  784. a, q, r, u = self.generate('1x1', which='col')
  785. for col in range(r.shape[1] + 1):
  786. q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
  787. a1 = np.insert(a, col, u, 1)
  788. check_qr(q1, r1, a1, self.rtol, self.atol)
  789. def test_1x1_p_col(self):
  790. a, q, r, u = self.generate('1x1', which='col', p=3)
  791. for col in range(r.shape[1] + 1):
  792. q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
  793. a1 = np.insert(a, col*np.ones(3, np.intp), u, 1)
  794. check_qr(q1, r1, a1, self.rtol, self.atol)
  795. def test_1x1_1_scalar(self):
  796. a, q, r, u = self.generate('1x1', which='row')
  797. assert_raises(ValueError, qr_insert, q[0, 0], r, u, 0, 'row')
  798. assert_raises(ValueError, qr_insert, q, r[0, 0], u, 0, 'row')
  799. assert_raises(ValueError, qr_insert, q, r, u[0], 0, 'row')
  800. assert_raises(ValueError, qr_insert, q[0, 0], r, u, 0, 'col')
  801. assert_raises(ValueError, qr_insert, q, r[0, 0], u, 0, 'col')
  802. assert_raises(ValueError, qr_insert, q, r, u[0], 0, 'col')
  803. def base_non_simple_strides(self, adjust_strides, k, p, which):
  804. for type in ['sqr', 'tall', 'fat']:
  805. a, q0, r0, u0 = self.generate(type, which=which, p=p)
  806. qs, rs, us = adjust_strides((q0, r0, u0))
  807. if p == 1:
  808. ai = np.insert(a, k, u0, 0 if which == 'row' else 1)
  809. else:
  810. ai = np.insert(a, k*np.ones(p, np.intp),
  811. u0 if which == 'row' else u0,
  812. 0 if which == 'row' else 1)
  813. # for each variable, q, r, u we try with it strided and
  814. # overwrite=False. Then we try with overwrite=True. Nothing
  815. # is checked to see if it can be overwritten, since only
  816. # F ordered Q can be overwritten when adding columns.
  817. q = q0.copy('F')
  818. r = r0.copy('F')
  819. u = u0.copy('F')
  820. q1, r1 = qr_insert(qs, r, u, k, which, overwrite_qru=False)
  821. check_qr(q1, r1, ai, self.rtol, self.atol)
  822. q1o, r1o = qr_insert(qs, r, u, k, which, overwrite_qru=True)
  823. check_qr(q1o, r1o, ai, self.rtol, self.atol)
  824. q = q0.copy('F')
  825. r = r0.copy('F')
  826. u = u0.copy('F')
  827. q2, r2 = qr_insert(q, rs, u, k, which, overwrite_qru=False)
  828. check_qr(q2, r2, ai, self.rtol, self.atol)
  829. q2o, r2o = qr_insert(q, rs, u, k, which, overwrite_qru=True)
  830. check_qr(q2o, r2o, ai, self.rtol, self.atol)
  831. q = q0.copy('F')
  832. r = r0.copy('F')
  833. u = u0.copy('F')
  834. q3, r3 = qr_insert(q, r, us, k, which, overwrite_qru=False)
  835. check_qr(q3, r3, ai, self.rtol, self.atol)
  836. q3o, r3o = qr_insert(q, r, us, k, which, overwrite_qru=True)
  837. check_qr(q3o, r3o, ai, self.rtol, self.atol)
  838. q = q0.copy('F')
  839. r = r0.copy('F')
  840. u = u0.copy('F')
  841. # since some of these were consumed above
  842. qs, rs, us = adjust_strides((q, r, u))
  843. q5, r5 = qr_insert(qs, rs, us, k, which, overwrite_qru=False)
  844. check_qr(q5, r5, ai, self.rtol, self.atol)
  845. q5o, r5o = qr_insert(qs, rs, us, k, which, overwrite_qru=True)
  846. check_qr(q5o, r5o, ai, self.rtol, self.atol)
  847. def test_non_unit_strides_1_row(self):
  848. self.base_non_simple_strides(make_strided, 0, 1, 'row')
  849. def test_non_unit_strides_p_row(self):
  850. self.base_non_simple_strides(make_strided, 0, 3, 'row')
  851. def test_non_unit_strides_1_col(self):
  852. self.base_non_simple_strides(make_strided, 0, 1, 'col')
  853. def test_non_unit_strides_p_col(self):
  854. self.base_non_simple_strides(make_strided, 0, 3, 'col')
  855. def test_neg_strides_1_row(self):
  856. self.base_non_simple_strides(negate_strides, 0, 1, 'row')
  857. def test_neg_strides_p_row(self):
  858. self.base_non_simple_strides(negate_strides, 0, 3, 'row')
  859. def test_neg_strides_1_col(self):
  860. self.base_non_simple_strides(negate_strides, 0, 1, 'col')
  861. def test_neg_strides_p_col(self):
  862. self.base_non_simple_strides(negate_strides, 0, 3, 'col')
  863. def test_non_itemsize_strides_1_row(self):
  864. self.base_non_simple_strides(nonitemsize_strides, 0, 1, 'row')
  865. def test_non_itemsize_strides_p_row(self):
  866. self.base_non_simple_strides(nonitemsize_strides, 0, 3, 'row')
  867. def test_non_itemsize_strides_1_col(self):
  868. self.base_non_simple_strides(nonitemsize_strides, 0, 1, 'col')
  869. def test_non_itemsize_strides_p_col(self):
  870. self.base_non_simple_strides(nonitemsize_strides, 0, 3, 'col')
  871. def test_non_native_byte_order_1_row(self):
  872. self.base_non_simple_strides(make_nonnative, 0, 1, 'row')
  873. def test_non_native_byte_order_p_row(self):
  874. self.base_non_simple_strides(make_nonnative, 0, 3, 'row')
  875. def test_non_native_byte_order_1_col(self):
  876. self.base_non_simple_strides(make_nonnative, 0, 1, 'col')
  877. def test_non_native_byte_order_p_col(self):
  878. self.base_non_simple_strides(make_nonnative, 0, 3, 'col')
  879. def test_overwrite_qu_rank_1(self):
  880. # when inserting rows, the size of both Q and R change, so only
  881. # column inserts can overwrite q. Only complex column inserts
  882. # with C ordered Q overwrite u. Any contiguous Q is overwritten
  883. # when inserting 1 column
  884. a, q0, r, u, = self.generate('sqr', which='col', p=1)
  885. q = q0.copy('C')
  886. u0 = u.copy()
  887. # don't overwrite
  888. q1, r1 = qr_insert(q, r, u, 0, 'col', overwrite_qru=False)
  889. a1 = np.insert(a, 0, u0, 1)
  890. check_qr(q1, r1, a1, self.rtol, self.atol)
  891. check_qr(q, r, a, self.rtol, self.atol)
  892. # try overwriting
  893. q2, r2 = qr_insert(q, r, u, 0, 'col', overwrite_qru=True)
  894. check_qr(q2, r2, a1, self.rtol, self.atol)
  895. # verify the overwriting
  896. assert_allclose(q2, q, rtol=self.rtol, atol=self.atol)
  897. assert_allclose(u, u0.conj(), self.rtol, self.atol)
  898. # now try with a fortran ordered Q
  899. qF = q0.copy('F')
  900. u1 = u0.copy()
  901. q3, r3 = qr_insert(qF, r, u1, 0, 'col', overwrite_qru=False)
  902. check_qr(q3, r3, a1, self.rtol, self.atol)
  903. check_qr(qF, r, a, self.rtol, self.atol)
  904. # try overwriting
  905. q4, r4 = qr_insert(qF, r, u1, 0, 'col', overwrite_qru=True)
  906. check_qr(q4, r4, a1, self.rtol, self.atol)
  907. assert_allclose(q4, qF, rtol=self.rtol, atol=self.atol)
  908. def test_overwrite_qu_rank_p(self):
  909. # when inserting rows, the size of both Q and R change, so only
  910. # column inserts can potentially overwrite Q. In practice, only
  911. # F ordered Q are overwritten with a rank p update.
  912. a, q0, r, u, = self.generate('sqr', which='col', p=3)
  913. q = q0.copy('F')
  914. a1 = np.insert(a, np.zeros(3, np.intp), u, 1)
  915. # don't overwrite
  916. q1, r1 = qr_insert(q, r, u, 0, 'col', overwrite_qru=False)
  917. check_qr(q1, r1, a1, self.rtol, self.atol)
  918. check_qr(q, r, a, self.rtol, self.atol)
  919. # try overwriting
  920. q2, r2 = qr_insert(q, r, u, 0, 'col', overwrite_qru=True)
  921. check_qr(q2, r2, a1, self.rtol, self.atol)
  922. assert_allclose(q2, q, rtol=self.rtol, atol=self.atol)
  923. def test_empty_inputs(self):
  924. a, q, r, u = self.generate('sqr', which='row')
  925. assert_raises(ValueError, qr_insert, np.array([]), r, u, 0, 'row')
  926. assert_raises(ValueError, qr_insert, q, np.array([]), u, 0, 'row')
  927. assert_raises(ValueError, qr_insert, q, r, np.array([]), 0, 'row')
  928. assert_raises(ValueError, qr_insert, np.array([]), r, u, 0, 'col')
  929. assert_raises(ValueError, qr_insert, q, np.array([]), u, 0, 'col')
  930. assert_raises(ValueError, qr_insert, q, r, np.array([]), 0, 'col')
  931. def test_mismatched_shapes(self):
  932. a, q, r, u = self.generate('tall', which='row')
  933. assert_raises(ValueError, qr_insert, q, r[1:], u, 0, 'row')
  934. assert_raises(ValueError, qr_insert, q[:-2], r, u, 0, 'row')
  935. assert_raises(ValueError, qr_insert, q, r, u[1:], 0, 'row')
  936. assert_raises(ValueError, qr_insert, q, r[1:], u, 0, 'col')
  937. assert_raises(ValueError, qr_insert, q[:-2], r, u, 0, 'col')
  938. assert_raises(ValueError, qr_insert, q, r, u[1:], 0, 'col')
  939. def test_unsupported_dtypes(self):
  940. dts = ['int8', 'int16', 'int32', 'int64',
  941. 'uint8', 'uint16', 'uint32', 'uint64',
  942. 'float16', 'longdouble', 'longcomplex',
  943. 'bool']
  944. a, q0, r0, u0 = self.generate('sqr', which='row')
  945. for dtype in dts:
  946. q = q0.real.astype(dtype)
  947. r = r0.real.astype(dtype)
  948. u = u0.real.astype(dtype)
  949. assert_raises(ValueError, qr_insert, q, r0, u0, 0, 'row')
  950. assert_raises(ValueError, qr_insert, q, r0, u0, 0, 'col')
  951. assert_raises(ValueError, qr_insert, q0, r, u0, 0, 'row')
  952. assert_raises(ValueError, qr_insert, q0, r, u0, 0, 'col')
  953. assert_raises(ValueError, qr_insert, q0, r0, u, 0, 'row')
  954. assert_raises(ValueError, qr_insert, q0, r0, u, 0, 'col')
  955. def test_check_finite(self):
  956. a0, q0, r0, u0 = self.generate('sqr', which='row', p=3)
  957. q = q0.copy('F')
  958. q[1,1] = np.nan
  959. assert_raises(ValueError, qr_insert, q, r0, u0[:,0], 0, 'row')
  960. assert_raises(ValueError, qr_insert, q, r0, u0, 0, 'row')
  961. assert_raises(ValueError, qr_insert, q, r0, u0[:,0], 0, 'col')
  962. assert_raises(ValueError, qr_insert, q, r0, u0, 0, 'col')
  963. r = r0.copy('F')
  964. r[1,1] = np.nan
  965. assert_raises(ValueError, qr_insert, q0, r, u0[:,0], 0, 'row')
  966. assert_raises(ValueError, qr_insert, q0, r, u0, 0, 'row')
  967. assert_raises(ValueError, qr_insert, q0, r, u0[:,0], 0, 'col')
  968. assert_raises(ValueError, qr_insert, q0, r, u0, 0, 'col')
  969. u = u0.copy('F')
  970. u[0,0] = np.nan
  971. assert_raises(ValueError, qr_insert, q0, r0, u[:,0], 0, 'row')
  972. assert_raises(ValueError, qr_insert, q0, r0, u, 0, 'row')
  973. assert_raises(ValueError, qr_insert, q0, r0, u[:,0], 0, 'col')
  974. assert_raises(ValueError, qr_insert, q0, r0, u, 0, 'col')
  975. class TestQRinsert_f(BaseQRinsert):
  976. dtype = np.dtype('f')
  977. class TestQRinsert_F(BaseQRinsert):
  978. dtype = np.dtype('F')
  979. class TestQRinsert_d(BaseQRinsert):
  980. dtype = np.dtype('d')
  981. class TestQRinsert_D(BaseQRinsert):
  982. dtype = np.dtype('D')
  983. class BaseQRupdate(BaseQRdeltas):
  984. def generate(self, type, mode='full', p=1):
  985. a, q, r = super(BaseQRupdate, self).generate(type, mode)
  986. # super call set the seed...
  987. if p == 1:
  988. u = np.random.random(q.shape[0])
  989. v = np.random.random(r.shape[1])
  990. else:
  991. u = np.random.random((q.shape[0], p))
  992. v = np.random.random((r.shape[1], p))
  993. if np.iscomplexobj(self.dtype.type(1)):
  994. b = np.random.random(u.shape)
  995. u = u + 1j * b
  996. c = np.random.random(v.shape)
  997. v = v + 1j * c
  998. u = u.astype(self.dtype)
  999. v = v.astype(self.dtype)
  1000. return a, q, r, u, v
  1001. def test_sqr_rank_1(self):
  1002. a, q, r, u, v = self.generate('sqr')
  1003. q1, r1 = qr_update(q, r, u, v, False)
  1004. a1 = a + np.outer(u, v.conj())
  1005. check_qr(q1, r1, a1, self.rtol, self.atol)
  1006. def test_sqr_rank_p(self):
  1007. # test ndim = 2, rank 1 updates here too
  1008. for p in [1, 2, 3, 5]:
  1009. a, q, r, u, v = self.generate('sqr', p=p)
  1010. if p == 1:
  1011. u = u.reshape(u.size, 1)
  1012. v = v.reshape(v.size, 1)
  1013. q1, r1 = qr_update(q, r, u, v, False)
  1014. a1 = a + np.dot(u, v.T.conj())
  1015. check_qr(q1, r1, a1, self.rtol, self.atol)
  1016. def test_tall_rank_1(self):
  1017. a, q, r, u, v = self.generate('tall')
  1018. q1, r1 = qr_update(q, r, u, v, False)
  1019. a1 = a + np.outer(u, v.conj())
  1020. check_qr(q1, r1, a1, self.rtol, self.atol)
  1021. def test_tall_rank_p(self):
  1022. for p in [1, 2, 3, 5]:
  1023. a, q, r, u, v = self.generate('tall', p=p)
  1024. if p == 1:
  1025. u = u.reshape(u.size, 1)
  1026. v = v.reshape(v.size, 1)
  1027. q1, r1 = qr_update(q, r, u, v, False)
  1028. a1 = a + np.dot(u, v.T.conj())
  1029. check_qr(q1, r1, a1, self.rtol, self.atol)
  1030. def test_fat_rank_1(self):
  1031. a, q, r, u, v = self.generate('fat')
  1032. q1, r1 = qr_update(q, r, u, v, False)
  1033. a1 = a + np.outer(u, v.conj())
  1034. check_qr(q1, r1, a1, self.rtol, self.atol)
  1035. def test_fat_rank_p(self):
  1036. for p in [1, 2, 3, 5]:
  1037. a, q, r, u, v = self.generate('fat', p=p)
  1038. if p == 1:
  1039. u = u.reshape(u.size, 1)
  1040. v = v.reshape(v.size, 1)
  1041. q1, r1 = qr_update(q, r, u, v, False)
  1042. a1 = a + np.dot(u, v.T.conj())
  1043. check_qr(q1, r1, a1, self.rtol, self.atol)
  1044. def test_economic_rank_1(self):
  1045. a, q, r, u, v = self.generate('tall', 'economic')
  1046. q1, r1 = qr_update(q, r, u, v, False)
  1047. a1 = a + np.outer(u, v.conj())
  1048. check_qr(q1, r1, a1, self.rtol, self.atol, False)
  1049. def test_economic_rank_p(self):
  1050. for p in [1, 2, 3, 5]:
  1051. a, q, r, u, v = self.generate('tall', 'economic', p)
  1052. if p == 1:
  1053. u = u.reshape(u.size, 1)
  1054. v = v.reshape(v.size, 1)
  1055. q1, r1 = qr_update(q, r, u, v, False)
  1056. a1 = a + np.dot(u, v.T.conj())
  1057. check_qr(q1, r1, a1, self.rtol, self.atol, False)
  1058. def test_Mx1_rank_1(self):
  1059. a, q, r, u, v = self.generate('Mx1')
  1060. q1, r1 = qr_update(q, r, u, v, False)
  1061. a1 = a + np.outer(u, v.conj())
  1062. check_qr(q1, r1, a1, self.rtol, self.atol)
  1063. def test_Mx1_rank_p(self):
  1064. # when M or N == 1, only a rank 1 update is allowed. This isn't
  1065. # fundamental limitation, but the code does not support it.
  1066. a, q, r, u, v = self.generate('Mx1', p=1)
  1067. u = u.reshape(u.size, 1)
  1068. v = v.reshape(v.size, 1)
  1069. q1, r1 = qr_update(q, r, u, v, False)
  1070. a1 = a + np.dot(u, v.T.conj())
  1071. check_qr(q1, r1, a1, self.rtol, self.atol)
  1072. def test_Mx1_economic_rank_1(self):
  1073. a, q, r, u, v = self.generate('Mx1', 'economic')
  1074. q1, r1 = qr_update(q, r, u, v, False)
  1075. a1 = a + np.outer(u, v.conj())
  1076. check_qr(q1, r1, a1, self.rtol, self.atol, False)
  1077. def test_Mx1_economic_rank_p(self):
  1078. # when M or N == 1, only a rank 1 update is allowed. This isn't
  1079. # fundamental limitation, but the code does not support it.
  1080. a, q, r, u, v = self.generate('Mx1', 'economic', p=1)
  1081. u = u.reshape(u.size, 1)
  1082. v = v.reshape(v.size, 1)
  1083. q1, r1 = qr_update(q, r, u, v, False)
  1084. a1 = a + np.dot(u, v.T.conj())
  1085. check_qr(q1, r1, a1, self.rtol, self.atol, False)
  1086. def test_1xN_rank_1(self):
  1087. a, q, r, u, v = self.generate('1xN')
  1088. q1, r1 = qr_update(q, r, u, v, False)
  1089. a1 = a + np.outer(u, v.conj())
  1090. check_qr(q1, r1, a1, self.rtol, self.atol)
  1091. def test_1xN_rank_p(self):
  1092. # when M or N == 1, only a rank 1 update is allowed. This isn't
  1093. # fundamental limitation, but the code does not support it.
  1094. a, q, r, u, v = self.generate('1xN', p=1)
  1095. u = u.reshape(u.size, 1)
  1096. v = v.reshape(v.size, 1)
  1097. q1, r1 = qr_update(q, r, u, v, False)
  1098. a1 = a + np.dot(u, v.T.conj())
  1099. check_qr(q1, r1, a1, self.rtol, self.atol)
  1100. def test_1x1_rank_1(self):
  1101. a, q, r, u, v = self.generate('1x1')
  1102. q1, r1 = qr_update(q, r, u, v, False)
  1103. a1 = a + np.outer(u, v.conj())
  1104. check_qr(q1, r1, a1, self.rtol, self.atol)
  1105. def test_1x1_rank_p(self):
  1106. # when M or N == 1, only a rank 1 update is allowed. This isn't
  1107. # fundamental limitation, but the code does not support it.
  1108. a, q, r, u, v = self.generate('1x1', p=1)
  1109. u = u.reshape(u.size, 1)
  1110. v = v.reshape(v.size, 1)
  1111. q1, r1 = qr_update(q, r, u, v, False)
  1112. a1 = a + np.dot(u, v.T.conj())
  1113. check_qr(q1, r1, a1, self.rtol, self.atol)
  1114. def test_1x1_rank_1_scalar(self):
  1115. a, q, r, u, v = self.generate('1x1')
  1116. assert_raises(ValueError, qr_update, q[0, 0], r, u, v)
  1117. assert_raises(ValueError, qr_update, q, r[0, 0], u, v)
  1118. assert_raises(ValueError, qr_update, q, r, u[0], v)
  1119. assert_raises(ValueError, qr_update, q, r, u, v[0])
  1120. def base_non_simple_strides(self, adjust_strides, mode, p, overwriteable):
  1121. assert_sqr = False if mode == 'economic' else True
  1122. for type in ['sqr', 'tall', 'fat']:
  1123. a, q0, r0, u0, v0 = self.generate(type, mode, p)
  1124. qs, rs, us, vs = adjust_strides((q0, r0, u0, v0))
  1125. if p == 1:
  1126. aup = a + np.outer(u0, v0.conj())
  1127. else:
  1128. aup = a + np.dot(u0, v0.T.conj())
  1129. # for each variable, q, r, u, v we try with it strided and
  1130. # overwrite=False. Then we try with overwrite=True, and make
  1131. # sure that if p == 1, r and v are still overwritten.
  1132. # a strided q and u must always be copied.
  1133. q = q0.copy('F')
  1134. r = r0.copy('F')
  1135. u = u0.copy('F')
  1136. v = v0.copy('C')
  1137. q1, r1 = qr_update(qs, r, u, v, False)
  1138. check_qr(q1, r1, aup, self.rtol, self.atol, assert_sqr)
  1139. q1o, r1o = qr_update(qs, r, u, v, True)
  1140. check_qr(q1o, r1o, aup, self.rtol, self.atol, assert_sqr)
  1141. if overwriteable:
  1142. assert_allclose(r1o, r, rtol=self.rtol, atol=self.atol)
  1143. assert_allclose(v, v0.conj(), rtol=self.rtol, atol=self.atol)
  1144. q = q0.copy('F')
  1145. r = r0.copy('F')
  1146. u = u0.copy('F')
  1147. v = v0.copy('C')
  1148. q2, r2 = qr_update(q, rs, u, v, False)
  1149. check_qr(q2, r2, aup, self.rtol, self.atol, assert_sqr)
  1150. q2o, r2o = qr_update(q, rs, u, v, True)
  1151. check_qr(q2o, r2o, aup, self.rtol, self.atol, assert_sqr)
  1152. if overwriteable:
  1153. assert_allclose(r2o, rs, rtol=self.rtol, atol=self.atol)
  1154. assert_allclose(v, v0.conj(), rtol=self.rtol, atol=self.atol)
  1155. q = q0.copy('F')
  1156. r = r0.copy('F')
  1157. u = u0.copy('F')
  1158. v = v0.copy('C')
  1159. q3, r3 = qr_update(q, r, us, v, False)
  1160. check_qr(q3, r3, aup, self.rtol, self.atol, assert_sqr)
  1161. q3o, r3o = qr_update(q, r, us, v, True)
  1162. check_qr(q3o, r3o, aup, self.rtol, self.atol, assert_sqr)
  1163. if overwriteable:
  1164. assert_allclose(r3o, r, rtol=self.rtol, atol=self.atol)
  1165. assert_allclose(v, v0.conj(), rtol=self.rtol, atol=self.atol)
  1166. q = q0.copy('F')
  1167. r = r0.copy('F')
  1168. u = u0.copy('F')
  1169. v = v0.copy('C')
  1170. q4, r4 = qr_update(q, r, u, vs, False)
  1171. check_qr(q4, r4, aup, self.rtol, self.atol, assert_sqr)
  1172. q4o, r4o = qr_update(q, r, u, vs, True)
  1173. check_qr(q4o, r4o, aup, self.rtol, self.atol, assert_sqr)
  1174. if overwriteable:
  1175. assert_allclose(r4o, r, rtol=self.rtol, atol=self.atol)
  1176. assert_allclose(vs, v0.conj(), rtol=self.rtol, atol=self.atol)
  1177. q = q0.copy('F')
  1178. r = r0.copy('F')
  1179. u = u0.copy('F')
  1180. v = v0.copy('C')
  1181. # since some of these were consumed above
  1182. qs, rs, us, vs = adjust_strides((q, r, u, v))
  1183. q5, r5 = qr_update(qs, rs, us, vs, False)
  1184. check_qr(q5, r5, aup, self.rtol, self.atol, assert_sqr)
  1185. q5o, r5o = qr_update(qs, rs, us, vs, True)
  1186. check_qr(q5o, r5o, aup, self.rtol, self.atol, assert_sqr)
  1187. if overwriteable:
  1188. assert_allclose(r5o, rs, rtol=self.rtol, atol=self.atol)
  1189. assert_allclose(vs, v0.conj(), rtol=self.rtol, atol=self.atol)
  1190. def test_non_unit_strides_rank_1(self):
  1191. self.base_non_simple_strides(make_strided, 'full', 1, True)
  1192. def test_non_unit_strides_economic_rank_1(self):
  1193. self.base_non_simple_strides(make_strided, 'economic', 1, True)
  1194. def test_non_unit_strides_rank_p(self):
  1195. self.base_non_simple_strides(make_strided, 'full', 3, False)
  1196. def test_non_unit_strides_economic_rank_p(self):
  1197. self.base_non_simple_strides(make_strided, 'economic', 3, False)
  1198. def test_neg_strides_rank_1(self):
  1199. self.base_non_simple_strides(negate_strides, 'full', 1, False)
  1200. def test_neg_strides_economic_rank_1(self):
  1201. self.base_non_simple_strides(negate_strides, 'economic', 1, False)
  1202. def test_neg_strides_rank_p(self):
  1203. self.base_non_simple_strides(negate_strides, 'full', 3, False)
  1204. def test_neg_strides_economic_rank_p(self):
  1205. self.base_non_simple_strides(negate_strides, 'economic', 3, False)
  1206. def test_non_itemsize_strides_rank_1(self):
  1207. self.base_non_simple_strides(nonitemsize_strides, 'full', 1, False)
  1208. def test_non_itemsize_strides_economic_rank_1(self):
  1209. self.base_non_simple_strides(nonitemsize_strides, 'economic', 1, False)
  1210. def test_non_itemsize_strides_rank_p(self):
  1211. self.base_non_simple_strides(nonitemsize_strides, 'full', 3, False)
  1212. def test_non_itemsize_strides_economic_rank_p(self):
  1213. self.base_non_simple_strides(nonitemsize_strides, 'economic', 3, False)
  1214. def test_non_native_byte_order_rank_1(self):
  1215. self.base_non_simple_strides(make_nonnative, 'full', 1, False)
  1216. def test_non_native_byte_order_economic_rank_1(self):
  1217. self.base_non_simple_strides(make_nonnative, 'economic', 1, False)
  1218. def test_non_native_byte_order_rank_p(self):
  1219. self.base_non_simple_strides(make_nonnative, 'full', 3, False)
  1220. def test_non_native_byte_order_economic_rank_p(self):
  1221. self.base_non_simple_strides(make_nonnative, 'economic', 3, False)
  1222. def test_overwrite_qruv_rank_1(self):
  1223. # Any positive strided q, r, u, and v can be overwritten for a rank 1
  1224. # update, only checking C and F contiguous.
  1225. a, q0, r0, u0, v0 = self.generate('sqr')
  1226. a1 = a + np.outer(u0, v0.conj())
  1227. q = q0.copy('F')
  1228. r = r0.copy('F')
  1229. u = u0.copy('F')
  1230. v = v0.copy('F')
  1231. # don't overwrite
  1232. q1, r1 = qr_update(q, r, u, v, False)
  1233. check_qr(q1, r1, a1, self.rtol, self.atol)
  1234. check_qr(q, r, a, self.rtol, self.atol)
  1235. q2, r2 = qr_update(q, r, u, v, True)
  1236. check_qr(q2, r2, a1, self.rtol, self.atol)
  1237. # verify the overwriting, no good way to check u and v.
  1238. assert_allclose(q2, q, rtol=self.rtol, atol=self.atol)
  1239. assert_allclose(r2, r, rtol=self.rtol, atol=self.atol)
  1240. q = q0.copy('C')
  1241. r = r0.copy('C')
  1242. u = u0.copy('C')
  1243. v = v0.copy('C')
  1244. q3, r3 = qr_update(q, r, u, v, True)
  1245. check_qr(q3, r3, a1, self.rtol, self.atol)
  1246. assert_allclose(q3, q, rtol=self.rtol, atol=self.atol)
  1247. assert_allclose(r3, r, rtol=self.rtol, atol=self.atol)
  1248. def test_overwrite_qruv_rank_1_economic(self):
  1249. # updating economic decompositions can overwrite any contigous r,
  1250. # and positively strided r and u. V is only ever read.
  1251. # only checking C and F contiguous.
  1252. a, q0, r0, u0, v0 = self.generate('tall', 'economic')
  1253. a1 = a + np.outer(u0, v0.conj())
  1254. q = q0.copy('F')
  1255. r = r0.copy('F')
  1256. u = u0.copy('F')
  1257. v = v0.copy('F')
  1258. # don't overwrite
  1259. q1, r1 = qr_update(q, r, u, v, False)
  1260. check_qr(q1, r1, a1, self.rtol, self.atol, False)
  1261. check_qr(q, r, a, self.rtol, self.atol, False)
  1262. q2, r2 = qr_update(q, r, u, v, True)
  1263. check_qr(q2, r2, a1, self.rtol, self.atol, False)
  1264. # verify the overwriting, no good way to check u and v.
  1265. assert_allclose(q2, q, rtol=self.rtol, atol=self.atol)
  1266. assert_allclose(r2, r, rtol=self.rtol, atol=self.atol)
  1267. q = q0.copy('C')
  1268. r = r0.copy('C')
  1269. u = u0.copy('C')
  1270. v = v0.copy('C')
  1271. q3, r3 = qr_update(q, r, u, v, True)
  1272. check_qr(q3, r3, a1, self.rtol, self.atol, False)
  1273. assert_allclose(q3, q, rtol=self.rtol, atol=self.atol)
  1274. assert_allclose(r3, r, rtol=self.rtol, atol=self.atol)
  1275. def test_overwrite_qruv_rank_p(self):
  1276. # for rank p updates, q r must be F contiguous, v must be C (v.T --> F)
  1277. # and u can be C or F, but is only overwritten if Q is C and complex
  1278. a, q0, r0, u0, v0 = self.generate('sqr', p=3)
  1279. a1 = a + np.dot(u0, v0.T.conj())
  1280. q = q0.copy('F')
  1281. r = r0.copy('F')
  1282. u = u0.copy('F')
  1283. v = v0.copy('C')
  1284. # don't overwrite
  1285. q1, r1 = qr_update(q, r, u, v, False)
  1286. check_qr(q1, r1, a1, self.rtol, self.atol)
  1287. check_qr(q, r, a, self.rtol, self.atol)
  1288. q2, r2 = qr_update(q, r, u, v, True)
  1289. check_qr(q2, r2, a1, self.rtol, self.atol)
  1290. # verify the overwriting, no good way to check u and v.
  1291. assert_allclose(q2, q, rtol=self.rtol, atol=self.atol)
  1292. assert_allclose(r2, r, rtol=self.rtol, atol=self.atol)
  1293. def test_empty_inputs(self):
  1294. a, q, r, u, v = self.generate('tall')
  1295. assert_raises(ValueError, qr_update, np.array([]), r, u, v)
  1296. assert_raises(ValueError, qr_update, q, np.array([]), u, v)
  1297. assert_raises(ValueError, qr_update, q, r, np.array([]), v)
  1298. assert_raises(ValueError, qr_update, q, r, u, np.array([]))
  1299. def test_mismatched_shapes(self):
  1300. a, q, r, u, v = self.generate('tall')
  1301. assert_raises(ValueError, qr_update, q, r[1:], u, v)
  1302. assert_raises(ValueError, qr_update, q[:-2], r, u, v)
  1303. assert_raises(ValueError, qr_update, q, r, u[1:], v)
  1304. assert_raises(ValueError, qr_update, q, r, u, v[1:])
  1305. def test_unsupported_dtypes(self):
  1306. dts = ['int8', 'int16', 'int32', 'int64',
  1307. 'uint8', 'uint16', 'uint32', 'uint64',
  1308. 'float16', 'longdouble', 'longcomplex',
  1309. 'bool']
  1310. a, q0, r0, u0, v0 = self.generate('tall')
  1311. for dtype in dts:
  1312. q = q0.real.astype(dtype)
  1313. r = r0.real.astype(dtype)
  1314. u = u0.real.astype(dtype)
  1315. v = v0.real.astype(dtype)
  1316. assert_raises(ValueError, qr_update, q, r0, u0, v0)
  1317. assert_raises(ValueError, qr_update, q0, r, u0, v0)
  1318. assert_raises(ValueError, qr_update, q0, r0, u, v0)
  1319. assert_raises(ValueError, qr_update, q0, r0, u0, v)
  1320. def test_integer_input(self):
  1321. q = np.arange(16).reshape(4, 4)
  1322. r = q.copy() # doesn't matter
  1323. u = q[:, 0].copy()
  1324. v = r[0, :].copy()
  1325. assert_raises(ValueError, qr_update, q, r, u, v)
  1326. def test_check_finite(self):
  1327. a0, q0, r0, u0, v0 = self.generate('tall', p=3)
  1328. q = q0.copy('F')
  1329. q[1,1] = np.nan
  1330. assert_raises(ValueError, qr_update, q, r0, u0[:,0], v0[:,0])
  1331. assert_raises(ValueError, qr_update, q, r0, u0, v0)
  1332. r = r0.copy('F')
  1333. r[1,1] = np.nan
  1334. assert_raises(ValueError, qr_update, q0, r, u0[:,0], v0[:,0])
  1335. assert_raises(ValueError, qr_update, q0, r, u0, v0)
  1336. u = u0.copy('F')
  1337. u[0,0] = np.nan
  1338. assert_raises(ValueError, qr_update, q0, r0, u[:,0], v0[:,0])
  1339. assert_raises(ValueError, qr_update, q0, r0, u, v0)
  1340. v = v0.copy('F')
  1341. v[0,0] = np.nan
  1342. assert_raises(ValueError, qr_update, q0, r0, u[:,0], v[:,0])
  1343. assert_raises(ValueError, qr_update, q0, r0, u, v)
  1344. def test_economic_check_finite(self):
  1345. a0, q0, r0, u0, v0 = self.generate('tall', mode='economic', p=3)
  1346. q = q0.copy('F')
  1347. q[1,1] = np.nan
  1348. assert_raises(ValueError, qr_update, q, r0, u0[:,0], v0[:,0])
  1349. assert_raises(ValueError, qr_update, q, r0, u0, v0)
  1350. r = r0.copy('F')
  1351. r[1,1] = np.nan
  1352. assert_raises(ValueError, qr_update, q0, r, u0[:,0], v0[:,0])
  1353. assert_raises(ValueError, qr_update, q0, r, u0, v0)
  1354. u = u0.copy('F')
  1355. u[0,0] = np.nan
  1356. assert_raises(ValueError, qr_update, q0, r0, u[:,0], v0[:,0])
  1357. assert_raises(ValueError, qr_update, q0, r0, u, v0)
  1358. v = v0.copy('F')
  1359. v[0,0] = np.nan
  1360. assert_raises(ValueError, qr_update, q0, r0, u[:,0], v[:,0])
  1361. assert_raises(ValueError, qr_update, q0, r0, u, v)
  1362. def test_u_exactly_in_span_q(self):
  1363. q = np.array([[0, 0], [0, 0], [1, 0], [0, 1]], self.dtype)
  1364. r = np.array([[1, 0], [0, 1]], self.dtype)
  1365. u = np.array([0, 0, 0, -1], self.dtype)
  1366. v = np.array([1, 2], self.dtype)
  1367. q1, r1 = qr_update(q, r, u, v)
  1368. a1 = np.dot(q, r) + np.outer(u, v.conj())
  1369. check_qr(q1, r1, a1, self.rtol, self.atol, False)
  1370. class TestQRupdate_f(BaseQRupdate):
  1371. dtype = np.dtype('f')
  1372. class TestQRupdate_F(BaseQRupdate):
  1373. dtype = np.dtype('F')
  1374. class TestQRupdate_d(BaseQRupdate):
  1375. dtype = np.dtype('d')
  1376. class TestQRupdate_D(BaseQRupdate):
  1377. dtype = np.dtype('D')
  1378. def test_form_qTu():
  1379. # We want to ensure that all of the code paths through this function are
  1380. # tested. Most of them should be hit with the rest of test suite, but
  1381. # explicit tests make clear precisely what is being tested.
  1382. #
  1383. # This function expects that Q is either C or F contiguous and square.
  1384. # Economic mode decompositions (Q is (M, N), M != N) do not go through this
  1385. # function. U may have any positive strides.
  1386. #
  1387. # Some of these test are duplicates, since contiguous 1d arrays are both C
  1388. # and F.
  1389. q_order = ['F', 'C']
  1390. q_shape = [(8, 8), ]
  1391. u_order = ['F', 'C', 'A'] # here A means is not F not C
  1392. u_shape = [1, 3]
  1393. dtype = ['f', 'd', 'F', 'D']
  1394. for qo, qs, uo, us, d in \
  1395. itertools.product(q_order, q_shape, u_order, u_shape, dtype):
  1396. if us == 1:
  1397. check_form_qTu(qo, qs, uo, us, 1, d)
  1398. check_form_qTu(qo, qs, uo, us, 2, d)
  1399. else:
  1400. check_form_qTu(qo, qs, uo, us, 2, d)
  1401. def check_form_qTu(q_order, q_shape, u_order, u_shape, u_ndim, dtype):
  1402. np.random.seed(47)
  1403. if u_shape == 1 and u_ndim == 1:
  1404. u_shape = (q_shape[0],)
  1405. else:
  1406. u_shape = (q_shape[0], u_shape)
  1407. dtype = np.dtype(dtype)
  1408. if dtype.char in 'fd':
  1409. q = np.random.random(q_shape)
  1410. u = np.random.random(u_shape)
  1411. elif dtype.char in 'FD':
  1412. q = np.random.random(q_shape) + 1j*np.random.random(q_shape)
  1413. u = np.random.random(u_shape) + 1j*np.random.random(u_shape)
  1414. else:
  1415. ValueError("form_qTu doesn't support this dtype")
  1416. q = np.require(q, dtype, q_order)
  1417. if u_order != 'A':
  1418. u = np.require(u, dtype, u_order)
  1419. else:
  1420. u, = make_strided((u.astype(dtype),))
  1421. rtol = 10.0 ** -(np.finfo(dtype).precision-2)
  1422. atol = 2*np.finfo(dtype).eps
  1423. expected = np.dot(q.T.conj(), u)
  1424. res = _decomp_update._form_qTu(q, u)
  1425. assert_allclose(res, expected, rtol=rtol, atol=atol)