test_csr.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from __future__ import division, print_function, absolute_import
  2. import numpy as np
  3. from numpy.testing import assert_array_almost_equal, assert_
  4. from scipy.sparse import csr_matrix
  5. def _check_csr_rowslice(i, sl, X, Xcsr):
  6. np_slice = X[i, sl]
  7. csr_slice = Xcsr[i, sl]
  8. assert_array_almost_equal(np_slice, csr_slice.toarray()[0])
  9. assert_(type(csr_slice) is csr_matrix)
  10. def test_csr_rowslice():
  11. N = 10
  12. np.random.seed(0)
  13. X = np.random.random((N, N))
  14. X[X > 0.7] = 0
  15. Xcsr = csr_matrix(X)
  16. slices = [slice(None, None, None),
  17. slice(None, None, -1),
  18. slice(1, -2, 2),
  19. slice(-2, 1, -2)]
  20. for i in range(N):
  21. for sl in slices:
  22. _check_csr_rowslice(i, sl, X, Xcsr)
  23. def test_csr_getrow():
  24. N = 10
  25. np.random.seed(0)
  26. X = np.random.random((N, N))
  27. X[X > 0.7] = 0
  28. Xcsr = csr_matrix(X)
  29. for i in range(N):
  30. arr_row = X[i:i + 1, :]
  31. csr_row = Xcsr.getrow(i)
  32. assert_array_almost_equal(arr_row, csr_row.toarray())
  33. assert_(type(csr_row) is csr_matrix)
  34. def test_csr_getcol():
  35. N = 10
  36. np.random.seed(0)
  37. X = np.random.random((N, N))
  38. X[X > 0.7] = 0
  39. Xcsr = csr_matrix(X)
  40. for i in range(N):
  41. arr_col = X[:, i:i + 1]
  42. csr_col = Xcsr.getcol(i)
  43. assert_array_almost_equal(arr_col, csr_col.toarray())
  44. assert_(type(csr_col) is csr_matrix)