reader_small.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. The file_reader converts raw corpus to input.
  16. """
  17. import os
  18. import __future__
  19. import io
  20. import paddle
  21. import paddle.fluid as fluid
  22. def load_kv_dict(dict_path,
  23. reverse=False,
  24. delimiter="\t",
  25. key_func=None,
  26. value_func=None):
  27. """
  28. Load key-value dict from file
  29. """
  30. result_dict = {}
  31. for line in io.open(dict_path, "r", encoding='utf8'):
  32. terms = line.strip("\n").split(delimiter)
  33. if len(terms) != 2:
  34. continue
  35. if reverse:
  36. value, key = terms
  37. else:
  38. key, value = terms
  39. if key in result_dict:
  40. raise KeyError("key duplicated with [%s]" % (key))
  41. if key_func:
  42. key = key_func(key)
  43. if value_func:
  44. value = value_func(value)
  45. result_dict[key] = value
  46. return result_dict
  47. class Dataset(object):
  48. """data reader"""
  49. def __init__(self):
  50. # read dict
  51. basepath = os.path.abspath(__file__)
  52. folder = os.path.dirname(basepath)
  53. word_dict_path = os.path.join(folder, "word.dic")
  54. label_dict_path = os.path.join(folder, "tag.dic")
  55. self.word2id_dict = load_kv_dict(
  56. word_dict_path, reverse=True, value_func=int)
  57. self.id2word_dict = load_kv_dict(word_dict_path)
  58. self.label2id_dict = load_kv_dict(
  59. label_dict_path, reverse=True, value_func=int)
  60. self.id2label_dict = load_kv_dict(label_dict_path)
  61. @property
  62. def vocab_size(self):
  63. """vocabulary size"""
  64. return max(self.word2id_dict.values()) + 1
  65. @property
  66. def num_labels(self):
  67. """num_labels"""
  68. return max(self.label2id_dict.values()) + 1
  69. def word_to_ids(self, words):
  70. """convert word to word index"""
  71. word_ids = []
  72. for word in words:
  73. if word not in self.word2id_dict:
  74. word = "OOV"
  75. word_id = self.word2id_dict[word]
  76. word_ids.append(word_id)
  77. return word_ids
  78. def label_to_ids(self, labels):
  79. """convert label to label index"""
  80. label_ids = []
  81. for label in labels:
  82. if label not in self.label2id_dict:
  83. label = "O"
  84. label_id = self.label2id_dict[label]
  85. label_ids.append(label_id)
  86. return label_ids
  87. def get_vars(self,str1):
  88. words = str1.strip()
  89. word_ids = self.word_to_ids(words)
  90. return word_ids