-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathInference.py
More file actions
68 lines (54 loc) · 2.42 KB
/
Inference.py
File metadata and controls
68 lines (54 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from tqdm import tqdm
from models import MFNet, ConvBnAct, ResidualConv, SEBlock, MiniInceptionRes
from Dataset_and_utils import MultiSpectralDataset
def run_inference(test_loader, model, device):
use_model_for_test = model #inference_model
use_model_for_test.eval()
predictions = []
use_threshold = 0.099 #best_thr
with torch.no_grad():
for imgs in tqdm(test_loader, desc="Predicting "):
imgs = imgs.to(device)
logits, _, _ = use_model_for_test(imgs)
probs = torch.sigmoid(logits).cpu().numpy() # (B,1,H,W)
preds = (probs > use_threshold).astype(np.uint8)
for mask in preds:
predictions.append(mask.squeeze().astype(np.uint8).flatten())
#################
# submission.csv
#################
# as per rules of the comeptition
submission = pd.DataFrame({
"id": np.arange(len(predictions)),
"pixels": [",".join(map(str, p)) for p in predictions]
})
SUBMISSION_PATH = "submission.csv"
submission.to_csv(SUBMISSION_PATH, index=False)
print("Saved", SUBMISSION_PATH)
print(torch.cuda.memory_allocated())
# print(torch.cuda.memory_cached())# if deprecated, use torch.cuda.memory_reserved()
try:
print(torch.cuda.memory_cached())# if deprecated, use torch.cuda.max_memory_reserved()
except Exception as e:
print(torch.cuda.memory_reserved())
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
return
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BEST_MODEL_PATH = "/content/best_model.pth"
X_test = np.load('/root/.cache/kagglehub/competitions/kaggle-competition-dl-f-2025/X_test_256.npy', mmap_mode='r')
print("Test shape:", X_test.shape)
indices = np.arange(len(X_test))
test_ds = MultiSpectralDataset(X_test, Y=None,indices = indices,compute_stats=True, augment=False)
del X_test
test_loader = DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=2, pin_memory=True)
# Initialize model
model = MFNet(in_ch=16, n_class=1, use_se=True, deep_supervision=True).to(device)
checkpoint = torch.load(BEST_MODEL_PATH, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
run_inference(test_loader, model, device)