utils.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. util tools
  16. """
  17. from __future__ import print_function
  18. import os
  19. import sys
  20. import numpy as np
  21. import paddle.fluid as fluid
  22. import io
  23. def str2bool(v):
  24. """
  25. argparse does not support True or False in python
  26. """
  27. return v.lower() in ("true", "t", "1")
  28. def parse_result(words, crf_decode, dataset):
  29. """ parse result """
  30. offset_list = (crf_decode.lod())[0]
  31. words = np.array(words)
  32. crf_decode = np.array(crf_decode)
  33. batch_size = len(offset_list) - 1
  34. for sent_index in range(batch_size):
  35. begin, end = offset_list[sent_index], offset_list[sent_index + 1]
  36. sent=[]
  37. for id in words[begin:end]:
  38. if dataset.id2word_dict[str(id[0])]=='OOV':
  39. sent.append(' ')
  40. else:
  41. sent.append(dataset.id2word_dict[str(id[0])])
  42. tags = [
  43. dataset.id2label_dict[str(id[0])] for id in crf_decode[begin:end]
  44. ]
  45. sent_out = []
  46. tags_out = []
  47. parital_word = ""
  48. for ind, tag in enumerate(tags):
  49. # for the first word
  50. if parital_word == "":
  51. parital_word = sent[ind]
  52. tags_out.append(tag.split('-')[0])
  53. continue
  54. # for the beginning of word
  55. if tag.endswith("-B") or (tag == "O" and tags[ind - 1] != "O"):
  56. sent_out.append(parital_word)
  57. tags_out.append(tag.split('-')[0])
  58. parital_word = sent[ind]
  59. continue
  60. parital_word += sent[ind]
  61. # append the last word, except for len(tags)=0
  62. if len(sent_out) < len(tags_out):
  63. sent_out.append(parital_word)
  64. return sent_out,tags_out
  65. def parse_padding_result(words, crf_decode, seq_lens, dataset):
  66. """ parse padding result """
  67. words = np.squeeze(words)
  68. batch_size = len(seq_lens)
  69. batch_out = []
  70. for sent_index in range(batch_size):
  71. sent=[]
  72. for id in words[begin:end]:
  73. if dataset.id2word_dict[str(id[0])]=='OOV':
  74. sent.append(' ')
  75. else:
  76. sent.append(dataset.id2word_dict[str(id[0])])
  77. tags = [
  78. dataset.id2label_dict[str(id)]
  79. for id in crf_decode[sent_index][1:seq_lens[sent_index] - 1]
  80. ]
  81. sent_out = []
  82. tags_out = []
  83. parital_word = ""
  84. for ind, tag in enumerate(tags):
  85. # for the first word
  86. if parital_word == "":
  87. parital_word = sent[ind]
  88. tags_out.append(tag.split('-')[0])
  89. continue
  90. # for the beginning of word
  91. if tag.endswith("-B") or (tag == "O" and tags[ind - 1] != "O"):
  92. sent_out.append(parital_word)
  93. tags_out.append(tag.split('-')[0])
  94. parital_word = sent[ind]
  95. continue
  96. parital_word += sent[ind]
  97. # append the last word, except for len(tags)=0
  98. if len(sent_out) < len(tags_out):
  99. sent_out.append(parital_word)
  100. batch_out.append([sent_out, tags_out])
  101. return batch_out
  102. def init_checkpoint(exe, init_checkpoint_path, main_program):
  103. """
  104. Init CheckPoint
  105. """
  106. assert os.path.exists(
  107. init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path
  108. def existed_persitables(var):
  109. """
  110. If existed presitabels
  111. """
  112. if not fluid.io.is_persistable(var):
  113. return False
  114. return os.path.exists(os.path.join(init_checkpoint_path, var.name))
  115. fluid.io.load_vars(
  116. exe,
  117. init_checkpoint_path,
  118. main_program=main_program,
  119. predicate=existed_persitables)