Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions backends/arm/_passes/decompose_grouped_conv_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -75,15 +75,24 @@ def _split_per_channel_qparams(qarg, index, output_slice_size):
@staticmethod
def _get_meta_copy(meta, i, output_slice_size):
meta_copy = meta.copy()

if "input_qparams" in meta.data and len(meta.data["input_qparams"]) > 0:
# Handle per-channel quantization by splitting quantization params
# similarly to how activations/weights/biases are split.
new_qparams = meta.data.get("input_qparams").copy()
# Get quantization params of the weights and slice them.
qarg = new_qparams[1]
w_qarg = new_qparams[1]
new_qparams[1] = DecomposeGroupedConvPass._split_per_channel_qparams(
qarg, index=i, output_slice_size=output_slice_size
w_qarg, index=i, output_slice_size=output_slice_size
)
# Special case for int16, grouped conv2d when bias is included.
# As we add bias after in the DecomposeConv2dWithInt16ActivationPass we must
# also split the bias quantization parameters for bias.
if new_qparams[0].dtype == torch.int16 and len(new_qparams) > 2:
b_qarg = new_qparams[2]
new_qparams[2] = DecomposeGroupedConvPass._split_per_channel_qparams(
b_qarg, index=i, output_slice_size=output_slice_size
)

meta_copy.data["input_qparams"] = new_qparams

Expand Down
4 changes: 2 additions & 2 deletions backends/arm/operator_support/ethos_u55_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,10 @@ def is_node_supported( # noqa: C901
):
for input_node in node.all_input_nodes:
dtype = _try_determine_dtype(input_node)
if dtype is not None and dtype != torch.int8:
if dtype is not None and dtype not in (torch.int8, torch.int16):
self.reporter.report_reject(
input_node,
f"Input {input_node.name} has unsupported dtype {dtype} (Supports i8).",
f"Input {input_node.name} has unsupported dtype {dtype} (Supports i8, i16).",
)
return False

Expand Down
32 changes: 31 additions & 1 deletion backends/arm/test/ops/test_amax.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -189,3 +189,33 @@ def test_max_dim_vgf_quant_to_amax(test_data: Max.input_t):
quantize=True,
)
pipeline.run()


@common.parametrize("test_data", Amax.test_data)
def test_amax_tosa_INT_a16w8(test_data: Amax.input_t):
"""Test amax with 16A8W quantization for TOSA INT."""
data, dim, keep_dims = test_data()
module = Amax(dim, keep_dims)
pipeline = TosaPipelineINT[Max.input_t](
module,
data,
"torch.ops.aten.amax",
tosa_extensions=["int16"],
)
pipeline.run()


@common.parametrize("test_data", Amax.test_data)
@common.XfailIfNoCorstone320
def test_amax_u85_INT_a16w8(test_data: Amax.input_t):
"""Test amax with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
data, dim, keep_dims = test_data()
module = Amax(dim, keep_dims)
pipeline = EthosU85PipelineINT[Max.input_t](
module,
data,
"torch.ops.aten.amax",
a16w8_quantization=True,
use_to_edge_transform_and_lower=True,
)
pipeline.run()
31 changes: 30 additions & 1 deletion backends/arm/test/ops/test_amin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -203,3 +203,32 @@ def test_min_dim_vgf_quant_to_amin(test_data: Min.input_t):
quantize=True,
)
pipeline.run()


@common.parametrize("test_data", Amin.test_data)
def test_amin_tosa_INT_a16w8(test_data: Amin.input_t):
"""Test amin with 16A8W quantization for TOSA INT."""
data, dim, keep_dims = test_data()
pipeline = TosaPipelineINT[Amin.input_t](
Amin(dim, keep_dims),
data,
Amin.aten_op,
tosa_extensions=["int16"],
)
pipeline.run()


@common.parametrize("test_data", Amin.test_data)
@common.XfailIfNoCorstone320
def test_amin_u85_INT_a16w8(test_data: Min.input_t):
"""Test amin with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
data, dim, keep_dims = test_data()
pipeline = EthosU85PipelineINT[Amin.input_t](
Amin(dim, keep_dims),
data,
Amin.aten_op,
per_channel_quantization=False,
a16w8_quantization=True,
use_to_edge_transform_and_lower=True,
)
pipeline.run()
49 changes: 49 additions & 0 deletions backends/arm/test/ops/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,52 @@ def test_convolution_2d_u55_INT_not_delegated(module: Conv2d):
quantize=True,
u55_subset=True,
).run()


@common.parametrize("test_data", test_data_INT)
def test_conv2d_tosa_INT_a16w8(test_data: input_t):
"""Test conv2d with 16A8W quantization for TOSA INT."""
model, per_channel_quantization = test_data()
pipeline = TosaPipelineINT[input_t](
model,
model.get_inputs(),
aten_op,
exir_op,
tosa_extensions=["int16"],
per_channel_quantization=per_channel_quantization,
)
pipeline.run()


@common.parametrize("test_data", test_data_INT)
@common.XfailIfNoCorstone300
def test_conv2d_u55_INT_a16w8(test_data: input_t):
"""Test conv2d with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
model, per_channel_quantization = test_data()
pipeline = EthosU55PipelineINT[input_t](
model,
model.get_inputs(),
aten_op,
exir_op,
a16w8_quantization=True,
use_to_edge_transform_and_lower=True,
per_channel_quantization=per_channel_quantization,
)
pipeline.run()


@common.parametrize("test_data", test_data_INT)
@common.XfailIfNoCorstone320
def test_conv2d_u85_INT_a16w8(test_data: input_t):
"""Test conv2d with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
model, per_channel_quantization = test_data()
pipeline = EthosU85PipelineINT[input_t](
model,
model.get_inputs(),
aten_op,
exir_op,
a16w8_quantization=True,
use_to_edge_transform_and_lower=True,
per_channel_quantization=per_channel_quantization,
)
pipeline.run()
48 changes: 48 additions & 0 deletions backends/arm/test/ops/test_depthwise_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,51 @@ def test_convolution_1d_u85_INT_a8w4_depthwise(test_data):
get_symmetric_a8w4_quantization_config(is_per_channel=per_channel_quantization)
)
pipeline.run()


@common.parametrize("test_data", test_data_conv2d_INT)
def test_convolution_2d_tosa_INT_a16w8_depthwise(test_data: input_t):
"""Test depthwise_conv with 16A8W quantization for TOSA INT."""
model, per_channel_quantization = test_data()
pipeline = TosaPipelineINT[input_t](
model,
model.get_inputs(),
aten_op=[],
exir_op=exir_op,
tosa_extensions=["int16"],
per_channel_quantization=per_channel_quantization,
)
pipeline.run()


@common.parametrize("test_data", test_data_conv2d_INT)
@common.XfailIfNoCorstone300
def test_convolution_2d_u85_INT_a16w8_depthwise(test_data: input_t):
"""Test depthwise_conv with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
model, per_channel_quantization = test_data()
pipeline = EthosU55PipelineINT[input_t](
model,
model.get_inputs(),
aten_ops=[],
exir_ops=exir_op,
per_channel_quantization=per_channel_quantization,
a16w8_quantization=True,
use_to_edge_transform_and_lower=True,
)
pipeline.run()


@common.parametrize("test_data", test_data_conv2d_INT)
@common.XfailIfNoCorstone320
def test_convolution_2d_u55_INT_a16w8_depthwise(test_data: input_t):
"""Test depthwise_conv with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
model, per_channel_quantization = test_data()
pipeline = EthosU85PipelineINT[input_t](
model,
model.get_inputs(),
aten_ops=[],
exir_ops=exir_op,
a16w8_quantization=True,
)

pipeline.run()
72 changes: 55 additions & 17 deletions backends/arm/test/ops/test_matmul.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import pytest

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
Expand Down Expand Up @@ -134,10 +136,7 @@ def test_matmul_u55_INT(test_data: input_t1):
pipeline.run()


@common.parametrize(
"test_data",
MatMulSingleInput.test_data_generators,
)
@common.parametrize("test_data", MatMulSingleInput.test_data_generators)
@common.XfailIfNoCorstone300
def test_matmul_u55_INT_single_input(test_data: input_t1):
pipeline = EthosU55PipelineINT[input_t1](
Expand All @@ -150,10 +149,7 @@ def test_matmul_u55_INT_single_input(test_data: input_t1):
pipeline.run()


@common.parametrize(
"test_data",
MatMulCombo.test_data_generators,
)
@common.parametrize("test_data", MatMulCombo.test_data_generators)
@common.XfailIfNoCorstone300
def test_matmul_u55_INT_combo(test_data: input_t1):
pipeline = EthosU55PipelineINT[input_t1](
Expand All @@ -179,10 +175,7 @@ def test_matmul_u85_INT(test_data: input_t1):
pipeline.run()


@common.parametrize(
"test_data",
MatMulSingleInput.test_data_generators,
)
@common.parametrize("test_data", MatMulSingleInput.test_data_generators)
@common.XfailIfNoCorstone320
def test_matmul_u85_INT_single_input(test_data: input_t1):
pipeline = EthosU85PipelineINT[input_t1](
Expand All @@ -195,10 +188,7 @@ def test_matmul_u85_INT_single_input(test_data: input_t1):
pipeline.run()


@common.parametrize(
"test_data",
MatMulCombo.test_data_generators,
)
@common.parametrize("test_data", MatMulCombo.test_data_generators)
@common.XfailIfNoCorstone320
def test_matmul_u85_INT_combo(test_data: input_t1):
pipeline = EthosU85PipelineINT[input_t1](
Expand Down Expand Up @@ -287,3 +277,51 @@ def test_matmul_vgf_quant_combo(test_data: input_t1):
quantize=True,
)
pipeline.run()


@common.parametrize("test_data", MatMulCombo.test_data_generators)
def test_matmul_tosa_INT_a16w8(test_data: input_t1):
"""Test matmul with 16A8W quantization for TOSA INT."""
pipeline = TosaPipelineINT[Tuple[torch.Tensor]](
MatMulCombo(),
test_data(),
aten_op_mm,
exir_op_mm,
tosa_extensions=["int16"],
)
pipeline.run()


@common.parametrize("test_data", MatMulCombo.test_data_generators)
@pytest.mark.xfail(
reason="Vela compilation fails with 'Non-passthrough operation' for int16 matmul operations"
)
@common.XfailIfNoCorstone300
def test_matmul_u55_INT_a16w8(test_data: input_t1):
"""Test matmul with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
pipeline = EthosU55PipelineINT[Tuple[torch.Tensor]](
MatMulCombo(),
test_data(),
aten_op_mm,
exir_op_mm,
per_channel_quantization=False,
a16w8_quantization=True,
use_to_edge_transform_and_lower=True,
)
pipeline.run()


@common.parametrize("test_data", MatMulCombo.test_data_generators)
@common.XfailIfNoCorstone320
def test_matmul_u85_INT_a16w8(test_data: input_t1):
"""Test matmul with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
pipeline = EthosU85PipelineINT[Tuple[torch.Tensor]](
MatMulCombo(),
test_data(),
aten_op_mm,
exir_op_mm,
per_channel_quantization=False,
a16w8_quantization=True,
use_to_edge_transform_and_lower=True,
)
pipeline.run()
44 changes: 43 additions & 1 deletion backends/arm/test/ops/test_maximum.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -96,3 +96,45 @@ def test_maximum_vgf_quant(test_data: Tuple):
quantize=True,
)
pipeline.run()


@common.parametrize("test_data", Maximum.test_parameters)
def test_maximum_tosa_INT_a16w8(test_data: test_t):
"""Test maximum with 16A8W quantization for TOSA INT."""
pipeline = TosaPipelineINT[test_t](
Maximum(),
test_data(),
aten_op,
tosa_extensions=["int16"],
)
pipeline.run()


@common.parametrize("test_data", Maximum.test_parameters)
@common.XfailIfNoCorstone300
def test_maximum_u55_INT_a16w8(test_data: test_t):
"""Test maximum with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
pipeline = EthosU55PipelineINT[test_t](
Maximum(),
test_data(),
aten_op,
per_channel_quantization=False,
a16w8_quantization=True,
use_to_edge_transform_and_lower=True,
)
pipeline.run()


@common.parametrize("test_data", Maximum.test_parameters)
@common.XfailIfNoCorstone320
def test_maximum_u85_INT_a16w8(test_data: test_t):
"""Test maximum with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
pipeline = EthosU85PipelineINT[test_t](
Maximum(),
test_data(),
aten_op,
per_channel_quantization=False,
a16w8_quantization=True,
use_to_edge_transform_and_lower=True,
)
pipeline.run()
Loading
Loading