diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 4c4d824e5d5..fe000ba209d 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -29,6 +29,7 @@ from .decompose_add_sub_alpha_pass import DecomposeAddSubAlphaPass # noqa from .decompose_addmm_pass import DecomposeAddmmPass # noqa from .decompose_any_pass import DecomposeAnyPass # noqa +from .decompose_as_strided_copy_pass import DecomposeAsStridedCopyPass # noqa from .decompose_asin_and_acos_pass import DecomposeAsinAndAcosPass # noqa from .decompose_asinh_pass import DecomposeAsinhPass # noqa from .decompose_atan_pass import DecomposeAtanPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index c04a6329f4a..03a9308d898 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -36,6 +36,7 @@ DecomposeAnyPass, DecomposeAsinAndAcosPass, DecomposeAsinhPass, + DecomposeAsStridedCopyPass, DecomposeAtanhPass, DecomposeAtanPass, DecomposeAvgPool2dPass, @@ -321,6 +322,7 @@ def _tosa_pipeline( ConvertExpandCopyToRepeatPass(), UnsqueezeBeforeRepeatPass(), DecomposeCumsumPass(exported_program), + DecomposeAsStridedCopyPass(), DecomposeMaxPool2dPass(), SizeAdjustInputPass(), DecomposeSelectPass(), diff --git a/backends/arm/_passes/decompose_as_strided_copy_pass.py b/backends/arm/_passes/decompose_as_strided_copy_pass.py new file mode 100644 index 00000000000..e4555555be6 --- /dev/null +++ b/backends/arm/_passes/decompose_as_strided_copy_pass.py @@ -0,0 +1,113 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, Optional, Set, Tuple, Type + +import torch + +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm.common.as_strided_utils import ( + contiguous_strides, + maybe_static_sequence, + to_int, + to_int_tuple, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +class DecomposeAsStridedCopyPass(ArmPass): + """ + Replace contiguous `aten.as_strided_copy` with `aten.view_copy`. + + The TOSA backend only supports the contiguous-as-strided case where the stride matches + row-major layout and the storage offset is zero. In that scenario the operator is + equivalent to a reshape with copy semantics and can be lowered via `view_copy`. + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + _EDGE_OPS = (exir_ops.edge.aten.as_strided_copy.default,) + _ATEN_OPS = (torch.ops.aten.as_strided_copy.default,) + + def _extract_args( + self, args: Tuple[object, ...], kwargs: dict + ) -> Optional[Tuple[Tuple[int, ...], Tuple[int, ...], int]]: + """Return (size, stride, storage_offset) when they are statically known.""" + if len(args) < 3: + return None + + size_arg = args[1] + stride_arg = args[2] + offset_arg = ( + kwargs.get("storage_offset") if "storage_offset" in kwargs else None + ) + if offset_arg is None and len(args) > 3: + offset_arg = args[3] + + size_seq = maybe_static_sequence(size_arg) + stride_seq = maybe_static_sequence(stride_arg) + if size_seq is None or stride_seq is None: + return None + + size_tuple = to_int_tuple(size_seq) + stride_tuple = to_int_tuple(stride_seq) + if size_tuple is None or stride_tuple is None: + return None + + if len(size_tuple) != len(stride_tuple): + return None + + if any(stride < 0 for stride in stride_tuple): + return None + + if offset_arg is None: + storage_offset = 0 + else: + parsed_offset = to_int(offset_arg) + if parsed_offset is None: + return None + storage_offset = parsed_offset + + return size_tuple, stride_tuple, storage_offset + + def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False): + if op not in (*self._EDGE_OPS, *self._ATEN_OPS): + return super().call_operator(op, args, kwargs, meta, updated) + + extracted = self._extract_args(args, kwargs) + if extracted is None: + return super().call_operator(op, args, kwargs, meta, updated) + + size_tuple, stride_tuple, storage_offset = extracted + if storage_offset != 0: + return super().call_operator(op, args, kwargs, meta, updated) + + expected_strides = contiguous_strides(size_tuple) + + def _stride_matches(idx: int, dim: int) -> bool: + stride = stride_tuple[idx] + expected = expected_strides[idx] + if idx == len(size_tuple) - 1: + return stride >= expected + if dim == 1 or expected == 0: + return True + return stride == expected + + if any(not _stride_matches(i, dim) for i, dim in enumerate(size_tuple)): + return super().call_operator(op, args, kwargs, meta, updated) + + view_args = (args[0], tuple(size_tuple)) + view_kwargs: Dict[str, object] = {} + + view_op = ( + exir_ops.edge.aten.view_copy.default + if op in self._EDGE_OPS + else torch.ops.aten.view_copy.default + ) + + return super().call_operator( + view_op, view_args, view_kwargs, meta, updated=True + ) diff --git a/backends/arm/common/as_strided_utils.py b/backends/arm/common/as_strided_utils.py new file mode 100644 index 00000000000..69112383ca7 --- /dev/null +++ b/backends/arm/common/as_strided_utils.py @@ -0,0 +1,70 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Utility helpers shared across as_strided_copy handling.""" + +from __future__ import annotations + +import numbers + +from collections.abc import Sequence +from typing import Optional, Tuple, TypeVar + +import torch +import torch.fx as fx + +T = TypeVar("T", bound=Sequence) + + +def to_int(value: object) -> Optional[int]: + """Return an int for supported numeric types, otherwise None.""" + if isinstance(value, (numbers.Integral, torch.SymInt)): + return int(value) + return None + + +def maybe_static_sequence(value: object) -> Optional[Sequence]: + """ + Return a Python sequence for literal or FX-constant values. + + FX exporters often wrap constant lists in nodes where the materialised + value is stored in ``node.meta["val"]``. This helper unwraps that so the + rest of the logic can treat them uniformly. + """ + if isinstance(value, (str, bytes)): + return None + if isinstance(value, fx.Node): + const_val = value.meta.get("val") + if isinstance(const_val, Sequence): + return const_val + return None + if isinstance(value, Sequence): + return value + return None + + +def to_int_tuple(value: object) -> Optional[Tuple[int, ...]]: + """Best-effort conversion of a sequence of integers/SymInts to a tuple[int, ...].""" + seq = maybe_static_sequence(value) + if seq is None: + return None + + result: list[int] = [] + for item in seq: + converted = to_int(item) + if converted is None: + return None + result.append(converted) + return tuple(result) + + +def contiguous_strides(shape: Sequence[int]) -> Tuple[int, ...]: + """Compute row-major contiguous strides for the provided shape.""" + strides = [0] * len(shape) + running = 1 + for idx in reversed(range(len(shape))): + dim_val = shape[idx] + strides[idx] = running if dim_val != 0 else 1 + running *= max(dim_val, 1) + return tuple(strides) diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index 01d936be7ce..a72ba3f0530 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -1,10 +1,11 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from . import ( # noqa + as_strided_copy_support, clone_dim_order_support, control_flow_support, convolution_support, diff --git a/backends/arm/operator_support/as_strided_copy_support.py b/backends/arm/operator_support/as_strided_copy_support.py new file mode 100644 index 00000000000..f71fe36b9b2 --- /dev/null +++ b/backends/arm/operator_support/as_strided_copy_support.py @@ -0,0 +1,117 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Declare operator support for aten.as_strided_copy in the TOSA backend.""" + +from collections.abc import Mapping +from typing import Any, Optional + +import torch.fx as fx +from executorch.backends.arm.common.as_strided_utils import ( + contiguous_strides, + maybe_static_sequence, + to_int, + to_int_tuple, +) + +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + register_tosa_support_check, + SupportedTOSAOperatorCheck, +) +from executorch.backends.arm.tosa.specification import TosaSpecification +from executorch.exir.dialects._ops import ops as exir_ops + + +def _arg_from_node(node: fx.Node, position: int, keyword: str) -> object | None: + """Fetch an argument either by keyword or positional index.""" + kwargs: Mapping[str, Any] = node.kwargs + if keyword in kwargs: + return kwargs[keyword] + if len(node.args) > position: + return node.args[position] + return None + + +def _extract_static_args( + node: fx.Node, +) -> Optional[tuple[tuple[int, ...], tuple[int, ...], int]]: + """Return static size/stride/offset if they are compatible.""" + size_arg = _arg_from_node(node, 1, "size") + stride_arg = _arg_from_node(node, 2, "stride") + offset_arg = _arg_from_node(node, 3, "storage_offset") + + if ( + maybe_static_sequence(size_arg) is None + or maybe_static_sequence(stride_arg) is None + ): + return None + + size_tuple = to_int_tuple(size_arg) + stride_tuple = to_int_tuple(stride_arg) + if size_tuple is None or stride_tuple is None: + return None + + if len(size_tuple) != len(stride_tuple): + return None + + if any(stride < 0 for stride in stride_tuple): + return None + + storage_offset = 0 + if offset_arg is not None: + parsed_offset = to_int(offset_arg) + if parsed_offset is None: + return None + storage_offset = parsed_offset + + return size_tuple, stride_tuple, storage_offset + + +@register_tosa_support_check +class AsStridedCopySupported(SupportedTOSAOperatorCheck): + """Support check ensuring as_strided_copy is contiguous with zero offset.""" + + targets = [exir_ops.edge.aten.as_strided_copy.default] + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def is_node_tosa_supported( + self, node: fx.Node, tosa_spec: TosaSpecification # noqa: D417 + ) -> bool: + extracted = _extract_static_args(node) + if extracted is None: + self.reporter.report_reject( + node, "Size/stride must be static with non-negative strides." + ) + return False + + size_tuple, stride_tuple, storage_offset = extracted + + if storage_offset != 0: + self.reporter.report_reject( + node, "Non-zero storage offsets are unsupported." + ) + return False + + expected_strides = contiguous_strides(size_tuple) + + def _stride_matches(idx: int, dim: int) -> bool: + stride = stride_tuple[idx] + expected = expected_strides[idx] + if idx == len(size_tuple) - 1: + return stride >= expected + if dim == 1: + return True + return stride == expected + + if any(not _stride_matches(i, dim) for i, dim in enumerate(size_tuple)): + self.reporter.report_reject( + node, + f"Stride {stride_tuple} is not contiguous for shape {size_tuple}.", + ) + return False + + return True diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 66e799aadc4..4a2c0709c77 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -429,6 +429,7 @@ def _match_pattern( torch.ops.aten.unflatten.int, torch.ops.aten.index_select.default, torch.ops.aten.index.Tensor, + torch.ops.aten.as_strided_copy.default, # Neg operator flips the range, but keps the magnitude the same. # That is why we force it to use the same qparams and avoid # dequant -> neg -> requant chain. diff --git a/backends/arm/test/models/test_mobilenet_v2_arm.py b/backends/arm/test/models/test_mobilenet_v2_arm.py index a0a758a760a..7bd33ade240 100644 --- a/backends/arm/test/models/test_mobilenet_v2_arm.py +++ b/backends/arm/test/models/test_mobilenet_v2_arm.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -64,11 +64,6 @@ def test_mv2_tosa_FP_channels_last(): exir_op=[], use_to_edge_transform_and_lower=True, ) - # Changing memory format leads to an unsupported as_strided_copy op being inserted into the graph, - # leading to a graph break. - pipeline.change_args( - "check_count.exir", {"torch.ops.higher_order.executorch_call_delegate": 2} - ) pipeline.run() diff --git a/backends/arm/test/ops/test_as_strided_copy.py b/backends/arm/test/ops/test_as_strided_copy.py new file mode 100644 index 00000000000..9ed0b52da49 --- /dev/null +++ b/backends/arm/test/ops/test_as_strided_copy.py @@ -0,0 +1,137 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm.common.as_strided_utils import contiguous_strides + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + OpNotSupportedPipeline, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + + +aten_op = "torch.ops.aten.as_strided_copy.default" +input_t = Tuple[torch.Tensor] + + +class AsStridedCopyModule(torch.nn.Module): + def __init__( + self, + size: Tuple[int, ...], + stride: Tuple[int, ...], + storage_offset: int = 0, + ): + super().__init__() + self.size = size + self.stride = stride + self.storage_offset = storage_offset + + def forward(self, x: torch.Tensor): + y = torch.ops.aten.as_strided_copy.default( + x, self.size, self.stride, self.storage_offset + ) + return y + + +def _make_case( + tensor_shape: Tuple[int, ...], + target_shape: Tuple[int, ...], +) -> Tuple[torch.Tensor, Tuple[int, ...], Tuple[int, ...]]: + tensor = torch.rand(tensor_shape) + stride = contiguous_strides(target_shape) + return tensor, target_shape, stride + + +delegated_cases = { + "reshape_2d": lambda: _make_case((4, 6), (3, 8)), + "flatten": lambda: _make_case((2, 3, 4), (6, 4)), + "expand_rank": lambda: _make_case((2, 3, 4), (2, 3, 4)), +} + +unsupported_cases = { + "non_contiguous_stride": lambda: ( + torch.rand(3, 3), + (3, 3), + (1, 3), # Not a contiguous stride layout for (3, 3) + ), + "non_zero_offset": lambda: ( + torch.rand(4, 4), + (4, 4), + contiguous_strides((4, 4)), + 4, + ), +} + + +@common.parametrize("test_data", delegated_cases) +def test_as_strided_copy_tosa_FP(test_data): + tensor, size, stride = test_data() + module = AsStridedCopyModule(size, stride) + pipeline = TosaPipelineFP[input_t]( + module, + (tensor,), + aten_op, + ) + pipeline.run() + + +@common.parametrize("test_data", delegated_cases) +def test_as_strided_copy_tosa_INT(test_data): + tensor, size, stride = test_data() + module = AsStridedCopyModule(size, stride) + pipeline = TosaPipelineINT[input_t]( + module, + (tensor,), + aten_op, + ) + pipeline.run() + + +@common.parametrize("test_data", delegated_cases) +@common.SkipIfNoModelConverter +def test_as_strided_copy_vgf_no_quant(test_data): + tensor, size, stride = test_data() + module = AsStridedCopyModule(size, stride) + pipeline = VgfPipeline[input_t]( + module, + (tensor,), + aten_op, + tosa_version="TOSA-1.0+FP", + quantize=False, + ) + pipeline.run() + + +@common.parametrize("test_data", delegated_cases) +@common.SkipIfNoModelConverter +def test_as_strided_copy_vgf_quant(test_data): + tensor, size, stride = test_data() + module = AsStridedCopyModule(size, stride) + pipeline = VgfPipeline[input_t]( + module, + (tensor,), + aten_op, + tosa_version="TOSA-1.0+INT", + ) + pipeline.run() + + +@common.parametrize("test_data", unsupported_cases) +def test_as_strided_copy_not_delegated(test_data): + tensor, size, stride, *rest = test_data() + storage_offset = rest[0] if rest else 0 + module = AsStridedCopyModule(size, stride, storage_offset=storage_offset) + pipeline = OpNotSupportedPipeline[input_t]( + module, + (tensor,), + {"executorch_exir_dialects_edge__ops_aten_as_strided_copy_default": 1}, + n_expected_delegates=0, + ) + pipeline.run()