__init__.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. from __future__ import absolute_import, unicode_literals
  2. import pickle
  3. import re
  4. import jieba
  5. from .viterbi import viterbi
  6. from .._compat import *
  7. PROB_START_P = "prob_start.p"
  8. PROB_TRANS_P = "prob_trans.p"
  9. PROB_EMIT_P = "prob_emit.p"
  10. CHAR_STATE_TAB_P = "char_state_tab.p"
  11. re_han_detail = re.compile("([\u4E00-\u9FD5]+)")
  12. re_skip_detail = re.compile("([\.0-9]+|[a-zA-Z0-9]+)")
  13. re_han_internal = re.compile("([\u4E00-\u9FD5a-zA-Z0-9+#&\._]+)")
  14. re_skip_internal = re.compile("(\r\n|\s)")
  15. re_eng = re.compile("[a-zA-Z0-9]+")
  16. re_num = re.compile("[\.0-9]+")
  17. re_eng1 = re.compile('^[a-zA-Z0-9]$', re.U)
  18. def load_model():
  19. # For Jython
  20. start_p = pickle.load(get_module_res("posseg", PROB_START_P))
  21. trans_p = pickle.load(get_module_res("posseg", PROB_TRANS_P))
  22. emit_p = pickle.load(get_module_res("posseg", PROB_EMIT_P))
  23. state = pickle.load(get_module_res("posseg", CHAR_STATE_TAB_P))
  24. return state, start_p, trans_p, emit_p
  25. if sys.platform.startswith("java"):
  26. char_state_tab_P, start_P, trans_P, emit_P = load_model()
  27. else:
  28. from .char_state_tab import P as char_state_tab_P
  29. from .prob_start import P as start_P
  30. from .prob_trans import P as trans_P
  31. from .prob_emit import P as emit_P
  32. class pair(object):
  33. def __init__(self, word, flag):
  34. self.word = word
  35. self.flag = flag
  36. def __unicode__(self):
  37. return '%s/%s' % (self.word, self.flag)
  38. def __repr__(self):
  39. return 'pair(%r, %r)' % (self.word, self.flag)
  40. def __str__(self):
  41. if PY2:
  42. return self.__unicode__().encode(default_encoding)
  43. else:
  44. return self.__unicode__()
  45. def __iter__(self):
  46. return iter((self.word, self.flag))
  47. def __lt__(self, other):
  48. return self.word < other.word
  49. def __eq__(self, other):
  50. return isinstance(other, pair) and self.word == other.word and self.flag == other.flag
  51. def __hash__(self):
  52. return hash(self.word)
  53. def encode(self, arg):
  54. return self.__unicode__().encode(arg)
  55. class POSTokenizer(object):
  56. def __init__(self, tokenizer=None):
  57. self.tokenizer = tokenizer or jieba.Tokenizer()
  58. self.load_word_tag(self.tokenizer.get_dict_file())
  59. def __repr__(self):
  60. return '<POSTokenizer tokenizer=%r>' % self.tokenizer
  61. def __getattr__(self, name):
  62. if name in ('cut_for_search', 'lcut_for_search', 'tokenize'):
  63. # may be possible?
  64. raise NotImplementedError
  65. return getattr(self.tokenizer, name)
  66. def initialize(self, dictionary=None):
  67. self.tokenizer.initialize(dictionary)
  68. self.load_word_tag(self.tokenizer.get_dict_file())
  69. def load_word_tag(self, f):
  70. self.word_tag_tab = {}
  71. f_name = resolve_filename(f)
  72. for lineno, line in enumerate(f, 1):
  73. try:
  74. line = line.strip().decode("utf-8")
  75. if not line:
  76. continue
  77. word, _, tag = line.split(" ")
  78. self.word_tag_tab[word] = tag
  79. except Exception:
  80. raise ValueError(
  81. 'invalid POS dictionary entry in %s at Line %s: %s' % (f_name, lineno, line))
  82. f.close()
  83. def makesure_userdict_loaded(self):
  84. if self.tokenizer.user_word_tag_tab:
  85. self.word_tag_tab.update(self.tokenizer.user_word_tag_tab)
  86. self.tokenizer.user_word_tag_tab = {}
  87. def __cut(self, sentence):
  88. prob, pos_list = viterbi(
  89. sentence, char_state_tab_P, start_P, trans_P, emit_P)
  90. begin, nexti = 0, 0
  91. for i, char in enumerate(sentence):
  92. pos = pos_list[i][0]
  93. if pos == 'B':
  94. begin = i
  95. elif pos == 'E':
  96. yield pair(sentence[begin:i + 1], pos_list[i][1])
  97. nexti = i + 1
  98. elif pos == 'S':
  99. yield pair(char, pos_list[i][1])
  100. nexti = i + 1
  101. if nexti < len(sentence):
  102. yield pair(sentence[nexti:], pos_list[nexti][1])
  103. def __cut_detail(self, sentence):
  104. blocks = re_han_detail.split(sentence)
  105. for blk in blocks:
  106. if re_han_detail.match(blk):
  107. for word in self.__cut(blk):
  108. yield word
  109. else:
  110. tmp = re_skip_detail.split(blk)
  111. for x in tmp:
  112. if x:
  113. if re_num.match(x):
  114. yield pair(x, 'm')
  115. elif re_eng.match(x):
  116. yield pair(x, 'eng')
  117. else:
  118. yield pair(x, 'x')
  119. def __cut_DAG_NO_HMM(self, sentence):
  120. DAG = self.tokenizer.get_DAG(sentence)
  121. route = {}
  122. self.tokenizer.calc(sentence, DAG, route)
  123. x = 0
  124. N = len(sentence)
  125. buf = ''
  126. while x < N:
  127. y = route[x][1] + 1
  128. l_word = sentence[x:y]
  129. if re_eng1.match(l_word):
  130. buf += l_word
  131. x = y
  132. else:
  133. if buf:
  134. yield pair(buf, 'eng')
  135. buf = ''
  136. yield pair(l_word, self.word_tag_tab.get(l_word, 'x'))
  137. x = y
  138. if buf:
  139. yield pair(buf, 'eng')
  140. buf = ''
  141. def __cut_DAG(self, sentence):
  142. DAG = self.tokenizer.get_DAG(sentence)
  143. route = {}
  144. self.tokenizer.calc(sentence, DAG, route)
  145. x = 0
  146. buf = ''
  147. N = len(sentence)
  148. while x < N:
  149. y = route[x][1] + 1
  150. l_word = sentence[x:y]
  151. if y - x == 1:
  152. buf += l_word
  153. else:
  154. if buf:
  155. if len(buf) == 1:
  156. yield pair(buf, self.word_tag_tab.get(buf, 'x'))
  157. elif not self.tokenizer.FREQ.get(buf):
  158. recognized = self.__cut_detail(buf)
  159. for t in recognized:
  160. yield t
  161. else:
  162. for elem in buf:
  163. yield pair(elem, self.word_tag_tab.get(elem, 'x'))
  164. buf = ''
  165. yield pair(l_word, self.word_tag_tab.get(l_word, 'x'))
  166. x = y
  167. if buf:
  168. if len(buf) == 1:
  169. yield pair(buf, self.word_tag_tab.get(buf, 'x'))
  170. elif not self.tokenizer.FREQ.get(buf):
  171. recognized = self.__cut_detail(buf)
  172. for t in recognized:
  173. yield t
  174. else:
  175. for elem in buf:
  176. yield pair(elem, self.word_tag_tab.get(elem, 'x'))
  177. def __cut_internal(self, sentence, HMM=True):
  178. self.makesure_userdict_loaded()
  179. sentence = strdecode(sentence)
  180. blocks = re_han_internal.split(sentence)
  181. if HMM:
  182. cut_blk = self.__cut_DAG
  183. else:
  184. cut_blk = self.__cut_DAG_NO_HMM
  185. for blk in blocks:
  186. if re_han_internal.match(blk):
  187. for word in cut_blk(blk):
  188. yield word
  189. else:
  190. tmp = re_skip_internal.split(blk)
  191. for x in tmp:
  192. if re_skip_internal.match(x):
  193. yield pair(x, 'x')
  194. else:
  195. for xx in x:
  196. if re_num.match(xx):
  197. yield pair(xx, 'm')
  198. elif re_eng.match(x):
  199. yield pair(xx, 'eng')
  200. else:
  201. yield pair(xx, 'x')
  202. def _lcut_internal(self, sentence):
  203. return list(self.__cut_internal(sentence))
  204. def _lcut_internal_no_hmm(self, sentence):
  205. return list(self.__cut_internal(sentence, False))
  206. def cut(self, sentence, HMM=True):
  207. for w in self.__cut_internal(sentence, HMM=HMM):
  208. yield w
  209. def lcut(self, *args, **kwargs):
  210. return list(self.cut(*args, **kwargs))
  211. # default Tokenizer instance
  212. dt = POSTokenizer(jieba.dt)
  213. # global functions
  214. initialize = dt.initialize
  215. def _lcut_internal(s):
  216. return dt._lcut_internal(s)
  217. def _lcut_internal_no_hmm(s):
  218. return dt._lcut_internal_no_hmm(s)
  219. def cut(sentence, HMM=True, use_paddle=False):
  220. """
  221. Global `cut` function that supports parallel processing.
  222. Note that this only works using dt, custom POSTokenizer
  223. instances are not supported.
  224. """
  225. is_paddle_installed = check_paddle_install['is_paddle_installed']
  226. if use_paddle and is_paddle_installed:
  227. # if sentence is null, it will raise core exception in paddle.
  228. if sentence is None or sentence == "" or sentence == u"":
  229. return
  230. import jieba.lac_small.predict as predict
  231. sents, tags = predict.get_result(strdecode(sentence))
  232. for i, sent in enumerate(sents):
  233. if sent is None or tags[i] is None:
  234. continue
  235. yield pair(sent, tags[i])
  236. return
  237. global dt
  238. if jieba.pool is None:
  239. for w in dt.cut(sentence, HMM=HMM):
  240. yield w
  241. else:
  242. parts = strdecode(sentence).splitlines(True)
  243. if HMM:
  244. result = jieba.pool.map(_lcut_internal, parts)
  245. else:
  246. result = jieba.pool.map(_lcut_internal_no_hmm, parts)
  247. for r in result:
  248. for w in r:
  249. yield w
  250. def lcut(sentence, HMM=True, use_paddle=False):
  251. if use_paddle:
  252. return list(cut(sentence, use_paddle=True))
  253. return list(cut(sentence, HMM))