predict.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # -*- coding: UTF-8 -*-
  2. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import os
  17. import time
  18. import sys
  19. import paddle.fluid as fluid
  20. import paddle
  21. import jieba.lac_small.utils as utils
  22. import jieba.lac_small.creator as creator
  23. import jieba.lac_small.reader_small as reader_small
  24. import numpy
  25. word_emb_dim=128
  26. grnn_hidden_dim=128
  27. bigru_num=2
  28. use_cuda=False
  29. basepath = os.path.abspath(__file__)
  30. folder = os.path.dirname(basepath)
  31. init_checkpoint = os.path.join(folder, "model_baseline")
  32. batch_size=1
  33. dataset = reader_small.Dataset()
  34. infer_program = fluid.Program()
  35. with fluid.program_guard(infer_program, fluid.default_startup_program()):
  36. with fluid.unique_name.guard():
  37. infer_ret = creator.create_model(dataset.vocab_size, dataset.num_labels, mode='infer')
  38. infer_program = infer_program.clone(for_test=True)
  39. place = fluid.CPUPlace()
  40. exe = fluid.Executor(place)
  41. exe.run(fluid.default_startup_program())
  42. utils.init_checkpoint(exe, init_checkpoint, infer_program)
  43. results = []
  44. def get_sent(str1):
  45. feed_data=dataset.get_vars(str1)
  46. a = numpy.array(feed_data).astype(numpy.int64)
  47. a=a.reshape(-1,1)
  48. c = fluid.create_lod_tensor(a, [[a.shape[0]]], place)
  49. words, crf_decode = exe.run(
  50. infer_program,
  51. fetch_list=[infer_ret['words'], infer_ret['crf_decode']],
  52. feed={"words":c, },
  53. return_numpy=False,
  54. use_program_cache=True)
  55. sents=[]
  56. sent,tag = utils.parse_result(words, crf_decode, dataset)
  57. sents = sents + sent
  58. return sents
  59. def get_result(str1):
  60. feed_data=dataset.get_vars(str1)
  61. a = numpy.array(feed_data).astype(numpy.int64)
  62. a=a.reshape(-1,1)
  63. c = fluid.create_lod_tensor(a, [[a.shape[0]]], place)
  64. words, crf_decode = exe.run(
  65. infer_program,
  66. fetch_list=[infer_ret['words'], infer_ret['crf_decode']],
  67. feed={"words":c, },
  68. return_numpy=False,
  69. use_program_cache=True)
  70. results=[]
  71. results += utils.parse_result(words, crf_decode, dataset)
  72. return results