-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
executable file
·60 lines (50 loc) · 1.68 KB
/
Copy pathevaluate.py
File metadata and controls
executable file
·60 lines (50 loc) · 1.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#!/usr/bin/env python3
#
# Created on Mon Apr 12 2021
#
# Arthur Lang
# evaluate.py
#
import sys
import numpy
import matplotlib.pyplot as plt
import seaborn
from sklearn.model_selection import KFold
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from src.Model import Model
from src.dataset import load
def checkArg(av):
return len(av) == 2
def displayConfusionMatrix(matrix):
print(matrix)
seaborn.set(font_scale=1.4) # for label size
seaborn.heatmap(matrix, annot=True, annot_kws={"size": 16}) # font size
plt.show()
def crossValidation(data, target, model, split_size=5):
results = []
kf = KFold(n_splits=split_size)
for trainIdx, valIdx in kf.split(data, target):
trainData = data[trainIdx]
trainTarget = target[trainIdx]
testData = data[valIdx]
testTarget = target[valIdx]
model.trainWithoutValidation(trainData, trainTarget)
predict = model.predict(testData)
results.append(confusion_matrix(testTarget, predict))
return results
def evaluate():
av = sys.argv
if (not(checkArg(av))):
print("Error: invalid number of arguments. Please specified a valid path of a model.", file=sys.stderr)
return -1
model = Model(modelPath = av[1])
data, labels = load(preprocess=True)
val = input("Run a cross validation ? [Y]es, [N]o\t")
if (val == "Y"):
result = crossValidation(numpy.array(data), numpy.array(labels), model)
displayConfusionMatrix(numpy.sum(result, axis=0))
elif (val == "N"):
prediction = model.predict(data)
displayConfusionMatrix(confusion_matrix(labels, prediction))
if __name__ == "__main__":
evaluate()