From adc4fcb640fc23959dfb1bc4123f003f56f790ef Mon Sep 17 00:00:00 2001 From: Saoirse Stewart Date: Thu, 15 Jan 2026 14:05:58 +0000 Subject: [PATCH] Arm backend: Update a8w4 _get_dtype_count to account for groups - When using groups, we need to update the count layers. Signed-off-by: Saoirse Stewart --- backends/arm/test/ops/test_conv2d.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/backends/arm/test/ops/test_conv2d.py b/backends/arm/test/ops/test_conv2d.py index 55eee293f95..d4953dbb5bc 100644 --- a/backends/arm/test/ops/test_conv2d.py +++ b/backends/arm/test/ops/test_conv2d.py @@ -50,6 +50,7 @@ def __init__( ): super().__init__() self.nbr_convs = nbr_conv + self.groups = groups # Handle default values in_channels = [2] * nbr_conv if in_channels is None else in_channels @@ -396,7 +397,8 @@ def forward(self, x): def _get_dtype_count(model: torch.nn.Module): - nbr_convs: int = model.nbr_convs # noqa + # Set nbr_conv to be the amount of groups set if necessary. + nbr_convs: int = model.nbr_convs if model.groups is None else model.groups # noqa return { "CONST": {"INT4": nbr_convs * 2}, # One for the weight, one for the zp. "CONV2D": {"INT32": nbr_convs}, @@ -430,16 +432,7 @@ def test_convolution_2d_tosa_INT(test_data): pipeline.run() -@common.parametrize( - "test_data", - test_data_INT, - xfails={ - "groups,per_channel_quant=True": "Int4 not supported for grouped convolutions. MLETORCH-1726", - "groups,per_channel_quant=False": "Int4 not supported for grouped convolutions. MLETORCH-1726", - "groups_bias,per_channel_quant=True": "Int4 not supported for grouped convolutions. MLETORCH-1726", - "groups_bias,per_channel_quant=False": "Int4 not supported for grouped convolutions. MLETORCH-1726", - }, -) +@common.parametrize("test_data", test_data_INT) def test_convolution_2d_tosa_INT_a8w4(test_data): model, per_channel_quantization = test_data() pipeline = TosaPipelineINT[input_t]( @@ -475,6 +468,7 @@ def test_convolution_2d_u55_INT(test_data): @common.parametrize("test_data", test_data_INT) +@common.XfailIfNoCorstone300 def test_convolution_2d_u55_INT_a8w4(test_data): model, per_channel_quantization = test_data() pipeline = EthosU55PipelineINT[input_t]( @@ -504,6 +498,7 @@ def test_convolution_u85_INT(test_data): @common.parametrize("test_data", test_data_INT) +@common.XfailIfNoCorstone320 def test_convolution_2d_u85_INT_a8w4(test_data): model, per_channel_quantization = test_data() pipeline = EthosU85PipelineINT[input_t](