Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions config/eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@

import numpy as np
from typing import List, Tuple, Dict

from collections import defaultdict, Counter

class Span:

Expand Down Expand Up @@ -62,11 +65,16 @@ def evaluate(insts):
return [precision, recall, fscore]


def evaluate_num(batch_insts, batch_pred_ids, batch_gold_ids, word_seq_lens, idx2label):
def evaluate_num(batch_insts, batch_pred_ids, batch_gold_ids, word_seq_lens, idx2label) -> Tuple[Dict, Dict, Dict]:

# p = 0
# total_entity = 0
# total_predict = 0

batch_p_dict = defaultdict(int)
batch_total_entity_dict = defaultdict(int)
batch_total_predict_dict = defaultdict(int)

p = 0
total_entity = 0
total_predict = 0
word_seq_lens = word_seq_lens.tolist()
for idx in range(len(batch_pred_ids)):
length = word_seq_lens[idx]
Expand All @@ -85,24 +93,30 @@ def evaluate_num(batch_insts, batch_pred_ids, batch_gold_ids, word_seq_lens, idx
if output[i].startswith("E-"):
end = i
output_spans.add(Span(start, end, output[i][2:]))
batch_total_entity_dict[output[i][2:]] += 1
if output[i].startswith("S-"):
output_spans.add(Span(i, i, output[i][2:]))
batch_total_entity_dict[output[i][2:]] += 1
predict_spans = set()
for i in range(len(prediction)):
if prediction[i].startswith("B-"):
start = i
if prediction[i].startswith("E-"):
end = i
predict_spans.add(Span(start, end, prediction[i][2:]))
batch_total_predict_dict[prediction[i][2:]] += 1
if prediction[i].startswith("S-"):
predict_spans.add(Span(i, i, prediction[i][2:]))

total_entity += len(output_spans)
total_predict += len(predict_spans)
p += len(predict_spans.intersection(output_spans))

batch_total_predict_dict[prediction[i][2:]] += 1
# total_entity += len(output_spans)
# total_predict += len(predict_spans)
# p += len(predict_spans.intersection(output_spans))
correct_spans = predict_spans.intersection(output_spans)
for span in correct_spans:
batch_p_dict[span.type] += 1
# precision = p * 1.0 / total_predict * 100 if total_predict != 0 else 0
# recall = p * 1.0 / total_entity * 100 if total_entity != 0 else 0
# fscore = 2.0 * precision * recall / (precision + recall) if precision != 0 or recall != 0 else 0

return np.asarray([p, total_predict, total_entity], dtype=int)
# return np.asarray([p, total_predict, total_entity], dtype=int)
return Counter(batch_p_dict), Counter(batch_total_predict_dict), Counter(batch_total_entity_dict)
27 changes: 21 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from config.reader import Reader
from config import eval
from config.eval import evaluate_num
from config.config import Config, ContextEmb, DepModelType
import time
from model.lstmcrf import NNCRF
Expand All @@ -15,6 +16,7 @@
from common.instance import Instance
from termcolor import colored
import adabound
from collections import Counter


def setSeed(opt, seed):
Expand Down Expand Up @@ -158,6 +160,7 @@ def learn_from_insts(config:Config, epoch: int, train_insts, dev_insts, test_ins
if i + 1 >= config.eval_epoch:
model.eval()
dev_metrics = evaluate(config, model, dev_batches, "dev", dev_insts)
print()
test_metrics = evaluate(config, model, test_batches, "test", test_insts)
if dev_metrics[2] > best_dev[0]:
print("saving the best model...")
Expand All @@ -181,18 +184,31 @@ def learn_from_insts(config:Config, epoch: int, train_insts, dev_insts, test_ins

def evaluate(config:Config, model: NNCRF, batch_insts_ids, name:str, insts: List[Instance]):
## evaluation
metrics = np.asarray([0, 0, 0], dtype=int)
# metrics = np.asarray([0, 0, 0], dtype=int)
p_dict, total_predict_dict, total_entity_dict = Counter(), Counter(), Counter()
batch_id = 0
batch_size = config.batch_size
for batch in batch_insts_ids:
one_batch_insts = insts[batch_id * batch_size:(batch_id + 1) * batch_size]
sorted_batch_insts = sorted(one_batch_insts, key=lambda inst: len(inst.input.words), reverse=True)
batch_max_scores, batch_max_ids = model.decode(batch)
metrics += eval.evaluate_num(sorted_batch_insts, batch_max_ids, batch[-2], batch[1], config.idx2labels)
# metrics += eval.evaluate_num(sorted_batch_insts, batch_max_ids, batch[-2], batch[1], config.idx2labels)
batch_p, batch_predict, batch_total = eval.evaluate_num(one_batch_insts, batch_max_ids, batch[-2], batch[1], config.idx2labels)
p_dict += batch_p
total_predict_dict += batch_predict
total_entity_dict += batch_total
batch_id += 1
p, total_predict, total_entity = metrics[0], metrics[1], metrics[2]
precision = p * 1.0 / total_predict * 100 if total_predict != 0 else 0
recall = p * 1.0 / total_entity * 100 if total_entity != 0 else 0
for key in total_entity_dict:
precision = p_dict[key] * 1.0 / total_predict_dict[key] * 100 if total_predict_dict[key] != 0 else 0
recall = p_dict[key] * 1.0 / total_entity_dict[key] * 100 if total_entity_dict[key] != 0 else 0
fscore = 2.0 * precision * recall / (precision + recall) if precision != 0 or recall != 0 else 0
print("[%s] Prec.: %.2f, Rec.: %.2f, F1: %.2f" % (key, precision, recall, fscore))
total_p = sum(list(p_dict.values()))
total_entity = sum(list(total_entity_dict.values()))
total_predict = sum(list(total_predict_dict.values()))
# total_predict, total_entity = metrics[0], metrics[1], metrics[2]
precision = total_p * 1.0 / total_predict * 100 if total_predict != 0 else 0
recall = total_p * 1.0 / total_entity * 100 if total_entity != 0 else 0
fscore = 2.0 * precision * recall / (precision + recall) if precision != 0 or recall != 0 else 0
print("[%s set] Precision: %.2f, Recall: %.2f, F1: %.2f" % (name, precision, recall,fscore), flush=True)
return [precision, recall, fscore]
Expand Down Expand Up @@ -239,7 +255,6 @@ def write_results(filename:str, insts):
dep_labels = inst.input.dep_labels
output = inst.output
prediction = inst.prediction
assert len(output) == len(prediction)
f.write("{}\t{}\t{}\t{}\t{}\t{}\t{}\n".format(i, words[i], tags[i], heads[i], dep_labels[i], output[i], prediction[i]))
f.write("\n")
f.close()
Expand Down