Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 6 additions & 11 deletions backends/arm/test/ops/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
):
super().__init__()
self.nbr_convs = nbr_conv
self.groups = groups

# Handle default values
in_channels = [2] * nbr_conv if in_channels is None else in_channels
Expand Down Expand Up @@ -396,7 +397,8 @@ def forward(self, x):


def _get_dtype_count(model: torch.nn.Module):
nbr_convs: int = model.nbr_convs # noqa
# Set nbr_conv to be the amount of groups set if necessary.
nbr_convs: int = model.nbr_convs if model.groups is None else model.groups # noqa
return {
"CONST": {"INT4": nbr_convs * 2}, # One for the weight, one for the zp.
"CONV2D": {"INT32": nbr_convs},
Expand Down Expand Up @@ -430,16 +432,7 @@ def test_convolution_2d_tosa_INT(test_data):
pipeline.run()


@common.parametrize(
"test_data",
test_data_INT,
xfails={
"groups,per_channel_quant=True": "Int4 not supported for grouped convolutions. MLETORCH-1726",
"groups,per_channel_quant=False": "Int4 not supported for grouped convolutions. MLETORCH-1726",
"groups_bias,per_channel_quant=True": "Int4 not supported for grouped convolutions. MLETORCH-1726",
"groups_bias,per_channel_quant=False": "Int4 not supported for grouped convolutions. MLETORCH-1726",
},
)
@common.parametrize("test_data", test_data_INT)
def test_convolution_2d_tosa_INT_a8w4(test_data):
model, per_channel_quantization = test_data()
pipeline = TosaPipelineINT[input_t](
Expand Down Expand Up @@ -475,6 +468,7 @@ def test_convolution_2d_u55_INT(test_data):


@common.parametrize("test_data", test_data_INT)
@common.XfailIfNoCorstone300
def test_convolution_2d_u55_INT_a8w4(test_data):
model, per_channel_quantization = test_data()
pipeline = EthosU55PipelineINT[input_t](
Expand Down Expand Up @@ -504,6 +498,7 @@ def test_convolution_u85_INT(test_data):


@common.parametrize("test_data", test_data_INT)
@common.XfailIfNoCorstone320
def test_convolution_2d_u85_INT_a8w4(test_data):
model, per_channel_quantization = test_data()
pipeline = EthosU85PipelineINT[input_t](
Expand Down
Loading