-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy patheval_seg_kaggle.py
More file actions
executable file
·79 lines (67 loc) · 3.88 KB
/
eval_seg_kaggle.py
File metadata and controls
executable file
·79 lines (67 loc) · 3.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import argparse
import torch.optim as optim
from torch.optim import lr_scheduler
from seg_utils import *
from dec_utils import *
from seg_utils import seg_transforms, seg_dataset_kaggle, seg_eval_kaggle
from models import dec_net_seg, seg_net
import cv2
import os
parser = argparse.ArgumentParser(description='Detection Training (MultiGPU)')
parser.add_argument('--testDir', default="/home/grace/PycharmProjects/DataSets/kaggle/test", type=str, help='test image directory')
parser.add_argument('--annoDir', default="data/root/mask", type=str, help='annotation image directory')
parser.add_argument('--imgSuffix', default='.png', type=str, help='suffix of the input images')
parser.add_argument('--annoSuffix', default='.png', type=str, help='suffix of the annotation images')
parser.add_argument('--img_height', default=512, type=int, help='img height')
parser.add_argument('--img_width', default=512, type=int, help='img width')
parser.add_argument('--num_classes', default=2, type=int, help='dataset classes')
parser.add_argument('--top_k', default=500, type=int, help='the number of detections to keep')
parser.add_argument('--conf_thresh', default=0.3, type=float, help='confidence threshold')
parser.add_argument('--nms_thresh', default=0.3, type=float, help='nms threshold')
parser.add_argument('--seg_thresh', default=0.5, type=float, help='segmentation threshold')
parser.add_argument('--dec_weights', default="dec_weights/kaggle/end_model.pth", type=str, help='detection weights')
parser.add_argument('--seg_weights', default="seg_weights/kaggle/end_model.pth", type=str, help='segmentation weights')
def load_dec_weights(dec_model, dec_weights):
print('Resuming detection weights from {} ...'.format(dec_weights))
dec_dict = torch.load(dec_weights)
dec_dict_update = {}
for k in dec_dict:
if k.startswith('module') and not k.startswith('module_list'):
dec_dict_update[k[7:]] = dec_dict[k]
else:
dec_dict_update[k] = dec_dict[k]
dec_model.load_state_dict(dec_dict_update, strict=True)
return dec_model
def evaluation(args):
#-----------------load detection model -------------------------
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dec_model = dec_net_seg.resnetssd50(pretrained=False, num_classes=args.num_classes)
dec_model = load_dec_weights(dec_model, args.dec_weights)
dec_model = dec_model.to(device)
dec_model.eval()
#-----------------load segmentation model -------------------------
seg_model = seg_net.SEG_NET(num_classes=args.num_classes)
seg_model.load_state_dict(torch.load(args.seg_weights))
seg_model= seg_model.to(device)
seg_model.eval()
##--------------------------------------------------------------
data_transforms = seg_transforms.Compose([seg_transforms.ConvertImgFloat(),
seg_transforms.Resize(args.img_height, args.img_width),
seg_transforms.ToTensor()])
dsets = seg_dataset_kaggle.NucleiCell(args.testDir, args.annoDir, data_transforms,
imgSuffix=args.imgSuffix, annoSuffix=args.annoSuffix)
# for validation data -----------------------------------
detector = Detect(num_classes=args.num_classes,
top_k=args.top_k,
conf_thresh=args.conf_thresh,
nms_thresh=args.nms_thresh,
variance=[0.1, 0.2])
anchorGen = Anchors(args.img_height, args.img_width)
anchors = anchorGen.forward()
ap_05, ap_07 = seg_eval_kaggle.do_python_eval(dsets=dsets, dec_model=dec_model, seg_model=seg_model,
detector=detector, anchors=anchors, device=device,
args=args, offline=True)
print('Finish')
if __name__ == '__main__':
args = parser.parse_args()
evaluation(args)