-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_model.py
More file actions
238 lines (194 loc) · 8.42 KB
/
test_model.py
File metadata and controls
238 lines (194 loc) · 8.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import torch
import torch.nn as nn
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
from pathlib import Path
import math
import csv
import os
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, confusion_matrix
# --- Configuration ---
TEST_DIR = Path("data_separated/32bit/challenge")
MODEL_PATH = "32bit_custom_cnn_interrupted.pth" # Update this to your actual saved model name
OUTPUT_CSV = "challenge_predictions.csv"
IMG_SIZE = 256
# Hardware Settings
BATCH_SIZE = 64 # Safe size for 16GB VRAM
NUM_WORKERS = 4 # Reduced to prevent Disk I/O choking on large files
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# OOM Safety Threshold (100 MB)
MEMMAP_THRESHOLD = 100 * 1024 * 1024
Image.MAX_IMAGE_PIXELS = None
# Expected folder names for Ground Truth detection
CLASS_MAP = {'goodware': 0, 'malware': 1, 'benign': 0, 'malicious': 1}
class MalwareCNN(nn.Module):
def __init__(self):
super(MalwareCNN, self).__init__()
self.conv1 = nn.Sequential(nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2))
self.conv2 = nn.Sequential(nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2))
self.conv3 = nn.Sequential(nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2))
self.conv4 = nn.Sequential(nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2))
self.flatten = nn.Flatten()
self.fc = nn.Sequential(nn.Linear(256 * 16 * 16, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, 2))
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.flatten(x)
x = self.fc(x)
return x
# --- STAMINA Logic ---
def get_width(file_size_kb):
if file_size_kb < 10: return 32
if file_size_kb < 30: return 64
if file_size_kb < 60: return 128
if file_size_kb < 100: return 256
if file_size_kb < 200: return 384
if file_size_kb < 1000: return 512
if file_size_kb < 1500: return 1024
return 2048
class SafeInferenceDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = Path(root_dir)
self.transform = transform
self.file_paths = [p for p in self.root_dir.rglob("*") if p.is_file()]
if len(self.file_paths) == 0:
print(f"[Warning] No files found in {root_dir}")
def __len__(self):
return len(self.file_paths)
def get_label_from_path(self, path):
parts = path.parts
for part in parts:
part_lower = part.lower()
if part_lower in CLASS_MAP:
return CLASS_MAP[part_lower]
return -1
def __getitem__(self, idx):
file_path = self.file_paths[idx]
file_name = file_path.name
label = self.get_label_from_path(file_path)
try:
file_stat = os.stat(file_path)
file_size = file_stat.st_size
if file_size == 0:
# Return black placeholder for empty files
img = Image.new('L', (IMG_SIZE, IMG_SIZE), 0)
else:
# 1. Determine Dimensions
file_size_kb = file_size / 1024
width = get_width(file_size_kb)
height = file_size // width
# Handle tiny files (smaller than 1 width row)
if height == 0:
with open(file_path, 'rb') as f:
data = np.frombuffer(f.read(), dtype=np.uint8)
pad_len = width - len(data)
data = np.pad(data, (0, pad_len), 'constant')
height = 1
img_array = data.reshape((1, width))
# Handle Large Files (Memmap)
elif file_size > MEMMAP_THRESHOLD:
# Create read-only memmap
img_array = np.memmap(file_path, dtype=np.uint8, mode='r', shape=(height, width))
# Handle Standard Files
else:
readable_size = height * width
with open(file_path, 'rb') as f:
data = np.frombuffer(f.read(readable_size), dtype=np.uint8)
img_array = data.reshape((height, width))
# Convert to Image
img = Image.fromarray(img_array, 'L')
# Explicit cleanup
if isinstance(img_array, np.memmap):
del img_array
if self.transform:
img = self.transform(img)
return img, str(file_name), label
except Exception as e:
print(f"Error processing {file_name}: {e}")
return torch.zeros((3, IMG_SIZE, IMG_SIZE)), str(file_name), -1
def load_model(model_path):
print(f"[Model] Loading {model_path}...")
model = MalwareCNN() # Load Custom Class
state_dict = torch.load(model_path, map_location=DEVICE)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
return model
def calculate_final_metrics(y_true, y_probs, y_preds):
valid_indices = [i for i, x in enumerate(y_true) if x != -1]
if not valid_indices:
print("\n[Metrics Warning] No ground truth labels found in folder names.")
return
y_true_valid = [y_true[i] for i in valid_indices]
y_probs_valid = [y_probs[i] for i in valid_indices]
y_preds_valid = [y_preds[i] for i in valid_indices]
try:
auc = roc_auc_score(y_true_valid, y_probs_valid)
except ValueError:
auc = 0.5
tn, fp, fn, tp = confusion_matrix(y_true_valid, y_preds_valid, labels=[0, 1]).ravel()
fpr = fp / (fp + tn) if (fp + tn) > 0 else 0.0
fnr = fn / (fn + tp) if (fn + tp) > 0 else 0.0
acc = (tp + tn) / (tp + tn + fp + fn)
print("\n" + "="*30)
print(" TEST SET METRICS")
print("="*30)
print(f"Accuracy : {acc:.4f}")
print(f"AUC : {auc:.4f}")
print(f"FPR : {fpr:.4f} ({fpr*100:.2f}%)")
print(f"FNR : {fnr:.4f} ({fnr*100:.2f}%)")
print("="*30)
def main():
# Transform must match training (Resize to 224 + Grayscale->RGB)
data_transform = transforms.Compose([
transforms.Grayscale(num_output_channels=3),
transforms.Resize((IMG_SIZE, IMG_SIZE)), # 256x256
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # Generic Norm
])
print(f"[Data] Scanning {TEST_DIR}...")
dataset = SafeInferenceDataset(TEST_DIR, transform=data_transform)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=NUM_WORKERS, pin_memory=True)
model = load_model(MODEL_PATH)
results = []
all_labels = []
all_probs = []
all_preds = []
print(f"[Inference] Starting processing on {DEVICE}...")
with torch.no_grad():
for images, filenames, labels in tqdm(loader, desc="Testing"):
images = images.to(DEVICE)
# Standard Float32 Inference (Matching Training Stability)
outputs = model(images)
probs = torch.softmax(outputs, dim=1)
_, preds = torch.max(outputs, 1)
probs_np = probs.cpu().numpy()
preds_np = preds.cpu().numpy()
labels_np = labels.numpy()
all_labels.extend(labels_np)
all_probs.extend(probs_np[:, 1])
all_preds.extend(preds_np)
for i in range(len(filenames)):
fname = filenames[i]
pred_class = preds_np[i]
confidence = probs_np[i][pred_class]
malware_prob = probs_np[i][1]
label_str = "Malware" if pred_class == 1 else "Goodware"
results.append([fname, label_str, f"{confidence:.4f}", f"{malware_prob:.4f}"])
with open(OUTPUT_CSV, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(["Filename", "Prediction", "Confidence", "Malware_Probability"])
writer.writerows(results)
print(f"[Output] Predictions saved to {OUTPUT_CSV}")
calculate_final_metrics(all_labels, all_probs, all_preds)
if __name__ == "__main__":
# Required for multiprocessing on Windows/Generic Python
import multiprocessing
multiprocessing.freeze_support()
main()