-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathpltModelTimegap.py
More file actions
107 lines (100 loc) · 3.29 KB
/
pltModelTimegap.py
File metadata and controls
107 lines (100 loc) · 3.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
read the result.txt and plot the model-timeGap-loss figures
"""
# run this code under ssh mode, you need to add the following two lines codes.
import matplotlib
# matplotlib.use('Agg')
import re
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import argparse
from constants import RES_DIR
def getList(filename):
model = []
trainError = []
valError = []
testError = []
timeGap = []
f = open(filename)
lines = f.readlines()
for line in lines:
if not line:
continue
line = line.strip('\n')
line = line.replace(':',' ')
line = re.split(" ", line)
# skip some useless lines
if len(line)<2:
continue
if line[1] == "model":
model.append(line[3])
elif line[1] == "train":
trainError.append(float(line[3]))
elif line[1] == "validation":
valError.append(float(line[3]))
elif line[1] == "test":
testError.append(float(line[3]))
elif line[1] == "gap":
timeGap.append(int(line[3]))
else:
# print("Error !")
pass
# print(line)
#print(model)
f.close()
return model, trainError, valError, testError, timeGap
def find_repeat(source, elmt):
elmt_index = []
s_idx, e_idx = 0, len(source)
while (s_idx < e_idx):
try:
temp = source.index(elmt, s_idx, e_idx)
elmt_index.append(temp)
s_idx = temp + 1
except ValueError:
break
return elmt_index
def pltFigure(modelname, matrix, idx):
plt.figure(idx)
styles = ['r-', 'b-', 'g-', 'p-']
labels = ['train loss', 'test loss', 'val loss']
# plot all the loss(train test validation)
# for i in range(len(matrix)-1):
# plt.plot(matrix[0], matrix[i+1], styles[i], label=labels[i])
# plt.title(modelname + " - timeGap - loss")
# plt.xlabel("Time gap")
# plt.ylabel("loss")
# plot only test loss
plt.plot(matrix[0], matrix[2], styles[1], label=labels[1])
plt.title(modelname + " - timeGap - loss")
plt.xlabel("Time gap")
plt.ylabel("loss")
plt.legend(loc='upper right')
plt.savefig('results/'+modelname+'_'+'loss_comparation'+'.png')
# plt.show()
def main(filename):
models, trainErrors, valErrors, testErrors, timeGaps = getList(filename)
unqModels = set(models)
posIdxMatrix = []
for model in unqModels:
posIdxMatrix.append(find_repeat(models, model))
unqModels = list(unqModels)
for i in range(len(unqModels)):
resMatrix = []
# resMatrix.append(posIdxMatrix[i])
tmp = [timeGaps[idx] for idx in posIdxMatrix[i]]
resMatrix.append(tmp)
tmp = [trainErrors[idx] for idx in posIdxMatrix[i]]
resMatrix.append(tmp)
tmp = [testErrors[idx] for idx in posIdxMatrix[i]]
resMatrix.append(tmp)
tmp = [valErrors[idx] for idx in posIdxMatrix[i]]
resMatrix.append(tmp)
pltFigure(unqModels[i], resMatrix, i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Test a line detector')
parser.add_argument('-f', '--file', help='Input file', default="result.txt", type=str)
args = parser.parse_args()
print("start ploting ...")
main(args.file)
print("finished ploting !")