-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
108 lines (91 loc) · 3.73 KB
/
test.py
File metadata and controls
108 lines (91 loc) · 3.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import numpy as np
from sklearn.metrics import confusion_matrix
from transformers import CvtForImageClassification
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
from pathlib import Path
# --- Config ---
BASE_MODEL = "microsoft/cvt-13"
# Point this to the .pth file that crashed earlier
CHECKPOINT_PATH = "32bit_checkpoints/Finetune_best.pth"
DATA_DIR = "data_processed/32bit/challenge"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# --- Minimal Dataset Re-implementation ---
class SimpleDataset(Dataset):
def __init__(self, root):
self.samples = []
self.root = Path(root)
# Ensure consistent class order
self.classes = sorted([d.name for d in self.root.iterdir() if d.is_dir()])
self.class_to_idx = {c: i for i, c in enumerate(self.classes)}
for c in self.classes:
for p in (self.root / c).glob("*.npy"):
self.samples.append((p, self.class_to_idx[c]))
def __len__(self): return len(self.samples)
def __getitem__(self, i):
p, label = self.samples[i]
# Load npy, convert to float tensor (0-1), add channel dim if needed
t = torch.from_numpy(np.load(p)).float().div(255.0)
if len(t.shape)==2: t = t.unsqueeze(0)
return t.repeat(3,1,1), label
def main():
print(f"Diagnosing Model: {CHECKPOINT_PATH}")
# 1. Load Data first to get class names
ds = SimpleDataset(DATA_DIR)
print(f"Classes found: {ds.classes}")
# 2. Re-Initialize the Architecture
# We must tell it to use 2 labels so the head matches your checkpoint
print(f"Initializing {BASE_MODEL} architecture...")
model = CvtForImageClassification.from_pretrained(
BASE_MODEL,
num_labels=len(ds.classes),
ignore_mismatched_sizes=True
)
# 3. Load the State Dict
print("Loading trained weights...")
try:
state_dict = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
model.load_state_dict(state_dict)
except Exception as e:
print(f"\n[CRITICAL ERROR] Could not load weights: {e}")
print("Ensure you are pointing to the .pth file, not a directory.")
return
model.to(DEVICE)
model.eval()
# 4. Run Inference
norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
dl = DataLoader(ds, batch_size=32, shuffle=False)
all_preds = []
all_labels = []
print("Running inference on Challenge set...")
with torch.no_grad():
for x, y in dl:
x = norm(x.to(DEVICE))
logits = model(pixel_values=x).logits
preds = torch.argmax(logits, dim=1).cpu().numpy()
all_preds.extend(preds)
all_labels.extend(y.numpy())
# 5. Results
cm = confusion_matrix(all_labels, all_preds)
# Dynamically handle 2x2 matrix (in case batch size is small or classes missing)
tn, fp, fn, tp = cm.ravel()
print("\n" + "="*40)
print("--- CHALLENGE CONFUSION MATRIX ---")
print("="*40)
print(f"True Negatives (Goodware Correct) : {tn}")
print(f"False Positives (Goodware -> Malware): {fp}")
print(f"False Negatives (Malware -> Goodware): {fn}")
print(f"True Positives (Malware Correct) : {tp}")
print("-" * 40)
if fp > tp and fp > tn:
print("\n[DIAGNOSIS]: HIGH FALSE POSITIVES DETECTED.")
print("The model is flagging benign files as malware.")
print("Likely cause: Packed/High-Entropy Goodware in Challenge Set.")
elif fn > tp:
print("\n[DIAGNOSIS]: HIGH FALSE NEGATIVES.")
print("The model is missing malware.")
else:
print("\n[DIAGNOSIS]: Results are mixed.")
if __name__ == "__main__":
main()