77from typing import Tuple , Union
88
99import mindspore .common .initializer as init
10- from mindspore import Tensor , nn , ops
10+ from mindspore import Tensor , mint , nn
1111
1212from .helpers import load_pretrained
13- from .layers .compatibility import Dropout
13+ from .layers .flatten import Flatten
1414from .layers .pooling import GlobalAvgPooling
1515from .registry import register_model
1616
@@ -45,12 +45,12 @@ def __init__(
4545 kernel_size : int = 1 ,
4646 stride : int = 1 ,
4747 padding : int = 0 ,
48- pad_mode : str = "same " ,
48+ pad_mode : str = "zeros " ,
4949 ) -> None :
5050 super ().__init__ ()
51- self .conv = nn .Conv2d (in_channels , out_channels , kernel_size , stride ,
52- padding = padding , pad_mode = pad_mode )
53- self .relu = nn .ReLU ()
51+ self .conv = mint . nn .Conv2d (
52+ in_channels , out_channels , kernel_size , stride , padding = padding , padding_mode = pad_mode , bias = False )
53+ self .relu = mint . nn .ReLU ()
5454
5555 def construct (self , x : Tensor ) -> Tensor :
5656 x = self .conv (x )
@@ -75,14 +75,14 @@ def __init__(
7575 self .b1 = BasicConv2d (in_channels , ch1x1 , kernel_size = 1 )
7676 self .b2 = nn .SequentialCell ([
7777 BasicConv2d (in_channels , ch3x3red , kernel_size = 1 ),
78- BasicConv2d (ch3x3red , ch3x3 , kernel_size = 3 ),
78+ BasicConv2d (ch3x3red , ch3x3 , kernel_size = 3 , padding = 1 ),
7979 ])
8080 self .b3 = nn .SequentialCell ([
8181 BasicConv2d (in_channels , ch5x5red , kernel_size = 1 ),
82- BasicConv2d (ch5x5red , ch5x5 , kernel_size = 5 ),
82+ BasicConv2d (ch5x5red , ch5x5 , kernel_size = 5 , padding = 2 ),
8383 ])
8484 self .b4 = nn .SequentialCell ([
85- nn .MaxPool2d (kernel_size = 3 , stride = 1 , pad_mode = "same" ),
85+ mint . nn .MaxPool2d (kernel_size = 3 , stride = 1 , padding = 1 , ceil_mode = True ),
8686 BasicConv2d (in_channels , pool_proj , kernel_size = 1 ),
8787 ])
8888
@@ -91,7 +91,7 @@ def construct(self, x: Tensor) -> Tensor:
9191 branch2 = self .b2 (x )
9292 branch3 = self .b3 (x )
9393 branch4 = self .b4 (x )
94- return ops .concat ((branch1 , branch2 , branch3 , branch4 ), axis = 1 )
94+ return mint .concat ((branch1 , branch2 , branch3 , branch4 ), dim = 1 )
9595
9696
9797class InceptionAux (nn .Cell ):
@@ -104,13 +104,13 @@ def __init__(
104104 drop_rate : float = 0.7 ,
105105 ) -> None :
106106 super ().__init__ ()
107- self .avg_pool = nn .AvgPool2d (kernel_size = 5 , stride = 3 )
107+ self .avg_pool = mint . nn .AvgPool2d (kernel_size = 5 , stride = 3 )
108108 self .conv = BasicConv2d (in_channels , 128 , kernel_size = 1 )
109- self .fc1 = nn .Dense (2048 , 1024 )
110- self .fc2 = nn .Dense (1024 , num_classes )
111- self .flatten = nn . Flatten ()
112- self .relu = nn .ReLU ()
113- self .dropout = Dropout (p = drop_rate )
109+ self .fc1 = mint . nn .Linear (2048 , 1024 )
110+ self .fc2 = mint . nn .Linear (1024 , num_classes )
111+ self .flatten = Flatten ()
112+ self .relu = mint . nn .ReLU ()
113+ self .dropout = mint . nn . Dropout (p = drop_rate )
114114
115115 def construct (self , x : Tensor ) -> Tensor :
116116 x = self .avg_pool (x )
@@ -145,23 +145,23 @@ def __init__(
145145 ) -> None :
146146 super ().__init__ ()
147147 self .aux_logits = aux_logits
148- self .conv1 = BasicConv2d (in_channels , 64 , kernel_size = 7 , stride = 2 )
149- self .maxpool1 = nn .MaxPool2d (kernel_size = 3 , stride = 2 , pad_mode = "same" )
148+ self .conv1 = BasicConv2d (in_channels , 64 , kernel_size = 7 , stride = 2 , padding = 3 )
149+ self .maxpool1 = mint . nn .MaxPool2d (kernel_size = 3 , stride = 2 , ceil_mode = True )
150150
151151 self .conv2 = BasicConv2d (64 , 64 , kernel_size = 1 )
152- self .conv3 = BasicConv2d (64 , 192 , kernel_size = 3 )
153- self .maxpool2 = nn .MaxPool2d (kernel_size = 3 , stride = 2 , pad_mode = "same" )
152+ self .conv3 = BasicConv2d (64 , 192 , kernel_size = 3 , padding = 1 )
153+ self .maxpool2 = mint . nn .MaxPool2d (kernel_size = 3 , stride = 2 , ceil_mode = True )
154154
155155 self .inception3a = Inception (192 , 64 , 96 , 128 , 16 , 32 , 32 )
156156 self .inception3b = Inception (256 , 128 , 128 , 192 , 32 , 96 , 64 )
157- self .maxpool3 = nn .MaxPool2d (kernel_size = 3 , stride = 2 , pad_mode = "same" )
157+ self .maxpool3 = mint . nn .MaxPool2d (kernel_size = 3 , stride = 2 , ceil_mode = True )
158158
159159 self .inception4a = Inception (480 , 192 , 96 , 208 , 16 , 48 , 64 )
160160 self .inception4b = Inception (512 , 160 , 112 , 224 , 24 , 64 , 64 )
161161 self .inception4c = Inception (512 , 128 , 128 , 256 , 24 , 64 , 64 )
162162 self .inception4d = Inception (512 , 112 , 144 , 288 , 32 , 64 , 64 )
163163 self .inception4e = Inception (528 , 256 , 160 , 320 , 32 , 128 , 128 )
164- self .maxpool4 = nn .MaxPool2d (kernel_size = 2 , stride = 2 , pad_mode = "same" )
164+ self .maxpool4 = mint . nn .MaxPool2d (kernel_size = 2 , stride = 2 , ceil_mode = True )
165165
166166 self .inception5a = Inception (832 , 256 , 160 , 320 , 32 , 128 , 128 )
167167 self .inception5b = Inception (832 , 384 , 192 , 384 , 48 , 128 , 128 )
@@ -171,22 +171,24 @@ def __init__(
171171 self .aux2 = InceptionAux (528 , num_classes , drop_rate = drop_rate_aux )
172172
173173 self .pool = GlobalAvgPooling ()
174- self .dropout = Dropout (p = drop_rate )
175- self .classifier = nn .Dense (1024 , num_classes )
174+ self .dropout = mint . nn . Dropout (p = drop_rate )
175+ self .classifier = mint . nn .Linear (1024 , num_classes )
176176 self ._initialize_weights ()
177177
178178 def _initialize_weights (self ):
179179 for _ , cell in self .cells_and_names ():
180- if isinstance (cell , nn .Conv2d ):
180+ if isinstance (cell , mint . nn .Conv2d ):
181181 cell .weight .set_data (init .initializer (init .HeNormal (0 , mode = 'fan_in' , nonlinearity = 'leaky_relu' ),
182182 cell .weight .shape , cell .weight .dtype ))
183183 if cell .bias is not None :
184184 cell .bias .set_data (init .initializer (init .Constant (0 ), cell .bias .shape , cell .bias .dtype ))
185- elif isinstance (cell , nn .BatchNorm2d ) or isinstance (cell , nn .BatchNorm1d ):
186- cell .gamma .set_data (init .initializer (init .Constant (1 ), cell .gamma .shape , cell .gamma .dtype ))
187- if cell .beta is not None :
188- cell .beta .set_data (init .initializer (init .Constant (0 ), cell .beta .shape , cell .gamma .dtype ))
189- elif isinstance (cell , nn .Dense ):
185+ elif isinstance (cell , mint .nn .BatchNorm2d ) or isinstance (cell , mint .nn .BatchNorm1d ):
186+ cell .weight .set_data (
187+ init .initializer (init .Constant (1 ), cell .weight .shape , cell .weight .dtype ))
188+ if cell .bias is not None :
189+ cell .bias .set_data (
190+ init .initializer (init .Constant (0 ), cell .bias .shape , cell .weight .dtype ))
191+ elif isinstance (cell , mint .nn .Linear ):
190192 cell .weight .set_data (
191193 init .initializer (init .HeUniform (math .sqrt (5 ), mode = 'fan_in' , nonlinearity = 'leaky_relu' ),
192194 cell .weight .shape , cell .weight .dtype ))
0 commit comments