Skip to content

ARM Ethos-U crashes in partitioning over conv->relu->permute->reshape(5D) #16739

@itaiberman

Description

@itaiberman

🐛 Describe the bug

Observed Behavior

The run crashes during to_edge_transform_and_lower when exporting to Ethos-U backend. It seems to be caused when there is Conv -> Relu -> Permute -> Reshape but only when the reshape is in 5D.
XNNPACK works, probably a problem with Ethos-U partitioner.

Reproduce Steps

import torch
import torch.nn as nn
from executorch.backends.arm.ethosu import EthosUCompileSpec, EthosUPartitioner
from executorch.backends.arm.quantizer import EthosUQuantizer
from executorch.backends.arm.quantizer.arm_quantizer import \
    get_symmetric_quantization_config as get_arm_symmetric_qconfig
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge_transform_and_lower
from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, convert_pt2e


class ReshapePermuteModel(torch.nn.Module):
    def __init__(self):
        super(ReshapePermuteModel, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )

    def forward(self, x):
        x1 = self.block(x)
        out = x1.permute(0, 2, 3, 1).reshape(1, 1, 49, 64, 256).permute(0, 1, 3, 2, 4)
        return out


device = 'cpu'  # 'mps'
float_model = ReshapePermuteModel().eval().to(device)
compile_spec = EthosUCompileSpec(target="ethos-u85-2048")
quantizer = EthosUQuantizer(compile_spec)
partitioner = EthosUPartitioner(compile_spec)
operator_config = get_arm_symmetric_qconfig(is_per_channel=True)
quantizer.set_global(operator_config)
exported_program = torch.export.export(float_model, (torch.randn(1, 3, 224, 224).to(device), ))
graph_module = exported_program.module(check_guards=False)
prepared = prepare_pt2e(graph_module, quantizer)
with torch.no_grad():
    prepared(torch.randn(1, 3, 224, 224).to(device))
prepared = prepared.to("cpu")
quantized_graph_module = convert_pt2e(prepared, fold_quantize=True)
quantized_exported_program = torch.export.export(quantized_graph_module, (torch.randn(1, 3, 224, 224).to(device), ))
edge_program_manager = to_edge_transform_and_lower(
    quantized_exported_program,
    partitioner=[partitioner],
    compile_config=EdgeCompileConfig(
        _check_ir_validity=False,
    ),
)

Error Message

The run crashes with
Process finished with exit code 139 (interrupted by signal 11:SIGSEGV)

Workaround

Changing the 5D reshape to 4D and then using 'unsqueeze(1)' works. This solves this case because the original model only "unsqueezes". If it was larger than 1 it would be a problem

 def forward(self, x):
        x1 = self.block(x)
        out = x1.permute(0, 2, 3, 1).reshape(1, 49, 64, 256).permute(0, 2, 1, 3)
        out = out.unsqueeze(1)
        return out

Versions

PyTorch version: 2.9.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 26.2 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.6.3.2)
CMake version: version 4.2.0
Libc version: N/A

Python version: 3.12.10 (v3.12.10:0cc81280367, Apr 8 2025, 08:46:59) [Clang 13.0.0 (clang-1300.0.29.30)] (64-bit runtime)
Python platform: macOS-26.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Apple M4 Max

Versions of relevant libraries:
[pip3] executorch==1.0.1
[pip3] numpy==2.4.0
[pip3] onnx==1.20.0
[pip3] onnx-ir==0.1.14
[pip3] onnxscript==0.5.7
[pip3] pytorch_tokenizers==1.0.1
[pip3] torch==2.9.1
[pip3] torchao==0.14.0
[pip3] torchaudio==2.9.1
[pip3] torchsampler==0.1.2
[pip3] torchvision==0.24.0
[conda] Could not collect

cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: armIssues related to arm backendpartner: armFor backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions