test_extract.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. """test sparse matrix construction functions"""
  2. from __future__ import division, print_function, absolute_import
  3. from numpy.testing import assert_equal
  4. from scipy.sparse import csr_matrix
  5. import numpy as np
  6. from scipy.sparse import extract
  7. class TestExtract(object):
  8. def setup_method(self):
  9. self.cases = [
  10. csr_matrix([[1,2]]),
  11. csr_matrix([[1,0]]),
  12. csr_matrix([[0,0]]),
  13. csr_matrix([[1],[2]]),
  14. csr_matrix([[1],[0]]),
  15. csr_matrix([[0],[0]]),
  16. csr_matrix([[1,2],[3,4]]),
  17. csr_matrix([[0,1],[0,0]]),
  18. csr_matrix([[0,0],[1,0]]),
  19. csr_matrix([[0,0],[0,0]]),
  20. csr_matrix([[1,2,0,0,3],[4,5,0,6,7],[0,0,8,9,0]]),
  21. csr_matrix([[1,2,0,0,3],[4,5,0,6,7],[0,0,8,9,0]]).T,
  22. ]
  23. def find(self):
  24. for A in self.cases:
  25. I,J,V = extract.find(A)
  26. assert_equal(A.toarray(), csr_matrix(((I,J),V), shape=A.shape))
  27. def test_tril(self):
  28. for A in self.cases:
  29. B = A.toarray()
  30. for k in [-3,-2,-1,0,1,2,3]:
  31. assert_equal(extract.tril(A,k=k).toarray(), np.tril(B,k=k))
  32. def test_triu(self):
  33. for A in self.cases:
  34. B = A.toarray()
  35. for k in [-3,-2,-1,0,1,2,3]:
  36. assert_equal(extract.triu(A,k=k).toarray(), np.triu(B,k=k))