Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -29,6 +29,7 @@
from .decompose_add_sub_alpha_pass import DecomposeAddSubAlphaPass # noqa
from .decompose_addmm_pass import DecomposeAddmmPass # noqa
from .decompose_any_pass import DecomposeAnyPass # noqa
from .decompose_as_strided_copy_pass import DecomposeAsStridedCopyPass # noqa
from .decompose_asin_and_acos_pass import DecomposeAsinAndAcosPass # noqa
from .decompose_asinh_pass import DecomposeAsinhPass # noqa
from .decompose_atan_pass import DecomposeAtanPass # noqa
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -36,6 +36,7 @@
DecomposeAnyPass,
DecomposeAsinAndAcosPass,
DecomposeAsinhPass,
DecomposeAsStridedCopyPass,
DecomposeAtanhPass,
DecomposeAtanPass,
DecomposeAvgPool2dPass,
Expand Down Expand Up @@ -321,6 +322,7 @@ def _tosa_pipeline(
ConvertExpandCopyToRepeatPass(),
UnsqueezeBeforeRepeatPass(),
DecomposeCumsumPass(exported_program),
DecomposeAsStridedCopyPass(),
DecomposeMaxPool2dPass(),
SizeAdjustInputPass(),
DecomposeSelectPass(),
Expand Down
113 changes: 113 additions & 0 deletions backends/arm/_passes/decompose_as_strided_copy_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, Optional, Set, Tuple, Type

import torch

from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm.common.as_strided_utils import (
contiguous_strides,
maybe_static_sequence,
to_int,
to_int_tuple,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


class DecomposeAsStridedCopyPass(ArmPass):
"""
Replace contiguous `aten.as_strided_copy` with `aten.view_copy`.

The TOSA backend only supports the contiguous-as-strided case where the stride matches
row-major layout and the storage offset is zero. In that scenario the operator is
equivalent to a reshape with copy semantics and can be lowered via `view_copy`.
"""

_passes_required_after: Set[Type[ExportPass]] = set()

_EDGE_OPS = (exir_ops.edge.aten.as_strided_copy.default,)
_ATEN_OPS = (torch.ops.aten.as_strided_copy.default,)

def _extract_args(
self, args: Tuple[object, ...], kwargs: dict
) -> Optional[Tuple[Tuple[int, ...], Tuple[int, ...], int]]:
"""Return (size, stride, storage_offset) when they are statically known."""
if len(args) < 3:
return None

size_arg = args[1]
stride_arg = args[2]
offset_arg = (
kwargs.get("storage_offset") if "storage_offset" in kwargs else None
)
if offset_arg is None and len(args) > 3:
offset_arg = args[3]

size_seq = maybe_static_sequence(size_arg)
stride_seq = maybe_static_sequence(stride_arg)
if size_seq is None or stride_seq is None:
return None

size_tuple = to_int_tuple(size_seq)
stride_tuple = to_int_tuple(stride_seq)
if size_tuple is None or stride_tuple is None:
return None

if len(size_tuple) != len(stride_tuple):
return None

if any(stride < 0 for stride in stride_tuple):
return None

if offset_arg is None:
storage_offset = 0
else:
parsed_offset = to_int(offset_arg)
if parsed_offset is None:
return None
storage_offset = parsed_offset

return size_tuple, stride_tuple, storage_offset

def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False):
if op not in (*self._EDGE_OPS, *self._ATEN_OPS):
return super().call_operator(op, args, kwargs, meta, updated)

extracted = self._extract_args(args, kwargs)
if extracted is None:
return super().call_operator(op, args, kwargs, meta, updated)

size_tuple, stride_tuple, storage_offset = extracted
if storage_offset != 0:
return super().call_operator(op, args, kwargs, meta, updated)

expected_strides = contiguous_strides(size_tuple)

def _stride_matches(idx: int, dim: int) -> bool:
stride = stride_tuple[idx]
expected = expected_strides[idx]
if idx == len(size_tuple) - 1:
return stride >= expected
if dim == 1 or expected == 0:
return True
return stride == expected

if any(not _stride_matches(i, dim) for i, dim in enumerate(size_tuple)):
return super().call_operator(op, args, kwargs, meta, updated)

view_args = (args[0], tuple(size_tuple))
view_kwargs: Dict[str, object] = {}

view_op = (
exir_ops.edge.aten.view_copy.default
if op in self._EDGE_OPS
else torch.ops.aten.view_copy.default
)

return super().call_operator(
view_op, view_args, view_kwargs, meta, updated=True
)
70 changes: 70 additions & 0 deletions backends/arm/common/as_strided_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Utility helpers shared across as_strided_copy handling."""

from __future__ import annotations

import numbers

from collections.abc import Sequence
from typing import Optional, Tuple, TypeVar

import torch
import torch.fx as fx

T = TypeVar("T", bound=Sequence)


def to_int(value: object) -> Optional[int]:
"""Return an int for supported numeric types, otherwise None."""
if isinstance(value, (numbers.Integral, torch.SymInt)):
return int(value)
return None


def maybe_static_sequence(value: object) -> Optional[Sequence]:
"""
Return a Python sequence for literal or FX-constant values.

FX exporters often wrap constant lists in nodes where the materialised
value is stored in ``node.meta["val"]``. This helper unwraps that so the
rest of the logic can treat them uniformly.
"""
if isinstance(value, (str, bytes)):
return None
if isinstance(value, fx.Node):
const_val = value.meta.get("val")
if isinstance(const_val, Sequence):
return const_val
return None
if isinstance(value, Sequence):
return value
return None


def to_int_tuple(value: object) -> Optional[Tuple[int, ...]]:
"""Best-effort conversion of a sequence of integers/SymInts to a tuple[int, ...]."""
seq = maybe_static_sequence(value)
if seq is None:
return None

result: list[int] = []
for item in seq:
converted = to_int(item)
if converted is None:
return None
result.append(converted)
return tuple(result)


def contiguous_strides(shape: Sequence[int]) -> Tuple[int, ...]:
"""Compute row-major contiguous strides for the provided shape."""
strides = [0] * len(shape)
running = 1
for idx in reversed(range(len(shape))):
dim_val = shape[idx]
strides[idx] = running if dim_val != 0 else 1
running *= max(dim_val, 1)
return tuple(strides)
3 changes: 2 additions & 1 deletion backends/arm/operator_support/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from . import ( # noqa
as_strided_copy_support,
clone_dim_order_support,
control_flow_support,
convolution_support,
Expand Down
117 changes: 117 additions & 0 deletions backends/arm/operator_support/as_strided_copy_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Declare operator support for aten.as_strided_copy in the TOSA backend."""

from collections.abc import Mapping
from typing import Any, Optional

import torch.fx as fx
from executorch.backends.arm.common.as_strided_utils import (
contiguous_strides,
maybe_static_sequence,
to_int,
to_int_tuple,
)

from executorch.backends.arm.operator_support.tosa_supported_operators import (
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa.specification import TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops


def _arg_from_node(node: fx.Node, position: int, keyword: str) -> object | None:
"""Fetch an argument either by keyword or positional index."""
kwargs: Mapping[str, Any] = node.kwargs
if keyword in kwargs:
return kwargs[keyword]
if len(node.args) > position:
return node.args[position]
return None


def _extract_static_args(
node: fx.Node,
) -> Optional[tuple[tuple[int, ...], tuple[int, ...], int]]:
"""Return static size/stride/offset if they are compatible."""
size_arg = _arg_from_node(node, 1, "size")
stride_arg = _arg_from_node(node, 2, "stride")
offset_arg = _arg_from_node(node, 3, "storage_offset")

if (
maybe_static_sequence(size_arg) is None
or maybe_static_sequence(stride_arg) is None
):
return None

size_tuple = to_int_tuple(size_arg)
stride_tuple = to_int_tuple(stride_arg)
if size_tuple is None or stride_tuple is None:
return None

if len(size_tuple) != len(stride_tuple):
return None

if any(stride < 0 for stride in stride_tuple):
return None

storage_offset = 0
if offset_arg is not None:
parsed_offset = to_int(offset_arg)
if parsed_offset is None:
return None
storage_offset = parsed_offset

return size_tuple, stride_tuple, storage_offset


@register_tosa_support_check
class AsStridedCopySupported(SupportedTOSAOperatorCheck):
"""Support check ensuring as_strided_copy is contiguous with zero offset."""

targets = [exir_ops.edge.aten.as_strided_copy.default]
tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def is_node_tosa_supported(
self, node: fx.Node, tosa_spec: TosaSpecification # noqa: D417
) -> bool:
extracted = _extract_static_args(node)
if extracted is None:
self.reporter.report_reject(
node, "Size/stride must be static with non-negative strides."
)
return False

size_tuple, stride_tuple, storage_offset = extracted

if storage_offset != 0:
self.reporter.report_reject(
node, "Non-zero storage offsets are unsupported."
)
return False

expected_strides = contiguous_strides(size_tuple)

def _stride_matches(idx: int, dim: int) -> bool:
stride = stride_tuple[idx]
expected = expected_strides[idx]
if idx == len(size_tuple) - 1:
return stride >= expected
if dim == 1:
return True
return stride == expected

if any(not _stride_matches(i, dim) for i, dim in enumerate(size_tuple)):
self.reporter.report_reject(
node,
f"Stride {stride_tuple} is not contiguous for shape {size_tuple}.",
)
return False

return True
3 changes: 2 additions & 1 deletion backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -429,6 +429,7 @@ def _match_pattern(
torch.ops.aten.unflatten.int,
torch.ops.aten.index_select.default,
torch.ops.aten.index.Tensor,
torch.ops.aten.as_strided_copy.default,
# Neg operator flips the range, but keps the magnitude the same.
# That is why we force it to use the same qparams and avoid
# dequant -> neg -> requant chain.
Expand Down
7 changes: 1 addition & 6 deletions backends/arm/test/models/test_mobilenet_v2_arm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# Copyright 2024-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -64,11 +64,6 @@ def test_mv2_tosa_FP_channels_last():
exir_op=[],
use_to_edge_transform_and_lower=True,
)
# Changing memory format leads to an unsupported as_strided_copy op being inserted into the graph,
# leading to a graph break.
pipeline.change_args(
"check_count.exir", {"torch.ops.higher_order.executorch_call_delegate": 2}
)
pipeline.run()


Expand Down
Loading
Loading