viterbi.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import sys
  2. import operator
  3. MIN_FLOAT = -3.14e100
  4. MIN_INF = float("-inf")
  5. if sys.version_info[0] > 2:
  6. xrange = range
  7. def get_top_states(t_state_v, K=4):
  8. return sorted(t_state_v, key=t_state_v.__getitem__, reverse=True)[:K]
  9. def viterbi(obs, states, start_p, trans_p, emit_p):
  10. V = [{}] # tabular
  11. mem_path = [{}]
  12. all_states = trans_p.keys()
  13. for y in states.get(obs[0], all_states): # init
  14. V[0][y] = start_p[y] + emit_p[y].get(obs[0], MIN_FLOAT)
  15. mem_path[0][y] = ''
  16. for t in xrange(1, len(obs)):
  17. V.append({})
  18. mem_path.append({})
  19. #prev_states = get_top_states(V[t-1])
  20. prev_states = [
  21. x for x in mem_path[t - 1].keys() if len(trans_p[x]) > 0]
  22. prev_states_expect_next = set(
  23. (y for x in prev_states for y in trans_p[x].keys()))
  24. obs_states = set(
  25. states.get(obs[t], all_states)) & prev_states_expect_next
  26. if not obs_states:
  27. obs_states = prev_states_expect_next if prev_states_expect_next else all_states
  28. for y in obs_states:
  29. prob, state = max((V[t - 1][y0] + trans_p[y0].get(y, MIN_INF) +
  30. emit_p[y].get(obs[t], MIN_FLOAT), y0) for y0 in prev_states)
  31. V[t][y] = prob
  32. mem_path[t][y] = state
  33. last = [(V[-1][y], y) for y in mem_path[-1].keys()]
  34. # if len(last)==0:
  35. # print obs
  36. prob, state = max(last)
  37. route = [None] * len(obs)
  38. i = len(obs) - 1
  39. while i >= 0:
  40. route[i] = state
  41. state = mem_path[i][state]
  42. i -= 1
  43. return (prob, route)