test_hb.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from __future__ import division, print_function, absolute_import
  2. import sys
  3. if sys.version_info[0] >= 3:
  4. from io import StringIO
  5. else:
  6. from StringIO import StringIO
  7. import tempfile
  8. import numpy as np
  9. from numpy.testing import assert_equal, \
  10. assert_array_almost_equal_nulp
  11. from scipy.sparse import coo_matrix, csc_matrix, rand
  12. from scipy.io import hb_read, hb_write
  13. SIMPLE = """\
  14. No Title |No Key
  15. 9 4 1 4
  16. RUA 100 100 10 0
  17. (26I3) (26I3) (3E23.15)
  18. 1 2 2 2 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
  19. 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
  20. 3 3 3 3 3 3 3 4 4 4 6 6 6 6 6 6 6 6 6 6 6 8 9 9 9 9
  21. 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 9 11
  22. 37 71 89 18 30 45 70 19 25 52
  23. 2.971243799687726e-01 3.662366682877375e-01 4.786962174699534e-01
  24. 6.490068647991184e-01 6.617490424831662e-02 8.870370343191623e-01
  25. 4.196478590163001e-01 5.649603072111251e-01 9.934423887087086e-01
  26. 6.912334991524289e-01
  27. """
  28. SIMPLE_MATRIX = coo_matrix(
  29. ((0.297124379969, 0.366236668288, 0.47869621747, 0.649006864799,
  30. 0.0661749042483, 0.887037034319, 0.419647859016,
  31. 0.564960307211, 0.993442388709, 0.691233499152,),
  32. (np.array([[36, 70, 88, 17, 29, 44, 69, 18, 24, 51],
  33. [0, 4, 58, 61, 61, 72, 72, 73, 99, 99]]))))
  34. def assert_csc_almost_equal(r, l):
  35. r = csc_matrix(r)
  36. l = csc_matrix(l)
  37. assert_equal(r.indptr, l.indptr)
  38. assert_equal(r.indices, l.indices)
  39. assert_array_almost_equal_nulp(r.data, l.data, 10000)
  40. class TestHBReader(object):
  41. def test_simple(self):
  42. m = hb_read(StringIO(SIMPLE))
  43. assert_csc_almost_equal(m, SIMPLE_MATRIX)
  44. class TestHBReadWrite(object):
  45. def check_save_load(self, value):
  46. with tempfile.NamedTemporaryFile(mode='w+t') as file:
  47. hb_write(file, value)
  48. file.file.seek(0)
  49. value_loaded = hb_read(file)
  50. assert_csc_almost_equal(value, value_loaded)
  51. def test_simple(self):
  52. random_matrix = rand(10, 100, 0.1)
  53. for matrix_format in ('coo', 'csc', 'csr', 'bsr', 'dia', 'dok', 'lil'):
  54. matrix = random_matrix.asformat(matrix_format, copy=False)
  55. self.check_save_load(matrix)