test_expm_multiply.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. """Test functions for the sparse.linalg._expm_multiply module
  2. """
  3. from __future__ import division, print_function, absolute_import
  4. import numpy as np
  5. from numpy.testing import assert_allclose, assert_, assert_equal
  6. from scipy._lib._numpy_compat import suppress_warnings
  7. from scipy.sparse import SparseEfficiencyWarning
  8. import scipy.linalg
  9. from scipy.sparse.linalg._expm_multiply import (_theta, _compute_p_max,
  10. _onenormest_matrix_power, expm_multiply, _expm_multiply_simple,
  11. _expm_multiply_interval)
  12. def less_than_or_close(a, b):
  13. return np.allclose(a, b) or (a < b)
  14. class TestExpmActionSimple(object):
  15. """
  16. These tests do not consider the case of multiple time steps in one call.
  17. """
  18. def test_theta_monotonicity(self):
  19. pairs = sorted(_theta.items())
  20. for (m_a, theta_a), (m_b, theta_b) in zip(pairs[:-1], pairs[1:]):
  21. assert_(theta_a < theta_b)
  22. def test_p_max_default(self):
  23. m_max = 55
  24. expected_p_max = 8
  25. observed_p_max = _compute_p_max(m_max)
  26. assert_equal(observed_p_max, expected_p_max)
  27. def test_p_max_range(self):
  28. for m_max in range(1, 55+1):
  29. p_max = _compute_p_max(m_max)
  30. assert_(p_max*(p_max - 1) <= m_max + 1)
  31. p_too_big = p_max + 1
  32. assert_(p_too_big*(p_too_big - 1) > m_max + 1)
  33. def test_onenormest_matrix_power(self):
  34. np.random.seed(1234)
  35. n = 40
  36. nsamples = 10
  37. for i in range(nsamples):
  38. A = scipy.linalg.inv(np.random.randn(n, n))
  39. for p in range(4):
  40. if not p:
  41. M = np.identity(n)
  42. else:
  43. M = np.dot(M, A)
  44. estimated = _onenormest_matrix_power(A, p)
  45. exact = np.linalg.norm(M, 1)
  46. assert_(less_than_or_close(estimated, exact))
  47. assert_(less_than_or_close(exact, 3*estimated))
  48. def test_expm_multiply(self):
  49. np.random.seed(1234)
  50. n = 40
  51. k = 3
  52. nsamples = 10
  53. for i in range(nsamples):
  54. A = scipy.linalg.inv(np.random.randn(n, n))
  55. B = np.random.randn(n, k)
  56. observed = expm_multiply(A, B)
  57. expected = np.dot(scipy.linalg.expm(A), B)
  58. assert_allclose(observed, expected)
  59. def test_matrix_vector_multiply(self):
  60. np.random.seed(1234)
  61. n = 40
  62. nsamples = 10
  63. for i in range(nsamples):
  64. A = scipy.linalg.inv(np.random.randn(n, n))
  65. v = np.random.randn(n)
  66. observed = expm_multiply(A, v)
  67. expected = np.dot(scipy.linalg.expm(A), v)
  68. assert_allclose(observed, expected)
  69. def test_scaled_expm_multiply(self):
  70. np.random.seed(1234)
  71. n = 40
  72. k = 3
  73. nsamples = 10
  74. for i in range(nsamples):
  75. for t in (0.2, 1.0, 1.5):
  76. with np.errstate(invalid='ignore'):
  77. A = scipy.linalg.inv(np.random.randn(n, n))
  78. B = np.random.randn(n, k)
  79. observed = _expm_multiply_simple(A, B, t=t)
  80. expected = np.dot(scipy.linalg.expm(t*A), B)
  81. assert_allclose(observed, expected)
  82. def test_scaled_expm_multiply_single_timepoint(self):
  83. np.random.seed(1234)
  84. t = 0.1
  85. n = 5
  86. k = 2
  87. A = np.random.randn(n, n)
  88. B = np.random.randn(n, k)
  89. observed = _expm_multiply_simple(A, B, t=t)
  90. expected = scipy.linalg.expm(t*A).dot(B)
  91. assert_allclose(observed, expected)
  92. def test_sparse_expm_multiply(self):
  93. np.random.seed(1234)
  94. n = 40
  95. k = 3
  96. nsamples = 10
  97. for i in range(nsamples):
  98. A = scipy.sparse.rand(n, n, density=0.05)
  99. B = np.random.randn(n, k)
  100. observed = expm_multiply(A, B)
  101. with suppress_warnings() as sup:
  102. sup.filter(SparseEfficiencyWarning,
  103. "splu requires CSC matrix format")
  104. sup.filter(SparseEfficiencyWarning,
  105. "spsolve is more efficient when sparse b is in the CSC matrix format")
  106. expected = scipy.linalg.expm(A).dot(B)
  107. assert_allclose(observed, expected)
  108. def test_complex(self):
  109. A = np.array([
  110. [1j, 1j],
  111. [0, 1j]], dtype=complex)
  112. B = np.array([1j, 1j])
  113. observed = expm_multiply(A, B)
  114. expected = np.array([
  115. 1j * np.exp(1j) + 1j * (1j*np.cos(1) - np.sin(1)),
  116. 1j * np.exp(1j)], dtype=complex)
  117. assert_allclose(observed, expected)
  118. class TestExpmActionInterval(object):
  119. def test_sparse_expm_multiply_interval(self):
  120. np.random.seed(1234)
  121. start = 0.1
  122. stop = 3.2
  123. n = 40
  124. k = 3
  125. endpoint = True
  126. for num in (14, 13, 2):
  127. A = scipy.sparse.rand(n, n, density=0.05)
  128. B = np.random.randn(n, k)
  129. v = np.random.randn(n)
  130. for target in (B, v):
  131. X = expm_multiply(A, target,
  132. start=start, stop=stop, num=num, endpoint=endpoint)
  133. samples = np.linspace(start=start, stop=stop,
  134. num=num, endpoint=endpoint)
  135. with suppress_warnings() as sup:
  136. sup.filter(SparseEfficiencyWarning,
  137. "splu requires CSC matrix format")
  138. sup.filter(SparseEfficiencyWarning,
  139. "spsolve is more efficient when sparse b is in the CSC matrix format")
  140. for solution, t in zip(X, samples):
  141. assert_allclose(solution,
  142. scipy.linalg.expm(t*A).dot(target))
  143. def test_expm_multiply_interval_vector(self):
  144. np.random.seed(1234)
  145. start = 0.1
  146. stop = 3.2
  147. endpoint = True
  148. for num in (14, 13, 2):
  149. for n in (1, 2, 5, 20, 40):
  150. A = scipy.linalg.inv(np.random.randn(n, n))
  151. v = np.random.randn(n)
  152. X = expm_multiply(A, v,
  153. start=start, stop=stop, num=num, endpoint=endpoint)
  154. samples = np.linspace(start=start, stop=stop,
  155. num=num, endpoint=endpoint)
  156. for solution, t in zip(X, samples):
  157. assert_allclose(solution, scipy.linalg.expm(t*A).dot(v))
  158. def test_expm_multiply_interval_matrix(self):
  159. np.random.seed(1234)
  160. start = 0.1
  161. stop = 3.2
  162. endpoint = True
  163. for num in (14, 13, 2):
  164. for n in (1, 2, 5, 20, 40):
  165. for k in (1, 2):
  166. A = scipy.linalg.inv(np.random.randn(n, n))
  167. B = np.random.randn(n, k)
  168. X = expm_multiply(A, B,
  169. start=start, stop=stop, num=num, endpoint=endpoint)
  170. samples = np.linspace(start=start, stop=stop,
  171. num=num, endpoint=endpoint)
  172. for solution, t in zip(X, samples):
  173. assert_allclose(solution, scipy.linalg.expm(t*A).dot(B))
  174. def test_sparse_expm_multiply_interval_dtypes(self):
  175. # Test A & B int
  176. A = scipy.sparse.diags(np.arange(5),format='csr', dtype=int)
  177. B = np.ones(5, dtype=int)
  178. Aexpm = scipy.sparse.diags(np.exp(np.arange(5)),format='csr')
  179. assert_allclose(expm_multiply(A,B,0,1)[-1], Aexpm.dot(B))
  180. # Test A complex, B int
  181. A = scipy.sparse.diags(-1j*np.arange(5),format='csr', dtype=complex)
  182. B = np.ones(5, dtype=int)
  183. Aexpm = scipy.sparse.diags(np.exp(-1j*np.arange(5)),format='csr')
  184. assert_allclose(expm_multiply(A,B,0,1)[-1], Aexpm.dot(B))
  185. # Test A int, B complex
  186. A = scipy.sparse.diags(np.arange(5),format='csr', dtype=int)
  187. B = 1j*np.ones(5, dtype=complex)
  188. Aexpm = scipy.sparse.diags(np.exp(np.arange(5)),format='csr')
  189. assert_allclose(expm_multiply(A,B,0,1)[-1], Aexpm.dot(B))
  190. def test_expm_multiply_interval_status_0(self):
  191. self._help_test_specific_expm_interval_status(0)
  192. def test_expm_multiply_interval_status_1(self):
  193. self._help_test_specific_expm_interval_status(1)
  194. def test_expm_multiply_interval_status_2(self):
  195. self._help_test_specific_expm_interval_status(2)
  196. def _help_test_specific_expm_interval_status(self, target_status):
  197. np.random.seed(1234)
  198. start = 0.1
  199. stop = 3.2
  200. num = 13
  201. endpoint = True
  202. n = 5
  203. k = 2
  204. nrepeats = 10
  205. nsuccesses = 0
  206. for num in [14, 13, 2] * nrepeats:
  207. A = np.random.randn(n, n)
  208. B = np.random.randn(n, k)
  209. status = _expm_multiply_interval(A, B,
  210. start=start, stop=stop, num=num, endpoint=endpoint,
  211. status_only=True)
  212. if status == target_status:
  213. X, status = _expm_multiply_interval(A, B,
  214. start=start, stop=stop, num=num, endpoint=endpoint,
  215. status_only=False)
  216. assert_equal(X.shape, (num, n, k))
  217. samples = np.linspace(start=start, stop=stop,
  218. num=num, endpoint=endpoint)
  219. for solution, t in zip(X, samples):
  220. assert_allclose(solution, scipy.linalg.expm(t*A).dot(B))
  221. nsuccesses += 1
  222. if not nsuccesses:
  223. msg = 'failed to find a status-' + str(target_status) + ' interval'
  224. raise Exception(msg)