test_interpolative.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. #******************************************************************************
  2. # Copyright (C) 2013 Kenneth L. Ho
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions are met:
  5. #
  6. # Redistributions of source code must retain the above copyright notice, this
  7. # list of conditions and the following disclaimer. Redistributions in binary
  8. # form must reproduce the above copyright notice, this list of conditions and
  9. # the following disclaimer in the documentation and/or other materials
  10. # provided with the distribution.
  11. #
  12. # None of the names of the copyright holders may be used to endorse or
  13. # promote products derived from this software without specific prior written
  14. # permission.
  15. #
  16. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  17. # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  18. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  19. # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
  20. # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  21. # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  22. # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  23. # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  24. # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  25. # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  26. # POSSIBILITY OF SUCH DAMAGE.
  27. #******************************************************************************
  28. import scipy.linalg.interpolative as pymatrixid
  29. import numpy as np
  30. from scipy.linalg import hilbert, svdvals, norm
  31. from scipy.sparse.linalg import aslinearoperator
  32. import time
  33. from numpy.testing import assert_, assert_allclose
  34. from pytest import raises as assert_raises
  35. def _debug_print(s):
  36. if 0:
  37. print(s)
  38. class TestInterpolativeDecomposition(object):
  39. def test_id(self):
  40. for dtype in [np.float64, np.complex128]:
  41. self.check_id(dtype)
  42. def check_id(self, dtype):
  43. # Test ID routines on a Hilbert matrix.
  44. # set parameters
  45. n = 300
  46. eps = 1e-12
  47. # construct Hilbert matrix
  48. A = hilbert(n).astype(dtype)
  49. if np.issubdtype(dtype, np.complexfloating):
  50. A = A * (1 + 1j)
  51. L = aslinearoperator(A)
  52. # find rank
  53. S = np.linalg.svd(A, compute_uv=False)
  54. try:
  55. rank = np.nonzero(S < eps)[0][0]
  56. except IndexError:
  57. rank = n
  58. # print input summary
  59. _debug_print("Hilbert matrix dimension: %8i" % n)
  60. _debug_print("Working precision: %8.2e" % eps)
  61. _debug_print("Rank to working precision: %8i" % rank)
  62. # set print format
  63. fmt = "%8.2e (s) / %5s"
  64. # test real ID routines
  65. _debug_print("-----------------------------------------")
  66. _debug_print("Real ID routines")
  67. _debug_print("-----------------------------------------")
  68. # fixed precision
  69. _debug_print("Calling iddp_id / idzp_id ...",)
  70. t0 = time.time()
  71. k, idx, proj = pymatrixid.interp_decomp(A, eps, rand=False)
  72. t = time.time() - t0
  73. B = pymatrixid.reconstruct_matrix_from_id(A[:, idx[:k]], idx, proj)
  74. _debug_print(fmt % (t, np.allclose(A, B, eps)))
  75. assert_(np.allclose(A, B, eps))
  76. _debug_print("Calling iddp_aid / idzp_aid ...",)
  77. t0 = time.time()
  78. k, idx, proj = pymatrixid.interp_decomp(A, eps)
  79. t = time.time() - t0
  80. B = pymatrixid.reconstruct_matrix_from_id(A[:, idx[:k]], idx, proj)
  81. _debug_print(fmt % (t, np.allclose(A, B, eps)))
  82. assert_(np.allclose(A, B, eps))
  83. _debug_print("Calling iddp_rid / idzp_rid ...",)
  84. t0 = time.time()
  85. k, idx, proj = pymatrixid.interp_decomp(L, eps)
  86. t = time.time() - t0
  87. B = pymatrixid.reconstruct_matrix_from_id(A[:, idx[:k]], idx, proj)
  88. _debug_print(fmt % (t, np.allclose(A, B, eps)))
  89. assert_(np.allclose(A, B, eps))
  90. # fixed rank
  91. k = rank
  92. _debug_print("Calling iddr_id / idzr_id ...",)
  93. t0 = time.time()
  94. idx, proj = pymatrixid.interp_decomp(A, k, rand=False)
  95. t = time.time() - t0
  96. B = pymatrixid.reconstruct_matrix_from_id(A[:, idx[:k]], idx, proj)
  97. _debug_print(fmt % (t, np.allclose(A, B, eps)))
  98. assert_(np.allclose(A, B, eps))
  99. _debug_print("Calling iddr_aid / idzr_aid ...",)
  100. t0 = time.time()
  101. idx, proj = pymatrixid.interp_decomp(A, k)
  102. t = time.time() - t0
  103. B = pymatrixid.reconstruct_matrix_from_id(A[:, idx[:k]], idx, proj)
  104. _debug_print(fmt % (t, np.allclose(A, B, eps)))
  105. assert_(np.allclose(A, B, eps))
  106. _debug_print("Calling iddr_rid / idzr_rid ...",)
  107. t0 = time.time()
  108. idx, proj = pymatrixid.interp_decomp(L, k)
  109. t = time.time() - t0
  110. B = pymatrixid.reconstruct_matrix_from_id(A[:, idx[:k]], idx, proj)
  111. _debug_print(fmt % (t, np.allclose(A, B, eps)))
  112. assert_(np.allclose(A, B, eps))
  113. # check skeleton and interpolation matrices
  114. idx, proj = pymatrixid.interp_decomp(A, k, rand=False)
  115. P = pymatrixid.reconstruct_interp_matrix(idx, proj)
  116. B = pymatrixid.reconstruct_skel_matrix(A, k, idx)
  117. assert_(np.allclose(B, A[:,idx[:k]], eps))
  118. assert_(np.allclose(B.dot(P), A, eps))
  119. # test SVD routines
  120. _debug_print("-----------------------------------------")
  121. _debug_print("SVD routines")
  122. _debug_print("-----------------------------------------")
  123. # fixed precision
  124. _debug_print("Calling iddp_svd / idzp_svd ...",)
  125. t0 = time.time()
  126. U, S, V = pymatrixid.svd(A, eps, rand=False)
  127. t = time.time() - t0
  128. B = np.dot(U, np.dot(np.diag(S), V.T.conj()))
  129. _debug_print(fmt % (t, np.allclose(A, B, eps)))
  130. assert_(np.allclose(A, B, eps))
  131. _debug_print("Calling iddp_asvd / idzp_asvd...",)
  132. t0 = time.time()
  133. U, S, V = pymatrixid.svd(A, eps)
  134. t = time.time() - t0
  135. B = np.dot(U, np.dot(np.diag(S), V.T.conj()))
  136. _debug_print(fmt % (t, np.allclose(A, B, eps)))
  137. assert_(np.allclose(A, B, eps))
  138. _debug_print("Calling iddp_rsvd / idzp_rsvd...",)
  139. t0 = time.time()
  140. U, S, V = pymatrixid.svd(L, eps)
  141. t = time.time() - t0
  142. B = np.dot(U, np.dot(np.diag(S), V.T.conj()))
  143. _debug_print(fmt % (t, np.allclose(A, B, eps)))
  144. assert_(np.allclose(A, B, eps))
  145. # fixed rank
  146. k = rank
  147. _debug_print("Calling iddr_svd / idzr_svd ...",)
  148. t0 = time.time()
  149. U, S, V = pymatrixid.svd(A, k, rand=False)
  150. t = time.time() - t0
  151. B = np.dot(U, np.dot(np.diag(S), V.T.conj()))
  152. _debug_print(fmt % (t, np.allclose(A, B, eps)))
  153. assert_(np.allclose(A, B, eps))
  154. _debug_print("Calling iddr_asvd / idzr_asvd ...",)
  155. t0 = time.time()
  156. U, S, V = pymatrixid.svd(A, k)
  157. t = time.time() - t0
  158. B = np.dot(U, np.dot(np.diag(S), V.T.conj()))
  159. _debug_print(fmt % (t, np.allclose(A, B, eps)))
  160. assert_(np.allclose(A, B, eps))
  161. _debug_print("Calling iddr_rsvd / idzr_rsvd ...",)
  162. t0 = time.time()
  163. U, S, V = pymatrixid.svd(L, k)
  164. t = time.time() - t0
  165. B = np.dot(U, np.dot(np.diag(S), V.T.conj()))
  166. _debug_print(fmt % (t, np.allclose(A, B, eps)))
  167. assert_(np.allclose(A, B, eps))
  168. # ID to SVD
  169. idx, proj = pymatrixid.interp_decomp(A, k, rand=False)
  170. Up, Sp, Vp = pymatrixid.id_to_svd(A[:, idx[:k]], idx, proj)
  171. B = U.dot(np.diag(S).dot(V.T.conj()))
  172. assert_(np.allclose(A, B, eps))
  173. # Norm estimates
  174. s = svdvals(A)
  175. norm_2_est = pymatrixid.estimate_spectral_norm(A)
  176. assert_(np.allclose(norm_2_est, s[0], 1e-6))
  177. B = A.copy()
  178. B[:,0] *= 1.2
  179. s = svdvals(A - B)
  180. norm_2_est = pymatrixid.estimate_spectral_norm_diff(A, B)
  181. assert_(np.allclose(norm_2_est, s[0], 1e-6))
  182. # Rank estimates
  183. B = np.array([[1, 1, 0], [0, 0, 1], [0, 0, 1]], dtype=dtype)
  184. for M in [A, B]:
  185. ML = aslinearoperator(M)
  186. rank_tol = 1e-9
  187. rank_np = np.linalg.matrix_rank(M, norm(M, 2)*rank_tol)
  188. rank_est = pymatrixid.estimate_rank(M, rank_tol)
  189. rank_est_2 = pymatrixid.estimate_rank(ML, rank_tol)
  190. assert_(rank_est >= rank_np)
  191. assert_(rank_est <= rank_np + 10)
  192. assert_(rank_est_2 >= rank_np - 4)
  193. assert_(rank_est_2 <= rank_np + 4)
  194. def test_rand(self):
  195. pymatrixid.seed('default')
  196. assert_(np.allclose(pymatrixid.rand(2), [0.8932059, 0.64500803], 1e-4))
  197. pymatrixid.seed(1234)
  198. x1 = pymatrixid.rand(2)
  199. assert_(np.allclose(x1, [0.7513823, 0.06861718], 1e-4))
  200. np.random.seed(1234)
  201. pymatrixid.seed()
  202. x2 = pymatrixid.rand(2)
  203. np.random.seed(1234)
  204. pymatrixid.seed(np.random.rand(55))
  205. x3 = pymatrixid.rand(2)
  206. assert_allclose(x1, x2)
  207. assert_allclose(x1, x3)
  208. def test_badcall(self):
  209. A = hilbert(5).astype(np.float32)
  210. assert_raises(ValueError, pymatrixid.interp_decomp, A, 1e-6, rand=False)
  211. def test_rank_too_large(self):
  212. # svd(array, k) should not segfault
  213. a = np.ones((4, 3))
  214. with assert_raises(ValueError):
  215. pymatrixid.svd(a, 4)
  216. def test_full_rank(self):
  217. eps = 1.0e-12
  218. # fixed precision
  219. A = np.random.rand(16, 8)
  220. k, idx, proj = pymatrixid.interp_decomp(A, eps)
  221. assert_(k == A.shape[1])
  222. P = pymatrixid.reconstruct_interp_matrix(idx, proj)
  223. B = pymatrixid.reconstruct_skel_matrix(A, k, idx)
  224. assert_allclose(A, B.dot(P))
  225. # fixed rank
  226. idx, proj = pymatrixid.interp_decomp(A, k)
  227. P = pymatrixid.reconstruct_interp_matrix(idx, proj)
  228. B = pymatrixid.reconstruct_skel_matrix(A, k, idx)
  229. assert_allclose(A, B.dot(P))