test_matrix_io.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from __future__ import division, print_function, absolute_import
  2. import sys
  3. import os
  4. import numpy as np
  5. import tempfile
  6. import pytest
  7. from pytest import raises as assert_raises
  8. from numpy.testing import assert_equal, assert_
  9. from scipy._lib._version import NumpyVersion
  10. from scipy.sparse import (csc_matrix, csr_matrix, bsr_matrix, dia_matrix,
  11. coo_matrix, save_npz, load_npz, dok_matrix)
  12. DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
  13. def _save_and_load(matrix):
  14. fd, tmpfile = tempfile.mkstemp(suffix='.npz')
  15. os.close(fd)
  16. try:
  17. save_npz(tmpfile, matrix)
  18. loaded_matrix = load_npz(tmpfile)
  19. finally:
  20. os.remove(tmpfile)
  21. return loaded_matrix
  22. def _check_save_and_load(dense_matrix):
  23. for matrix_class in [csc_matrix, csr_matrix, bsr_matrix, dia_matrix, coo_matrix]:
  24. matrix = matrix_class(dense_matrix)
  25. loaded_matrix = _save_and_load(matrix)
  26. assert_(type(loaded_matrix) is matrix_class)
  27. assert_(loaded_matrix.shape == dense_matrix.shape)
  28. assert_(loaded_matrix.dtype == dense_matrix.dtype)
  29. assert_equal(loaded_matrix.toarray(), dense_matrix)
  30. def test_save_and_load_random():
  31. N = 10
  32. np.random.seed(0)
  33. dense_matrix = np.random.random((N, N))
  34. dense_matrix[dense_matrix > 0.7] = 0
  35. _check_save_and_load(dense_matrix)
  36. def test_save_and_load_empty():
  37. dense_matrix = np.zeros((4,6))
  38. _check_save_and_load(dense_matrix)
  39. def test_save_and_load_one_entry():
  40. dense_matrix = np.zeros((4,6))
  41. dense_matrix[1,2] = 1
  42. _check_save_and_load(dense_matrix)
  43. @pytest.mark.skipif(NumpyVersion(np.__version__) < '1.10.0',
  44. reason='disabling unpickling requires numpy >= 1.10.0')
  45. def test_malicious_load():
  46. class Executor(object):
  47. def __reduce__(self):
  48. return (assert_, (False, 'unexpected code execution'))
  49. fd, tmpfile = tempfile.mkstemp(suffix='.npz')
  50. os.close(fd)
  51. try:
  52. np.savez(tmpfile, format=Executor())
  53. # Should raise a ValueError, not execute code
  54. assert_raises(ValueError, load_npz, tmpfile)
  55. finally:
  56. os.remove(tmpfile)
  57. def test_py23_compatibility():
  58. # Try loading files saved on Python 2 and Python 3. They are not
  59. # the same, since files saved with Scipy versions < 1.0.0 may
  60. # contain unicode.
  61. a = load_npz(os.path.join(DATA_DIR, 'csc_py2.npz'))
  62. b = load_npz(os.path.join(DATA_DIR, 'csc_py3.npz'))
  63. c = csc_matrix([[0]])
  64. assert_equal(a.toarray(), c.toarray())
  65. assert_equal(b.toarray(), c.toarray())
  66. def test_implemented_error():
  67. # Attempts to save an unsupported type and checks that an
  68. # NotImplementedError is raised.
  69. x = dok_matrix((2,3))
  70. x[0,1] = 1
  71. assert_raises(NotImplementedError, save_npz, 'x.npz', x)