_matrix_io.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. from __future__ import division, print_function, absolute_import
  2. import sys
  3. import numpy as np
  4. import scipy.sparse
  5. from scipy._lib._version import NumpyVersion
  6. __all__ = ['save_npz', 'load_npz']
  7. if NumpyVersion(np.__version__) >= '1.10.0':
  8. # Make loading safe vs. malicious input
  9. PICKLE_KWARGS = dict(allow_pickle=False)
  10. else:
  11. PICKLE_KWARGS = dict()
  12. def save_npz(file, matrix, compressed=True):
  13. """ Save a sparse matrix to a file using ``.npz`` format.
  14. Parameters
  15. ----------
  16. file : str or file-like object
  17. Either the file name (string) or an open file (file-like object)
  18. where the data will be saved. If file is a string, the ``.npz``
  19. extension will be appended to the file name if it is not already
  20. there.
  21. matrix: spmatrix (format: ``csc``, ``csr``, ``bsr``, ``dia`` or coo``)
  22. The sparse matrix to save.
  23. compressed : bool, optional
  24. Allow compressing the file. Default: True
  25. See Also
  26. --------
  27. scipy.sparse.load_npz: Load a sparse matrix from a file using ``.npz`` format.
  28. numpy.savez: Save several arrays into a ``.npz`` archive.
  29. numpy.savez_compressed : Save several arrays into a compressed ``.npz`` archive.
  30. Examples
  31. --------
  32. Store sparse matrix to disk, and load it again:
  33. >>> import scipy.sparse
  34. >>> sparse_matrix = scipy.sparse.csc_matrix(np.array([[0, 0, 3], [4, 0, 0]]))
  35. >>> sparse_matrix
  36. <2x3 sparse matrix of type '<class 'numpy.int64'>'
  37. with 2 stored elements in Compressed Sparse Column format>
  38. >>> sparse_matrix.todense()
  39. matrix([[0, 0, 3],
  40. [4, 0, 0]], dtype=int64)
  41. >>> scipy.sparse.save_npz('/tmp/sparse_matrix.npz', sparse_matrix)
  42. >>> sparse_matrix = scipy.sparse.load_npz('/tmp/sparse_matrix.npz')
  43. >>> sparse_matrix
  44. <2x3 sparse matrix of type '<class 'numpy.int64'>'
  45. with 2 stored elements in Compressed Sparse Column format>
  46. >>> sparse_matrix.todense()
  47. matrix([[0, 0, 3],
  48. [4, 0, 0]], dtype=int64)
  49. """
  50. arrays_dict = {}
  51. if matrix.format in ('csc', 'csr', 'bsr'):
  52. arrays_dict.update(indices=matrix.indices, indptr=matrix.indptr)
  53. elif matrix.format == 'dia':
  54. arrays_dict.update(offsets=matrix.offsets)
  55. elif matrix.format == 'coo':
  56. arrays_dict.update(row=matrix.row, col=matrix.col)
  57. else:
  58. raise NotImplementedError('Save is not implemented for sparse matrix of format {}.'.format(matrix.format))
  59. arrays_dict.update(
  60. format=matrix.format.encode('ascii'),
  61. shape=matrix.shape,
  62. data=matrix.data
  63. )
  64. if compressed:
  65. np.savez_compressed(file, **arrays_dict)
  66. else:
  67. np.savez(file, **arrays_dict)
  68. def load_npz(file):
  69. """ Load a sparse matrix from a file using ``.npz`` format.
  70. Parameters
  71. ----------
  72. file : str or file-like object
  73. Either the file name (string) or an open file (file-like object)
  74. where the data will be loaded.
  75. Returns
  76. -------
  77. result : csc_matrix, csr_matrix, bsr_matrix, dia_matrix or coo_matrix
  78. A sparse matrix containing the loaded data.
  79. Raises
  80. ------
  81. IOError
  82. If the input file does not exist or cannot be read.
  83. See Also
  84. --------
  85. scipy.sparse.save_npz: Save a sparse matrix to a file using ``.npz`` format.
  86. numpy.load: Load several arrays from a ``.npz`` archive.
  87. Examples
  88. --------
  89. Store sparse matrix to disk, and load it again:
  90. >>> import scipy.sparse
  91. >>> sparse_matrix = scipy.sparse.csc_matrix(np.array([[0, 0, 3], [4, 0, 0]]))
  92. >>> sparse_matrix
  93. <2x3 sparse matrix of type '<class 'numpy.int64'>'
  94. with 2 stored elements in Compressed Sparse Column format>
  95. >>> sparse_matrix.todense()
  96. matrix([[0, 0, 3],
  97. [4, 0, 0]], dtype=int64)
  98. >>> scipy.sparse.save_npz('/tmp/sparse_matrix.npz', sparse_matrix)
  99. >>> sparse_matrix = scipy.sparse.load_npz('/tmp/sparse_matrix.npz')
  100. >>> sparse_matrix
  101. <2x3 sparse matrix of type '<class 'numpy.int64'>'
  102. with 2 stored elements in Compressed Sparse Column format>
  103. >>> sparse_matrix.todense()
  104. matrix([[0, 0, 3],
  105. [4, 0, 0]], dtype=int64)
  106. """
  107. with np.load(file, **PICKLE_KWARGS) as loaded:
  108. try:
  109. matrix_format = loaded['format']
  110. except KeyError:
  111. raise ValueError('The file {} does not contain a sparse matrix.'.format(file))
  112. matrix_format = matrix_format.item()
  113. if sys.version_info[0] >= 3 and not isinstance(matrix_format, str):
  114. # Play safe with Python 2 vs 3 backward compatibility;
  115. # files saved with Scipy < 1.0.0 may contain unicode or bytes.
  116. matrix_format = matrix_format.decode('ascii')
  117. try:
  118. cls = getattr(scipy.sparse, '{}_matrix'.format(matrix_format))
  119. except AttributeError:
  120. raise ValueError('Unknown matrix format "{}"'.format(matrix_format))
  121. if matrix_format in ('csc', 'csr', 'bsr'):
  122. return cls((loaded['data'], loaded['indices'], loaded['indptr']), shape=loaded['shape'])
  123. elif matrix_format == 'dia':
  124. return cls((loaded['data'], loaded['offsets']), shape=loaded['shape'])
  125. elif matrix_format == 'coo':
  126. return cls((loaded['data'], (loaded['row'], loaded['col'])), shape=loaded['shape'])
  127. else:
  128. raise NotImplementedError('Load is not implemented for '
  129. 'sparse matrix of format {}.'.format(matrix_format))