From b42c3696f915acb96a12f0db00f9e4ef05486cfb Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Mon, 19 Jan 2026 12:51:20 +0100 Subject: [PATCH 1/4] Arm backend: Clean up TosaSpecification use in TosaArg There is no reason for TosaArg to have a tosa_spec attribute, it is available in self.tosa_spec for node visitors. Remove unused tosa_spec argument in map_dtype to unify __validate() approach. Signed-off-by: Erik Lundell Change-Id: I0c13964af42bd579709d8115f9777d484f118566 --- backends/arm/operators/op_abs.py | 4 +- backends/arm/operators/op_add.py | 4 +- backends/arm/operators/op_amax.py | 4 +- backends/arm/operators/op_amin.py | 4 +- backends/arm/operators/op_any.py | 4 +- backends/arm/operators/op_avg_pool2d.py | 4 +- backends/arm/operators/op_bitwise_not.py | 4 +- backends/arm/operators/op_cat.py | 6 +-- backends/arm/operators/op_ceil.py | 4 +- backends/arm/operators/op_clamp.py | 4 +- backends/arm/operators/op_constant_pad_nd.py | 4 +- backends/arm/operators/op_cos.py | 4 +- backends/arm/operators/op_eq.py | 6 +-- backends/arm/operators/op_erf.py | 4 +- backends/arm/operators/op_exp.py | 4 +- backends/arm/operators/op_floor.py | 4 +- backends/arm/operators/op_ge.py | 6 +-- backends/arm/operators/op_gt.py | 6 +-- backends/arm/operators/op_index_select.py | 4 +- backends/arm/operators/op_index_tensor.py | 8 ++-- backends/arm/operators/op_le.py | 6 +-- backends/arm/operators/op_log.py | 4 +- backends/arm/operators/op_logical_not.py | 4 +- backends/arm/operators/op_lt.py | 6 +-- backends/arm/operators/op_max_pool2d.py | 4 +- backends/arm/operators/op_maximum.py | 4 +- backends/arm/operators/op_minimum.py | 4 +- backends/arm/operators/op_mul.py | 4 +- backends/arm/operators/op_neg.py | 4 +- backends/arm/operators/op_permute.py | 4 +- backends/arm/operators/op_pow.py | 4 +- backends/arm/operators/op_reciprocal.py | 4 +- backends/arm/operators/op_repeat.py | 4 +- backends/arm/operators/op_rshift_tensor.py | 4 +- backends/arm/operators/op_rsqrt.py | 4 +- backends/arm/operators/op_sigmoid.py | 4 +- backends/arm/operators/op_sin.py | 4 +- backends/arm/operators/op_slice.py | 4 +- backends/arm/operators/op_sub.py | 4 +- backends/arm/operators/op_sum.py | 4 +- backends/arm/operators/op_tanh.py | 4 +- backends/arm/operators/op_tosa_matmul.py | 6 +-- backends/arm/operators/op_tosa_rescale.py | 6 +-- backends/arm/operators/op_tosa_resize.py | 6 +-- backends/arm/operators/op_tosa_scatter.py | 2 +- backends/arm/operators/op_tosa_table.py | 6 +-- backends/arm/operators/op_tosa_transpose.py | 4 +- backends/arm/operators/op_view.py | 4 +- backends/arm/operators/op_where.py | 10 ++-- backends/arm/operators/op_while.py | 6 +-- .../operators/operator_validation_utils.py | 4 +- backends/arm/operators/ops_binary.py | 6 +-- backends/arm/operators/ops_identity.py | 4 +- backends/arm/test/tester/arm_tester.py | 2 +- backends/arm/tosa/mapping.py | 46 ++++++------------- backends/arm/tosa/utils.py | 9 ++-- 56 files changed, 139 insertions(+), 162 deletions(-) diff --git a/backends/arm/operators/op_abs.py b/backends/arm/operators/op_abs.py index b5a58136395..fd16ca1eada 100644 --- a/backends/arm/operators/op_abs.py +++ b/backends/arm/operators/op_abs.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -44,7 +44,7 @@ def define_node( self.target, [*inputs, output], [ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index 6c1ff2e1449..fe90d5640a5 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. +# Copyright 2023-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -44,7 +44,7 @@ def define_node( self.target, [*inputs, output], [ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_amax.py b/backends/arm/operators/op_amax.py index e4824fb59c2..d698f4ccd72 100644 --- a/backends/arm/operators/op_amax.py +++ b/backends/arm/operators/op_amax.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -42,7 +42,7 @@ def define_node( self.target, [inputs[0], output], [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) input = inputs[0] diff --git a/backends/arm/operators/op_amin.py b/backends/arm/operators/op_amin.py index 34d4d37cdeb..c3b1bdd4786 100644 --- a/backends/arm/operators/op_amin.py +++ b/backends/arm/operators/op_amin.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -42,7 +42,7 @@ def define_node( self.target, [inputs[0], output], [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) input = inputs[0] diff --git a/backends/arm/operators/op_any.py b/backends/arm/operators/op_any.py index 2a850c0cf52..d2cec0b8ea8 100644 --- a/backends/arm/operators/op_any.py +++ b/backends/arm/operators/op_any.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -37,7 +37,7 @@ def define_node( validate_num_inputs(self.target, inputs, 3) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( - self.target, [inputs[0], output], ts.DType.BOOL, output.tosa_spec + self.target, [inputs[0], output], ts.DType.BOOL, self.tosa_spec ) input_shape = list(inputs[0].shape) diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index ec9d42915c1..1873351c465 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. +# Copyright 2023-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -122,7 +122,7 @@ def define_node( self.target, [inputs[0], output], supported_dtypes, - output.tosa_spec, + self.tosa_spec, ) if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16: diff --git a/backends/arm/operators/op_bitwise_not.py b/backends/arm/operators/op_bitwise_not.py index ac0f758469d..c6063ee132e 100644 --- a/backends/arm/operators/op_bitwise_not.py +++ b/backends/arm/operators/op_bitwise_not.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -46,7 +46,7 @@ def define_node( self.target, [*inputs, output], [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], - output.tosa_spec, + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_cat.py b/backends/arm/operators/op_cat.py index 71c18530d55..60ef31017d0 100644 --- a/backends/arm/operators/op_cat.py +++ b/backends/arm/operators/op_cat.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -41,13 +41,13 @@ def define_node( if self.tosa_spec.support_extension("int16"): supported_dtypes.append(ts.DType.INT16) validate_num_inputs(self.target, inputs, [1, 2]) - input_tosa_args = [TosaArg(arg, output.tosa_spec) for arg in inputs[0].special] + input_tosa_args = [TosaArg(arg, self.tosa_spec) for arg in inputs[0].special] validate_same_dtype(self.target, [*input_tosa_args, output], ts) validate_valid_dtype( self.target, [*input_tosa_args, output], supported_dtypes, - output.tosa_spec, + self.tosa_spec, ) dim = 0 if len(inputs) < 2 else inputs[1].number diff --git a/backends/arm/operators/op_ceil.py b/backends/arm/operators/op_ceil.py index 27ee81d0abe..d0f713c24a9 100644 --- a/backends/arm/operators/op_ceil.py +++ b/backends/arm/operators/op_ceil.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -46,7 +46,7 @@ def define_node( self.target, inputs[0], ts.DType.FP32, - output.tosa_spec, + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_clamp.py b/backends/arm/operators/op_clamp.py index d90f92f5e4b..931a842a656 100644 --- a/backends/arm/operators/op_clamp.py +++ b/backends/arm/operators/op_clamp.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree @@ -94,7 +94,7 @@ def define_node( self.target, [inputs[0], output], supported_dtypes, - output.tosa_spec, + self.tosa_spec, ) node_input_dtype = node.meta["val"].dtype diff --git a/backends/arm/operators/op_constant_pad_nd.py b/backends/arm/operators/op_constant_pad_nd.py index 47d11fb5627..ee9f268cdc1 100644 --- a/backends/arm/operators/op_constant_pad_nd.py +++ b/backends/arm/operators/op_constant_pad_nd.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -55,7 +55,7 @@ def define_node( ts.DType.FP32, ts.DType.BOOL, ], - output.tosa_spec, + self.tosa_spec, ) if inputs[0].dtype == ts.DType.INT8: diff --git a/backends/arm/operators/op_cos.py b/backends/arm/operators/op_cos.py index e6039730b69..30127fca3b9 100644 --- a/backends/arm/operators/op_cos.py +++ b/backends/arm/operators/op_cos.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -40,7 +40,7 @@ def define_node( validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( - self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec + self.target, [*inputs, output], ts.DType.FP32, self.tosa_spec ) attr = ts.TosaSerializerAttribute() attr.CosAttribute() diff --git a/backends/arm/operators/op_eq.py b/backends/arm/operators/op_eq.py index bd72c9491ca..0268a1d3ec3 100644 --- a/backends/arm/operators/op_eq.py +++ b/backends/arm/operators/op_eq.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -48,9 +48,9 @@ def define_node( self.target, inputs, [ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) - validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) + validate_valid_dtype(self.target, output, ts.DType.BOOL, self.tosa_spec) attr = ts.TosaSerializerAttribute() attr.EqualAttribute() diff --git a/backends/arm/operators/op_erf.py b/backends/arm/operators/op_erf.py index e642a4059fe..676bac392f0 100644 --- a/backends/arm/operators/op_erf.py +++ b/backends/arm/operators/op_erf.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -43,7 +43,7 @@ def define_node( self.target, [*inputs, output], ts.DType.FP32, - output.tosa_spec, + self.tosa_spec, ) # MI lowering diff --git a/backends/arm/operators/op_exp.py b/backends/arm/operators/op_exp.py index 72e89b6906b..4d11558fe75 100644 --- a/backends/arm/operators/op_exp.py +++ b/backends/arm/operators/op_exp.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -44,7 +44,7 @@ def define_node( self.target, [*inputs, output], ts.DType.FP32, - output.tosa_spec, + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_floor.py b/backends/arm/operators/op_floor.py index d9f831dfb35..ad0285e6fc7 100644 --- a/backends/arm/operators/op_floor.py +++ b/backends/arm/operators/op_floor.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -46,7 +46,7 @@ def define_node( self.target, inputs[0], ts.DType.FP32, - output.tosa_spec, + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_ge.py b/backends/arm/operators/op_ge.py index 754778487e9..e23f63d263d 100644 --- a/backends/arm/operators/op_ge.py +++ b/backends/arm/operators/op_ge.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -48,9 +48,9 @@ def define_node( self.target, inputs, [ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) - validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) + validate_valid_dtype(self.target, output, ts.DType.BOOL, self.tosa_spec) attr = ts.TosaSerializerAttribute() attr.GreaterEqualAttribute() diff --git a/backends/arm/operators/op_gt.py b/backends/arm/operators/op_gt.py index 2a483f735a7..31751ad0ed3 100644 --- a/backends/arm/operators/op_gt.py +++ b/backends/arm/operators/op_gt.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -48,9 +48,9 @@ def define_node( self.target, inputs, [ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) - validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) + validate_valid_dtype(self.target, output, ts.DType.BOOL, self.tosa_spec) attr = ts.TosaSerializerAttribute() attr.GreaterAttribute() diff --git a/backends/arm/operators/op_index_select.py b/backends/arm/operators/op_index_select.py index ba2aa03c7ff..a205bc29b53 100644 --- a/backends/arm/operators/op_index_select.py +++ b/backends/arm/operators/op_index_select.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -56,7 +56,7 @@ def define_node( self.target, [inputs[0], output], [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) weights, _, indices = inputs diff --git a/backends/arm/operators/op_index_tensor.py b/backends/arm/operators/op_index_tensor.py index cd0809df95b..7fc2aa5f027 100644 --- a/backends/arm/operators/op_index_tensor.py +++ b/backends/arm/operators/op_index_tensor.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -46,7 +46,7 @@ def _get_tensor_info(self, tensor: Node): """ if isinstance(tensor, Node): - dtype, shape, _ = extract_tensor_meta(tensor.meta, self.tosa_spec) + dtype, shape, _ = extract_tensor_meta(tensor.meta) return tensor.name, dtype, shape else: return tensor.name, tensor.dtype, tensor.shape @@ -133,9 +133,7 @@ def define_node( index_nodes = indices.special # Broadcast indices - broadcasted_tensors = tutils.broadcast_tensors( - tosa_graph, index_nodes, self.tosa_spec - ) + broadcasted_tensors = tutils.broadcast_tensors(tosa_graph, index_nodes) # Calculate strides so we can shift indices down the line. values_strides = self._calculate_value_strides(values.shape) diff --git a/backends/arm/operators/op_le.py b/backends/arm/operators/op_le.py index aa6b52b9982..5a49e1de420 100644 --- a/backends/arm/operators/op_le.py +++ b/backends/arm/operators/op_le.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -48,9 +48,9 @@ def define_node( self.target, inputs, [ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) - validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) + validate_valid_dtype(self.target, output, ts.DType.BOOL, self.tosa_spec) attr = ts.TosaSerializerAttribute() attr.GreaterEqualAttribute() diff --git a/backends/arm/operators/op_log.py b/backends/arm/operators/op_log.py index 565d6d56027..7ec0e6e082e 100644 --- a/backends/arm/operators/op_log.py +++ b/backends/arm/operators/op_log.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -41,7 +41,7 @@ def define_node( validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( - self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec + self.target, [*inputs, output], ts.DType.FP32, self.tosa_spec ) attr = ts.TosaSerializerAttribute() attr.LogAttribute() diff --git a/backends/arm/operators/op_logical_not.py b/backends/arm/operators/op_logical_not.py index 695af5f7a26..81b44326ae2 100644 --- a/backends/arm/operators/op_logical_not.py +++ b/backends/arm/operators/op_logical_not.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -47,7 +47,7 @@ def define_node( self.target, [*inputs, output], [ts.DType.BOOL], - output.tosa_spec, + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_lt.py b/backends/arm/operators/op_lt.py index 4b2b1a1960b..d51ac28b5b9 100644 --- a/backends/arm/operators/op_lt.py +++ b/backends/arm/operators/op_lt.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -48,9 +48,9 @@ def define_node( self.target, inputs, [ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) - validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) + validate_valid_dtype(self.target, output, ts.DType.BOOL, self.tosa_spec) attr = ts.TosaSerializerAttribute() attr.GreaterAttribute() diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index bee0cc3fb0c..8d4bd4f6635 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -51,7 +51,7 @@ def define_node( self.target, [inputs[0], output], supported_dtypes, - output.tosa_spec, + self.tosa_spec, ) input_tensor = inputs[0] diff --git a/backends/arm/operators/op_maximum.py b/backends/arm/operators/op_maximum.py index d3ab305ea3b..71eca11801b 100644 --- a/backends/arm/operators/op_maximum.py +++ b/backends/arm/operators/op_maximum.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -47,7 +47,7 @@ def define_node( self.target, [*inputs, output], [ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) attr_maximum = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_minimum.py b/backends/arm/operators/op_minimum.py index 7f72d158d43..bc506e26fe7 100644 --- a/backends/arm/operators/op_minimum.py +++ b/backends/arm/operators/op_minimum.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -48,7 +48,7 @@ def define_node( self.target, [*inputs, output], [ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) attr_minimum = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index 0e10443e523..dc7a2277738 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -45,7 +45,7 @@ def define_node( self.target, [*inputs, output], [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{output.name}_shift") diff --git a/backends/arm/operators/op_neg.py b/backends/arm/operators/op_neg.py index e0bb408e155..621c9872b31 100644 --- a/backends/arm/operators/op_neg.py +++ b/backends/arm/operators/op_neg.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -66,7 +66,7 @@ def define_node( validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( - self.target, [*inputs, output], supported_dtypes, output.tosa_spec + self.target, [*inputs, output], supported_dtypes, self.tosa_spec ) input_zp, output_zp = get_negate_zero_points( diff --git a/backends/arm/operators/op_permute.py b/backends/arm/operators/op_permute.py index fea0aea9298..b76c427d622 100644 --- a/backends/arm/operators/op_permute.py +++ b/backends/arm/operators/op_permute.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. +# Copyright 2023-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -123,7 +123,7 @@ def define_node( ts.DType.INT32, ts.DType.FP32, ], - output.tosa_spec, + self.tosa_spec, ) # The permutation vector describes a permutation P in default Pytorch dim_order. diff --git a/backends/arm/operators/op_pow.py b/backends/arm/operators/op_pow.py index 33cbc290d2c..e159fe66bcc 100644 --- a/backends/arm/operators/op_pow.py +++ b/backends/arm/operators/op_pow.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -46,7 +46,7 @@ def define_node( self.target, [*inputs, output], [ts.DType.FP16, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() attr.PowAttribute() diff --git a/backends/arm/operators/op_reciprocal.py b/backends/arm/operators/op_reciprocal.py index 108a4fac0fb..77478edbf21 100644 --- a/backends/arm/operators/op_reciprocal.py +++ b/backends/arm/operators/op_reciprocal.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. +# Copyright 2023-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -42,7 +42,7 @@ def define_node( validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( - self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec + self.target, [*inputs, output], ts.DType.FP32, self.tosa_spec ) attr = ts.TosaSerializerAttribute() attr.ReciprocalAttribute() diff --git a/backends/arm/operators/op_repeat.py b/backends/arm/operators/op_repeat.py index e44fede736d..29483501e5c 100644 --- a/backends/arm/operators/op_repeat.py +++ b/backends/arm/operators/op_repeat.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -50,7 +50,7 @@ def define_node( ts.DType.INT32, ts.DType.FP32, ], - output.tosa_spec, + self.tosa_spec, ) multiples = inputs[1].special diff --git a/backends/arm/operators/op_rshift_tensor.py b/backends/arm/operators/op_rshift_tensor.py index 0b5717aa403..8971622b9a7 100644 --- a/backends/arm/operators/op_rshift_tensor.py +++ b/backends/arm/operators/op_rshift_tensor.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -41,7 +41,7 @@ def define_node( self.target, [*inputs, output], [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], - output.tosa_spec, + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_rsqrt.py b/backends/arm/operators/op_rsqrt.py index a86eaa40985..a98ace6b759 100644 --- a/backends/arm/operators/op_rsqrt.py +++ b/backends/arm/operators/op_rsqrt.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -42,7 +42,7 @@ def define_node( validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( - self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec + self.target, [*inputs, output], ts.DType.FP32, self.tosa_spec ) attr = ts.TosaSerializerAttribute() attr.RsqrtAttribute() diff --git a/backends/arm/operators/op_sigmoid.py b/backends/arm/operators/op_sigmoid.py index 908544ff00c..24d04be93c2 100644 --- a/backends/arm/operators/op_sigmoid.py +++ b/backends/arm/operators/op_sigmoid.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -41,7 +41,7 @@ def define_node( validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( - self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec + self.target, [*inputs, output], ts.DType.FP32, self.tosa_spec ) attr = ts.TosaSerializerAttribute() attr.SigmoidAttribute() diff --git a/backends/arm/operators/op_sin.py b/backends/arm/operators/op_sin.py index faa249917c3..528ff0227ad 100644 --- a/backends/arm/operators/op_sin.py +++ b/backends/arm/operators/op_sin.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -40,7 +40,7 @@ def define_node( validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( - self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec + self.target, [*inputs, output], ts.DType.FP32, self.tosa_spec ) attr = ts.TosaSerializerAttribute() attr.SinAttribute() diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py index 21c86e5f7c4..9dac73d84ba 100644 --- a/backends/arm/operators/op_slice.py +++ b/backends/arm/operators/op_slice.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -80,7 +80,7 @@ def define_node( ts.DType.INT32, ts.DType.FP32, ], - output.tosa_spec, + self.tosa_spec, ) # See slice_copy_support.py diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index 039a2f6bd68..db85e2d78c7 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. +# Copyright 2023-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -44,7 +44,7 @@ def define_node( self.target, [*inputs, output], [ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_sum.py b/backends/arm/operators/op_sum.py index e956359736c..ca841e85a21 100644 --- a/backends/arm/operators/op_sum.py +++ b/backends/arm/operators/op_sum.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. +# Copyright 2023-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -44,7 +44,7 @@ def define_node( self.target, [inputs[0], output], [ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, + self.tosa_spec, ) tensor = inputs[0] diff --git a/backends/arm/operators/op_tanh.py b/backends/arm/operators/op_tanh.py index c4603e90118..ecc7a261816 100644 --- a/backends/arm/operators/op_tanh.py +++ b/backends/arm/operators/op_tanh.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -42,7 +42,7 @@ def define_node( validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) validate_valid_dtype( - self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec + self.target, [*inputs, output], ts.DType.FP32, self.tosa_spec ) attr = ts.TosaSerializerAttribute() attr.TanhAttribute() diff --git a/backends/arm/operators/op_tosa_matmul.py b/backends/arm/operators/op_tosa_matmul.py index 993caff9867..777b450abfb 100644 --- a/backends/arm/operators/op_tosa_matmul.py +++ b/backends/arm/operators/op_tosa_matmul.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -58,7 +58,7 @@ def define_node( self.target, [*inputs], supported_input_dtypes, - output.tosa_spec, + self.tosa_spec, ) supported_output_dtypes = [ts.DType.INT32, ts.DType.FP32] if self.tosa_spec.support_extension("int16"): @@ -67,7 +67,7 @@ def define_node( self.target, [output], supported_output_dtypes, - output.tosa_spec, + self.tosa_spec, ) # We need to get the zero points and add an intermediate tensor for INT16 case diff --git a/backends/arm/operators/op_tosa_rescale.py b/backends/arm/operators/op_tosa_rescale.py index ae87dcc9c31..77f3569bbac 100644 --- a/backends/arm/operators/op_tosa_rescale.py +++ b/backends/arm/operators/op_tosa_rescale.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -235,8 +235,8 @@ def define_node( if ( input_dtype not in [ - map_dtype(torch.int8, self.tosa_spec), - map_dtype(torch.int16, self.tosa_spec), + map_dtype(torch.int8), + map_dtype(torch.int16), ] and input_zp != 0 ): diff --git a/backends/arm/operators/op_tosa_resize.py b/backends/arm/operators/op_tosa_resize.py index e7e63f155d3..f40f3283eec 100644 --- a/backends/arm/operators/op_tosa_resize.py +++ b/backends/arm/operators/op_tosa_resize.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -46,7 +46,7 @@ def define_node( self.target, [inputs[0]], supported_input_dtypes, - output.tosa_spec, + self.tosa_spec, ) supported_output_dtypes = [ts.DType.FP32] if node.kwargs.get("resize_mode") == "bilinear": @@ -63,7 +63,7 @@ def define_node( if self.tosa_spec.support_extension("int16"): supported_output_dtypes.append(ts.DType.INT16) validate_valid_dtype( - self.target, [output], supported_output_dtypes, output.tosa_spec + self.target, [output], supported_output_dtypes, self.tosa_spec ) # tosa_shape output is NHWC, take HW input_size_yx = tuple([inputs[0].shape[dim] for dim in inputs[0].dim_order])[ diff --git a/backends/arm/operators/op_tosa_scatter.py b/backends/arm/operators/op_tosa_scatter.py index f0308e024ef..904965dfe94 100644 --- a/backends/arm/operators/op_tosa_scatter.py +++ b/backends/arm/operators/op_tosa_scatter.py @@ -46,7 +46,7 @@ def define_node( ts.DType.FP32, ts.DType.FP16, ], - output.tosa_spec, + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_tosa_table.py b/backends/arm/operators/op_tosa_table.py index 7448898bddc..2711d18a98b 100644 --- a/backends/arm/operators/op_tosa_table.py +++ b/backends/arm/operators/op_tosa_table.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -43,10 +43,10 @@ def define_node( supported_output_dtypes.append(ts.DType.INT32) validate_valid_dtype( - self.target, inputs, supported_input_dtypes, output.tosa_spec + self.target, inputs, supported_input_dtypes, self.tosa_spec ) validate_valid_dtype( - self.target, output, supported_output_dtypes, output.tosa_spec + self.target, output, supported_output_dtypes, self.tosa_spec ) # The name of the table constant is a bit complex. diff --git a/backends/arm/operators/op_tosa_transpose.py b/backends/arm/operators/op_tosa_transpose.py index c5aa66a85fd..d7db2a7b1d7 100644 --- a/backends/arm/operators/op_tosa_transpose.py +++ b/backends/arm/operators/op_tosa_transpose.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -54,7 +54,7 @@ def define_node( ts.DType.FP16, ts.DType.FP32, ], - output.tosa_spec, + self.tosa_spec, ) output_rank = len(output.shape) diff --git a/backends/arm/operators/op_view.py b/backends/arm/operators/op_view.py index a32cb3aac06..7983afbe786 100644 --- a/backends/arm/operators/op_view.py +++ b/backends/arm/operators/op_view.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. +# Copyright 2023-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -50,7 +50,7 @@ def define_node( ts.DType.FP32, ts.DType.BOOL, ], - output.tosa_spec, + self.tosa_spec, ) tosa_graph = cast(ts.TosaSerializer, tosa_graph) diff --git a/backends/arm/operators/op_where.py b/backends/arm/operators/op_where.py index f0b6538ac27..7e183e2bd1e 100644 --- a/backends/arm/operators/op_where.py +++ b/backends/arm/operators/op_where.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -43,24 +43,24 @@ def define_node( ) -> None: supported_dtypes = [ts.DType.BOOL] - if output.tosa_spec.support_integer(): + if self.tosa_spec.support_integer(): supported_dtypes += [ ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ] - if output.tosa_spec.support_float(): + if self.tosa_spec.support_float(): supported_dtypes += [ts.DType.FP16, ts.DType.FP32] validate_num_inputs(self.target, inputs, 3) # Not first input, which is condition tensor. validate_same_dtype(self.target, inputs[1:], ts) - validate_valid_dtype(self.target, inputs[0], ts.DType.BOOL, output.tosa_spec) + validate_valid_dtype(self.target, inputs[0], ts.DType.BOOL, self.tosa_spec) validate_valid_dtype( self.target, [*inputs[1:], output], supported_dtypes, - output.tosa_spec, + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_while.py b/backends/arm/operators/op_while.py index b4ac4f4f6f1..6c2bfcac6c9 100644 --- a/backends/arm/operators/op_while.py +++ b/backends/arm/operators/op_while.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -76,9 +76,7 @@ def define_node( ) in zip(outputs_needing_tensors, output_dim_orders, strict=True): tensor_name = output_needing_tensor.name + "_dummy" shape = output_needing_tensor.meta["val"].shape - dtype = map_dtype( - output_needing_tensor.meta["val"].dtype, self.tosa_spec - ) + dtype = map_dtype(output_needing_tensor.meta["val"].dtype) tosa_graph.currRegion.currBasicBlock.addTensor( tensor_name, diff --git a/backends/arm/operators/operator_validation_utils.py b/backends/arm/operators/operator_validation_utils.py index 20ee10534d0..6b3271ee8e4 100644 --- a/backends/arm/operators/operator_validation_utils.py +++ b/backends/arm/operators/operator_validation_utils.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -127,7 +127,7 @@ def validate_valid_dtype( self.target, [*inputs, output], [ts.DType.INT8, ts.DType.INT32], - output.tosa_spec, + self.tosa_spec, ) """ diff --git a/backends/arm/operators/ops_binary.py b/backends/arm/operators/ops_binary.py index 3e8cda76b5a..15f07db2e1b 100644 --- a/backends/arm/operators/ops_binary.py +++ b/backends/arm/operators/ops_binary.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -52,7 +52,7 @@ def define_node( self.target, [*inputs, output], [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], - output.tosa_spec, + self.tosa_spec, ) if self.target in [ "aten.logical_and.default", @@ -63,7 +63,7 @@ def define_node( self.target, [*inputs, output], [ts.DType.BOOL], - output.tosa_spec, + self.tosa_spec, ) attr = ts.TosaSerializerAttribute() attr_builder(attr) diff --git a/backends/arm/operators/ops_identity.py b/backends/arm/operators/ops_identity.py index 0930d7e7997..21c524f2644 100644 --- a/backends/arm/operators/ops_identity.py +++ b/backends/arm/operators/ops_identity.py @@ -49,7 +49,7 @@ def define_node( ts.DType.INT16, ts.DType.INT32, ] - if output.tosa_spec.support_float(): + if self.tosa_spec.support_float(): supported_dtypes += [ts.DType.FP32] if self.tosa_spec.support_extension("int16"): supported_dtypes += [ts.DType.INT48] @@ -59,7 +59,7 @@ def define_node( self.target, [inputs[0], output], supported_dtypes, - output.tosa_spec, + self.tosa_spec, ) # Simply add an identityOp diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 66f90ead413..e12902b6e2b 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -1025,7 +1025,7 @@ def _get_dtype_distribution( placeholder_dtypes.append(str(node.meta["val"].dtype)) if node.op == "call_function": if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor): - dtype, _, _ = extract_tensor_meta(node.meta, tosa_spec) + dtype, _, _ = extract_tensor_meta(node.meta) call_function_dtypes.append(ts.DTypeNames[dtype]) return Counter(placeholder_dtypes), Counter(call_function_dtypes) diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index c11a046cd66..206e244dbe8 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -11,7 +11,7 @@ import operator from enum import Enum -from typing import Any, Optional, Sequence +from typing import Any, Sequence import torch import tosa_serializer as ts @@ -76,12 +76,11 @@ def min(self): raise ValueError(f"Unrecognized TosaSpecialDtype {self}.") -def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any: +def map_dtype(data_type: torch.dtype) -> Any: """Map a ``torch.dtype`` to a ``ts.DType``. Args: data_type (torch.dtype): PyTorch dtype to convert. - tosa_spec (TosaSpecification): Active spec (reserved for future checks). Returns: ts.DType: Matching serializer dtype. @@ -114,12 +113,11 @@ def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any: # Returns the shape and type of a node # TODO: other types, can be # SymInt, FakeTensor, a List[Union[FakeTensor, SymInt]], or None -def extract_tensor_meta(meta, tosa_spec: TosaSpecification): +def extract_tensor_meta(meta): """Extract dtype, shape, and dimension order from FX metadata. Args: meta (dict): FX node ``meta`` containing a ``val`` FakeTensor (or tuple). - tosa_spec (TosaSpecification): Active TOSA spec for dtype mapping. Returns: tuple[ts.DType, tuple[int, ...], tuple[int, ...]]: Tuple containing @@ -140,7 +138,7 @@ def extract_tensor_meta(meta, tosa_spec: TosaSpecification): raise ValueError( f"Expected first value in node.meta['val'] to be FakeTensor, got {val.__class__}" ) - dtype = map_dtype(val.dtype, tosa_spec) + dtype = map_dtype(val.dtype) shape = tuple(val.size()) if meta.get("tosa_dim_order") is not None: @@ -165,7 +163,6 @@ class TosaArg: ``range(len(shape))``. special (list | None): Captured list when the argument is a sequence. number (float | int | None): Captured numeric value when provided. - tosa_spec (TosaSpecification): Active specification used for mapping. multiple_output_name (list[str]): Output node names when node has multiple outputs; empty otherwise. """ @@ -181,9 +178,8 @@ def __process_node(self, argument: torch.fx.Node): if "val" in argument.meta: output_dtype, self.shape, self.dim_order = extract_tensor_meta( - argument.meta, self.tosa_spec - ) - # Handle special case of types not representable in torch (i.e. i48_t) + argument.meta + ) # Handle special case of types not representable in torch (i.e. i48_t) if special_type := argument.meta.get(TosaSpecialDtype.meta_key(), None): output_dtype = special_type.get_tosa_dtype() @@ -199,11 +195,6 @@ def __process_node(self, argument: torch.fx.Node): else: self.multiple_output_names = [] - if not self.__validate(): - raise ValueError( - f"{self.tosa_spec} doesn't support tensor {self.__repr__()}" - ) - def __process_list(self, argument): """Capture a sequence argument as ``special``. @@ -222,20 +213,18 @@ def __process_number(self, argument: float | int): """ self.number: float | int = argument - def __validate(self) -> bool: + def __validate(self, tosa_spec: TosaSpecification) -> bool: match getattr(self, "dtype", None): case ts.DType.FP32: - if not self.tosa_spec.support_float(): + if not tosa_spec.support_float(): return False case ts.DType.INT4: - if not self.tosa_spec.support_extension("int4"): + if not tosa_spec.support_extension("int4"): return False return True - def __init__( - self, argument: Any, tosa_spec: Optional[TosaSpecification] = None - ) -> None: + def __init__(self, argument: Any, tosa_spec: TosaSpecification) -> None: """Initialize the argument wrapper and populate fields. Args: @@ -245,20 +234,17 @@ def __init__( required for metadata extraction. Raises: - ValueError: If ``tosa_spec`` is missing or has the wrong type. RuntimeError: If ``argument`` is of an unsupported type. """ - if tosa_spec is None: - raise ValueError("tosa_spec is None") - elif not isinstance(tosa_spec, TosaSpecification): - raise ValueError( - f"Expected tosa_spec to be a TosaSpecification, but got {tosa_spec}" - ) - self.tosa_spec = tosa_spec if isinstance(argument, torch.fx.Node): self.__process_node(argument) + if not self.__validate(tosa_spec): + raise ValueError( + f"{tosa_spec} doesn't support tensor {self.__repr__()}" + ) + return if isinstance(argument, Sequence): self.__process_list(argument) @@ -302,8 +288,6 @@ def __repr__(self): attrs.append(f"special={self.special!r}") if hasattr(self, "number") and self.number is not None: attrs.append(f"number={self.number!r}") - if hasattr(self, "tosa_spec") and self.tosa_spec is not None: - attrs.append(f"tosa_spec={self.tosa_spec!r}") if hasattr(self, "multiple_output_names"): attrs.append(f"names={self.multiple_output_names!r}") return f"{self.__class__.__name__}({', '.join(attrs)})" diff --git a/backends/arm/tosa/utils.py b/backends/arm/tosa/utils.py index df77153e29f..6666e422039 100644 --- a/backends/arm/tosa/utils.py +++ b/backends/arm/tosa/utils.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. +# Copyright 2023-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -15,7 +15,6 @@ import tosa_serializer as ts from executorch.backends.arm.tosa.mapping import extract_tensor_meta -from executorch.backends.arm.tosa.specification import TosaSpecification from torch._subclasses.fake_tensor import FakeTensor from torch.fx import Node @@ -64,9 +63,7 @@ def are_fake_tensors_broadcastable( return (True, list(reversed(broadcast_shape))) -def broadcast_tensors( - tosa_fb, nodes: list[Node], tosa_spec: TosaSpecification -) -> list[Any]: +def broadcast_tensors(tosa_fb, nodes: list[Node]) -> list[Any]: """Broadcast the FX nodes to a shared shape inside the TOSA graph. This mirrors ``reshape_for_broadcast`` but also emits the tile operators @@ -96,7 +93,7 @@ def broadcast_tensors( broadcast_tensors = [] for node in nodes: - tens_dtype, tens_shape, _ = extract_tensor_meta(node.meta, tosa_spec) + tens_dtype, tens_shape, _ = extract_tensor_meta(node.meta) list_tens_shape = list(tens_shape) # Already in the right shape we can just add it to the list. if list_tens_shape == common_shape: From cb971077b386aeb667e5a1fa31d500e9de939909 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Tue, 13 Jan 2026 15:38:14 +0100 Subject: [PATCH 2/4] Arm backend: Implement BF16 infrastructure Update add and transpose node visitors as examples. Since TosaArg validates that bf16 tensors are only created for EXT-BF16, no need to check again in the node visitor. Though transpose support is extended to support rank-4 adds, tests and proper support will be done in a follow-up patch. Signed-off-by: Erik Lundell Change-Id: I48adb9d106293df7aea415fcfe25b652202bab28 --- backends/arm/_passes/match_arg_dtype_pass.py | 7 ++--- backends/arm/operators/op_add.py | 2 +- backends/arm/operators/op_tosa_transpose.py | 1 + backends/arm/test/ops/test_add.py | 25 +++++++++++++++--- backends/arm/test/runner_utils.py | 27 ++++++++++++-------- backends/arm/tosa/mapping.py | 6 ++++- 6 files changed, 48 insertions(+), 20 deletions(-) diff --git a/backends/arm/_passes/match_arg_dtype_pass.py b/backends/arm/_passes/match_arg_dtype_pass.py index f0aaa0cf5f9..edf06bcf890 100644 --- a/backends/arm/_passes/match_arg_dtype_pass.py +++ b/backends/arm/_passes/match_arg_dtype_pass.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -19,8 +19,9 @@ torch.int32: 4, torch.int64: 5, torch.float16: 6, - torch.float32: 7, - torch.float64: 8, + torch.bfloat16: 7, + torch.float32: 8, + torch.float64: 9, } diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index fe90d5640a5..3cb0dc8b700 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -43,7 +43,7 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT32, ts.DType.FP32, ts.DType.BF16], self.tosa_spec, ) diff --git a/backends/arm/operators/op_tosa_transpose.py b/backends/arm/operators/op_tosa_transpose.py index d7db2a7b1d7..ada2e9820ae 100644 --- a/backends/arm/operators/op_tosa_transpose.py +++ b/backends/arm/operators/op_tosa_transpose.py @@ -53,6 +53,7 @@ def define_node( ts.DType.INT32, ts.DType.FP16, ts.DType.FP32, + ts.DType.BF16, ], self.tosa_spec, ) diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 31c31c3e88a..4ad1f2b1ba0 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -49,7 +49,7 @@ class Add2(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor): return x + y - test_data: list[input_t2] = { + test_data = { "5d_float": lambda: ( torch.FloatTensor([1, 2, 3, 5, 7]), (torch.FloatTensor([2, 1, 2, 1, 10])), @@ -70,6 +70,12 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): torch.randn(1, 10, 20, 30), ), } + test_data_bf16 = { + "4d_big_small_bf16": lambda: ( + (10e10) * torch.randn(1, 10, 20, 30, dtype=torch.bfloat16), + torch.randn(1, 10, 20, 30, dtype=torch.bfloat16), + ), + } class Add3(torch.nn.Module): @@ -89,6 +95,15 @@ def test_add_tensor_tosa_FP(test_data: input_t1): pipeline.run() +@common.parametrize("test_data", Add.test_data) +def test_add_tensor_tosa_FP_bf16(test_data: input_t1): + x = test_data()[0].to(torch.bfloat16) + pipeline = TosaPipelineFP[input_t1]( + Add(), (x,), aten_op, exir_op, tosa_extensions=["bf16"] + ) + pipeline.run() + + @common.parametrize("test_data", Add.test_data) def test_add_tensor_tosa_INT(test_data: input_t1): pipeline = TosaPipelineINT[input_t1](Add(), test_data(), aten_op, exir_op, qtol=0) @@ -152,9 +167,11 @@ def test_add_tensor_u85_INT(test_data: input_t1): pipeline.run() -@common.parametrize("test_data", Add2.test_data) +@common.parametrize("test_data", Add2.test_data | Add2.test_data_bf16) def test_add_tensor_tosa_FP_2(test_data: input_t2): - pipeline = TosaPipelineFP[input_t2](Add2(), test_data(), aten_op, exir_op) + pipeline = TosaPipelineFP[input_t2]( + Add2(), test_data(), aten_op, exir_op, tosa_extensions=["bf16"] + ) pipeline.run() diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 44e9e75d2d6..9aac49f9cb1 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -60,7 +60,7 @@ torch.float16: np.float16, torch.float32: np.float32, torch.float64: np.float64, - torch.bfloat16: np.float32, + torch.bfloat16: np.uint16, torch.complex32: np.complex64, torch.complex64: np.complex64, torch.complex128: np.complex128, @@ -172,16 +172,17 @@ def get_output_quantization_params( def torch_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray: - dtype = _torch_to_numpy_dtype_dict[tensor.dtype] - array = tensor.detach().numpy().astype(dtype) # type: ignore[var-annotated] dim_order = tensor.dim_order() if dim_order == NHWC_ORDER: - a = array.transpose(NHWC_ORDER) - return a + tensor = tensor.permute(NHWC_ORDER) elif dim_order == NNHWC_ORDER: - return array.transpose(NNHWC_ORDER) - else: - return array + tensor = tensor.permute(NNHWC_ORDER) + + tensor = tensor.detach() + if tensor.dtype == torch.bfloat16: + # Numpy doesn't support bfloat16, use, uint16 instead. Dtype is inferred from model anyways. + tensor = tensor.view(torch.uint16) + return tensor.numpy() def numpy_to_torch_tensor(array: np.ndarray, output_node: Node) -> torch.Tensor: @@ -197,8 +198,12 @@ def numpy_to_torch_tensor(array: np.ndarray, output_node: Node) -> torch.Tensor: tensor = torch.from_numpy(array).reshape(shape_with_dim_order) return tensor.permute(NNHWC_INVERSE_ORDER).to(memory_format=torch.channels_last) else: - tensor = torch.from_numpy(array).reshape(shape) - return tensor + if type(array.dtype) is np.dtypes.VoidDType: + # If dtype is void, "cheat" and use the output_tensor dtype. + tensor = torch.frombuffer(array, dtype=output_tensor.dtype) + else: + tensor = torch.from_numpy(array) + return tensor.reshape(shape) class TosaReferenceModelDispatch(TorchFunctionMode): diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index 206e244dbe8..494c0120f8b 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -221,6 +221,9 @@ def __validate(self, tosa_spec: TosaSpecification) -> bool: case ts.DType.INT4: if not tosa_spec.support_extension("int4"): return False + case ts.DType.BF16: + if not tosa_spec.support_extension("bf16"): + return False return True @@ -289,5 +292,6 @@ def __repr__(self): if hasattr(self, "number") and self.number is not None: attrs.append(f"number={self.number!r}") if hasattr(self, "multiple_output_names"): - attrs.append(f"names={self.multiple_output_names!r}") + if len(self.multiple_output_names) > 0: + attrs.append(f"names={self.multiple_output_names!r}") return f"{self.__class__.__name__}({', '.join(attrs)})" From 9c66acbe5317059938b4c61034eb8ca916e0c4d0 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Tue, 20 Jan 2026 11:03:35 +0100 Subject: [PATCH 3/4] Arm backend: Add BF16 support to operators pt.1 abs, average_pool_2d, cat, ceil and clamp mapping to TOSA operators ABS, ARGMAX, AVG_POOL2D, CAT, CEIL, CLAMP Required refactoring some serialization going through numpy since numpy doesn't have native support of bfloat16. Where possible, avoid dtype handling all together by viewing as uint8. In process node, make a special case for bf16 and use ml_dtypes. Signed-off-by: Erik Lundell Change-Id: I92c8cb204396d10d0809db93822dea7fe71ec8a0 --- .../arm/_passes/decompose_avg_pool2d_pass.py | 2 +- .../_passes/fuse_equal_placeholders_pass.py | 8 ++-- backends/arm/operators/op_abs.py | 2 +- backends/arm/operators/op_avg_pool2d.py | 2 +- backends/arm/operators/op_cat.py | 8 +++- backends/arm/operators/op_ceil.py | 2 +- backends/arm/operators/op_clamp.py | 19 ++++------ backends/arm/operators/op_maximum.py | 2 +- backends/arm/operators/op_minimum.py | 2 +- backends/arm/process_node.py | 37 +++++++++++++------ backends/arm/test/ops/test_abs.py | 12 ++++-- backends/arm/test/ops/test_avg_pool2d.py | 16 +++++++- backends/arm/test/ops/test_cat.py | 25 +++++++++++-- backends/arm/test/ops/test_ceil.py | 16 +++++++- backends/arm/test/ops/test_clamp.py | 36 ++++++++++++++++-- 15 files changed, 142 insertions(+), 47 deletions(-) diff --git a/backends/arm/_passes/decompose_avg_pool2d_pass.py b/backends/arm/_passes/decompose_avg_pool2d_pass.py index c46a54b0efa..a3fe049b8bb 100644 --- a/backends/arm/_passes/decompose_avg_pool2d_pass.py +++ b/backends/arm/_passes/decompose_avg_pool2d_pass.py @@ -51,7 +51,7 @@ def call_operator(self, op, args, kwargs, meta): full_op, cat_op, avgpool_op, mul_op = get_decomposition(op) x = args[0] - full_kwargs = {"device": x.data.device} + full_kwargs = {"device": x.data.device, "dtype": x.data.dtype} kernel_h, kernel_w = args[1] kernel_size = kernel_h * kernel_w if len(args) > 2 and args[2] is not None: diff --git a/backends/arm/_passes/fuse_equal_placeholders_pass.py b/backends/arm/_passes/fuse_equal_placeholders_pass.py index 37cac8a8c56..f31675b8daf 100644 --- a/backends/arm/_passes/fuse_equal_placeholders_pass.py +++ b/backends/arm/_passes/fuse_equal_placeholders_pass.py @@ -54,15 +54,15 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # ensure we don't merge any special case int48_t tensors with int32_t tensors # since int48_t tensors needs to be instantiated separately. is_special_dtype = node.meta.get(TosaSpecialDtype.meta_key(), None) - t_cpu = tensor.detach().cpu().contiguous() + t_cpu = tensor.cpu().contiguous().flatten().view(dtype=torch.uint8) data_bytes = t_cpu.numpy().tobytes() key = ( is_special_dtype, - str(t_cpu.dtype), - tuple(t_cpu.shape), + str(tensor.dtype), + tuple(tensor.shape), hashlib.sha1(data_bytes, usedforsecurity=False).hexdigest(), ) - hash_buckets[key].append((node, t_cpu)) + hash_buckets[key].append((node, tensor)) # For each bucket with more than one entry, fuse: for nodes_tensors in hash_buckets.values(): diff --git a/backends/arm/operators/op_abs.py b/backends/arm/operators/op_abs.py index fd16ca1eada..b21407591a5 100644 --- a/backends/arm/operators/op_abs.py +++ b/backends/arm/operators/op_abs.py @@ -43,7 +43,7 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT32, ts.DType.FP32, ts.DType.BF16], self.tosa_spec, ) diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index 1873351c465..cc180ec47b7 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -115,7 +115,7 @@ def define_node( ) -> None: validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7]) validate_same_dtype(self.target, [inputs[0], output], ts) - supported_dtypes = [ts.DType.INT8, ts.DType.FP32] + supported_dtypes = [ts.DType.INT8, ts.DType.FP32, ts.DType.BF16] if self.tosa_spec.support_extension("int16"): supported_dtypes.append(ts.DType.INT16) validate_valid_dtype( diff --git a/backends/arm/operators/op_cat.py b/backends/arm/operators/op_cat.py index 60ef31017d0..40debd29685 100644 --- a/backends/arm/operators/op_cat.py +++ b/backends/arm/operators/op_cat.py @@ -37,7 +37,13 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - supported_dtypes = [ts.DType.BOOL, ts.DType.INT8, ts.DType.INT32, ts.DType.FP32] + supported_dtypes = [ + ts.DType.BOOL, + ts.DType.INT8, + ts.DType.INT32, + ts.DType.FP32, + ts.DType.BF16, + ] if self.tosa_spec.support_extension("int16"): supported_dtypes.append(ts.DType.INT16) validate_num_inputs(self.target, inputs, [1, 2]) diff --git a/backends/arm/operators/op_ceil.py b/backends/arm/operators/op_ceil.py index d0f713c24a9..80bdc0f86bd 100644 --- a/backends/arm/operators/op_ceil.py +++ b/backends/arm/operators/op_ceil.py @@ -45,7 +45,7 @@ def define_node( validate_valid_dtype( self.target, inputs[0], - ts.DType.FP32, + [ts.DType.FP32, ts.DType.BF16], self.tosa_spec, ) diff --git a/backends/arm/operators/op_clamp.py b/backends/arm/operators/op_clamp.py index 931a842a656..faa40d836ef 100644 --- a/backends/arm/operators/op_clamp.py +++ b/backends/arm/operators/op_clamp.py @@ -6,7 +6,6 @@ from typing import Any, List, Tuple -import numpy as np import torch import tosa_serializer as ts @@ -67,16 +66,7 @@ def cast_type(value: Any) -> int | float: return min_arg, max_arg def _to_bytes(self, value: int | float, dtype: torch.dtype) -> bytes: - if dtype == torch.float32: - return np.frombuffer(np.float32(value).tobytes(), dtype=np.uint8).tolist() - elif dtype == torch.float16: - return np.frombuffer(np.float16(value).tobytes(), dtype=np.uint8).tolist() - elif dtype == torch.int8: - return np.frombuffer(np.int8(value).tobytes(), dtype=np.uint8).tolist() - elif dtype == torch.int16: - return np.frombuffer(np.int16(value).tobytes(), dtype=np.uint8).tolist() - else: - raise ValueError(f"Unsupported dtype for to_bytes: {dtype}") + return torch.full((1,), value, dtype=dtype).view(torch.uint8).numpy().tolist() def define_node( self, @@ -87,7 +77,12 @@ def define_node( ) -> None: validate_num_inputs(self.target, inputs, [2, 3]) validate_same_dtype(self.target, [inputs[0], output], ts) - supported_dtypes = [ts.DType.INT8, ts.DType.FP16, ts.DType.FP32] + supported_dtypes = [ + ts.DType.INT8, + ts.DType.FP16, + ts.DType.BF16, + ts.DType.FP32, + ] if self.tosa_spec.support_extension("int16"): supported_dtypes.append(ts.DType.INT16) validate_valid_dtype( diff --git a/backends/arm/operators/op_maximum.py b/backends/arm/operators/op_maximum.py index 71eca11801b..a44e20f6657 100644 --- a/backends/arm/operators/op_maximum.py +++ b/backends/arm/operators/op_maximum.py @@ -46,7 +46,7 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT32, ts.DType.FP32, ts.DType.BF16], self.tosa_spec, ) diff --git a/backends/arm/operators/op_minimum.py b/backends/arm/operators/op_minimum.py index bc506e26fe7..3fb9d23ccfd 100644 --- a/backends/arm/operators/op_minimum.py +++ b/backends/arm/operators/op_minimum.py @@ -47,7 +47,7 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT32, ts.DType.FP32, ts.DType.BF16], self.tosa_spec, ) diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index b85b1b43013..042965fecc5 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -11,6 +11,7 @@ import torch import torch.fx import tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import NodeVisitor from executorch.backends.arm.tosa.mapping import TosaArg from executorch.backends.arm.tosa.specification import TosaSpecification @@ -26,6 +27,23 @@ from torch.export.exported_program import ExportedProgram +def _tensor_to_numpy_with_dim_order( + tensor: torch.Tensor, dim_order: tuple[int, ...] +) -> np.ndarray: + tensor = tensor.detach().cpu().contiguous() + if tensor.dtype == torch.bfloat16: + try: + import ml_dtypes + except ImportError as e: + raise RuntimeError( + "ml_dtypes is required to serialize bfloat16 tensors for TOSA. Have you run setup.sh?" + ) from e + np_tensor = tensor.view(torch.uint16).numpy().view(ml_dtypes.bfloat16) + else: + np_tensor = tensor.numpy() + return np.transpose(np_tensor, dim_order) + + def process_call_function( node: torch.fx.Node, tosa_graph: Any, @@ -114,9 +132,9 @@ def process_inputs_to_parameters( f"Expected parameter '{node.name}' to be a torch.Tensor, got " f"{type(parameter_data).__name__}" ) - parameter_values = parameter_data.detach().numpy() - - parameter_values = np.transpose(parameter_values, tosa_arg.dim_order) + parameter_values = _tensor_to_numpy_with_dim_order( + parameter_data, tosa_arg.dim_order # type: ignore[arg-type] + ) tosa_graph.addConst( parameter_values.shape, tosa_arg.dtype, parameter_values, name=tosa_arg.name @@ -144,12 +162,7 @@ def process_inputs_to_buffers( f"Expected buffer '{node.name}' to be a torch.Tensor, got " f"{type(buffer_data).__name__}" ) - buffer_values = buffer_data.detach().numpy() - - # TODO: fragile code for temporary fix - # the mean and var tensors are also stored here but they have shape (1, ) - # we only transpose weights here - buffer_values = np.transpose(buffer_values, tosa_arg.dim_order) + buffer_values = _tensor_to_numpy_with_dim_order(buffer_data, tosa_arg.dim_order) # type: ignore[arg-type] tosa_graph.addConst( buffer_values.shape, tosa_arg.dtype, buffer_values, name=tosa_arg.name @@ -170,8 +183,10 @@ def process_inputs_to_lifted_tensor_constants( "Is the original torch function supported?" ) from e tensor = get_lifted_tensor_constant(edge_program, node) - tensor_data = tensor.detach().numpy() # type: ignore[union-attr] - tensor_values = np.transpose(tensor_data, tosa_arg.dim_order) + tensor_values = _tensor_to_numpy_with_dim_order( + tensor, # type: ignore[arg-type] + tosa_arg.dim_order, # type: ignore[arg-type] + ) tosa_graph.addConst( tensor_values.shape, tosa_arg.dtype, tensor_values, name=tosa_arg.name diff --git a/backends/arm/test/ops/test_abs.py b/backends/arm/test/ops/test_abs.py index 9e8ad2e3d03..632ae24d5cd 100644 --- a/backends/arm/test/ops/test_abs.py +++ b/backends/arm/test/ops/test_abs.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -34,14 +34,20 @@ class Abs(torch.nn.Module): "randn_4d": lambda: (torch.randn(1, 2, 3, 4),), "torch_normal": lambda: (torch.normal(mean=0, std=10, size=(2, 3, 4)),), } + test_parameters_bf16 = { + "randn_1d_bf16": lambda: (torch.randn(8, dtype=torch.bfloat16),), + "randn_4d_bf16": lambda: (torch.randn(1, 2, 3, 4, dtype=torch.bfloat16),), + } def forward(self, x): return torch.abs(x) -@common.parametrize("test_data", Abs.test_parameters) +@common.parametrize("test_data", Abs.test_parameters | Abs.test_parameters_bf16) def test_abs_tosa_FP(test_data: torch.Tensor): - pipeline = TosaPipelineFP[input_t1](Abs(), test_data(), aten_op, exir_op) + pipeline = TosaPipelineFP[input_t1]( + Abs(), test_data(), aten_op, exir_op, tosa_extensions=["bf16"] + ) pipeline.run() diff --git a/backends/arm/test/ops/test_avg_pool2d.py b/backends/arm/test/ops/test_avg_pool2d.py index 8885d19ddb4..a0791a88d9c 100644 --- a/backends/arm/test/ops/test_avg_pool2d.py +++ b/backends/arm/test/ops/test_avg_pool2d.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -124,8 +124,19 @@ def forward(self, x: torch.Tensor): "becomes_mean_rank5": lambda: (BecomesMeanInToEdge(), (torch.rand(2, 2, 8, 8),)), } +test_modules_bf16 = { + "rand_bf16": lambda: ( + AvgPool2d(4, 2, 0, False), + (torch.rand(1, 16, 50, 32, dtype=torch.bfloat16),), + ), + "kernel_3x3_stride_1_pad_1_bf16": lambda: ( + AvgPool2d((3, 3), (1, 1), 1), + (torch.rand(1, 4, 12, 12, dtype=torch.bfloat16),), + ), +} -@common.parametrize("test_module", test_modules) + +@common.parametrize("test_module", test_modules | test_modules_bf16) def test_avg_pool2d_tosa_FP(test_module): model, input_tensor = test_module() @@ -134,6 +145,7 @@ def test_avg_pool2d_tosa_FP(test_module): input_tensor, aten_op, exir_op, + tosa_extensions=["bf16"], run_on_tosa_ref_model=conftest.is_option_enabled("tosa_ref_model"), ) pipeline.run() diff --git a/backends/arm/test/ops/test_cat.py b/backends/arm/test/ops/test_cat.py index a037d0e366f..8a2f990c277 100644 --- a/backends/arm/test/ops/test_cat.py +++ b/backends/arm/test/ops/test_cat.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -66,6 +66,23 @@ class Cat(torch.nn.Module): ), } + test_parameters_bf16 = { + "cat_rand_two_tensors_bf16": lambda: ( + ( + torch.randn(1, 2, 4, 4, dtype=torch.bfloat16), + torch.randn(1, 2, 4, 1, dtype=torch.bfloat16), + ), + 3, + ), + "cat_rand_dim0_bf16": lambda: ( + ( + torch.randn(1, 2, 4, 4, dtype=torch.bfloat16), + torch.randn(1, 2, 4, 4, dtype=torch.bfloat16), + ), + 0, + ), + } + def __init__(self): super().__init__() @@ -73,19 +90,20 @@ def forward(self, t: tuple[torch.Tensor, ...], dim: int) -> torch.Tensor: return torch.cat(t, dim=dim) -@common.parametrize("test_data", Cat.test_parameters) +@common.parametrize("test_data", Cat.test_parameters | Cat.test_parameters_bf16) def test_cat_tosa_FP(test_data: Tuple): pipeline = TosaPipelineFP[input_t1]( Cat(), test_data(), aten_op, exir_op, + tosa_extensions=["bf16"], ) pipeline.run() def test_cat_tosa_FP_4d(): - square = torch.ones((2, 2, 2, 2)) + square = torch.ones((2, 2, 2, 2), dtype=torch.bfloat16) for dim in range(-3, 3): test_data = ((square, square.clone()), dim) pipeline = TosaPipelineFP[input_t1]( @@ -93,6 +111,7 @@ def test_cat_tosa_FP_4d(): test_data, aten_op, exir_op, + tosa_extensions=["bf16"], ) pipeline.run() diff --git a/backends/arm/test/ops/test_ceil.py b/backends/arm/test/ops/test_ceil.py index 93b5f9cd009..7f03d3ce11f 100644 --- a/backends/arm/test/ops/test_ceil.py +++ b/backends/arm/test/ops/test_ceil.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -42,8 +42,19 @@ def forward(self, x: torch.Tensor): "ceil_ramp": lambda: (Ceil(), ramp), } +test_data_bf16 = { + "ceil_rand_bf16": lambda: ( + Ceil(), + (torch.rand(4, 4, dtype=torch.bfloat16) - 0.5), + ), + "ceil_ramp_bf16": lambda: ( + Ceil(), + torch.arange(-8, 8, 0.25, dtype=torch.bfloat16), + ), +} -@common.parametrize("test_data", test_data) + +@common.parametrize("test_data", test_data | test_data_bf16) def test_ceil_tosa_FP(test_data: input_t1): module, data = test_data() pipeline = TosaPipelineFP[input_t1]( @@ -51,6 +62,7 @@ def test_ceil_tosa_FP(test_data: input_t1): (data,), module.aten_op, module.exir_op, + tosa_extensions=["bf16"], ) pipeline.run() diff --git a/backends/arm/test/ops/test_clamp.py b/backends/arm/test/ops/test_clamp.py index 13c3479daa3..ad01be27cf9 100644 --- a/backends/arm/test/ops/test_clamp.py +++ b/backends/arm/test/ops/test_clamp.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -35,6 +35,19 @@ "rank_4_no_max": lambda: (torch.rand(1, 10, 10, 1) - 3, -3.3, None), } +test_data_suite_bf16 = { + "rank_2_bf16": lambda: ( + torch.rand(1, 35, dtype=torch.bfloat16), + 0.5, + 0.8, + ), + "rank_4_no_max_bf16": lambda: ( + torch.rand(1, 10, 10, 1, dtype=torch.bfloat16) - 3, + -3.3, + None, + ), +} + test_data_suite_int32 = { "int32_rank2": lambda: (torch.randint(-50, 50, (2, 3), dtype=torch.int32), -10, 10), "int32_rank3_no_min": lambda: ( @@ -70,7 +83,7 @@ def forward(self, x): return torch.clamp(x, self.clamp_min, self.clamp_max) -@common.parametrize("test_data", test_data_suite) +@common.parametrize("test_data", test_data_suite | test_data_suite_bf16) def test_clamp_tosa_FP(test_data): input_tensor, min_val, max_val = test_data() model = Clamp(min_val, max_val) @@ -80,6 +93,7 @@ def test_clamp_tosa_FP(test_data): (input_tensor,), aten_op, exir_op, + tosa_extensions=["bf16"], ) pipeline.run() @@ -254,6 +268,19 @@ def test_clamp_vgf_quant(test_data): ), } +test_data_suite_tensor_bf16 = { + "rank_2_bf16": lambda: ( + torch.rand(1, 35, dtype=torch.bfloat16), + torch.tensor(0.5, dtype=torch.bfloat16), + torch.tensor(0.8, dtype=torch.bfloat16), + ), + "rank_4_no_max_bf16": lambda: ( + torch.rand(10, 20, 30, 40, dtype=torch.bfloat16) - 3, + torch.tensor(-0.1, dtype=torch.bfloat16), + None, + ), +} + test_data_suite_tensor_INT32 = { "int32_rank2": lambda: ( torch.randint(-50, 50, (2, 3), dtype=torch.int32), @@ -328,7 +355,9 @@ def test_clamp_vgf_quant(test_data): } -@common.parametrize("test_data", test_data_suite_tensor_FP) +@common.parametrize( + "test_data", test_data_suite_tensor_FP | test_data_suite_tensor_bf16 +) def test_clamp_tosa_FP_tensor(test_data): input_tensor, min_val, max_val = test_data() model = Clamp(min_val, max_val) @@ -338,6 +367,7 @@ def test_clamp_tosa_FP_tensor(test_data): (input_tensor,), aten_op_tensor, exir_op_tensor, + tosa_extensions=["bf16"], ) pipeline.run() From 4a8ef5e8af3b3395b836315c6f1571a4457ae212 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Thu, 22 Jan 2026 09:35:15 +0100 Subject: [PATCH 4/4] Add ignore to potentially missing import Signed-off-by: Erik Lundell Change-Id: I6c95f2c782913f8ea0a928102043547f0b87f5c4 --- backends/arm/process_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 042965fecc5..ac05dbcdc04 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -33,7 +33,7 @@ def _tensor_to_numpy_with_dim_order( tensor = tensor.detach().cpu().contiguous() if tensor.dtype == torch.bfloat16: try: - import ml_dtypes + import ml_dtypes # type: ignore[import-not-found] except ImportError as e: raise RuntimeError( "ml_dtypes is required to serialize bfloat16 tensors for TOSA. Have you run setup.sh?"