_validation.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from __future__ import division, print_function, absolute_import
  2. import numpy as np
  3. from scipy.sparse import csr_matrix, isspmatrix, isspmatrix_csc
  4. from ._tools import csgraph_to_dense, csgraph_from_dense,\
  5. csgraph_masked_from_dense, csgraph_from_masked
  6. DTYPE = np.float64
  7. def validate_graph(csgraph, directed, dtype=DTYPE,
  8. csr_output=True, dense_output=True,
  9. copy_if_dense=False, copy_if_sparse=False,
  10. null_value_in=0, null_value_out=np.inf,
  11. infinity_null=True, nan_null=True):
  12. """Routine for validation and conversion of csgraph inputs"""
  13. if not (csr_output or dense_output):
  14. raise ValueError("Internal: dense or csr output must be true")
  15. # if undirected and csc storage, then transposing in-place
  16. # is quicker than later converting to csr.
  17. if (not directed) and isspmatrix_csc(csgraph):
  18. csgraph = csgraph.T
  19. if isspmatrix(csgraph):
  20. if csr_output:
  21. csgraph = csr_matrix(csgraph, dtype=DTYPE, copy=copy_if_sparse)
  22. else:
  23. csgraph = csgraph_to_dense(csgraph, null_value=null_value_out)
  24. elif np.ma.isMaskedArray(csgraph):
  25. if dense_output:
  26. mask = csgraph.mask
  27. csgraph = np.array(csgraph.data, dtype=DTYPE, copy=copy_if_dense)
  28. csgraph[mask] = null_value_out
  29. else:
  30. csgraph = csgraph_from_masked(csgraph)
  31. else:
  32. if dense_output:
  33. csgraph = csgraph_masked_from_dense(csgraph,
  34. copy=copy_if_dense,
  35. null_value=null_value_in,
  36. nan_null=nan_null,
  37. infinity_null=infinity_null)
  38. mask = csgraph.mask
  39. csgraph = np.asarray(csgraph.data, dtype=DTYPE)
  40. csgraph[mask] = null_value_out
  41. else:
  42. csgraph = csgraph_from_dense(csgraph, null_value=null_value_in,
  43. infinity_null=infinity_null,
  44. nan_null=nan_null)
  45. if csgraph.ndim != 2:
  46. raise ValueError("compressed-sparse graph must be two dimensional")
  47. if csgraph.shape[0] != csgraph.shape[1]:
  48. raise ValueError("compressed-sparse graph must be shape (N, N)")
  49. return csgraph