forked from zxjzxj9/PyTorchIntroduction
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfcn.py
More file actions
96 lines (79 loc) · 2.91 KB
/
Copy pathfcn.py
File metadata and controls
96 lines (79 loc) · 2.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
""" 该代码改编自PyTorch官网torchvision源代码
https://github.com/pytorch/vision/blob/master/torchvision/models/segmentation/fcn.py
"""
# FCN特征提取部分
def _segm_resnet(name, backbone_name, num_classes, aux,
pretrained_backbone=True):
backbone = resnet.__dict__[backbone_name](
pretrained=pretrained_backbone,
replace_stride_with_dilation=[False, True, True])
return_layers = {'layer4': 'out'}
if aux:
return_layers['layer3'] = 'aux'
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
aux_classifier = None
if aux:
inplanes = 1024
aux_classifier = FCNHead(inplanes, num_classes)
model_map = {
'deeplabv3': (DeepLabHead, DeepLabV3),
'fcn': (FCNHead, FCN),
}
inplanes = 2048
classifier = model_map[name][0](inplanes, num_classes)
base_model = model_map[name][1]
model = base_model(backbone, classifier, aux_classifier)
return model
# FCN模块部分
class FCN(_SimpleSegmentationModel):
pass
class _SimpleSegmentationModel(nn.Module):
def __init__(self, backbone, classifier, aux_classifier=None):
super(_SimpleSegmentationModel, self).__init__()
self.backbone = backbone
self.classifier = classifier
self.aux_classifier = aux_classifier
def forward(self, x):
input_shape = x.shape[-2:]
features = self.backbone(x)
result = OrderedDict()
x = features["out"]
x = self.classifier(x)
x = F.interpolate(x, size=input_shape, mode='bilinear',
align_corners=False)
result["out"] = x
if self.aux_classifier is not None:
x = features["aux"]
x = self.aux_classifier(x)
x = F.interpolate(x, size=input_shape, mode='bilinear',
align_corners=False)
result["aux"] = x
return result
# FCNHead部分
class FCNHead(nn.Sequential):
def __init__(self, in_channels, channels):
inter_channels = in_channels // 4
layers = [
nn.Conv2d(in_channels, inter_channels, 3, padding=1,
bias=False),
nn.BatchNorm2d(inter_channels),
nn.ReLU(),
nn.Dropout(0.1),
nn.Conv2d(inter_channels, channels, 1)
]
super(FCNHead, self).__init__(*layers)
# FCN输入图像预处理
def get_transform(train):
base_size = 520
crop_size = 480
min_size = int((0.5 if train else 1.0) * base_size)
max_size = int((2.0 if train else 1.0) * base_size)
transforms = []
transforms.append(T.RandomResize(min_size, max_size))
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
transforms.append(T.RandomCrop(crop_size))
transforms.append(T.ToTensor())
transforms.append(T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]))
return T.Compose(transforms)