test_interface.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. """Test functions for the sparse.linalg.interface module
  2. """
  3. from __future__ import division, print_function, absolute_import
  4. from functools import partial
  5. from itertools import product
  6. import operator
  7. import pytest
  8. from pytest import raises as assert_raises, warns
  9. from numpy.testing import assert_, assert_equal
  10. import numpy as np
  11. import scipy.sparse as sparse
  12. from scipy.sparse.linalg import interface
  13. # Only test matmul operator (A @ B) when available (Python 3.5+)
  14. TEST_MATMUL = hasattr(operator, 'matmul')
  15. class TestLinearOperator(object):
  16. def setup_method(self):
  17. self.A = np.array([[1,2,3],
  18. [4,5,6]])
  19. self.B = np.array([[1,2],
  20. [3,4],
  21. [5,6]])
  22. self.C = np.array([[1,2],
  23. [3,4]])
  24. def test_matvec(self):
  25. def get_matvecs(A):
  26. return [{
  27. 'shape': A.shape,
  28. 'matvec': lambda x: np.dot(A, x).reshape(A.shape[0]),
  29. 'rmatvec': lambda x: np.dot(A.T.conj(),
  30. x).reshape(A.shape[1])
  31. },
  32. {
  33. 'shape': A.shape,
  34. 'matvec': lambda x: np.dot(A, x),
  35. 'rmatvec': lambda x: np.dot(A.T.conj(), x),
  36. 'matmat': lambda x: np.dot(A, x)
  37. }]
  38. for matvecs in get_matvecs(self.A):
  39. A = interface.LinearOperator(**matvecs)
  40. assert_(A.args == ())
  41. assert_equal(A.matvec(np.array([1,2,3])), [14,32])
  42. assert_equal(A.matvec(np.array([[1],[2],[3]])), [[14],[32]])
  43. assert_equal(A * np.array([1,2,3]), [14,32])
  44. assert_equal(A * np.array([[1],[2],[3]]), [[14],[32]])
  45. assert_equal(A.dot(np.array([1,2,3])), [14,32])
  46. assert_equal(A.dot(np.array([[1],[2],[3]])), [[14],[32]])
  47. assert_equal(A.matvec(np.matrix([[1],[2],[3]])), [[14],[32]])
  48. assert_equal(A * np.matrix([[1],[2],[3]]), [[14],[32]])
  49. assert_equal(A.dot(np.matrix([[1],[2],[3]])), [[14],[32]])
  50. assert_equal((2*A)*[1,1,1], [12,30])
  51. assert_equal((2*A).rmatvec([1,1]), [10, 14, 18])
  52. assert_equal((2*A).H.matvec([1,1]), [10, 14, 18])
  53. assert_equal((2*A)*[[1],[1],[1]], [[12],[30]])
  54. assert_equal((2*A).matmat([[1],[1],[1]]), [[12],[30]])
  55. assert_equal((A*2)*[1,1,1], [12,30])
  56. assert_equal((A*2)*[[1],[1],[1]], [[12],[30]])
  57. assert_equal((2j*A)*[1,1,1], [12j,30j])
  58. assert_equal((A+A)*[1,1,1], [12, 30])
  59. assert_equal((A+A).rmatvec([1,1]), [10, 14, 18])
  60. assert_equal((A+A).H.matvec([1,1]), [10, 14, 18])
  61. assert_equal((A+A)*[[1],[1],[1]], [[12], [30]])
  62. assert_equal((A+A).matmat([[1],[1],[1]]), [[12], [30]])
  63. assert_equal((-A)*[1,1,1], [-6,-15])
  64. assert_equal((-A)*[[1],[1],[1]], [[-6],[-15]])
  65. assert_equal((A-A)*[1,1,1], [0,0])
  66. assert_equal((A-A)*[[1],[1],[1]], [[0],[0]])
  67. z = A+A
  68. assert_(len(z.args) == 2 and z.args[0] is A and z.args[1] is A)
  69. z = 2*A
  70. assert_(len(z.args) == 2 and z.args[0] is A and z.args[1] == 2)
  71. assert_(isinstance(A.matvec([1, 2, 3]), np.ndarray))
  72. assert_(isinstance(A.matvec(np.array([[1],[2],[3]])), np.ndarray))
  73. assert_(isinstance(A * np.array([1,2,3]), np.ndarray))
  74. assert_(isinstance(A * np.array([[1],[2],[3]]), np.ndarray))
  75. assert_(isinstance(A.dot(np.array([1,2,3])), np.ndarray))
  76. assert_(isinstance(A.dot(np.array([[1],[2],[3]])), np.ndarray))
  77. assert_(isinstance(A.matvec(np.matrix([[1],[2],[3]])), np.ndarray))
  78. assert_(isinstance(A * np.matrix([[1],[2],[3]]), np.ndarray))
  79. assert_(isinstance(A.dot(np.matrix([[1],[2],[3]])), np.ndarray))
  80. assert_(isinstance(2*A, interface._ScaledLinearOperator))
  81. assert_(isinstance(2j*A, interface._ScaledLinearOperator))
  82. assert_(isinstance(A+A, interface._SumLinearOperator))
  83. assert_(isinstance(-A, interface._ScaledLinearOperator))
  84. assert_(isinstance(A-A, interface._SumLinearOperator))
  85. assert_((2j*A).dtype == np.complex_)
  86. assert_raises(ValueError, A.matvec, np.array([1,2]))
  87. assert_raises(ValueError, A.matvec, np.array([1,2,3,4]))
  88. assert_raises(ValueError, A.matvec, np.array([[1],[2]]))
  89. assert_raises(ValueError, A.matvec, np.array([[1],[2],[3],[4]]))
  90. assert_raises(ValueError, lambda: A*A)
  91. assert_raises(ValueError, lambda: A**2)
  92. for matvecsA, matvecsB in product(get_matvecs(self.A),
  93. get_matvecs(self.B)):
  94. A = interface.LinearOperator(**matvecsA)
  95. B = interface.LinearOperator(**matvecsB)
  96. assert_equal((A*B)*[1,1], [50,113])
  97. assert_equal((A*B)*[[1],[1]], [[50],[113]])
  98. assert_equal((A*B).matmat([[1],[1]]), [[50],[113]])
  99. assert_equal((A*B).rmatvec([1,1]), [71,92])
  100. assert_equal((A*B).H.matvec([1,1]), [71,92])
  101. assert_(isinstance(A*B, interface._ProductLinearOperator))
  102. assert_raises(ValueError, lambda: A+B)
  103. assert_raises(ValueError, lambda: A**2)
  104. z = A*B
  105. assert_(len(z.args) == 2 and z.args[0] is A and z.args[1] is B)
  106. for matvecsC in get_matvecs(self.C):
  107. C = interface.LinearOperator(**matvecsC)
  108. assert_equal((C**2)*[1,1], [17,37])
  109. assert_equal((C**2).rmatvec([1,1]), [22,32])
  110. assert_equal((C**2).H.matvec([1,1]), [22,32])
  111. assert_equal((C**2).matmat([[1],[1]]), [[17],[37]])
  112. assert_(isinstance(C**2, interface._PowerLinearOperator))
  113. def test_matmul(self):
  114. if not TEST_MATMUL:
  115. pytest.skip("matmul is only tested in Python 3.5+")
  116. D = {'shape': self.A.shape,
  117. 'matvec': lambda x: np.dot(self.A, x).reshape(self.A.shape[0]),
  118. 'rmatvec': lambda x: np.dot(self.A.T.conj(),
  119. x).reshape(self.A.shape[1]),
  120. 'matmat': lambda x: np.dot(self.A, x)}
  121. A = interface.LinearOperator(**D)
  122. B = np.array([[1, 2, 3],
  123. [4, 5, 6],
  124. [7, 8, 9]])
  125. b = B[0]
  126. assert_equal(operator.matmul(A, b), A * b)
  127. assert_equal(operator.matmul(A, B), A * B)
  128. assert_raises(ValueError, operator.matmul, A, 2)
  129. assert_raises(ValueError, operator.matmul, 2, A)
  130. class TestAsLinearOperator(object):
  131. def setup_method(self):
  132. self.cases = []
  133. def make_cases(dtype):
  134. self.cases.append(np.matrix([[1,2,3],[4,5,6]], dtype=dtype))
  135. self.cases.append(np.array([[1,2,3],[4,5,6]], dtype=dtype))
  136. self.cases.append(sparse.csr_matrix([[1,2,3],[4,5,6]], dtype=dtype))
  137. # Test default implementations of _adjoint and _rmatvec, which
  138. # refer to each other.
  139. def mv(x, dtype):
  140. y = np.array([1 * x[0] + 2 * x[1] + 3 * x[2],
  141. 4 * x[0] + 5 * x[1] + 6 * x[2]], dtype=dtype)
  142. if len(x.shape) == 2:
  143. y = y.reshape(-1, 1)
  144. return y
  145. def rmv(x, dtype):
  146. return np.array([1 * x[0] + 4 * x[1],
  147. 2 * x[0] + 5 * x[1],
  148. 3 * x[0] + 6 * x[1]], dtype=dtype)
  149. class BaseMatlike(interface.LinearOperator):
  150. def __init__(self, dtype):
  151. self.dtype = np.dtype(dtype)
  152. self.shape = (2,3)
  153. def _matvec(self, x):
  154. return mv(x, self.dtype)
  155. class HasRmatvec(BaseMatlike):
  156. def _rmatvec(self,x):
  157. return rmv(x, self.dtype)
  158. class HasAdjoint(BaseMatlike):
  159. def _adjoint(self):
  160. shape = self.shape[1], self.shape[0]
  161. matvec = partial(rmv, dtype=self.dtype)
  162. rmatvec = partial(mv, dtype=self.dtype)
  163. return interface.LinearOperator(matvec=matvec,
  164. rmatvec=rmatvec,
  165. dtype=self.dtype,
  166. shape=shape)
  167. self.cases.append(HasRmatvec(dtype))
  168. self.cases.append(HasAdjoint(dtype))
  169. make_cases('int32')
  170. make_cases('float32')
  171. make_cases('float64')
  172. def test_basic(self):
  173. for M in self.cases:
  174. A = interface.aslinearoperator(M)
  175. M,N = A.shape
  176. assert_equal(A.matvec(np.array([1,2,3])), [14,32])
  177. assert_equal(A.matvec(np.array([[1],[2],[3]])), [[14],[32]])
  178. assert_equal(A * np.array([1,2,3]), [14,32])
  179. assert_equal(A * np.array([[1],[2],[3]]), [[14],[32]])
  180. assert_equal(A.rmatvec(np.array([1,2])), [9,12,15])
  181. assert_equal(A.rmatvec(np.array([[1],[2]])), [[9],[12],[15]])
  182. assert_equal(A.H.matvec(np.array([1,2])), [9,12,15])
  183. assert_equal(A.H.matvec(np.array([[1],[2]])), [[9],[12],[15]])
  184. assert_equal(
  185. A.matmat(np.array([[1,4],[2,5],[3,6]])),
  186. [[14,32],[32,77]])
  187. assert_equal(A * np.array([[1,4],[2,5],[3,6]]), [[14,32],[32,77]])
  188. if hasattr(M,'dtype'):
  189. assert_equal(A.dtype, M.dtype)
  190. def test_dot(self):
  191. for M in self.cases:
  192. A = interface.aslinearoperator(M)
  193. M,N = A.shape
  194. assert_equal(A.dot(np.array([1,2,3])), [14,32])
  195. assert_equal(A.dot(np.array([[1],[2],[3]])), [[14],[32]])
  196. assert_equal(
  197. A.dot(np.array([[1,4],[2,5],[3,6]])),
  198. [[14,32],[32,77]])
  199. def test_repr():
  200. A = interface.LinearOperator(shape=(1, 1), matvec=lambda x: 1)
  201. repr_A = repr(A)
  202. assert_('unspecified dtype' not in repr_A, repr_A)
  203. def test_identity():
  204. ident = interface.IdentityOperator((3, 3))
  205. assert_equal(ident * [1, 2, 3], [1, 2, 3])
  206. assert_equal(ident.dot(np.arange(9).reshape(3, 3)).ravel(), np.arange(9))
  207. assert_raises(ValueError, ident.matvec, [1, 2, 3, 4])
  208. def test_attributes():
  209. A = interface.aslinearoperator(np.arange(16).reshape(4, 4))
  210. def always_four_ones(x):
  211. x = np.asarray(x)
  212. assert_(x.shape == (3,) or x.shape == (3, 1))
  213. return np.ones(4)
  214. B = interface.LinearOperator(shape=(4, 3), matvec=always_four_ones)
  215. for op in [A, B, A * B, A.H, A + A, B + B, A ** 4]:
  216. assert_(hasattr(op, "dtype"))
  217. assert_(hasattr(op, "shape"))
  218. assert_(hasattr(op, "_matvec"))
  219. def matvec(x):
  220. """ Needed for test_pickle as local functions are not pickleable """
  221. return np.zeros(3)
  222. def test_pickle():
  223. import pickle
  224. for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
  225. A = interface.LinearOperator((3, 3), matvec)
  226. s = pickle.dumps(A, protocol=protocol)
  227. B = pickle.loads(s)
  228. for k in A.__dict__:
  229. assert_equal(getattr(A, k), getattr(B, k))
  230. def test_inheritance():
  231. class Empty(interface.LinearOperator):
  232. pass
  233. with warns(RuntimeWarning, match="should implement at least"):
  234. assert_raises(TypeError, Empty)
  235. class Identity(interface.LinearOperator):
  236. def __init__(self, n):
  237. super(Identity, self).__init__(dtype=None, shape=(n, n))
  238. def _matvec(self, x):
  239. return x
  240. id3 = Identity(3)
  241. assert_equal(id3.matvec([1, 2, 3]), [1, 2, 3])
  242. assert_raises(NotImplementedError, id3.rmatvec, [4, 5, 6])
  243. class MatmatOnly(interface.LinearOperator):
  244. def __init__(self, A):
  245. super(MatmatOnly, self).__init__(A.dtype, A.shape)
  246. self.A = A
  247. def _matmat(self, x):
  248. return self.A.dot(x)
  249. mm = MatmatOnly(np.random.randn(5, 3))
  250. assert_equal(mm.matvec(np.random.randn(3)).shape, (5,))
  251. def test_dtypes_of_operator_sum():
  252. # gh-6078
  253. mat_complex = np.random.rand(2,2) + 1j * np.random.rand(2,2)
  254. mat_real = np.random.rand(2,2)
  255. complex_operator = interface.aslinearoperator(mat_complex)
  256. real_operator = interface.aslinearoperator(mat_real)
  257. sum_complex = complex_operator + complex_operator
  258. sum_real = real_operator + real_operator
  259. assert_equal(sum_real.dtype, np.float64)
  260. assert_equal(sum_complex.dtype, np.complex128)
  261. def test_no_double_init():
  262. call_count = [0]
  263. def matvec(v):
  264. call_count[0] += 1
  265. return v
  266. # It should call matvec exactly once (in order to determine the
  267. # operator dtype)
  268. A = interface.LinearOperator((2, 2), matvec=matvec)
  269. assert_equal(call_count[0], 1)
  270. def test_adjoint_conjugate():
  271. X = np.array([[1j]])
  272. A = interface.aslinearoperator(X)
  273. B = 1j * A
  274. Y = 1j * X
  275. v = np.array([1])
  276. assert_equal(B.dot(v), Y.dot(v))
  277. assert_equal(B.H.dot(v), Y.T.conj().dot(v))