-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathseq_scripts.py
More file actions
executable file
·137 lines (125 loc) · 5.62 KB
/
Copy pathseq_scripts.py
File metadata and controls
executable file
·137 lines (125 loc) · 5.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import sys
import torch
import numpy as np
from tqdm import tqdm
from evaluation.slr_eval.wer_calculation import evaluate
from torch.cuda.amp import autocast as autocast
from torch.cuda.amp import GradScaler
def seq_train(loader, model, optimizer, device, epoch_idx, recoder, loss_weights=None):
model.train()
loss_value = []
total_loss_dict = {} # dict of all types of loss
for k in loss_weights.keys():
total_loss_dict[k] = 0
clr = [group['lr'] for group in optimizer.optimizer.param_groups]
scaler = GradScaler()
for batch_idx, data in enumerate(loader):
vid = device.data_to_device(data[0])
vid_lgt = device.data_to_device(data[1])
label = device.data_to_device(data[2])
label_lgt = device.data_to_device(data[3])
optimizer.zero_grad()
with autocast():
ret_dict = model(vid, vid_lgt, label=label, label_lgt=label_lgt)
loss, loss_dict = model.criterion_calculation(ret_dict, label, label_lgt)
if np.isinf(loss.item()) or np.isnan(loss.item()):
print('loss is nan')
#print(data[-1])
print(str(data[1])+' frames')
print(str(data[3])+' glosses')
continue
scaler.scale(loss).backward()
scaler.step(optimizer.optimizer)
scaler.update()
loss_value.append(loss.item())
for item, value in loss_dict.items():
total_loss_dict[item] += value
if batch_idx % recoder.log_interval == 0:
recoder.print_log(
'\tEpoch: {}, Batch({}/{}) done. Loss: {:.8f} lr:{:.6f}'
.format(epoch_idx, batch_idx, len(loader), loss.item(), clr[0]))
for item, value in total_loss_dict.items():
recoder.print_log(f'\t Mean {item} loss: {value/recoder.log_interval:.5f}')
total_loss_dict = {}
for k in loss_weights.keys():
total_loss_dict[k] = 0
del ret_dict
del loss
optimizer.scheduler.step()
recoder.print_log('\tMean training loss: {:.10f}.'.format(np.mean(loss_value)))
return loss_value
def seq_eval(cfg, loader, model, device, mode, epoch, work_dir, recoder):
model.eval()
total_sent = []
total_info = []
#save_file = {}
stat = {i: [0, 0] for i in range(len(loader.dataset.dict))}
for batch_idx, data in enumerate(tqdm(loader)):
recoder.record_timer("device")
vid = device.data_to_device(data[0])
vid_lgt = device.data_to_device(data[1])
label = device.data_to_device(data[2])
label_lgt = device.data_to_device(data[3])
with torch.no_grad():
ret_dict = model(vid, vid_lgt, label=label, label_lgt=label_lgt)
total_info += [file_name.split("|")[0] for file_name in data[-1]]
total_sent += ret_dict['recognized_sents']
try:
write2file(work_dir + "output-hypothesis-{}.ctm".format(mode), total_info, total_sent)
ret = evaluate(prefix=work_dir, mode=mode, output_file="output-hypothesis-{}.ctm".format(mode),
evaluate_dir=cfg.dataset_info['evaluation_dir'],
evaluate_prefix=cfg.dataset_info['evaluation_prefix'],
output_dir="epoch_{}_result/".format(epoch))
except:
print("Unexpected error:", sys.exc_info()[0])
ret = "Percent Total Error = 100.00% (ERROR)"
return float(ret.split("=")[1].split("%")[0])
finally:
pass
recoder.print_log("Epoch {}, {} {}".format(epoch, mode, ret),
'{}/{}.txt'.format(work_dir, mode))
return float(ret.split("=")[1].split("%")[0])
def seq_feature_generation(loader, model, device, mode, work_dir, recoder):
model.eval()
src_path = os.path.abspath(f"{work_dir}{mode}")
tgt_path = os.path.abspath(f"./features/{mode}")
if os.path.islink(tgt_path):
curr_path = os.readlink(tgt_path)
if work_dir[1:] in curr_path and os.path.isabs(curr_path):
return
else:
os.unlink(tgt_path)
else:
if os.path.exists(src_path) and len(loader.dataset) == len(os.listdir(src_path)):
os.symlink(src_path, tgt_path)
return
for batch_idx, data in tqdm(enumerate(loader)):
recoder.record_timer("device")
vid = device.data_to_device(data[0])
vid_lgt = device.data_to_device(data[1])
with torch.no_grad():
ret_dict = model(vid, vid_lgt)
if not os.path.exists(src_path):
os.makedirs(src_path)
start = 0
for sample_idx in range(len(vid)):
end = start + data[3][sample_idx]
filename = f"{src_path}/{data[-1][sample_idx].split('|')[0]}_features.npy"
save_file = {
"label": data[2][start:end],
"features": ret_dict['framewise_features'][sample_idx][:, :vid_lgt[sample_idx]].T.cpu().detach(),
}
np.save(filename, save_file)
start = end
os.symlink(src_path, tgt_path)
assert end == len(data[2])
def write2file(path, info, output):
filereader = open(path, "w")
for sample_idx, sample in enumerate(output):
for word_idx, word in enumerate(sample):
filereader.writelines(
"{} 1 {:.2f} {:.2f} {}\n".format(info[sample_idx],
word_idx * 1.0 / 100,
(word_idx + 1) * 1.0 / 100,
word[0]))