From 0a2da688d111a45110b3338cd55e6be3d24647a2 Mon Sep 17 00:00:00 2001 From: hahajinbu <12061025@buaa.edu.cn> Date: Mon, 14 Oct 2019 14:50:44 +0800 Subject: [PATCH] Update similarity.py update predict method to support batch prediction. --- similarity.py | 39 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/similarity.py b/similarity.py index 6a77293..0cf12d1 100644 --- a/similarity.py +++ b/similarity.py @@ -103,7 +103,10 @@ def get_test_examples(self, data_dir): return test_data def get_sentence_examples(self, questions): - for index, data in enumerate(questions): + questions = questions[0] #self.input_queue.put([(sentences1, sentences2)]) + print(len(questions)) + for index,data in enumerate(zip(questions[0],questions[1])): +# print(data) guid = 'test-%d' % index text_a = tokenization.convert_to_unicode(str(data[0])) text_b = tokenization.convert_to_unicode(str(data[1])) @@ -301,7 +304,7 @@ def queue_predict_input_fn(self): 'input_ids': (None, self.max_seq_length), 'input_mask': (None, self.max_seq_length), 'segment_ids': (None, self.max_seq_length), - 'label_ids': (1,)}).prefetch(10)) + 'label_ids': (None,)}).prefetch(10)) def convert_examples_to_features(self, examples, label_list, max_seq_length, tokenizer): """Convert a set of `InputExample`s to a list of `InputFeatures`.""" @@ -668,7 +671,37 @@ def predict(self, sentence1, sentence2): sim.train() sim.set_mode(tf.estimator.ModeKeys.EVAL) sim.eval() - # sim.set_mode(tf.estimator.ModeKeys.PREDICT) + + #####预测测试 + sim.set_mode(tf.estimator.ModeKeys.PREDICT) + import time + results_1 = [] + t1 = time.time() + for i in range(1000): + if i % 2 ==0: + x = bs.predict(["你{}好".format(i)],["您{}好".format(i)])[0][1] + else: + x = bs.predict(["你{}好".format(i)],["不{}好".format(i)])[0][1] + results_1.append(x) + t2 = time.time() + print('predict one by one cost: {} seconds.'.format(str(t2 - t1))) + t3 = time.time() + ########=====predict batch============= + sentences_1 = [] + sentences_2 = [] + for i in range(1000): + if i % 2 ==0: + sentences_1.append("你{}好".format(i)) + sentences_2.append("您{}好".format(i)) + else: + sentences_1.append("你{}好".format(i)) + sentences_2.append("不{}好".format(i)) + batch_results_1 = bs.predict(sentences_1,sentences_2) + batch_results_1 = batch_results_1[:,1] + t4 = time.time() + print('predict batch cost: {} seconds.'.format(str(t4 - t3))) + from scipy.stats import pearsonr + print(pearsonr(results_1,batch_results_1)) ###(1.0,0.0) # while True: # sentence1 = input('sentence1: ') # sentence2 = input('sentence2: ')