test_spanning_tree.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. """Test the minimum spanning tree function"""
  2. from __future__ import division, print_function, absolute_import
  3. import numpy as np
  4. from numpy.testing import assert_
  5. import numpy.testing as npt
  6. from scipy.sparse import csr_matrix
  7. from scipy.sparse.csgraph import minimum_spanning_tree
  8. def test_minimum_spanning_tree():
  9. # Create a graph with two connected components.
  10. graph = [[0,1,0,0,0],
  11. [1,0,0,0,0],
  12. [0,0,0,8,5],
  13. [0,0,8,0,1],
  14. [0,0,5,1,0]]
  15. graph = np.asarray(graph)
  16. # Create the expected spanning tree.
  17. expected = [[0,1,0,0,0],
  18. [0,0,0,0,0],
  19. [0,0,0,0,5],
  20. [0,0,0,0,1],
  21. [0,0,0,0,0]]
  22. expected = np.asarray(expected)
  23. # Ensure minimum spanning tree code gives this expected output.
  24. csgraph = csr_matrix(graph)
  25. mintree = minimum_spanning_tree(csgraph)
  26. npt.assert_array_equal(mintree.todense(), expected,
  27. 'Incorrect spanning tree found.')
  28. # Ensure that the original graph was not modified.
  29. npt.assert_array_equal(csgraph.todense(), graph,
  30. 'Original graph was modified.')
  31. # Now let the algorithm modify the csgraph in place.
  32. mintree = minimum_spanning_tree(csgraph, overwrite=True)
  33. npt.assert_array_equal(mintree.todense(), expected,
  34. 'Graph was not properly modified to contain MST.')
  35. np.random.seed(1234)
  36. for N in (5, 10, 15, 20):
  37. # Create a random graph.
  38. graph = 3 + np.random.random((N, N))
  39. csgraph = csr_matrix(graph)
  40. # The spanning tree has at most N - 1 edges.
  41. mintree = minimum_spanning_tree(csgraph)
  42. assert_(mintree.nnz < N)
  43. # Set the sub diagonal to 1 to create a known spanning tree.
  44. idx = np.arange(N-1)
  45. graph[idx,idx+1] = 1
  46. csgraph = csr_matrix(graph)
  47. mintree = minimum_spanning_tree(csgraph)
  48. # We expect to see this pattern in the spanning tree and otherwise
  49. # have this zero.
  50. expected = np.zeros((N, N))
  51. expected[idx, idx+1] = 1
  52. npt.assert_array_equal(mintree.todense(), expected,
  53. 'Incorrect spanning tree found.')