textrank.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. from __future__ import absolute_import, unicode_literals
  4. import sys
  5. from operator import itemgetter
  6. from collections import defaultdict
  7. import jieba.posseg
  8. from .tfidf import KeywordExtractor
  9. from .._compat import *
  10. class UndirectWeightedGraph:
  11. d = 0.85
  12. def __init__(self):
  13. self.graph = defaultdict(list)
  14. def addEdge(self, start, end, weight):
  15. # use a tuple (start, end, weight) instead of a Edge object
  16. self.graph[start].append((start, end, weight))
  17. self.graph[end].append((end, start, weight))
  18. def rank(self):
  19. ws = defaultdict(float)
  20. outSum = defaultdict(float)
  21. wsdef = 1.0 / (len(self.graph) or 1.0)
  22. for n, out in self.graph.items():
  23. ws[n] = wsdef
  24. outSum[n] = sum((e[2] for e in out), 0.0)
  25. # this line for build stable iteration
  26. sorted_keys = sorted(self.graph.keys())
  27. for x in xrange(10): # 10 iters
  28. for n in sorted_keys:
  29. s = 0
  30. for e in self.graph[n]:
  31. s += e[2] / outSum[e[1]] * ws[e[1]]
  32. ws[n] = (1 - self.d) + self.d * s
  33. (min_rank, max_rank) = (sys.float_info[0], sys.float_info[3])
  34. for w in itervalues(ws):
  35. if w < min_rank:
  36. min_rank = w
  37. if w > max_rank:
  38. max_rank = w
  39. for n, w in ws.items():
  40. # to unify the weights, don't *100.
  41. ws[n] = (w - min_rank / 10.0) / (max_rank - min_rank / 10.0)
  42. return ws
  43. class TextRank(KeywordExtractor):
  44. def __init__(self):
  45. self.tokenizer = self.postokenizer = jieba.posseg.dt
  46. self.stop_words = self.STOP_WORDS.copy()
  47. self.pos_filt = frozenset(('ns', 'n', 'vn', 'v'))
  48. self.span = 5
  49. def pairfilter(self, wp):
  50. return (wp.flag in self.pos_filt and len(wp.word.strip()) >= 2
  51. and wp.word.lower() not in self.stop_words)
  52. def textrank(self, sentence, topK=20, withWeight=False, allowPOS=('ns', 'n', 'vn', 'v'), withFlag=False):
  53. """
  54. Extract keywords from sentence using TextRank algorithm.
  55. Parameter:
  56. - topK: return how many top keywords. `None` for all possible words.
  57. - withWeight: if True, return a list of (word, weight);
  58. if False, return a list of words.
  59. - allowPOS: the allowed POS list eg. ['ns', 'n', 'vn', 'v'].
  60. if the POS of w is not in this list, it will be filtered.
  61. - withFlag: if True, return a list of pair(word, weight) like posseg.cut
  62. if False, return a list of words
  63. """
  64. self.pos_filt = frozenset(allowPOS)
  65. g = UndirectWeightedGraph()
  66. cm = defaultdict(int)
  67. words = tuple(self.tokenizer.cut(sentence))
  68. for i, wp in enumerate(words):
  69. if self.pairfilter(wp):
  70. for j in xrange(i + 1, i + self.span):
  71. if j >= len(words):
  72. break
  73. if not self.pairfilter(words[j]):
  74. continue
  75. if allowPOS and withFlag:
  76. cm[(wp, words[j])] += 1
  77. else:
  78. cm[(wp.word, words[j].word)] += 1
  79. for terms, w in cm.items():
  80. g.addEdge(terms[0], terms[1], w)
  81. nodes_rank = g.rank()
  82. if withWeight:
  83. tags = sorted(nodes_rank.items(), key=itemgetter(1), reverse=True)
  84. else:
  85. tags = sorted(nodes_rank, key=nodes_rank.__getitem__, reverse=True)
  86. if topK:
  87. return tags[:topK]
  88. else:
  89. return tags
  90. extract_tags = textrank