test_hungarian.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # Author: Brian M. Clapper, G. Varoquaux, Lars Buitinck
  2. # License: BSD
  3. from numpy.testing import assert_array_equal
  4. from pytest import raises as assert_raises
  5. import numpy as np
  6. from scipy.optimize import linear_sum_assignment
  7. def test_linear_sum_assignment():
  8. for cost_matrix, expected_cost in [
  9. # Square
  10. ([[400, 150, 400],
  11. [400, 450, 600],
  12. [300, 225, 300]],
  13. [150, 400, 300]
  14. ),
  15. # Rectangular variant
  16. ([[400, 150, 400, 1],
  17. [400, 450, 600, 2],
  18. [300, 225, 300, 3]],
  19. [150, 2, 300]),
  20. # Square
  21. ([[10, 10, 8],
  22. [9, 8, 1],
  23. [9, 7, 4]],
  24. [10, 1, 7]),
  25. # Rectangular variant
  26. ([[10, 10, 8, 11],
  27. [9, 8, 1, 1],
  28. [9, 7, 4, 10]],
  29. [10, 1, 4]),
  30. # n == 2, m == 0 matrix
  31. ([[], []],
  32. []),
  33. ]:
  34. cost_matrix = np.array(cost_matrix)
  35. row_ind, col_ind = linear_sum_assignment(cost_matrix)
  36. assert_array_equal(row_ind, np.sort(row_ind))
  37. assert_array_equal(expected_cost, cost_matrix[row_ind, col_ind])
  38. cost_matrix = cost_matrix.T
  39. row_ind, col_ind = linear_sum_assignment(cost_matrix)
  40. assert_array_equal(row_ind, np.sort(row_ind))
  41. assert_array_equal(np.sort(expected_cost),
  42. np.sort(cost_matrix[row_ind, col_ind]))
  43. def test_linear_sum_assignment_input_validation():
  44. assert_raises(ValueError, linear_sum_assignment, [1, 2, 3])
  45. C = [[1, 2, 3], [4, 5, 6]]
  46. assert_array_equal(linear_sum_assignment(C),
  47. linear_sum_assignment(np.asarray(C)))
  48. assert_array_equal(linear_sum_assignment(C),
  49. linear_sum_assignment(np.matrix(C)))
  50. I = np.identity(3)
  51. assert_array_equal(linear_sum_assignment(I.astype(np.bool)),
  52. linear_sum_assignment(I))
  53. assert_raises(ValueError, linear_sum_assignment, I.astype(str))
  54. I[0][0] = np.nan
  55. assert_raises(ValueError, linear_sum_assignment, I)
  56. I = np.identity(3)
  57. I[1][1] = np.inf
  58. assert_raises(ValueError, linear_sum_assignment, I)