diff --git a/engine/deim/hybrid_encoder.py b/engine/deim/hybrid_encoder.py index 77a7472..388cef8 100644 --- a/engine/deim/hybrid_encoder.py +++ b/engine/deim/hybrid_encoder.py @@ -200,9 +200,11 @@ def __init__(self, c1, c2, c3, c4, n=3, self.c = c3//2 self.cv1 = ConvNormLayer_fuse(c1, c3, 1, 1, bias=bias, act=act) if csp_type == 'csp2': - CSPLayer = CSPLayer2 - self.cv2 = nn.Sequential(CSPLayer(c3//2, c4, n, 1, bias=bias, act=act, bottletype=VGGBlock), ConvNormLayer_fuse(c4, c4, 3, 1, bias=bias, act=act)) - self.cv3 = nn.Sequential(CSPLayer(c4, c4, n, 1, bias=bias, act=act, bottletype=VGGBlock), ConvNormLayer_fuse(c4, c4, 3, 1, bias=bias, act=act)) + CSPLayerType = CSPLayer2 + else: + CSPLayerType = CSPLayer + self.cv2 = nn.Sequential(CSPLayerType(c3//2, c4, n, 1, bias=bias, act=act, bottletype=VGGBlock), ConvNormLayer_fuse(c4, c4, 3, 1, bias=bias, act=act)) + self.cv3 = nn.Sequential(CSPLayerType(c4, c4, n, 1, bias=bias, act=act, bottletype=VGGBlock), ConvNormLayer_fuse(c4, c4, 3, 1, bias=bias, act=act)) self.cv4 = ConvNormLayer_fuse(c3+(2*c4), c2, 1, 1, bias=bias, act=act) def forward_chunk(self, x):