forked from Legal-NLP-EkStep/rhetorical-role-baseline
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathbatch_creator.py
More file actions
131 lines (101 loc) · 4.17 KB
/
batch_creator.py
File metadata and controls
131 lines (101 loc) · 4.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from allennlp.common.util import pad_sequence_to_length
import bucketing as bucketing
import torch
import numpy as np
class BatchCreator:
def __init__(self, dataset, tokenizer, labels, batch_sentence_size, max_seq_length):
#'''dataset: Iterable over documents
# tokenizer: WordPiece tokenizer of BERT. If None, then it is assumed that sentences are already tokenized.
# labels: possible lables of the sentences
# max_sequence_length: Max number of tokens for each sentences. Only required if tokenizer is provided.
#'''
self.dataset = dataset
self.labels = labels
self.batch_sentence_size = batch_sentence_size
self.max_sequence_length = max_seq_length
self.tokenizer = tokenizer
self.batches = None
def get_batches(self, task_name=None):
self.build_batches()
result = []
for b in self.batches:
batch = self.batch_to_tensor(b)
if task_name is not None:
batch["task"] = task_name
result.append(batch)
return result
def build_batches(self):
if self.batches is None:
def wrap_document(doc):
return bucketing.Record(doc.get_sentence_count(), doc)
mapped_ds = map(wrap_document, self.dataset)
self.batches = bucketing.bucket_records(mapped_ds, self.batch_sentence_size)
return len(self.batches)
def get_batches_count(self):
batches_count = self.build_batches()
return batches_count
def batch_to_tensor(self, b):
# dictionary of arrays
tensors_dict_arrays = b.to_tensor(self.document_to_sequence_example, merge_records)
# convert to dictionary of tensors and pad the tensors
result = {}
for k, v in tensors_dict_arrays.items():
if k in ["input_ids", "attention_mask"]:
# determine the max sentence len in the batch
max_sentence_len = -1
for doc in v:
for sentence in doc:
max_sentence_len = max(len(sentence), max_sentence_len)
# pad the sentences to max sentence len
for doc in v:
for i, sentence in enumerate(doc):
doc[i] = pad_sequence_to_length(sentence, desired_length=max_sentence_len)
if k!='doc_name':
result[k] = torch.tensor(v)
else:
result[k] = v
return result
def document_to_sequence_example(self, document, sentence_padding_len):
sentences = list(document.data.sentences)
labels = list(document.data.labels)
# pad number of sentences
for _ in range(len(document.data.labels), sentence_padding_len):
sentences.append("")
labels.append("mask")
token_ids = []
attention_masks = []
label_ids = []
for sentence, label in zip(sentences, labels):
if self.tokenizer is None:
# sentence already tokenized
if isinstance(sentence, list):
tok_ids = sentence
else:
tok_ids = [int(t) for t in sentence.split()]
else:
tok_ids = self.tokenizer.encode(sentence, add_special_tokens=True, max_length=self.max_sequence_length)
attention_mask = [1] * len(tok_ids)
# map label id
label_id = self.labels.index(label)
token_ids.append(tok_ids)
attention_masks.append(attention_mask)
label_ids.append(label_id)
return {
"sentence_mask": pad_sequence_to_length([1] * document.length, desired_length=sentence_padding_len),
"input_ids": token_ids,
"attention_mask": attention_masks,
"label_ids": label_ids,
"doc_name": document.data.doc_name
}
def merge_records (merged, r):
if merged is None:
merged = dict()
for k in r:
merged[k] = []
for k in r:
merged[k].append(r[k])
return merged
def one_hot(num, v):
r = np.zeros(num, dtype=int)
r[v] = 1
return r