-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
46 lines (38 loc) · 1.29 KB
/
evaluate.py
File metadata and controls
46 lines (38 loc) · 1.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from tensorflow.keras.datasets import mnist # pyright: ignore[reportMissingImports]
import joblib
from models_helpers import (
normalize_and_flatten,
load_model,
print_metrics,
plot_confusion,
subset_classes
)
def evaluate_model(path, X_test, y_test, name, labels):
model = load_model(path)
y_pred = model.predict(X_test)
print(f"\n=== {name} ===")
print_metrics(y_test, y_pred)
plot_confusion(y_test, y_pred, labels)
def main():
# 1. Load test split
#y = y.ravel()
_,(X_test, y_test) = mnist.load_data()
X_raw, y = subset_classes(X_test, y_test, keep_labels=[0,1,2,3,4,5,6,7,8,9],per_class=500)
# 2. Preprocess
X_flat = normalize_and_flatten(X_raw, size=(28,28))
scaler = load_model("scaler.pkl")
X_s = scaler.transform(X_flat)
labels = ['0','1','2','3','4','5','6','7','8','9']
# Evaluate
for fname, title in [
("nb_model.pkl", "Naive Bayes"),
("dt_model.pkl", "Decision Tree"),
("mlp_model.pkl", "MLP Classifier")
]:
model = load_model(fname)
y_pred = model.predict(X_s)
print(f"\n=== {title} ===")
print_metrics(y, y_pred)
plot_confusion(y, y_pred, labels)
if __name__ == "__main__":
main()