_hungarian.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. # Hungarian algorithm (Kuhn-Munkres) for solving the linear sum assignment
  2. # problem. Taken from scikit-learn. Based on original code by Brian Clapper,
  3. # adapted to NumPy by Gael Varoquaux.
  4. # Further improvements by Ben Root, Vlad Niculae and Lars Buitinck.
  5. #
  6. # Copyright (c) 2008 Brian M. Clapper <bmc@clapper.org>, Gael Varoquaux
  7. # Author: Brian M. Clapper, Gael Varoquaux
  8. # License: 3-clause BSD
  9. import numpy as np
  10. def linear_sum_assignment(cost_matrix):
  11. """Solve the linear sum assignment problem.
  12. The linear sum assignment problem is also known as minimum weight matching
  13. in bipartite graphs. A problem instance is described by a matrix C, where
  14. each C[i,j] is the cost of matching vertex i of the first partite set
  15. (a "worker") and vertex j of the second set (a "job"). The goal is to find
  16. a complete assignment of workers to jobs of minimal cost.
  17. Formally, let X be a boolean matrix where :math:`X[i,j] = 1` iff row i is
  18. assigned to column j. Then the optimal assignment has cost
  19. .. math::
  20. \\min \\sum_i \\sum_j C_{i,j} X_{i,j}
  21. s.t. each row is assignment to at most one column, and each column to at
  22. most one row.
  23. This function can also solve a generalization of the classic assignment
  24. problem where the cost matrix is rectangular. If it has more rows than
  25. columns, then not every row needs to be assigned to a column, and vice
  26. versa.
  27. The method used is the Hungarian algorithm, also known as the Munkres or
  28. Kuhn-Munkres algorithm.
  29. Parameters
  30. ----------
  31. cost_matrix : array
  32. The cost matrix of the bipartite graph.
  33. Returns
  34. -------
  35. row_ind, col_ind : array
  36. An array of row indices and one of corresponding column indices giving
  37. the optimal assignment. The cost of the assignment can be computed
  38. as ``cost_matrix[row_ind, col_ind].sum()``. The row indices will be
  39. sorted; in the case of a square cost matrix they will be equal to
  40. ``numpy.arange(cost_matrix.shape[0])``.
  41. Notes
  42. -----
  43. .. versionadded:: 0.17.0
  44. Examples
  45. --------
  46. >>> cost = np.array([[4, 1, 3], [2, 0, 5], [3, 2, 2]])
  47. >>> from scipy.optimize import linear_sum_assignment
  48. >>> row_ind, col_ind = linear_sum_assignment(cost)
  49. >>> col_ind
  50. array([1, 0, 2])
  51. >>> cost[row_ind, col_ind].sum()
  52. 5
  53. References
  54. ----------
  55. 1. http://csclab.murraystate.edu/bob.pilgrim/445/munkres.html
  56. 2. Harold W. Kuhn. The Hungarian Method for the assignment problem.
  57. *Naval Research Logistics Quarterly*, 2:83-97, 1955.
  58. 3. Harold W. Kuhn. Variants of the Hungarian method for assignment
  59. problems. *Naval Research Logistics Quarterly*, 3: 253-258, 1956.
  60. 4. Munkres, J. Algorithms for the Assignment and Transportation Problems.
  61. *J. SIAM*, 5(1):32-38, March, 1957.
  62. 5. https://en.wikipedia.org/wiki/Hungarian_algorithm
  63. """
  64. cost_matrix = np.asarray(cost_matrix)
  65. if len(cost_matrix.shape) != 2:
  66. raise ValueError("expected a matrix (2-d array), got a %r array"
  67. % (cost_matrix.shape,))
  68. if not (np.issubdtype(cost_matrix.dtype, np.number) or
  69. cost_matrix.dtype == np.dtype(np.bool)):
  70. raise ValueError("expected a matrix containing numerical entries, got %s"
  71. % (cost_matrix.dtype,))
  72. if np.any(np.isinf(cost_matrix) | np.isnan(cost_matrix)):
  73. raise ValueError("matrix contains invalid numeric entries")
  74. if cost_matrix.dtype == np.dtype(np.bool):
  75. cost_matrix = cost_matrix.astype(np.int)
  76. # The algorithm expects more columns than rows in the cost matrix.
  77. if cost_matrix.shape[1] < cost_matrix.shape[0]:
  78. cost_matrix = cost_matrix.T
  79. transposed = True
  80. else:
  81. transposed = False
  82. state = _Hungary(cost_matrix)
  83. # No need to bother with assignments if one of the dimensions
  84. # of the cost matrix is zero-length.
  85. step = None if 0 in cost_matrix.shape else _step1
  86. while step is not None:
  87. step = step(state)
  88. if transposed:
  89. marked = state.marked.T
  90. else:
  91. marked = state.marked
  92. return np.nonzero(marked == 1)
  93. class _Hungary(object):
  94. """State of the Hungarian algorithm.
  95. Parameters
  96. ----------
  97. cost_matrix : 2D matrix
  98. The cost matrix. Must have shape[1] >= shape[0].
  99. """
  100. def __init__(self, cost_matrix):
  101. self.C = cost_matrix.copy()
  102. n, m = self.C.shape
  103. self.row_uncovered = np.ones(n, dtype=bool)
  104. self.col_uncovered = np.ones(m, dtype=bool)
  105. self.Z0_r = 0
  106. self.Z0_c = 0
  107. self.path = np.zeros((n + m, 2), dtype=int)
  108. self.marked = np.zeros((n, m), dtype=int)
  109. def _clear_covers(self):
  110. """Clear all covered matrix cells"""
  111. self.row_uncovered[:] = True
  112. self.col_uncovered[:] = True
  113. # Individual steps of the algorithm follow, as a state machine: they return
  114. # the next step to be taken (function to be called), if any.
  115. def _step1(state):
  116. """Steps 1 and 2 in the Wikipedia page."""
  117. # Step 1: For each row of the matrix, find the smallest element and
  118. # subtract it from every element in its row.
  119. state.C -= state.C.min(axis=1)[:, np.newaxis]
  120. # Step 2: Find a zero (Z) in the resulting matrix. If there is no
  121. # starred zero in its row or column, star Z. Repeat for each element
  122. # in the matrix.
  123. for i, j in zip(*np.nonzero(state.C == 0)):
  124. if state.col_uncovered[j] and state.row_uncovered[i]:
  125. state.marked[i, j] = 1
  126. state.col_uncovered[j] = False
  127. state.row_uncovered[i] = False
  128. state._clear_covers()
  129. return _step3
  130. def _step3(state):
  131. """
  132. Cover each column containing a starred zero. If n columns are covered,
  133. the starred zeros describe a complete set of unique assignments.
  134. In this case, Go to DONE, otherwise, Go to Step 4.
  135. """
  136. marked = (state.marked == 1)
  137. state.col_uncovered[np.any(marked, axis=0)] = False
  138. if marked.sum() < state.C.shape[0]:
  139. return _step4
  140. def _step4(state):
  141. """
  142. Find a noncovered zero and prime it. If there is no starred zero
  143. in the row containing this primed zero, Go to Step 5. Otherwise,
  144. cover this row and uncover the column containing the starred
  145. zero. Continue in this manner until there are no uncovered zeros
  146. left. Save the smallest uncovered value and Go to Step 6.
  147. """
  148. # We convert to int as numpy operations are faster on int
  149. C = (state.C == 0).astype(int)
  150. covered_C = C * state.row_uncovered[:, np.newaxis]
  151. covered_C *= np.asarray(state.col_uncovered, dtype=int)
  152. n = state.C.shape[0]
  153. m = state.C.shape[1]
  154. while True:
  155. # Find an uncovered zero
  156. row, col = np.unravel_index(np.argmax(covered_C), (n, m))
  157. if covered_C[row, col] == 0:
  158. return _step6
  159. else:
  160. state.marked[row, col] = 2
  161. # Find the first starred element in the row
  162. star_col = np.argmax(state.marked[row] == 1)
  163. if state.marked[row, star_col] != 1:
  164. # Could not find one
  165. state.Z0_r = row
  166. state.Z0_c = col
  167. return _step5
  168. else:
  169. col = star_col
  170. state.row_uncovered[row] = False
  171. state.col_uncovered[col] = True
  172. covered_C[:, col] = C[:, col] * (
  173. np.asarray(state.row_uncovered, dtype=int))
  174. covered_C[row] = 0
  175. def _step5(state):
  176. """
  177. Construct a series of alternating primed and starred zeros as follows.
  178. Let Z0 represent the uncovered primed zero found in Step 4.
  179. Let Z1 denote the starred zero in the column of Z0 (if any).
  180. Let Z2 denote the primed zero in the row of Z1 (there will always be one).
  181. Continue until the series terminates at a primed zero that has no starred
  182. zero in its column. Unstar each starred zero of the series, star each
  183. primed zero of the series, erase all primes and uncover every line in the
  184. matrix. Return to Step 3
  185. """
  186. count = 0
  187. path = state.path
  188. path[count, 0] = state.Z0_r
  189. path[count, 1] = state.Z0_c
  190. while True:
  191. # Find the first starred element in the col defined by
  192. # the path.
  193. row = np.argmax(state.marked[:, path[count, 1]] == 1)
  194. if state.marked[row, path[count, 1]] != 1:
  195. # Could not find one
  196. break
  197. else:
  198. count += 1
  199. path[count, 0] = row
  200. path[count, 1] = path[count - 1, 1]
  201. # Find the first prime element in the row defined by the
  202. # first path step
  203. col = np.argmax(state.marked[path[count, 0]] == 2)
  204. if state.marked[row, col] != 2:
  205. col = -1
  206. count += 1
  207. path[count, 0] = path[count - 1, 0]
  208. path[count, 1] = col
  209. # Convert paths
  210. for i in range(count + 1):
  211. if state.marked[path[i, 0], path[i, 1]] == 1:
  212. state.marked[path[i, 0], path[i, 1]] = 0
  213. else:
  214. state.marked[path[i, 0], path[i, 1]] = 1
  215. state._clear_covers()
  216. # Erase all prime markings
  217. state.marked[state.marked == 2] = 0
  218. return _step3
  219. def _step6(state):
  220. """
  221. Add the value found in Step 4 to every element of each covered row,
  222. and subtract it from every element of each uncovered column.
  223. Return to Step 4 without altering any stars, primes, or covered lines.
  224. """
  225. # the smallest uncovered value in the matrix
  226. if np.any(state.row_uncovered) and np.any(state.col_uncovered):
  227. minval = np.min(state.C[state.row_uncovered], axis=0)
  228. minval = np.min(minval[state.col_uncovered])
  229. state.C[~state.row_uncovered] += minval
  230. state.C[:, state.col_uncovered] -= minval
  231. return _step4